// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Generalization.h" #include "Luau/Common.h" #include "Luau/DenseHash.h" #include "Luau/Scope.h" #include "Luau/Type.h" #include "Luau/ToString.h" #include "Luau/TypeArena.h" #include "Luau/TypePack.h" #include "Luau/VisitType.h" LUAU_FASTFLAG(LuauAutocompleteRefactorsForIncrementalAutocomplete) LUAU_FASTFLAGVARIABLE(LuauGeneralizationRemoveRecursiveUpperBound2) namespace Luau { struct MutatingGeneralizer : TypeOnceVisitor { NotNull arena; NotNull builtinTypes; NotNull scope; NotNull> cachedTypes; DenseHashMap positiveTypes; DenseHashMap negativeTypes; std::vector generics; std::vector genericPacks; bool isWithinFunction = false; MutatingGeneralizer( NotNull arena, NotNull builtinTypes, NotNull scope, NotNull> cachedTypes, DenseHashMap positiveTypes, DenseHashMap negativeTypes ) : TypeOnceVisitor(/* skipBoundTypes */ true) , arena(arena) , builtinTypes(builtinTypes) , scope(scope) , cachedTypes(cachedTypes) , positiveTypes(std::move(positiveTypes)) , negativeTypes(std::move(negativeTypes)) { } void replace(DenseHashSet& seen, TypeId haystack, TypeId needle, TypeId replacement) { haystack = follow(haystack); if (seen.find(haystack)) return; seen.insert(haystack); if (UnionType* ut = getMutable(haystack)) { for (auto iter = ut->options.begin(); iter != ut->options.end();) { // FIXME: I bet this function has reentrancy problems TypeId option = follow(*iter); if (option == needle && get(replacement)) { iter = ut->options.erase(iter); continue; } if (option == needle) { *iter = replacement; iter++; continue; } // advance the iterator, nothing after this can use it. iter++; if (seen.find(option)) continue; seen.insert(option); if (get(option)) replace(seen, option, needle, haystack); else if (get(option)) replace(seen, option, needle, haystack); } if (ut->options.size() == 1) { TypeId onlyType = ut->options[0]; LUAU_ASSERT(onlyType != haystack); emplaceType(asMutable(haystack), onlyType); } else if (FFlag::LuauGeneralizationRemoveRecursiveUpperBound2 && ut->options.empty()) { emplaceType(asMutable(haystack), builtinTypes->neverType); } return; } if (IntersectionType* it = getMutable(needle)) { for (auto iter = it->parts.begin(); iter != it->parts.end();) { // FIXME: I bet this function has reentrancy problems TypeId part = follow(*iter); if (part == needle && get(replacement)) { iter = it->parts.erase(iter); continue; } if (part == needle) { *iter = replacement; iter++; continue; } // advance the iterator, nothing after this can use it. iter++; if (seen.find(part)) continue; seen.insert(part); if (get(part)) replace(seen, part, needle, haystack); else if (get(part)) replace(seen, part, needle, haystack); } if (it->parts.size() == 1) { TypeId onlyType = it->parts[0]; LUAU_ASSERT(onlyType != needle); emplaceType(asMutable(needle), onlyType); } else if (FFlag::LuauGeneralizationRemoveRecursiveUpperBound2 && it->parts.empty()) { emplaceType(asMutable(needle), builtinTypes->unknownType); } return; } } bool visit(TypeId ty, const FunctionType& ft) override { if (cachedTypes->contains(ty)) return false; const bool oldValue = isWithinFunction; isWithinFunction = true; traverse(ft.argTypes); traverse(ft.retTypes); isWithinFunction = oldValue; return false; } bool visit(TypeId ty, const FreeType&) override { LUAU_ASSERT(!cachedTypes->contains(ty)); const FreeType* ft = get(ty); LUAU_ASSERT(ft); traverse(ft->lowerBound); traverse(ft->upperBound); // It is possible for the above traverse() calls to cause ty to be // transmuted. We must reacquire ft if this happens. ty = follow(ty); ft = get(ty); if (!ft) return false; const size_t positiveCount = getCount(positiveTypes, ty); const size_t negativeCount = getCount(negativeTypes, ty); if (!positiveCount && !negativeCount) return false; const bool hasLowerBound = !get(follow(ft->lowerBound)); const bool hasUpperBound = !get(follow(ft->upperBound)); DenseHashSet seen{nullptr}; seen.insert(ty); if (!hasLowerBound && !hasUpperBound) { if (!isWithinFunction || (positiveCount + negativeCount == 1)) emplaceType(asMutable(ty), builtinTypes->unknownType); else { emplaceType(asMutable(ty), scope); generics.push_back(ty); } } // It is possible that this free type has other free types in its upper // or lower bounds. If this is the case, we must replace those // references with never (for the lower bound) or unknown (for the upper // bound). // // If we do not do this, we get tautological bounds like a <: a <: unknown. else if (positiveCount && !hasUpperBound) { TypeId lb = follow(ft->lowerBound); if (FreeType* lowerFree = getMutable(lb); lowerFree && lowerFree->upperBound == ty) lowerFree->upperBound = builtinTypes->unknownType; else { DenseHashSet replaceSeen{nullptr}; replace(replaceSeen, lb, ty, builtinTypes->unknownType); } if (lb != ty) emplaceType(asMutable(ty), lb); else if (!isWithinFunction || (positiveCount + negativeCount == 1)) emplaceType(asMutable(ty), builtinTypes->unknownType); else { // if the lower bound is the type in question, we don't actually have a lower bound. emplaceType(asMutable(ty), scope); generics.push_back(ty); } } else { TypeId ub = follow(ft->upperBound); if (FreeType* upperFree = getMutable(ub); upperFree && upperFree->lowerBound == ty) upperFree->lowerBound = builtinTypes->neverType; else { DenseHashSet replaceSeen{nullptr}; replace(replaceSeen, ub, ty, builtinTypes->neverType); } if (ub != ty) emplaceType(asMutable(ty), ub); else if (!isWithinFunction || (positiveCount + negativeCount == 1)) emplaceType(asMutable(ty), builtinTypes->unknownType); else { // if the upper bound is the type in question, we don't actually have an upper bound. emplaceType(asMutable(ty), scope); generics.push_back(ty); } } return false; } size_t getCount(const DenseHashMap& map, const void* ty) { if (const size_t* count = map.find(ty)) return *count; else return 0; } bool visit(TypeId ty, const TableType&) override { if (cachedTypes->contains(ty)) return false; const size_t positiveCount = getCount(positiveTypes, ty); const size_t negativeCount = getCount(negativeTypes, ty); // FIXME: Free tables should probably just be replaced by upper bounds on free types. // // eg never <: 'a <: {x: number} & {z: boolean} if (!positiveCount && !negativeCount) return true; TableType* tt = getMutable(ty); LUAU_ASSERT(tt); tt->state = TableState::Sealed; return true; } bool visit(TypePackId tp, const FreeTypePack& ftp) override { if (!subsumes(scope, ftp.scope)) return true; tp = follow(tp); const size_t positiveCount = getCount(positiveTypes, tp); const size_t negativeCount = getCount(negativeTypes, tp); if (1 == positiveCount + negativeCount) emplaceTypePack(asMutable(tp), builtinTypes->unknownTypePack); else { emplaceTypePack(asMutable(tp), scope); genericPacks.push_back(tp); } return true; } }; struct FreeTypeSearcher : TypeVisitor { NotNull scope; NotNull> cachedTypes; explicit FreeTypeSearcher(NotNull scope, NotNull> cachedTypes) : TypeVisitor(/*skipBoundTypes*/ true) , scope(scope) , cachedTypes(cachedTypes) { } Polarity polarity = Polarity::Positive; void flip() { switch (polarity) { case Polarity::Positive: polarity = Polarity::Negative; break; case Polarity::Negative: polarity = Polarity::Positive; break; default: break; } } DenseHashSet seenPositive{nullptr}; DenseHashSet seenNegative{nullptr}; bool seenWithCurrentPolarity(const void* ty) { switch (polarity) { case Polarity::Positive: { if (seenPositive.contains(ty)) return true; seenPositive.insert(ty); return false; } case Polarity::Negative: { if (seenNegative.contains(ty)) return true; seenNegative.insert(ty); return false; } case Polarity::Mixed: { if (seenPositive.contains(ty) && seenNegative.contains(ty)) return true; seenPositive.insert(ty); seenNegative.insert(ty); return false; } default: LUAU_ASSERT(!"Unreachable"); } return false; } // The keys in these maps are either TypeIds or TypePackIds. It's safe to // mix them because we only use these pointers as unique keys. We never // indirect them. DenseHashMap negativeTypes{0}; DenseHashMap positiveTypes{0}; bool visit(TypeId ty) override { if (cachedTypes->contains(ty) || seenWithCurrentPolarity(ty)) return false; LUAU_ASSERT(ty); return true; } bool visit(TypeId ty, const FreeType& ft) override { if (cachedTypes->contains(ty) || seenWithCurrentPolarity(ty)) return false; if (!subsumes(scope, ft.scope)) return true; switch (polarity) { case Polarity::Positive: positiveTypes[ty]++; break; case Polarity::Negative: negativeTypes[ty]++; break; case Polarity::Mixed: positiveTypes[ty]++; negativeTypes[ty]++; break; default: LUAU_ASSERT(!"Unreachable"); } return true; } bool visit(TypeId ty, const TableType& tt) override { if (cachedTypes->contains(ty) || seenWithCurrentPolarity(ty)) return false; if ((tt.state == TableState::Free || tt.state == TableState::Unsealed) && subsumes(scope, tt.scope)) { switch (polarity) { case Polarity::Positive: positiveTypes[ty]++; break; case Polarity::Negative: negativeTypes[ty]++; break; case Polarity::Mixed: positiveTypes[ty]++; negativeTypes[ty]++; break; default: LUAU_ASSERT(!"Unreachable"); } } for (const auto& [_name, prop] : tt.props) { if (prop.isReadOnly()) traverse(*prop.readTy); else { LUAU_ASSERT(prop.isShared() || FFlag::LuauAutocompleteRefactorsForIncrementalAutocomplete); Polarity p = polarity; polarity = Polarity::Mixed; traverse(prop.type()); polarity = p; } } if (tt.indexer) { traverse(tt.indexer->indexType); traverse(tt.indexer->indexResultType); } return false; } bool visit(TypeId ty, const FunctionType& ft) override { if (cachedTypes->contains(ty) || seenWithCurrentPolarity(ty)) return false; flip(); traverse(ft.argTypes); flip(); traverse(ft.retTypes); return false; } bool visit(TypeId, const ClassType&) override { return false; } bool visit(TypePackId tp, const FreeTypePack& ftp) override { if (seenWithCurrentPolarity(tp)) return false; if (!subsumes(scope, ftp.scope)) return true; switch (polarity) { case Polarity::Positive: positiveTypes[tp]++; break; case Polarity::Negative: negativeTypes[tp]++; break; case Polarity::Mixed: positiveTypes[tp]++; negativeTypes[tp]++; break; default: LUAU_ASSERT(!"Unreachable"); } return true; } }; // We keep a running set of types that will not change under generalization and // only have outgoing references to types that are the same. We use this to // short circuit generalization. It improves performance quite a lot. // // We do this by tracing through the type and searching for types that are // uncacheable. If a type has a reference to an uncacheable type, it is itself // uncacheable. // // If a type has no outbound references to uncacheable types, we add it to the // cache. struct TypeCacher : TypeOnceVisitor { NotNull> cachedTypes; DenseHashSet uncacheable{nullptr}; DenseHashSet uncacheablePacks{nullptr}; explicit TypeCacher(NotNull> cachedTypes) : TypeOnceVisitor(/* skipBoundTypes */ false) , cachedTypes(cachedTypes) { } void cache(TypeId ty) const { cachedTypes->insert(ty); } bool isCached(TypeId ty) const { return cachedTypes->contains(ty); } void markUncacheable(TypeId ty) { uncacheable.insert(ty); } void markUncacheable(TypePackId tp) { uncacheablePacks.insert(tp); } bool isUncacheable(TypeId ty) const { return uncacheable.contains(ty); } bool isUncacheable(TypePackId tp) const { return uncacheablePacks.contains(tp); } bool visit(TypeId ty) override { // NOTE: `TypeCacher` should explicitly visit _all_ types and type packs, // otherwise it's prone to marking types that cannot be cached as // cacheable. LUAU_ASSERT(false); LUAU_UNREACHABLE(); } bool visit(TypeId ty, const BoundType& btv) override { traverse(btv.boundTo); if (isUncacheable(btv.boundTo)) markUncacheable(ty); return false; } bool visit(TypeId ty, const FreeType& ft) override { // Free types are never cacheable. LUAU_ASSERT(!isCached(ty)); if (!isUncacheable(ty)) { traverse(ft.lowerBound); traverse(ft.upperBound); markUncacheable(ty); } return false; } bool visit(TypeId ty, const GenericType&) override { cache(ty); return false; } bool visit(TypeId ty, const ErrorType&) override { cache(ty); return false; } bool visit(TypeId ty, const PrimitiveType&) override { cache(ty); return false; } bool visit(TypeId ty, const SingletonType&) override { cache(ty); return false; } bool visit(TypeId ty, const BlockedType&) override { markUncacheable(ty); return false; } bool visit(TypeId ty, const PendingExpansionType&) override { markUncacheable(ty); return false; } bool visit(TypeId ty, const FunctionType& ft) override { if (isCached(ty) || isUncacheable(ty)) return false; traverse(ft.argTypes); traverse(ft.retTypes); for (TypeId gen : ft.generics) traverse(gen); bool uncacheable = false; if (isUncacheable(ft.argTypes)) uncacheable = true; else if (isUncacheable(ft.retTypes)) uncacheable = true; for (TypeId argTy : ft.argTypes) { if (isUncacheable(argTy)) { uncacheable = true; break; } } for (TypeId retTy : ft.retTypes) { if (isUncacheable(retTy)) { uncacheable = true; break; } } for (TypeId g : ft.generics) { if (isUncacheable(g)) { uncacheable = true; break; } } if (uncacheable) markUncacheable(ty); else cache(ty); return false; } bool visit(TypeId ty, const TableType& tt) override { if (isCached(ty) || isUncacheable(ty)) return false; if (tt.boundTo) { traverse(*tt.boundTo); if (isUncacheable(*tt.boundTo)) { markUncacheable(ty); return false; } } bool uncacheable = false; // This logic runs immediately after generalization, so any remaining // unsealed tables are assuredly not cacheable. They may yet have // properties added to them. if (tt.state == TableState::Free || tt.state == TableState::Unsealed) uncacheable = true; for (const auto& [_name, prop] : tt.props) { if (prop.readTy) { traverse(*prop.readTy); if (isUncacheable(*prop.readTy)) uncacheable = true; } if (prop.writeTy && prop.writeTy != prop.readTy) { traverse(*prop.writeTy); if (isUncacheable(*prop.writeTy)) uncacheable = true; } } if (tt.indexer) { traverse(tt.indexer->indexType); if (isUncacheable(tt.indexer->indexType)) uncacheable = true; traverse(tt.indexer->indexResultType); if (isUncacheable(tt.indexer->indexResultType)) uncacheable = true; } if (uncacheable) markUncacheable(ty); else cache(ty); return false; } bool visit(TypeId ty, const MetatableType& mtv) override { traverse(mtv.table); traverse(mtv.metatable); if (isUncacheable(mtv.table) || isUncacheable(mtv.metatable)) markUncacheable(ty); else cache(ty); return false; } bool visit(TypeId ty, const ClassType&) override { cache(ty); return false; } bool visit(TypeId ty, const AnyType&) override { cache(ty); return false; } bool visit(TypeId ty, const NoRefineType&) override { cache(ty); return false; } bool visit(TypeId ty, const UnionType& ut) override { if (isUncacheable(ty) || isCached(ty)) return false; bool uncacheable = false; for (TypeId partTy : ut.options) { traverse(partTy); uncacheable |= isUncacheable(partTy); } if (uncacheable) markUncacheable(ty); else cache(ty); return false; } bool visit(TypeId ty, const IntersectionType& it) override { if (isUncacheable(ty) || isCached(ty)) return false; bool uncacheable = false; for (TypeId partTy : it.parts) { traverse(partTy); uncacheable |= isUncacheable(partTy); } if (uncacheable) markUncacheable(ty); else cache(ty); return false; } bool visit(TypeId ty, const UnknownType&) override { cache(ty); return false; } bool visit(TypeId ty, const NeverType&) override { cache(ty); return false; } bool visit(TypeId ty, const NegationType& nt) override { if (!isCached(ty) && !isUncacheable(ty)) { traverse(nt.ty); if (isUncacheable(nt.ty)) markUncacheable(ty); else cache(ty); } return false; } bool visit(TypeId ty, const TypeFunctionInstanceType& tfit) override { if (isCached(ty) || isUncacheable(ty)) return false; bool uncacheable = false; for (TypeId argTy : tfit.typeArguments) { traverse(argTy); if (isUncacheable(argTy)) uncacheable = true; } for (TypePackId argPack : tfit.packArguments) { traverse(argPack); if (isUncacheable(argPack)) uncacheable = true; } if (uncacheable) markUncacheable(ty); else cache(ty); return false; } bool visit(TypePackId tp) override { // NOTE: `TypeCacher` should explicitly visit _all_ types and type packs, // otherwise it's prone to marking types that cannot be cached as // cacheable, which will segfault down the line. LUAU_ASSERT(false); LUAU_UNREACHABLE(); } bool visit(TypePackId tp, const FreeTypePack&) override { markUncacheable(tp); return false; } bool visit(TypePackId tp, const GenericTypePack& gtp) override { return true; } bool visit(TypePackId tp, const ErrorTypePack& etp) override { return true; } bool visit(TypePackId tp, const VariadicTypePack& vtp) override { if (isUncacheable(tp)) return false; traverse(vtp.ty); if (isUncacheable(vtp.ty)) markUncacheable(tp); return false; } bool visit(TypePackId tp, const BlockedTypePack&) override { markUncacheable(tp); return false; } bool visit(TypePackId tp, const TypeFunctionInstanceTypePack&) override { markUncacheable(tp); return false; } bool visit(TypePackId tp, const BoundTypePack& btp) override { traverse(btp.boundTo); if (isUncacheable(btp.boundTo)) markUncacheable(tp); return false; } bool visit(TypePackId tp, const TypePack& typ) override { bool uncacheable = false; for (TypeId ty : typ.head) { traverse(ty); uncacheable |= isUncacheable(ty); } if (typ.tail) { traverse(*typ.tail); uncacheable |= isUncacheable(*typ.tail); } if (uncacheable) markUncacheable(tp); return false; } }; std::optional generalize( NotNull arena, NotNull builtinTypes, NotNull scope, NotNull> cachedTypes, TypeId ty ) { ty = follow(ty); if (ty->owningArena != arena || ty->persistent) return ty; FreeTypeSearcher fts{scope, cachedTypes}; fts.traverse(ty); MutatingGeneralizer gen{arena, builtinTypes, scope, cachedTypes, std::move(fts.positiveTypes), std::move(fts.negativeTypes)}; gen.traverse(ty); /* MutatingGeneralizer mutates types in place, so it is possible that ty has * been transmuted to a BoundType. We must follow it again and verify that * we are allowed to mutate it before we attach generics to it. */ ty = follow(ty); if (ty->owningArena != arena || ty->persistent) return ty; TypeCacher cacher{cachedTypes}; cacher.traverse(ty); FunctionType* ftv = getMutable(ty); if (ftv) { // If we're generalizing a function type, add any of the newly inferred // generics to the list of existing generic types. for (const auto g : std::move(gen.generics)) { ftv->generics.push_back(g); } // Ditto for generic packs. for (const auto gp : std::move(gen.genericPacks)) { ftv->genericPacks.push_back(gp); } } return ty; } } // namespace Luau