diff --git a/Analysis/include/Luau/ConstraintGraphBuilder.h b/Analysis/include/Luau/ConstraintGraphBuilder.h index eb1b1fed..1d3f20ee 100644 --- a/Analysis/include/Luau/ConstraintGraphBuilder.h +++ b/Analysis/include/Luau/ConstraintGraphBuilder.h @@ -13,6 +13,7 @@ #include "Luau/Type.h" #include "Luau/TypeUtils.h" #include "Luau/Variant.h" +#include "Normalize.h" #include #include @@ -86,6 +87,8 @@ struct ConstraintGraphBuilder // It is pretty uncommon for constraint generation to itself produce errors, but it can happen. std::vector errors; + // Needed to be able to enable error-suppression preservation for immediate refinements. + NotNull normalizer; // Needed to resolve modules to make 'require' import types properly. NotNull moduleResolver; // Occasionally constraint generation needs to produce an ICE. @@ -98,7 +101,7 @@ struct ConstraintGraphBuilder DcrLogger* logger; - ConstraintGraphBuilder(ModulePtr module, TypeArena* arena, NotNull moduleResolver, NotNull builtinTypes, + ConstraintGraphBuilder(ModulePtr module, NotNull normalizer, NotNull moduleResolver, NotNull builtinTypes, NotNull ice, const ScopePtr& globalScope, std::function prepareModuleScope, DcrLogger* logger, NotNull dfg, std::vector requireCycles); diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index cba2cbb4..47effcea 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -204,7 +204,7 @@ struct ConstraintSolver * @param subType the sub-type to unify. * @param superType the super-type to unify. */ - ErrorVec unify(TypeId subType, TypeId superType, NotNull scope); + ErrorVec unify(NotNull scope, Location location, TypeId subType, TypeId superType); /** * Creates a new Unifier and performs a single unification operation. Commits @@ -212,7 +212,7 @@ struct ConstraintSolver * @param subPack the sub-type pack to unify. * @param superPack the super-type pack to unify. */ - ErrorVec unify(TypePackId subPack, TypePackId superPack, NotNull scope); + ErrorVec unify(NotNull scope, Location location, TypePackId subPack, TypePackId superPack); /** Pushes a new solver constraint to the solver. * @param cv the body of the constraint. diff --git a/Analysis/include/Luau/Differ.h b/Analysis/include/Luau/Differ.h index 60f555dc..e9656ad4 100644 --- a/Analysis/include/Luau/Differ.h +++ b/Analysis/include/Luau/Differ.h @@ -3,6 +3,7 @@ #include "Luau/DenseHash.h" #include "Luau/Type.h" +#include "Luau/UnifierSharedState.h" #include #include #include @@ -151,8 +152,31 @@ struct DifferEnvironment { TypeId rootLeft; TypeId rootRight; - DenseHashMap genericMatchedPairs; + DenseHashMap genericTpMatchedPairs; + + DifferEnvironment(TypeId rootLeft, TypeId rootRight) + : rootLeft(rootLeft) + , rootRight(rootRight) + , genericMatchedPairs(nullptr) + , genericTpMatchedPairs(nullptr) + { + } + + bool isProvenEqual(TypeId left, TypeId right) const; + bool isAssumedEqual(TypeId left, TypeId right) const; + void recordProvenEqual(TypeId left, TypeId right); + void pushVisiting(TypeId left, TypeId right); + void popVisiting(); + std::vector>::const_reverse_iterator visitingBegin() const; + std::vector>::const_reverse_iterator visitingEnd() const; + +private: + // TODO: consider using DenseHashSet + std::unordered_set, TypeIdPairHash> provenEqual; + // Ancestors of current types + std::unordered_set, TypeIdPairHash> visiting; + std::vector> visitingStack; }; DifferResult diff(TypeId ty1, TypeId ty2); diff --git a/Analysis/include/Luau/Instantiation.h b/Analysis/include/Luau/Instantiation.h index c916f953..642f2b9e 100644 --- a/Analysis/include/Luau/Instantiation.h +++ b/Analysis/include/Luau/Instantiation.h @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/NotNull.h" #include "Luau/Substitution.h" #include "Luau/Type.h" #include "Luau/Unifiable.h" @@ -8,15 +9,18 @@ namespace Luau { -struct TypeArena; +struct BuiltinTypes; struct TxnLog; +struct TypeArena; +struct TypeCheckLimits; // A substitution which replaces generic types in a given set by free types. struct ReplaceGenerics : Substitution { - ReplaceGenerics(const TxnLog* log, TypeArena* arena, TypeLevel level, Scope* scope, const std::vector& generics, + ReplaceGenerics(const TxnLog* log, TypeArena* arena, NotNull builtinTypes, TypeLevel level, Scope* scope, const std::vector& generics, const std::vector& genericPacks) : Substitution(log, arena) + , builtinTypes(builtinTypes) , level(level) , scope(scope) , generics(generics) @@ -24,6 +28,8 @@ struct ReplaceGenerics : Substitution { } + NotNull builtinTypes; + TypeLevel level; Scope* scope; std::vector generics; @@ -38,13 +44,16 @@ struct ReplaceGenerics : Substitution // A substitution which replaces generic functions by monomorphic functions struct Instantiation : Substitution { - Instantiation(const TxnLog* log, TypeArena* arena, TypeLevel level, Scope* scope) + Instantiation(const TxnLog* log, TypeArena* arena, NotNull builtinTypes, TypeLevel level, Scope* scope) : Substitution(log, arena) + , builtinTypes(builtinTypes) , level(level) , scope(scope) { } + NotNull builtinTypes; + TypeLevel level; Scope* scope; bool ignoreChildren(TypeId ty) override; @@ -54,4 +63,20 @@ struct Instantiation : Substitution TypePackId clean(TypePackId tp) override; }; +/** Attempt to instantiate a type. Only used under local type inference. + * + * When given a generic function type, instantiate() will return a copy with the + * generics replaced by fresh types. Instantiation will return the same TypeId + * back if the function does not have any generics. + * + * All higher order generics are left as-is. For example, instantiation of + * ((Y) -> (X, Y)) -> (X, Y) is ((Y) -> ('x, Y)) -> ('x, Y) + * + * We substitute the generic X for the free 'x, but leave the generic Y alone. + * + * Instantiation fails only when processing the type causes internal recursion + * limits to be exceeded. + */ +std::optional instantiate(NotNull builtinTypes, NotNull arena, NotNull limits, NotNull scope, TypeId ty); + } // namespace Luau diff --git a/Analysis/include/Luau/Type.h b/Analysis/include/Luau/Type.h index e9420922..c152fc02 100644 --- a/Analysis/include/Luau/Type.h +++ b/Analysis/include/Luau/Type.h @@ -77,13 +77,15 @@ using TypeId = const Type*; using Name = std::string; -// A free type var is one whose exact shape has yet to be fully determined. +// A free type is one whose exact shape has yet to be fully determined. struct FreeType { explicit FreeType(TypeLevel level); explicit FreeType(Scope* scope); FreeType(Scope* scope, TypeLevel level); + FreeType(Scope* scope, TypeId lowerBound, TypeId upperBound); + int index; TypeLevel level; Scope* scope = nullptr; @@ -92,6 +94,10 @@ struct FreeType // recursive type alias whose definitions haven't been // resolved yet. bool forwardedTypeAlias = false; + + // Only used under local type inference + TypeId lowerBound = nullptr; + TypeId upperBound = nullptr; }; struct GenericType @@ -994,6 +1000,8 @@ private: } }; +TypeId freshType(NotNull arena, NotNull builtinTypes, Scope* scope); + using TypeIdPredicate = std::function(TypeId)>; std::vector filterMap(TypeId type, TypeIdPredicate predicate); diff --git a/Analysis/include/Luau/TypeChecker2.h b/Analysis/include/Luau/TypeChecker2.h index 11d2aff9..db81d9cf 100644 --- a/Analysis/include/Luau/TypeChecker2.h +++ b/Analysis/include/Luau/TypeChecker2.h @@ -9,10 +9,12 @@ namespace Luau { -struct DcrLogger; struct BuiltinTypes; +struct DcrLogger; +struct TypeCheckLimits; +struct UnifierSharedState; -void check(NotNull builtinTypes, NotNull sharedState, DcrLogger* logger, const SourceModule& sourceModule, +void check(NotNull builtinTypes, NotNull sharedState, NotNull limits, DcrLogger* logger, const SourceModule& sourceModule, Module* module); } // namespace Luau diff --git a/Analysis/include/Luau/Unifier2.h b/Analysis/include/Luau/Unifier2.h new file mode 100644 index 00000000..cf769da3 --- /dev/null +++ b/Analysis/include/Luau/Unifier2.h @@ -0,0 +1,75 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#pragma once + +#include "Luau/DenseHash.h" +#include "Luau/NotNull.h" + +#include +#include +#include + +namespace Luau +{ + +using TypeId = const struct Type*; +using TypePackId = const struct TypePackVar*; + +struct BuiltinTypes; +struct InternalErrorReporter; +struct Scope; +struct TypeArena; + +enum class OccursCheckResult +{ + Pass, + Fail +}; + +struct Unifier2 +{ + NotNull arena; + NotNull builtinTypes; + NotNull ice; + + int recursionCount = 0; + int recursionLimit = 0; + + Unifier2(NotNull arena, NotNull builtinTypes, NotNull ice); + + /** Attempt to commit the subtype relation subTy <: superTy to the type + * graph. + * + * @returns true if successful. + * + * Note that incoherent types can and will successfully be unified. We stop + * when we *cannot know* how to relate the provided types, not when doing so + * would narrow something down to never or broaden it to unknown. + * + * Presently, the only way unification can fail is if we attempt to bind one + * free TypePack to another and encounter an occurs check violation. + */ + bool unify(TypeId subTy, TypeId superTy); + + // TODO think about this one carefully. We don't do unions or intersections of type packs + bool unify(TypePackId subTp, TypePackId superTp); + + std::optional generalize(NotNull scope, TypeId ty); +private: + + /** + * @returns simplify(left | right) + */ + TypeId mkUnion(TypeId left, TypeId right); + + /** + * @returns simplify(left & right) + */ + TypeId mkIntersection(TypeId left, TypeId right); + + // Returns true if needle occurs within haystack already. ie if we bound + // needle to haystack, would a cyclic TypePack result? + OccursCheckResult occursCheck(DenseHashSet& seen, TypePackId needle, TypePackId haystack); +}; + +} diff --git a/Analysis/include/Luau/Variant.h b/Analysis/include/Luau/Variant.h index 016c51f6..95fdfac4 100644 --- a/Analysis/include/Luau/Variant.h +++ b/Analysis/include/Luau/Variant.h @@ -1,13 +1,13 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once -#include "Luau/Common.h" +#include #include #include -#include -#include #include +#include + namespace Luau { diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index cd3f4c6e..eaf47b77 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -14,7 +14,7 @@ LUAU_FASTFLAG(DebugLuauReadWriteProperties) LUAU_FASTFLAGVARIABLE(LuauDisableCompletionOutsideQuotes, false) -LUAU_FASTFLAGVARIABLE(LuauAnonymousAutofilled, false); +LUAU_FASTFLAGVARIABLE(LuauAnonymousAutofilled1, false); LUAU_FASTFLAGVARIABLE(LuauAutocompleteLastTypecheck, false) LUAU_FASTFLAGVARIABLE(LuauAutocompleteHideSelfArg, false) @@ -618,7 +618,7 @@ std::optional getLocalTypeInScopeAt(const Module& module, Position posit template static std::optional tryToStringDetailed(const ScopePtr& scope, T ty, bool functionTypeArguments) { - LUAU_ASSERT(FFlag::LuauAnonymousAutofilled); + LUAU_ASSERT(FFlag::LuauAnonymousAutofilled1); ToStringOptions opts; opts.useLineBreaks = false; opts.hideTableKind = true; @@ -637,7 +637,7 @@ static std::optional tryGetTypeNameInScope(ScopePtr scope, TypeId ty, bool if (!canSuggestInferredType(scope, ty)) return std::nullopt; - if (FFlag::LuauAnonymousAutofilled) + if (FFlag::LuauAnonymousAutofilled1) { return tryToStringDetailed(scope, ty, functionTypeArguments); } @@ -1419,7 +1419,7 @@ static AutocompleteResult autocompleteWhileLoopKeywords(std::vector an static std::string makeAnonymous(const ScopePtr& scope, const FunctionType& funcTy) { - LUAU_ASSERT(FFlag::LuauAnonymousAutofilled); + LUAU_ASSERT(FFlag::LuauAnonymousAutofilled1); std::string result = "function("; auto [args, tail] = Luau::flatten(funcTy.argTypes); @@ -1485,7 +1485,7 @@ static std::string makeAnonymous(const ScopePtr& scope, const FunctionType& func static std::optional makeAnonymousAutofilled(const ModulePtr& module, Position position, const AstNode* node, const std::vector& ancestry) { - LUAU_ASSERT(FFlag::LuauAnonymousAutofilled); + LUAU_ASSERT(FFlag::LuauAnonymousAutofilled1); const AstExprCall* call = node->as(); if (!call && ancestry.size() > 1) call = ancestry[ancestry.size() - 2]->as(); @@ -1803,7 +1803,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M if (node->asExpr()) { - if (FFlag::LuauAnonymousAutofilled) + if (FFlag::LuauAnonymousAutofilled1) { AutocompleteResult ret = autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); if (std::optional generated = makeAnonymousAutofilled(module, position, node, ancestry)) diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index 197aad7a..9080b1fc 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -9,6 +9,7 @@ LUAU_FASTFLAG(DebugLuauCopyBeforeNormalizing) LUAU_FASTFLAG(DebugLuauReadWriteProperties) +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300) LUAU_FASTFLAGVARIABLE(LuauCloneCyclicUnions, false) @@ -204,7 +205,14 @@ void TypeCloner::defaultClone(const T& t) void TypeCloner::operator()(const FreeType& t) { - defaultClone(t); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + FreeType ft{t.scope, clone(t.lowerBound, dest, cloneState), clone(t.upperBound, dest, cloneState)}; + TypeId res = dest.addType(ft); + seenTypes[typeId] = res; + } + else + defaultClone(t); } void TypeCloner::operator()(const GenericType& t) @@ -363,7 +371,10 @@ void TypeCloner::operator()(const UnionType& t) { if (FFlag::LuauCloneCyclicUnions) { - TypeId result = dest.addType(FreeType{nullptr}); + // We're just using this FreeType as a placeholder until we've finished + // cloning the parts of this union so it is okay that its bounds are + // nullptr. We'll never indirect them. + TypeId result = dest.addType(FreeType{nullptr, /*lowerBound*/nullptr, /*upperBound*/nullptr}); seenTypes[typeId] = result; std::vector options; diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index 9c2766ec..7d35ebc6 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -21,6 +21,7 @@ #include LUAU_FASTINT(LuauCheckRecursionLimit); +LUAU_FASTFLAG(DebugLuauLogSolverToJson); LUAU_FASTFLAG(DebugLuauMagicTypes); LUAU_FASTFLAG(LuauParseDeclareClassIndexer); @@ -137,15 +138,16 @@ void forEachConstraint(const Checkpoint& start, const Checkpoint& end, const Con } // namespace -ConstraintGraphBuilder::ConstraintGraphBuilder(ModulePtr module, TypeArena* arena, NotNull moduleResolver, +ConstraintGraphBuilder::ConstraintGraphBuilder(ModulePtr module, NotNull normalizer, NotNull moduleResolver, NotNull builtinTypes, NotNull ice, const ScopePtr& globalScope, std::function prepareModuleScope, DcrLogger* logger, NotNull dfg, std::vector requireCycles) : module(module) , builtinTypes(builtinTypes) - , arena(arena) + , arena(normalizer->arena) , rootScope(nullptr) , dfg(dfg) + , normalizer(normalizer) , moduleResolver(moduleResolver) , ice(ice) , globalScope(globalScope) @@ -158,7 +160,7 @@ ConstraintGraphBuilder::ConstraintGraphBuilder(ModulePtr module, TypeArena* aren TypeId ConstraintGraphBuilder::freshType(const ScopePtr& scope) { - return arena->addType(FreeType{scope.get()}); + return Luau::freshType(arena, builtinTypes, scope.get()); } TypePackId ConstraintGraphBuilder::freshTypePack(const ScopePtr& scope) @@ -414,7 +416,22 @@ void ConstraintGraphBuilder::applyRefinements(const ScopePtr& scope, Location lo ty = r; } else - ty = simplifyIntersection(builtinTypes, arena, ty, dt).result; + { + switch (shouldSuppressErrors(normalizer, ty)) + { + case ErrorSuppression::DoNotSuppress: + ty = simplifyIntersection(builtinTypes, arena, ty, dt).result; + break; + case ErrorSuppression::Suppress: + ty = simplifyIntersection(builtinTypes, arena, ty, dt).result; + ty = simplifyUnion(builtinTypes, arena, ty, builtinTypes->errorType).result; + break; + case ErrorSuppression::NormalizationFailed: + reportError(location, NormalizationTooComplex{}); + ty = simplifyIntersection(builtinTypes, arena, ty, dt).result; + break; + } + } } scope->dcrRefinements[def] = ty; @@ -777,7 +794,7 @@ ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatForIn* f } // It is always ok to provide too few variables, so we give this pack a free tail. - TypePackId variablePack = arena->addTypePack(std::move(variableTypes), arena->addTypePack(FreeTypePack{loopScope.get()})); + TypePackId variablePack = arena->addTypePack(std::move(variableTypes), freshTypePack(loopScope)); addConstraint( loopScope, getLocation(forIn->values), IterableConstraint{iterator, variablePack, forIn->values.data[0], &module->astForInNextTypes}); @@ -982,6 +999,7 @@ ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatBlock* b return flow; } +// TODO Clip? static void bindFreeType(TypeId a, TypeId b) { FreeType* af = getMutable(a); @@ -1488,7 +1506,7 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa if (selfTy) args.push_back(*selfTy); else - args.push_back(arena->freshType(scope.get())); + args.push_back(freshType(scope)); } else if (i < exprArgs.size() - 1 || !(arg->is() || arg->is())) { @@ -2148,6 +2166,8 @@ TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr) Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTable* expr, std::optional expectedType) { + const bool expectedTypeIsFree = expectedType && get(follow(*expectedType)); + TypeId ty = arena->addType(TableType{}); TableType* ttv = getMutable(ty); LUAU_ASSERT(ttv); @@ -2192,7 +2212,7 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTable* exp if (item.kind == AstExprTable::Item::Kind::General || item.kind == AstExprTable::Item::Kind::List) isIndexedResultType = true; - if (item.key && expectedType) + if (item.key && expectedType && !expectedTypeIsFree) { if (auto stringKey = item.key->as()) { diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index fbe08162..e1291eeb 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -18,6 +18,7 @@ #include "Luau/TypeFamily.h" #include "Luau/TypeUtils.h" #include "Luau/Unifier.h" +#include "Luau/Unifier2.h" #include "Luau/VisitType.h" LUAU_FASTFLAGVARIABLE(DebugLuauLogSolver, false); @@ -441,6 +442,17 @@ void ConstraintSolver::finalizeModule() { rootScope->returnType = anyifyModuleReturnTypePackGenerics(*returnType); } + + Unifier2 u2{NotNull{arena}, builtinTypes, NotNull{&iceReporter}}; + + for (auto& [name, binding] : rootScope->bindings) + { + auto generalizedTy = u2.generalize(rootScope, binding.typeId); + if (generalizedTy) + binding.typeId = *generalizedTy; + else + reportError(CodeTooComplex{}, binding.location); + } } bool ConstraintSolver::tryDispatch(NotNull constraint, bool force) @@ -526,19 +538,28 @@ bool ConstraintSolver::tryDispatch(const GeneralizationConstraint& c, NotNull(generalizedType)) return block(generalizedType, constraint); - std::optional generalized = quantify(arena, c.sourceType, constraint->scope); + std::optional generalized; + + Unifier2 u2{NotNull{arena}, builtinTypes, NotNull{&iceReporter}}; + + std::optional generalizedTy = u2.generalize(constraint->scope, c.sourceType); + if (generalizedTy) + generalized = QuantifierResult{*generalizedTy}; // FIXME insertedGenerics and insertedGenericPacks + else + reportError(CodeTooComplex{}, constraint->location); + if (generalized) { if (get(generalizedType)) asMutable(generalizedType)->ty.emplace(generalized->result); else - unify(generalizedType, generalized->result, constraint->scope); + unify(constraint->scope, constraint->location, generalizedType, generalized->result); for (auto [free, gen] : generalized->insertedGenerics.pairings) - unify(free, gen, constraint->scope); + unify(constraint->scope, constraint->location, free, gen); for (auto [free, gen] : generalized->insertedGenericPacks.pairings) - unify(free, gen, constraint->scope); + unify(constraint->scope, constraint->location, free, gen); } else { @@ -560,12 +581,8 @@ bool ConstraintSolver::tryDispatch(const InstantiationConstraint& c, NotNullscope); - - if (limits.instantiationChildLimit) - inst.childLimit = *limits.instantiationChildLimit; - - std::optional instantiated = inst.substitute(c.superType); + // TODO childLimit + std::optional instantiated = instantiate(builtinTypes, NotNull{arena}, NotNull{&limits}, constraint->scope, c.superType); LUAU_ASSERT(get(c.subType)); @@ -634,7 +651,8 @@ bool ConstraintSolver::tryDispatch(const UnaryConstraint& c, NotNulladdTypePack(TypePack{{operandType}, {}}); TypePackId retPack = arena->addTypePack(BlockedTypePack{}); - asMutable(c.resultType)->ty.emplace(constraint->scope); + TypeId res = freshType(arena, builtinTypes, constraint->scope); + asMutable(c.resultType)->ty.emplace(res); pushConstraint(constraint->scope, constraint->location, PackSubtypeConstraint{retPack, arena->addTypePack(TypePack{{c.resultType}})}); @@ -722,12 +740,7 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNullscope}; - - if (limits.instantiationChildLimit) - instantiation.childLimit = *limits.instantiationChildLimit; - - std::optional instantiatedMm = instantiation.substitute(*mm); + std::optional instantiatedMm = instantiate(builtinTypes, arena, NotNull{&limits}, constraint->scope, *mm); if (!instantiatedMm) { reportError(CodeTooComplex{}, constraint->location); @@ -750,7 +763,7 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNulladdTypePack({leftType, rightType}); } - unify(inferredArgs, ftv->argTypes, constraint->scope); + unify(constraint->scope, constraint->location, inferredArgs, ftv->argTypes); TypeId mmResult; @@ -802,14 +815,14 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNullisExactlyNumber() || get(normLeftTy->tops))) { - unify(leftType, rightType, constraint->scope); + unify(constraint->scope, constraint->location, leftType, rightType); asMutable(resultType)->ty.emplace(anyPresent ? builtinTypes->anyType : leftType); unblock(resultType, constraint->location); return true; } else if (get(leftType) || get(rightType)) { - unify(leftType, rightType, constraint->scope); + unify(constraint->scope, constraint->location, leftType, rightType); asMutable(resultType)->ty.emplace(builtinTypes->neverType); unblock(resultType, constraint->location); return true; @@ -826,14 +839,14 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNullnormalize(leftType); if (leftNormTy && leftNormTy->isSubtypeOfString()) { - unify(leftType, rightType, constraint->scope); + unify(constraint->scope, constraint->location, leftType, rightType); asMutable(resultType)->ty.emplace(anyPresent ? builtinTypes->anyType : leftType); unblock(resultType, constraint->location); return true; } else if (get(leftType) || get(rightType)) { - unify(leftType, rightType, constraint->scope); + unify(constraint->scope, constraint->location, leftType, rightType); asMutable(resultType)->ty.emplace(builtinTypes->neverType); unblock(resultType, constraint->location); return true; @@ -909,8 +922,8 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNullscope); - unify(rightType, errorRecoveryType(), constraint->scope); + unify(constraint->scope, constraint->location, leftType, errorRecoveryType()); + unify(constraint->scope, constraint->location, rightType, errorRecoveryType()); asMutable(resultType)->ty.emplace(errorRecoveryType()); unblock(resultType, constraint->location); @@ -979,7 +992,7 @@ bool ConstraintSolver::tryDispatch(const IterableConstraint& c, NotNullscope, builtinTypes, &iceReporter, errorRecoveryType(), errorRecoveryTypePack()}; std::optional anyified = anyify.substitute(c.variables); LUAU_ASSERT(anyified); - unify(*anyified, c.variables, constraint->scope); + unify(constraint->scope, constraint->location, *anyified, c.variables); return true; } @@ -1330,11 +1343,6 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNulltypeFromNormal(*normFn); std::vector overloads = flattenIntersection(normFnTy); - Instantiation inst(TxnLog::empty(), arena, TypeLevel{}, constraint->scope); - - if (limits.instantiationChildLimit) - inst.childLimit = *limits.instantiationChildLimit; - std::vector arityMatchingOverloads; std::optional bestOverloadLog; @@ -1342,7 +1350,7 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull instantiated = inst.substitute(overload); + std::optional instantiated = instantiate(builtinTypes, arena, NotNull{&limits}, constraint->scope, overload); if (!instantiated.has_value()) { @@ -1451,7 +1459,8 @@ bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNullty.emplace(TableState::Free, TypeLevel{}, constraint->scope); ttv.props[c.prop] = Property{c.resultType}; - asMutable(c.resultType)->ty.emplace(constraint->scope); + TypeId res = freshType(arena, builtinTypes, constraint->scope); + asMutable(c.resultType)->ty.emplace(res); unblock(c.resultType, constraint->location); return true; } @@ -1579,7 +1588,7 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNullscope); + unify(constraint->scope, constraint->location, c.propType, *existingPropType); bind(c.resultType, c.subjectType); unblock(c.resultType, constraint->location); return true; @@ -1590,7 +1599,7 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNull(subjectType)) { - TypeId ty = arena->freshType(constraint->scope); + TypeId ty = freshType(arena, builtinTypes, constraint->scope); // Mint a chain of free tables per c.path for (auto it = rbegin(c.path); it != rend(c.path); ++it) @@ -1661,7 +1670,8 @@ bool ConstraintSolver::tryDispatch(const SetIndexerConstraint& c, NotNullindexer = TableIndexer{c.indexType, c.propType}; asMutable(c.resultType)->ty.emplace(subjectType); - asMutable(c.propType)->ty.emplace(scope); + TypeId propType = freshType(arena, builtinTypes, scope); + asMutable(c.propType)->ty.emplace(propType); unblock(c.propType, constraint->location); unblock(c.resultType, constraint->location); @@ -1672,7 +1682,7 @@ bool ConstraintSolver::tryDispatch(const SetIndexerConstraint& c, NotNullindexer) { // TODO This probably has to be invariant. - unify(c.indexType, tt->indexer->indexType, constraint->scope); + unify(constraint->scope, constraint->location, c.indexType, tt->indexer->indexType); asMutable(c.propType)->ty.emplace(tt->indexer->indexResultType); asMutable(c.resultType)->ty.emplace(subjectType); unblock(c.propType, constraint->location); @@ -1681,12 +1691,13 @@ bool ConstraintSolver::tryDispatch(const SetIndexerConstraint& c, NotNullstate == TableState::Free || tt->state == TableState::Unsealed) { - TypeId promotedIndexTy = arena->freshType(tt->scope); - unify(c.indexType, promotedIndexTy, constraint->scope); + TypeId promotedIndexTy = freshType(arena, builtinTypes, tt->scope); + unify(constraint->scope, constraint->location, c.indexType, promotedIndexTy); auto mtt = getMutable(subjectType); mtt->indexer = TableIndexer{promotedIndexTy, c.propType}; - asMutable(c.propType)->ty.emplace(tt->scope); + TypeId propType = freshType(arena, builtinTypes, tt->scope); + asMutable(c.propType)->ty.emplace(propType); asMutable(c.resultType)->ty.emplace(subjectType); unblock(c.propType, constraint->location); unblock(c.resultType, constraint->location); @@ -1754,14 +1765,15 @@ bool ConstraintSolver::tryDispatch(const UnpackConstraint& c, NotNullty.emplace(constraint->scope); + TypeId f = freshType(arena, builtinTypes, constraint->scope); + asMutable(*destIter)->ty.emplace(f); } else asMutable(*destIter)->ty.emplace(srcTy); unblock(*destIter, constraint->location); } else - unify(*destIter, srcTy, constraint->scope); + unify(constraint->scope, constraint->location, *destIter, srcTy); ++destIter; ++i; @@ -1889,7 +1901,10 @@ bool ConstraintSolver::tryDispatch(const RefineConstraint& c, NotNull(follow(c.discriminant))) - asMutable(c.resultType)->ty.emplace(constraint->scope); + { + TypeId f = freshType(arena, builtinTypes, constraint->scope); + asMutable(c.resultType)->ty.emplace(f); + } else asMutable(c.resultType)->ty.emplace(c.discriminant); @@ -1999,7 +2014,7 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl if (!anyified) reportError(CodeTooComplex{}, constraint->location); else - unify(*anyified, ty, constraint->scope); + unify(constraint->scope, constraint->location, *anyified, ty); }; auto unknownify = [&](auto ty) { @@ -2008,7 +2023,7 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl if (!anyified) reportError(CodeTooComplex{}, constraint->location); else - unify(*anyified, ty, constraint->scope); + unify(constraint->scope, constraint->location, *anyified, ty); }; auto errorify = [&](auto ty) { @@ -2017,7 +2032,7 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl if (!errorified) reportError(CodeTooComplex{}, constraint->location); else - unify(*errorified, ty, constraint->scope); + unify(constraint->scope, constraint->location, *errorified, ty); }; auto neverify = [&](auto ty) { @@ -2026,7 +2041,7 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl if (!neverified) reportError(CodeTooComplex{}, constraint->location); else - unify(*neverified, ty, constraint->scope); + unify(constraint->scope, constraint->location, *neverified, ty); }; if (get(iteratorTy)) @@ -2065,7 +2080,7 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl if (iteratorTable->indexer) { TypePackId expectedVariablePack = arena->addTypePack({iteratorTable->indexer->indexType, iteratorTable->indexer->indexResultType}); - unify(c.variables, expectedVariablePack, constraint->scope); + unify(constraint->scope, constraint->location, c.variables, expectedVariablePack); } else errorify(c.variables); @@ -2077,17 +2092,12 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl return block(*iterFn, constraint); } - Instantiation instantiation(TxnLog::empty(), arena, TypeLevel{}, constraint->scope); - - if (limits.instantiationChildLimit) - instantiation.childLimit = *limits.instantiationChildLimit; - - if (std::optional instantiatedIterFn = instantiation.substitute(*iterFn)) + if (std::optional instantiatedIterFn = instantiate(builtinTypes, arena, NotNull{&limits}, constraint->scope, *iterFn)) { if (auto iterFtv = get(*instantiatedIterFn)) { TypePackId expectedIterArgs = arena->addTypePack({iteratorTy}); - unify(iterFtv->argTypes, expectedIterArgs, constraint->scope); + unify(constraint->scope, constraint->location, iterFtv->argTypes, expectedIterArgs); TypePack iterRets = extendTypePack(*arena, builtinTypes, iterFtv->retTypes, 2); @@ -2099,11 +2109,11 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl } TypeId nextFn = iterRets.head[0]; - TypeId table = iterRets.head.size() == 2 ? iterRets.head[1] : arena->freshType(constraint->scope); + TypeId table = iterRets.head.size() == 2 ? iterRets.head[1] : freshType(arena, builtinTypes, constraint->scope); - if (std::optional instantiatedNextFn = instantiation.substitute(nextFn)) + if (std::optional instantiatedNextFn = instantiate(builtinTypes, arena, NotNull{&limits}, constraint->scope, nextFn)) { - const TypeId firstIndex = arena->freshType(constraint->scope); + const TypeId firstIndex = freshType(arena, builtinTypes, constraint->scope); // nextTy : (iteratorTy, indexTy?) -> (indexTy, valueTailTy...) const TypePackId nextArgPack = arena->addTypePack({table, arena->addType(UnionType{{firstIndex, builtinTypes->nilType}})}); @@ -2111,7 +2121,7 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl const TypePackId nextRetPack = arena->addTypePack(TypePack{{firstIndex}, valueTailTy}); const TypeId expectedNextTy = arena->addType(FunctionType{nextArgPack, nextRetPack}); - unify(*instantiatedNextFn, expectedNextTy, constraint->scope); + unify(constraint->scope, constraint->location, *instantiatedNextFn, expectedNextTy); pushConstraint(constraint->scope, constraint->location, PackSubtypeConstraint{c.variables, nextRetPack}); } @@ -2165,7 +2175,7 @@ bool ConstraintSolver::tryDispatchIterableFunction( TypeId retIndex; if (isNil(firstIndexTy) || isOptional(firstIndexTy)) { - firstIndex = arena->addType(UnionType{{arena->freshType(constraint->scope), builtinTypes->nilType}}); + firstIndex = arena->addType(UnionType{{freshType(arena, builtinTypes, constraint->scope), builtinTypes->nilType}}); retIndex = firstIndex; } else @@ -2180,7 +2190,7 @@ bool ConstraintSolver::tryDispatchIterableFunction( const TypePackId nextRetPack = arena->addTypePack(TypePack{{retIndex}, valueTailTy}); const TypeId expectedNextTy = arena->addType(FunctionType{TypeLevel{}, constraint->scope, nextArgPack, nextRetPack}); - ErrorVec errors = unify(nextTy, expectedNextTy, constraint->scope); + ErrorVec errors = unify(constraint->scope, constraint->location, nextTy, expectedNextTy); // if there are no errors from unifying the two, we can pass forward the expected type as our selected resolution. if (errors.empty()) @@ -2241,7 +2251,7 @@ std::pair, std::optional> ConstraintSolver::lookupTa return {{}, ttv->indexer->indexResultType}; else if (ttv->state == TableState::Free) { - TypeId result = arena->freshType(ttv->scope); + TypeId result = freshType(arena, builtinTypes, ttv->scope); ttv->props[propName] = Property{result}; return {{}, result}; } @@ -2309,7 +2319,7 @@ std::pair, std::optional> ConstraintSolver::lookupTa TableType* tt = &asMutable(subjectType)->ty.emplace(); tt->state = TableState::Free; tt->scope = scope; - TypeId propType = arena->freshType(scope); + TypeId propType = freshType(arena, builtinTypes, scope); tt->props[propName] = Property{propType}; return {{}, propType}; @@ -2376,49 +2386,22 @@ std::pair, std::optional> ConstraintSolver::lookupTa return {{}, std::nullopt}; } -static TypeId getErrorType(NotNull builtinTypes, TypeId) -{ - return builtinTypes->errorRecoveryType(); -} - -static TypePackId getErrorType(NotNull builtinTypes, TypePackId) -{ - return builtinTypes->errorRecoveryTypePack(); -} - template bool ConstraintSolver::tryUnify(NotNull constraint, TID subTy, TID superTy) { - Unifier u{normalizer, constraint->scope, constraint->location, Covariant}; - u.enableNewSolver(); + Unifier2 u2{NotNull{arena}, builtinTypes, NotNull{&iceReporter}}; - u.tryUnify(subTy, superTy); + bool success = u2.unify(subTy, superTy); - if (!u.blockedTypes.empty() || !u.blockedTypePacks.empty()) + if (!success) { - for (TypeId bt : u.blockedTypes) - block(bt, constraint); - for (TypePackId btp : u.blockedTypePacks) - block(btp, constraint); - return false; + // Unification only fails when doing so would fail the occurs check. + // ie create a self-bound type or a cyclic type pack + reportError(OccursCheckFailed{}, constraint->location); } - if (const auto& e = hasUnificationTooComplex(u.errors)) - reportError(*e); - - if (!u.errors.empty()) - { - TID errorType = getErrorType(builtinTypes, TID{}); - u.tryUnify(subTy, errorType); - u.tryUnify(superTy, errorType); - } - - const auto [changedTypes, changedPacks] = u.log.getChanges(); - - u.log.commit(); - - unblock(changedTypes, constraint->location); - unblock(changedPacks, constraint->location); + unblock(subTy, constraint->location); + unblock(superTy, constraint->location); return true; } @@ -2636,31 +2619,22 @@ bool ConstraintSolver::isBlocked(NotNull constraint) return blockedIt != blockedConstraints.end() && blockedIt->second > 0; } -ErrorVec ConstraintSolver::unify(TypeId subType, TypeId superType, NotNull scope) +ErrorVec ConstraintSolver::unify(NotNull scope, Location location, TypeId subType, TypeId superType) { - Unifier u{normalizer, scope, Location{}, Covariant}; - u.enableNewSolver(); + Unifier2 u2{NotNull{arena}, builtinTypes, NotNull{&iceReporter}}; - u.tryUnify(subType, superType); + const bool ok = u2.unify(subType, superType); - if (!u.errors.empty()) - { - TypeId errorType = errorRecoveryType(); - u.tryUnify(subType, errorType); - u.tryUnify(superType, errorType); - } + if (!ok) + reportError(UnificationTooComplex{}, location); - const auto [changedTypes, changedPacks] = u.log.getChanges(); + unblock(subType, Location{}); + unblock(superType, Location{}); - u.log.commit(); - - unblock(changedTypes, Location{}); - unblock(changedPacks, Location{}); - - return std::move(u.errors); + return {}; } -ErrorVec ConstraintSolver::unify(TypePackId subPack, TypePackId superPack, NotNull scope) +ErrorVec ConstraintSolver::unify(NotNull scope, Location location, TypePackId subPack, TypePackId superPack) { UnifierSharedState sharedState{&iceReporter}; Unifier u{normalizer, scope, Location{}, Covariant}; diff --git a/Analysis/src/Differ.cpp b/Analysis/src/Differ.cpp index 307446ef..20f059d3 100644 --- a/Analysis/src/Differ.cpp +++ b/Analysis/src/Differ.cpp @@ -199,6 +199,13 @@ std::string getDevFixFriendlyName(TypeId ty) else if (table->syntheticName.has_value()) return *table->syntheticName; } + if (auto metatable = get(ty)) + { + if (metatable->syntheticName.has_value()) + { + return *metatable->syntheticName; + } + } // else if (auto primitive = get(ty)) //{ // return ""; @@ -246,11 +253,13 @@ void DifferResult::wrapDiffPath(DiffPathNode node) static DifferResult diffUsingEnv(DifferEnvironment& env, TypeId left, TypeId right); static DifferResult diffTable(DifferEnvironment& env, TypeId left, TypeId right); +static DifferResult diffMetatable(DifferEnvironment& env, TypeId left, TypeId right); static DifferResult diffPrimitive(DifferEnvironment& env, TypeId left, TypeId right); static DifferResult diffSingleton(DifferEnvironment& env, TypeId left, TypeId right); static DifferResult diffFunction(DifferEnvironment& env, TypeId left, TypeId right); static DifferResult diffGeneric(DifferEnvironment& env, TypeId left, TypeId right); static DifferResult diffNegation(DifferEnvironment& env, TypeId left, TypeId right); +static DifferResult diffClass(DifferEnvironment& env, TypeId left, TypeId right); struct FindSeteqCounterexampleResult { // nullopt if no counterexample found @@ -269,6 +278,7 @@ static DifferResult diffTpi(DifferEnvironment& env, DiffError::Kind possibleNonN static DifferResult diffCanonicalTpShape(DifferEnvironment& env, DiffError::Kind possibleNonNormalErrorKind, const std::pair, std::optional>& left, const std::pair, std::optional>& right); static DifferResult diffHandleFlattenedTail(DifferEnvironment& env, DiffError::Kind possibleNonNormalErrorKind, TypePackId left, TypePackId right); +static DifferResult diffGenericTp(DifferEnvironment& env, TypePackId left, TypePackId right); static DifferResult diffTable(DifferEnvironment& env, TypeId left, TypeId right) { @@ -315,6 +325,28 @@ static DifferResult diffTable(DifferEnvironment& env, TypeId left, TypeId right) return DifferResult{}; } +static DifferResult diffMetatable(DifferEnvironment& env, TypeId left, TypeId right) +{ + const MetatableType* leftMetatable = get(left); + const MetatableType* rightMetatable = get(right); + LUAU_ASSERT(leftMetatable); + LUAU_ASSERT(rightMetatable); + + DifferResult diffRes = diffUsingEnv(env, leftMetatable->table, rightMetatable->table); + if (diffRes.diffError.has_value()) + { + return diffRes; + } + + diffRes = diffUsingEnv(env, leftMetatable->metatable, rightMetatable->metatable); + if (diffRes.diffError.has_value()) + { + diffRes.wrapDiffPath(DiffPathNode::constructWithTableProperty("__metatable")); + return diffRes; + } + return DifferResult{}; +} + static DifferResult diffPrimitive(DifferEnvironment& env, TypeId left, TypeId right) { const PrimitiveType* leftPrimitive = get(left); @@ -420,6 +452,27 @@ static DifferResult diffNegation(DifferEnvironment& env, TypeId left, TypeId rig return differResult; } +static DifferResult diffClass(DifferEnvironment& env, TypeId left, TypeId right) +{ + const ClassType* leftClass = get(left); + const ClassType* rightClass = get(right); + LUAU_ASSERT(leftClass); + LUAU_ASSERT(rightClass); + + if (leftClass == rightClass) + { + return DifferResult{}; + } + + return DifferResult{DiffError{ + DiffError::Kind::Normal, + DiffPathNodeLeaf::detailsNormal(left), + DiffPathNodeLeaf::detailsNormal(right), + getDevFixFriendlyName(env.rootLeft), + getDevFixFriendlyName(env.rootRight), + }}; +} + static FindSeteqCounterexampleResult findSeteqCounterexample( DifferEnvironment& env, const std::vector& left, const std::vector& right) { @@ -438,8 +491,8 @@ static FindSeteqCounterexampleResult findSeteqCounterexample( unmatchedRightIdxIt++; continue; } - // unmatchedRightIdxIt is matched with current leftIdx + env.recordProvenEqual(left[leftIdx], right[*unmatchedRightIdxIt]); leftIdxIsMatched = true; unmatchedRightIdxIt = unmatchedRightIdxes.erase(unmatchedRightIdxIt); } @@ -537,6 +590,10 @@ static DifferResult diffUsingEnv(DifferEnvironment& env, TypeId left, TypeId rig // Both left and right are the same variant + // Check cycles & caches + if (env.isAssumedEqual(left, right) || env.isProvenEqual(left, right)) + return DifferResult{}; + if (isSimple(left)) { if (auto lp = get(left)) @@ -550,39 +607,89 @@ static DifferResult diffUsingEnv(DifferEnvironment& env, TypeId left, TypeId rig // Both left and right must be Any if either is Any for them to be equal! return DifferResult{}; } + else if (auto lu = get(left)) + { + return DifferResult{}; + } + else if (auto ln = get(left)) + { + return DifferResult{}; + } else if (auto ln = get(left)) { return diffNegation(env, left, right); } + else if (auto lc = get(left)) + { + return diffClass(env, left, right); + } throw InternalCompilerError{"Unimplemented Simple TypeId variant for diffing"}; } // Both left and right are the same non-Simple + // Non-simple types must record visits in the DifferEnvironment + env.pushVisiting(left, right); if (auto lt = get(left)) { - return diffTable(env, left, right); + DifferResult diffRes = diffTable(env, left, right); + if (!diffRes.diffError.has_value()) + { + env.recordProvenEqual(left, right); + } + env.popVisiting(); + return diffRes; + } + if (auto lm = get(left)) + { + env.popVisiting(); + return diffMetatable(env, left, right); } if (auto lf = get(left)) { - return diffFunction(env, left, right); + DifferResult diffRes = diffFunction(env, left, right); + if (!diffRes.diffError.has_value()) + { + env.recordProvenEqual(left, right); + } + env.popVisiting(); + return diffRes; } if (auto lg = get(left)) { - return diffGeneric(env, left, right); + DifferResult diffRes = diffGeneric(env, left, right); + if (!diffRes.diffError.has_value()) + { + env.recordProvenEqual(left, right); + } + env.popVisiting(); + return diffRes; } if (auto lu = get(left)) { - return diffUnion(env, left, right); + DifferResult diffRes = diffUnion(env, left, right); + if (!diffRes.diffError.has_value()) + { + env.recordProvenEqual(left, right); + } + env.popVisiting(); + return diffRes; } if (auto li = get(left)) { - return diffIntersection(env, left, right); + DifferResult diffRes = diffIntersection(env, left, right); + if (!diffRes.diffError.has_value()) + { + env.recordProvenEqual(left, right); + } + env.popVisiting(); + return diffRes; } if (auto le = get(left)) { // TODO: return debug-friendly result state + env.popVisiting(); return DifferResult{}; } @@ -658,7 +765,13 @@ static DifferResult diffHandleFlattenedTail(DifferEnvironment& env, DiffError::K if (left->ty.index() != right->ty.index()) { - throw InternalCompilerError{"Unhandled case where the tail of 2 normalized typepacks have different variants"}; + return DifferResult{DiffError{ + DiffError::Kind::Normal, + DiffPathNodeLeaf::detailsNormal(env.visitingBegin()->first), + DiffPathNodeLeaf::detailsNormal(env.visitingBegin()->second), + getDevFixFriendlyName(env.rootLeft), + getDevFixFriendlyName(env.rootRight), + }}; } // Both left and right are the same variant @@ -688,13 +801,116 @@ static DifferResult diffHandleFlattenedTail(DifferEnvironment& env, DiffError::K } } } + if (auto lg = get(left)) + { + DifferResult diffRes = diffGenericTp(env, left, right); + if (!diffRes.diffError.has_value()) + return DifferResult{}; + switch (possibleNonNormalErrorKind) + { + case DiffError::Kind::LengthMismatchInFnArgs: + { + diffRes.wrapDiffPath(DiffPathNode::constructWithKind(DiffPathNode::Kind::FunctionArgument)); + return diffRes; + } + case DiffError::Kind::LengthMismatchInFnRets: + { + diffRes.wrapDiffPath(DiffPathNode::constructWithKind(DiffPathNode::Kind::FunctionReturn)); + return diffRes; + } + default: + { + throw InternalCompilerError{"Unhandled flattened tail case for GenericTypePack"}; + } + } + } throw InternalCompilerError{"Unhandled tail type pack variant for flattened tails"}; } +static DifferResult diffGenericTp(DifferEnvironment& env, TypePackId left, TypePackId right) +{ + LUAU_ASSERT(get(left)); + LUAU_ASSERT(get(right)); + // Try to pair up the generics + bool isLeftFree = !env.genericTpMatchedPairs.contains(left); + bool isRightFree = !env.genericTpMatchedPairs.contains(right); + if (isLeftFree && isRightFree) + { + env.genericTpMatchedPairs[left] = right; + env.genericTpMatchedPairs[right] = left; + return DifferResult{}; + } + else if (isLeftFree || isRightFree) + { + return DifferResult{DiffError{ + DiffError::Kind::IncompatibleGeneric, + DiffPathNodeLeaf::nullopts(), + DiffPathNodeLeaf::nullopts(), + getDevFixFriendlyName(env.rootLeft), + getDevFixFriendlyName(env.rootRight), + }}; + } + + // Both generics are already paired up + if (*env.genericTpMatchedPairs.find(left) == right) + return DifferResult{}; + + return DifferResult{DiffError{ + DiffError::Kind::IncompatibleGeneric, + DiffPathNodeLeaf::nullopts(), + DiffPathNodeLeaf::nullopts(), + getDevFixFriendlyName(env.rootLeft), + getDevFixFriendlyName(env.rootRight), + }}; +} + +bool DifferEnvironment::isProvenEqual(TypeId left, TypeId right) const +{ + return provenEqual.find({left, right}) != provenEqual.end(); +} + +bool DifferEnvironment::isAssumedEqual(TypeId left, TypeId right) const +{ + return visiting.find({left, right}) != visiting.end(); +} + +void DifferEnvironment::recordProvenEqual(TypeId left, TypeId right) +{ + provenEqual.insert({left, right}); + provenEqual.insert({right, left}); +} + +void DifferEnvironment::pushVisiting(TypeId left, TypeId right) +{ + LUAU_ASSERT(visiting.find({left, right}) == visiting.end()); + LUAU_ASSERT(visiting.find({right, left}) == visiting.end()); + visitingStack.push_back({left, right}); + visiting.insert({left, right}); + visiting.insert({right, left}); +} + +void DifferEnvironment::popVisiting() +{ + auto tyPair = visitingStack.back(); + visiting.erase({tyPair.first, tyPair.second}); + visiting.erase({tyPair.second, tyPair.first}); + visitingStack.pop_back(); +} + +std::vector>::const_reverse_iterator DifferEnvironment::visitingBegin() const +{ + return visitingStack.crbegin(); +} + +std::vector>::const_reverse_iterator DifferEnvironment::visitingEnd() const +{ + return visitingStack.crend(); +} + DifferResult diff(TypeId ty1, TypeId ty2) { - DifferEnvironment differEnv{ty1, ty2, DenseHashMap{nullptr}}; + DifferEnvironment differEnv{ty1, ty2}; return diffUsingEnv(differEnv, ty1, ty2); } @@ -702,7 +918,8 @@ bool isSimple(TypeId ty) { ty = follow(ty); // TODO: think about GenericType, etc. - return get(ty) || get(ty) || get(ty) || get(ty); + return get(ty) || get(ty) || get(ty) || get(ty) || get(ty) || + get(ty) || get(ty); } } // namespace Luau diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 362fcdcc..52eedece 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -1182,7 +1182,7 @@ ModulePtr check(const SourceModule& sourceModule, const std::vectorinternalTypes, builtinTypes, NotNull{&unifierState}}; - ConstraintGraphBuilder cgb{result, &result->internalTypes, moduleResolver, builtinTypes, iceHandler, parentScope, std::move(prepareModuleScope), + ConstraintGraphBuilder cgb{result, NotNull{&normalizer}, moduleResolver, builtinTypes, iceHandler, parentScope, std::move(prepareModuleScope), logger.get(), NotNull{&dfg}, requireCycles}; cgb.visit(sourceModule.root); @@ -1229,7 +1229,7 @@ ModulePtr check(const SourceModule& sourceModule, const std::vectorgenerics, ftv->genericPacks}; + ReplaceGenerics replaceGenerics{log, arena, builtinTypes, level, scope, ftv->generics, ftv->genericPacks}; // TODO: What to do if this returns nullopt? // We don't have access to the error-reporting machinery @@ -118,8 +123,16 @@ TypeId ReplaceGenerics::clean(TypeId ty) clone.definitionLocation = ttv->definitionLocation; return addType(std::move(clone)); } + else if (FFlag::DebugLuauDeferredConstraintResolution) + { + TypeId res = freshType(NotNull{arena}, builtinTypes, scope); + getMutable(res)->level = level; + return res; + } else + { return addType(FreeType{scope, level}); + } } TypePackId ReplaceGenerics::clean(TypePackId tp) @@ -128,4 +141,75 @@ TypePackId ReplaceGenerics::clean(TypePackId tp) return addTypePack(TypePackVar(FreeTypePack{scope, level})); } +struct Replacer : Substitution +{ + DenseHashMap replacements; + DenseHashMap replacementPacks; + + Replacer(NotNull arena, DenseHashMap replacements, DenseHashMap replacementPacks) + : Substitution(TxnLog::empty(), arena) + , replacements(std::move(replacements)) + , replacementPacks(std::move(replacementPacks)) + { + } + + bool isDirty(TypeId ty) override + { + return replacements.find(ty) != nullptr; + } + + bool isDirty(TypePackId tp) override + { + return replacementPacks.find(tp) != nullptr; + } + + TypeId clean(TypeId ty) override + { + return replacements[ty]; + } + + TypePackId clean(TypePackId tp) override + { + return replacementPacks[tp]; + } +}; + +std::optional instantiate(NotNull builtinTypes, NotNull arena, NotNull limits, NotNull scope, TypeId ty) +{ + ty = follow(ty); + + const FunctionType* ft = get(ty); + if (!ft) + return ty; + + if (ft->generics.empty() && ft->genericPacks.empty()) + return ty; + + DenseHashMap replacements{nullptr}; + DenseHashMap replacementPacks{nullptr}; + + for (TypeId g : ft->generics) + replacements[g] = freshType(arena, builtinTypes, scope); + + for (TypePackId g : ft->genericPacks) + replacementPacks[g] = arena->freshTypePack(scope); + + Replacer r{arena, std::move(replacements), std::move(replacementPacks)}; + + if (limits->instantiationChildLimit) + r.childLimit = *limits->instantiationChildLimit; + + std::optional res = r.substitute(ty); + if (!res) + return res; + + FunctionType* ft2 = getMutable(*res); + LUAU_ASSERT(ft != ft2); + + ft2->generics.clear(); + ft2->genericPacks.clear(); + + return res; +} + } // namespace Luau diff --git a/Analysis/src/Simplify.cpp b/Analysis/src/Simplify.cpp index 20a9fa57..15b3b2c6 100644 --- a/Analysis/src/Simplify.cpp +++ b/Analysis/src/Simplify.cpp @@ -9,6 +9,7 @@ #include LUAU_FASTINT(LuauTypeReductionRecursionLimit) +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) namespace Luau { @@ -1109,6 +1110,19 @@ TypeId TypeSimplifier::intersect(TypeId left, TypeId right) if (get(right)) return right; + if (auto lf = get(left)) + { + Relation r = relate(lf->upperBound, right); + if (r == Relation::Subset || r == Relation::Coincident) + return left; + } + else if (auto rf = get(right)) + { + Relation r = relate(left, rf->upperBound); + if (r == Relation::Superset || r == Relation::Coincident) + return right; + } + if (isTypeVariable(left)) { blockedTypes.insert(left); @@ -1160,6 +1174,11 @@ TypeId TypeSimplifier::union_(TypeId left, TypeId right) left = simplify(left); right = simplify(right); + if (get(left)) + return right; + if (get(right)) + return left; + if (auto leftUnion = get(left)) { bool changed = false; @@ -1263,6 +1282,8 @@ TypeId TypeSimplifier::simplify(TypeId ty, DenseHashSet& seen) SimplifyResult simplifyIntersection(NotNull builtinTypes, NotNull arena, TypeId left, TypeId right) { + LUAU_ASSERT(FFlag::DebugLuauDeferredConstraintResolution); + TypeSimplifier s{builtinTypes, arena}; // fprintf(stderr, "Intersect %s and %s ...\n", toString(left).c_str(), toString(right).c_str()); @@ -1276,11 +1297,13 @@ SimplifyResult simplifyIntersection(NotNull builtinTypes, NotNull< SimplifyResult simplifyUnion(NotNull builtinTypes, NotNull arena, TypeId left, TypeId right) { + LUAU_ASSERT(FFlag::DebugLuauDeferredConstraintResolution); + TypeSimplifier s{builtinTypes, arena}; TypeId res = s.union_(left, right); - // fprintf(stderr, "Union %s and %s -> %s\n", toString(a).c_str(), toString(b).c_str(), toString(res).c_str()); + // fprintf(stderr, "Union %s and %s -> %s\n", toString(left).c_str(), toString(right).c_str(), toString(res).c_str()); return SimplifyResult{res, std::move(s.blockedTypes)}; } diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 19776d0a..96492a39 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -75,6 +75,20 @@ struct FindCyclicTypes final : TypeVisitor return visitedPacks.insert(tp).second; } + bool visit(TypeId ty, const FreeType& ft) override + { + if (!visited.insert(ty).second) + return false; + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + traverse(ft.lowerBound); + traverse(ft.upperBound); + } + + return false; + } + bool visit(TypeId ty, const TableType& ttv) override { if (!visited.insert(ty).second) @@ -428,6 +442,36 @@ struct TypeStringifier { state.result.invalid = true; + if (FFlag::DebugLuauDeferredConstraintResolution) + { + const TypeId lowerBound = follow(ftv.lowerBound); + const TypeId upperBound = follow(ftv.upperBound); + if (get(lowerBound) && get(upperBound)) + { + state.emit("'"); + state.emit(state.getName(ty)); + } + else + { + state.emit("("); + if (!get(lowerBound)) + { + stringify(lowerBound); + state.emit(" <: "); + } + state.emit("'"); + state.emit(state.getName(ty)); + + if (!get(upperBound)) + { + state.emit(" <: "); + stringify(upperBound); + } + state.emit(")"); + } + return; + } + if (FInt::DebugLuauVerboseTypeNames >= 1) state.emit("free-"); diff --git a/Analysis/src/Type.cpp b/Analysis/src/Type.cpp index 2aa13bc9..fb72bc12 100644 --- a/Analysis/src/Type.cpp +++ b/Analysis/src/Type.cpp @@ -476,6 +476,14 @@ FreeType::FreeType(Scope* scope, TypeLevel level) { } +FreeType::FreeType(Scope* scope, TypeId lowerBound, TypeId upperBound) + : index(Unifiable::freshIndex()) + , scope(scope) + , lowerBound(lowerBound) + , upperBound(upperBound) +{ +} + GenericType::GenericType() : index(Unifiable::freshIndex()) , name("g" + std::to_string(index)) @@ -1351,7 +1359,7 @@ static bool dcrMagicFunctionFormat(MagicFunctionCallContext context) // unify the prefix one argument at a time for (size_t i = 0; i < expected.size() && i + paramOffset < params.size(); ++i) { - context.solver->unify(params[i + paramOffset], expected[i], context.solver->rootScope); + context.solver->unify(context.solver->rootScope, context.callSite->location, params[i + paramOffset], expected[i]); } // if we know the argument count or if we have too many arguments for sure, we can issue an error @@ -1481,7 +1489,7 @@ static bool dcrMagicFunctionGmatch(MagicFunctionCallContext context) if (returnTypes.empty()) return false; - context.solver->unify(params[0], context.solver->builtinTypes->stringType, context.solver->rootScope); + context.solver->unify(context.solver->rootScope, context.callSite->location, params[0], context.solver->builtinTypes->stringType); const TypePackId emptyPack = arena->addTypePack({}); const TypePackId returnList = arena->addTypePack(returnTypes); @@ -1550,13 +1558,13 @@ static bool dcrMagicFunctionMatch(MagicFunctionCallContext context) if (returnTypes.empty()) return false; - context.solver->unify(params[0], context.solver->builtinTypes->stringType, context.solver->rootScope); + context.solver->unify(context.solver->rootScope, context.callSite->location, params[0], context.solver->builtinTypes->stringType); const TypeId optionalNumber = arena->addType(UnionType{{context.solver->builtinTypes->nilType, context.solver->builtinTypes->numberType}}); size_t initIndex = context.callSite->self ? 1 : 2; if (params.size() == 3 && context.callSite->args.size > initIndex) - context.solver->unify(params[2], optionalNumber, context.solver->rootScope); + context.solver->unify(context.solver->rootScope, context.callSite->location, params[2], optionalNumber); const TypePackId returnList = arena->addTypePack(returnTypes); asMutable(context.result)->ty.emplace(returnList); @@ -1653,17 +1661,17 @@ static bool dcrMagicFunctionFind(MagicFunctionCallContext context) return false; } - context.solver->unify(params[0], builtinTypes->stringType, context.solver->rootScope); + context.solver->unify(context.solver->rootScope, context.callSite->location, params[0], builtinTypes->stringType); const TypeId optionalNumber = arena->addType(UnionType{{builtinTypes->nilType, builtinTypes->numberType}}); const TypeId optionalBoolean = arena->addType(UnionType{{builtinTypes->nilType, builtinTypes->booleanType}}); size_t initIndex = context.callSite->self ? 1 : 2; if (params.size() >= 3 && context.callSite->args.size > initIndex) - context.solver->unify(params[2], optionalNumber, context.solver->rootScope); + context.solver->unify(context.solver->rootScope, context.callSite->location, params[2], optionalNumber); if (params.size() == 4 && context.callSite->args.size > plainIndex) - context.solver->unify(params[3], optionalBoolean, context.solver->rootScope); + context.solver->unify(context.solver->rootScope, context.callSite->location, params[3], optionalBoolean); returnTypes.insert(returnTypes.begin(), {optionalNumber, optionalNumber}); @@ -1672,6 +1680,11 @@ static bool dcrMagicFunctionFind(MagicFunctionCallContext context) return true; } +TypeId freshType(NotNull arena, NotNull builtinTypes, Scope* scope) +{ + return arena->addType(FreeType{scope, builtinTypes->neverType, builtinTypes->unknownType}); +} + std::vector filterMap(TypeId type, TypeIdPredicate predicate) { type = follow(type); diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index 40a4bd0f..08b1ffbc 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -228,7 +228,8 @@ struct TypeChecker2 { NotNull builtinTypes; DcrLogger* logger; - NotNull ice; + const NotNull limits; + const NotNull ice; const SourceModule* sourceModule; Module* module; TypeArena testArena; @@ -240,10 +241,11 @@ struct TypeChecker2 Normalizer normalizer; - TypeChecker2(NotNull builtinTypes, NotNull unifierState, DcrLogger* logger, const SourceModule* sourceModule, + TypeChecker2(NotNull builtinTypes, NotNull unifierState, NotNull limits, DcrLogger* logger, const SourceModule* sourceModule, Module* module) : builtinTypes(builtinTypes) , logger(logger) + , limits(limits) , ice(unifierState->iceHandler) , sourceModule(sourceModule) , module(module) @@ -807,9 +809,9 @@ struct TypeChecker2 else if (std::optional iterMmTy = findMetatableEntry(builtinTypes, module->errors, iteratorTy, "__iter", forInStatement->values.data[0]->location)) { - Instantiation instantiation{TxnLog::empty(), &arena, TypeLevel{}, scope}; + Instantiation instantiation{TxnLog::empty(), &arena, builtinTypes, TypeLevel{}, scope}; - if (std::optional instantiatedIterMmTy = instantiation.substitute(*iterMmTy)) + if (std::optional instantiatedIterMmTy = instantiate(builtinTypes, NotNull{&arena}, limits, scope, *iterMmTy)) { if (const FunctionType* iterMmFtv = get(*instantiatedIterMmTy)) { @@ -2679,9 +2681,9 @@ struct TypeChecker2 }; void check( - NotNull builtinTypes, NotNull unifierState, DcrLogger* logger, const SourceModule& sourceModule, Module* module) + NotNull builtinTypes, NotNull unifierState, NotNull limits, DcrLogger* logger, const SourceModule& sourceModule, Module* module) { - TypeChecker2 typeChecker{builtinTypes, unifierState, logger, &sourceModule, module}; + TypeChecker2 typeChecker{builtinTypes, unifierState, limits, logger, &sourceModule, module}; typeChecker.visit(sourceModule.root); diff --git a/Analysis/src/TypeFamily.cpp b/Analysis/src/TypeFamily.cpp index 4adf0f8a..61a92c77 100644 --- a/Analysis/src/TypeFamily.cpp +++ b/Analysis/src/TypeFamily.cpp @@ -3,14 +3,15 @@ #include "Luau/TypeFamily.h" #include "Luau/DenseHash.h" -#include "Luau/VisitType.h" -#include "Luau/TxnLog.h" -#include "Luau/Substitution.h" -#include "Luau/ToString.h" -#include "Luau/TypeUtils.h" -#include "Luau/Unifier.h" #include "Luau/Instantiation.h" #include "Luau/Normalize.h" +#include "Luau/Substitution.h" +#include "Luau/ToString.h" +#include "Luau/TxnLog.h" +#include "Luau/TypeCheckLimits.h" +#include "Luau/TypeUtils.h" +#include "Luau/Unifier.h" +#include "Luau/VisitType.h" LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFamilyGraphReductionMaximumSteps, 1'000'000); @@ -397,8 +398,8 @@ TypeFamilyReductionResult addFamilyFn(std::vector typeParams, st if (!mmFtv) return {std::nullopt, true, {}, {}}; - Instantiation instantiation{log.get(), arena.get(), TypeLevel{}, scope.get()}; - if (std::optional instantiatedAddMm = instantiation.substitute(log->follow(*addMm))) + TypeCheckLimits limits; // TODO: We need to thread TypeCheckLimits in from Frontend to here. + if (std::optional instantiatedAddMm = instantiate(builtins, arena, NotNull{&limits}, scope, log->follow(*addMm))) { if (const FunctionType* instantiatedMmFtv = get(*instantiatedAddMm)) { diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index a8025096..d0ae4133 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -4860,7 +4860,7 @@ TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location locat if (ftv && ftv->hasNoFreeOrGenericTypes) return ty; - Instantiation instantiation{log, ¤tModule->internalTypes, scope->level, /*scope*/ nullptr}; + Instantiation instantiation{log, ¤tModule->internalTypes, builtinTypes, scope->level, /*scope*/ nullptr}; if (instantiationChildLimit) instantiation.childLimit = *instantiationChildLimit; diff --git a/Analysis/src/TypeUtils.cpp b/Analysis/src/TypeUtils.cpp index 4f87de8f..089008fa 100644 --- a/Analysis/src/TypeUtils.cpp +++ b/Analysis/src/TypeUtils.cpp @@ -8,6 +8,8 @@ #include +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); + namespace Luau { @@ -190,7 +192,13 @@ TypePack extendTypePack( } else { - t = arena.freshType(ftp->scope); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + FreeType ft{ftp->scope, builtinTypes->neverType, builtinTypes->unknownType}; + t = arena.addType(ft); + } + else + t = arena.freshType(ftp->scope); } newPack.head.push_back(t); diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index c1b5e45e..db8e2008 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -24,6 +24,8 @@ LUAU_FASTFLAGVARIABLE(LuauTransitiveSubtyping, false) LUAU_FASTFLAGVARIABLE(LuauOccursIsntAlwaysFailure, false) LUAU_FASTFLAG(LuauNormalizeBlockedTypes) LUAU_FASTFLAG(LuauAlwaysCommitInferencesOfFunctionCalls) +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) +LUAU_FASTFLAGVARIABLE(LuauTableUnifyRecursionLimit, false) namespace Luau { @@ -1741,7 +1743,10 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal } auto mkFreshType = [this](Scope* scope, TypeLevel level) { - return types->freshType(scope, level); + if (FFlag::DebugLuauDeferredConstraintResolution) + return freshType(NotNull{types}, builtinTypes, scope); + else + return types->freshType(scope, level); }; const TypePackId emptyTp = types->addTypePack(TypePack{{}, std::nullopt}); @@ -1977,7 +1982,7 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal // generic methods in tables to be marked read-only. if (FFlag::LuauInstantiateInSubtyping && shouldInstantiate) { - Instantiation instantiation{&log, types, scope->level, scope}; + Instantiation instantiation{&log, types, builtinTypes, scope->level, scope}; std::optional instantiated = instantiation.substitute(subTy); if (instantiated.has_value()) @@ -2126,7 +2131,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection, { if (variance == Covariant && subTable->state == TableState::Generic && superTable->state != TableState::Generic) { - Instantiation instantiation{&log, types, subTable->level, scope}; + Instantiation instantiation{&log, types, builtinTypes, subTable->level, scope}; std::optional instantiated = instantiation.substitute(subTy); if (instantiated.has_value()) @@ -2251,10 +2256,23 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection, if (superTable != newSuperTable || subTable != newSubTable) { - if (errors.empty()) - return tryUnifyTables(subTy, superTy, isIntersection); - else + if (FFlag::LuauTableUnifyRecursionLimit) + { + if (errors.empty()) + { + RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit); + tryUnifyTables(subTy, superTy, isIntersection); + } + return; + } + else + { + if (errors.empty()) + return tryUnifyTables(subTy, superTy, isIntersection); + else + return; + } } } @@ -2329,10 +2347,23 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection, if (superTable != newSuperTable || subTable != newSubTable) { - if (errors.empty()) - return tryUnifyTables(subTy, superTy, isIntersection); - else + if (FFlag::LuauTableUnifyRecursionLimit) + { + if (errors.empty()) + { + RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit); + tryUnifyTables(subTy, superTy, isIntersection); + } + return; + } + else + { + if (errors.empty()) + return tryUnifyTables(subTy, superTy, isIntersection); + else + return; + } } } diff --git a/Analysis/src/Unifier2.cpp b/Analysis/src/Unifier2.cpp new file mode 100644 index 00000000..0be6941b --- /dev/null +++ b/Analysis/src/Unifier2.cpp @@ -0,0 +1,366 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Unifier2.h" + +#include "Luau/Scope.h" +#include "Luau/Simplify.h" +#include "Luau/Substitution.h" +#include "Luau/ToString.h" +#include "Luau/TxnLog.h" +#include "Luau/Type.h" +#include "Luau/TypeArena.h" +#include "Luau/TypeUtils.h" +#include "Luau/VisitType.h" + +#include +#include + +LUAU_FASTINT(LuauTypeInferRecursionLimit) + +namespace Luau +{ + +Unifier2::Unifier2(NotNull arena, NotNull builtinTypes, NotNull ice) + : arena(arena) + , builtinTypes(builtinTypes) + , ice(ice) + , recursionLimit(FInt::LuauTypeInferRecursionLimit) +{ + +} + +bool Unifier2::unify(TypeId subTy, TypeId superTy) +{ + subTy = follow(subTy); + superTy = follow(superTy); + + if (subTy == superTy) + return true; + + FreeType* subFree = getMutable(subTy); + FreeType* superFree = getMutable(superTy); + + if (subFree) + subFree->upperBound = mkIntersection(subFree->upperBound, superTy); + + if (superFree) + superFree->lowerBound = mkUnion(superFree->lowerBound, subTy); + + if (subFree || superFree) + return true; + + const FunctionType* subFn = get(subTy); + const FunctionType* superFn = get(superTy); + + if (subFn && superFn) + { + bool argResult = unify(superFn->argTypes, subFn->argTypes); + bool retResult = unify(subFn->retTypes, superFn->retTypes); + return argResult && retResult; + } + + // The unification failed, but we're not doing type checking. + return true; +} + +// FIXME? This should probably return an ErrorVec or an optional +// rather than a boolean to signal an occurs check failure. +bool Unifier2::unify(TypePackId subTp, TypePackId superTp) +{ + subTp = follow(subTp); + superTp = follow(superTp); + + const FreeTypePack* subFree = get(subTp); + const FreeTypePack* superFree = get(superTp); + + if (subFree) + { + DenseHashSet seen{nullptr}; + if (OccursCheckResult::Fail == occursCheck(seen, subTp, superTp)) + { + asMutable(subTp)->ty.emplace(builtinTypes->errorRecoveryTypePack()); + return false; + } + + asMutable(subTp)->ty.emplace(superTp); + return true; + } + + if (superFree) + { + DenseHashSet seen{nullptr}; + if (OccursCheckResult::Fail == occursCheck(seen, superTp, subTp)) + { + asMutable(superTp)->ty.emplace(builtinTypes->errorRecoveryTypePack()); + return false; + } + + asMutable(superTp)->ty.emplace(subTp); + return true; + } + + size_t maxLength = std::max( + flatten(subTp).first.size(), + flatten(superTp).first.size() + ); + + auto [subTypes, subTail] = extendTypePack(*arena, builtinTypes, subTp, maxLength); + auto [superTypes, superTail] = extendTypePack(*arena, builtinTypes, superTp, maxLength); + + if (subTypes.size() < maxLength || superTypes.size() < maxLength) + return true; + + for (size_t i = 0; i < maxLength; ++i) + unify(subTypes[i], superTypes[i]); + + return true; +} + +struct FreeTypeSearcher : TypeVisitor +{ + NotNull scope; + + explicit FreeTypeSearcher(NotNull scope) + : TypeVisitor(/*skipBoundTypes*/ true) + , scope(scope) + {} + + enum { Positive, Negative } polarity = Positive; + + void flip() + { + switch (polarity) + { + case Positive: polarity = Negative; break; + case Negative: polarity = Positive; break; + } + } + + std::unordered_set negativeTypes; + std::unordered_set positiveTypes; + + bool visit(TypeId ty) override + { + LUAU_ASSERT(ty); + return true; + } + + bool visit(TypeId ty, const FreeType& ft) override + { + if (!subsumes(scope, ft.scope)) + return true; + + switch (polarity) + { + case Positive: positiveTypes.insert(ty); break; + case Negative: negativeTypes.insert(ty); break; + } + + return true; + } + + bool visit(TypeId ty, const FunctionType& ft) override + { + flip(); + traverse(ft.argTypes); + flip(); + + traverse(ft.retTypes); + + return false; + } +}; + +struct MutatingGeneralizer : TypeOnceVisitor +{ + NotNull builtinTypes; + + NotNull scope; + std::unordered_set positiveTypes; + std::unordered_set negativeTypes; + std::vector generics; + + MutatingGeneralizer(NotNull builtinTypes, NotNull scope, std::unordered_set positiveTypes, std::unordered_set negativeTypes) + : TypeOnceVisitor(/* skipBoundTypes */ true) + , builtinTypes(builtinTypes) + , scope(scope) + , positiveTypes(std::move(positiveTypes)) + , negativeTypes(std::move(negativeTypes)) + {} + + static void replace(DenseHashSet& seen, TypeId haystack, TypeId needle, TypeId replacement) + { + haystack = follow(haystack); + + if (seen.find(haystack)) + return; + seen.insert(haystack); + + std::vector* parts = nullptr; + if (UnionType* ut = getMutable(haystack)) + parts = &ut->options; + else if (IntersectionType* it = getMutable(needle)) + parts = &it->parts; + else + return; + + LUAU_ASSERT(parts); + + for (TypeId& option : *parts) + { + // FIXME: I bet this function has reentrancy problems + option = follow(option); + if (option == needle) + { + LUAU_ASSERT(!seen.find(option)); + option = replacement; + } + + // TODO seen set + else if (get(option)) + replace(seen, option, needle, haystack); + else if (get(option)) + replace(seen, option, needle, haystack); + } + } + + bool visit (TypeId ty, const FreeType&) override + { + const FreeType* ft = get(ty); + LUAU_ASSERT(ft); + + traverse(ft->lowerBound); + traverse(ft->upperBound); + + // ft is potentially invalid now. + ty = follow(ty); + ft = get(ty); + if (!ft) + return false; + + const bool isPositive = positiveTypes.count(ty); + const bool isNegative = negativeTypes.count(ty); + + if (!isPositive && !isNegative) + return false; + + const bool hasLowerBound = !get(follow(ft->lowerBound)); + const bool hasUpperBound = !get(follow(ft->upperBound)); + + DenseHashSet seen{nullptr}; + seen.insert(ty); + + if (!hasLowerBound && !hasUpperBound) + { + emplaceType(asMutable(ty), scope); + generics.push_back(ty); + } + + // It is possible that this free type has other free types in its upper + // or lower bounds. If this is the case, we must replace those + // references with never (for the lower bound) or unknown (for the upper + // bound). + // + // If we do not do this, we get tautological bounds like a <: a <: unknown. + else if (isPositive && !hasUpperBound) + { + if (FreeType* lowerFree = getMutable(ft->lowerBound); lowerFree && lowerFree->upperBound == ty) + lowerFree->upperBound = builtinTypes->unknownType; + else + replace(seen, ft->lowerBound, ty, builtinTypes->unknownType); + emplaceType(asMutable(ty), ft->lowerBound); + } + else + { + if (FreeType* upperFree = getMutable(ft->upperBound); upperFree && upperFree->lowerBound == ty) + upperFree->lowerBound = builtinTypes->neverType; + else + replace(seen, ft->upperBound, ty, builtinTypes->neverType); + emplaceType(asMutable(ty), ft->upperBound); + } + + return false; + } +}; + +std::optional Unifier2::generalize(NotNull scope, TypeId ty) +{ + ty = follow(ty); + + if (ty->owningArena != arena) + return ty; + + if (ty->persistent) + return ty; + + if (const FunctionType* ft = get(ty); ft && (!ft->generics.empty() || !ft->genericPacks.empty())) + return ty; + + FreeTypeSearcher fts{scope}; + fts.traverse(ty); + + MutatingGeneralizer gen{builtinTypes, scope, std::move(fts.positiveTypes), std::move(fts.negativeTypes)}; + + gen.traverse(ty); + + std::optional res = ty; + + FunctionType* ftv = getMutable(follow(*res)); + if (ftv) + ftv->generics = std::move(gen.generics); + + return res; +} + +TypeId Unifier2::mkUnion(TypeId left, TypeId right) +{ + left = follow(left); + right = follow(right); + + return simplifyUnion(builtinTypes, arena, left, right).result; +} + +TypeId Unifier2::mkIntersection(TypeId left, TypeId right) +{ + left = follow(left); + right = follow(right); + + return simplifyIntersection(builtinTypes, arena, left, right).result; +} + +OccursCheckResult Unifier2::occursCheck(DenseHashSet& seen, TypePackId needle, TypePackId haystack) +{ + needle = follow(needle); + haystack = follow(haystack); + + if (seen.find(haystack)) + return OccursCheckResult::Pass; + + seen.insert(haystack); + + if (getMutable(needle)) + return OccursCheckResult::Pass; + + if (!getMutable(needle)) + ice->ice("Expected needle pack to be free"); + + RecursionLimiter _ra(&recursionCount, recursionLimit); + + while (!getMutable(haystack)) + { + if (needle == haystack) + return OccursCheckResult::Fail; + + if (auto a = get(haystack); a && a->tail) + { + haystack = follow(*a->tail); + continue; + } + + break; + } + + return OccursCheckResult::Pass; +} + +} diff --git a/CodeGen/include/Luau/IrData.h b/CodeGen/include/Luau/IrData.h index b950a8ec..67909b13 100644 --- a/CodeGen/include/Luau/IrData.h +++ b/CodeGen/include/Luau/IrData.h @@ -168,7 +168,12 @@ enum class IrCmd : uint8_t // Compute Luau 'not' operation on destructured TValue // A: tag // B: int (value) - NOT_ANY, // TODO: boolean specialization will be useful + NOT_ANY, + + // Perform a TValue comparison, supported conditions are LessEqual, Less and Equal + // A, B: Rn + // C: condition + CMP_ANY, // Unconditional jump // A: block/vmexit @@ -224,13 +229,6 @@ enum class IrCmd : uint8_t // E: block (if false) JUMP_CMP_NUM, - // Perform a conditional jump based on the result of TValue comparison - // A, B: Rn - // C: condition - // D: block (if true) - // E: block (if false) - JUMP_CMP_ANY, - // Perform a conditional jump based on cached table node slot matching the actual table node slot for a key // A: pointer (LuaNode) // B: Kn @@ -377,27 +375,33 @@ enum class IrCmd : uint8_t // instead. CHECK_TAG, + // Guard against a falsy tag+value + // A: tag + // B: value + // C: block/vmexit/undef + CHECK_TRUTHY, + // Guard against readonly table // A: pointer (Table) - // B: block/undef + // B: block/vmexit/undef // When undef is specified instead of a block, execution is aborted on check failure CHECK_READONLY, // Guard against table having a metatable // A: pointer (Table) - // B: block/undef + // B: block/vmexit/undef // When undef is specified instead of a block, execution is aborted on check failure CHECK_NO_METATABLE, // Guard against executing in unsafe environment, exits to VM on check failure - // A: vmexit/undef + // A: vmexit/vmexit/undef // When undef is specified, execution is aborted on check failure CHECK_SAFE_ENV, // Guard against index overflowing the table array size // A: pointer (Table) // B: int (index) - // C: block/undef + // C: block/vmexit/undef // When undef is specified instead of a block, execution is aborted on check failure CHECK_ARRAY_SIZE, @@ -410,7 +414,7 @@ enum class IrCmd : uint8_t // Guard against table node with a linked next node to ensure that our lookup hits the main position of the key // A: pointer (LuaNode) - // B: block/undef + // B: block/vmexit/undef // When undef is specified instead of a block, execution is aborted on check failure CHECK_NODE_NO_NEXT, diff --git a/CodeGen/include/Luau/IrUtils.h b/CodeGen/include/Luau/IrUtils.h index 9a9b84b4..fe38cb90 100644 --- a/CodeGen/include/Luau/IrUtils.h +++ b/CodeGen/include/Luau/IrUtils.h @@ -99,7 +99,6 @@ inline bool isBlockTerminator(IrCmd cmd) case IrCmd::JUMP_GE_UINT: case IrCmd::JUMP_EQ_POINTER: case IrCmd::JUMP_CMP_NUM: - case IrCmd::JUMP_CMP_ANY: case IrCmd::JUMP_SLOT_MATCH: case IrCmd::RETURN: case IrCmd::FORGLOOP: @@ -122,6 +121,7 @@ inline bool isNonTerminatingJump(IrCmd cmd) case IrCmd::TRY_CALL_FASTGETTM: case IrCmd::CHECK_FASTCALL_RES: case IrCmd::CHECK_TAG: + case IrCmd::CHECK_TRUTHY: case IrCmd::CHECK_READONLY: case IrCmd::CHECK_NO_METATABLE: case IrCmd::CHECK_SAFE_ENV: @@ -167,6 +167,7 @@ inline bool hasResult(IrCmd cmd) case IrCmd::SQRT_NUM: case IrCmd::ABS_NUM: case IrCmd::NOT_ANY: + case IrCmd::CMP_ANY: case IrCmd::TABLE_LEN: case IrCmd::STRING_LEN: case IrCmd::NEW_TABLE: diff --git a/CodeGen/src/CodeGenUtils.cpp b/CodeGen/src/CodeGenUtils.cpp index 20269cfd..e2d5ac58 100644 --- a/CodeGen/src/CodeGenUtils.cpp +++ b/CodeGen/src/CodeGenUtils.cpp @@ -794,34 +794,31 @@ const Instruction* executeFORGPREP(lua_State* L, const Instruction* pc, StkId ba return pc; } -const Instruction* executeGETVARARGS(lua_State* L, const Instruction* pc, StkId base, TValue* k) +void executeGETVARARGSMultRet(lua_State* L, const Instruction* pc, StkId base, int rai) { [[maybe_unused]] Closure* cl = clvalue(L->ci->func); - Instruction insn = *pc++; - int b = LUAU_INSN_B(insn) - 1; int n = cast_int(base - L->ci->func) - cl->l.p->numparams - 1; - if (b == LUA_MULTRET) - { - VM_PROTECT(luaD_checkstack(L, n)); - StkId ra = VM_REG(LUAU_INSN_A(insn)); // previous call may change the stack + VM_PROTECT(luaD_checkstack(L, n)); + StkId ra = VM_REG(rai); // previous call may change the stack - for (int j = 0; j < n; j++) - setobj2s(L, ra + j, base - n + j); + for (int j = 0; j < n; j++) + setobj2s(L, ra + j, base - n + j); - L->top = ra + n; - return pc; - } - else - { - StkId ra = VM_REG(LUAU_INSN_A(insn)); + L->top = ra + n; +} - for (int j = 0; j < b && j < n; j++) - setobj2s(L, ra + j, base - n + j); - for (int j = n; j < b; j++) - setnilvalue(ra + j); - return pc; - } +void executeGETVARARGSConst(lua_State* L, StkId base, int rai, int b) +{ + [[maybe_unused]] Closure* cl = clvalue(L->ci->func); + int n = cast_int(base - L->ci->func) - cl->l.p->numparams - 1; + + StkId ra = VM_REG(rai); + + for (int j = 0; j < b && j < n; j++) + setobj2s(L, ra + j, base - n + j); + for (int j = n; j < b; j++) + setnilvalue(ra + j); } const Instruction* executeDUPCLOSURE(lua_State* L, const Instruction* pc, StkId base, TValue* k) diff --git a/CodeGen/src/CodeGenUtils.h b/CodeGen/src/CodeGenUtils.h index a30d7e98..15b794d2 100644 --- a/CodeGen/src/CodeGenUtils.h +++ b/CodeGen/src/CodeGenUtils.h @@ -27,7 +27,8 @@ const Instruction* executeNEWCLOSURE(lua_State* L, const Instruction* pc, StkId const Instruction* executeNAMECALL(lua_State* L, const Instruction* pc, StkId base, TValue* k); const Instruction* executeSETLIST(lua_State* L, const Instruction* pc, StkId base, TValue* k); const Instruction* executeFORGPREP(lua_State* L, const Instruction* pc, StkId base, TValue* k); -const Instruction* executeGETVARARGS(lua_State* L, const Instruction* pc, StkId base, TValue* k); +void executeGETVARARGSMultRet(lua_State* L, const Instruction* pc, StkId base, int rai); +void executeGETVARARGSConst(lua_State* L, StkId base, int rai, int b); const Instruction* executeDUPCLOSURE(lua_State* L, const Instruction* pc, StkId base, TValue* k); const Instruction* executePREPVARARGS(lua_State* L, const Instruction* pc, StkId base, TValue* k); diff --git a/CodeGen/src/EmitCommonX64.cpp b/CodeGen/src/EmitCommonX64.cpp index 2ad5b040..e6fae4cc 100644 --- a/CodeGen/src/EmitCommonX64.cpp +++ b/CodeGen/src/EmitCommonX64.cpp @@ -66,29 +66,6 @@ void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs, } } -void jumpOnAnyCmpFallback(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb, IrCondition cond, Label& label) -{ - IrCallWrapperX64 callWrap(regs, build); - callWrap.addArgument(SizeX64::qword, rState); - callWrap.addArgument(SizeX64::qword, luauRegAddress(ra)); - callWrap.addArgument(SizeX64::qword, luauRegAddress(rb)); - - if (cond == IrCondition::NotLessEqual || cond == IrCondition::LessEqual) - callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_lessequal)]); - else if (cond == IrCondition::NotLess || cond == IrCondition::Less) - callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_lessthan)]); - else if (cond == IrCondition::NotEqual || cond == IrCondition::Equal) - callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_equalval)]); - else - LUAU_ASSERT(!"Unsupported condition"); - - emitUpdateBase(build); - build.test(eax, eax); - build.jcc(cond == IrCondition::NotLessEqual || cond == IrCondition::NotLess || cond == IrCondition::NotEqual ? ConditionX64::Zero - : ConditionX64::NotZero, - label); -} - void getTableNodeAtCachedSlot(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 node, RegisterX64 table, int pcpos) { LUAU_ASSERT(tmp != node); diff --git a/CodeGen/src/EmitCommonX64.h b/CodeGen/src/EmitCommonX64.h index 02d9f40b..d8c68da4 100644 --- a/CodeGen/src/EmitCommonX64.h +++ b/CodeGen/src/EmitCommonX64.h @@ -160,7 +160,6 @@ inline void jumpIfTruthy(AssemblyBuilderX64& build, int ri, Label& target, Label } void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs, OperandX64 rhs, IrCondition cond, Label& label); -void jumpOnAnyCmpFallback(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb, IrCondition cond, Label& label); void getTableNodeAtCachedSlot(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 node, RegisterX64 table, int pcpos); void convertNumberToIndexOrJump(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 numd, RegisterX64 numi, Label& label); diff --git a/CodeGen/src/IrAnalysis.cpp b/CodeGen/src/IrAnalysis.cpp index 62c0b8ab..23f2dd21 100644 --- a/CodeGen/src/IrAnalysis.cpp +++ b/CodeGen/src/IrAnalysis.cpp @@ -274,14 +274,14 @@ static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& case IrCmd::STORE_TVALUE: maybeDef(inst.a); // Argument can also be a pointer value break; + case IrCmd::CMP_ANY: + use(inst.a); + use(inst.b); + break; case IrCmd::JUMP_IF_TRUTHY: case IrCmd::JUMP_IF_FALSY: use(inst.a); break; - case IrCmd::JUMP_CMP_ANY: - use(inst.a); - use(inst.b); - break; // A <- B, C case IrCmd::DO_ARITH: case IrCmd::GET_TABLE: diff --git a/CodeGen/src/IrDump.cpp b/CodeGen/src/IrDump.cpp index ce0cbfb3..12b75bcc 100644 --- a/CodeGen/src/IrDump.cpp +++ b/CodeGen/src/IrDump.cpp @@ -147,6 +147,8 @@ const char* getCmdName(IrCmd cmd) return "ABS_NUM"; case IrCmd::NOT_ANY: return "NOT_ANY"; + case IrCmd::CMP_ANY: + return "CMP_ANY"; case IrCmd::JUMP: return "JUMP"; case IrCmd::JUMP_IF_TRUTHY: @@ -165,8 +167,6 @@ const char* getCmdName(IrCmd cmd) return "JUMP_EQ_POINTER"; case IrCmd::JUMP_CMP_NUM: return "JUMP_CMP_NUM"; - case IrCmd::JUMP_CMP_ANY: - return "JUMP_CMP_ANY"; case IrCmd::JUMP_SLOT_MATCH: return "JUMP_SLOT_MATCH"; case IrCmd::TABLE_LEN: @@ -219,6 +219,8 @@ const char* getCmdName(IrCmd cmd) return "PREPARE_FORN"; case IrCmd::CHECK_TAG: return "CHECK_TAG"; + case IrCmd::CHECK_TRUTHY: + return "CHECK_TRUTHY"; case IrCmd::CHECK_READONLY: return "CHECK_READONLY"; case IrCmd::CHECK_NO_METATABLE: diff --git a/CodeGen/src/IrLoweringA64.cpp b/CodeGen/src/IrLoweringA64.cpp index 3c247abd..16796a28 100644 --- a/CodeGen/src/IrLoweringA64.cpp +++ b/CodeGen/src/IrLoweringA64.cpp @@ -538,6 +538,33 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } break; } + case IrCmd::CMP_ANY: + { + IrCondition cond = conditionOp(inst.c); + + regs.spill(build, index); + build.mov(x0, rState); + build.add(x1, rBase, uint16_t(vmRegOp(inst.a) * sizeof(TValue))); + build.add(x2, rBase, uint16_t(vmRegOp(inst.b) * sizeof(TValue))); + + if (cond == IrCondition::LessEqual) + build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaV_lessequal))); + else if (cond == IrCondition::Less) + build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaV_lessthan))); + else if (cond == IrCondition::Equal) + build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaV_equalval))); + else + LUAU_ASSERT(!"Unsupported condition"); + + build.blr(x3); + + emitUpdateBase(build); + + // since w0 came from a call, we need to move it so that we don't violate zextReg safety contract + inst.regA64 = regs.allocReg(KindA64::w, index); + build.mov(inst.regA64, w0); + break; + } case IrCmd::JUMP: if (inst.a.kind == IrOpKind::VmExit) { @@ -671,35 +698,6 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) jumpOrFallthrough(blockOp(inst.e), next); break; } - case IrCmd::JUMP_CMP_ANY: - { - IrCondition cond = conditionOp(inst.c); - - regs.spill(build, index); - build.mov(x0, rState); - build.add(x1, rBase, uint16_t(vmRegOp(inst.a) * sizeof(TValue))); - build.add(x2, rBase, uint16_t(vmRegOp(inst.b) * sizeof(TValue))); - - if (cond == IrCondition::NotLessEqual || cond == IrCondition::LessEqual) - build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaV_lessequal))); - else if (cond == IrCondition::NotLess || cond == IrCondition::Less) - build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaV_lessthan))); - else if (cond == IrCondition::NotEqual || cond == IrCondition::Equal) - build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaV_equalval))); - else - LUAU_ASSERT(!"Unsupported condition"); - - build.blr(x3); - - emitUpdateBase(build); - - if (cond == IrCondition::NotLessEqual || cond == IrCondition::NotLess || cond == IrCondition::NotEqual) - build.cbz(x0, labelOp(inst.d)); - else - build.cbnz(x0, labelOp(inst.d)); - jumpOrFallthrough(blockOp(inst.e), next); - break; - } // IrCmd::JUMP_SLOT_MATCH implemented below case IrCmd::TABLE_LEN: { @@ -1072,6 +1070,36 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) finalizeTargetLabel(inst.c, fresh); break; } + case IrCmd::CHECK_TRUTHY: + { + // Constant tags which don't require boolean value check should've been removed in constant folding + LUAU_ASSERT(inst.a.kind != IrOpKind::Constant || tagOp(inst.a) == LUA_TBOOLEAN); + + Label fresh; // used when guard aborts execution or jumps to a VM exit + Label& target = getTargetLabel(inst.c, fresh); + + Label skip; + + if (inst.a.kind != IrOpKind::Constant) + { + // fail to fallback on 'nil' (falsy) + LUAU_ASSERT(LUA_TNIL == 0); + build.cbz(regOp(inst.a), target); + + // skip value test if it's not a boolean (truthy) + build.cmp(regOp(inst.a), LUA_TBOOLEAN); + build.b(ConditionA64::NotEqual, skip); + } + + // fail to fallback on 'false' boolean value (falsy) + build.cbz(regOp(inst.b), target); + + if (inst.a.kind != IrOpKind::Constant) + build.setLabel(skip); + + finalizeTargetLabel(inst.c, fresh); + break; + } case IrCmd::CHECK_READONLY: { Label fresh; // used when guard aborts execution or jumps to a VM exit @@ -1530,7 +1558,26 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) LUAU_ASSERT(inst.c.kind == IrOpKind::Constant); regs.spill(build, index); - emitFallback(build, offsetof(NativeContext, executeGETVARARGS), uintOp(inst.a)); + build.mov(x0, rState); + + if (intOp(inst.c) == LUA_MULTRET) + { + emitAddOffset(build, x1, rCode, uintOp(inst.a) * sizeof(Instruction)); + build.mov(x2, rBase); + build.mov(x3, vmRegOp(inst.b)); + build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, executeGETVARARGSMultRet))); + build.blr(x4); + + emitUpdateBase(build); + } + else + { + build.mov(x1, rBase); + build.mov(x2, vmRegOp(inst.b)); + build.mov(x3, intOp(inst.c)); + build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, executeGETVARARGSConst))); + build.blr(x4); + } break; case IrCmd::NEWCLOSURE: { diff --git a/CodeGen/src/IrLoweringX64.cpp b/CodeGen/src/IrLoweringX64.cpp index e791e55d..6fe1e771 100644 --- a/CodeGen/src/IrLoweringX64.cpp +++ b/CodeGen/src/IrLoweringX64.cpp @@ -541,6 +541,29 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.setLabel(exit); break; } + case IrCmd::CMP_ANY: + { + IrCondition cond = conditionOp(inst.c); + + IrCallWrapperX64 callWrap(regs, build); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::qword, luauRegAddress(vmRegOp(inst.a))); + callWrap.addArgument(SizeX64::qword, luauRegAddress(vmRegOp(inst.b))); + + if (cond == IrCondition::LessEqual) + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_lessequal)]); + else if (cond == IrCondition::Less) + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_lessthan)]); + else if (cond == IrCondition::Equal) + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_equalval)]); + else + LUAU_ASSERT(!"Unsupported condition"); + + emitUpdateBase(build); + + inst.regX64 = regs.takeReg(eax, index); + break; + } case IrCmd::JUMP: if (inst.a.kind == IrOpKind::VmExit) { @@ -589,10 +612,28 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) break; } case IrCmd::JUMP_EQ_INT: - build.cmp(regOp(inst.a), intOp(inst.b)); + if (intOp(inst.b) == 0) + { + build.test(regOp(inst.a), regOp(inst.a)); - build.jcc(ConditionX64::Equal, labelOp(inst.c)); - jumpOrFallthrough(blockOp(inst.d), next); + if (isFallthroughBlock(blockOp(inst.c), next)) + { + build.jcc(ConditionX64::NotZero, labelOp(inst.d)); + jumpOrFallthrough(blockOp(inst.c), next); + } + else + { + build.jcc(ConditionX64::Zero, labelOp(inst.c)); + jumpOrFallthrough(blockOp(inst.d), next); + } + } + else + { + build.cmp(regOp(inst.a), intOp(inst.b)); + + build.jcc(ConditionX64::Equal, labelOp(inst.c)); + jumpOrFallthrough(blockOp(inst.d), next); + } break; case IrCmd::JUMP_LT_INT: build.cmp(regOp(inst.a), intOp(inst.b)); @@ -623,10 +664,6 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) jumpOrFallthrough(blockOp(inst.e), next); break; } - case IrCmd::JUMP_CMP_ANY: - jumpOnAnyCmpFallback(regs, build, vmRegOp(inst.a), vmRegOp(inst.b), conditionOp(inst.c), labelOp(inst.d)); - jumpOrFallthrough(blockOp(inst.e), next); - break; case IrCmd::TABLE_LEN: { IrCallWrapperX64 callWrap(regs, build, index); @@ -944,6 +981,32 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) jumpOrAbortOnUndef(ConditionX64::NotEqual, ConditionX64::Equal, inst.c, continueInVm); break; } + case IrCmd::CHECK_TRUTHY: + { + // Constant tags which don't require boolean value check should've been removed in constant folding + LUAU_ASSERT(inst.a.kind != IrOpKind::Constant || tagOp(inst.a) == LUA_TBOOLEAN); + + Label skip; + + if (inst.a.kind != IrOpKind::Constant) + { + // Fail to fallback on 'nil' (falsy) + build.cmp(memRegTagOp(inst.a), LUA_TNIL); + jumpOrAbortOnUndef(ConditionX64::Equal, ConditionX64::NotEqual, inst.c); + + // Skip value test if it's not a boolean (truthy) + build.cmp(memRegTagOp(inst.a), LUA_TBOOLEAN); + build.jcc(ConditionX64::NotEqual, skip); + } + + // fail to fallback on 'false' boolean value (falsy) + build.cmp(memRegUintOp(inst.b), 0); + jumpOrAbortOnUndef(ConditionX64::Equal, ConditionX64::NotEqual, inst.c); + + if (inst.a.kind != IrOpKind::Constant) + build.setLabel(skip); + break; + } case IrCmd::CHECK_READONLY: build.cmp(byte[regOp(inst.a) + offsetof(Table, readonly)], 0); jumpOrAbortOnUndef(ConditionX64::NotEqual, ConditionX64::Equal, inst.b); @@ -1231,7 +1294,30 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.c.kind == IrOpKind::Constant); - emitFallback(regs, build, offsetof(NativeContext, executeGETVARARGS), uintOp(inst.a)); + if (intOp(inst.c) == LUA_MULTRET) + { + IrCallWrapperX64 callWrap(regs, build); + callWrap.addArgument(SizeX64::qword, rState); + + RegisterX64 reg = callWrap.suggestNextArgumentRegister(SizeX64::qword); + build.mov(reg, sCode); + callWrap.addArgument(SizeX64::qword, addr[reg + uintOp(inst.a) * sizeof(Instruction)]); + + callWrap.addArgument(SizeX64::qword, rBase); + callWrap.addArgument(SizeX64::dword, vmRegOp(inst.b)); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, executeGETVARARGSMultRet)]); + + emitUpdateBase(build); + } + else + { + IrCallWrapperX64 callWrap(regs, build); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::qword, rBase); + callWrap.addArgument(SizeX64::dword, vmRegOp(inst.b)); + callWrap.addArgument(SizeX64::dword, intOp(inst.c)); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, executeGETVARARGSConst)]); + } break; case IrCmd::NEWCLOSURE: { @@ -1585,6 +1671,8 @@ OperandX64 IrLoweringX64::memRegUintOp(IrOp op) return regOp(op); case IrOpKind::Constant: return OperandX64(unsigned(intOp(op))); + case IrOpKind::VmReg: + return luauRegValueInt(vmRegOp(op)); default: LUAU_ASSERT(!"Unsupported operand kind"); } diff --git a/CodeGen/src/IrTranslateBuiltins.cpp b/CodeGen/src/IrTranslateBuiltins.cpp index 73055c39..8392ad84 100644 --- a/CodeGen/src/IrTranslateBuiltins.cpp +++ b/CodeGen/src/IrTranslateBuiltins.cpp @@ -137,16 +137,17 @@ static BuiltinImplResult translateBuiltinNumberTo2Number( return {BuiltinImplType::Full, 2}; } -static BuiltinImplResult translateBuiltinAssert(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +static BuiltinImplResult translateBuiltinAssert(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, int pcpos) { if (nparams < 1 || nresults != 0) return {BuiltinImplType::None, -1}; - IrOp cont = build.block(IrBlockKind::Internal); + IrOp tag = build.inst(IrCmd::LOAD_TAG, build.vmReg(arg)); - // TODO: maybe adding a guard like CHECK_TRUTHY can be useful - build.inst(IrCmd::JUMP_IF_FALSY, build.vmReg(arg), fallback, cont); - build.beginBlock(cont); + // We don't know if it's really a boolean at this point, but we will only check this value if it is + IrOp value = build.inst(IrCmd::LOAD_INT, build.vmReg(arg)); + + build.inst(IrCmd::CHECK_TRUTHY, tag, value, build.vmExit(pcpos)); return {BuiltinImplType::UsesFallback, 0}; } @@ -463,8 +464,6 @@ static BuiltinImplResult translateBuiltinBit32Shift( if (nparams < 2 || nresults > 1) return {BuiltinImplType::None, -1}; - IrOp block = build.block(IrBlockKind::Internal); - builtinCheckDouble(build, build.vmReg(arg), pcpos); builtinCheckDouble(build, args, pcpos); @@ -472,10 +471,22 @@ static BuiltinImplResult translateBuiltinBit32Shift( IrOp vb = builtinLoadDouble(build, args); IrOp vaui = build.inst(IrCmd::NUM_TO_UINT, va); - IrOp vbi = build.inst(IrCmd::NUM_TO_INT, vb); - build.inst(IrCmd::JUMP_GE_UINT, vbi, build.constInt(32), fallback, block); - build.beginBlock(block); + IrOp vbi; + + if (std::optional vbd = build.function.asDoubleOp(vb); vbd && *vbd >= INT_MIN && *vbd <= INT_MAX) + vbi = build.constInt(int(*vbd)); + else + vbi = build.inst(IrCmd::NUM_TO_INT, vb); + + bool knownGoodShift = unsigned(build.function.asIntOp(vbi).value_or(-1)) < 32u; + + if (!knownGoodShift) + { + IrOp block = build.block(IrBlockKind::Internal); + build.inst(IrCmd::JUMP_GE_UINT, vbi, build.constInt(32), fallback, block); + build.beginBlock(block); + } IrCmd cmd = IrCmd::NOP; if (bfid == LBF_BIT32_LSHIFT) @@ -763,7 +774,7 @@ BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg, switch (bfid) { case LBF_ASSERT: - return translateBuiltinAssert(build, nparams, ra, arg, args, nresults, fallback); + return translateBuiltinAssert(build, nparams, ra, arg, args, nresults, pcpos); case LBF_MATH_DEG: return translateBuiltinMathDeg(build, nparams, ra, arg, args, nresults, pcpos); case LBF_MATH_RAD: diff --git a/CodeGen/src/IrTranslation.cpp b/CodeGen/src/IrTranslation.cpp index 63e756e1..363a1cdb 100644 --- a/CodeGen/src/IrTranslation.cpp +++ b/CodeGen/src/IrTranslation.cpp @@ -167,7 +167,9 @@ void translateInstJumpIfEq(IrBuilder& build, const Instruction* pc, int pcpos, b build.beginBlock(fallback); build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1)); - build.inst(IrCmd::JUMP_CMP_ANY, build.vmReg(ra), build.vmReg(rb), build.cond(not_ ? IrCondition::NotEqual : IrCondition::Equal), target, next); + + IrOp result = build.inst(IrCmd::CMP_ANY, build.vmReg(ra), build.vmReg(rb), build.cond(IrCondition::Equal)); + build.inst(IrCmd::JUMP_EQ_INT, result, build.constInt(0), not_ ? target : next, not_ ? next : target); build.beginBlock(next); } @@ -195,7 +197,27 @@ void translateInstJumpIfCond(IrBuilder& build, const Instruction* pc, int pcpos, build.beginBlock(fallback); build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1)); - build.inst(IrCmd::JUMP_CMP_ANY, build.vmReg(ra), build.vmReg(rb), build.cond(cond), target, next); + + bool reverse = false; + + if (cond == IrCondition::NotLessEqual) + { + reverse = true; + cond = IrCondition::LessEqual; + } + else if (cond == IrCondition::NotLess) + { + reverse = true; + cond = IrCondition::Less; + } + else if (cond == IrCondition::NotEqual) + { + reverse = true; + cond = IrCondition::Equal; + } + + IrOp result = build.inst(IrCmd::CMP_ANY, build.vmReg(ra), build.vmReg(rb), build.cond(cond)); + build.inst(IrCmd::JUMP_EQ_INT, result, build.constInt(0), reverse ? target : next, reverse ? next : target); build.beginBlock(next); } diff --git a/CodeGen/src/IrUtils.cpp b/CodeGen/src/IrUtils.cpp index 310c15b8..c2d3e1a8 100644 --- a/CodeGen/src/IrUtils.cpp +++ b/CodeGen/src/IrUtils.cpp @@ -66,6 +66,7 @@ IrValueKind getCmdValueKind(IrCmd cmd) case IrCmd::ABS_NUM: return IrValueKind::Double; case IrCmd::NOT_ANY: + case IrCmd::CMP_ANY: return IrValueKind::Int; case IrCmd::JUMP: case IrCmd::JUMP_IF_TRUTHY: @@ -76,7 +77,6 @@ IrValueKind getCmdValueKind(IrCmd cmd) case IrCmd::JUMP_GE_UINT: case IrCmd::JUMP_EQ_POINTER: case IrCmd::JUMP_CMP_NUM: - case IrCmd::JUMP_CMP_ANY: case IrCmd::JUMP_SLOT_MATCH: return IrValueKind::None; case IrCmd::TABLE_LEN: @@ -114,6 +114,7 @@ IrValueKind getCmdValueKind(IrCmd cmd) case IrCmd::SET_UPVALUE: case IrCmd::PREPARE_FORN: case IrCmd::CHECK_TAG: + case IrCmd::CHECK_TRUTHY: case IrCmd::CHECK_READONLY: case IrCmd::CHECK_NO_METATABLE: case IrCmd::CHECK_SAFE_ENV: @@ -624,6 +625,29 @@ void foldConstants(IrBuilder& build, IrFunction& function, IrBlock& block, uint3 replace(function, block, index, {IrCmd::JUMP, inst.c}); // Shows a conflict in assumptions on this path } break; + case IrCmd::CHECK_TRUTHY: + if (inst.a.kind == IrOpKind::Constant) + { + if (function.tagOp(inst.a) == LUA_TNIL) + { + replace(function, block, index, {IrCmd::JUMP, inst.c}); // Shows a conflict in assumptions on this path + } + else if (function.tagOp(inst.a) == LUA_TBOOLEAN) + { + if (inst.b.kind == IrOpKind::Constant) + { + if (function.intOp(inst.b) == 0) + replace(function, block, index, {IrCmd::JUMP, inst.c}); // Shows a conflict in assumptions on this path + else + kill(function, inst); + } + } + else + { + kill(function, inst); + } + } + break; case IrCmd::BITAND_UINT: if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) { diff --git a/CodeGen/src/IrValueLocationTracking.cpp b/CodeGen/src/IrValueLocationTracking.cpp index 0ed7c388..4536630b 100644 --- a/CodeGen/src/IrValueLocationTracking.cpp +++ b/CodeGen/src/IrValueLocationTracking.cpp @@ -90,9 +90,9 @@ void IrValueLocationTracking::beforeInstLowering(IrInst& inst) case IrCmd::LOAD_DOUBLE: case IrCmd::LOAD_INT: case IrCmd::LOAD_TVALUE: + case IrCmd::CMP_ANY: case IrCmd::JUMP_IF_TRUTHY: case IrCmd::JUMP_IF_FALSY: - case IrCmd::JUMP_CMP_ANY: case IrCmd::SET_TABLE: case IrCmd::SET_UPVALUE: case IrCmd::INTERRUPT: @@ -114,6 +114,7 @@ void IrValueLocationTracking::beforeInstLowering(IrInst& inst) // These instrucitons read VmReg only after optimizeMemoryOperandsX64 case IrCmd::CHECK_TAG: + case IrCmd::CHECK_TRUTHY: case IrCmd::ADD_NUM: case IrCmd::SUB_NUM: case IrCmd::MUL_NUM: diff --git a/CodeGen/src/NativeState.cpp b/CodeGen/src/NativeState.cpp index 65984562..e7a0c424 100644 --- a/CodeGen/src/NativeState.cpp +++ b/CodeGen/src/NativeState.cpp @@ -101,7 +101,8 @@ void initFunctions(NativeState& data) data.context.executeNEWCLOSURE = executeNEWCLOSURE; data.context.executeNAMECALL = executeNAMECALL; data.context.executeFORGPREP = executeFORGPREP; - data.context.executeGETVARARGS = executeGETVARARGS; + data.context.executeGETVARARGSMultRet = executeGETVARARGSMultRet; + data.context.executeGETVARARGSConst = executeGETVARARGSConst; data.context.executeDUPCLOSURE = executeDUPCLOSURE; data.context.executePREPVARARGS = executePREPVARARGS; data.context.executeSETLIST = executeSETLIST; diff --git a/CodeGen/src/NativeState.h b/CodeGen/src/NativeState.h index 1a039812..4aa5c8a2 100644 --- a/CodeGen/src/NativeState.h +++ b/CodeGen/src/NativeState.h @@ -98,7 +98,8 @@ struct NativeContext const Instruction* (*executeNAMECALL)(lua_State* L, const Instruction* pc, StkId base, TValue* k) = nullptr; const Instruction* (*executeSETLIST)(lua_State* L, const Instruction* pc, StkId base, TValue* k) = nullptr; const Instruction* (*executeFORGPREP)(lua_State* L, const Instruction* pc, StkId base, TValue* k) = nullptr; - const Instruction* (*executeGETVARARGS)(lua_State* L, const Instruction* pc, StkId base, TValue* k) = nullptr; + void (*executeGETVARARGSMultRet)(lua_State* L, const Instruction* pc, StkId base, int rai) = nullptr; + void (*executeGETVARARGSConst)(lua_State* L, StkId base, int rai, int b) = nullptr; const Instruction* (*executeDUPCLOSURE)(lua_State* L, const Instruction* pc, StkId base, TValue* k) = nullptr; const Instruction* (*executePREPVARARGS)(lua_State* L, const Instruction* pc, StkId base, TValue* k) = nullptr; diff --git a/CodeGen/src/OptimizeConstProp.cpp b/CodeGen/src/OptimizeConstProp.cpp index 758518a2..a5e20b16 100644 --- a/CodeGen/src/OptimizeConstProp.cpp +++ b/CodeGen/src/OptimizeConstProp.cpp @@ -513,6 +513,15 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& { state.invalidateValue(inst.a); state.forwardVmRegStoreToLoad(inst, IrCmd::LOAD_POINTER); + + if (IrInst* instOp = function.asInstOp(inst.b); instOp && instOp->cmd == IrCmd::NEW_TABLE) + { + if (RegisterInfo* info = state.tryGetRegisterInfo(inst.a)) + { + info->knownNotReadonly = true; + info->knownNoMetatable = true; + } + } } break; case IrCmd::STORE_DOUBLE: @@ -681,6 +690,9 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& } break; } + case IrCmd::CHECK_TRUTHY: + // It is possible to check if current tag in state is truthy or not, but this case almost never comes up + break; case IrCmd::CHECK_READONLY: if (RegisterInfo* info = state.tryGetRegisterInfo(inst.a)) { @@ -782,6 +794,9 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::NOT_ANY: state.substituteOrRecord(inst, index); break; + case IrCmd::CMP_ANY: + state.invalidateUserCall(); + break; case IrCmd::JUMP: case IrCmd::JUMP_EQ_POINTER: case IrCmd::JUMP_SLOT_MATCH: @@ -840,9 +855,6 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::FINDUPVAL: break; - case IrCmd::JUMP_CMP_ANY: - state.invalidateUserCall(); // TODO: if arguments are strings, there will be no user calls - break; case IrCmd::DO_ARITH: state.invalidate(inst.a); state.invalidateUserCall(); diff --git a/CodeGen/src/OptimizeFinalX64.cpp b/CodeGen/src/OptimizeFinalX64.cpp index 5ee626ae..63642c46 100644 --- a/CodeGen/src/OptimizeFinalX64.cpp +++ b/CodeGen/src/OptimizeFinalX64.cpp @@ -35,6 +35,25 @@ static void optimizeMemoryOperandsX64(IrFunction& function, IrBlock& block) } break; } + case IrCmd::CHECK_TRUTHY: + { + if (inst.a.kind == IrOpKind::Inst) + { + IrInst& tag = function.instOp(inst.a); + + if (tag.useCount == 1 && tag.cmd == IrCmd::LOAD_TAG && (tag.a.kind == IrOpKind::VmReg || tag.a.kind == IrOpKind::VmConst)) + replace(function, inst.a, tag.a); + } + + if (inst.b.kind == IrOpKind::Inst) + { + IrInst& value = function.instOp(inst.b); + + if (value.useCount == 1 && value.cmd == IrCmd::LOAD_INT) + replace(function, inst.b, value.a); + } + break; + } case IrCmd::ADD_NUM: case IrCmd::SUB_NUM: case IrCmd::MUL_NUM: diff --git a/Common/include/Luau/ExperimentalFlags.h b/Common/include/Luau/ExperimentalFlags.h index 8eca1050..fd074be1 100644 --- a/Common/include/Luau/ExperimentalFlags.h +++ b/Common/include/Luau/ExperimentalFlags.h @@ -12,7 +12,6 @@ inline bool isFlagExperimental(const char* flag) // or critical bugs that are found after the code has been submitted. static const char* const kList[] = { "LuauInstantiateInSubtyping", // requires some fixes to lua-apps code - "LuauTypecheckTypeguards", // requires some fixes to lua-apps code (CLI-67030) "LuauTinyControlFlowAnalysis", // waiting for updates to packages depended by internal builtin plugins // makes sure we always have at least one entry nullptr, diff --git a/Compiler/src/BuiltinFolding.cpp b/Compiler/src/BuiltinFolding.cpp index 03b5918c..8fa4b7c7 100644 --- a/Compiler/src/BuiltinFolding.cpp +++ b/Compiler/src/BuiltinFolding.cpp @@ -10,7 +10,8 @@ namespace Luau namespace Compile { -const double kRadDeg = 3.14159265358979323846 / 180.0; +const double kPi = 3.14159265358979323846; +const double kRadDeg = kPi / 180.0; static Constant cvar() { @@ -460,5 +461,16 @@ Constant foldBuiltin(int bfid, const Constant* args, size_t count) return cvar(); } +Constant foldBuiltinMath(AstName index) +{ + if (index == "pi") + return cnum(kPi); + + if (index == "huge") + return cnum(HUGE_VAL); + + return cvar(); +} + } // namespace Compile } // namespace Luau diff --git a/Compiler/src/BuiltinFolding.h b/Compiler/src/BuiltinFolding.h index 1904e14f..dd1ca8c0 100644 --- a/Compiler/src/BuiltinFolding.h +++ b/Compiler/src/BuiltinFolding.h @@ -9,6 +9,7 @@ namespace Compile { Constant foldBuiltin(int bfid, const Constant* args, size_t count); +Constant foldBuiltinMath(AstName index); } // namespace Compile } // namespace Luau diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 23deec9b..6ae31825 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -31,6 +31,8 @@ LUAU_FASTFLAGVARIABLE(LuauCompileNativeComment, false) LUAU_FASTFLAGVARIABLE(LuauCompileFixBuiltinArity, false) +LUAU_FASTFLAGVARIABLE(LuauCompileFoldMathK, false) + namespace Luau { @@ -661,7 +663,7 @@ struct Compiler inlineFrames.push_back({func, oldLocals, target, targetCount}); // fold constant values updated above into expressions in the function body - foldConstants(constants, variables, locstants, builtinsFold, func->body); + foldConstants(constants, variables, locstants, builtinsFold, builtinsFoldMathK, func->body); bool usedFallthrough = false; @@ -702,7 +704,7 @@ struct Compiler if (Constant* var = locstants.find(func->args.data[i])) var->type = Constant::Type_Unknown; - foldConstants(constants, variables, locstants, builtinsFold, func->body); + foldConstants(constants, variables, locstants, builtinsFold, builtinsFoldMathK, func->body); } void compileExprCall(AstExprCall* expr, uint8_t target, uint8_t targetCount, bool targetTop = false, bool multRet = false) @@ -2807,7 +2809,7 @@ struct Compiler locstants[var].type = Constant::Type_Number; locstants[var].valueNumber = from + iv * step; - foldConstants(constants, variables, locstants, builtinsFold, stat); + foldConstants(constants, variables, locstants, builtinsFold, builtinsFoldMathK, stat); size_t iterJumps = loopJumps.size(); @@ -2835,7 +2837,7 @@ struct Compiler // clean up fold state in case we need to recompile - normally we compile the loop body once, but due to inlining we may need to do it again locstants[var].type = Constant::Type_Unknown; - foldConstants(constants, variables, locstants, builtinsFold, stat); + foldConstants(constants, variables, locstants, builtinsFold, builtinsFoldMathK, stat); } void compileStatFor(AstStatFor* stat) @@ -3604,6 +3606,7 @@ struct Compiler { Compiler* self; std::vector& functions; + bool hasTypes = false; FunctionVisitor(Compiler* self, std::vector& functions) : self(self) @@ -3617,6 +3620,10 @@ struct Compiler { node->body->visit(this); + if (FFlag::LuauCompileFunctionType) + for (AstLocal* arg : node->args) + hasTypes |= arg->annotation != nullptr; + // this makes sure all functions that are used when compiling this one have been already added to the vector functions.push_back(node); @@ -3824,6 +3831,7 @@ struct Compiler DenseHashMap typeMap; const DenseHashMap* builtinsFold = nullptr; + bool builtinsFoldMathK = false; unsigned int regTop = 0; unsigned int stackSize = 0; @@ -3874,15 +3882,21 @@ void compileOrThrow(BytecodeBuilder& bytecode, const ParseResult& parseResult, c // builtin folding is enabled on optimization level 2 since we can't deoptimize folding at runtime if (options.optimizationLevel >= 2) + { compiler.builtinsFold = &compiler.builtins; + if (FFlag::LuauCompileFoldMathK) + if (AstName math = names.get("math"); math.value && getGlobalState(compiler.globals, math) == Global::Default) + compiler.builtinsFoldMathK = true; + } + if (options.optimizationLevel >= 1) { // this pass tracks which calls are builtins and can be compiled more efficiently analyzeBuiltins(compiler.builtins, compiler.globals, compiler.variables, options, root); // this pass analyzes constantness of expressions - foldConstants(compiler.constants, compiler.variables, compiler.locstants, compiler.builtinsFold, root); + foldConstants(compiler.constants, compiler.variables, compiler.locstants, compiler.builtinsFold, compiler.builtinsFoldMathK, root); // this pass analyzes table assignments to estimate table shapes for initially empty tables predictTableShapes(compiler.tableShapes, root); @@ -3895,17 +3909,16 @@ void compileOrThrow(BytecodeBuilder& bytecode, const ParseResult& parseResult, c root->visit(&fenvVisitor); } - if (FFlag::LuauCompileFunctionType) - { - buildTypeMap(compiler.typeMap, root, options.vectorType); - } - // gathers all functions with the invariant that all function references are to functions earlier in the list // for example, function foo() return function() end end will result in two vector entries, [0] = anonymous and [1] = foo std::vector functions; Compiler::FunctionVisitor functionVisitor(&compiler, functions); root->visit(&functionVisitor); + // computes type information for all functions based on type annotations + if (FFlag::LuauCompileFunctionType && functionVisitor.hasTypes) + buildTypeMap(compiler.typeMap, root, options.vectorType); + for (AstExprFunction* expr : functions) compiler.compileFunction(expr, 0); diff --git a/Compiler/src/ConstantFolding.cpp b/Compiler/src/ConstantFolding.cpp index 510f2b7b..a49a7748 100644 --- a/Compiler/src/ConstantFolding.cpp +++ b/Compiler/src/ConstantFolding.cpp @@ -197,17 +197,19 @@ struct ConstantVisitor : AstVisitor DenseHashMap& locals; const DenseHashMap* builtins; + bool foldMathK = false; bool wasEmpty = false; std::vector builtinArgs; ConstantVisitor(DenseHashMap& constants, DenseHashMap& variables, - DenseHashMap& locals, const DenseHashMap* builtins) + DenseHashMap& locals, const DenseHashMap* builtins, bool foldMathK) : constants(constants) , variables(variables) , locals(locals) , builtins(builtins) + , foldMathK(foldMathK) { // since we do a single pass over the tree, if the initial state was empty we don't need to clear out old entries wasEmpty = constants.empty() && locals.empty(); @@ -296,6 +298,14 @@ struct ConstantVisitor : AstVisitor else if (AstExprIndexName* expr = node->as()) { analyze(expr->expr); + + if (foldMathK) + { + if (AstExprGlobal* eg = expr->expr->as(); eg && eg->name == "math") + { + result = foldBuiltinMath(expr->index); + } + } } else if (AstExprIndexExpr* expr = node->as()) { @@ -437,9 +447,9 @@ struct ConstantVisitor : AstVisitor }; void foldConstants(DenseHashMap& constants, DenseHashMap& variables, - DenseHashMap& locals, const DenseHashMap* builtins, AstNode* root) + DenseHashMap& locals, const DenseHashMap* builtins, bool foldMathK, AstNode* root) { - ConstantVisitor visitor{constants, variables, locals, builtins}; + ConstantVisitor visitor{constants, variables, locals, builtins, foldMathK}; root->visit(&visitor); } diff --git a/Compiler/src/ConstantFolding.h b/Compiler/src/ConstantFolding.h index d67d9285..f0798ea2 100644 --- a/Compiler/src/ConstantFolding.h +++ b/Compiler/src/ConstantFolding.h @@ -43,7 +43,7 @@ struct Constant }; void foldConstants(DenseHashMap& constants, DenseHashMap& variables, - DenseHashMap& locals, const DenseHashMap* builtins, AstNode* root); + DenseHashMap& locals, const DenseHashMap* builtins, bool foldMathK, AstNode* root); } // namespace Compile } // namespace Luau diff --git a/Sources.cmake b/Sources.cmake index c1230f30..88f3bb01 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -194,6 +194,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/TypeUtils.h Analysis/include/Luau/Unifiable.h Analysis/include/Luau/Unifier.h + Analysis/include/Luau/Unifier2.h Analysis/include/Luau/UnifierSharedState.h Analysis/include/Luau/Variant.h Analysis/include/Luau/VisitType.h @@ -246,6 +247,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/TypeUtils.cpp Analysis/src/Unifiable.cpp Analysis/src/Unifier.cpp + Analysis/src/Unifier2.cpp ) # Luau.VM Sources @@ -424,6 +426,7 @@ if(TARGET Luau.UnitTest) tests/TypeInfer.unknownnever.test.cpp tests/TypePack.test.cpp tests/TypeVar.test.cpp + tests/Unifier2.test.cpp tests/Variant.test.cpp tests/VisitType.test.cpp tests/InsertionOrderedMap.test.cpp diff --git a/VM/include/lualib.h b/VM/include/lualib.h index 190cf66a..327bfefd 100644 --- a/VM/include/lualib.h +++ b/VM/include/lualib.h @@ -99,6 +99,7 @@ LUALIB_API char* luaL_extendbuffer(luaL_Buffer* B, size_t additionalsize, int bo LUALIB_API void luaL_reservebuffer(luaL_Buffer* B, size_t size, int boxloc); LUALIB_API void luaL_addlstring(luaL_Buffer* B, const char* s, size_t l, int boxloc); LUALIB_API void luaL_addvalue(luaL_Buffer* B); +LUALIB_API void luaL_addvalueany(luaL_Buffer* B, int idx); LUALIB_API void luaL_pushresult(luaL_Buffer* B); LUALIB_API void luaL_pushresultsize(luaL_Buffer* B, size_t size); diff --git a/VM/src/laux.cpp b/VM/src/laux.cpp index 0b9787a0..951b3028 100644 --- a/VM/src/laux.cpp +++ b/VM/src/laux.cpp @@ -11,6 +11,8 @@ #include +LUAU_FASTFLAG(LuauFasterInterp) + // convert a stack index to positive #define abs_index(L, i) ((i) > 0 || (i) <= LUA_REGISTRYINDEX ? (i) : lua_gettop(L) + (i) + 1) @@ -440,6 +442,52 @@ void luaL_addvalue(luaL_Buffer* B) } } +void luaL_addvalueany(luaL_Buffer* B, int idx) +{ + lua_State* L = B->L; + + switch (lua_type(L, idx)) + { + case LUA_TNONE: + { + LUAU_ASSERT(!"expected value"); + break; + } + case LUA_TNIL: + luaL_addstring(B, "nil"); + break; + case LUA_TBOOLEAN: + if (lua_toboolean(L, idx)) + luaL_addstring(B, "true"); + else + luaL_addstring(B, "false"); + break; + case LUA_TNUMBER: + { + double n = lua_tonumber(L, idx); + char s[LUAI_MAXNUM2STR]; + char* e = luai_num2str(s, n); + luaL_addlstring(B, s, e - s, -1); + break; + } + case LUA_TSTRING: + { + size_t len; + const char* s = lua_tolstring(L, idx, &len); + luaL_addlstring(B, s, len, -1); + break; + } + default: + { + size_t len; + const char* s = luaL_tolstring(L, idx, &len); + + luaL_addlstring(B, s, len, -2); + lua_pop(L, 1); + } + } +} + void luaL_pushresult(luaL_Buffer* B) { lua_State* L = B->L; @@ -476,13 +524,29 @@ const char* luaL_tolstring(lua_State* L, int idx, size_t* len) { if (luaL_callmeta(L, idx, "__tostring")) // is there a metafield? { - if (!lua_isstring(L, -1)) - luaL_error(L, "'__tostring' must return a string"); - return lua_tolstring(L, -1, len); + if (FFlag::LuauFasterInterp) + { + const char* s = lua_tolstring(L, -1, len); + if (!s) + luaL_error(L, "'__tostring' must return a string"); + return s; + } + else + { + if (!lua_isstring(L, -1)) + luaL_error(L, "'__tostring' must return a string"); + return lua_tolstring(L, -1, len); + } } switch (lua_type(L, idx)) { + case LUA_TNIL: + lua_pushliteral(L, "nil"); + break; + case LUA_TBOOLEAN: + lua_pushstring(L, (lua_toboolean(L, idx) ? "true" : "false")); + break; case LUA_TNUMBER: { double n = lua_tonumber(L, idx); @@ -491,15 +555,6 @@ const char* luaL_tolstring(lua_State* L, int idx, size_t* len) lua_pushlstring(L, s, e - s); break; } - case LUA_TSTRING: - lua_pushvalue(L, idx); - break; - case LUA_TBOOLEAN: - lua_pushstring(L, (lua_toboolean(L, idx) ? "true" : "false")); - break; - case LUA_TNIL: - lua_pushliteral(L, "nil"); - break; case LUA_TVECTOR: { const float* v = lua_tovector(L, idx); @@ -518,6 +573,9 @@ const char* luaL_tolstring(lua_State* L, int idx, size_t* len) lua_pushlstring(L, s, e - s); break; } + case LUA_TSTRING: + lua_pushvalue(L, idx); + break; default: { const void* ptr = lua_topointer(L, idx); diff --git a/VM/src/ldebug.cpp b/VM/src/ldebug.cpp index cb3c9d3b..844fc6dd 100644 --- a/VM/src/ldebug.cpp +++ b/VM/src/ldebug.cpp @@ -394,6 +394,15 @@ int luaG_getline(Proto* p, int pc) return p->abslineinfo[pc >> p->linegaplog2] + p->lineinfo[pc]; } +int luaG_isnative(lua_State* L, int level) +{ + if (unsigned(level) >= unsigned(L->ci - L->base_ci)) + return 0; + + CallInfo* ci = L->ci - level; + return (ci->flags & LUA_CALLINFO_NATIVE) != 0 ? 1 : 0; +} + void lua_singlestep(lua_State* L, int enabled) { L->singlestep = bool(enabled); diff --git a/VM/src/ldebug.h b/VM/src/ldebug.h index a93e412f..49b1ca88 100644 --- a/VM/src/ldebug.h +++ b/VM/src/ldebug.h @@ -29,3 +29,5 @@ LUAI_FUNC void luaG_breakpoint(lua_State* L, Proto* p, int line, bool enable); LUAI_FUNC bool luaG_onbreak(lua_State* L); LUAI_FUNC int luaG_getline(Proto* p, int pc); + +LUAI_FUNC int luaG_isnative(lua_State* L, int level); diff --git a/VM/src/lgcdebug.cpp b/VM/src/lgcdebug.cpp index e68b84a9..5b18c26f 100644 --- a/VM/src/lgcdebug.cpp +++ b/VM/src/lgcdebug.cpp @@ -648,16 +648,28 @@ static void enumtable(EnumContext* ctx, Table* h) if (h->node != &luaH_dummynode) { + bool weakkey = false; + bool weakvalue = false; + + if (const TValue* mode = gfasttm(ctx->L->global, h->metatable, TM_MODE)) + { + if (ttisstring(mode)) + { + weakkey = strchr(svalue(mode), 'k') != NULL; + weakvalue = strchr(svalue(mode), 'v') != NULL; + } + } + for (int i = 0; i < sizenode(h); ++i) { const LuaNode& n = h->node[i]; if (!ttisnil(&n.val) && (iscollectable(&n.key) || iscollectable(&n.val))) { - if (iscollectable(&n.key)) + if (!weakkey && iscollectable(&n.key)) enumedge(ctx, obj2gco(h), gcvalue(&n.key), "[key]"); - if (iscollectable(&n.val)) + if (!weakvalue && iscollectable(&n.val)) { if (ttisstring(&n.key)) { @@ -671,7 +683,9 @@ static void enumtable(EnumContext* ctx, Table* h) } else { - enumedge(ctx, obj2gco(h), gcvalue(&n.val), NULL); + char buf[32]; + snprintf(buf, sizeof(buf), "[%s]", getstr(ctx->L->global->ttname[n.key.tt])); + enumedge(ctx, obj2gco(h), gcvalue(&n.val), buf); } } } @@ -745,7 +759,14 @@ static void enumthread(EnumContext* ctx, lua_State* th) { Proto* p = tcl->l.p; - enumnode(ctx, obj2gco(th), getstr(p->source)); + char buf[LUA_IDSIZE]; + + if (p->source) + snprintf(buf, sizeof(buf), "%s:%d %s", p->debugname ? getstr(p->debugname) : "", p->linedefined, getstr(p->source)); + else + snprintf(buf, sizeof(buf), "%s:%d", p->debugname ? getstr(p->debugname) : "", p->linedefined); + + enumnode(ctx, obj2gco(th), buf); } else { diff --git a/VM/src/lstrlib.cpp b/VM/src/lstrlib.cpp index 875a479a..d9ce71f9 100644 --- a/VM/src/lstrlib.cpp +++ b/VM/src/lstrlib.cpp @@ -8,6 +8,8 @@ #include #include +LUAU_FASTFLAGVARIABLE(LuauFasterInterp, false) + // macro to `unsign' a character #define uchar(c) ((unsigned char)(c)) @@ -966,6 +968,14 @@ static int str_format(lua_State* L) luaL_addchar(&b, *strfrmt++); else if (*++strfrmt == L_ESC) luaL_addchar(&b, *strfrmt++); // %% + else if (FFlag::LuauFasterInterp && *strfrmt == '*') + { + strfrmt++; + if (++arg > top) + luaL_error(L, "missing argument #%d", arg); + + luaL_addvalueany(&b, arg); + } else { // format item char form[MAX_FORMAT]; // to store the format (`%...') @@ -1034,7 +1044,7 @@ static int str_format(lua_State* L) } case '*': { - if (formatItemSize != 1) + if (FFlag::LuauFasterInterp || formatItemSize != 1) luaL_error(L, "'%%*' does not take a form"); size_t length; diff --git a/bench/micro_tests/test_StringInterp.lua b/bench/micro_tests/test_StringInterp.lua new file mode 100644 index 00000000..33d5ecea --- /dev/null +++ b/bench/micro_tests/test_StringInterp.lua @@ -0,0 +1,44 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +bench.runCode(function() + for j=1,1e6 do + local _ = "j=" .. tostring(j) + end +end, "interp: tostring") + +bench.runCode(function() + for j=1,1e6 do + local _ = "j=" .. j + end +end, "interp: concat") + +bench.runCode(function() + for j=1,1e6 do + local _ = string.format("j=%f", j) + end +end, "interp: %f format") + +bench.runCode(function() + for j=1,1e6 do + local _ = string.format("j=%d", j) + end +end, "interp: %d format") + +bench.runCode(function() + for j=1,1e6 do + local _ = string.format("j=%*", j) + end +end, "interp: %* format") + +bench.runCode(function() + for j=1,1e6 do + local _ = `j={j}` + end +end, "interp: interp number") + +bench.runCode(function() + local ok = "hello!" + for j=1,1e6 do + local _ = `j={ok}` + end +end, "interp: interp string") \ No newline at end of file diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index e13e203a..daa1f81a 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -3595,7 +3595,7 @@ TEST_CASE_FIXTURE(ACFixture, "string_completion_outside_quotes") TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_empty") { - ScopedFastFlag flag{"LuauAnonymousAutofilled", true}; + ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; check(R"( local function foo(a: () -> ()) @@ -3618,7 +3618,7 @@ foo(@1) TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_args") { - ScopedFastFlag flag{"LuauAnonymousAutofilled", true}; + ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; check(R"( local function foo(a: (number, string) -> ()) @@ -3641,7 +3641,7 @@ foo(@1) TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_args_single_return") { - ScopedFastFlag flag{"LuauAnonymousAutofilled", true}; + ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; check(R"( local function foo(a: (number, string) -> (string)) @@ -3664,7 +3664,7 @@ foo(@1) TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_args_multi_return") { - ScopedFastFlag flag{"LuauAnonymousAutofilled", true}; + ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; check(R"( local function foo(a: (number, string) -> (string, number)) @@ -3687,7 +3687,7 @@ foo(@1) TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled__noargs_multi_return") { - ScopedFastFlag flag{"LuauAnonymousAutofilled", true}; + ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; check(R"( local function foo(a: () -> (string, number)) @@ -3710,7 +3710,7 @@ foo(@1) TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled__varargs_multi_return") { - ScopedFastFlag flag{"LuauAnonymousAutofilled", true}; + ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; check(R"( local function foo(a: (...number) -> (string, number)) @@ -3733,7 +3733,7 @@ foo(@1) TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_multi_varargs_multi_return") { - ScopedFastFlag flag{"LuauAnonymousAutofilled", true}; + ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; check(R"( local function foo(a: (string, ...number) -> (string, number)) @@ -3756,7 +3756,7 @@ foo(@1) TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_multi_varargs_varargs_return") { - ScopedFastFlag flag{"LuauAnonymousAutofilled", true}; + ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; check(R"( local function foo(a: (string, ...number) -> ...number) @@ -3779,7 +3779,7 @@ foo(@1) TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_multi_varargs_multi_varargs_return") { - ScopedFastFlag flag{"LuauAnonymousAutofilled", true}; + ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; check(R"( local function foo(a: (string, ...number) -> (boolean, ...number)) @@ -3802,7 +3802,7 @@ foo(@1) TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_named_args") { - ScopedFastFlag flag{"LuauAnonymousAutofilled", true}; + ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; check(R"( local function foo(a: (foo: number, bar: string) -> (string, number)) @@ -3825,7 +3825,7 @@ foo(@1) TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_partially_args") { - ScopedFastFlag flag{"LuauAnonymousAutofilled", true}; + ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; check(R"( local function foo(a: (number, bar: string) -> (string, number)) @@ -3848,7 +3848,7 @@ foo(@1) TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_partially_args_last") { - ScopedFastFlag flag{"LuauAnonymousAutofilled", true}; + ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; check(R"( local function foo(a: (foo: number, string) -> (string, number)) @@ -3871,7 +3871,7 @@ foo(@1) TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_typeof_args") { - ScopedFastFlag flag{"LuauAnonymousAutofilled", true}; + ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; check(R"( local t = { a = 1, b = 2 } @@ -3896,7 +3896,7 @@ foo(@1) TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_table_literal_args") { - ScopedFastFlag flag{"LuauAnonymousAutofilled", true}; + ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; check(R"( local function foo(a: (tbl: { x: number, y: number }) -> number) return a({x=2, y = 3}) end @@ -3916,7 +3916,7 @@ foo(@1) TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_typeof_returns") { - ScopedFastFlag flag{"LuauAnonymousAutofilled", true}; + ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; check(R"( local t = { a = 1, b = 2 } @@ -3941,7 +3941,7 @@ foo(@1) TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_table_literal_args") { - ScopedFastFlag flag{"LuauAnonymousAutofilled", true}; + ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; check(R"( local function foo(a: () -> { x: number, y: number }) return {x=2, y = 3} end @@ -3961,7 +3961,7 @@ foo(@1) TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_typeof_vararg") { - ScopedFastFlag flag{"LuauAnonymousAutofilled", true}; + ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; check(R"( local t = { a = 1, b = 2 } @@ -3986,7 +3986,7 @@ foo(@1) TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_generic_type_pack_vararg") { - ScopedFastFlag flag{"LuauAnonymousAutofilled", true}; + ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; check(R"( local function foo(a: (...A) -> number, ...: A) @@ -4009,7 +4009,7 @@ foo(@1) TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_generic_on_argument_type_pack_vararg") { - ScopedFastFlag flag{"LuauAnonymousAutofilled", true}; + ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; check(R"( local function foo(a: (...: T...) -> number) diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 7abf0423..d368af66 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -7264,4 +7264,47 @@ end )"); } +TEST_CASE("BuiltinFoldMathK") +{ + ScopedFastFlag sff("LuauCompileFoldMathK", true); + + // we can fold math.pi at optimization level 2 + CHECK_EQ("\n" + compileFunction(R"( +function test() + return math.pi * 2 +end +)", 0, 2), + R"( +LOADK R0 K0 [6.2831853071795862] +RETURN R0 1 +)"); + + // we don't do this at optimization level 1 because it may interfere with environment substitution + CHECK_EQ("\n" + compileFunction(R"( +function test() + return math.pi * 2 +end +)", 0, 1), + R"( +GETIMPORT R1 3 [math.pi] +MULK R0 R1 K0 [2] +RETURN R0 1 +)"); + + // we also don't do it if math global is assigned to + CHECK_EQ("\n" + compileFunction(R"( +function test() + return math.pi * 2 +end + +math = { pi = 4 } +)", 0, 2), + R"( +GETGLOBAL R2 K1 ['math'] +GETTABLEKS R1 R2 K2 ['pi'] +MULK R0 R1 K0 [2] +RETURN R0 1 +)"); +} + TEST_SUITE_END(); diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index c07aab0d..f4a74b66 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -1723,6 +1723,52 @@ TEST_CASE("Native") runConformance("native.lua"); } +TEST_CASE("NativeTypeAnnotations") +{ + ScopedFastFlag bytecodeVersion4("BytecodeVersion4", true); + ScopedFastFlag luauCompileFunctionType("LuauCompileFunctionType", true); + + // This tests requires code to run natively, otherwise all 'is_native' checks will fail + if (!codegen || !luau_codegen_supported()) + return; + + lua_CompileOptions copts = defaultOptions(); + copts.vectorCtor = "vector"; + copts.vectorType = "vector"; + + runConformance( + "native_types.lua", + [](lua_State* L) { + // add is_native() function + lua_pushcclosurek( + L, + [](lua_State* L) -> int { + extern int luaG_isnative(lua_State * L, int level); + + lua_pushboolean(L, luaG_isnative(L, 1)); + return 1; + }, + "is_native", 0, nullptr); + lua_setglobal(L, "is_native"); + + // for vector tests + lua_pushcfunction(L, lua_vector, "vector"); + lua_setglobal(L, "vector"); + +#if LUA_VECTOR_SIZE == 4 + lua_pushvector(L, 0.0f, 0.0f, 0.0f, 0.0f); +#else + lua_pushvector(L, 0.0f, 0.0f, 0.0f); +#endif + luaL_newmetatable(L, "vector"); + + lua_setreadonly(L, -1, true); + lua_setmetatable(L, -2); + lua_pop(L, 1); + }, + nullptr, nullptr, &copts); +} + TEST_CASE("HugeFunction") { std::string source; diff --git a/tests/ConstraintGraphBuilderFixture.cpp b/tests/ConstraintGraphBuilderFixture.cpp index 01b9a5dd..6c1b6fdd 100644 --- a/tests/ConstraintGraphBuilderFixture.cpp +++ b/tests/ConstraintGraphBuilderFixture.cpp @@ -20,7 +20,7 @@ void ConstraintGraphBuilderFixture::generateConstraints(const std::string& code) { AstStatBlock* root = parse(code); dfg = std::make_unique(DataFlowGraphBuilder::build(root, NotNull{&ice})); - cgb = std::make_unique(mainModule, &arena, NotNull(&moduleResolver), builtinTypes, NotNull(&ice), + cgb = std::make_unique(mainModule, NotNull{&normalizer}, NotNull(&moduleResolver), builtinTypes, NotNull(&ice), frontend.globals.globalScope, /*prepareModuleScope*/ nullptr, &logger, NotNull{dfg.get()}, std::vector()); cgb->visit(root); rootScope = cgb->rootScope; diff --git a/tests/Differ.test.cpp b/tests/Differ.test.cpp index 132b0267..c2b09bbd 100644 --- a/tests/Differ.test.cpp +++ b/tests/Differ.test.cpp @@ -5,8 +5,10 @@ #include "Luau/Frontend.h" #include "Fixture.h" +#include "ClassFixture.h" #include "Luau/Symbol.h" +#include "Luau/Type.h" #include "ScopedFlags.h" #include "doctest.h" #include @@ -128,6 +130,592 @@ TEST_CASE_FIXTURE(DifferFixture, "a_nested_table_wrong_match") "{ on: string } } } }, while the right type at almostFoo.inner.table.has.wrong.variant has type string"); } +TEST_CASE_FIXTURE(DifferFixture, "left_cyclic_table_right_table_missing_property") +{ + CheckResult result = check(R"( + local function id(x: a): a + return x + end + + -- Remove name from cyclic table + local foo = id({}) + foo.foo = foo + local almostFoo = { x = 2 } + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe("foo", "almostFoo", + R"(DiffError: these two types are not equal because the left type at .foo has type t1 where t1 = { foo: t1 }, while the right type at almostFoo is missing the property foo)"); +} + +TEST_CASE_FIXTURE(DifferFixture, "left_cyclic_table_right_table_property_wrong") +{ + CheckResult result = check(R"( + local function id(x: a): a + return x + end + + -- Remove name from cyclic table + local foo = id({}) + foo.foo = foo + local almostFoo = { foo = 2 } + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe("foo", "almostFoo", + R"(DiffError: these two types are not equal because the left type at .foo has type t1 where t1 = { foo: t1 }, while the right type at almostFoo.foo has type number)"); +} + +TEST_CASE_FIXTURE(DifferFixture, "right_cyclic_table_left_table_missing_property") +{ + CheckResult result = check(R"( + local function id(x: a): a + return x + end + + -- Remove name from cyclic table + local foo = id({}) + foo.foo = foo + local almostFoo = { x = 2 } + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe("almostFoo", "foo", + R"(DiffError: these two types are not equal because the left type at almostFoo.x has type number, while the right type at is missing the property x)"); +} + +TEST_CASE_FIXTURE(DifferFixture, "right_cyclic_table_left_table_property_wrong") +{ + CheckResult result = check(R"( + local function id(x: a): a + return x + end + + -- Remove name from cyclic table + local foo = id({}) + foo.foo = foo + local almostFoo = { foo = 2 } + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe("almostFoo", "foo", + R"(DiffError: these two types are not equal because the left type at almostFoo.foo has type number, while the right type at .foo has type t1 where t1 = { foo: t1 })"); +} + +TEST_CASE_FIXTURE(DifferFixture, "equal_table_two_cyclic_tables_are_not_different") +{ + CheckResult result = check(R"( + local function id(x: a): a + return x + end + + -- Remove name from cyclic table + local foo = id({}) + foo.foo = foo + local almostFoo = id({}) + almostFoo.foo = almostFoo + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesEq("foo", "almostFoo"); +} + +TEST_CASE_FIXTURE(DifferFixture, "equal_table_two_shifted_circles_are_not_different") +{ + CheckResult result = check(R"( + local function id(x: a): a + return x + end + + -- Remove name from cyclic table + local foo = id({}) + foo.foo = id({}) + foo.foo.foo = id({}) + foo.foo.foo.foo = id({}) + foo.foo.foo.foo.foo = foo + + local builder = id({}) + builder.foo = id({}) + builder.foo.foo = id({}) + builder.foo.foo.foo = id({}) + builder.foo.foo.foo.foo = builder + -- Shift + local almostFoo = builder.foo.foo + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesEq("foo", "almostFoo"); +} + +TEST_CASE_FIXTURE(DifferFixture, "table_left_circle_right_measuring_tape") +{ + // Left is a circle, right is a measuring tape + CheckResult result = check(R"( + local function id(x: a): a + return x + end + + -- Remove name from cyclic table + local foo = id({}) + foo.foo = id({}) + foo.foo.foo = id({}) + foo.foo.foo.foo = id({}) + foo.foo.foo.bar = id({}) -- anchor to pin shape + foo.foo.foo.foo.foo = foo + local almostFoo = id({}) + almostFoo.foo = id({}) + almostFoo.foo.foo = id({}) + almostFoo.foo.foo.foo = id({}) + almostFoo.foo.foo.bar = id({}) -- anchor to pin shape + almostFoo.foo.foo.foo.foo = almostFoo.foo + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe("foo", "almostFoo", + R"(DiffError: these two types are not equal because the left type at .foo.foo.foo.foo.foo is missing the property bar, while the right type at .foo.foo.foo.foo.foo.bar has type { })"); +} + +TEST_CASE_FIXTURE(DifferFixture, "equal_table_measuring_tapes") +{ + CheckResult result = check(R"( + local function id(x: a): a + return x + end + + -- Remove name from cyclic table + local foo = id({}) + foo.foo = id({}) + foo.foo.foo = id({}) + foo.foo.foo.foo = id({}) + foo.foo.foo.foo.foo = foo.foo + local almostFoo = id({}) + almostFoo.foo = id({}) + almostFoo.foo.foo = id({}) + almostFoo.foo.foo.foo = id({}) + almostFoo.foo.foo.foo.foo = almostFoo.foo + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesEq("foo", "almostFoo"); +} + +TEST_CASE_FIXTURE(DifferFixture, "equal_table_A_B_C") +{ + CheckResult result = check(R"( + local function id(x: a): a + return x + end + + -- Remove name from cyclic table + local foo = id({}) + foo.foo = id({}) + foo.foo.foo = id({}) + foo.foo.foo.foo = id({}) + foo.foo.foo.foo.foo = foo.foo + local almostFoo = id({}) + almostFoo.foo = id({}) + almostFoo.foo.foo = id({}) + almostFoo.foo.foo.foo = id({}) + almostFoo.foo.foo.foo.foo = almostFoo.foo + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesEq("foo", "almostFoo"); +} + +TEST_CASE_FIXTURE(DifferFixture, "equal_table_kind_A") +{ + CheckResult result = check(R"( + -- Remove name from cyclic table + local function id(x: a): a + return x + end + + local foo = id({}) + foo.left = id({}) + foo.right = id({}) + foo.left.left = id({}) + foo.left.right = id({}) + foo.right.left = id({}) + foo.right.right = id({}) + foo.right.left.left = id({}) + foo.right.left.right = id({}) + + foo.right.left.left.child = foo.right + + local almostFoo = id({}) + almostFoo.left = id({}) + almostFoo.right = id({}) + almostFoo.left.left = id({}) + almostFoo.left.right = id({}) + almostFoo.right.left = id({}) + almostFoo.right.right = id({}) + almostFoo.right.left.left = id({}) + almostFoo.right.left.right = id({}) + + almostFoo.right.left.left.child = almostFoo.right + + -- Bindings for requireType + local fooLeft = foo.left + local fooRight = foo.left.right + local fooLeftLeft = foo.left.left + local fooLeftRight = foo.left.right + local fooRightLeft = foo.right.left + local fooRightRight = foo.right.right + local fooRightLeftLeft = foo.right.left.left + local fooRightLeftRight = foo.right.left.right + + local almostFooLeft = almostFoo.left + local almostFooRight = almostFoo.left.right + local almostFooLeftLeft = almostFoo.left.left + local almostFooLeftRight = almostFoo.left.right + local almostFooRightLeft = almostFoo.right.left + local almostFooRightRight = almostFoo.right.right + local almostFooRightLeftLeft = almostFoo.right.left.left + local almostFooRightLeftRight = almostFoo.right.left.right + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesEq("foo", "almostFoo"); +} + +TEST_CASE_FIXTURE(DifferFixture, "equal_table_kind_B") +{ + CheckResult result = check(R"( + -- Remove name from cyclic table + local function id(x: a): a + return x + end + + local foo = id({}) + foo.left = id({}) + foo.right = id({}) + foo.left.left = id({}) + foo.left.right = id({}) + foo.right.left = id({}) + foo.right.right = id({}) + foo.right.left.left = id({}) + foo.right.left.right = id({}) + + foo.right.left.left.child = foo.left + + local almostFoo = id({}) + almostFoo.left = id({}) + almostFoo.right = id({}) + almostFoo.left.left = id({}) + almostFoo.left.right = id({}) + almostFoo.right.left = id({}) + almostFoo.right.right = id({}) + almostFoo.right.left.left = id({}) + almostFoo.right.left.right = id({}) + + almostFoo.right.left.left.child = almostFoo.left + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesEq("foo", "almostFoo"); +} + +TEST_CASE_FIXTURE(DifferFixture, "equal_table_kind_C") +{ + CheckResult result = check(R"( + -- Remove name from cyclic table + local function id(x: a): a + return x + end + + local foo = id({}) + foo.left = id({}) + foo.right = id({}) + foo.left.left = id({}) + foo.left.right = id({}) + foo.right.left = id({}) + foo.right.right = id({}) + foo.right.left.left = id({}) + foo.right.left.right = id({}) + + foo.right.left.left.child = foo + + local almostFoo = id({}) + almostFoo.left = id({}) + almostFoo.right = id({}) + almostFoo.left.left = id({}) + almostFoo.left.right = id({}) + almostFoo.right.left = id({}) + almostFoo.right.right = id({}) + almostFoo.right.left.left = id({}) + almostFoo.right.left.right = id({}) + + almostFoo.right.left.left.child = almostFoo + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesEq("foo", "almostFoo"); +} + +TEST_CASE_FIXTURE(DifferFixture, "equal_table_kind_D") +{ + CheckResult result = check(R"( + -- Remove name from cyclic table + local function id(x: a): a + return x + end + + local foo = id({}) + foo.left = id({}) + foo.right = id({}) + foo.left.left = id({}) + foo.left.right = id({}) + foo.right.left = id({}) + foo.right.right = id({}) + foo.right.left.left = id({}) + foo.right.left.right = id({}) + + foo.right.left.left.child = foo.right.left.left + + local almostFoo = id({}) + almostFoo.left = id({}) + almostFoo.right = id({}) + almostFoo.left.left = id({}) + almostFoo.left.right = id({}) + almostFoo.right.left = id({}) + almostFoo.right.right = id({}) + almostFoo.right.left.left = id({}) + almostFoo.right.left.right = id({}) + + almostFoo.right.left.left.child = almostFoo.right.left.left + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesEq("foo", "almostFoo"); +} + +TEST_CASE_FIXTURE(DifferFixture, "equal_table_cyclic_diamonds_unraveled") +{ + CheckResult result = check(R"( + -- Remove name from cyclic table + local function id(x: a): a + return x + end + + -- Pattern 1 + local foo = id({}) + foo.child = id({}) + foo.child.left = id({}) + foo.child.right = id({}) + + foo.child.left.child = foo + foo.child.right.child = foo + + -- Pattern 2 + local almostFoo = id({}) + almostFoo.child = id({}) + almostFoo.child.left = id({}) + almostFoo.child.right = id({}) + + almostFoo.child.left.child = id({}) -- Use a new table + almostFoo.child.right.child = almostFoo.child.left.child -- Refer to the same new table + + almostFoo.child.left.child.child = id({}) + almostFoo.child.left.child.child.left = id({}) + almostFoo.child.left.child.child.right = id({}) + + almostFoo.child.left.child.child.left.child = almostFoo.child.left.child + almostFoo.child.left.child.child.right.child = almostFoo.child.left.child + + -- Pattern 3 + local anotherFoo = id({}) + anotherFoo.child = id({}) + anotherFoo.child.left = id({}) + anotherFoo.child.right = id({}) + + anotherFoo.child.left.child = id({}) -- Use a new table + anotherFoo.child.right.child = id({}) -- Use another new table + + anotherFoo.child.left.child.child = id({}) + anotherFoo.child.left.child.child.left = id({}) + anotherFoo.child.left.child.child.right = id({}) + anotherFoo.child.right.child.child = id({}) + anotherFoo.child.right.child.child.left = id({}) + anotherFoo.child.right.child.child.right = id({}) + + anotherFoo.child.left.child.child.left.child = anotherFoo.child.left.child + anotherFoo.child.left.child.child.right.child = anotherFoo.child.left.child + anotherFoo.child.right.child.child.left.child = anotherFoo.child.right.child + anotherFoo.child.right.child.child.right.child = anotherFoo.child.right.child + + -- Pattern 4 + local cleverFoo = id({}) + cleverFoo.child = id({}) + cleverFoo.child.left = id({}) + cleverFoo.child.right = id({}) + + cleverFoo.child.left.child = id({}) -- Use a new table + cleverFoo.child.right.child = id({}) -- Use another new table + + cleverFoo.child.left.child.child = id({}) + cleverFoo.child.left.child.child.left = id({}) + cleverFoo.child.left.child.child.right = id({}) + cleverFoo.child.right.child.child = id({}) + cleverFoo.child.right.child.child.left = id({}) + cleverFoo.child.right.child.child.right = id({}) + -- Same as pattern 3, but swapped here + cleverFoo.child.left.child.child.left.child = cleverFoo.child.right.child -- Swap + cleverFoo.child.left.child.child.right.child = cleverFoo.child.right.child + cleverFoo.child.right.child.child.left.child = cleverFoo.child.left.child + cleverFoo.child.right.child.child.right.child = cleverFoo.child.left.child + + -- Pattern 5 + local cheekyFoo = id({}) + cheekyFoo.child = id({}) + cheekyFoo.child.left = id({}) + cheekyFoo.child.right = id({}) + + cheekyFoo.child.left.child = foo -- Use existing pattern + cheekyFoo.child.right.child = foo -- Use existing pattern + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + std::vector symbols{"foo", "almostFoo", "anotherFoo", "cleverFoo", "cheekyFoo"}; + + for (auto left : symbols) + { + for (auto right : symbols) + { + compareTypesEq(left, right); + } + } +} + +TEST_CASE_FIXTURE(DifferFixture, "equal_function_cyclic") +{ + // Old solver does not correctly infer function typepacks + ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + + CheckResult result = check(R"( + function foo() + return foo + end + function almostFoo() + function bar() + return bar + end + return bar + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesEq("foo", "almostFoo"); +} + +TEST_CASE_FIXTURE(DifferFixture, "equal_function_table_cyclic") +{ + // Old solver does not correctly infer function typepacks + ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + + CheckResult result = check(R"( + function foo() + return { + bar = foo + } + end + function almostFoo() + function bar() + return { + bar = bar + } + end + return { + bar = bar + } + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesEq("foo", "almostFoo"); +} + +TEST_CASE_FIXTURE(DifferFixture, "function_table_self_referential_cyclic") +{ + // Old solver does not correctly infer function typepacks + // ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + + CheckResult result = check(R"( + function foo() + return { + bar = foo + } + end + function almostFoo() + function bar() + return bar + end + return { + bar = bar + } + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe("foo", "almostFoo", + R"(DiffError: these two types are not equal because the left type at .Ret[1].bar.Ret[1] has type t1 where t1 = {| bar: () -> t1 |}, while the right type at .Ret[1].bar.Ret[1] has type t1 where t1 = () -> t1)"); +} + +TEST_CASE_FIXTURE(DifferFixture, "equal_union_cyclic") +{ + TypeArena arena; + TypeId number = arena.addType(PrimitiveType{PrimitiveType::Number}); + TypeId string = arena.addType(PrimitiveType{PrimitiveType::String}); + + TypeId foo = arena.addType(UnionType{std::vector{number, string}}); + UnionType* unionFoo = getMutable(foo); + unionFoo->options.push_back(foo); + + TypeId almostFoo = arena.addType(UnionType{std::vector{number, string}}); + UnionType* unionAlmostFoo = getMutable(almostFoo); + unionAlmostFoo->options.push_back(almostFoo); + + compareEq(foo, almostFoo); +} + +TEST_CASE_FIXTURE(DifferFixture, "equal_intersection_cyclic") +{ + // Old solver does not correctly refine test types + ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + + CheckResult result = check(R"( + function foo1(x: number) + return x + end + function foo2(x: string) + return 0 + end + function bar1(x: number) + return x + end + function bar2(x: string) + return 0 + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); + TypeId foo1 = requireType("foo1"); + TypeId foo2 = requireType("foo2"); + TypeId bar1 = requireType("bar1"); + TypeId bar2 = requireType("bar2"); + + TypeArena arena; + + TypeId foo = arena.addType(IntersectionType{std::vector{foo1, foo2}}); + IntersectionType* intersectionFoo = getMutable(foo); + intersectionFoo->parts.push_back(foo); + + TypeId almostFoo = arena.addType(IntersectionType{std::vector{bar1, bar2}}); + IntersectionType* intersectionAlmostFoo = getMutable(almostFoo); + intersectionAlmostFoo->parts.push_back(almostFoo); + + compareEq(foo, almostFoo); +} + TEST_CASE_FIXTURE(DifferFixture, "singleton") { CheckResult result = check(R"( @@ -700,4 +1288,244 @@ TEST_CASE_FIXTURE(DifferFixture, "generic_three_or_three") R"(DiffError: these two types are not equal because the left generic at .Arg[2] cannot be the same type parameter as the right generic at .Arg[2])"); } +TEST_CASE_FIXTURE(DifferFixtureWithBuiltins, "equal_metatable") +{ + CheckResult result = check(R"( + local metaFoo = { + metaBar = 5 + } + local metaAlmostFoo = { + metaBar = 1 + } + local foo = { + bar = 3 + } + setmetatable(foo, metaFoo) + local almostFoo = { + bar = 4 + } + setmetatable(almostFoo, metaAlmostFoo) + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesEq("foo", "almostFoo"); +} + +TEST_CASE_FIXTURE(DifferFixtureWithBuiltins, "metatable_normal") +{ + CheckResult result = check(R"( + local metaFoo = { + metaBar = 5 + } + local metaAlmostFoo = { + metaBar = 1 + } + local foo = { + bar = 3 + } + setmetatable(foo, metaFoo) + local almostFoo = { + bar = "hello" + } + setmetatable(almostFoo, metaAlmostFoo) + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe("foo", "almostFoo", + R"(DiffError: these two types are not equal because the left type at .bar has type number, while the right type at .bar has type string)"); +} + +TEST_CASE_FIXTURE(DifferFixtureWithBuiltins, "metatable_metanormal") +{ + CheckResult result = check(R"( + local metaFoo = { + metaBar = "world" + } + local metaAlmostFoo = { + metaBar = 1 + } + local foo = { + bar = "amazing" + } + setmetatable(foo, metaFoo) + local almostFoo = { + bar = "hello" + } + setmetatable(almostFoo, metaAlmostFoo) + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe("foo", "almostFoo", + R"(DiffError: these two types are not equal because the left type at .__metatable.metaBar has type string, while the right type at .__metatable.metaBar has type number)"); +} + +TEST_CASE_FIXTURE(DifferFixtureWithBuiltins, "metatable_metamissing_left") +{ + CheckResult result = check(R"( + local metaFoo = { + metaBar = "world" + } + local metaAlmostFoo = { + metaBar = 1, + thisIsOnlyInRight = 2, + } + local foo = { + bar = "amazing" + } + setmetatable(foo, metaFoo) + local almostFoo = { + bar = "hello" + } + setmetatable(almostFoo, metaAlmostFoo) + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe("foo", "almostFoo", + R"(DiffError: these two types are not equal because the left type at .__metatable is missing the property thisIsOnlyInRight, while the right type at .__metatable.thisIsOnlyInRight has type number)"); +} + +TEST_CASE_FIXTURE(DifferFixtureWithBuiltins, "metatable_metamissing_right") +{ + CheckResult result = check(R"( + local metaFoo = { + metaBar = "world", + thisIsOnlyInLeft = 2, + } + local metaAlmostFoo = { + metaBar = 1, + } + local foo = { + bar = "amazing" + } + setmetatable(foo, metaFoo) + local almostFoo = { + bar = "hello" + } + setmetatable(almostFoo, metaAlmostFoo) + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe("foo", "almostFoo", + R"(DiffError: these two types are not equal because the left type at .__metatable.thisIsOnlyInLeft has type number, while the right type at .__metatable is missing the property thisIsOnlyInLeft)"); +} + +TEST_CASE_FIXTURE(DifferFixtureGeneric, "equal_class") +{ + CheckResult result = check(R"( + local foo = BaseClass + local almostFoo = BaseClass + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesEq("foo", "almostFoo"); +} + +TEST_CASE_FIXTURE(DifferFixtureGeneric, "class_normal") +{ + CheckResult result = check(R"( + local foo = BaseClass + local almostFoo = ChildClass + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe("foo", "almostFoo", + R"(DiffError: these two types are not equal because the left type at has type BaseClass, while the right type at has type ChildClass)"); +} + +TEST_CASE_FIXTURE(DifferFixture, "equal_generictp") +{ + CheckResult result = check(R"( + local foo: () -> T... + local almostFoo: () -> U... + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesEq("foo", "almostFoo"); +} + +TEST_CASE_FIXTURE(DifferFixture, "generictp_ne_fn") +{ + CheckResult result = check(R"( + local foo: (...T) -> U... + local almostFoo: (U...) -> U... + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + compareTypesNe("foo", "almostFoo", + R"(DiffError: these two types are not equal because the left type at has type (...T) -> (U...), while the right type at has type (U...) -> (U...))"); +} + +TEST_CASE_FIXTURE(DifferFixture, "generictp_normal") +{ + CheckResult result = check(R"( + -- trN should be X... -> Y... + -- s should be X -> Y... + -- x should be X + -- bij should be X... -> X... + + -- Intended signature: (X... -> Y..., Z -> X..., X... -> Y..., Z, Y... -> Y...) -> () + function foo(tr, s, tr2, x, bij) + bij(bij(tr(s(x)))) + bij(bij(tr2(s(x)))) + end + -- Intended signature: (X... -> X..., Z -> X..., X... -> Y..., Z, Y... -> Y...) -> () + function almostFoo(bij, s, tr, x, bij2) + bij(bij(s(x))) + bij2(bij2(tr(s(x)))) + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + INFO(Luau::toString(requireType("foo"))); + INFO(Luau::toString(requireType("almostFoo"))); + + compareTypesNe("foo", "almostFoo", + R"(DiffError: these two types are not equal because the left generic at .Arg[1].Ret[Variadic] cannot be the same type parameter as the right generic at .Arg[1].Ret[Variadic])"); +} + +TEST_CASE_FIXTURE(DifferFixture, "generictp_normal_2") +{ + CheckResult result = check(R"( + -- trN should be X... -> Y... + -- s should be X -> Y... + -- x should be X + -- bij should be X... -> X... + + function foo(s, tr, tr2, x, bij) + bij(bij(tr(s(x)))) + bij(bij(tr2(s(x)))) + end + function almostFoo(s, bij, tr, x, bij2) + bij2(bij2(bij(bij(tr(s(x)))))) + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + INFO(Luau::toString(requireType("foo"))); + INFO(Luau::toString(requireType("almostFoo"))); + + compareTypesNe("foo", "almostFoo", + R"(DiffError: these two types are not equal because the left generic at .Arg[2].Arg[Variadic] cannot be the same type parameter as the right generic at .Arg[2].Arg[Variadic])"); +} + +TEST_CASE_FIXTURE(DifferFixture, "equal_generictp_cyclic") +{ + CheckResult result = check(R"( + function foo(f, g, s, x) + f(f(g(g(s(x))))) + return foo + end + function almostFoo(f, g, s, x) + g(g(f(f(s(x))))) + return almostFoo + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + INFO(Luau::toString(requireType("foo"))); + INFO(Luau::toString(requireType("almostFoo"))); + + compareTypesEq("foo", "almostFoo"); +} + TEST_SUITE_END(); diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index 0be3fa98..5b1849a7 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -267,13 +267,16 @@ TEST_CASE_FIXTURE(Fixture, "clone_class") TEST_CASE_FIXTURE(Fixture, "clone_free_types") { - Type freeTy(FreeType{TypeLevel{}}); + ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", false}; + + TypeArena arena; + TypeId freeTy = freshType(NotNull{&arena}, builtinTypes, nullptr); TypePackVar freeTp(FreeTypePack{TypeLevel{}}); TypeArena dest; CloneState cloneState; - TypeId clonedTy = clone(&freeTy, dest, cloneState); + TypeId clonedTy = clone(freeTy, dest, cloneState); CHECK(get(clonedTy)); cloneState = {}; diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 234034d7..3798082b 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -1125,6 +1125,10 @@ until false TEST_CASE_FIXTURE(Fixture, "parse_nesting_based_end_detection_local_function") { + ScopedFastFlag sff[] = { + {"DebugLuauDeferredConstraintResolution", false}, + }; + try { parse(R"(-- i am line 1 @@ -1157,6 +1161,10 @@ end TEST_CASE_FIXTURE(Fixture, "parse_nesting_based_end_detection_failsafe_earlier") { + ScopedFastFlag sff[] = { + {"DebugLuauDeferredConstraintResolution", false}, + }; + try { parse(R"(-- i am line 1 @@ -2418,6 +2426,10 @@ TEST_CASE_FIXTURE(Fixture, "recovery_of_parenthesized_expressions") } }; + ScopedFastFlag sff[] = { + {"DebugLuauDeferredConstraintResolution", false}, + }; + checkRecovery("function foo(a, b. c) return a + b end", "function foo(a, b) return a + b end", 1); checkRecovery("function foo(a, b: { a: number, b: number. c:number }) return a + b end", "function foo(a, b: { a: number, b: number }) return a + b end", 1); @@ -2648,6 +2660,10 @@ TEST_CASE_FIXTURE(Fixture, "AstName_comparison") TEST_CASE_FIXTURE(Fixture, "generic_type_list_recovery") { + ScopedFastFlag sff[] = { + {"DebugLuauDeferredConstraintResolution", false}, + }; + try { parse(R"( diff --git a/tests/RuntimeLimits.test.cpp b/tests/RuntimeLimits.test.cpp index 093570d3..3fff6920 100644 --- a/tests/RuntimeLimits.test.cpp +++ b/tests/RuntimeLimits.test.cpp @@ -35,6 +35,10 @@ TEST_SUITE_BEGIN("RuntimeLimits"); TEST_CASE_FIXTURE(LimitFixture, "typescript_port_of_Result_type") { + ScopedFastFlag sff[] = { + {"DebugLuauDeferredConstraintResolution", false}, + }; + constexpr const char* src = R"LUA( --!strict diff --git a/tests/Simplify.test.cpp b/tests/Simplify.test.cpp index 63c03ba8..a1491baf 100644 --- a/tests/Simplify.test.cpp +++ b/tests/Simplify.test.cpp @@ -42,7 +42,7 @@ struct SimplifyFixture : Fixture const TypeId truthyTy = builtinTypes->truthyType; const TypeId falsyTy = builtinTypes->falsyType; - const TypeId freeTy = arena->addType(FreeType{&scope}); + const TypeId freeTy = freshType(arena, builtinTypes, &scope); const TypeId genericTy = arena->addType(GenericType{}); const TypeId blockedTy = arena->addType(BlockedType{}); const TypeId pendingTy = arena->addType(PendingExpansionType{{}, {}, {}, {}}); @@ -60,6 +60,8 @@ struct SimplifyFixture : Fixture TypeId anotherChildClassTy = nullptr; TypeId unrelatedClassTy = nullptr; + ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + SimplifyFixture() { createSomeClasses(&frontend); @@ -176,8 +178,8 @@ TEST_CASE_FIXTURE(SimplifyFixture, "boolean_and_truthy_and_falsy") TEST_CASE_FIXTURE(SimplifyFixture, "any_and_indeterminate_types") { - CHECK("a" == intersectStr(anyTy, freeTy)); - CHECK("a" == intersectStr(freeTy, anyTy)); + CHECK("'a" == intersectStr(anyTy, freeTy)); + CHECK("'a" == intersectStr(freeTy, anyTy)); CHECK("b" == intersectStr(anyTy, genericTy)); CHECK("b" == intersectStr(genericTy, anyTy)); @@ -191,17 +193,25 @@ TEST_CASE_FIXTURE(SimplifyFixture, "any_and_indeterminate_types") TEST_CASE_FIXTURE(SimplifyFixture, "unknown_and_indeterminate_types") { - CHECK(isIntersection(intersect(unknownTy, freeTy))); - CHECK(isIntersection(intersect(freeTy, unknownTy))); + CHECK(freeTy == intersect(unknownTy, freeTy)); + CHECK(freeTy == intersect(freeTy, unknownTy)); - CHECK(isIntersection(intersect(unknownTy, genericTy))); - CHECK(isIntersection(intersect(genericTy, unknownTy))); + TypeId t = nullptr; - CHECK(isIntersection(intersect(unknownTy, blockedTy))); - CHECK(isIntersection(intersect(blockedTy, unknownTy))); + t = intersect(unknownTy, genericTy); + CHECK_MESSAGE(isIntersection(t), "Should be an intersection but got " << t); + t = intersect(genericTy, unknownTy); + CHECK_MESSAGE(isIntersection(t), "Should be an intersection but got " << t); - CHECK(isIntersection(intersect(unknownTy, pendingTy))); - CHECK(isIntersection(intersect(pendingTy, unknownTy))); + t = intersect(unknownTy, blockedTy); + CHECK_MESSAGE(isIntersection(t), "Should be an intersection but got " << t); + t = intersect(blockedTy, unknownTy); + CHECK_MESSAGE(isIntersection(t), "Should be an intersection but got " << t); + + t = intersect(unknownTy, pendingTy); + CHECK_MESSAGE(isIntersection(t), "Should be an intersection but got " << t); + t = intersect(pendingTy, unknownTy); + CHECK_MESSAGE(isIntersection(t), "Should be an intersection but got " << t); } TEST_CASE_FIXTURE(SimplifyFixture, "unknown_and_concrete") @@ -225,8 +235,8 @@ TEST_CASE_FIXTURE(SimplifyFixture, "error_and_other_tops_and_bottom_types") TEST_CASE_FIXTURE(SimplifyFixture, "error_and_indeterminate_types") { - CHECK("*error-type* & a" == intersectStr(errorTy, freeTy)); - CHECK("*error-type* & a" == intersectStr(freeTy, errorTy)); + CHECK("'a & *error-type*" == intersectStr(errorTy, freeTy)); + CHECK("'a & *error-type*" == intersectStr(freeTy, errorTy)); CHECK("*error-type* & b" == intersectStr(errorTy, genericTy)); CHECK("*error-type* & b" == intersectStr(genericTy, errorTy)); @@ -430,7 +440,7 @@ TEST_CASE_FIXTURE(SimplifyFixture, "curious_union") TypeId curious = arena->addType(UnionType{{arena->addType(IntersectionType{{freeTy, falseTy}}), arena->addType(IntersectionType{{freeTy, nilTy}})}}); - CHECK("(a & false) | (a & nil) | number" == toString(union_(curious, numberTy))); + CHECK("('a & false) | ('a & nil) | number" == toString(union_(curious, numberTy))); } TEST_CASE_FIXTURE(SimplifyFixture, "negations") @@ -516,4 +526,13 @@ TEST_CASE_FIXTURE(SimplifyFixture, "simplify_stops_at_cycles") CHECK(t2 == intersect(anyTy, t2)); } +TEST_CASE_FIXTURE(SimplifyFixture, "free_type_bound_by_any_with_any") +{ + CHECK(freeTy == intersect(freeTy, anyTy)); + CHECK(freeTy == intersect(anyTy, freeTy)); + + CHECK(freeTy == intersect(freeTy, anyTy)); + CHECK(freeTy == intersect(anyTy, freeTy)); +} + TEST_SUITE_END(); diff --git a/tests/ToDot.test.cpp b/tests/ToDot.test.cpp index 73ae4773..9293bfb2 100644 --- a/tests/ToDot.test.cpp +++ b/tests/ToDot.test.cpp @@ -269,6 +269,10 @@ n3 [label="TableType 3"]; TEST_CASE_FIXTURE(Fixture, "free") { + ScopedFastFlag sff[] = { + {"DebugLuauDeferredConstraintResolution", false}, + }; + Type type{TypeVariant{FreeType{TypeLevel{0, 0}}}}; ToDotOptions opts; diff --git a/tests/TxnLog.test.cpp b/tests/TxnLog.test.cpp index bfd29765..2d302ea6 100644 --- a/tests/TxnLog.test.cpp +++ b/tests/TxnLog.test.cpp @@ -22,9 +22,9 @@ struct TxnLogFixture ScopePtr globalScope = std::make_shared(builtinTypes.anyTypePack); ScopePtr childScope = std::make_shared(globalScope); - TypeId a = arena.freshType(globalScope.get()); - TypeId b = arena.freshType(globalScope.get()); - TypeId c = arena.freshType(childScope.get()); + TypeId a = freshType(NotNull{&arena}, NotNull{&builtinTypes}, globalScope.get()); + TypeId b = freshType(NotNull{&arena}, NotNull{&builtinTypes}, globalScope.get()); + TypeId c = freshType(NotNull{&arena}, NotNull{&builtinTypes}, childScope.get()); TypeId g = arena.addType(GenericType{"G"}); }; @@ -108,8 +108,8 @@ TEST_CASE_FIXTURE(TxnLogFixture, "colliding_coincident_logs_do_not_create_degene log.commit(); - CHECK("a" == toString(a)); - CHECK("a" == toString(b)); + CHECK("'a" == toString(a)); + CHECK("'a" == toString(b)); } TEST_CASE_FIXTURE(TxnLogFixture, "replacing_persistent_types_is_allowed_but_makes_the_log_radioactive") diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index e4577df6..55f4caec 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -336,6 +336,10 @@ TEST_CASE_FIXTURE(Fixture, "stringify_type_alias_of_recursive_template_table_typ // Check that recursive intersection type doesn't generate an OOM TEST_CASE_FIXTURE(Fixture, "cli_38393_recursive_intersection_oom") { + ScopedFastFlag sff[] = { + {"DebugLuauDeferredConstraintResolution", false}, + }; // FIXME + CheckResult result = check(R"( function _(l0:(t0)&((t0)&(((t0)&((t0)->()))->(typeof(_),typeof(# _)))),l39,...):any end diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index b37bcf83..379ecac6 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -1912,9 +1912,8 @@ end TEST_CASE_FIXTURE(BuiltinsFixture, "dont_assert_when_the_tarjan_limit_is_exceeded_during_generalization") { ScopedFastInt sfi{"LuauTarjanChildLimit", 2}; - ScopedFastFlag sff[] = { - {"DebugLuauDeferredConstraintResolution", true}, - }; + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; CheckResult result = check(R"( function f(t) @@ -2156,4 +2155,17 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "num_is_solved_after_num_or_str") CHECK_EQ("() -> number", toString(requireType("num_or_str"))); } +TEST_CASE_FIXTURE(BuiltinsFixture, "apply_of_lambda_with_inferred_and_explicit_types") +{ + CheckResult result = check(R"( + local function apply(f, x) return f(x) end + local x = apply(function(x: string): number return 5 end, "hello!") + + local function apply_explicit(f: (A) -> B..., x: A): B... return f(x) end + local x = apply_explicit(function(x: string): number return 5 end, "hello!") + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index 35df644b..6b933616 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -935,7 +935,7 @@ TEST_CASE_FIXTURE(Fixture, "instantiate_cyclic_generic_function") std::optional methodProp = get(argTable->props, "method"); REQUIRE(bool(methodProp)); - const FunctionType* methodFunction = get(methodProp->type()); + const FunctionType* methodFunction = get(follow(methodProp->type())); REQUIRE(methodFunction != nullptr); std::optional methodArg = first(methodFunction->argTypes); diff --git a/tests/TypeInfer.loops.test.cpp b/tests/TypeInfer.loops.test.cpp index fef000e4..c7d2896c 100644 --- a/tests/TypeInfer.loops.test.cpp +++ b/tests/TypeInfer.loops.test.cpp @@ -77,7 +77,9 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "iteration_regression_issue_69967") TEST_CASE_FIXTURE(BuiltinsFixture, "iteration_regression_issue_69967_alt") { - ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + CheckResult result = check(R"( type Iterable = typeof(setmetatable( {}, @@ -911,7 +913,9 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "dcr_xpath_candidates") TEST_CASE_FIXTURE(BuiltinsFixture, "dcr_iteration_on_never_gives_never") { - ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + CheckResult result = check(R"( local iter: never local ans diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 6f4f9328..b3d70bd7 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -1,5 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/TypeInfer.h" +#include "Luau/RecursionCounter.h" #include "Fixture.h" @@ -999,4 +1000,48 @@ end LUAU_REQUIRE_NO_ERRORS(result); } +// We would prefer this unification to be able to complete, but at least it should not crash +TEST_CASE_FIXTURE(BuiltinsFixture, "table_unification_infinite_recursion") +{ + ScopedFastFlag luauTableUnifyRecursionLimit{"LuauTableUnifyRecursionLimit", true}; + +#if defined(_NOOPT) || defined(_DEBUG) + ScopedFastInt LuauTypeInferRecursionLimit{"LuauTypeInferRecursionLimit", 100}; +#endif + + fileResolver.source["game/A"] = R"( +local tbl = {} + +function tbl:f1(state) + self.someNonExistentvalue2 = state +end + +function tbl:f2() + self.someNonExistentvalue:Dc() +end + +function tbl:f3() + self:f2() + self:f1(false) +end +return tbl + )"; + + fileResolver.source["game/B"] = R"( +local tbl = require(game.A) +tbl:f3() + )"; + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + // TODO: DCR should transform RecursionLimitException into a CodeTooComplex error (currently it rethows it as InternalCompilerError) + CHECK_THROWS_AS(frontend.check("game/B"), Luau::InternalCompilerError); + } + else + { + CheckResult result = frontend.check("game/B"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + } +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index ca302a2f..c0dbfce8 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -289,17 +289,26 @@ TEST_CASE_FIXTURE(Fixture, "type_assertion_expr_carry_its_constraints") TEST_CASE_FIXTURE(BuiltinsFixture, "typeguard_in_if_condition_position") { - CheckResult result = check(R"( - function f(s: any) + CheckResult result1 = check(R"( + function f(s: any, t: unknown) if type(s) == "number" then local n = s end + if type(t) == "number" then + local n = t + end end )"); - LUAU_REQUIRE_NO_ERRORS(result); + LUAU_REQUIRE_NO_ERRORS(result1); + + // DCR changes refinements to preserve error suppression. + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK_EQ("*error-type* | number", toString(requireTypeAtPosition({3, 26}))); + else + CHECK_EQ("number", toString(requireTypeAtPosition({3, 26}))); + CHECK_EQ("number", toString(requireTypeAtPosition({6, 26}))); - CHECK_EQ("number", toString(requireTypeAtPosition({3, 26}))); } TEST_CASE_FIXTURE(BuiltinsFixture, "typeguard_in_assert_position") @@ -322,16 +331,28 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "call_an_incompatible_function_after_using_ty return x end - local function g(x: any) + local function g(x: unknown) + if type(x) == "string" then + f(x) + end + end + + local function h(x: any) if type(x) == "string" then f(x) end end )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); + if (FFlag::DebugLuauDeferredConstraintResolution) + LUAU_REQUIRE_ERROR_COUNT(1, result); + else + LUAU_REQUIRE_ERROR_COUNT(2, result); CHECK_EQ("Type 'string' could not be converted into 'number'", toString(result.errors[0])); + + if (!FFlag::DebugLuauDeferredConstraintResolution) + CHECK_EQ("Type 'string' could not be converted into 'number'", toString(result.errors[1])); } TEST_CASE_FIXTURE(BuiltinsFixture, "impossible_type_narrow_is_not_an_error") @@ -785,16 +806,23 @@ TEST_CASE_FIXTURE(Fixture, "not_a_and_not_b2") TEST_CASE_FIXTURE(BuiltinsFixture, "either_number_or_string") { CheckResult result = check(R"( - local function f(x: any) + local function f(x: any, y: unknown) if type(x) == "number" or type(x) == "string" then local foo = x end + if type(y) == "number" or type(y) == "string" then + local foo = y + end end )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("number | string", toString(requireTypeAtPosition({3, 28}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK_EQ("*error-type* | number | string", toString(requireTypeAtPosition({3, 28}))); + else + CHECK_EQ("number | string", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("number | string", toString(requireTypeAtPosition({6, 28}))); } TEST_CASE_FIXTURE(Fixture, "not_t_or_some_prop_of_t") @@ -906,15 +934,30 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "type_comparison_ifelse_expression") function f(v:any) return if typeof(v) == "number" then v else returnOne(v) end + + function g(v:unknown) + return if typeof(v) == "number" then v else returnOne(v) + end )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("number", toString(requireTypeAtPosition({6, 49}))); if (FFlag::DebugLuauDeferredConstraintResolution) - CHECK_EQ("~number", toString(requireTypeAtPosition({6, 66}))); + { + CHECK_EQ("*error-type* | number", toString(requireTypeAtPosition({6, 49}))); + CHECK_EQ("*error-type* | ~number", toString(requireTypeAtPosition({6, 66}))); + } else + { + CHECK_EQ("number", toString(requireTypeAtPosition({6, 49}))); CHECK_EQ("any", toString(requireTypeAtPosition({6, 66}))); + } + + CHECK_EQ("number", toString(requireTypeAtPosition({10, 49}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK_EQ("unknown & ~number", toString(requireTypeAtPosition({10, 66}))); + else + CHECK_EQ("unknown", toString(requireTypeAtPosition({10, 66}))); } @@ -1862,4 +1905,20 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "refine_unknown_to_table") CHECK_EQ("unknown", toString(requireType("val"))); } +TEST_CASE_FIXTURE(BuiltinsFixture, "conditional_refinement_should_stay_error_suppressing") +{ + ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + // this test is DCR-only as an instance of DCR fixing a bug in the old solver + + CheckResult result = check(R"( + local function test(element: any?) + if element then + local owner = element._owner + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 36422f8d..4f40b386 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -78,9 +78,37 @@ TEST_CASE_FIXTURE(Fixture, "infer_locals_via_assignment_from_its_call_site") f("foo") )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK("number | string" == toString(requireType("a"))); + CHECK("(number | string) -> ()" == toString(requireType("f"))); - CHECK_EQ("number", toString(requireType("a"))); + LUAU_REQUIRE_NO_ERRORS(result); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("number", toString(requireType("a"))); + } +} +TEST_CASE_FIXTURE(Fixture, "interesting_local_type_inference_case") +{ + ScopedFastFlag sff[] = { + {"DebugLuauDeferredConstraintResolution", true}, + }; + + CheckResult result = check(R"( + local a + function f(x) a = x end + f({x = 5}) + f({x = 5}) + )"); + + CHECK("{ x: number }" == toString(requireType("a"))); + CHECK("({ x: number }) -> ()" == toString(requireType("f"))); + + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "infer_in_nocheck_mode") diff --git a/tests/Unifier2.test.cpp b/tests/Unifier2.test.cpp new file mode 100644 index 00000000..363c8109 --- /dev/null +++ b/tests/Unifier2.test.cpp @@ -0,0 +1,108 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Scope.h" +#include "Luau/ToString.h" +#include "Luau/TypeArena.h" +#include "Luau/Unifier2.h" +#include "Luau/Error.h" + +#include "ScopedFlags.h" + +#include "doctest.h" + +using namespace Luau; + +struct Unifier2Fixture +{ + TypeArena arena; + BuiltinTypes builtinTypes; + Scope scope{builtinTypes.anyTypePack}; + InternalErrorReporter iceReporter; + Unifier2 u2{NotNull{&arena}, NotNull{&builtinTypes}, NotNull{&iceReporter}}; + ToStringOptions opts; + + ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + + std::pair freshType() + { + FreeType ft{&scope, builtinTypes.neverType, builtinTypes.unknownType}; + + TypeId ty = arena.addType(ft); + FreeType* ftv = getMutable(ty); + REQUIRE(ftv != nullptr); + + return {ty, ftv}; + } + + std::string toString(TypeId ty) + { + return ::Luau::toString(ty, opts); + } +}; + +TEST_SUITE_BEGIN("Unifier2"); + +TEST_CASE_FIXTURE(Unifier2Fixture, "T <: number") +{ + auto [left, freeLeft] = freshType(); + + CHECK(u2.unify(left, builtinTypes.numberType)); + + CHECK("never" == toString(freeLeft->lowerBound)); + CHECK("number" == toString(freeLeft->upperBound)); +} + +TEST_CASE_FIXTURE(Unifier2Fixture, "number <: T") +{ + auto [right, freeRight] = freshType(); + + CHECK(u2.unify(builtinTypes.numberType, right)); + + CHECK("number" == toString(freeRight->lowerBound)); + CHECK("unknown" == toString(freeRight->upperBound)); +} + +TEST_CASE_FIXTURE(Unifier2Fixture, "T <: U") +{ + auto [left, freeLeft] = freshType(); + auto [right, freeRight] = freshType(); + + CHECK(u2.unify(left, right)); + + CHECK("t1 where t1 = ('a <: (t1 <: 'b))" == toString(left)); + CHECK("t1 where t1 = (('a <: t1) <: 'b)" == toString(right)); + + CHECK("never" == toString(freeLeft->lowerBound)); + CHECK("t1 where t1 = (('a <: t1) <: 'b)" == toString(freeLeft->upperBound)); + + CHECK("t1 where t1 = ('a <: (t1 <: 'b))" == toString(freeRight->lowerBound)); + CHECK("unknown" == toString(freeRight->upperBound)); +} + +TEST_CASE_FIXTURE(Unifier2Fixture, "(string) -> () <: (X) -> Y...") +{ + TypeId stringToUnit = arena.addType(FunctionType{ + arena.addTypePack({builtinTypes.stringType}), + arena.addTypePack({}) + }); + + auto [x, xFree] = freshType(); + TypePackId y = arena.freshTypePack(&scope); + + TypeId xToY = arena.addType(FunctionType{ + arena.addTypePack({x}), + y + }); + + u2.unify(stringToUnit, xToY); + + CHECK("string" == toString(xFree->upperBound)); + + const TypePack* yPack = get(follow(y)); + REQUIRE(yPack != nullptr); + + CHECK(0 == yPack->head.size()); + CHECK(!yPack->tail); +} + +TEST_SUITE_END(); diff --git a/tests/conformance/native_types.lua b/tests/conformance/native_types.lua new file mode 100644 index 00000000..67779230 --- /dev/null +++ b/tests/conformance/native_types.lua @@ -0,0 +1,72 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +print("testing native code generation with type annotations") + +function call(fn, ...) + local ok, res = pcall(fn, ...) + assert(ok) + return res +end + +function ecall(fn, ...) + local ok, err = pcall(fn, ...) + assert(not ok) + return err:sub(err:find(": ") + 2, #err) +end + +local function add(a: number, b: number, native: boolean) + assert(native == is_native()) + return a + b +end + +call(add, 1, 3, true) +ecall(add, nil, 2, false) + +local function isnil(x: nil) + assert(is_native()) + return not x +end + +call(isnil, nil) +ecall(isnil, 2) + +local function isany(x: any, y: number) + assert(is_native()) + return not not x +end + +call(isany, nil, 1) +call(isany, 2, 1) +call(isany, {}, 1) + +local function optstring(s: string?) + assert(is_native()) + return if s then s..'2' else '3' +end + +assert(call(optstring, nil) == '3') +assert(call(optstring, 'two: ') == 'two: 2') +ecall(optstring, 2) + +local function checktable(a: {x:number}) assert(is_native()) end +local function checkfunction(a: () -> ()) assert(is_native()) end +local function checkthread(a: thread) assert(is_native()) end +local function checkuserdata(a: userdata) assert(is_native()) end +local function checkvector(a: vector) assert(is_native()) end + +call(checktable, {}) +ecall(checktable, 2) + +call(checkfunction, function() end) +ecall(checkfunction, 2) + +call(checkthread, coroutine.create(function() end)) +ecall(checkthread, 2) + +call(checkuserdata, newproxy()) +ecall(checkuserdata, 2) + +call(checkvector, vector(1, 2, 3)) +ecall(checkvector, 2) + + +return('OK') diff --git a/tools/faillist.txt b/tools/faillist.txt index 3e2ee185..390b8b62 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -1,33 +1,73 @@ +AnnotationTests.infer_type_of_value_a_via_typeof_with_assignment +AnnotationTests.two_type_params AstQuery.last_argument_function_call_type AutocompleteTest.anonymous_autofilled_generic_on_argument_type_pack_vararg AutocompleteTest.anonymous_autofilled_generic_type_pack_vararg +AutocompleteTest.autocomplete_if_else_regression +AutocompleteTest.autocomplete_interpolated_string_as_singleton +AutocompleteTest.autocomplete_oop_implicit_self +AutocompleteTest.autocomplete_response_perf1 +AutocompleteTest.autocomplete_string_singleton_escape +AutocompleteTest.autocomplete_string_singletons +AutocompleteTest.cyclic_table +AutocompleteTest.suggest_external_module_type +AutocompleteTest.type_correct_expected_argument_type_pack_suggestion +AutocompleteTest.type_correct_expected_argument_type_suggestion +AutocompleteTest.type_correct_expected_argument_type_suggestion_optional +AutocompleteTest.type_correct_expected_argument_type_suggestion_self +AutocompleteTest.type_correct_function_no_parenthesis +AutocompleteTest.type_correct_function_return_types +AutocompleteTest.type_correct_keywords +AutocompleteTest.type_correct_suggestion_in_argument +AutocompleteTest.unsealed_table_2 BuiltinTests.aliased_string_format BuiltinTests.assert_removes_falsy_types BuiltinTests.assert_removes_falsy_types2 BuiltinTests.assert_removes_falsy_types_even_from_type_pack_tail_but_only_for_the_first_type BuiltinTests.assert_returns_false_and_string_iff_it_knows_the_first_argument_cannot_be_truthy BuiltinTests.bad_select_should_not_crash +BuiltinTests.next_iterator_should_infer_types_and_type_check BuiltinTests.select_slightly_out_of_range BuiltinTests.select_way_out_of_range BuiltinTests.set_metatable_needs_arguments BuiltinTests.setmetatable_should_not_mutate_persisted_types +BuiltinTests.sort_with_bad_predicate +BuiltinTests.string_format_arg_types_inference BuiltinTests.string_format_as_method BuiltinTests.string_format_correctly_ordered_types BuiltinTests.string_format_report_all_type_errors_at_correct_positions BuiltinTests.string_format_tostring_specifier_type_constraint BuiltinTests.string_format_use_correct_argument2 +BuiltinTests.table_dot_remove_optionally_returns_generic +BuiltinTests.table_freeze_is_generic +BuiltinTests.table_insert_correctly_infers_type_of_array_2_args_overload +BuiltinTests.table_insert_correctly_infers_type_of_array_3_args_overload BuiltinTests.table_pack BuiltinTests.table_pack_reduce BuiltinTests.table_pack_variadic DefinitionTests.class_definition_indexer DefinitionTests.class_definition_overload_metamethods DefinitionTests.class_definition_string_props +Differ.equal_generictp_cyclic +Differ.equal_table_A_B_C +Differ.equal_table_cyclic_diamonds_unraveled +Differ.equal_table_kind_A +Differ.equal_table_kind_B +Differ.equal_table_kind_C +Differ.equal_table_kind_D +Differ.equal_table_measuring_tapes +Differ.equal_table_two_shifted_circles_are_not_different +Differ.function_table_self_referential_cyclic +Differ.generictp_normal +Differ.generictp_normal_2 +Differ.table_left_circle_right_measuring_tape GenericsTests.better_mismatch_error_messages +GenericsTests.bidirectional_checking_and_generalization_play_nice GenericsTests.bound_tables_do_not_clone_original_fields GenericsTests.check_mutual_generic_functions GenericsTests.correctly_instantiate_polymorphic_member_functions GenericsTests.do_not_infer_generic_functions -GenericsTests.dont_unify_bound_types +GenericsTests.dont_substitute_bound_types GenericsTests.generic_argument_count_too_few GenericsTests.generic_argument_count_too_many GenericsTests.generic_functions_should_be_memory_safe @@ -37,71 +77,125 @@ GenericsTests.infer_generic_function_function_argument_2 GenericsTests.infer_generic_function_function_argument_3 GenericsTests.infer_generic_function_function_argument_overloaded GenericsTests.infer_generic_lib_function_function_argument +GenericsTests.infer_generic_property +GenericsTests.instantiate_generic_function_in_assignments +GenericsTests.instantiate_generic_function_in_assignments2 GenericsTests.instantiated_function_argument_names +GenericsTests.mutable_state_polymorphism GenericsTests.no_stack_overflow_from_quantifying +GenericsTests.properties_can_be_polytypes +GenericsTests.quantify_functions_even_if_they_have_an_explicit_generic GenericsTests.self_recursive_instantiated_param IntersectionTypes.intersection_of_tables_with_top_properties +IntersectionTypes.less_greedy_unification_with_intersection_types IntersectionTypes.table_intersection_write_sealed_indirect IntersectionTypes.table_write_sealed_indirect +Normalize.negations_of_tables +Normalize.specific_functions_cannot_be_negated +ParserTests.parse_nesting_based_end_detection +ParserTests.parse_nesting_based_end_detection_single_line ProvisionalTests.assign_table_with_refined_property_with_a_similar_type_is_illegal ProvisionalTests.do_not_ice_when_trying_to_pick_first_of_generic_type_pack ProvisionalTests.error_on_eq_metamethod_returning_a_type_other_than_boolean ProvisionalTests.free_options_can_be_unified_together ProvisionalTests.free_options_cannot_be_unified_together ProvisionalTests.greedy_inference_with_shared_self_triggers_function_with_no_returns -ProvisionalTests.luau-polyfill.Array.filter ProvisionalTests.setmetatable_constrains_free_type_into_free_table ProvisionalTests.specialization_binds_with_prototypes_too_early ProvisionalTests.table_insert_with_a_singleton_argument +ProvisionalTests.table_unification_infinite_recursion ProvisionalTests.typeguard_inference_incomplete +RefinementTest.call_an_incompatible_function_after_using_typeguard +RefinementTest.dataflow_analysis_can_tell_refinements_when_its_appropriate_to_refine_into_nil_or_never RefinementTest.discriminate_from_truthiness_of_x +RefinementTest.fail_to_refine_a_property_of_subscript_expression +RefinementTest.falsiness_of_TruthyPredicate_narrows_into_nil +RefinementTest.isa_type_refinement_must_be_known_ahead_of_time +RefinementTest.narrow_property_of_a_bounded_variable +RefinementTest.nonoptional_type_can_narrow_to_nil_if_sense_is_true RefinementTest.not_t_or_some_prop_of_t +RefinementTest.refine_a_param_that_got_resolved_during_constraint_solving_stage_2 RefinementTest.refine_a_property_of_some_global RefinementTest.truthy_constraint_on_properties RefinementTest.type_narrow_to_vector RefinementTest.typeguard_cast_free_table_to_vector RefinementTest.typeguard_in_assert_position RefinementTest.x_as_any_if_x_is_instance_elseif_x_is_table -RuntimeLimits.typescript_port_of_Result_type +TableTests.a_free_shape_can_turn_into_a_scalar_if_it_is_compatible TableTests.a_free_shape_cannot_turn_into_a_scalar_if_it_is_not_compatible +TableTests.call_method +TableTests.call_method_with_explicit_self_argument +TableTests.cannot_augment_sealed_table +TableTests.cannot_change_type_of_unsealed_table_prop +TableTests.casting_sealed_tables_with_props_into_table_with_indexer +TableTests.casting_tables_with_props_into_table_with_indexer4 +TableTests.casting_unsealed_tables_with_props_into_table_with_indexer TableTests.checked_prop_too_early +TableTests.cyclic_shifted_tables +TableTests.defining_a_method_for_a_local_sealed_table_must_fail +TableTests.defining_a_self_method_for_a_local_sealed_table_must_fail TableTests.disallow_indexing_into_an_unsealed_table_with_no_indexer_in_strict_mode TableTests.dont_crash_when_setmetatable_does_not_produce_a_metatabletypevar +TableTests.dont_extend_unsealed_tables_in_rvalue_position TableTests.dont_hang_when_trying_to_look_up_in_cyclic_metatable_index +TableTests.dont_leak_free_table_props TableTests.dont_suggest_exact_match_keys TableTests.error_detailed_metatable_prop TableTests.explicitly_typed_table TableTests.explicitly_typed_table_with_indexer -TableTests.fuzz_table_unify_instantiated_table TableTests.fuzz_table_unify_instantiated_table_with_prop_realloc +TableTests.generalize_table_argument TableTests.generic_table_instantiation_potential_regression TableTests.give_up_after_one_metatable_index_look_up -TableTests.indexer_on_sealed_table_must_unify_with_free_table +TableTests.hide_table_error_properties +TableTests.indexers_get_quantified_too TableTests.indexing_from_a_table_should_prefer_properties_when_possible TableTests.inequality_operators_imply_exactly_matching_types +TableTests.infer_array_2 +TableTests.infer_indexer_from_value_property_in_literal +TableTests.infer_type_when_indexing_from_a_table_indexer +TableTests.inferred_properties_of_a_table_should_start_with_the_same_TypeLevel_of_that_table TableTests.inferred_return_type_of_free_table TableTests.instantiate_table_cloning_3 TableTests.leaking_bad_metatable_errors TableTests.less_exponential_blowup_please TableTests.missing_metatable_for_sealed_tables_do_not_get_inferred TableTests.mixed_tables_with_implicit_numbered_keys -TableTests.nil_assign_doesnt_hit_indexer -TableTests.ok_to_set_nil_even_on_non_lvalue_base_expr +TableTests.ok_to_provide_a_subtype_during_construction +TableTests.okay_to_add_property_to_unsealed_tables_by_assignment +TableTests.okay_to_add_property_to_unsealed_tables_by_function_call +TableTests.only_ascribe_synthetic_names_at_module_scope +TableTests.oop_indexer_works TableTests.oop_polymorphic TableTests.quantify_even_that_table_was_never_exported_at_all TableTests.quantify_metatables_of_metatables_of_table +TableTests.quantifying_a_bound_var_works TableTests.reasonable_error_when_adding_a_nonexistent_property_to_an_array_like_table +TableTests.recursive_metatable_type_call TableTests.right_table_missing_key2 TableTests.shared_selfs TableTests.shared_selfs_from_free_param TableTests.shared_selfs_through_metatables TableTests.table_call_metamethod_basic TableTests.table_simple_call +TableTests.table_subtyping_with_extra_props_dont_report_multiple_errors TableTests.table_subtyping_with_missing_props_dont_report_multiple_errors +TableTests.table_subtyping_with_missing_props_dont_report_multiple_errors2 +TableTests.table_unification_4 +TableTests.type_mismatch_on_massive_table_is_cut_short TableTests.used_colon_instead_of_dot TableTests.used_dot_instead_of_colon +TableTests.used_dot_instead_of_colon_but_correctly +TableTests.when_augmenting_an_unsealed_table_with_an_indexer_apply_the_correct_scope_to_the_indexer_type +TableTests.wrong_assign_does_hit_indexer +ToDot.function +ToString.exhaustive_toString_of_cyclic_table +ToString.free_types +ToString.named_metatable_toStringNamedFunction +ToString.pick_distinct_names_for_mixed_explicit_and_implicit_generics ToString.toStringDetailed2 ToString.toStringErrorPack +ToString.toStringGenericPack ToString.toStringNamedFunction_generic_pack TryUnifyTests.members_of_failed_typepack_unification_are_unified_with_errorType TryUnifyTests.result_of_failed_typepack_unification_is_constrained @@ -118,34 +212,52 @@ TypeAliases.type_alias_local_mutation TypeAliases.type_alias_local_rename TypeAliases.type_alias_locations TypeAliases.type_alias_of_an_imported_recursive_generic_type +TypeFamilyTests.family_as_fn_arg +TypeFamilyTests.table_internal_families +TypeFamilyTests.unsolvable_family +TypeInfer.bidirectional_checking_of_higher_order_function TypeInfer.check_type_infer_recursion_count +TypeInfer.cli_39932_use_unifier_in_ensure_methods TypeInfer.cli_50041_committing_txnlog_in_apollo_client_error TypeInfer.dont_report_type_errors_within_an_AstExprError TypeInfer.dont_report_type_errors_within_an_AstStatError TypeInfer.follow_on_new_types_in_substitution TypeInfer.fuzz_free_table_type_change_during_index_check TypeInfer.infer_assignment_value_types_mutable_lval +TypeInfer.infer_locals_via_assignment_from_its_call_site TypeInfer.no_stack_overflow_from_isoptional +TypeInfer.recursive_function_that_invokes_itself_with_a_refinement_of_its_parameter TypeInfer.recursive_function_that_invokes_itself_with_a_refinement_of_its_parameter_2 TypeInfer.tc_after_error_recovery_no_replacement_name_in_error +TypeInfer.type_infer_cache_limit_normalizer TypeInfer.type_infer_recursion_limit_no_ice TypeInfer.type_infer_recursion_limit_normalizer +TypeInferAnyError.can_subscript_any TypeInferAnyError.for_in_loop_iterator_is_any2 +TypeInferAnyError.for_in_loop_iterator_returns_any +TypeInferAnyError.intersection_of_any_can_have_props +TypeInferAnyError.replace_every_free_type_when_unifying_a_complex_function_with_any +TypeInferAnyError.union_of_types_regression_test +TypeInferClasses.can_read_prop_of_base_class_using_string TypeInferClasses.class_type_mismatch_with_name_conflict TypeInferClasses.index_instance_property TypeInferFunctions.cannot_hoist_interior_defns_into_signature +TypeInferFunctions.dont_assert_when_the_tarjan_limit_is_exceeded_during_generalization TypeInferFunctions.dont_infer_parameter_types_for_functions_from_their_call_site TypeInferFunctions.function_cast_error_uses_correct_language TypeInferFunctions.function_decl_non_self_sealed_overwrite_2 TypeInferFunctions.function_decl_non_self_unsealed_overwrite TypeInferFunctions.function_does_not_return_enough_values +TypeInferFunctions.function_exprs_are_generalized_at_signature_scope_not_enclosing TypeInferFunctions.function_statement_sealed_table_assignment_through_indexer +TypeInferFunctions.higher_order_function_2 +TypeInferFunctions.higher_order_function_4 TypeInferFunctions.improved_function_arg_mismatch_errors TypeInferFunctions.infer_anonymous_function_arguments +TypeInferFunctions.infer_anonymous_function_arguments_outside_call TypeInferFunctions.infer_that_function_does_not_return_a_table TypeInferFunctions.luau_subtyping_is_np_hard TypeInferFunctions.no_lossy_function_type -TypeInferFunctions.occurs_check_failure_in_function_return_type TypeInferFunctions.report_exiting_without_return_strict TypeInferFunctions.return_type_by_overload TypeInferFunctions.too_few_arguments_variadic @@ -154,33 +266,53 @@ TypeInferFunctions.too_few_arguments_variadic_generic2 TypeInferFunctions.too_many_arguments_error_location TypeInferFunctions.too_many_return_values_in_parentheses TypeInferFunctions.too_many_return_values_no_function -TypeInferLoops.dcr_iteration_explore_raycast_minimization +TypeInferFunctions.toposort_doesnt_break_mutual_recursion +TypeInferFunctions.vararg_function_is_quantified +TypeInferLoops.cli_68448_iterators_need_not_accept_nil +TypeInferLoops.dcr_iteration_on_never_gives_never +TypeInferLoops.for_in_loop TypeInferLoops.for_in_loop_error_on_factory_not_returning_the_right_amount_of_values +TypeInferLoops.for_in_loop_with_custom_iterator +TypeInferLoops.for_in_loop_with_incompatible_args_to_iterator +TypeInferLoops.for_in_loop_with_next +TypeInferLoops.ipairs_produces_integral_indices +TypeInferLoops.iteration_regression_issue_69967_alt +TypeInferLoops.loop_iter_metamethod_nil TypeInferLoops.loop_iter_metamethod_ok_with_inference TypeInferLoops.loop_iter_trailing_nil -TypeInferLoops.properly_infer_iteratee_is_a_free_table TypeInferLoops.unreachable_code_after_infinite_loop TypeInferModules.do_not_modify_imported_types_5 TypeInferModules.module_type_conflict TypeInferModules.module_type_conflict_instantiated +TypeInferOOP.cycle_between_object_constructor_and_alias +TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_it_wont_help_2 +TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_not_defined_with_colon TypeInferOOP.inferring_hundreds_of_self_calls_should_not_suffocate_memory TypeInferOOP.methods_are_topologically_sorted +TypeInferOperators.and_binexps_dont_unify TypeInferOperators.cli_38355_recursive_union TypeInferOperators.compound_assign_mismatch_metatable +TypeInferOperators.concat_op_on_string_lhs_and_free_rhs TypeInferOperators.disallow_string_and_types_without_metatables_from_arithmetic_binary_ops +TypeInferOperators.luau_polyfill_is_array TypeInferOperators.operator_eq_completely_incompatible TypeInferOperators.typecheck_overloaded_multiply_that_is_an_intersection TypeInferOperators.typecheck_overloaded_multiply_that_is_an_intersection_on_rhs TypeInferOperators.typecheck_unary_len_error +TypeInferOperators.typecheck_unary_minus +TypeInferOperators.typecheck_unary_minus_error TypeInferOperators.unrelated_classes_cannot_be_compared TypeInferOperators.unrelated_primitives_cannot_be_compared TypeInferPrimitives.CheckMethodsOfNumber TypeInferPrimitives.string_index +TypeInferUnknownNever.length_of_never TypeInferUnknownNever.math_operators_and_never TypePackTests.detect_cyclic_typepacks2 +TypePackTests.higher_order_function TypePackTests.pack_tail_unification_check TypePackTests.type_alias_backwards_compatible TypePackTests.type_alias_default_type_errors +TypeSingletons.function_args_infer_singletons TypeSingletons.function_call_with_singletons TypeSingletons.function_call_with_singletons_mismatch TypeSingletons.no_widening_from_callsites @@ -192,4 +324,5 @@ TypeSingletons.widening_happens_almost_everywhere UnionTypes.dont_allow_cyclic_unions_to_be_inferred UnionTypes.generic_function_with_optional_arg UnionTypes.index_on_a_union_type_with_missing_property +UnionTypes.less_greedy_unification_with_union_types UnionTypes.table_union_write_indirect diff --git a/tools/test_dcr.py b/tools/test_dcr.py index 208096fa..30f8a310 100644 --- a/tools/test_dcr.py +++ b/tools/test_dcr.py @@ -126,12 +126,6 @@ def main(): args = parser.parse_args() - if args.write and args.rwp: - print_stderr( - "Cannot run test_dcr.py with --write *and* --rwp. You don't want to commit local type inference faillist.txt yet." - ) - sys.exit(1) - failList = loadFailList() flags = ["true", "DebugLuauDeferredConstraintResolution"]