From 9c588be16de68567623d32ea89fa45fdb691b4d2 Mon Sep 17 00:00:00 2001 From: aaron Date: Fri, 26 Jan 2024 19:20:56 -0800 Subject: [PATCH 1/4] Sync to upstream/release/610 (#1154) # What's changed? * Check interrupt handler inside the pattern match engine to eliminate potential for programs to hang during string library function execution. * Allow iteration over table properties to pass the old type solver. ### Native Code Generation * Use in-place memory operands for math library operations on x64. * Replace opaque bools with separate enum classes in IrDump to improve code maintainability. * Translate operations on inferred vectors to IR. * Enable support for debugging native-compiled functions in Roblox Studio. ### New Type Solver * Rework type inference for boolean and string literals to introduce bounded free types (bounded below by the singleton type, and above by the primitive type) and reworked primitive type constraint to decide which is the appropriate type for the literal. * Introduce `FunctionCheckConstraint` to handle bidirectional typechecking for function calls, pushing the expected parameter types from the function onto the arguments. * Introduce `union` and `intersect` type families to compute deferred simplified unions and intersections to be employed by the constraint generation logic in the new solver. * Implement support for expanding the domain of local types in `Unifier2`. * Rework type inference for iteration variables bound by for in loops to use local types. * Change constraint blocking logic to use a set to prevent accidental re-blocking. * Add logic to detect missing return statements in functions. ### Internal Contributors Co-authored-by: Aaron Weiss Co-authored-by: Alexander McCord Co-authored-by: Andy Friesen Co-authored-by: Aviral Goel Co-authored-by: Vyacheslav Egorov --------- Co-authored-by: Alexander McCord Co-authored-by: Andy Friesen Co-authored-by: Vighnesh Co-authored-by: Aviral Goel Co-authored-by: David Cope Co-authored-by: Lily Brown Co-authored-by: Vyacheslav Egorov --- Analysis/include/Luau/Constraint.h | 42 +++- Analysis/include/Luau/ConstraintGenerator.h | 11 +- Analysis/include/Luau/ConstraintSolver.h | 5 +- Analysis/include/Luau/TypeFamily.h | 3 + Analysis/include/Luau/Unifier2.h | 1 + Analysis/src/AstJsonEncoder.cpp | 21 +- Analysis/src/Autocomplete.cpp | 82 ++----- Analysis/src/Constraint.cpp | 6 + Analysis/src/ConstraintGenerator.cpp | 172 ++++++++------ Analysis/src/ConstraintSolver.cpp | 237 +++++++++++++------ Analysis/src/Subtyping.cpp | 3 + Analysis/src/ToString.cpp | 9 +- Analysis/src/Type.cpp | 4 + Analysis/src/TypeChecker2.cpp | 153 +++++++++++++ Analysis/src/TypeFamily.cpp | 69 ++++++ Analysis/src/TypeInfer.cpp | 5 +- Analysis/src/Unifier2.cpp | 6 + Ast/include/Luau/Ast.h | 24 +- Ast/src/Ast.cpp | 15 +- Ast/src/Parser.cpp | 48 ++-- CodeGen/include/Luau/AssemblyBuilderA64.h | 5 + CodeGen/include/Luau/AssemblyBuilderX64.h | 11 + CodeGen/include/Luau/CodeGen.h | 36 ++- CodeGen/include/Luau/IrData.h | 20 ++ CodeGen/include/Luau/IrDump.h | 8 +- CodeGen/include/Luau/IrUtils.h | 7 + CodeGen/include/Luau/IrVisitUseDef.h | 1 + CodeGen/src/AssemblyBuilderA64.cpp | 138 ++++++++++- CodeGen/src/AssemblyBuilderX64.cpp | 99 +++++++- CodeGen/src/CodeGen.cpp | 43 ++++ CodeGen/src/CodeGenLower.h | 6 +- CodeGen/src/IrDump.cpp | 41 ++-- CodeGen/src/IrLoweringA64.cpp | 103 +++++++++ CodeGen/src/IrLoweringX64.cpp | 121 ++++++++++ CodeGen/src/IrLoweringX64.h | 6 + CodeGen/src/IrTranslation.cpp | 179 ++++++++++++++- CodeGen/src/IrUtils.cpp | 10 + CodeGen/src/IrValueLocationTracking.cpp | 6 + CodeGen/src/OptimizeConstProp.cpp | 22 +- CodeGen/src/OptimizeFinalX64.cpp | 17 ++ VM/src/ldebug.cpp | 9 +- VM/src/lstate.h | 1 + VM/src/lstrlib.cpp | 17 ++ VM/src/lvmexecute.cpp | 2 +- tests/AssemblyBuilderA64.test.cpp | 22 ++ tests/AssemblyBuilderX64.test.cpp | 39 +++- tests/AstJsonEncoder.test.cpp | 6 - tests/Autocomplete.test.cpp | 9 + tests/Conformance.test.cpp | 48 +++- tests/IrBuilder.test.cpp | 150 ++++++------ tests/IrLowering.test.cpp | 242 +++++++++++++++++++- tests/TypeFamily.test.cpp | 8 +- tests/TypeInfer.loops.test.cpp | 94 ++++++-- tests/TypeInfer.oop.test.cpp | 2 +- tests/TypeInfer.refinements.test.cpp | 2 + tests/TypeInfer.typestates.test.cpp | 6 +- tests/conformance/interrupt.lua | 33 +++ tests/conformance/native.lua | 24 ++ tools/faillist.txt | 40 +--- 59 files changed, 2048 insertions(+), 501 deletions(-) diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h index e69346bd..692e4a14 100644 --- a/Analysis/include/Luau/Constraint.h +++ b/Analysis/include/Luau/Constraint.h @@ -90,18 +90,42 @@ struct FunctionCallConstraint DenseHashMap* astOverloadResolvedTypes = nullptr; }; -// result ~ prim ExpectedType SomeSingletonType MultitonType +// function_check fn argsPack // -// If ExpectedType is potentially a singleton (an actual singleton or a union -// that contains a singleton), then result ~ SomeSingletonType +// If fn is a function type and argsPack is a partially solved +// pack of arguments to be supplied to the function, propagate the argument +// types of fn into the types of argsPack. This is used to implement +// bidirectional inference of lambda arguments. +struct FunctionCheckConstraint +{ + TypeId fn; + TypePackId argsPack; + + class AstExprCall* callSite = nullptr; +}; + +// prim FreeType ExpectedType PrimitiveType // -// else result ~ MultitonType +// FreeType is bounded below by the singleton type and above by PrimitiveType +// initially. When this constraint is resolved, it will check that the bounds +// of the free type are well-formed by subtyping. +// +// If they are not well-formed, then FreeType is replaced by its lower bound +// +// If they are well-formed and ExpectedType is potentially a singleton (an +// actual singleton or a union that contains a singleton), +// then FreeType is replaced by its lower bound +// +// else FreeType is replaced by PrimitiveType struct PrimitiveTypeConstraint { - TypeId resultType; - TypeId expectedType; - TypeId singletonType; - TypeId multitonType; + TypeId freeType; + + // potentially gets used to force the lower bound? + std::optional expectedType; + + // the primitive type to check against + TypeId primitiveType; }; // result ~ hasProp type "prop_name" @@ -230,7 +254,7 @@ struct ReducePackConstraint }; using ConstraintV = Variant; struct Constraint diff --git a/Analysis/include/Luau/ConstraintGenerator.h b/Analysis/include/Luau/ConstraintGenerator.h index ebd237f6..2f746a74 100644 --- a/Analysis/include/Luau/ConstraintGenerator.h +++ b/Analysis/include/Luau/ConstraintGenerator.h @@ -150,7 +150,7 @@ private: */ ScopePtr childScope(AstNode* node, const ScopePtr& parent); - std::optional lookup(Scope* scope, DefId def, bool prototype = true); + std::optional lookup(const ScopePtr& scope, DefId def, bool prototype = true); /** * Adds a new constraint with no dependencies to a given scope. @@ -178,8 +178,8 @@ private: }; using RefinementContext = InsertionOrderedMap; - void unionRefinements(const RefinementContext& lhs, const RefinementContext& rhs, RefinementContext& dest, std::vector* constraints); - void computeRefinement(const ScopePtr& scope, RefinementId refinement, RefinementContext* refis, bool sense, bool eq, std::vector* constraints); + void unionRefinements(const ScopePtr& scope, Location location, const RefinementContext& lhs, const RefinementContext& rhs, RefinementContext& dest, std::vector* constraints); + void computeRefinement(const ScopePtr& scope, Location location, RefinementId refinement, RefinementContext* refis, bool sense, bool eq, std::vector* constraints); void applyRefinements(const ScopePtr& scope, Location location, RefinementId refinement); ControlFlow visitBlockWithoutChildScope(const ScopePtr& scope, AstStatBlock* block); @@ -329,6 +329,11 @@ private: void reportError(Location location, TypeErrorData err); void reportCodeTooComplex(Location location); + // make a union type family of these two types + TypeId makeUnion(const ScopePtr& scope, Location location, TypeId lhs, TypeId rhs); + // make an intersect type family of these two types + TypeId makeIntersect(const ScopePtr& scope, Location location, TypeId lhs, TypeId rhs); + /** Scan the program for global definitions. * * ConstraintGenerator needs to differentiate between globals and accesses to undefined symbols. Doing this "for diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index f258b28b..e962f343 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -75,7 +75,7 @@ struct ConstraintSolver // anything. std::unordered_map, size_t> blockedConstraints; // A mapping of type/pack pointers to the constraints they block. - std::unordered_map>, HashBlockedConstraintId> blocked; + std::unordered_map, HashBlockedConstraintId> blocked; // Memoized instantiations of type aliases. DenseHashMap instantiatedAliases{{}}; // Breadcrumbs for where a free type's upper bound was expanded. We use @@ -126,6 +126,7 @@ struct ConstraintSolver bool tryDispatch(const NameConstraint& c, NotNull constraint); bool tryDispatch(const TypeAliasExpansionConstraint& c, NotNull constraint); bool tryDispatch(const FunctionCallConstraint& c, NotNull constraint); + bool tryDispatch(const FunctionCheckConstraint& c, NotNull constraint); bool tryDispatch(const PrimitiveTypeConstraint& c, NotNull constraint); bool tryDispatch(const HasPropConstraint& c, NotNull constraint); bool tryDispatch(const SetPropConstraint& c, NotNull constraint, bool force); @@ -285,7 +286,7 @@ private: * @param target the type or type pack pointer that the constraint is blocked on. * @param constraint the constraint to block. **/ - void block_(BlockedConstraintId target, NotNull constraint); + bool block_(BlockedConstraintId target, NotNull constraint); /** * Informs the solver that progress has been made on a type or type pack. The diff --git a/Analysis/include/Luau/TypeFamily.h b/Analysis/include/Luau/TypeFamily.h index 49c652ee..77fd6e8a 100644 --- a/Analysis/include/Luau/TypeFamily.h +++ b/Analysis/include/Luau/TypeFamily.h @@ -163,6 +163,9 @@ struct BuiltinTypeFamilies TypeFamily eqFamily; TypeFamily refineFamily; + TypeFamily unionFamily; + TypeFamily intersectFamily; + TypeFamily keyofFamily; TypeFamily rawkeyofFamily; diff --git a/Analysis/include/Luau/Unifier2.h b/Analysis/include/Luau/Unifier2.h index 4930df6f..f9a3fdc9 100644 --- a/Analysis/include/Luau/Unifier2.h +++ b/Analysis/include/Luau/Unifier2.h @@ -56,6 +56,7 @@ struct Unifier2 * free TypePack to another and encounter an occurs check violation. */ bool unify(TypeId subTy, TypeId superTy); + bool unify(const LocalType* subTy, TypeId superFn); bool unify(TypeId subTy, const FunctionType* superFn); bool unify(const UnionType* subUnion, TypeId superTy); bool unify(TypeId subTy, const UnionType* superUnion); diff --git a/Analysis/src/AstJsonEncoder.cpp b/Analysis/src/AstJsonEncoder.cpp index dcee3492..470d69b3 100644 --- a/Analysis/src/AstJsonEncoder.cpp +++ b/Analysis/src/AstJsonEncoder.cpp @@ -8,8 +8,6 @@ #include -LUAU_FASTFLAG(LuauClipExtraHasEndProps); - namespace Luau { @@ -393,8 +391,6 @@ struct AstJsonEncoder : public AstVisitor PROP(body); PROP(functionDepth); PROP(debugname); - if (!FFlag::LuauClipExtraHasEndProps) - write("hasEnd", node->DEPRECATED_hasEnd); }); } @@ -591,11 +587,8 @@ struct AstJsonEncoder : public AstVisitor void write(class AstStatBlock* node) { writeNode(node, "AstStatBlock", [&]() { - if (FFlag::LuauClipExtraHasEndProps) - { - writeRaw(",\"hasEnd\":"); - write(node->hasEnd); - } + writeRaw(",\"hasEnd\":"); + write(node->hasEnd); writeRaw(",\"body\":["); bool comma = false; for (AstStat* stat : node->body) @@ -619,8 +612,6 @@ struct AstJsonEncoder : public AstVisitor if (node->elsebody) PROP(elsebody); write("hasThen", node->thenLocation.has_value()); - if (!FFlag::LuauClipExtraHasEndProps) - write("hasEnd", node->DEPRECATED_hasEnd); }); } @@ -630,8 +621,6 @@ struct AstJsonEncoder : public AstVisitor PROP(condition); PROP(body); PROP(hasDo); - if (!FFlag::LuauClipExtraHasEndProps) - write("hasEnd", node->DEPRECATED_hasEnd); }); } @@ -640,8 +629,6 @@ struct AstJsonEncoder : public AstVisitor writeNode(node, "AstStatRepeat", [&]() { PROP(condition); PROP(body); - if (!FFlag::LuauClipExtraHasEndProps) - write("hasUntil", node->DEPRECATED_hasUntil); }); } @@ -687,8 +674,6 @@ struct AstJsonEncoder : public AstVisitor PROP(step); PROP(body); PROP(hasDo); - if (!FFlag::LuauClipExtraHasEndProps) - write("hasEnd", node->DEPRECATED_hasEnd); }); } @@ -700,8 +685,6 @@ struct AstJsonEncoder : public AstVisitor PROP(body); PROP(hasIn); PROP(hasDo); - if (!FFlag::LuauClipExtraHasEndProps) - write("hasEnd", node->DEPRECATED_hasEnd); }); } diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 55e6d8f0..75fd9f58 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -15,7 +15,6 @@ LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAG(DebugLuauReadWriteProperties); -LUAU_FASTFLAG(LuauClipExtraHasEndProps); LUAU_FASTFLAGVARIABLE(LuauAutocompleteStringLiteralBounds, false); static const std::unordered_set kStatementStartingKeywords = { @@ -1068,51 +1067,30 @@ static AutocompleteEntryMap autocompleteStatement( for (const auto& kw : kStatementStartingKeywords) result.emplace(kw, AutocompleteEntry{AutocompleteEntryKind::Keyword}); - if (FFlag::LuauClipExtraHasEndProps) + for (auto it = ancestry.rbegin(); it != ancestry.rend(); ++it) { - for (auto it = ancestry.rbegin(); it != ancestry.rend(); ++it) + if (AstStatForIn* statForIn = (*it)->as(); statForIn && !statForIn->body->hasEnd) + result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + else if (AstStatFor* statFor = (*it)->as(); statFor && !statFor->body->hasEnd) + result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + else if (AstStatIf* statIf = (*it)->as()) { - if (AstStatForIn* statForIn = (*it)->as(); statForIn && !statForIn->body->hasEnd) - result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - else if (AstStatFor* statFor = (*it)->as(); statFor && !statFor->body->hasEnd) - result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - else if (AstStatIf* statIf = (*it)->as()) + bool hasEnd = statIf->thenbody->hasEnd; + if (statIf->elsebody) { - bool hasEnd = statIf->thenbody->hasEnd; - if (statIf->elsebody) - { - if (AstStatBlock* elseBlock = statIf->elsebody->as()) - hasEnd = elseBlock->hasEnd; - } - - if (!hasEnd) - result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + if (AstStatBlock* elseBlock = statIf->elsebody->as()) + hasEnd = elseBlock->hasEnd; } - else if (AstStatWhile* statWhile = (*it)->as(); statWhile && !statWhile->body->hasEnd) - result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - else if (AstExprFunction* exprFunction = (*it)->as(); exprFunction && !exprFunction->body->hasEnd) - result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - if (AstStatBlock* exprBlock = (*it)->as(); exprBlock && !exprBlock->hasEnd) - result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - } - } - else - { - for (auto it = ancestry.rbegin(); it != ancestry.rend(); ++it) - { - if (AstStatForIn* statForIn = (*it)->as(); statForIn && !statForIn->DEPRECATED_hasEnd) - result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - else if (AstStatFor* statFor = (*it)->as(); statFor && !statFor->DEPRECATED_hasEnd) - result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - else if (AstStatIf* statIf = (*it)->as(); statIf && !statIf->DEPRECATED_hasEnd) - result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - else if (AstStatWhile* statWhile = (*it)->as(); statWhile && !statWhile->DEPRECATED_hasEnd) - result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - else if (AstExprFunction* exprFunction = (*it)->as(); exprFunction && !exprFunction->DEPRECATED_hasEnd) - result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - if (AstStatBlock* exprBlock = (*it)->as(); exprBlock && !exprBlock->hasEnd) + + if (!hasEnd) result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); } + else if (AstStatWhile* statWhile = (*it)->as(); statWhile && !statWhile->body->hasEnd) + result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + else if (AstExprFunction* exprFunction = (*it)->as(); exprFunction && !exprFunction->body->hasEnd) + result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + if (AstStatBlock* exprBlock = (*it)->as(); exprBlock && !exprBlock->hasEnd) + result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); } if (ancestry.size() >= 2) @@ -1127,16 +1105,8 @@ static AutocompleteEntryMap autocompleteStatement( } } - if (FFlag::LuauClipExtraHasEndProps) - { - if (AstStatRepeat* statRepeat = parent->as(); statRepeat && !statRepeat->body->hasEnd) - result.emplace("until", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - } - else - { - if (AstStatRepeat* statRepeat = parent->as(); statRepeat && !statRepeat->DEPRECATED_hasUntil) - result.emplace("until", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - } + if (AstStatRepeat* statRepeat = parent->as(); statRepeat && !statRepeat->body->hasEnd) + result.emplace("until", AutocompleteEntry{AutocompleteEntryKind::Keyword}); } if (ancestry.size() >= 4) @@ -1150,16 +1120,8 @@ static AutocompleteEntryMap autocompleteStatement( } } - if (FFlag::LuauClipExtraHasEndProps) - { - if (AstStatRepeat* statRepeat = extractStat(ancestry); statRepeat && !statRepeat->body->hasEnd) - result.emplace("until", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - } - else - { - if (AstStatRepeat* statRepeat = extractStat(ancestry); statRepeat && !statRepeat->DEPRECATED_hasUntil) - result.emplace("until", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - } + if (AstStatRepeat* statRepeat = extractStat(ancestry); statRepeat && !statRepeat->body->hasEnd) + result.emplace("until", AutocompleteEntry{AutocompleteEntryKind::Keyword}); return result; } diff --git a/Analysis/src/Constraint.cpp b/Analysis/src/Constraint.cpp index 3035d480..7d3f9e31 100644 --- a/Analysis/src/Constraint.cpp +++ b/Analysis/src/Constraint.cpp @@ -45,6 +45,12 @@ DenseHashSet Constraint::getFreeTypes() const ftc.traverse(psc->subPack); ftc.traverse(psc->superPack); } + else if (auto ptc = get(*this)) + { + // we need to take into account primitive type constraints to prevent type families from reducing on + // primitive whose types we have not yet selected to be singleton or not. + ftc.traverse(ptc->freeType); + } return types; } diff --git a/Analysis/src/ConstraintGenerator.cpp b/Analysis/src/ConstraintGenerator.cpp index 91f2ab7e..bc412c0e 100644 --- a/Analysis/src/ConstraintGenerator.cpp +++ b/Analysis/src/ConstraintGenerator.cpp @@ -20,6 +20,7 @@ #include "Luau/InsertionOrderedMap.h" #include +#include LUAU_FASTINT(LuauCheckRecursionLimit); LUAU_FASTFLAG(DebugLuauLogSolverToJson); @@ -205,7 +206,7 @@ ScopePtr ConstraintGenerator::childScope(AstNode* node, const ScopePtr& parent) return scope; } -std::optional ConstraintGenerator::lookup(Scope* scope, DefId def, bool prototype) +std::optional ConstraintGenerator::lookup(const ScopePtr& scope, DefId def, bool prototype) { if (get(def)) return scope->lookup(def); @@ -230,7 +231,7 @@ std::optional ConstraintGenerator::lookup(Scope* scope, DefId def, bool rootScope->lvalueTypes[operand] = *ty; } - res = simplifyUnion(builtinTypes, arena, res, *ty).result; + res = makeUnion(scope, Location{} /* TODO: can we provide a real location here? */, res, *ty); } scope->lvalueTypes[def] = res; @@ -250,18 +251,13 @@ NotNull ConstraintGenerator::addConstraint(const ScopePtr& scope, st return NotNull{constraints.emplace_back(std::move(c)).get()}; } -void ConstraintGenerator::unionRefinements(const RefinementContext& lhs, const RefinementContext& rhs, RefinementContext& dest, std::vector* constraints) +void ConstraintGenerator::unionRefinements(const ScopePtr& scope, Location location, const RefinementContext& lhs, const RefinementContext& rhs, RefinementContext& dest, std::vector* constraints) { const auto intersect = [&](const std::vector& types) { if (1 == types.size()) return types[0]; else if (2 == types.size()) - { - // TODO: It may be advantageous to introduce a refine type family here when there are blockedTypes. - SimplifyResult sr = simplifyIntersection(builtinTypes, arena, types[0], types[1]); - if (sr.blockedTypes.empty()) - return sr.result; - } + return makeIntersect(scope, location, types[0], types[1]); return arena->addType(IntersectionType{types}); }; @@ -281,48 +277,48 @@ void ConstraintGenerator::unionRefinements(const RefinementContext& lhs, const R rhsIt->second.discriminantTypes.size() == 1 ? rhsIt->second.discriminantTypes[0] : intersect(rhsIt->second.discriminantTypes); dest.insert(def, {}); - dest.get(def)->discriminantTypes.push_back(simplifyUnion(builtinTypes, arena, leftDiscriminantTy, rightDiscriminantTy).result); + dest.get(def)->discriminantTypes.push_back(makeUnion(scope, location, leftDiscriminantTy, rightDiscriminantTy)); dest.get(def)->shouldAppendNilType |= partition.shouldAppendNilType || rhsIt->second.shouldAppendNilType; } } -void ConstraintGenerator::computeRefinement(const ScopePtr& scope, RefinementId refinement, RefinementContext* refis, bool sense, bool eq, std::vector* constraints) +void ConstraintGenerator::computeRefinement(const ScopePtr& scope, Location location, RefinementId refinement, RefinementContext* refis, bool sense, bool eq, std::vector* constraints) { if (!refinement) return; else if (auto variadic = get(refinement)) { for (RefinementId refi : variadic->refinements) - computeRefinement(scope, refi, refis, sense, eq, constraints); + computeRefinement(scope, location, refi, refis, sense, eq, constraints); } else if (auto negation = get(refinement)) - return computeRefinement(scope, negation->refinement, refis, !sense, eq, constraints); + return computeRefinement(scope, location, negation->refinement, refis, !sense, eq, constraints); else if (auto conjunction = get(refinement)) { RefinementContext lhsRefis; RefinementContext rhsRefis; - computeRefinement(scope, conjunction->lhs, sense ? refis : &lhsRefis, sense, eq, constraints); - computeRefinement(scope, conjunction->rhs, sense ? refis : &rhsRefis, sense, eq, constraints); + computeRefinement(scope, location, conjunction->lhs, sense ? refis : &lhsRefis, sense, eq, constraints); + computeRefinement(scope, location, conjunction->rhs, sense ? refis : &rhsRefis, sense, eq, constraints); if (!sense) - unionRefinements(lhsRefis, rhsRefis, *refis, constraints); + unionRefinements(scope, location, lhsRefis, rhsRefis, *refis, constraints); } else if (auto disjunction = get(refinement)) { RefinementContext lhsRefis; RefinementContext rhsRefis; - computeRefinement(scope, disjunction->lhs, sense ? &lhsRefis : refis, sense, eq, constraints); - computeRefinement(scope, disjunction->rhs, sense ? &rhsRefis : refis, sense, eq, constraints); + computeRefinement(scope, location, disjunction->lhs, sense ? &lhsRefis : refis, sense, eq, constraints); + computeRefinement(scope, location, disjunction->rhs, sense ? &rhsRefis : refis, sense, eq, constraints); if (sense) - unionRefinements(lhsRefis, rhsRefis, *refis, constraints); + unionRefinements(scope, location, lhsRefis, rhsRefis, *refis, constraints); } else if (auto equivalence = get(refinement)) { - computeRefinement(scope, equivalence->lhs, refis, sense, true, constraints); - computeRefinement(scope, equivalence->rhs, refis, sense, true, constraints); + computeRefinement(scope, location, equivalence->lhs, refis, sense, true, constraints); + computeRefinement(scope, location, equivalence->rhs, refis, sense, true, constraints); } else if (auto proposition = get(refinement)) { @@ -423,11 +419,11 @@ void ConstraintGenerator::applyRefinements(const ScopePtr& scope, Location locat RefinementContext refinements; std::vector constraints; - computeRefinement(scope, refinement, &refinements, /*sense*/ true, /*eq*/ false, &constraints); + computeRefinement(scope, location, refinement, &refinements, /*sense*/ true, /*eq*/ false, &constraints); for (auto& [def, partition] : refinements) { - if (std::optional defTy = lookup(scope.get(), def)) + if (std::optional defTy = lookup(scope, def)) { TypeId ty = *defTy; if (partition.shouldAppendNilType) @@ -455,15 +451,15 @@ void ConstraintGenerator::applyRefinements(const ScopePtr& scope, Location locat switch (shouldSuppressErrors(normalizer, ty)) { case ErrorSuppression::DoNotSuppress: - ty = simplifyIntersection(builtinTypes, arena, ty, dt).result; + ty = makeIntersect(scope, location, ty, dt); break; case ErrorSuppression::Suppress: - ty = simplifyIntersection(builtinTypes, arena, ty, dt).result; - ty = simplifyUnion(builtinTypes, arena, ty, builtinTypes->errorType).result; + ty = makeIntersect(scope, location, ty, dt); + ty = makeUnion(scope, location, ty, builtinTypes->errorType); break; case ErrorSuppression::NormalizationFailed: reportError(location, NormalizationTooComplex{}); - ty = simplifyIntersection(builtinTypes, arena, ty, dt).result; + ty = makeIntersect(scope, location, ty, dt); break; } } @@ -761,7 +757,7 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatForIn* forI for (AstLocal* var : forIn->vars) { - TypeId assignee = arena->addType(BlockedType{}); + TypeId assignee = arena->addType(LocalType{builtinTypes->neverType, /* blockCount */ 1, var->name.value}); variableTypes.push_back(assignee); if (var->annotation) @@ -872,7 +868,7 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatFunction* f DenseHashSet excludeList{nullptr}; DefId def = dfg->getDef(function->name); - std::optional existingFunctionTy = lookup(scope.get(), def); + std::optional existingFunctionTy = lookup(scope, def); if (AstExprLocal* localName = function->name->as()) { @@ -1492,8 +1488,12 @@ InferencePack ConstraintGenerator::checkPack(const ScopePtr& scope, AstExprCall* discriminantTypes.push_back(std::nullopt); } + Checkpoint funcBeginCheckpoint = checkpoint(this); + TypeId fnType = check(scope, call->func).ty; + Checkpoint funcEndCheckpoint = checkpoint(this); + std::vector> expectedTypesForCall = getExpectedCallTypesForFunctionOverloads(fnType); module->astOriginalCallTypes[call->func] = fnType; @@ -1624,13 +1624,31 @@ InferencePack ConstraintGenerator::checkPack(const ScopePtr& scope, AstExprCall* &module->astOverloadResolvedTypes, }); - // We force constraints produced by checking function arguments to wait - // until after we have resolved the constraint on the function itself. - // This ensures, for instance, that we start inferring the contents of - // lambdas under the assumption that their arguments and return types - // will be compatible with the enclosing function call. - forEachConstraint(argBeginCheckpoint, argEndCheckpoint, this, [fcc](const ConstraintPtr& constraint) { - constraint->dependencies.emplace_back(fcc); + NotNull foo = addConstraint(scope, call->func->location, + FunctionCheckConstraint{ + fnType, + argPack, + call + } + ); + + /* + * To make bidirectional type checking work, we need to solve these constraints in a particular order: + * + * 1. Solve the function type + * 2. Propagate type information from the function type to the argument types + * 3. Solve the argument types + * 4. Solve the call + */ + + forEachConstraint(funcBeginCheckpoint, funcEndCheckpoint, this, [foo](const ConstraintPtr& constraint) { + foo->dependencies.emplace_back(constraint.get()); + }); + + forEachConstraint(argBeginCheckpoint, argEndCheckpoint, this, [foo, fcc](const ConstraintPtr& constraint) { + constraint->dependencies.emplace_back(foo); + + fcc->dependencies.emplace_back(constraint.get()); }); return InferencePack{rets, {refinementArena.variadic(returnRefinements)}}; @@ -1712,23 +1730,12 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprConstantStrin if (forceSingleton) return Inference{arena->addType(SingletonType{StringSingleton{std::string{string->value.data, string->value.size}}})}; - if (expectedType) - { - const TypeId expectedTy = follow(*expectedType); - if (get(expectedTy) || get(expectedTy) || get(expectedTy)) - { - TypeId ty = arena->addType(BlockedType{}); - TypeId singletonType = arena->addType(SingletonType(StringSingleton{std::string(string->value.data, string->value.size)})); - addConstraint(scope, string->location, PrimitiveTypeConstraint{ty, expectedTy, singletonType, builtinTypes->stringType}); - return Inference{ty}; - } - else if (maybeSingleton(expectedTy)) - return Inference{arena->addType(SingletonType{StringSingleton{std::string{string->value.data, string->value.size}}})}; - - return Inference{builtinTypes->stringType}; - } - - return Inference{builtinTypes->stringType}; + FreeType ft = FreeType{scope.get()}; + ft.lowerBound = arena->addType(SingletonType{StringSingleton{std::string{string->value.data, string->value.size}}}); + ft.upperBound = builtinTypes->stringType; + const TypeId freeTy = arena->addType(ft); + addConstraint(scope, string->location, PrimitiveTypeConstraint{freeTy, expectedType, builtinTypes->stringType}); + return Inference{freeTy}; } Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprConstantBool* boolExpr, std::optional expectedType, bool forceSingleton) @@ -1737,23 +1744,12 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprConstantBool* if (forceSingleton) return Inference{singletonType}; - if (expectedType) - { - const TypeId expectedTy = follow(*expectedType); - - if (get(expectedTy) || get(expectedTy)) - { - TypeId ty = arena->addType(BlockedType{}); - addConstraint(scope, boolExpr->location, PrimitiveTypeConstraint{ty, expectedTy, singletonType, builtinTypes->booleanType}); - return Inference{ty}; - } - else if (maybeSingleton(expectedTy)) - return Inference{singletonType}; - - return Inference{builtinTypes->booleanType}; - } - - return Inference{builtinTypes->booleanType}; + FreeType ft = FreeType{scope.get()}; + ft.lowerBound = singletonType; + ft.upperBound = builtinTypes->booleanType; + const TypeId freeTy = arena->addType(ft); + addConstraint(scope, boolExpr->location, PrimitiveTypeConstraint{freeTy, expectedType, builtinTypes->booleanType}); + return Inference{freeTy}; } Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprLocal* local) @@ -1766,12 +1762,12 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprLocal* local) // if we have a refinement key, we can look up its type. if (key) - maybeTy = lookup(scope.get(), key->def); + maybeTy = lookup(scope, key->def); // if the current def doesn't have a type, we might be doing a compound assignment // and therefore might need to look at the rvalue def instead. if (!maybeTy && rvalueDef) - maybeTy = lookup(scope.get(), *rvalueDef); + maybeTy = lookup(scope, *rvalueDef); if (maybeTy) { @@ -1797,7 +1793,7 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprGlobal* globa /* prepopulateGlobalScope() has already added all global functions to the environment by this point, so any * global that is not already in-scope is definitely an unknown symbol. */ - if (auto ty = lookup(scope.get(), def, /*prototype=*/false)) + if (auto ty = lookup(scope, def, /*prototype=*/false)) { rootScope->lvalueTypes[def] = *ty; return Inference{*ty, refinementArena.proposition(key, builtinTypes->truthyType)}; @@ -1816,7 +1812,7 @@ Inference ConstraintGenerator::checkIndexName(const ScopePtr& scope, const Refin if (key) { - if (auto ty = lookup(scope.get(), key->def)) + if (auto ty = lookup(scope, key->def)) return Inference{*ty, refinementArena.proposition(key, builtinTypes->truthyType)}; scope->rvalueRefinements[key->def] = result; @@ -1852,7 +1848,7 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprIndexExpr* in const RefinementKey* key = dfg->getRefinementKey(indexExpr); if (key) { - if (auto ty = lookup(scope.get(), key->def)) + if (auto ty = lookup(scope, key->def)) return Inference{*ty, refinementArena.proposition(key, builtinTypes->truthyType)}; scope->rvalueRefinements[key->def] = result; @@ -2120,7 +2116,7 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprIfElse* ifEls applyRefinements(elseScope, ifElse->falseExpr->location, refinementArena.negation(refinement)); TypeId elseType = check(elseScope, ifElse->falseExpr, expectedType).ty; - return Inference{expectedType ? *expectedType : simplifyUnion(builtinTypes, arena, thenType, elseType).result}; + return Inference{expectedType ? *expectedType : makeUnion(scope, ifElse->location, thenType, elseType)}; } Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert) @@ -3172,6 +3168,30 @@ void ConstraintGenerator::reportCodeTooComplex(Location location) logger->captureGenerationError(errors.back()); } +TypeId ConstraintGenerator::makeUnion(const ScopePtr& scope, Location location, TypeId lhs, TypeId rhs) +{ + TypeId resultType = arena->addType(TypeFamilyInstanceType{ + NotNull{&kBuiltinTypeFamilies.unionFamily}, + {lhs, rhs}, + {}, + }); + addConstraint(scope, location, ReduceConstraint{resultType}); + + return resultType; +} + +TypeId ConstraintGenerator::makeIntersect(const ScopePtr& scope, Location location, TypeId lhs, TypeId rhs) +{ + TypeId resultType = arena->addType(TypeFamilyInstanceType{ + NotNull{&kBuiltinTypeFamilies.intersectFamily}, + {lhs, rhs}, + {}, + }); + addConstraint(scope, location, ReduceConstraint{resultType}); + + return resultType; +} + struct GlobalPrepopulator : AstVisitor { const NotNull globalScope; diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 7430dbf3..ec0d6c8a 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -283,7 +283,8 @@ ConstraintSolver::ConstraintSolver(NotNull normalizer, NotNullgetFreeTypes()) { // increment the reference count for `ty` - unresolvedConstraints[ty] += 1; + auto [refCount, _] = unresolvedConstraints.try_insert(ty, 0); + refCount += 1; } for (NotNull dep : c->dependencies) @@ -368,7 +369,13 @@ void ConstraintSolver::run() // decrement the referenced free types for this constraint if we dispatched successfully! for (auto ty : c->getFreeTypes()) - unresolvedConstraints[ty] -= 1; + { + // this is a little weird, but because we're only counting free types in subtyping constraints, + // some constraints (like unpack) might actually produce _more_ references to a free type. + size_t& refCount = unresolvedConstraints[ty]; + if (refCount > 0) + refCount -= 1; + } if (logger) { @@ -534,6 +541,8 @@ bool ConstraintSolver::tryDispatch(NotNull constraint, bool fo success = tryDispatch(*taec, constraint); else if (auto fcc = get(*constraint)) success = tryDispatch(*fcc, constraint); + else if (auto fcc = get(*constraint)) + success = tryDispatch(*fcc, constraint); else if (auto fcc = get(*constraint)) success = tryDispatch(*fcc, constraint); else if (auto hpc = get(*constraint)) @@ -992,6 +1001,15 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull(fn)) + { + asMutable(c.result)->ty.emplace(builtinTypes->errorTypePack); + unblock(c.result, constraint->location); + + return true; + } + auto [argsHead, argsTail] = flatten(argsPack); bool blocked = false; @@ -1080,44 +1098,6 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNullanyType}; } - // We know the type of the function and the arguments it expects to receive. - // We also know the TypeIds of the actual arguments that will be passed. - // - // Bidirectional type checking: Force those TypeIds to be the expected - // arguments. If something is incoherent, we'll spot it in type checking. - // - // Most important detail: If a function argument is a lambda, we also want - // to force unannotated argument types of that lambda to be the expected - // types. - - // FIXME: Bidirectional type checking of overloaded functions is not yet supported. - if (auto ftv = get(fn)) - { - const std::vector expectedArgs = flatten(ftv->argTypes).first; - const std::vector argPackHead = flatten(argsPack).first; - - for (size_t i = 0; i < c.callSite->args.size && i < expectedArgs.size() && i < argPackHead.size(); ++i) - { - const FunctionType* expectedLambdaTy = get(follow(expectedArgs[i])); - const FunctionType* lambdaTy = get(follow(argPackHead[i])); - const AstExprFunction* lambdaExpr = c.callSite->args.data[i]->as(); - - if (expectedLambdaTy && lambdaTy && lambdaExpr) - { - const std::vector expectedLambdaArgTys = flatten(expectedLambdaTy->argTypes).first; - const std::vector lambdaArgTys = flatten(lambdaTy->argTypes).first; - - for (size_t j = 0; j < expectedLambdaArgTys.size() && j < lambdaArgTys.size() && j < lambdaExpr->args.size; ++j) - { - if (!lambdaExpr->args.data[j]->annotation && get(follow(lambdaArgTys[j]))) - { - asMutable(lambdaArgTys[j])->ty.emplace(expectedLambdaArgTys[j]); - } - } - } - } - } - TypeId inferredTy = arena->addType(FunctionType{TypeLevel{}, constraint->scope.get(), argsPack, c.result}); Unifier2 u2{NotNull{arena}, builtinTypes, constraint->scope, NotNull{&iceReporter}}; @@ -1141,17 +1121,94 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull constraint) +{ + const TypeId fn = follow(c.fn); + const TypePackId argsPack = follow(c.argsPack); + + if (isBlocked(fn)) + return block(fn, constraint); + + if (isBlocked(argsPack)) + return block(argsPack, constraint); + + // We know the type of the function and the arguments it expects to receive. + // We also know the TypeIds of the actual arguments that will be passed. + // + // Bidirectional type checking: Force those TypeIds to be the expected + // arguments. If something is incoherent, we'll spot it in type checking. + // + // Most important detail: If a function argument is a lambda, we also want + // to force unannotated argument types of that lambda to be the expected + // types. + + // FIXME: Bidirectional type checking of overloaded functions is not yet supported. + if (auto ftv = get(fn)) + { + const std::vector expectedArgs = flatten(ftv->argTypes).first; + const std::vector argPackHead = flatten(argsPack).first; + + for (size_t i = 0; i < c.callSite->args.size && i < expectedArgs.size() && i < argPackHead.size(); ++i) + { + const TypeId expectedArgTy = follow(expectedArgs[i]); + const TypeId actualArgTy = follow(argPackHead[i]); + + const FunctionType* expectedLambdaTy = get(expectedArgTy); + const FunctionType* lambdaTy = get(actualArgTy); + const AstExprFunction* lambdaExpr = c.callSite->args.data[i]->as(); + + if (expectedLambdaTy && lambdaTy && lambdaExpr) + { + const std::vector expectedLambdaArgTys = flatten(expectedLambdaTy->argTypes).first; + const std::vector lambdaArgTys = flatten(lambdaTy->argTypes).first; + + for (size_t j = 0; j < expectedLambdaArgTys.size() && j < lambdaArgTys.size() && j < lambdaExpr->args.size; ++j) + { + if (!lambdaExpr->args.data[j]->annotation && get(follow(lambdaArgTys[j]))) + { + asMutable(lambdaArgTys[j])->ty.emplace(expectedLambdaArgTys[j]); + } + } + } + else + { + Unifier2 u2{arena, builtinTypes, constraint->scope, NotNull{&iceReporter}}; + u2.unify(actualArgTy, expectedArgTy); + } + } + } + + return true; +} + bool ConstraintSolver::tryDispatch(const PrimitiveTypeConstraint& c, NotNull constraint) { - TypeId expectedType = follow(c.expectedType); - if (isBlocked(expectedType) || get(expectedType)) - return block(expectedType, constraint); + std::optional expectedType = c.expectedType ? std::make_optional(follow(*c.expectedType)) : std::nullopt; + if (expectedType && (isBlocked(*expectedType) || get(*expectedType))) + return block(*expectedType, constraint); - LUAU_ASSERT(get(c.resultType)); + const FreeType* freeType = get(follow(c.freeType)); - TypeId bindTo = maybeSingleton(expectedType) ? c.singletonType : c.multitonType; - asMutable(c.resultType)->ty.emplace(bindTo); - unblock(c.resultType, constraint->location); + // if this is no longer a free type, then we're done. + if (!freeType) + return true; + + // We will wait if there are any other references to the free type mentioned here. + // This is probably the only thing that makes this not insane to do. + if (auto refCount = unresolvedConstraints.find(c.freeType); refCount && *refCount > 1) + { + block(c.freeType, constraint); + return false; + } + + TypeId bindTo = c.primitiveType; + + if (freeType->upperBound != c.primitiveType && maybeSingleton(freeType->upperBound)) + bindTo = freeType->lowerBound; + else if (expectedType && maybeSingleton(*expectedType)) + bindTo = freeType->lowerBound; + + asMutable(c.freeType)->ty.emplace(bindTo); return true; } @@ -1163,7 +1220,7 @@ bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNull(resultType)); - if (isBlocked(subjectType) || get(subjectType)) + if (isBlocked(subjectType) || get(subjectType) || get(subjectType)) return block(subjectType, constraint); auto [blocked, result] = lookupTableProp(subjectType, c.prop, c.suppressSimplification); @@ -1599,7 +1656,7 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl auto unpack = [&](TypeId ty) { TypePackId variadic = arena->addTypePack(VariadicTypePack{ty}); - pushConstraint(constraint->scope, constraint->location, UnpackConstraint{c.variables, variadic}); + pushConstraint(constraint->scope, constraint->location, UnpackConstraint{c.variables, variadic, /* resultIsLValue */ true}); }; if (get(iteratorTy)) @@ -1639,6 +1696,23 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl { TypePackId expectedVariablePack = arena->addTypePack({iteratorTable->indexer->indexType, iteratorTable->indexer->indexResultType}); unify(constraint->scope, constraint->location, c.variables, expectedVariablePack); + + auto [variableTys, variablesTail] = flatten(c.variables); + + // the local types for the indexer _should_ be all set after unification + for (TypeId ty : variableTys) + { + if (auto lt = getMutable(ty)) + { + LUAU_ASSERT(lt->blockCount > 0); + --lt->blockCount; + + LUAU_ASSERT(0 <= lt->blockCount); + + if (0 == lt->blockCount) + asMutable(ty)->ty.emplace(lt->domain); + } + } } else unpack(builtinTypes->errorType); @@ -1775,7 +1849,7 @@ bool ConstraintSolver::tryDispatchIterableFunction( modifiedNextRetHead.push_back(*it); TypePackId modifiedNextRetPack = arena->addTypePack(std::move(modifiedNextRetHead), it.tail()); - auto psc = pushConstraint(constraint->scope, constraint->location, UnpackConstraint{c.variables, modifiedNextRetPack}); + auto psc = pushConstraint(constraint->scope, constraint->location, UnpackConstraint{c.variables, modifiedNextRetPack, /* resultIsLValue */ true}); inheritBlocks(constraint, psc); return true; @@ -1876,7 +1950,7 @@ std::pair, std::optional> ConstraintSolver::lookupTa { const TypeId upperBound = follow(ft->upperBound); - if (get(upperBound)) + if (get(upperBound) || get(upperBound)) return lookupTableProp(upperBound, propName, suppressSimplification, seen); // TODO: The upper bound could be an intersection that contains suitable tables or classes. @@ -2008,46 +2082,63 @@ void ConstraintSolver::bindBlockedType(TypeId blockedTy, TypeId resultTy, TypeId asMutable(blockedTy)->ty.emplace(resultTy); } -void ConstraintSolver::block_(BlockedConstraintId target, NotNull constraint) +bool ConstraintSolver::block_(BlockedConstraintId target, NotNull constraint) { - blocked[target].push_back(constraint); + // If a set is not present for the target, construct a new DenseHashSet for it, + // else grab the address of the existing set. + NotNull> blockVec{&blocked.try_emplace(target, nullptr).first->second}; + + if (blockVec->find(constraint)) + return false; + + blockVec->insert(constraint); auto& count = blockedConstraints[constraint]; count += 1; + + return true; } void ConstraintSolver::block(NotNull target, NotNull constraint) { - if (logger) - logger->pushBlock(constraint, target); + const bool newBlock = block_(target.get(), constraint); + if (newBlock) + { + if (logger) + logger->pushBlock(constraint, target); - if (FFlag::DebugLuauLogSolver) - printf("block Constraint %s on\t%s\n", toString(*target, opts).c_str(), toString(*constraint, opts).c_str()); - - block_(target.get(), constraint); + if (FFlag::DebugLuauLogSolver) + printf("block Constraint %s on\t%s\n", toString(*target, opts).c_str(), toString(*constraint, opts).c_str()); + } } bool ConstraintSolver::block(TypeId target, NotNull constraint) { - if (logger) - logger->pushBlock(constraint, target); + const bool newBlock = block_(follow(target), constraint); + if (newBlock) + { + if (logger) + logger->pushBlock(constraint, target); - if (FFlag::DebugLuauLogSolver) - printf("block TypeId %s on\t%s\n", toString(target, opts).c_str(), toString(*constraint, opts).c_str()); + if (FFlag::DebugLuauLogSolver) + printf("block TypeId %s on\t%s\n", toString(target, opts).c_str(), toString(*constraint, opts).c_str()); + } - block_(follow(target), constraint); return false; } bool ConstraintSolver::block(TypePackId target, NotNull constraint) { - if (logger) - logger->pushBlock(constraint, target); + const bool newBlock = block_(target, constraint); + if (newBlock) + { + if (logger) + logger->pushBlock(constraint, target); - if (FFlag::DebugLuauLogSolver) - printf("block TypeId %s on\t%s\n", toString(target, opts).c_str(), toString(*constraint, opts).c_str()); + if (FFlag::DebugLuauLogSolver) + printf("block TypeId %s on\t%s\n", toString(target, opts).c_str(), toString(*constraint, opts).c_str()); + } - block_(target, constraint); return false; } @@ -2058,9 +2149,9 @@ void ConstraintSolver::inheritBlocks(NotNull source, NotNullsecond) + for (const Constraint* blockedConstraint : blockedIt->second) { - block(addition, blockedConstraint); + block(addition, NotNull{blockedConstraint}); } } } @@ -2112,9 +2203,9 @@ void ConstraintSolver::unblock_(BlockedConstraintId progressed) return; // unblocked should contain a value always, because of the above check - for (NotNull unblockedConstraint : it->second) + for (const Constraint* unblockedConstraint : it->second) { - auto& count = blockedConstraints[unblockedConstraint]; + auto& count = blockedConstraints[NotNull{unblockedConstraint}]; if (FFlag::DebugLuauLogSolver) printf("Unblocking count=%d\t%s\n", int(count), toString(*unblockedConstraint, opts).c_str()); diff --git a/Analysis/src/Subtyping.cpp b/Analysis/src/Subtyping.cpp index b8671d21..62b574ca 100644 --- a/Analysis/src/Subtyping.cpp +++ b/Analysis/src/Subtyping.cpp @@ -699,6 +699,9 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypePackId results.push_back(SubtypingResult{ok}.withBothComponent(TypePath::PackField::Tail)); } } + else if (get(*subTail) || get(*superTail)) + // error type is fine on either side + results.push_back(SubtypingResult{true}.withBothComponent(TypePath::PackField::Tail)); else iceReporter->ice( format("Subtyping::isSubtype got unexpected type packs %s and %s", toString(*subTail).c_str(), toString(*superTail).c_str())); diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index b14df311..0f520116 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -1742,9 +1742,16 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts) { return "call " + tos(c.fn) + "( " + tos(c.argsPack) + " )" + " with { result = " + tos(c.result) + " }"; } + else if constexpr (std::is_same_v) + { + return "function_check " + tos(c.fn) + " " + tos(c.argsPack); + } else if constexpr (std::is_same_v) { - return tos(c.resultType) + " ~ prim " + tos(c.expectedType) + ", " + tos(c.singletonType) + ", " + tos(c.multitonType); + if (c.expectedType) + return "prim " + tos(c.freeType) + "[expected: " + tos(*c.expectedType) + "] as " + tos(c.primitiveType); + else + return "prim " + tos(c.freeType) + " as " + tos(c.primitiveType); } else if constexpr (std::is_same_v) { diff --git a/Analysis/src/Type.cpp b/Analysis/src/Type.cpp index abe719e1..71a07d73 100644 --- a/Analysis/src/Type.cpp +++ b/Analysis/src/Type.cpp @@ -419,6 +419,10 @@ bool maybeSingleton(TypeId ty) for (TypeId option : utv) if (get(follow(option))) return true; + if (const IntersectionType* itv = get(ty)) + for (TypeId part : itv) + if (maybeSingleton(part)) // will i regret this? + return true; return false; } diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index 349443ab..fdac6daf 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -261,6 +261,155 @@ struct TypeChecker2 { } + static bool allowsNoReturnValues(const TypePackId tp) + { + for (TypeId ty : tp) + { + if (!get(follow(ty))) + return false; + } + + return true; + } + + static Location getEndLocation(const AstExprFunction* function) + { + Location loc = function->location; + if (loc.begin.line != loc.end.line) + { + Position begin = loc.end; + begin.column = std::max(0u, begin.column - 3); + loc = Location(begin, 3); + } + + return loc; + } + + bool isErrorCall(const AstExprCall* call) + { + const AstExprGlobal* global = call->func->as(); + if (!global) + return false; + + if (global->name == "error") + return true; + else if (global->name == "assert") + { + // assert() will error because it is missing the first argument + if (call->args.size == 0) + return true; + + if (AstExprConstantBool* expr = call->args.data[0]->as()) + if (!expr->value) + return true; + } + + return false; + } + + bool hasBreak(AstStat* node) + { + if (AstStatBlock* stat = node->as()) + { + for (size_t i = 0; i < stat->body.size; ++i) + { + if (hasBreak(stat->body.data[i])) + return true; + } + + return false; + } + + if (node->is()) + return true; + + if (AstStatIf* stat = node->as()) + { + if (hasBreak(stat->thenbody)) + return true; + + if (stat->elsebody && hasBreak(stat->elsebody)) + return true; + + return false; + } + + return false; + } + + // returns the last statement before the block implicitly exits, or nullptr if the block does not implicitly exit + // i.e. returns nullptr if the block returns properly or never returns + const AstStat* getFallthrough(const AstStat* node) + { + if (const AstStatBlock* stat = node->as()) + { + if (stat->body.size == 0) + return stat; + + for (size_t i = 0; i < stat->body.size - 1; ++i) + { + if (getFallthrough(stat->body.data[i]) == nullptr) + return nullptr; + } + + return getFallthrough(stat->body.data[stat->body.size - 1]); + } + + if (const AstStatIf* stat = node->as()) + { + if (const AstStat* thenf = getFallthrough(stat->thenbody)) + return thenf; + + if (stat->elsebody) + { + if (const AstStat* elsef = getFallthrough(stat->elsebody)) + return elsef; + + return nullptr; + } + else + return stat; + } + + if (node->is()) + return nullptr; + + if (const AstStatExpr* stat = node->as()) + { + if (AstExprCall* call = stat->expr->as(); call && isErrorCall(call)) + return nullptr; + + return stat; + } + + if (const AstStatWhile* stat = node->as()) + { + if (AstExprConstantBool* expr = stat->condition->as()) + { + if (expr->value && !hasBreak(stat->body)) + return nullptr; + } + + return node; + } + + if (const AstStatRepeat* stat = node->as()) + { + if (AstExprConstantBool* expr = stat->condition->as()) + { + if (!expr->value && !hasBreak(stat->body)) + return nullptr; + } + + if (getFallthrough(stat->body) == nullptr) + return nullptr; + + return node; + } + + return node; + } + std::optional pushStack(AstNode* node) { if (Scope** scope = module->astScopes.find(node)) @@ -1723,6 +1872,10 @@ struct TypeChecker2 ++argIt; } + + bool reachesImplicitReturn = getFallthrough(fn->body) != nullptr; + if (reachesImplicitReturn && !allowsNoReturnValues(follow(inferredFtv->retTypes))) + reportError(FunctionExitsWithoutReturning{inferredFtv->retTypes}, getEndLocation(fn)); } visit(fn->body); diff --git a/Analysis/src/TypeFamily.cpp b/Analysis/src/TypeFamily.cpp index ed81b663..4903865f 100644 --- a/Analysis/src/TypeFamily.cpp +++ b/Analysis/src/TypeFamily.cpp @@ -1052,6 +1052,73 @@ TypeFamilyReductionResult refineFamilyFn(const std::vector& type return {resultTy, false, {}, {}}; } +TypeFamilyReductionResult unionFamilyFn(const std::vector& typeParams, const std::vector& packParams, NotNull ctx) +{ + if (typeParams.size() != 2 || !packParams.empty()) + { + ctx->ice->ice("union type family: encountered a type family instance without the required argument structure"); + LUAU_ASSERT(false); + } + + TypeId lhsTy = follow(typeParams.at(0)); + TypeId rhsTy = follow(typeParams.at(1)); + + // check to see if both operand types are resolved enough, and wait to reduce if not + if (isPending(lhsTy, ctx->solver)) + return {std::nullopt, false, {lhsTy}, {}}; + else if (get(lhsTy)) // if the lhs is never, we don't need this family anymore + return {rhsTy, false, {}, {}}; + else if (isPending(rhsTy, ctx->solver)) + return {std::nullopt, false, {rhsTy}, {}}; + else if (get(rhsTy)) // if the rhs is never, we don't need this family anymore + return {lhsTy, false, {}, {}}; + + + SimplifyResult result = simplifyUnion(ctx->builtins, ctx->arena, lhsTy, rhsTy); + if (!result.blockedTypes.empty()) + return {std::nullopt, false, {result.blockedTypes.begin(), result.blockedTypes.end()}, {}}; + + return {result.result, false, {}, {}}; +} + + +TypeFamilyReductionResult intersectFamilyFn(const std::vector& typeParams, const std::vector& packParams, NotNull ctx) +{ + if (typeParams.size() != 2 || !packParams.empty()) + { + ctx->ice->ice("intersect type family: encountered a type family instance without the required argument structure"); + LUAU_ASSERT(false); + } + + TypeId lhsTy = follow(typeParams.at(0)); + TypeId rhsTy = follow(typeParams.at(1)); + + // check to see if both operand types are resolved enough, and wait to reduce if not + if (isPending(lhsTy, ctx->solver)) + return {std::nullopt, false, {lhsTy}, {}}; + else if (get(lhsTy)) // if the lhs is never, we don't need this family anymore + return {ctx->builtins->neverType, false, {}, {}}; + else if (isPending(rhsTy, ctx->solver)) + return {std::nullopt, false, {rhsTy}, {}}; + else if (get(rhsTy)) // if the rhs is never, we don't need this family anymore + return {ctx->builtins->neverType, false, {}, {}}; + + SimplifyResult result = simplifyIntersection(ctx->builtins, ctx->arena, lhsTy, rhsTy); + if (!result.blockedTypes.empty()) + return {std::nullopt, false, {result.blockedTypes.begin(), result.blockedTypes.end()}, {}}; + + // if the intersection simplifies to `never`, this gives us bad autocomplete. + // we'll just produce the intersection plainly instead, but this might be revisitable + // if we ever give `never` some kind of "explanation" trail. + if (get(result.result)) + { + TypeId intersection = ctx->arena->addType(IntersectionType{{lhsTy, rhsTy}}); + return {intersection, false, {}, {}}; + } + + return {result.result, false, {}, {}}; +} + // computes the keys of `ty` into `result` // `isRaw` parameter indicates whether or not we should follow __index metamethods // returns `false` if `result` should be ignored because the answer is "all strings" @@ -1262,6 +1329,8 @@ BuiltinTypeFamilies::BuiltinTypeFamilies() , leFamily{"le", leFamilyFn} , eqFamily{"eq", eqFamilyFn} , refineFamily{"refine", refineFamilyFn} + , unionFamily{"union", unionFamilyFn} + , intersectFamily{"intersect", intersectFamilyFn} , keyofFamily{"keyof", keyofFamilyFn} , rawkeyofFamily{"rawkeyof", rawkeyofFamilyFn} { diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index a1247915..653beb0e 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -40,6 +40,7 @@ LUAU_FASTFLAGVARIABLE(LuauLoopControlFlowAnalysis, false) LUAU_FASTFLAGVARIABLE(LuauAlwaysCommitInferencesOfFunctionCalls, false) LUAU_FASTFLAGVARIABLE(LuauRemoveBadRelationalOperatorWarning, false) LUAU_FASTFLAGVARIABLE(LuauForbidAliasNamedTypeof, false) +LUAU_FASTFLAGVARIABLE(LuauOkWithIteratingOverTableProperties, false) namespace Luau { @@ -1335,10 +1336,10 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) for (size_t i = 2; i < varTypes.size(); ++i) unify(nilType, varTypes[i], scope, forin.location); } - else if (isNonstrictMode()) + else if (isNonstrictMode() || FFlag::LuauOkWithIteratingOverTableProperties) { for (TypeId var : varTypes) - unify(anyType, var, scope, forin.location); + unify(unknownType, var, scope, forin.location); } else { diff --git a/Analysis/src/Unifier2.cpp b/Analysis/src/Unifier2.cpp index 6b213aea..cba2f4bb 100644 --- a/Analysis/src/Unifier2.cpp +++ b/Analysis/src/Unifier2.cpp @@ -56,6 +56,12 @@ bool Unifier2::unify(TypeId subTy, TypeId superTy) if (subFree || superFree) return true; + if (auto subLocal = getMutable(subTy)) + { + subLocal->domain = mkUnion(subLocal->domain, superTy); + expandedFreeTypes[subTy].push_back(superTy); + } + auto subFn = get(subTy); auto superFn = get(superTy); if (subFn && superFn) diff --git a/Ast/include/Luau/Ast.h b/Ast/include/Luau/Ast.h index d4836bb5..993116d6 100644 --- a/Ast/include/Luau/Ast.h +++ b/Ast/include/Luau/Ast.h @@ -387,7 +387,7 @@ public: AstExprFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, AstLocal* self, const AstArray& args, bool vararg, const Location& varargLocation, AstStatBlock* body, size_t functionDepth, const AstName& debugname, const std::optional& returnAnnotation = {}, AstTypePack* varargAnnotation = nullptr, - bool DEPRECATED_hasEnd = false, const std::optional& argLocation = std::nullopt); + const std::optional& argLocation = std::nullopt); void visit(AstVisitor* visitor) override; @@ -406,8 +406,6 @@ public: AstName debugname; - // TODO clip with FFlag::LuauClipExtraHasEndProps - bool DEPRECATED_hasEnd = false; std::optional argLocation; }; @@ -573,7 +571,7 @@ public: LUAU_RTTI(AstStatIf) AstStatIf(const Location& location, AstExpr* condition, AstStatBlock* thenbody, AstStat* elsebody, const std::optional& thenLocation, - const std::optional& elseLocation, bool DEPRECATED_hasEnd); + const std::optional& elseLocation); void visit(AstVisitor* visitor) override; @@ -585,9 +583,6 @@ public: // Active for 'elseif' as well std::optional elseLocation; - - // TODO clip with FFlag::LuauClipExtraHasEndProps - bool DEPRECATED_hasEnd = false; }; class AstStatWhile : public AstStat @@ -595,7 +590,7 @@ class AstStatWhile : public AstStat public: LUAU_RTTI(AstStatWhile) - AstStatWhile(const Location& location, AstExpr* condition, AstStatBlock* body, bool hasDo, const Location& doLocation, bool DEPRECATED_hasEnd); + AstStatWhile(const Location& location, AstExpr* condition, AstStatBlock* body, bool hasDo, const Location& doLocation); void visit(AstVisitor* visitor) override; @@ -604,9 +599,6 @@ public: bool hasDo = false; Location doLocation; - - // TODO clip with FFlag::LuauClipExtraHasEndProps - bool DEPRECATED_hasEnd = false; }; class AstStatRepeat : public AstStat @@ -690,7 +682,7 @@ public: LUAU_RTTI(AstStatFor) AstStatFor(const Location& location, AstLocal* var, AstExpr* from, AstExpr* to, AstExpr* step, AstStatBlock* body, bool hasDo, - const Location& doLocation, bool DEPRECATED_hasEnd); + const Location& doLocation); void visit(AstVisitor* visitor) override; @@ -702,9 +694,6 @@ public: bool hasDo = false; Location doLocation; - - // TODO clip with FFlag::LuauClipExtraHasEndProps - bool DEPRECATED_hasEnd = false; }; class AstStatForIn : public AstStat @@ -713,7 +702,7 @@ public: LUAU_RTTI(AstStatForIn) AstStatForIn(const Location& location, const AstArray& vars, const AstArray& values, AstStatBlock* body, bool hasIn, - const Location& inLocation, bool hasDo, const Location& doLocation, bool DEPRECATED_hasEnd); + const Location& inLocation, bool hasDo, const Location& doLocation); void visit(AstVisitor* visitor) override; @@ -726,9 +715,6 @@ public: bool hasDo = false; Location doLocation; - - // TODO clip with FFlag::LuauClipExtraHasEndProps - bool DEPRECATED_hasEnd = false; }; class AstStatAssign : public AstStat diff --git a/Ast/src/Ast.cpp b/Ast/src/Ast.cpp index 9a6ca4d7..0409a622 100644 --- a/Ast/src/Ast.cpp +++ b/Ast/src/Ast.cpp @@ -163,7 +163,7 @@ void AstExprIndexExpr::visit(AstVisitor* visitor) AstExprFunction::AstExprFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, AstLocal* self, const AstArray& args, bool vararg, const Location& varargLocation, AstStatBlock* body, size_t functionDepth, - const AstName& debugname, const std::optional& returnAnnotation, AstTypePack* varargAnnotation, bool DEPRECATED_hasEnd, + const AstName& debugname, const std::optional& returnAnnotation, AstTypePack* varargAnnotation, const std::optional& argLocation) : AstExpr(ClassIndex(), location) , generics(generics) @@ -177,7 +177,6 @@ AstExprFunction::AstExprFunction(const Location& location, const AstArray& thenLocation, const std::optional& elseLocation, bool DEPRECATED_hasEnd) + const std::optional& thenLocation, const std::optional& elseLocation) : AstStat(ClassIndex(), location) , condition(condition) , thenbody(thenbody) , elsebody(elsebody) , thenLocation(thenLocation) , elseLocation(elseLocation) - , DEPRECATED_hasEnd(DEPRECATED_hasEnd) { } @@ -418,13 +416,12 @@ void AstStatIf::visit(AstVisitor* visitor) } } -AstStatWhile::AstStatWhile(const Location& location, AstExpr* condition, AstStatBlock* body, bool hasDo, const Location& doLocation, bool DEPRECATED_hasEnd) +AstStatWhile::AstStatWhile(const Location& location, AstExpr* condition, AstStatBlock* body, bool hasDo, const Location& doLocation) : AstStat(ClassIndex(), location) , condition(condition) , body(body) , hasDo(hasDo) , doLocation(doLocation) - , DEPRECATED_hasEnd(DEPRECATED_hasEnd) { } @@ -526,7 +523,7 @@ void AstStatLocal::visit(AstVisitor* visitor) } AstStatFor::AstStatFor(const Location& location, AstLocal* var, AstExpr* from, AstExpr* to, AstExpr* step, AstStatBlock* body, bool hasDo, - const Location& doLocation, bool DEPRECATED_hasEnd) + const Location& doLocation) : AstStat(ClassIndex(), location) , var(var) , from(from) @@ -535,7 +532,6 @@ AstStatFor::AstStatFor(const Location& location, AstLocal* var, AstExpr* from, A , body(body) , hasDo(hasDo) , doLocation(doLocation) - , DEPRECATED_hasEnd(DEPRECATED_hasEnd) { } @@ -557,7 +553,7 @@ void AstStatFor::visit(AstVisitor* visitor) } AstStatForIn::AstStatForIn(const Location& location, const AstArray& vars, const AstArray& values, AstStatBlock* body, - bool hasIn, const Location& inLocation, bool hasDo, const Location& doLocation, bool DEPRECATED_hasEnd) + bool hasIn, const Location& inLocation, bool hasDo, const Location& doLocation) : AstStat(ClassIndex(), location) , vars(vars) , values(values) @@ -566,7 +562,6 @@ AstStatForIn::AstStatForIn(const Location& location, const AstArray& , inLocation(inLocation) , hasDo(hasDo) , doLocation(doLocation) - , DEPRECATED_hasEnd(DEPRECATED_hasEnd) { } diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 6d15e709..c4d1c65d 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -16,7 +16,6 @@ LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) // Warning: If you are introducing new syntax, ensure that it is behind a separate // flag so that we don't break production games by reverting syntax changes. // See docs/SyntaxChanges.md for an explanation. -LUAU_FASTFLAGVARIABLE(LuauClipExtraHasEndProps, false) LUAU_FASTFLAG(LuauCheckedFunctionSyntax) LUAU_FASTFLAGVARIABLE(LuauReadWritePropertySyntax, false) @@ -373,18 +372,15 @@ AstStat* Parser::parseIf() AstStat* elsebody = nullptr; Location end = start; std::optional elseLocation; - bool DEPRECATED_hasEnd = false; if (lexer.current().type == Lexeme::ReservedElseif) { - if (FFlag::LuauClipExtraHasEndProps) - thenbody->hasEnd = true; + thenbody->hasEnd = true; unsigned int oldRecursionCount = recursionCounter; incrementRecursionCounter("elseif"); elseLocation = lexer.current().location; elsebody = parseIf(); end = elsebody->location; - DEPRECATED_hasEnd = elsebody->as()->DEPRECATED_hasEnd; recursionCounter = oldRecursionCount; } else @@ -393,8 +389,7 @@ AstStat* Parser::parseIf() if (lexer.current().type == Lexeme::ReservedElse) { - if (FFlag::LuauClipExtraHasEndProps) - thenbody->hasEnd = true; + thenbody->hasEnd = true; elseLocation = lexer.current().location; matchThenElse = lexer.current(); nextLexeme(); @@ -406,21 +401,17 @@ AstStat* Parser::parseIf() end = lexer.current().location; bool hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchThenElse); - DEPRECATED_hasEnd = hasEnd; - if (FFlag::LuauClipExtraHasEndProps) + if (elsebody) { - if (elsebody) - { - if (AstStatBlock* elseBlock = elsebody->as()) - elseBlock->hasEnd = hasEnd; - } - else - thenbody->hasEnd = hasEnd; + if (AstStatBlock* elseBlock = elsebody->as()) + elseBlock->hasEnd = hasEnd; } + else + thenbody->hasEnd = hasEnd; } - return allocator.alloc(Location(start, end), cond, thenbody, elsebody, thenLocation, elseLocation, DEPRECATED_hasEnd); + return allocator.alloc(Location(start, end), cond, thenbody, elsebody, thenLocation, elseLocation); } // while exp do block end @@ -444,10 +435,9 @@ AstStat* Parser::parseWhile() Location end = lexer.current().location; bool hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchDo); - if (FFlag::LuauClipExtraHasEndProps) - body->hasEnd = hasEnd; + body->hasEnd = hasEnd; - return allocator.alloc(Location(start, end), cond, body, hasDo, matchDo.location, hasEnd); + return allocator.alloc(Location(start, end), cond, body, hasDo, matchDo.location); } // repeat block until exp @@ -467,8 +457,7 @@ AstStat* Parser::parseRepeat() functionStack.back().loopDepth--; bool hasUntil = expectMatchEndAndConsume(Lexeme::ReservedUntil, matchRepeat); - if (FFlag::LuauClipExtraHasEndProps) - body->hasEnd = hasUntil; + body->hasEnd = hasUntil; AstExpr* cond = parseExpr(); @@ -565,10 +554,9 @@ AstStat* Parser::parseFor() Location end = lexer.current().location; bool hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchDo); - if (FFlag::LuauClipExtraHasEndProps) - body->hasEnd = hasEnd; + body->hasEnd = hasEnd; - return allocator.alloc(Location(start, end), var, from, to, step, body, hasDo, matchDo.location, hasEnd); + return allocator.alloc(Location(start, end), var, from, to, step, body, hasDo, matchDo.location); } else { @@ -609,11 +597,10 @@ AstStat* Parser::parseFor() Location end = lexer.current().location; bool hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchDo); - if (FFlag::LuauClipExtraHasEndProps) - body->hasEnd = hasEnd; + body->hasEnd = hasEnd; return allocator.alloc( - Location(start, end), copy(vars), copy(values), body, hasIn, inLocation, hasDo, matchDo.location, hasEnd); + Location(start, end), copy(vars), copy(values), body, hasIn, inLocation, hasDo, matchDo.location); } } @@ -1100,11 +1087,10 @@ std::pair Parser::parseFunctionBody( Location end = lexer.current().location; bool hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchFunction); - if (FFlag::LuauClipExtraHasEndProps) - body->hasEnd = hasEnd; + body->hasEnd = hasEnd; return {allocator.alloc(Location(start, end), generics, genericPacks, self, vars, vararg, varargLocation, body, - functionStack.size(), debugname, typelist, varargAnnotation, hasEnd, argLocation), + functionStack.size(), debugname, typelist, varargAnnotation, argLocation), funLocal}; } diff --git a/CodeGen/include/Luau/AssemblyBuilderA64.h b/CodeGen/include/Luau/AssemblyBuilderA64.h index 0ed35910..78251012 100644 --- a/CodeGen/include/Luau/AssemblyBuilderA64.h +++ b/CodeGen/include/Luau/AssemblyBuilderA64.h @@ -139,6 +139,10 @@ public: void fsqrt(RegisterA64 dst, RegisterA64 src); void fsub(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + void ins_4s(RegisterA64 dst, RegisterA64 src, uint8_t index); + void ins_4s(RegisterA64 dst, uint8_t dstIndex, RegisterA64 src, uint8_t srcIndex); + void dup_4s(RegisterA64 dst, RegisterA64 src, uint8_t index); + // Floating-point rounding and conversions void frinta(RegisterA64 dst, RegisterA64 src); void frintm(RegisterA64 dst, RegisterA64 src); @@ -207,6 +211,7 @@ private: void placeSR3(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t op, int shift = 0, int N = 0); void placeSR2(const char* name, RegisterA64 dst, RegisterA64 src, uint8_t op, uint8_t op2 = 0); void placeR3(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t op, uint8_t op2); + void placeR3(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t sizes, uint8_t op, uint8_t op2); void placeR1(const char* name, RegisterA64 dst, RegisterA64 src, uint32_t op); void placeI12(const char* name, RegisterA64 dst, RegisterA64 src1, int src2, uint8_t op); void placeI16(const char* name, RegisterA64 dst, int src, uint8_t op, int shift = 0); diff --git a/CodeGen/include/Luau/AssemblyBuilderX64.h b/CodeGen/include/Luau/AssemblyBuilderX64.h index e5b82f08..0be59fb1 100644 --- a/CodeGen/include/Luau/AssemblyBuilderX64.h +++ b/CodeGen/include/Luau/AssemblyBuilderX64.h @@ -119,13 +119,18 @@ public: void vaddss(OperandX64 dst, OperandX64 src1, OperandX64 src2); void vsubsd(OperandX64 dst, OperandX64 src1, OperandX64 src2); + void vsubps(OperandX64 dst, OperandX64 src1, OperandX64 src2); void vmulsd(OperandX64 dst, OperandX64 src1, OperandX64 src2); + void vmulps(OperandX64 dst, OperandX64 src1, OperandX64 src2); void vdivsd(OperandX64 dst, OperandX64 src1, OperandX64 src2); + void vdivps(OperandX64 dst, OperandX64 src1, OperandX64 src2); + void vandps(OperandX64 dst, OperandX64 src1, OperandX64 src2); void vandpd(OperandX64 dst, OperandX64 src1, OperandX64 src2); void vandnpd(OperandX64 dst, OperandX64 src1, OperandX64 src2); void vxorpd(OperandX64 dst, OperandX64 src1, OperandX64 src2); + void vorps(OperandX64 dst, OperandX64 src1, OperandX64 src2); void vorpd(OperandX64 dst, OperandX64 src1, OperandX64 src2); void vucomisd(OperandX64 src1, OperandX64 src2); @@ -159,6 +164,9 @@ public: void vblendvpd(RegisterX64 dst, RegisterX64 src1, OperandX64 mask, RegisterX64 src3); + void vpshufps(RegisterX64 dst, RegisterX64 src1, OperandX64 src2, uint8_t shuffle); + void vpinsrd(RegisterX64 dst, RegisterX64 src1, OperandX64 src2, uint8_t offset); + // Run final checks bool finalize(); @@ -176,9 +184,11 @@ public: } // Constant allocation (uses rip-relative addressing) + OperandX64 i32(int32_t value); OperandX64 i64(int64_t value); OperandX64 f32(float value); OperandX64 f64(double value); + OperandX64 u32x4(uint32_t x, uint32_t y, uint32_t z, uint32_t w); OperandX64 f32x4(float x, float y, float z, float w); OperandX64 f64x2(double x, double y); OperandX64 bytes(const void* ptr, size_t size, size_t align = 8); @@ -260,6 +270,7 @@ private: std::vector