Merge remote-tracking branch 'origin/master' into patch-2

This commit is contained in:
ramdoys 2025-02-20 21:54:31 -05:00
commit ceed23d2b9
338 changed files with 13549 additions and 3674 deletions

View file

@ -65,10 +65,7 @@ TypeId makeFunction( // Polymorphic
bool checked = false
);
void attachMagicFunction(TypeId ty, MagicFunction fn);
void attachDcrMagicFunction(TypeId ty, DcrMagicFunction fn);
void attachDcrMagicRefinement(TypeId ty, DcrMagicRefinement fn);
void attachDcrMagicFunctionTypeCheck(TypeId ty, DcrMagicFunctionTypeCheck fn);
void attachMagicFunction(TypeId ty, std::shared_ptr<MagicFunction> fn);
Property makeProperty(TypeId ty, std::optional<std::string> documentationSymbol = std::nullopt);
void assignPropDocumentationSymbols(TableType::Props& props, const std::string& baseName);

View file

@ -4,6 +4,7 @@
#include <Luau/NotNull.h>
#include "Luau/TypeArena.h"
#include "Luau/Type.h"
#include "Luau/Scope.h"
#include <unordered_map>
@ -26,13 +27,17 @@ struct CloneState
* while `clone` will make a deep copy of the entire type and its every component.
*
* Be mindful about which behavior you actually _want_.
*
* Persistent types are not cloned as an optimization.
* If a type is cloned in order to mutate it, 'ignorePersistent' has to be set
*/
TypePackId shallowClone(TypePackId tp, TypeArena& dest, CloneState& cloneState);
TypeId shallowClone(TypeId typeId, TypeArena& dest, CloneState& cloneState);
TypePackId shallowClone(TypePackId tp, TypeArena& dest, CloneState& cloneState, bool ignorePersistent = false);
TypeId shallowClone(TypeId typeId, TypeArena& dest, CloneState& cloneState, bool ignorePersistent = false);
TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState);
TypeId clone(TypeId tp, TypeArena& dest, CloneState& cloneState);
TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState);
Binding clone(const Binding& binding, TypeArena& dest, CloneState& cloneState);
} // namespace Luau

View file

@ -166,7 +166,7 @@ struct ConstraintSolver
**/
void finalizeTypeFunctions();
bool isDone();
bool isDone() const;
private:
/**
@ -298,10 +298,10 @@ public:
// FIXME: This use of a boolean for the return result is an appalling
// interface.
bool blockOnPendingTypes(TypeId target, NotNull<const Constraint> constraint);
bool blockOnPendingTypes(TypePackId target, NotNull<const Constraint> constraint);
bool blockOnPendingTypes(TypePackId targetPack, NotNull<const Constraint> constraint);
void unblock(NotNull<const Constraint> progressed);
void unblock(TypeId progressed, Location location);
void unblock(TypeId ty, Location location);
void unblock(TypePackId progressed, Location location);
void unblock(const std::vector<TypeId>& types, Location location);
void unblock(const std::vector<TypePackId>& packs, Location location);
@ -336,7 +336,7 @@ public:
* @param location the location where the require is taking place; used for
* error locations.
**/
TypeId resolveModule(const ModuleInfo& module, const Location& location);
TypeId resolveModule(const ModuleInfo& info, const Location& location);
void reportError(TypeErrorData&& data, const Location& location);
void reportError(TypeError e);
@ -420,6 +420,11 @@ public:
void throwUserCancelError() const;
ToStringOptions opts;
void fillInDiscriminantTypes(
NotNull<const Constraint> constraint,
const std::vector<std::optional<TypeId>>& discriminantTypes
);
};
void dump(NotNull<Scope> rootScope, struct ToStringOptions& opts);

View file

@ -6,6 +6,7 @@
#include "Luau/ControlFlow.h"
#include "Luau/DenseHash.h"
#include "Luau/Def.h"
#include "Luau/NotNull.h"
#include "Luau/Symbol.h"
#include "Luau/TypedAllocator.h"
@ -48,13 +49,13 @@ struct DataFlowGraph
const RefinementKey* getRefinementKey(const AstExpr* expr) const;
private:
DataFlowGraph() = default;
DataFlowGraph(NotNull<DefArena> defArena, NotNull<RefinementKeyArena> keyArena);
DataFlowGraph(const DataFlowGraph&) = delete;
DataFlowGraph& operator=(const DataFlowGraph&) = delete;
DefArena defArena;
RefinementKeyArena keyArena;
NotNull<DefArena> defArena;
NotNull<RefinementKeyArena> keyArena;
DenseHashMap<const AstExpr*, const Def*> astDefs{nullptr};
@ -110,30 +111,22 @@ using ScopeStack = std::vector<DfgScope*>;
struct DataFlowGraphBuilder
{
static DataFlowGraph build(AstStatBlock* root, NotNull<struct InternalErrorReporter> handle);
/**
* This method is identical to the build method above, but returns a pair of dfg, scopes as the data flow graph
* here is intended to live on the module between runs of typechecking. Before, the DFG only needed to live as
* long as the typecheck, but in a world with incremental typechecking, we need the information on the dfg to incrementally
* typecheck small fragments of code.
* @param block - pointer to the ast to build the dfg for
* @param handle - for raising internal errors while building the dfg
*/
static std::pair<std::shared_ptr<DataFlowGraph>, std::vector<std::unique_ptr<DfgScope>>> buildShared(
static DataFlowGraph build(
AstStatBlock* block,
NotNull<InternalErrorReporter> handle
NotNull<DefArena> defArena,
NotNull<RefinementKeyArena> keyArena,
NotNull<struct InternalErrorReporter> handle
);
private:
DataFlowGraphBuilder() = default;
DataFlowGraphBuilder(NotNull<DefArena> defArena, NotNull<RefinementKeyArena> keyArena);
DataFlowGraphBuilder(const DataFlowGraphBuilder&) = delete;
DataFlowGraphBuilder& operator=(const DataFlowGraphBuilder&) = delete;
DataFlowGraph graph;
NotNull<DefArena> defArena{&graph.defArena};
NotNull<RefinementKeyArena> keyArena{&graph.keyArena};
NotNull<DefArena> defArena;
NotNull<RefinementKeyArena> keyArena;
struct InternalErrorReporter* handle = nullptr;

View file

@ -53,7 +53,7 @@ LUAU_EQSAT_NODE_SET(Intersection);
LUAU_EQSAT_NODE_ARRAY(Negation, 1);
LUAU_EQSAT_NODE_ATOM_WITH_VECTOR(TTypeFun, const TypeFunction*);
LUAU_EQSAT_NODE_ATOM_WITH_VECTOR(TTypeFun, std::shared_ptr<const TypeFunctionInstanceType>);
LUAU_EQSAT_UNIT(TNoRefine);
LUAU_EQSAT_UNIT(Invalid);
@ -105,6 +105,9 @@ private:
std::vector<Id> storage;
};
template <typename L>
using Node = EqSat::Node<L>;
using EType = EqSat::Language<
TNil,
TBoolean,
@ -146,7 +149,7 @@ using EType = EqSat::Language<
struct StringCache
{
Allocator allocator;
DenseHashMap<size_t, StringId> strings{{}};
DenseHashMap<std::string_view, StringId> strings{{}};
std::vector<std::string_view> views;
StringId add(std::string_view s);
@ -171,6 +174,9 @@ struct Subst
Id eclass;
Id newClass;
// The node into eclass which is boring, if any
std::optional<size_t> boringIndex;
std::string desc;
Subst(Id eclass, Id newClass, std::string desc = "");
@ -211,6 +217,7 @@ struct Simplifier
void subst(Id from, Id to);
void subst(Id from, Id to, const std::string& ruleName);
void subst(Id from, Id to, const std::string& ruleName, const std::unordered_map<Id, size_t>& forceNodes);
void subst(Id from, size_t boringIndex, Id to, const std::string& ruleName, const std::unordered_map<Id, size_t>& forceNodes);
void unionClasses(std::vector<Id>& hereParts, Id there);
@ -218,6 +225,7 @@ struct Simplifier
void simplifyUnion(Id id);
void uninhabitedIntersection(Id id);
void intersectWithNegatedClass(Id id);
void intersectWithNegatedAtom(Id id);
void intersectWithNoRefine(Id id);
void cyclicIntersectionOfUnion(Id id);
void cyclicUnionOfIntersection(Id id);
@ -228,6 +236,7 @@ struct Simplifier
void unneededTableModification(Id id);
void builtinTypeFunctions(Id id);
void iffyTypeFunctions(Id id);
void strictMetamethods(Id id);
};
template<typename Tag>
@ -293,13 +302,13 @@ QueryIterator<Tag>::QueryIterator(EGraph* egraph_, Id eclass)
for (const auto& enode : ecl.nodes)
{
if (enode.index() < idx)
if (enode.node.index() < idx)
++index;
else
break;
}
if (index >= ecl.nodes.size() || ecl.nodes[index].index() != idx)
if (index >= ecl.nodes.size() || ecl.nodes[index].node.index() != idx)
{
egraph = nullptr;
index = 0;
@ -329,7 +338,7 @@ std::pair<const Tag*, size_t> QueryIterator<Tag>::operator*() const
EGraph::EClassT& ecl = (*egraph)[eclass];
LUAU_ASSERT(index < ecl.nodes.size());
auto& enode = ecl.nodes[index];
auto& enode = ecl.nodes[index].node;
Tag* result = enode.template get<Tag>();
LUAU_ASSERT(result);
return {result, index};
@ -341,12 +350,16 @@ QueryIterator<Tag>& QueryIterator<Tag>::operator++()
{
const auto& ecl = (*egraph)[eclass];
++index;
if (index >= ecl.nodes.size() || ecl.nodes[index].index() != EType::VariantTy::getTypeId<Tag>())
do
{
egraph = nullptr;
index = 0;
}
++index;
if (index >= ecl.nodes.size() || ecl.nodes[index].node.index() != EType::VariantTy::getTypeId<Tag>())
{
egraph = nullptr;
index = 0;
break;
}
} while (ecl.nodes[index].boring);
return *this;
}

View file

@ -15,6 +15,12 @@ namespace Luau
{
struct FrontendOptions;
enum class FragmentTypeCheckStatus
{
SkipAutocomplete,
Success,
};
struct FragmentAutocompleteAncestryResult
{
DenseHashMap<AstName, AstLocal*> localMap{AstName()};
@ -29,6 +35,7 @@ struct FragmentParseResult
AstStatBlock* root = nullptr;
std::vector<AstNode*> ancestry;
AstStat* nearestStatement = nullptr;
std::vector<Comment> commentLocations;
std::unique_ptr<Allocator> alloc = std::make_unique<Allocator>();
};
@ -49,14 +56,14 @@ struct FragmentAutocompleteResult
FragmentAutocompleteAncestryResult findAncestryForFragmentParse(AstStatBlock* root, const Position& cursorPos);
FragmentParseResult parseFragment(
std::optional<FragmentParseResult> parseFragment(
const SourceModule& srcModule,
std::string_view src,
const Position& cursorPos,
std::optional<Position> fragmentEndPosition
);
FragmentTypeCheckResult typecheckFragment(
std::pair<FragmentTypeCheckStatus, FragmentTypeCheckResult> typecheckFragment(
Frontend& frontend,
const ModuleName& moduleName,
const Position& cursorPos,

View file

@ -7,6 +7,7 @@
#include "Luau/ModuleResolver.h"
#include "Luau/RequireTracer.h"
#include "Luau/Scope.h"
#include "Luau/Set.h"
#include "Luau/TypeCheckLimits.h"
#include "Luau/Variant.h"
#include "Luau/AnyTypeSummary.h"
@ -56,13 +57,32 @@ struct SourceNode
return forAutocomplete ? dirtyModuleForAutocomplete : dirtyModule;
}
bool hasInvalidModuleDependency(bool forAutocomplete) const
{
return forAutocomplete ? invalidModuleDependencyForAutocomplete : invalidModuleDependency;
}
void setInvalidModuleDependency(bool value, bool forAutocomplete)
{
if (forAutocomplete)
invalidModuleDependencyForAutocomplete = value;
else
invalidModuleDependency = value;
}
ModuleName name;
std::string humanReadableName;
DenseHashSet<ModuleName> requireSet{{}};
std::vector<std::pair<ModuleName, Location>> requireLocations;
Set<ModuleName> dependents{{}};
bool dirtySourceModule = true;
bool dirtyModule = true;
bool dirtyModuleForAutocomplete = true;
bool invalidModuleDependency = true;
bool invalidModuleDependencyForAutocomplete = true;
double autocompleteLimitsMult = 1.0;
};
@ -117,7 +137,7 @@ struct FrontendModuleResolver : ModuleResolver
std::optional<ModuleInfo> resolveModuleInfo(const ModuleName& currentModuleName, const AstExpr& pathExpr) override;
std::string getHumanReadableModuleName(const ModuleName& moduleName) const override;
void setModule(const ModuleName& moduleName, ModulePtr module);
bool setModule(const ModuleName& moduleName, ModulePtr module);
void clearModules();
private:
@ -151,9 +171,13 @@ struct Frontend
// Parse and typecheck module graph
CheckResult check(const ModuleName& name, std::optional<FrontendOptions> optionOverride = {}); // new shininess
bool allModuleDependenciesValid(const ModuleName& name, bool forAutocomplete = false) const;
bool isDirty(const ModuleName& name, bool forAutocomplete = false) const;
void markDirty(const ModuleName& name, std::vector<ModuleName>* markedDirty = nullptr);
void traverseDependents(const ModuleName& name, std::function<bool(SourceNode&)> processSubtree);
/** Borrow a pointer into the SourceModule cache.
*
* Returns nullptr if we don't have it. This could mean that the script

View file

@ -16,6 +16,8 @@
#include <unordered_map>
#include <optional>
LUAU_FASTFLAG(LuauIncrementalAutocompleteCommentDetection)
namespace Luau
{
@ -55,6 +57,7 @@ struct SourceModule
}
};
bool isWithinComment(const std::vector<Comment>& commentLocations, Position pos);
bool isWithinComment(const SourceModule& sourceModule, Position pos);
bool isWithinComment(const ParseResult& result, Position pos);
@ -136,6 +139,11 @@ struct Module
TypePackId returnType = nullptr;
std::unordered_map<Name, TypeFun> exportedTypeBindings;
// Arenas related to the DFG must persist after the DFG no longer exists, as
// Module objects maintain raw pointers to objects in these arenas.
DefArena defArena;
RefinementKeyArena keyArena;
bool hasModuleScope() const;
ScopePtr getModuleScope() const;

View file

@ -15,6 +15,7 @@ struct TypeCheckLimits;
void checkNonStrict(
NotNull<BuiltinTypes> builtinTypes,
NotNull<Simplifier> simplifier,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<InternalErrorReporter> ice,
NotNull<UnifierSharedState> unifierState,

View file

@ -1,6 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/EqSatSimplification.h"
#include "Luau/NotNull.h"
#include "Luau/Set.h"
#include "Luau/TypeFwd.h"
@ -21,8 +22,22 @@ struct Scope;
using ModulePtr = std::shared_ptr<Module>;
bool isSubtype(TypeId subTy, TypeId superTy, NotNull<Scope> scope, NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter& ice);
bool isSubtype(TypePackId subTy, TypePackId superTy, NotNull<Scope> scope, NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter& ice);
bool isSubtype(
TypeId subTy,
TypeId superTy,
NotNull<Scope> scope,
NotNull<BuiltinTypes> builtinTypes,
NotNull<Simplifier> simplifier,
InternalErrorReporter& ice
);
bool isSubtype(
TypePackId subPack,
TypePackId superPack,
NotNull<Scope> scope,
NotNull<BuiltinTypes> builtinTypes,
NotNull<Simplifier> simplifier,
InternalErrorReporter& ice
);
class TypeIds
{

View file

@ -2,12 +2,13 @@
#pragma once
#include "Luau/Ast.h"
#include "Luau/InsertionOrderedMap.h"
#include "Luau/NotNull.h"
#include "Luau/TypeFwd.h"
#include "Luau/Location.h"
#include "Luau/EqSatSimplification.h"
#include "Luau/Error.h"
#include "Luau/InsertionOrderedMap.h"
#include "Luau/Location.h"
#include "Luau/NotNull.h"
#include "Luau/Subtyping.h"
#include "Luau/TypeFwd.h"
namespace Luau
{
@ -34,6 +35,7 @@ struct OverloadResolver
OverloadResolver(
NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> arena,
NotNull<Simplifier> simplifier,
NotNull<Normalizer> normalizer,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<Scope> scope,
@ -44,6 +46,7 @@ struct OverloadResolver
NotNull<BuiltinTypes> builtinTypes;
NotNull<TypeArena> arena;
NotNull<Simplifier> simplifier;
NotNull<Normalizer> normalizer;
NotNull<TypeFunctionRuntime> typeFunctionRuntime;
NotNull<Scope> scope;
@ -110,6 +113,7 @@ struct SolveResult
SolveResult solveFunctionCall(
NotNull<TypeArena> arena,
NotNull<BuiltinTypes> builtinTypes,
NotNull<Simplifier> simplifier,
NotNull<Normalizer> normalizer,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<InternalErrorReporter> iceReporter,

View file

@ -85,12 +85,18 @@ struct Scope
void inheritAssignments(const ScopePtr& childScope);
void inheritRefinements(const ScopePtr& childScope);
// Track globals that should emit warnings during type checking.
DenseHashSet<std::string> globalsToWarn{""};
bool shouldWarnGlobal(std::string name) const;
// For mutually recursive type aliases, it's important that
// they use the same types for the same names.
// For instance, in `type Tree<T> { data: T, children: Forest<T> } type Forest<T> = {Tree<T>}`
// we need that the generic type `T` in both cases is the same, so we use a cache.
std::unordered_map<Name, TypeId> typeAliasTypeParameters;
std::unordered_map<Name, TypePackId> typeAliasTypePackParameters;
std::optional<std::vector<TypeId>> interiorFreeTypes;
};
// Returns true iff the left scope encloses the right scope. A Scope* equal to

View file

@ -19,10 +19,10 @@ struct SimplifyResult
DenseHashSet<TypeId> blockedTypes;
};
SimplifyResult simplifyIntersection(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, TypeId ty, TypeId discriminant);
SimplifyResult simplifyIntersection(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, TypeId left, TypeId right);
SimplifyResult simplifyIntersection(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, std::set<TypeId> parts);
SimplifyResult simplifyUnion(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, TypeId ty, TypeId discriminant);
SimplifyResult simplifyUnion(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, TypeId left, TypeId right);
enum class Relation
{

View file

@ -1,13 +1,14 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/DenseHash.h"
#include "Luau/EqSatSimplification.h"
#include "Luau/Set.h"
#include "Luau/TypeCheckLimits.h"
#include "Luau/TypeFunction.h"
#include "Luau/TypeFwd.h"
#include "Luau/TypePairHash.h"
#include "Luau/TypePath.h"
#include "Luau/TypeFunction.h"
#include "Luau/TypeCheckLimits.h"
#include "Luau/DenseHash.h"
#include <vector>
#include <optional>
@ -134,6 +135,7 @@ struct Subtyping
{
NotNull<BuiltinTypes> builtinTypes;
NotNull<TypeArena> arena;
NotNull<Simplifier> simplifier;
NotNull<Normalizer> normalizer;
NotNull<TypeFunctionRuntime> typeFunctionRuntime;
NotNull<InternalErrorReporter> iceReporter;
@ -155,6 +157,7 @@ struct Subtyping
Subtyping(
NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> typeArena,
NotNull<Simplifier> simplifier,
NotNull<Normalizer> normalizer,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<InternalErrorReporter> iceReporter

View file

@ -69,12 +69,16 @@ using Name = std::string;
// A free type is one whose exact shape has yet to be fully determined.
struct FreeType
{
// New constructors
explicit FreeType(TypeLevel level, TypeId lowerBound, TypeId upperBound);
// This one got promoted to explicit
explicit FreeType(Scope* scope, TypeId lowerBound, TypeId upperBound);
explicit FreeType(Scope* scope, TypeLevel level, TypeId lowerBound, TypeId upperBound);
// Old constructors
explicit FreeType(TypeLevel level);
explicit FreeType(Scope* scope);
FreeType(Scope* scope, TypeLevel level);
FreeType(Scope* scope, TypeId lowerBound, TypeId upperBound);
int index;
TypeLevel level;
Scope* scope = nullptr;
@ -131,14 +135,14 @@ struct BlockedType
BlockedType();
int index;
Constraint* getOwner() const;
void setOwner(Constraint* newOwner);
void replaceOwner(Constraint* newOwner);
const Constraint* getOwner() const;
void setOwner(const Constraint* newOwner);
void replaceOwner(const Constraint* newOwner);
private:
// The constraint that is intended to unblock this type. Other constraints
// should block on this constraint if present.
Constraint* owner = nullptr;
const Constraint* owner = nullptr;
};
struct PrimitiveType
@ -279,9 +283,6 @@ struct WithPredicate
}
};
using MagicFunction = std::function<std::optional<
WithPredicate<TypePackId>>(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>)>;
struct MagicFunctionCallContext
{
NotNull<struct ConstraintSolver> solver;
@ -291,7 +292,6 @@ struct MagicFunctionCallContext
TypePackId result;
};
using DcrMagicFunction = std::function<bool(MagicFunctionCallContext)>;
struct MagicRefinementContext
{
NotNull<Scope> scope;
@ -308,8 +308,29 @@ struct MagicFunctionTypeCheckContext
NotNull<Scope> checkScope;
};
using DcrMagicRefinement = void (*)(const MagicRefinementContext&);
using DcrMagicFunctionTypeCheck = std::function<void(const MagicFunctionTypeCheckContext&)>;
struct MagicFunction
{
virtual std::optional<WithPredicate<TypePackId>> handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) = 0;
// Callback to allow custom typechecking of builtin function calls whose argument types
// will only be resolved after constraint solving. For example, the arguments to string.format
// have types that can only be decided after parsing the format string and unifying
// with the passed in values, but the correctness of the call can only be decided after
// all the types have been finalized.
virtual bool infer(const MagicFunctionCallContext&) = 0;
virtual void refine(const MagicRefinementContext&) {}
// If a magic function needs to do its own special typechecking, do it here.
// Returns true if magic typechecking was performed. Return false if the
// default typechecking logic should run.
virtual bool typeCheck(const MagicFunctionTypeCheckContext&)
{
return false;
}
virtual ~MagicFunction() {}
};
struct FunctionType
{
// Global monomorphic function
@ -367,16 +388,7 @@ struct FunctionType
Scope* scope = nullptr;
TypePackId argTypes;
TypePackId retTypes;
MagicFunction magicFunction = nullptr;
DcrMagicFunction dcrMagicFunction = nullptr;
DcrMagicRefinement dcrMagicRefinement = nullptr;
// Callback to allow custom typechecking of builtin function calls whose argument types
// will only be resolved after constraint solving. For example, the arguments to string.format
// have types that can only be decided after parsing the format string and unifying
// with the passed in values, but the correctness of the call can only be decided after
// all the types have been finalized.
DcrMagicFunctionTypeCheck dcrMagicTypeCheck = nullptr;
std::shared_ptr<MagicFunction> magic = nullptr;
bool hasSelf;
// `hasNoFreeOrGenericTypes` should be true if and only if the type does not have any free or generic types present inside it.
@ -608,7 +620,8 @@ struct UserDefinedFunctionData
// References to AST elements are owned by the Module allocator which also stores this type
AstStatTypeFunction* definition = nullptr;
DenseHashMap<Name, AstStatTypeFunction*> environment{""};
DenseHashMap<Name, std::pair<AstStatTypeFunction*, size_t>> environment{""};
DenseHashMap<Name, AstStatTypeFunction*> environment_DEPRECATED{""};
};
/**
@ -625,7 +638,7 @@ struct TypeFunctionInstanceType
std::vector<TypeId> typeArguments;
std::vector<TypePackId> packArguments;
std::optional<AstName> userFuncName; // Name of the user-defined type function; only available for UDTFs
std::optional<AstName> userFuncName; // Name of the user-defined type function; only available for UDTFs
UserDefinedFunctionData userFuncData;
TypeFunctionInstanceType(

View file

@ -32,9 +32,13 @@ struct TypeArena
TypeId addTV(Type&& tv);
TypeId freshType(TypeLevel level);
TypeId freshType(Scope* scope);
TypeId freshType(Scope* scope, TypeLevel level);
TypeId freshType(NotNull<BuiltinTypes> builtins, TypeLevel level);
TypeId freshType(NotNull<BuiltinTypes> builtins, Scope* scope);
TypeId freshType(NotNull<BuiltinTypes> builtins, Scope* scope, TypeLevel level);
TypeId freshType_DEPRECATED(TypeLevel level);
TypeId freshType_DEPRECATED(Scope* scope);
TypeId freshType_DEPRECATED(Scope* scope, TypeLevel level);
TypePackId freshTypePack(Scope* scope);

View file

@ -2,15 +2,16 @@
#pragma once
#include "Luau/Error.h"
#include "Luau/NotNull.h"
#include "Luau/Common.h"
#include "Luau/TypeUtils.h"
#include "Luau/EqSatSimplification.h"
#include "Luau/Error.h"
#include "Luau/Normalize.h"
#include "Luau/NotNull.h"
#include "Luau/Subtyping.h"
#include "Luau/Type.h"
#include "Luau/TypeFwd.h"
#include "Luau/TypeOrPack.h"
#include "Luau/Normalize.h"
#include "Luau/Subtyping.h"
#include "Luau/TypeUtils.h"
namespace Luau
{
@ -60,8 +61,9 @@ struct Reasonings
void check(
NotNull<BuiltinTypes> builtinTypes,
NotNull<Simplifier> simplifier,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<UnifierSharedState> sharedState,
NotNull<UnifierSharedState> unifierState,
NotNull<TypeCheckLimits> limits,
DcrLogger* logger,
const SourceModule& sourceModule,
@ -71,6 +73,7 @@ void check(
struct TypeChecker2
{
NotNull<BuiltinTypes> builtinTypes;
NotNull<Simplifier> simplifier;
NotNull<TypeFunctionRuntime> typeFunctionRuntime;
DcrLogger* logger;
const NotNull<TypeCheckLimits> limits;
@ -90,6 +93,7 @@ struct TypeChecker2
TypeChecker2(
NotNull<BuiltinTypes> builtinTypes,
NotNull<Simplifier> simplifier,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<UnifierSharedState> unifierState,
NotNull<TypeCheckLimits> limits,
@ -112,14 +116,14 @@ private:
std::optional<StackPusher> pushStack(AstNode* node);
void checkForInternalTypeFunction(TypeId ty, Location location);
TypeId checkForTypeFunctionInhabitance(TypeId instance, Location location);
TypePackId lookupPack(AstExpr* expr);
TypePackId lookupPack(AstExpr* expr) const;
TypeId lookupType(AstExpr* expr);
TypeId lookupAnnotation(AstType* annotation);
std::optional<TypePackId> lookupPackAnnotation(AstTypePack* annotation);
TypeId lookupExpectedType(AstExpr* expr);
TypePackId lookupExpectedPack(AstExpr* expr, TypeArena& arena);
std::optional<TypePackId> lookupPackAnnotation(AstTypePack* annotation) const;
TypeId lookupExpectedType(AstExpr* expr) const;
TypePackId lookupExpectedPack(AstExpr* expr, TypeArena& arena) const;
TypePackId reconstructPack(AstArray<AstExpr*> exprs, TypeArena& arena);
Scope* findInnermostScope(Location location);
Scope* findInnermostScope(Location location) const;
void visit(AstStat* stat);
void visit(AstStatIf* ifStatement);
void visit(AstStatWhile* whileStatement);
@ -156,7 +160,7 @@ private:
void visit(AstExprVarargs* expr);
void visitCall(AstExprCall* call);
void visit(AstExprCall* call);
std::optional<TypeId> tryStripUnionFromNil(TypeId ty);
std::optional<TypeId> tryStripUnionFromNil(TypeId ty) const;
TypeId stripFromNilAndReport(TypeId ty, const Location& location);
void visitExprName(AstExpr* expr, Location location, const std::string& propName, ValueContext context, TypeId astIndexExprTy);
void visit(AstExprIndexName* indexName, ValueContext context);
@ -213,6 +217,9 @@ private:
std::vector<TypeError>& errors
);
// Avoid duplicate warnings being emitted for the same global variable.
DenseHashSet<std::string> warnedGlobals{""};
void diagnoseMissingTableKey(UnknownProperty* utk, TypeErrorData& data) const;
bool isErrorSuppressing(Location loc, TypeId ty);
bool isErrorSuppressing(Location loc1, TypeId ty1, Location loc2, TypeId ty2);

View file

@ -2,6 +2,7 @@
#pragma once
#include "Luau/Constraint.h"
#include "Luau/EqSatSimplification.h"
#include "Luau/Error.h"
#include "Luau/NotNull.h"
#include "Luau/TypeCheckLimits.h"
@ -41,9 +42,15 @@ struct TypeFunctionRuntime
StateRef state;
// Set of functions which have their environment table initialized
DenseHashSet<AstStatTypeFunction*> initialized{nullptr};
// Evaluation of type functions should only be performed in the absence of parse errors in the source module
bool allowEvaluation = true;
// Output created by 'print' function
std::vector<std::string> messages;
private:
void prepareState();
};
@ -53,6 +60,7 @@ struct TypeFunctionContext
NotNull<TypeArena> arena;
NotNull<BuiltinTypes> builtins;
NotNull<Scope> scope;
NotNull<Simplifier> simplifier;
NotNull<Normalizer> normalizer;
NotNull<TypeFunctionRuntime> typeFunctionRuntime;
NotNull<InternalErrorReporter> ice;
@ -63,7 +71,7 @@ struct TypeFunctionContext
// The constraint being reduced in this run of the reduction
const Constraint* constraint;
std::optional<AstName> userFuncName; // Name of the user-defined type function; only available for UDTFs
std::optional<AstName> userFuncName; // Name of the user-defined type function; only available for UDTFs
TypeFunctionContext(NotNull<ConstraintSolver> cs, NotNull<Scope> scope, NotNull<const Constraint> constraint);
@ -71,6 +79,7 @@ struct TypeFunctionContext
NotNull<TypeArena> arena,
NotNull<BuiltinTypes> builtins,
NotNull<Scope> scope,
NotNull<Simplifier> simplifier,
NotNull<Normalizer> normalizer,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<InternalErrorReporter> ice,
@ -79,6 +88,7 @@ struct TypeFunctionContext
: arena(arena)
, builtins(builtins)
, scope(scope)
, simplifier(simplifier)
, normalizer(normalizer)
, typeFunctionRuntime(typeFunctionRuntime)
, ice(ice)
@ -91,19 +101,31 @@ struct TypeFunctionContext
NotNull<Constraint> pushConstraint(ConstraintV&& c) const;
};
enum class Reduction
{
// The type function is either known to be reducible or the determination is blocked.
MaybeOk,
// The type function is known to be irreducible, but maybe not be erroneous, e.g. when it's over generics or free types.
Irreducible,
// The type function is known to be irreducible, and is definitely erroneous.
Erroneous,
};
/// Represents a reduction result, which may have successfully reduced the type,
/// may have concretely failed to reduce the type, or may simply be stuck
/// without more information.
template<typename Ty>
struct TypeFunctionReductionResult
{
/// The result of the reduction, if any. If this is nullopt, the type function
/// could not be reduced.
std::optional<Ty> result;
/// Whether the result is uninhabited: whether we know, unambiguously and
/// permanently, whether this type function reduction results in an
/// uninhabitable type. This will trigger an error to be reported.
bool uninhabited;
/// Indicates the status of this reduction: is `Reduction::Irreducible` if
/// the this result indicates the type function is irreducible, and
/// `Reduction::Erroneous` if this result indicates the type function is
/// erroneous. `Reduction::MaybeOk` otherwise.
Reduction reductionStatus;
/// Any types that need to be progressed or mutated before the reduction may
/// proceed.
std::vector<TypeId> blockedTypes;
@ -112,6 +134,8 @@ struct TypeFunctionReductionResult
std::vector<TypePackId> blockedPacks;
/// A runtime error message from user-defined type functions
std::optional<std::string> error;
/// Messages printed out from user-defined type functions
std::vector<std::string> messages;
};
template<typename T>
@ -145,6 +169,7 @@ struct TypePackFunction
struct FunctionGraphReductionResult
{
ErrorVec errors;
ErrorVec messages;
DenseHashSet<TypeId> blockedTypes{nullptr};
DenseHashSet<TypePackId> blockedPacks{nullptr};
DenseHashSet<TypeId> reducedTypes{nullptr};
@ -216,6 +241,9 @@ struct BuiltinTypeFunctions
TypeFunction indexFunc;
TypeFunction rawgetFunc;
TypeFunction setmetatableFunc;
TypeFunction getmetatableFunc;
void addToScope(NotNull<TypeArena> arena, NotNull<Scope> scope) const;
};

View file

@ -119,7 +119,14 @@ struct TypeFunctionVariadicTypePack
TypeFunctionTypeId type;
};
using TypeFunctionTypePackVariant = Variant<TypeFunctionTypePack, TypeFunctionVariadicTypePack>;
struct TypeFunctionGenericTypePack
{
bool isNamed = false;
std::string name;
};
using TypeFunctionTypePackVariant = Variant<TypeFunctionTypePack, TypeFunctionVariadicTypePack, TypeFunctionGenericTypePack>;
struct TypeFunctionTypePackVar
{
@ -135,6 +142,9 @@ struct TypeFunctionTypePackVar
struct TypeFunctionFunctionType
{
std::vector<TypeFunctionTypeId> generics;
std::vector<TypeFunctionTypePackId> genericPacks;
TypeFunctionTypePackId argTypes;
TypeFunctionTypePackId retTypes;
};
@ -210,6 +220,14 @@ struct TypeFunctionClassType
std::string name;
};
struct TypeFunctionGenericType
{
bool isNamed = false;
bool isPack = false;
std::string name;
};
using TypeFunctionTypeVariant = Luau::Variant<
TypeFunctionPrimitiveType,
TypeFunctionAnyType,
@ -221,7 +239,8 @@ using TypeFunctionTypeVariant = Luau::Variant<
TypeFunctionNegationType,
TypeFunctionFunctionType,
TypeFunctionTableType,
TypeFunctionClassType>;
TypeFunctionClassType,
TypeFunctionGenericType>;
struct TypeFunctionType
{

View file

@ -40,7 +40,7 @@ struct InConditionalContext
TypeContext* typeContext;
TypeContext oldValue;
InConditionalContext(TypeContext* c)
explicit InConditionalContext(TypeContext* c)
: typeContext(c)
, oldValue(*c)
{
@ -269,8 +269,8 @@ bool isLiteral(const AstExpr* expr);
std::vector<TypeId> findBlockedTypesIn(AstExprTable* expr, NotNull<DenseHashMap<const AstExpr*, TypeId>> astTypes);
/**
* Given a function call and a mapping from expression to type, determine
* whether the type of any argument in said call in depends on a blocked types.
* Given a function call and a mapping from expression to type, determine
* whether the type of any argument in said call in depends on a blocked types.
* This is used as a precondition for bidirectional inference: be warned that
* the behavior of this algorithm is tightly coupled to that of bidirectional
* inference.
@ -280,4 +280,13 @@ std::vector<TypeId> findBlockedTypesIn(AstExprTable* expr, NotNull<DenseHashMap<
*/
std::vector<TypeId> findBlockedArgTypesIn(AstExprCall* expr, NotNull<DenseHashMap<const AstExpr*, TypeId>> astTypes);
/**
* Given a scope and a free type, find the closest parent that has a present
* `interiorFreeTypes` and append the given type to said list. This list will
* be generalized when the requiste `GeneralizationConstraint` is resolved.
* @param scope Initial scope this free type was attached to
* @param ty Free type to track.
*/
void trackInteriorFreeType(Scope* scope, TypeId ty);
} // namespace Luau

View file

@ -85,6 +85,8 @@ struct GenericTypeVisitor
{
}
virtual ~GenericTypeVisitor() {}
virtual void cycle(TypeId) {}
virtual void cycle(TypePackId) {}

View file

@ -177,7 +177,6 @@ void AnyTypeSummary::visit(const Scope* scope, AstStatReturn* ret, const Module*
}
}
}
}
void AnyTypeSummary::visit(const Scope* scope, AstStatLocal* local, const Module* module, NotNull<BuiltinTypes> builtinTypes)

View file

