From be2769ad146a0b1c2d3588802196f49e4c444c8d Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 18 Aug 2022 14:32:08 -0700 Subject: [PATCH] Sync to upstream/release/541 (#644) - Fix autocomplete not suggesting globals defined after the cursor (fixes #622) - Improve type checker stability - Reduce parser C stack consumption which fixes some stack overflow crashes on deeply nested sources - Improve performance of bit32.extract/replace when width is implied (~3% faster chess) - Improve performance of bit32.extract when field/width are constants (~10% faster base64) - Heap dump now annotates thread stacks with local variable/function names --- Analysis/include/Luau/Anyification.h | 38 + Analysis/include/Luau/Constraint.h | 4 +- .../include/Luau/ConstraintGraphBuilder.h | 7 +- Analysis/include/Luau/ConstraintSolver.h | 11 +- Analysis/include/Luau/Module.h | 2 + Analysis/include/Luau/Normalize.h | 20 +- Analysis/include/Luau/TypeInfer.h | 63 +- Analysis/include/Luau/TypeUtils.h | 7 +- Analysis/include/Luau/TypedAllocator.h | 4 +- Analysis/include/Luau/Unifier.h | 6 +- Analysis/include/Luau/VisitTypeVar.h | 22 +- Analysis/src/Anyification.cpp | 96 +++ Analysis/src/Autocomplete.cpp | 45 +- Analysis/src/Constraint.cpp | 3 +- Analysis/src/ConstraintGraphBuilder.cpp | 98 ++- Analysis/src/ConstraintSolver.cpp | 42 +- Analysis/src/Frontend.cpp | 3 + Analysis/src/Module.cpp | 10 +- Analysis/src/Normalize.cpp | 67 +- Analysis/src/Quantify.cpp | 42 +- Analysis/src/TypeChecker2.cpp | 741 +++++++++++++----- Analysis/src/TypeInfer.cpp | 297 +++---- Analysis/src/TypeUtils.cpp | 111 ++- Analysis/src/TypeVar.cpp | 14 +- Analysis/src/TypedAllocator.cpp | 50 +- Analysis/src/Unifier.cpp | 10 +- Ast/include/Luau/Parser.h | 31 +- Ast/src/Parser.cpp | 322 ++++---- Common/include/Luau/Bytecode.h | 3 + Compiler/src/BuiltinFolding.cpp | 11 +- Compiler/src/Compiler.cpp | 108 ++- Sources.cmake | 4 +- VM/src/lapi.cpp | 5 +- VM/src/lbuiltins.cpp | 91 ++- VM/src/lfunc.cpp | 13 +- VM/src/lfunc.h | 1 + VM/src/lgcdebug.cpp | 57 +- tests/Autocomplete.test.cpp | 23 + tests/Compiler.test.cpp | 116 ++- tests/Conformance.test.cpp | 2 +- tests/Fixture.cpp | 3 +- tests/Fixture.h | 1 + tests/Normalize.test.cpp | 18 +- tests/Parser.test.cpp | 1 - tests/TypeInfer.functions.test.cpp | 2 - tests/TypeInfer.modules.test.cpp | 29 + tests/TypeInfer.provisional.test.cpp | 1 - tests/TypeInfer.refinements.test.cpp | 15 +- tests/TypeInfer.tables.test.cpp | 4 - tests/TypeInfer.tryUnify.test.cpp | 3 +- tests/TypeVar.test.cpp | 6 + tests/conformance/bitwise.lua | 4 + tools/faillist.txt | 31 - tools/heapgraph.py | 12 +- tools/heapstat.py | 3 +- tools/perfgraph.py | 23 +- 56 files changed, 1803 insertions(+), 953 deletions(-) create mode 100644 Analysis/include/Luau/Anyification.h create mode 100644 Analysis/src/Anyification.cpp diff --git a/Analysis/include/Luau/Anyification.h b/Analysis/include/Luau/Anyification.h new file mode 100644 index 00000000..ee8d6689 --- /dev/null +++ b/Analysis/include/Luau/Anyification.h @@ -0,0 +1,38 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#pragma once + +#include "Luau/NotNull.h" +#include "Luau/Substitution.h" +#include "Luau/TypeVar.h" + +#include + +namespace Luau +{ + +struct TypeArena; +struct Scope; +struct InternalErrorReporter; +using ScopePtr = std::shared_ptr; + +// A substitution which replaces free types by any +struct Anyification : Substitution +{ + Anyification(TypeArena* arena, const ScopePtr& scope, InternalErrorReporter* iceHandler, TypeId anyType, TypePackId anyTypePack); + NotNull scope; + InternalErrorReporter* iceHandler; + + TypeId anyType; + TypePackId anyTypePack; + bool normalizationTooComplex = false; + bool isDirty(TypeId ty) override; + bool isDirty(TypePackId tp) override; + TypeId clean(TypeId ty) override; + TypePackId clean(TypePackId tp) override; + + bool ignoreChildren(TypeId ty) override; + bool ignoreChildren(TypePackId ty) override; +}; + +} \ No newline at end of file diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h index 5b737146..ce90f2c5 100644 --- a/Analysis/include/Luau/Constraint.h +++ b/Analysis/include/Luau/Constraint.h @@ -40,7 +40,6 @@ struct GeneralizationConstraint { TypeId generalizedType; TypeId sourceType; - Scope* scope; }; // subType ~ inst superType @@ -85,13 +84,14 @@ using ConstraintPtr = std::unique_ptr; struct Constraint { - explicit Constraint(ConstraintV&& c); + Constraint(ConstraintV&& c, NotNull scope); Constraint(const Constraint&) = delete; Constraint& operator=(const Constraint&) = delete; ConstraintV c; std::vector> dependencies; + NotNull scope; }; inline Constraint& asMutable(const Constraint& c) diff --git a/Analysis/include/Luau/ConstraintGraphBuilder.h b/Analysis/include/Luau/ConstraintGraphBuilder.h index 41d14326..0e41e1ee 100644 --- a/Analysis/include/Luau/ConstraintGraphBuilder.h +++ b/Analysis/include/Luau/ConstraintGraphBuilder.h @@ -72,10 +72,10 @@ struct ConstraintGraphBuilder /** * Fabricates a scope that is a child of another scope. - * @param location the lexical extent of the scope in the source code. + * @param node the lexical node that the scope belongs to. * @param parent the parent scope of the new scope. Must not be null. */ - ScopePtr childScope(Location location, const ScopePtr& parent); + ScopePtr childScope(AstNode* node, const ScopePtr& parent); /** * Adds a new constraint with no dependencies to a given scope. @@ -105,10 +105,12 @@ struct ConstraintGraphBuilder void visit(const ScopePtr& scope, AstStatLocal* local); void visit(const ScopePtr& scope, AstStatFor* for_); void visit(const ScopePtr& scope, AstStatWhile* while_); + void visit(const ScopePtr& scope, AstStatRepeat* repeat); void visit(const ScopePtr& scope, AstStatLocalFunction* function); void visit(const ScopePtr& scope, AstStatFunction* function); void visit(const ScopePtr& scope, AstStatReturn* ret); void visit(const ScopePtr& scope, AstStatAssign* assign); + void visit(const ScopePtr& scope, AstStatCompoundAssign* assign); void visit(const ScopePtr& scope, AstStatIf* ifStatement); void visit(const ScopePtr& scope, AstStatTypeAlias* alias); void visit(const ScopePtr& scope, AstStatDeclareGlobal* declareGlobal); @@ -133,6 +135,7 @@ struct ConstraintGraphBuilder TypeId check(const ScopePtr& scope, AstExprIndexExpr* indexExpr); TypeId check(const ScopePtr& scope, AstExprUnary* unary); TypeId check(const ScopePtr& scope, AstExprBinary* binary); + TypeId check(const ScopePtr& scope, AstExprIfElse* ifElse); TypeId check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert); struct FunctionSignature diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index 9cc0e4cb..661d120e 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -60,6 +60,9 @@ struct ConstraintSolver // Memoized instantiations of type aliases. DenseHashMap instantiatedAliases{{}}; + // Recorded errors that take place within the solver. + ErrorVec errors; + ConstraintSolverLogger logger; explicit ConstraintSolver(TypeArena* arena, NotNull rootScope); @@ -115,7 +118,7 @@ struct ConstraintSolver * @param subType the sub-type to unify. * @param superType the super-type to unify. */ - void unify(TypeId subType, TypeId superType); + void unify(TypeId subType, TypeId superType, NotNull scope); /** * Creates a new Unifier and performs a single unification operation. Commits @@ -123,13 +126,15 @@ struct ConstraintSolver * @param subPack the sub-type pack to unify. * @param superPack the super-type pack to unify. */ - void unify(TypePackId subPack, TypePackId superPack); + void unify(TypePackId subPack, TypePackId superPack, NotNull scope); /** Pushes a new solver constraint to the solver. * @param cv the body of the constraint. **/ - void pushConstraint(ConstraintV cv); + void pushConstraint(ConstraintV cv, NotNull scope); + void reportError(TypeErrorData&& data, const Location& location); + void reportError(TypeError e); private: /** * Marks a constraint as being blocked on a type or type pack. The constraint diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index 6f4c6098..bec51b81 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -77,6 +77,8 @@ struct Module DenseHashMap astOverloadResolvedTypes{nullptr}; DenseHashMap astResolvedTypes{nullptr}; DenseHashMap astResolvedTypePacks{nullptr}; + // Map AST nodes to the scope they create. Cannot be NotNull because we need a sentinel value for the map. + DenseHashMap astScopes{nullptr}; std::unordered_map declaredGlobals; ErrorVec errors; diff --git a/Analysis/include/Luau/Normalize.h b/Analysis/include/Luau/Normalize.h index f5fd9886..78b241e4 100644 --- a/Analysis/include/Luau/Normalize.h +++ b/Analysis/include/Luau/Normalize.h @@ -1,20 +1,28 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/Substitution.h" -#include "Luau/TypeVar.h" #include "Luau/Module.h" +#include "Luau/NotNull.h" +#include "Luau/TypeVar.h" + +#include namespace Luau { struct InternalErrorReporter; +struct Module; +struct Scope; -bool isSubtype(TypeId subTy, TypeId superTy, InternalErrorReporter& ice); -bool isSubtype(TypePackId subTy, TypePackId superTy, InternalErrorReporter& ice); +using ModulePtr = std::shared_ptr; -std::pair normalize(TypeId ty, TypeArena& arena, InternalErrorReporter& ice); +bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, InternalErrorReporter& ice); +bool isSubtype(TypePackId subTy, TypePackId superTy, NotNull scope, InternalErrorReporter& ice); + +std::pair normalize(TypeId ty, NotNull scope, TypeArena& arena, InternalErrorReporter& ice); +std::pair normalize(TypeId ty, NotNull module, InternalErrorReporter& ice); std::pair normalize(TypeId ty, const ModulePtr& module, InternalErrorReporter& ice); -std::pair normalize(TypePackId ty, TypeArena& arena, InternalErrorReporter& ice); +std::pair normalize(TypePackId ty, NotNull scope, TypeArena& arena, InternalErrorReporter& ice); +std::pair normalize(TypePackId ty, NotNull module, InternalErrorReporter& ice); std::pair normalize(TypePackId ty, const ModulePtr& module, InternalErrorReporter& ice); } // namespace Luau diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 5c55ddb0..80f9085f 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/Anyification.h" #include "Luau/Predicate.h" #include "Luau/Error.h" #include "Luau/Module.h" @@ -36,40 +37,6 @@ const AstStat* getFallthrough(const AstStat* node); struct UnifierOptions; struct Unifier; -// A substitution which replaces free types by any -struct Anyification : Substitution -{ - Anyification(TypeArena* arena, InternalErrorReporter* iceHandler, TypeId anyType, TypePackId anyTypePack) - : Substitution(TxnLog::empty(), arena) - , iceHandler(iceHandler) - , anyType(anyType) - , anyTypePack(anyTypePack) - { - } - - InternalErrorReporter* iceHandler; - - TypeId anyType; - TypePackId anyTypePack; - bool normalizationTooComplex = false; - bool isDirty(TypeId ty) override; - bool isDirty(TypePackId tp) override; - TypeId clean(TypeId ty) override; - TypePackId clean(TypePackId tp) override; - - bool ignoreChildren(TypeId ty) override - { - if (FFlag::LuauClassTypeVarsInSubstitution && get(ty)) - return true; - - return ty->persistent; - } - bool ignoreChildren(TypePackId ty) override - { - return ty->persistent; - } -}; - struct GenericTypeDefinitions { std::vector genericTypes; @@ -196,32 +163,32 @@ struct TypeChecker /** Attempt to unify the types. * Treat any failures as type errors in the final typecheck report. */ - bool unify(TypeId subTy, TypeId superTy, const Location& location); - bool unify(TypeId subTy, TypeId superTy, const Location& location, const UnifierOptions& options); - bool unify(TypePackId subTy, TypePackId superTy, const Location& location, CountMismatch::Context ctx = CountMismatch::Context::Arg); + bool unify(TypeId subTy, TypeId superTy, const ScopePtr& scope, const Location& location); + bool unify(TypeId subTy, TypeId superTy, const ScopePtr& scope, const Location& location, const UnifierOptions& options); + bool unify(TypePackId subTy, TypePackId superTy, const ScopePtr& scope, const Location& location, CountMismatch::Context ctx = CountMismatch::Context::Arg); /** Attempt to unify the types. * If this fails, and the subTy type can be instantiated, do so and try unification again. */ - bool unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId subTy, TypeId superTy, const Location& location); - void unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId subTy, TypeId superTy, Unifier& state); + bool unifyWithInstantiationIfNeeded(TypeId subTy, TypeId superTy, const ScopePtr& scope, const Location& location); + void unifyWithInstantiationIfNeeded(TypeId subTy, TypeId superTy, const ScopePtr& scope, Unifier& state); /** Attempt to unify. * If there are errors, undo everything and return the errors. * If there are no errors, commit and return an empty error vector. */ template - ErrorVec tryUnify_(Id subTy, Id superTy, const Location& location); - ErrorVec tryUnify(TypeId subTy, TypeId superTy, const Location& location); - ErrorVec tryUnify(TypePackId subTy, TypePackId superTy, const Location& location); + ErrorVec tryUnify_(Id subTy, Id superTy, const ScopePtr& scope, const Location& location); + ErrorVec tryUnify(TypeId subTy, TypeId superTy, const ScopePtr& scope, const Location& location); + ErrorVec tryUnify(TypePackId subTy, TypePackId superTy, const ScopePtr& scope, const Location& location); // Test whether the two type vars unify. Never commits the result. template - ErrorVec canUnify_(Id subTy, Id superTy, const Location& location); - ErrorVec canUnify(TypeId subTy, TypeId superTy, const Location& location); - ErrorVec canUnify(TypePackId subTy, TypePackId superTy, const Location& location); + ErrorVec canUnify_(Id subTy, Id superTy, const ScopePtr& scope, const Location& location); + ErrorVec canUnify(TypeId subTy, TypeId superTy, const ScopePtr& scope, const Location& location); + ErrorVec canUnify(TypePackId subTy, TypePackId superTy, const ScopePtr& scope, const Location& location); - void unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel demotedLevel, const Location& location); + void unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel demotedLevel, const ScopePtr& scope, const Location& location); std::optional findMetatableEntry(TypeId type, std::string entry, const Location& location, bool addErrors); std::optional findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location, bool addErrors); @@ -290,7 +257,7 @@ private: void reportErrorCodeTooComplex(const Location& location); private: - Unifier mkUnifier(const Location& location); + Unifier mkUnifier(const ScopePtr& scope, const Location& location); // These functions are only safe to call when we are in the process of typechecking a module. @@ -312,7 +279,7 @@ public: std::pair, bool> pickTypesFromSense(TypeId type, bool sense); private: - TypeId unionOfTypes(TypeId a, TypeId b, const Location& location, bool unifyFreeTypes = true); + TypeId unionOfTypes(TypeId a, TypeId b, const ScopePtr& scope, const Location& location, bool unifyFreeTypes = true); // ex // TypeId id = addType(FreeTypeVar()); diff --git a/Analysis/include/Luau/TypeUtils.h b/Analysis/include/Luau/TypeUtils.h index 42c1bc0b..0aff5a7d 100644 --- a/Analysis/include/Luau/TypeUtils.h +++ b/Analysis/include/Luau/TypeUtils.h @@ -13,7 +13,10 @@ namespace Luau using ScopePtr = std::shared_ptr; -std::optional findMetatableEntry(ErrorVec& errors, TypeId type, std::string entry, Location location); -std::optional findTablePropertyRespectingMeta(ErrorVec& errors, TypeId ty, Name name, Location location); +std::optional findMetatableEntry(ErrorVec& errors, TypeId type, const std::string& entry, Location location); +std::optional findTablePropertyRespectingMeta(ErrorVec& errors, TypeId ty, const std::string& name, Location location); +std::optional getIndexTypeFromType( + const ScopePtr& scope, ErrorVec& errors, TypeArena* arena, TypeId type, const std::string& prop, const Location& location, bool addErrors, + InternalErrorReporter& handle); } // namespace Luau diff --git a/Analysis/include/Luau/TypedAllocator.h b/Analysis/include/Luau/TypedAllocator.h index f67e3d8e..a5dd17b6 100644 --- a/Analysis/include/Luau/TypedAllocator.h +++ b/Analysis/include/Luau/TypedAllocator.h @@ -10,7 +10,7 @@ namespace Luau { void* pagedAllocate(size_t size); -void pagedDeallocate(void* ptr); +void pagedDeallocate(void* ptr, size_t size); void pagedFreeze(void* ptr, size_t size); void pagedUnfreeze(void* ptr, size_t size); @@ -113,7 +113,7 @@ private: for (size_t i = 0; i < blockSize; ++i) block[i].~T(); - pagedDeallocate(block); + pagedDeallocate(block, kBlockSizeBytes); } stuff.clear(); diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 9fa29076..312b0584 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -3,9 +3,10 @@ #include "Luau/Error.h" #include "Luau/Location.h" +#include "Luau/Scope.h" #include "Luau/TxnLog.h" -#include "Luau/TypeInfer.h" #include "Luau/TypeArena.h" +#include "Luau/TypeInfer.h" #include "Luau/UnifierSharedState.h" #include @@ -48,6 +49,7 @@ struct Unifier TypeArena* const types; Mode mode; + NotNull scope; // const Scope maybe TxnLog log; ErrorVec errors; Location location; @@ -57,7 +59,7 @@ struct Unifier UnifierSharedState& sharedState; - Unifier(TypeArena* types, Mode mode, const Location& location, Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog = nullptr); + Unifier(TypeArena* types, Mode mode, NotNull scope, const Location& location, Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog = nullptr); // Test whether the two type vars unify. Never commits the result. ErrorVec canUnify(TypeId subTy, TypeId superTy); diff --git a/Analysis/include/Luau/VisitTypeVar.h b/Analysis/include/Luau/VisitTypeVar.h index 7e5d71d6..9d7fa9fe 100644 --- a/Analysis/include/Luau/VisitTypeVar.h +++ b/Analysis/include/Luau/VisitTypeVar.h @@ -69,12 +69,14 @@ struct GenericTypeVarVisitor using Set = S; Set seen; + bool skipBoundTypes = false; int recursionCounter = 0; GenericTypeVarVisitor() = default; - explicit GenericTypeVarVisitor(Set seen) + explicit GenericTypeVarVisitor(Set seen, bool skipBoundTypes = false) : seen(std::move(seen)) + , skipBoundTypes(skipBoundTypes) { } @@ -199,7 +201,9 @@ struct GenericTypeVarVisitor if (auto btv = get(ty)) { - if (visit(ty, *btv)) + if (skipBoundTypes) + traverse(btv->boundTo); + else if (visit(ty, *btv)) traverse(btv->boundTo); } else if (auto ftv = get(ty)) @@ -229,7 +233,11 @@ struct GenericTypeVarVisitor else if (auto ttv = get(ty)) { // Some visitors want to see bound tables, that's why we traverse the original type - if (visit(ty, *ttv)) + if (skipBoundTypes && ttv->boundTo) + { + traverse(*ttv->boundTo); + } + else if (visit(ty, *ttv)) { if (ttv->boundTo) { @@ -394,13 +402,17 @@ struct GenericTypeVarVisitor */ struct TypeVarVisitor : GenericTypeVarVisitor> { + explicit TypeVarVisitor(bool skipBoundTypes = false) + : GenericTypeVarVisitor{{}, skipBoundTypes} + { + } }; /// Visit each type under a given type. Each type will only be checked once even if there are multiple paths to it. struct TypeVarOnceVisitor : GenericTypeVarVisitor> { - TypeVarOnceVisitor() - : GenericTypeVarVisitor{DenseHashSet{nullptr}} + explicit TypeVarOnceVisitor(bool skipBoundTypes = false) + : GenericTypeVarVisitor{DenseHashSet{nullptr}, skipBoundTypes} { } }; diff --git a/Analysis/src/Anyification.cpp b/Analysis/src/Anyification.cpp new file mode 100644 index 00000000..b6e58009 --- /dev/null +++ b/Analysis/src/Anyification.cpp @@ -0,0 +1,96 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Anyification.h" + +#include "Luau/Common.h" +#include "Luau/Normalize.h" +#include "Luau/TxnLog.h" + +LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) + +namespace Luau +{ + +Anyification::Anyification(TypeArena* arena, const ScopePtr& scope, InternalErrorReporter* iceHandler, TypeId anyType, TypePackId anyTypePack) + : Substitution(TxnLog::empty(), arena) + , scope(NotNull{scope.get()}) + , iceHandler(iceHandler) + , anyType(anyType) + , anyTypePack(anyTypePack) +{ +} + +bool Anyification::isDirty(TypeId ty) +{ + if (ty->persistent) + return false; + + if (const TableTypeVar* ttv = log->getMutable(ty)) + return (ttv->state == TableState::Free || ttv->state == TableState::Unsealed); + else if (log->getMutable(ty)) + return true; + else if (get(ty)) + return true; + else + return false; +} + +bool Anyification::isDirty(TypePackId tp) +{ + if (tp->persistent) + return false; + + if (log->getMutable(tp)) + return true; + else + return false; +} + +TypeId Anyification::clean(TypeId ty) +{ + LUAU_ASSERT(isDirty(ty)); + if (const TableTypeVar* ttv = log->getMutable(ty)) + { + TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, TableState::Sealed}; + clone.definitionModuleName = ttv->definitionModuleName; + clone.name = ttv->name; + clone.syntheticName = ttv->syntheticName; + clone.tags = ttv->tags; + TypeId res = addType(std::move(clone)); + asMutable(res)->normal = ty->normal; + return res; + } + else if (auto ctv = get(ty)) + { + std::vector copy = ctv->parts; + for (TypeId& ty : copy) + ty = replace(ty); + TypeId res = copy.size() == 1 ? copy[0] : addType(UnionTypeVar{std::move(copy)}); + auto [t, ok] = normalize(res, scope, *arena, *iceHandler); + if (!ok) + normalizationTooComplex = true; + return t; + } + else + return anyType; +} + +TypePackId Anyification::clean(TypePackId tp) +{ + LUAU_ASSERT(isDirty(tp)); + return anyTypePack; +} + +bool Anyification::ignoreChildren(TypeId ty) +{ + if (FFlag::LuauClassTypeVarsInSubstitution && get(ty)) + return true; + + return ty->persistent; +} +bool Anyification::ignoreChildren(TypePackId ty) +{ + return ty->persistent; +} + +} diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 8d5cc723..5c484899 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -14,6 +14,8 @@ LUAU_FASTFLAG(LuauSelfCallAutocompleteFix3) +LUAU_FASTFLAGVARIABLE(LuauAutocompleteFixGlobalOrder, false) + static const std::unordered_set kStatementStartingKeywords = { "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; @@ -135,11 +137,11 @@ static std::optional findExpectedTypeAt(const Module& module, AstNode* n return *it; } -static bool checkTypeMatch(TypeArena* typeArena, TypeId subTy, TypeId superTy) +static bool checkTypeMatch(TypeId subTy, TypeId superTy, NotNull scope, TypeArena* typeArena) { InternalErrorReporter iceReporter; UnifierSharedState unifierState(&iceReporter); - Unifier unifier(typeArena, Mode::Strict, Location(), Variance::Covariant, unifierState); + Unifier unifier(typeArena, Mode::Strict, scope, Location(), Variance::Covariant, unifierState); return unifier.canUnify(subTy, superTy).empty(); } @@ -148,12 +150,14 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ { ty = follow(ty); - auto canUnify = [&typeArena](TypeId subTy, TypeId superTy) { + NotNull moduleScope{module.getModuleScope().get()}; + + auto canUnify = [&typeArena, moduleScope](TypeId subTy, TypeId superTy) { LUAU_ASSERT(!FFlag::LuauSelfCallAutocompleteFix3); InternalErrorReporter iceReporter; UnifierSharedState unifierState(&iceReporter); - Unifier unifier(typeArena, Mode::Strict, Location(), Variance::Covariant, unifierState); + Unifier unifier(typeArena, Mode::Strict, moduleScope, Location(), Variance::Covariant, unifierState); unifier.tryUnify(subTy, superTy); bool ok = unifier.errors.empty(); @@ -167,11 +171,11 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ TypeId expectedType = follow(*typeAtPosition); - auto checkFunctionType = [typeArena, &canUnify, &expectedType](const FunctionTypeVar* ftv) { + auto checkFunctionType = [typeArena, moduleScope, &canUnify, &expectedType](const FunctionTypeVar* ftv) { if (FFlag::LuauSelfCallAutocompleteFix3) { if (std::optional firstRetTy = first(ftv->retTypes)) - return checkTypeMatch(typeArena, *firstRetTy, expectedType); + return checkTypeMatch(*firstRetTy, expectedType, moduleScope, typeArena); return false; } @@ -210,7 +214,7 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ } if (FFlag::LuauSelfCallAutocompleteFix3) - return checkTypeMatch(typeArena, ty, expectedType) ? TypeCorrectKind::Correct : TypeCorrectKind::None; + return checkTypeMatch(ty, expectedType, NotNull{module.getModuleScope().get()}, typeArena) ? TypeCorrectKind::Correct : TypeCorrectKind::None; else return canUnify(ty, expectedType) ? TypeCorrectKind::Correct : TypeCorrectKind::None; } @@ -268,7 +272,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId return colonIndex; } }; - auto isWrongIndexer = [typeArena, rootTy, indexType](Luau::TypeId type) { + auto isWrongIndexer = [typeArena, &module, rootTy, indexType](Luau::TypeId type) { LUAU_ASSERT(FFlag::LuauSelfCallAutocompleteFix3); if (indexType == PropIndexType::Key) @@ -276,7 +280,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId bool calledWithSelf = indexType == PropIndexType::Colon; - auto isCompatibleCall = [typeArena, rootTy, calledWithSelf](const FunctionTypeVar* ftv) { + auto isCompatibleCall = [typeArena, &module, rootTy, calledWithSelf](const FunctionTypeVar* ftv) { // Strong match with definition is a success if (calledWithSelf == ftv->hasSelf) return true; @@ -289,7 +293,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId // When called with '.', but declared with 'self', it is considered invalid if first argument is compatible if (std::optional firstArgTy = first(ftv->argTypes)) { - if (checkTypeMatch(typeArena, rootTy, *firstArgTy)) + if (checkTypeMatch(rootTy, *firstArgTy, NotNull{module.getModuleScope().get()}, typeArena)) return calledWithSelf; } @@ -1073,10 +1077,21 @@ T* extractStat(const std::vector& ancestry) return nullptr; } -static bool isBindingLegalAtCurrentPosition(const Binding& binding, Position pos) +static bool isBindingLegalAtCurrentPosition(const Symbol& symbol, const Binding& binding, Position pos) { - // Default Location used for global bindings, which are always legal. - return binding.location == Location() || binding.location.end < pos; + if (FFlag::LuauAutocompleteFixGlobalOrder) + { + if (symbol.local) + return binding.location.end < pos; + + // Builtin globals have an empty location; for defined globals, we want pos to be outside of the definition range to suggest it + return binding.location == Location() || !binding.location.containsClosed(pos); + } + else + { + // Default Location used for global bindings, which are always legal. + return binding.location == Location() || binding.location.end < pos; + } } static AutocompleteEntryMap autocompleteStatement( @@ -1097,7 +1112,7 @@ static AutocompleteEntryMap autocompleteStatement( { for (const auto& [name, binding] : scope->bindings) { - if (!isBindingLegalAtCurrentPosition(binding, position)) + if (!isBindingLegalAtCurrentPosition(name, binding, position)) continue; std::string n = toString(name); @@ -1225,7 +1240,7 @@ static AutocompleteContext autocompleteExpression(const SourceModule& sourceModu { for (const auto& [name, binding] : scope->bindings) { - if (!isBindingLegalAtCurrentPosition(binding, position)) + if (!isBindingLegalAtCurrentPosition(name, binding, position)) continue; if (isBeingDefined(ancestry, name)) diff --git a/Analysis/src/Constraint.cpp b/Analysis/src/Constraint.cpp index 64e3a666..d272c027 100644 --- a/Analysis/src/Constraint.cpp +++ b/Analysis/src/Constraint.cpp @@ -5,8 +5,9 @@ namespace Luau { -Constraint::Constraint(ConstraintV&& c) +Constraint::Constraint(ConstraintV&& c, NotNull scope) : c(std::move(c)) + , scope(scope) { } diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index 2c36d422..c1e54df2 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -2,6 +2,7 @@ #include "Luau/ConstraintGraphBuilder.h" #include "Luau/Ast.h" +#include "Luau/Common.h" #include "Luau/Constraint.h" #include "Luau/RecursionCounter.h" #include "Luau/ToString.h" @@ -26,6 +27,7 @@ ConstraintGraphBuilder::ConstraintGraphBuilder( , globalScope(globalScope) { LUAU_ASSERT(arena); + LUAU_ASSERT(module); } TypeId ConstraintGraphBuilder::freshType(const ScopePtr& scope) @@ -39,20 +41,22 @@ TypePackId ConstraintGraphBuilder::freshTypePack(const ScopePtr& scope) return arena->addTypePack(TypePackVar{std::move(f)}); } -ScopePtr ConstraintGraphBuilder::childScope(Location location, const ScopePtr& parent) +ScopePtr ConstraintGraphBuilder::childScope(AstNode* node, const ScopePtr& parent) { auto scope = std::make_shared(parent); - scopes.emplace_back(location, scope); + scopes.emplace_back(node->location, scope); scope->returnType = parent->returnType; - parent->children.push_back(NotNull(scope.get())); + + parent->children.push_back(NotNull{scope.get()}); + module->astScopes[node] = scope.get(); return scope; } void ConstraintGraphBuilder::addConstraint(const ScopePtr& scope, ConstraintV cv) { - scope->constraints.emplace_back(new Constraint{std::move(cv)}); + scope->constraints.emplace_back(new Constraint{std::move(cv), NotNull{scope.get()}}); } void ConstraintGraphBuilder::addConstraint(const ScopePtr& scope, std::unique_ptr c) @@ -67,6 +71,7 @@ void ConstraintGraphBuilder::visit(AstStatBlock* block) ScopePtr scope = std::make_shared(globalScope); rootScope = scope.get(); scopes.emplace_back(block->location, scope); + module->astScopes[block] = NotNull{scope.get()}; rootScope->returnType = freshTypePack(scope); @@ -115,7 +120,7 @@ void ConstraintGraphBuilder::visitBlockWithoutChildScope(const ScopePtr& scope, ScopePtr defnScope = scope; if (hasGenerics) { - defnScope = childScope(alias->location, scope); + defnScope = childScope(alias, scope); } TypeId initialType = freshType(scope); @@ -155,6 +160,8 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStat* stat) visit(scope, s); else if (auto s = stat->as()) visit(scope, s); + else if (auto s = stat->as()) + visit(scope, s); else if (auto f = stat->as()) visit(scope, f); else if (auto f = stat->as()) @@ -163,6 +170,8 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStat* stat) visit(scope, r); else if (auto a = stat->as()) visit(scope, a); + else if (auto a = stat->as()) + visit(scope, a); else if (auto e = stat->as()) checkPack(scope, e->expr); else if (auto i = stat->as()) @@ -241,7 +250,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFor* for_) checkNumber(for_->to); checkNumber(for_->step); - ScopePtr forScope = childScope(for_->location, scope); + ScopePtr forScope = childScope(for_, scope); forScope->bindings[for_->var] = Binding{singletonTypes.numberType, for_->var->location}; visit(forScope, for_->body); @@ -251,11 +260,22 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatWhile* while_) { check(scope, while_->condition); - ScopePtr whileScope = childScope(while_->location, scope); + ScopePtr whileScope = childScope(while_, scope); visit(whileScope, while_->body); } +void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatRepeat* repeat) +{ + ScopePtr repeatScope = childScope(repeat, scope); + + visit(repeatScope, repeat->body); + + // The condition does indeed have access to bindings from within the body of + // the loop. + check(repeatScope, repeat->condition); +} + void addConstraints(Constraint* constraint, NotNull scope) { scope->constraints.reserve(scope->constraints.size() + scope->constraints.size()); @@ -286,8 +306,8 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocalFunction* checkFunctionBody(sig.bodyScope, function->func); - std::unique_ptr c{ - new Constraint{GeneralizationConstraint{functionType, sig.signature, sig.signatureScope ? sig.signatureScope.get() : sig.bodyScope.get()}}}; + NotNull constraintScope{sig.signatureScope ? sig.signatureScope.get() : sig.bodyScope.get()}; + std::unique_ptr c = std::make_unique(GeneralizationConstraint{functionType, sig.signature}, constraintScope); addConstraints(c.get(), NotNull(sig.bodyScope.get())); addConstraint(scope, std::move(c)); @@ -356,8 +376,8 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction* funct checkFunctionBody(sig.bodyScope, function->func); - std::unique_ptr c{ - new Constraint{GeneralizationConstraint{functionType, sig.signature, sig.signatureScope ? sig.signatureScope.get() : sig.bodyScope.get()}}}; + NotNull constraintScope{sig.signatureScope ? sig.signatureScope.get() : sig.bodyScope.get()}; + std::unique_ptr c = std::make_unique(GeneralizationConstraint{functionType, sig.signature}, constraintScope); addConstraints(c.get(), NotNull(sig.bodyScope.get())); addConstraint(scope, std::move(c)); @@ -371,7 +391,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatReturn* ret) void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatBlock* block) { - ScopePtr innerScope = childScope(block->location, scope); + ScopePtr innerScope = childScope(block, scope); visitBlockWithoutChildScope(innerScope, block); } @@ -384,16 +404,30 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatAssign* assign) addConstraint(scope, PackSubtypeConstraint{valuePack, varPackId}); } +void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatCompoundAssign* assign) +{ + // Synthesize A = A op B from A op= B and then build constraints for that instead. + + AstExprBinary exprBinary{assign->location, assign->op, assign->var, assign->value}; + AstExpr* exprBinaryPtr = &exprBinary; + + AstArray vars{&assign->var, 1}; + AstArray values{&exprBinaryPtr, 1}; + AstStatAssign syntheticAssign{assign->location, vars, values}; + + visit(scope, &syntheticAssign); +} + void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatIf* ifStatement) { check(scope, ifStatement->condition); - ScopePtr thenScope = childScope(ifStatement->thenbody->location, scope); + ScopePtr thenScope = childScope(ifStatement->thenbody, scope); visit(thenScope, ifStatement->thenbody); if (ifStatement->elsebody) { - ScopePtr elseScope = childScope(ifStatement->elsebody->location, scope); + ScopePtr elseScope = childScope(ifStatement->elsebody, scope); visit(elseScope, ifStatement->elsebody); } } @@ -561,7 +595,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareFunction ScopePtr funScope = scope; if (!generics.empty() || !genericPacks.empty()) - funScope = childScope(global->location, scope); + funScope = childScope(global, scope); TypePackId paramPack = resolveTypePack(funScope, global->params); TypePackId retPack = resolveTypePack(funScope, global->retTypes); @@ -739,6 +773,8 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr) result = check(scope, unary); else if (auto binary = expr->as()) result = check(scope, binary); + else if (auto ifElse = expr->as()) + result = check(scope, ifElse); else if (auto typeAssert = expr->as()) result = check(scope, typeAssert); else if (auto err = expr->as()) @@ -819,6 +855,12 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprBinary* binar addConstraint(scope, SubtypeConstraint{leftType, rightType}); return leftType; } + case AstExprBinary::Add: + { + TypeId resultType = arena->addType(BlockedTypeVar{}); + addConstraint(scope, BinaryConstraint{AstExprBinary::Add, leftType, rightType, resultType}); + return resultType; + } case AstExprBinary::Sub: { TypeId resultType = arena->addType(BlockedTypeVar{}); @@ -833,6 +875,24 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprBinary* binar return nullptr; } +TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIfElse* ifElse) +{ + check(scope, ifElse->condition); + + TypeId thenType = check(scope, ifElse->trueExpr); + TypeId elseType = check(scope, ifElse->falseExpr); + + if (ifElse->hasElse) + { + TypeId resultType = arena->addType(BlockedTypeVar{}); + addConstraint(scope, SubtypeConstraint{thenType, resultType}); + addConstraint(scope, SubtypeConstraint{elseType, resultType}); + return resultType; + } + + return thenType; +} + TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert) { check(scope, typeAssert->expr); @@ -905,14 +965,14 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS // generics properly. if (hasGenerics) { - signatureScope = childScope(fn->location, parent); + signatureScope = childScope(fn, parent); // We need to assign returnType before creating bodyScope so that the // return type gets propogated to bodyScope. returnType = freshTypePack(signatureScope); signatureScope->returnType = returnType; - bodyScope = childScope(fn->body->location, signatureScope); + bodyScope = childScope(fn->body, signatureScope); std::vector> genericDefinitions = createGenerics(signatureScope, fn->generics); std::vector> genericPackDefinitions = createGenericPacks(signatureScope, fn->genericPacks); @@ -933,7 +993,7 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS } else { - bodyScope = childScope(fn->body->location, parent); + bodyScope = childScope(fn->body, parent); returnType = freshTypePack(bodyScope); bodyScope->returnType = returnType; @@ -1098,7 +1158,7 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b // for the generic bindings to live on. if (hasGenerics) { - signatureScope = childScope(fn->location, scope); + signatureScope = childScope(fn, scope); std::vector> genericDefinitions = createGenerics(signatureScope, fn->generics); std::vector> genericPackDefinitions = createGenericPacks(signatureScope, fn->genericPacks); diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index d8105ce3..6c6d2722 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -373,7 +373,7 @@ bool ConstraintSolver::tryDispatch(const SubtypeConstraint& c, NotNullscope); unblock(c.subType); unblock(c.superType); @@ -383,7 +383,7 @@ bool ConstraintSolver::tryDispatch(const SubtypeConstraint& c, NotNull constraint, bool force) { - unify(c.subPack, c.superPack); + unify(c.subPack, c.superPack, constraint->scope); unblock(c.subPack); unblock(c.superPack); @@ -398,9 +398,9 @@ bool ConstraintSolver::tryDispatch(const GeneralizationConstraint& c, NotNullty.emplace(c.sourceType); else - unify(c.generalizedType, c.sourceType); + unify(c.generalizedType, c.sourceType, constraint->scope); - TypeId generalized = quantify(arena, c.sourceType, c.scope); + TypeId generalized = quantify(arena, c.sourceType, constraint->scope); *asMutable(c.sourceType) = *generalized; unblock(c.generalizedType); @@ -422,7 +422,7 @@ bool ConstraintSolver::tryDispatch(const InstantiationConstraint& c, NotNullty.emplace(*instantiated); else - unify(c.subType, *instantiated); + unify(c.subType, *instantiated, constraint->scope); unblock(c.subType); @@ -465,7 +465,7 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNullscope); asMutable(c.resultType)->ty.emplace(leftType); return true; } @@ -528,16 +528,18 @@ struct InstantiationQueuer : TypeVarOnceVisitor { ConstraintSolver* solver; const InstantiationSignature& signature; + NotNull scope; - explicit InstantiationQueuer(ConstraintSolver* solver, const InstantiationSignature& signature) + explicit InstantiationQueuer(ConstraintSolver* solver, const InstantiationSignature& signature, NotNull scope) : solver(solver) , signature(signature) + , scope(scope) { } bool visit(TypeId ty, const PendingExpansionTypeVar& petv) override { - solver->pushConstraint(TypeAliasExpansionConstraint{ty}); + solver->pushConstraint(TypeAliasExpansionConstraint{ty}, scope); return false; } }; @@ -686,7 +688,7 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul // The application is not recursive, so we need to queue up application of // any child type function instantiations within the result in order for it // to be complete. - InstantiationQueuer queuer{this, signature}; + InstantiationQueuer queuer{this, signature, constraint->scope}; queuer.traverse(target); instantiatedAliases[signature] = target; @@ -766,30 +768,40 @@ bool ConstraintSolver::isBlocked(NotNull constraint) return blockedIt != blockedConstraints.end() && blockedIt->second > 0; } -void ConstraintSolver::unify(TypeId subType, TypeId superType) +void ConstraintSolver::unify(TypeId subType, TypeId superType, NotNull scope) { UnifierSharedState sharedState{&iceReporter}; - Unifier u{arena, Mode::Strict, Location{}, Covariant, sharedState}; + Unifier u{arena, Mode::Strict, scope, Location{}, Covariant, sharedState}; u.tryUnify(subType, superType); u.log.commit(); } -void ConstraintSolver::unify(TypePackId subPack, TypePackId superPack) +void ConstraintSolver::unify(TypePackId subPack, TypePackId superPack, NotNull scope) { UnifierSharedState sharedState{&iceReporter}; - Unifier u{arena, Mode::Strict, Location{}, Covariant, sharedState}; + Unifier u{arena, Mode::Strict, scope, Location{}, Covariant, sharedState}; u.tryUnify(subPack, superPack); u.log.commit(); } -void ConstraintSolver::pushConstraint(ConstraintV cv) +void ConstraintSolver::pushConstraint(ConstraintV cv, NotNull scope) { - std::unique_ptr c = std::make_unique(std::move(cv)); + std::unique_ptr c = std::make_unique(std::move(cv), scope); NotNull borrow = NotNull(c.get()); solverConstraints.push_back(std::move(c)); unsolvedConstraints.push_back(borrow); } +void ConstraintSolver::reportError(TypeErrorData&& data, const Location& location) +{ + errors.emplace_back(location, std::move(data)); +} + +void ConstraintSolver::reportError(TypeError e) +{ + errors.emplace_back(std::move(e)); +} + } // namespace Luau diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index c450297f..8ab4e865 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -839,6 +839,9 @@ ModulePtr Frontend::check(const SourceModule& sourceModule, Mode mode, const Sco ConstraintSolver cs{&result->internalTypes, NotNull(cgb.rootScope)}; cs.run(); + for (TypeError& e : cs.errors) + result->errors.emplace_back(std::move(e)); + result->scopes = std::move(cgb.scopes); result->astTypes = std::move(cgb.astTypes); result->astTypePacks = std::move(cgb.astTypePacks); diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index 4e6b2585..de796c7a 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -244,12 +244,12 @@ void Module::clonePublicInterface(InternalErrorReporter& ice) if (FFlag::LuauLowerBoundsCalculation) { - normalize(returnType, interfaceTypes, ice); + normalize(returnType, NotNull{this}, ice); if (FFlag::LuauForceExportSurfacesToBeNormal) forceNormal.traverse(returnType); if (varargPack) { - normalize(*varargPack, interfaceTypes, ice); + normalize(*varargPack, NotNull{this}, ice); if (FFlag::LuauForceExportSurfacesToBeNormal) forceNormal.traverse(*varargPack); } @@ -265,7 +265,7 @@ void Module::clonePublicInterface(InternalErrorReporter& ice) tf = clone(tf, interfaceTypes, cloneState); if (FFlag::LuauLowerBoundsCalculation) { - normalize(tf.type, interfaceTypes, ice); + normalize(tf.type, NotNull{this}, ice); // We're about to freeze the memory. We know that the flag is conservative by design. Cyclic tables // won't be marked normal. If the types aren't normal by now, they never will be. @@ -276,7 +276,7 @@ void Module::clonePublicInterface(InternalErrorReporter& ice) if (param.defaultValue) { - normalize(*param.defaultValue, interfaceTypes, ice); + normalize(*param.defaultValue, NotNull{this}, ice); forceNormal.traverse(*param.defaultValue); } } @@ -302,7 +302,7 @@ void Module::clonePublicInterface(InternalErrorReporter& ice) ty = clone(ty, interfaceTypes, cloneState); if (FFlag::LuauLowerBoundsCalculation) { - normalize(ty, interfaceTypes, ice); + normalize(ty, NotNull{this}, ice); if (FFlag::LuauForceExportSurfacesToBeNormal) forceNormal.traverse(ty); diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index 33f369a7..94adaf5c 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -15,7 +15,6 @@ LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200); LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineTableFix, false); LUAU_FASTFLAGVARIABLE(LuauFixNormalizationOfCyclicUnions, false); LUAU_FASTFLAG(LuauUnknownAndNeverType) -LUAU_FASTFLAG(LuauQuantifyConstrained) namespace Luau { @@ -55,11 +54,11 @@ struct Replacer } // anonymous namespace -bool isSubtype(TypeId subTy, TypeId superTy, InternalErrorReporter& ice) +bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, InternalErrorReporter& ice) { UnifierSharedState sharedState{&ice}; TypeArena arena; - Unifier u{&arena, Mode::Strict, Location{}, Covariant, sharedState}; + Unifier u{&arena, Mode::Strict, scope, Location{}, Covariant, sharedState}; u.anyIsTop = true; u.tryUnify(subTy, superTy); @@ -67,11 +66,11 @@ bool isSubtype(TypeId subTy, TypeId superTy, InternalErrorReporter& ice) return ok; } -bool isSubtype(TypePackId subPack, TypePackId superPack, InternalErrorReporter& ice) +bool isSubtype(TypePackId subPack, TypePackId superPack, NotNull scope, InternalErrorReporter& ice) { UnifierSharedState sharedState{&ice}; TypeArena arena; - Unifier u{&arena, Mode::Strict, Location{}, Covariant, sharedState}; + Unifier u{&arena, Mode::Strict, scope, Location{}, Covariant, sharedState}; u.anyIsTop = true; u.tryUnify(subPack, superPack); @@ -134,13 +133,15 @@ struct Normalize final : TypeVarVisitor { using TypeVarVisitor::Set; - Normalize(TypeArena& arena, InternalErrorReporter& ice) + Normalize(TypeArena& arena, NotNull scope, InternalErrorReporter& ice) : arena(arena) + , scope(scope) , ice(ice) { } TypeArena& arena; + NotNull scope; InternalErrorReporter& ice; int iterationLimit = 0; @@ -215,22 +216,7 @@ struct Normalize final : TypeVarVisitor traverse(part); std::vector newParts = normalizeUnion(parts); - - if (FFlag::LuauQuantifyConstrained) - { - ctv->parts = std::move(newParts); - } - else - { - const bool normal = areNormal(newParts, seen, ice); - - if (newParts.size() == 1) - *asMutable(ty) = BoundTypeVar{newParts[0]}; - else - *asMutable(ty) = UnionTypeVar{std::move(newParts)}; - - asMutable(ty)->normal = normal; - } + ctv->parts = std::move(newParts); return false; } @@ -288,12 +274,7 @@ struct Normalize final : TypeVarVisitor } // An unsealed table can never be normal, ditto for free tables iff the type it is bound to is also not normal. - if (FFlag::LuauQuantifyConstrained) - { - if (ttv.state == TableState::Generic || ttv.state == TableState::Sealed || (ttv.state == TableState::Free && follow(ty)->normal)) - asMutable(ty)->normal = normal; - } - else + if (ttv.state == TableState::Generic || ttv.state == TableState::Sealed || (ttv.state == TableState::Free && follow(ty)->normal)) asMutable(ty)->normal = normal; return false; @@ -518,9 +499,9 @@ struct Normalize final : TypeVarVisitor for (TypeId& part : result) { - if (isSubtype(ty, part, ice)) + if (isSubtype(ty, part, scope, ice)) return; // no need to do anything - else if (isSubtype(part, ty, ice)) + else if (isSubtype(part, ty, scope, ice)) { part = ty; // replace the less general type by the more general one return; @@ -572,12 +553,12 @@ struct Normalize final : TypeVarVisitor bool merged = false; for (TypeId& part : result->parts) { - if (isSubtype(part, ty, ice)) + if (isSubtype(part, ty, scope, ice)) { merged = true; break; // no need to do anything } - else if (isSubtype(ty, part, ice)) + else if (isSubtype(ty, part, scope, ice)) { merged = true; part = ty; // replace the less general type by the more general one @@ -710,13 +691,13 @@ struct Normalize final : TypeVarVisitor /** * @returns A tuple of TypeId and a success indicator. (true indicates that the normalization completed successfully) */ -std::pair normalize(TypeId ty, TypeArena& arena, InternalErrorReporter& ice) +std::pair normalize(TypeId ty, NotNull scope, TypeArena& arena, InternalErrorReporter& ice) { CloneState state; if (FFlag::DebugLuauCopyBeforeNormalizing) (void)clone(ty, arena, state); - Normalize n{arena, ice}; + Normalize n{arena, scope, ice}; n.traverse(ty); return {ty, !n.limitExceeded}; @@ -726,29 +707,39 @@ std::pair normalize(TypeId ty, TypeArena& arena, InternalErrorRepo // reclaim memory used by wantonly allocated intermediate types here. // The main wrinkle here is that we don't want clone() to copy a type if the source and dest // arena are the same. +std::pair normalize(TypeId ty, NotNull module, InternalErrorReporter& ice) +{ + return normalize(ty, NotNull{module->getModuleScope().get()}, module->internalTypes, ice); +} + std::pair normalize(TypeId ty, const ModulePtr& module, InternalErrorReporter& ice) { - return normalize(ty, module->internalTypes, ice); + return normalize(ty, NotNull{module.get()}, ice); } /** * @returns A tuple of TypeId and a success indicator. (true indicates that the normalization completed successfully) */ -std::pair normalize(TypePackId tp, TypeArena& arena, InternalErrorReporter& ice) +std::pair normalize(TypePackId tp, NotNull scope, TypeArena& arena, InternalErrorReporter& ice) { CloneState state; if (FFlag::DebugLuauCopyBeforeNormalizing) (void)clone(tp, arena, state); - Normalize n{arena, ice}; + Normalize n{arena, scope, ice}; n.traverse(tp); return {tp, !n.limitExceeded}; } +std::pair normalize(TypePackId tp, NotNull module, InternalErrorReporter& ice) +{ + return normalize(tp, NotNull{module->getModuleScope().get()}, module->internalTypes, ice); +} + std::pair normalize(TypePackId tp, const ModulePtr& module, InternalErrorReporter& ice) { - return normalize(tp, module->internalTypes, ice); + return normalize(tp, NotNull{module.get()}, ice); } } // namespace Luau diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp index ce4afe22..7e6ff2f9 100644 --- a/Analysis/src/Quantify.cpp +++ b/Analysis/src/Quantify.cpp @@ -10,7 +10,6 @@ LUAU_FASTFLAG(DebugLuauSharedSelf) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); -LUAU_FASTFLAGVARIABLE(LuauQuantifyConstrained, false) LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) namespace Luau @@ -82,30 +81,25 @@ struct Quantifier final : TypeVarOnceVisitor bool visit(TypeId ty, const ConstrainedTypeVar&) override { - if (FFlag::LuauQuantifyConstrained) - { - ConstrainedTypeVar* ctv = getMutable(ty); + ConstrainedTypeVar* ctv = getMutable(ty); - seenMutableType = true; - - if (FFlag::DebugLuauDeferredConstraintResolution ? !subsumes(scope, ctv->scope) : !level.subsumes(ctv->level)) - return false; - - std::vector opts = std::move(ctv->parts); - - // We might transmute, so it's not safe to rely on the builtin traversal logic - for (TypeId opt : opts) - traverse(opt); - - if (opts.size() == 1) - *asMutable(ty) = BoundTypeVar{opts[0]}; - else - *asMutable(ty) = UnionTypeVar{std::move(opts)}; + seenMutableType = true; + if (FFlag::DebugLuauDeferredConstraintResolution ? !subsumes(scope, ctv->scope) : !level.subsumes(ctv->level)) return false; - } + + std::vector opts = std::move(ctv->parts); + + // We might transmute, so it's not safe to rely on the builtin traversal logic + for (TypeId opt : opts) + traverse(opt); + + if (opts.size() == 1) + *asMutable(ty) = BoundTypeVar{opts[0]}; else - return true; + *asMutable(ty) = UnionTypeVar{std::move(opts)}; + + return false; } bool visit(TypeId ty, const TableTypeVar&) override @@ -119,12 +113,6 @@ struct Quantifier final : TypeVarOnceVisitor if (ttv.state == TableState::Free) seenMutableType = true; - if (!FFlag::LuauQuantifyConstrained) - { - if (ttv.state == TableState::Sealed || ttv.state == TableState::Generic) - return false; - } - if (FFlag::DebugLuauDeferredConstraintResolution ? !subsumes(scope, ttv.scope) : !level.subsumes(ttv.level)) { if (ttv.state == TableState::Unsealed) diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index 9d97e8d9..e5813cd2 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -8,22 +8,59 @@ #include "Luau/Clone.h" #include "Luau/Instantiation.h" #include "Luau/Normalize.h" +#include "Luau/ToString.h" #include "Luau/TxnLog.h" #include "Luau/TypeUtils.h" #include "Luau/TypeVar.h" #include "Luau/Unifier.h" -#include "Luau/ToString.h" namespace Luau { -struct TypeChecker2 : public AstVisitor +/* Push a scope onto the end of a stack for the lifetime of the StackPusher instance. + * TypeChecker2 uses this to maintain knowledge about which scope encloses every + * given AstNode. + */ +struct StackPusher +{ + std::vector>* stack; + NotNull scope; + + explicit StackPusher(std::vector>& stack, Scope* scope) + : stack(&stack) + , scope(scope) + { + stack.push_back(NotNull{scope}); + } + + ~StackPusher() + { + if (stack) + { + LUAU_ASSERT(stack->back() == scope); + stack->pop_back(); + } + } + + StackPusher(const StackPusher&) = delete; + StackPusher&& operator=(const StackPusher&) = delete; + + StackPusher(StackPusher&& other) + : stack(std::exchange(other.stack, nullptr)) + , scope(other.scope) + { + } +}; + +struct TypeChecker2 { const SourceModule* sourceModule; Module* module; InternalErrorReporter ice; // FIXME accept a pointer from Frontend SingletonTypes& singletonTypes; + std::vector> stack; + TypeChecker2(const SourceModule* sourceModule, Module* module) : sourceModule(sourceModule) , module(module) @@ -31,7 +68,13 @@ struct TypeChecker2 : public AstVisitor { } - using AstVisitor::visit; + std::optional pushStack(AstNode* node) + { + if (Scope** scope = module->astScopes.find(node)) + return StackPusher{stack, *scope}; + else + return std::nullopt; + } TypePackId lookupPack(AstExpr* expr) { @@ -118,11 +161,128 @@ struct TypeChecker2 : public AstVisitor return bestScope; } - bool visit(AstStatLocal* local) override + void visit(AstStat* stat) + { + auto pusher = pushStack(stat); + + if (0) + {} + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else + LUAU_ASSERT(!"TypeChecker2 encountered an unknown node type"); + } + + void visit(AstStatBlock* block) + { + auto StackPusher = pushStack(block); + + for (AstStat* statement : block->body) + visit(statement); + } + + void visit(AstStatIf* ifStatement) + { + visit(ifStatement->condition); + visit(ifStatement->thenbody); + if (ifStatement->elsebody) + visit(ifStatement->elsebody); + } + + void visit(AstStatWhile* whileStatement) + { + visit(whileStatement->condition); + visit(whileStatement->body); + } + + void visit(AstStatRepeat* repeatStatement) + { + visit(repeatStatement->body); + visit(repeatStatement->condition); + } + + void visit(AstStatBreak*) + {} + + void visit(AstStatContinue*) + {} + + void visit(AstStatReturn* ret) + { + Scope* scope = findInnermostScope(ret->location); + TypePackId expectedRetType = scope->returnType; + + TypeArena arena; + TypePackId actualRetType = reconstructPack(ret->list, arena); + + UnifierSharedState sharedState{&ice}; + Unifier u{&arena, Mode::Strict, stack.back(), ret->location, Covariant, sharedState}; + u.anyIsTop = true; + + u.tryUnify(actualRetType, expectedRetType); + const bool ok = u.errors.empty() && u.log.empty(); + + if (!ok) + { + for (const TypeError& e : u.errors) + reportError(e); + } + + for (AstExpr* expr : ret->list) + visit(expr); + } + + void visit(AstStatExpr* expr) + { + visit(expr->expr); + } + + void visit(AstStatLocal* local) { for (size_t i = 0; i < local->values.size; ++i) { AstExpr* value = local->values.data[i]; + + visit(value); + if (i == local->values.size - 1) { if (i < local->values.size) @@ -140,7 +300,7 @@ struct TypeChecker2 : public AstVisitor if (var->annotation) { TypeId varType = lookupAnnotation(var->annotation); - if (!isSubtype(*it, varType, ice)) + if (!isSubtype(*it, varType, stack.back(), ice)) { reportError(TypeMismatch{varType, *it}, value->location); } @@ -158,64 +318,244 @@ struct TypeChecker2 : public AstVisitor if (var->annotation) { TypeId varType = lookupAnnotation(var->annotation); - if (!isSubtype(varType, valueType, ice)) + if (!isSubtype(varType, valueType, stack.back(), ice)) { reportError(TypeMismatch{varType, valueType}, value->location); } } } } - - return true; } - bool visit(AstStatAssign* assign) override + void visit(AstStatFor* forStatement) + { + if (forStatement->var->annotation) + visit(forStatement->var->annotation); + + visit(forStatement->from); + visit(forStatement->to); + if (forStatement->step) + visit(forStatement->step); + visit(forStatement->body); + } + + void visit(AstStatForIn* forInStatement) + { + for (AstLocal* local : forInStatement->vars) + { + if (local->annotation) + visit(local->annotation); + } + + for (AstExpr* expr : forInStatement->values) + visit(expr); + + visit(forInStatement->body); + } + + void visit(AstStatAssign* assign) { size_t count = std::min(assign->vars.size, assign->values.size); for (size_t i = 0; i < count; ++i) { AstExpr* lhs = assign->vars.data[i]; + visit(lhs); TypeId lhsType = lookupType(lhs); AstExpr* rhs = assign->values.data[i]; + visit(rhs); TypeId rhsType = lookupType(rhs); - if (!isSubtype(rhsType, lhsType, ice)) + if (!isSubtype(rhsType, lhsType, stack.back(), ice)) { reportError(TypeMismatch{lhsType, rhsType}, rhs->location); } } - - return true; } - bool visit(AstStatReturn* ret) override + void visit(AstStatCompoundAssign* stat) { - Scope* scope = findInnermostScope(ret->location); - TypePackId expectedRetType = scope->returnType; + visit(stat->var); + visit(stat->value); + } - TypeArena arena; - TypePackId actualRetType = reconstructPack(ret->list, arena); + void visit(AstStatFunction* stat) + { + visit(stat->name); + visit(stat->func); + } - UnifierSharedState sharedState{&ice}; - Unifier u{&arena, Mode::Strict, ret->location, Covariant, sharedState}; - u.anyIsTop = true; + void visit(AstStatLocalFunction* stat) + { + visit(stat->func); + } - u.tryUnify(actualRetType, expectedRetType); - const bool ok = u.errors.empty() && u.log.empty(); + void visit(const AstTypeList* typeList) + { + for (AstType* ty : typeList->types) + visit(ty); - if (!ok) + if (typeList->tailType) + visit(typeList->tailType); + } + + void visit(AstStatTypeAlias* stat) + { + for (const AstGenericType& el : stat->generics) { - for (const TypeError& e : u.errors) - reportError(e); + if (el.defaultValue) + visit(el.defaultValue); } - return true; + for (const AstGenericTypePack& el : stat->genericPacks) + { + if (el.defaultValue) + visit(el.defaultValue); + } + + visit(stat->type); } - bool visit(AstExprCall* call) override + void visit(AstTypeList types) { + for (AstType* type : types.types) + visit(type); + if (types.tailType) + visit(types.tailType); + } + + void visit(AstStatDeclareFunction* stat) + { + visit(stat->params); + visit(stat->retTypes); + } + + void visit(AstStatDeclareGlobal* stat) + { + visit(stat->type); + } + + void visit(AstStatDeclareClass* stat) + { + for (const AstDeclaredClassProp& prop : stat->props) + visit(prop.ty); + } + + void visit(AstStatError* stat) + { + for (AstExpr* expr : stat->expressions) + visit(expr); + + for (AstStat* s : stat->statements) + visit(s); + } + + void visit(AstExpr* expr) + { + auto StackPusher = pushStack(expr); + + if (0) + {} + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else + LUAU_ASSERT(!"TypeChecker2 encountered an unknown expression type"); + } + + void visit(AstExprGroup* expr) + { + visit(expr->expr); + } + + void visit(AstExprConstantNil* expr) + { + // TODO! + } + + void visit(AstExprConstantBool* expr) + { + // TODO! + } + + void visit(AstExprConstantNumber* number) + { + TypeId actualType = lookupType(number); + TypeId numberType = getSingletonTypes().numberType; + + if (!isSubtype(numberType, actualType, stack.back(), ice)) + { + reportError(TypeMismatch{actualType, numberType}, number->location); + } + } + + void visit(AstExprConstantString* string) + { + TypeId actualType = lookupType(string); + TypeId stringType = getSingletonTypes().stringType; + + if (!isSubtype(stringType, actualType, stack.back(), ice)) + { + reportError(TypeMismatch{actualType, stringType}, string->location); + } + } + + void visit(AstExprLocal* expr) + { + // TODO! + } + + void visit(AstExprGlobal* expr) + { + // TODO! + } + + void visit(AstExprVarargs* expr) + { + // TODO! + } + + void visit(AstExprCall* call) + { + visit(call->func); + + for (AstExpr* arg : call->args) + visit(arg); + TypeArena arena; Instantiation instantiation{TxnLog::empty(), &arena, TypeLevel{}}; @@ -225,7 +565,7 @@ struct TypeChecker2 : public AstVisitor LUAU_ASSERT(functionType); TypePack args; - for (const auto& arg : call->args) + for (AstExpr* arg : call->args) { TypeId argTy = module->astTypes[arg]; LUAU_ASSERT(argTy); @@ -235,7 +575,7 @@ struct TypeChecker2 : public AstVisitor TypePackId argsTp = arena.addTypePack(args); FunctionTypeVar ftv{argsTp, expectedRetType}; TypeId expectedType = arena.addType(ftv); - if (!isSubtype(expectedType, instantiatedFunctionType, ice)) + if (!isSubtype(expectedType, instantiatedFunctionType, stack.back(), ice)) { unfreeze(module->interfaceTypes); CloneState cloneState; @@ -243,12 +583,36 @@ struct TypeChecker2 : public AstVisitor freeze(module->interfaceTypes); reportError(TypeMismatch{expectedType, functionType}, call->location); } - - return true; } - bool visit(AstExprFunction* fn) override + void visit(AstExprIndexName* indexName) { + TypeId leftType = lookupType(indexName->expr); + TypeId resultType = lookupType(indexName); + + // leftType must have a property called indexName->index + + std::optional ty = getIndexTypeFromType(module->getModuleScope(), leftType, indexName->index.value, indexName->location, /* addErrors */ true); + if (ty) + { + if (!isSubtype(resultType, *ty, stack.back(), ice)) + { + reportError(TypeMismatch{resultType, *ty}, indexName->location); + } + } + } + + void visit(AstExprIndexExpr* indexExpr) + { + // TODO! + visit(indexExpr->expr); + visit(indexExpr->index); + } + + void visit(AstExprFunction* fn) + { + auto StackPusher = pushStack(fn); + TypeId inferredFnTy = lookupType(fn); const FunctionTypeVar* inferredFtv = get(inferredFnTy); LUAU_ASSERT(inferredFtv); @@ -264,7 +628,7 @@ struct TypeChecker2 : public AstVisitor TypeId inferredArgTy = *argIt; TypeId annotatedArgTy = lookupAnnotation(arg->annotation); - if (!isSubtype(annotatedArgTy, inferredArgTy, ice)) + if (!isSubtype(annotatedArgTy, inferredArgTy, stack.back(), ice)) { reportError(TypeMismatch{annotatedArgTy, inferredArgTy}, arg->location); } @@ -273,68 +637,64 @@ struct TypeChecker2 : public AstVisitor ++argIt; } - return true; + visit(fn->body); } - bool visit(AstExprIndexName* indexName) override + void visit(AstExprTable* expr) { - TypeId leftType = lookupType(indexName->expr); - TypeId resultType = lookupType(indexName); - - // leftType must have a property called indexName->index - - std::optional ty = getIndexTypeFromType(module->getModuleScope(), leftType, indexName->index.value, indexName->location, /* addErrors */ true); - if (ty) + // TODO! + for (const AstExprTable::Item& item : expr->items) { - if (!isSubtype(resultType, *ty, ice)) - { - reportError(TypeMismatch{resultType, *ty}, indexName->location); - } + if (item.key) + visit(item.key); + visit(item.value); } - - return true; } - bool visit(AstExprConstantNumber* number) override + void visit(AstExprUnary* expr) { - TypeId actualType = lookupType(number); - TypeId numberType = getSingletonTypes().numberType; - - if (!isSubtype(numberType, actualType, ice)) - { - reportError(TypeMismatch{actualType, numberType}, number->location); - } - - return true; + // TODO! + visit(expr->expr); } - bool visit(AstExprConstantString* string) override + void visit(AstExprBinary* expr) { - TypeId actualType = lookupType(string); - TypeId stringType = getSingletonTypes().stringType; - - if (!isSubtype(stringType, actualType, ice)) - { - reportError(TypeMismatch{actualType, stringType}, string->location); - } - - return true; + // TODO! + visit(expr->left); + visit(expr->right); } - bool visit(AstExprTypeAssertion* expr) override + void visit(AstExprTypeAssertion* expr) { + visit(expr->expr); + visit(expr->annotation); + TypeId annotationType = lookupAnnotation(expr->annotation); TypeId computedType = lookupType(expr->expr); // Note: As an optimization, we try 'number <: number | string' first, as that is the more likely case. - if (isSubtype(annotationType, computedType, ice)) - return true; + if (isSubtype(annotationType, computedType, stack.back(), ice)) + return; - if (isSubtype(computedType, annotationType, ice)) - return true; + if (isSubtype(computedType, annotationType, stack.back(), ice)) + return; reportError(TypesAreUnrelated{computedType, annotationType}, expr->location); - return true; + } + + void visit(AstExprIfElse* expr) + { + // TODO! + visit(expr->condition); + visit(expr->trueExpr); + visit(expr->falseExpr); + } + + void visit(AstExprError* expr) + { + // TODO! + for (AstExpr* e : expr->expressions) + visit(e); } /** Extract a TypeId for the first type of the provided pack. @@ -375,13 +735,32 @@ struct TypeChecker2 : public AstVisitor ice.ice("flattenPack got a weird pack!"); } - bool visit(AstType* ty) override + void visit(AstType* ty) { - return true; + if (auto t = ty->as()) + return visit(t); + else if (auto t = ty->as()) + return visit(t); + else if (auto t = ty->as()) + return visit(t); + else if (auto t = ty->as()) + return visit(t); + else if (auto t = ty->as()) + return visit(t); + else if (auto t = ty->as()) + return visit(t); } - bool visit(AstTypeReference* ty) override + void visit(AstTypeReference* ty) { + for (const AstTypeOrPack& param : ty->parameters) + { + if (param.type) + visit(param.type); + else + visit(param.typePack); + } + Scope* scope = findInnermostScope(ty->location); LUAU_ASSERT(scope); @@ -500,16 +879,76 @@ struct TypeChecker2 : public AstVisitor reportError(UnknownSymbol{ty->name.value, UnknownSymbol::Context::Type}, ty->location); } } - - return true; } - bool visit(AstTypePack*) override + void visit(AstTypeTable* table) { - return true; + // TODO! + + for (const AstTableProp& prop : table->props) + visit(prop.type); + + if (table->indexer) + { + visit(table->indexer->indexType); + visit(table->indexer->resultType); + } } - bool visit(AstTypePackGeneric* tp) override + void visit(AstTypeFunction* ty) + { + // TODO! + + visit(ty->argTypes); + visit(ty->returnTypes); + } + + void visit(AstTypeTypeof* ty) + { + visit(ty->expr); + } + + void visit(AstTypeUnion* ty) + { + // TODO! + for (AstType* type : ty->types) + visit(type); + } + + void visit(AstTypeIntersection* ty) + { + // TODO! + for (AstType* type : ty->types) + visit(type); + } + + void visit(AstTypePack* pack) + { + if (auto p = pack->as()) + return visit(p); + else if (auto p = pack->as()) + return visit(p); + else if (auto p = pack->as()) + return visit(p); + } + + void visit(AstTypePackExplicit* tp) + { + // TODO! + for (AstType* type : tp->typeList.types) + visit(type); + + if (tp->typeList.tailType) + visit(tp->typeList.tailType); + } + + void visit(AstTypePackVariadic* tp) + { + // TODO! + visit(tp->variadicType); + } + + void visit(AstTypePackGeneric* tp) { Scope* scope = findInnermostScope(tp->location); LUAU_ASSERT(scope); @@ -531,8 +970,6 @@ struct TypeChecker2 : public AstVisitor reportError(UnknownSymbol{tp->genericName.value, UnknownSymbol::Context::Type}, tp->location); } } - - return true; } void reportError(TypeErrorData&& data, const Location& location) @@ -546,139 +983,9 @@ struct TypeChecker2 : public AstVisitor } std::optional getIndexTypeFromType( - const ScopePtr& scope, TypeId type, const Name& name, const Location& location, bool addErrors) + const ScopePtr& scope, TypeId type, const std::string& prop, const Location& location, bool addErrors) { - type = follow(type); - - if (get(type) || get(type) || get(type)) - return type; - - if (auto f = get(type)) - *asMutable(type) = TableTypeVar{TableState::Free, f->level}; - - if (isString(type)) - { - std::optional mtIndex = Luau::findMetatableEntry(module->errors, singletonTypes.stringType, "__index", location); - LUAU_ASSERT(mtIndex); - type = *mtIndex; - } - - if (TableTypeVar* tableType = getMutableTableType(type)) - { - - return findTablePropertyRespectingMeta(module->errors, type, name, location); - } - else if (const ClassTypeVar* cls = get(type)) - { - const Property* prop = lookupClassProp(cls, name); - if (prop) - return prop->type; - } - else if (const UnionTypeVar* utv = get(type)) - { - std::vector goodOptions; - std::vector badOptions; - - for (TypeId t : utv) - { - // TODO: we should probably limit recursion here? - // RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit); - - // Not needed when we normalize types. - if (get(follow(t))) - return t; - - if (std::optional ty = getIndexTypeFromType(scope, t, name, location, /* addErrors= */ false)) - goodOptions.push_back(*ty); - else - badOptions.push_back(t); - } - - if (!badOptions.empty()) - { - if (addErrors) - { - if (goodOptions.empty()) - reportError(UnknownProperty{type, name}, location); - else - reportError(MissingUnionProperty{type, badOptions, name}, location); - } - return std::nullopt; - } - - std::vector result = reduceUnion(goodOptions); - if (result.empty()) - return singletonTypes.neverType; - - if (result.size() == 1) - return result[0]; - - return module->internalTypes.addType(UnionTypeVar{std::move(result)}); - } - else if (const IntersectionTypeVar* itv = get(type)) - { - std::vector parts; - - for (TypeId t : itv->parts) - { - // TODO: we should probably limit recursion here? - // RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit); - - if (std::optional ty = getIndexTypeFromType(scope, t, name, location, /* addErrors= */ false)) - parts.push_back(*ty); - } - - // If no parts of the intersection had the property we looked up for, it never existed at all. - if (parts.empty()) - { - if (addErrors) - reportError(UnknownProperty{type, name}, location); - return std::nullopt; - } - - if (parts.size() == 1) - return parts[0]; - - return module->internalTypes.addType(IntersectionTypeVar{std::move(parts)}); // Not at all correct. - } - - if (addErrors) - reportError(UnknownProperty{type, name}, location); - - return std::nullopt; - } - - std::vector reduceUnion(const std::vector& types) - { - std::vector result; - for (TypeId t : types) - { - t = follow(t); - if (get(t)) - continue; - - if (get(t) || get(t)) - return {t}; - - if (const UnionTypeVar* utv = get(t)) - { - for (TypeId ty : utv) - { - ty = follow(ty); - if (get(ty)) - continue; - if (get(ty) || get(ty)) - return {ty}; - - if (result.end() == std::find(result.begin(), result.end(), ty)) - result.push_back(ty); - } - } - else if (std::find(result.begin(), result.end(), t) == result.end()) - result.push_back(t); - } - - return result; + return Luau::getIndexTypeFromType(scope, module->errors, &module->internalTypes, type, prop, location, addErrors, ice); } }; @@ -686,7 +993,7 @@ void check(const SourceModule& sourceModule, Module* module) { TypeChecker2 typeChecker{&sourceModule, module}; - sourceModule.root->visit(&typeChecker); + typeChecker.visit(sourceModule.root); } } // namespace Luau diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 7ab23360..9886fb19 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -33,15 +33,13 @@ LUAU_FASTINTVARIABLE(LuauVisitRecursionLimit, 500) LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAG(LuauAutocompleteDynamicLimits) LUAU_FASTFLAGVARIABLE(LuauExpectedTableUnionIndexerType, false) +LUAU_FASTFLAGVARIABLE(LuauInplaceDemoteSkipAllBound, false) LUAU_FASTFLAGVARIABLE(LuauLowerBoundsCalculation, false) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauSelfCallAutocompleteFix3, false) LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false. LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false); -LUAU_FASTFLAGVARIABLE(LuauReportErrorsOnIndexerKeyMismatch, false) LUAU_FASTFLAGVARIABLE(LuauUnknownAndNeverType, false) -LUAU_FASTFLAG(LuauQuantifyConstrained) -LUAU_FASTFLAGVARIABLE(LuauFalsyPredicateReturnsNilInstead, false) LUAU_FASTFLAGVARIABLE(LuauCheckGenericHOFTypes, false) LUAU_FASTFLAGVARIABLE(LuauBinaryNeedsExpectedTypesToo, false) LUAU_FASTFLAGVARIABLE(LuauNeverTypesAndOperatorsInference, false) @@ -473,7 +471,8 @@ struct InplaceDemoter : TypeVarOnceVisitor TypeArena* arena; InplaceDemoter(TypeLevel level, TypeArena* arena) - : newLevel(level) + : TypeVarOnceVisitor(/* skipBoundTypes= */ FFlag::LuauInplaceDemoteSkipAllBound) + , newLevel(level) , arena(arena) { } @@ -494,6 +493,7 @@ struct InplaceDemoter : TypeVarOnceVisitor bool visit(TypeId ty, const BoundTypeVar& btyRef) override { + LUAU_ASSERT(!FFlag::LuauInplaceDemoteSkipAllBound); return true; } @@ -656,7 +656,7 @@ void TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const A TypeId leftType = follow(checkFunctionName(scope, *fun->name, funScope->level)); - unify(funTy, leftType, fun->location); + unify(funTy, leftType, scope, fun->location); } else if (auto fun = (*protoIter)->as()) { @@ -768,20 +768,20 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatIf& statement) } template -ErrorVec TypeChecker::canUnify_(Id subTy, Id superTy, const Location& location) +ErrorVec TypeChecker::canUnify_(Id subTy, Id superTy, const ScopePtr& scope, const Location& location) { - Unifier state = mkUnifier(location); + Unifier state = mkUnifier(scope, location); return state.canUnify(subTy, superTy); } -ErrorVec TypeChecker::canUnify(TypeId subTy, TypeId superTy, const Location& location) +ErrorVec TypeChecker::canUnify(TypeId subTy, TypeId superTy, const ScopePtr& scope, const Location& location) { - return canUnify_(subTy, superTy, location); + return canUnify_(subTy, superTy, scope, location); } -ErrorVec TypeChecker::canUnify(TypePackId subTy, TypePackId superTy, const Location& location) +ErrorVec TypeChecker::canUnify(TypePackId subTy, TypePackId superTy, const ScopePtr& scope, const Location& location) { - return canUnify_(subTy, superTy, location); + return canUnify_(subTy, superTy, scope, location); } void TypeChecker::check(const ScopePtr& scope, const AstStatWhile& statement) @@ -802,9 +802,9 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatRepeat& statement) checkExpr(repScope, *statement.condition); } -void TypeChecker::unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel demotedLevel, const Location& location) +void TypeChecker::unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel demotedLevel, const ScopePtr& scope, const Location& location) { - Unifier state = mkUnifier(location); + Unifier state = mkUnifier(scope, location); state.unifyLowerBound(subTy, superTy, demotedLevel); state.log.commit(); @@ -858,8 +858,6 @@ struct Demoter : Substitution void demote(std::vector>& expectedTypes) { - if (!FFlag::LuauQuantifyConstrained) - return; for (std::optional& ty : expectedTypes) { if (ty) @@ -897,7 +895,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatReturn& return_) if (useConstrainedIntersections()) { - unifyLowerBound(retPack, scope->returnType, demoter.demotedLevel(scope->level), return_.location); + unifyLowerBound(retPack, scope->returnType, demoter.demotedLevel(scope->level), scope, return_.location); return; } @@ -905,7 +903,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatReturn& return_) // start typechecking everything across module boundaries. if (isNonstrictMode() && follow(scope->returnType) == follow(currentModule->getModuleScope()->returnType)) { - ErrorVec errors = tryUnify(retPack, scope->returnType, return_.location); + ErrorVec errors = tryUnify(retPack, scope->returnType, scope, return_.location); if (!errors.empty()) currentModule->getModuleScope()->returnType = addTypePack({anyType}); @@ -913,13 +911,13 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatReturn& return_) return; } - unify(retPack, scope->returnType, return_.location, CountMismatch::Context::Return); + unify(retPack, scope->returnType, scope, return_.location, CountMismatch::Context::Return); } template -ErrorVec TypeChecker::tryUnify_(Id subTy, Id superTy, const Location& location) +ErrorVec TypeChecker::tryUnify_(Id subTy, Id superTy, const ScopePtr& scope, const Location& location) { - Unifier state = mkUnifier(location); + Unifier state = mkUnifier(scope, location); if (FFlag::DebugLuauFreezeDuringUnification) freeze(currentModule->internalTypes); @@ -935,14 +933,14 @@ ErrorVec TypeChecker::tryUnify_(Id subTy, Id superTy, const Location& location) return state.errors; } -ErrorVec TypeChecker::tryUnify(TypeId subTy, TypeId superTy, const Location& location) +ErrorVec TypeChecker::tryUnify(TypeId subTy, TypeId superTy, const ScopePtr& scope, const Location& location) { - return tryUnify_(subTy, superTy, location); + return tryUnify_(subTy, superTy, scope, location); } -ErrorVec TypeChecker::tryUnify(TypePackId subTy, TypePackId superTy, const Location& location) +ErrorVec TypeChecker::tryUnify(TypePackId subTy, TypePackId superTy, const ScopePtr& scope, const Location& location) { - return tryUnify_(subTy, superTy, location); + return tryUnify_(subTy, superTy, scope, location); } void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) @@ -1036,9 +1034,9 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) { // In nonstrict mode, any assignments where the lhs is free and rhs isn't a function, we give it any typevar. if (isNonstrictMode() && get(follow(left)) && !get(follow(right))) - unify(anyType, left, loc); + unify(anyType, left, scope, loc); else - unify(right, left, loc); + unify(right, left, scope, loc); } } } @@ -1053,7 +1051,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatCompoundAssign& assi TypeId result = checkBinaryOperation(scope, expr, left, right); - unify(result, left, assign.location); + unify(result, left, scope, assign.location); } void TypeChecker::check(const ScopePtr& scope, const AstStatLocal& local) @@ -1108,7 +1106,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatLocal& local) TypePackId valuePack = checkExprList(scope, local.location, local.values, /* substituteFreeForNil= */ true, instantiateGenerics, expectedTypes).type; - Unifier state = mkUnifier(local.location); + Unifier state = mkUnifier(scope, local.location); state.ctx = CountMismatch::Result; state.tryUnify(valuePack, variablePack); reportErrors(state.errors); @@ -1184,7 +1182,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatFor& expr) TypeId loopVarType = numberType; if (expr.var->annotation) - unify(loopVarType, resolveType(scope, *expr.var->annotation), expr.location); + unify(loopVarType, resolveType(scope, *expr.var->annotation), scope, expr.location); loopScope->bindings[expr.var] = {loopVarType, expr.var->location}; @@ -1194,11 +1192,11 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatFor& expr) if (!expr.to) ice("Bad AstStatFor has no to expr"); - unify(checkExpr(loopScope, *expr.from).type, loopVarType, expr.from->location); - unify(checkExpr(loopScope, *expr.to).type, loopVarType, expr.to->location); + unify(checkExpr(loopScope, *expr.from).type, loopVarType, scope, expr.from->location); + unify(checkExpr(loopScope, *expr.to).type, loopVarType, scope, expr.to->location); if (expr.step) - unify(checkExpr(loopScope, *expr.step).type, loopVarType, expr.step->location); + unify(checkExpr(loopScope, *expr.step).type, loopVarType, scope, expr.step->location); check(loopScope, *expr.body); } @@ -1251,12 +1249,12 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) if (get(callRetPack)) { iterTy = freshType(scope); - unify(callRetPack, addTypePack({{iterTy}, freshTypePack(scope)}), forin.location); + unify(callRetPack, addTypePack({{iterTy}, freshTypePack(scope)}), scope, forin.location); } else if (get(callRetPack) || !first(callRetPack)) { for (TypeId var : varTypes) - unify(errorRecoveryType(scope), var, forin.location); + unify(errorRecoveryType(scope), var, scope, forin.location); return check(loopScope, *forin.body); } @@ -1277,7 +1275,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) // TODO: this needs to typecheck all returned values by __iter as if they were for loop arguments // the structure of the function makes it difficult to do this especially since we don't have actual expressions, only types for (TypeId var : varTypes) - unify(anyType, var, forin.location); + unify(anyType, var, scope, forin.location); return check(loopScope, *forin.body); } @@ -1289,25 +1287,25 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) if (iterTable->indexer) { if (varTypes.size() > 0) - unify(iterTable->indexer->indexType, varTypes[0], forin.location); + unify(iterTable->indexer->indexType, varTypes[0], scope, forin.location); if (varTypes.size() > 1) - unify(iterTable->indexer->indexResultType, varTypes[1], forin.location); + unify(iterTable->indexer->indexResultType, varTypes[1], scope, forin.location); for (size_t i = 2; i < varTypes.size(); ++i) - unify(nilType, varTypes[i], forin.location); + unify(nilType, varTypes[i], scope, forin.location); } else if (isNonstrictMode()) { for (TypeId var : varTypes) - unify(anyType, var, forin.location); + unify(anyType, var, scope, forin.location); } else { TypeId varTy = errorRecoveryType(loopScope); for (TypeId var : varTypes) - unify(varTy, var, forin.location); + unify(varTy, var, scope, forin.location); reportError(firstValue->location, GenericError{"Cannot iterate over a table without indexer"}); } @@ -1321,7 +1319,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) TypeId varTy = get(iterTy) ? anyType : errorRecoveryType(loopScope); for (TypeId var : varTypes) - unify(varTy, var, forin.location); + unify(varTy, var, scope, forin.location); if (!get(iterTy) && !get(iterTy) && !get(iterTy) && !get(iterTy)) reportError(firstValue->location, CannotCallNonFunction{iterTy}); @@ -1346,7 +1344,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) argPack = addTypePack(TypePack{}); } - Unifier state = mkUnifier(firstValue->location); + Unifier state = mkUnifier(loopScope, firstValue->location); checkArgumentList(loopScope, state, argPack, iterFunc->argTypes, /*argLocations*/ {}); state.log.commit(); @@ -1365,10 +1363,10 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) AstExprCall exprCall{Location(start, end), firstValue, arguments, /* self= */ false, Location()}; TypePackId retPack = checkExprPack(scope, exprCall).type; - unify(retPack, varPack, forin.location); + unify(retPack, varPack, scope, forin.location); } else - unify(iterFunc->retTypes, varPack, forin.location); + unify(iterFunc->retTypes, varPack, scope, forin.location); check(loopScope, *forin.body); } @@ -1603,7 +1601,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias TypeId& bindingType = bindingsMap[name].type; - if (unify(ty, bindingType, typealias.location)) + if (unify(ty, bindingType, aliasScope, typealias.location)) bindingType = ty; if (FFlag::LuauLowerBoundsCalculation) @@ -1891,7 +1889,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp TypeLevel level = FFlag::LuauLowerBoundsCalculation ? ftp->level : scope->level; TypeId head = freshType(level); TypePackId pack = addTypePack(TypePackVar{TypePack{{head}, freshTypePack(level)}}); - unify(pack, retPack, expr.location); + unify(pack, retPack, scope, expr.location); return {head, std::move(result.predicates)}; } if (get(retPack)) @@ -1983,20 +1981,15 @@ std::optional TypeChecker::getIndexTypeFromTypeImpl( else if (auto indexer = tableType->indexer) { // TODO: Property lookup should work with string singletons or unions thereof as the indexer key type. - ErrorVec errors = tryUnify(stringType, indexer->indexType, location); + ErrorVec errors = tryUnify(stringType, indexer->indexType, scope, location); - if (FFlag::LuauReportErrorsOnIndexerKeyMismatch) - { - if (errors.empty()) - return indexer->indexResultType; - - if (addErrors) - reportError(location, UnknownProperty{type, name}); - - return std::nullopt; - } - else + if (errors.empty()) return indexer->indexResultType; + + if (addErrors) + reportError(location, UnknownProperty{type, name}); + + return std::nullopt; } else if (tableType->state == TableState::Free) { @@ -2228,8 +2221,8 @@ TypeId TypeChecker::checkExprTable( if (indexer) { - unify(numberType, indexer->indexType, value->location); - unify(valueType, indexer->indexResultType, value->location); + unify(numberType, indexer->indexType, scope, value->location); + unify(valueType, indexer->indexResultType, scope, value->location); } else indexer = TableIndexer{numberType, anyIfNonstrict(valueType)}; @@ -2248,13 +2241,13 @@ TypeId TypeChecker::checkExprTable( if (it != expectedTable->props.end()) { Property expectedProp = it->second; - ErrorVec errors = tryUnify(exprType, expectedProp.type, k->location); + ErrorVec errors = tryUnify(exprType, expectedProp.type, scope, k->location); if (errors.empty()) exprType = expectedProp.type; } else if (expectedTable->indexer && maybeString(expectedTable->indexer->indexType)) { - ErrorVec errors = tryUnify(exprType, expectedTable->indexer->indexResultType, k->location); + ErrorVec errors = tryUnify(exprType, expectedTable->indexer->indexResultType, scope, k->location); if (errors.empty()) exprType = expectedTable->indexer->indexResultType; } @@ -2269,8 +2262,8 @@ TypeId TypeChecker::checkExprTable( if (indexer) { - unify(keyType, indexer->indexType, k->location); - unify(valueType, indexer->indexResultType, value->location); + unify(keyType, indexer->indexType, scope, k->location); + unify(valueType, indexer->indexResultType, scope, value->location); } else if (isNonstrictMode()) { @@ -2411,7 +2404,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp TypePackId retTypePack = freshTypePack(scope); TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retTypePack)); - Unifier state = mkUnifier(expr.location); + Unifier state = mkUnifier(scope, expr.location); state.tryUnify(actualFunctionType, expectedFunctionType, /*isFunctionCall*/ true); state.log.commit(); @@ -2429,7 +2422,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp return {errorRecoveryType(scope)}; } - reportErrors(tryUnify(operandType, numberType, expr.location)); + reportErrors(tryUnify(operandType, numberType, scope, expr.location)); return {numberType}; } case AstExprUnary::Len: @@ -2459,7 +2452,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp TypePackId retTypePack = addTypePack({numberType}); TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retTypePack)); - Unifier state = mkUnifier(expr.location); + Unifier state = mkUnifier(scope, expr.location); state.tryUnify(actualFunctionType, expectedFunctionType, /*isFunctionCall*/ true); state.log.commit(); @@ -2509,11 +2502,11 @@ std::string opToMetaTableEntry(const AstExprBinary::Op& op) } } -TypeId TypeChecker::unionOfTypes(TypeId a, TypeId b, const Location& location, bool unifyFreeTypes) +TypeId TypeChecker::unionOfTypes(TypeId a, TypeId b, const ScopePtr& scope, const Location& location, bool unifyFreeTypes) { if (unifyFreeTypes && (get(a) || get(b))) { - if (unify(b, a, location)) + if (unify(b, a, scope, location)) return a; return errorRecoveryType(anyType); @@ -2588,7 +2581,7 @@ TypeId TypeChecker::checkRelationalOperation( { ScopePtr subScope = childScope(scope, subexp->location); resolve(predicates, subScope, true); - return unionOfTypes(rhsType, stripNil(checkExpr(subScope, *subexp->right).type, true), expr.location); + return unionOfTypes(rhsType, stripNil(checkExpr(subScope, *subexp->right).type, true), subScope, expr.location); } } @@ -2624,7 +2617,7 @@ TypeId TypeChecker::checkRelationalOperation( * report any problems that might have been surfaced as a result of this step because we might already * have a better, more descriptive error teed up. */ - Unifier state = mkUnifier(expr.location); + Unifier state = mkUnifier(scope, expr.location); if (!isEquality) { state.tryUnify(rhsType, lhsType); @@ -2703,7 +2696,7 @@ TypeId TypeChecker::checkRelationalOperation( { if (isEquality) { - Unifier state = mkUnifier(expr.location); + Unifier state = mkUnifier(scope, expr.location); state.tryUnify(addTypePack({booleanType}), ftv->retTypes); if (!state.errors.empty()) @@ -2755,11 +2748,11 @@ TypeId TypeChecker::checkRelationalOperation( case AstExprBinary::And: if (lhsIsAny) return lhsType; - return unionOfTypes(rhsType, booleanType, expr.location, false); + return unionOfTypes(rhsType, booleanType, scope, expr.location, false); case AstExprBinary::Or: if (lhsIsAny) return lhsType; - return unionOfTypes(lhsType, rhsType, expr.location); + return unionOfTypes(lhsType, rhsType, scope, expr.location); default: LUAU_ASSERT(0); ice(format("checkRelationalOperation called with incorrect binary expression '%s'", toString(expr.op).c_str()), expr.location); @@ -2816,7 +2809,7 @@ TypeId TypeChecker::checkBinaryOperation( } if (get(rhsType)) - unify(rhsType, lhsType, expr.location); + unify(rhsType, lhsType, scope, expr.location); if (typeCouldHaveMetatable(lhsType) || typeCouldHaveMetatable(rhsType)) { @@ -2826,7 +2819,7 @@ TypeId TypeChecker::checkBinaryOperation( TypePackId retTypePack = freshTypePack(scope); TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retTypePack)); - Unifier state = mkUnifier(expr.location); + Unifier state = mkUnifier(scope, expr.location); state.tryUnify(actualFunctionType, expectedFunctionType, /*isFunctionCall*/ true); reportErrors(state.errors); @@ -2876,8 +2869,8 @@ TypeId TypeChecker::checkBinaryOperation( switch (expr.op) { case AstExprBinary::Concat: - reportErrors(tryUnify(lhsType, addType(UnionTypeVar{{stringType, numberType}}), expr.left->location)); - reportErrors(tryUnify(rhsType, addType(UnionTypeVar{{stringType, numberType}}), expr.right->location)); + reportErrors(tryUnify(lhsType, addType(UnionTypeVar{{stringType, numberType}}), scope, expr.left->location)); + reportErrors(tryUnify(rhsType, addType(UnionTypeVar{{stringType, numberType}}), scope, expr.right->location)); return stringType; case AstExprBinary::Add: case AstExprBinary::Sub: @@ -2885,8 +2878,8 @@ TypeId TypeChecker::checkBinaryOperation( case AstExprBinary::Div: case AstExprBinary::Mod: case AstExprBinary::Pow: - reportErrors(tryUnify(lhsType, numberType, expr.left->location)); - reportErrors(tryUnify(rhsType, numberType, expr.right->location)); + reportErrors(tryUnify(lhsType, numberType, scope, expr.left->location)); + reportErrors(tryUnify(rhsType, numberType, scope, expr.right->location)); return numberType; default: // These should have been handled with checkRelationalOperation @@ -2961,10 +2954,10 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp WithPredicate result = checkExpr(scope, *expr.expr, annotationType); // Note: As an optimization, we try 'number <: number | string' first, as that is the more likely case. - if (canUnify(annotationType, result.type, expr.location).empty()) + if (canUnify(annotationType, result.type, scope, expr.location).empty()) return {annotationType, std::move(result.predicates)}; - if (canUnify(result.type, annotationType, expr.location).empty()) + if (canUnify(result.type, annotationType, scope, expr.location).empty()) return {annotationType, std::move(result.predicates)}; reportError(expr.location, TypesAreUnrelated{result.type, annotationType}); @@ -3101,7 +3094,7 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex } else if (auto indexer = lhsTable->indexer) { - Unifier state = mkUnifier(expr.location); + Unifier state = mkUnifier(scope, expr.location); state.tryUnify(stringType, indexer->indexType); TypeId retType = indexer->indexResultType; if (!state.errors.empty()) @@ -3213,7 +3206,7 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex if (exprTable->indexer) { const TableIndexer& indexer = *exprTable->indexer; - unify(indexType, indexer.indexType, expr.index->location); + unify(indexType, indexer.indexType, scope, expr.index->location); return indexer.indexResultType; } else if (exprTable->state == TableState::Unsealed || exprTable->state == TableState::Free) @@ -3829,7 +3822,7 @@ void TypeChecker::checkArgumentList( { // The use of unify here is deliberate. We don't want this unification // to be undoable. - unify(errorRecoveryType(scope), *argIter, state.location); + unify(errorRecoveryType(scope), *argIter, scope, state.location); ++argIter; } reportCountMismatchError(); @@ -3852,7 +3845,7 @@ void TypeChecker::checkArgumentList( TypeId e = errorRecoveryType(scope); while (argIter != endIter) { - unify(e, *argIter, state.location); + unify(e, *argIter, scope, state.location); ++argIter; } @@ -3869,7 +3862,7 @@ void TypeChecker::checkArgumentList( if (argIndex < argLocations.size()) location = argLocations[argIndex]; - unify(*argIter, vtp->ty, location); + unify(*argIter, vtp->ty, scope, location); ++argIter; ++argIndex; } @@ -3906,7 +3899,7 @@ void TypeChecker::checkArgumentList( } else { - unifyWithInstantiationIfNeeded(scope, *argIter, *paramIter, state); + unifyWithInstantiationIfNeeded(*argIter, *paramIter, scope, state); ++argIter; ++paramIter; } @@ -4114,7 +4107,7 @@ std::optional> TypeChecker::checkCallOverload(const Sc if (get(fn)) { - unify(anyTypePack, argPack, expr.location); + unify(anyTypePack, argPack, scope, expr.location); return {{anyTypePack}}; } @@ -4160,7 +4153,7 @@ std::optional> TypeChecker::checkCallOverload(const Sc UnifierOptions options; options.isFunctionCall = true; - unify(r, fn, expr.location, options); + unify(r, fn, scope, expr.location, options); return {{retPack}}; } @@ -4194,7 +4187,7 @@ std::optional> TypeChecker::checkCallOverload(const Sc if (!ftv) { reportError(TypeError{expr.func->location, CannotCallNonFunction{fn}}); - unify(errorRecoveryTypePack(scope), retPack, expr.func->location); + unify(errorRecoveryTypePack(scope), retPack, scope, expr.func->location); return {{errorRecoveryTypePack(retPack)}}; } @@ -4207,7 +4200,7 @@ std::optional> TypeChecker::checkCallOverload(const Sc return *ret; } - Unifier state = mkUnifier(expr.location); + Unifier state = mkUnifier(scope, expr.location); // Unify return types checkArgumentList(scope, state, retPack, ftv->retTypes, /*argLocations*/ {}); @@ -4269,7 +4262,7 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal std::vector editedParamList(args->head.begin() + 1, args->head.end()); TypePackId editedArgPack = addTypePack(TypePack{editedParamList}); - Unifier editedState = mkUnifier(expr.location); + Unifier editedState = mkUnifier(scope, expr.location); checkArgumentList(scope, editedState, editedArgPack, ftv->argTypes, editedArgLocations); if (editedState.errors.empty()) @@ -4299,7 +4292,7 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal editedArgList.insert(editedArgList.begin(), checkExpr(scope, *indexName->expr).type); TypePackId editedArgPack = addTypePack(TypePack{editedArgList}); - Unifier editedState = mkUnifier(expr.location); + Unifier editedState = mkUnifier(scope, expr.location); checkArgumentList(scope, editedState, editedArgPack, ftv->argTypes, editedArgLocations); @@ -4365,7 +4358,7 @@ void TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const Ast for (size_t i = 0; i < overloadTypes.size(); ++i) { TypeId overload = overloadTypes[i]; - Unifier state = mkUnifier(expr.location); + Unifier state = mkUnifier(scope, expr.location); // Unify return types if (const FunctionTypeVar* ftv = get(overload)) @@ -4415,7 +4408,7 @@ WithPredicate TypeChecker::checkExprList(const ScopePtr& scope, cons size_t lastIndex = exprs.size - 1; tp->head.reserve(lastIndex); - Unifier state = mkUnifier(location); + Unifier state = mkUnifier(scope, location); std::vector inverseLogs; @@ -4580,15 +4573,15 @@ TypeId TypeChecker::anyIfNonstrict(TypeId ty) const return ty; } -bool TypeChecker::unify(TypeId subTy, TypeId superTy, const Location& location) +bool TypeChecker::unify(TypeId subTy, TypeId superTy, const ScopePtr& scope, const Location& location) { UnifierOptions options; - return unify(subTy, superTy, location, options); + return unify(subTy, superTy, scope, location, options); } -bool TypeChecker::unify(TypeId subTy, TypeId superTy, const Location& location, const UnifierOptions& options) +bool TypeChecker::unify(TypeId subTy, TypeId superTy, const ScopePtr& scope, const Location& location, const UnifierOptions& options) { - Unifier state = mkUnifier(location); + Unifier state = mkUnifier(scope, location); state.tryUnify(subTy, superTy, options.isFunctionCall); state.log.commit(); @@ -4598,9 +4591,9 @@ bool TypeChecker::unify(TypeId subTy, TypeId superTy, const Location& location, return state.errors.empty(); } -bool TypeChecker::unify(TypePackId subTy, TypePackId superTy, const Location& location, CountMismatch::Context ctx) +bool TypeChecker::unify(TypePackId subTy, TypePackId superTy, const ScopePtr& scope, const Location& location, CountMismatch::Context ctx) { - Unifier state = mkUnifier(location); + Unifier state = mkUnifier(scope, location); state.ctx = ctx; state.tryUnify(subTy, superTy); @@ -4611,10 +4604,10 @@ bool TypeChecker::unify(TypePackId subTy, TypePackId superTy, const Location& lo return state.errors.empty(); } -bool TypeChecker::unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId subTy, TypeId superTy, const Location& location) +bool TypeChecker::unifyWithInstantiationIfNeeded(TypeId subTy, TypeId superTy, const ScopePtr& scope, const Location& location) { - Unifier state = mkUnifier(location); - unifyWithInstantiationIfNeeded(scope, subTy, superTy, state); + Unifier state = mkUnifier(scope, location); + unifyWithInstantiationIfNeeded(subTy, superTy, scope, state); state.log.commit(); @@ -4623,7 +4616,7 @@ bool TypeChecker::unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId s return state.errors.empty(); } -void TypeChecker::unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId subTy, TypeId superTy, Unifier& state) +void TypeChecker::unifyWithInstantiationIfNeeded(TypeId subTy, TypeId superTy, const ScopePtr& scope, Unifier& state) { if (!maybeGeneric(subTy)) // Quick check to see if we definitely can't instantiate @@ -4662,77 +4655,6 @@ void TypeChecker::unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId s } } -bool Anyification::isDirty(TypeId ty) -{ - if (ty->persistent) - return false; - - if (const TableTypeVar* ttv = log->getMutable(ty)) - return (ttv->state == TableState::Free || ttv->state == TableState::Unsealed); - else if (log->getMutable(ty)) - return true; - else if (get(ty)) - return true; - else - return false; -} - -bool Anyification::isDirty(TypePackId tp) -{ - if (tp->persistent) - return false; - - if (log->getMutable(tp)) - return true; - else - return false; -} - -TypeId Anyification::clean(TypeId ty) -{ - LUAU_ASSERT(isDirty(ty)); - if (const TableTypeVar* ttv = log->getMutable(ty)) - { - TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, TableState::Sealed}; - clone.definitionModuleName = ttv->definitionModuleName; - clone.name = ttv->name; - clone.syntheticName = ttv->syntheticName; - clone.tags = ttv->tags; - TypeId res = addType(std::move(clone)); - asMutable(res)->normal = ty->normal; - return res; - } - else if (auto ctv = get(ty)) - { - if (FFlag::LuauQuantifyConstrained) - { - std::vector copy = ctv->parts; - for (TypeId& ty : copy) - ty = replace(ty); - TypeId res = copy.size() == 1 ? copy[0] : addType(UnionTypeVar{std::move(copy)}); - auto [t, ok] = normalize(res, *arena, *iceHandler); - if (!ok) - normalizationTooComplex = true; - return t; - } - else - { - auto [t, ok] = normalize(ty, *arena, *iceHandler); - if (!ok) - normalizationTooComplex = true; - return t; - } - } - else - return anyType; -} - -TypePackId Anyification::clean(TypePackId tp) -{ - LUAU_ASSERT(isDirty(tp)); - return anyTypePack; -} - TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location) { ty = follow(ty); @@ -4804,7 +4726,7 @@ TypeId TypeChecker::anyify(const ScopePtr& scope, TypeId ty, Location location) ty = t; } - Anyification anyification{¤tModule->internalTypes, iceHandler, anyType, anyTypePack}; + Anyification anyification{¤tModule->internalTypes, scope, iceHandler, anyType, anyTypePack}; std::optional any = anyification.substitute(ty); if (anyification.normalizationTooComplex) reportError(location, NormalizationTooComplex{}); @@ -4827,7 +4749,7 @@ TypePackId TypeChecker::anyify(const ScopePtr& scope, TypePackId ty, Location lo ty = t; } - Anyification anyification{¤tModule->internalTypes, iceHandler, anyType, anyTypePack}; + Anyification anyification{¤tModule->internalTypes, scope, iceHandler, anyType, anyTypePack}; std::optional any = anyification.substitute(ty); if (any.has_value()) return *any; @@ -4963,9 +4885,9 @@ void TypeChecker::merge(RefinementMap& l, const RefinementMap& r) }); } -Unifier TypeChecker::mkUnifier(const Location& location) +Unifier TypeChecker::mkUnifier(const ScopePtr& scope, const Location& location) { - return Unifier{¤tModule->internalTypes, currentModule->mode, location, Variance::Covariant, unifierState}; + return Unifier{¤tModule->internalTypes, currentModule->mode, NotNull{scope.get()}, location, Variance::Covariant, unifierState}; } TypeId TypeChecker::freshType(const ScopePtr& scope) @@ -5029,10 +4951,7 @@ TypeIdPredicate TypeChecker::mkTruthyPredicate(bool sense) return sense ? std::nullopt : std::optional(ty); // at this point, anything else is kept if sense is true, or replaced by nil - if (FFlag::LuauFalsyPredicateReturnsNilInstead) - return sense ? ty : nilType; - else - return sense ? std::optional(ty) : std::nullopt; + return sense ? ty : nilType; }; } @@ -5875,8 +5794,8 @@ void TypeChecker::resolve(const IsAPredicate& isaP, RefinementMap& refis, const { auto predicate = [&](TypeId option) -> std::optional { // This by itself is not truly enough to determine that A is stronger than B or vice versa. - bool optionIsSubtype = canUnify(option, isaP.ty, isaP.location).empty(); - bool targetIsSubtype = canUnify(isaP.ty, option, isaP.location).empty(); + bool optionIsSubtype = canUnify(option, isaP.ty, scope, isaP.location).empty(); + bool targetIsSubtype = canUnify(isaP.ty, option, scope, isaP.location).empty(); // If A is a superset of B, then if sense is true, we promote A to B, otherwise we keep A. if (!optionIsSubtype && targetIsSubtype) @@ -6019,7 +5938,7 @@ void TypeChecker::resolve(const EqPredicate& eqP, RefinementMap& refis, const Sc if (maybeSingleton(eqP.type)) { // Normally we'd write option <: eqP.type, but singletons are always the subtype, so we flip this. - if (!sense || canUnify(eqP.type, option, eqP.location).empty()) + if (!sense || canUnify(eqP.type, option, scope, eqP.location).empty()) return sense ? eqP.type : option; // local variable works around an odd gcc 9.3 warning: may be used uninitialized @@ -6053,7 +5972,7 @@ std::vector TypeChecker::unTypePack(const ScopePtr& scope, TypePackId tp size_t oldErrorsSize = currentModule->errors.size(); - unify(tp, expectedTypePack, location); + unify(tp, expectedTypePack, scope, location); // HACK: tryUnify would undo the changes to the expectedTypePack if the length mismatches, but // we want to tie up free types to be error types, so we do this instead. diff --git a/Analysis/src/TypeUtils.cpp b/Analysis/src/TypeUtils.cpp index 66b38cf3..60bca0a3 100644 --- a/Analysis/src/TypeUtils.cpp +++ b/Analysis/src/TypeUtils.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/TypeUtils.h" +#include "Luau/Normalize.h" #include "Luau/Scope.h" #include "Luau/ToString.h" #include "Luau/TypeInfer.h" @@ -8,7 +9,7 @@ namespace Luau { -std::optional findMetatableEntry(ErrorVec& errors, TypeId type, std::string entry, Location location) +std::optional findMetatableEntry(ErrorVec& errors, TypeId type, const std::string& entry, Location location) { type = follow(type); @@ -35,7 +36,7 @@ std::optional findMetatableEntry(ErrorVec& errors, TypeId type, std::str return std::nullopt; } -std::optional findTablePropertyRespectingMeta(ErrorVec& errors, TypeId ty, Name name, Location location) +std::optional findTablePropertyRespectingMeta(ErrorVec& errors, TypeId ty, const std::string& name, Location location) { if (get(ty)) return ty; @@ -83,4 +84,110 @@ std::optional findTablePropertyRespectingMeta(ErrorVec& errors, TypeId t return std::nullopt; } +std::optional getIndexTypeFromType( + const ScopePtr& scope, ErrorVec& errors, TypeArena* arena, TypeId type, const std::string& prop, const Location& location, bool addErrors, + InternalErrorReporter& handle) +{ + type = follow(type); + + if (get(type) || get(type) || get(type)) + return type; + + if (auto f = get(type)) + *asMutable(type) = TableTypeVar{TableState::Free, f->level}; + + if (isString(type)) + { + std::optional mtIndex = Luau::findMetatableEntry(errors, getSingletonTypes().stringType, "__index", location); + LUAU_ASSERT(mtIndex); + type = *mtIndex; + } + + if (getTableType(type)) + { + return findTablePropertyRespectingMeta(errors, type, prop, location); + } + else if (const ClassTypeVar* cls = get(type)) + { + if (const Property* p = lookupClassProp(cls, prop)) + return p->type; + } + else if (const UnionTypeVar* utv = get(type)) + { + std::vector goodOptions; + std::vector badOptions; + + for (TypeId t : utv) + { + // TODO: we should probably limit recursion here? + // RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit); + + // Not needed when we normalize types. + if (get(follow(t))) + return t; + + if (std::optional ty = getIndexTypeFromType(scope, errors, arena, t, prop, location, /* addErrors= */ false, handle)) + goodOptions.push_back(*ty); + else + badOptions.push_back(t); + } + + if (!badOptions.empty()) + { + if (addErrors) + { + if (goodOptions.empty()) + errors.push_back(TypeError{location, UnknownProperty{type, prop}}); + else + errors.push_back(TypeError{location, MissingUnionProperty{type, badOptions, prop}}); + } + return std::nullopt; + } + + if (goodOptions.empty()) + return getSingletonTypes().neverType; + + if (goodOptions.size() == 1) + return goodOptions[0]; + + // TODO: inefficient. + TypeId result = arena->addType(UnionTypeVar{std::move(goodOptions)}); + auto [ty, ok] = normalize(result, NotNull{scope.get()}, *arena, handle); + if (!ok && addErrors) + errors.push_back(TypeError{location, NormalizationTooComplex{}}); + return ok ? ty : getSingletonTypes().anyType; + } + else if (const IntersectionTypeVar* itv = get(type)) + { + std::vector parts; + + for (TypeId t : itv->parts) + { + // TODO: we should probably limit recursion here? + // RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit); + + if (std::optional ty = getIndexTypeFromType(scope, errors, arena, t, prop, location, /* addErrors= */ false, handle)) + parts.push_back(*ty); + } + + // If no parts of the intersection had the property we looked up for, it never existed at all. + if (parts.empty()) + { + if (addErrors) + errors.push_back(TypeError{location, UnknownProperty{type, prop}}); + return std::nullopt; + } + + if (parts.size() == 1) + return parts[0]; + + return arena->addType(IntersectionTypeVar{std::move(parts)}); // Not at all correct. + } + + if (addErrors) + errors.push_back(TypeError{location, UnknownProperty{type, prop}}); + + return std::nullopt; +} + } // namespace Luau diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index ada2b012..9020b1ac 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -1135,7 +1135,7 @@ std::optional> magicFunctionFormat( { Location location = expr.args.data[std::min(i + dataOffset, expr.args.size - 1)]->location; - typechecker.unify(params[i + paramOffset], expected[i], location); + typechecker.unify(params[i + paramOffset], expected[i], scope, location); } // if we know the argument count or if we have too many arguments for sure, we can issue an error @@ -1234,7 +1234,7 @@ static std::optional> magicFunctionGmatch( if (returnTypes.empty()) return std::nullopt; - typechecker.unify(params[0], typechecker.stringType, expr.args.data[0]->location); + typechecker.unify(params[0], typechecker.stringType, scope, expr.args.data[0]->location); const TypePackId emptyPack = arena.addTypePack({}); const TypePackId returnList = arena.addTypePack(returnTypes); @@ -1269,13 +1269,13 @@ static std::optional> magicFunctionMatch( if (returnTypes.empty()) return std::nullopt; - typechecker.unify(params[0], typechecker.stringType, expr.args.data[0]->location); + typechecker.unify(params[0], typechecker.stringType, scope, expr.args.data[0]->location); const TypeId optionalNumber = arena.addType(UnionTypeVar{{typechecker.nilType, typechecker.numberType}}); size_t initIndex = expr.self ? 1 : 2; if (params.size() == 3 && expr.args.size > initIndex) - typechecker.unify(params[2], optionalNumber, expr.args.data[initIndex]->location); + typechecker.unify(params[2], optionalNumber, scope, expr.args.data[initIndex]->location); const TypePackId returnList = arena.addTypePack(returnTypes); return WithPredicate{returnList}; @@ -1320,17 +1320,17 @@ static std::optional> magicFunctionFind( return std::nullopt; } - typechecker.unify(params[0], typechecker.stringType, expr.args.data[0]->location); + typechecker.unify(params[0], typechecker.stringType, scope, expr.args.data[0]->location); const TypeId optionalNumber = arena.addType(UnionTypeVar{{typechecker.nilType, typechecker.numberType}}); const TypeId optionalBoolean = arena.addType(UnionTypeVar{{typechecker.nilType, typechecker.booleanType}}); size_t initIndex = expr.self ? 1 : 2; if (params.size() >= 3 && expr.args.size > initIndex) - typechecker.unify(params[2], optionalNumber, expr.args.data[initIndex]->location); + typechecker.unify(params[2], optionalNumber, scope, expr.args.data[initIndex]->location); if (params.size() == 4 && expr.args.size > plainIndex) - typechecker.unify(params[3], optionalBoolean, expr.args.data[plainIndex]->location); + typechecker.unify(params[3], optionalBoolean, scope, expr.args.data[plainIndex]->location); returnTypes.insert(returnTypes.begin(), {optionalNumber, optionalNumber}); diff --git a/Analysis/src/TypedAllocator.cpp b/Analysis/src/TypedAllocator.cpp index c7f31822..9ce8c3dc 100644 --- a/Analysis/src/TypedAllocator.cpp +++ b/Analysis/src/TypedAllocator.cpp @@ -27,27 +27,6 @@ LUAU_FASTFLAG(DebugLuauFreezeArena) namespace Luau { -static void* systemAllocateAligned(size_t size, size_t align) -{ -#ifdef _WIN32 - return _aligned_malloc(size, align); -#elif defined(__ANDROID__) // for Android 4.1 - return memalign(align, size); -#else - void* ptr; - return posix_memalign(&ptr, align, size) == 0 ? ptr : 0; -#endif -} - -static void systemDeallocateAligned(void* ptr) -{ -#ifdef _WIN32 - _aligned_free(ptr); -#else - free(ptr); -#endif -} - static size_t pageAlign(size_t size) { return (size + kPageSize - 1) & ~(kPageSize - 1); @@ -55,18 +34,31 @@ static size_t pageAlign(size_t size) void* pagedAllocate(size_t size) { - if (FFlag::DebugLuauFreezeArena) - return systemAllocateAligned(pageAlign(size), kPageSize); - else + // By default we use operator new/delete instead of malloc/free so that they can be overridden externally + if (!FFlag::DebugLuauFreezeArena) return ::operator new(size, std::nothrow); + + // On Windows, VirtualAlloc results in 64K granularity allocations; we allocate in chunks of ~32K so aligned_malloc is a little more efficient + // On Linux, we must use mmap because using regular heap results in mprotect() fragmenting the page table and us bumping into 64K mmap limit. +#ifdef _WIN32 + return _aligned_malloc(size, kPageSize); +#else + return mmap(nullptr, pageAlign(size), PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANON, -1, 0); +#endif } -void pagedDeallocate(void* ptr) +void pagedDeallocate(void* ptr, size_t size) { - if (FFlag::DebugLuauFreezeArena) - systemDeallocateAligned(ptr); - else - ::operator delete(ptr); + // By default we use operator new/delete instead of malloc/free so that they can be overridden externally + if (!FFlag::DebugLuauFreezeArena) + return ::operator delete(ptr); + +#ifdef _WIN32 + _aligned_free(ptr); +#else + int rc = munmap(ptr, size); + LUAU_ASSERT(rc == 0); +#endif } void pagedFreeze(void* ptr, size_t size) diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index a0f4725d..b5f58c83 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -20,7 +20,6 @@ LUAU_FASTINTVARIABLE(LuauTypeInferLowerBoundsIterationLimit, 2000); LUAU_FASTFLAG(LuauLowerBoundsCalculation); LUAU_FASTFLAG(LuauErrorRecoveryType); LUAU_FASTFLAG(LuauUnknownAndNeverType) -LUAU_FASTFLAG(LuauQuantifyConstrained) LUAU_FASTFLAGVARIABLE(LuauScalarShapeSubtyping, false) LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) @@ -318,9 +317,10 @@ static std::optional> getTableMat return std::nullopt; } -Unifier::Unifier(TypeArena* types, Mode mode, const Location& location, Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog) +Unifier::Unifier(TypeArena* types, Mode mode, NotNull scope, const Location& location, Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog) : types(types) , mode(mode) + , scope(scope) , log(parentLog) , location(location) , variance(variance) @@ -2091,13 +2091,11 @@ void Unifier::unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel de if (!freeTailPack) return; - TypeLevel level = FFlag::LuauQuantifyConstrained ? demotedLevel : freeTailPack->level; - TypePack* tp = getMutable(log.replace(tailPack, TypePack{})); for (; subIter != subEndIter; ++subIter) { - tp->head.push_back(types->addType(ConstrainedTypeVar{level, {follow(*subIter)}})); + tp->head.push_back(types->addType(ConstrainedTypeVar{demotedLevel, {follow(*subIter)}})); } tp->tail = subIter.tail(); @@ -2270,7 +2268,7 @@ bool Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ Unifier Unifier::makeChildUnifier() { - Unifier u = Unifier{types, mode, location, variance, sharedState, &log}; + Unifier u = Unifier{types, mode, scope, location, variance, sharedState, &log}; u.anyIsTop = anyIsTop; return u; } diff --git a/Ast/include/Luau/Parser.h b/Ast/include/Luau/Parser.h index 4b5ae315..046706d3 100644 --- a/Ast/include/Luau/Parser.h +++ b/Ast/include/Luau/Parser.h @@ -138,7 +138,7 @@ private: // funcbody ::= `(' [parlist] `)' block end // parlist ::= namelist [`,' `...'] | `...' std::pair parseFunctionBody( - bool hasself, const Lexeme& matchFunction, const AstName& debugname, std::optional localName); + bool hasself, const Lexeme& matchFunction, const AstName& debugname, const Name* localName); // explist ::= {exp `,'} exp void parseExprList(TempVector& result); @@ -217,7 +217,7 @@ private: AstExpr* parseSimpleExpr(); // args ::= `(' [explist] `)' | tableconstructor | String - AstExpr* parseFunctionArgs(AstExpr* func, bool self, const Location& selfLocation); + AstExpr* parseFunctionArgs(AstExpr* func, bool self); // tableconstructor ::= `{' [fieldlist] `}' // fieldlist ::= field {fieldsep field} [fieldsep] @@ -241,6 +241,7 @@ private: std::optional> parseCharArray(); AstExpr* parseString(); + AstExpr* parseNumber(); AstLocal* pushLocal(const Binding& binding); @@ -253,11 +254,24 @@ private: bool expectAndConsume(Lexeme::Type type, const char* context = nullptr); void expectAndConsumeFail(Lexeme::Type type, const char* context); - bool expectMatchAndConsume(char value, const Lexeme& begin, bool searchForMissing = false); - void expectMatchAndConsumeFail(Lexeme::Type type, const Lexeme& begin, const char* extra = nullptr); + struct MatchLexeme + { + MatchLexeme(const Lexeme& l) + : type(l.type) + , position(l.location.begin) + { + } - bool expectMatchEndAndConsume(Lexeme::Type type, const Lexeme& begin); - void expectMatchEndAndConsumeFail(Lexeme::Type type, const Lexeme& begin); + Lexeme::Type type; + Position position; + }; + + bool expectMatchAndConsume(char value, const MatchLexeme& begin, bool searchForMissing = false); + void expectMatchAndConsumeFail(Lexeme::Type type, const MatchLexeme& begin, const char* extra = nullptr); + bool expectMatchAndConsumeRecover(char value, const MatchLexeme& begin, bool searchForMissing); + + bool expectMatchEndAndConsume(Lexeme::Type type, const MatchLexeme& begin); + void expectMatchEndAndConsumeFail(Lexeme::Type type, const MatchLexeme& begin); template AstArray copy(const T* data, std::size_t size); @@ -283,6 +297,9 @@ private: AstTypeError* reportTypeAnnotationError(const Location& location, const AstArray& types, bool isMissing, const char* format, ...) LUAU_PRINTF_ATTR(5, 6); + AstExpr* reportFunctionArgsError(AstExpr* func, bool self); + void reportAmbiguousCallError(); + void nextLexeme(); struct Function @@ -350,7 +367,7 @@ private: AstName nameError; AstName nameNil; - Lexeme endMismatchSuspect; + MatchLexeme endMismatchSuspect; std::vector functionStack; diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 1eb9565f..e46eebf3 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -14,8 +14,6 @@ LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) -LUAU_FASTFLAGVARIABLE(LuauParserFunctionKeywordAsTypeHelp, false) - LUAU_FASTFLAGVARIABLE(LuauFixNamedFunctionParse, false) LUAU_DYNAMIC_FASTFLAGVARIABLE(LuaReportParseWrongNamedType, false) @@ -177,7 +175,7 @@ Parser::Parser(const char* buffer, size_t bufferSize, AstNameTable& names, Alloc , lexer(buffer, bufferSize, names) , allocator(allocator) , recursionCounter(0) - , endMismatchSuspect(Location(), Lexeme::Eof) + , endMismatchSuspect(Lexeme(Location(), Lexeme::Eof)) , localMap(AstName()) { Function top; @@ -657,7 +655,7 @@ AstStat* Parser::parseFunctionStat() matchRecoveryStopOnToken[Lexeme::ReservedEnd]++; - AstExprFunction* body = parseFunctionBody(hasself, matchFunction, debugname, {}).first; + AstExprFunction* body = parseFunctionBody(hasself, matchFunction, debugname, nullptr).first; matchRecoveryStopOnToken[Lexeme::ReservedEnd]--; @@ -686,7 +684,7 @@ AstStat* Parser::parseLocal() matchRecoveryStopOnToken[Lexeme::ReservedEnd]++; - auto [body, var] = parseFunctionBody(false, matchFunction, name.name, name); + auto [body, var] = parseFunctionBody(false, matchFunction, name.name, &name); matchRecoveryStopOnToken[Lexeme::ReservedEnd]--; @@ -778,7 +776,7 @@ AstDeclaredClassProp Parser::parseDeclaredClassMethod() genericPacks.size = 0; genericPacks.data = nullptr; - Lexeme matchParen = lexer.current(); + MatchLexeme matchParen = lexer.current(); expectAndConsume('(', "function parameter list start"); TempVector args(scratchBinding); @@ -834,7 +832,7 @@ AstStat* Parser::parseDeclaration(const Location& start) auto [generics, genericPacks] = parseGenericTypeList(/* withDefaultValues= */ false); - Lexeme matchParen = lexer.current(); + MatchLexeme matchParen = lexer.current(); expectAndConsume('(', "global function declaration"); @@ -970,13 +968,13 @@ AstStat* Parser::parseCompoundAssignment(AstExpr* initial, AstExprBinary::Op op) // funcbody ::= `(' [parlist] `)' [`:' ReturnType] block end // parlist ::= bindinglist [`,' `...'] | `...' std::pair Parser::parseFunctionBody( - bool hasself, const Lexeme& matchFunction, const AstName& debugname, std::optional localName) + bool hasself, const Lexeme& matchFunction, const AstName& debugname, const Name* localName) { Location start = matchFunction.location; auto [generics, genericPacks] = parseGenericTypeList(/* withDefaultValues= */ false); - Lexeme matchParen = lexer.current(); + MatchLexeme matchParen = lexer.current(); expectAndConsume('(', "function"); TempVector args(scratchBinding); @@ -988,7 +986,7 @@ std::pair Parser::parseFunctionBody( std::tie(vararg, varargAnnotation) = parseBindingList(args, /* allowDot3= */ true); std::optional argLocation = matchParen.type == Lexeme::Type('(') && lexer.current().type == Lexeme::Type(')') - ? std::make_optional(Location(matchParen.location.begin, lexer.current().location.end)) + ? std::make_optional(Location(matchParen.position, lexer.current().location.end)) : std::nullopt; expectMatchAndConsume(')', matchParen, true); @@ -1255,7 +1253,7 @@ AstType* Parser::parseTableTypeAnnotation() Location start = lexer.current().location; - Lexeme matchBrace = lexer.current(); + MatchLexeme matchBrace = lexer.current(); expectAndConsume('{', "table type"); while (lexer.current().type != '}') @@ -1628,7 +1626,7 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) { return parseFunctionTypeAnnotation(allowPack); } - else if (FFlag::LuauParserFunctionKeywordAsTypeHelp && lexer.current().type == Lexeme::ReservedFunction) + else if (lexer.current().type == Lexeme::ReservedFunction) { Location location = lexer.current().location; @@ -1912,14 +1910,14 @@ AstExpr* Parser::parsePrefixExpr() { if (lexer.current().type == '(') { - Location start = lexer.current().location; + Position start = lexer.current().location.begin; - Lexeme matchParen = lexer.current(); + MatchLexeme matchParen = lexer.current(); nextLexeme(); AstExpr* expr = parseExpr(); - Location end = lexer.current().location; + Position end = lexer.current().location.end; if (lexer.current().type != ')') { @@ -1927,7 +1925,7 @@ AstExpr* Parser::parsePrefixExpr() expectMatchAndConsumeFail(static_cast(')'), matchParen, suggestion); - end = lexer.previousLocation(); + end = lexer.previousLocation().end; } else { @@ -1945,7 +1943,7 @@ AstExpr* Parser::parsePrefixExpr() // primaryexp -> prefixexp { `.' NAME | `[' exp `]' | `:' NAME funcargs | funcargs } AstExpr* Parser::parsePrimaryExpr(bool asStatement) { - Location start = lexer.current().location; + Position start = lexer.current().location.begin; AstExpr* expr = parsePrefixExpr(); @@ -1960,16 +1958,16 @@ AstExpr* Parser::parsePrimaryExpr(bool asStatement) Name index = parseIndexName(nullptr, opPosition); - expr = allocator.alloc(Location(start, index.location), expr, index.name, index.location, opPosition, '.'); + expr = allocator.alloc(Location(start, index.location.end), expr, index.name, index.location, opPosition, '.'); } else if (lexer.current().type == '[') { - Lexeme matchBracket = lexer.current(); + MatchLexeme matchBracket = lexer.current(); nextLexeme(); AstExpr* index = parseExpr(); - Location end = lexer.current().location; + Position end = lexer.current().location.end; expectMatchAndConsume(']', matchBracket); @@ -1981,27 +1979,24 @@ AstExpr* Parser::parsePrimaryExpr(bool asStatement) nextLexeme(); Name index = parseIndexName("method name", opPosition); - AstExpr* func = allocator.alloc(Location(start, index.location), expr, index.name, index.location, opPosition, ':'); + AstExpr* func = allocator.alloc(Location(start, index.location.end), expr, index.name, index.location, opPosition, ':'); - expr = parseFunctionArgs(func, true, index.location); + expr = parseFunctionArgs(func, true); } else if (lexer.current().type == '(') { // This error is handled inside 'parseFunctionArgs' as well, but for better error recovery we need to break out the current loop here if (!asStatement && expr->location.end.line != lexer.current().location.begin.line) { - report(lexer.current().location, - "Ambiguous syntax: this looks like an argument list for a function call, but could also be a start of " - "new statement; use ';' to separate statements"); - + reportAmbiguousCallError(); break; } - expr = parseFunctionArgs(expr, false, Location()); + expr = parseFunctionArgs(expr, false); } else if (lexer.current().type == '{' || lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::QuotedString) { - expr = parseFunctionArgs(expr, false, Location()); + expr = parseFunctionArgs(expr, false); } else { @@ -2156,7 +2151,7 @@ static ConstantNumberParseResult parseInteger(double& result, const char* data, return ConstantNumberParseResult::Ok; } -static ConstantNumberParseResult parseNumber(double& result, const char* data) +static ConstantNumberParseResult parseDouble(double& result, const char* data) { LUAU_ASSERT(FFlag::LuauLintParseIntegerIssues); @@ -2214,61 +2209,11 @@ AstExpr* Parser::parseSimpleExpr() Lexeme matchFunction = lexer.current(); nextLexeme(); - return parseFunctionBody(false, matchFunction, AstName(), {}).first; + return parseFunctionBody(false, matchFunction, AstName(), nullptr).first; } else if (lexer.current().type == Lexeme::Number) { - scratchData.assign(lexer.current().data, lexer.current().length); - - // Remove all internal _ - they don't hold any meaning and this allows parsing code to just pass the string pointer to strtod et al - if (scratchData.find('_') != std::string::npos) - { - scratchData.erase(std::remove(scratchData.begin(), scratchData.end(), '_'), scratchData.end()); - } - - if (FFlag::LuauLintParseIntegerIssues) - { - double value = 0; - ConstantNumberParseResult result = parseNumber(value, scratchData.c_str()); - nextLexeme(); - - if (result == ConstantNumberParseResult::Malformed) - return reportExprError(start, {}, "Malformed number"); - - return allocator.alloc(start, value, result); - } - else if (DFFlag::LuaReportParseIntegerIssues) - { - double value = 0; - if (const char* error = parseNumber_DEPRECATED2(value, scratchData.c_str())) - { - nextLexeme(); - - return reportExprError(start, {}, "%s", error); - } - else - { - nextLexeme(); - - return allocator.alloc(start, value); - } - } - else - { - double value = 0; - if (parseNumber_DEPRECATED(value, scratchData.c_str())) - { - nextLexeme(); - - return allocator.alloc(start, value); - } - else - { - nextLexeme(); - - return reportExprError(start, {}, "Malformed number"); - } - } + return parseNumber(); } else if (lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::QuotedString) { @@ -2309,18 +2254,15 @@ AstExpr* Parser::parseSimpleExpr() } // args ::= `(' [explist] `)' | tableconstructor | String -AstExpr* Parser::parseFunctionArgs(AstExpr* func, bool self, const Location& selfLocation) +AstExpr* Parser::parseFunctionArgs(AstExpr* func, bool self) { if (lexer.current().type == '(') { Position argStart = lexer.current().location.end; if (func->location.end.line != lexer.current().location.begin.line) - { - report(lexer.current().location, "Ambiguous syntax: this looks like an argument list for a function call, but could also be a start of " - "new statement; use ';' to separate statements"); - } + reportAmbiguousCallError(); - Lexeme matchParen = lexer.current(); + MatchLexeme matchParen = lexer.current(); nextLexeme(); TempVector args(scratchExpr); @@ -2352,18 +2294,29 @@ AstExpr* Parser::parseFunctionArgs(AstExpr* func, bool self, const Location& sel } else { - if (self && lexer.current().location.begin.line != func->location.end.line) - { - return reportExprError(func->location, copy({func}), "Expected function call arguments after '('"); - } - else - { - return reportExprError(Location(func->location.begin, lexer.current().location.begin), copy({func}), - "Expected '(', '{' or when parsing function call, got %s", lexer.current().toString().c_str()); - } + return reportFunctionArgsError(func, self); } } +LUAU_NOINLINE AstExpr* Parser::reportFunctionArgsError(AstExpr* func, bool self) +{ + if (self && lexer.current().location.begin.line != func->location.end.line) + { + return reportExprError(func->location, copy({func}), "Expected function call arguments after '('"); + } + else + { + return reportExprError(Location(func->location.begin, lexer.current().location.begin), copy({func}), + "Expected '(', '{' or when parsing function call, got %s", lexer.current().toString().c_str()); + } +} + +LUAU_NOINLINE void Parser::reportAmbiguousCallError() +{ + report(lexer.current().location, "Ambiguous syntax: this looks like an argument list for a function call, but could also be a start of " + "new statement; use ';' to separate statements"); +} + // tableconstructor ::= `{' [fieldlist] `}' // fieldlist ::= field {fieldsep field} [fieldsep] // field ::= `[' exp `]' `=' exp | Name `=' exp | exp @@ -2374,14 +2327,14 @@ AstExpr* Parser::parseTableConstructor() Location start = lexer.current().location; - Lexeme matchBrace = lexer.current(); + MatchLexeme matchBrace = lexer.current(); expectAndConsume('{', "table literal"); while (lexer.current().type != '}') { if (lexer.current().type == '[') { - Lexeme matchLocationBracket = lexer.current(); + MatchLexeme matchLocationBracket = lexer.current(); nextLexeme(); AstExpr* key = parseExpr(); @@ -2692,6 +2645,63 @@ AstExpr* Parser::parseString() return reportExprError(location, {}, "String literal contains malformed escape sequence"); } +AstExpr* Parser::parseNumber() +{ + Location start = lexer.current().location; + + scratchData.assign(lexer.current().data, lexer.current().length); + + // Remove all internal _ - they don't hold any meaning and this allows parsing code to just pass the string pointer to strtod et al + if (scratchData.find('_') != std::string::npos) + { + scratchData.erase(std::remove(scratchData.begin(), scratchData.end(), '_'), scratchData.end()); + } + + if (FFlag::LuauLintParseIntegerIssues) + { + double value = 0; + ConstantNumberParseResult result = parseDouble(value, scratchData.c_str()); + nextLexeme(); + + if (result == ConstantNumberParseResult::Malformed) + return reportExprError(start, {}, "Malformed number"); + + return allocator.alloc(start, value, result); + } + else if (DFFlag::LuaReportParseIntegerIssues) + { + double value = 0; + if (const char* error = parseNumber_DEPRECATED2(value, scratchData.c_str())) + { + nextLexeme(); + + return reportExprError(start, {}, "%s", error); + } + else + { + nextLexeme(); + + return allocator.alloc(start, value); + } + } + else + { + double value = 0; + if (parseNumber_DEPRECATED(value, scratchData.c_str())) + { + nextLexeme(); + + return allocator.alloc(start, value); + } + else + { + nextLexeme(); + + return reportExprError(start, {}, "Malformed number"); + } + } +} + AstLocal* Parser::pushLocal(const Binding& binding) { const Name& name = binding.name; @@ -2763,7 +2773,7 @@ LUAU_NOINLINE void Parser::expectAndConsumeFail(Lexeme::Type type, const char* c report(lexer.current().location, "Expected %s, got %s", typeString.c_str(), currLexemeString.c_str()); } -bool Parser::expectMatchAndConsume(char value, const Lexeme& begin, bool searchForMissing) +bool Parser::expectMatchAndConsume(char value, const MatchLexeme& begin, bool searchForMissing) { Lexeme::Type type = static_cast(static_cast(value)); @@ -2771,42 +2781,7 @@ bool Parser::expectMatchAndConsume(char value, const Lexeme& begin, bool searchF { expectMatchAndConsumeFail(type, begin); - if (searchForMissing) - { - // previous location is taken because 'current' lexeme is already the next token - unsigned currentLine = lexer.previousLocation().end.line; - - // search to the end of the line for expected token - // we will also stop if we hit a token that can be handled by parsing function above the current one - Lexeme::Type lexemeType = lexer.current().type; - - while (currentLine == lexer.current().location.begin.line && lexemeType != type && matchRecoveryStopOnToken[lexemeType] == 0) - { - nextLexeme(); - lexemeType = lexer.current().type; - } - - if (lexemeType == type) - { - nextLexeme(); - - return true; - } - } - else - { - // check if this is an extra token and the expected token is next - if (lexer.lookahead().type == type) - { - // skip invalid and consume expected - nextLexeme(); - nextLexeme(); - - return true; - } - } - - return false; + return expectMatchAndConsumeRecover(value, begin, searchForMissing); } else { @@ -2816,21 +2791,64 @@ bool Parser::expectMatchAndConsume(char value, const Lexeme& begin, bool searchF } } -// LUAU_NOINLINE is used to limit the stack cost of this function due to std::string objects, and to increase caller performance since this code is -// cold -LUAU_NOINLINE void Parser::expectMatchAndConsumeFail(Lexeme::Type type, const Lexeme& begin, const char* extra) +LUAU_NOINLINE bool Parser::expectMatchAndConsumeRecover(char value, const MatchLexeme& begin, bool searchForMissing) { - std::string typeString = Lexeme(Location(Position(0, 0), 0), type).toString(); + Lexeme::Type type = static_cast(static_cast(value)); - if (lexer.current().location.begin.line == begin.location.begin.line) - report(lexer.current().location, "Expected %s (to close %s at column %d), got %s%s", typeString.c_str(), begin.toString().c_str(), - begin.location.begin.column + 1, lexer.current().toString().c_str(), extra ? extra : ""); + if (searchForMissing) + { + // previous location is taken because 'current' lexeme is already the next token + unsigned currentLine = lexer.previousLocation().end.line; + + // search to the end of the line for expected token + // we will also stop if we hit a token that can be handled by parsing function above the current one + Lexeme::Type lexemeType = lexer.current().type; + + while (currentLine == lexer.current().location.begin.line && lexemeType != type && matchRecoveryStopOnToken[lexemeType] == 0) + { + nextLexeme(); + lexemeType = lexer.current().type; + } + + if (lexemeType == type) + { + nextLexeme(); + + return true; + } + } else - report(lexer.current().location, "Expected %s (to close %s at line %d), got %s%s", typeString.c_str(), begin.toString().c_str(), - begin.location.begin.line + 1, lexer.current().toString().c_str(), extra ? extra : ""); + { + // check if this is an extra token and the expected token is next + if (lexer.lookahead().type == type) + { + // skip invalid and consume expected + nextLexeme(); + nextLexeme(); + + return true; + } + } + + return false; } -bool Parser::expectMatchEndAndConsume(Lexeme::Type type, const Lexeme& begin) +// LUAU_NOINLINE is used to limit the stack cost of this function due to std::string objects, and to increase caller performance since this code is +// cold +LUAU_NOINLINE void Parser::expectMatchAndConsumeFail(Lexeme::Type type, const MatchLexeme& begin, const char* extra) +{ + std::string typeString = Lexeme(Location(Position(0, 0), 0), type).toString(); + std::string matchString = Lexeme(Location(Position(0, 0), 0), begin.type).toString(); + + if (lexer.current().location.begin.line == begin.position.line) + report(lexer.current().location, "Expected %s (to close %s at column %d), got %s%s", typeString.c_str(), matchString.c_str(), + begin.position.column + 1, lexer.current().toString().c_str(), extra ? extra : ""); + else + report(lexer.current().location, "Expected %s (to close %s at line %d), got %s%s", typeString.c_str(), matchString.c_str(), + begin.position.line + 1, lexer.current().toString().c_str(), extra ? extra : ""); +} + +bool Parser::expectMatchEndAndConsume(Lexeme::Type type, const MatchLexeme& begin) { if (lexer.current().type != type) { @@ -2852,9 +2870,9 @@ bool Parser::expectMatchEndAndConsume(Lexeme::Type type, const Lexeme& begin) { // If the token matches on a different line and a different column, it suggests misleading indentation // This can be used to pinpoint the problem location for a possible future *actual* mismatch - if (lexer.current().location.begin.line != begin.location.begin.line && - lexer.current().location.begin.column != begin.location.begin.column && - endMismatchSuspect.location.begin.line < begin.location.begin.line) // Only replace the previous suspect with more recent suspects + if (lexer.current().location.begin.line != begin.position.line && + lexer.current().location.begin.column != begin.position.column && + endMismatchSuspect.position.line < begin.position.line) // Only replace the previous suspect with more recent suspects { endMismatchSuspect = begin; } @@ -2867,12 +2885,12 @@ bool Parser::expectMatchEndAndConsume(Lexeme::Type type, const Lexeme& begin) // LUAU_NOINLINE is used to limit the stack cost of this function due to std::string objects, and to increase caller performance since this code is // cold -LUAU_NOINLINE void Parser::expectMatchEndAndConsumeFail(Lexeme::Type type, const Lexeme& begin) +LUAU_NOINLINE void Parser::expectMatchEndAndConsumeFail(Lexeme::Type type, const MatchLexeme& begin) { - if (endMismatchSuspect.type != Lexeme::Eof && endMismatchSuspect.location.begin.line > begin.location.begin.line) + if (endMismatchSuspect.type != Lexeme::Eof && endMismatchSuspect.position.line > begin.position.line) { - std::string suggestion = - format("; did you forget to close %s at line %d?", endMismatchSuspect.toString().c_str(), endMismatchSuspect.location.begin.line + 1); + std::string matchString = Lexeme(Location(Position(0, 0), 0), endMismatchSuspect.type).toString(); + std::string suggestion = format("; did you forget to close %s at line %d?", matchString.c_str(), endMismatchSuspect.position.line + 1); expectMatchAndConsumeFail(type, begin, suggestion.c_str()); } diff --git a/Common/include/Luau/Bytecode.h b/Common/include/Luau/Bytecode.h index 1d6b18e5..d6660f05 100644 --- a/Common/include/Luau/Bytecode.h +++ b/Common/include/Luau/Bytecode.h @@ -515,6 +515,9 @@ enum LuauBuiltinFunction // rawlen LBF_RAWLEN, + + // bit32.extract(_, k, k) + LBF_BIT32_EXTRACTK, }; // Capture type, used in LOP_CAPTURE diff --git a/Compiler/src/BuiltinFolding.cpp b/Compiler/src/BuiltinFolding.cpp index e76da4e2..03b5918c 100644 --- a/Compiler/src/BuiltinFolding.cpp +++ b/Compiler/src/BuiltinFolding.cpp @@ -319,11 +319,12 @@ Constant foldBuiltin(int bfid, const Constant* args, size_t count) break; case LBF_BIT32_EXTRACT: - if (count == 3 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number && args[2].type == Constant::Type_Number) + if (count >= 2 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number && + (count == 2 || args[2].type == Constant::Type_Number)) { uint32_t u = bit32(args[0].valueNumber); int f = int(args[1].valueNumber); - int w = int(args[2].valueNumber); + int w = count == 2 ? 1 : int(args[2].valueNumber); if (f >= 0 && w > 0 && f + w <= 32) { @@ -356,13 +357,13 @@ Constant foldBuiltin(int bfid, const Constant* args, size_t count) break; case LBF_BIT32_REPLACE: - if (count == 4 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number && args[2].type == Constant::Type_Number && - args[3].type == Constant::Type_Number) + if (count >= 3 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number && args[2].type == Constant::Type_Number && + (count == 3 || args[3].type == Constant::Type_Number)) { uint32_t n = bit32(args[0].valueNumber); uint32_t v = bit32(args[1].valueNumber); int f = int(args[2].valueNumber); - int w = int(args[3].valueNumber); + int w = count == 3 ? 1 : int(args[3].valueNumber); if (f >= 0 && w > 0 && f + w <= 32) { diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index eb815225..bd8744cc 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -23,13 +23,12 @@ LUAU_FASTINTVARIABLE(LuauCompileInlineThreshold, 25) LUAU_FASTINTVARIABLE(LuauCompileInlineThresholdMaxBoost, 300) LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5) -LUAU_FASTFLAGVARIABLE(LuauCompileNoIpairs, false) - -LUAU_FASTFLAGVARIABLE(LuauCompileFreeReassign, false) LUAU_FASTFLAGVARIABLE(LuauCompileXEQ, false) LUAU_FASTFLAGVARIABLE(LuauCompileOptimalAssignment, false) +LUAU_FASTFLAGVARIABLE(LuauCompileExtractK, false) + namespace Luau { @@ -403,18 +402,37 @@ struct Compiler } } - void compileExprFastcallN(AstExprCall* expr, uint8_t target, uint8_t targetCount, bool targetTop, bool multRet, uint8_t regs, int bfid) + void compileExprFastcallN(AstExprCall* expr, uint8_t target, uint8_t targetCount, bool targetTop, bool multRet, uint8_t regs, int bfid, int bfK = -1) { LUAU_ASSERT(!expr->self); - LUAU_ASSERT(expr->args.size <= 2); + LUAU_ASSERT(expr->args.size >= 1); + LUAU_ASSERT(expr->args.size <= 2 || (bfid == LBF_BIT32_EXTRACTK && expr->args.size == 3)); + LUAU_ASSERT(bfid == LBF_BIT32_EXTRACTK ? bfK >= 0 : bfK < 0); LuauOpcode opc = expr->args.size == 1 ? LOP_FASTCALL1 : LOP_FASTCALL2; - uint32_t args[2] = {}; + if (FFlag::LuauCompileExtractK) + { + opc = expr->args.size == 1 ? LOP_FASTCALL1 : (bfK >= 0 || isConstant(expr->args.data[1])) ? LOP_FASTCALL2K : LOP_FASTCALL2; + } + + uint32_t args[3] = {}; for (size_t i = 0; i < expr->args.size; ++i) { - if (i > 0) + if (FFlag::LuauCompileExtractK) + { + if (i > 0 && opc == LOP_FASTCALL2K) + { + int32_t cid = getConstantIndex(expr->args.data[i]); + if (cid < 0) + CompileError::raise(expr->location, "Exceeded constant limit; simplify the code to compile"); + + args[i] = cid; + continue; // TODO: remove this and change if below to else if + } + } + else if (i > 0) { if (int32_t cid = getConstantIndex(expr->args.data[i]); cid >= 0) { @@ -425,7 +443,9 @@ struct Compiler } if (int reg = getExprLocalReg(expr->args.data[i]); reg >= 0) + { args[i] = uint8_t(reg); + } else { args[i] = uint8_t(regs + 1 + i); @@ -437,21 +457,31 @@ struct Compiler bytecode.emitABC(opc, uint8_t(bfid), uint8_t(args[0]), 0); if (opc != LOP_FASTCALL1) - bytecode.emitAux(args[1]); + bytecode.emitAux(bfK >= 0 ? bfK : args[1]); // Set up a traditional Lua stack for the subsequent LOP_CALL. // Note, as with other instructions that immediately follow FASTCALL, these are normally not executed and are used as a fallback for // these FASTCALL variants. for (size_t i = 0; i < expr->args.size; ++i) { - if (i > 0 && opc == LOP_FASTCALL2K) + if (FFlag::LuauCompileExtractK) { - emitLoadK(uint8_t(regs + 1 + i), args[i]); - break; + if (i > 0 && opc == LOP_FASTCALL2K) + emitLoadK(uint8_t(regs + 1 + i), args[i]); + else if (args[i] != regs + 1 + i) + bytecode.emitABC(LOP_MOVE, uint8_t(regs + 1 + i), uint8_t(args[i]), 0); } + else + { + if (i > 0 && opc == LOP_FASTCALL2K) + { + emitLoadK(uint8_t(regs + 1 + i), args[i]); + break; + } - if (args[i] != regs + 1 + i) - bytecode.emitABC(LOP_MOVE, uint8_t(regs + 1 + i), uint8_t(args[i]), 0); + if (args[i] != regs + 1 + i) + bytecode.emitABC(LOP_MOVE, uint8_t(regs + 1 + i), uint8_t(args[i]), 0); + } } // note, these instructions are normally not executed and are used as a fallback for FASTCALL @@ -600,7 +630,7 @@ struct Compiler } else { - AstExprLocal* le = FFlag::LuauCompileFreeReassign ? getExprLocal(arg) : arg->as(); + AstExprLocal* le = getExprLocal(arg); Variable* lv = le ? variables.find(le->local) : nullptr; // if the argument is a local that isn't mutated, we will simply reuse the existing register @@ -723,6 +753,26 @@ struct Compiler bfid = -1; } + // Optimization: for bit32.extract with constant in-range f/w we compile using FASTCALL2K and a special builtin + if (FFlag::LuauCompileExtractK && bfid == LBF_BIT32_EXTRACT && expr->args.size == 3 && isConstant(expr->args.data[1]) && isConstant(expr->args.data[2])) + { + Constant fc = getConstant(expr->args.data[1]); + Constant wc = getConstant(expr->args.data[2]); + + int fi = fc.type == Constant::Type_Number ? int(fc.valueNumber) : -1; + int wi = wc.type == Constant::Type_Number ? int(wc.valueNumber) : -1; + + if (fi >= 0 && wi > 0 && fi + wi <= 32) + { + int fwp = fi | ((wi - 1) << 5); + int32_t cid = bytecode.addConstantNumber(fwp); + if (cid < 0) + CompileError::raise(expr->location, "Exceeded constant limit; simplify the code to compile"); + + return compileExprFastcallN(expr, target, targetCount, targetTop, multRet, regs, LBF_BIT32_EXTRACTK, cid); + } + } + // Optimization: for 1/2 argument fast calls use specialized opcodes if (bfid >= 0 && expr->args.size >= 1 && expr->args.size <= 2 && !isExprMultRet(expr->args.data[expr->args.size - 1])) return compileExprFastcallN(expr, target, targetCount, targetTop, multRet, regs, bfid); @@ -1218,7 +1268,7 @@ struct Compiler { // disambiguation: there's 4 cases (we only need truthy or falsy results based on onlyTruth) // onlyTruth = 1: a and b transforms to a ? b : dontcare - // onlyTruth = 1: a or b transforms to a ? a : a + // onlyTruth = 1: a or b transforms to a ? a : b // onlyTruth = 0: a and b transforms to !a ? a : b // onlyTruth = 0: a or b transforms to !a ? b : dontcare if (onlyTruth == (expr->op == AstExprBinary::And)) @@ -2576,7 +2626,7 @@ struct Compiler return; // Optimization: for 1-1 local assignments, we can reuse the register *if* neither local is mutated - if (FFlag::LuauCompileFreeReassign && options.optimizationLevel >= 1 && stat->vars.size == 1 && stat->values.size == 1) + if (options.optimizationLevel >= 1 && stat->vars.size == 1 && stat->values.size == 1) { if (AstExprLocal* re = getExprLocal(stat->values.data[0])) { @@ -2790,7 +2840,6 @@ struct Compiler LUAU_ASSERT(vars == regs + 3); LuauOpcode skipOp = LOP_FORGPREP; - LuauOpcode loopOp = LOP_FORGLOOP; // Optimization: when we iterate via pairs/ipairs, we generate special bytecode that optimizes the traversal using internal iteration index // These instructions dynamically check if generator is equal to next/inext and bail out @@ -2802,25 +2851,16 @@ struct Compiler Builtin builtin = getBuiltin(stat->values.data[0]->as()->func, globals, variables); if (builtin.isGlobal("ipairs")) // for .. in ipairs(t) - { skipOp = LOP_FORGPREP_INEXT; - loopOp = FFlag::LuauCompileNoIpairs ? LOP_FORGLOOP : LOP_FORGLOOP_INEXT; - } else if (builtin.isGlobal("pairs")) // for .. in pairs(t) - { skipOp = LOP_FORGPREP_NEXT; - loopOp = LOP_FORGLOOP; - } } else if (stat->values.size == 2) { Builtin builtin = getBuiltin(stat->values.data[0], globals, variables); if (builtin.isGlobal("next")) // for .. in next,t - { skipOp = LOP_FORGPREP_NEXT; - loopOp = LOP_FORGLOOP; - } } } @@ -2846,19 +2886,9 @@ struct Compiler size_t backLabel = bytecode.emitLabel(); - bytecode.emitAD(loopOp, regs, 0); - - if (FFlag::LuauCompileNoIpairs) - { - // TODO: remove loopOp as it's a constant now - LUAU_ASSERT(loopOp == LOP_FORGLOOP); - - // FORGLOOP uses aux to encode variable count and fast path flag for ipairs traversal in the high bit - bytecode.emitAux((skipOp == LOP_FORGPREP_INEXT ? 0x80000000 : 0) | uint32_t(stat->vars.size)); - } - // note: FORGLOOP needs variable count encoded in AUX field, other loop instructions assume a fixed variable count - else if (loopOp == LOP_FORGLOOP) - bytecode.emitAux(uint32_t(stat->vars.size)); + // FORGLOOP uses aux to encode variable count and fast path flag for ipairs traversal in the high bit + bytecode.emitAD(LOP_FORGLOOP, regs, 0); + bytecode.emitAux((skipOp == LOP_FORGPREP_INEXT ? 0x80000000 : 0) | uint32_t(stat->vars.size)); size_t endLabel = bytecode.emitLabel(); diff --git a/Sources.cmake b/Sources.cmake index 7a27684e..50770f92 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -66,6 +66,7 @@ target_sources(Luau.CodeGen PRIVATE # Luau.Analysis Sources target_sources(Luau.Analysis PRIVATE + Analysis/include/Luau/Anyification.h Analysis/include/Luau/ApplyTypeFunction.h Analysis/include/Luau/AstJsonEncoder.h Analysis/include/Luau/AstQuery.h @@ -115,6 +116,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/Variant.h Analysis/include/Luau/VisitTypeVar.h + Analysis/src/Anyification.cpp Analysis/src/ApplyTypeFunction.cpp Analysis/src/AstJsonEncoder.cpp Analysis/src/AstQuery.cpp @@ -126,6 +128,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/ConstraintGraphBuilder.cpp Analysis/src/ConstraintSolver.cpp Analysis/src/ConstraintSolverLogger.cpp + Analysis/src/EmbeddedBuiltinDefinitions.cpp Analysis/src/Error.cpp Analysis/src/Frontend.cpp Analysis/src/Instantiation.cpp @@ -155,7 +158,6 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/TypeVar.cpp Analysis/src/Unifiable.cpp Analysis/src/Unifier.cpp - Analysis/src/EmbeddedBuiltinDefinitions.cpp ) # Luau.VM Sources diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index 77788e40..af97dc4a 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -1209,9 +1209,8 @@ void* lua_newuserdatadtor(lua_State* L, size_t sz, void (*dtor)(void*)) { luaC_checkGC(L); luaC_checkthreadsleep(L); - size_t as = sz + sizeof(dtor); - if (as < sizeof(dtor)) - as = SIZE_MAX; // Will cause a memory error in luaU_newudata. + // make sure sz + sizeof(dtor) doesn't overflow; luaU_newdata will reject SIZE_MAX correctly + size_t as = sz < SIZE_MAX - sizeof(dtor) ? sz + sizeof(dtor) : SIZE_MAX; Udata* u = luaU_newudata(L, as, UTAG_IDTOR); memcpy(&u->data + sz, &dtor, sizeof(dtor)); setuvalue(L, L->top, u); diff --git a/VM/src/lbuiltins.cpp b/VM/src/lbuiltins.cpp index e98660a7..422c82b1 100644 --- a/VM/src/lbuiltins.cpp +++ b/VM/src/lbuiltins.cpp @@ -15,6 +15,8 @@ #include #endif +LUAU_FASTFLAGVARIABLE(LuauFasterBit32NoWidth, false) + // luauF functions implement FASTCALL instruction that performs a direct execution of some builtin functions from the VM // The rule of thumb is that FASTCALL functions can not call user code, yield, fail, or reallocate stack. // If types of the arguments mismatch, luauF_* needs to return -1 and the execution will fall back to the usual call path @@ -600,24 +602,39 @@ static int luauF_btest(lua_State* L, StkId res, TValue* arg0, int nresults, StkI static int luauF_extract(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) { - if (nparams >= 3 && nresults <= 1 && ttisnumber(arg0) && ttisnumber(args) && ttisnumber(args + 1)) + if (nparams >= (3 - FFlag::LuauFasterBit32NoWidth) && nresults <= 1 && ttisnumber(arg0) && ttisnumber(args)) { double a1 = nvalue(arg0); double a2 = nvalue(args); - double a3 = nvalue(args + 1); unsigned n; luai_num2unsigned(n, a1); int f = int(a2); - int w = int(a3); - if (f >= 0 && w > 0 && f + w <= 32) + if (nparams == 2) { - uint32_t m = ~(0xfffffffeu << (w - 1)); - uint32_t r = (n >> f) & m; + if (unsigned(f) < 32) + { + uint32_t m = 1; + uint32_t r = (n >> f) & m; - setnvalue(res, double(r)); - return 1; + setnvalue(res, double(r)); + return 1; + } + } + else if (ttisnumber(args + 1)) + { + double a3 = nvalue(args + 1); + int w = int(a3); + + if (f >= 0 && w > 0 && f + w <= 32) + { + uint32_t m = ~(0xfffffffeu << (w - 1)); + uint32_t r = (n >> f) & m; + + setnvalue(res, double(r)); + return 1; + } } } @@ -676,26 +693,41 @@ static int luauF_lshift(lua_State* L, StkId res, TValue* arg0, int nresults, Stk static int luauF_replace(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) { - if (nparams >= 4 && nresults <= 1 && ttisnumber(arg0) && ttisnumber(args) && ttisnumber(args + 1) && ttisnumber(args + 2)) + if (nparams >= (4 - FFlag::LuauFasterBit32NoWidth) && nresults <= 1 && ttisnumber(arg0) && ttisnumber(args) && ttisnumber(args + 1)) { double a1 = nvalue(arg0); double a2 = nvalue(args); double a3 = nvalue(args + 1); - double a4 = nvalue(args + 2); unsigned n, v; luai_num2unsigned(n, a1); luai_num2unsigned(v, a2); int f = int(a3); - int w = int(a4); - if (f >= 0 && w > 0 && f + w <= 32) + if (nparams == 3) { - uint32_t m = ~(0xfffffffeu << (w - 1)); - uint32_t r = (n & ~(m << f)) | ((v & m) << f); + if (unsigned(f) < 32) + { + uint32_t m = 1; + uint32_t r = (n & ~(m << f)) | ((v & m) << f); - setnvalue(res, double(r)); - return 1; + setnvalue(res, double(r)); + return 1; + } + } + else if (ttisnumber(args + 2)) + { + double a4 = nvalue(args + 2); + int w = int(a4); + + if (f >= 0 && w > 0 && f + w <= 32) + { + uint32_t m = ~(0xfffffffeu << (w - 1)); + uint32_t r = (n & ~(m << f)) | ((v & m) << f); + + setnvalue(res, double(r)); + return 1; + } } } @@ -1138,6 +1170,31 @@ static int luauF_rawlen(lua_State* L, StkId res, TValue* arg0, int nresults, Stk return -1; } +static int luauF_extractk(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + // args is known to contain a number constant with packed in-range f/w + if (nparams >= 2 && nresults <= 1 && ttisnumber(arg0)) + { + double a1 = nvalue(arg0); + double a2 = nvalue(args); + + unsigned n; + luai_num2unsigned(n, a1); + int fw = int(a2); + + int f = fw & 31; + int w1 = fw >> 5; + + uint32_t m = ~(0xfffffffeu << w1); + uint32_t r = (n >> f) & m; + + setnvalue(res, double(r)); + return 1; + } + + return -1; +} + luau_FastFunction luauF_table[256] = { NULL, luauF_assert, @@ -1211,4 +1268,6 @@ luau_FastFunction luauF_table[256] = { luauF_select, luauF_rawlen, + + luauF_extractk, }; diff --git a/VM/src/lfunc.cpp b/VM/src/lfunc.cpp index dfde6dcb..bd0f8264 100644 --- a/VM/src/lfunc.cpp +++ b/VM/src/lfunc.cpp @@ -175,8 +175,7 @@ void luaF_freeclosure(lua_State* L, Closure* c, lua_Page* page) const LocVar* luaF_getlocal(const Proto* f, int local_number, int pc) { - int i; - for (i = 0; i < f->sizelocvars; i++) + for (int i = 0; i < f->sizelocvars; i++) { if (pc >= f->locvars[i].startpc && pc < f->locvars[i].endpc) { // is variable active? @@ -185,5 +184,15 @@ const LocVar* luaF_getlocal(const Proto* f, int local_number, int pc) return &f->locvars[i]; } } + + return NULL; // not found +} + +const LocVar* luaF_findlocal(const Proto* f, int local_reg, int pc) +{ + for (int i = 0; i < f->sizelocvars; i++) + if (local_reg == f->locvars[i].reg && pc >= f->locvars[i].startpc && pc < f->locvars[i].endpc) + return &f->locvars[i]; + return NULL; // not found } diff --git a/VM/src/lfunc.h b/VM/src/lfunc.h index a260d00a..59ab5726 100644 --- a/VM/src/lfunc.h +++ b/VM/src/lfunc.h @@ -17,3 +17,4 @@ LUAI_FUNC void luaF_freeclosure(lua_State* L, Closure* c, struct lua_Page* page) LUAI_FUNC void luaF_unlinkupval(UpVal* uv); LUAI_FUNC void luaF_freeupval(lua_State* L, UpVal* uv, struct lua_Page* page); LUAI_FUNC const LocVar* luaF_getlocal(const Proto* func, int local_number, int pc); +LUAI_FUNC const LocVar* luaF_findlocal(const Proto* func, int local_reg, int pc); diff --git a/VM/src/lgcdebug.cpp b/VM/src/lgcdebug.cpp index bc997d44..b6204b1a 100644 --- a/VM/src/lgcdebug.cpp +++ b/VM/src/lgcdebug.cpp @@ -345,6 +345,9 @@ static void dumpclosure(FILE* f, Closure* cl) if (cl->isC) { + if (cl->c.debugname) + fprintf(f, ",\"name\":\"%s\"", cl->c.debugname + 0); + if (cl->nupvalues) { fprintf(f, ",\"upvalues\":["); @@ -354,6 +357,9 @@ static void dumpclosure(FILE* f, Closure* cl) } else { + if (cl->l.p->debugname) + fprintf(f, ",\"name\":\"%s\"", getstr(cl->l.p->debugname)); + fprintf(f, ",\"proto\":"); dumpref(f, obj2gco(cl->l.p)); if (cl->nupvalues) @@ -403,7 +409,7 @@ static void dumpthread(FILE* f, lua_State* th) fprintf(f, ",\"source\":\""); dumpstringdata(f, p->source->data, p->source->len); - fprintf(f, "\",\"line\":%d", p->abslineinfo ? p->abslineinfo[0] : 0); + fprintf(f, "\",\"line\":%d", p->linedefined); } if (th->top > th->stack) @@ -411,6 +417,55 @@ static void dumpthread(FILE* f, lua_State* th) fprintf(f, ",\"stack\":["); dumprefs(f, th->stack, th->top - th->stack); fprintf(f, "]"); + + CallInfo* ci = th->base_ci; + bool first = true; + + fprintf(f, ",\"stacknames\":["); + for (StkId v = th->stack; v < th->top; ++v) + { + if (!iscollectable(v)) + continue; + + while (ci < th->ci && v >= (ci + 1)->func) + ci++; + + if (!first) + fputc(',', f); + first = false; + + if (v == ci->func) + { + Closure* cl = ci_func(ci); + + if (cl->isC) + { + fprintf(f, "\"frame:%s\"", cl->c.debugname ? cl->c.debugname : "[C]"); + } + else + { + Proto* p = cl->l.p; + fprintf(f, "\"frame:"); + if (p->source) + dumpstringdata(f, p->source->data, p->source->len); + fprintf(f, ":%d:%s\"", p->linedefined, p->debugname ? getstr(p->debugname) : ""); + } + } + else if (isLua(ci)) + { + Proto* p = ci_func(ci)->l.p; + int pc = pcRel(ci->savedpc, p); + const LocVar* var = luaF_findlocal(p, int(v - ci->base), pc); + + if (var && var->varname) + fprintf(f, "\"%s\"", getstr(var->varname)); + else + fprintf(f, "null"); + } + else + fprintf(f, "null"); + } + fprintf(f, "]"); } fprintf(f, "}"); } diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 0f17531b..25447dd8 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -3189,4 +3189,27 @@ a.@1 CHECK(ac.entryMap.count("y")); } +TEST_CASE_FIXTURE(ACFixture, "globals_are_order_independent") +{ + ScopedFastFlag sff("LuauAutocompleteFixGlobalOrder", true); + + check(R"( + local myLocal = 4 + function abc0() + local myInnerLocal = 1 +@1 + end + + function abc1() + local myInnerLocal = 1 + end + )"); + + auto ac = autocomplete('1'); + CHECK(ac.entryMap.count("myLocal")); + CHECK(ac.entryMap.count("myInnerLocal")); + CHECK(ac.entryMap.count("abc0")); + CHECK(ac.entryMap.count("abc1")); +} + TEST_SUITE_END(); diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index a4ce88b1..a2e748a8 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -261,8 +261,6 @@ L1: RETURN R0 0 TEST_CASE("ForBytecode") { - ScopedFastFlag sff("LuauCompileNoIpairs", true); - // basic for loop: variable directly refers to internal iteration index (R2) CHECK_EQ("\n" + compileFunction0("for i=1,5 do print(i) end"), R"( LOADN R2 1 @@ -349,8 +347,6 @@ RETURN R0 0 TEST_CASE("ForBytecodeBuiltin") { - ScopedFastFlag sff("LuauCompileNoIpairs", true); - // we generally recognize builtins like pairs/ipairs and emit special opcodes CHECK_EQ("\n" + compileFunction0("for k,v in ipairs({}) do end"), R"( GETIMPORT R0 1 @@ -2065,6 +2061,69 @@ TEST_CASE("RecursionParse") { CHECK_EQ(std::string(e.what()), "Exceeded allowed recursion depth; simplify your block to make the code compile"); } + + try + { + Luau::compileOrThrow(bcb, rep("a(", 1500) + "42" + rep(")", 1500)); + CHECK(!"Expected exception"); + } + catch (std::exception& e) + { + CHECK_EQ(std::string(e.what()), "Exceeded allowed recursion depth; simplify your expression to make the code compile"); + } + + try + { + Luau::compileOrThrow(bcb, "return " + rep("{", 1500) + "42" + rep("}", 1500)); + CHECK(!"Expected exception"); + } + catch (std::exception& e) + { + CHECK_EQ(std::string(e.what()), "Exceeded allowed recursion depth; simplify your expression to make the code compile"); + } + + try + { + Luau::compileOrThrow(bcb, rep("while true do ", 1500) + "print()" + rep(" end", 1500)); + CHECK(!"Expected exception"); + } + catch (std::exception& e) + { + CHECK_EQ(std::string(e.what()), "Exceeded allowed recursion depth; simplify your expression to make the code compile"); + } + + try + { + Luau::compileOrThrow(bcb, rep("for i=1,1 do ", 1500) + "print()" + rep(" end", 1500)); + CHECK(!"Expected exception"); + } + catch (std::exception& e) + { + CHECK_EQ(std::string(e.what()), "Exceeded allowed recursion depth; simplify your expression to make the code compile"); + } + +#if 0 + // This currently requires too much stack space on MSVC/x64 and crashes with stack overflow at recursion depth 935 + try + { + Luau::compileOrThrow(bcb, rep("function a() ", 1500) + "print()" + rep(" end", 1500)); + CHECK(!"Expected exception"); + } + catch (std::exception& e) + { + CHECK_EQ(std::string(e.what()), "Exceeded allowed recursion depth; simplify your block to make the code compile"); + } + + try + { + Luau::compileOrThrow(bcb, "return " + rep("function() return ", 1500) + "42" + rep(" end", 1500)); + CHECK(!"Expected exception"); + } + catch (std::exception& e) + { + CHECK_EQ(std::string(e.what()), "Exceeded allowed recursion depth; simplify your block to make the code compile"); + } +#endif } TEST_CASE("ArrayIndexLiteral") @@ -2111,8 +2170,6 @@ L1: RETURN R3 -1 TEST_CASE("UpvaluesLoopsBytecode") { - ScopedFastFlag sff("LuauCompileNoIpairs", true); - CHECK_EQ("\n" + compileFunction(R"( function test() for i=1,10 do @@ -3790,8 +3847,6 @@ RETURN R0 1 TEST_CASE("SharedClosure") { - ScopedFastFlag sff("LuauCompileFreeReassign", true); - // closures can be shared even if functions refer to upvalues, as long as upvalues are top-level CHECK_EQ("\n" + compileFunction(R"( local val = ... @@ -6004,6 +6059,8 @@ return math.clamp(-1, 0, 1), math.sign(77), math.round(7.6), + bit32.extract(-1, 31), + bit32.replace(100, 1, 0), (type("fin")) )", 0, 2), @@ -6055,8 +6112,10 @@ LOADK R43 K2 LOADN R44 0 LOADN R45 1 LOADN R46 8 -LOADK R47 K3 -RETURN R0 48 +LOADN R47 1 +LOADN R48 101 +LOADK R49 K3 +RETURN R0 50 )"); } @@ -6067,7 +6126,8 @@ return math.abs(), math.max(1, true), string.byte("abc", 42), - bit32.rshift(10, 42) + bit32.rshift(10, 42), + bit32.extract(1, 2, "3") )", 0, 2), R"( @@ -6088,8 +6148,14 @@ L2: LOADN R4 10 FASTCALL2K 39 R4 K7 L3 LOADK R5 K7 GETIMPORT R3 13 -CALL R3 2 -1 -L3: RETURN R0 -1 +CALL R3 2 1 +L3: LOADN R5 1 +LOADN R6 2 +LOADK R7 K14 +FASTCALL 34 L4 +GETIMPORT R4 16 +CALL R4 3 -1 +L4: RETURN R0 -1 )"); } @@ -6146,8 +6212,6 @@ RETURN R0 1 TEST_CASE("LocalReassign") { - ScopedFastFlag sff("LuauCompileFreeReassign", true); - // locals can be re-assigned and the register gets reused CHECK_EQ("\n" + compileFunction0(R"( local function test(a, b) @@ -6459,4 +6523,26 @@ RETURN R0 0 )"); } +TEST_CASE("BuiltinExtractK") +{ + ScopedFastFlag sff("LuauCompileExtractK", true); + + // below, K0 refers to a packed f+w constant for bit32.extractk builtin + // K1 and K2 refer to 1 and 3 and are only used during fallback path + CHECK_EQ("\n" + compileFunction0(R"( +local v = ... + +return bit32.extract(v, 1, 3) +)"), R"( +GETVARARGS R0 1 +FASTCALL2K 59 R0 K0 L0 +MOVE R2 R0 +LOADK R3 K1 +LOADK R4 K2 +GETIMPORT R1 5 +CALL R1 3 -1 +L0: RETURN R1 -1 +)"); +} + TEST_SUITE_END(); diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 17c7ac58..be2feacb 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -727,7 +727,7 @@ TEST_CASE("NewUserdataOverflow") // The overflow might segfault in the following call. lua_getmetatable(L1, -1); return 0; - }, "PCall"); + }, nullptr); CHECK(lua_pcall(L, 0, 0, 0) == LUA_ERRRUN); CHECK(strcmp(lua_tostring(L, -1), "memory allocation error: block too big") == 0); diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index f0146602..40be39a0 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -443,7 +443,8 @@ BuiltinsFixture::BuiltinsFixture(bool freeze, bool prepareAutocomplete) ConstraintGraphBuilderFixture::ConstraintGraphBuilderFixture() : Fixture() - , cgb(mainModuleName, getMainModule(), &arena, NotNull(&ice), frontend.getGlobalScope()) + , mainModule(new Module) + , cgb(mainModuleName, mainModule, &arena, NotNull(&ice), frontend.getGlobalScope()) , forceTheFlag{"DebugLuauDeferredConstraintResolution", true} { BlockedTypeVar::nextIndex = 0; diff --git a/tests/Fixture.h b/tests/Fixture.h index a716fe9b..8dc3dd24 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -162,6 +162,7 @@ struct BuiltinsFixture : Fixture struct ConstraintGraphBuilderFixture : Fixture { TypeArena arena; + ModulePtr mainModule; ConstraintGraphBuilder cgb; ScopedFastFlag forceTheFlag; diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index c64c41c5..b017d8dd 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -12,6 +12,11 @@ using namespace Luau; struct NormalizeFixture : Fixture { ScopedFastFlag sff1{"LuauLowerBoundsCalculation", true}; + + bool isSubtype(TypeId a, TypeId b) + { + return ::Luau::isSubtype(a, b, NotNull{getMainModule()->getModuleScope().get()}, ice); + } }; void createSomeClasses(TypeChecker& typeChecker) @@ -49,12 +54,6 @@ void createSomeClasses(TypeChecker& typeChecker) freeze(arena); } -static bool isSubtype(TypeId a, TypeId b) -{ - InternalErrorReporter ice; - return isSubtype(a, b, ice); -} - TEST_SUITE_BEGIN("isSubtype"); TEST_CASE_FIXTURE(NormalizeFixture, "primitives") @@ -511,6 +510,8 @@ TEST_CASE_FIXTURE(NormalizeFixture, "classes") { createSomeClasses(typeChecker); + check(""); // Ensure that we have a main Module. + TypeId p = typeChecker.globalScope->lookupType("Parent")->type; TypeId c = typeChecker.globalScope->lookupType("Child")->type; TypeId u = typeChecker.globalScope->lookupType("Unrelated")->type; @@ -595,6 +596,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "union_with_overlapping_field_that_has_a_sub )"); ModulePtr tempModule{new Module}; + tempModule->scopes.emplace_back(Location(), std::make_shared(getSingletonTypes().anyTypePack)); // HACK: Normalization is an in-place operation. We need to cheat a little here and unfreeze // the arena that the type lives in. @@ -880,7 +882,6 @@ TEST_CASE_FIXTURE(Fixture, "intersection_inside_a_table_inside_another_intersect { ScopedFastFlag flags[] = { {"LuauLowerBoundsCalculation", true}, - {"LuauQuantifyConstrained", true}, }; // We use a function and inferred parameter types to prevent intermediate normalizations from being performed. @@ -921,7 +922,6 @@ TEST_CASE_FIXTURE(Fixture, "intersection_inside_a_table_inside_another_intersect { ScopedFastFlag flags[] = { {"LuauLowerBoundsCalculation", true}, - {"LuauQuantifyConstrained", true}, }; // We use a function and inferred parameter types to prevent intermediate normalizations from being performed. @@ -961,7 +961,6 @@ TEST_CASE_FIXTURE(Fixture, "intersection_inside_a_table_inside_another_intersect { ScopedFastFlag flags[] = { {"LuauLowerBoundsCalculation", true}, - {"LuauQuantifyConstrained", true}, }; // We use a function and inferred parameter types to prevent intermediate normalizations from being performed. @@ -1149,7 +1148,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "normalization_does_not_convert_ever") { ScopedFastFlag sff[]{ {"LuauLowerBoundsCalculation", true}, - {"LuauQuantifyConstrained", true}, }; CheckResult result = check(R"( diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 5b86807c..c55ec181 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -2535,7 +2535,6 @@ end TEST_CASE_FIXTURE(Fixture, "error_message_for_using_function_as_type_annotation") { - ScopedFastFlag sff{"LuauParserFunctionKeywordAsTypeHelp", true}; ParseResult result = tryParse(R"( type Foo = function )"); diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 074c86c3..5cc759d6 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -1637,7 +1637,6 @@ TEST_CASE_FIXTURE(Fixture, "quantify_constrained_types") { ScopedFastFlag sff[]{ {"LuauLowerBoundsCalculation", true}, - {"LuauQuantifyConstrained", true}, }; CheckResult result = check(R"( @@ -1662,7 +1661,6 @@ TEST_CASE_FIXTURE(Fixture, "call_o_with_another_argument_after_foo_was_quantifie { ScopedFastFlag sff[]{ {"LuauLowerBoundsCalculation", true}, - {"LuauQuantifyConstrained", true}, }; CheckResult result = check(R"( diff --git a/tests/TypeInfer.modules.test.cpp b/tests/TypeInfer.modules.test.cpp index 754fb193..b5f2296c 100644 --- a/tests/TypeInfer.modules.test.cpp +++ b/tests/TypeInfer.modules.test.cpp @@ -329,6 +329,35 @@ function tbl:foo(b: number, c: number) -- introduce BoundTypeVar to imported type arrayops.foo(self._regions) end +-- this alias decreases function type level and causes a demotion of its type +type Table = typeof(tbl) +)"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "do_not_modify_imported_types_5") +{ + ScopedFastFlag luauInplaceDemoteSkipAllBound{"LuauInplaceDemoteSkipAllBound", true}; + + fileResolver.source["game/A"] = R"( +export type Type = {x: number, y: number} +local arrayops = {} +function arrayops.foo(x: Type) end +return arrayops + )"; + + CheckResult result = check(R"( +local arrayops = require(game.A) + +local tbl = {} +tbl.a = 2 +function tbl:foo(b: number, c: number) + -- introduce boundTo TableTypeVar to imported type + self.x.a = 2 + arrayops.foo(self.x) +end +-- this alias decreases function type level and causes a demotion of its type type Table = typeof(tbl) )"); diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 01923f38..9a917a6f 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -485,7 +485,6 @@ TEST_CASE_FIXTURE(Fixture, "constrained_is_level_dependent") { ScopedFastFlag sff[]{ {"LuauLowerBoundsCalculation", true}, - {"LuauQuantifyConstrained", true}, }; CheckResult result = check(R"( diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 8a1cadcf..e61e6e45 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -41,6 +41,7 @@ struct RefinementClassFixture : Fixture RefinementClassFixture() { TypeArena& arena = typeChecker.globalTypes; + NotNull scope{typeChecker.globalScope.get()}; unfreeze(arena); TypeId vec3 = arena.addType(ClassTypeVar{"Vector3", {}, std::nullopt, std::nullopt, {}, nullptr, "Test"}); @@ -49,7 +50,7 @@ struct RefinementClassFixture : Fixture {"Y", Property{typeChecker.numberType}}, {"Z", Property{typeChecker.numberType}}, }; - normalize(vec3, arena, *typeChecker.iceHandler); + normalize(vec3, scope, arena, *typeChecker.iceHandler); TypeId inst = arena.addType(ClassTypeVar{"Instance", {}, std::nullopt, std::nullopt, {}, nullptr, "Test"}); @@ -57,21 +58,21 @@ struct RefinementClassFixture : Fixture TypePackId isARets = arena.addTypePack({typeChecker.booleanType}); TypeId isA = arena.addType(FunctionTypeVar{isAParams, isARets}); getMutable(isA)->magicFunction = magicFunctionInstanceIsA; - normalize(isA, arena, *typeChecker.iceHandler); + normalize(isA, scope, arena, *typeChecker.iceHandler); getMutable(inst)->props = { {"Name", Property{typeChecker.stringType}}, {"IsA", Property{isA}}, }; - normalize(inst, arena, *typeChecker.iceHandler); + normalize(inst, scope, arena, *typeChecker.iceHandler); TypeId folder = typeChecker.globalTypes.addType(ClassTypeVar{"Folder", {}, inst, std::nullopt, {}, nullptr, "Test"}); - normalize(folder, arena, *typeChecker.iceHandler); + normalize(folder, scope, arena, *typeChecker.iceHandler); TypeId part = typeChecker.globalTypes.addType(ClassTypeVar{"Part", {}, inst, std::nullopt, {}, nullptr, "Test"}); getMutable(part)->props = { {"Position", Property{vec3}}, }; - normalize(part, arena, *typeChecker.iceHandler); + normalize(part, scope, arena, *typeChecker.iceHandler); typeChecker.globalScope->exportedTypeBindings["Vector3"] = TypeFun{{}, vec3}; typeChecker.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, inst}; @@ -934,8 +935,6 @@ TEST_CASE_FIXTURE(Fixture, "apply_refinements_on_astexprindexexpr_whose_subscrip TEST_CASE_FIXTURE(Fixture, "discriminate_from_truthiness_of_x") { - ScopedFastFlag sff{"LuauFalsyPredicateReturnsNilInstead", true}; - CheckResult result = check(R"( type T = {tag: "missing", x: nil} | {tag: "exists", x: string} @@ -1230,8 +1229,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "refine_unknowns") TEST_CASE_FIXTURE(BuiltinsFixture, "falsiness_of_TruthyPredicate_narrows_into_nil") { - ScopedFastFlag sff{"LuauFalsyPredicateReturnsNilInstead", true}; - CheckResult result = check(R"( local function f(t: {number}) local x = t[1] diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 8c7d8a53..95b85c48 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -3003,8 +3003,6 @@ TEST_CASE_FIXTURE(Fixture, "expected_indexer_from_table_union") TEST_CASE_FIXTURE(Fixture, "prop_access_on_key_whose_types_mismatches") { - ScopedFastFlag sff{"LuauReportErrorsOnIndexerKeyMismatch", true}; - CheckResult result = check(R"( local t: {number} = {} local x = t.x @@ -3016,8 +3014,6 @@ TEST_CASE_FIXTURE(Fixture, "prop_access_on_key_whose_types_mismatches") TEST_CASE_FIXTURE(Fixture, "prop_access_on_unions_of_indexers_where_key_whose_types_mismatches") { - ScopedFastFlag sff{"LuauReportErrorsOnIndexerKeyMismatch", true}; - CheckResult result = check(R"( local t: { [number]: number } | { [boolean]: number } = {} local u = t.x diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index e0a0e5b5..02fdfd73 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -17,7 +17,8 @@ struct TryUnifyFixture : Fixture ScopePtr globalScope{new Scope{arena.addTypePack({TypeId{}})}}; InternalErrorReporter iceHandler; UnifierSharedState unifierState{&iceHandler}; - Unifier state{&arena, Mode::Strict, Location{}, Variance::Covariant, unifierState}; + + Unifier state{&arena, Mode::Strict, NotNull{globalScope.get()}, Location{}, Variance::Covariant, unifierState}; }; TEST_SUITE_BEGIN("TryUnifyTests"); diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index edec8442..32be8215 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -193,6 +193,11 @@ TEST_CASE_FIXTURE(Fixture, "UnionTypeVarIterator_with_only_cyclic_union") CHECK(actual.empty()); } + +/* FIXME: This test is pretty weird. It would be much nicer if we could + * perform this operation without a TypeChecker so that we don't have to jam + * all this state into it to make stuff work. + */ TEST_CASE_FIXTURE(Fixture, "substitution_skip_failure") { TypeVar ftv11{FreeTypeVar{TypeLevel{}}}; @@ -268,6 +273,7 @@ TEST_CASE_FIXTURE(Fixture, "substitution_skip_failure") TypeId root = &ttvTweenResult; typeChecker.currentModule = std::make_shared(); + typeChecker.currentModule->scopes.emplace_back(Location{}, std::make_shared(getSingletonTypes().anyTypePack)); TypeId result = typeChecker.anyify(typeChecker.globalScope, root, Location{}); diff --git a/tests/conformance/bitwise.lua b/tests/conformance/bitwise.lua index 13be3f94..f0c5698d 100644 --- a/tests/conformance/bitwise.lua +++ b/tests/conformance/bitwise.lua @@ -100,6 +100,9 @@ assert(bit32.extract(0xa0001111, 28, 4) == 0xa) assert(bit32.extract(0xa0001111, 31, 1) == 1) assert(bit32.extract(0x50000111, 31, 1) == 0) assert(bit32.extract(0xf2345679, 0, 32) == 0xf2345679) +assert(bit32.extract(0xa0001111, 16) == 0) +assert(bit32.extract(0xa0001111, 31) == 1) +assert(bit32.extract(42, 1, 3) == 5) assert(not pcall(bit32.extract, 0, -1)) assert(not pcall(bit32.extract, 0, 32)) @@ -152,5 +155,6 @@ assert(bit32.btest(1, "3") == true) assert(bit32.btest("1", 3) == true) assert(bit32.countlz("42") == 26) assert(bit32.countrz("42") == 1) +assert(bit32.extract("42", 1, 3) == 5) return('OK') diff --git a/tools/faillist.txt b/tools/faillist.txt index d796236d..630bf9f7 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -10,12 +10,10 @@ AnnotationTests.luau_ice_triggers_an_ice_exception_with_flag AnnotationTests.luau_ice_triggers_an_ice_exception_with_flag_handler AnnotationTests.luau_ice_triggers_an_ice_handler AnnotationTests.luau_print_is_magic_if_the_flag_is_set -AnnotationTests.luau_print_is_not_special_without_the_flag AnnotationTests.occurs_check_on_cyclic_intersection_typevar AnnotationTests.occurs_check_on_cyclic_union_typevar AnnotationTests.too_many_type_params AnnotationTests.two_type_params -AnnotationTests.unknown_type_reference_generates_error AnnotationTests.use_type_required_from_another_file AstQuery.last_argument_function_call_type AstQuery::getDocumentationSymbolAtPosition.overloaded_fn @@ -27,17 +25,13 @@ AutocompleteTest.autocomplete_end_with_lambda AutocompleteTest.autocomplete_first_function_arg_expected_type AutocompleteTest.autocomplete_for_in_middle_keywords AutocompleteTest.autocomplete_for_middle_keywords -AutocompleteTest.autocomplete_if_else_regression AutocompleteTest.autocomplete_if_middle_keywords -AutocompleteTest.autocomplete_ifelse_expressions AutocompleteTest.autocomplete_on_string_singletons AutocompleteTest.autocomplete_oop_implicit_self AutocompleteTest.autocomplete_repeat_middle_keyword AutocompleteTest.autocomplete_string_singleton_equality AutocompleteTest.autocomplete_string_singleton_escape AutocompleteTest.autocomplete_string_singletons -AutocompleteTest.autocomplete_until_expression -AutocompleteTest.autocomplete_until_in_repeat AutocompleteTest.autocomplete_while_middle_keywords AutocompleteTest.autocompleteProp_index_function_metamethod_is_variadic AutocompleteTest.bias_toward_inner_scope @@ -60,7 +54,6 @@ AutocompleteTest.get_suggestions_for_the_very_start_of_the_script AutocompleteTest.global_function_params AutocompleteTest.global_functions_are_not_scoped_lexically AutocompleteTest.if_then_else_elseif_completions -AutocompleteTest.if_then_else_full_keywords AutocompleteTest.keyword_methods AutocompleteTest.keyword_types AutocompleteTest.library_non_self_calls_are_fine @@ -181,7 +174,6 @@ DefinitionTests.single_class_type_identity_in_global_types FrontendTest.ast_node_at_position FrontendTest.automatically_check_dependent_scripts FrontendTest.check_without_builtin_next -FrontendTest.clearStats FrontendTest.dont_reparse_clean_file_when_linting FrontendTest.environments FrontendTest.imported_table_modification_2 @@ -195,7 +187,6 @@ FrontendTest.reexport_cyclic_type FrontendTest.reexport_type_alias FrontendTest.report_require_to_nonexistent_file FrontendTest.report_syntax_error_in_required_file -FrontendTest.stats_are_not_reset_between_checks FrontendTest.trace_requires_in_nonstrict_mode GenericsTests.apply_type_function_nested_generics1 GenericsTests.apply_type_function_nested_generics2 @@ -213,8 +204,6 @@ GenericsTests.duplicate_generic_types GenericsTests.error_detailed_function_mismatch_generic_pack GenericsTests.error_detailed_function_mismatch_generic_types GenericsTests.factories_of_generics -GenericsTests.function_arguments_can_be_polytypes -GenericsTests.function_results_can_be_polytypes GenericsTests.generic_argument_count_too_few GenericsTests.generic_argument_count_too_many GenericsTests.generic_factories @@ -243,7 +232,6 @@ GenericsTests.properties_can_be_instantiated_polytypes GenericsTests.rank_N_types_via_typeof GenericsTests.reject_clashing_generic_and_pack_names GenericsTests.self_recursive_instantiated_param -GenericsTests.variadic_generics IntersectionTypes.argument_is_intersection IntersectionTypes.error_detailed_intersection_all IntersectionTypes.error_detailed_intersection_part @@ -289,7 +277,6 @@ Normalize.cyclic_intersection Normalize.cyclic_table_normalizes_sensibly Normalize.cyclic_union Normalize.fuzz_failure_bound_type_is_normal_but_not_its_bounded_to -Normalize.higher_order_function Normalize.intersection_combine_on_bound_self Normalize.intersection_inside_a_table_inside_another_intersection Normalize.intersection_inside_a_table_inside_another_intersection_2 @@ -304,7 +291,6 @@ Normalize.normalization_does_not_convert_ever Normalize.normalize_module_return_type Normalize.normalize_unions_containing_never Normalize.normalize_unions_containing_unknown -Normalize.return_type_is_not_a_constrained_intersection Normalize.union_of_distinct_free_types Normalize.variadic_tail_is_marked_normal Normalize.visiting_a_type_twice_is_not_considered_normal @@ -456,7 +442,6 @@ TableTests.inferred_return_type_of_free_table TableTests.inferring_crazy_table_should_also_be_quick TableTests.instantiate_table_cloning_3 TableTests.instantiate_tables_at_scope_level -TableTests.invariant_table_properties_means_instantiating_tables_in_assignment_is_unsound TableTests.leaking_bad_metatable_errors TableTests.length_operator_intersection TableTests.length_operator_non_table_union @@ -521,7 +506,6 @@ ToDot.function ToDot.metatable ToDot.table ToString.exhaustive_toString_of_cyclic_table -ToString.function_type_with_argument_names_and_self ToString.function_type_with_argument_names_generic ToString.no_parentheses_around_cyclic_function_type_in_union ToString.toStringDetailed2 @@ -565,10 +549,7 @@ TypeInfer.do_not_bind_a_free_table_to_a_union_containing_that_table TypeInfer.dont_report_type_errors_within_an_AstStatError TypeInfer.globals TypeInfer.globals2 -TypeInfer.infer_assignment_value_types TypeInfer.infer_assignment_value_types_mutable_lval -TypeInfer.infer_through_group_expr -TypeInfer.no_heap_use_after_free_error TypeInfer.no_stack_overflow_from_isoptional TypeInfer.tc_after_error_recovery_no_replacement_name_in_error TypeInfer.tc_if_else_expressions1 @@ -589,11 +570,8 @@ TypeInferAnyError.for_in_loop_iterator_returns_any TypeInferAnyError.for_in_loop_iterator_returns_any2 TypeInferAnyError.length_of_error_type_does_not_produce_an_error TypeInferAnyError.replace_every_free_type_when_unifying_a_complex_function_with_any -TypeInferAnyError.type_error_addition TypeInferClasses.call_base_method TypeInferClasses.call_instance_method -TypeInferClasses.can_read_prop_of_base_class -TypeInferClasses.can_read_prop_of_base_class_using_string TypeInferClasses.class_type_mismatch_with_name_conflict TypeInferClasses.classes_can_have_overloaded_operators TypeInferClasses.classes_without_overloaded_operators_cannot_be_added @@ -609,8 +587,6 @@ TypeInferFunctions.another_indirect_function_case_where_it_is_ok_to_provide_too_ TypeInferFunctions.another_recursive_local_function TypeInferFunctions.calling_function_with_anytypepack_doesnt_leak_free_types TypeInferFunctions.calling_function_with_incorrect_argument_type_yields_errors_spanning_argument -TypeInferFunctions.cannot_hoist_interior_defns_into_signature -TypeInferFunctions.check_function_before_lambda_that_uses_it TypeInferFunctions.complicated_return_types_require_an_explicit_annotation TypeInferFunctions.cyclic_function_type_in_args TypeInferFunctions.dont_give_other_overloads_message_if_only_one_argument_matching_overload_exists @@ -635,7 +611,6 @@ TypeInferFunctions.ignored_return_values TypeInferFunctions.inconsistent_higher_order_function TypeInferFunctions.inconsistent_return_types TypeInferFunctions.infer_anonymous_function_arguments -TypeInferFunctions.infer_anonymous_function_arguments_outside_call TypeInferFunctions.infer_return_type_from_selected_overload TypeInferFunctions.infer_that_function_does_not_return_a_table TypeInferFunctions.it_is_ok_not_to_supply_enough_retvals @@ -680,9 +655,6 @@ TypeInferLoops.loop_iter_no_indexer_strict TypeInferLoops.loop_iter_trailing_nil TypeInferLoops.loop_typecheck_crash_on_empty_optional TypeInferLoops.properly_infer_iteratee_is_a_free_table -TypeInferLoops.repeat_loop -TypeInferLoops.repeat_loop_condition_binds_to_its_block -TypeInferLoops.symbols_in_repeat_block_should_not_be_visible_beyond_until_condition TypeInferLoops.unreachable_code_after_infinite_loop TypeInferLoops.varlist_declared_by_for_in_loop_should_be_free TypeInferModules.do_not_modify_imported_types @@ -718,8 +690,6 @@ TypeInferOperators.cannot_indirectly_compare_types_that_do_not_offer_overloaded_ TypeInferOperators.cli_38355_recursive_union TypeInferOperators.compare_numbers TypeInferOperators.compare_strings -TypeInferOperators.compound_assign_basic -TypeInferOperators.compound_assign_metatable TypeInferOperators.compound_assign_mismatch_metatable TypeInferOperators.compound_assign_mismatch_op TypeInferOperators.compound_assign_mismatch_result @@ -775,7 +745,6 @@ TypeInferUnknownNever.type_packs_containing_never_is_itself_uninhabitable TypeInferUnknownNever.type_packs_containing_never_is_itself_uninhabitable2 TypeInferUnknownNever.unary_minus_of_never TypeInferUnknownNever.unknown_is_reflexive -TypePackTests.cyclic_type_packs TypePackTests.higher_order_function TypePackTests.multiple_varargs_inference_are_not_confused TypePackTests.no_return_size_should_be_zero diff --git a/tools/heapgraph.py b/tools/heapgraph.py index 2817c38f..b8dc207f 100644 --- a/tools/heapgraph.py +++ b/tools/heapgraph.py @@ -132,8 +132,16 @@ while offset < len(queue): queue.append((obj["metatable"], node.child("__meta"))) elif obj["type"] == "thread": queue.append((obj["env"], node.child("__env"))) - for a in obj.get("stack", []): - queue.append((a, node.child("__stack"))) + stack = obj.get("stack") + stacknames = obj.get("stacknames", []) + stacknode = node.child("__stack") + framenode = None + for i in range(len(stack)): + name = stacknames[i] if stacknames else None + if name and name.startswith("frame:"): + framenode = stacknode.child(name[6:]) + name = None + queue.append((stack[i], framenode.child(name) if framenode and name else framenode or stacknode)) elif obj["type"] == "proto": for a in obj.get("constants", []): queue.append((a, node)) diff --git a/tools/heapstat.py b/tools/heapstat.py index 4c0cb407..7337aa44 100644 --- a/tools/heapstat.py +++ b/tools/heapstat.py @@ -59,5 +59,6 @@ if len(size_category) != 0: print("objects by category:") for type, (count, size) in sortedsize(size_category.items()): - name = dump["stats"]["categories"][type]["name"] + cat = dump["stats"]["categories"][type] + name = cat["name"] if "name" in cat else str(type) print(name.ljust(30), str(size).rjust(8), "bytes", str(count).rjust(5), "objects") diff --git a/tools/perfgraph.py b/tools/perfgraph.py index ae74b7d6..25bb0dd9 100644 --- a/tools/perfgraph.py +++ b/tools/perfgraph.py @@ -56,18 +56,28 @@ def nodeFromCallstackListFile(source_file): return root +def getDuration(obj): + total = obj['TotalDuration'] -def nodeFromJSONbject(node, key, obj): + if 'Children' in obj: + for key, obj in obj['Children'].items(): + total -= obj['TotalDuration'] + + return total + + +def nodeFromJSONObject(node, key, obj): source, function, line = key.split(",") node.function = function node.source = source node.line = int(line) if len(line) > 0 else 0 - node.ticks = obj['Duration'] + node.ticks = getDuration(obj) - for key, obj in obj['Children'].items(): - nodeFromJSONbject(node.child(key), key, obj) + if 'Children' in obj: + for key, obj in obj['Children'].items(): + nodeFromJSONObject(node.child(key), key, obj) return node @@ -77,8 +87,9 @@ def nodeFromJSONFile(source_file): root = Node() - for key, obj in dump['Children'].items(): - nodeFromJSONbject(root.child(key), key, obj) + if 'Children' in dump: + for key, obj in dump['Children'].items(): + nodeFromJSONObject(root.child(key), key, obj) return root