From 0edacdded41c89caf26cf3b780ca5ad07adcd9f2 Mon Sep 17 00:00:00 2001 From: Aaron Weiss Date: Fri, 26 Jan 2024 18:30:40 -0800 Subject: [PATCH] Sync to upstream/release/610 --- Analysis/include/Luau/Constraint.h | 42 +++- Analysis/include/Luau/ConstraintGenerator.h | 11 +- Analysis/include/Luau/ConstraintSolver.h | 5 +- Analysis/include/Luau/TypeFamily.h | 3 + Analysis/include/Luau/Unifier2.h | 1 + Analysis/src/AstJsonEncoder.cpp | 21 +- Analysis/src/Autocomplete.cpp | 82 ++----- Analysis/src/Constraint.cpp | 6 + Analysis/src/ConstraintGenerator.cpp | 172 ++++++++------ Analysis/src/ConstraintSolver.cpp | 237 +++++++++++++------ Analysis/src/Subtyping.cpp | 3 + Analysis/src/ToString.cpp | 9 +- Analysis/src/Type.cpp | 4 + Analysis/src/TypeChecker2.cpp | 153 +++++++++++++ Analysis/src/TypeFamily.cpp | 69 ++++++ Analysis/src/TypeInfer.cpp | 5 +- Analysis/src/Unifier2.cpp | 6 + Ast/include/Luau/Ast.h | 24 +- Ast/src/Ast.cpp | 15 +- Ast/src/Parser.cpp | 48 ++-- CodeGen/include/Luau/AssemblyBuilderA64.h | 5 + CodeGen/include/Luau/AssemblyBuilderX64.h | 11 + CodeGen/include/Luau/CodeGen.h | 36 ++- CodeGen/include/Luau/IrData.h | 20 ++ CodeGen/include/Luau/IrDump.h | 8 +- CodeGen/include/Luau/IrUtils.h | 7 + CodeGen/include/Luau/IrVisitUseDef.h | 1 + CodeGen/src/AssemblyBuilderA64.cpp | 138 ++++++++++- CodeGen/src/AssemblyBuilderX64.cpp | 99 +++++++- CodeGen/src/CodeGen.cpp | 43 ++++ CodeGen/src/CodeGenLower.h | 6 +- CodeGen/src/IrDump.cpp | 41 ++-- CodeGen/src/IrLoweringA64.cpp | 103 +++++++++ CodeGen/src/IrLoweringX64.cpp | 121 ++++++++++ CodeGen/src/IrLoweringX64.h | 6 + CodeGen/src/IrTranslation.cpp | 179 ++++++++++++++- CodeGen/src/IrUtils.cpp | 10 + CodeGen/src/IrValueLocationTracking.cpp | 6 + CodeGen/src/OptimizeConstProp.cpp | 22 +- CodeGen/src/OptimizeFinalX64.cpp | 17 ++ VM/src/ldebug.cpp | 9 +- VM/src/lstate.h | 1 + VM/src/lstrlib.cpp | 17 ++ VM/src/lvmexecute.cpp | 2 +- tests/AssemblyBuilderA64.test.cpp | 22 ++ tests/AssemblyBuilderX64.test.cpp | 39 +++- tests/AstJsonEncoder.test.cpp | 6 - tests/Autocomplete.test.cpp | 9 + tests/Conformance.test.cpp | 48 +++- tests/IrBuilder.test.cpp | 150 ++++++------ tests/IrLowering.test.cpp | 242 +++++++++++++++++++- tests/TypeFamily.test.cpp | 8 +- tests/TypeInfer.loops.test.cpp | 94 ++++++-- tests/TypeInfer.oop.test.cpp | 2 +- tests/TypeInfer.refinements.test.cpp | 2 + tests/TypeInfer.typestates.test.cpp | 6 +- tests/conformance/interrupt.lua | 33 +++ tests/conformance/native.lua | 24 ++ tools/faillist.txt | 40 +--- 59 files changed, 2048 insertions(+), 501 deletions(-) diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h index e69346bd..692e4a14 100644 --- a/Analysis/include/Luau/Constraint.h +++ b/Analysis/include/Luau/Constraint.h @@ -90,18 +90,42 @@ struct FunctionCallConstraint DenseHashMap* astOverloadResolvedTypes = nullptr; }; -// result ~ prim ExpectedType SomeSingletonType MultitonType +// function_check fn argsPack // -// If ExpectedType is potentially a singleton (an actual singleton or a union -// that contains a singleton), then result ~ SomeSingletonType +// If fn is a function type and argsPack is a partially solved +// pack of arguments to be supplied to the function, propagate the argument +// types of fn into the types of argsPack. This is used to implement +// bidirectional inference of lambda arguments. +struct FunctionCheckConstraint +{ + TypeId fn; + TypePackId argsPack; + + class AstExprCall* callSite = nullptr; +}; + +// prim FreeType ExpectedType PrimitiveType // -// else result ~ MultitonType +// FreeType is bounded below by the singleton type and above by PrimitiveType +// initially. When this constraint is resolved, it will check that the bounds +// of the free type are well-formed by subtyping. +// +// If they are not well-formed, then FreeType is replaced by its lower bound +// +// If they are well-formed and ExpectedType is potentially a singleton (an +// actual singleton or a union that contains a singleton), +// then FreeType is replaced by its lower bound +// +// else FreeType is replaced by PrimitiveType struct PrimitiveTypeConstraint { - TypeId resultType; - TypeId expectedType; - TypeId singletonType; - TypeId multitonType; + TypeId freeType; + + // potentially gets used to force the lower bound? + std::optional expectedType; + + // the primitive type to check against + TypeId primitiveType; }; // result ~ hasProp type "prop_name" @@ -230,7 +254,7 @@ struct ReducePackConstraint }; using ConstraintV = Variant; struct Constraint diff --git a/Analysis/include/Luau/ConstraintGenerator.h b/Analysis/include/Luau/ConstraintGenerator.h index ebd237f6..2f746a74 100644 --- a/Analysis/include/Luau/ConstraintGenerator.h +++ b/Analysis/include/Luau/ConstraintGenerator.h @@ -150,7 +150,7 @@ private: */ ScopePtr childScope(AstNode* node, const ScopePtr& parent); - std::optional lookup(Scope* scope, DefId def, bool prototype = true); + std::optional lookup(const ScopePtr& scope, DefId def, bool prototype = true); /** * Adds a new constraint with no dependencies to a given scope. @@ -178,8 +178,8 @@ private: }; using RefinementContext = InsertionOrderedMap; - void unionRefinements(const RefinementContext& lhs, const RefinementContext& rhs, RefinementContext& dest, std::vector* constraints); - void computeRefinement(const ScopePtr& scope, RefinementId refinement, RefinementContext* refis, bool sense, bool eq, std::vector* constraints); + void unionRefinements(const ScopePtr& scope, Location location, const RefinementContext& lhs, const RefinementContext& rhs, RefinementContext& dest, std::vector* constraints); + void computeRefinement(const ScopePtr& scope, Location location, RefinementId refinement, RefinementContext* refis, bool sense, bool eq, std::vector* constraints); void applyRefinements(const ScopePtr& scope, Location location, RefinementId refinement); ControlFlow visitBlockWithoutChildScope(const ScopePtr& scope, AstStatBlock* block); @@ -329,6 +329,11 @@ private: void reportError(Location location, TypeErrorData err); void reportCodeTooComplex(Location location); + // make a union type family of these two types + TypeId makeUnion(const ScopePtr& scope, Location location, TypeId lhs, TypeId rhs); + // make an intersect type family of these two types + TypeId makeIntersect(const ScopePtr& scope, Location location, TypeId lhs, TypeId rhs); + /** Scan the program for global definitions. * * ConstraintGenerator needs to differentiate between globals and accesses to undefined symbols. Doing this "for diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index f258b28b..e962f343 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -75,7 +75,7 @@ struct ConstraintSolver // anything. std::unordered_map, size_t> blockedConstraints; // A mapping of type/pack pointers to the constraints they block. - std::unordered_map>, HashBlockedConstraintId> blocked; + std::unordered_map, HashBlockedConstraintId> blocked; // Memoized instantiations of type aliases. DenseHashMap instantiatedAliases{{}}; // Breadcrumbs for where a free type's upper bound was expanded. We use @@ -126,6 +126,7 @@ struct ConstraintSolver bool tryDispatch(const NameConstraint& c, NotNull constraint); bool tryDispatch(const TypeAliasExpansionConstraint& c, NotNull constraint); bool tryDispatch(const FunctionCallConstraint& c, NotNull constraint); + bool tryDispatch(const FunctionCheckConstraint& c, NotNull constraint); bool tryDispatch(const PrimitiveTypeConstraint& c, NotNull constraint); bool tryDispatch(const HasPropConstraint& c, NotNull constraint); bool tryDispatch(const SetPropConstraint& c, NotNull constraint, bool force); @@ -285,7 +286,7 @@ private: * @param target the type or type pack pointer that the constraint is blocked on. * @param constraint the constraint to block. **/ - void block_(BlockedConstraintId target, NotNull constraint); + bool block_(BlockedConstraintId target, NotNull constraint); /** * Informs the solver that progress has been made on a type or type pack. The diff --git a/Analysis/include/Luau/TypeFamily.h b/Analysis/include/Luau/TypeFamily.h index 49c652ee..77fd6e8a 100644 --- a/Analysis/include/Luau/TypeFamily.h +++ b/Analysis/include/Luau/TypeFamily.h @@ -163,6 +163,9 @@ struct BuiltinTypeFamilies TypeFamily eqFamily; TypeFamily refineFamily; + TypeFamily unionFamily; + TypeFamily intersectFamily; + TypeFamily keyofFamily; TypeFamily rawkeyofFamily; diff --git a/Analysis/include/Luau/Unifier2.h b/Analysis/include/Luau/Unifier2.h index 4930df6f..f9a3fdc9 100644 --- a/Analysis/include/Luau/Unifier2.h +++ b/Analysis/include/Luau/Unifier2.h @@ -56,6 +56,7 @@ struct Unifier2 * free TypePack to another and encounter an occurs check violation. */ bool unify(TypeId subTy, TypeId superTy); + bool unify(const LocalType* subTy, TypeId superFn); bool unify(TypeId subTy, const FunctionType* superFn); bool unify(const UnionType* subUnion, TypeId superTy); bool unify(TypeId subTy, const UnionType* superUnion); diff --git a/Analysis/src/AstJsonEncoder.cpp b/Analysis/src/AstJsonEncoder.cpp index dcee3492..470d69b3 100644 --- a/Analysis/src/AstJsonEncoder.cpp +++ b/Analysis/src/AstJsonEncoder.cpp @@ -8,8 +8,6 @@ #include -LUAU_FASTFLAG(LuauClipExtraHasEndProps); - namespace Luau { @@ -393,8 +391,6 @@ struct AstJsonEncoder : public AstVisitor PROP(body); PROP(functionDepth); PROP(debugname); - if (!FFlag::LuauClipExtraHasEndProps) - write("hasEnd", node->DEPRECATED_hasEnd); }); } @@ -591,11 +587,8 @@ struct AstJsonEncoder : public AstVisitor void write(class AstStatBlock* node) { writeNode(node, "AstStatBlock", [&]() { - if (FFlag::LuauClipExtraHasEndProps) - { - writeRaw(",\"hasEnd\":"); - write(node->hasEnd); - } + writeRaw(",\"hasEnd\":"); + write(node->hasEnd); writeRaw(",\"body\":["); bool comma = false; for (AstStat* stat : node->body) @@ -619,8 +612,6 @@ struct AstJsonEncoder : public AstVisitor if (node->elsebody) PROP(elsebody); write("hasThen", node->thenLocation.has_value()); - if (!FFlag::LuauClipExtraHasEndProps) - write("hasEnd", node->DEPRECATED_hasEnd); }); } @@ -630,8 +621,6 @@ struct AstJsonEncoder : public AstVisitor PROP(condition); PROP(body); PROP(hasDo); - if (!FFlag::LuauClipExtraHasEndProps) - write("hasEnd", node->DEPRECATED_hasEnd); }); } @@ -640,8 +629,6 @@ struct AstJsonEncoder : public AstVisitor writeNode(node, "AstStatRepeat", [&]() { PROP(condition); PROP(body); - if (!FFlag::LuauClipExtraHasEndProps) - write("hasUntil", node->DEPRECATED_hasUntil); }); } @@ -687,8 +674,6 @@ struct AstJsonEncoder : public AstVisitor PROP(step); PROP(body); PROP(hasDo); - if (!FFlag::LuauClipExtraHasEndProps) - write("hasEnd", node->DEPRECATED_hasEnd); }); } @@ -700,8 +685,6 @@ struct AstJsonEncoder : public AstVisitor PROP(body); PROP(hasIn); PROP(hasDo); - if (!FFlag::LuauClipExtraHasEndProps) - write("hasEnd", node->DEPRECATED_hasEnd); }); } diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 55e6d8f0..75fd9f58 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -15,7 +15,6 @@ LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAG(DebugLuauReadWriteProperties); -LUAU_FASTFLAG(LuauClipExtraHasEndProps); LUAU_FASTFLAGVARIABLE(LuauAutocompleteStringLiteralBounds, false); static const std::unordered_set kStatementStartingKeywords = { @@ -1068,51 +1067,30 @@ static AutocompleteEntryMap autocompleteStatement( for (const auto& kw : kStatementStartingKeywords) result.emplace(kw, AutocompleteEntry{AutocompleteEntryKind::Keyword}); - if (FFlag::LuauClipExtraHasEndProps) + for (auto it = ancestry.rbegin(); it != ancestry.rend(); ++it) { - for (auto it = ancestry.rbegin(); it != ancestry.rend(); ++it) + if (AstStatForIn* statForIn = (*it)->as(); statForIn && !statForIn->body->hasEnd) + result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + else if (AstStatFor* statFor = (*it)->as(); statFor && !statFor->body->hasEnd) + result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + else if (AstStatIf* statIf = (*it)->as()) { - if (AstStatForIn* statForIn = (*it)->as(); statForIn && !statForIn->body->hasEnd) - result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - else if (AstStatFor* statFor = (*it)->as(); statFor && !statFor->body->hasEnd) - result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - else if (AstStatIf* statIf = (*it)->as()) + bool hasEnd = statIf->thenbody->hasEnd; + if (statIf->elsebody) { - bool hasEnd = statIf->thenbody->hasEnd; - if (statIf->elsebody) - { - if (AstStatBlock* elseBlock = statIf->elsebody->as()) - hasEnd = elseBlock->hasEnd; - } - - if (!hasEnd) - result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + if (AstStatBlock* elseBlock = statIf->elsebody->as()) + hasEnd = elseBlock->hasEnd; } - else if (AstStatWhile* statWhile = (*it)->as(); statWhile && !statWhile->body->hasEnd) - result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - else if (AstExprFunction* exprFunction = (*it)->as(); exprFunction && !exprFunction->body->hasEnd) - result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - if (AstStatBlock* exprBlock = (*it)->as(); exprBlock && !exprBlock->hasEnd) - result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - } - } - else - { - for (auto it = ancestry.rbegin(); it != ancestry.rend(); ++it) - { - if (AstStatForIn* statForIn = (*it)->as(); statForIn && !statForIn->DEPRECATED_hasEnd) - result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - else if (AstStatFor* statFor = (*it)->as(); statFor && !statFor->DEPRECATED_hasEnd) - result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - else if (AstStatIf* statIf = (*it)->as(); statIf && !statIf->DEPRECATED_hasEnd) - result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - else if (AstStatWhile* statWhile = (*it)->as(); statWhile && !statWhile->DEPRECATED_hasEnd) - result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - else if (AstExprFunction* exprFunction = (*it)->as(); exprFunction && !exprFunction->DEPRECATED_hasEnd) - result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - if (AstStatBlock* exprBlock = (*it)->as(); exprBlock && !exprBlock->hasEnd) + + if (!hasEnd) result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); } + else if (AstStatWhile* statWhile = (*it)->as(); statWhile && !statWhile->body->hasEnd) + result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + else if (AstExprFunction* exprFunction = (*it)->as(); exprFunction && !exprFunction->body->hasEnd) + result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + if (AstStatBlock* exprBlock = (*it)->as(); exprBlock && !exprBlock->hasEnd) + result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); } if (ancestry.size() >= 2) @@ -1127,16 +1105,8 @@ static AutocompleteEntryMap autocompleteStatement( } } - if (FFlag::LuauClipExtraHasEndProps) - { - if (AstStatRepeat* statRepeat = parent->as(); statRepeat && !statRepeat->body->hasEnd) - result.emplace("until", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - } - else - { - if (AstStatRepeat* statRepeat = parent->as(); statRepeat && !statRepeat->DEPRECATED_hasUntil) - result.emplace("until", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - } + if (AstStatRepeat* statRepeat = parent->as(); statRepeat && !statRepeat->body->hasEnd) + result.emplace("until", AutocompleteEntry{AutocompleteEntryKind::Keyword}); } if (ancestry.size() >= 4) @@ -1150,16 +1120,8 @@ static AutocompleteEntryMap autocompleteStatement( } } - if (FFlag::LuauClipExtraHasEndProps) - { - if (AstStatRepeat* statRepeat = extractStat(ancestry); statRepeat && !statRepeat->body->hasEnd) - result.emplace("until", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - } - else - { - if (AstStatRepeat* statRepeat = extractStat(ancestry); statRepeat && !statRepeat->DEPRECATED_hasUntil) - result.emplace("until", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - } + if (AstStatRepeat* statRepeat = extractStat(ancestry); statRepeat && !statRepeat->body->hasEnd) + result.emplace("until", AutocompleteEntry{AutocompleteEntryKind::Keyword}); return result; } diff --git a/Analysis/src/Constraint.cpp b/Analysis/src/Constraint.cpp index 3035d480..7d3f9e31 100644 --- a/Analysis/src/Constraint.cpp +++ b/Analysis/src/Constraint.cpp @@ -45,6 +45,12 @@ DenseHashSet Constraint::getFreeTypes() const ftc.traverse(psc->subPack); ftc.traverse(psc->superPack); } + else if (auto ptc = get(*this)) + { + // we need to take into account primitive type constraints to prevent type families from reducing on + // primitive whose types we have not yet selected to be singleton or not. + ftc.traverse(ptc->freeType); + } return types; } diff --git a/Analysis/src/ConstraintGenerator.cpp b/Analysis/src/ConstraintGenerator.cpp index 91f2ab7e..bc412c0e 100644 --- a/Analysis/src/ConstraintGenerator.cpp +++ b/Analysis/src/ConstraintGenerator.cpp @@ -20,6 +20,7 @@ #include "Luau/InsertionOrderedMap.h" #include +#include LUAU_FASTINT(LuauCheckRecursionLimit); LUAU_FASTFLAG(DebugLuauLogSolverToJson); @@ -205,7 +206,7 @@ ScopePtr ConstraintGenerator::childScope(AstNode* node, const ScopePtr& parent) return scope; } -std::optional ConstraintGenerator::lookup(Scope* scope, DefId def, bool prototype) +std::optional ConstraintGenerator::lookup(const ScopePtr& scope, DefId def, bool prototype) { if (get(def)) return scope->lookup(def); @@ -230,7 +231,7 @@ std::optional ConstraintGenerator::lookup(Scope* scope, DefId def, bool rootScope->lvalueTypes[operand] = *ty; } - res = simplifyUnion(builtinTypes, arena, res, *ty).result; + res = makeUnion(scope, Location{} /* TODO: can we provide a real location here? */, res, *ty); } scope->lvalueTypes[def] = res; @@ -250,18 +251,13 @@ NotNull ConstraintGenerator::addConstraint(const ScopePtr& scope, st return NotNull{constraints.emplace_back(std::move(c)).get()}; } -void ConstraintGenerator::unionRefinements(const RefinementContext& lhs, const RefinementContext& rhs, RefinementContext& dest, std::vector* constraints) +void ConstraintGenerator::unionRefinements(const ScopePtr& scope, Location location, const RefinementContext& lhs, const RefinementContext& rhs, RefinementContext& dest, std::vector* constraints) { const auto intersect = [&](const std::vector& types) { if (1 == types.size()) return types[0]; else if (2 == types.size()) - { - // TODO: It may be advantageous to introduce a refine type family here when there are blockedTypes. - SimplifyResult sr = simplifyIntersection(builtinTypes, arena, types[0], types[1]); - if (sr.blockedTypes.empty()) - return sr.result; - } + return makeIntersect(scope, location, types[0], types[1]); return arena->addType(IntersectionType{types}); }; @@ -281,48 +277,48 @@ void ConstraintGenerator::unionRefinements(const RefinementContext& lhs, const R rhsIt->second.discriminantTypes.size() == 1 ? rhsIt->second.discriminantTypes[0] : intersect(rhsIt->second.discriminantTypes); dest.insert(def, {}); - dest.get(def)->discriminantTypes.push_back(simplifyUnion(builtinTypes, arena, leftDiscriminantTy, rightDiscriminantTy).result); + dest.get(def)->discriminantTypes.push_back(makeUnion(scope, location, leftDiscriminantTy, rightDiscriminantTy)); dest.get(def)->shouldAppendNilType |= partition.shouldAppendNilType || rhsIt->second.shouldAppendNilType; } } -void ConstraintGenerator::computeRefinement(const ScopePtr& scope, RefinementId refinement, RefinementContext* refis, bool sense, bool eq, std::vector* constraints) +void ConstraintGenerator::computeRefinement(const ScopePtr& scope, Location location, RefinementId refinement, RefinementContext* refis, bool sense, bool eq, std::vector* constraints) { if (!refinement) return; else if (auto variadic = get(refinement)) { for (RefinementId refi : variadic->refinements) - computeRefinement(scope, refi, refis, sense, eq, constraints); + computeRefinement(scope, location, refi, refis, sense, eq, constraints); } else if (auto negation = get(refinement)) - return computeRefinement(scope, negation->refinement, refis, !sense, eq, constraints); + return computeRefinement(scope, location, negation->refinement, refis, !sense, eq, constraints); else if (auto conjunction = get(refinement)) { RefinementContext lhsRefis; RefinementContext rhsRefis; - computeRefinement(scope, conjunction->lhs, sense ? refis : &lhsRefis, sense, eq, constraints); - computeRefinement(scope, conjunction->rhs, sense ? refis : &rhsRefis, sense, eq, constraints); + computeRefinement(scope, location, conjunction->lhs, sense ? refis : &lhsRefis, sense, eq, constraints); + computeRefinement(scope, location, conjunction->rhs, sense ? refis : &rhsRefis, sense, eq, constraints); if (!sense) - unionRefinements(lhsRefis, rhsRefis, *refis, constraints); + unionRefinements(scope, location, lhsRefis, rhsRefis, *refis, constraints); } else if (auto disjunction = get(refinement)) { RefinementContext lhsRefis; RefinementContext rhsRefis; - computeRefinement(scope, disjunction->lhs, sense ? &lhsRefis : refis, sense, eq, constraints); - computeRefinement(scope, disjunction->rhs, sense ? &rhsRefis : refis, sense, eq, constraints); + computeRefinement(scope, location, disjunction->lhs, sense ? &lhsRefis : refis, sense, eq, constraints); + computeRefinement(scope, location, disjunction->rhs, sense ? &rhsRefis : refis, sense, eq, constraints); if (sense) - unionRefinements(lhsRefis, rhsRefis, *refis, constraints); + unionRefinements(scope, location, lhsRefis, rhsRefis, *refis, constraints); } else if (auto equivalence = get(refinement)) { - computeRefinement(scope, equivalence->lhs, refis, sense, true, constraints); - computeRefinement(scope, equivalence->rhs, refis, sense, true, constraints); + computeRefinement(scope, location, equivalence->lhs, refis, sense, true, constraints); + computeRefinement(scope, location, equivalence->rhs, refis, sense, true, constraints); } else if (auto proposition = get(refinement)) { @@ -423,11 +419,11 @@ void ConstraintGenerator::applyRefinements(const ScopePtr& scope, Location locat RefinementContext refinements; std::vector constraints; - computeRefinement(scope, refinement, &refinements, /*sense*/ true, /*eq*/ false, &constraints); + computeRefinement(scope, location, refinement, &refinements, /*sense*/ true, /*eq*/ false, &constraints); for (auto& [def, partition] : refinements) { - if (std::optional defTy = lookup(scope.get(), def)) + if (std::optional defTy = lookup(scope, def)) { TypeId ty = *defTy; if (partition.shouldAppendNilType) @@ -455,15 +451,15 @@ void ConstraintGenerator::applyRefinements(const ScopePtr& scope, Location locat switch (shouldSuppressErrors(normalizer, ty)) { case ErrorSuppression::DoNotSuppress: - ty = simplifyIntersection(builtinTypes, arena, ty, dt).result; + ty = makeIntersect(scope, location, ty, dt); break; case ErrorSuppression::Suppress: - ty = simplifyIntersection(builtinTypes, arena, ty, dt).result; - ty = simplifyUnion(builtinTypes, arena, ty, builtinTypes->errorType).result; + ty = makeIntersect(scope, location, ty, dt); + ty = makeUnion(scope, location, ty, builtinTypes->errorType); break; case ErrorSuppression::NormalizationFailed: reportError(location, NormalizationTooComplex{}); - ty = simplifyIntersection(builtinTypes, arena, ty, dt).result; + ty = makeIntersect(scope, location, ty, dt); break; } } @@ -761,7 +757,7 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatForIn* forI for (AstLocal* var : forIn->vars) { - TypeId assignee = arena->addType(BlockedType{}); + TypeId assignee = arena->addType(LocalType{builtinTypes->neverType, /* blockCount */ 1, var->name.value}); variableTypes.push_back(assignee); if (var->annotation) @@ -872,7 +868,7 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatFunction* f DenseHashSet excludeList{nullptr}; DefId def = dfg->getDef(function->name); - std::optional existingFunctionTy = lookup(scope.get(), def); + std::optional existingFunctionTy = lookup(scope, def); if (AstExprLocal* localName = function->name->as()) { @@ -1492,8 +1488,12 @@ InferencePack ConstraintGenerator::checkPack(const ScopePtr& scope, AstExprCall* discriminantTypes.push_back(std::nullopt); } + Checkpoint funcBeginCheckpoint = checkpoint(this); + TypeId fnType = check(scope, call->func).ty; + Checkpoint funcEndCheckpoint = checkpoint(this); + std::vector> expectedTypesForCall = getExpectedCallTypesForFunctionOverloads(fnType); module->astOriginalCallTypes[call->func] = fnType; @@ -1624,13 +1624,31 @@ InferencePack ConstraintGenerator::checkPack(const ScopePtr& scope, AstExprCall* &module->astOverloadResolvedTypes, }); - // We force constraints produced by checking function arguments to wait - // until after we have resolved the constraint on the function itself. - // This ensures, for instance, that we start inferring the contents of - // lambdas under the assumption that their arguments and return types - // will be compatible with the enclosing function call. - forEachConstraint(argBeginCheckpoint, argEndCheckpoint, this, [fcc](const ConstraintPtr& constraint) { - constraint->dependencies.emplace_back(fcc); + NotNull foo = addConstraint(scope, call->func->location, + FunctionCheckConstraint{ + fnType, + argPack, + call + } + ); + + /* + * To make bidirectional type checking work, we need to solve these constraints in a particular order: + * + * 1. Solve the function type + * 2. Propagate type information from the function type to the argument types + * 3. Solve the argument types + * 4. Solve the call + */ + + forEachConstraint(funcBeginCheckpoint, funcEndCheckpoint, this, [foo](const ConstraintPtr& constraint) { + foo->dependencies.emplace_back(constraint.get()); + }); + + forEachConstraint(argBeginCheckpoint, argEndCheckpoint, this, [foo, fcc](const ConstraintPtr& constraint) { + constraint->dependencies.emplace_back(foo); + + fcc->dependencies.emplace_back(constraint.get()); }); return InferencePack{rets, {refinementArena.variadic(returnRefinements)}}; @@ -1712,23 +1730,12 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprConstantStrin if (forceSingleton) return Inference{arena->addType(SingletonType{StringSingleton{std::string{string->value.data, string->value.size}}})}; - if (expectedType) - { - const TypeId expectedTy = follow(*expectedType); - if (get(expectedTy) || get(expectedTy) || get(expectedTy)) - { - TypeId ty = arena->addType(BlockedType{}); - TypeId singletonType = arena->addType(SingletonType(StringSingleton{std::string(string->value.data, string->value.size)})); - addConstraint(scope, string->location, PrimitiveTypeConstraint{ty, expectedTy, singletonType, builtinTypes->stringType}); - return Inference{ty}; - } - else if (maybeSingleton(expectedTy)) - return Inference{arena->addType(SingletonType{StringSingleton{std::string{string->value.data, string->value.size}}})}; - - return Inference{builtinTypes->stringType}; - } - - return Inference{builtinTypes->stringType}; + FreeType ft = FreeType{scope.get()}; + ft.lowerBound = arena->addType(SingletonType{StringSingleton{std::string{string->value.data, string->value.size}}}); + ft.upperBound = builtinTypes->stringType; + const TypeId freeTy = arena->addType(ft); + addConstraint(scope, string->location, PrimitiveTypeConstraint{freeTy, expectedType, builtinTypes->stringType}); + return Inference{freeTy}; } Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprConstantBool* boolExpr, std::optional expectedType, bool forceSingleton) @@ -1737,23 +1744,12 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprConstantBool* if (forceSingleton) return Inference{singletonType}; - if (expectedType) - { - const TypeId expectedTy = follow(*expectedType); - - if (get(expectedTy) || get(expectedTy)) - { - TypeId ty = arena->addType(BlockedType{}); - addConstraint(scope, boolExpr->location, PrimitiveTypeConstraint{ty, expectedTy, singletonType, builtinTypes->booleanType}); - return Inference{ty}; - } - else if (maybeSingleton(expectedTy)) - return Inference{singletonType}; - - return Inference{builtinTypes->booleanType}; - } - - return Inference{builtinTypes->booleanType}; + FreeType ft = FreeType{scope.get()}; + ft.lowerBound = singletonType; + ft.upperBound = builtinTypes->booleanType; + const TypeId freeTy = arena->addType(ft); + addConstraint(scope, boolExpr->location, PrimitiveTypeConstraint{freeTy, expectedType, builtinTypes->booleanType}); + return Inference{freeTy}; } Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprLocal* local) @@ -1766,12 +1762,12 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprLocal* local) // if we have a refinement key, we can look up its type. if (key) - maybeTy = lookup(scope.get(), key->def); + maybeTy = lookup(scope, key->def); // if the current def doesn't have a type, we might be doing a compound assignment // and therefore might need to look at the rvalue def instead. if (!maybeTy && rvalueDef) - maybeTy = lookup(scope.get(), *rvalueDef); + maybeTy = lookup(scope, *rvalueDef); if (maybeTy) { @@ -1797,7 +1793,7 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprGlobal* globa /* prepopulateGlobalScope() has already added all global functions to the environment by this point, so any * global that is not already in-scope is definitely an unknown symbol. */ - if (auto ty = lookup(scope.get(), def, /*prototype=*/false)) + if (auto ty = lookup(scope, def, /*prototype=*/false)) { rootScope->lvalueTypes[def] = *ty; return Inference{*ty, refinementArena.proposition(key, builtinTypes->truthyType)}; @@ -1816,7 +1812,7 @@ Inference ConstraintGenerator::checkIndexName(const ScopePtr& scope, const Refin if (key) { - if (auto ty = lookup(scope.get(), key->def)) + if (auto ty = lookup(scope, key->def)) return Inference{*ty, refinementArena.proposition(key, builtinTypes->truthyType)}; scope->rvalueRefinements[key->def] = result; @@ -1852,7 +1848,7 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprIndexExpr* in const RefinementKey* key = dfg->getRefinementKey(indexExpr); if (key) { - if (auto ty = lookup(scope.get(), key->def)) + if (auto ty = lookup(scope, key->def)) return Inference{*ty, refinementArena.proposition(key, builtinTypes->truthyType)}; scope->rvalueRefinements[key->def] = result; @@ -2120,7 +2116,7 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprIfElse* ifEls applyRefinements(elseScope, ifElse->falseExpr->location, refinementArena.negation(refinement)); TypeId elseType = check(elseScope, ifElse->falseExpr, expectedType).ty; - return Inference{expectedType ? *expectedType : simplifyUnion(builtinTypes, arena, thenType, elseType).result}; + return Inference{expectedType ? *expectedType : makeUnion(scope, ifElse->location, thenType, elseType)}; } Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert) @@ -3172,6 +3168,30 @@ void ConstraintGenerator::reportCodeTooComplex(Location location) logger->captureGenerationError(errors.back()); } +TypeId ConstraintGenerator::makeUnion(const ScopePtr& scope, Location location, TypeId lhs, TypeId rhs) +{ + TypeId resultType = arena->addType(TypeFamilyInstanceType{ + NotNull{&kBuiltinTypeFamilies.unionFamily}, + {lhs, rhs}, + {}, + }); + addConstraint(scope, location, ReduceConstraint{resultType}); + + return resultType; +} + +TypeId ConstraintGenerator::makeIntersect(const ScopePtr& scope, Location location, TypeId lhs, TypeId rhs) +{ + TypeId resultType = arena->addType(TypeFamilyInstanceType{ + NotNull{&kBuiltinTypeFamilies.intersectFamily}, + {lhs, rhs}, + {}, + }); + addConstraint(scope, location, ReduceConstraint{resultType}); + + return resultType; +} + struct GlobalPrepopulator : AstVisitor { const NotNull globalScope; diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 7430dbf3..ec0d6c8a 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -283,7 +283,8 @@ ConstraintSolver::ConstraintSolver(NotNull normalizer, NotNullgetFreeTypes()) { // increment the reference count for `ty` - unresolvedConstraints[ty] += 1; + auto [refCount, _] = unresolvedConstraints.try_insert(ty, 0); + refCount += 1; } for (NotNull dep : c->dependencies) @@ -368,7 +369,13 @@ void ConstraintSolver::run() // decrement the referenced free types for this constraint if we dispatched successfully! for (auto ty : c->getFreeTypes()) - unresolvedConstraints[ty] -= 1; + { + // this is a little weird, but because we're only counting free types in subtyping constraints, + // some constraints (like unpack) might actually produce _more_ references to a free type. + size_t& refCount = unresolvedConstraints[ty]; + if (refCount > 0) + refCount -= 1; + } if (logger) { @@ -534,6 +541,8 @@ bool ConstraintSolver::tryDispatch(NotNull constraint, bool fo success = tryDispatch(*taec, constraint); else if (auto fcc = get(*constraint)) success = tryDispatch(*fcc, constraint); + else if (auto fcc = get(*constraint)) + success = tryDispatch(*fcc, constraint); else if (auto fcc = get(*constraint)) success = tryDispatch(*fcc, constraint); else if (auto hpc = get(*constraint)) @@ -992,6 +1001,15 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull(fn)) + { + asMutable(c.result)->ty.emplace(builtinTypes->errorTypePack); + unblock(c.result, constraint->location); + + return true; + } + auto [argsHead, argsTail] = flatten(argsPack); bool blocked = false; @@ -1080,44 +1098,6 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNullanyType}; } - // We know the type of the function and the arguments it expects to receive. - // We also know the TypeIds of the actual arguments that will be passed. - // - // Bidirectional type checking: Force those TypeIds to be the expected - // arguments. If something is incoherent, we'll spot it in type checking. - // - // Most important detail: If a function argument is a lambda, we also want - // to force unannotated argument types of that lambda to be the expected - // types. - - // FIXME: Bidirectional type checking of overloaded functions is not yet supported. - if (auto ftv = get(fn)) - { - const std::vector expectedArgs = flatten(ftv->argTypes).first; - const std::vector argPackHead = flatten(argsPack).first; - - for (size_t i = 0; i < c.callSite->args.size && i < expectedArgs.size() && i < argPackHead.size(); ++i) - { - const FunctionType* expectedLambdaTy = get(follow(expectedArgs[i])); - const FunctionType* lambdaTy = get(follow(argPackHead[i])); - const AstExprFunction* lambdaExpr = c.callSite->args.data[i]->as(); - - if (expectedLambdaTy && lambdaTy && lambdaExpr) - { - const std::vector expectedLambdaArgTys = flatten(expectedLambdaTy->argTypes).first; - const std::vector lambdaArgTys = flatten(lambdaTy->argTypes).first; - - for (size_t j = 0; j < expectedLambdaArgTys.size() && j < lambdaArgTys.size() && j < lambdaExpr->args.size; ++j) - { - if (!lambdaExpr->args.data[j]->annotation && get(follow(lambdaArgTys[j]))) - { - asMutable(lambdaArgTys[j])->ty.emplace(expectedLambdaArgTys[j]); - } - } - } - } - } - TypeId inferredTy = arena->addType(FunctionType{TypeLevel{}, constraint->scope.get(), argsPack, c.result}); Unifier2 u2{NotNull{arena}, builtinTypes, constraint->scope, NotNull{&iceReporter}}; @@ -1141,17 +1121,94 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull constraint) +{ + const TypeId fn = follow(c.fn); + const TypePackId argsPack = follow(c.argsPack); + + if (isBlocked(fn)) + return block(fn, constraint); + + if (isBlocked(argsPack)) + return block(argsPack, constraint); + + // We know the type of the function and the arguments it expects to receive. + // We also know the TypeIds of the actual arguments that will be passed. + // + // Bidirectional type checking: Force those TypeIds to be the expected + // arguments. If something is incoherent, we'll spot it in type checking. + // + // Most important detail: If a function argument is a lambda, we also want + // to force unannotated argument types of that lambda to be the expected + // types. + + // FIXME: Bidirectional type checking of overloaded functions is not yet supported. + if (auto ftv = get(fn)) + { + const std::vector expectedArgs = flatten(ftv->argTypes).first; + const std::vector argPackHead = flatten(argsPack).first; + + for (size_t i = 0; i < c.callSite->args.size && i < expectedArgs.size() && i < argPackHead.size(); ++i) + { + const TypeId expectedArgTy = follow(expectedArgs[i]); + const TypeId actualArgTy = follow(argPackHead[i]); + + const FunctionType* expectedLambdaTy = get(expectedArgTy); + const FunctionType* lambdaTy = get(actualArgTy); + const AstExprFunction* lambdaExpr = c.callSite->args.data[i]->as(); + + if (expectedLambdaTy && lambdaTy && lambdaExpr) + { + const std::vector expectedLambdaArgTys = flatten(expectedLambdaTy->argTypes).first; + const std::vector lambdaArgTys = flatten(lambdaTy->argTypes).first; + + for (size_t j = 0; j < expectedLambdaArgTys.size() && j < lambdaArgTys.size() && j < lambdaExpr->args.size; ++j) + { + if (!lambdaExpr->args.data[j]->annotation && get(follow(lambdaArgTys[j]))) + { + asMutable(lambdaArgTys[j])->ty.emplace(expectedLambdaArgTys[j]); + } + } + } + else + { + Unifier2 u2{arena, builtinTypes, constraint->scope, NotNull{&iceReporter}}; + u2.unify(actualArgTy, expectedArgTy); + } + } + } + + return true; +} + bool ConstraintSolver::tryDispatch(const PrimitiveTypeConstraint& c, NotNull constraint) { - TypeId expectedType = follow(c.expectedType); - if (isBlocked(expectedType) || get(expectedType)) - return block(expectedType, constraint); + std::optional expectedType = c.expectedType ? std::make_optional(follow(*c.expectedType)) : std::nullopt; + if (expectedType && (isBlocked(*expectedType) || get(*expectedType))) + return block(*expectedType, constraint); - LUAU_ASSERT(get(c.resultType)); + const FreeType* freeType = get(follow(c.freeType)); - TypeId bindTo = maybeSingleton(expectedType) ? c.singletonType : c.multitonType; - asMutable(c.resultType)->ty.emplace(bindTo); - unblock(c.resultType, constraint->location); + // if this is no longer a free type, then we're done. + if (!freeType) + return true; + + // We will wait if there are any other references to the free type mentioned here. + // This is probably the only thing that makes this not insane to do. + if (auto refCount = unresolvedConstraints.find(c.freeType); refCount && *refCount > 1) + { + block(c.freeType, constraint); + return false; + } + + TypeId bindTo = c.primitiveType; + + if (freeType->upperBound != c.primitiveType && maybeSingleton(freeType->upperBound)) + bindTo = freeType->lowerBound; + else if (expectedType && maybeSingleton(*expectedType)) + bindTo = freeType->lowerBound; + + asMutable(c.freeType)->ty.emplace(bindTo); return true; } @@ -1163,7 +1220,7 @@ bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNull(resultType)); - if (isBlocked(subjectType) || get(subjectType)) + if (isBlocked(subjectType) || get(subjectType) || get(subjectType)) return block(subjectType, constraint); auto [blocked, result] = lookupTableProp(subjectType, c.prop, c.suppressSimplification); @@ -1599,7 +1656,7 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl auto unpack = [&](TypeId ty) { TypePackId variadic = arena->addTypePack(VariadicTypePack{ty}); - pushConstraint(constraint->scope, constraint->location, UnpackConstraint{c.variables, variadic}); + pushConstraint(constraint->scope, constraint->location, UnpackConstraint{c.variables, variadic, /* resultIsLValue */ true}); }; if (get(iteratorTy)) @@ -1639,6 +1696,23 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl { TypePackId expectedVariablePack = arena->addTypePack({iteratorTable->indexer->indexType, iteratorTable->indexer->indexResultType}); unify(constraint->scope, constraint->location, c.variables, expectedVariablePack); + + auto [variableTys, variablesTail] = flatten(c.variables); + + // the local types for the indexer _should_ be all set after unification + for (TypeId ty : variableTys) + { + if (auto lt = getMutable(ty)) + { + LUAU_ASSERT(lt->blockCount > 0); + --lt->blockCount; + + LUAU_ASSERT(0 <= lt->blockCount); + + if (0 == lt->blockCount) + asMutable(ty)->ty.emplace(lt->domain); + } + } } else unpack(builtinTypes->errorType); @@ -1775,7 +1849,7 @@ bool ConstraintSolver::tryDispatchIterableFunction( modifiedNextRetHead.push_back(*it); TypePackId modifiedNextRetPack = arena->addTypePack(std::move(modifiedNextRetHead), it.tail()); - auto psc = pushConstraint(constraint->scope, constraint->location, UnpackConstraint{c.variables, modifiedNextRetPack}); + auto psc = pushConstraint(constraint->scope, constraint->location, UnpackConstraint{c.variables, modifiedNextRetPack, /* resultIsLValue */ true}); inheritBlocks(constraint, psc); return true; @@ -1876,7 +1950,7 @@ std::pair, std::optional> ConstraintSolver::lookupTa { const TypeId upperBound = follow(ft->upperBound); - if (get(upperBound)) + if (get(upperBound) || get(upperBound)) return lookupTableProp(upperBound, propName, suppressSimplification, seen); // TODO: The upper bound could be an intersection that contains suitable tables or classes. @@ -2008,46 +2082,63 @@ void ConstraintSolver::bindBlockedType(TypeId blockedTy, TypeId resultTy, TypeId asMutable(blockedTy)->ty.emplace(resultTy); } -void ConstraintSolver::block_(BlockedConstraintId target, NotNull constraint) +bool ConstraintSolver::block_(BlockedConstraintId target, NotNull constraint) { - blocked[target].push_back(constraint); + // If a set is not present for the target, construct a new DenseHashSet for it, + // else grab the address of the existing set. + NotNull> blockVec{&blocked.try_emplace(target, nullptr).first->second}; + + if (blockVec->find(constraint)) + return false; + + blockVec->insert(constraint); auto& count = blockedConstraints[constraint]; count += 1; + + return true; } void ConstraintSolver::block(NotNull target, NotNull constraint) { - if (logger) - logger->pushBlock(constraint, target); + const bool newBlock = block_(target.get(), constraint); + if (newBlock) + { + if (logger) + logger->pushBlock(constraint, target); - if (FFlag::DebugLuauLogSolver) - printf("block Constraint %s on\t%s\n", toString(*target, opts).c_str(), toString(*constraint, opts).c_str()); - - block_(target.get(), constraint); + if (FFlag::DebugLuauLogSolver) + printf("block Constraint %s on\t%s\n", toString(*target, opts).c_str(), toString(*constraint, opts).c_str()); + } } bool ConstraintSolver::block(TypeId target, NotNull constraint) { - if (logger) - logger->pushBlock(constraint, target); + const bool newBlock = block_(follow(target), constraint); + if (newBlock) + { + if (logger) + logger->pushBlock(constraint, target); - if (FFlag::DebugLuauLogSolver) - printf("block TypeId %s on\t%s\n", toString(target, opts).c_str(), toString(*constraint, opts).c_str()); + if (FFlag::DebugLuauLogSolver) + printf("block TypeId %s on\t%s\n", toString(target, opts).c_str(), toString(*constraint, opts).c_str()); + } - block_(follow(target), constraint); return false; } bool ConstraintSolver::block(TypePackId target, NotNull constraint) { - if (logger) - logger->pushBlock(constraint, target); + const bool newBlock = block_(target, constraint); + if (newBlock) + { + if (logger) + logger->pushBlock(constraint, target); - if (FFlag::DebugLuauLogSolver) - printf("block TypeId %s on\t%s\n", toString(target, opts).c_str(), toString(*constraint, opts).c_str()); + if (FFlag::DebugLuauLogSolver) + printf("block TypeId %s on\t%s\n", toString(target, opts).c_str(), toString(*constraint, opts).c_str()); + } - block_(target, constraint); return false; } @@ -2058,9 +2149,9 @@ void ConstraintSolver::inheritBlocks(NotNull source, NotNullsecond) + for (const Constraint* blockedConstraint : blockedIt->second) { - block(addition, blockedConstraint); + block(addition, NotNull{blockedConstraint}); } } } @@ -2112,9 +2203,9 @@ void ConstraintSolver::unblock_(BlockedConstraintId progressed) return; // unblocked should contain a value always, because of the above check - for (NotNull unblockedConstraint : it->second) + for (const Constraint* unblockedConstraint : it->second) { - auto& count = blockedConstraints[unblockedConstraint]; + auto& count = blockedConstraints[NotNull{unblockedConstraint}]; if (FFlag::DebugLuauLogSolver) printf("Unblocking count=%d\t%s\n", int(count), toString(*unblockedConstraint, opts).c_str()); diff --git a/Analysis/src/Subtyping.cpp b/Analysis/src/Subtyping.cpp index b8671d21..62b574ca 100644 --- a/Analysis/src/Subtyping.cpp +++ b/Analysis/src/Subtyping.cpp @@ -699,6 +699,9 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypePackId results.push_back(SubtypingResult{ok}.withBothComponent(TypePath::PackField::Tail)); } } + else if (get(*subTail) || get(*superTail)) + // error type is fine on either side + results.push_back(SubtypingResult{true}.withBothComponent(TypePath::PackField::Tail)); else iceReporter->ice( format("Subtyping::isSubtype got unexpected type packs %s and %s", toString(*subTail).c_str(), toString(*superTail).c_str())); diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index b14df311..0f520116 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -1742,9 +1742,16 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts) { return "call " + tos(c.fn) + "( " + tos(c.argsPack) + " )" + " with { result = " + tos(c.result) + " }"; } + else if constexpr (std::is_same_v) + { + return "function_check " + tos(c.fn) + " " + tos(c.argsPack); + } else if constexpr (std::is_same_v) { - return tos(c.resultType) + " ~ prim " + tos(c.expectedType) + ", " + tos(c.singletonType) + ", " + tos(c.multitonType); + if (c.expectedType) + return "prim " + tos(c.freeType) + "[expected: " + tos(*c.expectedType) + "] as " + tos(c.primitiveType); + else + return "prim " + tos(c.freeType) + " as " + tos(c.primitiveType); } else if constexpr (std::is_same_v) { diff --git a/Analysis/src/Type.cpp b/Analysis/src/Type.cpp index abe719e1..71a07d73 100644 --- a/Analysis/src/Type.cpp +++ b/Analysis/src/Type.cpp @@ -419,6 +419,10 @@ bool maybeSingleton(TypeId ty) for (TypeId option : utv) if (get(follow(option))) return true; + if (const IntersectionType* itv = get(ty)) + for (TypeId part : itv) + if (maybeSingleton(part)) // will i regret this? + return true; return false; } diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index 349443ab..fdac6daf 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -261,6 +261,155 @@ struct TypeChecker2 { } + static bool allowsNoReturnValues(const TypePackId tp) + { + for (TypeId ty : tp) + { + if (!get(follow(ty))) + return false; + } + + return true; + } + + static Location getEndLocation(const AstExprFunction* function) + { + Location loc = function->location; + if (loc.begin.line != loc.end.line) + { + Position begin = loc.end; + begin.column = std::max(0u, begin.column - 3); + loc = Location(begin, 3); + } + + return loc; + } + + bool isErrorCall(const AstExprCall* call) + { + const AstExprGlobal* global = call->func->as(); + if (!global) + return false; + + if (global->name == "error") + return true; + else if (global->name == "assert") + { + // assert() will error because it is missing the first argument + if (call->args.size == 0) + return true; + + if (AstExprConstantBool* expr = call->args.data[0]->as()) + if (!expr->value) + return true; + } + + return false; + } + + bool hasBreak(AstStat* node) + { + if (AstStatBlock* stat = node->as()) + { + for (size_t i = 0; i < stat->body.size; ++i) + { + if (hasBreak(stat->body.data[i])) + return true; + } + + return false; + } + + if (node->is()) + return true; + + if (AstStatIf* stat = node->as()) + { + if (hasBreak(stat->thenbody)) + return true; + + if (stat->elsebody && hasBreak(stat->elsebody)) + return true; + + return false; + } + + return false; + } + + // returns the last statement before the block implicitly exits, or nullptr if the block does not implicitly exit + // i.e. returns nullptr if the block returns properly or never returns + const AstStat* getFallthrough(const AstStat* node) + { + if (const AstStatBlock* stat = node->as()) + { + if (stat->body.size == 0) + return stat; + + for (size_t i = 0; i < stat->body.size - 1; ++i) + { + if (getFallthrough(stat->body.data[i]) == nullptr) + return nullptr; + } + + return getFallthrough(stat->body.data[stat->body.size - 1]); + } + + if (const AstStatIf* stat = node->as()) + { + if (const AstStat* thenf = getFallthrough(stat->thenbody)) + return thenf; + + if (stat->elsebody) + { + if (const AstStat* elsef = getFallthrough(stat->elsebody)) + return elsef; + + return nullptr; + } + else + return stat; + } + + if (node->is()) + return nullptr; + + if (const AstStatExpr* stat = node->as()) + { + if (AstExprCall* call = stat->expr->as(); call && isErrorCall(call)) + return nullptr; + + return stat; + } + + if (const AstStatWhile* stat = node->as()) + { + if (AstExprConstantBool* expr = stat->condition->as()) + { + if (expr->value && !hasBreak(stat->body)) + return nullptr; + } + + return node; + } + + if (const AstStatRepeat* stat = node->as()) + { + if (AstExprConstantBool* expr = stat->condition->as()) + { + if (!expr->value && !hasBreak(stat->body)) + return nullptr; + } + + if (getFallthrough(stat->body) == nullptr) + return nullptr; + + return node; + } + + return node; + } + std::optional pushStack(AstNode* node) { if (Scope** scope = module->astScopes.find(node)) @@ -1723,6 +1872,10 @@ struct TypeChecker2 ++argIt; } + + bool reachesImplicitReturn = getFallthrough(fn->body) != nullptr; + if (reachesImplicitReturn && !allowsNoReturnValues(follow(inferredFtv->retTypes))) + reportError(FunctionExitsWithoutReturning{inferredFtv->retTypes}, getEndLocation(fn)); } visit(fn->body); diff --git a/Analysis/src/TypeFamily.cpp b/Analysis/src/TypeFamily.cpp index ed81b663..4903865f 100644 --- a/Analysis/src/TypeFamily.cpp +++ b/Analysis/src/TypeFamily.cpp @@ -1052,6 +1052,73 @@ TypeFamilyReductionResult refineFamilyFn(const std::vector& type return {resultTy, false, {}, {}}; } +TypeFamilyReductionResult unionFamilyFn(const std::vector& typeParams, const std::vector& packParams, NotNull ctx) +{ + if (typeParams.size() != 2 || !packParams.empty()) + { + ctx->ice->ice("union type family: encountered a type family instance without the required argument structure"); + LUAU_ASSERT(false); + } + + TypeId lhsTy = follow(typeParams.at(0)); + TypeId rhsTy = follow(typeParams.at(1)); + + // check to see if both operand types are resolved enough, and wait to reduce if not + if (isPending(lhsTy, ctx->solver)) + return {std::nullopt, false, {lhsTy}, {}}; + else if (get(lhsTy)) // if the lhs is never, we don't need this family anymore + return {rhsTy, false, {}, {}}; + else if (isPending(rhsTy, ctx->solver)) + return {std::nullopt, false, {rhsTy}, {}}; + else if (get(rhsTy)) // if the rhs is never, we don't need this family anymore + return {lhsTy, false, {}, {}}; + + + SimplifyResult result = simplifyUnion(ctx->builtins, ctx->arena, lhsTy, rhsTy); + if (!result.blockedTypes.empty()) + return {std::nullopt, false, {result.blockedTypes.begin(), result.blockedTypes.end()}, {}}; + + return {result.result, false, {}, {}}; +} + + +TypeFamilyReductionResult intersectFamilyFn(const std::vector& typeParams, const std::vector& packParams, NotNull ctx) +{ + if (typeParams.size() != 2 || !packParams.empty()) + { + ctx->ice->ice("intersect type family: encountered a type family instance without the required argument structure"); + LUAU_ASSERT(false); + } + + TypeId lhsTy = follow(typeParams.at(0)); + TypeId rhsTy = follow(typeParams.at(1)); + + // check to see if both operand types are resolved enough, and wait to reduce if not + if (isPending(lhsTy, ctx->solver)) + return {std::nullopt, false, {lhsTy}, {}}; + else if (get(lhsTy)) // if the lhs is never, we don't need this family anymore + return {ctx->builtins->neverType, false, {}, {}}; + else if (isPending(rhsTy, ctx->solver)) + return {std::nullopt, false, {rhsTy}, {}}; + else if (get(rhsTy)) // if the rhs is never, we don't need this family anymore + return {ctx->builtins->neverType, false, {}, {}}; + + SimplifyResult result = simplifyIntersection(ctx->builtins, ctx->arena, lhsTy, rhsTy); + if (!result.blockedTypes.empty()) + return {std::nullopt, false, {result.blockedTypes.begin(), result.blockedTypes.end()}, {}}; + + // if the intersection simplifies to `never`, this gives us bad autocomplete. + // we'll just produce the intersection plainly instead, but this might be revisitable + // if we ever give `never` some kind of "explanation" trail. + if (get(result.result)) + { + TypeId intersection = ctx->arena->addType(IntersectionType{{lhsTy, rhsTy}}); + return {intersection, false, {}, {}}; + } + + return {result.result, false, {}, {}}; +} + // computes the keys of `ty` into `result` // `isRaw` parameter indicates whether or not we should follow __index metamethods // returns `false` if `result` should be ignored because the answer is "all strings" @@ -1262,6 +1329,8 @@ BuiltinTypeFamilies::BuiltinTypeFamilies() , leFamily{"le", leFamilyFn} , eqFamily{"eq", eqFamilyFn} , refineFamily{"refine", refineFamilyFn} + , unionFamily{"union", unionFamilyFn} + , intersectFamily{"intersect", intersectFamilyFn} , keyofFamily{"keyof", keyofFamilyFn} , rawkeyofFamily{"rawkeyof", rawkeyofFamilyFn} { diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index a1247915..653beb0e 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -40,6 +40,7 @@ LUAU_FASTFLAGVARIABLE(LuauLoopControlFlowAnalysis, false) LUAU_FASTFLAGVARIABLE(LuauAlwaysCommitInferencesOfFunctionCalls, false) LUAU_FASTFLAGVARIABLE(LuauRemoveBadRelationalOperatorWarning, false) LUAU_FASTFLAGVARIABLE(LuauForbidAliasNamedTypeof, false) +LUAU_FASTFLAGVARIABLE(LuauOkWithIteratingOverTableProperties, false) namespace Luau { @@ -1335,10 +1336,10 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) for (size_t i = 2; i < varTypes.size(); ++i) unify(nilType, varTypes[i], scope, forin.location); } - else if (isNonstrictMode()) + else if (isNonstrictMode() || FFlag::LuauOkWithIteratingOverTableProperties) { for (TypeId var : varTypes) - unify(anyType, var, scope, forin.location); + unify(unknownType, var, scope, forin.location); } else { diff --git a/Analysis/src/Unifier2.cpp b/Analysis/src/Unifier2.cpp index 6b213aea..cba2f4bb 100644 --- a/Analysis/src/Unifier2.cpp +++ b/Analysis/src/Unifier2.cpp @@ -56,6 +56,12 @@ bool Unifier2::unify(TypeId subTy, TypeId superTy) if (subFree || superFree) return true; + if (auto subLocal = getMutable(subTy)) + { + subLocal->domain = mkUnion(subLocal->domain, superTy); + expandedFreeTypes[subTy].push_back(superTy); + } + auto subFn = get(subTy); auto superFn = get(superTy); if (subFn && superFn) diff --git a/Ast/include/Luau/Ast.h b/Ast/include/Luau/Ast.h index d4836bb5..993116d6 100644 --- a/Ast/include/Luau/Ast.h +++ b/Ast/include/Luau/Ast.h @@ -387,7 +387,7 @@ public: AstExprFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, AstLocal* self, const AstArray& args, bool vararg, const Location& varargLocation, AstStatBlock* body, size_t functionDepth, const AstName& debugname, const std::optional& returnAnnotation = {}, AstTypePack* varargAnnotation = nullptr, - bool DEPRECATED_hasEnd = false, const std::optional& argLocation = std::nullopt); + const std::optional& argLocation = std::nullopt); void visit(AstVisitor* visitor) override; @@ -406,8 +406,6 @@ public: AstName debugname; - // TODO clip with FFlag::LuauClipExtraHasEndProps - bool DEPRECATED_hasEnd = false; std::optional argLocation; }; @@ -573,7 +571,7 @@ public: LUAU_RTTI(AstStatIf) AstStatIf(const Location& location, AstExpr* condition, AstStatBlock* thenbody, AstStat* elsebody, const std::optional& thenLocation, - const std::optional& elseLocation, bool DEPRECATED_hasEnd); + const std::optional& elseLocation); void visit(AstVisitor* visitor) override; @@ -585,9 +583,6 @@ public: // Active for 'elseif' as well std::optional elseLocation; - - // TODO clip with FFlag::LuauClipExtraHasEndProps - bool DEPRECATED_hasEnd = false; }; class AstStatWhile : public AstStat @@ -595,7 +590,7 @@ class AstStatWhile : public AstStat public: LUAU_RTTI(AstStatWhile) - AstStatWhile(const Location& location, AstExpr* condition, AstStatBlock* body, bool hasDo, const Location& doLocation, bool DEPRECATED_hasEnd); + AstStatWhile(const Location& location, AstExpr* condition, AstStatBlock* body, bool hasDo, const Location& doLocation); void visit(AstVisitor* visitor) override; @@ -604,9 +599,6 @@ public: bool hasDo = false; Location doLocation; - - // TODO clip with FFlag::LuauClipExtraHasEndProps - bool DEPRECATED_hasEnd = false; }; class AstStatRepeat : public AstStat @@ -690,7 +682,7 @@ public: LUAU_RTTI(AstStatFor) AstStatFor(const Location& location, AstLocal* var, AstExpr* from, AstExpr* to, AstExpr* step, AstStatBlock* body, bool hasDo, - const Location& doLocation, bool DEPRECATED_hasEnd); + const Location& doLocation); void visit(AstVisitor* visitor) override; @@ -702,9 +694,6 @@ public: bool hasDo = false; Location doLocation; - - // TODO clip with FFlag::LuauClipExtraHasEndProps - bool DEPRECATED_hasEnd = false; }; class AstStatForIn : public AstStat @@ -713,7 +702,7 @@ public: LUAU_RTTI(AstStatForIn) AstStatForIn(const Location& location, const AstArray& vars, const AstArray& values, AstStatBlock* body, bool hasIn, - const Location& inLocation, bool hasDo, const Location& doLocation, bool DEPRECATED_hasEnd); + const Location& inLocation, bool hasDo, const Location& doLocation); void visit(AstVisitor* visitor) override; @@ -726,9 +715,6 @@ public: bool hasDo = false; Location doLocation; - - // TODO clip with FFlag::LuauClipExtraHasEndProps - bool DEPRECATED_hasEnd = false; }; class AstStatAssign : public AstStat diff --git a/Ast/src/Ast.cpp b/Ast/src/Ast.cpp index 9a6ca4d7..0409a622 100644 --- a/Ast/src/Ast.cpp +++ b/Ast/src/Ast.cpp @@ -163,7 +163,7 @@ void AstExprIndexExpr::visit(AstVisitor* visitor) AstExprFunction::AstExprFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, AstLocal* self, const AstArray& args, bool vararg, const Location& varargLocation, AstStatBlock* body, size_t functionDepth, - const AstName& debugname, const std::optional& returnAnnotation, AstTypePack* varargAnnotation, bool DEPRECATED_hasEnd, + const AstName& debugname, const std::optional& returnAnnotation, AstTypePack* varargAnnotation, const std::optional& argLocation) : AstExpr(ClassIndex(), location) , generics(generics) @@ -177,7 +177,6 @@ AstExprFunction::AstExprFunction(const Location& location, const AstArray& thenLocation, const std::optional& elseLocation, bool DEPRECATED_hasEnd) + const std::optional& thenLocation, const std::optional& elseLocation) : AstStat(ClassIndex(), location) , condition(condition) , thenbody(thenbody) , elsebody(elsebody) , thenLocation(thenLocation) , elseLocation(elseLocation) - , DEPRECATED_hasEnd(DEPRECATED_hasEnd) { } @@ -418,13 +416,12 @@ void AstStatIf::visit(AstVisitor* visitor) } } -AstStatWhile::AstStatWhile(const Location& location, AstExpr* condition, AstStatBlock* body, bool hasDo, const Location& doLocation, bool DEPRECATED_hasEnd) +AstStatWhile::AstStatWhile(const Location& location, AstExpr* condition, AstStatBlock* body, bool hasDo, const Location& doLocation) : AstStat(ClassIndex(), location) , condition(condition) , body(body) , hasDo(hasDo) , doLocation(doLocation) - , DEPRECATED_hasEnd(DEPRECATED_hasEnd) { } @@ -526,7 +523,7 @@ void AstStatLocal::visit(AstVisitor* visitor) } AstStatFor::AstStatFor(const Location& location, AstLocal* var, AstExpr* from, AstExpr* to, AstExpr* step, AstStatBlock* body, bool hasDo, - const Location& doLocation, bool DEPRECATED_hasEnd) + const Location& doLocation) : AstStat(ClassIndex(), location) , var(var) , from(from) @@ -535,7 +532,6 @@ AstStatFor::AstStatFor(const Location& location, AstLocal* var, AstExpr* from, A , body(body) , hasDo(hasDo) , doLocation(doLocation) - , DEPRECATED_hasEnd(DEPRECATED_hasEnd) { } @@ -557,7 +553,7 @@ void AstStatFor::visit(AstVisitor* visitor) } AstStatForIn::AstStatForIn(const Location& location, const AstArray& vars, const AstArray& values, AstStatBlock* body, - bool hasIn, const Location& inLocation, bool hasDo, const Location& doLocation, bool DEPRECATED_hasEnd) + bool hasIn, const Location& inLocation, bool hasDo, const Location& doLocation) : AstStat(ClassIndex(), location) , vars(vars) , values(values) @@ -566,7 +562,6 @@ AstStatForIn::AstStatForIn(const Location& location, const AstArray& , inLocation(inLocation) , hasDo(hasDo) , doLocation(doLocation) - , DEPRECATED_hasEnd(DEPRECATED_hasEnd) { } diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 6d15e709..c4d1c65d 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -16,7 +16,6 @@ LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) // Warning: If you are introducing new syntax, ensure that it is behind a separate // flag so that we don't break production games by reverting syntax changes. // See docs/SyntaxChanges.md for an explanation. -LUAU_FASTFLAGVARIABLE(LuauClipExtraHasEndProps, false) LUAU_FASTFLAG(LuauCheckedFunctionSyntax) LUAU_FASTFLAGVARIABLE(LuauReadWritePropertySyntax, false) @@ -373,18 +372,15 @@ AstStat* Parser::parseIf() AstStat* elsebody = nullptr; Location end = start; std::optional elseLocation; - bool DEPRECATED_hasEnd = false; if (lexer.current().type == Lexeme::ReservedElseif) { - if (FFlag::LuauClipExtraHasEndProps) - thenbody->hasEnd = true; + thenbody->hasEnd = true; unsigned int oldRecursionCount = recursionCounter; incrementRecursionCounter("elseif"); elseLocation = lexer.current().location; elsebody = parseIf(); end = elsebody->location; - DEPRECATED_hasEnd = elsebody->as()->DEPRECATED_hasEnd; recursionCounter = oldRecursionCount; } else @@ -393,8 +389,7 @@ AstStat* Parser::parseIf() if (lexer.current().type == Lexeme::ReservedElse) { - if (FFlag::LuauClipExtraHasEndProps) - thenbody->hasEnd = true; + thenbody->hasEnd = true; elseLocation = lexer.current().location; matchThenElse = lexer.current(); nextLexeme(); @@ -406,21 +401,17 @@ AstStat* Parser::parseIf() end = lexer.current().location; bool hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchThenElse); - DEPRECATED_hasEnd = hasEnd; - if (FFlag::LuauClipExtraHasEndProps) + if (elsebody) { - if (elsebody) - { - if (AstStatBlock* elseBlock = elsebody->as()) - elseBlock->hasEnd = hasEnd; - } - else - thenbody->hasEnd = hasEnd; + if (AstStatBlock* elseBlock = elsebody->as()) + elseBlock->hasEnd = hasEnd; } + else + thenbody->hasEnd = hasEnd; } - return allocator.alloc(Location(start, end), cond, thenbody, elsebody, thenLocation, elseLocation, DEPRECATED_hasEnd); + return allocator.alloc(Location(start, end), cond, thenbody, elsebody, thenLocation, elseLocation); } // while exp do block end @@ -444,10 +435,9 @@ AstStat* Parser::parseWhile() Location end = lexer.current().location; bool hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchDo); - if (FFlag::LuauClipExtraHasEndProps) - body->hasEnd = hasEnd; + body->hasEnd = hasEnd; - return allocator.alloc(Location(start, end), cond, body, hasDo, matchDo.location, hasEnd); + return allocator.alloc(Location(start, end), cond, body, hasDo, matchDo.location); } // repeat block until exp @@ -467,8 +457,7 @@ AstStat* Parser::parseRepeat() functionStack.back().loopDepth--; bool hasUntil = expectMatchEndAndConsume(Lexeme::ReservedUntil, matchRepeat); - if (FFlag::LuauClipExtraHasEndProps) - body->hasEnd = hasUntil; + body->hasEnd = hasUntil; AstExpr* cond = parseExpr(); @@ -565,10 +554,9 @@ AstStat* Parser::parseFor() Location end = lexer.current().location; bool hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchDo); - if (FFlag::LuauClipExtraHasEndProps) - body->hasEnd = hasEnd; + body->hasEnd = hasEnd; - return allocator.alloc(Location(start, end), var, from, to, step, body, hasDo, matchDo.location, hasEnd); + return allocator.alloc(Location(start, end), var, from, to, step, body, hasDo, matchDo.location); } else { @@ -609,11 +597,10 @@ AstStat* Parser::parseFor() Location end = lexer.current().location; bool hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchDo); - if (FFlag::LuauClipExtraHasEndProps) - body->hasEnd = hasEnd; + body->hasEnd = hasEnd; return allocator.alloc( - Location(start, end), copy(vars), copy(values), body, hasIn, inLocation, hasDo, matchDo.location, hasEnd); + Location(start, end), copy(vars), copy(values), body, hasIn, inLocation, hasDo, matchDo.location); } } @@ -1100,11 +1087,10 @@ std::pair Parser::parseFunctionBody( Location end = lexer.current().location; bool hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchFunction); - if (FFlag::LuauClipExtraHasEndProps) - body->hasEnd = hasEnd; + body->hasEnd = hasEnd; return {allocator.alloc(Location(start, end), generics, genericPacks, self, vars, vararg, varargLocation, body, - functionStack.size(), debugname, typelist, varargAnnotation, hasEnd, argLocation), + functionStack.size(), debugname, typelist, varargAnnotation, argLocation), funLocal}; } diff --git a/CodeGen/include/Luau/AssemblyBuilderA64.h b/CodeGen/include/Luau/AssemblyBuilderA64.h index 0ed35910..78251012 100644 --- a/CodeGen/include/Luau/AssemblyBuilderA64.h +++ b/CodeGen/include/Luau/AssemblyBuilderA64.h @@ -139,6 +139,10 @@ public: void fsqrt(RegisterA64 dst, RegisterA64 src); void fsub(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + void ins_4s(RegisterA64 dst, RegisterA64 src, uint8_t index); + void ins_4s(RegisterA64 dst, uint8_t dstIndex, RegisterA64 src, uint8_t srcIndex); + void dup_4s(RegisterA64 dst, RegisterA64 src, uint8_t index); + // Floating-point rounding and conversions void frinta(RegisterA64 dst, RegisterA64 src); void frintm(RegisterA64 dst, RegisterA64 src); @@ -207,6 +211,7 @@ private: void placeSR3(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t op, int shift = 0, int N = 0); void placeSR2(const char* name, RegisterA64 dst, RegisterA64 src, uint8_t op, uint8_t op2 = 0); void placeR3(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t op, uint8_t op2); + void placeR3(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t sizes, uint8_t op, uint8_t op2); void placeR1(const char* name, RegisterA64 dst, RegisterA64 src, uint32_t op); void placeI12(const char* name, RegisterA64 dst, RegisterA64 src1, int src2, uint8_t op); void placeI16(const char* name, RegisterA64 dst, int src, uint8_t op, int shift = 0); diff --git a/CodeGen/include/Luau/AssemblyBuilderX64.h b/CodeGen/include/Luau/AssemblyBuilderX64.h index e5b82f08..0be59fb1 100644 --- a/CodeGen/include/Luau/AssemblyBuilderX64.h +++ b/CodeGen/include/Luau/AssemblyBuilderX64.h @@ -119,13 +119,18 @@ public: void vaddss(OperandX64 dst, OperandX64 src1, OperandX64 src2); void vsubsd(OperandX64 dst, OperandX64 src1, OperandX64 src2); + void vsubps(OperandX64 dst, OperandX64 src1, OperandX64 src2); void vmulsd(OperandX64 dst, OperandX64 src1, OperandX64 src2); + void vmulps(OperandX64 dst, OperandX64 src1, OperandX64 src2); void vdivsd(OperandX64 dst, OperandX64 src1, OperandX64 src2); + void vdivps(OperandX64 dst, OperandX64 src1, OperandX64 src2); + void vandps(OperandX64 dst, OperandX64 src1, OperandX64 src2); void vandpd(OperandX64 dst, OperandX64 src1, OperandX64 src2); void vandnpd(OperandX64 dst, OperandX64 src1, OperandX64 src2); void vxorpd(OperandX64 dst, OperandX64 src1, OperandX64 src2); + void vorps(OperandX64 dst, OperandX64 src1, OperandX64 src2); void vorpd(OperandX64 dst, OperandX64 src1, OperandX64 src2); void vucomisd(OperandX64 src1, OperandX64 src2); @@ -159,6 +164,9 @@ public: void vblendvpd(RegisterX64 dst, RegisterX64 src1, OperandX64 mask, RegisterX64 src3); + void vpshufps(RegisterX64 dst, RegisterX64 src1, OperandX64 src2, uint8_t shuffle); + void vpinsrd(RegisterX64 dst, RegisterX64 src1, OperandX64 src2, uint8_t offset); + // Run final checks bool finalize(); @@ -176,9 +184,11 @@ public: } // Constant allocation (uses rip-relative addressing) + OperandX64 i32(int32_t value); OperandX64 i64(int64_t value); OperandX64 f32(float value); OperandX64 f64(double value); + OperandX64 u32x4(uint32_t x, uint32_t y, uint32_t z, uint32_t w); OperandX64 f32x4(float x, float y, float z, float w); OperandX64 f64x2(double x, double y); OperandX64 bytes(const void* ptr, size_t size, size_t align = 8); @@ -260,6 +270,7 @@ private: std::vector