@ -1161,6 +1161,19 @@ struct AstJsonEncoder : public AstVisitor
);
}
bool visit(class AstTypeGroup* node) override
{
writeNode(
node,
"AstTypeGroup",
[&]()
{
write("type", node->type);
}
);
return false;
}
bool visit(class AstTypeSingletonBool* node) override
{
writeNode(

View file

@ -13,7 +13,7 @@
LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAGVARIABLE(LuauDocumentationAtPosition)
LUAU_FASTFLAG(LuauExtendStatEndPosWithSemicolon)
namespace Luau
{
@ -43,11 +43,26 @@ struct AutocompleteNodeFinder : public AstVisitor
bool visit(AstStat* stat) override
{
if (stat->location.begin < pos && pos <= stat->location.end)
if (FFlag::LuauExtendStatEndPosWithSemicolon)
{
ancestry.push_back(stat);
return true;
// Consider 'local myLocal = 4;|' and 'local myLocal = 4', where '|' is the cursor position. In both cases, the cursor position is equal
// to `AstStatLocal.location.end`. However, in the first case (semicolon), we are starting a new statement, whilst in the second case
// (no semicolon) we are still part of the AstStatLocal, hence the different comparison check.
if (stat->location.begin < pos && (stat->hasSemicolon ? pos < stat->location.end : pos <= stat->location.end))
{
ancestry.push_back(stat);
return true;
}
}
else
{
if (stat->location.begin < pos && pos <= stat->location.end)
{
ancestry.push_back(stat);
return true;
}
}
return false;
}
@ -518,7 +533,6 @@ static std::optional<DocumentationSymbol> getMetatableDocumentation(
const AstName& index
)
{
LUAU_ASSERT(FFlag::LuauDocumentationAtPosition);
auto indexIt = mtable->props.find("__index");
if (indexIt == mtable->props.end())
return std::nullopt;
@ -575,26 +589,7 @@ std::optional<DocumentationSymbol> getDocumentationSymbolAtPosition(const Source
}
else if (const ClassType* ctv = get<ClassType>(parentTy))
{
if (FFlag::LuauDocumentationAtPosition)
{
while (ctv)
{
if (auto propIt = ctv->props.find(indexName->index.value); propIt != ctv->props.end())
{
if (FFlag::LuauSolverV2)
{
if (auto ty = propIt->second.readTy)
return checkOverloadedDocumentationSymbol(module, *ty, parentExpr, propIt->second.documentationSymbol);
}
else
return checkOverloadedDocumentationSymbol(
module, propIt->second.type(), parentExpr, propIt->second.documentationSymbol
);
}
ctv = ctv->parent ? Luau::get<Luau::ClassType>(*ctv->parent) : nullptr;
}
}
else
while (ctv)
{
if (auto propIt = ctv->props.find(indexName->index.value); propIt != ctv->props.end())
{
@ -608,17 +603,15 @@ std::optional<DocumentationSymbol> getDocumentationSymbolAtPosition(const Source
module, propIt->second.type(), parentExpr, propIt->second.documentationSymbol
);
}
ctv = ctv->parent ? Luau::get<Luau::ClassType>(*ctv->parent) : nullptr;
}
}
else if (FFlag::LuauDocumentationAtPosition)
else if (const PrimitiveType* ptv = get<PrimitiveType>(parentTy); ptv && ptv->metatable)
{
if (const PrimitiveType* ptv = get<PrimitiveType>(parentTy); ptv && ptv->metatable)
if (auto mtable = get<TableType>(*ptv->metatable))
{
if (auto mtable = get<TableType>(*ptv->metatable))
{
if (std::optional<std::string> docSymbol = getMetatableDocumentation(module, parentExpr, mtable, indexName->index))
return docSymbol;
}
if (std::optional<std::string> docSymbol = getMetatableDocumentation(module, parentExpr, mtable, indexName->index))
return docSymbol;
}
}
}

View file

@ -25,6 +25,7 @@ LUAU_FASTINT(LuauTypeInferIterationLimit)
LUAU_FASTINT(LuauTypeInferRecursionLimit)
LUAU_FASTFLAGVARIABLE(LuauAutocompleteRefactorsForIncrementalAutocomplete)
LUAU_FASTFLAGVARIABLE(LuauAutocompleteUseLimits)
static const std::unordered_set<std::string> kStatementStartingKeywords =
{"while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"};
@ -150,6 +151,7 @@ static bool checkTypeMatch(TypeId subTy, TypeId superTy, NotNull<Scope> scope, T
{
InternalErrorReporter iceReporter;
UnifierSharedState unifierState(&iceReporter);
SimplifierPtr simplifier = newSimplifier(NotNull{typeArena}, builtinTypes);
Normalizer normalizer{typeArena, builtinTypes, NotNull{&unifierState}};
if (FFlag::LuauSolverV2)
@ -162,7 +164,9 @@ static bool checkTypeMatch(TypeId subTy, TypeId superTy, NotNull<Scope> scope, T
unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit;
unifierState.counters.iterationLimit = FInt::LuauTypeInferIterationLimit;
Subtyping subtyping{builtinTypes, NotNull{typeArena}, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&iceReporter}};
Subtyping subtyping{
builtinTypes, NotNull{typeArena}, NotNull{simplifier.get()}, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&iceReporter}
};
return subtyping.isSubtype(subTy, superTy, scope).isSubtype;
}
@ -174,6 +178,12 @@ static bool checkTypeMatch(TypeId subTy, TypeId superTy, NotNull<Scope> scope, T
unifier.normalize = false;
unifier.checkInhabited = false;
if (FFlag::LuauAutocompleteUseLimits)
{
unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit;
unifierState.counters.iterationLimit = FInt::LuauTypeInferIterationLimit;
}
return unifier.canUnify(subTy, superTy).empty();
}
}

View file

@ -29,50 +29,81 @@
*/
LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAGVARIABLE(LuauTypestateBuiltins2)
LUAU_FASTFLAGVARIABLE(LuauStringFormatArityFix)
LUAU_FASTFLAGVARIABLE(LuauStringFormatErrorSuppression)
LUAU_FASTFLAG(AutocompleteRequirePathSuggestions2)
LUAU_FASTFLAG(LuauVectorDefinitionsExtra)
LUAU_FASTFLAGVARIABLE(LuauTableCloneClonesType3)
LUAU_FASTFLAG(LuauTrackInteriorFreeTypesOnScope)
LUAU_FASTFLAGVARIABLE(LuauFreezeIgnorePersistent)
namespace Luau
{
static std::optional<WithPredicate<TypePackId>> magicFunctionSelect(
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
);
static std::optional<WithPredicate<TypePackId>> magicFunctionSetMetaTable(
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
);
static std::optional<WithPredicate<TypePackId>> magicFunctionAssert(
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
);
static std::optional<WithPredicate<TypePackId>> magicFunctionPack(
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
);
static std::optional<WithPredicate<TypePackId>> magicFunctionRequire(
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
);
struct MagicSelect final : MagicFunction
{
std::optional<WithPredicate<TypePackId>> handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override;
};
struct MagicSetMetatable final : MagicFunction
{
std::optional<WithPredicate<TypePackId>> handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override;
};
static bool dcrMagicFunctionSelect(MagicFunctionCallContext context);
static bool dcrMagicFunctionRequire(MagicFunctionCallContext context);
static bool dcrMagicFunctionPack(MagicFunctionCallContext context);
static bool dcrMagicFunctionFreeze(MagicFunctionCallContext context);
struct MagicAssert final : MagicFunction
{
std::optional<WithPredicate<TypePackId>> handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override;
};
struct MagicPack final : MagicFunction
{
std::optional<WithPredicate<TypePackId>> handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override;
};
struct MagicRequire final : MagicFunction
{
std::optional<WithPredicate<TypePackId>> handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override;
};
struct MagicClone final : MagicFunction
{
std::optional<WithPredicate<TypePackId>> handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override;
};
struct MagicFreeze final : MagicFunction
{
std::optional<WithPredicate<TypePackId>> handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override;
};
struct MagicFormat final : MagicFunction
{
std::optional<WithPredicate<TypePackId>> handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override;
bool typeCheck(const MagicFunctionTypeCheckContext& ctx) override;
};
struct MagicMatch final : MagicFunction
{
std::optional<WithPredicate<TypePackId>> handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override;
};
struct MagicGmatch final : MagicFunction
{
std::optional<WithPredicate<TypePackId>> handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override;
};
struct MagicFind final : MagicFunction
{
std::optional<WithPredicate<TypePackId>> handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override;
};
TypeId makeUnion(TypeArena& arena, std::vector<TypeId>&& types)
{
@ -167,34 +198,10 @@ TypeId makeFunction(
return arena.addType(std::move(ftv));
}
void attachMagicFunction(TypeId ty, MagicFunction fn)
void attachMagicFunction(TypeId ty, std::shared_ptr<MagicFunction> magic)
{
if (auto ftv = getMutable<FunctionType>(ty))
ftv->magicFunction = fn;
else
LUAU_ASSERT(!"Got a non functional type");
}
void attachDcrMagicFunction(TypeId ty, DcrMagicFunction fn)
{
if (auto ftv = getMutable<FunctionType>(ty))
ftv->dcrMagicFunction = fn;
else
LUAU_ASSERT(!"Got a non functional type");
}
void attachDcrMagicRefinement(TypeId ty, DcrMagicRefinement fn)
{
if (auto ftv = getMutable<FunctionType>(ty))
ftv->dcrMagicRefinement = fn;
else
LUAU_ASSERT(!"Got a non functional type");
}
void attachDcrMagicFunctionTypeCheck(TypeId ty, DcrMagicFunctionTypeCheck fn)
{
if (auto ftv = getMutable<FunctionType>(ty))
ftv->dcrMagicTypeCheck = fn;
ftv->magic = std::move(magic);
else
LUAU_ASSERT(!"Got a non functional type");
}
@ -301,28 +308,25 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC
addGlobalBinding(globals, "string", it->second.type(), "@luau");
// Setup 'vector' metatable
if (FFlag::LuauVectorDefinitionsExtra)
if (auto it = globals.globalScope->exportedTypeBindings.find("vector"); it != globals.globalScope->exportedTypeBindings.end())
{
if (auto it = globals.globalScope->exportedTypeBindings.find("vector"); it != globals.globalScope->exportedTypeBindings.end())
{
TypeId vectorTy = it->second.type;
ClassType* vectorCls = getMutable<ClassType>(vectorTy);
TypeId vectorTy = it->second.type;
ClassType* vectorCls = getMutable<ClassType>(vectorTy);
vectorCls->metatable = arena.addType(TableType{{}, std::nullopt, TypeLevel{}, TableState::Sealed});
TableType* metatableTy = Luau::getMutable<TableType>(vectorCls->metatable);
vectorCls->metatable = arena.addType(TableType{{}, std::nullopt, TypeLevel{}, TableState::Sealed});
TableType* metatableTy = Luau::getMutable<TableType>(vectorCls->metatable);
metatableTy->props["__add"] = {makeFunction(arena, vectorTy, {vectorTy}, {vectorTy})};
metatableTy->props["__sub"] = {makeFunction(arena, vectorTy, {vectorTy}, {vectorTy})};
metatableTy->props["__unm"] = {makeFunction(arena, vectorTy, {}, {vectorTy})};
metatableTy->props["__add"] = {makeFunction(arena, vectorTy, {vectorTy}, {vectorTy})};
metatableTy->props["__sub"] = {makeFunction(arena, vectorTy, {vectorTy}, {vectorTy})};
metatableTy->props["__unm"] = {makeFunction(arena, vectorTy, {}, {vectorTy})};
std::initializer_list<TypeId> mulOverloads{
makeFunction(arena, vectorTy, {vectorTy}, {vectorTy}),
makeFunction(arena, vectorTy, {builtinTypes->numberType}, {vectorTy}),
};
metatableTy->props["__mul"] = {makeIntersection(arena, mulOverloads)};
metatableTy->props["__div"] = {makeIntersection(arena, mulOverloads)};
metatableTy->props["__idiv"] = {makeIntersection(arena, mulOverloads)};
}
std::initializer_list<TypeId> mulOverloads{
makeFunction(arena, vectorTy, {vectorTy}, {vectorTy}),
makeFunction(arena, vectorTy, {builtinTypes->numberType}, {vectorTy}),
};
metatableTy->props["__mul"] = {makeIntersection(arena, mulOverloads)};
metatableTy->props["__div"] = {makeIntersection(arena, mulOverloads)};
metatableTy->props["__idiv"] = {makeIntersection(arena, mulOverloads)};
}
// next<K, V>(t: Table<K, V>, i: K?) -> (K?, V)
@ -395,7 +399,7 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC
}
}
attachMagicFunction(getGlobalBinding(globals, "assert"), magicFunctionAssert);
attachMagicFunction(getGlobalBinding(globals, "assert"), std::make_shared<MagicAssert>());
if (FFlag::LuauSolverV2)
{
@ -411,9 +415,8 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC
addGlobalBinding(globals, "assert", assertTy, "@luau");
}
attachMagicFunction(getGlobalBinding(globals, "setmetatable"), magicFunctionSetMetaTable);
attachMagicFunction(getGlobalBinding(globals, "select"), magicFunctionSelect);
attachDcrMagicFunction(getGlobalBinding(globals, "select"), dcrMagicFunctionSelect);
attachMagicFunction(getGlobalBinding(globals, "setmetatable"), std::make_shared<MagicSetMetatable>());
attachMagicFunction(getGlobalBinding(globals, "select"), std::make_shared<MagicSelect>());
if (TableType* ttv = getMutable<TableType>(getGlobalBinding(globals, "table")))
{
@ -444,23 +447,21 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC
ttv->props["foreach"].deprecated = true;
ttv->props["foreachi"].deprecated = true;
attachMagicFunction(ttv->props["pack"].type(), magicFunctionPack);
attachDcrMagicFunction(ttv->props["pack"].type(), dcrMagicFunctionPack);
if (FFlag::LuauTypestateBuiltins2)
attachDcrMagicFunction(ttv->props["freeze"].type(), dcrMagicFunctionFreeze);
attachMagicFunction(ttv->props["pack"].type(), std::make_shared<MagicPack>());
if (FFlag::LuauTableCloneClonesType3)
attachMagicFunction(ttv->props["clone"].type(), std::make_shared<MagicClone>());
attachMagicFunction(ttv->props["freeze"].type(), std::make_shared<MagicFreeze>());
}
if (FFlag::AutocompleteRequirePathSuggestions2)
{
TypeId requireTy = getGlobalBinding(globals, "require");
attachTag(requireTy, kRequireTagName);
attachMagicFunction(requireTy, magicFunctionRequire);
attachDcrMagicFunction(requireTy, dcrMagicFunctionRequire);
attachMagicFunction(requireTy, std::make_shared<MagicRequire>());
}
else
{
attachMagicFunction(getGlobalBinding(globals, "require"), magicFunctionRequire);
attachDcrMagicFunction(getGlobalBinding(globals, "require"), dcrMagicFunctionRequire);
attachMagicFunction(getGlobalBinding(globals, "require"), std::make_shared<MagicRequire>());
}
}
@ -500,7 +501,7 @@ static std::vector<TypeId> parseFormatString(NotNull<BuiltinTypes> builtinTypes,
return result;
}
std::optional<WithPredicate<TypePackId>> magicFunctionFormat(
std::optional<WithPredicate<TypePackId>> MagicFormat::handleOldSolver(
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
@ -550,7 +551,7 @@ std::optional<WithPredicate<TypePackId>> magicFunctionFormat(
return WithPredicate<TypePackId>{arena.addTypePack({typechecker.stringType})};
}
static bool dcrMagicFunctionFormat(MagicFunctionCallContext context)
bool MagicFormat::infer(const MagicFunctionCallContext& context)
{
TypeArena* arena = context.solver->arena;
@ -594,7 +595,7 @@ static bool dcrMagicFunctionFormat(MagicFunctionCallContext context)
return true;
}
static void dcrMagicFunctionTypeCheckFormat(MagicFunctionTypeCheckContext context)
bool MagicFormat::typeCheck(const MagicFunctionTypeCheckContext& context)
{
AstExprConstantString* fmt = nullptr;
if (auto index = context.callSite->func->as<AstExprIndexName>(); index && context.callSite->self)
@ -610,9 +611,8 @@ static void dcrMagicFunctionTypeCheckFormat(MagicFunctionTypeCheckContext contex
if (!fmt)
{
if (FFlag::LuauStringFormatArityFix)
context.typechecker->reportError(CountMismatch{1, std::nullopt, 0, CountMismatch::Arg, true, "string.format"}, context.callSite->location);
return;
context.typechecker->reportError(CountMismatch{1, std::nullopt, 0, CountMismatch::Arg, true, "string.format"}, context.callSite->location);
return true;
}
std::vector<TypeId> expected = parseFormatString(context.builtinTypes, fmt->value.data, fmt->value.size);
@ -629,12 +629,33 @@ static void dcrMagicFunctionTypeCheckFormat(MagicFunctionTypeCheckContext contex
Location location = context.callSite->args.data[i + (calledWithSelf ? 0 : paramOffset)]->location;
// use subtyping instead here
SubtypingResult result = context.typechecker->subtyping->isSubtype(actualTy, expectedTy, context.checkScope);
if (!result.isSubtype)
{
Reasonings reasonings = context.typechecker->explainReasonings(actualTy, expectedTy, location, result);
context.typechecker->reportError(TypeMismatch{expectedTy, actualTy, reasonings.toString()}, location);
if (FFlag::LuauStringFormatErrorSuppression)
{
switch (shouldSuppressErrors(NotNull{&context.typechecker->normalizer}, actualTy))
{
case ErrorSuppression::Suppress:
break;
case ErrorSuppression::NormalizationFailed:
break;
case ErrorSuppression::DoNotSuppress:
Reasonings reasonings = context.typechecker->explainReasonings(actualTy, expectedTy, location, result);
if (!reasonings.suppressed)
context.typechecker->reportError(TypeMismatch{expectedTy, actualTy, reasonings.toString()}, location);
}
}
else
{
Reasonings reasonings = context.typechecker->explainReasonings(actualTy, expectedTy, location, result);
context.typechecker->reportError(TypeMismatch{expectedTy, actualTy, reasonings.toString()}, location);
}
}
}
return true;
}
static std::vector<TypeId> parsePatternString(NotNull<BuiltinTypes> builtinTypes, const char* data, size_t size)
@ -697,7 +718,7 @@ static std::vector<TypeId> parsePatternString(NotNull<BuiltinTypes> builtinTypes
return result;
}
static std::optional<WithPredicate<TypePackId>> magicFunctionGmatch(
std::optional<WithPredicate<TypePackId>> MagicGmatch::handleOldSolver(
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
@ -733,7 +754,7 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionGmatch(
return WithPredicate<TypePackId>{arena.addTypePack({iteratorType})};
}
static bool dcrMagicFunctionGmatch(MagicFunctionCallContext context)
bool MagicGmatch::infer(const MagicFunctionCallContext& context)
{
const auto& [params, tail] = flatten(context.arguments);
@ -766,7 +787,7 @@ static bool dcrMagicFunctionGmatch(MagicFunctionCallContext context)
return true;
}
static std::optional<WithPredicate<TypePackId>> magicFunctionMatch(
std::optional<WithPredicate<TypePackId>> MagicMatch::handleOldSolver(
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
@ -806,7 +827,7 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionMatch(
return WithPredicate<TypePackId>{returnList};
}
static bool dcrMagicFunctionMatch(MagicFunctionCallContext context)
bool MagicMatch::infer(const MagicFunctionCallContext& context)
{
const auto& [params, tail] = flatten(context.arguments);
@ -842,7 +863,7 @@ static bool dcrMagicFunctionMatch(MagicFunctionCallContext context)
return true;
}
static std::optional<WithPredicate<TypePackId>> magicFunctionFind(
std::optional<WithPredicate<TypePackId>> MagicFind::handleOldSolver(
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
@ -900,7 +921,7 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionFind(
return WithPredicate<TypePackId>{returnList};
}
static bool dcrMagicFunctionFind(MagicFunctionCallContext context)
bool MagicFind::infer(const MagicFunctionCallContext& context)
{
const auto& [params, tail] = flatten(context.arguments);
@ -977,11 +998,9 @@ TypeId makeStringMetatable(NotNull<BuiltinTypes> builtinTypes)
FunctionType formatFTV{arena->addTypePack(TypePack{{stringType}, variadicTailPack}), oneStringPack};
formatFTV.magicFunction = &magicFunctionFormat;
formatFTV.isCheckedFunction = true;
const TypeId formatFn = arena->addType(formatFTV);
attachDcrMagicFunction(formatFn, dcrMagicFunctionFormat);
attachDcrMagicFunctionTypeCheck(formatFn, dcrMagicFunctionTypeCheckFormat);
attachMagicFunction(formatFn, std::make_shared<MagicFormat>());
const TypeId stringToStringType = makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType}, /* checked */ true);
@ -995,16 +1014,14 @@ TypeId makeStringMetatable(NotNull<BuiltinTypes> builtinTypes)
makeFunction(*arena, stringType, {}, {}, {stringType, replArgType, optionalNumber}, {}, {stringType, numberType}, /* checked */ false);
const TypeId gmatchFunc =
makeFunction(*arena, stringType, {}, {}, {stringType}, {}, {arena->addType(FunctionType{emptyPack, stringVariadicList})}, /* checked */ true);
attachMagicFunction(gmatchFunc, magicFunctionGmatch);
attachDcrMagicFunction(gmatchFunc, dcrMagicFunctionGmatch);
attachMagicFunction(gmatchFunc, std::make_shared<MagicGmatch>());
FunctionType matchFuncTy{
arena->addTypePack({stringType, stringType, optionalNumber}), arena->addTypePack(TypePackVar{VariadicTypePack{stringType}})
};
matchFuncTy.isCheckedFunction = true;
const TypeId matchFunc = arena->addType(matchFuncTy);
attachMagicFunction(matchFunc, magicFunctionMatch);
attachDcrMagicFunction(matchFunc, dcrMagicFunctionMatch);
attachMagicFunction(matchFunc, std::make_shared<MagicMatch>());
FunctionType findFuncTy{
arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}),
@ -1012,8 +1029,7 @@ TypeId makeStringMetatable(NotNull<BuiltinTypes> builtinTypes)
};
findFuncTy.isCheckedFunction = true;
const TypeId findFunc = arena->addType(findFuncTy);
attachMagicFunction(findFunc, magicFunctionFind);
attachDcrMagicFunction(findFunc, dcrMagicFunctionFind);
attachMagicFunction(findFunc, std::make_shared<MagicFind>());
// string.byte : string -> number? -> number? -> ...number
FunctionType stringDotByte{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList};
@ -1074,7 +1090,7 @@ TypeId makeStringMetatable(NotNull<BuiltinTypes> builtinTypes)
return arena->addType(TableType{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed});
}
static std::optional<WithPredicate<TypePackId>> magicFunctionSelect(
std::optional<WithPredicate<TypePackId>> MagicSelect::handleOldSolver(
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
@ -1119,7 +1135,7 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionSelect(
return std::nullopt;
}
static bool dcrMagicFunctionSelect(MagicFunctionCallContext context)
bool MagicSelect::infer(const MagicFunctionCallContext& context)
{
if (context.callSite->args.size <= 0)
{
@ -1164,7 +1180,7 @@ static bool dcrMagicFunctionSelect(MagicFunctionCallContext context)
return false;
}
static std::optional<WithPredicate<TypePackId>> magicFunctionSetMetaTable(
std::optional<WithPredicate<TypePackId>> MagicSetMetatable::handleOldSolver(
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
@ -1246,7 +1262,12 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionSetMetaTable(
return WithPredicate<TypePackId>{arena.addTypePack({target})};
}
static std::optional<WithPredicate<TypePackId>> magicFunctionAssert(
bool MagicSetMetatable::infer(const MagicFunctionCallContext&)
{
return false;
}
std::optional<WithPredicate<TypePackId>> MagicAssert::handleOldSolver(
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
@ -1280,7 +1301,12 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionAssert(
return WithPredicate<TypePackId>{arena.addTypePack(TypePack{std::move(head), tail})};
}
static std::optional<WithPredicate<TypePackId>> magicFunctionPack(
bool MagicAssert::infer(const MagicFunctionCallContext&)
{
return false;
}
std::optional<WithPredicate<TypePackId>> MagicPack::handleOldSolver(
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
@ -1323,7 +1349,7 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionPack(
return WithPredicate<TypePackId>{arena.addTypePack({packedTable})};
}
static bool dcrMagicFunctionPack(MagicFunctionCallContext context)
bool MagicPack::infer(const MagicFunctionCallContext& context)
{
TypeArena* arena = context.solver->arena;
@ -1363,7 +1389,74 @@ static bool dcrMagicFunctionPack(MagicFunctionCallContext context)
return true;
}
static std::optional<TypeId> freezeTable(TypeId inputType, MagicFunctionCallContext& context)
std::optional<WithPredicate<TypePackId>> MagicClone::handleOldSolver(
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
)
{
LUAU_ASSERT(FFlag::LuauTableCloneClonesType3);
auto [paramPack, _predicates] = withPredicate;
TypeArena& arena = typechecker.currentModule->internalTypes;
const auto& [paramTypes, paramTail] = flatten(paramPack);
if (paramTypes.empty() || expr.args.size == 0)
{
typechecker.reportError(expr.argLocation, CountMismatch{1, std::nullopt, 0});
return std::nullopt;
}
TypeId inputType = follow(paramTypes[0]);
if (!get<TableType>(inputType))
return std::nullopt;
CloneState cloneState{typechecker.builtinTypes};
TypeId resultType = shallowClone(inputType, arena, cloneState);
TypePackId clonedTypePack = arena.addTypePack({resultType});
return WithPredicate<TypePackId>{clonedTypePack};
}
bool MagicClone::infer(const MagicFunctionCallContext& context)
{
LUAU_ASSERT(FFlag::LuauTableCloneClonesType3);
TypeArena* arena = context.solver->arena;
const auto& [paramTypes, paramTail] = flatten(context.arguments);
if (paramTypes.empty() || context.callSite->args.size == 0)
{
context.solver->reportError(CountMismatch{1, std::nullopt, 0}, context.callSite->argLocation);
return false;
}
TypeId inputType = follow(paramTypes[0]);
if (!get<TableType>(inputType))
return false;
CloneState cloneState{context.solver->builtinTypes};
TypeId resultType = shallowClone(inputType, *arena, cloneState, /* ignorePersistent */ FFlag::LuauFreezeIgnorePersistent);
if (auto tableType = getMutable<TableType>(resultType))
{
tableType->scope = context.constraint->scope.get();
}
if (FFlag::LuauTrackInteriorFreeTypesOnScope)
trackInteriorFreeType(context.constraint->scope.get(), resultType);
TypePackId clonedTypePack = arena->addTypePack({resultType});
asMutable(context.result)->ty.emplace<BoundTypePack>(clonedTypePack);
return true;
}
static std::optional<TypeId> freezeTable(TypeId inputType, const MagicFunctionCallContext& context)
{
TypeArena* arena = context.solver->arena;
@ -1383,7 +1476,7 @@ static std::optional<TypeId> freezeTable(TypeId inputType, MagicFunctionCallCont
{
// Clone the input type, this will become our final result type after we mutate it.
CloneState cloneState{context.solver->builtinTypes};
TypeId resultType = shallowClone(inputType, *arena, cloneState);
TypeId resultType = shallowClone(inputType, *arena, cloneState, /* ignorePersistent */ FFlag::LuauFreezeIgnorePersistent);
auto tableTy = getMutable<TableType>(resultType);
// `clone` should not break this.
LUAU_ASSERT(tableTy);
@ -1408,10 +1501,13 @@ static std::optional<TypeId> freezeTable(TypeId inputType, MagicFunctionCallCont
return std::nullopt;
}
static bool dcrMagicFunctionFreeze(MagicFunctionCallContext context)
std::optional<WithPredicate<TypePackId>> MagicFreeze::handleOldSolver(struct TypeChecker &, const std::shared_ptr<struct Scope> &, const class AstExprCall &, WithPredicate<TypePackId>)
{
LUAU_ASSERT(FFlag::LuauTypestateBuiltins2);
return std::nullopt;
}
bool MagicFreeze::infer(const MagicFunctionCallContext& context)
{
TypeArena* arena = context.solver->arena;
const DataFlowGraph* dfg = context.solver->dfg.get();
Scope* scope = context.constraint->scope.get();
@ -1469,7 +1565,7 @@ static bool checkRequirePath(TypeChecker& typechecker, AstExpr* expr)
return good;
}
static std::optional<WithPredicate<TypePackId>> magicFunctionRequire(
std::optional<WithPredicate<TypePackId>> MagicRequire::handleOldSolver(
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
@ -1515,7 +1611,7 @@ static bool checkRequirePathDcr(NotNull<ConstraintSolver> solver, AstExpr* expr)
return good;
}
static bool dcrMagicFunctionRequire(MagicFunctionCallContext context)
bool MagicRequire::infer(const MagicFunctionCallContext& context)
{
if (context.callSite->args.size != 1)
{

View file

@ -7,6 +7,7 @@
#include "Luau/Unifiable.h"
LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAG(LuauFreezeIgnorePersistent)
// For each `Luau::clone` call, we will clone only up to N amount of types _and_ packs, as controlled by this limit.
LUAU_FASTINTVARIABLE(LuauTypeCloneIterationLimit, 100'000)
@ -38,14 +39,26 @@ class TypeCloner
NotNull<SeenTypes> types;
NotNull<SeenTypePacks> packs;
TypeId forceTy = nullptr;
TypePackId forceTp = nullptr;
int steps = 0;
public:
TypeCloner(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtinTypes, NotNull<SeenTypes> types, NotNull<SeenTypePacks> packs)
TypeCloner(
NotNull<TypeArena> arena,
NotNull<BuiltinTypes> builtinTypes,
NotNull<SeenTypes> types,
NotNull<SeenTypePacks> packs,
TypeId forceTy,
TypePackId forceTp
)
: arena(arena)
, builtinTypes(builtinTypes)
, types(types)
, packs(packs)
, forceTy(forceTy)
, forceTp(forceTp)
{
}
@ -112,7 +125,7 @@ private:
ty = follow(ty, FollowOption::DisableLazyTypeThunks);
if (auto it = types->find(ty); it != types->end())
return it->second;
else if (ty->persistent)
else if (ty->persistent && (!FFlag::LuauFreezeIgnorePersistent || ty != forceTy))
return ty;
return std::nullopt;
}
@ -122,7 +135,7 @@ private:
tp = follow(tp);
if (auto it = packs->find(tp); it != packs->end())
return it->second;
else if (tp->persistent)
else if (tp->persistent && (!FFlag::LuauFreezeIgnorePersistent || tp != forceTp))
return tp;
return std::nullopt;
}
@ -148,7 +161,7 @@ public:
if (auto clone = find(ty))
return *clone;
else if (ty->persistent)
else if (ty->persistent && (!FFlag::LuauFreezeIgnorePersistent || ty != forceTy))
return ty;
TypeId target = arena->addType(ty->ty);
@ -174,7 +187,7 @@ public:
if (auto clone = find(tp))
return *clone;
else if (tp->persistent)
else if (tp->persistent && (!FFlag::LuauFreezeIgnorePersistent || tp != forceTp))
return tp;
TypePackId target = arena->addTypePack(tp->ty);
@ -458,21 +471,37 @@ private:
} // namespace
TypePackId shallowClone(TypePackId tp, TypeArena& dest, CloneState& cloneState)
TypePackId shallowClone(TypePackId tp, TypeArena& dest, CloneState& cloneState, bool ignorePersistent)
{
if (tp->persistent)
if (tp->persistent && (!FFlag::LuauFreezeIgnorePersistent || !ignorePersistent))
return tp;
TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}};
TypeCloner cloner{
NotNull{&dest},
cloneState.builtinTypes,
NotNull{&cloneState.seenTypes},
NotNull{&cloneState.seenTypePacks},
nullptr,
FFlag::LuauFreezeIgnorePersistent && ignorePersistent ? tp : nullptr
};
return cloner.shallowClone(tp);
}
TypeId shallowClone(TypeId typeId, TypeArena& dest, CloneState& cloneState)
TypeId shallowClone(TypeId typeId, TypeArena& dest, CloneState& cloneState, bool ignorePersistent)
{
if (typeId->persistent)
if (typeId->persistent && (!FFlag::LuauFreezeIgnorePersistent || !ignorePersistent))
return typeId;
TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}};
TypeCloner cloner{
NotNull{&dest},
cloneState.builtinTypes,
NotNull{&cloneState.seenTypes},
NotNull{&cloneState.seenTypePacks},
FFlag::LuauFreezeIgnorePersistent && ignorePersistent ? typeId : nullptr,
nullptr
};
return cloner.shallowClone(typeId);
}
@ -481,7 +510,7 @@ TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState)
if (tp->persistent)
return tp;
TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}};
TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}, nullptr, nullptr};
return cloner.clone(tp);
}
@ -490,13 +519,13 @@ TypeId clone(TypeId typeId, TypeArena& dest, CloneState& cloneState)
if (typeId->persistent)
return typeId;
TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}};
TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}, nullptr, nullptr};
return cloner.clone(typeId);
}
TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState)
{
TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}};
TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}, nullptr, nullptr};
TypeFun copy = typeFun;
@ -521,4 +550,18 @@ TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState)
return copy;
}
Binding clone(const Binding& binding, TypeArena& dest, CloneState& cloneState)
{
TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}, nullptr, nullptr};
Binding b;
b.deprecated = binding.deprecated;
b.deprecatedSuggestion = binding.deprecatedSuggestion;
b.documentationSymbol = binding.documentationSymbol;
b.location = binding.location;
b.typeId = cloner.clone(binding.typeId);
return b;
}
} // namespace Luau

View file

@ -3,8 +3,6 @@
#include "Luau/Constraint.h"
#include "Luau/VisitType.h"
LUAU_FASTFLAGVARIABLE(LuauDontRefCountTypesInTypeFunctions)
namespace Luau
{
@ -60,9 +58,8 @@ struct ReferenceCountInitializer : TypeOnceVisitor
//
// The default behavior here is `true` for "visit the child types"
// of this type, hence:
return !FFlag::LuauDontRefCountTypesInTypeFunctions;
return false;
}
};
bool isReferenceCountedType(const TypeId typ)

View file

