From ae459a01972a8db6e08c1dd60cc46da70648b70d Mon Sep 17 00:00:00 2001 From: aaron Date: Fri, 8 Mar 2024 16:47:53 -0800 Subject: [PATCH 1/5] Sync to upstream/release/616 (#1184) # What's Changed * Add a compiler hint to improve Luau memory allocation inlining ### New Type Solver * Added a system for recommending explicit type annotations to users in cases where we've inferred complex generic types with type families. * Marked string library functions as `@checked` for use in new non-strict mode. * Fixed a bug with new non-strict mode where we would incorrectly report arity mismatches when missing optional arguments. * Implement an occurs check for unifications that would produce self-recursive types. * Fix bug where overload resolution would fail when applied to non-overloaded functions. * Fix bug that caused the subtyping to report an error whenever a generic was instantiated in an invariant context. * Fix crash caused by `SetPropConstraint` not blocking properly. ### Native Code Generation * Implement optimization to eliminate dead stores * Optimize vector ops for X64 when the source is computed (thanks, @zeux!) * Use more efficient lowering for UNM_* (thanks, @zeux!) --- ### Internal Contributors Co-authored-by: Aaron Weiss Co-authored-by: Alexander McCord Co-authored-by: Andy Friesen Co-authored-by: David Cope Co-authored-by: Lily Brown Co-authored-by: Vyacheslav Egorov --------- Co-authored-by: Alexander McCord Co-authored-by: Andy Friesen Co-authored-by: Vighnesh Co-authored-by: Aviral Goel Co-authored-by: David Cope Co-authored-by: Lily Brown Co-authored-by: Vyacheslav Egorov --- Analysis/include/Luau/BuiltinDefinitions.h | 10 +- Analysis/include/Luau/Error.h | 24 +- Analysis/include/Luau/Normalize.h | 3 + Analysis/include/Luau/OverloadResolution.h | 32 ++ Analysis/include/Luau/Subtyping.h | 18 +- Analysis/include/Luau/TypeFamily.h | 2 + .../include/Luau/TypeFamilyReductionGuesser.h | 81 +++ Analysis/include/Luau/Unifier2.h | 4 + Analysis/src/BuiltinDefinitions.cpp | 222 +++++--- Analysis/src/ConstraintGenerator.cpp | 11 +- Analysis/src/ConstraintSolver.cpp | 11 +- Analysis/src/Error.cpp | 32 ++ Analysis/src/IostreamHelpers.cpp | 9 + Analysis/src/NonStrictTypeChecker.cpp | 18 +- Analysis/src/Normalize.cpp | 19 + Analysis/src/OverloadResolution.cpp | 88 ++- Analysis/src/Scope.cpp | 5 +- Analysis/src/Simplify.cpp | 3 + Analysis/src/Subtyping.cpp | 19 +- Analysis/src/ToString.cpp | 12 +- Analysis/src/TypeChecker2.cpp | 32 +- Analysis/src/TypeFamily.cpp | 71 ++- Analysis/src/TypeFamilyReductionGuesser.cpp | 409 ++++++++++++++ Analysis/src/Unifier2.cpp | 59 ++ CodeGen/include/Luau/IrData.h | 1 + CodeGen/include/Luau/IrVisitUseDef.h | 17 +- CodeGen/include/Luau/OptimizeDeadStore.h | 16 + CodeGen/src/CodeAllocator.cpp | 8 + CodeGen/src/CodeGenLower.h | 5 + CodeGen/src/IrLoweringX64.cpp | 136 +++-- CodeGen/src/OptimizeDeadStore.cpp | 530 ++++++++++++++++++ Sources.cmake | 4 + VM/src/lmem.cpp | 5 +- tests/ClassFixture.cpp | 6 + tests/IrBuilder.test.cpp | 420 +++++++++++++- tests/IrLowering.test.cpp | 52 +- tests/NonStrictTypeChecker.test.cpp | 31 + tests/ScopedFlags.h | 2 +- tests/Simplify.test.cpp | 6 + tests/TypeFamily.test.cpp | 23 + tests/TypeInfer.functions.test.cpp | 56 ++ tests/TypeInfer.loops.test.cpp | 13 + tests/TypeInfer.oop.test.cpp | 17 + tests/TypeInfer.operators.test.cpp | 18 + tests/TypeInfer.refinements.test.cpp | 35 ++ tests/TypeInfer.tables.test.cpp | 12 +- tests/conformance/native.lua | 20 + tools/faillist.txt | 15 - 48 files changed, 2365 insertions(+), 277 deletions(-) create mode 100644 Analysis/include/Luau/TypeFamilyReductionGuesser.h create mode 100644 Analysis/src/TypeFamilyReductionGuesser.cpp create mode 100644 CodeGen/include/Luau/OptimizeDeadStore.h create mode 100644 CodeGen/src/OptimizeDeadStore.cpp diff --git a/Analysis/include/Luau/BuiltinDefinitions.h b/Analysis/include/Luau/BuiltinDefinitions.h index 6154f3d1..5622c143 100644 --- a/Analysis/include/Luau/BuiltinDefinitions.h +++ b/Analysis/include/Luau/BuiltinDefinitions.h @@ -25,19 +25,21 @@ TypeId makeOption(NotNull builtinTypes, TypeArena& arena, TypeId t /** Small utility function for building up type definitions from C++. */ TypeId makeFunction( // Monomorphic - TypeArena& arena, std::optional selfType, std::initializer_list paramTypes, std::initializer_list retTypes); + TypeArena& arena, std::optional selfType, std::initializer_list paramTypes, std::initializer_list retTypes, + bool checked = false); TypeId makeFunction( // Polymorphic TypeArena& arena, std::optional selfType, std::initializer_list generics, std::initializer_list genericPacks, - std::initializer_list paramTypes, std::initializer_list retTypes); + std::initializer_list paramTypes, std::initializer_list retTypes, bool checked = false); TypeId makeFunction( // Monomorphic TypeArena& arena, std::optional selfType, std::initializer_list paramTypes, std::initializer_list paramNames, - std::initializer_list retTypes); + std::initializer_list retTypes, bool checked = false); TypeId makeFunction( // Polymorphic TypeArena& arena, std::optional selfType, std::initializer_list generics, std::initializer_list genericPacks, - std::initializer_list paramTypes, std::initializer_list paramNames, std::initializer_list retTypes); + std::initializer_list paramTypes, std::initializer_list paramNames, std::initializer_list retTypes, + bool checked = false); void attachMagicFunction(TypeId ty, MagicFunction fn); void attachDcrMagicFunction(TypeId ty, DcrMagicFunction fn); diff --git a/Analysis/include/Luau/Error.h b/Analysis/include/Luau/Error.h index d9b5f1ba..4fbb4089 100644 --- a/Analysis/include/Luau/Error.h +++ b/Analysis/include/Luau/Error.h @@ -341,6 +341,13 @@ struct UninhabitedTypeFamily bool operator==(const UninhabitedTypeFamily& rhs) const; }; +struct ExplicitFunctionAnnotationRecommended +{ + std::vector> recommendedArgs; + TypeId recommendedReturn; + bool operator==(const ExplicitFunctionAnnotationRecommended& rhs) const; +}; + struct UninhabitedTypePackFamily { TypePackId tp; @@ -416,14 +423,15 @@ struct UnexpectedTypePackInSubtyping bool operator==(const UnexpectedTypePackInSubtyping& rhs) const; }; -using TypeErrorData = Variant; +using TypeErrorData = + Variant; struct TypeErrorSummary { diff --git a/Analysis/include/Luau/Normalize.h b/Analysis/include/Luau/Normalize.h index 9458462d..9d6312a5 100644 --- a/Analysis/include/Luau/Normalize.h +++ b/Analysis/include/Luau/Normalize.h @@ -307,6 +307,9 @@ struct NormalizedType bool hasTables() const; bool hasFunctions() const; bool hasTyvars() const; + + bool isFalsy() const; + bool isTruthy() const; }; diff --git a/Analysis/include/Luau/OverloadResolution.h b/Analysis/include/Luau/OverloadResolution.h index 42256727..bf9a5d40 100644 --- a/Analysis/include/Luau/OverloadResolution.h +++ b/Analysis/include/Luau/OverloadResolution.h @@ -67,4 +67,36 @@ private: void add(Analysis analysis, TypeId ty, ErrorVec&& errors); }; +struct SolveResult +{ + enum OverloadCallResult { + Ok, + CodeTooComplex, + OccursCheckFailed, + NoMatchingOverload, + }; + + OverloadCallResult result; + std::optional typePackId; // nullopt if result != Ok + + TypeId overloadToUse = nullptr; + TypeId inferredTy = nullptr; + DenseHashMap> expandedFreeTypes{nullptr}; +}; + +// Helper utility, presently used for binary operator type families. +// +// Given a function and a set of arguments, select a suitable overload. +SolveResult solveFunctionCall( + NotNull arena, + NotNull builtinTypes, + NotNull normalizer, + NotNull iceReporter, + NotNull limits, + NotNull scope, + const Location& location, + TypeId fn, + TypePackId argsPack +); + } // namespace Luau diff --git a/Analysis/include/Luau/Subtyping.h b/Analysis/include/Luau/Subtyping.h index 8fa4b8b0..a421fc99 100644 --- a/Analysis/include/Luau/Subtyping.h +++ b/Analysis/include/Luau/Subtyping.h @@ -219,23 +219,7 @@ private: template TypeId makeAggregateType(const Container& container, TypeId orElse); - - std::pair handleTypeFamilyReductionResult(const TypeFamilyInstanceType* familyInstance) - { - TypeFamilyContext context{arena, builtinTypes, scope, normalizer, iceReporter, NotNull{&limits}}; - TypeId family = arena->addType(*familyInstance); - std::string familyString = toString(family); - FamilyGraphReductionResult result = reduceFamilies(family, {}, context, true); - ErrorVec errors; - if (result.blockedTypes.size() != 0 || result.blockedPacks.size() != 0) - { - errors.push_back(TypeError{{}, UninhabitedTypeFamily{family}}); - return {builtinTypes->neverType, errors}; - } - if (result.reducedTypes.contains(family)) - return {family, errors}; - return {builtinTypes->neverType, errors}; - } + std::pair handleTypeFamilyReductionResult(const TypeFamilyInstanceType* familyInstance); [[noreturn]] void unexpected(TypeId ty); [[noreturn]] void unexpected(TypePackId tp); diff --git a/Analysis/include/Luau/TypeFamily.h b/Analysis/include/Luau/TypeFamily.h index 389a2e00..99f4f446 100644 --- a/Analysis/include/Luau/TypeFamily.h +++ b/Analysis/include/Luau/TypeFamily.h @@ -57,6 +57,8 @@ struct TypeFamilyContext , constraint(nullptr) { } + + NotNull pushConstraint(ConstraintV&& c); }; /// Represents a reduction result, which may have successfully reduced the type, /// may have concretely failed to reduce the type, or may simply be stuck diff --git a/Analysis/include/Luau/TypeFamilyReductionGuesser.h b/Analysis/include/Luau/TypeFamilyReductionGuesser.h new file mode 100644 index 00000000..0092c317 --- /dev/null +++ b/Analysis/include/Luau/TypeFamilyReductionGuesser.h @@ -0,0 +1,81 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#pragma once + +#include "Luau/Ast.h" +#include "Luau/VecDeque.h" +#include "Luau/DenseHash.h" +#include "Luau/TypeFamily.h" +#include "Luau/Type.h" +#include "Luau/TypePack.h" +#include "Luau/TypeUtils.h" +#include "Luau/Normalize.h" +#include "Luau/TypeFwd.h" +#include "Luau/VisitType.h" +#include "Luau/NotNull.h" + +namespace Luau +{ + +struct TypeFamilyReductionGuessResult +{ + std::vector> guessedFunctionAnnotations; + TypeId guessedReturnType; + bool shouldRecommendAnnotation = true; +}; + +// An Inference result for a type family is a list of types corresponding to the guessed argument types, followed by a type for the result +struct TypeFamilyInferenceResult +{ + std::vector operandInference; + TypeId familyResultInference; +}; + +struct TypeFamilyReductionGuesser +{ + // Tracks our hypothesis about what a type family reduces to + DenseHashMap familyReducesTo{nullptr}; + // Tracks our constraints on type family operands + DenseHashMap substitutable{nullptr}; + // List of instances to try progress + VecDeque toInfer; + DenseHashSet cyclicInstances{nullptr}; + + // Utilities + NotNull builtins; + NotNull normalizer; + + TypeFamilyReductionGuesser(NotNull builtins, NotNull normalizer); + + TypeFamilyReductionGuessResult guessTypeFamilyReductionForFunction(const AstExprFunction& expr, const FunctionType* ftv, TypeId retTy); + +private: + std::optional guessType(TypeId arg); + void dumpGuesses(); + + bool isNumericBinopFamily(const TypeFamilyInstanceType& instance); + bool isComparisonFamily(const TypeFamilyInstanceType& instance); + bool isOrAndFamily(const TypeFamilyInstanceType& instance); + bool isNotFamily(const TypeFamilyInstanceType& instance); + bool isLenFamily(const TypeFamilyInstanceType& instance); + bool isUnaryMinus(const TypeFamilyInstanceType& instance); + + // Operand is assignable if it looks like a cyclic family instance, or a generic type + bool operandIsAssignable(TypeId ty); + std::optional tryAssignOperandType(TypeId ty); + + const NormalizedType* normalize(TypeId ty); + void step(); + void infer(); + bool done(); + + bool isFunctionGenericsSaturated(const FunctionType& ftv, DenseHashSet& instanceArgs); + void inferTypeFamilySubstitutions(TypeId ty, const TypeFamilyInstanceType* instance); + TypeFamilyInferenceResult inferNumericBinopFamily(const TypeFamilyInstanceType* instance); + TypeFamilyInferenceResult inferComparisonFamily(const TypeFamilyInstanceType* instance); + TypeFamilyInferenceResult inferOrAndFamily(const TypeFamilyInstanceType* instance); + TypeFamilyInferenceResult inferNotFamily(const TypeFamilyInstanceType* instance); + TypeFamilyInferenceResult inferLenFamily(const TypeFamilyInstanceType* instance); + TypeFamilyInferenceResult inferUnaryMinusFamily(const TypeFamilyInstanceType* instance); +}; +} // namespace Luau diff --git a/Analysis/include/Luau/Unifier2.h b/Analysis/include/Luau/Unifier2.h index a0553a2a..6728c0f0 100644 --- a/Analysis/include/Luau/Unifier2.h +++ b/Analysis/include/Luau/Unifier2.h @@ -86,6 +86,10 @@ private: */ TypeId mkIntersection(TypeId left, TypeId right); + // Returns true if needle occurs within haystack already. ie if we bound + // needle to haystack, would a cyclic type result? + OccursCheckResult occursCheck(DenseHashSet& seen, TypeId needle, TypeId haystack); + // Returns true if needle occurs within haystack already. ie if we bound // needle to haystack, would a cyclic TypePack result? OccursCheckResult occursCheck(DenseHashSet& seen, TypePackId needle, TypePackId haystack); diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index 7a6188a6..3a7fd724 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -25,6 +25,7 @@ LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAGVARIABLE(LuauSetMetatableOnUnionsOfTables, false); +LUAU_FASTFLAGVARIABLE(LuauMakeStringMethodsChecked, false); namespace Luau { @@ -62,26 +63,26 @@ TypeId makeOption(NotNull builtinTypes, TypeArena& arena, TypeId t } TypeId makeFunction( - TypeArena& arena, std::optional selfType, std::initializer_list paramTypes, std::initializer_list retTypes) + TypeArena& arena, std::optional selfType, std::initializer_list paramTypes, std::initializer_list retTypes, bool checked) { - return makeFunction(arena, selfType, {}, {}, paramTypes, {}, retTypes); + return makeFunction(arena, selfType, {}, {}, paramTypes, {}, retTypes, checked); } TypeId makeFunction(TypeArena& arena, std::optional selfType, std::initializer_list generics, - std::initializer_list genericPacks, std::initializer_list paramTypes, std::initializer_list retTypes) + std::initializer_list genericPacks, std::initializer_list paramTypes, std::initializer_list retTypes, bool checked) { - return makeFunction(arena, selfType, generics, genericPacks, paramTypes, {}, retTypes); + return makeFunction(arena, selfType, generics, genericPacks, paramTypes, {}, retTypes, checked); } TypeId makeFunction(TypeArena& arena, std::optional selfType, std::initializer_list paramTypes, - std::initializer_list paramNames, std::initializer_list retTypes) + std::initializer_list paramNames, std::initializer_list retTypes, bool checked) { - return makeFunction(arena, selfType, {}, {}, paramTypes, paramNames, retTypes); + return makeFunction(arena, selfType, {}, {}, paramTypes, paramNames, retTypes, checked); } TypeId makeFunction(TypeArena& arena, std::optional selfType, std::initializer_list generics, std::initializer_list genericPacks, std::initializer_list paramTypes, std::initializer_list paramNames, - std::initializer_list retTypes) + std::initializer_list retTypes, bool checked) { std::vector params; if (selfType) @@ -108,6 +109,8 @@ TypeId makeFunction(TypeArena& arena, std::optional selfType, std::initi ftv.argNames.push_back(std::nullopt); } + ftv.isCheckedFunction = checked; + return arena.addType(std::move(ftv)); } @@ -289,17 +292,10 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC // declare function assert(value: T, errorMessage: string?): intersect TypeId genericT = arena.addType(GenericType{"T"}); TypeId refinedTy = arena.addType(TypeFamilyInstanceType{ - NotNull{&kBuiltinTypeFamilies.intersectFamily}, - {genericT, arena.addType(NegationType{builtinTypes->falsyType})}, - {} - }); + NotNull{&kBuiltinTypeFamilies.intersectFamily}, {genericT, arena.addType(NegationType{builtinTypes->falsyType})}, {}}); TypeId assertTy = arena.addType(FunctionType{ - {genericT}, - {}, - arena.addTypePack(TypePack{{genericT, builtinTypes->optionalStringType}}), - arena.addTypePack(TypePack{{refinedTy}}) - }); + {genericT}, {}, arena.addTypePack(TypePack{{genericT, builtinTypes->optionalStringType}}), arena.addTypePack(TypePack{{refinedTy}})}); addGlobalBinding(globals, "assert", assertTy, "@luau"); } @@ -773,72 +769,158 @@ TypeId makeStringMetatable(NotNull builtinTypes) const TypePackId anyTypePack = builtinTypes->anyTypePack; const TypePackId variadicTailPack = FFlag::DebugLuauDeferredConstraintResolution ? builtinTypes->unknownTypePack : anyTypePack; - - FunctionType formatFTV{arena->addTypePack(TypePack{{stringType}, variadicTailPack}), oneStringPack}; - formatFTV.magicFunction = &magicFunctionFormat; - const TypeId formatFn = arena->addType(formatFTV); - attachDcrMagicFunction(formatFn, dcrMagicFunctionFormat); - const TypePackId emptyPack = arena->addTypePack({}); const TypePackId stringVariadicList = arena->addTypePack(TypePackVar{VariadicTypePack{stringType}}); const TypePackId numberVariadicList = arena->addTypePack(TypePackVar{VariadicTypePack{numberType}}); - const TypeId stringToStringType = makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType}); - const TypeId replArgType = - arena->addType(UnionType{{stringType, arena->addType(TableType({}, TableIndexer(stringType, stringType), TypeLevel{}, TableState::Generic)), - makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType})}}); - const TypeId gsubFunc = makeFunction(*arena, stringType, {}, {}, {stringType, replArgType, optionalNumber}, {}, {stringType, numberType}); - const TypeId gmatchFunc = - makeFunction(*arena, stringType, {}, {}, {stringType}, {}, {arena->addType(FunctionType{emptyPack, stringVariadicList})}); - attachMagicFunction(gmatchFunc, magicFunctionGmatch); - attachDcrMagicFunction(gmatchFunc, dcrMagicFunctionGmatch); + if (FFlag::LuauMakeStringMethodsChecked) + { + FunctionType formatFTV{arena->addTypePack(TypePack{{stringType}, variadicTailPack}), oneStringPack}; + formatFTV.magicFunction = &magicFunctionFormat; + formatFTV.isCheckedFunction = true; + const TypeId formatFn = arena->addType(formatFTV); + attachDcrMagicFunction(formatFn, dcrMagicFunctionFormat); - const TypeId matchFunc = arena->addType( - FunctionType{arena->addTypePack({stringType, stringType, optionalNumber}), arena->addTypePack(TypePackVar{VariadicTypePack{stringType}})}); - attachMagicFunction(matchFunc, magicFunctionMatch); - attachDcrMagicFunction(matchFunc, dcrMagicFunctionMatch); - const TypeId findFunc = arena->addType(FunctionType{arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}), - arena->addTypePack(TypePack{{optionalNumber, optionalNumber}, stringVariadicList})}); - attachMagicFunction(findFunc, magicFunctionFind); - attachDcrMagicFunction(findFunc, dcrMagicFunctionFind); + const TypeId stringToStringType = makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType}, /* checked */ true); - TableType::Props stringLib = { - {"byte", {arena->addType(FunctionType{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList})}}, - {"char", {arena->addType(FunctionType{numberVariadicList, arena->addTypePack({stringType})})}}, - {"find", {findFunc}}, - {"format", {formatFn}}, // FIXME - {"gmatch", {gmatchFunc}}, - {"gsub", {gsubFunc}}, - {"len", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType})}}, - {"lower", {stringToStringType}}, - {"match", {matchFunc}}, - {"rep", {makeFunction(*arena, stringType, {}, {}, {numberType}, {}, {stringType})}}, - {"reverse", {stringToStringType}}, - {"sub", {makeFunction(*arena, stringType, {}, {}, {numberType, optionalNumber}, {}, {stringType})}}, - {"upper", {stringToStringType}}, - {"split", {makeFunction(*arena, stringType, {}, {}, {optionalString}, {}, - {arena->addType(TableType{{}, TableIndexer{numberType, stringType}, TypeLevel{}, TableState::Sealed})})}}, - {"pack", {arena->addType(FunctionType{ - arena->addTypePack(TypePack{{stringType}, variadicTailPack}), - oneStringPack, - })}}, - {"packsize", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType})}}, - {"unpack", {arena->addType(FunctionType{ - arena->addTypePack(TypePack{{stringType, stringType, optionalNumber}}), - variadicTailPack, - })}}, - }; + const TypeId replArgType = arena->addType( + UnionType{{stringType, arena->addType(TableType({}, TableIndexer(stringType, stringType), TypeLevel{}, TableState::Generic)), + makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType}, /* checked */ false)}}); + const TypeId gsubFunc = + makeFunction(*arena, stringType, {}, {}, {stringType, replArgType, optionalNumber}, {}, {stringType, numberType}, /* checked */ false); + const TypeId gmatchFunc = makeFunction( + *arena, stringType, {}, {}, {stringType}, {}, {arena->addType(FunctionType{emptyPack, stringVariadicList})}, /* checked */ true); + attachMagicFunction(gmatchFunc, magicFunctionGmatch); + attachDcrMagicFunction(gmatchFunc, dcrMagicFunctionGmatch); - assignPropDocumentationSymbols(stringLib, "@luau/global/string"); + FunctionType matchFuncTy{ + arena->addTypePack({stringType, stringType, optionalNumber}), arena->addTypePack(TypePackVar{VariadicTypePack{stringType}})}; + matchFuncTy.isCheckedFunction = true; + const TypeId matchFunc = arena->addType(matchFuncTy); + attachMagicFunction(matchFunc, magicFunctionMatch); + attachDcrMagicFunction(matchFunc, dcrMagicFunctionMatch); - TypeId tableType = arena->addType(TableType{std::move(stringLib), std::nullopt, TypeLevel{}, TableState::Sealed}); + FunctionType findFuncTy{arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}), + arena->addTypePack(TypePack{{optionalNumber, optionalNumber}, stringVariadicList})}; + findFuncTy.isCheckedFunction = true; + const TypeId findFunc = arena->addType(findFuncTy); + attachMagicFunction(findFunc, magicFunctionFind); + attachDcrMagicFunction(findFunc, dcrMagicFunctionFind); - if (TableType* ttv = getMutable(tableType)) - ttv->name = "typeof(string)"; + // string.byte : string -> number? -> number? -> ...number + FunctionType stringDotByte{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList}; + stringDotByte.isCheckedFunction = true; - return arena->addType(TableType{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed}); + // string.char : .... number -> string + FunctionType stringDotChar{numberVariadicList, arena->addTypePack({stringType})}; + stringDotChar.isCheckedFunction = true; + + // string.unpack : string -> string -> number? -> ...any + FunctionType stringDotUnpack{ + arena->addTypePack(TypePack{{stringType, stringType, optionalNumber}}), + variadicTailPack, + }; + stringDotUnpack.isCheckedFunction = true; + + TableType::Props stringLib = { + {"byte", {arena->addType(stringDotByte)}}, + {"char", {arena->addType(stringDotChar)}}, + {"find", {findFunc}}, + {"format", {formatFn}}, // FIXME + {"gmatch", {gmatchFunc}}, + {"gsub", {gsubFunc}}, + {"len", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType}, /* checked */ true)}}, + {"lower", {stringToStringType}}, + {"match", {matchFunc}}, + {"rep", {makeFunction(*arena, stringType, {}, {}, {numberType}, {}, {stringType}, /* checked */ true)}}, + {"reverse", {stringToStringType}}, + {"sub", {makeFunction(*arena, stringType, {}, {}, {numberType, optionalNumber}, {}, {stringType}, /* checked */ true)}}, + {"upper", {stringToStringType}}, + {"split", {makeFunction(*arena, stringType, {}, {}, {optionalString}, {}, + {arena->addType(TableType{{}, TableIndexer{numberType, stringType}, TypeLevel{}, TableState::Sealed})}, + /* checked */ true)}}, + {"pack", {arena->addType(FunctionType{ + arena->addTypePack(TypePack{{stringType}, variadicTailPack}), + oneStringPack, + })}}, + {"packsize", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType}, /* checked */ true)}}, + {"unpack", {arena->addType(stringDotUnpack)}}, + }; + assignPropDocumentationSymbols(stringLib, "@luau/global/string"); + + TypeId tableType = arena->addType(TableType{std::move(stringLib), std::nullopt, TypeLevel{}, TableState::Sealed}); + + if (TableType* ttv = getMutable(tableType)) + ttv->name = "typeof(string)"; + + return arena->addType(TableType{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed}); + } + else + { + FunctionType formatFTV{arena->addTypePack(TypePack{{stringType}, variadicTailPack}), oneStringPack}; + formatFTV.magicFunction = &magicFunctionFormat; + const TypeId formatFn = arena->addType(formatFTV); + attachDcrMagicFunction(formatFn, dcrMagicFunctionFormat); + + const TypeId stringToStringType = makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType}); + + const TypeId replArgType = arena->addType( + UnionType{{stringType, arena->addType(TableType({}, TableIndexer(stringType, stringType), TypeLevel{}, TableState::Generic)), + makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType})}}); + const TypeId gsubFunc = makeFunction(*arena, stringType, {}, {}, {stringType, replArgType, optionalNumber}, {}, {stringType, numberType}); + const TypeId gmatchFunc = + makeFunction(*arena, stringType, {}, {}, {stringType}, {}, {arena->addType(FunctionType{emptyPack, stringVariadicList})}); + attachMagicFunction(gmatchFunc, magicFunctionGmatch); + attachDcrMagicFunction(gmatchFunc, dcrMagicFunctionGmatch); + + const TypeId matchFunc = arena->addType(FunctionType{ + arena->addTypePack({stringType, stringType, optionalNumber}), arena->addTypePack(TypePackVar{VariadicTypePack{stringType}})}); + attachMagicFunction(matchFunc, magicFunctionMatch); + attachDcrMagicFunction(matchFunc, dcrMagicFunctionMatch); + + const TypeId findFunc = arena->addType(FunctionType{arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}), + arena->addTypePack(TypePack{{optionalNumber, optionalNumber}, stringVariadicList})}); + attachMagicFunction(findFunc, magicFunctionFind); + attachDcrMagicFunction(findFunc, dcrMagicFunctionFind); + + TableType::Props stringLib = { + {"byte", {arena->addType(FunctionType{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList})}}, + {"char", {arena->addType(FunctionType{numberVariadicList, arena->addTypePack({stringType})})}}, + {"find", {findFunc}}, + {"format", {formatFn}}, // FIXME + {"gmatch", {gmatchFunc}}, + {"gsub", {gsubFunc}}, + {"len", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType})}}, + {"lower", {stringToStringType}}, + {"match", {matchFunc}}, + {"rep", {makeFunction(*arena, stringType, {}, {}, {numberType}, {}, {stringType})}}, + {"reverse", {stringToStringType}}, + {"sub", {makeFunction(*arena, stringType, {}, {}, {numberType, optionalNumber}, {}, {stringType})}}, + {"upper", {stringToStringType}}, + {"split", {makeFunction(*arena, stringType, {}, {}, {optionalString}, {}, + {arena->addType(TableType{{}, TableIndexer{numberType, stringType}, TypeLevel{}, TableState::Sealed})})}}, + {"pack", {arena->addType(FunctionType{ + arena->addTypePack(TypePack{{stringType}, variadicTailPack}), + oneStringPack, + })}}, + {"packsize", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType})}}, + {"unpack", {arena->addType(FunctionType{ + arena->addTypePack(TypePack{{stringType, stringType, optionalNumber}}), + variadicTailPack, + })}}, + }; + + assignPropDocumentationSymbols(stringLib, "@luau/global/string"); + + TypeId tableType = arena->addType(TableType{std::move(stringLib), std::nullopt, TypeLevel{}, TableState::Sealed}); + + if (TableType* ttv = getMutable(tableType)) + ttv->name = "typeof(string)"; + + return arena->addType(TableType{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed}); + } } static std::optional> magicFunctionSelect( diff --git a/Analysis/src/ConstraintGenerator.cpp b/Analysis/src/ConstraintGenerator.cpp index 7fcd4d9e..7208f2cc 100644 --- a/Analysis/src/ConstraintGenerator.cpp +++ b/Analysis/src/ConstraintGenerator.cpp @@ -290,8 +290,8 @@ std::optional ConstraintGenerator::lookup(const ScopePtr& scope, DefId d { if (auto found = scope->lookup(def)) return *found; - else if (phi->operands.size() == 1) - return lookup(scope, phi->operands[0], prototype); + else if (!prototype && phi->operands.size() == 1) + return lookup(scope, phi->operands.at(0), prototype); else if (!prototype) return std::nullopt; @@ -963,7 +963,7 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatFunction* f DefId def = dfg->getDef(function->name); std::optional existingFunctionTy = lookup(scope, def); - if (sigFullyDefined && existingFunctionTy && get(*existingFunctionTy)) + if (existingFunctionTy && (sigFullyDefined || function->name->is()) && get(*existingFunctionTy)) asMutable(*existingFunctionTy)->ty.emplace(sig.signature); if (AstExprLocal* localName = function->name->as()) @@ -2537,7 +2537,7 @@ TypeId ConstraintGenerator::updateProperty(const ScopePtr& scope, AstExpr* expr, std::vector segmentStrings(begin(segments), end(segments)); TypeId updatedType = arena->addType(BlockedType{}); - addConstraint(scope, expr->location, SetPropConstraint{updatedType, subjectType, std::move(segmentStrings), assignedTy}); + auto setC = addConstraint(scope, expr->location, SetPropConstraint{updatedType, subjectType, std::move(segmentStrings), assignedTy}); TypeId prevSegmentTy = updatedType; for (size_t i = 0; i < segments.size(); ++i) @@ -2545,7 +2545,8 @@ TypeId ConstraintGenerator::updateProperty(const ScopePtr& scope, AstExpr* expr, TypeId segmentTy = arena->addType(BlockedType{}); module->astTypes[exprs[i]] = segmentTy; ValueContext ctx = i == segments.size() - 1 ? ValueContext::LValue : ValueContext::RValue; - addConstraint(scope, expr->location, HasPropConstraint{segmentTy, prevSegmentTy, segments[i], ctx}); + auto hasC = addConstraint(scope, expr->location, HasPropConstraint{segmentTy, prevSegmentTy, segments[i], ctx}); + setC->dependencies.push_back(hasC); prevSegmentTy = segmentTy; } diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index a1a6aa3b..00829b85 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -1068,7 +1068,7 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNullscope, NotNull{&iceReporter}, NotNull{&limits}, c.callSite->location}; + builtinTypes, NotNull{arena}, normalizer, constraint->scope, NotNull{&iceReporter}, NotNull{&limits}, constraint->location}; auto [status, overload] = resolver.selectOverload(fn, argsPack); TypeId overloadToUse = fn; if (status == OverloadResolver::Analysis::Ok) @@ -2184,14 +2184,15 @@ bool ConstraintSolver::block_(BlockedConstraintId target, NotNull> blockVec{&blocked.try_emplace(target, nullptr).first->second}; + auto [iter, inserted] = blocked.try_emplace(target, nullptr); + auto& [key, blockVec] = *iter; - if (blockVec->find(constraint)) + if (blockVec.find(constraint)) return false; - blockVec->insert(constraint); + blockVec.insert(constraint); - auto& count = blockedConstraints[constraint]; + size_t& count = blockedConstraints[constraint]; count += 1; return true; diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index d1516b09..68e732e6 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -509,6 +509,26 @@ struct ErrorConverter return "Type family instance " + Luau::toString(e.ty) + " is uninhabited"; } + std::string operator()(const ExplicitFunctionAnnotationRecommended& r) const + { + std::string toReturn = toString(r.recommendedReturn); + std::string argAnnotations; + for (auto [arg, type] : r.recommendedArgs) + { + argAnnotations += arg + ": " + toString(type) + ", "; + } + if (argAnnotations.length() >= 2) + { + argAnnotations.pop_back(); + argAnnotations.pop_back(); + } + + if (argAnnotations.empty()) + return "Consider annotating the return with " + toReturn; + + return "Consider placing the following annotations on the arguments: " + argAnnotations + " or instead annotating the return as " + toReturn; + } + std::string operator()(const UninhabitedTypePackFamily& e) const { return "Type pack family instance " + Luau::toString(e.tp) + " is uninhabited"; @@ -883,6 +903,12 @@ bool UninhabitedTypeFamily::operator==(const UninhabitedTypeFamily& rhs) const return ty == rhs.ty; } + +bool ExplicitFunctionAnnotationRecommended::operator==(const ExplicitFunctionAnnotationRecommended& rhs) const +{ + return recommendedReturn == rhs.recommendedReturn && recommendedArgs == rhs.recommendedArgs; +} + bool UninhabitedTypePackFamily::operator==(const UninhabitedTypePackFamily& rhs) const { return tp == rhs.tp; @@ -1084,6 +1110,12 @@ void copyError(T& e, TypeArena& destArena, CloneState& cloneState) e.ty = clone(e.ty); else if constexpr (std::is_same_v) e.ty = clone(e.ty); + else if constexpr (std::is_same_v) + { + e.recommendedReturn = clone(e.recommendedReturn); + for (auto [_, t] : e.recommendedArgs) + t = clone(t); + } else if constexpr (std::is_same_v) e.tp = clone(e.tp); else if constexpr (std::is_same_v) diff --git a/Analysis/src/IostreamHelpers.cpp b/Analysis/src/IostreamHelpers.cpp index 54afd5d6..dd392faa 100644 --- a/Analysis/src/IostreamHelpers.cpp +++ b/Analysis/src/IostreamHelpers.cpp @@ -195,6 +195,15 @@ static void errorToString(std::ostream& stream, const T& err) stream << "DynamicPropertyLookupOnClassesUnsafe { " << toString(err.ty) << " }"; else if constexpr (std::is_same_v) stream << "UninhabitedTypeFamily { " << toString(err.ty) << " }"; + else if constexpr (std::is_same_v) + { + std::string recArgs = "["; + for (auto [s, t] : err.recommendedArgs) + recArgs += " " + s + ": " + toString(t); + recArgs += " ]"; + stream << "ExplicitFunctionAnnotationRecommended { recommmendedReturn = '" + toString(err.recommendedReturn) + + "', recommmendedArgs = " + recArgs + "}"; + } else if constexpr (std::is_same_v) stream << "UninhabitedTypePackFamily { " << toString(err.tp) << " }"; else if constexpr (std::is_same_v) diff --git a/Analysis/src/NonStrictTypeChecker.cpp b/Analysis/src/NonStrictTypeChecker.cpp index 2aeb8ebd..d285bf26 100644 --- a/Analysis/src/NonStrictTypeChecker.cpp +++ b/Analysis/src/NonStrictTypeChecker.cpp @@ -542,11 +542,11 @@ struct NonStrictTypeChecker } } } - // For a checked function, these gotta be the same size std::string functionName = getFunctionNameAsString(*call->func).value_or(""); - if (call->args.size != argTypes.size()) + if (call->args.size > argTypes.size()) { + // We are passing more arguments than we expect, so we should error reportError(CheckedFunctionIncorrectArgs{functionName, argTypes.size(), call->args.size}, call->location); return fresh; } @@ -572,6 +572,20 @@ struct NonStrictTypeChecker if (auto runTimeFailureType = willRunTimeError(arg, fresh)) reportError(CheckedFunctionCallError{argTypes[i], *runTimeFailureType, functionName, i}, arg->location); } + + if (call->args.size < argTypes.size()) + { + // We are passing fewer arguments than we expect + // so we need to ensure that the rest of the args are optional. + bool remainingArgsOptional = true; + for (size_t i = call->args.size; i < argTypes.size(); i++) + remainingArgsOptional = remainingArgsOptional && isOptional(argTypes[i]); + if (!remainingArgsOptional) + { + reportError(CheckedFunctionIncorrectArgs{functionName, argTypes.size(), call->args.size}, call->location); + return fresh; + } + } } } diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index 62aeb145..015507cb 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -394,6 +394,25 @@ bool NormalizedType::hasTyvars() const return !tyvars.empty(); } +bool NormalizedType::isFalsy() const +{ + + bool hasAFalse = false; + if (auto singleton = get(booleans)) + { + if (auto bs = singleton->variant.get_if()) + hasAFalse = !bs->value; + } + + return (hasAFalse || hasNils()) && (!hasTops() && !hasClasses() && !hasErrors() && !hasNumbers() && !hasStrings() && !hasThreads() && + !hasBuffers() && !hasTables() && !hasFunctions() && !hasTyvars()); +} + +bool NormalizedType::isTruthy() const +{ + return !isFalsy(); +} + static bool isShallowInhabited(const NormalizedType& norm) { // This test is just a shallow check, for example it returns `true` for `{ p : never }` diff --git a/Analysis/src/OverloadResolution.cpp b/Analysis/src/OverloadResolution.cpp index 8bce3efd..c6c71baa 100644 --- a/Analysis/src/OverloadResolution.cpp +++ b/Analysis/src/OverloadResolution.cpp @@ -1,12 +1,14 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/OverloadResolution.h" +#include "Luau/Instantiation2.h" #include "Luau/Subtyping.h" #include "Luau/TxnLog.h" #include "Luau/Type.h" +#include "Luau/TypeFamily.h" #include "Luau/TypePack.h" #include "Luau/TypeUtils.h" -#include "Luau/TypeFamily.h" +#include "Luau/Unifier2.h" namespace Luau { @@ -26,19 +28,28 @@ OverloadResolver::OverloadResolver(NotNull builtinTypes, NotNull OverloadResolver::selectOverload(TypeId ty, TypePackId argsPack) { + auto tryOne = [&](TypeId f) { + if (auto ftv = get(f)) + { + SubtypingResult r = subtyping.isSubtype(argsPack, ftv->argTypes); + if (r.isSubtype) + return true; + } + + return false; + }; + TypeId t = follow(ty); + + if (tryOne(ty)) + return {Analysis::Ok, ty}; + if (auto it = get(t)) { for (TypeId component : it) { - if (auto ftv = get(component)) - { - SubtypingResult r = subtyping.isSubtype(argsPack, ftv->argTypes); - if (r.isSubtype) - return {Analysis::Ok, component}; - } - else - continue; + if (tryOne(component)) + return {Analysis::Ok, component}; } } @@ -348,4 +359,63 @@ void OverloadResolver::add(Analysis analysis, TypeId ty, ErrorVec&& errors) } +SolveResult solveFunctionCall( + NotNull arena, + NotNull builtinTypes, + NotNull normalizer, + NotNull iceReporter, + NotNull limits, + NotNull scope, + const Location& location, + TypeId fn, + TypePackId argsPack +) +{ + OverloadResolver resolver{ + builtinTypes, NotNull{arena}, normalizer, scope, iceReporter, limits, location}; + auto [status, overload] = resolver.selectOverload(fn, argsPack); + TypeId overloadToUse = fn; + if (status == OverloadResolver::Analysis::Ok) + overloadToUse = overload; + else if (get(fn) || get(fn)) + { + // Nothing. Let's keep going + } + else + return {SolveResult::NoMatchingOverload}; + + TypePackId resultPack = arena->freshTypePack(scope); + + TypeId inferredTy = arena->addType(FunctionType{TypeLevel{}, scope.get(), argsPack, resultPack}); + Unifier2 u2{NotNull{arena}, builtinTypes, scope, iceReporter}; + + const bool occursCheckPassed = u2.unify(overloadToUse, inferredTy); + + if (!u2.genericSubstitutions.empty() || !u2.genericPackSubstitutions.empty()) + { + Instantiation2 instantiation{arena, std::move(u2.genericSubstitutions), std::move(u2.genericPackSubstitutions)}; + + std::optional subst = instantiation.substitute(resultPack); + + if (!subst) + return {SolveResult::CodeTooComplex}; + else + resultPack = *subst; + } + + if (!occursCheckPassed) + return {SolveResult::OccursCheckFailed}; + + SolveResult result; + result.result = SolveResult::Ok; + result.typePackId = resultPack; + + LUAU_ASSERT(overloadToUse); + result.overloadToUse = overloadToUse; + result.inferredTy = inferredTy; + result.expandedFreeTypes = std::move(u2.expandedFreeTypes); + + return result; +} + } // namespace Luau diff --git a/Analysis/src/Scope.cpp b/Analysis/src/Scope.cpp index a3182c0a..791167c8 100644 --- a/Analysis/src/Scope.cpp +++ b/Analysis/src/Scope.cpp @@ -44,8 +44,9 @@ std::optional> Scope::lookupEx(DefId def) while (true) { - TypeId* it = s->lvalueTypes.find(def); - if (it) + if (TypeId* it = s->lvalueTypes.find(def)) + return std::pair{*it, s}; + else if (TypeId* it = s->rvalueRefinements.find(def)) return std::pair{*it, s}; if (s->parent) diff --git a/Analysis/src/Simplify.cpp b/Analysis/src/Simplify.cpp index c4eb9368..931ec6d7 100644 --- a/Analysis/src/Simplify.cpp +++ b/Analysis/src/Simplify.cpp @@ -1128,6 +1128,9 @@ TypeId TypeSimplifier::intersect(TypeId left, TypeId right) left = simplify(left); right = simplify(right); + if (left == right) + return left; + if (get(left) && get(right)) return right; if (get(right) && get(left)) diff --git a/Analysis/src/Subtyping.cpp b/Analysis/src/Subtyping.cpp index ecdb039f..8a8b668a 100644 --- a/Analysis/src/Subtyping.cpp +++ b/Analysis/src/Subtyping.cpp @@ -480,7 +480,7 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypeId sub // tested as though it were its upper bounds. We do not yet support bounded // generics, so the upper bound is always unknown. if (auto subGeneric = get(subTy); subGeneric && subsumes(subGeneric->scope, scope)) - return isCovariantWith(env, builtinTypes->unknownType, superTy); + return isCovariantWith(env, builtinTypes->neverType, superTy); if (auto superGeneric = get(superTy); superGeneric && subsumes(superGeneric->scope, scope)) return isCovariantWith(env, subTy, builtinTypes->unknownType); @@ -1611,4 +1611,21 @@ TypeId Subtyping::makeAggregateType(const Container& container, TypeId orElse) return arena->addType(T{std::vector(begin(container), end(container))}); } +std::pair Subtyping::handleTypeFamilyReductionResult(const TypeFamilyInstanceType* familyInstance) +{ + TypeFamilyContext context{arena, builtinTypes, scope, normalizer, iceReporter, NotNull{&limits}}; + TypeId family = arena->addType(*familyInstance); + std::string familyString = toString(family); + FamilyGraphReductionResult result = reduceFamilies(family, {}, context, true); + ErrorVec errors; + if (result.blockedTypes.size() != 0 || result.blockedPacks.size() != 0) + { + errors.push_back(TypeError{{}, UninhabitedTypeFamily{family}}); + return {builtinTypes->neverType, errors}; + } + if (result.reducedTypes.contains(family)) + return {family, errors}; + return {builtinTypes->neverType, errors}; +} + } // namespace Luau diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index e26ed138..c4cccc8f 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -20,7 +20,6 @@ #include LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) -LUAU_FASTFLAGVARIABLE(LuauToStringPrettifyLocation, false) LUAU_FASTFLAGVARIABLE(LuauToStringSimpleCompositeTypesSingleLine, false) /* @@ -1879,15 +1878,8 @@ std::string toString(const Position& position) std::string toString(const Location& location, int offset, bool useBegin) { - if (FFlag::LuauToStringPrettifyLocation) - { - return "(" + std::to_string(location.begin.line + offset) + ", " + std::to_string(location.begin.column + offset) + ") - (" + - std::to_string(location.end.line + offset) + ", " + std::to_string(location.end.column + offset) + ")"; - } - else - { - return "Location { " + toString(location.begin) + ", " + toString(location.end) + " }"; - } + return "(" + std::to_string(location.begin.line + offset) + ", " + std::to_string(location.begin.column + offset) + ") - (" + + std::to_string(location.end.line + offset) + ", " + std::to_string(location.end.column + offset) + ")"; } std::string toString(const TypeOrPack& tyOrTp, ToStringOptions& opts) diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index 261c578f..cbeaa1f3 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -17,6 +17,7 @@ #include "Luau/TxnLog.h" #include "Luau/Type.h" #include "Luau/TypeFamily.h" +#include "Luau/TypeFamilyReductionGuesser.h" #include "Luau/TypeFwd.h" #include "Luau/TypePack.h" #include "Luau/TypePath.h" @@ -25,6 +26,8 @@ #include "Luau/VisitType.h" #include +#include +#include LUAU_FASTFLAG(DebugLuauMagicTypes) @@ -36,6 +39,7 @@ namespace Luau using PrintLineProc = void (*)(const std::string&); extern PrintLineProc luauPrintLine; + /* Push a scope onto the end of a stack for the lifetime of the StackPusher instance. * TypeChecker2 uses this to maintain knowledge about which scope encloses every * given AstNode. @@ -1271,13 +1275,13 @@ struct TypeChecker2 { switch (shouldSuppressErrors(NotNull{&normalizer}, fnTy)) { - case ErrorSuppression::Suppress: - break; - case ErrorSuppression::NormalizationFailed: - reportError(NormalizationTooComplex{}, call->func->location); - // fallthrough intentional - case ErrorSuppression::DoNotSuppress: - reportError(OptionalValueAccess{fnTy}, call->func->location); + case ErrorSuppression::Suppress: + break; + case ErrorSuppression::NormalizationFailed: + reportError(NormalizationTooComplex{}, call->func->location); + // fallthrough intentional + case ErrorSuppression::DoNotSuppress: + reportError(OptionalValueAccess{fnTy}, call->func->location); } return; } @@ -1528,6 +1532,7 @@ struct TypeChecker2 functionDeclStack.push_back(inferredFnTy); const NormalizedType* normalizedFnTy = normalizer.normalize(inferredFnTy); + const FunctionType* inferredFtv = get(normalizedFnTy->functions.parts.front()); if (!normalizedFnTy) { reportError(CodeTooComplex{}, fn->location); @@ -1622,6 +1627,19 @@ struct TypeChecker2 if (fn->returnAnnotation) visit(*fn->returnAnnotation); + // If the function type has a family annotation, we need to see if we can suggest an annotation + TypeFamilyReductionGuesser guesser{builtinTypes, NotNull{&normalizer}}; + for (TypeId retTy : inferredFtv->retTypes) + { + if (get(follow(retTy))) + { + TypeFamilyReductionGuessResult result = guesser.guessTypeFamilyReductionForFunction(*fn, inferredFtv, retTy); + if (result.shouldRecommendAnnotation) + reportError( + ExplicitFunctionAnnotationRecommended{std::move(result.guessedFunctionAnnotations), result.guessedReturnType}, fn->location); + } + } + functionDeclStack.pop_back(); } diff --git a/Analysis/src/TypeFamily.cpp b/Analysis/src/TypeFamily.cpp index 986fcacd..e5df313e 100644 --- a/Analysis/src/TypeFamily.cpp +++ b/Analysis/src/TypeFamily.cpp @@ -8,6 +8,8 @@ #include "Luau/Instantiation.h" #include "Luau/Normalize.h" #include "Luau/NotNull.h" +#include "Luau/OverloadResolution.h" +#include "Luau/Set.h" #include "Luau/Simplify.h" #include "Luau/Substitution.h" #include "Luau/Subtyping.h" @@ -19,7 +21,6 @@ #include "Luau/TypeUtils.h" #include "Luau/Unifier2.h" #include "Luau/VecDeque.h" -#include "Luau/Set.h" #include "Luau/VisitType.h" LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFamilyGraphReductionMaximumSteps, 1'000'000); @@ -514,6 +515,18 @@ TypeFamilyReductionResult unmFamilyFn( return {std::nullopt, true, {}, {}}; } +NotNull TypeFamilyContext::pushConstraint(ConstraintV&& c) +{ + NotNull newConstraint = solver->pushConstraint(scope, constraint ? constraint->location : Location{}, std::move(c)); + + // Every constraint that is blocked on the current constraint must also be + // blocked on this new one. + if (constraint) + solver->inheritBlocks(NotNull{constraint}, newConstraint); + + return newConstraint; +} + TypeFamilyReductionResult numericBinopFamilyFn(TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx, const std::string metamethod) { @@ -526,6 +539,8 @@ TypeFamilyReductionResult numericBinopFamilyFn(TypeId instance, const st TypeId lhsTy = follow(typeParams.at(0)); TypeId rhsTy = follow(typeParams.at(1)); + const Location location = ctx->constraint ? ctx->constraint->location : Location{}; + // check to see if both operand types are resolved enough, and wait to reduce if not if (isPending(lhsTy, ctx->solver)) return {std::nullopt, false, {lhsTy}, {}}; @@ -555,11 +570,11 @@ TypeFamilyReductionResult numericBinopFamilyFn(TypeId instance, const st // the necessary state to do that, even if we intend to just eat the errors. ErrorVec dummy; - std::optional mmType = findMetatableEntry(ctx->builtins, dummy, lhsTy, metamethod, Location{}); + std::optional mmType = findMetatableEntry(ctx->builtins, dummy, lhsTy, metamethod, location); bool reversed = false; if (!mmType) { - mmType = findMetatableEntry(ctx->builtins, dummy, rhsTy, metamethod, Location{}); + mmType = findMetatableEntry(ctx->builtins, dummy, rhsTy, metamethod, location); reversed = true; } @@ -570,33 +585,26 @@ TypeFamilyReductionResult numericBinopFamilyFn(TypeId instance, const st if (isPending(*mmType, ctx->solver)) return {std::nullopt, false, {*mmType}, {}}; - const FunctionType* mmFtv = get(*mmType); - if (!mmFtv) - return {std::nullopt, true, {}, {}}; + TypePackId argPack = ctx->arena->addTypePack({lhsTy, rhsTy}); + SolveResult solveResult; - std::optional instantiatedMmType = instantiate(ctx->builtins, ctx->arena, ctx->limits, ctx->scope, *mmType); - if (!instantiatedMmType) - return {std::nullopt, true, {}, {}}; - - const FunctionType* instantiatedMmFtv = get(*instantiatedMmType); - if (!instantiatedMmFtv) - return {ctx->builtins->errorRecoveryType(), false, {}, {}}; - - std::vector inferredArgs; if (!reversed) - inferredArgs = {lhsTy, rhsTy}; + solveResult = solveFunctionCall(ctx->arena, ctx->builtins, ctx->normalizer, ctx->ice, ctx->limits, ctx->scope, location, *mmType, argPack); else - inferredArgs = {rhsTy, lhsTy}; + { + TypePack* p = getMutable(argPack); + std::swap(p->head.front(), p->head.back()); + solveResult = solveFunctionCall(ctx->arena, ctx->builtins, ctx->normalizer, ctx->ice, ctx->limits, ctx->scope, location, *mmType, argPack); + } - TypePackId inferredArgPack = ctx->arena->addTypePack(std::move(inferredArgs)); - Unifier2 u2{ctx->arena, ctx->builtins, ctx->scope, ctx->ice}; - if (!u2.unify(inferredArgPack, instantiatedMmFtv->argTypes)) - return {std::nullopt, true, {}, {}}; // occurs check failed - - if (std::optional ret = first(instantiatedMmFtv->retTypes)) - return {*ret, false, {}, {}}; - else + if (!solveResult.typePackId.has_value()) return {std::nullopt, true, {}, {}}; + + TypePack extracted = extendTypePack(*ctx->arena, ctx->builtins, *solveResult.typePackId, 1); + if (extracted.head.empty()) + return {std::nullopt, true, {}, {}}; + + return {extracted.head.front(), false, {}, {}}; } TypeFamilyReductionResult addFamilyFn( @@ -855,6 +863,11 @@ static TypeFamilyReductionResult comparisonFamilyFn(TypeId instance, con TypeId lhsTy = follow(typeParams.at(0)); TypeId rhsTy = follow(typeParams.at(1)); + if (isPending(lhsTy, ctx->solver)) + return {std::nullopt, false, {lhsTy}, {}}; + else if (isPending(rhsTy, ctx->solver)) + return {std::nullopt, false, {rhsTy}, {}}; + // Algebra Reduction Rules for comparison family functions // Note that comparing to never tells you nothing about the other operand // lt< 'a , never> -> continue @@ -875,12 +888,12 @@ static TypeFamilyReductionResult comparisonFamilyFn(TypeId instance, con asMutable(rhsTy)->ty.emplace(ctx->builtins->numberType); else if (lhsFree && get(rhsTy) == nullptr) { - auto c1 = ctx->solver->pushConstraint(ctx->scope, {}, EqualityConstraint{lhsTy, rhsTy}); + auto c1 = ctx->pushConstraint(EqualityConstraint{lhsTy, rhsTy}); const_cast(ctx->constraint)->dependencies.emplace_back(c1); } else if (rhsFree && get(lhsTy) == nullptr) { - auto c1 = ctx->solver->pushConstraint(ctx->scope, {}, EqualityConstraint{rhsTy, lhsTy}); + auto c1 = ctx->pushConstraint(EqualityConstraint{rhsTy, lhsTy}); const_cast(ctx->constraint)->dependencies.emplace_back(c1); } } @@ -890,10 +903,6 @@ static TypeFamilyReductionResult comparisonFamilyFn(TypeId instance, con rhsTy = follow(rhsTy); // check to see if both operand types are resolved enough, and wait to reduce if not - if (isPending(lhsTy, ctx->solver)) - return {std::nullopt, false, {lhsTy}, {}}; - else if (isPending(rhsTy, ctx->solver)) - return {std::nullopt, false, {rhsTy}, {}}; const NormalizedType* normLhsTy = ctx->normalizer->normalize(lhsTy); const NormalizedType* normRhsTy = ctx->normalizer->normalize(rhsTy); diff --git a/Analysis/src/TypeFamilyReductionGuesser.cpp b/Analysis/src/TypeFamilyReductionGuesser.cpp new file mode 100644 index 00000000..50c9b4f5 --- /dev/null +++ b/Analysis/src/TypeFamilyReductionGuesser.cpp @@ -0,0 +1,409 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/TypeFamilyReductionGuesser.h" + +#include "Luau/DenseHash.h" +#include "Luau/Normalize.h" +#include "Luau/TypeFamily.h" +#include "Luau/Type.h" +#include "Luau/TypePack.h" +#include "Luau/TypeUtils.h" +#include "Luau/VecDeque.h" +#include "Luau/VisitType.h" + +#include +#include + +namespace Luau +{ +struct InstanceCollector2 : TypeOnceVisitor +{ + VecDeque tys; + VecDeque tps; + DenseHashSet cyclicInstance{nullptr}; + DenseHashSet instanceArguments{nullptr}; + + bool visit(TypeId ty, const TypeFamilyInstanceType& it) override + { + // TypeOnceVisitor performs a depth-first traversal in the absence of + // cycles. This means that by pushing to the front of the queue, we will + // try to reduce deeper instances first if we start with the first thing + // in the queue. Consider Add, number>, number>: + // we want to reduce the innermost Add instantiation + // first. + tys.push_front(ty); + for (auto t : it.typeArguments) + instanceArguments.insert(follow(t)); + return true; + } + + void cycle(TypeId ty) override + { + /// Detected cyclic type pack + TypeId t = follow(ty); + if (get(t)) + cyclicInstance.insert(t); + } + + bool visit(TypeId ty, const ClassType&) override + { + return false; + } + + bool visit(TypePackId tp, const TypeFamilyInstanceTypePack&) override + { + // TypeOnceVisitor performs a depth-first traversal in the absence of + // cycles. This means that by pushing to the front of the queue, we will + // try to reduce deeper instances first if we start with the first thing + // in the queue. Consider Add, number>, number>: + // we want to reduce the innermost Add instantiation + // first. + tps.push_front(tp); + return true; + } +}; + + + +TypeFamilyReductionGuesser::TypeFamilyReductionGuesser(NotNull builtins, NotNull normalizer) + : builtins(builtins) + , normalizer(normalizer) +{ +} + +bool TypeFamilyReductionGuesser::isFunctionGenericsSaturated(const FunctionType& ftv, DenseHashSet& argsUsed) +{ + bool sameSize = ftv.generics.size() == argsUsed.size(); + bool allGenericsAppear = true; + for (auto gt : ftv.generics) + allGenericsAppear = allGenericsAppear || argsUsed.contains(gt); + return sameSize && allGenericsAppear; +} + +void TypeFamilyReductionGuesser::dumpGuesses() +{ + for (auto [tf, t] : familyReducesTo) + printf("Type family %s ~~> %s\n", toString(tf).c_str(), toString(t).c_str()); + for (auto [t, t_] : substitutable) + printf("Substitute %s for %s\n", toString(t).c_str(), toString(t_).c_str()); +} + +TypeFamilyReductionGuessResult TypeFamilyReductionGuesser::guessTypeFamilyReductionForFunction( + const AstExprFunction& expr, const FunctionType* ftv, TypeId retTy) +{ + InstanceCollector2 collector; + collector.traverse(retTy); + toInfer = std::move(collector.tys); + cyclicInstances = std::move(collector.cyclicInstance); + + if (isFunctionGenericsSaturated(*ftv, collector.instanceArguments)) + return TypeFamilyReductionGuessResult{{}, nullptr, false}; + infer(); + + std::vector> results; + std::vector args; + for (TypeId t : ftv->argTypes) + args.push_back(t); + + // Submit a guess for arg types + for (size_t i = 0; i < expr.args.size; i++) + { + TypeId argTy; + AstLocal* local = expr.args.data[i]; + if (i >= args.size()) + continue; + + argTy = args[i]; + std::optional guessedType = guessType(argTy); + if (!guessedType.has_value()) + continue; + TypeId guess = follow(*guessedType); + if (get(guess)) + continue; + + results.push_back({local->name.value, guess}); + } + + // Submit a guess for return types + TypeId recommendedAnnotation; + std::optional guessedReturnType = guessType(retTy); + if (!guessedReturnType.has_value()) + recommendedAnnotation = builtins->unknownType; + else + recommendedAnnotation = follow(*guessedReturnType); + if (auto t = get(recommendedAnnotation)) + recommendedAnnotation = builtins->unknownType; + + toInfer.clear(); + cyclicInstances.clear(); + familyReducesTo.clear(); + substitutable.clear(); + + return TypeFamilyReductionGuessResult{results, recommendedAnnotation}; +} + +std::optional TypeFamilyReductionGuesser::guessType(TypeId arg) +{ + TypeId t = follow(arg); + if (substitutable.contains(t)) + { + TypeId subst = follow(substitutable[t]); + if (subst == t || substitutable.contains(subst)) + return subst; + else if (!get(subst)) + return subst; + else + return guessType(subst); + } + if (get(t)) + { + if (familyReducesTo.contains(t)) + return familyReducesTo[t]; + } + return {}; +} + +bool TypeFamilyReductionGuesser::isNumericBinopFamily(const TypeFamilyInstanceType& instance) +{ + return instance.family->name == "add" || instance.family->name == "sub" || instance.family->name == "mul" || instance.family->name == "div" || + instance.family->name == "idiv" || instance.family->name == "pow" || instance.family->name == "mod"; +} + +bool TypeFamilyReductionGuesser::isComparisonFamily(const TypeFamilyInstanceType& instance) +{ + return instance.family->name == "lt" || instance.family->name == "le" || instance.family->name == "eq"; +} + +bool TypeFamilyReductionGuesser::isOrAndFamily(const TypeFamilyInstanceType& instance) +{ + return instance.family->name == "or" || instance.family->name == "and"; +} + +bool TypeFamilyReductionGuesser::isNotFamily(const TypeFamilyInstanceType& instance) +{ + return instance.family->name == "not"; +} + +bool TypeFamilyReductionGuesser::isLenFamily(const TypeFamilyInstanceType& instance) +{ + return instance.family->name == "len"; +} + +bool TypeFamilyReductionGuesser::isUnaryMinus(const TypeFamilyInstanceType& instance) +{ + return instance.family->name == "unm"; +} + +// Operand is assignable if it looks like a cyclic family instance, or a generic type +bool TypeFamilyReductionGuesser::operandIsAssignable(TypeId ty) +{ + if (get(ty)) + return true; + if (get(ty)) + return true; + if (cyclicInstances.contains(ty)) + return true; + return false; +} + +const NormalizedType* TypeFamilyReductionGuesser::normalize(TypeId ty) +{ + return normalizer->normalize(ty); +} + + +std::optional TypeFamilyReductionGuesser::tryAssignOperandType(TypeId ty) +{ + // Because we collect innermost instances first, if we see a typefamily instance as an operand, + // We try to check if we guessed a type for it + if (auto tfit = get(ty)) + { + if (familyReducesTo.contains(ty)) + return {familyReducesTo[ty]}; + } + + // If ty is a generic, we need to check if we inferred a substitution + if (auto gt = get(ty)) + { + if (substitutable.contains(ty)) + return {substitutable[ty]}; + } + + // If we cannot substitute a type for this value, we return an empty optional + return {}; +} + +void TypeFamilyReductionGuesser::step() +{ + TypeId t = toInfer.front(); + toInfer.pop_front(); + t = follow(t); + if (auto tf = get(t)) + inferTypeFamilySubstitutions(t, tf); +} + +void TypeFamilyReductionGuesser::infer() +{ + while (!done()) + step(); +} + +bool TypeFamilyReductionGuesser::done() +{ + return toInfer.empty(); +} + +void TypeFamilyReductionGuesser::inferTypeFamilySubstitutions(TypeId ty, const TypeFamilyInstanceType* instance) +{ + + TypeFamilyInferenceResult result; + LUAU_ASSERT(instance); + // TODO: Make an inexhaustive version of this warn in the compiler? + if (isNumericBinopFamily(*instance)) + result = inferNumericBinopFamily(instance); + else if (isComparisonFamily(*instance)) + result = inferComparisonFamily(instance); + else if (isOrAndFamily(*instance)) + result = inferOrAndFamily(instance); + else if (isNotFamily(*instance)) + result = inferNotFamily(instance); + else if (isLenFamily(*instance)) + result = inferLenFamily(instance); + else if (isUnaryMinus(*instance)) + result = inferUnaryMinusFamily(instance); + else + result = {{}, builtins->unknownType}; + + TypeId resultInference = follow(result.familyResultInference); + if (!familyReducesTo.contains(resultInference)) + familyReducesTo[ty] = resultInference; + + for (size_t i = 0; i < instance->typeArguments.size(); i++) + { + if (i < result.operandInference.size()) + { + TypeId arg = follow(instance->typeArguments[i]); + TypeId inference = follow(result.operandInference[i]); + if (auto tfit = get(arg)) + { + if (!familyReducesTo.contains(arg)) + familyReducesTo.try_insert(arg, inference); + } + else if (auto gt = get(arg)) + substitutable[arg] = inference; + } + } +} + +TypeFamilyInferenceResult TypeFamilyReductionGuesser::inferNumericBinopFamily(const TypeFamilyInstanceType* instance) +{ + LUAU_ASSERT(instance->typeArguments.size() == 2); + TypeFamilyInferenceResult defaultNumericBinopInference{{builtins->numberType, builtins->numberType}, builtins->numberType}; + return defaultNumericBinopInference; +} + +TypeFamilyInferenceResult TypeFamilyReductionGuesser::inferComparisonFamily(const TypeFamilyInstanceType* instance) +{ + LUAU_ASSERT(instance->typeArguments.size() == 2); + // Comparison families are lt/le/eq. + // Heuristic: these are type functions from t -> t -> bool + + TypeId lhsTy = follow(instance->typeArguments[0]); + TypeId rhsTy = follow(instance->typeArguments[1]); + + auto comparisonInference = [&](TypeId op) -> TypeFamilyInferenceResult { + return TypeFamilyInferenceResult{{op, op}, builtins->booleanType}; + }; + + if (std::optional ty = tryAssignOperandType(lhsTy)) + lhsTy = follow(*ty); + if (std::optional ty = tryAssignOperandType(rhsTy)) + rhsTy = follow(*ty); + if (operandIsAssignable(lhsTy) && !operandIsAssignable(rhsTy)) + return comparisonInference(rhsTy); + if (operandIsAssignable(rhsTy) && !operandIsAssignable(lhsTy)) + return comparisonInference(lhsTy); + return comparisonInference(builtins->numberType); +} + +TypeFamilyInferenceResult TypeFamilyReductionGuesser::inferOrAndFamily(const TypeFamilyInstanceType* instance) +{ + + LUAU_ASSERT(instance->typeArguments.size() == 2); + + TypeId lhsTy = follow(instance->typeArguments[0]); + TypeId rhsTy = follow(instance->typeArguments[1]); + + if (std::optional ty = tryAssignOperandType(lhsTy)) + lhsTy = follow(*ty); + if (std::optional ty = tryAssignOperandType(rhsTy)) + rhsTy = follow(*ty); + TypeFamilyInferenceResult defaultAndOrInference{{builtins->unknownType, builtins->unknownType}, builtins->booleanType}; + + const NormalizedType* lty = normalize(lhsTy); + const NormalizedType* rty = normalize(lhsTy); + bool lhsTruthy = lty ? lty->isTruthy() : false; + bool rhsTruthy = rty ? rty->isTruthy() : false; + // If at the end, we still don't have good substitutions, return the default type + if (instance->family->name == "or") + { + if (operandIsAssignable(lhsTy) && operandIsAssignable(rhsTy)) + return defaultAndOrInference; + if (operandIsAssignable(lhsTy)) + return TypeFamilyInferenceResult{{builtins->unknownType, rhsTy}, rhsTy}; + if (operandIsAssignable(rhsTy)) + return TypeFamilyInferenceResult{{lhsTy, builtins->unknownType}, lhsTy}; + if (lhsTruthy) + return {{lhsTy, rhsTy}, lhsTy}; + if (rhsTruthy) + return {{builtins->unknownType, rhsTy}, rhsTy}; + } + + if (instance->family->name == "and") + { + + if (operandIsAssignable(lhsTy) && operandIsAssignable(rhsTy)) + return defaultAndOrInference; + if (operandIsAssignable(lhsTy)) + return TypeFamilyInferenceResult{{}, rhsTy}; + if (operandIsAssignable(rhsTy)) + return TypeFamilyInferenceResult{{}, lhsTy}; + if (lhsTruthy) + return {{lhsTy, rhsTy}, rhsTy}; + else + return {{lhsTy, rhsTy}, lhsTy}; + } + + return defaultAndOrInference; +} + +TypeFamilyInferenceResult TypeFamilyReductionGuesser::inferNotFamily(const TypeFamilyInstanceType* instance) +{ + LUAU_ASSERT(instance->typeArguments.size() == 1); + TypeId opTy = follow(instance->typeArguments[0]); + if (std::optional ty = tryAssignOperandType(opTy)) + opTy = follow(*ty); + return {{opTy}, builtins->booleanType}; +} + +TypeFamilyInferenceResult TypeFamilyReductionGuesser::inferLenFamily(const TypeFamilyInstanceType* instance) +{ + LUAU_ASSERT(instance->typeArguments.size() == 1); + TypeId opTy = follow(instance->typeArguments[0]); + if (std::optional ty = tryAssignOperandType(opTy)) + opTy = follow(*ty); + return {{opTy}, builtins->numberType}; +} + +TypeFamilyInferenceResult TypeFamilyReductionGuesser::inferUnaryMinusFamily(const TypeFamilyInstanceType* instance) +{ + LUAU_ASSERT(instance->typeArguments.size() == 1); + TypeId opTy = follow(instance->typeArguments[0]); + if (std::optional ty = tryAssignOperandType(opTy)) + opTy = follow(*ty); + if (isNumber(opTy)) + return {{builtins->numberType}, builtins->numberType}; + return {{builtins->unknownType}, builtins->numberType}; +} + + +} // namespace Luau diff --git a/Analysis/src/Unifier2.cpp b/Analysis/src/Unifier2.cpp index 5faa9553..51922602 100644 --- a/Analysis/src/Unifier2.cpp +++ b/Analysis/src/Unifier2.cpp @@ -106,6 +106,18 @@ bool Unifier2::unify(TypeId subTy, TypeId superTy) if (subFree && superFree) { + DenseHashSet seen{nullptr}; + if (OccursCheckResult::Fail == occursCheck(seen, subTy, superTy)) + { + asMutable(subTy)->ty.emplace(builtinTypes->errorRecoveryType()); + return false; + } + else if (OccursCheckResult::Fail == occursCheck(seen, superTy, subTy)) + { + asMutable(subTy)->ty.emplace(builtinTypes->errorRecoveryType()); + return false; + } + superFree->lowerBound = mkUnion(subFree->lowerBound, superFree->lowerBound); superFree->upperBound = mkIntersection(subFree->upperBound, superFree->upperBound); asMutable(subTy)->ty.emplace(superTy); @@ -821,6 +833,53 @@ TypeId Unifier2::mkIntersection(TypeId left, TypeId right) return simplifyIntersection(builtinTypes, arena, left, right).result; } +OccursCheckResult Unifier2::occursCheck(DenseHashSet& seen, TypeId needle, TypeId haystack) +{ + RecursionLimiter _ra(&recursionCount, recursionLimit); + + OccursCheckResult occurrence = OccursCheckResult::Pass; + + auto check = [&](TypeId ty) { + if (occursCheck(seen, needle, ty) == OccursCheckResult::Fail) + occurrence = OccursCheckResult::Fail; + }; + + needle = follow(needle); + haystack = follow(haystack); + + if (seen.find(haystack)) + return OccursCheckResult::Pass; + + seen.insert(haystack); + + if (get(needle)) + return OccursCheckResult::Pass; + + if (!get(needle)) + ice->ice("Expected needle to be free"); + + if (needle == haystack) + return OccursCheckResult::Fail; + + if (auto haystackFree = get(haystack)) + { + check(haystackFree->lowerBound); + check(haystackFree->upperBound); + } + else if (auto ut = get(haystack)) + { + for (TypeId ty : ut->options) + check(ty); + } + else if (auto it = get(haystack)) + { + for (TypeId ty : it->parts) + check(ty); + } + + return occurrence; +} + OccursCheckResult Unifier2::occursCheck(DenseHashSet& seen, TypePackId needle, TypePackId haystack) { needle = follow(needle); diff --git a/CodeGen/include/Luau/IrData.h b/CodeGen/include/Luau/IrData.h index 1d9bbc73..79d06e5a 100644 --- a/CodeGen/include/Luau/IrData.h +++ b/CodeGen/include/Luau/IrData.h @@ -398,6 +398,7 @@ enum class IrCmd : uint8_t // A, B: tag // C: block/vmexit/undef // In final x64 lowering, A can also be Rn + // When DebugLuauAbortingChecks flag is enabled, A can also be Rn // When undef is specified instead of a block, execution is aborted on check failure CHECK_TAG, diff --git a/CodeGen/include/Luau/IrVisitUseDef.h b/CodeGen/include/Luau/IrVisitUseDef.h index acff0d76..27244cb8 100644 --- a/CodeGen/include/Luau/IrVisitUseDef.h +++ b/CodeGen/include/Luau/IrVisitUseDef.h @@ -4,6 +4,8 @@ #include "Luau/Common.h" #include "Luau/IrData.h" +LUAU_FASTFLAG(LuauCodegenRemoveDeadStores2) + namespace Luau { namespace CodeGen @@ -186,7 +188,15 @@ static void visitVmRegDefsUses(T& visitor, IrFunction& function, const IrInst& i visitor.def(inst.b); break; case IrCmd::FALLBACK_FORGPREP: - visitor.use(inst.b); + if (FFlag::LuauCodegenRemoveDeadStores2) + { + // This instruction doesn't always redefine Rn, Rn+1, Rn+2, so we have to mark it as implicit use + visitor.useRange(vmRegOp(inst.b), 3); + } + else + { + visitor.use(inst.b); + } visitor.defRange(vmRegOp(inst.b), 3); break; @@ -204,6 +214,11 @@ static void visitVmRegDefsUses(T& visitor, IrFunction& function, const IrInst& i visitor.use(inst.a); break; + // After optimizations with DebugLuauAbortingChecks enabled, CHECK_TAG Rn, tag, block instructions are generated + case IrCmd::CHECK_TAG: + visitor.maybeUse(inst.a); + break; + default: // All instructions which reference registers have to be handled explicitly CODEGEN_ASSERT(inst.a.kind != IrOpKind::VmReg); diff --git a/CodeGen/include/Luau/OptimizeDeadStore.h b/CodeGen/include/Luau/OptimizeDeadStore.h new file mode 100644 index 00000000..45395a51 --- /dev/null +++ b/CodeGen/include/Luau/OptimizeDeadStore.h @@ -0,0 +1,16 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/IrData.h" + +namespace Luau +{ +namespace CodeGen +{ + +struct IrBuilder; + +void markDeadStoresInBlockChains(IrBuilder& build); + +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/src/CodeAllocator.cpp b/CodeGen/src/CodeAllocator.cpp index ab623b42..ca667ae7 100644 --- a/CodeGen/src/CodeAllocator.cpp +++ b/CodeGen/src/CodeAllocator.cpp @@ -27,6 +27,10 @@ const size_t kPageSize = sysconf(_SC_PAGESIZE); #endif #endif +#ifdef __APPLE__ +extern "C" void sys_icache_invalidate(void* start, size_t len); +#endif + static size_t alignToPageSize(size_t size) { return (size + kPageSize - 1) & ~(kPageSize - 1); @@ -98,7 +102,11 @@ static void makePagesExecutable(uint8_t* mem, size_t size) static void flushInstructionCache(uint8_t* mem, size_t size) { +#ifdef __APPLE__ + sys_icache_invalidate(mem, size); +#else __builtin___clear_cache((char*)mem, (char*)mem + size); +#endif } #endif diff --git a/CodeGen/src/CodeGenLower.h b/CodeGen/src/CodeGenLower.h index c011981b..33f395ff 100644 --- a/CodeGen/src/CodeGenLower.h +++ b/CodeGen/src/CodeGenLower.h @@ -8,6 +8,7 @@ #include "Luau/IrDump.h" #include "Luau/IrUtils.h" #include "Luau/OptimizeConstProp.h" +#include "Luau/OptimizeDeadStore.h" #include "Luau/OptimizeFinalX64.h" #include "EmitCommon.h" @@ -26,6 +27,7 @@ LUAU_FASTFLAG(DebugCodegenSkipNumbering) LUAU_FASTINT(CodegenHeuristicsInstructionLimit) LUAU_FASTINT(CodegenHeuristicsBlockLimit) LUAU_FASTINT(CodegenHeuristicsBlockInstructionLimit) +LUAU_FASTFLAG(LuauCodegenRemoveDeadStores2) namespace Luau { @@ -309,6 +311,9 @@ inline bool lowerFunction(IrBuilder& ir, AssemblyBuilder& build, ModuleHelpers& stats->blockLinearizationStats.constPropInstructionCount += constPropInstructionCount; } } + + if (FFlag::LuauCodegenRemoveDeadStores2) + markDeadStoresInBlockChains(ir); } std::vector sortedBlocks = getSortedBlockOrder(ir.function); diff --git a/CodeGen/src/IrLoweringX64.cpp b/CodeGen/src/IrLoweringX64.cpp index bf82be52..cfa9ba98 100644 --- a/CodeGen/src/IrLoweringX64.cpp +++ b/CodeGen/src/IrLoweringX64.cpp @@ -17,6 +17,7 @@ LUAU_FASTFLAG(LuauCodegenVectorTag2) LUAU_FASTFLAGVARIABLE(LuauCodegenVectorOptAnd, false) +LUAU_FASTFLAGVARIABLE(LuauCodegenSmallerUnm, false) namespace Luau { @@ -542,18 +543,24 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) { inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a}); - RegisterX64 src = regOp(inst.a); - - if (inst.regX64 == src) + if (FFlag::LuauCodegenSmallerUnm) { - build.vxorpd(inst.regX64, inst.regX64, build.f64(-0.0)); + build.vxorpd(inst.regX64, regOp(inst.a), build.f64(-0.0)); } else { - build.vmovsd(inst.regX64, src, src); - build.vxorpd(inst.regX64, inst.regX64, build.f64(-0.0)); - } + RegisterX64 src = regOp(inst.a); + if (inst.regX64 == src) + { + build.vxorpd(inst.regX64, inst.regX64, build.f64(-0.0)); + } + else + { + build.vmovsd(inst.regX64, src, src); + build.vxorpd(inst.regX64, inst.regX64, build.f64(-0.0)); + } + } break; } case IrCmd::FLOOR_NUM: @@ -604,13 +611,26 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) { inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a, inst.b}); - ScopedRegX64 tmp1{regs}; - ScopedRegX64 tmp2{regs}; + if (FFlag::LuauCodegenVectorOptAnd) + { + ScopedRegX64 tmp1{regs}; + ScopedRegX64 tmp2{regs}; - RegisterX64 tmpa = vecOp(inst.a, tmp1); - RegisterX64 tmpb = (inst.a == inst.b) ? tmpa : vecOp(inst.b, tmp2); + RegisterX64 tmpa = vecOp(inst.a, tmp1); + RegisterX64 tmpb = (inst.a == inst.b) ? tmpa : vecOp(inst.b, tmp2); - build.vaddps(inst.regX64, tmpa, tmpb); + build.vaddps(inst.regX64, tmpa, tmpb); + } + else + { + ScopedRegX64 tmp1{regs, SizeX64::xmmword}; + ScopedRegX64 tmp2{regs, SizeX64::xmmword}; + + // Fourth component is the tag number which is interpreted as a denormal and has to be filtered out + build.vandps(tmp1.reg, regOp(inst.a), vectorAndMaskOp()); + build.vandps(tmp2.reg, regOp(inst.b), vectorAndMaskOp()); + build.vaddps(inst.regX64, tmp1.reg, tmp2.reg); + } if (!FFlag::LuauCodegenVectorTag2) build.vorps(inst.regX64, inst.regX64, vectorOrMaskOp()); @@ -620,13 +640,27 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) { inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a, inst.b}); - ScopedRegX64 tmp1{regs}; - ScopedRegX64 tmp2{regs}; + if (FFlag::LuauCodegenVectorOptAnd) + { + ScopedRegX64 tmp1{regs}; + ScopedRegX64 tmp2{regs}; - RegisterX64 tmpa = vecOp(inst.a, tmp1); - RegisterX64 tmpb = (inst.a == inst.b) ? tmpa : vecOp(inst.b, tmp2); + RegisterX64 tmpa = vecOp(inst.a, tmp1); + RegisterX64 tmpb = (inst.a == inst.b) ? tmpa : vecOp(inst.b, tmp2); + + build.vsubps(inst.regX64, tmpa, tmpb); + } + else + { + ScopedRegX64 tmp1{regs, SizeX64::xmmword}; + ScopedRegX64 tmp2{regs, SizeX64::xmmword}; + + // Fourth component is the tag number which is interpreted as a denormal and has to be filtered out + build.vandps(tmp1.reg, regOp(inst.a), vectorAndMaskOp()); + build.vandps(tmp2.reg, regOp(inst.b), vectorAndMaskOp()); + build.vsubps(inst.regX64, tmp1.reg, tmp2.reg); + } - build.vsubps(inst.regX64, tmpa, tmpb); if (!FFlag::LuauCodegenVectorTag2) build.vorps(inst.regX64, inst.regX64, vectorOrMaskOp()); break; @@ -635,13 +669,27 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) { inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a, inst.b}); - ScopedRegX64 tmp1{regs}; - ScopedRegX64 tmp2{regs}; + if (FFlag::LuauCodegenVectorOptAnd) + { + ScopedRegX64 tmp1{regs}; + ScopedRegX64 tmp2{regs}; - RegisterX64 tmpa = vecOp(inst.a, tmp1); - RegisterX64 tmpb = (inst.a == inst.b) ? tmpa : vecOp(inst.b, tmp2); + RegisterX64 tmpa = vecOp(inst.a, tmp1); + RegisterX64 tmpb = (inst.a == inst.b) ? tmpa : vecOp(inst.b, tmp2); + + build.vmulps(inst.regX64, tmpa, tmpb); + } + else + { + ScopedRegX64 tmp1{regs, SizeX64::xmmword}; + ScopedRegX64 tmp2{regs, SizeX64::xmmword}; + + // Fourth component is the tag number which is interpreted as a denormal and has to be filtered out + build.vandps(tmp1.reg, regOp(inst.a), vectorAndMaskOp()); + build.vandps(tmp2.reg, regOp(inst.b), vectorAndMaskOp()); + build.vmulps(inst.regX64, tmp1.reg, tmp2.reg); + } - build.vmulps(inst.regX64, tmpa, tmpb); if (!FFlag::LuauCodegenVectorTag2) build.vorps(inst.regX64, inst.regX64, vectorOrMaskOp()); break; @@ -650,13 +698,27 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) { inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a, inst.b}); - ScopedRegX64 tmp1{regs}; - ScopedRegX64 tmp2{regs}; + if (FFlag::LuauCodegenVectorOptAnd) + { + ScopedRegX64 tmp1{regs}; + ScopedRegX64 tmp2{regs}; - RegisterX64 tmpa = vecOp(inst.a, tmp1); - RegisterX64 tmpb = (inst.a == inst.b) ? tmpa : vecOp(inst.b, tmp2); + RegisterX64 tmpa = vecOp(inst.a, tmp1); + RegisterX64 tmpb = (inst.a == inst.b) ? tmpa : vecOp(inst.b, tmp2); + + build.vdivps(inst.regX64, tmpa, tmpb); + } + else + { + ScopedRegX64 tmp1{regs, SizeX64::xmmword}; + ScopedRegX64 tmp2{regs, SizeX64::xmmword}; + + // Fourth component is the tag number which is interpreted as a denormal and has to be filtered out + build.vandps(tmp1.reg, regOp(inst.a), vectorAndMaskOp()); + build.vandps(tmp2.reg, regOp(inst.b), vectorAndMaskOp()); + build.vdivps(inst.regX64, tmp1.reg, tmp2.reg); + } - build.vdivps(inst.regX64, tmpa, tmpb); if (!FFlag::LuauCodegenVectorTag2) build.vpinsrd(inst.regX64, inst.regX64, build.i32(LUA_TVECTOR), 3); break; @@ -665,16 +727,23 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) { inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a}); - RegisterX64 src = regOp(inst.a); - - if (inst.regX64 == src) + if (FFlag::LuauCodegenSmallerUnm) { - build.vxorpd(inst.regX64, inst.regX64, build.f32x4(-0.0, -0.0, -0.0, -0.0)); + build.vxorpd(inst.regX64, regOp(inst.a), build.f32x4(-0.0, -0.0, -0.0, -0.0)); } else { - build.vmovsd(inst.regX64, src, src); - build.vxorpd(inst.regX64, inst.regX64, build.f32x4(-0.0, -0.0, -0.0, -0.0)); + RegisterX64 src = regOp(inst.a); + + if (inst.regX64 == src) + { + build.vxorpd(inst.regX64, inst.regX64, build.f32x4(-0.0, -0.0, -0.0, -0.0)); + } + else + { + build.vmovsd(inst.regX64, src, src); + build.vxorpd(inst.regX64, inst.regX64, build.f32x4(-0.0, -0.0, -0.0, -0.0)); + } } if (!FFlag::LuauCodegenVectorTag2) @@ -2299,6 +2368,7 @@ OperandX64 IrLoweringX64::vectorAndMaskOp() OperandX64 IrLoweringX64::vectorOrMaskOp() { CODEGEN_ASSERT(!FFlag::LuauCodegenVectorTag2); + if (vectorOrMask.base == noreg) vectorOrMask = build.u32x4(0, 0, 0, LUA_TVECTOR); diff --git a/CodeGen/src/OptimizeDeadStore.cpp b/CodeGen/src/OptimizeDeadStore.cpp new file mode 100644 index 00000000..3ea066e4 --- /dev/null +++ b/CodeGen/src/OptimizeDeadStore.cpp @@ -0,0 +1,530 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/OptimizeDeadStore.h" + +#include "Luau/IrBuilder.h" +#include "Luau/IrVisitUseDef.h" +#include "Luau/IrUtils.h" + +#include + +#include "lobject.h" + +LUAU_FASTFLAGVARIABLE(LuauCodegenRemoveDeadStores2, false) +LUAU_FASTFLAG(LuauCodegenVectorTag2) + +// TODO: optimization can be improved by knowing which registers are live in at each VM exit + +namespace Luau +{ +namespace CodeGen +{ + +// Luau value structure reminder: +// [ TValue ] +// [ Value ][ Extra ][ Tag ] +// Storing individual components will not kill any previous TValue stores +// Storing TValue will kill any full store or a component store ('extra' excluded because it's rare) + +struct StoreRegInfo +{ + // Indices of the last unused store instructions + uint32_t tagInstIdx = ~0u; + uint32_t valueInstIdx = ~0u; + uint32_t tvalueInstIdx = ~0u; + + // This register might contain a GC object + bool maybeGco = false; +}; + +struct RemoveDeadStoreState +{ + RemoveDeadStoreState(IrFunction& function) + : function(function) + { + maxReg = function.proto ? function.proto->maxstacksize : 255; + } + + void killTagStore(StoreRegInfo& regInfo) + { + if (regInfo.tagInstIdx != ~0u) + { + kill(function, function.instructions[regInfo.tagInstIdx]); + + regInfo.tagInstIdx = ~0u; + regInfo.maybeGco = false; + } + } + + void killValueStore(StoreRegInfo& regInfo) + { + if (regInfo.valueInstIdx != ~0u) + { + kill(function, function.instructions[regInfo.valueInstIdx]); + + regInfo.valueInstIdx = ~0u; + regInfo.maybeGco = false; + } + } + + void killTValueStore(StoreRegInfo& regInfo) + { + if (regInfo.tvalueInstIdx != ~0u) + { + kill(function, function.instructions[regInfo.tvalueInstIdx]); + + regInfo.tvalueInstIdx = ~0u; + regInfo.maybeGco = false; + } + } + + // When a register value is being defined, it kills previous stores + void defReg(uint8_t reg) + { + StoreRegInfo& regInfo = info[reg]; + + // Stores to captured registers are not removed since we don't track their uses outside of function + if (function.cfg.captured.regs.test(reg)) + return; + + killTagStore(regInfo); + killValueStore(regInfo); + killTValueStore(regInfo); + } + + // When a register value is being used, we forget about the last store location to not kill them + void useReg(uint8_t reg) + { + info[reg] = StoreRegInfo{}; + } + + // When checking control flow, such as exit to fallback blocks: + // For VM exits, we keep all stores because we don't have information on what registers are live at the start of the VM assist + // For regular blocks, we check which registers are expected to be live at entry (if we have CFG information available) + void checkLiveIns(IrOp op) + { + if (op.kind == IrOpKind::VmExit) + { + clear(); + } + else if (op.kind == IrOpKind::Block) + { + if (op.index < function.cfg.in.size()) + { + const RegisterSet& in = function.cfg.in[op.index]; + + for (int i = 0; i <= maxReg; i++) + { + if (in.regs.test(i) || (in.varargSeq && i >= in.varargStart)) + useReg(i); + } + } + else + { + clear(); + } + } + else if (op.kind == IrOpKind::Undef) + { + // Nothing to do for a debug abort + } + else + { + CODEGEN_ASSERT(!"unexpected jump target type"); + } + } + + // When checking block terminators, any registers that are not live out can be removed by saying that a new value is being 'defined' + void checkLiveOuts(const IrBlock& block) + { + uint32_t index = function.getBlockIndex(block); + + if (index < function.cfg.out.size()) + { + const RegisterSet& out = function.cfg.out[index]; + + for (int i = 0; i <= maxReg; i++) + { + bool isOut = out.regs.test(i) || (out.varargSeq && i >= out.varargStart); + + if (!isOut) + defReg(i); + } + } + } + + // Common instruction visitor handling + void defVarargs(uint8_t varargStart) + { + for (int i = varargStart; i <= maxReg; i++) + defReg(uint8_t(i)); + } + + void useVarargs(uint8_t varargStart) + { + for (int i = varargStart; i <= maxReg; i++) + useReg(uint8_t(i)); + } + + void def(IrOp op, int offset = 0) + { + defReg(vmRegOp(op) + offset); + } + + void use(IrOp op, int offset = 0) + { + useReg(vmRegOp(op) + offset); + } + + void maybeDef(IrOp op) + { + if (op.kind == IrOpKind::VmReg) + defReg(vmRegOp(op)); + } + + void maybeUse(IrOp op) + { + if (op.kind == IrOpKind::VmReg) + useReg(vmRegOp(op)); + } + + void defRange(int start, int count) + { + if (count == -1) + { + defVarargs(start); + } + else + { + for (int i = start; i < start + count; i++) + defReg(i); + } + } + + void useRange(int start, int count) + { + if (count == -1) + { + useVarargs(start); + } + else + { + for (int i = start; i < start + count; i++) + useReg(i); + } + } + + // Required for a full visitor interface + void capture(int reg) {} + + // Full clear of the tracked information + void clear() + { + for (int i = 0; i <= maxReg; i++) + info[i] = StoreRegInfo(); + + hasGcoToClear = false; + } + + // Partial clear of information about registers that might contain a GC object + // This is used by instructions that might perform a GC assist and GC needs all pointers to be pinned to stack + void flushGcoRegs() + { + for (int i = 0; i <= maxReg; i++) + { + if (info[i].maybeGco) + info[i] = StoreRegInfo(); + } + + hasGcoToClear = false; + } + + IrFunction& function; + + std::array info; + int maxReg = 255; + + // Some of the registers contain values which might be a GC object + bool hasGcoToClear = false; +}; + +static void markDeadStoresInInst(RemoveDeadStoreState& state, IrBuilder& build, IrFunction& function, IrBlock& block, IrInst& inst, uint32_t index) +{ + switch (inst.cmd) + { + case IrCmd::STORE_TAG: + if (inst.a.kind == IrOpKind::VmReg) + { + int reg = vmRegOp(inst.a); + + if (function.cfg.captured.regs.test(reg)) + return; + + StoreRegInfo& regInfo = state.info[reg]; + + state.killTagStore(regInfo); + + uint8_t tag = function.tagOp(inst.b); + + regInfo.tagInstIdx = index; + regInfo.maybeGco = isGCO(tag); + state.hasGcoToClear |= regInfo.maybeGco; + } + break; + case IrCmd::STORE_EXTRA: + // To simplify, extra field store is preserved along with all other stores made so far + if (inst.a.kind == IrOpKind::VmReg) + { + state.useReg(vmRegOp(inst.a)); + } + break; + case IrCmd::STORE_POINTER: + if (inst.a.kind == IrOpKind::VmReg) + { + int reg = vmRegOp(inst.a); + + if (function.cfg.captured.regs.test(reg)) + return; + + StoreRegInfo& regInfo = state.info[reg]; + + state.killValueStore(regInfo); + + regInfo.valueInstIdx = index; + regInfo.maybeGco = true; + state.hasGcoToClear = true; + } + break; + case IrCmd::STORE_DOUBLE: + case IrCmd::STORE_INT: + case IrCmd::STORE_VECTOR: + if (inst.a.kind == IrOpKind::VmReg) + { + int reg = vmRegOp(inst.a); + + if (function.cfg.captured.regs.test(reg)) + return; + + StoreRegInfo& regInfo = state.info[reg]; + + state.killValueStore(regInfo); + + regInfo.valueInstIdx = index; + } + break; + case IrCmd::STORE_TVALUE: + if (inst.a.kind == IrOpKind::VmReg) + { + int reg = vmRegOp(inst.a); + + if (function.cfg.captured.regs.test(reg)) + return; + + StoreRegInfo& regInfo = state.info[reg]; + + state.killTagStore(regInfo); + state.killValueStore(regInfo); + state.killTValueStore(regInfo); + + regInfo.tvalueInstIdx = index; + regInfo.maybeGco = true; + + // If the argument is a vector, it's not a GC object + // Note that for known boolean/number/GCO, we already optimize into STORE_SPLIT_TVALUE form + // TODO: this can be removed if TAG_VECTOR+STORE_TVALUE is replaced with STORE_SPLIT_TVALUE + if (IrInst* arg = function.asInstOp(inst.b)) + { + if (FFlag::LuauCodegenVectorTag2) + { + if (arg->cmd == IrCmd::TAG_VECTOR) + regInfo.maybeGco = false; + } + else + { + if (arg->cmd == IrCmd::ADD_VEC || arg->cmd == IrCmd::SUB_VEC || arg->cmd == IrCmd::MUL_VEC || arg->cmd == IrCmd::DIV_VEC || + arg->cmd == IrCmd::UNM_VEC) + regInfo.maybeGco = false; + } + } + + state.hasGcoToClear |= regInfo.maybeGco; + } + break; + case IrCmd::STORE_SPLIT_TVALUE: + if (inst.a.kind == IrOpKind::VmReg) + { + int reg = vmRegOp(inst.a); + + if (function.cfg.captured.regs.test(reg)) + return; + + StoreRegInfo& regInfo = state.info[reg]; + + state.killTagStore(regInfo); + state.killValueStore(regInfo); + state.killTValueStore(regInfo); + + regInfo.tvalueInstIdx = index; + regInfo.maybeGco = isGCO(function.tagOp(inst.b)); + state.hasGcoToClear |= regInfo.maybeGco; + } + break; + + // Guard checks can jump to a block which might be using some or all the values we stored + case IrCmd::CHECK_TAG: + // After optimizations with DebugLuauAbortingChecks enabled, CHECK_TAG might use a VM register + visitVmRegDefsUses(state, function, inst); + + state.checkLiveIns(inst.c); + break; + case IrCmd::TRY_NUM_TO_INDEX: + state.checkLiveIns(inst.b); + break; + case IrCmd::TRY_CALL_FASTGETTM: + state.checkLiveIns(inst.c); + break; + case IrCmd::CHECK_FASTCALL_RES: + state.checkLiveIns(inst.b); + break; + case IrCmd::CHECK_TRUTHY: + state.checkLiveIns(inst.c); + break; + case IrCmd::CHECK_READONLY: + state.checkLiveIns(inst.b); + break; + case IrCmd::CHECK_NO_METATABLE: + state.checkLiveIns(inst.b); + break; + case IrCmd::CHECK_SAFE_ENV: + state.checkLiveIns(inst.a); + break; + case IrCmd::CHECK_ARRAY_SIZE: + state.checkLiveIns(inst.c); + break; + case IrCmd::CHECK_SLOT_MATCH: + state.checkLiveIns(inst.c); + break; + case IrCmd::CHECK_NODE_NO_NEXT: + state.checkLiveIns(inst.b); + break; + case IrCmd::CHECK_NODE_VALUE: + state.checkLiveIns(inst.b); + break; + case IrCmd::CHECK_BUFFER_LEN: + state.checkLiveIns(inst.d); + break; + + case IrCmd::JUMP: + // Ideally, we would be able to remove stores to registers that are not live out from a block + // But during chain optimizations, we rely on data stored in the predecessor even when it's not an explicit live out + break; + case IrCmd::RETURN: + visitVmRegDefsUses(state, function, inst); + + // At the end of a function, we can kill stores to registers that are not live out + state.checkLiveOuts(block); + break; + case IrCmd::ADJUST_STACK_TO_REG: + // visitVmRegDefsUses considers adjustment as the fast call register definition point, but for dead store removal, we count the actual writes + break; + + // This group of instructions can trigger GC assist internally + // For GC to work correctly, all values containing a GCO have to be stored on stack - otherwise a live reference might be missed + case IrCmd::CMP_ANY: + case IrCmd::DO_ARITH: + case IrCmd::DO_LEN: + case IrCmd::GET_TABLE: + case IrCmd::SET_TABLE: + case IrCmd::GET_IMPORT: + case IrCmd::CONCAT: + case IrCmd::INTERRUPT: + case IrCmd::CHECK_GC: + case IrCmd::CALL: + case IrCmd::FORGLOOP_FALLBACK: + case IrCmd::FALLBACK_GETGLOBAL: + case IrCmd::FALLBACK_SETGLOBAL: + case IrCmd::FALLBACK_GETTABLEKS: + case IrCmd::FALLBACK_SETTABLEKS: + case IrCmd::FALLBACK_NAMECALL: + case IrCmd::FALLBACK_DUPCLOSURE: + case IrCmd::FALLBACK_FORGPREP: + if (state.hasGcoToClear) + state.flushGcoRegs(); + + visitVmRegDefsUses(state, function, inst); + break; + + default: + // Guards have to be covered explicitly + CODEGEN_ASSERT(!isNonTerminatingJump(inst.cmd)); + + visitVmRegDefsUses(state, function, inst); + break; + } +} + +static void markDeadStoresInBlock(IrBuilder& build, IrBlock& block, RemoveDeadStoreState& state) +{ + IrFunction& function = build.function; + + for (uint32_t index = block.start; index <= block.finish; index++) + { + CODEGEN_ASSERT(index < function.instructions.size()); + IrInst& inst = function.instructions[index]; + + markDeadStoresInInst(state, build, function, block, inst, index); + } +} + +static void markDeadStoresInBlockChain(IrBuilder& build, std::vector& visited, IrBlock* block) +{ + IrFunction& function = build.function; + + RemoveDeadStoreState state{function}; + + while (block) + { + uint32_t blockIdx = function.getBlockIndex(*block); + CODEGEN_ASSERT(!visited[blockIdx]); + visited[blockIdx] = true; + + markDeadStoresInBlock(build, *block, state); + + IrInst& termInst = function.instructions[block->finish]; + + IrBlock* nextBlock = nullptr; + + // Unconditional jump into a block with a single user (current block) allows us to continue optimization + // with the information we have gathered so far (unless we have already visited that block earlier) + if (termInst.cmd == IrCmd::JUMP && termInst.a.kind == IrOpKind::Block) + { + IrBlock& target = function.blockOp(termInst.a); + uint32_t targetIdx = function.getBlockIndex(target); + + if (target.useCount == 1 && !visited[targetIdx] && target.kind != IrBlockKind::Fallback) + nextBlock = ⌖ + } + + block = nextBlock; + } +} + +void markDeadStoresInBlockChains(IrBuilder& build) +{ + IrFunction& function = build.function; + + std::vector visited(function.blocks.size(), false); + + for (IrBlock& block : function.blocks) + { + if (block.kind == IrBlockKind::Fallback || block.kind == IrBlockKind::Dead) + continue; + + if (visited[function.getBlockIndex(block)]) + continue; + + markDeadStoresInBlockChain(build, visited, &block); + } +} + +} // namespace CodeGen +} // namespace Luau diff --git a/Sources.cmake b/Sources.cmake index 5745f6dd..294ce7aa 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -88,6 +88,7 @@ target_sources(Luau.CodeGen PRIVATE CodeGen/include/Luau/Label.h CodeGen/include/Luau/OperandX64.h CodeGen/include/Luau/OptimizeConstProp.h + CodeGen/include/Luau/OptimizeDeadStore.h CodeGen/include/Luau/OptimizeFinalX64.h CodeGen/include/Luau/RegisterA64.h CodeGen/include/Luau/RegisterX64.h @@ -125,6 +126,7 @@ target_sources(Luau.CodeGen PRIVATE CodeGen/src/lcodegen.cpp CodeGen/src/NativeState.cpp CodeGen/src/OptimizeConstProp.cpp + CodeGen/src/OptimizeDeadStore.cpp CodeGen/src/OptimizeFinalX64.cpp CodeGen/src/UnwindBuilderDwarf2.cpp CodeGen/src/UnwindBuilderWin.cpp @@ -210,6 +212,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/TypeCheckLimits.h Analysis/include/Luau/TypedAllocator.h Analysis/include/Luau/TypeFamily.h + Analysis/include/Luau/TypeFamilyReductionGuesser.h Analysis/include/Luau/TypeFwd.h Analysis/include/Luau/TypeInfer.h Analysis/include/Luau/TypeOrPack.h @@ -271,6 +274,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/TypeChecker2.cpp Analysis/src/TypedAllocator.cpp Analysis/src/TypeFamily.cpp + Analysis/src/TypeFamilyReductionGuesser.cpp Analysis/src/TypeInfer.cpp Analysis/src/TypeOrPack.cpp Analysis/src/TypePack.cpp diff --git a/VM/src/lmem.cpp b/VM/src/lmem.cpp index 6628918f..d52d3794 100644 --- a/VM/src/lmem.cpp +++ b/VM/src/lmem.cpp @@ -260,7 +260,10 @@ static lua_Page* newpage(lua_State* L, lua_Page** gcopageset, int pageSize, int return page; } -static lua_Page* newclasspage(lua_State* L, lua_Page** freepageset, lua_Page** gcopageset, uint8_t sizeClass, bool storeMetadata) +// this is part of a cold path in newblock and newgcoblock +// it is marked as noinline to prevent it from being inlined into those functions +// if it is inlined, then the compiler may determine those functions are "too big" to be profitably inlined, which results in reduced performance +LUAU_NOINLINE static lua_Page* newclasspage(lua_State* L, lua_Page** freepageset, lua_Page** gcopageset, uint8_t sizeClass, bool storeMetadata) { if (FFlag::LuauExtendedSizeClasses) { diff --git a/tests/ClassFixture.cpp b/tests/ClassFixture.cpp index db6cd327..980df711 100644 --- a/tests/ClassFixture.cpp +++ b/tests/ClassFixture.cpp @@ -103,6 +103,12 @@ ClassFixture::ClassFixture() }; getMutable(vector2MetaType)->props = { {"__add", {makeFunction(arena, nullopt, {vector2InstanceType, vector2InstanceType}, {vector2InstanceType})}}, + {"__mul", { + arena.addType(IntersectionType{{ + makeFunction(arena, vector2InstanceType, {vector2InstanceType}, {vector2InstanceType}), + makeFunction(arena, vector2InstanceType, {builtinTypes->numberType}, {vector2InstanceType}), + }}) + }} }; globals.globalScope->exportedTypeBindings["Vector2"] = TypeFun{{}, vector2InstanceType}; addGlobalBinding(globals, "Vector2", vector2Type, "@test"); diff --git a/tests/IrBuilder.test.cpp b/tests/IrBuilder.test.cpp index 78f51809..83d84f63 100644 --- a/tests/IrBuilder.test.cpp +++ b/tests/IrBuilder.test.cpp @@ -4,6 +4,7 @@ #include "Luau/IrDump.h" #include "Luau/IrUtils.h" #include "Luau/OptimizeConstProp.h" +#include "Luau/OptimizeDeadStore.h" #include "Luau/OptimizeFinalX64.h" #include "ScopedFlags.h" @@ -15,6 +16,10 @@ LUAU_FASTFLAG(LuauCodegenVectorTag2) using namespace Luau::CodeGen; +LUAU_FASTFLAG(LuauCodegenRemoveDeadStores2) + +LUAU_FASTFLAG(DebugLuauAbortingChecks) + class IrBuilderFixture { public: @@ -2538,6 +2543,8 @@ bb_0: ; useCount: 0 TEST_CASE_FIXTURE(IrBuilderFixture, "ForgprepInvalidation") { + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores2, true}; + IrOp block = build.block(IrBlockKind::Internal); IrOp followup = build.block(IrBlockKind::Internal); @@ -2560,7 +2567,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "ForgprepInvalidation") CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( bb_0: ; successors: bb_1 -; in regs: R0, R1 +; in regs: R0, R1, R2, R3 ; out regs: R1, R2, R3 %0 = LOAD_POINTER R0 CHECK_READONLY %0, exit(1) @@ -2884,6 +2891,65 @@ bb_1: )"); } +TEST_CASE_FIXTURE(IrBuilderFixture, "ForgprepImplicitUse") +{ + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores2, true}; + + IrOp entry = build.block(IrBlockKind::Internal); + IrOp direct = build.block(IrBlockKind::Internal); + IrOp fallback = build.block(IrBlockKind::Internal); + IrOp exit = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.constDouble(1.0)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(2), build.constDouble(10.0)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(3), build.constDouble(1.0)); + IrOp tag = build.inst(IrCmd::LOAD_TAG, build.vmReg(0)); + build.inst(IrCmd::JUMP_EQ_TAG, tag, build.constTag(tnumber), direct, fallback); + + build.beginBlock(direct); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(1)); + + build.beginBlock(fallback); + build.inst(IrCmd::FALLBACK_FORGPREP, build.constUint(0), build.vmReg(1), exit); + + build.beginBlock(exit); + build.inst(IrCmd::RETURN, build.vmReg(1), build.constInt(3)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: +; successors: bb_1, bb_2 +; in regs: R0 +; out regs: R0, R1, R2, R3 + STORE_DOUBLE R1, 1 + STORE_DOUBLE R2, 10 + STORE_DOUBLE R3, 1 + %3 = LOAD_TAG R0 + JUMP_EQ_TAG %3, tnumber, bb_1, bb_2 + +bb_1: +; predecessors: bb_0 +; in regs: R0 + RETURN R0, 1i + +bb_2: +; predecessors: bb_0 +; successors: bb_3 +; in regs: R1, R2, R3 +; out regs: R1, R2, R3 + FALLBACK_FORGPREP 0u, R1, bb_3 + +bb_3: +; predecessors: bb_2 +; in regs: R1, R2, R3 + RETURN R1, 3i + +)"); +} + TEST_CASE_FIXTURE(IrBuilderFixture, "SetTable") { IrOp entry = build.block(IrBlockKind::Internal); @@ -3333,6 +3399,358 @@ bb_1: TEST_SUITE_END(); +TEST_SUITE_BEGIN("DeadStoreRemoval"); + +TEST_CASE_FIXTURE(IrBuilderFixture, "SimpleDoubleStore") +{ + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores2, true}; + + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.constDouble(1.0)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.constDouble(2.0)); // Should remove previous store + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(2), build.constDouble(1.0)); + build.inst(IrCmd::STORE_INT, build.vmReg(2), build.constInt(4)); // Should remove previous store of different type + + build.inst(IrCmd::STORE_TAG, build.vmReg(3), build.constTag(tnil)); + build.inst(IrCmd::STORE_TAG, build.vmReg(3), build.constTag(tnumber)); // Should remove previous store + + build.inst(IrCmd::STORE_TAG, build.vmReg(4), build.constTag(tnil)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(4), build.constDouble(1.0)); + build.inst(IrCmd::STORE_SPLIT_TVALUE, build.vmReg(4), build.constTag(tnumber), build.constDouble(2.0)); // Should remove two previous stores + + IrOp someTv = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(0)); + build.inst(IrCmd::STORE_TAG, build.vmReg(5), build.constTag(tnil)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(5), build.constDouble(1.0)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(5), someTv); // Should remove two previous stores + + build.inst(IrCmd::RETURN, build.vmReg(1), build.constInt(5)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + constPropInBlockChains(build, true); + markDeadStoresInBlockChains(build); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: +; in regs: R0 + STORE_DOUBLE R1, 2 + STORE_INT R2, 4i + STORE_TAG R3, tnumber + STORE_SPLIT_TVALUE R4, tnumber, 2 + %9 = LOAD_TVALUE R0 + STORE_TVALUE R5, %9 + RETURN R1, 5i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "UnusedAtReturn") +{ + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores2, true}; + + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.constDouble(1.0)); + build.inst(IrCmd::STORE_INT, build.vmReg(2), build.constInt(4)); + build.inst(IrCmd::STORE_TAG, build.vmReg(3), build.constTag(tnumber)); + build.inst(IrCmd::STORE_SPLIT_TVALUE, build.vmReg(4), build.constTag(tnumber), build.constDouble(2.0)); + + IrOp someTv = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(0)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(5), someTv); + + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + constPropInBlockChains(build, true); + markDeadStoresInBlockChains(build); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: +; in regs: R0 + RETURN R0, 1i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "HiddenPointerUse1") +{ + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores2, true}; + + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + IrOp somePtr = build.inst(IrCmd::LOAD_POINTER, build.vmReg(0)); + build.inst(IrCmd::STORE_POINTER, build.vmReg(1), somePtr); + build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.constTag(ttable)); + build.inst(IrCmd::CALL, build.vmReg(2), build.constInt(0), build.constInt(1)); + build.inst(IrCmd::RETURN, build.vmReg(2), build.constInt(1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + constPropInBlockChains(build, true); + markDeadStoresInBlockChains(build); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: +; in regs: R0, R2 + %0 = LOAD_POINTER R0 + STORE_POINTER R1, %0 + STORE_TAG R1, ttable + CALL R2, 0i, 1i + RETURN R2, 1i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "HiddenPointerUse2") +{ + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores2, true}; + + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + IrOp somePtrA = build.inst(IrCmd::LOAD_POINTER, build.vmReg(0)); + build.inst(IrCmd::STORE_POINTER, build.vmReg(1), somePtrA); + build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.constTag(ttable)); + build.inst(IrCmd::CALL, build.vmReg(2), build.constInt(0), build.constInt(1)); + IrOp somePtrB = build.inst(IrCmd::LOAD_POINTER, build.vmReg(2)); + build.inst(IrCmd::STORE_POINTER, build.vmReg(1), somePtrB); + build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.constTag(ttable)); + build.inst(IrCmd::RETURN, build.vmReg(2), build.constInt(1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + constPropInBlockChains(build, true); + markDeadStoresInBlockChains(build); + + // Stores to pointers can be safely removed at 'return' point, but have to preserved for any GC assist trigger (such as a call) + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: +; in regs: R0, R2 + %0 = LOAD_POINTER R0 + STORE_POINTER R1, %0 + STORE_TAG R1, ttable + CALL R2, 0i, 1i + RETURN R2, 1i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "HiddenPointerUse3") +{ + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores2, true}; + + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + IrOp somePtrA = build.inst(IrCmd::LOAD_POINTER, build.vmReg(0)); + build.inst(IrCmd::STORE_POINTER, build.vmReg(1), somePtrA); + build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.constTag(ttable)); + IrOp someTv = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(2)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(1), someTv); + build.inst(IrCmd::RETURN, build.vmReg(1), build.constInt(1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + constPropInBlockChains(build, true); + markDeadStoresInBlockChains(build); + + // Stores to pointers can be safely removed if there are no potential implicit uses by any GC assists + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: +; in regs: R0, R2 + %3 = LOAD_TVALUE R2 + STORE_TVALUE R1, %3 + RETURN R1, 1i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "IgnoreFastcallAdjustment") +{ + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores2, true}; + + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.constTag(tnumber)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.constDouble(-1.0)); + build.inst(IrCmd::ADJUST_STACK_TO_REG, build.vmReg(1), build.constInt(1)); + build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.constTag(tnumber)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.constDouble(1.0)); + build.inst(IrCmd::RETURN, build.vmReg(1), build.constInt(1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + constPropInBlockChains(build, true); + markDeadStoresInBlockChains(build); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + STORE_TAG R1, tnumber + ADJUST_STACK_TO_REG R1, 1i + STORE_DOUBLE R1, 1 + RETURN R1, 1i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "JumpImplicitLiveOut") +{ + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores2, true}; + + IrOp entry = build.block(IrBlockKind::Internal); + IrOp next = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.constTag(tnumber)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.constDouble(1.0)); + build.inst(IrCmd::JUMP, next); + + build.beginBlock(next); + build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.constTag(tnumber)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.constDouble(1.0)); + build.inst(IrCmd::RETURN, build.vmReg(1), build.constInt(1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + constPropInBlockChains(build, true); + markDeadStoresInBlockChains(build); + + // Even though bb_0 doesn't have R1 as a live out, chain optimization used the knowledge of those writes happening to optimize duplicate stores + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: +; successors: bb_1 + STORE_TAG R1, tnumber + STORE_DOUBLE R1, 1 + JUMP bb_1 + +bb_1: +; predecessors: bb_0 + RETURN R1, 1i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "KeepCapturedRegisterStores") +{ + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores2, true}; + + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::CAPTURE, build.vmReg(1), build.constUint(1)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.constDouble(1.0)); + build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.constTag(tnumber)); + build.inst(IrCmd::DO_ARITH, build.vmReg(0), build.vmReg(2), build.vmReg(3), build.constInt(0)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.constDouble(-1.0)); + build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.constTag(tnumber)); + build.inst(IrCmd::DO_ARITH, build.vmReg(1), build.vmReg(4), build.vmReg(5), build.constInt(0)); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(2)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + constPropInBlockChains(build, true); + markDeadStoresInBlockChains(build); + + // Captured registers may be modified from called user functions (plain or hidden in metamethods) + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +; captured regs: R1 + +bb_0: +; in regs: R1, R2, R3, R4, R5 + CAPTURE R1, 1u + STORE_DOUBLE R1, 1 + STORE_TAG R1, tnumber + DO_ARITH R0, R2, R3, 0i + STORE_DOUBLE R1, -1 + STORE_TAG R1, tnumber + DO_ARITH R1, R4, R5, 0i + RETURN R0, 2i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "AbortingChecksRequireStores") +{ + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores2, true}; + ScopedFastFlag debugLuauAbortingChecks{FFlag::DebugLuauAbortingChecks, true}; + + IrOp block = build.block(IrBlockKind::Internal); + + build.beginBlock(block); + + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(2), build.constDouble(0.5)); + + build.inst(IrCmd::STORE_TAG, build.vmReg(3), build.inst(IrCmd::LOAD_TAG, build.vmReg(0))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(5), build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(2))); + + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(2), build.constDouble(0.5)); + + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnil)); + + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(6)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + constPropInBlockChains(build, true); + markDeadStoresInBlockChains(build); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: +; in regs: R1, R4 + STORE_TAG R0, tnumber + STORE_DOUBLE R2, 0.5 + STORE_TAG R3, tnumber + STORE_DOUBLE R5, 0.5 + CHECK_TAG R0, tnumber, undef + STORE_TAG R0, tnil + RETURN R0, 6i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "PartialOverFullValue") +{ + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores2, true}; + + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::STORE_SPLIT_TVALUE, build.vmReg(0), build.constTag(tnumber), build.constDouble(1.0)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.constDouble(2.0)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.constDouble(4.0)); + build.inst( + IrCmd::STORE_SPLIT_TVALUE, build.vmReg(0), build.constTag(ttable), build.inst(IrCmd::NEW_TABLE, build.constUint(16), build.constUint(32))); + build.inst(IrCmd::STORE_POINTER, build.vmReg(0), build.inst(IrCmd::NEW_TABLE, build.constUint(8), build.constUint(16))); + build.inst(IrCmd::STORE_POINTER, build.vmReg(0), build.inst(IrCmd::NEW_TABLE, build.constUint(4), build.constUint(8))); + build.inst(IrCmd::STORE_SPLIT_TVALUE, build.vmReg(0), build.constTag(tnumber), build.constDouble(1.0)); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tstring)); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(ttable)); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + markDeadStoresInBlockChains(build); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + STORE_SPLIT_TVALUE R0, tnumber, 1 + STORE_TAG R0, ttable + RETURN R0, 1i + +)"); +} + +TEST_SUITE_END(); + TEST_SUITE_BEGIN("Dump"); TEST_CASE_FIXTURE(IrBuilderFixture, "ToDot") diff --git a/tests/IrLowering.test.cpp b/tests/IrLowering.test.cpp index 31661711..e8140637 100644 --- a/tests/IrLowering.test.cpp +++ b/tests/IrLowering.test.cpp @@ -13,6 +13,7 @@ #include LUAU_FASTFLAG(LuauCodegenVectorTag2) +LUAU_FASTFLAG(LuauCodegenRemoveDeadStores2) static std::string getCodegenAssembly(const char* source) { @@ -90,6 +91,8 @@ bb_bytecode_1: TEST_CASE("VectorComponentRead") { + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores2, true}; + CHECK_EQ("\n" + getCodegenAssembly(R"( local function compsum(a: vector) return a.X + a.Y + a.Z @@ -104,16 +107,9 @@ bb_2: JUMP bb_bytecode_1 bb_bytecode_1: %6 = LOAD_FLOAT R0, 0i - STORE_DOUBLE R3, %6 - STORE_TAG R3, tnumber %11 = LOAD_FLOAT R0, 4i - STORE_DOUBLE R4, %11 - STORE_TAG R4, tnumber %20 = ADD_NUM %6, %11 - STORE_DOUBLE R2, %20 - STORE_TAG R2, tnumber %25 = LOAD_FLOAT R0, 8i - STORE_DOUBLE R3, %25 %34 = ADD_NUM %20, %25 STORE_DOUBLE R1, %34 STORE_TAG R1, tnumber @@ -179,6 +175,7 @@ bb_bytecode_1: TEST_CASE("VectorSubMulDiv") { ScopedFastFlag luauCodegenVectorTag2{FFlag::LuauCodegenVectorTag2, true}; + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores2, true}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function vec3combo(a: vector, b: vector, c: vector, d: vector) @@ -199,13 +196,9 @@ bb_bytecode_1: %14 = LOAD_TVALUE R0 %15 = LOAD_TVALUE R1 %16 = MUL_VEC %14, %15 - %17 = TAG_VECTOR %16 - STORE_TVALUE R5, %17 %23 = LOAD_TVALUE R2 %24 = LOAD_TVALUE R3 %25 = DIV_VEC %23, %24 - %26 = TAG_VECTOR %25 - STORE_TVALUE R6, %26 %34 = SUB_VEC %16, %25 %35 = TAG_VECTOR %34 STORE_TVALUE R4, %35 @@ -217,6 +210,7 @@ bb_bytecode_1: TEST_CASE("VectorSubMulDiv2") { ScopedFastFlag luauCodegenVectorTag2{FFlag::LuauCodegenVectorTag2, true}; + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores2, true}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function vec3combo(a: vector) @@ -234,14 +228,8 @@ bb_2: bb_bytecode_1: %8 = LOAD_TVALUE R0 %10 = MUL_VEC %8, %8 - %11 = TAG_VECTOR %10 - STORE_TVALUE R1, %11 %19 = SUB_VEC %10, %10 - %20 = TAG_VECTOR %19 - STORE_TVALUE R3, %20 %28 = ADD_VEC %10, %10 - %29 = TAG_VECTOR %28 - STORE_TVALUE R4, %29 %37 = DIV_VEC %19, %28 %38 = TAG_VECTOR %37 STORE_TVALUE R2, %38 @@ -253,6 +241,7 @@ bb_bytecode_1: TEST_CASE("VectorMulDivMixed") { ScopedFastFlag luauCodegenVectorTag2{FFlag::LuauCodegenVectorTag2, true}; + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores2, true}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function vec3combo(a: vector, b: vector, c: vector, d: vector) @@ -273,31 +262,17 @@ bb_bytecode_1: %12 = LOAD_TVALUE R0 %13 = NUM_TO_VEC 2 %14 = MUL_VEC %12, %13 - %15 = TAG_VECTOR %14 - STORE_TVALUE R7, %15 %19 = LOAD_TVALUE R1 %20 = NUM_TO_VEC 4 %21 = DIV_VEC %19, %20 - %22 = TAG_VECTOR %21 - STORE_TVALUE R8, %22 %30 = ADD_VEC %14, %21 - %31 = TAG_VECTOR %30 - STORE_TVALUE R6, %31 - STORE_DOUBLE R8, 0.5 - STORE_TAG R8, tnumber %40 = NUM_TO_VEC 0.5 %41 = LOAD_TVALUE R2 %42 = MUL_VEC %40, %41 - %43 = TAG_VECTOR %42 - STORE_TVALUE R7, %43 %51 = ADD_VEC %30, %42 - %52 = TAG_VECTOR %51 - STORE_TVALUE R5, %52 %56 = NUM_TO_VEC 40 %57 = LOAD_TVALUE R3 %58 = DIV_VEC %56, %57 - %59 = TAG_VECTOR %58 - STORE_TVALUE R6, %59 %67 = ADD_VEC %51, %58 %68 = TAG_VECTOR %67 STORE_TVALUE R4, %68 @@ -308,6 +283,8 @@ bb_bytecode_1: TEST_CASE("ExtraMathMemoryOperands") { + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores2, true}; + CHECK_EQ("\n" + getCodegenAssembly(R"( local function foo(a: number, b: number, c: number, d: number, e: number) return math.floor(a) + math.ceil(b) + math.round(c) + math.sqrt(d) + math.abs(e) @@ -327,26 +304,13 @@ bb_2: bb_bytecode_1: CHECK_SAFE_ENV exit(1) %16 = FLOOR_NUM R0 - STORE_DOUBLE R9, %16 - STORE_TAG R9, tnumber %23 = CEIL_NUM R1 - STORE_DOUBLE R10, %23 - STORE_TAG R10, tnumber %32 = ADD_NUM %16, %23 - STORE_DOUBLE R8, %32 - STORE_TAG R8, tnumber %39 = ROUND_NUM R2 - STORE_DOUBLE R9, %39 %48 = ADD_NUM %32, %39 - STORE_DOUBLE R7, %48 - STORE_TAG R7, tnumber %55 = SQRT_NUM R3 - STORE_DOUBLE R8, %55 %64 = ADD_NUM %48, %55 - STORE_DOUBLE R6, %64 - STORE_TAG R6, tnumber %71 = ABS_NUM R4 - STORE_DOUBLE R7, %71 %80 = ADD_NUM %64, %71 STORE_DOUBLE R5, %80 STORE_TAG R5, tnumber diff --git a/tests/NonStrictTypeChecker.test.cpp b/tests/NonStrictTypeChecker.test.cpp index 12ad0dde..a98c8dbd 100644 --- a/tests/NonStrictTypeChecker.test.cpp +++ b/tests/NonStrictTypeChecker.test.cpp @@ -89,6 +89,9 @@ declare function @checked optionalArg(x: string?) : number declare foo: { bar: @checked (number) -> number, } + +declare function @checked optionalArgsAtTheEnd1(x: string, y: number?, z: number?) : number +declare function @checked optionalArgsAtTheEnd2(x: string, y: number?, z: string) : number )BUILTIN_SRC"; }; @@ -474,4 +477,32 @@ abs(3, "hi"); CHECK_EQ("foo.bar", r2->functionName); } +TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "optionals_in_checked_function_can_be_omitted") +{ + CheckResult result = checkNonStrict(R"( +optionalArgsAtTheEnd1("a") +optionalArgsAtTheEnd1("a", 3) +optionalArgsAtTheEnd1("a", nil, 3) +)"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "optionals_in_checked_function_in_middle_cannot_be_omitted") +{ + CheckResult result = checkNonStrict(R"( +optionalArgsAtTheEnd2("a", "a") -- error +optionalArgsAtTheEnd2("a", nil, "b") +optionalArgsAtTheEnd2("a", 3, "b") +optionalArgsAtTheEnd2("a", "b", "c") -- error +)"); + LUAU_REQUIRE_ERROR_COUNT(3, result); + NONSTRICT_REQUIRE_CHECKED_ERR(Position(1, 27), "optionalArgsAtTheEnd2", result); + NONSTRICT_REQUIRE_CHECKED_ERR(Position(4, 27), "optionalArgsAtTheEnd2", result); + auto r1 = get(result.errors[2]); + LUAU_ASSERT(r1); + CHECK_EQ(3, r1->expected); + CHECK_EQ(2, r1->actual); +} + TEST_SUITE_END(); diff --git a/tests/ScopedFlags.h b/tests/ScopedFlags.h index 13059498..beb3cc06 100644 --- a/tests/ScopedFlags.h +++ b/tests/ScopedFlags.h @@ -6,7 +6,7 @@ #include template -struct ScopedFValue +struct [[nodiscard]] ScopedFValue { private: Luau::FValue* value = nullptr; diff --git a/tests/Simplify.test.cpp b/tests/Simplify.test.cpp index 536de39d..66f7dcd3 100644 --- a/tests/Simplify.test.cpp +++ b/tests/Simplify.test.cpp @@ -563,4 +563,10 @@ TEST_CASE_FIXTURE(SimplifyFixture, "free_type_bound_by_any_with_any") CHECK("'a | *error-type*" == intersectStr(anyTy, freeTy)); } +TEST_CASE_FIXTURE(SimplifyFixture, "bound_intersected_by_itself_should_be_itself") +{ + TypeId blocked = arena->addType(BlockedType{}); + CHECK(toString(blocked) == intersectStr(blocked, blocked)); +} + TEST_SUITE_END(); diff --git a/tests/TypeFamily.test.cpp b/tests/TypeFamily.test.cpp index 734ff036..8ce5d28d 100644 --- a/tests/TypeFamily.test.cpp +++ b/tests/TypeFamily.test.cpp @@ -552,6 +552,29 @@ TEST_CASE_FIXTURE(ClassFixture, "keyof_type_family_common_subset_if_union_of_dif LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(ClassFixture, "vector2_multiply_is_overloaded") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + local v = Vector2.New(1, 2) + + local v2 = v * 1.5 + local v3 = v * v + local v4 = v * "Hello" -- line 5 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK(5 == result.errors[0].location.begin.line); + CHECK(5 == result.errors[0].location.end.line); + + CHECK("Vector2" == toString(requireType("v2"))); + CHECK("Vector2" == toString(requireType("v3"))); + CHECK("mul" == toString(requireType("v4"))); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "keyof_rfc_example") { if (!FFlag::DebugLuauDeferredConstraintResolution) diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 1cb97ea2..fe7ff512 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -2309,4 +2309,60 @@ end CHECK_EQ("(number) -> boolean", toString(requireType("odd"))); } +TEST_CASE_FIXTURE(BuiltinsFixture, "tf_suggest_return_type") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + CheckResult result = check(R"( +function fib(n) + return n < 2 and 1 or fib(n-1) + fib(n-2) +end +)"); + + LUAU_REQUIRE_ERRORS(result); + auto err = get(result.errors.back()); + LUAU_ASSERT(err); + CHECK("false | number" == toString(err->recommendedReturn)); + CHECK(err->recommendedArgs.size() == 0); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "tf_suggest_arg_type") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + CheckResult result = check(R"( +function fib(n, u) + return (n or u) and (n < u and n + fib(n,u)) +end +)"); + + LUAU_REQUIRE_ERRORS(result); + auto err = get(result.errors.back()); + LUAU_ASSERT(err); + CHECK("number" == toString(err->recommendedReturn)); + CHECK(err->recommendedArgs.size() == 2); + CHECK("number" == toString(err->recommendedArgs[0].second)); + CHECK("number" == toString(err->recommendedArgs[1].second)); +} + +TEST_CASE_FIXTURE(Fixture, "local_function_fwd_decl_doesnt_crash") +{ + CheckResult result = check(R"( + local foo + + local function bar() + foo() + end + + function foo() + end + + bar() + )"); + + // This test verifies that an ICE doesn't occur, so the bulk of the test is + // just from running check above. + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.loops.test.cpp b/tests/TypeInfer.loops.test.cpp index eb0a7898..7f681023 100644 --- a/tests/TypeInfer.loops.test.cpp +++ b/tests/TypeInfer.loops.test.cpp @@ -460,6 +460,19 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "correctly_scope_locals_while") CHECK_EQ(us->name, "a"); } +TEST_CASE_FIXTURE(BuiltinsFixture, "trivial_ipairs_usage") +{ + CheckResult result = check(R"( + local next, t, s = ipairs({1, 2, 3}) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + REQUIRE_EQ("({number}, number) -> (number?, number)", toString(requireType("next"))); + REQUIRE_EQ("{number}", toString(requireType("t"))); + REQUIRE_EQ("number", toString(requireType("s"))); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "ipairs_produces_integral_indices") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.oop.test.cpp b/tests/TypeInfer.oop.test.cpp index 6d6658da..667a1ebe 100644 --- a/tests/TypeInfer.oop.test.cpp +++ b/tests/TypeInfer.oop.test.cpp @@ -502,4 +502,21 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "promise_type_error_too_complex" * doctest::t LUAU_REQUIRE_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "method_should_not_create_cyclic_type") +{ + ScopedFastFlag sff(FFlag::DebugLuauDeferredConstraintResolution, true); + + CheckResult result = check(R"( + local Component = {} + + function Component:__resolveUpdate(incomingState) + local oldState = self.state + incomingState = oldState + self.state = incomingState + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index ba3c8216..56548608 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -1474,4 +1474,22 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "compare_singleton_string_to_string") LUAU_REQUIRE_ERROR_COUNT(1, result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "no_infinite_expansion_of_free_type" * doctest::timeout(1.0)) +{ + ScopedFastFlag sff(FFlag::DebugLuauDeferredConstraintResolution, true); + check(R"( + local tooltip = {} + + function tooltip:Show() + local playerGui = self.Player:FindFirstChild("PlayerGui") + for _,c in ipairs(playerGui:GetChildren()) do + if c:IsA("ScreenGui") and c.DisplayOrder > self.Gui.DisplayOrder then + end + end + end + )"); + + // just type-checking this code is enough +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 7a3397ce..9a356c59 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -2022,4 +2022,39 @@ end CHECK("string" == toString(t)); } +TEST_CASE_FIXTURE(RefinementClassFixture, "mutate_prop_of_some_refined_symbol") +{ + CheckResult result = check(R"( + local function instances(): {Instance} error("") end + local function vec3(x, y, z): Vector3 error("") end + + for _, object in ipairs(instances()) do + if object:IsA("Part") then + object.Position = vec3(1, 2, 3) + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(RefinementClassFixture, "mutate_prop_of_some_refined_symbol_2") +{ + CheckResult result = check(R"( + type Result = never + | { tag: "ok", value: T } + | { tag: "err", error: E } + + local function results(): {Result} error("") end + + for _, res in ipairs(results()) do + if res.tag == "ok" then + res.value = 7 + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 6d300769..7ac67935 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -4171,17 +4171,9 @@ TEST_CASE_FIXTURE(Fixture, "table_writes_introduce_write_properties") LUAU_REQUIRE_NO_ERRORS(result); - CHECK("({{ read Character: a }}, { Character: t1 }) -> () " + CHECK("({{ read Character: t1 }}, { Character: t1 }) -> () " "where " - "t1 = a & { read FindFirstChild: (t1, string) -> (b, c...) }" == toString(requireType("oc"))); - -// We currently get -// ({{ read Character: a }}, { Character: t1 }) -> () where t1 = { read FindFirstChild: (t1, string) -> (b, c...) } - -// But we'd like to see -// ({{ read Character: t1 }}, { Character: t1 }) -> () where t1 = { read FindFirstChild: (t1, string) -> (a, b...) } - -// The type of speaker.Character should be the same as player[1].Character + "t1 = { read FindFirstChild: (t1, string) -> (a, b...) }" == toString(requireType("oc"))); } TEST_SUITE_END(); diff --git a/tests/conformance/native.lua b/tests/conformance/native.lua index 6ebe85d6..17064213 100644 --- a/tests/conformance/native.lua +++ b/tests/conformance/native.lua @@ -407,4 +407,24 @@ end bufferbounds(0) +function deadStoreChecks1() + local a = 1.0 + local b = 0.0 + + local function update() + b += a + for i = 1, 100 do print(`{b} is {b}`) end + end + + update() + a = 10 + update() + a = 100 + update() + + return b +end + +assert(deadStoreChecks1() == 111) + return('OK') diff --git a/tools/faillist.txt b/tools/faillist.txt index 008f0e9c..c4c06cc1 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -11,7 +11,6 @@ BuiltinTests.assert_removes_falsy_types_even_from_type_pack_tail_but_only_for_th BuiltinTests.assert_returns_false_and_string_iff_it_knows_the_first_argument_cannot_be_truthy BuiltinTests.bad_select_should_not_crash BuiltinTests.coroutine_resume_anything_goes -BuiltinTests.global_singleton_types_are_sealed BuiltinTests.gmatch_capture_types BuiltinTests.gmatch_capture_types2 BuiltinTests.gmatch_capture_types_balanced_escaped_parens @@ -20,26 +19,18 @@ BuiltinTests.gmatch_capture_types_parens_in_sets_are_ignored BuiltinTests.gmatch_capture_types_set_containing_lbracket BuiltinTests.gmatch_definition BuiltinTests.ipairs_iterator_should_infer_types_and_type_check -BuiltinTests.next_iterator_should_infer_types_and_type_check BuiltinTests.os_time_takes_optional_date_table -BuiltinTests.pairs_iterator_should_infer_types_and_type_check BuiltinTests.select_slightly_out_of_range BuiltinTests.select_way_out_of_range BuiltinTests.select_with_variadic_typepack_tail_and_string_head BuiltinTests.set_metatable_needs_arguments BuiltinTests.setmetatable_should_not_mutate_persisted_types -BuiltinTests.sort BuiltinTests.sort_with_bad_predicate -BuiltinTests.sort_with_predicate BuiltinTests.string_format_as_method BuiltinTests.string_format_correctly_ordered_types BuiltinTests.string_format_report_all_type_errors_at_correct_positions BuiltinTests.string_format_use_correct_argument2 -BuiltinTests.table_concat_returns_string -BuiltinTests.table_dot_remove_optionally_returns_generic BuiltinTests.table_freeze_is_generic -BuiltinTests.table_insert_correctly_infers_type_of_array_2_args_overload -BuiltinTests.table_insert_correctly_infers_type_of_array_3_args_overload BuiltinTests.tonumber_returns_optional_number_type ControlFlowAnalysis.if_not_x_break_elif_not_y_break ControlFlowAnalysis.if_not_x_break_elif_not_y_continue @@ -98,12 +89,10 @@ GenericsTests.generic_type_pack_unification1 GenericsTests.generic_type_pack_unification2 GenericsTests.generic_type_pack_unification3 GenericsTests.higher_rank_polymorphism_should_not_accept_instantiated_arguments -GenericsTests.hof_subtype_instantiation_regression GenericsTests.infer_generic_function_function_argument GenericsTests.infer_generic_function_function_argument_2 GenericsTests.infer_generic_function_function_argument_3 GenericsTests.infer_generic_function_function_argument_overloaded -GenericsTests.infer_generic_lib_function_function_argument GenericsTests.instantiated_function_argument_names GenericsTests.mutable_state_polymorphism GenericsTests.no_stack_overflow_from_quantifying @@ -408,7 +397,6 @@ TypeInferFunctions.infer_anonymous_function_arguments TypeInferFunctions.infer_anonymous_function_arguments_outside_call TypeInferFunctions.infer_generic_function_function_argument TypeInferFunctions.infer_generic_function_function_argument_overloaded -TypeInferFunctions.infer_generic_lib_function_function_argument TypeInferFunctions.infer_return_type_from_selected_overload TypeInferFunctions.infer_return_value_type TypeInferFunctions.inferred_higher_order_functions_are_quantified_at_the_right_time3 @@ -432,7 +420,6 @@ TypeInferFunctions.too_many_arguments_error_location TypeInferFunctions.too_many_return_values_in_parentheses TypeInferFunctions.too_many_return_values_no_function TypeInferLoops.cli_68448_iterators_need_not_accept_nil -TypeInferLoops.dcr_iteration_explore_raycast_minimization TypeInferLoops.dcr_iteration_fragmented_keys TypeInferLoops.dcr_iteration_on_never_gives_never TypeInferLoops.dcr_xpath_candidates @@ -443,7 +430,6 @@ TypeInferLoops.for_in_loop_on_error TypeInferLoops.for_in_loop_on_non_function TypeInferLoops.for_in_loop_with_next TypeInferLoops.for_in_with_an_iterator_of_type_any -TypeInferLoops.for_in_with_generic_next TypeInferLoops.for_loop TypeInferLoops.ipairs_produces_integral_indices TypeInferLoops.iterate_over_free_table @@ -486,7 +472,6 @@ TypeInferOperators.error_on_invalid_operand_types_to_relational_operators2 TypeInferOperators.luau_polyfill_is_array TypeInferOperators.mm_comparisons_must_return_a_boolean TypeInferOperators.operator_eq_verifies_types_do_intersect -TypeInferOperators.reducing_and TypeInferOperators.refine_and_or TypeInferOperators.reworked_and TypeInferOperators.reworked_or From 209fd506c9017b9e252408834811c8f5c5529158 Mon Sep 17 00:00:00 2001 From: Maxwell Ruben <89617289+mxruben@users.noreply.github.com> Date: Mon, 11 Mar 2024 08:28:40 -0400 Subject: [PATCH 2/5] Fix REPL help message formatting (#1186) The last line of the help message was missing a newline character. I feel a little silly creating a pull request for a 2 character change but it was bothering me. Fixes #1185 --- CLI/Repl.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index 84e4a654..d1122ae6 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -726,7 +726,7 @@ static void displayHelp(const char* argv0) printf(" --profile[=N]: profile the code using N Hz sampling (default 10000) and output results to profile.out\n"); printf(" --timetrace: record compiler time tracing information into trace.json\n"); printf(" --codegen: execute code using native code generation\n"); - printf(" --program-args,-a: declare start of arguments to be passed to the Luau program"); + printf(" --program-args,-a: declare start of arguments to be passed to the Luau program\n"); } static int assertionHandler(const char* expr, const char* file, int line, const char* function) From 9aa82c6fb90e1dcd6e7f60626255d597ef0fdea1 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Wed, 13 Mar 2024 12:56:11 -0700 Subject: [PATCH 3/5] CodeGen: Improve lowering of NUM_TO_VEC on A64 for constants (#1194) When the input is a constant, we use a fairly inefficient sequence of fmov+fcvt+dup or, when the double isn't encodable in fmov, adr+ldr+fcvt+dup. Instead, we can use the same lowering as X64 when the input is a constant, and load the vector from memory. However, if the constant is encodable via fmov, we can use a vector fmov instead (which is just one instruction and doesn't need constant space). Fortunately the bit encoding of fmov for 32-bit floating point numbers matches that of 64-bit: the decoding algorithm is a little different because it expands into a larger exponent, but the values are compatible, so if a double can be encoded into a scalar fmov with a given abcdefgh pattern, the same pattern should encode the same float; due to the very limited number of mantissa and exponent bits, all values that are encodable are also exact in both 32-bit and 64-bit floats. This strategy is ~same as what gcc uses. For complex vectors, we previously used 4 instructions and 8 bytes of constant storage, and now we use 2 instructions and 16 bytes of constant storage, so the memory footprint is the same; for simple vectors we just need 1 instruction (4 bytes). clang lowers vector constants a little differently, opting to synthesize a 64-bit integer using 4 instructions (mov/movk) and then move it to the vector register - this requires 5 instructions and 20 bytes, vs ours/gcc 2 instructions and 8+16=24 bytes. I tried a simpler version of this that would be more compact - synthesize a 32-bit integer constant with mov+movk, and move it to vector register via dup.4s - but this was a little slower on M2, so for now we prefer the slightly larger version as it's not a regression vs current implementation. On the vector approximation benchmark we get: - Before this PR (flag=false): ~7.85 ns/op - After this PR (flag=true): ~7.74 ns/op - After this PR, with 0.125 instead of 0.123 in the benchmark code (to use fmov): ~7.52 ns/op - Not part of this PR, but the mov/dup strategy described above: ~8.00 ns/op --- CodeGen/include/Luau/AssemblyBuilderA64.h | 5 +-- CodeGen/src/AssemblyBuilderA64.cpp | 20 ++++++++--- CodeGen/src/IrLoweringA64.cpp | 44 +++++++++++++++++------ tests/AssemblyBuilderA64.test.cpp | 6 ++++ tests/conformance/vector.lua | 6 ++++ 5 files changed, 64 insertions(+), 17 deletions(-) diff --git a/CodeGen/include/Luau/AssemblyBuilderA64.h b/CodeGen/include/Luau/AssemblyBuilderA64.h index bea70fd0..a4d857a4 100644 --- a/CodeGen/include/Luau/AssemblyBuilderA64.h +++ b/CodeGen/include/Luau/AssemblyBuilderA64.h @@ -125,12 +125,12 @@ public: // Address of code (label) void adr(RegisterA64 dst, Label& label); - // Floating-point scalar moves + // Floating-point scalar/vector moves // Note: constant must be compatible with immediate floating point moves (see isFmovSupported) void fmov(RegisterA64 dst, RegisterA64 src); void fmov(RegisterA64 dst, double src); - // Floating-point scalar math + // Floating-point scalar/vector math void fabs(RegisterA64 dst, RegisterA64 src); void fadd(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); void fdiv(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); @@ -139,6 +139,7 @@ public: void fsqrt(RegisterA64 dst, RegisterA64 src); void fsub(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + // Vector component manipulation void ins_4s(RegisterA64 dst, RegisterA64 src, uint8_t index); void ins_4s(RegisterA64 dst, uint8_t dstIndex, RegisterA64 src, uint8_t srcIndex); void dup_4s(RegisterA64 dst, RegisterA64 src, uint8_t index); diff --git a/CodeGen/src/AssemblyBuilderA64.cpp b/CodeGen/src/AssemblyBuilderA64.cpp index ffb0a774..9d0522c0 100644 --- a/CodeGen/src/AssemblyBuilderA64.cpp +++ b/CodeGen/src/AssemblyBuilderA64.cpp @@ -557,16 +557,26 @@ void AssemblyBuilderA64::fmov(RegisterA64 dst, RegisterA64 src) void AssemblyBuilderA64::fmov(RegisterA64 dst, double src) { - CODEGEN_ASSERT(dst.kind == KindA64::d); + CODEGEN_ASSERT(dst.kind == KindA64::d || dst.kind == KindA64::q); int imm = getFmovImm(src); CODEGEN_ASSERT(imm >= 0 && imm <= 256); - // fmov can't encode 0, but movi can; movi is otherwise not useful for 64-bit fp immediates because it encodes repeating patterns - if (imm == 256) - placeFMOV("movi", dst, src, 0b001'0111100000'000'1110'01'00000); + // fmov can't encode 0, but movi can; movi is otherwise not useful for fp immediates because it encodes repeating patterns + if (dst.kind == KindA64::d) + { + if (imm == 256) + placeFMOV("movi", dst, src, 0b001'0111100000'000'1110'01'00000); + else + placeFMOV("fmov", dst, src, 0b000'11110'01'1'00000000'100'00000 | (imm << 8)); + } else - placeFMOV("fmov", dst, src, 0b000'11110'01'1'00000000'100'00000 | (imm << 8)); + { + if (imm == 256) + placeFMOV("movi.4s", dst, src, 0b010'0111100000'000'0000'01'00000); + else + placeFMOV("fmov.4s", dst, src, 0b010'0111100000'000'1111'0'1'00000 | ((imm >> 5) << 11) | (imm & 31)); + } } void AssemblyBuilderA64::fabs(RegisterA64 dst, RegisterA64 src) diff --git a/CodeGen/src/IrLoweringA64.cpp b/CodeGen/src/IrLoweringA64.cpp index 2a296949..284cef4d 100644 --- a/CodeGen/src/IrLoweringA64.cpp +++ b/CodeGen/src/IrLoweringA64.cpp @@ -12,6 +12,7 @@ #include "lgc.h" LUAU_FASTFLAGVARIABLE(LuauCodeGenVectorA64, false) +LUAU_FASTFLAGVARIABLE(LuauCodeGenOptVecA64, false) LUAU_FASTFLAG(LuauCodegenVectorTag2) @@ -1176,17 +1177,40 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) { inst.regA64 = regs.allocReg(KindA64::q, index); - RegisterA64 tempd = tempDouble(inst.a); - RegisterA64 temps = castReg(KindA64::s, tempd); - RegisterA64 tempw = regs.allocTemp(KindA64::w); - - build.fcvt(temps, tempd); - build.dup_4s(inst.regA64, castReg(KindA64::q, temps), 0); - - if (!FFlag::LuauCodegenVectorTag2) + if (FFlag::LuauCodeGenOptVecA64 && FFlag::LuauCodegenVectorTag2 && inst.a.kind == IrOpKind::Constant) { - build.mov(tempw, LUA_TVECTOR); - build.ins_4s(inst.regA64, tempw, 3); + float value = float(doubleOp(inst.a)); + uint32_t asU32; + static_assert(sizeof(asU32) == sizeof(value), "Expecting float to be 32-bit"); + memcpy(&asU32, &value, sizeof(value)); + + if (AssemblyBuilderA64::isFmovSupported(value)) + { + build.fmov(inst.regA64, value); + } + else + { + RegisterA64 temp = regs.allocTemp(KindA64::x); + + uint32_t vec[4] = { asU32, asU32, asU32, 0 }; + build.adr(temp, vec, sizeof(vec)); + build.ldr(inst.regA64, temp); + } + } + else + { + RegisterA64 tempd = tempDouble(inst.a); + RegisterA64 temps = castReg(KindA64::s, tempd); + RegisterA64 tempw = regs.allocTemp(KindA64::w); + + build.fcvt(temps, tempd); + build.dup_4s(inst.regA64, castReg(KindA64::q, temps), 0); + + if (!FFlag::LuauCodegenVectorTag2) + { + build.mov(tempw, LUA_TVECTOR); + build.ins_4s(inst.regA64, tempw, 3); + } } break; } diff --git a/tests/AssemblyBuilderA64.test.cpp b/tests/AssemblyBuilderA64.test.cpp index 320a7a6a..3942003b 100644 --- a/tests/AssemblyBuilderA64.test.cpp +++ b/tests/AssemblyBuilderA64.test.cpp @@ -451,6 +451,12 @@ TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "FPImm") SINGLE_COMPARE(fmov(d0, 0), 0x2F00E400); SINGLE_COMPARE(fmov(d0, 0.125), 0x1E681000); SINGLE_COMPARE(fmov(d0, -0.125), 0x1E781000); + SINGLE_COMPARE(fmov(d0, 1.9375), 0x1E6FF000); + + SINGLE_COMPARE(fmov(q0, 0), 0x4F000400); + SINGLE_COMPARE(fmov(q0, 0.125), 0x4F02F400); + SINGLE_COMPARE(fmov(q0, -0.125), 0x4F06F400); + SINGLE_COMPARE(fmov(q0, 1.9375), 0x4F03F7E0); CHECK(!AssemblyBuilderA64::isFmovSupported(-0.0)); CHECK(!AssemblyBuilderA64::isFmovSupported(0.12389)); diff --git a/tests/conformance/vector.lua b/tests/conformance/vector.lua index c9cc47aa..9be88f69 100644 --- a/tests/conformance/vector.lua +++ b/tests/conformance/vector.lua @@ -51,6 +51,12 @@ assert(8 * vector(8, 16, 24) == vector(64, 128, 192)); assert(vector(1, 2, 4) * '8' == vector(8, 16, 32)); assert('8' * vector(8, 16, 24) == vector(64, 128, 192)); +assert(vector(1, 2, 4) * -0.125 == vector(-0.125, -0.25, -0.5)) +assert(-0.125 * vector(1, 2, 4) == vector(-0.125, -0.25, -0.5)) + +assert(vector(1, 2, 4) * 100 == vector(100, 200, 400)) +assert(100 * vector(1, 2, 4) == vector(100, 200, 400)) + if vector_size == 4 then assert(vector(1, 2, 4, 8) / vector(8, 16, 24, 32) == vector(1/8, 2/16, 4/24, 8/32)); assert(8 / vector(8, 16, 24, 32) == vector(1, 1/2, 1/3, 1/4)); From d2ed2150ca4c6dc84e77cbae930f6233354d9961 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Fri, 15 Mar 2024 09:32:27 -0700 Subject: [PATCH 4/5] Work around ASLR+ASAN compatibility issues in GHA (#1203) vm.mmap_rnd_bits has been recently changed to 32 on GHA, which triggers issues in ASAN builds that spuriously fail on startup. The fix requires a more recent clang/gcc than the agents have available (clang 17, not sure what GCC version), so for now we need to work around this by restricting the ASLR randomness. See https://github.com/google/sanitizers/issues/1614 --- .github/workflows/build.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index e6a01255..7a2b5f10 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -25,6 +25,9 @@ jobs: runs-on: ${{matrix.os.version}} steps: - uses: actions/checkout@v1 + - name: work around ASLR+ASAN compatibility + run: sudo sysctl -w vm.mmap_rnd_bits=28 + if: matrix.os.name == 'ubuntu' - name: make tests run: | make -j2 config=sanitize werror=1 native=1 luau-tests From a7683110d71a15bfc823688191476d4c822565cf Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Fri, 15 Mar 2024 09:49:00 -0700 Subject: [PATCH 5/5] CodeGen: Preserve known tags for LOAD_TVALUE synthesized from LOADK (#1201) When lowering LOADK for booleans/numbers/nils, we deconstruct the operation using STORE_TAG which informs the rest of the optimization pipeline about the tag of the value. This is helpful to remove various tag checks. When the constant is a string or a vector, we just use LOAD_TVALUE/STORE_TVALUE. For strings, this could be replaced by pointer load/store, but for vectors there's no great alternative using current IR ops; in either case, the optimization needs to be carefully examined for profitability as simply copying constants into registers for function calls could become more expensive. However, there are cases where it's still valuable to preserve the tag. For vectors, doing any math with vector constants contains tag checks that could be removed. For both strings and vectors, storing them into a table has a barrier that for vectors could be elided, and for strings could be simplified as there's no need to confirm the tag. With this change we now carry the optional tag of the value with LOAD_TVALUE. This has no performance effect on existing benchmarks but does reduce the generated code for benchmarks by ~0.1%, and it makes vector code more efficient (~5% lift on X64 log1p approximation). --- CodeGen/include/Luau/IrData.h | 3 ++- CodeGen/src/IrTranslation.cpp | 8 ++++++++ CodeGen/src/OptimizeConstProp.cpp | 4 ++++ 3 files changed, 14 insertions(+), 1 deletion(-) diff --git a/CodeGen/include/Luau/IrData.h b/CodeGen/include/Luau/IrData.h index 79d06e5a..a950370b 100644 --- a/CodeGen/include/Luau/IrData.h +++ b/CodeGen/include/Luau/IrData.h @@ -59,7 +59,8 @@ enum class IrCmd : uint8_t // Load a TValue from memory // A: Rn or Kn or pointer (TValue) - // B: int (optional 'A' pointer offset) + // B: int/none (optional 'A' pointer offset) + // C: tag/none (tag of the value being loaded) LOAD_TVALUE, // Load current environment table diff --git a/CodeGen/src/IrTranslation.cpp b/CodeGen/src/IrTranslation.cpp index 995225a6..5d55c877 100644 --- a/CodeGen/src/IrTranslation.cpp +++ b/CodeGen/src/IrTranslation.cpp @@ -14,6 +14,7 @@ LUAU_FASTFLAGVARIABLE(LuauCodegenVectorTag2, false) LUAU_FASTFLAGVARIABLE(LuauCodegenVectorTag, false) +LUAU_FASTFLAGVARIABLE(LuauCodegenLoadTVTag, false) namespace Luau { @@ -111,6 +112,13 @@ static void translateInstLoadConstant(IrBuilder& build, int ra, int k) build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), build.constDouble(protok.value.n)); build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); } + else if (FFlag::LuauCodegenLoadTVTag) + { + // Tag could be LUA_TSTRING or LUA_TVECTOR; for TSTRING we could generate LOAD_POINTER/STORE_POINTER/STORE_TAG, but it's not profitable; + // however, it's still valuable to preserve the tag throughout the optimization pipeline to eliminate tag checks. + IrOp load = build.inst(IrCmd::LOAD_TVALUE, build.vmConst(k), build.constInt(0), build.constTag(protok.tt)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), load); + } else { // Remaining tag here right now is LUA_TSTRING, while it can be transformed to LOAD_POINTER/STORE_POINTER/STORE_TAG, it's not profitable right diff --git a/CodeGen/src/OptimizeConstProp.cpp b/CodeGen/src/OptimizeConstProp.cpp index d765b800..ff4f7bfc 100644 --- a/CodeGen/src/OptimizeConstProp.cpp +++ b/CodeGen/src/OptimizeConstProp.cpp @@ -19,6 +19,7 @@ LUAU_FASTINTVARIABLE(LuauCodeGenReuseSlotLimit, 64) LUAU_FASTFLAGVARIABLE(DebugLuauAbortingChecks, false) LUAU_FASTFLAG(LuauCodegenVectorTag2) LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauCodeGenCoverForgprepEffect, false) +LUAU_FASTFLAG(LuauCodegenLoadTVTag) namespace Luau { @@ -726,6 +727,9 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& arg->cmd == IrCmd::UNM_VEC) tag = LUA_TVECTOR; } + + if (FFlag::LuauCodegenLoadTVTag && arg->cmd == IrCmd::LOAD_TVALUE && arg->c.kind != IrOpKind::None) + tag = function.tagOp(arg->c); } }