diff --git a/Analysis/include/Luau/ConstraintGenerator.h b/Analysis/include/Luau/ConstraintGenerator.h index b540b82f..28cfb5aa 100644 --- a/Analysis/include/Luau/ConstraintGenerator.h +++ b/Analysis/include/Luau/ConstraintGenerator.h @@ -118,7 +118,7 @@ struct ConstraintGenerator std::function prepareModuleScope; std::vector requireCycles; - DenseHashMap> localTypes{nullptr}; + DenseHashMap localTypes{nullptr}; DcrLogger* logger; diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index 902dd15d..6e62a2e3 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -94,6 +94,10 @@ struct ConstraintSolver // Irreducible/uninhabited type families or type pack families. DenseHashSet uninhabitedTypeFamilies{{}}; + // The set of types that will definitely be unchanged by generalization. + DenseHashSet generalizedTypes_{nullptr}; + const NotNull> generalizedTypes{&generalizedTypes_}; + // Recorded errors that take place within the solver. ErrorVec errors; @@ -103,6 +107,8 @@ struct ConstraintSolver DcrLogger* logger; TypeCheckLimits limits; + DenseHashMap typeFamiliesToFinalize{nullptr}; + explicit ConstraintSolver(NotNull normalizer, NotNull rootScope, std::vector> constraints, ModuleName moduleName, NotNull moduleResolver, std::vector requireCycles, DcrLogger* logger, TypeCheckLimits limits); @@ -116,8 +122,35 @@ struct ConstraintSolver **/ void run(); + + /** + * Attempts to perform one final reduction on type families after every constraint has been completed + * + **/ + void finalizeTypeFamilies(); + bool isDone(); +private: + /** + * Bind a type variable to another type. + * + * A constraint is required and will validate that blockedTy is owned by this + * constraint. This prevents one constraint from interfering with another's + * blocked types. + * + * Bind will also unblock the type variable for you. + */ + void bind(NotNull constraint, TypeId ty, TypeId boundTo); + void bind(NotNull constraint, TypePackId tp, TypePackId boundTo); + + template + void emplace(NotNull constraint, TypeId ty, Args&&... args); + + template + void emplace(NotNull constraint, TypePackId tp, Args&&... args); + +public: /** Attempt to dispatch a constraint. Returns true if it was successful. If * tryDispatch() returns false, the constraint remains in the unsolved set * and will be retried later. @@ -135,19 +168,14 @@ struct ConstraintSolver bool tryDispatch(const PrimitiveTypeConstraint& c, NotNull constraint); bool tryDispatch(const HasPropConstraint& c, NotNull constraint); + bool tryDispatchHasIndexer( int& recursionDepth, NotNull constraint, TypeId subjectType, TypeId indexType, TypeId resultType, Set& seen); bool tryDispatch(const HasIndexerConstraint& c, NotNull constraint); - std::pair> tryDispatchSetIndexer( - NotNull constraint, TypeId subjectType, TypeId indexType, TypeId propType, bool expandFreeTypeBounds); - bool tryDispatch(const AssignPropConstraint& c, NotNull constraint); bool tryDispatch(const AssignIndexConstraint& c, NotNull constraint); - - bool tryDispatchUnpack1(NotNull constraint, TypeId resultType, TypeId sourceType); bool tryDispatch(const UnpackConstraint& c, NotNull constraint); - bool tryDispatch(const ReduceConstraint& c, NotNull constraint, bool force); bool tryDispatch(const ReducePackConstraint& c, NotNull constraint, bool force); bool tryDispatch(const EqualityConstraint& c, NotNull constraint, bool force); @@ -298,22 +326,6 @@ struct ConstraintSolver template bool unify(NotNull constraint, TID subTy, TID superTy); -private: - /** - * Bind a BlockedType to another type while taking care not to bind it to - * itself in the case that resultTy == blockedTy. This can happen if we - * have a tautological constraint. When it does, we must instead bind - * blockedTy to a fresh type belonging to an appropriate scope. - * - * To determine which scope is appropriate, we also accept rootTy, which is - * to be the type that contains blockedTy. - * - * A constraint is required and will validate that blockedTy is owned by this - * constraint. This prevents one constraint from interfering with another's - * blocked types. - */ - void bindBlockedType(TypeId blockedTy, TypeId resultTy, TypeId rootTy, NotNull constraint); - /** * Marks a constraint as being blocked on a type or type pack. The constraint * solver will not attempt to dispatch blocked constraints until their diff --git a/Analysis/include/Luau/Generalization.h b/Analysis/include/Luau/Generalization.h index bf196f3e..44d0db67 100644 --- a/Analysis/include/Luau/Generalization.h +++ b/Analysis/include/Luau/Generalization.h @@ -8,6 +8,6 @@ namespace Luau { -std::optional generalize(NotNull arena, NotNull builtinTypes, NotNull scope, TypeId ty); +std::optional generalize(NotNull arena, NotNull builtinTypes, NotNull scope, NotNull> bakedTypes, TypeId ty); } diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index 197c7f9c..152d8c65 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -102,6 +102,12 @@ struct Module DenseHashMap astResolvedTypes{nullptr}; DenseHashMap astResolvedTypePacks{nullptr}; + // The computed result type of a compound assignment. (eg foo += 1) + // + // Type checking uses this to check that the result of such an operation is + // actually compatible with the left-side operand. + DenseHashMap astCompoundAssignResultTypes{nullptr}; + DenseHashMap>> upperBoundContributors{nullptr}; // Map AST nodes to the scope they create. Cannot be NotNull because diff --git a/Analysis/include/Luau/Type.h b/Analysis/include/Luau/Type.h index 6105ede3..881dc646 100644 --- a/Analysis/include/Luau/Type.h +++ b/Analysis/include/Luau/Type.h @@ -658,7 +658,7 @@ struct NegationType using ErrorType = Unifiable::Error; using TypeVariant = - Unifiable::Variant; struct Type final diff --git a/Analysis/include/Luau/TypeFamily.h b/Analysis/include/Luau/TypeFamily.h index 5b72a370..fa23a6ba 100644 --- a/Analysis/include/Luau/TypeFamily.h +++ b/Analysis/include/Luau/TypeFamily.h @@ -179,6 +179,8 @@ struct BuiltinTypeFamilies TypeFamily keyofFamily; TypeFamily rawkeyofFamily; + TypeFamily indexFamily; + void addToScope(NotNull arena, NotNull scope) const; }; diff --git a/Analysis/src/Constraint.cpp b/Analysis/src/Constraint.cpp index 7b3377cb..a62879fa 100644 --- a/Analysis/src/Constraint.cpp +++ b/Analysis/src/Constraint.cpp @@ -93,6 +93,10 @@ DenseHashSet Constraint::getMaybeMutatedFreeTypes() const { rci.traverse(taec->target); } + else if (auto fchc = get(*this)) + { + rci.traverse(fchc->argsPack); + } else if (auto ptc = get(*this)) { rci.traverse(ptc->freeType); diff --git a/Analysis/src/ConstraintGenerator.cpp b/Analysis/src/ConstraintGenerator.cpp index 9d825408..b784f4aa 100644 --- a/Analysis/src/ConstraintGenerator.cpp +++ b/Analysis/src/ConstraintGenerator.cpp @@ -253,7 +253,11 @@ void ConstraintGenerator::visitModuleRoot(AstStatBlock* block) // FIXME: This isn't the most efficient thing. TypeId domainTy = builtinTypes->neverType; for (TypeId d : domain) + { + if (d == ty) + continue; domainTy = simplifyUnion(builtinTypes, arena, domainTy, d).result; + } LUAU_ASSERT(get(ty)); asMutable(ty)->ty.emplace(domainTy); @@ -323,7 +327,7 @@ std::optional ConstraintGenerator::lookup(const ScopePtr& scope, Locatio if (!ty) { ty = arena->addType(BlockedType{}); - localTypes[*ty] = {}; + localTypes.try_insert(*ty, {}); rootScope->lvalueTypes[operand] = *ty; } @@ -717,7 +721,7 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatLocal* stat const Location location = local->location; TypeId assignee = arena->addType(BlockedType{}); - localTypes[assignee] = {}; + localTypes.try_insert(assignee, {}); assignees.push_back(assignee); @@ -756,9 +760,9 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatLocal* stat for (size_t i = 0; i < statLocal->vars.size; ++i) { LUAU_ASSERT(get(assignees[i])); - std::vector* localDomain = localTypes.find(assignees[i]); + TypeIds* localDomain = localTypes.find(assignees[i]); LUAU_ASSERT(localDomain); - localDomain->push_back(annotatedTypes[i]); + localDomain->insert(annotatedTypes[i]); } TypePackId annotatedPack = arena->addTypePack(std::move(annotatedTypes)); @@ -790,9 +794,9 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatLocal* stat for (size_t i = 0; i < statLocal->vars.size; ++i) { LUAU_ASSERT(get(assignees[i])); - std::vector* localDomain = localTypes.find(assignees[i]); + TypeIds* localDomain = localTypes.find(assignees[i]); LUAU_ASSERT(localDomain); - localDomain->push_back(valueTypes[i]); + localDomain->insert(valueTypes[i]); } } @@ -898,7 +902,7 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatForIn* forI variableTypes.push_back(assignee); TypeId loopVar = arena->addType(BlockedType{}); - localTypes[loopVar].push_back(assignee); + localTypes[loopVar].insert(assignee); if (var->annotation) { @@ -1183,8 +1187,13 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatCompoundAss { AstExprBinary binop = AstExprBinary{assign->location, assign->op, assign->var, assign->value}; TypeId resultTy = check(scope, &binop).ty; + module->astCompoundAssignResultTypes[assign] = resultTy; - visitLValue(scope, assign->var, resultTy); + TypeId lhsType = check(scope, assign->var).ty; + visitLValue(scope, assign->var, lhsType); + + follow(lhsType); + follow(resultTy); return ControlFlow::None; } @@ -1383,16 +1392,15 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatDeclareClas } } - if (ctv->props.count(propName) == 0) + TableType::Props& props = assignToMetatable ? metatable->props : ctv->props; + + if (props.count(propName) == 0) { - if (assignToMetatable) - metatable->props[propName] = {propTy}; - else - ctv->props[propName] = {propTy}; + props[propName] = {propTy}; } else { - TypeId currentTy = assignToMetatable ? metatable->props[propName].type() : ctv->props[propName].type(); + TypeId currentTy = props[propName].type(); // We special-case this logic to keep the intersection flat; otherwise we // would create a ton of nested intersection types. @@ -1402,19 +1410,13 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatDeclareClas options.push_back(propTy); TypeId newItv = arena->addType(IntersectionType{std::move(options)}); - if (assignToMetatable) - metatable->props[propName] = {newItv}; - else - ctv->props[propName] = {newItv}; + props[propName] = {newItv}; } else if (get(currentTy)) { TypeId intersection = arena->addType(IntersectionType{{currentTy, propTy}}); - if (assignToMetatable) - metatable->props[propName] = {intersection}; - else - ctv->props[propName] = {intersection}; + props[propName] = {intersection}; } else { @@ -1913,8 +1915,8 @@ Inference ConstraintGenerator::checkIndexName( // the current lexical position within the script. if (!tt) { - if (auto localDomain = localTypes.find(obj); localDomain && 1 == localDomain->size()) - tt = getTableType(localDomain->front()); + if (TypeIds* localDomain = localTypes.find(obj); localDomain && 1 == localDomain->size()) + tt = getTableType(*localDomain->begin()); } if (tt) @@ -2327,14 +2329,14 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprLocal* local if (ty) { - std::vector* localDomain = localTypes.find(*ty); + TypeIds* localDomain = localTypes.find(*ty); if (localDomain) - localDomain->push_back(rhsType); + localDomain->insert(rhsType); } else { ty = arena->addType(BlockedType{}); - localTypes[*ty].push_back(rhsType); + localTypes[*ty].insert(rhsType); if (annotatedTy) { @@ -2359,8 +2361,8 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprLocal* local if (annotatedTy) addConstraint(scope, local->location, SubtypeConstraint{rhsType, *annotatedTy}); - if (auto localDomain = localTypes.find(*ty)) - localDomain->push_back(rhsType); + if (TypeIds* localDomain = localTypes.find(*ty)) + localDomain->insert(rhsType); } void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprGlobal* global, TypeId rhsType) @@ -2383,7 +2385,8 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprIndexName* e bool incremented = recordPropertyAssignment(lhsTy); - addConstraint(scope, expr->location, AssignPropConstraint{lhsTy, expr->index.value, rhsType, propTy, incremented}); + auto apc = addConstraint(scope, expr->location, AssignPropConstraint{lhsTy, expr->index.value, rhsType, propTy, incremented}); + getMutable(propTy)->setOwner(apc); } void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprIndexExpr* expr, TypeId rhsType) @@ -2398,7 +2401,8 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprIndexExpr* e bool incremented = recordPropertyAssignment(lhsTy); - addConstraint(scope, expr->location, AssignPropConstraint{lhsTy, std::move(propName), rhsType, propTy, incremented}); + auto apc = addConstraint(scope, expr->location, AssignPropConstraint{lhsTy, std::move(propName), rhsType, propTy, incremented}); + getMutable(propTy)->setOwner(apc); return; } @@ -2407,7 +2411,8 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprIndexExpr* e TypeId indexTy = check(scope, expr->index).ty; TypeId propTy = arena->addType(BlockedType{}); module->astTypes[expr] = propTy; - addConstraint(scope, expr->location, AssignIndexConstraint{lhsTy, indexTy, rhsType, propTy}); + auto aic = addConstraint(scope, expr->location, AssignIndexConstraint{lhsTy, indexTy, rhsType, propTy}); + getMutable(propTy)->setOwner(aic); } Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprTable* expr, std::optional expectedType) @@ -2447,7 +2452,8 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprTable* expr, if (AstExprConstantString* key = item.key->as()) { - ttv->props[key->value.begin()] = {itemTy}; + std::string propName{key->value.data, key->value.size}; + ttv->props[propName] = {itemTy}; } else { @@ -3187,7 +3193,7 @@ bool ConstraintGenerator::recordPropertyAssignment(TypeId ty) } else if (auto mt = get(t)) queue.push_back(mt->table); - else if (auto localDomain = localTypes.find(t)) + else if (TypeIds* localDomain = localTypes.find(t)) { for (TypeId domainTy : *localDomain) queue.push_back(domainTy); diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 07fc26fb..e59bc8a7 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -1,9 +1,9 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/ConstraintSolver.h" #include "Luau/Anyification.h" #include "Luau/ApplyTypeFunction.h" #include "Luau/Common.h" -#include "Luau/ConstraintSolver.h" #include "Luau/DcrLogger.h" #include "Luau/Generalization.h" #include "Luau/Instantiation.h" @@ -22,8 +22,8 @@ #include "Luau/TypeFwd.h" #include "Luau/TypeUtils.h" #include "Luau/Unifier2.h" -#include "Luau/VecDeque.h" #include "Luau/VisitType.h" + #include #include @@ -67,7 +67,11 @@ size_t HashBlockedConstraintId::operator()(const BlockedConstraintId& bci) const [[maybe_unused]] static bool canMutate(TypeId ty, NotNull constraint) { if (auto blocked = get(ty)) - return blocked->getOwner() == constraint; + { + Constraint* owner = blocked->getOwner(); + LUAU_ASSERT(owner); + return owner == constraint; + } return true; } @@ -76,7 +80,11 @@ size_t HashBlockedConstraintId::operator()(const BlockedConstraintId& bci) const [[maybe_unused]] static bool canMutate(TypePackId tp, NotNull constraint) { if (auto blocked = get(tp)) - return blocked->owner == nullptr || blocked->owner == constraint; + { + Constraint* owner = blocked->owner; + LUAU_ASSERT(owner); + return owner == constraint; + } return true; } @@ -478,6 +486,12 @@ void ConstraintSolver::run() progress |= runSolverPass(true); } while (progress); + // After we have run all the constraints, type families should be generalized + // At this point, we can try to perform one final simplification to suss out + // whether type families are truly uninhabited or if they can reduce + + finalizeTypeFamilies(); + if (FFlag::DebugLuauLogSolver || FFlag::DebugLuauLogBindings) dumpBindings(rootScope, opts); @@ -487,6 +501,25 @@ void ConstraintSolver::run() } } +void ConstraintSolver::finalizeTypeFamilies() +{ + // At this point, we've generalized. Let's try to finish reducing as much as we can, we'll leave warning to the typechecker + for (auto [t, constraint] : typeFamiliesToFinalize) + { + TypeId ty = follow(t); + if (get(ty)) + { + FamilyGraphReductionResult result = + reduceFamilies(t, constraint->location, TypeFamilyContext{NotNull{this}, constraint->scope, NotNull{constraint}}, true); + + for (TypeId r : result.reducedTypes) + unblock(r, constraint->location); + for (TypePackId r : result.reducedPacks) + unblock(r, constraint->location); + } + } +} + bool ConstraintSolver::isDone() { return unsolvedConstraints.empty(); @@ -503,6 +536,56 @@ struct TypeAndLocation } // namespace +void ConstraintSolver::bind(NotNull constraint, TypeId ty, TypeId boundTo) +{ + LUAU_ASSERT(get(ty) || get(ty) || get(ty)); + LUAU_ASSERT(canMutate(ty, constraint)); + + boundTo = follow(boundTo); + if (get(ty) && ty == boundTo) + return emplace(constraint, ty, constraint->scope, builtinTypes->neverType, builtinTypes->unknownType); + + shiftReferences(ty, boundTo); + emplaceType(asMutable(ty), boundTo); + unblock(ty, constraint->location); +} + +void ConstraintSolver::bind(NotNull constraint, TypePackId tp, TypePackId boundTo) +{ + LUAU_ASSERT(get(tp) || get(tp)); + LUAU_ASSERT(canMutate(tp, constraint)); + + boundTo = follow(boundTo); + LUAU_ASSERT(tp != boundTo); + + emplaceTypePack(asMutable(tp), boundTo); + unblock(tp, constraint->location); +} + +template +void ConstraintSolver::emplace(NotNull constraint, TypeId ty, Args&&... args) +{ + static_assert(!std::is_same_v, "cannot use `emplace`! use `bind`"); + + LUAU_ASSERT(get(ty) || get(ty) || get(ty)); + LUAU_ASSERT(canMutate(ty, constraint)); + + emplaceType(asMutable(ty), std::forward(args)...); + unblock(ty, constraint->location); +} + +template +void ConstraintSolver::emplace(NotNull constraint, TypePackId tp, Args&&... args) +{ + static_assert(!std::is_same_v, "cannot use `emplace`! use `bind`"); + + LUAU_ASSERT(get(tp) || get(tp)); + LUAU_ASSERT(canMutate(tp, constraint)); + + emplaceTypePack(asMutable(tp), std::forward(args)...); + unblock(tp, constraint->location); +} + bool ConstraintSolver::tryDispatch(NotNull constraint, bool force) { if (!force && isBlocked(constraint)) @@ -547,9 +630,6 @@ bool ConstraintSolver::tryDispatch(NotNull constraint, bool fo else LUAU_ASSERT(false); - if (success) - unblock(constraint); - return success; } @@ -588,7 +668,7 @@ bool ConstraintSolver::tryDispatch(const GeneralizationConstraint& c, NotNull generalized; - std::optional generalizedTy = generalize(NotNull{arena}, builtinTypes, constraint->scope, c.sourceType); + std::optional generalizedTy = generalize(NotNull{arena}, builtinTypes, constraint->scope, generalizedTypes, c.sourceType); if (generalizedTy) generalized = QuantifierResult{*generalizedTy}; // FIXME insertedGenerics and insertedGenericPacks else @@ -597,7 +677,7 @@ bool ConstraintSolver::tryDispatch(const GeneralizationConstraint& c, NotNull(generalizedType)) - bindBlockedType(generalizedType, generalized->result, c.sourceType, constraint); + bind(constraint, generalizedType, generalized->result); else unify(constraint, generalizedType, generalized->result); @@ -610,17 +690,11 @@ bool ConstraintSolver::tryDispatch(const GeneralizationConstraint& c, NotNulllocation); - emplaceType(asMutable(c.generalizedType), builtinTypes->errorType); + bind(constraint, c.generalizedType, builtinTypes->errorRecoveryType()); } - unblock(c.generalizedType, constraint->location); - unblock(c.sourceType, constraint->location); - for (TypeId ty : c.interiorTypes) - { - generalize(NotNull{arena}, builtinTypes, constraint->scope, ty); - unblock(ty, constraint->location); - } + generalize(NotNull{arena}, builtinTypes, constraint->scope, generalizedTypes, ty); return true; } @@ -710,18 +784,18 @@ bool ConstraintSolver::tryDispatch(const IterableConstraint& c, NotNullnilType, builtinTypes->nilType, constraint); + bind(constraint, *it, builtinTypes->nilType); ++it; } @@ -813,15 +887,14 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul const PendingExpansionType* petv = get(follow(c.target)); if (!petv) { - unblock(c.target, constraint->location); + unblock(c.target, constraint->location); // TODO: do we need this? any re-entrancy? return true; } auto bindResult = [this, &c, constraint](TypeId result) { LUAU_ASSERT(get(c.target)); shiftReferences(c.target, result); - emplaceType(asMutable(c.target), result); - unblock(c.target, constraint->location); + bind(constraint, c.target, result); }; std::optional tf = (petv->prefix) ? constraint->scope->lookupImportedType(petv->prefix->value, petv->name.value) @@ -1009,19 +1082,23 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull(fn)) + { + emplaceTypePack(asMutable(c.result), builtinTypes->anyTypePack); + unblock(c.result, constraint->location); + return true; + } + // if we're calling an error type, the result is an error type, and that's that. if (get(fn)) { - emplaceTypePack(asMutable(c.result), builtinTypes->errorTypePack); - unblock(c.result, constraint->location); - + bind(constraint, c.result, builtinTypes->errorRecoveryTypePack()); return true; } if (get(fn)) { - emplaceTypePack(asMutable(c.result), builtinTypes->neverTypePack); - unblock(c.result, constraint->location); + bind(constraint, c.result, builtinTypes->neverTypePack); return true; } @@ -1078,7 +1155,7 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNulladdTypePack(TypePack{std::move(argsHead), argsTail}); fn = follow(*callMm); - emplaceTypePack(asMutable(c.result), constraint->scope); + emplace(constraint, c.result, constraint->scope); } else { @@ -1095,14 +1172,21 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull(asMutable(c.result), constraint->scope); + emplace(constraint, c.result, constraint->scope); } for (std::optional ty : c.discriminantTypes) { - if (!ty || !isBlocked(*ty)) + if (!ty) continue; + // If the discriminant type has been transmuted, we need to unblock them. + if (!isBlocked(*ty)) + { + unblock(*ty, constraint->location); + continue; + } + // We use `any` here because the discriminant type may be pointed at by both branches, // where the discriminant type is not negated, and the other where it is negated, i.e. // `unknown ~ unknown` and `~unknown ~ never`, so `T & unknown ~ T` and `T & ~unknown ~ never` @@ -1110,7 +1194,7 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNullanyType}; + emplaceType(asMutable(follow(*ty)), builtinTypes->anyType); } OverloadResolver resolver{ @@ -1120,7 +1204,6 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNulladdType(FunctionType{TypeLevel{}, constraint->scope.get(), argsPack, c.result}); Unifier2 u2{NotNull{arena}, builtinTypes, constraint->scope, NotNull{&iceReporter}}; @@ -1150,12 +1233,12 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNulllocation); - InstantiationQueuer queuer{constraint->scope, constraint->location, this}; queuer.traverse(overloadToUse); queuer.traverse(inferredTy); + unblock(c.result, constraint->location); + return true; } @@ -1250,7 +1333,7 @@ bool ConstraintSolver::tryDispatch(const FunctionCheckConstraint& c, NotNullargs.data[j]->annotation && get(follow(lambdaArgTys[j]))) { shiftReferences(lambdaArgTys[j], expectedLambdaArgTys[j]); - emplaceType(asMutable(lambdaArgTys[j]), expectedLambdaArgTys[j]); + bind(constraint, lambdaArgTys[j], expectedLambdaArgTys[j]); } } } @@ -1303,7 +1386,7 @@ bool ConstraintSolver::tryDispatch(const PrimitiveTypeConstraint& c, NotNulllowerBound; shiftReferences(c.freeType, bindTo); - emplaceType(asMutable(c.freeType), bindTo); + bind(constraint, c.freeType, bindTo); return true; } @@ -1336,8 +1419,7 @@ bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNullanyType), c.subjectType, constraint); - unblock(resultType, constraint->location); + bind(constraint, resultType, result.value_or(builtinTypes->anyType)); return true; } @@ -1361,12 +1443,12 @@ bool ConstraintSolver::tryDispatchHasIndexer( if (auto tbl = get(follow(ft->upperBound)); tbl && tbl->indexer) { unify(constraint, indexType, tbl->indexer->indexType); - bindBlockedType(resultType, tbl->indexer->indexResultType, subjectType, constraint); + bind(constraint, resultType, tbl->indexer->indexResultType); return true; } FreeType freeResult{ft->scope, builtinTypes->neverType, builtinTypes->unknownType}; - emplaceType(asMutable(resultType), freeResult); + emplace(constraint, resultType, freeResult); TypeId upperBound = arena->addType(TableType{/* props */ {}, TableIndexer{indexType, resultType}, TypeLevel{}, TableState::Unsealed}); @@ -1380,7 +1462,7 @@ bool ConstraintSolver::tryDispatchHasIndexer( { unify(constraint, indexType, indexer->indexType); - bindBlockedType(resultType, indexer->indexResultType, subjectType, constraint); + bind(constraint, resultType, indexer->indexResultType); return true; } else if (tt->state == TableState::Unsealed) @@ -1388,7 +1470,7 @@ bool ConstraintSolver::tryDispatchHasIndexer( // FIXME this is greedy. FreeType freeResult{tt->scope, builtinTypes->neverType, builtinTypes->unknownType}; - emplaceType(asMutable(resultType), freeResult); + emplace(constraint, resultType, freeResult); tt->indexer = TableIndexer{indexType, resultType}; return true; @@ -1401,12 +1483,12 @@ bool ConstraintSolver::tryDispatchHasIndexer( if (auto indexer = ct->indexer) { unify(constraint, indexType, indexer->indexType); - bindBlockedType(resultType, indexer->indexResultType, subjectType, constraint); + bind(constraint, resultType, indexer->indexResultType); return true; } else if (isString(indexType)) { - bindBlockedType(resultType, builtinTypes->unknownType, subjectType, constraint); + bind(constraint, resultType, builtinTypes->unknownType); return true; } } @@ -1441,11 +1523,11 @@ bool ConstraintSolver::tryDispatchHasIndexer( } if (0 == results.size()) - bindBlockedType(resultType, builtinTypes->errorType, subjectType, constraint); + bind(constraint, resultType, builtinTypes->errorType); else if (1 == results.size()) - bindBlockedType(resultType, *results.begin(), subjectType, constraint); + bind(constraint, resultType, *results.begin()); else - emplaceType(asMutable(resultType), std::vector(results.begin(), results.end())); + emplace(constraint, resultType, std::vector(results.begin(), results.end())); return true; } @@ -1473,20 +1555,20 @@ bool ConstraintSolver::tryDispatchHasIndexer( } if (0 == results.size()) - emplaceType(asMutable(resultType), builtinTypes->errorType); + bind(constraint, resultType, builtinTypes->errorType); else if (1 == results.size()) { TypeId firstResult = *results.begin(); shiftReferences(resultType, firstResult); - emplaceType(asMutable(resultType), firstResult); + bind(constraint, resultType, firstResult); } else - emplaceType(asMutable(resultType), std::vector(results.begin(), results.end())); + emplace(constraint, resultType, std::vector(results.begin(), results.end())); return true; } - bindBlockedType(resultType, builtinTypes->errorType, subjectType, constraint); + bind(constraint, resultType, builtinTypes->errorType); return true; } @@ -1534,86 +1616,7 @@ bool ConstraintSolver::tryDispatch(const HasIndexerConstraint& c, NotNull seen{nullptr}; - bool ok = tryDispatchHasIndexer(recursionDepth, constraint, subjectType, indexType, c.resultType, seen); - if (ok) - unblock(c.resultType, constraint->location); - return ok; -} - -std::pair> ConstraintSolver::tryDispatchSetIndexer( - NotNull constraint, TypeId subjectType, TypeId indexType, TypeId propType, bool expandFreeTypeBounds) -{ - if (isBlocked(subjectType)) - return {block(subjectType, constraint), std::nullopt}; - - if (auto tt = getMutable(subjectType)) - { - if (tt->indexer) - { - if (isBlocked(tt->indexer->indexResultType)) - return {block(tt->indexer->indexResultType, constraint), std::nullopt}; - - unify(constraint, indexType, tt->indexer->indexType); - return {true, tt->indexer->indexResultType}; - } - else if (tt->state == TableState::Free || tt->state == TableState::Unsealed) - { - TypeId resultTy = freshType(arena, builtinTypes, constraint->scope.get()); - tt->indexer = TableIndexer{indexType, resultTy}; - return {true, resultTy}; - } - } - else if (auto ft = getMutable(subjectType); ft && expandFreeTypeBounds) - { - // Setting an indexer on some fresh type means we use that fresh type in a negative position. - // Therefore, we only care about the upper bound. - // - // We'll extend the upper bound if we could dispatch, but could not find a table type to update the indexer. - auto [dispatched, resultTy] = tryDispatchSetIndexer(constraint, ft->upperBound, indexType, propType, /*expandFreeTypeBounds=*/false); - if (dispatched && !resultTy) - { - // Despite that we haven't found a table type, adding a table type causes us to have one that we can /now/ find. - resultTy = freshType(arena, builtinTypes, constraint->scope.get()); - - TypeId tableTy = arena->addType(TableType{TableState::Sealed, TypeLevel{}, constraint->scope.get()}); - TableType* tt2 = getMutable(tableTy); - tt2->indexer = TableIndexer{indexType, *resultTy}; - - ft->upperBound = - simplifyIntersection(builtinTypes, arena, ft->upperBound, tableTy).result; // TODO: intersect type family or a constraint. - } - - return {dispatched, resultTy}; - } - else if (auto it = get(subjectType)) - { - bool dispatched = true; - std::vector results; - - for (TypeId part : it) - { - auto [dispatched2, found] = tryDispatchSetIndexer(constraint, part, indexType, propType, expandFreeTypeBounds); - dispatched &= dispatched2; - results.push_back(found.value_or(builtinTypes->errorRecoveryType())); - - if (!dispatched) - return {dispatched, std::nullopt}; - } - - TypeId resultTy = arena->addType(TypeFamilyInstanceType{ - NotNull{&kBuiltinTypeFamilies.unionFamily}, - std::move(results), - {}, - }); - - pushConstraint(constraint->scope, constraint->location, ReduceConstraint{resultTy}); - - return {dispatched, resultTy}; - } - else if (is(subjectType)) - return {true, subjectType}; - - return {true, std::nullopt}; + return tryDispatchHasIndexer(recursionDepth, constraint, subjectType, indexType, c.resultType, seen); } bool ConstraintSolver::tryDispatch(const AssignPropConstraint& c, NotNull constraint) @@ -1643,7 +1646,7 @@ bool ConstraintSolver::tryDispatch(const AssignPropConstraint& c, NotNullwriteTy.has_value()) return true; - emplaceType(asMutable(c.propType), *prop->writeTy); + bind(constraint, c.propType, *prop->writeTy); unify(constraint, rhsType, *prop->writeTy); return true; } @@ -1663,7 +1666,7 @@ bool ConstraintSolver::tryDispatch(const AssignPropConstraint& c, NotNullupperBound = simplifyIntersection(builtinTypes, arena, lhsFree->upperBound, newUpperBound).result; - emplaceType(asMutable(c.propType), rhsType); + bind(constraint, c.propType, rhsType); return true; } } @@ -1681,7 +1684,7 @@ bool ConstraintSolver::tryDispatch(const AssignPropConstraint& c, NotNull(asMutable(c.propType), propTy); + bind(constraint, c.propType, propTy); unify(constraint, rhsType, propTy); return true; } @@ -1700,7 +1703,7 @@ bool ConstraintSolver::tryDispatch(const AssignPropConstraint& c, NotNull(asMutable(c.propType), *prop.writeTy); + bind(constraint, c.propType, *prop.writeTy); unify(constraint, rhsType, *prop.writeTy); return true; } @@ -1710,13 +1713,13 @@ bool ConstraintSolver::tryDispatch(const AssignPropConstraint& c, NotNullstate == TableState::Unsealed || lhsTable->state == TableState::Free) { prop.writeTy = prop.readTy; - emplaceType(asMutable(c.propType), *prop.writeTy); + bind(constraint, c.propType, *prop.writeTy); unify(constraint, rhsType, *prop.writeTy); return true; } else { - emplaceType(asMutable(c.propType), builtinTypes->errorType); + bind(constraint, c.propType, builtinTypes->errorType); return true; } } @@ -1724,28 +1727,27 @@ bool ConstraintSolver::tryDispatch(const AssignPropConstraint& c, NotNullindexer && maybeString(lhsTable->indexer->indexType)) { - emplaceType(asMutable(c.propType), rhsType); + bind(constraint, c.propType, rhsType); unify(constraint, rhsType, lhsTable->indexer->indexResultType); return true; } if (lhsTable->state == TableState::Unsealed || lhsTable->state == TableState::Free) { - emplaceType(asMutable(c.propType), rhsType); + bind(constraint, c.propType, rhsType); lhsTable->props[propName] = Property::rw(rhsType); if (lhsTable->state == TableState::Unsealed && c.decrementPropCount) { LUAU_ASSERT(lhsTable->remainingProps > 0); lhsTable->remainingProps -= 1; - unblock(lhsType, constraint->location); } return true; } } - emplaceType(asMutable(c.propType), builtinTypes->errorType); + bind(constraint, c.propType, builtinTypes->errorType); return true; } @@ -1772,14 +1774,14 @@ bool ConstraintSolver::tryDispatch(const AssignIndexConstraint& c, NotNullindexer->indexType); unify(constraint, rhsType, lhsTable->indexer->indexResultType); - emplaceType(asMutable(c.propType), lhsTable->indexer->indexResultType); + bind(constraint, c.propType, lhsTable->indexer->indexResultType); return true; } if (lhsTable->state == TableState::Unsealed || lhsTable->state == TableState::Free) { lhsTable->indexer = TableIndexer{indexType, rhsType}; - emplaceType(asMutable(c.propType), rhsType); + bind(constraint, c.propType, rhsType); return true; } @@ -1802,7 +1804,7 @@ bool ConstraintSolver::tryDispatch(const AssignIndexConstraint& c, NotNullindexer); - emplaceType(asMutable(c.propType), newTable->indexer->indexResultType); + bind(constraint, c.propType, newTable->indexer->indexResultType); return true; } @@ -1821,7 +1823,7 @@ bool ConstraintSolver::tryDispatch(const AssignIndexConstraint& c, NotNullindexer->indexType); unify(constraint, rhsType, lhsClass->indexer->indexResultType); - emplaceType(asMutable(c.propType), lhsClass->indexer->indexResultType); + bind(constraint, c.propType, lhsClass->indexer->indexResultType); return true; } @@ -1878,40 +1880,11 @@ bool ConstraintSolver::tryDispatch(const AssignIndexConstraint& c, NotNull(asMutable(c.propType), builtinTypes->errorType); + bind(constraint, c.propType, builtinTypes->errorType); return true; } -bool ConstraintSolver::tryDispatchUnpack1(NotNull constraint, TypeId resultTy, TypeId srcTy) -{ - resultTy = follow(resultTy); - LUAU_ASSERT(canMutate(resultTy, constraint)); - - LUAU_ASSERT(get(resultTy)); - - if (get(resultTy)) - { - if (follow(srcTy) == resultTy) - { - // It is sometimes the case that we find that a blocked type - // is only blocked on itself. This doesn't actually - // constitute any meaningful constraint, so we replace it - // with a free type. - TypeId f = freshType(arena, builtinTypes, constraint->scope); - shiftReferences(resultTy, f); - emplaceType(asMutable(resultTy), f); - } - else - bindBlockedType(resultTy, srcTy, srcTy, constraint); - } - else - unify(constraint, srcTy, resultTy); - - unblock(resultTy, constraint->location); - return true; -} - bool ConstraintSolver::tryDispatch(const UnpackConstraint& c, NotNull constraint) { TypePackId sourcePack = follow(c.sourcePack); @@ -1932,7 +1905,29 @@ bool ConstraintSolver::tryDispatch(const UnpackConstraint& c, NotNull(resultTy)); + LUAU_ASSERT(canMutate(resultTy, constraint)); + + if (get(resultTy)) + { + if (follow(srcTy) == resultTy) + { + // It is sometimes the case that we find that a blocked type + // is only blocked on itself. This doesn't actually + // constitute any meaningful constraint, so we replace it + // with a free type. + TypeId f = freshType(arena, builtinTypes, constraint->scope); + shiftReferences(resultTy, f); + emplaceType(asMutable(resultTy), f); + } + else + bind(constraint, resultTy, srcTy); + } + else + unify(constraint, srcTy, resultTy); + + unblock(resultTy, constraint->location); ++resultIter; ++i; @@ -1948,8 +1943,7 @@ bool ConstraintSolver::tryDispatch(const UnpackConstraint& c, NotNull(resultTy) || get(resultTy)) { - emplaceType(asMutable(resultTy), builtinTypes->nilType); - unblock(resultTy, constraint->location); + bind(constraint, resultTy, builtinTypes->nilType); } ++resultIter; @@ -1972,6 +1966,11 @@ bool ConstraintSolver::tryDispatch(const ReduceConstraint& c, NotNull(ty)) + typeFamiliesToFinalize[ty] = constraint; + if (force || reductionFinished) { // if we're completely dispatching this constraint, we want to record any uninhabited type families to unblock. @@ -2058,11 +2057,11 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl auto endIt = end(c.variables); if (it != endIt) { - bindBlockedType(*it, keyTy, keyTy, constraint); + bind(constraint, *it, keyTy); ++it; } if (it != endIt) - bindBlockedType(*it, valueTy, valueTy, constraint); + bind(constraint, *it, valueTy); return true; } @@ -2072,7 +2071,7 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl { LUAU_ASSERT(get(varTy)); LUAU_ASSERT(varTy != ty); - bindBlockedType(varTy, ty, ty, constraint); + bind(constraint, varTy, ty); } }; @@ -2121,8 +2120,7 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl unify(constraint, c.variables[i], expectedVariables[i]); - bindBlockedType(c.variables[i], expectedVariables[i], expectedVariables[i], constraint); - unblock(c.variables[i], constraint->location); + bind(constraint, c.variables[i], expectedVariables[i]); } } else @@ -2517,42 +2515,9 @@ bool ConstraintSolver::unify(NotNull constraint, TID subTy, TI return false; } - unblock(subTy, constraint->location); - unblock(superTy, constraint->location); - return true; } -void ConstraintSolver::bindBlockedType(TypeId blockedTy, TypeId resultTy, TypeId rootTy, NotNull constraint) -{ - resultTy = follow(resultTy); - - LUAU_ASSERT(get(blockedTy) && canMutate(blockedTy, constraint)); - - if (blockedTy == resultTy) - { - rootTy = follow(rootTy); - Scope* freeScope = nullptr; - if (auto ft = get(rootTy)) - freeScope = ft->scope; - else if (auto tt = get(rootTy); tt && tt->state == TableState::Free) - freeScope = tt->scope; - else - iceReporter.ice("bindBlockedType couldn't find an appropriate scope for a fresh type!", constraint->location); - - LUAU_ASSERT(freeScope); - - TypeId freeType = arena->freshType(freeScope); - shiftReferences(blockedTy, freeType); - emplaceType(asMutable(blockedTy), freeType); - } - else - { - shiftReferences(blockedTy, resultTy); - emplaceType(asMutable(blockedTy), resultTy); - } -} - bool ConstraintSolver::block_(BlockedConstraintId target, NotNull constraint) { // If a set is not present for the target, construct a new DenseHashSet for it, @@ -2884,7 +2849,7 @@ std::optional ConstraintSolver::generalizeFreeType(NotNull scope, // that until all constraint generation is complete. } - return generalize(NotNull{arena}, builtinTypes, scope, type); + return generalize(NotNull{arena}, builtinTypes, scope, generalizedTypes, type); } bool ConstraintSolver::hasUnresolvedConstraints(TypeId ty) diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index 2087e3d3..d356b1cc 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -7,6 +7,7 @@ #include "Luau/NotNull.h" #include "Luau/StringUtils.h" #include "Luau/ToString.h" +#include "Luau/Type.h" #include "Luau/TypeFamily.h" #include @@ -666,6 +667,18 @@ struct ErrorConverter return "Type family instance " + Luau::toString(e.ty) + " is ill-formed, and thus invalid"; } + if ("index" == tfit->family->name) + { + if (tfit->typeArguments.size() != 2) + return "Type family instance " + Luau::toString(e.ty) + " is ill-formed, and thus invalid"; + + if (auto errType = get(tfit->typeArguments[1])) // Second argument to index<_,_> is not a type + return "Second argument to index<" + Luau::toString(tfit->typeArguments[0]) + ", _> is not a valid index type"; + else // Second argument to index<_,_> is not a property of the first argument + return "Property '" + Luau::toString(tfit->typeArguments[1]) + "' does not exist on type '" + Luau::toString(tfit->typeArguments[0]) + + "'"; + } + if (kUnreachableTypeFamilies.count(tfit->family->name)) { return "Type family instance " + Luau::toString(e.ty) + " is uninhabited\n" + diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 7823f3d4..618a9a9c 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -1003,6 +1003,7 @@ void Frontend::checkBuildQueueItem(BuildQueueItem& item) module->astForInNextTypes.clear(); module->astResolvedTypes.clear(); module->astResolvedTypePacks.clear(); + module->astCompoundAssignResultTypes.clear(); module->astScopes.clear(); module->upperBoundContributors.clear(); diff --git a/Analysis/src/Generalization.cpp b/Analysis/src/Generalization.cpp index 081ea153..c2c44d96 100644 --- a/Analysis/src/Generalization.cpp +++ b/Analysis/src/Generalization.cpp @@ -4,6 +4,7 @@ #include "Luau/Scope.h" #include "Luau/Type.h" +#include "Luau/ToString.h" #include "Luau/TypeArena.h" #include "Luau/TypePack.h" #include "Luau/VisitType.h" @@ -16,6 +17,7 @@ struct MutatingGeneralizer : TypeOnceVisitor NotNull builtinTypes; NotNull scope; + NotNull> cachedTypes; DenseHashMap positiveTypes; DenseHashMap negativeTypes; std::vector generics; @@ -23,11 +25,12 @@ struct MutatingGeneralizer : TypeOnceVisitor bool isWithinFunction = false; - MutatingGeneralizer(NotNull builtinTypes, NotNull scope, DenseHashMap positiveTypes, + MutatingGeneralizer(NotNull builtinTypes, NotNull scope, NotNull> cachedTypes, DenseHashMap positiveTypes, DenseHashMap negativeTypes) : TypeOnceVisitor(/* skipBoundTypes */ true) , builtinTypes(builtinTypes) , scope(scope) + , cachedTypes(cachedTypes) , positiveTypes(std::move(positiveTypes)) , negativeTypes(std::move(negativeTypes)) { @@ -130,6 +133,9 @@ struct MutatingGeneralizer : TypeOnceVisitor bool visit(TypeId ty, const FunctionType& ft) override { + if (cachedTypes->contains(ty)) + return false; + const bool oldValue = isWithinFunction; isWithinFunction = true; @@ -144,6 +150,8 @@ struct MutatingGeneralizer : TypeOnceVisitor bool visit(TypeId ty, const FreeType&) override { + LUAU_ASSERT(!cachedTypes->contains(ty)); + const FreeType* ft = get(ty); LUAU_ASSERT(ft); @@ -244,6 +252,9 @@ struct MutatingGeneralizer : TypeOnceVisitor bool visit(TypeId ty, const TableType&) override { + if (cachedTypes->contains(ty)) + return false; + const size_t positiveCount = getCount(positiveTypes, ty); const size_t negativeCount = getCount(negativeTypes, ty); @@ -287,10 +298,12 @@ struct MutatingGeneralizer : TypeOnceVisitor struct FreeTypeSearcher : TypeVisitor { NotNull scope; + NotNull> cachedTypes; - explicit FreeTypeSearcher(NotNull scope) + explicit FreeTypeSearcher(NotNull scope, NotNull> cachedTypes) : TypeVisitor(/*skipBoundTypes*/ true) , scope(scope) + , cachedTypes(cachedTypes) { } @@ -363,7 +376,7 @@ struct FreeTypeSearcher : TypeVisitor bool visit(TypeId ty) override { - if (seenWithPolarity(ty)) + if (cachedTypes->contains(ty) || seenWithPolarity(ty)) return false; LUAU_ASSERT(ty); @@ -372,7 +385,7 @@ struct FreeTypeSearcher : TypeVisitor bool visit(TypeId ty, const FreeType& ft) override { - if (seenWithPolarity(ty)) + if (cachedTypes->contains(ty) || seenWithPolarity(ty)) return false; if (!subsumes(scope, ft.scope)) @@ -397,7 +410,7 @@ struct FreeTypeSearcher : TypeVisitor bool visit(TypeId ty, const TableType& tt) override { - if (seenWithPolarity(ty)) + if (cachedTypes->contains(ty) || seenWithPolarity(ty)) return false; if ((tt.state == TableState::Free || tt.state == TableState::Unsealed) && subsumes(scope, tt.scope)) @@ -443,7 +456,7 @@ struct FreeTypeSearcher : TypeVisitor bool visit(TypeId ty, const FunctionType& ft) override { - if (seenWithPolarity(ty)) + if (cachedTypes->contains(ty) || seenWithPolarity(ty)) return false; flip(); @@ -486,8 +499,371 @@ struct FreeTypeSearcher : TypeVisitor } }; +// We keep a running set of types that will not change under generalization and +// only have outgoing references to types that are the same. We use this to +// short circuit generalization. It improves performance quite a lot. +// +// We do this by tracing through the type and searching for types that are +// uncacheable. If a type has a reference to an uncacheable type, it is itself +// uncacheable. +// +// If a type has no outbound references to uncacheable types, we add it to the +// cache. +struct TypeCacher : TypeOnceVisitor +{ + NotNull> cachedTypes; -std::optional generalize(NotNull arena, NotNull builtinTypes, NotNull scope, TypeId ty) + DenseHashSet uncacheable{nullptr}; + DenseHashSet uncacheablePacks{nullptr}; + + explicit TypeCacher(NotNull> cachedTypes) + : TypeOnceVisitor(/* skipBoundTypes */ true) + , cachedTypes(cachedTypes) + {} + + void cache(TypeId ty) + { + cachedTypes->insert(ty); + } + + bool isCached(TypeId ty) const + { + return cachedTypes->contains(ty); + } + + void markUncacheable(TypeId ty) + { + uncacheable.insert(ty); + } + + void markUncacheable(TypePackId tp) + { + uncacheablePacks.insert(tp); + } + + bool isUncacheable(TypeId ty) const + { + return uncacheable.contains(ty); + } + + bool isUncacheable(TypePackId tp) const + { + return uncacheablePacks.contains(tp); + } + + bool visit(TypeId ty) override + { + if (isUncacheable(ty) || isCached(ty)) + return false; + return true; + } + + bool visit(TypeId ty, const FreeType& ft) override + { + // Free types are never cacheable. + LUAU_ASSERT(!isCached(ty)); + + if (!isUncacheable(ty)) + { + traverse(ft.lowerBound); + traverse(ft.upperBound); + + markUncacheable(ty); + } + + return false; + } + + bool visit(TypeId ty, const GenericType&) override + { + cache(ty); + return false; + } + + bool visit(TypeId ty, const PrimitiveType&) override + { + cache(ty); + return false; + } + + bool visit(TypeId ty, const SingletonType&) override + { + cache(ty); + return false; + } + + bool visit(TypeId ty, const BlockedType&) override + { + markUncacheable(ty); + return false; + } + + bool visit(TypeId ty, const PendingExpansionType&) override + { + markUncacheable(ty); + return false; + } + + bool visit(TypeId ty, const FunctionType& ft) override + { + if (isCached(ty) || isUncacheable(ty)) + return false; + + traverse(ft.argTypes); + traverse(ft.retTypes); + for (TypeId gen: ft.generics) + traverse(gen); + + bool uncacheable = false; + + if (isUncacheable(ft.argTypes)) + uncacheable = true; + + else if (isUncacheable(ft.retTypes)) + uncacheable = true; + + for (TypeId argTy: ft.argTypes) + { + if (isUncacheable(argTy)) + { + uncacheable = true; + break; + } + } + + for (TypeId retTy: ft.retTypes) + { + if (isUncacheable(retTy)) + { + uncacheable = true; + break; + } + } + + for (TypeId g: ft.generics) + { + if (isUncacheable(g)) + { + uncacheable = true; + break; + } + } + + if (uncacheable) + markUncacheable(ty); + else + cache(ty); + + return false; + } + + bool visit(TypeId ty, const TableType& tt) override + { + if (isCached(ty) || isUncacheable(ty)) + return false; + + if (tt.boundTo) + { + traverse(*tt.boundTo); + if (isUncacheable(*tt.boundTo)) + { + markUncacheable(ty); + return false; + } + } + + bool uncacheable = false; + + // This logic runs immediately after generalization, so any remaining + // unsealed tables are assuredly not cacheable. They may yet have + // properties added to them. + if (tt.state == TableState::Free || tt.state == TableState::Unsealed) + uncacheable = true; + + for (const auto& [_name, prop] : tt.props) + { + if (prop.readTy) + { + traverse(*prop.readTy); + + if (isUncacheable(*prop.readTy)) + uncacheable = true; + } + if (prop.writeTy && prop.writeTy != prop.readTy) + { + traverse(*prop.writeTy); + + if (isUncacheable(*prop.writeTy)) + uncacheable = true; + } + } + + if (tt.indexer) + { + traverse(tt.indexer->indexType); + if (isUncacheable(tt.indexer->indexType)) + uncacheable = true; + + traverse(tt.indexer->indexResultType); + if (isUncacheable(tt.indexer->indexResultType)) + uncacheable = true; + } + + if (uncacheable) + markUncacheable(ty); + else + cache(ty); + + return false; + } + + bool visit(TypeId ty, const ClassType&) override + { + cache(ty); + return false; + } + + bool visit(TypeId ty, const AnyType&) override + { + cache(ty); + return false; + } + + bool visit(TypeId ty, const UnionType& ut) override + { + if (isUncacheable(ty) || isCached(ty)) + return false; + + bool uncacheable = false; + + for (TypeId partTy : ut.options) + { + traverse(partTy); + + uncacheable |= isUncacheable(partTy); + } + + if (uncacheable) + markUncacheable(ty); + else + cache(ty); + + return false; + } + + bool visit(TypeId ty, const IntersectionType& it) override + { + if (isUncacheable(ty) || isCached(ty)) + return false; + + bool uncacheable = false; + + for (TypeId partTy : it.parts) + { + traverse(partTy); + + uncacheable |= isUncacheable(partTy); + } + + if (uncacheable) + markUncacheable(ty); + else + cache(ty); + + return false; + } + + bool visit(TypeId ty, const UnknownType&) override + { + cache(ty); + return false; + } + + bool visit(TypeId ty, const NeverType&) override + { + cache(ty); + return false; + } + + bool visit(TypeId ty, const NegationType& nt) override + { + if (!isCached(ty) && !isUncacheable(ty)) + { + traverse(nt.ty); + + if (isUncacheable(nt.ty)) + markUncacheable(ty); + else + cache(ty); + } + + return false; + } + + bool visit(TypeId ty, const TypeFamilyInstanceType& tfit) override + { + if (isCached(ty) || isUncacheable(ty)) + return false; + + bool uncacheable = false; + + for (TypeId argTy : tfit.typeArguments) + { + traverse(argTy); + + if (isUncacheable(argTy)) + uncacheable = true; + } + + for (TypePackId argPack : tfit.packArguments) + { + traverse(argPack); + + if (isUncacheable(argPack)) + uncacheable = true; + } + + if (uncacheable) + markUncacheable(ty); + else + cache(ty); + + return false; + } + + bool visit(TypePackId tp, const FreeTypePack&) override + { + markUncacheable(tp); + return false; + } + + bool visit(TypePackId tp, const VariadicTypePack& vtp) override + { + if (isUncacheable(tp)) + return false; + + traverse(vtp.ty); + + if (isUncacheable(vtp.ty)) + markUncacheable(tp); + + return false; + } + + bool visit(TypePackId tp, const BlockedTypePack&) override + { + markUncacheable(tp); + return false; + } + + bool visit(TypePackId tp, const TypeFamilyInstanceTypePack&) override + { + markUncacheable(tp); + return false; + } +}; + +std::optional generalize(NotNull arena, NotNull builtinTypes, NotNull scope, NotNull> cachedTypes, TypeId ty) { ty = follow(ty); @@ -497,10 +873,10 @@ std::optional generalize(NotNull arena, NotNull if (const FunctionType* ft = get(ty); ft && (!ft->generics.empty() || !ft->genericPacks.empty())) return ty; - FreeTypeSearcher fts{scope}; + FreeTypeSearcher fts{scope, cachedTypes}; fts.traverse(ty); - MutatingGeneralizer gen{builtinTypes, scope, std::move(fts.positiveTypes), std::move(fts.negativeTypes)}; + MutatingGeneralizer gen{builtinTypes, scope, cachedTypes, std::move(fts.positiveTypes), std::move(fts.negativeTypes)}; gen.traverse(ty); @@ -513,6 +889,9 @@ std::optional generalize(NotNull arena, NotNull if (ty->owningArena != arena || ty->persistent) return ty; + TypeCacher cacher{cachedTypes}; + cacher.traverse(ty); + FunctionType* ftv = getMutable(ty); if (ftv) { diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index 7ce50284..16fe9546 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -18,7 +18,6 @@ LUAU_FASTFLAGVARIABLE(DebugLuauCheckNormalizeInvariant, false) LUAU_FASTFLAGVARIABLE(LuauNormalizeAwayUninhabitableTables, false) LUAU_FASTFLAGVARIABLE(LuauNormalizeNotUnknownIntersection, false); -LUAU_FASTFLAGVARIABLE(LuauFixCyclicUnionsOfIntersections, false); LUAU_FASTFLAGVARIABLE(LuauFixReduceStackPressure, false); LUAU_FASTFLAGVARIABLE(LuauFixCyclicTablesBlowingStack, false); @@ -27,11 +26,6 @@ LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200); LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); -static bool fixCyclicUnionsOfIntersections() -{ - return FFlag::LuauFixCyclicUnionsOfIntersections || FFlag::DebugLuauDeferredConstraintResolution; -} - static bool fixReduceStackPressure() { return FFlag::LuauFixReduceStackPressure || FFlag::DebugLuauDeferredConstraintResolution; @@ -1776,12 +1770,9 @@ NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId t } else if (const IntersectionType* itv = get(there)) { - if (fixCyclicUnionsOfIntersections()) - { - if (seenSetTypes.count(there)) - return NormalizationResult::True; - seenSetTypes.insert(there); - } + if (seenSetTypes.count(there)) + return NormalizationResult::True; + seenSetTypes.insert(there); NormalizedType norm{builtinTypes}; norm.tops = builtinTypes->anyType; @@ -1790,14 +1781,12 @@ NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId t NormalizationResult res = intersectNormalWithTy(norm, *it, seenSetTypes); if (res != NormalizationResult::True) { - if (fixCyclicUnionsOfIntersections()) - seenSetTypes.erase(there); + seenSetTypes.erase(there); return res; } } - if (fixCyclicUnionsOfIntersections()) - seenSetTypes.erase(there); + seenSetTypes.erase(there); return unionNormals(here, norm); } diff --git a/Analysis/src/TableLiteralInference.cpp b/Analysis/src/TableLiteralInference.cpp index 414544b6..3514ff65 100644 --- a/Analysis/src/TableLiteralInference.cpp +++ b/Analysis/src/TableLiteralInference.cpp @@ -337,7 +337,9 @@ TypeId matchLiteralType(NotNull> astTypes, TypeId matchedType = matchLiteralType(astTypes, astExpectedTypes, builtinTypes, arena, unifier, expectedTableTy->indexer->indexResultType, *propTy, item.value, toBlock); - tableTy->indexer->indexResultType = matchedType; + // if the index result type is the prop type, we can replace it with the matched type here. + if (tableTy->indexer->indexResultType == *propTy) + tableTy->indexer->indexResultType = matchedType; } } else if (item.kind == AstExprTable::Item::General) diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index 5ffeb951..cc02bea6 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -446,7 +446,6 @@ struct TypeChecker2 .errors; if (!isErrorSuppressing(location, instance)) reportErrors(std::move(errors)); - return instance; } @@ -1108,10 +1107,13 @@ struct TypeChecker2 void visit(AstStatCompoundAssign* stat) { AstExprBinary fake{stat->location, stat->op, stat->var, stat->value}; - TypeId resultTy = visit(&fake, stat); + visit(&fake, stat); + + TypeId* resultTy = module->astCompoundAssignResultTypes.find(stat); + LUAU_ASSERT(resultTy); TypeId varTy = lookupType(stat->var); - testIsSubtype(resultTy, varTy, stat->location); + testIsSubtype(*resultTy, varTy, stat->location); } void visit(AstStatFunction* stat) @@ -1857,7 +1859,7 @@ struct TypeChecker2 bool isStringOperation = (normLeft ? normLeft->isSubtypeOfString() : isString(leftType)) && (normRight ? normRight->isSubtypeOfString() : isString(rightType)); - + leftType = follow(leftType); if (get(leftType) || get(leftType) || get(leftType)) return leftType; else if (get(rightType) || get(rightType) || get(rightType)) diff --git a/Analysis/src/TypeFamily.cpp b/Analysis/src/TypeFamily.cpp index 89de1912..c65fde00 100644 --- a/Analysis/src/TypeFamily.cpp +++ b/Analysis/src/TypeFamily.cpp @@ -1490,6 +1490,18 @@ TypeFamilyReductionResult refineFamilyFn(TypeId instance, const std::vec if (get(follow(nt->ty))) return {targetTy, false, {}, {}}; + // If the target type is a table, then simplification already implements the logic to deal with refinements properly since the + // type of the discriminant is guaranteed to only ever be an (arbitrarily-nested) table of a single property type. + if (get(targetTy)) + { + SimplifyResult result = simplifyIntersection(ctx->builtins, ctx->arena, targetTy, discriminantTy); + if (!result.blockedTypes.empty()) + return {std::nullopt, false, {result.blockedTypes.begin(), result.blockedTypes.end()}, {}}; + + return {result.result, false, {}, {}}; + } + + // In the general case, we'll still use normalization though. TypeId intersection = ctx->arena->addType(IntersectionType{{targetTy, discriminantTy}}); std::shared_ptr normIntersection = ctx->normalizer->normalize(intersection); std::shared_ptr normType = ctx->normalizer->normalize(targetTy); @@ -1853,6 +1865,208 @@ TypeFamilyReductionResult rawkeyofFamilyFn(TypeId instance, const std::v return keyofFamilyImpl(typeParams, packParams, ctx, /* isRaw */ true); } +/* Searches through table's or class's props/indexer to find the property of `ty` + If found, appends that property to `result` and returns true + Else, returns false */ +bool searchPropsAndIndexer( + TypeId ty, TableType::Props tblProps, std::optional tblIndexer, DenseHashSet& result, NotNull ctx) +{ + ty = follow(ty); + + // index into tbl's properties + if (auto stringSingleton = get(get(ty))) + { + if (tblProps.find(stringSingleton->value) != tblProps.end()) + { + TypeId propTy = follow(tblProps.at(stringSingleton->value).type()); + + // property is a union type -> we need to extend our reduction type + if (auto propUnionTy = get(propTy)) + { + for (TypeId option : propUnionTy->options) + result.insert(option); + } + else // property is a singular type or intersection type -> we can simply append + result.insert(propTy); + + return true; + } + } + + // index into tbl's indexer + if (tblIndexer) + { + if (isSubtype(ty, tblIndexer->indexType, ctx->scope, ctx->builtins, *ctx->ice)) + { + TypeId idxResultTy = follow(tblIndexer->indexResultType); + + // indexResultType is a union type -> we need to extend our reduction type + if (auto idxResUnionTy = get(idxResultTy)) + { + for (TypeId option : idxResUnionTy->options) + result.insert(option); + } + else // indexResultType is a singular type or intersection type -> we can simply append + result.insert(idxResultTy); + + return true; + } + } + + return false; +} + +/* Handles recursion / metamethods of tables/classes + `isRaw` parameter indicates whether or not we should follow __index metamethods + returns false if property of `ty` could not be found */ +bool tblIndexInto(TypeId indexer, TypeId indexee, DenseHashSet& result, NotNull ctx, bool isRaw) +{ + indexer = follow(indexer); + indexee = follow(indexee); + + // we have a table type to try indexing + if (auto tableTy = get(indexee)) + { + return searchPropsAndIndexer(indexer, tableTy->props, tableTy->indexer, result, ctx); + } + + // we have a metatable type to try indexing + if (auto metatableTy = get(indexee)) + { + if (auto tableTy = get(metatableTy->table)) + { + + // try finding all properties within the current scope of the table + if (searchPropsAndIndexer(indexer, tableTy->props, tableTy->indexer, result, ctx)) + return true; + } + + // if the code reached here, it means we weren't able to find all properties -> look into __index metamethod + if (!isRaw) + { + // findMetatableEntry demands the ability to emit errors, so we must give it + // the necessary state to do that, even if we intend to just eat the errors. + ErrorVec dummy; + std::optional mmType = findMetatableEntry(ctx->builtins, dummy, indexee, "__index", Location{}); + if (mmType) + return tblIndexInto(indexer, *mmType, result, ctx, isRaw); + } + } + + return false; +} + +/* Vocabulary note: indexee refers to the type that contains the properties, + indexer refers to the type that is used to access indexee + Example: index => `Person` is the indexee and `"name"` is the indexer */ +TypeFamilyReductionResult indexFamilyFn( + TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) +{ + if (typeParams.size() != 2 || !packParams.empty()) + { + ctx->ice->ice("index type family: encountered a type family instance without the required argument structure"); + LUAU_ASSERT(false); + } + + TypeId indexeeTy = follow(typeParams.at(0)); + std::shared_ptr indexeeNormTy = ctx->normalizer->normalize(indexeeTy); + + // if the indexee failed to normalize, we can't reduce, but know nothing about inhabitance. + if (!indexeeNormTy) + return {std::nullopt, false, {}, {}}; + + // if we don't have either just tables or just classes, we've got nothing to index into + if (indexeeNormTy->hasTables() == indexeeNormTy->hasClasses()) + return {std::nullopt, true, {}, {}}; + + // we're trying to reject any type that has not normalized to a table/class or a union of tables/classes. + if (indexeeNormTy->hasTops() || indexeeNormTy->hasBooleans() || indexeeNormTy->hasErrors() || indexeeNormTy->hasNils() || + indexeeNormTy->hasNumbers() || indexeeNormTy->hasStrings() || indexeeNormTy->hasThreads() || indexeeNormTy->hasBuffers() || + indexeeNormTy->hasFunctions() || indexeeNormTy->hasTyvars()) + return {std::nullopt, true, {}, {}}; + + TypeId indexerTy = follow(typeParams.at(1)); + std::shared_ptr indexerNormTy = ctx->normalizer->normalize(indexerTy); + + // if the indexer failed to normalize, we can't reduce, but know nothing about inhabitance. + if (!indexerNormTy) + return {std::nullopt, false, {}, {}}; + + // we're trying to reject any type that is not a string singleton or primitive (string, number, boolean, thread, nil, function, table, or buffer) + if (indexerNormTy->hasTops() || indexerNormTy->hasErrors()) + return {std::nullopt, true, {}, {}}; + + // indexer can be a union —> break them down into a vector + const std::vector* typesToFind; + const std::vector singleType{indexerTy}; + if (auto unionTy = get(indexerTy)) + typesToFind = &unionTy->options; + else + typesToFind = &singleType; + + DenseHashSet properties{{}}; // vector of types that will be returned + bool isRaw = false; + + if (indexeeNormTy->hasClasses()) + { + LUAU_ASSERT(!indexeeNormTy->hasTables()); + + // at least one class is guaranteed to be in the iterator by .hasClasses() + for (auto classesIter = indexeeNormTy->classes.ordering.begin(); classesIter != indexeeNormTy->classes.ordering.end(); ++classesIter) + { + auto classTy = get(*classesIter); + if (!classTy) + { + LUAU_ASSERT(false); // this should not be possible according to normalization's spec + return {std::nullopt, true, {}, {}}; + } + + for (TypeId ty : *typesToFind) + { + // Search for all instances of indexer in class->props and class->indexer using `indexInto` + if (searchPropsAndIndexer(ty, classTy->props, classTy->indexer, properties, ctx)) + continue; // Indexer was found in this class, so we can move on to the next + + // If code reaches here,that means the property not found -> check in the metatable's __index + + // findMetatableEntry demands the ability to emit errors, so we must give it + // the necessary state to do that, even if we intend to just eat the errors. + ErrorVec dummy; + std::optional mmType = findMetatableEntry(ctx->builtins, dummy, *classesIter, "__index", Location{}); + if (!mmType) // if a metatable does not exist, there is no where else to look + return {std::nullopt, true, {}, {}}; + + if (!tblIndexInto(ty, *mmType, properties, ctx, isRaw)) // if indexer is not in the metatable, we fail to reduce + return {std::nullopt, true, {}, {}}; + } + } + } + + if (indexeeNormTy->hasTables()) + { + LUAU_ASSERT(!indexeeNormTy->hasClasses()); + + // at least one table is guaranteed to be in the iterator by .hasTables() + for (auto tablesIter = indexeeNormTy->tables.begin(); tablesIter != indexeeNormTy->tables.end(); ++tablesIter) + { + for (TypeId ty : *typesToFind) + if (!tblIndexInto(ty, *tablesIter, properties, ctx, isRaw)) + return {std::nullopt, true, {}, {}}; + } + } + + // Call `follow()` on each element to resolve all Bound types before returning + std::transform(properties.begin(), properties.end(), properties.begin(), [](TypeId ty) { + return follow(ty); + }); + + // If the type being reduced to is a single type, no need to union + if (properties.size() == 1) + return {*properties.begin(), false, {}, {}}; + + return {ctx->arena->addType(UnionType{std::vector(properties.begin(), properties.end())}), false, {}, {}}; +} + BuiltinTypeFamilies::BuiltinTypeFamilies() : notFamily{"not", notFamilyFn} , lenFamily{"len", lenFamilyFn} @@ -1876,6 +2090,7 @@ BuiltinTypeFamilies::BuiltinTypeFamilies() , intersectFamily{"intersect", intersectFamilyFn} , keyofFamily{"keyof", keyofFamilyFn} , rawkeyofFamily{"rawkeyof", rawkeyofFamilyFn} + , indexFamily{"index", indexFamilyFn} { } @@ -1917,6 +2132,8 @@ void BuiltinTypeFamilies::addToScope(NotNull arena, NotNull sc scope->exportedTypeBindings[keyofFamily.name] = mkUnaryTypeFamily(&keyofFamily); scope->exportedTypeBindings[rawkeyofFamily.name] = mkUnaryTypeFamily(&rawkeyofFamily); + + scope->exportedTypeBindings[indexFamily.name] = mkBinaryTypeFamily(&indexFamily); } } // namespace Luau diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index eed3c715..3050f09e 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -37,7 +37,6 @@ LUAU_FASTFLAGVARIABLE(LuauMetatableInstantiationCloneCheck, false) LUAU_FASTFLAGVARIABLE(LuauTinyControlFlowAnalysis, false) LUAU_FASTFLAGVARIABLE(LuauAlwaysCommitInferencesOfFunctionCalls, false) LUAU_FASTFLAGVARIABLE(LuauRemoveBadRelationalOperatorWarning, false) -LUAU_FASTFLAGVARIABLE(LuauForbidAliasNamedTypeof, false) LUAU_FASTFLAGVARIABLE(LuauOkWithIteratingOverTableProperties, false) namespace Luau @@ -667,7 +666,7 @@ LUAU_NOINLINE void TypeChecker::checkBlockTypeAliases(const ScopePtr& scope, std { if (const auto& typealias = stat->as()) { - if (typealias->name == kParseNameError || (FFlag::LuauForbidAliasNamedTypeof && typealias->name == "typeof")) + if (typealias->name == kParseNameError || typealias->name == "typeof") continue; auto& bindings = typealias->exported ? scope->exportedTypeBindings : scope->privateTypeBindings; @@ -1535,7 +1534,7 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& ty if (name == kParseNameError) return ControlFlow::None; - if (FFlag::LuauForbidAliasNamedTypeof && name == "typeof") + if (name == "typeof") { reportError(typealias.location, GenericError{"Type aliases cannot be named typeof"}); return ControlFlow::None; @@ -1656,7 +1655,7 @@ void TypeChecker::prototype(const ScopePtr& scope, const AstStatTypeAlias& typea // If the alias is missing a name, we can't do anything with it. Ignore it. // Also, typeof is not a valid type alias name. We will report an error for // this in check() - if (name == kParseNameError || (FFlag::LuauForbidAliasNamedTypeof && name == "typeof")) + if (name == kParseNameError || name == "typeof") return; std::optional binding; diff --git a/Ast/include/Luau/Ast.h b/Ast/include/Luau/Ast.h index e8479e09..ab0d40e2 100644 --- a/Ast/include/Luau/Ast.h +++ b/Ast/include/Luau/Ast.h @@ -207,6 +207,7 @@ public: enum Type { Checked, + Native, }; AstAttr(const Location& location, Type type); @@ -420,6 +421,8 @@ public: void visit(AstVisitor* visitor) override; + bool hasNativeAttribute() const; + AstArray attributes; AstArray generics; AstArray genericPacks; diff --git a/Ast/include/Luau/Parser.h b/Ast/include/Luau/Parser.h index c1fd43ea..5a945e26 100644 --- a/Ast/include/Luau/Parser.h +++ b/Ast/include/Luau/Parser.h @@ -234,7 +234,7 @@ private: // asexp -> simpleexp [`::' Type] AstExpr* parseAssertionExpr(); - // simpleexp -> NUMBER | STRING | NIL | true | false | ... | constructor | FUNCTION body | primaryexp + // simpleexp -> NUMBER | STRING | NIL | true | false | ... | constructor | [attributes] FUNCTION body | primaryexp AstExpr* parseSimpleExpr(); // args ::= `(' [explist] `)' | tableconstructor | String diff --git a/Ast/src/Ast.cpp b/Ast/src/Ast.cpp index 4c956307..14b79767 100644 --- a/Ast/src/Ast.cpp +++ b/Ast/src/Ast.cpp @@ -4,6 +4,7 @@ #include "Luau/Common.h" LUAU_FASTFLAG(LuauAttributeSyntax); +LUAU_FASTFLAG(LuauNativeAttribute); namespace Luau { @@ -214,6 +215,18 @@ void AstExprFunction::visit(AstVisitor* visitor) } } +bool AstExprFunction::hasNativeAttribute() const +{ + LUAU_ASSERT(FFlag::LuauNativeAttribute); + + for (const auto attribute : attributes) + { + if (attribute->type == AstAttr::Type::Native) + return true; + } + return false; +} + AstExprTable::AstExprTable(const Location& location, const AstArray& items) : AstExpr(ClassIndex(), location) , items(items) diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index d80878d5..3a6625a5 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -18,7 +18,9 @@ LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) // See docs/SyntaxChanges.md for an explanation. LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false) LUAU_FASTFLAG(LuauAttributeSyntax) -LUAU_FASTFLAGVARIABLE(LuauLeadingBarAndAmpersand, false) +LUAU_FASTFLAGVARIABLE(LuauLeadingBarAndAmpersand2, false) +LUAU_FASTFLAGVARIABLE(LuauNativeAttribute, false) +LUAU_FASTFLAGVARIABLE(LuauAttributeSyntaxFunExpr, false) namespace Luau { @@ -29,7 +31,7 @@ struct AttributeEntry AstAttr::Type type; }; -AttributeEntry kAttributeEntries[] = {{"@checked", AstAttr::Type::Checked}, {nullptr, AstAttr::Type::Checked}}; +AttributeEntry kAttributeEntries[] = {{"@checked", AstAttr::Type::Checked}, {"@native", AstAttr::Type::Native}, {nullptr, AstAttr::Type::Checked}}; ParseError::ParseError(const Location& location, const std::string& message) : location(location) @@ -703,6 +705,10 @@ std::pair Parser::validateAttribute(const char* attributeNa if (found) { type = kAttributeEntries[i].type; + + if (!FFlag::LuauNativeAttribute && type == AstAttr::Type::Native) + found = false; + break; } } @@ -772,7 +778,7 @@ AstStat* Parser::parseAttributeStat() { LUAU_ASSERT(FFlag::LuauAttributeSyntax); - AstArray attributes = Parser::parseAttributes(); + AstArray attributes = parseAttributes(); Lexeme::Type type = lexer.current().type; @@ -1654,7 +1660,7 @@ AstType* Parser::parseTypeSuffix(AstType* type, const Location& begin) { TempVector parts(scratchType); - if (!FFlag::LuauLeadingBarAndAmpersand || type != nullptr) + if (!FFlag::LuauLeadingBarAndAmpersand2 || type != nullptr) { parts.push_back(type); } @@ -1682,6 +1688,8 @@ AstType* Parser::parseTypeSuffix(AstType* type, const Location& begin) } else if (c == '?') { + LUAU_ASSERT(parts.size() >= 1); + Location loc = lexer.current().location; nextLexeme(); @@ -1714,7 +1722,7 @@ AstType* Parser::parseTypeSuffix(AstType* type, const Location& begin) } if (parts.size() == 1) - return type; + return FFlag::LuauLeadingBarAndAmpersand2 ? parts[0] : type; if (isUnion && isIntersection) { @@ -1761,7 +1769,7 @@ AstType* Parser::parseType(bool inDeclarationContext) Location begin = lexer.current().location; - if (FFlag::LuauLeadingBarAndAmpersand) + if (FFlag::LuauLeadingBarAndAmpersand2) { AstType* type = nullptr; @@ -2369,11 +2377,24 @@ static ConstantNumberParseResult parseDouble(double& result, const char* data) return ConstantNumberParseResult::Ok; } -// simpleexp -> NUMBER | STRING | NIL | true | false | ... | constructor | FUNCTION body | primaryexp +// simpleexp -> NUMBER | STRING | NIL | true | false | ... | constructor | [attributes] FUNCTION body | primaryexp AstExpr* Parser::parseSimpleExpr() { Location start = lexer.current().location; + AstArray attributes{nullptr, 0}; + + if (FFlag::LuauAttributeSyntax && FFlag::LuauAttributeSyntaxFunExpr && lexer.current().type == Lexeme::Attribute) + { + attributes = parseAttributes(); + + if (lexer.current().type != Lexeme::ReservedFunction) + { + return reportExprError( + start, {}, "Expected 'function' declaration after attribute, but got %s intead", lexer.current().toString().c_str()); + } + } + if (lexer.current().type == Lexeme::ReservedNil) { nextLexeme(); @@ -2397,7 +2418,7 @@ AstExpr* Parser::parseSimpleExpr() Lexeme matchFunction = lexer.current(); nextLexeme(); - return parseFunctionBody(false, matchFunction, AstName(), nullptr, AstArray({nullptr, 0})).first; + return parseFunctionBody(false, matchFunction, AstName(), nullptr, attributes).first; } else if (lexer.current().type == Lexeme::Number) { diff --git a/CodeGen/include/Luau/IrBuilder.h b/CodeGen/include/Luau/IrBuilder.h index 8ad75fbe..2077cce0 100644 --- a/CodeGen/include/Luau/IrBuilder.h +++ b/CodeGen/include/Luau/IrBuilder.h @@ -54,6 +54,7 @@ struct IrBuilder IrOp inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d); IrOp inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d, IrOp e); IrOp inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d, IrOp e, IrOp f); + IrOp inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d, IrOp e, IrOp f, IrOp g); IrOp block(IrBlockKind kind); // Requested kind can be ignored if we are in an outlined sequence IrOp blockAtInst(uint32_t index); diff --git a/CodeGen/include/Luau/IrData.h b/CodeGen/include/Luau/IrData.h index 60af706f..d0e40ca3 100644 --- a/CodeGen/include/Luau/IrData.h +++ b/CodeGen/include/Luau/IrData.h @@ -31,7 +31,7 @@ enum // * Rn - VM stack register slot, n in 0..254 // * Kn - VM proto constant slot, n in 0..2^23-1 // * UPn - VM function upvalue slot, n in 0..199 -// * A, B, C, D, E are instruction arguments +// * A, B, C, D, E, F, G are instruction arguments enum class IrCmd : uint8_t { NOP, @@ -869,6 +869,7 @@ struct IrInst IrOp d; IrOp e; IrOp f; + IrOp g; uint32_t lastUse = 0; uint16_t useCount = 0; @@ -923,6 +924,7 @@ struct IrInstHash h = mix(h, key.d); h = mix(h, key.e); h = mix(h, key.f); + h = mix(h, key.g); // MurmurHash2 tail h ^= h >> 13; @@ -937,7 +939,7 @@ struct IrInstEq { bool operator()(const IrInst& a, const IrInst& b) const { - return a.cmd == b.cmd && a.a == b.a && a.b == b.b && a.c == b.c && a.d == b.d && a.e == b.e && a.f == b.f; + return a.cmd == b.cmd && a.a == b.a && a.b == b.b && a.c == b.c && a.d == b.d && a.e == b.e && a.f == b.f && a.g == b.g; } }; diff --git a/CodeGen/include/Luau/IrVisitUseDef.h b/CodeGen/include/Luau/IrVisitUseDef.h index 58c88661..32dd6c2a 100644 --- a/CodeGen/include/Luau/IrVisitUseDef.h +++ b/CodeGen/include/Luau/IrVisitUseDef.h @@ -228,6 +228,7 @@ static void visitVmRegDefsUses(T& visitor, IrFunction& function, const IrInst& i CODEGEN_ASSERT(inst.d.kind != IrOpKind::VmReg); CODEGEN_ASSERT(inst.e.kind != IrOpKind::VmReg); CODEGEN_ASSERT(inst.f.kind != IrOpKind::VmReg); + CODEGEN_ASSERT(inst.g.kind != IrOpKind::VmReg); break; } } diff --git a/CodeGen/src/BytecodeSummary.cpp b/CodeGen/src/BytecodeSummary.cpp index 0089f592..d0d71504 100644 --- a/CodeGen/src/BytecodeSummary.cpp +++ b/CodeGen/src/BytecodeSummary.cpp @@ -8,6 +8,8 @@ #include "lobject.h" #include "lstate.h" +LUAU_FASTFLAG(LuauNativeAttribute) + namespace Luau { namespace CodeGen @@ -56,7 +58,10 @@ std::vector summarizeBytecode(lua_State* L, int idx, un Proto* root = clvalue(func)->l.p; std::vector protos; - gatherFunctions(protos, root, CodeGen_ColdFunctions); + if (FFlag::LuauNativeAttribute) + gatherFunctions(protos, root, CodeGen_ColdFunctions, root->flags & LPF_NATIVE_FUNCTION); + else + gatherFunctions_DEPRECATED(protos, root, CodeGen_ColdFunctions); std::vector summaries; summaries.reserve(protos.size()); diff --git a/CodeGen/src/CodeGenAssembly.cpp b/CodeGen/src/CodeGenAssembly.cpp index 269bf8dc..121535be 100644 --- a/CodeGen/src/CodeGenAssembly.cpp +++ b/CodeGen/src/CodeGenAssembly.cpp @@ -14,6 +14,7 @@ LUAU_FASTFLAG(LuauCodegenTypeInfo) LUAU_FASTFLAG(LuauLoadUserdataInfo) +LUAU_FASTFLAG(LuauNativeAttribute) namespace Luau { @@ -200,7 +201,10 @@ static std::string getAssemblyImpl(AssemblyBuilder& build, const TValue* func, A return std::string(); std::vector protos; - gatherFunctions(protos, root, options.compilationOptions.flags); + if (FFlag::LuauNativeAttribute) + gatherFunctions(protos, root, options.compilationOptions.flags, root->flags & LPF_NATIVE_FUNCTION); + else + gatherFunctions_DEPRECATED(protos, root, options.compilationOptions.flags); protos.erase(std::remove_if(protos.begin(), protos.end(), [](Proto* p) { diff --git a/CodeGen/src/CodeGenContext.cpp b/CodeGen/src/CodeGenContext.cpp index ae9e41f1..67a2676e 100644 --- a/CodeGen/src/CodeGenContext.cpp +++ b/CodeGen/src/CodeGenContext.cpp @@ -16,6 +16,7 @@ LUAU_FASTFLAGVARIABLE(LuauCodegenCheckNullContext, false) LUAU_FASTINTVARIABLE(LuauCodeGenBlockSize, 4 * 1024 * 1024) LUAU_FASTINTVARIABLE(LuauCodeGenMaxTotalSize, 256 * 1024 * 1024) +LUAU_FASTFLAG(LuauNativeAttribute) namespace Luau { @@ -455,7 +456,7 @@ template Proto* root = clvalue(func)->l.p; - if ((options.flags & CodeGen_OnlyNativeModules) != 0 && (root->flags & LPF_NATIVE_MODULE) == 0) + if ((options.flags & CodeGen_OnlyNativeModules) != 0 && (root->flags & LPF_NATIVE_MODULE) == 0 && (root->flags & LPF_NATIVE_FUNCTION) == 0) return CompilationResult{CodeGenCompilationResult::NotNativeModule}; BaseCodeGenContext* codeGenContext = getCodeGenContext(L); @@ -463,7 +464,10 @@ template return CompilationResult{CodeGenCompilationResult::CodeGenNotInitialized}; std::vector protos; - gatherFunctions(protos, root, options.flags); + if (FFlag::LuauNativeAttribute) + gatherFunctions(protos, root, options.flags, root->flags & LPF_NATIVE_FUNCTION); + else + gatherFunctions_DEPRECATED(protos, root, options.flags); // Skip protos that have been compiled during previous invocations of CodeGen::compile protos.erase(std::remove_if(protos.begin(), protos.end(), diff --git a/CodeGen/src/CodeGenLower.h b/CodeGen/src/CodeGenLower.h index 6015ef10..4523d62b 100644 --- a/CodeGen/src/CodeGenLower.h +++ b/CodeGen/src/CodeGenLower.h @@ -29,13 +29,14 @@ LUAU_FASTINT(CodegenHeuristicsBlockLimit) LUAU_FASTINT(CodegenHeuristicsBlockInstructionLimit) LUAU_FASTFLAG(LuauCodegenRemoveDeadStores5) LUAU_FASTFLAG(LuauLoadUserdataInfo) +LUAU_FASTFLAG(LuauNativeAttribute) namespace Luau { namespace CodeGen { -inline void gatherFunctions(std::vector& results, Proto* proto, unsigned int flags) +inline void gatherFunctions_DEPRECATED(std::vector& results, Proto* proto, unsigned int flags) { if (results.size() <= size_t(proto->bytecodeid)) results.resize(proto->bytecodeid + 1); @@ -50,7 +51,36 @@ inline void gatherFunctions(std::vector& results, Proto* proto, unsigned // Recursively traverse child protos even if we aren't compiling this one for (int i = 0; i < proto->sizep; i++) - gatherFunctions(results, proto->p[i], flags); + gatherFunctions_DEPRECATED(results, proto->p[i], flags); +} + +inline void gatherFunctionsHelper( + std::vector& results, Proto* proto, const unsigned int flags, const bool hasNativeFunctions, const bool root) +{ + if (results.size() <= size_t(proto->bytecodeid)) + results.resize(proto->bytecodeid + 1); + + // Skip protos that we've already compiled in this run: this happens because at -O2, inlined functions get their protos reused + if (results[proto->bytecodeid]) + return; + + // if native module, compile cold functions if requested + // if not native module, compile function if it has native attribute and is not root + bool shouldGather = hasNativeFunctions ? (!root && (proto->flags & LPF_NATIVE_FUNCTION) != 0) + : ((proto->flags & LPF_NATIVE_COLD) == 0 || (flags & CodeGen_ColdFunctions) != 0); + + if (shouldGather) + results[proto->bytecodeid] = proto; + + // Recursively traverse child protos even if we aren't compiling this one + for (int i = 0; i < proto->sizep; i++) + gatherFunctionsHelper(results, proto->p[i], flags, hasNativeFunctions, false); +} + +inline void gatherFunctions(std::vector& results, Proto* root, const unsigned int flags, const bool hasNativeFunctions = false) +{ + LUAU_ASSERT(FFlag::LuauNativeAttribute); + gatherFunctionsHelper(results, root, flags, hasNativeFunctions, true); } inline unsigned getInstructionCount(const std::vector& instructions, IrCmd cmd) diff --git a/CodeGen/src/IrAnalysis.cpp b/CodeGen/src/IrAnalysis.cpp index 30ed42a0..f78823df 100644 --- a/CodeGen/src/IrAnalysis.cpp +++ b/CodeGen/src/IrAnalysis.cpp @@ -13,6 +13,8 @@ #include +LUAU_FASTFLAGVARIABLE(LuauCodegenInstG, false) + namespace Luau { namespace CodeGen @@ -52,6 +54,9 @@ void updateUseCounts(IrFunction& function) checkOp(inst.d); checkOp(inst.e); checkOp(inst.f); + + if (FFlag::LuauCodegenInstG) + checkOp(inst.g); } } @@ -95,6 +100,9 @@ void updateLastUseLocations(IrFunction& function, const std::vector& s checkOp(inst.d); checkOp(inst.e); checkOp(inst.f); + + if (FFlag::LuauCodegenInstG) + checkOp(inst.g); } } } @@ -128,6 +136,12 @@ uint32_t getNextInstUse(IrFunction& function, uint32_t targetInstIdx, uint32_t s if (inst.f.kind == IrOpKind::Inst && inst.f.index == targetInstIdx) return i; + + if (FFlag::LuauCodegenInstG) + { + if (inst.g.kind == IrOpKind::Inst && inst.g.index == targetInstIdx) + return i; + } } // There must be a next use since there is the last use location @@ -165,6 +179,9 @@ std::pair getLiveInOutValueCount(IrFunction& function, IrBlo checkOp(inst.d); checkOp(inst.e); checkOp(inst.f); + + if (FFlag::LuauCodegenInstG) + checkOp(inst.g); } return std::make_pair(liveIns, liveOuts); @@ -488,6 +505,9 @@ static void computeCfgBlockEdges(IrFunction& function) checkOp(inst.d); checkOp(inst.e); checkOp(inst.f); + + if (FFlag::LuauCodegenInstG) + checkOp(inst.g); } } diff --git a/CodeGen/src/IrBuilder.cpp b/CodeGen/src/IrBuilder.cpp index 723d35c4..e62885eb 100644 --- a/CodeGen/src/IrBuilder.cpp +++ b/CodeGen/src/IrBuilder.cpp @@ -16,6 +16,7 @@ LUAU_FASTFLAG(LuauLoadTypeInfo) // Because new VM typeinfo load changes the format used by Codegen, same flag is used LUAU_FASTFLAG(LuauCodegenAnalyzeHostVectorOps) LUAU_FASTFLAG(LuauLoadUserdataInfo) +LUAU_FASTFLAG(LuauCodegenInstG) namespace Luau { @@ -741,6 +742,9 @@ void IrBuilder::clone(const IrBlock& source, bool removeCurrentTerminator) redirect(clone.e); redirect(clone.f); + if (FFlag::LuauCodegenInstG) + redirect(clone.g); + addUse(function, clone.a); addUse(function, clone.b); addUse(function, clone.c); @@ -748,11 +752,17 @@ void IrBuilder::clone(const IrBlock& source, bool removeCurrentTerminator) addUse(function, clone.e); addUse(function, clone.f); + if (FFlag::LuauCodegenInstG) + addUse(function, clone.g); + // Instructions that referenced the original will have to be adjusted to use the clone instRedir[index] = uint32_t(function.instructions.size()); // Reconstruct the fresh clone - inst(clone.cmd, clone.a, clone.b, clone.c, clone.d, clone.e, clone.f); + if (FFlag::LuauCodegenInstG) + inst(clone.cmd, clone.a, clone.b, clone.c, clone.d, clone.e, clone.f, clone.g); + else + inst(clone.cmd, clone.a, clone.b, clone.c, clone.d, clone.e, clone.f); } } @@ -850,8 +860,33 @@ IrOp IrBuilder::inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d, IrOp e) IrOp IrBuilder::inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d, IrOp e, IrOp f) { + if (FFlag::LuauCodegenInstG) + { + return inst(cmd, a, b, c, d, e, f, {}); + } + else + { + uint32_t index = uint32_t(function.instructions.size()); + function.instructions.push_back({cmd, a, b, c, d, e, f}); + + CODEGEN_ASSERT(!inTerminatedBlock); + + if (isBlockTerminator(cmd)) + { + function.blocks[activeBlockIdx].finish = index; + inTerminatedBlock = true; + } + + return {IrOpKind::Inst, index}; + } +} + +IrOp IrBuilder::inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d, IrOp e, IrOp f, IrOp g) +{ + CODEGEN_ASSERT(FFlag::LuauCodegenInstG); + uint32_t index = uint32_t(function.instructions.size()); - function.instructions.push_back({cmd, a, b, c, d, e, f}); + function.instructions.push_back({cmd, a, b, c, d, e, f, g}); CODEGEN_ASSERT(!inTerminatedBlock); diff --git a/CodeGen/src/IrDump.cpp b/CodeGen/src/IrDump.cpp index a82ee894..5465d0a0 100644 --- a/CodeGen/src/IrDump.cpp +++ b/CodeGen/src/IrDump.cpp @@ -8,6 +8,7 @@ #include LUAU_FASTFLAG(LuauLoadUserdataInfo) +LUAU_FASTFLAG(LuauCodegenInstG) namespace Luau { @@ -417,6 +418,9 @@ void toString(IrToStringContext& ctx, const IrInst& inst, uint32_t index) checkOp(inst.d, ", "); checkOp(inst.e, ", "); checkOp(inst.f, ", "); + + if (FFlag::LuauCodegenInstG) + checkOp(inst.g, ", "); } void toString(IrToStringContext& ctx, const IrBlock& block, uint32_t index) @@ -656,6 +660,8 @@ static RegisterSet getJumpTargetExtraLiveIn(IrToStringContext& ctx, const IrBloc op = inst.e; else if (inst.f.kind == IrOpKind::Block) op = inst.f; + else if (FFlag::LuauCodegenInstG && inst.g.kind == IrOpKind::Block) + op = inst.g; if (op.kind == IrOpKind::Block && op.index < ctx.cfg.in.size()) { @@ -940,6 +946,9 @@ std::string toDot(const IrFunction& function, bool includeInst) checkOp(inst.d); checkOp(inst.e); checkOp(inst.f); + + if (FFlag::LuauCodegenInstG) + checkOp(inst.g); } } diff --git a/CodeGen/src/IrRegAllocA64.cpp b/CodeGen/src/IrRegAllocA64.cpp index 24b0b285..af63a2fc 100644 --- a/CodeGen/src/IrRegAllocA64.cpp +++ b/CodeGen/src/IrRegAllocA64.cpp @@ -11,6 +11,7 @@ #include LUAU_FASTFLAGVARIABLE(DebugCodegenChaosA64, false) +LUAU_FASTFLAG(LuauCodegenInstG) namespace Luau { @@ -256,6 +257,9 @@ void IrRegAllocA64::freeLastUseRegs(const IrInst& inst, uint32_t index) checkOp(inst.d); checkOp(inst.e); checkOp(inst.f); + + if (FFlag::LuauCodegenInstG) + checkOp(inst.g); } void IrRegAllocA64::freeTempRegs() diff --git a/CodeGen/src/IrRegAllocX64.cpp b/CodeGen/src/IrRegAllocX64.cpp index 2b5da623..60326074 100644 --- a/CodeGen/src/IrRegAllocX64.cpp +++ b/CodeGen/src/IrRegAllocX64.cpp @@ -6,6 +6,8 @@ #include "EmitCommonX64.h" +LUAU_FASTFLAG(LuauCodegenInstG) + namespace Luau { namespace CodeGen @@ -181,6 +183,9 @@ void IrRegAllocX64::freeLastUseRegs(const IrInst& inst, uint32_t instIdx) checkOp(inst.d); checkOp(inst.e); checkOp(inst.f); + + if (FFlag::LuauCodegenInstG) + checkOp(inst.g); } bool IrRegAllocX64::isLastUseReg(const IrInst& target, uint32_t instIdx) const diff --git a/CodeGen/src/IrUtils.cpp b/CodeGen/src/IrUtils.cpp index d1bfca45..2244c4d3 100644 --- a/CodeGen/src/IrUtils.cpp +++ b/CodeGen/src/IrUtils.cpp @@ -12,6 +12,8 @@ #include #include +LUAU_FASTFLAG(LuauCodegenInstG) + namespace Luau { namespace CodeGen @@ -315,12 +317,18 @@ void kill(IrFunction& function, IrInst& inst) removeUse(function, inst.e); removeUse(function, inst.f); + if (FFlag::LuauCodegenInstG) + removeUse(function, inst.g); + inst.a = {}; inst.b = {}; inst.c = {}; inst.d = {}; inst.e = {}; inst.f = {}; + + if (FFlag::LuauCodegenInstG) + inst.g = {}; } void kill(IrFunction& function, uint32_t start, uint32_t end) @@ -370,6 +378,9 @@ void replace(IrFunction& function, IrBlock& block, uint32_t instIdx, IrInst repl addUse(function, replacement.e); addUse(function, replacement.f); + if (FFlag::LuauCodegenInstG) + addUse(function, replacement.g); + // An extra reference is added so block will not remove itself block.useCount++; @@ -392,6 +403,9 @@ void replace(IrFunction& function, IrBlock& block, uint32_t instIdx, IrInst repl removeUse(function, inst.e); removeUse(function, inst.f); + if (FFlag::LuauCodegenInstG) + removeUse(function, inst.g); + // Inherit existing use count (last use is skipped as it will be defined later) replacement.useCount = inst.useCount; @@ -417,12 +431,18 @@ void substitute(IrFunction& function, IrInst& inst, IrOp replacement) removeUse(function, inst.e); removeUse(function, inst.f); + if (FFlag::LuauCodegenInstG) + removeUse(function, inst.g); + inst.a = replacement; inst.b = {}; inst.c = {}; inst.d = {}; inst.e = {}; inst.f = {}; + + if (FFlag::LuauCodegenInstG) + inst.g = {}; } void applySubstitutions(IrFunction& function, IrOp& op) @@ -466,6 +486,9 @@ void applySubstitutions(IrFunction& function, IrInst& inst) applySubstitutions(function, inst.d); applySubstitutions(function, inst.e); applySubstitutions(function, inst.f); + + if (FFlag::LuauCodegenInstG) + applySubstitutions(function, inst.g); } bool compare(double a, double b, IrCondition cond) diff --git a/CodeGen/src/IrValueLocationTracking.cpp b/CodeGen/src/IrValueLocationTracking.cpp index 3dc72610..c6b2d044 100644 --- a/CodeGen/src/IrValueLocationTracking.cpp +++ b/CodeGen/src/IrValueLocationTracking.cpp @@ -146,6 +146,7 @@ void IrValueLocationTracking::beforeInstLowering(IrInst& inst) CODEGEN_ASSERT(inst.d.kind != IrOpKind::VmReg); CODEGEN_ASSERT(inst.e.kind != IrOpKind::VmReg); CODEGEN_ASSERT(inst.f.kind != IrOpKind::VmReg); + CODEGEN_ASSERT(inst.g.kind != IrOpKind::VmReg); break; } } diff --git a/Common/include/Luau/Bytecode.h b/Common/include/Luau/Bytecode.h index 2ae54c67..85fef5aa 100644 --- a/Common/include/Luau/Bytecode.h +++ b/Common/include/Luau/Bytecode.h @@ -612,4 +612,6 @@ enum LuauProtoFlag LPF_NATIVE_MODULE = 1 << 0, // used to tag individual protos as not profitable to compile natively LPF_NATIVE_COLD = 1 << 1, + // used to tag main proto for modules that have at least one function with native attribute + LPF_NATIVE_FUNCTION = 1 << 2, }; diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 4842b9a1..db86fbc6 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -30,6 +30,8 @@ LUAU_FASTFLAG(LuauCompileTypeInfo) LUAU_FASTFLAGVARIABLE(LuauCompileTempTypeInfo, false) LUAU_FASTFLAGVARIABLE(LuauCompileUserdataInfo, false) +LUAU_FASTFLAG(LuauNativeAttribute) + namespace Luau { @@ -195,7 +197,7 @@ struct Compiler return node->as(); } - uint32_t compileFunction(AstExprFunction* func, uint8_t protoflags) + uint32_t compileFunction(AstExprFunction* func, uint8_t& protoflags) { LUAU_TIMETRACE_SCOPE("Compiler::compileFunction", "Compiler"); @@ -297,6 +299,9 @@ struct Compiler if (func->functionDepth == 0 && !hasLoops) protoflags |= LPF_NATIVE_COLD; + if (FFlag::LuauNativeAttribute && func->hasNativeAttribute()) + protoflags |= LPF_NATIVE_FUNCTION; + bytecode.endFunction(uint8_t(stackSize), uint8_t(upvals.size()), protoflags); Function& f = functions[func]; @@ -3863,13 +3868,12 @@ struct Compiler struct FunctionVisitor : AstVisitor { - Compiler* self; std::vector& functions; bool hasTypes = false; + bool hasNativeFunction = false; - FunctionVisitor(Compiler* self, std::vector& functions) - : self(self) - , functions(functions) + FunctionVisitor(std::vector& functions) + : functions(functions) { // preallocate the result; this works around std::vector's inefficient growth policy for small arrays functions.reserve(16); @@ -3885,6 +3889,9 @@ struct Compiler // this makes sure all functions that are used when compiling this one have been already added to the vector functions.push_back(node); + if (FFlag::LuauNativeAttribute && !hasNativeFunction && node->hasNativeAttribute()) + hasNativeFunction = true; + return false; } }; @@ -4117,6 +4124,14 @@ struct Compiler std::vector> interpStrings; }; +static void setCompileOptionsForNativeCompilation(CompileOptions& options) +{ + options.optimizationLevel = 2; // note: this might be removed in the future in favor of --!optimize + + if (FFlag::LuauCompileTypeInfo) + options.typeInfoLevel = 1; +} + void compileOrThrow(BytecodeBuilder& bytecode, const ParseResult& parseResult, const AstNameTable& names, const CompileOptions& inputOptions) { LUAU_TIMETRACE_SCOPE("compileOrThrow", "Compiler"); @@ -4135,15 +4150,21 @@ void compileOrThrow(BytecodeBuilder& bytecode, const ParseResult& parseResult, c if (hc.header && hc.content == "native") { mainFlags |= LPF_NATIVE_MODULE; - options.optimizationLevel = 2; // note: this might be removed in the future in favor of --!optimize - - if (FFlag::LuauCompileTypeInfo) - options.typeInfoLevel = 1; + setCompileOptionsForNativeCompilation(options); } } AstStatBlock* root = parseResult.root; + // gathers all functions with the invariant that all function references are to functions earlier in the list + // for example, function foo() return function() end end will result in two vector entries, [0] = anonymous and [1] = foo + std::vector functions; + Compiler::FunctionVisitor functionVisitor(functions); + root->visit(&functionVisitor); + + if (functionVisitor.hasNativeFunction) + setCompileOptionsForNativeCompilation(options); + Compiler compiler(bytecode, options); // since access to some global objects may result in values that change over time, we block imports from non-readonly tables @@ -4180,12 +4201,6 @@ void compileOrThrow(BytecodeBuilder& bytecode, const ParseResult& parseResult, c predictTableShapes(compiler.tableShapes, root); } - // gathers all functions with the invariant that all function references are to functions earlier in the list - // for example, function foo() return function() end end will result in two vector entries, [0] = anonymous and [1] = foo - std::vector functions; - Compiler::FunctionVisitor functionVisitor(&compiler, functions); - root->visit(&functionVisitor); - if (FFlag::LuauCompileUserdataInfo) { if (const char* const* ptr = options.userdataTypes) @@ -4217,7 +4232,15 @@ void compileOrThrow(BytecodeBuilder& bytecode, const ParseResult& parseResult, c } for (AstExprFunction* expr : functions) - compiler.compileFunction(expr, 0); + { + uint8_t protoflags = 0; + compiler.compileFunction(expr, protoflags); + + // If a function has native attribute and the whole module is not native, we set LPF_NATIVE_FUNCTION flag + // This ensures that LPF_NATIVE_MODULE and LPF_NATIVE_FUNCTION are exclusive. + if (FFlag::LuauNativeAttribute && (protoflags & LPF_NATIVE_FUNCTION) && !(mainFlags & LPF_NATIVE_MODULE)) + mainFlags |= LPF_NATIVE_FUNCTION; + } AstExprFunction main(root->location, /*attributes=*/AstArray({nullptr, 0}), /*generics= */ AstArray(), /*genericPacks= */ AstArray(), diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 9c1fca9e..516e02f4 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -34,6 +34,8 @@ void luaC_validate(lua_State* L); LUAU_FASTFLAG(DebugLuauAbortingChecks) LUAU_FASTINT(CodegenHeuristicsInstructionLimit) LUAU_FASTFLAG(LuauCodegenFixSplitStoreConstMismatch) +LUAU_FASTFLAG(LuauAttributeSyntax) +LUAU_FASTFLAG(LuauNativeAttribute) static lua_CompileOptions defaultOptions() { @@ -2707,4 +2709,57 @@ end 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})); } +TEST_CASE("NativeAttribute") +{ + if (!codegen || !luau_codegen_supported()) + return; + + ScopedFastFlag sffs[] = {{FFlag::LuauAttributeSyntax, true}, {FFlag::LuauNativeAttribute, true}}; + + std::string source = R"R( + @native + local function sum(x, y) + local function sumHelper(z) + return (x+y+z) + end + return sumHelper + end + + local function sub(x, y) + @native + local function subHelper(z) + return (x+y-z) + end + return subHelper + end)R"; + + StateRef globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + luau_codegen_create(L); + + luaL_openlibs(L); + luaL_sandbox(L); + luaL_sandboxthread(L); + + size_t bytecodeSize = 0; + char* bytecode = luau_compile(source.data(), source.size(), nullptr, &bytecodeSize); + int result = luau_load(L, "=Code", bytecode, bytecodeSize, 0); + free(bytecode); + + REQUIRE(result == 0); + + Luau::CodeGen::CompilationOptions nativeOptions{Luau::CodeGen::CodeGen_ColdFunctions}; + Luau::CodeGen::CompilationStats nativeStats = {}; + Luau::CodeGen::CompilationResult nativeResult = Luau::CodeGen::compile(L, -1, nativeOptions, &nativeStats); + + CHECK(nativeResult.result == Luau::CodeGen::CodeGenCompilationResult::Success); + + CHECK(!nativeResult.hasErrors()); + REQUIRE(nativeResult.protoFailures.empty()); + + // We should be able to compile at least one of our functions + CHECK_EQ(nativeStats.functionsCompiled, 2); +} + TEST_SUITE_END(); diff --git a/tests/Generalization.test.cpp b/tests/Generalization.test.cpp index 8268dde6..43bd7325 100644 --- a/tests/Generalization.test.cpp +++ b/tests/Generalization.test.cpp @@ -21,14 +21,18 @@ struct GeneralizationFixture { TypeArena arena; BuiltinTypes builtinTypes; - Scope scope{builtinTypes.anyTypePack}; + ScopePtr globalScope = std::make_shared(builtinTypes.anyTypePack); + ScopePtr scope = std::make_shared(globalScope); ToStringOptions opts; + DenseHashSet generalizedTypes_{nullptr}; + NotNull> generalizedTypes{&generalizedTypes_}; + ScopedFastFlag sff{FFlag::DebugLuauDeferredConstraintResolution, true}; std::pair freshType() { - FreeType ft{&scope, builtinTypes.neverType, builtinTypes.unknownType}; + FreeType ft{scope.get(), builtinTypes.neverType, builtinTypes.unknownType}; TypeId ty = arena.addType(ft); FreeType* ftv = getMutable(ty); @@ -49,7 +53,7 @@ struct GeneralizationFixture std::optional generalize(TypeId ty) { - return ::Luau::generalize(NotNull{&arena}, NotNull{&builtinTypes}, NotNull{&scope}, ty); + return ::Luau::generalize(NotNull{&arena}, NotNull{&builtinTypes}, NotNull{scope.get()}, generalizedTypes, ty); } }; @@ -116,4 +120,71 @@ TEST_CASE_FIXTURE(GeneralizationFixture, "dont_traverse_into_class_types_when_ge CHECK(is(*genPropTy)); } +TEST_CASE_FIXTURE(GeneralizationFixture, "cache_fully_generalized_types") +{ + CHECK(generalizedTypes->empty()); + + TypeId tinyTable = arena.addType(TableType{ + TableType::Props{{"one", builtinTypes.numberType}, {"two", builtinTypes.stringType}}, + std::nullopt, + TypeLevel{}, + TableState::Sealed + }); + + generalize(tinyTable); + + CHECK(generalizedTypes->contains(tinyTable)); + CHECK(generalizedTypes->contains(builtinTypes.numberType)); + CHECK(generalizedTypes->contains(builtinTypes.stringType)); +} + +TEST_CASE_FIXTURE(GeneralizationFixture, "dont_cache_types_that_arent_done_yet") +{ + TypeId freeTy = arena.addType(FreeType{NotNull{globalScope.get()}, builtinTypes.neverType, builtinTypes.stringType}); + + TypeId fnTy = arena.addType(FunctionType{ + builtinTypes.emptyTypePack, + arena.addTypePack(TypePack{{builtinTypes.numberType}}) + }); + + TypeId tableTy = arena.addType(TableType{ + TableType::Props{{"one", builtinTypes.numberType}, {"two", freeTy}, {"three", fnTy}}, + std::nullopt, + TypeLevel{}, + TableState::Sealed + }); + + generalize(tableTy); + + CHECK(generalizedTypes->contains(fnTy)); + CHECK(generalizedTypes->contains(builtinTypes.numberType)); + CHECK(generalizedTypes->contains(builtinTypes.neverType)); + CHECK(generalizedTypes->contains(builtinTypes.stringType)); + CHECK(!generalizedTypes->contains(freeTy)); + CHECK(!generalizedTypes->contains(tableTy)); +} + +TEST_CASE_FIXTURE(GeneralizationFixture, "functions_containing_cyclic_tables_can_be_cached") +{ + TypeId selfTy = arena.addType(BlockedType{}); + + TypeId methodTy = arena.addType(FunctionType{ + arena.addTypePack({selfTy}), + arena.addTypePack({builtinTypes.numberType}), + }); + + asMutable(selfTy)->ty.emplace( + TableType::Props{{"count", builtinTypes.numberType}, {"method", methodTy}}, + std::nullopt, + TypeLevel{}, + TableState::Sealed + ); + + generalize(methodTy); + + CHECK(generalizedTypes->contains(methodTy)); + CHECK(generalizedTypes->contains(selfTy)); + CHECK(generalizedTypes->contains(builtinTypes.numberType)); +} + TEST_SUITE_END(); diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index 2eb8ca91..e8a10e92 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -12,7 +12,6 @@ LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(LuauNormalizeNotUnknownIntersection) -LUAU_FASTFLAG(LuauFixCyclicUnionsOfIntersections); LUAU_FASTINT(LuauTypeInferRecursionLimit) using namespace Luau; @@ -799,8 +798,6 @@ TEST_CASE_FIXTURE(NormalizeFixture, "cyclic_union") TEST_CASE_FIXTURE(NormalizeFixture, "cyclic_union_of_intersection") { - ScopedFastFlag sff{FFlag::LuauFixCyclicUnionsOfIntersections, true}; - // t1 where t1 = (string & t1) | string TypeId boundTy = arena.addType(BlockedType{}); TypeId intersectTy = arena.addType(IntersectionType{{builtinTypes->stringType, boundTy}}); @@ -814,8 +811,6 @@ TEST_CASE_FIXTURE(NormalizeFixture, "cyclic_union_of_intersection") TEST_CASE_FIXTURE(NormalizeFixture, "cyclic_intersection_of_unions") { - ScopedFastFlag sff{FFlag::LuauFixCyclicUnionsOfIntersections, true}; - // t1 where t1 = (string & t1) | string TypeId boundTy = arena.addType(BlockedType{}); TypeId unionTy = arena.addType(UnionType{{builtinTypes->stringType, boundTy}}); diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 8b2cc6ba..0e5f0dd0 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -17,7 +17,8 @@ LUAU_FASTINT(LuauTypeLengthLimit); LUAU_FASTINT(LuauParseErrorLimit); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAG(LuauAttributeSyntax); -LUAU_FASTFLAG(LuauLeadingBarAndAmpersand); +LUAU_FASTFLAG(LuauLeadingBarAndAmpersand2); +LUAU_FASTFLAG(LuauAttributeSyntaxFunExpr); namespace { @@ -3234,6 +3235,45 @@ end)"); checkAttribute(attributes.data[0], AstAttr::Type::Checked, Location(Position(1, 0), Position(1, 8))); } +TEST_CASE_FIXTURE(Fixture, "parse_attribute_for_function_expression") +{ + ScopedFastFlag sff[] = {{FFlag::LuauAttributeSyntax, true}, {FFlag::LuauAttributeSyntaxFunExpr, true}}; + + AstStatBlock* stat1 = parse(R"( +local function invoker(f) + return f(1) +end + +invoker(@checked function(x) return (x + 2) end) +)"); + + LUAU_ASSERT(stat1 != nullptr); + + AstExprFunction* func1 = stat1->body.data[1]->as()->expr->as()->args.data[0]->as(); + LUAU_ASSERT(func1 != nullptr); + + AstArray attributes1 = func1->attributes; + + CHECK_EQ(attributes1.size, 1); + + checkAttribute(attributes1.data[0], AstAttr::Type::Checked, Location(Position(5, 8), Position(5, 16))); + + AstStatBlock* stat2 = parse(R"( +local f = @checked function(x) return (x + 2) end +)"); + + LUAU_ASSERT(stat2 != nullptr); + + AstExprFunction* func2 = stat2->body.data[0]->as()->values.data[0]->as(); + LUAU_ASSERT(func2 != nullptr); + + AstArray attributes2 = func2->attributes; + + CHECK_EQ(attributes2.size, 1); + + checkAttribute(attributes2.data[0], AstAttr::Type::Checked, Location(Position(1, 10), Position(1, 18))); +} + TEST_CASE_FIXTURE(Fixture, "parse_attribute_on_local_function_stat") { ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; @@ -3342,6 +3382,22 @@ function foo1 () @checked return 'a' end "Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got 'return' intead"); } +TEST_CASE_FIXTURE(Fixture, "dont_parse_attribute_on_argument_non_function") +{ + ScopedFastFlag sff[] = {{FFlag::LuauAttributeSyntax, true}, {FFlag::LuauAttributeSyntaxFunExpr, true}}; + + ParseResult pr = tryParse(R"( +local function invoker(f, y) + return f(y) +end + +invoker(function(x) return (x + 2) end, @checked 1) +)"); + + checkFirstErrorForAttributes( + pr.errors, 1, Location(Position(5, 40), Position(5, 48)), "Expected 'function' declaration after attribute, but got '1' intead"); +} + TEST_CASE_FIXTURE(Fixture, "parse_attribute_on_function_type_declaration") { ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; @@ -3472,21 +3528,21 @@ end)"); TEST_CASE_FIXTURE(Fixture, "can_parse_leading_bar_unions_successfully") { - ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand, true}; + ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand2, true}; parse(R"(type A = | "Hello" | "World")"); } TEST_CASE_FIXTURE(Fixture, "can_parse_leading_ampersand_intersections_successfully") { - ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand, true}; + ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand2, true}; parse(R"(type A = & { string } & { number })"); } TEST_CASE_FIXTURE(Fixture, "mixed_leading_intersection_and_union_not_allowed") { - ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand, true}; + ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand2, true}; matchParseError("type A = & number | string | boolean", "Mixing union and intersection types is not allowed; consider wrapping in parentheses."); matchParseError("type A = | number & string & boolean", "Mixing union and intersection types is not allowed; consider wrapping in parentheses."); diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index d7cb225a..17faa2e7 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -356,32 +356,15 @@ TEST_CASE_FIXTURE(Fixture, "quit_stringifying_type_when_length_is_exceeded") )"); if (FFlag::DebugLuauDeferredConstraintResolution) { - LUAU_REQUIRE_ERROR_COUNT(3, result); - auto err = get(result.errors[0]); - LUAU_ASSERT(err); - CHECK("(...any) -> ()" == toString(err->recommendedReturn)); - REQUIRE(1 == err->recommendedArgs.size()); - CHECK("unknown" == toString(err->recommendedArgs[0].second)); - err = get(result.errors[1]); - LUAU_ASSERT(err); - // FIXME: this recommendation could be better - CHECK("(a) -> or ()>" == toString(err->recommendedReturn)); - REQUIRE(1 == err->recommendedArgs.size()); - CHECK("unknown" == toString(err->recommendedArgs[0].second)); - err = get(result.errors[2]); - LUAU_ASSERT(err); - // FIXME: this recommendation could be better - CHECK("(a) -> or(b) -> or ()>>" == toString(err->recommendedReturn)); - REQUIRE(1 == err->recommendedArgs.size()); - CHECK("unknown" == toString(err->recommendedArgs[0].second)); + LUAU_REQUIRE_NO_ERRORS(result); ToStringOptions o; o.exhaustive = false; o.maxTypeLength = 20; CHECK_EQ(toString(requireType("f0"), o), "() -> ()"); - CHECK_EQ(toString(requireType("f1"), o), "(a) -> or ... *TRUNCATED*"); - CHECK_EQ(toString(requireType("f2"), o), "(b) -> or(a... *TRUNCATED*"); - CHECK_EQ(toString(requireType("f3"), o), "(c) -> or(b... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f1"), o), "(a) -> (() -> ()) ... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f2"), o), "(b) -> ((a) -> (() -> ())... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f3"), o), "(c) -> ((b) -> ((a) -> (() -> ())... *TRUNCATED*"); } else { @@ -408,32 +391,15 @@ TEST_CASE_FIXTURE(Fixture, "stringifying_type_is_still_capped_when_exhaustive") if (FFlag::DebugLuauDeferredConstraintResolution) { - LUAU_REQUIRE_ERROR_COUNT(3, result); - auto err = get(result.errors[0]); - LUAU_ASSERT(err); - CHECK("(...any) -> ()" == toString(err->recommendedReturn)); - REQUIRE(1 == err->recommendedArgs.size()); - CHECK("unknown" == toString(err->recommendedArgs[0].second)); - err = get(result.errors[1]); - LUAU_ASSERT(err); - // FIXME: this recommendation could be better - CHECK("(a) -> or ()>" == toString(err->recommendedReturn)); - REQUIRE(1 == err->recommendedArgs.size()); - CHECK("unknown" == toString(err->recommendedArgs[0].second)); - err = get(result.errors[2]); - LUAU_ASSERT(err); - // FIXME: this recommendation could be better - CHECK("(a) -> or(b) -> or ()>>" == toString(err->recommendedReturn)); - REQUIRE(1 == err->recommendedArgs.size()); - CHECK("unknown" == toString(err->recommendedArgs[0].second)); + LUAU_REQUIRE_NO_ERRORS(result); ToStringOptions o; o.exhaustive = true; o.maxTypeLength = 20; CHECK_EQ(toString(requireType("f0"), o), "() -> ()"); - CHECK_EQ(toString(requireType("f1"), o), "(a) -> or ... *TRUNCATED*"); - CHECK_EQ(toString(requireType("f2"), o), "(b) -> or(a... *TRUNCATED*"); - CHECK_EQ(toString(requireType("f3"), o), "(c) -> or(b... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f1"), o), "(a) -> (() -> ()) ... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f2"), o), "(b) -> ((a) -> (() -> ())... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f3"), o), "(c) -> ((b) -> ((a) -> (() -> ())... *TRUNCATED*"); } else { diff --git a/tests/TypeFamily.test.cpp b/tests/TypeFamily.test.cpp index c66f0227..063ed39c 100644 --- a/tests/TypeFamily.test.cpp +++ b/tests/TypeFamily.test.cpp @@ -167,15 +167,13 @@ TEST_CASE_FIXTURE(FamilyFixture, "table_internal_families") LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK(toString(requireType("a")) == "{string}"); CHECK(toString(requireType("b")) == "{number}"); - CHECK(toString(requireType("c")) == "{Swap}"); - CHECK(toString(result.errors[0]) == "Type family instance Swap is uninhabited"); + // FIXME: table types are constructing a trivial union here. + CHECK(toString(requireType("c")) == "{Swap}"); + CHECK(toString(result.errors[0]) == "Type family instance Swap is uninhabited"); } TEST_CASE_FIXTURE(FamilyFixture, "function_internal_families") { - // This test is broken right now, but it's not because of type families. See - // CLI-71143. - if (!FFlag::DebugLuauDeferredConstraintResolution) return; @@ -829,4 +827,222 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "ensure_equivalence_with_distributivity") CHECK(toString(requireTypeAlias("U")) == "A | A | B | B"); } -TEST_SUITE_END(); +TEST_CASE_FIXTURE(BuiltinsFixture, "we_shouldnt_warn_that_a_reducible_type_family_is_uninhabited") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + +local Debounce = false +local Active = false + +local function Use(Mode) + + if Mode ~= nil then + + if Mode == false and Active == false then + return + else + Active = not Mode + end + + Debounce = false + end + Active = not Active + +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "index_type_family_works") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + type MyObject = {a: string, b: number, c: boolean} + type IdxAType = index + type IdxBType = index> + + local function ok(idx: IdxAType): string return idx end + local function ok2(idx: IdxBType): string | number | boolean return idx end + local function err(idx: IdxAType): boolean return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK_EQ("boolean", toString(tpm->wantedTp)); + CHECK_EQ("string", toString(tpm->givenTp)); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "index_type_family_works_w_array") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + local MyObject = {"hello", 1, true} + type IdxAType = index + + local function ok(idx: IdxAType): string | number | boolean return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "index_type_family_works_w_generic_types") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + local function access(tbl: T & {}, key: K): index + return tbl[key] + end + + local subjects = { + english = "boring", + math = "fun" + } + + local key: "english" = "english" + local a: string = access(subjects, key) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "index_type_family_errors_w_bad_indexer") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + type MyObject = {a: string, b: number, c: boolean} + type errType1 = index + type errType2 = index + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK(toString(result.errors[0]) == "Property '\"d\"' does not exist on type 'MyObject'"); + CHECK(toString(result.errors[1]) == "Property 'boolean' does not exist on type 'MyObject'"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "index_type_family_errors_w_var_indexer") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + type MyObject = {a: string, b: number, c: boolean} + local key = "a" + + type errType1 = index + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK(toString(result.errors[0]) == "Second argument to index is not a valid index type"); + CHECK(toString(result.errors[1]) == "Unknown type 'key'"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "index_type_family_works_w_union_type_indexer") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + type MyObject = {a: string, b: number, c: boolean} + + type idxType = index + local function ok(idx: idxType): string | number return idx end + + type errType = index + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(toString(result.errors[0]) == "Property '\"a\" | \"d\"' does not exist on type 'MyObject'"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "index_type_family_works_w_union_type_indexee") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + type MyObject = {a: string, b: number, c: boolean} + type MyObject2 = {a: number} + + type idxTypeA = index + local function ok(idx: idxTypeA): string | number return idx end + + type errType = index + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(toString(result.errors[0]) == "Property '\"b\"' does not exist on type 'MyObject | MyObject2'"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "index_type_family_rfc_alternative_section") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + type MyObject = {a: string} + type MyObject2 = {a: string, b: number} + + local function edgeCase(param: MyObject) + type unknownType = index + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(toString(result.errors[0]) == "Property '\"b\"' does not exist on type 'MyObject'"); +} + +TEST_CASE_FIXTURE(ClassFixture, "index_type_family_works_on_classes") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + type KeysOfMyObject = index + + local function ok(idx: KeysOfMyObject): number return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "index_type_family_works_w_index_metatables") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + local exampleClass = { Foo = "text", Bar = true } + + local exampleClass2 = setmetatable({ Foo = 8 }, { __index = exampleClass }) + type exampleTy2 = index + local function ok(idx: exampleTy2): number return idx end + + local exampleClass3 = setmetatable({ Bar = 5 }, { __index = exampleClass }) + type exampleTy3 = index + local function ok2(idx: exampleTy3): string return idx end + + type exampleTy4 = index + local function ok3(idx: exampleTy4): string | number return idx end + + type errTy = index + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(toString(result.errors[0]) == "Property '\"Car\"' does not exist on type 'exampleClass2'"); +} + +TEST_SUITE_END(); \ No newline at end of file diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index 06e698a8..54cf1cef 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -9,7 +9,6 @@ using namespace Luau; LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(DebugLuauSharedSelf); -LUAU_FASTFLAG(LuauForbidAliasNamedTypeof); TEST_SUITE_BEGIN("TypeAliases"); @@ -1065,8 +1064,6 @@ TEST_CASE_FIXTURE(Fixture, "table_types_record_the_property_locations") TEST_CASE_FIXTURE(Fixture, "typeof_is_not_a_valid_alias_name") { - ScopedFastFlag sff{FFlag::LuauForbidAliasNamedTypeof, true}; - CheckResult result = check(R"( type typeof = number )"); diff --git a/tests/TypeInfer.loops.test.cpp b/tests/TypeInfer.loops.test.cpp index ce6988aa..a58fb638 100644 --- a/tests/TypeInfer.loops.test.cpp +++ b/tests/TypeInfer.loops.test.cpp @@ -696,13 +696,7 @@ TEST_CASE_FIXTURE(Fixture, "loop_iter_basic") if (FFlag::DebugLuauDeferredConstraintResolution) { TypeId keyTy = requireType("key"); - - const UnionType* ut = get(keyTy); - REQUIRE(ut); - - REQUIRE(ut->options.size() == 2); - CHECK_EQ(builtinTypes->nilType, follow(ut->options[0])); - CHECK_EQ(*builtinTypes->numberType, *ut->options[1]); + CHECK("number?" == toString(keyTy)); } else CHECK_EQ(*builtinTypes->numberType, *requireType("key")); diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index fac86150..b8bb9795 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -396,9 +396,17 @@ TEST_CASE_FIXTURE(Fixture, "compound_assign_mismatch_result") s += 10 )"); - LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK_EQ(result.errors[0], (TypeError{Location{{2, 8}, {2, 9}}, TypeMismatch{builtinTypes->numberType, builtinTypes->stringType}})); - CHECK_EQ(result.errors[1], (TypeError{Location{{2, 8}, {2, 15}}, TypeMismatch{builtinTypes->stringType, builtinTypes->numberType}})); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(result.errors[0], (TypeError{Location{{2, 8}, {2, 9}}, TypeMismatch{builtinTypes->numberType, builtinTypes->stringType}})); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ(result.errors[0], (TypeError{Location{{2, 8}, {2, 9}}, TypeMismatch{builtinTypes->numberType, builtinTypes->stringType}})); + CHECK_EQ(result.errors[1], (TypeError{Location{{2, 8}, {2, 15}}, TypeMismatch{builtinTypes->stringType, builtinTypes->numberType}})); + } } TEST_CASE_FIXTURE(BuiltinsFixture, "compound_assign_metatable") @@ -423,6 +431,33 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "compound_assign_metatable") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "compound_assign_metatable_with_changing_return_type") +{ + ScopedFastFlag sff{FFlag::DebugLuauDeferredConstraintResolution, true}; + + CheckResult result = check(R"( + --!strict + type T = { x: number } + local MT = {} + + function MT:__add(other): number + return 112 + end + + local t = setmetatable({x = 2}, MT) + local u = t + 3 + t += 3 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + + CHECK("t" == toString(tm->wantedType)); + CHECK("number" == toString(tm->givenType)); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "compound_assign_result_must_be_compatible_with_var") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index ebf1fde4..e089c7be 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -576,15 +576,13 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "singletons_stick_around_under_assignment") local foo = (nil :: any) :: Foo - print(foo.kind == "Bar") -- TypeError: Type "Foo" cannot be compared with "Bar" + print(foo.kind == "Bar") -- type of equality refines to `false` local kind = foo.kind - print(kind == "Bar") -- SHOULD BE: TypeError: Type "Foo" cannot be compared with "Bar" + print(kind == "Bar") -- type of equality refines to `false` )"); - // FIXME: Under the new solver, we get both the errors we expect, but they're - // duplicated because of how we are currently running type family reduction. if (FFlag::DebugLuauDeferredConstraintResolution) - LUAU_REQUIRE_ERROR_COUNT(4, result); + LUAU_REQUIRE_NO_ERRORS(result); else LUAU_REQUIRE_ERROR_COUNT(1, result); } diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 516a761b..2c9614d0 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -4511,4 +4511,30 @@ end )"); } +TEST_CASE_FIXTURE(BuiltinsFixture, "table_literal_inference_assert") +{ + CheckResult result = check(R"( + local buttons = { + buttons = {}; + } + + buttons.Button = { + call = nil; + lightParts = nil; + litPropertyOverrides = nil; + model = nil; + pivot = nil; + unlitPropertyOverrides = nil; + } + buttons.Button.__index = buttons.Button + + local lightFuncs: { (self: types.Button, lit: boolean) -> nil } = { + ['\x00'] = function(self: types.Button, lit: boolean) + end; + } + )"); + + +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 9ea9539f..60903733 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -19,6 +19,7 @@ LUAU_FASTFLAG(LuauFixLocationSpanTableIndexExpr); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAG(LuauInstantiateInSubtyping); +LUAU_FASTFLAG(LuauLeadingBarAndAmpersand2) LUAU_FASTINT(LuauCheckRecursionLimit); LUAU_FASTINT(LuauNormalizeCacheLimit); LUAU_FASTINT(LuauRecursionLimit); @@ -1572,4 +1573,62 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "bad_iter_metamethod") } } +TEST_CASE_FIXTURE(Fixture, "leading_bar") +{ + ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand2, true}; + CheckResult result = check(R"( + type Bar = | number + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK("number" == toString(requireTypeAlias("Bar"))); +} + +TEST_CASE_FIXTURE(Fixture, "leading_bar_question_mark") +{ + ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand2, true}; + CheckResult result = check(R"( + type Bar = |? + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK("Expected type, got '?'" == toString(result.errors[0])); + CHECK("*error-type*?" == toString(requireTypeAlias("Bar"))); +} + +TEST_CASE_FIXTURE(Fixture, "leading_ampersand") +{ + ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand2, true}; + CheckResult result = check(R"( + type Amp = & string + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK("string" == toString(requireTypeAlias("Amp"))); +} + +TEST_CASE_FIXTURE(Fixture, "leading_bar_no_type") +{ + ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand2, true}; + CheckResult result = check(R"( + type Bar = | + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK("Expected type, got " == toString(result.errors[0])); + CHECK("*error-type*" == toString(requireTypeAlias("Bar"))); +} + +TEST_CASE_FIXTURE(Fixture, "leading_ampersand_no_type") +{ + ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand2, true}; + CheckResult result = check(R"( + type Amp = & + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK("Expected type, got " == toString(result.errors[0])); + CHECK("*error-type*" == toString(requireTypeAlias("Amp"))); +} + TEST_SUITE_END(); diff --git a/tools/faillist.txt b/tools/faillist.txt index b2677bf4..834b24c0 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -32,16 +32,6 @@ BuiltinTests.string_format_report_all_type_errors_at_correct_positions BuiltinTests.string_format_use_correct_argument2 BuiltinTests.table_freeze_is_generic BuiltinTests.tonumber_returns_optional_number_type -ControlFlowAnalysis.for_record_do_if_not_x_break -ControlFlowAnalysis.for_record_do_if_not_x_continue -ControlFlowAnalysis.if_not_x_break_elif_not_y_break -ControlFlowAnalysis.if_not_x_break_elif_not_y_continue -ControlFlowAnalysis.if_not_x_break_elif_rand_break_elif_not_y_break -ControlFlowAnalysis.if_not_x_continue_elif_not_y_continue -ControlFlowAnalysis.if_not_x_continue_elif_not_y_throw_elif_not_z_fallthrough -ControlFlowAnalysis.if_not_x_continue_elif_rand_continue_elif_not_y_continue -ControlFlowAnalysis.if_not_x_return_elif_not_y_break -DefinitionTests.class_definition_overload_metamethods Differ.metatable_metamissing_left Differ.metatable_metamissing_right Differ.metatable_metanormal @@ -238,8 +228,6 @@ ToString.named_metatable_toStringNamedFunction ToString.no_parentheses_around_cyclic_function_type_in_intersection ToString.pick_distinct_names_for_mixed_explicit_and_implicit_generics ToString.primitive -ToString.quit_stringifying_type_when_length_is_exceeded -ToString.stringifying_type_is_still_capped_when_exhaustive ToString.toStringDetailed2 ToString.toStringErrorPack TryUnifyTests.members_of_failed_typepack_unification_are_unified_with_errorType @@ -262,7 +250,6 @@ TypeAliases.type_alias_of_an_imported_recursive_generic_type TypeFamilyTests.add_family_at_work TypeFamilyTests.family_as_fn_arg TypeFamilyTests.internal_families_raise_errors -TypeFamilyTests.mul_family_with_union_of_multiplicatives_2 TypeFamilyTests.unsolvable_family TypeInfer.be_sure_to_use_active_txnlog_when_evaluating_a_variadic_overload TypeInfer.check_type_infer_recursion_count @@ -319,7 +306,6 @@ TypeInferFunctions.function_exprs_are_generalized_at_signature_scope_not_enclosi TypeInferFunctions.function_is_supertype_of_concrete_functions TypeInferFunctions.function_statement_sealed_table_assignment_through_indexer TypeInferFunctions.generic_packs_are_not_variadic -TypeInferFunctions.higher_order_function_2 TypeInferFunctions.higher_order_function_4 TypeInferFunctions.improved_function_arg_mismatch_error_nonstrict TypeInferFunctions.improved_function_arg_mismatch_errors @@ -339,7 +325,7 @@ TypeInferFunctions.param_1_and_2_both_takes_the_same_generic_but_their_arguments TypeInferFunctions.param_1_and_2_both_takes_the_same_generic_but_their_arguments_are_incompatible_2 TypeInferFunctions.report_exiting_without_return_nonstrict TypeInferFunctions.return_type_by_overload -TypeInferFunctions.tf_suggest_return_type +TypeInferFunctions.simple_unannotated_mutual_recursion TypeInferFunctions.too_few_arguments_variadic TypeInferFunctions.too_few_arguments_variadic_generic TypeInferFunctions.too_few_arguments_variadic_generic2 @@ -377,7 +363,6 @@ TypeInferOOP.inferring_hundreds_of_self_calls_should_not_suffocate_memory TypeInferOOP.promise_type_error_too_complex TypeInferOperators.add_type_family_works TypeInferOperators.cli_38355_recursive_union -TypeInferOperators.compound_assign_mismatch_metatable TypeInferOperators.compound_assign_result_must_be_compatible_with_var TypeInferOperators.concat_op_on_free_lhs_and_string_rhs TypeInferOperators.concat_op_on_string_lhs_and_free_rhs @@ -408,8 +393,6 @@ TypeSingletons.error_detailed_tagged_union_mismatch_bool TypeSingletons.error_detailed_tagged_union_mismatch_string TypeSingletons.overloaded_function_call_with_singletons_mismatch TypeSingletons.return_type_of_f_is_not_widened -TypeSingletons.singletons_stick_around_under_assignment -TypeSingletons.string_singleton_function_call TypeSingletons.table_properties_type_error_escapes TypeSingletons.widen_the_supertype_if_it_is_free_and_subtype_has_singleton TypeStatesTest.typestates_preserve_error_suppression_properties