diff --git a/Analysis/include/Luau/BuiltinDefinitions.h b/Analysis/include/Luau/BuiltinDefinitions.h index 16213958..d4457638 100644 --- a/Analysis/include/Luau/BuiltinDefinitions.h +++ b/Analysis/include/Luau/BuiltinDefinitions.h @@ -16,9 +16,7 @@ struct TypeArena; void registerBuiltinTypes(GlobalTypes& globals); -void registerBuiltinGlobals(TypeChecker& typeChecker, GlobalTypes& globals); -void registerBuiltinGlobals(Frontend& frontend); - +void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeCheckForAutocomplete = false); TypeId makeUnion(TypeArena& arena, std::vector&& types); TypeId makeIntersection(TypeArena& arena, std::vector&& types); diff --git a/Analysis/include/Luau/Clone.h b/Analysis/include/Luau/Clone.h index 51f1e7a6..b3cbe467 100644 --- a/Analysis/include/Luau/Clone.h +++ b/Analysis/include/Luau/Clone.h @@ -26,7 +26,4 @@ TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState); TypeId clone(TypeId tp, TypeArena& dest, CloneState& cloneState); TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState); -TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysClone = false); -TypeId shallowClone(TypeId ty, NotNull dest); - } // namespace Luau diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index e9e1e884..2feee236 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -53,7 +53,6 @@ struct ConstraintSolver NotNull builtinTypes; InternalErrorReporter iceReporter; NotNull normalizer; - NotNull reducer; // The entire set of constraints that the solver is trying to resolve. std::vector> constraints; NotNull rootScope; @@ -85,8 +84,7 @@ struct ConstraintSolver DcrLogger* logger; explicit ConstraintSolver(NotNull normalizer, NotNull rootScope, std::vector> constraints, - ModuleName moduleName, NotNull reducer, NotNull moduleResolver, std::vector requireCycles, - DcrLogger* logger); + ModuleName moduleName, NotNull moduleResolver, std::vector requireCycles, DcrLogger* logger); // Randomize the order in which to dispatch constraints void randomize(unsigned seed); @@ -219,6 +217,20 @@ struct ConstraintSolver void reportError(TypeError e); private: + + /** Helper used by tryDispatch(SubtypeConstraint) and + * tryDispatch(PackSubtypeConstraint) + * + * Attempts to unify subTy with superTy. If doing so would require unifying + * BlockedTypes, fail and block the constraint on those BlockedTypes. + * + * If unification fails, replace all free types with errorType. + * + * If unification succeeds, unblock every type changed by the unification. + */ + template + bool tryUnify(NotNull constraint, TID subTy, TID superTy); + /** * Marks a constraint as being blocked on a type or type pack. The constraint * solver will not attempt to dispatch blocked constraints until their diff --git a/Analysis/include/Luau/ControlFlow.h b/Analysis/include/Luau/ControlFlow.h index 8272bd53..566d77bd 100644 --- a/Analysis/include/Luau/ControlFlow.h +++ b/Analysis/include/Luau/ControlFlow.h @@ -11,10 +11,10 @@ using ScopePtr = std::shared_ptr; enum class ControlFlow { - None = 0b00001, - Returns = 0b00010, - Throws = 0b00100, - Break = 0b01000, // Currently unused. + None = 0b00001, + Returns = 0b00010, + Throws = 0b00100, + Break = 0b01000, // Currently unused. Continue = 0b10000, // Currently unused. }; diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index 68ba8ff5..82251378 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -8,7 +8,6 @@ #include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/Variant.h" - #include #include #include @@ -36,9 +35,6 @@ struct LoadDefinitionFileResult ModulePtr module; }; -LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, GlobalTypes& globals, ScopePtr targetScope, std::string_view definition, - const std::string& packageName, bool captureComments); - std::optional parseMode(const std::vector& hotcomments); std::vector parsePathExpr(const AstExpr& pathExpr); @@ -55,7 +51,9 @@ std::optional pathExprToModuleName(const ModuleName& currentModuleNa * error when we try during typechecking. */ std::optional pathExprToModuleName(const ModuleName& currentModuleName, const AstExpr& expr); - +// TODO: Deprecate this code path when we move away from the old solver +LoadDefinitionFileResult loadDefinitionFileNoDCR(TypeChecker& typeChecker, GlobalTypes& globals, ScopePtr targetScope, std::string_view definition, + const std::string& packageName, bool captureComments); struct SourceNode { bool hasDirtySourceModule() const @@ -140,10 +138,6 @@ struct Frontend CheckResult check(const ModuleName& name, std::optional optionOverride = {}); // new shininess - // Use 'check' with 'runLintChecks' set to true in FrontendOptions (enabledLintWarnings be set there as well) - LintResult lint_DEPRECATED(const ModuleName& name, std::optional enabledLintWarnings = {}); - LintResult lint_DEPRECATED(const SourceModule& module, std::optional enabledLintWarnings = {}); - bool isDirty(const ModuleName& name, bool forAutocomplete = false) const; void markDirty(const ModuleName& name, std::vector* markedDirty = nullptr); @@ -164,10 +158,11 @@ struct Frontend ScopePtr addEnvironment(const std::string& environmentName); ScopePtr getEnvironmentScope(const std::string& environmentName) const; - void registerBuiltinDefinition(const std::string& name, std::function); + void registerBuiltinDefinition(const std::string& name, std::function); void applyBuiltinDefinitionToEnvironment(const std::string& environmentName, const std::string& definitionName); - LoadDefinitionFileResult loadDefinitionFile(std::string_view source, const std::string& packageName, bool captureComments); + LoadDefinitionFileResult loadDefinitionFile(GlobalTypes& globals, ScopePtr targetScope, std::string_view source, const std::string& packageName, + bool captureComments, bool typeCheckForAutocomplete = false); private: ModulePtr check(const SourceModule& sourceModule, Mode mode, std::vector requireCycles, bool forAutocomplete = false, bool recordJsonLog = false); @@ -182,7 +177,7 @@ private: ScopePtr getModuleEnvironment(const SourceModule& module, const Config& config, bool forAutocomplete) const; std::unordered_map environments; - std::unordered_map> builtinDefinitions; + std::unordered_map> builtinDefinitions; BuiltinTypes builtinTypes_; diff --git a/Analysis/include/Luau/Normalize.h b/Analysis/include/Luau/Normalize.h index 15404707..efcb5108 100644 --- a/Analysis/include/Luau/Normalize.h +++ b/Analysis/include/Luau/Normalize.h @@ -191,12 +191,8 @@ struct NormalizedClassType // this type may contain `error`. struct NormalizedFunctionType { - NormalizedFunctionType(); - bool isTop = false; - // TODO: Remove this wrapping optional when clipping - // FFlagLuauNegatedFunctionTypes. - std::optional parts; + TypeIds parts; void resetToNever(); void resetToTop(); diff --git a/Analysis/include/Luau/Scope.h b/Analysis/include/Luau/Scope.h index 745ea47a..c3038fac 100644 --- a/Analysis/include/Luau/Scope.h +++ b/Analysis/include/Luau/Scope.h @@ -55,11 +55,11 @@ struct Scope std::optional lookup(DefId def) const; std::optional> lookupEx(Symbol sym); - std::optional lookupType(const Name& name); - std::optional lookupImportedType(const Name& moduleAlias, const Name& name); + std::optional lookupType(const Name& name) const; + std::optional lookupImportedType(const Name& moduleAlias, const Name& name) const; std::unordered_map privateTypePackBindings; - std::optional lookupPack(const Name& name); + std::optional lookupPack(const Name& name) const; // WARNING: This function linearly scans for a string key of equal value! It is thus O(n**2) std::optional linearSearchForBinding(const std::string& name, bool traverseScopeChain = true) const; diff --git a/Analysis/include/Luau/Type.h b/Analysis/include/Luau/Type.h index dba2a8de..cff86df4 100644 --- a/Analysis/include/Luau/Type.h +++ b/Analysis/include/Luau/Type.h @@ -75,14 +75,45 @@ using TypeId = const Type*; using Name = std::string; // A free type var is one whose exact shape has yet to be fully determined. -using FreeType = Unifiable::Free; +struct FreeType +{ + explicit FreeType(TypeLevel level); + explicit FreeType(Scope* scope); + FreeType(Scope* scope, TypeLevel level); -// When a free type var is unified with any other, it is then "bound" -// to that type var, indicating that the two types are actually the same type. + int index; + TypeLevel level; + Scope* scope = nullptr; + + // True if this free type variable is part of a mutually + // recursive type alias whose definitions haven't been + // resolved yet. + bool forwardedTypeAlias = false; +}; + +struct GenericType +{ + // By default, generics are global, with a synthetic name + GenericType(); + + explicit GenericType(TypeLevel level); + explicit GenericType(const Name& name); + explicit GenericType(Scope* scope); + + GenericType(TypeLevel level, const Name& name); + GenericType(Scope* scope, const Name& name); + + int index; + TypeLevel level; + Scope* scope = nullptr; + Name name; + bool explicitName = false; +}; + +// When an equality constraint is found, it is then "bound" to that type, +// indicating that the two types are actually the same type. using BoundType = Unifiable::Bound; -using GenericType = Unifiable::Generic; - using Tags = std::vector; using ModuleName = std::string; @@ -395,9 +426,11 @@ struct TableType // Represents a metatable attached to a table type. Somewhat analogous to a bound type. struct MetatableType { - // Always points to a TableType. + // Should always be a TableType. TypeId table; - // Always points to either a TableType or a MetatableType. + // Should almost always either be a TableType or another MetatableType, + // though it is possible for other types (like AnyType and ErrorType) to + // find their way here sometimes. TypeId metatable; std::optional syntheticName; @@ -536,8 +569,8 @@ struct NegationType using ErrorType = Unifiable::Error; -using TypeVariant = Unifiable::Variant; +using TypeVariant = Unifiable::Variant; struct Type final { diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 68161794..7dae79c3 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -79,7 +79,8 @@ struct GlobalTypes // within a program are borrowed pointers into this set. struct TypeChecker { - explicit TypeChecker(const GlobalTypes& globals, ModuleResolver* resolver, NotNull builtinTypes, InternalErrorReporter* iceHandler); + explicit TypeChecker( + const ScopePtr& globalScope, ModuleResolver* resolver, NotNull builtinTypes, InternalErrorReporter* iceHandler); TypeChecker(const TypeChecker&) = delete; TypeChecker& operator=(const TypeChecker&) = delete; @@ -367,8 +368,7 @@ public: */ std::vector unTypePack(const ScopePtr& scope, TypePackId pack, size_t expectedLength, const Location& location); - // TODO: only const version of global scope should be available to make sure nothing else is modified inside of from users of TypeChecker - const GlobalTypes& globals; + const ScopePtr& globalScope; ModuleResolver* resolver; ModulePtr currentModule; diff --git a/Analysis/include/Luau/TypePack.h b/Analysis/include/Luau/TypePack.h index 4831f233..2ae56e5f 100644 --- a/Analysis/include/Luau/TypePack.h +++ b/Analysis/include/Luau/TypePack.h @@ -12,20 +12,48 @@ namespace Luau { struct TypeArena; +struct TxnLog; struct TypePack; struct VariadicTypePack; struct BlockedTypePack; struct TypePackVar; - -struct TxnLog; - using TypePackId = const TypePackVar*; -using FreeTypePack = Unifiable::Free; + +struct FreeTypePack +{ + explicit FreeTypePack(TypeLevel level); + explicit FreeTypePack(Scope* scope); + FreeTypePack(Scope* scope, TypeLevel level); + + int index; + TypeLevel level; + Scope* scope = nullptr; +}; + +struct GenericTypePack +{ + // By default, generics are global, with a synthetic name + GenericTypePack(); + explicit GenericTypePack(TypeLevel level); + explicit GenericTypePack(const Name& name); + explicit GenericTypePack(Scope* scope); + GenericTypePack(TypeLevel level, const Name& name); + GenericTypePack(Scope* scope, const Name& name); + + int index; + TypeLevel level; + Scope* scope = nullptr; + Name name; + bool explicitName = false; +}; + using BoundTypePack = Unifiable::Bound; -using GenericTypePack = Unifiable::Generic; -using TypePackVariant = Unifiable::Variant; + +using ErrorTypePack = Unifiable::Error; + +using TypePackVariant = Unifiable::Variant; /* A TypePack is a rope-like string of TypeIds. We use this structure to encode * notions like packs of unknown length and packs of any length, as well as more diff --git a/Analysis/include/Luau/Unifiable.h b/Analysis/include/Luau/Unifiable.h index 9c4f0132..79b3b7de 100644 --- a/Analysis/include/Luau/Unifiable.h +++ b/Analysis/include/Luau/Unifiable.h @@ -82,24 +82,6 @@ namespace Luau::Unifiable using Name = std::string; int freshIndex(); - -struct Free -{ - explicit Free(TypeLevel level); - explicit Free(Scope* scope); - explicit Free(Scope* scope, TypeLevel level); - - int index; - TypeLevel level; - Scope* scope = nullptr; - // True if this free type variable is part of a mutually - // recursive type alias whose definitions haven't been - // resolved yet. - bool forwardedTypeAlias = false; - -private: - static int DEPRECATED_nextIndex; -}; template struct Bound @@ -112,26 +94,6 @@ struct Bound Id boundTo; }; -struct Generic -{ - // By default, generics are global, with a synthetic name - Generic(); - explicit Generic(TypeLevel level); - explicit Generic(const Name& name); - explicit Generic(Scope* scope); - Generic(TypeLevel level, const Name& name); - Generic(Scope* scope, const Name& name); - - int index; - TypeLevel level; - Scope* scope = nullptr; - Name name; - bool explicitName = false; - -private: - static int DEPRECATED_nextIndex; -}; - struct Error { // This constructor has to be public, since it's used in Type and TypePack, @@ -145,6 +107,6 @@ private: }; template -using Variant = Luau::Variant, Generic, Error, Value...>; +using Variant = Luau::Variant, Error, Value...>; } // namespace Luau::Unifiable diff --git a/Analysis/include/Luau/VisitType.h b/Analysis/include/Luau/VisitType.h index ff4dfc3c..95b2b050 100644 --- a/Analysis/include/Luau/VisitType.h +++ b/Analysis/include/Luau/VisitType.h @@ -341,10 +341,10 @@ struct GenericTypeVisitor traverse(btv->boundTo); } - else if (auto ftv = get(tp)) + else if (auto ftv = get(tp)) visit(tp, *ftv); - else if (auto gtv = get(tp)) + else if (auto gtv = get(tp)) visit(tp, *gtv); else if (auto etv = get(tp)) diff --git a/Analysis/src/AstQuery.cpp b/Analysis/src/AstQuery.cpp index b0c3750b..dc07a35c 100644 --- a/Analysis/src/AstQuery.cpp +++ b/Analysis/src/AstQuery.cpp @@ -11,8 +11,6 @@ #include -LUAU_FASTFLAG(LuauCompleteTableKeysBetter); - namespace Luau { @@ -31,24 +29,12 @@ struct AutocompleteNodeFinder : public AstVisitor bool visit(AstExpr* expr) override { - if (FFlag::LuauCompleteTableKeysBetter) + if (expr->location.begin <= pos && pos <= expr->location.end) { - if (expr->location.begin <= pos && pos <= expr->location.end) - { - ancestry.push_back(expr); - return true; - } - return false; - } - else - { - if (expr->location.begin < pos && pos <= expr->location.end) - { - ancestry.push_back(expr); - return true; - } - return false; + ancestry.push_back(expr); + return true; } + return false; } bool visit(AstStat* stat) override diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 1df4d3d7..42fc9a71 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -13,9 +13,6 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauCompleteTableKeysBetter, false); -LUAU_FASTFLAGVARIABLE(LuauAutocompleteSkipNormalization, false); - static const std::unordered_set kStatementStartingKeywords = { "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; @@ -144,12 +141,9 @@ static bool checkTypeMatch(TypeId subTy, TypeId superTy, NotNull scope, T Normalizer normalizer{typeArena, builtinTypes, NotNull{&unifierState}}; Unifier unifier(NotNull{&normalizer}, Mode::Strict, scope, Location(), Variance::Covariant); - if (FFlag::LuauAutocompleteSkipNormalization) - { - // Cost of normalization can be too high for autocomplete response time requirements - unifier.normalize = false; - unifier.checkInhabited = false; - } + // Cost of normalization can be too high for autocomplete response time requirements + unifier.normalize = false; + unifier.checkInhabited = false; return unifier.canUnify(subTy, superTy).empty(); } @@ -981,25 +975,14 @@ T* extractStat(const std::vector& ancestry) AstNode* grandParent = ancestry.size() >= 3 ? ancestry.rbegin()[2] : nullptr; AstNode* greatGrandParent = ancestry.size() >= 4 ? ancestry.rbegin()[3] : nullptr; - if (FFlag::LuauCompleteTableKeysBetter) - { - if (!grandParent) - return nullptr; + if (!grandParent) + return nullptr; - if (T* t = parent->as(); t && grandParent->is()) - return t; + if (T* t = parent->as(); t && grandParent->is()) + return t; - if (!greatGrandParent) - return nullptr; - } - else - { - if (T* t = parent->as(); t && parent->is()) - return t; - - if (!grandParent || !greatGrandParent) - return nullptr; - } + if (!greatGrandParent) + return nullptr; if (T* t = greatGrandParent->as(); t && grandParent->is() && parent->is() && isIdentifier(node)) return t; @@ -1533,23 +1516,20 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M { auto result = autocompleteProps(*module, typeArena, builtinTypes, *it, PropIndexType::Key, ancestry); - if (FFlag::LuauCompleteTableKeysBetter) - { - if (auto nodeIt = module->astExpectedTypes.find(node->asExpr())) - autocompleteStringSingleton(*nodeIt, !node->is(), result); + if (auto nodeIt = module->astExpectedTypes.find(node->asExpr())) + autocompleteStringSingleton(*nodeIt, !node->is(), result); - if (!key) + if (!key) + { + // If there is "no key," it may be that the user + // intends for the current token to be the key, but + // has yet to type the `=` sign. + // + // If the key type is a union of singleton strings, + // suggest those too. + if (auto ttv = get(follow(*it)); ttv && ttv->indexer) { - // If there is "no key," it may be that the user - // intends for the current token to be the key, but - // has yet to type the `=` sign. - // - // If the key type is a union of singleton strings, - // suggest those too. - if (auto ttv = get(follow(*it)); ttv && ttv->indexer) - { - autocompleteStringSingleton(ttv->indexer->indexType, false, result); - } + autocompleteStringSingleton(ttv->indexer->indexType, false, result); } } diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index d2ace49b..7ed92fb4 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -15,8 +15,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauDeprecateTableGetnForeach, false) - /** FIXME: Many of these type definitions are not quite completely accurate. * * Some of them require richer generics than we have. For instance, we do not yet have a way to talk @@ -214,7 +212,7 @@ void registerBuiltinTypes(GlobalTypes& globals) globals.globalScope->addBuiltinTypeBinding("never", TypeFun{{}, globals.builtinTypes->neverType}); } -void registerBuiltinGlobals(TypeChecker& typeChecker, GlobalTypes& globals) +void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeCheckForAutocomplete) { LUAU_ASSERT(!globals.globalTypes.types.isFrozen()); LUAU_ASSERT(!globals.globalTypes.typePacks.isFrozen()); @@ -222,8 +220,8 @@ void registerBuiltinGlobals(TypeChecker& typeChecker, GlobalTypes& globals) TypeArena& arena = globals.globalTypes; NotNull builtinTypes = globals.builtinTypes; - LoadDefinitionFileResult loadResult = - Luau::loadDefinitionFile(typeChecker, globals, globals.globalScope, getBuiltinDefinitionSource(), "@luau", /* captureComments */ false); + LoadDefinitionFileResult loadResult = frontend.loadDefinitionFile( + globals, globals.globalScope, getBuiltinDefinitionSource(), "@luau", /* captureComments */ false, typeCheckForAutocomplete); LUAU_ASSERT(loadResult.success); TypeId genericK = arena.addType(GenericType{"K"}); @@ -298,13 +296,10 @@ void registerBuiltinGlobals(TypeChecker& typeChecker, GlobalTypes& globals) ttv->props["freeze"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.freeze"); ttv->props["clone"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.clone"); - if (FFlag::LuauDeprecateTableGetnForeach) - { - ttv->props["getn"].deprecated = true; - ttv->props["getn"].deprecatedSuggestion = "#"; - ttv->props["foreach"].deprecated = true; - ttv->props["foreachi"].deprecated = true; - } + ttv->props["getn"].deprecated = true; + ttv->props["getn"].deprecatedSuggestion = "#"; + ttv->props["foreach"].deprecated = true; + ttv->props["foreachi"].deprecated = true; attachMagicFunction(ttv->props["pack"].type, magicFunctionPack); attachDcrMagicFunction(ttv->props["pack"].type, dcrMagicFunctionPack); @@ -314,108 +309,6 @@ void registerBuiltinGlobals(TypeChecker& typeChecker, GlobalTypes& globals) attachDcrMagicFunction(getGlobalBinding(globals, "require"), dcrMagicFunctionRequire); } -void registerBuiltinGlobals(Frontend& frontend) -{ - GlobalTypes& globals = frontend.globals; - - LUAU_ASSERT(!globals.globalTypes.types.isFrozen()); - LUAU_ASSERT(!globals.globalTypes.typePacks.isFrozen()); - - registerBuiltinTypes(globals); - - TypeArena& arena = globals.globalTypes; - NotNull builtinTypes = globals.builtinTypes; - - LoadDefinitionFileResult loadResult = frontend.loadDefinitionFile(getBuiltinDefinitionSource(), "@luau", /* captureComments */ false); - LUAU_ASSERT(loadResult.success); - - TypeId genericK = arena.addType(GenericType{"K"}); - TypeId genericV = arena.addType(GenericType{"V"}); - TypeId mapOfKtoV = arena.addType(TableType{{}, TableIndexer(genericK, genericV), globals.globalScope->level, TableState::Generic}); - - std::optional stringMetatableTy = getMetatable(builtinTypes->stringType, builtinTypes); - LUAU_ASSERT(stringMetatableTy); - const TableType* stringMetatableTable = get(follow(*stringMetatableTy)); - LUAU_ASSERT(stringMetatableTable); - - auto it = stringMetatableTable->props.find("__index"); - LUAU_ASSERT(it != stringMetatableTable->props.end()); - - addGlobalBinding(globals, "string", it->second.type, "@luau"); - - // next(t: Table, i: K?) -> (K?, V) - TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(builtinTypes, arena, genericK)}}); - TypePackId nextRetsTypePack = arena.addTypePack(TypePack{{makeOption(builtinTypes, arena, genericK), genericV}}); - addGlobalBinding(globals, "next", arena.addType(FunctionType{{genericK, genericV}, {}, nextArgsTypePack, nextRetsTypePack}), "@luau"); - - TypePackId pairsArgsTypePack = arena.addTypePack({mapOfKtoV}); - - TypeId pairsNext = arena.addType(FunctionType{nextArgsTypePack, nextRetsTypePack}); - TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, builtinTypes->nilType}}); - - // pairs(t: Table) -> ((Table, K?) -> (K?, V), Table, nil) - addGlobalBinding(globals, "pairs", arena.addType(FunctionType{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); - - TypeId genericMT = arena.addType(GenericType{"MT"}); - - TableType tab{TableState::Generic, globals.globalScope->level}; - TypeId tabTy = arena.addType(tab); - - TypeId tableMetaMT = arena.addType(MetatableType{tabTy, genericMT}); - - addGlobalBinding(globals, "getmetatable", makeFunction(arena, std::nullopt, {genericMT}, {}, {tableMetaMT}, {genericMT}), "@luau"); - - // clang-format off - // setmetatable(T, MT) -> { @metatable MT, T } - addGlobalBinding(globals, "setmetatable", - arena.addType( - FunctionType{ - {genericMT}, - {}, - arena.addTypePack(TypePack{{tabTy, genericMT}}), - arena.addTypePack(TypePack{{tableMetaMT}}) - } - ), "@luau" - ); - // clang-format on - - for (const auto& pair : globals.globalScope->bindings) - { - persist(pair.second.typeId); - - if (TableType* ttv = getMutable(pair.second.typeId)) - { - if (!ttv->name) - ttv->name = "typeof(" + toString(pair.first) + ")"; - } - } - - attachMagicFunction(getGlobalBinding(globals, "assert"), magicFunctionAssert); - attachMagicFunction(getGlobalBinding(globals, "setmetatable"), magicFunctionSetMetaTable); - attachMagicFunction(getGlobalBinding(globals, "select"), magicFunctionSelect); - attachDcrMagicFunction(getGlobalBinding(globals, "select"), dcrMagicFunctionSelect); - - if (TableType* ttv = getMutable(getGlobalBinding(globals, "table"))) - { - // tabTy is a generic table type which we can't express via declaration syntax yet - ttv->props["freeze"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.freeze"); - ttv->props["clone"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.clone"); - - if (FFlag::LuauDeprecateTableGetnForeach) - { - ttv->props["getn"].deprecated = true; - ttv->props["getn"].deprecatedSuggestion = "#"; - ttv->props["foreach"].deprecated = true; - ttv->props["foreachi"].deprecated = true; - } - - attachMagicFunction(ttv->props["pack"].type, magicFunctionPack); - } - - attachMagicFunction(getGlobalBinding(globals, "require"), magicFunctionRequire); - attachDcrMagicFunction(getGlobalBinding(globals, "require"), dcrMagicFunctionRequire); -} - static std::optional> magicFunctionSelect( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) { diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index ff8e0c3c..ac73622d 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -7,7 +7,7 @@ #include "Luau/Unifiable.h" LUAU_FASTFLAG(DebugLuauCopyBeforeNormalizing) -LUAU_FASTFLAG(LuauClonePublicInterfaceLess) +LUAU_FASTFLAG(LuauClonePublicInterfaceLess2) LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300) @@ -44,10 +44,10 @@ struct TypeCloner template void defaultClone(const T& t); - void operator()(const Unifiable::Free& t); - void operator()(const Unifiable::Generic& t); - void operator()(const Unifiable::Bound& t); - void operator()(const Unifiable::Error& t); + void operator()(const FreeType& t); + void operator()(const GenericType& t); + void operator()(const BoundType& t); + void operator()(const ErrorType& t); void operator()(const BlockedType& t); void operator()(const PendingExpansionType& t); void operator()(const PrimitiveType& t); @@ -89,15 +89,15 @@ struct TypePackCloner seenTypePacks[typePackId] = cloned; } - void operator()(const Unifiable::Free& t) + void operator()(const FreeTypePack& t) { defaultClone(t); } - void operator()(const Unifiable::Generic& t) + void operator()(const GenericTypePack& t) { defaultClone(t); } - void operator()(const Unifiable::Error& t) + void operator()(const ErrorTypePack& t) { defaultClone(t); } @@ -145,12 +145,12 @@ void TypeCloner::defaultClone(const T& t) seenTypes[typeId] = cloned; } -void TypeCloner::operator()(const Unifiable::Free& t) +void TypeCloner::operator()(const FreeType& t) { defaultClone(t); } -void TypeCloner::operator()(const Unifiable::Generic& t) +void TypeCloner::operator()(const GenericType& t) { defaultClone(t); } @@ -422,86 +422,4 @@ TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState) return result; } -TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysClone) -{ - ty = log->follow(ty); - - TypeId result = ty; - - if (auto pty = log->pending(ty)) - ty = &pty->pending; - - if (const FunctionType* ftv = get(ty)) - { - FunctionType clone = FunctionType{ftv->level, ftv->scope, ftv->argTypes, ftv->retTypes, ftv->definition, ftv->hasSelf}; - clone.generics = ftv->generics; - clone.genericPacks = ftv->genericPacks; - clone.magicFunction = ftv->magicFunction; - clone.dcrMagicFunction = ftv->dcrMagicFunction; - clone.dcrMagicRefinement = ftv->dcrMagicRefinement; - clone.tags = ftv->tags; - clone.argNames = ftv->argNames; - result = dest.addType(std::move(clone)); - } - else if (const TableType* ttv = get(ty)) - { - LUAU_ASSERT(!ttv->boundTo); - TableType clone = TableType{ttv->props, ttv->indexer, ttv->level, ttv->scope, ttv->state}; - clone.definitionModuleName = ttv->definitionModuleName; - clone.definitionLocation = ttv->definitionLocation; - clone.name = ttv->name; - clone.syntheticName = ttv->syntheticName; - clone.instantiatedTypeParams = ttv->instantiatedTypeParams; - clone.instantiatedTypePackParams = ttv->instantiatedTypePackParams; - clone.tags = ttv->tags; - result = dest.addType(std::move(clone)); - } - else if (const MetatableType* mtv = get(ty)) - { - MetatableType clone = MetatableType{mtv->table, mtv->metatable}; - clone.syntheticName = mtv->syntheticName; - result = dest.addType(std::move(clone)); - } - else if (const UnionType* utv = get(ty)) - { - UnionType clone; - clone.options = utv->options; - result = dest.addType(std::move(clone)); - } - else if (const IntersectionType* itv = get(ty)) - { - IntersectionType clone; - clone.parts = itv->parts; - result = dest.addType(std::move(clone)); - } - else if (const PendingExpansionType* petv = get(ty)) - { - PendingExpansionType clone{petv->prefix, petv->name, petv->typeArguments, petv->packArguments}; - result = dest.addType(std::move(clone)); - } - else if (const ClassType* ctv = get(ty); FFlag::LuauClonePublicInterfaceLess && ctv && alwaysClone) - { - ClassType clone{ctv->name, ctv->props, ctv->parent, ctv->metatable, ctv->tags, ctv->userData, ctv->definitionModuleName}; - result = dest.addType(std::move(clone)); - } - else if (FFlag::LuauClonePublicInterfaceLess && alwaysClone) - { - result = dest.addType(*ty); - } - else if (const NegationType* ntv = get(ty)) - { - result = dest.addType(NegationType{ntv->ty}); - } - else - return result; - - asMutable(result)->documentationSymbol = ty->documentationSymbol; - return result; -} - -TypeId shallowClone(TypeId ty, NotNull dest) -{ - return shallowClone(ty, *dest, TxnLog::empty()); -} - } // namespace Luau diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index e90cb7d3..474d3923 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -23,7 +23,7 @@ LUAU_FASTFLAG(LuauNegatedClassTypes); namespace Luau { -bool doesCallError(const AstExprCall* call); // TypeInfer.cpp +bool doesCallError(const AstExprCall* call); // TypeInfer.cpp const AstStat* getFallthrough(const AstStat* node); // TypeInfer.cpp static std::optional matchRequire(const AstExprCall& call) @@ -1359,10 +1359,34 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa if (argTail && args.size() < 2) argTailPack = extendTypePack(*arena, builtinTypes, *argTail, 2 - args.size()); - LUAU_ASSERT(args.size() + argTailPack.head.size() == 2); + TypeId target = nullptr; + TypeId mt = nullptr; - TypeId target = args.size() > 0 ? args[0] : argTailPack.head[0]; - TypeId mt = args.size() > 1 ? args[1] : argTailPack.head[args.size() == 0 ? 1 : 0]; + if (args.size() + argTailPack.head.size() == 2) + { + target = args.size() > 0 ? args[0] : argTailPack.head[0]; + mt = args.size() > 1 ? args[1] : argTailPack.head[args.size() == 0 ? 1 : 0]; + } + else + { + std::vector unpackedTypes; + if (args.size() > 0) + target = args[0]; + else + { + target = arena->addType(BlockedType{}); + unpackedTypes.emplace_back(target); + } + + mt = arena->addType(BlockedType{}); + unpackedTypes.emplace_back(mt); + TypePackId mtPack = arena->addTypePack(std::move(unpackedTypes)); + + addConstraint(scope, call->location, UnpackConstraint{mtPack, *argTail}); + } + + LUAU_ASSERT(target); + LUAU_ASSERT(mt); AstExpr* targetExpr = call->args.data[0]; @@ -2090,6 +2114,19 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS TypePack expectedArgPack; const FunctionType* expectedFunction = expectedType ? get(*expectedType) : nullptr; + // This check ensures that expectedType is precisely optional and not any (since any is also an optional type) + if (expectedType && isOptional(*expectedType) && !get(*expectedType)) + { + auto ut = get(*expectedType); + for (auto u : ut) + { + if (get(u) && !isNil(u)) + { + expectedFunction = get(u); + break; + } + } + } if (expectedFunction) { diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 5662cf04..d2bed2da 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -3,6 +3,7 @@ #include "Luau/Anyification.h" #include "Luau/ApplyTypeFunction.h" #include "Luau/Clone.h" +#include "Luau/Common.h" #include "Luau/ConstraintSolver.h" #include "Luau/DcrLogger.h" #include "Luau/Instantiation.h" @@ -221,27 +222,14 @@ void dump(ConstraintSolver* cs, ToStringOptions& opts) auto it = cs->blockedConstraints.find(c); int blockCount = it == cs->blockedConstraints.end() ? 0 : int(it->second); printf("\t%d\t%s\n", blockCount, toString(*c, opts).c_str()); - - for (NotNull dep : c->dependencies) - { - auto unsolvedIter = std::find(begin(cs->unsolvedConstraints), end(cs->unsolvedConstraints), dep); - if (unsolvedIter == cs->unsolvedConstraints.end()) - continue; - - auto it = cs->blockedConstraints.find(dep); - int blockCount = it == cs->blockedConstraints.end() ? 0 : int(it->second); - printf("\t%d\t\t%s\n", blockCount, toString(*dep, opts).c_str()); - } } } ConstraintSolver::ConstraintSolver(NotNull normalizer, NotNull rootScope, std::vector> constraints, - ModuleName moduleName, NotNull reducer, NotNull moduleResolver, std::vector requireCycles, - DcrLogger* logger) + ModuleName moduleName, NotNull moduleResolver, std::vector requireCycles, DcrLogger* logger) : arena(normalizer->arena) , builtinTypes(normalizer->builtinTypes) , normalizer(normalizer) - , reducer(reducer) , constraints(std::move(constraints)) , rootScope(rootScope) , currentModuleName(std::move(moduleName)) @@ -468,40 +456,7 @@ bool ConstraintSolver::tryDispatch(const SubtypeConstraint& c, NotNullscope, Location{}, Covariant}; - u.useScopes = true; - - u.tryUnify(c.subType, c.superType); - - if (!u.blockedTypes.empty() || !u.blockedTypePacks.empty()) - { - for (TypeId bt : u.blockedTypes) - block(bt, constraint); - for (TypePackId btp : u.blockedTypePacks) - block(btp, constraint); - return false; - } - - if (const auto& e = hasUnificationTooComplex(u.errors)) - reportError(*e); - - if (!u.errors.empty()) - { - TypeId errorType = errorRecoveryType(); - u.tryUnify(c.subType, errorType); - u.tryUnify(c.superType, errorType); - } - - const auto [changedTypes, changedPacks] = u.log.getChanges(); - - u.log.commit(); - - unblock(changedTypes); - unblock(changedPacks); - - // unify(c.subType, c.superType, constraint->scope); - - return true; + return tryUnify(constraint, c.subType, c.superType); } bool ConstraintSolver::tryDispatch(const PackSubtypeConstraint& c, NotNull constraint, bool force) @@ -511,9 +466,7 @@ bool ConstraintSolver::tryDispatch(const PackSubtypeConstraint& c, NotNullscope); - - return true; + return tryUnify(constraint, c.subPack, c.superPack); } bool ConstraintSolver::tryDispatch(const GeneralizationConstraint& c, NotNull constraint, bool force) @@ -578,12 +531,16 @@ bool ConstraintSolver::tryDispatch(const UnaryConstraint& c, NotNullty.emplace(builtinTypes->booleanType); + + unblock(c.resultType); return true; } case AstExprUnary::Len: { // __len must return a number. asMutable(c.resultType)->ty.emplace(builtinTypes->numberType); + + unblock(c.resultType); return true; } case AstExprUnary::Minus: @@ -613,6 +570,7 @@ bool ConstraintSolver::tryDispatch(const UnaryConstraint& c, NotNullty.emplace(builtinTypes->errorRecoveryType()); } + unblock(c.resultType); return true; } } @@ -868,7 +826,7 @@ bool ConstraintSolver::tryDispatch(const IterableConstraint& c, NotNullscope, constraint->location, this}; queuer.traverse(target); - if (target->persistent) + if (target->persistent || target->owningArena != arena) { bindResult(target); return true; @@ -1249,35 +1207,63 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNullanyType}; } - TypeId instantiatedTy = arena->addType(BlockedType{}); TypeId inferredTy = arena->addType(FunctionType{TypeLevel{}, constraint->scope.get(), argsPack, c.result}); - auto pushConstraintGreedy = [this, constraint](ConstraintV cv) -> Constraint* { - std::unique_ptr c = std::make_unique(constraint->scope, constraint->location, std::move(cv)); - NotNull borrow{c.get()}; + std::vector overloads = flattenIntersection(fn); - bool ok = tryDispatch(borrow, false); - if (ok) - return nullptr; + Instantiation inst(TxnLog::empty(), arena, TypeLevel{}, constraint->scope); - solverConstraints.push_back(std::move(c)); - unsolvedConstraints.push_back(borrow); + for (TypeId overload : overloads) + { + overload = follow(overload); - return borrow; - }; + std::optional instantiated = inst.substitute(overload); + LUAU_ASSERT(instantiated); // TODO FIXME HANDLE THIS - // HACK: We don't want other constraints to act on the free type pack - // created above until after these two constraints are solved, so we try to - // dispatch them directly. + Unifier u{normalizer, Mode::Strict, constraint->scope, Location{}, Covariant}; + u.useScopes = true; - auto ic = pushConstraintGreedy(InstantiationConstraint{instantiatedTy, fn}); - auto sc = pushConstraintGreedy(SubtypeConstraint{instantiatedTy, inferredTy}); + u.tryUnify(*instantiated, inferredTy, /* isFunctionCall */ true); - if (ic) - inheritBlocks(constraint, NotNull{ic}); + if (!u.blockedTypes.empty() || !u.blockedTypePacks.empty()) + { + for (TypeId bt : u.blockedTypes) + block(bt, constraint); + for (TypePackId btp : u.blockedTypePacks) + block(btp, constraint); + return false; + } - if (sc) - inheritBlocks(constraint, NotNull{sc}); + if (const auto& e = hasUnificationTooComplex(u.errors)) + reportError(*e); + + if (u.errors.empty()) + { + // We found a matching overload. + const auto [changedTypes, changedPacks] = u.log.getChanges(); + u.log.commit(); + unblock(changedTypes); + unblock(changedPacks); + + unblock(c.result); + return true; + } + } + + // We found no matching overloads. + Unifier u{normalizer, Mode::Strict, constraint->scope, Location{}, Covariant}; + u.useScopes = true; + + u.tryUnify(inferredTy, builtinTypes->anyType); + u.tryUnify(fn, builtinTypes->anyType); + + LUAU_ASSERT(u.errors.empty()); // unifying with any should never fail + + const auto [changedTypes, changedPacks] = u.log.getChanges(); + u.log.commit(); + + unblock(changedTypes); + unblock(changedPacks); unblock(c.result); return true; @@ -1291,6 +1277,7 @@ bool ConstraintSolver::tryDispatch(const PrimitiveTypeConstraint& c, NotNullty.emplace(bindTo); + unblock(c.resultType); return true; } @@ -1311,8 +1298,6 @@ bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNullreduce(subjectType).value_or(subjectType); - auto [blocked, result] = lookupTableProp(subjectType, c.prop); if (!blocked.empty()) { @@ -1335,20 +1320,17 @@ static bool isUnsealedTable(TypeId ty) } /** - * Create a shallow copy of `ty` and its properties along `path`. Insert a new - * property (the last segment of `path`) into the tail table with the value `t`. + * Given a path into a set of nested unsealed tables `ty`, insert a new property `replaceTy` as the leaf-most property. * - * On success, returns the new outermost table type. If the root table or any - * of its subkeys are not unsealed tables, the function fails and returns - * std::nullopt. + * Fails and does nothing if every table along the way is not unsealed. * - * TODO: Prove that we completely give up in the face of indexers and - * metatables. + * Mutates the innermost table type in-place. */ -static std::optional updateTheTableType(NotNull arena, TypeId ty, const std::vector& path, TypeId replaceTy) +static void updateTheTableType( + NotNull builtinTypes, NotNull arena, TypeId ty, const std::vector& path, TypeId replaceTy) { if (path.empty()) - return std::nullopt; + return; // First walk the path and ensure that it's unsealed tables all the way // to the end. @@ -1357,12 +1339,12 @@ static std::optional updateTheTableType(NotNull arena, TypeId for (size_t i = 0; i < path.size() - 1; ++i) { if (!isUnsealedTable(t)) - return std::nullopt; + return; const TableType* tbl = get(t); auto it = tbl->props.find(path[i]); if (it == tbl->props.end()) - return std::nullopt; + return; t = follow(it->second.type); } @@ -1371,40 +1353,37 @@ static std::optional updateTheTableType(NotNull arena, TypeId // We are not changing property types. We are only admitting this one // new property to be appended. if (!isUnsealedTable(t)) - return std::nullopt; + return; const TableType* tbl = get(t); if (0 != tbl->props.count(path.back())) - return std::nullopt; + return; } - const TypeId res = shallowClone(ty, arena); - TypeId t = res; + TypeId t = ty; + ErrorVec dummy; for (size_t i = 0; i < path.size() - 1; ++i) { - const std::string segment = path[i]; + auto propTy = findTablePropertyRespectingMeta(builtinTypes, dummy, t, path[i], Location{}); + dummy.clear(); - TableType* ttv = getMutable(t); - LUAU_ASSERT(ttv); + if (!propTy) + return; - auto propIt = ttv->props.find(segment); - if (propIt != ttv->props.end()) - { - LUAU_ASSERT(isUnsealedTable(propIt->second.type)); - t = shallowClone(follow(propIt->second.type), arena); - ttv->props[segment].type = t; - } - else - return std::nullopt; + t = *propTy; } - TableType* ttv = getMutable(t); - LUAU_ASSERT(ttv); + const std::string& lastSegment = path.back(); - const std::string lastSegment = path.back(); - LUAU_ASSERT(0 == ttv->props.count(lastSegment)); - ttv->props[lastSegment] = Property{replaceTy}; - return res; + t = follow(t); + TableType* tt = getMutable(t); + if (auto mt = get(t)) + tt = getMutable(mt->table); + + if (!tt) + return; + + tt->props[lastSegment].type = replaceTy; } bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNull constraint, bool force) @@ -1443,6 +1422,7 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNullscope); bind(c.resultType, c.subjectType); + unblock(c.resultType); return true; } @@ -1467,6 +1447,8 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNull(subjectType)) @@ -1477,20 +1459,23 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNullprops[c.path[0]] = Property{c.propType}; bind(c.resultType, c.subjectType); + unblock(c.resultType); return true; } else if (ttv->state == TableState::Unsealed) { LUAU_ASSERT(!subjectType->persistent); - std::optional augmented = updateTheTableType(NotNull{arena}, subjectType, c.path, c.propType); - bind(c.resultType, augmented.value_or(subjectType)); - bind(subjectType, c.resultType); + updateTheTableType(builtinTypes, NotNull{arena}, subjectType, c.path, c.propType); + bind(c.resultType, c.subjectType); + unblock(subjectType); + unblock(c.resultType); return true; } else { bind(c.resultType, subjectType); + unblock(c.resultType); return true; } } @@ -1499,6 +1484,7 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNull(iteratorTy)) { - if (iteratorTable->state == TableState::Free) - return block_(iteratorTy); + /* + * We try not to dispatch IterableConstraints over free tables because + * it's possible that there are other constraints on the table that will + * clarify what we should do. + * + * We should eventually introduce a type family to talk about iteration. + */ + if (iteratorTable->state == TableState::Free && !force) + return block(iteratorTy, constraint); if (iteratorTable->indexer) { @@ -1932,14 +1925,14 @@ std::pair, std::optional> ConstraintSolver::lookupTa else if (auto utv = get(subjectType)) { std::vector blocked; - std::vector options; + std::set options; for (TypeId ty : utv) { auto [innerBlocked, innerResult] = lookupTableProp(ty, propName, seen); blocked.insert(blocked.end(), innerBlocked.begin(), innerBlocked.end()); if (innerResult) - options.push_back(*innerResult); + options.insert(*innerResult); } if (!blocked.empty()) @@ -1948,21 +1941,21 @@ std::pair, std::optional> ConstraintSolver::lookupTa if (options.empty()) return {{}, std::nullopt}; else if (options.size() == 1) - return {{}, options[0]}; + return {{}, *begin(options)}; else - return {{}, arena->addType(UnionType{std::move(options)})}; + return {{}, arena->addType(UnionType{std::vector(begin(options), end(options))})}; } else if (auto itv = get(subjectType)) { std::vector blocked; - std::vector options; + std::set options; for (TypeId ty : itv) { auto [innerBlocked, innerResult] = lookupTableProp(ty, propName, seen); blocked.insert(blocked.end(), innerBlocked.begin(), innerBlocked.end()); if (innerResult) - options.push_back(*innerResult); + options.insert(*innerResult); } if (!blocked.empty()) @@ -1971,14 +1964,61 @@ std::pair, std::optional> ConstraintSolver::lookupTa if (options.empty()) return {{}, std::nullopt}; else if (options.size() == 1) - return {{}, options[0]}; + return {{}, *begin(options)}; else - return {{}, arena->addType(IntersectionType{std::move(options)})}; + return {{}, arena->addType(IntersectionType{std::vector(begin(options), end(options))})}; } return {{}, std::nullopt}; } +static TypeId getErrorType(NotNull builtinTypes, TypeId) +{ + return builtinTypes->errorRecoveryType(); +} + +static TypePackId getErrorType(NotNull builtinTypes, TypePackId) +{ + return builtinTypes->errorRecoveryTypePack(); +} + +template +bool ConstraintSolver::tryUnify(NotNull constraint, TID subTy, TID superTy) +{ + Unifier u{normalizer, Mode::Strict, constraint->scope, Location{}, Covariant}; + u.useScopes = true; + + u.tryUnify(subTy, superTy); + + if (!u.blockedTypes.empty() || !u.blockedTypePacks.empty()) + { + for (TypeId bt : u.blockedTypes) + block(bt, constraint); + for (TypePackId btp : u.blockedTypePacks) + block(btp, constraint); + return false; + } + + if (const auto& e = hasUnificationTooComplex(u.errors)) + reportError(*e); + + if (!u.errors.empty()) + { + TID errorType = getErrorType(builtinTypes, TID{}); + u.tryUnify(subTy, errorType); + u.tryUnify(superTy, errorType); + } + + const auto [changedTypes, changedPacks] = u.log.getChanges(); + + u.log.commit(); + + unblock(changedTypes); + unblock(changedPacks); + + return true; +} + void ConstraintSolver::block_(BlockedConstraintId target, NotNull constraint) { blocked[target].push_back(constraint); diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index de79e0be..98022d86 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -29,10 +29,8 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTINT(LuauTarjanChildLimit) LUAU_FASTFLAG(LuauInferInNoCheckMode) LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) -LUAU_FASTFLAGVARIABLE(LuauLintInTypecheck, false) LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 100) LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false) -LUAU_FASTFLAGVARIABLE(LuauDefinitionFileSourceModule, false) LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false); namespace Luau @@ -85,179 +83,92 @@ static void generateDocumentationSymbols(TypeId ty, const std::string& rootName) } } -LoadDefinitionFileResult Frontend::loadDefinitionFile(std::string_view source, const std::string& packageName, bool captureComments) +static ParseResult parseSourceForModule(std::string_view source, Luau::SourceModule& sourceModule, bool captureComments) { - if (!FFlag::DebugLuauDeferredConstraintResolution) - return Luau::loadDefinitionFile(typeChecker, globals, globals.globalScope, source, packageName, captureComments); - - LUAU_TIMETRACE_SCOPE("loadDefinitionFile", "Frontend"); - - Luau::SourceModule sourceModule; - ParseOptions options; options.allowDeclarationSyntax = true; options.captureComments = captureComments; Luau::ParseResult parseResult = Luau::Parser::parse(source.data(), source.size(), *sourceModule.names, *sourceModule.allocator, options); - - if (parseResult.errors.size() > 0) - return LoadDefinitionFileResult{false, parseResult, sourceModule, nullptr}; - sourceModule.root = parseResult.root; sourceModule.mode = Mode::Definition; + return parseResult; +} + +static void persistCheckedTypes(ModulePtr checkedModule, GlobalTypes& globals, ScopePtr targetScope, const std::string& packageName) +{ + CloneState cloneState; + + std::vector typesToPersist; + typesToPersist.reserve(checkedModule->declaredGlobals.size() + checkedModule->exportedTypeBindings.size()); + + for (const auto& [name, ty] : checkedModule->declaredGlobals) + { + TypeId globalTy = clone(ty, globals.globalTypes, cloneState); + std::string documentationSymbol = packageName + "/global/" + name; + generateDocumentationSymbols(globalTy, documentationSymbol); + targetScope->bindings[globals.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; + + typesToPersist.push_back(globalTy); + } + + for (const auto& [name, ty] : checkedModule->exportedTypeBindings) + { + TypeFun globalTy = clone(ty, globals.globalTypes, cloneState); + std::string documentationSymbol = packageName + "/globaltype/" + name; + generateDocumentationSymbols(globalTy.type, documentationSymbol); + targetScope->exportedTypeBindings[name] = globalTy; + + typesToPersist.push_back(globalTy.type); + } + + for (TypeId ty : typesToPersist) + { + persist(ty); + } +} + +LoadDefinitionFileResult Frontend::loadDefinitionFile(GlobalTypes& globals, ScopePtr targetScope, std::string_view source, + const std::string& packageName, bool captureComments, bool typeCheckForAutocomplete) +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return Luau::loadDefinitionFileNoDCR(typeCheckForAutocomplete ? typeCheckerForAutocomplete : typeChecker, + typeCheckForAutocomplete ? globalsForAutocomplete : globals, targetScope, source, packageName, captureComments); + + LUAU_TIMETRACE_SCOPE("loadDefinitionFile", "Frontend"); + + Luau::SourceModule sourceModule; + Luau::ParseResult parseResult = parseSourceForModule(source, sourceModule, captureComments); + if (parseResult.errors.size() > 0) + return LoadDefinitionFileResult{false, parseResult, sourceModule, nullptr}; ModulePtr checkedModule = check(sourceModule, Mode::Definition, {}); if (checkedModule->errors.size() > 0) return LoadDefinitionFileResult{false, parseResult, sourceModule, checkedModule}; - CloneState cloneState; - - std::vector typesToPersist; - typesToPersist.reserve(checkedModule->declaredGlobals.size() + checkedModule->exportedTypeBindings.size()); - - for (const auto& [name, ty] : checkedModule->declaredGlobals) - { - TypeId globalTy = clone(ty, globals.globalTypes, cloneState); - std::string documentationSymbol = packageName + "/global/" + name; - generateDocumentationSymbols(globalTy, documentationSymbol); - globals.globalScope->bindings[globals.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; - - typesToPersist.push_back(globalTy); - } - - for (const auto& [name, ty] : checkedModule->exportedTypeBindings) - { - TypeFun globalTy = clone(ty, globals.globalTypes, cloneState); - std::string documentationSymbol = packageName + "/globaltype/" + name; - generateDocumentationSymbols(globalTy.type, documentationSymbol); - globals.globalScope->exportedTypeBindings[name] = globalTy; - - typesToPersist.push_back(globalTy.type); - } - - for (TypeId ty : typesToPersist) - { - persist(ty); - } + persistCheckedTypes(checkedModule, globals, targetScope, packageName); return LoadDefinitionFileResult{true, parseResult, sourceModule, checkedModule}; } -LoadDefinitionFileResult loadDefinitionFile_DEPRECATED( - TypeChecker& typeChecker, GlobalTypes& globals, ScopePtr targetScope, std::string_view source, const std::string& packageName) -{ - LUAU_TIMETRACE_SCOPE("loadDefinitionFile", "Frontend"); - - Luau::Allocator allocator; - Luau::AstNameTable names(allocator); - - ParseOptions options; - options.allowDeclarationSyntax = true; - - Luau::ParseResult parseResult = Luau::Parser::parse(source.data(), source.size(), names, allocator, options); - - if (parseResult.errors.size() > 0) - return LoadDefinitionFileResult{false, parseResult, {}, nullptr}; - - Luau::SourceModule module; - module.root = parseResult.root; - module.mode = Mode::Definition; - - ModulePtr checkedModule = typeChecker.check(module, Mode::Definition); - - if (checkedModule->errors.size() > 0) - return LoadDefinitionFileResult{false, parseResult, {}, checkedModule}; - - CloneState cloneState; - - std::vector typesToPersist; - typesToPersist.reserve(checkedModule->declaredGlobals.size() + checkedModule->exportedTypeBindings.size()); - - for (const auto& [name, ty] : checkedModule->declaredGlobals) - { - TypeId globalTy = clone(ty, globals.globalTypes, cloneState); - std::string documentationSymbol = packageName + "/global/" + name; - generateDocumentationSymbols(globalTy, documentationSymbol); - targetScope->bindings[globals.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; - - typesToPersist.push_back(globalTy); - } - - for (const auto& [name, ty] : checkedModule->exportedTypeBindings) - { - TypeFun globalTy = clone(ty, globals.globalTypes, cloneState); - std::string documentationSymbol = packageName + "/globaltype/" + name; - generateDocumentationSymbols(globalTy.type, documentationSymbol); - targetScope->exportedTypeBindings[name] = globalTy; - - typesToPersist.push_back(globalTy.type); - } - - for (TypeId ty : typesToPersist) - { - persist(ty); - } - - return LoadDefinitionFileResult{true, parseResult, {}, checkedModule}; -} - -LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, GlobalTypes& globals, ScopePtr targetScope, std::string_view source, +LoadDefinitionFileResult loadDefinitionFileNoDCR(TypeChecker& typeChecker, GlobalTypes& globals, ScopePtr targetScope, std::string_view source, const std::string& packageName, bool captureComments) { - if (!FFlag::LuauDefinitionFileSourceModule) - return loadDefinitionFile_DEPRECATED(typeChecker, globals, targetScope, source, packageName); - LUAU_TIMETRACE_SCOPE("loadDefinitionFile", "Frontend"); Luau::SourceModule sourceModule; - - ParseOptions options; - options.allowDeclarationSyntax = true; - options.captureComments = captureComments; - - Luau::ParseResult parseResult = Luau::Parser::parse(source.data(), source.size(), *sourceModule.names, *sourceModule.allocator, options); + Luau::ParseResult parseResult = parseSourceForModule(source, sourceModule, captureComments); if (parseResult.errors.size() > 0) return LoadDefinitionFileResult{false, parseResult, sourceModule, nullptr}; - sourceModule.root = parseResult.root; - sourceModule.mode = Mode::Definition; - ModulePtr checkedModule = typeChecker.check(sourceModule, Mode::Definition); if (checkedModule->errors.size() > 0) return LoadDefinitionFileResult{false, parseResult, sourceModule, checkedModule}; - CloneState cloneState; - - std::vector typesToPersist; - typesToPersist.reserve(checkedModule->declaredGlobals.size() + checkedModule->exportedTypeBindings.size()); - - for (const auto& [name, ty] : checkedModule->declaredGlobals) - { - TypeId globalTy = clone(ty, globals.globalTypes, cloneState); - std::string documentationSymbol = packageName + "/global/" + name; - generateDocumentationSymbols(globalTy, documentationSymbol); - targetScope->bindings[globals.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; - - typesToPersist.push_back(globalTy); - } - - for (const auto& [name, ty] : checkedModule->exportedTypeBindings) - { - TypeFun globalTy = clone(ty, globals.globalTypes, cloneState); - std::string documentationSymbol = packageName + "/globaltype/" + name; - generateDocumentationSymbols(globalTy.type, documentationSymbol); - targetScope->exportedTypeBindings[name] = globalTy; - - typesToPersist.push_back(globalTy.type); - } - - for (TypeId ty : typesToPersist) - { - persist(ty); - } + persistCheckedTypes(checkedModule, globals, targetScope, packageName); return LoadDefinitionFileResult{true, parseResult, sourceModule, checkedModule}; } @@ -378,8 +289,6 @@ static ErrorVec accumulateErrors( static void filterLintOptions(LintOptions& lintOptions, const std::vector& hotcomments, Mode mode) { - LUAU_ASSERT(FFlag::LuauLintInTypecheck); - uint64_t ignoreLints = LintWarning::parseMask(hotcomments); lintOptions.warningMask &= ~ignoreLints; @@ -497,8 +406,8 @@ Frontend::Frontend(FileResolver* fileResolver, ConfigResolver* configResolver, c , moduleResolverForAutocomplete(this) , globals(builtinTypes) , globalsForAutocomplete(builtinTypes) - , typeChecker(globals, &moduleResolver, builtinTypes, &iceHandler) - , typeCheckerForAutocomplete(globalsForAutocomplete, &moduleResolverForAutocomplete, builtinTypes, &iceHandler) + , typeChecker(globals.globalScope, &moduleResolver, builtinTypes, &iceHandler) + , typeCheckerForAutocomplete(globalsForAutocomplete.globalScope, &moduleResolverForAutocomplete, builtinTypes, &iceHandler) , configResolver(configResolver) , options(options) { @@ -534,24 +443,16 @@ CheckResult Frontend::check(const ModuleName& name, std::optional& modules = - frontendOptions.forAutocomplete ? moduleResolverForAutocomplete.modules : moduleResolver.modules; + std::unordered_map& modules = + frontendOptions.forAutocomplete ? moduleResolverForAutocomplete.modules : moduleResolver.modules; - checkResult.errors = accumulateErrors(sourceNodes, modules, name); + checkResult.errors = accumulateErrors(sourceNodes, modules, name); - // Get lint result only for top checked module - if (auto it = modules.find(name); it != modules.end()) - checkResult.lintResult = it->second->lintResult; + // Get lint result only for top checked module + if (auto it = modules.find(name); it != modules.end()) + checkResult.lintResult = it->second->lintResult; - return checkResult; - } - else - { - return CheckResult{accumulateErrors( - sourceNodes, frontendOptions.forAutocomplete ? moduleResolverForAutocomplete.modules : moduleResolver.modules, name)}; - } + return checkResult; } std::vector buildQueue; @@ -615,9 +516,10 @@ CheckResult Frontend::check(const ModuleName& name, std::optional& modules = - frontendOptions.forAutocomplete ? moduleResolverForAutocomplete.modules : moduleResolver.modules; + // Get lint result only for top checked module + std::unordered_map& modules = + frontendOptions.forAutocomplete ? moduleResolverForAutocomplete.modules : moduleResolver.modules; - if (auto it = modules.find(name); it != modules.end()) - checkResult.lintResult = it->second->lintResult; - } + if (auto it = modules.find(name); it != modules.end()) + checkResult.lintResult = it->second->lintResult; return checkResult; } @@ -862,59 +759,6 @@ ScopePtr Frontend::getModuleEnvironment(const SourceModule& module, const Config return result; } -LintResult Frontend::lint_DEPRECATED(const ModuleName& name, std::optional enabledLintWarnings) -{ - LUAU_ASSERT(!FFlag::LuauLintInTypecheck); - - LUAU_TIMETRACE_SCOPE("Frontend::lint", "Frontend"); - LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); - - auto [_sourceNode, sourceModule] = getSourceNode(name); - - if (!sourceModule) - return LintResult{}; // FIXME: We really should do something a bit more obvious when a file is too broken to lint. - - return lint_DEPRECATED(*sourceModule, enabledLintWarnings); -} - -LintResult Frontend::lint_DEPRECATED(const SourceModule& module, std::optional enabledLintWarnings) -{ - LUAU_ASSERT(!FFlag::LuauLintInTypecheck); - - LUAU_TIMETRACE_SCOPE("Frontend::lint", "Frontend"); - LUAU_TIMETRACE_ARGUMENT("module", module.name.c_str()); - - const Config& config = configResolver->getConfig(module.name); - - uint64_t ignoreLints = LintWarning::parseMask(module.hotcomments); - - LintOptions options = enabledLintWarnings.value_or(config.enabledLint); - options.warningMask &= ~ignoreLints; - - Mode mode = module.mode.value_or(config.mode); - if (mode != Mode::NoCheck) - { - options.disableWarning(Luau::LintWarning::Code_UnknownGlobal); - } - - if (mode == Mode::Strict) - { - options.disableWarning(Luau::LintWarning::Code_ImplicitReturn); - } - - ScopePtr environmentScope = getModuleEnvironment(module, config, /*forAutocomplete*/ false); - - ModulePtr modulePtr = moduleResolver.getModule(module.name); - - double timestamp = getTimestamp(); - - std::vector warnings = Luau::lint(module.root, *module.names, environmentScope, modulePtr.get(), module.hotcomments, options); - - stats.timeLint += getTimestamp() - timestamp; - - return classifyLints(warnings, config); -} - bool Frontend::isDirty(const ModuleName& name, bool forAutocomplete) const { auto it = sourceNodes.find(name); @@ -1032,8 +876,8 @@ ModulePtr check(const SourceModule& sourceModule, const std::vectorerrors = std::move(cgb.errors); - ConstraintSolver cs{NotNull{&normalizer}, NotNull(cgb.rootScope), borrowConstraints(cgb.constraints), sourceModule.name, - NotNull{result->reduction.get()}, moduleResolver, requireCycles, logger.get()}; + ConstraintSolver cs{NotNull{&normalizer}, NotNull(cgb.rootScope), borrowConstraints(cgb.constraints), sourceModule.name, moduleResolver, + requireCycles, logger.get()}; if (options.randomizeConstraintResolutionSeed) cs.randomize(*options.randomizeConstraintResolutionSeed); @@ -1257,7 +1101,7 @@ ScopePtr Frontend::getEnvironmentScope(const std::string& environmentName) const return {}; } -void Frontend::registerBuiltinDefinition(const std::string& name, std::function applicator) +void Frontend::registerBuiltinDefinition(const std::string& name, std::function applicator) { LUAU_ASSERT(builtinDefinitions.count(name) == 0); @@ -1270,7 +1114,7 @@ void Frontend::applyBuiltinDefinitionToEnvironment(const std::string& environmen LUAU_ASSERT(builtinDefinitions.count(definitionName) > 0); if (builtinDefinitions.count(definitionName) > 0) - builtinDefinitions[definitionName](typeChecker, globals, getEnvironmentScope(environmentName)); + builtinDefinitions[definitionName](*this, globals, getEnvironmentScope(environmentName)); } LintResult Frontend::classifyLints(const std::vector& warnings, const Config& config) diff --git a/Analysis/src/Instantiation.cpp b/Analysis/src/Instantiation.cpp index 9c3ae077..7d0f0f72 100644 --- a/Analysis/src/Instantiation.cpp +++ b/Analysis/src/Instantiation.cpp @@ -127,7 +127,7 @@ TypeId ReplaceGenerics::clean(TypeId ty) TypePackId ReplaceGenerics::clean(TypePackId tp) { LUAU_ASSERT(isDirty(tp)); - return addTypePack(TypePackVar(FreeTypePack{level})); + return addTypePack(TypePackVar(FreeTypePack{scope, level})); } } // namespace Luau diff --git a/Analysis/src/Linter.cpp b/Analysis/src/Linter.cpp index f850bd3d..d6aafda6 100644 --- a/Analysis/src/Linter.cpp +++ b/Analysis/src/Linter.cpp @@ -14,8 +14,6 @@ LUAU_FASTINTVARIABLE(LuauSuggestionDistance, 4) -LUAU_FASTFLAGVARIABLE(LuauImproveDeprecatedApiLint, false) - namespace Luau { @@ -2102,9 +2100,6 @@ class LintDeprecatedApi : AstVisitor public: LUAU_NOINLINE static void process(LintContext& context) { - if (!FFlag::LuauImproveDeprecatedApiLint && !context.module) - return; - LintDeprecatedApi pass{&context}; context.root->visit(&pass); } @@ -2122,8 +2117,7 @@ private: if (std::optional ty = context->getType(node->expr)) check(node, follow(*ty)); else if (AstExprGlobal* global = node->expr->as()) - if (FFlag::LuauImproveDeprecatedApiLint) - check(node->location, global->name, node->index); + check(node->location, global->name, node->index); return true; } @@ -2144,7 +2138,7 @@ private: if (prop != tty->props.end() && prop->second.deprecated) { // strip synthetic typeof() for builtin tables - if (FFlag::LuauImproveDeprecatedApiLint && tty->name && tty->name->compare(0, 7, "typeof(") == 0 && tty->name->back() == ')') + if (tty->name && tty->name->compare(0, 7, "typeof(") == 0 && tty->name->back() == ')') report(node->location, prop->second, tty->name->substr(7, tty->name->length() - 8).c_str(), node->index.value); else report(node->location, prop->second, tty->name ? tty->name->c_str() : nullptr, node->index.value); diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index b51b7c9a..fd948403 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -16,7 +16,7 @@ #include LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); -LUAU_FASTFLAGVARIABLE(LuauClonePublicInterfaceLess, false); +LUAU_FASTFLAGVARIABLE(LuauClonePublicInterfaceLess2, false); LUAU_FASTFLAG(LuauSubstitutionReentrant); LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution); LUAU_FASTFLAG(LuauSubstitutionFixMissingFields); @@ -194,7 +194,7 @@ void Module::clonePublicInterface(NotNull builtinTypes, InternalEr TxnLog log; ClonePublicInterface clonePublicInterface{&log, builtinTypes, this}; - if (FFlag::LuauClonePublicInterfaceLess) + if (FFlag::LuauClonePublicInterfaceLess2) returnType = clonePublicInterface.cloneTypePack(returnType); else returnType = clone(returnType, interfaceTypes, cloneState); @@ -202,7 +202,7 @@ void Module::clonePublicInterface(NotNull builtinTypes, InternalEr moduleScope->returnType = returnType; if (varargPack) { - if (FFlag::LuauClonePublicInterfaceLess) + if (FFlag::LuauClonePublicInterfaceLess2) varargPack = clonePublicInterface.cloneTypePack(*varargPack); else varargPack = clone(*varargPack, interfaceTypes, cloneState); @@ -211,7 +211,7 @@ void Module::clonePublicInterface(NotNull builtinTypes, InternalEr for (auto& [name, tf] : moduleScope->exportedTypeBindings) { - if (FFlag::LuauClonePublicInterfaceLess) + if (FFlag::LuauClonePublicInterfaceLess2) tf = clonePublicInterface.cloneTypeFun(tf); else tf = clone(tf, interfaceTypes, cloneState); @@ -219,7 +219,7 @@ void Module::clonePublicInterface(NotNull builtinTypes, InternalEr for (auto& [name, ty] : declaredGlobals) { - if (FFlag::LuauClonePublicInterfaceLess) + if (FFlag::LuauClonePublicInterfaceLess2) ty = clonePublicInterface.cloneType(ty); else ty = clone(ty, interfaceTypes, cloneState); diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index f8f8b97f..46595b70 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -18,9 +18,9 @@ LUAU_FASTFLAGVARIABLE(DebugLuauCheckNormalizeInvariant, false) LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200); LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000); LUAU_FASTFLAGVARIABLE(LuauNegatedClassTypes, false); -LUAU_FASTFLAGVARIABLE(LuauNegatedFunctionTypes, false); LUAU_FASTFLAGVARIABLE(LuauNegatedTableTypes, false); LUAU_FASTFLAGVARIABLE(LuauNormalizeBlockedTypes, false); +LUAU_FASTFLAGVARIABLE(LuauNormalizeMetatableFixes, false); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(LuauUninhabitedSubAnything2) LUAU_FASTFLAG(LuauTransitiveSubtyping) @@ -202,26 +202,21 @@ bool NormalizedClassType::isNever() const return classes.empty(); } -NormalizedFunctionType::NormalizedFunctionType() - : parts(FFlag::LuauNegatedFunctionTypes ? std::optional{TypeIds{}} : std::nullopt) -{ -} - void NormalizedFunctionType::resetToTop() { isTop = true; - parts.emplace(); + parts.clear(); } void NormalizedFunctionType::resetToNever() { isTop = false; - parts.emplace(); + parts.clear(); } bool NormalizedFunctionType::isNever() const { - return !isTop && (!parts || parts->empty()); + return !isTop && parts.empty(); } NormalizedType::NormalizedType(NotNull builtinTypes) @@ -438,13 +433,10 @@ static bool isNormalizedThread(TypeId ty) static bool areNormalizedFunctions(const NormalizedFunctionType& tys) { - if (tys.parts) + for (TypeId ty : tys.parts) { - for (TypeId ty : *tys.parts) - { - if (!get(ty) && !get(ty)) - return false; - } + if (!get(ty) && !get(ty)) + return false; } return true; } @@ -533,7 +525,7 @@ static bool areNormalizedClasses(const NormalizedClassType& tys) static bool isPlainTyvar(TypeId ty) { - return (get(ty) || get(ty) || (FFlag::LuauNormalizeBlockedTypes && get(ty))); + return (get(ty) || get(ty) || (FFlag::LuauNormalizeBlockedTypes && get(ty)) || get(ty)); } static bool isNormalizedTyvar(const NormalizedTyvars& tyvars) @@ -1170,13 +1162,10 @@ std::optional Normalizer::unionOfFunctions(TypeId here, TypeId there) void Normalizer::unionFunctions(NormalizedFunctionType& heres, const NormalizedFunctionType& theres) { - if (FFlag::LuauNegatedFunctionTypes) - { - if (heres.isTop) - return; - if (theres.isTop) - heres.resetToTop(); - } + if (heres.isTop) + return; + if (theres.isTop) + heres.resetToTop(); if (theres.isNever()) return; @@ -1185,13 +1174,13 @@ void Normalizer::unionFunctions(NormalizedFunctionType& heres, const NormalizedF if (heres.isNever()) { - tmps.insert(theres.parts->begin(), theres.parts->end()); + tmps.insert(theres.parts.begin(), theres.parts.end()); heres.parts = std::move(tmps); return; } - for (TypeId here : *heres.parts) - for (TypeId there : *theres.parts) + for (TypeId here : heres.parts) + for (TypeId there : theres.parts) { if (std::optional fun = unionOfFunctions(here, there)) tmps.insert(*fun); @@ -1213,7 +1202,7 @@ void Normalizer::unionFunctionsWithFunction(NormalizedFunctionType& heres, TypeI } TypeIds tmps; - for (TypeId here : *heres.parts) + for (TypeId here : heres.parts) { if (std::optional fun = unionOfFunctions(here, there)) tmps.insert(*fun); @@ -1380,7 +1369,8 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor } else if (FFlag::LuauTransitiveSubtyping && get(here.tops)) return true; - else if (get(there) || get(there) || (FFlag::LuauNormalizeBlockedTypes && get(there))) + else if (get(there) || get(there) || (FFlag::LuauNormalizeBlockedTypes && get(there)) || + get(there)) { if (tyvarIndex(there) <= ignoreSmallerTyvars) return true; @@ -1419,7 +1409,6 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor here.threads = there; else if (ptv->type == PrimitiveType::Function) { - LUAU_ASSERT(FFlag::LuauNegatedFunctionTypes); here.functions.resetToTop(); } else if (ptv->type == PrimitiveType::Table && FFlag::LuauNegatedTableTypes) @@ -1460,6 +1449,10 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor } else if (!FFlag::LuauNormalizeBlockedTypes && get(there)) LUAU_ASSERT(!"Internal error: Trying to normalize a BlockedType"); + else if (get(there)) + { + // nothing + } else LUAU_ASSERT(!"Unreachable"); @@ -1548,15 +1541,12 @@ std::optional Normalizer::negateNormal(const NormalizedType& her * arbitrary function types. Ordinary code can never form these kinds of * types, so we decline to negate them. */ - if (FFlag::LuauNegatedFunctionTypes) - { - if (here.functions.isNever()) - result.functions.resetToTop(); - else if (here.functions.isTop) - result.functions.resetToNever(); - else - return std::nullopt; - } + if (here.functions.isNever()) + result.functions.resetToTop(); + else if (here.functions.isTop) + result.functions.resetToNever(); + else + return std::nullopt; /* * It is not possible to negate an arbitrary table type, because function @@ -2073,6 +2063,18 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there else if (isPrim(there, PrimitiveType::Table)) return here; + if (FFlag::LuauNormalizeMetatableFixes) + { + if (get(here)) + return there; + else if (get(there)) + return here; + else if (get(here)) + return there; + else if (get(there)) + return here; + } + TypeId htable = here; TypeId hmtable = nullptr; if (const MetatableType* hmtv = get(here)) @@ -2089,9 +2091,23 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there } const TableType* httv = get(htable); - LUAU_ASSERT(httv); + if (FFlag::LuauNormalizeMetatableFixes) + { + if (!httv) + return std::nullopt; + } + else + LUAU_ASSERT(httv); + const TableType* tttv = get(ttable); - LUAU_ASSERT(tttv); + if (FFlag::LuauNormalizeMetatableFixes) + { + if (!tttv) + return std::nullopt; + } + else + LUAU_ASSERT(tttv); + if (httv->state == TableState::Free || tttv->state == TableState::Free) return std::nullopt; @@ -2385,15 +2401,15 @@ void Normalizer::intersectFunctionsWithFunction(NormalizedFunctionType& heres, T heres.isTop = false; - for (auto it = heres.parts->begin(); it != heres.parts->end();) + for (auto it = heres.parts.begin(); it != heres.parts.end();) { TypeId here = *it; if (get(here)) it++; else if (std::optional tmp = intersectionOfFunctions(here, there)) { - heres.parts->erase(it); - heres.parts->insert(*tmp); + heres.parts.erase(it); + heres.parts.insert(*tmp); return; } else @@ -2401,13 +2417,13 @@ void Normalizer::intersectFunctionsWithFunction(NormalizedFunctionType& heres, T } TypeIds tmps; - for (TypeId here : *heres.parts) + for (TypeId here : heres.parts) { if (std::optional tmp = unionSaturatedFunctions(here, there)) tmps.insert(*tmp); } - heres.parts->insert(there); - heres.parts->insert(tmps.begin(), tmps.end()); + heres.parts.insert(there); + heres.parts.insert(tmps.begin(), tmps.end()); } void Normalizer::intersectFunctions(NormalizedFunctionType& heres, const NormalizedFunctionType& theres) @@ -2421,7 +2437,7 @@ void Normalizer::intersectFunctions(NormalizedFunctionType& heres, const Normali } else { - for (TypeId there : *theres.parts) + for (TypeId there : theres.parts) intersectFunctionsWithFunction(heres, there); } } @@ -2544,7 +2560,8 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) return false; return true; } - else if (get(there) || get(there) || (FFlag::LuauNormalizeBlockedTypes && get(there))) + else if (get(there) || get(there) || (FFlag::LuauNormalizeBlockedTypes && get(there)) || + get(there)) { NormalizedType thereNorm{builtinTypes}; NormalizedType topNorm{builtinTypes}; @@ -2615,10 +2632,7 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) else if (ptv->type == PrimitiveType::Thread) here.threads = threads; else if (ptv->type == PrimitiveType::Function) - { - LUAU_ASSERT(FFlag::LuauNegatedFunctionTypes); here.functions = std::move(functions); - } else if (ptv->type == PrimitiveType::Table) { LUAU_ASSERT(FFlag::LuauNegatedTableTypes); @@ -2762,16 +2776,16 @@ TypeId Normalizer::typeFromNormal(const NormalizedType& norm) if (!get(norm.errors)) result.push_back(norm.errors); - if (FFlag::LuauNegatedFunctionTypes && norm.functions.isTop) + if (norm.functions.isTop) result.push_back(builtinTypes->functionType); else if (!norm.functions.isNever()) { - if (norm.functions.parts->size() == 1) - result.push_back(*norm.functions.parts->begin()); + if (norm.functions.parts.size() == 1) + result.push_back(*norm.functions.parts.begin()); else { std::vector parts; - parts.insert(parts.end(), norm.functions.parts->begin(), norm.functions.parts->end()); + parts.insert(parts.end(), norm.functions.parts.begin(), norm.functions.parts.end()); result.push_back(arena->addType(IntersectionType{std::move(parts)})); } } @@ -2856,7 +2870,8 @@ bool isConsistentSubtype(TypeId subTy, TypeId superTy, NotNull scope, Not return ok; } -bool isConsistentSubtype(TypePackId subPack, TypePackId superPack, NotNull scope, NotNull builtinTypes, InternalErrorReporter& ice) +bool isConsistentSubtype( + TypePackId subPack, TypePackId superPack, NotNull scope, NotNull builtinTypes, InternalErrorReporter& ice) { UnifierSharedState sharedState{&ice}; TypeArena arena; diff --git a/Analysis/src/Scope.cpp b/Analysis/src/Scope.cpp index f54ebe2a..2de381be 100644 --- a/Analysis/src/Scope.cpp +++ b/Analysis/src/Scope.cpp @@ -65,7 +65,7 @@ std::optional Scope::lookup(DefId def) const return std::nullopt; } -std::optional Scope::lookupType(const Name& name) +std::optional Scope::lookupType(const Name& name) const { const Scope* scope = this; while (true) @@ -85,7 +85,7 @@ std::optional Scope::lookupType(const Name& name) } } -std::optional Scope::lookupImportedType(const Name& moduleAlias, const Name& name) +std::optional Scope::lookupImportedType(const Name& moduleAlias, const Name& name) const { const Scope* scope = this; while (scope) @@ -110,7 +110,7 @@ std::optional Scope::lookupImportedType(const Name& moduleAlias, const return std::nullopt; } -std::optional Scope::lookupPack(const Name& name) +std::optional Scope::lookupPack(const Name& name) const { const Scope* scope = this; while (true) diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index 160647a0..935d85d7 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -9,7 +9,7 @@ #include LUAU_FASTFLAGVARIABLE(LuauSubstitutionFixMissingFields, false) -LUAU_FASTFLAG(LuauClonePublicInterfaceLess) +LUAU_FASTFLAG(LuauClonePublicInterfaceLess2) LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 10000) LUAU_FASTFLAGVARIABLE(LuauClassTypeVarsInSubstitution, false) LUAU_FASTFLAGVARIABLE(LuauSubstitutionReentrant, false) @@ -17,6 +17,181 @@ LUAU_FASTFLAGVARIABLE(LuauSubstitutionReentrant, false) namespace Luau { +static TypeId DEPRECATED_shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysClone) +{ + ty = log->follow(ty); + + TypeId result = ty; + + if (auto pty = log->pending(ty)) + ty = &pty->pending; + + if (const FunctionType* ftv = get(ty)) + { + FunctionType clone = FunctionType{ftv->level, ftv->scope, ftv->argTypes, ftv->retTypes, ftv->definition, ftv->hasSelf}; + clone.generics = ftv->generics; + clone.genericPacks = ftv->genericPacks; + clone.magicFunction = ftv->magicFunction; + clone.dcrMagicFunction = ftv->dcrMagicFunction; + clone.dcrMagicRefinement = ftv->dcrMagicRefinement; + clone.tags = ftv->tags; + clone.argNames = ftv->argNames; + result = dest.addType(std::move(clone)); + } + else if (const TableType* ttv = get(ty)) + { + LUAU_ASSERT(!ttv->boundTo); + TableType clone = TableType{ttv->props, ttv->indexer, ttv->level, ttv->scope, ttv->state}; + clone.definitionModuleName = ttv->definitionModuleName; + clone.definitionLocation = ttv->definitionLocation; + clone.name = ttv->name; + clone.syntheticName = ttv->syntheticName; + clone.instantiatedTypeParams = ttv->instantiatedTypeParams; + clone.instantiatedTypePackParams = ttv->instantiatedTypePackParams; + clone.tags = ttv->tags; + result = dest.addType(std::move(clone)); + } + else if (const MetatableType* mtv = get(ty)) + { + MetatableType clone = MetatableType{mtv->table, mtv->metatable}; + clone.syntheticName = mtv->syntheticName; + result = dest.addType(std::move(clone)); + } + else if (const UnionType* utv = get(ty)) + { + UnionType clone; + clone.options = utv->options; + result = dest.addType(std::move(clone)); + } + else if (const IntersectionType* itv = get(ty)) + { + IntersectionType clone; + clone.parts = itv->parts; + result = dest.addType(std::move(clone)); + } + else if (const PendingExpansionType* petv = get(ty)) + { + PendingExpansionType clone{petv->prefix, petv->name, petv->typeArguments, petv->packArguments}; + result = dest.addType(std::move(clone)); + } + else if (const NegationType* ntv = get(ty)) + { + result = dest.addType(NegationType{ntv->ty}); + } + else + return result; + + asMutable(result)->documentationSymbol = ty->documentationSymbol; + return result; +} + +static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysClone) +{ + if (!FFlag::LuauClonePublicInterfaceLess2) + return DEPRECATED_shallowClone(ty, dest, log, alwaysClone); + + auto go = [ty, &dest, alwaysClone](auto&& a) { + using T = std::decay_t; + + if constexpr (std::is_same_v) + return ty; + else if constexpr (std::is_same_v) + { + // This should never happen, but visit() cannot see it. + LUAU_ASSERT(!"shallowClone didn't follow its argument!"); + return dest.addType(BoundType{a.boundTo}); + } + else if constexpr (std::is_same_v) + return dest.addType(a); + else if constexpr (std::is_same_v) + return ty; + else if constexpr (std::is_same_v) + return ty; + else if constexpr (std::is_same_v) + return ty; + else if constexpr (std::is_same_v) + return ty; + else if constexpr (std::is_same_v) + return ty; + else if constexpr (std::is_same_v) + return ty; + else if constexpr (std::is_same_v) + return ty; + else if constexpr (std::is_same_v) + return ty; + else if constexpr (std::is_same_v) + return dest.addType(a); + else if constexpr (std::is_same_v) + { + FunctionType clone = FunctionType{a.level, a.scope, a.argTypes, a.retTypes, a.definition, a.hasSelf}; + clone.generics = a.generics; + clone.genericPacks = a.genericPacks; + clone.magicFunction = a.magicFunction; + clone.dcrMagicFunction = a.dcrMagicFunction; + clone.dcrMagicRefinement = a.dcrMagicRefinement; + clone.tags = a.tags; + clone.argNames = a.argNames; + return dest.addType(std::move(clone)); + } + else if constexpr (std::is_same_v) + { + LUAU_ASSERT(!a.boundTo); + TableType clone = TableType{a.props, a.indexer, a.level, a.scope, a.state}; + clone.definitionModuleName = a.definitionModuleName; + clone.definitionLocation = a.definitionLocation; + clone.name = a.name; + clone.syntheticName = a.syntheticName; + clone.instantiatedTypeParams = a.instantiatedTypeParams; + clone.instantiatedTypePackParams = a.instantiatedTypePackParams; + clone.tags = a.tags; + return dest.addType(std::move(clone)); + } + else if constexpr (std::is_same_v) + { + MetatableType clone = MetatableType{a.table, a.metatable}; + clone.syntheticName = a.syntheticName; + return dest.addType(std::move(clone)); + } + else if constexpr (std::is_same_v) + { + UnionType clone; + clone.options = a.options; + return dest.addType(std::move(clone)); + } + else if constexpr (std::is_same_v) + { + IntersectionType clone; + clone.parts = a.parts; + return dest.addType(std::move(clone)); + } + else if constexpr (std::is_same_v) + { + if (alwaysClone) + { + ClassType clone{a.name, a.props, a.parent, a.metatable, a.tags, a.userData, a.definitionModuleName}; + return dest.addType(std::move(clone)); + } + else + return ty; + } + else if constexpr (std::is_same_v) + return dest.addType(NegationType{a.ty}); + else + static_assert(always_false_v, "Non-exhaustive shallowClone switch"); + }; + + ty = log->follow(ty); + + if (auto pty = log->pending(ty)) + ty = &pty->pending; + + TypeId resTy = visit(go, ty->ty); + if (resTy != ty) + asMutable(resTy)->documentationSymbol = ty->documentationSymbol; + + return resTy; +} + void Tarjan::visitChildren(TypeId ty, int index) { LUAU_ASSERT(ty == log->follow(ty)); @@ -469,7 +644,7 @@ std::optional Substitution::substitute(TypePackId tp) TypeId Substitution::clone(TypeId ty) { - return shallowClone(ty, *arena, log, /* alwaysClone */ FFlag::LuauClonePublicInterfaceLess); + return shallowClone(ty, *arena, log, /* alwaysClone */ FFlag::LuauClonePublicInterfaceLess2); } TypePackId Substitution::clone(TypePackId tp) @@ -494,7 +669,7 @@ TypePackId Substitution::clone(TypePackId tp) clone.hidden = vtp->hidden; return addTypePack(std::move(clone)); } - else if (FFlag::LuauClonePublicInterfaceLess) + else if (FFlag::LuauClonePublicInterfaceLess2) { return addTypePack(*tp); } diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index d0c53984..fe09ef11 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -14,7 +14,6 @@ #include LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) -LUAU_FASTFLAGVARIABLE(LuauFunctionReturnStringificationFixup, false) /* * Prefix generic typenames with gen- @@ -85,6 +84,11 @@ struct FindCyclicTypes final : TypeVisitor { return false; } + + bool visit(TypeId, const PendingExpansionType&) override + { + return false; + } }; template @@ -364,7 +368,7 @@ struct TypeStringifier state.emit(">"); } - void operator()(TypeId ty, const Unifiable::Free& ftv) + void operator()(TypeId ty, const FreeType& ftv) { state.result.invalid = true; if (FFlag::DebugLuauVerboseTypeNames) @@ -1518,7 +1522,7 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts) } else if constexpr (std::is_same_v) { - return "call " + tos(c.fn) + " with { result = " + tos(c.result) + " }"; + return "call " + tos(c.fn) + "( " + tos(c.argsPack) + " )" + " with { result = " + tos(c.result) + " }"; } else if constexpr (std::is_same_v) { diff --git a/Analysis/src/Type.cpp b/Analysis/src/Type.cpp index 42fa40a5..d70f17f5 100644 --- a/Analysis/src/Type.cpp +++ b/Analysis/src/Type.cpp @@ -26,7 +26,6 @@ LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAG(LuauNormalizeBlockedTypes) -LUAU_FASTFLAGVARIABLE(LuauMatchReturnsOptionalString, false); namespace Luau { @@ -431,8 +430,71 @@ bool hasLength(TypeId ty, DenseHashSet& seen, int* recursionCount) return false; } +FreeType::FreeType(TypeLevel level) + : index(Unifiable::freshIndex()) + , level(level) + , scope(nullptr) +{ +} + +FreeType::FreeType(Scope* scope) + : index(Unifiable::freshIndex()) + , level{} + , scope(scope) +{ +} + +FreeType::FreeType(Scope* scope, TypeLevel level) + : index(Unifiable::freshIndex()) + , level(level) + , scope(scope) +{ +} + +GenericType::GenericType() + : index(Unifiable::freshIndex()) + , name("g" + std::to_string(index)) +{ +} + +GenericType::GenericType(TypeLevel level) + : index(Unifiable::freshIndex()) + , level(level) + , name("g" + std::to_string(index)) +{ +} + +GenericType::GenericType(const Name& name) + : index(Unifiable::freshIndex()) + , name(name) + , explicitName(true) +{ +} + +GenericType::GenericType(Scope* scope) + : index(Unifiable::freshIndex()) + , scope(scope) +{ +} + +GenericType::GenericType(TypeLevel level, const Name& name) + : index(Unifiable::freshIndex()) + , level(level) + , name(name) + , explicitName(true) +{ +} + +GenericType::GenericType(Scope* scope, const Name& name) + : index(Unifiable::freshIndex()) + , scope(scope) + , name(name) + , explicitName(true) +{ +} + BlockedType::BlockedType() - : index(FFlag::LuauNormalizeBlockedTypes ? Unifiable::freshIndex() : ++DEPRECATED_nextIndex) + : index(FFlag::LuauNormalizeBlockedTypes ? Unifiable::freshIndex() : ++DEPRECATED_nextIndex) { } @@ -972,7 +1034,7 @@ const TypeLevel* getLevel(TypeId ty) { ty = follow(ty); - if (auto ftv = get(ty)) + if (auto ftv = get(ty)) return &ftv->level; else if (auto ttv = get(ty)) return &ttv->level; @@ -991,7 +1053,7 @@ std::optional getLevel(TypePackId tp) { tp = follow(tp); - if (auto ftv = get(tp)) + if (auto ftv = get(tp)) return ftv->level; else return std::nullopt; @@ -1219,12 +1281,12 @@ static std::vector parsePatternString(NotNull builtinTypes if (i + 1 < size && data[i + 1] == ')') { i++; - result.push_back(FFlag::LuauMatchReturnsOptionalString ? builtinTypes->optionalNumberType : builtinTypes->numberType); + result.push_back(builtinTypes->optionalNumberType); continue; } ++depth; - result.push_back(FFlag::LuauMatchReturnsOptionalString ? builtinTypes->optionalStringType : builtinTypes->stringType); + result.push_back(builtinTypes->optionalStringType); } else if (data[i] == ')') { @@ -1242,7 +1304,7 @@ static std::vector parsePatternString(NotNull builtinTypes return std::vector(); if (result.empty()) - result.push_back(FFlag::LuauMatchReturnsOptionalString ? builtinTypes->optionalStringType : builtinTypes->stringType); + result.push_back(builtinTypes->optionalStringType); return result; } diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index f9a16205..d6494edf 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -35,7 +35,21 @@ using SyntheticNames = std::unordered_map; namespace Luau { -static const char* getName(Allocator* allocator, SyntheticNames* syntheticNames, const Unifiable::Generic& gen) +static const char* getName(Allocator* allocator, SyntheticNames* syntheticNames, const GenericType& gen) +{ + size_t s = syntheticNames->size(); + char*& n = (*syntheticNames)[&gen]; + if (!n) + { + std::string str = gen.explicitName ? gen.name : generateName(s); + n = static_cast(allocator->allocate(str.size() + 1)); + strcpy(n, str.c_str()); + } + + return n; +} + +static const char* getName(Allocator* allocator, SyntheticNames* syntheticNames, const GenericTypePack& gen) { size_t s = syntheticNames->size(); char*& n = (*syntheticNames)[&gen]; @@ -237,7 +251,7 @@ public: size_t numGenericPacks = 0; for (auto it = ftv.genericPacks.begin(); it != ftv.genericPacks.end(); ++it) { - if (auto gtv = get(*it)) + if (auto gtv = get(*it)) genericPacks.data[numGenericPacks++] = {AstName(gtv->name.c_str()), Location(), nullptr}; } diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index a160a1d2..c7d30f43 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -568,6 +568,10 @@ struct TypeChecker2 { // nothing } + else if (isOptional(iteratorTy)) + { + reportError(OptionalValueAccess{iteratorTy}, forInStatement->values.data[0]->location); + } else if (std::optional iterMmTy = findMetatableEntry(builtinTypes, module->errors, iteratorTy, "__iter", forInStatement->values.data[0]->location)) { @@ -973,6 +977,12 @@ struct TypeChecker2 else if (auto utv = get(functionType)) { // Sometimes it's okay to call a union of functions, but only if all of the functions are the same. + // Another scenario we might run into it is if the union has a nil member. In this case, we want to throw an error + if (isOptional(functionType)) + { + reportError(OptionalValueAccess{functionType}, call->location); + return; + } std::optional fst; for (TypeId ty : utv) { @@ -1187,6 +1197,8 @@ struct TypeChecker2 else reportError(CannotExtendTable{exprType, CannotExtendTable::Indexer, "indexer??"}, indexExpr->location); } + else if (get(exprType) && isOptional(exprType)) + reportError(OptionalValueAccess{exprType}, indexExpr->location); } void visit(AstExprFunction* fn) @@ -1297,9 +1309,13 @@ struct TypeChecker2 DenseHashSet seen{nullptr}; int recursionCount = 0; + if (!hasLength(operandType, seen, &recursionCount)) { - reportError(NotATable{operandType}, expr->location); + if (isOptional(operandType)) + reportError(OptionalValueAccess{operandType}, expr->location); + else + reportError(NotATable{operandType}, expr->location); } } else if (expr->op == AstExprUnary::Op::Minus) @@ -2059,12 +2075,12 @@ struct TypeChecker2 fetch(builtinTypes->functionType); else if (!norm.functions.isNever()) { - if (norm.functions.parts->size() == 1) - fetch(norm.functions.parts->front()); + if (norm.functions.parts.size() == 1) + fetch(norm.functions.parts.front()); else { std::vector parts; - parts.insert(parts.end(), norm.functions.parts->begin(), norm.functions.parts->end()); + parts.insert(parts.end(), norm.functions.parts.begin(), norm.functions.parts.end()); fetch(testArena.addType(IntersectionType{std::move(parts)})); } } diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index abc65286..acf70fec 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -26,7 +26,6 @@ #include LUAU_FASTFLAGVARIABLE(DebugLuauMagicTypes, false) -LUAU_FASTFLAGVARIABLE(LuauDontExtendUnsealedRValueTables, false) LUAU_FASTINTVARIABLE(LuauTypeInferRecursionLimit, 165) LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 20000) LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 5000) @@ -38,7 +37,6 @@ LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false) LUAU_FASTFLAGVARIABLE(LuauTryhardAnd, false) LUAU_FASTFLAG(LuauInstantiateInSubtyping) -LUAU_FASTFLAGVARIABLE(LuauIntersectionTestForEquality, false) LUAU_FASTFLAG(LuauNegatedClassTypes) LUAU_FASTFLAGVARIABLE(LuauAllowIndexClassParameters, false) LUAU_FASTFLAG(LuauUninhabitedSubAnything2) @@ -228,8 +226,8 @@ GlobalTypes::GlobalTypes(NotNull builtinTypes) globalScope->addBuiltinTypeBinding("never", TypeFun{{}, builtinTypes->neverType}); } -TypeChecker::TypeChecker(const GlobalTypes& globals, ModuleResolver* resolver, NotNull builtinTypes, InternalErrorReporter* iceHandler) - : globals(globals) +TypeChecker::TypeChecker(const ScopePtr& globalScope, ModuleResolver* resolver, NotNull builtinTypes, InternalErrorReporter* iceHandler) + : globalScope(globalScope) , resolver(resolver) , builtinTypes(builtinTypes) , iceHandler(iceHandler) @@ -280,7 +278,7 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit; unifierState.counters.iterationLimit = unifierIterationLimit ? *unifierIterationLimit : FInt::LuauTypeInferIterationLimit; - ScopePtr parentScope = environmentScope.value_or(globals.globalScope); + ScopePtr parentScope = environmentScope.value_or(globalScope); ScopePtr moduleScope = std::make_shared(parentScope); if (module.cyclic) @@ -689,11 +687,10 @@ LUAU_NOINLINE void TypeChecker::checkBlockTypeAliases(const ScopePtr& scope, std if (duplicateTypeAliases.contains({typealias->exported, name})) continue; - TypeId type = bindings[name].type; - if (get(follow(type))) + TypeId type = follow(bindings[name].type); + if (get(type)) { - Type* mty = asMutable(follow(type)); - mty->reassign(*errorRecoveryType(anyType)); + asMutable(type)->ty.emplace(errorRecoveryType(anyType)); reportError(TypeError{typealias->location, OccursCheckFailed{}}); } @@ -1023,7 +1020,7 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assig right = errorRecoveryType(scope); else if (auto vtp = get(tailPack)) right = vtp->ty; - else if (get(tailPack)) + else if (get(tailPack)) { *asMutable(tailPack) = TypePack{{left}}; growingPack = getMutable(tailPack); @@ -1284,7 +1281,7 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) callRetPack = checkExprPack(scope, *exprCall).type; callRetPack = follow(callRetPack); - if (get(callRetPack)) + if (get(callRetPack)) { iterTy = freshType(scope); unify(callRetPack, addTypePack({{iterTy}, freshTypePack(scope)}), scope, forin.location); @@ -1657,7 +1654,7 @@ void TypeChecker::prototype(const ScopePtr& scope, const AstStatTypeAlias& typea } else { - if (globals.globalScope->builtinTypeNames.contains(name)) + if (globalScope->builtinTypeNames.contains(name)) { reportError(typealias.location, DuplicateTypeDefinition{name}); duplicateTypeAliases.insert({typealias.exported, name}); @@ -1954,7 +1951,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp return WithPredicate{errorRecoveryType(scope)}; else if (auto vtp = get(varargPack)) return WithPredicate{vtp->ty}; - else if (get(varargPack)) + else if (get(varargPack)) { // TODO: Better error? reportError(expr.location, GenericError{"Trying to get a type from a variadic type parameter"}); @@ -1973,7 +1970,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp { return {pack->head.empty() ? nilType : pack->head[0], std::move(result.predicates)}; } - else if (const FreeTypePack* ftp = get(retPack)) + else if (const FreeTypePack* ftp = get(retPack)) { TypeId head = freshType(scope->level); TypePackId pack = addTypePack(TypePackVar{TypePack{{head}, freshTypePack(scope->level)}}); @@ -1984,7 +1981,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp return {errorRecoveryType(scope), std::move(result.predicates)}; else if (auto vtp = get(retPack)) return {vtp->ty, std::move(result.predicates)}; - else if (get(retPack)) + else if (get(retPack)) { if (FFlag::LuauReturnAnyInsteadOfICE) return {anyType, std::move(result.predicates)}; @@ -2691,7 +2688,7 @@ TypeId TypeChecker::checkRelationalOperation( if (get(lhsType) || get(rhsType)) return booleanType; - if (FFlag::LuauIntersectionTestForEquality && isEquality) + if (isEquality) { // Unless either type is free or any, an equality comparison is only // valid when the intersection of the two operands is non-empty. @@ -3262,16 +3259,7 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex { return it->second.type; } - else if (!FFlag::LuauDontExtendUnsealedRValueTables && (lhsTable->state == TableState::Unsealed || lhsTable->state == TableState::Free)) - { - TypeId theType = freshType(scope); - Property& property = lhsTable->props[name]; - property.type = theType; - property.location = expr.indexLocation; - return theType; - } - else if (FFlag::LuauDontExtendUnsealedRValueTables && - ((ctx == ValueContext::LValue && lhsTable->state == TableState::Unsealed) || lhsTable->state == TableState::Free)) + else if ((ctx == ValueContext::LValue && lhsTable->state == TableState::Unsealed) || lhsTable->state == TableState::Free) { TypeId theType = freshType(scope); Property& property = lhsTable->props[name]; @@ -3392,16 +3380,7 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex { return it->second.type; } - else if (!FFlag::LuauDontExtendUnsealedRValueTables && (exprTable->state == TableState::Unsealed || exprTable->state == TableState::Free)) - { - TypeId resultType = freshType(scope); - Property& property = exprTable->props[value->value.data]; - property.type = resultType; - property.location = expr.index->location; - return resultType; - } - else if (FFlag::LuauDontExtendUnsealedRValueTables && - ((ctx == ValueContext::LValue && exprTable->state == TableState::Unsealed) || exprTable->state == TableState::Free)) + else if ((ctx == ValueContext::LValue && exprTable->state == TableState::Unsealed) || exprTable->state == TableState::Free) { TypeId resultType = freshType(scope); Property& property = exprTable->props[value->value.data]; @@ -3417,14 +3396,7 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex unify(indexType, indexer.indexType, scope, expr.index->location); return indexer.indexResultType; } - else if (!FFlag::LuauDontExtendUnsealedRValueTables && (exprTable->state == TableState::Unsealed || exprTable->state == TableState::Free)) - { - TypeId resultType = freshType(exprTable->level); - exprTable->indexer = TableIndexer{anyIfNonstrict(indexType), anyIfNonstrict(resultType)}; - return resultType; - } - else if (FFlag::LuauDontExtendUnsealedRValueTables && - ((ctx == ValueContext::LValue && exprTable->state == TableState::Unsealed) || exprTable->state == TableState::Free)) + else if ((ctx == ValueContext::LValue && exprTable->state == TableState::Unsealed) || exprTable->state == TableState::Free) { TypeId indexerType = freshType(exprTable->level); unify(indexType, indexerType, scope, expr.location); @@ -3440,13 +3412,7 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex * has no indexer, we have no idea if it will work so we just return any * and hope for the best. */ - if (FFlag::LuauDontExtendUnsealedRValueTables) - return anyType; - else - { - TypeId resultType = freshType(scope); - return resultType; - } + return anyType; } } @@ -3872,7 +3838,7 @@ void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funNam if (argTail) { - if (state.log.getMutable(state.log.follow(*argTail))) + if (state.log.getMutable(state.log.follow(*argTail))) { if (paramTail) state.tryUnify(*paramTail, *argTail); @@ -3887,7 +3853,7 @@ void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funNam else if (paramTail) { // argTail is definitely empty - if (state.log.getMutable(state.log.follow(*paramTail))) + if (state.log.getMutable(state.log.follow(*paramTail))) state.log.replace(*paramTail, TypePackVar(TypePack{{}})); } @@ -5604,7 +5570,7 @@ GenericTypeDefinitions TypeChecker::createGenericTypes(const ScopePtr& scope, st } else { - g = addType(Unifiable::Generic{level, n}); + g = addType(GenericType{level, n}); } generics.push_back({g, defaultValue}); @@ -5632,7 +5598,7 @@ GenericTypeDefinitions TypeChecker::createGenericTypes(const ScopePtr& scope, st TypePackId& cached = scope->parent->typeAliasTypePackParameters[n]; if (!cached) - cached = addTypePack(TypePackVar{Unifiable::Generic{level, n}}); + cached = addTypePack(TypePackVar{GenericTypePack{level, n}}); genericPacks.push_back({cached, defaultValue}); scope->privateTypePackBindings[n] = cached; @@ -5998,7 +5964,7 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, RefinementMap& r if (!typeguardP.isTypeof) return addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); - auto typeFun = globals.globalScope->lookupType(typeguardP.kind); + auto typeFun = globalScope->lookupType(typeguardP.kind); if (!typeFun || !typeFun->typeParams.empty() || !typeFun->typePackParams.empty()) return addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); diff --git a/Analysis/src/TypePack.cpp b/Analysis/src/TypePack.cpp index ccea604f..6873820a 100644 --- a/Analysis/src/TypePack.cpp +++ b/Analysis/src/TypePack.cpp @@ -9,6 +9,69 @@ namespace Luau { +FreeTypePack::FreeTypePack(TypeLevel level) + : index(Unifiable::freshIndex()) + , level(level) + , scope(nullptr) +{ +} + +FreeTypePack::FreeTypePack(Scope* scope) + : index(Unifiable::freshIndex()) + , level{} + , scope(scope) +{ +} + +FreeTypePack::FreeTypePack(Scope* scope, TypeLevel level) + : index(Unifiable::freshIndex()) + , level(level) + , scope(scope) +{ +} + +GenericTypePack::GenericTypePack() + : index(Unifiable::freshIndex()) + , name("g" + std::to_string(index)) +{ +} + +GenericTypePack::GenericTypePack(TypeLevel level) + : index(Unifiable::freshIndex()) + , level(level) + , name("g" + std::to_string(index)) +{ +} + +GenericTypePack::GenericTypePack(const Name& name) + : index(Unifiable::freshIndex()) + , name(name) + , explicitName(true) +{ +} + +GenericTypePack::GenericTypePack(Scope* scope) + : index(Unifiable::freshIndex()) + , scope(scope) +{ +} + +GenericTypePack::GenericTypePack(TypeLevel level, const Name& name) + : index(Unifiable::freshIndex()) + , level(level) + , name(name) + , explicitName(true) +{ +} + +GenericTypePack::GenericTypePack(Scope* scope, const Name& name) + : index(Unifiable::freshIndex()) + , scope(scope) + , name(name) + , explicitName(true) +{ +} + BlockedTypePack::BlockedTypePack() : index(++nextIndex) { @@ -160,8 +223,8 @@ bool areEqual(SeenSet& seen, const TypePackVar& lhs, const TypePackVar& rhs) TypePackId rhsTail = *rhsIter.tail(); { - const Unifiable::Free* lf = get_if(&lhsTail->ty); - const Unifiable::Free* rf = get_if(&rhsTail->ty); + const FreeTypePack* lf = get_if(&lhsTail->ty); + const FreeTypePack* rf = get_if(&rhsTail->ty); if (lf && rf) return lf->index == rf->index; } @@ -174,8 +237,8 @@ bool areEqual(SeenSet& seen, const TypePackVar& lhs, const TypePackVar& rhs) } { - const Unifiable::Generic* lg = get_if(&lhsTail->ty); - const Unifiable::Generic* rg = get_if(&rhsTail->ty); + const GenericTypePack* lg = get_if(&lhsTail->ty); + const GenericTypePack* rg = get_if(&rhsTail->ty); if (lg && rg) return lg->index == rg->index; } diff --git a/Analysis/src/TypeReduction.cpp b/Analysis/src/TypeReduction.cpp index 6a9fadfa..310df766 100644 --- a/Analysis/src/TypeReduction.cpp +++ b/Analysis/src/TypeReduction.cpp @@ -331,7 +331,7 @@ TypeId TypeReducer::reduce(TypeId ty) if (edge->irreducible) return edge->type; else - ty = edge->type; + ty = follow(edge->type); } else if (cyclics->contains(ty)) return ty; diff --git a/Analysis/src/Unifiable.cpp b/Analysis/src/Unifiable.cpp index dcb2d367..2ceb97aa 100644 --- a/Analysis/src/Unifiable.cpp +++ b/Analysis/src/Unifiable.cpp @@ -12,71 +12,6 @@ int freshIndex() { return ++nextIndex; } - -Free::Free(TypeLevel level) - : index(++nextIndex) - , level(level) -{ -} - -Free::Free(Scope* scope) - : index(++nextIndex) - , scope(scope) -{ -} - -Free::Free(Scope* scope, TypeLevel level) - : index(++nextIndex) - , level(level) - , scope(scope) -{ -} - -int Free::DEPRECATED_nextIndex = 0; - -Generic::Generic() - : index(++nextIndex) - , name("g" + std::to_string(index)) -{ -} - -Generic::Generic(TypeLevel level) - : index(++nextIndex) - , level(level) - , name("g" + std::to_string(index)) -{ -} - -Generic::Generic(const Name& name) - : index(++nextIndex) - , name(name) - , explicitName(true) -{ -} - -Generic::Generic(Scope* scope) - : index(++nextIndex) - , scope(scope) -{ -} - -Generic::Generic(TypeLevel level, const Name& name) - : index(++nextIndex) - , level(level) - , name(name) - , explicitName(true) -{ -} - -Generic::Generic(Scope* scope, const Name& name) - : index(++nextIndex) - , scope(scope) - , name(name) - , explicitName(true) -{ -} - -int Generic::DEPRECATED_nextIndex = 0; Error::Error() : index(++nextIndex) diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 9f30d11b..642aa399 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -21,11 +21,9 @@ LUAU_FASTFLAGVARIABLE(LuauInstantiateInSubtyping, false) LUAU_FASTFLAGVARIABLE(LuauUninhabitedSubAnything2, false) LUAU_FASTFLAGVARIABLE(LuauMaintainScopesInUnifier, false) LUAU_FASTFLAGVARIABLE(LuauTransitiveSubtyping, false) -LUAU_FASTFLAGVARIABLE(LuauTinyUnifyNormalsFix, false) LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(LuauNormalizeBlockedTypes) -LUAU_FASTFLAG(LuauNegatedFunctionTypes) LUAU_FASTFLAG(LuauNegatedClassTypes) LUAU_FASTFLAG(LuauNegatedTableTypes) @@ -192,6 +190,18 @@ struct SkipCacheForType final : TypeOnceVisitor return false; } + bool visit(TypeId, const BlockedType&) override + { + result = true; + return false; + } + + bool visit(TypeId, const PendingExpansionType&) override + { + result = true; + return false; + } + bool visit(TypeId ty, const TableType&) override { // Types from other modules don't contain mutable elements and are ok to cache @@ -259,6 +269,12 @@ struct SkipCacheForType final : TypeOnceVisitor return false; } + bool visit(TypePackId tp, const BlockedTypePack&) override + { + result = true; + return false; + } + const DenseHashMap& skipCacheForType; const TypeArena* typeArena = nullptr; bool result = false; @@ -386,6 +402,12 @@ void Unifier::tryUnify(TypeId subTy, TypeId superTy, bool isFunctionCall, bool i tryUnify_(subTy, superTy, isFunctionCall, isIntersection); } +static bool isBlocked(const TxnLog& log, TypeId ty) +{ + ty = log.follow(ty); + return get(ty) || get(ty); +} + void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool isIntersection) { RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit); @@ -531,11 +553,15 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool size_t errorCount = errors.size(); - if (log.getMutable(subTy) && log.getMutable(superTy)) + if (isBlocked(log, subTy) && isBlocked(log, superTy)) { blockedTypes.push_back(subTy); blockedTypes.push_back(superTy); } + else if (isBlocked(log, subTy)) + blockedTypes.push_back(subTy); + else if (isBlocked(log, superTy)) + blockedTypes.push_back(superTy); else if (const UnionType* subUnion = log.getMutable(subTy)) { tryUnifyUnionWithType(subTy, subUnion, superTy); @@ -587,8 +613,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool else if ((log.getMutable(superTy) || log.getMutable(superTy)) && log.getMutable(subTy)) tryUnifySingletons(subTy, superTy); - else if (auto ptv = get(superTy); - FFlag::LuauNegatedFunctionTypes && ptv && ptv->type == PrimitiveType::Function && get(subTy)) + else if (auto ptv = get(superTy); ptv && ptv->type == PrimitiveType::Function && get(subTy)) { // Ok. Do nothing. forall functions F, F <: function } @@ -890,7 +915,8 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp if (!subNorm || !superNorm) return reportError(location, UnificationTooComplex{}); else if ((failedOptionCount == 1 || foundHeuristic) && failedOption) - innerState.tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption); + innerState.tryUnifyNormalizedTypes( + subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption); else innerState.tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the union options are compatible"); if (!innerState.failure) @@ -1246,17 +1272,7 @@ void Unifier::tryUnifyNormalizedTypes( Unifier innerState = makeChildUnifier(); - if (FFlag::LuauTinyUnifyNormalsFix) - innerState.tryUnify(subTable, superTable); - else - { - if (get(superTable)) - innerState.tryUnifyWithMetatable(subTable, superTable, /* reversed */ false); - else if (get(subTable)) - innerState.tryUnifyWithMetatable(superTable, subTable, /* reversed */ true); - else - innerState.tryUnifyTables(subTable, superTable); - } + innerState.tryUnify(subTable, superTable); if (innerState.errors.empty()) { @@ -1275,7 +1291,7 @@ void Unifier::tryUnifyNormalizedTypes( { if (superNorm.functions.isNever()) return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); - for (TypeId superFun : *superNorm.functions.parts) + for (TypeId superFun : superNorm.functions.parts) { Unifier innerState = makeChildUnifier(); const FunctionType* superFtv = get(superFun); @@ -1314,7 +1330,7 @@ TypePackId Unifier::tryApplyOverloadedFunction(TypeId function, const Normalized std::optional result; const FunctionType* firstFun = nullptr; - for (TypeId overload : *overloads.parts) + for (TypeId overload : overloads.parts) { if (const FunctionType* ftv = get(overload)) { @@ -1473,7 +1489,7 @@ struct WeirdIter bool canGrow() const { - return nullptr != log.getMutable(packId); + return nullptr != log.getMutable(packId); } void grow(TypePackId newTail) @@ -1481,7 +1497,7 @@ struct WeirdIter LUAU_ASSERT(canGrow()); LUAU_ASSERT(log.getMutable(newTail)); - auto freePack = log.getMutable(packId); + auto freePack = log.getMutable(packId); level = freePack->level; if (FFlag::LuauMaintainScopesInUnifier && freePack->scope != nullptr) @@ -1575,7 +1591,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal if (log.haveSeen(superTp, subTp)) return; - if (log.getMutable(superTp)) + if (log.getMutable(superTp)) { if (!occursCheck(superTp, subTp)) { @@ -1583,7 +1599,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal log.replace(superTp, Unifiable::Bound(widen(subTp))); } } - else if (log.getMutable(subTp)) + else if (log.getMutable(subTp)) { if (!occursCheck(subTp, superTp)) { @@ -2551,9 +2567,9 @@ static void queueTypePack(std::vector& queue, DenseHashSet& break; seenTypePacks.insert(a); - if (state.log.getMutable(a)) + if (state.log.getMutable(a)) { - state.log.replace(a, Unifiable::Bound{anyTypePack}); + state.log.replace(a, BoundTypePack{anyTypePack}); } else if (auto tp = state.log.getMutable(a)) { @@ -2601,7 +2617,7 @@ void Unifier::tryUnifyVariadics(TypePackId subTp, TypePackId superTp, bool rever { tryUnify_(vtp->ty, superVariadic->ty); } - else if (get(tail)) + else if (get(tail)) { reportError(location, GenericError{"Cannot unify variadic and generic packs"}); } @@ -2761,10 +2777,10 @@ bool Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId hays seen.insert(haystack); - if (log.getMutable(needle)) + if (log.getMutable(needle)) return false; - if (!log.getMutable(needle)) + if (!log.getMutable(needle)) ice("Expected needle to be free"); if (needle == haystack) @@ -2808,10 +2824,10 @@ bool Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ seen.insert(haystack); - if (log.getMutable(needle)) + if (log.getMutable(needle)) return false; - if (!log.getMutable(needle)) + if (!log.getMutable(needle)) ice("Expected needle pack to be free"); RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit); diff --git a/Ast/src/Lexer.cpp b/Ast/src/Lexer.cpp index dac3b95b..75b4fe30 100644 --- a/Ast/src/Lexer.cpp +++ b/Ast/src/Lexer.cpp @@ -6,8 +6,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauFixInterpStringMid, false) - namespace Luau { @@ -642,9 +640,7 @@ Lexeme Lexer::readInterpolatedStringSection(Position start, Lexeme::Type formatT } consume(); - Lexeme lexemeOutput(Location(start, position()), FFlag::LuauFixInterpStringMid ? formatType : Lexeme::InterpStringBegin, - &buffer[startOffset], offset - startOffset - 1); - return lexemeOutput; + return Lexeme(Location(start, position()), formatType, &buffer[startOffset], offset - startOffset - 1); } default: diff --git a/CLI/Analyze.cpp b/CLI/Analyze.cpp index 4fdb0443..6d1f5451 100644 --- a/CLI/Analyze.cpp +++ b/CLI/Analyze.cpp @@ -14,7 +14,6 @@ #endif LUAU_FASTFLAG(DebugLuauTimeTracing) -LUAU_FASTFLAG(LuauLintInTypecheck) enum class ReportFormat { @@ -81,12 +80,10 @@ static bool analyzeFile(Luau::Frontend& frontend, const char* name, ReportFormat for (auto& error : cr.errors) reportError(frontend, format, error); - Luau::LintResult lr = FFlag::LuauLintInTypecheck ? cr.lintResult : frontend.lint_DEPRECATED(name); - std::string humanReadableName = frontend.fileResolver->getHumanReadableModuleName(name); - for (auto& error : lr.errors) + for (auto& error : cr.lintResult.errors) reportWarning(format, humanReadableName.c_str(), error); - for (auto& warning : lr.warnings) + for (auto& warning : cr.lintResult.warnings) reportWarning(format, humanReadableName.c_str(), warning); if (annotate) @@ -101,7 +98,7 @@ static bool analyzeFile(Luau::Frontend& frontend, const char* name, ReportFormat printf("%s", annotated.c_str()); } - return cr.errors.empty() && lr.errors.empty(); + return cr.errors.empty() && cr.lintResult.errors.empty(); } static void displayHelp(const char* argv0) @@ -264,13 +261,13 @@ int main(int argc, char** argv) Luau::FrontendOptions frontendOptions; frontendOptions.retainFullTypeGraphs = annotate; - frontendOptions.runLintChecks = FFlag::LuauLintInTypecheck; + frontendOptions.runLintChecks = true; CliFileResolver fileResolver; CliConfigResolver configResolver(mode); Luau::Frontend frontend(&fileResolver, &configResolver, frontendOptions); - Luau::registerBuiltinGlobals(frontend.typeChecker, frontend.globals); + Luau::registerBuiltinGlobals(frontend, frontend.globals); Luau::freeze(frontend.globals.globalTypes); #ifdef CALLGRIND diff --git a/CodeGen/include/Luau/AddressA64.h b/CodeGen/include/Luau/AddressA64.h index 2c852046..2796ef70 100644 --- a/CodeGen/include/Luau/AddressA64.h +++ b/CodeGen/include/Luau/AddressA64.h @@ -3,6 +3,8 @@ #include "Luau/RegisterA64.h" +#include + namespace Luau { namespace CodeGen @@ -23,6 +25,10 @@ enum class AddressKindA64 : uint8_t struct AddressA64 { + // This is a little misleading since AddressA64 can encode offsets up to 1023*size where size depends on the load/store size + // For example, ldr x0, [reg+imm] is limited to 8 KB offsets assuming imm is divisible by 8, but loading into w0 reduces the range to 4 KB + static constexpr size_t kMaxOffset = 1023; + AddressA64(RegisterA64 base, int off = 0) : kind(AddressKindA64::imm) , base(base) @@ -30,7 +36,6 @@ struct AddressA64 , data(off) { LUAU_ASSERT(base.kind == KindA64::x || base == sp); - LUAU_ASSERT(off >= -256 && off < 4096); } AddressA64(RegisterA64 base, RegisterA64 offset) diff --git a/CodeGen/include/Luau/AssemblyBuilderA64.h b/CodeGen/include/Luau/AssemblyBuilderA64.h index 0179967a..def4d0c0 100644 --- a/CodeGen/include/Luau/AssemblyBuilderA64.h +++ b/CodeGen/include/Luau/AssemblyBuilderA64.h @@ -16,32 +16,45 @@ namespace CodeGen namespace A64 { +enum FeaturesA64 +{ + Feature_JSCVT = 1 << 0, +}; + class AssemblyBuilderA64 { public: - explicit AssemblyBuilderA64(bool logText); + explicit AssemblyBuilderA64(bool logText, unsigned int features = 0); ~AssemblyBuilderA64(); // Moves void mov(RegisterA64 dst, RegisterA64 src); - void mov(RegisterA64 dst, uint16_t src, int shift = 0); + void mov(RegisterA64 dst, int src); // macro + + // Moves of 32-bit immediates get decomposed into one or more of these + void movz(RegisterA64 dst, uint16_t src, int shift = 0); + void movn(RegisterA64 dst, uint16_t src, int shift = 0); void movk(RegisterA64 dst, uint16_t src, int shift = 0); // Arithmetics + // TODO: support various kinds of shifts void add(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift = 0); - void add(RegisterA64 dst, RegisterA64 src1, int src2); + void add(RegisterA64 dst, RegisterA64 src1, uint16_t src2); void sub(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift = 0); - void sub(RegisterA64 dst, RegisterA64 src1, int src2); + void sub(RegisterA64 dst, RegisterA64 src1, uint16_t src2); void neg(RegisterA64 dst, RegisterA64 src); // Comparisons // Note: some arithmetic instructions also have versions that update flags (ADDS etc) but we aren't using them atm void cmp(RegisterA64 src1, RegisterA64 src2); - void cmp(RegisterA64 src1, int src2); + void cmp(RegisterA64 src1, uint16_t src2); + void csel(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, ConditionA64 cond); // Bitwise - // Note: shifted-register support and bitfield operations are omitted for simplicity // TODO: support immediate arguments (they have odd encoding and forbid many values) + // TODO: support bic (andnot) + // TODO: support shifts + // TODO: support bitfield ops void and_(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); void orr(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); void eor(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); @@ -63,14 +76,16 @@ public: void ldrsb(RegisterA64 dst, AddressA64 src); void ldrsh(RegisterA64 dst, AddressA64 src); void ldrsw(RegisterA64 dst, AddressA64 src); + void ldp(RegisterA64 dst1, RegisterA64 dst2, AddressA64 src); // Store void str(RegisterA64 src, AddressA64 dst); void strb(RegisterA64 src, AddressA64 dst); void strh(RegisterA64 src, AddressA64 dst); + void stp(RegisterA64 src1, RegisterA64 src2, AddressA64 dst); // Control flow - // Note: tbz/tbnz are currently not supported because they have 15-bit offsets and we don't support branch thunks + // TODO: support tbz/tbnz; they have 15-bit offsets but they can be useful in constrained cases void b(Label& label); void b(ConditionA64 cond, Label& label); void cbz(RegisterA64 src, Label& label); @@ -84,6 +99,39 @@ public: void adr(RegisterA64 dst, uint64_t value); void adr(RegisterA64 dst, double value); + // Address of code (label) + void adr(RegisterA64 dst, Label& label); + + // Floating-point scalar moves + void fmov(RegisterA64 dst, RegisterA64 src); + + // Floating-point scalar math + void fabs(RegisterA64 dst, RegisterA64 src); + void fadd(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + void fdiv(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + void fmul(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + void fneg(RegisterA64 dst, RegisterA64 src); + void fsqrt(RegisterA64 dst, RegisterA64 src); + void fsub(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + + // Floating-point rounding and conversions + void frinta(RegisterA64 dst, RegisterA64 src); + void frintm(RegisterA64 dst, RegisterA64 src); + void frintp(RegisterA64 dst, RegisterA64 src); + void fcvtzs(RegisterA64 dst, RegisterA64 src); + void fcvtzu(RegisterA64 dst, RegisterA64 src); + void scvtf(RegisterA64 dst, RegisterA64 src); + void ucvtf(RegisterA64 dst, RegisterA64 src); + + // Floating-point conversion to integer using JS rules (wrap around 2^32) and set Z flag + // note: this is part of ARM8.3 (JSCVT feature); support of this instruction needs to be checked at runtime + void fjcvtzs(RegisterA64 dst, RegisterA64 src); + + // Floating-point comparisons + void fcmp(RegisterA64 src1, RegisterA64 src2); + void fcmpz(RegisterA64 src); + void fcsel(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, ConditionA64 cond); + // Run final checks bool finalize(); @@ -112,6 +160,10 @@ public: std::string text; const bool logText = false; + const unsigned int features = 0; + + // Maximum immediate argument to functions like add/sub/cmp + static constexpr size_t kMaxImmediate = (1 << 12) - 1; private: // Instruction archetypes @@ -122,11 +174,15 @@ private: void placeR1(const char* name, RegisterA64 dst, RegisterA64 src, uint32_t op); void placeI12(const char* name, RegisterA64 dst, RegisterA64 src1, int src2, uint8_t op); void placeI16(const char* name, RegisterA64 dst, int src, uint8_t op, int shift = 0); - void placeA(const char* name, RegisterA64 dst, AddressA64 src, uint8_t op, uint8_t size); + void placeA(const char* name, RegisterA64 dst, AddressA64 src, uint8_t op, uint8_t size, int sizelog); void placeBC(const char* name, Label& label, uint8_t op, uint8_t cond); void placeBCR(const char* name, Label& label, uint8_t op, RegisterA64 cond); void placeBR(const char* name, RegisterA64 src, uint32_t op); void placeADR(const char* name, RegisterA64 src, uint8_t op); + void placeADR(const char* name, RegisterA64 src, uint8_t op, Label& label); + void placeP(const char* name, RegisterA64 dst1, RegisterA64 dst2, AddressA64 src, uint8_t op, uint8_t opc, int sizelog); + void placeCS(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, ConditionA64 cond, uint8_t op, uint8_t opc); + void placeFCMP(const char* name, RegisterA64 src1, RegisterA64 src2, uint8_t op, uint8_t opc); void place(uint32_t word); @@ -146,9 +202,11 @@ private: LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, RegisterA64 src); LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, int src, int shift = 0); LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, AddressA64 src); + LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst1, RegisterA64 dst2, AddressA64 src); LUAU_NOINLINE void log(const char* opcode, RegisterA64 src, Label label); LUAU_NOINLINE void log(const char* opcode, RegisterA64 src); LUAU_NOINLINE void log(const char* opcode, Label label); + LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, ConditionA64 cond); LUAU_NOINLINE void log(Label label); LUAU_NOINLINE void log(RegisterA64 reg); LUAU_NOINLINE void log(AddressA64 addr); diff --git a/CodeGen/include/Luau/AssemblyBuilderX64.h b/CodeGen/include/Luau/AssemblyBuilderX64.h index 17076ed6..467be466 100644 --- a/CodeGen/include/Luau/AssemblyBuilderX64.h +++ b/CodeGen/include/Luau/AssemblyBuilderX64.h @@ -41,6 +41,7 @@ enum class ABIX64 class AssemblyBuilderX64 { public: + explicit AssemblyBuilderX64(bool logText, ABIX64 abi); explicit AssemblyBuilderX64(bool logText); ~AssemblyBuilderX64(); @@ -120,6 +121,7 @@ public: void vcvttsd2si(OperandX64 dst, OperandX64 src); void vcvtsi2sd(OperandX64 dst, OperandX64 src1, OperandX64 src2); + void vcvtsd2ss(OperandX64 dst, OperandX64 src1, OperandX64 src2); void vroundsd(OperandX64 dst, OperandX64 src1, OperandX64 src2, RoundingModeX64 roundingMode); // inexact diff --git a/CodeGen/include/Luau/ConditionA64.h b/CodeGen/include/Luau/ConditionA64.h index 0beadad5..e94adbcf 100644 --- a/CodeGen/include/Luau/ConditionA64.h +++ b/CodeGen/include/Luau/ConditionA64.h @@ -8,28 +8,45 @@ namespace CodeGen namespace A64 { +// See Table C1-1 on page C1-229 of Arm ARM for A-profile architecture enum class ConditionA64 { + // EQ: integer (equal), floating-point (equal) Equal, + // NE: integer (not equal), floating-point (not equal or unordered) NotEqual, + // CS: integer (carry set), floating-point (greater than, equal or unordered) CarrySet, + // CC: integer (carry clear), floating-point (less than) CarryClear, + // MI: integer (negative), floating-point (less than) Minus, + // PL: integer (positive or zero), floating-point (greater than, equal or unordered) Plus, + // VS: integer (overflow), floating-point (unordered) Overflow, + // VC: integer (no overflow), floating-point (ordered) NoOverflow, + // HI: integer (unsigned higher), floating-point (greater than, or unordered) UnsignedGreater, + // LS: integer (unsigned lower or same), floating-point (less than or equal) UnsignedLessEqual, + // GE: integer (signed greater than or equal), floating-point (greater than or equal) GreaterEqual, + // LT: integer (signed less than), floating-point (less than, or unordered) Less, + + // GT: integer (signed greater than), floating-point (greater than) Greater, + // LE: integer (signed less than or equal), floating-point (less than, equal or unordered) LessEqual, + // AL: always Always, Count diff --git a/CodeGen/include/Luau/IrAnalysis.h b/CodeGen/include/Luau/IrAnalysis.h index 5c2bc4df..75b4940a 100644 --- a/CodeGen/include/Luau/IrAnalysis.h +++ b/CodeGen/include/Luau/IrAnalysis.h @@ -19,6 +19,8 @@ void updateUseCounts(IrFunction& function); void updateLastUseLocations(IrFunction& function); +uint32_t getNextInstUse(IrFunction& function, uint32_t targetInstIdx, uint32_t startInstIdx); + // Returns how many values are coming into the block (live in) and how many are coming out of the block (live out) std::pair getLiveInOutValueCount(IrFunction& function, IrBlock& block); uint32_t getLiveInValueCount(IrFunction& function, IrBlock& block); @@ -52,8 +54,8 @@ void computeCfgInfo(IrFunction& function); struct BlockIteratorWrapper { - uint32_t* itBegin = nullptr; - uint32_t* itEnd = nullptr; + const uint32_t* itBegin = nullptr; + const uint32_t* itEnd = nullptr; bool empty() const { @@ -65,19 +67,19 @@ struct BlockIteratorWrapper return size_t(itEnd - itBegin); } - uint32_t* begin() const + const uint32_t* begin() const { return itBegin; } - uint32_t* end() const + const uint32_t* end() const { return itEnd; } }; -BlockIteratorWrapper predecessors(CfgInfo& cfg, uint32_t blockIdx); -BlockIteratorWrapper successors(CfgInfo& cfg, uint32_t blockIdx); +BlockIteratorWrapper predecessors(const CfgInfo& cfg, uint32_t blockIdx); +BlockIteratorWrapper successors(const CfgInfo& cfg, uint32_t blockIdx); } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/include/Luau/IrCallWrapperX64.h b/CodeGen/include/Luau/IrCallWrapperX64.h new file mode 100644 index 00000000..724d4624 --- /dev/null +++ b/CodeGen/include/Luau/IrCallWrapperX64.h @@ -0,0 +1,79 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/AssemblyBuilderX64.h" +#include "Luau/IrData.h" +#include "Luau/OperandX64.h" +#include "Luau/RegisterX64.h" + +#include + +// TODO: call wrapper can be used to suggest target registers for ScopedRegX64 to compute data into argument registers directly + +namespace Luau +{ +namespace CodeGen +{ +namespace X64 +{ + +struct IrRegAllocX64; +struct ScopedRegX64; + +struct CallArgument +{ + SizeX64 targetSize = SizeX64::none; + + OperandX64 source = noreg; + IrOp sourceOp; + + OperandX64 target = noreg; + bool candidate = true; +}; + +class IrCallWrapperX64 +{ +public: + IrCallWrapperX64(IrRegAllocX64& regs, AssemblyBuilderX64& build, uint32_t instIdx = kInvalidInstIdx); + + void addArgument(SizeX64 targetSize, OperandX64 source, IrOp sourceOp = {}); + void addArgument(SizeX64 targetSize, ScopedRegX64& scopedReg); + + void call(const OperandX64& func); + + IrRegAllocX64& regs; + AssemblyBuilderX64& build; + uint32_t instIdx = ~0u; + +private: + void assignTargetRegisters(); + void countRegisterUses(); + CallArgument* findNonInterferingArgument(); + bool interferesWithOperand(const OperandX64& op, RegisterX64 reg) const; + bool interferesWithActiveSources(const CallArgument& targetArg, int targetArgIndex) const; + bool interferesWithActiveTarget(RegisterX64 sourceReg) const; + void moveToTarget(CallArgument& arg); + void freeSourceRegisters(CallArgument& arg); + void renameRegister(RegisterX64& target, RegisterX64 reg, RegisterX64 replacement); + void renameSourceRegisters(RegisterX64 reg, RegisterX64 replacement); + RegisterX64 findConflictingTarget() const; + void renameConflictingRegister(RegisterX64 conflict); + + int getRegisterUses(RegisterX64 reg) const; + void addRegisterUse(RegisterX64 reg); + void removeRegisterUse(RegisterX64 reg); + + static const int kMaxCallArguments = 6; + std::array args; + int argCount = 0; + + OperandX64 funcOp; + + // Internal counters for remaining register use counts + std::array gprUses; + std::array xmmUses; +}; + +} // namespace X64 +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/include/Luau/IrData.h b/CodeGen/include/Luau/IrData.h index 67e70632..fcf29adb 100644 --- a/CodeGen/include/Luau/IrData.h +++ b/CodeGen/include/Luau/IrData.h @@ -62,11 +62,12 @@ enum class IrCmd : uint8_t // Get pointer (LuaNode) to table node element at the active cached slot index // A: pointer (Table) + // B: unsigned int (pcpos) GET_SLOT_NODE_ADDR, // Get pointer (LuaNode) to table node element at the main position of the specified key hash // A: pointer (Table) - // B: unsigned int + // B: unsigned int (hash) GET_HASH_NODE_ADDR, // Store a tag into TValue @@ -89,6 +90,13 @@ enum class IrCmd : uint8_t // B: int STORE_INT, + // Store a vector into TValue + // A: Rn + // B: double (x) + // C: double (y) + // D: double (z) + STORE_VECTOR, + // Store a TValue into memory // A: Rn or pointer (TValue) // B: TValue @@ -125,6 +133,26 @@ enum class IrCmd : uint8_t // A: double UNM_NUM, + // Round number to negative infinity (math.floor) + // A: double + FLOOR_NUM, + + // Round number to positive infinity (math.ceil) + // A: double + CEIL_NUM, + + // Round number to nearest integer number, rounding half-way cases away from zero (math.round) + // A: double + ROUND_NUM, + + // Get square root of the argument (math.sqrt) + // A: double + SQRT_NUM, + + // Get absolute value of the argument (math.abs) + // A: double + ABS_NUM, + // Compute Luau 'not' operation on destructured TValue // A: tag // B: double @@ -252,6 +280,7 @@ enum class IrCmd : uint8_t // A: Rn (where to store the result) // B: Rn (lhs) // C: Rn or Kn (rhs) + // D: int (TMS enum with arithmetic type) DO_ARITH, // Get length of a TValue of any type @@ -382,57 +411,44 @@ enum class IrCmd : uint8_t // C: Rn (source start) // D: int (count or -1 to assign values up to stack top) // E: unsigned int (table index to start from) - LOP_SETLIST, + SETLIST, // Call specified function - // A: unsigned int (bytecode instruction index) - // B: Rn (function, followed by arguments) - // C: int (argument count or -1 to use all arguments up to stack top) - // D: int (result count or -1 to preserve all results and adjust stack top) - // Note: return values are placed starting from Rn specified in 'B' - LOP_CALL, + // A: Rn (function, followed by arguments) + // B: int (argument count or -1 to use all arguments up to stack top) + // C: int (result count or -1 to preserve all results and adjust stack top) + // Note: return values are placed starting from Rn specified in 'A' + CALL, // Return specified values from the function - // A: unsigned int (bytecode instruction index) - // B: Rn (value start) - // C: int (result count or -1 to return all values up to stack top) - LOP_RETURN, + // A: Rn (value start) + // B: int (result count or -1 to return all values up to stack top) + RETURN, // Adjust loop variables for one iteration of a generic for loop, jump back to the loop header if loop needs to continue // A: Rn (loop variable start, updates Rn+2 and 'B' number of registers starting from Rn+3) // B: int (loop variable count, if more than 2, registers starting from Rn+5 are set to nil) // C: block (repeat) // D: block (exit) - LOP_FORGLOOP, + FORGLOOP, // Handle LOP_FORGLOOP fallback when variable being iterated is not a table - // A: unsigned int (bytecode instruction index) - // B: Rn (loop state start, updates Rn+2 and 'C' number of registers starting from Rn+3) - // C: int (loop variable count and a MSB set when it's an ipairs-like iteration loop) - // D: block (repeat) - // E: block (exit) - LOP_FORGLOOP_FALLBACK, + // A: Rn (loop state start, updates Rn+2 and 'B' number of registers starting from Rn+3) + // B: int (loop variable count and a MSB set when it's an ipairs-like iteration loop) + // C: block (repeat) + // D: block (exit) + FORGLOOP_FALLBACK, // Fallback for generic for loop preparation when iterating over builtin pairs/ipairs // It raises an error if 'B' register is not a function // A: unsigned int (bytecode instruction index) // B: Rn // C: block (forgloop location) - LOP_FORGPREP_XNEXT_FALLBACK, - - // Perform `and` or `or` operation (selecting lhs or rhs based on whether the lhs is truthy) and put the result into target register - // A: unsigned int (bytecode instruction index) - // B: Rn (target) - // C: Rn (lhs) - // D: Rn or Kn (rhs) - LOP_AND, - LOP_ANDK, - LOP_OR, - LOP_ORK, + FORGPREP_XNEXT_FALLBACK, // Increment coverage data (saturating 24 bit add) // A: unsigned int (bytecode instruction index) - LOP_COVERAGE, + COVERAGE, // Operations that have a translation, but use a full instruction fallback @@ -605,6 +621,17 @@ struct IrOp static_assert(sizeof(IrOp) == 4); +enum class IrValueKind : uint8_t +{ + Unknown, // Used by SUBSTITUTE, argument has to be checked to get type + None, + Tag, + Int, + Pointer, + Double, + Tvalue, +}; + struct IrInst { IrCmd cmd; @@ -624,8 +651,12 @@ struct IrInst X64::RegisterX64 regX64 = X64::noreg; A64::RegisterA64 regA64 = A64::noreg; bool reusedReg = false; + bool spilled = false; }; +// When IrInst operands are used, current instruction index is often required to track lifetime +constexpr uint32_t kInvalidInstIdx = ~0u; + enum class IrBlockKind : uint8_t { Bytecode, @@ -679,6 +710,14 @@ struct IrFunction return instructions[op.index]; } + IrInst* asInstOp(IrOp op) + { + if (op.kind == IrOpKind::Inst) + return &instructions[op.index]; + + return nullptr; + } + IrConst& constOp(IrOp op) { LUAU_ASSERT(op.kind == IrOpKind::Constant); @@ -790,19 +829,44 @@ struct IrFunction return value.valueDouble; } - IrCondition conditionOp(IrOp op) - { - LUAU_ASSERT(op.kind == IrOpKind::Condition); - return IrCondition(op.index); - } - uint32_t getBlockIndex(const IrBlock& block) { // Can only be called with blocks from our vector LUAU_ASSERT(&block >= blocks.data() && &block <= blocks.data() + blocks.size()); return uint32_t(&block - blocks.data()); } + + uint32_t getInstIndex(const IrInst& inst) + { + // Can only be called with instructions from our vector + LUAU_ASSERT(&inst >= instructions.data() && &inst <= instructions.data() + instructions.size()); + return uint32_t(&inst - instructions.data()); + } }; +inline IrCondition conditionOp(IrOp op) +{ + LUAU_ASSERT(op.kind == IrOpKind::Condition); + return IrCondition(op.index); +} + +inline int vmRegOp(IrOp op) +{ + LUAU_ASSERT(op.kind == IrOpKind::VmReg); + return op.index; +} + +inline int vmConstOp(IrOp op) +{ + LUAU_ASSERT(op.kind == IrOpKind::VmConst); + return op.index; +} + +inline int vmUpvalueOp(IrOp op) +{ + LUAU_ASSERT(op.kind == IrOpKind::VmUpvalue); + return op.index; +} + } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/include/Luau/IrDump.h b/CodeGen/include/Luau/IrDump.h index ae517e89..1bc31d9d 100644 --- a/CodeGen/include/Luau/IrDump.h +++ b/CodeGen/include/Luau/IrDump.h @@ -19,9 +19,9 @@ const char* getBlockKindName(IrBlockKind kind); struct IrToStringContext { std::string& result; - std::vector& blocks; - std::vector& constants; - CfgInfo& cfg; + const std::vector& blocks; + const std::vector& constants; + const CfgInfo& cfg; }; void toString(IrToStringContext& ctx, const IrInst& inst, uint32_t index); @@ -33,13 +33,13 @@ void toString(std::string& result, IrConst constant); void toStringDetailed(IrToStringContext& ctx, const IrInst& inst, uint32_t index, bool includeUseInfo); void toStringDetailed(IrToStringContext& ctx, const IrBlock& block, uint32_t index, bool includeUseInfo); // Block title -std::string toString(IrFunction& function, bool includeUseInfo); +std::string toString(const IrFunction& function, bool includeUseInfo); -std::string dump(IrFunction& function); +std::string dump(const IrFunction& function); -std::string toDot(IrFunction& function, bool includeInst); +std::string toDot(const IrFunction& function, bool includeInst); -std::string dumpDot(IrFunction& function, bool includeInst); +std::string dumpDot(const IrFunction& function, bool includeInst); } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/include/Luau/IrRegAllocX64.h b/CodeGen/include/Luau/IrRegAllocX64.h new file mode 100644 index 00000000..dc7b48c6 --- /dev/null +++ b/CodeGen/include/Luau/IrRegAllocX64.h @@ -0,0 +1,118 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/AssemblyBuilderX64.h" +#include "Luau/IrData.h" +#include "Luau/RegisterX64.h" + +#include +#include + +namespace Luau +{ +namespace CodeGen +{ +namespace X64 +{ + +constexpr uint8_t kNoStackSlot = 0xff; + +struct IrSpillX64 +{ + uint32_t instIdx = 0; + bool useDoubleSlot = 0; + + // Spill location can be a stack location or be empty + // When it's empty, it means that instruction value can be rematerialized + uint8_t stackSlot = kNoStackSlot; + + RegisterX64 originalLoc = noreg; +}; + +struct IrRegAllocX64 +{ + IrRegAllocX64(AssemblyBuilderX64& build, IrFunction& function); + + RegisterX64 allocGprReg(SizeX64 preferredSize, uint32_t instIdx); + RegisterX64 allocXmmReg(uint32_t instIdx); + + RegisterX64 allocGprRegOrReuse(SizeX64 preferredSize, uint32_t instIdx, std::initializer_list oprefs); + RegisterX64 allocXmmRegOrReuse(uint32_t instIdx, std::initializer_list oprefs); + + RegisterX64 takeReg(RegisterX64 reg, uint32_t instIdx); + + void freeReg(RegisterX64 reg); + void freeLastUseReg(IrInst& target, uint32_t instIdx); + void freeLastUseRegs(const IrInst& inst, uint32_t instIdx); + + bool isLastUseReg(const IrInst& target, uint32_t instIdx) const; + + bool shouldFreeGpr(RegisterX64 reg) const; + + // Register used by instruction is about to be freed, have to find a way to restore value later + void preserve(IrInst& inst); + + void restore(IrInst& inst, bool intoOriginalLocation); + + void preserveAndFreeInstValues(); + + uint32_t findInstructionWithFurthestNextUse(const std::array& regInstUsers) const; + + void assertFree(RegisterX64 reg) const; + void assertAllFree() const; + void assertNoSpills() const; + + AssemblyBuilderX64& build; + IrFunction& function; + + uint32_t currInstIdx = ~0u; + + std::array freeGprMap; + std::array gprInstUsers; + std::array freeXmmMap; + std::array xmmInstUsers; + + std::bitset<256> usedSpillSlots; + unsigned maxUsedSlot = 0; + std::vector spills; +}; + +struct ScopedRegX64 +{ + explicit ScopedRegX64(IrRegAllocX64& owner); + ScopedRegX64(IrRegAllocX64& owner, SizeX64 size); + ScopedRegX64(IrRegAllocX64& owner, RegisterX64 reg); + ~ScopedRegX64(); + + ScopedRegX64(const ScopedRegX64&) = delete; + ScopedRegX64& operator=(const ScopedRegX64&) = delete; + + void alloc(SizeX64 size); + void free(); + + RegisterX64 release(); + + IrRegAllocX64& owner; + RegisterX64 reg; +}; + +// When IR instruction makes a call under a condition that's not reflected as a real branch in IR, +// spilled values have to be restored to their exact original locations, so that both after a call +// and after the skip, values are found in the same place +struct ScopedSpills +{ + explicit ScopedSpills(IrRegAllocX64& owner); + ~ScopedSpills(); + + ScopedSpills(const ScopedSpills&) = delete; + ScopedSpills& operator=(const ScopedSpills&) = delete; + + bool wasSpilledBefore(const IrSpillX64& spill) const; + + IrRegAllocX64& owner; + std::vector snapshot; +}; + +} // namespace X64 +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/include/Luau/IrUtils.h b/CodeGen/include/Luau/IrUtils.h index 0fc14025..09c55c79 100644 --- a/CodeGen/include/Luau/IrUtils.h +++ b/CodeGen/include/Luau/IrUtils.h @@ -99,10 +99,10 @@ inline bool isBlockTerminator(IrCmd cmd) case IrCmd::JUMP_CMP_NUM: case IrCmd::JUMP_CMP_ANY: case IrCmd::JUMP_SLOT_MATCH: - case IrCmd::LOP_RETURN: - case IrCmd::LOP_FORGLOOP: - case IrCmd::LOP_FORGLOOP_FALLBACK: - case IrCmd::LOP_FORGPREP_XNEXT_FALLBACK: + case IrCmd::RETURN: + case IrCmd::FORGLOOP: + case IrCmd::FORGLOOP_FALLBACK: + case IrCmd::FORGPREP_XNEXT_FALLBACK: case IrCmd::FALLBACK_FORGPREP: return true; default: @@ -137,6 +137,11 @@ inline bool hasResult(IrCmd cmd) case IrCmd::MIN_NUM: case IrCmd::MAX_NUM: case IrCmd::UNM_NUM: + case IrCmd::FLOOR_NUM: + case IrCmd::CEIL_NUM: + case IrCmd::ROUND_NUM: + case IrCmd::SQRT_NUM: + case IrCmd::ABS_NUM: case IrCmd::NOT_ANY: case IrCmd::TABLE_LEN: case IrCmd::NEW_TABLE: @@ -170,6 +175,8 @@ inline bool isPseudo(IrCmd cmd) return cmd == IrCmd::NOP || cmd == IrCmd::SUBSTITUTE; } +IrValueKind getCmdValueKind(IrCmd cmd); + bool isGCO(uint8_t tag); // Manually add or remove use of an operand diff --git a/CodeGen/include/Luau/RegisterA64.h b/CodeGen/include/Luau/RegisterA64.h index 519e83fc..99e62958 100644 --- a/CodeGen/include/Luau/RegisterA64.h +++ b/CodeGen/include/Luau/RegisterA64.h @@ -17,6 +17,8 @@ enum class KindA64 : uint8_t none, w, // 32-bit GPR x, // 64-bit GPR + d, // 64-bit SIMD&FP scalar + q, // 128-bit SIMD&FP vector }; struct RegisterA64 @@ -35,6 +37,15 @@ struct RegisterA64 } }; +constexpr RegisterA64 castReg(KindA64 kind, RegisterA64 reg) +{ + LUAU_ASSERT(kind != reg.kind); + LUAU_ASSERT(kind != KindA64::none && reg.kind != KindA64::none); + LUAU_ASSERT((kind == KindA64::w || kind == KindA64::x) == (reg.kind == KindA64::w || reg.kind == KindA64::x)); + + return RegisterA64{kind, reg.index}; +} + constexpr RegisterA64 noreg{KindA64::none, 0}; constexpr RegisterA64 w0{KindA64::w, 0}; @@ -105,6 +116,72 @@ constexpr RegisterA64 xzr{KindA64::x, 31}; constexpr RegisterA64 sp{KindA64::none, 31}; +constexpr RegisterA64 d0{KindA64::d, 0}; +constexpr RegisterA64 d1{KindA64::d, 1}; +constexpr RegisterA64 d2{KindA64::d, 2}; +constexpr RegisterA64 d3{KindA64::d, 3}; +constexpr RegisterA64 d4{KindA64::d, 4}; +constexpr RegisterA64 d5{KindA64::d, 5}; +constexpr RegisterA64 d6{KindA64::d, 6}; +constexpr RegisterA64 d7{KindA64::d, 7}; +constexpr RegisterA64 d8{KindA64::d, 8}; +constexpr RegisterA64 d9{KindA64::d, 9}; +constexpr RegisterA64 d10{KindA64::d, 10}; +constexpr RegisterA64 d11{KindA64::d, 11}; +constexpr RegisterA64 d12{KindA64::d, 12}; +constexpr RegisterA64 d13{KindA64::d, 13}; +constexpr RegisterA64 d14{KindA64::d, 14}; +constexpr RegisterA64 d15{KindA64::d, 15}; +constexpr RegisterA64 d16{KindA64::d, 16}; +constexpr RegisterA64 d17{KindA64::d, 17}; +constexpr RegisterA64 d18{KindA64::d, 18}; +constexpr RegisterA64 d19{KindA64::d, 19}; +constexpr RegisterA64 d20{KindA64::d, 20}; +constexpr RegisterA64 d21{KindA64::d, 21}; +constexpr RegisterA64 d22{KindA64::d, 22}; +constexpr RegisterA64 d23{KindA64::d, 23}; +constexpr RegisterA64 d24{KindA64::d, 24}; +constexpr RegisterA64 d25{KindA64::d, 25}; +constexpr RegisterA64 d26{KindA64::d, 26}; +constexpr RegisterA64 d27{KindA64::d, 27}; +constexpr RegisterA64 d28{KindA64::d, 28}; +constexpr RegisterA64 d29{KindA64::d, 29}; +constexpr RegisterA64 d30{KindA64::d, 30}; +constexpr RegisterA64 d31{KindA64::d, 31}; + +constexpr RegisterA64 q0{KindA64::q, 0}; +constexpr RegisterA64 q1{KindA64::q, 1}; +constexpr RegisterA64 q2{KindA64::q, 2}; +constexpr RegisterA64 q3{KindA64::q, 3}; +constexpr RegisterA64 q4{KindA64::q, 4}; +constexpr RegisterA64 q5{KindA64::q, 5}; +constexpr RegisterA64 q6{KindA64::q, 6}; +constexpr RegisterA64 q7{KindA64::q, 7}; +constexpr RegisterA64 q8{KindA64::q, 8}; +constexpr RegisterA64 q9{KindA64::q, 9}; +constexpr RegisterA64 q10{KindA64::q, 10}; +constexpr RegisterA64 q11{KindA64::q, 11}; +constexpr RegisterA64 q12{KindA64::q, 12}; +constexpr RegisterA64 q13{KindA64::q, 13}; +constexpr RegisterA64 q14{KindA64::q, 14}; +constexpr RegisterA64 q15{KindA64::q, 15}; +constexpr RegisterA64 q16{KindA64::q, 16}; +constexpr RegisterA64 q17{KindA64::q, 17}; +constexpr RegisterA64 q18{KindA64::q, 18}; +constexpr RegisterA64 q19{KindA64::q, 19}; +constexpr RegisterA64 q20{KindA64::q, 20}; +constexpr RegisterA64 q21{KindA64::q, 21}; +constexpr RegisterA64 q22{KindA64::q, 22}; +constexpr RegisterA64 q23{KindA64::q, 23}; +constexpr RegisterA64 q24{KindA64::q, 24}; +constexpr RegisterA64 q25{KindA64::q, 25}; +constexpr RegisterA64 q26{KindA64::q, 26}; +constexpr RegisterA64 q27{KindA64::q, 27}; +constexpr RegisterA64 q28{KindA64::q, 28}; +constexpr RegisterA64 q29{KindA64::q, 29}; +constexpr RegisterA64 q30{KindA64::q, 30}; +constexpr RegisterA64 q31{KindA64::q, 31}; + } // namespace A64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/AssemblyBuilderA64.cpp b/CodeGen/src/AssemblyBuilderA64.cpp index 308747d2..a80003e9 100644 --- a/CodeGen/src/AssemblyBuilderA64.cpp +++ b/CodeGen/src/AssemblyBuilderA64.cpp @@ -21,8 +21,9 @@ static_assert(sizeof(textForCondition) / sizeof(textForCondition[0]) == size_t(C const unsigned kMaxAlign = 32; -AssemblyBuilderA64::AssemblyBuilderA64(bool logText) +AssemblyBuilderA64::AssemblyBuilderA64(bool logText, unsigned int features) : logText(logText) + , features(features) { data.resize(4096); dataPos = data.size(); // data is filled backwards @@ -39,15 +40,39 @@ AssemblyBuilderA64::~AssemblyBuilderA64() void AssemblyBuilderA64::mov(RegisterA64 dst, RegisterA64 src) { + LUAU_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x || dst == sp); + LUAU_ASSERT(dst.kind == src.kind || (dst.kind == KindA64::x && src == sp) || (dst == sp && src.kind == KindA64::x)); + if (dst == sp || src == sp) placeR1("mov", dst, src, 0b00'100010'0'000000000000); else placeSR2("mov", dst, src, 0b01'01010); } -void AssemblyBuilderA64::mov(RegisterA64 dst, uint16_t src, int shift) +void AssemblyBuilderA64::mov(RegisterA64 dst, int src) { - placeI16("mov", dst, src, 0b10'100101, shift); + if (src >= 0) + { + movz(dst, src & 0xffff); + if (src > 0xffff) + movk(dst, src >> 16, 16); + } + else + { + movn(dst, ~src & 0xffff); + if (src < -0x10000) + movk(dst, (src >> 16) & 0xffff, 16); + } +} + +void AssemblyBuilderA64::movz(RegisterA64 dst, uint16_t src, int shift) +{ + placeI16("movz", dst, src, 0b10'100101, shift); +} + +void AssemblyBuilderA64::movn(RegisterA64 dst, uint16_t src, int shift) +{ + placeI16("movn", dst, src, 0b00'100101, shift); } void AssemblyBuilderA64::movk(RegisterA64 dst, uint16_t src, int shift) @@ -60,7 +85,7 @@ void AssemblyBuilderA64::add(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2 placeSR3("add", dst, src1, src2, 0b00'01011, shift); } -void AssemblyBuilderA64::add(RegisterA64 dst, RegisterA64 src1, int src2) +void AssemblyBuilderA64::add(RegisterA64 dst, RegisterA64 src1, uint16_t src2) { placeI12("add", dst, src1, src2, 0b00'10001); } @@ -70,7 +95,7 @@ void AssemblyBuilderA64::sub(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2 placeSR3("sub", dst, src1, src2, 0b10'01011, shift); } -void AssemblyBuilderA64::sub(RegisterA64 dst, RegisterA64 src1, int src2) +void AssemblyBuilderA64::sub(RegisterA64 dst, RegisterA64 src1, uint16_t src2) { placeI12("sub", dst, src1, src2, 0b10'10001); } @@ -87,13 +112,20 @@ void AssemblyBuilderA64::cmp(RegisterA64 src1, RegisterA64 src2) placeSR3("cmp", dst, src1, src2, 0b11'01011); } -void AssemblyBuilderA64::cmp(RegisterA64 src1, int src2) +void AssemblyBuilderA64::cmp(RegisterA64 src1, uint16_t src2) { RegisterA64 dst = src1.kind == KindA64::x ? xzr : wzr; placeI12("cmp", dst, src1, src2, 0b11'10001); } +void AssemblyBuilderA64::csel(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, ConditionA64 cond) +{ + LUAU_ASSERT(dst.kind == KindA64::x || dst.kind == KindA64::w); + + placeCS("csel", dst, src1, src2, cond, 0b11010'10'0, 0b00); +} + void AssemblyBuilderA64::and_(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2) { placeSR3("and", dst, src1, src2, 0b00'01010); @@ -136,75 +168,129 @@ void AssemblyBuilderA64::ror(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2 void AssemblyBuilderA64::clz(RegisterA64 dst, RegisterA64 src) { + LUAU_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x); + LUAU_ASSERT(dst.kind == src.kind); + placeR1("clz", dst, src, 0b10'11010110'00000'00010'0); } void AssemblyBuilderA64::rbit(RegisterA64 dst, RegisterA64 src) { + LUAU_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x); + LUAU_ASSERT(dst.kind == src.kind); + placeR1("rbit", dst, src, 0b10'11010110'00000'0000'00); } void AssemblyBuilderA64::ldr(RegisterA64 dst, AddressA64 src) { - LUAU_ASSERT(dst.kind == KindA64::x || dst.kind == KindA64::w); + LUAU_ASSERT(dst.kind == KindA64::x || dst.kind == KindA64::w || dst.kind == KindA64::d || dst.kind == KindA64::q); - placeA("ldr", dst, src, 0b11100001, 0b10 | uint8_t(dst.kind == KindA64::x)); + switch (dst.kind) + { + case KindA64::w: + placeA("ldr", dst, src, 0b11100001, 0b10, 2); + break; + case KindA64::x: + placeA("ldr", dst, src, 0b11100001, 0b11, 3); + break; + case KindA64::d: + placeA("ldr", dst, src, 0b11110001, 0b11, 3); + break; + case KindA64::q: + placeA("ldr", dst, src, 0b11110011, 0b00, 4); + break; + case KindA64::none: + LUAU_ASSERT(!"Unexpected register kind"); + } } void AssemblyBuilderA64::ldrb(RegisterA64 dst, AddressA64 src) { LUAU_ASSERT(dst.kind == KindA64::w); - placeA("ldrb", dst, src, 0b11100001, 0b00); + placeA("ldrb", dst, src, 0b11100001, 0b00, 2); } void AssemblyBuilderA64::ldrh(RegisterA64 dst, AddressA64 src) { LUAU_ASSERT(dst.kind == KindA64::w); - placeA("ldrh", dst, src, 0b11100001, 0b01); + placeA("ldrh", dst, src, 0b11100001, 0b01, 2); } void AssemblyBuilderA64::ldrsb(RegisterA64 dst, AddressA64 src) { LUAU_ASSERT(dst.kind == KindA64::x || dst.kind == KindA64::w); - placeA("ldrsb", dst, src, 0b11100010 | uint8_t(dst.kind == KindA64::w), 0b00); + placeA("ldrsb", dst, src, 0b11100010 | uint8_t(dst.kind == KindA64::w), 0b00, 0); } void AssemblyBuilderA64::ldrsh(RegisterA64 dst, AddressA64 src) { LUAU_ASSERT(dst.kind == KindA64::x || dst.kind == KindA64::w); - placeA("ldrsh", dst, src, 0b11100010 | uint8_t(dst.kind == KindA64::w), 0b01); + placeA("ldrsh", dst, src, 0b11100010 | uint8_t(dst.kind == KindA64::w), 0b01, 1); } void AssemblyBuilderA64::ldrsw(RegisterA64 dst, AddressA64 src) { LUAU_ASSERT(dst.kind == KindA64::x); - placeA("ldrsw", dst, src, 0b11100010, 0b10); + placeA("ldrsw", dst, src, 0b11100010, 0b10, 2); +} + +void AssemblyBuilderA64::ldp(RegisterA64 dst1, RegisterA64 dst2, AddressA64 src) +{ + LUAU_ASSERT(dst1.kind == KindA64::x || dst1.kind == KindA64::w); + LUAU_ASSERT(dst1.kind == dst2.kind); + + placeP("ldp", dst1, dst2, src, 0b101'0'010'1, uint8_t(dst1.kind == KindA64::x) << 1, dst1.kind == KindA64::x ? 3 : 2); } void AssemblyBuilderA64::str(RegisterA64 src, AddressA64 dst) { - LUAU_ASSERT(src.kind == KindA64::x || src.kind == KindA64::w); + LUAU_ASSERT(src.kind == KindA64::x || src.kind == KindA64::w || src.kind == KindA64::d || src.kind == KindA64::q); - placeA("str", src, dst, 0b11100000, 0b10 | uint8_t(src.kind == KindA64::x)); + switch (src.kind) + { + case KindA64::w: + placeA("str", src, dst, 0b11100000, 0b10, 2); + break; + case KindA64::x: + placeA("str", src, dst, 0b11100000, 0b11, 3); + break; + case KindA64::d: + placeA("str", src, dst, 0b11110000, 0b11, 3); + break; + case KindA64::q: + placeA("str", src, dst, 0b11110010, 0b00, 4); + break; + case KindA64::none: + LUAU_ASSERT(!"Unexpected register kind"); + } } void AssemblyBuilderA64::strb(RegisterA64 src, AddressA64 dst) { LUAU_ASSERT(src.kind == KindA64::w); - placeA("strb", src, dst, 0b11100000, 0b00); + placeA("strb", src, dst, 0b11100000, 0b00, 2); } void AssemblyBuilderA64::strh(RegisterA64 src, AddressA64 dst) { LUAU_ASSERT(src.kind == KindA64::w); - placeA("strh", src, dst, 0b11100000, 0b01); + placeA("strh", src, dst, 0b11100000, 0b01, 2); +} + +void AssemblyBuilderA64::stp(RegisterA64 src1, RegisterA64 src2, AddressA64 dst) +{ + LUAU_ASSERT(src1.kind == KindA64::x || src1.kind == KindA64::w); + LUAU_ASSERT(src1.kind == src2.kind); + + placeP("stp", src1, src2, dst, 0b101'0'010'0, uint8_t(src1.kind == KindA64::x) << 1, src1.kind == KindA64::x ? 3 : 2); } void AssemblyBuilderA64::b(Label& label) @@ -276,6 +362,150 @@ void AssemblyBuilderA64::adr(RegisterA64 dst, double value) patchImm19(location, -int(location) - int((data.size() - pos) / 4)); } +void AssemblyBuilderA64::adr(RegisterA64 dst, Label& label) +{ + placeADR("adr", dst, 0b10000, label); +} + +void AssemblyBuilderA64::fmov(RegisterA64 dst, RegisterA64 src) +{ + LUAU_ASSERT(dst.kind == KindA64::d && src.kind == KindA64::d); + + placeR1("fmov", dst, src, 0b000'11110'01'1'0000'00'10000); +} + +void AssemblyBuilderA64::fabs(RegisterA64 dst, RegisterA64 src) +{ + LUAU_ASSERT(dst.kind == KindA64::d && src.kind == KindA64::d); + + placeR1("fabs", dst, src, 0b000'11110'01'1'0000'01'10000); +} + +void AssemblyBuilderA64::fadd(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2) +{ + LUAU_ASSERT(dst.kind == KindA64::d && src1.kind == KindA64::d && src2.kind == KindA64::d); + + placeR3("fadd", dst, src1, src2, 0b11110'01'1, 0b0010'10); +} + +void AssemblyBuilderA64::fdiv(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2) +{ + LUAU_ASSERT(dst.kind == KindA64::d && src1.kind == KindA64::d && src2.kind == KindA64::d); + + placeR3("fdiv", dst, src1, src2, 0b11110'01'1, 0b0001'10); +} + +void AssemblyBuilderA64::fmul(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2) +{ + LUAU_ASSERT(dst.kind == KindA64::d && src1.kind == KindA64::d && src2.kind == KindA64::d); + + placeR3("fmul", dst, src1, src2, 0b11110'01'1, 0b0000'10); +} + +void AssemblyBuilderA64::fneg(RegisterA64 dst, RegisterA64 src) +{ + LUAU_ASSERT(dst.kind == KindA64::d && src.kind == KindA64::d); + + placeR1("fneg", dst, src, 0b000'11110'01'1'0000'10'10000); +} + +void AssemblyBuilderA64::fsqrt(RegisterA64 dst, RegisterA64 src) +{ + LUAU_ASSERT(dst.kind == KindA64::d && src.kind == KindA64::d); + + placeR1("fsqrt", dst, src, 0b000'11110'01'1'0000'11'10000); +} + +void AssemblyBuilderA64::fsub(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2) +{ + LUAU_ASSERT(dst.kind == KindA64::d && src1.kind == KindA64::d && src2.kind == KindA64::d); + + placeR3("fsub", dst, src1, src2, 0b11110'01'1, 0b0011'10); +} + +void AssemblyBuilderA64::frinta(RegisterA64 dst, RegisterA64 src) +{ + LUAU_ASSERT(dst.kind == KindA64::d && src.kind == KindA64::d); + + placeR1("frinta", dst, src, 0b000'11110'01'1'001'100'10000); +} + +void AssemblyBuilderA64::frintm(RegisterA64 dst, RegisterA64 src) +{ + LUAU_ASSERT(dst.kind == KindA64::d && src.kind == KindA64::d); + + placeR1("frintm", dst, src, 0b000'11110'01'1'001'010'10000); +} + +void AssemblyBuilderA64::frintp(RegisterA64 dst, RegisterA64 src) +{ + LUAU_ASSERT(dst.kind == KindA64::d && src.kind == KindA64::d); + + placeR1("frintp", dst, src, 0b000'11110'01'1'001'001'10000); +} + +void AssemblyBuilderA64::fcvtzs(RegisterA64 dst, RegisterA64 src) +{ + LUAU_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x); + LUAU_ASSERT(src.kind == KindA64::d); + + placeR1("fcvtzs", dst, src, 0b000'11110'01'1'11'000'000000); +} + +void AssemblyBuilderA64::fcvtzu(RegisterA64 dst, RegisterA64 src) +{ + LUAU_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x); + LUAU_ASSERT(src.kind == KindA64::d); + + placeR1("fcvtzu", dst, src, 0b000'11110'01'1'11'001'000000); +} + +void AssemblyBuilderA64::scvtf(RegisterA64 dst, RegisterA64 src) +{ + LUAU_ASSERT(dst.kind == KindA64::d); + LUAU_ASSERT(src.kind == KindA64::w || src.kind == KindA64::x); + + placeR1("scvtf", dst, src, 0b000'11110'01'1'00'010'000000); +} + +void AssemblyBuilderA64::ucvtf(RegisterA64 dst, RegisterA64 src) +{ + LUAU_ASSERT(dst.kind == KindA64::d); + LUAU_ASSERT(src.kind == KindA64::w || src.kind == KindA64::x); + + placeR1("ucvtf", dst, src, 0b000'11110'01'1'00'011'000000); +} + +void AssemblyBuilderA64::fjcvtzs(RegisterA64 dst, RegisterA64 src) +{ + LUAU_ASSERT(dst.kind == KindA64::w); + LUAU_ASSERT(src.kind == KindA64::d); + LUAU_ASSERT(features & Feature_JSCVT); + + placeR1("fjcvtzs", dst, src, 0b000'11110'01'1'11'110'000000); +} + +void AssemblyBuilderA64::fcmp(RegisterA64 src1, RegisterA64 src2) +{ + LUAU_ASSERT(src1.kind == KindA64::d && src2.kind == KindA64::d); + + placeFCMP("fcmp", src1, src2, 0b11110'01'1, 0b00); +} + +void AssemblyBuilderA64::fcmpz(RegisterA64 src) +{ + LUAU_ASSERT(src.kind == KindA64::d); + + placeFCMP("fcmp", src, RegisterA64{src.kind, 0}, 0b11110'01'1, 0b01); +} + +void AssemblyBuilderA64::fcsel(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, ConditionA64 cond) +{ + LUAU_ASSERT(dst.kind == KindA64::d); + + placeCS("fcsel", dst, src1, src2, cond, 0b11110'01'1, 0b11); +} + bool AssemblyBuilderA64::finalize() { code.resize(codePos - code.data()); @@ -387,7 +617,7 @@ void AssemblyBuilderA64::placeR3(const char* name, RegisterA64 dst, RegisterA64 if (logText) log(name, dst, src1, src2); - LUAU_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x); + LUAU_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x || dst.kind == KindA64::d); LUAU_ASSERT(dst.kind == src1.kind && dst.kind == src2.kind); uint32_t sf = (dst.kind == KindA64::x) ? 0x80000000 : 0; @@ -401,10 +631,7 @@ void AssemblyBuilderA64::placeR1(const char* name, RegisterA64 dst, RegisterA64 if (logText) log(name, dst, src); - LUAU_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x || dst == sp); - LUAU_ASSERT(dst.kind == src.kind || (dst.kind == KindA64::x && src == sp) || (dst == sp && src.kind == KindA64::x)); - - uint32_t sf = (dst.kind != KindA64::w) ? 0x80000000 : 0; + uint32_t sf = (dst.kind == KindA64::x || src.kind == KindA64::x) ? 0x80000000 : 0; place(dst.index | (src.index << 5) | (op << 10) | sf); commit(); @@ -440,7 +667,7 @@ void AssemblyBuilderA64::placeI16(const char* name, RegisterA64 dst, int src, ui commit(); } -void AssemblyBuilderA64::placeA(const char* name, RegisterA64 dst, AddressA64 src, uint8_t op, uint8_t size) +void AssemblyBuilderA64::placeA(const char* name, RegisterA64 dst, AddressA64 src, uint8_t op, uint8_t size, int sizelog) { if (logText) log(name, dst, src); @@ -448,8 +675,8 @@ void AssemblyBuilderA64::placeA(const char* name, RegisterA64 dst, AddressA64 sr switch (src.kind) { case AddressKindA64::imm: - if (src.data >= 0 && src.data % (1 << size) == 0) - place(dst.index | (src.base.index << 5) | ((src.data >> size) << 10) | (op << 22) | (1 << 24) | (size << 30)); + if (src.data >= 0 && (src.data >> sizelog) < 1024 && (src.data & ((1 << sizelog) - 1)) == 0) + place(dst.index | (src.base.index << 5) | ((src.data >> sizelog) << 10) | (op << 22) | (1 << 24) | (size << 30)); else if (src.data >= -256 && src.data <= 255) place(dst.index | (src.base.index << 5) | ((src.data & ((1 << 9) - 1)) << 12) | (op << 22) | (size << 30)); else @@ -511,6 +738,61 @@ void AssemblyBuilderA64::placeADR(const char* name, RegisterA64 dst, uint8_t op) commit(); } +void AssemblyBuilderA64::placeADR(const char* name, RegisterA64 dst, uint8_t op, Label& label) +{ + LUAU_ASSERT(dst.kind == KindA64::x); + + place(dst.index | (op << 24)); + commit(); + + patchLabel(label); + + if (logText) + log(name, dst, label); +} + +void AssemblyBuilderA64::placeP(const char* name, RegisterA64 src1, RegisterA64 src2, AddressA64 dst, uint8_t op, uint8_t opc, int sizelog) +{ + if (logText) + log(name, src1, src2, dst); + + LUAU_ASSERT(dst.kind == AddressKindA64::imm); + LUAU_ASSERT(dst.data >= -128 * (1 << sizelog) && dst.data <= 127 * (1 << sizelog)); + LUAU_ASSERT(dst.data % (1 << sizelog) == 0); + + place(src1.index | (dst.base.index << 5) | (src2.index << 10) | (((dst.data >> sizelog) & 127) << 15) | (op << 22) | (opc << 30)); + commit(); +} + +void AssemblyBuilderA64::placeCS(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, ConditionA64 cond, uint8_t op, uint8_t opc) +{ + if (logText) + log(name, dst, src1, src2, cond); + + LUAU_ASSERT(dst.kind == src1.kind && dst.kind == src2.kind); + + uint32_t sf = (dst.kind == KindA64::x) ? 0x80000000 : 0; + + place(dst.index | (src1.index << 5) | (opc << 10) | (codeForCondition[int(cond)] << 12) | (src2.index << 16) | (op << 21) | sf); + commit(); +} + +void AssemblyBuilderA64::placeFCMP(const char* name, RegisterA64 src1, RegisterA64 src2, uint8_t op, uint8_t opc) +{ + if (logText) + { + if (opc) + log(name, src1, 0); + else + log(name, src1, src2); + } + + LUAU_ASSERT(src1.kind == src2.kind); + + place((opc << 3) | (src1.index << 5) | (0b1000 << 10) | (src2.index << 16) | (op << 21)); + commit(); +} + void AssemblyBuilderA64::place(uint32_t word) { LUAU_ASSERT(codePos < codeEnd); @@ -628,6 +910,17 @@ void AssemblyBuilderA64::log(const char* opcode, RegisterA64 dst, AddressA64 src text.append("\n"); } +void AssemblyBuilderA64::log(const char* opcode, RegisterA64 dst1, RegisterA64 dst2, AddressA64 src) +{ + logAppend(" %-12s", opcode); + log(dst1); + text.append(","); + log(dst2); + text.append(","); + log(src); + text.append("\n"); +} + void AssemblyBuilderA64::log(const char* opcode, RegisterA64 dst, RegisterA64 src) { logAppend(" %-12s", opcode); @@ -668,6 +961,19 @@ void AssemblyBuilderA64::log(const char* opcode, Label label) logAppend(" %-12s.L%d\n", opcode, label.id); } +void AssemblyBuilderA64::log(const char* opcode, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, ConditionA64 cond) +{ + logAppend(" %-12s", opcode); + log(dst); + text.append(","); + log(src1); + text.append(","); + log(src2); + text.append(","); + text.append(textForCondition[int(cond)] + 2); // skip b. + text.append("\n"); +} + void AssemblyBuilderA64::log(Label label) { logAppend(".L%d:\n", label.id); @@ -691,6 +997,14 @@ void AssemblyBuilderA64::log(RegisterA64 reg) logAppend("x%d", reg.index); break; + case KindA64::d: + logAppend("d%d", reg.index); + break; + + case KindA64::q: + logAppend("q%d", reg.index); + break; + case KindA64::none: if (reg.index == 31) text.append("sp"); diff --git a/CodeGen/src/AssemblyBuilderX64.cpp b/CodeGen/src/AssemblyBuilderX64.cpp index bf7889b8..d86a37c6 100644 --- a/CodeGen/src/AssemblyBuilderX64.cpp +++ b/CodeGen/src/AssemblyBuilderX64.cpp @@ -71,9 +71,9 @@ static ABIX64 getCurrentX64ABI() #endif } -AssemblyBuilderX64::AssemblyBuilderX64(bool logText) +AssemblyBuilderX64::AssemblyBuilderX64(bool logText, ABIX64 abi) : logText(logText) - , abi(getCurrentX64ABI()) + , abi(abi) { data.resize(4096); dataPos = data.size(); // data is filled backwards @@ -83,6 +83,11 @@ AssemblyBuilderX64::AssemblyBuilderX64(bool logText) codeEnd = code.data() + code.size(); } +AssemblyBuilderX64::AssemblyBuilderX64(bool logText) + : AssemblyBuilderX64(logText, getCurrentX64ABI()) +{ +} + AssemblyBuilderX64::~AssemblyBuilderX64() { LUAU_ASSERT(finalized); @@ -671,6 +676,16 @@ void AssemblyBuilderX64::vcvtsi2sd(OperandX64 dst, OperandX64 src1, OperandX64 s placeAvx("vcvtsi2sd", dst, src1, src2, 0x2a, (src2.cat == CategoryX64::reg ? src2.base.size : src2.memSize) == SizeX64::qword, AVX_0F, AVX_F2); } +void AssemblyBuilderX64::vcvtsd2ss(OperandX64 dst, OperandX64 src1, OperandX64 src2) +{ + if (src2.cat == CategoryX64::reg) + LUAU_ASSERT(src2.base.size == SizeX64::xmmword); + else + LUAU_ASSERT(src2.memSize == SizeX64::qword); + + placeAvx("vcvtsd2ss", dst, src1, src2, 0x5a, (src2.cat == CategoryX64::reg ? src2.base.size : src2.memSize) == SizeX64::qword, AVX_0F, AVX_F2); +} + void AssemblyBuilderX64::vroundsd(OperandX64 dst, OperandX64 src1, OperandX64 src2, RoundingModeX64 roundingMode) { placeAvx("vroundsd", dst, src1, src2, uint8_t(roundingMode) | kRoundingPrecisionInexact, 0x0b, false, AVX_0F3A, AVX_66); diff --git a/CodeGen/src/CodeGen.cpp b/CodeGen/src/CodeGen.cpp index ce490f91..8e6e9493 100644 --- a/CodeGen/src/CodeGen.cpp +++ b/CodeGen/src/CodeGen.cpp @@ -6,6 +6,8 @@ #include "Luau/CodeBlockUnwind.h" #include "Luau/IrAnalysis.h" #include "Luau/IrBuilder.h" +#include "Luau/IrDump.h" +#include "Luau/IrUtils.h" #include "Luau/OptimizeConstProp.h" #include "Luau/OptimizeFinalX64.h" @@ -13,19 +15,24 @@ #include "Luau/UnwindBuilderDwarf2.h" #include "Luau/UnwindBuilderWin.h" -#include "Luau/AssemblyBuilderX64.h" #include "Luau/AssemblyBuilderA64.h" +#include "Luau/AssemblyBuilderX64.h" #include "CustomExecUtils.h" -#include "CodeGenX64.h" +#include "NativeState.h" + #include "CodeGenA64.h" +#include "EmitCommonA64.h" +#include "IrLoweringA64.h" + +#include "CodeGenX64.h" #include "EmitCommonX64.h" #include "EmitInstructionX64.h" #include "IrLoweringX64.h" -#include "NativeState.h" #include "lapi.h" +#include #include #if defined(__x86_64__) || defined(_M_X64) @@ -36,6 +43,12 @@ #endif #endif +#if defined(__aarch64__) +#ifdef __APPLE__ +#include +#endif +#endif + LUAU_FASTFLAGVARIABLE(DebugCodegenNoOpt, false) namespace Luau @@ -60,7 +73,154 @@ static NativeProto* createNativeProto(Proto* proto, const IrBuilder& ir) return result; } -[[maybe_unused]] static void lowerIr( +template +static bool lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction& function, int bytecodeid, AssemblyOptions options) +{ + // While we will need a better block ordering in the future, right now we want to mostly preserve build order with fallbacks outlined + std::vector sortedBlocks; + sortedBlocks.reserve(function.blocks.size()); + for (uint32_t i = 0; i < function.blocks.size(); i++) + sortedBlocks.push_back(i); + + std::sort(sortedBlocks.begin(), sortedBlocks.end(), [&](uint32_t idxA, uint32_t idxB) { + const IrBlock& a = function.blocks[idxA]; + const IrBlock& b = function.blocks[idxB]; + + // Place fallback blocks at the end + if ((a.kind == IrBlockKind::Fallback) != (b.kind == IrBlockKind::Fallback)) + return (a.kind == IrBlockKind::Fallback) < (b.kind == IrBlockKind::Fallback); + + // Try to order by instruction order + return a.start < b.start; + }); + + DenseHashMap bcLocations{~0u}; + + // Create keys for IR assembly locations that original bytecode instruction are interested in + for (const auto& [irLocation, asmLocation] : function.bcMapping) + { + if (irLocation != ~0u) + bcLocations[irLocation] = 0; + } + + DenseHashMap indexIrToBc{~0u}; + bool outputEnabled = options.includeAssembly || options.includeIr; + + if (outputEnabled && options.annotator) + { + // Create reverse mapping from IR location to bytecode location + for (size_t i = 0; i < function.bcMapping.size(); ++i) + { + uint32_t irLocation = function.bcMapping[i].irLocation; + + if (irLocation != ~0u) + indexIrToBc[irLocation] = uint32_t(i); + } + } + + IrToStringContext ctx{build.text, function.blocks, function.constants, function.cfg}; + + // We use this to skip outlined fallback blocks from IR/asm text output + size_t textSize = build.text.length(); + uint32_t codeSize = build.getCodeSize(); + bool seenFallback = false; + + IrBlock dummy; + dummy.start = ~0u; + + for (size_t i = 0; i < sortedBlocks.size(); ++i) + { + uint32_t blockIndex = sortedBlocks[i]; + + IrBlock& block = function.blocks[blockIndex]; + + if (block.kind == IrBlockKind::Dead) + continue; + + LUAU_ASSERT(block.start != ~0u); + LUAU_ASSERT(block.finish != ~0u); + + // If we want to skip fallback code IR/asm, we'll record when those blocks start once we see them + if (block.kind == IrBlockKind::Fallback && !seenFallback) + { + textSize = build.text.length(); + codeSize = build.getCodeSize(); + seenFallback = true; + } + + if (options.includeIr) + { + build.logAppend("# "); + toStringDetailed(ctx, block, blockIndex, /* includeUseInfo */ true); + } + + build.setLabel(block.label); + + for (uint32_t index = block.start; index <= block.finish; index++) + { + LUAU_ASSERT(index < function.instructions.size()); + + // If IR instruction is the first one for the original bytecode, we can annotate it with source code text + if (outputEnabled && options.annotator) + { + if (uint32_t* bcIndex = indexIrToBc.find(index)) + options.annotator(options.annotatorContext, build.text, bytecodeid, *bcIndex); + } + + // If bytecode needs the location of this instruction for jumps, record it + if (uint32_t* bcLocation = bcLocations.find(index)) + { + Label label = (index == block.start) ? block.label : build.setLabel(); + *bcLocation = build.getLabelOffset(label); + } + + IrInst& inst = function.instructions[index]; + + // Skip pseudo instructions, but make sure they are not used at this stage + // This also prevents them from getting into text output when that's enabled + if (isPseudo(inst.cmd)) + { + LUAU_ASSERT(inst.useCount == 0); + continue; + } + + if (options.includeIr) + { + build.logAppend("# "); + toStringDetailed(ctx, inst, index, /* includeUseInfo */ true); + } + + IrBlock& next = i + 1 < sortedBlocks.size() ? function.blocks[sortedBlocks[i + 1]] : dummy; + + lowering.lowerInst(inst, index, next); + + if (lowering.hasError()) + return false; + } + + if (options.includeIr) + build.logAppend("#\n"); + } + + if (outputEnabled && !options.includeOutlinedCode && seenFallback) + { + build.text.resize(textSize); + + if (options.includeAssembly) + build.logAppend("; skipping %u bytes of outlined code\n", unsigned((build.getCodeSize() - codeSize) * sizeof(build.code[0]))); + } + + // Copy assembly locations of IR instructions that are mapped to bytecode instructions + for (auto& [irLocation, asmLocation] : function.bcMapping) + { + if (irLocation != ~0u) + asmLocation = bcLocations[irLocation]; + } + + return true; +} + +[[maybe_unused]] static bool lowerIr( X64::AssemblyBuilderX64& build, IrBuilder& ir, NativeState& data, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options) { constexpr uint32_t kFunctionAlignment = 32; @@ -69,24 +229,20 @@ static NativeProto* createNativeProto(Proto* proto, const IrBuilder& ir) build.align(kFunctionAlignment, X64::AlignmentDataX64::Ud2); - X64::IrLoweringX64 lowering(build, helpers, data, proto, ir.function); + X64::IrLoweringX64 lowering(build, helpers, data, ir.function); - lowering.lower(options); + return lowerImpl(build, lowering, ir.function, proto->bytecodeid, options); } -[[maybe_unused]] static void lowerIr( +[[maybe_unused]] static bool lowerIr( A64::AssemblyBuilderA64& build, IrBuilder& ir, NativeState& data, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options) { - Label start = build.setLabel(); + if (!A64::IrLoweringA64::canLower(ir.function)) + return false; - build.mov(A64::x0, 1); // finish function in VM - build.ret(); + A64::IrLoweringA64 lowering(build, helpers, data, proto, ir.function); - // TODO: This is only needed while we don't support all IR opcodes - // When we can't translate some parts of the function, we instead encode a dummy assembly sequence that hands off control to VM - // In the future we could return nullptr from assembleFunction and handle it because there may be other reasons for why we refuse to assemble. - for (int i = 0; i < proto->sizecode; i++) - ir.function.bcMapping[i].asmLocation = build.getLabelOffset(start); + return lowerImpl(build, lowering, ir.function, proto->bytecodeid, options); } template @@ -123,16 +279,20 @@ static NativeProto* assembleFunction(AssemblyBuilder& build, NativeState& data, IrBuilder ir; ir.buildFunctionIr(proto); + computeCfgInfo(ir.function); + if (!FFlag::DebugCodegenNoOpt) { constPropInBlockChains(ir); } - // TODO: cfg info has to be computed earlier to use in optimizations - // It's done here to appear in text output and to measure performance impact on code generation - computeCfgInfo(ir.function); + if (!lowerIr(build, ir, data, helpers, proto, options)) + { + if (build.logText) + build.logAppend("; skipping (can't lower)\n\n"); - lowerIr(build, ir, data, helpers, proto, options); + return nullptr; + } if (build.logText) build.logAppend("\n"); @@ -188,6 +348,22 @@ static void onSetBreakpoint(lua_State* L, Proto* proto, int instruction) LUAU_ASSERT(!"native breakpoints are not implemented"); } +#if defined(__aarch64__) +static unsigned int getCpuFeaturesA64() +{ + unsigned int result = 0; + +#ifdef __APPLE__ + int jscvt = 0; + size_t jscvtLen = sizeof(jscvt); + if (sysctlbyname("hw.optional.arm.FEAT_JSCVT", &jscvt, &jscvtLen, nullptr, 0) == 0 && jscvt == 1) + result |= A64::Feature_JSCVT; +#endif + + return result; +} +#endif + bool isSupported() { #if !LUA_CUSTOM_EXECUTION @@ -217,6 +393,19 @@ bool isSupported() return true; #elif defined(__aarch64__) + if (LUA_EXTRA_SIZE != 1) + return false; + + if (sizeof(TValue) != 16) + return false; + + if (sizeof(LuaNode) != 32) + return false; + + // TODO: A64 codegen does not generate correct unwind info at the moment so it requires longjmp instead of C++ exceptions + if (!LUA_USE_LONGJMP) + return false; + return true; #else return false; @@ -289,7 +478,7 @@ void compile(lua_State* L, int idx) return; #if defined(__aarch64__) - A64::AssemblyBuilderA64 build(/* logText= */ false); + A64::AssemblyBuilderA64 build(/* logText= */ false, getCpuFeaturesA64()); #else X64::AssemblyBuilderX64 build(/* logText= */ false); #endif @@ -300,7 +489,9 @@ void compile(lua_State* L, int idx) gatherFunctions(protos, clvalue(func)->l.p); ModuleHelpers helpers; -#if !defined(__aarch64__) +#if defined(__aarch64__) + A64::assembleHelpers(build, helpers); +#else X64::assembleHelpers(build, helpers); #endif @@ -310,10 +501,15 @@ void compile(lua_State* L, int idx) // Skip protos that have been compiled during previous invocations of CodeGen::compile for (Proto* p : protos) if (p && getProtoExecData(p) == nullptr) - results.push_back(assembleFunction(build, *data, helpers, p, {})); + if (NativeProto* np = assembleFunction(build, *data, helpers, p, {})) + results.push_back(np); build.finalize(); + // If no functions were assembled, we don't need to allocate/copy executable pages for helpers + if (results.empty()) + return; + uint8_t* nativeData = nullptr; size_t sizeNativeData = 0; uint8_t* codeStart = nullptr; @@ -347,7 +543,7 @@ std::string getAssembly(lua_State* L, int idx, AssemblyOptions options) const TValue* func = luaA_toobject(L, idx); #if defined(__aarch64__) - A64::AssemblyBuilderA64 build(/* logText= */ options.includeAssembly); + A64::AssemblyBuilderA64 build(/* logText= */ options.includeAssembly, getCpuFeaturesA64()); #else X64::AssemblyBuilderX64 build(/* logText= */ options.includeAssembly); #endif @@ -359,22 +555,21 @@ std::string getAssembly(lua_State* L, int idx, AssemblyOptions options) gatherFunctions(protos, clvalue(func)->l.p); ModuleHelpers helpers; -#if !defined(__aarch64__) +#if defined(__aarch64__) + A64::assembleHelpers(build, helpers); +#else X64::assembleHelpers(build, helpers); #endif for (Proto* p : protos) if (p) - { - NativeProto* nativeProto = assembleFunction(build, data, helpers, p, options); - destroyNativeProto(nativeProto); - } + if (NativeProto* np = assembleFunction(build, data, helpers, p, options)) + destroyNativeProto(np); build.finalize(); if (options.outputBinary) - return std::string( - reinterpret_cast(build.code.data()), reinterpret_cast(build.code.data() + build.code.size())) + + return std::string(reinterpret_cast(build.code.data()), reinterpret_cast(build.code.data() + build.code.size())) + std::string(build.data.begin(), build.data.end()); else return build.text; diff --git a/CodeGen/src/CodeGenA64.cpp b/CodeGen/src/CodeGenA64.cpp index 94d6f2e3..e7a1e2e2 100644 --- a/CodeGen/src/CodeGenA64.cpp +++ b/CodeGen/src/CodeGenA64.cpp @@ -6,6 +6,7 @@ #include "CustomExecUtils.h" #include "NativeState.h" +#include "EmitCommonA64.h" #include "lstate.h" @@ -21,26 +22,50 @@ bool initEntryFunction(NativeState& data) AssemblyBuilderA64 build(/* logText= */ false); UnwindBuilder& unwind = *data.unwindBuilder.get(); - unwind.start(); - unwind.allocStack(8); // TODO: this is only necessary to align stack by 16 bytes, as start() allocates 8b return pointer + // Arguments: x0 = lua_State*, x1 = Proto*, x2 = native code pointer to jump to, x3 = NativeContext* - // TODO: prologue goes here + unwind.start(); + unwind.allocStack(8); // TODO: this is just a hack to make UnwindBuilder assertions cooperate + + // prologue + build.sub(sp, sp, kStackSize); + build.stp(x29, x30, mem(sp)); // fp, lr + + // stash non-volatile registers used for execution environment + build.stp(x19, x20, mem(sp, 16)); + build.stp(x21, x22, mem(sp, 32)); + build.stp(x23, x24, mem(sp, 48)); + + build.mov(x29, sp); // this is only necessary if we maintain frame pointers, which we do in the JIT for now unwind.finish(); size_t prologueSize = build.setLabel().location; // Setup native execution environment - // TODO: figure out state layout + build.mov(rState, x0); + build.mov(rNativeContext, x3); - // Jump to the specified instruction; further control flow will be handled with custom ABI with register setup from EmitCommonX64.h + build.ldr(rBase, mem(x0, offsetof(lua_State, base))); // L->base + build.ldr(rConstants, mem(x1, offsetof(Proto, k))); // proto->k + build.ldr(rCode, mem(x1, offsetof(Proto, code))); // proto->code + + build.ldr(x9, mem(x0, offsetof(lua_State, ci))); // L->ci + build.ldr(x9, mem(x9, offsetof(CallInfo, func))); // L->ci->func + build.ldr(rClosure, mem(x9, offsetof(TValue, value.gc))); // L->ci->func->value.gc aka cl + + // Jump to the specified instruction; further control flow will be handled with custom ABI with register setup from EmitCommonA64.h build.br(x2); // Even though we jumped away, we will return here in the end Label returnOff = build.setLabel(); // Cleanup and exit - // TODO: epilogue + build.ldp(x23, x24, mem(sp, 48)); + build.ldp(x21, x22, mem(sp, 32)); + build.ldp(x19, x20, mem(sp, 16)); + build.ldp(x29, x30, mem(sp)); // fp, lr + build.add(sp, sp, kStackSize); build.ret(); @@ -59,11 +84,34 @@ bool initEntryFunction(NativeState& data) // specified by the unwind information of the entry function unwind.setBeginOffset(prologueSize); - data.context.gateExit = data.context.gateEntry + returnOff.location; + data.context.gateExit = data.context.gateEntry + build.getLabelOffset(returnOff); return true; } +void assembleHelpers(AssemblyBuilderA64& build, ModuleHelpers& helpers) +{ + if (build.logText) + build.logAppend("; exitContinueVm\n"); + helpers.exitContinueVm = build.setLabel(); + emitExit(build, /* continueInVm */ true); + + if (build.logText) + build.logAppend("; exitNoContinueVm\n"); + helpers.exitNoContinueVm = build.setLabel(); + emitExit(build, /* continueInVm */ false); + + if (build.logText) + build.logAppend("; reentry\n"); + helpers.reentry = build.setLabel(); + emitReentry(build, helpers); + + if (build.logText) + build.logAppend("; interrupt\n"); + helpers.interrupt = build.setLabel(); + emitInterrupt(build); +} + } // namespace A64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/CodeGenA64.h b/CodeGen/src/CodeGenA64.h index 5043e5c6..7b792cc1 100644 --- a/CodeGen/src/CodeGenA64.h +++ b/CodeGen/src/CodeGenA64.h @@ -7,11 +7,15 @@ namespace CodeGen { struct NativeState; +struct ModuleHelpers; namespace A64 { +class AssemblyBuilderA64; + bool initEntryFunction(NativeState& data); +void assembleHelpers(AssemblyBuilderA64& build, ModuleHelpers& helpers); } // namespace A64 } // namespace CodeGen diff --git a/CodeGen/src/CodeGenUtils.cpp b/CodeGen/src/CodeGenUtils.cpp index 77047a7a..ae3dbd45 100644 --- a/CodeGen/src/CodeGenUtils.cpp +++ b/CodeGen/src/CodeGenUtils.cpp @@ -126,5 +126,125 @@ void callEpilogC(lua_State* L, int nresults, int n) L->top = (nresults == LUA_MULTRET) ? res : cip->top; } +// Extracted as-is from lvmexecute.cpp with the exception of control flow (reentry) and removed interrupts/savedpc +Closure* callFallback(lua_State* L, StkId ra, StkId argtop, int nresults) +{ + // slow-path: not a function call + if (LUAU_UNLIKELY(!ttisfunction(ra))) + { + luaV_tryfuncTM(L, ra); + argtop++; // __call adds an extra self + } + + Closure* ccl = clvalue(ra); + + CallInfo* ci = incr_ci(L); + ci->func = ra; + ci->base = ra + 1; + ci->top = argtop + ccl->stacksize; // note: technically UB since we haven't reallocated the stack yet + ci->savedpc = NULL; + ci->flags = 0; + ci->nresults = nresults; + + L->base = ci->base; + L->top = argtop; + + // note: this reallocs stack, but we don't need to VM_PROTECT this + // this is because we're going to modify base/savedpc manually anyhow + // crucially, we can't use ra/argtop after this line + luaD_checkstack(L, ccl->stacksize); + + LUAU_ASSERT(ci->top <= L->stack_last); + + if (!ccl->isC) + { + Proto* p = ccl->l.p; + + // fill unused parameters with nil + StkId argi = L->top; + StkId argend = L->base + p->numparams; + while (argi < argend) + setnilvalue(argi++); // complete missing arguments + L->top = p->is_vararg ? argi : ci->top; + + // keep executing new function + ci->savedpc = p->code; + return ccl; + } + else + { + lua_CFunction func = ccl->c.f; + int n = func(L); + + // yield + if (n < 0) + return NULL; + + // ci is our callinfo, cip is our parent + CallInfo* ci = L->ci; + CallInfo* cip = ci - 1; + + // copy return values into parent stack (but only up to nresults!), fill the rest with nil + // note: in MULTRET context nresults starts as -1 so i != 0 condition never activates intentionally + StkId res = ci->func; + StkId vali = L->top - n; + StkId valend = L->top; + + int i; + for (i = nresults; i != 0 && vali < valend; i--) + setobj2s(L, res++, vali++); + while (i-- > 0) + setnilvalue(res++); + + // pop the stack frame + L->ci = cip; + L->base = cip->base; + L->top = (nresults == LUA_MULTRET) ? res : cip->top; + + // keep executing current function + LUAU_ASSERT(isLua(cip)); + return clvalue(cip->func); + } +} + +// Extracted as-is from lvmexecute.cpp with the exception of control flow (reentry) and removed interrupts +Closure* returnFallback(lua_State* L, StkId ra, int n) +{ + // ci is our callinfo, cip is our parent + CallInfo* ci = L->ci; + CallInfo* cip = ci - 1; + + StkId res = ci->func; // note: we assume CALL always puts func+args and expects results to start at func + + StkId vali = ra; + StkId valend = (n == LUA_MULTRET) ? L->top : ra + n; // copy as much as possible for MULTRET calls, and only as much as needed otherwise + + int nresults = ci->nresults; + + // copy return values into parent stack (but only up to nresults!), fill the rest with nil + // note: in MULTRET context nresults starts as -1 so i != 0 condition never activates intentionally + int i; + for (i = nresults; i != 0 && vali < valend; i--) + setobj2s(L, res++, vali++); + while (i-- > 0) + setnilvalue(res++); + + // pop the stack frame + L->ci = cip; + L->base = cip->base; + L->top = (nresults == LUA_MULTRET) ? res : cip->top; + + // we're done! + if (LUAU_UNLIKELY(ci->flags & LUA_CALLINFO_RETURN)) + { + L->top = res; + return NULL; + } + + // keep executing new function + LUAU_ASSERT(isLua(cip)); + return clvalue(cip->func); +} + } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/CodeGenUtils.h b/CodeGen/src/CodeGenUtils.h index ca190213..6066a691 100644 --- a/CodeGen/src/CodeGenUtils.h +++ b/CodeGen/src/CodeGenUtils.h @@ -16,5 +16,8 @@ void forgPrepXnextFallback(lua_State* L, TValue* ra, int pc); Closure* callProlog(lua_State* L, TValue* ra, StkId argtop, int nresults); void callEpilogC(lua_State* L, int nresults, int n); +Closure* callFallback(lua_State* L, StkId ra, StkId argtop, int nresults); +Closure* returnFallback(lua_State* L, StkId ra, int n); + } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/EmitBuiltinsX64.cpp b/CodeGen/src/EmitBuiltinsX64.cpp index d70b6ed8..b010ce62 100644 --- a/CodeGen/src/EmitBuiltinsX64.cpp +++ b/CodeGen/src/EmitBuiltinsX64.cpp @@ -3,15 +3,14 @@ #include "Luau/AssemblyBuilderX64.h" #include "Luau/Bytecode.h" +#include "Luau/IrCallWrapperX64.h" +#include "Luau/IrRegAllocX64.h" #include "EmitCommonX64.h" -#include "IrRegAllocX64.h" #include "NativeState.h" #include "lstate.h" -// TODO: LBF_MATH_FREXP and LBF_MATH_MODF can work for 1 result case if second store is removed - namespace Luau { namespace CodeGen @@ -19,40 +18,11 @@ namespace CodeGen namespace X64 { -void emitBuiltinMathFloor(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) -{ - ScopedRegX64 tmp{regs, SizeX64::xmmword}; - build.vroundsd(tmp.reg, tmp.reg, luauRegValue(arg), RoundingModeX64::RoundToNegativeInfinity); - build.vmovsd(luauRegValue(ra), tmp.reg); -} - -void emitBuiltinMathCeil(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) -{ - ScopedRegX64 tmp{regs, SizeX64::xmmword}; - build.vroundsd(tmp.reg, tmp.reg, luauRegValue(arg), RoundingModeX64::RoundToPositiveInfinity); - build.vmovsd(luauRegValue(ra), tmp.reg); -} - -void emitBuiltinMathSqrt(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) -{ - ScopedRegX64 tmp{regs, SizeX64::xmmword}; - build.vsqrtsd(tmp.reg, tmp.reg, luauRegValue(arg)); - build.vmovsd(luauRegValue(ra), tmp.reg); -} - -void emitBuiltinMathAbs(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) -{ - ScopedRegX64 tmp{regs, SizeX64::xmmword}; - build.vmovsd(tmp.reg, luauRegValue(arg)); - build.vandpd(tmp.reg, tmp.reg, build.i64(~(1LL << 63))); - build.vmovsd(luauRegValue(ra), tmp.reg); -} - static void emitBuiltinMathSingleArgFunc(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int arg, int32_t offset) { - regs.assertAllFree(); - build.vmovsd(xmm0, luauRegValue(arg)); - build.call(qword[rNativeContext + offset]); + IrCallWrapperX64 callWrap(regs, build); + callWrap.addArgument(SizeX64::xmmword, luauRegValue(arg)); + callWrap.call(qword[rNativeContext + offset]); build.vmovsd(luauRegValue(ra), xmm0); } @@ -64,20 +34,10 @@ void emitBuiltinMathExp(IrRegAllocX64& regs, AssemblyBuilderX64& build, int npar void emitBuiltinMathFmod(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - regs.assertAllFree(); - build.vmovsd(xmm0, luauRegValue(arg)); - build.vmovsd(xmm1, qword[args + offsetof(TValue, value)]); - build.call(qword[rNativeContext + offsetof(NativeContext, libm_fmod)]); - - build.vmovsd(luauRegValue(ra), xmm0); -} - -void emitBuiltinMathPow(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) -{ - regs.assertAllFree(); - build.vmovsd(xmm0, luauRegValue(arg)); - build.vmovsd(xmm1, qword[args + offsetof(TValue, value)]); - build.call(qword[rNativeContext + offsetof(NativeContext, libm_pow)]); + IrCallWrapperX64 callWrap(regs, build); + callWrap.addArgument(SizeX64::xmmword, luauRegValue(arg)); + callWrap.addArgument(SizeX64::xmmword, qword[args + offsetof(TValue, value)]); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, libm_fmod)]); build.vmovsd(luauRegValue(ra), xmm0); } @@ -129,10 +89,10 @@ void emitBuiltinMathTanh(IrRegAllocX64& regs, AssemblyBuilderX64& build, int npa void emitBuiltinMathAtan2(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - regs.assertAllFree(); - build.vmovsd(xmm0, luauRegValue(arg)); - build.vmovsd(xmm1, qword[args + offsetof(TValue, value)]); - build.call(qword[rNativeContext + offsetof(NativeContext, libm_atan2)]); + IrCallWrapperX64 callWrap(regs, build); + callWrap.addArgument(SizeX64::xmmword, luauRegValue(arg)); + callWrap.addArgument(SizeX64::xmmword, qword[args + offsetof(TValue, value)]); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, libm_atan2)]); build.vmovsd(luauRegValue(ra), xmm0); } @@ -194,69 +154,45 @@ void emitBuiltinMathLog(IrRegAllocX64& regs, AssemblyBuilderX64& build, int npar void emitBuiltinMathLdexp(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - regs.assertAllFree(); - build.vmovsd(xmm0, luauRegValue(arg)); + ScopedRegX64 tmp{regs, SizeX64::qword}; + build.vcvttsd2si(tmp.reg, qword[args + offsetof(TValue, value)]); - if (build.abi == ABIX64::Windows) - build.vcvttsd2si(rArg2, qword[args + offsetof(TValue, value)]); - else - build.vcvttsd2si(rArg1, qword[args + offsetof(TValue, value)]); - - build.call(qword[rNativeContext + offsetof(NativeContext, libm_ldexp)]); + IrCallWrapperX64 callWrap(regs, build); + callWrap.addArgument(SizeX64::xmmword, luauRegValue(arg)); + callWrap.addArgument(SizeX64::qword, tmp); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, libm_ldexp)]); build.vmovsd(luauRegValue(ra), xmm0); } -void emitBuiltinMathRound(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) -{ - ScopedRegX64 tmp0{regs, SizeX64::xmmword}; - ScopedRegX64 tmp1{regs, SizeX64::xmmword}; - ScopedRegX64 tmp2{regs, SizeX64::xmmword}; - - build.vmovsd(tmp0.reg, luauRegValue(arg)); - build.vandpd(tmp1.reg, tmp0.reg, build.f64x2(-0.0, -0.0)); - build.vmovsd(tmp2.reg, build.i64(0x3fdfffffffffffff)); // 0.49999999999999994 - build.vorpd(tmp1.reg, tmp1.reg, tmp2.reg); - build.vaddsd(tmp0.reg, tmp0.reg, tmp1.reg); - build.vroundsd(tmp0.reg, tmp0.reg, tmp0.reg, RoundingModeX64::RoundToZero); - - build.vmovsd(luauRegValue(ra), tmp0.reg); -} - void emitBuiltinMathFrexp(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - regs.assertAllFree(); - build.vmovsd(xmm0, luauRegValue(arg)); - - if (build.abi == ABIX64::Windows) - build.lea(rArg2, sTemporarySlot); - else - build.lea(rArg1, sTemporarySlot); - - build.call(qword[rNativeContext + offsetof(NativeContext, libm_frexp)]); + IrCallWrapperX64 callWrap(regs, build); + callWrap.addArgument(SizeX64::xmmword, luauRegValue(arg)); + callWrap.addArgument(SizeX64::qword, sTemporarySlot); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, libm_frexp)]); build.vmovsd(luauRegValue(ra), xmm0); - build.vcvtsi2sd(xmm0, xmm0, dword[sTemporarySlot + 0]); - build.vmovsd(luauRegValue(ra + 1), xmm0); + if (nresults > 1) + { + build.vcvtsi2sd(xmm0, xmm0, dword[sTemporarySlot + 0]); + build.vmovsd(luauRegValue(ra + 1), xmm0); + } } void emitBuiltinMathModf(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - regs.assertAllFree(); - build.vmovsd(xmm0, luauRegValue(arg)); - - if (build.abi == ABIX64::Windows) - build.lea(rArg2, sTemporarySlot); - else - build.lea(rArg1, sTemporarySlot); - - build.call(qword[rNativeContext + offsetof(NativeContext, libm_modf)]); + IrCallWrapperX64 callWrap(regs, build); + callWrap.addArgument(SizeX64::xmmword, luauRegValue(arg)); + callWrap.addArgument(SizeX64::qword, sTemporarySlot); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, libm_modf)]); build.vmovsd(xmm1, qword[sTemporarySlot + 0]); build.vmovsd(luauRegValue(ra), xmm1); - build.vmovsd(luauRegValue(ra + 1), xmm0); + if (nresults > 1) + build.vmovsd(luauRegValue(ra + 1), xmm0); } void emitBuiltinMathSign(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) @@ -301,12 +237,10 @@ void emitBuiltinType(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams void emitBuiltinTypeof(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - regs.assertAllFree(); - - build.mov(rArg1, rState); - build.lea(rArg2, luauRegAddress(arg)); - - build.call(qword[rNativeContext + offsetof(NativeContext, luaT_objtypenamestr)]); + IrCallWrapperX64 callWrap(regs, build); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::qword, luauRegAddress(arg)); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaT_objtypenamestr)]); build.mov(luauRegValue(ra), rax); } @@ -316,9 +250,9 @@ void emitBuiltin(IrRegAllocX64& regs, AssemblyBuilderX64& build, int bfid, int r OperandX64 argsOp = 0; if (args.kind == IrOpKind::VmReg) - argsOp = luauRegAddress(args.index); + argsOp = luauRegAddress(vmRegOp(args)); else if (args.kind == IrOpKind::VmConst) - argsOp = luauConstantAddress(args.index); + argsOp = luauConstantAddress(vmConstOp(args)); switch (bfid) { @@ -328,22 +262,18 @@ void emitBuiltin(IrRegAllocX64& regs, AssemblyBuilderX64& build, int bfid, int r case LBF_MATH_MIN: case LBF_MATH_MAX: case LBF_MATH_CLAMP: + case LBF_MATH_FLOOR: + case LBF_MATH_CEIL: + case LBF_MATH_SQRT: + case LBF_MATH_POW: + case LBF_MATH_ABS: + case LBF_MATH_ROUND: // These instructions are fully translated to IR break; - case LBF_MATH_FLOOR: - return emitBuiltinMathFloor(regs, build, nparams, ra, arg, argsOp, nresults); - case LBF_MATH_CEIL: - return emitBuiltinMathCeil(regs, build, nparams, ra, arg, argsOp, nresults); - case LBF_MATH_SQRT: - return emitBuiltinMathSqrt(regs, build, nparams, ra, arg, argsOp, nresults); - case LBF_MATH_ABS: - return emitBuiltinMathAbs(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_EXP: return emitBuiltinMathExp(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_FMOD: return emitBuiltinMathFmod(regs, build, nparams, ra, arg, argsOp, nresults); - case LBF_MATH_POW: - return emitBuiltinMathPow(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_ASIN: return emitBuiltinMathAsin(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_SIN: @@ -370,8 +300,6 @@ void emitBuiltin(IrRegAllocX64& regs, AssemblyBuilderX64& build, int bfid, int r return emitBuiltinMathLog(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_LDEXP: return emitBuiltinMathLdexp(regs, build, nparams, ra, arg, argsOp, nresults); - case LBF_MATH_ROUND: - return emitBuiltinMathRound(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_FREXP: return emitBuiltinMathFrexp(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_MODF: diff --git a/CodeGen/src/EmitCommon.h b/CodeGen/src/EmitCommon.h index 3c41c271..a71eafd4 100644 --- a/CodeGen/src/EmitCommon.h +++ b/CodeGen/src/EmitCommon.h @@ -20,9 +20,16 @@ constexpr unsigned kOffsetOfInstructionC = 3; // Leaf functions that are placed in every module to perform common instruction sequences struct ModuleHelpers { + // A64/X64 Label exitContinueVm; Label exitNoContinueVm; + + // X64 Label continueCallInVm; + + // A64 + Label reentry; // x0: closure + Label interrupt; // x0: pc offset, x1: return address, x2: interrupt }; } // namespace CodeGen diff --git a/CodeGen/src/EmitCommonA64.cpp b/CodeGen/src/EmitCommonA64.cpp new file mode 100644 index 00000000..1758e4fb --- /dev/null +++ b/CodeGen/src/EmitCommonA64.cpp @@ -0,0 +1,130 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "EmitCommonA64.h" + +#include "NativeState.h" +#include "CustomExecUtils.h" + +namespace Luau +{ +namespace CodeGen +{ +namespace A64 +{ + +void emitUpdateBase(AssemblyBuilderA64& build) +{ + build.ldr(rBase, mem(rState, offsetof(lua_State, base))); +} + +void emitExit(AssemblyBuilderA64& build, bool continueInVm) +{ + build.mov(x0, continueInVm); + build.ldr(x1, mem(rNativeContext, offsetof(NativeContext, gateExit))); + build.br(x1); +} + +void emitInterrupt(AssemblyBuilderA64& build) +{ + // x0 = pc offset + // x1 = return address in native code + // x2 = interrupt + + // Stash return address in rBase; we need to reload rBase anyway + build.mov(rBase, x1); + + // Update savedpc; required in case interrupt errors + build.add(x0, rCode, x0); + build.ldr(x1, mem(rState, offsetof(lua_State, ci))); + build.str(x0, mem(x1, offsetof(CallInfo, savedpc))); + + // Call interrupt + build.mov(x0, rState); + build.mov(w1, -1); + build.blr(x2); + + // Check if we need to exit + Label skip; + build.ldrb(w0, mem(rState, offsetof(lua_State, status))); + build.cbz(w0, skip); + + // L->ci->savedpc-- + // note: recomputing this avoids having to stash x0 + build.ldr(x1, mem(rState, offsetof(lua_State, ci))); + build.ldr(x0, mem(x1, offsetof(CallInfo, savedpc))); + build.sub(x0, x0, sizeof(Instruction)); + build.str(x0, mem(x1, offsetof(CallInfo, savedpc))); + + emitExit(build, /* continueInVm */ false); + + build.setLabel(skip); + + // Return back to caller; rBase has stashed return address + build.mov(x0, rBase); + + emitUpdateBase(build); // interrupt may have reallocated stack + + build.br(x0); +} + +void emitReentry(AssemblyBuilderA64& build, ModuleHelpers& helpers) +{ + // x0 = closure object to reentry (equal to clvalue(L->ci->func)) + + // If the fallback requested an exit, we need to do this right away + build.cbz(x0, helpers.exitNoContinueVm); + + emitUpdateBase(build); + + // Need to update state of the current function before we jump away + build.ldr(x1, mem(x0, offsetof(Closure, l.p))); // cl->l.p aka proto + + build.mov(rClosure, x0); + build.ldr(rConstants, mem(x1, offsetof(Proto, k))); // proto->k + build.ldr(rCode, mem(x1, offsetof(Proto, code))); // proto->code + + // Get instruction index from instruction pointer + // To get instruction index from instruction pointer, we need to divide byte offset by 4 + // But we will actually need to scale instruction index by 8 back to byte offset later so it cancels out + build.ldr(x2, mem(rState, offsetof(lua_State, ci))); // L->ci + build.ldr(x2, mem(x2, offsetof(CallInfo, savedpc))); // L->ci->savedpc + build.sub(x2, x2, rCode); + build.add(x2, x2, x2); // TODO: this would not be necessary if we supported shifted register offsets in loads + + // We need to check if the new function can be executed natively + // TODO: This can be done earlier in the function flow, to reduce the JIT->VM transition penalty + build.ldr(x1, mem(x1, offsetofProtoExecData)); + build.cbz(x1, helpers.exitContinueVm); + + // Get new instruction location and jump to it + build.ldr(x1, mem(x1, offsetof(NativeProto, instTargets))); + build.ldr(x1, mem(x1, x2)); + build.br(x1); +} + +void emitFallback(AssemblyBuilderA64& build, int op, int pcpos) +{ + // fallback(L, instruction, base, k) + build.mov(x0, rState); + + // TODO: refactor into a common helper + if (pcpos * sizeof(Instruction) <= AssemblyBuilderA64::kMaxImmediate) + { + build.add(x1, rCode, uint16_t(pcpos * sizeof(Instruction))); + } + else + { + build.mov(x1, pcpos * sizeof(Instruction)); + build.add(x1, rCode, x1); + } + + build.mov(x2, rBase); + build.mov(x3, rConstants); + build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, fallback) + op * sizeof(NativeFallback) + offsetof(NativeFallback, fallback))); + build.blr(x4); + + emitUpdateBase(build); +} + +} // namespace A64 +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/src/EmitCommonA64.h b/CodeGen/src/EmitCommonA64.h new file mode 100644 index 00000000..2a65afa8 --- /dev/null +++ b/CodeGen/src/EmitCommonA64.h @@ -0,0 +1,53 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/AssemblyBuilderA64.h" + +#include "EmitCommon.h" + +#include "lobject.h" +#include "ltm.h" + +// AArch64 ABI reminder: +// Arguments: x0-x7, v0-v7 +// Return: x0, v0 (or x8 that points to the address of the resulting structure) +// Volatile: x9-x15, v16-v31 ("caller-saved", any call may change them) +// Non-volatile: x19-x28, v8-v15 ("callee-saved", preserved after calls, only bottom half of SIMD registers is preserved!) +// Reserved: x16-x18: reserved for linker/platform use; x29: frame pointer (unless omitted); x30: link register; x31: stack pointer + +namespace Luau +{ +namespace CodeGen +{ + +struct NativeState; + +namespace A64 +{ + +// Data that is very common to access is placed in non-volatile registers: +// 1. Constant registers (only loaded during codegen entry) +constexpr RegisterA64 rState = x19; // lua_State* L +constexpr RegisterA64 rNativeContext = x20; // NativeContext* context + +// 2. Frame registers (reloaded when call frame changes; rBase is also reloaded after all calls that may reallocate stack) +constexpr RegisterA64 rConstants = x21; // TValue* k +constexpr RegisterA64 rClosure = x22; // Closure* cl +constexpr RegisterA64 rCode = x23; // Instruction* code +constexpr RegisterA64 rBase = x24; // StkId base + +// Native code is as stackless as the interpreter, so we can place some data on the stack once and have it accessible at any point +// See CodeGenA64.cpp for layout +constexpr unsigned kStackSize = 64; // 8 stashed registers + +void emitUpdateBase(AssemblyBuilderA64& build); + +// TODO: Move these to CodeGenA64 so that they can't be accidentally called during lowering +void emitExit(AssemblyBuilderA64& build, bool continueInVm); +void emitInterrupt(AssemblyBuilderA64& build); +void emitReentry(AssemblyBuilderA64& build, ModuleHelpers& helpers); +void emitFallback(AssemblyBuilderA64& build, int op, int pcpos); + +} // namespace A64 +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/src/EmitCommonX64.cpp b/CodeGen/src/EmitCommonX64.cpp index e9cfdc48..9136add8 100644 --- a/CodeGen/src/EmitCommonX64.cpp +++ b/CodeGen/src/EmitCommonX64.cpp @@ -2,7 +2,9 @@ #include "EmitCommonX64.h" #include "Luau/AssemblyBuilderX64.h" +#include "Luau/IrCallWrapperX64.h" #include "Luau/IrData.h" +#include "Luau/IrRegAllocX64.h" #include "CustomExecUtils.h" #include "NativeState.h" @@ -64,18 +66,19 @@ void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs, } } -void jumpOnAnyCmpFallback(AssemblyBuilderX64& build, int ra, int rb, IrCondition cond, Label& label) +void jumpOnAnyCmpFallback(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb, IrCondition cond, Label& label) { - build.mov(rArg1, rState); - build.lea(rArg2, luauRegAddress(ra)); - build.lea(rArg3, luauRegAddress(rb)); + IrCallWrapperX64 callWrap(regs, build); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::qword, luauRegAddress(ra)); + callWrap.addArgument(SizeX64::qword, luauRegAddress(rb)); if (cond == IrCondition::NotLessEqual || cond == IrCondition::LessEqual) - build.call(qword[rNativeContext + offsetof(NativeContext, luaV_lessequal)]); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_lessequal)]); else if (cond == IrCondition::NotLess || cond == IrCondition::Less) - build.call(qword[rNativeContext + offsetof(NativeContext, luaV_lessthan)]); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_lessthan)]); else if (cond == IrCondition::NotEqual || cond == IrCondition::Equal) - build.call(qword[rNativeContext + offsetof(NativeContext, luaV_equalval)]); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_equalval)]); else LUAU_ASSERT(!"Unsupported condition"); @@ -119,68 +122,66 @@ void convertNumberToIndexOrJump(AssemblyBuilderX64& build, RegisterX64 tmp, Regi build.jcc(ConditionX64::NotZero, label); } -void callArithHelper(AssemblyBuilderX64& build, int ra, int rb, OperandX64 c, TMS tm) +void callArithHelper(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb, OperandX64 c, TMS tm) { - if (build.abi == ABIX64::Windows) - build.mov(sArg5, tm); - else - build.mov(rArg5, tm); - - build.mov(rArg1, rState); - build.lea(rArg2, luauRegAddress(ra)); - build.lea(rArg3, luauRegAddress(rb)); - build.lea(rArg4, c); - build.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarith)]); + IrCallWrapperX64 callWrap(regs, build); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::qword, luauRegAddress(ra)); + callWrap.addArgument(SizeX64::qword, luauRegAddress(rb)); + callWrap.addArgument(SizeX64::qword, c); + callWrap.addArgument(SizeX64::dword, tm); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarith)]); emitUpdateBase(build); } -void callLengthHelper(AssemblyBuilderX64& build, int ra, int rb) +void callLengthHelper(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb) { - build.mov(rArg1, rState); - build.lea(rArg2, luauRegAddress(ra)); - build.lea(rArg3, luauRegAddress(rb)); - build.call(qword[rNativeContext + offsetof(NativeContext, luaV_dolen)]); + IrCallWrapperX64 callWrap(regs, build); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::qword, luauRegAddress(ra)); + callWrap.addArgument(SizeX64::qword, luauRegAddress(rb)); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_dolen)]); emitUpdateBase(build); } -void callPrepareForN(AssemblyBuilderX64& build, int limit, int step, int init) +void callPrepareForN(IrRegAllocX64& regs, AssemblyBuilderX64& build, int limit, int step, int init) { - build.mov(rArg1, rState); - build.lea(rArg2, luauRegAddress(limit)); - build.lea(rArg3, luauRegAddress(step)); - build.lea(rArg4, luauRegAddress(init)); - build.call(qword[rNativeContext + offsetof(NativeContext, luaV_prepareFORN)]); + IrCallWrapperX64 callWrap(regs, build); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::qword, luauRegAddress(limit)); + callWrap.addArgument(SizeX64::qword, luauRegAddress(step)); + callWrap.addArgument(SizeX64::qword, luauRegAddress(init)); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_prepareFORN)]); } -void callGetTable(AssemblyBuilderX64& build, int rb, OperandX64 c, int ra) +void callGetTable(IrRegAllocX64& regs, AssemblyBuilderX64& build, int rb, OperandX64 c, int ra) { - build.mov(rArg1, rState); - build.lea(rArg2, luauRegAddress(rb)); - build.lea(rArg3, c); - build.lea(rArg4, luauRegAddress(ra)); - build.call(qword[rNativeContext + offsetof(NativeContext, luaV_gettable)]); + IrCallWrapperX64 callWrap(regs, build); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::qword, luauRegAddress(rb)); + callWrap.addArgument(SizeX64::qword, c); + callWrap.addArgument(SizeX64::qword, luauRegAddress(ra)); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_gettable)]); emitUpdateBase(build); } -void callSetTable(AssemblyBuilderX64& build, int rb, OperandX64 c, int ra) +void callSetTable(IrRegAllocX64& regs, AssemblyBuilderX64& build, int rb, OperandX64 c, int ra) { - build.mov(rArg1, rState); - build.lea(rArg2, luauRegAddress(rb)); - build.lea(rArg3, c); - build.lea(rArg4, luauRegAddress(ra)); - build.call(qword[rNativeContext + offsetof(NativeContext, luaV_settable)]); + IrCallWrapperX64 callWrap(regs, build); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::qword, luauRegAddress(rb)); + callWrap.addArgument(SizeX64::qword, c); + callWrap.addArgument(SizeX64::qword, luauRegAddress(ra)); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_settable)]); emitUpdateBase(build); } -// works for luaC_barriertable, luaC_barrierf -static void callBarrierImpl(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 object, int ra, Label& skip, int contextOffset) +void checkObjectBarrierConditions(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 object, int ra, Label& skip) { - LUAU_ASSERT(tmp != object); - // iscollectable(ra) build.cmp(luauRegTag(ra), LUA_TSTRING); build.jcc(ConditionX64::Less, skip); @@ -193,86 +194,74 @@ static void callBarrierImpl(AssemblyBuilderX64& build, RegisterX64 tmp, Register build.mov(tmp, luauRegValue(ra)); build.test(byte[tmp + offsetof(GCheader, marked)], bit2mask(WHITE0BIT, WHITE1BIT)); build.jcc(ConditionX64::Zero, skip); +} + +void callBarrierObject(IrRegAllocX64& regs, AssemblyBuilderX64& build, RegisterX64 object, IrOp objectOp, int ra) +{ + Label skip; + + ScopedRegX64 tmp{regs, SizeX64::qword}; + checkObjectBarrierConditions(build, tmp.reg, object, ra, skip); - // TODO: even with re-ordering we have a chance of failure, we have a task to fix this in the future - if (object == rArg3) { - LUAU_ASSERT(tmp != rArg2); + ScopedSpills spillGuard(regs); - if (rArg2 != object) - build.mov(rArg2, object); - - if (rArg3 != tmp) - build.mov(rArg3, tmp); - } - else - { - if (rArg3 != tmp) - build.mov(rArg3, tmp); - - if (rArg2 != object) - build.mov(rArg2, object); + IrCallWrapperX64 callWrap(regs, build); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::qword, object, objectOp); + callWrap.addArgument(SizeX64::qword, tmp); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaC_barrierf)]); } - build.mov(rArg1, rState); - build.call(qword[rNativeContext + contextOffset]); + build.setLabel(skip); } -void callBarrierTable(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 table, int ra, Label& skip) +void callBarrierTableFast(IrRegAllocX64& regs, AssemblyBuilderX64& build, RegisterX64 table, IrOp tableOp) { - callBarrierImpl(build, tmp, table, ra, skip, offsetof(NativeContext, luaC_barriertable)); -} + Label skip; -void callBarrierObject(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 object, int ra, Label& skip) -{ - callBarrierImpl(build, tmp, object, ra, skip, offsetof(NativeContext, luaC_barrierf)); -} - -void callBarrierTableFast(AssemblyBuilderX64& build, RegisterX64 table, Label& skip) -{ // isblack(obj2gco(t)) build.test(byte[table + offsetof(GCheader, marked)], bitmask(BLACKBIT)); build.jcc(ConditionX64::Zero, skip); - // Argument setup re-ordered to avoid conflicts with table register - if (table != rArg2) - build.mov(rArg2, table); - build.lea(rArg3, addr[rArg2 + offsetof(Table, gclist)]); - build.mov(rArg1, rState); - build.call(qword[rNativeContext + offsetof(NativeContext, luaC_barrierback)]); + { + ScopedSpills spillGuard(regs); + + IrCallWrapperX64 callWrap(regs, build); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::qword, table, tableOp); + callWrap.addArgument(SizeX64::qword, addr[table + offsetof(Table, gclist)]); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaC_barrierback)]); + } + + build.setLabel(skip); } -void callCheckGc(AssemblyBuilderX64& build, int pcpos, bool savepc, Label& skip) +void callStepGc(IrRegAllocX64& regs, AssemblyBuilderX64& build) { - build.mov(rax, qword[rState + offsetof(lua_State, global)]); - build.mov(rdx, qword[rax + offsetof(global_State, totalbytes)]); - build.cmp(rdx, qword[rax + offsetof(global_State, GCthreshold)]); - build.jcc(ConditionX64::Below, skip); + Label skip; - if (savepc) - emitSetSavedPc(build, pcpos + 1); + { + ScopedRegX64 tmp1{regs, SizeX64::qword}; + ScopedRegX64 tmp2{regs, SizeX64::qword}; - build.mov(rArg1, rState); - build.mov(dwordReg(rArg2), 1); - build.call(qword[rNativeContext + offsetof(NativeContext, luaC_step)]); + build.mov(tmp1.reg, qword[rState + offsetof(lua_State, global)]); + build.mov(tmp2.reg, qword[tmp1.reg + offsetof(global_State, totalbytes)]); + build.cmp(tmp2.reg, qword[tmp1.reg + offsetof(global_State, GCthreshold)]); + build.jcc(ConditionX64::Below, skip); + } - emitUpdateBase(build); -} + { + ScopedSpills spillGuard(regs); -void callGetFastTmOrFallback(AssemblyBuilderX64& build, RegisterX64 table, TMS tm, Label& fallback) -{ - build.mov(rArg1, qword[table + offsetof(Table, metatable)]); - build.test(rArg1, rArg1); - build.jcc(ConditionX64::Zero, fallback); // no metatable + IrCallWrapperX64 callWrap(regs, build); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::dword, 1); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaC_step)]); + emitUpdateBase(build); + } - build.test(byte[rArg1 + offsetof(Table, tmcache)], 1 << tm); - build.jcc(ConditionX64::NotZero, fallback); // no tag method - - // rArg1 is already prepared - build.mov(rArg2, tm); - build.mov(rax, qword[rState + offsetof(lua_State, global)]); - build.mov(rArg3, qword[rax + offsetof(global_State, tmname) + tm * sizeof(TString*)]); - build.call(qword[rNativeContext + offsetof(NativeContext, luaT_gettm)]); + build.setLabel(skip); } void emitExit(AssemblyBuilderX64& build, bool continueInVm) @@ -291,7 +280,7 @@ void emitUpdateBase(AssemblyBuilderX64& build) } // Note: only uses rax/rdx, the caller may use other registers -void emitSetSavedPc(AssemblyBuilderX64& build, int pcpos) +static void emitSetSavedPc(AssemblyBuilderX64& build, int pcpos) { build.mov(rdx, sCode); build.add(rdx, pcpos * sizeof(Instruction)); @@ -317,6 +306,8 @@ void emitInterrupt(AssemblyBuilderX64& build, int pcpos) build.mov(dwordReg(rArg2), -1); // function accepts 'int' here and using qword reg would've forced 8 byte constant here build.call(r8); + emitUpdateBase(build); // interrupt may have reallocated stack + // Check if we need to exit build.mov(al, byte[rState + offsetof(lua_State, status)]); build.test(al, al); @@ -331,9 +322,6 @@ void emitInterrupt(AssemblyBuilderX64& build, int pcpos) void emitFallback(AssemblyBuilderX64& build, NativeState& data, int op, int pcpos) { - if (op == LOP_CAPTURE) - return; - NativeFallback& opinfo = data.context.fallback[op]; LUAU_ASSERT(opinfo.fallback); diff --git a/CodeGen/src/EmitCommonX64.h b/CodeGen/src/EmitCommonX64.h index 6b676255..6aac5a1e 100644 --- a/CodeGen/src/EmitCommonX64.h +++ b/CodeGen/src/EmitCommonX64.h @@ -27,10 +27,13 @@ namespace CodeGen enum class IrCondition : uint8_t; struct NativeState; +struct IrOp; namespace X64 { +struct IrRegAllocX64; + // Data that is very common to access is placed in non-volatile registers constexpr RegisterX64 rState = r15; // lua_State* L constexpr RegisterX64 rBase = r14; // StkId base @@ -39,12 +42,14 @@ constexpr RegisterX64 rConstants = r12; // TValue* k // Native code is as stackless as the interpreter, so we can place some data on the stack once and have it accessible at any point // See CodeGenX64.cpp for layout -constexpr unsigned kStackSize = 32 + 16; // 4 home locations for registers, 16 bytes for additional function call arguments -constexpr unsigned kLocalsSize = 24; // 3 extra slots for our custom locals (also aligns the stack to 16 byte boundary) +constexpr unsigned kStackSize = 32 + 16; // 4 home locations for registers, 16 bytes for additional function call arguments +constexpr unsigned kSpillSlots = 4; // locations for register allocator to spill data into +constexpr unsigned kLocalsSize = 24 + 8 * kSpillSlots; // 3 extra slots for our custom locals (also aligns the stack to 16 byte boundary) constexpr OperandX64 sClosure = qword[rsp + kStackSize + 0]; // Closure* cl constexpr OperandX64 sCode = qword[rsp + kStackSize + 8]; // Instruction* code constexpr OperandX64 sTemporarySlot = addr[rsp + kStackSize + 16]; +constexpr OperandX64 sSpillArea = addr[rsp + kStackSize + 24]; // TODO: These should be replaced with a portable call function that checks the ABI at runtime and reorders moves accordingly to avoid conflicts #if defined(_WIN32) @@ -96,6 +101,11 @@ inline OperandX64 luauRegValueInt(int ri) return dword[rBase + ri * sizeof(TValue) + offsetof(TValue, value)]; } +inline OperandX64 luauRegValueVector(int ri, int index) +{ + return dword[rBase + ri * sizeof(TValue) + offsetof(TValue, value) + (sizeof(float) * index)]; +} + inline OperandX64 luauConstant(int ki) { return xmmword[rConstants + ki * sizeof(TValue)]; @@ -233,25 +243,23 @@ inline void jumpIfNodeKeyNotInExpectedSlot(AssemblyBuilderX64& build, RegisterX6 } void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs, OperandX64 rhs, IrCondition cond, Label& label); -void jumpOnAnyCmpFallback(AssemblyBuilderX64& build, int ra, int rb, IrCondition cond, Label& label); +void jumpOnAnyCmpFallback(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb, IrCondition cond, Label& label); void getTableNodeAtCachedSlot(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 node, RegisterX64 table, int pcpos); void convertNumberToIndexOrJump(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 numd, RegisterX64 numi, Label& label); -void callArithHelper(AssemblyBuilderX64& build, int ra, int rb, OperandX64 c, TMS tm); -void callLengthHelper(AssemblyBuilderX64& build, int ra, int rb); -void callPrepareForN(AssemblyBuilderX64& build, int limit, int step, int init); -void callGetTable(AssemblyBuilderX64& build, int rb, OperandX64 c, int ra); -void callSetTable(AssemblyBuilderX64& build, int rb, OperandX64 c, int ra); -void callBarrierTable(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 table, int ra, Label& skip); -void callBarrierObject(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 object, int ra, Label& skip); -void callBarrierTableFast(AssemblyBuilderX64& build, RegisterX64 table, Label& skip); -void callCheckGc(AssemblyBuilderX64& build, int pcpos, bool savepc, Label& skip); -void callGetFastTmOrFallback(AssemblyBuilderX64& build, RegisterX64 table, TMS tm, Label& fallback); +void callArithHelper(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb, OperandX64 c, TMS tm); +void callLengthHelper(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb); +void callPrepareForN(IrRegAllocX64& regs, AssemblyBuilderX64& build, int limit, int step, int init); +void callGetTable(IrRegAllocX64& regs, AssemblyBuilderX64& build, int rb, OperandX64 c, int ra); +void callSetTable(IrRegAllocX64& regs, AssemblyBuilderX64& build, int rb, OperandX64 c, int ra); +void checkObjectBarrierConditions(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 object, int ra, Label& skip); +void callBarrierObject(IrRegAllocX64& regs, AssemblyBuilderX64& build, RegisterX64 object, IrOp objectOp, int ra); +void callBarrierTableFast(IrRegAllocX64& regs, AssemblyBuilderX64& build, RegisterX64 table, IrOp tableOp); +void callStepGc(IrRegAllocX64& regs, AssemblyBuilderX64& build); void emitExit(AssemblyBuilderX64& build, bool continueInVm); void emitUpdateBase(AssemblyBuilderX64& build); -void emitSetSavedPc(AssemblyBuilderX64& build, int pcpos); // Note: only uses rax/rdx, the caller may use other registers void emitInterrupt(AssemblyBuilderX64& build, int pcpos); void emitFallback(AssemblyBuilderX64& build, NativeState& data, int op, int pcpos); diff --git a/CodeGen/src/EmitInstructionA64.cpp b/CodeGen/src/EmitInstructionA64.cpp new file mode 100644 index 00000000..400ba77e --- /dev/null +++ b/CodeGen/src/EmitInstructionA64.cpp @@ -0,0 +1,74 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "EmitInstructionA64.h" + +#include "Luau/AssemblyBuilderA64.h" + +#include "EmitCommonA64.h" +#include "NativeState.h" +#include "CustomExecUtils.h" + +namespace Luau +{ +namespace CodeGen +{ +namespace A64 +{ + +void emitInstReturn(AssemblyBuilderA64& build, ModuleHelpers& helpers, int ra, int n) +{ + // callFallback(L, ra, n) + build.mov(x0, rState); + build.add(x1, rBase, uint16_t(ra * sizeof(TValue))); + build.mov(w2, n); + build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, returnFallback))); + build.blr(x3); + + // reentry with x0=closure (NULL will trigger exit) + build.b(helpers.reentry); +} + +void emitInstCall(AssemblyBuilderA64& build, ModuleHelpers& helpers, int ra, int nparams, int nresults) +{ + // argtop = (nparams == LUA_MULTRET) ? L->top : ra + 1 + nparams; + if (nparams == LUA_MULTRET) + build.ldr(x2, mem(rState, offsetof(lua_State, top))); + else + build.add(x2, rBase, uint16_t((ra + 1 + nparams) * sizeof(TValue))); + + // callFallback(L, ra, argtop, nresults) + build.mov(x0, rState); + build.add(x1, rBase, uint16_t(ra * sizeof(TValue))); + build.mov(w3, nresults); + build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, callFallback))); + build.blr(x4); + + // reentry with x0=closure (NULL will trigger exit) + build.b(helpers.reentry); +} + +void emitInstGetImport(AssemblyBuilderA64& build, int ra, uint32_t aux) +{ + // luaV_getimport(L, cl->env, k, aux, /* propagatenil= */ false) + build.mov(x0, rState); + build.ldr(x1, mem(rClosure, offsetof(Closure, env))); + build.mov(x2, rConstants); + build.mov(w3, aux); + build.mov(w4, 0); + build.ldr(x5, mem(rNativeContext, offsetof(NativeContext, luaV_getimport))); + build.blr(x5); + + emitUpdateBase(build); + + // setobj2s(L, ra, L->top - 1) + build.ldr(x0, mem(rState, offsetof(lua_State, top))); + build.sub(x0, x0, sizeof(TValue)); + build.ldr(q0, x0); + build.str(q0, mem(rBase, ra * sizeof(TValue))); + + // L->top-- + build.str(x0, mem(rState, offsetof(lua_State, top))); +} + +} // namespace A64 +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/src/EmitInstructionA64.h b/CodeGen/src/EmitInstructionA64.h new file mode 100644 index 00000000..278d8e8e --- /dev/null +++ b/CodeGen/src/EmitInstructionA64.h @@ -0,0 +1,24 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include + +namespace Luau +{ +namespace CodeGen +{ + +struct ModuleHelpers; + +namespace A64 +{ + +class AssemblyBuilderA64; + +void emitInstReturn(AssemblyBuilderA64& build, ModuleHelpers& helpers, int ra, int n); +void emitInstCall(AssemblyBuilderA64& build, ModuleHelpers& helpers, int ra, int nparams, int nresults); +void emitInstGetImport(AssemblyBuilderA64& build, int ra, uint32_t aux); + +} // namespace A64 +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/src/EmitInstructionX64.cpp b/CodeGen/src/EmitInstructionX64.cpp index e8f61ebb..c0a64274 100644 --- a/CodeGen/src/EmitInstructionX64.cpp +++ b/CodeGen/src/EmitInstructionX64.cpp @@ -2,14 +2,10 @@ #include "EmitInstructionX64.h" #include "Luau/AssemblyBuilderX64.h" +#include "Luau/IrRegAllocX64.h" #include "CustomExecUtils.h" -#include "EmitBuiltinsX64.h" #include "EmitCommonX64.h" -#include "NativeState.h" - -#include "lobject.h" -#include "ltm.h" namespace Luau { @@ -18,16 +14,8 @@ namespace CodeGen namespace X64 { -void emitInstCall(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Instruction* pc, int pcpos) +void emitInstCall(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, int nparams, int nresults) { - int ra = LUAU_INSN_A(*pc); - int nparams = LUAU_INSN_B(*pc) - 1; - int nresults = LUAU_INSN_C(*pc) - 1; - - emitInterrupt(build, pcpos); - - emitSetSavedPc(build, pcpos + 1); - build.mov(rArg1, rState); build.lea(rArg2, luauRegAddress(ra)); @@ -171,13 +159,8 @@ void emitInstCall(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Instr } } -void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Instruction* pc, int pcpos) +void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, int actualResults) { - emitInterrupt(build, pcpos); - - int ra = LUAU_INSN_A(*pc); - int b = LUAU_INSN_B(*pc) - 1; - RegisterX64 ci = r8; RegisterX64 cip = r9; RegisterX64 res = rdi; @@ -196,7 +179,7 @@ void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Ins RegisterX64 counter = ecx; - if (b == 0) + if (actualResults == 0) { // Our instruction doesn't have any results, so just fill results expected in parent with 'nil' build.test(nresults, nresults); // test here will set SF=1 for a negative number, ZF=1 for zero and OF=0 @@ -210,7 +193,7 @@ void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Ins build.dec(counter); build.jcc(ConditionX64::NotZero, repeatNilLoop); } - else if (b == 1) + else if (actualResults == 1) { // Try setting our 1 result build.test(nresults, nresults); @@ -245,10 +228,10 @@ void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Ins build.lea(vali, luauRegAddress(ra)); // Copy as much as possible for MULTRET calls, and only as much as needed otherwise - if (b == LUA_MULTRET) + if (actualResults == LUA_MULTRET) build.mov(valend, qword[rState + offsetof(lua_State, top)]); // valend = L->top else - build.lea(valend, luauRegAddress(ra + b)); // valend = ra + b + build.lea(valend, luauRegAddress(ra + actualResults)); // valend = ra + actualResults build.mov(counter, nresults); @@ -333,24 +316,19 @@ void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Ins build.jmp(qword[rdx + rax * 2]); } -void emitInstSetList(AssemblyBuilderX64& build, const Instruction* pc, Label& next) +void emitInstSetList(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb, int count, uint32_t index) { - int ra = LUAU_INSN_A(*pc); - int rb = LUAU_INSN_B(*pc); - int c = LUAU_INSN_C(*pc) - 1; - uint32_t index = pc[1]; + OperandX64 last = index + count - 1; - OperandX64 last = index + c - 1; - - // Using non-volatile 'rbx' for dynamic 'c' value (for LUA_MULTRET) to skip later recomputation - // We also keep 'c' scaled by sizeof(TValue) here as it helps in the loop below + // Using non-volatile 'rbx' for dynamic 'count' value (for LUA_MULTRET) to skip later recomputation + // We also keep 'count' scaled by sizeof(TValue) here as it helps in the loop below RegisterX64 cscaled = rbx; - if (c == LUA_MULTRET) + if (count == LUA_MULTRET) { RegisterX64 tmp = rax; - // c = L->top - rb + // count = L->top - rb build.mov(cscaled, qword[rState + offsetof(lua_State, top)]); build.lea(tmp, luauRegAddress(rb)); build.sub(cscaled, tmp); // Using byte difference @@ -360,7 +338,7 @@ void emitInstSetList(AssemblyBuilderX64& build, const Instruction* pc, Label& ne build.mov(tmp, qword[tmp + offsetof(CallInfo, top)]); build.mov(qword[rState + offsetof(lua_State, top)], tmp); - // last = index + c - 1; + // last = index + count - 1; last = edx; build.mov(last, dwordReg(cscaled)); build.shr(last, kTValueSizeLog2); @@ -369,7 +347,7 @@ void emitInstSetList(AssemblyBuilderX64& build, const Instruction* pc, Label& ne Label skipResize; - RegisterX64 table = rax; + RegisterX64 table = regs.takeReg(rax, kInvalidInstIdx); build.mov(table, luauRegValue(ra)); @@ -394,9 +372,9 @@ void emitInstSetList(AssemblyBuilderX64& build, const Instruction* pc, Label& ne const int kUnrollSetListLimit = 4; - if (c != LUA_MULTRET && c <= kUnrollSetListLimit) + if (count != LUA_MULTRET && count <= kUnrollSetListLimit) { - for (int i = 0; i < c; ++i) + for (int i = 0; i < count; ++i) { // setobj2t(L, &array[index + i - 1], rb + i); build.vmovups(xmm0, luauRegValue(rb + i)); @@ -405,17 +383,17 @@ void emitInstSetList(AssemblyBuilderX64& build, const Instruction* pc, Label& ne } else { - LUAU_ASSERT(c != 0); + LUAU_ASSERT(count != 0); build.xor_(offset, offset); if (index != 1) build.add(arrayDst, (index - 1) * sizeof(TValue)); Label repeatLoop, endLoop; - OperandX64 limit = c == LUA_MULTRET ? cscaled : OperandX64(c * sizeof(TValue)); + OperandX64 limit = count == LUA_MULTRET ? cscaled : OperandX64(count * sizeof(TValue)); // If c is static, we will always do at least one iteration - if (c == LUA_MULTRET) + if (count == LUA_MULTRET) { build.cmp(offset, limit); build.jcc(ConditionX64::NotBelow, endLoop); @@ -434,7 +412,7 @@ void emitInstSetList(AssemblyBuilderX64& build, const Instruction* pc, Label& ne build.setLabel(endLoop); } - callBarrierTableFast(build, table, next); + callBarrierTableFast(regs, build, table, {}); } void emitinstForGLoop(AssemblyBuilderX64& build, int ra, int aux, Label& loopRepeat, Label& loopExit) @@ -506,10 +484,8 @@ void emitinstForGLoop(AssemblyBuilderX64& build, int ra, int aux, Label& loopRep build.jcc(ConditionX64::NotZero, loopRepeat); } -void emitinstForGLoopFallback(AssemblyBuilderX64& build, int pcpos, int ra, int aux, Label& loopRepeat) +void emitinstForGLoopFallback(AssemblyBuilderX64& build, int ra, int aux, Label& loopRepeat) { - emitSetSavedPc(build, pcpos + 1); - build.mov(rArg1, rState); build.mov(dwordReg(rArg2), ra); build.mov(dwordReg(rArg3), aux); @@ -528,82 +504,6 @@ void emitInstForGPrepXnextFallback(AssemblyBuilderX64& build, int pcpos, int ra, build.jmp(target); } -static void emitInstAndX(AssemblyBuilderX64& build, int ra, int rb, OperandX64 c) -{ - Label target, fallthrough; - jumpIfFalsy(build, rb, target, fallthrough); - - build.setLabel(fallthrough); - - build.vmovups(xmm0, c); - build.vmovups(luauReg(ra), xmm0); - - if (ra == rb) - { - build.setLabel(target); - } - else - { - Label exit; - build.jmp(exit); - - build.setLabel(target); - - build.vmovups(xmm0, luauReg(rb)); - build.vmovups(luauReg(ra), xmm0); - - build.setLabel(exit); - } -} - -void emitInstAnd(AssemblyBuilderX64& build, const Instruction* pc) -{ - emitInstAndX(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), luauReg(LUAU_INSN_C(*pc))); -} - -void emitInstAndK(AssemblyBuilderX64& build, const Instruction* pc) -{ - emitInstAndX(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), luauConstant(LUAU_INSN_C(*pc))); -} - -static void emitInstOrX(AssemblyBuilderX64& build, int ra, int rb, OperandX64 c) -{ - Label target, fallthrough; - jumpIfTruthy(build, rb, target, fallthrough); - - build.setLabel(fallthrough); - - build.vmovups(xmm0, c); - build.vmovups(luauReg(ra), xmm0); - - if (ra == rb) - { - build.setLabel(target); - } - else - { - Label exit; - build.jmp(exit); - - build.setLabel(target); - - build.vmovups(xmm0, luauReg(rb)); - build.vmovups(luauReg(ra), xmm0); - - build.setLabel(exit); - } -} - -void emitInstOr(AssemblyBuilderX64& build, const Instruction* pc) -{ - emitInstOrX(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), luauReg(LUAU_INSN_C(*pc))); -} - -void emitInstOrK(AssemblyBuilderX64& build, const Instruction* pc) -{ - emitInstOrX(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), luauConstant(LUAU_INSN_C(*pc))); -} - void emitInstGetImportFallback(AssemblyBuilderX64& build, int ra, uint32_t aux) { build.mov(rax, sClosure); diff --git a/CodeGen/src/EmitInstructionX64.h b/CodeGen/src/EmitInstructionX64.h index 6a8a3c0e..d58e1331 100644 --- a/CodeGen/src/EmitInstructionX64.h +++ b/CodeGen/src/EmitInstructionX64.h @@ -3,11 +3,6 @@ #include -#include "ltm.h" - -typedef uint32_t Instruction; -typedef struct lua_TValue TValue; - namespace Luau { namespace CodeGen @@ -20,17 +15,14 @@ namespace X64 { class AssemblyBuilderX64; +struct IrRegAllocX64; -void emitInstCall(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Instruction* pc, int pcpos); -void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Instruction* pc, int pcpos); -void emitInstSetList(AssemblyBuilderX64& build, const Instruction* pc, Label& next); +void emitInstCall(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, int nparams, int nresults); +void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, int actualResults); +void emitInstSetList(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb, int count, uint32_t index); void emitinstForGLoop(AssemblyBuilderX64& build, int ra, int aux, Label& loopRepeat, Label& loopExit); -void emitinstForGLoopFallback(AssemblyBuilderX64& build, int pcpos, int ra, int aux, Label& loopRepeat); +void emitinstForGLoopFallback(AssemblyBuilderX64& build, int ra, int aux, Label& loopRepeat); void emitInstForGPrepXnextFallback(AssemblyBuilderX64& build, int pcpos, int ra, Label& target); -void emitInstAnd(AssemblyBuilderX64& build, const Instruction* pc); -void emitInstAndK(AssemblyBuilderX64& build, const Instruction* pc); -void emitInstOr(AssemblyBuilderX64& build, const Instruction* pc); -void emitInstOrK(AssemblyBuilderX64& build, const Instruction* pc); void emitInstGetImportFallback(AssemblyBuilderX64& build, int ra, uint32_t aux); void emitInstCoverage(AssemblyBuilderX64& build, int pcpos); diff --git a/CodeGen/src/IrAnalysis.cpp b/CodeGen/src/IrAnalysis.cpp index b998487f..2246e5c5 100644 --- a/CodeGen/src/IrAnalysis.cpp +++ b/CodeGen/src/IrAnalysis.cpp @@ -69,6 +69,9 @@ void updateLastUseLocations(IrFunction& function) instructions[op.index].lastUse = uint32_t(instIdx); }; + if (isPseudo(inst.cmd)) + continue; + checkOp(inst.a); checkOp(inst.b); checkOp(inst.c); @@ -78,6 +81,42 @@ void updateLastUseLocations(IrFunction& function) } } +uint32_t getNextInstUse(IrFunction& function, uint32_t targetInstIdx, uint32_t startInstIdx) +{ + LUAU_ASSERT(startInstIdx < function.instructions.size()); + IrInst& targetInst = function.instructions[targetInstIdx]; + + for (uint32_t i = startInstIdx; i <= targetInst.lastUse; i++) + { + IrInst& inst = function.instructions[i]; + + if (isPseudo(inst.cmd)) + continue; + + if (inst.a.kind == IrOpKind::Inst && inst.a.index == targetInstIdx) + return i; + + if (inst.b.kind == IrOpKind::Inst && inst.b.index == targetInstIdx) + return i; + + if (inst.c.kind == IrOpKind::Inst && inst.c.index == targetInstIdx) + return i; + + if (inst.d.kind == IrOpKind::Inst && inst.d.index == targetInstIdx) + return i; + + if (inst.e.kind == IrOpKind::Inst && inst.e.index == targetInstIdx) + return i; + + if (inst.f.kind == IrOpKind::Inst && inst.f.index == targetInstIdx) + return i; + } + + // There must be a next use since there is the last use location + LUAU_ASSERT(!"failed to find next use"); + return targetInst.lastUse; +} + std::pair getLiveInOutValueCount(IrFunction& function, IrBlock& block) { uint32_t liveIns = 0; @@ -97,6 +136,9 @@ std::pair getLiveInOutValueCount(IrFunction& function, IrBlo { IrInst& inst = function.instructions[instIdx]; + if (isPseudo(inst.cmd)) + continue; + liveOuts += inst.useCount; checkOp(inst.a); @@ -149,26 +191,24 @@ static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& RegisterSet inRs; auto def = [&](IrOp op, int offset = 0) { - LUAU_ASSERT(op.kind == IrOpKind::VmReg); - defRs.regs.set(op.index + offset, true); + defRs.regs.set(vmRegOp(op) + offset, true); }; auto use = [&](IrOp op, int offset = 0) { - LUAU_ASSERT(op.kind == IrOpKind::VmReg); - if (!defRs.regs.test(op.index + offset)) - inRs.regs.set(op.index + offset, true); + if (!defRs.regs.test(vmRegOp(op) + offset)) + inRs.regs.set(vmRegOp(op) + offset, true); }; auto maybeDef = [&](IrOp op) { if (op.kind == IrOpKind::VmReg) - defRs.regs.set(op.index, true); + defRs.regs.set(vmRegOp(op), true); }; auto maybeUse = [&](IrOp op) { if (op.kind == IrOpKind::VmReg) { - if (!defRs.regs.test(op.index)) - inRs.regs.set(op.index, true); + if (!defRs.regs.test(vmRegOp(op))) + inRs.regs.set(vmRegOp(op), true); } }; @@ -230,6 +270,7 @@ static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& case IrCmd::STORE_POINTER: case IrCmd::STORE_DOUBLE: case IrCmd::STORE_INT: + case IrCmd::STORE_VECTOR: case IrCmd::STORE_TVALUE: maybeDef(inst.a); // Argument can also be a pointer value break; @@ -244,12 +285,16 @@ static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& // A <- B, C case IrCmd::DO_ARITH: case IrCmd::GET_TABLE: - case IrCmd::SET_TABLE: use(inst.b); maybeUse(inst.c); // Argument can also be a VmConst def(inst.a); break; + case IrCmd::SET_TABLE: + use(inst.a); + use(inst.b); + maybeUse(inst.c); // Argument can also be a VmConst + break; // A <- B case IrCmd::DO_LEN: use(inst.b); @@ -260,9 +305,9 @@ static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& def(inst.a); break; case IrCmd::CONCAT: - useRange(inst.a.index, function.uintOp(inst.b)); + useRange(vmRegOp(inst.a), function.uintOp(inst.b)); - defRange(inst.a.index, function.uintOp(inst.b)); + defRange(vmRegOp(inst.a), function.uintOp(inst.b)); break; case IrCmd::GET_UPVALUE: def(inst.a); @@ -294,20 +339,20 @@ static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& maybeUse(inst.a); if (function.boolOp(inst.b)) - capturedRegs.set(inst.a.index, true); + capturedRegs.set(vmRegOp(inst.a), true); break; - case IrCmd::LOP_SETLIST: + case IrCmd::SETLIST: use(inst.b); - useRange(inst.c.index, function.intOp(inst.d)); + useRange(vmRegOp(inst.c), function.intOp(inst.d)); break; - case IrCmd::LOP_CALL: - use(inst.b); - useRange(inst.b.index + 1, function.intOp(inst.c)); + case IrCmd::CALL: + use(inst.a); + useRange(vmRegOp(inst.a) + 1, function.intOp(inst.b)); - defRange(inst.b.index, function.intOp(inst.d)); + defRange(vmRegOp(inst.a), function.intOp(inst.c)); break; - case IrCmd::LOP_RETURN: - useRange(inst.b.index, function.intOp(inst.c)); + case IrCmd::RETURN: + useRange(vmRegOp(inst.a), function.intOp(inst.b)); break; case IrCmd::FASTCALL: case IrCmd::INVOKE_FASTCALL: @@ -315,9 +360,9 @@ static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& { if (count >= 3) { - LUAU_ASSERT(inst.d.kind == IrOpKind::VmReg && inst.d.index == inst.c.index + 1); + LUAU_ASSERT(inst.d.kind == IrOpKind::VmReg && vmRegOp(inst.d) == vmRegOp(inst.c) + 1); - useRange(inst.c.index, count); + useRange(vmRegOp(inst.c), count); } else { @@ -330,43 +375,30 @@ static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& } else { - useVarargs(inst.c.index); + useVarargs(vmRegOp(inst.c)); } - defRange(inst.b.index, function.intOp(inst.f)); + // Multiple return sequences (count == -1) are defined by ADJUST_STACK_TO_REG + if (int count = function.intOp(inst.f); count != -1) + defRange(vmRegOp(inst.b), count); break; - case IrCmd::LOP_FORGLOOP: + case IrCmd::FORGLOOP: // First register is not used by instruction, we check that it's still 'nil' with CHECK_TAG use(inst.a, 1); use(inst.a, 2); def(inst.a, 2); - defRange(inst.a.index + 3, function.intOp(inst.b)); + defRange(vmRegOp(inst.a) + 3, function.intOp(inst.b)); break; - case IrCmd::LOP_FORGLOOP_FALLBACK: - useRange(inst.b.index, 3); + case IrCmd::FORGLOOP_FALLBACK: + useRange(vmRegOp(inst.a), 3); - def(inst.b, 2); - defRange(inst.b.index + 3, uint8_t(function.intOp(inst.c))); // ignore most significant bit + def(inst.a, 2); + defRange(vmRegOp(inst.a) + 3, uint8_t(function.intOp(inst.b))); // ignore most significant bit break; - case IrCmd::LOP_FORGPREP_XNEXT_FALLBACK: + case IrCmd::FORGPREP_XNEXT_FALLBACK: use(inst.b); break; - // B <- C, D - case IrCmd::LOP_AND: - case IrCmd::LOP_OR: - use(inst.c); - use(inst.d); - - def(inst.b); - break; - // B <- C - case IrCmd::LOP_ANDK: - case IrCmd::LOP_ORK: - use(inst.c); - - def(inst.b); - break; case IrCmd::FALLBACK_GETGLOBAL: def(inst.b); break; @@ -385,13 +417,13 @@ static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& case IrCmd::FALLBACK_NAMECALL: use(inst.c); - defRange(inst.b.index, 2); + defRange(vmRegOp(inst.b), 2); break; case IrCmd::FALLBACK_PREPVARARGS: // No effect on explicitly referenced registers break; case IrCmd::FALLBACK_GETVARARGS: - defRange(inst.b.index, function.intOp(inst.c)); + defRange(vmRegOp(inst.b), function.intOp(inst.c)); break; case IrCmd::FALLBACK_NEWCLOSURE: def(inst.b); @@ -402,11 +434,13 @@ static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& case IrCmd::FALLBACK_FORGPREP: use(inst.b); - defRange(inst.b.index, 3); + defRange(vmRegOp(inst.b), 3); break; case IrCmd::ADJUST_STACK_TO_REG: + defRange(vmRegOp(inst.a), -1); + break; case IrCmd::ADJUST_STACK_TO_TOP: - // While these can be considered as vararg producers and consumers, it is already handled in fastcall instruction + // While this can be considered to be a vararg consumer, it is already handled in fastcall instructions break; default: @@ -626,7 +660,7 @@ void computeCfgInfo(IrFunction& function) computeCfgLiveInOutRegSets(function); } -BlockIteratorWrapper predecessors(CfgInfo& cfg, uint32_t blockIdx) +BlockIteratorWrapper predecessors(const CfgInfo& cfg, uint32_t blockIdx) { LUAU_ASSERT(blockIdx < cfg.predecessorsOffsets.size()); @@ -636,7 +670,7 @@ BlockIteratorWrapper predecessors(CfgInfo& cfg, uint32_t blockIdx) return BlockIteratorWrapper{cfg.predecessors.data() + start, cfg.predecessors.data() + end}; } -BlockIteratorWrapper successors(CfgInfo& cfg, uint32_t blockIdx) +BlockIteratorWrapper successors(const CfgInfo& cfg, uint32_t blockIdx) { LUAU_ASSERT(blockIdx < cfg.successorsOffsets.size()); diff --git a/CodeGen/src/IrBuilder.cpp b/CodeGen/src/IrBuilder.cpp index f1099cfa..48c0e25c 100644 --- a/CodeGen/src/IrBuilder.cpp +++ b/CodeGen/src/IrBuilder.cpp @@ -132,7 +132,10 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) translateInstSetGlobal(*this, pc, i); break; case LOP_CALL: - inst(IrCmd::LOP_CALL, constUint(i), vmReg(LUAU_INSN_A(*pc)), constInt(LUAU_INSN_B(*pc) - 1), constInt(LUAU_INSN_C(*pc) - 1)); + inst(IrCmd::INTERRUPT, constUint(i)); + inst(IrCmd::SET_SAVEDPC, constUint(i + 1)); + + inst(IrCmd::CALL, vmReg(LUAU_INSN_A(*pc)), constInt(LUAU_INSN_B(*pc) - 1), constInt(LUAU_INSN_C(*pc) - 1)); if (activeFastcallFallback) { @@ -144,7 +147,9 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) } break; case LOP_RETURN: - inst(IrCmd::LOP_RETURN, constUint(i), vmReg(LUAU_INSN_A(*pc)), constInt(LUAU_INSN_B(*pc) - 1)); + inst(IrCmd::INTERRUPT, constUint(i)); + + inst(IrCmd::RETURN, vmReg(LUAU_INSN_A(*pc)), constInt(LUAU_INSN_B(*pc) - 1)); break; case LOP_GETTABLE: translateInstGetTable(*this, pc, i); @@ -261,7 +266,7 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) translateInstDupTable(*this, pc, i); break; case LOP_SETLIST: - inst(IrCmd::LOP_SETLIST, constUint(i), vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), constInt(LUAU_INSN_C(*pc) - 1), constUint(pc[1])); + inst(IrCmd::SETLIST, constUint(i), vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), constInt(LUAU_INSN_C(*pc) - 1), constUint(pc[1])); break; case LOP_GETUPVAL: translateInstGetUpval(*this, pc, i); @@ -342,10 +347,11 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) inst(IrCmd::INTERRUPT, constUint(i)); loadAndCheckTag(vmReg(ra), LUA_TNIL, fallback); - inst(IrCmd::LOP_FORGLOOP, vmReg(ra), constInt(aux), loopRepeat, loopExit); + inst(IrCmd::FORGLOOP, vmReg(ra), constInt(aux), loopRepeat, loopExit); beginBlock(fallback); - inst(IrCmd::LOP_FORGLOOP_FALLBACK, constUint(i), vmReg(ra), constInt(aux), loopRepeat, loopExit); + inst(IrCmd::SET_SAVEDPC, constUint(i + 1)); + inst(IrCmd::FORGLOOP_FALLBACK, vmReg(ra), constInt(aux), loopRepeat, loopExit); beginBlock(loopExit); } @@ -358,19 +364,19 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) translateInstForGPrepInext(*this, pc, i); break; case LOP_AND: - inst(IrCmd::LOP_AND, constUint(i), vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), vmReg(LUAU_INSN_C(*pc))); + translateInstAndX(*this, pc, i, vmReg(LUAU_INSN_C(*pc))); break; case LOP_ANDK: - inst(IrCmd::LOP_ANDK, constUint(i), vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), vmConst(LUAU_INSN_C(*pc))); + translateInstAndX(*this, pc, i, vmConst(LUAU_INSN_C(*pc))); break; case LOP_OR: - inst(IrCmd::LOP_OR, constUint(i), vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), vmReg(LUAU_INSN_C(*pc))); + translateInstOrX(*this, pc, i, vmReg(LUAU_INSN_C(*pc))); break; case LOP_ORK: - inst(IrCmd::LOP_ORK, constUint(i), vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), vmConst(LUAU_INSN_C(*pc))); + translateInstOrX(*this, pc, i, vmConst(LUAU_INSN_C(*pc))); break; case LOP_COVERAGE: - inst(IrCmd::LOP_COVERAGE, constUint(i)); + inst(IrCmd::COVERAGE, constUint(i)); break; case LOP_GETIMPORT: translateInstGetImport(*this, pc, i); diff --git a/CodeGen/src/IrCallWrapperX64.cpp b/CodeGen/src/IrCallWrapperX64.cpp new file mode 100644 index 00000000..8ac5f8bc --- /dev/null +++ b/CodeGen/src/IrCallWrapperX64.cpp @@ -0,0 +1,424 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/IrCallWrapperX64.h" + +#include "Luau/AssemblyBuilderX64.h" +#include "Luau/IrRegAllocX64.h" + +#include "EmitCommonX64.h" + +namespace Luau +{ +namespace CodeGen +{ +namespace X64 +{ + +static bool sameUnderlyingRegister(RegisterX64 a, RegisterX64 b) +{ + SizeX64 underlyingSizeA = a.size == SizeX64::xmmword ? SizeX64::xmmword : SizeX64::qword; + SizeX64 underlyingSizeB = b.size == SizeX64::xmmword ? SizeX64::xmmword : SizeX64::qword; + + return underlyingSizeA == underlyingSizeB && a.index == b.index; +} + +IrCallWrapperX64::IrCallWrapperX64(IrRegAllocX64& regs, AssemblyBuilderX64& build, uint32_t instIdx) + : regs(regs) + , build(build) + , instIdx(instIdx) + , funcOp(noreg) +{ + gprUses.fill(0); + xmmUses.fill(0); +} + +void IrCallWrapperX64::addArgument(SizeX64 targetSize, OperandX64 source, IrOp sourceOp) +{ + // Instruction operands rely on current instruction index for lifetime tracking + LUAU_ASSERT(instIdx != kInvalidInstIdx || sourceOp.kind == IrOpKind::None); + + LUAU_ASSERT(argCount < kMaxCallArguments); + args[argCount++] = {targetSize, source, sourceOp}; +} + +void IrCallWrapperX64::addArgument(SizeX64 targetSize, ScopedRegX64& scopedReg) +{ + LUAU_ASSERT(argCount < kMaxCallArguments); + args[argCount++] = {targetSize, scopedReg.release(), {}}; +} + +void IrCallWrapperX64::call(const OperandX64& func) +{ + funcOp = func; + + assignTargetRegisters(); + + countRegisterUses(); + + for (int i = 0; i < argCount; ++i) + { + CallArgument& arg = args[i]; + + if (arg.sourceOp.kind != IrOpKind::None) + { + if (IrInst* inst = regs.function.asInstOp(arg.sourceOp)) + { + // Source registers are recorded separately from source operands in CallArgument + // If source is the last use of IrInst, clear the register from the operand + if (regs.isLastUseReg(*inst, instIdx)) + inst->regX64 = noreg; + // If it's not the last use and register is volatile, register ownership is taken, which also spills the operand + else if (inst->regX64.size == SizeX64::xmmword || regs.shouldFreeGpr(inst->regX64)) + regs.takeReg(inst->regX64, kInvalidInstIdx); + } + } + + // Immediate values are stored at the end since they are not interfering and target register can still be used temporarily + if (arg.source.cat == CategoryX64::imm) + { + arg.candidate = false; + } + // Arguments passed through stack can be handled immediately + else if (arg.target.cat == CategoryX64::mem) + { + if (arg.source.cat == CategoryX64::mem) + { + ScopedRegX64 tmp{regs, arg.target.memSize}; + + freeSourceRegisters(arg); + + if (arg.source.memSize == SizeX64::none) + build.lea(tmp.reg, arg.source); + else + build.mov(tmp.reg, arg.source); + + build.mov(arg.target, tmp.reg); + } + else + { + freeSourceRegisters(arg); + + build.mov(arg.target, arg.source); + } + + arg.candidate = false; + } + // Skip arguments that are already in their place + else if (arg.source.cat == CategoryX64::reg && sameUnderlyingRegister(arg.target.base, arg.source.base)) + { + freeSourceRegisters(arg); + + // If target is not used as source in other arguments, prevent register allocator from giving it out + if (getRegisterUses(arg.target.base) == 0) + regs.takeReg(arg.target.base, kInvalidInstIdx); + else // Otherwise, make sure we won't free it when last source use is completed + addRegisterUse(arg.target.base); + + arg.candidate = false; + } + } + + // Repeat until we run out of arguments to pass + while (true) + { + // Find target argument register that is not an active source + if (CallArgument* candidate = findNonInterferingArgument()) + { + // This section is only for handling register targets + LUAU_ASSERT(candidate->target.cat == CategoryX64::reg); + + freeSourceRegisters(*candidate); + + LUAU_ASSERT(getRegisterUses(candidate->target.base) == 0); + regs.takeReg(candidate->target.base, kInvalidInstIdx); + + moveToTarget(*candidate); + + candidate->candidate = false; + } + // If all registers cross-interfere (rcx <- rdx, rdx <- rcx), one has to be renamed + else if (RegisterX64 conflict = findConflictingTarget(); conflict != noreg) + { + renameConflictingRegister(conflict); + } + else + { + for (int i = 0; i < argCount; ++i) + LUAU_ASSERT(!args[i].candidate); + break; + } + } + + // Handle immediate arguments last + for (int i = 0; i < argCount; ++i) + { + CallArgument& arg = args[i]; + + if (arg.source.cat == CategoryX64::imm) + { + // There could be a conflict with the function source register, make this argument a candidate to find it + arg.candidate = true; + + if (RegisterX64 conflict = findConflictingTarget(); conflict != noreg) + renameConflictingRegister(conflict); + + if (arg.target.cat == CategoryX64::reg) + regs.takeReg(arg.target.base, kInvalidInstIdx); + + moveToTarget(arg); + + arg.candidate = false; + } + } + + // Free registers used in the function call + removeRegisterUse(funcOp.base); + removeRegisterUse(funcOp.index); + + // Just before the call is made, argument registers are all marked as free in register allocator + for (int i = 0; i < argCount; ++i) + { + CallArgument& arg = args[i]; + + if (arg.target.cat == CategoryX64::reg) + regs.freeReg(arg.target.base); + } + + regs.preserveAndFreeInstValues(); + + regs.assertAllFree(); + + build.call(funcOp); +} + +void IrCallWrapperX64::assignTargetRegisters() +{ + static const std::array kWindowsGprOrder = {rcx, rdx, r8, r9, addr[rsp + 32], addr[rsp + 40]}; + static const std::array kSystemvGprOrder = {rdi, rsi, rdx, rcx, r8, r9}; + + const std::array& gprOrder = build.abi == ABIX64::Windows ? kWindowsGprOrder : kSystemvGprOrder; + static const std::array kXmmOrder = {xmm0, xmm1, xmm2, xmm3}; // Common order for first 4 fp arguments on Windows/SystemV + + int gprPos = 0; + int xmmPos = 0; + + for (int i = 0; i < argCount; i++) + { + CallArgument& arg = args[i]; + + if (arg.targetSize == SizeX64::xmmword) + { + LUAU_ASSERT(size_t(xmmPos) < kXmmOrder.size()); + arg.target = kXmmOrder[xmmPos++]; + + if (build.abi == ABIX64::Windows) + gprPos++; // On Windows, gpr/xmm register positions move in sync + } + else + { + LUAU_ASSERT(size_t(gprPos) < gprOrder.size()); + arg.target = gprOrder[gprPos++]; + + if (build.abi == ABIX64::Windows) + xmmPos++; // On Windows, gpr/xmm register positions move in sync + + // Keep requested argument size + if (arg.target.cat == CategoryX64::reg) + arg.target.base.size = arg.targetSize; + else if (arg.target.cat == CategoryX64::mem) + arg.target.memSize = arg.targetSize; + } + } +} + +void IrCallWrapperX64::countRegisterUses() +{ + for (int i = 0; i < argCount; ++i) + { + addRegisterUse(args[i].source.base); + addRegisterUse(args[i].source.index); + } + + addRegisterUse(funcOp.base); + addRegisterUse(funcOp.index); +} + +CallArgument* IrCallWrapperX64::findNonInterferingArgument() +{ + for (int i = 0; i < argCount; ++i) + { + CallArgument& arg = args[i]; + + if (arg.candidate && !interferesWithActiveSources(arg, i) && !interferesWithOperand(funcOp, arg.target.base)) + return &arg; + } + + return nullptr; +} + +bool IrCallWrapperX64::interferesWithOperand(const OperandX64& op, RegisterX64 reg) const +{ + return sameUnderlyingRegister(op.base, reg) || sameUnderlyingRegister(op.index, reg); +} + +bool IrCallWrapperX64::interferesWithActiveSources(const CallArgument& targetArg, int targetArgIndex) const +{ + for (int i = 0; i < argCount; ++i) + { + const CallArgument& arg = args[i]; + + if (arg.candidate && i != targetArgIndex && interferesWithOperand(arg.source, targetArg.target.base)) + return true; + } + + return false; +} + +bool IrCallWrapperX64::interferesWithActiveTarget(RegisterX64 sourceReg) const +{ + for (int i = 0; i < argCount; ++i) + { + const CallArgument& arg = args[i]; + + if (arg.candidate && sameUnderlyingRegister(arg.target.base, sourceReg)) + return true; + } + + return false; +} + +void IrCallWrapperX64::moveToTarget(CallArgument& arg) +{ + if (arg.source.cat == CategoryX64::reg) + { + RegisterX64 source = arg.source.base; + + if (source.size == SizeX64::xmmword) + build.vmovsd(arg.target, source, source); + else + build.mov(arg.target, source); + } + else if (arg.source.cat == CategoryX64::imm) + { + build.mov(arg.target, arg.source); + } + else + { + if (arg.source.memSize == SizeX64::none) + build.lea(arg.target, arg.source); + else if (arg.target.base.size == SizeX64::xmmword && arg.source.memSize == SizeX64::xmmword) + build.vmovups(arg.target, arg.source); + else if (arg.target.base.size == SizeX64::xmmword) + build.vmovsd(arg.target, arg.source); + else + build.mov(arg.target, arg.source); + } +} + +void IrCallWrapperX64::freeSourceRegisters(CallArgument& arg) +{ + removeRegisterUse(arg.source.base); + removeRegisterUse(arg.source.index); +} + +void IrCallWrapperX64::renameRegister(RegisterX64& target, RegisterX64 reg, RegisterX64 replacement) +{ + if (sameUnderlyingRegister(target, reg)) + { + addRegisterUse(replacement); + removeRegisterUse(target); + + target.index = replacement.index; // Only change index, size is preserved + } +} + +void IrCallWrapperX64::renameSourceRegisters(RegisterX64 reg, RegisterX64 replacement) +{ + for (int i = 0; i < argCount; ++i) + { + CallArgument& arg = args[i]; + + if (arg.candidate) + { + renameRegister(arg.source.base, reg, replacement); + renameRegister(arg.source.index, reg, replacement); + } + } + + renameRegister(funcOp.base, reg, replacement); + renameRegister(funcOp.index, reg, replacement); +} + +RegisterX64 IrCallWrapperX64::findConflictingTarget() const +{ + for (int i = 0; i < argCount; ++i) + { + const CallArgument& arg = args[i]; + + if (arg.candidate) + { + if (interferesWithActiveTarget(arg.source.base)) + return arg.source.base; + + if (interferesWithActiveTarget(arg.source.index)) + return arg.source.index; + } + } + + if (interferesWithActiveTarget(funcOp.base)) + return funcOp.base; + + if (interferesWithActiveTarget(funcOp.index)) + return funcOp.index; + + return noreg; +} + +void IrCallWrapperX64::renameConflictingRegister(RegisterX64 conflict) +{ + // Get a fresh register + RegisterX64 freshReg = conflict.size == SizeX64::xmmword ? regs.allocXmmReg(kInvalidInstIdx) : regs.allocGprReg(conflict.size, kInvalidInstIdx); + + if (conflict.size == SizeX64::xmmword) + build.vmovsd(freshReg, conflict, conflict); + else + build.mov(freshReg, conflict); + + renameSourceRegisters(conflict, freshReg); +} + +int IrCallWrapperX64::getRegisterUses(RegisterX64 reg) const +{ + return reg.size == SizeX64::xmmword ? xmmUses[reg.index] : (reg.size != SizeX64::none ? gprUses[reg.index] : 0); +} + +void IrCallWrapperX64::addRegisterUse(RegisterX64 reg) +{ + if (reg.size == SizeX64::xmmword) + xmmUses[reg.index]++; + else if (reg.size != SizeX64::none) + gprUses[reg.index]++; +} + +void IrCallWrapperX64::removeRegisterUse(RegisterX64 reg) +{ + if (reg.size == SizeX64::xmmword) + { + LUAU_ASSERT(xmmUses[reg.index] != 0); + xmmUses[reg.index]--; + + if (xmmUses[reg.index] == 0) // we don't use persistent xmm regs so no need to call shouldFreeRegister + regs.freeReg(reg); + } + else if (reg.size != SizeX64::none) + { + LUAU_ASSERT(gprUses[reg.index] != 0); + gprUses[reg.index]--; + + if (gprUses[reg.index] == 0 && regs.shouldFreeGpr(reg)) + regs.freeReg(reg); + } +} + +} // namespace X64 +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/src/IrDump.cpp b/CodeGen/src/IrDump.cpp index 3c4e420d..8f299520 100644 --- a/CodeGen/src/IrDump.cpp +++ b/CodeGen/src/IrDump.cpp @@ -100,6 +100,8 @@ const char* getCmdName(IrCmd cmd) return "STORE_DOUBLE"; case IrCmd::STORE_INT: return "STORE_INT"; + case IrCmd::STORE_VECTOR: + return "STORE_VECTOR"; case IrCmd::STORE_TVALUE: return "STORE_TVALUE"; case IrCmd::STORE_NODE_VALUE_TV: @@ -126,6 +128,16 @@ const char* getCmdName(IrCmd cmd) return "MAX_NUM"; case IrCmd::UNM_NUM: return "UNM_NUM"; + case IrCmd::FLOOR_NUM: + return "FLOOR_NUM"; + case IrCmd::CEIL_NUM: + return "CEIL_NUM"; + case IrCmd::ROUND_NUM: + return "ROUND_NUM"; + case IrCmd::SQRT_NUM: + return "SQRT_NUM"; + case IrCmd::ABS_NUM: + return "ABS_NUM"; case IrCmd::NOT_ANY: return "NOT_ANY"; case IrCmd::JUMP: @@ -216,28 +228,20 @@ const char* getCmdName(IrCmd cmd) return "CLOSE_UPVALS"; case IrCmd::CAPTURE: return "CAPTURE"; - case IrCmd::LOP_SETLIST: - return "LOP_SETLIST"; - case IrCmd::LOP_CALL: - return "LOP_CALL"; - case IrCmd::LOP_RETURN: - return "LOP_RETURN"; - case IrCmd::LOP_FORGLOOP: - return "LOP_FORGLOOP"; - case IrCmd::LOP_FORGLOOP_FALLBACK: - return "LOP_FORGLOOP_FALLBACK"; - case IrCmd::LOP_FORGPREP_XNEXT_FALLBACK: - return "LOP_FORGPREP_XNEXT_FALLBACK"; - case IrCmd::LOP_AND: - return "LOP_AND"; - case IrCmd::LOP_ANDK: - return "LOP_ANDK"; - case IrCmd::LOP_OR: - return "LOP_OR"; - case IrCmd::LOP_ORK: - return "LOP_ORK"; - case IrCmd::LOP_COVERAGE: - return "LOP_COVERAGE"; + case IrCmd::SETLIST: + return "SETLIST"; + case IrCmd::CALL: + return "CALL"; + case IrCmd::RETURN: + return "RETURN"; + case IrCmd::FORGLOOP: + return "FORGLOOP"; + case IrCmd::FORGLOOP_FALLBACK: + return "FORGLOOP_FALLBACK"; + case IrCmd::FORGPREP_XNEXT_FALLBACK: + return "FORGPREP_XNEXT_FALLBACK"; + case IrCmd::COVERAGE: + return "COVERAGE"; case IrCmd::FALLBACK_GETGLOBAL: return "FALLBACK_GETGLOBAL"; case IrCmd::FALLBACK_SETGLOBAL: @@ -335,13 +339,13 @@ void toString(IrToStringContext& ctx, IrOp op) append(ctx.result, "%s_%u", getBlockKindName(ctx.blocks[op.index].kind), op.index); break; case IrOpKind::VmReg: - append(ctx.result, "R%u", op.index); + append(ctx.result, "R%d", vmRegOp(op)); break; case IrOpKind::VmConst: - append(ctx.result, "K%u", op.index); + append(ctx.result, "K%d", vmConstOp(op)); break; case IrOpKind::VmUpvalue: - append(ctx.result, "U%u", op.index); + append(ctx.result, "U%d", vmUpvalueOp(op)); break; } } @@ -455,7 +459,7 @@ void toStringDetailed(IrToStringContext& ctx, const IrBlock& block, uint32_t ind } // Predecessor list - if (!ctx.cfg.predecessors.empty()) + if (index < ctx.cfg.predecessorsOffsets.size()) { BlockIteratorWrapper pred = predecessors(ctx.cfg, index); @@ -469,7 +473,7 @@ void toStringDetailed(IrToStringContext& ctx, const IrBlock& block, uint32_t ind } // Successor list - if (!ctx.cfg.successors.empty()) + if (index < ctx.cfg.successorsOffsets.size()) { BlockIteratorWrapper succ = successors(ctx.cfg, index); @@ -509,14 +513,14 @@ void toStringDetailed(IrToStringContext& ctx, const IrBlock& block, uint32_t ind } } -std::string toString(IrFunction& function, bool includeUseInfo) +std::string toString(const IrFunction& function, bool includeUseInfo) { std::string result; IrToStringContext ctx{result, function.blocks, function.constants, function.cfg}; for (size_t i = 0; i < function.blocks.size(); i++) { - IrBlock& block = function.blocks[i]; + const IrBlock& block = function.blocks[i]; if (block.kind == IrBlockKind::Dead) continue; @@ -532,7 +536,7 @@ std::string toString(IrFunction& function, bool includeUseInfo) // To allow dumping blocks that are still being constructed, we can't rely on terminator and need a bounds check for (uint32_t index = block.start; index <= block.finish && index < uint32_t(function.instructions.size()); index++) { - IrInst& inst = function.instructions[index]; + const IrInst& inst = function.instructions[index]; // Skip pseudo instructions unless they are still referenced if (isPseudo(inst.cmd) && inst.useCount == 0) @@ -548,7 +552,7 @@ std::string toString(IrFunction& function, bool includeUseInfo) return result; } -std::string dump(IrFunction& function) +std::string dump(const IrFunction& function) { std::string result = toString(function, /* includeUseInfo */ true); @@ -557,12 +561,12 @@ std::string dump(IrFunction& function) return result; } -std::string toDot(IrFunction& function, bool includeInst) +std::string toDot(const IrFunction& function, bool includeInst) { std::string result; IrToStringContext ctx{result, function.blocks, function.constants, function.cfg}; - auto appendLabelRegset = [&ctx](std::vector& regSets, size_t blockIdx, const char* name) { + auto appendLabelRegset = [&ctx](const std::vector& regSets, size_t blockIdx, const char* name) { if (blockIdx < regSets.size()) { const RegisterSet& rs = regSets[blockIdx]; @@ -581,7 +585,7 @@ std::string toDot(IrFunction& function, bool includeInst) for (size_t i = 0; i < function.blocks.size(); i++) { - IrBlock& block = function.blocks[i]; + const IrBlock& block = function.blocks[i]; append(ctx.result, "b%u [", unsigned(i)); @@ -599,7 +603,7 @@ std::string toDot(IrFunction& function, bool includeInst) { for (uint32_t instIdx = block.start; instIdx <= block.finish; instIdx++) { - IrInst& inst = function.instructions[instIdx]; + const IrInst& inst = function.instructions[instIdx]; // Skip pseudo instructions unless they are still referenced if (isPseudo(inst.cmd) && inst.useCount == 0) @@ -618,14 +622,14 @@ std::string toDot(IrFunction& function, bool includeInst) for (size_t i = 0; i < function.blocks.size(); i++) { - IrBlock& block = function.blocks[i]; + const IrBlock& block = function.blocks[i]; if (block.start == ~0u) continue; for (uint32_t instIdx = block.start; instIdx != ~0u && instIdx <= block.finish; instIdx++) { - IrInst& inst = function.instructions[instIdx]; + const IrInst& inst = function.instructions[instIdx]; auto checkOp = [&](IrOp op) { if (op.kind == IrOpKind::Block) @@ -651,7 +655,7 @@ std::string toDot(IrFunction& function, bool includeInst) return result; } -std::string dumpDot(IrFunction& function, bool includeInst) +std::string dumpDot(const IrFunction& function, bool includeInst) { std::string result = toDot(function, includeInst); diff --git a/CodeGen/src/IrLoweringA64.cpp b/CodeGen/src/IrLoweringA64.cpp new file mode 100644 index 00000000..7f0305cc --- /dev/null +++ b/CodeGen/src/IrLoweringA64.cpp @@ -0,0 +1,1363 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "IrLoweringA64.h" + +#include "Luau/CodeGen.h" +#include "Luau/DenseHash.h" +#include "Luau/IrAnalysis.h" +#include "Luau/IrDump.h" +#include "Luau/IrUtils.h" + +#include "EmitCommonA64.h" +#include "EmitInstructionA64.h" +#include "NativeState.h" + +#include "lstate.h" +#include "lgc.h" + +// TODO: Eventually this can go away +// #define TRACE + +namespace Luau +{ +namespace CodeGen +{ +namespace A64 +{ + +#ifdef TRACE +struct LoweringStatsA64 +{ + size_t can; + size_t total; + + ~LoweringStatsA64() + { + if (total) + printf("A64 lowering succeeded for %.1f%% functions (%d/%d)\n", double(can) / double(total) * 100, int(can), int(total)); + } +} gStatsA64; +#endif + +inline ConditionA64 getConditionFP(IrCondition cond) +{ + switch (cond) + { + case IrCondition::Equal: + return ConditionA64::Equal; + + case IrCondition::NotEqual: + return ConditionA64::NotEqual; + + case IrCondition::Less: + return ConditionA64::Minus; + + case IrCondition::NotLess: + return ConditionA64::Plus; + + case IrCondition::LessEqual: + return ConditionA64::UnsignedLessEqual; + + case IrCondition::NotLessEqual: + return ConditionA64::UnsignedGreater; + + case IrCondition::Greater: + return ConditionA64::Greater; + + case IrCondition::NotGreater: + return ConditionA64::LessEqual; + + case IrCondition::GreaterEqual: + return ConditionA64::GreaterEqual; + + case IrCondition::NotGreaterEqual: + return ConditionA64::Less; + + default: + LUAU_ASSERT(!"Unexpected condition code"); + return ConditionA64::Always; + } +} + +// TODO: instead of temp1/temp2 we can take a register that we will use for ra->value; that way callers to this function will be able to use it when +// calling luaC_barrier* +static void checkObjectBarrierConditions(AssemblyBuilderA64& build, RegisterA64 object, RegisterA64 temp1, RegisterA64 temp2, int ra, Label& skip) +{ + RegisterA64 temp1w = castReg(KindA64::w, temp1); + RegisterA64 temp2w = castReg(KindA64::w, temp2); + + // iscollectable(ra) + build.ldr(temp1w, mem(rBase, ra * sizeof(TValue) + offsetof(TValue, tt))); + build.cmp(temp1w, LUA_TSTRING); + build.b(ConditionA64::Less, skip); + + // isblack(obj2gco(o)) + // TODO: conditional bit test with BLACKBIT + build.ldrb(temp1w, mem(object, offsetof(GCheader, marked))); + build.mov(temp2w, bitmask(BLACKBIT)); + build.and_(temp1w, temp1w, temp2w); + build.cbz(temp1w, skip); + + // iswhite(gcvalue(ra)) + // TODO: tst with bitmask(WHITE0BIT, WHITE1BIT) + build.ldr(temp1, mem(rBase, ra * sizeof(TValue) + offsetof(TValue, value))); + build.ldrb(temp1w, mem(temp1, offsetof(GCheader, marked))); + build.mov(temp2w, bit2mask(WHITE0BIT, WHITE1BIT)); + build.and_(temp1w, temp1w, temp2w); + build.cbz(temp1w, skip); +} + +IrLoweringA64::IrLoweringA64(AssemblyBuilderA64& build, ModuleHelpers& helpers, NativeState& data, Proto* proto, IrFunction& function) + : build(build) + , helpers(helpers) + , data(data) + , proto(proto) + , function(function) + , regs(function, {{x0, x15}, {q0, q7}, {q16, q31}}) +{ + // In order to allocate registers during lowering, we need to know where instruction results are last used + updateLastUseLocations(function); +} + +// TODO: Eventually this can go away +bool IrLoweringA64::canLower(const IrFunction& function) +{ +#ifdef TRACE + gStatsA64.total++; +#endif + + for (const IrInst& inst : function.instructions) + { + switch (inst.cmd) + { + case IrCmd::NOP: + case IrCmd::LOAD_TAG: + case IrCmd::LOAD_POINTER: + case IrCmd::LOAD_DOUBLE: + case IrCmd::LOAD_INT: + case IrCmd::LOAD_TVALUE: + case IrCmd::LOAD_NODE_VALUE_TV: + case IrCmd::LOAD_ENV: + case IrCmd::GET_ARR_ADDR: + case IrCmd::GET_SLOT_NODE_ADDR: + case IrCmd::GET_HASH_NODE_ADDR: + case IrCmd::STORE_TAG: + case IrCmd::STORE_POINTER: + case IrCmd::STORE_DOUBLE: + case IrCmd::STORE_INT: + case IrCmd::STORE_TVALUE: + case IrCmd::STORE_NODE_VALUE_TV: + case IrCmd::ADD_INT: + case IrCmd::SUB_INT: + case IrCmd::ADD_NUM: + case IrCmd::SUB_NUM: + case IrCmd::MUL_NUM: + case IrCmd::DIV_NUM: + case IrCmd::MOD_NUM: + case IrCmd::POW_NUM: + case IrCmd::MIN_NUM: + case IrCmd::MAX_NUM: + case IrCmd::UNM_NUM: + case IrCmd::FLOOR_NUM: + case IrCmd::CEIL_NUM: + case IrCmd::ROUND_NUM: + case IrCmd::SQRT_NUM: + case IrCmd::ABS_NUM: + case IrCmd::JUMP: + case IrCmd::JUMP_IF_TRUTHY: + case IrCmd::JUMP_IF_FALSY: + case IrCmd::JUMP_EQ_TAG: + case IrCmd::JUMP_EQ_INT: + case IrCmd::JUMP_EQ_POINTER: + case IrCmd::JUMP_CMP_NUM: + case IrCmd::JUMP_CMP_ANY: + case IrCmd::TABLE_LEN: + case IrCmd::NEW_TABLE: + case IrCmd::DUP_TABLE: + case IrCmd::TRY_NUM_TO_INDEX: + case IrCmd::INT_TO_NUM: + case IrCmd::ADJUST_STACK_TO_REG: + case IrCmd::ADJUST_STACK_TO_TOP: + case IrCmd::INVOKE_FASTCALL: + case IrCmd::CHECK_FASTCALL_RES: + case IrCmd::DO_ARITH: + case IrCmd::DO_LEN: + case IrCmd::GET_TABLE: + case IrCmd::SET_TABLE: + case IrCmd::GET_IMPORT: + case IrCmd::CONCAT: + case IrCmd::GET_UPVALUE: + case IrCmd::SET_UPVALUE: + case IrCmd::PREPARE_FORN: + case IrCmd::CHECK_TAG: + case IrCmd::CHECK_READONLY: + case IrCmd::CHECK_NO_METATABLE: + case IrCmd::CHECK_SAFE_ENV: + case IrCmd::CHECK_ARRAY_SIZE: + case IrCmd::CHECK_SLOT_MATCH: + case IrCmd::INTERRUPT: + case IrCmd::CHECK_GC: + case IrCmd::BARRIER_OBJ: + case IrCmd::BARRIER_TABLE_BACK: + case IrCmd::BARRIER_TABLE_FORWARD: + case IrCmd::SET_SAVEDPC: + case IrCmd::CLOSE_UPVALS: + case IrCmd::CAPTURE: + case IrCmd::CALL: + case IrCmd::RETURN: + case IrCmd::FALLBACK_GETGLOBAL: + case IrCmd::FALLBACK_SETGLOBAL: + case IrCmd::FALLBACK_GETTABLEKS: + case IrCmd::FALLBACK_SETTABLEKS: + case IrCmd::FALLBACK_NAMECALL: + case IrCmd::FALLBACK_PREPVARARGS: + case IrCmd::FALLBACK_GETVARARGS: + case IrCmd::FALLBACK_NEWCLOSURE: + case IrCmd::FALLBACK_DUPCLOSURE: + case IrCmd::SUBSTITUTE: + continue; + + default: +#ifdef TRACE + printf("A64 lowering missing %s\n", getCmdName(inst.cmd)); +#endif + return false; + } + } + +#ifdef TRACE + gStatsA64.can++; +#endif + + return true; +} + +void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) +{ + switch (inst.cmd) + { + case IrCmd::LOAD_TAG: + { + inst.regA64 = regs.allocReg(KindA64::w); + AddressA64 addr = tempAddr(inst.a, offsetof(TValue, tt)); + build.ldr(inst.regA64, addr); + break; + } + case IrCmd::LOAD_POINTER: + { + inst.regA64 = regs.allocReg(KindA64::x); + AddressA64 addr = tempAddr(inst.a, offsetof(TValue, value)); + build.ldr(inst.regA64, addr); + break; + } + case IrCmd::LOAD_DOUBLE: + { + inst.regA64 = regs.allocReg(KindA64::d); + AddressA64 addr = tempAddr(inst.a, offsetof(TValue, value)); + build.ldr(inst.regA64, addr); + break; + } + case IrCmd::LOAD_INT: + { + inst.regA64 = regs.allocReg(KindA64::w); + AddressA64 addr = tempAddr(inst.a, offsetof(TValue, value)); + build.ldr(inst.regA64, addr); + break; + } + case IrCmd::LOAD_TVALUE: + { + inst.regA64 = regs.allocReg(KindA64::q); + AddressA64 addr = tempAddr(inst.a, 0); + build.ldr(inst.regA64, addr); + break; + } + case IrCmd::LOAD_NODE_VALUE_TV: + { + inst.regA64 = regs.allocReg(KindA64::q); + build.ldr(inst.regA64, mem(regOp(inst.a), offsetof(LuaNode, val))); + break; + } + case IrCmd::LOAD_ENV: + inst.regA64 = regs.allocReg(KindA64::x); + build.ldr(inst.regA64, mem(rClosure, offsetof(Closure, env))); + break; + case IrCmd::GET_ARR_ADDR: + { + inst.regA64 = regs.allocReg(KindA64::x); + build.ldr(inst.regA64, mem(regOp(inst.a), offsetof(Table, array))); + + if (inst.b.kind == IrOpKind::Inst) + { + // TODO: This is a temporary hack that reads wN register as if it was xN. This should use unsigned extension shift once we support it. + build.add(inst.regA64, inst.regA64, castReg(KindA64::x, regOp(inst.b)), kTValueSizeLog2); + } + else if (inst.b.kind == IrOpKind::Constant) + { + LUAU_ASSERT(size_t(intOp(inst.b)) <= AssemblyBuilderA64::kMaxImmediate >> kTValueSizeLog2); // TODO: handle out of range values + build.add(inst.regA64, inst.regA64, uint16_t(intOp(inst.b) << kTValueSizeLog2)); + } + else + LUAU_ASSERT(!"Unsupported instruction form"); + break; + } + case IrCmd::GET_SLOT_NODE_ADDR: + { + inst.regA64 = regs.allocReuse(KindA64::x, index, {inst.a}); + RegisterA64 temp1 = regs.allocTemp(KindA64::x); + RegisterA64 temp1w = castReg(KindA64::w, temp1); + RegisterA64 temp2 = regs.allocTemp(KindA64::w); + + // TODO: this can use a slightly more efficient sequence with a 4b load + and-with-right-shift for pcpos<1024 but we don't support it yet. + build.mov(temp1, uintOp(inst.b) * sizeof(Instruction) + kOffsetOfInstructionC); + build.ldrb(temp1w, mem(rCode, temp1)); + build.ldrb(temp2, mem(regOp(inst.a), offsetof(Table, nodemask8))); + build.and_(temp2, temp2, temp1w); + + // note: this may clobber inst.a, so it's important that we don't use it after this + build.ldr(inst.regA64, mem(regOp(inst.a), offsetof(Table, node))); + // TODO: This is a temporary hack that reads wN register as if it was xN. This should use unsigned extension shift once we support it. + build.add(inst.regA64, inst.regA64, castReg(KindA64::x, temp2), kLuaNodeSizeLog2); + break; + } + case IrCmd::GET_HASH_NODE_ADDR: + { + inst.regA64 = regs.allocReuse(KindA64::x, index, {inst.a}); + RegisterA64 temp1 = regs.allocTemp(KindA64::w); + RegisterA64 temp2 = regs.allocTemp(KindA64::w); + + // TODO: this can use bic (andnot) to do hash & ~(-1 << lsizenode) instead but we don't support it yet + build.mov(temp1, 1); + build.ldrb(temp2, mem(regOp(inst.a), offsetof(Table, lsizenode))); + build.lsl(temp1, temp1, temp2); + build.sub(temp1, temp1, 1); + build.mov(temp2, uintOp(inst.b)); + build.and_(temp2, temp2, temp1); + + // note: this may clobber inst.a, so it's important that we don't use it after this + build.ldr(inst.regA64, mem(regOp(inst.a), offsetof(Table, node))); + // TODO: This is a temporary hack that reads wN register as if it was xN. This should use unsigned extension shift once we support it. + build.add(inst.regA64, inst.regA64, castReg(KindA64::x, temp2), kLuaNodeSizeLog2); + break; + } + case IrCmd::STORE_TAG: + { + RegisterA64 temp = regs.allocTemp(KindA64::w); + AddressA64 addr = tempAddr(inst.a, offsetof(TValue, tt)); + build.mov(temp, tagOp(inst.b)); + build.str(temp, addr); + break; + } + case IrCmd::STORE_POINTER: + { + AddressA64 addr = tempAddr(inst.a, offsetof(TValue, value)); + build.str(regOp(inst.b), addr); + break; + } + case IrCmd::STORE_DOUBLE: + { + RegisterA64 temp = tempDouble(inst.b); + AddressA64 addr = tempAddr(inst.a, offsetof(TValue, value)); + build.str(temp, addr); + break; + } + case IrCmd::STORE_INT: + { + RegisterA64 temp = tempInt(inst.b); + AddressA64 addr = tempAddr(inst.a, offsetof(TValue, value)); + build.str(temp, addr); + break; + } + case IrCmd::STORE_TVALUE: + { + AddressA64 addr = tempAddr(inst.a, 0); + build.str(regOp(inst.b), addr); + break; + } + case IrCmd::STORE_NODE_VALUE_TV: + build.str(regOp(inst.b), mem(regOp(inst.a), offsetof(LuaNode, val))); + break; + case IrCmd::ADD_INT: + LUAU_ASSERT(unsigned(intOp(inst.b)) <= AssemblyBuilderA64::kMaxImmediate); // TODO: handle out of range values + inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.a}); + build.add(inst.regA64, regOp(inst.a), uint16_t(intOp(inst.b))); + break; + case IrCmd::SUB_INT: + LUAU_ASSERT(unsigned(intOp(inst.b)) <= AssemblyBuilderA64::kMaxImmediate); // TODO: handle out of range values + inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.a}); + build.sub(inst.regA64, regOp(inst.a), uint16_t(intOp(inst.b))); + break; + case IrCmd::ADD_NUM: + { + inst.regA64 = regs.allocReuse(KindA64::d, index, {inst.a, inst.b}); + RegisterA64 temp1 = tempDouble(inst.a); + RegisterA64 temp2 = tempDouble(inst.b); + build.fadd(inst.regA64, temp1, temp2); + break; + } + case IrCmd::SUB_NUM: + { + inst.regA64 = regs.allocReuse(KindA64::d, index, {inst.a, inst.b}); + RegisterA64 temp1 = tempDouble(inst.a); + RegisterA64 temp2 = tempDouble(inst.b); + build.fsub(inst.regA64, temp1, temp2); + break; + } + case IrCmd::MUL_NUM: + { + inst.regA64 = regs.allocReuse(KindA64::d, index, {inst.a, inst.b}); + RegisterA64 temp1 = tempDouble(inst.a); + RegisterA64 temp2 = tempDouble(inst.b); + build.fmul(inst.regA64, temp1, temp2); + break; + } + case IrCmd::DIV_NUM: + { + inst.regA64 = regs.allocReuse(KindA64::d, index, {inst.a, inst.b}); + RegisterA64 temp1 = tempDouble(inst.a); + RegisterA64 temp2 = tempDouble(inst.b); + build.fdiv(inst.regA64, temp1, temp2); + break; + } + case IrCmd::MOD_NUM: + { + inst.regA64 = regs.allocReg(KindA64::d); // can't allocReuse because both A and B are used twice + RegisterA64 temp1 = tempDouble(inst.a); + RegisterA64 temp2 = tempDouble(inst.b); + build.fdiv(inst.regA64, temp1, temp2); + build.frintm(inst.regA64, inst.regA64); + build.fmul(inst.regA64, inst.regA64, temp2); + build.fsub(inst.regA64, temp1, inst.regA64); + break; + } + case IrCmd::POW_NUM: + { + // TODO: this instruction clobbers all registers because of a call but it's unclear how to assert that cleanly atm + inst.regA64 = regs.allocReg(KindA64::d); + RegisterA64 temp1 = tempDouble(inst.a); + RegisterA64 temp2 = tempDouble(inst.b); + build.fmov(d0, temp1); // TODO: aliasing hazard + build.fmov(d1, temp2); // TODO: aliasing hazard + build.ldr(x0, mem(rNativeContext, offsetof(NativeContext, libm_pow))); + build.blr(x0); + build.fmov(inst.regA64, d0); + break; + } + case IrCmd::MIN_NUM: + { + inst.regA64 = regs.allocReuse(KindA64::d, index, {inst.a, inst.b}); + RegisterA64 temp1 = tempDouble(inst.a); + RegisterA64 temp2 = tempDouble(inst.b); + build.fcmp(temp1, temp2); + build.fcsel(inst.regA64, temp1, temp2, getConditionFP(IrCondition::Less)); + break; + } + case IrCmd::MAX_NUM: + { + inst.regA64 = regs.allocReuse(KindA64::d, index, {inst.a, inst.b}); + RegisterA64 temp1 = tempDouble(inst.a); + RegisterA64 temp2 = tempDouble(inst.b); + build.fcmp(temp1, temp2); + build.fcsel(inst.regA64, temp1, temp2, getConditionFP(IrCondition::Greater)); + break; + } + case IrCmd::UNM_NUM: + { + inst.regA64 = regs.allocReuse(KindA64::d, index, {inst.a}); + RegisterA64 temp = tempDouble(inst.a); + build.fneg(inst.regA64, temp); + break; + } + case IrCmd::FLOOR_NUM: + { + inst.regA64 = regs.allocReuse(KindA64::d, index, {inst.a}); + RegisterA64 temp = tempDouble(inst.a); + build.frintm(inst.regA64, temp); + break; + } + case IrCmd::CEIL_NUM: + { + inst.regA64 = regs.allocReuse(KindA64::d, index, {inst.a}); + RegisterA64 temp = tempDouble(inst.a); + build.frintp(inst.regA64, temp); + break; + } + case IrCmd::ROUND_NUM: + { + inst.regA64 = regs.allocReuse(KindA64::d, index, {inst.a}); + RegisterA64 temp = tempDouble(inst.a); + build.frinta(inst.regA64, temp); + break; + } + case IrCmd::SQRT_NUM: + { + inst.regA64 = regs.allocReuse(KindA64::d, index, {inst.a}); + RegisterA64 temp = tempDouble(inst.a); + build.fsqrt(inst.regA64, temp); + break; + } + case IrCmd::ABS_NUM: + { + inst.regA64 = regs.allocReuse(KindA64::d, index, {inst.a}); + RegisterA64 temp = tempDouble(inst.a); + build.fabs(inst.regA64, temp); + break; + } + case IrCmd::JUMP: + jumpOrFallthrough(blockOp(inst.a), next); + break; + case IrCmd::JUMP_IF_TRUTHY: + { + RegisterA64 temp = regs.allocTemp(KindA64::w); + build.ldr(temp, mem(rBase, vmRegOp(inst.a) * sizeof(TValue) + offsetof(TValue, tt))); + // nil => falsy + LUAU_ASSERT(LUA_TNIL == 0); + build.cbz(temp, labelOp(inst.c)); + // not boolean => truthy + build.cmp(temp, LUA_TBOOLEAN); + build.b(ConditionA64::NotEqual, labelOp(inst.b)); + // compare boolean value + build.ldr(temp, mem(rBase, vmRegOp(inst.a) * sizeof(TValue) + offsetof(TValue, value))); + build.cbnz(temp, labelOp(inst.b)); + jumpOrFallthrough(blockOp(inst.c), next); + break; + } + case IrCmd::JUMP_IF_FALSY: + { + RegisterA64 temp = regs.allocTemp(KindA64::w); + build.ldr(temp, mem(rBase, vmRegOp(inst.a) * sizeof(TValue) + offsetof(TValue, tt))); + // nil => falsy + LUAU_ASSERT(LUA_TNIL == 0); + build.cbz(temp, labelOp(inst.b)); + // not boolean => truthy + build.cmp(temp, LUA_TBOOLEAN); + build.b(ConditionA64::NotEqual, labelOp(inst.c)); + // compare boolean value + build.ldr(temp, mem(rBase, vmRegOp(inst.a) * sizeof(TValue) + offsetof(TValue, value))); + build.cbz(temp, labelOp(inst.b)); + jumpOrFallthrough(blockOp(inst.c), next); + break; + } + case IrCmd::JUMP_EQ_TAG: + if (inst.b.kind == IrOpKind::Constant) + build.cmp(regOp(inst.a), tagOp(inst.b)); + else if (inst.b.kind == IrOpKind::Inst) + build.cmp(regOp(inst.a), regOp(inst.b)); + else + LUAU_ASSERT(!"Unsupported instruction form"); + + if (isFallthroughBlock(blockOp(inst.d), next)) + { + build.b(ConditionA64::Equal, labelOp(inst.c)); + jumpOrFallthrough(blockOp(inst.d), next); + } + else + { + build.b(ConditionA64::NotEqual, labelOp(inst.d)); + jumpOrFallthrough(blockOp(inst.c), next); + } + break; + case IrCmd::JUMP_EQ_INT: + LUAU_ASSERT(unsigned(intOp(inst.b)) <= AssemblyBuilderA64::kMaxImmediate); + build.cmp(regOp(inst.a), uint16_t(intOp(inst.b))); + build.b(ConditionA64::Equal, labelOp(inst.c)); + jumpOrFallthrough(blockOp(inst.d), next); + break; + case IrCmd::JUMP_EQ_POINTER: + build.cmp(regOp(inst.a), regOp(inst.b)); + build.b(ConditionA64::Equal, labelOp(inst.c)); + jumpOrFallthrough(blockOp(inst.d), next); + break; + case IrCmd::JUMP_CMP_NUM: + { + IrCondition cond = conditionOp(inst.c); + + RegisterA64 temp1 = tempDouble(inst.a); + RegisterA64 temp2 = tempDouble(inst.b); + + build.fcmp(temp1, temp2); + build.b(getConditionFP(cond), labelOp(inst.d)); + jumpOrFallthrough(blockOp(inst.e), next); + break; + } + case IrCmd::JUMP_CMP_ANY: + { + IrCondition cond = conditionOp(inst.c); + + regs.assertAllFree(); + build.mov(x0, rState); + build.add(x1, rBase, uint16_t(vmRegOp(inst.a) * sizeof(TValue))); + build.add(x2, rBase, uint16_t(vmRegOp(inst.b) * sizeof(TValue))); + + if (cond == IrCondition::NotLessEqual || cond == IrCondition::LessEqual) + build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaV_lessequal))); + else if (cond == IrCondition::NotLess || cond == IrCondition::Less) + build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaV_lessthan))); + else if (cond == IrCondition::NotEqual || cond == IrCondition::Equal) + build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaV_equalval))); + else + LUAU_ASSERT(!"Unsupported condition"); + + build.blr(x3); + + emitUpdateBase(build); + + if (cond == IrCondition::NotLessEqual || cond == IrCondition::NotLess || cond == IrCondition::NotEqual) + build.cbz(x0, labelOp(inst.d)); + else + build.cbnz(x0, labelOp(inst.d)); + jumpOrFallthrough(blockOp(inst.e), next); + break; + } + case IrCmd::TABLE_LEN: + { + regs.assertAllFreeExcept(regOp(inst.a)); + build.mov(x0, regOp(inst.a)); // TODO: minor aliasing hazard + build.ldr(x1, mem(rNativeContext, offsetof(NativeContext, luaH_getn))); + build.blr(x1); + + inst.regA64 = regs.allocReg(KindA64::d); + build.scvtf(inst.regA64, x0); + break; + } + case IrCmd::NEW_TABLE: + { + regs.assertAllFree(); + build.mov(x0, rState); + build.mov(x1, uintOp(inst.a)); + build.mov(x2, uintOp(inst.b)); + build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaH_new))); + build.blr(x3); + // TODO: we could takeReg x0 but it's unclear if we will be able to keep x0 allocatable due to aliasing concerns + inst.regA64 = regs.allocReg(KindA64::x); + build.mov(inst.regA64, x0); + break; + } + case IrCmd::DUP_TABLE: + { + regs.assertAllFreeExcept(regOp(inst.a)); + build.mov(x0, rState); + build.mov(x1, regOp(inst.a)); // TODO: aliasing hazard + build.ldr(x2, mem(rNativeContext, offsetof(NativeContext, luaH_clone))); + build.blr(x2); + // TODO: we could takeReg x0 but it's unclear if we will be able to keep x0 allocatable due to aliasing concerns + inst.regA64 = regs.allocReg(KindA64::x); + build.mov(inst.regA64, x0); + break; + } + case IrCmd::TRY_NUM_TO_INDEX: + { + inst.regA64 = regs.allocReg(KindA64::w); + RegisterA64 temp1 = tempDouble(inst.a); + + if (build.features & Feature_JSCVT) + { + build.fjcvtzs(inst.regA64, temp1); // fjcvtzs sets PSTATE.Z (equal) iff conversion is exact + build.b(ConditionA64::NotEqual, labelOp(inst.b)); + } + else + { + RegisterA64 temp2 = regs.allocTemp(KindA64::d); + + build.fcvtzs(inst.regA64, temp1); + build.scvtf(temp2, inst.regA64); + build.fcmp(temp1, temp2); + build.b(ConditionA64::NotEqual, labelOp(inst.b)); + } + break; + } + case IrCmd::INT_TO_NUM: + { + inst.regA64 = regs.allocReg(KindA64::d); + RegisterA64 temp = tempInt(inst.a); + build.scvtf(inst.regA64, temp); + break; + } + case IrCmd::ADJUST_STACK_TO_REG: + { + RegisterA64 temp = regs.allocTemp(KindA64::x); + + if (inst.b.kind == IrOpKind::Constant) + { + build.add(temp, rBase, uint16_t((vmRegOp(inst.a) + intOp(inst.b)) * sizeof(TValue))); + build.str(temp, mem(rState, offsetof(lua_State, top))); + } + else if (inst.b.kind == IrOpKind::Inst) + { + build.add(temp, rBase, uint16_t(vmRegOp(inst.a) * sizeof(TValue))); + // TODO: This is a temporary hack that reads wN register as if it was xN. This should use unsigned extension shift once we support it. + build.add(temp, temp, castReg(KindA64::x, regOp(inst.b)), kTValueSizeLog2); + build.str(temp, mem(rState, offsetof(lua_State, top))); + } + else + LUAU_ASSERT(!"Unsupported instruction form"); + break; + } + case IrCmd::ADJUST_STACK_TO_TOP: + { + RegisterA64 temp = regs.allocTemp(KindA64::x); + build.ldr(temp, mem(rState, offsetof(lua_State, ci))); + build.ldr(temp, mem(temp, offsetof(CallInfo, top))); + build.str(temp, mem(rState, offsetof(lua_State, top))); + break; + } + case IrCmd::INVOKE_FASTCALL: + { + regs.assertAllFree(); + build.mov(x0, rState); + build.add(x1, rBase, uint16_t(vmRegOp(inst.b) * sizeof(TValue))); + build.add(x2, rBase, uint16_t(vmRegOp(inst.c) * sizeof(TValue))); + build.mov(w3, intOp(inst.f)); // nresults + + if (inst.d.kind == IrOpKind::VmReg) + build.add(x4, rBase, uint16_t(vmRegOp(inst.d) * sizeof(TValue))); + else if (inst.d.kind == IrOpKind::VmConst) + { + // TODO: refactor into a common helper + if (vmConstOp(inst.d) * sizeof(TValue) <= AssemblyBuilderA64::kMaxImmediate) + { + build.add(x4, rConstants, uint16_t(vmConstOp(inst.d) * sizeof(TValue))); + } + else + { + build.mov(x4, vmConstOp(inst.d) * sizeof(TValue)); + build.add(x4, rConstants, x4); + } + } + else + LUAU_ASSERT(boolOp(inst.d) == false); + + // nparams + if (intOp(inst.e) == LUA_MULTRET) + { + // L->top - (ra + 1) + build.ldr(x5, mem(rState, offsetof(lua_State, top))); + build.sub(x5, x5, rBase); + build.sub(x5, x5, uint16_t((vmRegOp(inst.b) + 1) * sizeof(TValue))); + // TODO: this can use immediate shift right or maybe add/sub with shift right but we don't implement them yet + build.mov(x6, kTValueSizeLog2); + build.lsr(x5, x5, x6); + } + else + build.mov(w5, intOp(inst.e)); + + build.ldr(x6, mem(rNativeContext, offsetof(NativeContext, luauF_table) + uintOp(inst.a) * sizeof(luau_FastFunction))); + build.blr(x6); + + // TODO: we could takeReg w0 but it's unclear if we will be able to keep x0 allocatable due to aliasing concerns + inst.regA64 = regs.allocReg(KindA64::w); + build.mov(inst.regA64, w0); + break; + } + case IrCmd::CHECK_FASTCALL_RES: + build.cmp(regOp(inst.a), 0); + build.b(ConditionA64::Less, labelOp(inst.b)); + break; + case IrCmd::DO_ARITH: + regs.assertAllFree(); + build.mov(x0, rState); + build.add(x1, rBase, uint16_t(vmRegOp(inst.a) * sizeof(TValue))); + build.add(x2, rBase, uint16_t(vmRegOp(inst.b) * sizeof(TValue))); + + if (inst.c.kind == IrOpKind::VmConst) + { + // TODO: refactor into a common helper + if (vmConstOp(inst.c) * sizeof(TValue) <= AssemblyBuilderA64::kMaxImmediate) + { + build.add(x3, rConstants, uint16_t(vmConstOp(inst.c) * sizeof(TValue))); + } + else + { + build.mov(x3, vmConstOp(inst.c) * sizeof(TValue)); + build.add(x3, rConstants, x3); + } + } + else + build.add(x3, rBase, uint16_t(vmRegOp(inst.c) * sizeof(TValue))); + + build.mov(w4, TMS(intOp(inst.d))); + build.ldr(x5, mem(rNativeContext, offsetof(NativeContext, luaV_doarith))); + build.blr(x5); + + emitUpdateBase(build); + break; + case IrCmd::DO_LEN: + regs.assertAllFree(); + build.mov(x0, rState); + build.add(x1, rBase, uint16_t(vmRegOp(inst.a) * sizeof(TValue))); + build.add(x2, rBase, uint16_t(vmRegOp(inst.b) * sizeof(TValue))); + build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaV_dolen))); + build.blr(x3); + + emitUpdateBase(build); + break; + case IrCmd::GET_TABLE: + regs.assertAllFree(); + build.mov(x0, rState); + build.add(x1, rBase, uint16_t(vmRegOp(inst.b) * sizeof(TValue))); + + if (inst.c.kind == IrOpKind::VmReg) + build.add(x2, rBase, uint16_t(vmRegOp(inst.c) * sizeof(TValue))); + else if (inst.c.kind == IrOpKind::Constant) + { + TValue n; + setnvalue(&n, uintOp(inst.c)); + build.adr(x2, &n, sizeof(n)); + } + else + LUAU_ASSERT(!"Unsupported instruction form"); + + build.add(x3, rBase, uint16_t(vmRegOp(inst.a) * sizeof(TValue))); + build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_gettable))); + build.blr(x4); + + emitUpdateBase(build); + break; + case IrCmd::SET_TABLE: + regs.assertAllFree(); + build.mov(x0, rState); + build.add(x1, rBase, uint16_t(vmRegOp(inst.b) * sizeof(TValue))); + + if (inst.c.kind == IrOpKind::VmReg) + build.add(x2, rBase, uint16_t(vmRegOp(inst.c) * sizeof(TValue))); + else if (inst.c.kind == IrOpKind::Constant) + { + TValue n; + setnvalue(&n, uintOp(inst.c)); + build.adr(x2, &n, sizeof(n)); + } + else + LUAU_ASSERT(!"Unsupported instruction form"); + + build.add(x3, rBase, uint16_t(vmRegOp(inst.a) * sizeof(TValue))); + build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_settable))); + build.blr(x4); + + emitUpdateBase(build); + break; + case IrCmd::GET_IMPORT: + regs.assertAllFree(); + emitInstGetImport(build, vmRegOp(inst.a), uintOp(inst.b)); + break; + case IrCmd::CONCAT: + regs.assertAllFree(); + build.mov(x0, rState); + build.mov(x1, uintOp(inst.b)); + build.mov(x2, vmRegOp(inst.a) + uintOp(inst.b) - 1); + build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaV_concat))); + build.blr(x3); + + emitUpdateBase(build); + break; + case IrCmd::GET_UPVALUE: + { + RegisterA64 temp1 = regs.allocTemp(KindA64::x); + RegisterA64 temp2 = regs.allocTemp(KindA64::q); + RegisterA64 temp3 = regs.allocTemp(KindA64::w); + + build.add(temp1, rClosure, uint16_t(offsetof(Closure, l.uprefs) + sizeof(TValue) * vmUpvalueOp(inst.b))); + + // uprefs[] is either an actual value, or it points to UpVal object which has a pointer to value + Label skip; + build.ldr(temp3, mem(temp1, offsetof(TValue, tt))); + build.cmp(temp3, LUA_TUPVAL); + build.b(ConditionA64::NotEqual, skip); + + // UpVal.v points to the value (either on stack, or on heap inside each UpVal, but we can deref it unconditionally) + build.ldr(temp1, mem(temp1, offsetof(TValue, value.gc))); + build.ldr(temp1, mem(temp1, offsetof(UpVal, v))); + + build.setLabel(skip); + + build.ldr(temp2, temp1); + build.str(temp2, mem(rBase, vmRegOp(inst.a) * sizeof(TValue))); + break; + } + case IrCmd::SET_UPVALUE: + { + regs.assertAllFree(); + RegisterA64 temp1 = regs.allocTemp(KindA64::x); + RegisterA64 temp2 = regs.allocTemp(KindA64::x); + RegisterA64 temp3 = regs.allocTemp(KindA64::q); + RegisterA64 temp4 = regs.allocTemp(KindA64::x); + + // UpVal* + build.ldr(temp1, mem(rClosure, offsetof(Closure, l.uprefs) + sizeof(TValue) * vmUpvalueOp(inst.a) + offsetof(TValue, value.gc))); + + build.ldr(temp2, mem(temp1, offsetof(UpVal, v))); + build.ldr(temp3, mem(rBase, vmRegOp(inst.b) * sizeof(TValue))); + build.str(temp3, temp2); + + Label skip; + checkObjectBarrierConditions(build, temp1, temp2, temp4, vmRegOp(inst.b), skip); + + build.mov(x0, rState); + build.mov(x1, temp1); // TODO: aliasing hazard + build.ldr(x2, mem(rBase, vmRegOp(inst.b) * sizeof(TValue) + offsetof(TValue, value))); + build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaC_barrierf))); + build.blr(x3); + + // note: no emitUpdateBase necessary because luaC_ barriers do not reallocate stack + build.setLabel(skip); + break; + } + case IrCmd::PREPARE_FORN: + regs.assertAllFree(); + build.mov(x0, rState); + build.add(x1, rBase, uint16_t(vmRegOp(inst.a) * sizeof(TValue))); + build.add(x2, rBase, uint16_t(vmRegOp(inst.b) * sizeof(TValue))); + build.add(x3, rBase, uint16_t(vmRegOp(inst.c) * sizeof(TValue))); + build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_prepareFORN))); + build.blr(x4); + // note: no emitUpdateBase necessary because prepareFORN does not reallocate stack + break; + case IrCmd::CHECK_TAG: + build.cmp(regOp(inst.a), tagOp(inst.b)); + build.b(ConditionA64::NotEqual, labelOp(inst.c)); + break; + case IrCmd::CHECK_READONLY: + { + RegisterA64 temp = regs.allocTemp(KindA64::w); + build.ldrb(temp, mem(regOp(inst.a), offsetof(Table, readonly))); + build.cbnz(temp, labelOp(inst.b)); + break; + } + case IrCmd::CHECK_NO_METATABLE: + { + RegisterA64 temp = regs.allocTemp(KindA64::x); + build.ldr(temp, mem(regOp(inst.a), offsetof(Table, metatable))); + build.cbnz(temp, labelOp(inst.b)); + break; + } + case IrCmd::CHECK_SAFE_ENV: + { + RegisterA64 temp = regs.allocTemp(KindA64::x); + RegisterA64 tempw = castReg(KindA64::w, temp); + build.ldr(temp, mem(rClosure, offsetof(Closure, env))); + build.ldrb(tempw, mem(temp, offsetof(Table, safeenv))); + build.cbz(tempw, labelOp(inst.a)); + break; + } + case IrCmd::CHECK_ARRAY_SIZE: + { + RegisterA64 temp = regs.allocTemp(KindA64::w); + build.ldr(temp, mem(regOp(inst.a), offsetof(Table, sizearray))); + + if (inst.b.kind == IrOpKind::Inst) + build.cmp(temp, regOp(inst.b)); + else if (inst.b.kind == IrOpKind::Constant) + { + LUAU_ASSERT(size_t(intOp(inst.b)) <= AssemblyBuilderA64::kMaxImmediate); // TODO: handle out of range values + build.cmp(temp, uint16_t(intOp(inst.b))); + } + else + LUAU_ASSERT(!"Unsupported instruction form"); + + build.b(ConditionA64::UnsignedLessEqual, labelOp(inst.c)); + break; + } + case IrCmd::CHECK_SLOT_MATCH: + { + RegisterA64 temp1 = regs.allocTemp(KindA64::x); + RegisterA64 temp1w = castReg(KindA64::w, temp1); + RegisterA64 temp2 = regs.allocTemp(KindA64::x); + RegisterA64 temp2w = castReg(KindA64::w, temp2); + + build.ldr(temp1w, mem(regOp(inst.a), kOffsetOfLuaNodeTag)); + // TODO: this needs bitfield extraction, or and-immediate + build.mov(temp2w, kLuaNodeTagMask); + build.and_(temp1w, temp1w, temp2w); + build.cmp(temp1w, LUA_TSTRING); + build.b(ConditionA64::NotEqual, labelOp(inst.c)); + + AddressA64 addr = tempAddr(inst.b, offsetof(TValue, value)); + build.ldr(temp1, mem(regOp(inst.a), offsetof(LuaNode, key.value))); + build.ldr(temp2, addr); + build.cmp(temp1, temp2); + build.b(ConditionA64::NotEqual, labelOp(inst.c)); + + build.ldr(temp1w, mem(regOp(inst.a), offsetof(LuaNode, val.tt))); + LUAU_ASSERT(LUA_TNIL == 0); + build.cbz(temp1w, labelOp(inst.c)); + break; + } + case IrCmd::INTERRUPT: + { + unsigned int pcpos = uintOp(inst.a); + regs.assertAllFree(); + + Label skip; + build.ldr(x2, mem(rState, offsetof(lua_State, global))); + build.ldr(x2, mem(x2, offsetof(global_State, cb.interrupt))); + build.cbz(x2, skip); + + // Jump to outlined interrupt handler, it will give back control to x1 + build.mov(x0, (pcpos + 1) * sizeof(Instruction)); + build.adr(x1, skip); + build.b(helpers.interrupt); + + build.setLabel(skip); + break; + } + case IrCmd::CHECK_GC: + { + regs.assertAllFree(); + RegisterA64 temp1 = regs.allocTemp(KindA64::x); + RegisterA64 temp2 = regs.allocTemp(KindA64::x); + + Label skip; + build.ldr(temp1, mem(rState, offsetof(lua_State, global))); + build.ldr(temp2, mem(temp1, offsetof(global_State, totalbytes))); + build.ldr(temp1, mem(temp1, offsetof(global_State, GCthreshold))); + build.cmp(temp1, temp2); + build.b(ConditionA64::UnsignedGreater, skip); + + build.mov(x0, rState); + build.mov(w1, 1); + build.ldr(x1, mem(rNativeContext, offsetof(NativeContext, luaC_step))); + build.blr(x1); + + emitUpdateBase(build); + build.setLabel(skip); + break; + } + case IrCmd::BARRIER_OBJ: + { + regs.assertAllFreeExcept(regOp(inst.a)); + + Label skip; + RegisterA64 temp1 = regs.allocTemp(KindA64::x); + RegisterA64 temp2 = regs.allocTemp(KindA64::x); + + checkObjectBarrierConditions(build, regOp(inst.a), temp1, temp2, vmRegOp(inst.b), skip); + + build.mov(x0, rState); + build.mov(x1, regOp(inst.a)); // TODO: aliasing hazard + build.ldr(x2, mem(rBase, vmRegOp(inst.b) * sizeof(TValue) + offsetof(TValue, value))); + build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaC_barrierf))); + build.blr(x3); + + // note: no emitUpdateBase necessary because luaC_ barriers do not reallocate stack + build.setLabel(skip); + break; + } + case IrCmd::BARRIER_TABLE_BACK: + { + regs.assertAllFreeExcept(regOp(inst.a)); + + Label skip; + RegisterA64 temp1 = regs.allocTemp(KindA64::w); + RegisterA64 temp2 = regs.allocTemp(KindA64::w); + + // isblack(obj2gco(t)) + build.ldrb(temp1, mem(regOp(inst.a), offsetof(GCheader, marked))); + // TODO: conditional bit test with BLACKBIT + build.mov(temp2, bitmask(BLACKBIT)); + build.and_(temp1, temp1, temp2); + build.cbz(temp1, skip); + + build.mov(x0, rState); + build.mov(x1, regOp(inst.a)); // TODO: aliasing hazard here and below + build.add(x2, regOp(inst.a), uint16_t(offsetof(Table, gclist))); + build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaC_barrierback))); + build.blr(x3); + + // note: no emitUpdateBase necessary because luaC_ barriers do not reallocate stack + build.setLabel(skip); + break; + } + case IrCmd::BARRIER_TABLE_FORWARD: + { + regs.assertAllFreeExcept(regOp(inst.a)); + + Label skip; + RegisterA64 temp1 = regs.allocTemp(KindA64::x); + RegisterA64 temp2 = regs.allocTemp(KindA64::x); + + checkObjectBarrierConditions(build, regOp(inst.a), temp1, temp2, vmRegOp(inst.b), skip); + + build.mov(x0, rState); + build.mov(x1, regOp(inst.a)); // TODO: aliasing hazard + build.ldr(x2, mem(rBase, vmRegOp(inst.b) * sizeof(TValue) + offsetof(TValue, value))); + build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaC_barriertable))); + build.blr(x3); + + // note: no emitUpdateBase necessary because luaC_ barriers do not reallocate stack + build.setLabel(skip); + break; + } + case IrCmd::SET_SAVEDPC: + { + unsigned int pcpos = uintOp(inst.a); + RegisterA64 temp1 = regs.allocTemp(KindA64::x); + RegisterA64 temp2 = regs.allocTemp(KindA64::x); + + // TODO: refactor into a common helper + if (pcpos * sizeof(Instruction) <= AssemblyBuilderA64::kMaxImmediate) + { + build.add(temp1, rCode, uint16_t(pcpos * sizeof(Instruction))); + } + else + { + build.mov(temp1, pcpos * sizeof(Instruction)); + build.add(temp1, rCode, temp1); + } + + build.ldr(temp2, mem(rState, offsetof(lua_State, ci))); + build.str(temp1, mem(temp2, offsetof(CallInfo, savedpc))); + break; + } + case IrCmd::CLOSE_UPVALS: + { + regs.assertAllFree(); + Label skip; + RegisterA64 temp1 = regs.allocTemp(KindA64::x); + RegisterA64 temp2 = regs.allocTemp(KindA64::x); + + // L->openupval != 0 + build.ldr(temp1, mem(rState, offsetof(lua_State, openupval))); + build.cbz(temp1, skip); + + // ra <= L->openuval->v + build.ldr(temp1, mem(temp1, offsetof(UpVal, v))); + build.add(temp2, rBase, uint16_t(vmRegOp(inst.a) * sizeof(TValue))); + build.cmp(temp2, temp1); + build.b(ConditionA64::UnsignedGreater, skip); + + build.mov(x0, rState); + build.mov(x1, temp2); // TODO: aliasing hazard + build.ldr(x2, mem(rNativeContext, offsetof(NativeContext, luaF_close))); + build.blr(x2); + + build.setLabel(skip); + break; + } + case IrCmd::CAPTURE: + // no-op + break; + case IrCmd::CALL: + regs.assertAllFree(); + emitInstCall(build, helpers, vmRegOp(inst.a), intOp(inst.b), intOp(inst.c)); + break; + case IrCmd::RETURN: + regs.assertAllFree(); + emitInstReturn(build, helpers, vmRegOp(inst.a), intOp(inst.b)); + break; + + // Full instruction fallbacks + case IrCmd::FALLBACK_GETGLOBAL: + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::VmConst); + + regs.assertAllFree(); + emitFallback(build, LOP_GETGLOBAL, uintOp(inst.a)); + break; + case IrCmd::FALLBACK_SETGLOBAL: + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::VmConst); + + regs.assertAllFree(); + emitFallback(build, LOP_SETGLOBAL, uintOp(inst.a)); + break; + case IrCmd::FALLBACK_GETTABLEKS: + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.d.kind == IrOpKind::VmConst); + + regs.assertAllFree(); + emitFallback(build, LOP_GETTABLEKS, uintOp(inst.a)); + break; + case IrCmd::FALLBACK_SETTABLEKS: + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.d.kind == IrOpKind::VmConst); + + regs.assertAllFree(); + emitFallback(build, LOP_SETTABLEKS, uintOp(inst.a)); + break; + case IrCmd::FALLBACK_NAMECALL: + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.d.kind == IrOpKind::VmConst); + + regs.assertAllFree(); + emitFallback(build, LOP_NAMECALL, uintOp(inst.a)); + break; + case IrCmd::FALLBACK_PREPVARARGS: + LUAU_ASSERT(inst.b.kind == IrOpKind::Constant); + + regs.assertAllFree(); + emitFallback(build, LOP_PREPVARARGS, uintOp(inst.a)); + break; + case IrCmd::FALLBACK_GETVARARGS: + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::Constant); + + regs.assertAllFree(); + emitFallback(build, LOP_GETVARARGS, uintOp(inst.a)); + break; + case IrCmd::FALLBACK_NEWCLOSURE: + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::Constant); + + regs.assertAllFree(); + emitFallback(build, LOP_NEWCLOSURE, uintOp(inst.a)); + break; + case IrCmd::FALLBACK_DUPCLOSURE: + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::VmConst); + + regs.assertAllFree(); + emitFallback(build, LOP_DUPCLOSURE, uintOp(inst.a)); + break; + + default: + LUAU_ASSERT(!"Not supported yet"); + break; + } + + regs.freeLastUseRegs(inst, index); + regs.freeTempRegs(); +} + +bool IrLoweringA64::hasError() const +{ + return false; +} + +bool IrLoweringA64::isFallthroughBlock(IrBlock target, IrBlock next) +{ + return target.start == next.start; +} + +void IrLoweringA64::jumpOrFallthrough(IrBlock& target, IrBlock& next) +{ + if (!isFallthroughBlock(target, next)) + build.b(target.label); +} + +RegisterA64 IrLoweringA64::tempDouble(IrOp op) +{ + if (op.kind == IrOpKind::Inst) + return regOp(op); + else if (op.kind == IrOpKind::Constant) + { + RegisterA64 temp1 = regs.allocTemp(KindA64::x); + RegisterA64 temp2 = regs.allocTemp(KindA64::d); + build.adr(temp1, doubleOp(op)); + build.ldr(temp2, temp1); + return temp2; + } + else + { + LUAU_ASSERT(!"Unsupported instruction form"); + return noreg; + } +} + +RegisterA64 IrLoweringA64::tempInt(IrOp op) +{ + if (op.kind == IrOpKind::Inst) + return regOp(op); + else if (op.kind == IrOpKind::Constant) + { + RegisterA64 temp = regs.allocTemp(KindA64::w); + build.mov(temp, intOp(op)); + return temp; + } + else + { + LUAU_ASSERT(!"Unsupported instruction form"); + return noreg; + } +} + +AddressA64 IrLoweringA64::tempAddr(IrOp op, int offset) +{ + // This is needed to tighten the bounds checks in the VmConst case below + LUAU_ASSERT(offset % 4 == 0); + + if (op.kind == IrOpKind::VmReg) + return mem(rBase, vmRegOp(op) * sizeof(TValue) + offset); + else if (op.kind == IrOpKind::VmConst) + { + size_t constantOffset = vmConstOp(op) * sizeof(TValue) + offset; + + // Note: cumulative offset is guaranteed to be divisible by 4; we can use that to expand the useful range that doesn't require temporaries + if (constantOffset / 4 <= AddressA64::kMaxOffset) + return mem(rConstants, int(constantOffset)); + + RegisterA64 temp = regs.allocTemp(KindA64::x); + + // TODO: refactor into a common helper + if (constantOffset <= AssemblyBuilderA64::kMaxImmediate) + { + build.add(temp, rConstants, uint16_t(constantOffset)); + } + else + { + build.mov(temp, int(constantOffset)); + build.add(temp, rConstants, temp); + } + + return temp; + } + // If we have a register, we assume it's a pointer to TValue + // We might introduce explicit operand types in the future to make this more robust + else if (op.kind == IrOpKind::Inst) + return mem(regOp(op), offset); + else + { + LUAU_ASSERT(!"Unsupported instruction form"); + return noreg; + } +} + +RegisterA64 IrLoweringA64::regOp(IrOp op) const +{ + IrInst& inst = function.instOp(op); + LUAU_ASSERT(inst.regA64 != noreg); + return inst.regA64; +} + +IrConst IrLoweringA64::constOp(IrOp op) const +{ + return function.constOp(op); +} + +uint8_t IrLoweringA64::tagOp(IrOp op) const +{ + return function.tagOp(op); +} + +bool IrLoweringA64::boolOp(IrOp op) const +{ + return function.boolOp(op); +} + +int IrLoweringA64::intOp(IrOp op) const +{ + return function.intOp(op); +} + +unsigned IrLoweringA64::uintOp(IrOp op) const +{ + return function.uintOp(op); +} + +double IrLoweringA64::doubleOp(IrOp op) const +{ + return function.doubleOp(op); +} + +IrBlock& IrLoweringA64::blockOp(IrOp op) const +{ + return function.blockOp(op); +} + +Label& IrLoweringA64::labelOp(IrOp op) const +{ + return blockOp(op).label; +} + +} // namespace A64 +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/src/IrLoweringA64.h b/CodeGen/src/IrLoweringA64.h new file mode 100644 index 00000000..b374a26a --- /dev/null +++ b/CodeGen/src/IrLoweringA64.h @@ -0,0 +1,68 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/AssemblyBuilderA64.h" +#include "Luau/IrData.h" + +#include "IrRegAllocA64.h" + +#include + +struct Proto; + +namespace Luau +{ +namespace CodeGen +{ + +struct ModuleHelpers; +struct NativeState; +struct AssemblyOptions; + +namespace A64 +{ + +struct IrLoweringA64 +{ + IrLoweringA64(AssemblyBuilderA64& build, ModuleHelpers& helpers, NativeState& data, Proto* proto, IrFunction& function); + + static bool canLower(const IrFunction& function); + + void lowerInst(IrInst& inst, uint32_t index, IrBlock& next); + + bool hasError() const; + + bool isFallthroughBlock(IrBlock target, IrBlock next); + void jumpOrFallthrough(IrBlock& target, IrBlock& next); + + // Operand data build helpers + RegisterA64 tempDouble(IrOp op); + RegisterA64 tempInt(IrOp op); + AddressA64 tempAddr(IrOp op, int offset); + + // Operand data lookup helpers + RegisterA64 regOp(IrOp op) const; + + IrConst constOp(IrOp op) const; + uint8_t tagOp(IrOp op) const; + bool boolOp(IrOp op) const; + int intOp(IrOp op) const; + unsigned uintOp(IrOp op) const; + double doubleOp(IrOp op) const; + + IrBlock& blockOp(IrOp op) const; + Label& labelOp(IrOp op) const; + + AssemblyBuilderA64& build; + ModuleHelpers& helpers; + NativeState& data; + Proto* proto = nullptr; // Temporarily required to provide 'Instruction* pc' to old emitInst* methods + + IrFunction& function; + + IrRegAllocA64 regs; +}; + +} // namespace A64 +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/src/IrLoweringX64.cpp b/CodeGen/src/IrLoweringX64.cpp index b45ce226..f2dfdb3b 100644 --- a/CodeGen/src/IrLoweringX64.cpp +++ b/CodeGen/src/IrLoweringX64.cpp @@ -4,6 +4,7 @@ #include "Luau/CodeGen.h" #include "Luau/DenseHash.h" #include "Luau/IrAnalysis.h" +#include "Luau/IrCallWrapperX64.h" #include "Luau/IrDump.h" #include "Luau/IrUtils.h" @@ -14,8 +15,6 @@ #include "lstate.h" -#include - namespace Luau { namespace CodeGen @@ -23,169 +22,49 @@ namespace CodeGen namespace X64 { -IrLoweringX64::IrLoweringX64(AssemblyBuilderX64& build, ModuleHelpers& helpers, NativeState& data, Proto* proto, IrFunction& function) +IrLoweringX64::IrLoweringX64(AssemblyBuilderX64& build, ModuleHelpers& helpers, NativeState& data, IrFunction& function) : build(build) , helpers(helpers) , data(data) - , proto(proto) , function(function) - , regs(function) + , regs(build, function) { // In order to allocate registers during lowering, we need to know where instruction results are last used updateLastUseLocations(function); } -void IrLoweringX64::lower(AssemblyOptions options) +void IrLoweringX64::storeDoubleAsFloat(OperandX64 dst, IrOp src) { - // While we will need a better block ordering in the future, right now we want to mostly preserve build order with fallbacks outlined - std::vector sortedBlocks; - sortedBlocks.reserve(function.blocks.size()); - for (uint32_t i = 0; i < function.blocks.size(); i++) - sortedBlocks.push_back(i); + ScopedRegX64 tmp{regs, SizeX64::xmmword}; - std::sort(sortedBlocks.begin(), sortedBlocks.end(), [&](uint32_t idxA, uint32_t idxB) { - const IrBlock& a = function.blocks[idxA]; - const IrBlock& b = function.blocks[idxB]; - - // Place fallback blocks at the end - if ((a.kind == IrBlockKind::Fallback) != (b.kind == IrBlockKind::Fallback)) - return (a.kind == IrBlockKind::Fallback) < (b.kind == IrBlockKind::Fallback); - - // Try to order by instruction order - return a.start < b.start; - }); - - DenseHashMap bcLocations{~0u}; - - // Create keys for IR assembly locations that original bytecode instruction are interested in - for (const auto& [irLocation, asmLocation] : function.bcMapping) + if (src.kind == IrOpKind::Constant) { - if (irLocation != ~0u) - bcLocations[irLocation] = 0; + build.vmovss(tmp.reg, build.f32(float(doubleOp(src)))); } - - DenseHashMap indexIrToBc{~0u}; - bool outputEnabled = options.includeAssembly || options.includeIr; - - if (outputEnabled && options.annotator) + else if (src.kind == IrOpKind::Inst) { - // Create reverse mapping from IR location to bytecode location - for (size_t i = 0; i < function.bcMapping.size(); ++i) - { - uint32_t irLocation = function.bcMapping[i].irLocation; - - if (irLocation != ~0u) - indexIrToBc[irLocation] = uint32_t(i); - } + build.vcvtsd2ss(tmp.reg, regOp(src), regOp(src)); } - - IrToStringContext ctx{build.text, function.blocks, function.constants, function.cfg}; - - // We use this to skip outlined fallback blocks from IR/asm text output - size_t textSize = build.text.length(); - uint32_t codeSize = build.getCodeSize(); - bool seenFallback = false; - - IrBlock dummy; - dummy.start = ~0u; - - for (size_t i = 0; i < sortedBlocks.size(); ++i) + else { - uint32_t blockIndex = sortedBlocks[i]; - - IrBlock& block = function.blocks[blockIndex]; - - if (block.kind == IrBlockKind::Dead) - continue; - - LUAU_ASSERT(block.start != ~0u); - LUAU_ASSERT(block.finish != ~0u); - - // If we want to skip fallback code IR/asm, we'll record when those blocks start once we see them - if (block.kind == IrBlockKind::Fallback && !seenFallback) - { - textSize = build.text.length(); - codeSize = build.getCodeSize(); - seenFallback = true; - } - - if (options.includeIr) - { - build.logAppend("# "); - toStringDetailed(ctx, block, blockIndex, /* includeUseInfo */ true); - } - - build.setLabel(block.label); - - for (uint32_t index = block.start; index <= block.finish; index++) - { - LUAU_ASSERT(index < function.instructions.size()); - - // If IR instruction is the first one for the original bytecode, we can annotate it with source code text - if (outputEnabled && options.annotator) - { - if (uint32_t* bcIndex = indexIrToBc.find(index)) - options.annotator(options.annotatorContext, build.text, proto->bytecodeid, *bcIndex); - } - - // If bytecode needs the location of this instruction for jumps, record it - if (uint32_t* bcLocation = bcLocations.find(index)) - *bcLocation = build.getCodeSize(); - - IrInst& inst = function.instructions[index]; - - // Skip pseudo instructions, but make sure they are not used at this stage - // This also prevents them from getting into text output when that's enabled - if (isPseudo(inst.cmd)) - { - LUAU_ASSERT(inst.useCount == 0); - continue; - } - - if (options.includeIr) - { - build.logAppend("# "); - toStringDetailed(ctx, inst, index, /* includeUseInfo */ true); - } - - IrBlock& next = i + 1 < sortedBlocks.size() ? function.blocks[sortedBlocks[i + 1]] : dummy; - - lowerInst(inst, index, next); - - regs.freeLastUseRegs(inst, index); - } - - if (options.includeIr) - build.logAppend("#\n"); - } - - if (outputEnabled && !options.includeOutlinedCode && seenFallback) - { - build.text.resize(textSize); - - if (options.includeAssembly) - build.logAppend("; skipping %u bytes of outlined code\n", build.getCodeSize() - codeSize); - } - - // Copy assembly locations of IR instructions that are mapped to bytecode instructions - for (auto& [irLocation, asmLocation] : function.bcMapping) - { - if (irLocation != ~0u) - asmLocation = bcLocations[irLocation]; + LUAU_ASSERT(!"Unsupported instruction form"); } + build.vmovss(dst, tmp.reg); } void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) { + regs.currInstIdx = index; + switch (inst.cmd) { case IrCmd::LOAD_TAG: - inst.regX64 = regs.allocGprReg(SizeX64::dword); + inst.regX64 = regs.allocGprReg(SizeX64::dword, index); if (inst.a.kind == IrOpKind::VmReg) - build.mov(inst.regX64, luauRegTag(inst.a.index)); + build.mov(inst.regX64, luauRegTag(vmRegOp(inst.a))); else if (inst.a.kind == IrOpKind::VmConst) - build.mov(inst.regX64, luauConstantTag(inst.a.index)); + build.mov(inst.regX64, luauConstantTag(vmConstOp(inst.a))); // If we have a register, we assume it's a pointer to TValue // We might introduce explicit operand types in the future to make this more robust else if (inst.a.kind == IrOpKind::Inst) @@ -194,12 +73,12 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) LUAU_ASSERT(!"Unsupported instruction form"); break; case IrCmd::LOAD_POINTER: - inst.regX64 = regs.allocGprReg(SizeX64::qword); + inst.regX64 = regs.allocGprReg(SizeX64::qword, index); if (inst.a.kind == IrOpKind::VmReg) - build.mov(inst.regX64, luauRegValue(inst.a.index)); + build.mov(inst.regX64, luauRegValue(vmRegOp(inst.a))); else if (inst.a.kind == IrOpKind::VmConst) - build.mov(inst.regX64, luauConstantValue(inst.a.index)); + build.mov(inst.regX64, luauConstantValue(vmConstOp(inst.a))); // If we have a register, we assume it's a pointer to TValue // We might introduce explicit operand types in the future to make this more robust else if (inst.a.kind == IrOpKind::Inst) @@ -208,41 +87,39 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) LUAU_ASSERT(!"Unsupported instruction form"); break; case IrCmd::LOAD_DOUBLE: - inst.regX64 = regs.allocXmmReg(); + inst.regX64 = regs.allocXmmReg(index); if (inst.a.kind == IrOpKind::VmReg) - build.vmovsd(inst.regX64, luauRegValue(inst.a.index)); + build.vmovsd(inst.regX64, luauRegValue(vmRegOp(inst.a))); else if (inst.a.kind == IrOpKind::VmConst) - build.vmovsd(inst.regX64, luauConstantValue(inst.a.index)); + build.vmovsd(inst.regX64, luauConstantValue(vmConstOp(inst.a))); else LUAU_ASSERT(!"Unsupported instruction form"); break; case IrCmd::LOAD_INT: - LUAU_ASSERT(inst.a.kind == IrOpKind::VmReg); + inst.regX64 = regs.allocGprReg(SizeX64::dword, index); - inst.regX64 = regs.allocGprReg(SizeX64::dword); - - build.mov(inst.regX64, luauRegValueInt(inst.a.index)); + build.mov(inst.regX64, luauRegValueInt(vmRegOp(inst.a))); break; case IrCmd::LOAD_TVALUE: - inst.regX64 = regs.allocXmmReg(); + inst.regX64 = regs.allocXmmReg(index); if (inst.a.kind == IrOpKind::VmReg) - build.vmovups(inst.regX64, luauReg(inst.a.index)); + build.vmovups(inst.regX64, luauReg(vmRegOp(inst.a))); else if (inst.a.kind == IrOpKind::VmConst) - build.vmovups(inst.regX64, luauConstant(inst.a.index)); + build.vmovups(inst.regX64, luauConstant(vmConstOp(inst.a))); else if (inst.a.kind == IrOpKind::Inst) build.vmovups(inst.regX64, xmmword[regOp(inst.a)]); else LUAU_ASSERT(!"Unsupported instruction form"); break; case IrCmd::LOAD_NODE_VALUE_TV: - inst.regX64 = regs.allocXmmReg(); + inst.regX64 = regs.allocXmmReg(index); build.vmovups(inst.regX64, luauNodeValue(regOp(inst.a))); break; case IrCmd::LOAD_ENV: - inst.regX64 = regs.allocGprReg(SizeX64::qword); + inst.regX64 = regs.allocGprReg(SizeX64::qword, index); build.mov(inst.regX64, sClosure); build.mov(inst.regX64, qword[inst.regX64 + offsetof(Closure, env)]); @@ -274,7 +151,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) break; case IrCmd::GET_SLOT_NODE_ADDR: { - inst.regX64 = regs.allocGprReg(SizeX64::qword); + inst.regX64 = regs.allocGprReg(SizeX64::qword, index); ScopedRegX64 tmp{regs, SizeX64::qword}; @@ -283,10 +160,10 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } case IrCmd::GET_HASH_NODE_ADDR: { - inst.regX64 = regs.allocGprReg(SizeX64::qword); + inst.regX64 = regs.allocGprReg(SizeX64::qword, index); // Custom bit shift value can only be placed in cl - ScopedRegX64 shiftTmp{regs, regs.takeGprReg(rcx)}; + ScopedRegX64 shiftTmp{regs, regs.takeReg(rcx, kInvalidInstIdx)}; ScopedRegX64 tmp{regs, SizeX64::qword}; @@ -301,31 +178,25 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) break; }; case IrCmd::STORE_TAG: - LUAU_ASSERT(inst.a.kind == IrOpKind::VmReg); - if (inst.b.kind == IrOpKind::Constant) - build.mov(luauRegTag(inst.a.index), tagOp(inst.b)); + build.mov(luauRegTag(vmRegOp(inst.a)), tagOp(inst.b)); else LUAU_ASSERT(!"Unsupported instruction form"); break; case IrCmd::STORE_POINTER: - LUAU_ASSERT(inst.a.kind == IrOpKind::VmReg); - - build.mov(luauRegValue(inst.a.index), regOp(inst.b)); + build.mov(luauRegValue(vmRegOp(inst.a)), regOp(inst.b)); break; case IrCmd::STORE_DOUBLE: - LUAU_ASSERT(inst.a.kind == IrOpKind::VmReg); - if (inst.b.kind == IrOpKind::Constant) { ScopedRegX64 tmp{regs, SizeX64::xmmword}; build.vmovsd(tmp.reg, build.f64(doubleOp(inst.b))); - build.vmovsd(luauRegValue(inst.a.index), tmp.reg); + build.vmovsd(luauRegValue(vmRegOp(inst.a)), tmp.reg); } else if (inst.b.kind == IrOpKind::Inst) { - build.vmovsd(luauRegValue(inst.a.index), regOp(inst.b)); + build.vmovsd(luauRegValue(vmRegOp(inst.a)), regOp(inst.b)); } else { @@ -334,19 +205,24 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) break; case IrCmd::STORE_INT: { - LUAU_ASSERT(inst.a.kind == IrOpKind::VmReg); - if (inst.b.kind == IrOpKind::Constant) - build.mov(luauRegValueInt(inst.a.index), intOp(inst.b)); + build.mov(luauRegValueInt(vmRegOp(inst.a)), intOp(inst.b)); else if (inst.b.kind == IrOpKind::Inst) - build.mov(luauRegValueInt(inst.a.index), regOp(inst.b)); + build.mov(luauRegValueInt(vmRegOp(inst.a)), regOp(inst.b)); else LUAU_ASSERT(!"Unsupported instruction form"); break; } + case IrCmd::STORE_VECTOR: + { + storeDoubleAsFloat(luauRegValueVector(vmRegOp(inst.a), 0), inst.b); + storeDoubleAsFloat(luauRegValueVector(vmRegOp(inst.a), 1), inst.c); + storeDoubleAsFloat(luauRegValueVector(vmRegOp(inst.a), 2), inst.d); + break; + } case IrCmd::STORE_TVALUE: if (inst.a.kind == IrOpKind::VmReg) - build.vmovups(luauReg(inst.a.index), regOp(inst.b)); + build.vmovups(luauReg(vmRegOp(inst.a)), regOp(inst.b)); else if (inst.a.kind == IrOpKind::Inst) build.vmovups(xmmword[regOp(inst.a)], regOp(inst.b)); else @@ -478,82 +354,11 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } case IrCmd::POW_NUM: { - inst.regX64 = regs.allocXmmRegOrReuse(index, {inst.a, inst.b}); - - ScopedRegX64 optLhsTmp{regs}; - RegisterX64 lhs; - - if (inst.a.kind == IrOpKind::Constant) - { - optLhsTmp.alloc(SizeX64::xmmword); - - build.vmovsd(optLhsTmp.reg, memRegDoubleOp(inst.a)); - lhs = optLhsTmp.reg; - } - else - { - lhs = regOp(inst.a); - } - - if (inst.b.kind == IrOpKind::Inst) - { - // TODO: this doesn't happen with current local-only register allocation, but has to be handled in the future - LUAU_ASSERT(regOp(inst.b) != xmm0); - - if (lhs != xmm0) - build.vmovsd(xmm0, lhs, lhs); - - if (regOp(inst.b) != xmm1) - build.vmovsd(xmm1, regOp(inst.b), regOp(inst.b)); - - build.call(qword[rNativeContext + offsetof(NativeContext, libm_pow)]); - - if (inst.regX64 != xmm0) - build.vmovsd(inst.regX64, xmm0, xmm0); - } - else if (inst.b.kind == IrOpKind::Constant) - { - double rhs = doubleOp(inst.b); - - if (rhs == 2.0) - { - build.vmulsd(inst.regX64, lhs, lhs); - } - else if (rhs == 0.5) - { - build.vsqrtsd(inst.regX64, lhs, lhs); - } - else if (rhs == 3.0) - { - ScopedRegX64 tmp{regs, SizeX64::xmmword}; - - build.vmulsd(tmp.reg, lhs, lhs); - build.vmulsd(inst.regX64, lhs, tmp.reg); - } - else - { - if (lhs != xmm0) - build.vmovsd(xmm0, xmm0, lhs); - - build.vmovsd(xmm1, build.f64(rhs)); - build.call(qword[rNativeContext + offsetof(NativeContext, libm_pow)]); - - if (inst.regX64 != xmm0) - build.vmovsd(inst.regX64, xmm0, xmm0); - } - } - else - { - if (lhs != xmm0) - build.vmovsd(xmm0, lhs, lhs); - - build.vmovsd(xmm1, memRegDoubleOp(inst.b)); - build.call(qword[rNativeContext + offsetof(NativeContext, libm_pow)]); - - if (inst.regX64 != xmm0) - build.vmovsd(inst.regX64, xmm0, xmm0); - } - + IrCallWrapperX64 callWrap(regs, build, index); + callWrap.addArgument(SizeX64::xmmword, memRegDoubleOp(inst.a), inst.a); + callWrap.addArgument(SizeX64::xmmword, memRegDoubleOp(inst.b), inst.b); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, libm_pow)]); + inst.regX64 = regs.takeReg(xmm0, index); break; } case IrCmd::MIN_NUM: @@ -604,6 +409,50 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) break; } + case IrCmd::FLOOR_NUM: + inst.regX64 = regs.allocXmmRegOrReuse(index, {inst.a}); + + build.vroundsd(inst.regX64, inst.regX64, memRegDoubleOp(inst.a), RoundingModeX64::RoundToNegativeInfinity); + break; + case IrCmd::CEIL_NUM: + inst.regX64 = regs.allocXmmRegOrReuse(index, {inst.a}); + + build.vroundsd(inst.regX64, inst.regX64, memRegDoubleOp(inst.a), RoundingModeX64::RoundToPositiveInfinity); + break; + case IrCmd::ROUND_NUM: + { + inst.regX64 = regs.allocXmmRegOrReuse(index, {inst.a}); + + ScopedRegX64 tmp1{regs, SizeX64::xmmword}; + ScopedRegX64 tmp2{regs, SizeX64::xmmword}; + + if (inst.a.kind != IrOpKind::Inst) + build.vmovsd(inst.regX64, memRegDoubleOp(inst.a)); + else if (regOp(inst.a) != inst.regX64) + build.vmovsd(inst.regX64, inst.regX64, regOp(inst.a)); + + build.vandpd(tmp1.reg, inst.regX64, build.f64x2(-0.0, -0.0)); + build.vmovsd(tmp2.reg, build.i64(0x3fdfffffffffffff)); // 0.49999999999999994 + build.vorpd(tmp1.reg, tmp1.reg, tmp2.reg); + build.vaddsd(inst.regX64, inst.regX64, tmp1.reg); + build.vroundsd(inst.regX64, inst.regX64, inst.regX64, RoundingModeX64::RoundToZero); + break; + } + case IrCmd::SQRT_NUM: + inst.regX64 = regs.allocXmmRegOrReuse(index, {inst.a}); + + build.vsqrtsd(inst.regX64, inst.regX64, memRegDoubleOp(inst.a)); + break; + case IrCmd::ABS_NUM: + inst.regX64 = regs.allocXmmRegOrReuse(index, {inst.a}); + + if (inst.a.kind != IrOpKind::Inst) + build.vmovsd(inst.regX64, memRegDoubleOp(inst.a)); + else if (regOp(inst.a) != inst.regX64) + build.vmovsd(inst.regX64, inst.regX64, regOp(inst.a)); + + build.vandpd(inst.regX64, inst.regX64, build.i64(~(1LL << 63))); + break; case IrCmd::NOT_ANY: { // TODO: if we have a single user which is a STORE_INT, we are missing the opportunity to write directly to target @@ -642,15 +491,11 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) jumpOrFallthrough(blockOp(inst.a), next); break; case IrCmd::JUMP_IF_TRUTHY: - LUAU_ASSERT(inst.a.kind == IrOpKind::VmReg); - - jumpIfTruthy(build, inst.a.index, labelOp(inst.b), labelOp(inst.c)); + jumpIfTruthy(build, vmRegOp(inst.a), labelOp(inst.b), labelOp(inst.c)); jumpOrFallthrough(blockOp(inst.c), next); break; case IrCmd::JUMP_IF_FALSY: - LUAU_ASSERT(inst.a.kind == IrOpKind::VmReg); - - jumpIfFalsy(build, inst.a.index, labelOp(inst.b), labelOp(inst.c)); + jumpIfFalsy(build, vmRegOp(inst.a), labelOp(inst.b), labelOp(inst.c)); jumpOrFallthrough(blockOp(inst.c), next); break; case IrCmd::JUMP_EQ_TAG: @@ -686,9 +531,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) break; case IrCmd::JUMP_CMP_NUM: { - LUAU_ASSERT(inst.c.kind == IrOpKind::Condition); - - IrCondition cond = IrCondition(inst.c.index); + IrCondition cond = conditionOp(inst.c); ScopedRegX64 tmp{regs, SizeX64::xmmword}; @@ -698,59 +541,49 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) break; } case IrCmd::JUMP_CMP_ANY: - { - LUAU_ASSERT(inst.a.kind == IrOpKind::VmReg); - LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); - LUAU_ASSERT(inst.c.kind == IrOpKind::Condition); - - IrCondition cond = IrCondition(inst.c.index); - - jumpOnAnyCmpFallback(build, inst.a.index, inst.b.index, cond, labelOp(inst.d)); + jumpOnAnyCmpFallback(regs, build, vmRegOp(inst.a), vmRegOp(inst.b), conditionOp(inst.c), labelOp(inst.d)); jumpOrFallthrough(blockOp(inst.e), next); break; - } case IrCmd::JUMP_SLOT_MATCH: { - LUAU_ASSERT(inst.b.kind == IrOpKind::VmConst); - ScopedRegX64 tmp{regs, SizeX64::qword}; - jumpIfNodeKeyNotInExpectedSlot(build, tmp.reg, regOp(inst.a), luauConstantValue(inst.b.index), labelOp(inst.d)); + jumpIfNodeKeyNotInExpectedSlot(build, tmp.reg, regOp(inst.a), luauConstantValue(vmConstOp(inst.b)), labelOp(inst.d)); jumpOrFallthrough(blockOp(inst.c), next); break; } case IrCmd::TABLE_LEN: - inst.regX64 = regs.allocXmmReg(); + { + IrCallWrapperX64 callWrap(regs, build, index); + callWrap.addArgument(SizeX64::qword, regOp(inst.a), inst.a); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaH_getn)]); - build.mov(rArg1, regOp(inst.a)); - build.call(qword[rNativeContext + offsetof(NativeContext, luaH_getn)]); + inst.regX64 = regs.allocXmmReg(index); build.vcvtsi2sd(inst.regX64, inst.regX64, eax); break; + } case IrCmd::NEW_TABLE: - inst.regX64 = regs.allocGprReg(SizeX64::qword); - - build.mov(rArg1, rState); - build.mov(dwordReg(rArg2), uintOp(inst.a)); - build.mov(dwordReg(rArg3), uintOp(inst.b)); - build.call(qword[rNativeContext + offsetof(NativeContext, luaH_new)]); - - if (inst.regX64 != rax) - build.mov(inst.regX64, rax); + { + IrCallWrapperX64 callWrap(regs, build, index); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::dword, int32_t(uintOp(inst.a)), inst.a); + callWrap.addArgument(SizeX64::dword, int32_t(uintOp(inst.b)), inst.b); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaH_new)]); + inst.regX64 = regs.takeReg(rax, index); break; + } case IrCmd::DUP_TABLE: - inst.regX64 = regs.allocGprReg(SizeX64::qword); - - // Re-ordered to avoid register conflict - build.mov(rArg2, regOp(inst.a)); - build.mov(rArg1, rState); - build.call(qword[rNativeContext + offsetof(NativeContext, luaH_clone)]); - - if (inst.regX64 != rax) - build.mov(inst.regX64, rax); + { + IrCallWrapperX64 callWrap(regs, build, index); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::qword, regOp(inst.a), inst.a); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaH_clone)]); + inst.regX64 = regs.takeReg(rax, index); break; + } case IrCmd::TRY_NUM_TO_INDEX: { - inst.regX64 = regs.allocGprReg(SizeX64::dword); + inst.regX64 = regs.allocGprReg(SizeX64::dword, index); ScopedRegX64 tmp{regs, SizeX64::xmmword}; @@ -759,37 +592,53 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } case IrCmd::TRY_CALL_FASTGETTM: { - inst.regX64 = regs.allocGprReg(SizeX64::qword); + ScopedRegX64 tmp{regs, SizeX64::qword}; - callGetFastTmOrFallback(build, regOp(inst.a), TMS(intOp(inst.b)), labelOp(inst.c)); + build.mov(tmp.reg, qword[regOp(inst.a) + offsetof(Table, metatable)]); + regs.freeLastUseReg(function.instOp(inst.a), index); // Release before the call if it's the last use - if (inst.regX64 != rax) - build.mov(inst.regX64, rax); + build.test(tmp.reg, tmp.reg); + build.jcc(ConditionX64::Zero, labelOp(inst.c)); // No metatable + + build.test(byte[tmp.reg + offsetof(Table, tmcache)], 1 << intOp(inst.b)); + build.jcc(ConditionX64::NotZero, labelOp(inst.c)); // No tag method + + ScopedRegX64 tmp2{regs, SizeX64::qword}; + build.mov(tmp2.reg, qword[rState + offsetof(lua_State, global)]); + + { + ScopedSpills spillGuard(regs); + + IrCallWrapperX64 callWrap(regs, build, index); + callWrap.addArgument(SizeX64::qword, tmp); + callWrap.addArgument(SizeX64::qword, intOp(inst.b)); + callWrap.addArgument(SizeX64::qword, qword[tmp2.release() + offsetof(global_State, tmname) + intOp(inst.b) * sizeof(TString*)]); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaT_gettm)]); + } + + inst.regX64 = regs.takeReg(rax, index); break; } case IrCmd::INT_TO_NUM: - inst.regX64 = regs.allocXmmReg(); + inst.regX64 = regs.allocXmmReg(index); build.vcvtsi2sd(inst.regX64, inst.regX64, regOp(inst.a)); break; case IrCmd::ADJUST_STACK_TO_REG: { - LUAU_ASSERT(inst.a.kind == IrOpKind::VmReg); + ScopedRegX64 tmp{regs, SizeX64::qword}; if (inst.b.kind == IrOpKind::Constant) { - ScopedRegX64 tmp{regs, SizeX64::qword}; - - build.lea(tmp.reg, addr[rBase + (inst.a.index + intOp(inst.b)) * sizeof(TValue)]); + build.lea(tmp.reg, addr[rBase + (vmRegOp(inst.a) + intOp(inst.b)) * sizeof(TValue)]); build.mov(qword[rState + offsetof(lua_State, top)], tmp.reg); } else if (inst.b.kind == IrOpKind::Inst) { - ScopedRegX64 tmp(regs, regs.allocGprRegOrReuse(SizeX64::dword, index, {inst.b})); - - build.shl(qwordReg(tmp.reg), kTValueSizeLog2); - build.lea(qwordReg(tmp.reg), addr[rBase + qwordReg(tmp.reg) + inst.a.index * sizeof(TValue)]); - build.mov(qword[rState + offsetof(lua_State, top)], qwordReg(tmp.reg)); + build.mov(dwordReg(tmp.reg), regOp(inst.b)); + build.shl(tmp.reg, kTValueSizeLog2); + build.lea(tmp.reg, addr[rBase + tmp.reg + vmRegOp(inst.a) * sizeof(TValue)]); + build.mov(qword[rState + offsetof(lua_State, top)], tmp.reg); } else { @@ -807,77 +656,57 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } case IrCmd::FASTCALL: - LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); - LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); - emitBuiltin(regs, build, uintOp(inst.a), inst.b.index, inst.c.index, inst.d, intOp(inst.e), intOp(inst.f)); + emitBuiltin(regs, build, uintOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c), inst.d, intOp(inst.e), intOp(inst.f)); break; case IrCmd::INVOKE_FASTCALL: { - LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); - LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); - unsigned bfid = uintOp(inst.a); OperandX64 args = 0; if (inst.d.kind == IrOpKind::VmReg) - args = luauRegAddress(inst.d.index); + args = luauRegAddress(vmRegOp(inst.d)); else if (inst.d.kind == IrOpKind::VmConst) - args = luauConstantAddress(inst.d.index); + args = luauConstantAddress(vmConstOp(inst.d)); else LUAU_ASSERT(boolOp(inst.d) == false); - int ra = inst.b.index; - int arg = inst.c.index; + int ra = vmRegOp(inst.b); + int arg = vmRegOp(inst.c); int nparams = intOp(inst.e); int nresults = intOp(inst.f); - regs.assertAllFree(); + ScopedRegX64 func{regs, SizeX64::qword}; + build.mov(func.reg, qword[rNativeContext + offsetof(NativeContext, luauF_table) + bfid * sizeof(luau_FastFunction)]); - build.mov(rax, qword[rNativeContext + offsetof(NativeContext, luauF_table) + bfid * sizeof(luau_FastFunction)]); - - // 5th parameter (args) is left unset for LOP_FASTCALL1 - if (args.cat == CategoryX64::mem) - { - if (build.abi == ABIX64::Windows) - { - build.lea(rcx, args); - build.mov(sArg5, rcx); - } - else - { - build.lea(rArg5, args); - } - } + IrCallWrapperX64 callWrap(regs, build, index); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::qword, luauRegAddress(ra)); + callWrap.addArgument(SizeX64::qword, luauRegAddress(arg)); + callWrap.addArgument(SizeX64::dword, nresults); + callWrap.addArgument(SizeX64::qword, args); if (nparams == LUA_MULTRET) { - // L->top - (ra + 1) - RegisterX64 reg = (build.abi == ABIX64::Windows) ? rcx : rArg6; + // Compute 'L->top - (ra + 1)', on SystemV, take r9 register to compute directly into the argument + // TODO: IrCallWrapperX64 should provide a way to 'guess' target argument register correctly + RegisterX64 reg = build.abi == ABIX64::Windows ? regs.allocGprReg(SizeX64::qword, kInvalidInstIdx) : regs.takeReg(rArg6, kInvalidInstIdx); + ScopedRegX64 tmp{regs, SizeX64::qword}; + build.mov(reg, qword[rState + offsetof(lua_State, top)]); - build.lea(rdx, addr[rBase + (ra + 1) * sizeof(TValue)]); - build.sub(reg, rdx); + build.lea(tmp.reg, addr[rBase + (ra + 1) * sizeof(TValue)]); + build.sub(reg, tmp.reg); build.shr(reg, kTValueSizeLog2); - if (build.abi == ABIX64::Windows) - build.mov(sArg6, reg); + callWrap.addArgument(SizeX64::dword, dwordReg(reg)); } else { - if (build.abi == ABIX64::Windows) - build.mov(sArg6, nparams); - else - build.mov(rArg6, nparams); + callWrap.addArgument(SizeX64::dword, nparams); } - build.mov(rArg1, rState); - build.lea(rArg2, luauRegAddress(ra)); - build.lea(rArg3, luauRegAddress(arg)); - build.mov(dwordReg(rArg4), nresults); - - build.call(rax); - - inst.regX64 = regs.takeGprReg(eax); // Result of a builtin call is returned in eax + callWrap.call(func.release()); + inst.regX64 = regs.takeReg(eax, index); // Result of a builtin call is returned in eax break; } case IrCmd::CHECK_FASTCALL_RES: @@ -889,34 +718,24 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) break; } case IrCmd::DO_ARITH: - LUAU_ASSERT(inst.a.kind == IrOpKind::VmReg); - LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); - LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg || inst.c.kind == IrOpKind::VmConst); - if (inst.c.kind == IrOpKind::VmReg) - callArithHelper(build, inst.a.index, inst.b.index, luauRegAddress(inst.c.index), TMS(intOp(inst.d))); + callArithHelper(regs, build, vmRegOp(inst.a), vmRegOp(inst.b), luauRegAddress(vmRegOp(inst.c)), TMS(intOp(inst.d))); else - callArithHelper(build, inst.a.index, inst.b.index, luauConstantAddress(inst.c.index), TMS(intOp(inst.d))); + callArithHelper(regs, build, vmRegOp(inst.a), vmRegOp(inst.b), luauConstantAddress(vmConstOp(inst.c)), TMS(intOp(inst.d))); break; case IrCmd::DO_LEN: - LUAU_ASSERT(inst.a.kind == IrOpKind::VmReg); - LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); - - callLengthHelper(build, inst.a.index, inst.b.index); + callLengthHelper(regs, build, vmRegOp(inst.a), vmRegOp(inst.b)); break; case IrCmd::GET_TABLE: - LUAU_ASSERT(inst.a.kind == IrOpKind::VmReg); - LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); - if (inst.c.kind == IrOpKind::VmReg) { - callGetTable(build, inst.b.index, luauRegAddress(inst.c.index), inst.a.index); + callGetTable(regs, build, vmRegOp(inst.b), luauRegAddress(vmRegOp(inst.c)), vmRegOp(inst.a)); } else if (inst.c.kind == IrOpKind::Constant) { TValue n; setnvalue(&n, uintOp(inst.c)); - callGetTable(build, inst.b.index, build.bytes(&n, sizeof(n)), inst.a.index); + callGetTable(regs, build, vmRegOp(inst.b), build.bytes(&n, sizeof(n)), vmRegOp(inst.a)); } else { @@ -924,18 +743,15 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } break; case IrCmd::SET_TABLE: - LUAU_ASSERT(inst.a.kind == IrOpKind::VmReg); - LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); - if (inst.c.kind == IrOpKind::VmReg) { - callSetTable(build, inst.b.index, luauRegAddress(inst.c.index), inst.a.index); + callSetTable(regs, build, vmRegOp(inst.b), luauRegAddress(vmRegOp(inst.c)), vmRegOp(inst.a)); } else if (inst.c.kind == IrOpKind::Constant) { TValue n; setnvalue(&n, uintOp(inst.c)); - callSetTable(build, inst.b.index, build.bytes(&n, sizeof(n)), inst.a.index); + callSetTable(regs, build, vmRegOp(inst.b), build.bytes(&n, sizeof(n)), vmRegOp(inst.a)); } else { @@ -943,30 +759,27 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } break; case IrCmd::GET_IMPORT: - LUAU_ASSERT(inst.a.kind == IrOpKind::VmReg); - - emitInstGetImportFallback(build, inst.a.index, uintOp(inst.b)); + regs.assertAllFree(); + emitInstGetImportFallback(build, vmRegOp(inst.a), uintOp(inst.b)); break; case IrCmd::CONCAT: - LUAU_ASSERT(inst.a.kind == IrOpKind::VmReg); - - build.mov(rArg1, rState); - build.mov(dwordReg(rArg2), uintOp(inst.b)); - build.mov(dwordReg(rArg3), inst.a.index + uintOp(inst.b) - 1); - build.call(qword[rNativeContext + offsetof(NativeContext, luaV_concat)]); + { + IrCallWrapperX64 callWrap(regs, build, index); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::dword, int32_t(uintOp(inst.b))); + callWrap.addArgument(SizeX64::dword, int32_t(vmRegOp(inst.a) + uintOp(inst.b) - 1)); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_concat)]); emitUpdateBase(build); break; + } case IrCmd::GET_UPVALUE: { - LUAU_ASSERT(inst.a.kind == IrOpKind::VmReg); - LUAU_ASSERT(inst.b.kind == IrOpKind::VmUpvalue); - ScopedRegX64 tmp1{regs, SizeX64::qword}; ScopedRegX64 tmp2{regs, SizeX64::xmmword}; build.mov(tmp1.reg, sClosure); - build.add(tmp1.reg, offsetof(Closure, l.uprefs) + sizeof(TValue) * inst.b.index); + build.add(tmp1.reg, offsetof(Closure, l.uprefs) + sizeof(TValue) * vmUpvalueOp(inst.b)); // uprefs[] is either an actual value, or it points to UpVal object which has a pointer to value Label skip; @@ -981,32 +794,32 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.setLabel(skip); build.vmovups(tmp2.reg, xmmword[tmp1.reg]); - build.vmovups(luauReg(inst.a.index), tmp2.reg); + build.vmovups(luauReg(vmRegOp(inst.a)), tmp2.reg); break; } case IrCmd::SET_UPVALUE: { - LUAU_ASSERT(inst.a.kind == IrOpKind::VmUpvalue); - LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); - - Label next; ScopedRegX64 tmp1{regs, SizeX64::qword}; ScopedRegX64 tmp2{regs, SizeX64::qword}; - ScopedRegX64 tmp3{regs, SizeX64::xmmword}; build.mov(tmp1.reg, sClosure); - build.mov(tmp2.reg, qword[tmp1.reg + offsetof(Closure, l.uprefs) + sizeof(TValue) * inst.a.index + offsetof(TValue, value.gc)]); + build.mov(tmp2.reg, qword[tmp1.reg + offsetof(Closure, l.uprefs) + sizeof(TValue) * vmUpvalueOp(inst.a) + offsetof(TValue, value.gc)]); build.mov(tmp1.reg, qword[tmp2.reg + offsetof(UpVal, v)]); - build.vmovups(tmp3.reg, luauReg(inst.b.index)); - build.vmovups(xmmword[tmp1.reg], tmp3.reg); - callBarrierObject(build, tmp1.reg, tmp2.reg, inst.b.index, next); - build.setLabel(next); + { + ScopedRegX64 tmp3{regs, SizeX64::xmmword}; + build.vmovups(tmp3.reg, luauReg(vmRegOp(inst.b))); + build.vmovups(xmmword[tmp1.reg], tmp3.reg); + } + + tmp1.free(); + + callBarrierObject(regs, build, tmp2.release(), {}, vmRegOp(inst.b)); break; } case IrCmd::PREPARE_FORN: - callPrepareForN(build, inst.a.index, inst.b.index, inst.c.index); + callPrepareForN(regs, build, vmRegOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c)); break; case IrCmd::CHECK_TAG: if (inst.a.kind == IrOpKind::Inst) @@ -1016,11 +829,11 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } else if (inst.a.kind == IrOpKind::VmReg) { - jumpIfTagIsNot(build, inst.a.index, lua_Type(tagOp(inst.b)), labelOp(inst.c)); + jumpIfTagIsNot(build, vmRegOp(inst.a), lua_Type(tagOp(inst.b)), labelOp(inst.c)); } else if (inst.a.kind == IrOpKind::VmConst) { - build.cmp(luauConstantTag(inst.a.index), tagOp(inst.b)); + build.cmp(luauConstantTag(vmConstOp(inst.a)), tagOp(inst.b)); build.jcc(ConditionX64::NotEqual, labelOp(inst.c)); } else @@ -1053,53 +866,44 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) break; case IrCmd::CHECK_SLOT_MATCH: { - LUAU_ASSERT(inst.b.kind == IrOpKind::VmConst); - ScopedRegX64 tmp{regs, SizeX64::qword}; - jumpIfNodeKeyNotInExpectedSlot(build, tmp.reg, regOp(inst.a), luauConstantValue(inst.b.index), labelOp(inst.c)); + jumpIfNodeKeyNotInExpectedSlot(build, tmp.reg, regOp(inst.a), luauConstantValue(vmConstOp(inst.b)), labelOp(inst.c)); break; } case IrCmd::CHECK_NODE_NO_NEXT: jumpIfNodeHasNext(build, regOp(inst.a), labelOp(inst.b)); break; case IrCmd::INTERRUPT: + regs.assertAllFree(); emitInterrupt(build, uintOp(inst.a)); break; case IrCmd::CHECK_GC: - { - Label skip; - callCheckGc(build, -1, false, skip); - build.setLabel(skip); + callStepGc(regs, build); break; - } case IrCmd::BARRIER_OBJ: - { - LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); - - Label skip; - ScopedRegX64 tmp{regs, SizeX64::qword}; - - callBarrierObject(build, tmp.reg, regOp(inst.a), inst.b.index, skip); - build.setLabel(skip); + callBarrierObject(regs, build, regOp(inst.a), inst.a, vmRegOp(inst.b)); break; - } case IrCmd::BARRIER_TABLE_BACK: - { - Label skip; - - callBarrierTableFast(build, regOp(inst.a), skip); - build.setLabel(skip); + callBarrierTableFast(regs, build, regOp(inst.a), inst.a); break; - } case IrCmd::BARRIER_TABLE_FORWARD: { - LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); - Label skip; - ScopedRegX64 tmp{regs, SizeX64::qword}; - callBarrierTable(build, tmp.reg, regOp(inst.a), inst.b.index, skip); + ScopedRegX64 tmp{regs, SizeX64::qword}; + checkObjectBarrierConditions(build, tmp.reg, regOp(inst.a), vmRegOp(inst.b), skip); + + { + ScopedSpills spillGuard(regs); + + IrCallWrapperX64 callWrap(regs, build, index); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::qword, regOp(inst.a), inst.a); + callWrap.addArgument(SizeX64::qword, tmp); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaC_barriertable)]); + } + build.setLabel(skip); break; } @@ -1117,8 +921,6 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } case IrCmd::CLOSE_UPVALS: { - LUAU_ASSERT(inst.a.kind == IrOpKind::VmReg); - Label next; ScopedRegX64 tmp1{regs, SizeX64::qword}; ScopedRegX64 tmp2{regs, SizeX64::qword}; @@ -1129,15 +931,20 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.jcc(ConditionX64::Zero, next); // ra <= L->openuval->v - build.lea(tmp2.reg, addr[rBase + inst.a.index * sizeof(TValue)]); + build.lea(tmp2.reg, addr[rBase + vmRegOp(inst.a) * sizeof(TValue)]); build.cmp(tmp2.reg, qword[tmp1.reg + offsetof(UpVal, v)]); build.jcc(ConditionX64::Above, next); - if (rArg2 != tmp2.reg) - build.mov(rArg2, tmp2.reg); + tmp1.free(); - build.mov(rArg1, rState); - build.call(qword[rNativeContext + offsetof(NativeContext, luaF_close)]); + { + ScopedSpills spillGuard(regs); + + IrCallWrapperX64 callWrap(regs, build, index); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::qword, tmp2); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaF_close)]); + } build.setLabel(next); break; @@ -1147,64 +954,35 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) break; // Fallbacks to non-IR instruction implementations - case IrCmd::LOP_SETLIST: - { - const Instruction* pc = proto->code + uintOp(inst.a); - LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); - LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); - LUAU_ASSERT(inst.d.kind == IrOpKind::Constant); - LUAU_ASSERT(inst.e.kind == IrOpKind::Constant); - - Label next; - emitInstSetList(build, pc, next); - build.setLabel(next); + case IrCmd::SETLIST: + regs.assertAllFree(); + emitInstSetList(regs, build, vmRegOp(inst.b), vmRegOp(inst.c), intOp(inst.d), uintOp(inst.e)); break; - } - case IrCmd::LOP_CALL: - { - const Instruction* pc = proto->code + uintOp(inst.a); - LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); - LUAU_ASSERT(inst.c.kind == IrOpKind::Constant); - LUAU_ASSERT(inst.d.kind == IrOpKind::Constant); - - emitInstCall(build, helpers, pc, uintOp(inst.a)); + case IrCmd::CALL: + regs.assertAllFree(); + regs.assertNoSpills(); + emitInstCall(build, helpers, vmRegOp(inst.a), intOp(inst.b), intOp(inst.c)); break; - } - case IrCmd::LOP_RETURN: - { - const Instruction* pc = proto->code + uintOp(inst.a); - LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); - LUAU_ASSERT(inst.c.kind == IrOpKind::Constant); - - emitInstReturn(build, helpers, pc, uintOp(inst.a)); + case IrCmd::RETURN: + regs.assertAllFree(); + regs.assertNoSpills(); + emitInstReturn(build, helpers, vmRegOp(inst.a), intOp(inst.b)); break; - } - case IrCmd::LOP_FORGLOOP: - LUAU_ASSERT(inst.a.kind == IrOpKind::VmReg); - emitinstForGLoop(build, inst.a.index, intOp(inst.b), labelOp(inst.c), labelOp(inst.d)); + case IrCmd::FORGLOOP: + regs.assertAllFree(); + emitinstForGLoop(build, vmRegOp(inst.a), intOp(inst.b), labelOp(inst.c), labelOp(inst.d)); break; - case IrCmd::LOP_FORGLOOP_FALLBACK: - LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); - emitinstForGLoopFallback(build, uintOp(inst.a), inst.b.index, intOp(inst.c), labelOp(inst.d)); - build.jmp(labelOp(inst.e)); + case IrCmd::FORGLOOP_FALLBACK: + regs.assertAllFree(); + emitinstForGLoopFallback(build, vmRegOp(inst.a), intOp(inst.b), labelOp(inst.c)); + build.jmp(labelOp(inst.d)); break; - case IrCmd::LOP_FORGPREP_XNEXT_FALLBACK: - LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); - emitInstForGPrepXnextFallback(build, uintOp(inst.a), inst.b.index, labelOp(inst.c)); + case IrCmd::FORGPREP_XNEXT_FALLBACK: + regs.assertAllFree(); + emitInstForGPrepXnextFallback(build, uintOp(inst.a), vmRegOp(inst.b), labelOp(inst.c)); break; - case IrCmd::LOP_AND: - emitInstAnd(build, proto->code + uintOp(inst.a)); - break; - case IrCmd::LOP_ANDK: - emitInstAndK(build, proto->code + uintOp(inst.a)); - break; - case IrCmd::LOP_OR: - emitInstOr(build, proto->code + uintOp(inst.a)); - break; - case IrCmd::LOP_ORK: - emitInstOrK(build, proto->code + uintOp(inst.a)); - break; - case IrCmd::LOP_COVERAGE: + case IrCmd::COVERAGE: + regs.assertAllFree(); emitInstCoverage(build, uintOp(inst.a)); break; @@ -1213,12 +991,14 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.c.kind == IrOpKind::VmConst); + regs.assertAllFree(); emitFallback(build, data, LOP_GETGLOBAL, uintOp(inst.a)); break; case IrCmd::FALLBACK_SETGLOBAL: LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.c.kind == IrOpKind::VmConst); + regs.assertAllFree(); emitFallback(build, data, LOP_SETGLOBAL, uintOp(inst.a)); break; case IrCmd::FALLBACK_GETTABLEKS: @@ -1226,6 +1006,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.d.kind == IrOpKind::VmConst); + regs.assertAllFree(); emitFallback(build, data, LOP_GETTABLEKS, uintOp(inst.a)); break; case IrCmd::FALLBACK_SETTABLEKS: @@ -1233,6 +1014,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.d.kind == IrOpKind::VmConst); + regs.assertAllFree(); emitFallback(build, data, LOP_SETTABLEKS, uintOp(inst.a)); break; case IrCmd::FALLBACK_NAMECALL: @@ -1240,38 +1022,55 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.d.kind == IrOpKind::VmConst); + regs.assertAllFree(); emitFallback(build, data, LOP_NAMECALL, uintOp(inst.a)); break; case IrCmd::FALLBACK_PREPVARARGS: LUAU_ASSERT(inst.b.kind == IrOpKind::Constant); + regs.assertAllFree(); emitFallback(build, data, LOP_PREPVARARGS, uintOp(inst.a)); break; case IrCmd::FALLBACK_GETVARARGS: LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.c.kind == IrOpKind::Constant); + regs.assertAllFree(); emitFallback(build, data, LOP_GETVARARGS, uintOp(inst.a)); break; case IrCmd::FALLBACK_NEWCLOSURE: LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.c.kind == IrOpKind::Constant); + regs.assertAllFree(); emitFallback(build, data, LOP_NEWCLOSURE, uintOp(inst.a)); break; case IrCmd::FALLBACK_DUPCLOSURE: LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.c.kind == IrOpKind::VmConst); + regs.assertAllFree(); emitFallback(build, data, LOP_DUPCLOSURE, uintOp(inst.a)); break; case IrCmd::FALLBACK_FORGPREP: + regs.assertAllFree(); emitFallback(build, data, LOP_FORGPREP, uintOp(inst.a)); break; default: LUAU_ASSERT(!"Not supported yet"); break; } + + regs.freeLastUseRegs(inst, index); +} + +bool IrLoweringX64::hasError() const +{ + // If register allocator had to use more stack slots than we have available, this function can't run natively + if (regs.maxUsedSlot > kSpillSlots) + return true; + + return false; } bool IrLoweringX64::isFallthroughBlock(IrBlock target, IrBlock next) @@ -1285,7 +1084,7 @@ void IrLoweringX64::jumpOrFallthrough(IrBlock& target, IrBlock& next) build.jmp(target.label); } -OperandX64 IrLoweringX64::memRegDoubleOp(IrOp op) const +OperandX64 IrLoweringX64::memRegDoubleOp(IrOp op) { switch (op.kind) { @@ -1294,9 +1093,9 @@ OperandX64 IrLoweringX64::memRegDoubleOp(IrOp op) const case IrOpKind::Constant: return build.f64(doubleOp(op)); case IrOpKind::VmReg: - return luauRegValue(op.index); + return luauRegValue(vmRegOp(op)); case IrOpKind::VmConst: - return luauConstantValue(op.index); + return luauConstantValue(vmConstOp(op)); default: LUAU_ASSERT(!"Unsupported operand kind"); } @@ -1304,16 +1103,16 @@ OperandX64 IrLoweringX64::memRegDoubleOp(IrOp op) const return noreg; } -OperandX64 IrLoweringX64::memRegTagOp(IrOp op) const +OperandX64 IrLoweringX64::memRegTagOp(IrOp op) { switch (op.kind) { case IrOpKind::Inst: return regOp(op); case IrOpKind::VmReg: - return luauRegTag(op.index); + return luauRegTag(vmRegOp(op)); case IrOpKind::VmConst: - return luauConstantTag(op.index); + return luauConstantTag(vmConstOp(op)); default: LUAU_ASSERT(!"Unsupported operand kind"); } @@ -1321,9 +1120,15 @@ OperandX64 IrLoweringX64::memRegTagOp(IrOp op) const return noreg; } -RegisterX64 IrLoweringX64::regOp(IrOp op) const +RegisterX64 IrLoweringX64::regOp(IrOp op) { - return function.instOp(op).regX64; + IrInst& inst = function.instOp(op); + + if (inst.spilled) + regs.restore(inst, false); + + LUAU_ASSERT(inst.regX64 != noreg); + return inst.regX64; } IrConst IrLoweringX64::constOp(IrOp op) const diff --git a/CodeGen/src/IrLoweringX64.h b/CodeGen/src/IrLoweringX64.h index a0ad3eab..42d26277 100644 --- a/CodeGen/src/IrLoweringX64.h +++ b/CodeGen/src/IrLoweringX64.h @@ -3,8 +3,7 @@ #include "Luau/AssemblyBuilderX64.h" #include "Luau/IrData.h" - -#include "IrRegAllocX64.h" +#include "Luau/IrRegAllocX64.h" #include @@ -24,20 +23,21 @@ namespace X64 struct IrLoweringX64 { - // Some of these arguments are only required while we re-use old direct bytecode to x64 lowering - IrLoweringX64(AssemblyBuilderX64& build, ModuleHelpers& helpers, NativeState& data, Proto* proto, IrFunction& function); - - void lower(AssemblyOptions options); + IrLoweringX64(AssemblyBuilderX64& build, ModuleHelpers& helpers, NativeState& data, IrFunction& function); void lowerInst(IrInst& inst, uint32_t index, IrBlock& next); + bool hasError() const; + bool isFallthroughBlock(IrBlock target, IrBlock next); void jumpOrFallthrough(IrBlock& target, IrBlock& next); + void storeDoubleAsFloat(OperandX64 dst, IrOp src); + // Operand data lookup helpers - OperandX64 memRegDoubleOp(IrOp op) const; - OperandX64 memRegTagOp(IrOp op) const; - RegisterX64 regOp(IrOp op) const; + OperandX64 memRegDoubleOp(IrOp op); + OperandX64 memRegTagOp(IrOp op); + RegisterX64 regOp(IrOp op); IrConst constOp(IrOp op) const; uint8_t tagOp(IrOp op) const; @@ -52,7 +52,6 @@ struct IrLoweringX64 AssemblyBuilderX64& build; ModuleHelpers& helpers; NativeState& data; - Proto* proto = nullptr; // Temporarily required to provide 'Instruction* pc' to old emitInst* methods IrFunction& function; diff --git a/CodeGen/src/IrRegAllocA64.cpp b/CodeGen/src/IrRegAllocA64.cpp new file mode 100644 index 00000000..c6db9e9e --- /dev/null +++ b/CodeGen/src/IrRegAllocA64.cpp @@ -0,0 +1,183 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "IrRegAllocA64.h" + +#ifdef _MSC_VER +#include +#endif + +namespace Luau +{ +namespace CodeGen +{ +namespace A64 +{ + +inline int setBit(uint32_t n) +{ + LUAU_ASSERT(n); + +#ifdef _MSC_VER + unsigned long rl; + _BitScanReverse(&rl, n); + return int(rl); +#else + return 31 - __builtin_clz(n); +#endif +} + +IrRegAllocA64::IrRegAllocA64(IrFunction& function, std::initializer_list> regs) + : function(function) +{ + for (auto& p : regs) + { + LUAU_ASSERT(p.first.kind == p.second.kind && p.first.index <= p.second.index); + + Set& set = getSet(p.first.kind); + + for (int i = p.first.index; i <= p.second.index; ++i) + set.base |= 1u << i; + } + + gpr.free = gpr.base; + simd.free = simd.base; +} + +RegisterA64 IrRegAllocA64::allocReg(KindA64 kind) +{ + Set& set = getSet(kind); + + if (set.free == 0) + { + LUAU_ASSERT(!"Out of registers to allocate"); + return noreg; + } + + int index = setBit(set.free); + set.free &= ~(1u << index); + + return RegisterA64{kind, uint8_t(index)}; +} + +RegisterA64 IrRegAllocA64::allocTemp(KindA64 kind) +{ + Set& set = getSet(kind); + + if (set.free == 0) + { + LUAU_ASSERT(!"Out of registers to allocate"); + return noreg; + } + + int index = setBit(set.free); + + set.free &= ~(1u << index); + set.temp |= 1u << index; + + return RegisterA64{kind, uint8_t(index)}; +} + +RegisterA64 IrRegAllocA64::allocReuse(KindA64 kind, uint32_t index, std::initializer_list oprefs) +{ + for (IrOp op : oprefs) + { + if (op.kind != IrOpKind::Inst) + continue; + + IrInst& source = function.instructions[op.index]; + + if (source.lastUse == index && !source.reusedReg) + { + LUAU_ASSERT(source.regA64.kind == kind); + + source.reusedReg = true; + return source.regA64; + } + } + + return allocReg(kind); +} + +void IrRegAllocA64::freeReg(RegisterA64 reg) +{ + Set& set = getSet(reg.kind); + + LUAU_ASSERT((set.base & (1u << reg.index)) != 0); + LUAU_ASSERT((set.free & (1u << reg.index)) == 0); + set.free |= 1u << reg.index; +} + +void IrRegAllocA64::freeLastUseReg(IrInst& target, uint32_t index) +{ + if (target.lastUse == index && !target.reusedReg) + { + // Register might have already been freed if it had multiple uses inside a single instruction + if (target.regA64 == noreg) + return; + + freeReg(target.regA64); + target.regA64 = noreg; + } +} + +void IrRegAllocA64::freeLastUseRegs(const IrInst& inst, uint32_t index) +{ + auto checkOp = [this, index](IrOp op) { + if (op.kind == IrOpKind::Inst) + freeLastUseReg(function.instructions[op.index], index); + }; + + checkOp(inst.a); + checkOp(inst.b); + checkOp(inst.c); + checkOp(inst.d); + checkOp(inst.e); + checkOp(inst.f); +} + +void IrRegAllocA64::freeTempRegs() +{ + LUAU_ASSERT((gpr.free & gpr.temp) == 0); + gpr.free |= gpr.temp; + gpr.temp = 0; + + LUAU_ASSERT((simd.free & simd.temp) == 0); + simd.free |= simd.temp; + simd.temp = 0; +} + +void IrRegAllocA64::assertAllFree() const +{ + LUAU_ASSERT(gpr.free == gpr.base); + LUAU_ASSERT(simd.free == simd.base); +} + +void IrRegAllocA64::assertAllFreeExcept(RegisterA64 reg) const +{ + const Set& set = const_cast(this)->getSet(reg.kind); + const Set& other = &set == &gpr ? simd : gpr; + + LUAU_ASSERT(set.free == (set.base & ~(1u << reg.index))); + LUAU_ASSERT(other.free == other.base); +} + +IrRegAllocA64::Set& IrRegAllocA64::getSet(KindA64 kind) +{ + switch (kind) + { + case KindA64::x: + case KindA64::w: + return gpr; + + case KindA64::d: + case KindA64::q: + return simd; + + default: + LUAU_ASSERT(!"Unexpected register kind"); + LUAU_UNREACHABLE(); + } +} + +} // namespace A64 +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/src/IrRegAllocA64.h b/CodeGen/src/IrRegAllocA64.h new file mode 100644 index 00000000..9ff03552 --- /dev/null +++ b/CodeGen/src/IrRegAllocA64.h @@ -0,0 +1,56 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/IrData.h" +#include "Luau/RegisterA64.h" + +#include +#include + +namespace Luau +{ +namespace CodeGen +{ +namespace A64 +{ + +struct IrRegAllocA64 +{ + IrRegAllocA64(IrFunction& function, std::initializer_list> regs); + + RegisterA64 allocReg(KindA64 kind); + RegisterA64 allocTemp(KindA64 kind); + RegisterA64 allocReuse(KindA64 kind, uint32_t index, std::initializer_list oprefs); + + void freeReg(RegisterA64 reg); + + void freeLastUseReg(IrInst& target, uint32_t index); + void freeLastUseRegs(const IrInst& inst, uint32_t index); + + void freeTempRegs(); + + void assertAllFree() const; + void assertAllFreeExcept(RegisterA64 reg) const; + + IrFunction& function; + + struct Set + { + // which registers are in the set that the allocator manages (initialized at construction) + uint32_t base = 0; + + // which subset of initial set is free + uint32_t free = 0; + + // which subset of initial set is allocated as temporary + uint32_t temp = 0; + }; + + Set gpr, simd; + + Set& getSet(KindA64 kind); +}; + +} // namespace A64 +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/src/IrRegAllocX64.cpp b/CodeGen/src/IrRegAllocX64.cpp index c527d033..dc9e7f90 100644 --- a/CodeGen/src/IrRegAllocX64.cpp +++ b/CodeGen/src/IrRegAllocX64.cpp @@ -1,19 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "IrRegAllocX64.h" - -#include "Luau/CodeGen.h" -#include "Luau/DenseHash.h" -#include "Luau/IrAnalysis.h" -#include "Luau/IrDump.h" -#include "Luau/IrUtils.h" +#include "Luau/IrRegAllocX64.h" #include "EmitCommonX64.h" -#include "EmitInstructionX64.h" -#include "NativeState.h" - -#include "lstate.h" - -#include namespace Luau { @@ -24,14 +12,22 @@ namespace X64 static const RegisterX64 kGprAllocOrder[] = {rax, rdx, rcx, rbx, rsi, rdi, r8, r9, r10, r11}; -IrRegAllocX64::IrRegAllocX64(IrFunction& function) - : function(function) +static bool isFullTvalueOperand(IrCmd cmd) { - freeGprMap.fill(true); - freeXmmMap.fill(true); + return cmd == IrCmd::LOAD_TVALUE || cmd == IrCmd::LOAD_NODE_VALUE_TV; } -RegisterX64 IrRegAllocX64::allocGprReg(SizeX64 preferredSize) +IrRegAllocX64::IrRegAllocX64(AssemblyBuilderX64& build, IrFunction& function) + : build(build) + , function(function) +{ + freeGprMap.fill(true); + gprInstUsers.fill(kInvalidInstIdx); + freeXmmMap.fill(true); + xmmInstUsers.fill(kInvalidInstIdx); +} + +RegisterX64 IrRegAllocX64::allocGprReg(SizeX64 preferredSize, uint32_t instIdx) { LUAU_ASSERT( preferredSize == SizeX64::byte || preferredSize == SizeX64::word || preferredSize == SizeX64::dword || preferredSize == SizeX64::qword); @@ -41,30 +37,40 @@ RegisterX64 IrRegAllocX64::allocGprReg(SizeX64 preferredSize) if (freeGprMap[reg.index]) { freeGprMap[reg.index] = false; + gprInstUsers[reg.index] = instIdx; return RegisterX64{preferredSize, reg.index}; } } + // If possible, spill the value with the furthest next use + if (uint32_t furthestUseTarget = findInstructionWithFurthestNextUse(gprInstUsers); furthestUseTarget != kInvalidInstIdx) + return takeReg(function.instructions[furthestUseTarget].regX64, instIdx); + LUAU_ASSERT(!"Out of GPR registers to allocate"); return noreg; } -RegisterX64 IrRegAllocX64::allocXmmReg() +RegisterX64 IrRegAllocX64::allocXmmReg(uint32_t instIdx) { for (size_t i = 0; i < freeXmmMap.size(); ++i) { if (freeXmmMap[i]) { freeXmmMap[i] = false; + xmmInstUsers[i] = instIdx; return RegisterX64{SizeX64::xmmword, uint8_t(i)}; } } + // Out of registers, spill the value with the furthest next use + if (uint32_t furthestUseTarget = findInstructionWithFurthestNextUse(xmmInstUsers); furthestUseTarget != kInvalidInstIdx) + return takeReg(function.instructions[furthestUseTarget].regX64, instIdx); + LUAU_ASSERT(!"Out of XMM registers to allocate"); return noreg; } -RegisterX64 IrRegAllocX64::allocGprRegOrReuse(SizeX64 preferredSize, uint32_t index, std::initializer_list oprefs) +RegisterX64 IrRegAllocX64::allocGprRegOrReuse(SizeX64 preferredSize, uint32_t instIdx, std::initializer_list oprefs) { for (IrOp op : oprefs) { @@ -73,20 +79,21 @@ RegisterX64 IrRegAllocX64::allocGprRegOrReuse(SizeX64 preferredSize, uint32_t in IrInst& source = function.instructions[op.index]; - if (source.lastUse == index && !source.reusedReg) + if (source.lastUse == instIdx && !source.reusedReg && !source.spilled) { LUAU_ASSERT(source.regX64.size != SizeX64::xmmword); LUAU_ASSERT(source.regX64 != noreg); source.reusedReg = true; + gprInstUsers[source.regX64.index] = instIdx; return RegisterX64{preferredSize, source.regX64.index}; } } - return allocGprReg(preferredSize); + return allocGprReg(preferredSize, instIdx); } -RegisterX64 IrRegAllocX64::allocXmmRegOrReuse(uint32_t index, std::initializer_list oprefs) +RegisterX64 IrRegAllocX64::allocXmmRegOrReuse(uint32_t instIdx, std::initializer_list oprefs) { for (IrOp op : oprefs) { @@ -95,26 +102,47 @@ RegisterX64 IrRegAllocX64::allocXmmRegOrReuse(uint32_t index, std::initializer_l IrInst& source = function.instructions[op.index]; - if (source.lastUse == index && !source.reusedReg) + if (source.lastUse == instIdx && !source.reusedReg && !source.spilled) { LUAU_ASSERT(source.regX64.size == SizeX64::xmmword); LUAU_ASSERT(source.regX64 != noreg); source.reusedReg = true; + xmmInstUsers[source.regX64.index] = instIdx; return source.regX64; } } - return allocXmmReg(); + return allocXmmReg(instIdx); } -RegisterX64 IrRegAllocX64::takeGprReg(RegisterX64 reg) +RegisterX64 IrRegAllocX64::takeReg(RegisterX64 reg, uint32_t instIdx) { - // In a more advanced register allocator, this would require a spill for the current register user - // But at the current stage we don't have register live ranges intersecting forced register uses - LUAU_ASSERT(freeGprMap[reg.index]); + if (reg.size == SizeX64::xmmword) + { + if (!freeXmmMap[reg.index]) + { + LUAU_ASSERT(xmmInstUsers[reg.index] != kInvalidInstIdx); + preserve(function.instructions[xmmInstUsers[reg.index]]); + } + + LUAU_ASSERT(freeXmmMap[reg.index]); + freeXmmMap[reg.index] = false; + xmmInstUsers[reg.index] = instIdx; + } + else + { + if (!freeGprMap[reg.index]) + { + LUAU_ASSERT(gprInstUsers[reg.index] != kInvalidInstIdx); + preserve(function.instructions[gprInstUsers[reg.index]]); + } + + LUAU_ASSERT(freeGprMap[reg.index]); + freeGprMap[reg.index] = false; + gprInstUsers[reg.index] = instIdx; + } - freeGprMap[reg.index] = false; return reg; } @@ -124,17 +152,19 @@ void IrRegAllocX64::freeReg(RegisterX64 reg) { LUAU_ASSERT(!freeXmmMap[reg.index]); freeXmmMap[reg.index] = true; + xmmInstUsers[reg.index] = kInvalidInstIdx; } else { LUAU_ASSERT(!freeGprMap[reg.index]); freeGprMap[reg.index] = true; + gprInstUsers[reg.index] = kInvalidInstIdx; } } -void IrRegAllocX64::freeLastUseReg(IrInst& target, uint32_t index) +void IrRegAllocX64::freeLastUseReg(IrInst& target, uint32_t instIdx) { - if (target.lastUse == index && !target.reusedReg) + if (isLastUseReg(target, instIdx)) { // Register might have already been freed if it had multiple uses inside a single instruction if (target.regX64 == noreg) @@ -145,11 +175,11 @@ void IrRegAllocX64::freeLastUseReg(IrInst& target, uint32_t index) } } -void IrRegAllocX64::freeLastUseRegs(const IrInst& inst, uint32_t index) +void IrRegAllocX64::freeLastUseRegs(const IrInst& inst, uint32_t instIdx) { - auto checkOp = [this, index](IrOp op) { + auto checkOp = [this, instIdx](IrOp op) { if (op.kind == IrOpKind::Inst) - freeLastUseReg(function.instructions[op.index], index); + freeLastUseReg(function.instructions[op.index], instIdx); }; checkOp(inst.a); @@ -160,6 +190,185 @@ void IrRegAllocX64::freeLastUseRegs(const IrInst& inst, uint32_t index) checkOp(inst.f); } +bool IrRegAllocX64::isLastUseReg(const IrInst& target, uint32_t instIdx) const +{ + return target.lastUse == instIdx && !target.reusedReg; +} + +void IrRegAllocX64::preserve(IrInst& inst) +{ + bool doubleSlot = isFullTvalueOperand(inst.cmd); + + // Find a free stack slot. Two consecutive slots might be required for 16 byte TValues, so '- 1' is used + for (unsigned i = 0; i < unsigned(usedSpillSlots.size() - 1); ++i) + { + if (usedSpillSlots.test(i)) + continue; + + if (doubleSlot && usedSpillSlots.test(i + 1)) + { + ++i; // No need to retest this double position + continue; + } + + if (inst.regX64.size == SizeX64::xmmword && doubleSlot) + { + build.vmovups(xmmword[sSpillArea + i * 8], inst.regX64); + } + else if (inst.regX64.size == SizeX64::xmmword) + { + build.vmovsd(qword[sSpillArea + i * 8], inst.regX64); + } + else + { + OperandX64 location = addr[sSpillArea + i * 8]; + location.memSize = inst.regX64.size; // Override memory access size + build.mov(location, inst.regX64); + } + + usedSpillSlots.set(i); + + if (i + 1 > maxUsedSlot) + maxUsedSlot = i + 1; + + if (doubleSlot) + { + usedSpillSlots.set(i + 1); + + if (i + 2 > maxUsedSlot) + maxUsedSlot = i + 2; + } + + IrSpillX64 spill; + spill.instIdx = function.getInstIndex(inst); + spill.useDoubleSlot = doubleSlot; + spill.stackSlot = uint8_t(i); + spill.originalLoc = inst.regX64; + + spills.push_back(spill); + + freeReg(inst.regX64); + + inst.regX64 = noreg; + inst.spilled = true; + return; + } + + LUAU_ASSERT(!"nowhere to spill"); +} + +void IrRegAllocX64::restore(IrInst& inst, bool intoOriginalLocation) +{ + uint32_t instIdx = function.getInstIndex(inst); + + for (size_t i = 0; i < spills.size(); i++) + { + const IrSpillX64& spill = spills[i]; + + if (spill.instIdx == instIdx) + { + LUAU_ASSERT(spill.stackSlot != kNoStackSlot); + RegisterX64 reg; + + if (spill.originalLoc.size == SizeX64::xmmword) + { + reg = intoOriginalLocation ? takeReg(spill.originalLoc, instIdx) : allocXmmReg(instIdx); + + if (spill.useDoubleSlot) + build.vmovups(reg, xmmword[sSpillArea + spill.stackSlot * 8]); + else + build.vmovsd(reg, qword[sSpillArea + spill.stackSlot * 8]); + } + else + { + reg = intoOriginalLocation ? takeReg(spill.originalLoc, instIdx) : allocGprReg(spill.originalLoc.size, instIdx); + + OperandX64 location = addr[sSpillArea + spill.stackSlot * 8]; + location.memSize = reg.size; // Override memory access size + build.mov(reg, location); + } + + inst.regX64 = reg; + inst.spilled = false; + + usedSpillSlots.set(spill.stackSlot, false); + + if (spill.useDoubleSlot) + usedSpillSlots.set(spill.stackSlot + 1, false); + + spills[i] = spills.back(); + spills.pop_back(); + return; + } + } +} + +void IrRegAllocX64::preserveAndFreeInstValues() +{ + for (uint32_t instIdx : gprInstUsers) + { + if (instIdx != kInvalidInstIdx) + preserve(function.instructions[instIdx]); + } + + for (uint32_t instIdx : xmmInstUsers) + { + if (instIdx != kInvalidInstIdx) + preserve(function.instructions[instIdx]); + } +} + +bool IrRegAllocX64::shouldFreeGpr(RegisterX64 reg) const +{ + if (reg == noreg) + return false; + + LUAU_ASSERT(reg.size != SizeX64::xmmword); + + for (RegisterX64 gpr : kGprAllocOrder) + { + if (reg.index == gpr.index) + return true; + } + + return false; +} + +uint32_t IrRegAllocX64::findInstructionWithFurthestNextUse(const std::array& regInstUsers) const +{ + uint32_t furthestUseTarget = kInvalidInstIdx; + uint32_t furthestUseLocation = 0; + + for (uint32_t regInstUser : regInstUsers) + { + // Cannot spill temporary registers or the register of the value that's defined in the current instruction + if (regInstUser == kInvalidInstIdx || regInstUser == currInstIdx) + continue; + + uint32_t nextUse = getNextInstUse(function, regInstUser, currInstIdx); + + // Cannot spill value that is about to be used in the current instruction + if (nextUse == currInstIdx) + continue; + + if (furthestUseTarget == kInvalidInstIdx || nextUse > furthestUseLocation) + { + furthestUseLocation = nextUse; + furthestUseTarget = regInstUser; + } + } + + return furthestUseTarget; +} + +void IrRegAllocX64::assertFree(RegisterX64 reg) const +{ + if (reg.size == SizeX64::xmmword) + LUAU_ASSERT(freeXmmMap[reg.index]); + else + LUAU_ASSERT(freeGprMap[reg.index]); +} + void IrRegAllocX64::assertAllFree() const { for (RegisterX64 reg : kGprAllocOrder) @@ -169,6 +378,11 @@ void IrRegAllocX64::assertAllFree() const LUAU_ASSERT(free); } +void IrRegAllocX64::assertNoSpills() const +{ + LUAU_ASSERT(spills.empty()); +} + ScopedRegX64::ScopedRegX64(IrRegAllocX64& owner) : owner(owner) , reg(noreg) @@ -199,9 +413,9 @@ void ScopedRegX64::alloc(SizeX64 size) LUAU_ASSERT(reg == noreg); if (size == SizeX64::xmmword) - reg = owner.allocXmmReg(); + reg = owner.allocXmmReg(kInvalidInstIdx); else - reg = owner.allocGprReg(size); + reg = owner.allocGprReg(size, kInvalidInstIdx); } void ScopedRegX64::free() @@ -211,6 +425,48 @@ void ScopedRegX64::free() reg = noreg; } +RegisterX64 ScopedRegX64::release() +{ + RegisterX64 tmp = reg; + reg = noreg; + return tmp; +} + +ScopedSpills::ScopedSpills(IrRegAllocX64& owner) + : owner(owner) +{ + snapshot = owner.spills; +} + +ScopedSpills::~ScopedSpills() +{ + // Taking a copy of current spills because we are going to potentially restore them + std::vector current = owner.spills; + + // Restore registers that were spilled inside scope protected by this object + for (IrSpillX64& curr : current) + { + // If spill existed before current scope, it can be restored outside of it + if (!wasSpilledBefore(curr)) + { + IrInst& inst = owner.function.instructions[curr.instIdx]; + + owner.restore(inst, /*intoOriginalLocation*/ true); + } + } +} + +bool ScopedSpills::wasSpilledBefore(const IrSpillX64& spill) const +{ + for (const IrSpillX64& preexisting : snapshot) + { + if (spill.instIdx == preexisting.instIdx) + return true; + } + + return false; +} + } // namespace X64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/IrRegAllocX64.h b/CodeGen/src/IrRegAllocX64.h deleted file mode 100644 index 497bb035..00000000 --- a/CodeGen/src/IrRegAllocX64.h +++ /dev/null @@ -1,60 +0,0 @@ -// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#pragma once - -#include "Luau/IrData.h" -#include "Luau/RegisterX64.h" - -#include -#include - -namespace Luau -{ -namespace CodeGen -{ -namespace X64 -{ - -struct IrRegAllocX64 -{ - IrRegAllocX64(IrFunction& function); - - RegisterX64 allocGprReg(SizeX64 preferredSize); - RegisterX64 allocXmmReg(); - - RegisterX64 allocGprRegOrReuse(SizeX64 preferredSize, uint32_t index, std::initializer_list oprefs); - RegisterX64 allocXmmRegOrReuse(uint32_t index, std::initializer_list oprefs); - - RegisterX64 takeGprReg(RegisterX64 reg); - - void freeReg(RegisterX64 reg); - void freeLastUseReg(IrInst& target, uint32_t index); - void freeLastUseRegs(const IrInst& inst, uint32_t index); - - void assertAllFree() const; - - IrFunction& function; - - std::array freeGprMap; - std::array freeXmmMap; -}; - -struct ScopedRegX64 -{ - explicit ScopedRegX64(IrRegAllocX64& owner); - ScopedRegX64(IrRegAllocX64& owner, SizeX64 size); - ScopedRegX64(IrRegAllocX64& owner, RegisterX64 reg); - ~ScopedRegX64(); - - ScopedRegX64(const ScopedRegX64&) = delete; - ScopedRegX64& operator=(const ScopedRegX64&) = delete; - - void alloc(SizeX64 size); - void free(); - - IrRegAllocX64& owner; - RegisterX64 reg; -}; - -} // namespace X64 -} // namespace CodeGen -} // namespace Luau diff --git a/CodeGen/src/IrTranslateBuiltins.cpp b/CodeGen/src/IrTranslateBuiltins.cpp index d9f935c4..ba491564 100644 --- a/CodeGen/src/IrTranslateBuiltins.cpp +++ b/CodeGen/src/IrTranslateBuiltins.cpp @@ -6,7 +6,6 @@ #include "lstate.h" -// TODO: should be possible to handle fastcalls in contexts where nresults is -1 by adding the adjustment instruction // TODO: when nresults is less than our actual result count, we can skip computing/writing unused results namespace Luau @@ -26,8 +25,8 @@ BuiltinImplResult translateBuiltinNumberToNumber( build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); build.inst(IrCmd::FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), args, build.constInt(nparams), build.constInt(nresults)); - // TODO: tag update might not be required, we place it here now because FASTCALL is not modeled in constant propagation yet - build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); + if (ra != arg) + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); return {BuiltinImplType::UsesFallback, 1}; } @@ -43,8 +42,8 @@ BuiltinImplResult translateBuiltin2NumberToNumber( build.loadAndCheckTag(args, LUA_TNUMBER, fallback); build.inst(IrCmd::FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), args, build.constInt(nparams), build.constInt(nresults)); - // TODO:tag update might not be required, we place it here now because FASTCALL is not modeled in constant propagation yet - build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); + if (ra != arg) + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); return {BuiltinImplType::UsesFallback, 1}; } @@ -59,9 +58,11 @@ BuiltinImplResult translateBuiltinNumberTo2Number( build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); build.inst(IrCmd::FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), args, build.constInt(nparams), build.constInt(nresults)); - // TODO: some tag updates might not be required, we place them here now because FASTCALL is not modeled in constant propagation yet - build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); - build.inst(IrCmd::STORE_TAG, build.vmReg(ra + 1), build.constTag(LUA_TNUMBER)); + if (ra != arg) + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); + + if (nresults > 1) + build.inst(IrCmd::STORE_TAG, build.vmReg(ra + 1), build.constTag(LUA_TNUMBER)); return {BuiltinImplType::UsesFallback, 2}; } @@ -131,8 +132,8 @@ BuiltinImplResult translateBuiltinMathLog( build.inst(IrCmd::FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), args, build.constInt(nparams), build.constInt(nresults)); - // TODO: tag update might not be required, we place it here now because FASTCALL is not modeled in constant propagation yet - build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); + if (ra != arg) + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); return {BuiltinImplType::UsesFallback, 1}; } @@ -190,10 +191,10 @@ BuiltinImplResult translateBuiltinMathClamp(IrBuilder& build, int nparams, int r build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); build.loadAndCheckTag(args, LUA_TNUMBER, fallback); - build.loadAndCheckTag(build.vmReg(args.index + 1), LUA_TNUMBER, fallback); + build.loadAndCheckTag(build.vmReg(vmRegOp(args) + 1), LUA_TNUMBER, fallback); IrOp min = build.inst(IrCmd::LOAD_DOUBLE, args); - IrOp max = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(args.index + 1)); + IrOp max = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(vmRegOp(args) + 1)); build.inst(IrCmd::JUMP_CMP_NUM, min, max, build.cond(IrCondition::NotLessEqual), fallback, block); build.beginBlock(block); @@ -210,6 +211,44 @@ BuiltinImplResult translateBuiltinMathClamp(IrBuilder& build, int nparams, int r return {BuiltinImplType::UsesFallback, 1}; } +BuiltinImplResult translateBuiltinMathUnary(IrBuilder& build, IrCmd cmd, int nparams, int ra, int arg, int nresults, IrOp fallback) +{ + if (nparams < 1 || nresults > 1) + return {BuiltinImplType::None, -1}; + + build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); + + IrOp varg = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); + IrOp result = build.inst(cmd, varg); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), result); + + if (ra != arg) + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); + + return {BuiltinImplType::UsesFallback, 1}; +} + +BuiltinImplResult translateBuiltinMathBinary(IrBuilder& build, IrCmd cmd, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +{ + if (nparams < 2 || nresults > 1) + return {BuiltinImplType::None, -1}; + + build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); + build.loadAndCheckTag(args, LUA_TNUMBER, fallback); + + IrOp lhs = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); + IrOp rhs = build.inst(IrCmd::LOAD_DOUBLE, args); + IrOp result = build.inst(cmd, lhs, rhs); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), result); + + if (ra != arg) + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); + + return {BuiltinImplType::UsesFallback, 1}; +} + BuiltinImplResult translateBuiltinType(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) { if (nparams < 1 || nresults > 1) @@ -218,7 +257,6 @@ BuiltinImplResult translateBuiltinType(IrBuilder& build, int nparams, int ra, in build.inst( IrCmd::FASTCALL, build.constUint(LBF_TYPE), build.vmReg(ra), build.vmReg(arg), args, build.constInt(nparams), build.constInt(nresults)); - // TODO: tag update might not be required, we place it here now because FASTCALL is not modeled in constant propagation yet build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TSTRING)); return {BuiltinImplType::UsesFallback, 1}; @@ -232,14 +270,38 @@ BuiltinImplResult translateBuiltinTypeof(IrBuilder& build, int nparams, int ra, build.inst( IrCmd::FASTCALL, build.constUint(LBF_TYPEOF), build.vmReg(ra), build.vmReg(arg), args, build.constInt(nparams), build.constInt(nresults)); - // TODO: tag update might not be required, we place it here now because FASTCALL is not modeled in constant propagation yet build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TSTRING)); return {BuiltinImplType::UsesFallback, 1}; } +BuiltinImplResult translateBuiltinVector(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +{ + if (nparams < 3 || nresults > 1) + return {BuiltinImplType::None, -1}; + + LUAU_ASSERT(LUA_VECTOR_SIZE == 3); + + build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); + build.loadAndCheckTag(args, LUA_TNUMBER, fallback); + build.loadAndCheckTag(build.vmReg(vmRegOp(args) + 1), LUA_TNUMBER, fallback); + + IrOp x = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); + IrOp y = build.inst(IrCmd::LOAD_DOUBLE, args); + IrOp z = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(vmRegOp(args) + 1)); + + build.inst(IrCmd::STORE_VECTOR, build.vmReg(ra), x, y, z); + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TVECTOR)); + + return {BuiltinImplType::UsesFallback, 1}; +} + BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg, IrOp args, int nparams, int nresults, IrOp fallback) { + // Builtins are not allowed to handle variadic arguments + if (nparams == LUA_MULTRET) + return {BuiltinImplType::None, -1}; + switch (bfid) { case LBF_ASSERT: @@ -257,9 +319,17 @@ BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg, case LBF_MATH_CLAMP: return translateBuiltinMathClamp(build, nparams, ra, arg, args, nresults, fallback); case LBF_MATH_FLOOR: + return translateBuiltinMathUnary(build, IrCmd::FLOOR_NUM, nparams, ra, arg, nresults, fallback); case LBF_MATH_CEIL: + return translateBuiltinMathUnary(build, IrCmd::CEIL_NUM, nparams, ra, arg, nresults, fallback); case LBF_MATH_SQRT: + return translateBuiltinMathUnary(build, IrCmd::SQRT_NUM, nparams, ra, arg, nresults, fallback); case LBF_MATH_ABS: + return translateBuiltinMathUnary(build, IrCmd::ABS_NUM, nparams, ra, arg, nresults, fallback); + case LBF_MATH_ROUND: + return translateBuiltinMathUnary(build, IrCmd::ROUND_NUM, nparams, ra, arg, nresults, fallback); + case LBF_MATH_POW: + return translateBuiltinMathBinary(build, IrCmd::POW_NUM, nparams, ra, arg, args, nresults, fallback); case LBF_MATH_EXP: case LBF_MATH_ASIN: case LBF_MATH_SIN: @@ -271,11 +341,9 @@ BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg, case LBF_MATH_TAN: case LBF_MATH_TANH: case LBF_MATH_LOG10: - case LBF_MATH_ROUND: case LBF_MATH_SIGN: return translateBuiltinNumberToNumber(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback); case LBF_MATH_FMOD: - case LBF_MATH_POW: case LBF_MATH_ATAN2: case LBF_MATH_LDEXP: return translateBuiltin2NumberToNumber(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback); @@ -286,6 +354,8 @@ BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg, return translateBuiltinType(build, nparams, ra, arg, args, nresults, fallback); case LBF_TYPEOF: return translateBuiltinTypeof(build, nparams, ra, arg, args, nresults, fallback); + case LBF_VECTOR: + return translateBuiltinVector(build, nparams, ra, arg, args, nresults, fallback); default: return {BuiltinImplType::None, -1}; } diff --git a/CodeGen/src/IrTranslation.cpp b/CodeGen/src/IrTranslation.cpp index 28c6aca1..a985318b 100644 --- a/CodeGen/src/IrTranslation.cpp +++ b/CodeGen/src/IrTranslation.cpp @@ -296,46 +296,60 @@ static void translateInstBinaryNumeric(IrBuilder& build, int ra, int rb, int rc, IrOp vb = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(rb)); IrOp vc; + IrOp result; + if (opc.kind == IrOpKind::VmConst) { LUAU_ASSERT(build.function.proto); - TValue protok = build.function.proto->k[opc.index]; + TValue protok = build.function.proto->k[vmConstOp(opc)]; LUAU_ASSERT(protok.tt == LUA_TNUMBER); - vc = build.constDouble(protok.value.n); + + // VM has special cases for exponentiation with constants + if (tm == TM_POW && protok.value.n == 0.5) + result = build.inst(IrCmd::SQRT_NUM, vb); + else if (tm == TM_POW && protok.value.n == 2.0) + result = build.inst(IrCmd::MUL_NUM, vb, vb); + else if (tm == TM_POW && protok.value.n == 3.0) + result = build.inst(IrCmd::MUL_NUM, vb, build.inst(IrCmd::MUL_NUM, vb, vb)); + else + vc = build.constDouble(protok.value.n); } else { vc = build.inst(IrCmd::LOAD_DOUBLE, opc); } - IrOp va; - - switch (tm) + if (result.kind == IrOpKind::None) { - case TM_ADD: - va = build.inst(IrCmd::ADD_NUM, vb, vc); - break; - case TM_SUB: - va = build.inst(IrCmd::SUB_NUM, vb, vc); - break; - case TM_MUL: - va = build.inst(IrCmd::MUL_NUM, vb, vc); - break; - case TM_DIV: - va = build.inst(IrCmd::DIV_NUM, vb, vc); - break; - case TM_MOD: - va = build.inst(IrCmd::MOD_NUM, vb, vc); - break; - case TM_POW: - va = build.inst(IrCmd::POW_NUM, vb, vc); - break; - default: - LUAU_ASSERT(!"unsupported binary op"); + LUAU_ASSERT(vc.kind != IrOpKind::None); + + switch (tm) + { + case TM_ADD: + result = build.inst(IrCmd::ADD_NUM, vb, vc); + break; + case TM_SUB: + result = build.inst(IrCmd::SUB_NUM, vb, vc); + break; + case TM_MUL: + result = build.inst(IrCmd::MUL_NUM, vb, vc); + break; + case TM_DIV: + result = build.inst(IrCmd::DIV_NUM, vb, vc); + break; + case TM_MOD: + result = build.inst(IrCmd::MOD_NUM, vb, vc); + break; + case TM_POW: + result = build.inst(IrCmd::POW_NUM, vb, vc); + break; + default: + LUAU_ASSERT(!"unsupported binary op"); + } } - build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), va); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), result); if (ra != rb && ra != rc) // TODO: optimization should handle second check, but we'll test this later build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); @@ -501,10 +515,10 @@ void translateFastCallN(IrBuilder& build, const Instruction* pc, int pcpos, bool if (br.type == BuiltinImplType::UsesFallback) { + LUAU_ASSERT(nparams != LUA_MULTRET && "builtins are not allowed to handle variadic arguments"); + if (nresults == LUA_MULTRET) build.inst(IrCmd::ADJUST_STACK_TO_REG, build.vmReg(ra), build.constInt(br.actualResultCount)); - else if (nparams == LUA_MULTRET) - build.inst(IrCmd::ADJUST_STACK_TO_TOP); } else { @@ -638,7 +652,7 @@ void translateInstForGPrepNext(IrBuilder& build, const Instruction* pc, int pcpo build.inst(IrCmd::JUMP, target); build.beginBlock(fallback); - build.inst(IrCmd::LOP_FORGPREP_XNEXT_FALLBACK, build.constUint(pcpos), build.vmReg(ra), target); + build.inst(IrCmd::FORGPREP_XNEXT_FALLBACK, build.constUint(pcpos), build.vmReg(ra), target); } void translateInstForGPrepInext(IrBuilder& build, const Instruction* pc, int pcpos) @@ -670,7 +684,7 @@ void translateInstForGPrepInext(IrBuilder& build, const Instruction* pc, int pcp build.inst(IrCmd::JUMP, target); build.beginBlock(fallback); - build.inst(IrCmd::LOP_FORGPREP_XNEXT_FALLBACK, build.constUint(pcpos), build.vmReg(ra), target); + build.inst(IrCmd::FORGPREP_XNEXT_FALLBACK, build.constUint(pcpos), build.vmReg(ra), target); } void translateInstForGLoopIpairs(IrBuilder& build, const Instruction* pc, int pcpos) @@ -721,7 +735,8 @@ void translateInstForGLoopIpairs(IrBuilder& build, const Instruction* pc, int pc build.inst(IrCmd::JUMP, loopRepeat); build.beginBlock(fallback); - build.inst(IrCmd::LOP_FORGLOOP_FALLBACK, build.constUint(pcpos), build.vmReg(ra), build.constInt(int(pc[1])), loopRepeat, loopExit); + build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1)); + build.inst(IrCmd::FORGLOOP_FALLBACK, build.vmReg(ra), build.constInt(int(pc[1])), loopRepeat, loopExit); // Fallthrough in original bytecode is implicit, so we start next internal block here if (build.isInternalBlock(loopExit)) @@ -1093,5 +1108,71 @@ void translateInstNamecall(IrBuilder& build, const Instruction* pc, int pcpos) build.beginBlock(next); } +void translateInstAndX(IrBuilder& build, const Instruction* pc, int pcpos, IrOp c) +{ + int ra = LUAU_INSN_A(*pc); + int rb = LUAU_INSN_B(*pc); + + IrOp fallthrough = build.block(IrBlockKind::Internal); + IrOp next = build.blockAtInst(pcpos + 1); + + IrOp target = (ra == rb) ? next : build.block(IrBlockKind::Internal); + + build.inst(IrCmd::JUMP_IF_FALSY, build.vmReg(rb), target, fallthrough); + build.beginBlock(fallthrough); + + IrOp load = build.inst(IrCmd::LOAD_TVALUE, c); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), load); + build.inst(IrCmd::JUMP, next); + + if (ra == rb) + { + build.beginBlock(next); + } + else + { + build.beginBlock(target); + + IrOp load1 = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(rb)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), load1); + build.inst(IrCmd::JUMP, next); + + build.beginBlock(next); + } +} + +void translateInstOrX(IrBuilder& build, const Instruction* pc, int pcpos, IrOp c) +{ + int ra = LUAU_INSN_A(*pc); + int rb = LUAU_INSN_B(*pc); + + IrOp fallthrough = build.block(IrBlockKind::Internal); + IrOp next = build.blockAtInst(pcpos + 1); + + IrOp target = (ra == rb) ? next : build.block(IrBlockKind::Internal); + + build.inst(IrCmd::JUMP_IF_TRUTHY, build.vmReg(rb), target, fallthrough); + build.beginBlock(fallthrough); + + IrOp load = build.inst(IrCmd::LOAD_TVALUE, c); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), load); + build.inst(IrCmd::JUMP, next); + + if (ra == rb) + { + build.beginBlock(next); + } + else + { + build.beginBlock(target); + + IrOp load1 = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(rb)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), load1); + build.inst(IrCmd::JUMP, next); + + build.beginBlock(next); + } +} + } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/IrTranslation.h b/CodeGen/src/IrTranslation.h index 0be111dc..87a530b5 100644 --- a/CodeGen/src/IrTranslation.h +++ b/CodeGen/src/IrTranslation.h @@ -61,6 +61,8 @@ void translateInstSetGlobal(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstConcat(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstCapture(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstNamecall(IrBuilder& build, const Instruction* pc, int pcpos); +void translateInstAndX(IrBuilder& build, const Instruction* pc, int pcpos, IrOp c); +void translateInstOrX(IrBuilder& build, const Instruction* pc, int pcpos, IrOp c); } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/IrUtils.cpp b/CodeGen/src/IrUtils.cpp index e29a5b02..c5e7c887 100644 --- a/CodeGen/src/IrUtils.cpp +++ b/CodeGen/src/IrUtils.cpp @@ -14,6 +14,134 @@ namespace Luau namespace CodeGen { +IrValueKind getCmdValueKind(IrCmd cmd) +{ + switch (cmd) + { + case IrCmd::NOP: + return IrValueKind::None; + case IrCmd::LOAD_TAG: + return IrValueKind::Tag; + case IrCmd::LOAD_POINTER: + return IrValueKind::Pointer; + case IrCmd::LOAD_DOUBLE: + return IrValueKind::Double; + case IrCmd::LOAD_INT: + return IrValueKind::Int; + case IrCmd::LOAD_TVALUE: + case IrCmd::LOAD_NODE_VALUE_TV: + return IrValueKind::Tvalue; + case IrCmd::LOAD_ENV: + case IrCmd::GET_ARR_ADDR: + case IrCmd::GET_SLOT_NODE_ADDR: + case IrCmd::GET_HASH_NODE_ADDR: + return IrValueKind::Pointer; + case IrCmd::STORE_TAG: + case IrCmd::STORE_POINTER: + case IrCmd::STORE_DOUBLE: + case IrCmd::STORE_INT: + case IrCmd::STORE_VECTOR: + case IrCmd::STORE_TVALUE: + case IrCmd::STORE_NODE_VALUE_TV: + return IrValueKind::None; + case IrCmd::ADD_INT: + case IrCmd::SUB_INT: + return IrValueKind::Int; + case IrCmd::ADD_NUM: + case IrCmd::SUB_NUM: + case IrCmd::MUL_NUM: + case IrCmd::DIV_NUM: + case IrCmd::MOD_NUM: + case IrCmd::POW_NUM: + case IrCmd::MIN_NUM: + case IrCmd::MAX_NUM: + case IrCmd::UNM_NUM: + case IrCmd::FLOOR_NUM: + case IrCmd::CEIL_NUM: + case IrCmd::ROUND_NUM: + case IrCmd::SQRT_NUM: + case IrCmd::ABS_NUM: + return IrValueKind::Double; + case IrCmd::NOT_ANY: + return IrValueKind::Int; + case IrCmd::JUMP: + case IrCmd::JUMP_IF_TRUTHY: + case IrCmd::JUMP_IF_FALSY: + case IrCmd::JUMP_EQ_TAG: + case IrCmd::JUMP_EQ_INT: + case IrCmd::JUMP_EQ_POINTER: + case IrCmd::JUMP_CMP_NUM: + case IrCmd::JUMP_CMP_ANY: + case IrCmd::JUMP_SLOT_MATCH: + return IrValueKind::None; + case IrCmd::TABLE_LEN: + return IrValueKind::Double; + case IrCmd::NEW_TABLE: + case IrCmd::DUP_TABLE: + return IrValueKind::Pointer; + case IrCmd::TRY_NUM_TO_INDEX: + return IrValueKind::Int; + case IrCmd::TRY_CALL_FASTGETTM: + return IrValueKind::Pointer; + case IrCmd::INT_TO_NUM: + return IrValueKind::Double; + case IrCmd::ADJUST_STACK_TO_REG: + case IrCmd::ADJUST_STACK_TO_TOP: + return IrValueKind::None; + case IrCmd::FASTCALL: + return IrValueKind::None; + case IrCmd::INVOKE_FASTCALL: + return IrValueKind::Int; + case IrCmd::CHECK_FASTCALL_RES: + case IrCmd::DO_ARITH: + case IrCmd::DO_LEN: + case IrCmd::GET_TABLE: + case IrCmd::SET_TABLE: + case IrCmd::GET_IMPORT: + case IrCmd::CONCAT: + case IrCmd::GET_UPVALUE: + case IrCmd::SET_UPVALUE: + case IrCmd::PREPARE_FORN: + case IrCmd::CHECK_TAG: + case IrCmd::CHECK_READONLY: + case IrCmd::CHECK_NO_METATABLE: + case IrCmd::CHECK_SAFE_ENV: + case IrCmd::CHECK_ARRAY_SIZE: + case IrCmd::CHECK_SLOT_MATCH: + case IrCmd::CHECK_NODE_NO_NEXT: + case IrCmd::INTERRUPT: + case IrCmd::CHECK_GC: + case IrCmd::BARRIER_OBJ: + case IrCmd::BARRIER_TABLE_BACK: + case IrCmd::BARRIER_TABLE_FORWARD: + case IrCmd::SET_SAVEDPC: + case IrCmd::CLOSE_UPVALS: + case IrCmd::CAPTURE: + case IrCmd::SETLIST: + case IrCmd::CALL: + case IrCmd::RETURN: + case IrCmd::FORGLOOP: + case IrCmd::FORGLOOP_FALLBACK: + case IrCmd::FORGPREP_XNEXT_FALLBACK: + case IrCmd::COVERAGE: + case IrCmd::FALLBACK_GETGLOBAL: + case IrCmd::FALLBACK_SETGLOBAL: + case IrCmd::FALLBACK_GETTABLEKS: + case IrCmd::FALLBACK_SETTABLEKS: + case IrCmd::FALLBACK_NAMECALL: + case IrCmd::FALLBACK_PREPVARARGS: + case IrCmd::FALLBACK_GETVARARGS: + case IrCmd::FALLBACK_NEWCLOSURE: + case IrCmd::FALLBACK_DUPCLOSURE: + case IrCmd::FALLBACK_FORGPREP: + return IrValueKind::None; + case IrCmd::SUBSTITUTE: + return IrValueKind::Unknown; + } + + LUAU_UNREACHABLE(); +} + static void removeInstUse(IrFunction& function, uint32_t instIdx) { IrInst& inst = function.instructions[instIdx]; @@ -320,6 +448,26 @@ void foldConstants(IrBuilder& build, IrFunction& function, IrBlock& block, uint3 if (inst.a.kind == IrOpKind::Constant) substitute(function, inst, build.constDouble(-function.doubleOp(inst.a))); break; + case IrCmd::FLOOR_NUM: + if (inst.a.kind == IrOpKind::Constant) + substitute(function, inst, build.constDouble(floor(function.doubleOp(inst.a)))); + break; + case IrCmd::CEIL_NUM: + if (inst.a.kind == IrOpKind::Constant) + substitute(function, inst, build.constDouble(ceil(function.doubleOp(inst.a)))); + break; + case IrCmd::ROUND_NUM: + if (inst.a.kind == IrOpKind::Constant) + substitute(function, inst, build.constDouble(round(function.doubleOp(inst.a)))); + break; + case IrCmd::SQRT_NUM: + if (inst.a.kind == IrOpKind::Constant) + substitute(function, inst, build.constDouble(sqrt(function.doubleOp(inst.a)))); + break; + case IrCmd::ABS_NUM: + if (inst.a.kind == IrOpKind::Constant) + substitute(function, inst, build.constDouble(fabs(function.doubleOp(inst.a)))); + break; case IrCmd::NOT_ANY: if (inst.a.kind == IrOpKind::Constant) { @@ -354,7 +502,7 @@ void foldConstants(IrBuilder& build, IrFunction& function, IrBlock& block, uint3 case IrCmd::JUMP_CMP_NUM: if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) { - if (compare(function.doubleOp(inst.a), function.doubleOp(inst.b), function.conditionOp(inst.c))) + if (compare(function.doubleOp(inst.a), function.doubleOp(inst.b), conditionOp(inst.c))) replace(function, block, index, {IrCmd::JUMP, inst.d}); else replace(function, block, index, {IrCmd::JUMP, inst.e}); diff --git a/CodeGen/src/NativeState.cpp b/CodeGen/src/NativeState.cpp index f79bcab8..52479692 100644 --- a/CodeGen/src/NativeState.cpp +++ b/CodeGen/src/NativeState.cpp @@ -45,6 +45,7 @@ void initFallbackTable(NativeState& data) CODEGEN_SET_FALLBACK(LOP_BREAK, 0); // Fallbacks that are called from partial implementation of an instruction + // TODO: these fallbacks should be replaced with special functions that exclude the (redundantly executed) fast path from the fallback CODEGEN_SET_FALLBACK(LOP_GETGLOBAL, 0); CODEGEN_SET_FALLBACK(LOP_SETGLOBAL, 0); CODEGEN_SET_FALLBACK(LOP_GETTABLEKS, 0); @@ -109,6 +110,9 @@ void initHelperFunctions(NativeState& data) data.context.forgPrepXnextFallback = forgPrepXnextFallback; data.context.callProlog = callProlog; data.context.callEpilogC = callEpilogC; + + data.context.callFallback = callFallback; + data.context.returnFallback = returnFallback; } } // namespace CodeGen diff --git a/CodeGen/src/NativeState.h b/CodeGen/src/NativeState.h index bebf421b..2d97e63c 100644 --- a/CodeGen/src/NativeState.h +++ b/CodeGen/src/NativeState.h @@ -47,12 +47,6 @@ struct NativeContext uint8_t* gateEntry = nullptr; uint8_t* gateExit = nullptr; - // Opcode fallbacks, implemented in C - NativeFallback fallback[LOP__COUNT] = {}; - - // Fast call methods, implemented in C - luau_FastFunction luauF_table[256] = {}; - // Helper functions, implemented in C int (*luaV_lessthan)(lua_State* L, const TValue* l, const TValue* r) = nullptr; int (*luaV_lessequal)(lua_State* L, const TValue* l, const TValue* r) = nullptr; @@ -107,6 +101,15 @@ struct NativeContext void (*forgPrepXnextFallback)(lua_State* L, TValue* ra, int pc) = nullptr; Closure* (*callProlog)(lua_State* L, TValue* ra, StkId argtop, int nresults) = nullptr; void (*callEpilogC)(lua_State* L, int nresults, int n) = nullptr; + + Closure* (*callFallback)(lua_State* L, StkId ra, StkId argtop, int nresults) = nullptr; + Closure* (*returnFallback)(lua_State* L, StkId ra, int n) = nullptr; + + // Opcode fallbacks, implemented in C + NativeFallback fallback[LOP__COUNT] = {}; + + // Fast call methods, implemented in C + luau_FastFunction luauF_table[256] = {}; }; struct NativeState diff --git a/CodeGen/src/OptimizeConstProp.cpp b/CodeGen/src/OptimizeConstProp.cpp index b12a9b94..7157a18c 100644 --- a/CodeGen/src/OptimizeConstProp.cpp +++ b/CodeGen/src/OptimizeConstProp.cpp @@ -42,6 +42,11 @@ struct RegisterLink // Data we know about the current VM state struct ConstPropState { + ConstPropState(const IrFunction& function) + : function(function) + { + } + uint8_t tryGetTag(IrOp op) { if (RegisterInfo* info = tryGetRegisterInfo(op)) @@ -91,30 +96,42 @@ struct ConstPropState void invalidateTag(IrOp regOp) { - LUAU_ASSERT(regOp.kind == IrOpKind::VmReg); - invalidate(regs[regOp.index], /* invalidateTag */ true, /* invalidateValue */ false); + invalidate(regs[vmRegOp(regOp)], /* invalidateTag */ true, /* invalidateValue */ false); } void invalidateValue(IrOp regOp) { - LUAU_ASSERT(regOp.kind == IrOpKind::VmReg); - invalidate(regs[regOp.index], /* invalidateTag */ false, /* invalidateValue */ true); + invalidate(regs[vmRegOp(regOp)], /* invalidateTag */ false, /* invalidateValue */ true); } void invalidate(IrOp regOp) { - LUAU_ASSERT(regOp.kind == IrOpKind::VmReg); - invalidate(regs[regOp.index], /* invalidateTag */ true, /* invalidateValue */ true); + invalidate(regs[vmRegOp(regOp)], /* invalidateTag */ true, /* invalidateValue */ true); } - void invalidateRegistersFrom(uint32_t firstReg) + void invalidateRegistersFrom(int firstReg) { - for (int i = int(firstReg); i <= maxReg; ++i) + for (int i = firstReg; i <= maxReg; ++i) invalidate(regs[i], /* invalidateTag */ true, /* invalidateValue */ true); maxReg = int(firstReg) - 1; } + void invalidateRegisterRange(int firstReg, int count) + { + for (int i = firstReg; i < firstReg + count && i <= maxReg; ++i) + invalidate(regs[i], /* invalidateTag */ true, /* invalidateValue */ true); + } + + void invalidateCapturedRegisters() + { + for (int i = 0; i <= maxReg; ++i) + { + if (function.cfg.captured.regs.test(i)) + invalidate(regs[i], /* invalidateTag */ true, /* invalidateValue */ true); + } + } + void invalidateHeap() { for (int i = 0; i <= maxReg; ++i) @@ -127,26 +144,25 @@ struct ConstPropState reg.knownNoMetatable = false; } - void invalidateAll() + void invalidateUserCall() { - // Invalidating registers also invalidates what we know about the heap (stored in RegisterInfo) - invalidateRegistersFrom(0u); + invalidateHeap(); + invalidateCapturedRegisters(); inSafeEnv = false; } void createRegLink(uint32_t instIdx, IrOp regOp) { - LUAU_ASSERT(regOp.kind == IrOpKind::VmReg); LUAU_ASSERT(!instLink.contains(instIdx)); - instLink[instIdx] = RegisterLink{uint8_t(regOp.index), regs[regOp.index].version}; + instLink[instIdx] = RegisterLink{uint8_t(vmRegOp(regOp)), regs[vmRegOp(regOp)].version}; } RegisterInfo* tryGetRegisterInfo(IrOp op) { if (op.kind == IrOpKind::VmReg) { - maxReg = int(op.index) > maxReg ? int(op.index) : maxReg; - return ®s[op.index]; + maxReg = vmRegOp(op) > maxReg ? vmRegOp(op) : maxReg; + return ®s[vmRegOp(op)]; } if (RegisterLink* link = tryGetRegLink(op)) @@ -175,6 +191,8 @@ struct ConstPropState return nullptr; } + const IrFunction& function; + RegisterInfo regs[256]; // For range/full invalidations, we only want to visit a limited number of data that we have recorded @@ -346,6 +364,9 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& } } break; + case IrCmd::STORE_VECTOR: + state.invalidateValue(inst.a); + break; case IrCmd::STORE_TVALUE: if (inst.a.kind == IrOpKind::VmReg) { @@ -411,7 +432,7 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& if (valueA && valueB) { - if (compare(*valueA, *valueB, function.conditionOp(inst.c))) + if (compare(*valueA, *valueB, conditionOp(inst.c))) replace(function, block, index, {IrCmd::JUMP, inst.d}); else replace(function, block, index, {IrCmd::JUMP, inst.e}); @@ -481,15 +502,9 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& } } break; - case IrCmd::LOP_AND: - case IrCmd::LOP_ANDK: - case IrCmd::LOP_OR: - case IrCmd::LOP_ORK: - state.invalidate(inst.b); - break; case IrCmd::FASTCALL: case IrCmd::INVOKE_FASTCALL: - handleBuiltinEffects(state, LuauBuiltinFunction(function.uintOp(inst.a)), inst.b.index, function.intOp(inst.f)); + handleBuiltinEffects(state, LuauBuiltinFunction(function.uintOp(inst.a)), vmRegOp(inst.b), function.intOp(inst.f)); break; // These instructions don't have an effect on register/memory state we are tracking @@ -511,6 +526,11 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::MIN_NUM: case IrCmd::MAX_NUM: case IrCmd::UNM_NUM: + case IrCmd::FLOOR_NUM: + case IrCmd::CEIL_NUM: + case IrCmd::ROUND_NUM: + case IrCmd::SQRT_NUM: + case IrCmd::ABS_NUM: case IrCmd::NOT_ANY: case IrCmd::JUMP: case IrCmd::JUMP_EQ_POINTER: @@ -525,10 +545,10 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::CHECK_SLOT_MATCH: case IrCmd::CHECK_NODE_NO_NEXT: case IrCmd::BARRIER_TABLE_BACK: - case IrCmd::LOP_RETURN: - case IrCmd::LOP_COVERAGE: + case IrCmd::RETURN: + case IrCmd::COVERAGE: case IrCmd::SET_UPVALUE: - case IrCmd::LOP_SETLIST: // We don't track table state that this can invalidate + case IrCmd::SETLIST: // We don't track table state that this can invalidate case IrCmd::SET_SAVEDPC: // TODO: we may be able to remove some updates to PC case IrCmd::CLOSE_UPVALS: // Doesn't change memory that we track case IrCmd::CAPTURE: @@ -538,35 +558,93 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::CHECK_FASTCALL_RES: // Changes stack top, but not the values break; - // We don't model the following instructions, so we just clear all the knowledge we have built up - // Many of these call user functions that can change memory and captured registers - // Some of these might yield with similar effects case IrCmd::JUMP_CMP_ANY: + state.invalidateUserCall(); // TODO: if arguments are strings, there will be no user calls + break; case IrCmd::DO_ARITH: + state.invalidate(inst.a); + state.invalidateUserCall(); + break; case IrCmd::DO_LEN: + state.invalidate(inst.a); + state.invalidateUserCall(); // TODO: if argument is a string, there will be no user call + + state.saveTag(inst.a, LUA_TNUMBER); + break; case IrCmd::GET_TABLE: + state.invalidate(inst.a); + state.invalidateUserCall(); + break; case IrCmd::SET_TABLE: + state.invalidateUserCall(); + break; case IrCmd::GET_IMPORT: + state.invalidate(inst.a); + state.invalidateUserCall(); + break; case IrCmd::CONCAT: + state.invalidateRegisterRange(vmRegOp(inst.a), function.uintOp(inst.b)); + state.invalidateUserCall(); // TODO: if only strings and numbers are concatenated, there will be no user calls + break; case IrCmd::PREPARE_FORN: - case IrCmd::INTERRUPT: // TODO: it will be important to keep tag/value state, but we have to track register capture - case IrCmd::LOP_CALL: - case IrCmd::LOP_FORGLOOP: - case IrCmd::LOP_FORGLOOP_FALLBACK: - case IrCmd::LOP_FORGPREP_XNEXT_FALLBACK: + state.invalidateValue(inst.a); + state.saveTag(inst.a, LUA_TNUMBER); + state.invalidateValue(inst.b); + state.saveTag(inst.b, LUA_TNUMBER); + state.invalidateValue(inst.c); + state.saveTag(inst.c, LUA_TNUMBER); + break; + case IrCmd::INTERRUPT: + state.invalidateUserCall(); + break; + case IrCmd::CALL: + state.invalidateRegistersFrom(vmRegOp(inst.a)); + state.invalidateUserCall(); + break; + case IrCmd::FORGLOOP: + state.invalidateRegistersFrom(vmRegOp(inst.a) + 2); // Rn and Rn+1 are not modified + break; + case IrCmd::FORGLOOP_FALLBACK: + state.invalidateRegistersFrom(vmRegOp(inst.a) + 2); // Rn and Rn+1 are not modified + state.invalidateUserCall(); + break; + case IrCmd::FORGPREP_XNEXT_FALLBACK: + // This fallback only conditionally throws an exception + break; case IrCmd::FALLBACK_GETGLOBAL: + state.invalidate(inst.b); + state.invalidateUserCall(); + break; case IrCmd::FALLBACK_SETGLOBAL: + state.invalidateUserCall(); + break; case IrCmd::FALLBACK_GETTABLEKS: + state.invalidate(inst.b); + state.invalidateUserCall(); + break; case IrCmd::FALLBACK_SETTABLEKS: + state.invalidateUserCall(); + break; case IrCmd::FALLBACK_NAMECALL: + state.invalidate(IrOp{inst.b.kind, vmRegOp(inst.b) + 0u}); + state.invalidate(IrOp{inst.b.kind, vmRegOp(inst.b) + 1u}); + state.invalidateUserCall(); + break; case IrCmd::FALLBACK_PREPVARARGS: + break; case IrCmd::FALLBACK_GETVARARGS: + state.invalidateRegistersFrom(vmRegOp(inst.b)); + break; case IrCmd::FALLBACK_NEWCLOSURE: + state.invalidate(inst.b); + break; case IrCmd::FALLBACK_DUPCLOSURE: + state.invalidate(inst.b); + break; case IrCmd::FALLBACK_FORGPREP: - // TODO: this is very conservative, some of there instructions can be tracked better - // TODO: non-captured register tags and values should not be cleared here - state.invalidateAll(); + state.invalidate(IrOp{inst.b.kind, vmRegOp(inst.b) + 0u}); + state.invalidate(IrOp{inst.b.kind, vmRegOp(inst.b) + 1u}); + state.invalidate(IrOp{inst.b.kind, vmRegOp(inst.b) + 2u}); break; } } @@ -592,7 +670,7 @@ static void constPropInBlockChain(IrBuilder& build, std::vector& visite { IrFunction& function = build.function; - ConstPropState state; + ConstPropState state{function}; while (block) { @@ -698,7 +776,7 @@ static void tryCreateLinearBlock(IrBuilder& build, std::vector& visited return; // Initialize state with the knowledge of our current block - ConstPropState state; + ConstPropState state{function}; constPropInBlock(build, startingBlock, state); // Veryfy that target hasn't changed diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 03f4b3e6..9478404a 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -25,8 +25,6 @@ LUAU_FASTINTVARIABLE(LuauCompileInlineThreshold, 25) LUAU_FASTINTVARIABLE(LuauCompileInlineThresholdMaxBoost, 300) LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5) -LUAU_FASTFLAGVARIABLE(LuauCompileBuiltinArity, false) - namespace Luau { @@ -295,7 +293,7 @@ struct Compiler // handles builtin calls that can't be constant-folded but are known to return one value // note: optimizationLevel check is technically redundant but it's important that we never optimize based on builtins in O1 - if (FFlag::LuauCompileBuiltinArity && options.optimizationLevel >= 2) + if (options.optimizationLevel >= 2) if (int* bfid = builtins.find(expr)) return getBuiltinInfo(*bfid).results != 1; @@ -766,7 +764,7 @@ struct Compiler { if (!isExprMultRet(expr->args.data[expr->args.size - 1])) return compileExprFastcallN(expr, target, targetCount, targetTop, multRet, regs, bfid); - else if (FFlag::LuauCompileBuiltinArity && options.optimizationLevel >= 2 && int(expr->args.size) == getBuiltinInfo(bfid).params) + else if (options.optimizationLevel >= 2 && int(expr->args.size) == getBuiltinInfo(bfid).params) return compileExprFastcallN(expr, target, targetCount, targetTop, multRet, regs, bfid); } diff --git a/Makefile b/Makefile index 58512293..bbc66c2e 100644 --- a/Makefile +++ b/Makefile @@ -117,6 +117,11 @@ ifneq ($(native),) TESTS_ARGS+=--codegen endif +ifneq ($(nativelj),) + CXXFLAGS+=-DLUA_CUSTOM_EXECUTION=1 -DLUA_USE_LONGJMP=1 + TESTS_ARGS+=--codegen +endif + # target-specific flags $(AST_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include $(COMPILER_OBJECTS): CXXFLAGS+=-std=c++17 -ICompiler/include -ICommon/include -IAst/include diff --git a/Sources.cmake b/Sources.cmake index 6e0a32ed..3508ec39 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -65,8 +65,10 @@ target_sources(Luau.CodeGen PRIVATE CodeGen/include/Luau/ConditionX64.h CodeGen/include/Luau/IrAnalysis.h CodeGen/include/Luau/IrBuilder.h + CodeGen/include/Luau/IrCallWrapperX64.h CodeGen/include/Luau/IrDump.h CodeGen/include/Luau/IrData.h + CodeGen/include/Luau/IrRegAllocX64.h CodeGen/include/Luau/IrUtils.h CodeGen/include/Luau/Label.h CodeGen/include/Luau/OperandX64.h @@ -84,15 +86,21 @@ target_sources(Luau.CodeGen PRIVATE CodeGen/src/CodeBlockUnwind.cpp CodeGen/src/CodeGen.cpp CodeGen/src/CodeGenUtils.cpp + CodeGen/src/CodeGenA64.cpp CodeGen/src/CodeGenX64.cpp CodeGen/src/EmitBuiltinsX64.cpp + CodeGen/src/EmitCommonA64.cpp CodeGen/src/EmitCommonX64.cpp + CodeGen/src/EmitInstructionA64.cpp CodeGen/src/EmitInstructionX64.cpp CodeGen/src/Fallbacks.cpp CodeGen/src/IrAnalysis.cpp CodeGen/src/IrBuilder.cpp + CodeGen/src/IrCallWrapperX64.cpp CodeGen/src/IrDump.cpp + CodeGen/src/IrLoweringA64.cpp CodeGen/src/IrLoweringX64.cpp + CodeGen/src/IrRegAllocA64.cpp CodeGen/src/IrRegAllocX64.cpp CodeGen/src/IrTranslateBuiltins.cpp CodeGen/src/IrTranslation.cpp @@ -106,15 +114,19 @@ target_sources(Luau.CodeGen PRIVATE CodeGen/src/ByteUtils.h CodeGen/src/CustomExecUtils.h CodeGen/src/CodeGenUtils.h + CodeGen/src/CodeGenA64.h CodeGen/src/CodeGenX64.h CodeGen/src/EmitBuiltinsX64.h CodeGen/src/EmitCommon.h + CodeGen/src/EmitCommonA64.h CodeGen/src/EmitCommonX64.h + CodeGen/src/EmitInstructionA64.h CodeGen/src/EmitInstructionX64.h CodeGen/src/Fallbacks.h CodeGen/src/FallbacksProlog.h + CodeGen/src/IrLoweringA64.h CodeGen/src/IrLoweringX64.h - CodeGen/src/IrRegAllocX64.h + CodeGen/src/IrRegAllocA64.h CodeGen/src/IrTranslateBuiltins.h CodeGen/src/IrTranslation.h CodeGen/src/NativeState.h @@ -334,6 +346,7 @@ if(TARGET Luau.UnitTest) tests/Fixture.h tests/IostreamOptional.h tests/ScopedFlags.h + tests/AssemblyBuilderA64.test.cpp tests/AssemblyBuilderX64.test.cpp tests/AstJsonEncoder.test.cpp tests/AstQuery.test.cpp @@ -350,6 +363,7 @@ if(TARGET Luau.UnitTest) tests/Error.test.cpp tests/Frontend.test.cpp tests/IrBuilder.test.cpp + tests/IrCallWrapperX64.test.cpp tests/JsonEmitter.test.cpp tests/Lexer.test.cpp tests/Linter.test.cpp diff --git a/VM/include/lua.h b/VM/include/lua.h index 32d31bdb..94e8cfd9 100644 --- a/VM/include/lua.h +++ b/VM/include/lua.h @@ -29,7 +29,7 @@ enum lua_Status LUA_OK = 0, LUA_YIELD, LUA_ERRRUN, - LUA_ERRSYNTAX, + LUA_ERRSYNTAX, // legacy error code, preserved for compatibility LUA_ERRMEM, LUA_ERRERR, LUA_BREAK, // yielded for a debug breakpoint diff --git a/VM/src/lbuiltins.cpp b/VM/src/lbuiltins.cpp index 3c669bff..e0dc8a38 100644 --- a/VM/src/lbuiltins.cpp +++ b/VM/src/lbuiltins.cpp @@ -23,8 +23,6 @@ #endif #endif -LUAU_FASTFLAGVARIABLE(LuauBuiltinSSE41, false) - // luauF functions implement FASTCALL instruction that performs a direct execution of some builtin functions from the VM // The rule of thumb is that FASTCALL functions can not call user code, yield, fail, or reallocate stack. // If types of the arguments mismatch, luauF_* needs to return -1 and the execution will fall back to the usual call path @@ -105,9 +103,7 @@ static int luauF_atan(lua_State* L, StkId res, TValue* arg0, int nresults, StkId return -1; } -// TODO: LUAU_NOINLINE can be removed with LuauBuiltinSSE41 LUAU_FASTMATH_BEGIN -LUAU_NOINLINE static int luauF_ceil(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) { if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) @@ -170,9 +166,7 @@ static int luauF_exp(lua_State* L, StkId res, TValue* arg0, int nresults, StkId return -1; } -// TODO: LUAU_NOINLINE can be removed with LuauBuiltinSSE41 LUAU_FASTMATH_BEGIN -LUAU_NOINLINE static int luauF_floor(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) { if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) @@ -949,9 +943,7 @@ static int luauF_sign(lua_State* L, StkId res, TValue* arg0, int nresults, StkId return -1; } -// TODO: LUAU_NOINLINE can be removed with LuauBuiltinSSE41 LUAU_FASTMATH_BEGIN -LUAU_NOINLINE static int luauF_round(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) { if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) @@ -1271,9 +1263,6 @@ LUAU_TARGET_SSE41 inline double roundsd_sse41(double v) LUAU_TARGET_SSE41 static int luauF_floor_sse41(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) { - if (!FFlag::LuauBuiltinSSE41) - return luauF_floor(L, res, arg0, nresults, args, nparams); - if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) { double a1 = nvalue(arg0); @@ -1286,9 +1275,6 @@ LUAU_TARGET_SSE41 static int luauF_floor_sse41(lua_State* L, StkId res, TValue* LUAU_TARGET_SSE41 static int luauF_ceil_sse41(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) { - if (!FFlag::LuauBuiltinSSE41) - return luauF_ceil(L, res, arg0, nresults, args, nparams); - if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) { double a1 = nvalue(arg0); @@ -1301,9 +1287,6 @@ LUAU_TARGET_SSE41 static int luauF_ceil_sse41(lua_State* L, StkId res, TValue* a LUAU_TARGET_SSE41 static int luauF_round_sse41(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) { - if (!FFlag::LuauBuiltinSSE41) - return luauF_round(L, res, arg0, nresults, args, nparams); - if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) { double a1 = nvalue(arg0); diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index ff8105b8..264388bc 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -17,6 +17,8 @@ #include +LUAU_FASTFLAGVARIABLE(LuauBetterOOMHandling, false) + /* ** {====================================================== ** Error-recovery functions @@ -79,22 +81,17 @@ public: const char* what() const throw() override { - // LUA_ERRRUN/LUA_ERRSYNTAX pass an object on the stack which is intended to describe the error. - if (status == LUA_ERRRUN || status == LUA_ERRSYNTAX) - { - // Conversion to a string could still fail. For example if a user passes a non-string/non-number argument to `error()`. + // LUA_ERRRUN passes error object on the stack + if (status == LUA_ERRRUN || (status == LUA_ERRSYNTAX && !FFlag::LuauBetterOOMHandling)) if (const char* str = lua_tostring(L, -1)) - { return str; - } - } switch (status) { case LUA_ERRRUN: - return "lua_exception: LUA_ERRRUN (no string/number provided as description)"; + return "lua_exception: runtime error"; case LUA_ERRSYNTAX: - return "lua_exception: LUA_ERRSYNTAX (no string/number provided as description)"; + return "lua_exception: syntax error"; case LUA_ERRMEM: return "lua_exception: " LUA_MEMERRMSG; case LUA_ERRERR: @@ -550,19 +547,42 @@ int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t e int status = luaD_rawrunprotected(L, func, u); if (status != 0) { + int errstatus = status; + // call user-defined error function (used in xpcall) if (ef) { - // if errfunc fails, we fail with "error in error handling" - if (luaD_rawrunprotected(L, callerrfunc, restorestack(L, ef)) != 0) - status = LUA_ERRERR; + if (FFlag::LuauBetterOOMHandling) + { + // push error object to stack top if it's not already there + if (status != LUA_ERRRUN) + seterrorobj(L, status, L->top); + + // if errfunc fails, we fail with "error in error handling" or "not enough memory" + int err = luaD_rawrunprotected(L, callerrfunc, restorestack(L, ef)); + + // in general we preserve the status, except for cases when the error handler fails + // out of memory is treated specially because it's common for it to be cascading, in which case we preserve the code + if (err == 0) + errstatus = LUA_ERRRUN; + else if (status == LUA_ERRMEM && err == LUA_ERRMEM) + errstatus = LUA_ERRMEM; + else + errstatus = status = LUA_ERRERR; + } + else + { + // if errfunc fails, we fail with "error in error handling" + if (luaD_rawrunprotected(L, callerrfunc, restorestack(L, ef)) != 0) + status = LUA_ERRERR; + } } // since the call failed with an error, we might have to reset the 'active' thread state if (!oldactive) L->isactive = false; - // Restore nCcalls before calling the debugprotectederror callback which may rely on the proper value to have been restored. + // restore nCcalls before calling the debugprotectederror callback which may rely on the proper value to have been restored. L->nCcalls = oldnCcalls; // an error occurred, check if we have a protected error callback @@ -577,7 +597,7 @@ int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t e StkId oldtop = restorestack(L, old_top); luaF_close(L, oldtop); // close eventual pending closures - seterrorobj(L, status, oldtop); + seterrorobj(L, FFlag::LuauBetterOOMHandling ? errstatus : status, oldtop); L->ci = restoreci(L, old_ci); L->base = L->ci->base; restore_stack_limit(L); diff --git a/VM/src/lstate.h b/VM/src/lstate.h index 1d324896..32a240bf 100644 --- a/VM/src/lstate.h +++ b/VM/src/lstate.h @@ -208,14 +208,14 @@ typedef struct global_State uint64_t rngstate; // PCG random number generator state uint64_t ptrenckey[4]; // pointer encoding key for display - void (*udatagc[LUA_UTAG_LIMIT])(lua_State*, void*); // for each userdata tag, a gc callback to be called immediately before freeing memory - lua_Callbacks cb; #if LUA_CUSTOM_EXECUTION lua_ExecutionCallbacks ecb; #endif + void (*udatagc[LUA_UTAG_LIMIT])(lua_State*, void*); // for each userdata tag, a gc callback to be called immediately before freeing memory + GCStats gcstats; #ifdef LUAI_GCMETRICS diff --git a/VM/src/ltablib.cpp b/VM/src/ltablib.cpp index ddee3a71..4443be34 100644 --- a/VM/src/ltablib.cpp +++ b/VM/src/ltablib.cpp @@ -10,7 +10,7 @@ #include "ldebug.h" #include "lvm.h" -LUAU_FASTFLAGVARIABLE(LuauOptimizedSort, false) +LUAU_FASTFLAGVARIABLE(LuauIntrosort, false) static int foreachi(lua_State* L) { @@ -298,120 +298,6 @@ static int tunpack(lua_State* L) return (int)n; } -/* -** {====================================================== -** Quicksort -** (based on `Algorithms in MODULA-3', Robert Sedgewick; -** Addison-Wesley, 1993.) -*/ - -static void set2(lua_State* L, int i, int j) -{ - LUAU_ASSERT(!FFlag::LuauOptimizedSort); - lua_rawseti(L, 1, i); - lua_rawseti(L, 1, j); -} - -static int sort_comp(lua_State* L, int a, int b) -{ - LUAU_ASSERT(!FFlag::LuauOptimizedSort); - if (!lua_isnil(L, 2)) - { // function? - int res; - lua_pushvalue(L, 2); - lua_pushvalue(L, a - 1); // -1 to compensate function - lua_pushvalue(L, b - 2); // -2 to compensate function and `a' - lua_call(L, 2, 1); - res = lua_toboolean(L, -1); - lua_pop(L, 1); - return res; - } - else // a < b? - return lua_lessthan(L, a, b); -} - -static void auxsort(lua_State* L, int l, int u) -{ - LUAU_ASSERT(!FFlag::LuauOptimizedSort); - while (l < u) - { // for tail recursion - int i, j; - // sort elements a[l], a[(l+u)/2] and a[u] - lua_rawgeti(L, 1, l); - lua_rawgeti(L, 1, u); - if (sort_comp(L, -1, -2)) // a[u] < a[l]? - set2(L, l, u); // swap a[l] - a[u] - else - lua_pop(L, 2); - if (u - l == 1) - break; // only 2 elements - i = (l + u) / 2; - lua_rawgeti(L, 1, i); - lua_rawgeti(L, 1, l); - if (sort_comp(L, -2, -1)) // a[i]= P - while (lua_rawgeti(L, 1, ++i), sort_comp(L, -1, -2)) - { - if (i >= u) - luaL_error(L, "invalid order function for sorting"); - lua_pop(L, 1); // remove a[i] - } - // repeat --j until a[j] <= P - while (lua_rawgeti(L, 1, --j), sort_comp(L, -3, -1)) - { - if (j <= l) - luaL_error(L, "invalid order function for sorting"); - lua_pop(L, 1); // remove a[j] - } - if (j < i) - { - lua_pop(L, 3); // pop pivot, a[i], a[j] - break; - } - set2(L, i, j); - } - lua_rawgeti(L, 1, u - 1); - lua_rawgeti(L, 1, i); - set2(L, u - 1, i); // swap pivot (a[u-1]) with a[i] - // a[l..i-1] <= a[i] == P <= a[i+1..u] - // adjust so that smaller half is in [j..i] and larger one in [l..u] - if (i - l < u - i) - { - j = l; - i = i - 1; - l = i + 2; - } - else - { - j = i + 1; - i = u; - u = j - 2; - } - auxsort(L, j, i); // call recursively the smaller one - } // repeat the routine for the larger one -} - typedef int (*SortPredicate)(lua_State* L, const TValue* l, const TValue* r); static int sort_func(lua_State* L, const TValue* l, const TValue* r) @@ -456,30 +342,77 @@ inline int sort_less(lua_State* L, Table* t, int i, int j, SortPredicate pred) return res; } -static void sort_rec(lua_State* L, Table* t, int l, int u, SortPredicate pred) +static void sort_siftheap(lua_State* L, Table* t, int l, int u, SortPredicate pred, int root) +{ + LUAU_ASSERT(l <= u); + int count = u - l + 1; + + // process all elements with two children + while (root * 2 + 2 < count) + { + int left = root * 2 + 1, right = root * 2 + 2; + int next = root; + next = sort_less(L, t, l + next, l + left, pred) ? left : next; + next = sort_less(L, t, l + next, l + right, pred) ? right : next; + + if (next == root) + break; + + sort_swap(L, t, l + root, l + next); + root = next; + } + + // process last element if it has just one child + int lastleft = root * 2 + 1; + if (lastleft == count - 1 && sort_less(L, t, l + root, l + lastleft, pred)) + sort_swap(L, t, l + root, l + lastleft); +} + +static void sort_heap(lua_State* L, Table* t, int l, int u, SortPredicate pred) +{ + LUAU_ASSERT(l <= u); + int count = u - l + 1; + + for (int i = count / 2 - 1; i >= 0; --i) + sort_siftheap(L, t, l, u, pred, i); + + for (int i = count - 1; i > 0; --i) + { + sort_swap(L, t, l, l + i); + sort_siftheap(L, t, l, l + i - 1, pred, 0); + } +} + +static void sort_rec(lua_State* L, Table* t, int l, int u, int limit, SortPredicate pred) { // sort range [l..u] (inclusive, 0-based) while (l < u) { - int i, j; + // if the limit has been reached, quick sort is going over the permitted nlogn complexity, so we fall back to heap sort + if (FFlag::LuauIntrosort && limit == 0) + return sort_heap(L, t, l, u, pred); + // sort elements a[l], a[(l+u)/2] and a[u] + // note: this simultaneously acts as a small sort and a median selector if (sort_less(L, t, u, l, pred)) // a[u] < a[l]? sort_swap(L, t, u, l); // swap a[l] - a[u] if (u - l == 1) break; // only 2 elements - i = l + ((u - l) >> 1); // midpoint - if (sort_less(L, t, i, l, pred)) // a[i]> 1); // midpoint + if (sort_less(L, t, m, l, pred)) // a[m]= P @@ -498,63 +431,72 @@ static void sort_rec(lua_State* L, Table* t, int l, int u, SortPredicate pred) break; sort_swap(L, t, i, j); } - // swap pivot (a[u-1]) with a[i], which is the new midpoint - sort_swap(L, t, u - 1, i); - // a[l..i-1] <= a[i] == P <= a[i+1..u] - // adjust so that smaller half is in [j..i] and larger one in [l..u] - if (i - l < u - i) + + // swap pivot a[p] with a[i], which is the new midpoint + sort_swap(L, t, p, i); + + if (FFlag::LuauIntrosort) { - j = l; - i = i - 1; - l = i + 2; + // adjust limit to allow 1.5 log2N recursive steps + limit = (limit >> 1) + (limit >> 2); + + // a[l..i-1] <= a[i] == P <= a[i+1..u] + // sort smaller half recursively; the larger half is sorted in the next loop iteration + if (i - l < u - i) + { + sort_rec(L, t, l, i - 1, limit, pred); + l = i + 1; + } + else + { + sort_rec(L, t, i + 1, u, limit, pred); + u = i - 1; + } } else { - j = i + 1; - i = u; - u = j - 2; + // a[l..i-1] <= a[i] == P <= a[i+1..u] + // adjust so that smaller half is in [j..i] and larger one in [l..u] + if (i - l < u - i) + { + j = l; + i = i - 1; + l = i + 2; + } + else + { + j = i + 1; + i = u; + u = j - 2; + } + + // sort smaller half recursively; the larger half is sorted in the next loop iteration + sort_rec(L, t, j, i, limit, pred); } - sort_rec(L, t, j, i, pred); // call recursively the smaller one - } // repeat the routine for the larger one + } } static int tsort(lua_State* L) { - if (FFlag::LuauOptimizedSort) - { - luaL_checktype(L, 1, LUA_TTABLE); - Table* t = hvalue(L->base); - int n = luaH_getn(t); - if (t->readonly) - luaG_readonlyerror(L); + luaL_checktype(L, 1, LUA_TTABLE); + Table* t = hvalue(L->base); + int n = luaH_getn(t); + if (t->readonly) + luaG_readonlyerror(L); - SortPredicate pred = luaV_lessthan; - if (!lua_isnoneornil(L, 2)) // is there a 2nd argument? - { - luaL_checktype(L, 2, LUA_TFUNCTION); - pred = sort_func; - } - lua_settop(L, 2); // make sure there are two arguments - - if (n > 0) - sort_rec(L, t, 0, n - 1, pred); - return 0; - } - else + SortPredicate pred = luaV_lessthan; + if (!lua_isnoneornil(L, 2)) // is there a 2nd argument? { - luaL_checktype(L, 1, LUA_TTABLE); - int n = lua_objlen(L, 1); - luaL_checkstack(L, 40, ""); // assume array is smaller than 2^40 - if (!lua_isnoneornil(L, 2)) // is there a 2nd argument? - luaL_checktype(L, 2, LUA_TFUNCTION); - lua_settop(L, 2); // make sure there is two arguments - auxsort(L, 1, n); - return 0; + luaL_checktype(L, 2, LUA_TFUNCTION); + pred = sort_func; } + lua_settop(L, 2); // make sure there are two arguments + + if (n > 0) + sort_rec(L, t, 0, n - 1, n, pred); + return 0; } -// }====================================================== - static int tcreate(lua_State* L) { int size = luaL_checkinteger(L, 1); diff --git a/docs/_pages/compatibility.md b/docs/_pages/compatibility.md index 1c15e9b8..b57a9af6 100644 --- a/docs/_pages/compatibility.md +++ b/docs/_pages/compatibility.md @@ -87,7 +87,7 @@ Ephemeron tables may be implemented at some point since they do have valid uses | bitwise operators | ❌ | `bit32` library covers this in absence of 64-bit integers | | basic utf-8 support | ✔️ | we include `utf8` library and other UTF8 features | | functions for packing and unpacking values (string.pack/unpack/packsize) | ✔️ | | -| floor division | ❌ | no strong use cases, syntax overlaps with C comments | +| floor division | 🔜 | | | `ipairs` and the `table` library respect metamethods | ❌ | no strong use cases, performance implications | | new function `table.move` | ✔️ | | | `collectgarbage("count")` now returns only one result | ✔️ | | @@ -98,8 +98,6 @@ It's important to highlight integer support and bitwise operators. For Luau, it' If integers are taken out of the equation, bitwise operators make less sense, as integers aren't a first class feature; additionally, `bit32` library is more fully featured (includes commonly used operations such as rotates and arithmetic shift; bit extraction/replacement is also more readable). Adding operators along with metamethods for all of them increases complexity, which means this feature isn't worth it on the balance. Common arguments for this include a more familiar syntax, which, while true, gets more nuanced as `^` isn't available as a xor operator, and arithmetic right shift isn't expressible without yet another operator, and performance, which in Luau is substantially better than in Lua because `bit32` library uses VM builtins instead of expensive function calls. -Floor division is much less complex, but it's used rarely enough that `math.floor(a/b)` seems like an adequate replacement; additionally, `//` is a comment in C-derived languages and we may decide to adopt it in addition to `--` at some point. - ## Lua 5.4 | feature | status | notes | diff --git a/docs/_posts/2023-03-31-luau-recap-march-2023.md b/docs/_posts/2023-03-31-luau-recap-march-2023.md new file mode 100644 index 00000000..a951b8e6 --- /dev/null +++ b/docs/_posts/2023-03-31-luau-recap-march-2023.md @@ -0,0 +1,143 @@ +--- +layout: single +title: "Luau Recap: March 2023" +--- + +How the time flies! The team has been busy since the last November Luau Recap working on some large updates that are coming in the future, but before those arrive, we have some improvements that you can already use! + +[Cross-posted to the [Roblox Developer Forum](https://devforum.roblox.com/t/luau-recap-march-2023/).] + +## Improved type refinements + +Type refinements handle constraints placed on variables inside conditional blocks. + +In the following example, while variable `a` is declared to have type `number?`, inside the `if` block we know that it cannot be `nil`: + +```lua +local function f(a: number?) + if a ~= nil then + a *= 2 -- no type errors + end + ... +end +``` + +One limitation we had previously is that after a conditional block, refinements were discarded. + +But there are cases where `if` is used to exit the function early, making the following code essentially act as a hidden `else` block. + +We now correctly preserve such refinements and you should be able to remove `assert` function calls that were only used to get rid of false positive errors about types being `nil`. + +```lua +local function f(x: string?) + if not x then return end + + -- x is a 'string' here +end +``` + +Throwing calls like `error()` or `assert(false)` instead of a `return` statement are also recognized. + +```lua +local function f(x: string?) + if not x then error('first argument is nil') end + + -- x is 'string' here +end +``` + +Existing complex refinements like `type`/`typeof`, tagged union checks and other are expected to work as expected. + +## Marking table.getn/foreach/foreachi as deprecated + +`table.getn`, `table.foreach` and `table.foreachi` were deprecated in Lua 5.1 that Luau is based on, and removed in Lua 5.2. + +`table.getn(x)` is equivalent to `rawlen(x)` when 'x' is a table; when 'x' is not a table, `table.getn` produces an error. + +It's difficult to imagine code where `table.getn(x)` is better than either `#x` (idiomatic) or `rawlen(x)` (fully compatible replacement). + +`table.getn` is also slower than both alternatives and was marked as deprecated. + +`table.foreach` is equivalent to a `for .. pairs` loop; `table.foreachi` is equivalent to a `for .. ipairs` loop; both may also be replaced by generalized iteration. + +Both functions are significantly slower than equivalent for loop replacements, are more restrictive because the function can't yield. + +Because both functions bring no value over other library or language alternatives, they were marked deprecated as well. + +You may have noticed linter warnings about places where these functions are used. For compatibility, these functions are not going to be removed. + +## Autocomplete improvements + +When table key type is defined to be a union of string singletons, those keys can now autocomplete in locations marked as '^': + +```lua +type Direction = "north" | "south" | "east" | "west" + +local a: {[Direction]: boolean} = {[^] = true} +local b: {[Direction]: boolean} = {["^"]} +local b: {[Direction]: boolean} = {^} +``` + +We also fixed incorrect and incomplete suggestions inside the header of `if`, `for` and `while` statements. + +## Runtime improvements + +On the runtime side, we added multiple optimizations. + +`table.sort` is now ~4.1x faster (when not using a predicate) and ~2.1x faster when using a simple predicate. + +We also have ideas on how improve the sorting performance in the future. + +`math.floor`, `math.ceil` and `math.round` now use specialized processor instructions. We have measured ~7-9% speedup in math benchmarks that heavily used those functions. + +A small improvement was made to builtin library function calls, getting a 1-2% improvement in code that contains a lot of fastcalls. + +Finally, a fix was made to table array part resizing that brings large improvement to performance of large tables filled as an array, but at an offset (for example, starting at 10000 instead of 1). + +Aside from performance, a correctness issue was fixed in multi-assignment expressions. + +```lua +arr[1], n = n, n - 1 +``` + +In this example, `n - 1` was assigned to `n` before `n` was assigned to `arr[1]`. This issue has now been fixed. + +## Analysis improvements + +Multiple changes were made to improve error messages and type presentation. + +* Table type strings are now shown with newlines, to make them easier to read +* Fixed unions of `nil` types displaying as a single `?` character +* "Type pack A cannot be converted to B" error is not reported instead of a cryptic "Failed to unify type packs" +* Improved error message for value count mismatch in assignments like `local a, b = 2` + +You may have seen error messages like `Type 'string' cannot be converted to 'string?'` even though usually it is valid to assign `local s: string? = 'hello'` because `string` is a sub-type of `string?`. + +This is true in what is called Covariant use contexts, but doesn't hold in Invariant use contexts, like in the example below: + +```lua +local a: { x: Model } +local b: { x: Instance } = a -- Type 'Model' could not be converted into 'Instance' in an invariant context +``` + +In this example, while `Model` is a sub-type of `Instance` and can be used where `Instance` is required. + +The same is not true for a table field because when using table `b`, `b.x` can be assigned an `Instance` that is not a `Model`. When `b` is an alias to `a`, this assignment is not compatible with `a`'s type annotation. + +--- + +Some other light changes to type inference include: + +* `string.match` and `string.gmatch` are now defined to return optional values as match is not guaranteed at runtime +* Added an error when unrelated types are compared with `==`/`~=` +* Fixed issues where variable after `typeof(x) == 'table'` could not have been used as a table + +## Thanks + +A very special thanks to all of our open source contributors: + +* [niansa/tuxifan](https://github.com/niansa) +* [B. Gibbons](https://github.com/bmg817) +* [Epix](https://github.com/EpixScripts) +* [Harold Cindy](https://github.com/HaroldCindy) +* [Qualadore](https://github.com/Qualadore) diff --git a/fuzz/format.cpp b/fuzz/format.cpp index 3ad3912f..4b943bf1 100644 --- a/fuzz/format.cpp +++ b/fuzz/format.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Common.h" +#include #include #include diff --git a/fuzz/linter.cpp b/fuzz/linter.cpp index 66ca5bb1..854c6327 100644 --- a/fuzz/linter.cpp +++ b/fuzz/linter.cpp @@ -3,10 +3,10 @@ #include "Luau/BuiltinDefinitions.h" #include "Luau/Common.h" +#include "Luau/Frontend.h" #include "Luau/Linter.h" #include "Luau/ModuleResolver.h" #include "Luau/Parser.h" -#include "Luau/TypeInfer.h" extern "C" int LLVMFuzzerTestOneInput(const uint8_t* Data, size_t Size) { @@ -18,18 +18,17 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* Data, size_t Size) Luau::ParseResult parseResult = Luau::Parser::parse(reinterpret_cast(Data), Size, names, allocator, options); // "static" here is to accelerate fuzzing process by only creating and populating the type environment once - static Luau::NullModuleResolver moduleResolver; - static Luau::InternalErrorReporter iceHandler; - static Luau::TypeChecker sharedEnv(&moduleResolver, &iceHandler); - static int once = (Luau::registerBuiltinGlobals(sharedEnv), 1); + static Luau::NullFileResolver fileResolver; + static Luau::NullConfigResolver configResolver; + static Luau::Frontend frontend{&fileResolver, &configResolver}; + static int once = (Luau::registerBuiltinGlobals(frontend), 1); (void)once; - static int once2 = (Luau::freeze(sharedEnv.globalTypes), 1); + static int once2 = (Luau::freeze(frontend.globals.globalTypes), 1); (void)once2; if (parseResult.errors.empty()) { - Luau::TypeChecker typeck(&moduleResolver, &iceHandler); - typeck.globalScope = sharedEnv.globalScope; + Luau::TypeChecker typeck(frontend.globals.globalScope, &frontend.moduleResolver, frontend.builtinTypes, &frontend.iceHandler); Luau::LintOptions lintOptions; lintOptions.warningMask = ~0ull; diff --git a/fuzz/proto.cpp b/fuzz/proto.cpp index c94f0889..ffeb4919 100644 --- a/fuzz/proto.cpp +++ b/fuzz/proto.cpp @@ -261,8 +261,8 @@ DEFINE_PROTO_FUZZER(const luau::ModuleSet& message) { static FuzzFileResolver fileResolver; static FuzzConfigResolver configResolver; - static Luau::FrontendOptions options{true, true}; - static Luau::Frontend frontend(&fileResolver, &configResolver, options); + static Luau::FrontendOptions defaultOptions{/*retainFullTypeGraphs*/ true, /*forAutocomplete*/ false, /*runLintChecks*/ kFuzzLinter}; + static Luau::Frontend frontend(&fileResolver, &configResolver, defaultOptions); static int once = (setupFrontend(frontend), 0); (void)once; @@ -285,16 +285,12 @@ DEFINE_PROTO_FUZZER(const luau::ModuleSet& message) try { - Luau::CheckResult result = frontend.check(name, std::nullopt); - - // lint (note that we need access to types so we need to do this with typeck in scope) - if (kFuzzLinter && result.errors.empty()) - frontend.lint(name, std::nullopt); + frontend.check(name); // Second pass in strict mode (forced by auto-complete) - Luau::FrontendOptions opts; - opts.forAutocomplete = true; - frontend.check(name, opts); + Luau::FrontendOptions options = defaultOptions; + options.forAutocomplete = true; + frontend.check(name, options); } catch (std::exception&) { diff --git a/fuzz/typeck.cpp b/fuzz/typeck.cpp index a6c9ae28..4f8f8857 100644 --- a/fuzz/typeck.cpp +++ b/fuzz/typeck.cpp @@ -3,9 +3,9 @@ #include "Luau/BuiltinDefinitions.h" #include "Luau/Common.h" +#include "Luau/Frontend.h" #include "Luau/ModuleResolver.h" #include "Luau/Parser.h" -#include "Luau/TypeInfer.h" LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTINT(LuauTypeInferTypePackLoopLimit) @@ -23,23 +23,22 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* Data, size_t Size) Luau::ParseResult parseResult = Luau::Parser::parse(reinterpret_cast(Data), Size, names, allocator, options); // "static" here is to accelerate fuzzing process by only creating and populating the type environment once - static Luau::NullModuleResolver moduleResolver; - static Luau::InternalErrorReporter iceHandler; - static Luau::TypeChecker sharedEnv(&moduleResolver, &iceHandler); - static int once = (Luau::registerBuiltinGlobals(sharedEnv), 1); + static Luau::NullFileResolver fileResolver; + static Luau::NullConfigResolver configResolver; + static Luau::Frontend frontend{&fileResolver, &configResolver}; + static int once = (Luau::registerBuiltinGlobals(frontend), 1); (void)once; - static int once2 = (Luau::freeze(sharedEnv.globalTypes), 1); + static int once2 = (Luau::freeze(frontend.globals.globalTypes), 1); (void)once2; if (parseResult.errors.empty()) { + Luau::TypeChecker typeck(frontend.globals.globalScope, &frontend.moduleResolver, frontend.builtinTypes, &frontend.iceHandler); + Luau::SourceModule module; module.root = parseResult.root; module.mode = Luau::Mode::Nonstrict; - Luau::TypeChecker typeck(&moduleResolver, &iceHandler); - typeck.globalScope = sharedEnv.globalScope; - try { typeck.check(module, Luau::Mode::Nonstrict); diff --git a/rfcs/syntax-floor-division-operator.md b/rfcs/syntax-floor-division-operator.md new file mode 100644 index 00000000..8ec3913f --- /dev/null +++ b/rfcs/syntax-floor-division-operator.md @@ -0,0 +1,62 @@ +# Floor division operator + +## Summary + +Add floor division operator `//` to ease computing with integers. + +## Motivation + +Integers are everywhere. Indices, pixel coordinates, offsets, ranges, quantities, counters, rationals, fixed point arithmetic and bitwise operations all use integers. + +Luau is generally well suited to work with integers. The math operators +, -, \*, ^ and % support integers. That is, given integer operands these operators produce an integer result (provided that the result fits into representable range of integers). However, that is not the case with the division operator `/` which in the general case produces numbers with fractionals. + +To overcome this, typical Luau code performing integer computations needs to wrap the result of division inside a call to `math.floor`. This has a number of issues and can be error prone in practice. + +A typical mistake is to forget to use `math.floor`. This can produce subtle issues ranging from slightly wrong results to script errors. A script error could occur, for example, when the result of division is used to fetch from a table with only integer keys, which produces nil and a script error happens soon after. Another type of error occurs when an accidental fractional number is passed to a C function. Depending on the implementation, the C function could raise an error (if it checks that the number is actually an integer) or cause logic errors due to rounding. + +Particularly problematic is incorrect code which seems to work with frequently used data, only to fail with some rare input. For example, image sizes often have power of two dimensions, so code dealing with them may appear to work fine until much later some rare image has an odd size and a division by two in the code does not produce the correct result. Due to better ergonomics of the floor division operator, it becomes a second nature to write `//` everywhere when integers are involved and thus this class of bugs is much less likely to happen. + +Another issue with using `math.floor` as a workaround is that code performing a lot of integer calculations is harder to understand, write and maintain. + +Especially with applications dealing with pixel graphics, such as 2D games, integer math is so common that `math.floor` could easily become the most commonly used math library function. For these applications, avoiding the calls to `math.floor` is alluring from the performance perspective. + +> Non-normative: Here are the top math library functions used by a shipped game that heavily uses Lua: +> `floor`: 461 matches, `max`: 224 matches, `sin`: 197 matches, `min`: 195 matches, `clamp`: 171 matches, `cos`: 106 matches, `abs`: 85 matches. +> The majority of `math.floor` calls disappear from this codebase with the floor division operator. + +Lua has had floor division operator since version 5.3, so its addition to Luau makes it easier to migrate from Lua to Luau and perhaps more importantly use the wide variety of existing Lua libraries in Luau. Of other languages, most notably Python has floor division operator with same semantics and same syntax. R and Julia also have a similar operator. + +## Design + +The design mirrors Lua 5.3: + +New operators `//` and `//=` will be added to the language. The operator `//` performs division of two operands and rounds the result towards negative infinity. By default, the operator is only defined for numbers. The operator has the same precedence as the normal division operator `/`. `//=` is the compound-assignment operator for floor division, similar to the existing operator `/=`. + +A new metamethod `__idiv` will be added. The metamethod is invoked when any operand of floor division is not a number. The metamethod can be used to implement floor division for any user defined data type as well as the built-in vector type. + +The typechecker does not need special handling for the new operators. It can simply apply the same rules for floor division as it does for normal division operators. + +Examples of usage: + +``` +-- Convert offset into 2d indices +local i, j = offset % 5, offset // 5 + +-- Halve dimensions of an image or UI element +width, height = width // 2, height // 2 + +-- Draw an image to the center of the window +draw_image(image, window_width // 2 - element_width // 2, window_height // 2 - element_height // 2) +``` + +## Drawbacks + +The addition of the new operator adds some complexity to the implementation (mostly to the VM) and to the language, which can be seen as a drawback. + +C like languages use `//` for line comments. Using the symbol `//` for floor division closes the door for using it for line comments in Luau. On the other hand, Luau already has long and short comment syntax, so adding yet another syntax for comments would add complexity to the language for little benefit. Moreover, it would make it harder to translate code from Lua to Luau and use existing Lua libraries if the symbol `//` has a completely different meaning in Lua and Luau. + +## Alternatives + +An alternative would be to do nothing but this would not solve the issues the lack of floor division currently has. + +An alternative implementation would treat `//` and `//=` only as syntactic sugar. The addition of new VM opcode for floor division could be omitted and the compiler could be simply modified to automatically emit a call to `math.floor` when necessary. This would require only minimal changes to Luau, but it would not support overloading the floor division operator using metamethods and would not have the performance benefits of the full implementation. diff --git a/tests/AssemblyBuilderA64.test.cpp b/tests/AssemblyBuilderA64.test.cpp index e23b965b..1690c748 100644 --- a/tests/AssemblyBuilderA64.test.cpp +++ b/tests/AssemblyBuilderA64.test.cpp @@ -32,9 +32,9 @@ static std::string bytecodeAsArray(const std::vector& code) class AssemblyBuilderA64Fixture { public: - bool check(void (*f)(AssemblyBuilderA64& build), std::vector code, std::vector data = {}) + bool check(void (*f)(AssemblyBuilderA64& build), std::vector code, std::vector data = {}, unsigned int features = 0) { - AssemblyBuilderA64 build(/* logText= */ false); + AssemblyBuilderA64 build(/* logText= */ false, features); f(build); @@ -120,6 +120,10 @@ TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Loads") SINGLE_COMPARE(ldrsh(x0, x1), 0x79800020); SINGLE_COMPARE(ldrsh(w0, x1), 0x79C00020); SINGLE_COMPARE(ldrsw(x0, x1), 0xB9800020); + + // paired loads + SINGLE_COMPARE(ldp(x0, x1, mem(x2, 8)), 0xA9408440); + SINGLE_COMPARE(ldp(w0, w1, mem(x2, -8)), 0x297F0440); } TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Stores") @@ -135,15 +139,58 @@ TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Stores") SINGLE_COMPARE(str(w0, x1), 0xB9000020); SINGLE_COMPARE(strb(w0, x1), 0x39000020); SINGLE_COMPARE(strh(w0, x1), 0x79000020); + + // paired stores + SINGLE_COMPARE(stp(x0, x1, mem(x2, 8)), 0xA9008440); + SINGLE_COMPARE(stp(w0, w1, mem(x2, -8)), 0x293F0440); } TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Moves") { SINGLE_COMPARE(mov(x0, x1), 0xAA0103E0); SINGLE_COMPARE(mov(w0, w1), 0x2A0103E0); - SINGLE_COMPARE(mov(x0, 42), 0xD2800540); - SINGLE_COMPARE(mov(w0, 42), 0x52800540); + + SINGLE_COMPARE(movz(x0, 42), 0xD2800540); + SINGLE_COMPARE(movz(w0, 42), 0x52800540); + SINGLE_COMPARE(movn(x0, 42), 0x92800540); + SINGLE_COMPARE(movn(w0, 42), 0x12800540); SINGLE_COMPARE(movk(x0, 42, 16), 0xF2A00540); + + CHECK(check( + [](AssemblyBuilderA64& build) { + build.mov(x0, 42); + }, + {0xD2800540})); + + CHECK(check( + [](AssemblyBuilderA64& build) { + build.mov(x0, 424242); + }, + {0xD28F2640, 0xF2A000C0})); + + CHECK(check( + [](AssemblyBuilderA64& build) { + build.mov(x0, -42); + }, + {0x92800520})); + + CHECK(check( + [](AssemblyBuilderA64& build) { + build.mov(x0, -424242); + }, + {0x928F2620, 0xF2BFFF20})); + + CHECK(check( + [](AssemblyBuilderA64& build) { + build.mov(x0, -65536); + }, + {0x929FFFE0})); + + CHECK(check( + [](AssemblyBuilderA64& build) { + build.mov(x0, -65537); + }, + {0x92800000, 0xF2BFFFC0})); } TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "ControlFlow") @@ -222,6 +269,103 @@ TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Constants") // clang-format on } +TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "AddressOfLabel") +{ + // clang-format off + CHECK(check( + [](AssemblyBuilderA64& build) { + Label label; + build.adr(x0, label); + build.add(x0, x0, x0); + build.setLabel(label); + }, + { + 0x10000040, 0x8b000000, + })); + // clang-format on +} + +TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "FPBasic") +{ + SINGLE_COMPARE(fmov(d0, d1), 0x1E604020); +} + +TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "FPMath") +{ + SINGLE_COMPARE(fabs(d1, d2), 0x1E60C041); + SINGLE_COMPARE(fadd(d1, d2, d3), 0x1E632841); + SINGLE_COMPARE(fdiv(d1, d2, d3), 0x1E631841); + SINGLE_COMPARE(fmul(d1, d2, d3), 0x1E630841); + SINGLE_COMPARE(fneg(d1, d2), 0x1E614041); + SINGLE_COMPARE(fsqrt(d1, d2), 0x1E61C041); + SINGLE_COMPARE(fsub(d1, d2, d3), 0x1E633841); + + SINGLE_COMPARE(frinta(d1, d2), 0x1E664041); + SINGLE_COMPARE(frintm(d1, d2), 0x1E654041); + SINGLE_COMPARE(frintp(d1, d2), 0x1E64C041); + + SINGLE_COMPARE(fcvtzs(w1, d2), 0x1E780041); + SINGLE_COMPARE(fcvtzs(x1, d2), 0x9E780041); + SINGLE_COMPARE(fcvtzu(w1, d2), 0x1E790041); + SINGLE_COMPARE(fcvtzu(x1, d2), 0x9E790041); + + SINGLE_COMPARE(scvtf(d1, w2), 0x1E620041); + SINGLE_COMPARE(scvtf(d1, x2), 0x9E620041); + SINGLE_COMPARE(ucvtf(d1, w2), 0x1E630041); + SINGLE_COMPARE(ucvtf(d1, x2), 0x9E630041); + + CHECK(check( + [](AssemblyBuilderA64& build) { + build.fjcvtzs(w1, d2); + }, + {0x1E7E0041}, {}, A64::Feature_JSCVT)); +} + +TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "FPLoadStore") +{ + // address forms + SINGLE_COMPARE(ldr(d0, x1), 0xFD400020); + SINGLE_COMPARE(ldr(d0, mem(x1, 8)), 0xFD400420); + SINGLE_COMPARE(ldr(d0, mem(x1, x7)), 0xFC676820); + SINGLE_COMPARE(ldr(d0, mem(x1, -7)), 0xFC5F9020); + SINGLE_COMPARE(str(d0, x1), 0xFD000020); + SINGLE_COMPARE(str(d0, mem(x1, 8)), 0xFD000420); + SINGLE_COMPARE(str(d0, mem(x1, x7)), 0xFC276820); + SINGLE_COMPARE(str(d0, mem(x1, -7)), 0xFC1F9020); + + // load/store sizes + SINGLE_COMPARE(ldr(d0, x1), 0xFD400020); + SINGLE_COMPARE(ldr(q0, x1), 0x3DC00020); + SINGLE_COMPARE(str(d0, x1), 0xFD000020); + SINGLE_COMPARE(str(q0, x1), 0x3D800020); +} + +TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "FPCompare") +{ + SINGLE_COMPARE(fcmp(d0, d1), 0x1E612000); + SINGLE_COMPARE(fcmpz(d1), 0x1E602028); +} + +TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "AddressOffsetSize") +{ + SINGLE_COMPARE(ldr(w0, mem(x1, 16)), 0xB9401020); + SINGLE_COMPARE(ldr(x0, mem(x1, 16)), 0xF9400820); + SINGLE_COMPARE(ldr(d0, mem(x1, 16)), 0xFD400820); + SINGLE_COMPARE(ldr(q0, mem(x1, 16)), 0x3DC00420); + + SINGLE_COMPARE(str(w0, mem(x1, 16)), 0xB9001020); + SINGLE_COMPARE(str(x0, mem(x1, 16)), 0xF9000820); + SINGLE_COMPARE(str(d0, mem(x1, 16)), 0xFD000820); + SINGLE_COMPARE(str(q0, mem(x1, 16)), 0x3D800420); +} + +TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "ConditionalSelect") +{ + SINGLE_COMPARE(csel(x0, x1, x2, ConditionA64::Equal), 0x9A820020); + SINGLE_COMPARE(csel(w0, w1, w2, ConditionA64::Equal), 0x1A820020); + SINGLE_COMPARE(fcsel(d0, d1, d2, ConditionA64::Equal), 0x1E620C20); +} + TEST_CASE("LogTest") { AssemblyBuilderA64 build(/* logText= */ true); @@ -243,6 +387,17 @@ TEST_CASE("LogTest") build.b(ConditionA64::Plus, l); build.cbz(x7, l); + build.ldp(x0, x1, mem(x8, 8)); + build.adr(x0, l); + + build.fabs(d1, d2); + build.ldr(q1, x2); + + build.csel(x0, x1, x2, ConditionA64::Equal); + + build.fcmp(d0, d1); + build.fcmpz(d0); + build.setLabel(l); build.ret(); @@ -263,6 +418,13 @@ TEST_CASE("LogTest") blr x0 b.pl .L1 cbz x7,.L1 + ldp x0,x1,[x8,#8] + adr x0,.L1 + fabs d1,d2 + ldr q1,[x2] + csel x0,x1,x2,eq + fcmp d0,d1 + fcmp d0,#0 .L1: ret )"; diff --git a/tests/AssemblyBuilderX64.test.cpp b/tests/AssemblyBuilderX64.test.cpp index 6aa7aa56..054eca7b 100644 --- a/tests/AssemblyBuilderX64.test.cpp +++ b/tests/AssemblyBuilderX64.test.cpp @@ -507,6 +507,8 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXConversionInstructionForms") SINGLE_COMPARE(vcvtsi2sd(xmm6, xmm11, dword[rcx + rdx]), 0xc4, 0xe1, 0x23, 0x2a, 0x34, 0x11); SINGLE_COMPARE(vcvtsi2sd(xmm5, xmm10, r13), 0xc4, 0xc1, 0xab, 0x2a, 0xed); SINGLE_COMPARE(vcvtsi2sd(xmm6, xmm11, qword[rcx + rdx]), 0xc4, 0xe1, 0xa3, 0x2a, 0x34, 0x11); + SINGLE_COMPARE(vcvtsd2ss(xmm5, xmm10, xmm11), 0xc4, 0xc1, 0x2b, 0x5a, 0xeb); + SINGLE_COMPARE(vcvtsd2ss(xmm6, xmm11, qword[rcx + rdx]), 0xc4, 0xe1, 0xa3, 0x5a, 0x34, 0x11); } TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXTernaryInstructionForms") diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index aedb50ab..c79bf35e 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -85,8 +85,8 @@ struct ACFixtureImpl : BaseType { GlobalTypes& globals = this->frontend.globalsForAutocomplete; unfreeze(globals.globalTypes); - LoadDefinitionFileResult result = - loadDefinitionFile(this->frontend.typeChecker, globals, globals.globalScope, source, "@test", /* captureComments */ false); + LoadDefinitionFileResult result = this->frontend.loadDefinitionFile( + globals, globals.globalScope, source, "@test", /* captureComments */ false, /* typeCheckForAutocomplete */ true); freeze(globals.globalTypes); REQUIRE_MESSAGE(result.success, "loadDefinition: unable to load definition file"); @@ -2995,8 +2995,6 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_string_singletons") TEST_CASE_FIXTURE(ACFixture, "string_singleton_as_table_key") { - ScopedFastFlag sff{"LuauCompleteTableKeysBetter", true}; - check(R"( type Direction = "up" | "down" @@ -3450,8 +3448,6 @@ TEST_CASE_FIXTURE(ACFixture, "string_contents_is_available_to_callback") TEST_CASE_FIXTURE(ACFixture, "autocomplete_response_perf1" * doctest::timeout(0.5)) { - ScopedFastFlag luauAutocompleteSkipNormalization{"LuauAutocompleteSkipNormalization", true}; - // Build a function type with a large overload set const int parts = 100; std::string source; diff --git a/tests/CodeAllocator.test.cpp b/tests/CodeAllocator.test.cpp index a6ed96f0..359f2ba1 100644 --- a/tests/CodeAllocator.test.cpp +++ b/tests/CodeAllocator.test.cpp @@ -465,6 +465,7 @@ TEST_CASE("GeneratedCodeExecutionA64") build.add(x1, x1, 2); build.add(x0, x0, x1, /* LSL */ 1); + build.ret(); build.finalize(); diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index c9d0c01d..cabf1cce 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -4691,8 +4691,6 @@ RETURN R0 0 TEST_CASE("LoopUnrollCost") { - ScopedFastFlag sff("LuauCompileBuiltinArity", true); - ScopedFastInt sfis[] = { {"LuauCompileLoopUnrollThreshold", 25}, {"LuauCompileLoopUnrollThresholdMaxBoost", 300}, @@ -5962,8 +5960,6 @@ RETURN R2 1 TEST_CASE("InlineMultret") { - ScopedFastFlag sff("LuauCompileBuiltinArity", true); - // inlining a function in multret context is prohibited since we can't adjust L->top outside of CALL/GETVARARGS CHECK_EQ("\n" + compileFunction(R"( local function foo(a) @@ -6301,8 +6297,6 @@ RETURN R0 52 TEST_CASE("BuiltinFoldingProhibited") { - ScopedFastFlag sff("LuauCompileBuiltinArity", true); - CHECK_EQ("\n" + compileFunction(R"( return math.abs(), @@ -6905,8 +6899,6 @@ L3: RETURN R0 0 TEST_CASE("BuiltinArity") { - ScopedFastFlag sff("LuauCompileBuiltinArity", true); - // by default we can't assume that we know parameter/result count for builtins as they can be overridden at runtime CHECK_EQ("\n" + compileFunction(R"( return math.abs(unknown()) diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 58fb7d91..cd9a2f12 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -9,6 +9,7 @@ #include "Luau/StringUtils.h" #include "Luau/BytecodeBuilder.h" #include "Luau/CodeGen.h" +#include "Luau/Frontend.h" #include "doctest.h" #include "ScopedFlags.h" @@ -243,6 +244,24 @@ static StateRef runConformance(const char* name, void (*setup)(lua_State* L) = n return globalState; } +static void* limitedRealloc(void* ud, void* ptr, size_t osize, size_t nsize) +{ + if (nsize == 0) + { + free(ptr); + return nullptr; + } + else if (nsize > 8 * 1024 * 1024) + { + // For testing purposes return null for large allocations so we can generate errors related to memory allocation failures + return nullptr; + } + else + { + return realloc(ptr, nsize); + } +} + TEST_SUITE_BEGIN("Conformance"); TEST_CASE("Assert") @@ -381,6 +400,8 @@ static int cxxthrow(lua_State* L) TEST_CASE("PCall") { + ScopedFastFlag sff("LuauBetterOOMHandling", true); + runConformance("pcall.lua", [](lua_State* L) { lua_pushcfunction(L, cxxthrow, "cxxthrow"); lua_setglobal(L, "cxxthrow"); @@ -395,7 +416,7 @@ TEST_CASE("PCall") }, "resumeerror"); lua_setglobal(L, "resumeerror"); - }); + }, nullptr, lua_newstate(limitedRealloc, nullptr)); } TEST_CASE("Pack") @@ -501,17 +522,15 @@ TEST_CASE("Types") { runConformance("types.lua", [](lua_State* L) { Luau::NullModuleResolver moduleResolver; - Luau::InternalErrorReporter iceHandler; - Luau::BuiltinTypes builtinTypes; - Luau::GlobalTypes globals{Luau::NotNull{&builtinTypes}}; - Luau::TypeChecker env(globals, &moduleResolver, Luau::NotNull{&builtinTypes}, &iceHandler); - - Luau::registerBuiltinGlobals(env, globals); - Luau::freeze(globals.globalTypes); + Luau::NullFileResolver fileResolver; + Luau::NullConfigResolver configResolver; + Luau::Frontend frontend{&fileResolver, &configResolver}; + Luau::registerBuiltinGlobals(frontend, frontend.globals); + Luau::freeze(frontend.globals.globalTypes); lua_newtable(L); - for (const auto& [name, binding] : globals.globalScope->bindings) + for (const auto& [name, binding] : frontend.globals.globalScope->bindings) { populateRTTI(L, binding.typeId); lua_setfield(L, -2, toString(name).c_str()); @@ -882,7 +901,7 @@ TEST_CASE("ApiIter") TEST_CASE("ApiCalls") { - StateRef globalState = runConformance("apicalls.lua"); + StateRef globalState = runConformance("apicalls.lua", nullptr, nullptr, lua_newstate(limitedRealloc, nullptr)); lua_State* L = globalState.get(); // lua_call @@ -981,6 +1000,55 @@ TEST_CASE("ApiCalls") CHECK(lua_tonumber(L, -1) == 4); lua_pop(L, 1); } + + ScopedFastFlag sff("LuauBetterOOMHandling", true); + + // lua_pcall on OOM + { + lua_getfield(L, LUA_GLOBALSINDEX, "largealloc"); + int res = lua_pcall(L, 0, 0, 0); + CHECK(res == LUA_ERRMEM); + } + + // lua_pcall on OOM with an error handler + { + lua_getfield(L, LUA_GLOBALSINDEX, "oops"); + lua_getfield(L, LUA_GLOBALSINDEX, "largealloc"); + int res = lua_pcall(L, 0, 1, -2); + CHECK(res == LUA_ERRMEM); + CHECK((lua_isstring(L, -1) && strcmp(lua_tostring(L, -1), "oops") == 0)); + lua_pop(L, 1); + } + + // lua_pcall on OOM with an error handler that errors + { + lua_getfield(L, LUA_GLOBALSINDEX, "error"); + lua_getfield(L, LUA_GLOBALSINDEX, "largealloc"); + int res = lua_pcall(L, 0, 1, -2); + CHECK(res == LUA_ERRERR); + CHECK((lua_isstring(L, -1) && strcmp(lua_tostring(L, -1), "error in error handling") == 0)); + lua_pop(L, 1); + } + + // lua_pcall on OOM with an error handler that OOMs + { + lua_getfield(L, LUA_GLOBALSINDEX, "largealloc"); + lua_getfield(L, LUA_GLOBALSINDEX, "largealloc"); + int res = lua_pcall(L, 0, 1, -2); + CHECK(res == LUA_ERRMEM); + CHECK((lua_isstring(L, -1) && strcmp(lua_tostring(L, -1), "not enough memory") == 0)); + lua_pop(L, 1); + } + + // lua_pcall on error with an error handler that OOMs + { + lua_getfield(L, LUA_GLOBALSINDEX, "largealloc"); + lua_getfield(L, LUA_GLOBALSINDEX, "error"); + int res = lua_pcall(L, 0, 1, -2); + CHECK(res == LUA_ERRERR); + CHECK((lua_isstring(L, -1) && strcmp(lua_tostring(L, -1), "error in error handling") == 0)); + lua_pop(L, 1); + } } TEST_CASE("ApiAtoms") @@ -1051,26 +1119,7 @@ TEST_CASE("ExceptionObject") return ExceptionResult{false, ""}; }; - auto reallocFunc = [](void* /*ud*/, void* ptr, size_t /*osize*/, size_t nsize) -> void* { - if (nsize == 0) - { - free(ptr); - return nullptr; - } - else if (nsize > 512 * 1024) - { - // For testing purposes return null for large allocations - // so we can generate exceptions related to memory allocation - // failures. - return nullptr; - } - else - { - return realloc(ptr, nsize); - } - }; - - StateRef globalState = runConformance("exceptions.lua", nullptr, nullptr, lua_newstate(reallocFunc, nullptr)); + StateRef globalState = runConformance("exceptions.lua", nullptr, nullptr, lua_newstate(limitedRealloc, nullptr)); lua_State* L = globalState.get(); { @@ -1250,7 +1299,9 @@ TEST_CASE("Interrupt") 13, 13, 16, - 20, + 23, + 21, + 25, }; static int index; diff --git a/tests/ConstraintGraphBuilderFixture.cpp b/tests/ConstraintGraphBuilderFixture.cpp index cc239b7e..d34b86bd 100644 --- a/tests/ConstraintGraphBuilderFixture.cpp +++ b/tests/ConstraintGraphBuilderFixture.cpp @@ -31,8 +31,7 @@ void ConstraintGraphBuilderFixture::generateConstraints(const std::string& code) void ConstraintGraphBuilderFixture::solve(const std::string& code) { generateConstraints(code); - ConstraintSolver cs{NotNull{&normalizer}, NotNull{rootScope}, constraints, "MainModule", NotNull{mainModule->reduction.get()}, - NotNull(&moduleResolver), {}, &logger}; + ConstraintSolver cs{NotNull{&normalizer}, NotNull{rootScope}, constraints, "MainModule", NotNull(&moduleResolver), {}, &logger}; cs.run(); } diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index 4d2e83fc..aebf177c 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -506,7 +506,8 @@ void Fixture::validateErrors(const std::vector& errors) LoadDefinitionFileResult Fixture::loadDefinition(const std::string& source) { unfreeze(frontend.globals.globalTypes); - LoadDefinitionFileResult result = frontend.loadDefinitionFile(source, "@test", /* captureComments */ false); + LoadDefinitionFileResult result = + frontend.loadDefinitionFile(frontend.globals, frontend.globals.globalScope, source, "@test", /* captureComments */ false); freeze(frontend.globals.globalTypes); if (result.module) @@ -521,9 +522,9 @@ BuiltinsFixture::BuiltinsFixture(bool freeze, bool prepareAutocomplete) Luau::unfreeze(frontend.globals.globalTypes); Luau::unfreeze(frontend.globalsForAutocomplete.globalTypes); - registerBuiltinGlobals(frontend); + registerBuiltinGlobals(frontend, frontend.globals); if (prepareAutocomplete) - registerBuiltinGlobals(frontend.typeCheckerForAutocomplete, frontend.globalsForAutocomplete); + registerBuiltinGlobals(frontend, frontend.globalsForAutocomplete, /*typeCheckForAutocomplete*/ true); registerTestTypes(); Luau::freeze(frontend.globals.globalTypes); @@ -594,8 +595,12 @@ void registerHiddenTypes(Frontend* frontend) TypeId t = globals.globalTypes.addType(GenericType{"T"}); GenericTypeDefinition genericT{t}; + TypeId u = globals.globalTypes.addType(GenericType{"U"}); + GenericTypeDefinition genericU{u}; + ScopePtr globalScope = globals.globalScope; globalScope->exportedTypeBindings["Not"] = TypeFun{{genericT}, globals.globalTypes.addType(NegationType{t})}; + globalScope->exportedTypeBindings["Mt"] = TypeFun{{genericT, genericU}, globals.globalTypes.addType(MetatableType{t, u})}; globalScope->exportedTypeBindings["fun"] = TypeFun{{}, frontend->builtinTypes->functionType}; globalScope->exportedTypeBindings["cls"] = TypeFun{{}, frontend->builtinTypes->classType}; globalScope->exportedTypeBindings["err"] = TypeFun{{}, frontend->builtinTypes->errorType}; diff --git a/tests/Fixture.h b/tests/Fixture.h index 4c49593c..8d48ab1d 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -94,7 +94,6 @@ struct Fixture TypeId requireTypeAlias(const std::string& name); ScopedFastFlag sff_DebugLuauFreezeArena; - ScopedFastFlag luauLintInTypecheck{"LuauLintInTypecheck", true}; TestFileResolver fileResolver; TestConfigResolver configResolver; diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index 3b1ec4ad..13fd6e0f 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -877,7 +877,7 @@ TEST_CASE_FIXTURE(FrontendFixture, "environments") ScopePtr testScope = frontend.addEnvironment("test"); unfreeze(frontend.globals.globalTypes); - loadDefinitionFile(frontend.typeChecker, frontend.globals, testScope, R"( + frontend.loadDefinitionFile(frontend.globals, testScope, R"( export type Foo = number | string )", "@test", /* captureComments */ false); diff --git a/tests/IrBuilder.test.cpp b/tests/IrBuilder.test.cpp index 37c12dc9..c1392c9d 100644 --- a/tests/IrBuilder.test.cpp +++ b/tests/IrBuilder.test.cpp @@ -42,7 +42,7 @@ public: f(a); build.beginBlock(a); - build.inst(IrCmd::LOP_RETURN, build.constUint(1)); + build.inst(IrCmd::RETURN, build.constUint(1)); }; template @@ -56,10 +56,10 @@ public: f(a, b); build.beginBlock(a); - build.inst(IrCmd::LOP_RETURN, build.constUint(1)); + build.inst(IrCmd::RETURN, build.constUint(1)); build.beginBlock(b); - build.inst(IrCmd::LOP_RETURN, build.constUint(2)); + build.inst(IrCmd::RETURN, build.constUint(2)); }; void checkEq(IrOp instOp, const IrInst& inst) @@ -94,10 +94,10 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptCheckTag") build.inst(IrCmd::CHECK_TAG, tag1, build.constTag(0), fallback); IrOp tag2 = build.inst(IrCmd::LOAD_TAG, build.vmConst(5)); build.inst(IrCmd::CHECK_TAG, tag2, build.constTag(0), fallback); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); build.beginBlock(fallback); - build.inst(IrCmd::LOP_RETURN, build.constUint(1)); + build.inst(IrCmd::RETURN, build.constUint(1)); updateUseCounts(build.function); optimizeMemoryOperandsX64(build.function); @@ -107,10 +107,10 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptCheckTag") bb_0: CHECK_TAG R2, tnil, bb_fallback_1 CHECK_TAG K5, tnil, bb_fallback_1 - LOP_RETURN 0u + RETURN 0u bb_fallback_1: - LOP_RETURN 1u + RETURN 1u )"); } @@ -123,7 +123,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptBinaryArith") IrOp opA = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(1)); IrOp opB = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(2)); build.inst(IrCmd::ADD_NUM, opA, opB); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); updateUseCounts(build.function); optimizeMemoryOperandsX64(build.function); @@ -133,7 +133,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptBinaryArith") bb_0: %0 = LOAD_DOUBLE R1 %2 = ADD_NUM %0, R2 - LOP_RETURN 0u + RETURN 0u )"); } @@ -150,10 +150,10 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptEqTag1") build.inst(IrCmd::JUMP_EQ_TAG, opA, opB, trueBlock, falseBlock); build.beginBlock(trueBlock); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); build.beginBlock(falseBlock); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); updateUseCounts(build.function); optimizeMemoryOperandsX64(build.function); @@ -165,10 +165,10 @@ bb_0: JUMP_EQ_TAG R1, %1, bb_1, bb_2 bb_1: - LOP_RETURN 0u + RETURN 0u bb_2: - LOP_RETURN 0u + RETURN 0u )"); } @@ -186,10 +186,10 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptEqTag2") build.inst(IrCmd::JUMP_EQ_TAG, opA, opB, trueBlock, falseBlock); build.beginBlock(trueBlock); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); build.beginBlock(falseBlock); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); updateUseCounts(build.function); optimizeMemoryOperandsX64(build.function); @@ -203,10 +203,10 @@ bb_0: JUMP_EQ_TAG R2, %0, bb_1, bb_2 bb_1: - LOP_RETURN 0u + RETURN 0u bb_2: - LOP_RETURN 0u + RETURN 0u )"); } @@ -224,10 +224,10 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptEqTag3") build.inst(IrCmd::JUMP_EQ_TAG, opA, build.constTag(0), trueBlock, falseBlock); build.beginBlock(trueBlock); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); build.beginBlock(falseBlock); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); updateUseCounts(build.function); optimizeMemoryOperandsX64(build.function); @@ -241,10 +241,10 @@ bb_0: JUMP_EQ_TAG %2, tnil, bb_1, bb_2 bb_1: - LOP_RETURN 0u + RETURN 0u bb_2: - LOP_RETURN 0u + RETURN 0u )"); } @@ -261,10 +261,10 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptJumpCmpNum") build.inst(IrCmd::JUMP_CMP_NUM, opA, opB, trueBlock, falseBlock); build.beginBlock(trueBlock); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); build.beginBlock(falseBlock); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); updateUseCounts(build.function); optimizeMemoryOperandsX64(build.function); @@ -276,10 +276,10 @@ bb_0: JUMP_CMP_NUM R1, %1, bb_1, bb_2 bb_1: - LOP_RETURN 0u + RETURN 0u bb_2: - LOP_RETURN 0u + RETURN 0u )"); } @@ -317,7 +317,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "Numeric") build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::INT_TO_NUM, build.constInt(8))); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); updateUseCounts(build.function); constantFold(); @@ -342,7 +342,7 @@ bb_0: STORE_INT R0, 1i STORE_INT R0, 0i STORE_DOUBLE R0, 8 - LOP_RETURN 0u + RETURN 0u )"); } @@ -373,25 +373,25 @@ bb_0: JUMP bb_1 bb_1: - LOP_RETURN 1u + RETURN 1u bb_3: JUMP bb_5 bb_5: - LOP_RETURN 2u + RETURN 2u bb_6: JUMP bb_7 bb_7: - LOP_RETURN 1u + RETURN 1u bb_9: JUMP bb_11 bb_11: - LOP_RETURN 2u + RETURN 2u )"); } @@ -400,18 +400,18 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "NumToIndex") { withOneBlock([this](IrOp a) { build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::TRY_NUM_TO_INDEX, build.constDouble(4), a)); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); }); withOneBlock([this](IrOp a) { build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::TRY_NUM_TO_INDEX, build.constDouble(1.2), a)); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); }); withOneBlock([this](IrOp a) { IrOp nan = build.inst(IrCmd::DIV_NUM, build.constDouble(0.0), build.constDouble(0.0)); build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::TRY_NUM_TO_INDEX, nan, a)); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); }); updateUseCounts(build.function); @@ -420,19 +420,19 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "NumToIndex") CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: STORE_INT R0, 4i - LOP_RETURN 0u + RETURN 0u bb_2: JUMP bb_3 bb_3: - LOP_RETURN 1u + RETURN 1u bb_4: JUMP bb_5 bb_5: - LOP_RETURN 1u + RETURN 1u )"); } @@ -441,12 +441,12 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "Guards") { withOneBlock([this](IrOp a) { build.inst(IrCmd::CHECK_TAG, build.constTag(tnumber), build.constTag(tnumber), a); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); }); withOneBlock([this](IrOp a) { build.inst(IrCmd::CHECK_TAG, build.constTag(tnil), build.constTag(tnumber), a); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); }); updateUseCounts(build.function); @@ -454,13 +454,13 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "Guards") CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: - LOP_RETURN 0u + RETURN 0u bb_2: JUMP bb_3 bb_3: - LOP_RETURN 1u + RETURN 1u )"); } @@ -568,7 +568,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "RememberTagsAndValues") build.inst(IrCmd::STORE_INT, build.vmReg(10), build.inst(IrCmd::LOAD_INT, build.vmReg(1))); build.inst(IrCmd::STORE_DOUBLE, build.vmReg(11), build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(2))); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); updateUseCounts(build.function); constPropInBlockChains(build); @@ -593,7 +593,7 @@ bb_0: STORE_INT R10, %20 %22 = LOAD_DOUBLE R2 STORE_DOUBLE R11, %22 - LOP_RETURN 0u + RETURN 0u )"); } @@ -614,7 +614,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "PropagateThroughTvalue") build.inst(IrCmd::STORE_TAG, build.vmReg(3), build.inst(IrCmd::LOAD_TAG, build.vmReg(1))); build.inst(IrCmd::STORE_DOUBLE, build.vmReg(3), build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(1))); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); updateUseCounts(build.function); constPropInBlockChains(build); @@ -627,7 +627,7 @@ bb_0: STORE_TVALUE R1, %2 STORE_TAG R3, tnumber STORE_DOUBLE R3, 0.5 - LOP_RETURN 0u + RETURN 0u )"); } @@ -641,10 +641,10 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "SkipCheckTag") build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber)); build.inst(IrCmd::CHECK_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(0)), build.constTag(tnumber), fallback); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); build.beginBlock(fallback); - build.inst(IrCmd::LOP_RETURN, build.constUint(1)); + build.inst(IrCmd::RETURN, build.constUint(1)); updateUseCounts(build.function); constPropInBlockChains(build); @@ -652,7 +652,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "SkipCheckTag") CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: STORE_TAG R0, tnumber - LOP_RETURN 0u + RETURN 0u )"); } @@ -671,7 +671,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "SkipOncePerBlockChecks") build.inst(IrCmd::DO_LEN, build.vmReg(1), build.vmReg(2)); // Can make env unsafe build.inst(IrCmd::CHECK_SAFE_ENV); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); updateUseCounts(build.function); constPropInBlockChains(build); @@ -682,7 +682,7 @@ bb_0: CHECK_GC DO_LEN R1, R2 CHECK_SAFE_ENV - LOP_RETURN 0u + RETURN 0u )"); } @@ -707,10 +707,10 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "RememberTableState") build.inst(IrCmd::CHECK_NO_METATABLE, table, fallback); build.inst(IrCmd::CHECK_READONLY, table, fallback); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); build.beginBlock(fallback); - build.inst(IrCmd::LOP_RETURN, build.constUint(1)); + build.inst(IrCmd::RETURN, build.constUint(1)); updateUseCounts(build.function); constPropInBlockChains(build); @@ -723,10 +723,10 @@ bb_0: DO_LEN R1, R2 CHECK_NO_METATABLE %0, bb_fallback_1 CHECK_READONLY %0, bb_fallback_1 - LOP_RETURN 0u + RETURN 0u bb_fallback_1: - LOP_RETURN 1u + RETURN 1u )"); } @@ -742,7 +742,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "SkipUselessBarriers") build.inst(IrCmd::BARRIER_TABLE_FORWARD, table, build.vmReg(0)); IrOp something = build.inst(IrCmd::LOAD_POINTER, build.vmReg(2)); build.inst(IrCmd::BARRIER_OBJ, something, build.vmReg(0)); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); updateUseCounts(build.function); constPropInBlockChains(build); @@ -750,7 +750,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "SkipUselessBarriers") CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: STORE_TAG R0, tnumber - LOP_RETURN 0u + RETURN 0u )"); } @@ -764,14 +764,16 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "ConcatInvalidation") build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber)); build.inst(IrCmd::STORE_INT, build.vmReg(1), build.constInt(10)); build.inst(IrCmd::STORE_DOUBLE, build.vmReg(2), build.constDouble(0.5)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(3), build.constDouble(2.0)); - build.inst(IrCmd::CONCAT, build.vmReg(0), build.vmReg(3)); // Concat invalidates more than the target register + build.inst(IrCmd::CONCAT, build.vmReg(0), build.constUint(3)); - build.inst(IrCmd::STORE_TAG, build.vmReg(3), build.inst(IrCmd::LOAD_TAG, build.vmReg(0))); - build.inst(IrCmd::STORE_INT, build.vmReg(4), build.inst(IrCmd::LOAD_INT, build.vmReg(1))); - build.inst(IrCmd::STORE_DOUBLE, build.vmReg(5), build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(2))); + build.inst(IrCmd::STORE_TAG, build.vmReg(4), build.inst(IrCmd::LOAD_TAG, build.vmReg(0))); + build.inst(IrCmd::STORE_INT, build.vmReg(5), build.inst(IrCmd::LOAD_INT, build.vmReg(1))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(6), build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(2))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(7), build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(3))); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); updateUseCounts(build.function); constPropInBlockChains(build); @@ -781,14 +783,16 @@ bb_0: STORE_TAG R0, tnumber STORE_INT R1, 10i STORE_DOUBLE R2, 0.5 - CONCAT R0, R3 - %4 = LOAD_TAG R0 - STORE_TAG R3, %4 - %6 = LOAD_INT R1 - STORE_INT R4, %6 - %8 = LOAD_DOUBLE R2 - STORE_DOUBLE R5, %8 - LOP_RETURN 0u + STORE_DOUBLE R3, 2 + CONCAT R0, 3u + %5 = LOAD_TAG R0 + STORE_TAG R4, %5 + %7 = LOAD_INT R1 + STORE_INT R5, %7 + %9 = LOAD_DOUBLE R2 + STORE_DOUBLE R6, %9 + STORE_DOUBLE R7, 2 + RETURN 0u )"); } @@ -815,10 +819,10 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "BuiltinFastcallsMayInvalidateMemory") build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(0))); // At least R0 wasn't touched - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); build.beginBlock(fallback); - build.inst(IrCmd::LOP_RETURN, build.constUint(1)); + build.inst(IrCmd::RETURN, build.constUint(1)); updateUseCounts(build.function); constPropInBlockChains(build); @@ -833,10 +837,10 @@ bb_0: CHECK_NO_METATABLE %1, bb_fallback_1 CHECK_READONLY %1, bb_fallback_1 STORE_DOUBLE R1, 0.5 - LOP_RETURN 0u + RETURN 0u bb_fallback_1: - LOP_RETURN 1u + RETURN 1u )"); } @@ -851,7 +855,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "RedundantStoreCheckConstantType") build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.constDouble(0.5)); build.inst(IrCmd::STORE_INT, build.vmReg(0), build.constInt(10)); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); updateUseCounts(build.function); constPropInBlockChains(build); @@ -861,7 +865,7 @@ bb_0: STORE_INT R0, 10i STORE_DOUBLE R0, 0.5 STORE_INT R0, 10i - LOP_RETURN 0u + RETURN 0u )"); } @@ -878,10 +882,10 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "TagCheckPropagation") build.inst(IrCmd::CHECK_TAG, unknown, build.constTag(tnumber), fallback); build.inst(IrCmd::CHECK_TAG, unknown, build.constTag(tnumber), fallback); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); build.beginBlock(fallback); - build.inst(IrCmd::LOP_RETURN, build.constUint(1)); + build.inst(IrCmd::RETURN, build.constUint(1)); updateUseCounts(build.function); constPropInBlockChains(build); @@ -890,10 +894,10 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "TagCheckPropagation") bb_0: %0 = LOAD_TAG R0 CHECK_TAG %0, tnumber, bb_fallback_1 - LOP_RETURN 0u + RETURN 0u bb_fallback_1: - LOP_RETURN 1u + RETURN 1u )"); } @@ -910,10 +914,10 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "TagCheckPropagationConflicting") build.inst(IrCmd::CHECK_TAG, unknown, build.constTag(tnumber), fallback); build.inst(IrCmd::CHECK_TAG, unknown, build.constTag(tnil), fallback); - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + build.inst(IrCmd::RETURN, build.constUint(0)); build.beginBlock(fallback); - build.inst(IrCmd::LOP_RETURN, build.constUint(1)); + build.inst(IrCmd::RETURN, build.constUint(1)); updateUseCounts(build.function); constPropInBlockChains(build); @@ -925,7 +929,7 @@ bb_0: JUMP bb_fallback_1 bb_fallback_1: - LOP_RETURN 1u + RETURN 1u )"); } @@ -943,13 +947,13 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "TruthyTestRemoval") build.inst(IrCmd::JUMP_IF_TRUTHY, build.vmReg(1), trueBlock, falseBlock); build.beginBlock(trueBlock); - build.inst(IrCmd::LOP_RETURN, build.constUint(1)); + build.inst(IrCmd::RETURN, build.constUint(1)); build.beginBlock(falseBlock); - build.inst(IrCmd::LOP_RETURN, build.constUint(2)); + build.inst(IrCmd::RETURN, build.constUint(2)); build.beginBlock(fallback); - build.inst(IrCmd::LOP_RETURN, build.constUint(3)); + build.inst(IrCmd::RETURN, build.constUint(3)); updateUseCounts(build.function); constPropInBlockChains(build); @@ -961,10 +965,10 @@ bb_0: JUMP bb_1 bb_1: - LOP_RETURN 1u + RETURN 1u bb_fallback_3: - LOP_RETURN 3u + RETURN 3u )"); } @@ -982,13 +986,13 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FalsyTestRemoval") build.inst(IrCmd::JUMP_IF_FALSY, build.vmReg(1), trueBlock, falseBlock); build.beginBlock(trueBlock); - build.inst(IrCmd::LOP_RETURN, build.constUint(1)); + build.inst(IrCmd::RETURN, build.constUint(1)); build.beginBlock(falseBlock); - build.inst(IrCmd::LOP_RETURN, build.constUint(2)); + build.inst(IrCmd::RETURN, build.constUint(2)); build.beginBlock(fallback); - build.inst(IrCmd::LOP_RETURN, build.constUint(3)); + build.inst(IrCmd::RETURN, build.constUint(3)); updateUseCounts(build.function); constPropInBlockChains(build); @@ -1000,10 +1004,10 @@ bb_0: JUMP bb_2 bb_2: - LOP_RETURN 2u + RETURN 2u bb_fallback_3: - LOP_RETURN 3u + RETURN 3u )"); } @@ -1020,10 +1024,10 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "TagEqRemoval") build.inst(IrCmd::JUMP_EQ_TAG, tag, build.constTag(tnumber), trueBlock, falseBlock); build.beginBlock(trueBlock); - build.inst(IrCmd::LOP_RETURN, build.constUint(1)); + build.inst(IrCmd::RETURN, build.constUint(1)); build.beginBlock(falseBlock); - build.inst(IrCmd::LOP_RETURN, build.constUint(2)); + build.inst(IrCmd::RETURN, build.constUint(2)); updateUseCounts(build.function); constPropInBlockChains(build); @@ -1035,7 +1039,7 @@ bb_0: JUMP bb_2 bb_2: - LOP_RETURN 2u + RETURN 2u )"); } @@ -1052,10 +1056,10 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "IntEqRemoval") build.inst(IrCmd::JUMP_EQ_INT, value, build.constInt(5), trueBlock, falseBlock); build.beginBlock(trueBlock); - build.inst(IrCmd::LOP_RETURN, build.constUint(1)); + build.inst(IrCmd::RETURN, build.constUint(1)); build.beginBlock(falseBlock); - build.inst(IrCmd::LOP_RETURN, build.constUint(2)); + build.inst(IrCmd::RETURN, build.constUint(2)); updateUseCounts(build.function); constPropInBlockChains(build); @@ -1066,7 +1070,7 @@ bb_0: JUMP bb_1 bb_1: - LOP_RETURN 1u + RETURN 1u )"); } @@ -1083,10 +1087,10 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "NumCmpRemoval") build.inst(IrCmd::JUMP_CMP_NUM, value, build.constDouble(8.0), build.cond(IrCondition::Greater), trueBlock, falseBlock); build.beginBlock(trueBlock); - build.inst(IrCmd::LOP_RETURN, build.constUint(1)); + build.inst(IrCmd::RETURN, build.constUint(1)); build.beginBlock(falseBlock); - build.inst(IrCmd::LOP_RETURN, build.constUint(2)); + build.inst(IrCmd::RETURN, build.constUint(2)); updateUseCounts(build.function); constPropInBlockChains(build); @@ -1097,7 +1101,7 @@ bb_0: JUMP bb_2 bb_2: - LOP_RETURN 2u + RETURN 2u )"); } @@ -1114,7 +1118,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "DataFlowsThroughDirectJumpToUniqueSuccessor build.beginBlock(block2); build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.inst(IrCmd::LOAD_TAG, build.vmReg(0))); - build.inst(IrCmd::LOP_RETURN, build.constUint(1)); + build.inst(IrCmd::RETURN, build.constUint(1)); updateUseCounts(build.function); constPropInBlockChains(build); @@ -1126,7 +1130,7 @@ bb_0: bb_1: STORE_TAG R1, tnumber - LOP_RETURN 1u + RETURN 1u )"); } @@ -1144,7 +1148,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "DataDoesNotFlowThroughDirectJumpToNonUnique build.beginBlock(block2); build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.inst(IrCmd::LOAD_TAG, build.vmReg(0))); - build.inst(IrCmd::LOP_RETURN, build.constUint(1)); + build.inst(IrCmd::RETURN, build.constUint(1)); build.beginBlock(block3); build.inst(IrCmd::JUMP, block2); @@ -1160,7 +1164,7 @@ bb_0: bb_1: %2 = LOAD_TAG R0 STORE_TAG R1, %2 - LOP_RETURN 1u + RETURN 1u bb_2: JUMP bb_1 @@ -1179,7 +1183,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "EntryBlockUseRemoval") build.inst(IrCmd::JUMP_IF_TRUTHY, build.vmReg(0), exit, repeat); build.beginBlock(exit); - build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(0), build.constInt(0)); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(0)); build.beginBlock(repeat); build.inst(IrCmd::INTERRUPT, build.constUint(0)); @@ -1194,7 +1198,7 @@ bb_0: JUMP bb_1 bb_1: - LOP_RETURN 0u, R0, 0i + RETURN R0, 0i )"); } @@ -1207,14 +1211,14 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "RecursiveSccUseRemoval1") IrOp repeat = build.block(IrBlockKind::Internal); build.beginBlock(entry); - build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(0), build.constInt(0)); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(0)); build.beginBlock(block); build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber)); build.inst(IrCmd::JUMP_IF_TRUTHY, build.vmReg(0), exit, repeat); build.beginBlock(exit); - build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(0), build.constInt(0)); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(0)); build.beginBlock(repeat); build.inst(IrCmd::INTERRUPT, build.constUint(0)); @@ -1225,14 +1229,14 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "RecursiveSccUseRemoval1") CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: - LOP_RETURN 0u, R0, 0i + RETURN R0, 0i bb_1: STORE_TAG R0, tnumber JUMP bb_2 bb_2: - LOP_RETURN 0u, R0, 0i + RETURN R0, 0i )"); } @@ -1249,14 +1253,14 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "RecursiveSccUseRemoval2") build.inst(IrCmd::JUMP_EQ_INT, build.constInt(0), build.constInt(1), block, exit1); build.beginBlock(exit1); - build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(0), build.constInt(0)); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(0)); build.beginBlock(block); build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber)); build.inst(IrCmd::JUMP_IF_TRUTHY, build.vmReg(0), exit2, repeat); build.beginBlock(exit2); - build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(0), build.constInt(0)); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(0)); build.beginBlock(repeat); build.inst(IrCmd::INTERRUPT, build.constUint(0)); @@ -1270,14 +1274,14 @@ bb_0: JUMP bb_1 bb_1: - LOP_RETURN 0u, R0, 0i + RETURN R0, 0i bb_2: STORE_TAG R0, tnumber JUMP bb_3 bb_3: - LOP_RETURN 0u, R0, 0i + RETURN R0, 0i )"); } @@ -1318,7 +1322,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "SimplePathExtraction") build.inst(IrCmd::JUMP, block4); build.beginBlock(block4); - build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(0), build.constInt(0)); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(0)); updateUseCounts(build.function); constPropInBlockChains(build); @@ -1346,10 +1350,10 @@ bb_4: JUMP bb_5 bb_5: - LOP_RETURN 0u, R0, 0i + RETURN R0, 0i bb_linear_6: - LOP_RETURN 0u, R0, 0i + RETURN R0, 0i )"); } @@ -1389,11 +1393,11 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "NoPathExtractionForBlocksWithLiveOutValues" build.beginBlock(block4a); build.inst(IrCmd::STORE_TAG, build.vmReg(0), tag3a); - build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(0), build.constInt(0)); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(0)); build.beginBlock(block4b); build.inst(IrCmd::STORE_TAG, build.vmReg(0), tag3a); - build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(0), build.constInt(0)); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(0)); updateUseCounts(build.function); constPropInBlockChains(build); @@ -1423,11 +1427,11 @@ bb_4: bb_5: STORE_TAG R0, %10 - LOP_RETURN 0u, R0, 0i + RETURN R0, 0i bb_6: STORE_TAG R0, %10 - LOP_RETURN 0u, R0, 0i + RETURN R0, 0i )"); } @@ -1484,7 +1488,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "SimpleDiamond") build.inst(IrCmd::JUMP, exit); build.beginBlock(exit); - build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(2), build.constInt(2)); + build.inst(IrCmd::RETURN, build.vmReg(2), build.constInt(2)); updateUseCounts(build.function); computeCfgInfo(build.function); @@ -1518,7 +1522,7 @@ bb_2: bb_3: ; predecessors: bb_1, bb_2 ; in regs: R2, R3 - LOP_RETURN 0u, R2, 2i + RETURN R2, 2i )"); } @@ -1530,11 +1534,11 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "ImplicitFixedRegistersInVarargCall") build.beginBlock(entry); build.inst(IrCmd::FALLBACK_GETVARARGS, build.constUint(0), build.vmReg(3), build.constInt(-1)); - build.inst(IrCmd::LOP_CALL, build.constUint(0), build.vmReg(0), build.constInt(-1), build.constInt(5)); + build.inst(IrCmd::CALL, build.vmReg(0), build.constInt(-1), build.constInt(5)); build.inst(IrCmd::JUMP, exit); build.beginBlock(exit); - build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(0), build.constInt(5)); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(5)); updateUseCounts(build.function); computeCfgInfo(build.function); @@ -1545,13 +1549,13 @@ bb_0: ; in regs: R0, R1, R2 ; out regs: R0, R1, R2, R3, R4 FALLBACK_GETVARARGS 0u, R3, -1i - LOP_CALL 0u, R0, -1i, 5i + CALL R0, -1i, 5i JUMP bb_1 bb_1: ; predecessors: bb_0 ; in regs: R0, R1, R2, R3, R4 - LOP_RETURN 0u, R0, 5i + RETURN R0, 5i )"); } @@ -1563,11 +1567,13 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "ExplicitUseOfRegisterInVarargSequence") build.beginBlock(entry); build.inst(IrCmd::FALLBACK_GETVARARGS, build.constUint(0), build.vmReg(1), build.constInt(-1)); - build.inst(IrCmd::INVOKE_FASTCALL, build.constUint(0), build.vmReg(0), build.vmReg(1), build.vmReg(2), build.constInt(-1), build.constInt(-1)); + IrOp results = build.inst( + IrCmd::INVOKE_FASTCALL, build.constUint(0), build.vmReg(0), build.vmReg(1), build.vmReg(2), build.constInt(-1), build.constInt(-1)); + build.inst(IrCmd::ADJUST_STACK_TO_REG, build.vmReg(0), results); build.inst(IrCmd::JUMP, exit); build.beginBlock(exit); - build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(0), build.constInt(-1)); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(-1)); updateUseCounts(build.function); computeCfgInfo(build.function); @@ -1578,12 +1584,13 @@ bb_0: ; out regs: R0... FALLBACK_GETVARARGS 0u, R1, -1i %1 = INVOKE_FASTCALL 0u, R0, R1, R2, -1i, -1i + ADJUST_STACK_TO_REG R0, %1 JUMP bb_1 bb_1: ; predecessors: bb_0 ; in regs: R0... - LOP_RETURN 0u, R0, -1i + RETURN R0, -1i )"); } @@ -1594,12 +1601,12 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "VariadicSequenceRestart") IrOp exit = build.block(IrBlockKind::Internal); build.beginBlock(entry); - build.inst(IrCmd::LOP_CALL, build.constUint(0), build.vmReg(1), build.constInt(0), build.constInt(-1)); - build.inst(IrCmd::LOP_CALL, build.constUint(0), build.vmReg(0), build.constInt(-1), build.constInt(-1)); + build.inst(IrCmd::CALL, build.vmReg(1), build.constInt(0), build.constInt(-1)); + build.inst(IrCmd::CALL, build.vmReg(0), build.constInt(-1), build.constInt(-1)); build.inst(IrCmd::JUMP, exit); build.beginBlock(exit); - build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(0), build.constInt(-1)); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(-1)); updateUseCounts(build.function); computeCfgInfo(build.function); @@ -1609,14 +1616,14 @@ bb_0: ; successors: bb_1 ; in regs: R0, R1 ; out regs: R0... - LOP_CALL 0u, R1, 0i, -1i - LOP_CALL 0u, R0, -1i, -1i + CALL R1, 0i, -1i + CALL R0, -1i, -1i JUMP bb_1 bb_1: ; predecessors: bb_0 ; in regs: R0... - LOP_RETURN 0u, R0, -1i + RETURN R0, -1i )"); } @@ -1630,15 +1637,15 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FallbackDoesNotFlowUp") build.beginBlock(entry); build.inst(IrCmd::FALLBACK_GETVARARGS, build.constUint(0), build.vmReg(1), build.constInt(-1)); build.inst(IrCmd::CHECK_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(0)), build.constTag(tnumber), fallback); - build.inst(IrCmd::LOP_CALL, build.constUint(0), build.vmReg(0), build.constInt(-1), build.constInt(-1)); + build.inst(IrCmd::CALL, build.vmReg(0), build.constInt(-1), build.constInt(-1)); build.inst(IrCmd::JUMP, exit); build.beginBlock(fallback); - build.inst(IrCmd::LOP_CALL, build.constUint(0), build.vmReg(0), build.constInt(-1), build.constInt(-1)); + build.inst(IrCmd::CALL, build.vmReg(0), build.constInt(-1), build.constInt(-1)); build.inst(IrCmd::JUMP, exit); build.beginBlock(exit); - build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(0), build.constInt(-1)); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(-1)); updateUseCounts(build.function); computeCfgInfo(build.function); @@ -1651,7 +1658,7 @@ bb_0: FALLBACK_GETVARARGS 0u, R1, -1i %1 = LOAD_TAG R0 CHECK_TAG %1, tnumber, bb_fallback_1 - LOP_CALL 0u, R0, -1i, -1i + CALL R0, -1i, -1i JUMP bb_2 bb_fallback_1: @@ -1659,13 +1666,13 @@ bb_fallback_1: ; successors: bb_2 ; in regs: R0, R1... ; out regs: R0... - LOP_CALL 0u, R0, -1i, -1i + CALL R0, -1i, -1i JUMP bb_2 bb_2: ; predecessors: bb_0, bb_fallback_1 ; in regs: R0... - LOP_RETURN 0u, R0, -1i + RETURN R0, -1i )"); } @@ -1690,7 +1697,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "VariadicSequencePeeling") build.inst(IrCmd::JUMP, exit); build.beginBlock(exit); - build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(2), build.constInt(-1)); + build.inst(IrCmd::RETURN, build.vmReg(2), build.constInt(-1)); updateUseCounts(build.function); computeCfgInfo(build.function); @@ -1725,7 +1732,65 @@ bb_2: bb_3: ; predecessors: bb_1, bb_2 ; in regs: R2... - LOP_RETURN 0u, R2, -1i + RETURN R2, -1i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "BuiltinVariadicStart") +{ + IrOp entry = build.block(IrBlockKind::Internal); + IrOp exit = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.constDouble(1.0)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(2), build.constDouble(2.0)); + build.inst(IrCmd::ADJUST_STACK_TO_REG, build.vmReg(2), build.constInt(1)); + build.inst(IrCmd::CALL, build.vmReg(1), build.constInt(-1), build.constInt(1)); + build.inst(IrCmd::JUMP, exit); + + build.beginBlock(exit); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(2)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: +; successors: bb_1 +; in regs: R0 +; out regs: R0, R1 + STORE_DOUBLE R1, 1 + STORE_DOUBLE R2, 2 + ADJUST_STACK_TO_REG R2, 1i + CALL R1, -1i, 1i + JUMP bb_1 + +bb_1: +; predecessors: bb_0 +; in regs: R0, R1 + RETURN R0, 2i + +)"); +} + + +TEST_CASE_FIXTURE(IrBuilderFixture, "SetTable") +{ + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::SET_TABLE, build.vmReg(0), build.vmReg(1), build.constUint(1)); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: +; in regs: R0, R1 + SET_TABLE R0, R1, 1u + RETURN R0, 1i )"); } diff --git a/tests/IrCallWrapperX64.test.cpp b/tests/IrCallWrapperX64.test.cpp new file mode 100644 index 00000000..c8918dbd --- /dev/null +++ b/tests/IrCallWrapperX64.test.cpp @@ -0,0 +1,522 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/IrCallWrapperX64.h" +#include "Luau/IrRegAllocX64.h" + +#include "doctest.h" + +using namespace Luau::CodeGen; +using namespace Luau::CodeGen::X64; + +class IrCallWrapperX64Fixture +{ +public: + IrCallWrapperX64Fixture() + : build(/* logText */ true, ABIX64::Windows) + , regs(build, function) + , callWrap(regs, build, ~0u) + { + } + + void checkMatch(std::string expected) + { + regs.assertAllFree(); + + build.finalize(); + + CHECK("\n" + build.text == expected); + } + + AssemblyBuilderX64 build; + IrFunction function; + IrRegAllocX64 regs; + IrCallWrapperX64 callWrap; + + // Tests rely on these to force interference between registers + static constexpr RegisterX64 rArg1 = rcx; + static constexpr RegisterX64 rArg1d = ecx; + static constexpr RegisterX64 rArg2 = rdx; + static constexpr RegisterX64 rArg2d = edx; + static constexpr RegisterX64 rArg3 = r8; + static constexpr RegisterX64 rArg3d = r8d; + static constexpr RegisterX64 rArg4 = r9; + static constexpr RegisterX64 rArg4d = r9d; +}; + +TEST_SUITE_BEGIN("IrCallWrapperX64"); + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "SimpleRegs") +{ + ScopedRegX64 tmp1{regs, regs.takeReg(rax, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rArg2, kInvalidInstIdx)}; + callWrap.addArgument(SizeX64::qword, tmp1); + callWrap.addArgument(SizeX64::qword, tmp2); // Already in its place + callWrap.call(qword[r12]); + + checkMatch(R"( + mov rcx,rax + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "TrickyUse1") +{ + ScopedRegX64 tmp1{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; + callWrap.addArgument(SizeX64::qword, tmp1.reg); // Already in its place + callWrap.addArgument(SizeX64::qword, tmp1.release()); + callWrap.call(qword[r12]); + + checkMatch(R"( + mov rdx,rcx + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "TrickyUse2") +{ + ScopedRegX64 tmp1{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; + callWrap.addArgument(SizeX64::qword, qword[tmp1.reg]); + callWrap.addArgument(SizeX64::qword, tmp1.release()); + callWrap.call(qword[r12]); + + checkMatch(R"( + mov rdx,rcx + mov rcx,qword ptr [rcx] + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "SimpleMemImm") +{ + ScopedRegX64 tmp1{regs, regs.takeReg(rax, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rsi, kInvalidInstIdx)}; + callWrap.addArgument(SizeX64::dword, 32); + callWrap.addArgument(SizeX64::dword, -1); + callWrap.addArgument(SizeX64::qword, qword[r14 + 32]); + callWrap.addArgument(SizeX64::qword, qword[tmp1.release() + tmp2.release()]); + callWrap.call(qword[r12]); + + checkMatch(R"( + mov r8,qword ptr [r14+020h] + mov r9,qword ptr [rax+rsi] + mov ecx,20h + mov edx,FFFFFFFFh + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "SimpleStackArgs") +{ + ScopedRegX64 tmp{regs, regs.takeReg(rax, kInvalidInstIdx)}; + callWrap.addArgument(SizeX64::qword, tmp); + callWrap.addArgument(SizeX64::qword, qword[r14 + 16]); + callWrap.addArgument(SizeX64::qword, qword[r14 + 32]); + callWrap.addArgument(SizeX64::qword, qword[r14 + 48]); + callWrap.addArgument(SizeX64::dword, 1); + callWrap.addArgument(SizeX64::qword, qword[r13]); + callWrap.call(qword[r12]); + + checkMatch(R"( + mov rdx,qword ptr [r13] + mov qword ptr [rsp+028h],rdx + mov rcx,rax + mov rdx,qword ptr [r14+010h] + mov r8,qword ptr [r14+020h] + mov r9,qword ptr [r14+030h] + mov dword ptr [rsp+020h],1 + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "FixedRegisters") +{ + callWrap.addArgument(SizeX64::dword, 1); + callWrap.addArgument(SizeX64::qword, 2); + callWrap.addArgument(SizeX64::qword, 3); + callWrap.addArgument(SizeX64::qword, 4); + callWrap.addArgument(SizeX64::qword, r14); + callWrap.call(qword[r12]); + + checkMatch(R"( + mov qword ptr [rsp+020h],r14 + mov ecx,1 + mov rdx,2 + mov r8,3 + mov r9,4 + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "EasyInterference") +{ + ScopedRegX64 tmp1{regs, regs.takeReg(rdi, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rsi, kInvalidInstIdx)}; + ScopedRegX64 tmp3{regs, regs.takeReg(rArg2, kInvalidInstIdx)}; + ScopedRegX64 tmp4{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; + callWrap.addArgument(SizeX64::qword, tmp1); + callWrap.addArgument(SizeX64::qword, tmp2); + callWrap.addArgument(SizeX64::qword, tmp3); + callWrap.addArgument(SizeX64::qword, tmp4); + callWrap.call(qword[r12]); + + checkMatch(R"( + mov r8,rdx + mov rdx,rsi + mov r9,rcx + mov rcx,rdi + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "FakeInterference") +{ + ScopedRegX64 tmp1{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rArg2, kInvalidInstIdx)}; + callWrap.addArgument(SizeX64::qword, qword[tmp1.release() + 8]); + callWrap.addArgument(SizeX64::qword, qword[tmp2.release() + 8]); + callWrap.call(qword[r12]); + + checkMatch(R"( + mov rcx,qword ptr [rcx+8] + mov rdx,qword ptr [rdx+8] + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "HardInterferenceInt") +{ + ScopedRegX64 tmp1{regs, regs.takeReg(rArg4, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rArg3, kInvalidInstIdx)}; + ScopedRegX64 tmp3{regs, regs.takeReg(rArg2, kInvalidInstIdx)}; + ScopedRegX64 tmp4{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; + callWrap.addArgument(SizeX64::qword, tmp1); + callWrap.addArgument(SizeX64::qword, tmp2); + callWrap.addArgument(SizeX64::qword, tmp3); + callWrap.addArgument(SizeX64::qword, tmp4); + callWrap.call(qword[r12]); + + checkMatch(R"( + mov rax,r9 + mov r9,rcx + mov rcx,rax + mov rax,r8 + mov r8,rdx + mov rdx,rax + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "HardInterferenceInt2") +{ + ScopedRegX64 tmp1{regs, regs.takeReg(rArg4d, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rArg3d, kInvalidInstIdx)}; + ScopedRegX64 tmp3{regs, regs.takeReg(rArg2d, kInvalidInstIdx)}; + ScopedRegX64 tmp4{regs, regs.takeReg(rArg1d, kInvalidInstIdx)}; + callWrap.addArgument(SizeX64::dword, tmp1); + callWrap.addArgument(SizeX64::dword, tmp2); + callWrap.addArgument(SizeX64::dword, tmp3); + callWrap.addArgument(SizeX64::dword, tmp4); + callWrap.call(qword[r12]); + + checkMatch(R"( + mov eax,r9d + mov r9d,ecx + mov ecx,eax + mov eax,r8d + mov r8d,edx + mov edx,eax + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "HardInterferenceFp") +{ + ScopedRegX64 tmp1{regs, regs.takeReg(xmm1, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(xmm0, kInvalidInstIdx)}; + callWrap.addArgument(SizeX64::xmmword, tmp1); + callWrap.addArgument(SizeX64::xmmword, tmp2); + callWrap.call(qword[r12]); + + checkMatch(R"( + vmovsd xmm2,xmm1,xmm1 + vmovsd xmm1,xmm0,xmm0 + vmovsd xmm0,xmm2,xmm2 + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "HardInterferenceBoth") +{ + ScopedRegX64 int1{regs, regs.takeReg(rArg2, kInvalidInstIdx)}; + ScopedRegX64 int2{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; + ScopedRegX64 fp1{regs, regs.takeReg(xmm3, kInvalidInstIdx)}; + ScopedRegX64 fp2{regs, regs.takeReg(xmm2, kInvalidInstIdx)}; + callWrap.addArgument(SizeX64::qword, int1); + callWrap.addArgument(SizeX64::qword, int2); + callWrap.addArgument(SizeX64::xmmword, fp1); + callWrap.addArgument(SizeX64::xmmword, fp2); + callWrap.call(qword[r12]); + + checkMatch(R"( + mov rax,rdx + mov rdx,rcx + mov rcx,rax + vmovsd xmm0,xmm3,xmm3 + vmovsd xmm3,xmm2,xmm2 + vmovsd xmm2,xmm0,xmm0 + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "FakeMultiuseInterferenceMem") +{ + ScopedRegX64 tmp1{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rArg2, kInvalidInstIdx)}; + callWrap.addArgument(SizeX64::qword, qword[tmp1.reg + tmp2.reg + 8]); + callWrap.addArgument(SizeX64::qword, qword[tmp2.reg + 16]); + tmp1.release(); + tmp2.release(); + callWrap.call(qword[r12]); + + checkMatch(R"( + mov rcx,qword ptr [rcx+rdx+8] + mov rdx,qword ptr [rdx+010h] + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "HardMultiuseInterferenceMem1") +{ + ScopedRegX64 tmp1{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rArg2, kInvalidInstIdx)}; + callWrap.addArgument(SizeX64::qword, qword[tmp1.reg + tmp2.reg + 8]); + callWrap.addArgument(SizeX64::qword, qword[tmp1.reg + 16]); + tmp1.release(); + tmp2.release(); + callWrap.call(qword[r12]); + + checkMatch(R"( + mov rax,rcx + mov rcx,qword ptr [rax+rdx+8] + mov rdx,qword ptr [rax+010h] + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "HardMultiuseInterferenceMem2") +{ + ScopedRegX64 tmp1{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rArg2, kInvalidInstIdx)}; + callWrap.addArgument(SizeX64::qword, qword[tmp1.reg + tmp2.reg + 8]); + callWrap.addArgument(SizeX64::qword, qword[tmp1.reg + tmp2.reg + 16]); + tmp1.release(); + tmp2.release(); + callWrap.call(qword[r12]); + + checkMatch(R"( + mov rax,rcx + mov rcx,qword ptr [rax+rdx+8] + mov rdx,qword ptr [rax+rdx+010h] + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "HardMultiuseInterferenceMem3") +{ + ScopedRegX64 tmp1{regs, regs.takeReg(rArg3, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rArg2, kInvalidInstIdx)}; + ScopedRegX64 tmp3{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; + callWrap.addArgument(SizeX64::qword, qword[tmp1.reg + tmp2.reg + 8]); + callWrap.addArgument(SizeX64::qword, qword[tmp2.reg + tmp3.reg + 16]); + callWrap.addArgument(SizeX64::qword, qword[tmp3.reg + tmp1.reg + 16]); + tmp1.release(); + tmp2.release(); + tmp3.release(); + callWrap.call(qword[r12]); + + checkMatch(R"( + mov rax,r8 + mov r8,qword ptr [rcx+rax+010h] + mov rbx,rdx + mov rdx,qword ptr [rbx+rcx+010h] + mov rcx,qword ptr [rax+rbx+8] + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "InterferenceWithCallArg1") +{ + ScopedRegX64 tmp1{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; + callWrap.addArgument(SizeX64::qword, qword[tmp1.reg + 8]); + callWrap.call(qword[tmp1.release() + 16]); + + checkMatch(R"( + mov rax,rcx + mov rcx,qword ptr [rax+8] + call qword ptr [rax+010h] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "InterferenceWithCallArg2") +{ + ScopedRegX64 tmp1{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rArg2, kInvalidInstIdx)}; + callWrap.addArgument(SizeX64::qword, tmp2); + callWrap.call(qword[tmp1.release() + 16]); + + checkMatch(R"( + mov rax,rcx + mov rcx,rdx + call qword ptr [rax+010h] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "InterferenceWithCallArg3") +{ + ScopedRegX64 tmp1{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; + callWrap.addArgument(SizeX64::qword, tmp1.reg); + callWrap.call(qword[tmp1.release() + 16]); + + checkMatch(R"( + call qword ptr [rcx+010h] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "WithLastIrInstUse1") +{ + IrInst irInst1; + IrOp irOp1 = {IrOpKind::Inst, 0}; + irInst1.regX64 = regs.takeReg(xmm0, irOp1.index); + irInst1.lastUse = 1; + function.instructions.push_back(irInst1); + callWrap.instIdx = irInst1.lastUse; + + callWrap.addArgument(SizeX64::xmmword, irInst1.regX64, irOp1); // Already in its place + callWrap.addArgument(SizeX64::xmmword, qword[r12 + 8]); + callWrap.call(qword[r12]); + + checkMatch(R"( + vmovsd xmm1,qword ptr [r12+8] + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "WithLastIrInstUse2") +{ + IrInst irInst1; + IrOp irOp1 = {IrOpKind::Inst, 0}; + irInst1.regX64 = regs.takeReg(xmm0, irOp1.index); + irInst1.lastUse = 1; + function.instructions.push_back(irInst1); + callWrap.instIdx = irInst1.lastUse; + + callWrap.addArgument(SizeX64::xmmword, qword[r12 + 8]); + callWrap.addArgument(SizeX64::xmmword, irInst1.regX64, irOp1); + callWrap.call(qword[r12]); + + checkMatch(R"( + vmovsd xmm1,xmm0,xmm0 + vmovsd xmm0,qword ptr [r12+8] + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "WithLastIrInstUse3") +{ + IrInst irInst1; + IrOp irOp1 = {IrOpKind::Inst, 0}; + irInst1.regX64 = regs.takeReg(xmm0, irOp1.index); + irInst1.lastUse = 1; + function.instructions.push_back(irInst1); + callWrap.instIdx = irInst1.lastUse; + + callWrap.addArgument(SizeX64::xmmword, irInst1.regX64, irOp1); + callWrap.addArgument(SizeX64::xmmword, irInst1.regX64, irOp1); + callWrap.call(qword[r12]); + + checkMatch(R"( + vmovsd xmm1,xmm0,xmm0 + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "WithLastIrInstUse4") +{ + IrInst irInst1; + IrOp irOp1 = {IrOpKind::Inst, 0}; + irInst1.regX64 = regs.takeReg(rax, irOp1.index); + irInst1.lastUse = 1; + function.instructions.push_back(irInst1); + callWrap.instIdx = irInst1.lastUse; + + ScopedRegX64 tmp{regs, regs.takeReg(rdx, kInvalidInstIdx)}; + callWrap.addArgument(SizeX64::qword, r15); + callWrap.addArgument(SizeX64::qword, irInst1.regX64, irOp1); + callWrap.addArgument(SizeX64::qword, tmp); + callWrap.call(qword[r12]); + + checkMatch(R"( + mov rcx,r15 + mov r8,rdx + mov rdx,rax + call qword ptr [r12] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "ExtraCoverage") +{ + ScopedRegX64 tmp1{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rArg2, kInvalidInstIdx)}; + callWrap.addArgument(SizeX64::qword, addr[r12 + 8]); + callWrap.addArgument(SizeX64::qword, addr[r12 + 16]); + callWrap.addArgument(SizeX64::xmmword, xmmword[r13]); + callWrap.call(qword[tmp1.release() + tmp2.release()]); + + checkMatch(R"( + vmovups xmm2,xmmword ptr [r13] + mov rax,rcx + lea rcx,none ptr [r12+8] + mov rbx,rdx + lea rdx,none ptr [r12+010h] + call qword ptr [rax+rbx] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "AddressInStackArguments") +{ + callWrap.addArgument(SizeX64::dword, 1); + callWrap.addArgument(SizeX64::dword, 2); + callWrap.addArgument(SizeX64::dword, 3); + callWrap.addArgument(SizeX64::dword, 4); + callWrap.addArgument(SizeX64::qword, addr[r12 + 16]); + callWrap.call(qword[r14]); + + checkMatch(R"( + lea rax,none ptr [r12+010h] + mov qword ptr [rsp+020h],rax + mov ecx,1 + mov edx,2 + mov r8d,3 + mov r9d,4 + call qword ptr [r14] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "ImmediateConflictWithFunction") +{ + ScopedRegX64 tmp1{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rArg2, kInvalidInstIdx)}; + + callWrap.addArgument(SizeX64::dword, 1); + callWrap.addArgument(SizeX64::dword, 2); + callWrap.call(qword[tmp1.release() + tmp2.release()]); + + checkMatch(R"( + mov rax,rcx + mov ecx,1 + mov rbx,rdx + mov edx,2 + call qword ptr [rax+rbx] +)"); +} + +TEST_SUITE_END(); diff --git a/tests/Lexer.test.cpp b/tests/Lexer.test.cpp index 7fcc1e54..78d1389a 100644 --- a/tests/Lexer.test.cpp +++ b/tests/Lexer.test.cpp @@ -157,8 +157,6 @@ TEST_CASE("string_interpolation_basic") TEST_CASE("string_interpolation_full") { - ScopedFastFlag sff("LuauFixInterpStringMid", true); - const std::string testInput = R"(`foo {"bar"} {"baz"} end`)"; Luau::Allocator alloc; AstNameTable table(alloc); diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index 0f134616..54a1f44c 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -1273,7 +1273,7 @@ TEST_CASE_FIXTURE(Fixture, "use_all_parent_scopes_for_globals") { ScopePtr testScope = frontend.addEnvironment("Test"); unfreeze(frontend.globals.globalTypes); - loadDefinitionFile(frontend.typeChecker, frontend.globals, testScope, R"( + frontend.loadDefinitionFile(frontend.globals, testScope, R"( declare Foo: number )", "@test", /* captureComments */ false); @@ -1444,8 +1444,6 @@ TEST_CASE_FIXTURE(Fixture, "LintHygieneUAF") TEST_CASE_FIXTURE(BuiltinsFixture, "DeprecatedApiTyped") { - ScopedFastFlag sff("LuauImproveDeprecatedApiLint", true); - unfreeze(frontend.globals.globalTypes); TypeId instanceType = frontend.globals.globalTypes.addType(ClassType{"Instance", {}, std::nullopt, std::nullopt, {}, {}, "Test"}); persist(instanceType); @@ -1496,8 +1494,6 @@ end TEST_CASE_FIXTURE(BuiltinsFixture, "DeprecatedApiUntyped") { - ScopedFastFlag sff("LuauImproveDeprecatedApiLint", true); - if (TableType* ttv = getMutable(getGlobalBinding(frontend.globals, "table"))) { ttv->props["foreach"].deprecated = true; diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index d2796b6d..7e61235a 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -338,7 +338,7 @@ type B = A TEST_CASE_FIXTURE(BuiltinsFixture, "do_not_clone_reexports") { ScopedFastFlag flags[] = { - {"LuauClonePublicInterfaceLess", true}, + {"LuauClonePublicInterfaceLess2", true}, {"LuauSubstitutionReentrant", true}, {"LuauClassTypeVarsInSubstitution", true}, {"LuauSubstitutionFixMissingFields", true}, @@ -376,7 +376,7 @@ return {} TEST_CASE_FIXTURE(BuiltinsFixture, "do_not_clone_types_of_reexported_values") { ScopedFastFlag flags[] = { - {"LuauClonePublicInterfaceLess", true}, + {"LuauClonePublicInterfaceLess2", true}, {"LuauSubstitutionReentrant", true}, {"LuauClassTypeVarsInSubstitution", true}, {"LuauSubstitutionFixMissingFields", true}, diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index 384a39fe..6552a24d 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -470,7 +470,6 @@ TEST_SUITE_END(); struct NormalizeFixture : Fixture { - ScopedFastFlag sff1{"LuauNegatedFunctionTypes", true}; ScopedFastFlag sff2{"LuauNegatedClassTypes", true}; TypeArena arena; @@ -749,6 +748,20 @@ TEST_CASE_FIXTURE(NormalizeFixture, "narrow_union_of_classes_with_intersection") CHECK("Child" == toString(normal("(Child | Unrelated) & Child"))); } +TEST_CASE_FIXTURE(NormalizeFixture, "intersection_of_metatables_where_the_metatable_is_top_or_bottom") +{ + ScopedFastFlag sff{"LuauNormalizeMetatableFixes", true}; + + CHECK("{ @metatable *error-type*, {| |} }" == toString(normal("Mt<{}, any> & Mt<{}, err>"))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "crazy_metatable") +{ + ScopedFastFlag sff{"LuauNormalizeMetatableFixes", true}; + + CHECK("never" == toString(normal("Mt<{}, number> & Mt<{}, string>"))); +} + TEST_CASE_FIXTURE(NormalizeFixture, "negations_of_classes") { ScopedFastFlag sffs[] = { @@ -802,7 +815,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "negations_of_tables") TEST_CASE_FIXTURE(NormalizeFixture, "normalize_blocked_types") { - ScopedFastFlag sff[] { + ScopedFastFlag sff[]{ {"LuauNormalizeBlockedTypes", true}, }; @@ -813,4 +826,14 @@ TEST_CASE_FIXTURE(NormalizeFixture, "normalize_blocked_types") CHECK_EQ(normalizer.typeFromNormal(*norm), &blocked); } +TEST_CASE_FIXTURE(NormalizeFixture, "normalize_pending_expansion_types") +{ + AstName name; + Type pending{PendingExpansionType{std::nullopt, name, {}, {}}}; + + const NormalizedType* norm = normalizer.normalize(&pending); + + CHECK_EQ(normalizer.typeFromNormal(*norm), &pending); +} + TEST_SUITE_END(); diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 9ff16d16..ef5aabbe 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -1040,8 +1040,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_call_without_parens") TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_without_expression") { - ScopedFastFlag sff("LuauFixInterpStringMid", true); - try { parse(R"( diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index b55c7746..52de15c7 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -947,4 +947,101 @@ TEST_CASE_FIXTURE(Fixture, "type_alias_locations") CHECK(mod->scopes[3].second->typeAliasNameLocations["X"] == Location(Position(5, 17), 1)); } +/* + * We had a bug in DCR where substitution would improperly clone a + * PendingExpansionType. + * + * This cloned type did not have a matching constraint to expand it, so it was + * left dangling and unexpanded forever. + * + * We must also delay the dispatch a constraint if doing so would require + * unifying a PendingExpansionType. + */ +TEST_CASE_FIXTURE(BuiltinsFixture, "dont_lose_track_of_PendingExpansionTypes_after_substitution") +{ + fileResolver.source["game/ReactCurrentDispatcher"] = R"( + export type BasicStateAction = ((S) -> S) | S + export type Dispatch = (A) -> () + + export type Dispatcher = { + useState: (initialState: (() -> S) | S) -> (S, Dispatch>), + } + + return {} + )"; + + // Note: This script path is actually as short as it can be. Any shorter + // and we somehow fail to surface the bug. + fileResolver.source["game/React/React/ReactHooks"] = R"( + local RCD = require(script.Parent.Parent.Parent.ReactCurrentDispatcher) + + local function resolveDispatcher(): RCD.Dispatcher + return (nil :: any) :: RCD.Dispatcher + end + + function useState( + initialState: (() -> S) | S + ): (S, RCD.Dispatch>) + local dispatcher = resolveDispatcher() + return dispatcher.useState(initialState) + end + )"; + + CheckResult result = frontend.check("game/React/React/ReactHooks"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "another_thing_from_roact") +{ + CheckResult result = check(R"( + type Map = { [K]: V } + type Set = { [T]: boolean } + + type FiberRoot = { + pingCache: Map | Map>)> | nil, + } + + type Wakeable = { + andThen: (self: Wakeable) -> nil | Wakeable, + } + + local function attachPingListener(root: FiberRoot, wakeable: Wakeable, lanes: number) + local pingCache: Map | Map>)> | nil = root.pingCache + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +/* + * It is sometimes possible for type alias resolution to produce a TypeId that + * belongs to a different module. + * + * We must not mutate any fields of the resulting type when this happens. The + * memory has been frozen. + */ +TEST_CASE_FIXTURE(BuiltinsFixture, "alias_expands_to_bare_reference_to_imported_type") +{ + fileResolver.source["game/A"] = R"( + --!strict + export type Object = {[string]: any} + return {} + )"; + + fileResolver.source["game/B"] = R"( + local A = require(script.Parent.A) + + type Object = A.Object + type ReadOnly = T + + local function f(): ReadOnly + return nil :: any + end + )"; + + CheckResult result = frontend.check("game/B"); + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index 49209a4d..79d9108d 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -9,7 +9,6 @@ using namespace Luau; LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); -LUAU_FASTFLAG(LuauMatchReturnsOptionalString); TEST_SUITE_BEGIN("BuiltinTests"); @@ -1064,10 +1063,7 @@ TEST_CASE_FIXTURE(Fixture, "string_match") )"); LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::LuauMatchReturnsOptionalString) - CHECK_EQ(toString(requireType("p")), "string?"); - else - CHECK_EQ(toString(requireType("p")), "string"); + CHECK_EQ(toString(requireType("p")), "string?"); } TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types") @@ -1078,18 +1074,9 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::LuauMatchReturnsOptionalString) - { - CHECK_EQ(toString(requireType("a")), "string?"); - CHECK_EQ(toString(requireType("b")), "number?"); - CHECK_EQ(toString(requireType("c")), "string?"); - } - else - { - CHECK_EQ(toString(requireType("a")), "string"); - CHECK_EQ(toString(requireType("b")), "number"); - CHECK_EQ(toString(requireType("c")), "string"); - } + CHECK_EQ(toString(requireType("a")), "string?"); + CHECK_EQ(toString(requireType("b")), "number?"); + CHECK_EQ(toString(requireType("c")), "string?"); } TEST_CASE_FIXTURE(Fixture, "gmatch_capture_types2") @@ -1100,18 +1087,9 @@ TEST_CASE_FIXTURE(Fixture, "gmatch_capture_types2") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::LuauMatchReturnsOptionalString) - { - CHECK_EQ(toString(requireType("a")), "string?"); - CHECK_EQ(toString(requireType("b")), "number?"); - CHECK_EQ(toString(requireType("c")), "string?"); - } - else - { - CHECK_EQ(toString(requireType("a")), "string"); - CHECK_EQ(toString(requireType("b")), "number"); - CHECK_EQ(toString(requireType("c")), "string"); - } + CHECK_EQ(toString(requireType("a")), "string?"); + CHECK_EQ(toString(requireType("b")), "number?"); + CHECK_EQ(toString(requireType("c")), "string?"); } TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_default_capture") @@ -1128,10 +1106,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_default_capture") CHECK_EQ(acm->expected, 1); CHECK_EQ(acm->actual, 4); - if (FFlag::LuauMatchReturnsOptionalString) - CHECK_EQ(toString(requireType("a")), "string?"); - else - CHECK_EQ(toString(requireType("a")), "string"); + CHECK_EQ(toString(requireType("a")), "string?"); } TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_balanced_escaped_parens") @@ -1148,18 +1123,9 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_balanced_escaped_parens CHECK_EQ(acm->expected, 3); CHECK_EQ(acm->actual, 4); - if (FFlag::LuauMatchReturnsOptionalString) - { - CHECK_EQ(toString(requireType("a")), "string?"); - CHECK_EQ(toString(requireType("b")), "string?"); - CHECK_EQ(toString(requireType("c")), "number?"); - } - else - { - CHECK_EQ(toString(requireType("a")), "string"); - CHECK_EQ(toString(requireType("b")), "string"); - CHECK_EQ(toString(requireType("c")), "number"); - } + CHECK_EQ(toString(requireType("a")), "string?"); + CHECK_EQ(toString(requireType("b")), "string?"); + CHECK_EQ(toString(requireType("c")), "number?"); } TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_parens_in_sets_are_ignored") @@ -1176,16 +1142,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_parens_in_sets_are_igno CHECK_EQ(acm->expected, 2); CHECK_EQ(acm->actual, 3); - if (FFlag::LuauMatchReturnsOptionalString) - { - CHECK_EQ(toString(requireType("a")), "string?"); - CHECK_EQ(toString(requireType("b")), "number?"); - } - else - { - CHECK_EQ(toString(requireType("a")), "string"); - CHECK_EQ(toString(requireType("b")), "number"); - } + CHECK_EQ(toString(requireType("a")), "string?"); + CHECK_EQ(toString(requireType("b")), "number?"); } TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_set_containing_lbracket") @@ -1196,16 +1154,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_set_containing_lbracket LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::LuauMatchReturnsOptionalString) - { - CHECK_EQ(toString(requireType("a")), "number?"); - CHECK_EQ(toString(requireType("b")), "string?"); - } - else - { - CHECK_EQ(toString(requireType("a")), "number"); - CHECK_EQ(toString(requireType("b")), "string"); - } + CHECK_EQ(toString(requireType("a")), "number?"); + CHECK_EQ(toString(requireType("b")), "string?"); } TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_leading_end_bracket_is_part_of_set") @@ -1253,18 +1203,9 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "match_capture_types") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::LuauMatchReturnsOptionalString) - { - CHECK_EQ(toString(requireType("a")), "string?"); - CHECK_EQ(toString(requireType("b")), "number?"); - CHECK_EQ(toString(requireType("c")), "string?"); - } - else - { - CHECK_EQ(toString(requireType("a")), "string"); - CHECK_EQ(toString(requireType("b")), "number"); - CHECK_EQ(toString(requireType("c")), "string"); - } + CHECK_EQ(toString(requireType("a")), "string?"); + CHECK_EQ(toString(requireType("b")), "number?"); + CHECK_EQ(toString(requireType("c")), "string?"); } TEST_CASE_FIXTURE(BuiltinsFixture, "match_capture_types2") @@ -1280,18 +1221,9 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "match_capture_types2") CHECK_EQ(toString(tm->wantedType), "number?"); CHECK_EQ(toString(tm->givenType), "string"); - if (FFlag::LuauMatchReturnsOptionalString) - { - CHECK_EQ(toString(requireType("a")), "string?"); - CHECK_EQ(toString(requireType("b")), "number?"); - CHECK_EQ(toString(requireType("c")), "string?"); - } - else - { - CHECK_EQ(toString(requireType("a")), "string"); - CHECK_EQ(toString(requireType("b")), "number"); - CHECK_EQ(toString(requireType("c")), "string"); - } + CHECK_EQ(toString(requireType("a")), "string?"); + CHECK_EQ(toString(requireType("b")), "number?"); + CHECK_EQ(toString(requireType("c")), "string?"); } TEST_CASE_FIXTURE(BuiltinsFixture, "find_capture_types") @@ -1302,18 +1234,9 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "find_capture_types") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::LuauMatchReturnsOptionalString) - { - CHECK_EQ(toString(requireType("a")), "string?"); - CHECK_EQ(toString(requireType("b")), "number?"); - CHECK_EQ(toString(requireType("c")), "string?"); - } - else - { - CHECK_EQ(toString(requireType("a")), "string"); - CHECK_EQ(toString(requireType("b")), "number"); - CHECK_EQ(toString(requireType("c")), "string"); - } + CHECK_EQ(toString(requireType("a")), "string?"); + CHECK_EQ(toString(requireType("b")), "number?"); + CHECK_EQ(toString(requireType("c")), "string?"); CHECK_EQ(toString(requireType("d")), "number?"); CHECK_EQ(toString(requireType("e")), "number?"); } @@ -1331,18 +1254,9 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "find_capture_types2") CHECK_EQ(toString(tm->wantedType), "number?"); CHECK_EQ(toString(tm->givenType), "string"); - if (FFlag::LuauMatchReturnsOptionalString) - { - CHECK_EQ(toString(requireType("a")), "string?"); - CHECK_EQ(toString(requireType("b")), "number?"); - CHECK_EQ(toString(requireType("c")), "string?"); - } - else - { - CHECK_EQ(toString(requireType("a")), "string"); - CHECK_EQ(toString(requireType("b")), "number"); - CHECK_EQ(toString(requireType("c")), "string"); - } + CHECK_EQ(toString(requireType("a")), "string?"); + CHECK_EQ(toString(requireType("b")), "number?"); + CHECK_EQ(toString(requireType("c")), "string?"); CHECK_EQ(toString(requireType("d")), "number?"); CHECK_EQ(toString(requireType("e")), "number?"); } @@ -1360,18 +1274,9 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "find_capture_types3") CHECK_EQ(toString(tm->wantedType), "boolean?"); CHECK_EQ(toString(tm->givenType), "string"); - if (FFlag::LuauMatchReturnsOptionalString) - { - CHECK_EQ(toString(requireType("a")), "string?"); - CHECK_EQ(toString(requireType("b")), "number?"); - CHECK_EQ(toString(requireType("c")), "string?"); - } - else - { - CHECK_EQ(toString(requireType("a")), "string"); - CHECK_EQ(toString(requireType("b")), "number"); - CHECK_EQ(toString(requireType("c")), "string"); - } + CHECK_EQ(toString(requireType("a")), "string?"); + CHECK_EQ(toString(requireType("b")), "number?"); + CHECK_EQ(toString(requireType("c")), "string?"); CHECK_EQ(toString(requireType("d")), "number?"); CHECK_EQ(toString(requireType("e")), "number?"); } diff --git a/tests/TypeInfer.definitions.test.cpp b/tests/TypeInfer.definitions.test.cpp index f3f46413..d6799757 100644 --- a/tests/TypeInfer.definitions.test.cpp +++ b/tests/TypeInfer.definitions.test.cpp @@ -78,7 +78,7 @@ TEST_CASE_FIXTURE(Fixture, "definition_file_loading") TEST_CASE_FIXTURE(Fixture, "load_definition_file_errors_do_not_pollute_global_scope") { unfreeze(frontend.globals.globalTypes); - LoadDefinitionFileResult parseFailResult = loadDefinitionFile(frontend.typeChecker, frontend.globals, frontend.globals.globalScope, R"( + LoadDefinitionFileResult parseFailResult = frontend.loadDefinitionFile(frontend.globals, frontend.globals.globalScope, R"( declare foo )", "@test", /* captureComments */ false); @@ -88,7 +88,7 @@ TEST_CASE_FIXTURE(Fixture, "load_definition_file_errors_do_not_pollute_global_sc std::optional fooTy = tryGetGlobalBinding(frontend.globals, "foo"); CHECK(!fooTy.has_value()); - LoadDefinitionFileResult checkFailResult = loadDefinitionFile(frontend.typeChecker, frontend.globals, frontend.globals.globalScope, R"( + LoadDefinitionFileResult checkFailResult = frontend.loadDefinitionFile(frontend.globals, frontend.globals.globalScope, R"( local foo: string = 123 declare bar: typeof(foo) )", @@ -140,7 +140,7 @@ TEST_CASE_FIXTURE(Fixture, "definition_file_classes") TEST_CASE_FIXTURE(Fixture, "class_definitions_cannot_overload_non_function") { unfreeze(frontend.globals.globalTypes); - LoadDefinitionFileResult result = loadDefinitionFile(frontend.typeChecker, frontend.globals, frontend.globals.globalScope, R"( + LoadDefinitionFileResult result = frontend.loadDefinitionFile(frontend.globals, frontend.globals.globalScope, R"( declare class A X: number X: string @@ -161,7 +161,7 @@ TEST_CASE_FIXTURE(Fixture, "class_definitions_cannot_overload_non_function") TEST_CASE_FIXTURE(Fixture, "class_definitions_cannot_extend_non_class") { unfreeze(frontend.globals.globalTypes); - LoadDefinitionFileResult result = loadDefinitionFile(frontend.typeChecker, frontend.globals, frontend.globals.globalScope, R"( + LoadDefinitionFileResult result = frontend.loadDefinitionFile(frontend.globals, frontend.globals.globalScope, R"( type NotAClass = {} declare class Foo extends NotAClass @@ -182,7 +182,7 @@ TEST_CASE_FIXTURE(Fixture, "class_definitions_cannot_extend_non_class") TEST_CASE_FIXTURE(Fixture, "no_cyclic_defined_classes") { unfreeze(frontend.globals.globalTypes); - LoadDefinitionFileResult result = loadDefinitionFile(frontend.typeChecker, frontend.globals, frontend.globals.globalScope, R"( + LoadDefinitionFileResult result = frontend.loadDefinitionFile(frontend.globals, frontend.globals.globalScope, R"( declare class Foo extends Bar end @@ -397,7 +397,7 @@ TEST_CASE_FIXTURE(Fixture, "class_definition_string_props") TEST_CASE_FIXTURE(Fixture, "class_definitions_reference_other_classes") { unfreeze(frontend.globals.globalTypes); - LoadDefinitionFileResult result = loadDefinitionFile(frontend.typeChecker, frontend.globals, frontend.globals.globalScope, R"( + LoadDefinitionFileResult result = frontend.loadDefinitionFile(frontend.globals, frontend.globals.globalScope, R"( declare class Channel Messages: { Message } OnMessage: (message: Message) -> () diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 482a6b7f..f1d42c6a 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -1784,7 +1784,6 @@ z = y -- Not OK, so the line is colorable TEST_CASE_FIXTURE(Fixture, "function_is_supertype_of_concrete_functions") { - ScopedFastFlag sff{"LuauNegatedFunctionTypes", true}; registerHiddenTypes(&frontend); CheckResult result = check(R"( @@ -1803,7 +1802,6 @@ TEST_CASE_FIXTURE(Fixture, "function_is_supertype_of_concrete_functions") TEST_CASE_FIXTURE(Fixture, "concrete_functions_are_not_supertypes_of_function") { - ScopedFastFlag sff{"LuauNegatedFunctionTypes", true}; registerHiddenTypes(&frontend); CheckResult result = check(R"( @@ -1824,7 +1822,6 @@ TEST_CASE_FIXTURE(Fixture, "concrete_functions_are_not_supertypes_of_function") TEST_CASE_FIXTURE(Fixture, "other_things_are_not_related_to_function") { - ScopedFastFlag sff{"LuauNegatedFunctionTypes", true}; registerHiddenTypes(&frontend); CheckResult result = check(R"( @@ -1860,7 +1857,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "dont_assert_when_the_tarjan_limit_is_exceede ScopedFastInt sfi{"LuauTarjanChildLimit", 2}; ScopedFastFlag sff[] = { {"DebugLuauDeferredConstraintResolution", true}, - {"LuauClonePublicInterfaceLess", true}, + {"LuauClonePublicInterfaceLess2", true}, {"LuauSubstitutionReentrant", true}, {"LuauSubstitutionFixMissingFields", true}, }; @@ -1880,4 +1877,33 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "dont_assert_when_the_tarjan_limit_is_exceede CHECK(Location({0, 0}, {4, 4}) == result.errors[1].location); } +/* We had a bug under DCR where instantiated type packs had a nullptr scope. + * + * This caused an issue with promotion. + */ +TEST_CASE_FIXTURE(Fixture, "instantiated_type_packs_must_have_a_non_null_scope") +{ + CheckResult result = check(R"( + function pcall(...: A...): R... + end + + type Dispatch = (A) -> () + + function mountReducer() + dispatchAction() + return nil :: any + end + + function dispatchAction() + end + + function useReducer(): Dispatch + local result, setResult = pcall(mountReducer) + return setResult + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index b3b2e4c9..b9784817 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -874,7 +874,7 @@ TEST_CASE_FIXTURE(Fixture, "generic_table_method") std::vector args = flatten(ftv->argTypes).first; TypeId argType = args.at(1); - CHECK_MESSAGE(get(argType), "Should be generic: " << *barType); + CHECK_MESSAGE(get(argType), "Should be generic: " << *barType); } TEST_CASE_FIXTURE(Fixture, "correctly_instantiate_polymorphic_member_functions") diff --git a/tests/TypeInfer.intersectionTypes.test.cpp b/tests/TypeInfer.intersectionTypes.test.cpp index f6d04a95..b682e5f6 100644 --- a/tests/TypeInfer.intersectionTypes.test.cpp +++ b/tests/TypeInfer.intersectionTypes.test.cpp @@ -327,7 +327,12 @@ TEST_CASE_FIXTURE(Fixture, "table_intersection_write_sealed") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Cannot add property 'z' to table 'X & Y'"); + auto e = toString(result.errors[0]); + // In DCR, because of type normalization, we print a different error message + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK_EQ("Cannot add property 'z' to table '{| x: number, y: number |}'", e); + else + CHECK_EQ("Cannot add property 'z' to table 'X & Y'", e); } TEST_CASE_FIXTURE(Fixture, "table_intersection_write_sealed_indirect") diff --git a/tests/TypeInfer.loops.test.cpp b/tests/TypeInfer.loops.test.cpp index 511cbc76..7a134358 100644 --- a/tests/TypeInfer.loops.test.cpp +++ b/tests/TypeInfer.loops.test.cpp @@ -707,4 +707,26 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "cli_68448_iterators_need_not_accept_nil") CHECK(toString(requireType("makeEnum"), {true}) == "({a}) -> {| [a]: a |}"); } +TEST_CASE_FIXTURE(Fixture, "iterate_over_free_table") +{ + CheckResult result = check(R"( + function print(x) end + + function dump(tbl) + print(tbl.whatever) + for k, v in tbl do + print(k) + print(v) + end + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + GenericError* ge = get(result.errors[0]); + REQUIRE(ge); + + CHECK("Cannot iterate over a table without indexer" == ge->message); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.oop.test.cpp b/tests/TypeInfer.oop.test.cpp index 0f540f68..f2b3d055 100644 --- a/tests/TypeInfer.oop.test.cpp +++ b/tests/TypeInfer.oop.test.cpp @@ -326,4 +326,84 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "flag_when_index_metamethod_returns_0_values" CHECK("nil" == toString(requireType("p"))); } +TEST_CASE_FIXTURE(BuiltinsFixture, "augmenting_an_unsealed_table_with_a_metatable") +{ + CheckResult result = check(R"( + local A = {number = 8} + + local B = setmetatable({}, A) + + function B:method() + return "hello!!" + end + )"); + + CHECK("{ @metatable { number: number }, { method: (a) -> string } }" == toString(requireType("B"), {true})); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "react_style_oo") +{ + CheckResult result = check(R"( + local Prototype = {} + + local ClassMetatable = { + __index = Prototype + } + + local BaseClass = (setmetatable({}, ClassMetatable)) + + function BaseClass:extend(name) + local class = { + name=name + } + + class.__index = class + + function class.ctor(props) + return setmetatable({props=props}, class) + end + + return setmetatable(class, getmetatable(self)) + end + + local C = BaseClass:extend('C') + local i = C.ctor({hello='world'}) + + local iName = i.name + local cName = C.name + local hello = i.props.hello + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK("string" == toString(requireType("iName"))); + CHECK("string" == toString(requireType("cName"))); + CHECK("string" == toString(requireType("hello"))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "cycle_between_object_constructor_and_alias") +{ + CheckResult result = check(R"( + local T = {} + T.__index = T + + function T.new(): T + return setmetatable({}, T) + end + + export type T = typeof(T.new()) + + return T + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + auto module = getMainModule(); + + REQUIRE(module->exportedTypeBindings.count("T")); + + TypeId aliasType = module->exportedTypeBindings["T"].type; + CHECK_MESSAGE(get(follow(aliasType)), "Expected metatable type but got: " << toString(aliasType)); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index 720784c3..174bc310 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -526,7 +526,17 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_unary_minus_error") LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("string", toString(requireType("a"))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + // Under DCR, this currently functions as a failed overload resolution, and so we can't say + // anything about the result type of the unary minus. + CHECK_EQ("any", toString(requireType("a"))); + } + else + { + + CHECK_EQ("string", toString(requireType("a"))); + } TypeMismatch* tm = get(result.errors[0]); REQUIRE_EQ(*tm->wantedType, *builtinTypes->booleanType); @@ -850,8 +860,6 @@ TEST_CASE_FIXTURE(Fixture, "operator_eq_operands_are_not_subtypes_of_each_other_ TEST_CASE_FIXTURE(Fixture, "operator_eq_completely_incompatible") { - ScopedFastFlag sff{"LuauIntersectionTestForEquality", true}; - CheckResult result = check(R"( local a: string | number = "hi" local b: {x: string}? = {x = "bye"} @@ -960,8 +968,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "expected_types_through_binary_or") TEST_CASE_FIXTURE(ClassFixture, "unrelated_classes_cannot_be_compared") { - ScopedFastFlag sff{"LuauIntersectionTestForEquality", true}; - CheckResult result = check(R"( local a = BaseClass.New() local b = UnrelatedClass.New() @@ -974,8 +980,6 @@ TEST_CASE_FIXTURE(ClassFixture, "unrelated_classes_cannot_be_compared") TEST_CASE_FIXTURE(Fixture, "unrelated_primitives_cannot_be_compared") { - ScopedFastFlag sff{"LuauIntersectionTestForEquality", true}; - CheckResult result = check(R"( local c = 5 == true )"); diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 30f77d68..87419deb 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -176,8 +176,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "error_on_eq_metamethod_returning_a_type_othe // We need refine both operands as `never` in the `==` branch. TEST_CASE_FIXTURE(Fixture, "lvalue_equals_another_lvalue_with_no_overlap") { - ScopedFastFlag sff{"LuauIntersectionTestForEquality", true}; - CheckResult result = check(R"( local function f(a: string, b: boolean?) if a == b then @@ -479,10 +477,10 @@ TEST_CASE_FIXTURE(Fixture, "free_options_cannot_be_unified_together") std::unique_ptr scope = std::make_unique(builtinTypes->anyTypePack); - TypeId free1 = arena.addType(FreeTypePack{scope.get()}); + TypeId free1 = arena.addType(FreeType{scope.get()}); TypeId option1 = arena.addType(UnionType{{nilType, free1}}); - TypeId free2 = arena.addType(FreeTypePack{scope.get()}); + TypeId free2 = arena.addType(FreeType{scope.get()}); TypeId option2 = arena.addType(UnionType{{nilType, free2}}); InternalErrorReporter iceHandler; diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 21ac6421..468adc2c 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -18,7 +18,6 @@ LUAU_FASTFLAG(LuauLowerBoundsCalculation); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAG(LuauTypeMismatchInvarianceInError) -LUAU_FASTFLAG(LuauDontExtendUnsealedRValueTables) TEST_SUITE_BEGIN("TableTests"); @@ -913,10 +912,7 @@ TEST_CASE_FIXTURE(Fixture, "disallow_indexing_into_an_unsealed_table_with_no_ind local k1 = getConstant("key1") )"); - if (FFlag::LuauDontExtendUnsealedRValueTables) - CHECK("any" == toString(requireType("k1"))); - else - CHECK("a" == toString(requireType("k1"))); + CHECK("any" == toString(requireType("k1"))); LUAU_REQUIRE_NO_ERRORS(result); } @@ -3542,8 +3538,6 @@ _ = {_,} TEST_CASE_FIXTURE(Fixture, "when_augmenting_an_unsealed_table_with_an_indexer_apply_the_correct_scope_to_the_indexer_type") { - ScopedFastFlag sff{"LuauDontExtendUnsealedRValueTables", true}; - CheckResult result = check(R"( local events = {} local mockObserveEvent = function(_, key, callback) @@ -3572,8 +3566,6 @@ TEST_CASE_FIXTURE(Fixture, "when_augmenting_an_unsealed_table_with_an_indexer_ap TEST_CASE_FIXTURE(Fixture, "dont_extend_unsealed_tables_in_rvalue_position") { - ScopedFastFlag sff{"LuauDontExtendUnsealedRValueTables", true}; - CheckResult result = check(R"( local testDictionary = { FruitName = "Lemon", diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 7c4bfb2e..3088235a 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -490,8 +490,13 @@ struct FindFreeTypes return !foundOne; } - template - bool operator()(ID, Unifiable::Free) + bool operator()(TypeId, FreeType) + { + foundOne = true; + return false; + } + + bool operator()(TypePackId, FreeTypePack) { foundOne = true; return false; @@ -1194,7 +1199,6 @@ TEST_CASE_FIXTURE(Fixture, "dcr_delays_expansion_of_function_containing_blocked_ { ScopedFastFlag sff[] = { {"DebugLuauDeferredConstraintResolution", true}, - {"LuauTinyUnifyNormalsFix", true}, // If we run this with error-suppression, it triggers an assertion. // FATAL ERROR: Assertion failed: !"Internal error: Trying to normalize a BlockedType" {"LuauTransitiveSubtyping", false}, diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index d49f0044..19a19e45 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -196,7 +196,6 @@ TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_missing_property") REQUIRE(bTy); CHECK_EQ(mup->missing[0], *bTy); CHECK_EQ(mup->key, "x"); - CHECK_EQ("*error-type*", toString(requireType("r"))); } @@ -354,7 +353,11 @@ a.x = 2 )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Value of type '({| x: number |} & {| y: number |})?' could be nil", toString(result.errors[0])); + auto s = toString(result.errors[0]); + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK_EQ("Value of type '{| x: number, y: number |}?' could be nil", s); + else + CHECK_EQ("Value of type '({| x: number |} & {| y: number |})?' could be nil", s); } TEST_CASE_FIXTURE(Fixture, "optional_length_error") diff --git a/tests/TypePack.test.cpp b/tests/TypePack.test.cpp index 20404434..7d8ed38f 100644 --- a/tests/TypePack.test.cpp +++ b/tests/TypePack.test.cpp @@ -25,7 +25,7 @@ struct TypePackFixture TypePackId freshTypePack() { - typePacks.emplace_back(new TypePackVar{Unifiable::Free{TypeLevel{}}}); + typePacks.emplace_back(new TypePackVar{FreeTypePack{TypeLevel{}}}); return typePacks.back().get(); } diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index 64ba63c8..3f0becc5 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -74,7 +74,7 @@ TEST_CASE_FIXTURE(Fixture, "return_type_of_function_is_parenthesized_if_not_just TEST_CASE_FIXTURE(Fixture, "return_type_of_function_is_parenthesized_if_tail_is_free") { auto emptyArgumentPack = TypePackVar{TypePack{}}; - auto free = Unifiable::Free(TypeLevel()); + auto free = FreeTypePack(TypeLevel()); auto freePack = TypePackVar{TypePackVariant{free}}; auto returnPack = TypePackVar{TypePack{{builtinTypes->numberType}, &freePack}}; auto returnsTwo = Type(FunctionType(frontend.globals.globalScope->level, &emptyArgumentPack, &returnPack)); diff --git a/tests/conformance/apicalls.lua b/tests/conformance/apicalls.lua index 27416623..8db62d96 100644 --- a/tests/conformance/apicalls.lua +++ b/tests/conformance/apicalls.lua @@ -22,4 +22,12 @@ function getpi() return pi end +function largealloc() + table.create(1000000) +end + +function oops() + return "oops" +end + return('OK') diff --git a/tests/conformance/interrupt.lua b/tests/conformance/interrupt.lua index d4b7c80a..c07f57e7 100644 --- a/tests/conformance/interrupt.lua +++ b/tests/conformance/interrupt.lua @@ -17,4 +17,9 @@ end bar() +function baz() +end + +baz() + return "OK" diff --git a/tests/conformance/math.lua b/tests/conformance/math.lua index 18ed1370..ea3b5c87 100644 --- a/tests/conformance/math.lua +++ b/tests/conformance/math.lua @@ -345,5 +345,7 @@ assert(math.round("1.8") == 2) assert(select('#', math.floor(1.4)) == 1) assert(select('#', math.ceil(1.6)) == 1) assert(select('#', math.sqrt(9)) == 1) +assert(select('#', math.deg(9)) == 1) +assert(select('#', math.rad(9)) == 1) return('OK') diff --git a/tests/conformance/pcall.lua b/tests/conformance/pcall.lua index 969209fc..b94f7972 100644 --- a/tests/conformance/pcall.lua +++ b/tests/conformance/pcall.lua @@ -161,4 +161,11 @@ checkresults({ false, "ok" }, xpcall(recurse, function() return string.reverse(" -- however, if xpcall handler itself runs out of extra stack space, we get "error in error handling" checkresults({ false, "error in error handling" }, xpcall(recurse, function() return recurse(calllimit) end, calllimit - 2)) +-- simulate OOM and make sure we can catch it with pcall or xpcall +checkresults({ false, "not enough memory" }, pcall(function() table.create(1e6) end)) +checkresults({ false, "not enough memory" }, xpcall(function() table.create(1e6) end, function(e) return e end)) +checkresults({ false, "oops" }, xpcall(function() table.create(1e6) end, function(e) return "oops" end)) +checkresults({ false, "error in error handling" }, xpcall(function() error("oops") end, function(e) table.create(1e6) end)) +checkresults({ false, "not enough memory" }, xpcall(function() table.create(1e6) end, function(e) table.create(1e6) end)) + return 'OK' diff --git a/tests/conformance/sort.lua b/tests/conformance/sort.lua index 693a10dc..3c2c20dd 100644 --- a/tests/conformance/sort.lua +++ b/tests/conformance/sort.lua @@ -99,12 +99,12 @@ a = {" table.sort(a) check(a) --- TODO: assert that pcall returns false for new sort implementation (table is modified during sorting) -pcall(table.sort, a, function (x, y) +local ok = pcall(table.sort, a, function (x, y) loadstring(string.format("a[%q] = ''", x))() collectgarbage() return x - - noreg - rip + + noreg + rip - al - cl - dl - bl + al + cl + dl + bl - eax - ecx - edx - ebx - esp - ebp - esi - edi - e{(int)index,d}d + eax + ecx + edx + ebx + esp + ebp + esi + edi + e{(int)index,d}d - rax - rcx - rdx - rbx - rsp - rbp - rsi - rdi - r{(int)index,d} + rax + rcx + rdx + rbx + rsp + rbp + rsi + rdi + r{(int)index,d} - xmm{(int)index,d} + xmm{(int)index,d} - ymm{(int)index,d} + ymm{(int)index,d} - + {base} {memSize,en} ptr[{base} + {index}*{(int)scale,d} + {imm}] {memSize,en} ptr[{index}*{(int)scale,d} + {imm}] {memSize,en} ptr[{base} + {imm}] + {memSize,en} ptr[{base} + {imm}] {memSize,en} ptr[{imm}] {imm} diff --git a/tools/natvis/Common.natvis b/tools/natvis/Common.natvis new file mode 100644 index 00000000..fe3a96d5 --- /dev/null +++ b/tools/natvis/Common.natvis @@ -0,0 +1,27 @@ + + + + + + count + capacity + + capacity + data + + + + + + + impl + + + + + + impl + + + + diff --git a/tools/test_dcr.py b/tools/test_dcr.py index d30490b3..817d0831 100644 --- a/tools/test_dcr.py +++ b/tools/test_dcr.py @@ -107,6 +107,12 @@ def main(): action="store_true", help="Write a new faillist.txt after running tests.", ) + parser.add_argument( + "--lti", + dest="lti", + action="store_true", + help="Run the tests with local type inference enabled.", + ) parser.add_argument("--randomize", action="store_true", help="Pick a random seed") @@ -120,13 +126,19 @@ def main(): args = parser.parse_args() + if args.write and args.lti: + print_stderr( + "Cannot run test_dcr.py with --write *and* --lti. You don't want to commit local type inference faillist.txt yet." + ) + sys.exit(1) + failList = loadFailList() - commandLine = [ - args.path, - "--reporters=xml", - "--fflags=true,DebugLuauDeferredConstraintResolution=true", - ] + flags = ["true", "DebugLuauDeferredConstraintResolution"] + if args.lti: + flags.append("DebugLuauLocalTypeInference") + + commandLine = [args.path, "--reporters=xml", "--fflags=" + ",".join(flags)] if args.random_seed: commandLine.append("--random-seed=" + str(args.random_seed))