From 31a017c5c7d48841c63ac66a0a62fad06f71860e Mon Sep 17 00:00:00 2001 From: Andy Friesen Date: Fri, 15 Sep 2023 10:26:59 -0700 Subject: [PATCH 1/3] Sync to upstream/release/595 (#1044) * Rerun clang-format on the code * Fix the variance on indexer result subtyping. This fixes some issues with inconsistent error reporting. * Fix a bug in the normalization logic for intersections of strings New Type Solver * New overload selection logic * Subtype tests now correctly treat a generic as its upper bound within that generic's scope * Semantic subtyping for negation types * Semantic subtyping between strings and compatible table types like `{lower: (string) -> string}` * Further work toward finalizing our new subtype test * Correctly generalize module-scope symbols Native Codegen * Lowering statistics for assembly * Make executable allocation size/limit configurable without a rebuild. Use `FInt::LuauCodeGenBlockSize` and `FInt::LuauCodeGenMaxTotalSize`. --------- Co-authored-by: Arseny Kapoulkine Co-authored-by: Vyacheslav Egorov Co-authored-by: Lily Brown --- .../include/Luau/ConstraintGraphBuilder.h | 7 +- Analysis/include/Luau/GlobalTypes.h | 2 +- Analysis/include/Luau/Instantiation.h | 7 +- Analysis/include/Luau/Subtyping.h | 88 ++-- Analysis/include/Luau/Type.h | 12 + Analysis/include/Luau/TypeChecker2.h | 4 +- Analysis/include/Luau/TypeUtils.h | 3 +- Analysis/include/Luau/Unifier.h | 6 +- Analysis/include/Luau/Unifier2.h | 4 +- Analysis/include/Luau/VisitType.h | 17 +- Analysis/src/Autocomplete.cpp | 29 +- Analysis/src/Clone.cpp | 2 +- Analysis/src/ConstraintGraphBuilder.cpp | 22 +- Analysis/src/ConstraintSolver.cpp | 69 ++- Analysis/src/Differ.cpp | 51 +- Analysis/src/Frontend.cpp | 3 +- Analysis/src/GlobalTypes.cpp | 2 +- Analysis/src/Instantiation.cpp | 3 +- Analysis/src/Linter.cpp | 4 +- Analysis/src/Normalize.cpp | 68 ++- Analysis/src/Subtyping.cpp | 495 +++++++++++++----- Analysis/src/ToDot.cpp | 3 +- Analysis/src/TypeChecker2.cpp | 19 +- Analysis/src/Unifier.cpp | 3 +- Analysis/src/Unifier2.cpp | 92 +++- Ast/include/Luau/Ast.h | 2 +- CLI/Compile.cpp | 20 +- CodeGen/include/Luau/CodeGen.h | 13 +- CodeGen/include/Luau/IrRegAllocX64.h | 6 +- CodeGen/include/Luau/IrUtils.h | 9 + CodeGen/src/AssemblyBuilderA64.cpp | 2 +- CodeGen/src/CodeGen.cpp | 2 +- CodeGen/src/CodeGenAssembly.cpp | 16 +- CodeGen/src/CodeGenLower.h | 48 +- CodeGen/src/IrBuilder.cpp | 3 +- CodeGen/src/IrLoweringA64.cpp | 16 +- CodeGen/src/IrLoweringA64.h | 4 +- CodeGen/src/IrLoweringX64.cpp | 14 +- CodeGen/src/IrLoweringX64.h | 4 +- CodeGen/src/IrRegAllocA64.cpp | 15 +- CodeGen/src/IrRegAllocA64.h | 6 +- CodeGen/src/IrRegAllocX64.cpp | 10 +- CodeGen/src/IrTranslation.cpp | 14 +- CodeGen/src/IrUtils.cpp | 38 ++ CodeGen/src/NativeState.cpp | 8 +- Common/include/Luau/ExperimentalFlags.h | 1 + VM/src/lmathlib.cpp | 18 +- VM/src/lvmexecute.cpp | 2 +- tests/Autocomplete.test.cpp | 3 +- tests/CodeAllocator.test.cpp | 3 +- tests/Compiler.test.cpp | 11 +- tests/IrCallWrapperX64.test.cpp | 2 +- tests/IrRegAllocX64.test.cpp | 2 +- tests/Normalize.test.cpp | 65 +++ tests/Subtyping.test.cpp | 326 ++++++------ tests/TypeInfer.operators.test.cpp | 3 +- tests/TypeInfer.refinements.test.cpp | 1 - tests/TypeInfer.tables.test.cpp | 36 ++ tests/TypeInfer.test.cpp | 2 +- tests/Unifier2.test.cpp | 60 ++- tools/faillist.txt | 31 +- 61 files changed, 1250 insertions(+), 581 deletions(-) diff --git a/Analysis/include/Luau/ConstraintGraphBuilder.h b/Analysis/include/Luau/ConstraintGraphBuilder.h index 1d3f20ee..902da0d5 100644 --- a/Analysis/include/Luau/ConstraintGraphBuilder.h +++ b/Analysis/include/Luau/ConstraintGraphBuilder.h @@ -101,9 +101,10 @@ struct ConstraintGraphBuilder DcrLogger* logger; - ConstraintGraphBuilder(ModulePtr module, NotNull normalizer, NotNull moduleResolver, NotNull builtinTypes, - NotNull ice, const ScopePtr& globalScope, std::function prepareModuleScope, - DcrLogger* logger, NotNull dfg, std::vector requireCycles); + ConstraintGraphBuilder(ModulePtr module, NotNull normalizer, NotNull moduleResolver, + NotNull builtinTypes, NotNull ice, const ScopePtr& globalScope, + std::function prepareModuleScope, DcrLogger* logger, NotNull dfg, + std::vector requireCycles); /** * Fabricates a new free type belonging to a given scope. diff --git a/Analysis/include/Luau/GlobalTypes.h b/Analysis/include/Luau/GlobalTypes.h index 86bfd943..7a34f935 100644 --- a/Analysis/include/Luau/GlobalTypes.h +++ b/Analysis/include/Luau/GlobalTypes.h @@ -23,4 +23,4 @@ struct GlobalTypes ScopePtr globalScope; // shared by all modules }; -} +} // namespace Luau diff --git a/Analysis/include/Luau/Instantiation.h b/Analysis/include/Luau/Instantiation.h index 642f2b9e..1dbf6b67 100644 --- a/Analysis/include/Luau/Instantiation.h +++ b/Analysis/include/Luau/Instantiation.h @@ -17,8 +17,8 @@ struct TypeCheckLimits; // A substitution which replaces generic types in a given set by free types. struct ReplaceGenerics : Substitution { - ReplaceGenerics(const TxnLog* log, TypeArena* arena, NotNull builtinTypes, TypeLevel level, Scope* scope, const std::vector& generics, - const std::vector& genericPacks) + ReplaceGenerics(const TxnLog* log, TypeArena* arena, NotNull builtinTypes, TypeLevel level, Scope* scope, + const std::vector& generics, const std::vector& genericPacks) : Substitution(log, arena) , builtinTypes(builtinTypes) , level(level) @@ -77,6 +77,7 @@ struct Instantiation : Substitution * Instantiation fails only when processing the type causes internal recursion * limits to be exceeded. */ -std::optional instantiate(NotNull builtinTypes, NotNull arena, NotNull limits, NotNull scope, TypeId ty); +std::optional instantiate( + NotNull builtinTypes, NotNull arena, NotNull limits, NotNull scope, TypeId ty); } // namespace Luau diff --git a/Analysis/include/Luau/Subtyping.h b/Analysis/include/Luau/Subtyping.h index 70cd8bae..de702d53 100644 --- a/Analysis/include/Luau/Subtyping.h +++ b/Analysis/include/Luau/Subtyping.h @@ -19,20 +19,18 @@ class TypeIds; class Normalizer; struct NormalizedType; struct NormalizedClassType; +struct NormalizedStringType; struct NormalizedFunctionType; + struct SubtypingResult { - // Did the test succeed? bool isSubtype = false; bool isErrorSuppressing = false; bool normalizationTooComplex = false; - // If so, what constraints are implied by this relation? - // If not, what happened? - - void andAlso(const SubtypingResult& other); - void orElse(const SubtypingResult& other); + SubtypingResult& andAlso(const SubtypingResult& other); + SubtypingResult& orElse(const SubtypingResult& other); // Only negates the `isSubtype`. static SubtypingResult negate(const SubtypingResult& result); @@ -47,6 +45,8 @@ struct Subtyping NotNull normalizer; NotNull iceReporter; + NotNull scope; + enum class Variance { Covariant, @@ -72,6 +72,12 @@ struct Subtyping SeenSet seenTypes; + Subtyping(const Subtyping&) = delete; + Subtyping& operator=(const Subtyping&) = delete; + + Subtyping(Subtyping&&) = default; + Subtyping& operator=(Subtyping&&) = default; + // TODO cache // TODO cyclic types // TODO recursion limits @@ -80,43 +86,61 @@ struct Subtyping SubtypingResult isSubtype(TypePackId subTy, TypePackId superTy); private: - SubtypingResult isSubtype_(TypeId subTy, TypeId superTy); - SubtypingResult isSubtype_(TypePackId subTy, TypePackId superTy); + SubtypingResult isCovariantWith(TypeId subTy, TypeId superTy); + SubtypingResult isCovariantWith(TypePackId subTy, TypePackId superTy); template - SubtypingResult isSubtype_(const TryPair& pair); + SubtypingResult isContravariantWith(SubTy&& subTy, SuperTy&& superTy); - SubtypingResult isSubtype_(TypeId subTy, const UnionType* superUnion); - SubtypingResult isSubtype_(const UnionType* subUnion, TypeId superTy); - SubtypingResult isSubtype_(TypeId subTy, const IntersectionType* superIntersection); - SubtypingResult isSubtype_(const IntersectionType* subIntersection, TypeId superTy); - SubtypingResult isSubtype_(const PrimitiveType* subPrim, const PrimitiveType* superPrim); - SubtypingResult isSubtype_(const SingletonType* subSingleton, const PrimitiveType* superPrim); - SubtypingResult isSubtype_(const SingletonType* subSingleton, const SingletonType* superSingleton); - SubtypingResult isSubtype_(const TableType* subTable, const TableType* superTable); - SubtypingResult isSubtype_(const MetatableType* subMt, const MetatableType* superMt); - SubtypingResult isSubtype_(const MetatableType* subMt, const TableType* superTable); - SubtypingResult isSubtype_(const ClassType* subClass, const ClassType* superClass); - SubtypingResult isSubtype_(const ClassType* subClass, const TableType* superTable); // Actually a class <: shape. - SubtypingResult isSubtype_(const FunctionType* subFunction, const FunctionType* superFunction); - SubtypingResult isSubtype_(const PrimitiveType* subPrim, const TableType* superTable); - SubtypingResult isSubtype_(const SingletonType* subSingleton, const TableType* superTable); + template + SubtypingResult isInvariantWith(SubTy&& subTy, SuperTy&& superTy); - SubtypingResult isSubtype_(const NormalizedType* subNorm, const NormalizedType* superNorm); - SubtypingResult isSubtype_(const NormalizedClassType& subClass, const NormalizedClassType& superClass, const TypeIds& superTables); - SubtypingResult isSubtype_(const NormalizedFunctionType& subFunction, const NormalizedFunctionType& superFunction); - SubtypingResult isSubtype_(const TypeIds& subTypes, const TypeIds& superTypes); + template + SubtypingResult isCovariantWith(const TryPair& pair); - SubtypingResult isSubtype_(const VariadicTypePack* subVariadic, const VariadicTypePack* superVariadic); + template + SubtypingResult isContravariantWith(const TryPair& pair); + + template + SubtypingResult isInvariantWith(const TryPair& pair); + + SubtypingResult isCovariantWith(TypeId subTy, const UnionType* superUnion); + SubtypingResult isCovariantWith(const UnionType* subUnion, TypeId superTy); + SubtypingResult isCovariantWith(TypeId subTy, const IntersectionType* superIntersection); + SubtypingResult isCovariantWith(const IntersectionType* subIntersection, TypeId superTy); + + SubtypingResult isCovariantWith(const NegationType* subNegation, TypeId superTy); + SubtypingResult isCovariantWith(const TypeId subTy, const NegationType* superNegation); + + SubtypingResult isCovariantWith(const PrimitiveType* subPrim, const PrimitiveType* superPrim); + SubtypingResult isCovariantWith(const SingletonType* subSingleton, const PrimitiveType* superPrim); + SubtypingResult isCovariantWith(const SingletonType* subSingleton, const SingletonType* superSingleton); + SubtypingResult isCovariantWith(const TableType* subTable, const TableType* superTable); + SubtypingResult isCovariantWith(const MetatableType* subMt, const MetatableType* superMt); + SubtypingResult isCovariantWith(const MetatableType* subMt, const TableType* superTable); + SubtypingResult isCovariantWith(const ClassType* subClass, const ClassType* superClass); + SubtypingResult isCovariantWith(const ClassType* subClass, const TableType* superTable); + SubtypingResult isCovariantWith(const FunctionType* subFunction, const FunctionType* superFunction); + SubtypingResult isCovariantWith(const PrimitiveType* subPrim, const TableType* superTable); + SubtypingResult isCovariantWith(const SingletonType* subSingleton, const TableType* superTable); + + SubtypingResult isCovariantWith(const NormalizedType* subNorm, const NormalizedType* superNorm); + SubtypingResult isCovariantWith(const NormalizedClassType& subClass, const NormalizedClassType& superClass); + SubtypingResult isCovariantWith(const NormalizedClassType& subClass, const TypeIds& superTables); + SubtypingResult isCovariantWith(const NormalizedStringType& subString, const NormalizedStringType& superString); + SubtypingResult isCovariantWith(const NormalizedStringType& subString, const TypeIds& superTables); + SubtypingResult isCovariantWith(const NormalizedFunctionType& subFunction, const NormalizedFunctionType& superFunction); + SubtypingResult isCovariantWith(const TypeIds& subTypes, const TypeIds& superTypes); + + SubtypingResult isCovariantWith(const VariadicTypePack* subVariadic, const VariadicTypePack* superVariadic); bool bindGeneric(TypeId subTp, TypeId superTp); bool bindGeneric(TypePackId subTp, TypePackId superTp); - template + template TypeId makeAggregateType(const Container& container, TypeId orElse); - [[noreturn]] - void unexpected(TypePackId tp); + [[noreturn]] void unexpected(TypePackId tp); }; } // namespace Luau diff --git a/Analysis/include/Luau/Type.h b/Analysis/include/Luau/Type.h index ffbe3fa0..d43266cb 100644 --- a/Analysis/include/Luau/Type.h +++ b/Analysis/include/Luau/Type.h @@ -849,6 +849,18 @@ bool isSubclass(const ClassType* cls, const ClassType* parent); Type* asMutable(TypeId ty); +template +bool is(T&& tv) +{ + if (!tv) + return false; + + if constexpr (std::is_same_v && !(std::is_same_v || ...)) + LUAU_ASSERT(get_if(&tv->ty) == nullptr); + + return (get(tv) || ...); +} + template const T* get(TypeId tv) { diff --git a/Analysis/include/Luau/TypeChecker2.h b/Analysis/include/Luau/TypeChecker2.h index db81d9cf..aeeab0f8 100644 --- a/Analysis/include/Luau/TypeChecker2.h +++ b/Analysis/include/Luau/TypeChecker2.h @@ -14,7 +14,7 @@ struct DcrLogger; struct TypeCheckLimits; struct UnifierSharedState; -void check(NotNull builtinTypes, NotNull sharedState, NotNull limits, DcrLogger* logger, const SourceModule& sourceModule, - Module* module); +void check(NotNull builtinTypes, NotNull sharedState, NotNull limits, DcrLogger* logger, + const SourceModule& sourceModule, Module* module); } // namespace Luau diff --git a/Analysis/include/Luau/TypeUtils.h b/Analysis/include/Luau/TypeUtils.h index 9699c4ae..4d41926c 100644 --- a/Analysis/include/Luau/TypeUtils.h +++ b/Analysis/include/Luau/TypeUtils.h @@ -104,7 +104,8 @@ ErrorSuppression shouldSuppressErrors(NotNull normalizer, TypePackId // Similar to `std::optional>`, but whose `sizeof()` is the same as `std::pair` // and cooperates with C++'s `if (auto p = ...)` syntax without the extra fatness of `std::optional`. template -struct TryPair { +struct TryPair +{ A first; B second; diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index f7c5c94c..1260ac93 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -105,10 +105,12 @@ struct Unifier * Populate the vector errors with any type errors that may arise. * Populate the transaction log with the set of TypeIds that need to be reset to undo the unification attempt. */ - void tryUnify(TypeId subTy, TypeId superTy, bool isFunctionCall = false, bool isIntersection = false, const LiteralProperties* aliasableMap = nullptr); + void tryUnify( + TypeId subTy, TypeId superTy, bool isFunctionCall = false, bool isIntersection = false, const LiteralProperties* aliasableMap = nullptr); private: - void tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall = false, bool isIntersection = false, const LiteralProperties* aliasableMap = nullptr); + void tryUnify_( + TypeId subTy, TypeId superTy, bool isFunctionCall = false, bool isIntersection = false, const LiteralProperties* aliasableMap = nullptr); void tryUnifyUnionWithType(TypeId subTy, const UnionType* uv, TypeId superTy); // Traverse the two types provided and block on any BlockedTypes we find. diff --git a/Analysis/include/Luau/Unifier2.h b/Analysis/include/Luau/Unifier2.h index cf769da3..6d32e03f 100644 --- a/Analysis/include/Luau/Unifier2.h +++ b/Analysis/include/Luau/Unifier2.h @@ -55,8 +55,8 @@ struct Unifier2 bool unify(TypePackId subTp, TypePackId superTp); std::optional generalize(NotNull scope, TypeId ty); -private: +private: /** * @returns simplify(left | right) */ @@ -72,4 +72,4 @@ private: OccursCheckResult occursCheck(DenseHashSet& seen, TypePackId needle, TypePackId haystack); }; -} +} // namespace Luau diff --git a/Analysis/include/Luau/VisitType.h b/Analysis/include/Luau/VisitType.h index a84fb48c..28dfffbf 100644 --- a/Analysis/include/Luau/VisitType.h +++ b/Analysis/include/Luau/VisitType.h @@ -10,6 +10,7 @@ LUAU_FASTINT(LuauVisitRecursionLimit) LUAU_FASTFLAG(LuauBoundLazyTypes2) +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(DebugLuauReadWriteProperties) namespace Luau @@ -220,7 +221,21 @@ struct GenericTypeVisitor traverse(btv->boundTo); } else if (auto ftv = get(ty)) - visit(ty, *ftv); + { + if (FFlag::DebugLuauDeferredConstraintResolution) + { + if (visit(ty, *ftv)) + { + LUAU_ASSERT(ftv->lowerBound); + traverse(ftv->lowerBound); + + LUAU_ASSERT(ftv->upperBound); + traverse(ftv->upperBound); + } + } + else + visit(ty, *ftv); + } else if (auto gtv = get(ty)) visit(ty, *gtv); else if (auto etv = get(ty)) diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 4a5638b1..3eba2e12 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -282,20 +282,8 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNul ParenthesesRecommendation parens = indexType == PropIndexType::Key ? ParenthesesRecommendation::None : getParenRecommendation(type, nodes, typeCorrect); - result[name] = AutocompleteEntry{ - AutocompleteEntryKind::Property, - type, - prop.deprecated, - isWrongIndexer(type), - typeCorrect, - containingClass, - &prop, - prop.documentationSymbol, - {}, - parens, - {}, - indexType == PropIndexType::Colon - }; + result[name] = AutocompleteEntry{AutocompleteEntryKind::Property, type, prop.deprecated, isWrongIndexer(type), typeCorrect, + containingClass, &prop, prop.documentationSymbol, {}, parens, {}, indexType == PropIndexType::Colon}; } } }; @@ -606,7 +594,7 @@ std::optional getLocalTypeInScopeAt(const Module& module, Position posit return {}; } -template +template static std::optional tryToStringDetailed(const ScopePtr& scope, T ty, bool functionTypeArguments) { ToStringOptions opts; @@ -1418,7 +1406,7 @@ static std::string makeAnonymous(const ScopePtr& scope, const FunctionType& func name = "a" + std::to_string(argIdx); if (std::optional type = tryGetTypeNameInScope(scope, args[argIdx], true)) - result += name + ": " + *type; + result += name + ": " + *type; else result += name; } @@ -1434,7 +1422,7 @@ static std::string makeAnonymous(const ScopePtr& scope, const FunctionType& func if (std::optional res = tryToStringDetailed(scope, pack->ty, true)) varArgType = std::move(res); } - + if (varArgType) result += "...: " + *varArgType; else @@ -1461,7 +1449,8 @@ static std::string makeAnonymous(const ScopePtr& scope, const FunctionType& func return result; } -static std::optional makeAnonymousAutofilled(const ModulePtr& module, Position position, const AstNode* node, const std::vector& ancestry) +static std::optional makeAnonymousAutofilled( + const ModulePtr& module, Position position, const AstNode* node, const std::vector& ancestry) { const AstExprCall* call = node->as(); if (!call && ancestry.size() > 1) @@ -1498,10 +1487,10 @@ static std::optional makeAnonymousAutofilled(const ModulePtr& auto [args, tail] = flatten(outerFunction->argTypes); if (argument < args.size()) argType = args[argument]; - + if (!argType) return std::nullopt; - + TypeId followed = follow(*argType); const FunctionType* type = get(followed); if (!type) diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index ad96527e..c28e1678 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -780,7 +780,7 @@ void TypeCloner::operator()(const UnionType& t) // We're just using this FreeType as a placeholder until we've finished // cloning the parts of this union so it is okay that its bounds are // nullptr. We'll never indirect them. - TypeId result = dest.addType(FreeType{nullptr, /*lowerBound*/nullptr, /*upperBound*/nullptr}); + TypeId result = dest.addType(FreeType{nullptr, /*lowerBound*/ nullptr, /*upperBound*/ nullptr}); seenTypes[typeId] = result; std::vector options; diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index d2413adb..ae143ca5 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -420,17 +420,17 @@ void ConstraintGraphBuilder::applyRefinements(const ScopePtr& scope, Location lo { switch (shouldSuppressErrors(normalizer, ty)) { - case ErrorSuppression::DoNotSuppress: - ty = simplifyIntersection(builtinTypes, arena, ty, dt).result; - break; - case ErrorSuppression::Suppress: - ty = simplifyIntersection(builtinTypes, arena, ty, dt).result; - ty = simplifyUnion(builtinTypes, arena, ty, builtinTypes->errorType).result; - break; - case ErrorSuppression::NormalizationFailed: - reportError(location, NormalizationTooComplex{}); - ty = simplifyIntersection(builtinTypes, arena, ty, dt).result; - break; + case ErrorSuppression::DoNotSuppress: + ty = simplifyIntersection(builtinTypes, arena, ty, dt).result; + break; + case ErrorSuppression::Suppress: + ty = simplifyIntersection(builtinTypes, arena, ty, dt).result; + ty = simplifyUnion(builtinTypes, arena, ty, builtinTypes->errorType).result; + break; + case ErrorSuppression::NormalizationFailed: + reportError(location, NormalizationTooComplex{}); + ty = simplifyIntersection(builtinTypes, arena, ty, dt).result; + break; } } } diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index fd0c1c00..6faea1d2 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -430,6 +430,35 @@ bool ConstraintSolver::isDone() return unsolvedConstraints.empty(); } +namespace +{ + +struct TypeAndLocation +{ + TypeId typeId; + Location location; +}; + +struct FreeTypeSearcher : TypeOnceVisitor +{ + std::deque* result; + Location location; + + FreeTypeSearcher(std::deque* result, Location location) + : result(result) + , location(location) + { + } + + bool visit(TypeId ty, const FreeType&) override + { + result->push_back({ty, location}); + return false; + } +}; + +} // namespace + void ConstraintSolver::finalizeModule() { Anyification a{arena, rootScope, builtinTypes, &iceReporter, builtinTypes->anyType, builtinTypes->anyTypePack}; @@ -446,12 +475,28 @@ void ConstraintSolver::finalizeModule() Unifier2 u2{NotNull{arena}, builtinTypes, NotNull{&iceReporter}}; + std::deque queue; for (auto& [name, binding] : rootScope->bindings) + queue.push_back({binding.typeId, binding.location}); + + DenseHashSet seen{nullptr}; + + while (!queue.empty()) { - auto generalizedTy = u2.generalize(rootScope, binding.typeId); - if (generalizedTy) - binding.typeId = *generalizedTy; - else + TypeAndLocation binding = queue.front(); + queue.pop_front(); + + TypeId ty = follow(binding.typeId); + + if (seen.find(ty)) + continue; + seen.insert(ty); + + FreeTypeSearcher fts{&queue, binding.location}; + fts.traverse(ty); + + auto result = u2.generalize(rootScope, ty); + if (!result) reportError(CodeTooComplex{}, binding.location); } } @@ -2642,20 +2687,14 @@ ErrorVec ConstraintSolver::unify(NotNull scope, Location location, TypeId ErrorVec ConstraintSolver::unify(NotNull scope, Location location, TypePackId subPack, TypePackId superPack) { - UnifierSharedState sharedState{&iceReporter}; - Unifier u{normalizer, scope, Location{}, Covariant}; - u.enableNewSolver(); + Unifier2 u{arena, builtinTypes, NotNull{&iceReporter}}; - u.tryUnify(subPack, superPack); + u.unify(subPack, superPack); - const auto [changedTypes, changedPacks] = u.log.getChanges(); + unblock(subPack, Location{}); + unblock(superPack, Location{}); - u.log.commit(); - - unblock(changedTypes, Location{}); - unblock(changedPacks, Location{}); - - return std::move(u.errors); + return {}; } NotNull ConstraintSolver::pushConstraint(NotNull scope, const Location& location, ConstraintV cv) diff --git a/Analysis/src/Differ.cpp b/Analysis/src/Differ.cpp index 84505071..5b614cef 100644 --- a/Analysis/src/Differ.cpp +++ b/Analysis/src/Differ.cpp @@ -117,9 +117,7 @@ std::string DiffError::toStringALeaf(std::string rootName, const DiffPathNodeLea case DiffError::Kind::Normal: { checkNonMissingPropertyLeavesHaveNulloptTableProperty(); - return pathStr + conditionalNewline - + "has type" + conditionalNewline - + conditionalIndent + Luau::toString(*leaf.ty); + return pathStr + conditionalNewline + "has type" + conditionalNewline + conditionalIndent + Luau::toString(*leaf.ty); } case DiffError::Kind::MissingTableProperty: { @@ -127,17 +125,14 @@ std::string DiffError::toStringALeaf(std::string rootName, const DiffPathNodeLea { if (!leaf.tableProperty.has_value()) throw InternalCompilerError{"leaf.tableProperty is nullopt"}; - return pathStr + "." + *leaf.tableProperty + conditionalNewline - + "has type" + conditionalNewline - + conditionalIndent + Luau::toString(*leaf.ty); + return pathStr + "." + *leaf.tableProperty + conditionalNewline + "has type" + conditionalNewline + conditionalIndent + + Luau::toString(*leaf.ty); } else if (otherLeaf.ty.has_value()) { if (!otherLeaf.tableProperty.has_value()) throw InternalCompilerError{"otherLeaf.tableProperty is nullopt"}; - return pathStr + conditionalNewline - + "is missing the property" + conditionalNewline - + conditionalIndent + *otherLeaf.tableProperty; + return pathStr + conditionalNewline + "is missing the property" + conditionalNewline + conditionalIndent + *otherLeaf.tableProperty; } throw InternalCompilerError{"Both leaf.ty and otherLeaf.ty is nullopt"}; } @@ -148,15 +143,11 @@ std::string DiffError::toStringALeaf(std::string rootName, const DiffPathNodeLea { if (!leaf.unionIndex.has_value()) throw InternalCompilerError{"leaf.unionIndex is nullopt"}; - return pathStr + conditionalNewline - + "is a union containing type" + conditionalNewline - + conditionalIndent + Luau::toString(*leaf.ty); + return pathStr + conditionalNewline + "is a union containing type" + conditionalNewline + conditionalIndent + Luau::toString(*leaf.ty); } else if (otherLeaf.ty.has_value()) { - return pathStr + conditionalNewline - + "is a union missing type" + conditionalNewline - + conditionalIndent + Luau::toString(*otherLeaf.ty); + return pathStr + conditionalNewline + "is a union missing type" + conditionalNewline + conditionalIndent + Luau::toString(*otherLeaf.ty); } throw InternalCompilerError{"Both leaf.ty and otherLeaf.ty is nullopt"}; } @@ -169,15 +160,13 @@ std::string DiffError::toStringALeaf(std::string rootName, const DiffPathNodeLea { if (!leaf.unionIndex.has_value()) throw InternalCompilerError{"leaf.unionIndex is nullopt"}; - return pathStr + conditionalNewline - + "is an intersection containing type" + conditionalNewline - + conditionalIndent + Luau::toString(*leaf.ty); + return pathStr + conditionalNewline + "is an intersection containing type" + conditionalNewline + conditionalIndent + + Luau::toString(*leaf.ty); } else if (otherLeaf.ty.has_value()) { - return pathStr + conditionalNewline - + "is an intersection missing type" + conditionalNewline - + conditionalIndent + Luau::toString(*otherLeaf.ty); + return pathStr + conditionalNewline + "is an intersection missing type" + conditionalNewline + conditionalIndent + + Luau::toString(*otherLeaf.ty); } throw InternalCompilerError{"Both leaf.ty and otherLeaf.ty is nullopt"}; } @@ -185,15 +174,13 @@ std::string DiffError::toStringALeaf(std::string rootName, const DiffPathNodeLea { if (!leaf.minLength.has_value()) throw InternalCompilerError{"leaf.minLength is nullopt"}; - return pathStr + conditionalNewline - + "takes " + std::to_string(*leaf.minLength) + (leaf.isVariadic ? " or more" : "") + " arguments"; + return pathStr + conditionalNewline + "takes " + std::to_string(*leaf.minLength) + (leaf.isVariadic ? " or more" : "") + " arguments"; } case DiffError::Kind::LengthMismatchInFnRets: { if (!leaf.minLength.has_value()) throw InternalCompilerError{"leaf.minLength is nullopt"}; - return pathStr + conditionalNewline - + "returns " + std::to_string(*leaf.minLength) + (leaf.isVariadic ? " or more" : "") + " values"; + return pathStr + conditionalNewline + "returns " + std::to_string(*leaf.minLength) + (leaf.isVariadic ? " or more" : "") + " values"; } default: { @@ -249,17 +236,15 @@ std::string DiffError::toString(bool multiLine) const case DiffError::Kind::IncompatibleGeneric: { std::string diffPathStr{diffPath.toString(true)}; - return "DiffError: these two types are not equal because the left generic at" + conditionalNewline - + conditionalIndent + leftRootName + diffPathStr + conditionalNewline - + "cannot be the same type parameter as the right generic at" + conditionalNewline - + conditionalIndent + rightRootName + diffPathStr; + return "DiffError: these two types are not equal because the left generic at" + conditionalNewline + conditionalIndent + leftRootName + + diffPathStr + conditionalNewline + "cannot be the same type parameter as the right generic at" + conditionalNewline + + conditionalIndent + rightRootName + diffPathStr; } default: { - return "DiffError: these two types are not equal because the left type at" + conditionalNewline - + conditionalIndent + toStringALeaf(leftRootName, left, right, multiLine) + "," + conditionalNewline + - "while the right type at" + conditionalNewline - + conditionalIndent + toStringALeaf(rightRootName, right, left, multiLine); + return "DiffError: these two types are not equal because the left type at" + conditionalNewline + conditionalIndent + + toStringALeaf(leftRootName, left, right, multiLine) + "," + conditionalNewline + "while the right type at" + conditionalNewline + + conditionalIndent + toStringALeaf(rightRootName, right, left, multiLine); } } } diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index b71a9354..677458ab 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -1289,7 +1289,8 @@ ModulePtr check(const SourceModule& sourceModule, const std::vectortimeout || result->cancelled) { - // If solver was interrupted, skip typechecking and replace all module results with error-supressing types to avoid leaking blocked/pending types + // If solver was interrupted, skip typechecking and replace all module results with error-supressing types to avoid leaking blocked/pending + // types ScopePtr moduleScope = result->getModuleScope(); moduleScope->returnType = builtinTypes->errorRecoveryTypePack(); diff --git a/Analysis/src/GlobalTypes.cpp b/Analysis/src/GlobalTypes.cpp index 9e26a2e3..654cfa5d 100644 --- a/Analysis/src/GlobalTypes.cpp +++ b/Analysis/src/GlobalTypes.cpp @@ -31,4 +31,4 @@ GlobalTypes::GlobalTypes(NotNull builtinTypes) } } -} +} // namespace Luau diff --git a/Analysis/src/Instantiation.cpp b/Analysis/src/Instantiation.cpp index 0c8bc1ec..e74ece06 100644 --- a/Analysis/src/Instantiation.cpp +++ b/Analysis/src/Instantiation.cpp @@ -174,7 +174,8 @@ struct Replacer : Substitution } }; -std::optional instantiate(NotNull builtinTypes, NotNull arena, NotNull limits, NotNull scope, TypeId ty) +std::optional instantiate( + NotNull builtinTypes, NotNull arena, NotNull limits, NotNull scope, TypeId ty) { ty = follow(ty); diff --git a/Analysis/src/Linter.cpp b/Analysis/src/Linter.cpp index e8cd7dbd..d4a16a75 100644 --- a/Analysis/src/Linter.cpp +++ b/Analysis/src/Linter.cpp @@ -2791,8 +2791,8 @@ static void lintComments(LintContext& context, const std::vector& ho else if (first == "native") { if (space != std::string::npos) - emitWarning(context, LintWarning::Code_CommentDirective, hc.location, - "native directive has extra symbols at the end of the line"); + emitWarning( + context, LintWarning::Code_CommentDirective, hc.location, "native directive has extra symbols at the end of the line"); } else { diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index bcad75b0..ac58c7f6 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -176,7 +176,7 @@ const NormalizedStringType NormalizedStringType::never; bool isSubtype(const NormalizedStringType& subStr, const NormalizedStringType& superStr) { - if (subStr.isUnion() && superStr.isUnion()) + if (subStr.isUnion() && (superStr.isUnion() && !superStr.isNever())) { for (auto [name, ty] : subStr.singletons) { @@ -1983,18 +1983,68 @@ void Normalizer::intersectClassesWithClass(NormalizedClassType& heres, TypeId th void Normalizer::intersectStrings(NormalizedStringType& here, const NormalizedStringType& there) { + /* There are 9 cases to worry about here + Normalized Left | Normalized Right + C1 string | string ===> trivial + C2 string - {u_1,..} | string ===> trivial + C3 {u_1, ..} | string ===> trivial + C4 string | string - {v_1, ..} ===> string - {v_1, ..} + C5 string - {u_1,..} | string - {v_1, ..} ===> string - ({u_s} U {v_s}) + C6 {u_1, ..} | string - {v_1, ..} ===> {u_s} - {v_s} + C7 string | {v_1, ..} ===> {v_s} + C8 string - {u_1,..} | {v_1, ..} ===> {v_s} - {u_s} + C9 {u_1, ..} | {v_1, ..} ===> {u_s} ∩ {v_s} + */ + // Case 1,2,3 if (there.isString()) return; - if (here.isString()) - here.resetToNever(); - - for (auto it = here.singletons.begin(); it != here.singletons.end();) + // Case 4, Case 7 + else if (here.isString()) { - if (there.singletons.count(it->first)) - it++; - else - it = here.singletons.erase(it); + here.singletons.clear(); + for (const auto& [key, type] : there.singletons) + here.singletons[key] = type; + here.isCofinite = here.isCofinite && there.isCofinite; } + // Case 5 + else if (here.isIntersection() && there.isIntersection()) + { + here.isCofinite = true; + for (const auto& [key, type] : there.singletons) + here.singletons[key] = type; + } + // Case 6 + else if (here.isUnion() && there.isIntersection()) + { + here.isCofinite = false; + for (const auto& [key, _] : there.singletons) + here.singletons.erase(key); + } + // Case 8 + else if (here.isIntersection() && there.isUnion()) + { + here.isCofinite = false; + std::map result(there.singletons); + for (const auto& [key, _] : here.singletons) + result.erase(key); + here.singletons = result; + } + // Case 9 + else if (here.isUnion() && there.isUnion()) + { + here.isCofinite = false; + std::map result; + result.insert(here.singletons.begin(), here.singletons.end()); + result.insert(there.singletons.begin(), there.singletons.end()); + for (auto it = result.begin(); it != result.end();) + if (!here.singletons.count(it->first) || !there.singletons.count(it->first)) + it = result.erase(it); + else + ++it; + here.singletons = result; + } + else + LUAU_ASSERT(0 && "Internal Error - unrecognized case"); } std::optional Normalizer::intersectionOfTypePacks(TypePackId here, TypePackId there) diff --git a/Analysis/src/Subtyping.cpp b/Analysis/src/Subtyping.cpp index 8012bac7..e216d623 100644 --- a/Analysis/src/Subtyping.cpp +++ b/Analysis/src/Subtyping.cpp @@ -5,6 +5,7 @@ #include "Luau/Common.h" #include "Luau/Error.h" #include "Luau/Normalize.h" +#include "Luau/Scope.h" #include "Luau/StringUtils.h" #include "Luau/ToString.h" #include "Luau/Type.h" @@ -28,12 +29,12 @@ struct VarianceFlipper { switch (oldValue) { - case Subtyping::Variance::Covariant: - *variance = Subtyping::Variance::Contravariant; - break; - case Subtyping::Variance::Contravariant: - *variance = Subtyping::Variance::Covariant; - break; + case Subtyping::Variance::Covariant: + *variance = Subtyping::Variance::Contravariant; + break; + case Subtyping::Variance::Contravariant: + *variance = Subtyping::Variance::Covariant; + break; } } @@ -43,19 +44,21 @@ struct VarianceFlipper } }; -void SubtypingResult::andAlso(const SubtypingResult& other) +SubtypingResult& SubtypingResult::andAlso(const SubtypingResult& other) { isSubtype &= other.isSubtype; // `|=` is intentional here, we want to preserve error related flags. isErrorSuppressing |= other.isErrorSuppressing; normalizationTooComplex |= other.normalizationTooComplex; + return *this; } -void SubtypingResult::orElse(const SubtypingResult& other) +SubtypingResult& SubtypingResult::orElse(const SubtypingResult& other) { isSubtype |= other.isSubtype; isErrorSuppressing |= other.isErrorSuppressing; normalizationTooComplex |= other.normalizationTooComplex; + return *this; } SubtypingResult SubtypingResult::negate(const SubtypingResult& result) @@ -88,9 +91,9 @@ SubtypingResult Subtyping::isSubtype(TypeId subTy, TypeId superTy) mappedGenerics.clear(); mappedGenericPacks.clear(); - SubtypingResult result = isSubtype_(subTy, superTy); + SubtypingResult result = isCovariantWith(subTy, superTy); - for (const auto& [subTy, bounds]: mappedGenerics) + for (const auto& [subTy, bounds] : mappedGenerics) { const auto& lb = bounds.lowerBound; const auto& ub = bounds.upperBound; @@ -98,7 +101,7 @@ SubtypingResult Subtyping::isSubtype(TypeId subTy, TypeId superTy) TypeId lowerBound = makeAggregateType(lb, builtinTypes->neverType); TypeId upperBound = makeAggregateType(ub, builtinTypes->unknownType); - result.andAlso(isSubtype_(lowerBound, upperBound)); + result.andAlso(isCovariantWith(lowerBound, upperBound)); } return result; @@ -106,7 +109,7 @@ SubtypingResult Subtyping::isSubtype(TypeId subTy, TypeId superTy) SubtypingResult Subtyping::isSubtype(TypePackId subTp, TypePackId superTp) { - return isSubtype_(subTp, superTp); + return isCovariantWith(subTp, superTp); } namespace @@ -119,16 +122,17 @@ struct SeenSetPopper SeenSetPopper(Subtyping::SeenSet* seenTypes, std::pair pair) : seenTypes(seenTypes) , pair(pair) - {} + { + } ~SeenSetPopper() { seenTypes->erase(pair); } }; -} +} // namespace -SubtypingResult Subtyping::isSubtype_(TypeId subTy, TypeId superTy) +SubtypingResult Subtyping::isCovariantWith(TypeId subTy, TypeId superTy) { subTy = follow(subTy); superTy = follow(superTy); @@ -146,19 +150,27 @@ SubtypingResult Subtyping::isSubtype_(TypeId subTy, TypeId superTy) SeenSetPopper ssp{&seenTypes, typePair}; + // Within the scope to which a generic belongs, that generic should be + // tested as though it were its upper bounds. We do not yet support bounded + // generics, so the upper bound is always unknown. + if (auto subGeneric = get(subTy); subGeneric && subsumes(subGeneric->scope, scope)) + return isCovariantWith(builtinTypes->unknownType, superTy); + if (auto superGeneric = get(superTy); superGeneric && subsumes(superGeneric->scope, scope)) + return isCovariantWith(subTy, builtinTypes->unknownType); + if (auto subUnion = get(subTy)) - return isSubtype_(subUnion, superTy); + return isCovariantWith(subUnion, superTy); else if (auto superUnion = get(superTy)) - return isSubtype_(subTy, superUnion); + return isCovariantWith(subTy, superUnion); else if (auto superIntersection = get(superTy)) - return isSubtype_(subTy, superIntersection); + return isCovariantWith(subTy, superIntersection); else if (auto subIntersection = get(subTy)) { - SubtypingResult result = isSubtype_(subIntersection, superTy); + SubtypingResult result = isCovariantWith(subIntersection, superTy); if (result.isSubtype || result.isErrorSuppressing || result.normalizationTooComplex) return result; else - return isSubtype_(normalizer->normalize(subTy), normalizer->normalize(superTy)); + return isCovariantWith(normalizer->normalize(subTy), normalizer->normalize(superTy)); } else if (get(superTy)) return {true}; // This is always true. @@ -166,14 +178,12 @@ SubtypingResult Subtyping::isSubtype_(TypeId subTy, TypeId superTy) { // any = unknown | error, so we rewrite this to match. // As per TAPL: A | B <: T iff A <: T && B <: T - SubtypingResult result = isSubtype_(builtinTypes->unknownType, superTy); - result.andAlso(isSubtype_(builtinTypes->errorType, superTy)); - return result; + return isCovariantWith(builtinTypes->unknownType, superTy).andAlso(isCovariantWith(builtinTypes->errorType, superTy)); } else if (get(superTy)) { - LUAU_ASSERT(!get(subTy)); // TODO: replace with ice. - LUAU_ASSERT(!get(subTy)); // TODO: replace with ice. + LUAU_ASSERT(!get(subTy)); // TODO: replace with ice. + LUAU_ASSERT(!get(subTy)); // TODO: replace with ice. LUAU_ASSERT(!get(subTy)); // TODO: replace with ice. bool errorSuppressing = get(subTy); @@ -185,6 +195,12 @@ SubtypingResult Subtyping::isSubtype_(TypeId subTy, TypeId superTy) return {false, true}; else if (get(subTy)) return {false, true}; + else if (auto p = get2(subTy, superTy)) + return isCovariantWith(p.first->ty, p.second->ty); + else if (auto subNegation = get(subTy)) + return isCovariantWith(subNegation, superTy); + else if (auto superNegation = get(superTy)) + return isCovariantWith(subTy, superNegation); else if (auto subGeneric = get(subTy); subGeneric && variance == Variance::Covariant) { bool ok = bindGeneric(subTy, superTy); @@ -196,32 +212,32 @@ SubtypingResult Subtyping::isSubtype_(TypeId subTy, TypeId superTy) return {ok}; } else if (auto p = get2(subTy, superTy)) - return isSubtype_(p); + return isCovariantWith(p); else if (auto p = get2(subTy, superTy)) - return isSubtype_(p); + return isCovariantWith(p); else if (auto p = get2(subTy, superTy)) - return isSubtype_(p); + return isCovariantWith(p); else if (auto p = get2(subTy, superTy)) - return isSubtype_(p); + return isCovariantWith(p); else if (auto p = get2(subTy, superTy)) - return isSubtype_(p); + return isCovariantWith(p); else if (auto p = get2(subTy, superTy)) - return isSubtype_(p); + return isCovariantWith(p); else if (auto p = get2(subTy, superTy)) - return isSubtype_(p); + return isCovariantWith(p); else if (auto p = get2(subTy, superTy)) - return isSubtype_(p); + return isCovariantWith(p); else if (auto p = get2(subTy, superTy)) - return isSubtype_(p); + return isCovariantWith(p); else if (auto p = get2(subTy, superTy)) - return isSubtype_(p); + return isCovariantWith(p); else if (auto p = get2(subTy, superTy)) - return isSubtype_(p); + return isCovariantWith(p); return {false}; } -SubtypingResult Subtyping::isSubtype_(TypePackId subTp, TypePackId superTp) +SubtypingResult Subtyping::isCovariantWith(TypePackId subTp, TypePackId superTp) { subTp = follow(subTp); superTp = follow(superTp); @@ -241,7 +257,7 @@ SubtypingResult Subtyping::isSubtype_(TypePackId subTp, TypePackId superTp) for (size_t i = 0; i < headSize; ++i) { - results.push_back(isSubtype_(subHead[i], superHead[i])); + results.push_back(isCovariantWith(subHead[i], superHead[i])); if (!results.back().isSubtype) return {false}; } @@ -255,7 +271,7 @@ SubtypingResult Subtyping::isSubtype_(TypePackId subTp, TypePackId superTp) if (auto vt = get(*subTail)) { for (size_t i = headSize; i < superHead.size(); ++i) - results.push_back(isSubtype_(vt->ty, superHead[i])); + results.push_back(isCovariantWith(vt->ty, superHead[i])); } else if (auto gt = get(*subTail)) { @@ -266,11 +282,11 @@ SubtypingResult Subtyping::isSubtype_(TypePackId subTp, TypePackId superTp) // (X) -> () <: (T) -> () // Possible optimization: If headSize == 0 then we can just use subTp as-is. - std::vector headSlice(begin(superHead), end(superHead) + headSize); + std::vector headSlice(begin(superHead), begin(superHead) + headSize); TypePackId superTailPack = arena->addTypePack(std::move(headSlice), superTail); if (TypePackId* other = mappedGenericPacks.find(*subTail)) - results.push_back(isSubtype_(*other, superTailPack)); + results.push_back(isCovariantWith(*other, superTailPack)); else mappedGenericPacks.try_insert(*subTail, superTailPack); @@ -300,7 +316,7 @@ SubtypingResult Subtyping::isSubtype_(TypePackId subTp, TypePackId superTp) if (auto vt = get(*superTail)) { for (size_t i = headSize; i < subHead.size(); ++i) - results.push_back(isSubtype_(subHead[i], vt->ty)); + results.push_back(isCovariantWith(subHead[i], vt->ty)); } else if (auto gt = get(*superTail)) { @@ -311,11 +327,11 @@ SubtypingResult Subtyping::isSubtype_(TypePackId subTp, TypePackId superTp) // (X...) -> () <: (T) -> () // Possible optimization: If headSize == 0 then we can just use subTp as-is. - std::vector headSlice(begin(subHead), end(subHead) + headSize); + std::vector headSlice(begin(subHead), begin(subHead) + headSize); TypePackId subTailPack = arena->addTypePack(std::move(headSlice), subTail); if (TypePackId* other = mappedGenericPacks.find(*superTail)) - results.push_back(isSubtype_(*other, subTailPack)); + results.push_back(isCovariantWith(*other, subTailPack)); else mappedGenericPacks.try_insert(*superTail, subTailPack); @@ -344,7 +360,7 @@ SubtypingResult Subtyping::isSubtype_(TypePackId subTp, TypePackId superTp) { if (auto p = get2(*subTail, *superTail)) { - results.push_back(isSubtype_(p)); + results.push_back(isCovariantWith(p)); } else if (auto p = get2(*subTail, *superTail)) { @@ -380,7 +396,8 @@ SubtypingResult Subtyping::isSubtype_(TypePackId subTp, TypePackId superTp) } } else - iceReporter->ice(format("Subtyping::isSubtype got unexpected type packs %s and %s", toString(*subTail).c_str(), toString(*superTail).c_str())); + iceReporter->ice( + format("Subtyping::isSubtype got unexpected type packs %s and %s", toString(*subTail).c_str(), toString(*superTail).c_str())); } else if (subTail) { @@ -428,9 +445,33 @@ SubtypingResult Subtyping::isSubtype_(TypePackId subTp, TypePackId superTp) } template -SubtypingResult Subtyping::isSubtype_(const TryPair& pair) +SubtypingResult Subtyping::isContravariantWith(SubTy&& subTy, SuperTy&& superTy) { - return isSubtype_(pair.first, pair.second); + return isCovariantWith(superTy, subTy); +} + +template +SubtypingResult Subtyping::isInvariantWith(SubTy&& subTy, SuperTy&& superTy) +{ + return isCovariantWith(subTy, superTy).andAlso(isContravariantWith(subTy, superTy)); +} + +template +SubtypingResult Subtyping::isCovariantWith(const TryPair& pair) +{ + return isCovariantWith(pair.first, pair.second); +} + +template +SubtypingResult Subtyping::isContravariantWith(const TryPair& pair) +{ + return isCovariantWith(pair.second, pair.first); +} + +template +SubtypingResult Subtyping::isInvariantWith(const TryPair& pair) +{ + return isCovariantWith(pair).andAlso(isContravariantWith(pair)); } /* @@ -464,48 +505,219 @@ SubtypingResult Subtyping::isSubtype_(const TryPair subtypings; for (TypeId ty : superUnion) - subtypings.push_back(isSubtype_(subTy, ty)); + subtypings.push_back(isCovariantWith(subTy, ty)); return SubtypingResult::any(subtypings); } -SubtypingResult Subtyping::isSubtype_(const UnionType* subUnion, TypeId superTy) +SubtypingResult Subtyping::isCovariantWith(const UnionType* subUnion, TypeId superTy) { // As per TAPL: A | B <: T iff A <: T && B <: T std::vector subtypings; for (TypeId ty : subUnion) - subtypings.push_back(isSubtype_(ty, superTy)); + subtypings.push_back(isCovariantWith(ty, superTy)); return SubtypingResult::all(subtypings); } -SubtypingResult Subtyping::isSubtype_(TypeId subTy, const IntersectionType* superIntersection) +SubtypingResult Subtyping::isCovariantWith(TypeId subTy, const IntersectionType* superIntersection) { // As per TAPL: T <: A & B iff T <: A && T <: B std::vector subtypings; for (TypeId ty : superIntersection) - subtypings.push_back(isSubtype_(subTy, ty)); + subtypings.push_back(isCovariantWith(subTy, ty)); return SubtypingResult::all(subtypings); } -SubtypingResult Subtyping::isSubtype_(const IntersectionType* subIntersection, TypeId superTy) +SubtypingResult Subtyping::isCovariantWith(const IntersectionType* subIntersection, TypeId superTy) { // As per TAPL: A & B <: T iff A <: T || B <: T std::vector subtypings; for (TypeId ty : subIntersection) - subtypings.push_back(isSubtype_(ty, superTy)); + subtypings.push_back(isCovariantWith(ty, superTy)); return SubtypingResult::any(subtypings); } -SubtypingResult Subtyping::isSubtype_(const PrimitiveType* subPrim, const PrimitiveType* superPrim) +SubtypingResult Subtyping::isCovariantWith(const NegationType* subNegation, TypeId superTy) +{ + TypeId negatedTy = follow(subNegation->ty); + + // In order to follow a consistent codepath, rather than folding the + // isCovariantWith test down to its conclusion here, we test the subtyping test + // of the result of negating the type for never, unknown, any, and error. + if (is(negatedTy)) + { + // ¬never ~ unknown + return isCovariantWith(builtinTypes->unknownType, superTy); + } + else if (is(negatedTy)) + { + // ¬unknown ~ never + return isCovariantWith(builtinTypes->neverType, superTy); + } + else if (is(negatedTy)) + { + // ¬any ~ any + return isCovariantWith(negatedTy, superTy); + } + else if (auto u = get(negatedTy)) + { + // ¬(A ∪ B) ~ ¬A ∩ ¬B + // follow intersection rules: A & B <: T iff A <: T && B <: T + std::vector subtypings; + + for (TypeId ty : u) + { + NegationType negatedTmp{ty}; + subtypings.push_back(isCovariantWith(&negatedTmp, superTy)); + } + + return SubtypingResult::all(subtypings); + } + else if (auto i = get(negatedTy)) + { + // ¬(A ∩ B) ~ ¬A ∪ ¬B + // follow union rules: A | B <: T iff A <: T || B <: T + std::vector subtypings; + + for (TypeId ty : i) + { + if (auto negatedPart = get(follow(ty))) + subtypings.push_back(isCovariantWith(negatedPart->ty, superTy)); + else + { + NegationType negatedTmp{ty}; + subtypings.push_back(isCovariantWith(&negatedTmp, superTy)); + } + } + + return SubtypingResult::any(subtypings); + } + else if (is(negatedTy)) + { + iceReporter->ice("attempting to negate a non-testable type"); + } + // negating a different subtype will get you a very wide type that's not a + // subtype of other stuff. + else + { + return {false}; + } +} + +SubtypingResult Subtyping::isCovariantWith(const TypeId subTy, const NegationType* superNegation) +{ + TypeId negatedTy = follow(superNegation->ty); + + if (is(negatedTy)) + { + // ¬never ~ unknown + return isCovariantWith(subTy, builtinTypes->unknownType); + } + else if (is(negatedTy)) + { + // ¬unknown ~ never + return isCovariantWith(subTy, builtinTypes->neverType); + } + else if (is(negatedTy)) + { + // ¬any ~ any + return isSubtype(subTy, negatedTy); + } + else if (auto u = get(negatedTy)) + { + // ¬(A ∪ B) ~ ¬A ∩ ¬B + // follow intersection rules: A & B <: T iff A <: T && B <: T + std::vector subtypings; + + for (TypeId ty : u) + { + if (auto negatedPart = get(follow(ty))) + subtypings.push_back(isCovariantWith(subTy, negatedPart->ty)); + else + { + NegationType negatedTmp{ty}; + subtypings.push_back(isCovariantWith(subTy, &negatedTmp)); + } + } + + return SubtypingResult::all(subtypings); + } + else if (auto i = get(negatedTy)) + { + // ¬(A ∩ B) ~ ¬A ∪ ¬B + // follow union rules: A | B <: T iff A <: T || B <: T + std::vector subtypings; + + for (TypeId ty : i) + { + if (auto negatedPart = get(follow(ty))) + subtypings.push_back(isCovariantWith(subTy, negatedPart->ty)); + else + { + NegationType negatedTmp{ty}; + subtypings.push_back(isCovariantWith(subTy, &negatedTmp)); + } + } + + return SubtypingResult::any(subtypings); + } + else if (auto p = get2(subTy, negatedTy)) + { + // number <: ¬boolean + // number type != p.second->type}; + } + else if (auto p = get2(subTy, negatedTy)) + { + // "foo" (p.first) && p.second->type == PrimitiveType::String) + return {false}; + // false (p.first) && p.second->type == PrimitiveType::Boolean) + return {false}; + // other cases are true + else + return {true}; + } + else if (auto p = get2(subTy, negatedTy)) + { + if (p.first->type == PrimitiveType::String && get(p.second)) + return {false}; + else if (p.first->type == PrimitiveType::Boolean && get(p.second)) + return {false}; + else + return {true}; + } + // the top class type is not actually a primitive type, so the negation of + // any one of them includes the top class type. + else if (auto p = get2(subTy, negatedTy)) + return {true}; + else if (auto p = get(negatedTy); p && is(subTy)) + return {p->type != PrimitiveType::Table}; + else if (auto p = get2(subTy, negatedTy)) + return {p.second->type != PrimitiveType::Function}; + else if (auto p = get2(subTy, negatedTy)) + return {*p.first != *p.second}; + else if (auto p = get2(subTy, negatedTy)) + return SubtypingResult::negate(isCovariantWith(p.first, p.second)); + else if (get2(subTy, negatedTy)) + return {true}; + else if (is(negatedTy)) + iceReporter->ice("attempting to negate a non-testable type"); + + return {false}; +} + +SubtypingResult Subtyping::isCovariantWith(const PrimitiveType* subPrim, const PrimitiveType* superPrim) { return {subPrim->type == superPrim->type}; } -SubtypingResult Subtyping::isSubtype_(const SingletonType* subSingleton, const PrimitiveType* superPrim) +SubtypingResult Subtyping::isCovariantWith(const SingletonType* subSingleton, const PrimitiveType* superPrim) { if (get(subSingleton) && superPrim->type == PrimitiveType::String) return {true}; @@ -515,24 +727,20 @@ SubtypingResult Subtyping::isSubtype_(const SingletonType* subSingleton, const P return {false}; } -SubtypingResult Subtyping::isSubtype_(const SingletonType* subSingleton, const SingletonType* superSingleton) +SubtypingResult Subtyping::isCovariantWith(const SingletonType* subSingleton, const SingletonType* superSingleton) { return {*subSingleton == *superSingleton}; } -SubtypingResult Subtyping::isSubtype_(const TableType* subTable, const TableType* superTable) +SubtypingResult Subtyping::isCovariantWith(const TableType* subTable, const TableType* superTable) { SubtypingResult result{true}; - for (const auto& [name, prop]: superTable->props) + for (const auto& [name, prop] : superTable->props) { auto it = subTable->props.find(name); if (it != subTable->props.end()) - { - // Table properties are invariant - result.andAlso(isSubtype(it->second.type(), prop.type())); - result.andAlso(isSubtype(prop.type(), it->second.type())); - } + result.andAlso(isInvariantWith(prop.type(), it->second.type())); else return SubtypingResult{false}; } @@ -540,17 +748,18 @@ SubtypingResult Subtyping::isSubtype_(const TableType* subTable, const TableType return result; } -SubtypingResult Subtyping::isSubtype_(const MetatableType* subMt, const MetatableType* superMt) +SubtypingResult Subtyping::isCovariantWith(const MetatableType* subMt, const MetatableType* superMt) { return SubtypingResult::all({ - isSubtype_(subMt->table, superMt->table), - isSubtype_(subMt->metatable, superMt->metatable), + isCovariantWith(subMt->table, superMt->table), + isCovariantWith(subMt->metatable, superMt->metatable), }); } -SubtypingResult Subtyping::isSubtype_(const MetatableType* subMt, const TableType* superTable) +SubtypingResult Subtyping::isCovariantWith(const MetatableType* subMt, const TableType* superTable) { - if (auto subTable = get(subMt->table)) { + if (auto subTable = get(subMt->table)) + { // Metatables cannot erase properties from the table they're attached to, so // the subtyping rule for this is just if the table component is a subtype // of the supertype table. @@ -560,7 +769,7 @@ SubtypingResult Subtyping::isSubtype_(const MetatableType* subMt, const TableTyp // that the metatable isn't a subtype of the table, even though they have // compatible properties/shapes. We'll revisit this later when we have a // better understanding of how important this is. - return isSubtype_(subTable, superTable); + return isCovariantWith(subTable, superTable); } else { @@ -569,23 +778,19 @@ SubtypingResult Subtyping::isSubtype_(const MetatableType* subMt, const TableTyp } } -SubtypingResult Subtyping::isSubtype_(const ClassType* subClass, const ClassType* superClass) +SubtypingResult Subtyping::isCovariantWith(const ClassType* subClass, const ClassType* superClass) { return {isSubclass(subClass, superClass)}; } -SubtypingResult Subtyping::isSubtype_(const ClassType* subClass, const TableType* superTable) +SubtypingResult Subtyping::isCovariantWith(const ClassType* subClass, const TableType* superTable) { SubtypingResult result{true}; - for (const auto& [name, prop]: superTable->props) + for (const auto& [name, prop] : superTable->props) { if (auto classProp = lookupClassProp(subClass, name)) - { - // Table properties are invariant - result.andAlso(isSubtype_(classProp->type(), prop.type())); - result.andAlso(isSubtype_(prop.type(), classProp->type())); - } + result.andAlso(isInvariantWith(prop.type(), classProp->type())); else return SubtypingResult{false}; } @@ -593,20 +798,20 @@ SubtypingResult Subtyping::isSubtype_(const ClassType* subClass, const TableType return result; } -SubtypingResult Subtyping::isSubtype_(const FunctionType* subFunction, const FunctionType* superFunction) +SubtypingResult Subtyping::isCovariantWith(const FunctionType* subFunction, const FunctionType* superFunction) { SubtypingResult result; { VarianceFlipper vf{&variance}; - result.orElse(isSubtype_(superFunction->argTypes, subFunction->argTypes)); + result.orElse(isContravariantWith(subFunction->argTypes, superFunction->argTypes)); } - result.andAlso(isSubtype_(subFunction->retTypes, superFunction->retTypes)); + result.andAlso(isCovariantWith(subFunction->retTypes, superFunction->retTypes)); return result; } -SubtypingResult Subtyping::isSubtype_(const PrimitiveType* subPrim, const TableType* superTable) +SubtypingResult Subtyping::isCovariantWith(const PrimitiveType* subPrim, const TableType* superTable) { SubtypingResult result{false}; if (subPrim->type == PrimitiveType::String) @@ -618,7 +823,7 @@ SubtypingResult Subtyping::isSubtype_(const PrimitiveType* subPrim, const TableT if (auto it = mttv->props.find("__index"); it != mttv->props.end()) { if (auto stringTable = get(it->second.type())) - result.orElse(isSubtype_(stringTable, superTable)); + result.orElse(isCovariantWith(stringTable, superTable)); } } } @@ -627,7 +832,7 @@ SubtypingResult Subtyping::isSubtype_(const PrimitiveType* subPrim, const TableT return result; } -SubtypingResult Subtyping::isSubtype_(const SingletonType* subSingleton, const TableType* superTable) +SubtypingResult Subtyping::isCovariantWith(const SingletonType* subSingleton, const TableType* superTable) { SubtypingResult result{false}; if (auto stringleton = get(subSingleton)) @@ -639,7 +844,7 @@ SubtypingResult Subtyping::isSubtype_(const SingletonType* subSingleton, const T if (auto it = mttv->props.find("__index"); it != mttv->props.end()) { if (auto stringTable = get(it->second.type())) - result.orElse(isSubtype_(stringTable, superTable)); + result.orElse(isCovariantWith(stringTable, superTable)); } } } @@ -647,29 +852,27 @@ SubtypingResult Subtyping::isSubtype_(const SingletonType* subSingleton, const T return result; } -SubtypingResult Subtyping::isSubtype_(const NormalizedType* subNorm, const NormalizedType* superNorm) +SubtypingResult Subtyping::isCovariantWith(const NormalizedType* subNorm, const NormalizedType* superNorm) { if (!subNorm || !superNorm) return {false, true, true}; - SubtypingResult result = isSubtype_(subNorm->tops, superNorm->tops); - result.andAlso(isSubtype_(subNorm->booleans, superNorm->booleans)); - result.andAlso(isSubtype_(subNorm->classes, superNorm->classes, superNorm->tables)); - result.andAlso(isSubtype_(subNorm->errors, superNorm->errors)); - result.andAlso(isSubtype_(subNorm->nils, superNorm->nils)); - result.andAlso(isSubtype_(subNorm->numbers, superNorm->numbers)); - result.isSubtype &= Luau::isSubtype(subNorm->strings, superNorm->strings); - // isSubtype_(subNorm->strings, superNorm->tables); - result.andAlso(isSubtype_(subNorm->threads, superNorm->threads)); - result.andAlso(isSubtype_(subNorm->tables, superNorm->tables)); - // isSubtype_(subNorm->tables, superNorm->strings); - // isSubtype_(subNorm->tables, superNorm->classes); - result.andAlso(isSubtype_(subNorm->functions, superNorm->functions)); - // isSubtype_(subNorm->tyvars, superNorm->tyvars); + SubtypingResult result = isCovariantWith(subNorm->tops, superNorm->tops); + result.andAlso(isCovariantWith(subNorm->booleans, superNorm->booleans)); + result.andAlso(isCovariantWith(subNorm->classes, superNorm->classes).orElse(isCovariantWith(subNorm->classes, superNorm->tables))); + result.andAlso(isCovariantWith(subNorm->errors, superNorm->errors)); + result.andAlso(isCovariantWith(subNorm->nils, superNorm->nils)); + result.andAlso(isCovariantWith(subNorm->numbers, superNorm->numbers)); + result.andAlso(isCovariantWith(subNorm->strings, superNorm->strings)); + result.andAlso(isCovariantWith(subNorm->strings, superNorm->tables)); + result.andAlso(isCovariantWith(subNorm->threads, superNorm->threads)); + result.andAlso(isCovariantWith(subNorm->tables, superNorm->tables)); + result.andAlso(isCovariantWith(subNorm->functions, superNorm->functions)); + // isCovariantWith(subNorm->tyvars, superNorm->tyvars); return result; } -SubtypingResult Subtyping::isSubtype_(const NormalizedClassType& subClass, const NormalizedClassType& superClass, const TypeIds& superTables) +SubtypingResult Subtyping::isCovariantWith(const NormalizedClassType& subClass, const NormalizedClassType& superClass) { for (const auto& [subClassTy, _] : subClass.classes) { @@ -677,24 +880,18 @@ SubtypingResult Subtyping::isSubtype_(const NormalizedClassType& subClass, const for (const auto& [superClassTy, superNegations] : superClass.classes) { - result.orElse(isSubtype_(subClassTy, superClassTy)); + result.orElse(isCovariantWith(subClassTy, superClassTy)); if (!result.isSubtype) continue; for (TypeId negation : superNegations) { - result.andAlso(SubtypingResult::negate(isSubtype_(subClassTy, negation))); + result.andAlso(SubtypingResult::negate(isCovariantWith(subClassTy, negation))); if (result.isSubtype) break; } } - if (result.isSubtype) - continue; - - for (TypeId superTableTy : superTables) - result.orElse(isSubtype_(subClassTy, superTableTy)); - if (!result.isSubtype) return result; } @@ -702,17 +899,79 @@ SubtypingResult Subtyping::isSubtype_(const NormalizedClassType& subClass, const return {true}; } -SubtypingResult Subtyping::isSubtype_(const NormalizedFunctionType& subFunction, const NormalizedFunctionType& superFunction) +SubtypingResult Subtyping::isCovariantWith(const NormalizedClassType& subClass, const TypeIds& superTables) +{ + for (const auto& [subClassTy, _] : subClass.classes) + { + SubtypingResult result; + + for (TypeId superTableTy : superTables) + result.orElse(isCovariantWith(subClassTy, superTableTy)); + + if (!result.isSubtype) + return result; + } + + return {true}; +} + +SubtypingResult Subtyping::isCovariantWith(const NormalizedStringType& subString, const NormalizedStringType& superString) +{ + bool isSubtype = Luau::isSubtype(subString, superString); + return {isSubtype}; +} + +SubtypingResult Subtyping::isCovariantWith(const NormalizedStringType& subString, const TypeIds& superTables) +{ + if (subString.isNever()) + return {true}; + + if (subString.isCofinite) + { + SubtypingResult result; + for (const auto& superTable : superTables) + { + result.orElse(isCovariantWith(builtinTypes->stringType, superTable)); + if (result.isSubtype) + return result; + } + return result; + } + + // Finite case + // S = s1 | s2 | s3 ... sn <: t1 | t2 | ... | tn + // iff for some ti, S <: ti + // iff for all sj, sj <: ti + for (const auto& superTable : superTables) + { + SubtypingResult result{true}; + for (const auto& [_, subString] : subString.singletons) + { + result.andAlso(isCovariantWith(subString, superTable)); + if (!result.isSubtype) + break; + } + + if (!result.isSubtype) + continue; + else + return result; + } + + return {false}; +} + +SubtypingResult Subtyping::isCovariantWith(const NormalizedFunctionType& subFunction, const NormalizedFunctionType& superFunction) { if (subFunction.isNever()) return {true}; else if (superFunction.isTop) return {true}; else - return isSubtype_(subFunction.parts, superFunction.parts); + return isCovariantWith(subFunction.parts, superFunction.parts); } -SubtypingResult Subtyping::isSubtype_(const TypeIds& subTypes, const TypeIds& superTypes) +SubtypingResult Subtyping::isCovariantWith(const TypeIds& subTypes, const TypeIds& superTypes) { std::vector results; @@ -720,15 +979,15 @@ SubtypingResult Subtyping::isSubtype_(const TypeIds& subTypes, const TypeIds& su { results.emplace_back(); for (TypeId superTy : superTypes) - results.back().orElse(isSubtype_(subTy, superTy)); + results.back().orElse(isCovariantWith(subTy, superTy)); } return SubtypingResult::all(results); } -SubtypingResult Subtyping::isSubtype_(const VariadicTypePack* subVariadic, const VariadicTypePack* superVariadic) +SubtypingResult Subtyping::isCovariantWith(const VariadicTypePack* subVariadic, const VariadicTypePack* superVariadic) { - return isSubtype_(subVariadic->ty, superVariadic->ty); + return isCovariantWith(subVariadic->ty, superVariadic->ty); } bool Subtyping::bindGeneric(TypeId subTy, TypeId superTy) @@ -772,7 +1031,7 @@ bool Subtyping::bindGeneric(TypePackId subTp, TypePackId superTp) return true; } -template +template TypeId Subtyping::makeAggregateType(const Container& container, TypeId orElse) { if (container.empty()) diff --git a/Analysis/src/ToDot.cpp b/Analysis/src/ToDot.cpp index 04d04470..09851024 100644 --- a/Analysis/src/ToDot.cpp +++ b/Analysis/src/ToDot.cpp @@ -145,8 +145,7 @@ void StateDot::visitChildren(TypeId ty, int index) startNode(index); startNodeLabel(); - auto go = [&](auto&& t) - { + auto go = [&](auto&& t) { using T = std::decay_t; if constexpr (std::is_same_v) diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index 65da9dfa..1daabda3 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -242,8 +242,8 @@ struct TypeChecker2 Normalizer normalizer; - TypeChecker2(NotNull builtinTypes, NotNull unifierState, NotNull limits, DcrLogger* logger, const SourceModule* sourceModule, - Module* module) + TypeChecker2(NotNull builtinTypes, NotNull unifierState, NotNull limits, DcrLogger* logger, + const SourceModule* sourceModule, Module* module) : builtinTypes(builtinTypes) , logger(logger) , limits(limits) @@ -1295,13 +1295,8 @@ struct TypeChecker2 else if (auto assertion = expr->as()) return isLiteral(assertion->expr); - return - expr->is() || - expr->is() || - expr->is() || - expr->is() || - expr->is() || - expr->is(); + return expr->is() || expr->is() || expr->is() || + expr->is() || expr->is() || expr->is(); } static std::unique_ptr buildLiteralPropertiesSet(AstExpr* expr) @@ -1423,7 +1418,7 @@ struct TypeChecker2 LUAU_ASSERT(argOffset == args->head.size()); const Location argLoc = argExprs->empty() ? Location{} // TODO - : argExprs->at(argExprs->size() - 1)->location; + : argExprs->at(argExprs->size() - 1)->location; if (paramIter.tail() && args->tail) { @@ -2686,8 +2681,8 @@ struct TypeChecker2 } }; -void check( - NotNull builtinTypes, NotNull unifierState, NotNull limits, DcrLogger* logger, const SourceModule& sourceModule, Module* module) +void check(NotNull builtinTypes, NotNull unifierState, NotNull limits, DcrLogger* logger, + const SourceModule& sourceModule, Module* module) { TypeChecker2 typeChecker{builtinTypes, unifierState, limits, logger, &sourceModule, module}; diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 7cf05cda..4a27bbde 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -25,6 +25,7 @@ LUAU_FASTFLAGVARIABLE(LuauOccursIsntAlwaysFailure, false) LUAU_FASTFLAG(LuauNormalizeBlockedTypes) LUAU_FASTFLAG(LuauAlwaysCommitInferencesOfFunctionCalls) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) +LUAU_FASTFLAGVARIABLE(LuauFixIndexerSubtypingOrdering, false) namespace Luau { @@ -2285,7 +2286,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection, variance = Invariant; Unifier innerState = makeChildUnifier(); - if (useNewSolver) + if (useNewSolver || FFlag::LuauFixIndexerSubtypingOrdering) innerState.tryUnify_(prop.type(), superTable->indexer->indexResultType); else { diff --git a/Analysis/src/Unifier2.cpp b/Analysis/src/Unifier2.cpp index 0be6941b..1f3330dc 100644 --- a/Analysis/src/Unifier2.cpp +++ b/Analysis/src/Unifier2.cpp @@ -26,7 +26,6 @@ Unifier2::Unifier2(NotNull arena, NotNull builtinTypes, , ice(ice) , recursionLimit(FInt::LuauTypeInferRecursionLimit) { - } bool Unifier2::unify(TypeId subTy, TypeId superTy) @@ -99,10 +98,7 @@ bool Unifier2::unify(TypePackId subTp, TypePackId superTp) return true; } - size_t maxLength = std::max( - flatten(subTp).first.size(), - flatten(superTp).first.size() - ); + size_t maxLength = std::max(flatten(subTp).first.size(), flatten(superTp).first.size()); auto [subTypes, subTail] = extendTypePack(*arena, builtinTypes, subTp, maxLength); auto [superTypes, superTail] = extendTypePack(*arena, builtinTypes, superTp, maxLength); @@ -123,16 +119,25 @@ struct FreeTypeSearcher : TypeVisitor explicit FreeTypeSearcher(NotNull scope) : TypeVisitor(/*skipBoundTypes*/ true) , scope(scope) - {} + { + } - enum { Positive, Negative } polarity = Positive; + enum + { + Positive, + Negative + } polarity = Positive; void flip() { switch (polarity) { - case Positive: polarity = Negative; break; - case Negative: polarity = Positive; break; + case Positive: + polarity = Negative; + break; + case Negative: + polarity = Positive; + break; } } @@ -152,8 +157,12 @@ struct FreeTypeSearcher : TypeVisitor switch (polarity) { - case Positive: positiveTypes.insert(ty); break; - case Negative: negativeTypes.insert(ty); break; + case Positive: + positiveTypes.insert(ty); + break; + case Negative: + negativeTypes.insert(ty); + break; } return true; @@ -180,13 +189,17 @@ struct MutatingGeneralizer : TypeOnceVisitor std::unordered_set negativeTypes; std::vector generics; - MutatingGeneralizer(NotNull builtinTypes, NotNull scope, std::unordered_set positiveTypes, std::unordered_set negativeTypes) + bool isWithinFunction = false; + + MutatingGeneralizer( + NotNull builtinTypes, NotNull scope, std::unordered_set positiveTypes, std::unordered_set negativeTypes) : TypeOnceVisitor(/* skipBoundTypes */ true) , builtinTypes(builtinTypes) , scope(scope) , positiveTypes(std::move(positiveTypes)) , negativeTypes(std::move(negativeTypes)) - {} + { + } static void replace(DenseHashSet& seen, TypeId haystack, TypeId needle, TypeId replacement) { @@ -211,10 +224,7 @@ struct MutatingGeneralizer : TypeOnceVisitor // FIXME: I bet this function has reentrancy problems option = follow(option); if (option == needle) - { - LUAU_ASSERT(!seen.find(option)); option = replacement; - } // TODO seen set else if (get(option)) @@ -224,7 +234,21 @@ struct MutatingGeneralizer : TypeOnceVisitor } } - bool visit (TypeId ty, const FreeType&) override + bool visit(TypeId ty, const FunctionType& ft) override + { + const bool oldValue = isWithinFunction; + + isWithinFunction = true; + + traverse(ft.argTypes); + traverse(ft.retTypes); + + isWithinFunction = oldValue; + + return false; + } + + bool visit(TypeId ty, const FreeType&) override { const FreeType* ft = get(ty); LUAU_ASSERT(ft); @@ -232,7 +256,8 @@ struct MutatingGeneralizer : TypeOnceVisitor traverse(ft->lowerBound); traverse(ft->upperBound); - // ft is potentially invalid now. + // It is possible for the above traverse() calls to cause ty to be + // transmuted. We must reaquire ft if this happens. ty = follow(ty); ft = get(ty); if (!ft) @@ -252,8 +277,13 @@ struct MutatingGeneralizer : TypeOnceVisitor if (!hasLowerBound && !hasUpperBound) { - emplaceType(asMutable(ty), scope); - generics.push_back(ty); + if (isWithinFunction) + { + emplaceType(asMutable(ty), scope); + generics.push_back(ty); + } + else + emplaceType(asMutable(ty), builtinTypes->unknownType); } // It is possible that this free type has other free types in its upper @@ -264,19 +294,27 @@ struct MutatingGeneralizer : TypeOnceVisitor // If we do not do this, we get tautological bounds like a <: a <: unknown. else if (isPositive && !hasUpperBound) { - if (FreeType* lowerFree = getMutable(ft->lowerBound); lowerFree && lowerFree->upperBound == ty) + TypeId lb = follow(ft->lowerBound); + if (FreeType* lowerFree = getMutable(lb); lowerFree && lowerFree->upperBound == ty) lowerFree->upperBound = builtinTypes->unknownType; else - replace(seen, ft->lowerBound, ty, builtinTypes->unknownType); - emplaceType(asMutable(ty), ft->lowerBound); + { + DenseHashSet replaceSeen{nullptr}; + replace(replaceSeen, lb, ty, builtinTypes->unknownType); + } + emplaceType(asMutable(ty), lb); } else { - if (FreeType* upperFree = getMutable(ft->upperBound); upperFree && upperFree->lowerBound == ty) + TypeId ub = follow(ft->upperBound); + if (FreeType* upperFree = getMutable(ub); upperFree && upperFree->lowerBound == ty) upperFree->lowerBound = builtinTypes->neverType; else - replace(seen, ft->upperBound, ty, builtinTypes->neverType); - emplaceType(asMutable(ty), ft->upperBound); + { + DenseHashSet replaceSeen{nullptr}; + replace(replaceSeen, ub, ty, builtinTypes->neverType); + } + emplaceType(asMutable(ty), ub); } return false; @@ -363,4 +401,4 @@ OccursCheckResult Unifier2::occursCheck(DenseHashSet& seen, TypePack return OccursCheckResult::Pass; } -} +} // namespace Luau diff --git a/Ast/include/Luau/Ast.h b/Ast/include/Luau/Ast.h index b60fec28..f3b5b149 100644 --- a/Ast/include/Luau/Ast.h +++ b/Ast/include/Luau/Ast.h @@ -534,7 +534,7 @@ class AstStatBlock : public AstStat public: LUAU_RTTI(AstStatBlock) - AstStatBlock(const Location& location, const AstArray& body, bool hasEnd=true); + AstStatBlock(const Location& location, const AstArray& body, bool hasEnd = true); void visit(AstVisitor* visitor) override; diff --git a/CLI/Compile.cpp b/CLI/Compile.cpp index 6197f03e..c35f9c3d 100644 --- a/CLI/Compile.cpp +++ b/CLI/Compile.cpp @@ -89,13 +89,14 @@ static void reportError(const char* name, const Luau::CompileError& error) report(name, error.getLocation(), "CompileError", error.what()); } -static std::string getCodegenAssembly(const char* name, const std::string& bytecode, Luau::CodeGen::AssemblyOptions options) +static std::string getCodegenAssembly( + const char* name, const std::string& bytecode, Luau::CodeGen::AssemblyOptions options, Luau::CodeGen::LoweringStats* stats) { std::unique_ptr globalState(luaL_newstate(), lua_close); lua_State* L = globalState.get(); if (luau_load(L, name, bytecode.data(), bytecode.size(), 0) == 0) - return Luau::CodeGen::getAssembly(L, -1, options); + return Luau::CodeGen::getAssembly(L, -1, options, stats); fprintf(stderr, "Error loading bytecode %s\n", name); return ""; @@ -119,6 +120,8 @@ struct CompileStats double parseTime; double compileTime; double codegenTime; + + Luau::CodeGen::LoweringStats lowerStats; }; static double recordDeltaTime(double& timer) @@ -213,10 +216,10 @@ static bool compileFile(const char* name, CompileFormat format, Luau::CodeGen::A case CompileFormat::CodegenAsm: case CompileFormat::CodegenIr: case CompileFormat::CodegenVerbose: - printf("%s", getCodegenAssembly(name, bcb.getBytecode(), options).c_str()); + printf("%s", getCodegenAssembly(name, bcb.getBytecode(), options, &stats.lowerStats).c_str()); break; case CompileFormat::CodegenNull: - stats.codegen += getCodegenAssembly(name, bcb.getBytecode(), options).size(); + stats.codegen += getCodegenAssembly(name, bcb.getBytecode(), options, &stats.lowerStats).size(); stats.codegenTime += recordDeltaTime(currts); break; case CompileFormat::Null: @@ -355,13 +358,22 @@ int main(int argc, char** argv) failed += !compileFile(path.c_str(), compileFormat, assemblyTarget, stats); if (compileFormat == CompileFormat::Null) + { printf("Compiled %d KLOC into %d KB bytecode (read %.2fs, parse %.2fs, compile %.2fs)\n", int(stats.lines / 1000), int(stats.bytecode / 1024), stats.readTime, stats.parseTime, stats.compileTime); + } else if (compileFormat == CompileFormat::CodegenNull) + { printf("Compiled %d KLOC into %d KB bytecode => %d KB native code (%.2fx) (read %.2fs, parse %.2fs, compile %.2fs, codegen %.2fs)\n", int(stats.lines / 1000), int(stats.bytecode / 1024), int(stats.codegen / 1024), stats.bytecode == 0 ? 0.0 : double(stats.codegen) / double(stats.bytecode), stats.readTime, stats.parseTime, stats.compileTime, stats.codegenTime); + printf("Lowering stats:\n"); + printf("- spills to stack: %d, spills to restore: %d, max spill slot %u\n", stats.lowerStats.spillsToSlot, stats.lowerStats.spillsToRestore, + stats.lowerStats.maxSpillSlotsUsed); + printf("- regalloc failed: %d, lowering failed %d\n", stats.lowerStats.regAllocErrors, stats.lowerStats.loweringErrors); + } + return failed ? 1 : 0; } diff --git a/CodeGen/include/Luau/CodeGen.h b/CodeGen/include/Luau/CodeGen.h index 85f19d01..031779f7 100644 --- a/CodeGen/include/Luau/CodeGen.h +++ b/CodeGen/include/Luau/CodeGen.h @@ -3,6 +3,7 @@ #include +#include #include struct lua_State; @@ -74,8 +75,18 @@ struct AssemblyOptions void* annotatorContext = nullptr; }; +struct LoweringStats +{ + int spillsToSlot = 0; + int spillsToRestore = 0; + unsigned maxSpillSlotsUsed = 0; + + int regAllocErrors = 0; + int loweringErrors = 0; +}; + // Generates assembly for target function and all inner functions -std::string getAssembly(lua_State* L, int idx, AssemblyOptions options = {}); +std::string getAssembly(lua_State* L, int idx, AssemblyOptions options = {}, LoweringStats* stats = nullptr); using PerfLogFn = void (*)(void* context, uintptr_t addr, unsigned size, const char* symbol); diff --git a/CodeGen/include/Luau/IrRegAllocX64.h b/CodeGen/include/Luau/IrRegAllocX64.h index 665b5229..632499a4 100644 --- a/CodeGen/include/Luau/IrRegAllocX64.h +++ b/CodeGen/include/Luau/IrRegAllocX64.h @@ -12,6 +12,9 @@ namespace Luau { namespace CodeGen { + +struct LoweringStats; + namespace X64 { @@ -33,7 +36,7 @@ struct IrSpillX64 struct IrRegAllocX64 { - IrRegAllocX64(AssemblyBuilderX64& build, IrFunction& function); + IrRegAllocX64(AssemblyBuilderX64& build, IrFunction& function, LoweringStats* stats); RegisterX64 allocReg(SizeX64 size, uint32_t instIdx); RegisterX64 allocRegOrReuse(SizeX64 size, uint32_t instIdx, std::initializer_list oprefs); @@ -70,6 +73,7 @@ struct IrRegAllocX64 AssemblyBuilderX64& build; IrFunction& function; + LoweringStats* stats = nullptr; uint32_t currInstIdx = ~0u; diff --git a/CodeGen/include/Luau/IrUtils.h b/CodeGen/include/Luau/IrUtils.h index 5db5f6f1..50b3dc7c 100644 --- a/CodeGen/include/Luau/IrUtils.h +++ b/CodeGen/include/Luau/IrUtils.h @@ -264,5 +264,14 @@ uint32_t getNativeContextOffset(int bfid); // Cleans up blocks that were created with no users void killUnusedBlocks(IrFunction& function); +// Get blocks in order that tries to maximize fallthrough between them during lowering +// We want to mostly preserve build order with fallbacks outlined +// But we also use hints from optimization passes that chain blocks together where there's only one out-in edge between them +std::vector getSortedBlockOrder(IrFunction& function); + +// Returns first non-dead block that comes after block at index 'i' in the sorted blocks array +// 'dummy' block is returned if the end of array was reached +IrBlock& getNextBlock(IrFunction& function, std::vector& sortedBlocks, IrBlock& dummy, size_t i); + } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/AssemblyBuilderA64.cpp b/CodeGen/src/AssemblyBuilderA64.cpp index f385cd0a..15be6b88 100644 --- a/CodeGen/src/AssemblyBuilderA64.cpp +++ b/CodeGen/src/AssemblyBuilderA64.cpp @@ -1091,7 +1091,7 @@ void AssemblyBuilderA64::placeER(const char* name, RegisterA64 dst, RegisterA64 LUAU_ASSERT(shift >= 0 && shift <= 4); uint32_t sf = (dst.kind == KindA64::x) ? 0x80000000 : 0; // could be useful in the future for byte->word extends - int option = 0b010; // UXTW + int option = 0b010; // UXTW place(dst.index | (src1.index << 5) | (shift << 10) | (option << 13) | (src2.index << 16) | (1 << 21) | (op << 24) | sf); commit(); diff --git a/CodeGen/src/CodeGen.cpp b/CodeGen/src/CodeGen.cpp index 9d117b1d..a25e1046 100644 --- a/CodeGen/src/CodeGen.cpp +++ b/CodeGen/src/CodeGen.cpp @@ -106,7 +106,7 @@ static std::optional createNativeFunction(AssemblyBuilder& build, M IrBuilder ir; ir.buildFunctionIr(proto); - if (!lowerFunction(ir, build, helpers, proto, {})) + if (!lowerFunction(ir, build, helpers, proto, {}, /* stats */ nullptr)) return std::nullopt; return createNativeProto(proto, ir); diff --git a/CodeGen/src/CodeGenAssembly.cpp b/CodeGen/src/CodeGenAssembly.cpp index fed5ddd3..4fd24cd9 100644 --- a/CodeGen/src/CodeGenAssembly.cpp +++ b/CodeGen/src/CodeGenAssembly.cpp @@ -43,7 +43,7 @@ static void logFunctionHeader(AssemblyBuilder& build, Proto* proto) } template -static std::string getAssemblyImpl(AssemblyBuilder& build, const TValue* func, AssemblyOptions options) +static std::string getAssemblyImpl(AssemblyBuilder& build, const TValue* func, AssemblyOptions options, LoweringStats* stats) { std::vector protos; gatherFunctions(protos, clvalue(func)->l.p); @@ -66,7 +66,7 @@ static std::string getAssemblyImpl(AssemblyBuilder& build, const TValue* func, A if (options.includeAssembly || options.includeIr) logFunctionHeader(build, p); - if (!lowerFunction(ir, build, helpers, p, options)) + if (!lowerFunction(ir, build, helpers, p, options, stats)) { if (build.logText) build.logAppend("; skipping (can't lower)\n"); @@ -90,7 +90,7 @@ static std::string getAssemblyImpl(AssemblyBuilder& build, const TValue* func, A unsigned int getCpuFeaturesA64(); #endif -std::string getAssembly(lua_State* L, int idx, AssemblyOptions options) +std::string getAssembly(lua_State* L, int idx, AssemblyOptions options, LoweringStats* stats) { LUAU_ASSERT(lua_isLfunction(L, idx)); const TValue* func = luaA_toobject(L, idx); @@ -106,35 +106,35 @@ std::string getAssembly(lua_State* L, int idx, AssemblyOptions options) X64::AssemblyBuilderX64 build(/* logText= */ options.includeAssembly); #endif - return getAssemblyImpl(build, func, options); + return getAssemblyImpl(build, func, options, stats); } case AssemblyOptions::A64: { A64::AssemblyBuilderA64 build(/* logText= */ options.includeAssembly, /* features= */ A64::Feature_JSCVT); - return getAssemblyImpl(build, func, options); + return getAssemblyImpl(build, func, options, stats); } case AssemblyOptions::A64_NoFeatures: { A64::AssemblyBuilderA64 build(/* logText= */ options.includeAssembly, /* features= */ 0); - return getAssemblyImpl(build, func, options); + return getAssemblyImpl(build, func, options, stats); } case AssemblyOptions::X64_Windows: { X64::AssemblyBuilderX64 build(/* logText= */ options.includeAssembly, X64::ABIX64::Windows); - return getAssemblyImpl(build, func, options); + return getAssemblyImpl(build, func, options, stats); } case AssemblyOptions::X64_SystemV: { X64::AssemblyBuilderX64 build(/* logText= */ options.includeAssembly, X64::ABIX64::SystemV); - return getAssemblyImpl(build, func, options); + return getAssemblyImpl(build, func, options, stats); } default: diff --git a/CodeGen/src/CodeGenLower.h b/CodeGen/src/CodeGenLower.h index 24e1c38c..9b729a92 100644 --- a/CodeGen/src/CodeGenLower.h +++ b/CodeGen/src/CodeGenLower.h @@ -44,42 +44,10 @@ inline void gatherFunctions(std::vector& results, Proto* proto) gatherFunctions(results, proto->p[i]); } -inline IrBlock& getNextBlock(IrFunction& function, std::vector& sortedBlocks, IrBlock& dummy, size_t i) -{ - for (size_t j = i + 1; j < sortedBlocks.size(); ++j) - { - IrBlock& block = function.blocks[sortedBlocks[j]]; - if (block.kind != IrBlockKind::Dead) - return block; - } - - return dummy; -} - template inline bool lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction& function, int bytecodeid, AssemblyOptions options) { - // While we will need a better block ordering in the future, right now we want to mostly preserve build order with fallbacks outlined - std::vector sortedBlocks; - sortedBlocks.reserve(function.blocks.size()); - for (uint32_t i = 0; i < function.blocks.size(); i++) - sortedBlocks.push_back(i); - - std::sort(sortedBlocks.begin(), sortedBlocks.end(), [&](uint32_t idxA, uint32_t idxB) { - const IrBlock& a = function.blocks[idxA]; - const IrBlock& b = function.blocks[idxB]; - - // Place fallback blocks at the end - if ((a.kind == IrBlockKind::Fallback) != (b.kind == IrBlockKind::Fallback)) - return (a.kind == IrBlockKind::Fallback) < (b.kind == IrBlockKind::Fallback); - - // Try to order by instruction order - if (a.sortkey != b.sortkey) - return a.sortkey < b.sortkey; - - // Chains of blocks are merged together by having the same sort key and consecutive chain key - return a.chainkey < b.chainkey; - }); + std::vector sortedBlocks = getSortedBlockOrder(function); // For each IR instruction that begins a bytecode instruction, which bytecode instruction is it? std::vector bcLocations(function.instructions.size() + 1, ~0u); @@ -231,24 +199,26 @@ inline bool lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction& return true; } -inline bool lowerIr(X64::AssemblyBuilderX64& build, IrBuilder& ir, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options) +inline bool lowerIr( + X64::AssemblyBuilderX64& build, IrBuilder& ir, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options, LoweringStats* stats) { optimizeMemoryOperandsX64(ir.function); - X64::IrLoweringX64 lowering(build, helpers, ir.function); + X64::IrLoweringX64 lowering(build, helpers, ir.function, stats); return lowerImpl(build, lowering, ir.function, proto->bytecodeid, options); } -inline bool lowerIr(A64::AssemblyBuilderA64& build, IrBuilder& ir, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options) +inline bool lowerIr( + A64::AssemblyBuilderA64& build, IrBuilder& ir, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options, LoweringStats* stats) { - A64::IrLoweringA64 lowering(build, helpers, ir.function); + A64::IrLoweringA64 lowering(build, helpers, ir.function, stats); return lowerImpl(build, lowering, ir.function, proto->bytecodeid, options); } template -inline bool lowerFunction(IrBuilder& ir, AssemblyBuilder& build, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options) +inline bool lowerFunction(IrBuilder& ir, AssemblyBuilder& build, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options, LoweringStats* stats) { killUnusedBlocks(ir.function); @@ -264,7 +234,7 @@ inline bool lowerFunction(IrBuilder& ir, AssemblyBuilder& build, ModuleHelpers& createLinearBlocks(ir, useValueNumbering); } - return lowerIr(build, ir, helpers, proto, options); + return lowerIr(build, ir, helpers, proto, options, stats); } } // namespace CodeGen diff --git a/CodeGen/src/IrBuilder.cpp b/CodeGen/src/IrBuilder.cpp index e467ca68..1aca2993 100644 --- a/CodeGen/src/IrBuilder.cpp +++ b/CodeGen/src/IrBuilder.cpp @@ -385,7 +385,8 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) translateInstDupTable(*this, pc, i); break; case LOP_SETLIST: - inst(IrCmd::SETLIST, constUint(i), vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), constInt(LUAU_INSN_C(*pc) - 1), constUint(pc[1]), undef()); + inst(IrCmd::SETLIST, constUint(i), vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), constInt(LUAU_INSN_C(*pc) - 1), constUint(pc[1]), + undef()); break; case LOP_GETUPVAL: translateInstGetUpval(*this, pc, i); diff --git a/CodeGen/src/IrLoweringA64.cpp b/CodeGen/src/IrLoweringA64.cpp index a030f955..6e51b0d5 100644 --- a/CodeGen/src/IrLoweringA64.cpp +++ b/CodeGen/src/IrLoweringA64.cpp @@ -240,11 +240,12 @@ static bool emitBuiltin( } } -IrLoweringA64::IrLoweringA64(AssemblyBuilderA64& build, ModuleHelpers& helpers, IrFunction& function) +IrLoweringA64::IrLoweringA64(AssemblyBuilderA64& build, ModuleHelpers& helpers, IrFunction& function, LoweringStats* stats) : build(build) , helpers(helpers) , function(function) - , regs(function, {{x0, x15}, {x16, x17}, {q0, q7}, {q16, q31}}) + , stats(stats) + , regs(function, stats, {{x0, x15}, {x16, x17}, {q0, q7}, {q16, q31}}) , valueTracker(function) , exitHandlerMap(~0u) { @@ -858,7 +859,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaH_setnum))); build.blr(x3); inst.regA64 = regs.takeReg(x0, index); - break; + break; } case IrCmd::NEW_TABLE: { @@ -2016,6 +2017,15 @@ void IrLoweringA64::finishFunction() build.mov(x0, handler.pcpos * sizeof(Instruction)); build.b(helpers.updatePcAndContinueInVm); } + + if (stats) + { + if (error) + stats->loweringErrors++; + + if (regs.error) + stats->regAllocErrors++; + } } bool IrLoweringA64::hasError() const diff --git a/CodeGen/src/IrLoweringA64.h b/CodeGen/src/IrLoweringA64.h index 5134ceda..46f41021 100644 --- a/CodeGen/src/IrLoweringA64.h +++ b/CodeGen/src/IrLoweringA64.h @@ -17,13 +17,14 @@ namespace CodeGen struct ModuleHelpers; struct AssemblyOptions; +struct LoweringStats; namespace A64 { struct IrLoweringA64 { - IrLoweringA64(AssemblyBuilderA64& build, ModuleHelpers& helpers, IrFunction& function); + IrLoweringA64(AssemblyBuilderA64& build, ModuleHelpers& helpers, IrFunction& function, LoweringStats* stats); void lowerInst(IrInst& inst, uint32_t index, const IrBlock& next); void finishBlock(const IrBlock& curr, const IrBlock& next); @@ -74,6 +75,7 @@ struct IrLoweringA64 ModuleHelpers& helpers; IrFunction& function; + LoweringStats* stats = nullptr; IrRegAllocA64 regs; diff --git a/CodeGen/src/IrLoweringX64.cpp b/CodeGen/src/IrLoweringX64.cpp index fe5127ac..03ae4f0c 100644 --- a/CodeGen/src/IrLoweringX64.cpp +++ b/CodeGen/src/IrLoweringX64.cpp @@ -22,11 +22,12 @@ namespace CodeGen namespace X64 { -IrLoweringX64::IrLoweringX64(AssemblyBuilderX64& build, ModuleHelpers& helpers, IrFunction& function) +IrLoweringX64::IrLoweringX64(AssemblyBuilderX64& build, ModuleHelpers& helpers, IrFunction& function, LoweringStats* stats) : build(build) , helpers(helpers) , function(function) - , regs(build, function) + , stats(stats) + , regs(build, function, stats) , valueTracker(function) , exitHandlerMap(~0u) { @@ -1646,6 +1647,15 @@ void IrLoweringX64::finishFunction() build.mov(edx, handler.pcpos * sizeof(Instruction)); build.jmp(helpers.updatePcAndContinueInVm); } + + if (stats) + { + if (regs.maxUsedSlot > kSpillSlots) + stats->regAllocErrors++; + + if (regs.maxUsedSlot > stats->maxSpillSlotsUsed) + stats->maxSpillSlotsUsed = regs.maxUsedSlot; + } } bool IrLoweringX64::hasError() const diff --git a/CodeGen/src/IrLoweringX64.h b/CodeGen/src/IrLoweringX64.h index a32e034d..920ad002 100644 --- a/CodeGen/src/IrLoweringX64.h +++ b/CodeGen/src/IrLoweringX64.h @@ -19,13 +19,14 @@ namespace CodeGen struct ModuleHelpers; struct AssemblyOptions; +struct LoweringStats; namespace X64 { struct IrLoweringX64 { - IrLoweringX64(AssemblyBuilderX64& build, ModuleHelpers& helpers, IrFunction& function); + IrLoweringX64(AssemblyBuilderX64& build, ModuleHelpers& helpers, IrFunction& function, LoweringStats* stats); void lowerInst(IrInst& inst, uint32_t index, const IrBlock& next); void finishBlock(const IrBlock& curr, const IrBlock& next); @@ -76,6 +77,7 @@ struct IrLoweringX64 ModuleHelpers& helpers; IrFunction& function; + LoweringStats* stats = nullptr; IrRegAllocX64 regs; diff --git a/CodeGen/src/IrRegAllocA64.cpp b/CodeGen/src/IrRegAllocA64.cpp index f552c17f..d2b7ff14 100644 --- a/CodeGen/src/IrRegAllocA64.cpp +++ b/CodeGen/src/IrRegAllocA64.cpp @@ -2,6 +2,7 @@ #include "IrRegAllocA64.h" #include "Luau/AssemblyBuilderA64.h" +#include "Luau/CodeGen.h" #include "Luau/IrUtils.h" #include "BitUtils.h" @@ -109,8 +110,9 @@ static void restoreInst(AssemblyBuilderA64& build, uint32_t& freeSpillSlots, IrF inst.regA64 = reg; } -IrRegAllocA64::IrRegAllocA64(IrFunction& function, std::initializer_list> regs) +IrRegAllocA64::IrRegAllocA64(IrFunction& function, LoweringStats* stats, std::initializer_list> regs) : function(function) + , stats(stats) { for (auto& p : regs) { @@ -329,6 +331,9 @@ size_t IrRegAllocA64::spill(AssemblyBuilderA64& build, uint32_t index, std::init spills.push_back(s); def.needsReload = true; + + if (stats) + stats->spillsToRestore++; } else { @@ -345,6 +350,14 @@ size_t IrRegAllocA64::spill(AssemblyBuilderA64& build, uint32_t index, std::init spills.push_back(s); def.spilled = true; + + if (stats) + { + stats->spillsToSlot++; + + if (slot != kInvalidSpill && unsigned(slot + 1) > stats->maxSpillSlotsUsed) + stats->maxSpillSlotsUsed = slot + 1; + } } def.regA64 = noreg; diff --git a/CodeGen/src/IrRegAllocA64.h b/CodeGen/src/IrRegAllocA64.h index ae3110d7..d16fa19f 100644 --- a/CodeGen/src/IrRegAllocA64.h +++ b/CodeGen/src/IrRegAllocA64.h @@ -12,6 +12,9 @@ namespace Luau { namespace CodeGen { + +struct LoweringStats; + namespace A64 { @@ -19,7 +22,7 @@ class AssemblyBuilderA64; struct IrRegAllocA64 { - IrRegAllocA64(IrFunction& function, std::initializer_list> regs); + IrRegAllocA64(IrFunction& function, LoweringStats* stats, std::initializer_list> regs); RegisterA64 allocReg(KindA64 kind, uint32_t index); RegisterA64 allocTemp(KindA64 kind); @@ -69,6 +72,7 @@ struct IrRegAllocA64 Set& getSet(KindA64 kind); IrFunction& function; + LoweringStats* stats = nullptr; Set gpr, simd; std::vector spills; diff --git a/CodeGen/src/IrRegAllocX64.cpp b/CodeGen/src/IrRegAllocX64.cpp index 7690f69a..81cf2f4c 100644 --- a/CodeGen/src/IrRegAllocX64.cpp +++ b/CodeGen/src/IrRegAllocX64.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/IrRegAllocX64.h" +#include "Luau/CodeGen.h" #include "Luau/IrUtils.h" #include "EmitCommonX64.h" @@ -14,9 +15,10 @@ namespace X64 static const RegisterX64 kGprAllocOrder[] = {rax, rdx, rcx, rbx, rsi, rdi, r8, r9, r10, r11}; -IrRegAllocX64::IrRegAllocX64(AssemblyBuilderX64& build, IrFunction& function) +IrRegAllocX64::IrRegAllocX64(AssemblyBuilderX64& build, IrFunction& function, LoweringStats* stats) : build(build) , function(function) + , stats(stats) , usableXmmRegCount(getXmmRegisterCount(build.abi)) { freeGprMap.fill(true); @@ -225,10 +227,16 @@ void IrRegAllocX64::preserve(IrInst& inst) spill.stackSlot = uint8_t(i); inst.spilled = true; + + if (stats) + stats->spillsToSlot++; } else { inst.needsReload = true; + + if (stats) + stats->spillsToRestore++; } spills.push_back(spill); diff --git a/CodeGen/src/IrTranslation.cpp b/CodeGen/src/IrTranslation.cpp index f1eea645..13256789 100644 --- a/CodeGen/src/IrTranslation.cpp +++ b/CodeGen/src/IrTranslation.cpp @@ -615,7 +615,7 @@ static IrOp getLoopStepK(IrBuilder& build, int ra) { IrBlock& active = build.function.blocks[build.activeBlockIdx]; - if (active.start + 2 < build.function.instructions.size()) + if (active.start + 2 < build.function.instructions.size()) { IrInst& sv = build.function.instructions[build.function.instructions.size() - 2]; IrInst& st = build.function.instructions[build.function.instructions.size() - 1]; @@ -665,7 +665,7 @@ void translateInstForNPrep(IrBuilder& build, const Instruction* pc, int pcpos) IrOp step = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 1)); // step > 0 - // note: equivalent to 0 < step, but lowers into one instruction on both X64 and A64 + // note: equivalent to 0 < step, but lowers into one instruction on both X64 and A64 build.inst(IrCmd::JUMP_CMP_NUM, step, zero, build.cond(IrCondition::Greater), direct, reverse); // Condition to start the loop: step > 0 ? idx <= limit : limit <= idx @@ -763,7 +763,7 @@ void translateInstForNLoop(IrBuilder& build, const Instruction* pc, int pcpos) IrOp reverse = build.block(IrBlockKind::Internal); // step > 0 - // note: equivalent to 0 < step, but lowers into one instruction on both X64 and A64 + // note: equivalent to 0 < step, but lowers into one instruction on both X64 and A64 build.inst(IrCmd::JUMP_CMP_NUM, step, zero, build.cond(IrCondition::Greater), direct, reverse); // Condition to continue the loop: step > 0 ? idx <= limit : limit <= idx @@ -776,8 +776,8 @@ void translateInstForNLoop(IrBuilder& build, const Instruction* pc, int pcpos) build.beginBlock(direct); build.inst(IrCmd::JUMP_CMP_NUM, idx, limit, build.cond(IrCondition::LessEqual), loopRepeat, loopExit); } - else - { + else + { double stepN = build.function.doubleOp(stepK); // Condition to continue the loop: step > 0 ? idx <= limit : limit <= idx @@ -785,9 +785,9 @@ void translateInstForNLoop(IrBuilder& build, const Instruction* pc, int pcpos) build.inst(IrCmd::JUMP_CMP_NUM, idx, limit, build.cond(IrCondition::LessEqual), loopRepeat, loopExit); else build.inst(IrCmd::JUMP_CMP_NUM, limit, idx, build.cond(IrCondition::LessEqual), loopRepeat, loopExit); - } + } } - else + else { build.inst(IrCmd::INTERRUPT, build.constUint(pcpos)); diff --git a/CodeGen/src/IrUtils.cpp b/CodeGen/src/IrUtils.cpp index d263d3aa..687b0c2a 100644 --- a/CodeGen/src/IrUtils.cpp +++ b/CodeGen/src/IrUtils.cpp @@ -860,5 +860,43 @@ void killUnusedBlocks(IrFunction& function) } } +std::vector getSortedBlockOrder(IrFunction& function) +{ + std::vector sortedBlocks; + sortedBlocks.reserve(function.blocks.size()); + for (uint32_t i = 0; i < function.blocks.size(); i++) + sortedBlocks.push_back(i); + + std::sort(sortedBlocks.begin(), sortedBlocks.end(), [&](uint32_t idxA, uint32_t idxB) { + const IrBlock& a = function.blocks[idxA]; + const IrBlock& b = function.blocks[idxB]; + + // Place fallback blocks at the end + if ((a.kind == IrBlockKind::Fallback) != (b.kind == IrBlockKind::Fallback)) + return (a.kind == IrBlockKind::Fallback) < (b.kind == IrBlockKind::Fallback); + + // Try to order by instruction order + if (a.sortkey != b.sortkey) + return a.sortkey < b.sortkey; + + // Chains of blocks are merged together by having the same sort key and consecutive chain key + return a.chainkey < b.chainkey; + }); + + return sortedBlocks; +} + +IrBlock& getNextBlock(IrFunction& function, std::vector& sortedBlocks, IrBlock& dummy, size_t i) +{ + for (size_t j = i + 1; j < sortedBlocks.size(); ++j) + { + IrBlock& block = function.blocks[sortedBlocks[j]]; + if (block.kind != IrBlockKind::Dead) + return block; + } + + return dummy; +} + } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/NativeState.cpp b/CodeGen/src/NativeState.cpp index 6f567a2b..7b2f068b 100644 --- a/CodeGen/src/NativeState.cpp +++ b/CodeGen/src/NativeState.cpp @@ -14,21 +14,21 @@ #include #include +LUAU_FASTINTVARIABLE(LuauCodeGenBlockSize, 4 * 1024 * 1024) +LUAU_FASTINTVARIABLE(LuauCodeGenMaxTotalSize, 256 * 1024 * 1024) + namespace Luau { namespace CodeGen { -constexpr unsigned kBlockSize = 4 * 1024 * 1024; -constexpr unsigned kMaxTotalSize = 256 * 1024 * 1024; - NativeState::NativeState() : NativeState(nullptr, nullptr) { } NativeState::NativeState(AllocationCallback* allocationCallback, void* allocationCallbackContext) - : codeAllocator{kBlockSize, kMaxTotalSize, allocationCallback, allocationCallbackContext} + : codeAllocator{size_t(FInt::LuauCodeGenBlockSize), size_t(FInt::LuauCodeGenMaxTotalSize), allocationCallback, allocationCallbackContext} { } diff --git a/Common/include/Luau/ExperimentalFlags.h b/Common/include/Luau/ExperimentalFlags.h index fd074be1..94d13ea8 100644 --- a/Common/include/Luau/ExperimentalFlags.h +++ b/Common/include/Luau/ExperimentalFlags.h @@ -13,6 +13,7 @@ inline bool isFlagExperimental(const char* flag) static const char* const kList[] = { "LuauInstantiateInSubtyping", // requires some fixes to lua-apps code "LuauTinyControlFlowAnalysis", // waiting for updates to packages depended by internal builtin plugins + "LuauFixIndexerSubtypingOrdering", // requires some small fixes to lua-apps code since this fixes a false negative // makes sure we always have at least one entry nullptr, }; diff --git a/VM/src/lmathlib.cpp b/VM/src/lmathlib.cpp index 8a140780..5a817f25 100644 --- a/VM/src/lmathlib.cpp +++ b/VM/src/lmathlib.cpp @@ -275,15 +275,15 @@ static int math_randomseed(lua_State* L) return 0; } -static const unsigned char kPerlinHash[257] = {151, 160, 137, 91, 90, 15, 131, 13, 201, 95, 96, 53, 194, 233, 7, 225, 140, 36, 103, 30, 69, 142, 8, 99, - 37, 240, 21, 10, 23, 190, 6, 148, 247, 120, 234, 75, 0, 26, 197, 62, 94, 252, 219, 203, 117, 35, 11, 32, 57, 177, 33, 88, 237, 149, 56, 87, 174, - 20, 125, 136, 171, 168, 68, 175, 74, 165, 71, 134, 139, 48, 27, 166, 77, 146, 158, 231, 83, 111, 229, 122, 60, 211, 133, 230, 220, 105, 92, 41, - 55, 46, 245, 40, 244, 102, 143, 54, 65, 25, 63, 161, 1, 216, 80, 73, 209, 76, 132, 187, 208, 89, 18, 169, 200, 196, 135, 130, 116, 188, 159, 86, - 164, 100, 109, 198, 173, 186, 3, 64, 52, 217, 226, 250, 124, 123, 5, 202, 38, 147, 118, 126, 255, 82, 85, 212, 207, 206, 59, 227, 47, 16, 58, 17, - 182, 189, 28, 42, 223, 183, 170, 213, 119, 248, 152, 2, 44, 154, 163, 70, 221, 153, 101, 155, 167, 43, 172, 9, 129, 22, 39, 253, 19, 98, 108, 110, - 79, 113, 224, 232, 178, 185, 112, 104, 218, 246, 97, 228, 251, 34, 242, 193, 238, 210, 144, 12, 191, 179, 162, 241, 81, 51, 145, 235, 249, 14, - 239, 107, 49, 192, 214, 31, 181, 199, 106, 157, 184, 84, 204, 176, 115, 121, 50, 45, 127, 4, 150, 254, 138, 236, 205, 93, 222, 114, 67, 29, 24, - 72, 243, 141, 128, 195, 78, 66, 215, 61, 156, 180, 151}; +static const unsigned char kPerlinHash[257] = {151, 160, 137, 91, 90, 15, 131, 13, 201, 95, 96, 53, 194, 233, 7, 225, 140, 36, 103, 30, 69, 142, 8, + 99, 37, 240, 21, 10, 23, 190, 6, 148, 247, 120, 234, 75, 0, 26, 197, 62, 94, 252, 219, 203, 117, 35, 11, 32, 57, 177, 33, 88, 237, 149, 56, 87, + 174, 20, 125, 136, 171, 168, 68, 175, 74, 165, 71, 134, 139, 48, 27, 166, 77, 146, 158, 231, 83, 111, 229, 122, 60, 211, 133, 230, 220, 105, 92, + 41, 55, 46, 245, 40, 244, 102, 143, 54, 65, 25, 63, 161, 1, 216, 80, 73, 209, 76, 132, 187, 208, 89, 18, 169, 200, 196, 135, 130, 116, 188, 159, + 86, 164, 100, 109, 198, 173, 186, 3, 64, 52, 217, 226, 250, 124, 123, 5, 202, 38, 147, 118, 126, 255, 82, 85, 212, 207, 206, 59, 227, 47, 16, 58, + 17, 182, 189, 28, 42, 223, 183, 170, 213, 119, 248, 152, 2, 44, 154, 163, 70, 221, 153, 101, 155, 167, 43, 172, 9, 129, 22, 39, 253, 19, 98, 108, + 110, 79, 113, 224, 232, 178, 185, 112, 104, 218, 246, 97, 228, 251, 34, 242, 193, 238, 210, 144, 12, 191, 179, 162, 241, 81, 51, 145, 235, 249, + 14, 239, 107, 49, 192, 214, 31, 181, 199, 106, 157, 184, 84, 204, 176, 115, 121, 50, 45, 127, 4, 150, 254, 138, 236, 205, 93, 222, 114, 67, 29, + 24, 72, 243, 141, 128, 195, 78, 66, 215, 61, 156, 180, 151}; const float kPerlinGrad[16][3] = {{1, 1, 0}, {-1, 1, 0}, {1, -1, 0}, {-1, -1, 0}, {1, 0, 1}, {-1, 0, 1}, {1, 0, -1}, {-1, 0, -1}, {0, 1, 1}, {0, -1, 1}, {0, 1, -1}, {0, -1, -1}, {1, 1, 0}, {0, -1, 1}, {-1, 1, 0}, {0, -1, -1}}; diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index 68efc0e4..086ed649 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -2291,7 +2291,7 @@ reentry: { // table or userdata with __call, will be called during FORGLOOP // TODO: we might be able to stop supporting this depending on whether it's used in practice - void (*telemetrycb)(lua_State* L, int gtt, int stt, int itt) = lua_iter_call_telemetry; + void (*telemetrycb)(lua_State * L, int gtt, int stt, int itt) = lua_iter_call_telemetry; if (telemetrycb) telemetrycb(L, ttype(ra), ttype(ra + 1), ttype(ra + 2)); diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index dd90dd8e..84dee432 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -3667,8 +3667,7 @@ TEST_CASE_FIXTURE(ACFixture, "string_completion_outside_quotes") )"); StringCompletionCallback callback = [](std::string, std::optional, - std::optional contents) -> std::optional - { + std::optional contents) -> std::optional { Luau::AutocompleteEntryMap results = {{"test", Luau::AutocompleteEntry{Luau::AutocompleteEntryKind::String, std::nullopt, false, false}}}; return results; }; diff --git a/tests/CodeAllocator.test.cpp b/tests/CodeAllocator.test.cpp index 298035c2..9c99862d 100644 --- a/tests/CodeAllocator.test.cpp +++ b/tests/CodeAllocator.test.cpp @@ -57,8 +57,7 @@ TEST_CASE("CodeAllocationCallbacks") AllocationData allocationData{}; - const auto allocationCallback = [](void* context, void* oldPointer, size_t oldSize, void* newPointer, size_t newSize) - { + const auto allocationCallback = [](void* context, void* oldPointer, size_t oldSize, void* newPointer, size_t newSize) { AllocationData& allocationData = *static_cast(context); if (oldPointer != nullptr) { diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 93290567..49498744 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -74,7 +74,7 @@ TEST_CASE("BytecodeIsStable") // Bytecode ops (serialized & in-memory) CHECK(LOP_FASTCALL2K == 75); // bytecode v1 - CHECK(LOP_JUMPXEQKS == 80); // bytecode v3 + CHECK(LOP_JUMPXEQKS == 80); // bytecode v3 // Bytecode fastcall ids (serialized & in-memory) // Note: these aren't strictly bound to specific bytecode versions, but must monotonically increase to keep backwards compat @@ -7371,7 +7371,8 @@ TEST_CASE("BuiltinFoldMathK") function test() return math.pi * 2 end -)", 0, 2), +)", + 0, 2), R"( LOADK R0 K0 [6.2831853071795862] RETURN R0 1 @@ -7382,7 +7383,8 @@ RETURN R0 1 function test() return math.pi * 2 end -)", 0, 1), +)", + 0, 1), R"( GETIMPORT R1 3 [math.pi] MULK R0 R1 K0 [2] @@ -7396,7 +7398,8 @@ function test() end math = { pi = 4 } -)", 0, 2), +)", + 0, 2), R"( GETGLOBAL R2 K1 ['math'] GETTABLEKS R1 R2 K2 ['pi'] diff --git a/tests/IrCallWrapperX64.test.cpp b/tests/IrCallWrapperX64.test.cpp index ec04e531..1ff22a32 100644 --- a/tests/IrCallWrapperX64.test.cpp +++ b/tests/IrCallWrapperX64.test.cpp @@ -12,7 +12,7 @@ class IrCallWrapperX64Fixture public: IrCallWrapperX64Fixture(ABIX64 abi = ABIX64::Windows) : build(/* logText */ true, abi) - , regs(build, function) + , regs(build, function, nullptr) , callWrap(regs, build, ~0u) { } diff --git a/tests/IrRegAllocX64.test.cpp b/tests/IrRegAllocX64.test.cpp index bbf9c154..b4b63f4b 100644 --- a/tests/IrRegAllocX64.test.cpp +++ b/tests/IrRegAllocX64.test.cpp @@ -11,7 +11,7 @@ class IrRegAllocX64Fixture public: IrRegAllocX64Fixture() : build(/* logText */ true, ABIX64::Windows) - , regs(build, function) + , regs(build, function, nullptr) { } diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index afcc08f2..7bed110b 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -516,6 +516,71 @@ struct NormalizeFixture : Fixture TEST_SUITE_BEGIN("Normalize"); +TEST_CASE_FIXTURE(NormalizeFixture, "string_intersection_is_commutative") +{ + auto c4 = toString(normal(R"( + string & (string & Not<"a"> & Not<"b">) +)")); + auto c4Reverse = toString(normal(R"( + (string & Not<"a"> & Not<"b">) & string +)")); + CHECK(c4 == c4Reverse); + CHECK_EQ("string & ~\"a\" & ~\"b\"", c4); + + auto c5 = toString(normal(R"( + (string & Not<"a"> & Not<"b">) & (string & Not<"b"> & Not<"c">) +)")); + auto c5Reverse = toString(normal(R"( + (string & Not<"b"> & Not<"c">) & (string & Not<"a"> & Not<"c">) +)")); + CHECK(c5 == c5Reverse); + CHECK_EQ("string & ~\"a\" & ~\"b\" & ~\"c\"", c5); + + auto c6 = toString(normal(R"( + ("a" | "b") & (string & Not<"b"> & Not<"c">) +)")); + auto c6Reverse = toString(normal(R"( + (string & Not<"b"> & Not<"c">) & ("a" | "b") +)")); + CHECK(c6 == c6Reverse); + CHECK_EQ("\"a\"", c6); + + auto c7 = toString(normal(R"( + string & ("b" | "c") +)")); + auto c7Reverse = toString(normal(R"( + ("b" | "c") & string +)")); + CHECK(c7 == c7Reverse); + CHECK_EQ("\"b\" | \"c\"", c7); + + auto c8 = toString(normal(R"( +(string & Not<"a"> & Not<"b">) & ("b" | "c") +)")); + auto c8Reverse = toString(normal(R"( + ("b" | "c") & (string & Not<"a"> & Not<"b">) +)")); + CHECK(c8 == c8Reverse); + CHECK_EQ("\"c\"", c8); + auto c9 = toString(normal(R"( + ("a" | "b") & ("b" | "c") + )")); + auto c9Reverse = toString(normal(R"( + ("b" | "c") & ("a" | "b") + )")); + CHECK(c9 == c9Reverse); + CHECK_EQ("\"b\"", c9); + + auto l = toString(normal(R"( + (string | number) & ("a" | true) + )")); + auto r = toString(normal(R"( + ("a" | true) & (string | number) + )")); + CHECK(l == r); + CHECK_EQ("\"a\"", l); +} + TEST_CASE_FIXTURE(NormalizeFixture, "negate_string") { CHECK("number" == toString(normal(R"( diff --git a/tests/Subtyping.test.cpp b/tests/Subtyping.test.cpp index 6089f036..ce418b71 100644 --- a/tests/Subtyping.test.cpp +++ b/tests/Subtyping.test.cpp @@ -6,6 +6,7 @@ #include "Luau/Normalize.h" #include "Luau/Subtyping.h" +#include "Luau/Type.h" #include "Luau/TypePack.h" using namespace Luau; @@ -17,7 +18,15 @@ struct SubtypeFixture : Fixture UnifierSharedState sharedState{&ice}; Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; - Subtyping subtyping{builtinTypes, NotNull{&arena}, NotNull{&normalizer}, NotNull{&iceReporter}}; + ScopePtr rootScope{new Scope(builtinTypes->emptyTypePack)}; + ScopePtr moduleScope{new Scope(rootScope)}; + + Subtyping subtyping = mkSubtyping(rootScope); + + Subtyping mkSubtyping(const ScopePtr& scope) + { + return Subtyping{builtinTypes, NotNull{&arena}, NotNull{&normalizer}, NotNull{&iceReporter}, NotNull{scope.get()}}; + } TypePackId pack(std::initializer_list tys) { @@ -66,11 +75,18 @@ struct SubtypeFixture : Fixture return arena.addType(UnionType{{a, b}}); } + // `~` TypeId negate(TypeId ty) { return arena.addType(NegationType{ty}); } + // "literal" + TypeId str(const char* literal) + { + return arena.addType(SingletonType{StringSingleton{literal}}); + } + TypeId cls(const std::string& name, std::optional parent = std::nullopt) { return arena.addType(ClassType{name, {}, parent.value_or(builtinTypes->classType), {}, {}, nullptr, ""}); @@ -97,8 +113,8 @@ struct SubtypeFixture : Fixture return arena.addType(MetatableType{tbl(std::move(tableProps)), tbl(std::move(metaProps))}); } - TypeId genericT = arena.addType(GenericType{"T"}); - TypeId genericU = arena.addType(GenericType{"U"}); + TypeId genericT = arena.addType(GenericType{moduleScope.get(), "T"}); + TypeId genericU = arena.addType(GenericType{moduleScope.get(), "U"}); TypePackId genericAs = arena.addTypePack(GenericTypePack{"A"}); TypePackId genericBs = arena.addTypePack(GenericTypePack{"B"}); @@ -113,6 +129,10 @@ struct SubtypeFixture : Fixture TypeId helloType2 = arena.addType(SingletonType{StringSingleton{"hello"}}); TypeId worldType = arena.addType(SingletonType{StringSingleton{"world"}}); + TypeId aType = arena.addType(SingletonType{StringSingleton{"a"}}); + TypeId bType = arena.addType(SingletonType{StringSingleton{"b"}}); + TypeId trueSingleton = arena.addType(SingletonType{BooleanSingleton{true}}); + TypeId falseSingleton = arena.addType(SingletonType{BooleanSingleton{false}}); TypeId helloOrWorldType = join(helloType, worldType); TypeId trueOrFalseType = join(builtinTypes->trueType, builtinTypes->falseType); @@ -128,7 +148,7 @@ struct SubtypeFixture : Fixture * \- AnotherChild * |- AnotherGrandchildOne * \- AnotherGrandchildTwo - */ + */ TypeId rootClass = cls("Root"); TypeId childClass = cls("Child", rootClass); TypeId grandchildOneClass = cls("GrandchildOne", childClass); @@ -138,9 +158,9 @@ struct SubtypeFixture : Fixture TypeId anotherGrandchildTwoClass = cls("AnotherGrandchildTwo", anotherChildClass); TypeId vec2Class = cls("Vec2", { - {"X", builtinTypes->numberType}, - {"Y", builtinTypes->numberType}, - }); + {"X", builtinTypes->numberType}, + {"Y", builtinTypes->numberType}, + }); // "hello" | "hello" TypeId helloOrHelloType = arena.addType(UnionType{{helloType, helloType}}); @@ -149,160 +169,76 @@ struct SubtypeFixture : Fixture const TypeId nothingToNothingType = fn({}, {}); // (number) -> string - const TypeId numberToStringType = fn( - {builtinTypes->numberType}, - {builtinTypes->stringType} - ); + const TypeId numberToStringType = fn({builtinTypes->numberType}, {builtinTypes->stringType}); // (unknown) -> string - const TypeId unknownToStringType = fn( - {builtinTypes->unknownType}, - {builtinTypes->stringType} - ); + const TypeId unknownToStringType = fn({builtinTypes->unknownType}, {builtinTypes->stringType}); // (number) -> () - const TypeId numberToNothingType = fn( - {builtinTypes->numberType}, - {} - ); + const TypeId numberToNothingType = fn({builtinTypes->numberType}, {}); // () -> number - const TypeId nothingToNumberType = fn( - {}, - {builtinTypes->numberType} - ); + const TypeId nothingToNumberType = fn({}, {builtinTypes->numberType}); // (number) -> number - const TypeId numberToNumberType = fn( - {builtinTypes->numberType}, - {builtinTypes->numberType} - ); + const TypeId numberToNumberType = fn({builtinTypes->numberType}, {builtinTypes->numberType}); // (number) -> unknown - const TypeId numberToUnknownType = fn( - {builtinTypes->numberType}, - {builtinTypes->unknownType} - ); + const TypeId numberToUnknownType = fn({builtinTypes->numberType}, {builtinTypes->unknownType}); // (number) -> (string, string) - const TypeId numberToTwoStringsType = fn( - {builtinTypes->numberType}, - {builtinTypes->stringType, builtinTypes->stringType} - ); + const TypeId numberToTwoStringsType = fn({builtinTypes->numberType}, {builtinTypes->stringType, builtinTypes->stringType}); // (number) -> (string, unknown) - const TypeId numberToStringAndUnknownType = fn( - {builtinTypes->numberType}, - {builtinTypes->stringType, builtinTypes->unknownType} - ); + const TypeId numberToStringAndUnknownType = fn({builtinTypes->numberType}, {builtinTypes->stringType, builtinTypes->unknownType}); // (number, number) -> string - const TypeId numberNumberToStringType = fn( - {builtinTypes->numberType, builtinTypes->numberType}, - {builtinTypes->stringType} - ); + const TypeId numberNumberToStringType = fn({builtinTypes->numberType, builtinTypes->numberType}, {builtinTypes->stringType}); // (unknown, number) -> string - const TypeId unknownNumberToStringType = fn( - {builtinTypes->unknownType, builtinTypes->numberType}, - {builtinTypes->stringType} - ); + const TypeId unknownNumberToStringType = fn({builtinTypes->unknownType, builtinTypes->numberType}, {builtinTypes->stringType}); // (number, string) -> string - const TypeId numberAndStringToStringType = fn( - {builtinTypes->numberType, builtinTypes->stringType}, - {builtinTypes->stringType} - ); + const TypeId numberAndStringToStringType = fn({builtinTypes->numberType, builtinTypes->stringType}, {builtinTypes->stringType}); // (number, ...string) -> string - const TypeId numberAndStringsToStringType = fn( - {builtinTypes->numberType}, VariadicTypePack{builtinTypes->stringType}, - {builtinTypes->stringType} - ); + const TypeId numberAndStringsToStringType = + fn({builtinTypes->numberType}, VariadicTypePack{builtinTypes->stringType}, {builtinTypes->stringType}); // (number, ...string?) -> string - const TypeId numberAndOptionalStringsToStringType = fn( - {builtinTypes->numberType}, VariadicTypePack{builtinTypes->optionalStringType}, - {builtinTypes->stringType} - ); + const TypeId numberAndOptionalStringsToStringType = + fn({builtinTypes->numberType}, VariadicTypePack{builtinTypes->optionalStringType}, {builtinTypes->stringType}); // (...number) -> number - const TypeId numbersToNumberType = arena.addType(FunctionType{ - arena.addTypePack(VariadicTypePack{builtinTypes->numberType}), - arena.addTypePack({builtinTypes->numberType}) - }); + const TypeId numbersToNumberType = + arena.addType(FunctionType{arena.addTypePack(VariadicTypePack{builtinTypes->numberType}), arena.addTypePack({builtinTypes->numberType})}); // (T) -> () - const TypeId genericTToNothingType = arena.addType(FunctionType{ - {genericT}, - {}, - arena.addTypePack({genericT}), - builtinTypes->emptyTypePack - }); + const TypeId genericTToNothingType = arena.addType(FunctionType{{genericT}, {}, arena.addTypePack({genericT}), builtinTypes->emptyTypePack}); // (T) -> T - const TypeId genericTToTType = arena.addType(FunctionType{ - {genericT}, - {}, - arena.addTypePack({genericT}), - arena.addTypePack({genericT}) - }); + const TypeId genericTToTType = arena.addType(FunctionType{{genericT}, {}, arena.addTypePack({genericT}), arena.addTypePack({genericT})}); // (U) -> () - const TypeId genericUToNothingType = arena.addType(FunctionType{ - {genericU}, - {}, - arena.addTypePack({genericU}), - builtinTypes->emptyTypePack - }); + const TypeId genericUToNothingType = arena.addType(FunctionType{{genericU}, {}, arena.addTypePack({genericU}), builtinTypes->emptyTypePack}); // () -> T - const TypeId genericNothingToTType = arena.addType(FunctionType{ - {genericT}, - {}, - builtinTypes->emptyTypePack, - arena.addTypePack({genericT}) - }); + const TypeId genericNothingToTType = arena.addType(FunctionType{{genericT}, {}, builtinTypes->emptyTypePack, arena.addTypePack({genericT})}); // (A...) -> A... - const TypeId genericAsToAsType = arena.addType(FunctionType{ - {}, - {genericAs}, - genericAs, - genericAs - }); + const TypeId genericAsToAsType = arena.addType(FunctionType{{}, {genericAs}, genericAs, genericAs}); // (A...) -> number - const TypeId genericAsToNumberType = arena.addType(FunctionType{ - {}, - {genericAs}, - genericAs, - arena.addTypePack({builtinTypes->numberType}) - }); + const TypeId genericAsToNumberType = arena.addType(FunctionType{{}, {genericAs}, genericAs, arena.addTypePack({builtinTypes->numberType})}); // (B...) -> B... - const TypeId genericBsToBsType = arena.addType(FunctionType{ - {}, - {genericBs}, - genericBs, - genericBs - }); + const TypeId genericBsToBsType = arena.addType(FunctionType{{}, {genericBs}, genericBs, genericBs}); // (B...) -> C... - const TypeId genericBsToCsType = arena.addType(FunctionType{ - {}, - {genericBs, genericCs}, - genericBs, - genericCs - }); + const TypeId genericBsToCsType = arena.addType(FunctionType{{}, {genericBs, genericCs}, genericBs, genericCs}); // () -> A... - const TypeId genericNothingToAsType = arena.addType(FunctionType{ - {}, - {genericAs}, - builtinTypes->emptyTypePack, - genericAs - }); + const TypeId genericNothingToAsType = arena.addType(FunctionType{{}, {genericAs}, builtinTypes->emptyTypePack, genericAs}); // { lower : string -> string } TypeId tableWithLower = tbl(TableType::Props{{"lower", fn({builtinTypes->stringType}, {builtinTypes->stringType})}}); @@ -728,60 +664,98 @@ TEST_CASE_FIXTURE(SubtypeFixture, "{x: number?} (T) -> ()} <: {x: (U) -> ()}") { - CHECK_IS_SUBTYPE( - tbl({{"x", genericTToNothingType}}), - tbl({{"x", genericUToNothingType}}) - ); + CHECK_IS_SUBTYPE(tbl({{"x", genericTToNothingType}}), tbl({{"x", genericUToNothingType}})); } TEST_CASE_FIXTURE(SubtypeFixture, "{ @metatable { x: number } } <: { @metatable {} }") { - CHECK_IS_SUBTYPE( - meta({{"x", builtinTypes->numberType}}), - meta({}) - ); + CHECK_IS_SUBTYPE(meta({{"x", builtinTypes->numberType}}), meta({})); } TEST_CASE_FIXTURE(SubtypeFixture, "{ @metatable { x: number } } numberType}}), - meta({{"x", builtinTypes->booleanType}}) - ); + CHECK_IS_NOT_SUBTYPE(meta({{"x", builtinTypes->numberType}}), meta({{"x", builtinTypes->booleanType}})); } TEST_CASE_FIXTURE(SubtypeFixture, "{ @metatable {} } booleanType}}) - ); + CHECK_IS_NOT_SUBTYPE(meta({}), meta({{"x", builtinTypes->booleanType}})); } TEST_CASE_FIXTURE(SubtypeFixture, "{ @metatable {} } <: {}") { - CHECK_IS_SUBTYPE( - meta({}), - tbl({}) - ); + CHECK_IS_SUBTYPE(meta({}), tbl({})); } TEST_CASE_FIXTURE(SubtypeFixture, "{ @metatable { u: boolean }, x: number } <: { x: number }") { - CHECK_IS_SUBTYPE( - meta({{"u", builtinTypes->booleanType}}, {{"x", builtinTypes->numberType}}), - tbl({{"x", builtinTypes->numberType}}) - ); + CHECK_IS_SUBTYPE(meta({{"u", builtinTypes->booleanType}}, {{"x", builtinTypes->numberType}}), tbl({{"x", builtinTypes->numberType}})); } TEST_CASE_FIXTURE(SubtypeFixture, "{ @metatable { x: number } } numberType}}), - tbl({{"x", builtinTypes->numberType}}) - ); + CHECK_IS_NOT_SUBTYPE(meta({{"x", builtinTypes->numberType}}), tbl({{"x", builtinTypes->numberType}})); } +// Negated subtypes +TEST_IS_NOT_SUBTYPE(negate(builtinTypes->neverType), builtinTypes->stringType); +TEST_IS_SUBTYPE(negate(builtinTypes->unknownType), builtinTypes->stringType); +TEST_IS_NOT_SUBTYPE(negate(builtinTypes->anyType), builtinTypes->stringType); +TEST_IS_SUBTYPE(negate(meet(builtinTypes->neverType, builtinTypes->unknownType)), builtinTypes->stringType); +TEST_IS_NOT_SUBTYPE(negate(join(builtinTypes->neverType, builtinTypes->unknownType)), builtinTypes->stringType); + +// Negated supertypes: never/unknown/any/error +TEST_IS_SUBTYPE(builtinTypes->stringType, negate(builtinTypes->neverType)); +TEST_IS_SUBTYPE(builtinTypes->neverType, negate(builtinTypes->unknownType)); +TEST_IS_NOT_SUBTYPE(builtinTypes->stringType, negate(builtinTypes->unknownType)); +TEST_IS_SUBTYPE(builtinTypes->numberType, negate(builtinTypes->anyType)); +TEST_IS_SUBTYPE(builtinTypes->unknownType, negate(builtinTypes->anyType)); + +// Negated supertypes: unions +TEST_IS_SUBTYPE(builtinTypes->booleanType, negate(join(builtinTypes->stringType, builtinTypes->numberType))); +TEST_IS_SUBTYPE(rootClass, negate(join(childClass, builtinTypes->numberType))); +TEST_IS_SUBTYPE(str("foo"), negate(join(builtinTypes->numberType, builtinTypes->booleanType))); +TEST_IS_NOT_SUBTYPE(str("foo"), negate(join(builtinTypes->stringType, builtinTypes->numberType))); +TEST_IS_NOT_SUBTYPE(childClass, negate(join(rootClass, builtinTypes->numberType))); +TEST_IS_NOT_SUBTYPE(numbersToNumberType, negate(join(builtinTypes->functionType, rootClass))); + +// Negated supertypes: intersections +TEST_IS_SUBTYPE(builtinTypes->booleanType, negate(meet(builtinTypes->stringType, str("foo")))); +TEST_IS_SUBTYPE(builtinTypes->trueType, negate(meet(builtinTypes->booleanType, builtinTypes->numberType))); +TEST_IS_SUBTYPE(rootClass, negate(meet(builtinTypes->classType, childClass))); +TEST_IS_SUBTYPE(childClass, negate(meet(builtinTypes->classType, builtinTypes->numberType))); +TEST_IS_NOT_SUBTYPE(builtinTypes->unknownType, negate(meet(builtinTypes->classType, builtinTypes->numberType))); +TEST_IS_NOT_SUBTYPE(str("foo"), negate(meet(builtinTypes->stringType, negate(str("bar"))))); + +// Negated supertypes: tables and metatables +TEST_IS_SUBTYPE(tbl({}), negate(builtinTypes->numberType)); +TEST_IS_NOT_SUBTYPE(tbl({}), negate(builtinTypes->tableType)); +TEST_IS_SUBTYPE(meta({}), negate(builtinTypes->numberType)); +TEST_IS_NOT_SUBTYPE(meta({}), negate(builtinTypes->tableType)); + +// Negated supertypes: Functions +TEST_IS_SUBTYPE(numberToNumberType, negate(builtinTypes->classType)); +TEST_IS_NOT_SUBTYPE(numberToNumberType, negate(builtinTypes->functionType)); + +// Negated supertypes: Primitives and singletons +TEST_IS_SUBTYPE(builtinTypes->stringType, negate(builtinTypes->numberType)); +TEST_IS_SUBTYPE(str("foo"), meet(builtinTypes->stringType, negate(str("bar")))); +TEST_IS_NOT_SUBTYPE(builtinTypes->trueType, negate(builtinTypes->booleanType)); +TEST_IS_NOT_SUBTYPE(str("foo"), negate(str("foo"))); +TEST_IS_NOT_SUBTYPE(str("foo"), negate(builtinTypes->stringType)); +TEST_IS_SUBTYPE(builtinTypes->falseType, negate(builtinTypes->trueType)); +TEST_IS_SUBTYPE(builtinTypes->falseType, meet(builtinTypes->booleanType, negate(builtinTypes->trueType))); +TEST_IS_NOT_SUBTYPE(builtinTypes->stringType, meet(builtinTypes->booleanType, negate(builtinTypes->trueType))); +TEST_IS_NOT_SUBTYPE(builtinTypes->stringType, negate(str("foo"))); +TEST_IS_NOT_SUBTYPE(builtinTypes->booleanType, negate(builtinTypes->falseType)); + +// Negated supertypes: Classes +TEST_IS_SUBTYPE(rootClass, negate(builtinTypes->tableType)); +TEST_IS_NOT_SUBTYPE(rootClass, negate(builtinTypes->classType)); +TEST_IS_NOT_SUBTYPE(childClass, negate(rootClass)); +TEST_IS_NOT_SUBTYPE(childClass, meet(builtinTypes->classType, negate(rootClass))); +TEST_IS_SUBTYPE(anotherChildClass, meet(builtinTypes->classType, negate(childClass))); + TEST_CASE_FIXTURE(SubtypeFixture, "Root <: class") { CHECK_IS_SUBTYPE(rootClass, builtinTypes->classType); @@ -829,13 +803,11 @@ TEST_CASE_FIXTURE(SubtypeFixture, "Child & ~GrandchildOne string} <: t2 where t2 = {trim: (t2) -> string}") { - TypeId t1 = cyclicTable([&](TypeId ty, TableType* tt) - { + TypeId t1 = cyclicTable([&](TypeId ty, TableType* tt) { tt->props["trim"] = fn({ty}, {builtinTypes->stringType}); }); - TypeId t2 = cyclicTable([&](TypeId ty, TableType* tt) - { + TypeId t2 = cyclicTable([&](TypeId ty, TableType* tt) { tt->props["trim"] = fn({ty}, {builtinTypes->stringType}); }); @@ -844,13 +816,11 @@ TEST_CASE_FIXTURE(SubtypeFixture, "t1 where t1 = {trim: (t1) -> string} <: t2 wh TEST_CASE_FIXTURE(SubtypeFixture, "t1 where t1 = {trim: (t1) -> string} t2}") { - TypeId t1 = cyclicTable([&](TypeId ty, TableType* tt) - { + TypeId t1 = cyclicTable([&](TypeId ty, TableType* tt) { tt->props["trim"] = fn({ty}, {builtinTypes->stringType}); }); - TypeId t2 = cyclicTable([&](TypeId ty, TableType* tt) - { + TypeId t2 = cyclicTable([&](TypeId ty, TableType* tt) { tt->props["trim"] = fn({ty}, {ty}); }); @@ -859,13 +829,11 @@ TEST_CASE_FIXTURE(SubtypeFixture, "t1 where t1 = {trim: (t1) -> string} t1} string}") { - TypeId t1 = cyclicTable([&](TypeId ty, TableType* tt) - { + TypeId t1 = cyclicTable([&](TypeId ty, TableType* tt) { tt->props["trim"] = fn({ty}, {ty}); }); - TypeId t2 = cyclicTable([&](TypeId ty, TableType* tt) - { + TypeId t2 = cyclicTable([&](TypeId ty, TableType* tt) { tt->props["trim"] = fn({ty}, {builtinTypes->stringType}); }); @@ -960,6 +928,50 @@ TEST_CASE_FIXTURE(SubtypeFixture, "(string) -> number <: ~fun & (string) -> numb CHECK_IS_NOT_SUBTYPE(numberToStringType, meet(negate(builtinTypes->functionType), numberToStringType)); } +TEST_CASE_FIXTURE(SubtypeFixture, "~\"a\" & ~\"b\" & string <: { lower : (string) -> ()}") +{ + CHECK_IS_SUBTYPE(meet(meet(negate(aType), negate(bType)), builtinTypes->stringType), tableWithLower); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "\"a\" | (~\"b\" & string) <: { lower : (string) -> ()}") +{ + CHECK_IS_SUBTYPE(join(aType, meet(negate(bType), builtinTypes->stringType)), tableWithLower); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(string | number) & (\"a\" | true) <: { lower: (string) -> string }") +{ + auto base = meet(join(builtinTypes->stringType, builtinTypes->numberType), join(aType, trueSingleton)); + CHECK_IS_SUBTYPE(base, tableWithLower); +} + +/* + * Within the scope to which a generic belongs, that generic ought to be treated + * as its bounds. + * + * We do not yet support bounded generics, so all generics are considered to be + * bounded by unknown. + */ +TEST_CASE_FIXTURE(SubtypeFixture, "unknown <: X") +{ + ScopePtr childScope{new Scope(rootScope)}; + ScopePtr grandChildScope{new Scope(childScope)}; + + TypeId genericX = arena.addType(GenericType(childScope.get(), "X")); + + SubtypingResult usingGlobalScope = subtyping.isSubtype(builtinTypes->unknownType, genericX); + CHECK_MESSAGE(!usingGlobalScope.isSubtype, "Expected " << builtinTypes->unknownType << " unknownType, genericX); + CHECK_MESSAGE(usingChildScope.isSubtype, "Expected " << builtinTypes->unknownType << " <: " << genericX); + + Subtyping grandChildSubtyping{mkSubtyping(grandChildScope)}; + + SubtypingResult usingGrandChildScope = grandChildSubtyping.isSubtype(builtinTypes->unknownType, genericX); + CHECK_MESSAGE(usingGrandChildScope.isSubtype, "Expected " << builtinTypes->unknownType << " <: " << genericX); +} + /* * (A) -> A <: (X) -> X * A can be bound to X. diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index 6a551811..c588eaba 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -1067,8 +1067,9 @@ local w = c and 1 CHECK("false | number" == toString(requireType("z"))); else CHECK("boolean | number" == toString(requireType("z"))); // 'false' widened to boolean + if (FFlag::DebugLuauDeferredConstraintResolution) - CHECK("((false?) & a) | number" == toString(requireType("w"))); + CHECK("((false?) & unknown) | number" == toString(requireType("w"))); else CHECK("(boolean | number)?" == toString(requireType("w"))); } diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index c0dbfce8..c42830aa 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -308,7 +308,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "typeguard_in_if_condition_position") else CHECK_EQ("number", toString(requireTypeAtPosition({3, 26}))); CHECK_EQ("number", toString(requireTypeAtPosition({6, 26}))); - } TEST_CASE_FIXTURE(BuiltinsFixture, "typeguard_in_assert_position") diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 315798c2..aea38253 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -3788,4 +3788,40 @@ TEST_CASE_FIXTURE(Fixture, "cyclic_shifted_tables") LUAU_REQUIRE_NO_ERRORS(result); } + +TEST_CASE_FIXTURE(Fixture, "cli_84607_missing_prop_in_array_or_dict") +{ + ScopedFastFlag sff{"LuauFixIndexerSubtypingOrdering", true}; + + CheckResult result = check(R"( + type Thing = { name: string, prop: boolean } + + local arrayOfThings : {Thing} = { + { name = "a" } + } + + local dictOfThings : {[string]: Thing} = { + a = { name = "a" } + } + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + TypeError& err1 = result.errors[0]; + MissingProperties* error1 = get(err1); + REQUIRE(error1); + REQUIRE(error1->properties.size() == 1); + + CHECK_EQ("prop", error1->properties[0]); + + TypeError& err2 = result.errors[1]; + TypeMismatch* mismatch = get(err2); + REQUIRE(mismatch); + MissingProperties* error2 = get(*mismatch->error); + REQUIRE(error2); + REQUIRE(error2->properties.size() == 1); + + CHECK_EQ("prop", error2->properties[0]); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index aeabf0ac..2d34fc7f 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -1428,7 +1428,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "be_sure_to_use_active_txnlog_when_evaluating LUAU_REQUIRE_ERRORS(result); - for (const auto& e: result.errors) + for (const auto& e : result.errors) CHECK(5 == e.location.begin.line); } diff --git a/tests/Unifier2.test.cpp b/tests/Unifier2.test.cpp index 363c8109..2e6cf3b6 100644 --- a/tests/Unifier2.test.cpp +++ b/tests/Unifier2.test.cpp @@ -81,18 +81,12 @@ TEST_CASE_FIXTURE(Unifier2Fixture, "T <: U") TEST_CASE_FIXTURE(Unifier2Fixture, "(string) -> () <: (X) -> Y...") { - TypeId stringToUnit = arena.addType(FunctionType{ - arena.addTypePack({builtinTypes.stringType}), - arena.addTypePack({}) - }); + TypeId stringToUnit = arena.addType(FunctionType{arena.addTypePack({builtinTypes.stringType}), arena.addTypePack({})}); auto [x, xFree] = freshType(); TypePackId y = arena.freshTypePack(&scope); - TypeId xToY = arena.addType(FunctionType{ - arena.addTypePack({x}), - y - }); + TypeId xToY = arena.addType(FunctionType{arena.addTypePack({x}), y}); u2.unify(stringToUnit, xToY); @@ -105,4 +99,54 @@ TEST_CASE_FIXTURE(Unifier2Fixture, "(string) -> () <: (X) -> Y...") CHECK(!yPack->tail); } +TEST_CASE_FIXTURE(Unifier2Fixture, "generalize_a_type_that_is_bounded_by_another_generalizable_type") +{ + auto [t1, ft1] = freshType(); + auto [t2, ft2] = freshType(); + + // t2 <: t1 <: unknown + // unknown <: t2 <: t1 + + ft1->lowerBound = t2; + ft2->upperBound = t1; + ft2->lowerBound = builtinTypes.unknownType; + + auto t2generalized = u2.generalize(NotNull{&scope}, t2); + REQUIRE(t2generalized); + + CHECK(follow(t1) == follow(t2)); + + auto t1generalized = u2.generalize(NotNull{&scope}, t1); + REQUIRE(t1generalized); + + CHECK(builtinTypes.unknownType == follow(t1)); + CHECK(builtinTypes.unknownType == follow(t2)); +} + +// Same as generalize_a_type_that_is_bounded_by_another_generalizable_type +// except that we generalize the types in the opposite order +TEST_CASE_FIXTURE(Unifier2Fixture, "generalize_a_type_that_is_bounded_by_another_generalizable_type_in_reverse_order") +{ + auto [t1, ft1] = freshType(); + auto [t2, ft2] = freshType(); + + // t2 <: t1 <: unknown + // unknown <: t2 <: t1 + + ft1->lowerBound = t2; + ft2->upperBound = t1; + ft2->lowerBound = builtinTypes.unknownType; + + auto t1generalized = u2.generalize(NotNull{&scope}, t1); + REQUIRE(t1generalized); + + CHECK(follow(t1) == follow(t2)); + + auto t2generalized = u2.generalize(NotNull{&scope}, t2); + REQUIRE(t2generalized); + + CHECK(builtinTypes.unknownType == follow(t1)); + CHECK(builtinTypes.unknownType == follow(t2)); +} + TEST_SUITE_END(); diff --git a/tools/faillist.txt b/tools/faillist.txt index f179a63f..69790ba5 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -3,7 +3,6 @@ AnnotationTests.two_type_params AstQuery.last_argument_function_call_type AutocompleteTest.anonymous_autofilled_generic_on_argument_type_pack_vararg AutocompleteTest.anonymous_autofilled_generic_type_pack_vararg -AutocompleteTest.autocomplete_if_else_regression AutocompleteTest.autocomplete_interpolated_string_as_singleton AutocompleteTest.autocomplete_oop_implicit_self AutocompleteTest.autocomplete_response_perf1 @@ -92,8 +91,6 @@ IntersectionTypes.table_intersection_write_sealed_indirect IntersectionTypes.table_write_sealed_indirect Normalize.negations_of_tables Normalize.specific_functions_cannot_be_negated -ParserTests.parse_nesting_based_end_detection -ParserTests.parse_nesting_based_end_detection_single_line ProvisionalTests.assign_table_with_refined_property_with_a_similar_type_is_illegal ProvisionalTests.do_not_ice_when_trying_to_pick_first_of_generic_type_pack ProvisionalTests.error_on_eq_metamethod_returning_a_type_other_than_boolean @@ -124,13 +121,13 @@ RefinementTest.x_as_any_if_x_is_instance_elseif_x_is_table TableTests.a_free_shape_can_turn_into_a_scalar_if_it_is_compatible TableTests.a_free_shape_cannot_turn_into_a_scalar_if_it_is_not_compatible TableTests.call_method -TableTests.call_method_with_explicit_self_argument TableTests.cannot_augment_sealed_table TableTests.cannot_change_type_of_unsealed_table_prop TableTests.casting_sealed_tables_with_props_into_table_with_indexer TableTests.casting_tables_with_props_into_table_with_indexer4 TableTests.casting_unsealed_tables_with_props_into_table_with_indexer TableTests.checked_prop_too_early +TableTests.cli_84607_missing_prop_in_array_or_dict TableTests.cyclic_shifted_tables TableTests.defining_a_method_for_a_local_sealed_table_must_fail TableTests.defining_a_self_method_for_a_local_sealed_table_must_fail @@ -139,6 +136,7 @@ TableTests.dont_crash_when_setmetatable_does_not_produce_a_metatabletypevar TableTests.dont_extend_unsealed_tables_in_rvalue_position TableTests.dont_hang_when_trying_to_look_up_in_cyclic_metatable_index TableTests.dont_leak_free_table_props +TableTests.dont_quantify_table_that_belongs_to_outer_scope TableTests.dont_suggest_exact_match_keys TableTests.error_detailed_metatable_prop TableTests.explicitly_typed_table @@ -154,19 +152,21 @@ TableTests.inequality_operators_imply_exactly_matching_types TableTests.infer_array_2 TableTests.infer_indexer_from_value_property_in_literal TableTests.infer_type_when_indexing_from_a_table_indexer -TableTests.inferred_properties_of_a_table_should_start_with_the_same_TypeLevel_of_that_table TableTests.inferred_return_type_of_free_table TableTests.instantiate_table_cloning_3 TableTests.leaking_bad_metatable_errors TableTests.less_exponential_blowup_please TableTests.missing_metatable_for_sealed_tables_do_not_get_inferred TableTests.mixed_tables_with_implicit_numbered_keys +TableTests.ok_to_add_property_to_free_table TableTests.ok_to_provide_a_subtype_during_construction TableTests.okay_to_add_property_to_unsealed_tables_by_assignment -TableTests.okay_to_add_property_to_unsealed_tables_by_function_call -TableTests.only_ascribe_synthetic_names_at_module_scope TableTests.oop_indexer_works TableTests.oop_polymorphic +TableTests.pass_a_union_of_tables_to_a_function_that_requires_a_table +TableTests.pass_a_union_of_tables_to_a_function_that_requires_a_table_2 +TableTests.pass_incompatible_union_to_a_generic_table_without_crashing +TableTests.passing_compatible_unions_to_a_generic_table_without_crashing TableTests.quantify_even_that_table_was_never_exported_at_all TableTests.quantify_metatables_of_metatables_of_table TableTests.quantifying_a_bound_var_works @@ -185,7 +185,6 @@ TableTests.table_unification_4 TableTests.type_mismatch_on_massive_table_is_cut_short TableTests.used_colon_instead_of_dot TableTests.used_dot_instead_of_colon -TableTests.used_dot_instead_of_colon_but_correctly TableTests.when_augmenting_an_unsealed_table_with_an_indexer_apply_the_correct_scope_to_the_indexer_type TableTests.wrong_assign_does_hit_indexer ToDot.function @@ -215,6 +214,7 @@ TypeAliases.type_alias_of_an_imported_recursive_generic_type TypeFamilyTests.family_as_fn_arg TypeFamilyTests.table_internal_families TypeFamilyTests.unsolvable_family +TypeInfer.be_sure_to_use_active_txnlog_when_evaluating_a_variadic_overload TypeInfer.bidirectional_checking_of_higher_order_function TypeInfer.check_type_infer_recursion_count TypeInfer.cli_39932_use_unifier_in_ensure_methods @@ -226,21 +226,24 @@ TypeInfer.fuzz_free_table_type_change_during_index_check TypeInfer.infer_assignment_value_types_mutable_lval TypeInfer.infer_locals_via_assignment_from_its_call_site TypeInfer.no_stack_overflow_from_isoptional -TypeInfer.recursive_function_that_invokes_itself_with_a_refinement_of_its_parameter TypeInfer.recursive_function_that_invokes_itself_with_a_refinement_of_its_parameter_2 TypeInfer.tc_after_error_recovery_no_replacement_name_in_error TypeInfer.type_infer_cache_limit_normalizer TypeInfer.type_infer_recursion_limit_no_ice TypeInfer.type_infer_recursion_limit_normalizer TypeInferAnyError.can_subscript_any +TypeInferAnyError.for_in_loop_iterator_is_any TypeInferAnyError.for_in_loop_iterator_is_any2 +TypeInferAnyError.for_in_loop_iterator_is_error +TypeInferAnyError.for_in_loop_iterator_is_error2 TypeInferAnyError.for_in_loop_iterator_returns_any +TypeInferAnyError.for_in_loop_iterator_returns_any2 TypeInferAnyError.intersection_of_any_can_have_props TypeInferAnyError.replace_every_free_type_when_unifying_a_complex_function_with_any TypeInferAnyError.union_of_types_regression_test -TypeInferClasses.can_read_prop_of_base_class_using_string TypeInferClasses.class_type_mismatch_with_name_conflict TypeInferClasses.index_instance_property +TypeInferClasses.table_class_unification_reports_sane_errors_for_missing_properties TypeInferFunctions.cannot_hoist_interior_defns_into_signature TypeInferFunctions.dont_assert_when_the_tarjan_limit_is_exceeded_during_generalization TypeInferFunctions.dont_infer_parameter_types_for_functions_from_their_call_site @@ -254,10 +257,10 @@ TypeInferFunctions.higher_order_function_2 TypeInferFunctions.higher_order_function_4 TypeInferFunctions.improved_function_arg_mismatch_errors TypeInferFunctions.infer_anonymous_function_arguments +TypeInferFunctions.infer_anonymous_function_arguments_outside_call TypeInferFunctions.infer_generic_function_function_argument TypeInferFunctions.infer_generic_function_function_argument_overloaded TypeInferFunctions.infer_generic_lib_function_function_argument -TypeInferFunctions.infer_anonymous_function_arguments_outside_call TypeInferFunctions.infer_that_function_does_not_return_a_table TypeInferFunctions.luau_subtyping_is_np_hard TypeInferFunctions.no_lossy_function_type @@ -269,21 +272,22 @@ TypeInferFunctions.too_few_arguments_variadic_generic2 TypeInferFunctions.too_many_arguments_error_location TypeInferFunctions.too_many_return_values_in_parentheses TypeInferFunctions.too_many_return_values_no_function -TypeInferFunctions.toposort_doesnt_break_mutual_recursion -TypeInferFunctions.vararg_function_is_quantified TypeInferLoops.cli_68448_iterators_need_not_accept_nil TypeInferLoops.dcr_iteration_on_never_gives_never TypeInferLoops.for_in_loop TypeInferLoops.for_in_loop_error_on_factory_not_returning_the_right_amount_of_values +TypeInferLoops.for_in_loop_on_error TypeInferLoops.for_in_loop_with_custom_iterator TypeInferLoops.for_in_loop_with_incompatible_args_to_iterator TypeInferLoops.for_in_loop_with_next TypeInferLoops.ipairs_produces_integral_indices TypeInferLoops.iteration_regression_issue_69967_alt +TypeInferLoops.loop_iter_basic TypeInferLoops.loop_iter_metamethod_nil TypeInferLoops.loop_iter_metamethod_ok_with_inference TypeInferLoops.loop_iter_trailing_nil TypeInferLoops.unreachable_code_after_infinite_loop +TypeInferLoops.varlist_declared_by_for_in_loop_should_be_free TypeInferModules.do_not_modify_imported_types_5 TypeInferModules.module_type_conflict TypeInferModules.module_type_conflict_instantiated @@ -294,7 +298,6 @@ TypeInferOOP.inferring_hundreds_of_self_calls_should_not_suffocate_memory TypeInferOOP.methods_are_topologically_sorted TypeInferOperators.and_binexps_dont_unify TypeInferOperators.cli_38355_recursive_union -TypeInferOperators.compound_assign_mismatch_metatable TypeInferOperators.concat_op_on_string_lhs_and_free_rhs TypeInferOperators.disallow_string_and_types_without_metatables_from_arithmetic_binary_ops TypeInferOperators.luau_polyfill_is_array From 309001020a784a9f18acc3136146b0ac4fd57b52 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Sat, 16 Sep 2023 12:21:09 +0200 Subject: [PATCH 2/3] Update benchmark.yml Update apt-get cache before installing valgrind as it looks like the default cache got out of date. --- .github/workflows/benchmark.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index c7531608..7a11fbe1 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -25,6 +25,7 @@ jobs: - name: Install valgrind run: | + sudo apt-get update sudo apt-get install valgrind - name: Build Luau (gcc) From d00e93c82c0ad156afa8cf0c100c2becfe83b808 Mon Sep 17 00:00:00 2001 From: Amber Grace <131925693+AmberGraceSoftware@users.noreply.github.com> Date: Thu, 21 Sep 2023 16:28:42 -0600 Subject: [PATCH 3/3] Support Control Flow type Refinements for "break" and "continue" statements (#1004) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes: https://github.com/Roblox/luau/issues/913 This PR adds support for type refinements around guard clauses that use `break` and `continue` statements inside a loop, similar to how guard clauses with `return` is supported. I had some free time today, so I figure I'd give a shot at a naïve fix for this at the very least. --- ## Resulting Change: Luau now supports type refinements within loops where a `continue` or `break` guard clause was used. For example: ```lua for _, object in objects :: {{value: string?}} do if not object.value then continue end local x: string = object.value -- OK; Used to emit "Type 'string?' could not be converted into 'string'" end ``` --------- Co-authored-by: Alexander McCord --- Analysis/include/Luau/ControlFlow.h | 4 +- Analysis/src/ConstraintGraphBuilder.cpp | 18 +- Analysis/src/TypeInfer.cpp | 18 +- tests/TypeInfer.cfa.test.cpp | 764 ++++++++++++++++++++++++ 4 files changed, 786 insertions(+), 18 deletions(-) diff --git a/Analysis/include/Luau/ControlFlow.h b/Analysis/include/Luau/ControlFlow.h index 566d77bd..82c0403c 100644 --- a/Analysis/include/Luau/ControlFlow.h +++ b/Analysis/include/Luau/ControlFlow.h @@ -14,8 +14,8 @@ enum class ControlFlow None = 0b00001, Returns = 0b00010, Throws = 0b00100, - Break = 0b01000, // Currently unused. - Continue = 0b10000, // Currently unused. + Breaks = 0b01000, + Continues = 0b10000, }; inline ControlFlow operator&(ControlFlow a, ControlFlow b) diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index ae143ca5..f9b0dbf8 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -24,6 +24,7 @@ LUAU_FASTINT(LuauCheckRecursionLimit); LUAU_FASTFLAG(DebugLuauLogSolverToJson); LUAU_FASTFLAG(DebugLuauMagicTypes); LUAU_FASTFLAG(LuauParseDeclareClassIndexer); +LUAU_FASTFLAG(LuauLoopControlFlowAnalysis); LUAU_FASTFLAG(LuauFloorDivision); namespace Luau @@ -537,11 +538,10 @@ ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStat* stat) return visit(scope, s); else if (auto s = stat->as()) return visit(scope, s); - else if (stat->is() || stat->is()) - { - // Nothing - return ControlFlow::None; // TODO: ControlFlow::Break/Continue - } + else if (stat->is()) + return FFlag::LuauLoopControlFlowAnalysis ? ControlFlow::Breaks : ControlFlow::None; + else if (stat->is()) + return FFlag::LuauLoopControlFlowAnalysis ? ControlFlow::Continues : ControlFlow::None; else if (auto r = stat->as()) return visit(scope, r); else if (auto e = stat->as()) @@ -1072,12 +1072,14 @@ ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatIf* ifSt if (ifStatement->elsebody) elsecf = visit(elseScope, ifStatement->elsebody); - if (matches(thencf, ControlFlow::Returns | ControlFlow::Throws) && elsecf == ControlFlow::None) + if (thencf != ControlFlow::None && elsecf == ControlFlow::None) scope->inheritRefinements(elseScope); - else if (thencf == ControlFlow::None && matches(elsecf, ControlFlow::Returns | ControlFlow::Throws)) + else if (thencf == ControlFlow::None && elsecf != ControlFlow::None) scope->inheritRefinements(thenScope); - if (matches(thencf, ControlFlow::Returns | ControlFlow::Throws) && matches(elsecf, ControlFlow::Returns | ControlFlow::Throws)) + if (FFlag::LuauLoopControlFlowAnalysis && thencf == elsecf) + return thencf; + else if (matches(thencf, ControlFlow::Returns | ControlFlow::Throws) && matches(elsecf, ControlFlow::Returns | ControlFlow::Throws)) return ControlFlow::Returns; else return ControlFlow::None; diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 61c90ba8..a29b1e06 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -38,6 +38,7 @@ LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAGVARIABLE(LuauAllowIndexClassParameters, false) LUAU_FASTFLAG(LuauOccursIsntAlwaysFailure) LUAU_FASTFLAGVARIABLE(LuauTinyControlFlowAnalysis, false) +LUAU_FASTFLAGVARIABLE(LuauLoopControlFlowAnalysis, false) LUAU_FASTFLAGVARIABLE(LuauVariadicOverloadFix, false) LUAU_FASTFLAGVARIABLE(LuauAlwaysCommitInferencesOfFunctionCalls, false) LUAU_FASTFLAG(LuauParseDeclareClassIndexer) @@ -350,11 +351,10 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStat& program) return check(scope, *while_); else if (auto repeat = program.as()) return check(scope, *repeat); - else if (program.is() || program.is()) - { - // Nothing to do - return ControlFlow::None; - } + else if (program.is()) + return FFlag::LuauLoopControlFlowAnalysis ? ControlFlow::Breaks : ControlFlow::None; + else if (program.is()) + return FFlag::LuauLoopControlFlowAnalysis ? ControlFlow::Continues : ControlFlow::None; else if (auto return_ = program.as()) return check(scope, *return_); else if (auto expr = program.as()) @@ -752,12 +752,14 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatIf& statement if (statement.elsebody) elsecf = check(elseScope, *statement.elsebody); - if (matches(thencf, ControlFlow::Returns | ControlFlow::Throws) && elsecf == ControlFlow::None) + if (thencf != ControlFlow::None && elsecf == ControlFlow::None) scope->inheritRefinements(elseScope); - else if (thencf == ControlFlow::None && matches(elsecf, ControlFlow::Returns | ControlFlow::Throws)) + else if (thencf == ControlFlow::None && elsecf != ControlFlow::None) scope->inheritRefinements(thenScope); - if (matches(thencf, ControlFlow::Returns | ControlFlow::Throws) && matches(elsecf, ControlFlow::Returns | ControlFlow::Throws)) + if (FFlag::LuauLoopControlFlowAnalysis && thencf == elsecf) + return thencf; + else if (matches(thencf, ControlFlow::Returns | ControlFlow::Throws) && matches(elsecf, ControlFlow::Returns | ControlFlow::Throws)) return ControlFlow::Returns; else return ControlFlow::None; diff --git a/tests/TypeInfer.cfa.test.cpp b/tests/TypeInfer.cfa.test.cpp index 04aeb54b..19700d2c 100644 --- a/tests/TypeInfer.cfa.test.cpp +++ b/tests/TypeInfer.cfa.test.cpp @@ -26,6 +26,52 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_return") CHECK_EQ("string", toString(requireTypeAtPosition({6, 24}))); } +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_break") +{ + ScopedFastFlag flags[] = { + {"LuauTinyControlFlowAnalysis", true}, + {"LuauLoopControlFlowAnalysis", true} + }; + + CheckResult result = check(R"( + local function f(x: {{value: string?}}) + for _, record in x do + if not record.value then + break + end + + local foo = record.value + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({7, 34}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_continue") +{ + ScopedFastFlag flags[] = { + {"LuauTinyControlFlowAnalysis", true}, + {"LuauLoopControlFlowAnalysis", true} + }; + + CheckResult result = check(R"( + local function f(x: {{value: string?}}) + for _, record in x do + if not record.value then + continue + end + + local foo = record.value + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({7, 38}))); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_return_elif_not_y_return") { ScopedFastFlag sff{"LuauTinyControlFlowAnalysis", true}; @@ -48,6 +94,118 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_return_elif_not_y_return") CHECK_EQ("string", toString(requireTypeAtPosition({9, 24}))); } +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_break_elif_not_y_break") +{ + ScopedFastFlag flags[] = { + {"LuauTinyControlFlowAnalysis", true}, + {"LuauLoopControlFlowAnalysis", true} + }; + + CheckResult result = check(R"( + local function f(x: {{value: string?}}, y: {{value: string?}}) + for i, recordX in x do + local recordY = y[i] + if not recordX.value then + break + elseif not recordY.value then + break + end + + local foo = recordX.value + local bar = recordY.value + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({10, 38}))); + CHECK_EQ("string", toString(requireTypeAtPosition({11, 38}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_continue_elif_not_y_continue") +{ + ScopedFastFlag flags[] = { + {"LuauTinyControlFlowAnalysis", true}, + {"LuauLoopControlFlowAnalysis", true} + }; + + CheckResult result = check(R"( + local function f(x: {{value: string?}}, y: {{value: string?}}) + for i, recordX in x do + local recordY = y[i] + if not recordX.value then + continue + elseif not recordY.value then + continue + end + + local foo = recordX.value + local bar = recordY.value + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({10, 38}))); + CHECK_EQ("string", toString(requireTypeAtPosition({11, 38}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_return_elif_not_y_break") +{ + ScopedFastFlag flags[] = { + {"LuauTinyControlFlowAnalysis", true}, + {"LuauLoopControlFlowAnalysis", true} + }; + + CheckResult result = check(R"( + local function f(x: {{value: string?}}, y: {{value: string?}}) + for i, recordX in x do + local recordY = y[i] + if not recordX.value then + return + elseif not recordY.value then + break + end + + local foo = recordX.value + local bar = recordY.value + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({10, 38}))); + CHECK_EQ("string", toString(requireTypeAtPosition({11, 38}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_break_elif_not_y_continue") +{ + ScopedFastFlag flags[] = { + {"LuauTinyControlFlowAnalysis", true}, + {"LuauLoopControlFlowAnalysis", true} + }; + + CheckResult result = check(R"( + local function f(x: {{value: string?}}, y: {{value: string?}}) + for i, recordX in x do + local recordY = y[i] + if not recordX.value then + break + elseif not recordY.value then + continue + end + + local foo = recordX.value + local bar = recordY.value + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({10, 38}))); + CHECK_EQ("string", toString(requireTypeAtPosition({11, 38}))); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_return_elif_rand_return_elif_not_y_return") { ScopedFastFlag sff{"LuauTinyControlFlowAnalysis", true}; @@ -72,6 +230,66 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_return_elif_rand_return_elif_not_y_ CHECK_EQ("string", toString(requireTypeAtPosition({11, 24}))); } +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_break_elif_rand_break_elif_not_y_break") +{ + ScopedFastFlag flags[] = { + {"LuauTinyControlFlowAnalysis", true}, + {"LuauLoopControlFlowAnalysis", true} + }; + + CheckResult result = check(R"( + local function f(x: {{value: string?}}, y: {{value: string?}}) + for i, recordX in x do + local recordY = y[i] + if not recordX.value then + break + elseif math.random() > 0.5 then + break + elseif not recordY.value then + break + end + + local foo = recordX.value + local bar = recordY.value + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({12, 38}))); + CHECK_EQ("string", toString(requireTypeAtPosition({13, 38}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_continue_elif_rand_continue_elif_not_y_continue") +{ + ScopedFastFlag flags[] = { + {"LuauTinyControlFlowAnalysis", true}, + {"LuauLoopControlFlowAnalysis", true} + }; + + CheckResult result = check(R"( + local function f(x: {{value: string?}}, y: {{value: string?}}) + for i, recordX in x do + local recordY = y[i] + if not recordX.value then + continue + elseif math.random() > 0.5 then + continue + elseif not recordY.value then + continue + end + + local foo = recordX.value + local bar = recordY.value + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({12, 38}))); + CHECK_EQ("string", toString(requireTypeAtPosition({13, 38}))); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_return_elif_not_rand_return_elif_not_y_fallthrough") { ScopedFastFlag sff{"LuauTinyControlFlowAnalysis", true}; @@ -96,6 +314,66 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_return_elif_not_rand_return_elif_no CHECK_EQ("string?", toString(requireTypeAtPosition({11, 24}))); } +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_break_elif_rand_break_elif_not_y_fallthrough") +{ + ScopedFastFlag flags[] = { + {"LuauTinyControlFlowAnalysis", true}, + {"LuauLoopControlFlowAnalysis", true} + }; + + CheckResult result = check(R"( + local function f(x: {{value: string?}}, y: {{value: string?}}) + for i, recordX in x do + local recordY = y[i] + if not recordX.value then + break + elseif math.random() > 0.5 then + break + elseif not recordY.value then + + end + + local foo = recordX.value + local bar = recordY.value + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({12, 38}))); + CHECK_EQ("string?", toString(requireTypeAtPosition({13, 38}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_continue_elif_rand_continue_elif_not_y_fallthrough") +{ + ScopedFastFlag flags[] = { + {"LuauTinyControlFlowAnalysis", true}, + {"LuauLoopControlFlowAnalysis", true} + }; + + CheckResult result = check(R"( + local function f(x: {{value: string?}}, y: {{value: string?}}) + for i, recordX in x do + local recordY = y[i] + if not recordX.value then + continue + elseif math.random() > 0.5 then + continue + elseif not recordY.value then + + end + + local foo = recordX.value + local bar = recordY.value + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({12, 38}))); + CHECK_EQ("string?", toString(requireTypeAtPosition({13, 38}))); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_return_elif_not_y_fallthrough_elif_not_z_return") { ScopedFastFlag sff{"LuauTinyControlFlowAnalysis", true}; @@ -122,6 +400,138 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_return_elif_not_y_fallthrough_elif_ CHECK_EQ("string?", toString(requireTypeAtPosition({12, 24}))); } +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_break_elif_not_y_fallthrough_elif_not_z_break") +{ + ScopedFastFlag flags[] = { + {"LuauTinyControlFlowAnalysis", true}, + {"LuauLoopControlFlowAnalysis", true} + }; + + CheckResult result = check(R"( + local function f(x: {{value: string?}}, y: {{value: string?}}, z: {{value: string?}}) + for i, recordX in x do + local recordY = y[i] + local recordZ = y[i] + if not recordX.value then + break + elseif not recordY.value then + + elseif not recordZ.value then + break + end + + local foo = recordX.value + local bar = recordY.value + local baz = recordZ.value + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({13, 38}))); + CHECK_EQ("string?", toString(requireTypeAtPosition({14, 38}))); + CHECK_EQ("string?", toString(requireTypeAtPosition({15, 38}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_continue_elif_not_y_fallthrough_elif_not_z_continue") +{ + ScopedFastFlag flags[] = { + {"LuauTinyControlFlowAnalysis", true}, + {"LuauLoopControlFlowAnalysis", true} + }; + + CheckResult result = check(R"( + local function f(x: {{value: string?}}, y: {{value: string?}}, z: {{value: string?}}) + for i, recordX in x do + local recordY = y[i] + local recordZ = y[i] + if not recordX.value then + continue + elseif not recordY.value then + + elseif not recordZ.value then + continue + end + + local foo = recordX.value + local bar = recordY.value + local baz = recordZ.value + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({13, 38}))); + CHECK_EQ("string?", toString(requireTypeAtPosition({14, 38}))); + CHECK_EQ("string?", toString(requireTypeAtPosition({15, 38}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_continue_elif_not_y_throw_elif_not_z_fallthrough") +{ + ScopedFastFlag flags[] = { + {"LuauTinyControlFlowAnalysis", true}, + {"LuauLoopControlFlowAnalysis", true} + }; + + CheckResult result = check(R"( + local function f(x: {{value: string?}}, y: {{value: string?}}, z: {{value: string?}}) + for i, recordX in x do + local recordY = y[i] + local recordZ = y[i] + if not recordX.value then + continue + elseif not recordY.value then + error("Y value not defined") + elseif not recordZ.value then + + end + + local foo = recordX.value + local bar = recordY.value + local baz = recordZ.value + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({13, 38}))); + CHECK_EQ("string", toString(requireTypeAtPosition({14, 38}))); + CHECK_EQ("string?", toString(requireTypeAtPosition({15, 38}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_return_elif_not_y_fallthrough_elif_not_z_break") +{ + ScopedFastFlag flags[] = { + {"LuauTinyControlFlowAnalysis", true}, + {"LuauLoopControlFlowAnalysis", true} + }; + + CheckResult result = check(R"( + local function f(x: {{value: string?}}, y: {{value: string?}}, z: {{value: string?}}) + for i, recordX in x do + local recordY = y[i] + local recordZ = y[i] + if not recordX.value then + return + elseif not recordY.value then + + elseif not recordZ.value then + break + end + + local foo = recordX.value + local bar = recordY.value + local baz = recordZ.value + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({13, 38}))); + CHECK_EQ("string?", toString(requireTypeAtPosition({14, 38}))); + CHECK_EQ("string?", toString(requireTypeAtPosition({15, 38}))); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "do_if_not_x_return") { ScopedFastFlag sff{"LuauTinyControlFlowAnalysis", true}; @@ -142,6 +552,56 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "do_if_not_x_return") CHECK_EQ("string", toString(requireTypeAtPosition({8, 24}))); } +TEST_CASE_FIXTURE(BuiltinsFixture, "for_record_do_if_not_x_break") +{ + ScopedFastFlag flags[] = { + {"LuauTinyControlFlowAnalysis", true}, + {"LuauLoopControlFlowAnalysis", true} + }; + + CheckResult result = check(R"( + local function f(x: {{value: string?}}) + for _, record in x do + do + if not record.value then + break + end + end + + local foo = record.value + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({9, 38}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "for_record_do_if_not_x_continue") +{ + ScopedFastFlag flags[] = { + {"LuauTinyControlFlowAnalysis", true}, + {"LuauLoopControlFlowAnalysis", true} + }; + + CheckResult result = check(R"( + local function f(x: {{value: string?}}) + for _, record in x do + do + if not record.value then + continue + end + end + + local foo = record.value + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({9, 38}))); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "early_return_in_a_loop_which_isnt_guaranteed_to_run_first") { ScopedFastFlag sff{"LuauTinyControlFlowAnalysis", true}; @@ -271,6 +731,126 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_return_if_not_y_return") CHECK_EQ("string", toString(requireTypeAtPosition({11, 24}))); } +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_break_if_not_y_break") +{ + ScopedFastFlag flags[] = { + {"LuauTinyControlFlowAnalysis", true}, + {"LuauLoopControlFlowAnalysis", true} + }; + + CheckResult result = check(R"( + local function f(x: {{value: string?}}, y: {{value: string?}}) + for i, recordX in x do + local recordY = y[i] + if not recordX.value then + break + end + + if not recordY.value then + break + end + + local foo = recordX.value + local bar = recordY.value + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({12, 38}))); + CHECK_EQ("string", toString(requireTypeAtPosition({13, 38}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_continue_if_not_y_continue") +{ + ScopedFastFlag flags[] = { + {"LuauTinyControlFlowAnalysis", true}, + {"LuauLoopControlFlowAnalysis", true} + }; + + CheckResult result = check(R"( + local function f(x: {{value: string?}}, y: {{value: string?}}) + for i, recordX in x do + local recordY = y[i] + if not recordX.value then + continue + end + + if not recordY.value then + continue + end + + local foo = recordX.value + local bar = recordY.value + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({12, 38}))); + CHECK_EQ("string", toString(requireTypeAtPosition({13, 38}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_continue_if_not_y_throw") +{ + ScopedFastFlag flags[] = { + {"LuauTinyControlFlowAnalysis", true}, + {"LuauLoopControlFlowAnalysis", true} + }; + + CheckResult result = check(R"( + local function f(x: {{value: string?}}, y: {{value: string?}}) + for i, recordX in x do + local recordY = y[i] + if not recordX.value then + continue + end + + if not recordY.value then + error("Y value not defined") + end + + local foo = recordX.value + local bar = recordY.value + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({12, 38}))); + CHECK_EQ("string", toString(requireTypeAtPosition({13, 38}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_break_if_not_y_continue") +{ + ScopedFastFlag flags[] = { + {"LuauTinyControlFlowAnalysis", true}, + {"LuauLoopControlFlowAnalysis", true} + }; + + CheckResult result = check(R"( + local function f(x: {{value: string?}}, y: {{value: string?}}) + for i, recordX in x do + local recordY = y[i] + if not recordX.value then + break + end + + if not recordY.value then + continue + end + + local foo = recordX.value + local bar = recordY.value + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({12, 38}))); + CHECK_EQ("string", toString(requireTypeAtPosition({13, 38}))); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "type_alias_does_not_leak_out") { ScopedFastFlag sff{"LuauTinyControlFlowAnalysis", true}; @@ -294,6 +874,62 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "type_alias_does_not_leak_out") CHECK_EQ("nil", toString(requireTypeAtPosition({8, 29}))); } +TEST_CASE_FIXTURE(BuiltinsFixture, "type_alias_does_not_leak_out_breaking") +{ + ScopedFastFlag flags[] = { + {"LuauTinyControlFlowAnalysis", true}, + {"LuauLoopControlFlowAnalysis", true} + }; + + CheckResult result = check(R"( + local function f(x: {{value: string?}}) + for _, record in x do + if typeof(record.value) == "string" then + break + else + type Foo = number + end + + local foo: Foo = record.value + end + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("Unknown type 'Foo'", toString(result.errors[0])); + + CHECK_EQ("nil", toString(requireTypeAtPosition({9, 43}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "type_alias_does_not_leak_out_continuing") +{ + ScopedFastFlag flags[] = { + {"LuauTinyControlFlowAnalysis", true}, + {"LuauLoopControlFlowAnalysis", true} + }; + + CheckResult result = check(R"( + local function f(x: {{value: string?}}) + for _, record in x do + if typeof(record.value) == "string" then + continue + else + type Foo = number + end + + local foo: Foo = record.value + end + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("Unknown type 'Foo'", toString(result.errors[0])); + + CHECK_EQ("nil", toString(requireTypeAtPosition({9, 43}))); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "prototyping_and_visiting_alias_has_the_same_scope") { ScopedFastFlag sff{"LuauTinyControlFlowAnalysis", true}; @@ -320,6 +956,62 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "prototyping_and_visiting_alias_has_the_same_ CHECK_EQ("nil", toString(requireTypeAtPosition({8, 29}))); } +TEST_CASE_FIXTURE(BuiltinsFixture, "prototyping_and_visiting_alias_has_the_same_scope_breaking") +{ + ScopedFastFlag flags[] = { + {"LuauTinyControlFlowAnalysis", true}, + {"LuauLoopControlFlowAnalysis", true} + }; + + CheckResult result = check(R"( + local function f(x: {{value: string?}}) + for _, record in x do + type Foo = number + + if typeof(record.value) == "string" then + break + end + + local foo: Foo = record.value + end + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("Type 'nil' could not be converted into 'number'", toString(result.errors[0])); + + CHECK_EQ("nil", toString(requireTypeAtPosition({9, 43}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "prototyping_and_visiting_alias_has_the_same_scope_continuing") +{ + ScopedFastFlag flags[] = { + {"LuauTinyControlFlowAnalysis", true}, + {"LuauLoopControlFlowAnalysis", true} + }; + + CheckResult result = check(R"( + local function f(x: {{value: string?}}) + for _, record in x do + type Foo = number + + if typeof(record.value) == "string" then + continue + end + + local foo: Foo = record.value + end + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("Type 'nil' could not be converted into 'number'", toString(result.errors[0])); + + CHECK_EQ("nil", toString(requireTypeAtPosition({9, 43}))); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "tagged_unions") { ScopedFastFlag sff{"LuauTinyControlFlowAnalysis", true}; @@ -355,6 +1047,78 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "tagged_unions") CHECK_EQ("Err", toString(requireTypeAtPosition({16, 19}))); } +TEST_CASE_FIXTURE(BuiltinsFixture, "tagged_unions_breaking") +{ + ScopedFastFlag flags[] = { + {"LuauTinyControlFlowAnalysis", true}, + {"LuauLoopControlFlowAnalysis", true} + }; + + CheckResult result = check(R"( + type Ok = { tag: "ok", value: T } + type Err = { tag: "err", error: E } + type Result = Ok | Err + + local function process(results: {Result}) + for _, result in results do + if result.tag == "ok" then + local tag = result.tag + local val = result.value + + break + end + + local tag = result.tag + local err = result.error + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("\"ok\"", toString(requireTypeAtPosition({8, 39}))); + CHECK_EQ("T", toString(requireTypeAtPosition({9, 39}))); + + CHECK_EQ("\"err\"", toString(requireTypeAtPosition({14, 35}))); + CHECK_EQ("E", toString(requireTypeAtPosition({15, 35}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "tagged_unions_continuing") +{ + ScopedFastFlag flags[] = { + {"LuauTinyControlFlowAnalysis", true}, + {"LuauLoopControlFlowAnalysis", true} + }; + + CheckResult result = check(R"( + type Ok = { tag: "ok", value: T } + type Err = { tag: "err", error: E } + type Result = Ok | Err + + local function process(results: {Result}) + for _, result in results do + if result.tag == "ok" then + local tag = result.tag + local val = result.value + + continue + end + + local tag = result.tag + local err = result.error + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("\"ok\"", toString(requireTypeAtPosition({8, 39}))); + CHECK_EQ("T", toString(requireTypeAtPosition({9, 39}))); + + CHECK_EQ("\"err\"", toString(requireTypeAtPosition({14, 35}))); + CHECK_EQ("E", toString(requireTypeAtPosition({15, 35}))); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "do_assert_x") { ScopedFastFlag sff{"LuauTinyControlFlowAnalysis", true};