@ -31,13 +31,14 @@
LUAU_FASTINT(LuauCheckRecursionLimit)
LUAU_FASTFLAG(DebugLuauLogSolverToJson)
LUAU_FASTFLAG(DebugLuauMagicTypes)
LUAU_FASTFLAG(DebugLuauEqSatSimplification)
LUAU_FASTFLAG(LuauTypestateBuiltins2)
LUAU_FASTFLAG(LuauPreserveUnionIntersectionNodeForLeadingTokenSingleType)
LUAU_FASTFLAGVARIABLE(LuauNewSolverVisitErrorExprLvalues)
LUAU_FASTFLAGVARIABLE(LuauNewSolverPrePopulateClasses)
LUAU_FASTFLAGVARIABLE(LuauUserTypeFunExportedAndLocal)
LUAU_FASTFLAGVARIABLE(LuauNewSolverPopulateTableLocations)
LUAU_FASTFLAGVARIABLE(LuauTrackInteriorFreeTypesOnScope)
LUAU_FASTFLAGVARIABLE(InferGlobalTypes)
LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds)
namespace Luau
{
@ -229,8 +230,17 @@ void ConstraintGenerator::visitModuleRoot(AstStatBlock* block)
Checkpoint end = checkpoint(this);
TypeId result = arena->addType(BlockedType{});
NotNull<Constraint> genConstraint =
addConstraint(scope, block->location, GeneralizationConstraint{result, moduleFnTy, std::move(interiorTypes.back())});
NotNull<Constraint> genConstraint = addConstraint(
scope,
block->location,
GeneralizationConstraint{
result, moduleFnTy, FFlag::LuauTrackInteriorFreeTypesOnScope ? std::vector<TypeId>{} : std::move(interiorTypes.back())
}
);
if (FFlag::LuauTrackInteriorFreeTypesOnScope)
scope->interiorFreeTypes = std::move(interiorTypes.back());
getMutable<BlockedType>(result)->setOwner(genConstraint);
forEachConstraint(
start,
@ -299,9 +309,19 @@ void ConstraintGenerator::visitFragmentRoot(const ScopePtr& resumeScope, AstStat
}
}
TypeId ConstraintGenerator::freshType(const ScopePtr& scope)
{
return Luau::freshType(arena, builtinTypes, scope.get());
if (FFlag::LuauTrackInteriorFreeTypesOnScope)
{
auto ft = Luau::freshType(arena, builtinTypes, scope.get());
interiorTypes.back().push_back(ft);
return ft;
}
else
{
return Luau::freshType(arena, builtinTypes, scope.get());
}
}
TypePackId ConstraintGenerator::freshTypePack(const ScopePtr& scope)
@ -720,12 +740,6 @@ void ConstraintGenerator::checkAliases(const ScopePtr& scope, AstStatBlock* bloc
continue;
}
if (!FFlag::LuauUserTypeFunExportedAndLocal && scope->parent != globalScope)
{
reportError(function->location, GenericError{"Local user-defined functions are not supported yet"});
continue;
}
ScopePtr defnScope = childScope(function, scope);
// Create TypeFunctionInstanceType
@ -751,11 +765,8 @@ void ConstraintGenerator::checkAliases(const ScopePtr& scope, AstStatBlock* bloc
UserDefinedFunctionData udtfData;
if (FFlag::LuauUserTypeFunExportedAndLocal)
{
udtfData.owner = module;
udtfData.definition = function;
}
udtfData.owner = module;
udtfData.definition = function;
TypeId typeFunctionTy = arena->addType(
TypeFunctionInstanceType{NotNull{&builtinTypeFunctions().userFunc}, std::move(typeParams), {}, function->name, udtfData}
@ -764,7 +775,7 @@ void ConstraintGenerator::checkAliases(const ScopePtr& scope, AstStatBlock* bloc
TypeFun typeFunction{std::move(quantifiedTypeParams), typeFunctionTy};
// Set type bindings and definition locations for this user-defined type function
if (FFlag::LuauUserTypeFunExportedAndLocal && function->exported)
if (function->exported)
scope->exportedTypeBindings[function->name.value] = std::move(typeFunction);
else
scope->privateTypeBindings[function->name.value] = std::move(typeFunction);
@ -799,49 +810,74 @@ void ConstraintGenerator::checkAliases(const ScopePtr& scope, AstStatBlock* bloc
}
}
if (FFlag::LuauUserTypeFunExportedAndLocal)
// Additional pass for user-defined type functions to fill in their environments completely
for (AstStat* stat : block->body)
{
// Additional pass for user-defined type functions to fill in their environments completely
for (AstStat* stat : block->body)
if (auto function = stat->as<AstStatTypeFunction>())
{
if (auto function = stat->as<AstStatTypeFunction>())
// Find the type function we have already created
TypeFunctionInstanceType* mainTypeFun = nullptr;
if (auto it = scope->privateTypeBindings.find(function->name.value); it != scope->privateTypeBindings.end())
mainTypeFun = getMutable<TypeFunctionInstanceType>(it->second.type);
if (!mainTypeFun)
{
// Find the type function we have already created
TypeFunctionInstanceType* mainTypeFun = nullptr;
if (auto it = scope->privateTypeBindings.find(function->name.value); it != scope->privateTypeBindings.end())
if (auto it = scope->exportedTypeBindings.find(function->name.value); it != scope->exportedTypeBindings.end())
mainTypeFun = getMutable<TypeFunctionInstanceType>(it->second.type);
}
if (!mainTypeFun)
// Fill it with all visible type functions
if (mainTypeFun)
{
UserDefinedFunctionData& userFuncData = mainTypeFun->userFuncData;
size_t level = 0;
for (Scope* curr = scope.get(); curr; curr = curr->parent.get())
{
if (auto it = scope->exportedTypeBindings.find(function->name.value); it != scope->exportedTypeBindings.end())
mainTypeFun = getMutable<TypeFunctionInstanceType>(it->second.type);
}
// Fill it with all visible type functions
if (mainTypeFun)
{
UserDefinedFunctionData& userFuncData = mainTypeFun->userFuncData;
for (Scope* curr = scope.get(); curr; curr = curr->parent.get())
for (auto& [name, tf] : curr->privateTypeBindings)
{
for (auto& [name, tf] : curr->privateTypeBindings)
{
if (userFuncData.environment.find(name))
continue;
if (userFuncData.environment.find(name))
continue;
if (auto ty = get<TypeFunctionInstanceType>(tf.type); ty && ty->userFuncData.definition)
userFuncData.environment[name] = ty->userFuncData.definition;
}
if (auto ty = get<TypeFunctionInstanceType>(tf.type); ty && ty->userFuncData.definition)
userFuncData.environment[name] = std::make_pair(ty->userFuncData.definition, level);
}
for (auto& [name, tf] : curr->exportedTypeBindings)
{
if (userFuncData.environment.find(name))
continue;
for (auto& [name, tf] : curr->exportedTypeBindings)
{
if (userFuncData.environment.find(name))
continue;
if (auto ty = get<TypeFunctionInstanceType>(tf.type); ty && ty->userFuncData.definition)
userFuncData.environment[name] = ty->userFuncData.definition;
}
if (auto ty = get<TypeFunctionInstanceType>(tf.type); ty && ty->userFuncData.definition)
userFuncData.environment[name] = std::make_pair(ty->userFuncData.definition, level);
}
level++;
}
}
else if (mainTypeFun)
{
UserDefinedFunctionData& userFuncData = mainTypeFun->userFuncData;
for (Scope* curr = scope.get(); curr; curr = curr->parent.get())
{
for (auto& [name, tf] : curr->privateTypeBindings)
{
if (userFuncData.environment_DEPRECATED.find(name))
continue;
if (auto ty = get<TypeFunctionInstanceType>(tf.type); ty && ty->userFuncData.definition)
userFuncData.environment_DEPRECATED[name] = ty->userFuncData.definition;
}
for (auto& [name, tf] : curr->exportedTypeBindings)
{
if (userFuncData.environment_DEPRECATED.find(name))
continue;
if (auto ty = get<TypeFunctionInstanceType>(tf.type); ty && ty->userFuncData.definition)
userFuncData.environment_DEPRECATED[name] = ty->userFuncData.definition;
}
}
}
@ -1053,18 +1089,8 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatLocal* stat
addConstraint(scope, value->location, NameConstraint{*firstValueType, var->name.value, /*synthetic*/ true});
else if (const AstExprCall* call = value->as<AstExprCall>())
{
if (FFlag::LuauTypestateBuiltins2)
{
if (matchSetMetatable(*call))
addConstraint(scope, value->location, NameConstraint{*firstValueType, var->name.value, /*synthetic*/ true});
}
else
{
if (const AstExprGlobal* global = call->func->as<AstExprGlobal>(); global && global->name == "setmetatable")
{
addConstraint(scope, value->location, NameConstraint{*firstValueType, var->name.value, /*synthetic*/ true});
}
}
if (matchSetMetatable(*call))
addConstraint(scope, value->location, NameConstraint{*firstValueType, var->name.value, /*synthetic*/ true});
}
}
@ -1571,20 +1597,6 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatTypeAlias*
ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatTypeFunction* function)
{
// If a type function with the same name was already defined, we skip over
auto bindingIt = scope->privateTypeBindings.find(function->name.value);
if (bindingIt == scope->privateTypeBindings.end())
return ControlFlow::None;
TypeFun typeFunction = bindingIt->second;
// Adding typeAliasExpansionConstraint on user-defined type function for the constraint solver
if (auto typeFunctionTy = get<TypeFunctionInstanceType>(follow(typeFunction.type)))
{
TypeId expansionTy = arena->addType(PendingExpansionType{{}, function->name, typeFunctionTy->typeArguments, typeFunctionTy->packArguments});
addConstraint(scope, function->location, TypeAliasExpansionConstraint{/* target */ expansionTy});
}
return ControlFlow::None;
}
@ -2047,7 +2059,7 @@ InferencePack ConstraintGenerator::checkPack(const ScopePtr& scope, AstExprCall*
return InferencePack{arena->addTypePack({resultTy}), {refinementArena.variadic(returnRefinements)}};
}
if (FFlag::LuauTypestateBuiltins2 && shouldTypestateForFirstArgument(*call) && call->args.size > 0 && isLValue(call->args.data[0]))
if (shouldTypestateForFirstArgument(*call) && call->args.size > 0 && isLValue(call->args.data[0]))
{
AstExpr* targetExpr = call->args.data[0];
auto resultTy = arena->addType(BlockedType{});
@ -2196,7 +2208,8 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprConstantStrin
if (forceSingleton)
return Inference{arena->addType(SingletonType{StringSingleton{std::string{string->value.data, string->value.size}}})};
FreeType ft = FreeType{scope.get()};
FreeType ft =
FFlag::LuauFreeTypesMustHaveBounds ? FreeType{scope.get(), builtinTypes->neverType, builtinTypes->unknownType} : FreeType{scope.get()};
ft.lowerBound = arena->addType(SingletonType{StringSingleton{std::string{string->value.data, string->value.size}}});
ft.upperBound = builtinTypes->stringType;
const TypeId freeTy = arena->addType(ft);
@ -2210,7 +2223,8 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprConstantBool*
if (forceSingleton)
return Inference{singletonType};
FreeType ft = FreeType{scope.get()};
FreeType ft =
FFlag::LuauFreeTypesMustHaveBounds ? FreeType{scope.get(), builtinTypes->neverType, builtinTypes->unknownType} : FreeType{scope.get()};
ft.lowerBound = singletonType;
ft.upperBound = builtinTypes->booleanType;
const TypeId freeTy = arena->addType(ft);
@ -2372,8 +2386,17 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprFunction* fun
Checkpoint endCheckpoint = checkpoint(this);
TypeId generalizedTy = arena->addType(BlockedType{});
NotNull<Constraint> gc =
addConstraint(sig.signatureScope, func->location, GeneralizationConstraint{generalizedTy, sig.signature, std::move(interiorTypes.back())});
NotNull<Constraint> gc = addConstraint(
sig.signatureScope,
func->location,
GeneralizationConstraint{
generalizedTy, sig.signature, FFlag::LuauTrackInteriorFreeTypesOnScope ? std::vector<TypeId>{} : std::move(interiorTypes.back())
}
);
if (FFlag::LuauTrackInteriorFreeTypesOnScope)
sig.signatureScope->interiorFreeTypes = std::move(interiorTypes.back());
getMutable<BlockedType>(generalizedTy)->setOwner(gc);
interiorTypes.pop_back();
@ -2721,15 +2744,12 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExpr* expr, Type
visitLValue(scope, e, rhsType);
else if (auto e = expr->as<AstExprError>())
{
if (FFlag::LuauNewSolverVisitErrorExprLvalues)
// If we end up with some sort of error expression in an lvalue
// position, at least go and check the expressions so that when
// we visit them later, there aren't any invalid assumptions.
for (auto subExpr : e->expressions)
{
// If we end up with some sort of error expression in an lvalue
// position, at least go and check the expressions so that when
// we visit them later, there aren't any invalid assumptions.
for (auto subExpr : e->expressions)
{
check(scope, subExpr);
}
check(scope, subExpr);
}
}
else
@ -2790,6 +2810,14 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprGlobal* glob
DefId def = dfg->getDef(global);
rootScope->lvalueTypes[def] = rhsType;
if (FFlag::InferGlobalTypes)
{
// Sketchy: We're specifically looking for BlockedTypes that were
// initially created by ConstraintGenerator::prepopulateGlobalScope.
if (auto bt = get<BlockedType>(follow(*annotatedTy)); bt && !bt->getOwner())
emplaceType<BoundType>(asMutable(*annotatedTy), rhsType);
}
addConstraint(scope, global->location, SubtypeConstraint{rhsType, *annotatedTy});
}
}
@ -2931,11 +2959,11 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprTable* expr,
ty,
expr,
toBlock
);
// The visitor we ran prior should ensure that there are no
// blocked types that we would encounter while matching on
// this expression.
LUAU_ASSERT(toBlock.empty());
);
// The visitor we ran prior should ensure that there are no
// blocked types that we would encounter while matching on
// this expression.
LUAU_ASSERT(toBlock.empty());
}
}
@ -3182,9 +3210,8 @@ TypeId ConstraintGenerator::resolveReferenceType(
if (alias.has_value())
{
// If the alias is not generic, we don't need to set up a blocked
// type and an instantiation constraint.
if (alias.has_value() && alias->typeParams.empty() && alias->typePackParams.empty())
// If the alias is not generic, we don't need to set up a blocked type and an instantiation constraint
if (alias.has_value() && alias->typeParams.empty() && alias->typePackParams.empty() && !ref->hasParameterList)
{
result = alias->type;
}
@ -3393,6 +3420,12 @@ TypeId ConstraintGenerator::resolveType(const ScopePtr& scope, AstType* ty, bool
}
else if (auto unionAnnotation = ty->as<AstTypeUnion>())
{
if (FFlag::LuauPreserveUnionIntersectionNodeForLeadingTokenSingleType)
{
if (unionAnnotation->types.size == 1)
return resolveType(scope, unionAnnotation->types.data[0], inTypeArguments);
}
std::vector<TypeId> parts;
for (AstType* part : unionAnnotation->types)
{
@ -3403,6 +3436,12 @@ TypeId ConstraintGenerator::resolveType(const ScopePtr& scope, AstType* ty, bool
}
else if (auto intersectionAnnotation = ty->as<AstTypeIntersection>())
{
if (FFlag::LuauPreserveUnionIntersectionNodeForLeadingTokenSingleType)
{
if (intersectionAnnotation->types.size == 1)
return resolveType(scope, intersectionAnnotation->types.data[0], inTypeArguments);
}
std::vector<TypeId> parts;
for (AstType* part : intersectionAnnotation->types)
{
@ -3411,6 +3450,10 @@ TypeId ConstraintGenerator::resolveType(const ScopePtr& scope, AstType* ty, bool
result = arena->addType(IntersectionType{parts});
}
else if (auto typeGroupAnnotation = ty->as<AstTypeGroup>())
{
result = resolveType(scope, typeGroupAnnotation->type, inTypeArguments);
}
else if (auto boolAnnotation = ty->as<AstTypeSingletonBool>())
{
if (boolAnnotation->value)
@ -3694,6 +3737,26 @@ struct GlobalPrepopulator : AstVisitor
return true;
}
bool visit(AstStatAssign* assign) override
{
if (FFlag::InferGlobalTypes)
{
for (const Luau::AstExpr* expr : assign->vars)
{
if (const AstExprGlobal* g = expr->as<AstExprGlobal>())
{
if (!globalScope->lookup(g->name))
globalScope->globalsToWarn.insert(g->name.value);
TypeId bt = arena->addType(BlockedType{});
globalScope->bindings[g->name] = Binding{bt, g->location};
}
}
}
return true;
}
bool visit(AstStatFunction* function) override
{
if (AstExprGlobal* g = function->name->as<AstExprGlobal>())
@ -3877,20 +3940,7 @@ TypeId ConstraintGenerator::createTypeFunctionInstance(
TypeId ConstraintGenerator::simplifyUnion(const ScopePtr& scope, Location location, TypeId left, TypeId right)
{
if (FFlag::DebugLuauEqSatSimplification)
{
TypeId ty = arena->addType(UnionType{{left, right}});
std::optional<EqSatSimplificationResult> res = eqSatSimplify(simplifier, ty);
if (!res)
return ty;
for (TypeId tyFun : res->newTypeFunctions)
addConstraint(scope, location, ReduceConstraint{tyFun});
return res->result;
}
else
return ::Luau::simplifyUnion(builtinTypes, arena, left, right).result;
return ::Luau::simplifyUnion(builtinTypes, arena, left, right).result;
}
std::vector<NotNull<Constraint>> borrowConstraints(const std::vector<ConstraintPtr>& constraints)

View file

@ -31,10 +31,12 @@ LUAU_FASTFLAGVARIABLE(DebugLuauLogSolver)
LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverIncludeDependencies)
LUAU_FASTFLAGVARIABLE(DebugLuauLogBindings)
LUAU_FASTINTVARIABLE(LuauSolverRecursionLimit, 500)
LUAU_FASTFLAGVARIABLE(LuauRemoveNotAnyHack)
LUAU_FASTFLAGVARIABLE(DebugLuauEqSatSimplification)
LUAU_FASTFLAG(LuauNewSolverPopulateTableLocations)
LUAU_FASTFLAGVARIABLE(LuauAllowNilAssignmentToIndexer)
LUAU_FASTFLAG(LuauTrackInteriorFreeTypesOnScope)
LUAU_FASTFLAGVARIABLE(LuauAlwaysFillInFunctionCallDiscriminantTypes)
LUAU_FASTFLAGVARIABLE(LuauTrackInteriorFreeTablesOnScope)
namespace Luau
{
@ -72,7 +74,7 @@ size_t HashBlockedConstraintId::operator()(const BlockedConstraintId& bci) const
{
if (auto blocked = get<BlockedType>(ty))
{
Constraint* owner = blocked->getOwner();
const Constraint* owner = blocked->getOwner();
LUAU_ASSERT(owner);
return owner == constraint;
}
@ -443,7 +445,7 @@ void ConstraintSolver::run()
if (success)
{
unblock(c);
unsolvedConstraints.erase(unsolvedConstraints.begin() + i);
unsolvedConstraints.erase(unsolvedConstraints.begin() + ptrdiff_t(i));
// decrement the referenced free types for this constraint if we dispatched successfully!
for (auto ty : c->getMaybeMutatedFreeTypes())
@ -550,7 +552,7 @@ void ConstraintSolver::finalizeTypeFunctions()
}
}
bool ConstraintSolver::isDone()
bool ConstraintSolver::isDone() const
{
return unsolvedConstraints.empty();
}
@ -723,8 +725,20 @@ bool ConstraintSolver::tryDispatch(const GeneralizationConstraint& c, NotNull<co
bind(constraint, c.generalizedType, builtinTypes->errorRecoveryType());
}
for (TypeId ty : c.interiorTypes)
generalize(NotNull{arena}, builtinTypes, constraint->scope, generalizedTypes, ty, /* avoidSealingTables */ false);
if (FFlag::LuauTrackInteriorFreeTypesOnScope)
{
// We check if this member is initialized and then access it, but
// clang-tidy doesn't understand this is safe.
if (constraint->scope->interiorFreeTypes)
for (TypeId ty : *constraint->scope->interiorFreeTypes) // NOLINT(bugprone-unchecked-optional-access)
generalize(NotNull{arena}, builtinTypes, constraint->scope, generalizedTypes, ty, /* avoidSealingTables */ false);
}
else
{
for (TypeId ty : c.interiorTypes)
generalize(NotNull{arena}, builtinTypes, constraint->scope, generalizedTypes, ty, /* avoidSealingTables */ false);
}
return true;
}
@ -800,9 +814,17 @@ bool ConstraintSolver::tryDispatch(const IterableConstraint& c, NotNull<const Co
{
TypeId keyTy = freshType(arena, builtinTypes, constraint->scope);
TypeId valueTy = freshType(arena, builtinTypes, constraint->scope);
if (FFlag::LuauTrackInteriorFreeTypesOnScope)
{
trackInteriorFreeType(constraint->scope, keyTy);
trackInteriorFreeType(constraint->scope, valueTy);
}
TypeId tableTy =
arena->addType(TableType{TableType::Props{}, TableIndexer{keyTy, valueTy}, TypeLevel{}, constraint->scope, TableState::Free});
if (FFlag::LuauTrackInteriorFreeTypesOnScope && FFlag::LuauTrackInteriorFreeTablesOnScope)
trackInteriorFreeType(constraint->scope, tableTy);
unify(constraint, nextTy, tableTy);
auto it = begin(c.variables);
@ -939,14 +961,6 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul
if (auto typeFn = get<TypeFunctionInstanceType>(follow(tf->type)))
pushConstraint(NotNull(constraint->scope.get()), constraint->location, ReduceConstraint{tf->type});
// If there are no parameters to the type function we can just use the type
// directly.
if (tf->typeParams.empty() && tf->typePackParams.empty())
{
bindResult(tf->type);
return true;
}
// Due to how pending expansion types and TypeFun's are created
// If this check passes, we have created a cyclic / corecursive type alias
// of size 0
@ -959,6 +973,13 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul
return true;
}
// If there are no parameters to the type function we can just use the type directly
if (tf->typeParams.empty() && tf->typePackParams.empty())
{
bindResult(tf->type);
return true;
}
auto [typeArguments, packArguments] = saturateArguments(arena, builtinTypes, *tf, petv->typeArguments, petv->packArguments);
bool sameTypes = std::equal(
@ -1122,6 +1143,28 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul
return true;
}
void ConstraintSolver::fillInDiscriminantTypes(
NotNull<const Constraint> constraint,
const std::vector<std::optional<TypeId>>& discriminantTypes
)
{
for (std::optional<TypeId> ty : discriminantTypes)
{
if (!ty)
continue;
// If the discriminant type has been transmuted, we need to unblock them.
if (!isBlocked(*ty))
{
unblock(*ty, constraint->location);
continue;
}
// We bind any unused discriminants to the `*no-refine*` type indicating that it can be safely ignored.
emplaceType<BoundType>(asMutable(follow(*ty)), builtinTypes->noRefineType);
}
}
bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull<const Constraint> constraint)
{
TypeId fn = follow(c.fn);
@ -1137,6 +1180,8 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull<cons
{
emplaceTypePack<BoundTypePack>(asMutable(c.result), builtinTypes->anyTypePack);
unblock(c.result, constraint->location);
if (FFlag::LuauAlwaysFillInFunctionCallDiscriminantTypes)
fillInDiscriminantTypes(constraint, c.discriminantTypes);
return true;
}
@ -1144,12 +1189,16 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull<cons
if (get<ErrorType>(fn))
{
bind(constraint, c.result, builtinTypes->errorRecoveryTypePack());
if (FFlag::LuauAlwaysFillInFunctionCallDiscriminantTypes)
fillInDiscriminantTypes(constraint, c.discriminantTypes);
return true;
}
if (get<NeverType>(fn))
{
bind(constraint, c.result, builtinTypes->neverTypePack);
if (FFlag::LuauAlwaysFillInFunctionCallDiscriminantTypes)
fillInDiscriminantTypes(constraint, c.discriminantTypes);
return true;
}
@ -1219,50 +1268,46 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull<cons
if (ftv)
{
if (ftv->dcrMagicFunction)
usedMagic = ftv->dcrMagicFunction(MagicFunctionCallContext{NotNull{this}, constraint, c.callSite, c.argsPack, result});
if (ftv->dcrMagicRefinement)
ftv->dcrMagicRefinement(MagicRefinementContext{constraint->scope, c.callSite, c.discriminantTypes});
if (ftv->magic)
{
usedMagic = ftv->magic->infer(MagicFunctionCallContext{NotNull{this}, constraint, c.callSite, c.argsPack, result});
ftv->magic->refine(MagicRefinementContext{constraint->scope, c.callSite, c.discriminantTypes});
}
}
if (!usedMagic)
emplace<FreeTypePack>(constraint, c.result, constraint->scope);
}
for (std::optional<TypeId> ty : c.discriminantTypes)
if (FFlag::LuauAlwaysFillInFunctionCallDiscriminantTypes)
{
if (!ty)
continue;
// If the discriminant type has been transmuted, we need to unblock them.
if (!isBlocked(*ty))
fillInDiscriminantTypes(constraint, c.discriminantTypes);
}
else
{
// NOTE: This is the body of the `fillInDiscriminantTypes` helper.
for (std::optional<TypeId> ty : c.discriminantTypes)
{
unblock(*ty, constraint->location);
continue;
}
if (!ty)
continue;
// If the discriminant type has been transmuted, we need to unblock them.
if (!isBlocked(*ty))
{
unblock(*ty, constraint->location);
continue;
}
if (FFlag::LuauRemoveNotAnyHack)
{
// We bind any unused discriminants to the `*no-refine*` type indicating that it can be safely ignored.
emplaceType<BoundType>(asMutable(follow(*ty)), builtinTypes->noRefineType);
}
else
{
// We use `any` here because the discriminant type may be pointed at by both branches,
// where the discriminant type is not negated, and the other where it is negated, i.e.
// `unknown ~ unknown` and `~unknown ~ never`, so `T & unknown ~ T` and `T & ~unknown ~ never`
// v.s.
// `any ~ any` and `~any ~ any`, so `T & any ~ T` and `T & ~any ~ T`
//
// In practice, users cannot negate `any`, so this is an implementation detail we can always change.
emplaceType<BoundType>(asMutable(follow(*ty)), builtinTypes->anyType);
}
}
OverloadResolver resolver{
builtinTypes,
NotNull{arena},
simplifier,
normalizer,
typeFunctionRuntime,
constraint->scope,
@ -1618,7 +1663,7 @@ bool ConstraintSolver::tryDispatchHasIndexer(
for (TypeId part : parts)
{
TypeId r = arena->addType(BlockedType{});
getMutable<BlockedType>(r)->setOwner(const_cast<Constraint*>(constraint.get()));
getMutable<BlockedType>(r)->setOwner(constraint.get());
bool ok = tryDispatchHasIndexer(recursionDepth, constraint, part, indexType, r, seen);
// If we've cut a recursive loop short, skip it.
@ -1650,7 +1695,7 @@ bool ConstraintSolver::tryDispatchHasIndexer(
for (TypeId part : parts)
{
TypeId r = arena->addType(BlockedType{});
getMutable<BlockedType>(r)->setOwner(const_cast<Constraint*>(constraint.get()));
getMutable<BlockedType>(r)->setOwner(constraint.get());
bool ok = tryDispatchHasIndexer(recursionDepth, constraint, part, indexType, r, seen);
// If we've cut a recursive loop short, skip it.
@ -1770,6 +1815,10 @@ bool ConstraintSolver::tryDispatch(const AssignPropConstraint& c, NotNull<const
else
{
TypeId newUpperBound = arena->addType(TableType{TableState::Free, TypeLevel{}, constraint->scope});
if (FFlag::LuauTrackInteriorFreeTypesOnScope && FFlag::LuauTrackInteriorFreeTablesOnScope)
trackInteriorFreeType(constraint->scope, newUpperBound);
TableType* upperTable = getMutable<TableType>(newUpperBound);
LUAU_ASSERT(upperTable);
@ -2048,6 +2097,8 @@ bool ConstraintSolver::tryDispatch(const UnpackConstraint& c, NotNull<const Cons
// constitute any meaningful constraint, so we replace it
// with a free type.
TypeId f = freshType(arena, builtinTypes, constraint->scope);
if (FFlag::LuauTrackInteriorFreeTypesOnScope)
trackInteriorFreeType(constraint->scope, f);
shiftReferences(resultTy, f);
emplaceType<BoundType>(asMutable(resultTy), f);
}
@ -2103,6 +2154,11 @@ bool ConstraintSolver::tryDispatch(const ReduceConstraint& c, NotNull<const Cons
if (force || reductionFinished)
{
for (auto& message : result.messages)
{
reportError(std::move(message));
}
// if we're completely dispatching this constraint, we want to record any uninhabited type functions to unblock.
for (auto error : result.errors)
{
@ -2178,6 +2234,11 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl
{
TypeId keyTy = freshType(arena, builtinTypes, constraint->scope);
TypeId valueTy = freshType(arena, builtinTypes, constraint->scope);
if (FFlag::LuauTrackInteriorFreeTypesOnScope)
{
trackInteriorFreeType(constraint->scope, keyTy);
trackInteriorFreeType(constraint->scope, valueTy);
}
TypeId tableTy = arena->addType(TableType{TableState::Sealed, {}, constraint->scope});
getMutable<TableType>(tableTy)->indexer = TableIndexer{keyTy, valueTy};
@ -2434,6 +2495,8 @@ TablePropLookupResult ConstraintSolver::lookupTableProp(
if (ttv->state == TableState::Free)
{
TypeId result = freshType(arena, builtinTypes, ttv->scope);
if (FFlag::LuauTrackInteriorFreeTypesOnScope)
trackInteriorFreeType(ttv->scope, result);
switch (context)
{
case ValueContext::RValue:
@ -2539,10 +2602,17 @@ TablePropLookupResult ConstraintSolver::lookupTableProp(
NotNull<Scope> scope{ft->scope};
const TypeId newUpperBound = arena->addType(TableType{TableState::Free, TypeLevel{}, scope});
if (FFlag::LuauTrackInteriorFreeTypesOnScope && FFlag::LuauTrackInteriorFreeTablesOnScope)
trackInteriorFreeType(constraint->scope, newUpperBound);
TableType* tt = getMutable<TableType>(newUpperBound);
LUAU_ASSERT(tt);
TypeId propType = freshType(arena, builtinTypes, scope);
if (FFlag::LuauTrackInteriorFreeTypesOnScope)
trackInteriorFreeType(scope, propType);
switch (context)
{
case ValueContext::RValue:
@ -2773,10 +2843,10 @@ bool ConstraintSolver::blockOnPendingTypes(TypeId target, NotNull<const Constrai
return !blocker.blocked;
}
bool ConstraintSolver::blockOnPendingTypes(TypePackId pack, NotNull<const Constraint> constraint)
bool ConstraintSolver::blockOnPendingTypes(TypePackId targetPack, NotNull<const Constraint> constraint)
{
Blocker blocker{NotNull{this}, constraint};
blocker.traverse(pack);
blocker.traverse(targetPack);
return !blocker.blocked;
}

View file

@ -13,7 +13,6 @@
LUAU_FASTFLAG(DebugLuauFreezeArena)
LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAG(LuauTypestateBuiltins2)
namespace Luau
{
@ -62,6 +61,12 @@ const RefinementKey* RefinementKeyArena::node(const RefinementKey* parent, DefId
return allocator.allocate(RefinementKey{parent, def, propName});
}
DataFlowGraph::DataFlowGraph(NotNull<DefArena> defArena, NotNull<RefinementKeyArena> keyArena)
: defArena{defArena}
, keyArena{keyArena}
{
}
DefId DataFlowGraph::getDef(const AstExpr* expr) const
{
auto def = astDefs.find(expr);
@ -178,11 +183,23 @@ bool DfgScope::canUpdateDefinition(DefId def, const std::string& key) const
return true;
}
DataFlowGraph DataFlowGraphBuilder::build(AstStatBlock* block, NotNull<InternalErrorReporter> handle)
DataFlowGraphBuilder::DataFlowGraphBuilder(NotNull<DefArena> defArena, NotNull<RefinementKeyArena> keyArena)
: graph{defArena, keyArena}
, defArena{defArena}
, keyArena{keyArena}
{
}
DataFlowGraph DataFlowGraphBuilder::build(
AstStatBlock* block,
NotNull<DefArena> defArena,
NotNull<RefinementKeyArena> keyArena,
NotNull<struct InternalErrorReporter> handle
)
{
LUAU_TIMETRACE_SCOPE("DataFlowGraphBuilder::build", "Typechecking");
DataFlowGraphBuilder builder;
DataFlowGraphBuilder builder(defArena, keyArena);
builder.handle = handle;
DfgScope* moduleScope = builder.makeChildScope();
PushScope ps{builder.scopeStack, moduleScope};
@ -198,30 +215,6 @@ DataFlowGraph DataFlowGraphBuilder::build(AstStatBlock* block, NotNull<InternalE
return std::move(builder.graph);
}
std::pair<std::shared_ptr<DataFlowGraph>, std::vector<std::unique_ptr<DfgScope>>> DataFlowGraphBuilder::buildShared(
AstStatBlock* block,
NotNull<InternalErrorReporter> handle
)
{
LUAU_TIMETRACE_SCOPE("DataFlowGraphBuilder::build", "Typechecking");
DataFlowGraphBuilder builder;
builder.handle = handle;
DfgScope* moduleScope = builder.makeChildScope();
PushScope ps{builder.scopeStack, moduleScope};
builder.visitBlockWithoutChildScope(block);
builder.resolveCaptures();
if (FFlag::DebugLuauFreezeArena)
{
builder.defArena->allocator.freeze();
builder.keyArena->allocator.freeze();
}
return {std::make_shared<DataFlowGraph>(std::move(builder.graph)), std::move(builder.scopes)};
}
void DataFlowGraphBuilder::resolveCaptures()
{
for (const auto& [_, capture] : captures)
@ -885,7 +878,7 @@ DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprCall* c)
{
visitExpr(c->func);
if (FFlag::LuauTypestateBuiltins2 && shouldTypestateForFirstArgument(*c) && c->args.size > 1 && isLValue(*c->args.begin()))
if (shouldTypestateForFirstArgument(*c) && c->args.size > 1 && isLValue(*c->args.begin()))
{
AstExpr* firstArg = *c->args.begin();
@ -1176,6 +1169,8 @@ void DataFlowGraphBuilder::visitType(AstType* t)
return; // ok
else if (auto s = t->as<AstTypeSingletonString>())
return; // ok
else if (auto g = t->as<AstTypeGroup>())
return visitType(g->type);
else
handle->ice("Unknown AstType in DataFlowGraphBuilder::visitType");
}

View file

@ -13,7 +13,6 @@
namespace Luau
{
std::string DiffPathNode::toString() const
{
switch (kind)
@ -945,14 +944,12 @@ std::vector<std::pair<TypeId, TypeId>>::const_reverse_iterator DifferEnvironment
return visitingStack.crend();
}
DifferResult diff(TypeId ty1, TypeId ty2)
{
DifferEnvironment differEnv{ty1, ty2, std::nullopt, std::nullopt};
return diffUsingEnv(differEnv, ty1, ty2);
}
DifferResult diffWithSymbols(TypeId ty1, TypeId ty2, std::optional<std::string> symbol1, std::optional<std::string> symbol2)
{
DifferEnvironment differEnv{ty1, ty2, symbol1, symbol2};

View file

@ -1,15 +1,13 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/BuiltinDefinitions.h"
LUAU_FASTFLAG(LuauMathMap)
LUAU_FASTFLAGVARIABLE(LuauVectorDefinitions)
LUAU_FASTFLAGVARIABLE(LuauVectorDefinitionsExtra)
LUAU_FASTFLAG(LuauBufferBitMethods2)
LUAU_FASTFLAGVARIABLE(LuauMathMapDefinition)
LUAU_FASTFLAG(LuauVector2Constructor)
namespace Luau
{
// TODO: there has to be a better way, like splitting up per library
static const std::string kBuiltinDefinitionLuaSrcChecked_DEPRECATED = R"BUILTIN_SRC(
declare bit32: {
@ -30,227 +28,6 @@ declare bit32: {
byteswap: @checked (n: number) -> number,
}
declare math: {
frexp: @checked (n: number) -> (number, number),
ldexp: @checked (s: number, e: number) -> number,
fmod: @checked (x: number, y: number) -> number,
modf: @checked (n: number) -> (number, number),
pow: @checked (x: number, y: number) -> number,
exp: @checked (n: number) -> number,
ceil: @checked (n: number) -> number,
floor: @checked (n: number) -> number,
abs: @checked (n: number) -> number,
sqrt: @checked (n: number) -> number,
log: @checked (n: number, base: number?) -> number,
log10: @checked (n: number) -> number,
rad: @checked (n: number) -> number,
deg: @checked (n: number) -> number,
sin: @checked (n: number) -> number,
cos: @checked (n: number) -> number,
tan: @checked (n: number) -> number,
sinh: @checked (n: number) -> number,
cosh: @checked (n: number) -> number,
tanh: @checked (n: number) -> number,
atan: @checked (n: number) -> number,
acos: @checked (n: number) -> number,
asin: @checked (n: number) -> number,
atan2: @checked (y: number, x: number) -> number,
min: @checked (number, ...number) -> number,
max: @checked (number, ...number) -> number,
pi: number,
huge: number,
randomseed: @checked (seed: number) -> (),
random: @checked (number?, number?) -> number,
sign: @checked (n: number) -> number,
clamp: @checked (n: number, min: number, max: number) -> number,
noise: @checked (x: number, y: number?, z: number?) -> number,
round: @checked (n: number) -> number,
}
type DateTypeArg = {
year: number,
month: number,
day: number,
hour: number?,
min: number?,
sec: number?,
isdst: boolean?,
}
type DateTypeResult = {
year: number,
month: number,
wday: number,
yday: number,
day: number,
hour: number,
min: number,
sec: number,
isdst: boolean,
}
declare os: {
time: (time: DateTypeArg?) -> number,
date: ((formatString: "*t" | "!*t", time: number?) -> DateTypeResult) & ((formatString: string?, time: number?) -> string),
difftime: (t2: DateTypeResult | number, t1: DateTypeResult | number) -> number,
clock: () -> number,
}
@checked declare function require(target: any): any
@checked declare function getfenv(target: any): { [string]: any }
declare _G: any
declare _VERSION: string
declare function gcinfo(): number
declare function print<T...>(...: T...)
declare function type<T>(value: T): string
declare function typeof<T>(value: T): string
-- `assert` has a magic function attached that will give more detailed type information
declare function assert<T>(value: T, errorMessage: string?): T
declare function error<T>(message: T, level: number?): never
declare function tostring<T>(value: T): string
declare function tonumber<T>(value: T, radix: number?): number?
declare function rawequal<T1, T2>(a: T1, b: T2): boolean
declare function rawget<K, V>(tab: {[K]: V}, k: K): V
declare function rawset<K, V>(tab: {[K]: V}, k: K, v: V): {[K]: V}
declare function rawlen<K, V>(obj: {[K]: V} | string): number
declare function setfenv<T..., R...>(target: number | (T...) -> R..., env: {[string]: any}): ((T...) -> R...)?
declare function ipairs<V>(tab: {V}): (({V}, number) -> (number?, V), {V}, number)
declare function pcall<A..., R...>(f: (A...) -> R..., ...: A...): (boolean, R...)
-- FIXME: The actual type of `xpcall` is:
-- <E, A..., R1..., R2...>(f: (A...) -> R1..., err: (E) -> R2..., A...) -> (true, R1...) | (false, R2...)
-- Since we can't represent the return value, we use (boolean, R1...).
declare function xpcall<E, A..., R1..., R2...>(f: (A...) -> R1..., err: (E) -> R2..., ...: A...): (boolean, R1...)
-- `select` has a magic function attached to provide more detailed type information
declare function select<A...>(i: string | number, ...: A...): ...any
-- FIXME: This type is not entirely correct - `loadstring` returns a function or
-- (nil, string).
declare function loadstring<A...>(src: string, chunkname: string?): (((A...) -> any)?, string?)
@checked declare function newproxy(mt: boolean?): any
declare coroutine: {
create: <A..., R...>(f: (A...) -> R...) -> thread,
resume: <A..., R...>(co: thread, A...) -> (boolean, R...),
running: () -> thread,
status: @checked (co: thread) -> "dead" | "running" | "normal" | "suspended",
wrap: <A..., R...>(f: (A...) -> R...) -> ((A...) -> R...),
yield: <A..., R...>(A...) -> R...,
isyieldable: () -> boolean,
close: @checked (co: thread) -> (boolean, any)
}
declare table: {
concat: <V>(t: {V}, sep: string?, i: number?, j: number?) -> string,
insert: (<V>(t: {V}, value: V) -> ()) & (<V>(t: {V}, pos: number, value: V) -> ()),
maxn: <V>(t: {V}) -> number,
remove: <V>(t: {V}, number?) -> V?,
sort: <V>(t: {V}, comp: ((V, V) -> boolean)?) -> (),
create: <V>(count: number, value: V?) -> {V},
find: <V>(haystack: {V}, needle: V, init: number?) -> number?,
unpack: <V>(list: {V}, i: number?, j: number?) -> ...V,
pack: <V>(...V) -> { n: number, [number]: V },
getn: <V>(t: {V}) -> number,
foreach: <K, V>(t: {[K]: V}, f: (K, V) -> ()) -> (),
foreachi: <V>({V}, (number, V) -> ()) -> (),
move: <V>(src: {V}, a: number, b: number, t: number, dst: {V}?) -> {V},
clear: <K, V>(table: {[K]: V}) -> (),
isfrozen: <K, V>(t: {[K]: V}) -> boolean,
}
declare debug: {
info: (<R...>(thread: thread, level: number, options: string) -> R...) & (<R...>(level: number, options: string) -> R...) & (<A..., R1..., R2...>(func: (A...) -> R1..., options: string) -> R2...),
traceback: ((message: string?, level: number?) -> string) & ((thread: thread, message: string?, level: number?) -> string),
}
declare utf8: {
char: @checked (...number) -> string,
charpattern: string,
codes: @checked (str: string) -> ((string, number) -> (number, number), string, number),
codepoint: @checked (str: string, i: number?, j: number?) -> ...number,
len: @checked (s: string, i: number?, j: number?) -> (number?, number?),
offset: @checked (s: string, n: number?, i: number?) -> number,
}
-- Cannot use `typeof` here because it will produce a polytype when we expect a monotype.
declare function unpack<V>(tab: {V}, i: number?, j: number?): ...V
--- Buffer API
declare buffer: {
create: @checked (size: number) -> buffer,
fromstring: @checked (str: string) -> buffer,
tostring: @checked (b: buffer) -> string,
len: @checked (b: buffer) -> number,
copy: @checked (target: buffer, targetOffset: number, source: buffer, sourceOffset: number?, count: number?) -> (),
fill: @checked (b: buffer, offset: number, value: number, count: number?) -> (),
readi8: @checked (b: buffer, offset: number) -> number,
readu8: @checked (b: buffer, offset: number) -> number,
readi16: @checked (b: buffer, offset: number) -> number,
readu16: @checked (b: buffer, offset: number) -> number,
readi32: @checked (b: buffer, offset: number) -> number,
readu32: @checked (b: buffer, offset: number) -> number,
readf32: @checked (b: buffer, offset: number) -> number,
readf64: @checked (b: buffer, offset: number) -> number,
writei8: @checked (b: buffer, offset: number, value: number) -> (),
writeu8: @checked (b: buffer, offset: number, value: number) -> (),
writei16: @checked (b: buffer, offset: number, value: number) -> (),
writeu16: @checked (b: buffer, offset: number, value: number) -> (),
writei32: @checked (b: buffer, offset: number, value: number) -> (),
writeu32: @checked (b: buffer, offset: number, value: number) -> (),
writef32: @checked (b: buffer, offset: number, value: number) -> (),
writef64: @checked (b: buffer, offset: number, value: number) -> (),
readstring: @checked (b: buffer, offset: number, count: number) -> string,
writestring: @checked (b: buffer, offset: number, value: string, count: number?) -> (),
}
)BUILTIN_SRC";
static const std::string kBuiltinDefinitionLuaSrcChecked = R"BUILTIN_SRC(
declare bit32: {
band: @checked (...number) -> number,
bor: @checked (...number) -> number,
bxor: @checked (...number) -> number,
btest: @checked (number, ...number) -> boolean,
rrotate: @checked (x: number, disp: number) -> number,
lrotate: @checked (x: number, disp: number) -> number,
lshift: @checked (x: number, disp: number) -> number,
arshift: @checked (x: number, disp: number) -> number,
rshift: @checked (x: number, disp: number) -> number,
bnot: @checked (x: number) -> number,
extract: @checked (n: number, field: number, width: number?) -> number,
replace: @checked (n: number, v: number, field: number, width: number?) -> number,
countlz: @checked (n: number) -> number,
countrz: @checked (n: number) -> number,
byteswap: @checked (n: number) -> number,
}
declare math: {
frexp: @checked (n: number) -> (number, number),
ldexp: @checked (s: number, e: number) -> number,
@ -422,7 +199,231 @@ declare utf8: {
-- Cannot use `typeof` here because it will produce a polytype when we expect a monotype.
declare function unpack<V>(tab: {V}, i: number?, j: number?): ...V
)BUILTIN_SRC";
static const std::string kBuiltinDefinitionBaseSrc = R"BUILTIN_SRC(
@checked declare function require(target: any): any
@checked declare function getfenv(target: any): { [string]: any }
declare _G: any
declare _VERSION: string
declare function gcinfo(): number
declare function print<T...>(...: T...)
declare function type<T>(value: T): string
declare function typeof<T>(value: T): string
-- `assert` has a magic function attached that will give more detailed type information
declare function assert<T>(value: T, errorMessage: string?): T
declare function error<T>(message: T, level: number?): never
declare function tostring<T>(value: T): string
declare function tonumber<T>(value: T, radix: number?): number?
declare function rawequal<T1, T2>(a: T1, b: T2): boolean
declare function rawget<K, V>(tab: {[K]: V}, k: K): V
declare function rawset<K, V>(tab: {[K]: V}, k: K, v: V): {[K]: V}
declare function rawlen<K, V>(obj: {[K]: V} | string): number
declare function setfenv<T..., R...>(target: number | (T...) -> R..., env: {[string]: any}): ((T...) -> R...)?
declare function ipairs<V>(tab: {V}): (({V}, number) -> (number?, V), {V}, number)
declare function pcall<A..., R...>(f: (A...) -> R..., ...: A...): (boolean, R...)
-- FIXME: The actual type of `xpcall` is:
-- <E, A..., R1..., R2...>(f: (A...) -> R1..., err: (E) -> R2..., A...) -> (true, R1...) | (false, R2...)
-- Since we can't represent the return value, we use (boolean, R1...).
declare function xpcall<E, A..., R1..., R2...>(f: (A...) -> R1..., err: (E) -> R2..., ...: A...): (boolean, R1...)
-- `select` has a magic function attached to provide more detailed type information
declare function select<A...>(i: string | number, ...: A...): ...any
-- FIXME: This type is not entirely correct - `loadstring` returns a function or
-- (nil, string).
declare function loadstring<A...>(src: string, chunkname: string?): (((A...) -> any)?, string?)
@checked declare function newproxy(mt: boolean?): any
-- Cannot use `typeof` here because it will produce a polytype when we expect a monotype.
declare function unpack<V>(tab: {V}, i: number?, j: number?): ...V
)BUILTIN_SRC";
static const std::string kBuiltinDefinitionBit32Src = R"BUILTIN_SRC(
declare bit32: {
band: @checked (...number) -> number,
bor: @checked (...number) -> number,
bxor: @checked (...number) -> number,
btest: @checked (number, ...number) -> boolean,
rrotate: @checked (x: number, disp: number) -> number,
lrotate: @checked (x: number, disp: number) -> number,
lshift: @checked (x: number, disp: number) -> number,
arshift: @checked (x: number, disp: number) -> number,
rshift: @checked (x: number, disp: number) -> number,
bnot: @checked (x: number) -> number,
extract: @checked (n: number, field: number, width: number?) -> number,
replace: @checked (n: number, v: number, field: number, width: number?) -> number,
countlz: @checked (n: number) -> number,
countrz: @checked (n: number) -> number,
byteswap: @checked (n: number) -> number,
}
)BUILTIN_SRC";
static const std::string kBuiltinDefinitionMathSrc = R"BUILTIN_SRC(
declare math: {
frexp: @checked (n: number) -> (number, number),
ldexp: @checked (s: number, e: number) -> number,
fmod: @checked (x: number, y: number) -> number,
modf: @checked (n: number) -> (number, number),
pow: @checked (x: number, y: number) -> number,
exp: @checked (n: number) -> number,
ceil: @checked (n: number) -> number,
floor: @checked (n: number) -> number,
abs: @checked (n: number) -> number,
sqrt: @checked (n: number) -> number,
log: @checked (n: number, base: number?) -> number,
log10: @checked (n: number) -> number,
rad: @checked (n: number) -> number,
deg: @checked (n: number) -> number,
sin: @checked (n: number) -> number,
cos: @checked (n: number) -> number,
tan: @checked (n: number) -> number,
sinh: @checked (n: number) -> number,
cosh: @checked (n: number) -> number,
tanh: @checked (n: number) -> number,
atan: @checked (n: number) -> number,
acos: @checked (n: number) -> number,
asin: @checked (n: number) -> number,
atan2: @checked (y: number, x: number) -> number,
min: @checked (number, ...number) -> number,
max: @checked (number, ...number) -> number,
pi: number,
huge: number,
randomseed: @checked (seed: number) -> (),
random: @checked (number?, number?) -> number,
sign: @checked (n: number) -> number,
clamp: @checked (n: number, min: number, max: number) -> number,
noise: @checked (x: number, y: number?, z: number?) -> number,
round: @checked (n: number) -> number,
map: @checked (x: number, inmin: number, inmax: number, outmin: number, outmax: number) -> number,
lerp: @checked (a: number, b: number, t: number) -> number,
}
)BUILTIN_SRC";
static const std::string kBuiltinDefinitionOsSrc = R"BUILTIN_SRC(
type DateTypeArg = {
year: number,
month: number,
day: number,
hour: number?,
min: number?,
sec: number?,
isdst: boolean?,
}
type DateTypeResult = {
year: number,
month: number,
wday: number,
yday: number,
day: number,
hour: number,
min: number,
sec: number,
isdst: boolean,
}
declare os: {
time: (time: DateTypeArg?) -> number,
date: ((formatString: "*t" | "!*t", time: number?) -> DateTypeResult) & ((formatString: string?, time: number?) -> string),
difftime: (t2: DateTypeResult | number, t1: DateTypeResult | number) -> number,
clock: () -> number,
}
)BUILTIN_SRC";
static const std::string kBuiltinDefinitionCoroutineSrc = R"BUILTIN_SRC(
declare coroutine: {
create: <A..., R...>(f: (A...) -> R...) -> thread,
resume: <A..., R...>(co: thread, A...) -> (boolean, R...),
running: () -> thread,
status: @checked (co: thread) -> "dead" | "running" | "normal" | "suspended",
wrap: <A..., R...>(f: (A...) -> R...) -> ((A...) -> R...),
yield: <A..., R...>(A...) -> R...,
isyieldable: () -> boolean,
close: @checked (co: thread) -> (boolean, any)
}
)BUILTIN_SRC";
static const std::string kBuiltinDefinitionTableSrc = R"BUILTIN_SRC(
declare table: {
concat: <V>(t: {V}, sep: string?, i: number?, j: number?) -> string,
insert: (<V>(t: {V}, value: V) -> ()) & (<V>(t: {V}, pos: number, value: V) -> ()),
maxn: <V>(t: {V}) -> number,
remove: <V>(t: {V}, number?) -> V?,
sort: <V>(t: {V}, comp: ((V, V) -> boolean)?) -> (),
create: <V>(count: number, value: V?) -> {V},
find: <V>(haystack: {V}, needle: V, init: number?) -> number?,
unpack: <V>(list: {V}, i: number?, j: number?) -> ...V,
pack: <V>(...V) -> { n: number, [number]: V },
getn: <V>(t: {V}) -> number,
foreach: <K, V>(t: {[K]: V}, f: (K, V) -> ()) -> (),
foreachi: <V>({V}, (number, V) -> ()) -> (),
move: <V>(src: {V}, a: number, b: number, t: number, dst: {V}?) -> {V},
clear: <K, V>(table: {[K]: V}) -> (),
isfrozen: <K, V>(t: {[K]: V}) -> boolean,
}
)BUILTIN_SRC";
static const std::string kBuiltinDefinitionDebugSrc = R"BUILTIN_SRC(
declare debug: {
info: (<R...>(thread: thread, level: number, options: string) -> R...) & (<R...>(level: number, options: string) -> R...) & (<A..., R1..., R2...>(func: (A...) -> R1..., options: string) -> R2...),
traceback: ((message: string?, level: number?) -> string) & ((thread: thread, message: string?, level: number?) -> string),
}
)BUILTIN_SRC";
static const std::string kBuiltinDefinitionUtf8Src = R"BUILTIN_SRC(
declare utf8: {
char: @checked (...number) -> string,
charpattern: string,
codes: @checked (str: string) -> ((string, number) -> (number, number), string, number),
codepoint: @checked (str: string, i: number?, j: number?) -> ...number,
len: @checked (s: string, i: number?, j: number?) -> (number?, number?),
offset: @checked (s: string, n: number?, i: number?) -> number,
}
)BUILTIN_SRC";
static const std::string kBuiltinDefinitionBufferSrc_DEPRECATED = R"BUILTIN_SRC(
--- Buffer API
declare buffer: {
create: @checked (size: number) -> buffer,
@ -453,10 +454,47 @@ declare buffer: {
)BUILTIN_SRC";
static const std::string kBuiltinDefinitionVectorSrc_DEPRECATED = R"BUILTIN_SRC(
static const std::string kBuiltinDefinitionBufferSrc = R"BUILTIN_SRC(
--- Buffer API
declare buffer: {
create: @checked (size: number) -> buffer,
fromstring: @checked (str: string) -> buffer,
tostring: @checked (b: buffer) -> string,
len: @checked (b: buffer) -> number,
copy: @checked (target: buffer, targetOffset: number, source: buffer, sourceOffset: number?, count: number?) -> (),
fill: @checked (b: buffer, offset: number, value: number, count: number?) -> (),
readi8: @checked (b: buffer, offset: number) -> number,
readu8: @checked (b: buffer, offset: number) -> number,
readi16: @checked (b: buffer, offset: number) -> number,
readu16: @checked (b: buffer, offset: number) -> number,
readi32: @checked (b: buffer, offset: number) -> number,
readu32: @checked (b: buffer, offset: number) -> number,
readf32: @checked (b: buffer, offset: number) -> number,
readf64: @checked (b: buffer, offset: number) -> number,
writei8: @checked (b: buffer, offset: number, value: number) -> (),
writeu8: @checked (b: buffer, offset: number, value: number) -> (),
writei16: @checked (b: buffer, offset: number, value: number) -> (),
writeu16: @checked (b: buffer, offset: number, value: number) -> (),
writei32: @checked (b: buffer, offset: number, value: number) -> (),
writeu32: @checked (b: buffer, offset: number, value: number) -> (),
writef32: @checked (b: buffer, offset: number, value: number) -> (),
writef64: @checked (b: buffer, offset: number, value: number) -> (),
readstring: @checked (b: buffer, offset: number, count: number) -> string,
writestring: @checked (b: buffer, offset: number, value: string, count: number?) -> (),
readbits: @checked (b: buffer, bitOffset: number, bitCount: number) -> number,
writebits: @checked (b: buffer, bitOffset: number, bitCount: number, value: number) -> (),
}
-- TODO: this will be replaced with a built-in primitive type
declare class vector end
)BUILTIN_SRC";
static const std::string kBuiltinDefinitionVectorSrc_NoVector2Ctor_DEPRECATED = R"BUILTIN_SRC(
-- While vector would have been better represented as a built-in primitive type, type solver class handling covers most of the properties
declare class vector
x: number
y: number
z: number
end
declare vector: {
create: @checked (x: number, y: number, z: number) -> vector,
@ -489,7 +527,7 @@ declare class vector
end
declare vector: {
create: @checked (x: number, y: number, z: number) -> vector,
create: @checked (x: number, y: number, z: number?) -> vector,
magnitude: @checked (vec: vector) -> number,
normalize: @checked (vec: vector) -> vector,
cross: @checked (vec1: vector, vec2: vector) -> vector,
@ -511,12 +549,25 @@ declare vector: {
std::string getBuiltinDefinitionSource()
{
std::string result = FFlag::LuauMathMap ? kBuiltinDefinitionLuaSrcChecked : kBuiltinDefinitionLuaSrcChecked_DEPRECATED;
std::string result = FFlag::LuauMathMapDefinition ? kBuiltinDefinitionBaseSrc : kBuiltinDefinitionLuaSrcChecked_DEPRECATED;
if (FFlag::LuauVectorDefinitionsExtra)
if (FFlag::LuauMathMapDefinition)
{
result += kBuiltinDefinitionBit32Src;
result += kBuiltinDefinitionMathSrc;
result += kBuiltinDefinitionOsSrc;
result += kBuiltinDefinitionCoroutineSrc;
result += kBuiltinDefinitionTableSrc;
result += kBuiltinDefinitionDebugSrc;
result += kBuiltinDefinitionUtf8Src;
}
result += FFlag::LuauBufferBitMethods2 ? kBuiltinDefinitionBufferSrc : kBuiltinDefinitionBufferSrc_DEPRECATED;
if (FFlag::LuauVector2Constructor)
result += kBuiltinDefinitionVectorSrc;
else if (FFlag::LuauVectorDefinitions)
result += kBuiltinDefinitionVectorSrc_DEPRECATED;
else
result += kBuiltinDefinitionVectorSrc_NoVector2Ctor_DEPRECATED;
return result;
}

View file

@ -92,18 +92,24 @@ size_t TTable::Hash::operator()(const TTable& value) const
return hash;
}
uint32_t StringCache::add(std::string_view s)
StringId StringCache::add(std::string_view s)
{
size_t hash = std::hash<std::string_view>()(s);
if (uint32_t* it = strings.find(hash))
/* Important subtlety: This use of DenseHashMap<std::string_view, StringId>
* is okay because std::hash<std::string_view> works solely on the bytes
* referred by the string_view.
*
* In other words, two string views which contain the same bytes will have
* the same hash whether or not their addresses are the same.
*/
if (StringId* it = strings.find(s))
return *it;
char* storage = static_cast<char*>(allocator.allocate(s.size()));
memcpy(storage, s.data(), s.size());
uint32_t result = uint32_t(views.size());
StringId result = StringId(views.size());
views.emplace_back(storage, s.size());
strings[hash] = result;
strings[s] = result;
return result;
}
@ -143,6 +149,61 @@ static bool isTerminal(const EType& node)
node.get<TNever>() || node.get<TNoRefine>();
}
static bool areTerminalAndDefinitelyDisjoint(const EType& lhs, const EType& rhs)
{
// If either node is non-terminal, then we early exit: we're not going to
// do a state space search for whether something like:
// (A | B | C | D) & (E | F | G | H)
// ... is a disjoint intersection.
if (!isTerminal(lhs) || !isTerminal(rhs))
return false;
// Special case some types that aren't strict, disjoint subsets.
if (lhs.get<TTopClass>() || lhs.get<TClass>())
return !(rhs.get<TTopClass>() || rhs.get<TClass>());
// Handling strings / booleans: these are the types for which we
// expect something like:
//
// "foo" & ~"bar"
//
// ... to simplify to "foo".
if (lhs.get<TString>())
return !(rhs.get<TString>() || rhs.get<SString>());
if (lhs.get<TBoolean>())
return !(rhs.get<TBoolean>() || rhs.get<SBoolean>());
if (auto lhsSString = lhs.get<SString>())
{
auto rhsSString = rhs.get<SString>();
if (!rhsSString)
return !rhs.get<TString>();
return lhsSString->value() != rhsSString->value();
}
if (auto lhsSBoolean = lhs.get<SBoolean>())
{
auto rhsSBoolean = rhs.get<SBoolean>();
if (!rhsSBoolean)
return !rhs.get<TBoolean>();
return lhsSBoolean->value() != rhsSBoolean->value();
}
// At this point:
// - We know both nodes are terminal
// - We know that the LHS is not any boolean, string, or class
// At this point, we have two classes of checks left:
// - Whether the two enodes are exactly the same set (now that the static
// sets have been covered).
// - Whether one of the enodes is a large semantic set such as TAny,
// TUnknown, or TError.
return !(
lhs.index() == rhs.index() || lhs.get<TUnknown>() || rhs.get<TUnknown>() || lhs.get<TAny>() || rhs.get<TAny>() || lhs.get<TNoRefine>() ||
rhs.get<TNoRefine>() || lhs.get<TError>() || rhs.get<TError>() || lhs.get<TOpaque>() || rhs.get<TOpaque>()
);
}
static bool isTerminal(const EGraph& egraph, Id eclass)
{
const auto& nodes = egraph[eclass].nodes;
@ -151,7 +212,7 @@ static bool isTerminal(const EGraph& egraph, Id eclass)
nodes.end(),
[](auto& a)
{
return isTerminal(a);
return isTerminal(a.node);
}
);
}
@ -335,11 +396,31 @@ Id toId(
{
LUAU_ASSERT(tfun->packArguments.empty());
if (tfun->userFuncName) {
// TODO: User defined type functions are pseudo-effectful: error
// reporting is done via the `print` statement, so running a
// UDTF multiple times may end up double erroring. egraphs
// currently may induce type functions to be reduced multiple
// times. We should probably opt _not_ to process user defined
// type functions at all.
return egraph.add(TOpaque{ty});
}
std::vector<Id> parts;
parts.reserve(tfun->typeArguments.size());
for (TypeId part : tfun->typeArguments)
parts.push_back(toId(egraph, builtinTypes, mappingIdToClass, typeToMappingId, boundNodes, strings, part));
return cache(egraph.add(TTypeFun{tfun->function.get(), std::move(parts)}));
// This looks sily, but we're making a copy of the specific
// `TypeFunctionInstanceType` outside of the provided arena so that
// we can access the members without fear of the specific TFIT being
// overwritten with a bound type.
return cache(egraph.add(TTypeFun{
std::make_shared<const TypeFunctionInstanceType>(
tfun->function, tfun->typeArguments, tfun->packArguments, tfun->userFuncName, tfun->userFuncData
),
std::move(parts)
}));
}
else if (get<NoRefineType>(ty))
return egraph.add(TNoRefine{});
@ -399,7 +480,7 @@ static size_t computeCost(std::unordered_map<Id, size_t>& bestNodes, const EGrap
if (auto it = costs.find(id); it != costs.end())
return it->second;
const std::vector<EType>& nodes = egraph[id].nodes;
const std::vector<Node<EType>>& nodes = egraph[id].nodes;
size_t minCost = std::numeric_limits<size_t>::max();
size_t bestNode = std::numeric_limits<size_t>::max();
@ -416,7 +497,7 @@ static size_t computeCost(std::unordered_map<Id, size_t>& bestNodes, const EGrap
// First, quickly scan for a terminal type. If we can find one, it is obviously the best.
for (size_t index = 0; index < nodes.size(); ++index)
{
if (isTerminal(nodes[index]))
if (isTerminal(nodes[index].node))
{
minCost = 1;
bestNode = index;
@ -468,44 +549,44 @@ static size_t computeCost(std::unordered_map<Id, size_t>& bestNodes, const EGrap
{
const auto& node = nodes[index];
if (node.get<TBound>())
if (node.node.get<TBound>())
updateCost(BOUND_PENALTY, index); // TODO: This could probably be an assert now that we don't need rewrite rules to handle TBound.
else if (node.get<TFunction>())
else if (node.node.get<TFunction>())
{
minCost = 1;
bestNode = index;
}
else if (auto tbl = node.get<TTable>())
else if (auto tbl = node.node.get<TTable>())
{
// TODO: We could make the penalty a parameter to computeChildren.
std::optional<size_t> maybeCost = computeChildren(tbl->operands(), minCost);
if (maybeCost)
updateCost(TABLE_TYPE_PENALTY + *maybeCost, index);
}
else if (node.get<TImportedTable>())
else if (node.node.get<TImportedTable>())
{
minCost = IMPORTED_TABLE_PENALTY;
bestNode = index;
}
else if (auto u = node.get<Union>())
else if (auto u = node.node.get<Union>())
{
std::optional<size_t> maybeCost = computeChildren(u->operands(), minCost);
if (maybeCost)
updateCost(SET_TYPE_PENALTY + *maybeCost, index);
}
else if (auto i = node.get<Intersection>())
else if (auto i = node.node.get<Intersection>())
{
std::optional<size_t> maybeCost = computeChildren(i->operands(), minCost);
if (maybeCost)
updateCost(SET_TYPE_PENALTY + *maybeCost, index);
}
else if (auto negation = node.get<Negation>())
else if (auto negation = node.node.get<Negation>())
{
std::optional<size_t> maybeCost = computeChildren(negation->operands(), minCost);
if (maybeCost)
updateCost(NEGATION_PENALTY + *maybeCost, index);
}
else if (auto tfun = node.get<TTypeFun>())
else if (auto tfun = node.node.get<TTypeFun>())
{
std::optional<size_t> maybeCost = computeChildren(tfun->operands(), minCost);
if (maybeCost)
@ -574,28 +655,34 @@ TypeId flattenTableNode(
// If a TTable is its own basis, it must be the case that some other
// node on this eclass is a TImportedTable. Let's use that.
bool found = false;
for (size_t i = 0; i < eclass.nodes.size(); ++i)
{
if (eclass.nodes[i].get<TImportedTable>())
if (eclass.nodes[i].node.get<TImportedTable>())
{
found = true;
index = i;
break;
}
}
// If we couldn't find one, we don't know what to do. Use ErrorType.
LUAU_ASSERT(0);
return builtinTypes->errorType;
if (!found)
{
// If we couldn't find one, we don't know what to do. Use ErrorType.
LUAU_ASSERT(0);
return builtinTypes->errorType;
}
}
const auto& node = eclass.nodes[index];
if (const TTable* ttable = node.get<TTable>())
if (const TTable* ttable = node.node.get<TTable>())
{
stack.push_back(ttable);
id = ttable->getBasis();
continue;
}
else if (const TImportedTable* ti = node.get<TImportedTable>())
else if (const TImportedTable* ti = node.node.get<TImportedTable>())
{
importedTable = ti;
break;
@ -622,7 +709,8 @@ TypeId flattenTableNode(
StringId propName = t->propNames[i];
const Id propType = t->propTypes()[i];
resultTable.props[strings.asString(propName)] = Property{fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, propType)};
resultTable.props[strings.asString(propName)] =
Property{fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, propType)};
}
}
@ -646,7 +734,7 @@ TypeId fromId(
size_t index = bestNodes.at(rootId);
LUAU_ASSERT(index <= egraph[rootId].nodes.size());
const EType& node = egraph[rootId].nodes[index];
const EType& node = egraph[rootId].nodes[index].node;
if (node.get<TNil>())
return builtinTypes->nilType;
@ -703,7 +791,20 @@ TypeId fromId(
if (parts.empty())
return builtinTypes->neverType;
else if (parts.size() == 1)
return fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, parts[0]);
{
TypeId placeholder = arena->addType(BlockedType{});
seen[rootId] = placeholder;
auto result = fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, parts[0]);
if (follow(result) == placeholder)
{
emplaceType<GenericType>(asMutable(placeholder), "EGRAPH-SINGLETON-CYCLE");
}
else
{
emplaceType<BoundType>(asMutable(placeholder), result);
}
return result;
}
else
{
TypeId res = arena->addType(BlockedType{});
@ -768,7 +869,11 @@ TypeId fromId(
for (Id part : tfun->operands())
args.push_back(fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, part));
asMutable(res)->ty.emplace<TypeFunctionInstanceType>(*tfun->value(), std::move(args));
auto oldInstance = tfun->value();
asMutable(res)->ty.emplace<TypeFunctionInstanceType>(
oldInstance->function, std::move(args), std::vector<TypePackId>(), oldInstance->userFuncName, oldInstance->userFuncData
);
newTypeFunctions.push_back(res);
@ -848,12 +953,20 @@ std::string mkDesc(
const int RULE_PADDING = 35;
const std::string rulePadding(std::max<size_t>(0, RULE_PADDING - rule.size()), ' ');
const std::string fromIdStr = ""; // "(" + std::to_string(uint32_t(from)) + ") ";
const std::string toIdStr = ""; // "(" + std::to_string(uint32_t(to)) + ") ";
const std::string toIdStr = ""; // "(" + std::to_string(uint32_t(to)) + ") ";
return rule + ":" + rulePadding + fromIdStr + toString(fromTy, opts) + " <=> " + toIdStr + toString(toTy, opts);
}
std::string mkDesc(EGraph& egraph, const StringCache& strings, NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtinTypes, Id from, Id to, const std::string& rule)
std::string mkDesc(
EGraph& egraph,
const StringCache& strings,
NotNull<TypeArena> arena,
NotNull<BuiltinTypes> builtinTypes,
Id from,
Id to,
const std::string& rule
)
{
if (!FFlag::DebugLuauLogSimplification)
return "";
@ -906,7 +1019,7 @@ static std::string getNodeName(const StringCache& strings, const EType& node)
else if (node.get<TNever>())
return "never";
else if (auto tfun = node.get<TTypeFun>())
return "tfun " + tfun->value()->name;
return "tfun " + tfun->value()->function->name;
else if (node.get<Negation>())
return "~";
else if (node.get<Invalid>())
@ -928,8 +1041,9 @@ std::string toDot(const StringCache& strings, const EGraph& egraph)
for (const auto& [id, eclass] : egraph.getAllClasses())
{
for (const auto& node : eclass.nodes)
for (const auto& n : eclass.nodes)
{
const EType& node = n.node;
if (!node.operands().empty())
populated.insert(id);
for (Id op : node.operands())
@ -950,7 +1064,7 @@ std::string toDot(const StringCache& strings, const EGraph& egraph)
for (size_t index = 0; index < eclass.nodes.size(); ++index)
{
const auto& node = eclass.nodes[index];
const auto& node = eclass.nodes[index].node;
const std::string label = getNodeName(strings, node);
const std::string nodeName = "n" + std::to_string(uint32_t(id)) + "_" + std::to_string(index);
@ -965,7 +1079,7 @@ std::string toDot(const StringCache& strings, const EGraph& egraph)
{
for (size_t index = 0; index < eclass.nodes.size(); ++index)
{
const auto& node = eclass.nodes[index];
const auto& node = eclass.nodes[index].node;
const std::string label = getNodeName(strings, node);
const std::string nodeName = "n" + std::to_string(uint32_t(egraph.find(id))) + "_" + std::to_string(index);
@ -1001,7 +1115,7 @@ static Tag const* isTag(const EGraph& egraph, Id id)
{
for (const auto& node : egraph[id].nodes)
{
if (auto n = isTag<Tag>(node))
if (auto n = isTag<Tag>(node.node))
return n;
}
return nullptr;
@ -1037,7 +1151,7 @@ protected:
{
for (const auto& node : (*egraph)[id].nodes)
{
if (auto n = node.get<Tag>())
if (auto n = node.node.get<Tag>())
return n;
}
return nullptr;
@ -1225,8 +1339,10 @@ const EType* findSubtractableClass(const EGraph& egraph, std::unordered_set<Id>&
const EType* bestUnion = nullptr;
std::optional<size_t> unionSize;
for (const auto& node : egraph[id].nodes)
for (const auto& n : egraph[id].nodes)
{
const EType& node = n.node;
if (isTerminal(node))
return &node;
@ -1342,14 +1458,14 @@ bool subtract(EGraph& egraph, CanonicalizedType& ct, Id part)
return true;
}
Id fromCanonicalized(EGraph& egraph, CanonicalizedType& ct)
static std::pair<Id, size_t> fromCanonicalized(EGraph& egraph, CanonicalizedType& ct)
{
if (ct.isUnknown())
{
if (ct.errorPart)
return egraph.add(TAny{});
return {egraph.add(TAny{}), 1};
else
return egraph.add(TUnknown{});
return {egraph.add(TUnknown{}), 1};
}
std::vector<Id> parts;
@ -1387,7 +1503,12 @@ Id fromCanonicalized(EGraph& egraph, CanonicalizedType& ct)
parts.insert(parts.end(), ct.functionParts.begin(), ct.functionParts.end());
parts.insert(parts.end(), ct.otherParts.begin(), ct.otherParts.end());
return mkUnion(egraph, std::move(parts));
std::sort(parts.begin(), parts.end());
auto it = std::unique(parts.begin(), parts.end());
parts.erase(it, parts.end());
const size_t size = parts.size();
return {mkUnion(egraph, std::move(parts)), size};
}
void addChildren(const EGraph& egraph, const EType* enode, VecDeque<Id>& worklist)
@ -1433,7 +1554,7 @@ const Tag* Simplifier::isTag(Id id) const
{
for (const auto& node : get(id).nodes)
{
if (const Tag* ty = node.get<Tag>())
if (const Tag* ty = node.node.get<Tag>())
return ty;
}
@ -1467,6 +1588,16 @@ void Simplifier::subst(Id from, Id to, const std::string& ruleName, const std::u
substs.emplace_back(from, to, desc);
}
void Simplifier::subst(Id from, size_t boringIndex, Id to, const std::string& ruleName, const std::unordered_map<Id, size_t>& forceNodes)
{
std::string desc;
if (FFlag::DebugLuauLogSimplification)
desc = mkDesc(egraph, stringCache, arena, builtinTypes, from, to, forceNodes, ruleName);
egraph.markBoring(from, boringIndex);
substs.emplace_back(from, to, desc);
}
void Simplifier::unionClasses(std::vector<Id>& hereParts, Id there)
{
if (1 == hereParts.size() && isTag<TTopClass>(hereParts[0]))
@ -1517,9 +1648,12 @@ void Simplifier::simplifyUnion(Id id)
for (Id part : u->operands())
unionWithType(egraph, canonicalized, find(part));
Id resultId = fromCanonicalized(egraph, canonicalized);
const auto [resultId, newSize] = fromCanonicalized(egraph, canonicalized);
subst(id, resultId, "simplifyUnion", {{id, unionIndex}});
if (newSize < u->operands().size())
subst(id, unionIndex, resultId, "simplifyUnion", {{id, unionIndex}});
else
subst(id, resultId, "simplifyUnion", {{id, unionIndex}});
}
}
@ -1552,11 +1686,6 @@ std::optional<EType> intersectOne(EGraph& egraph, Id hereId, const EType* hereNo
thereNode->get<Intersection>() || thereNode->get<Negation>() || hereNode->get<TOpaque>() || thereNode->get<TOpaque>())
return std::nullopt;
if (hereNode->get<TAny>())
return *thereNode;
if (thereNode->get<TAny>())
return *hereNode;
if (hereNode->get<TUnknown>())
return *thereNode;
if (thereNode->get<TUnknown>())
@ -1732,7 +1861,7 @@ void Simplifier::uninhabitedIntersection(Id id)
const auto& partNodes = egraph[partId].nodes;
for (size_t partIndex = 0; partIndex < partNodes.size(); ++partIndex)
{
const EType& N = partNodes[partIndex];
const EType& N = partNodes[partIndex].node;
if (std::optional<EType> intersection = intersectOne(egraph, accumulator, &accumulatorNode, partId, &N))
{
if (isTag<TNever>(*intersection))
@ -1755,9 +1884,14 @@ void Simplifier::uninhabitedIntersection(Id id)
if ((unsimplified.empty() || !isTag<TUnknown>(accumulator)) && find(accumulator) != id)
unsimplified.push_back(accumulator);
const bool isSmaller = unsimplified.size() < parts.size();
const Id result = mkIntersection(egraph, std::move(unsimplified));
subst(id, result, "uninhabitedIntersection", {{id, index}});
if (isSmaller)
subst(id, index, result, "uninhabitedIntersection", {{id, index}});
else
subst(id, result, "uninhabitedIntersection", {{id, index}});
}
}
@ -1788,14 +1922,19 @@ void Simplifier::intersectWithNegatedClass(Id id)
const auto& iNodes = egraph[iId].nodes;
for (size_t iIndex = 0; iIndex < iNodes.size(); ++iIndex)
{
const EType& iNode = iNodes[iIndex];
const EType& iNode = iNodes[iIndex].node;
if (isTag<TNil>(iNode) || isTag<TBoolean>(iNode) || isTag<TNumber>(iNode) || isTag<TString>(iNode) || isTag<TThread>(iNode) ||
isTag<TTopFunction>(iNode) ||
// isTag<TTopTable>(iNode) || // I'm not sure about this one.
isTag<SBoolean>(iNode) || isTag<SString>(iNode) || isTag<TFunction>(iNode) || isTag<TNever>(iNode))
{
// eg string & ~SomeClass
subst(id, iId, "intersectClassWithNegatedClass", {{id, intersectionIndex}, {iId, iIndex}, {jId, negationIndex}, {negated, negatedClassIndex}});
subst(
id,
iId,
"intersectClassWithNegatedClass",
{{id, intersectionIndex}, {iId, iIndex}, {jId, negationIndex}, {negated, negatedClassIndex}}
);
return;
}
@ -1803,27 +1942,37 @@ void Simplifier::intersectWithNegatedClass(Id id)
{
switch (relateClasses(class_, negatedClass))
{
case LeftSuper:
// eg Instance & ~Part
// This cannot be meaningfully reduced.
continue;
case RightSuper:
subst(id, egraph.add(TNever{}), "intersectClassWithNegatedClass", {{id, intersectionIndex}, {iId, iIndex}, {jId, negationIndex}, {negated, negatedClassIndex}});
return;
case Unrelated:
// Part & ~Folder == Part
case LeftSuper:
// eg Instance & ~Part
// This cannot be meaningfully reduced.
continue;
case RightSuper:
subst(
id,
egraph.add(TNever{}),
"intersectClassWithNegatedClass",
{{id, intersectionIndex}, {iId, iIndex}, {jId, negationIndex}, {negated, negatedClassIndex}}
);
return;
case Unrelated:
// Part & ~Folder == Part
{
std::vector<Id> newParts;
newParts.reserve(intersection->operands().size() - 1);
for (Id part : intersection->operands())
{
std::vector<Id> newParts;
newParts.reserve(intersection->operands().size() - 1);
for (Id part : intersection->operands())
{
if (part != jId)
newParts.push_back(part);
}
Id substId = egraph.add(Intersection{newParts.begin(), newParts.end()});
subst(id, substId, "intersectClassWithNegatedClass", {{id, intersectionIndex}, {iId, iIndex}, {jId, negationIndex}, {negated, negatedClassIndex}});
if (part != jId)
newParts.push_back(part);
}
Id substId = mkIntersection(egraph, newParts);
subst(
id,
substId,
"intersectClassWithNegatedClass",
{{id, intersectionIndex}, {iId, iIndex}, {jId, negationIndex}, {negated, negatedClassIndex}}
);
}
}
}
}
@ -1839,6 +1988,74 @@ void Simplifier::intersectWithNegatedClass(Id id)
}
}
void Simplifier::intersectWithNegatedAtom(Id id)
{
// Let I and ~J be two arbitrary distinct operands of an intersection where
// I and J are terminal but are not type variables. (free, generic, or
// otherwise opaque)
//
// If I and J are equal, then the whole intersection is equivalent to never.
//
// If I and J are inequal, then J & ~I == J
for (const auto [intersection, intersectionIndex] : Query<Intersection>(&egraph, id))
{
const Slice<const Id>& intersectionOperands = intersection->operands();
for (size_t i = 0; i < intersectionOperands.size(); ++i)
{
for (const auto [negation, negationIndex] : Query<Negation>(&egraph, intersectionOperands[i]))
{
for (size_t negationOperandIndex = 0; negationOperandIndex < egraph[negation->operands()[0]].nodes.size(); ++negationOperandIndex)
{
const EType* negationOperand = &egraph[negation->operands()[0]].nodes[negationOperandIndex].node;
if (!isTerminal(*negationOperand) || negationOperand->get<TOpaque>())
continue;
for (size_t j = 0; j < intersectionOperands.size(); ++j)
{
if (j == i)
continue;
for (size_t jNodeIndex = 0; jNodeIndex < egraph[intersectionOperands[j]].nodes.size(); ++jNodeIndex)
{
const EType* jNode = &egraph[intersectionOperands[j]].nodes[jNodeIndex].node;
if (!isTerminal(*jNode) || jNode->get<TOpaque>())
continue;
if (*negationOperand == *jNode)
{
// eg "Hello" & ~"Hello"
// or boolean & ~boolean
subst(
id,
egraph.add(TNever{}),
"intersectWithNegatedAtom",
{{id, intersectionIndex}, {intersectionOperands[i], negationIndex}, {intersectionOperands[j], jNodeIndex}}
);
return;
}
else if (areTerminalAndDefinitelyDisjoint(*jNode, *negationOperand))
{
// eg "Hello" & ~"World"
// or boolean & ~string
std::vector<Id> newOperands(intersectionOperands.begin(), intersectionOperands.end());
newOperands.erase(newOperands.begin() + std::vector<Id>::difference_type(i));
subst(
id,
mkIntersection(egraph, std::move(newOperands)),
"intersectWithNegatedAtom",
{{id, intersectionIndex}, {intersectionOperands[i], negationIndex}, {intersectionOperands[j], jNodeIndex}}
);
}
}
}
}
}
}
}
}
void Simplifier::intersectWithNoRefine(Id id)
{
for (const auto pair : Query<Intersection>(&egraph, id))
@ -2003,7 +2220,7 @@ void Simplifier::expandNegation(Id id)
if (!ok)
continue;
subst(id, fromCanonicalized(egraph, canonicalized), "expandNegation", {{id, index}});
subst(id, fromCanonicalized(egraph, canonicalized).first, "expandNegation", {{id, index}});
}
}
@ -2160,7 +2377,7 @@ void Simplifier::intersectTableProperty(Id id)
subst(
id,
egraph.add(Intersection{std::move(newIntersectionParts)}),
mkIntersection(egraph, std::move(newIntersectionParts)),
"intersectTableProperty",
{{id, intersectionIndex}, {iId, table1Index}, {jId, table2Index}}
);
@ -2250,7 +2467,7 @@ void Simplifier::builtinTypeFunctions(Id id)
if (args.size() != 2)
continue;
const std::string& name = tfun->value()->name;
const std::string& name = tfun->value()->function->name;
if (name == "add" || name == "sub" || name == "mul" || name == "div" || name == "idiv" || name == "pow" || name == "mod")
{
if (isTag<TNumber>(args[0]) && isTag<TNumber>(args[1]))
@ -2272,15 +2489,43 @@ void Simplifier::iffyTypeFunctions(Id id)
{
const Slice<const Id>& args = tfun->operands();
const std::string& name = tfun->value()->name;
const std::string& name = tfun->value()->function->name;
if (name == "union")
subst(id, add(Union{std::vector(args.begin(), args.end())}), "iffyTypeFunctions", {{id, index}});
else if (name == "intersect" || name == "refine")
else if (name == "intersect")
subst(id, add(Intersection{std::vector(args.begin(), args.end())}), "iffyTypeFunctions", {{id, index}});
}
}
// Replace instances of `lt<X, Y>` and `le<X, Y>` when either X or Y is `number`
// or `string` with `boolean`. Lua semantics are that if we see the expression:
//
// x < y
//
// ... we error if `x` and `y` don't have the same type. We know that for
// `string` and `number`, comparisons will always return a boolean. So if either
// of the arguments to `lt<>` are equivalent to `number` or `string`, then the
// type is effectively `boolean`: either the other type is equivalent, in which
// case we eval to `boolean`, or we diverge (raise an error).
void Simplifier::strictMetamethods(Id id)
{
for (const auto [tfun, index] : Query<TTypeFun>(&egraph, id))
{
const Slice<const Id>& args = tfun->operands();
const std::string& name = tfun->value()->function->name;
if (!(name == "lt" || name == "le") || args.size() != 2)
continue;
if (isTag<TNumber>(args[0]) || isTag<TString>(args[0]) || isTag<TNumber>(args[1]) || isTag<TString>(args[1]))
{
subst(id, add(TBoolean{}), __FUNCTION__, {{id, index}});
}
}
}
static void deleteSimplifier(Simplifier* s)
{
delete s;
@ -2308,6 +2553,7 @@ std::optional<EqSatSimplificationResult> eqSatSimplify(NotNull<Simplifier> simpl
&Simplifier::simplifyUnion,
&Simplifier::uninhabitedIntersection,
&Simplifier::intersectWithNegatedClass,
&Simplifier::intersectWithNegatedAtom,
&Simplifier::intersectWithNoRefine,
&Simplifier::cyclicIntersectionOfUnion,
&Simplifier::cyclicUnionOfIntersection,
@ -2318,6 +2564,7 @@ std::optional<EqSatSimplificationResult> eqSatSimplify(NotNull<Simplifier> simpl
&Simplifier::unneededTableModification,
&Simplifier::builtinTypeFunctions,
&Simplifier::iffyTypeFunctions,
&Simplifier::strictMetamethods,
};
std::unordered_set<Id> seen;
@ -2371,9 +2618,9 @@ std::optional<EqSatSimplificationResult> eqSatSimplify(NotNull<Simplifier> simpl
// try to run any rules on it.
bool shouldAbort = false;
for (const EType& enode : egraph[id].nodes)
for (const auto& enode : egraph[id].nodes)
{
if (isTerminal(enode))
if (isTerminal(enode.node))
{
shouldAbort = true;
break;
@ -2383,8 +2630,8 @@ std::optional<EqSatSimplificationResult> eqSatSimplify(NotNull<Simplifier> simpl
if (shouldAbort)
continue;
for (const EType& enode : egraph[id].nodes)
addChildren(egraph, &enode, worklist);
for (const auto& enode : egraph[id].nodes)
addChildren(egraph, &enode.node, worklist);
for (Simplifier::RewriteRuleFn rule : rules)
(simplifier.get()->*rule)(id);

View file

@ -6,6 +6,7 @@
#include "Luau/Autocomplete.h"
#include "Luau/Common.h"
#include "Luau/EqSatSimplification.h"
#include "Luau/ModuleResolver.h"
#include "Luau/Parser.h"
#include "Luau/ParseOptions.h"
#include "Luau/Module.h"
@ -19,16 +20,21 @@
#include "Luau/Parser.h"
#include "Luau/ParseOptions.h"
#include "Luau/Module.h"
#include "Luau/Clone.h"
#include "AutocompleteCore.h"
LUAU_FASTINT(LuauTypeInferRecursionLimit);
LUAU_FASTINT(LuauTypeInferIterationLimit);
LUAU_FASTINT(LuauTarjanChildLimit)
LUAU_FASTFLAG(LuauAllowFragmentParsing);
LUAU_FASTFLAG(LuauAutocompleteRefactorsForIncrementalAutocomplete)
LUAU_FASTFLAGVARIABLE(LuauIncrementalAutocompleteBugfixes)
LUAU_FASTFLAG(LuauReferenceAllocatorInNewSolver)
LUAU_FASTFLAGVARIABLE(LuauMixedModeDefFinderTraversesTypeOf)
LUAU_FASTFLAG(LuauBetterReverseDependencyTracking)
LUAU_FASTFLAGVARIABLE(LuauCloneIncrementalModule)
namespace
{
template<typename T>
@ -49,6 +55,96 @@ void copyModuleMap(Luau::DenseHashMap<K, V>& result, const Luau::DenseHashMap<K,
namespace Luau
{
template<typename K, typename V>
void cloneModuleMap(TypeArena& destArena, CloneState& cloneState, const Luau::DenseHashMap<K, V>& source, Luau::DenseHashMap<K, V>& dest)
{
for (auto [k, v] : source)
{
dest[k] = Luau::clone(v, destArena, cloneState);
}
}
struct MixedModeIncrementalTCDefFinder : public AstVisitor
{
bool visit(AstExprLocal* local) override
{
referencedLocalDefs.emplace_back(local->local, local);
return true;
}
bool visit(AstTypeTypeof* node) override
{
// We need to traverse typeof expressions because they may refer to locals that we need
// to populate the local environment for fragment typechecking. For example, `typeof(m)`
// requires that we find the local/global `m` and place it in the environment.
// The default behaviour here is to return false, and have individual visitors override
// the specific behaviour they need.
return FFlag::LuauMixedModeDefFinderTraversesTypeOf;
}
// ast defs is just a mapping from expr -> def in general
// will get built up by the dfg builder
// localDefs, we need to copy over
std::vector<std::pair<AstLocal*, AstExpr*>> referencedLocalDefs;
};
void cloneAndSquashScopes(
CloneState& cloneState,
const Scope* staleScope,
const ModulePtr& staleModule,
NotNull<TypeArena> destArena,
NotNull<DataFlowGraph> dfg,
AstStatBlock* program,
Scope* destScope
)
{
std::vector<const Scope*> scopes;
for (const Scope* current = staleScope; current; current = current->parent.get())
{
scopes.emplace_back(current);
}
// in reverse order (we need to clone the parents and override defs as we go down the list)
for (auto it = scopes.rbegin(); it != scopes.rend(); ++it)
{
const Scope* curr = *it;
// Clone the lvalue types
for (const auto& [def, ty] : curr->lvalueTypes)
destScope->lvalueTypes[def] = Luau::clone(ty, *destArena, cloneState);
// Clone the rvalueRefinements
for (const auto& [def, ty] : curr->rvalueRefinements)
destScope->rvalueRefinements[def] = Luau::clone(ty, *destArena, cloneState);
for (const auto& [n, m] : curr->importedTypeBindings)
{
std::unordered_map<Name, TypeFun> importedBindingTypes;
for (const auto& [v, tf] : m)
importedBindingTypes[v] = Luau::clone(tf, *destArena, cloneState);
destScope->importedTypeBindings[n] = m;
}
// Finally, clone up the bindings
for (const auto& [s, b] : curr->bindings)
{
destScope->bindings[s] = Luau::clone(b, *destArena, cloneState);
}
}
// The above code associates defs with TypeId's in the scope
// so that lookup to locals will succeed.
MixedModeIncrementalTCDefFinder finder;
program->visit(&finder);
std::vector<std::pair<AstLocal*, AstExpr*>> locals = std::move(finder.referencedLocalDefs);
for (auto [loc, expr] : locals)
{
if (std::optional<Binding> binding = staleScope->linearSearchForBinding(loc->name.value, true))
{
destScope->lvalueTypes[dfg->getDef(expr)] = Luau::clone(binding->typeId, *destArena, cloneState);
}
}
return;
}
static FrontendModuleResolver& getModuleResolver(Frontend& frontend, std::optional<FrontendOptions> options)
{
if (FFlag::LuauSolverV2 || !options)
@ -200,7 +296,7 @@ ScopePtr findClosestScope(const ModulePtr& module, const AstStat* nearestStateme
return closest;
}
FragmentParseResult parseFragment(
std::optional<FragmentParseResult> parseFragment(
const SourceModule& srcModule,
std::string_view src,
const Position& cursorPos,
@ -245,6 +341,9 @@ FragmentParseResult parseFragment(
opts.captureComments = true;
opts.parseFragment = FragmentParseResumeSettings{std::move(result.localMap), std::move(result.localStack), startPos};
ParseResult p = Luau::Parser::parse(srcStart, parseLength, *nameTbl, *fragmentResult.alloc.get(), opts);
// This means we threw a ParseError and we should decline to offer autocomplete here.
if (p.root == nullptr)
return std::nullopt;
std::vector<AstNode*> fabricatedAncestry = std::move(result.ancestry);
@ -258,16 +357,39 @@ FragmentParseResult parseFragment(
fragmentResult.root = std::move(p.root);
fragmentResult.ancestry = std::move(fabricatedAncestry);
fragmentResult.nearestStatement = nearestStatement;
fragmentResult.commentLocations = std::move(p.commentLocations);
return fragmentResult;
}
ModulePtr cloneModule(CloneState& cloneState, const ModulePtr& source, std::unique_ptr<Allocator> alloc)
{
freeze(source->internalTypes);
freeze(source->interfaceTypes);
ModulePtr incremental = std::make_shared<Module>();
incremental->name = source->name;
incremental->humanReadableName = source->humanReadableName;
incremental->allocator = std::move(alloc);
// Clone types
cloneModuleMap(incremental->internalTypes, cloneState, source->astTypes, incremental->astTypes);
cloneModuleMap(incremental->internalTypes, cloneState, source->astTypePacks, incremental->astTypePacks);
cloneModuleMap(incremental->internalTypes, cloneState, source->astExpectedTypes, incremental->astExpectedTypes);
cloneModuleMap(incremental->internalTypes, cloneState, source->astOverloadResolvedTypes, incremental->astOverloadResolvedTypes);
cloneModuleMap(incremental->internalTypes, cloneState, source->astForInNextTypes, incremental->astForInNextTypes);
copyModuleMap(incremental->astScopes, source->astScopes);
return incremental;
}
ModulePtr copyModule(const ModulePtr& result, std::unique_ptr<Allocator> alloc)
{
freeze(result->internalTypes);
freeze(result->interfaceTypes);
ModulePtr incrementalModule = std::make_shared<Module>();
incrementalModule->name = result->name;
incrementalModule->humanReadableName = result->humanReadableName;
incrementalModule->humanReadableName = "Incremental$" + result->humanReadableName;
incrementalModule->internalTypes.owningModule = incrementalModule.get();
incrementalModule->interfaceTypes.owningModule = incrementalModule.get();
incrementalModule->allocator = std::move(alloc);
// Don't need to keep this alive (it's already on the source module)
copyModuleVec(incrementalModule->scopes, result->scopes);
@ -286,21 +408,6 @@ ModulePtr copyModule(const ModulePtr& result, std::unique_ptr<Allocator> alloc)
return incrementalModule;
}
struct MixedModeIncrementalTCDefFinder : public AstVisitor
{
bool visit(AstExprLocal* local) override
{
referencedLocalDefs.push_back({local->local, local});
return true;
}
// ast defs is just a mapping from expr -> def in general
// will get built up by the dfg builder
// localDefs, we need to copy over
std::vector<std::pair<AstLocal*, AstExpr*>> referencedLocalDefs;
};
void mixedModeCompatibility(
const ScopePtr& bottomScopeStale,
const ScopePtr& myFakeScope,
@ -339,7 +446,9 @@ FragmentTypeCheckResult typecheckFragment_(
{
freeze(stale->internalTypes);
freeze(stale->interfaceTypes);
ModulePtr incrementalModule = copyModule(stale, std::move(astAllocator));
CloneState cloneState{frontend.builtinTypes};
ModulePtr incrementalModule =
FFlag::LuauCloneIncrementalModule ? cloneModule(cloneState, stale, std::move(astAllocator)) : copyModule(stale, std::move(astAllocator));
incrementalModule->checkedInNewSolver = true;
unfreeze(incrementalModule->internalTypes);
unfreeze(incrementalModule->interfaceTypes);
@ -366,7 +475,8 @@ FragmentTypeCheckResult typecheckFragment_(
TypeFunctionRuntime typeFunctionRuntime(iceHandler, NotNull{&limits});
/// Create a DataFlowGraph just for the surrounding context
auto dfg = DataFlowGraphBuilder::build(root, iceHandler);
DataFlowGraph dfg = DataFlowGraphBuilder::build(root, NotNull{&incrementalModule->defArena}, NotNull{&incrementalModule->keyArena}, iceHandler);
SimplifierPtr simplifier = newSimplifier(NotNull{&incrementalModule->internalTypes}, frontend.builtinTypes);
FrontendModuleResolver& resolver = getModuleResolver(frontend, opts);
@ -386,25 +496,34 @@ FragmentTypeCheckResult typecheckFragment_(
NotNull{&dfg},
{}
};
std::shared_ptr<Scope> freshChildOfNearestScope = nullptr;
if (FFlag::LuauCloneIncrementalModule)
{
freshChildOfNearestScope = std::make_shared<Scope>(closestScope);
incrementalModule->scopes.emplace_back(root->location, freshChildOfNearestScope);
cg.rootScope = freshChildOfNearestScope.get();
cg.rootScope = stale->getModuleScope().get();
// Any additions to the scope must occur in a fresh scope
auto freshChildOfNearestScope = std::make_shared<Scope>(closestScope);
incrementalModule->scopes.emplace_back(root->location, freshChildOfNearestScope);
// Update freshChildOfNearestScope with the appropriate lvalueTypes
mixedModeCompatibility(closestScope, freshChildOfNearestScope, stale, NotNull{&dfg}, root);
// closest Scope -> children = { ...., freshChildOfNearestScope}
// We need to trim nearestChild from the scope hierarcy
closestScope->children.push_back(NotNull{freshChildOfNearestScope.get()});
// Visit just the root - we know the scope it should be in
cg.visitFragmentRoot(freshChildOfNearestScope, root);
// Trim nearestChild from the closestScope
Scope* back = closestScope->children.back().get();
LUAU_ASSERT(back == freshChildOfNearestScope.get());
closestScope->children.pop_back();
cloneAndSquashScopes(
cloneState, closestScope.get(), stale, NotNull{&incrementalModule->internalTypes}, NotNull{&dfg}, root, freshChildOfNearestScope.get()
);
cg.visitFragmentRoot(freshChildOfNearestScope, root);
}
else
{
// Any additions to the scope must occur in a fresh scope
cg.rootScope = stale->getModuleScope().get();
freshChildOfNearestScope = std::make_shared<Scope>(closestScope);
incrementalModule->scopes.emplace_back(root->location, freshChildOfNearestScope);
mixedModeCompatibility(closestScope, freshChildOfNearestScope, stale, NotNull{&dfg}, root);
// closest Scope -> children = { ...., freshChildOfNearestScope}
// We need to trim nearestChild from the scope hierarcy
closestScope->children.emplace_back(freshChildOfNearestScope.get());
cg.visitFragmentRoot(freshChildOfNearestScope, root);
// Trim nearestChild from the closestScope
Scope* back = closestScope->children.back().get();
LUAU_ASSERT(back == freshChildOfNearestScope.get());
closestScope->children.pop_back();
}
/// Initialize the constraint solver and run it
ConstraintSolver cs{
@ -444,7 +563,7 @@ FragmentTypeCheckResult typecheckFragment_(
}
FragmentTypeCheckResult typecheckFragment(
std::pair<FragmentTypeCheckStatus, FragmentTypeCheckResult> typecheckFragment(
Frontend& frontend,
const ModuleName& moduleName,
const Position& cursorPos,
@ -453,6 +572,13 @@ FragmentTypeCheckResult typecheckFragment(
std::optional<Position> fragmentEndPosition
)
{
if (FFlag::LuauBetterReverseDependencyTracking)
{
if (!frontend.allModuleDependenciesValid(moduleName, opts && opts->forAutocomplete))
return {FragmentTypeCheckStatus::SkipAutocomplete, {}};
}
const SourceModule* sourceModule = frontend.getSourceModule(moduleName);
if (!sourceModule)
{
@ -468,13 +594,30 @@ FragmentTypeCheckResult typecheckFragment(
return {};
}
FragmentParseResult parseResult = parseFragment(*sourceModule, src, cursorPos, fragmentEndPosition);
if (FFlag::LuauIncrementalAutocompleteBugfixes && FFlag::LuauReferenceAllocatorInNewSolver)
{
if (sourceModule->allocator.get() != module->allocator.get())
{
return {FragmentTypeCheckStatus::SkipAutocomplete, {}};
}
}
auto tryParse = parseFragment(*sourceModule, src, cursorPos, fragmentEndPosition);
if (!tryParse)
return {FragmentTypeCheckStatus::SkipAutocomplete, {}};
FragmentParseResult& parseResult = *tryParse;
if (isWithinComment(parseResult.commentLocations, fragmentEndPosition.value_or(cursorPos)))
return {FragmentTypeCheckStatus::SkipAutocomplete, {}};
FrontendOptions frontendOptions = opts.value_or(frontend.options);
const ScopePtr& closestScope = findClosestScope(module, parseResult.nearestStatement);
FragmentTypeCheckResult result =
typecheckFragment_(frontend, parseResult.root, module, closestScope, cursorPos, std::move(parseResult.alloc), frontendOptions);
result.ancestry = std::move(parseResult.ancestry);
return result;
return {FragmentTypeCheckStatus::Success, result};
}
@ -498,7 +641,14 @@ FragmentAutocompleteResult fragmentAutocomplete(
return {};
}
auto tcResult = typecheckFragment(frontend, moduleName, cursorPosition, opts, src, fragmentEndPosition);
// If the cursor is within a comment in the stale source module we should avoid providing a recommendation
if (isWithinComment(*sourceModule, fragmentEndPosition.value_or(cursorPosition)))
return {};
auto [tcStatus, tcResult] = typecheckFragment(frontend, moduleName, cursorPosition, opts, src, fragmentEndPosition);
if (tcStatus == FragmentTypeCheckStatus::SkipAutocomplete)
return {};
auto globalScope = (opts && opts->forAutocomplete) ? frontend.globalsForAutocomplete.globalScope.get() : frontend.globals.globalScope.get();
TypeArena arenaForFragmentAutocomplete;

View file

@ -13,6 +13,7 @@
#include "Luau/EqSatSimplification.h"
#include "Luau/FileResolver.h"
#include "Luau/NonStrictTypeChecker.h"
#include "Luau/NotNull.h"
#include "Luau/Parser.h"
#include "Luau/Scope.h"
#include "Luau/StringUtils.h"
@ -38,7 +39,6 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit)
LUAU_FASTINT(LuauTarjanChildLimit)
LUAU_FASTFLAG(LuauInferInNoCheckMode)
LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3)
LUAU_FASTFLAGVARIABLE(LuauStoreCommentsForDefinitionFiles)
LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson)
LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJsonFile)
@ -47,9 +47,14 @@ LUAU_FASTFLAGVARIABLE(DebugLuauForceStrictMode)
LUAU_FASTFLAGVARIABLE(DebugLuauForceNonStrictMode)
LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauRunCustomModuleChecks, false)
LUAU_FASTFLAGVARIABLE(LuauBetterReverseDependencyTracking)
LUAU_FASTFLAG(StudioReportLuauAny2)
LUAU_FASTFLAGVARIABLE(LuauStoreSolverTypeOnModule)
LUAU_FASTFLAGVARIABLE(LuauReferenceAllocatorInNewSolver)
LUAU_FASTFLAGVARIABLE(LuauSelectivelyRetainDFGArena)
namespace Luau
{
@ -135,7 +140,7 @@ static ParseResult parseSourceForModule(std::string_view source, Luau::SourceMod
sourceModule.root = parseResult.root;
sourceModule.mode = Mode::Definition;
if (FFlag::LuauStoreCommentsForDefinitionFiles && options.captureComments)
if (options.captureComments)
{
sourceModule.hotcomments = parseResult.hotcomments;
sourceModule.commentLocations = parseResult.commentLocations;
@ -817,6 +822,16 @@ bool Frontend::parseGraph(
topseen = Permanent;
buildQueue.push_back(top->name);
if (FFlag::LuauBetterReverseDependencyTracking)
{
// at this point we know all valid dependencies are processed into SourceNodes
for (const ModuleName& dep : top->requireSet)
{
if (auto it = sourceNodes.find(dep); it != sourceNodes.end())
it->second->dependents.insert(top->name);
}
}
}
else
{
@ -1046,6 +1061,11 @@ void Frontend::checkBuildQueueItem(BuildQueueItem& item)
freeze(module->interfaceTypes);
module->internalTypes.clear();
if (FFlag::LuauSelectivelyRetainDFGArena)
{
module->defArena.allocator.clear();
module->keyArena.allocator.clear();
}
module->astTypes.clear();
module->astTypePacks.clear();
@ -1099,15 +1119,49 @@ void Frontend::recordItemResult(const BuildQueueItem& item)
if (item.exception)
std::rethrow_exception(item.exception);
if (item.options.forAutocomplete)
if (FFlag::LuauBetterReverseDependencyTracking)
{
moduleResolverForAutocomplete.setModule(item.name, item.module);
item.sourceNode->dirtyModuleForAutocomplete = false;
bool replacedModule = false;
if (item.options.forAutocomplete)
{
replacedModule = moduleResolverForAutocomplete.setModule(item.name, item.module);
item.sourceNode->dirtyModuleForAutocomplete = false;
}
else
{
replacedModule = moduleResolver.setModule(item.name, item.module);
item.sourceNode->dirtyModule = false;
}
if (replacedModule)
{
LUAU_TIMETRACE_SCOPE("Frontend::invalidateDependentModules", "Frontend");
LUAU_TIMETRACE_ARGUMENT("name", item.name.c_str());
traverseDependents(
item.name,
[forAutocomplete = item.options.forAutocomplete](SourceNode& sourceNode)
{
bool traverseSubtree = !sourceNode.hasInvalidModuleDependency(forAutocomplete);
sourceNode.setInvalidModuleDependency(true, forAutocomplete);
return traverseSubtree;
}
);
}
item.sourceNode->setInvalidModuleDependency(false, item.options.forAutocomplete);
}
else
{
moduleResolver.setModule(item.name, item.module);
item.sourceNode->dirtyModule = false;
if (item.options.forAutocomplete)
{
moduleResolverForAutocomplete.setModule(item.name, item.module);
item.sourceNode->dirtyModuleForAutocomplete = false;
}
else
{
moduleResolver.setModule(item.name, item.module);
item.sourceNode->dirtyModule = false;
}
}
stats.timeCheck += item.stats.timeCheck;
@ -1144,6 +1198,13 @@ ScopePtr Frontend::getModuleEnvironment(const SourceModule& module, const Config
return result;
}
bool Frontend::allModuleDependenciesValid(const ModuleName& name, bool forAutocomplete) const
{
LUAU_ASSERT(FFlag::LuauBetterReverseDependencyTracking);
auto it = sourceNodes.find(name);
return it != sourceNodes.end() && !it->second->hasInvalidModuleDependency(forAutocomplete);
}
bool Frontend::isDirty(const ModuleName& name, bool forAutocomplete) const
{
auto it = sourceNodes.find(name);
@ -1158,16 +1219,80 @@ bool Frontend::isDirty(const ModuleName& name, bool forAutocomplete) const
*/
void Frontend::markDirty(const ModuleName& name, std::vector<ModuleName>* markedDirty)
{
LUAU_TIMETRACE_SCOPE("Frontend::markDirty", "Frontend");
LUAU_TIMETRACE_ARGUMENT("name", name.c_str());
if (FFlag::LuauBetterReverseDependencyTracking)
{
traverseDependents(
name,
[markedDirty](SourceNode& sourceNode)
{
if (markedDirty)
markedDirty->push_back(sourceNode.name);
if (sourceNode.dirtySourceModule && sourceNode.dirtyModule && sourceNode.dirtyModuleForAutocomplete)
return false;
sourceNode.dirtySourceModule = true;
sourceNode.dirtyModule = true;
sourceNode.dirtyModuleForAutocomplete = true;
return true;
}
);
}
else
{
if (sourceNodes.count(name) == 0)
return;
std::unordered_map<ModuleName, std::vector<ModuleName>> reverseDeps;
for (const auto& module : sourceNodes)
{
for (const auto& dep : module.second->requireSet)
reverseDeps[dep].push_back(module.first);
}
std::vector<ModuleName> queue{name};
while (!queue.empty())
{
ModuleName next = std::move(queue.back());
queue.pop_back();
LUAU_ASSERT(sourceNodes.count(next) > 0);
SourceNode& sourceNode = *sourceNodes[next];
if (markedDirty)
markedDirty->push_back(next);
if (sourceNode.dirtySourceModule && sourceNode.dirtyModule && sourceNode.dirtyModuleForAutocomplete)
continue;
sourceNode.dirtySourceModule = true;
sourceNode.dirtyModule = true;
sourceNode.dirtyModuleForAutocomplete = true;
if (0 == reverseDeps.count(next))
continue;
sourceModules.erase(next);
const std::vector<ModuleName>& dependents = reverseDeps[next];
queue.insert(queue.end(), dependents.begin(), dependents.end());
}
}
}
void Frontend::traverseDependents(const ModuleName& name, std::function<bool(SourceNode&)> processSubtree)
{
LUAU_ASSERT(FFlag::LuauBetterReverseDependencyTracking);
LUAU_TIMETRACE_SCOPE("Frontend::traverseDependents", "Frontend");
if (sourceNodes.count(name) == 0)
return;
std::unordered_map<ModuleName, std::vector<ModuleName>> reverseDeps;
for (const auto& module : sourceNodes)
{
for (const auto& dep : module.second->requireSet)
reverseDeps[dep].push_back(module.first);
}
std::vector<ModuleName> queue{name};
while (!queue.empty())
@ -1178,22 +1303,10 @@ void Frontend::markDirty(const ModuleName& name, std::vector<ModuleName>* marked
LUAU_ASSERT(sourceNodes.count(next) > 0);
SourceNode& sourceNode = *sourceNodes[next];
if (markedDirty)
markedDirty->push_back(next);
if (sourceNode.dirtySourceModule && sourceNode.dirtyModule && sourceNode.dirtyModuleForAutocomplete)
if (!processSubtree(sourceNode))
continue;
sourceNode.dirtySourceModule = true;
sourceNode.dirtyModule = true;
sourceNode.dirtyModuleForAutocomplete = true;
if (0 == reverseDeps.count(next))
continue;
sourceModules.erase(next);
const std::vector<ModuleName>& dependents = reverseDeps[next];
const Set<ModuleName>& dependents = sourceNode.dependents;
queue.insert(queue.end(), dependents.begin(), dependents.end());
}
}
@ -1317,6 +1430,11 @@ ModulePtr check(
result->mode = mode;
result->internalTypes.owningModule = result.get();
result->interfaceTypes.owningModule = result.get();
if (FFlag::LuauReferenceAllocatorInNewSolver)
{
result->allocator = sourceModule.allocator;
result->names = sourceModule.names;
}
iceHandler->moduleName = sourceModule.name;
@ -1331,7 +1449,7 @@ ModulePtr check(
}
}
DataFlowGraph dfg = DataFlowGraphBuilder::build(sourceModule.root, iceHandler);
DataFlowGraph dfg = DataFlowGraphBuilder::build(sourceModule.root, NotNull{&result->defArena}, NotNull{&result->keyArena}, iceHandler);
UnifierSharedState unifierState{iceHandler};
unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit;
@ -1427,6 +1545,7 @@ ModulePtr check(
case Mode::Nonstrict:
Luau::checkNonStrict(
builtinTypes,
NotNull{simplifier.get()},
NotNull{&typeFunctionRuntime},
iceHandler,
NotNull{&unifierState},
@ -1440,7 +1559,14 @@ ModulePtr check(
// fallthrough intentional
case Mode::Strict:
Luau::check(
builtinTypes, NotNull{&typeFunctionRuntime}, NotNull{&unifierState}, NotNull{&limits}, logger.get(), sourceModule, result.get()
builtinTypes,
NotNull{simplifier.get()},
NotNull{&typeFunctionRuntime},
NotNull{&unifierState},
NotNull{&limits},
logger.get(),
sourceModule,
result.get()
);
break;
case Mode::NoCheck:
@ -1622,6 +1748,17 @@ std::pair<SourceNode*, SourceModule*> Frontend::getSourceNode(const ModuleName&
sourceNode->name = sourceModule->name;
sourceNode->humanReadableName = sourceModule->humanReadableName;
if (FFlag::LuauBetterReverseDependencyTracking)
{
// clear all prior dependents. we will re-add them after parsing the rest of the graph
for (const auto& [moduleName, _] : sourceNode->requireLocations)
{
if (auto depIt = sourceNodes.find(moduleName); depIt != sourceNodes.end())
depIt->second->dependents.erase(sourceNode->name);
}
}
sourceNode->requireSet.clear();
sourceNode->requireLocations.clear();
sourceNode->dirtySourceModule = false;
@ -1743,11 +1880,21 @@ std::string FrontendModuleResolver::getHumanReadableModuleName(const ModuleName&
return frontend->fileResolver->getHumanReadableModuleName(moduleName);
}
void FrontendModuleResolver::setModule(const ModuleName& moduleName, ModulePtr module)
bool FrontendModuleResolver::setModule(const ModuleName& moduleName, ModulePtr module)
{
std::scoped_lock lock(moduleMutex);
modules[moduleName] = std::move(module);
if (FFlag::LuauBetterReverseDependencyTracking)
{
bool replaced = modules.count(moduleName) > 0;
modules[moduleName] = std::move(module);
return replaced;
}
else
{
modules[moduleName] = std::move(module);
return false;
}
}
void FrontendModuleResolver::clearModules()

View file

@ -10,12 +10,14 @@
#include "Luau/VisitType.h"
LUAU_FASTFLAG(LuauAutocompleteRefactorsForIncrementalAutocomplete)
LUAU_FASTFLAGVARIABLE(LuauGeneralizationRemoveRecursiveUpperBound)
namespace Luau
{
struct MutatingGeneralizer : TypeOnceVisitor
{
NotNull<TypeArena> arena;
NotNull<BuiltinTypes> builtinTypes;
NotNull<Scope> scope;
@ -29,6 +31,7 @@ struct MutatingGeneralizer : TypeOnceVisitor
bool avoidSealingTables = false;
MutatingGeneralizer(
NotNull<TypeArena> arena,
NotNull<BuiltinTypes> builtinTypes,
NotNull<Scope> scope,
NotNull<DenseHashSet<TypeId>> cachedTypes,
@ -37,6 +40,7 @@ struct MutatingGeneralizer : TypeOnceVisitor
bool avoidSealingTables
)
: TypeOnceVisitor(/* skipBoundTypes */ true)
, arena(arena)
, builtinTypes(builtinTypes)
, scope(scope)
, cachedTypes(cachedTypes)
@ -229,6 +233,53 @@ struct MutatingGeneralizer : TypeOnceVisitor
else
{
TypeId ub = follow(ft->upperBound);
if (FFlag::LuauGeneralizationRemoveRecursiveUpperBound)
{
// If the upper bound is a union type or an intersection type,
// and one of it's members is the free type we're
// generalizing, don't include it in the upper bound. For a
// free type such as:
//
// t1 where t1 = D <: 'a <: (A | B | C | t1)
//
// Naively replacing it with it's upper bound creates:
//
// t1 where t1 = A | B | C | t1
//
// It makes sense to just optimize this and exclude the
// recursive component by semantic subtyping rules.
if (auto itv = get<IntersectionType>(ub))
{
std::vector<TypeId> newIds;
newIds.reserve(itv->parts.size());
for (auto part : itv)
{
if (part != ty)
newIds.push_back(part);
}
if (newIds.size() == 1)
ub = newIds[0];
else if (newIds.size() > 0)
ub = arena->addType(IntersectionType{std::move(newIds)});
}
else if (auto utv = get<UnionType>(ub))
{
std::vector<TypeId> newIds;
newIds.reserve(utv->options.size());
for (auto part : utv)
{
if (part != ty)
newIds.push_back(part);
}
if (newIds.size() == 1)
ub = newIds[0];
else if (newIds.size() > 0)
ub = arena->addType(UnionType{std::move(newIds)});
}
}
if (FreeType* upperFree = getMutable<FreeType>(ub); upperFree && upperFree->lowerBound == ty)
upperFree->lowerBound = builtinTypes->neverType;
else
@ -926,7 +977,8 @@ struct TypeCacher : TypeOnceVisitor
return false;
}
bool visit(TypePackId tp, const BoundTypePack& btp) override {
bool visit(TypePackId tp, const BoundTypePack& btp) override
{
traverse(btp.boundTo);
if (isUncacheable(btp.boundTo))
markUncacheable(tp);
@ -969,7 +1021,7 @@ std::optional<TypeId> generalize(
FreeTypeSearcher fts{scope, cachedTypes};
fts.traverse(ty);
MutatingGeneralizer gen{builtinTypes, scope, cachedTypes, std::move(fts.positiveTypes), std::move(fts.negativeTypes), avoidSealingTables};
MutatingGeneralizer gen{arena, builtinTypes, scope, cachedTypes, std::move(fts.positiveTypes), std::move(fts.negativeTypes), avoidSealingTables};
gen.traverse(ty);

View file

@ -11,6 +11,7 @@
#include <algorithm>
LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds)
namespace Luau
{
@ -61,9 +62,7 @@ TypeId Instantiation::clean(TypeId ty)
LUAU_ASSERT(ftv);
FunctionType clone = FunctionType{level, scope, ftv->argTypes, ftv->retTypes, ftv->definition, ftv->hasSelf};
clone.magicFunction = ftv->magicFunction;
clone.dcrMagicFunction = ftv->dcrMagicFunction;
clone.dcrMagicRefinement = ftv->dcrMagicRefinement;
clone.magic = ftv->magic;
clone.tags = ftv->tags;
clone.argNames = ftv->argNames;
TypeId result = addType(std::move(clone));
@ -165,7 +164,7 @@ TypeId ReplaceGenerics::clean(TypeId ty)
}
else
{
return addType(FreeType{scope, level});
return FFlag::LuauFreeTypesMustHaveBounds ? arena->freshType(builtinTypes, scope, level) : addType(FreeType{scope, level});
}
}

View file

@ -15,11 +15,12 @@
#include <algorithm>
LUAU_FASTFLAG(LuauSolverV2);
LUAU_FASTFLAGVARIABLE(LuauIncrementalAutocompleteCommentDetection)
namespace Luau
{
static bool contains(Position pos, Comment comment)
static bool contains_DEPRECATED(Position pos, Comment comment)
{
if (comment.location.contains(pos))
return true;
@ -32,7 +33,22 @@ static bool contains(Position pos, Comment comment)
return false;
}
static bool isWithinComment(const std::vector<Comment>& commentLocations, Position pos)
static bool contains(Position pos, Comment comment)
{
if (comment.location.contains(pos))
return true;
else if (comment.type == Lexeme::BrokenComment && comment.location.begin <= pos) // Broken comments are broken specifically because they don't
// have an end
return true;
// comments actually span the whole line - in incremental mode, we could pass a cursor outside of the current parsed comment range span, but it
// would still be 'within' the comment So, the cursor must be on the same line and the comment itself must come strictly after the `begin`
else if (comment.type == Lexeme::Comment && comment.location.end.line == pos.line && comment.location.begin <= pos)
return true;
else
return false;
}
bool isWithinComment(const std::vector<Comment>& commentLocations, Position pos)
{
auto iter = std::lower_bound(
commentLocations.begin(),
@ -40,6 +56,11 @@ static bool isWithinComment(const std::vector<Comment>& commentLocations, Positi
Comment{Lexeme::Comment, Location{pos, pos}},
[](const Comment& a, const Comment& b)
{
if (FFlag::LuauIncrementalAutocompleteCommentDetection)
{
if (a.type == Lexeme::Comment)
return a.location.end.line < b.location.end.line;
}
return a.location.end < b.location.end;
}
);
@ -47,7 +68,7 @@ static bool isWithinComment(const std::vector<Comment>& commentLocations, Positi
if (iter == commentLocations.end())
return false;
if (contains(pos, *iter))
if (FFlag::LuauIncrementalAutocompleteCommentDetection ? contains(pos, *iter) : contains_DEPRECATED(pos, *iter))
return true;
// Due to the nature of std::lower_bound, it is possible that iter points at a comment that ends
@ -149,9 +170,9 @@ struct ClonePublicInterface : Substitution
freety->scope->location,
module->name,
InternalError{"Free type is escaping its module; please report this bug at "
"https://github.com/luau-lang/luau/issues"}
);
result = builtinTypes->errorRecoveryType();
"https://github.com/luau-lang/luau/issues"}
);
result = builtinTypes->errorRecoveryType();
}
else if (auto genericty = getMutable<GenericType>(result))
{
@ -173,8 +194,8 @@ struct ClonePublicInterface : Substitution
ftp->scope->location,
module->name,
InternalError{"Free type pack is escaping its module; please report this bug at "
"https://github.com/luau-lang/luau/issues"}
);
"https://github.com/luau-lang/luau/issues"}
);
clonedTp = builtinTypes->errorRecoveryTypePack();
}
else if (auto gtp = getMutable<GenericTypePack>(clonedTp))

View file

@ -19,8 +19,8 @@
#include <iostream>
#include <iterator>
LUAU_FASTFLAGVARIABLE(LuauUserTypeFunNonstrict)
LUAU_FASTFLAGVARIABLE(LuauCountSelfCallsNonstrict)
LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds)
namespace Luau
{
@ -158,6 +158,7 @@ private:
struct NonStrictTypeChecker
{
NotNull<BuiltinTypes> builtinTypes;
NotNull<Simplifier> simplifier;
NotNull<TypeFunctionRuntime> typeFunctionRuntime;
const NotNull<InternalErrorReporter> ice;
NotNull<TypeArena> arena;
@ -174,6 +175,7 @@ struct NonStrictTypeChecker
NonStrictTypeChecker(
NotNull<TypeArena> arena,
NotNull<BuiltinTypes> builtinTypes,
NotNull<Simplifier> simplifier,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
const NotNull<InternalErrorReporter> ice,
NotNull<UnifierSharedState> unifierState,
@ -182,12 +184,13 @@ struct NonStrictTypeChecker
Module* module
)
: builtinTypes(builtinTypes)
, simplifier(simplifier)
, typeFunctionRuntime(typeFunctionRuntime)
, ice(ice)
, arena(arena)
, module(module)
, normalizer{arena, builtinTypes, unifierState, /* cache inhabitance */ true}
, subtyping{builtinTypes, arena, NotNull(&normalizer), typeFunctionRuntime, ice}
, subtyping{builtinTypes, arena, simplifier, NotNull(&normalizer), typeFunctionRuntime, ice}
, dfg(dfg)
, limits(limits)
{
@ -209,7 +212,7 @@ struct NonStrictTypeChecker
return *fst;
else if (auto ftp = get<FreeTypePack>(pack))
{
TypeId result = arena->addType(FreeType{ftp->scope});
TypeId result = FFlag::LuauFreeTypesMustHaveBounds ? arena->freshType(builtinTypes, ftp->scope) : arena->addType(FreeType{ftp->scope});
TypePackId freeTail = arena->addTypePack(FreeTypePack{ftp->scope});
TypePack* resultPack = emplaceTypePack<TypePack>(asMutable(pack));
@ -232,13 +235,14 @@ struct NonStrictTypeChecker
if (noTypeFunctionErrors.find(instance))
return instance;
ErrorVec errors = reduceTypeFunctions(
instance,
location,
TypeFunctionContext{arena, builtinTypes, stack.back(), NotNull{&normalizer}, typeFunctionRuntime, ice, limits},
true
)
.errors;
ErrorVec errors =
reduceTypeFunctions(
instance,
location,
TypeFunctionContext{arena, builtinTypes, stack.back(), simplifier, NotNull{&normalizer}, typeFunctionRuntime, ice, limits},
true
)
.errors;
if (errors.empty())
noTypeFunctionErrors.insert(instance);
@ -424,9 +428,6 @@ struct NonStrictTypeChecker
NonStrictContext visit(AstStatTypeFunction* typeFunc)
{
if (!FFlag::LuauUserTypeFunNonstrict)
reportError(GenericError{"This syntax is not supported"}, typeFunc->location);
return {};
}
@ -888,6 +889,7 @@ private:
void checkNonStrict(
NotNull<BuiltinTypes> builtinTypes,
NotNull<Simplifier> simplifier,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<InternalErrorReporter> ice,
NotNull<UnifierSharedState> unifierState,
@ -899,7 +901,9 @@ void checkNonStrict(
{
LUAU_TIMETRACE_SCOPE("checkNonStrict", "Typechecking");
NonStrictTypeChecker typeChecker{NotNull{&module->internalTypes}, builtinTypes, typeFunctionRuntime, ice, unifierState, dfg, limits, module};
NonStrictTypeChecker typeChecker{
NotNull{&module->internalTypes}, builtinTypes, simplifier, typeFunctionRuntime, ice, unifierState, dfg, limits, module
};
typeChecker.visit(sourceModule.root);
unfreeze(module->interfaceTypes);
copyErrors(module->errors, module->interfaceTypes, builtinTypes);

View file

@ -17,12 +17,11 @@
LUAU_FASTFLAGVARIABLE(DebugLuauCheckNormalizeInvariant)
LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000);
LUAU_FASTFLAG(LuauSolverV2);
LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000)
LUAU_FASTINTVARIABLE(LuauNormalizeIntersectionLimit, 200)
LUAU_FASTFLAGVARIABLE(LuauNormalizationTracksCyclicPairsThroughInhabitance);
LUAU_FASTFLAGVARIABLE(LuauIntersectNormalsNeedsToTrackResourceLimits);
LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAGVARIABLE(LuauFixInfiniteRecursionInNormalization)
LUAU_FASTFLAGVARIABLE(LuauFixNormalizedIntersectionOfNegatedClass)
namespace Luau
{
@ -1809,7 +1808,8 @@ NormalizationResult Normalizer::unionNormalWithTy(
}
else if (get<UnknownType>(here.tops))
return NormalizationResult::True;
else if (get<GenericType>(there) || get<FreeType>(there) || get<BlockedType>(there) || get<PendingExpansionType>(there) || get<TypeFunctionInstanceType>(there))
else if (get<GenericType>(there) || get<FreeType>(there) || get<BlockedType>(there) || get<PendingExpansionType>(there) ||
get<TypeFunctionInstanceType>(there))
{
if (tyvarIndex(there) <= ignoreSmallerTyvars)
return NormalizationResult::True;
@ -2284,9 +2284,24 @@ void Normalizer::intersectClassesWithClass(NormalizedClassType& heres, TypeId th
else if (isSubclass(there, hereTy))
{
TypeIds negations = std::move(hereNegations);
bool emptyIntersectWithNegation = false;
for (auto nIt = negations.begin(); nIt != negations.end();)
{
if (FFlag::LuauFixNormalizedIntersectionOfNegatedClass && isSubclass(there, *nIt))
{
// Hitting this block means that the incoming class is a
// subclass of this type, _and_ one of its negations is a
// superclass of this type, e.g.:
//
// Dog & ~Animal
//
// Clearly this intersects to never, so we mark this class as
// being removed from the normalized class type.
emptyIntersectWithNegation = true;
break;
}
if (!isSubclass(*nIt, there))
{
nIt = negations.erase(nIt);
@ -2299,7 +2314,8 @@ void Normalizer::intersectClassesWithClass(NormalizedClassType& heres, TypeId th
it = heres.ordering.erase(it);
heres.classes.erase(hereTy);
heres.pushPair(there, std::move(negations));
if (!emptyIntersectWithNegation)
heres.pushPair(there, std::move(negations));
break;
}
// If the incoming class is a superclass of the current class, we don't
@ -2584,11 +2600,31 @@ std::optional<TypeId> Normalizer::intersectionOfTables(TypeId here, TypeId there
{
if (tprop.readTy.has_value())
{
// if the intersection of the read types of a property is uninhabited, the whole table is `never`.
// We've seen these table prop elements before and we're about to ask if their intersection
// is inhabited
if (FFlag::LuauNormalizationTracksCyclicPairsThroughInhabitance)
if (FFlag::LuauFixInfiniteRecursionInNormalization)
{
TypeId ty = simplifyIntersection(builtinTypes, NotNull{arena}, *hprop.readTy, *tprop.readTy).result;
// If any property is going to get mapped to `never`, we can just call the entire table `never`.
// Since this check is syntactic, we may sometimes miss simplifying tables with complex uninhabited properties.
// Prior versions of this code attempted to do this semantically using the normalization machinery, but this
// mistakenly causes infinite loops when giving more complex recursive table types. As it stands, this approach
// will continue to scale as simplification is improved, but we may wish to reintroduce the semantic approach
// once we have revisited the usage of seen sets systematically (and possibly with some additional guarding to recognize
// when types are infinitely-recursive with non-pointer identical instances of them, or some guard to prevent that
// construction altogether). See also: `gh1632_no_infinite_recursion_in_normalization`
if (get<NeverType>(ty))
return {builtinTypes->neverType};
prop.readTy = ty;
hereSubThere &= (ty == hprop.readTy);
thereSubHere &= (ty == tprop.readTy);
}
else
{
// if the intersection of the read types of a property is uninhabited, the whole table is `never`.
// We've seen these table prop elements before and we're about to ask if their intersection
// is inhabited
auto pair1 = std::pair{*hprop.readTy, *tprop.readTy};
auto pair2 = std::pair{*tprop.readTy, *hprop.readTy};
if (seenTablePropPairs.contains(pair1) || seenTablePropPairs.contains(pair2))
@ -2603,6 +2639,8 @@ std::optional<TypeId> Normalizer::intersectionOfTables(TypeId here, TypeId there
seenTablePropPairs.insert(pair2);
}
// FIXME(ariel): this is being added in a flag removal, so not changing the semantics here, but worth noting that this
// fresh `seenSet` is definitely a bug. we already have `seenSet` from the parameter that _should_ have been used here.
Set<TypeId> seenSet{nullptr};
NormalizationResult res = isIntersectionInhabited(*hprop.readTy, *tprop.readTy, seenTablePropPairs, seenSet);
@ -2616,34 +2654,6 @@ std::optional<TypeId> Normalizer::intersectionOfTables(TypeId here, TypeId there
hereSubThere &= (ty == hprop.readTy);
thereSubHere &= (ty == tprop.readTy);
}
else
{
if (seenSet.contains(*hprop.readTy) && seenSet.contains(*tprop.readTy))
{
seenSet.erase(*hprop.readTy);
seenSet.erase(*tprop.readTy);
return {builtinTypes->neverType};
}
else
{
seenSet.insert(*hprop.readTy);
seenSet.insert(*tprop.readTy);
}
NormalizationResult res = isIntersectionInhabited(*hprop.readTy, *tprop.readTy);
seenSet.erase(*hprop.readTy);
seenSet.erase(*tprop.readTy);
if (NormalizationResult::True != res)
return {builtinTypes->neverType};
TypeId ty = simplifyIntersection(builtinTypes, NotNull{arena}, *hprop.readTy, *tprop.readTy).result;
prop.readTy = ty;
hereSubThere &= (ty == hprop.readTy);
thereSubHere &= (ty == tprop.readTy);
}
}
else
{
@ -3042,12 +3052,9 @@ NormalizationResult Normalizer::intersectTyvarsWithTy(
// See above for an explaination of `ignoreSmallerTyvars`.
NormalizationResult Normalizer::intersectNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars)
{
if (FFlag::LuauIntersectNormalsNeedsToTrackResourceLimits)
{
RecursionCounter _rc(&sharedState->counters.recursionCount);
if (!withinResourceLimits())
return NormalizationResult::HitLimits;
}
RecursionCounter _rc(&sharedState->counters.recursionCount);
if (!withinResourceLimits())
return NormalizationResult::HitLimits;
if (!get<NeverType>(there.tops))
{
@ -3162,7 +3169,8 @@ NormalizationResult Normalizer::intersectNormalWithTy(
}
return NormalizationResult::True;
}
else if (get<GenericType>(there) || get<FreeType>(there) || get<BlockedType>(there) || get<PendingExpansionType>(there) || get<TypeFunctionInstanceType>(there))
else if (get<GenericType>(there) || get<FreeType>(there) || get<BlockedType>(there) || get<PendingExpansionType>(there) ||
get<TypeFunctionInstanceType>(there))
{
NormalizedType thereNorm{builtinTypes};
NormalizedType topNorm{builtinTypes};
@ -3465,7 +3473,14 @@ TypeId Normalizer::typeFromNormal(const NormalizedType& norm)
return arena->addType(UnionType{std::move(result)});
}
bool isSubtype(TypeId subTy, TypeId superTy, NotNull<Scope> scope, NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter& ice)
bool isSubtype(
TypeId subTy,
TypeId superTy,
NotNull<Scope> scope,
NotNull<BuiltinTypes> builtinTypes,
NotNull<Simplifier> simplifier,
InternalErrorReporter& ice
)
{
UnifierSharedState sharedState{&ice};
TypeArena arena;
@ -3478,7 +3493,7 @@ bool isSubtype(TypeId subTy, TypeId superTy, NotNull<Scope> scope, NotNull<Built
// Subtyping under DCR is not implemented using unification!
if (FFlag::LuauSolverV2)
{
Subtyping subtyping{builtinTypes, NotNull{&arena}, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&ice}};
Subtyping subtyping{builtinTypes, NotNull{&arena}, simplifier, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&ice}};
return subtyping.isSubtype(subTy, superTy, scope).isSubtype;
}
@ -3491,7 +3506,14 @@ bool isSubtype(TypeId subTy, TypeId superTy, NotNull<Scope> scope, NotNull<Built
}
}
bool isSubtype(TypePackId subPack, TypePackId superPack, NotNull<Scope> scope, NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter& ice)
bool isSubtype(
TypePackId subPack,
TypePackId superPack,
NotNull<Scope> scope,
NotNull<BuiltinTypes> builtinTypes,
NotNull<Simplifier> simplifier,
InternalErrorReporter& ice
)
{
UnifierSharedState sharedState{&ice};
TypeArena arena;
@ -3504,7 +3526,7 @@ bool isSubtype(TypePackId subPack, TypePackId superPack, NotNull<Scope> scope, N
// Subtyping under DCR is not implemented using unification!
if (FFlag::LuauSolverV2)
{
Subtyping subtyping{builtinTypes, NotNull{&arena}, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&ice}};
Subtyping subtyping{builtinTypes, NotNull{&arena}, simplifier, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&ice}};
return subtyping.isSubtype(subPack, superPack, scope).isSubtype;
}

View file

@ -16,6 +16,7 @@ namespace Luau
OverloadResolver::OverloadResolver(
NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> arena,
NotNull<Simplifier> simplifier,
NotNull<Normalizer> normalizer,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<Scope> scope,
@ -25,12 +26,13 @@ OverloadResolver::OverloadResolver(
)
: builtinTypes(builtinTypes)
, arena(arena)
, simplifier(simplifier)
, normalizer(normalizer)
, typeFunctionRuntime(typeFunctionRuntime)
, scope(scope)
, ice(reporter)
, limits(limits)
, subtyping({builtinTypes, arena, normalizer, typeFunctionRuntime, ice})
, subtyping({builtinTypes, arena, simplifier, normalizer, typeFunctionRuntime, ice})
, callLoc(callLocation)
{
}
@ -202,7 +204,7 @@ std::pair<OverloadResolver::Analysis, ErrorVec> OverloadResolver::checkOverload_
)
{
FunctionGraphReductionResult result = reduceTypeFunctions(
fnTy, callLoc, TypeFunctionContext{arena, builtinTypes, scope, normalizer, typeFunctionRuntime, ice, limits}, /*force=*/true
fnTy, callLoc, TypeFunctionContext{arena, builtinTypes, scope, simplifier, normalizer, typeFunctionRuntime, ice, limits}, /*force=*/true
);
if (!result.errors.empty())
return {OverloadIsNonviable, result.errors};
@ -404,9 +406,10 @@ void OverloadResolver::add(Analysis analysis, TypeId ty, ErrorVec&& errors)
// we wrap calling the overload resolver in a separate function to reduce overall stack pressure in `solveFunctionCall`.
// this limits the lifetime of `OverloadResolver`, a large type, to only as long as it is actually needed.
std::optional<TypeId> selectOverload(
static std::optional<TypeId> selectOverload(
NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> arena,
NotNull<Simplifier> simplifier,
NotNull<Normalizer> normalizer,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<Scope> scope,
@ -417,7 +420,8 @@ std::optional<TypeId> selectOverload(
TypePackId argsPack
)
{
auto resolver = std::make_unique<OverloadResolver>(builtinTypes, arena, normalizer, typeFunctionRuntime, scope, iceReporter, limits, location);
auto resolver =
std::make_unique<OverloadResolver>(builtinTypes, arena, simplifier, normalizer, typeFunctionRuntime, scope, iceReporter, limits, location);
auto [status, overload] = resolver->selectOverload(fn, argsPack);
if (status == OverloadResolver::Analysis::Ok)
@ -432,6 +436,7 @@ std::optional<TypeId> selectOverload(
SolveResult solveFunctionCall(
NotNull<TypeArena> arena,
NotNull<BuiltinTypes> builtinTypes,
NotNull<Simplifier> simplifier,
NotNull<Normalizer> normalizer,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<InternalErrorReporter> iceReporter,
@ -443,7 +448,7 @@ SolveResult solveFunctionCall(
)
{
std::optional<TypeId> overloadToUse =
selectOverload(builtinTypes, arena, normalizer, typeFunctionRuntime, scope, iceReporter, limits, location, fn, argsPack);
selectOverload(builtinTypes, arena, simplifier, normalizer, typeFunctionRuntime, scope, iceReporter, limits, location, fn, argsPack);
if (!overloadToUse)
return {SolveResult::NoMatchingOverload};

View file

@ -211,6 +211,16 @@ void Scope::inheritRefinements(const ScopePtr& childScope)
}
}
bool Scope::shouldWarnGlobal(std::string name) const
{
for (const Scope* current = this; current; current = current->parent.get())
{
if (current->globalsToWarn.contains(name))
return true;
}
return false;
}
bool subsumesStrict(Scope* left, Scope* right)
{
while (right)

View file

@ -31,16 +31,16 @@ struct TypeSimplifier
int recursionDepth = 0;
TypeId mkNegation(TypeId ty);
TypeId mkNegation(TypeId ty) const;
TypeId intersectFromParts(std::set<TypeId> parts);
TypeId intersectUnionWithType(TypeId unionTy, TypeId right);
TypeId intersectUnionWithType(TypeId left, TypeId right);
TypeId intersectUnions(TypeId left, TypeId right);
TypeId intersectNegatedUnion(TypeId unionTy, TypeId right);
TypeId intersectNegatedUnion(TypeId left, TypeId right);
TypeId intersectTypeWithNegation(TypeId a, TypeId b);
TypeId intersectNegations(TypeId a, TypeId b);
TypeId intersectTypeWithNegation(TypeId left, TypeId right);
TypeId intersectNegations(TypeId left, TypeId right);
TypeId intersectIntersectionWithType(TypeId left, TypeId right);
@ -48,8 +48,8 @@ struct TypeSimplifier
// unions, intersections, or negations.
std::optional<TypeId> basicIntersect(TypeId left, TypeId right);
TypeId intersect(TypeId ty, TypeId discriminant);
TypeId union_(TypeId ty, TypeId discriminant);
TypeId intersect(TypeId left, TypeId right);
TypeId union_(TypeId left, TypeId right);
TypeId simplify(TypeId ty);
TypeId simplify(TypeId ty, DenseHashSet<TypeId>& seen);
@ -573,7 +573,7 @@ Relation relate(TypeId left, TypeId right)
return relate(left, right, seen);
}
TypeId TypeSimplifier::mkNegation(TypeId ty)
TypeId TypeSimplifier::mkNegation(TypeId ty) const
{
TypeId result = nullptr;

View file

@ -98,9 +98,7 @@ static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool a
FunctionType clone = FunctionType{a.level, a.scope, a.argTypes, a.retTypes, a.definition, a.hasSelf};
clone.generics = a.generics;
clone.genericPacks = a.genericPacks;
clone.magicFunction = a.magicFunction;
clone.dcrMagicFunction = a.dcrMagicFunction;
clone.dcrMagicRefinement = a.dcrMagicRefinement;
clone.magic = a.magic;
clone.tags = a.tags;
clone.argNames = a.argNames;
clone.isCheckedFunction = a.isCheckedFunction;

View file

@ -22,7 +22,6 @@
#include <algorithm>
LUAU_FASTFLAGVARIABLE(DebugLuauSubtypingCheckPathValidity)
LUAU_FASTFLAGVARIABLE(LuauRetrySubtypingWithoutHiddenPack)
namespace Luau
{
@ -396,12 +395,14 @@ TypePackId* SubtypingEnvironment::getMappedPackBounds(TypePackId tp)
Subtyping::Subtyping(
NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> typeArena,
NotNull<Simplifier> simplifier,
NotNull<Normalizer> normalizer,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<InternalErrorReporter> iceReporter
)
: builtinTypes(builtinTypes)
, arena(typeArena)
, simplifier(simplifier)
, normalizer(normalizer)
, typeFunctionRuntime(typeFunctionRuntime)
, iceReporter(iceReporter)
@ -1472,15 +1473,14 @@ SubtypingResult Subtyping::isCovariantWith(
// If subtyping failed in the argument packs, we should check if there's a hidden variadic tail and try ignoring it.
// This might cause subtyping correctly because the sub type here may not have a hidden variadic tail or equivalent.
if (FFlag::LuauRetrySubtypingWithoutHiddenPack && !result.isSubtype)
if (!result.isSubtype)
{
auto [arguments, tail] = flatten(superFunction->argTypes);
if (auto variadic = get<VariadicTypePack>(tail); variadic && variadic->hidden)
{
result.orElse(
isContravariantWith(env, subFunction->argTypes, arena->addTypePack(TypePack{arguments}), scope).withBothComponent(TypePath::PackField::Arguments)
);
result.orElse(isContravariantWith(env, subFunction->argTypes, arena->addTypePack(TypePack{arguments}), scope)
.withBothComponent(TypePath::PackField::Arguments));
}
}
}
@ -1861,7 +1861,7 @@ TypeId Subtyping::makeAggregateType(const Container& container, TypeId orElse)
std::pair<TypeId, ErrorVec> Subtyping::handleTypeFunctionReductionResult(const TypeFunctionInstanceType* functionInstance, NotNull<Scope> scope)
{
TypeFunctionContext context{arena, builtinTypes, scope, normalizer, typeFunctionRuntime, iceReporter, NotNull{&limits}};
TypeFunctionContext context{arena, builtinTypes, scope, simplifier, normalizer, typeFunctionRuntime, iceReporter, NotNull{&limits}};
TypeId function = arena->addType(*functionInstance);
FunctionGraphReductionResult result = reduceTypeFunctions(function, {}, context, true);
ErrorVec errors;

View file

@ -9,6 +9,8 @@
#include "Luau/TypeUtils.h"
#include "Luau/Unifier2.h"
LUAU_FASTFLAGVARIABLE(LuauDontInPlaceMutateTableType)
namespace Luau
{
@ -236,6 +238,8 @@ TypeId matchLiteralType(
return exprType;
}
DenseHashSet<AstExprConstantString*> keysToDelete{nullptr};
for (const AstExprTable::Item& item : exprTable->items)
{
if (isRecord(item))
@ -280,7 +284,10 @@ TypeId matchLiteralType(
else
tableTy->indexer = TableIndexer{expectedTableTy->indexer->indexType, matchedType};
tableTy->props.erase(keyStr);
if (FFlag::LuauDontInPlaceMutateTableType)
keysToDelete.insert(item.key->as<AstExprConstantString>());
else
tableTy->props.erase(keyStr);
}
// If it's just an extra property and the expected type
@ -387,6 +394,16 @@ TypeId matchLiteralType(
LUAU_ASSERT(!"Unexpected");
}
if (FFlag::LuauDontInPlaceMutateTableType)
{
for (const auto& key: keysToDelete)
{
const AstArray<char>& s = key->value;
std::string keyStr{s.data, s.data + s.size};
tableTy->props.erase(keyStr);
}
}
// Keys that the expectedType says we should have, but that aren't
// specified by the AST fragment.
//

File diff suppressed because it is too large Load diff

View file

@ -27,6 +27,7 @@ LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500)
LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0)
LUAU_FASTINT(LuauTypeInferRecursionLimit)
LUAU_FASTFLAG(LuauInstantiateInSubtyping)
LUAU_FASTFLAGVARIABLE(LuauFreeTypesMustHaveBounds)
namespace Luau
{
@ -478,24 +479,12 @@ bool hasLength(TypeId ty, DenseHashSet<TypeId>& seen, int* recursionCount)
return false;
}
FreeType::FreeType(TypeLevel level)
// New constructors
FreeType::FreeType(TypeLevel level, TypeId lowerBound, TypeId upperBound)
: index(Unifiable::freshIndex())
, level(level)
, scope(nullptr)
{
}
FreeType::FreeType(Scope* scope)
: index(Unifiable::freshIndex())
, level{}
, scope(scope)
{
}
FreeType::FreeType(Scope* scope, TypeLevel level)
: index(Unifiable::freshIndex())
, level(level)
, scope(scope)
, lowerBound(lowerBound)
, upperBound(upperBound)
{
}
@ -507,6 +496,40 @@ FreeType::FreeType(Scope* scope, TypeId lowerBound, TypeId upperBound)
{
}
FreeType::FreeType(Scope* scope, TypeLevel level, TypeId lowerBound, TypeId upperBound)
: index(Unifiable::freshIndex())
, level(level)
, scope(scope)
, lowerBound(lowerBound)
, upperBound(upperBound)
{
}
// Old constructors
FreeType::FreeType(TypeLevel level)
: index(Unifiable::freshIndex())
, level(level)
, scope(nullptr)
{
LUAU_ASSERT(!FFlag::LuauFreeTypesMustHaveBounds);
}
FreeType::FreeType(Scope* scope)
: index(Unifiable::freshIndex())
, level{}
, scope(scope)
{
LUAU_ASSERT(!FFlag::LuauFreeTypesMustHaveBounds);
}
FreeType::FreeType(Scope* scope, TypeLevel level)
: index(Unifiable::freshIndex())
, level(level)
, scope(scope)
{
LUAU_ASSERT(!FFlag::LuauFreeTypesMustHaveBounds);
}
GenericType::GenericType()
: index(Unifiable::freshIndex())
, name("g" + std::to_string(index))
@ -554,12 +577,12 @@ BlockedType::BlockedType()
{
}
Constraint* BlockedType::getOwner() const
const Constraint* BlockedType::getOwner() const
{
return owner;
}
void BlockedType::setOwner(Constraint* newOwner)
void BlockedType::setOwner(const Constraint* newOwner)
{
LUAU_ASSERT(owner == nullptr);
@ -569,7 +592,7 @@ void BlockedType::setOwner(Constraint* newOwner)
owner = newOwner;
}
void BlockedType::replaceOwner(Constraint* newOwner)
void BlockedType::replaceOwner(const Constraint* newOwner)
{
owner = newOwner;
}

View file

@ -3,6 +3,7 @@
#include "Luau/TypeArena.h"
LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena);
LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds)
namespace Luau
{
@ -22,7 +23,34 @@ TypeId TypeArena::addTV(Type&& tv)
return allocated;
}
TypeId TypeArena::freshType(TypeLevel level)
TypeId TypeArena::freshType(NotNull<BuiltinTypes> builtins, TypeLevel level)
{
TypeId allocated = types.allocate(FreeType{level, builtins->neverType, builtins->unknownType});
asMutable(allocated)->owningArena = this;
return allocated;
}
TypeId TypeArena::freshType(NotNull<BuiltinTypes> builtins, Scope* scope)
{
TypeId allocated = types.allocate(FreeType{scope, builtins->neverType, builtins->unknownType});
asMutable(allocated)->owningArena = this;
return allocated;
}
TypeId TypeArena::freshType(NotNull<BuiltinTypes> builtins, Scope* scope, TypeLevel level)
{
TypeId allocated = types.allocate(FreeType{scope, level, builtins->neverType, builtins->unknownType});
asMutable(allocated)->owningArena = this;
return allocated;
}
TypeId TypeArena::freshType_DEPRECATED(TypeLevel level)
{
TypeId allocated = types.allocate(FreeType{level});
@ -31,7 +59,7 @@ TypeId TypeArena::freshType(TypeLevel level)
return allocated;
}
TypeId TypeArena::freshType(Scope* scope)
TypeId TypeArena::freshType_DEPRECATED(Scope* scope)
{
TypeId allocated = types.allocate(FreeType{scope});
@ -40,7 +68,7 @@ TypeId TypeArena::freshType(Scope* scope)
return allocated;
}
TypeId TypeArena::freshType(Scope* scope, TypeLevel level)
TypeId TypeArena::freshType_DEPRECATED(Scope* scope, TypeLevel level)
{
TypeId allocated = types.allocate(FreeType{scope, level});

View file

@ -386,8 +386,12 @@ public:
}
AstType* operator()(const NegationType& ntv)
{
// FIXME: do the same thing we do with ErrorType
throw InternalCompilerError("Cannot convert NegationType into AstNode");
AstArray<AstTypeOrPack> params;
params.size = 1;
params.data = static_cast<AstTypeOrPack*>(allocator->allocate(sizeof(AstType*)));
params.data[0] = AstTypeOrPack{Luau::visit(*this, ntv.ty->ty), nullptr};
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("negate"), std::nullopt, Location(), true, params);
}
AstType* operator()(const TypeFunctionInstanceType& tfit)
{

View file

@ -7,7 +7,6 @@
#include "Luau/DcrLogger.h"
#include "Luau/DenseHash.h"
#include "Luau/Error.h"
#include "Luau/InsertionOrderedMap.h"
#include "Luau/Instantiation.h"
#include "Luau/Metamethods.h"
#include "Luau/Normalize.h"
@ -27,12 +26,12 @@
#include "Luau/VisitType.h"
#include <algorithm>
#include <iostream>
#include <ostream>
LUAU_FASTFLAG(DebugLuauMagicTypes)
LUAU_FASTFLAG(InferGlobalTypes)
LUAU_FASTFLAGVARIABLE(LuauTableKeysAreRValues)
LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds)
namespace Luau
{
@ -175,7 +174,7 @@ struct InternalTypeFunctionFinder : TypeOnceVisitor
DenseHashSet<TypeId> mentionedFunctions{nullptr};
DenseHashSet<TypePackId> mentionedFunctionPacks{nullptr};
InternalTypeFunctionFinder(std::vector<TypeId>& declStack)
explicit InternalTypeFunctionFinder(std::vector<TypeId>& declStack)
{
TypeFunctionFinder f;
for (TypeId fn : declStack)
@ -268,6 +267,7 @@ struct InternalTypeFunctionFinder : TypeOnceVisitor
void check(
NotNull<BuiltinTypes> builtinTypes,
NotNull<Simplifier> simplifier,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<UnifierSharedState> unifierState,
NotNull<TypeCheckLimits> limits,
@ -278,7 +278,7 @@ void check(
{
LUAU_TIMETRACE_SCOPE("check", "Typechecking");
TypeChecker2 typeChecker{builtinTypes, typeFunctionRuntime, unifierState, limits, logger, &sourceModule, module};
TypeChecker2 typeChecker{builtinTypes, simplifier, typeFunctionRuntime, unifierState, limits, logger, &sourceModule, module};
typeChecker.visit(sourceModule.root);
@ -295,6 +295,7 @@ void check(
TypeChecker2::TypeChecker2(
NotNull<BuiltinTypes> builtinTypes,
NotNull<Simplifier> simplifier,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<UnifierSharedState> unifierState,
NotNull<TypeCheckLimits> limits,
@ -303,6 +304,7 @@ TypeChecker2::TypeChecker2(
Module* module
)
: builtinTypes(builtinTypes)
, simplifier(simplifier)
, typeFunctionRuntime(typeFunctionRuntime)
, logger(logger)
, limits(limits)
@ -310,7 +312,7 @@ TypeChecker2::TypeChecker2(
, sourceModule(sourceModule)
, module(module)
, normalizer{&module->internalTypes, builtinTypes, unifierState, /* cacheInhabitance */ true}
, _subtyping{builtinTypes, NotNull{&module->internalTypes}, NotNull{&normalizer}, typeFunctionRuntime, NotNull{unifierState->iceHandler}}
, _subtyping{builtinTypes, NotNull{&module->internalTypes}, simplifier, NotNull{&normalizer}, typeFunctionRuntime, NotNull{unifierState->iceHandler}}
, subtyping(&_subtyping)
{
}
@ -492,7 +494,9 @@ TypeId TypeChecker2::checkForTypeFunctionInhabitance(TypeId instance, Location l
reduceTypeFunctions(
instance,
location,
TypeFunctionContext{NotNull{&module->internalTypes}, builtinTypes, stack.back(), NotNull{&normalizer}, typeFunctionRuntime, ice, limits},
TypeFunctionContext{
NotNull{&module->internalTypes}, builtinTypes, stack.back(), simplifier, NotNull{&normalizer}, typeFunctionRuntime, ice, limits
},
true
)
.errors;
@ -501,7 +505,7 @@ TypeId TypeChecker2::checkForTypeFunctionInhabitance(TypeId instance, Location l
return instance;
}
TypePackId TypeChecker2::lookupPack(AstExpr* expr)
TypePackId TypeChecker2::lookupPack(AstExpr* expr) const
{
// If a type isn't in the type graph, it probably means that a recursion limit was exceeded.
// We'll just return anyType in these cases. Typechecking against any is very fast and this
@ -551,7 +555,7 @@ TypeId TypeChecker2::lookupAnnotation(AstType* annotation)
return checkForTypeFunctionInhabitance(follow(*ty), annotation->location);
}
std::optional<TypePackId> TypeChecker2::lookupPackAnnotation(AstTypePack* annotation)
std::optional<TypePackId> TypeChecker2::lookupPackAnnotation(AstTypePack* annotation) const
{
TypePackId* tp = module->astResolvedTypePacks.find(annotation);
if (tp != nullptr)
@ -559,7 +563,7 @@ std::optional<TypePackId> TypeChecker2::lookupPackAnnotation(AstTypePack* annota
return {};
}
TypeId TypeChecker2::lookupExpectedType(AstExpr* expr)
TypeId TypeChecker2::lookupExpectedType(AstExpr* expr) const
{
if (TypeId* ty = module->astExpectedTypes.find(expr))
return follow(*ty);
@ -567,7 +571,7 @@ TypeId TypeChecker2::lookupExpectedType(AstExpr* expr)
return builtinTypes->anyType;
}
TypePackId TypeChecker2::lookupExpectedPack(AstExpr* expr, TypeArena& arena)
TypePackId TypeChecker2::lookupExpectedPack(AstExpr* expr, TypeArena& arena) const
{
if (TypeId* ty = module->astExpectedTypes.find(expr))
return arena.addTypePack(TypePack{{follow(*ty)}, std::nullopt});
@ -591,7 +595,7 @@ TypePackId TypeChecker2::reconstructPack(AstArray<AstExpr*> exprs, TypeArena& ar
return arena.addTypePack(TypePack{head, tail});
}
Scope* TypeChecker2::findInnermostScope(Location location)
Scope* TypeChecker2::findInnermostScope(Location location) const
{
Scope* bestScope = module->getModuleScope().get();
@ -1014,7 +1018,8 @@ void TypeChecker2::visit(AstStatForIn* forInStatement)
{
reportError(OptionalValueAccess{iteratorTy}, forInStatement->values.data[0]->location);
}
else if (std::optional<TypeId> iterMmTy = findMetatableEntry(builtinTypes, module->errors, iteratorTy, "__iter", forInStatement->values.data[0]->location))
else if (std::optional<TypeId> iterMmTy =
findMetatableEntry(builtinTypes, module->errors, iteratorTy, "__iter", forInStatement->values.data[0]->location))
{
Instantiation instantiation{TxnLog::empty(), &arena, builtinTypes, TypeLevel{}, scope};
@ -1349,7 +1354,17 @@ void TypeChecker2::visit(AstExprGlobal* expr)
{
NotNull<Scope> scope = stack.back();
if (!scope->lookup(expr->name))
{
reportError(UnknownSymbol{expr->name.value, UnknownSymbol::Binding}, expr->location);
}
else if (FFlag::InferGlobalTypes)
{
if (scope->shouldWarnGlobal(expr->name.value) && !warnedGlobals.contains(expr->name.value))
{
reportError(UnknownSymbol{expr->name.value, UnknownSymbol::Binding}, expr->location);
warnedGlobals.insert(expr->name.value);
}
}
}
void TypeChecker2::visit(AstExprVarargs* expr)
@ -1437,10 +1452,11 @@ void TypeChecker2::visitCall(AstExprCall* call)
TypePackId argsTp = module->internalTypes.addTypePack(args);
if (auto ftv = get<FunctionType>(follow(*originalCallTy)))
{
if (ftv->dcrMagicTypeCheck)
if (ftv->magic)
{
ftv->dcrMagicTypeCheck(MagicFunctionTypeCheckContext{NotNull{this}, builtinTypes, call, argsTp, scope});
return;
bool usedMagic = ftv->magic->typeCheck(MagicFunctionTypeCheckContext{NotNull{this}, builtinTypes, call, argsTp, scope});
if (usedMagic)
return;
}
}
@ -1448,6 +1464,7 @@ void TypeChecker2::visitCall(AstExprCall* call)
OverloadResolver resolver{
builtinTypes,
NotNull{&module->internalTypes},
simplifier,
NotNull{&normalizer},
typeFunctionRuntime,
NotNull{stack.back()},
@ -1545,7 +1562,7 @@ void TypeChecker2::visit(AstExprCall* call)
visitCall(call);
}
std::optional<TypeId> TypeChecker2::tryStripUnionFromNil(TypeId ty)
std::optional<TypeId> TypeChecker2::tryStripUnionFromNil(TypeId ty) const
{
if (const UnionType* utv = get<UnionType>(ty))
{
@ -2089,7 +2106,10 @@ TypeId TypeChecker2::visit(AstExprBinary* expr, AstNode* overrideKey)
}
else
{
expectedRets = module->internalTypes.addTypePack({module->internalTypes.freshType(scope, TypeLevel{})});
expectedRets = module->internalTypes.addTypePack(
{FFlag::LuauFreeTypesMustHaveBounds ? module->internalTypes.freshType(builtinTypes, scope, TypeLevel{})
: module->internalTypes.freshType_DEPRECATED(scope, TypeLevel{})}
);
}
TypeId expectedTy = module->internalTypes.addType(FunctionType(expectedArgs, expectedRets));
@ -2341,7 +2361,8 @@ TypeId TypeChecker2::flattenPack(TypePackId pack)
return *fst;
else if (auto ftp = get<FreeTypePack>(pack))
{
TypeId result = module->internalTypes.addType(FreeType{ftp->scope});
TypeId result = FFlag::LuauFreeTypesMustHaveBounds ? module->internalTypes.freshType(builtinTypes, ftp->scope)
: module->internalTypes.addType(FreeType{ftp->scope});
TypePackId freeTail = module->internalTypes.addTypePack(FreeTypePack{ftp->scope});
TypePack* resultPack = emplaceTypePack<TypePack>(asMutable(pack));
@ -2403,6 +2424,8 @@ void TypeChecker2::visit(AstType* ty)
return visit(t);
else if (auto t = ty->as<AstTypeIntersection>())
return visit(t);
else if (auto t = ty->as<AstTypeGroup>())
return visit(t->type);
}
void TypeChecker2::visit(AstTypeReference* ty)
@ -3024,7 +3047,7 @@ PropertyType TypeChecker2::hasIndexTypeFromType(
{
TypeId indexType = follow(tt->indexer->indexType);
TypeId givenType = module->internalTypes.addType(SingletonType{StringSingleton{prop}});
if (isSubtype(givenType, indexType, NotNull{module->getModuleScope().get()}, builtinTypes, *ice))
if (isSubtype(givenType, indexType, NotNull{module->getModuleScope().get()}, builtinTypes, simplifier, *ice))
return {NormalizationResult::True, {tt->indexer->indexResultType}};
}

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -20,8 +20,7 @@
// currently, controls serialization, deserialization, and `type.copy`
LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFunctionSerdeIterationLimit, 100'000);
LUAU_FASTFLAGVARIABLE(LuauUserTypeFunFixMetatable)
LUAU_FASTFLAG(LuauUserTypeFunThreadBuffer)
LUAU_FASTFLAG(LuauUserTypeFunGenerics)
namespace Luau
{
@ -161,26 +160,10 @@ private:
target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::String));
break;
case PrimitiveType::Thread:
if (FFlag::LuauUserTypeFunThreadBuffer)
{
target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::Thread));
}
else
{
std::string error = format("Argument of primitive type %s is not currently serializable by type functions", toString(ty).c_str());
state->errors.push_back(error);
}
target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::Thread));
break;
case PrimitiveType::Buffer:
if (FFlag::LuauUserTypeFunThreadBuffer)
{
target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::Buffer));
}
else
{
std::string error = format("Argument of primitive type %s is not currently serializable by type functions", toString(ty).c_str());
state->errors.push_back(error);
}
target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::Buffer));
break;
case PrimitiveType::Function:
case PrimitiveType::Table:
@ -222,13 +205,22 @@ private:
else if (auto f = get<FunctionType>(ty))
{
TypeFunctionTypePackId emptyTypePack = typeFunctionRuntime->typePackArena.allocate(TypeFunctionTypePack{});
target = typeFunctionRuntime->typeArena.allocate(TypeFunctionFunctionType{emptyTypePack, emptyTypePack});
target = typeFunctionRuntime->typeArena.allocate(TypeFunctionFunctionType{{}, {}, emptyTypePack, emptyTypePack});
}
else if (auto c = get<ClassType>(ty))
{
state->classesSerialized[c->name] = ty;
target = typeFunctionRuntime->typeArena.allocate(TypeFunctionClassType{{}, std::nullopt, std::nullopt, std::nullopt, c->name});
}
else if (auto g = get<GenericType>(ty); FFlag::LuauUserTypeFunGenerics && g)
{
Name name = g->name;
if (!g->explicitName)
name = format("g%d", g->index);
target = typeFunctionRuntime->typeArena.allocate(TypeFunctionGenericType{g->explicitName, false, name});
}
else
{
std::string error = format("Argument of type %s is not currently serializable by type functions", toString(ty).c_str());
@ -253,6 +245,15 @@ private:
target = typeFunctionRuntime->typePackArena.allocate(TypeFunctionTypePack{{}});
else if (auto vPack = get<VariadicTypePack>(tp))
target = typeFunctionRuntime->typePackArena.allocate(TypeFunctionVariadicTypePack{});
else if (auto gPack = get<GenericTypePack>(tp); FFlag::LuauUserTypeFunGenerics && gPack)
{
Name name = gPack->name;
if (!gPack->explicitName)
name = format("g%d", gPack->index);
target = typeFunctionRuntime->typePackArena.allocate(TypeFunctionGenericTypePack{gPack->explicitName, name});
}
else
{
std::string error = format("Argument of type pack %s is not currently serializable by type functions", toString(tp).c_str());
@ -290,6 +291,9 @@ private:
serializeChildren(f1, f2);
else if (auto [c1, c2] = std::tuple{get<ClassType>(ty), getMutable<TypeFunctionClassType>(tfti)}; c1 && c2)
serializeChildren(c1, c2);
else if (auto [g1, g2] = std::tuple{get<GenericType>(ty), getMutable<TypeFunctionGenericType>(tfti)};
FFlag::LuauUserTypeFunGenerics && g1 && g2)
serializeChildren(g1, g2);
else
{ // Either this or ty and tfti do not represent the same type
std::string error = format("Argument of type %s is not currently serializable by type functions", toString(ty).c_str());
@ -303,6 +307,9 @@ private:
serializeChildren(tPack1, tPack2);
else if (auto [vPack1, vPack2] = std::tuple{get<VariadicTypePack>(tp), getMutable<TypeFunctionVariadicTypePack>(tftp)}; vPack1 && vPack2)
serializeChildren(vPack1, vPack2);
else if (auto [gPack1, gPack2] = std::tuple{get<GenericTypePack>(tp), getMutable<TypeFunctionGenericTypePack>(tftp)};
FFlag::LuauUserTypeFunGenerics && gPack1 && gPack2)
serializeChildren(gPack1, gPack2);
else
{ // Either this or ty and tfti do not represent the same type
std::string error = format("Argument of type pack %s is not currently serializable by type functions", toString(tp).c_str());
@ -383,27 +390,26 @@ private:
void serializeChildren(const MetatableType* m1, TypeFunctionTableType* m2)
{
if (FFlag::LuauUserTypeFunFixMetatable)
{
// Serialize main part of the metatable immediately
if (auto tableTy = get<TableType>(m1->table))
serializeChildren(tableTy, m2);
}
else
{
auto tmpTable = get<TypeFunctionTableType>(shallowSerialize(m1->table));
if (!tmpTable)
state->ctx->ice->ice("Serializing user defined type function arguments: metatable's table is not a TableType");
m2->props = tmpTable->props;
m2->indexer = tmpTable->indexer;
}
// Serialize main part of the metatable immediately
if (auto tableTy = get<TableType>(m1->table))
serializeChildren(tableTy, m2);
m2->metatable = shallowSerialize(m1->metatable);
}
void serializeChildren(const FunctionType* f1, TypeFunctionFunctionType* f2)
{
if (FFlag::LuauUserTypeFunGenerics)
{
f2->generics.reserve(f1->generics.size());
for (auto ty : f1->generics)
f2->generics.push_back(shallowSerialize(ty));
f2->genericPacks.reserve(f1->genericPacks.size());
for (auto tp : f1->genericPacks)
f2->genericPacks.push_back(shallowSerialize(tp));
}
f2->argTypes = shallowSerialize(f1->argTypes);
f2->retTypes = shallowSerialize(f1->retTypes);
}
@ -433,6 +439,11 @@ private:
c2->parent = shallowSerialize(*c1->parent);
}
void serializeChildren(const GenericType* g1, TypeFunctionGenericType* g2)
{
// noop.
}
void serializeChildren(const TypePack* t1, TypeFunctionTypePack* t2)
{
for (const TypeId& ty : t1->head)
@ -446,6 +457,25 @@ private:
{
v2->type = shallowSerialize(v1->ty);
}
void serializeChildren(const GenericTypePack* v1, TypeFunctionGenericTypePack* v2)
{
// noop.
}
};
template<typename T>
struct SerializedGeneric
{
bool isNamed = false;
std::string name;
T type = nullptr;
};
struct SerializedFunctionScope
{
size_t oldQueueSize = 0;
TypeFunctionFunctionType* function = nullptr;
};
// Complete inverse of TypeFunctionSerializer
@ -466,6 +496,15 @@ class TypeFunctionDeserializer
// second must be PrimitiveType; else there should be an error
std::vector<std::tuple<TypeFunctionKind, Kind>> queue;
// Generic types and packs currently in scope
// Generics are resolved by name even if runtime generic type pointers are different
// Multiple names mapping to the same generic can be in scope for nested generic functions
std::vector<SerializedGeneric<TypeId>> genericTypes;
std::vector<SerializedGeneric<TypePackId>> genericPacks;
// To track when generics go out of scope, we have a list of queue positions at which a specific function has introduced generics
std::vector<SerializedFunctionScope> functionScopes;
SeenTypes types; // Mapping of TypeFunctionTypeIds that have been shallow deserialized to TypeIds
SeenTypePacks packs; // Mapping of TypeFunctionTypePackIds that have been shallow deserialized to TypePackIds
@ -477,7 +516,9 @@ public:
, typeFunctionRuntime(state->ctx->typeFunctionRuntime)
, queue({})
, types({})
, packs({}){};
, packs({})
{
}
TypeId deserialize(TypeFunctionTypeId ty)
{
@ -531,6 +572,16 @@ private:
queue.pop_back();
deserializeChildren(tfti, ty);
if (FFlag::LuauUserTypeFunGenerics)
{
// If we have completed working on all children of a function, remove the generic parameters from scope
if (!functionScopes.empty() && queue.size() == functionScopes.back().oldQueueSize && state->errors.empty())
{
closeFunctionScope(functionScopes.back().function);
functionScopes.pop_back();
}
}
}
}
@ -563,6 +614,21 @@ private:
}
}
void closeFunctionScope(TypeFunctionFunctionType* f)
{
if (!f->generics.empty())
{
LUAU_ASSERT(genericTypes.size() >= f->generics.size());
genericTypes.erase(genericTypes.begin() + int(genericTypes.size() - f->generics.size()), genericTypes.end());
}
if (!f->genericPacks.empty())
{
LUAU_ASSERT(genericPacks.size() >= f->genericPacks.size());
genericPacks.erase(genericPacks.begin() + int(genericPacks.size() - f->genericPacks.size()), genericPacks.end());
}
}
TypeId shallowDeserialize(TypeFunctionTypeId ty)
{
if (auto it = find(ty))
@ -587,16 +653,10 @@ private:
target = state->ctx->builtins->stringType;
break;
case TypeFunctionPrimitiveType::Type::Thread:
if (FFlag::LuauUserTypeFunThreadBuffer)
target = state->ctx->builtins->threadType;
else
state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized");
target = state->ctx->builtins->threadType;
break;
case TypeFunctionPrimitiveType::Type::Buffer:
if (FFlag::LuauUserTypeFunThreadBuffer)
target = state->ctx->builtins->bufferType;
else
state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized");
target = state->ctx->builtins->bufferType;
break;
default:
state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized");
@ -642,6 +702,33 @@ private:
else
state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious class type is being deserialized");
}
else if (auto g = get<TypeFunctionGenericType>(ty); FFlag::LuauUserTypeFunGenerics && g)
{
if (g->isPack)
{
state->errors.push_back(format("Generic type pack '%s...' cannot be placed in a type position", g->name.c_str()));
return nullptr;
}
else
{
auto it = std::find_if(
genericTypes.rbegin(),
genericTypes.rend(),
[&](const SerializedGeneric<TypeId>& el)
{
return g->isNamed == el.isNamed && g->name == el.name;
}
);
if (it == genericTypes.rend())
{
state->errors.push_back(format("Generic type '%s' is not in a scope of the active generic function", g->name.c_str()));
return nullptr;
}
target = it->type;
}
}
else
state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized");
@ -658,11 +745,36 @@ private:
// Create a shallow deserialization
TypePackId target = {};
if (auto tPack = get<TypeFunctionTypePack>(tp))
{
target = state->ctx->arena->addTypePack(TypePack{});
}
else if (auto vPack = get<TypeFunctionVariadicTypePack>(tp))
{
target = state->ctx->arena->addTypePack(VariadicTypePack{});
}
else if (auto gPack = get<TypeFunctionGenericTypePack>(tp); FFlag::LuauUserTypeFunGenerics && gPack)
{
auto it = std::find_if(
genericPacks.rbegin(),
genericPacks.rend(),
[&](const SerializedGeneric<TypePackId>& el)
{
return gPack->isNamed == el.isNamed && gPack->name == el.name;
}
);
if (it == genericPacks.rend())
{
state->errors.push_back(format("Generic type pack '%s...' is not in a scope of the active generic function", gPack->name.c_str()));
return nullptr;
}
target = it->type;
}
else
{
state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized");
}
packs[tp] = target;
queue.emplace_back(tp, target);
@ -697,6 +809,9 @@ private:
deserializeChildren(f2, f1);
else if (auto [c1, c2] = std::tuple{getMutable<ClassType>(ty), getMutable<TypeFunctionClassType>(tfti)}; c1 && c2)
deserializeChildren(c2, c1);
else if (auto [g1, g2] = std::tuple{getMutable<GenericType>(ty), getMutable<TypeFunctionGenericType>(tfti)};
FFlag::LuauUserTypeFunGenerics && g1 && g2)
deserializeChildren(g2, g1);
else
state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized");
}
@ -708,6 +823,9 @@ private:
else if (auto [vPack1, vPack2] = std::tuple{getMutable<VariadicTypePack>(tp), getMutable<TypeFunctionVariadicTypePack>(tftp)};
vPack1 && vPack2)
deserializeChildren(vPack2, vPack1);
else if (auto [gPack1, gPack2] = std::tuple{getMutable<GenericTypePack>(tp), getMutable<TypeFunctionGenericTypePack>(tftp)};
FFlag::LuauUserTypeFunGenerics && gPack1 && gPack2)
deserializeChildren(gPack2, gPack1);
else
state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized");
}
@ -791,6 +909,64 @@ private:
void deserializeChildren(TypeFunctionFunctionType* f2, FunctionType* f1)
{
if (FFlag::LuauUserTypeFunGenerics)
{
functionScopes.push_back({queue.size(), f2});
std::set<std::pair<bool, std::string>> genericNames;
// Introduce generic function parameters into scope
for (auto ty : f2->generics)
{
auto gty = get<TypeFunctionGenericType>(ty);
LUAU_ASSERT(gty && !gty->isPack);
std::pair<bool, std::string> nameKey = std::make_pair(gty->isNamed, gty->name);
// Duplicates are not allowed
if (genericNames.find(nameKey) != genericNames.end())
{
state->errors.push_back(format("Duplicate type parameter '%s'", gty->name.c_str()));
return;
}
genericNames.insert(nameKey);
TypeId mapping = state->ctx->arena->addTV(Type(gty->isNamed ? GenericType{state->ctx->scope.get(), gty->name} : GenericType{}));
genericTypes.push_back({gty->isNamed, gty->name, mapping});
}
for (auto tp : f2->genericPacks)
{
auto gtp = get<TypeFunctionGenericTypePack>(tp);
LUAU_ASSERT(gtp);
std::pair<bool, std::string> nameKey = std::make_pair(gtp->isNamed, gtp->name);
// Duplicates are not allowed
if (genericNames.find(nameKey) != genericNames.end())
{
state->errors.push_back(format("Duplicate type parameter '%s'", gtp->name.c_str()));
return;
}
genericNames.insert(nameKey);
TypePackId mapping =
state->ctx->arena->addTypePack(TypePackVar(gtp->isNamed ? GenericTypePack{state->ctx->scope.get(), gtp->name} : GenericTypePack{})
);
genericPacks.push_back({gtp->isNamed, gtp->name, mapping});
}
f1->generics.reserve(f2->generics.size());
for (auto ty : f2->generics)
f1->generics.push_back(shallowDeserialize(ty));
f1->genericPacks.reserve(f2->genericPacks.size());
for (auto tp : f2->genericPacks)
f1->genericPacks.push_back(shallowDeserialize(tp));
}
if (f2->argTypes)
f1->argTypes = shallowDeserialize(f2->argTypes);
@ -803,6 +979,11 @@ private:
// noop.
}
void deserializeChildren(TypeFunctionGenericType* g2, GenericType* g1)
{
// noop.
}
void deserializeChildren(TypeFunctionTypePack* t2, TypePack* t1)
{
for (TypeFunctionTypeId& ty : t2->head)
@ -816,6 +997,11 @@ private:
{
v1->ty = shallowDeserialize(v2->type);
}
void deserializeChildren(TypeFunctionGenericTypePack* v2, GenericTypePack* v1)
{
// noop.
}
};
TypeFunctionTypeId serialize(TypeId ty, TypeFunctionRuntimeBuilderState* state)

View file

@ -32,7 +32,9 @@ LUAU_FASTINTVARIABLE(LuauVisitRecursionLimit, 500)
LUAU_FASTFLAG(LuauKnowsTheDataModel3)
LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification)
LUAU_FASTFLAG(LuauInstantiateInSubtyping)
LUAU_FASTFLAGVARIABLE(LuauMetatableFollow)
LUAU_FASTFLAGVARIABLE(LuauOldSolverCreatesChildScopePointers)
LUAU_FASTFLAG(LuauPreserveUnionIntersectionNodeForLeadingTokenSingleType)
LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds)
namespace Luau
{
@ -761,8 +763,12 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatRepeat& state
struct Demoter : Substitution
{
Demoter(TypeArena* arena)
TypeArena* arena = nullptr;
NotNull<BuiltinTypes> builtins;
Demoter(TypeArena* arena, NotNull<BuiltinTypes> builtins)
: Substitution(TxnLog::empty(), arena)
, arena(arena)
, builtins(builtins)
{
}
@ -788,7 +794,8 @@ struct Demoter : Substitution
{
auto ftv = get<FreeType>(ty);
LUAU_ASSERT(ftv);
return addType(FreeType{demotedLevel(ftv->level)});
return FFlag::LuauFreeTypesMustHaveBounds ? arena->freshType(builtins, demotedLevel(ftv->level))
: addType(FreeType{demotedLevel(ftv->level)});
}
TypePackId clean(TypePackId tp) override
@ -835,7 +842,7 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatReturn& retur
}
}
Demoter demoter{&currentModule->internalTypes};
Demoter demoter{&currentModule->internalTypes, builtinTypes};
demoter.demote(expectedTypes);
TypePackId retPack = checkExprList(scope, return_.location, return_.list, false, {}, expectedTypes).type;
@ -2799,10 +2806,10 @@ TypeId TypeChecker::checkRelationalOperation(
reportError(
expr.location,
GenericError{
format("Type '%s' cannot be compared with relational operator %s", toString(leftType).c_str(), toString(expr.op).c_str())
}
);
}
format("Type '%s' cannot be compared with relational operator %s", toString(leftType).c_str(), toString(expr.op).c_str())
}
);
}
return booleanType;
}
@ -2866,7 +2873,7 @@ TypeId TypeChecker::checkRelationalOperation(
std::optional<TypeId> metamethod = findMetatableEntry(lhsType, metamethodName, expr.location, /* addErrors= */ true);
if (metamethod)
{
if (const FunctionType* ftv = get<FunctionType>(FFlag::LuauMetatableFollow ? follow(*metamethod) : *metamethod))
if (const FunctionType* ftv = get<FunctionType>(follow(*metamethod)))
{
if (isEquality)
{
@ -4408,7 +4415,7 @@ std::vector<std::optional<TypeId>> TypeChecker::getExpectedTypesForCall(const st
}
}
Demoter demoter{&currentModule->internalTypes};
Demoter demoter{&currentModule->internalTypes, builtinTypes};
demoter.demote(expectedTypes);
return expectedTypes;
@ -4506,10 +4513,10 @@ std::unique_ptr<WithPredicate<TypePackId>> TypeChecker::checkCallOverload(
// When this function type has magic functions and did return something, we select that overload instead.
// TODO: pass in a Unifier object to the magic functions? This will allow the magic functions to cooperate with overload resolution.
if (ftv->magicFunction)
if (ftv->magic)
{
// TODO: We're passing in the wrong TypePackId. Should be argPack, but a unit test fails otherwise. CLI-40458
if (std::optional<WithPredicate<TypePackId>> ret = ftv->magicFunction(*this, scope, expr, argListResult))
if (std::optional<WithPredicate<TypePackId>> ret = ftv->magic->handleOldSolver(*this, scope, expr, argListResult))
return std::make_unique<WithPredicate<TypePackId>>(std::move(*ret));
}
@ -5205,6 +5212,13 @@ LUAU_NOINLINE void TypeChecker::reportErrorCodeTooComplex(const Location& locati
ScopePtr TypeChecker::childFunctionScope(const ScopePtr& parent, const Location& location, int subLevel)
{
ScopePtr scope = std::make_shared<Scope>(parent, subLevel);
if (FFlag::LuauOldSolverCreatesChildScopePointers)
{
scope->location = location;
scope->returnType = parent->returnType;
parent->children.emplace_back(scope.get());
}
currentModule->scopes.push_back(std::make_pair(location, scope));
return scope;
}
@ -5215,6 +5229,12 @@ ScopePtr TypeChecker::childScope(const ScopePtr& parent, const Location& locatio
ScopePtr scope = std::make_shared<Scope>(parent);
scope->level = parent->level;
scope->varargPack = parent->varargPack;
if (FFlag::LuauOldSolverCreatesChildScopePointers)
{
scope->location = location;
scope->returnType = parent->returnType;
parent->children.emplace_back(scope.get());
}
currentModule->scopes.push_back(std::make_pair(location, scope));
return scope;
@ -5260,7 +5280,8 @@ TypeId TypeChecker::freshType(const ScopePtr& scope)
TypeId TypeChecker::freshType(TypeLevel level)
{
return currentModule->internalTypes.addType(Type(FreeType(level)));
return FFlag::LuauFreeTypesMustHaveBounds ? currentModule->internalTypes.freshType(builtinTypes, level)
: currentModule->internalTypes.addType(Type(FreeType(level)));
}
TypeId TypeChecker::singletonType(bool value)
@ -5705,6 +5726,12 @@ TypeId TypeChecker::resolveTypeWorker(const ScopePtr& scope, const AstType& anno
}
else if (const auto& un = annotation.as<AstTypeUnion>())
{
if (FFlag::LuauPreserveUnionIntersectionNodeForLeadingTokenSingleType)
{
if (un->types.size == 1)
return resolveType(scope, *un->types.data[0]);
}
std::vector<TypeId> types;
for (AstType* ann : un->types)
types.push_back(resolveType(scope, *ann));
@ -5713,12 +5740,22 @@ TypeId TypeChecker::resolveTypeWorker(const ScopePtr& scope, const AstType& anno
}
else if (const auto& un = annotation.as<AstTypeIntersection>())
{
if (FFlag::LuauPreserveUnionIntersectionNodeForLeadingTokenSingleType)
{
if (un->types.size == 1)
return resolveType(scope, *un->types.data[0]);
}
std::vector<TypeId> types;
for (AstType* ann : un->types)
types.push_back(resolveType(scope, *ann));
return addType(IntersectionType{types});
}
else if (const auto& g = annotation.as<AstTypeGroup>())
{
return resolveType(scope, *g->type);
}
else if (const auto& tsb = annotation.as<AstTypeSingletonBool>())
{
return singletonType(tsb->value);

View file

@ -5,12 +5,15 @@
#include "Luau/Normalize.h"
#include "Luau/Scope.h"
#include "Luau/ToString.h"
#include "Luau/Type.h"
#include "Luau/TypeInfer.h"
#include <algorithm>
LUAU_FASTFLAG(LuauSolverV2);
LUAU_FASTFLAG(LuauAutocompleteRefactorsForIncrementalAutocomplete);
LUAU_FASTFLAG(LuauTrackInteriorFreeTypesOnScope);
LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds)
namespace Luau
{
@ -318,9 +321,11 @@ TypePack extendTypePack(
{
FreeType ft{ftp->scope, builtinTypes->neverType, builtinTypes->unknownType};
t = arena.addType(ft);
if (FFlag::LuauTrackInteriorFreeTypesOnScope)
trackInteriorFreeType(ftp->scope, t);
}
else
t = arena.freshType(ftp->scope);
t = FFlag::LuauFreeTypesMustHaveBounds ? arena.freshType(builtinTypes, ftp->scope) : arena.freshType_DEPRECATED(ftp->scope);
}
newPack.head.push_back(t);
@ -533,7 +538,7 @@ std::vector<TypeId> findBlockedArgTypesIn(AstExprCall* expr, NotNull<DenseHashMa
{
std::vector<TypeId> toBlock;
BlockedTypeInLiteralVisitor v{astTypes, NotNull{&toBlock}};
for (auto arg: expr->args)
for (auto arg : expr->args)
{
if (isLiteral(arg) || arg->is<AstExprGroup>())
{
@ -543,5 +548,21 @@ std::vector<TypeId> findBlockedArgTypesIn(AstExprCall* expr, NotNull<DenseHashMa
return toBlock;
}
void trackInteriorFreeType(Scope* scope, TypeId ty)
{
LUAU_ASSERT(FFlag::LuauSolverV2 && FFlag::LuauTrackInteriorFreeTypesOnScope);
for (; scope; scope = scope->parent.get())
{
if (scope->interiorFreeTypes)
{
scope->interiorFreeTypes->push_back(ty);
return;
}
}
// There should at least be *one* generalization constraint per module
// where `interiorFreeTypes` is present, which would be the one made
// by ConstraintGenerator::visitModuleRoot.
LUAU_ASSERT(!"No scopes in parent chain had a present `interiorFreeTypes` member.");
}
} // namespace Luau

View file

@ -22,6 +22,7 @@ LUAU_FASTFLAGVARIABLE(LuauTransitiveSubtyping)
LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAGVARIABLE(LuauFixIndexerSubtypingOrdering)
LUAU_FASTFLAGVARIABLE(LuauUnifierRecursionOnRestart)
LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds)
namespace Luau
{
@ -1648,7 +1649,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal
if (FFlag::LuauSolverV2)
return freshType(NotNull{types}, builtinTypes, scope);
else
return types->freshType(scope, level);
return FFlag::LuauFreeTypesMustHaveBounds ? types->freshType(builtinTypes, scope, level) : types->freshType_DEPRECATED(scope, level);
};
const TypePackId emptyTp = types->addTypePack(TypePack{{}, std::nullopt});

View file

@ -45,4 +45,4 @@ private:
size_t offset;
};
}
} // namespace Luau

View file

@ -1204,6 +1204,18 @@ public:
const AstArray<char> value;
};
class AstTypeGroup : public AstType
{
public:
LUAU_RTTI(AstTypeGroup)
explicit AstTypeGroup(const Location& location, AstType* type);
void visit(AstVisitor* visitor) override;
AstType* type;
};
class AstTypePack : public AstNode
{
public:
@ -1470,6 +1482,10 @@ public:
{
return visit(static_cast<AstType*>(node));
}
virtual bool visit(class AstTypeGroup* node)
{
return visit(static_cast<AstType*>(node));
}
virtual bool visit(class AstTypeError* node)
{
return visit(static_cast<AstType*>(node));

334
Ast/include/Luau/Cst.h Normal file
View file

@ -0,0 +1,334 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Location.h"
#include <string>
namespace Luau
{
extern int gCstRttiIndex;
template<typename T>
struct CstRtti
{
static const int value;
};
template<typename T>
const int CstRtti<T>::value = ++gCstRttiIndex;
#define LUAU_CST_RTTI(Class) \
static int CstClassIndex() \
{ \
return CstRtti<Class>::value; \
}
class CstNode
{
public:
explicit CstNode(int classIndex)
: classIndex(classIndex)
{
}
template<typename T>
bool is() const
{
return classIndex == T::CstClassIndex();
}
template<typename T>
T* as()
{
return classIndex == T::CstClassIndex() ? static_cast<T*>(this) : nullptr;
}
template<typename T>
const T* as() const
{
return classIndex == T::CstClassIndex() ? static_cast<const T*>(this) : nullptr;
}
const int classIndex;
};
class CstExprConstantNumber : public CstNode
{
public:
LUAU_CST_RTTI(CstExprConstantNumber)
explicit CstExprConstantNumber(const AstArray<char>& value);
AstArray<char> value;
};
class CstExprConstantString : public CstNode
{
public:
LUAU_CST_RTTI(CstExprConstantNumber)
enum QuoteStyle
{
QuotedSingle,
QuotedDouble,
QuotedRaw,
QuotedInterp,
};
CstExprConstantString(AstArray<char> sourceString, QuoteStyle quoteStyle, unsigned int blockDepth);
AstArray<char> sourceString;
QuoteStyle quoteStyle;
unsigned int blockDepth;
};
class CstExprCall : public CstNode
{
public:
LUAU_CST_RTTI(CstExprCall)
CstExprCall(std::optional<Position> openParens, std::optional<Position> closeParens, AstArray<Position> commaPositions);
std::optional<Position> openParens;
std::optional<Position> closeParens;
AstArray<Position> commaPositions;
};
class CstExprIndexExpr : public CstNode
{
public:
LUAU_CST_RTTI(CstExprIndexExpr)
CstExprIndexExpr(Position openBracketPosition, Position closeBracketPosition);
Position openBracketPosition;
Position closeBracketPosition;
};
class CstExprTable : public CstNode
{
public:
LUAU_CST_RTTI(CstExprTable)
enum Separator
{
Comma,
Semicolon,
};
struct Item
{
std::optional<Position> indexerOpenPosition; // '[', only if Kind == General
std::optional<Position> indexerClosePosition; // ']', only if Kind == General
std::optional<Position> equalsPosition; // only if Kind != List
std::optional<Separator> separator; // may be missing for last Item
std::optional<Position> separatorPosition;
};
explicit CstExprTable(const AstArray<Item>& items);
AstArray<Item> items;
};
// TODO: Shared between unary and binary, should we split?
class CstExprOp : public CstNode
{
public:
LUAU_CST_RTTI(CstExprOp)
explicit CstExprOp(Position opPosition);
Position opPosition;
};
class CstExprIfElse : public CstNode
{
public:
LUAU_CST_RTTI(CstExprIfElse)
CstExprIfElse(Position thenPosition, Position elsePosition, bool isElseIf);
Position thenPosition;
Position elsePosition;
bool isElseIf;
};
class CstExprInterpString : public CstNode
{
public:
LUAU_CST_RTTI(CstExprInterpString)
explicit CstExprInterpString(AstArray<AstArray<char>> sourceStrings, AstArray<Position> stringPositions);
AstArray<AstArray<char>> sourceStrings;
AstArray<Position> stringPositions;
};
class CstStatDo : public CstNode
{
public:
LUAU_CST_RTTI(CstStatDo)
explicit CstStatDo(Position endPosition);
Position endPosition;
};
class CstStatRepeat : public CstNode
{
public:
LUAU_CST_RTTI(CstStatRepeat)
explicit CstStatRepeat(Position untilPosition);
Position untilPosition;
};
class CstStatReturn : public CstNode
{
public:
LUAU_CST_RTTI(CstStatReturn)
explicit CstStatReturn(AstArray<Position> commaPositions);
AstArray<Position> commaPositions;
};
class CstStatLocal : public CstNode
{
public:
LUAU_CST_RTTI(CstStatLocal)
CstStatLocal(AstArray<Position> varsCommaPositions, AstArray<Position> valuesCommaPositions);
AstArray<Position> varsCommaPositions;
AstArray<Position> valuesCommaPositions;
};
class CstStatFor : public CstNode
{
public:
LUAU_CST_RTTI(CstStatFor)
CstStatFor(Position equalsPosition, Position endCommaPosition, std::optional<Position> stepCommaPosition);
Position equalsPosition;
Position endCommaPosition;
std::optional<Position> stepCommaPosition;
};
class CstStatForIn : public CstNode
{
public:
LUAU_CST_RTTI(CstStatForIn)
CstStatForIn(AstArray<Position> varsCommaPositions, AstArray<Position> valuesCommaPositions);
AstArray<Position> varsCommaPositions;
AstArray<Position> valuesCommaPositions;
};
class CstStatAssign : public CstNode
{
public:
LUAU_CST_RTTI(CstStatAssign)
CstStatAssign(AstArray<Position> varsCommaPositions, Position equalsPosition, AstArray<Position> valuesCommaPositions);
AstArray<Position> varsCommaPositions;
Position equalsPosition;
AstArray<Position> valuesCommaPositions;
};
class CstStatCompoundAssign : public CstNode
{
public:
LUAU_CST_RTTI(CstStatCompoundAssign)
explicit CstStatCompoundAssign(Position opPosition);
Position opPosition;
};
class CstStatLocalFunction : public CstNode
{
public:
LUAU_CST_RTTI(CstStatLocalFunction)
explicit CstStatLocalFunction(Position functionKeywordPosition);
Position functionKeywordPosition;
};
class CstTypeReference : public CstNode
{
public:
LUAU_CST_RTTI(CstTypeReference)
CstTypeReference(
std::optional<Position> prefixPointPosition,
Position openParametersPosition,
AstArray<Position> parametersCommaPositions,
Position closeParametersPosition
);
std::optional<Position> prefixPointPosition;
Position openParametersPosition;
AstArray<Position> parametersCommaPositions;
Position closeParametersPosition;
};
class CstTypeTable : public CstNode
{
public:
LUAU_CST_RTTI(CstTypeTable)
struct Item
{
enum struct Kind
{
Indexer,
Property,
StringProperty,
};
Kind kind;
Position indexerOpenPosition; // '[', only if Kind != Property
Position indexerClosePosition; // ']' only if Kind != Property
Position colonPosition;
std::optional<CstExprTable::Separator> separator; // may be missing for last Item
std::optional<Position> separatorPosition;
CstExprConstantString* stringInfo = nullptr; // only if Kind == StringProperty
};
CstTypeTable(AstArray<Item> items, bool isArray);
AstArray<Item> items;
bool isArray = false;
};
class CstTypeTypeof : public CstNode
{
public:
LUAU_CST_RTTI(CstTypeTypeof)
CstTypeTypeof(Position openPosition, Position closePosition);
Position openPosition;
Position closePosition;
};
class CstTypeSingletonString : public CstNode
{
public:
LUAU_CST_RTTI(CstTypeSingletonString)
CstTypeSingletonString(AstArray<char> sourceString, CstExprConstantString::QuoteStyle quoteStyle, unsigned int blockDepth);
AstArray<char> sourceString;
CstExprConstantString::QuoteStyle quoteStyle;
unsigned int blockDepth;
};
} // namespace Luau

View file

@ -87,6 +87,12 @@ struct Lexeme
Reserved_END
};
enum struct QuoteStyle
{
Single,
Double,
};
Type type;
Location location;
@ -111,6 +117,8 @@ public:
Lexeme(const Location& location, Type type, const char* name);
unsigned int getLength() const;
unsigned int getBlockDepth() const;
QuoteStyle getQuoteStyle() const;
std::string toString() const;
};
@ -230,17 +238,6 @@ private:
bool skipComments;
bool readNames;
// This offset represents a column offset to be applied to any positions created by the lexer until the next new line.
// For example:
// local x = 4
// local y = 5
// If we start lexing from the position of `l` in `local x = 4`, the line number will be 1, and the column will be 4
// However, because the lexer calculates line offsets by 'index in source buffer where there is a newline', the column
// count will start at 0. For this reason, for just the first line, we'll need to store the offset.
unsigned int lexResumeOffset;
enum class BraceType
{
InterpolatedString,

View file

@ -14,12 +14,37 @@ struct Position
{
}
bool operator==(const Position& rhs) const;
bool operator!=(const Position& rhs) const;
bool operator<(const Position& rhs) const;
bool operator>(const Position& rhs) const;
bool operator<=(const Position& rhs) const;
bool operator>=(const Position& rhs) const;
bool operator==(const Position& rhs) const
{
return this->column == rhs.column && this->line == rhs.line;
}
bool operator!=(const Position& rhs) const
{
return !(*this == rhs);
}
bool operator<(const Position& rhs) const
{
if (line == rhs.line)
return column < rhs.column;
else
return line < rhs.line;
}
bool operator>(const Position& rhs) const
{
if (line == rhs.line)
return column > rhs.column;
else
return line > rhs.line;
}
bool operator<=(const Position& rhs) const
{
return *this == rhs || *this < rhs;
}
bool operator>=(const Position& rhs) const
{
return *this == rhs || *this > rhs;
}
void shift(const Position& start, const Position& oldEnd, const Position& newEnd);
};
@ -52,8 +77,14 @@ struct Location
{
}
bool operator==(const Location& rhs) const;
bool operator!=(const Location& rhs) const;
bool operator==(const Location& rhs) const
{
return this->begin == rhs.begin && this->end == rhs.end;
}
bool operator!=(const Location& rhs) const
{
return !(*this == rhs);
}
bool encloses(const Location& l) const;
bool overlaps(const Location& l) const;

View file

@ -29,6 +29,8 @@ struct ParseOptions
bool allowDeclarationSyntax = false;
bool captureComments = false;
std::optional<FragmentParseResumeSettings> parseFragment = std::nullopt;
bool storeCstData = false;
bool noErrorLimit = false;
};
} // namespace Luau

View file

@ -10,6 +10,7 @@ namespace Luau
{
class AstStatBlock;
class CstNode;
class ParseError : public std::exception
{
@ -55,6 +56,8 @@ struct Comment
Location location;
};
using CstNodeMap = DenseHashMap<AstNode*, CstNode*>;
struct ParseResult
{
AstStatBlock* root;
@ -64,6 +67,8 @@ struct ParseResult
std::vector<ParseError> errors;
std::vector<Comment> commentLocations;
CstNodeMap cstNodeMap{nullptr};
};
static constexpr const char* kParseNameError = "%error-id%";

View file

@ -8,6 +8,7 @@
#include "Luau/StringUtils.h"
#include "Luau/DenseHash.h"
#include "Luau/Common.h"
#include "Luau/Cst.h"
#include <initializer_list>
#include <optional>
@ -116,7 +117,7 @@ private:
AstStat* parseFor();
// funcname ::= Name {`.' Name} [`:' Name]
AstExpr* parseFunctionName(Location start, bool& hasself, AstName& debugname);
AstExpr* parseFunctionName(Location start_DEPRECATED, bool& hasself, AstName& debugname);
// function funcname funcbody
LUAU_FORCEINLINE AstStat* parseFunctionStat(const AstArray<AstAttr*>& attributes = {nullptr, 0});
@ -173,14 +174,18 @@ private:
);
// explist ::= {exp `,'} exp
void parseExprList(TempVector<AstExpr*>& result);
void parseExprList(TempVector<AstExpr*>& result, TempVector<Position>* commaPositions = nullptr);
// binding ::= Name [`:` Type]
Binding parseBinding();
// bindinglist ::= (binding | `...') {`,' bindinglist}
// Returns the location of the vararg ..., or std::nullopt if the function is not vararg.
std::tuple<bool, Location, AstTypePack*> parseBindingList(TempVector<Binding>& result, bool allowDot3 = false);
std::tuple<bool, Location, AstTypePack*> parseBindingList(
TempVector<Binding>& result,
bool allowDot3 = false,
TempVector<Position>* commaPositions = nullptr
);
AstType* parseOptionalType();
@ -201,7 +206,17 @@ private:
std::optional<AstTypeList> parseOptionalReturnType();
std::pair<Location, AstTypeList> parseReturnType();
AstTableIndexer* parseTableIndexer(AstTableAccess access, std::optional<Location> accessLocation);
struct TableIndexerResult
{
AstTableIndexer* node;
Position indexerOpenPosition;
Position indexerClosePosition;
Position colonPosition;
};
TableIndexerResult parseTableIndexer(AstTableAccess access, std::optional<Location> accessLocation);
// Remove with FFlagLuauStoreCSTData
AstTableIndexer* parseTableIndexer_DEPRECATED(AstTableAccess access, std::optional<Location> accessLocation);
AstTypeOrPack parseFunctionType(bool allowPack, const AstArray<AstAttr*>& attributes);
AstType* parseFunctionTypeTail(
@ -259,6 +274,8 @@ private:
// args ::= `(' [explist] `)' | tableconstructor | String
AstExpr* parseFunctionArgs(AstExpr* func, bool self);
std::optional<CstExprTable::Separator> tableSeparator();
// tableconstructor ::= `{' [fieldlist] `}'
// fieldlist ::= field {fieldsep field} [fieldsep]
// field ::= `[' exp `]' `=' exp | Name `=' exp | exp
@ -280,9 +297,13 @@ private:
std::pair<AstArray<AstGenericType>, AstArray<AstGenericTypePack>> parseGenericTypeList(bool withDefaultValues);
// `<' Type[, ...] `>'
AstArray<AstTypeOrPack> parseTypeParams();
AstArray<AstTypeOrPack> parseTypeParams(
Position* openingPosition = nullptr,
TempVector<Position>* commaPositions = nullptr,
Position* closingPosition = nullptr
);
std::optional<AstArray<char>> parseCharArray();
std::optional<AstArray<char>> parseCharArray(AstArray<char>* originalString = nullptr);
AstExpr* parseString();
AstExpr* parseNumber();
@ -292,6 +313,9 @@ private:
void restoreLocals(unsigned int offset);
/// Returns string quote style and block depth
std::pair<CstExprConstantString::QuoteStyle, unsigned int> extractStringDetails();
// check that parser is at lexeme/symbol, move to next lexeme/symbol on success, report failure and continue on failure
bool expectAndConsume(char value, const char* context = nullptr);
bool expectAndConsume(Lexeme::Type type, const char* context = nullptr);
@ -435,6 +459,7 @@ private:
std::vector<AstAttr*> scratchAttr;
std::vector<AstStat*> scratchStat;
std::vector<AstArray<char>> scratchString;
std::vector<AstArray<char>> scratchString2;
std::vector<AstExpr*> scratchExpr;
std::vector<AstExpr*> scratchExprAux;
std::vector<AstName> scratchName;
@ -442,15 +467,20 @@ private:
std::vector<Binding> scratchBinding;
std::vector<AstLocal*> scratchLocal;
std::vector<AstTableProp> scratchTableTypeProps;
std::vector<CstTypeTable::Item> scratchCstTableTypeProps;
std::vector<AstType*> scratchType;
std::vector<AstTypeOrPack> scratchTypeOrPack;
std::vector<AstDeclaredClassProp> scratchDeclaredClassProps;
std::vector<AstExprTable::Item> scratchItem;
std::vector<CstExprTable::Item> scratchCstItem;
std::vector<AstArgumentName> scratchArgName;
std::vector<AstGenericType> scratchGenericTypes;
std::vector<AstGenericTypePack> scratchGenericTypePacks;
std::vector<std::optional<AstArgumentName>> scratchOptArgName;
std::vector<Position> scratchPosition;
std::string scratchData;
CstNodeMap cstNodeMap;
};
} // namespace Luau

View file

@ -63,4 +63,4 @@ void* Allocator::allocate(size_t size)
return page->data;
}
}
} // namespace Luau

View file

@ -1091,6 +1091,18 @@ void AstTypeSingletonString::visit(AstVisitor* visitor)
visitor->visit(this);
}
AstTypeGroup::AstTypeGroup(const Location& location, AstType* type)
: AstType(ClassIndex(), location)
, type(type)
{
}
void AstTypeGroup::visit(AstVisitor* visitor)
{
if (visitor->visit(this))
type->visit(visitor);
}
AstTypeError::AstTypeError(const Location& location, const AstArray<AstType*>& types, bool isMissing, unsigned messageIndex)
: AstType(ClassIndex(), location)
, types(types)
@ -1151,10 +1163,7 @@ void AstTypePackGeneric::visit(AstVisitor* visitor)
bool isLValue(const AstExpr* expr)
{
return expr->is<AstExprLocal>()
|| expr->is<AstExprGlobal>()
|| expr->is<AstExprIndexName>()
|| expr->is<AstExprIndexExpr>();
return expr->is<AstExprLocal>() || expr->is<AstExprGlobal>() || expr->is<AstExprIndexName>() || expr->is<AstExprIndexExpr>();
}
AstName getIdentifier(AstExpr* node)

169
Ast/src/Cst.cpp Normal file
View file

@ -0,0 +1,169 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Ast.h"
#include "Luau/Cst.h"
#include "Luau/Common.h"
namespace Luau
{
int gCstRttiIndex = 0;
CstExprConstantNumber::CstExprConstantNumber(const AstArray<char>& value)
: CstNode(CstClassIndex())
, value(value)
{
}
CstExprConstantString::CstExprConstantString(AstArray<char> sourceString, QuoteStyle quoteStyle, unsigned int blockDepth)
: CstNode(CstClassIndex())
, sourceString(sourceString)
, quoteStyle(quoteStyle)
, blockDepth(blockDepth)
{
LUAU_ASSERT(blockDepth == 0 || quoteStyle == QuoteStyle::QuotedRaw);
}
CstExprCall::CstExprCall(std::optional<Position> openParens, std::optional<Position> closeParens, AstArray<Position> commaPositions)
: CstNode(CstClassIndex())
, openParens(openParens)
, closeParens(closeParens)
, commaPositions(commaPositions)
{
}
CstExprIndexExpr::CstExprIndexExpr(Position openBracketPosition, Position closeBracketPosition)
: CstNode(CstClassIndex())
, openBracketPosition(openBracketPosition)
, closeBracketPosition(closeBracketPosition)
{
}
CstExprTable::CstExprTable(const AstArray<Item>& items)
: CstNode(CstClassIndex())
, items(items)
{
}
CstExprOp::CstExprOp(Position opPosition)
: CstNode(CstClassIndex())
, opPosition(opPosition)
{
}
CstExprIfElse::CstExprIfElse(Position thenPosition, Position elsePosition, bool isElseIf)
: CstNode(CstClassIndex())
, thenPosition(thenPosition)
, elsePosition(elsePosition)
, isElseIf(isElseIf)
{
}
CstExprInterpString::CstExprInterpString(AstArray<AstArray<char>> sourceStrings, AstArray<Position> stringPositions)
: CstNode(CstClassIndex())
, sourceStrings(sourceStrings)
, stringPositions(stringPositions)
{
}
CstStatDo::CstStatDo(Position endPosition)
: CstNode(CstClassIndex())
, endPosition(endPosition)
{
}
CstStatRepeat::CstStatRepeat(Position untilPosition)
: CstNode(CstClassIndex())
, untilPosition(untilPosition)
{
}
CstStatReturn::CstStatReturn(AstArray<Position> commaPositions)
: CstNode(CstClassIndex())
, commaPositions(commaPositions)
{
}
CstStatLocal::CstStatLocal(AstArray<Position> varsCommaPositions, AstArray<Position> valuesCommaPositions)
: CstNode(CstClassIndex())
, varsCommaPositions(varsCommaPositions)
, valuesCommaPositions(valuesCommaPositions)
{
}
CstStatFor::CstStatFor(Position equalsPosition, Position endCommaPosition, std::optional<Position> stepCommaPosition)
: CstNode(CstClassIndex())
, equalsPosition(equalsPosition)
, endCommaPosition(endCommaPosition)
, stepCommaPosition(stepCommaPosition)
{
}
CstStatForIn::CstStatForIn(AstArray<Position> varsCommaPositions, AstArray<Position> valuesCommaPositions)
: CstNode(CstClassIndex())
, varsCommaPositions(varsCommaPositions)
, valuesCommaPositions(valuesCommaPositions)
{
}
CstStatAssign::CstStatAssign(
AstArray<Position> varsCommaPositions,
Position equalsPosition,
AstArray<Position> valuesCommaPositions
)
: CstNode(CstClassIndex())
, varsCommaPositions(varsCommaPositions)
, equalsPosition(equalsPosition)
, valuesCommaPositions(valuesCommaPositions)
{
}
CstStatCompoundAssign::CstStatCompoundAssign(Position opPosition)
: CstNode(CstClassIndex())
, opPosition(opPosition)
{
}
CstStatLocalFunction::CstStatLocalFunction(Position functionKeywordPosition)
: CstNode(CstClassIndex())
, functionKeywordPosition(functionKeywordPosition)
{
}
CstTypeReference::CstTypeReference(
std::optional<Position> prefixPointPosition,
Position openParametersPosition,
AstArray<Position> parametersCommaPositions,
Position closeParametersPosition
)
: CstNode(CstClassIndex())
, prefixPointPosition(prefixPointPosition)
, openParametersPosition(openParametersPosition)
, parametersCommaPositions(parametersCommaPositions)
, closeParametersPosition(closeParametersPosition)
{
}
CstTypeTable::CstTypeTable(AstArray<Item> items, bool isArray)
: CstNode(CstClassIndex())
, items(items)
, isArray(isArray)
{
}
CstTypeTypeof::CstTypeTypeof(Position openPosition, Position closePosition)
: CstNode(CstClassIndex())
, openPosition(openPosition)
, closePosition(closePosition)
{
}
CstTypeSingletonString::CstTypeSingletonString(AstArray<char> sourceString, CstExprConstantString::QuoteStyle quoteStyle, unsigned int blockDepth)
: CstNode(CstClassIndex())
, sourceString(sourceString)
, quoteStyle(quoteStyle)
, blockDepth(blockDepth)
{
LUAU_ASSERT(quoteStyle != CstExprConstantString::QuotedInterp);
}
} // namespace Luau

View file

@ -8,7 +8,9 @@
#include <limits.h>
LUAU_FASTFLAGVARIABLE(LexerResumesFromPosition)
LUAU_FASTFLAGVARIABLE(LexerResumesFromPosition2)
LUAU_FASTFLAGVARIABLE(LexerFixInterpStringStart)
namespace Luau
{
@ -304,20 +306,51 @@ static char unescape(char ch)
}
}
unsigned int Lexeme::getBlockDepth() const
{
LUAU_ASSERT(type == Lexeme::RawString || type == Lexeme::BlockComment);
// If we have a well-formed string, we are guaranteed to see 2 `]` characters after the end of the string contents
LUAU_ASSERT(*(data + length) == ']');
unsigned int depth = 0;
do
{
depth++;
} while (*(data + length + depth) != ']');
return depth - 1;
}
Lexeme::QuoteStyle Lexeme::getQuoteStyle() const
{
LUAU_ASSERT(type == Lexeme::QuotedString);
// If we have a well-formed string, we are guaranteed to see a closing delimiter after the string
LUAU_ASSERT(data);
char quote = *(data + length);
if (quote == '\'')
return Lexeme::QuoteStyle::Single;
else if (quote == '"')
return Lexeme::QuoteStyle::Double;
LUAU_ASSERT(!"Unknown quote style");
return Lexeme::QuoteStyle::Double; // unreachable, but required due to compiler warning
}
Lexer::Lexer(const char* buffer, size_t bufferSize, AstNameTable& names, Position startPosition)
: buffer(buffer)
, bufferSize(bufferSize)
, offset(0)
, line(FFlag::LexerResumesFromPosition ? startPosition.line : 0)
, lineOffset(0)
, line(FFlag::LexerResumesFromPosition2 ? startPosition.line : 0)
, lineOffset(FFlag::LexerResumesFromPosition2 ? 0u - startPosition.column : 0)
, lexeme(
(FFlag::LexerResumesFromPosition ? Location(Position(startPosition.line, startPosition.column), 0) : Location(Position(0, 0), 0)),
(FFlag::LexerResumesFromPosition2 ? Location(Position(startPosition.line, startPosition.column), 0) : Location(Position(0, 0), 0)),
Lexeme::Eof
)
, names(names)
, skipComments(false)
, readNames(true)
, lexResumeOffset(FFlag::LexerResumesFromPosition ? startPosition.column : 0)
{
}
@ -372,7 +405,6 @@ Lexeme Lexer::lookahead()
Location currentPrevLocation = prevLocation;
size_t currentBraceStackSize = braceStack.size();
BraceType currentBraceType = braceStack.empty() ? BraceType::Normal : braceStack.back();
unsigned int currentLexResumeOffset = lexResumeOffset;
Lexeme result = next();
@ -381,7 +413,6 @@ Lexeme Lexer::lookahead()
lineOffset = currentLineOffset;
lexeme = currentLexeme;
prevLocation = currentPrevLocation;
lexResumeOffset = currentLexResumeOffset;
if (braceStack.size() < currentBraceStackSize)
braceStack.push_back(currentBraceType);
@ -412,9 +443,10 @@ char Lexer::peekch(unsigned int lookahead) const
return (offset + lookahead < bufferSize) ? buffer[offset + lookahead] : 0;
}
LUAU_FORCEINLINE
Position Lexer::position() const
{
return Position(line, offset - lineOffset + (FFlag::LexerResumesFromPosition ? lexResumeOffset : 0));
return Position(line, offset - lineOffset);
}
LUAU_FORCEINLINE
@ -433,9 +465,6 @@ void Lexer::consumeAny()
{
line++;
lineOffset = offset + 1;
// every new line, we reset
if (FFlag::LexerResumesFromPosition)
lexResumeOffset = 0;
}
offset++;
@ -764,7 +793,7 @@ Lexeme Lexer::readNext()
return Lexeme(Location(start, 1), '}');
}
return readInterpolatedStringSection(position(), Lexeme::InterpStringMid, Lexeme::InterpStringEnd);
return readInterpolatedStringSection(FFlag::LexerFixInterpStringStart ? start : position(), Lexeme::InterpStringMid, Lexeme::InterpStringEnd);
}
case '=':

View file

@ -4,42 +4,6 @@
namespace Luau
{
bool Position::operator==(const Position& rhs) const
{
return this->column == rhs.column && this->line == rhs.line;
}
bool Position::operator!=(const Position& rhs) const
{
return !(*this == rhs);
}
bool Position::operator<(const Position& rhs) const
{
if (line == rhs.line)
return column < rhs.column;
else
return line < rhs.line;
}
bool Position::operator>(const Position& rhs) const
{
if (line == rhs.line)
return column > rhs.column;
else
return line > rhs.line;
}
bool Position::operator<=(const Position& rhs) const
{
return *this == rhs || *this < rhs;
}
bool Position::operator>=(const Position& rhs) const
{
return *this == rhs || *this > rhs;
}
void Position::shift(const Position& start, const Position& oldEnd, const Position& newEnd)
{
if (*this >= start)
@ -54,16 +18,6 @@ void Position::shift(const Position& start, const Position& oldEnd, const Positi
}
}
bool Location::operator==(const Location& rhs) const
{
return this->begin == rhs.begin && this->end == rhs.end;
}
bool Location::operator!=(const Location& rhs) const
{
return !(*this == rhs);
}
bool Location::encloses(const Location& l) const
{
return begin <= l.begin && end >= l.end;

File diff suppressed because it is too large Load diff

View file

@ -10,7 +10,7 @@
std::optional<std::string> getCurrentWorkingDirectory();
std::string normalizePath(std::string_view path);
std::string resolvePath(std::string_view relativePath, std::string_view baseFilePath);
std::optional<std::string> resolvePath(std::string_view relativePath, std::string_view baseFilePath);
std::optional<std::string> readFile(const std::string& name);
std::optional<std::string> readStdin();
@ -23,7 +23,7 @@ bool isDirectory(const std::string& path);
bool traverseDirectory(const std::string& path, const std::function<void(const std::string& name)>& callback);
std::vector<std::string_view> splitPath(std::string_view path);
std::string joinPaths(const std::string& lhs, const std::string& rhs);
std::optional<std::string> getParentPath(const std::string& path);
std::string joinPaths(std::string_view lhs, std::string_view rhs);
std::optional<std::string> getParentPath(std::string_view path);
std::vector<std::string> getSourceFiles(int argc, char** argv);

View file

@ -7,9 +7,9 @@
#include "Luau/TypeAttach.h"
#include "Luau/Transpiler.h"
#include "FileUtils.h"
#include "Flags.h"
#include "Require.h"
#include "Luau/FileUtils.h"
#include "Luau/Flags.h"
#include "Luau/Require.h"
#include <condition_variable>
#include <functional>

View file

@ -8,7 +8,7 @@
#include "Luau/ParseOptions.h"
#include "Luau/ToString.h"
#include "FileUtils.h"
#include "Luau/FileUtils.h"
static void displayHelp(const char* argv0)
{

View file

@ -7,8 +7,8 @@
#include "Luau/BytecodeBuilder.h"
#include "Luau/Parser.h"
#include "Luau/BytecodeSummary.h"
#include "FileUtils.h"
#include "Flags.h"
#include "Luau/FileUtils.h"
#include "Luau/Flags.h"
#include <memory>

View file

@ -8,8 +8,8 @@
#include "Luau/Parser.h"
#include "Luau/TimeTrace.h"
#include "FileUtils.h"
#include "Flags.h"
#include "Luau/FileUtils.h"
#include "Luau/Flags.h"
#include <memory>
@ -341,7 +341,8 @@ static bool compileFile(const char* name, CompileFormat format, Luau::CodeGen::A
bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Remarks);
bcb.setDumpSource(*source);
}
else if (format == CompileFormat::Codegen || format == CompileFormat::CodegenAsm || format == CompileFormat::CodegenIr || format == CompileFormat::CodegenVerbose)
else if (format == CompileFormat::Codegen || format == CompileFormat::CodegenAsm || format == CompileFormat::CodegenIr ||
format == CompileFormat::CodegenVerbose)
{
bcb.setDumpFlags(
Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Locals |

View file

@ -1,5 +1,5 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Coverage.h"
#include "Luau/Coverage.h"
#include "lua.h"

View file

@ -1,5 +1,5 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "FileUtils.h"
#include "Luau/FileUtils.h"
#include "Luau/Common.h"
@ -20,6 +20,7 @@
#endif
#include <string.h>
#include <string_view>
#ifdef _WIN32
static std::wstring fromUtf8(const std::string& path)
@ -90,108 +91,76 @@ std::optional<std::string> getCurrentWorkingDirectory()
return std::nullopt;
}
// Returns the normal/canonical form of a path (e.g. "../subfolder/../module.luau" -> "../module.luau")
std::string normalizePath(std::string_view path)
{
return resolvePath(path, "");
}
const std::vector<std::string_view> components = splitPath(path);
std::vector<std::string_view> normalizedComponents;
// Takes a path that is relative to the file at baseFilePath and returns the path explicitly rebased onto baseFilePath.
// For absolute paths, baseFilePath will be ignored, and this function will resolve the path to a canonical path:
// (e.g. "/Users/.././Users/johndoe" -> "/Users/johndoe").
std::string resolvePath(std::string_view path, std::string_view baseFilePath)
{
std::vector<std::string_view> pathComponents;
std::vector<std::string_view> baseFilePathComponents;
const bool isAbsolute = isAbsolutePath(path);
// Dependent on whether the final resolved path is absolute or relative
// - if relative (when path and baseFilePath are both relative), resolvedPathPrefix remains empty
// - if absolute (if either path or baseFilePath are absolute), resolvedPathPrefix is "C:\", "/", etc.
std::string resolvedPathPrefix;
bool isResolvedPathRelative = false;
if (isAbsolutePath(path))
{
// path is absolute, we use path's prefix and ignore baseFilePath
size_t afterPrefix = path.find_first_of("\\/") + 1;
resolvedPathPrefix = path.substr(0, afterPrefix);
pathComponents = splitPath(path.substr(afterPrefix));
}
else
{
size_t afterPrefix = baseFilePath.find_first_of("\\/") + 1;
baseFilePathComponents = splitPath(baseFilePath.substr(afterPrefix));
if (isAbsolutePath(baseFilePath))
{
// path is relative and baseFilePath is absolute, we use baseFilePath's prefix
resolvedPathPrefix = baseFilePath.substr(0, afterPrefix);
}
else
{
// path and baseFilePath are both relative, we do not set a prefix (resolved path will be relative)
isResolvedPathRelative = true;
}
pathComponents = splitPath(path);
}
// Remove filename from components
if (!baseFilePathComponents.empty())
baseFilePathComponents.pop_back();
// Resolve the path by applying pathComponents to baseFilePathComponents
int numPrependedParents = 0;
for (std::string_view component : pathComponents)
// 1. Normalize path components
const size_t startIndex = isAbsolute ? 1 : 0;
for (size_t i = startIndex; i < components.size(); i++)
{
std::string_view component = components[i];
if (component == "..")
{
if (baseFilePathComponents.empty())
if (normalizedComponents.empty())
{
if (isResolvedPathRelative)
numPrependedParents++; // "../" will later be added to the beginning of the resolved path
if (!isAbsolute)
{
normalizedComponents.emplace_back("..");
}
}
else if (baseFilePathComponents.back() != "..")
else if (normalizedComponents.back() == "..")
{
baseFilePathComponents.pop_back(); // Resolve cases like "folder/subfolder/../../file" to "file"
normalizedComponents.emplace_back("..");
}
else
{
normalizedComponents.pop_back();
}
}
else if (component != "." && !component.empty())
else if (!component.empty() && component != ".")
{
baseFilePathComponents.push_back(component);
normalizedComponents.emplace_back(component);
}
}
// Create resolved path prefix for relative paths
if (isResolvedPathRelative)
std::string normalizedPath;
// 2. Add correct prefix to formatted path
if (isAbsolute)
{
if (numPrependedParents > 0)
{
resolvedPathPrefix.reserve(numPrependedParents * 3);
for (int i = 0; i < numPrependedParents; i++)
{
resolvedPathPrefix += "../";
}
}
else
{
resolvedPathPrefix = "./";
}
normalizedPath += components[0];
normalizedPath += "/";
}
else if (normalizedComponents.empty() || normalizedComponents[0] != "..")
{
normalizedPath += "./";
}
// Join baseFilePathComponents to form the resolved path
std::string resolvedPath = resolvedPathPrefix;
for (auto iter = baseFilePathComponents.begin(); iter != baseFilePathComponents.end(); ++iter)
// 3. Join path components to form the normalized path
for (auto iter = normalizedComponents.begin(); iter != normalizedComponents.end(); ++iter)
{
if (iter != baseFilePathComponents.begin())
resolvedPath += "/";
if (iter != normalizedComponents.begin())
normalizedPath += "/";
resolvedPath += *iter;
normalizedPath += *iter;
}
if (resolvedPath.size() > resolvedPathPrefix.size() && resolvedPath.back() == '/')
{
// Remove trailing '/' if present
resolvedPath.pop_back();
}
return resolvedPath;
if (normalizedPath.size() >= 2 && normalizedPath[normalizedPath.size() - 1] == '.' && normalizedPath[normalizedPath.size() - 2] == '.')
normalizedPath += "/";
return normalizedPath;
}
std::optional<std::string> resolvePath(std::string_view path, std::string_view baseFilePath)
{
std::optional<std::string> baseFilePathParent = getParentPath(baseFilePath);
if (!baseFilePathParent)
return std::nullopt;
return normalizePath(joinPaths(*baseFilePathParent, path));
}
bool hasFileExtension(std::string_view name, const std::vector<std::string>& extensions)
@ -416,16 +385,16 @@ std::vector<std::string_view> splitPath(std::string_view path)
return components;
}
std::string joinPaths(const std::string& lhs, const std::string& rhs)
std::string joinPaths(std::string_view lhs, std::string_view rhs)
{
std::string result = lhs;
std::string result = std::string(lhs);
if (!result.empty() && result.back() != '/' && result.back() != '\\')
result += '/';
result += rhs;
return result;
}
std::optional<std::string> getParentPath(const std::string& path)
std::optional<std::string> getParentPath(std::string_view path)
{
if (path == "" || path == "." || path == "/")
return std::nullopt;
@ -441,7 +410,7 @@ std::optional<std::string> getParentPath(const std::string& path)
return "/";
if (slash != std::string::npos)
return path.substr(0, slash);
return std::string(path.substr(0, slash));
return "";
}
@ -471,10 +440,12 @@ std::vector<std::string> getSourceFiles(int argc, char** argv)
if (argv[i][0] == '-' && argv[i][1] != '\0')
continue;
if (isDirectory(argv[i]))
std::string normalized = normalizePath(argv[i]);
if (isDirectory(normalized))
{
traverseDirectory(
argv[i],
normalized,
[&](const std::string& name)
{
std::string ext = getExtension(name);
@ -486,7 +457,7 @@ std::vector<std::string> getSourceFiles(int argc, char** argv)
}
else
{
files.push_back(argv[i]);
files.push_back(normalized);
}
}

View file

@ -5,7 +5,7 @@
#include "Luau/Parser.h"
#include "Luau/Transpiler.h"
#include "FileUtils.h"
#include "Luau/FileUtils.h"
#include <algorithm>
#include <stdio.h>

View file

@ -1,5 +1,5 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Repl.h"
#include "Luau/Repl.h"
#include "Luau/Common.h"
#include "lua.h"
@ -10,11 +10,11 @@
#include "Luau/Parser.h"
#include "Luau/TimeTrace.h"
#include "Coverage.h"
#include "FileUtils.h"
#include "Flags.h"
#include "Profiler.h"
#include "Require.h"
#include "Luau/Coverage.h"
#include "Luau/FileUtils.h"
#include "Luau/Flags.h"
#include "Luau/Profiler.h"
#include "Luau/Require.h"
#include "isocline.h"

View file

@ -1,5 +1,5 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Repl.h"
#include "Luau/Repl.h"
int main(int argc, char** argv)
{

View file

@ -1,7 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Require.h"
#include "Luau/Require.h"
#include "FileUtils.h"
#include "Luau/FileUtils.h"
#include "Luau/Common.h"
#include "Luau/Config.h"
@ -141,8 +141,17 @@ bool RequireResolver::resolveAndStoreDefaultPaths()
return false;
// resolvePath automatically sanitizes/normalizes the paths
resolvedRequire.identifier = resolvePath(pathToResolve, identifierContext);
resolvedRequire.absolutePath = resolvePath(pathToResolve, *absolutePathContext);
std::optional<std::string> identifier = resolvePath(pathToResolve, identifierContext);
std::optional<std::string> absolutePath = resolvePath(pathToResolve, *absolutePathContext);
if (!identifier || !absolutePath)
{
errorHandler.reportError("could not resolve require path");
return false;
}
resolvedRequire.identifier = std::move(*identifier);
resolvedRequire.absolutePath = std::move(*absolutePath);
}
else
{
@ -181,7 +190,7 @@ std::optional<std::string> RequireResolver::getRequiringContextAbsolute()
else
{
// Require statement is being executed in a file, must resolve relative to CWD
requiringFile = resolvePath(requireContext.getPath(), joinPaths(*cwd, "stdin"));
requiringFile = normalizePath(joinPaths(*cwd, requireContext.getPath()));
}
}
std::replace(requiringFile.begin(), requiringFile.end(), '\\', '/');
@ -190,7 +199,7 @@ std::optional<std::string> RequireResolver::getRequiringContextAbsolute()
std::string RequireResolver::getRequiringContextRelative()
{
return requireContext.isStdin() ? "" : requireContext.getPath();
return requireContext.isStdin() ? "./" : requireContext.getPath();
}
bool RequireResolver::substituteAliasIfPresent(std::string& path)
@ -301,4 +310,4 @@ bool RequireResolver::parseConfigInDirectory(const std::string& directory)
}
return true;
}
}

View file

@ -68,11 +68,12 @@ include(Sources.cmake)
target_include_directories(Luau.Common INTERFACE Common/include)
target_compile_features(Luau.CLI.lib PUBLIC cxx_std_17)
target_link_libraries(Luau.CLI.lib PRIVATE Luau.Common)
target_include_directories(Luau.CLI.lib PUBLIC CLI/include)
target_link_libraries(Luau.CLI.lib PRIVATE Luau.Common Luau.Config)
target_compile_features(Luau.Ast PUBLIC cxx_std_17)
target_include_directories(Luau.Ast PUBLIC Ast/include)
target_link_libraries(Luau.Ast PUBLIC Luau.Common Luau.CLI.lib)
target_link_libraries(Luau.Ast PUBLIC Luau.Common)
target_compile_features(Luau.Compiler PUBLIC cxx_std_17)
target_include_directories(Luau.Compiler PUBLIC Compiler/include)

View file

@ -160,6 +160,7 @@ public:
void vmaxsd(OperandX64 dst, OperandX64 src1, OperandX64 src2);
void vminsd(OperandX64 dst, OperandX64 src1, OperandX64 src2);
void vcmpeqsd(OperandX64 dst, OperandX64 src1, OperandX64 src2);
void vcmpltsd(OperandX64 dst, OperandX64 src1, OperandX64 src2);
void vblendvpd(RegisterX64 dst, RegisterX64 src1, OperandX64 mask, RegisterX64 src3);

View file

@ -1,7 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/CodeGen.h"
#include "Luau/CodeGenOptions.h"
#include <vector>

View file

@ -1,7 +1,10 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include <algorithm>
#include "Luau/CodeGenCommon.h"
#include "Luau/CodeGenOptions.h"
#include "Luau/LoweringStats.h"
#include <array>
#include <memory>
#include <string>
@ -12,25 +15,11 @@
struct lua_State;
#if defined(__x86_64__) || defined(_M_X64)
#define CODEGEN_TARGET_X64
#elif defined(__aarch64__) || defined(_M_ARM64)
#define CODEGEN_TARGET_A64
#endif
namespace Luau
{
namespace CodeGen
{
enum CodeGenFlags
{
// Only run native codegen for modules that have been marked with --!native
CodeGen_OnlyNativeModules = 1 << 0,
// Run native codegen for functions that the compiler considers not profitable
CodeGen_ColdFunctions = 1 << 1,
};
// These enum values can be reported through telemetry.
// To ensure consistency, changes should be additive.
enum class CodeGenCompilationResult
@ -72,106 +61,6 @@ struct CompilationResult
}
};
struct IrBuilder;
struct IrOp;
using HostVectorOperationBytecodeType = uint8_t (*)(const char* member, size_t memberLength);
using HostVectorAccessHandler = bool (*)(IrBuilder& builder, const char* member, size_t memberLength, int resultReg, int sourceReg, int pcpos);
using HostVectorNamecallHandler =
bool (*)(IrBuilder& builder, const char* member, size_t memberLength, int argResReg, int sourceReg, int params, int results, int pcpos);
enum class HostMetamethod
{
Add,
Sub,
Mul,
Div,
Idiv,
Mod,
Pow,
Minus,
Equal,
LessThan,
LessEqual,
Length,
Concat,
};
using HostUserdataOperationBytecodeType = uint8_t (*)(uint8_t type, const char* member, size_t memberLength);
using HostUserdataMetamethodBytecodeType = uint8_t (*)(uint8_t lhsTy, uint8_t rhsTy, HostMetamethod method);
using HostUserdataAccessHandler =
bool (*)(IrBuilder& builder, uint8_t type, const char* member, size_t memberLength, int resultReg, int sourceReg, int pcpos);
using HostUserdataMetamethodHandler =
bool (*)(IrBuilder& builder, uint8_t lhsTy, uint8_t rhsTy, int resultReg, IrOp lhs, IrOp rhs, HostMetamethod method, int pcpos);
using HostUserdataNamecallHandler = bool (*)(
IrBuilder& builder,
uint8_t type,
const char* member,
size_t memberLength,
int argResReg,
int sourceReg,
int params,
int results,
int pcpos
);
struct HostIrHooks
{
// Suggest result type of a vector field access
HostVectorOperationBytecodeType vectorAccessBytecodeType = nullptr;
// Suggest result type of a vector function namecall
HostVectorOperationBytecodeType vectorNamecallBytecodeType = nullptr;
// Handle vector value field access
// 'sourceReg' is guaranteed to be a vector
// Guards should take a VM exit to 'pcpos'
HostVectorAccessHandler vectorAccess = nullptr;
// Handle namecall performed on a vector value
// 'sourceReg' (self argument) is guaranteed to be a vector
// All other arguments can be of any type
// Guards should take a VM exit to 'pcpos'
HostVectorNamecallHandler vectorNamecall = nullptr;
// Suggest result type of a userdata field access
HostUserdataOperationBytecodeType userdataAccessBytecodeType = nullptr;
// Suggest result type of a metamethod call
HostUserdataMetamethodBytecodeType userdataMetamethodBytecodeType = nullptr;
// Suggest result type of a userdata namecall
HostUserdataOperationBytecodeType userdataNamecallBytecodeType = nullptr;
// Handle userdata value field access
// 'sourceReg' is guaranteed to be a userdata, but tag has to be checked
// Write to 'resultReg' might invalidate 'sourceReg'
// Guards should take a VM exit to 'pcpos'
HostUserdataAccessHandler userdataAccess = nullptr;
// Handle metamethod operation on a userdata value
// 'lhs' and 'rhs' operands can be VM registers of constants
// Operand types have to be checked and userdata operand tags have to be checked
// Write to 'resultReg' might invalidate source operands
// Guards should take a VM exit to 'pcpos'
HostUserdataMetamethodHandler userdataMetamethod = nullptr;
// Handle namecall performed on a userdata value
// 'sourceReg' (self argument) is guaranteed to be a userdata, but tag has to be checked
// All other arguments can be of any type
// Guards should take a VM exit to 'pcpos'
HostUserdataNamecallHandler userdataNamecall = nullptr;
};
struct CompilationOptions
{
unsigned int flags = 0;
HostIrHooks hooks;
// null-terminated array of userdata types names that might have custom lowering
const char* const* userdataTypes = nullptr;
};
struct CompilationStats
{
size_t bytecodeSizeBytes = 0;
@ -184,8 +73,6 @@ struct CompilationStats
uint32_t functionsBound = 0;
};
using AllocationCallback = void(void* context, void* oldPointer, size_t oldSize, void* newPointer, size_t newSize);
bool isSupported();
class SharedCodeGenContext;
@ -249,153 +136,6 @@ CompilationResult compile(const ModuleId& moduleId, lua_State* L, int idx, unsig
CompilationResult compile(lua_State* L, int idx, const CompilationOptions& options, CompilationStats* stats = nullptr);
CompilationResult compile(const ModuleId& moduleId, lua_State* L, int idx, const CompilationOptions& options, CompilationStats* stats = nullptr);
using AnnotatorFn = void (*)(void* context, std::string& result, int fid, int instpos);
// Output "#" before IR blocks and instructions
enum class IncludeIrPrefix
{
No,
Yes
};
// Output user count and last use information of blocks and instructions
enum class IncludeUseInfo
{
No,
Yes
};
// Output CFG informations like block predecessors, successors and etc
enum class IncludeCfgInfo
{
No,
Yes
};
// Output VM register live in/out information for blocks
enum class IncludeRegFlowInfo
{
No,
Yes
};
struct AssemblyOptions
{
enum Target
{
Host,
A64,
A64_NoFeatures,
X64_Windows,
X64_SystemV,
};
Target target = Host;
CompilationOptions compilationOptions;
bool outputBinary = false;
bool includeAssembly = false;
bool includeIr = false;
bool includeOutlinedCode = false;
bool includeIrTypes = false;
IncludeIrPrefix includeIrPrefix = IncludeIrPrefix::Yes;
IncludeUseInfo includeUseInfo = IncludeUseInfo::Yes;
IncludeCfgInfo includeCfgInfo = IncludeCfgInfo::Yes;
IncludeRegFlowInfo includeRegFlowInfo = IncludeRegFlowInfo::Yes;
// Optional annotator function can be provided to describe each instruction, it takes function id and sequential instruction id
AnnotatorFn annotator = nullptr;
void* annotatorContext = nullptr;
};
struct BlockLinearizationStats
{
unsigned int constPropInstructionCount = 0;
double timeSeconds = 0.0;
BlockLinearizationStats& operator+=(const BlockLinearizationStats& that)
{
this->constPropInstructionCount += that.constPropInstructionCount;
this->timeSeconds += that.timeSeconds;
return *this;
}
BlockLinearizationStats operator+(const BlockLinearizationStats& other) const
{
BlockLinearizationStats result(*this);
result += other;
return result;
}
};
enum FunctionStatsFlags
{
// Enable stats collection per function
FunctionStats_Enable = 1 << 0,
// Compute function bytecode summary
FunctionStats_BytecodeSummary = 1 << 1,
};
struct FunctionStats
{
std::string name;
int line = -1;
unsigned bcodeCount = 0;
unsigned irCount = 0;
unsigned asmCount = 0;
unsigned asmSize = 0;
std::vector<std::vector<unsigned>> bytecodeSummary;
};
struct LoweringStats
{
unsigned totalFunctions = 0;
unsigned skippedFunctions = 0;
int spillsToSlot = 0;
int spillsToRestore = 0;
unsigned maxSpillSlotsUsed = 0;
unsigned blocksPreOpt = 0;
unsigned blocksPostOpt = 0;
unsigned maxBlockInstructions = 0;
int regAllocErrors = 0;
int loweringErrors = 0;
BlockLinearizationStats blockLinearizationStats;
unsigned functionStatsFlags = 0;
std::vector<FunctionStats> functions;
LoweringStats operator+(const LoweringStats& other) const
{
LoweringStats result(*this);
result += other;
return result;
}
LoweringStats& operator+=(const LoweringStats& that)
{
this->totalFunctions += that.totalFunctions;
this->skippedFunctions += that.skippedFunctions;
this->spillsToSlot += that.spillsToSlot;
this->spillsToRestore += that.spillsToRestore;
this->maxSpillSlotsUsed = std::max(this->maxSpillSlotsUsed, that.maxSpillSlotsUsed);
this->blocksPreOpt += that.blocksPreOpt;
this->blocksPostOpt += that.blocksPostOpt;
this->maxBlockInstructions = std::max(this->maxBlockInstructions, that.maxBlockInstructions);
this->regAllocErrors += that.regAllocErrors;
this->loweringErrors += that.loweringErrors;
this->blockLinearizationStats += that.blockLinearizationStats;
if (this->functionStatsFlags & FunctionStats_Enable)
this->functions.insert(this->functions.end(), that.functions.begin(), that.functions.end());
return *this;
}
};
// Generates assembly for target function and all inner functions
std::string getAssembly(lua_State* L, int idx, AssemblyOptions options = {}, LoweringStats* stats = nullptr);

View file

@ -10,3 +10,9 @@
#else
#define CODEGEN_ASSERT(expr) (void)sizeof(!!(expr))
#endif
#if defined(__x86_64__) || defined(_M_X64)
#define CODEGEN_TARGET_X64
#elif defined(__aarch64__) || defined(_M_ARM64)
#define CODEGEN_TARGET_A64
#endif

View file

@ -0,0 +1,188 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include <string>
#include <stddef.h>
#include <stdint.h>
namespace Luau
{
namespace CodeGen
{
enum CodeGenFlags
{
// Only run native codegen for modules that have been marked with --!native
CodeGen_OnlyNativeModules = 1 << 0,
// Run native codegen for functions that the compiler considers not profitable
CodeGen_ColdFunctions = 1 << 1,
};
using AllocationCallback = void(void* context, void* oldPointer, size_t oldSize, void* newPointer, size_t newSize);
struct IrBuilder;
struct IrOp;
using HostVectorOperationBytecodeType = uint8_t (*)(const char* member, size_t memberLength);
using HostVectorAccessHandler = bool (*)(IrBuilder& builder, const char* member, size_t memberLength, int resultReg, int sourceReg, int pcpos);
using HostVectorNamecallHandler =
bool (*)(IrBuilder& builder, const char* member, size_t memberLength, int argResReg, int sourceReg, int params, int results, int pcpos);
enum class HostMetamethod
{
Add,
Sub,
Mul,
Div,
Idiv,
Mod,
Pow,
Minus,
Equal,
LessThan,
LessEqual,
Length,
Concat,
};
using HostUserdataOperationBytecodeType = uint8_t (*)(uint8_t type, const char* member, size_t memberLength);
using HostUserdataMetamethodBytecodeType = uint8_t (*)(uint8_t lhsTy, uint8_t rhsTy, HostMetamethod method);
using HostUserdataAccessHandler =
bool (*)(IrBuilder& builder, uint8_t type, const char* member, size_t memberLength, int resultReg, int sourceReg, int pcpos);
using HostUserdataMetamethodHandler =
bool (*)(IrBuilder& builder, uint8_t lhsTy, uint8_t rhsTy, int resultReg, IrOp lhs, IrOp rhs, HostMetamethod method, int pcpos);
using HostUserdataNamecallHandler = bool (*)(
IrBuilder& builder,
uint8_t type,
const char* member,
size_t memberLength,
int argResReg,
int sourceReg,
int params,
int results,
int pcpos
);
struct HostIrHooks
{
// Suggest result type of a vector field access
HostVectorOperationBytecodeType vectorAccessBytecodeType = nullptr;
// Suggest result type of a vector function namecall
HostVectorOperationBytecodeType vectorNamecallBytecodeType = nullptr;
// Handle vector value field access
// 'sourceReg' is guaranteed to be a vector
// Guards should take a VM exit to 'pcpos'
HostVectorAccessHandler vectorAccess = nullptr;
// Handle namecall performed on a vector value
// 'sourceReg' (self argument) is guaranteed to be a vector
// All other arguments can be of any type
// Guards should take a VM exit to 'pcpos'
HostVectorNamecallHandler vectorNamecall = nullptr;
// Suggest result type of a userdata field access
HostUserdataOperationBytecodeType userdataAccessBytecodeType = nullptr;
// Suggest result type of a metamethod call
HostUserdataMetamethodBytecodeType userdataMetamethodBytecodeType = nullptr;
// Suggest result type of a userdata namecall
HostUserdataOperationBytecodeType userdataNamecallBytecodeType = nullptr;
// Handle userdata value field access
// 'sourceReg' is guaranteed to be a userdata, but tag has to be checked
// Write to 'resultReg' might invalidate 'sourceReg'
// Guards should take a VM exit to 'pcpos'
HostUserdataAccessHandler userdataAccess = nullptr;
// Handle metamethod operation on a userdata value
// 'lhs' and 'rhs' operands can be VM registers of constants
// Operand types have to be checked and userdata operand tags have to be checked
// Write to 'resultReg' might invalidate source operands
// Guards should take a VM exit to 'pcpos'
HostUserdataMetamethodHandler userdataMetamethod = nullptr;
// Handle namecall performed on a userdata value
// 'sourceReg' (self argument) is guaranteed to be a userdata, but tag has to be checked
// All other arguments can be of any type
// Guards should take a VM exit to 'pcpos'
HostUserdataNamecallHandler userdataNamecall = nullptr;
};
struct CompilationOptions
{
unsigned int flags = 0;
HostIrHooks hooks;
// null-terminated array of userdata types names that might have custom lowering
const char* const* userdataTypes = nullptr;
};
using AnnotatorFn = void (*)(void* context, std::string& result, int fid, int instpos);
// Output "#" before IR blocks and instructions
enum class IncludeIrPrefix
{
No,
Yes
};
// Output user count and last use information of blocks and instructions
enum class IncludeUseInfo
{
No,
Yes
};
// Output CFG informations like block predecessors, successors and etc
enum class IncludeCfgInfo
{
No,
Yes
};
// Output VM register live in/out information for blocks
enum class IncludeRegFlowInfo
{
No,
Yes
};
struct AssemblyOptions
{
enum Target
{
Host,
A64,
A64_NoFeatures,
X64_Windows,
X64_SystemV,
};
Target target = Host;
CompilationOptions compilationOptions;
bool outputBinary = false;
bool includeAssembly = false;
bool includeIr = false;
bool includeOutlinedCode = false;
bool includeIrTypes = false;
IncludeIrPrefix includeIrPrefix = IncludeIrPrefix::Yes;
IncludeUseInfo includeUseInfo = IncludeUseInfo::Yes;
IncludeCfgInfo includeCfgInfo = IncludeCfgInfo::Yes;
IncludeRegFlowInfo includeRegFlowInfo = IncludeRegFlowInfo::Yes;
// Optional annotator function can be provided to describe each instruction, it takes function id and sequential instruction id
AnnotatorFn annotator = nullptr;
void* annotatorContext = nullptr;
};
} // namespace CodeGen
} // namespace Luau

View file

@ -20,6 +20,8 @@ namespace Luau
namespace CodeGen
{
struct LoweringStats;
// IR extensions to LuauBuiltinFunction enum (these only exist inside IR, and start from 256 to avoid collisions)
enum
{
@ -67,18 +69,18 @@ enum class IrCmd : uint8_t
LOAD_ENV,
// Get pointer (TValue) to table array at index
// A: pointer (Table)
// A: pointer (LuaTable)
// B: int
GET_ARR_ADDR,
// Get pointer (LuaNode) to table node element at the active cached slot index
// A: pointer (Table)
// A: pointer (LuaTable)
// B: unsigned int (pcpos)
// C: Kn
GET_SLOT_NODE_ADDR,
// Get pointer (LuaNode) to table node element at the main position of the specified key hash
// A: pointer (Table)
// A: pointer (LuaTable)
// B: unsigned int (hash)
GET_HASH_NODE_ADDR,
@ -185,6 +187,11 @@ enum class IrCmd : uint8_t
// A: double
SIGN_NUM,
// Select B if C == D, otherwise select A
// A, B: double (endpoints)
// C, D: double (condition arguments)
SELECT_NUM,
// Add/Sub/Mul/Div/Idiv two vectors
// A, B: TValue
ADD_VEC,
@ -268,7 +275,7 @@ enum class IrCmd : uint8_t
JUMP_SLOT_MATCH,
// Get table length
// A: pointer (Table)
// A: pointer (LuaTable)
TABLE_LEN,
// Get string length
@ -281,11 +288,11 @@ enum class IrCmd : uint8_t
NEW_TABLE,
// Duplicate a table
// A: pointer (Table)
// A: pointer (LuaTable)
DUP_TABLE,
// Insert an integer key into a table and return the pointer to inserted value (TValue)
// A: pointer (Table)
// A: pointer (LuaTable)
// B: int (key)
TABLE_SETNUM,
@ -425,13 +432,13 @@ enum class IrCmd : uint8_t
CHECK_TRUTHY,
// Guard against readonly table
// A: pointer (Table)
// A: pointer (LuaTable)
// B: block/vmexit/undef
// When undef is specified instead of a block, execution is aborted on check failure
CHECK_READONLY,
// Guard against table having a metatable
// A: pointer (Table)
// A: pointer (LuaTable)
// B: block/vmexit/undef
// When undef is specified instead of a block, execution is aborted on check failure
CHECK_NO_METATABLE,
@ -442,7 +449,7 @@ enum class IrCmd : uint8_t
CHECK_SAFE_ENV,
// Guard against index overflowing the table array size
// A: pointer (Table)
// A: pointer (LuaTable)
// B: int (index)
// C: block/vmexit/undef
// When undef is specified instead of a block, execution is aborted on check failure
@ -498,11 +505,11 @@ enum class IrCmd : uint8_t
BARRIER_OBJ,
// Handle GC write barrier (backwards) for a write into a table
// A: pointer (Table)
// A: pointer (LuaTable)
BARRIER_TABLE_BACK,
// Handle GC write barrier (forward) for a write into a table
// A: pointer (Table)
// A: pointer (LuaTable)
// B: Rn (TValue that was written to the object)
// C: tag/undef (tag of the value that was written)
BARRIER_TABLE_FORWARD,
@ -1044,6 +1051,8 @@ struct IrFunction
CfgInfo cfg;
LoweringStats* stats = nullptr;
IrBlock& blockOp(IrOp op)
{
CODEGEN_ASSERT(op.kind == IrOpKind::Block);

View file

@ -2,11 +2,13 @@
#pragma once
#include "Luau/IrData.h"
#include "Luau/CodeGen.h"
#include "Luau/CodeGenOptions.h"
#include <string>
#include <vector>
struct Proto;
namespace Luau
{
namespace CodeGen
@ -23,6 +25,7 @@ struct IrToStringContext
const std::vector<IrBlock>& blocks;
const std::vector<IrConst>& constants;
const CfgInfo& cfg;
Proto* proto = nullptr;
};
void toString(IrToStringContext& ctx, const IrInst& inst, uint32_t index);

View file

@ -174,6 +174,7 @@ inline bool hasResult(IrCmd cmd)
case IrCmd::SQRT_NUM:
case IrCmd::ABS_NUM:
case IrCmd::SIGN_NUM:
case IrCmd::SELECT_NUM:
case IrCmd::ADD_VEC:
case IrCmd::SUB_VEC:
case IrCmd::MUL_VEC:

Some files were not shown because too many files have changed in this diff Show more