From 1e7b23fbfc3a8681f867d613bb4845db83b3715f Mon Sep 17 00:00:00 2001 From: Andy Friesen Date: Fri, 24 Feb 2023 10:24:22 -0800 Subject: [PATCH] Sync to upstream/release/565 --- Analysis/include/Luau/Constraint.h | 26 +- .../include/Luau/ConstraintGraphBuilder.h | 27 +- Analysis/include/Luau/ConstraintSolver.h | 17 +- Analysis/include/Luau/DcrLogger.h | 41 +- Analysis/include/Luau/Scope.h | 2 +- Analysis/include/Luau/Symbol.h | 5 + Analysis/include/Luau/Type.h | 23 + Analysis/include/Luau/TypeInfer.h | 4 +- Analysis/include/Luau/TypeReduction.h | 41 +- Analysis/include/Luau/Unifier.h | 6 + Analysis/src/ConstraintGraphBuilder.cpp | 163 ++-- Analysis/src/ConstraintSolver.cpp | 358 +++++++- Analysis/src/DcrLogger.cpp | 227 +++-- Analysis/src/Frontend.cpp | 4 +- Analysis/src/Normalize.cpp | 2 + Analysis/src/Quantify.cpp | 2 +- Analysis/src/Scope.cpp | 6 +- Analysis/src/ToString.cpp | 6 + Analysis/src/TypeChecker2.cpp | 248 ++++- Analysis/src/TypeInfer.cpp | 205 +++-- Analysis/src/TypeReduction.cpp | 362 ++++---- Analysis/src/Unifier.cpp | 7 +- CodeGen/include/Luau/IrBuilder.h | 3 + CodeGen/include/Luau/IrData.h | 95 +- CodeGen/include/Luau/IrUtils.h | 9 +- CodeGen/include/Luau/OptimizeConstProp.h | 16 + CodeGen/src/CodeGen.cpp | 496 +--------- CodeGen/src/EmitBuiltinsX64.cpp | 24 +- CodeGen/src/EmitBuiltinsX64.h | 13 +- CodeGen/src/EmitCommonX64.h | 8 - CodeGen/src/EmitInstructionX64.cpp | 867 +----------------- CodeGen/src/EmitInstructionX64.h | 61 +- CodeGen/src/IrAnalysis.cpp | 2 + CodeGen/src/IrBuilder.cpp | 63 +- CodeGen/src/IrDump.cpp | 55 +- CodeGen/src/IrLoweringX64.cpp | 373 ++++---- CodeGen/src/IrLoweringX64.h | 32 +- CodeGen/src/IrRegAllocX64.cpp | 181 ++++ CodeGen/src/IrRegAllocX64.h | 51 ++ CodeGen/src/IrTranslateBuiltins.cpp | 40 + CodeGen/src/IrTranslateBuiltins.h | 27 + CodeGen/src/IrTranslation.cpp | 81 +- CodeGen/src/IrTranslation.h | 3 + CodeGen/src/IrUtils.cpp | 70 +- CodeGen/src/OptimizeConstProp.cpp | 565 ++++++++++++ CodeGen/src/OptimizeFinalX64.cpp | 5 +- Common/include/Luau/Bytecode.h | 2 +- Compiler/src/Compiler.cpp | 22 +- Sources.cmake | 6 + tests/Compiler.test.cpp | 61 +- tests/ConstraintGraphBuilderFixture.cpp | 3 +- tests/IrBuilder.test.cpp | 689 +++++++++++++- tests/Module.test.cpp | 47 +- tests/NonstrictMode.test.cpp | 2 +- tests/ToString.test.cpp | 14 +- tests/TypeInfer.aliases.test.cpp | 12 +- tests/TypeInfer.functions.test.cpp | 15 +- tests/TypeInfer.refinements.test.cpp | 34 + tests/TypeInfer.tables.test.cpp | 25 +- tests/TypeInfer.tryUnify.test.cpp | 9 +- tests/TypeInfer.unknownnever.test.cpp | 2 +- tools/faillist.txt | 64 +- 62 files changed, 3500 insertions(+), 2429 deletions(-) create mode 100644 CodeGen/include/Luau/OptimizeConstProp.h create mode 100644 CodeGen/src/IrRegAllocX64.cpp create mode 100644 CodeGen/src/IrRegAllocX64.h create mode 100644 CodeGen/src/IrTranslateBuiltins.cpp create mode 100644 CodeGen/src/IrTranslateBuiltins.h create mode 100644 CodeGen/src/OptimizeConstProp.cpp diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h index 65599e49..1c41bbb7 100644 --- a/Analysis/include/Luau/Constraint.h +++ b/Analysis/include/Luau/Constraint.h @@ -159,6 +159,20 @@ struct SetPropConstraint TypeId propType; }; +// result ~ setIndexer subjectType indexType propType +// +// If the subject is a table or table-like thing that already has an indexer, +// unify its indexType and propType with those from this constraint. +// +// If the table is a free or unsealed table, we augment it with a new indexer. +struct SetIndexerConstraint +{ + TypeId resultType; + TypeId subjectType; + TypeId indexType; + TypeId propType; +}; + // if negation: // result ~ if isSingleton D then ~D else unknown where D = discriminantType // if not negation: @@ -170,9 +184,19 @@ struct SingletonOrTopTypeConstraint bool negated; }; +// resultType ~ unpack sourceTypePack +// +// Similar to PackSubtypeConstraint, but with one important difference: If the +// sourcePack is blocked, this constraint blocks. +struct UnpackConstraint +{ + TypePackId resultPack; + TypePackId sourcePack; +}; + using ConstraintV = Variant; + HasPropConstraint, SetPropConstraint, SetIndexerConstraint, SingletonOrTopTypeConstraint, UnpackConstraint>; struct Constraint { diff --git a/Analysis/include/Luau/ConstraintGraphBuilder.h b/Analysis/include/Luau/ConstraintGraphBuilder.h index 085b6732..7b2711f8 100644 --- a/Analysis/include/Luau/ConstraintGraphBuilder.h +++ b/Analysis/include/Luau/ConstraintGraphBuilder.h @@ -191,7 +191,7 @@ struct ConstraintGraphBuilder Inference check(const ScopePtr& scope, AstExprTable* expr, std::optional expectedType); std::tuple checkBinary(const ScopePtr& scope, AstExprBinary* binary, std::optional expectedType); - TypePackId checkLValues(const ScopePtr& scope, AstArray exprs); + std::vector checkLValues(const ScopePtr& scope, AstArray exprs); TypeId checkLValue(const ScopePtr& scope, AstExpr* expr); @@ -244,10 +244,31 @@ struct ConstraintGraphBuilder **/ TypePackId resolveTypePack(const ScopePtr& scope, const AstTypeList& list, bool inTypeArguments); + /** + * Creates generic types given a list of AST definitions, resolving default + * types as required. + * @param scope the scope that the generics should belong to. + * @param generics the AST generics to create types for. + * @param useCache whether to use the generic type cache for the given + * scope. + * @param addTypes whether to add the types to the scope's + * privateTypeBindings map. + **/ std::vector> createGenerics( - const ScopePtr& scope, AstArray generics, bool useCache = false); + const ScopePtr& scope, AstArray generics, bool useCache = false, bool addTypes = true); + + /** + * Creates generic type packs given a list of AST definitions, resolving + * default type packs as required. + * @param scope the scope that the generic packs should belong to. + * @param generics the AST generics to create type packs for. + * @param useCache whether to use the generic type pack cache for the given + * scope. + * @param addTypes whether to add the types to the scope's + * privateTypePackBindings map. + **/ std::vector> createGenericPacks( - const ScopePtr& scope, AstArray packs, bool useCache = false); + const ScopePtr& scope, AstArray packs, bool useCache = false, bool addTypes = true); Inference flattenPack(const ScopePtr& scope, Location location, InferencePack pack); diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index de7b3a04..62687ae4 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -8,6 +8,7 @@ #include "Luau/Normalize.h" #include "Luau/ToString.h" #include "Luau/Type.h" +#include "Luau/TypeReduction.h" #include "Luau/Variant.h" #include @@ -19,7 +20,12 @@ struct DcrLogger; // TypeId, TypePackId, or Constraint*. It is impossible to know which, but we // never dereference this pointer. -using BlockedConstraintId = const void*; +using BlockedConstraintId = Variant; + +struct HashBlockedConstraintId +{ + size_t operator()(const BlockedConstraintId& bci) const; +}; struct ModuleResolver; @@ -47,6 +53,7 @@ struct ConstraintSolver NotNull builtinTypes; InternalErrorReporter iceReporter; NotNull normalizer; + NotNull reducer; // The entire set of constraints that the solver is trying to resolve. std::vector> constraints; NotNull rootScope; @@ -65,7 +72,7 @@ struct ConstraintSolver // anything. std::unordered_map, size_t> blockedConstraints; // A mapping of type/pack pointers to the constraints they block. - std::unordered_map>> blocked; + std::unordered_map>, HashBlockedConstraintId> blocked; // Memoized instantiations of type aliases. DenseHashMap instantiatedAliases{{}}; @@ -78,7 +85,8 @@ struct ConstraintSolver DcrLogger* logger; explicit ConstraintSolver(NotNull normalizer, NotNull rootScope, std::vector> constraints, - ModuleName moduleName, NotNull moduleResolver, std::vector requireCycles, DcrLogger* logger); + ModuleName moduleName, NotNull reducer, NotNull moduleResolver, std::vector requireCycles, + DcrLogger* logger); // Randomize the order in which to dispatch constraints void randomize(unsigned seed); @@ -112,7 +120,9 @@ struct ConstraintSolver bool tryDispatch(const PrimitiveTypeConstraint& c, NotNull constraint); bool tryDispatch(const HasPropConstraint& c, NotNull constraint); bool tryDispatch(const SetPropConstraint& c, NotNull constraint, bool force); + bool tryDispatch(const SetIndexerConstraint& c, NotNull constraint, bool force); bool tryDispatch(const SingletonOrTopTypeConstraint& c, NotNull constraint); + bool tryDispatch(const UnpackConstraint& c, NotNull constraint); // for a, ... in some_table do // also handles __iter metamethod @@ -123,6 +133,7 @@ struct ConstraintSolver TypeId nextTy, TypeId tableTy, TypeId firstIndexTy, const IterableConstraint& c, NotNull constraint, bool force); std::optional lookupTableProp(TypeId subjectType, const std::string& propName); + std::optional lookupTableProp(TypeId subjectType, const std::string& propName, std::unordered_set& seen); void block(NotNull target, NotNull constraint); /** diff --git a/Analysis/include/Luau/DcrLogger.h b/Analysis/include/Luau/DcrLogger.h index 45c84c66..1e170d5b 100644 --- a/Analysis/include/Luau/DcrLogger.h +++ b/Analysis/include/Luau/DcrLogger.h @@ -4,6 +4,7 @@ #include "Luau/Constraint.h" #include "Luau/NotNull.h" #include "Luau/Scope.h" +#include "Luau/Module.h" #include "Luau/ToString.h" #include "Luau/Error.h" #include "Luau/Variant.h" @@ -34,11 +35,26 @@ struct TypeBindingSnapshot std::string typeString; }; +struct ExprTypesAtLocation +{ + Location location; + TypeId ty; + std::optional expectedTy; +}; + +struct AnnotationTypesAtLocation +{ + Location location; + TypeId resolvedTy; +}; + struct ConstraintGenerationLog { std::string source; - std::unordered_map constraintLocations; std::vector errors; + + std::vector exprTypeLocations; + std::vector annotationTypeLocations; }; struct ScopeSnapshot @@ -49,16 +65,11 @@ struct ScopeSnapshot std::vector children; }; -enum class ConstraintBlockKind -{ - TypeId, - TypePackId, - ConstraintId, -}; +using ConstraintBlockTarget = Variant>; struct ConstraintBlock { - ConstraintBlockKind kind; + ConstraintBlockTarget target; std::string stringification; }; @@ -71,16 +82,18 @@ struct ConstraintSnapshot struct BoundarySnapshot { - std::unordered_map constraints; + DenseHashMap unsolvedConstraints{nullptr}; ScopeSnapshot rootScope; + DenseHashMap typeStrings{nullptr}; }; struct StepSnapshot { - std::string currentConstraint; + const Constraint* currentConstraint; bool forced; - std::unordered_map unsolvedConstraints; + DenseHashMap unsolvedConstraints{nullptr}; ScopeSnapshot rootScope; + DenseHashMap typeStrings{nullptr}; }; struct TypeSolveLog @@ -95,8 +108,6 @@ struct TypeCheckLog std::vector errors; }; -using ConstraintBlockTarget = Variant>; - struct DcrLogger { std::string compileOutput(); @@ -104,6 +115,7 @@ struct DcrLogger void captureSource(std::string source); void captureGenerationError(const TypeError& error); void captureConstraintLocation(NotNull constraint, Location location); + void captureGenerationModule(const ModulePtr& module); void pushBlock(NotNull constraint, TypeId block); void pushBlock(NotNull constraint, TypePackId block); @@ -126,9 +138,10 @@ private: TypeSolveLog solveLog; TypeCheckLog checkLog; - ToStringOptions opts; + ToStringOptions opts{true}; std::vector snapshotBlocks(NotNull constraint); + void captureBoundaryState(BoundarySnapshot& target, const Scope* rootScope, const std::vector>& unsolvedConstraints); }; } // namespace Luau diff --git a/Analysis/include/Luau/Scope.h b/Analysis/include/Luau/Scope.h index a8f83e2f..85a36fc9 100644 --- a/Analysis/include/Luau/Scope.h +++ b/Analysis/include/Luau/Scope.h @@ -52,7 +52,7 @@ struct Scope std::optional lookup(Symbol sym) const; std::optional lookup(DefId def) const; - std::optional> lookupEx(Symbol sym); + std::optional> lookupEx(Symbol sym); std::optional lookupType(const Name& name); std::optional lookupImportedType(const Name& moduleAlias, const Name& name); diff --git a/Analysis/include/Luau/Symbol.h b/Analysis/include/Luau/Symbol.h index 0432946c..b47554e0 100644 --- a/Analysis/include/Luau/Symbol.h +++ b/Analysis/include/Luau/Symbol.h @@ -37,6 +37,11 @@ struct Symbol AstLocal* local; AstName global; + explicit operator bool() const + { + return local != nullptr || global.value != nullptr; + } + bool operator==(const Symbol& rhs) const { if (local) diff --git a/Analysis/include/Luau/Type.h b/Analysis/include/Luau/Type.h index 00e6d6c6..d009001b 100644 --- a/Analysis/include/Luau/Type.h +++ b/Analysis/include/Luau/Type.h @@ -246,6 +246,18 @@ struct WithPredicate { T type; PredicateVec predicates; + + WithPredicate() = default; + explicit WithPredicate(T type) + : type(type) + { + } + + WithPredicate(T type, PredicateVec predicates) + : type(type) + , predicates(std::move(predicates)) + { + } }; using MagicFunction = std::function>( @@ -853,4 +865,15 @@ bool hasTag(TypeId ty, const std::string& tagName); bool hasTag(const Property& prop, const std::string& tagName); bool hasTag(const Tags& tags, const std::string& tagName); // Do not use in new work. +/* + * Use this to change the kind of a particular type. + * + * LUAU_NOINLINE so that the calling frame doesn't have to pay the stack storage for the new variant. + */ +template +LUAU_NOINLINE T* emplaceType(Type* ty, Args&&... args) +{ + return &ty->ty.emplace(std::forward(args)...); +} + } // namespace Luau diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index d748a1f5..678bd419 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -146,10 +146,12 @@ struct TypeChecker WithPredicate checkExprPackHelper(const ScopePtr& scope, const AstExpr& expr); WithPredicate checkExprPackHelper(const ScopePtr& scope, const AstExprCall& expr); + WithPredicate checkExprPackHelper2( + const ScopePtr& scope, const AstExprCall& expr, TypeId selfType, TypeId actualFunctionType, TypeId functionType, TypePackId retPack); std::vector> getExpectedTypesForCall(const std::vector& overloads, size_t argumentCount, bool selfCall); - std::optional> checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, + std::unique_ptr> checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, TypePackId argPack, TypePack* args, const std::vector* argLocations, const WithPredicate& argListResult, std::vector& overloadsThatMatchArgCount, std::vector& overloadsThatDont, std::vector& errors); bool handleSelfCallMismatch(const ScopePtr& scope, const AstExprCall& expr, TypePack* args, const std::vector& argLocations, diff --git a/Analysis/include/Luau/TypeReduction.h b/Analysis/include/Luau/TypeReduction.h index 0ad034a4..80a7ac59 100644 --- a/Analysis/include/Luau/TypeReduction.h +++ b/Analysis/include/Luau/TypeReduction.h @@ -12,11 +12,36 @@ namespace Luau namespace detail { template -struct ReductionContext +struct ReductionEdge { T type = nullptr; bool irreducible = false; }; + +struct TypeReductionMemoization +{ + TypeReductionMemoization() = default; + + TypeReductionMemoization(const TypeReductionMemoization&) = delete; + TypeReductionMemoization& operator=(const TypeReductionMemoization&) = delete; + + TypeReductionMemoization(TypeReductionMemoization&&) = default; + TypeReductionMemoization& operator=(TypeReductionMemoization&&) = default; + + DenseHashMap> types{nullptr}; + DenseHashMap> typePacks{nullptr}; + + bool isIrreducible(TypeId ty); + bool isIrreducible(TypePackId tp); + + TypeId memoize(TypeId ty, TypeId reducedTy); + TypePackId memoize(TypePackId tp, TypePackId reducedTp); + + // Reducing A into B may have a non-irreducible edge A to B for which B is not irreducible, which means B could be reduced into C. + // Because reduction should always be transitive, A should point to C if A points to B and B points to C. + std::optional> memoizedof(TypeId ty) const; + std::optional> memoizedof(TypePackId tp) const; +}; } // namespace detail struct TypeReductionOptions @@ -42,29 +67,19 @@ struct TypeReduction std::optional reduce(TypePackId tp); std::optional reduce(const TypeFun& fun); - /// Creating a child TypeReduction will allow the parent TypeReduction to share its memoization with the child TypeReductions. - /// This is safe as long as the parent's TypeArena continues to outlive both TypeReduction memoization. - TypeReduction fork(NotNull arena, const TypeReductionOptions& opts = {}) const; - private: - const TypeReduction* parent = nullptr; - NotNull arena; NotNull builtinTypes; NotNull handle; - TypeReductionOptions options; - DenseHashMap> memoizedTypes{nullptr}; - DenseHashMap> memoizedTypePacks{nullptr}; + TypeReductionOptions options; + detail::TypeReductionMemoization memoization; // Computes an *estimated length* of the cartesian product of the given type. size_t cartesianProductSize(TypeId ty) const; bool hasExceededCartesianProductLimit(TypeId ty) const; bool hasExceededCartesianProductLimit(TypePackId tp) const; - - std::optional memoizedof(TypeId ty) const; - std::optional memoizedof(TypePackId tp) const; }; } // namespace Luau diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 988ad9c6..ebfff4c2 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -67,6 +67,12 @@ struct Unifier UnifierSharedState& sharedState; + // When the Unifier is forced to unify two blocked types (or packs), they + // get added to these vectors. The ConstraintSolver can use this to know + // when it is safe to reattempt dispatching a constraint. + std::vector blockedTypes; + std::vector blockedTypePacks; + Unifier( NotNull normalizer, Mode mode, NotNull scope, const Location& location, Variance variance, TxnLog* parentLog = nullptr); diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index aa605bdf..fe412632 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -320,6 +320,9 @@ void ConstraintGraphBuilder::visit(AstStatBlock* block) prepopulateGlobalScope(scope, block); visitBlockWithoutChildScope(scope, block); + + if (FFlag::DebugLuauLogSolverToJson) + logger->captureGenerationModule(module); } void ConstraintGraphBuilder::visitBlockWithoutChildScope(const ScopePtr& scope, AstStatBlock* block) @@ -357,13 +360,11 @@ void ConstraintGraphBuilder::visitBlockWithoutChildScope(const ScopePtr& scope, for (const auto& [name, gen] : createGenerics(defnScope, alias->generics, /* useCache */ true)) { initialFun.typeParams.push_back(gen); - defnScope->privateTypeBindings[name] = TypeFun{gen.ty}; } for (const auto& [name, genPack] : createGenericPacks(defnScope, alias->genericPacks, /* useCache */ true)) { initialFun.typePackParams.push_back(genPack); - defnScope->privateTypePackBindings[name] = genPack.tp; } if (alias->exported) @@ -503,13 +504,13 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) if (j - i < packTypes.head.size()) varTypes[j] = packTypes.head[j - i]; else - varTypes[j] = freshType(scope); + varTypes[j] = arena->addType(BlockedType{}); } } std::vector tailValues{varTypes.begin() + i, varTypes.end()}; TypePackId tailPack = arena->addTypePack(std::move(tailValues)); - addConstraint(scope, local->location, PackSubtypeConstraint{exprPack, tailPack}); + addConstraint(scope, local->location, UnpackConstraint{tailPack, exprPack}); } } } @@ -686,6 +687,8 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction* funct Checkpoint start = checkpoint(this); FunctionSignature sig = checkFunctionSignature(scope, function->func); + std::unordered_set excludeList; + if (AstExprLocal* localName = function->name->as()) { std::optional existingFunctionTy = scope->lookup(localName->local); @@ -716,9 +719,20 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction* funct } else if (AstExprIndexName* indexName = function->name->as()) { + Checkpoint check1 = checkpoint(this); TypeId lvalueType = checkLValue(scope, indexName); + Checkpoint check2 = checkpoint(this); + + forEachConstraint(check1, check2, this, [&excludeList](const ConstraintPtr& c) { + excludeList.insert(c.get()); + }); + // TODO figure out how to populate the location field of the table Property. - addConstraint(scope, indexName->location, SubtypeConstraint{lvalueType, generalizedType}); + + if (get(lvalueType)) + asMutable(lvalueType)->ty.emplace(generalizedType); + else + addConstraint(scope, indexName->location, SubtypeConstraint{lvalueType, generalizedType}); } else if (AstExprError* err = function->name->as()) { @@ -735,8 +749,9 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction* funct std::unique_ptr c = std::make_unique(constraintScope, function->name->location, GeneralizationConstraint{generalizedType, sig.signature}); - forEachConstraint(start, end, this, [&c](const ConstraintPtr& constraint) { - c->dependencies.push_back(NotNull{constraint.get()}); + forEachConstraint(start, end, this, [&c, &excludeList](const ConstraintPtr& constraint) { + if (!excludeList.count(constraint.get())) + c->dependencies.push_back(NotNull{constraint.get()}); }); addConstraint(scope, std::move(c)); @@ -763,16 +778,31 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatBlock* block) visitBlockWithoutChildScope(innerScope, block); } +static void bindFreeType(TypeId a, TypeId b) +{ + FreeType* af = getMutable(a); + FreeType* bf = getMutable(b); + + LUAU_ASSERT(af || bf); + + if (!bf) + asMutable(a)->ty.emplace(b); + else if (!af) + asMutable(b)->ty.emplace(a); + else if (subsumes(bf->scope, af->scope)) + asMutable(a)->ty.emplace(b); + else if (subsumes(af->scope, bf->scope)) + asMutable(b)->ty.emplace(a); +} + void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatAssign* assign) { - TypePackId varPackId = checkLValues(scope, assign->vars); - - TypePack expectedPack = extendTypePack(*arena, builtinTypes, varPackId, assign->values.size); + std::vector varTypes = checkLValues(scope, assign->vars); std::vector> expectedTypes; - expectedTypes.reserve(expectedPack.head.size()); + expectedTypes.reserve(varTypes.size()); - for (TypeId ty : expectedPack.head) + for (TypeId ty : varTypes) { ty = follow(ty); if (get(ty)) @@ -781,9 +811,10 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatAssign* assign) expectedTypes.push_back(ty); } - TypePackId valuePack = checkPack(scope, assign->values, expectedTypes).tp; + TypePackId exprPack = checkPack(scope, assign->values, expectedTypes).tp; + TypePackId varPack = arena->addTypePack({varTypes}); - addConstraint(scope, assign->location, PackSubtypeConstraint{valuePack, varPackId}); + addConstraint(scope, assign->location, PackSubtypeConstraint{exprPack, varPack}); } void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatCompoundAssign* assign) @@ -865,11 +896,11 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatTypeAlias* alia asMutable(aliasTy)->ty.emplace(ty); std::vector typeParams; - for (auto tyParam : createGenerics(*defnScope, alias->generics, /* useCache */ true)) + for (auto tyParam : createGenerics(*defnScope, alias->generics, /* useCache */ true, /* addTypes */ false)) typeParams.push_back(tyParam.second.ty); std::vector typePackParams; - for (auto tpParam : createGenericPacks(*defnScope, alias->genericPacks, /* useCache */ true)) + for (auto tpParam : createGenericPacks(*defnScope, alias->genericPacks, /* useCache */ true, /* addTypes */ false)) typePackParams.push_back(tpParam.second.tp); addConstraint(scope, alias->type->location, @@ -1010,7 +1041,6 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareFunction for (auto& [name, generic] : generics) { genericTys.push_back(generic.ty); - scope->privateTypeBindings[name] = TypeFun{generic.ty}; } std::vector genericTps; @@ -1018,7 +1048,6 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareFunction for (auto& [name, generic] : genericPacks) { genericTps.push_back(generic.tp); - scope->privateTypePackBindings[name] = generic.tp; } ScopePtr funScope = scope; @@ -1161,7 +1190,7 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa TypePackId expectedArgPack = arena->freshTypePack(scope.get()); TypePackId expectedRetPack = arena->freshTypePack(scope.get()); - TypeId expectedFunctionType = arena->addType(FunctionType{expectedArgPack, expectedRetPack}); + TypeId expectedFunctionType = arena->addType(FunctionType{expectedArgPack, expectedRetPack, std::nullopt, call->self}); TypeId instantiatedFnType = arena->addType(BlockedType{}); addConstraint(scope, call->location, InstantiationConstraint{instantiatedFnType, fnType}); @@ -1264,7 +1293,7 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa // TODO: How do expectedTypes play into this? Do they? TypePackId rets = arena->addTypePack(BlockedTypePack{}); TypePackId argPack = arena->addTypePack(TypePack{args, argTail}); - FunctionType ftv(TypeLevel{}, scope.get(), argPack, rets); + FunctionType ftv(TypeLevel{}, scope.get(), argPack, rets, std::nullopt, call->self); NotNull fcc = addConstraint(scope, call->func->location, FunctionCallConstraint{ @@ -1457,7 +1486,7 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprGlobal* gl Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexName* indexName) { TypeId obj = check(scope, indexName->expr).ty; - TypeId result = freshType(scope); + TypeId result = arena->addType(BlockedType{}); std::optional def = dfg->getDef(indexName); if (def) @@ -1468,13 +1497,7 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexName* scope->dcrRefinements[*def] = result; } - TableType::Props props{{indexName->index.value, Property{result}}}; - const std::optional indexer; - TableType ttv{std::move(props), indexer, TypeLevel{}, scope.get(), TableState::Free}; - - TypeId expectedTableType = arena->addType(std::move(ttv)); - - addConstraint(scope, indexName->expr->location, SubtypeConstraint{obj, expectedTableType}); + addConstraint(scope, indexName->expr->location, HasPropConstraint{result, obj, indexName->index.value}); if (def) return Inference{result, refinementArena.proposition(*def, builtinTypes->truthyType)}; @@ -1589,6 +1612,8 @@ std::tuple ConstraintGraphBuilder::checkBinary( else if (typeguard->type == "number") discriminantTy = builtinTypes->numberType; else if (typeguard->type == "boolean") + discriminantTy = builtinTypes->booleanType; + else if (typeguard->type == "thread") discriminantTy = builtinTypes->threadType; else if (typeguard->type == "table") discriminantTy = builtinTypes->tableType; @@ -1596,8 +1621,8 @@ std::tuple ConstraintGraphBuilder::checkBinary( discriminantTy = builtinTypes->functionType; else if (typeguard->type == "userdata") { - // For now, we don't really care about being accurate with userdata if the typeguard was using typeof - discriminantTy = builtinTypes->neverType; // TODO: replace with top class type + // For now, we don't really care about being accurate with userdata if the typeguard was using typeof. + discriminantTy = builtinTypes->classType; } else if (!typeguard->isTypeof && typeguard->type == "vector") discriminantTy = builtinTypes->neverType; // TODO: figure out a way to deal with this quirky type @@ -1649,18 +1674,15 @@ std::tuple ConstraintGraphBuilder::checkBinary( } } -TypePackId ConstraintGraphBuilder::checkLValues(const ScopePtr& scope, AstArray exprs) +std::vector ConstraintGraphBuilder::checkLValues(const ScopePtr& scope, AstArray exprs) { std::vector types; types.reserve(exprs.size); - for (size_t i = 0; i < exprs.size; ++i) - { - AstExpr* const expr = exprs.data[i]; + for (AstExpr* expr : exprs) types.push_back(checkLValue(scope, expr)); - } - return arena->addTypePack(std::move(types)); + return types; } /** @@ -1679,6 +1701,28 @@ TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr) indexExpr->location, indexExpr->expr, syntheticIndex, constantString->location, indexExpr->expr->location.end, '.'}; return checkLValue(scope, &synthetic); } + + // An indexer is only interesting in an lvalue-ey way if it is at the + // tail of an expression. + // + // If the indexer is not at the tail, then we are not interested in + // augmenting the lhs data structure with a new indexer. Constraint + // generation can treat it as an ordinary lvalue. + // + // eg + // + // a.b.c[1] = 44 -- lvalue + // a.b[4].c = 2 -- rvalue + + TypeId resultType = arena->addType(BlockedType{}); + TypeId subjectType = check(scope, indexExpr->expr).ty; + TypeId indexType = check(scope, indexExpr->index).ty; + TypeId propType = arena->addType(BlockedType{}); + addConstraint(scope, expr->location, SetIndexerConstraint{resultType, subjectType, indexType, propType}); + + module->astTypes[expr] = propType; + + return propType; } else if (!expr->is()) return check(scope, expr).ty; @@ -1718,7 +1762,8 @@ TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr) auto lookupResult = scope->lookupEx(sym); if (!lookupResult) return check(scope, expr).ty; - const auto [subjectType, symbolScope] = std::move(*lookupResult); + const auto [subjectBinding, symbolScope] = std::move(*lookupResult); + TypeId subjectType = subjectBinding->typeId; TypeId propTy = freshType(scope); @@ -1739,14 +1784,17 @@ TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr) module->astTypes[expr] = prevSegmentTy; module->astTypes[e] = updatedType; - symbolScope->bindings[sym].typeId = updatedType; - - std::optional def = dfg->getDef(sym); - if (def) + if (!subjectType->persistent) { - // This can fail if the user is erroneously trying to augment a builtin - // table like os or string. - symbolScope->dcrRefinements[*def] = updatedType; + symbolScope->bindings[sym].typeId = updatedType; + + std::optional def = dfg->getDef(sym); + if (def) + { + // This can fail if the user is erroneously trying to augment a builtin + // table like os or string. + symbolScope->dcrRefinements[*def] = updatedType; + } } return propTy; @@ -1904,13 +1952,11 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS for (const auto& [name, g] : genericDefinitions) { genericTypes.push_back(g.ty); - signatureScope->privateTypeBindings[name] = TypeFun{g.ty}; } for (const auto& [name, g] : genericPackDefinitions) { genericTypePacks.push_back(g.tp); - signatureScope->privateTypePackBindings[name] = g.tp; } // Local variable works around an odd gcc 11.3 warning: may be used uninitialized @@ -2023,15 +2069,14 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS actualFunction.generics = std::move(genericTypes); actualFunction.genericPacks = std::move(genericTypePacks); actualFunction.argNames = std::move(argNames); + actualFunction.hasSelf = fn->self != nullptr; TypeId actualFunctionType = arena->addType(std::move(actualFunction)); LUAU_ASSERT(actualFunctionType); module->astTypes[fn] = actualFunctionType; if (expectedType && get(*expectedType)) - { - asMutable(*expectedType)->ty.emplace(actualFunctionType); - } + bindFreeType(*expectedType, actualFunctionType); return { /* signature */ actualFunctionType, @@ -2179,13 +2224,11 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b for (const auto& [name, g] : genericDefinitions) { genericTypes.push_back(g.ty); - signatureScope->privateTypeBindings[name] = TypeFun{g.ty}; } for (const auto& [name, g] : genericPackDefinitions) { genericTypePacks.push_back(g.tp); - signatureScope->privateTypePackBindings[name] = g.tp; } } else @@ -2330,7 +2373,7 @@ TypePackId ConstraintGraphBuilder::resolveTypePack(const ScopePtr& scope, const } std::vector> ConstraintGraphBuilder::createGenerics( - const ScopePtr& scope, AstArray generics, bool useCache) + const ScopePtr& scope, AstArray generics, bool useCache, bool addTypes) { std::vector> result; for (const auto& generic : generics) @@ -2350,6 +2393,9 @@ std::vector> ConstraintGraphBuilder::crea if (generic.defaultValue) defaultTy = resolveType(scope, generic.defaultValue, /* inTypeArguments */ false); + if (addTypes) + scope->privateTypeBindings[generic.name.value] = TypeFun{genericTy}; + result.push_back({generic.name.value, GenericTypeDefinition{genericTy, defaultTy}}); } @@ -2357,7 +2403,7 @@ std::vector> ConstraintGraphBuilder::crea } std::vector> ConstraintGraphBuilder::createGenericPacks( - const ScopePtr& scope, AstArray generics, bool useCache) + const ScopePtr& scope, AstArray generics, bool useCache, bool addTypes) { std::vector> result; for (const auto& generic : generics) @@ -2378,6 +2424,9 @@ std::vector> ConstraintGraphBuilder:: if (generic.defaultValue) defaultTy = resolveTypePack(scope, generic.defaultValue, /* inTypeArguments */ false); + if (addTypes) + scope->privateTypePackBindings[generic.name.value] = genericTy; + result.push_back({generic.name.value, GenericTypePackDefinition{genericTy, defaultTy}}); } @@ -2394,11 +2443,9 @@ Inference ConstraintGraphBuilder::flattenPack(const ScopePtr& scope, Location lo if (auto f = first(tp)) return Inference{*f, refinement}; - TypeId typeResult = freshType(scope); - TypePack onePack{{typeResult}, freshTypePack(scope)}; - TypePackId oneTypePack = arena->addTypePack(std::move(onePack)); - - addConstraint(scope, location, PackSubtypeConstraint{tp, oneTypePack}); + TypeId typeResult = arena->addType(BlockedType{}); + TypePackId resultPack = arena->addTypePack({typeResult}, arena->freshTypePack(scope.get())); + addConstraint(scope, location, UnpackConstraint{resultPack, tp}); return Inference{typeResult, refinement}; } diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 879dac39..96673e3d 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -22,6 +22,22 @@ LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false); namespace Luau { +size_t HashBlockedConstraintId::operator()(const BlockedConstraintId& bci) const +{ + size_t result = 0; + + if (const TypeId* ty = get_if(&bci)) + result = std::hash()(*ty); + else if (const TypePackId* tp = get_if(&bci)) + result = std::hash()(*tp); + else if (Constraint const* const* c = get_if(&bci)) + result = std::hash()(*c); + else + LUAU_ASSERT(!"Should be unreachable"); + + return result; +} + [[maybe_unused]] static void dumpBindings(NotNull scope, ToStringOptions& opts) { for (const auto& [k, v] : scope->bindings) @@ -221,10 +237,12 @@ void dump(ConstraintSolver* cs, ToStringOptions& opts) } ConstraintSolver::ConstraintSolver(NotNull normalizer, NotNull rootScope, std::vector> constraints, - ModuleName moduleName, NotNull moduleResolver, std::vector requireCycles, DcrLogger* logger) + ModuleName moduleName, NotNull reducer, NotNull moduleResolver, std::vector requireCycles, + DcrLogger* logger) : arena(normalizer->arena) , builtinTypes(normalizer->builtinTypes) , normalizer(normalizer) + , reducer(reducer) , constraints(std::move(constraints)) , rootScope(rootScope) , currentModuleName(std::move(moduleName)) @@ -326,6 +344,27 @@ void ConstraintSolver::run() if (force) printf("Force "); printf("Dispatched\n\t%s\n", saveMe.c_str()); + + if (force) + { + printf("Blocked on:\n"); + + for (const auto& [bci, cv] : blocked) + { + if (end(cv) == std::find(begin(cv), end(cv), c)) + continue; + + if (auto bty = get_if(&bci)) + printf("\tType %s\n", toString(*bty, opts).c_str()); + else if (auto btp = get_if(&bci)) + printf("\tPack %s\n", toString(*btp, opts).c_str()); + else if (auto cc = get_if(&bci)) + printf("\tCons %s\n", toString(**cc, opts).c_str()); + else + LUAU_ASSERT(!"Unreachable??"); + } + } + dump(this, opts); } } @@ -411,8 +450,12 @@ bool ConstraintSolver::tryDispatch(NotNull constraint, bool fo success = tryDispatch(*hpc, constraint); else if (auto spc = get(*constraint)) success = tryDispatch(*spc, constraint, force); + else if (auto spc = get(*constraint)) + success = tryDispatch(*spc, constraint, force); else if (auto sottc = get(*constraint)) success = tryDispatch(*sottc, constraint); + else if (auto uc = get(*constraint)) + success = tryDispatch(*uc, constraint); else LUAU_ASSERT(false); @@ -424,26 +467,46 @@ bool ConstraintSolver::tryDispatch(NotNull constraint, bool fo bool ConstraintSolver::tryDispatch(const SubtypeConstraint& c, NotNull constraint, bool force) { - if (!recursiveBlock(c.subType, constraint)) - return false; - if (!recursiveBlock(c.superType, constraint)) - return false; - if (isBlocked(c.subType)) return block(c.subType, constraint); else if (isBlocked(c.superType)) return block(c.superType, constraint); - unify(c.subType, c.superType, constraint->scope); + Unifier u{normalizer, Mode::Strict, constraint->scope, Location{}, Covariant}; + u.useScopes = true; + + u.tryUnify(c.subType, c.superType); + + if (!u.blockedTypes.empty() || !u.blockedTypePacks.empty()) + { + for (TypeId bt : u.blockedTypes) + block(bt, constraint); + for (TypePackId btp : u.blockedTypePacks) + block(btp, constraint); + return false; + } + + if (!u.errors.empty()) + { + TypeId errorType = errorRecoveryType(); + u.tryUnify(c.subType, errorType); + u.tryUnify(c.superType, errorType); + } + + const auto [changedTypes, changedPacks] = u.log.getChanges(); + + u.log.commit(); + + unblock(changedTypes); + unblock(changedPacks); + + // unify(c.subType, c.superType, constraint->scope); return true; } bool ConstraintSolver::tryDispatch(const PackSubtypeConstraint& c, NotNull constraint, bool force) { - if (!recursiveBlock(c.subPack, constraint) || !recursiveBlock(c.superPack, constraint)) - return false; - if (isBlocked(c.subPack)) return block(c.subPack, constraint); else if (isBlocked(c.superPack)) @@ -1183,8 +1246,26 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNulladdType(BlockedType{}); TypeId inferredTy = arena->addType(FunctionType{TypeLevel{}, constraint->scope.get(), argsPack, c.result}); - auto ic = pushConstraint(constraint->scope, constraint->location, InstantiationConstraint{instantiatedTy, fn}); - auto sc = pushConstraint(constraint->scope, constraint->location, SubtypeConstraint{instantiatedTy, inferredTy}); + auto pushConstraintGreedy = [this, constraint](ConstraintV cv) -> Constraint* { + std::unique_ptr c = std::make_unique(constraint->scope, constraint->location, std::move(cv)); + NotNull borrow{c.get()}; + + bool ok = tryDispatch(borrow, false); + if (ok) + return nullptr; + + solverConstraints.push_back(std::move(c)); + unsolvedConstraints.push_back(borrow); + + return borrow; + }; + + // HACK: We don't want other constraints to act on the free type pack + // created above until after these two constraints are solved, so we try to + // dispatch them directly. + + auto ic = pushConstraintGreedy(InstantiationConstraint{instantiatedTy, fn}); + auto sc = pushConstraintGreedy(SubtypeConstraint{instantiatedTy, inferredTy}); // Anything that is blocked on this constraint must also be blocked on our // synthesized constraints. @@ -1193,8 +1274,10 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNullsecond) { - block(ic, blockedConstraint); - block(sc, blockedConstraint); + if (ic) + block(NotNull{ic}, blockedConstraint); + if (sc) + block(NotNull{sc}, blockedConstraint); } } @@ -1230,6 +1313,8 @@ bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNullreduce(subjectType).value_or(subjectType); + std::optional resultType = lookupTableProp(subjectType, c.prop); if (!resultType) { @@ -1360,11 +1445,18 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNullscope); + if (!isBlocked(c.propType)) + unify(c.propType, *existingPropType, constraint->scope); bind(c.resultType, c.subjectType); return true; } + if (get(subjectType) || get(subjectType) || get(subjectType)) + { + bind(c.resultType, subjectType); + return true; + } + if (get(subjectType)) { TypeId ty = arena->freshType(constraint->scope); @@ -1381,21 +1473,27 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNull(subjectType)) { if (ttv->state == TableState::Free) { + LUAU_ASSERT(!subjectType->persistent); + ttv->props[c.path[0]] = Property{c.propType}; bind(c.resultType, c.subjectType); return true; } else if (ttv->state == TableState::Unsealed) { + LUAU_ASSERT(!subjectType->persistent); + std::optional augmented = updateTheTableType(NotNull{arena}, subjectType, c.path, c.propType); bind(c.resultType, augmented.value_or(subjectType)); + bind(subjectType, c.resultType); return true; } else @@ -1411,16 +1509,62 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNull(subjectType) || get(subjectType) || get(subjectType)) - { - bind(c.resultType, subjectType); - return true; - } LUAU_ASSERT(0); return true; } +bool ConstraintSolver::tryDispatch(const SetIndexerConstraint& c, NotNull constraint, bool force) +{ + TypeId subjectType = follow(c.subjectType); + if (isBlocked(subjectType)) + return block(subjectType, constraint); + + if (auto ft = get(subjectType)) + { + Scope* scope = ft->scope; + TableType* tt = &asMutable(subjectType)->ty.emplace(TableState::Free, TypeLevel{}, scope); + tt->indexer = TableIndexer{c.indexType, c.propType}; + + asMutable(c.resultType)->ty.emplace(subjectType); + asMutable(c.propType)->ty.emplace(scope); + unblock(c.propType); + unblock(c.resultType); + + return true; + } + else if (auto tt = get(subjectType)) + { + if (tt->indexer) + { + // TODO This probably has to be invariant. + unify(c.indexType, tt->indexer->indexType, constraint->scope); + asMutable(c.propType)->ty.emplace(tt->indexer->indexResultType); + asMutable(c.resultType)->ty.emplace(subjectType); + unblock(c.propType); + unblock(c.resultType); + return true; + } + else if (tt->state == TableState::Free || tt->state == TableState::Unsealed) + { + auto mtt = getMutable(subjectType); + mtt->indexer = TableIndexer{c.indexType, c.propType}; + asMutable(c.propType)->ty.emplace(tt->scope); + asMutable(c.resultType)->ty.emplace(subjectType); + unblock(c.propType); + unblock(c.resultType); + return true; + } + // Do not augment sealed or generic tables that lack indexers + } + + asMutable(c.propType)->ty.emplace(builtinTypes->errorRecoveryType()); + asMutable(c.resultType)->ty.emplace(builtinTypes->errorRecoveryType()); + unblock(c.propType); + unblock(c.resultType); + return true; +} + bool ConstraintSolver::tryDispatch(const SingletonOrTopTypeConstraint& c, NotNull constraint) { if (isBlocked(c.discriminantType)) @@ -1439,6 +1583,69 @@ bool ConstraintSolver::tryDispatch(const SingletonOrTopTypeConstraint& c, NotNul return true; } +bool ConstraintSolver::tryDispatch(const UnpackConstraint& c, NotNull constraint) +{ + TypePackId sourcePack = follow(c.sourcePack); + TypePackId resultPack = follow(c.resultPack); + + if (isBlocked(sourcePack)) + return block(sourcePack, constraint); + + if (isBlocked(resultPack)) + { + asMutable(resultPack)->ty.emplace(sourcePack); + unblock(resultPack); + return true; + } + + TypePack srcPack = extendTypePack(*arena, builtinTypes, sourcePack, size(resultPack)); + + auto destIter = begin(resultPack); + auto destEnd = end(resultPack); + + size_t i = 0; + while (destIter != destEnd) + { + if (i >= srcPack.head.size()) + break; + TypeId srcTy = follow(srcPack.head[i]); + + if (isBlocked(*destIter)) + { + if (follow(srcTy) == *destIter) + { + // Cyclic type dependency. (????) + asMutable(*destIter)->ty.emplace(constraint->scope); + } + else + asMutable(*destIter)->ty.emplace(srcTy); + unblock(*destIter); + } + else + unify(*destIter, srcTy, constraint->scope); + + ++destIter; + ++i; + } + + // We know that resultPack does not have a tail, but we don't know if + // sourcePack is long enough to fill every value. Replace every remaining + // result TypeId with the error recovery type. + + while (destIter != destEnd) + { + if (isBlocked(*destIter)) + { + asMutable(*destIter)->ty.emplace(builtinTypes->errorRecoveryType()); + unblock(*destIter); + } + + ++destIter; + } + + return true; +} + bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const IterableConstraint& c, NotNull constraint, bool force) { auto block_ = [&](auto&& t) { @@ -1628,10 +1835,20 @@ bool ConstraintSolver::tryDispatchIterableFunction( std::optional ConstraintSolver::lookupTableProp(TypeId subjectType, const std::string& propName) { + std::unordered_set seen; + return lookupTableProp(subjectType, propName, seen); +} + +std::optional ConstraintSolver::lookupTableProp(TypeId subjectType, const std::string& propName, std::unordered_set& seen) +{ + if (!seen.insert(subjectType).second) + return std::nullopt; + auto collectParts = [&](auto&& unionOrIntersection) -> std::pair, std::vector> { std::optional blocked; std::vector parts; + std::vector freeParts; for (TypeId expectedPart : unionOrIntersection) { expectedPart = follow(expectedPart); @@ -1644,6 +1861,29 @@ std::optional ConstraintSolver::lookupTableProp(TypeId subjectType, cons else if (ttv->indexer && maybeString(ttv->indexer->indexType)) parts.push_back(ttv->indexer->indexResultType); } + else if (get(expectedPart)) + { + freeParts.push_back(expectedPart); + } + } + + // If the only thing resembling a match is a single fresh type, we can + // confidently tablify it. If other types match or if there are more + // than one free type, we can't do anything. + if (parts.empty() && 1 == freeParts.size()) + { + TypeId freePart = freeParts.front(); + const FreeType* ft = get(freePart); + LUAU_ASSERT(ft); + Scope* scope = ft->scope; + + TableType* tt = &asMutable(freePart)->ty.emplace(); + tt->state = TableState::Free; + tt->scope = scope; + TypeId propType = arena->freshType(scope); + tt->props[propName] = Property{propType}; + + parts.push_back(propType); } return {blocked, parts}; @@ -1651,12 +1891,75 @@ std::optional ConstraintSolver::lookupTableProp(TypeId subjectType, cons std::optional resultType; - if (auto ttv = get(subjectType)) + if (get(subjectType) || get(subjectType)) + { + return subjectType; + } + else if (auto ttv = getMutable(subjectType)) { if (auto prop = ttv->props.find(propName); prop != ttv->props.end()) resultType = prop->second.type; else if (ttv->indexer && maybeString(ttv->indexer->indexType)) resultType = ttv->indexer->indexResultType; + else if (ttv->state == TableState::Free) + { + resultType = arena->addType(FreeType{ttv->scope}); + ttv->props[propName] = Property{*resultType}; + } + } + else if (auto mt = get(subjectType)) + { + if (auto p = lookupTableProp(mt->table, propName, seen)) + return p; + + TypeId mtt = follow(mt->metatable); + + if (get(mtt)) + return mtt; + else if (auto metatable = get(mtt)) + { + auto indexProp = metatable->props.find("__index"); + if (indexProp == metatable->props.end()) + return std::nullopt; + + // TODO: __index can be an overloaded function. + + TypeId indexType = follow(indexProp->second.type); + + if (auto ft = get(indexType)) + { + std::optional ret = first(ft->retTypes); + if (ret) + return *ret; + else + return std::nullopt; + } + + return lookupTableProp(indexType, propName, seen); + } + } + else if (auto ct = get(subjectType)) + { + while (ct) + { + if (auto prop = ct->props.find(propName); prop != ct->props.end()) + return prop->second.type; + else if (ct->parent) + ct = get(follow(*ct->parent)); + else + break; + } + } + else if (auto pt = get(subjectType); pt && pt->metatable) + { + const TableType* metatable = get(follow(*pt->metatable)); + LUAU_ASSERT(metatable); + + auto indexProp = metatable->props.find("__index"); + if (indexProp == metatable->props.end()) + return std::nullopt; + + return lookupTableProp(indexProp->second.type, propName, seen); } else if (auto utv = get(subjectType)) { @@ -1704,7 +2007,7 @@ void ConstraintSolver::block(NotNull target, NotNull constraint) @@ -1715,7 +2018,7 @@ bool ConstraintSolver::block(TypeId target, NotNull constraint if (FFlag::DebugLuauLogSolver) printf("block TypeId %s on\t%s\n", toString(target, opts).c_str(), toString(*constraint, opts).c_str()); - block_(target, constraint); + block_(follow(target), constraint); return false; } @@ -1802,7 +2105,7 @@ void ConstraintSolver::unblock(NotNull progressed) if (FFlag::DebugLuauLogSolverToJson) logger->popBlock(progressed); - return unblock_(progressed); + return unblock_(progressed.get()); } void ConstraintSolver::unblock(TypeId progressed) @@ -1810,7 +2113,10 @@ void ConstraintSolver::unblock(TypeId progressed) if (FFlag::DebugLuauLogSolverToJson) logger->popBlock(progressed); - return unblock_(progressed); + unblock_(progressed); + + if (auto bt = get(progressed)) + unblock(bt->boundTo); } void ConstraintSolver::unblock(TypePackId progressed) diff --git a/Analysis/src/DcrLogger.cpp b/Analysis/src/DcrLogger.cpp index a1ef650b..9f66b022 100644 --- a/Analysis/src/DcrLogger.cpp +++ b/Analysis/src/DcrLogger.cpp @@ -9,17 +9,39 @@ namespace Luau { +template +static std::string toPointerId(const T* ptr) +{ + return std::to_string(reinterpret_cast(ptr)); +} + +static std::string toPointerId(NotNull ptr) +{ + return std::to_string(reinterpret_cast(ptr.get())); +} + namespace Json { +template +void write(JsonEmitter& emitter, const T* ptr) +{ + write(emitter, toPointerId(ptr)); +} + +void write(JsonEmitter& emitter, NotNull ptr) +{ + write(emitter, toPointerId(ptr)); +} + void write(JsonEmitter& emitter, const Location& location) { - ObjectEmitter o = emitter.writeObject(); - o.writePair("beginLine", location.begin.line); - o.writePair("beginColumn", location.begin.column); - o.writePair("endLine", location.end.line); - o.writePair("endColumn", location.end.column); - o.finish(); + ArrayEmitter a = emitter.writeArray(); + a.writeValue(location.begin.line); + a.writeValue(location.begin.column); + a.writeValue(location.end.line); + a.writeValue(location.end.column); + a.finish(); } void write(JsonEmitter& emitter, const ErrorSnapshot& snapshot) @@ -47,24 +69,43 @@ void write(JsonEmitter& emitter, const TypeBindingSnapshot& snapshot) o.finish(); } +template +void write(JsonEmitter& emitter, const DenseHashMap& map) +{ + ObjectEmitter o = emitter.writeObject(); + for (const auto& [k, v] : map) + o.writePair(toPointerId(k), v); + o.finish(); +} + +void write(JsonEmitter& emitter, const ExprTypesAtLocation& tys) +{ + ObjectEmitter o = emitter.writeObject(); + o.writePair("location", tys.location); + o.writePair("ty", toPointerId(tys.ty)); + + if (tys.expectedTy) + o.writePair("expectedTy", toPointerId(*tys.expectedTy)); + + o.finish(); +} + +void write(JsonEmitter& emitter, const AnnotationTypesAtLocation& tys) +{ + ObjectEmitter o = emitter.writeObject(); + o.writePair("location", tys.location); + o.writePair("resolvedTy", toPointerId(tys.resolvedTy)); + o.finish(); +} + void write(JsonEmitter& emitter, const ConstraintGenerationLog& log) { ObjectEmitter o = emitter.writeObject(); o.writePair("source", log.source); - - emitter.writeComma(); - write(emitter, "constraintLocations"); - emitter.writeRaw(":"); - - ObjectEmitter locationEmitter = emitter.writeObject(); - - for (const auto& [id, location] : log.constraintLocations) - { - locationEmitter.writePair(id, location); - } - - locationEmitter.finish(); o.writePair("errors", log.errors); + o.writePair("exprTypeLocations", log.exprTypeLocations); + o.writePair("annotationTypeLocations", log.annotationTypeLocations); + o.finish(); } @@ -78,26 +119,34 @@ void write(JsonEmitter& emitter, const ScopeSnapshot& snapshot) o.finish(); } -void write(JsonEmitter& emitter, const ConstraintBlockKind& kind) -{ - switch (kind) - { - case ConstraintBlockKind::TypeId: - return write(emitter, "type"); - case ConstraintBlockKind::TypePackId: - return write(emitter, "typePack"); - case ConstraintBlockKind::ConstraintId: - return write(emitter, "constraint"); - default: - LUAU_ASSERT(0); - } -} - void write(JsonEmitter& emitter, const ConstraintBlock& block) { ObjectEmitter o = emitter.writeObject(); - o.writePair("kind", block.kind); o.writePair("stringification", block.stringification); + + auto go = [&o](auto&& t) { + using T = std::decay_t; + + o.writePair("id", toPointerId(t)); + + if constexpr (std::is_same_v) + { + o.writePair("kind", "type"); + } + else if constexpr (std::is_same_v) + { + o.writePair("kind", "typePack"); + } + else if constexpr (std::is_same_v>) + { + o.writePair("kind", "constraint"); + } + else + static_assert(always_false_v, "non-exhaustive possibility switch"); + }; + + visit(go, block.target); + o.finish(); } @@ -114,7 +163,8 @@ void write(JsonEmitter& emitter, const BoundarySnapshot& snapshot) { ObjectEmitter o = emitter.writeObject(); o.writePair("rootScope", snapshot.rootScope); - o.writePair("constraints", snapshot.constraints); + o.writePair("unsolvedConstraints", snapshot.unsolvedConstraints); + o.writePair("typeStrings", snapshot.typeStrings); o.finish(); } @@ -125,6 +175,7 @@ void write(JsonEmitter& emitter, const StepSnapshot& snapshot) o.writePair("forced", snapshot.forced); o.writePair("unsolvedConstraints", snapshot.unsolvedConstraints); o.writePair("rootScope", snapshot.rootScope); + o.writePair("typeStrings", snapshot.typeStrings); o.finish(); } @@ -146,11 +197,6 @@ void write(JsonEmitter& emitter, const TypeCheckLog& log) } // namespace Json -static std::string toPointerId(NotNull ptr) -{ - return std::to_string(reinterpret_cast(ptr.get())); -} - static ScopeSnapshot snapshotScope(const Scope* scope, ToStringOptions& opts) { std::unordered_map bindings; @@ -230,6 +276,32 @@ void DcrLogger::captureSource(std::string source) generationLog.source = std::move(source); } +void DcrLogger::captureGenerationModule(const ModulePtr& module) +{ + generationLog.exprTypeLocations.reserve(module->astTypes.size()); + for (const auto& [expr, ty] : module->astTypes) + { + ExprTypesAtLocation tys; + tys.location = expr->location; + tys.ty = ty; + + if (auto expectedTy = module->astExpectedTypes.find(expr)) + tys.expectedTy = *expectedTy; + + generationLog.exprTypeLocations.push_back(tys); + } + + generationLog.annotationTypeLocations.reserve(module->astResolvedTypes.size()); + for (const auto& [annot, ty] : module->astResolvedTypes) + { + AnnotationTypesAtLocation tys; + tys.location = annot->location; + tys.resolvedTy = ty; + + generationLog.annotationTypeLocations.push_back(tys); + } +} + void DcrLogger::captureGenerationError(const TypeError& error) { std::string stringifiedError = toString(error); @@ -239,12 +311,6 @@ void DcrLogger::captureGenerationError(const TypeError& error) }); } -void DcrLogger::captureConstraintLocation(NotNull constraint, Location location) -{ - std::string id = toPointerId(constraint); - generationLog.constraintLocations[id] = location; -} - void DcrLogger::pushBlock(NotNull constraint, TypeId block) { constraintBlocks[constraint].push_back(block); @@ -284,44 +350,70 @@ void DcrLogger::popBlock(NotNull block) } } -void DcrLogger::captureInitialSolverState(const Scope* rootScope, const std::vector>& unsolvedConstraints) +static void snapshotTypeStrings(const std::vector& interestedExprs, + const std::vector& interestedAnnots, DenseHashMap& map, ToStringOptions& opts) { - solveLog.initialState.rootScope = snapshotScope(rootScope, opts); - solveLog.initialState.constraints.clear(); + for (const ExprTypesAtLocation& tys : interestedExprs) + { + map[tys.ty] = toString(tys.ty, opts); + + if (tys.expectedTy) + map[*tys.expectedTy] = toString(*tys.expectedTy, opts); + } + + for (const AnnotationTypesAtLocation& tys : interestedAnnots) + { + map[tys.resolvedTy] = toString(tys.resolvedTy, opts); + } +} + +void DcrLogger::captureBoundaryState( + BoundarySnapshot& target, const Scope* rootScope, const std::vector>& unsolvedConstraints) +{ + target.rootScope = snapshotScope(rootScope, opts); + target.unsolvedConstraints.clear(); for (NotNull c : unsolvedConstraints) { - std::string id = toPointerId(c); - solveLog.initialState.constraints[id] = { + target.unsolvedConstraints[c.get()] = { toString(*c.get(), opts), c->location, snapshotBlocks(c), }; } + + snapshotTypeStrings(generationLog.exprTypeLocations, generationLog.annotationTypeLocations, target.typeStrings, opts); +} + +void DcrLogger::captureInitialSolverState(const Scope* rootScope, const std::vector>& unsolvedConstraints) +{ + captureBoundaryState(solveLog.initialState, rootScope, unsolvedConstraints); } StepSnapshot DcrLogger::prepareStepSnapshot( const Scope* rootScope, NotNull current, bool force, const std::vector>& unsolvedConstraints) { ScopeSnapshot scopeSnapshot = snapshotScope(rootScope, opts); - std::string currentId = toPointerId(current); - std::unordered_map constraints; + DenseHashMap constraints{nullptr}; for (NotNull c : unsolvedConstraints) { - std::string id = toPointerId(c); - constraints[id] = { + constraints[c.get()] = { toString(*c.get(), opts), c->location, snapshotBlocks(c), }; } + DenseHashMap typeStrings{nullptr}; + snapshotTypeStrings(generationLog.exprTypeLocations, generationLog.annotationTypeLocations, typeStrings, opts); + return StepSnapshot{ - currentId, + current, force, - constraints, + std::move(constraints), scopeSnapshot, + std::move(typeStrings), }; } @@ -332,18 +424,7 @@ void DcrLogger::commitStepSnapshot(StepSnapshot snapshot) void DcrLogger::captureFinalSolverState(const Scope* rootScope, const std::vector>& unsolvedConstraints) { - solveLog.finalState.rootScope = snapshotScope(rootScope, opts); - solveLog.finalState.constraints.clear(); - - for (NotNull c : unsolvedConstraints) - { - std::string id = toPointerId(c); - solveLog.finalState.constraints[id] = { - toString(*c.get(), opts), - c->location, - snapshotBlocks(c), - }; - } + captureBoundaryState(solveLog.finalState, rootScope, unsolvedConstraints); } void DcrLogger::captureTypeCheckError(const TypeError& error) @@ -370,21 +451,21 @@ std::vector DcrLogger::snapshotBlocks(NotNull if (const TypeId* ty = get_if(&target)) { snapshot.push_back({ - ConstraintBlockKind::TypeId, + *ty, toString(*ty, opts), }); } else if (const TypePackId* tp = get_if(&target)) { snapshot.push_back({ - ConstraintBlockKind::TypePackId, + *tp, toString(*tp, opts), }); } else if (const NotNull* c = get_if>(&target)) { snapshot.push_back({ - ConstraintBlockKind::ConstraintId, + *c, toString(*(c->get()), opts), }); } diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index fb61b4ab..91c72e44 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -899,8 +899,8 @@ ModulePtr check( cgb.visit(sourceModule.root); result->errors = std::move(cgb.errors); - ConstraintSolver cs{NotNull{&normalizer}, NotNull(cgb.rootScope), borrowConstraints(cgb.constraints), sourceModule.name, moduleResolver, - requireCycles, logger.get()}; + ConstraintSolver cs{NotNull{&normalizer}, NotNull(cgb.rootScope), borrowConstraints(cgb.constraints), sourceModule.name, + NotNull{result->reduction.get()}, moduleResolver, requireCycles, logger.get()}; if (options.randomizeConstraintResolutionSeed) cs.randomize(*options.randomizeConstraintResolutionSeed); diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index a7b2b727..0b760810 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -1441,6 +1441,8 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor if (!unionNormals(here, *tn)) return false; } + else if (get(there)) + LUAU_ASSERT(!"Internal error: Trying to normalize a BlockedType"); else LUAU_ASSERT(!"Unreachable"); diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp index aac7864a..845ae3a3 100644 --- a/Analysis/src/Quantify.cpp +++ b/Analysis/src/Quantify.cpp @@ -183,7 +183,7 @@ struct PureQuantifier : Substitution else if (ttv->state == TableState::Generic) seenGenericType = true; - return ttv->state == TableState::Unsealed || (ttv->state == TableState::Free && subsumes(scope, ttv->scope)); + return (ttv->state == TableState::Unsealed || ttv->state == TableState::Free) && subsumes(scope, ttv->scope); } return false; diff --git a/Analysis/src/Scope.cpp b/Analysis/src/Scope.cpp index 84925f79..cac72124 100644 --- a/Analysis/src/Scope.cpp +++ b/Analysis/src/Scope.cpp @@ -31,12 +31,12 @@ std::optional Scope::lookup(Symbol sym) const { auto r = const_cast(this)->lookupEx(sym); if (r) - return r->first; + return r->first->typeId; else return std::nullopt; } -std::optional> Scope::lookupEx(Symbol sym) +std::optional> Scope::lookupEx(Symbol sym) { Scope* s = this; @@ -44,7 +44,7 @@ std::optional> Scope::lookupEx(Symbol sym) { auto it = s->bindings.find(sym); if (it != s->bindings.end()) - return std::pair{it->second.typeId, s}; + return std::pair{&it->second, s}; if (s->parent) s = s->parent.get(); diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 1972177c..d0c53984 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -1533,6 +1533,10 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts) const std::string pathStr = c.path.size() == 1 ? "\"" + c.path[0] + "\"" : "[\"" + join(c.path, "\", \"") + "\"]"; return tos(c.resultType) + " ~ setProp " + tos(c.subjectType) + ", " + pathStr + " " + tos(c.propType); } + else if constexpr (std::is_same_v) + { + return tos(c.resultType) + " ~ setIndexer " + tos(c.subjectType) + " [ " + tos(c.indexType) + " ] " + tos(c.propType); + } else if constexpr (std::is_same_v) { std::string result = tos(c.resultType); @@ -1543,6 +1547,8 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts) else return result + " ~ if isSingleton D then D else unknown where D = " + discriminant; } + else if constexpr (std::is_same_v) + return tos(c.resultPack) + " ~ unpack " + tos(c.sourcePack); else static_assert(always_false_v, "Non-exhaustive constraint switch"); }; diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index 4322a0da..f23fad78 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -4,6 +4,7 @@ #include "Luau/Ast.h" #include "Luau/AstQuery.h" #include "Luau/Clone.h" +#include "Luau/Common.h" #include "Luau/DcrLogger.h" #include "Luau/Error.h" #include "Luau/Instantiation.h" @@ -329,11 +330,12 @@ struct TypeChecker2 for (size_t i = 0; i < count; ++i) { AstExpr* value = i < local->values.size ? local->values.data[i] : nullptr; + const bool isPack = value && (value->is() || value->is()); if (value) visit(value, RValue); - if (i != local->values.size - 1 || value) + if (i != local->values.size - 1 || !isPack) { AstLocal* var = i < local->vars.size ? local->vars.data[i] : nullptr; @@ -351,16 +353,19 @@ struct TypeChecker2 visit(var->annotation); } } - else + else if (value) { - LUAU_ASSERT(value); + TypePackId valuePack = lookupPack(value); + TypePack valueTypes; + if (i < local->vars.size) + valueTypes = extendTypePack(module->internalTypes, builtinTypes, valuePack, local->vars.size - i); - TypePackId valueTypes = lookupPack(value); - auto it = begin(valueTypes); + Location errorLocation; for (size_t j = i; j < local->vars.size; ++j) { - if (it == end(valueTypes)) + if (j - i >= valueTypes.head.size()) { + errorLocation = local->vars.data[j]->location; break; } @@ -368,14 +373,28 @@ struct TypeChecker2 if (var->annotation) { TypeId varType = lookupAnnotation(var->annotation); - ErrorVec errors = tryUnify(stack.back(), value->location, *it, varType); + ErrorVec errors = tryUnify(stack.back(), value->location, valueTypes.head[j - i], varType); if (!errors.empty()) reportErrors(std::move(errors)); visit(var->annotation); } + } - ++it; + if (valueTypes.head.size() < local->vars.size - i) + { + reportError( + CountMismatch{ + // We subtract 1 here because the final AST + // expression is not worth one value. It is worth 0 + // or more depending on valueTypes.head + local->values.size - 1 + valueTypes.head.size(), + std::nullopt, + local->vars.size, + local->values.data[local->values.size - 1]->is() ? CountMismatch::FunctionResult + : CountMismatch::ExprListResult, + }, + errorLocation); } } } @@ -810,6 +829,95 @@ struct TypeChecker2 // TODO! } + ErrorVec visitOverload(AstExprCall* call, NotNull overloadFunctionType, const std::vector& argLocs, + TypePackId expectedArgTypes, TypePackId expectedRetType) + { + ErrorVec overloadErrors = + tryUnify(stack.back(), call->location, overloadFunctionType->retTypes, expectedRetType, CountMismatch::FunctionResult); + + size_t argIndex = 0; + auto inferredArgIt = begin(overloadFunctionType->argTypes); + auto expectedArgIt = begin(expectedArgTypes); + while (inferredArgIt != end(overloadFunctionType->argTypes) && expectedArgIt != end(expectedArgTypes)) + { + Location argLoc = (argIndex >= argLocs.size()) ? argLocs.back() : argLocs[argIndex]; + ErrorVec argErrors = tryUnify(stack.back(), argLoc, *expectedArgIt, *inferredArgIt); + for (TypeError e : argErrors) + overloadErrors.emplace_back(e); + + ++argIndex; + ++inferredArgIt; + ++expectedArgIt; + } + + // piggyback on the unifier for arity checking, but we can't do this for checking the actual arguments since the locations would be bad + ErrorVec argumentErrors = tryUnify(stack.back(), call->location, expectedArgTypes, overloadFunctionType->argTypes); + for (TypeError e : argumentErrors) + if (get(e) != nullptr) + overloadErrors.emplace_back(std::move(e)); + + return overloadErrors; + } + + void reportOverloadResolutionErrors(AstExprCall* call, std::vector overloads, TypePackId expectedArgTypes, + const std::vector& overloadsThatMatchArgCount, std::vector> overloadsErrors) + { + if (overloads.size() == 1) + { + reportErrors(std::get<0>(overloadsErrors.front())); + return; + } + + std::vector overloadTypes = overloadsThatMatchArgCount; + if (overloadsThatMatchArgCount.size() == 0) + { + reportError(GenericError{"No overload for function accepts " + std::to_string(size(expectedArgTypes)) + " arguments."}, call->location); + // If no overloads match argument count, just list all overloads. + overloadTypes = overloads; + } + else + { + // Report errors of the first argument-count-matching, but failing overload + TypeId overload = overloadsThatMatchArgCount[0]; + + // Remove the overload we are reporting errors about from the list of alternatives + overloadTypes.erase(std::remove(overloadTypes.begin(), overloadTypes.end(), overload), overloadTypes.end()); + + const FunctionType* ftv = get(overload); + LUAU_ASSERT(ftv); // overload must be a function type here + + auto error = std::find_if(overloadsErrors.begin(), overloadsErrors.end(), [ftv](const std::pair& e) { + return ftv == std::get<1>(e); + }); + + LUAU_ASSERT(error != overloadsErrors.end()); + reportErrors(std::get<0>(*error)); + + // If only one overload matched, we don't need this error because we provided the previous errors. + if (overloadsThatMatchArgCount.size() == 1) + return; + } + + std::string s; + for (size_t i = 0; i < overloadTypes.size(); ++i) + { + TypeId overload = follow(overloadTypes[i]); + + if (i > 0) + s += "; "; + + if (i > 0 && i == overloadTypes.size() - 1) + s += "and "; + + s += toString(overload); + } + + if (overloadsThatMatchArgCount.size() == 0) + reportError(ExtraInformation{"Available overloads: " + s}, call->func->location); + else + reportError(ExtraInformation{"Other overloads are also not viable: " + s}, call->func->location); + } + void visit(AstExprCall* call) { visit(call->func, RValue); @@ -865,6 +973,10 @@ struct TypeChecker2 return; } } + else if (auto itv = get(functionType)) + { + // We do nothing here because we'll flatten the intersection later, but we don't want to report it as a non-function. + } else if (auto utv = get(functionType)) { // Sometimes it's okay to call a union of functions, but only if all of the functions are the same. @@ -930,48 +1042,105 @@ struct TypeChecker2 TypePackId expectedArgTypes = arena->addTypePack(args); - const FunctionType* inferredFunctionType = get(testFunctionType); - LUAU_ASSERT(inferredFunctionType); // testFunctionType should always be a FunctionType here + std::vector overloads = flattenIntersection(testFunctionType); + std::vector> overloadsErrors; + overloadsErrors.reserve(overloads.size()); - size_t argIndex = 0; - auto inferredArgIt = begin(inferredFunctionType->argTypes); - auto expectedArgIt = begin(expectedArgTypes); - while (inferredArgIt != end(inferredFunctionType->argTypes) && expectedArgIt != end(expectedArgTypes)) + std::vector overloadsThatMatchArgCount; + + for (TypeId overload : overloads) { - Location argLoc = (argIndex >= argLocs.size()) ? argLocs.back() : argLocs[argIndex]; - reportErrors(tryUnify(stack.back(), argLoc, *expectedArgIt, *inferredArgIt)); + overload = follow(overload); - ++argIndex; - ++inferredArgIt; - ++expectedArgIt; + const FunctionType* overloadFn = get(overload); + if (!overloadFn) + { + reportError(CannotCallNonFunction{overload}, call->func->location); + return; + } + else + { + // We may have to instantiate the overload in order for it to typecheck. + if (std::optional instantiatedFunctionType = instantiation.substitute(overload)) + { + overloadFn = get(*instantiatedFunctionType); + } + else + { + overloadsErrors.emplace_back(std::vector{TypeError{call->func->location, UnificationTooComplex{}}}, overloadFn); + return; + } + } + + ErrorVec overloadErrors = visitOverload(call, NotNull{overloadFn}, argLocs, expectedArgTypes, expectedRetType); + if (overloadErrors.empty()) + return; + + bool argMismatch = false; + for (auto error : overloadErrors) + { + CountMismatch* cm = get(error); + if (!cm) + continue; + + if (cm->context == CountMismatch::Arg) + { + argMismatch = true; + break; + } + } + + if (!argMismatch) + overloadsThatMatchArgCount.push_back(overload); + + overloadsErrors.emplace_back(std::move(overloadErrors), overloadFn); } - // piggyback on the unifier for arity checking, but we can't do this for checking the actual arguments since the locations would be bad - ErrorVec errors = tryUnify(stack.back(), call->location, expectedArgTypes, inferredFunctionType->argTypes); - for (TypeError e : errors) - if (get(e) != nullptr) - reportError(std::move(e)); + reportOverloadResolutionErrors(call, overloads, expectedArgTypes, overloadsThatMatchArgCount, overloadsErrors); + } - reportErrors(tryUnify(stack.back(), call->location, inferredFunctionType->retTypes, expectedRetType, CountMismatch::FunctionResult)); + void visitExprName(AstExpr* expr, Location location, const std::string& propName, ValueContext context) + { + visit(expr, RValue); + + TypeId leftType = lookupType(expr); + const NormalizedType* norm = normalizer.normalize(leftType); + if (!norm) + reportError(NormalizationTooComplex{}, location); + + checkIndexTypeFromType(leftType, *norm, propName, location, context); } void visit(AstExprIndexName* indexName, ValueContext context) { - visit(indexName->expr, RValue); - - TypeId leftType = lookupType(indexName->expr); - const NormalizedType* norm = normalizer.normalize(leftType); - if (!norm) - reportError(NormalizationTooComplex{}, indexName->indexLocation); - - checkIndexTypeFromType(leftType, *norm, indexName->index.value, indexName->location, context); + visitExprName(indexName->expr, indexName->location, indexName->index.value, context); } void visit(AstExprIndexExpr* indexExpr, ValueContext context) { + if (auto str = indexExpr->index->as()) + { + const std::string stringValue(str->value.data, str->value.size); + visitExprName(indexExpr->expr, indexExpr->location, stringValue, context); + return; + } + // TODO! visit(indexExpr->expr, LValue); visit(indexExpr->index, RValue); + + NotNull scope = stack.back(); + + TypeId exprType = lookupType(indexExpr->expr); + TypeId indexType = lookupType(indexExpr->index); + + if (auto tt = get(exprType)) + { + if (tt->indexer) + reportErrors(tryUnify(scope, indexExpr->index->location, indexType, tt->indexer->indexType)); + else + reportError(CannotExtendTable{exprType, CannotExtendTable::Indexer, "indexer??"}, indexExpr->location); + } } void visit(AstExprFunction* fn) @@ -1879,8 +2048,17 @@ struct TypeChecker2 ty = *mtIndex; } - if (getTableType(ty)) - return bool(findTablePropertyRespectingMeta(builtinTypes, module->errors, ty, prop, location)); + if (auto tt = getTableType(ty)) + { + if (findTablePropertyRespectingMeta(builtinTypes, module->errors, ty, prop, location)) + return true; + + else if (tt->indexer && isPrim(tt->indexer->indexResultType, PrimitiveType::String)) + return tt->indexer->indexResultType; + + else + return false; + } else if (const ClassType* cls = get(ty)) return bool(lookupClassProp(cls, prop)); else if (const UnionType* utv = get(ty)) diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index e59c7e0e..adca034c 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -1759,7 +1759,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp if (FInt::LuauCheckRecursionLimit > 0 && checkRecursionCount >= FInt::LuauCheckRecursionLimit) { reportErrorCodeTooComplex(expr.location); - return {errorRecoveryType(scope)}; + return WithPredicate{errorRecoveryType(scope)}; } WithPredicate result; @@ -1767,23 +1767,23 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp if (auto a = expr.as()) result = checkExpr(scope, *a->expr, expectedType); else if (expr.is()) - result = {nilType}; + result = WithPredicate{nilType}; else if (const AstExprConstantBool* bexpr = expr.as()) { if (forceSingleton || (expectedType && maybeSingleton(*expectedType))) - result = {singletonType(bexpr->value)}; + result = WithPredicate{singletonType(bexpr->value)}; else - result = {booleanType}; + result = WithPredicate{booleanType}; } else if (const AstExprConstantString* sexpr = expr.as()) { if (forceSingleton || (expectedType && maybeSingleton(*expectedType))) - result = {singletonType(std::string(sexpr->value.data, sexpr->value.size))}; + result = WithPredicate{singletonType(std::string(sexpr->value.data, sexpr->value.size))}; else - result = {stringType}; + result = WithPredicate{stringType}; } else if (expr.is()) - result = {numberType}; + result = WithPredicate{numberType}; else if (auto a = expr.as()) result = checkExpr(scope, *a); else if (auto a = expr.as()) @@ -1837,7 +1837,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp // TODO: tempting to ice here, but this breaks very often because our toposort doesn't enforce this constraint // ice("AstExprLocal exists but no binding definition for it?", expr.location); reportError(TypeError{expr.location, UnknownSymbol{expr.local->name.value, UnknownSymbol::Binding}}); - return {errorRecoveryType(scope)}; + return WithPredicate{errorRecoveryType(scope)}; } WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprGlobal& expr) @@ -1849,7 +1849,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp return {*ty, {TruthyPredicate{std::move(*lvalue), expr.location}}}; reportError(TypeError{expr.location, UnknownSymbol{expr.name.value, UnknownSymbol::Binding}}); - return {errorRecoveryType(scope)}; + return WithPredicate{errorRecoveryType(scope)}; } WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprVarargs& expr) @@ -1859,26 +1859,26 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp if (get(varargPack)) { if (std::optional ty = first(varargPack)) - return {*ty}; + return WithPredicate{*ty}; - return {nilType}; + return WithPredicate{nilType}; } else if (get(varargPack)) { TypeId head = freshType(scope); TypePackId tail = freshTypePack(scope); *asMutable(varargPack) = TypePack{{head}, tail}; - return {head}; + return WithPredicate{head}; } if (get(varargPack)) - return {errorRecoveryType(scope)}; + return WithPredicate{errorRecoveryType(scope)}; else if (auto vtp = get(varargPack)) - return {vtp->ty}; + return WithPredicate{vtp->ty}; else if (get(varargPack)) { // TODO: Better error? reportError(expr.location, GenericError{"Trying to get a type from a variadic type parameter"}); - return {errorRecoveryType(scope)}; + return WithPredicate{errorRecoveryType(scope)}; } else ice("Unknown TypePack type in checkExpr(AstExprVarargs)!"); @@ -1929,9 +1929,9 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp lhsType = stripFromNilAndReport(lhsType, expr.expr->location); if (std::optional ty = getIndexTypeFromType(scope, lhsType, name, expr.location, /* addErrors= */ true)) - return {*ty}; + return WithPredicate{*ty}; - return {errorRecoveryType(scope)}; + return WithPredicate{errorRecoveryType(scope)}; } std::optional TypeChecker::findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location, bool addErrors) @@ -2138,7 +2138,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp if (std::optional refiTy = resolveLValue(scope, *lvalue)) return {*refiTy, {TruthyPredicate{std::move(*lvalue), expr.location}}}; - return {ty}; + return WithPredicate{ty}; } WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprFunction& expr, std::optional expectedType) @@ -2147,7 +2147,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp checkFunctionBody(funScope, funTy, expr); - return {quantify(funScope, funTy, expr.location)}; + return WithPredicate{quantify(funScope, funTy, expr.location)}; } TypeId TypeChecker::checkExprTable( @@ -2252,7 +2252,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp if (FInt::LuauCheckRecursionLimit > 0 && checkRecursionCount >= FInt::LuauCheckRecursionLimit) { reportErrorCodeTooComplex(expr.location); - return {errorRecoveryType(scope)}; + return WithPredicate{errorRecoveryType(scope)}; } std::vector> fieldTypes(expr.items.size); @@ -2339,7 +2339,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp expectedIndexResultType = fieldTypes[i].second; } - return {checkExprTable(scope, expr, fieldTypes, expectedType)}; + return WithPredicate{checkExprTable(scope, expr, fieldTypes, expectedType)}; } WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprUnary& expr) @@ -2356,7 +2356,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp const bool operandIsAny = get(operandType) || get(operandType) || get(operandType); if (operandIsAny) - return {operandType}; + return WithPredicate{operandType}; if (typeCouldHaveMetatable(operandType)) { @@ -2377,16 +2377,16 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp if (!state.errors.empty()) retType = errorRecoveryType(retType); - return {retType}; + return WithPredicate{retType}; } reportError(expr.location, GenericError{format("Unary operator '%s' not supported by type '%s'", toString(expr.op).c_str(), toString(operandType).c_str())}); - return {errorRecoveryType(scope)}; + return WithPredicate{errorRecoveryType(scope)}; } reportErrors(tryUnify(operandType, numberType, scope, expr.location)); - return {numberType}; + return WithPredicate{numberType}; } case AstExprUnary::Len: { @@ -2396,7 +2396,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp // # operator is guaranteed to return number if (get(operandType) || get(operandType) || get(operandType)) - return {numberType}; + return WithPredicate{numberType}; DenseHashSet seen{nullptr}; @@ -2420,7 +2420,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp if (!hasLength(operandType, seen, &recursionCount)) reportError(TypeError{expr.location, NotATable{operandType}}); - return {numberType}; + return WithPredicate{numberType}; } default: ice("Unknown AstExprUnary " + std::to_string(int(expr.op))); @@ -3014,7 +3014,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp WithPredicate rhs = checkExpr(scope, *expr.right); // Intentionally discarding predicates with other operators. - return {checkBinaryOperation(scope, expr, lhs.type, rhs.type, lhs.predicates)}; + return WithPredicate{checkBinaryOperation(scope, expr, lhs.type, rhs.type, lhs.predicates)}; } } @@ -3045,7 +3045,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp // any type errors that may arise from it are going to be useless. currentModule->errors.resize(oldSize); - return {errorRecoveryType(scope)}; + return WithPredicate{errorRecoveryType(scope)}; } WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIfElse& expr, std::optional expectedType) @@ -3061,12 +3061,12 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp WithPredicate falseType = checkExpr(falseScope, *expr.falseExpr, expectedType); if (falseType.type == trueType.type) - return {trueType.type}; + return WithPredicate{trueType.type}; std::vector types = reduceUnion({trueType.type, falseType.type}); if (types.empty()) - return {neverType}; - return {types.size() == 1 ? types[0] : addType(UnionType{std::move(types)})}; + return WithPredicate{neverType}; + return WithPredicate{types.size() == 1 ? types[0] : addType(UnionType{std::move(types)})}; } WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprInterpString& expr) @@ -3074,7 +3074,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp for (AstExpr* expr : expr.expressions) checkExpr(scope, *expr); - return {stringType}; + return WithPredicate{stringType}; } TypeId TypeChecker::checkLValue(const ScopePtr& scope, const AstExpr& expr, ValueContext ctx) @@ -3704,7 +3704,7 @@ WithPredicate TypeChecker::checkExprPack(const ScopePtr& scope, cons { WithPredicate result = checkExprPackHelper(scope, expr); if (containsNever(result.type)) - return {uninhabitableTypePack}; + return WithPredicate{uninhabitableTypePack}; return result; } @@ -3715,14 +3715,14 @@ WithPredicate TypeChecker::checkExprPackHelper(const ScopePtr& scope else if (expr.is()) { if (!scope->varargPack) - return {errorRecoveryTypePack(scope)}; + return WithPredicate{errorRecoveryTypePack(scope)}; - return {*scope->varargPack}; + return WithPredicate{*scope->varargPack}; } else { TypeId type = checkExpr(scope, expr).type; - return {addTypePack({type})}; + return WithPredicate{addTypePack({type})}; } } @@ -3994,71 +3994,77 @@ WithPredicate TypeChecker::checkExprPackHelper(const ScopePtr& scope { retPack = freshTypePack(free->level); TypePackId freshArgPack = freshTypePack(free->level); - asMutable(actualFunctionType)->ty.emplace(free->level, freshArgPack, retPack); + emplaceType(asMutable(actualFunctionType), free->level, freshArgPack, retPack); } else retPack = freshTypePack(scope->level); - // checkExpr will log the pre-instantiated type of the function. - // That's not nearly as interesting as the instantiated type, which will include details about how - // generic functions are being instantiated for this particular callsite. - currentModule->astOriginalCallTypes[expr.func] = follow(functionType); - currentModule->astTypes[expr.func] = actualFunctionType; + // We break this function up into a lambda here to limit our stack footprint. + // The vectors used by this function aren't allocated until the lambda is actually called. + auto the_rest = [&]() -> WithPredicate { + // checkExpr will log the pre-instantiated type of the function. + // That's not nearly as interesting as the instantiated type, which will include details about how + // generic functions are being instantiated for this particular callsite. + currentModule->astOriginalCallTypes[expr.func] = follow(functionType); + currentModule->astTypes[expr.func] = actualFunctionType; - std::vector overloads = flattenIntersection(actualFunctionType); + std::vector overloads = flattenIntersection(actualFunctionType); - std::vector> expectedTypes = getExpectedTypesForCall(overloads, expr.args.size, expr.self); + std::vector> expectedTypes = getExpectedTypesForCall(overloads, expr.args.size, expr.self); - WithPredicate argListResult = checkExprList(scope, expr.location, expr.args, false, {}, expectedTypes); - TypePackId argPack = argListResult.type; + WithPredicate argListResult = checkExprList(scope, expr.location, expr.args, false, {}, expectedTypes); + TypePackId argPack = argListResult.type; - if (get(argPack)) - return {errorRecoveryTypePack(scope)}; + if (get(argPack)) + return WithPredicate{errorRecoveryTypePack(scope)}; - TypePack* args = nullptr; - if (expr.self) - { - argPack = addTypePack(TypePack{{selfType}, argPack}); - argListResult.type = argPack; - } - args = getMutable(argPack); - LUAU_ASSERT(args); + TypePack* args = nullptr; + if (expr.self) + { + argPack = addTypePack(TypePack{{selfType}, argPack}); + argListResult.type = argPack; + } + args = getMutable(argPack); + LUAU_ASSERT(args); - std::vector argLocations; - argLocations.reserve(expr.args.size + 1); - if (expr.self) - argLocations.push_back(expr.func->as()->expr->location); - for (AstExpr* arg : expr.args) - argLocations.push_back(arg->location); + std::vector argLocations; + argLocations.reserve(expr.args.size + 1); + if (expr.self) + argLocations.push_back(expr.func->as()->expr->location); + for (AstExpr* arg : expr.args) + argLocations.push_back(arg->location); - std::vector errors; // errors encountered for each overload + std::vector errors; // errors encountered for each overload - std::vector overloadsThatMatchArgCount; - std::vector overloadsThatDont; + std::vector overloadsThatMatchArgCount; + std::vector overloadsThatDont; - for (TypeId fn : overloads) - { - fn = follow(fn); + for (TypeId fn : overloads) + { + fn = follow(fn); - if (auto ret = checkCallOverload( - scope, expr, fn, retPack, argPack, args, &argLocations, argListResult, overloadsThatMatchArgCount, overloadsThatDont, errors)) - return *ret; - } + if (auto ret = checkCallOverload( + scope, expr, fn, retPack, argPack, args, &argLocations, argListResult, overloadsThatMatchArgCount, overloadsThatDont, errors)) + return *ret; + } - if (handleSelfCallMismatch(scope, expr, args, argLocations, errors)) - return {retPack}; + if (handleSelfCallMismatch(scope, expr, args, argLocations, errors)) + return WithPredicate{retPack}; - reportOverloadResolutionError(scope, expr, retPack, argPack, argLocations, overloads, overloadsThatMatchArgCount, errors); + reportOverloadResolutionError(scope, expr, retPack, argPack, argLocations, overloads, overloadsThatMatchArgCount, errors); - const FunctionType* overload = nullptr; - if (!overloadsThatMatchArgCount.empty()) - overload = get(overloadsThatMatchArgCount[0]); - if (!overload && !overloadsThatDont.empty()) - overload = get(overloadsThatDont[0]); - if (overload) - return {errorRecoveryTypePack(overload->retTypes)}; + const FunctionType* overload = nullptr; + if (!overloadsThatMatchArgCount.empty()) + overload = get(overloadsThatMatchArgCount[0]); + if (!overload && !overloadsThatDont.empty()) + overload = get(overloadsThatDont[0]); + if (overload) + return WithPredicate{errorRecoveryTypePack(overload->retTypes)}; - return {errorRecoveryTypePack(retPack)}; + return WithPredicate{errorRecoveryTypePack(retPack)}; + }; + + return the_rest(); } std::vector> TypeChecker::getExpectedTypesForCall(const std::vector& overloads, size_t argumentCount, bool selfCall) @@ -4119,8 +4125,13 @@ std::vector> TypeChecker::getExpectedTypesForCall(const st return expectedTypes; } -std::optional> TypeChecker::checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, - TypePackId argPack, TypePack* args, const std::vector* argLocations, const WithPredicate& argListResult, +/* + * Note: We return a std::unique_ptr here rather than an optional to manage our stack consumption. + * If this was an optional, callers would have to pay the stack cost for the result. This is problematic + * for functions that need to support recursion up to 600 levels deep. + */ +std::unique_ptr> TypeChecker::checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, + TypePackId retPack, TypePackId argPack, TypePack* args, const std::vector* argLocations, const WithPredicate& argListResult, std::vector& overloadsThatMatchArgCount, std::vector& overloadsThatDont, std::vector& errors) { LUAU_ASSERT(argLocations); @@ -4130,16 +4141,16 @@ std::optional> TypeChecker::checkCallOverload(const Sc if (get(fn)) { unify(anyTypePack, argPack, scope, expr.location); - return {{anyTypePack}}; + return std::make_unique>(anyTypePack); } if (get(fn)) { - return {{errorRecoveryTypePack(scope)}}; + return std::make_unique>(errorRecoveryTypePack(scope)); } if (get(fn)) - return {{uninhabitableTypePack}}; + return std::make_unique>(uninhabitableTypePack); if (auto ftv = get(fn)) { @@ -4152,7 +4163,7 @@ std::optional> TypeChecker::checkCallOverload(const Sc options.isFunctionCall = true; unify(r, fn, scope, expr.location, options); - return {{retPack}}; + return std::make_unique>(retPack); } std::vector metaArgLocations; @@ -4191,7 +4202,7 @@ std::optional> TypeChecker::checkCallOverload(const Sc { reportError(TypeError{expr.func->location, CannotCallNonFunction{fn}}); unify(errorRecoveryTypePack(scope), retPack, scope, expr.func->location); - return {{errorRecoveryTypePack(retPack)}}; + return std::make_unique>(errorRecoveryTypePack(retPack)); } // When this function type has magic functions and did return something, we select that overload instead. @@ -4200,7 +4211,7 @@ std::optional> TypeChecker::checkCallOverload(const Sc { // TODO: We're passing in the wrong TypePackId. Should be argPack, but a unit test fails otherwise. CLI-40458 if (std::optional> ret = ftv->magicFunction(*this, scope, expr, argListResult)) - return *ret; + return std::make_unique>(std::move(*ret)); } Unifier state = mkUnifier(scope, expr.location); @@ -4209,7 +4220,7 @@ std::optional> TypeChecker::checkCallOverload(const Sc checkArgumentList(scope, *expr.func, state, retPack, ftv->retTypes, /*argLocations*/ {}); if (!state.errors.empty()) { - return {}; + return nullptr; } checkArgumentList(scope, *expr.func, state, argPack, ftv->argTypes, *argLocations); @@ -4244,10 +4255,10 @@ std::optional> TypeChecker::checkCallOverload(const Sc currentModule->astOverloadResolvedTypes[&expr] = fn; // We select this overload - return {{retPack}}; + return std::make_unique>(retPack); } - return {}; + return nullptr; } bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCall& expr, TypePack* args, const std::vector& argLocations, @@ -4404,7 +4415,7 @@ WithPredicate TypeChecker::checkExprList(const ScopePtr& scope, cons }; if (exprs.size == 0) - return {pack}; + return WithPredicate{pack}; TypePack* tp = getMutable(pack); @@ -4484,7 +4495,7 @@ WithPredicate TypeChecker::checkExprList(const ScopePtr& scope, cons log.commit(); if (uninhabitable) - return {uninhabitableTypePack}; + return WithPredicate{uninhabitableTypePack}; return {pack, predicates}; } diff --git a/Analysis/src/TypeReduction.cpp b/Analysis/src/TypeReduction.cpp index 94fb4ad3..2393829d 100644 --- a/Analysis/src/TypeReduction.cpp +++ b/Analysis/src/TypeReduction.cpp @@ -16,11 +16,167 @@ LUAU_FASTFLAGVARIABLE(DebugLuauDontReduceTypes, false) namespace Luau { +namespace detail +{ +bool TypeReductionMemoization::isIrreducible(TypeId ty) +{ + ty = follow(ty); + + // Only does shallow check, the TypeReducer itself already does deep traversal. + if (auto edge = types.find(ty); edge && edge->irreducible) + return true; + else if (get(ty) || get(ty) || get(ty)) + return false; + else if (auto tt = get(ty); tt && (tt->state == TableState::Free || tt->state == TableState::Unsealed)) + return false; + else + return true; +} + +bool TypeReductionMemoization::isIrreducible(TypePackId tp) +{ + tp = follow(tp); + + // Only does shallow check, the TypeReducer itself already does deep traversal. + if (auto edge = typePacks.find(tp); edge && edge->irreducible) + return true; + else if (get(tp) || get(tp)) + return false; + else if (auto vtp = get(tp)) + return isIrreducible(vtp->ty); + else + return true; +} + +TypeId TypeReductionMemoization::memoize(TypeId ty, TypeId reducedTy) +{ + ty = follow(ty); + reducedTy = follow(reducedTy); + + // The irreducibility of this [`reducedTy`] depends on whether its contents are themselves irreducible. + // We don't need to recurse much further than that, because we already record the irreducibility from + // the bottom up. + bool irreducible = isIrreducible(reducedTy); + if (auto it = get(reducedTy)) + { + for (TypeId part : it) + irreducible &= isIrreducible(part); + } + else if (auto ut = get(reducedTy)) + { + for (TypeId option : ut) + irreducible &= isIrreducible(option); + } + else if (auto tt = get(reducedTy)) + { + for (auto& [k, p] : tt->props) + irreducible &= isIrreducible(p.type); + + if (tt->indexer) + { + irreducible &= isIrreducible(tt->indexer->indexType); + irreducible &= isIrreducible(tt->indexer->indexResultType); + } + + for (auto ta : tt->instantiatedTypeParams) + irreducible &= isIrreducible(ta); + + for (auto tpa : tt->instantiatedTypePackParams) + irreducible &= isIrreducible(tpa); + } + else if (auto mt = get(reducedTy)) + { + irreducible &= isIrreducible(mt->table); + irreducible &= isIrreducible(mt->metatable); + } + else if (auto ft = get(reducedTy)) + { + irreducible &= isIrreducible(ft->argTypes); + irreducible &= isIrreducible(ft->retTypes); + } + else if (auto nt = get(reducedTy)) + irreducible &= isIrreducible(nt->ty); + + types[ty] = {reducedTy, irreducible}; + types[reducedTy] = {reducedTy, irreducible}; + return reducedTy; +} + +TypePackId TypeReductionMemoization::memoize(TypePackId tp, TypePackId reducedTp) +{ + tp = follow(tp); + reducedTp = follow(reducedTp); + + bool irreducible = isIrreducible(reducedTp); + TypePackIterator it = begin(tp); + while (it != end(tp)) + { + irreducible &= isIrreducible(*it); + ++it; + } + + if (it.tail()) + irreducible &= isIrreducible(*it.tail()); + + typePacks[tp] = {reducedTp, irreducible}; + typePacks[reducedTp] = {reducedTp, irreducible}; + return reducedTp; +} + +std::optional> TypeReductionMemoization::memoizedof(TypeId ty) const +{ + auto fetchContext = [this](TypeId ty) -> std::optional> { + if (auto edge = types.find(ty)) + return *edge; + else + return std::nullopt; + }; + + TypeId currentTy = ty; + std::optional> lastEdge; + while (auto edge = fetchContext(currentTy)) + { + lastEdge = edge; + if (edge->irreducible) + return edge; + else if (edge->type == currentTy) + return edge; + else + currentTy = edge->type; + } + + return lastEdge; +} + +std::optional> TypeReductionMemoization::memoizedof(TypePackId tp) const +{ + auto fetchContext = [this](TypePackId tp) -> std::optional> { + if (auto edge = typePacks.find(tp)) + return *edge; + else + return std::nullopt; + }; + + TypePackId currentTp = tp; + std::optional> lastEdge; + while (auto edge = fetchContext(currentTp)) + { + lastEdge = edge; + if (edge->irreducible) + return edge; + else if (edge->type == currentTp) + return edge; + else + currentTp = edge->type; + } + + return lastEdge; +} +} // namespace detail + namespace { -using detail::ReductionContext; - template std::pair get2(const Thing& one, const Thing& two) { @@ -34,9 +190,7 @@ struct TypeReducer NotNull arena; NotNull builtinTypes; NotNull handle; - - DenseHashMap>* memoizedTypes; - DenseHashMap>* memoizedTypePacks; + NotNull memoization; DenseHashSet* cyclics; int depth = 0; @@ -50,12 +204,6 @@ struct TypeReducer TypeId functionType(TypeId ty); TypeId negationType(TypeId ty); - bool isIrreducible(TypeId ty); - bool isIrreducible(TypePackId tp); - - TypeId memoize(TypeId ty, TypeId reducedTy); - TypePackId memoize(TypePackId tp, TypePackId reducedTp); - using BinaryFold = std::optional (TypeReducer::*)(TypeId, TypeId); using UnaryFold = TypeId (TypeReducer::*)(TypeId); @@ -64,12 +212,15 @@ struct TypeReducer { ty = follow(ty); - if (auto ctx = memoizedTypes->find(ty)) - return {ctx->type, getMutable(ctx->type)}; + if (auto edge = memoization->memoizedof(ty)) + return {edge->type, getMutable(edge->type)}; + // We specifically do not want to use [`detail::TypeReductionMemoization::memoize`] because that will + // potentially consider these copiedTy to be reducible, but we need this to resolve cyclic references + // without attempting to recursively reduce it, causing copies of copies of copies of... TypeId copiedTy = arena->addType(*t); - (*memoizedTypes)[ty] = {copiedTy, true}; - (*memoizedTypes)[copiedTy] = {copiedTy, true}; + memoization->types[ty] = {copiedTy, true}; + memoization->types[copiedTy] = {copiedTy, true}; return {copiedTy, getMutable(copiedTy)}; } @@ -175,8 +326,13 @@ TypeId TypeReducer::reduce(TypeId ty) { ty = follow(ty); - if (auto ctx = memoizedTypes->find(ty); ctx && ctx->irreducible) - return ctx->type; + if (auto edge = memoization->memoizedof(ty)) + { + if (edge->irreducible) + return edge->type; + else + ty = edge->type; + } else if (cyclics->contains(ty)) return ty; @@ -196,15 +352,20 @@ TypeId TypeReducer::reduce(TypeId ty) else result = ty; - return memoize(ty, result); + return memoization->memoize(ty, result); } TypePackId TypeReducer::reduce(TypePackId tp) { tp = follow(tp); - if (auto ctx = memoizedTypePacks->find(tp); ctx && ctx->irreducible) - return ctx->type; + if (auto edge = memoization->memoizedof(tp)) + { + if (edge->irreducible) + return edge->type; + else + tp = edge->type; + } else if (cyclics->contains(tp)) return tp; @@ -237,11 +398,11 @@ TypePackId TypeReducer::reduce(TypePackId tp) } if (!didReduce) - return memoize(tp, tp); + return memoization->memoize(tp, tp); else if (head.empty() && tail) - return memoize(tp, *tail); + return memoization->memoize(tp, *tail); else - return memoize(tp, arena->addTypePack(TypePack{std::move(head), tail})); + return memoization->memoize(tp, arena->addTypePack(TypePack{std::move(head), tail})); } std::optional TypeReducer::intersectionType(TypeId left, TypeId right) @@ -832,111 +993,6 @@ TypeId TypeReducer::negationType(TypeId ty) return ty; // for all T except the ones handled above, ~T ~ ~T } -bool TypeReducer::isIrreducible(TypeId ty) -{ - ty = follow(ty); - - // Only does shallow check, the TypeReducer itself already does deep traversal. - if (auto ctx = memoizedTypes->find(ty); ctx && ctx->irreducible) - return true; - else if (get(ty) || get(ty) || get(ty)) - return false; - else if (auto tt = get(ty); tt && (tt->state == TableState::Free || tt->state == TableState::Unsealed)) - return false; - else - return true; -} - -bool TypeReducer::isIrreducible(TypePackId tp) -{ - tp = follow(tp); - - // Only does shallow check, the TypeReducer itself already does deep traversal. - if (auto ctx = memoizedTypePacks->find(tp); ctx && ctx->irreducible) - return true; - else if (get(tp) || get(tp)) - return false; - else if (auto vtp = get(tp)) - return isIrreducible(vtp->ty); - else - return true; -} - -TypeId TypeReducer::memoize(TypeId ty, TypeId reducedTy) -{ - ty = follow(ty); - reducedTy = follow(reducedTy); - - // The irreducibility of this [`reducedTy`] depends on whether its contents are themselves irreducible. - // We don't need to recurse much further than that, because we already record the irreducibility from - // the bottom up. - bool irreducible = isIrreducible(reducedTy); - if (auto it = get(reducedTy)) - { - for (TypeId part : it) - irreducible &= isIrreducible(part); - } - else if (auto ut = get(reducedTy)) - { - for (TypeId option : ut) - irreducible &= isIrreducible(option); - } - else if (auto tt = get(reducedTy)) - { - for (auto& [k, p] : tt->props) - irreducible &= isIrreducible(p.type); - - if (tt->indexer) - { - irreducible &= isIrreducible(tt->indexer->indexType); - irreducible &= isIrreducible(tt->indexer->indexResultType); - } - - for (auto ta : tt->instantiatedTypeParams) - irreducible &= isIrreducible(ta); - - for (auto tpa : tt->instantiatedTypePackParams) - irreducible &= isIrreducible(tpa); - } - else if (auto mt = get(reducedTy)) - { - irreducible &= isIrreducible(mt->table); - irreducible &= isIrreducible(mt->metatable); - } - else if (auto ft = get(reducedTy)) - { - irreducible &= isIrreducible(ft->argTypes); - irreducible &= isIrreducible(ft->retTypes); - } - else if (auto nt = get(reducedTy)) - irreducible &= isIrreducible(nt->ty); - - (*memoizedTypes)[ty] = {reducedTy, irreducible}; - (*memoizedTypes)[reducedTy] = {reducedTy, irreducible}; - return reducedTy; -} - -TypePackId TypeReducer::memoize(TypePackId tp, TypePackId reducedTp) -{ - tp = follow(tp); - reducedTp = follow(reducedTp); - - bool irreducible = isIrreducible(reducedTp); - TypePackIterator it = begin(tp); - while (it != end(tp)) - { - irreducible &= isIrreducible(*it); - ++it; - } - - if (it.tail()) - irreducible &= isIrreducible(*it.tail()); - - (*memoizedTypePacks)[tp] = {reducedTp, irreducible}; - (*memoizedTypePacks)[reducedTp] = {reducedTp, irreducible}; - return reducedTp; -} - struct MarkCycles : TypeVisitor { DenseHashSet cyclics{nullptr}; @@ -961,7 +1017,6 @@ struct MarkCycles : TypeVisitor return !cyclics.find(follow(tp)); } }; - } // namespace TypeReduction::TypeReduction( @@ -981,8 +1036,13 @@ std::optional TypeReduction::reduce(TypeId ty) return ty; else if (!options.allowTypeReductionsFromOtherArenas && ty->owningArena != arena) return ty; - else if (auto memoized = memoizedof(ty)) - return *memoized; + else if (auto edge = memoization.memoizedof(ty)) + { + if (edge->irreducible) + return edge->type; + else + ty = edge->type; + } else if (hasExceededCartesianProductLimit(ty)) return std::nullopt; @@ -991,7 +1051,7 @@ std::optional TypeReduction::reduce(TypeId ty) MarkCycles finder; finder.traverse(ty); - TypeReducer reducer{arena, builtinTypes, handle, &memoizedTypes, &memoizedTypePacks, &finder.cyclics}; + TypeReducer reducer{arena, builtinTypes, handle, NotNull{&memoization}, &finder.cyclics}; return reducer.reduce(ty); } catch (const RecursionLimitException&) @@ -1008,8 +1068,13 @@ std::optional TypeReduction::reduce(TypePackId tp) return tp; else if (!options.allowTypeReductionsFromOtherArenas && tp->owningArena != arena) return tp; - else if (auto memoized = memoizedof(tp)) - return *memoized; + else if (auto edge = memoization.memoizedof(tp)) + { + if (edge->irreducible) + return edge->type; + else + tp = edge->type; + } else if (hasExceededCartesianProductLimit(tp)) return std::nullopt; @@ -1018,7 +1083,7 @@ std::optional TypeReduction::reduce(TypePackId tp) MarkCycles finder; finder.traverse(tp); - TypeReducer reducer{arena, builtinTypes, handle, &memoizedTypes, &memoizedTypePacks, &finder.cyclics}; + TypeReducer reducer{arena, builtinTypes, handle, NotNull{&memoization}, &finder.cyclics}; return reducer.reduce(tp); } catch (const RecursionLimitException&) @@ -1039,13 +1104,6 @@ std::optional TypeReduction::reduce(const TypeFun& fun) return std::nullopt; } -TypeReduction TypeReduction::fork(NotNull arena, const TypeReductionOptions& opts) const -{ - TypeReduction child{arena, builtinTypes, handle, opts}; - child.parent = this; - return child; -} - size_t TypeReduction::cartesianProductSize(TypeId ty) const { ty = follow(ty); @@ -1093,24 +1151,4 @@ bool TypeReduction::hasExceededCartesianProductLimit(TypePackId tp) const return false; } -std::optional TypeReduction::memoizedof(TypeId ty) const -{ - if (auto ctx = memoizedTypes.find(ty); ctx && ctx->irreducible) - return ctx->type; - else if (parent) - return parent->memoizedof(ty); - else - return std::nullopt; -} - -std::optional TypeReduction::memoizedof(TypePackId tp) const -{ - if (auto ctx = memoizedTypePacks.find(tp); ctx && ctx->irreducible) - return ctx->type; - else if (parent) - return parent->memoizedof(tp); - else - return std::nullopt; -} - } // namespace Luau diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 7104f2e7..6364a5aa 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -520,7 +520,12 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool size_t errorCount = errors.size(); - if (const UnionType* subUnion = log.getMutable(subTy)) + if (log.getMutable(subTy) && log.getMutable(superTy)) + { + blockedTypes.push_back(subTy); + blockedTypes.push_back(superTy); + } + else if (const UnionType* subUnion = log.getMutable(subTy)) { tryUnifyUnionWithType(subTy, subUnion, superTy); } diff --git a/CodeGen/include/Luau/IrBuilder.h b/CodeGen/include/Luau/IrBuilder.h index ebbba689..29553421 100644 --- a/CodeGen/include/Luau/IrBuilder.h +++ b/CodeGen/include/Luau/IrBuilder.h @@ -42,6 +42,7 @@ struct IrBuilder IrOp inst(IrCmd cmd, IrOp a, IrOp b, IrOp c); IrOp inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d); IrOp inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d, IrOp e); + IrOp inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d, IrOp e, IrOp f); IrOp block(IrBlockKind kind); // Requested kind can be ignored if we are in an outlined sequence IrOp blockAtInst(uint32_t index); @@ -57,6 +58,8 @@ struct IrBuilder IrFunction function; + uint32_t activeBlockIdx = ~0u; + std::vector instIndexToBlock; // Block index at the bytecode instruction }; diff --git a/CodeGen/include/Luau/IrData.h b/CodeGen/include/Luau/IrData.h index 6a709468..18d510cc 100644 --- a/CodeGen/include/Luau/IrData.h +++ b/CodeGen/include/Luau/IrData.h @@ -5,6 +5,7 @@ #include "Luau/RegisterX64.h" #include "Luau/RegisterA64.h" +#include #include #include @@ -186,6 +187,16 @@ enum class IrCmd : uint8_t // A: int INT_TO_NUM, + // Adjust stack top (L->top) to point at 'B' TValues *after* the specified register + // This is used to return muliple values + // A: Rn + // B: int (offset) + ADJUST_STACK_TO_REG, + + // Restore stack top (L->top) to point to the function stack top (L->ci->top) + // This is used to recover after calling a variadic function + ADJUST_STACK_TO_TOP, + // Fallback functions // Perform an arithmetic operation on TValues of any type @@ -329,7 +340,7 @@ enum class IrCmd : uint8_t // Call specified function // A: unsigned int (bytecode instruction index) // B: Rn (function, followed by arguments) - // C: int (argument count or -1 to preserve all arguments up to stack top) + // C: int (argument count or -1 to use all arguments up to stack top) // D: int (result count or -1 to preserve all results and adjust stack top) // Note: return values are placed starting from Rn specified in 'B' LOP_CALL, @@ -337,13 +348,13 @@ enum class IrCmd : uint8_t // Return specified values from the function // A: unsigned int (bytecode instruction index) // B: Rn (value start) - // B: int (result count or -1 to return all values up to stack top) + // C: int (result count or -1 to return all values up to stack top) LOP_RETURN, // Perform a fast call of a built-in function // A: unsigned int (bytecode instruction index) // B: Rn (argument start) - // C: int (argument count or -1 preserve all arguments up to stack top) + // C: int (argument count or -1 use all arguments up to stack top) // D: block (fallback) // Note: return values are placed starting from Rn specified in 'B' LOP_FASTCALL, @@ -560,6 +571,7 @@ struct IrInst IrOp c; IrOp d; IrOp e; + IrOp f; uint32_t lastUse = 0; uint16_t useCount = 0; @@ -584,9 +596,10 @@ struct IrBlock uint16_t useCount = 0; - // Start points to an instruction index in a stream - // End is implicit + // 'start' and 'finish' define an inclusive range of instructions which belong to this block inside the function + // When block has been constructed, 'finish' always points to the first and only terminating instruction uint32_t start = ~0u; + uint32_t finish = ~0u; Label label; }; @@ -633,6 +646,19 @@ struct IrFunction return value.valueTag; } + std::optional asTagOp(IrOp op) + { + if (op.kind != IrOpKind::Constant) + return std::nullopt; + + IrConst& value = constOp(op); + + if (value.kind != IrConstKind::Tag) + return std::nullopt; + + return value.valueTag; + } + bool boolOp(IrOp op) { IrConst& value = constOp(op); @@ -641,6 +667,19 @@ struct IrFunction return value.valueBool; } + std::optional asBoolOp(IrOp op) + { + if (op.kind != IrOpKind::Constant) + return std::nullopt; + + IrConst& value = constOp(op); + + if (value.kind != IrConstKind::Bool) + return std::nullopt; + + return value.valueBool; + } + int intOp(IrOp op) { IrConst& value = constOp(op); @@ -649,6 +688,19 @@ struct IrFunction return value.valueInt; } + std::optional asIntOp(IrOp op) + { + if (op.kind != IrOpKind::Constant) + return std::nullopt; + + IrConst& value = constOp(op); + + if (value.kind != IrConstKind::Int) + return std::nullopt; + + return value.valueInt; + } + unsigned uintOp(IrOp op) { IrConst& value = constOp(op); @@ -657,6 +709,19 @@ struct IrFunction return value.valueUint; } + std::optional asUintOp(IrOp op) + { + if (op.kind != IrOpKind::Constant) + return std::nullopt; + + IrConst& value = constOp(op); + + if (value.kind != IrConstKind::Uint) + return std::nullopt; + + return value.valueUint; + } + double doubleOp(IrOp op) { IrConst& value = constOp(op); @@ -665,11 +730,31 @@ struct IrFunction return value.valueDouble; } + std::optional asDoubleOp(IrOp op) + { + if (op.kind != IrOpKind::Constant) + return std::nullopt; + + IrConst& value = constOp(op); + + if (value.kind != IrConstKind::Double) + return std::nullopt; + + return value.valueDouble; + } + IrCondition conditionOp(IrOp op) { LUAU_ASSERT(op.kind == IrOpKind::Condition); return IrCondition(op.index); } + + uint32_t getBlockIndex(const IrBlock& block) + { + // Can only be called with blocks from our vector + LUAU_ASSERT(&block >= blocks.data() && &block <= blocks.data() + blocks.size()); + return uint32_t(&block - blocks.data()); + } }; } // namespace CodeGen diff --git a/CodeGen/include/Luau/IrUtils.h b/CodeGen/include/Luau/IrUtils.h index 3e95813b..0a23b3f7 100644 --- a/CodeGen/include/Luau/IrUtils.h +++ b/CodeGen/include/Luau/IrUtils.h @@ -162,6 +162,8 @@ inline bool isPseudo(IrCmd cmd) return cmd == IrCmd::NOP || cmd == IrCmd::SUBSTITUTE; } +bool isGCO(uint8_t tag); + // Remove a single instruction void kill(IrFunction& function, IrInst& inst); @@ -179,7 +181,7 @@ void replace(IrFunction& function, IrOp& original, IrOp replacement); // Replace a single instruction // Target instruction index instead of reference is used to handle introduction of a new block terminator -void replace(IrFunction& function, uint32_t instIdx, IrInst replacement); +void replace(IrFunction& function, IrBlock& block, uint32_t instIdx, IrInst replacement); // Replace instruction with a different value (using IrCmd::SUBSTITUTE) void substitute(IrFunction& function, IrInst& inst, IrOp replacement); @@ -188,10 +190,13 @@ void substitute(IrFunction& function, IrInst& inst, IrOp replacement); void applySubstitutions(IrFunction& function, IrOp& op); void applySubstitutions(IrFunction& function, IrInst& inst); +// Compare numbers using IR condition value +bool compare(double a, double b, IrCondition cond); + // Perform constant folding on instruction at index // For most instructions, successful folding results in a IrCmd::SUBSTITUTE // But it can also be successful on conditional control-flow, replacing it with an unconditional IrCmd::JUMP -void foldConstants(IrBuilder& build, IrFunction& function, uint32_t instIdx); +void foldConstants(IrBuilder& build, IrFunction& function, IrBlock& block, uint32_t instIdx); } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/include/Luau/OptimizeConstProp.h b/CodeGen/include/Luau/OptimizeConstProp.h new file mode 100644 index 00000000..3be04412 --- /dev/null +++ b/CodeGen/include/Luau/OptimizeConstProp.h @@ -0,0 +1,16 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/IrData.h" + +namespace Luau +{ +namespace CodeGen +{ + +struct IrBuilder; + +void constPropInBlockChains(IrBuilder& build); + +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/src/CodeGen.cpp b/CodeGen/src/CodeGen.cpp index 78f001f1..5076cba2 100644 --- a/CodeGen/src/CodeGen.cpp +++ b/CodeGen/src/CodeGen.cpp @@ -7,6 +7,7 @@ #include "Luau/CodeBlockUnwind.h" #include "Luau/IrAnalysis.h" #include "Luau/IrBuilder.h" +#include "Luau/OptimizeConstProp.h" #include "Luau/OptimizeFinalX64.h" #include "Luau/UnwindBuilder.h" #include "Luau/UnwindBuilderDwarf2.h" @@ -31,7 +32,7 @@ #endif #endif -LUAU_FASTFLAGVARIABLE(DebugUseOldCodegen, false) +LUAU_FASTFLAGVARIABLE(DebugCodegenNoOpt, false) namespace Luau { @@ -40,12 +41,6 @@ namespace CodeGen constexpr uint32_t kFunctionAlignment = 32; -struct InstructionOutline -{ - int pcpos; - int length; -}; - static void assembleHelpers(AssemblyBuilderX64& build, ModuleHelpers& helpers) { if (build.logText) @@ -64,346 +59,6 @@ static void assembleHelpers(AssemblyBuilderX64& build, ModuleHelpers& helpers) emitContinueCallInVm(build); } -static int emitInst(AssemblyBuilderX64& build, NativeState& data, ModuleHelpers& helpers, Proto* proto, LuauOpcode op, const Instruction* pc, int i, - Label* labelarr, Label& next, Label& fallback) -{ - int skip = 0; - - switch (op) - { - case LOP_NOP: - break; - case LOP_LOADNIL: - emitInstLoadNil(build, pc); - break; - case LOP_LOADB: - emitInstLoadB(build, pc, i, labelarr); - break; - case LOP_LOADN: - emitInstLoadN(build, pc); - break; - case LOP_LOADK: - emitInstLoadK(build, pc); - break; - case LOP_LOADKX: - emitInstLoadKX(build, pc); - break; - case LOP_MOVE: - emitInstMove(build, pc); - break; - case LOP_GETGLOBAL: - emitInstGetGlobal(build, pc, i, fallback); - break; - case LOP_SETGLOBAL: - emitInstSetGlobal(build, pc, i, next, fallback); - break; - case LOP_NAMECALL: - emitInstNameCall(build, pc, i, proto->k, next, fallback); - break; - case LOP_CALL: - emitInstCall(build, helpers, pc, i); - break; - case LOP_RETURN: - emitInstReturn(build, helpers, pc, i); - break; - case LOP_GETTABLE: - emitInstGetTable(build, pc, fallback); - break; - case LOP_SETTABLE: - emitInstSetTable(build, pc, next, fallback); - break; - case LOP_GETTABLEKS: - emitInstGetTableKS(build, pc, i, fallback); - break; - case LOP_SETTABLEKS: - emitInstSetTableKS(build, pc, i, next, fallback); - break; - case LOP_GETTABLEN: - emitInstGetTableN(build, pc, fallback); - break; - case LOP_SETTABLEN: - emitInstSetTableN(build, pc, next, fallback); - break; - case LOP_JUMP: - emitInstJump(build, pc, i, labelarr); - break; - case LOP_JUMPBACK: - emitInstJumpBack(build, pc, i, labelarr); - break; - case LOP_JUMPIF: - emitInstJumpIf(build, pc, i, labelarr, /* not_ */ false); - break; - case LOP_JUMPIFNOT: - emitInstJumpIf(build, pc, i, labelarr, /* not_ */ true); - break; - case LOP_JUMPIFEQ: - emitInstJumpIfEq(build, pc, i, labelarr, /* not_ */ false, fallback); - break; - case LOP_JUMPIFLE: - emitInstJumpIfCond(build, pc, i, labelarr, ConditionX64::LessEqual, fallback); - break; - case LOP_JUMPIFLT: - emitInstJumpIfCond(build, pc, i, labelarr, ConditionX64::Less, fallback); - break; - case LOP_JUMPIFNOTEQ: - emitInstJumpIfEq(build, pc, i, labelarr, /* not_ */ true, fallback); - break; - case LOP_JUMPIFNOTLE: - emitInstJumpIfCond(build, pc, i, labelarr, ConditionX64::NotLessEqual, fallback); - break; - case LOP_JUMPIFNOTLT: - emitInstJumpIfCond(build, pc, i, labelarr, ConditionX64::NotLess, fallback); - break; - case LOP_JUMPX: - emitInstJumpX(build, pc, i, labelarr); - break; - case LOP_JUMPXEQKNIL: - emitInstJumpxEqNil(build, pc, i, labelarr); - break; - case LOP_JUMPXEQKB: - emitInstJumpxEqB(build, pc, i, labelarr); - break; - case LOP_JUMPXEQKN: - emitInstJumpxEqN(build, pc, proto->k, i, labelarr); - break; - case LOP_JUMPXEQKS: - emitInstJumpxEqS(build, pc, i, labelarr); - break; - case LOP_ADD: - emitInstBinary(build, pc, TM_ADD, fallback); - break; - case LOP_SUB: - emitInstBinary(build, pc, TM_SUB, fallback); - break; - case LOP_MUL: - emitInstBinary(build, pc, TM_MUL, fallback); - break; - case LOP_DIV: - emitInstBinary(build, pc, TM_DIV, fallback); - break; - case LOP_MOD: - emitInstBinary(build, pc, TM_MOD, fallback); - break; - case LOP_POW: - emitInstBinary(build, pc, TM_POW, fallback); - break; - case LOP_ADDK: - emitInstBinaryK(build, pc, TM_ADD, fallback); - break; - case LOP_SUBK: - emitInstBinaryK(build, pc, TM_SUB, fallback); - break; - case LOP_MULK: - emitInstBinaryK(build, pc, TM_MUL, fallback); - break; - case LOP_DIVK: - emitInstBinaryK(build, pc, TM_DIV, fallback); - break; - case LOP_MODK: - emitInstBinaryK(build, pc, TM_MOD, fallback); - break; - case LOP_POWK: - emitInstPowK(build, pc, proto->k, fallback); - break; - case LOP_NOT: - emitInstNot(build, pc); - break; - case LOP_MINUS: - emitInstMinus(build, pc, fallback); - break; - case LOP_LENGTH: - emitInstLength(build, pc, fallback); - break; - case LOP_NEWTABLE: - emitInstNewTable(build, pc, i, next); - break; - case LOP_DUPTABLE: - emitInstDupTable(build, pc, i, next); - break; - case LOP_SETLIST: - emitInstSetList(build, pc, next); - break; - case LOP_GETUPVAL: - emitInstGetUpval(build, pc); - break; - case LOP_SETUPVAL: - emitInstSetUpval(build, pc, next); - break; - case LOP_CLOSEUPVALS: - emitInstCloseUpvals(build, pc, next); - break; - case LOP_FASTCALL: - // We want to lower next instruction at skip+2, but this instruction is only 1 long, so we need to add 1 - skip = emitInstFastCall(build, pc, i, next) + 1; - break; - case LOP_FASTCALL1: - // We want to lower next instruction at skip+2, but this instruction is only 1 long, so we need to add 1 - skip = emitInstFastCall1(build, pc, i, next) + 1; - break; - case LOP_FASTCALL2: - skip = emitInstFastCall2(build, pc, i, next); - break; - case LOP_FASTCALL2K: - skip = emitInstFastCall2K(build, pc, i, next); - break; - case LOP_FORNPREP: - emitInstForNPrep(build, pc, i, next, labelarr[i + 1 + LUAU_INSN_D(*pc)]); - break; - case LOP_FORNLOOP: - emitInstForNLoop(build, pc, i, labelarr[i + 1 + LUAU_INSN_D(*pc)], next); - break; - case LOP_FORGLOOP: - emitinstForGLoop(build, pc, i, labelarr[i + 1 + LUAU_INSN_D(*pc)], next, fallback); - break; - case LOP_FORGPREP_NEXT: - emitInstForGPrepNext(build, pc, labelarr[i + 1 + LUAU_INSN_D(*pc)], fallback); - break; - case LOP_FORGPREP_INEXT: - emitInstForGPrepInext(build, pc, labelarr[i + 1 + LUAU_INSN_D(*pc)], fallback); - break; - case LOP_AND: - emitInstAnd(build, pc); - break; - case LOP_ANDK: - emitInstAndK(build, pc); - break; - case LOP_OR: - emitInstOr(build, pc); - break; - case LOP_ORK: - emitInstOrK(build, pc); - break; - case LOP_GETIMPORT: - emitInstGetImport(build, pc, fallback); - break; - case LOP_CONCAT: - emitInstConcat(build, pc, i, next); - break; - case LOP_COVERAGE: - emitInstCoverage(build, i); - break; - default: - emitFallback(build, data, op, i); - break; - } - - return skip; -} - -static void emitInstFallback(AssemblyBuilderX64& build, NativeState& data, LuauOpcode op, const Instruction* pc, int i, Label* labelarr) -{ - switch (op) - { - case LOP_GETIMPORT: - emitSetSavedPc(build, i + 1); - emitInstGetImportFallback(build, LUAU_INSN_A(*pc), pc[1]); - break; - case LOP_GETTABLE: - emitInstGetTableFallback(build, pc, i); - break; - case LOP_SETTABLE: - emitInstSetTableFallback(build, pc, i); - break; - case LOP_GETTABLEN: - emitInstGetTableNFallback(build, pc, i); - break; - case LOP_SETTABLEN: - emitInstSetTableNFallback(build, pc, i); - break; - case LOP_NAMECALL: - // TODO: fast-paths that we've handled can be removed from the fallback - emitFallback(build, data, op, i); - break; - case LOP_JUMPIFEQ: - emitInstJumpIfEqFallback(build, pc, i, labelarr, /* not_ */ false); - break; - case LOP_JUMPIFLE: - emitInstJumpIfCondFallback(build, pc, i, labelarr, ConditionX64::LessEqual); - break; - case LOP_JUMPIFLT: - emitInstJumpIfCondFallback(build, pc, i, labelarr, ConditionX64::Less); - break; - case LOP_JUMPIFNOTEQ: - emitInstJumpIfEqFallback(build, pc, i, labelarr, /* not_ */ true); - break; - case LOP_JUMPIFNOTLE: - emitInstJumpIfCondFallback(build, pc, i, labelarr, ConditionX64::NotLessEqual); - break; - case LOP_JUMPIFNOTLT: - emitInstJumpIfCondFallback(build, pc, i, labelarr, ConditionX64::NotLess); - break; - case LOP_ADD: - emitInstBinaryFallback(build, pc, i, TM_ADD); - break; - case LOP_SUB: - emitInstBinaryFallback(build, pc, i, TM_SUB); - break; - case LOP_MUL: - emitInstBinaryFallback(build, pc, i, TM_MUL); - break; - case LOP_DIV: - emitInstBinaryFallback(build, pc, i, TM_DIV); - break; - case LOP_MOD: - emitInstBinaryFallback(build, pc, i, TM_MOD); - break; - case LOP_POW: - emitInstBinaryFallback(build, pc, i, TM_POW); - break; - case LOP_ADDK: - emitInstBinaryKFallback(build, pc, i, TM_ADD); - break; - case LOP_SUBK: - emitInstBinaryKFallback(build, pc, i, TM_SUB); - break; - case LOP_MULK: - emitInstBinaryKFallback(build, pc, i, TM_MUL); - break; - case LOP_DIVK: - emitInstBinaryKFallback(build, pc, i, TM_DIV); - break; - case LOP_MODK: - emitInstBinaryKFallback(build, pc, i, TM_MOD); - break; - case LOP_POWK: - emitInstBinaryKFallback(build, pc, i, TM_POW); - break; - case LOP_MINUS: - emitInstMinusFallback(build, pc, i); - break; - case LOP_LENGTH: - emitInstLengthFallback(build, pc, i); - break; - case LOP_FORGLOOP: - emitinstForGLoopFallback(build, pc, i, labelarr[i + 1 + LUAU_INSN_D(*pc)]); - break; - case LOP_FORGPREP_NEXT: - case LOP_FORGPREP_INEXT: - emitInstForGPrepXnextFallback(build, pc, i, labelarr[i + 1 + LUAU_INSN_D(*pc)]); - break; - case LOP_GETGLOBAL: - // TODO: luaV_gettable + cachedslot update instead of full fallback - emitFallback(build, data, op, i); - break; - case LOP_SETGLOBAL: - // TODO: luaV_settable + cachedslot update instead of full fallback - emitFallback(build, data, op, i); - break; - case LOP_GETTABLEKS: - // Full fallback required for LOP_GETTABLEKS because 'luaV_gettable' doesn't handle builtin vector field access - // It is also required to perform cached slot update - // TODO: extra fast-paths could be lowered before the full fallback - emitFallback(build, data, op, i); - break; - case LOP_SETTABLEKS: - // TODO: luaV_settable + cachedslot update instead of full fallback - emitFallback(build, data, op, i); - break; - default: - LUAU_ASSERT(!"Expected fallback for instruction"); - } -} - static NativeProto* assembleFunction(AssemblyBuilderX64& build, NativeState& data, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options) { NativeProto* result = new NativeProto(); @@ -423,153 +78,32 @@ static NativeProto* assembleFunction(AssemblyBuilderX64& build, NativeState& dat build.logAppend("\n"); } - if (!FFlag::DebugUseOldCodegen) - { - build.align(kFunctionAlignment, AlignmentDataX64::Ud2); - - Label start = build.setLabel(); - - IrBuilder builder; - builder.buildFunctionIr(proto); - - optimizeMemoryOperandsX64(builder.function); - - IrLoweringX64 lowering(build, helpers, data, proto, builder.function); - - lowering.lower(options); - - result->instTargets = new uintptr_t[proto->sizecode]; - - for (int i = 0; i < proto->sizecode; i++) - { - auto [irLocation, asmLocation] = builder.function.bcMapping[i]; - - result->instTargets[i] = irLocation == ~0u ? 0 : asmLocation - start.location; - } - - result->location = start.location; - - if (build.logText) - build.logAppend("\n"); - - return result; - } - - std::vector