Merge branch 'upstream' into merge

This commit is contained in:
Junseo Yoo 2024-08-01 16:26:06 -07:00
commit 1adbd45b20
218 changed files with 11074 additions and 5108 deletions

View file

@ -16,10 +16,14 @@ SortIncludes: false
IndentWidth: 4
TabWidth: 4
ObjCBlockIndentWidth: 4
AlignAfterOpenBracket: DontAlign
UseTab: Never
PointerAlignment: Left
SpaceAfterTemplateKeyword: false
AlignEscapedNewlines: DontAlign
AlwaysBreakTemplateDeclarations: Yes
MaxEmptyLinesToKeep: 10
AllowAllParametersOfDeclarationOnNextLine: false
AlignAfterOpenBracket: BlockIndent
BinPackArguments: false
BinPackParameters: false
PenaltyReturnTypeOnItsOwnLine: 10000

View file

@ -1,6 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/AstQuery.h"
#include "Luau/Config.h"
#include "Luau/ModuleResolver.h"
#include "Luau/Scope.h"
@ -36,6 +37,7 @@ struct AnyTypeSummary
{
TypeArena arena;
AstStatBlock* rootSrc = nullptr;
DenseHashSet<TypeId> seenTypeFamilyInstances{nullptr};
int recursionCount = 0;
@ -47,33 +49,30 @@ struct AnyTypeSummary
AnyTypeSummary();
void traverse(Module* module, AstStat* src, NotNull<BuiltinTypes> builtinTypes);
void traverse(const Module* module, AstStat* src, NotNull<BuiltinTypes> builtinTypes);
std::pair<bool, TypeId> checkForAnyCast(Scope* scope, AstExprTypeAssertion* expr);
// Todo: errors resolved by anys
void reportError(Location location, TypeErrorData err);
std::pair<bool, TypeId> checkForAnyCast(const Scope* scope, AstExprTypeAssertion* expr);
bool containsAny(TypePackId typ);
bool containsAny(TypeId typ);
bool isAnyCast(Scope* scope, AstExpr* expr, Module* module, NotNull<BuiltinTypes> builtinTypes);
bool isAnyCall(Scope* scope, AstExpr* expr, Module* module, NotNull<BuiltinTypes> builtinTypes);
bool isAnyCast(const Scope* scope, AstExpr* expr, const Module* module, NotNull<BuiltinTypes> builtinTypes);
bool isAnyCall(const Scope* scope, AstExpr* expr, const Module* module, NotNull<BuiltinTypes> builtinTypes);
bool hasVariadicAnys(Scope* scope, AstExprFunction* expr, Module* module, NotNull<BuiltinTypes> builtinTypes);
bool hasArgAnys(Scope* scope, AstExprFunction* expr, Module* module, NotNull<BuiltinTypes> builtinTypes);
bool hasAnyReturns(Scope* scope, AstExprFunction* expr, Module* module, NotNull<BuiltinTypes> builtinTypes);
bool hasVariadicAnys(const Scope* scope, AstExprFunction* expr, const Module* module, NotNull<BuiltinTypes> builtinTypes);
bool hasArgAnys(const Scope* scope, AstExprFunction* expr, const Module* module, NotNull<BuiltinTypes> builtinTypes);
bool hasAnyReturns(const Scope* scope, AstExprFunction* expr, const Module* module, NotNull<BuiltinTypes> builtinTypes);
TypeId checkForFamilyInhabitance(TypeId instance, Location location);
TypeId lookupType(AstExpr* expr, Module* module, NotNull<BuiltinTypes> builtinTypes);
TypePackId reconstructTypePack(AstArray<AstExpr*> exprs, Module* module, NotNull<BuiltinTypes> builtinTypes);
TypeId checkForFamilyInhabitance(const TypeId instance, Location location);
TypeId lookupType(const AstExpr* expr, const Module* module, NotNull<BuiltinTypes> builtinTypes);
TypePackId reconstructTypePack(const AstArray<AstExpr*> exprs, const Module* module, NotNull<BuiltinTypes> builtinTypes);
DenseHashSet<TypeId> seenTypeFunctionInstances{nullptr};
TypeId lookupAnnotation(AstType* annotation, Module* module, NotNull<BuiltinTypes> builtintypes);
std::optional<TypePackId> lookupPackAnnotation(AstTypePack* annotation, Module* module);
TypeId checkForTypeFunctionInhabitance(TypeId instance, Location location);
TypeId lookupAnnotation(AstType* annotation, const Module* module, NotNull<BuiltinTypes> builtintypes);
std::optional<TypePackId> lookupPackAnnotation(AstTypePack* annotation, const Module* module);
TypeId checkForTypeFunctionInhabitance(const TypeId instance, const Location location);
enum Pattern : uint64_t
enum Pattern: uint64_t
{
Casts,
FuncArg,
@ -91,11 +90,25 @@ struct AnyTypeSummary
Pattern code;
std::string node;
TelemetryTypePair type;
std::string debug;
explicit TypeInfo(Pattern code, std::string node, TelemetryTypePair type);
};
struct FindReturnAncestry final : public AstVisitor
{
AstNode* currNode{nullptr};
AstNode* stat{nullptr};
Position rootEnd;
bool found = false;
explicit FindReturnAncestry(AstNode* stat, Position rootEnd);
bool visit(AstType* node) override;
bool visit(AstNode* node) override;
bool visit(AstStatFunction* node) override;
bool visit(AstStatLocalFunction* node) override;
};
std::vector<TypeInfo> typeInfo;
/**
@ -103,29 +116,32 @@ struct AnyTypeSummary
* @param node the lexical node that the scope belongs to.
* @param parent the parent scope of the new scope. Must not be null.
*/
Scope* childScope(AstNode* node, const Scope* parent);
const Scope* childScope(const AstNode* node, const Scope* parent);
Scope* findInnerMostScope(Location location, Module* module);
std::optional<AstExpr*> matchRequire(const AstExprCall& call);
AstNode* getNode(AstStatBlock* root, AstNode* node);
const Scope* findInnerMostScope(const Location location, const Module* module);
const AstNode* findAstAncestryAtLocation(const AstStatBlock* root, AstNode* node);
void visit(Scope* scope, AstStat* stat, Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(Scope* scope, AstStatBlock* block, Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(Scope* scope, AstStatIf* ifStatement, Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(Scope* scope, AstStatWhile* while_, Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(Scope* scope, AstStatRepeat* repeat, Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(Scope* scope, AstStatReturn* ret, Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(Scope* scope, AstStatLocal* local, Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(Scope* scope, AstStatFor* for_, Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(Scope* scope, AstStatForIn* forIn, Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(Scope* scope, AstStatAssign* assign, Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(Scope* scope, AstStatCompoundAssign* assign, Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(Scope* scope, AstStatFunction* function, Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(Scope* scope, AstStatLocalFunction* function, Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(Scope* scope, AstStatTypeAlias* alias, Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(Scope* scope, AstStatExpr* expr, Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(Scope* scope, AstStatDeclareGlobal* declareGlobal, Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(Scope* scope, AstStatDeclareClass* declareClass, Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(Scope* scope, AstStatDeclareFunction* declareFunction, Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(Scope* scope, AstStatError* error, Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(const Scope* scope, AstStat* stat, const Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(const Scope* scope, AstStatBlock* block, const Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(const Scope* scope, AstStatIf* ifStatement, const Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(const Scope* scope, AstStatWhile* while_, const Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(const Scope* scope, AstStatRepeat* repeat, const Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(const Scope* scope, AstStatReturn* ret, const Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(const Scope* scope, AstStatLocal* local, const Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(const Scope* scope, AstStatFor* for_, const Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(const Scope* scope, AstStatForIn* forIn, const Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(const Scope* scope, AstStatAssign* assign, const Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(const Scope* scope, AstStatCompoundAssign* assign, const Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(const Scope* scope, AstStatFunction* function, const Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(const Scope* scope, AstStatLocalFunction* function, const Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(const Scope* scope, AstStatTypeAlias* alias, const Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(const Scope* scope, AstStatExpr* expr, const Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(const Scope* scope, AstStatDeclareGlobal* declareGlobal, const Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(const Scope* scope, AstStatDeclareClass* declareClass, const Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(const Scope* scope, AstStatDeclareFunction* declareFunction, const Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(const Scope* scope, AstStatError* error, const Module* module, NotNull<BuiltinTypes> builtinTypes);
};
} // namespace Luau

View file

@ -19,10 +19,22 @@ using ScopePtr = std::shared_ptr<Scope>;
// A substitution which replaces free types by any
struct Anyification : Substitution
{
Anyification(TypeArena* arena, NotNull<Scope> scope, NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter* iceHandler, TypeId anyType,
TypePackId anyTypePack);
Anyification(TypeArena* arena, const ScopePtr& scope, NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter* iceHandler, TypeId anyType,
TypePackId anyTypePack);
Anyification(
TypeArena* arena,
NotNull<Scope> scope,
NotNull<BuiltinTypes> builtinTypes,
InternalErrorReporter* iceHandler,
TypeId anyType,
TypePackId anyTypePack
);
Anyification(
TypeArena* arena,
const ScopePtr& scope,
NotNull<BuiltinTypes> builtinTypes,
InternalErrorReporter* iceHandler,
TypeId anyType,
TypePackId anyTypePack
);
NotNull<Scope> scope;
NotNull<BuiltinTypes> builtinTypes;
InternalErrorReporter* iceHandler;

View file

@ -25,21 +25,42 @@ TypeId makeOption(NotNull<BuiltinTypes> builtinTypes, TypeArena& arena, TypeId t
/** Small utility function for building up type definitions from C++.
*/
TypeId makeFunction( // Monomorphic
TypeArena& arena, std::optional<TypeId> selfType, std::initializer_list<TypeId> paramTypes, std::initializer_list<TypeId> retTypes,
bool checked = false);
TypeArena& arena,
std::optional<TypeId> selfType,
std::initializer_list<TypeId> paramTypes,
std::initializer_list<TypeId> retTypes,
bool checked = false
);
TypeId makeFunction( // Polymorphic
TypeArena& arena, std::optional<TypeId> selfType, std::initializer_list<TypeId> generics, std::initializer_list<TypePackId> genericPacks,
std::initializer_list<TypeId> paramTypes, std::initializer_list<TypeId> retTypes, bool checked = false);
TypeArena& arena,
std::optional<TypeId> selfType,
std::initializer_list<TypeId> generics,
std::initializer_list<TypePackId> genericPacks,
std::initializer_list<TypeId> paramTypes,
std::initializer_list<TypeId> retTypes,
bool checked = false
);
TypeId makeFunction( // Monomorphic
TypeArena& arena, std::optional<TypeId> selfType, std::initializer_list<TypeId> paramTypes, std::initializer_list<std::string> paramNames,
std::initializer_list<TypeId> retTypes, bool checked = false);
TypeArena& arena,
std::optional<TypeId> selfType,
std::initializer_list<TypeId> paramTypes,
std::initializer_list<std::string> paramNames,
std::initializer_list<TypeId> retTypes,
bool checked = false
);
TypeId makeFunction( // Polymorphic
TypeArena& arena, std::optional<TypeId> selfType, std::initializer_list<TypeId> generics, std::initializer_list<TypePackId> genericPacks,
std::initializer_list<TypeId> paramTypes, std::initializer_list<std::string> paramNames, std::initializer_list<TypeId> retTypes,
bool checked = false);
TypeArena& arena,
std::optional<TypeId> selfType,
std::initializer_list<TypeId> generics,
std::initializer_list<TypePackId> genericPacks,
std::initializer_list<TypeId> paramTypes,
std::initializer_list<std::string> paramNames,
std::initializer_list<TypeId> retTypes,
bool checked = false
);
void attachMagicFunction(TypeId ty, MagicFunction fn);
void attachDcrMagicFunction(TypeId ty, DcrMagicFunction fn);

View file

@ -256,9 +256,24 @@ struct ReducePackConstraint
TypePackId tp;
};
using ConstraintV = Variant<SubtypeConstraint, PackSubtypeConstraint, GeneralizationConstraint, IterableConstraint, NameConstraint,
TypeAliasExpansionConstraint, FunctionCallConstraint, FunctionCheckConstraint, PrimitiveTypeConstraint, HasPropConstraint, HasIndexerConstraint,
AssignPropConstraint, AssignIndexConstraint, UnpackConstraint, ReduceConstraint, ReducePackConstraint, EqualityConstraint>;
using ConstraintV = Variant<
SubtypeConstraint,
PackSubtypeConstraint,
GeneralizationConstraint,
IterableConstraint,
NameConstraint,
TypeAliasExpansionConstraint,
FunctionCallConstraint,
FunctionCheckConstraint,
PrimitiveTypeConstraint,
HasPropConstraint,
HasIndexerConstraint,
AssignPropConstraint,
AssignIndexConstraint,
UnpackConstraint,
ReduceConstraint,
ReducePackConstraint,
EqualityConstraint>;
struct Constraint
{

View file

@ -122,9 +122,18 @@ struct ConstraintGenerator
DcrLogger* logger;
ConstraintGenerator(ModulePtr module, NotNull<Normalizer> normalizer, NotNull<ModuleResolver> moduleResolver, NotNull<BuiltinTypes> builtinTypes,
NotNull<InternalErrorReporter> ice, const ScopePtr& globalScope, std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope,
DcrLogger* logger, NotNull<DataFlowGraph> dfg, std::vector<RequireCycle> requireCycles);
ConstraintGenerator(
ModulePtr module,
NotNull<Normalizer> normalizer,
NotNull<ModuleResolver> moduleResolver,
NotNull<BuiltinTypes> builtinTypes,
NotNull<InternalErrorReporter> ice,
const ScopePtr& globalScope,
std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope,
DcrLogger* logger,
NotNull<DataFlowGraph> dfg,
std::vector<RequireCycle> requireCycles
);
/**
* The entry point to the ConstraintGenerator. This will construct a set
@ -195,10 +204,23 @@ private:
};
using RefinementContext = InsertionOrderedMap<DefId, RefinementPartition>;
void unionRefinements(const ScopePtr& scope, Location location, const RefinementContext& lhs, const RefinementContext& rhs,
RefinementContext& dest, std::vector<ConstraintV>* constraints);
void computeRefinement(const ScopePtr& scope, Location location, RefinementId refinement, RefinementContext* refis, bool sense, bool eq,
std::vector<ConstraintV>* constraints);
void unionRefinements(
const ScopePtr& scope,
Location location,
const RefinementContext& lhs,
const RefinementContext& rhs,
RefinementContext& dest,
std::vector<ConstraintV>* constraints
);
void computeRefinement(
const ScopePtr& scope,
Location location,
RefinementId refinement,
RefinementContext* refis,
bool sense,
bool eq,
std::vector<ConstraintV>* constraints
);
void applyRefinements(const ScopePtr& scope, Location location, RefinementId refinement);
ControlFlow visitBlockWithoutChildScope(const ScopePtr& scope, AstStatBlock* block);
@ -217,6 +239,7 @@ private:
ControlFlow visit(const ScopePtr& scope, AstStatCompoundAssign* assign);
ControlFlow visit(const ScopePtr& scope, AstStatIf* ifStatement);
ControlFlow visit(const ScopePtr& scope, AstStatTypeAlias* alias);
ControlFlow visit(const ScopePtr& scope, AstStatTypeFunction* function);
ControlFlow visit(const ScopePtr& scope, AstStatDeclareGlobal* declareGlobal);
ControlFlow visit(const ScopePtr& scope, AstStatDeclareClass* declareClass);
ControlFlow visit(const ScopePtr& scope, AstStatDeclareFunction* declareFunction);
@ -224,7 +247,11 @@ private:
InferencePack checkPack(const ScopePtr& scope, AstArray<AstExpr*> exprs, const std::vector<std::optional<TypeId>>& expectedTypes = {});
InferencePack checkPack(
const ScopePtr& scope, AstExpr* expr, const std::vector<std::optional<TypeId>>& expectedTypes = {}, bool generalize = true);
const ScopePtr& scope,
AstExpr* expr,
const std::vector<std::optional<TypeId>>& expectedTypes = {},
bool generalize = true
);
InferencePack checkPack(const ScopePtr& scope, AstExprCall* call);
@ -238,7 +265,12 @@ private:
* @return the type of the expression.
*/
Inference check(
const ScopePtr& scope, AstExpr* expr, std::optional<TypeId> expectedType = {}, bool forceSingleton = false, bool generalize = true);
const ScopePtr& scope,
AstExpr* expr,
std::optional<TypeId> expectedType = {},
bool forceSingleton = false,
bool generalize = true
);
Inference check(const ScopePtr& scope, AstExprConstantString* string, std::optional<TypeId> expectedType, bool forceSingleton);
Inference check(const ScopePtr& scope, AstExprConstantBool* bool_, std::optional<TypeId> expectedType, bool forceSingleton);
@ -276,7 +308,11 @@ private:
};
FunctionSignature checkFunctionSignature(
const ScopePtr& parent, AstExprFunction* fn, std::optional<TypeId> expectedType = {}, std::optional<Location> originalName = {});
const ScopePtr& parent,
AstExprFunction* fn,
std::optional<TypeId> expectedType = {},
std::optional<Location> originalName = {}
);
/**
* Checks the body of a function expression.
@ -323,7 +359,11 @@ private:
* privateTypeBindings map.
**/
std::vector<std::pair<Name, GenericTypeDefinition>> createGenerics(
const ScopePtr& scope, AstArray<AstGenericType> generics, bool useCache = false, bool addTypes = true);
const ScopePtr& scope,
AstArray<AstGenericType> generics,
bool useCache = false,
bool addTypes = true
);
/**
* Creates generic type packs given a list of AST definitions, resolving
@ -336,7 +376,11 @@ private:
* privateTypePackBindings map.
**/
std::vector<std::pair<Name, GenericTypePackDefinition>> createGenericPacks(
const ScopePtr& scope, AstArray<AstGenericTypePack> packs, bool useCache = false, bool addTypes = true);
const ScopePtr& scope,
AstArray<AstGenericTypePack> packs,
bool useCache = false,
bool addTypes = true
);
Inference flattenPack(const ScopePtr& scope, Location location, InferencePack pack);
@ -371,7 +415,12 @@ private:
std::vector<std::optional<TypeId>> getExpectedCallTypesForFunctionOverloads(const TypeId fnType);
TypeId createTypeFunctionInstance(
const TypeFunction& function, std::vector<TypeId> typeArguments, std::vector<TypePackId> packArguments, const ScopePtr& scope, Location location);
const TypeFunction& function,
std::vector<TypeId> typeArguments,
std::vector<TypePackId> packArguments,
const ScopePtr& scope,
Location location
);
};
/** Borrow a vector of pointers from a vector of owning pointers to constraints.

View file

@ -109,9 +109,16 @@ struct ConstraintSolver
DenseHashMap<TypeId, const Constraint*> typeFunctionsToFinalize{nullptr};
explicit ConstraintSolver(NotNull<Normalizer> normalizer, NotNull<Scope> rootScope, std::vector<NotNull<Constraint>> constraints,
ModuleName moduleName, NotNull<ModuleResolver> moduleResolver, std::vector<RequireCycle> requireCycles, DcrLogger* logger,
TypeCheckLimits limits);
explicit ConstraintSolver(
NotNull<Normalizer> normalizer,
NotNull<Scope> rootScope,
std::vector<NotNull<Constraint>> constraints,
ModuleName moduleName,
NotNull<ModuleResolver> moduleResolver,
std::vector<RequireCycle> requireCycles,
DcrLogger* logger,
TypeCheckLimits limits
);
// Randomize the order in which to dispatch constraints
void randomize(unsigned seed);
@ -170,7 +177,13 @@ public:
bool tryDispatchHasIndexer(
int& recursionDepth, NotNull<const Constraint> constraint, TypeId subjectType, TypeId indexType, TypeId resultType, Set<TypeId>& seen);
int& recursionDepth,
NotNull<const Constraint> constraint,
TypeId subjectType,
TypeId indexType,
TypeId resultType,
Set<TypeId>& seen
);
bool tryDispatch(const HasIndexerConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const AssignPropConstraint& c, NotNull<const Constraint> constraint);
@ -187,10 +200,23 @@ public:
// for a, ... in next_function, t, ... do
bool tryDispatchIterableFunction(TypeId nextTy, TypeId tableTy, const IterableConstraint& c, NotNull<const Constraint> constraint, bool force);
std::pair<std::vector<TypeId>, std::optional<TypeId>> lookupTableProp(NotNull<const Constraint> constraint, TypeId subjectType,
const std::string& propName, ValueContext context, bool inConditional = false, bool suppressSimplification = false);
std::pair<std::vector<TypeId>, std::optional<TypeId>> lookupTableProp(NotNull<const Constraint> constraint, TypeId subjectType,
const std::string& propName, ValueContext context, bool inConditional, bool suppressSimplification, DenseHashSet<TypeId>& seen);
std::pair<std::vector<TypeId>, std::optional<TypeId>> lookupTableProp(
NotNull<const Constraint> constraint,
TypeId subjectType,
const std::string& propName,
ValueContext context,
bool inConditional = false,
bool suppressSimplification = false
);
std::pair<std::vector<TypeId>, std::optional<TypeId>> lookupTableProp(
NotNull<const Constraint> constraint,
TypeId subjectType,
const std::string& propName,
ValueContext context,
bool inConditional,
bool suppressSimplification,
DenseHashSet<TypeId>& seen
);
/**
* Generate constraints to unpack the types of srcTypes and assign each

View file

@ -162,6 +162,7 @@ private:
ControlFlow visit(DfgScope* scope, AstStatFunction* f);
ControlFlow visit(DfgScope* scope, AstStatLocalFunction* l);
ControlFlow visit(DfgScope* scope, AstStatTypeAlias* t);
ControlFlow visit(DfgScope* scope, AstStatTypeFunction* f);
ControlFlow visit(DfgScope* scope, AstStatDeclareGlobal* d);
ControlFlow visit(DfgScope* scope, AstStatDeclareFunction* d);
ControlFlow visit(DfgScope* scope, AstStatDeclareClass* d);

View file

@ -126,7 +126,11 @@ struct DcrLogger
void captureInitialSolverState(const Scope* rootScope, const std::vector<NotNull<const Constraint>>& unsolvedConstraints);
StepSnapshot prepareStepSnapshot(
const Scope* rootScope, NotNull<const Constraint> current, bool force, const std::vector<NotNull<const Constraint>>& unsolvedConstraints);
const Scope* rootScope,
NotNull<const Constraint> current,
bool force,
const std::vector<NotNull<const Constraint>>& unsolvedConstraints
);
void commitStepSnapshot(StepSnapshot snapshot);
void captureFinalSolverState(const Scope* rootScope, const std::vector<NotNull<const Constraint>>& unsolvedConstraints);

View file

@ -62,7 +62,12 @@ struct DiffPathNodeLeaf
// TODO: Rename to anonymousIndex, for both union and Intersection
std::optional<size_t> unionIndex;
DiffPathNodeLeaf(
std::optional<TypeId> ty, std::optional<Name> tableProperty, std::optional<int> minLength, bool isVariadic, std::optional<size_t> unionIndex)
std::optional<TypeId> ty,
std::optional<Name> tableProperty,
std::optional<int> minLength,
bool isVariadic,
std::optional<size_t> unionIndex
)
: ty(ty)
, tableProperty(tableProperty)
, minLength(minLength)
@ -159,7 +164,11 @@ struct DifferEnvironment
DenseHashMap<TypePackId, TypePackId> genericTpMatchedPairs;
DifferEnvironment(
TypeId rootLeft, TypeId rootRight, std::optional<std::string> externalSymbolLeft, std::optional<std::string> externalSymbolRight)
TypeId rootLeft,
TypeId rootRight,
std::optional<std::string> externalSymbolLeft,
std::optional<std::string> externalSymbolRight
)
: rootLeft(rootLeft)
, rootRight(rootRight)
, externalSymbolLeft(externalSymbolLeft)

View file

@ -194,6 +194,11 @@ struct InternalError
bool operator==(const InternalError& rhs) const;
};
struct ConstraintSolvingIncompleteError
{
bool operator==(const ConstraintSolvingIncompleteError& rhs) const;
};
struct CannotCallNonFunction
{
TypeId ty;
@ -443,15 +448,55 @@ struct UnexpectedTypePackInSubtyping
bool operator==(const UnexpectedTypePackInSubtyping& rhs) const;
};
using TypeErrorData =
Variant<TypeMismatch, UnknownSymbol, UnknownProperty, NotATable, CannotExtendTable, OnlyTablesCanHaveMethods, DuplicateTypeDefinition,
CountMismatch, FunctionDoesNotTakeSelf, FunctionRequiresSelf, OccursCheckFailed, UnknownRequire, IncorrectGenericParameterCount, SyntaxError,
CodeTooComplex, UnificationTooComplex, UnknownPropButFoundLikeProp, GenericError, InternalError, CannotCallNonFunction, ExtraInformation,
DeprecatedApiUsed, ModuleHasCyclicDependency, IllegalRequire, FunctionExitsWithoutReturning, DuplicateGenericParameter, CannotAssignToNever,
CannotInferBinaryOperation, MissingProperties, SwappedGenericTypeParameter, OptionalValueAccess, MissingUnionProperty, TypesAreUnrelated,
NormalizationTooComplex, TypePackMismatch, DynamicPropertyLookupOnClassesUnsafe, UninhabitedTypeFunction, UninhabitedTypePackFunction,
WhereClauseNeeded, PackWhereClauseNeeded, CheckedFunctionCallError, NonStrictFunctionDefinitionError, PropertyAccessViolation,
CheckedFunctionIncorrectArgs, UnexpectedTypeInSubtyping, UnexpectedTypePackInSubtyping, ExplicitFunctionAnnotationRecommended>;
using TypeErrorData = Variant<
TypeMismatch,
UnknownSymbol,
UnknownProperty,
NotATable,
CannotExtendTable,
OnlyTablesCanHaveMethods,
DuplicateTypeDefinition,
CountMismatch,
FunctionDoesNotTakeSelf,
FunctionRequiresSelf,
OccursCheckFailed,
UnknownRequire,
IncorrectGenericParameterCount,
SyntaxError,
CodeTooComplex,
UnificationTooComplex,
UnknownPropButFoundLikeProp,
GenericError,
InternalError,
ConstraintSolvingIncompleteError,
CannotCallNonFunction,
ExtraInformation,
DeprecatedApiUsed,
ModuleHasCyclicDependency,
IllegalRequire,
FunctionExitsWithoutReturning,
DuplicateGenericParameter,
CannotAssignToNever,
CannotInferBinaryOperation,
MissingProperties,
SwappedGenericTypeParameter,
OptionalValueAccess,
MissingUnionProperty,
TypesAreUnrelated,
NormalizationTooComplex,
TypePackMismatch,
DynamicPropertyLookupOnClassesUnsafe,
UninhabitedTypeFunction,
UninhabitedTypePackFunction,
WhereClauseNeeded,
PackWhereClauseNeeded,
CheckedFunctionCallError,
NonStrictFunctionDefinitionError,
PropertyAccessViolation,
CheckedFunctionIncorrectArgs,
UnexpectedTypeInSubtyping,
UnexpectedTypePackInSubtyping,
ExplicitFunctionAnnotationRecommended>;
struct TypeErrorSummary
{

View file

@ -185,30 +185,55 @@ struct Frontend
void registerBuiltinDefinition(const std::string& name, std::function<void(Frontend&, GlobalTypes&, ScopePtr)>);
void applyBuiltinDefinitionToEnvironment(const std::string& environmentName, const std::string& definitionName);
LoadDefinitionFileResult loadDefinitionFile(GlobalTypes& globals, ScopePtr targetScope, std::string_view source, const std::string& packageName,
bool captureComments, bool typeCheckForAutocomplete = false);
LoadDefinitionFileResult loadDefinitionFile(
GlobalTypes& globals,
ScopePtr targetScope,
std::string_view source,
const std::string& packageName,
bool captureComments,
bool typeCheckForAutocomplete = false
);
// Batch module checking. Queue modules and check them together, retrieve results with 'getCheckResult'
// If provided, 'executeTask' function is allowed to call the 'task' function on any thread and return without waiting for 'task' to complete
void queueModuleCheck(const std::vector<ModuleName>& names);
void queueModuleCheck(const ModuleName& name);
std::vector<ModuleName> checkQueuedModules(std::optional<FrontendOptions> optionOverride = {},
std::function<void(std::function<void()> task)> executeTask = {}, std::function<bool(size_t done, size_t total)> progress = {});
std::vector<ModuleName> checkQueuedModules(
std::optional<FrontendOptions> optionOverride = {},
std::function<void(std::function<void()> task)> executeTask = {},
std::function<bool(size_t done, size_t total)> progress = {}
);
std::optional<CheckResult> getCheckResult(const ModuleName& name, bool accumulateNested, bool forAutocomplete = false);
private:
ModulePtr check(const SourceModule& sourceModule, Mode mode, std::vector<RequireCycle> requireCycles, std::optional<ScopePtr> environmentScope,
bool forAutocomplete, bool recordJsonLog, TypeCheckLimits typeCheckLimits);
ModulePtr check(
const SourceModule& sourceModule,
Mode mode,
std::vector<RequireCycle> requireCycles,
std::optional<ScopePtr> environmentScope,
bool forAutocomplete,
bool recordJsonLog,
TypeCheckLimits typeCheckLimits
);
std::pair<SourceNode*, SourceModule*> getSourceNode(const ModuleName& name);
SourceModule parse(const ModuleName& name, std::string_view src, const ParseOptions& parseOptions);
bool parseGraph(
std::vector<ModuleName>& buildQueue, const ModuleName& root, bool forAutocomplete, std::function<bool(const ModuleName&)> canSkip = {});
std::vector<ModuleName>& buildQueue,
const ModuleName& root,
bool forAutocomplete,
std::function<bool(const ModuleName&)> canSkip = {}
);
void addBuildQueueItems(std::vector<BuildQueueItem>& items, std::vector<ModuleName>& buildQueue, bool cycleDetected,
DenseHashSet<Luau::ModuleName>& seen, const FrontendOptions& frontendOptions);
void addBuildQueueItems(
std::vector<BuildQueueItem>& items,
std::vector<ModuleName>& buildQueue,
bool cycleDetected,
DenseHashSet<Luau::ModuleName>& seen,
const FrontendOptions& frontendOptions
);
void checkBuildQueueItem(BuildQueueItem& item);
void checkBuildQueueItems(std::vector<BuildQueueItem>& items);
void recordItemResult(const BuildQueueItem& item);
@ -248,14 +273,34 @@ public:
std::vector<ModuleName> moduleQueue;
};
ModulePtr check(const SourceModule& sourceModule, Mode mode, const std::vector<RequireCycle>& requireCycles, NotNull<BuiltinTypes> builtinTypes,
NotNull<InternalErrorReporter> iceHandler, NotNull<ModuleResolver> moduleResolver, NotNull<FileResolver> fileResolver,
const ScopePtr& globalScope, std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope, FrontendOptions options,
TypeCheckLimits limits);
ModulePtr check(
const SourceModule& sourceModule,
Mode mode,
const std::vector<RequireCycle>& requireCycles,
NotNull<BuiltinTypes> builtinTypes,
NotNull<InternalErrorReporter> iceHandler,
NotNull<ModuleResolver> moduleResolver,
NotNull<FileResolver> fileResolver,
const ScopePtr& globalScope,
std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope,
FrontendOptions options,
TypeCheckLimits limits
);
ModulePtr check(const SourceModule& sourceModule, Mode mode, const std::vector<RequireCycle>& requireCycles, NotNull<BuiltinTypes> builtinTypes,
NotNull<InternalErrorReporter> iceHandler, NotNull<ModuleResolver> moduleResolver, NotNull<FileResolver> fileResolver,
const ScopePtr& globalScope, std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope, FrontendOptions options,
TypeCheckLimits limits, bool recordJsonLog, std::function<void(const ModuleName&, std::string)> writeJsonLog);
ModulePtr check(
const SourceModule& sourceModule,
Mode mode,
const std::vector<RequireCycle>& requireCycles,
NotNull<BuiltinTypes> builtinTypes,
NotNull<InternalErrorReporter> iceHandler,
NotNull<ModuleResolver> moduleResolver,
NotNull<FileResolver> fileResolver,
const ScopePtr& globalScope,
std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope,
FrontendOptions options,
TypeCheckLimits limits,
bool recordJsonLog,
std::function<void(const ModuleName&, std::string)> writeJsonLog
);
} // namespace Luau

View file

@ -8,6 +8,12 @@
namespace Luau
{
std::optional<TypeId> generalize(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtinTypes, NotNull<Scope> scope,
NotNull<DenseHashSet<TypeId>> bakedTypes, TypeId ty, /* avoid sealing tables*/ bool avoidSealingTables = false);
std::optional<TypeId> generalize(
NotNull<TypeArena> arena,
NotNull<BuiltinTypes> builtinTypes,
NotNull<Scope> scope,
NotNull<DenseHashSet<TypeId>> bakedTypes,
TypeId ty,
/* avoid sealing tables*/ bool avoidSealingTables = false
);
}

View file

@ -17,8 +17,15 @@ struct TypeCheckLimits;
// A substitution which replaces generic types in a given set by free types.
struct ReplaceGenerics : Substitution
{
ReplaceGenerics(const TxnLog* log, TypeArena* arena, NotNull<BuiltinTypes> builtinTypes, TypeLevel level, Scope* scope,
const std::vector<TypeId>& generics, const std::vector<TypePackId>& genericPacks)
ReplaceGenerics(
const TxnLog* log,
TypeArena* arena,
NotNull<BuiltinTypes> builtinTypes,
TypeLevel level,
Scope* scope,
const std::vector<TypeId>& generics,
const std::vector<TypePackId>& genericPacks
)
: Substitution(log, arena)
, builtinTypes(builtinTypes)
, level(level)
@ -28,8 +35,15 @@ struct ReplaceGenerics : Substitution
{
}
void resetState(const TxnLog* log, TypeArena* arena, NotNull<BuiltinTypes> builtinTypes, TypeLevel level, Scope* scope,
const std::vector<TypeId>& generics, const std::vector<TypePackId>& genericPacks);
void resetState(
const TxnLog* log,
TypeArena* arena,
NotNull<BuiltinTypes> builtinTypes,
TypeLevel level,
Scope* scope,
const std::vector<TypeId>& generics,
const std::vector<TypePackId>& genericPacks
);
NotNull<BuiltinTypes> builtinTypes;
@ -141,6 +155,11 @@ struct GenericTypeFinder : TypeOnceVisitor
* limits to be exceeded.
*/
std::optional<TypeId> instantiate(
NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, NotNull<TypeCheckLimits> limits, NotNull<Scope> scope, TypeId ty);
NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> arena,
NotNull<TypeCheckLimits> limits,
NotNull<Scope> scope,
TypeId ty
);
} // namespace Luau

View file

@ -75,8 +75,16 @@ struct Instantiation2 : Substitution
};
std::optional<TypeId> instantiate2(
TypeArena* arena, DenseHashMap<TypeId, TypeId> genericSubstitutions, DenseHashMap<TypePackId, TypePackId> genericPackSubstitutions, TypeId ty);
std::optional<TypePackId> instantiate2(TypeArena* arena, DenseHashMap<TypeId, TypeId> genericSubstitutions,
DenseHashMap<TypePackId, TypePackId> genericPackSubstitutions, TypePackId tp);
TypeArena* arena,
DenseHashMap<TypeId, TypeId> genericSubstitutions,
DenseHashMap<TypePackId, TypePackId> genericPackSubstitutions,
TypeId ty
);
std::optional<TypePackId> instantiate2(
TypeArena* arena,
DenseHashMap<TypeId, TypeId> genericSubstitutions,
DenseHashMap<TypePackId, TypePackId> genericPackSubstitutions,
TypePackId tp
);
} // namespace Luau

View file

@ -25,8 +25,14 @@ struct LintResult
std::vector<LintWarning> warnings;
};
std::vector<LintWarning> lint(AstStat* root, const AstNameTable& names, const ScopePtr& env, const Module* module,
const std::vector<HotComment>& hotcomments, const LintOptions& options);
std::vector<LintWarning> lint(
AstStat* root,
const AstNameTable& names,
const ScopePtr& env,
const Module* module,
const std::vector<HotComment>& hotcomments,
const LintOptions& options
);
std::vector<AstName> getDeprecatedGlobals(const AstNameTable& names);

View file

@ -12,8 +12,15 @@ struct BuiltinTypes;
struct UnifierSharedState;
struct TypeCheckLimits;
void checkNonStrict(NotNull<BuiltinTypes> builtinTypes, NotNull<InternalErrorReporter> ice, NotNull<UnifierSharedState> unifierState,
NotNull<const DataFlowGraph> dfg, NotNull<TypeCheckLimits> limits, const SourceModule& sourceModule, Module* module);
void checkNonStrict(
NotNull<BuiltinTypes> builtinTypes,
NotNull<InternalErrorReporter> ice,
NotNull<UnifierSharedState> unifierState,
NotNull<const DataFlowGraph> dfg,
NotNull<TypeCheckLimits> limits,
const SourceModule& sourceModule,
Module* module
);
} // namespace Luau

View file

@ -31,8 +31,15 @@ struct OverloadResolver
OverloadIsNonviable, // Arguments were incompatible with the overloads parameters but were otherwise compatible by arity
};
OverloadResolver(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, NotNull<Normalizer> normalizer, NotNull<Scope> scope,
NotNull<InternalErrorReporter> reporter, NotNull<TypeCheckLimits> limits, Location callLocation);
OverloadResolver(
NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> arena,
NotNull<Normalizer> normalizer,
NotNull<Scope> scope,
NotNull<InternalErrorReporter> reporter,
NotNull<TypeCheckLimits> limits,
Location callLocation
);
NotNull<BuiltinTypes> builtinTypes;
NotNull<TypeArena> arena;
@ -58,11 +65,21 @@ private:
std::optional<ErrorVec> testIsSubtype(const Location& location, TypeId subTy, TypeId superTy);
std::optional<ErrorVec> testIsSubtype(const Location& location, TypePackId subTy, TypePackId superTy);
std::pair<Analysis, ErrorVec> checkOverload(
TypeId fnTy, const TypePack* args, AstExpr* fnLoc, const std::vector<AstExpr*>* argExprs, bool callMetamethodOk = true);
TypeId fnTy,
const TypePack* args,
AstExpr* fnLoc,
const std::vector<AstExpr*>* argExprs,
bool callMetamethodOk = true
);
static bool isLiteral(AstExpr* expr);
LUAU_NOINLINE
std::pair<Analysis, ErrorVec> checkOverload_(
TypeId fnTy, const FunctionType* fn, const TypePack* args, AstExpr* fnExpr, const std::vector<AstExpr*>* argExprs);
TypeId fnTy,
const FunctionType* fn,
const TypePack* args,
AstExpr* fnExpr,
const std::vector<AstExpr*>* argExprs
);
size_t indexof(Analysis analysis);
void add(Analysis analysis, TypeId ty, ErrorVec&& errors);
};
@ -88,8 +105,16 @@ struct SolveResult
// Helper utility, presently used for binary operator type functions.
//
// Given a function and a set of arguments, select a suitable overload.
SolveResult solveFunctionCall(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtinTypes, NotNull<Normalizer> normalizer,
NotNull<InternalErrorReporter> iceReporter, NotNull<TypeCheckLimits> limits, NotNull<Scope> scope, const Location& location, TypeId fn,
TypePackId argsPack);
SolveResult solveFunctionCall(
NotNull<TypeArena> arena,
NotNull<BuiltinTypes> builtinTypes,
NotNull<Normalizer> normalizer,
NotNull<InternalErrorReporter> iceReporter,
NotNull<TypeCheckLimits> limits,
NotNull<Scope> scope,
const Location& location,
TypeId fn,
TypePackId argsPack
);
} // namespace Luau

View file

@ -140,8 +140,13 @@ struct Subtyping
SeenSet seenTypes{{}};
Subtyping(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> typeArena, NotNull<Normalizer> normalizer,
NotNull<InternalErrorReporter> iceReporter, NotNull<Scope> scope);
Subtyping(
NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> typeArena,
NotNull<Normalizer> normalizer,
NotNull<InternalErrorReporter> iceReporter,
NotNull<Scope> scope
);
Subtyping(const Subtyping&) = delete;
Subtyping& operator=(const Subtyping&) = delete;
@ -209,13 +214,19 @@ private:
SubtypingResult isCovariantWith(SubtypingEnvironment& env, const Property& subProperty, const Property& superProperty, const std::string& name);
SubtypingResult isCovariantWith(
SubtypingEnvironment& env, const std::shared_ptr<const NormalizedType>& subNorm, const std::shared_ptr<const NormalizedType>& superNorm);
SubtypingEnvironment& env,
const std::shared_ptr<const NormalizedType>& subNorm,
const std::shared_ptr<const NormalizedType>& superNorm
);
SubtypingResult isCovariantWith(SubtypingEnvironment& env, const NormalizedClassType& subClass, const NormalizedClassType& superClass);
SubtypingResult isCovariantWith(SubtypingEnvironment& env, const NormalizedClassType& subClass, const TypeIds& superTables);
SubtypingResult isCovariantWith(SubtypingEnvironment& env, const NormalizedStringType& subString, const NormalizedStringType& superString);
SubtypingResult isCovariantWith(SubtypingEnvironment& env, const NormalizedStringType& subString, const TypeIds& superTables);
SubtypingResult isCovariantWith(
SubtypingEnvironment& env, const NormalizedFunctionType& subFunction, const NormalizedFunctionType& superFunction);
SubtypingEnvironment& env,
const NormalizedFunctionType& subFunction,
const NormalizedFunctionType& superFunction
);
SubtypingResult isCovariantWith(SubtypingEnvironment& env, const TypeIds& subTypes, const TypeIds& superTypes);
SubtypingResult isCovariantWith(SubtypingEnvironment& env, const VariadicTypePack* subVariadic, const VariadicTypePack* superVariadic);

View file

@ -14,7 +14,15 @@ struct BuiltinTypes;
struct Unifier2;
class AstExpr;
TypeId matchLiteralType(NotNull<DenseHashMap<const AstExpr*, TypeId>> astTypes, NotNull<DenseHashMap<const AstExpr*, TypeId>> astExpectedTypes,
NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, NotNull<Unifier2> unifier, TypeId expectedType, TypeId exprType,
const AstExpr* expr, std::vector<TypeId>& toBlock);
TypeId matchLiteralType(
NotNull<DenseHashMap<const AstExpr*, TypeId>> astTypes,
NotNull<DenseHashMap<const AstExpr*, TypeId>> astExpectedTypes,
NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> arena,
NotNull<Unifier2> unifier,
TypeId expectedType,
TypeId exprType,
const AstExpr* expr,
std::vector<TypeId>& toBlock
);
} // namespace Luau

View file

@ -276,8 +276,8 @@ struct WithPredicate
}
};
using MagicFunction = std::function<std::optional<WithPredicate<TypePackId>>(
struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>)>;
using MagicFunction = std::function<std::optional<
WithPredicate<TypePackId>>(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>)>;
struct MagicFunctionCallContext
{
@ -305,19 +305,46 @@ struct FunctionType
FunctionType(TypePackId argTypes, TypePackId retTypes, std::optional<FunctionDefinition> defn = {}, bool hasSelf = false);
// Global polymorphic function
FunctionType(std::vector<TypeId> generics, std::vector<TypePackId> genericPacks, TypePackId argTypes, TypePackId retTypes,
std::optional<FunctionDefinition> defn = {}, bool hasSelf = false);
FunctionType(
std::vector<TypeId> generics,
std::vector<TypePackId> genericPacks,
TypePackId argTypes,
TypePackId retTypes,
std::optional<FunctionDefinition> defn = {},
bool hasSelf = false
);
// Local monomorphic function
FunctionType(TypeLevel level, TypePackId argTypes, TypePackId retTypes, std::optional<FunctionDefinition> defn = {}, bool hasSelf = false);
FunctionType(
TypeLevel level, Scope* scope, TypePackId argTypes, TypePackId retTypes, std::optional<FunctionDefinition> defn = {}, bool hasSelf = false);
TypeLevel level,
Scope* scope,
TypePackId argTypes,
TypePackId retTypes,
std::optional<FunctionDefinition> defn = {},
bool hasSelf = false
);
// Local polymorphic function
FunctionType(TypeLevel level, std::vector<TypeId> generics, std::vector<TypePackId> genericPacks, TypePackId argTypes, TypePackId retTypes,
std::optional<FunctionDefinition> defn = {}, bool hasSelf = false);
FunctionType(TypeLevel level, Scope* scope, std::vector<TypeId> generics, std::vector<TypePackId> genericPacks, TypePackId argTypes,
TypePackId retTypes, std::optional<FunctionDefinition> defn = {}, bool hasSelf = false);
FunctionType(
TypeLevel level,
std::vector<TypeId> generics,
std::vector<TypePackId> genericPacks,
TypePackId argTypes,
TypePackId retTypes,
std::optional<FunctionDefinition> defn = {},
bool hasSelf = false
);
FunctionType(
TypeLevel level,
Scope* scope,
std::vector<TypeId> generics,
std::vector<TypePackId> genericPacks,
TypePackId argTypes,
TypePackId retTypes,
std::optional<FunctionDefinition> defn = {},
bool hasSelf = false
);
std::optional<FunctionDefinition> definition;
/// These should all be generic
@ -398,9 +425,15 @@ struct Property
// DEPRECATED
// TODO: Kill all constructors in favor of `Property::rw(TypeId read, TypeId write)` and friends.
Property();
Property(TypeId readTy, bool deprecated = false, const std::string& deprecatedSuggestion = "", std::optional<Location> location = std::nullopt,
const Tags& tags = {}, const std::optional<std::string>& documentationSymbol = std::nullopt,
std::optional<Location> typeLocation = std::nullopt);
Property(
TypeId readTy,
bool deprecated = false,
const std::string& deprecatedSuggestion = "",
std::optional<Location> location = std::nullopt,
const Tags& tags = {},
const std::optional<std::string>& documentationSymbol = std::nullopt,
std::optional<Location> typeLocation = std::nullopt
);
// DEPRECATED: Should only be called in non-RWP! We assert that the `readTy` is not nullopt.
// TODO: Kill once we don't have non-RWP.
@ -502,8 +535,16 @@ struct ClassType
std::optional<Location> definitionLocation;
std::optional<TableIndexer> indexer;
ClassType(Name name, Props props, std::optional<TypeId> parent, std::optional<TypeId> metatable, Tags tags,
std::shared_ptr<ClassUserData> userData, ModuleName definitionModuleName, std::optional<Location> definitionLocation)
ClassType(
Name name,
Props props,
std::optional<TypeId> parent,
std::optional<TypeId> metatable,
Tags tags,
std::shared_ptr<ClassUserData> userData,
ModuleName definitionModuleName,
std::optional<Location> definitionLocation
)
: name(name)
, props(props)
, parent(parent)
@ -515,9 +556,17 @@ struct ClassType
{
}
ClassType(Name name, Props props, std::optional<TypeId> parent, std::optional<TypeId> metatable, Tags tags,
std::shared_ptr<ClassUserData> userData, ModuleName definitionModuleName, std::optional<Location> definitionLocation,
std::optional<TableIndexer> indexer)
ClassType(
Name name,
Props props,
std::optional<TypeId> parent,
std::optional<TypeId> metatable,
Tags tags,
std::shared_ptr<ClassUserData> userData,
ModuleName definitionModuleName,
std::optional<Location> definitionLocation,
std::optional<TableIndexer> indexer
)
: name(name)
, props(props)
, parent(parent)
@ -661,9 +710,26 @@ struct NegationType
using ErrorType = Unifiable::Error;
using TypeVariant =
Unifiable::Variant<TypeId, FreeType, GenericType, PrimitiveType, SingletonType, BlockedType, PendingExpansionType, FunctionType, TableType,
MetatableType, ClassType, AnyType, UnionType, IntersectionType, LazyType, UnknownType, NeverType, NegationType, TypeFunctionInstanceType>;
using TypeVariant = Unifiable::Variant<
TypeId,
FreeType,
GenericType,
PrimitiveType,
SingletonType,
BlockedType,
PendingExpansionType,
FunctionType,
TableType,
MetatableType,
ClassType,
AnyType,
UnionType,
IntersectionType,
LazyType,
UnknownType,
NeverType,
NegationType,
TypeFunctionInstanceType>;
struct Type final
{

View file

@ -14,7 +14,13 @@ struct UnifierSharedState;
struct SourceModule;
struct Module;
void check(NotNull<BuiltinTypes> builtinTypes, NotNull<UnifierSharedState> sharedState, NotNull<TypeCheckLimits> limits, DcrLogger* logger,
const SourceModule& sourceModule, Module* module);
void check(
NotNull<BuiltinTypes> builtinTypes,
NotNull<UnifierSharedState> sharedState,
NotNull<TypeCheckLimits> limits,
DcrLogger* logger,
const SourceModule& sourceModule,
Module* module
);
} // namespace Luau

View file

@ -44,8 +44,14 @@ struct TypeFunctionContext
{
}
TypeFunctionContext(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtins, NotNull<Scope> scope, NotNull<Normalizer> normalizer,
NotNull<InternalErrorReporter> ice, NotNull<TypeCheckLimits> limits)
TypeFunctionContext(
NotNull<TypeArena> arena,
NotNull<BuiltinTypes> builtins,
NotNull<Scope> scope,
NotNull<Normalizer> normalizer,
NotNull<InternalErrorReporter> ice,
NotNull<TypeCheckLimits> limits
)
: arena(arena)
, builtins(builtins)
, scope(scope)

View file

@ -62,7 +62,11 @@ struct HashBoolNamePair
struct TypeChecker
{
explicit TypeChecker(
const ScopePtr& globalScope, ModuleResolver* resolver, NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter* iceHandler);
const ScopePtr& globalScope,
ModuleResolver* resolver,
NotNull<BuiltinTypes> builtinTypes,
InternalErrorReporter* iceHandler
);
TypeChecker(const TypeChecker&) = delete;
TypeChecker& operator=(const TypeChecker&) = delete;
@ -85,6 +89,7 @@ struct TypeChecker
ControlFlow check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatFunction& function);
ControlFlow check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatLocalFunction& function);
ControlFlow check(const ScopePtr& scope, const AstStatTypeAlias& typealias);
ControlFlow check(const ScopePtr& scope, const AstStatTypeFunction& typefunction);
ControlFlow check(const ScopePtr& scope, const AstStatDeclareClass& declaredClass);
ControlFlow check(const ScopePtr& scope, const AstStatDeclareFunction& declaredFunction);
@ -96,7 +101,11 @@ struct TypeChecker
void checkBlockTypeAliases(const ScopePtr& scope, std::vector<AstStat*>& sorted);
WithPredicate<TypeId> checkExpr(
const ScopePtr& scope, const AstExpr& expr, std::optional<TypeId> expectedType = std::nullopt, bool forceSingleton = false);
const ScopePtr& scope,
const AstExpr& expr,
std::optional<TypeId> expectedType = std::nullopt,
bool forceSingleton = false
);
WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprLocal& expr);
WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprGlobal& expr);
WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprVarargs& expr);
@ -107,17 +116,31 @@ struct TypeChecker
WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprTable& expr, std::optional<TypeId> expectedType = std::nullopt);
WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprUnary& expr);
TypeId checkRelationalOperation(
const ScopePtr& scope, const AstExprBinary& expr, TypeId lhsType, TypeId rhsType, const PredicateVec& predicates = {});
const ScopePtr& scope,
const AstExprBinary& expr,
TypeId lhsType,
TypeId rhsType,
const PredicateVec& predicates = {}
);
TypeId checkBinaryOperation(
const ScopePtr& scope, const AstExprBinary& expr, TypeId lhsType, TypeId rhsType, const PredicateVec& predicates = {});
const ScopePtr& scope,
const AstExprBinary& expr,
TypeId lhsType,
TypeId rhsType,
const PredicateVec& predicates = {}
);
WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprBinary& expr, std::optional<TypeId> expectedType = std::nullopt);
WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprTypeAssertion& expr);
WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprError& expr);
WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprIfElse& expr, std::optional<TypeId> expectedType = std::nullopt);
WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprInterpString& expr);
TypeId checkExprTable(const ScopePtr& scope, const AstExprTable& expr, const std::vector<std::pair<TypeId, TypeId>>& fieldTypes,
std::optional<TypeId> expectedType);
TypeId checkExprTable(
const ScopePtr& scope,
const AstExprTable& expr,
const std::vector<std::pair<TypeId, TypeId>>& fieldTypes,
std::optional<TypeId> expectedType
);
// Returns the type of the lvalue.
TypeId checkLValue(const ScopePtr& scope, const AstExpr& expr, ValueContext ctx);
@ -130,34 +153,79 @@ struct TypeChecker
TypeId checkLValueBinding(const ScopePtr& scope, const AstExprIndexExpr& expr, ValueContext ctx);
TypeId checkFunctionName(const ScopePtr& scope, AstExpr& funName, TypeLevel level);
std::pair<TypeId, ScopePtr> checkFunctionSignature(const ScopePtr& scope, int subLevel, const AstExprFunction& expr,
std::optional<Location> originalNameLoc, std::optional<TypeId> selfType, std::optional<TypeId> expectedType);
std::pair<TypeId, ScopePtr> checkFunctionSignature(
const ScopePtr& scope,
int subLevel,
const AstExprFunction& expr,
std::optional<Location> originalNameLoc,
std::optional<TypeId> selfType,
std::optional<TypeId> expectedType
);
void checkFunctionBody(const ScopePtr& scope, TypeId type, const AstExprFunction& function);
void checkArgumentList(const ScopePtr& scope, const AstExpr& funName, Unifier& state, TypePackId paramPack, TypePackId argPack,
const std::vector<Location>& argLocations);
void checkArgumentList(
const ScopePtr& scope,
const AstExpr& funName,
Unifier& state,
TypePackId paramPack,
TypePackId argPack,
const std::vector<Location>& argLocations
);
WithPredicate<TypePackId> checkExprPack(const ScopePtr& scope, const AstExpr& expr);
WithPredicate<TypePackId> checkExprPackHelper(const ScopePtr& scope, const AstExpr& expr);
WithPredicate<TypePackId> checkExprPackHelper(const ScopePtr& scope, const AstExprCall& expr);
WithPredicate<TypePackId> checkExprPackHelper2(
const ScopePtr& scope, const AstExprCall& expr, TypeId selfType, TypeId actualFunctionType, TypeId functionType, TypePackId retPack);
const ScopePtr& scope,
const AstExprCall& expr,
TypeId selfType,
TypeId actualFunctionType,
TypeId functionType,
TypePackId retPack
);
std::vector<std::optional<TypeId>> getExpectedTypesForCall(const std::vector<TypeId>& overloads, size_t argumentCount, bool selfCall);
std::unique_ptr<WithPredicate<TypePackId>> checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack,
TypePackId argPack, TypePack* args, const std::vector<Location>* argLocations, const WithPredicate<TypePackId>& argListResult,
std::vector<TypeId>& overloadsThatMatchArgCount, std::vector<TypeId>& overloadsThatDont, std::vector<OverloadErrorEntry>& errors);
bool handleSelfCallMismatch(const ScopePtr& scope, const AstExprCall& expr, TypePack* args, const std::vector<Location>& argLocations,
const std::vector<OverloadErrorEntry>& errors);
void reportOverloadResolutionError(const ScopePtr& scope, const AstExprCall& expr, TypePackId retPack, TypePackId argPack,
const std::vector<Location>& argLocations, const std::vector<TypeId>& overloads, const std::vector<TypeId>& overloadsThatMatchArgCount,
std::vector<OverloadErrorEntry>& errors);
std::unique_ptr<WithPredicate<TypePackId>> checkCallOverload(
const ScopePtr& scope,
const AstExprCall& expr,
TypeId fn,
TypePackId retPack,
TypePackId argPack,
TypePack* args,
const std::vector<Location>* argLocations,
const WithPredicate<TypePackId>& argListResult,
std::vector<TypeId>& overloadsThatMatchArgCount,
std::vector<TypeId>& overloadsThatDont,
std::vector<OverloadErrorEntry>& errors
);
bool handleSelfCallMismatch(
const ScopePtr& scope,
const AstExprCall& expr,
TypePack* args,
const std::vector<Location>& argLocations,
const std::vector<OverloadErrorEntry>& errors
);
void reportOverloadResolutionError(
const ScopePtr& scope,
const AstExprCall& expr,
TypePackId retPack,
TypePackId argPack,
const std::vector<Location>& argLocations,
const std::vector<TypeId>& overloads,
const std::vector<TypeId>& overloadsThatMatchArgCount,
std::vector<OverloadErrorEntry>& errors
);
WithPredicate<TypePackId> checkExprList(const ScopePtr& scope, const Location& location, const AstArray<AstExpr*>& exprs,
bool substituteFreeForNil = false, const std::vector<bool>& lhsAnnotations = {},
const std::vector<std::optional<TypeId>>& expectedTypes = {});
WithPredicate<TypePackId> checkExprList(
const ScopePtr& scope,
const Location& location,
const AstArray<AstExpr*>& exprs,
bool substituteFreeForNil = false,
const std::vector<bool>& lhsAnnotations = {},
const std::vector<std::optional<TypeId>>& expectedTypes = {}
);
static std::optional<AstExpr*> matchRequire(const AstExprCall& call);
TypeId checkRequire(const ScopePtr& scope, const ModuleInfo& moduleInfo, const Location& location);
@ -175,8 +243,13 @@ struct TypeChecker
*/
bool unify(TypeId subTy, TypeId superTy, const ScopePtr& scope, const Location& location);
bool unify(TypeId subTy, TypeId superTy, const ScopePtr& scope, const Location& location, const UnifierOptions& options);
bool unify(TypePackId subTy, TypePackId superTy, const ScopePtr& scope, const Location& location,
CountMismatch::Context ctx = CountMismatch::Context::Arg);
bool unify(
TypePackId subTy,
TypePackId superTy,
const ScopePtr& scope,
const Location& location,
CountMismatch::Context ctx = CountMismatch::Context::Arg
);
/** Attempt to unify the types.
* If this fails, and the subTy type can be instantiated, do so and try unification again.
@ -313,12 +386,23 @@ private:
TypeId resolveTypeWorker(const ScopePtr& scope, const AstType& annotation);
TypePackId resolveTypePack(const ScopePtr& scope, const AstTypeList& types);
TypePackId resolveTypePack(const ScopePtr& scope, const AstTypePack& annotation);
TypeId instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& typePackParams, const Location& location);
TypeId instantiateTypeFun(
const ScopePtr& scope,
const TypeFun& tf,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& typePackParams,
const Location& location
);
// Note: `scope` must be a fresh scope.
GenericTypeDefinitions createGenericTypes(const ScopePtr& scope, std::optional<TypeLevel> levelOpt, const AstNode& node,
const AstArray<AstGenericType>& genericNames, const AstArray<AstGenericTypePack>& genericPackNames, bool useCache = false);
GenericTypeDefinitions createGenericTypes(
const ScopePtr& scope,
std::optional<TypeLevel> levelOpt,
const AstNode& node,
const AstArray<AstGenericType>& genericNames,
const AstArray<AstGenericTypePack>& genericPackNames,
bool useCache = false
);
public:
void resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense);

View file

@ -56,14 +56,35 @@ struct InConditionalContext
using ScopePtr = std::shared_ptr<struct Scope>;
std::optional<Property> findTableProperty(
NotNull<BuiltinTypes> builtinTypes, ErrorVec& errors, TypeId ty, const std::string& name, Location location);
NotNull<BuiltinTypes> builtinTypes,
ErrorVec& errors,
TypeId ty,
const std::string& name,
Location location
);
std::optional<TypeId> findMetatableEntry(
NotNull<BuiltinTypes> builtinTypes, ErrorVec& errors, TypeId type, const std::string& entry, Location location);
NotNull<BuiltinTypes> builtinTypes,
ErrorVec& errors,
TypeId type,
const std::string& entry,
Location location
);
std::optional<TypeId> findTablePropertyRespectingMeta(
NotNull<BuiltinTypes> builtinTypes, ErrorVec& errors, TypeId ty, const std::string& name, Location location);
NotNull<BuiltinTypes> builtinTypes,
ErrorVec& errors,
TypeId ty,
const std::string& name,
Location location
);
std::optional<TypeId> findTablePropertyRespectingMeta(
NotNull<BuiltinTypes> builtinTypes, ErrorVec& errors, TypeId ty, const std::string& name, ValueContext context, Location location);
NotNull<BuiltinTypes> builtinTypes,
ErrorVec& errors,
TypeId ty,
const std::string& name,
ValueContext context,
Location location
);
bool occursCheck(TypeId needle, TypeId haystack);
@ -73,7 +94,12 @@ std::pair<size_t, std::optional<size_t>> getParameterExtents(const TxnLog* log,
// Extend the provided pack to at least `length` types.
// Returns a temporary TypePack that contains those types plus a tail.
TypePack extendTypePack(
TypeArena& arena, NotNull<BuiltinTypes> builtinTypes, TypePackId pack, size_t length, std::vector<std::optional<TypeId>> overrides = {});
TypeArena& arena,
NotNull<BuiltinTypes> builtinTypes,
TypePackId pack,
size_t length,
std::vector<std::optional<TypeId>> overrides = {}
);
/**
* Reduces a union by decomposing to the any/error type if it appears in the

View file

@ -106,11 +106,21 @@ struct Unifier
* Populate the transaction log with the set of TypeIds that need to be reset to undo the unification attempt.
*/
void tryUnify(
TypeId subTy, TypeId superTy, bool isFunctionCall = false, bool isIntersection = false, const LiteralProperties* aliasableMap = nullptr);
TypeId subTy,
TypeId superTy,
bool isFunctionCall = false,
bool isIntersection = false,
const LiteralProperties* aliasableMap = nullptr
);
private:
void tryUnify_(
TypeId subTy, TypeId superTy, bool isFunctionCall = false, bool isIntersection = false, const LiteralProperties* aliasableMap = nullptr);
TypeId subTy,
TypeId superTy,
bool isFunctionCall = false,
bool isIntersection = false,
const LiteralProperties* aliasableMap = nullptr
);
void tryUnifyUnionWithType(TypeId subTy, const UnionType* uv, TypeId superTy);
// Traverse the two types provided and block on any BlockedTypes we find.
@ -120,8 +130,14 @@ private:
void tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionType* uv, bool cacheEnabled, bool isFunctionCall);
void tryUnifyTypeWithIntersection(TypeId subTy, TypeId superTy, const IntersectionType* uv);
void tryUnifyIntersectionWithType(TypeId subTy, const IntersectionType* uv, TypeId superTy, bool cacheEnabled, bool isFunctionCall);
void tryUnifyNormalizedTypes(TypeId subTy, TypeId superTy, const NormalizedType& subNorm, const NormalizedType& superNorm, std::string reason,
std::optional<TypeError> error = std::nullopt);
void tryUnifyNormalizedTypes(
TypeId subTy,
TypeId superTy,
const NormalizedType& subNorm,
const NormalizedType& superNorm,
std::string reason,
std::optional<TypeError> error = std::nullopt
);
void tryUnifyPrimitives(TypeId subTy, TypeId superTy);
void tryUnifySingletons(TypeId subTy, TypeId superTy);
void tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCall = false);

View file

@ -52,8 +52,13 @@ struct Unifier2
DenseHashSet<const void*>* uninhabitedTypeFunctions;
Unifier2(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtinTypes, NotNull<Scope> scope, NotNull<InternalErrorReporter> ice);
Unifier2(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtinTypes, NotNull<Scope> scope, NotNull<InternalErrorReporter> ice,
DenseHashSet<const void*>* uninhabitedTypeFunctions);
Unifier2(
NotNull<TypeArena> arena,
NotNull<BuiltinTypes> builtinTypes,
NotNull<Scope> scope,
NotNull<InternalErrorReporter> ice,
DenseHashSet<const void*>* uninhabitedTypeFunctions
);
/** Attempt to commit the subtype relation subTy <: superTy to the type
* graph.

View file

@ -46,33 +46,12 @@ LUAU_FASTFLAG(DebugLuauMagicTypes);
namespace Luau
{
// TODO: instead of pair just type for solver? generated type
// TODO: see lookupAnnotation in typechecker2. is cleaner than resolvetype
// or delay containsAny() check and do not return pair.
// quick flag in typeid saying was annotation or inferred, would be solid
std::optional<TypeOrPack> getInferredType(AstExpr* expr, Module* module)
void AnyTypeSummary::traverse(const Module* module, AstStat* src, NotNull<BuiltinTypes> builtinTypes)
{
std::optional<TypeOrPack> inferredType;
if (module->astTypePacks.contains(expr))
{
inferredType = *module->astTypePacks.find(expr);
}
else if (module->astTypes.contains(expr))
{
inferredType = *module->astTypes.find(expr);
}
return inferredType;
visit(findInnerMostScope(src->location, module), src, module, builtinTypes);
}
void AnyTypeSummary::traverse(Module* module, AstStat* src, NotNull<BuiltinTypes> builtinTypes)
{
Scope* scope = findInnerMostScope(src->location, module);
visit(scope, src, module, builtinTypes);
}
void AnyTypeSummary::visit(Scope* scope, AstStat* stat, Module* module, NotNull<BuiltinTypes> builtinTypes)
void AnyTypeSummary::visit(const Scope* scope, AstStat* stat, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{
RecursionLimiter limiter{&recursionCount, FInt::LuauAnySummaryRecursionLimit};
@ -114,7 +93,7 @@ void AnyTypeSummary::visit(Scope* scope, AstStat* stat, Module* module, NotNull<
return visit(scope, s, module, builtinTypes);
}
void AnyTypeSummary::visit(Scope* scope, AstStatBlock* block, Module* module, NotNull<BuiltinTypes> builtinTypes)
void AnyTypeSummary::visit(const Scope* scope, AstStatBlock* block, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{
RecursionCounter counter{&recursionCount};
@ -125,37 +104,38 @@ void AnyTypeSummary::visit(Scope* scope, AstStatBlock* block, Module* module, No
visit(scope, stat, module, builtinTypes);
}
void AnyTypeSummary::visit(Scope* scope, AstStatIf* ifStatement, Module* module, NotNull<BuiltinTypes> builtinTypes)
void AnyTypeSummary::visit(const Scope* scope, AstStatIf* ifStatement, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{
if (ifStatement->thenbody)
{
Scope* thenScope = findInnerMostScope(ifStatement->thenbody->location, module);
const Scope* thenScope = findInnerMostScope(ifStatement->thenbody->location, module);
visit(thenScope, ifStatement->thenbody, module, builtinTypes);
}
if (ifStatement->elsebody)
{
Scope* elseScope = findInnerMostScope(ifStatement->elsebody->location, module);
const Scope* elseScope = findInnerMostScope(ifStatement->elsebody->location, module);
visit(elseScope, ifStatement->elsebody, module, builtinTypes);
}
}
void AnyTypeSummary::visit(Scope* scope, AstStatWhile* while_, Module* module, NotNull<BuiltinTypes> builtinTypes)
void AnyTypeSummary::visit(const Scope* scope, AstStatWhile* while_, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{
Scope* whileScope = findInnerMostScope(while_->location, module);
const Scope* whileScope = findInnerMostScope(while_->location, module);
visit(whileScope, while_->body, module, builtinTypes);
}
void AnyTypeSummary::visit(Scope* scope, AstStatRepeat* repeat, Module* module, NotNull<BuiltinTypes> builtinTypes)
void AnyTypeSummary::visit(const Scope* scope, AstStatRepeat* repeat, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{
Scope* repeatScope = findInnerMostScope(repeat->location, module);
const Scope* repeatScope = findInnerMostScope(repeat->location, module);
visit(repeatScope, repeat->body, module, builtinTypes);
}
void AnyTypeSummary::visit(Scope* scope, AstStatReturn* ret, Module* module, NotNull<BuiltinTypes> builtinTypes)
{
// Scope* outScope = findOuterScope(ret->location, module);
Scope* retScope = findInnerMostScope(ret->location, module);
void AnyTypeSummary::visit(const Scope* scope, AstStatReturn* ret, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{
const Scope* retScope = findInnerMostScope(ret->location, module);
auto ctxNode = getNode(rootSrc, ret);
for (auto val : ret->list)
{
@ -163,7 +143,7 @@ void AnyTypeSummary::visit(Scope* scope, AstStatReturn* ret, Module* module, Not
{
TelemetryTypePair types;
types.inferredType = toString(lookupType(val, module, builtinTypes));
TypeInfo ti{Pattern::FuncApp, toString(ret), types};
TypeInfo ti{Pattern::FuncApp, toString(ctxNode), types};
typeInfo.push_back(ti);
}
@ -174,19 +154,19 @@ void AnyTypeSummary::visit(Scope* scope, AstStatReturn* ret, Module* module, Not
TelemetryTypePair types;
types.annotatedType = toString(lookupAnnotation(cast->annotation, module, builtinTypes));
auto inf = getInferredType(cast->expr, module);
if (inf)
types.inferredType = toString(*inf);
types.inferredType = toString(lookupType(cast->expr, module, builtinTypes));
TypeInfo ti{Pattern::Casts, toString(ret), types};
TypeInfo ti{Pattern::Casts, toString(ctxNode), types};
typeInfo.push_back(ti);
}
}
}
}
void AnyTypeSummary::visit(Scope* scope, AstStatLocal* local, Module* module, NotNull<BuiltinTypes> builtinTypes)
void AnyTypeSummary::visit(const Scope* scope, AstStatLocal* local, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{
auto ctxNode = getNode(rootSrc, local);
TypePackId values = reconstructTypePack(local->values, module, builtinTypes);
auto [head, tail] = flatten(values);
@ -203,18 +183,30 @@ void AnyTypeSummary::visit(Scope* scope, AstStatLocal* local, Module* module, No
TelemetryTypePair types;
types.annotatedType = toString(annot);
types.inferredType = toString(lookupType(local->values.data[posn], module, builtinTypes));
auto inf = getInferredType(local->values.data[posn], module);
if (inf)
types.inferredType = toString(*inf);
TypeInfo ti{Pattern::VarAnnot, toString(local), types};
TypeInfo ti{Pattern::VarAnnot, toString(ctxNode), types};
typeInfo.push_back(ti);
}
}
const AstExprTypeAssertion* maybeRequire = local->values.data[posn]->as<AstExprTypeAssertion>();
if (!maybeRequire)
continue;
if (isAnyCast(scope, local->values.data[posn], module, builtinTypes))
{
TelemetryTypePair types;
types.inferredType = toString(head[std::min(local->values.size - 1, posn)]);
TypeInfo ti{Pattern::Casts, toString(ctxNode), types};
typeInfo.push_back(ti);
}
}
else
{
if (std::min(local->values.size - 1, posn) < head.size())
{
if (loc->annotation)
@ -227,7 +219,7 @@ void AnyTypeSummary::visit(Scope* scope, AstStatLocal* local, Module* module, No
types.annotatedType = toString(annot);
types.inferredType = toString(head[std::min(local->values.size - 1, posn)]);
TypeInfo ti{Pattern::VarAnnot, toString(local), types};
TypeInfo ti{Pattern::VarAnnot, toString(ctxNode), types};
typeInfo.push_back(ti);
}
}
@ -242,7 +234,7 @@ void AnyTypeSummary::visit(Scope* scope, AstStatLocal* local, Module* module, No
types.inferredType = toString(*tail);
TypeInfo ti{Pattern::VarAny, toString(local), types};
TypeInfo ti{Pattern::VarAny, toString(ctxNode), types};
typeInfo.push_back(ti);
}
}
@ -253,20 +245,22 @@ void AnyTypeSummary::visit(Scope* scope, AstStatLocal* local, Module* module, No
}
}
void AnyTypeSummary::visit(Scope* scope, AstStatFor* for_, Module* module, NotNull<BuiltinTypes> builtinTypes)
void AnyTypeSummary::visit(const Scope* scope, AstStatFor* for_, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{
Scope* forScope = findInnerMostScope(for_->location, module);
const Scope* forScope = findInnerMostScope(for_->location, module);
visit(forScope, for_->body, module, builtinTypes);
}
void AnyTypeSummary::visit(Scope* scope, AstStatForIn* forIn, Module* module, NotNull<BuiltinTypes> builtinTypes)
void AnyTypeSummary::visit(const Scope* scope, AstStatForIn* forIn, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{
Scope* loopScope = findInnerMostScope(forIn->location, module);
const Scope* loopScope = findInnerMostScope(forIn->location, module);
visit(loopScope, forIn->body, module, builtinTypes);
}
void AnyTypeSummary::visit(Scope* scope, AstStatAssign* assign, Module* module, NotNull<BuiltinTypes> builtinTypes)
void AnyTypeSummary::visit(const Scope* scope, AstStatAssign* assign, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{
auto ctxNode = getNode(rootSrc, assign);
TypePackId values = reconstructTypePack(assign->values, module, builtinTypes);
auto [head, tail] = flatten(values);
@ -290,7 +284,7 @@ void AnyTypeSummary::visit(Scope* scope, AstStatAssign* assign, Module* module,
else
types.inferredType = toString(builtinTypes->nilType);
TypeInfo ti{Pattern::Assign, toString(assign), types};
TypeInfo ti{Pattern::Assign, toString(ctxNode), types};
typeInfo.push_back(ti);
}
++posn;
@ -302,11 +296,9 @@ void AnyTypeSummary::visit(Scope* scope, AstStatAssign* assign, Module* module,
{
TelemetryTypePair types;
auto inf = getInferredType(val, module);
if (inf)
types.inferredType = toString(*inf);
types.inferredType = toString(lookupType(val, module, builtinTypes));
TypeInfo ti{Pattern::FuncApp, toString(assign), types};
TypeInfo ti{Pattern::FuncApp, toString(ctxNode), types};
typeInfo.push_back(ti);
}
@ -317,17 +309,15 @@ void AnyTypeSummary::visit(Scope* scope, AstStatAssign* assign, Module* module,
TelemetryTypePair types;
types.annotatedType = toString(lookupAnnotation(cast->annotation, module, builtinTypes));
auto inf = getInferredType(val, module);
if (inf)
types.inferredType = toString(*inf);
types.inferredType = toString(lookupType(val, module, builtinTypes));
TypeInfo ti{Pattern::Casts, toString(assign), types};
TypeInfo ti{Pattern::Casts, toString(ctxNode), types};
typeInfo.push_back(ti);
}
}
}
if (tail)
if (tail)
{
if (containsAny(*tail))
{
@ -335,14 +325,16 @@ void AnyTypeSummary::visit(Scope* scope, AstStatAssign* assign, Module* module,
types.inferredType = toString(*tail);
TypeInfo ti{Pattern::Assign, toString(assign), types};
TypeInfo ti{Pattern::Assign, toString(ctxNode), types};
typeInfo.push_back(ti);
}
}
}
void AnyTypeSummary::visit(Scope* scope, AstStatCompoundAssign* assign, Module* module, NotNull<BuiltinTypes> builtinTypes)
void AnyTypeSummary::visit(const Scope* scope, AstStatCompoundAssign* assign, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{
auto ctxNode = getNode(rootSrc, assign);
TelemetryTypePair types;
types.inferredType = toString(lookupType(assign->value, module, builtinTypes));
@ -352,7 +344,7 @@ void AnyTypeSummary::visit(Scope* scope, AstStatCompoundAssign* assign, Module*
{
if (containsAny(*module->astTypes.find(assign->var)))
{
TypeInfo ti{Pattern::Assign, toString(assign), types};
TypeInfo ti{Pattern::Assign, toString(ctxNode), types};
typeInfo.push_back(ti);
}
}
@ -360,14 +352,14 @@ void AnyTypeSummary::visit(Scope* scope, AstStatCompoundAssign* assign, Module*
{
if (containsAny(*module->astTypePacks.find(assign->var)))
{
TypeInfo ti{Pattern::Assign, toString(assign), types};
TypeInfo ti{Pattern::Assign, toString(ctxNode), types};
typeInfo.push_back(ti);
}
}
if (isAnyCall(scope, assign->value, module, builtinTypes))
{
TypeInfo ti{Pattern::FuncApp, toString(assign), types};
TypeInfo ti{Pattern::FuncApp, toString(ctxNode), types};
typeInfo.push_back(ti);
}
@ -376,17 +368,15 @@ void AnyTypeSummary::visit(Scope* scope, AstStatCompoundAssign* assign, Module*
if (auto cast = assign->value->as<AstExprTypeAssertion>())
{
types.annotatedType = toString(lookupAnnotation(cast->annotation, module, builtinTypes));
auto inf = getInferredType(cast->expr, module);
if (inf)
types.inferredType = toString(*inf);
types.inferredType = toString(lookupType(cast->expr, module, builtinTypes));
TypeInfo ti{Pattern::Casts, toString(assign), types};
TypeInfo ti{Pattern::Casts, toString(ctxNode), types};
typeInfo.push_back(ti);
}
}
}
void AnyTypeSummary::visit(Scope* scope, AstStatFunction* function, Module* module, NotNull<BuiltinTypes> builtinTypes)
void AnyTypeSummary::visit(const Scope* scope, AstStatFunction* function, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{
TelemetryTypePair types;
types.inferredType = toString(lookupType(function->func, module, builtinTypes));
@ -413,25 +403,27 @@ void AnyTypeSummary::visit(Scope* scope, AstStatFunction* function, Module* modu
visit(scope, function->func->body, module, builtinTypes);
}
void AnyTypeSummary::visit(Scope* scope, AstStatLocalFunction* function, Module* module, NotNull<BuiltinTypes> builtinTypes)
void AnyTypeSummary::visit(const Scope* scope, AstStatLocalFunction* function, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{
TelemetryTypePair types;
types.inferredType = toString(lookupType(function->func, module, builtinTypes));
if (hasVariadicAnys(scope, function->func, module, builtinTypes))
{
types.inferredType = toString(lookupType(function->func, module, builtinTypes));
TypeInfo ti{Pattern::VarAny, toString(function), types};
typeInfo.push_back(ti);
}
if (hasArgAnys(scope, function->func, module, builtinTypes))
{
types.inferredType = toString(lookupType(function->func, module, builtinTypes));
TypeInfo ti{Pattern::FuncArg, toString(function), types};
typeInfo.push_back(ti);
}
if (hasAnyReturns(scope, function->func, module, builtinTypes))
{
types.inferredType = toString(lookupType(function->func, module, builtinTypes));
TypeInfo ti{Pattern::FuncRet, toString(function), types};
typeInfo.push_back(ti);
}
@ -440,8 +432,9 @@ void AnyTypeSummary::visit(Scope* scope, AstStatLocalFunction* function, Module*
visit(scope, function->func->body, module, builtinTypes);
}
void AnyTypeSummary::visit(Scope* scope, AstStatTypeAlias* alias, Module* module, NotNull<BuiltinTypes> builtinTypes)
void AnyTypeSummary::visit(const Scope* scope, AstStatTypeAlias* alias, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{
auto ctxNode = getNode(rootSrc, alias);
auto annot = lookupAnnotation(alias->type, module, builtinTypes);
if (containsAny(annot))
@ -450,33 +443,34 @@ void AnyTypeSummary::visit(Scope* scope, AstStatTypeAlias* alias, Module* module
TelemetryTypePair types;
types.annotatedType = toString(annot);
TypeInfo ti{Pattern::Alias, toString(alias), types};
TypeInfo ti{Pattern::Alias, toString(ctxNode), types};
typeInfo.push_back(ti);
}
}
void AnyTypeSummary::visit(Scope* scope, AstStatExpr* expr, Module* module, NotNull<BuiltinTypes> builtinTypes)
void AnyTypeSummary::visit(const Scope* scope, AstStatExpr* expr, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{
auto ctxNode = getNode(rootSrc, expr);
if (isAnyCall(scope, expr->expr, module, builtinTypes))
{
TelemetryTypePair types;
types.inferredType = toString(lookupType(expr->expr, module, builtinTypes));
TypeInfo ti{Pattern::FuncApp, toString(expr), types};
TypeInfo ti{Pattern::FuncApp, toString(ctxNode), types};
typeInfo.push_back(ti);
}
}
void AnyTypeSummary::visit(Scope* scope, AstStatDeclareGlobal* declareGlobal, Module* module, NotNull<BuiltinTypes> builtinTypes) {}
void AnyTypeSummary::visit(const Scope* scope, AstStatDeclareGlobal* declareGlobal, const Module* module, NotNull<BuiltinTypes> builtinTypes) {}
void AnyTypeSummary::visit(Scope* scope, AstStatDeclareClass* declareClass, Module* module, NotNull<BuiltinTypes> builtinTypes) {}
void AnyTypeSummary::visit(const Scope* scope, AstStatDeclareClass* declareClass, const Module* module, NotNull<BuiltinTypes> builtinTypes) {}
void AnyTypeSummary::visit(Scope* scope, AstStatDeclareFunction* declareFunction, Module* module, NotNull<BuiltinTypes> builtinTypes) {}
void AnyTypeSummary::visit(const Scope* scope, AstStatDeclareFunction* declareFunction, const Module* module, NotNull<BuiltinTypes> builtinTypes) {}
void AnyTypeSummary::visit(Scope* scope, AstStatError* error, Module* module, NotNull<BuiltinTypes> builtinTypes) {}
void AnyTypeSummary::visit(const Scope* scope, AstStatError* error, const Module* module, NotNull<BuiltinTypes> builtinTypes) {}
TypeId AnyTypeSummary::checkForFamilyInhabitance(TypeId instance, Location location)
TypeId AnyTypeSummary::checkForFamilyInhabitance(const TypeId instance, const Location location)
{
if (seenTypeFamilyInstances.find(instance))
return instance;
@ -485,13 +479,13 @@ TypeId AnyTypeSummary::checkForFamilyInhabitance(TypeId instance, Location locat
return instance;
}
TypeId AnyTypeSummary::lookupType(AstExpr* expr, Module* module, NotNull<BuiltinTypes> builtinTypes)
TypeId AnyTypeSummary::lookupType(const AstExpr* expr, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{
TypeId* ty = module->astTypes.find(expr);
const TypeId* ty = module->astTypes.find(expr);
if (ty)
return checkForFamilyInhabitance(follow(*ty), expr->location);
TypePackId* tp = module->astTypePacks.find(expr);
const TypePackId* tp = module->astTypePacks.find(expr);
if (tp)
{
if (auto fst = first(*tp, /*ignoreHiddenVariadics*/ false))
@ -503,7 +497,7 @@ TypeId AnyTypeSummary::lookupType(AstExpr* expr, Module* module, NotNull<Builtin
return builtinTypes->errorRecoveryType();
}
TypePackId AnyTypeSummary::reconstructTypePack(AstArray<AstExpr*> exprs, Module* module, NotNull<BuiltinTypes> builtinTypes)
TypePackId AnyTypeSummary::reconstructTypePack(AstArray<AstExpr*> exprs, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{
if (exprs.size == 0)
return arena.addTypePack(TypePack{{}, std::nullopt});
@ -515,14 +509,14 @@ TypePackId AnyTypeSummary::reconstructTypePack(AstArray<AstExpr*> exprs, Module*
head.push_back(lookupType(exprs.data[i], module, builtinTypes));
}
TypePackId* tail = module->astTypePacks.find(exprs.data[exprs.size - 1]);
const TypePackId* tail = module->astTypePacks.find(exprs.data[exprs.size - 1]);
if (tail)
return arena.addTypePack(TypePack{std::move(head), follow(*tail)});
else
return arena.addTypePack(TypePack{std::move(head), builtinTypes->errorRecoveryTypePack()});
}
bool AnyTypeSummary::isAnyCall(Scope* scope, AstExpr* expr, Module* module, NotNull<BuiltinTypes> builtinTypes)
bool AnyTypeSummary::isAnyCall(const Scope* scope, AstExpr* expr, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{
if (auto call = expr->as<AstExprCall>())
{
@ -537,7 +531,7 @@ bool AnyTypeSummary::isAnyCall(Scope* scope, AstExpr* expr, Module* module, NotN
return false;
}
bool AnyTypeSummary::hasVariadicAnys(Scope* scope, AstExprFunction* expr, Module* module, NotNull<BuiltinTypes> builtinTypes)
bool AnyTypeSummary::hasVariadicAnys(const Scope* scope, AstExprFunction* expr, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{
if (expr->vararg && expr->varargAnnotation)
{
@ -550,7 +544,7 @@ bool AnyTypeSummary::hasVariadicAnys(Scope* scope, AstExprFunction* expr, Module
return false;
}
bool AnyTypeSummary::hasArgAnys(Scope* scope, AstExprFunction* expr, Module* module, NotNull<BuiltinTypes> builtinTypes)
bool AnyTypeSummary::hasArgAnys(const Scope* scope, AstExprFunction* expr, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{
if (expr->args.size > 0)
{
@ -569,7 +563,7 @@ bool AnyTypeSummary::hasArgAnys(Scope* scope, AstExprFunction* expr, Module* mod
return false;
}
bool AnyTypeSummary::hasAnyReturns(Scope* scope, AstExprFunction* expr, Module* module, NotNull<BuiltinTypes> builtinTypes)
bool AnyTypeSummary::hasAnyReturns(const Scope* scope, AstExprFunction* expr, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{
if (!expr->returnAnnotation)
{
@ -596,7 +590,7 @@ bool AnyTypeSummary::hasAnyReturns(Scope* scope, AstExprFunction* expr, Module*
return false;
}
bool AnyTypeSummary::isAnyCast(Scope* scope, AstExpr* expr, Module* module, NotNull<BuiltinTypes> builtinTypes)
bool AnyTypeSummary::isAnyCast(const Scope* scope, AstExpr* expr, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{
if (auto cast = expr->as<AstExprTypeAssertion>())
{
@ -609,7 +603,7 @@ bool AnyTypeSummary::isAnyCast(Scope* scope, AstExpr* expr, Module* module, NotN
return false;
}
TypeId AnyTypeSummary::lookupAnnotation(AstType* annotation, Module* module, NotNull<BuiltinTypes> builtintypes)
TypeId AnyTypeSummary::lookupAnnotation(AstType* annotation, const Module* module, NotNull<BuiltinTypes> builtintypes)
{
if (FFlag::DebugLuauMagicTypes)
{
@ -623,14 +617,14 @@ TypeId AnyTypeSummary::lookupAnnotation(AstType* annotation, Module* module, Not
}
}
TypeId* ty = module->astResolvedTypes.find(annotation);
const TypeId* ty = module->astResolvedTypes.find(annotation);
if (ty)
return checkForTypeFunctionInhabitance(follow(*ty), annotation->location);
else
return checkForTypeFunctionInhabitance(builtintypes->errorRecoveryType(), annotation->location);
}
TypeId AnyTypeSummary::checkForTypeFunctionInhabitance(TypeId instance, Location location)
TypeId AnyTypeSummary::checkForTypeFunctionInhabitance(const TypeId instance, const Location location)
{
if (seenTypeFunctionInstances.find(instance))
return instance;
@ -639,9 +633,9 @@ TypeId AnyTypeSummary::checkForTypeFunctionInhabitance(TypeId instance, Location
return instance;
}
std::optional<TypePackId> AnyTypeSummary::lookupPackAnnotation(AstTypePack* annotation, Module* module)
std::optional<TypePackId> AnyTypeSummary::lookupPackAnnotation(AstTypePack* annotation, const Module* module)
{
TypePackId* tp = module->astResolvedTypePacks.find(annotation);
const TypePackId* tp = module->astResolvedTypePacks.find(annotation);
if (tp != nullptr)
return {follow(*tp)};
return {};
@ -786,9 +780,9 @@ bool AnyTypeSummary::containsAny(TypePackId typ)
return found;
}
Scope* AnyTypeSummary::findInnerMostScope(Location location, Module* module)
const Scope* AnyTypeSummary::findInnerMostScope(const Location location, const Module* module)
{
Scope* bestScope = module->getModuleScope().get();
const Scope* bestScope = module->getModuleScope().get();
bool didNarrow = false;
do
@ -808,6 +802,69 @@ Scope* AnyTypeSummary::findInnerMostScope(Location location, Module* module)
return bestScope;
}
std::optional<AstExpr*> AnyTypeSummary::matchRequire(const AstExprCall& call)
{
const char* require = "require";
if (call.args.size != 1)
return std::nullopt;
const AstExprGlobal* funcAsGlobal = call.func->as<AstExprGlobal>();
if (!funcAsGlobal || funcAsGlobal->name != require)
return std::nullopt;
if (call.args.size != 1)
return std::nullopt;
return call.args.data[0];
}
AstNode* AnyTypeSummary::getNode(AstStatBlock* root, AstNode* node)
{
FindReturnAncestry finder(node, root->location.end);
root->visit(&finder);
if (!finder.currNode)
finder.currNode = node;
LUAU_ASSERT(finder.found && finder.currNode);
return finder.currNode;
}
bool AnyTypeSummary::FindReturnAncestry::visit(AstStatLocalFunction* node)
{
currNode = node;
return !found;
}
bool AnyTypeSummary::FindReturnAncestry::visit(AstStatFunction* node)
{
currNode = node;
return !found;
}
bool AnyTypeSummary::FindReturnAncestry::visit(AstType* node)
{
return !found;
}
bool AnyTypeSummary::FindReturnAncestry::visit(AstNode* node)
{
if (node == stat)
{
found = true;
}
if (node->location.end == rootEnd && stat->location.end >= rootEnd)
{
currNode = node;
found = true;
}
return !found;
}
AnyTypeSummary::TypeInfo::TypeInfo(Pattern code, std::string node, TelemetryTypePair type)
: code(code)
, node(node)
@ -815,6 +872,12 @@ AnyTypeSummary::TypeInfo::TypeInfo(Pattern code, std::string node, TelemetryType
{
}
AnyTypeSummary::FindReturnAncestry::FindReturnAncestry(AstNode* stat, Position rootEnd)
: stat(stat)
, rootEnd(rootEnd)
{
}
AnyTypeSummary::AnyTypeSummary() {}
} // namespace Luau

View file

@ -9,8 +9,14 @@
namespace Luau
{
Anyification::Anyification(TypeArena* arena, NotNull<Scope> scope, NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter* iceHandler,
TypeId anyType, TypePackId anyTypePack)
Anyification::Anyification(
TypeArena* arena,
NotNull<Scope> scope,
NotNull<BuiltinTypes> builtinTypes,
InternalErrorReporter* iceHandler,
TypeId anyType,
TypePackId anyTypePack
)
: Substitution(TxnLog::empty(), arena)
, scope(scope)
, builtinTypes(builtinTypes)
@ -20,8 +26,14 @@ Anyification::Anyification(TypeArena* arena, NotNull<Scope> scope, NotNull<Built
{
}
Anyification::Anyification(TypeArena* arena, const ScopePtr& scope, NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter* iceHandler,
TypeId anyType, TypePackId anyTypePack)
Anyification::Anyification(
TypeArena* arena,
const ScopePtr& scope,
NotNull<BuiltinTypes> builtinTypes,
InternalErrorReporter* iceHandler,
TypeId anyType,
TypePackId anyTypePack
)
: Anyification(arena, NotNull{scope.get()}, builtinTypes, iceHandler, anyType, anyTypePack)
{
}

View file

@ -273,9 +273,14 @@ struct AstJsonEncoder : public AstVisitor
void write(class AstExprGroup* node)
{
writeNode(node, "AstExprGroup", [&]() {
write("expr", node->expr);
});
writeNode(
node,
"AstExprGroup",
[&]()
{
write("expr", node->expr);
}
);
}
void write(class AstExprConstantNil* node)
@ -285,37 +290,62 @@ struct AstJsonEncoder : public AstVisitor
void write(class AstExprConstantBool* node)
{
writeNode(node, "AstExprConstantBool", [&]() {
write("value", node->value);
});
writeNode(
node,
"AstExprConstantBool",
[&]()
{
write("value", node->value);
}
);
}
void write(class AstExprConstantNumber* node)
{
writeNode(node, "AstExprConstantNumber", [&]() {
write("value", node->value);
});
writeNode(
node,
"AstExprConstantNumber",
[&]()
{
write("value", node->value);
}
);
}
void write(class AstExprConstantString* node)
{
writeNode(node, "AstExprConstantString", [&]() {
write("value", node->value);
});
writeNode(
node,
"AstExprConstantString",
[&]()
{
write("value", node->value);
}
);
}
void write(class AstExprLocal* node)
{
writeNode(node, "AstExprLocal", [&]() {
write("local", node->local);
});
writeNode(
node,
"AstExprLocal",
[&]()
{
write("local", node->local);
}
);
}
void write(class AstExprGlobal* node)
{
writeNode(node, "AstExprGlobal", [&]() {
write("global", node->name);
});
writeNode(
node,
"AstExprGlobal",
[&]()
{
write("global", node->name);
}
);
}
void write(class AstExprVarargs* node)
@ -349,51 +379,71 @@ struct AstJsonEncoder : public AstVisitor
void write(class AstExprCall* node)
{
writeNode(node, "AstExprCall", [&]() {
PROP(func);
PROP(args);
PROP(self);
PROP(argLocation);
});
writeNode(
node,
"AstExprCall",
[&]()
{
PROP(func);
PROP(args);
PROP(self);
PROP(argLocation);
}
);
}
void write(class AstExprIndexName* node)
{
writeNode(node, "AstExprIndexName", [&]() {
PROP(expr);
PROP(index);
PROP(indexLocation);
PROP(op);
});
writeNode(
node,
"AstExprIndexName",
[&]()
{
PROP(expr);
PROP(index);
PROP(indexLocation);
PROP(op);
}
);
}
void write(class AstExprIndexExpr* node)
{
writeNode(node, "AstExprIndexExpr", [&]() {
PROP(expr);
PROP(index);
});
writeNode(
node,
"AstExprIndexExpr",
[&]()
{
PROP(expr);
PROP(index);
}
);
}
void write(class AstExprFunction* node)
{
writeNode(node, "AstExprFunction", [&]() {
PROP(generics);
PROP(genericPacks);
if (node->self)
PROP(self);
PROP(args);
if (node->returnAnnotation)
PROP(returnAnnotation);
PROP(vararg);
PROP(varargLocation);
if (node->varargAnnotation)
PROP(varargAnnotation);
writeNode(
node,
"AstExprFunction",
[&]()
{
PROP(generics);
PROP(genericPacks);
if (node->self)
PROP(self);
PROP(args);
if (node->returnAnnotation)
PROP(returnAnnotation);
PROP(vararg);
PROP(varargLocation);
if (node->varargAnnotation)
PROP(varargAnnotation);
PROP(body);
PROP(functionDepth);
PROP(debugname);
});
PROP(body);
PROP(functionDepth);
PROP(debugname);
}
);
}
void write(const std::optional<AstTypeList>& typeList)
@ -475,28 +525,43 @@ struct AstJsonEncoder : public AstVisitor
void write(class AstExprIfElse* node)
{
writeNode(node, "AstExprIfElse", [&]() {
PROP(condition);
PROP(hasThen);
PROP(trueExpr);
PROP(hasElse);
PROP(falseExpr);
});
writeNode(
node,
"AstExprIfElse",
[&]()
{
PROP(condition);
PROP(hasThen);
PROP(trueExpr);
PROP(hasElse);
PROP(falseExpr);
}
);
}
void write(class AstExprInterpString* node)
{
writeNode(node, "AstExprInterpString", [&]() {
PROP(strings);
PROP(expressions);
});
writeNode(
node,
"AstExprInterpString",
[&]()
{
PROP(strings);
PROP(expressions);
}
);
}
void write(class AstExprTable* node)
{
writeNode(node, "AstExprTable", [&]() {
PROP(items);
});
writeNode(
node,
"AstExprTable",
[&]()
{
PROP(items);
}
);
}
void write(AstExprUnary::Op op)
@ -514,10 +579,15 @@ struct AstJsonEncoder : public AstVisitor
void write(class AstExprUnary* node)
{
writeNode(node, "AstExprUnary", [&]() {
PROP(op);
PROP(expr);
});
writeNode(
node,
"AstExprUnary",
[&]()
{
PROP(op);
PROP(expr);
}
);
}
void write(AstExprBinary::Op op)
@ -563,75 +633,110 @@ struct AstJsonEncoder : public AstVisitor
void write(class AstExprBinary* node)
{
writeNode(node, "AstExprBinary", [&]() {
PROP(op);
PROP(left);
PROP(right);
});
writeNode(
node,
"AstExprBinary",
[&]()
{
PROP(op);
PROP(left);
PROP(right);
}
);
}
void write(class AstExprTypeAssertion* node)
{
writeNode(node, "AstExprTypeAssertion", [&]() {
PROP(expr);
PROP(annotation);
});
writeNode(
node,
"AstExprTypeAssertion",
[&]()
{
PROP(expr);
PROP(annotation);
}
);
}
void write(class AstExprError* node)
{
writeNode(node, "AstExprError", [&]() {
PROP(expressions);
PROP(messageIndex);
});
writeNode(
node,
"AstExprError",
[&]()
{
PROP(expressions);
PROP(messageIndex);
}
);
}
void write(class AstStatBlock* node)
{
writeNode(node, "AstStatBlock", [&]() {
writeRaw(",\"hasEnd\":");
write(node->hasEnd);
writeRaw(",\"body\":[");
bool comma = false;
for (AstStat* stat : node->body)
writeNode(
node,
"AstStatBlock",
[&]()
{
if (comma)
writeRaw(",");
else
comma = true;
writeRaw(",\"hasEnd\":");
write(node->hasEnd);
writeRaw(",\"body\":[");
bool comma = false;
for (AstStat* stat : node->body)
{
if (comma)
writeRaw(",");
else
comma = true;
write(stat);
write(stat);
}
writeRaw("]");
}
writeRaw("]");
});
);
}
void write(class AstStatIf* node)
{
writeNode(node, "AstStatIf", [&]() {
PROP(condition);
PROP(thenbody);
if (node->elsebody)
PROP(elsebody);
write("hasThen", node->thenLocation.has_value());
});
writeNode(
node,
"AstStatIf",
[&]()
{
PROP(condition);
PROP(thenbody);
if (node->elsebody)
PROP(elsebody);
write("hasThen", node->thenLocation.has_value());
}
);
}
void write(class AstStatWhile* node)
{
writeNode(node, "AstStatWhile", [&]() {
PROP(condition);
PROP(body);
PROP(hasDo);
});
writeNode(
node,
"AstStatWhile",
[&]()
{
PROP(condition);
PROP(body);
PROP(hasDo);
}
);
}
void write(class AstStatRepeat* node)
{
writeNode(node, "AstStatRepeat", [&]() {
PROP(condition);
PROP(body);
});
writeNode(
node,
"AstStatRepeat",
[&]()
{
PROP(condition);
PROP(body);
}
);
}
void write(class AstStatBreak* node)
@ -646,128 +751,188 @@ struct AstJsonEncoder : public AstVisitor
void write(class AstStatReturn* node)
{
writeNode(node, "AstStatReturn", [&]() {
PROP(list);
});
writeNode(
node,
"AstStatReturn",
[&]()
{
PROP(list);
}
);
}
void write(class AstStatExpr* node)
{
writeNode(node, "AstStatExpr", [&]() {
PROP(expr);
});
writeNode(
node,
"AstStatExpr",
[&]()
{
PROP(expr);
}
);
}
void write(class AstStatLocal* node)
{
writeNode(node, "AstStatLocal", [&]() {
PROP(vars);
PROP(values);
});
writeNode(
node,
"AstStatLocal",
[&]()
{
PROP(vars);
PROP(values);
}
);
}
void write(class AstStatFor* node)
{
writeNode(node, "AstStatFor", [&]() {
PROP(var);
PROP(from);
PROP(to);
if (node->step)
PROP(step);
PROP(body);
PROP(hasDo);
});
writeNode(
node,
"AstStatFor",
[&]()
{
PROP(var);
PROP(from);
PROP(to);
if (node->step)
PROP(step);
PROP(body);
PROP(hasDo);
}
);
}
void write(class AstStatForIn* node)
{
writeNode(node, "AstStatForIn", [&]() {
PROP(vars);
PROP(values);
PROP(body);
PROP(hasIn);
PROP(hasDo);
});
writeNode(
node,
"AstStatForIn",
[&]()
{
PROP(vars);
PROP(values);
PROP(body);
PROP(hasIn);
PROP(hasDo);
}
);
}
void write(class AstStatAssign* node)
{
writeNode(node, "AstStatAssign", [&]() {
PROP(vars);
PROP(values);
});
writeNode(
node,
"AstStatAssign",
[&]()
{
PROP(vars);
PROP(values);
}
);
}
void write(class AstStatCompoundAssign* node)
{
writeNode(node, "AstStatCompoundAssign", [&]() {
PROP(op);
PROP(var);
PROP(value);
});
writeNode(
node,
"AstStatCompoundAssign",
[&]()
{
PROP(op);
PROP(var);
PROP(value);
}
);
}
void write(class AstStatFunction* node)
{
writeNode(node, "AstStatFunction", [&]() {
PROP(name);
PROP(func);
});
writeNode(
node,
"AstStatFunction",
[&]()
{
PROP(name);
PROP(func);
}
);
}
void write(class AstStatLocalFunction* node)
{
writeNode(node, "AstStatLocalFunction", [&]() {
PROP(name);
PROP(func);
});
writeNode(
node,
"AstStatLocalFunction",
[&]()
{
PROP(name);
PROP(func);
}
);
}
void write(class AstStatTypeAlias* node)
{
writeNode(node, "AstStatTypeAlias", [&]() {
PROP(name);
PROP(generics);
PROP(genericPacks);
PROP(type);
PROP(exported);
});
writeNode(
node,
"AstStatTypeAlias",
[&]()
{
PROP(name);
PROP(generics);
PROP(genericPacks);
PROP(type);
PROP(exported);
}
);
}
void write(class AstStatDeclareFunction* node)
{
writeNode(node, "AstStatDeclareFunction", [&]() {
// TODO: attributes
PROP(name);
if (FFlag::LuauDeclarationExtraPropData)
PROP(nameLocation);
PROP(params);
if (FFlag::LuauDeclarationExtraPropData)
writeNode(
node,
"AstStatDeclareFunction",
[&]()
{
PROP(paramNames);
PROP(vararg);
PROP(varargLocation);
}
// TODO: attributes
PROP(name);
PROP(retTypes);
PROP(generics);
PROP(genericPacks);
});
if (FFlag::LuauDeclarationExtraPropData)
PROP(nameLocation);
PROP(params);
if (FFlag::LuauDeclarationExtraPropData)
{
PROP(paramNames);
PROP(vararg);
PROP(varargLocation);
}
PROP(retTypes);
PROP(generics);
PROP(genericPacks);
}
);
}
void write(class AstStatDeclareGlobal* node)
{
writeNode(node, "AstStatDeclareGlobal", [&]() {
PROP(name);
writeNode(
node,
"AstStatDeclareGlobal",
[&]()
{
PROP(name);
if (FFlag::LuauDeclarationExtraPropData)
PROP(nameLocation);
if (FFlag::LuauDeclarationExtraPropData)
PROP(nameLocation);
PROP(type);
});
PROP(type);
}
);
}
void write(const AstDeclaredClassProp& prop)
@ -791,21 +956,31 @@ struct AstJsonEncoder : public AstVisitor
void write(class AstStatDeclareClass* node)
{
writeNode(node, "AstStatDeclareClass", [&]() {
PROP(name);
if (node->superName)
write("superName", *node->superName);
PROP(props);
PROP(indexer);
});
writeNode(
node,
"AstStatDeclareClass",
[&]()
{
PROP(name);
if (node->superName)
write("superName", *node->superName);
PROP(props);
PROP(indexer);
}
);
}
void write(class AstStatError* node)
{
writeNode(node, "AstStatError", [&]() {
PROP(expressions);
PROP(statements);
});
writeNode(
node,
"AstStatError",
[&]()
{
PROP(expressions);
PROP(statements);
}
);
}
void write(struct AstTypeOrPack node)
@ -818,15 +993,20 @@ struct AstJsonEncoder : public AstVisitor
void write(class AstTypeReference* node)
{
writeNode(node, "AstTypeReference", [&]() {
if (node->prefix)
PROP(prefix);
if (node->prefixLocation)
write("prefixLocation", *node->prefixLocation);
PROP(name);
PROP(nameLocation);
PROP(parameters);
});
writeNode(
node,
"AstTypeReference",
[&]()
{
if (node->prefix)
PROP(prefix);
if (node->prefixLocation)
write("prefixLocation", *node->prefixLocation);
PROP(name);
PROP(nameLocation);
PROP(parameters);
}
);
}
void write(const AstTableProp& prop)
@ -845,10 +1025,15 @@ struct AstJsonEncoder : public AstVisitor
void write(class AstTypeTable* node)
{
writeNode(node, "AstTypeTable", [&]() {
PROP(props);
PROP(indexer);
});
writeNode(
node,
"AstTypeTable",
[&]()
{
PROP(props);
PROP(indexer);
}
);
}
void write(struct AstTableIndexer* indexer)
@ -871,78 +1056,128 @@ struct AstJsonEncoder : public AstVisitor
void write(class AstTypeFunction* node)
{
writeNode(node, "AstTypeFunction", [&]() {
PROP(generics);
PROP(genericPacks);
PROP(argTypes);
PROP(argNames);
PROP(returnTypes);
});
writeNode(
node,
"AstTypeFunction",
[&]()
{
PROP(generics);
PROP(genericPacks);
PROP(argTypes);
PROP(argNames);
PROP(returnTypes);
}
);
}
void write(class AstTypeTypeof* node)
{
writeNode(node, "AstTypeTypeof", [&]() {
PROP(expr);
});
writeNode(
node,
"AstTypeTypeof",
[&]()
{
PROP(expr);
}
);
}
void write(class AstTypeUnion* node)
{
writeNode(node, "AstTypeUnion", [&]() {
PROP(types);
});
writeNode(
node,
"AstTypeUnion",
[&]()
{
PROP(types);
}
);
}
void write(class AstTypeIntersection* node)
{
writeNode(node, "AstTypeIntersection", [&]() {
PROP(types);
});
writeNode(
node,
"AstTypeIntersection",
[&]()
{
PROP(types);
}
);
}
void write(class AstTypeError* node)
{
writeNode(node, "AstTypeError", [&]() {
PROP(types);
PROP(messageIndex);
});
writeNode(
node,
"AstTypeError",
[&]()
{
PROP(types);
PROP(messageIndex);
}
);
}
void write(class AstTypePackExplicit* node)
{
writeNode(node, "AstTypePackExplicit", [&]() {
PROP(typeList);
});
writeNode(
node,
"AstTypePackExplicit",
[&]()
{
PROP(typeList);
}
);
}
void write(class AstTypePackVariadic* node)
{
writeNode(node, "AstTypePackVariadic", [&]() {
PROP(variadicType);
});
writeNode(
node,
"AstTypePackVariadic",
[&]()
{
PROP(variadicType);
}
);
}
void write(class AstTypePackGeneric* node)
{
writeNode(node, "AstTypePackGeneric", [&]() {
PROP(genericName);
});
writeNode(
node,
"AstTypePackGeneric",
[&]()
{
PROP(genericName);
}
);
}
bool visit(class AstTypeSingletonBool* node) override
{
writeNode(node, "AstTypeSingletonBool", [&]() {
write("value", node->value);
});
writeNode(
node,
"AstTypeSingletonBool",
[&]()
{
write("value", node->value);
}
);
return false;
}
bool visit(class AstTypeSingletonString* node) override
{
writeNode(node, "AstTypeSingletonString", [&]() {
write("value", node->value);
});
writeNode(
node,
"AstTypeSingletonString",
[&]()
{
write("value", node->value);
}
);
return false;
}

View file

@ -331,9 +331,14 @@ static std::optional<AstStatLocal*> findBindingLocalStatement(const SourceModule
return std::nullopt;
std::vector<AstNode*> nodes = findAstAncestryOfPosition(source, binding.location.begin);
auto iter = std::find_if(nodes.rbegin(), nodes.rend(), [](AstNode* node) {
return node->is<AstStatLocal>();
});
auto iter = std::find_if(
nodes.rbegin(),
nodes.rend(),
[](AstNode* node)
{
return node->is<AstStatLocal>();
}
);
return iter != nodes.rend() ? std::make_optional((*iter)->as<AstStatLocal>()) : std::nullopt;
}
@ -472,7 +477,11 @@ ExprOrLocal findExprOrLocalAtPosition(const SourceModule& source, Position pos)
}
static std::optional<DocumentationSymbol> checkOverloadedDocumentationSymbol(
const Module& module, const TypeId ty, const AstExpr* parentExpr, const std::optional<DocumentationSymbol> documentationSymbol)
const Module& module,
const TypeId ty,
const AstExpr* parentExpr,
const std::optional<DocumentationSymbol> documentationSymbol
)
{
if (!documentationSymbol)
return std::nullopt;

View file

@ -15,8 +15,8 @@
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution);
static const std::unordered_set<std::string> kStatementStartingKeywords = {
"while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"};
static const std::unordered_set<std::string> kStatementStartingKeywords =
{"while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"};
namespace Luau
{
@ -161,7 +161,13 @@ static bool checkTypeMatch(TypeId subTy, TypeId superTy, NotNull<Scope> scope, T
}
static TypeCorrectKind checkTypeCorrectKind(
const Module& module, TypeArena* typeArena, NotNull<BuiltinTypes> builtinTypes, AstNode* node, Position position, TypeId ty)
const Module& module,
TypeArena* typeArena,
NotNull<BuiltinTypes> builtinTypes,
AstNode* node,
Position position,
TypeId ty
)
{
ty = follow(ty);
@ -176,7 +182,8 @@ static TypeCorrectKind checkTypeCorrectKind(
TypeId expectedType = follow(*typeAtPosition);
auto checkFunctionType = [typeArena, builtinTypes, moduleScope, &expectedType](const FunctionType* ftv) {
auto checkFunctionType = [typeArena, builtinTypes, moduleScope, &expectedType](const FunctionType* ftv)
{
if (std::optional<TypeId> firstRetTy = first(ftv->retTypes))
return checkTypeMatch(*firstRetTy, expectedType, moduleScope, typeArena, builtinTypes);
@ -209,9 +216,18 @@ enum class PropIndexType
Key,
};
static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNull<BuiltinTypes> builtinTypes, TypeId rootTy, TypeId ty,
PropIndexType indexType, const std::vector<AstNode*>& nodes, AutocompleteEntryMap& result, std::unordered_set<TypeId>& seen,
std::optional<const ClassType*> containingClass = std::nullopt)
static void autocompleteProps(
const Module& module,
TypeArena* typeArena,
NotNull<BuiltinTypes> builtinTypes,
TypeId rootTy,
TypeId ty,
PropIndexType indexType,
const std::vector<AstNode*>& nodes,
AutocompleteEntryMap& result,
std::unordered_set<TypeId>& seen,
std::optional<const ClassType*> containingClass = std::nullopt
)
{
rootTy = follow(rootTy);
ty = follow(ty);
@ -220,13 +236,15 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNul
return;
seen.insert(ty);
auto isWrongIndexer = [typeArena, builtinTypes, &module, rootTy, indexType](Luau::TypeId type) {
auto isWrongIndexer = [typeArena, builtinTypes, &module, rootTy, indexType](Luau::TypeId type)
{
if (indexType == PropIndexType::Key)
return false;
bool calledWithSelf = indexType == PropIndexType::Colon;
auto isCompatibleCall = [typeArena, builtinTypes, &module, rootTy, calledWithSelf](const FunctionType* ftv) {
auto isCompatibleCall = [typeArena, builtinTypes, &module, rootTy, calledWithSelf](const FunctionType* ftv)
{
// Strong match with definition is a success
if (calledWithSelf == ftv->hasSelf)
return true;
@ -265,7 +283,8 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNul
return calledWithSelf;
};
auto fillProps = [&](const ClassType::Props& props) {
auto fillProps = [&](const ClassType::Props& props)
{
for (const auto& [name, prop] : props)
{
// We are walking up the class hierarchy, so if we encounter a property that we have
@ -291,13 +310,26 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNul
ParenthesesRecommendation parens =
indexType == PropIndexType::Key ? ParenthesesRecommendation::None : getParenRecommendation(type, nodes, typeCorrect);
result[name] = AutocompleteEntry{AutocompleteEntryKind::Property, type, prop.deprecated, isWrongIndexer(type), typeCorrect,
containingClass, &prop, prop.documentationSymbol, {}, parens, {}, indexType == PropIndexType::Colon};
result[name] = AutocompleteEntry{
AutocompleteEntryKind::Property,
type,
prop.deprecated,
isWrongIndexer(type),
typeCorrect,
containingClass,
&prop,
prop.documentationSymbol,
{},
parens,
{},
indexType == PropIndexType::Colon
};
}
}
};
auto fillMetatableProps = [&](const TableType* mtable) {
auto fillMetatableProps = [&](const TableType* mtable)
{
auto indexIt = mtable->props.find("__index");
if (indexIt != mtable->props.end())
{
@ -409,7 +441,11 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNul
}
static void autocompleteKeywords(
const SourceModule& sourceModule, const std::vector<AstNode*>& ancestry, Position position, AutocompleteEntryMap& result)
const SourceModule& sourceModule,
const std::vector<AstNode*>& ancestry,
Position position,
AutocompleteEntryMap& result
)
{
LUAU_ASSERT(!ancestry.empty());
@ -429,15 +465,28 @@ static void autocompleteKeywords(
}
}
static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNull<BuiltinTypes> builtinTypes, TypeId ty, PropIndexType indexType,
const std::vector<AstNode*>& nodes, AutocompleteEntryMap& result)
static void autocompleteProps(
const Module& module,
TypeArena* typeArena,
NotNull<BuiltinTypes> builtinTypes,
TypeId ty,
PropIndexType indexType,
const std::vector<AstNode*>& nodes,
AutocompleteEntryMap& result
)
{
std::unordered_set<TypeId> seen;
autocompleteProps(module, typeArena, builtinTypes, ty, ty, indexType, nodes, result, seen);
}
AutocompleteEntryMap autocompleteProps(const Module& module, TypeArena* typeArena, NotNull<BuiltinTypes> builtinTypes, TypeId ty,
PropIndexType indexType, const std::vector<AstNode*>& nodes)
AutocompleteEntryMap autocompleteProps(
const Module& module,
TypeArena* typeArena,
NotNull<BuiltinTypes> builtinTypes,
TypeId ty,
PropIndexType indexType,
const std::vector<AstNode*>& nodes
)
{
AutocompleteEntryMap result;
autocompleteProps(module, typeArena, builtinTypes, ty, indexType, nodes, result);
@ -472,7 +521,8 @@ static void autocompleteStringSingleton(TypeId ty, bool addQuotes, AstNode* node
return;
}
auto formatKey = [addQuotes](const std::string& key) {
auto formatKey = [addQuotes](const std::string& key)
{
if (addQuotes)
return "\"" + escape(key) + "\"";
@ -705,9 +755,14 @@ static std::optional<bool> functionIsExpectedAt(const Module& module, AstNode* n
if (const IntersectionType* itv = get<IntersectionType>(expectedType))
{
return std::all_of(begin(itv->parts), end(itv->parts), [](auto&& ty) {
return get<FunctionType>(Luau::follow(ty)) != nullptr;
});
return std::all_of(
begin(itv->parts),
end(itv->parts),
[](auto&& ty)
{
return get<FunctionType>(Luau::follow(ty)) != nullptr;
}
);
}
if (const UnionType* utv = get<UnionType>(expectedType))
@ -727,15 +782,31 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi
for (const auto& [name, ty] : scope->exportedTypeBindings)
{
if (!result.count(name))
result[name] = AutocompleteEntry{AutocompleteEntryKind::Type, ty.type, false, false, TypeCorrectKind::None, std::nullopt,
std::nullopt, ty.type->documentationSymbol};
result[name] = AutocompleteEntry{
AutocompleteEntryKind::Type,
ty.type,
false,
false,
TypeCorrectKind::None,
std::nullopt,
std::nullopt,
ty.type->documentationSymbol
};
}
for (const auto& [name, ty] : scope->privateTypeBindings)
{
if (!result.count(name))
result[name] = AutocompleteEntry{AutocompleteEntryKind::Type, ty.type, false, false, TypeCorrectKind::None, std::nullopt,
std::nullopt, ty.type->documentationSymbol};
result[name] = AutocompleteEntry{
AutocompleteEntryKind::Type,
ty.type,
false,
false,
TypeCorrectKind::None,
std::nullopt,
std::nullopt,
ty.type->documentationSymbol
};
}
for (const auto& [name, _] : scope->importedTypeBindings)
@ -825,7 +896,8 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi
else if (AstExprFunction* node = parent->as<AstExprFunction>())
{
// For lookup inside expected function type if that's available
auto tryGetExpectedFunctionType = [](const Module& module, AstExpr* expr) -> const FunctionType* {
auto tryGetExpectedFunctionType = [](const Module& module, AstExpr* expr) -> const FunctionType*
{
auto it = module.astExpectedTypes.find(expr);
if (!it)
@ -1029,7 +1101,11 @@ static bool isBindingLegalAtCurrentPosition(const Symbol& symbol, const Binding&
}
static AutocompleteEntryMap autocompleteStatement(
const SourceModule& sourceModule, const Module& module, const std::vector<AstNode*>& ancestry, Position position)
const SourceModule& sourceModule,
const Module& module,
const std::vector<AstNode*>& ancestry,
Position position
)
{
// This is inefficient. :(
ScopePtr scope = findScopeAtPosition(module, position);
@ -1051,8 +1127,18 @@ static AutocompleteEntryMap autocompleteStatement(
std::string n = toString(name);
if (!result.count(n))
result[n] = {AutocompleteEntryKind::Binding, binding.typeId, binding.deprecated, false, TypeCorrectKind::None, std::nullopt,
std::nullopt, binding.documentationSymbol, {}, getParenRecommendation(binding.typeId, ancestry, TypeCorrectKind::None)};
result[n] = {
AutocompleteEntryKind::Binding,
binding.typeId,
binding.deprecated,
false,
TypeCorrectKind::None,
std::nullopt,
std::nullopt,
binding.documentationSymbol,
{},
getParenRecommendation(binding.typeId, ancestry, TypeCorrectKind::None)
};
}
scope = scope->parent;
@ -1122,7 +1208,11 @@ static AutocompleteEntryMap autocompleteStatement(
// Returns true iff `node` was handled by this function (completions, if any, are returned in `outResult`)
static bool autocompleteIfElseExpression(
const AstNode* node, const std::vector<AstNode*>& ancestry, const Position& position, AutocompleteEntryMap& outResult)
const AstNode* node,
const std::vector<AstNode*>& ancestry,
const Position& position,
AutocompleteEntryMap& outResult
)
{
AstNode* parent = ancestry.size() >= 2 ? ancestry.rbegin()[1] : nullptr;
if (!parent)
@ -1161,8 +1251,15 @@ static bool autocompleteIfElseExpression(
}
}
static AutocompleteContext autocompleteExpression(const SourceModule& sourceModule, const Module& module, NotNull<BuiltinTypes> builtinTypes,
TypeArena* typeArena, const std::vector<AstNode*>& ancestry, Position position, AutocompleteEntryMap& result)
static AutocompleteContext autocompleteExpression(
const SourceModule& sourceModule,
const Module& module,
NotNull<BuiltinTypes> builtinTypes,
TypeArena* typeArena,
const std::vector<AstNode*>& ancestry,
Position position,
AutocompleteEntryMap& result
)
{
LUAU_ASSERT(!ancestry.empty());
@ -1197,8 +1294,18 @@ static AutocompleteContext autocompleteExpression(const SourceModule& sourceModu
{
TypeCorrectKind typeCorrect = checkTypeCorrectKind(module, typeArena, builtinTypes, node, position, binding.typeId);
result[n] = {AutocompleteEntryKind::Binding, binding.typeId, binding.deprecated, false, typeCorrect, std::nullopt, std::nullopt,
binding.documentationSymbol, {}, getParenRecommendation(binding.typeId, ancestry, typeCorrect)};
result[n] = {
AutocompleteEntryKind::Binding,
binding.typeId,
binding.deprecated,
false,
typeCorrect,
std::nullopt,
std::nullopt,
binding.documentationSymbol,
{},
getParenRecommendation(binding.typeId, ancestry, typeCorrect)
};
}
}
@ -1225,8 +1332,14 @@ static AutocompleteContext autocompleteExpression(const SourceModule& sourceModu
return AutocompleteContext::Expression;
}
static AutocompleteResult autocompleteExpression(const SourceModule& sourceModule, const Module& module, NotNull<BuiltinTypes> builtinTypes,
TypeArena* typeArena, const std::vector<AstNode*>& ancestry, Position position)
static AutocompleteResult autocompleteExpression(
const SourceModule& sourceModule,
const Module& module,
NotNull<BuiltinTypes> builtinTypes,
TypeArena* typeArena,
const std::vector<AstNode*>& ancestry,
Position position
)
{
AutocompleteEntryMap result;
AutocompleteContext context = autocompleteExpression(sourceModule, module, builtinTypes, typeArena, ancestry, position, result);
@ -1312,8 +1425,13 @@ static std::optional<std::string> getStringContents(const AstNode* node)
}
}
static std::optional<AutocompleteEntryMap> autocompleteStringParams(const SourceModule& sourceModule, const ModulePtr& module,
const std::vector<AstNode*>& nodes, Position position, StringCompletionCallback callback)
static std::optional<AutocompleteEntryMap> autocompleteStringParams(
const SourceModule& sourceModule,
const ModulePtr& module,
const std::vector<AstNode*>& nodes,
Position position,
StringCompletionCallback callback
)
{
if (nodes.size() < 2)
{
@ -1354,7 +1472,8 @@ static std::optional<AutocompleteEntryMap> autocompleteStringParams(const Source
std::optional<std::string> candidateString = getStringContents(nodes.back());
auto performCallback = [&](const FunctionType* funcType) -> std::optional<AutocompleteEntryMap> {
auto performCallback = [&](const FunctionType* funcType) -> std::optional<AutocompleteEntryMap>
{
for (const std::string& tag : funcType->tags)
{
if (std::optional<AutocompleteEntryMap> ret = callback(tag, getMethodContainingClass(module, candidate->func), candidateString))
@ -1463,7 +1582,11 @@ static std::string makeAnonymous(const ScopePtr& scope, const FunctionType& func
}
static std::optional<AutocompleteEntry> makeAnonymousAutofilled(
const ModulePtr& module, Position position, const AstNode* node, const std::vector<AstNode*>& ancestry)
const ModulePtr& module,
Position position,
const AstNode* node,
const std::vector<AstNode*>& ancestry
)
{
const AstExprCall* call = node->as<AstExprCall>();
if (!call && ancestry.size() > 1)
@ -1530,8 +1653,15 @@ static std::optional<AutocompleteEntry> makeAnonymousAutofilled(
return std::make_optional(std::move(entry));
}
static AutocompleteResult autocomplete(const SourceModule& sourceModule, const ModulePtr& module, NotNull<BuiltinTypes> builtinTypes,
TypeArena* typeArena, Scope* globalScope, Position position, StringCompletionCallback callback)
static AutocompleteResult autocomplete(
const SourceModule& sourceModule,
const ModulePtr& module,
NotNull<BuiltinTypes> builtinTypes,
TypeArena* typeArena,
Scope* globalScope,
Position position,
StringCompletionCallback callback
)
{
if (isWithinComment(sourceModule, position))
return {};
@ -1656,14 +1786,17 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M
else if (AstStatWhile* statWhile = extractStat<AstStatWhile>(ancestry);
(statWhile && (!statWhile->hasDo || statWhile->doLocation.containsClosed(position)) && statWhile->condition &&
!statWhile->condition->location.containsClosed(position)))
!statWhile->condition->location.containsClosed(position)))
{
return autocompleteWhileLoopKeywords(ancestry);
}
else if (AstStatIf* statIf = node->as<AstStatIf>(); statIf && !statIf->elseLocation.has_value())
{
return {{{"else", AutocompleteEntry{AutocompleteEntryKind::Keyword}}, {"elseif", AutocompleteEntry{AutocompleteEntryKind::Keyword}}},
ancestry, AutocompleteContext::Keyword};
return {
{{"else", AutocompleteEntry{AutocompleteEntryKind::Keyword}}, {"elseif", AutocompleteEntry{AutocompleteEntryKind::Keyword}}},
ancestry,
AutocompleteContext::Keyword
};
}
else if (AstStatIf* statIf = parent->as<AstStatIf>(); statIf && node->is<AstStatBlock>())
{

View file

@ -29,15 +29,35 @@ namespace Luau
{
static std::optional<WithPredicate<TypePackId>> magicFunctionSelect(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate);
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
);
static std::optional<WithPredicate<TypePackId>> magicFunctionSetMetaTable(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate);
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
);
static std::optional<WithPredicate<TypePackId>> magicFunctionAssert(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate);
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
);
static std::optional<WithPredicate<TypePackId>> magicFunctionPack(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate);
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
);
static std::optional<WithPredicate<TypePackId>> magicFunctionRequire(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate);
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
);
static bool dcrMagicFunctionSelect(MagicFunctionCallContext context);
@ -61,26 +81,51 @@ TypeId makeOption(NotNull<BuiltinTypes> builtinTypes, TypeArena& arena, TypeId t
}
TypeId makeFunction(
TypeArena& arena, std::optional<TypeId> selfType, std::initializer_list<TypeId> paramTypes, std::initializer_list<TypeId> retTypes, bool checked)
TypeArena& arena,
std::optional<TypeId> selfType,
std::initializer_list<TypeId> paramTypes,
std::initializer_list<TypeId> retTypes,
bool checked
)
{
return makeFunction(arena, selfType, {}, {}, paramTypes, {}, retTypes, checked);
}
TypeId makeFunction(TypeArena& arena, std::optional<TypeId> selfType, std::initializer_list<TypeId> generics,
std::initializer_list<TypePackId> genericPacks, std::initializer_list<TypeId> paramTypes, std::initializer_list<TypeId> retTypes, bool checked)
TypeId makeFunction(
TypeArena& arena,
std::optional<TypeId> selfType,
std::initializer_list<TypeId> generics,
std::initializer_list<TypePackId> genericPacks,
std::initializer_list<TypeId> paramTypes,
std::initializer_list<TypeId> retTypes,
bool checked
)
{
return makeFunction(arena, selfType, generics, genericPacks, paramTypes, {}, retTypes, checked);
}
TypeId makeFunction(TypeArena& arena, std::optional<TypeId> selfType, std::initializer_list<TypeId> paramTypes,
std::initializer_list<std::string> paramNames, std::initializer_list<TypeId> retTypes, bool checked)
TypeId makeFunction(
TypeArena& arena,
std::optional<TypeId> selfType,
std::initializer_list<TypeId> paramTypes,
std::initializer_list<std::string> paramNames,
std::initializer_list<TypeId> retTypes,
bool checked
)
{
return makeFunction(arena, selfType, {}, {}, paramTypes, paramNames, retTypes, checked);
}
TypeId makeFunction(TypeArena& arena, std::optional<TypeId> selfType, std::initializer_list<TypeId> generics,
std::initializer_list<TypePackId> genericPacks, std::initializer_list<TypeId> paramTypes, std::initializer_list<std::string> paramNames,
std::initializer_list<TypeId> retTypes, bool checked)
TypeId makeFunction(
TypeArena& arena,
std::optional<TypeId> selfType,
std::initializer_list<TypeId> generics,
std::initializer_list<TypePackId> genericPacks,
std::initializer_list<TypeId> paramTypes,
std::initializer_list<std::string> paramNames,
std::initializer_list<TypeId> retTypes,
bool checked
)
{
std::vector<TypeId> params;
if (selfType)
@ -219,7 +264,8 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC
builtinTypeFunctions().addToScope(NotNull{&arena}, NotNull{globals.globalScope.get()});
LoadDefinitionFileResult loadResult = frontend.loadDefinitionFile(
globals, globals.globalScope, getBuiltinDefinitionSource(), "@luau", /* captureComments */ false, typeCheckForAutocomplete);
globals, globals.globalScope, getBuiltinDefinitionSource(), "@luau", /* captureComments */ false, typeCheckForAutocomplete
);
LUAU_ASSERT(loadResult.success);
TypeId genericK = arena.addType(GenericType{"K"});
@ -313,10 +359,12 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC
// declare function assert<T>(value: T, errorMessage: string?): intersect<T, ~(false?)>
TypeId genericT = arena.addType(GenericType{"T"});
TypeId refinedTy = arena.addType(TypeFunctionInstanceType{
NotNull{&builtinTypeFunctions().intersectFunc}, {genericT, arena.addType(NegationType{builtinTypes->falsyType})}, {}});
NotNull{&builtinTypeFunctions().intersectFunc}, {genericT, arena.addType(NegationType{builtinTypes->falsyType})}, {}
});
TypeId assertTy = arena.addType(FunctionType{
{genericT}, {}, arena.addTypePack(TypePack{{genericT, builtinTypes->optionalStringType}}), arena.addTypePack(TypePack{{refinedTy}})});
{genericT}, {}, arena.addTypePack(TypePack{{genericT, builtinTypes->optionalStringType}}), arena.addTypePack(TypePack{{refinedTy}})
});
addGlobalBinding(globals, "assert", assertTy, "@luau");
}
@ -380,7 +428,11 @@ static std::vector<TypeId> parseFormatString(NotNull<BuiltinTypes> builtinTypes,
}
std::optional<WithPredicate<TypePackId>> magicFunctionFormat(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate)
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
)
{
auto [paramPack, _predicates] = withPredicate;
@ -529,7 +581,11 @@ static std::vector<TypeId> parsePatternString(NotNull<BuiltinTypes> builtinTypes
}
static std::optional<WithPredicate<TypePackId>> magicFunctionGmatch(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate)
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
)
{
auto [paramPack, _predicates] = withPredicate;
const auto& [params, tail] = flatten(paramPack);
@ -594,7 +650,11 @@ static bool dcrMagicFunctionGmatch(MagicFunctionCallContext context)
}
static std::optional<WithPredicate<TypePackId>> magicFunctionMatch(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate)
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
)
{
auto [paramPack, _predicates] = withPredicate;
const auto& [params, tail] = flatten(paramPack);
@ -666,7 +726,11 @@ static bool dcrMagicFunctionMatch(MagicFunctionCallContext context)
}
static std::optional<WithPredicate<TypePackId>> magicFunctionFind(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate)
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
)
{
auto [paramPack, _predicates] = withPredicate;
const auto& [params, tail] = flatten(paramPack);
@ -804,9 +868,11 @@ TypeId makeStringMetatable(NotNull<BuiltinTypes> builtinTypes)
const TypeId stringToStringType = makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType}, /* checked */ true);
const TypeId replArgType =
arena->addType(UnionType{{stringType, arena->addType(TableType({}, TableIndexer(stringType, stringType), TypeLevel{}, TableState::Generic)),
makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType}, /* checked */ false)}});
const TypeId replArgType = arena->addType(UnionType{
{stringType,
arena->addType(TableType({}, TableIndexer(stringType, stringType), TypeLevel{}, TableState::Generic)),
makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType}, /* checked */ false)}
});
const TypeId gsubFunc =
makeFunction(*arena, stringType, {}, {}, {stringType, replArgType, optionalNumber}, {}, {stringType, numberType}, /* checked */ false);
const TypeId gmatchFunc =
@ -815,14 +881,17 @@ TypeId makeStringMetatable(NotNull<BuiltinTypes> builtinTypes)
attachDcrMagicFunction(gmatchFunc, dcrMagicFunctionGmatch);
FunctionType matchFuncTy{
arena->addTypePack({stringType, stringType, optionalNumber}), arena->addTypePack(TypePackVar{VariadicTypePack{stringType}})};
arena->addTypePack({stringType, stringType, optionalNumber}), arena->addTypePack(TypePackVar{VariadicTypePack{stringType}})
};
matchFuncTy.isCheckedFunction = true;
const TypeId matchFunc = arena->addType(matchFuncTy);
attachMagicFunction(matchFunc, magicFunctionMatch);
attachDcrMagicFunction(matchFunc, dcrMagicFunctionMatch);
FunctionType findFuncTy{arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}),
arena->addTypePack(TypePack{{optionalNumber, optionalNumber}, stringVariadicList})};
FunctionType findFuncTy{
arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}),
arena->addTypePack(TypePack{{optionalNumber, optionalNumber}, stringVariadicList})
};
findFuncTy.isCheckedFunction = true;
const TypeId findFunc = arena->addType(findFuncTy);
attachMagicFunction(findFunc, magicFunctionFind);
@ -857,13 +926,22 @@ TypeId makeStringMetatable(NotNull<BuiltinTypes> builtinTypes)
{"reverse", {stringToStringType}},
{"sub", {makeFunction(*arena, stringType, {}, {}, {numberType, optionalNumber}, {}, {stringType}, /* checked */ true)}},
{"upper", {stringToStringType}},
{"split", {makeFunction(*arena, stringType, {}, {}, {optionalString}, {},
{arena->addType(TableType{{}, TableIndexer{numberType, stringType}, TypeLevel{}, TableState::Sealed})},
/* checked */ true)}},
{"pack", {arena->addType(FunctionType{
arena->addTypePack(TypePack{{stringType}, variadicTailPack}),
oneStringPack,
})}},
{"split",
{makeFunction(
*arena,
stringType,
{},
{},
{optionalString},
{},
{arena->addType(TableType{{}, TableIndexer{numberType, stringType}, TypeLevel{}, TableState::Sealed})},
/* checked */ true
)}},
{"pack",
{arena->addType(FunctionType{
arena->addTypePack(TypePack{{stringType}, variadicTailPack}),
oneStringPack,
})}},
{"packsize", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType}, /* checked */ true)}},
{"unpack", {arena->addType(stringDotUnpack)}},
};
@ -879,7 +957,11 @@ TypeId makeStringMetatable(NotNull<BuiltinTypes> builtinTypes)
}
static std::optional<WithPredicate<TypePackId>> magicFunctionSelect(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate)
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
)
{
auto [paramPack, _predicates] = withPredicate;
@ -965,7 +1047,11 @@ static bool dcrMagicFunctionSelect(MagicFunctionCallContext context)
}
static std::optional<WithPredicate<TypePackId>> magicFunctionSetMetaTable(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate)
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
)
{
auto [paramPack, _predicates] = withPredicate;
@ -1043,7 +1129,11 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionSetMetaTable(
}
static std::optional<WithPredicate<TypePackId>> magicFunctionAssert(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate)
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
)
{
auto [paramPack, predicates] = withPredicate;
@ -1073,7 +1163,11 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionAssert(
}
static std::optional<WithPredicate<TypePackId>> magicFunctionPack(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate)
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
)
{
auto [paramPack, _predicates] = withPredicate;
@ -1174,7 +1268,11 @@ static bool checkRequirePath(TypeChecker& typechecker, AstExpr* expr)
}
static std::optional<WithPredicate<TypePackId>> magicFunctionRequire(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate)
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
)
{
TypeArena& arena = typechecker.currentModule->internalTypes;

View file

@ -227,19 +227,23 @@ private:
void cloneChildren(TypeId ty)
{
return visit(
[&](auto&& t) {
[&](auto&& t)
{
return cloneChildren(&t);
},
asMutable(ty)->ty);
asMutable(ty)->ty
);
}
void cloneChildren(TypePackId tp)
{
return visit(
[&](auto&& t) {
[&](auto&& t)
{
return cloneChildren(&t);
},
asMutable(tp)->ty);
asMutable(tp)->ty
);
}
void cloneChildren(Kind kind)

View file

@ -189,10 +189,18 @@ bool hasFreeType(TypeId ty)
} // namespace
ConstraintGenerator::ConstraintGenerator(ModulePtr module, NotNull<Normalizer> normalizer, NotNull<ModuleResolver> moduleResolver,
NotNull<BuiltinTypes> builtinTypes, NotNull<InternalErrorReporter> ice, const ScopePtr& globalScope,
std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope, DcrLogger* logger, NotNull<DataFlowGraph> dfg,
std::vector<RequireCycle> requireCycles)
ConstraintGenerator::ConstraintGenerator(
ModulePtr module,
NotNull<Normalizer> normalizer,
NotNull<ModuleResolver> moduleResolver,
NotNull<BuiltinTypes> builtinTypes,
NotNull<InternalErrorReporter> ice,
const ScopePtr& globalScope,
std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope,
DcrLogger* logger,
NotNull<DataFlowGraph> dfg,
std::vector<RequireCycle> requireCycles
)
: module(module)
, builtinTypes(builtinTypes)
, arena(normalizer->arena)
@ -240,9 +248,15 @@ void ConstraintGenerator::visitModuleRoot(AstStatBlock* block)
NotNull<Constraint> genConstraint =
addConstraint(scope, block->location, GeneralizationConstraint{result, moduleFnTy, std::move(interiorTypes.back())});
getMutable<BlockedType>(result)->setOwner(genConstraint);
forEachConstraint(start, end, this, [genConstraint](const ConstraintPtr& c) {
genConstraint->dependencies.push_back(NotNull{c.get()});
});
forEachConstraint(
start,
end,
this,
[genConstraint](const ConstraintPtr& c)
{
genConstraint->dependencies.push_back(NotNull{c.get()});
}
);
interiorTypes.pop_back();
@ -354,10 +368,17 @@ NotNull<Constraint> ConstraintGenerator::addConstraint(const ScopePtr& scope, st
return NotNull{constraints.emplace_back(std::move(c)).get()};
}
void ConstraintGenerator::unionRefinements(const ScopePtr& scope, Location location, const RefinementContext& lhs, const RefinementContext& rhs,
RefinementContext& dest, std::vector<ConstraintV>* constraints)
void ConstraintGenerator::unionRefinements(
const ScopePtr& scope,
Location location,
const RefinementContext& lhs,
const RefinementContext& rhs,
RefinementContext& dest,
std::vector<ConstraintV>* constraints
)
{
const auto intersect = [&](const std::vector<TypeId>& types) {
const auto intersect = [&](const std::vector<TypeId>& types)
{
if (1 == types.size())
return types[0];
else if (2 == types.size())
@ -386,8 +407,15 @@ void ConstraintGenerator::unionRefinements(const ScopePtr& scope, Location locat
}
}
void ConstraintGenerator::computeRefinement(const ScopePtr& scope, Location location, RefinementId refinement, RefinementContext* refis, bool sense,
bool eq, std::vector<ConstraintV>* constraints)
void ConstraintGenerator::computeRefinement(
const ScopePtr& scope,
Location location,
RefinementId refinement,
RefinementContext* refis,
bool sense,
bool eq,
std::vector<ConstraintV>* constraints
)
{
if (!refinement)
return;
@ -555,8 +583,11 @@ void ConstraintGenerator::applyRefinements(const ScopePtr& scope, Location locat
switch (shouldSuppressErrors(normalizer, ty))
{
case ErrorSuppression::DoNotSuppress:
ty = makeIntersect(scope, location, ty, dt);
{
if (!get<NeverType>(follow(ty)))
ty = makeIntersect(scope, location, ty, dt);
break;
}
case ErrorSuppression::Suppress:
ty = makeIntersect(scope, location, ty, dt);
ty = makeUnion(scope, location, ty, builtinTypes->errorType);
@ -688,6 +719,8 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStat* stat)
return visit(scope, f);
else if (auto a = stat->as<AstStatTypeAlias>())
return visit(scope, a);
else if (auto f = stat->as<AstStatTypeFunction>())
return visit(scope, f);
else if (auto s = stat->as<AstStatDeclareGlobal>())
return visit(scope, s);
else if (auto s = stat->as<AstStatDeclareFunction>())
@ -792,11 +825,15 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatLocal* stat
auto uc = addConstraint(scope, statLocal->location, UnpackConstraint{valueTypes, rvaluePack});
forEachConstraint(start, end, this,
forEachConstraint(
start,
end,
this,
[&uc](const ConstraintPtr& runBefore)
{
uc->dependencies.push_back(NotNull{runBefore.get()});
});
}
);
for (TypeId t : valueTypes)
getMutable<BlockedType>(t)->setOwner(uc);
@ -875,7 +912,8 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatFor* for_)
if (for_->var->annotation)
annotationTy = resolveType(scope, for_->var->annotation, /* inTypeArguments */ false);
auto inferNumber = [&](AstExpr* expr) {
auto inferNumber = [&](AstExpr* expr)
{
if (!expr)
return;
@ -929,7 +967,8 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatForIn* forI
}
auto iterable = addConstraint(
loopScope, getLocation(forIn->values), IterableConstraint{iterator, variableTypes, forIn->values.data[0], &module->astForInNextTypes});
loopScope, getLocation(forIn->values), IterableConstraint{iterator, variableTypes, forIn->values.data[0], &module->astForInNextTypes}
);
for (TypeId var : variableTypes)
{
@ -943,9 +982,15 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatForIn* forI
Checkpoint end = checkpoint(this);
// This iter constraint must dispatch first.
forEachConstraint(start, end, this, [&iterable](const ConstraintPtr& runLater) {
runLater->dependencies.push_back(iterable);
});
forEachConstraint(
start,
end,
this,
[&iterable](const ConstraintPtr& runLater)
{
runLater->dependencies.push_back(iterable);
}
);
return ControlFlow::None;
}
@ -1011,17 +1056,23 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatLocalFuncti
std::make_unique<Constraint>(constraintScope, function->name->location, GeneralizationConstraint{functionType, sig.signature});
Constraint* previous = nullptr;
forEachConstraint(start, end, this, [&c, &previous](const ConstraintPtr& constraint) {
c->dependencies.push_back(NotNull{constraint.get()});
if (auto psc = get<PackSubtypeConstraint>(*constraint); psc && psc->returns)
forEachConstraint(
start,
end,
this,
[&c, &previous](const ConstraintPtr& constraint)
{
if (previous)
constraint->dependencies.push_back(NotNull{previous});
c->dependencies.push_back(NotNull{constraint.get()});
previous = constraint.get();
if (auto psc = get<PackSubtypeConstraint>(*constraint); psc && psc->returns)
{
if (previous)
constraint->dependencies.push_back(NotNull{previous});
previous = constraint.get();
}
}
});
);
getMutable<BlockedType>(functionType)->setOwner(addConstraint(scope, std::move(c)));
module->astTypes[function->func] = functionType;
@ -1055,17 +1106,23 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatFunction* f
getMutable<BlockedType>(generalizedType)->setOwner(c);
Constraint* previous = nullptr;
forEachConstraint(start, end, this, [&c, &previous](const ConstraintPtr& constraint) {
c->dependencies.push_back(NotNull{constraint.get()});
if (auto psc = get<PackSubtypeConstraint>(*constraint); psc && psc->returns)
forEachConstraint(
start,
end,
this,
[&c, &previous](const ConstraintPtr& constraint)
{
if (previous)
constraint->dependencies.push_back(NotNull{previous});
c->dependencies.push_back(NotNull{constraint.get()});
previous = constraint.get();
if (auto psc = get<PackSubtypeConstraint>(*constraint); psc && psc->returns)
{
if (previous)
constraint->dependencies.push_back(NotNull{previous});
previous = constraint.get();
}
}
});
);
}
DefId def = dfg->getDef(function->name);
@ -1211,7 +1268,8 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatCompoundAss
ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatIf* ifStatement)
{
RefinementId refinement = [&]() {
RefinementId refinement = [&]()
{
InConditionalContext flipper{&typeContext};
return check(scope, ifStatement->condition, std::nullopt).refinement;
}();
@ -1293,18 +1351,26 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatTypeAlias*
for (auto tpParam : createGenericPacks(*defnScope, alias->genericPacks, /* useCache */ true, /* addTypes */ false))
typePackParams.push_back(tpParam.second.tp);
addConstraint(scope, alias->type->location,
addConstraint(
scope,
alias->type->location,
NameConstraint{
ty,
alias->name.value,
/*synthetic=*/false,
std::move(typeParams),
std::move(typePackParams),
});
}
);
return ControlFlow::None;
}
ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatTypeFunction* function)
{
return ControlFlow::None;
}
ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatDeclareGlobal* global)
{
LUAU_ASSERT(global->type);
@ -1350,8 +1416,10 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatDeclareClas
if (!get<ClassType>(follow(*superTy)))
{
reportError(declaredClass->location,
GenericError{format("Cannot use non-class type '%s' as a superclass of class '%s'", superName.c_str(), declaredClass->name.value)});
reportError(
declaredClass->location,
GenericError{format("Cannot use non-class type '%s' as a superclass of class '%s'", superName.c_str(), declaredClass->name.value)}
);
return ControlFlow::None;
}
@ -1579,7 +1647,11 @@ InferencePack ConstraintGenerator::checkPack(const ScopePtr& scope, AstArray<Ast
}
InferencePack ConstraintGenerator::checkPack(
const ScopePtr& scope, AstExpr* expr, const std::vector<std::optional<TypeId>>& expectedTypes, bool generalize)
const ScopePtr& scope,
AstExpr* expr,
const std::vector<std::optional<TypeId>>& expectedTypes,
bool generalize
)
{
RecursionCounter counter{&recursionCount};
@ -1661,7 +1733,6 @@ InferencePack ConstraintGenerator::checkPack(const ScopePtr& scope, AstExprCall*
std::vector<std::optional<TypeId>> expectedTypesForCall = getExpectedCallTypesForFunctionOverloads(fnType);
module->astOriginalCallTypes[call->func] = fnType;
module->astOriginalCallTypes[call] = fnType;
Checkpoint argBeginCheckpoint = checkpoint(this);
@ -1796,14 +1867,25 @@ InferencePack ConstraintGenerator::checkPack(const ScopePtr& scope, AstExprCall*
* 4. Solve the call
*/
NotNull<Constraint> checkConstraint = addConstraint(scope, call->func->location,
FunctionCheckConstraint{fnType, argPack, call, NotNull{&module->astTypes}, NotNull{&module->astExpectedTypes}});
NotNull<Constraint> checkConstraint = addConstraint(
scope,
call->func->location,
FunctionCheckConstraint{fnType, argPack, call, NotNull{&module->astTypes}, NotNull{&module->astExpectedTypes}}
);
forEachConstraint(funcBeginCheckpoint, funcEndCheckpoint, this, [checkConstraint](const ConstraintPtr& constraint) {
checkConstraint->dependencies.emplace_back(constraint.get());
});
forEachConstraint(
funcBeginCheckpoint,
funcEndCheckpoint,
this,
[checkConstraint](const ConstraintPtr& constraint)
{
checkConstraint->dependencies.emplace_back(constraint.get());
}
);
NotNull<Constraint> callConstraint = addConstraint(scope, call->func->location,
NotNull<Constraint> callConstraint = addConstraint(
scope,
call->func->location,
FunctionCallConstraint{
fnType,
argPack,
@ -1811,17 +1893,24 @@ InferencePack ConstraintGenerator::checkPack(const ScopePtr& scope, AstExprCall*
call,
std::move(discriminantTypes),
&module->astOverloadResolvedTypes,
});
}
);
getMutable<BlockedTypePack>(rets)->owner = callConstraint.get();
callConstraint->dependencies.push_back(checkConstraint);
forEachConstraint(argBeginCheckpoint, argEndCheckpoint, this, [checkConstraint, callConstraint](const ConstraintPtr& constraint) {
constraint->dependencies.emplace_back(checkConstraint);
forEachConstraint(
argBeginCheckpoint,
argEndCheckpoint,
this,
[checkConstraint, callConstraint](const ConstraintPtr& constraint)
{
constraint->dependencies.emplace_back(checkConstraint);
callConstraint->dependencies.emplace_back(constraint.get());
});
callConstraint->dependencies.emplace_back(constraint.get());
}
);
return InferencePack{rets, {refinementArena.variadic(returnRefinements)}};
}
@ -1974,7 +2063,12 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprGlobal* globa
}
Inference ConstraintGenerator::checkIndexName(
const ScopePtr& scope, const RefinementKey* key, AstExpr* indexee, const std::string& index, Location indexLocation)
const ScopePtr& scope,
const RefinementKey* key,
AstExpr* indexee,
const std::string& index,
Location indexLocation
)
{
TypeId obj = check(scope, indexee).ty;
TypeId result = nullptr;
@ -2005,7 +2099,8 @@ Inference ConstraintGenerator::checkIndexName(
result = arena->addType(BlockedType{});
auto c = addConstraint(
scope, indexee->location, HasPropConstraint{result, obj, std::move(index), ValueContext::RValue, inConditional(typeContext)});
scope, indexee->location, HasPropConstraint{result, obj, std::move(index), ValueContext::RValue, inConditional(typeContext)}
);
getMutable<BlockedType>(result)->setOwner(c);
}
@ -2076,17 +2171,23 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprFunction* fun
interiorTypes.pop_back();
Constraint* previous = nullptr;
forEachConstraint(startCheckpoint, endCheckpoint, this, [gc, &previous](const ConstraintPtr& constraint) {
gc->dependencies.emplace_back(constraint.get());
if (auto psc = get<PackSubtypeConstraint>(*constraint); psc && psc->returns)
forEachConstraint(
startCheckpoint,
endCheckpoint,
this,
[gc, &previous](const ConstraintPtr& constraint)
{
if (previous)
constraint->dependencies.push_back(NotNull{previous});
gc->dependencies.emplace_back(constraint.get());
previous = constraint.get();
if (auto psc = get<PackSubtypeConstraint>(*constraint); psc && psc->returns)
{
if (previous)
constraint->dependencies.push_back(NotNull{previous});
previous = constraint.get();
}
}
});
);
if (generalize && hasFreeType(sig.signature))
{
@ -2187,9 +2288,13 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprBinary* binar
}
case AstExprBinary::Op::CompareGe:
{
TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().ltFunc,
TypeId resultType = createTypeFunctionInstance(
builtinTypeFunctions().ltFunc,
{rightType, leftType}, // lua decided that `__ge(a, b)` is instead just `__lt(b, a)`
{}, scope, binary->location);
{},
scope,
binary->location
);
return Inference{resultType, std::move(refinement)};
}
case AstExprBinary::Op::CompareLe:
@ -2200,9 +2305,12 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprBinary* binar
case AstExprBinary::Op::CompareGt:
{
TypeId resultType = createTypeFunctionInstance(
builtinTypeFunctions().leFunc,
builtinTypeFunctions().leFunc,
{rightType, leftType}, // lua decided that `__gt(a, b)` is instead just `__le(b, a)`
{}, scope, binary->location);
{},
scope,
binary->location
);
return Inference{resultType, std::move(refinement)};
}
case AstExprBinary::Op::CompareEq:
@ -2234,7 +2342,8 @@ builtinTypeFunctions().leFunc,
Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional<TypeId> expectedType)
{
RefinementId refinement = [&]() {
RefinementId refinement = [&]()
{
InConditionalContext flipper{&typeContext};
ScopePtr condScope = childScope(ifElse->condition, scope);
return check(condScope, ifElse->condition).refinement;
@ -2266,7 +2375,10 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprInterpString*
}
std::tuple<TypeId, TypeId, RefinementId> ConstraintGenerator::checkBinary(
const ScopePtr& scope, AstExprBinary* binary, std::optional<TypeId> expectedType)
const ScopePtr& scope,
AstExprBinary* binary,
std::optional<TypeId> expectedType
)
{
if (binary->op == AstExprBinary::And)
{
@ -2460,7 +2572,8 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprIndexName* e
bool incremented = recordPropertyAssignment(lhsTy);
auto apc = addConstraint(scope, expr->location, AssignPropConstraint{lhsTy, expr->index.value, rhsType, expr->indexLocation, propTy, incremented});
auto apc =
addConstraint(scope, expr->location, AssignPropConstraint{lhsTy, expr->index.value, rhsType, expr->indexLocation, propTy, incremented});
getMutable<BlockedType>(propTy)->setOwner(apc);
}
@ -2476,7 +2589,9 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprIndexExpr* e
bool incremented = recordPropertyAssignment(lhsTy);
auto apc = addConstraint(scope, expr->location, AssignPropConstraint{lhsTy, std::move(propName), rhsType, expr->index->location, propTy, incremented});
auto apc = addConstraint(
scope, expr->location, AssignPropConstraint{lhsTy, std::move(propName), rhsType, expr->index->location, propTy, incremented}
);
getMutable<BlockedType>(propTy)->setOwner(apc);
return;
@ -2505,7 +2620,8 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprTable* expr,
TypeIds indexKeyLowerBound;
TypeIds indexValueLowerBound;
auto createIndexer = [&indexKeyLowerBound, &indexValueLowerBound](const Location& location, TypeId currentIndexType, TypeId currentResultType) {
auto createIndexer = [&indexKeyLowerBound, &indexValueLowerBound](const Location& location, TypeId currentIndexType, TypeId currentResultType)
{
indexKeyLowerBound.insert(follow(currentIndexType));
indexValueLowerBound.insert(follow(currentResultType));
};
@ -2565,14 +2681,19 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprTable* expr,
Unifier2 unifier{arena, builtinTypes, NotNull{scope.get()}, ice};
std::vector<TypeId> toBlock;
matchLiteralType(
NotNull{&module->astTypes}, NotNull{&module->astExpectedTypes}, builtinTypes, arena, NotNull{&unifier}, *expectedType, ty, expr, toBlock);
NotNull{&module->astTypes}, NotNull{&module->astExpectedTypes}, builtinTypes, arena, NotNull{&unifier}, *expectedType, ty, expr, toBlock
);
}
return Inference{ty};
}
ConstraintGenerator::FunctionSignature ConstraintGenerator::checkFunctionSignature(
const ScopePtr& parent, AstExprFunction* fn, std::optional<TypeId> expectedType, std::optional<Location> originalName)
const ScopePtr& parent,
AstExprFunction* fn,
std::optional<TypeId> expectedType,
std::optional<Location> originalName
)
{
ScopePtr signatureScope = nullptr;
ScopePtr bodyScope = nullptr;
@ -3076,7 +3197,11 @@ TypePackId ConstraintGenerator::resolveTypePack(const ScopePtr& scope, const Ast
}
std::vector<std::pair<Name, GenericTypeDefinition>> ConstraintGenerator::createGenerics(
const ScopePtr& scope, AstArray<AstGenericType> generics, bool useCache, bool addTypes)
const ScopePtr& scope,
AstArray<AstGenericType> generics,
bool useCache,
bool addTypes
)
{
std::vector<std::pair<Name, GenericTypeDefinition>> result;
for (const auto& generic : generics)
@ -3106,7 +3231,11 @@ std::vector<std::pair<Name, GenericTypeDefinition>> ConstraintGenerator::createG
}
std::vector<std::pair<Name, GenericTypePackDefinition>> ConstraintGenerator::createGenericPacks(
const ScopePtr& scope, AstArray<AstGenericTypePack> generics, bool useCache, bool addTypes)
const ScopePtr& scope,
AstArray<AstGenericTypePack> generics,
bool useCache,
bool addTypes
)
{
std::vector<std::pair<Name, GenericTypePackDefinition>> result;
for (const auto& generic : generics)
@ -3323,7 +3452,8 @@ std::vector<std::optional<TypeId>> ConstraintGenerator::getExpectedCallTypesForF
// For a list of functions f_0 : e_0 -> r_0, ... f_n : e_n -> r_n,
// emit a list of arguments that the function could take at each position
// by unioning the arguments at each place
auto assignOption = [this, &expectedTypes](size_t index, TypeId ty) {
auto assignOption = [this, &expectedTypes](size_t index, TypeId ty)
{
if (index == expectedTypes.size())
expectedTypes.push_back(ty);
else if (ty)
@ -3372,7 +3502,12 @@ std::vector<std::optional<TypeId>> ConstraintGenerator::getExpectedCallTypesForF
}
TypeId ConstraintGenerator::createTypeFunctionInstance(
const TypeFunction& function, std::vector<TypeId> typeArguments, std::vector<TypePackId> packArguments, const ScopePtr& scope, Location location)
const TypeFunction& function,
std::vector<TypeId> typeArguments,
std::vector<TypePackId> packArguments,
const ScopePtr& scope,
Location location
)
{
TypeId result = arena->addTypeFunction(function, typeArguments, packArguments);
addConstraint(scope, location, ReduceConstraint{result});

View file

@ -89,8 +89,13 @@ size_t HashBlockedConstraintId::operator()(const BlockedConstraintId& bci) const
return true;
}
static std::pair<std::vector<TypeId>, std::vector<TypePackId>> saturateArguments(TypeArena* arena, NotNull<BuiltinTypes> builtinTypes,
const TypeFun& fn, const std::vector<TypeId>& rawTypeArguments, const std::vector<TypePackId>& rawPackArguments)
static std::pair<std::vector<TypeId>, std::vector<TypePackId>> saturateArguments(
TypeArena* arena,
NotNull<BuiltinTypes> builtinTypes,
const TypeFun& fn,
const std::vector<TypeId>& rawTypeArguments,
const std::vector<TypePackId>& rawPackArguments
)
{
std::vector<TypeId> saturatedTypeArguments;
std::vector<TypeId> extraTypes;
@ -310,8 +315,16 @@ struct InstantiationQueuer : TypeOnceVisitor
}
};
ConstraintSolver::ConstraintSolver(NotNull<Normalizer> normalizer, NotNull<Scope> rootScope, std::vector<NotNull<Constraint>> constraints,
ModuleName moduleName, NotNull<ModuleResolver> moduleResolver, std::vector<RequireCycle> requireCycles, DcrLogger* logger, TypeCheckLimits limits)
ConstraintSolver::ConstraintSolver(
NotNull<Normalizer> normalizer,
NotNull<Scope> rootScope,
std::vector<NotNull<Constraint>> constraints,
ModuleName moduleName,
NotNull<ModuleResolver> moduleResolver,
std::vector<RequireCycle> requireCycles,
DcrLogger* logger,
TypeCheckLimits limits
)
: arena(normalizer->arena)
, builtinTypes(normalizer->builtinTypes)
, normalizer(normalizer)
@ -374,7 +387,8 @@ void ConstraintSolver::run()
if (FFlag::DebugLuauLogSolver)
{
printf(
"Starting solver for module %s (%s)\n", moduleResolver->getHumanReadableModuleName(currentModuleName).c_str(), currentModuleName.c_str());
"Starting solver for module %s (%s)\n", moduleResolver->getHumanReadableModuleName(currentModuleName).c_str(), currentModuleName.c_str()
);
dump(this, opts);
printf("Bindings:\n");
dumpBindings(rootScope, opts);
@ -385,7 +399,8 @@ void ConstraintSolver::run()
logger->captureInitialSolverState(rootScope, unsolvedConstraints);
}
auto runSolverPass = [&](bool force) {
auto runSolverPass = [&](bool force)
{
bool progress = false;
size_t i = 0;
@ -489,7 +504,7 @@ void ConstraintSolver::run()
} while (progress);
if (!unsolvedConstraints.empty())
reportError(InternalError{"Type inference failed to complete, you may see some confusing types and type errors."}, Location{});
reportError(ConstraintSolvingIncompleteError{}, Location{});
// After we have run all the constraints, type functions should be generalized
// At this point, we can try to perform one final simplification to suss out
@ -730,7 +745,8 @@ bool ConstraintSolver::tryDispatch(const IterableConstraint& c, NotNull<const Co
* applies constraints to the types of the iterators.
*/
auto block_ = [&](auto&& t) {
auto block_ = [&](auto&& t)
{
if (force)
{
// If we haven't figured out the type of the iteratee by now,
@ -891,7 +907,8 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul
return true;
}
auto bindResult = [this, &c, constraint](TypeId result) {
auto bindResult = [this, &c, constraint](TypeId result)
{
LUAU_ASSERT(get<PendingExpansionType>(c.target));
shiftReferences(c.target, result);
bind(constraint, c.target, result);
@ -929,14 +946,27 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul
auto [typeArguments, packArguments] = saturateArguments(arena, builtinTypes, *tf, petv->typeArguments, petv->packArguments);
bool sameTypes = std::equal(typeArguments.begin(), typeArguments.end(), tf->typeParams.begin(), tf->typeParams.end(), [](auto&& itp, auto&& p) {
return itp == p.ty;
});
bool sameTypes = std::equal(
typeArguments.begin(),
typeArguments.end(),
tf->typeParams.begin(),
tf->typeParams.end(),
[](auto&& itp, auto&& p)
{
return itp == p.ty;
}
);
bool samePacks =
std::equal(packArguments.begin(), packArguments.end(), tf->typePackParams.begin(), tf->typePackParams.end(), [](auto&& itp, auto&& p) {
bool samePacks = std::equal(
packArguments.begin(),
packArguments.end(),
tf->typePackParams.begin(),
tf->typePackParams.end(),
[](auto&& itp, auto&& p)
{
return itp == p.tp;
});
}
);
// If we're instantiating the type with its generic saturatedTypeArguments we are
// performing the identity substitution. We can just short-circuit and bind
@ -1023,9 +1053,14 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul
//clang-format off
bool needsClone = follow(tf->type) == target || (tfTable != nullptr && tfTable == getTableType(target)) ||
std::any_of(typeArguments.begin(), typeArguments.end(), [&](const auto& other) {
return other == target;
});
std::any_of(
typeArguments.begin(),
typeArguments.end(),
[&](const auto& other)
{
return other == target;
}
);
//clang-format on
// Only tables have the properties we're trying to set.
@ -1120,7 +1155,8 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull<cons
if (blocked)
return false;
auto collapse = [](const auto* t) -> std::optional<TypeId> {
auto collapse = [](const auto* t) -> std::optional<TypeId>
{
auto it = begin(t);
auto endIt = end(t);
@ -1145,6 +1181,9 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull<cons
// We don't support magic __call metamethods.
if (std::optional<TypeId> callMm = findMetatableEntry(builtinTypes, errors, fn, "__call", constraint->location))
{
if (isBlocked(*callMm))
return block(*callMm, constraint);
argsHead.insert(argsHead.begin(), fn);
if (argsTail && isBlocked(*argsTail))
@ -1195,7 +1234,8 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull<cons
}
OverloadResolver resolver{
builtinTypes, NotNull{arena}, normalizer, constraint->scope, NotNull{&iceReporter}, NotNull{&limits}, constraint->location};
builtinTypes, NotNull{arena}, normalizer, constraint->scope, NotNull{&iceReporter}, NotNull{&limits}, constraint->location
};
auto [status, overload] = resolver.selectOverload(fn, argsPack);
TypeId overloadToUse = fn;
if (status == OverloadResolver::Analysis::Ok)
@ -1334,8 +1374,7 @@ bool ConstraintSolver::tryDispatch(const FunctionCheckConstraint& c, NotNull<con
}
}
}
else if (expr->is<AstExprConstantBool>() || expr->is<AstExprConstantString>() || expr->is<AstExprConstantNumber>() ||
expr->is<AstExprConstantNil>())
else if (expr->is<AstExprConstantBool>() || expr->is<AstExprConstantString>() || expr->is<AstExprConstantNumber>() || expr->is<AstExprConstantNil>())
{
Unifier2 u2{arena, builtinTypes, constraint->scope, NotNull{&iceReporter}};
u2.unify(actualArgTy, expectedArgTy);
@ -1421,7 +1460,13 @@ bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNull<const Con
}
bool ConstraintSolver::tryDispatchHasIndexer(
int& recursionDepth, NotNull<const Constraint> constraint, TypeId subjectType, TypeId indexType, TypeId resultType, Set<TypeId>& seen)
int& recursionDepth,
NotNull<const Constraint> constraint,
TypeId subjectType,
TypeId indexType,
TypeId resultType,
Set<TypeId>& seen
)
{
RecursionLimiter _rl{&recursionDepth, FInt::LuauSolverRecursionLimit};
@ -1455,7 +1500,8 @@ bool ConstraintSolver::tryDispatchHasIndexer(
FreeType freeResult{ft->scope, builtinTypes->neverType, builtinTypes->unknownType};
emplace<FreeType>(constraint, resultType, freeResult);
TypeId upperBound = arena->addType(TableType{/* props */ {}, TableIndexer{indexType, resultType}, TypeLevel{}, TableState::Unsealed});
TypeId upperBound =
arena->addType(TableType{/* props */ {}, TableIndexer{indexType, resultType}, TypeLevel{}, ft->scope, TableState::Unsealed});
unify(constraint, subjectType, upperBound);
@ -1777,7 +1823,8 @@ bool ConstraintSolver::tryDispatch(const AssignIndexConstraint& c, NotNull<const
// Important: In every codepath through this function, the type `c.propType`
// must be bound to something, even if it's just the errorType.
auto tableStuff = [&](TableType* lhsTable) -> std::optional<bool> {
auto tableStuff = [&](TableType* lhsTable) -> std::optional<bool>
{
if (lhsTable->indexer)
{
unify(constraint, indexType, lhsTable->indexer->indexType);
@ -2074,7 +2121,8 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl
return true;
}
auto unpack = [&](TypeId ty) {
auto unpack = [&](TypeId ty)
{
for (TypeId varTy : c.variables)
{
LUAU_ASSERT(get<BlockedType>(varTy));
@ -2200,7 +2248,12 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl
}
bool ConstraintSolver::tryDispatchIterableFunction(
TypeId nextTy, TypeId tableTy, const IterableConstraint& c, NotNull<const Constraint> constraint, bool force)
TypeId nextTy,
TypeId tableTy,
const IterableConstraint& c,
NotNull<const Constraint> constraint,
bool force
)
{
const FunctionType* nextFn = get<FunctionType>(nextTy);
// If this does not hold, we should've never called `tryDispatchIterableFunction` in the first place.
@ -2237,7 +2290,10 @@ bool ConstraintSolver::tryDispatchIterableFunction(
}
NotNull<const Constraint> ConstraintSolver::unpackAndAssign(
const std::vector<TypeId> destTypes, TypePackId srcTypes, NotNull<const Constraint> constraint)
const std::vector<TypeId> destTypes,
TypePackId srcTypes,
NotNull<const Constraint> constraint
)
{
auto c = pushConstraint(constraint->scope, constraint->location, UnpackConstraint{destTypes, srcTypes});
@ -2251,15 +2307,28 @@ NotNull<const Constraint> ConstraintSolver::unpackAndAssign(
return c;
}
std::pair<std::vector<TypeId>, std::optional<TypeId>> ConstraintSolver::lookupTableProp(NotNull<const Constraint> constraint, TypeId subjectType,
const std::string& propName, ValueContext context, bool inConditional, bool suppressSimplification)
std::pair<std::vector<TypeId>, std::optional<TypeId>> ConstraintSolver::lookupTableProp(
NotNull<const Constraint> constraint,
TypeId subjectType,
const std::string& propName,
ValueContext context,
bool inConditional,
bool suppressSimplification
)
{
DenseHashSet<TypeId> seen{nullptr};
return lookupTableProp(constraint, subjectType, propName, context, inConditional, suppressSimplification, seen);
}
std::pair<std::vector<TypeId>, std::optional<TypeId>> ConstraintSolver::lookupTableProp(NotNull<const Constraint> constraint, TypeId subjectType,
const std::string& propName, ValueContext context, bool inConditional, bool suppressSimplification, DenseHashSet<TypeId>& seen)
std::pair<std::vector<TypeId>, std::optional<TypeId>> ConstraintSolver::lookupTableProp(
NotNull<const Constraint> constraint,
TypeId subjectType,
const std::string& propName,
ValueContext context,
bool inConditional,
bool suppressSimplification,
DenseHashSet<TypeId>& seen
)
{
if (seen.contains(subjectType))
return {};

View file

@ -204,7 +204,8 @@ void DataFlowGraphBuilder::joinBindings(DfgScope* p, const DfgScope& a, const Df
void DataFlowGraphBuilder::joinProps(DfgScope* result, const DfgScope& a, const DfgScope& b)
{
auto phinodify = [this](DfgScope* scope, const auto& a, const auto& b, DefId parent) mutable {
auto phinodify = [this](DfgScope* scope, const auto& a, const auto& b, DefId parent) mutable
{
auto& p = scope->props[parent];
for (const auto& [k, defA] : a)
{
@ -373,6 +374,8 @@ ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStat* s)
return visit(scope, l);
else if (auto t = s->as<AstStatTypeAlias>())
return visit(scope, t);
else if (auto f = s->as<AstStatTypeFunction>())
return visit(scope, f);
else if (auto d = s->as<AstStatDeclareGlobal>())
return visit(scope, d);
else if (auto d = s->as<AstStatDeclareFunction>())
@ -631,6 +634,14 @@ ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatTypeAlias* t)
return ControlFlow::None;
}
ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatTypeFunction* f)
{
DfgScope* unreachable = childScope(scope);
visitExpr(unreachable, f->body);
return ControlFlow::None;
}
ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatDeclareGlobal* d)
{
DefId def = defArena->freshCell();
@ -691,7 +702,8 @@ DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExpr* e)
return {NotNull{*def}, key ? *key : nullptr};
}
auto go = [&]() -> DataFlowResult {
auto go = [&]() -> DataFlowResult
{
if (auto g = e->as<AstExprGroup>())
return visitExpr(scope, g);
else if (auto c = e->as<AstExprConstantNil>())
@ -910,7 +922,8 @@ DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprError* er
void DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExpr* e, DefId incomingDef, bool isCompoundAssignment)
{
auto go = [&]() {
auto go = [&]()
{
if (auto l = e->as<AstExprLocal>())
return visitLValue(scope, l, incomingDef, isCompoundAssignment);
else if (auto g = e->as<AstExprGlobal>())

View file

@ -124,7 +124,8 @@ void write(JsonEmitter& emitter, const ConstraintBlock& block)
ObjectEmitter o = emitter.writeObject();
o.writePair("stringification", block.stringification);
auto go = [&o](auto&& t) {
auto go = [&o](auto&& t)
{
using T = std::decay_t<decltype(t)>;
o.writePair("id", toPointerId(t));
@ -350,8 +351,12 @@ void DcrLogger::popBlock(NotNull<const Constraint> block)
}
}
static void snapshotTypeStrings(const std::vector<ExprTypesAtLocation>& interestedExprs,
const std::vector<AnnotationTypesAtLocation>& interestedAnnots, DenseHashMap<const void*, std::string>& map, ToStringOptions& opts)
static void snapshotTypeStrings(
const std::vector<ExprTypesAtLocation>& interestedExprs,
const std::vector<AnnotationTypesAtLocation>& interestedAnnots,
DenseHashMap<const void*, std::string>& map,
ToStringOptions& opts
)
{
for (const ExprTypesAtLocation& tys : interestedExprs)
{
@ -368,7 +373,10 @@ static void snapshotTypeStrings(const std::vector<ExprTypesAtLocation>& interest
}
void DcrLogger::captureBoundaryState(
BoundarySnapshot& target, const Scope* rootScope, const std::vector<NotNull<const Constraint>>& unsolvedConstraints)
BoundarySnapshot& target,
const Scope* rootScope,
const std::vector<NotNull<const Constraint>>& unsolvedConstraints
)
{
target.rootScope = snapshotScope(rootScope, opts);
target.unsolvedConstraints.clear();
@ -391,7 +399,11 @@ void DcrLogger::captureInitialSolverState(const Scope* rootScope, const std::vec
}
StepSnapshot DcrLogger::prepareStepSnapshot(
const Scope* rootScope, NotNull<const Constraint> current, bool force, const std::vector<NotNull<const Constraint>>& unsolvedConstraints)
const Scope* rootScope,
NotNull<const Constraint> current,
bool force,
const std::vector<NotNull<const Constraint>>& unsolvedConstraints
)
{
ScopeSnapshot scopeSnapshot = snapshotScope(rootScope, opts);
DenseHashMap<const Constraint*, ConstraintSnapshot> constraints{nullptr};

View file

@ -286,15 +286,22 @@ struct FindSeteqCounterexampleResult
bool inLeft;
};
static FindSeteqCounterexampleResult findSeteqCounterexample(
DifferEnvironment& env, const std::vector<TypeId>& left, const std::vector<TypeId>& right);
DifferEnvironment& env,
const std::vector<TypeId>& left,
const std::vector<TypeId>& right
);
static DifferResult diffUnion(DifferEnvironment& env, TypeId left, TypeId right);
static DifferResult diffIntersection(DifferEnvironment& env, TypeId left, TypeId right);
/**
* The last argument gives context info on which complex type contained the TypePack.
*/
static DifferResult diffTpi(DifferEnvironment& env, DiffError::Kind possibleNonNormalErrorKind, TypePackId left, TypePackId right);
static DifferResult diffCanonicalTpShape(DifferEnvironment& env, DiffError::Kind possibleNonNormalErrorKind,
const std::pair<std::vector<TypeId>, std::optional<TypePackId>>& left, const std::pair<std::vector<TypeId>, std::optional<TypePackId>>& right);
static DifferResult diffCanonicalTpShape(
DifferEnvironment& env,
DiffError::Kind possibleNonNormalErrorKind,
const std::pair<std::vector<TypeId>, std::optional<TypePackId>>& left,
const std::pair<std::vector<TypeId>, std::optional<TypePackId>>& right
);
static DifferResult diffHandleFlattenedTail(DifferEnvironment& env, DiffError::Kind possibleNonNormalErrorKind, TypePackId left, TypePackId right);
static DifferResult diffGenericTp(DifferEnvironment& env, TypePackId left, TypePackId right);
@ -324,8 +331,13 @@ static DifferResult diffTable(DifferEnvironment& env, TypeId left, TypeId right)
if (leftTable->props.find(field) == leftTable->props.end())
{
// right has a field the left doesn't
return DifferResult{DiffError{DiffError::Kind::MissingTableProperty, DiffPathNodeLeaf::nullopts(),
DiffPathNodeLeaf::detailsTableProperty(value.type(), field), env.getDevFixFriendlyNameLeft(), env.getDevFixFriendlyNameRight()}};
return DifferResult{DiffError{
DiffError::Kind::MissingTableProperty,
DiffPathNodeLeaf::nullopts(),
DiffPathNodeLeaf::detailsTableProperty(value.type(), field),
env.getDevFixFriendlyNameLeft(),
env.getDevFixFriendlyNameRight()
}};
}
}
// left and right have the same set of keys
@ -491,7 +503,10 @@ static DifferResult diffClass(DifferEnvironment& env, TypeId left, TypeId right)
}
static FindSeteqCounterexampleResult findSeteqCounterexample(
DifferEnvironment& env, const std::vector<TypeId>& left, const std::vector<TypeId>& right)
DifferEnvironment& env,
const std::vector<TypeId>& left,
const std::vector<TypeId>& right
)
{
std::unordered_set<size_t> unmatchedRightIdxes;
for (size_t i = 0; i < right.size(); i++)
@ -760,8 +775,12 @@ static DifferResult diffTpi(DifferEnvironment& env, DiffError::Kind possibleNonN
return diffHandleFlattenedTail(env, possibleNonNormalErrorKind, *leftFlatTpi.second, *rightFlatTpi.second);
}
static DifferResult diffCanonicalTpShape(DifferEnvironment& env, DiffError::Kind possibleNonNormalErrorKind,
const std::pair<std::vector<TypeId>, std::optional<TypePackId>>& left, const std::pair<std::vector<TypeId>, std::optional<TypePackId>>& right)
static DifferResult diffCanonicalTpShape(
DifferEnvironment& env,
DiffError::Kind possibleNonNormalErrorKind,
const std::pair<std::vector<TypeId>, std::optional<TypePackId>>& left,
const std::pair<std::vector<TypeId>, std::optional<TypePackId>>& right
)
{
if (left.first.size() == right.first.size() && left.second.has_value() == right.second.has_value())
return DifferResult{};

View file

@ -21,7 +21,12 @@ LUAU_FASTINTVARIABLE(LuauIndentTypeMismatchMaxTypeLength, 10)
LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauImproveNonFunctionCallError, false)
static std::string wrongNumberOfArgsString(
size_t expectedCount, std::optional<size_t> maximumCount, size_t actualCount, const char* argPrefix = nullptr, bool isVariadic = false)
size_t expectedCount,
std::optional<size_t> maximumCount,
size_t actualCount,
const char* argPrefix = nullptr,
bool isVariadic = false
)
{
std::string s = "expects ";
@ -65,8 +70,21 @@ namespace Luau
{
// this list of binary operator type functions is used for better stringification of type functions errors
static const std::unordered_map<std::string, const char*> kBinaryOps{{"add", "+"}, {"sub", "-"}, {"mul", "*"}, {"div", "/"}, {"idiv", "//"},
{"pow", "^"}, {"mod", "%"}, {"concat", ".."}, {"and", "and"}, {"or", "or"}, {"lt", "< or >="}, {"le", "<= or >"}, {"eq", "== or ~="}};
static const std::unordered_map<std::string, const char*> kBinaryOps{
{"add", "+"},
{"sub", "-"},
{"mul", "*"},
{"div", "/"},
{"idiv", "//"},
{"pow", "^"},
{"mod", "%"},
{"concat", ".."},
{"and", "and"},
{"or", "or"},
{"lt", "< or >="},
{"le", "<= or >"},
{"eq", "== or ~="}
};
// this list of unary operator type functions is used for better stringification of type functions errors
static const std::unordered_map<std::string, const char*> kUnaryOps{{"unm", "-"}, {"len", "#"}, {"not", "not"}};
@ -86,12 +104,15 @@ struct ErrorConverter
std::string result;
auto quote = [&](std::string s) {
auto quote = [&](std::string s)
{
return "'" + s + "'";
};
auto constructErrorMessage = [&](std::string givenType, std::string wantedType, std::optional<std::string> givenModule,
std::optional<std::string> wantedModule) -> std::string {
auto constructErrorMessage =
[&](std::string givenType, std::string wantedType, std::optional<std::string> givenModule, std::optional<std::string> wantedModule
) -> std::string
{
std::string given = givenModule ? quote(givenType) + " from " + quote(*givenModule) : quote(givenType);
std::string wanted = wantedModule ? quote(wantedType) + " from " + quote(*wantedModule) : quote(wantedType);
size_t luauIndentTypeMismatchMaxTypeLength = size_t(FInt::LuauIndentTypeMismatchMaxTypeLength);
@ -351,6 +372,11 @@ struct ErrorConverter
return e.message;
}
std::string operator()(const Luau::ConstraintSolvingIncompleteError& e) const
{
return "Type inference failed to complete, you may see some confusing types and type errors.";
}
std::optional<TypeId> findCallMetamethod(TypeId type) const
{
type = follow(type);
@ -987,6 +1013,11 @@ bool InternalError::operator==(const InternalError& rhs) const
return message == rhs.message;
}
bool ConstraintSolvingIncompleteError::operator==(const ConstraintSolvingIncompleteError& rhs) const
{
return true;
}
bool CannotCallNonFunction::operator==(const CannotCallNonFunction& rhs) const
{
return ty == rhs.ty;
@ -1177,11 +1208,13 @@ bool containsParseErrorName(const TypeError& error)
template<typename T>
void copyError(T& e, TypeArena& destArena, CloneState& cloneState)
{
auto clone = [&](auto&& ty) {
auto clone = [&](auto&& ty)
{
return ::Luau::clone(ty, destArena, cloneState);
};
auto visitErrorData = [&](auto&& e) {
auto visitErrorData = [&](auto&& e)
{
copyError(e, destArena, cloneState);
};
@ -1256,6 +1289,9 @@ void copyError(T& e, TypeArena& destArena, CloneState& cloneState)
else if constexpr (std::is_same_v<T, InternalError>)
{
}
else if constexpr (std::is_same_v<T, ConstraintSolvingIncompleteError>)
{
}
else if constexpr (std::is_same_v<T, CannotCallNonFunction>)
{
e.ty = clone(e.ty);
@ -1363,7 +1399,8 @@ void copyErrors(ErrorVec& errors, TypeArena& destArena, NotNull<BuiltinTypes> bu
{
CloneState cloneState{builtinTypes};
auto visitErrorData = [&](auto&& e) {
auto visitErrorData = [&](auto&& e)
{
copyError(e, destArena, cloneState);
};

View file

@ -176,8 +176,14 @@ static void persistCheckedTypes(ModulePtr checkedModule, GlobalTypes& globals, S
}
}
LoadDefinitionFileResult Frontend::loadDefinitionFile(GlobalTypes& globals, ScopePtr targetScope, std::string_view source,
const std::string& packageName, bool captureComments, bool typeCheckForAutocomplete)
LoadDefinitionFileResult Frontend::loadDefinitionFile(
GlobalTypes& globals,
ScopePtr targetScope,
std::string_view source,
const std::string& packageName,
bool captureComments,
bool typeCheckForAutocomplete
)
{
LUAU_TIMETRACE_SCOPE("loadDefinitionFile", "Frontend");
@ -269,7 +275,10 @@ namespace
{
static ErrorVec accumulateErrors(
const std::unordered_map<ModuleName, std::shared_ptr<SourceNode>>& sourceNodes, ModuleResolver& moduleResolver, const ModuleName& name)
const std::unordered_map<ModuleName, std::shared_ptr<SourceNode>>& sourceNodes,
ModuleResolver& moduleResolver,
const ModuleName& name
)
{
DenseHashSet<ModuleName> seen{{}};
std::vector<ModuleName> queue{name};
@ -301,9 +310,14 @@ static ErrorVec accumulateErrors(
Module& module = *modulePtr;
std::sort(module.errors.begin(), module.errors.end(), [](const TypeError& e1, const TypeError& e2) -> bool {
return e1.location.begin > e2.location.begin;
});
std::sort(
module.errors.begin(),
module.errors.end(),
[](const TypeError& e1, const TypeError& e2) -> bool
{
return e1.location.begin > e2.location.begin;
}
);
result.insert(result.end(), module.errors.begin(), module.errors.end());
}
@ -334,8 +348,12 @@ static void filterLintOptions(LintOptions& lintOptions, const std::vector<HotCom
// For each such path, record the full path and the location of the require in the starting module.
// Note that this is O(V^2) for a fully connected graph and produces O(V) paths of length O(V)
// However, when the graph is acyclic, this is O(V), as well as when only the first cycle is needed (stopAtFirst=true)
std::vector<RequireCycle> getRequireCycles(const FileResolver* resolver,
const std::unordered_map<ModuleName, std::shared_ptr<SourceNode>>& sourceNodes, const SourceNode* start, bool stopAtFirst = false)
std::vector<RequireCycle> getRequireCycles(
const FileResolver* resolver,
const std::unordered_map<ModuleName, std::shared_ptr<SourceNode>>& sourceNodes,
const SourceNode* start,
bool stopAtFirst = false
)
{
std::vector<RequireCycle> result;
@ -503,7 +521,7 @@ CheckResult Frontend::check(const ModuleName& name, std::optional<FrontendOption
{
item.module->ats.root = toString(sourceModule.root);
}
item.module->ats.rootSrc = sourceModule.root;
item.module->ats.traverse(item.module.get(), sourceModule.root, NotNull{&builtinTypes_});
}
}
@ -522,8 +540,11 @@ void Frontend::queueModuleCheck(const ModuleName& name)
moduleQueue.push_back(name);
}
std::vector<ModuleName> Frontend::checkQueuedModules(std::optional<FrontendOptions> optionOverride,
std::function<void(std::function<void()> task)> executeTask, std::function<bool(size_t done, size_t total)> progress)
std::vector<ModuleName> Frontend::checkQueuedModules(
std::optional<FrontendOptions> optionOverride,
std::function<void(std::function<void()> task)> executeTask,
std::function<bool(size_t done, size_t total)> progress
)
{
FrontendOptions frontendOptions = optionOverride.value_or(options);
if (FFlag::DebugLuauDeferredConstraintResolution)
@ -548,9 +569,15 @@ std::vector<ModuleName> Frontend::checkQueuedModules(std::optional<FrontendOptio
}
std::vector<ModuleName> queue;
bool cycleDetected = parseGraph(queue, name, frontendOptions.forAutocomplete, [&seen](const ModuleName& name) {
return seen.contains(name);
});
bool cycleDetected = parseGraph(
queue,
name,
frontendOptions.forAutocomplete,
[&seen](const ModuleName& name)
{
return seen.contains(name);
}
);
addBuildQueueItems(buildQueueItems, queue, cycleDetected, seen, frontendOptions);
}
@ -570,7 +597,8 @@ std::vector<ModuleName> Frontend::checkQueuedModules(std::optional<FrontendOptio
// Default task execution is single-threaded and immediate
if (!executeTask)
{
executeTask = [](std::function<void()> task) {
executeTask = [](std::function<void()> task)
{
task();
};
}
@ -582,7 +610,8 @@ std::vector<ModuleName> Frontend::checkQueuedModules(std::optional<FrontendOptio
size_t processing = 0;
size_t remaining = buildQueueItems.size();
auto itemTask = [&](size_t i) {
auto itemTask = [&](size_t i)
{
BuildQueueItem& item = buildQueueItems[i];
try
@ -602,18 +631,23 @@ std::vector<ModuleName> Frontend::checkQueuedModules(std::optional<FrontendOptio
cv.notify_one();
};
auto sendItemTask = [&](size_t i) {
auto sendItemTask = [&](size_t i)
{
BuildQueueItem& item = buildQueueItems[i];
item.processing = true;
processing++;
executeTask([&itemTask, i]() {
itemTask(i);
});
executeTask(
[&itemTask, i]()
{
itemTask(i);
}
);
};
auto sendCycleItemTask = [&] {
auto sendCycleItemTask = [&]
{
for (size_t i = 0; i < buildQueueItems.size(); i++)
{
BuildQueueItem& item = buildQueueItems[i];
@ -662,9 +696,13 @@ std::vector<ModuleName> Frontend::checkQueuedModules(std::optional<FrontendOptio
std::unique_lock guard(mtx);
// If nothing is ready yet, wait
cv.wait(guard, [&readyQueueItems] {
return !readyQueueItems.empty();
});
cv.wait(
guard,
[&readyQueueItems]
{
return !readyQueueItems.empty();
}
);
// Handle checked items
for (size_t i : readyQueueItems)
@ -782,7 +820,11 @@ std::optional<CheckResult> Frontend::getCheckResult(const ModuleName& name, bool
}
bool Frontend::parseGraph(
std::vector<ModuleName>& buildQueue, const ModuleName& root, bool forAutocomplete, std::function<bool(const ModuleName&)> canSkip)
std::vector<ModuleName>& buildQueue,
const ModuleName& root,
bool forAutocomplete,
std::function<bool(const ModuleName&)> canSkip
)
{
LUAU_TIMETRACE_SCOPE("Frontend::parseGraph", "Frontend");
LUAU_TIMETRACE_ARGUMENT("root", root.c_str());
@ -884,8 +926,13 @@ bool Frontend::parseGraph(
return cyclic;
}
void Frontend::addBuildQueueItems(std::vector<BuildQueueItem>& items, std::vector<ModuleName>& buildQueue, bool cycleDetected,
DenseHashSet<Luau::ModuleName>& seen, const FrontendOptions& frontendOptions)
void Frontend::addBuildQueueItems(
std::vector<BuildQueueItem>& items,
std::vector<ModuleName>& buildQueue,
bool cycleDetected,
DenseHashSet<Luau::ModuleName>& seen,
const FrontendOptions& frontendOptions
)
{
for (const ModuleName& moduleName : buildQueue)
{
@ -981,8 +1028,15 @@ void Frontend::checkBuildQueueItem(BuildQueueItem& item)
if (item.options.forAutocomplete)
{
// The autocomplete typecheck is always in strict mode with DM awareness to provide better type information for IDE features
ModulePtr moduleForAutocomplete = check(sourceModule, Mode::Strict, requireCycles, environmentScope, /*forAutocomplete*/ true,
/*recordJsonLog*/ false, typeCheckLimits);
ModulePtr moduleForAutocomplete = check(
sourceModule,
Mode::Strict,
requireCycles,
environmentScope,
/*forAutocomplete*/ true,
/*recordJsonLog*/ false,
typeCheckLimits
);
double duration = getTimestamp() - timestamp;
@ -1209,14 +1263,37 @@ const SourceModule* Frontend::getSourceModule(const ModuleName& moduleName) cons
return const_cast<Frontend*>(this)->getSourceModule(moduleName);
}
ModulePtr check(const SourceModule& sourceModule, Mode mode, const std::vector<RequireCycle>& requireCycles, NotNull<BuiltinTypes> builtinTypes,
NotNull<InternalErrorReporter> iceHandler, NotNull<ModuleResolver> moduleResolver, NotNull<FileResolver> fileResolver,
const ScopePtr& parentScope, std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope, FrontendOptions options,
TypeCheckLimits limits, std::function<void(const ModuleName&, std::string)> writeJsonLog)
ModulePtr check(
const SourceModule& sourceModule,
Mode mode,
const std::vector<RequireCycle>& requireCycles,
NotNull<BuiltinTypes> builtinTypes,
NotNull<InternalErrorReporter> iceHandler,
NotNull<ModuleResolver> moduleResolver,
NotNull<FileResolver> fileResolver,
const ScopePtr& parentScope,
std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope,
FrontendOptions options,
TypeCheckLimits limits,
std::function<void(const ModuleName&, std::string)> writeJsonLog
)
{
const bool recordJsonLog = FFlag::DebugLuauLogSolverToJson;
return check(sourceModule, mode, requireCycles, builtinTypes, iceHandler, moduleResolver, fileResolver, parentScope,
std::move(prepareModuleScope), options, limits, recordJsonLog, writeJsonLog);
return check(
sourceModule,
mode,
requireCycles,
builtinTypes,
iceHandler,
moduleResolver,
fileResolver,
parentScope,
std::move(prepareModuleScope),
options,
limits,
recordJsonLog,
writeJsonLog
);
}
struct InternalTypeFinder : TypeOnceVisitor
@ -1263,10 +1340,21 @@ struct InternalTypeFinder : TypeOnceVisitor
}
};
ModulePtr check(const SourceModule& sourceModule, Mode mode, const std::vector<RequireCycle>& requireCycles, NotNull<BuiltinTypes> builtinTypes,
NotNull<InternalErrorReporter> iceHandler, NotNull<ModuleResolver> moduleResolver, NotNull<FileResolver> fileResolver,
const ScopePtr& parentScope, std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope, FrontendOptions options,
TypeCheckLimits limits, bool recordJsonLog, std::function<void(const ModuleName&, std::string)> writeJsonLog)
ModulePtr check(
const SourceModule& sourceModule,
Mode mode,
const std::vector<RequireCycle>& requireCycles,
NotNull<BuiltinTypes> builtinTypes,
NotNull<InternalErrorReporter> iceHandler,
NotNull<ModuleResolver> moduleResolver,
NotNull<FileResolver> fileResolver,
const ScopePtr& parentScope,
std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope,
FrontendOptions options,
TypeCheckLimits limits,
bool recordJsonLog,
std::function<void(const ModuleName&, std::string)> writeJsonLog
)
{
LUAU_TIMETRACE_SCOPE("Frontend::check", "Typechecking");
LUAU_TIMETRACE_ARGUMENT("module", sourceModule.name.c_str());
@ -1300,14 +1388,32 @@ ModulePtr check(const SourceModule& sourceModule, Mode mode, const std::vector<R
Normalizer normalizer{&result->internalTypes, builtinTypes, NotNull{&unifierState}};
ConstraintGenerator cg{result, NotNull{&normalizer}, moduleResolver, builtinTypes, iceHandler, parentScope, std::move(prepareModuleScope),
logger.get(), NotNull{&dfg}, requireCycles};
ConstraintGenerator cg{
result,
NotNull{&normalizer},
moduleResolver,
builtinTypes,
iceHandler,
parentScope,
std::move(prepareModuleScope),
logger.get(),
NotNull{&dfg},
requireCycles
};
cg.visitModuleRoot(sourceModule.root);
result->errors = std::move(cg.errors);
ConstraintSolver cs{NotNull{&normalizer}, NotNull(cg.rootScope), borrowConstraints(cg.constraints), result->name, moduleResolver, requireCycles,
logger.get(), limits};
ConstraintSolver cs{
NotNull{&normalizer},
NotNull(cg.rootScope),
borrowConstraints(cg.constraints),
result->name,
moduleResolver,
requireCycles,
logger.get(),
limits
};
if (options.randomizeConstraintResolutionSeed)
cs.randomize(*options.randomizeConstraintResolutionSeed);
@ -1419,22 +1525,41 @@ ModulePtr check(const SourceModule& sourceModule, Mode mode, const std::vector<R
return result;
}
ModulePtr Frontend::check(const SourceModule& sourceModule, Mode mode, std::vector<RequireCycle> requireCycles,
std::optional<ScopePtr> environmentScope, bool forAutocomplete, bool recordJsonLog, TypeCheckLimits typeCheckLimits)
ModulePtr Frontend::check(
const SourceModule& sourceModule,
Mode mode,
std::vector<RequireCycle> requireCycles,
std::optional<ScopePtr> environmentScope,
bool forAutocomplete,
bool recordJsonLog,
TypeCheckLimits typeCheckLimits
)
{
if (FFlag::DebugLuauDeferredConstraintResolution)
{
auto prepareModuleScopeWrap = [this, forAutocomplete](const ModuleName& name, const ScopePtr& scope) {
auto prepareModuleScopeWrap = [this, forAutocomplete](const ModuleName& name, const ScopePtr& scope)
{
if (prepareModuleScope)
prepareModuleScope(name, scope, forAutocomplete);
};
try
{
return Luau::check(sourceModule, mode, requireCycles, builtinTypes, NotNull{&iceHandler},
NotNull{forAutocomplete ? &moduleResolverForAutocomplete : &moduleResolver}, NotNull{fileResolver},
environmentScope ? *environmentScope : globals.globalScope, prepareModuleScopeWrap, options, typeCheckLimits, recordJsonLog,
writeJsonLog);
return Luau::check(
sourceModule,
mode,
requireCycles,
builtinTypes,
NotNull{&iceHandler},
NotNull{forAutocomplete ? &moduleResolverForAutocomplete : &moduleResolver},
NotNull{fileResolver},
environmentScope ? *environmentScope : globals.globalScope,
prepareModuleScopeWrap,
options,
typeCheckLimits,
recordJsonLog,
writeJsonLog
);
}
catch (const InternalCompilerError& err)
{
@ -1445,12 +1570,17 @@ ModulePtr Frontend::check(const SourceModule& sourceModule, Mode mode, std::vect
}
else
{
TypeChecker typeChecker(forAutocomplete ? globalsForAutocomplete.globalScope : globals.globalScope,
forAutocomplete ? &moduleResolverForAutocomplete : &moduleResolver, builtinTypes, &iceHandler);
TypeChecker typeChecker(
forAutocomplete ? globalsForAutocomplete.globalScope : globals.globalScope,
forAutocomplete ? &moduleResolverForAutocomplete : &moduleResolver,
builtinTypes,
&iceHandler
);
if (prepareModuleScope)
{
typeChecker.prepareModuleScope = [this, forAutocomplete](const ModuleName& name, const ScopePtr& scope) {
typeChecker.prepareModuleScope = [this, forAutocomplete](const ModuleName& name, const ScopePtr& scope)
{
prepareModuleScope(name, scope, forAutocomplete);
};
}

View file

@ -26,8 +26,14 @@ struct MutatingGeneralizer : TypeOnceVisitor
bool isWithinFunction = false;
bool avoidSealingTables = false;
MutatingGeneralizer(NotNull<BuiltinTypes> builtinTypes, NotNull<Scope> scope, NotNull<DenseHashSet<TypeId>> cachedTypes,
DenseHashMap<const void*, size_t> positiveTypes, DenseHashMap<const void*, size_t> negativeTypes, bool avoidSealingTables)
MutatingGeneralizer(
NotNull<BuiltinTypes> builtinTypes,
NotNull<Scope> scope,
NotNull<DenseHashSet<TypeId>> cachedTypes,
DenseHashMap<const void*, size_t> positiveTypes,
DenseHashMap<const void*, size_t> negativeTypes,
bool avoidSealingTables
)
: TypeOnceVisitor(/* skipBoundTypes */ true)
, builtinTypes(builtinTypes)
, scope(scope)
@ -867,8 +873,14 @@ struct TypeCacher : TypeOnceVisitor
}
};
std::optional<TypeId> generalize(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtinTypes, NotNull<Scope> scope,
NotNull<DenseHashSet<TypeId>> cachedTypes, TypeId ty, bool avoidSealingTables)
std::optional<TypeId> generalize(
NotNull<TypeArena> arena,
NotNull<BuiltinTypes> builtinTypes,
NotNull<Scope> scope,
NotNull<DenseHashSet<TypeId>> cachedTypes,
TypeId ty,
bool avoidSealingTables
)
{
ty = follow(ty);

View file

@ -102,8 +102,15 @@ TypePackId Instantiation::clean(TypePackId tp)
return tp;
}
void ReplaceGenerics::resetState(const TxnLog* log, TypeArena* arena, NotNull<BuiltinTypes> builtinTypes, TypeLevel level, Scope* scope,
const std::vector<TypeId>& generics, const std::vector<TypePackId>& genericPacks)
void ReplaceGenerics::resetState(
const TxnLog* log,
TypeArena* arena,
NotNull<BuiltinTypes> builtinTypes,
TypeLevel level,
Scope* scope,
const std::vector<TypeId>& generics,
const std::vector<TypePackId>& genericPacks
)
{
LUAU_ASSERT(FFlag::LuauReusableSubstitutions);
@ -187,7 +194,12 @@ TypePackId ReplaceGenerics::clean(TypePackId tp)
}
std::optional<TypeId> instantiate(
NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, NotNull<TypeCheckLimits> limits, NotNull<Scope> scope, TypeId ty)
NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> arena,
NotNull<TypeCheckLimits> limits,
NotNull<Scope> scope,
TypeId ty
)
{
ty = follow(ty);

View file

@ -8,6 +8,23 @@ bool Instantiation2::ignoreChildren(TypeId ty)
{
if (get<ClassType>(ty))
return true;
if (auto ftv = get<FunctionType>(ty))
{
if (ftv->hasNoFreeOrGenericTypes)
return false;
// If this function type quantifies over these generics, we don't want substitution to
// go any further into them because it's being shadowed in this case.
for (auto generic : ftv->generics)
if (genericSubstitutions.contains(generic))
return true;
for (auto generic : ftv->genericPacks)
if (genericPackSubstitutions.contains(generic))
return true;
}
return false;
}
@ -47,14 +64,22 @@ TypePackId Instantiation2::clean(TypePackId tp)
}
std::optional<TypeId> instantiate2(
TypeArena* arena, DenseHashMap<TypeId, TypeId> genericSubstitutions, DenseHashMap<TypePackId, TypePackId> genericPackSubstitutions, TypeId ty)
TypeArena* arena,
DenseHashMap<TypeId, TypeId> genericSubstitutions,
DenseHashMap<TypePackId, TypePackId> genericPackSubstitutions,
TypeId ty
)
{
Instantiation2 instantiation{arena, std::move(genericSubstitutions), std::move(genericPackSubstitutions)};
return instantiation.substitute(ty);
}
std::optional<TypePackId> instantiate2(
TypeArena* arena, DenseHashMap<TypeId, TypeId> genericSubstitutions, DenseHashMap<TypePackId, TypePackId> genericPackSubstitutions, TypePackId tp)
TypeArena* arena,
DenseHashMap<TypeId, TypeId> genericSubstitutions,
DenseHashMap<TypePackId, TypePackId> genericPackSubstitutions,
TypePackId tp
)
{
Instantiation2 instantiation{arena, std::move(genericSubstitutions), std::move(genericPackSubstitutions)};
return instantiation.substitute(tp);

View file

@ -114,6 +114,8 @@ static void errorToString(std::ostream& stream, const T& err)
stream << "GenericError { " << err.message << " }";
else if constexpr (std::is_same_v<T, InternalError>)
stream << "InternalError { " << err.message << " }";
else if constexpr (std::is_same_v<T, ConstraintSolvingIncompleteError>)
stream << "ConstraintSolvingIncompleteError {}";
else if constexpr (std::is_same_v<T, CannotCallNonFunction>)
stream << "CannotCallNonFunction { " << toString(err.ty) << " }";
else if constexpr (std::is_same_v<T, ExtraInformation>)
@ -259,7 +261,8 @@ std::ostream& operator<<(std::ostream& stream, const CannotAssignToNever::Reason
std::ostream& operator<<(std::ostream& stream, const TypeErrorData& data)
{
auto cb = [&](const auto& e) {
auto cb = [&](const auto& e)
{
return errorToString(stream, e);
};
visit(cb, data);

View file

@ -275,8 +275,14 @@ private:
else if (g->deprecated)
{
if (const char* replacement = *g->deprecated; replacement && strlen(replacement))
emitWarning(*context, LintWarning::Code_DeprecatedGlobal, gv->location, "Global '%s' is deprecated, use '%s' instead",
gv->name.value, replacement);
emitWarning(
*context,
LintWarning::Code_DeprecatedGlobal,
gv->location,
"Global '%s' is deprecated, use '%s' instead",
gv->name.value,
replacement
);
else
emitWarning(*context, LintWarning::Code_DeprecatedGlobal, gv->location, "Global '%s' is deprecated", gv->name.value);
}
@ -291,18 +297,33 @@ private:
AstExprFunction* top = g.functionRef.back();
if (top->debugname.value)
emitWarning(*context, LintWarning::Code_GlobalUsedAsLocal, g.firstRef->location,
"Global '%s' is only used in the enclosing function '%s'; consider changing it to local", g.firstRef->name.value,
top->debugname.value);
emitWarning(
*context,
LintWarning::Code_GlobalUsedAsLocal,
g.firstRef->location,
"Global '%s' is only used in the enclosing function '%s'; consider changing it to local",
g.firstRef->name.value,
top->debugname.value
);
else
emitWarning(*context, LintWarning::Code_GlobalUsedAsLocal, g.firstRef->location,
emitWarning(
*context,
LintWarning::Code_GlobalUsedAsLocal,
g.firstRef->location,
"Global '%s' is only used in the enclosing function defined at line %d; consider changing it to local",
g.firstRef->name.value, top->location.begin.line + 1);
g.firstRef->name.value,
top->location.begin.line + 1
);
}
else if (g.assigned && !g.readBeforeWritten && !g.definedInModuleScope && g.firstRef->name != context->placeholder)
{
emitWarning(*context, LintWarning::Code_GlobalUsedAsLocal, g.firstRef->location,
"Global '%s' is never read before being written. Consider changing it to local", g.firstRef->name.value);
emitWarning(
*context,
LintWarning::Code_GlobalUsedAsLocal,
g.firstRef->location,
"Global '%s' is never read before being written. Consider changing it to local",
g.firstRef->name.value
);
}
}
}
@ -329,7 +350,8 @@ private:
if (node->name == context->placeholder)
emitWarning(
*context, LintWarning::Code_PlaceholderRead, node->location, "Placeholder value '_' is read here; consider using a named variable");
*context, LintWarning::Code_PlaceholderRead, node->location, "Placeholder value '_' is read here; consider using a named variable"
);
return true;
}
@ -338,7 +360,8 @@ private:
{
if (node->local->name == context->placeholder)
emitWarning(
*context, LintWarning::Code_PlaceholderRead, node->location, "Placeholder value '_' is read here; consider using a named variable");
*context, LintWarning::Code_PlaceholderRead, node->location, "Placeholder value '_' is read here; consider using a named variable"
);
return true;
}
@ -366,8 +389,13 @@ private:
}
if (g.builtin)
emitWarning(*context, LintWarning::Code_BuiltinGlobalWrite, gv->location,
"Built-in global '%s' is overwritten here; consider using a local or changing the name", gv->name.value);
emitWarning(
*context,
LintWarning::Code_BuiltinGlobalWrite,
gv->location,
"Built-in global '%s' is overwritten here; consider using a local or changing the name",
gv->name.value
);
else
g.assigned = true;
@ -396,8 +424,13 @@ private:
Global& g = globals[gv->name];
if (g.builtin)
emitWarning(*context, LintWarning::Code_BuiltinGlobalWrite, gv->location,
"Built-in global '%s' is overwritten here; consider using a local or changing the name", gv->name.value);
emitWarning(
*context,
LintWarning::Code_BuiltinGlobalWrite,
gv->location,
"Built-in global '%s' is overwritten here; consider using a local or changing the name",
gv->name.value
);
else
{
g.assigned = true;
@ -565,8 +598,12 @@ private:
if (node->body.data[i - 1]->hasSemicolon)
continue;
emitWarning(*context, LintWarning::Code_SameLineStatement, location,
"A new statement is on the same line; add semi-colon on previous statement to silence");
emitWarning(
*context,
LintWarning::Code_SameLineStatement,
location,
"A new statement is on the same line; add semi-colon on previous statement to silence"
);
lastLine = location.begin.line;
}
@ -613,7 +650,8 @@ private:
if (location.begin.column <= top.start.begin.column)
{
emitWarning(
*context, LintWarning::Code_MultiLineStatement, location, "Statement spans multiple lines; use indentation to silence");
*context, LintWarning::Code_MultiLineStatement, location, "Statement spans multiple lines; use indentation to silence"
);
top.flagged = true;
}
@ -727,8 +765,14 @@ private:
// don't warn on inter-function shadowing since it is much more fragile wrt refactoring
if (shadow->functionDepth == local->functionDepth)
emitWarning(*context, LintWarning::Code_LocalShadow, local->location, "Variable '%s' shadows previous declaration at line %d",
local->name.value, shadow->location.begin.line + 1);
emitWarning(
*context,
LintWarning::Code_LocalShadow,
local->location,
"Variable '%s' shadows previous declaration at line %d",
local->name.value,
shadow->location.begin.line + 1
);
}
else if (Global* global = globals.find(local->name))
{
@ -736,8 +780,14 @@ private:
; // there are many builtins with common names like 'table'; some of them are deprecated as well
else if (global->firstRef)
{
emitWarning(*context, LintWarning::Code_LocalShadow, local->location, "Variable '%s' shadows a global variable used at line %d",
local->name.value, global->firstRef->location.begin.line + 1);
emitWarning(
*context,
LintWarning::Code_LocalShadow,
local->location,
"Variable '%s' shadows a global variable used at line %d",
local->name.value,
global->firstRef->location.begin.line + 1
);
}
else
{
@ -752,14 +802,21 @@ private:
return;
if (info.function)
emitWarning(*context, LintWarning::Code_FunctionUnused, local->location, "Function '%s' is never used; prefix with '_' to silence",
local->name.value);
emitWarning(
*context,
LintWarning::Code_FunctionUnused,
local->location,
"Function '%s' is never used; prefix with '_' to silence",
local->name.value
);
else if (info.import)
emitWarning(*context, LintWarning::Code_ImportUnused, local->location, "Import '%s' is never used; prefix with '_' to silence",
local->name.value);
emitWarning(
*context, LintWarning::Code_ImportUnused, local->location, "Import '%s' is never used; prefix with '_' to silence", local->name.value
);
else
emitWarning(*context, LintWarning::Code_LocalUnused, local->location, "Variable '%s' is never used; prefix with '_' to silence",
local->name.value);
emitWarning(
*context, LintWarning::Code_LocalUnused, local->location, "Variable '%s' is never used; prefix with '_' to silence", local->name.value
);
}
bool isRequireCall(AstExpr* expr)
@ -913,8 +970,13 @@ private:
for (auto& g : globals)
{
if (g.second.function && !g.second.used && g.first.value[0] != '_')
emitWarning(*context, LintWarning::Code_FunctionUnused, g.second.location, "Function '%s' is never used; prefix with '_' to silence",
g.first.value);
emitWarning(
*context,
LintWarning::Code_FunctionUnused,
g.second.location,
"Function '%s' is never used; prefix with '_' to silence",
g.first.value
);
}
}
@ -1013,8 +1075,13 @@ private:
if (step == Error && si->is<AstStatExpr>() && next->is<AstStatReturn>() && i + 2 == stat->body.size)
return Error;
emitWarning(*context, LintWarning::Code_UnreachableCode, next->location, "Unreachable code (previous statement always %ss)",
getReason(step));
emitWarning(
*context,
LintWarning::Code_UnreachableCode,
next->location,
"Unreachable code (previous statement always %ss)",
getReason(step)
);
return step;
}
}
@ -1209,22 +1276,34 @@ private:
// for i=#t,1 do
if (fu && fu->op == AstExprUnary::Len && tc && tc->value == 1.0)
emitWarning(
*context, LintWarning::Code_ForRange, rangeLocation, "For loop should iterate backwards; did you forget to specify -1 as step?");
*context, LintWarning::Code_ForRange, rangeLocation, "For loop should iterate backwards; did you forget to specify -1 as step?"
);
// for i=8,1 do
else if (fc && tc && fc->value > tc->value)
emitWarning(
*context, LintWarning::Code_ForRange, rangeLocation, "For loop should iterate backwards; did you forget to specify -1 as step?");
*context, LintWarning::Code_ForRange, rangeLocation, "For loop should iterate backwards; did you forget to specify -1 as step?"
);
// for i=1,8.75 do
else if (fc && tc && getLoopEnd(fc->value, tc->value) != tc->value)
emitWarning(*context, LintWarning::Code_ForRange, rangeLocation, "For loop ends at %g instead of %g; did you forget to specify step?",
getLoopEnd(fc->value, tc->value), tc->value);
emitWarning(
*context,
LintWarning::Code_ForRange,
rangeLocation,
"For loop ends at %g instead of %g; did you forget to specify step?",
getLoopEnd(fc->value, tc->value),
tc->value
);
// for i=0,#t do
else if (fc && tu && fc->value == 0.0 && tu->op == AstExprUnary::Len)
emitWarning(*context, LintWarning::Code_ForRange, rangeLocation, "For loop starts at 0, but arrays start at 1");
// for i=#t,0 do
else if (fu && fu->op == AstExprUnary::Len && tc && tc->value == 0.0)
emitWarning(*context, LintWarning::Code_ForRange, rangeLocation,
"For loop should iterate backwards; did you forget to specify -1 as step? Also consider changing 0 to 1 since arrays start at 1");
emitWarning(
*context,
LintWarning::Code_ForRange,
rangeLocation,
"For loop should iterate backwards; did you forget to specify -1 as step? Also consider changing 0 to 1 since arrays start at 1"
);
}
return true;
@ -1252,16 +1331,27 @@ private:
AstExpr* last = values.data[values.size - 1];
if (vars < values.size)
emitWarning(*context, LintWarning::Code_UnbalancedAssignment, location,
"Assigning %d values to %d variables leaves some values unused", int(values.size), int(vars));
emitWarning(
*context,
LintWarning::Code_UnbalancedAssignment,
location,
"Assigning %d values to %d variables leaves some values unused",
int(values.size),
int(vars)
);
else if (last->is<AstExprCall>() || last->is<AstExprVarargs>())
; // we don't know how many values the last expression returns
else if (last->is<AstExprConstantNil>())
; // last expression is nil which explicitly silences the nil-init warning
else
emitWarning(*context, LintWarning::Code_UnbalancedAssignment, location,
"Assigning %d values to %d variables initializes extra variables with nil; add 'nil' to value list to silence", int(values.size),
int(vars));
emitWarning(
*context,
LintWarning::Code_UnbalancedAssignment,
location,
"Assigning %d values to %d variables initializes extra variables with nil; add 'nil' to value list to silence",
int(values.size),
int(vars)
);
}
}
@ -1344,13 +1434,22 @@ private:
Location location = getEndLocation(bodyf);
if (node->debugname.value)
emitWarning(*context, LintWarning::Code_ImplicitReturn, location,
emitWarning(
*context,
LintWarning::Code_ImplicitReturn,
location,
"Function '%s' can implicitly return no values even though there's an explicit return at line %d; add explicit return to silence",
node->debugname.value, vret->location.begin.line + 1);
node->debugname.value,
vret->location.begin.line + 1
);
else
emitWarning(*context, LintWarning::Code_ImplicitReturn, location,
emitWarning(
*context,
LintWarning::Code_ImplicitReturn,
location,
"Function can implicitly return no values even though there's an explicit return at line %d; add explicit return to silence",
vret->location.begin.line + 1);
vret->location.begin.line + 1
);
}
return true;
@ -1821,23 +1920,41 @@ private:
int& line = names[&expr->value];
if (line)
emitWarning(*context, LintWarning::Code_TableLiteral, expr->location,
"Table field '%.*s' is a duplicate; previously defined at line %d", int(expr->value.size), expr->value.data, line);
emitWarning(
*context,
LintWarning::Code_TableLiteral,
expr->location,
"Table field '%.*s' is a duplicate; previously defined at line %d",
int(expr->value.size),
expr->value.data,
line
);
else
line = expr->location.begin.line + 1;
}
else if (AstExprConstantNumber* expr = item.key->as<AstExprConstantNumber>())
{
if (expr->value >= 1 && expr->value <= double(count) && double(int(expr->value)) == expr->value)
emitWarning(*context, LintWarning::Code_TableLiteral, expr->location,
"Table index %d is a duplicate; previously defined as a list entry", int(expr->value));
emitWarning(
*context,
LintWarning::Code_TableLiteral,
expr->location,
"Table index %d is a duplicate; previously defined as a list entry",
int(expr->value)
);
else if (expr->value >= 0 && expr->value <= double(INT_MAX) && double(int(expr->value)) == expr->value)
{
int& line = indices[int(expr->value)];
if (line)
emitWarning(*context, LintWarning::Code_TableLiteral, expr->location,
"Table index %d is a duplicate; previously defined at line %d", int(expr->value), line);
emitWarning(
*context,
LintWarning::Code_TableLiteral,
expr->location,
"Table index %d is a duplicate; previously defined at line %d",
int(expr->value),
line
);
else
line = expr->location.begin.line + 1;
}
@ -1875,18 +1992,41 @@ private:
if (int(rec->access) & int(item.access))
{
if (rec->access == item.access)
emitWarning(*context, LintWarning::Code_TableLiteral, item.location,
"Table type field '%s' is a duplicate; previously defined at line %d", item.name.value, rec->location.begin.line + 1);
emitWarning(
*context,
LintWarning::Code_TableLiteral,
item.location,
"Table type field '%s' is a duplicate; previously defined at line %d",
item.name.value,
rec->location.begin.line + 1
);
else if (rec->access == AstTableAccess::ReadWrite)
emitWarning(*context, LintWarning::Code_TableLiteral, item.location,
"Table type field '%s' is already read-write; previously defined at line %d", item.name.value,
rec->location.begin.line + 1);
emitWarning(
*context,
LintWarning::Code_TableLiteral,
item.location,
"Table type field '%s' is already read-write; previously defined at line %d",
item.name.value,
rec->location.begin.line + 1
);
else if (rec->access == AstTableAccess::Read)
emitWarning(*context, LintWarning::Code_TableLiteral, rec->location,
"Table type field '%s' already has a read type defined at line %d", item.name.value, rec->location.begin.line + 1);
emitWarning(
*context,
LintWarning::Code_TableLiteral,
rec->location,
"Table type field '%s' already has a read type defined at line %d",
item.name.value,
rec->location.begin.line + 1
);
else if (rec->access == AstTableAccess::Write)
emitWarning(*context, LintWarning::Code_TableLiteral, rec->location,
"Table type field '%s' already has a write type defined at line %d", item.name.value, rec->location.begin.line + 1);
emitWarning(
*context,
LintWarning::Code_TableLiteral,
rec->location,
"Table type field '%s' already has a write type defined at line %d",
item.name.value,
rec->location.begin.line + 1
);
else
LUAU_ASSERT(!"Unreachable");
}
@ -1904,8 +2044,14 @@ private:
int& line = names[item.name];
if (line)
emitWarning(*context, LintWarning::Code_TableLiteral, item.location,
"Table type field '%s' is a duplicate; previously defined at line %d", item.name.value, line);
emitWarning(
*context,
LintWarning::Code_TableLiteral,
item.location,
"Table type field '%s' is a duplicate; previously defined at line %d",
item.name.value,
line
);
else
line = item.location.begin.line + 1;
}
@ -1966,9 +2112,14 @@ private:
if (l.defined && !l.initialized && !l.assigned && l.firstUse)
{
emitWarning(*context, LintWarning::Code_UninitializedLocal, l.firstUse->location,
"Variable '%s' defined at line %d is never initialized or assigned; initialize with 'nil' to silence", local->name.value,
local->location.begin.line + 1);
emitWarning(
*context,
LintWarning::Code_UninitializedLocal,
l.firstUse->location,
"Variable '%s' defined at line %d is never initialized or assigned; initialize with 'nil' to silence",
local->name.value,
local->location.begin.line + 1
);
}
}
}
@ -2102,8 +2253,14 @@ private:
void report(const std::string& name, Location location, Location otherLocation)
{
emitWarning(*context, LintWarning::Code_DuplicateFunction, location, "Duplicate function definition: '%s' also defined on line %d",
name.c_str(), otherLocation.begin.line + 1);
emitWarning(
*context,
LintWarning::Code_DuplicateFunction,
location,
"Duplicate function definition: '%s' also defined on line %d",
name.c_str(),
otherLocation.begin.line + 1
);
}
};
@ -2152,7 +2309,8 @@ private:
const char* suggestion = (fenv->name == "getfenv") ? "; consider using 'debug.info' instead" : "";
emitWarning(
*context, LintWarning::Code_DeprecatedApi, node->location, "Function '%s' is deprecated%s", fenv->name.value, suggestion);
*context, LintWarning::Code_DeprecatedApi, node->location, "Function '%s' is deprecated%s", fenv->name.value, suggestion
);
}
}
}
@ -2265,7 +2423,8 @@ private:
if (!tty->indexer && !tty->props.empty() && tty->state != TableState::Generic)
emitWarning(
*context, LintWarning::Code_TableOperations, node->location, "Using '%s' on a table without an array part is likely a bug", op);
*context, LintWarning::Code_TableOperations, node->location, "Using '%s' on a table without an array part is likely a bug", op
);
else if (tty->indexer && isString(tty->indexer->indexType)) // note: to avoid complexity of subtype tests we just check if the key is a string
emitWarning(*context, LintWarning::Code_TableOperations, node->location, "Using '%s' on a table with string keys is likely a bug", op);
}
@ -2283,9 +2442,13 @@ private:
size_t ret = getReturnCount(follow(*funty));
if (ret > 1)
emitWarning(*context, LintWarning::Code_TableOperations, tail->location,
emitWarning(
*context,
LintWarning::Code_TableOperations,
tail->location,
"table.insert may change behavior if the call returns more than one result; consider adding parentheses around second "
"argument");
"argument"
);
}
}
}
@ -2294,28 +2457,44 @@ private:
{
// table.insert(t, 0, ?)
if (isConstant(args[1], 0.0))
emitWarning(*context, LintWarning::Code_TableOperations, args[1]->location,
"table.insert uses index 0 but arrays are 1-based; did you mean 1 instead?");
emitWarning(
*context,
LintWarning::Code_TableOperations,
args[1]->location,
"table.insert uses index 0 but arrays are 1-based; did you mean 1 instead?"
);
// table.insert(t, #t, ?)
if (isLength(args[1], args[0]))
emitWarning(*context, LintWarning::Code_TableOperations, args[1]->location,
emitWarning(
*context,
LintWarning::Code_TableOperations,
args[1]->location,
"table.insert will insert the value before the last element, which is likely a bug; consider removing the second argument or "
"wrap it in parentheses to silence");
"wrap it in parentheses to silence"
);
// table.insert(t, #t+1, ?)
if (AstExprBinary* add = args[1]->as<AstExprBinary>();
add && add->op == AstExprBinary::Add && isLength(add->left, args[0]) && isConstant(add->right, 1.0))
emitWarning(*context, LintWarning::Code_TableOperations, args[1]->location,
"table.insert will append the value to the table; consider removing the second argument for efficiency");
emitWarning(
*context,
LintWarning::Code_TableOperations,
args[1]->location,
"table.insert will append the value to the table; consider removing the second argument for efficiency"
);
}
if (func->index == "remove" && node->args.size >= 2)
{
// table.remove(t, 0)
if (isConstant(args[1], 0.0))
emitWarning(*context, LintWarning::Code_TableOperations, args[1]->location,
"table.remove uses index 0 but arrays are 1-based; did you mean 1 instead?");
emitWarning(
*context,
LintWarning::Code_TableOperations,
args[1]->location,
"table.remove uses index 0 but arrays are 1-based; did you mean 1 instead?"
);
// note: it's tempting to check for table.remove(t, #t), which is equivalent to table.remove(t), but it's correct, occurs frequently,
// and also reads better.
@ -2323,35 +2502,55 @@ private:
// table.remove(t, #t-1)
if (AstExprBinary* sub = args[1]->as<AstExprBinary>();
sub && sub->op == AstExprBinary::Sub && isLength(sub->left, args[0]) && isConstant(sub->right, 1.0))
emitWarning(*context, LintWarning::Code_TableOperations, args[1]->location,
emitWarning(
*context,
LintWarning::Code_TableOperations,
args[1]->location,
"table.remove will remove the value before the last element, which is likely a bug; consider removing the second argument or "
"wrap it in parentheses to silence");
"wrap it in parentheses to silence"
);
}
if (func->index == "move" && node->args.size >= 4)
{
// table.move(t, 0, _, _)
if (isConstant(args[1], 0.0))
emitWarning(*context, LintWarning::Code_TableOperations, args[1]->location,
"table.move uses index 0 but arrays are 1-based; did you mean 1 instead?");
emitWarning(
*context,
LintWarning::Code_TableOperations,
args[1]->location,
"table.move uses index 0 but arrays are 1-based; did you mean 1 instead?"
);
// table.move(t, _, _, 0)
else if (isConstant(args[3], 0.0))
emitWarning(*context, LintWarning::Code_TableOperations, args[3]->location,
"table.move uses index 0 but arrays are 1-based; did you mean 1 instead?");
emitWarning(
*context,
LintWarning::Code_TableOperations,
args[3]->location,
"table.move uses index 0 but arrays are 1-based; did you mean 1 instead?"
);
}
if (func->index == "create" && node->args.size == 2)
{
// table.create(n, {...})
if (args[1]->is<AstExprTable>())
emitWarning(*context, LintWarning::Code_TableOperations, args[1]->location,
"table.create with a table literal will reuse the same object for all elements; consider using a for loop instead");
emitWarning(
*context,
LintWarning::Code_TableOperations,
args[1]->location,
"table.create with a table literal will reuse the same object for all elements; consider using a for loop instead"
);
// table.create(n, {...} :: ?)
if (AstExprTypeAssertion* as = args[1]->as<AstExprTypeAssertion>(); as && as->expr->is<AstExprTable>())
emitWarning(*context, LintWarning::Code_TableOperations, as->expr->location,
"table.create with a table literal will reuse the same object for all elements; consider using a for loop instead");
emitWarning(
*context,
LintWarning::Code_TableOperations,
as->expr->location,
"table.create with a table literal will reuse the same object for all elements; consider using a for loop instead"
);
}
}
@ -2543,11 +2742,21 @@ private:
if (similar(conditions[j], conditions[i]))
{
if (conditions[i]->location.begin.line == conditions[j]->location.begin.line)
emitWarning(*context, LintWarning::Code_DuplicateCondition, conditions[i]->location,
"Condition has already been checked on column %d", conditions[j]->location.begin.column + 1);
emitWarning(
*context,
LintWarning::Code_DuplicateCondition,
conditions[i]->location,
"Condition has already been checked on column %d",
conditions[j]->location.begin.column + 1
);
else
emitWarning(*context, LintWarning::Code_DuplicateCondition, conditions[i]->location,
"Condition has already been checked on line %d", conditions[j]->location.begin.line + 1);
emitWarning(
*context,
LintWarning::Code_DuplicateCondition,
conditions[i]->location,
"Condition has already been checked on line %d",
conditions[j]->location.begin.line + 1
);
break;
}
}
@ -2592,11 +2801,23 @@ private:
if (local->shadow && locals[local->shadow] == node && !ignoreDuplicate(local))
{
if (local->shadow->location.begin.line == local->location.begin.line)
emitWarning(*context, LintWarning::Code_DuplicateLocal, local->location, "Variable '%s' already defined on column %d",
local->name.value, local->shadow->location.begin.column + 1);
emitWarning(
*context,
LintWarning::Code_DuplicateLocal,
local->location,
"Variable '%s' already defined on column %d",
local->name.value,
local->shadow->location.begin.column + 1
);
else
emitWarning(*context, LintWarning::Code_DuplicateLocal, local->location, "Variable '%s' already defined on line %d",
local->name.value, local->shadow->location.begin.line + 1);
emitWarning(
*context,
LintWarning::Code_DuplicateLocal,
local->location,
"Variable '%s' already defined on line %d",
local->name.value,
local->shadow->location.begin.line + 1
);
}
}
@ -2620,11 +2841,23 @@ private:
if (local->shadow == node->self)
emitWarning(*context, LintWarning::Code_DuplicateLocal, local->location, "Function parameter 'self' already defined implicitly");
else if (local->shadow->location.begin.line == local->location.begin.line)
emitWarning(*context, LintWarning::Code_DuplicateLocal, local->location, "Function parameter '%s' already defined on column %d",
local->name.value, local->shadow->location.begin.column + 1);
emitWarning(
*context,
LintWarning::Code_DuplicateLocal,
local->location,
"Function parameter '%s' already defined on column %d",
local->name.value,
local->shadow->location.begin.column + 1
);
else
emitWarning(*context, LintWarning::Code_DuplicateLocal, local->location, "Function parameter '%s' already defined on line %d",
local->name.value, local->shadow->location.begin.line + 1);
emitWarning(
*context,
LintWarning::Code_DuplicateLocal,
local->location,
"Function parameter '%s' already defined on line %d",
local->name.value,
local->shadow->location.begin.line + 1
);
}
}
@ -2668,10 +2901,14 @@ private:
alt = "false";
if (alt)
emitWarning(*context, LintWarning::Code_MisleadingAndOr, node->location,
emitWarning(
*context,
LintWarning::Code_MisleadingAndOr,
node->location,
"The and-or expression always evaluates to the second alternative because the first alternative is %s; consider using if-then-else "
"expression instead",
alt);
alt
);
return true;
}
@ -2699,16 +2936,28 @@ private:
case ConstantNumberParseResult::Malformed:
break;
case ConstantNumberParseResult::Imprecise:
emitWarning(*context, LintWarning::Code_IntegerParsing, node->location,
"Number literal exceeded available precision and was truncated to closest representable number");
emitWarning(
*context,
LintWarning::Code_IntegerParsing,
node->location,
"Number literal exceeded available precision and was truncated to closest representable number"
);
break;
case ConstantNumberParseResult::BinOverflow:
emitWarning(*context, LintWarning::Code_IntegerParsing, node->location,
"Binary number literal exceeded available precision and was truncated to 2^64");
emitWarning(
*context,
LintWarning::Code_IntegerParsing,
node->location,
"Binary number literal exceeded available precision and was truncated to 2^64"
);
break;
case ConstantNumberParseResult::HexOverflow:
emitWarning(*context, LintWarning::Code_IntegerParsing, node->location,
"Hexadecimal number literal exceeded available precision and was truncated to 2^64");
emitWarning(
*context,
LintWarning::Code_IntegerParsing,
node->location,
"Hexadecimal number literal exceeded available precision and was truncated to 2^64"
);
break;
}
@ -2759,12 +3008,24 @@ private:
std::string op = toString(node->op);
if (isEquality(node->op))
emitWarning(*context, LintWarning::Code_ComparisonPrecedence, node->location,
"not X %s Y is equivalent to (not X) %s Y; consider using X %s Y, or add parentheses to silence", op.c_str(), op.c_str(),
node->op == AstExprBinary::CompareEq ? "~=" : "==");
emitWarning(
*context,
LintWarning::Code_ComparisonPrecedence,
node->location,
"not X %s Y is equivalent to (not X) %s Y; consider using X %s Y, or add parentheses to silence",
op.c_str(),
op.c_str(),
node->op == AstExprBinary::CompareEq ? "~=" : "=="
);
else
emitWarning(*context, LintWarning::Code_ComparisonPrecedence, node->location,
"not X %s Y is equivalent to (not X) %s Y; add parentheses to silence", op.c_str(), op.c_str());
emitWarning(
*context,
LintWarning::Code_ComparisonPrecedence,
node->location,
"not X %s Y is equivalent to (not X) %s Y; add parentheses to silence",
op.c_str(),
op.c_str()
);
}
else if (AstExprBinary* left = node->left->as<AstExprBinary>(); left && isComparison(left->op))
{
@ -2772,12 +3033,29 @@ private:
std::string rop = toString(node->op);
if (isEquality(left->op) || isEquality(node->op))
emitWarning(*context, LintWarning::Code_ComparisonPrecedence, node->location,
"X %s Y %s Z is equivalent to (X %s Y) %s Z; add parentheses to silence", lop.c_str(), rop.c_str(), lop.c_str(), rop.c_str());
emitWarning(
*context,
LintWarning::Code_ComparisonPrecedence,
node->location,
"X %s Y %s Z is equivalent to (X %s Y) %s Z; add parentheses to silence",
lop.c_str(),
rop.c_str(),
lop.c_str(),
rop.c_str()
);
else
emitWarning(*context, LintWarning::Code_ComparisonPrecedence, node->location,
"X %s Y %s Z is equivalent to (X %s Y) %s Z; did you mean X %s Y and Y %s Z?", lop.c_str(), rop.c_str(), lop.c_str(), rop.c_str(),
lop.c_str(), rop.c_str());
emitWarning(
*context,
LintWarning::Code_ComparisonPrecedence,
node->location,
"X %s Y %s Z is equivalent to (X %s Y) %s Z; did you mean X %s Y and Y %s Z?",
lop.c_str(),
rop.c_str(),
lop.c_str(),
rop.c_str(),
lop.c_str(),
rop.c_str()
);
}
return true;
@ -2843,8 +3121,12 @@ static void lintComments(LintContext& context, const std::vector<HotComment>& ho
if (!hc.header)
{
emitWarning(context, LintWarning::Code_CommentDirective, hc.location,
"Comment directive is ignored because it is placed after the first non-comment token");
emitWarning(
context,
LintWarning::Code_CommentDirective,
hc.location,
"Comment directive is ignored because it is placed after the first non-comment token"
);
}
else
{
@ -2865,21 +3147,36 @@ static void lintComments(LintContext& context, const std::vector<HotComment>& ho
// skip Unknown
if (const char* suggestion = fuzzyMatch(rule, kWarningNames + 1, LintWarning::Code__Count - 1))
emitWarning(context, LintWarning::Code_CommentDirective, hc.location,
"nolint directive refers to unknown lint rule '%s'; did you mean '%s'?", rule, suggestion);
emitWarning(
context,
LintWarning::Code_CommentDirective,
hc.location,
"nolint directive refers to unknown lint rule '%s'; did you mean '%s'?",
rule,
suggestion
);
else
emitWarning(
context, LintWarning::Code_CommentDirective, hc.location, "nolint directive refers to unknown lint rule '%s'", rule);
context, LintWarning::Code_CommentDirective, hc.location, "nolint directive refers to unknown lint rule '%s'", rule
);
}
}
else if (first == "nocheck" || first == "nonstrict" || first == "strict")
{
if (space != std::string::npos)
emitWarning(context, LintWarning::Code_CommentDirective, hc.location,
"Comment directive with the type checking mode has extra symbols at the end of the line");
emitWarning(
context,
LintWarning::Code_CommentDirective,
hc.location,
"Comment directive with the type checking mode has extra symbols at the end of the line"
);
else if (seenMode)
emitWarning(context, LintWarning::Code_CommentDirective, hc.location,
"Comment directive with the type checking mode has already been used");
emitWarning(
context,
LintWarning::Code_CommentDirective,
hc.location,
"Comment directive with the type checking mode has already been used"
);
else
seenMode = true;
}
@ -2894,15 +3191,21 @@ static void lintComments(LintContext& context, const std::vector<HotComment>& ho
const char* level = hc.content.c_str() + notspace;
if (strcmp(level, "0") && strcmp(level, "1") && strcmp(level, "2"))
emitWarning(context, LintWarning::Code_CommentDirective, hc.location,
"optimize directive uses unknown optimization level '%s', 0..2 expected", level);
emitWarning(
context,
LintWarning::Code_CommentDirective,
hc.location,
"optimize directive uses unknown optimization level '%s', 0..2 expected",
level
);
}
}
else if (first == "native")
{
if (space != std::string::npos)
emitWarning(
context, LintWarning::Code_CommentDirective, hc.location, "native directive has extra symbols at the end of the line");
context, LintWarning::Code_CommentDirective, hc.location, "native directive has extra symbols at the end of the line"
);
}
else
{
@ -2916,11 +3219,19 @@ static void lintComments(LintContext& context, const std::vector<HotComment>& ho
};
if (const char* suggestion = fuzzyMatch(first, kHotComments, std::size(kHotComments)))
emitWarning(context, LintWarning::Code_CommentDirective, hc.location, "Unknown comment directive '%.*s'; did you mean '%s'?",
int(first.size()), first.data(), suggestion);
emitWarning(
context,
LintWarning::Code_CommentDirective,
hc.location,
"Unknown comment directive '%.*s'; did you mean '%s'?",
int(first.size()),
first.data(),
suggestion
);
else
emitWarning(context, LintWarning::Code_CommentDirective, hc.location, "Unknown comment directive '%.*s'", int(first.size()),
first.data());
emitWarning(
context, LintWarning::Code_CommentDirective, hc.location, "Unknown comment directive '%.*s'", int(first.size()), first.data()
);
}
}
}
@ -2973,8 +3284,12 @@ private:
{
if (attribute->type == AstAttr::Type::Native)
{
emitWarning(*context, LintWarning::Code_RedundantNativeAttribute, attribute->location,
"native attribute on a function is redundant in a native module; consider removing it");
emitWarning(
*context,
LintWarning::Code_RedundantNativeAttribute,
attribute->location,
"native attribute on a function is redundant in a native module; consider removing it"
);
}
}
@ -2982,8 +3297,14 @@ private:
}
};
std::vector<LintWarning> lint(AstStat* root, const AstNameTable& names, const ScopePtr& env, const Module* module,
const std::vector<HotComment>& hotcomments, const LintOptions& options)
std::vector<LintWarning> lint(
AstStat* root,
const AstNameTable& names,
const ScopePtr& env,
const Module* module,
const std::vector<HotComment>& hotcomments,
const LintOptions& options
)
{
LintContext context;
@ -3068,8 +3389,7 @@ std::vector<LintWarning> lint(AstStat* root, const AstNameTable& names, const Sc
if (context.warningEnabled(LintWarning::Code_ComparisonPrecedence))
LintComparisonPrecedence::process(context);
if (FFlag::LuauNativeAttribute && FFlag::LintRedundantNativeAttribute &&
context.warningEnabled(LintWarning::Code_RedundantNativeAttribute))
if (FFlag::LuauNativeAttribute && FFlag::LintRedundantNativeAttribute && context.warningEnabled(LintWarning::Code_RedundantNativeAttribute))
{
if (hasNativeCommentDirective(hotcomments))
LintRedundantNativeAttribute::process(context);

View file

@ -24,8 +24,8 @@ static bool contains(Position pos, Comment comment)
{
if (comment.location.contains(pos))
return true;
else if (comment.type == Lexeme::BrokenComment &&
comment.location.begin <= pos) // Broken comments are broken specifically because they don't have an end
else if (comment.type == Lexeme::BrokenComment && comment.location.begin <= pos) // Broken comments are broken specifically because they don't
// have an end
return true;
else if (comment.type == Lexeme::Comment && comment.location.end == pos)
return true;
@ -36,9 +36,14 @@ static bool contains(Position pos, Comment comment)
static bool isWithinComment(const std::vector<Comment>& commentLocations, Position pos)
{
auto iter = std::lower_bound(
commentLocations.begin(), commentLocations.end(), Comment{Lexeme::Comment, Location{pos, pos}}, [](const Comment& a, const Comment& b) {
commentLocations.begin(),
commentLocations.end(),
Comment{Lexeme::Comment, Location{pos, pos}},
[](const Comment& a, const Comment& b)
{
return a.location.end < b.location.end;
});
}
);
if (iter == commentLocations.end())
return false;

View file

@ -69,7 +69,11 @@ struct NonStrictContext
NonStrictContext& operator=(NonStrictContext&&) = default;
static NonStrictContext disjunction(
NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, const NonStrictContext& left, const NonStrictContext& right)
NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> arena,
const NonStrictContext& left,
const NonStrictContext& right
)
{
// disjunction implements union over the domain of keys
// if the default value for a defId not in the map is `never`
@ -94,7 +98,11 @@ struct NonStrictContext
}
static NonStrictContext conjunction(
NotNull<BuiltinTypes> builtins, NotNull<TypeArena> arena, const NonStrictContext& left, const NonStrictContext& right)
NotNull<BuiltinTypes> builtins,
NotNull<TypeArena> arena,
const NonStrictContext& left,
const NonStrictContext& right
)
{
NonStrictContext conj{};
@ -160,8 +168,15 @@ struct NonStrictTypeChecker
const NotNull<TypeCheckLimits> limits;
NonStrictTypeChecker(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtinTypes, const NotNull<InternalErrorReporter> ice,
NotNull<UnifierSharedState> unifierState, NotNull<const DataFlowGraph> dfg, NotNull<TypeCheckLimits> limits, Module* module)
NonStrictTypeChecker(
NotNull<TypeArena> arena,
NotNull<BuiltinTypes> builtinTypes,
const NotNull<InternalErrorReporter> ice,
NotNull<UnifierSharedState> unifierState,
NotNull<const DataFlowGraph> dfg,
NotNull<TypeCheckLimits> limits,
Module* module
)
: builtinTypes(builtinTypes)
, ice(ice)
, arena(arena)
@ -213,7 +228,8 @@ struct NonStrictTypeChecker
return instance;
ErrorVec errors =
reduceTypeFunctions(instance, location, TypeFunctionContext{arena, builtinTypes, stack.back(), NotNull{&normalizer}, ice, limits}, true).errors;
reduceTypeFunctions(instance, location, TypeFunctionContext{arena, builtinTypes, stack.back(), NotNull{&normalizer}, ice, limits}, true)
.errors;
if (errors.empty())
noTypeFunctionErrors.insert(instance);
@ -271,6 +287,8 @@ struct NonStrictTypeChecker
return visit(s);
else if (auto s = stat->as<AstStatTypeAlias>())
return visit(s);
else if (auto f = stat->as<AstStatTypeFunction>())
return visit(f);
else if (auto s = stat->as<AstStatDeclareFunction>())
return visit(s);
else if (auto s = stat->as<AstStatDeclareGlobal>())
@ -395,6 +413,12 @@ struct NonStrictTypeChecker
return {};
}
NonStrictContext visit(AstStatTypeFunction* typeFunc)
{
reportError(GenericError{"This syntax is not supported"}, typeFunc->location);
return {};
}
NonStrictContext visit(AstStatDeclareFunction* declFn)
{
return {};
@ -726,8 +750,15 @@ private:
};
};
void checkNonStrict(NotNull<BuiltinTypes> builtinTypes, NotNull<InternalErrorReporter> ice, NotNull<UnifierSharedState> unifierState,
NotNull<const DataFlowGraph> dfg, NotNull<TypeCheckLimits> limits, const SourceModule& sourceModule, Module* module)
void checkNonStrict(
NotNull<BuiltinTypes> builtinTypes,
NotNull<InternalErrorReporter> ice,
NotNull<UnifierSharedState> unifierState,
NotNull<const DataFlowGraph> dfg,
NotNull<TypeCheckLimits> limits,
const SourceModule& sourceModule,
Module* module
)
{
LUAU_TIMETRACE_SCOPE("checkNonStrict", "Typechecking");

View file

@ -159,10 +159,15 @@ size_t TypeIds::getHash() const
bool TypeIds::isNever() const
{
return std::all_of(begin(), end(), [&](TypeId i) {
// If each typeid is never, then I guess typeid's is also never?
return get<NeverType>(i) != nullptr;
});
return std::all_of(
begin(),
end(),
[&](TypeId i)
{
// If each typeid is never, then I guess typeid's is also never?
return get<NeverType>(i) != nullptr;
}
);
}
bool TypeIds::operator==(const TypeIds& there) const
@ -371,10 +376,15 @@ bool NormalizedType::shouldSuppressErrors() const
bool NormalizedType::hasTopTable() const
{
return hasTables() && std::any_of(tables.begin(), tables.end(), [&](TypeId ty) {
auto primTy = get<PrimitiveType>(ty);
return primTy && primTy->type == PrimitiveType::Type::Table;
});
return hasTables() && std::any_of(
tables.begin(),
tables.end(),
[&](TypeId ty)
{
auto primTy = get<PrimitiveType>(ty);
return primTy && primTy->type == PrimitiveType::Type::Table;
}
);
}
bool NormalizedType::hasTops() const
@ -449,7 +459,7 @@ bool NormalizedType::isFalsy() const
}
return (hasAFalse || hasNils()) && (!hasTops() && !hasClasses() && !hasErrors() && !hasNumbers() && !hasStrings() && !hasThreads() &&
!hasBuffers() && !hasTables() && !hasFunctions() && !hasTyvars());
!hasBuffers() && !hasTables() && !hasFunctions() && !hasTyvars());
}
bool NormalizedType::isTruthy() const
@ -806,7 +816,8 @@ static bool areNormalizedClasses(const NormalizedClassType& tys)
if (isSubclass(ctv, octv))
{
auto iss = [ctv](TypeId t) {
auto iss = [ctv](TypeId t)
{
const ClassType* c = get<ClassType>(t);
if (!c)
return false;
@ -970,7 +981,6 @@ NormalizationResult Normalizer::normalizeIntersections(const std::vector<TypeId>
NormalizedType norm{builtinTypes};
norm.tops = builtinTypes->anyType;
// Now we need to intersect the two types
Set<TypeId> seenSetTypes{nullptr};
for (auto ty : intersections)
{
NormalizationResult res = intersectNormalWithTy(norm, ty, seenSet);
@ -1417,8 +1427,9 @@ std::optional<TypePackId> Normalizer::unionOfTypePacks(TypePackId here, TypePack
itt++;
}
auto dealWithDifferentArities = [&](TypePackIterator& ith, TypePackIterator itt, TypePackId here, TypePackId there, bool& hereSubThere,
bool& thereSubHere) {
auto dealWithDifferentArities =
[&](TypePackIterator& ith, TypePackIterator itt, TypePackId here, TypePackId there, bool& hereSubThere, bool& thereSubHere)
{
if (ith != end(here))
{
TypeId tty = builtinTypes->nilType;
@ -1803,8 +1814,7 @@ NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId t
}
else if (get<UnknownType>(here.tops))
return NormalizationResult::True;
else if (get<GenericType>(there) || get<FreeType>(there) || get<BlockedType>(there) || get<PendingExpansionType>(there) ||
get<TypeFunctionInstanceType>(there))
else if (get<GenericType>(there) || get<FreeType>(there) || get<BlockedType>(there) || get<PendingExpansionType>(there) || get<TypeFunctionInstanceType>(there))
{
if (tyvarIndex(there) <= ignoreSmallerTyvars)
return NormalizationResult::True;
@ -2379,8 +2389,9 @@ std::optional<TypePackId> Normalizer::intersectionOfTypePacks(TypePackId here, T
itt++;
}
auto dealWithDifferentArities = [&](TypePackIterator& ith, TypePackIterator itt, TypePackId here, TypePackId there, bool& hereSubThere,
bool& thereSubHere) {
auto dealWithDifferentArities =
[&](TypePackIterator& ith, TypePackIterator itt, TypePackId here, TypePackId there, bool& hereSubThere, bool& thereSubHere)
{
if (ith != end(here))
{
TypeId tty = builtinTypes->nilType;
@ -2570,7 +2581,7 @@ std::optional<TypeId> Normalizer::intersectionOfTables(TypeId here, TypeId there
}
}
NormalizationResult res = isIntersectionInhabited(*hprop.readTy, *tprop.readTy, seenSet);
NormalizationResult res = isIntersectionInhabited(*hprop.readTy, *tprop.readTy);
// Cleanup
if (fixCyclicTablesBlowingStack())
@ -3088,8 +3099,7 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type
}
return NormalizationResult::True;
}
else if (get<GenericType>(there) || get<FreeType>(there) || get<BlockedType>(there) || get<PendingExpansionType>(there) ||
get<TypeFunctionInstanceType>(there))
else if (get<GenericType>(there) || get<FreeType>(there) || get<BlockedType>(there) || get<PendingExpansionType>(there) || get<TypeFunctionInstanceType>(there))
{
NormalizedType thereNorm{builtinTypes};
NormalizedType topNorm{builtinTypes};
@ -3441,7 +3451,12 @@ bool isConsistentSubtype(TypeId subTy, TypeId superTy, NotNull<Scope> scope, Not
}
bool isConsistentSubtype(
TypePackId subPack, TypePackId superPack, NotNull<Scope> scope, NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter& ice)
TypePackId subPack,
TypePackId superPack,
NotNull<Scope> scope,
NotNull<BuiltinTypes> builtinTypes,
InternalErrorReporter& ice
)
{
LUAU_ASSERT(!FFlag::DebugLuauDeferredConstraintResolution);

View file

@ -13,8 +13,15 @@
namespace Luau
{
OverloadResolver::OverloadResolver(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, NotNull<Normalizer> normalizer, NotNull<Scope> scope,
NotNull<InternalErrorReporter> reporter, NotNull<TypeCheckLimits> limits, Location callLocation)
OverloadResolver::OverloadResolver(
NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> arena,
NotNull<Normalizer> normalizer,
NotNull<Scope> scope,
NotNull<InternalErrorReporter> reporter,
NotNull<TypeCheckLimits> limits,
Location callLocation
)
: builtinTypes(builtinTypes)
, arena(arena)
, normalizer(normalizer)
@ -28,10 +35,15 @@ OverloadResolver::OverloadResolver(NotNull<BuiltinTypes> builtinTypes, NotNull<T
std::pair<OverloadResolver::Analysis, TypeId> OverloadResolver::selectOverload(TypeId ty, TypePackId argsPack)
{
auto tryOne = [&](TypeId f) {
auto tryOne = [&](TypeId f)
{
if (auto ftv = get<FunctionType>(f))
{
Subtyping::Variance variance = subtyping.variance;
subtyping.variance = Subtyping::Variance::Contravariant;
SubtypingResult r = subtyping.isSubtype(argsPack, ftv->argTypes);
subtyping.variance = variance;
if (r.isSubtype)
return true;
}
@ -137,7 +149,12 @@ std::optional<ErrorVec> OverloadResolver::testIsSubtype(const Location& location
}
std::pair<OverloadResolver::Analysis, ErrorVec> OverloadResolver::checkOverload(
TypeId fnTy, const TypePack* args, AstExpr* fnLoc, const std::vector<AstExpr*>* argExprs, bool callMetamethodOk)
TypeId fnTy,
const TypePack* args,
AstExpr* fnLoc,
const std::vector<AstExpr*>* argExprs,
bool callMetamethodOk
)
{
fnTy = follow(fnTy);
@ -173,7 +190,12 @@ bool OverloadResolver::isLiteral(AstExpr* expr)
}
std::pair<OverloadResolver::Analysis, ErrorVec> OverloadResolver::checkOverload_(
TypeId fnTy, const FunctionType* fn, const TypePack* args, AstExpr* fnExpr, const std::vector<AstExpr*>* argExprs)
TypeId fnTy,
const FunctionType* fn,
const TypePack* args,
AstExpr* fnExpr,
const std::vector<AstExpr*>* argExprs
)
{
FunctionGraphReductionResult result =
reduceTypeFunctions(fnTy, callLoc, TypeFunctionContext{arena, builtinTypes, scope, normalizer, ice, limits}, /*force=*/true);
@ -373,9 +395,17 @@ void OverloadResolver::add(Analysis analysis, TypeId ty, ErrorVec&& errors)
// we wrap calling the overload resolver in a separate function to reduce overall stack pressure in `solveFunctionCall`.
// this limits the lifetime of `OverloadResolver`, a large type, to only as long as it is actually needed.
std::optional<TypeId> selectOverload(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, NotNull<Normalizer> normalizer,
NotNull<Scope> scope, NotNull<InternalErrorReporter> iceReporter, NotNull<TypeCheckLimits> limits, const Location& location, TypeId fn,
TypePackId argsPack)
std::optional<TypeId> selectOverload(
NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> arena,
NotNull<Normalizer> normalizer,
NotNull<Scope> scope,
NotNull<InternalErrorReporter> iceReporter,
NotNull<TypeCheckLimits> limits,
const Location& location,
TypeId fn,
TypePackId argsPack
)
{
OverloadResolver resolver{builtinTypes, arena, normalizer, scope, iceReporter, limits, location};
auto [status, overload] = resolver.selectOverload(fn, argsPack);
@ -389,9 +419,17 @@ std::optional<TypeId> selectOverload(NotNull<BuiltinTypes> builtinTypes, NotNull
return {};
}
SolveResult solveFunctionCall(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtinTypes, NotNull<Normalizer> normalizer,
NotNull<InternalErrorReporter> iceReporter, NotNull<TypeCheckLimits> limits, NotNull<Scope> scope, const Location& location, TypeId fn,
TypePackId argsPack)
SolveResult solveFunctionCall(
NotNull<TypeArena> arena,
NotNull<BuiltinTypes> builtinTypes,
NotNull<Normalizer> normalizer,
NotNull<InternalErrorReporter> iceReporter,
NotNull<TypeCheckLimits> limits,
NotNull<Scope> scope,
const Location& location,
TypeId fn,
TypePackId argsPack
)
{
std::optional<TypeId> overloadToUse = selectOverload(builtinTypes, arena, normalizer, scope, iceReporter, limits, location, fn, argsPack);
if (!overloadToUse)

View file

@ -8,8 +8,6 @@
#include "Luau/Type.h"
#include "Luau/VisitType.h"
LUAU_FASTFLAG(DebugLuauSharedSelf)
namespace Luau
{
@ -100,53 +98,13 @@ struct Quantifier final : TypeOnceVisitor
void quantify(TypeId ty, TypeLevel level)
{
if (FFlag::DebugLuauSharedSelf)
{
ty = follow(ty);
Quantifier q{level};
q.traverse(ty);
if (auto ttv = getTableType(ty); ttv && ttv->selfTy)
{
Quantifier selfQ{level};
selfQ.traverse(*ttv->selfTy);
Quantifier q{level};
q.traverse(ty);
for (const auto& [_, prop] : ttv->props)
{
auto ftv = getMutable<FunctionType>(follow(prop.type()));
if (!ftv || !ftv->hasSelf)
continue;
if (Luau::first(ftv->argTypes) == ttv->selfTy)
{
ftv->generics.insert(ftv->generics.end(), selfQ.generics.begin(), selfQ.generics.end());
ftv->genericPacks.insert(ftv->genericPacks.end(), selfQ.genericPacks.begin(), selfQ.genericPacks.end());
}
}
}
else if (auto ftv = getMutable<FunctionType>(ty))
{
Quantifier q{level};
q.traverse(ty);
ftv->generics.insert(ftv->generics.end(), q.generics.begin(), q.generics.end());
ftv->genericPacks.insert(ftv->genericPacks.end(), q.genericPacks.begin(), q.genericPacks.end());
if (ftv->generics.empty() && ftv->genericPacks.empty() && !q.seenMutableType && !q.seenGenericType)
ftv->hasNoFreeOrGenericTypes = true;
}
}
else
{
Quantifier q{level};
q.traverse(ty);
FunctionType* ftv = getMutable<FunctionType>(ty);
LUAU_ASSERT(ftv);
ftv->generics.insert(ftv->generics.end(), q.generics.begin(), q.generics.end());
ftv->genericPacks.insert(ftv->genericPacks.end(), q.genericPacks.begin(), q.genericPacks.end());
}
FunctionType* ftv = getMutable<FunctionType>(ty);
LUAU_ASSERT(ftv);
ftv->generics.insert(ftv->generics.end(), q.generics.begin(), q.generics.end());
ftv->genericPacks.insert(ftv->genericPacks.end(), q.genericPacks.begin(), q.genericPacks.end());
}
struct PureQuantifier : Substitution

View file

@ -238,12 +238,22 @@ Relation relateTables(TypeId left, TypeId right, SimplifierSeenSet& seen)
LUAU_ASSERT(1 == rightTable->props.size());
// Disjoint props have nothing in common
// t1 with props p1's cannot appear in t2 and t2 with props p2's cannot appear in t1
bool foundPropFromLeftInRight = std::any_of(begin(leftTable->props), end(leftTable->props), [&](auto prop) {
return rightTable->props.count(prop.first) > 0;
});
bool foundPropFromRightInLeft = std::any_of(begin(rightTable->props), end(rightTable->props), [&](auto prop) {
return leftTable->props.count(prop.first) > 0;
});
bool foundPropFromLeftInRight = std::any_of(
begin(leftTable->props),
end(leftTable->props),
[&](auto prop)
{
return rightTable->props.count(prop.first) > 0;
}
);
bool foundPropFromRightInLeft = std::any_of(
begin(rightTable->props),
end(rightTable->props),
[&](auto prop)
{
return leftTable->props.count(prop.first) > 0;
}
);
if (!foundPropFromLeftInRight && !foundPropFromRightInLeft && leftTable->props.size() >= 1 && rightTable->props.size() >= 1)
return Relation::Disjoint;
@ -1112,8 +1122,13 @@ std::optional<TypeId> TypeSimplifier::basicIntersect(TypeId left, TypeId right)
{
case Relation::Disjoint:
return builtinTypes->neverType;
case Relation::Superset:
case Relation::Coincident:
return right;
case Relation::Subset:
if (1 == rt->props.size())
return left;
break;
default:
break;
}
@ -1121,6 +1136,40 @@ std::optional<TypeId> TypeSimplifier::basicIntersect(TypeId left, TypeId right)
}
else if (1 == rt->props.size())
return basicIntersect(right, left);
// If two tables have disjoint properties and indexers, we can combine them.
if (!lt->indexer && !rt->indexer && lt->state == TableState::Sealed && rt->state == TableState::Sealed)
{
if (rt->props.empty())
return left;
bool areDisjoint = true;
for (const auto& [name, leftProp]: lt->props)
{
if (rt->props.count(name))
{
areDisjoint = false;
break;
}
}
if (areDisjoint)
{
TableType::Props mergedProps = lt->props;
for (const auto& [name, rightProp]: rt->props)
mergedProps[name] = rightProp;
return arena->addType(TableType{
mergedProps,
std::nullopt,
TypeLevel{},
lt->scope,
TableState::Sealed
});
}
}
return std::nullopt;
}
Relation relation = relate(left, right);

View file

@ -18,7 +18,8 @@ namespace Luau
static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysClone)
{
auto go = [ty, &dest, alwaysClone](auto&& a) {
auto go = [ty, &dest, alwaysClone](auto&& a)
{
using T = std::decay_t<decltype(a)>;
// The pointer identities of free and local types is very important.
@ -672,7 +673,8 @@ TypePackId Substitution::clone(TypePackId tp)
else if (const TypeFunctionInstanceTypePack* tfitp = get<TypeFunctionInstanceTypePack>(tp))
{
TypeFunctionInstanceTypePack clone{
tfitp->function, std::vector<TypeId>(tfitp->typeArguments.size()), std::vector<TypePackId>(tfitp->packArguments.size())};
tfitp->function, std::vector<TypeId>(tfitp->typeArguments.size()), std::vector<TypePackId>(tfitp->packArguments.size())
};
clone.typeArguments.assign(tfitp->typeArguments.begin(), tfitp->typeArguments.end());
clone.packArguments.assign(tfitp->packArguments.begin(), tfitp->packArguments.end());
return addTypePack(std::move(clone));

View file

@ -91,7 +91,8 @@ static SubtypingReasonings mergeReasonings(const SubtypingReasonings& a, const S
else if (r.variance == SubtypingVariance::Covariant || r.variance == SubtypingVariance::Contravariant)
{
SubtypingReasoning inverseReasoning = SubtypingReasoning{
r.subPath, r.superPath, r.variance == SubtypingVariance::Covariant ? SubtypingVariance::Contravariant : SubtypingVariance::Covariant};
r.subPath, r.superPath, r.variance == SubtypingVariance::Covariant ? SubtypingVariance::Contravariant : SubtypingVariance::Covariant
};
if (b.contains(inverseReasoning))
result.insert(SubtypingReasoning{r.subPath, r.superPath, SubtypingVariance::Invariant});
else
@ -106,7 +107,8 @@ static SubtypingReasonings mergeReasonings(const SubtypingReasonings& a, const S
else if (r.variance == SubtypingVariance::Covariant || r.variance == SubtypingVariance::Contravariant)
{
SubtypingReasoning inverseReasoning = SubtypingReasoning{
r.subPath, r.superPath, r.variance == SubtypingVariance::Covariant ? SubtypingVariance::Contravariant : SubtypingVariance::Covariant};
r.subPath, r.superPath, r.variance == SubtypingVariance::Covariant ? SubtypingVariance::Contravariant : SubtypingVariance::Covariant
};
if (a.contains(inverseReasoning))
result.insert(SubtypingReasoning{r.subPath, r.superPath, SubtypingVariance::Invariant});
else
@ -267,7 +269,11 @@ struct ApplyMappedGenerics : Substitution
ApplyMappedGenerics(
NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, MappedGenerics& mappedGenerics, MappedGenericPacks& mappedGenericPacks)
NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> arena,
MappedGenerics& mappedGenerics,
MappedGenericPacks& mappedGenericPacks
)
: Substitution(TxnLog::empty(), arena)
, builtinTypes(builtinTypes)
, arena(arena)
@ -323,8 +329,13 @@ std::optional<TypeId> SubtypingEnvironment::applyMappedGenerics(NotNull<BuiltinT
return amg.substitute(ty);
}
Subtyping::Subtyping(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> typeArena, NotNull<Normalizer> normalizer,
NotNull<InternalErrorReporter> iceReporter, NotNull<Scope> scope)
Subtyping::Subtyping(
NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> typeArena,
NotNull<Normalizer> normalizer,
NotNull<InternalErrorReporter> iceReporter,
NotNull<Scope> scope
)
: builtinTypes(builtinTypes)
, arena(typeArena)
, normalizer(normalizer)
@ -1243,8 +1254,7 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const Tabl
std::vector<SubtypingResult> results;
if (auto subIter = subTable->props.find(name); subIter != subTable->props.end())
results.push_back(isCovariantWith(env, subIter->second, superProp, name));
if (subTable->indexer)
else if (subTable->indexer)
{
if (isCovariantWith(env, builtinTypes->stringType, subTable->indexer->indexType).isSubtype)
{
@ -1317,7 +1327,12 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const Clas
}
SubtypingResult Subtyping::isCovariantWith(
SubtypingEnvironment& env, TypeId subTy, const ClassType* subClass, TypeId superTy, const TableType* superTable)
SubtypingEnvironment& env,
TypeId subTy,
const ClassType* subClass,
TypeId superTy,
const TableType* superTable
)
{
SubtypingResult result{true};
@ -1366,7 +1381,8 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const Prim
{
if (auto stringTable = get<TableType>(it->second.type()))
result.orElse(
isCovariantWith(env, stringTable, superTable).withSubPath(TypePath::PathBuilder().mt().readProp("__index").build()));
isCovariantWith(env, stringTable, superTable).withSubPath(TypePath::PathBuilder().mt().readProp("__index").build())
);
}
}
}
@ -1388,7 +1404,8 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const Sing
{
if (auto stringTable = get<TableType>(it->second.type()))
result.orElse(
isCovariantWith(env, stringTable, superTable).withSubPath(TypePath::PathBuilder().mt().readProp("__index").build()));
isCovariantWith(env, stringTable, superTable).withSubPath(TypePath::PathBuilder().mt().readProp("__index").build())
);
}
}
}
@ -1429,7 +1446,10 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const Prop
}
SubtypingResult Subtyping::isCovariantWith(
SubtypingEnvironment& env, const std::shared_ptr<const NormalizedType>& subNorm, const std::shared_ptr<const NormalizedType>& superNorm)
SubtypingEnvironment& env,
const std::shared_ptr<const NormalizedType>& subNorm,
const std::shared_ptr<const NormalizedType>& superNorm
)
{
if (!subNorm || !superNorm)
return {false, true};
@ -1540,7 +1560,10 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const Norm
}
SubtypingResult Subtyping::isCovariantWith(
SubtypingEnvironment& env, const NormalizedFunctionType& subFunction, const NormalizedFunctionType& superFunction)
SubtypingEnvironment& env,
const NormalizedFunctionType& subFunction,
const NormalizedFunctionType& superFunction
)
{
if (subFunction.isNever())
return {true};

View file

@ -13,8 +13,10 @@ namespace Luau
static bool isLiteral(const AstExpr* expr)
{
return (expr->is<AstExprTable>() || expr->is<AstExprFunction>() || expr->is<AstExprConstantNumber>() || expr->is<AstExprConstantString>() ||
expr->is<AstExprConstantBool>() || expr->is<AstExprConstantNil>());
return (
expr->is<AstExprTable>() || expr->is<AstExprFunction>() || expr->is<AstExprConstantNumber>() || expr->is<AstExprConstantString>() ||
expr->is<AstExprConstantBool>() || expr->is<AstExprConstantNil>()
);
}
// A fast approximation of subTy <: superTy
@ -108,9 +110,17 @@ static std::optional<TypeId> extractMatchingTableType(std::vector<TypeId>& table
return std::nullopt;
}
TypeId matchLiteralType(NotNull<DenseHashMap<const AstExpr*, TypeId>> astTypes, NotNull<DenseHashMap<const AstExpr*, TypeId>> astExpectedTypes,
NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, NotNull<Unifier2> unifier, TypeId expectedType, TypeId exprType,
const AstExpr* expr, std::vector<TypeId>& toBlock)
TypeId matchLiteralType(
NotNull<DenseHashMap<const AstExpr*, TypeId>> astTypes,
NotNull<DenseHashMap<const AstExpr*, TypeId>> astExpectedTypes,
NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> arena,
NotNull<Unifier2> unifier,
TypeId expectedType,
TypeId exprType,
const AstExpr* expr,
std::vector<TypeId>& toBlock
)
{
/*
* Table types that arise from literal table expressions have some
@ -208,7 +218,7 @@ TypeId matchLiteralType(NotNull<DenseHashMap<const AstExpr*, TypeId>> astTypes,
if (auto exprTable = expr->as<AstExprTable>())
{
TableType* tableTy = getMutable<TableType>(exprType);
TableType* const tableTy = getMutable<TableType>(exprType);
LUAU_ASSERT(tableTy);
const TableType* expectedTableTy = get<TableType>(expectedType);
@ -260,8 +270,17 @@ TypeId matchLiteralType(NotNull<DenseHashMap<const AstExpr*, TypeId>> astTypes,
(*astExpectedTypes)[item.key] = expectedTableTy->indexer->indexType;
(*astExpectedTypes)[item.value] = expectedTableTy->indexer->indexResultType;
TypeId matchedType = matchLiteralType(astTypes, astExpectedTypes, builtinTypes, arena, unifier,
expectedTableTy->indexer->indexResultType, propTy, item.value, toBlock);
TypeId matchedType = matchLiteralType(
astTypes,
astExpectedTypes,
builtinTypes,
arena,
unifier,
expectedTableTy->indexer->indexResultType,
propTy,
item.value,
toBlock
);
if (tableTy->indexer)
unifier->unify(matchedType, tableTy->indexer->indexResultType);
@ -334,8 +353,17 @@ TypeId matchLiteralType(NotNull<DenseHashMap<const AstExpr*, TypeId>> astTypes,
LUAU_ASSERT(propTy);
unifier->unify(expectedTableTy->indexer->indexType, builtinTypes->numberType);
TypeId matchedType = matchLiteralType(astTypes, astExpectedTypes, builtinTypes, arena, unifier,
expectedTableTy->indexer->indexResultType, *propTy, item.value, toBlock);
TypeId matchedType = matchLiteralType(
astTypes,
astExpectedTypes,
builtinTypes,
arena,
unifier,
expectedTableTy->indexer->indexResultType,
*propTy,
item.value,
toBlock
);
// if the index result type is the prop type, we can replace it with the matched type here.
if (tableTy->indexer->indexResultType == *propTy)
@ -410,6 +438,15 @@ TypeId matchLiteralType(NotNull<DenseHashMap<const AstExpr*, TypeId>> astTypes,
if (exprProp.readTy || exprProp.writeTy)
tableTy->props[*key] = std::move(exprProp);
}
// If the expected table has an indexer, then the provided table can
// have one too.
// TODO: If the expected table also has an indexer, we might want to
// push the expected indexer's types into it.
if (expectedTableTy->indexer && !tableTy->indexer)
{
tableTy->indexer = expectedTableTy->indexer;
}
}
return exprType;

View file

@ -146,7 +146,8 @@ void StateDot::visitChildren(TypeId ty, int index)
startNode(index);
startNodeLabel();
auto go = [&](auto&& t) {
auto go = [&](auto&& t)
{
using T = std::decay_t<decltype(t)>;
if constexpr (std::is_same_v<T, BoundType>)

View file

@ -168,7 +168,8 @@ struct StringifierState
DenseHashMap<TypeId, std::string> cycleNames{{}};
DenseHashMap<TypePackId, std::string> cycleTpNames{{}};
Set<void*> seen{{}};
// `$$$` was chosen as the tombstone for `usedNames` since it is not a valid name syntactically and is relatively short for string comparison reasons.
// `$$$` was chosen as the tombstone for `usedNames` since it is not a valid name syntactically and is relatively short for string comparison
// reasons.
DenseHashSet<std::string> usedNames{"$$$"};
size_t indentation = 0;
@ -356,10 +357,12 @@ struct TypeStringifier
}
Luau::visit(
[this, tv](auto&& t) {
[this, tv](auto&& t)
{
return (*this)(tv, t);
},
tv->ty);
tv->ty
);
}
void emitKey(const std::string& name)
@ -1104,10 +1107,12 @@ struct TypePackStringifier
}
Luau::visit(
[this, tp](auto&& t) {
[this, tp](auto&& t)
{
return (*this)(tp, t);
},
tp->ty);
tp->ty
);
}
void operator()(TypePackId, const TypePack& tp)
@ -1272,8 +1277,13 @@ void TypeStringifier::stringify(TypePackId tpid, const std::vector<std::optional
tps.stringify(tpid);
}
static void assignCycleNames(const std::set<TypeId>& cycles, const std::set<TypePackId>& cycleTPs, DenseHashMap<TypeId, std::string>& cycleNames,
DenseHashMap<TypePackId, std::string>& cycleTpNames, bool exhaustive)
static void assignCycleNames(
const std::set<TypeId>& cycles,
const std::set<TypePackId>& cycleTPs,
DenseHashMap<TypeId, std::string>& cycleNames,
DenseHashMap<TypePackId, std::string>& cycleTpNames,
bool exhaustive
)
{
int nextIndex = 1;
@ -1285,9 +1295,14 @@ static void assignCycleNames(const std::set<TypeId>& cycles, const std::set<Type
if (auto ttv = get<TableType>(follow(cycleTy)); !exhaustive && ttv && (ttv->syntheticName || ttv->name))
{
// If we have a cycle type in type parameters, assign a cycle name for this named table
if (std::find_if(ttv->instantiatedTypeParams.begin(), ttv->instantiatedTypeParams.end(), [&](auto&& el) {
return cycles.count(follow(el));
}) != ttv->instantiatedTypeParams.end())
if (std::find_if(
ttv->instantiatedTypeParams.begin(),
ttv->instantiatedTypeParams.end(),
[&](auto&& el)
{
return cycles.count(follow(el));
}
) != ttv->instantiatedTypeParams.end())
cycleNames[cycleTy] = ttv->name ? *ttv->name : *ttv->syntheticName;
continue;
@ -1381,9 +1396,14 @@ ToStringResult toStringDetailed(TypeId ty, ToStringOptions& opts)
state.exhaustive = true;
std::vector<std::pair<TypeId, std::string>> sortedCycleNames{state.cycleNames.begin(), state.cycleNames.end()};
std::sort(sortedCycleNames.begin(), sortedCycleNames.end(), [](const auto& a, const auto& b) {
return a.second < b.second;
});
std::sort(
sortedCycleNames.begin(),
sortedCycleNames.end(),
[](const auto& a, const auto& b)
{
return a.second < b.second;
}
);
bool semi = false;
for (const auto& [cycleTy, name] : sortedCycleNames)
@ -1394,18 +1414,25 @@ ToStringResult toStringDetailed(TypeId ty, ToStringOptions& opts)
state.emit(name);
state.emit(" = ");
Luau::visit(
[&tvs, cycleTy = cycleTy](auto&& t) {
[&tvs, cycleTy = cycleTy](auto&& t)
{
return tvs(cycleTy, t);
},
cycleTy->ty);
cycleTy->ty
);
semi = true;
}
std::vector<std::pair<TypePackId, std::string>> sortedCycleTpNames(state.cycleTpNames.begin(), state.cycleTpNames.end());
std::sort(sortedCycleTpNames.begin(), sortedCycleTpNames.end(), [](const auto& a, const auto& b) {
return a.second < b.second;
});
std::sort(
sortedCycleTpNames.begin(),
sortedCycleTpNames.end(),
[](const auto& a, const auto& b)
{
return a.second < b.second;
}
);
TypePackStringifier tps{state};
@ -1417,10 +1444,12 @@ ToStringResult toStringDetailed(TypeId ty, ToStringOptions& opts)
state.emit(name);
state.emit(" = ");
Luau::visit(
[&tps, cycleTy = cycleTp](auto&& t) {
[&tps, cycleTy = cycleTp](auto&& t)
{
return tps(cycleTy, t);
},
cycleTp->ty);
cycleTp->ty
);
semi = true;
}
@ -1474,9 +1503,14 @@ ToStringResult toStringDetailed(TypePackId tp, ToStringOptions& opts)
state.exhaustive = true;
std::vector<std::pair<TypeId, std::string>> sortedCycleNames{state.cycleNames.begin(), state.cycleNames.end()};
std::sort(sortedCycleNames.begin(), sortedCycleNames.end(), [](const auto& a, const auto& b) {
return a.second < b.second;
});
std::sort(
sortedCycleNames.begin(),
sortedCycleNames.end(),
[](const auto& a, const auto& b)
{
return a.second < b.second;
}
);
bool semi = false;
for (const auto& [cycleTy, name] : sortedCycleNames)
@ -1487,18 +1521,25 @@ ToStringResult toStringDetailed(TypePackId tp, ToStringOptions& opts)
state.emit(name);
state.emit(" = ");
Luau::visit(
[&tvs, cycleTy = cycleTy](auto t) {
[&tvs, cycleTy = cycleTy](auto t)
{
return tvs(cycleTy, t);
},
cycleTy->ty);
cycleTy->ty
);
semi = true;
}
std::vector<std::pair<TypePackId, std::string>> sortedCycleTpNames{state.cycleTpNames.begin(), state.cycleTpNames.end()};
std::sort(sortedCycleTpNames.begin(), sortedCycleTpNames.end(), [](const auto& a, const auto& b) {
return a.second < b.second;
});
std::sort(
sortedCycleTpNames.begin(),
sortedCycleTpNames.end(),
[](const auto& a, const auto& b)
{
return a.second < b.second;
}
);
TypePackStringifier tps{tvs.state};
@ -1510,10 +1551,12 @@ ToStringResult toStringDetailed(TypePackId tp, ToStringOptions& opts)
state.emit(name);
state.emit(" = ");
Luau::visit(
[&tps, cycleTp = cycleTp](auto t) {
[&tps, cycleTp = cycleTp](auto t)
{
return tps(cycleTp, t);
},
cycleTp->ty);
cycleTp->ty
);
semi = true;
}
@ -1713,10 +1756,12 @@ std::string toStringVector(const std::vector<TypeId>& types, ToStringOptions& op
std::string toString(const Constraint& constraint, ToStringOptions& opts)
{
auto go = [&opts](auto&& c) -> std::string {
auto go = [&opts](auto&& c) -> std::string
{
using T = std::decay_t<decltype(c)>;
auto tos = [&opts](auto&& a) {
auto tos = [&opts](auto&& a)
{
return toString(a, opts);
};

View file

@ -28,8 +28,8 @@ bool isIdentifierChar(char c)
return isIdentifierStartChar(c) || isDigit(c);
}
const std::vector<std::string> keywords = {"and", "break", "do", "else", "elseif", "end", "false", "for", "function", "if", "in", "local", "nil",
"not", "or", "repeat", "return", "then", "true", "until", "while"};
const std::vector<std::string> keywords = {"and", "break", "do", "else", "elseif", "end", "false", "for", "function", "if", "in",
"local", "nil", "not", "or", "repeat", "return", "then", "true", "until", "while"};
} // namespace
@ -844,6 +844,15 @@ struct Printer
visualizeTypeAnnotation(*a->type);
}
}
else if (const auto& t = program.as<AstStatTypeFunction>())
{
if (writeTypes)
{
writer.keyword("type function");
writer.identifier(t->name.value);
visualizeFunctionBody(*t->body);
}
}
else if (const auto& a = program.as<AstStatError>())
{
writer.symbol("(error-stat");

View file

@ -469,34 +469,44 @@ std::optional<TypeLevel> TxnLog::getLevel(TypeId ty) const
TypeId TxnLog::follow(TypeId ty) const
{
return Luau::follow(ty, this, [](const void* ctx, TypeId ty) -> TypeId {
const TxnLog* self = static_cast<const TxnLog*>(ctx);
PendingType* state = self->pending(ty);
return Luau::follow(
ty,
this,
[](const void* ctx, TypeId ty) -> TypeId
{
const TxnLog* self = static_cast<const TxnLog*>(ctx);
PendingType* state = self->pending(ty);
if (state == nullptr)
return ty;
if (state == nullptr)
return ty;
// Ugly: Fabricate a TypeId that doesn't adhere to most of the invariants
// that normally apply. This is safe because follow will only call get<>
// on the returned pointer.
return const_cast<const Type*>(&state->pending);
});
// Ugly: Fabricate a TypeId that doesn't adhere to most of the invariants
// that normally apply. This is safe because follow will only call get<>
// on the returned pointer.
return const_cast<const Type*>(&state->pending);
}
);
}
TypePackId TxnLog::follow(TypePackId tp) const
{
return Luau::follow(tp, this, [](const void* ctx, TypePackId tp) -> TypePackId {
const TxnLog* self = static_cast<const TxnLog*>(ctx);
PendingTypePack* state = self->pending(tp);
return Luau::follow(
tp,
this,
[](const void* ctx, TypePackId tp) -> TypePackId
{
const TxnLog* self = static_cast<const TxnLog*>(ctx);
PendingTypePack* state = self->pending(tp);
if (state == nullptr)
return tp;
if (state == nullptr)
return tp;
// Ugly: Fabricate a TypePackId that doesn't adhere to most of the
// invariants that normally apply. This is safe because follow will
// only call get<> on the returned pointer.
return const_cast<const TypePackVar*>(&state->pending);
});
// Ugly: Fabricate a TypePackId that doesn't adhere to most of the
// invariants that normally apply. This is safe because follow will
// only call get<> on the returned pointer.
return const_cast<const TypePackVar*>(&state->pending);
}
);
}
std::pair<std::vector<TypeId>, std::vector<TypePackId>> TxnLog::getChanges() const

View file

@ -58,9 +58,15 @@ TypeId follow(TypeId t)
TypeId follow(TypeId t, FollowOption followOption)
{
return follow(t, followOption, nullptr, [](const void*, TypeId t) -> TypeId {
return t;
});
return follow(
t,
followOption,
nullptr,
[](const void*, TypeId t) -> TypeId
{
return t;
}
);
}
TypeId follow(TypeId t, const void* context, TypeId (*mapper)(const void*, TypeId))
@ -70,7 +76,8 @@ TypeId follow(TypeId t, const void* context, TypeId (*mapper)(const void*, TypeI
TypeId follow(TypeId t, FollowOption followOption, const void* context, TypeId (*mapper)(const void*, TypeId))
{
auto advance = [followOption, context, mapper](TypeId ty) -> std::optional<TypeId> {
auto advance = [followOption, context, mapper](TypeId ty) -> std::optional<TypeId>
{
TypeId mapped = mapper(context, ty);
if (auto btv = get<Unifiable::Bound<TypeId>>(mapped))
@ -259,7 +266,8 @@ bool isOverloadedFunction(TypeId ty)
if (!get<IntersectionType>(follow(ty)))
return false;
auto isFunction = [](TypeId part) -> bool {
auto isFunction = [](TypeId part) -> bool
{
return get<FunctionType>(part);
};
@ -567,7 +575,11 @@ void BlockedType::replaceOwner(Constraint* newOwner)
}
PendingExpansionType::PendingExpansionType(
std::optional<AstName> prefix, AstName name, std::vector<TypeId> typeArguments, std::vector<TypePackId> packArguments)
std::optional<AstName> prefix,
AstName name,
std::vector<TypeId> typeArguments,
std::vector<TypePackId> packArguments
)
: prefix(prefix)
, name(name)
, typeArguments(typeArguments)
@ -596,7 +608,13 @@ FunctionType::FunctionType(TypeLevel level, TypePackId argTypes, TypePackId retT
}
FunctionType::FunctionType(
TypeLevel level, Scope* scope, TypePackId argTypes, TypePackId retTypes, std::optional<FunctionDefinition> defn, bool hasSelf)
TypeLevel level,
Scope* scope,
TypePackId argTypes,
TypePackId retTypes,
std::optional<FunctionDefinition> defn,
bool hasSelf
)
: definition(std::move(defn))
, level(level)
, scope(scope)
@ -606,8 +624,14 @@ FunctionType::FunctionType(
{
}
FunctionType::FunctionType(std::vector<TypeId> generics, std::vector<TypePackId> genericPacks, TypePackId argTypes, TypePackId retTypes,
std::optional<FunctionDefinition> defn, bool hasSelf)
FunctionType::FunctionType(
std::vector<TypeId> generics,
std::vector<TypePackId> genericPacks,
TypePackId argTypes,
TypePackId retTypes,
std::optional<FunctionDefinition> defn,
bool hasSelf
)
: definition(std::move(defn))
, generics(generics)
, genericPacks(genericPacks)
@ -617,8 +641,15 @@ FunctionType::FunctionType(std::vector<TypeId> generics, std::vector<TypePackId>
{
}
FunctionType::FunctionType(TypeLevel level, std::vector<TypeId> generics, std::vector<TypePackId> genericPacks, TypePackId argTypes,
TypePackId retTypes, std::optional<FunctionDefinition> defn, bool hasSelf)
FunctionType::FunctionType(
TypeLevel level,
std::vector<TypeId> generics,
std::vector<TypePackId> genericPacks,
TypePackId argTypes,
TypePackId retTypes,
std::optional<FunctionDefinition> defn,
bool hasSelf
)
: definition(std::move(defn))
, generics(generics)
, genericPacks(genericPacks)
@ -629,8 +660,16 @@ FunctionType::FunctionType(TypeLevel level, std::vector<TypeId> generics, std::v
{
}
FunctionType::FunctionType(TypeLevel level, Scope* scope, std::vector<TypeId> generics, std::vector<TypePackId> genericPacks, TypePackId argTypes,
TypePackId retTypes, std::optional<FunctionDefinition> defn, bool hasSelf)
FunctionType::FunctionType(
TypeLevel level,
Scope* scope,
std::vector<TypeId> generics,
std::vector<TypePackId> genericPacks,
TypePackId argTypes,
TypePackId retTypes,
std::optional<FunctionDefinition> defn,
bool hasSelf
)
: definition(std::move(defn))
, generics(generics)
, genericPacks(genericPacks)
@ -644,8 +683,15 @@ FunctionType::FunctionType(TypeLevel level, Scope* scope, std::vector<TypeId> ge
Property::Property() {}
Property::Property(TypeId readTy, bool deprecated, const std::string& deprecatedSuggestion, std::optional<Location> location, const Tags& tags,
const std::optional<std::string>& documentationSymbol, std::optional<Location> typeLocation)
Property::Property(
TypeId readTy,
bool deprecated,
const std::string& deprecatedSuggestion,
std::optional<Location> location,
const Tags& tags,
const std::optional<std::string>& documentationSymbol,
std::optional<Location> typeLocation
)
: deprecated(deprecated)
, deprecatedSuggestion(deprecatedSuggestion)
, location(location)
@ -953,9 +999,15 @@ Type& Type::operator=(const Type& rhs)
return *this;
}
TypeId makeFunction(TypeArena& arena, std::optional<TypeId> selfType, std::initializer_list<TypeId> generics,
std::initializer_list<TypePackId> genericPacks, std::initializer_list<TypeId> paramTypes, std::initializer_list<std::string> paramNames,
std::initializer_list<TypeId> retTypes);
TypeId makeFunction(
TypeArena& arena,
std::optional<TypeId> selfType,
std::initializer_list<TypeId> generics,
std::initializer_list<TypePackId> genericPacks,
std::initializer_list<TypeId> paramTypes,
std::initializer_list<std::string> paramNames,
std::initializer_list<TypeId> retTypes
);
TypeId makeStringMetatable(NotNull<BuiltinTypes> builtinTypes); // BuiltinDefinitions.cpp

View file

@ -166,7 +166,8 @@ public:
}
return allocator->alloc<AstTypeReference>(
Location(), std::nullopt, AstName(ttv.name->c_str()), std::nullopt, Location(), parameters.size != 0, parameters);
Location(), std::nullopt, AstName(ttv.name->c_str()), std::nullopt, Location(), parameters.size != 0, parameters
);
}
if (hasSeen(&ttv))
@ -319,7 +320,8 @@ public:
retTailAnnotation = rehydrate(*retTail);
return allocator->alloc<AstTypeFunction>(
Location(), generics, genericPacks, AstTypeList{argTypes, argTailAnnotation}, argNames, AstTypeList{returnTypes, retTailAnnotation});
Location(), generics, genericPacks, AstTypeList{argTypes, argTailAnnotation}, argNames, AstTypeList{returnTypes, retTailAnnotation}
);
}
AstType* operator()(const Unifiable::Error&)
{
@ -328,7 +330,8 @@ public:
AstType* operator()(const GenericType& gtv)
{
return allocator->alloc<AstTypeReference>(
Location(), std::nullopt, AstName(getName(allocator, syntheticNames, gtv)), std::nullopt, Location());
Location(), std::nullopt, AstName(getName(allocator, syntheticNames, gtv)), std::nullopt, Location()
);
}
AstType* operator()(const Unifiable::Bound<TypeId>& bound)
{

View file

@ -442,8 +442,12 @@ struct TypeChecker2
return instance;
seenTypeFunctionInstances.insert(instance);
ErrorVec errors = reduceTypeFunctions(instance, location,
TypeFunctionContext{NotNull{&module->internalTypes}, builtinTypes, stack.back(), NotNull{&normalizer}, ice, limits}, true)
ErrorVec errors = reduceTypeFunctions(
instance,
location,
TypeFunctionContext{NotNull{&module->internalTypes}, builtinTypes, stack.back(), NotNull{&normalizer}, ice, limits},
true
)
.errors;
if (!isErrorSuppressing(location, instance))
reportErrors(std::move(errors));
@ -488,7 +492,8 @@ struct TypeChecker2
{
TypeId argTy = lookupAnnotation(ref->parameters.data[0].type);
luauPrintLine(format(
"_luau_print (%d, %d): %s\n", annotation->location.begin.line, annotation->location.begin.column, toString(argTy).c_str()));
"_luau_print (%d, %d): %s\n", annotation->location.begin.line, annotation->location.begin.column, toString(argTy).c_str()
));
return follow(argTy);
}
}
@ -597,6 +602,8 @@ struct TypeChecker2
return visit(s);
else if (auto s = stat->as<AstStatTypeAlias>())
return visit(s);
else if (auto f = stat->as<AstStatTypeFunction>())
return visit(f);
else if (auto s = stat->as<AstStatDeclareFunction>())
return visit(s);
else if (auto s = stat->as<AstStatDeclareGlobal>())
@ -728,7 +735,8 @@ struct TypeChecker2
local->values.data[local->values.size - 1]->is<AstExprCall>() ? CountMismatch::FunctionResult
: CountMismatch::ExprListResult,
},
errorLocation);
errorLocation
);
}
}
}
@ -744,7 +752,8 @@ struct TypeChecker2
testIsSubtype(builtinTypes->numberType, annotatedType, forStatement->var->location);
}
auto checkNumber = [this](AstExpr* expr) {
auto checkNumber = [this](AstExpr* expr)
{
if (!expr)
return;
@ -839,7 +848,8 @@ struct TypeChecker2
}
TypeId iteratorTy = follow(iteratorTypes.head[0]);
auto checkFunction = [this, &arena, &forInStatement, &variableTypes](const FunctionType* iterFtv, std::vector<TypeId> iterTys, bool isMm) {
auto checkFunction = [this, &arena, &forInStatement, &variableTypes](const FunctionType* iterFtv, std::vector<TypeId> iterTys, bool isMm)
{
if (iterTys.size() < 1 || iterTys.size() > 3)
{
if (isMm)
@ -856,7 +866,8 @@ struct TypeChecker2
{
if (isMm)
reportError(
GenericError{"__iter metamethod's next() function does not return enough values"}, getLocation(forInStatement->values));
GenericError{"__iter metamethod's next() function does not return enough values"}, getLocation(forInStatement->values)
);
else
reportError(GenericError{"next() does not return enough values"}, forInStatement->values.data[0]->location);
}
@ -1143,6 +1154,13 @@ struct TypeChecker2
visit(stat->type);
}
void visit(AstStatTypeFunction* stat)
{
// TODO: add type checking for user-defined type functions
reportError(TypeError{stat->location, GenericError{"This syntax is not supported"}});
}
void visit(AstTypeList types)
{
for (AstType* type : types.types)
@ -1349,11 +1367,6 @@ struct TypeChecker2
args.head.push_back(lookupType(indexExpr->expr));
argExprs.push_back(indexExpr->expr);
}
else if (findMetatableEntry(builtinTypes, module->errors, *originalCallTy, "__call", call->func->location))
{
args.head.insert(args.head.begin(), lookupType(call->func));
argExprs.push_back(call->func);
}
for (size_t i = 0; i < call->args.size; ++i)
{
@ -1698,12 +1711,17 @@ struct TypeChecker2
// together. For now, this will work.
reportError(
GenericError{format(
"Parameter '%s' has been reduced to never. This function is not callable with any possible value.", arg->name.value)},
arg->location);
"Parameter '%s' has been reduced to never. This function is not callable with any possible value.", arg->name.value
)},
arg->location
);
for (const auto& [site, component] : *contributors)
reportError(ExtraInformation{format("Parameter '%s' is required to be a subtype of '%s' here.", arg->name.value,
toString(component).c_str())},
site);
reportError(
ExtraInformation{
format("Parameter '%s' is required to be a subtype of '%s' here.", arg->name.value, toString(component).c_str())
},
site
);
}
}
@ -1739,8 +1757,10 @@ struct TypeChecker2
{
TypeFunctionReductionGuessResult result = guesser.guessTypeFunctionReductionForFunctionExpr(*fn, inferredFtv, retTy);
if (result.shouldRecommendAnnotation)
reportError(ExplicitFunctionAnnotationRecommended{std::move(result.guessedFunctionAnnotations), result.guessedReturnType},
fn->location);
reportError(
ExplicitFunctionAnnotationRecommended{std::move(result.guessedFunctionAnnotations), result.guessedReturnType},
fn->location
);
}
}
}
@ -1881,9 +1901,12 @@ struct TypeChecker2
if ((get<BlockedType>(leftType) || get<FreeType>(leftType) || get<GenericType>(leftType)) && !isEquality && !isLogical)
{
auto name = getIdentifierOfBaseVar(expr->left);
reportError(CannotInferBinaryOperation{expr->op, name,
isComparison ? CannotInferBinaryOperation::OpKind::Comparison : CannotInferBinaryOperation::OpKind::Operation},
expr->location);
reportError(
CannotInferBinaryOperation{
expr->op, name, isComparison ? CannotInferBinaryOperation::OpKind::Comparison : CannotInferBinaryOperation::OpKind::Operation
},
expr->location
);
return leftType;
}
@ -1897,7 +1920,8 @@ struct TypeChecker2
if (isEquality && !matches)
{
auto testUnion = [&matches, builtinTypes = this->builtinTypes](const UnionType* utv, std::optional<TypeId> otherMt) {
auto testUnion = [&matches, builtinTypes = this->builtinTypes](const UnionType* utv, std::optional<TypeId> otherMt)
{
for (TypeId option : utv)
{
if (getMetatable(follow(option), builtinTypes) == otherMt)
@ -1929,9 +1953,15 @@ struct TypeChecker2
if (!matches && isComparison)
{
reportError(GenericError{format("Types %s and %s cannot be compared with %s because they do not have the same metatable",
toString(leftType).c_str(), toString(rightType).c_str(), toString(expr->op).c_str())},
expr->location);
reportError(
GenericError{format(
"Types %s and %s cannot be compared with %s because they do not have the same metatable",
toString(leftType).c_str(),
toString(rightType).c_str(),
toString(expr->op).c_str()
)},
expr->location
);
return builtinTypes->errorRecoveryType();
}
@ -2034,17 +2064,29 @@ struct TypeChecker2
{
if (isComparison)
{
reportError(GenericError{format(
"Types '%s' and '%s' cannot be compared with %s because neither type's metatable has a '%s' metamethod",
toString(leftType).c_str(), toString(rightType).c_str(), toString(expr->op).c_str(), it->second)},
expr->location);
reportError(
GenericError{format(
"Types '%s' and '%s' cannot be compared with %s because neither type's metatable has a '%s' metamethod",
toString(leftType).c_str(),
toString(rightType).c_str(),
toString(expr->op).c_str(),
it->second
)},
expr->location
);
}
else
{
reportError(GenericError{format(
"Operator %s is not applicable for '%s' and '%s' because neither type's metatable has a '%s' metamethod",
toString(expr->op).c_str(), toString(leftType).c_str(), toString(rightType).c_str(), it->second)},
expr->location);
reportError(
GenericError{format(
"Operator %s is not applicable for '%s' and '%s' because neither type's metatable has a '%s' metamethod",
toString(expr->op).c_str(),
toString(leftType).c_str(),
toString(rightType).c_str(),
it->second
)},
expr->location
);
}
return builtinTypes->errorRecoveryType();
@ -2053,15 +2095,27 @@ struct TypeChecker2
{
if (isComparison)
{
reportError(GenericError{format("Types '%s' and '%s' cannot be compared with %s because neither type has a metatable",
toString(leftType).c_str(), toString(rightType).c_str(), toString(expr->op).c_str())},
expr->location);
reportError(
GenericError{format(
"Types '%s' and '%s' cannot be compared with %s because neither type has a metatable",
toString(leftType).c_str(),
toString(rightType).c_str(),
toString(expr->op).c_str()
)},
expr->location
);
}
else
{
reportError(GenericError{format("Operator %s is not applicable for '%s' and '%s' because neither type has a metatable",
toString(expr->op).c_str(), toString(leftType).c_str(), toString(rightType).c_str())},
expr->location);
reportError(
GenericError{format(
"Operator %s is not applicable for '%s' and '%s' because neither type has a metatable",
toString(expr->op).c_str(),
toString(leftType).c_str(),
toString(rightType).c_str()
)},
expr->location
);
}
return builtinTypes->errorRecoveryType();
@ -2111,9 +2165,15 @@ struct TypeChecker2
return builtinTypes->booleanType;
}
reportError(GenericError{format("Types '%s' and '%s' cannot be compared with relational operator %s", toString(leftType).c_str(),
toString(rightType).c_str(), toString(expr->op).c_str())},
expr->location);
reportError(
GenericError{format(
"Types '%s' and '%s' cannot be compared with relational operator %s",
toString(leftType).c_str(),
toString(rightType).c_str(),
toString(expr->op).c_str()
)},
expr->location
);
return builtinTypes->errorRecoveryType();
}
@ -2297,13 +2357,23 @@ struct TypeChecker2
size_t typesRequired = alias->typeParams.size();
size_t packsRequired = alias->typePackParams.size();
bool hasDefaultTypes = std::any_of(alias->typeParams.begin(), alias->typeParams.end(), [](auto&& el) {
return el.defaultValue.has_value();
});
bool hasDefaultTypes = std::any_of(
alias->typeParams.begin(),
alias->typeParams.end(),
[](auto&& el)
{
return el.defaultValue.has_value();
}
);
bool hasDefaultPacks = std::any_of(alias->typePackParams.begin(), alias->typePackParams.end(), [](auto&& el) {
return el.defaultValue.has_value();
});
bool hasDefaultPacks = std::any_of(
alias->typePackParams.begin(),
alias->typePackParams.end(),
[](auto&& el)
{
return el.defaultValue.has_value();
}
);
if (!ty->hasParameterList)
{
@ -2385,13 +2455,15 @@ struct TypeChecker2
if (typesProvided != typesRequired || packsProvided != packsRequired)
{
reportError(IncorrectGenericParameterCount{
/* name */ ty->name.value,
/* typeFun */ *alias,
/* actualParameters */ typesProvided,
/* actualPackParameters */ packsProvided,
},
ty->location);
reportError(
IncorrectGenericParameterCount{
/* name */ ty->name.value,
/* typeFun */ *alias,
/* actualParameters */ typesProvided,
/* actualPackParameters */ packsProvided,
},
ty->location
);
}
}
else
@ -2403,7 +2475,8 @@ struct TypeChecker2
ty->name.value,
SwappedGenericTypeParameter::Kind::Type,
},
ty->location);
ty->location
);
}
else
{
@ -2501,7 +2574,8 @@ struct TypeChecker2
tp->genericName.value,
SwappedGenericTypeParameter::Kind::Pack,
},
tp->location);
tp->location
);
}
else
{
@ -2715,8 +2789,14 @@ struct TypeChecker2
* contains the prop, and
* * A vector of types that do not contain the prop.
*/
PropertyTypes lookupProp(const NormalizedType* norm, const std::string& prop, ValueContext context, const Location& location,
TypeId astIndexExprType, std::vector<TypeError>& errors)
PropertyTypes lookupProp(
const NormalizedType* norm,
const std::string& prop,
ValueContext context,
const Location& location,
TypeId astIndexExprType,
std::vector<TypeError>& errors
)
{
std::vector<TypeId> typesOfProp;
std::vector<TypeId> typesMissingTheProp;
@ -2724,7 +2804,8 @@ struct TypeChecker2
// this is `false` if we ever hit the resource limits during any of our uses of `fetch`.
bool normValid = true;
auto fetch = [&](TypeId ty) {
auto fetch = [&](TypeId ty)
{
NormalizationResult result = normalizer.isInhabited(ty);
if (result == NormalizationResult::HitLimits)
normValid = false;
@ -2875,8 +2956,15 @@ struct TypeChecker2
std::optional<TypeId> result;
};
PropertyType hasIndexTypeFromType(TypeId ty, const std::string& prop, ValueContext context, const Location& location, DenseHashSet<TypeId>& seen,
TypeId astIndexExprType, std::vector<TypeError>& errors)
PropertyType hasIndexTypeFromType(
TypeId ty,
const std::string& prop,
ValueContext context,
const Location& location,
DenseHashSet<TypeId>& seen,
TypeId astIndexExprType,
std::vector<TypeError>& errors
)
{
// If we have already encountered this type, we must assume that some
// other codepath will do the right thing and signal false if the
@ -2982,7 +3070,8 @@ struct TypeChecker2
std::string_view sv(utk->key);
std::set<Name> candidates;
auto accumulate = [&](const TableType::Props& props) {
auto accumulate = [&](const TableType::Props& props)
{
for (const auto& [name, ty] : props)
{
if (sv != name && equalsLower(sv, name))
@ -3055,8 +3144,14 @@ struct TypeChecker2
}
};
void check(NotNull<BuiltinTypes> builtinTypes, NotNull<UnifierSharedState> unifierState, NotNull<TypeCheckLimits> limits, DcrLogger* logger,
const SourceModule& sourceModule, Module* module)
void check(
NotNull<BuiltinTypes> builtinTypes,
NotNull<UnifierSharedState> unifierState,
NotNull<TypeCheckLimits> limits,
DcrLogger* logger,
const SourceModule& sourceModule,
Module* module
)
{
LUAU_TIMETRACE_SCOPE("check", "Typechecking");
@ -3064,6 +3159,12 @@ void check(NotNull<BuiltinTypes> builtinTypes, NotNull<UnifierSharedState> unifi
typeChecker.visit(sourceModule.root);
// if the only error we're producing is one about constraint solving being incomplete, we can silence it.
// this means we won't give this warning if types seem totally nonsensical, but there are no other errors.
// this is probably, on the whole, a good decision to not annoy users though.
if (module->errors.size() == 1 && get<ConstraintSolvingIncompleteError>(module->errors[0]))
module->errors.clear();
unfreeze(module->interfaceTypes);
copyErrors(module->errors, module->interfaceTypes, builtinTypes);
freeze(module->interfaceTypes);

View file

@ -112,8 +112,15 @@ struct TypeFunctionReducer
// Local to the constraint being reduced.
Location location;
TypeFunctionReducer(VecDeque<TypeId> queuedTys, VecDeque<TypePackId> queuedTps, TypeOrTypePackIdSet shouldGuess, std::vector<TypeId> cyclicTypes,
Location location, TypeFunctionContext ctx, bool force = false)
TypeFunctionReducer(
VecDeque<TypeId> queuedTys,
VecDeque<TypePackId> queuedTps,
TypeOrTypePackIdSet shouldGuess,
std::vector<TypeId> cyclicTypes,
Location location,
TypeFunctionContext ctx,
bool force = false
)
: ctx(ctx)
, queuedTys(std::move(queuedTys))
, queuedTps(std::move(queuedTps))
@ -218,8 +225,12 @@ struct TypeFunctionReducer
else if (!reduction.uninhabited && !force)
{
if (FFlag::DebugLuauLogTypeFamilies)
printf("%s is irreducible; blocked on %zu types, %zu packs\n", toString(subject, {true}).c_str(), reduction.blockedTypes.size(),
reduction.blockedPacks.size());
printf(
"%s is irreducible; blocked on %zu types, %zu packs\n",
toString(subject, {true}).c_str(),
reduction.blockedTypes.size(),
reduction.blockedPacks.size()
);
for (TypeId b : reduction.blockedTypes)
result.blockedTypes.insert(b);
@ -371,7 +382,8 @@ struct TypeFunctionReducer
if (tryGuessing(subject))
return;
TypeFunctionReductionResult<TypePackId> result = tfit->function->reducer(subject, tfit->typeArguments, tfit->packArguments, NotNull{&ctx});
TypeFunctionReductionResult<TypePackId> result =
tfit->function->reducer(subject, tfit->typeArguments, tfit->packArguments, NotNull{&ctx});
handleTypeFunctionReduction(subject, result);
}
}
@ -385,8 +397,15 @@ struct TypeFunctionReducer
}
};
static FunctionGraphReductionResult reduceFunctionsInternal(VecDeque<TypeId> queuedTys, VecDeque<TypePackId> queuedTps, TypeOrTypePackIdSet shouldGuess,
std::vector<TypeId> cyclics, Location location, TypeFunctionContext ctx, bool force)
static FunctionGraphReductionResult reduceFunctionsInternal(
VecDeque<TypeId> queuedTys,
VecDeque<TypePackId> queuedTps,
TypeOrTypePackIdSet shouldGuess,
std::vector<TypeId> cyclics,
Location location,
TypeFunctionContext ctx,
bool force
)
{
TypeFunctionReducer reducer{std::move(queuedTys), std::move(queuedTps), std::move(shouldGuess), std::move(cyclics), location, ctx, force};
int iterationCount = 0;
@ -422,8 +441,15 @@ FunctionGraphReductionResult reduceTypeFunctions(TypeId entrypoint, Location loc
if (collector.tys.empty() && collector.tps.empty())
return {};
return reduceFunctionsInternal(std::move(collector.tys), std::move(collector.tps), std::move(collector.shouldGuess),
std::move(collector.cyclicInstance), location, ctx, force);
return reduceFunctionsInternal(
std::move(collector.tys),
std::move(collector.tps),
std::move(collector.shouldGuess),
std::move(collector.cyclicInstance),
location,
ctx,
force
);
}
FunctionGraphReductionResult reduceTypeFunctions(TypePackId entrypoint, Location location, TypeFunctionContext ctx, bool force)
@ -442,8 +468,15 @@ FunctionGraphReductionResult reduceTypeFunctions(TypePackId entrypoint, Location
if (collector.tys.empty() && collector.tps.empty())
return {};
return reduceFunctionsInternal(std::move(collector.tys), std::move(collector.tps), std::move(collector.shouldGuess),
std::move(collector.cyclicInstance), location, ctx, force);
return reduceFunctionsInternal(
std::move(collector.tys),
std::move(collector.tps),
std::move(collector.shouldGuess),
std::move(collector.cyclicInstance),
location,
ctx,
force
);
}
bool isPending(TypeId ty, ConstraintSolver* solver)
@ -452,8 +485,14 @@ bool isPending(TypeId ty, ConstraintSolver* solver)
}
template<typename F, typename... Args>
static std::optional<TypeFunctionReductionResult<TypeId>> tryDistributeTypeFunctionApp(F f, TypeId instance, const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx, Args&&... args)
static std::optional<TypeFunctionReductionResult<TypeId>> tryDistributeTypeFunctionApp(
F f,
TypeId instance,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx,
Args&&... args
)
{
// op (a | b) (c | d) ~ (op a (c | d)) | (op b (c | d)) ~ (op a c) | (op a d) | (op b c) | (op b d)
bool uninhabited = false;
@ -529,7 +568,11 @@ static std::optional<TypeFunctionReductionResult<TypeId>> tryDistributeTypeFunct
}
TypeFunctionReductionResult<TypeId> notTypeFunction(
TypeId instance, const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx)
TypeId instance,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx
)
{
if (typeParams.size() != 1 || !packParams.empty())
{
@ -553,7 +596,11 @@ TypeFunctionReductionResult<TypeId> notTypeFunction(
}
TypeFunctionReductionResult<TypeId> lenTypeFunction(
TypeId instance, const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx)
TypeId instance,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx
)
{
if (typeParams.size() != 1 || !packParams.empty())
{
@ -645,7 +692,11 @@ TypeFunctionReductionResult<TypeId> lenTypeFunction(
}
TypeFunctionReductionResult<TypeId> unmTypeFunction(
TypeId instance, const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx)
TypeId instance,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx
)
{
if (typeParams.size() != 1 || !packParams.empty())
{
@ -744,8 +795,13 @@ NotNull<Constraint> TypeFunctionContext::pushConstraint(ConstraintV&& c)
return newConstraint;
}
TypeFunctionReductionResult<TypeId> numericBinopTypeFunction(TypeId instance, const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx, const std::string metamethod)
TypeFunctionReductionResult<TypeId> numericBinopTypeFunction(
TypeId instance,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx,
const std::string metamethod
)
{
if (typeParams.size() != 2 || !packParams.empty())
{
@ -848,7 +904,11 @@ TypeFunctionReductionResult<TypeId> numericBinopTypeFunction(TypeId instance, co
}
TypeFunctionReductionResult<TypeId> addTypeFunction(
TypeId instance, const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx)
TypeId instance,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx
)
{
if (typeParams.size() != 2 || !packParams.empty())
{
@ -860,7 +920,11 @@ TypeFunctionReductionResult<TypeId> addTypeFunction(
}
TypeFunctionReductionResult<TypeId> subTypeFunction(
TypeId instance, const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx)
TypeId instance,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx
)
{
if (typeParams.size() != 2 || !packParams.empty())
{
@ -872,7 +936,11 @@ TypeFunctionReductionResult<TypeId> subTypeFunction(
}
TypeFunctionReductionResult<TypeId> mulTypeFunction(
TypeId instance, const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx)
TypeId instance,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx
)
{
if (typeParams.size() != 2 || !packParams.empty())
{
@ -884,7 +952,11 @@ TypeFunctionReductionResult<TypeId> mulTypeFunction(
}
TypeFunctionReductionResult<TypeId> divTypeFunction(
TypeId instance, const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx)
TypeId instance,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx
)
{
if (typeParams.size() != 2 || !packParams.empty())
{
@ -896,7 +968,11 @@ TypeFunctionReductionResult<TypeId> divTypeFunction(
}
TypeFunctionReductionResult<TypeId> idivTypeFunction(
TypeId instance, const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx)
TypeId instance,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx
)
{
if (typeParams.size() != 2 || !packParams.empty())
{
@ -908,7 +984,11 @@ TypeFunctionReductionResult<TypeId> idivTypeFunction(
}
TypeFunctionReductionResult<TypeId> powTypeFunction(
TypeId instance, const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx)
TypeId instance,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx
)
{
if (typeParams.size() != 2 || !packParams.empty())
{
@ -920,7 +1000,11 @@ TypeFunctionReductionResult<TypeId> powTypeFunction(
}
TypeFunctionReductionResult<TypeId> modTypeFunction(
TypeId instance, const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx)
TypeId instance,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx
)
{
if (typeParams.size() != 2 || !packParams.empty())
{
@ -932,7 +1016,11 @@ TypeFunctionReductionResult<TypeId> modTypeFunction(
}
TypeFunctionReductionResult<TypeId> concatTypeFunction(
TypeId instance, const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx)
TypeId instance,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx
)
{
if (typeParams.size() != 2 || !packParams.empty())
{
@ -1040,7 +1128,11 @@ TypeFunctionReductionResult<TypeId> concatTypeFunction(
}
TypeFunctionReductionResult<TypeId> andTypeFunction(
TypeId instance, const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx)
TypeId instance,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx
)
{
if (typeParams.size() != 2 || !packParams.empty())
{
@ -1091,7 +1183,11 @@ TypeFunctionReductionResult<TypeId> andTypeFunction(
}
TypeFunctionReductionResult<TypeId> orTypeFunction(
TypeId instance, const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx)
TypeId instance,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx
)
{
if (typeParams.size() != 2 || !packParams.empty())
{
@ -1141,8 +1237,13 @@ TypeFunctionReductionResult<TypeId> orTypeFunction(
return {overallResult.result, false, std::move(blockedTypes), {}};
}
static TypeFunctionReductionResult<TypeId> comparisonTypeFunction(TypeId instance, const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx, const std::string metamethod)
static TypeFunctionReductionResult<TypeId> comparisonTypeFunction(
TypeId instance,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx,
const std::string metamethod
)
{
if (typeParams.size() != 2 || !packParams.empty())
@ -1281,7 +1382,11 @@ static TypeFunctionReductionResult<TypeId> comparisonTypeFunction(TypeId instanc
}
TypeFunctionReductionResult<TypeId> ltTypeFunction(
TypeId instance, const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx)
TypeId instance,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx
)
{
if (typeParams.size() != 2 || !packParams.empty())
{
@ -1293,7 +1398,11 @@ TypeFunctionReductionResult<TypeId> ltTypeFunction(
}
TypeFunctionReductionResult<TypeId> leTypeFunction(
TypeId instance, const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx)
TypeId instance,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx
)
{
if (typeParams.size() != 2 || !packParams.empty())
{
@ -1305,7 +1414,11 @@ TypeFunctionReductionResult<TypeId> leTypeFunction(
}
TypeFunctionReductionResult<TypeId> eqTypeFunction(
TypeId instance, const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx)
TypeId instance,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx
)
{
if (typeParams.size() != 2 || !packParams.empty())
{
@ -1436,7 +1549,11 @@ struct FindRefinementBlockers : TypeOnceVisitor
TypeFunctionReductionResult<TypeId> refineTypeFunction(
TypeId instance, const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx)
TypeId instance,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx
)
{
if (typeParams.size() != 2 || !packParams.empty())
{
@ -1521,7 +1638,11 @@ TypeFunctionReductionResult<TypeId> refineTypeFunction(
}
TypeFunctionReductionResult<TypeId> singletonTypeFunction(
TypeId instance, const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx)
TypeId instance,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx
)
{
if (typeParams.size() != 1 || !packParams.empty())
{
@ -1558,7 +1679,11 @@ TypeFunctionReductionResult<TypeId> singletonTypeFunction(
}
TypeFunctionReductionResult<TypeId> unionTypeFunction(
TypeId instance, const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx)
TypeId instance,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx
)
{
if (!packParams.empty())
{
@ -1619,7 +1744,11 @@ TypeFunctionReductionResult<TypeId> unionTypeFunction(
TypeFunctionReductionResult<TypeId> intersectTypeFunction(
TypeId instance, const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx)
TypeId instance,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx
)
{
if (!packParams.empty())
{
@ -1726,7 +1855,11 @@ bool computeKeysOf(TypeId ty, Set<std::string>& result, DenseHashSet<TypeId>& se
}
TypeFunctionReductionResult<TypeId> keyofFunctionImpl(
const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx, bool isRaw)
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx,
bool isRaw
)
{
if (typeParams.size() != 1 || !packParams.empty())
{
@ -1843,7 +1976,11 @@ TypeFunctionReductionResult<TypeId> keyofFunctionImpl(
}
TypeFunctionReductionResult<TypeId> keyofTypeFunction(
TypeId instance, const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx)
TypeId instance,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx
)
{
if (typeParams.size() != 1 || !packParams.empty())
{
@ -1855,7 +1992,11 @@ TypeFunctionReductionResult<TypeId> keyofTypeFunction(
}
TypeFunctionReductionResult<TypeId> rawkeyofTypeFunction(
TypeId instance, const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx)
TypeId instance,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx
)
{
if (typeParams.size() != 1 || !packParams.empty())
{
@ -1870,7 +2011,12 @@ TypeFunctionReductionResult<TypeId> rawkeyofTypeFunction(
If found, appends that property to `result` and returns true
Else, returns false */
bool searchPropsAndIndexer(
TypeId ty, TableType::Props tblProps, std::optional<TableIndexer> tblIndexer, DenseHashSet<TypeId>& result, NotNull<TypeFunctionContext> ctx)
TypeId ty,
TableType::Props tblProps,
std::optional<TableIndexer> tblIndexer,
DenseHashSet<TypeId>& result,
NotNull<TypeFunctionContext> ctx
)
{
ty = follow(ty);
@ -1961,7 +2107,11 @@ bool tblIndexInto(TypeId indexer, TypeId indexee, DenseHashSet<TypeId>& result,
indexer refers to the type that is used to access indexee
Example: index<Person, "name"> => `Person` is the indexee and `"name"` is the indexer */
TypeFunctionReductionResult<TypeId> indexFunctionImpl(
const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx, bool isRaw)
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx,
bool isRaw
)
{
TypeId indexeeTy = follow(typeParams.at(0));
std::shared_ptr<const NormalizedType> indexeeNormTy = ctx->normalizer->normalize(indexeeTy);
@ -2053,9 +2203,15 @@ TypeFunctionReductionResult<TypeId> indexFunctionImpl(
}
// Call `follow()` on each element to resolve all Bound types before returning
std::transform(properties.begin(), properties.end(), properties.begin(), [](TypeId ty) {
return follow(ty);
});
std::transform(
properties.begin(),
properties.end(),
properties.begin(),
[](TypeId ty)
{
return follow(ty);
}
);
// If the type being reduced to is a single type, no need to union
if (properties.size() == 1)
@ -2065,7 +2221,11 @@ TypeFunctionReductionResult<TypeId> indexFunctionImpl(
}
TypeFunctionReductionResult<TypeId> indexTypeFunction(
TypeId instance, const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx)
TypeId instance,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx
)
{
if (typeParams.size() != 2 || !packParams.empty())
{
@ -2077,7 +2237,11 @@ TypeFunctionReductionResult<TypeId> indexTypeFunction(
}
TypeFunctionReductionResult<TypeId> rawgetTypeFunction(
TypeId instance, const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx)
TypeId instance,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx
)
{
if (typeParams.size() != 2 || !packParams.empty())
{
@ -2119,7 +2283,8 @@ BuiltinTypeFunctions::BuiltinTypeFunctions()
void BuiltinTypeFunctions::addToScope(NotNull<TypeArena> arena, NotNull<Scope> scope) const
{
// make a type function for a one-argument type function
auto mkUnaryTypeFunction = [&](const TypeFunction* tf) {
auto mkUnaryTypeFunction = [&](const TypeFunction* tf)
{
TypeId t = arena->addType(GenericType{"T"});
GenericTypeDefinition genericT{t};
@ -2127,7 +2292,8 @@ void BuiltinTypeFunctions::addToScope(NotNull<TypeArena> arena, NotNull<Scope> s
};
// make a type function for a two-argument type function
auto mkBinaryTypeFunction = [&](const TypeFunction* tf) {
auto mkBinaryTypeFunction = [&](const TypeFunction* tf)
{
TypeId t = arena->addType(GenericType{"T"});
TypeId u = arena->addType(GenericType{"U"});
GenericTypeDefinition genericT{t};

View file

@ -128,7 +128,10 @@ std::optional<TypePackId> TypeFunctionReductionGuesser::guess(TypePackId tp)
}
TypeFunctionReductionGuessResult TypeFunctionReductionGuesser::guessTypeFunctionReductionForFunctionExpr(
const AstExprFunction& expr, const FunctionType* ftv, TypeId retTy)
const AstExprFunction& expr,
const FunctionType* ftv,
TypeId retTy
)
{
InstanceCollector2 collector;
collector.traverse(retTy);
@ -204,8 +207,9 @@ std::optional<TypeId> TypeFunctionReductionGuesser::guessType(TypeId arg)
bool TypeFunctionReductionGuesser::isNumericBinopFunction(const TypeFunctionInstanceType& instance)
{
return instance.function->name == "add" || instance.function->name == "sub" || instance.function->name == "mul" || instance.function->name == "div" ||
instance.function->name == "idiv" || instance.function->name == "pow" || instance.function->name == "mod";
return instance.function->name == "add" || instance.function->name == "sub" || instance.function->name == "mul" ||
instance.function->name == "div" || instance.function->name == "idiv" || instance.function->name == "pow" ||
instance.function->name == "mod";
}
bool TypeFunctionReductionGuesser::isComparisonFunction(const TypeFunctionInstanceType& instance)
@ -350,7 +354,8 @@ TypeFunctionInferenceResult TypeFunctionReductionGuesser::inferComparisonFunctio
TypeId lhsTy = follow(instance->typeArguments[0]);
TypeId rhsTy = follow(instance->typeArguments[1]);
auto comparisonInference = [&](TypeId op) -> TypeFunctionInferenceResult {
auto comparisonInference = [&](TypeId op) -> TypeFunctionInferenceResult
{
return TypeFunctionInferenceResult{{op, op}, builtins->booleanType};
};

View file

@ -31,9 +31,7 @@ LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 300)
LUAU_FASTINTVARIABLE(LuauVisitRecursionLimit, 500)
LUAU_FASTFLAG(LuauKnowsTheDataModel3)
LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false)
LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false)
LUAU_FASTFLAG(LuauInstantiateInSubtyping)
LUAU_FASTFLAGVARIABLE(LuauAlwaysCommitInferencesOfFunctionCalls, false)
LUAU_FASTFLAGVARIABLE(LuauRemoveBadRelationalOperatorWarning, false)
LUAU_FASTFLAGVARIABLE(LuauOkWithIteratingOverTableProperties, false)
LUAU_FASTFLAGVARIABLE(LuauReusableSubstitutions, false)
@ -294,13 +292,6 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo
currentModule->cancelled = true;
}
if (FFlag::DebugLuauSharedSelf)
{
for (auto& [ty, scope] : deferredQuantification)
Luau::quantify(ty, scope->level);
deferredQuantification.clear();
}
if (get<FreeTypePack>(follow(moduleScope->returnType)))
moduleScope->returnType = addTypePack(TypePack{{}, std::nullopt});
else
@ -379,6 +370,8 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStat& program)
ice("Should not be calling two-argument check() on a function statement", program.location);
else if (auto typealias = program.as<AstStatTypeAlias>())
return check(scope, *typealias);
else if (auto typefunction = program.as<AstStatTypeFunction>())
return check(scope, *typefunction);
else if (auto global = program.as<AstStatDeclareGlobal>())
{
TypeId globalType = resolveType(scope, *global->type);
@ -517,7 +510,8 @@ ControlFlow TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope,
std::unordered_map<AstStat*, std::pair<TypeId, ScopePtr>> functionDecls;
auto checkBody = [&](AstStat* stat) {
auto checkBody = [&](AstStat* stat)
{
if (auto fun = stat->as<AstStatFunction>())
{
LUAU_ASSERT(functionDecls.count(stat));
@ -581,39 +575,15 @@ ControlFlow TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope,
}
else if (auto fun = (*protoIter)->as<AstStatFunction>())
{
std::optional<TypeId> selfType;
std::optional<TypeId> selfType; // TODO clip
std::optional<TypeId> expectedType;
if (FFlag::DebugLuauSharedSelf)
if (!fun->func->self)
{
if (auto name = fun->name->as<AstExprIndexName>())
{
TypeId baseTy = checkExpr(scope, *name->expr).type;
tablify(baseTy);
if (!fun->func->self)
expectedType = getIndexTypeFromType(scope, baseTy, name->index.value, name->indexLocation, /* addErrors= */ false);
else if (auto ttv = getMutableTableType(baseTy))
{
if (!baseTy->persistent && ttv->state != TableState::Sealed && !ttv->selfTy)
{
ttv->selfTy = anyIfNonstrict(freshType(ttv->level));
deferredQuantification.push_back({baseTy, scope});
}
selfType = ttv->selfTy;
}
}
}
else
{
if (!fun->func->self)
{
if (auto name = fun->name->as<AstExprIndexName>())
{
TypeId exprTy = checkExpr(scope, *name->expr).type;
expectedType = getIndexTypeFromType(scope, exprTy, name->index.value, name->indexLocation, /* addErrors= */ false);
}
TypeId exprTy = checkExpr(scope, *name->expr).type;
expectedType = getIndexTypeFromType(scope, exprTy, name->index.value, name->indexLocation, /* addErrors= */ false);
}
}
@ -1563,14 +1533,26 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& ty
// Additionally, we can't modify types that come from other modules
if (ttv->name || follow(ty)->owningArena != &currentModule->internalTypes)
{
bool sameTys = std::equal(ttv->instantiatedTypeParams.begin(), ttv->instantiatedTypeParams.end(), binding->typeParams.begin(),
binding->typeParams.end(), [](auto&& itp, auto&& tp) {
bool sameTys = std::equal(
ttv->instantiatedTypeParams.begin(),
ttv->instantiatedTypeParams.end(),
binding->typeParams.begin(),
binding->typeParams.end(),
[](auto&& itp, auto&& tp)
{
return itp == tp.ty;
});
bool sameTps = std::equal(ttv->instantiatedTypePackParams.begin(), ttv->instantiatedTypePackParams.end(), binding->typePackParams.begin(),
binding->typePackParams.end(), [](auto&& itpp, auto&& tpp) {
}
);
bool sameTps = std::equal(
ttv->instantiatedTypePackParams.begin(),
ttv->instantiatedTypePackParams.end(),
binding->typePackParams.begin(),
binding->typePackParams.end(),
[](auto&& itpp, auto&& tpp)
{
return itpp == tpp.tp;
});
}
);
// Copy can be skipped if this is an identical alias
if (!ttv->name || ttv->name != name || !sameTys || !sameTps)
@ -1630,6 +1612,13 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& ty
return ControlFlow::None;
}
ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatTypeFunction& typefunction)
{
reportError(TypeError{typefunction.location, GenericError{"This syntax is not supported"}});
return ControlFlow::None;
}
void TypeChecker::prototype(const ScopePtr& scope, const AstStatTypeAlias& typealias, int subLevel)
{
Name name = typealias.name.value;
@ -1704,8 +1693,10 @@ void TypeChecker::prototype(const ScopePtr& scope, const AstStatDeclareClass& de
if (!get<ClassType>(follow(*superTy)))
{
reportError(declaredClass.location,
GenericError{format("Cannot use non-class type '%s' as a superclass of class '%s'", superName.c_str(), declaredClass.name.value)});
reportError(
declaredClass.location,
GenericError{format("Cannot use non-class type '%s' as a superclass of class '%s'", superName.c_str(), declaredClass.name.value)}
);
incorrectClassDefinitions.insert(&declaredClass);
return;
}
@ -1852,15 +1843,27 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatDeclareFuncti
std::vector<TypeId> genericTys;
genericTys.reserve(generics.size());
std::transform(generics.begin(), generics.end(), std::back_inserter(genericTys), [](auto&& el) {
return el.ty;
});
std::transform(
generics.begin(),
generics.end(),
std::back_inserter(genericTys),
[](auto&& el)
{
return el.ty;
}
);
std::vector<TypePackId> genericTps;
genericTps.reserve(genericPacks.size());
std::transform(genericPacks.begin(), genericPacks.end(), std::back_inserter(genericTps), [](auto&& el) {
return el.tp;
});
std::transform(
genericPacks.begin(),
genericPacks.end(),
std::back_inserter(genericTps),
[](auto&& el)
{
return el.tp;
}
);
TypePackId argPack = resolveTypePack(funScope, global.params);
TypePackId retPack = resolveTypePack(funScope, global.retTypes);
@ -2085,7 +2088,12 @@ std::optional<TypeId> TypeChecker::findMetatableEntry(TypeId type, std::string e
}
std::optional<TypeId> TypeChecker::getIndexTypeFromType(
const ScopePtr& scope, TypeId type, const Name& name, const Location& location, bool addErrors)
const ScopePtr& scope,
TypeId type,
const Name& name,
const Location& location,
bool addErrors
)
{
size_t errorCount = currentModule->errors.size();
@ -2098,7 +2106,12 @@ std::optional<TypeId> TypeChecker::getIndexTypeFromType(
}
std::optional<TypeId> TypeChecker::getIndexTypeFromTypeImpl(
const ScopePtr& scope, TypeId type, const Name& name, const Location& location, bool addErrors)
const ScopePtr& scope,
TypeId type,
const Name& name,
const Location& location,
bool addErrors
)
{
type = follow(type);
@ -2297,7 +2310,11 @@ WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExp
}
TypeId TypeChecker::checkExprTable(
const ScopePtr& scope, const AstExprTable& expr, const std::vector<std::pair<TypeId, TypeId>>& fieldTypes, std::optional<TypeId> expectedType)
const ScopePtr& scope,
const AstExprTable& expr,
const std::vector<std::pair<TypeId, TypeId>>& fieldTypes,
std::optional<TypeId> expectedType
)
{
TableType::Props props;
std::optional<TableIndexer> indexer;
@ -2526,8 +2543,10 @@ WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExp
return WithPredicate{retType};
}
reportError(expr.location,
GenericError{format("Unary operator '%s' not supported by type '%s'", toString(expr.op).c_str(), toString(operandType).c_str())});
reportError(
expr.location,
GenericError{format("Unary operator '%s' not supported by type '%s'", toString(expr.op).c_str(), toString(operandType).c_str())}
);
return WithPredicate{errorRecoveryType(scope)};
}
@ -2674,7 +2693,8 @@ static std::optional<bool> areEqComparable(NotNull<TypeArena> arena, NotNull<Nor
a = follow(a);
b = follow(b);
auto isExempt = [](TypeId t) {
auto isExempt = [](TypeId t)
{
return isNil(t) || get<FreeType>(t);
};
@ -2705,9 +2725,15 @@ static std::optional<bool> areEqComparable(NotNull<TypeArena> arena, NotNull<Nor
}
TypeId TypeChecker::checkRelationalOperation(
const ScopePtr& scope, const AstExprBinary& expr, TypeId lhsType, TypeId rhsType, const PredicateVec& predicates)
const ScopePtr& scope,
const AstExprBinary& expr,
TypeId lhsType,
TypeId rhsType,
const PredicateVec& predicates
)
{
auto stripNil = [this](TypeId ty, bool isOrOp = false) {
auto stripNil = [this](TypeId ty, bool isOrOp = false)
{
ty = follow(ty);
if (!isNonstrictMode() && !isOrOp)
return ty;
@ -2788,7 +2814,8 @@ TypeId TypeChecker::checkRelationalOperation(
if (!*eqTestResult)
{
reportError(
expr.location, GenericError{format("Type %s cannot be compared with %s", toString(lhsType).c_str(), toString(rhsType).c_str())});
expr.location, GenericError{format("Type %s cannot be compared with %s", toString(lhsType).c_str(), toString(rhsType).c_str())}
);
return errorRecoveryType(booleanType);
}
}
@ -2821,16 +2848,24 @@ TypeId TypeChecker::checkRelationalOperation(
// we need to be conservative in the old solver to deliver a reasonable developer experience.
if (!isEquality && state.errors.empty() && isBoolean(leftType))
{
reportError(expr.location, GenericError{format("Type '%s' cannot be compared with relational operator %s",
toString(leftType).c_str(), toString(expr.op).c_str())});
reportError(
expr.location,
GenericError{
format("Type '%s' cannot be compared with relational operator %s", toString(leftType).c_str(), toString(expr.op).c_str())
}
);
}
}
else
{
if (!isEquality && state.errors.empty() && (get<UnionType>(leftType) || isBoolean(leftType)))
{
reportError(expr.location, GenericError{format("Type '%s' cannot be compared with relational operator %s",
toString(leftType).c_str(), toString(expr.op).c_str())});
reportError(
expr.location,
GenericError{
format("Type '%s' cannot be compared with relational operator %s", toString(leftType).c_str(), toString(expr.op).c_str())
}
);
}
}
@ -2879,8 +2914,14 @@ TypeId TypeChecker::checkRelationalOperation(
if (!matches)
{
reportError(
expr.location, GenericError{format("Types %s and %s cannot be compared with %s because they do not have the same metatable",
toString(lhsType).c_str(), toString(rhsType).c_str(), toString(expr.op).c_str())});
expr.location,
GenericError{format(
"Types %s and %s cannot be compared with %s because they do not have the same metatable",
toString(lhsType).c_str(),
toString(rhsType).c_str(),
toString(expr.op).c_str()
)}
);
return errorRecoveryType(booleanType);
}
}
@ -2911,7 +2952,8 @@ TypeId TypeChecker::checkRelationalOperation(
TypeId actualFunctionType = addType(FunctionType(scope->level, addTypePack({lhsType, rhsType}), addTypePack({booleanType})));
state.tryUnify(
instantiate(scope, actualFunctionType, expr.location), instantiate(scope, *metamethod, expr.location), /*isFunctionCall*/ true);
instantiate(scope, actualFunctionType, expr.location), instantiate(scope, *metamethod, expr.location), /*isFunctionCall*/ true
);
state.log.commit();
@ -2921,7 +2963,8 @@ TypeId TypeChecker::checkRelationalOperation(
else if (needsMetamethod)
{
reportError(
expr.location, GenericError{format("Table %s does not offer metamethod %s", toString(lhsType).c_str(), metamethodName.c_str())});
expr.location, GenericError{format("Table %s does not offer metamethod %s", toString(lhsType).c_str(), metamethodName.c_str())}
);
return errorRecoveryType(booleanType);
}
}
@ -2935,8 +2978,12 @@ TypeId TypeChecker::checkRelationalOperation(
if (needsMetamethod)
{
reportError(expr.location, GenericError{format("Type %s cannot be compared with %s because it has no metatable",
toString(lhsType).c_str(), toString(expr.op).c_str())});
reportError(
expr.location,
GenericError{
format("Type %s cannot be compared with %s because it has no metatable", toString(lhsType).c_str(), toString(expr.op).c_str())
}
);
return errorRecoveryType(booleanType);
}
@ -3006,7 +3053,12 @@ TypeId TypeChecker::checkRelationalOperation(
}
TypeId TypeChecker::checkBinaryOperation(
const ScopePtr& scope, const AstExprBinary& expr, TypeId lhsType, TypeId rhsType, const PredicateVec& predicates)
const ScopePtr& scope,
const AstExprBinary& expr,
TypeId lhsType,
TypeId rhsType,
const PredicateVec& predicates
)
{
switch (expr.op)
{
@ -3057,7 +3109,8 @@ TypeId TypeChecker::checkBinaryOperation(
if (typeCouldHaveMetatable(lhsType) || typeCouldHaveMetatable(rhsType))
{
auto checkMetatableCall = [this, &scope, &expr](TypeId fnt, TypeId lhst, TypeId rhst) -> TypeId {
auto checkMetatableCall = [this, &scope, &expr](TypeId fnt, TypeId lhst, TypeId rhst) -> TypeId
{
TypeId actualFunctionType = instantiate(scope, fnt, expr.location);
TypePackId arguments = addTypePack({lhst, rhst});
TypePackId retTypePack = freshTypePack(scope);
@ -3104,8 +3157,15 @@ TypeId TypeChecker::checkBinaryOperation(
return checkMetatableCall(*fnt, rhsType, lhsType);
}
reportError(expr.location, GenericError{format("Binary operator '%s' not supported by types '%s' and '%s'", toString(expr.op).c_str(),
toString(lhsType).c_str(), toString(rhsType).c_str())});
reportError(
expr.location,
GenericError{format(
"Binary operator '%s' not supported by types '%s' and '%s'",
toString(expr.op).c_str(),
toString(lhsType).c_str(),
toString(rhsType).c_str()
)}
);
return errorRecoveryType(scope);
}
@ -3537,7 +3597,8 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex
// Primarily about detecting duplicates.
TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName, TypeLevel level)
{
auto freshTy = [&]() {
auto freshTy = [&]()
{
return freshType(level);
};
@ -3610,8 +3671,14 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName, T
// `(X) -> Y...`, but after typechecking the body, we cam unify `Y...` with `X`
// to get type `(X) -> X`, then we quantify the free types to get the final
// generic type `<a>(a) -> a`.
std::pair<TypeId, ScopePtr> TypeChecker::checkFunctionSignature(const ScopePtr& scope, int subLevel, const AstExprFunction& expr,
std::optional<Location> originalName, std::optional<TypeId> selfType, std::optional<TypeId> expectedType)
std::pair<TypeId, ScopePtr> TypeChecker::checkFunctionSignature(
const ScopePtr& scope,
int subLevel,
const AstExprFunction& expr,
std::optional<Location> originalName,
std::optional<TypeId> selfType,
std::optional<TypeId> expectedType
)
{
ScopePtr funScope = childFunctionScope(scope, expr.location, subLevel);
@ -3704,25 +3771,11 @@ std::pair<TypeId, ScopePtr> TypeChecker::checkFunctionSignature(const ScopePtr&
funScope->returnType = retPack;
if (FFlag::DebugLuauSharedSelf)
if (expr.self)
{
if (expr.self)
{
// TODO: generic self types: CLI-39906
TypeId selfTy = anyIfNonstrict(selfType ? *selfType : freshType(funScope));
funScope->bindings[expr.self] = {selfTy, expr.self->location};
argTypes.push_back(selfTy);
}
}
else
{
if (expr.self)
{
// TODO: generic self types: CLI-39906
TypeId selfType = anyIfNonstrict(freshType(funScope));
funScope->bindings[expr.self] = {selfType, expr.self->location};
argTypes.push_back(selfType);
}
TypeId selfType = anyIfNonstrict(freshType(funScope));
funScope->bindings[expr.self] = {selfType, expr.self->location};
argTypes.push_back(selfType);
}
// Prepare expected argument type iterators if we have an expected function type
@ -3911,8 +3964,14 @@ WithPredicate<TypePackId> TypeChecker::checkExprPackHelper(const ScopePtr& scope
}
}
void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funName, Unifier& state, TypePackId argPack, TypePackId paramPack,
const std::vector<Location>& argLocations)
void TypeChecker::checkArgumentList(
const ScopePtr& scope,
const AstExpr& funName,
Unifier& state,
TypePackId argPack,
TypePackId paramPack,
const std::vector<Location>& argLocations
)
{
/* Important terminology refresher:
* A function requires parameters.
@ -3924,7 +3983,8 @@ void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funNam
size_t paramIndex = 0;
auto reportCountMismatchError = [&state, &argLocations, paramPack, argPack, &funName]() {
auto reportCountMismatchError = [&state, &argLocations, paramPack, argPack, &funName]()
{
// For this case, we want the error span to cover every errant extra parameter
Location location = state.location;
if (!argLocations.empty())
@ -3936,8 +3996,10 @@ void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funNam
namePath = *path;
auto [minParams, optMaxParams] = getParameterExtents(&state.log, paramPack);
state.reportError(TypeError{location,
CountMismatch{minParams, optMaxParams, std::distance(begin(argPack), end(argPack)), CountMismatch::Context::Arg, false, namePath}});
state.reportError(TypeError{
location,
CountMismatch{minParams, optMaxParams, std::distance(begin(argPack), end(argPack)), CountMismatch::Context::Arg, false, namePath}
});
};
while (true)
@ -4044,7 +4106,8 @@ void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funNam
namePath = *path;
state.reportError(TypeError{
funName.location, CountMismatch{minParams, optMaxParams, paramIndex, CountMismatch::Context::Arg, isVariadic, namePath}});
funName.location, CountMismatch{minParams, optMaxParams, paramIndex, CountMismatch::Context::Arg, isVariadic, namePath}
});
return;
}
++paramIter;
@ -4188,7 +4251,8 @@ WithPredicate<TypePackId> TypeChecker::checkExprPackHelper(const ScopePtr& scope
// We break this function up into a lambda here to limit our stack footprint.
// The vectors used by this function aren't allocated until the lambda is actually called.
auto the_rest = [&]() -> WithPredicate<TypePackId> {
auto the_rest = [&]() -> WithPredicate<TypePackId>
{
// checkExpr will log the pre-instantiated type of the function.
// That's not nearly as interesting as the instantiated type, which will include details about how
// generic functions are being instantiated for this particular callsite.
@ -4231,7 +4295,8 @@ WithPredicate<TypePackId> TypeChecker::checkExprPackHelper(const ScopePtr& scope
fn = follow(fn);
if (auto ret = checkCallOverload(
scope, expr, fn, retPack, argPack, args, &argLocations, argListResult, overloadsThatMatchArgCount, overloadsThatDont, errors))
scope, expr, fn, retPack, argPack, args, &argLocations, argListResult, overloadsThatMatchArgCount, overloadsThatDont, errors
))
return *ret;
}
@ -4258,7 +4323,8 @@ std::vector<std::optional<TypeId>> TypeChecker::getExpectedTypesForCall(const st
{
std::vector<std::optional<TypeId>> expectedTypes;
auto assignOption = [this, &expectedTypes](size_t index, TypeId ty) {
auto assignOption = [this, &expectedTypes](size_t index, TypeId ty)
{
if (index == expectedTypes.size())
{
expectedTypes.push_back(ty);
@ -4317,9 +4383,19 @@ std::vector<std::optional<TypeId>> TypeChecker::getExpectedTypesForCall(const st
* If this was an optional, callers would have to pay the stack cost for the result. This is problematic
* for functions that need to support recursion up to 600 levels deep.
*/
std::unique_ptr<WithPredicate<TypePackId>> TypeChecker::checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn,
TypePackId retPack, TypePackId argPack, TypePack* args, const std::vector<Location>* argLocations, const WithPredicate<TypePackId>& argListResult,
std::vector<TypeId>& overloadsThatMatchArgCount, std::vector<TypeId>& overloadsThatDont, std::vector<OverloadErrorEntry>& errors)
std::unique_ptr<WithPredicate<TypePackId>> TypeChecker::checkCallOverload(
const ScopePtr& scope,
const AstExprCall& expr,
TypeId fn,
TypePackId retPack,
TypePackId argPack,
TypePack* args,
const std::vector<Location>* argLocations,
const WithPredicate<TypePackId>& argListResult,
std::vector<TypeId>& overloadsThatMatchArgCount,
std::vector<TypeId>& overloadsThatDont,
std::vector<OverloadErrorEntry>& errors
)
{
LUAU_ASSERT(argLocations);
@ -4453,8 +4529,13 @@ std::unique_ptr<WithPredicate<TypePackId>> TypeChecker::checkCallOverload(const
return nullptr;
}
bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCall& expr, TypePack* args, const std::vector<Location>& argLocations,
const std::vector<OverloadErrorEntry>& errors)
bool TypeChecker::handleSelfCallMismatch(
const ScopePtr& scope,
const AstExprCall& expr,
TypePack* args,
const std::vector<Location>& argLocations,
const std::vector<OverloadErrorEntry>& errors
)
{
// No overloads succeeded: Scan for one that would have worked had the user
// used a.b() rather than a:b() or vice versa.
@ -4521,14 +4602,20 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal
return false;
}
void TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const AstExprCall& expr, TypePackId retPack, TypePackId argPack,
const std::vector<Location>& argLocations, const std::vector<TypeId>& overloads, const std::vector<TypeId>& overloadsThatMatchArgCount,
std::vector<OverloadErrorEntry>& errors)
void TypeChecker::reportOverloadResolutionError(
const ScopePtr& scope,
const AstExprCall& expr,
TypePackId retPack,
TypePackId argPack,
const std::vector<Location>& argLocations,
const std::vector<TypeId>& overloads,
const std::vector<TypeId>& overloadsThatMatchArgCount,
std::vector<OverloadErrorEntry>& errors
)
{
if (overloads.size() == 1)
{
if (FFlag::LuauAlwaysCommitInferencesOfFunctionCalls)
errors.front().log.commit();
errors.front().log.commit();
reportErrors(errors.front().errors);
return;
@ -4551,14 +4638,18 @@ void TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const Ast
const FunctionType* ftv = get<FunctionType>(overload);
auto error = std::find_if(errors.begin(), errors.end(), [ftv](const OverloadErrorEntry& e) {
return ftv == e.fnTy;
});
auto error = std::find_if(
errors.begin(),
errors.end(),
[ftv](const OverloadErrorEntry& e)
{
return ftv == e.fnTy;
}
);
LUAU_ASSERT(error != errors.end());
if (FFlag::LuauAlwaysCommitInferencesOfFunctionCalls)
error->log.commit();
error->log.commit();
reportErrors(error->errors);
@ -4601,14 +4692,21 @@ void TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const Ast
return;
}
WithPredicate<TypePackId> TypeChecker::checkExprList(const ScopePtr& scope, const Location& location, const AstArray<AstExpr*>& exprs,
bool substituteFreeForNil, const std::vector<bool>& instantiateGenerics, const std::vector<std::optional<TypeId>>& expectedTypes)
WithPredicate<TypePackId> TypeChecker::checkExprList(
const ScopePtr& scope,
const Location& location,
const AstArray<AstExpr*>& exprs,
bool substituteFreeForNil,
const std::vector<bool>& instantiateGenerics,
const std::vector<std::optional<TypeId>>& expectedTypes
)
{
bool uninhabitable = false;
TypePackId pack = addTypePack(TypePack{});
PredicateVec predicates; // At the moment we will be pushing all predicate sets into this. Do we need some way to split them up?
auto insert = [&predicates](PredicateVec& vec) {
auto insert = [&predicates](PredicateVec& vec)
{
for (Predicate& c : vec)
predicates.push_back(std::move(c));
};
@ -4875,20 +4973,10 @@ TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location
{
ty = follow(ty);
if (FFlag::DebugLuauSharedSelf)
{
if (auto ftv = get<FunctionType>(ty))
Luau::quantify(ty, scope->level);
else if (auto ttv = getTableType(ty); ttv && ttv->selfTy)
Luau::quantify(ty, scope->level);
}
else
{
const FunctionType* ftv = get<FunctionType>(ty);
const FunctionType* ftv = get<FunctionType>(ty);
if (ftv)
Luau::quantify(ty, scope->level);
}
if (ftv)
Luau::quantify(ty, scope->level);
return ty;
}
@ -5031,11 +5119,17 @@ LUAU_NOINLINE void TypeChecker::throwUserCancelError()
void TypeChecker::prepareErrorsForDisplay(ErrorVec& errVec)
{
// Remove errors with names that were generated by recovery from a parse error
errVec.erase(std::remove_if(errVec.begin(), errVec.end(),
[](auto& err) {
return containsParseErrorName(err);
}),
errVec.end());
errVec.erase(
std::remove_if(
errVec.begin(),
errVec.end(),
[](auto& err)
{
return containsParseErrorName(err);
}
),
errVec.end()
);
for (auto& err : errVec)
{
@ -5049,7 +5143,8 @@ void TypeChecker::diagnoseMissingTableKey(UnknownProperty* utk, TypeErrorData& d
std::string_view sv(utk->key);
std::set<Name> candidates;
auto accumulate = [&](const TableType::Props& props) {
auto accumulate = [&](const TableType::Props& props)
{
for (const auto& [name, ty] : props)
{
if (sv != name && equalsLower(sv, name))
@ -5103,25 +5198,30 @@ ScopePtr TypeChecker::childScope(const ScopePtr& parent, const Location& locatio
void TypeChecker::merge(RefinementMap& l, const RefinementMap& r)
{
Luau::merge(l, r, [this](TypeId a, TypeId b) {
// TODO: normalize(UnionType{{a, b}})
std::unordered_set<TypeId> set;
Luau::merge(
l,
r,
[this](TypeId a, TypeId b)
{
// TODO: normalize(UnionType{{a, b}})
std::unordered_set<TypeId> set;
if (auto utv = get<UnionType>(follow(a)))
set.insert(begin(utv), end(utv));
else
set.insert(a);
if (auto utv = get<UnionType>(follow(a)))
set.insert(begin(utv), end(utv));
else
set.insert(a);
if (auto utv = get<UnionType>(follow(b)))
set.insert(begin(utv), end(utv));
else
set.insert(b);
if (auto utv = get<UnionType>(follow(b)))
set.insert(begin(utv), end(utv));
else
set.insert(b);
std::vector<TypeId> options(set.begin(), set.end());
if (set.size() == 1)
return options[0];
return addType(UnionType{std::move(options)});
});
std::vector<TypeId> options(set.begin(), set.end());
if (set.size() == 1)
return options[0];
return addType(UnionType{std::move(options)});
}
);
}
Unifier TypeChecker::mkUnifier(const ScopePtr& scope, const Location& location)
@ -5172,7 +5272,8 @@ TypePackId TypeChecker::errorRecoveryTypePack(TypePackId guess)
TypeIdPredicate TypeChecker::mkTruthyPredicate(bool sense, TypeId emptySetTy)
{
return [this, sense, emptySetTy](TypeId ty) -> std::optional<TypeId> {
return [this, sense, emptySetTy](TypeId ty) -> std::optional<TypeId>
{
// any/error/free gets a special pass unconditionally because they can't be decided.
if (get<AnyType>(ty) || get<ErrorType>(ty) || get<FreeType>(ty))
return ty;
@ -5314,12 +5415,22 @@ TypeId TypeChecker::resolveTypeWorker(const ScopePtr& scope, const AstType& anno
return tf->type;
bool parameterCountErrorReported = false;
bool hasDefaultTypes = std::any_of(tf->typeParams.begin(), tf->typeParams.end(), [](auto&& el) {
return el.defaultValue.has_value();
});
bool hasDefaultPacks = std::any_of(tf->typePackParams.begin(), tf->typePackParams.end(), [](auto&& el) {
return el.defaultValue.has_value();
});
bool hasDefaultTypes = std::any_of(
tf->typeParams.begin(),
tf->typeParams.end(),
[](auto&& el)
{
return el.defaultValue.has_value();
}
);
bool hasDefaultPacks = std::any_of(
tf->typePackParams.begin(),
tf->typePackParams.end(),
[](auto&& el)
{
return el.defaultValue.has_value();
}
);
if (!lit->hasParameterList)
{
@ -5442,7 +5553,8 @@ TypeId TypeChecker::resolveTypeWorker(const ScopePtr& scope, const AstType& anno
{
if (!parameterCountErrorReported)
reportError(
TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, typeParams.size(), typePackParams.size()}});
TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, typeParams.size(), typePackParams.size()}}
);
// Pad the types out with error recovery types
while (typeParams.size() < tf->typeParams.size())
@ -5451,13 +5563,26 @@ TypeId TypeChecker::resolveTypeWorker(const ScopePtr& scope, const AstType& anno
typePackParams.push_back(errorRecoveryTypePack(scope));
}
bool sameTys = std::equal(typeParams.begin(), typeParams.end(), tf->typeParams.begin(), tf->typeParams.end(), [](auto&& itp, auto&& tp) {
return itp == tp.ty;
});
bool sameTys = std::equal(
typeParams.begin(),
typeParams.end(),
tf->typeParams.begin(),
tf->typeParams.end(),
[](auto&& itp, auto&& tp)
{
return itp == tp.ty;
}
);
bool sameTps = std::equal(
typePackParams.begin(), typePackParams.end(), tf->typePackParams.begin(), tf->typePackParams.end(), [](auto&& itpp, auto&& tpp) {
typePackParams.begin(),
typePackParams.end(),
tf->typePackParams.begin(),
tf->typePackParams.end(),
[](auto&& itpp, auto&& tpp)
{
return itpp == tpp.tp;
});
}
);
// If the generic parameters and the type arguments are the same, we are about to
// perform an identity substitution, which we can just short-circuit.
@ -5512,15 +5637,27 @@ TypeId TypeChecker::resolveTypeWorker(const ScopePtr& scope, const AstType& anno
std::vector<TypeId> genericTys;
genericTys.reserve(generics.size());
std::transform(generics.begin(), generics.end(), std::back_inserter(genericTys), [](auto&& el) {
return el.ty;
});
std::transform(
generics.begin(),
generics.end(),
std::back_inserter(genericTys),
[](auto&& el)
{
return el.ty;
}
);
std::vector<TypePackId> genericTps;
genericTps.reserve(genericPacks.size());
std::transform(genericPacks.begin(), genericPacks.end(), std::back_inserter(genericTps), [](auto&& el) {
return el.tp;
});
std::transform(
genericPacks.begin(),
genericPacks.end(),
std::back_inserter(genericTps),
[](auto&& el)
{
return el.tp;
}
);
TypeId fnType = addType(FunctionType{funcScope->level, std::move(genericTys), std::move(genericTps), argTypes, retTypes});
@ -5641,8 +5778,13 @@ TypePackId TypeChecker::resolveTypePack(const ScopePtr& scope, const AstTypePack
return result;
}
TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& typePackParams, const Location& location)
TypeId TypeChecker::instantiateTypeFun(
const ScopePtr& scope,
const TypeFun& tf,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& typePackParams,
const Location& location
)
{
if (tf.typeParams.empty() && tf.typePackParams.empty())
return tf.type;
@ -5706,8 +5848,14 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf,
return instantiated;
}
GenericTypeDefinitions TypeChecker::createGenericTypes(const ScopePtr& scope, std::optional<TypeLevel> levelOpt, const AstNode& node,
const AstArray<AstGenericType>& genericNames, const AstArray<AstGenericTypePack>& genericPackNames, bool useCache)
GenericTypeDefinitions TypeChecker::createGenericTypes(
const ScopePtr& scope,
std::optional<TypeLevel> levelOpt,
const AstNode& node,
const AstArray<AstGenericType>& genericNames,
const AstArray<AstGenericTypePack>& genericPackNames,
bool useCache
)
{
LUAU_ASSERT(scope->parent);
@ -5835,7 +5983,8 @@ void TypeChecker::refineLValue(const LValue& lvalue, RefinementMap& refis, const
}
}
auto intoType = [this](const std::unordered_set<TypeId>& s) -> std::optional<TypeId> {
auto intoType = [this](const std::unordered_set<TypeId>& s) -> std::optional<TypeId>
{
if (s.empty())
return std::nullopt;
@ -6022,7 +6171,8 @@ void TypeChecker::resolve(const OrPredicate& orP, RefinementMap& refis, const Sc
void TypeChecker::resolve(const IsAPredicate& isaP, RefinementMap& refis, const ScopePtr& scope, bool sense)
{
auto predicate = [&](TypeId option) -> std::optional<TypeId> {
auto predicate = [&](TypeId option) -> std::optional<TypeId>
{
// This by itself is not truly enough to determine that A is stronger than B or vice versa.
bool optionIsSubtype = canUnify(option, isaP.ty, scope, isaP.location).empty();
bool targetIsSubtype = canUnify(isaP.ty, option, scope, isaP.location).empty();
@ -6085,8 +6235,10 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, RefinementMap& r
return;
}
auto refine = [this, &lvalue = typeguardP.lvalue, &refis, &scope, sense](bool(f)(TypeId), std::optional<TypeId> mapsTo = std::nullopt) {
TypeIdPredicate predicate = [f, mapsTo, sense](TypeId ty) -> std::optional<TypeId> {
auto refine = [this, &lvalue = typeguardP.lvalue, &refis, &scope, sense](bool(f)(TypeId), std::optional<TypeId> mapsTo = std::nullopt)
{
TypeIdPredicate predicate = [f, mapsTo, sense](TypeId ty) -> std::optional<TypeId>
{
if (sense && get<UnknownType>(ty))
return mapsTo.value_or(ty);
@ -6117,22 +6269,31 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, RefinementMap& r
return refine(isBuffer, bufferType);
else if (typeguardP.kind == "table")
{
return refine([](TypeId ty) -> bool {
return isTableIntersection(ty) || get<TableType>(ty) || get<MetatableType>(ty);
});
return refine(
[](TypeId ty) -> bool
{
return isTableIntersection(ty) || get<TableType>(ty) || get<MetatableType>(ty);
}
);
}
else if (typeguardP.kind == "function")
{
return refine([](TypeId ty) -> bool {
return isOverloadedFunction(ty) || get<FunctionType>(ty);
});
return refine(
[](TypeId ty) -> bool
{
return isOverloadedFunction(ty) || get<FunctionType>(ty);
}
);
}
else if (typeguardP.kind == "userdata")
{
// For now, we don't really care about being accurate with userdata if the typeguard was using typeof.
return refine([](TypeId ty) -> bool {
return get<ClassType>(ty);
});
return refine(
[](TypeId ty) -> bool
{
return get<ClassType>(ty);
}
);
}
if (!typeguardP.isTypeof)
@ -6162,7 +6323,8 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, RefinementMap& r
void TypeChecker::resolve(const EqPredicate& eqP, RefinementMap& refis, const ScopePtr& scope, bool sense)
{
// This refinement will require success typing to do everything correctly. For now, we can get most of the way there.
auto options = [](TypeId ty) -> std::vector<TypeId> {
auto options = [](TypeId ty) -> std::vector<TypeId>
{
if (auto utv = get<UnionType>(follow(ty)))
return std::vector<TypeId>(begin(utv), end(utv));
return {ty};
@ -6173,7 +6335,8 @@ void TypeChecker::resolve(const EqPredicate& eqP, RefinementMap& refis, const Sc
if (sense && std::any_of(rhs.begin(), rhs.end(), isUndecidable))
return; // Optimization: the other side has unknown types, so there's probably an overlap. Refining is no-op here.
auto predicate = [&](TypeId option) -> std::optional<TypeId> {
auto predicate = [&](TypeId option) -> std::optional<TypeId>
{
if (!sense && isNil(eqP.type))
return (isUndecidable(option) || !isNil(option)) ? std::optional<TypeId>(option) : std::nullopt;

View file

@ -257,14 +257,20 @@ bool areEqual(SeenSet& seen, const TypePackVar& lhs, const TypePackVar& rhs)
TypePackId follow(TypePackId tp)
{
return follow(tp, nullptr, [](const void*, TypePackId t) {
return t;
});
return follow(
tp,
nullptr,
[](const void*, TypePackId t)
{
return t;
}
);
}
TypePackId follow(TypePackId tp, const void* context, TypePackId (*mapper)(const void*, TypePackId))
{
auto advance = [context, mapper](TypePackId ty) -> std::optional<TypePackId> {
auto advance = [context, mapper](TypePackId ty) -> std::optional<TypePackId>
{
TypePackId mapped = mapper(context, ty);
if (const Unifiable::Bound<TypePackId>* btv = get<Unifiable::Bound<TypePackId>>(mapped))

View file

@ -534,7 +534,8 @@ std::string toString(const TypePath::Path& path, bool prefixDot)
std::stringstream result;
bool first = true;
auto strComponent = [&](auto&& c) {
auto strComponent = [&](auto&& c)
{
using T = std::decay_t<decltype(c)>;
if constexpr (std::is_same_v<T, TypePath::Property>)
{
@ -626,7 +627,8 @@ std::string toString(const TypePath::Path& path, bool prefixDot)
static bool traverse(TraversalState& state, const Path& path)
{
auto step = [&state](auto&& c) {
auto step = [&state](auto&& c)
{
return state.traverse(c);
};

View file

@ -24,7 +24,8 @@ bool occursCheck(TypeId needle, TypeId haystack)
LUAU_ASSERT(get<BlockedType>(needle) || get<PendingExpansionType>(needle));
haystack = follow(haystack);
auto checkHaystack = [needle](TypeId haystack) {
auto checkHaystack = [needle](TypeId haystack)
{
return occursCheck(needle, haystack);
};
@ -92,7 +93,12 @@ std::optional<Property> findTableProperty(NotNull<BuiltinTypes> builtinTypes, Er
}
std::optional<TypeId> findMetatableEntry(
NotNull<BuiltinTypes> builtinTypes, ErrorVec& errors, TypeId type, const std::string& entry, Location location)
NotNull<BuiltinTypes> builtinTypes,
ErrorVec& errors,
TypeId type,
const std::string& entry,
Location location
)
{
type = follow(type);
@ -120,13 +126,24 @@ std::optional<TypeId> findMetatableEntry(
}
std::optional<TypeId> findTablePropertyRespectingMeta(
NotNull<BuiltinTypes> builtinTypes, ErrorVec& errors, TypeId ty, const std::string& name, Location location)
NotNull<BuiltinTypes> builtinTypes,
ErrorVec& errors,
TypeId ty,
const std::string& name,
Location location
)
{
return findTablePropertyRespectingMeta(builtinTypes, errors, ty, name, ValueContext::RValue, location);
}
std::optional<TypeId> findTablePropertyRespectingMeta(
NotNull<BuiltinTypes> builtinTypes, ErrorVec& errors, TypeId ty, const std::string& name, ValueContext context, Location location)
NotNull<BuiltinTypes> builtinTypes,
ErrorVec& errors,
TypeId ty,
const std::string& name,
ValueContext context,
Location location
)
{
if (get<AnyType>(ty))
return ty;
@ -217,7 +234,12 @@ std::pair<size_t, std::optional<size_t>> getParameterExtents(const TxnLog* log,
}
TypePack extendTypePack(
TypeArena& arena, NotNull<BuiltinTypes> builtinTypes, TypePackId pack, size_t length, std::vector<std::optional<TypeId>> overrides)
TypeArena& arena,
NotNull<BuiltinTypes> builtinTypes,
TypePackId pack,
size_t length,
std::vector<std::optional<TypeId>> overrides
)
{
TypePack result;

View file

@ -19,7 +19,6 @@ LUAU_FASTINT(LuauTypeInferTypePackLoopLimit)
LUAU_FASTFLAG(LuauErrorRecoveryType)
LUAU_FASTFLAGVARIABLE(LuauInstantiateInSubtyping, false)
LUAU_FASTFLAGVARIABLE(LuauTransitiveSubtyping, false)
LUAU_FASTFLAG(LuauAlwaysCommitInferencesOfFunctionCalls)
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
LUAU_FASTFLAGVARIABLE(LuauFixIndexerSubtypingOrdering, false)
LUAU_FASTFLAGVARIABLE(LuauUnifierShouldNotCopyError, false)
@ -329,7 +328,8 @@ TypePackId Widen::operator()(TypePackId tp)
std::optional<TypeError> hasUnificationTooComplex(const ErrorVec& errors)
{
auto isUnificationTooComplex = [](const TypeError& te) {
auto isUnificationTooComplex = [](const TypeError& te)
{
return nullptr != get<UnificationTooComplex>(te);
};
@ -342,7 +342,8 @@ std::optional<TypeError> hasUnificationTooComplex(const ErrorVec& errors)
std::optional<TypeError> hasCountMismatch(const ErrorVec& errors)
{
auto isCountMismatch = [](const TypeError& te) {
auto isCountMismatch = [](const TypeError& te)
{
return nullptr != get<CountMismatch>(te);
};
@ -771,47 +772,7 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionType* subUnion, Typ
}
}
if (FFlag::LuauAlwaysCommitInferencesOfFunctionCalls)
log.concatAsUnion(combineLogsIntoUnion(std::move(logs)), NotNull{types});
else
{
// even if A | B <: T fails, we want to bind some options of T with A | B iff A | B was a subtype of that option.
auto tryBind = [this, subTy](TypeId superOption) {
superOption = log.follow(superOption);
// just skip if the superOption is not free-ish.
auto ttv = log.getMutable<TableType>(superOption);
if (!log.is<FreeType>(superOption) && (!ttv || ttv->state != TableState::Free))
return;
// If superOption is already present in subTy, do nothing. Nothing new has been learned, but the subtype
// test is successful.
if (auto subUnion = get<UnionType>(subTy))
{
if (end(subUnion) != std::find(begin(subUnion), end(subUnion), superOption))
return;
}
// Since we have already checked if S <: T, checking it again will not queue up the type for replacement.
// So we'll have to do it ourselves. We assume they unified cleanly if they are still in the seen set.
if (log.haveSeen(subTy, superOption))
{
// TODO: would it be nice for TxnLog::replace to do this?
if (log.is<TableType>(superOption))
log.bindTable(superOption, subTy);
else
log.replace(superOption, *subTy);
}
};
if (auto superUnion = log.getMutable<UnionType>(superTy))
{
for (TypeId ty : superUnion)
tryBind(ty);
}
else
tryBind(superTy);
}
log.concatAsUnion(combineLogsIntoUnion(std::move(logs)), NotNull{types});
if (unificationTooComplex)
reportError(*unificationTooComplex);
@ -954,7 +915,8 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp
return reportError(location, NormalizationTooComplex{});
else if ((failedOptionCount == 1 || foundHeuristic) && failedOption)
innerState.tryUnifyNormalizedTypes(
subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption);
subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption
);
else
innerState.tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the union options are compatible");
@ -985,7 +947,8 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp
failure = true;
else if ((failedOptionCount == 1 || foundHeuristic) && failedOption)
reportError(
location, TypeMismatch{superTy, subTy, "None of the union options are compatible. For example:", *failedOption, mismatchContext()});
location, TypeMismatch{superTy, subTy, "None of the union options are compatible. For example:", *failedOption, mismatchContext()}
);
else
reportError(location, TypeMismatch{superTy, subTy, "none of the union options are compatible", mismatchContext()});
}
@ -1151,7 +1114,13 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionType*
}
void Unifier::tryUnifyNormalizedTypes(
TypeId subTy, TypeId superTy, const NormalizedType& subNorm, const NormalizedType& superNorm, std::string reason, std::optional<TypeError> error)
TypeId subTy,
TypeId superTy,
const NormalizedType& subNorm,
const NormalizedType& superNorm,
std::string reason,
std::optional<TypeError> error
)
{
if (get<AnyType>(superNorm.tops))
return;
@ -1394,7 +1363,8 @@ bool Unifier::canCacheResult(TypeId subTy, TypeId superTy)
if (subTyInfo && *subTyInfo)
return false;
auto skipCacheFor = [this](TypeId ty) {
auto skipCacheFor = [this](TypeId ty)
{
SkipCacheForType visitor{sharedState.skipCacheForType, types};
visitor.traverse(ty);
@ -1674,7 +1644,8 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal
superIter.scope = scope.get();
subIter.scope = scope.get();
auto mkFreshType = [this](Scope* scope, TypeLevel level) {
auto mkFreshType = [this](Scope* scope, TypeLevel level)
{
if (FFlag::DebugLuauDeferredConstraintResolution)
return freshType(NotNull{types}, builtinTypes, scope);
else
@ -1970,8 +1941,16 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal
if (auto e = hasUnificationTooComplex(innerState.errors))
reportError(*e);
else if (!innerState.errors.empty() && innerState.firstPackErrorPos)
reportError(location, TypeMismatch{superTy, subTy, format("Argument #%d type is not compatible.", *innerState.firstPackErrorPos),
innerState.errors.front(), mismatchContext()});
reportError(
location,
TypeMismatch{
superTy,
subTy,
format("Argument #%d type is not compatible.", *innerState.firstPackErrorPos),
innerState.errors.front(),
mismatchContext()
}
);
else if (!innerState.errors.empty())
reportError(location, TypeMismatch{superTy, subTy, "", innerState.errors.front(), mismatchContext()});
@ -1985,8 +1964,16 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal
else if (!innerState.errors.empty() && size(superFunction->retTypes) == 1 && finite(superFunction->retTypes))
reportError(location, TypeMismatch{superTy, subTy, "Return type is not compatible.", innerState.errors.front(), mismatchContext()});
else if (!innerState.errors.empty() && innerState.firstPackErrorPos)
reportError(location, TypeMismatch{superTy, subTy, format("Return #%d type is not compatible.", *innerState.firstPackErrorPos),
innerState.errors.front(), mismatchContext()});
reportError(
location,
TypeMismatch{
superTy,
subTy,
format("Return #%d type is not compatible.", *innerState.firstPackErrorPos),
innerState.errors.front(),
mismatchContext()
}
);
else if (!innerState.errors.empty())
reportError(location, TypeMismatch{superTy, subTy, "", innerState.errors.front(), mismatchContext()});
}
@ -2402,7 +2389,8 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed)
if (!superTable || superTable->state != TableState::Free)
return reportError(location, TypeMismatch{osuperTy, osubTy, mismatchContext()});
auto fail = [&](std::optional<TypeError> e) {
auto fail = [&](std::optional<TypeError> e)
{
std::string reason = "The former's metatable does not satisfy the requirements.";
if (e)
reportError(location, TypeMismatch{osuperTy, osubTy, reason, *e, mismatchContext()});
@ -2497,7 +2485,8 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed)
reportError(*e);
else if (!innerState.errors.empty())
reportError(
location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front(), mismatchContext()});
location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front(), mismatchContext()}
);
log.concat(std::move(innerState.log));
failure |= innerState.failure;
@ -2535,8 +2524,10 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed)
if (auto e = hasUnificationTooComplex(innerState.errors))
reportError(*e);
else if (!innerState.errors.empty())
reportError(TypeError{location,
TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front(), mismatchContext()}});
reportError(TypeError{
location,
TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front(), mismatchContext()}
});
else if (!missingProperty)
{
log.concat(std::move(innerState.log));
@ -2574,7 +2565,8 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed)
if (reversed)
std::swap(superTy, subTy);
auto fail = [&]() {
auto fail = [&]()
{
if (!reversed)
reportError(location, TypeMismatch{superTy, subTy, mismatchContext()});
else
@ -2770,8 +2762,15 @@ void Unifier::tryUnifyVariadics(TypePackId subTp, TypePackId superTp, bool rever
}
}
static void tryUnifyWithAny(std::vector<TypeId>& queue, Unifier& state, DenseHashSet<TypeId>& seen, DenseHashSet<TypePackId>& seenTypePacks,
const TypeArena* typeArena, TypeId anyType, TypePackId anyTypePack)
static void tryUnifyWithAny(
std::vector<TypeId>& queue,
Unifier& state,
DenseHashSet<TypeId>& seen,
DenseHashSet<TypePackId>& seenTypePacks,
const TypeArena* typeArena,
TypeId anyType,
TypePackId anyTypePack
)
{
while (!queue.empty())
{
@ -2927,7 +2926,8 @@ bool Unifier::occursCheck(DenseHashSet<TypeId>& seen, TypeId needle, TypeId hays
bool occurrence = false;
auto check = [&](TypeId tv) {
auto check = [&](TypeId tv)
{
if (occursCheck(seen, needle, tv))
occurrence = true;
};
@ -3064,8 +3064,10 @@ void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, const s
if (auto e = hasUnificationTooComplex(innerErrors))
reportError(*e);
else if (!innerErrors.empty())
reportError(TypeError{location,
TypeMismatch{wantedType, givenType, format("Property '%s' is not compatible.", prop.c_str()), innerErrors.front(), mismatchContext()}});
reportError(TypeError{
location,
TypeMismatch{wantedType, givenType, format("Property '%s' is not compatible.", prop.c_str()), innerErrors.front(), mismatchContext()}
});
}
void Unifier::ice(const std::string& message, const Location& location)

View file

@ -33,7 +33,8 @@ static bool areCompatible(TypeId left, TypeId right)
const TableType* rightTable = p.second;
LUAU_ASSERT(rightTable);
const auto missingPropIsCompatible = [](const Property& leftProp, const TableType* rightTable) {
const auto missingPropIsCompatible = [](const Property& leftProp, const TableType* rightTable)
{
// Two tables may be compatible even if their shapes aren't exactly the
// same if the extra property is optional, free (and therefore
// potentially optional), or if the right table has an indexer. Or if
@ -96,8 +97,13 @@ Unifier2::Unifier2(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtinTypes,
{
}
Unifier2::Unifier2(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtinTypes, NotNull<Scope> scope, NotNull<InternalErrorReporter> ice,
DenseHashSet<const void*>* uninhabitedTypeFunctions)
Unifier2::Unifier2(
NotNull<TypeArena> arena,
NotNull<BuiltinTypes> builtinTypes,
NotNull<Scope> scope,
NotNull<InternalErrorReporter> ice,
DenseHashSet<const void*>* uninhabitedTypeFunctions
)
: arena(arena)
, builtinTypes(builtinTypes)
, scope(scope)
@ -251,7 +257,8 @@ bool Unifier2::unifyFreeWithType(TypeId subTy, TypeId superTy)
FreeType* subFree = getMutable<FreeType>(subTy);
LUAU_ASSERT(subFree);
auto doDefault = [&]() {
auto doDefault = [&]()
{
subFree->upperBound = mkIntersection(subFree->upperBound, superTy);
expandedFreeTypes[subTy].push_back(superTy);
return true;
@ -841,7 +848,8 @@ OccursCheckResult Unifier2::occursCheck(DenseHashSet<TypeId>& seen, TypeId needl
OccursCheckResult occurrence = OccursCheckResult::Pass;
auto check = [&](TypeId ty) {
auto check = [&](TypeId ty)
{
if (occursCheck(seen, needle, ty) == OccursCheckResult::Fail)
occurrence = OccursCheckResult::Fail;
};

View file

@ -384,7 +384,13 @@ public:
LUAU_RTTI(AstExprIndexName)
AstExprIndexName(
const Location& location, AstExpr* expr, const AstName& index, const Location& indexLocation, const Position& opPosition, char op);
const Location& location,
AstExpr* expr,
const AstName& index,
const Location& indexLocation,
const Position& opPosition,
char op
);
void visit(AstVisitor* visitor) override;
@ -413,11 +419,22 @@ class AstExprFunction : public AstExpr
public:
LUAU_RTTI(AstExprFunction)
AstExprFunction(const Location& location, const AstArray<AstAttr*>& attributes, const AstArray<AstGenericType>& generics,
const AstArray<AstGenericTypePack>& genericPacks, AstLocal* self, const AstArray<AstLocal*>& args, bool vararg,
const Location& varargLocation, AstStatBlock* body, size_t functionDepth, const AstName& debugname,
const std::optional<AstTypeList>& returnAnnotation = {}, AstTypePack* varargAnnotation = nullptr,
const std::optional<Location>& argLocation = std::nullopt);
AstExprFunction(
const Location& location,
const AstArray<AstAttr*>& attributes,
const AstArray<AstGenericType>& generics,
const AstArray<AstGenericTypePack>& genericPacks,
AstLocal* self,
const AstArray<AstLocal*>& args,
bool vararg,
const Location& varargLocation,
AstStatBlock* body,
size_t functionDepth,
const AstName& debugname,
const std::optional<AstTypeList>& returnAnnotation = {},
AstTypePack* varargAnnotation = nullptr,
const std::optional<Location>& argLocation = std::nullopt
);
void visit(AstVisitor* visitor) override;
@ -603,8 +620,14 @@ class AstStatIf : public AstStat
public:
LUAU_RTTI(AstStatIf)
AstStatIf(const Location& location, AstExpr* condition, AstStatBlock* thenbody, AstStat* elsebody, const std::optional<Location>& thenLocation,
const std::optional<Location>& elseLocation);
AstStatIf(
const Location& location,
AstExpr* condition,
AstStatBlock* thenbody,
AstStat* elsebody,
const std::optional<Location>& thenLocation,
const std::optional<Location>& elseLocation
);
void visit(AstVisitor* visitor) override;
@ -698,8 +721,12 @@ class AstStatLocal : public AstStat
public:
LUAU_RTTI(AstStatLocal)
AstStatLocal(const Location& location, const AstArray<AstLocal*>& vars, const AstArray<AstExpr*>& values,
const std::optional<Location>& equalsSignLocation);
AstStatLocal(
const Location& location,
const AstArray<AstLocal*>& vars,
const AstArray<AstExpr*>& values,
const std::optional<Location>& equalsSignLocation
);
void visit(AstVisitor* visitor) override;
@ -714,8 +741,16 @@ class AstStatFor : public AstStat
public:
LUAU_RTTI(AstStatFor)
AstStatFor(const Location& location, AstLocal* var, AstExpr* from, AstExpr* to, AstExpr* step, AstStatBlock* body, bool hasDo,
const Location& doLocation);
AstStatFor(
const Location& location,
AstLocal* var,
AstExpr* from,
AstExpr* to,
AstExpr* step,
AstStatBlock* body,
bool hasDo,
const Location& doLocation
);
void visit(AstVisitor* visitor) override;
@ -734,8 +769,16 @@ class AstStatForIn : public AstStat
public:
LUAU_RTTI(AstStatForIn)
AstStatForIn(const Location& location, const AstArray<AstLocal*>& vars, const AstArray<AstExpr*>& values, AstStatBlock* body, bool hasIn,
const Location& inLocation, bool hasDo, const Location& doLocation);
AstStatForIn(
const Location& location,
const AstArray<AstLocal*>& vars,
const AstArray<AstExpr*>& values,
AstStatBlock* body,
bool hasIn,
const Location& inLocation,
bool hasDo,
const Location& doLocation
);
void visit(AstVisitor* visitor) override;
@ -808,8 +851,15 @@ class AstStatTypeAlias : public AstStat
public:
LUAU_RTTI(AstStatTypeAlias)
AstStatTypeAlias(const Location& location, const AstName& name, const Location& nameLocation, const AstArray<AstGenericType>& generics,
const AstArray<AstGenericTypePack>& genericPacks, AstType* type, bool exported);
AstStatTypeAlias(
const Location& location,
const AstName& name,
const Location& nameLocation,
const AstArray<AstGenericType>& generics,
const AstArray<AstGenericTypePack>& genericPacks,
AstType* type,
bool exported
);
void visit(AstVisitor* visitor) override;
@ -821,6 +871,20 @@ public:
bool exported;
};
class AstStatTypeFunction : public AstStat
{
public:
LUAU_RTTI(AstStatTypeFunction);
AstStatTypeFunction(const Location& location, const AstName& name, const Location& nameLocation, AstExprFunction* body);
void visit(AstVisitor* visitor) override;
AstName name;
Location nameLocation;
AstExprFunction* body;
};
class AstStatDeclareGlobal : public AstStat
{
public:
@ -840,13 +904,32 @@ class AstStatDeclareFunction : public AstStat
public:
LUAU_RTTI(AstStatDeclareFunction)
AstStatDeclareFunction(const Location& location, const AstName& name, const Location& nameLocation, const AstArray<AstGenericType>& generics,
const AstArray<AstGenericTypePack>& genericPacks, const AstTypeList& params, const AstArray<AstArgumentName>& paramNames, bool vararg,
const Location& varargLocation, const AstTypeList& retTypes);
AstStatDeclareFunction(
const Location& location,
const AstName& name,
const Location& nameLocation,
const AstArray<AstGenericType>& generics,
const AstArray<AstGenericTypePack>& genericPacks,
const AstTypeList& params,
const AstArray<AstArgumentName>& paramNames,
bool vararg,
const Location& varargLocation,
const AstTypeList& retTypes
);
AstStatDeclareFunction(const Location& location, const AstArray<AstAttr*>& attributes, const AstName& name, const Location& nameLocation,
const AstArray<AstGenericType>& generics, const AstArray<AstGenericTypePack>& genericPacks, const AstTypeList& params,
const AstArray<AstArgumentName>& paramNames, bool vararg, const Location& varargLocation, const AstTypeList& retTypes);
AstStatDeclareFunction(
const Location& location,
const AstArray<AstAttr*>& attributes,
const AstName& name,
const Location& nameLocation,
const AstArray<AstGenericType>& generics,
const AstArray<AstGenericTypePack>& genericPacks,
const AstTypeList& params,
const AstArray<AstArgumentName>& paramNames,
bool vararg,
const Location& varargLocation,
const AstTypeList& retTypes
);
void visit(AstVisitor* visitor) override;
@ -896,8 +979,13 @@ class AstStatDeclareClass : public AstStat
public:
LUAU_RTTI(AstStatDeclareClass)
AstStatDeclareClass(const Location& location, const AstName& name, std::optional<AstName> superName, const AstArray<AstDeclaredClassProp>& props,
AstTableIndexer* indexer = nullptr);
AstStatDeclareClass(
const Location& location,
const AstName& name,
std::optional<AstName> superName,
const AstArray<AstDeclaredClassProp>& props,
AstTableIndexer* indexer = nullptr
);
void visit(AstVisitor* visitor) override;
@ -934,8 +1022,15 @@ class AstTypeReference : public AstType
public:
LUAU_RTTI(AstTypeReference)
AstTypeReference(const Location& location, std::optional<AstName> prefix, AstName name, std::optional<Location> prefixLocation,
const Location& nameLocation, bool hasParameterList = false, const AstArray<AstTypeOrPack>& parameters = {});
AstTypeReference(
const Location& location,
std::optional<AstName> prefix,
AstName name,
std::optional<Location> prefixLocation,
const Location& nameLocation,
bool hasParameterList = false,
const AstArray<AstTypeOrPack>& parameters = {}
);
void visit(AstVisitor* visitor) override;
@ -974,12 +1069,24 @@ class AstTypeFunction : public AstType
public:
LUAU_RTTI(AstTypeFunction)
AstTypeFunction(const Location& location, const AstArray<AstGenericType>& generics, const AstArray<AstGenericTypePack>& genericPacks,
const AstTypeList& argTypes, const AstArray<std::optional<AstArgumentName>>& argNames, const AstTypeList& returnTypes);
AstTypeFunction(
const Location& location,
const AstArray<AstGenericType>& generics,
const AstArray<AstGenericTypePack>& genericPacks,
const AstTypeList& argTypes,
const AstArray<std::optional<AstArgumentName>>& argNames,
const AstTypeList& returnTypes
);
AstTypeFunction(const Location& location, const AstArray<AstAttr*>& attributes, const AstArray<AstGenericType>& generics,
const AstArray<AstGenericTypePack>& genericPacks, const AstTypeList& argTypes, const AstArray<std::optional<AstArgumentName>>& argNames,
const AstTypeList& returnTypes);
AstTypeFunction(
const Location& location,
const AstArray<AstAttr*>& attributes,
const AstArray<AstGenericType>& generics,
const AstArray<AstGenericTypePack>& genericPacks,
const AstTypeList& argTypes,
const AstArray<std::optional<AstArgumentName>>& argNames,
const AstTypeList& returnTypes
);
void visit(AstVisitor* visitor) override;
@ -1413,4 +1520,4 @@ struct hash<Luau::AstName>
}
};
} // namespace std
} // namespace std

View file

@ -55,7 +55,12 @@ class Parser
{
public:
static ParseResult parse(
const char* buffer, std::size_t bufferSize, AstNameTable& names, Allocator& allocator, ParseOptions options = ParseOptions());
const char* buffer,
std::size_t bufferSize,
AstNameTable& names,
Allocator& allocator,
ParseOptions options = ParseOptions()
);
private:
struct Name;
@ -140,6 +145,9 @@ private:
// type Name `=' Type
AstStat* parseTypeAlias(const Location& start, bool exported);
// type function Name ... end
AstStat* parseTypeFunction(const Location& start);
AstDeclaredClassProp parseDeclaredClassMethod();
// `declare global' Name: Type |
@ -157,7 +165,12 @@ private:
// funcbodyhead ::= `(' [namelist [`,' `...'] | `...'] `)' [`:` Type]
// funcbody ::= funcbodyhead block end
std::pair<AstExprFunction*, AstLocal*> parseFunctionBody(
bool hasself, const Lexeme& matchFunction, const AstName& debugname, const Name* localName, const AstArray<AstAttr*>& attributes);
bool hasself,
const Lexeme& matchFunction,
const AstName& debugname,
const Name* localName,
const AstArray<AstAttr*>& attributes
);
// explist ::= {exp `,'} exp
void parseExprList(TempVector<AstExpr*>& result);
@ -191,9 +204,15 @@ private:
AstTableIndexer* parseTableIndexer(AstTableAccess access, std::optional<Location> accessLocation);
AstTypeOrPack parseFunctionType(bool allowPack, const AstArray<AstAttr*>& attributes);
AstType* parseFunctionTypeTail(const Lexeme& begin, const AstArray<AstAttr*>& attributes, AstArray<AstGenericType> generics,
AstArray<AstGenericTypePack> genericPacks, AstArray<AstType*> params, AstArray<std::optional<AstArgumentName>> paramNames,
AstTypePack* varargAnnotation);
AstType* parseFunctionTypeTail(
const Lexeme& begin,
const AstArray<AstAttr*>& attributes,
AstArray<AstGenericType> generics,
AstArray<AstGenericTypePack> genericPacks,
AstArray<AstType*> params,
AstArray<std::optional<AstArgumentName>> paramNames,
AstTypePack* varargAnnotation
);
AstType* parseTableType(bool inDeclarationContext = false);
AstTypeOrPack parseSimpleType(bool allowPack, bool inDeclarationContext = false);
@ -315,8 +334,13 @@ private:
void reportNameError(const char* context);
AstStatError* reportStatError(const Location& location, const AstArray<AstExpr*>& expressions, const AstArray<AstStat*>& statements,
const char* format, ...) LUAU_PRINTF_ATTR(5, 6);
AstStatError* reportStatError(
const Location& location,
const AstArray<AstExpr*>& expressions,
const AstArray<AstStat*>& statements,
const char* format,
...
) LUAU_PRINTF_ATTR(5, 6);
AstExprError* reportExprError(const Location& location, const AstArray<AstExpr*>& expressions, const char* format, ...) LUAU_PRINTF_ATTR(4, 5);
AstTypeError* reportTypeError(const Location& location, const AstArray<AstType*>& types, const char* format, ...) LUAU_PRINTF_ATTR(4, 5);
// `parseErrorLocation` is associated with the parser error
@ -428,4 +452,4 @@ private:
std::string scratchData;
};
} // namespace Luau
} // namespace Luau

View file

@ -141,7 +141,13 @@ void AstExprCall::visit(AstVisitor* visitor)
}
AstExprIndexName::AstExprIndexName(
const Location& location, AstExpr* expr, const AstName& index, const Location& indexLocation, const Position& opPosition, char op)
const Location& location,
AstExpr* expr,
const AstName& index,
const Location& indexLocation,
const Position& opPosition,
char op
)
: AstExpr(ClassIndex(), location)
, expr(expr)
, index(index)
@ -173,10 +179,22 @@ void AstExprIndexExpr::visit(AstVisitor* visitor)
}
}
AstExprFunction::AstExprFunction(const Location& location, const AstArray<AstAttr*>& attributes, const AstArray<AstGenericType>& generics,
const AstArray<AstGenericTypePack>& genericPacks, AstLocal* self, const AstArray<AstLocal*>& args, bool vararg, const Location& varargLocation,
AstStatBlock* body, size_t functionDepth, const AstName& debugname, const std::optional<AstTypeList>& returnAnnotation,
AstTypePack* varargAnnotation, const std::optional<Location>& argLocation)
AstExprFunction::AstExprFunction(
const Location& location,
const AstArray<AstAttr*>& attributes,
const AstArray<AstGenericType>& generics,
const AstArray<AstGenericTypePack>& genericPacks,
AstLocal* self,
const AstArray<AstLocal*>& args,
bool vararg,
const Location& varargLocation,
AstStatBlock* body,
size_t functionDepth,
const AstName& debugname,
const std::optional<AstTypeList>& returnAnnotation,
AstTypePack* varargAnnotation,
const std::optional<Location>& argLocation
)
: AstExpr(ClassIndex(), location)
, attributes(attributes)
, generics(generics)
@ -418,8 +436,14 @@ void AstStatBlock::visit(AstVisitor* visitor)
}
}
AstStatIf::AstStatIf(const Location& location, AstExpr* condition, AstStatBlock* thenbody, AstStat* elsebody,
const std::optional<Location>& thenLocation, const std::optional<Location>& elseLocation)
AstStatIf::AstStatIf(
const Location& location,
AstExpr* condition,
AstStatBlock* thenbody,
AstStat* elsebody,
const std::optional<Location>& thenLocation,
const std::optional<Location>& elseLocation
)
: AstStat(ClassIndex(), location)
, condition(condition)
, thenbody(thenbody)
@ -524,7 +548,11 @@ void AstStatExpr::visit(AstVisitor* visitor)
}
AstStatLocal::AstStatLocal(
const Location& location, const AstArray<AstLocal*>& vars, const AstArray<AstExpr*>& values, const std::optional<Location>& equalsSignLocation)
const Location& location,
const AstArray<AstLocal*>& vars,
const AstArray<AstExpr*>& values,
const std::optional<Location>& equalsSignLocation
)
: AstStat(ClassIndex(), location)
, vars(vars)
, values(values)
@ -548,7 +576,15 @@ void AstStatLocal::visit(AstVisitor* visitor)
}
AstStatFor::AstStatFor(
const Location& location, AstLocal* var, AstExpr* from, AstExpr* to, AstExpr* step, AstStatBlock* body, bool hasDo, const Location& doLocation)
const Location& location,
AstLocal* var,
AstExpr* from,
AstExpr* to,
AstExpr* step,
AstStatBlock* body,
bool hasDo,
const Location& doLocation
)
: AstStat(ClassIndex(), location)
, var(var)
, from(from)
@ -577,8 +613,16 @@ void AstStatFor::visit(AstVisitor* visitor)
}
}
AstStatForIn::AstStatForIn(const Location& location, const AstArray<AstLocal*>& vars, const AstArray<AstExpr*>& values, AstStatBlock* body,
bool hasIn, const Location& inLocation, bool hasDo, const Location& doLocation)
AstStatForIn::AstStatForIn(
const Location& location,
const AstArray<AstLocal*>& vars,
const AstArray<AstExpr*>& values,
AstStatBlock* body,
bool hasIn,
const Location& inLocation,
bool hasDo,
const Location& doLocation
)
: AstStat(ClassIndex(), location)
, vars(vars)
, values(values)
@ -672,8 +716,15 @@ void AstStatLocalFunction::visit(AstVisitor* visitor)
func->visit(visitor);
}
AstStatTypeAlias::AstStatTypeAlias(const Location& location, const AstName& name, const Location& nameLocation,
const AstArray<AstGenericType>& generics, const AstArray<AstGenericTypePack>& genericPacks, AstType* type, bool exported)
AstStatTypeAlias::AstStatTypeAlias(
const Location& location,
const AstName& name,
const Location& nameLocation,
const AstArray<AstGenericType>& generics,
const AstArray<AstGenericTypePack>& genericPacks,
AstType* type,
bool exported
)
: AstStat(ClassIndex(), location)
, name(name)
, nameLocation(nameLocation)
@ -704,6 +755,20 @@ void AstStatTypeAlias::visit(AstVisitor* visitor)
}
}
AstStatTypeFunction::AstStatTypeFunction(const Location& location, const AstName& name, const Location& nameLocation, AstExprFunction* body)
: AstStat(ClassIndex(), location)
, name(name)
, nameLocation(nameLocation)
, body(body)
{
}
void AstStatTypeFunction::visit(AstVisitor* visitor)
{
if (visitor->visit(this))
body->visit(visitor);
}
AstStatDeclareGlobal::AstStatDeclareGlobal(const Location& location, const AstName& name, const Location& nameLocation, AstType* type)
: AstStat(ClassIndex(), location)
, name(name)
@ -718,9 +783,18 @@ void AstStatDeclareGlobal::visit(AstVisitor* visitor)
type->visit(visitor);
}
AstStatDeclareFunction::AstStatDeclareFunction(const Location& location, const AstName& name, const Location& nameLocation,
const AstArray<AstGenericType>& generics, const AstArray<AstGenericTypePack>& genericPacks, const AstTypeList& params,
const AstArray<AstArgumentName>& paramNames, bool vararg, const Location& varargLocation, const AstTypeList& retTypes)
AstStatDeclareFunction::AstStatDeclareFunction(
const Location& location,
const AstName& name,
const Location& nameLocation,
const AstArray<AstGenericType>& generics,
const AstArray<AstGenericTypePack>& genericPacks,
const AstTypeList& params,
const AstArray<AstArgumentName>& paramNames,
bool vararg,
const Location& varargLocation,
const AstTypeList& retTypes
)
: AstStat(ClassIndex(), location)
, attributes()
, name(name)
@ -735,9 +809,19 @@ AstStatDeclareFunction::AstStatDeclareFunction(const Location& location, const A
{
}
AstStatDeclareFunction::AstStatDeclareFunction(const Location& location, const AstArray<AstAttr*>& attributes, const AstName& name,
const Location& nameLocation, const AstArray<AstGenericType>& generics, const AstArray<AstGenericTypePack>& genericPacks,
const AstTypeList& params, const AstArray<AstArgumentName>& paramNames, bool vararg, const Location& varargLocation, const AstTypeList& retTypes)
AstStatDeclareFunction::AstStatDeclareFunction(
const Location& location,
const AstArray<AstAttr*>& attributes,
const AstName& name,
const Location& nameLocation,
const AstArray<AstGenericType>& generics,
const AstArray<AstGenericTypePack>& genericPacks,
const AstTypeList& params,
const AstArray<AstArgumentName>& paramNames,
bool vararg,
const Location& varargLocation,
const AstTypeList& retTypes
)
: AstStat(ClassIndex(), location)
, attributes(attributes)
, name(name)
@ -772,8 +856,13 @@ bool AstStatDeclareFunction::isCheckedFunction() const
return false;
}
AstStatDeclareClass::AstStatDeclareClass(const Location& location, const AstName& name, std::optional<AstName> superName,
const AstArray<AstDeclaredClassProp>& props, AstTableIndexer* indexer)
AstStatDeclareClass::AstStatDeclareClass(
const Location& location,
const AstName& name,
std::optional<AstName> superName,
const AstArray<AstDeclaredClassProp>& props,
AstTableIndexer* indexer
)
: AstStat(ClassIndex(), location)
, name(name)
, superName(superName)
@ -792,7 +881,11 @@ void AstStatDeclareClass::visit(AstVisitor* visitor)
}
AstStatError::AstStatError(
const Location& location, const AstArray<AstExpr*>& expressions, const AstArray<AstStat*>& statements, unsigned messageIndex)
const Location& location,
const AstArray<AstExpr*>& expressions,
const AstArray<AstStat*>& statements,
unsigned messageIndex
)
: AstStat(ClassIndex(), location)
, expressions(expressions)
, statements(statements)
@ -812,8 +905,15 @@ void AstStatError::visit(AstVisitor* visitor)
}
}
AstTypeReference::AstTypeReference(const Location& location, std::optional<AstName> prefix, AstName name, std::optional<Location> prefixLocation,
const Location& nameLocation, bool hasParameterList, const AstArray<AstTypeOrPack>& parameters)
AstTypeReference::AstTypeReference(
const Location& location,
std::optional<AstName> prefix,
AstName name,
std::optional<Location> prefixLocation,
const Location& nameLocation,
bool hasParameterList,
const AstArray<AstTypeOrPack>& parameters
)
: AstType(ClassIndex(), location)
, hasParameterList(hasParameterList)
, prefix(prefix)
@ -860,8 +960,14 @@ void AstTypeTable::visit(AstVisitor* visitor)
}
}
AstTypeFunction::AstTypeFunction(const Location& location, const AstArray<AstGenericType>& generics, const AstArray<AstGenericTypePack>& genericPacks,
const AstTypeList& argTypes, const AstArray<std::optional<AstArgumentName>>& argNames, const AstTypeList& returnTypes)
AstTypeFunction::AstTypeFunction(
const Location& location,
const AstArray<AstGenericType>& generics,
const AstArray<AstGenericTypePack>& genericPacks,
const AstTypeList& argTypes,
const AstArray<std::optional<AstArgumentName>>& argNames,
const AstTypeList& returnTypes
)
: AstType(ClassIndex(), location)
, attributes()
, generics(generics)
@ -873,9 +979,15 @@ AstTypeFunction::AstTypeFunction(const Location& location, const AstArray<AstGen
LUAU_ASSERT(argNames.size == 0 || argNames.size == argTypes.types.size);
}
AstTypeFunction::AstTypeFunction(const Location& location, const AstArray<AstAttr*>& attributes, const AstArray<AstGenericType>& generics,
const AstArray<AstGenericTypePack>& genericPacks, const AstTypeList& argTypes, const AstArray<std::optional<AstArgumentName>>& argNames,
const AstTypeList& returnTypes)
AstTypeFunction::AstTypeFunction(
const Location& location,
const AstArray<AstAttr*>& attributes,
const AstArray<AstGenericType>& generics,
const AstArray<AstGenericTypePack>& genericPacks,
const AstTypeList& argTypes,
const AstArray<std::optional<AstArgumentName>>& argNames,
const AstTypeList& returnTypes
)
: AstType(ClassIndex(), location)
, attributes(attributes)
, generics(generics)
@ -1053,4 +1165,4 @@ Location getLocation(const AstTypeList& typeList)
return result;
}
} // namespace Luau
} // namespace Luau

View file

@ -1808,9 +1808,15 @@ static const Confusable kConfusables[] =
const char* findConfusable(uint32_t codepoint)
{
auto it = std::lower_bound(std::begin(kConfusables), std::end(kConfusables), codepoint, [](const Confusable& lhs, uint32_t rhs) {
return lhs.codepoint < rhs;
});
auto it = std::lower_bound(
std::begin(kConfusables),
std::end(kConfusables),
codepoint,
[](const Confusable& lhs, uint32_t rhs)
{
return lhs.codepoint < rhs;
}
);
return (it != std::end(kConfusables) && it->codepoint == codepoint) ? it->text : nullptr;
}

View file

@ -92,8 +92,10 @@ Lexeme::Lexeme(const Location& location, Type type, const char* data, size_t siz
, length(unsigned(size))
, data(data)
{
LUAU_ASSERT(type == RawString || type == QuotedString || type == InterpStringBegin || type == InterpStringMid || type == InterpStringEnd ||
type == InterpStringSimple || type == BrokenInterpDoubleBrace || type == Number || type == Comment || type == BlockComment);
LUAU_ASSERT(
type == RawString || type == QuotedString || type == InterpStringBegin || type == InterpStringMid || type == InterpStringEnd ||
type == InterpStringSimple || type == BrokenInterpDoubleBrace || type == Number || type == Comment || type == BlockComment
);
}
Lexeme::Lexeme(const Location& location, Type type, const char* name)
@ -107,14 +109,16 @@ Lexeme::Lexeme(const Location& location, Type type, const char* name)
unsigned int Lexeme::getLength() const
{
LUAU_ASSERT(type == RawString || type == QuotedString || type == InterpStringBegin || type == InterpStringMid || type == InterpStringEnd ||
type == InterpStringSimple || type == BrokenInterpDoubleBrace || type == Number || type == Comment || type == BlockComment);
LUAU_ASSERT(
type == RawString || type == QuotedString || type == InterpStringBegin || type == InterpStringMid || type == InterpStringEnd ||
type == InterpStringSimple || type == BrokenInterpDoubleBrace || type == Number || type == Comment || type == BlockComment
);
return length;
}
static const char* kReserved[] = {"and", "break", "do", "else", "elseif", "end", "false", "for", "function", "if", "in", "local", "nil", "not", "or",
"repeat", "return", "then", "true", "until", "while"};
static const char* kReserved[] = {"and", "break", "do", "else", "elseif", "end", "false", "for", "function", "if", "in",
"local", "nil", "not", "or", "repeat", "return", "then", "true", "until", "while"};
std::string Lexeme::toString() const
{

View file

@ -20,6 +20,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false)
LUAU_FASTFLAGVARIABLE(LuauNativeAttribute, false)
LUAU_FASTFLAGVARIABLE(LuauAttributeSyntaxFunExpr, false)
LUAU_FASTFLAGVARIABLE(LuauDeclarationExtraPropData, false)
LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunctions, false)
namespace Luau
{
@ -785,9 +786,13 @@ AstStat* Parser::parseAttributeStat()
return parseDeclaration(expr->location, attributes);
}
default:
return reportStatError(lexer.current().location, {}, {},
return reportStatError(
lexer.current().location,
{},
{},
"Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got %s instead",
lexer.current().toString().c_str());
lexer.current().toString().c_str()
);
}
}
@ -825,8 +830,13 @@ AstStat* Parser::parseLocal(const AstArray<AstAttr*>& attributes)
{
if (attributes.size != 0)
{
return reportStatError(lexer.current().location, {}, {}, "Expected 'function' after local declaration with attribute, but got %s instead",
lexer.current().toString().c_str());
return reportStatError(
lexer.current().location,
{},
{},
"Expected 'function' after local declaration with attribute, but got %s instead",
lexer.current().toString().c_str()
);
}
matchRecoveryStopOnToken['=']++;
@ -880,6 +890,15 @@ AstStat* Parser::parseReturn()
// type Name [`<' varlist `>'] `=' Type
AstStat* Parser::parseTypeAlias(const Location& start, bool exported)
{
// parsing a type function
if (FFlag::LuauUserDefinedTypeFunctions)
{
if (lexer.current().type == Lexeme::ReservedFunction)
return parseTypeFunction(start);
}
// parsing a type alias
// note: `type` token is already parsed for us, so we just need to parse the rest
std::optional<Name> name = parseNameOpt("type name");
@ -897,6 +916,26 @@ AstStat* Parser::parseTypeAlias(const Location& start, bool exported)
return allocator.alloc<AstStatTypeAlias>(Location(start, type->location), name->name, name->location, generics, genericPacks, type, exported);
}
// type function Name `(' arglist `)' `=' funcbody `end'
AstStat* Parser::parseTypeFunction(const Location& start)
{
Lexeme matchFn = lexer.current();
nextLexeme();
// parse the name of the type function
std::optional<Name> fnName = parseNameOpt("type function name");
if (!fnName)
fnName = Name(nameError, lexer.current().location);
matchRecoveryStopOnToken[Lexeme::ReservedEnd]++;
AstExprFunction* body = parseFunctionBody(/* hasself */ false, matchFn, fnName->name, nullptr, AstArray<AstAttr*>({nullptr, 0})).first;
matchRecoveryStopOnToken[Lexeme::ReservedEnd]--;
return allocator.alloc<AstStatTypeFunction>(Location(start, body->location), fnName->name, fnName->location, body);
}
AstDeclaredClassProp Parser::parseDeclaredClassMethod()
{
Location start;
@ -940,8 +979,12 @@ AstDeclaredClassProp Parser::parseDeclaredClassMethod()
if (args.size() == 0 || args[0].name.name != "self" || args[0].annotation != nullptr)
{
return AstDeclaredClassProp{fnName.name, FFlag::LuauDeclarationExtraPropData ? fnName.location : Location{},
reportTypeError(Location(start, end), {}, "'self' must be present as the unannotated first parameter"), true};
return AstDeclaredClassProp{
fnName.name,
FFlag::LuauDeclarationExtraPropData ? fnName.location : Location{},
reportTypeError(Location(start, end), {}, "'self' must be present as the unannotated first parameter"),
true
};
}
// Skip the first index.
@ -959,10 +1002,16 @@ AstDeclaredClassProp Parser::parseDeclaredClassMethod()
report(start, "All declaration parameters aside from 'self' must be annotated");
AstType* fnType = allocator.alloc<AstTypeFunction>(
Location(start, end), generics, genericPacks, AstTypeList{copy(vars), varargAnnotation}, copy(varNames), retTypes);
Location(start, end), generics, genericPacks, AstTypeList{copy(vars), varargAnnotation}, copy(varNames), retTypes
);
return AstDeclaredClassProp{fnName.name, FFlag::LuauDeclarationExtraPropData ? fnName.location : Location{}, fnType, true,
FFlag::LuauDeclarationExtraPropData ? Location(start, end) : Location{}};
return AstDeclaredClassProp{
fnName.name,
FFlag::LuauDeclarationExtraPropData ? fnName.location : Location{},
fnType,
true,
FFlag::LuauDeclarationExtraPropData ? Location(start, end) : Location{}
};
}
AstStat* Parser::parseDeclaration(const Location& start, const AstArray<AstAttr*>& attributes)
@ -970,8 +1019,13 @@ AstStat* Parser::parseDeclaration(const Location& start, const AstArray<AstAttr*
// `declare` token is already parsed at this point
if ((attributes.size != 0) && (lexer.current().type != Lexeme::ReservedFunction))
return reportStatError(lexer.current().location, {}, {}, "Expected a function type declaration after attribute, but got %s instead",
lexer.current().toString().c_str());
return reportStatError(
lexer.current().location,
{},
{},
"Expected a function type declaration after attribute, but got %s instead",
lexer.current().toString().c_str()
);
if (lexer.current().type == Lexeme::ReservedFunction)
{
@ -1014,11 +1068,33 @@ AstStat* Parser::parseDeclaration(const Location& start, const AstArray<AstAttr*
return reportStatError(Location(start, end), {}, {}, "All declaration parameters must be annotated");
if (FFlag::LuauDeclarationExtraPropData)
return allocator.alloc<AstStatDeclareFunction>(Location(start, end), attributes, globalName.name, globalName.location, generics,
genericPacks, AstTypeList{copy(vars), varargAnnotation}, copy(varNames), vararg, varargLocation, retTypes);
return allocator.alloc<AstStatDeclareFunction>(
Location(start, end),
attributes,
globalName.name,
globalName.location,
generics,
genericPacks,
AstTypeList{copy(vars), varargAnnotation},
copy(varNames),
vararg,
varargLocation,
retTypes
);
else
return allocator.alloc<AstStatDeclareFunction>(Location(start, end), attributes, globalName.name, Location{}, generics, genericPacks,
AstTypeList{copy(vars), varargAnnotation}, copy(varNames), false, Location{}, retTypes);
return allocator.alloc<AstStatDeclareFunction>(
Location(start, end),
attributes,
globalName.name,
Location{},
generics,
genericPacks,
AstTypeList{copy(vars), varargAnnotation},
copy(varNames),
false,
Location{},
retTypes
);
}
else if (AstName(lexer.current().name) == "class")
{
@ -1064,7 +1140,8 @@ AstStat* Parser::parseDeclaration(const Location& start, const AstArray<AstAttr*
if (chars && !containsNull)
props.push_back(AstDeclaredClassProp{
AstName(chars->data), Location(nameBegin, nameEnd), type, false, Location(begin.location, lexer.previousLocation())});
AstName(chars->data), Location(nameBegin, nameEnd), type, false, Location(begin.location, lexer.previousLocation())
});
else
report(begin.location, "String literal contains malformed escape sequence or \\0");
}
@ -1107,8 +1184,8 @@ AstStat* Parser::parseDeclaration(const Location& start, const AstArray<AstAttr*
Name propName = parseName("property name");
expectAndConsume(':', "property type annotation");
AstType* propType = parseType();
props.push_back(
AstDeclaredClassProp{propName.name, propName.location, propType, false, Location(propStart, lexer.previousLocation())});
props.push_back(AstDeclaredClassProp{propName.name, propName.location, propType, false, Location(propStart, lexer.previousLocation())}
);
}
else
{
@ -1130,7 +1207,8 @@ AstStat* Parser::parseDeclaration(const Location& start, const AstArray<AstAttr*
AstType* type = parseType(/* in declaration context */ true);
return allocator.alloc<AstStatDeclareGlobal>(
Location(start, type->location), globalName->name, FFlag::LuauDeclarationExtraPropData ? globalName->location : Location{}, type);
Location(start, type->location), globalName->name, FFlag::LuauDeclarationExtraPropData ? globalName->location : Location{}, type
);
}
else
{
@ -1205,7 +1283,12 @@ std::pair<AstLocal*, AstArray<AstLocal*>> Parser::prepareFunctionArguments(const
// funcbody ::= `(' [parlist] `)' [`:' ReturnType] block end
// parlist ::= bindinglist [`,' `...'] | `...'
std::pair<AstExprFunction*, AstLocal*> Parser::parseFunctionBody(
bool hasself, const Lexeme& matchFunction, const AstName& debugname, const Name* localName, const AstArray<AstAttr*>& attributes)
bool hasself,
const Lexeme& matchFunction,
const AstName& debugname,
const Name* localName,
const AstArray<AstAttr*>& attributes
)
{
Location start = matchFunction.location;
@ -1257,9 +1340,25 @@ std::pair<AstExprFunction*, AstLocal*> Parser::parseFunctionBody(
bool hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchFunction);
body->hasEnd = hasEnd;
return {allocator.alloc<AstExprFunction>(Location(start, end), attributes, generics, genericPacks, self, vars, vararg, varargLocation, body,
functionStack.size(), debugname, typelist, varargAnnotation, argLocation),
funLocal};
return {
allocator.alloc<AstExprFunction>(
Location(start, end),
attributes,
generics,
genericPacks,
self,
vars,
vararg,
varargLocation,
body,
functionStack.size(),
debugname,
typelist,
varargAnnotation,
argLocation
),
funLocal
};
}
// explist ::= {exp `,'} exp
@ -1656,9 +1755,15 @@ AstTypeOrPack Parser::parseFunctionType(bool allowPack, const AstArray<AstAttr*>
return {parseFunctionTypeTail(begin, attributes, generics, genericPacks, paramTypes, paramNames, varargAnnotation), {}};
}
AstType* Parser::parseFunctionTypeTail(const Lexeme& begin, const AstArray<AstAttr*>& attributes, AstArray<AstGenericType> generics,
AstArray<AstGenericTypePack> genericPacks, AstArray<AstType*> params, AstArray<std::optional<AstArgumentName>> paramNames,
AstTypePack* varargAnnotation)
AstType* Parser::parseFunctionTypeTail(
const Lexeme& begin,
const AstArray<AstAttr*>& attributes,
AstArray<AstGenericType> generics,
AstArray<AstGenericTypePack> genericPacks,
AstArray<AstType*> params,
AstArray<std::optional<AstArgumentName>> paramNames,
AstTypePack* varargAnnotation
)
{
incrementRecursionCounter("type annotation");
@ -1683,7 +1788,8 @@ AstType* Parser::parseFunctionTypeTail(const Lexeme& begin, const AstArray<AstAt
AstTypeList paramTypes = AstTypeList{params, varargAnnotation};
return allocator.alloc<AstTypeFunction>(
Location(begin.location, endLocation), attributes, generics, genericPacks, paramTypes, paramNames, returnTypeList);
Location(begin.location, endLocation), attributes, generics, genericPacks, paramTypes, paramNames, returnTypeList
);
}
// Type ::=
@ -1760,8 +1866,11 @@ AstType* Parser::parseTypeSuffix(AstType* type, const Location& begin)
if (isUnion && isIntersection)
{
return reportTypeError(Location(begin, parts.back()->location), copy(parts),
"Mixing union and intersection types is not allowed; consider wrapping in parentheses.");
return reportTypeError(
Location(begin, parts.back()->location),
copy(parts),
"Mixing union and intersection types is not allowed; consider wrapping in parentheses."
);
}
location.end = parts.back()->location.end;
@ -1922,7 +2031,8 @@ AstTypeOrPack Parser::parseSimpleType(bool allowPack, bool inDeclarationContext)
Location end = lexer.previousLocation();
return {
allocator.alloc<AstTypeReference>(Location(start, end), prefix, name.name, prefixLocation, name.location, hasParameters, parameters), {}};
allocator.alloc<AstTypeReference>(Location(start, end), prefix, name.name, prefixLocation, name.location, hasParameters, parameters), {}
};
}
else if (lexer.current().type == '{')
{
@ -1936,10 +2046,15 @@ AstTypeOrPack Parser::parseSimpleType(bool allowPack, bool inDeclarationContext)
{
nextLexeme();
return {reportTypeError(start, {},
"Using 'function' as a type annotation is not supported, consider replacing with a function type annotation e.g. '(...any) -> "
"...any'"),
{}};
return {
reportTypeError(
start,
{},
"Using 'function' as a type annotation is not supported, consider replacing with a function type annotation e.g. '(...any) -> "
"...any'"
),
{}
};
}
else
{
@ -2114,8 +2229,7 @@ std::optional<AstExprBinary::Op> Parser::checkBinaryConfusables(const BinaryOpPr
report(Location(start, next.location), "Unexpected '||'; did you mean 'or'?");
return AstExprBinary::Or;
}
else if (curr.type == '!' && next.type == '=' && curr.location.end == next.location.begin &&
binaryPriority[AstExprBinary::CompareNe].left > limit)
else if (curr.type == '!' && next.type == '=' && curr.location.end == next.location.begin && binaryPriority[AstExprBinary::CompareNe].left > limit)
{
nextLexeme();
report(Location(start, next.location), "Unexpected '!='; did you mean '~='?");
@ -2129,6 +2243,7 @@ std::optional<AstExprBinary::Op> Parser::checkBinaryConfusables(const BinaryOpPr
// where `binop' is any binary operator with a priority higher than `limit'
AstExpr* Parser::parseExpr(unsigned int limit)
{
// clang-format off
static const BinaryOpPriority binaryPriority[] = {
{6, 6}, {6, 6}, {7, 7}, {7, 7}, {7, 7}, {7, 7}, // `+' `-' `*' `/' `//' `%'
{10, 9}, {5, 4}, // power and concat (right associative)
@ -2136,6 +2251,8 @@ AstExpr* Parser::parseExpr(unsigned int limit)
{3, 3}, {3, 3}, {3, 3}, {3, 3}, // order
{2, 2}, {1, 1} // logical (and/or)
};
// clang-format on
static_assert(sizeof(binaryPriority) / sizeof(binaryPriority[0]) == size_t(AstExprBinary::Op__Count), "binaryPriority needs an entry per op");
unsigned int oldRecursionCount = recursionCounter;
@ -2414,7 +2531,8 @@ AstExpr* Parser::parseSimpleExpr()
if (lexer.current().type != Lexeme::ReservedFunction)
{
return reportExprError(
start, {}, "Expected 'function' declaration after attribute, but got %s instead", lexer.current().toString().c_str());
start, {}, "Expected 'function' declaration after attribute, but got %s instead", lexer.current().toString().c_str()
);
}
}
@ -2447,8 +2565,7 @@ AstExpr* Parser::parseSimpleExpr()
{
return parseNumber();
}
else if (lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::QuotedString ||
lexer.current().type == Lexeme::InterpStringSimple)
else if (lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::InterpStringSimple)
{
return parseString();
}
@ -2548,15 +2665,22 @@ LUAU_NOINLINE AstExpr* Parser::reportFunctionArgsError(AstExpr* func, bool self)
}
else
{
return reportExprError(Location(func->location.begin, lexer.current().location.begin), copy({func}),
"Expected '(', '{' or <string> when parsing function call, got %s", lexer.current().toString().c_str());
return reportExprError(
Location(func->location.begin, lexer.current().location.begin),
copy({func}),
"Expected '(', '{' or <string> when parsing function call, got %s",
lexer.current().toString().c_str()
);
}
}
LUAU_NOINLINE void Parser::reportAmbiguousCallError()
{
report(lexer.current().location, "Ambiguous syntax: this looks like an argument list for a function call, but could also be a start of "
"new statement; use ';' to separate statements");
report(
lexer.current().location,
"Ambiguous syntax: this looks like an argument list for a function call, but could also be a start of "
"new statement; use ';' to separate statements"
);
}
// tableconstructor ::= `{' [fieldlist] `}'
@ -2868,8 +2992,10 @@ AstArray<AstTypeOrPack> Parser::parseTypeParams()
std::optional<AstArray<char>> Parser::parseCharArray()
{
LUAU_ASSERT(lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::RawString ||
lexer.current().type == Lexeme::InterpStringSimple);
LUAU_ASSERT(
lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::RawString ||
lexer.current().type == Lexeme::InterpStringSimple
);
scratchData.assign(lexer.current().data, lexer.current().getLength());
@ -2911,8 +3037,10 @@ AstExpr* Parser::parseInterpString()
do
{
Lexeme currentLexeme = lexer.current();
LUAU_ASSERT(currentLexeme.type == Lexeme::InterpStringBegin || currentLexeme.type == Lexeme::InterpStringMid ||
currentLexeme.type == Lexeme::InterpStringEnd || currentLexeme.type == Lexeme::InterpStringSimple);
LUAU_ASSERT(
currentLexeme.type == Lexeme::InterpStringBegin || currentLexeme.type == Lexeme::InterpStringMid ||
currentLexeme.type == Lexeme::InterpStringEnd || currentLexeme.type == Lexeme::InterpStringSimple
);
endLocation = currentLexeme.location;
@ -3013,7 +3141,8 @@ AstLocal* Parser::pushLocal(const Binding& binding)
AstLocal*& local = localMap[name.name];
local = allocator.alloc<AstLocal>(
name.name, name.location, /* shadow= */ local, functionStack.size() - 1, functionStack.back().loopDepth, binding.annotation);
name.name, name.location, /* shadow= */ local, functionStack.size() - 1, functionStack.back().loopDepth, binding.annotation
);
localStack.push_back(local);
@ -3146,11 +3275,25 @@ LUAU_NOINLINE void Parser::expectMatchAndConsumeFail(Lexeme::Type type, const Ma
std::string matchString = Lexeme(Location(Position(0, 0), 0), begin.type).toString();
if (lexer.current().location.begin.line == begin.position.line)
report(lexer.current().location, "Expected %s (to close %s at column %d), got %s%s", typeString.c_str(), matchString.c_str(),
begin.position.column + 1, lexer.current().toString().c_str(), extra ? extra : "");
report(
lexer.current().location,
"Expected %s (to close %s at column %d), got %s%s",
typeString.c_str(),
matchString.c_str(),
begin.position.column + 1,
lexer.current().toString().c_str(),
extra ? extra : ""
);
else
report(lexer.current().location, "Expected %s (to close %s at line %d), got %s%s", typeString.c_str(), matchString.c_str(),
begin.position.line + 1, lexer.current().toString().c_str(), extra ? extra : "");
report(
lexer.current().location,
"Expected %s (to close %s at line %d), got %s%s",
typeString.c_str(),
matchString.c_str(),
begin.position.line + 1,
lexer.current().toString().c_str(),
extra ? extra : ""
);
}
bool Parser::expectMatchEndAndConsume(Lexeme::Type type, const MatchLexeme& begin)
@ -3287,7 +3430,12 @@ LUAU_NOINLINE void Parser::reportNameError(const char* context)
}
AstStatError* Parser::reportStatError(
const Location& location, const AstArray<AstExpr*>& expressions, const AstArray<AstStat*>& statements, const char* format, ...)
const Location& location,
const AstArray<AstExpr*>& expressions,
const AstArray<AstStat*>& statements,
const char* format,
...
)
{
va_list args;
va_start(args, format);
@ -3359,4 +3507,4 @@ void Parser::nextLexeme()
}
}
} // namespace Luau
} // namespace Luau

View file

@ -141,7 +141,8 @@ size_t editDistance(std::string_view a, std::string_view b)
size_t maxDistance = a.size() + b.size();
std::vector<size_t> distances((a.size() + 2) * (b.size() + 2), 0);
auto getPos = [b](size_t x, size_t y) -> size_t {
auto getPos = [b](size_t x, size_t y) -> size_t
{
return (x * (b.size() + 2)) + y;
};

View file

@ -184,8 +184,14 @@ void flushEvents(GlobalContext& context, uint32_t threadId, const std::vector<Ev
Token& token = context.tokens[ev.token];
formatAppend(temp, R"({"name": "%s", "cat": "%s", "ph": "B", "ts": %u, "pid": 0, "tid": %u)", token.name, token.category,
ev.data.microsec, threadId);
formatAppend(
temp,
R"({"name": "%s", "cat": "%s", "ph": "B", "ts": %u, "pid": 0, "tid": %u)",
token.name,
token.category,
ev.data.microsec,
threadId
);
unfinishedEnter = true;
}
break;
@ -201,10 +207,13 @@ void flushEvents(GlobalContext& context, uint32_t threadId, const std::vector<Ev
unfinishedEnter = false;
}
formatAppend(temp,
formatAppend(
temp,
R"({"ph": "E", "ts": %u, "pid": 0, "tid": %u},)"
"\n",
ev.data.microsec, threadId);
ev.data.microsec,
threadId
);
break;
case EventType::ArgName:
LUAU_ASSERT(unfinishedEnter);

View file

@ -64,8 +64,13 @@ static void reportError(const Luau::Frontend& frontend, ReportFormat format, con
if (const Luau::SyntaxError* syntaxError = Luau::get_if<Luau::SyntaxError>(&error.data))
report(format, humanReadableName.c_str(), error.location, "SyntaxError", syntaxError->message.c_str());
else
report(format, humanReadableName.c_str(), error.location, "TypeError",
Luau::toString(error, Luau::TypeErrorToStringOptions{frontend.fileResolver}).c_str());
report(
format,
humanReadableName.c_str(),
error.location,
"TypeError",
Luau::toString(error, Luau::TypeErrorToStringOptions{frontend.fileResolver}).c_str()
);
}
static void reportWarning(ReportFormat format, const char* name, const Luau::LintWarning& warning)
@ -235,9 +240,12 @@ struct TaskScheduler
{
for (unsigned i = 0; i < threadCount; i++)
{
workers.emplace_back([this] {
workerFunction();
});
workers.emplace_back(
[this]
{
workerFunction();
}
);
}
}
@ -254,9 +262,13 @@ struct TaskScheduler
{
std::unique_lock guard(mtx);
cv.wait(guard, [this] {
return !tasks.empty();
});
cv.wait(
guard,
[this]
{
return !tasks.empty();
}
);
std::function<void()> task = tasks.front();
tasks.pop();
@ -351,7 +363,8 @@ int main(int argc, char** argv)
if (FFlag::DebugLuauLogSolverToJsonFile)
{
frontend.writeJsonLog = [&basePath](const Luau::ModuleName& moduleName, std::string log) {
frontend.writeJsonLog = [&basePath](const Luau::ModuleName& moduleName, std::string log)
{
std::string path = moduleName + ".log.json";
size_t pos = moduleName.find_last_of('/');
if (pos != std::string::npos)
@ -390,9 +403,13 @@ int main(int argc, char** argv)
{
TaskScheduler scheduler(threadCount);
checkedModules = frontend.checkQueuedModules(std::nullopt, [&](std::function<void()> f) {
scheduler.push(std::move(f));
});
checkedModules = frontend.checkQueuedModules(
std::nullopt,
[&](std::function<void()> f)
{
scheduler.push(std::move(f));
}
);
}
catch (const Luau::InternalCompilerError& ice)
{
@ -403,8 +420,13 @@ int main(int argc, char** argv)
Luau::TypeError error(location, moduleName, Luau::InternalError{ice.message});
report(format, humanReadableName.c_str(), location, "InternalCompilerError",
Luau::toString(error, Luau::TypeErrorToStringOptions{frontend.fileResolver}).c_str());
report(
format,
humanReadableName.c_str(),
location,
"InternalCompilerError",
Luau::toString(error, Luau::TypeErrorToStringOptions{frontend.fileResolver}).c_str()
);
return 1;
}

View file

@ -231,7 +231,10 @@ static void serializeScriptSummary(const std::string& file, const std::vector<Fu
}
static bool serializeSummaries(
const std::vector<std::string>& files, const std::vector<std::vector<FunctionBytecodeSummary>>& scriptSummaries, const std::string& summaryFile)
const std::vector<std::string>& files,
const std::vector<std::vector<FunctionBytecodeSummary>>& scriptSummaries,
const std::string& summaryFile
)
{
FILE* fp = fopen(summaryFile.c_str(), "w");

View file

@ -108,7 +108,11 @@ static void reportError(const char* name, const Luau::CompileError& error)
}
static std::string getCodegenAssembly(
const char* name, const std::string& bytecode, Luau::CodeGen::AssemblyOptions options, Luau::CodeGen::LoweringStats* stats)
const char* name,
const std::string& bytecode,
Luau::CodeGen::AssemblyOptions options,
Luau::CodeGen::LoweringStats* stats
)
{
std::unique_ptr<lua_State, void (*)(lua_State*)> globalState(luaL_newstate(), lua_close);
lua_State* L = globalState.get();
@ -326,8 +330,10 @@ static bool compileFile(const char* name, CompileFormat format, Luau::CodeGen::A
if (format == CompileFormat::Text)
{
bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Locals |
Luau::BytecodeBuilder::Dump_Remarks | Luau::BytecodeBuilder::Dump_Types);
bcb.setDumpFlags(
Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Locals |
Luau::BytecodeBuilder::Dump_Remarks | Luau::BytecodeBuilder::Dump_Types
);
bcb.setDumpSource(*source);
}
else if (format == CompileFormat::Remarks)
@ -335,11 +341,12 @@ static bool compileFile(const char* name, CompileFormat format, Luau::CodeGen::A
bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Remarks);
bcb.setDumpSource(*source);
}
else if (format == CompileFormat::Codegen || format == CompileFormat::CodegenAsm || format == CompileFormat::CodegenIr ||
format == CompileFormat::CodegenVerbose)
else if (format == CompileFormat::Codegen || format == CompileFormat::CodegenAsm || format == CompileFormat::CodegenIr || format == CompileFormat::CodegenVerbose)
{
bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Locals |
Luau::BytecodeBuilder::Dump_Remarks);
bcb.setDumpFlags(
Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Locals |
Luau::BytecodeBuilder::Dump_Remarks
);
bcb.setDumpSource(*source);
}
@ -623,19 +630,37 @@ int main(int argc, char** argv)
if (compileFormat == CompileFormat::Null)
{
printf("Compiled %d KLOC into %d KB bytecode (read %.2fs, parse %.2fs, compile %.2fs)\n", int(stats.lines / 1000), int(stats.bytecode / 1024),
stats.readTime, stats.parseTime, stats.compileTime);
printf(
"Compiled %d KLOC into %d KB bytecode (read %.2fs, parse %.2fs, compile %.2fs)\n",
int(stats.lines / 1000),
int(stats.bytecode / 1024),
stats.readTime,
stats.parseTime,
stats.compileTime
);
}
else if (compileFormat == CompileFormat::CodegenNull)
{
printf("Compiled %d KLOC into %d KB bytecode => %d KB native code (%.2fx) (read %.2fs, parse %.2fs, compile %.2fs, codegen %.2fs)\n",
int(stats.lines / 1000), int(stats.bytecode / 1024), int(stats.codegen / 1024),
stats.bytecode == 0 ? 0.0 : double(stats.codegen) / double(stats.bytecode), stats.readTime, stats.parseTime, stats.compileTime,
stats.codegenTime);
printf(
"Compiled %d KLOC into %d KB bytecode => %d KB native code (%.2fx) (read %.2fs, parse %.2fs, compile %.2fs, codegen %.2fs)\n",
int(stats.lines / 1000),
int(stats.bytecode / 1024),
int(stats.codegen / 1024),
stats.bytecode == 0 ? 0.0 : double(stats.codegen) / double(stats.bytecode),
stats.readTime,
stats.parseTime,
stats.compileTime,
stats.codegenTime
);
printf("Lowering: regalloc failed: %d, lowering failed %d; spills to stack: %d, spills to restore: %d, max spill slot %u\n",
stats.lowerStats.regAllocErrors, stats.lowerStats.loweringErrors, stats.lowerStats.spillsToSlot, stats.lowerStats.spillsToRestore,
stats.lowerStats.maxSpillSlotsUsed);
printf(
"Lowering: regalloc failed: %d, lowering failed %d; spills to stack: %d, spills to restore: %d, max spill slot %u\n",
stats.lowerStats.regAllocErrors,
stats.lowerStats.loweringErrors,
stats.lowerStats.spillsToSlot,
stats.lowerStats.spillsToRestore,
stats.lowerStats.maxSpillSlotsUsed
);
}
if (recordStats != RecordStats::None)

View file

@ -442,12 +442,16 @@ std::vector<std::string> getSourceFiles(int argc, char** argv)
if (isDirectory(argv[i]))
{
traverseDirectory(argv[i], [&](const std::string& name) {
std::string ext = getExtension(name);
traverseDirectory(
argv[i],
[&](const std::string& name)
{
std::string ext = getExtension(name);
if (ext == ".lua" || ext == ".luau")
files.push_back(name);
});
if (ext == ".lua" || ext == ".luau")
files.push_back(name);
}
);
}
else
{

View file

@ -54,8 +54,9 @@ void setLuauFlags(const char* list)
else if (value == "false" || value == "False")
setLuauFlag(key, false);
else
fprintf(stderr, "Warning: unrecognized value '%.*s' for flag '%.*s'.\n", int(value.length()), value.data(), int(key.length()),
key.data());
fprintf(
stderr, "Warning: unrecognized value '%.*s' for flag '%.*s'.\n", int(value.length()), value.data(), int(key.length()), key.data()
);
}
else
{

View file

@ -131,8 +131,13 @@ void profilerDump(const char* path)
fclose(f);
printf("Profiler dump written to %s (total runtime %.3f seconds, %lld samples, %lld stacks)\n", path, double(total) / 1e6,
static_cast<long long>(gProfiler.samples.load()), static_cast<long long>(gProfiler.data.size()));
printf(
"Profiler dump written to %s (total runtime %.3f seconds, %lld samples, %lld stacks)\n",
path,
double(total) / 1e6,
static_cast<long long>(gProfiler.samples.load()),
static_cast<long long>(gProfiler.data.size())
);
uint64_t totalgc = 0;
for (uint64_t p : gProfiler.gc)

View file

@ -184,7 +184,8 @@ struct Reducer
{
std::vector<AstStat*> result;
auto append = [&](AstStatBlock* block) {
auto append = [&](AstStatBlock* block)
{
if (block)
result.insert(result.end(), block->body.data, block->body.data + block->body.size);
};
@ -250,7 +251,8 @@ struct Reducer
std::vector<std::pair<Span, Span>> result;
auto append = [&result](Span a, Span b) {
auto append = [&result](Span a, Span b)
{
if (a.first == a.second && b.first == b.second)
return;
else

View file

@ -388,8 +388,13 @@ static void safeGetTable(lua_State* L, int tableIndex)
// completePartialMatches finds keys that match the specified 'prefix'
// Note: the table/object to be searched must be on the top of the Lua stack
static void completePartialMatches(lua_State* L, bool completeOnlyFunctions, const std::string& editBuffer, std::string_view prefix,
const AddCompletionCallback& addCompletionCallback)
static void completePartialMatches(
lua_State* L,
bool completeOnlyFunctions,
const std::string& editBuffer,
std::string_view prefix,
const AddCompletionCallback& addCompletionCallback
)
{
for (int i = 0; i < MaxTraversalLimit && lua_istable(L, -1); i++)
{
@ -483,9 +488,14 @@ static void icGetCompletions(ic_completion_env_t* cenv, const char* editBuffer)
{
auto* L = reinterpret_cast<lua_State*>(ic_completion_arg(cenv));
getCompletions(L, std::string(editBuffer), [cenv](const std::string& completion, const std::string& display) {
ic_add_completion_ex(cenv, completion.data(), display.data(), nullptr);
});
getCompletions(
L,
std::string(editBuffer),
[cenv](const std::string& completion, const std::string& display)
{
ic_add_completion_ex(cenv, completion.data(), display.data(), nullptr);
}
);
}
static bool isMethodOrFunctionChar(const char* s, long len)
@ -788,9 +798,13 @@ int replMain(int argc, char** argv)
// note, there's no need to close the log explicitly as it will be closed when the process exits
FILE* codegenPerfLog = fopen(path, "w");
Luau::CodeGen::setPerfLog(codegenPerfLog, [](void* context, uintptr_t addr, unsigned size, const char* symbol) {
fprintf(static_cast<FILE*>(context), "%016lx %08x %s\n", long(addr), size, symbol);
});
Luau::CodeGen::setPerfLog(
codegenPerfLog,
[](void* context, uintptr_t addr, unsigned size, const char* symbol)
{
fprintf(static_cast<FILE*>(context), "%016lx %08x %s\n", long(addr), size, symbol);
}
);
#else
fprintf(stderr, "--codegen-perf option is only supported on Linux\n");
return 1;

View file

@ -223,9 +223,15 @@ void RequireResolver::substituteAliasIfPresent(std::string& path)
std::optional<std::string> RequireResolver::getAlias(std::string alias)
{
std::transform(alias.begin(), alias.end(), alias.begin(), [](unsigned char c) {
return ('A' <= c && c <= 'Z') ? (c + ('a' - 'A')) : c;
});
std::transform(
alias.begin(),
alias.end(),
alias.begin(),
[](unsigned char c)
{
return ('A' <= c && c <= 'Z') ? (c + ('a' - 'A')) : c;
}
);
while (!config.aliases.count(alias) && !isConfigFullyResolved)
{
parseNextConfig();

View file

@ -212,8 +212,19 @@ public:
private:
// Instruction archetypes
void placeBinary(const char* name, OperandX64 lhs, OperandX64 rhs, uint8_t codeimm8, uint8_t codeimm, uint8_t codeimmImm8, uint8_t code8rev,
uint8_t coderev, uint8_t code8, uint8_t code, uint8_t opreg);
void placeBinary(
const char* name,
OperandX64 lhs,
OperandX64 rhs,
uint8_t codeimm8,
uint8_t codeimm,
uint8_t codeimmImm8,
uint8_t code8rev,
uint8_t coderev,
uint8_t code8,
uint8_t code,
uint8_t opreg
);
void placeBinaryRegMemAndImm(OperandX64 lhs, OperandX64 rhs, uint8_t code8, uint8_t code, uint8_t codeImm8, uint8_t opreg);
void placeBinaryRegAndRegMem(OperandX64 lhs, OperandX64 rhs, uint8_t code8, uint8_t code);
void placeBinaryRegMemAndReg(OperandX64 lhs, OperandX64 rhs, uint8_t code8, uint8_t code);
@ -228,7 +239,16 @@ private:
void placeAvx(const char* name, OperandX64 dst, OperandX64 src, uint8_t code, uint8_t coderev, bool setW, uint8_t mode, uint8_t prefix);
void placeAvx(const char* name, OperandX64 dst, OperandX64 src1, OperandX64 src2, uint8_t code, bool setW, uint8_t mode, uint8_t prefix);
void placeAvx(
const char* name, OperandX64 dst, OperandX64 src1, OperandX64 src2, uint8_t imm8, uint8_t code, bool setW, uint8_t mode, uint8_t prefix);
const char* name,
OperandX64 dst,
OperandX64 src1,
OperandX64 src2,
uint8_t imm8,
uint8_t code,
bool setW,
uint8_t mode,
uint8_t prefix
);
// Instruction components
void placeRegAndModRegMem(OperandX64 lhs, OperandX64 rhs, int32_t extraCodeBytes = 0);

View file

@ -25,7 +25,14 @@ struct CodeAllocator
// To allow allocation while previously allocated code is already running, allocation has page granularity
// It's important to group functions together so that page alignment won't result in a lot of wasted space
bool allocate(
const uint8_t* data, size_t dataSize, const uint8_t* code, size_t codeSize, uint8_t*& result, size_t& resultSize, uint8_t*& resultCodeStart);
const uint8_t* data,
size_t dataSize,
const uint8_t* code,
size_t codeSize,
uint8_t*& result,
size_t& resultSize,
uint8_t*& resultCodeStart
);
// Provided to unwind info callbacks
void* context = nullptr;

View file

@ -77,8 +77,8 @@ struct IrOp;
using HostVectorOperationBytecodeType = uint8_t (*)(const char* member, size_t memberLength);
using HostVectorAccessHandler = bool (*)(IrBuilder& builder, const char* member, size_t memberLength, int resultReg, int sourceReg, int pcpos);
using HostVectorNamecallHandler = bool (*)(
IrBuilder& builder, const char* member, size_t memberLength, int argResReg, int sourceReg, int params, int results, int pcpos);
using HostVectorNamecallHandler =
bool (*)(IrBuilder& builder, const char* member, size_t memberLength, int argResReg, int sourceReg, int params, int results, int pcpos);
enum class HostMetamethod
{
@ -99,12 +99,21 @@ enum class HostMetamethod
using HostUserdataOperationBytecodeType = uint8_t (*)(uint8_t type, const char* member, size_t memberLength);
using HostUserdataMetamethodBytecodeType = uint8_t (*)(uint8_t lhsTy, uint8_t rhsTy, HostMetamethod method);
using HostUserdataAccessHandler = bool (*)(
IrBuilder& builder, uint8_t type, const char* member, size_t memberLength, int resultReg, int sourceReg, int pcpos);
using HostUserdataMetamethodHandler = bool (*)(
IrBuilder& builder, uint8_t lhsTy, uint8_t rhsTy, int resultReg, IrOp lhs, IrOp rhs, HostMetamethod method, int pcpos);
using HostUserdataAccessHandler =
bool (*)(IrBuilder& builder, uint8_t type, const char* member, size_t memberLength, int resultReg, int sourceReg, int pcpos);
using HostUserdataMetamethodHandler =
bool (*)(IrBuilder& builder, uint8_t lhsTy, uint8_t rhsTy, int resultReg, IrOp lhs, IrOp rhs, HostMetamethod method, int pcpos);
using HostUserdataNamecallHandler = bool (*)(
IrBuilder& builder, uint8_t type, const char* member, size_t memberLength, int argResReg, int sourceReg, int params, int results, int pcpos);
IrBuilder& builder,
uint8_t type,
const char* member,
size_t memberLength,
int argResReg,
int sourceReg,
int params,
int results,
int pcpos
);
struct HostIrHooks
{
@ -196,7 +205,11 @@ using UniqueSharedCodeGenContext = std::unique_ptr<SharedCodeGenContext, SharedC
[[nodiscard]] UniqueSharedCodeGenContext createSharedCodeGenContext(AllocationCallback* allocationCallback, void* allocationCallbackContext);
[[nodiscard]] UniqueSharedCodeGenContext createSharedCodeGenContext(
size_t blockSize, size_t maxTotalSize, AllocationCallback* allocationCallback, void* allocationCallbackContext);
size_t blockSize,
size_t maxTotalSize,
AllocationCallback* allocationCallback,
void* allocationCallbackContext
);
// Destroys the provided SharedCodeGenContext. All Luau VMs using the
// SharedCodeGenContext must be destroyed before this function is called.

View file

@ -135,7 +135,11 @@ struct IdfContext
// 'Iterated' comes from the definition where we recompute the IDFn+1 = DF(S) while adding IDFn to S until a fixed point is reached
// Iterated dominance frontier has been shown to be equal to the set of nodes where phi instructions have to be inserted
void computeIteratedDominanceFrontierForDefs(
IdfContext& ctx, const IrFunction& function, const std::vector<uint32_t>& defBlocks, const std::vector<uint32_t>& liveInBlocks);
IdfContext& ctx,
const IrFunction& function,
const std::vector<uint32_t>& defBlocks,
const std::vector<uint32_t>& liveInBlocks
);
// Function used to update all CFG data
void computeCfgInfo(IrFunction& function);

View file

@ -36,9 +36,21 @@ const char* getBytecodeTypeName(uint8_t type, const char* const* userdataTypes);
void toString(std::string& result, const BytecodeTypes& bcTypes, const char* const* userdataTypes);
void toStringDetailed(
IrToStringContext& ctx, const IrBlock& block, uint32_t blockIdx, const IrInst& inst, uint32_t instIdx, IncludeUseInfo includeUseInfo);
void toStringDetailed(IrToStringContext& ctx, const IrBlock& block, uint32_t blockIdx, IncludeUseInfo includeUseInfo, IncludeCfgInfo includeCfgInfo,
IncludeRegFlowInfo includeRegFlowInfo);
IrToStringContext& ctx,
const IrBlock& block,
uint32_t blockIdx,
const IrInst& inst,
uint32_t instIdx,
IncludeUseInfo includeUseInfo
);
void toStringDetailed(
IrToStringContext& ctx,
const IrBlock& block,
uint32_t blockIdx,
IncludeUseInfo includeUseInfo,
IncludeCfgInfo includeCfgInfo,
IncludeRegFlowInfo includeRegFlowInfo
);
std::string toString(const IrFunction& function, IncludeUseInfo includeUseInfo);

View file

@ -42,8 +42,12 @@ class SharedCodeAllocator;
class NativeModule
{
public:
NativeModule(SharedCodeAllocator* allocator, const std::optional<ModuleId>& moduleId, const uint8_t* moduleBaseAddress,
std::vector<NativeProtoExecDataPtr> nativeProtos) noexcept;
NativeModule(
SharedCodeAllocator* allocator,
const std::optional<ModuleId>& moduleId,
const uint8_t* moduleBaseAddress,
std::vector<NativeProtoExecDataPtr> nativeProtos
) noexcept;
NativeModule(const NativeModule&) = delete;
NativeModule(NativeModule&&) = delete;
@ -132,11 +136,22 @@ public:
// data and code such that it can be executed). Like std::map::insert, the
// bool result is true if a new module was created; false if an existing
// module is being returned.
std::pair<NativeModuleRef, bool> getOrInsertNativeModule(const ModuleId& moduleId, std::vector<NativeProtoExecDataPtr> nativeProtos,
const uint8_t* data, size_t dataSize, const uint8_t* code, size_t codeSize);
std::pair<NativeModuleRef, bool> getOrInsertNativeModule(
const ModuleId& moduleId,
std::vector<NativeProtoExecDataPtr> nativeProtos,
const uint8_t* data,
size_t dataSize,
const uint8_t* code,
size_t codeSize
);
NativeModuleRef insertAnonymousNativeModule(
std::vector<NativeProtoExecDataPtr> nativeProtos, const uint8_t* data, size_t dataSize, const uint8_t* code, size_t codeSize);
std::vector<NativeProtoExecDataPtr> nativeProtos,
const uint8_t* data,
size_t dataSize,
const uint8_t* code,
size_t codeSize
);
// If a NativeModule exists for the given ModuleId and that NativeModule
// is no longer referenced, the NativeModule is destroyed. This should

View file

@ -49,8 +49,13 @@ public:
// mov rbp, rsp
// push reg in the order specified in regs
// sub rsp, stackSize
virtual void prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list<X64::RegisterX64> gpr,
const std::vector<X64::RegisterX64>& simd) = 0;
virtual void prologueX64(
uint32_t prologueSize,
uint32_t stackSize,
bool setupFrame,
std::initializer_list<X64::RegisterX64> gpr,
const std::vector<X64::RegisterX64>& simd
) = 0;
virtual size_t getUnwindInfoSize(size_t blockSize) const = 0;

View file

@ -30,8 +30,13 @@ public:
void finishInfo() override;
void prologueA64(uint32_t prologueSize, uint32_t stackSize, std::initializer_list<A64::RegisterA64> regs) override;
void prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list<X64::RegisterX64> gpr,
const std::vector<X64::RegisterX64>& simd) override;
void prologueX64(
uint32_t prologueSize,
uint32_t stackSize,
bool setupFrame,
std::initializer_list<X64::RegisterX64> gpr,
const std::vector<X64::RegisterX64>& simd
) override;
size_t getUnwindInfoSize(size_t blockSize = 0) const override;

View file

@ -50,8 +50,13 @@ public:
void finishInfo() override;
void prologueA64(uint32_t prologueSize, uint32_t stackSize, std::initializer_list<A64::RegisterA64> regs) override;
void prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list<X64::RegisterX64> gpr,
const std::vector<X64::RegisterX64>& simd) override;
void prologueX64(
uint32_t prologueSize,
uint32_t stackSize,
bool setupFrame,
std::initializer_list<X64::RegisterX64> gpr,
const std::vector<X64::RegisterX64>& simd
) override;
size_t getUnwindInfoSize(size_t blockSize = 0) const override;

View file

@ -17,8 +17,8 @@ namespace A64
static const uint8_t codeForCondition[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14};
static_assert(sizeof(codeForCondition) / sizeof(codeForCondition[0]) == size_t(ConditionA64::Count), "all conditions have to be covered");
static const char* textForCondition[] = {
"b.eq", "b.ne", "b.cs", "b.cc", "b.mi", "b.pl", "b.vs", "b.vc", "b.hi", "b.ls", "b.ge", "b.lt", "b.gt", "b.le", "b.al"};
static const char* textForCondition[] =
{"b.eq", "b.ne", "b.cs", "b.cc", "b.mi", "b.pl", "b.vs", "b.vc", "b.hi", "b.ls", "b.ge", "b.lt", "b.gt", "b.le", "b.al"};
static_assert(sizeof(textForCondition) / sizeof(textForCondition[0]) == size_t(ConditionA64::Count), "all conditions have to be covered");
const unsigned kMaxAlign = 32;
@ -968,8 +968,10 @@ void AssemblyBuilderA64::placeSR3(const char* name, RegisterA64 dst, RegisterA64
uint32_t sf = (dst.kind == KindA64::x) ? 0x80000000 : 0;
place(dst.index | (src1.index << 5) | ((shift < 0 ? -shift : shift) << 10) | (src2.index << 16) | (N << 21) | (int(shift < 0) << 22) |
(op << 24) | sf);
place(
dst.index | (src1.index << 5) | ((shift < 0 ? -shift : shift) << 10) | (src2.index << 16) | (N << 21) | (int(shift < 0) << 22) | (op << 24) |
sf
);
commit();
}
@ -1173,7 +1175,15 @@ void AssemblyBuilderA64::placeP(const char* name, RegisterA64 src1, RegisterA64
}
void AssemblyBuilderA64::placeCS(
const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, ConditionA64 cond, uint8_t op, uint8_t opc, int invert)
const char* name,
RegisterA64 dst,
RegisterA64 src1,
RegisterA64 src2,
ConditionA64 cond,
uint8_t op,
uint8_t opc,
int invert
)
{
if (logText)
log(name, dst, src1, src2, cond);

View file

@ -15,21 +15,22 @@ namespace X64
// TODO: more assertions on operand sizes
static const uint8_t codeForCondition[] = {
0x0, 0x1, 0x2, 0x3, 0x2, 0x6, 0x7, 0x3, 0x4, 0xc, 0xe, 0xf, 0xd, 0x3, 0x7, 0x6, 0x2, 0x5, 0xd, 0xf, 0xe, 0xc, 0x4, 0x5, 0xa, 0xb};
static const uint8_t codeForCondition[] = {0x0, 0x1, 0x2, 0x3, 0x2, 0x6, 0x7, 0x3, 0x4, 0xc, 0xe, 0xf, 0xd,
0x3, 0x7, 0x6, 0x2, 0x5, 0xd, 0xf, 0xe, 0xc, 0x4, 0x5, 0xa, 0xb};
static_assert(sizeof(codeForCondition) / sizeof(codeForCondition[0]) == size_t(ConditionX64::Count), "all conditions have to be covered");
static const char* jccTextForCondition[] = {"jo", "jno", "jc", "jnc", "jb", "jbe", "ja", "jae", "je", "jl", "jle", "jg", "jge", "jnb", "jnbe", "jna",
"jnae", "jne", "jnl", "jnle", "jng", "jnge", "jz", "jnz", "jp", "jnp"};
static const char* jccTextForCondition[] = {"jo", "jno", "jc", "jnc", "jb", "jbe", "ja", "jae", "je", "jl", "jle", "jg", "jge",
"jnb", "jnbe", "jna", "jnae", "jne", "jnl", "jnle", "jng", "jnge", "jz", "jnz", "jp", "jnp"};
static_assert(sizeof(jccTextForCondition) / sizeof(jccTextForCondition[0]) == size_t(ConditionX64::Count), "all conditions have to be covered");
static const char* setccTextForCondition[] = {"seto", "setno", "setc", "setnc", "setb", "setbe", "seta", "setae", "sete", "setl", "setle", "setg",
"setge", "setnb", "setnbe", "setna", "setnae", "setne", "setnl", "setnle", "setng", "setnge", "setz", "setnz", "setp", "setnp"};
static const char* setccTextForCondition[] = {"seto", "setno", "setc", "setnc", "setb", "setbe", "seta", "setae", "sete",
"setl", "setle", "setg", "setge", "setnb", "setnbe", "setna", "setnae", "setne",
"setnl", "setnle", "setng", "setnge", "setz", "setnz", "setp", "setnp"};
static_assert(sizeof(setccTextForCondition) / sizeof(setccTextForCondition[0]) == size_t(ConditionX64::Count), "all conditions have to be covered");
static const char* cmovTextForCondition[] = {"cmovo", "cmovno", "cmovc", "cmovnc", "cmovb", "cmovbe", "cmova", "cmovae", "cmove", "cmovl", "cmovle",
"cmovg", "cmovge", "cmovnb", "cmovnbe", "cmovna", "cmovnae", "cmovne", "cmovnl", "cmovnle", "cmovng", "cmovnge", "cmovz", "cmovnz", "cmovp",
"cmovnp"};
static const char* cmovTextForCondition[] = {"cmovo", "cmovno", "cmovc", "cmovnc", "cmovb", "cmovbe", "cmova", "cmovae", "cmove",
"cmovl", "cmovle", "cmovg", "cmovge", "cmovnb", "cmovnbe", "cmovna", "cmovnae", "cmovne",
"cmovnl", "cmovnle", "cmovng", "cmovnge", "cmovz", "cmovnz", "cmovp", "cmovnp"};
static_assert(sizeof(cmovTextForCondition) / sizeof(cmovTextForCondition[0]) == size_t(ConditionX64::Count), "all conditions have to be covered");
#define OP_PLUS_REG(op, reg) ((op) + (reg & 0x7))
@ -50,8 +51,8 @@ static_assert(sizeof(cmovTextForCondition) / sizeof(cmovTextForCondition[0]) ==
#define AVX_3_2(r, x, b, m) (AVX_R(r) | AVX_X(x) | AVX_B(b) | (m))
#define AVX_3_3(w, v, l, p) (AVX_W(w) | ((~(v.index) & 0xf) << 3) | ((l) << 2) | (p))
#define MOD_RM(mod, reg, rm) (((mod) << 6) | (((reg)&0x7) << 3) | ((rm)&0x7))
#define SIB(scale, index, base) ((getScaleEncoding(scale) << 6) | (((index)&0x7) << 3) | ((base)&0x7))
#define MOD_RM(mod, reg, rm) (((mod) << 6) | (((reg) & 0x7) << 3) | ((rm) & 0x7))
#define SIB(scale, index, base) ((getScaleEncoding(scale) << 6) | (((index) & 0x7) << 3) | ((base) & 0x7))
const unsigned AVX_0F = 0b0001;
[[maybe_unused]] const unsigned AVX_0F38 = 0b0010;
@ -1136,8 +1137,19 @@ unsigned AssemblyBuilderX64::getInstructionCount() const
return instructionCount;
}
void AssemblyBuilderX64::placeBinary(const char* name, OperandX64 lhs, OperandX64 rhs, uint8_t codeimm8, uint8_t codeimm, uint8_t codeimmImm8,
uint8_t code8rev, uint8_t coderev, uint8_t code8, uint8_t code, uint8_t opreg)
void AssemblyBuilderX64::placeBinary(
const char* name,
OperandX64 lhs,
OperandX64 rhs,
uint8_t codeimm8,
uint8_t codeimm,
uint8_t codeimmImm8,
uint8_t code8rev,
uint8_t coderev,
uint8_t code8,
uint8_t code,
uint8_t opreg
)
{
if (logText)
log(name, lhs, rhs);
@ -1292,7 +1304,15 @@ void AssemblyBuilderX64::placeAvx(const char* name, OperandX64 dst, OperandX64 s
}
void AssemblyBuilderX64::placeAvx(
const char* name, OperandX64 dst, OperandX64 src, uint8_t code, uint8_t coderev, bool setW, uint8_t mode, uint8_t prefix)
const char* name,
OperandX64 dst,
OperandX64 src,
uint8_t code,
uint8_t coderev,
bool setW,
uint8_t mode,
uint8_t prefix
)
{
CODEGEN_ASSERT((dst.cat == CategoryX64::mem && src.cat == CategoryX64::reg) || (dst.cat == CategoryX64::reg && src.cat == CategoryX64::mem));
@ -1316,7 +1336,15 @@ void AssemblyBuilderX64::placeAvx(
}
void AssemblyBuilderX64::placeAvx(
const char* name, OperandX64 dst, OperandX64 src1, OperandX64 src2, uint8_t code, bool setW, uint8_t mode, uint8_t prefix)
const char* name,
OperandX64 dst,
OperandX64 src1,
OperandX64 src2,
uint8_t code,
bool setW,
uint8_t mode,
uint8_t prefix
)
{
CODEGEN_ASSERT(dst.cat == CategoryX64::reg);
CODEGEN_ASSERT(src1.cat == CategoryX64::reg);
@ -1332,8 +1360,8 @@ void AssemblyBuilderX64::placeAvx(
commit();
}
void AssemblyBuilderX64::placeAvx(
const char* name, OperandX64 dst, OperandX64 src1, OperandX64 src2, uint8_t imm8, uint8_t code, bool setW, uint8_t mode, uint8_t prefix)
void AssemblyBuilderX64::
placeAvx(const char* name, OperandX64 dst, OperandX64 src1, OperandX64 src2, uint8_t imm8, uint8_t code, bool setW, uint8_t mode, uint8_t prefix)
{
CODEGEN_ASSERT(dst.cat == CategoryX64::reg);
CODEGEN_ASSERT(src1.cat == CategoryX64::reg);
@ -1735,13 +1763,15 @@ const char* AssemblyBuilderX64::getSizeName(SizeX64 size) const
const char* AssemblyBuilderX64::getRegisterName(RegisterX64 reg) const
{
static const char* names[][16] = {{"rip", "", "", "", "", "", "", "", "", "", "", "", "", "", "", ""},
static const char* names[][16] = {
{"rip", "", "", "", "", "", "", "", "", "", "", "", "", "", "", ""},
{"al", "cl", "dl", "bl", "spl", "bpl", "sil", "dil", "r8b", "r9b", "r10b", "r11b", "r12b", "r13b", "r14b", "r15b"},
{"ax", "cx", "dx", "bx", "sp", "bp", "si", "di", "r8w", "r9w", "r10w", "r11w", "r12w", "r13w", "r14w", "r15w"},
{"eax", "ecx", "edx", "ebx", "esp", "ebp", "esi", "edi", "r8d", "r9d", "r10d", "r11d", "r12d", "r13d", "r14d", "r15d"},
{"rax", "rcx", "rdx", "rbx", "rsp", "rbp", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15"},
{"xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15"},
{"ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", "ymm15"}};
{"ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", "ymm15"}
};
CODEGEN_ASSERT(reg.index < 16);
CODEGEN_ASSERT(reg.size <= SizeX64::ymmword);

View file

@ -116,12 +116,17 @@ void loadBytecodeTypeInfo(IrFunction& function)
static void prepareRegTypeInfoLookups(BytecodeTypeInfo& typeInfo)
{
// Sort by register first, then by end PC
std::sort(typeInfo.regTypes.begin(), typeInfo.regTypes.end(), [](const BytecodeRegTypeInfo& a, const BytecodeRegTypeInfo& b) {
if (a.reg != b.reg)
return a.reg < b.reg;
std::sort(
typeInfo.regTypes.begin(),
typeInfo.regTypes.end(),
[](const BytecodeRegTypeInfo& a, const BytecodeRegTypeInfo& b)
{
if (a.reg != b.reg)
return a.reg < b.reg;
return a.endpc < b.endpc;
});
return a.endpc < b.endpc;
}
);
// Prepare data for all registers as 'regTypes' might be missing temporaries
typeInfo.regTypeOffsets.resize(256 + 1);
@ -805,8 +810,7 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks)
regTags[ra] = LBC_TYPE_NUMBER;
else if (bcType.a == LBC_TYPE_VECTOR && bcType.b == LBC_TYPE_VECTOR)
regTags[ra] = LBC_TYPE_VECTOR;
else if (hostHooks.userdataMetamethodBytecodeType &&
(isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
else if (hostHooks.userdataMetamethodBytecodeType && (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op));
bcType.result = regTags[ra];
@ -837,8 +841,7 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks)
if (bcType.b == LBC_TYPE_NUMBER || bcType.b == LBC_TYPE_VECTOR)
regTags[ra] = LBC_TYPE_VECTOR;
}
else if (hostHooks.userdataMetamethodBytecodeType &&
(isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
else if (hostHooks.userdataMetamethodBytecodeType && (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
{
regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op));
}
@ -860,8 +863,7 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks)
if (bcType.a == LBC_TYPE_NUMBER && bcType.b == LBC_TYPE_NUMBER)
regTags[ra] = LBC_TYPE_NUMBER;
else if (hostHooks.userdataMetamethodBytecodeType &&
(isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
else if (hostHooks.userdataMetamethodBytecodeType && (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op));
bcType.result = regTags[ra];
@ -883,8 +885,7 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks)
regTags[ra] = LBC_TYPE_NUMBER;
else if (bcType.a == LBC_TYPE_VECTOR && bcType.b == LBC_TYPE_VECTOR)
regTags[ra] = LBC_TYPE_VECTOR;
else if (hostHooks.userdataMetamethodBytecodeType &&
(isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
else if (hostHooks.userdataMetamethodBytecodeType && (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op));
bcType.result = regTags[ra];
@ -915,8 +916,7 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks)
if (bcType.b == LBC_TYPE_NUMBER || bcType.b == LBC_TYPE_VECTOR)
regTags[ra] = LBC_TYPE_VECTOR;
}
else if (hostHooks.userdataMetamethodBytecodeType &&
(isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
else if (hostHooks.userdataMetamethodBytecodeType && (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
{
regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op));
}
@ -938,8 +938,7 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks)
if (bcType.a == LBC_TYPE_NUMBER && bcType.b == LBC_TYPE_NUMBER)
regTags[ra] = LBC_TYPE_NUMBER;
else if (hostHooks.userdataMetamethodBytecodeType &&
(isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
else if (hostHooks.userdataMetamethodBytecodeType && (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op));
bcType.result = regTags[ra];
@ -960,8 +959,7 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks)
regTags[ra] = LBC_TYPE_NUMBER;
else if (bcType.a == LBC_TYPE_VECTOR && bcType.b == LBC_TYPE_VECTOR)
regTags[ra] = LBC_TYPE_VECTOR;
else if (hostHooks.userdataMetamethodBytecodeType &&
(isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
else if (hostHooks.userdataMetamethodBytecodeType && (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op));
bcType.result = regTags[ra];
@ -990,8 +988,7 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks)
if (bcType.b == LBC_TYPE_NUMBER || bcType.b == LBC_TYPE_VECTOR)
regTags[ra] = LBC_TYPE_VECTOR;
}
else if (hostHooks.userdataMetamethodBytecodeType &&
(isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
else if (hostHooks.userdataMetamethodBytecodeType && (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
{
regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op));
}

View file

@ -143,7 +143,14 @@ CodeAllocator::~CodeAllocator()
}
bool CodeAllocator::allocate(
const uint8_t* data, size_t dataSize, const uint8_t* code, size_t codeSize, uint8_t*& result, size_t& resultSize, uint8_t*& resultCodeStart)
const uint8_t* data,
size_t dataSize,
const uint8_t* code,
size_t codeSize,
uint8_t*& result,
size_t& resultSize,
uint8_t*& resultCodeStart
)
{
// 'Round up' to preserve code alignment
size_t alignedDataSize = (dataSize + (kCodeAlignment - 1)) & ~(kCodeAlignment - 1);

Some files were not shown because too many files have changed in this diff Show more