diff --git a/Analysis/include/Luau/Breadcrumb.h b/Analysis/include/Luau/Breadcrumb.h new file mode 100644 index 00000000..59b293a0 --- /dev/null +++ b/Analysis/include/Luau/Breadcrumb.h @@ -0,0 +1,75 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Def.h" +#include "Luau/NotNull.h" +#include "Luau/Variant.h" + +#include +#include + +namespace Luau +{ + +using NullableBreadcrumbId = const struct Breadcrumb*; +using BreadcrumbId = NotNull; + +struct FieldMetadata +{ + std::string prop; +}; + +struct SubscriptMetadata +{ + BreadcrumbId key; +}; + +using Metadata = Variant; + +struct Breadcrumb +{ + NullableBreadcrumbId previous; + DefId def; + std::optional metadata; + std::vector children; +}; + +inline Breadcrumb* asMutable(NullableBreadcrumbId breadcrumb) +{ + LUAU_ASSERT(breadcrumb); + return const_cast(breadcrumb); +} + +template +const T* getMetadata(NullableBreadcrumbId breadcrumb) +{ + if (!breadcrumb || !breadcrumb->metadata) + return nullptr; + + return get_if(&*breadcrumb->metadata); +} + +struct BreadcrumbArena +{ + TypedAllocator allocator; + + template + BreadcrumbId add(NullableBreadcrumbId previous, DefId def, Args&&... args) + { + Breadcrumb* bc = allocator.allocate(Breadcrumb{previous, def, std::forward(args)...}); + if (previous) + asMutable(previous)->children.push_back(NotNull{bc}); + return NotNull{bc}; + } + + template + BreadcrumbId emplace(NullableBreadcrumbId previous, DefId def, Args&&... args) + { + Breadcrumb* bc = allocator.allocate(Breadcrumb{previous, def, Metadata{T{std::forward(args)...}}}); + if (previous) + asMutable(previous)->children.push_back(NotNull{bc}); + return NotNull{bc}; + } +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h index 1c41bbb7..2223c29e 100644 --- a/Analysis/include/Luau/Constraint.h +++ b/Analysis/include/Luau/Constraint.h @@ -2,7 +2,6 @@ #pragma once #include "Luau/Ast.h" // Used for some of the enumerations -#include "Luau/Def.h" #include "Luau/DenseHash.h" #include "Luau/NotNull.h" #include "Luau/Type.h" diff --git a/Analysis/include/Luau/ConstraintGraphBuilder.h b/Analysis/include/Luau/ConstraintGraphBuilder.h index 7b2711f8..e79c4c91 100644 --- a/Analysis/include/Luau/ConstraintGraphBuilder.h +++ b/Analysis/include/Luau/ConstraintGraphBuilder.h @@ -224,7 +224,7 @@ struct ConstraintGraphBuilder * @param inTypeArguments whether we are resolving a type that's contained within type arguments, `<...>`. * @return the type of the AST annotation. **/ - TypeId resolveType(const ScopePtr& scope, AstType* ty, bool inTypeArguments); + TypeId resolveType(const ScopePtr& scope, AstType* ty, bool inTypeArguments, bool replaceErrorWithFresh = false); /** * Resolves a type pack from its AST annotation. @@ -233,7 +233,7 @@ struct ConstraintGraphBuilder * @param inTypeArguments whether we are resolving a type that's contained within type arguments, `<...>`. * @return the type pack of the AST annotation. **/ - TypePackId resolveTypePack(const ScopePtr& scope, AstTypePack* tp, bool inTypeArguments); + TypePackId resolveTypePack(const ScopePtr& scope, AstTypePack* tp, bool inTypeArguments, bool replaceErrorWithFresh = false); /** * Resolves a type pack from its AST annotation. @@ -242,7 +242,7 @@ struct ConstraintGraphBuilder * @param inTypeArguments whether we are resolving a type that's contained within type arguments, `<...>`. * @return the type pack of the AST annotation. **/ - TypePackId resolveTypePack(const ScopePtr& scope, const AstTypeList& list, bool inTypeArguments); + TypePackId resolveTypePack(const ScopePtr& scope, const AstTypeList& list, bool inTypeArguments, bool replaceErrorWithFresh = false); /** * Creates generic types given a list of AST definitions, resolving default @@ -282,10 +282,17 @@ struct ConstraintGraphBuilder * initial scan of the AST and note what globals are defined. */ void prepopulateGlobalScope(const ScopePtr& globalScope, AstStatBlock* program); + + /** Given a function type annotation, return a vector describing the expected types of the calls to the function + * For example, calling a function with annotation ((number) -> string & ((string) -> number)) + * yields a vector of size 1, with value: [number | string] + */ + std::vector> getExpectedCallTypesForFunctionOverloads(const TypeId fnType); }; /** Borrow a vector of pointers from a vector of owning pointers to constraints. */ std::vector> borrowConstraints(const std::vector& constraints); + } // namespace Luau diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index 62687ae4..4fd7d0d1 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -132,8 +132,8 @@ struct ConstraintSolver bool tryDispatchIterableFunction( TypeId nextTy, TypeId tableTy, TypeId firstIndexTy, const IterableConstraint& c, NotNull constraint, bool force); - std::optional lookupTableProp(TypeId subjectType, const std::string& propName); - std::optional lookupTableProp(TypeId subjectType, const std::string& propName, std::unordered_set& seen); + std::pair, std::optional> lookupTableProp(TypeId subjectType, const std::string& propName); + std::pair, std::optional> lookupTableProp(TypeId subjectType, const std::string& propName, std::unordered_set& seen); void block(NotNull target, NotNull constraint); /** diff --git a/Analysis/include/Luau/DataFlowGraph.h b/Analysis/include/Luau/DataFlowGraph.h index bd096ea9..ce4ecb04 100644 --- a/Analysis/include/Luau/DataFlowGraph.h +++ b/Analysis/include/Luau/DataFlowGraph.h @@ -3,6 +3,7 @@ // Do not include LValue. It should never be used here. #include "Luau/Ast.h" +#include "Luau/Breadcrumb.h" #include "Luau/DenseHash.h" #include "Luau/Def.h" #include "Luau/Symbol.h" @@ -17,16 +18,14 @@ struct DataFlowGraph DataFlowGraph(DataFlowGraph&&) = default; DataFlowGraph& operator=(DataFlowGraph&&) = default; - // TODO: AstExprLocal, AstExprGlobal, and AstLocal* are guaranteed never to return nullopt. - // We leave them to return an optional as we build it out, but the end state is for them to return a non-optional DefId. - std::optional getDef(const AstExpr* expr) const; - std::optional getDef(const AstLocal* local) const; + NullableBreadcrumbId getBreadcrumb(const AstExpr* expr) const; - /// Retrieve the Def that corresponds to the given Symbol. - /// - /// We do not perform dataflow analysis on globals, so this function always - /// yields nullopt when passed a global Symbol. - std::optional getDef(const Symbol& symbol) const; + BreadcrumbId getBreadcrumb(const AstLocal* local) const; + BreadcrumbId getBreadcrumb(const AstExprLocal* local) const; + BreadcrumbId getBreadcrumb(const AstExprGlobal* global) const; + + BreadcrumbId getBreadcrumb(const AstStatDeclareGlobal* global) const; + BreadcrumbId getBreadcrumb(const AstStatDeclareFunction* func) const; private: DataFlowGraph() = default; @@ -34,9 +33,17 @@ private: DataFlowGraph(const DataFlowGraph&) = delete; DataFlowGraph& operator=(const DataFlowGraph&) = delete; - DefArena arena; - DenseHashMap astDefs{nullptr}; - DenseHashMap localDefs{nullptr}; + DefArena defs; + BreadcrumbArena breadcrumbs; + + DenseHashMap astBreadcrumbs{nullptr}; + + // Sometimes we don't have the AstExprLocal* but we have AstLocal*, and sometimes we need to extract that DefId. + DenseHashMap localBreadcrumbs{nullptr}; + + // There's no AstStatDeclaration, and it feels useless to introduce it just to enforce an invariant in one place. + // All keys in this maps are really only statements that ambiently declares a symbol. + DenseHashMap declaredBreadcrumbs{nullptr}; friend struct DataFlowGraphBuilder; }; @@ -44,12 +51,11 @@ private: struct DfgScope { DfgScope* parent; - DenseHashMap bindings{Symbol{}}; -}; + DenseHashMap bindings{Symbol{}}; + DenseHashMap> props{nullptr}; -struct ExpressionFlowGraph -{ - std::optional def; + NullableBreadcrumbId lookup(Symbol symbol) const; + NullableBreadcrumbId lookup(DefId def, const std::string& key) const; }; // Currently unsound. We do not presently track the control flow of the program. @@ -65,23 +71,19 @@ private: DataFlowGraphBuilder& operator=(const DataFlowGraphBuilder&) = delete; DataFlowGraph graph; - NotNull arena{&graph.arena}; - struct InternalErrorReporter* handle; + NotNull defs{&graph.defs}; + NotNull breadcrumbs{&graph.breadcrumbs}; + + struct InternalErrorReporter* handle = nullptr; + DfgScope* moduleScope = nullptr; + std::vector> scopes; - // Does not belong in DataFlowGraphBuilder, but the old solver allows properties to escape the scope they were defined in, - // so we will need to be able to emulate this same behavior here too. We can kill this once we have better flow sensitivity. - DenseHashMap> props{nullptr}; - DfgScope* childScope(DfgScope* scope); - std::optional use(DfgScope* scope, Symbol symbol, AstExpr* e); - DefId use(DefId def, AstExprIndexName* e); - void visit(DfgScope* scope, AstStatBlock* b); void visitBlockWithoutChildScope(DfgScope* scope, AstStatBlock* b); - // TODO: visit type aliases void visit(DfgScope* scope, AstStat* s); void visit(DfgScope* scope, AstStatIf* i); void visit(DfgScope* scope, AstStatWhile* w); @@ -97,24 +99,52 @@ private: void visit(DfgScope* scope, AstStatCompoundAssign* c); void visit(DfgScope* scope, AstStatFunction* f); void visit(DfgScope* scope, AstStatLocalFunction* l); + void visit(DfgScope* scope, AstStatTypeAlias* t); + void visit(DfgScope* scope, AstStatDeclareGlobal* d); + void visit(DfgScope* scope, AstStatDeclareFunction* d); + void visit(DfgScope* scope, AstStatDeclareClass* d); + void visit(DfgScope* scope, AstStatError* error); - ExpressionFlowGraph visitExpr(DfgScope* scope, AstExpr* e); - ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprLocal* l); - ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprGlobal* g); - ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprCall* c); - ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprIndexName* i); - ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprIndexExpr* i); - ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprFunction* f); - ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprTable* t); - ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprUnary* u); - ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprBinary* b); - ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprTypeAssertion* t); - ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprIfElse* i); - ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprInterpString* i); + BreadcrumbId visitExpr(DfgScope* scope, AstExpr* e); + BreadcrumbId visitExpr(DfgScope* scope, AstExprLocal* l); + BreadcrumbId visitExpr(DfgScope* scope, AstExprGlobal* g); + BreadcrumbId visitExpr(DfgScope* scope, AstExprCall* c); + BreadcrumbId visitExpr(DfgScope* scope, AstExprIndexName* i); + BreadcrumbId visitExpr(DfgScope* scope, AstExprIndexExpr* i); + BreadcrumbId visitExpr(DfgScope* scope, AstExprFunction* f); + BreadcrumbId visitExpr(DfgScope* scope, AstExprTable* t); + BreadcrumbId visitExpr(DfgScope* scope, AstExprUnary* u); + BreadcrumbId visitExpr(DfgScope* scope, AstExprBinary* b); + BreadcrumbId visitExpr(DfgScope* scope, AstExprTypeAssertion* t); + BreadcrumbId visitExpr(DfgScope* scope, AstExprIfElse* i); + BreadcrumbId visitExpr(DfgScope* scope, AstExprInterpString* i); + BreadcrumbId visitExpr(DfgScope* scope, AstExprError* error); - // TODO: visitLValue - // TODO: visitTypes (because of typeof which has access to values namespace, needs unreachable scope) - // TODO: visitTypePacks (because of typeof which has access to values namespace, needs unreachable scope) + void visitLValue(DfgScope* scope, AstExpr* e); + void visitLValue(DfgScope* scope, AstExprLocal* l); + void visitLValue(DfgScope* scope, AstExprGlobal* g); + void visitLValue(DfgScope* scope, AstExprIndexName* i); + void visitLValue(DfgScope* scope, AstExprIndexExpr* i); + void visitLValue(DfgScope* scope, AstExprError* e); + + void visitType(DfgScope* scope, AstType* t); + void visitType(DfgScope* scope, AstTypeReference* r); + void visitType(DfgScope* scope, AstTypeTable* t); + void visitType(DfgScope* scope, AstTypeFunction* f); + void visitType(DfgScope* scope, AstTypeTypeof* t); + void visitType(DfgScope* scope, AstTypeUnion* u); + void visitType(DfgScope* scope, AstTypeIntersection* i); + void visitType(DfgScope* scope, AstTypeError* error); + + void visitTypePack(DfgScope* scope, AstTypePack* p); + void visitTypePack(DfgScope* scope, AstTypePackExplicit* e); + void visitTypePack(DfgScope* scope, AstTypePackVariadic* v); + void visitTypePack(DfgScope* scope, AstTypePackGeneric* g); + + void visitTypeList(DfgScope* scope, AstTypeList l); + + void visitGenerics(DfgScope* scope, AstArray g); + void visitGenericPacks(DfgScope* scope, AstArray g); }; } // namespace Luau diff --git a/Analysis/include/Luau/Def.h b/Analysis/include/Luau/Def.h index 1eef7dfd..10d81367 100644 --- a/Analysis/include/Luau/Def.h +++ b/Analysis/include/Luau/Def.h @@ -14,12 +14,6 @@ namespace Luau struct Def; using DefId = NotNull; -struct FieldMetadata -{ - DefId parent; - std::string propName; -}; - /** * A cell is a "single-object" value. * @@ -29,7 +23,6 @@ struct FieldMetadata */ struct Cell { - std::optional field; }; /** @@ -83,7 +76,6 @@ struct DefArena TypedAllocator allocator; DefId freshCell(); - DefId freshCell(DefId parent, const std::string& prop); // TODO: implement once we have cases where we need to merge in definitions // DefId phi(const std::vector& defs); }; diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index 403551f6..7c5dc4a0 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -162,8 +162,7 @@ struct Frontend ScopePtr getGlobalScope(); private: - ModulePtr check(const SourceModule& sourceModule, Mode mode, std::vector requireCycles, - bool forAutocomplete = false); + ModulePtr check(const SourceModule& sourceModule, Mode mode, std::vector requireCycles, bool forAutocomplete = false, bool recordJsonLog = false); std::pair getSourceNode(const ModuleName& name); SourceModule parse(const ModuleName& name, std::string_view src, const ParseOptions& parseOptions); @@ -202,16 +201,12 @@ private: ScopePtr globalScope; }; -ModulePtr check( - const SourceModule& sourceModule, - const std::vector& requireCycles, - NotNull builtinTypes, - NotNull iceHandler, - NotNull moduleResolver, - NotNull fileResolver, - const ScopePtr& globalScope, - NotNull unifierState, - FrontendOptions options -); +ModulePtr check(const SourceModule& sourceModule, const std::vector& requireCycles, NotNull builtinTypes, + NotNull iceHandler, NotNull moduleResolver, NotNull fileResolver, + const ScopePtr& globalScope, FrontendOptions options); + +ModulePtr check(const SourceModule& sourceModule, const std::vector& requireCycles, NotNull builtinTypes, + NotNull iceHandler, NotNull moduleResolver, NotNull fileResolver, + const ScopePtr& globalScope, FrontendOptions options, bool recordJsonLog); } // namespace Luau diff --git a/Analysis/include/Luau/Refinement.h b/Analysis/include/Luau/Refinement.h index e7d3cf23..fecf459a 100644 --- a/Analysis/include/Luau/Refinement.h +++ b/Analysis/include/Luau/Refinement.h @@ -1,13 +1,15 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once -#include "Luau/Def.h" +#include "Luau/NotNull.h" #include "Luau/TypedAllocator.h" #include "Luau/Variant.h" namespace Luau { +using BreadcrumbId = NotNull; + struct Type; using TypeId = const Type*; @@ -50,7 +52,7 @@ struct Equivalence struct Proposition { - DefId def; + BreadcrumbId breadcrumb; TypeId discriminantTy; }; @@ -67,7 +69,7 @@ struct RefinementArena RefinementId conjunction(RefinementId lhs, RefinementId rhs); RefinementId disjunction(RefinementId lhs, RefinementId rhs); RefinementId equivalence(RefinementId lhs, RefinementId rhs); - RefinementId proposition(DefId def, TypeId discriminantTy); + RefinementId proposition(BreadcrumbId breadcrumb, TypeId discriminantTy); private: TypedAllocator allocator; diff --git a/Analysis/include/Luau/Scope.h b/Analysis/include/Luau/Scope.h index 85a36fc9..0d397267 100644 --- a/Analysis/include/Luau/Scope.h +++ b/Analysis/include/Luau/Scope.h @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/Def.h" #include "Luau/Location.h" #include "Luau/NotNull.h" #include "Luau/Type.h" diff --git a/Analysis/include/Luau/Type.h b/Analysis/include/Luau/Type.h index d009001b..cf1f8dae 100644 --- a/Analysis/include/Luau/Type.h +++ b/Analysis/include/Luau/Type.h @@ -4,9 +4,7 @@ #include "Luau/Ast.h" #include "Luau/Common.h" #include "Luau/Refinement.h" -#include "Luau/DataFlowGraph.h" #include "Luau/DenseHash.h" -#include "Luau/Def.h" #include "Luau/NotNull.h" #include "Luau/Predicate.h" #include "Luau/Unifiable.h" @@ -662,6 +660,7 @@ public: const TypeId functionType; const TypeId classType; const TypeId tableType; + const TypeId emptyTableType; const TypeId trueType; const TypeId falseType; const TypeId anyType; diff --git a/Analysis/include/Luau/TypeReduction.h b/Analysis/include/Luau/TypeReduction.h index 80a7ac59..3f64870a 100644 --- a/Analysis/include/Luau/TypeReduction.h +++ b/Analysis/include/Luau/TypeReduction.h @@ -54,8 +54,8 @@ struct TypeReductionOptions struct TypeReduction { - explicit TypeReduction( - NotNull arena, NotNull builtinTypes, NotNull handle, const TypeReductionOptions& opts = {}); + explicit TypeReduction(NotNull arena, NotNull builtinTypes, NotNull handle, + const TypeReductionOptions& opts = {}); TypeReduction(const TypeReduction&) = delete; TypeReduction& operator=(const TypeReduction&) = delete; diff --git a/Analysis/include/Luau/TypeUtils.h b/Analysis/include/Luau/TypeUtils.h index 3f535a03..42ba4052 100644 --- a/Analysis/include/Luau/TypeUtils.h +++ b/Analysis/include/Luau/TypeUtils.h @@ -27,7 +27,8 @@ std::pair> getParameterExtents(const TxnLog* log, // Extend the provided pack to at least `length` types. // Returns a temporary TypePack that contains those types plus a tail. -TypePack extendTypePack(TypeArena& arena, NotNull builtinTypes, TypePackId pack, size_t length); +TypePack extendTypePack( + TypeArena& arena, NotNull builtinTypes, TypePackId pack, size_t length, std::vector> overrides = {}); /** * Reduces a union by decomposing to the any/error type if it appears in the diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index ebfff4c2..50024e3f 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -61,8 +61,9 @@ struct Unifier ErrorVec errors; Location location; Variance variance = Covariant; - bool normalize; // Normalize unions and intersections if necessary - bool useScopes = false; // If true, we use the scope hierarchy rather than TypeLevels + bool normalize = true; // Normalize unions and intersections if necessary + bool checkInhabited = true; // Normalize types to check if they are inhabited + bool useScopes = false; // If true, we use the scope hierarchy rather than TypeLevels CountMismatch::Context ctx = CountMismatch::Arg; UnifierSharedState& sharedState; @@ -155,5 +156,6 @@ private: }; void promoteTypeLevels(TxnLog& log, const TypeArena* arena, TypeLevel minLevel, Scope* outerScope, bool useScope, TypePackId tp); +std::optional hasUnificationTooComplex(const ErrorVec& errors); } // namespace Luau diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 85e27168..1e094971 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -16,6 +16,7 @@ LUAU_FASTFLAGVARIABLE(LuauCompleteTableKeysBetter, false); LUAU_FASTFLAGVARIABLE(LuauFixAutocompleteInWhile, false); LUAU_FASTFLAGVARIABLE(LuauFixAutocompleteInFor, false); +LUAU_FASTFLAGVARIABLE(LuauAutocompleteSkipNormalization, false); static const std::unordered_set kStatementStartingKeywords = { "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; @@ -145,6 +146,13 @@ static bool checkTypeMatch(TypeId subTy, TypeId superTy, NotNull scope, T Normalizer normalizer{typeArena, builtinTypes, NotNull{&unifierState}}; Unifier unifier(NotNull{&normalizer}, Mode::Strict, scope, Location(), Variance::Covariant); + if (FFlag::LuauAutocompleteSkipNormalization) + { + // Cost of normalization can be too high for autocomplete response time requirements + unifier.normalize = false; + unifier.checkInhabited = false; + } + return unifier.canUnify(subTy, superTy).empty(); } @@ -314,7 +322,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNul { autocompleteProps(module, typeArena, builtinTypes, rootTy, mt->table, indexType, nodes, result, seen); - if (auto mtable = get(mt->metatable)) + if (auto mtable = get(follow(mt->metatable))) fillMetatableProps(mtable); } else if (auto i = get(ty)) @@ -1528,9 +1536,9 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M else if (!statIf->thenLocation || statIf->thenLocation->containsClosed(position)) return {{{"then", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; } - else if (AstStatIf* statIf = extractStat(ancestry); - statIf && (!statIf->thenLocation || statIf->thenLocation->containsClosed(position)) && - (statIf->condition && !statIf->condition->location.containsClosed(position))) + else if (AstStatIf* statIf = extractStat(ancestry); statIf && + (!statIf->thenLocation || statIf->thenLocation->containsClosed(position)) && + (statIf->condition && !statIf->condition->location.containsClosed(position))) { AutocompleteEntryMap ret; ret["then"] = {AutocompleteEntryKind::Keyword}; diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index 7bb57208..b111c504 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -15,6 +15,8 @@ #include +LUAU_FASTFLAGVARIABLE(LuauDeprecateTableGetnForeach, false) + /** FIXME: Many of these type definitions are not quite completely accurate. * * Some of them require richer generics than we have. For instance, we do not yet have a way to talk @@ -335,6 +337,14 @@ void registerBuiltinGlobals(TypeChecker& typeChecker) ttv->props["freeze"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.freeze"); ttv->props["clone"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.clone"); + if (FFlag::LuauDeprecateTableGetnForeach) + { + ttv->props["getn"].deprecated = true; + ttv->props["getn"].deprecatedSuggestion = "#"; + ttv->props["foreach"].deprecated = true; + ttv->props["foreachi"].deprecated = true; + } + attachMagicFunction(ttv->props["pack"].type, magicFunctionPack); attachDcrMagicFunction(ttv->props["pack"].type, dcrMagicFunctionPack); } @@ -428,6 +438,14 @@ void registerBuiltinGlobals(Frontend& frontend) ttv->props["freeze"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.freeze"); ttv->props["clone"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.clone"); + if (FFlag::LuauDeprecateTableGetnForeach) + { + ttv->props["getn"].deprecated = true; + ttv->props["getn"].deprecatedSuggestion = "#"; + ttv->props["foreach"].deprecated = true; + ttv->props["foreachi"].deprecated = true; + } + attachMagicFunction(ttv->props["pack"].type, magicFunctionPack); } diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index fe412632..9ee2b088 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -2,11 +2,13 @@ #include "Luau/ConstraintGraphBuilder.h" #include "Luau/Ast.h" +#include "Luau/Breadcrumb.h" #include "Luau/Common.h" #include "Luau/Constraint.h" #include "Luau/DcrLogger.h" #include "Luau/ModuleResolver.h" #include "Luau/RecursionCounter.h" +#include "Luau/Refinement.h" #include "Luau/Scope.h" #include "Luau/TypeUtils.h" #include "Luau/Type.h" @@ -14,7 +16,6 @@ #include LUAU_FASTINT(LuauCheckRecursionLimit); -LUAU_FASTFLAG(DebugLuauLogSolverToJson); LUAU_FASTFLAG(DebugLuauMagicTypes); LUAU_FASTFLAG(LuauNegatedClassTypes); LUAU_FASTFLAG(SupportTypeAliasGoToDeclaration); @@ -145,9 +146,6 @@ ConstraintGraphBuilder::ConstraintGraphBuilder(const ModuleName& moduleName, Mod , globalScope(globalScope) , logger(logger) { - if (FFlag::DebugLuauLogSolverToJson) - LUAU_ASSERT(logger); - LUAU_ASSERT(module); } @@ -186,29 +184,42 @@ NotNull ConstraintGraphBuilder::addConstraint(const ScopePtr& scope, return NotNull{constraints.emplace_back(std::move(c)).get()}; } -static void unionRefinements(const std::unordered_map& lhs, const std::unordered_map& rhs, - std::unordered_map& dest, NotNull arena) +struct RefinementPartition { - for (auto [def, ty] : lhs) + // Types that we want to intersect against the type of the expression. + std::vector discriminantTypes; + + // Sometimes the type we're discriminating against is implicitly nil. + bool shouldAppendNilType = false; +}; + +using RefinementContext = std::unordered_map; + +static void unionRefinements(const RefinementContext& lhs, const RefinementContext& rhs, RefinementContext& dest, NotNull arena) +{ + for (auto& [def, partition] : lhs) { auto rhsIt = rhs.find(def); if (rhsIt == rhs.end()) continue; - std::vector discriminants{{ty, rhsIt->second}}; + LUAU_ASSERT(!partition.discriminantTypes.empty()); + LUAU_ASSERT(!rhsIt->second.discriminantTypes.empty()); - if (auto destIt = dest.find(def); destIt != dest.end()) - discriminants.push_back(destIt->second); + TypeId leftDiscriminantTy = + partition.discriminantTypes.size() == 1 ? partition.discriminantTypes[0] : arena->addType(IntersectionType{partition.discriminantTypes}); - dest[def] = arena->addType(UnionType{std::move(discriminants)}); + TypeId rightDiscriminantTy = rhsIt->second.discriminantTypes.size() == 1 ? rhsIt->second.discriminantTypes[0] + : arena->addType(IntersectionType{rhsIt->second.discriminantTypes}); + + dest[def].discriminantTypes.push_back(arena->addType(UnionType{{leftDiscriminantTy, rightDiscriminantTy}})); + dest[def].shouldAppendNilType |= partition.shouldAppendNilType || rhsIt->second.shouldAppendNilType; } } -static void computeRefinement(const ScopePtr& scope, RefinementId refinement, std::unordered_map* refis, bool sense, - NotNull arena, bool eq, std::vector* constraints) +static void computeRefinement(const ScopePtr& scope, RefinementId refinement, RefinementContext* refis, bool sense, NotNull arena, bool eq, + std::vector* constraints) { - using RefinementMap = std::unordered_map; - if (!refinement) return; else if (auto variadic = get(refinement)) @@ -220,8 +231,8 @@ static void computeRefinement(const ScopePtr& scope, RefinementId refinement, st return computeRefinement(scope, negation->refinement, refis, !sense, arena, eq, constraints); else if (auto conjunction = get(refinement)) { - RefinementMap lhsRefis; - RefinementMap rhsRefis; + RefinementContext lhsRefis; + RefinementContext rhsRefis; computeRefinement(scope, conjunction->lhs, sense ? refis : &lhsRefis, sense, arena, eq, constraints); computeRefinement(scope, conjunction->rhs, sense ? refis : &rhsRefis, sense, arena, eq, constraints); @@ -231,8 +242,8 @@ static void computeRefinement(const ScopePtr& scope, RefinementId refinement, st } else if (auto disjunction = get(refinement)) { - RefinementMap lhsRefis; - RefinementMap rhsRefis; + RefinementContext lhsRefis; + RefinementContext rhsRefis; computeRefinement(scope, disjunction->lhs, sense ? &lhsRefis : refis, sense, arena, eq, constraints); computeRefinement(scope, disjunction->rhs, sense ? &rhsRefis : refis, sense, arena, eq, constraints); @@ -256,50 +267,59 @@ static void computeRefinement(const ScopePtr& scope, RefinementId refinement, st constraints->push_back(SingletonOrTopTypeConstraint{discriminantTy, proposition->discriminantTy, !sense}); } - if (auto it = refis->find(proposition->def); it != refis->end()) - (*refis)[proposition->def] = arena->addType(IntersectionType{{discriminantTy, it->second}}); - else - (*refis)[proposition->def] = discriminantTy; + RefinementContext uncommittedRefis; + uncommittedRefis[proposition->breadcrumb->def].discriminantTypes.push_back(discriminantTy); + + // When the top-level expression is `t[x]`, we want to refine it into `nil`, not `never`. + if ((sense || !eq) && getMetadata(proposition->breadcrumb)) + uncommittedRefis[proposition->breadcrumb->def].shouldAppendNilType = true; + + for (NullableBreadcrumbId current = proposition->breadcrumb; current && current->previous; current = current->previous) + { + LUAU_ASSERT(get(current->def)); + + // If this current breadcrumb has no metadata, it's no-op for the purpose of building a discriminant type. + if (!current->metadata) + continue; + else if (auto field = getMetadata(current)) + { + TableType::Props props{{field->prop, Property{discriminantTy}}}; + discriminantTy = arena->addType(TableType{std::move(props), std::nullopt, TypeLevel{}, scope.get(), TableState::Sealed}); + uncommittedRefis[current->previous->def].discriminantTypes.push_back(discriminantTy); + } + } + + // And now it's time to commit it. + for (auto& [def, partition] : uncommittedRefis) + { + for (TypeId discriminantTy : partition.discriminantTypes) + (*refis)[def].discriminantTypes.push_back(discriminantTy); + + (*refis)[def].shouldAppendNilType |= partition.shouldAppendNilType; + } } } -static std::pair computeDiscriminantType(NotNull arena, const ScopePtr& scope, DefId def, TypeId discriminantTy) -{ - LUAU_ASSERT(get(def)); - - while (const Cell* current = get(def)) - { - if (!current->field) - break; - - TableType::Props props{{current->field->propName, Property{discriminantTy}}}; - discriminantTy = arena->addType(TableType{std::move(props), std::nullopt, TypeLevel{}, scope.get(), TableState::Sealed}); - - def = current->field->parent; - current = get(def); - } - - return {def, discriminantTy}; -} - void ConstraintGraphBuilder::applyRefinements(const ScopePtr& scope, Location location, RefinementId refinement) { if (!refinement) return; - std::unordered_map refinements; + RefinementContext refinements; std::vector constraints; computeRefinement(scope, refinement, &refinements, /*sense*/ true, arena, /*eq*/ false, &constraints); - for (auto [def, discriminantTy] : refinements) + for (auto& [def, partition] : refinements) { - auto [def2, discriminantTy2] = computeDiscriminantType(arena, scope, def, discriminantTy); - std::optional defTy = scope->lookup(def2); - if (!defTy) - ice->ice("Every DefId must map to a type!"); + if (std::optional defTy = scope->lookup(def)) + { + TypeId ty = *defTy; + if (partition.shouldAppendNilType) + ty = arena->addType(UnionType{{ty, builtinTypes->nilType}}); - TypeId resultTy = arena->addType(IntersectionType{{*defTy, discriminantTy2}}); - scope->dcrRefinements[def2] = resultTy; + partition.discriminantTypes.push_back(ty); + scope->dcrRefinements[def] = arena->addType(IntersectionType{std::move(partition.discriminantTypes)}); + } } for (auto& c : constraints) @@ -321,7 +341,7 @@ void ConstraintGraphBuilder::visit(AstStatBlock* block) visitBlockWithoutChildScope(scope, block); - if (FFlag::DebugLuauLogSolverToJson) + if (logger) logger->captureGenerationModule(module); } @@ -543,8 +563,8 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) // HACK: In the greedy solver, we say the type state of a variable is the type annotation itself, but // the actual type state is the corresponding initializer expression (if it exists) or nil otherwise. - if (auto def = dfg->getDef(l)) - scope->dcrRefinements[*def] = varTypes[i]; + BreadcrumbId bc = dfg->getBreadcrumb(l); + scope->dcrRefinements[bc->def] = varTypes[i]; } if (local->values.size > 0) @@ -578,8 +598,9 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFor* for_) { + TypeId annotationTy = builtinTypes->numberType; if (for_->var->annotation) - resolveType(scope, for_->var->annotation, /* inTypeArguments */ false); + annotationTy = resolveType(scope, for_->var->annotation, /* inTypeArguments */ false); auto inferNumber = [&](AstExpr* expr) { if (!expr) @@ -594,7 +615,10 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFor* for_) inferNumber(for_->step); ScopePtr forScope = childScope(for_, scope); - forScope->bindings[for_->var] = Binding{builtinTypes->numberType, for_->var->location}; + forScope->bindings[for_->var] = Binding{annotationTy, for_->var->location}; + + BreadcrumbId bc = dfg->getBreadcrumb(for_->var); + forScope->dcrRefinements[bc->def] = annotationTy; visit(forScope, for_->body); } @@ -613,8 +637,8 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatForIn* forIn) loopScope->bindings[var] = Binding{ty, var->location}; variableTypes.push_back(ty); - if (auto def = dfg->getDef(var)) - loopScope->dcrRefinements[*def] = ty; + BreadcrumbId bc = dfg->getBreadcrumb(var); + loopScope->dcrRefinements[bc->def] = ty; } // It is always ok to provide too few variables, so we give this pack a free tail. @@ -638,10 +662,8 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatRepeat* repeat) { ScopePtr repeatScope = childScope(repeat, scope); - visit(repeatScope, repeat->body); + visitBlockWithoutChildScope(repeatScope, repeat->body); - // The condition does indeed have access to bindings from within the body of - // the loop. check(repeatScope, repeat->condition); } @@ -662,6 +684,10 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocalFunction* FunctionSignature sig = checkFunctionSignature(scope, function->func); sig.bodyScope->bindings[function->name] = Binding{sig.signature, function->func->location}; + BreadcrumbId bc = dfg->getBreadcrumb(function->name); + scope->dcrRefinements[bc->def] = functionType; + sig.bodyScope->dcrRefinements[bc->def] = sig.signature; + Checkpoint start = checkpoint(this); checkFunctionBody(sig.bodyScope, function->func); Checkpoint end = checkpoint(this); @@ -697,10 +723,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction* funct addConstraint(scope, function->name->location, SubtypeConstraint{generalizedType, *existingFunctionTy}); Symbol sym{localName->local}; - std::optional def = dfg->getDef(sym); - LUAU_ASSERT(def); scope->bindings[sym].typeId = generalizedType; - scope->dcrRefinements[*def] = generalizedType; } else scope->bindings[localName->local] = Binding{generalizedType, localName->location}; @@ -742,6 +765,9 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction* funct if (generalizedType == nullptr) ice->ice("generalizedType == nullptr", function->location); + if (NullableBreadcrumbId bc = dfg->getBreadcrumb(function->name)) + scope->dcrRefinements[bc->def] = generalizedType; + checkFunctionBody(sig.bodyScope, function->func); Checkpoint end = checkpoint(this); @@ -821,19 +847,19 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatCompoundAssign* { // We need to tweak the BinaryConstraint that we emit, so we cannot use the // strategy of falsifying an AST fragment. - TypeId varId = checkLValue(scope, assign->var); - Inference valueInf = check(scope, assign->value); + TypeId varTy = checkLValue(scope, assign->var); + TypeId valueTy = check(scope, assign->value).ty; TypeId resultType = arena->addType(BlockedType{}); addConstraint(scope, assign->location, - BinaryConstraint{assign->op, varId, valueInf.ty, resultType, assign, &module->astOriginalCallTypes, &module->astOverloadResolvedTypes}); - addConstraint(scope, assign->location, SubtypeConstraint{resultType, varId}); + BinaryConstraint{assign->op, varTy, valueTy, resultType, assign, &module->astOriginalCallTypes, &module->astOverloadResolvedTypes}); + addConstraint(scope, assign->location, SubtypeConstraint{resultType, varTy}); } void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatIf* ifStatement) { ScopePtr condScope = childScope(ifStatement->condition, scope); - auto [_, refinement] = check(condScope, ifStatement->condition, std::nullopt); + RefinementId refinement = check(condScope, ifStatement->condition, std::nullopt).refinement; ScopePtr thenScope = childScope(ifStatement->thenbody, scope); applyRefinements(thenScope, ifStatement->condition->location, refinement); @@ -921,7 +947,10 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareGlobal* Name globalName(global->name.value); module->declaredGlobals[globalName] = globalTy; - scope->bindings[global->name] = Binding{globalTy, global->location}; + rootScope->bindings[global->name] = Binding{globalTy, global->location}; + + BreadcrumbId bc = dfg->getBreadcrumb(global); + rootScope->dcrRefinements[bc->def] = globalTy; } static bool isMetamethod(const Name& name) @@ -1067,6 +1096,9 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareFunction module->declaredGlobals[fnName] = fnType; scope->bindings[global->name] = Binding{fnType, global->location}; + + BreadcrumbId bc = dfg->getBreadcrumb(global); + rootScope->dcrRefinements[bc->def] = fnType; } void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatError* error) @@ -1158,10 +1190,10 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa exprArgs.push_back(indexExpr->expr); - if (auto def = dfg->getDef(indexExpr->expr)) + if (auto bc = dfg->getBreadcrumb(indexExpr->expr)) { TypeId discriminantTy = arena->addType(BlockedType{}); - returnRefinements.push_back(refinementArena.proposition(*def, discriminantTy)); + returnRefinements.push_back(refinementArena.proposition(NotNull{bc}, discriminantTy)); discriminantTypes.push_back(discriminantTy); } else @@ -1172,10 +1204,10 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa { exprArgs.push_back(arg); - if (auto def = dfg->getDef(arg)) + if (auto bc = dfg->getBreadcrumb(arg)) { TypeId discriminantTy = arena->addType(BlockedType{}); - returnRefinements.push_back(refinementArena.proposition(*def, discriminantTy)); + returnRefinements.push_back(refinementArena.proposition(NotNull{bc}, discriminantTy)); discriminantTypes.push_back(discriminantTy); } else @@ -1186,6 +1218,8 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa TypeId fnType = check(scope, call->func).ty; Checkpoint fnEndCheckpoint = checkpoint(this); + std::vector> expectedTypesForCall = getExpectedCallTypesForFunctionOverloads(fnType); + module->astOriginalCallTypes[call->func] = fnType; TypePackId expectedArgPack = arena->freshTypePack(scope.get()); @@ -1208,9 +1242,9 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa TypePack expectedArgs; if (!needTail) - expectedArgs = extendTypePack(*arena, builtinTypes, expectedArgPack, exprArgs.size()); + expectedArgs = extendTypePack(*arena, builtinTypes, expectedArgPack, exprArgs.size(), expectedTypesForCall); else - expectedArgs = extendTypePack(*arena, builtinTypes, expectedArgPack, exprArgs.size() - 1); + expectedArgs = extendTypePack(*arena, builtinTypes, expectedArgPack, exprArgs.size() - 1, expectedTypesForCall); std::vector args; std::optional argTail; @@ -1278,9 +1312,9 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa if (AstExprLocal* targetLocal = targetExpr->as()) { scope->bindings[targetLocal->local].typeId = resultTy; - auto def = dfg->getDef(targetLocal->local); - if (def) - scope->dcrRefinements[*def] = resultTy; // TODO: typestates: track this as an assignment + + BreadcrumbId bc = dfg->getBreadcrumb(targetLocal); + scope->dcrRefinements[bc->def] = resultTy; // TODO: typestates: track this as an assignment } return InferencePack{arena->addTypePack({resultTy}), {refinementArena.variadic(returnRefinements)}}; @@ -1451,36 +1485,35 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprConstantBo Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprLocal* local) { - std::optional resultTy; - auto def = dfg->getDef(local); - if (def) - resultTy = scope->lookup(*def); + BreadcrumbId bc = dfg->getBreadcrumb(local); - if (!resultTy) - { - if (auto ty = scope->lookup(local->local)) - resultTy = *ty; - } - - if (!resultTy) - return Inference{builtinTypes->errorRecoveryType()}; // TODO: replace with ice, locals should never exist before its definition. - - if (def) - return Inference{*resultTy, refinementArena.proposition(*def, builtinTypes->truthyType)}; + if (auto ty = scope->lookup(bc->def)) + return Inference{*ty, refinementArena.proposition(bc, builtinTypes->truthyType)}; + else if (auto ty = scope->lookup(local->local)) + return Inference{*ty, refinementArena.proposition(bc, builtinTypes->truthyType)}; else - return Inference{*resultTy}; + ice->ice("AstExprLocal came before its declaration?"); } Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprGlobal* global) { - if (std::optional ty = scope->lookup(global->name)) - return Inference{*ty}; + BreadcrumbId bc = dfg->getBreadcrumb(global); /* prepopulateGlobalScope() has already added all global functions to the environment by this point, so any * global that is not already in-scope is definitely an unknown symbol. */ - reportError(global->location, UnknownSymbol{global->name.value}); - return Inference{builtinTypes->errorRecoveryType()}; + if (auto ty = scope->lookup(bc->def)) + return Inference{*ty, refinementArena.proposition(bc, builtinTypes->truthyType)}; + else if (auto ty = scope->lookup(global->name)) + { + rootScope->dcrRefinements[bc->def] = *ty; + return Inference{*ty, refinementArena.proposition(bc, builtinTypes->truthyType)}; + } + else + { + reportError(global->location, UnknownSymbol{global->name.value}); + return Inference{builtinTypes->errorRecoveryType()}; + } } Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexName* indexName) @@ -1488,19 +1521,19 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexName* TypeId obj = check(scope, indexName->expr).ty; TypeId result = arena->addType(BlockedType{}); - std::optional def = dfg->getDef(indexName); - if (def) + NullableBreadcrumbId bc = dfg->getBreadcrumb(indexName); + if (bc) { - if (auto ty = scope->lookup(*def)) - return Inference{*ty, refinementArena.proposition(*def, builtinTypes->truthyType)}; - else - scope->dcrRefinements[*def] = result; + if (auto ty = scope->lookup(bc->def)) + return Inference{*ty, refinementArena.proposition(NotNull{bc}, builtinTypes->truthyType)}; + + scope->dcrRefinements[bc->def] = result; } addConstraint(scope, indexName->expr->location, HasPropConstraint{result, obj, indexName->index.value}); - if (def) - return Inference{result, refinementArena.proposition(*def, builtinTypes->truthyType)}; + if (bc) + return Inference{result, refinementArena.proposition(NotNull{bc}, builtinTypes->truthyType)}; else return Inference{result}; } @@ -1509,15 +1542,26 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexExpr* { TypeId obj = check(scope, indexExpr->expr).ty; TypeId indexType = check(scope, indexExpr->index).ty; - TypeId result = freshType(scope); + NullableBreadcrumbId bc = dfg->getBreadcrumb(indexExpr); + if (bc) + { + if (auto ty = scope->lookup(bc->def)) + return Inference{*ty, refinementArena.proposition(NotNull{bc}, builtinTypes->truthyType)}; + + scope->dcrRefinements[bc->def] = result; + } + TableIndexer indexer{indexType, result}; TypeId tableType = arena->addType(TableType{TableType::Props{}, TableIndexer{indexType, result}, TypeLevel{}, scope.get(), TableState::Free}); addConstraint(scope, indexExpr->expr->location, SubtypeConstraint{obj, tableType}); - return Inference{result}; + if (bc) + return Inference{result, refinementArena.proposition(NotNull{bc}, builtinTypes->truthyType)}; + else + return Inference{result}; } Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprUnary* unary) @@ -1545,7 +1589,7 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprBinary* bi Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional expectedType) { ScopePtr condScope = childScope(ifElse->condition, scope); - auto [_, refinement] = check(scope, ifElse->condition); + RefinementId refinement = check(condScope, ifElse->condition).refinement; ScopePtr thenScope = childScope(ifElse->trueExpr, scope); applyRefinements(thenScope, ifElse->trueExpr->location, refinement); @@ -1600,8 +1644,8 @@ std::tuple ConstraintGraphBuilder::checkBinary( TypeId leftType = check(scope, binary->left).ty; TypeId rightType = check(scope, binary->right).ty; - std::optional def = dfg->getDef(typeguard->target); - if (!def) + NullableBreadcrumbId bc = dfg->getBreadcrumb(typeguard->target); + if (!bc) return {leftType, rightType, nullptr}; TypeId discriminantTy = builtinTypes->neverType; @@ -1637,7 +1681,7 @@ std::tuple ConstraintGraphBuilder::checkBinary( discriminantTy = ty; } - RefinementId proposition = refinementArena.proposition(*def, discriminantTy); + RefinementId proposition = refinementArena.proposition(NotNull{bc}, discriminantTy); if (binary->op == AstExprBinary::CompareEq) return {leftType, rightType, proposition}; else if (binary->op == AstExprBinary::CompareNe) @@ -1651,12 +1695,12 @@ std::tuple ConstraintGraphBuilder::checkBinary( TypeId rightType = check(scope, binary->right, expectedType, true).ty; RefinementId leftRefinement = nullptr; - if (auto def = dfg->getDef(binary->left)) - leftRefinement = refinementArena.proposition(*def, rightType); + if (auto bc = dfg->getBreadcrumb(binary->left)) + leftRefinement = refinementArena.proposition(NotNull{bc}, rightType); RefinementId rightRefinement = nullptr; - if (auto def = dfg->getDef(binary->right)) - rightRefinement = refinementArena.proposition(*def, leftType); + if (auto bc = dfg->getBreadcrumb(binary->right)) + rightRefinement = refinementArena.proposition(NotNull{bc}, leftType); if (binary->op == AstExprBinary::CompareNe) { @@ -1685,6 +1729,21 @@ std::vector ConstraintGraphBuilder::checkLValues(const ScopePtr& scope, return types; } +static bool isIndexNameEquivalent(AstExpr* expr) +{ + if (expr->is()) + return true; + + AstExprIndexExpr* e = expr->as(); + if (e == nullptr) + return false; + + if (!e->index->is()) + return false; + + return true; +} + /** * This function is mostly about identifying properties that are being inserted into unsealed tables. * @@ -1692,16 +1751,8 @@ std::vector ConstraintGraphBuilder::checkLValues(const ScopePtr& scope, */ TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr) { - if (auto indexExpr = expr->as()) + if (auto indexExpr = expr->as(); indexExpr && !indexExpr->index->is()) { - if (auto constantString = indexExpr->index->as()) - { - AstName syntheticIndex{constantString->value.data}; - AstExprIndexName synthetic{ - indexExpr->location, indexExpr->expr, syntheticIndex, constantString->location, indexExpr->expr->location.end, '.'}; - return checkLValue(scope, &synthetic); - } - // An indexer is only interesting in an lvalue-ey way if it is at the // tail of an expression. // @@ -1724,7 +1775,7 @@ TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr) return propType; } - else if (!expr->is()) + else if (!isIndexNameEquivalent(expr)) return check(scope, expr).ty; Symbol sym; @@ -1750,6 +1801,19 @@ TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr) exprs.push_back(e); e = indexName->expr; } + else if (auto indexExpr = e->as()) + { + if (auto strIndex = indexExpr->index->as()) + { + segments.push_back(std::string(strIndex->value.data, strIndex->value.size)); + exprs.push_back(e); + e = indexExpr->expr; + } + else + { + return check(scope, expr).ty; + } + } else return check(scope, expr).ty; } @@ -1788,13 +1852,10 @@ TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr) { symbolScope->bindings[sym].typeId = updatedType; - std::optional def = dfg->getDef(sym); - if (def) - { - // This can fail if the user is erroneously trying to augment a builtin - // table like os or string. - symbolScope->dcrRefinements[*def] = updatedType; - } + // This can fail if the user is erroneously trying to augment a builtin + // table like os or string. + if (auto bc = dfg->getBreadcrumb(e)) + symbolScope->dcrRefinements[bc->def] = updatedType; } return propTy; @@ -1984,36 +2045,32 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS argTypes.push_back(selfType); argNames.emplace_back(FunctionArgument{fn->self->name.value, fn->self->location}); signatureScope->bindings[fn->self] = Binding{selfType, fn->self->location}; + + BreadcrumbId bc = dfg->getBreadcrumb(fn->self); + signatureScope->dcrRefinements[bc->def] = selfType; } for (size_t i = 0; i < fn->args.size; ++i) { AstLocal* local = fn->args.data[i]; - TypeId t = freshType(signatureScope); - argTypes.push_back(t); - argNames.emplace_back(FunctionArgument{local->name.value, local->location}); - signatureScope->bindings[local] = Binding{t, local->location}; - - auto def = dfg->getDef(local); - LUAU_ASSERT(def); - signatureScope->dcrRefinements[*def] = t; - - TypeId annotationTy = t; - + TypeId argTy = nullptr; if (local->annotation) + argTy = resolveType(signatureScope, local->annotation, /* inTypeArguments */ false, /* replaceErrorWithFresh*/ true); + else { - annotationTy = resolveType(signatureScope, local->annotation, /* inTypeArguments */ false); - // If we provide an annotation that is wrong, type inference should ignore the annotation - // and try to infer a fresh type, like in the old solver - if (get(follow(annotationTy))) - annotationTy = freshType(signatureScope); - addConstraint(signatureScope, local->annotation->location, SubtypeConstraint{t, annotationTy}); - } - else if (i < expectedArgPack.head.size()) - { - addConstraint(signatureScope, local->location, SubtypeConstraint{t, expectedArgPack.head[i]}); + argTy = freshType(signatureScope); + + if (i < expectedArgPack.head.size()) + addConstraint(signatureScope, local->location, SubtypeConstraint{argTy, expectedArgPack.head[i]}); } + + argTypes.push_back(argTy); + argNames.emplace_back(FunctionArgument{local->name.value, local->location}); + signatureScope->bindings[local] = Binding{argTy, local->location}; + + BreadcrumbId bc = dfg->getBreadcrumb(local); + signatureScope->dcrRefinements[bc->def] = argTy; } TypePackId varargPack = nullptr; @@ -2022,7 +2079,8 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS { if (fn->varargAnnotation) { - TypePackId annotationType = resolveTypePack(signatureScope, fn->varargAnnotation, /* inTypeArguments */ false); + TypePackId annotationType = + resolveTypePack(signatureScope, fn->varargAnnotation, /* inTypeArguments */ false, /* replaceErrorWithFresh */ true); varargPack = annotationType; } else if (expectedArgPack.tail && get(*expectedArgPack.tail)) @@ -2049,8 +2107,8 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS // Type checking will sort out any discrepancies later. if (fn->returnAnnotation) { - TypePackId annotatedRetType = resolveTypePack(signatureScope, *fn->returnAnnotation, /* inTypeArguments */ false); - + TypePackId annotatedRetType = + resolveTypePack(signatureScope, *fn->returnAnnotation, /* inTypeArguments */ false, /* replaceErrorWithFresh*/ true); // We bind the annotated type directly here so that, when we need to // generate constraints for return types, we have a guarantee that we // know the annotated return type already, if one was provided. @@ -2098,7 +2156,7 @@ void ConstraintGraphBuilder::checkFunctionBody(const ScopePtr& scope, AstExprFun } } -TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, bool inTypeArguments) +TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, bool inTypeArguments, bool replaceErrorWithFresh) { TypeId result = nullptr; @@ -2176,6 +2234,8 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b else { result = builtinTypes->errorRecoveryType(); + if (replaceErrorWithFresh) + result = freshType(scope); } } else if (auto tab = ty->as()) @@ -2239,8 +2299,8 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b signatureScope = scope; } - TypePackId argTypes = resolveTypePack(signatureScope, fn->argTypes, inTypeArguments); - TypePackId returnTypes = resolveTypePack(signatureScope, fn->returnTypes, inTypeArguments); + TypePackId argTypes = resolveTypePack(signatureScope, fn->argTypes, inTypeArguments, replaceErrorWithFresh); + TypePackId returnTypes = resolveTypePack(signatureScope, fn->returnTypes, inTypeArguments, replaceErrorWithFresh); // TODO: FunctionType needs a pointer to the scope so that we know // how to quantify/instantiate it. @@ -2307,6 +2367,8 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b else if (ty->is()) { result = builtinTypes->errorRecoveryType(); + if (replaceErrorWithFresh) + result = freshType(scope); } else { @@ -2318,18 +2380,16 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b return result; } -TypePackId ConstraintGraphBuilder::resolveTypePack(const ScopePtr& scope, AstTypePack* tp, bool inTypeArgument) +TypePackId ConstraintGraphBuilder::resolveTypePack(const ScopePtr& scope, AstTypePack* tp, bool inTypeArgument, bool replaceErrorWithFresh) { TypePackId result; if (auto expl = tp->as()) { - result = resolveTypePack(scope, expl->typeList, inTypeArgument); + result = resolveTypePack(scope, expl->typeList, inTypeArgument, replaceErrorWithFresh); } else if (auto var = tp->as()) { - TypeId ty = resolveType(scope, var->variadicType, inTypeArgument); - if (get(follow(ty))) - ty = freshType(scope); + TypeId ty = resolveType(scope, var->variadicType, inTypeArgument, replaceErrorWithFresh); result = arena->addTypePack(TypePackVar{VariadicTypePack{ty}}); } else if (auto gen = tp->as()) @@ -2354,19 +2414,19 @@ TypePackId ConstraintGraphBuilder::resolveTypePack(const ScopePtr& scope, AstTyp return result; } -TypePackId ConstraintGraphBuilder::resolveTypePack(const ScopePtr& scope, const AstTypeList& list, bool inTypeArguments) +TypePackId ConstraintGraphBuilder::resolveTypePack(const ScopePtr& scope, const AstTypeList& list, bool inTypeArguments, bool replaceErrorWithFresh) { std::vector head; for (AstType* headTy : list.types) { - head.push_back(resolveType(scope, headTy, inTypeArguments)); + head.push_back(resolveType(scope, headTy, inTypeArguments, replaceErrorWithFresh)); } std::optional tail = std::nullopt; if (list.tailType) { - tail = resolveTypePack(scope, list.tailType, inTypeArguments); + tail = resolveTypePack(scope, list.tailType, inTypeArguments, replaceErrorWithFresh); } return arena->addTypePack(TypePack{head, tail}); @@ -2454,7 +2514,7 @@ void ConstraintGraphBuilder::reportError(Location location, TypeErrorData err) { errors.push_back(TypeError{location, moduleName, std::move(err)}); - if (FFlag::DebugLuauLogSolverToJson) + if (logger) logger->captureGenerationError(errors.back()); } @@ -2462,7 +2522,7 @@ void ConstraintGraphBuilder::reportCodeTooComplex(Location location) { errors.push_back(TypeError{location, moduleName, CodeTooComplex{}}); - if (FFlag::DebugLuauLogSolverToJson) + if (logger) logger->captureGenerationError(errors.back()); } @@ -2493,6 +2553,69 @@ void ConstraintGraphBuilder::prepopulateGlobalScope(const ScopePtr& globalScope, program->visit(&gp); } +std::vector> ConstraintGraphBuilder::getExpectedCallTypesForFunctionOverloads(const TypeId fnType) +{ + std::vector funTys; + if (auto it = get(follow(fnType))) + { + for (TypeId intersectionComponent : it) + { + funTys.push_back(intersectionComponent); + } + } + + std::vector> expectedTypes; + // For a list of functions f_0 : e_0 -> r_0, ... f_n : e_n -> r_n, + // emit a list of arguments that the function could take at each position + // by unioning the arguments at each place + auto assignOption = [this, &expectedTypes](size_t index, TypeId ty) { + if (index == expectedTypes.size()) + expectedTypes.push_back(ty); + else if (ty) + { + auto& el = expectedTypes[index]; + + if (!el) + el = ty; + else + { + std::vector result = reduceUnion({*el, ty}); + if (result.empty()) + el = builtinTypes->neverType; + else if (result.size() == 1) + el = result[0]; + else + el = module->internalTypes.addType(UnionType{std::move(result)}); + } + } + }; + + for (const TypeId overload : funTys) + { + if (const FunctionType* ftv = get(follow(overload))) + { + auto [argsHead, argsTail] = flatten(ftv->argTypes); + size_t start = ftv->hasSelf ? 1 : 0; + size_t index = 0; + for (size_t i = start; i < argsHead.size(); ++i) + assignOption(index++, argsHead[i]); + if (argsTail) + { + argsTail = follow(*argsTail); + if (const VariadicTypePack* vtp = get(*argsTail)) + { + while (index < funTys.size()) + assignOption(index++, vtp->ty); + } + } + } + } + + // TODO vvijay Feb 24, 2023 apparently we have to demote the types here? + + return expectedTypes; +} + std::vector> borrowConstraints(const std::vector& constraints) { std::vector> result; diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 96673e3d..3cb4e4e7 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -17,7 +17,6 @@ #include "Luau/VisitType.h" LUAU_FASTFLAGVARIABLE(DebugLuauLogSolver, false); -LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false); namespace Luau { @@ -261,9 +260,6 @@ ConstraintSolver::ConstraintSolver(NotNull normalizer, NotNullcaptureInitialSolverState(rootScope, unsolvedConstraints); } @@ -320,7 +316,7 @@ void ConstraintSolver::run() std::string saveMe = FFlag::DebugLuauLogSolver ? toString(*c, opts) : std::string{}; StepSnapshot snapshot; - if (FFlag::DebugLuauLogSolverToJson) + if (logger) { snapshot = logger->prepareStepSnapshot(rootScope, c, force, unsolvedConstraints); } @@ -334,7 +330,7 @@ void ConstraintSolver::run() unblock(c); unsolvedConstraints.erase(unsolvedConstraints.begin() + i); - if (FFlag::DebugLuauLogSolverToJson) + if (logger) { logger->commitStepSnapshot(snapshot); } @@ -393,7 +389,7 @@ void ConstraintSolver::run() dumpBindings(rootScope, opts); } - if (FFlag::DebugLuauLogSolverToJson) + if (logger) { logger->captureFinalSolverState(rootScope, unsolvedConstraints); } @@ -486,6 +482,9 @@ bool ConstraintSolver::tryDispatch(const SubtypeConstraint& c, NotNullreduce(subjectType).value_or(subjectType); - std::optional resultType = lookupTableProp(subjectType, c.prop); - if (!resultType) + auto [blocked, result] = lookupTableProp(subjectType, c.prop); + if (!blocked.empty()) { - asMutable(c.resultType)->ty.emplace(builtinTypes->errorRecoveryType()); - unblock(c.resultType); - return true; - } + for (TypeId blocked : blocked) + block(blocked, constraint); - if (isBlocked(*resultType)) - { - block(*resultType, constraint); return false; } - asMutable(c.resultType)->ty.emplace(*resultType); + asMutable(c.resultType)->ty.emplace(result.value_or(builtinTypes->errorRecoveryType())); + unblock(c.resultType); return true; } @@ -1426,17 +1421,18 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNull existingPropType = subjectType; for (const std::string& segment : c.path) { - ErrorVec e; - std::optional propTy = lookupTableProp(*existingPropType, segment); - if (!propTy) - { - existingPropType = std::nullopt; + if (!existingPropType) break; + + auto [blocked, result] = lookupTableProp(*existingPropType, segment); + if (!blocked.empty()) + { + for (TypeId blocked : blocked) + block(blocked, constraint); + return false; } - else if (isBlocked(*propTy)) - return block(*propTy, constraint); - else - existingPropType = follow(*propTy); + + existingPropType = result; } auto bind = [](TypeId a, TypeId b) { @@ -1451,6 +1447,9 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNull(subjectType)) + subjectType = follow(mt->table); + if (get(subjectType) || get(subjectType) || get(subjectType)) { bind(c.resultType, subjectType); @@ -1504,8 +1503,8 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNull(subjectType)) { - // Classes never change shape as a result of property assignments. - // The result is always the subject. + // Classes and intersections never change shape as a result of property + // assignments. The result is always the subject. bind(c.resultType, subjectType); return true; } @@ -1833,122 +1832,68 @@ bool ConstraintSolver::tryDispatchIterableFunction( return true; } -std::optional ConstraintSolver::lookupTableProp(TypeId subjectType, const std::string& propName) +std::pair, std::optional> ConstraintSolver::lookupTableProp(TypeId subjectType, const std::string& propName) { std::unordered_set seen; return lookupTableProp(subjectType, propName, seen); } -std::optional ConstraintSolver::lookupTableProp(TypeId subjectType, const std::string& propName, std::unordered_set& seen) +std::pair, std::optional> ConstraintSolver::lookupTableProp(TypeId subjectType, const std::string& propName, std::unordered_set& seen) { if (!seen.insert(subjectType).second) - return std::nullopt; + return {}; - auto collectParts = [&](auto&& unionOrIntersection) -> std::pair, std::vector> { - std::optional blocked; + subjectType = follow(subjectType); - std::vector parts; - std::vector freeParts; - for (TypeId expectedPart : unionOrIntersection) - { - expectedPart = follow(expectedPart); - if (isBlocked(expectedPart) || get(expectedPart)) - blocked = expectedPart; - else if (const TableType* ttv = get(follow(expectedPart))) - { - if (auto prop = ttv->props.find(propName); prop != ttv->props.end()) - parts.push_back(prop->second.type); - else if (ttv->indexer && maybeString(ttv->indexer->indexType)) - parts.push_back(ttv->indexer->indexResultType); - } - else if (get(expectedPart)) - { - freeParts.push_back(expectedPart); - } - } - - // If the only thing resembling a match is a single fresh type, we can - // confidently tablify it. If other types match or if there are more - // than one free type, we can't do anything. - if (parts.empty() && 1 == freeParts.size()) - { - TypeId freePart = freeParts.front(); - const FreeType* ft = get(freePart); - LUAU_ASSERT(ft); - Scope* scope = ft->scope; - - TableType* tt = &asMutable(freePart)->ty.emplace(); - tt->state = TableState::Free; - tt->scope = scope; - TypeId propType = arena->freshType(scope); - tt->props[propName] = Property{propType}; - - parts.push_back(propType); - } - - return {blocked, parts}; - }; - - std::optional resultType; - - if (get(subjectType) || get(subjectType)) + if (isBlocked(subjectType)) + return {{subjectType}, std::nullopt}; + else if (get(subjectType) || get(subjectType)) { - return subjectType; + return {{}, subjectType}; } else if (auto ttv = getMutable(subjectType)) { if (auto prop = ttv->props.find(propName); prop != ttv->props.end()) - resultType = prop->second.type; + return {{}, prop->second.type}; else if (ttv->indexer && maybeString(ttv->indexer->indexType)) - resultType = ttv->indexer->indexResultType; + return {{}, ttv->indexer->indexResultType}; else if (ttv->state == TableState::Free) { - resultType = arena->addType(FreeType{ttv->scope}); - ttv->props[propName] = Property{*resultType}; + TypeId result = arena->freshType(ttv->scope); + ttv->props[propName] = Property{result}; + return {{}, result}; } } else if (auto mt = get(subjectType)) { - if (auto p = lookupTableProp(mt->table, propName, seen)) - return p; + auto [blocked, result] = lookupTableProp(mt->table, propName, seen); + if (!blocked.empty() || result) + return {blocked, result}; TypeId mtt = follow(mt->metatable); if (get(mtt)) - return mtt; + return {{mtt}, std::nullopt}; else if (auto metatable = get(mtt)) { auto indexProp = metatable->props.find("__index"); if (indexProp == metatable->props.end()) - return std::nullopt; + return {{}, result}; // TODO: __index can be an overloaded function. TypeId indexType = follow(indexProp->second.type); if (auto ft = get(indexType)) - { - std::optional ret = first(ft->retTypes); - if (ret) - return *ret; - else - return std::nullopt; - } - - return lookupTableProp(indexType, propName, seen); + return {{}, first(ft->retTypes)}; + else + return lookupTableProp(indexType, propName, seen); } } else if (auto ct = get(subjectType)) { - while (ct) - { - if (auto prop = ct->props.find(propName); prop != ct->props.end()) - return prop->second.type; - else if (ct->parent) - ct = get(follow(*ct->parent)); - else - break; - } + if (auto p = lookupClassProp(ct, propName)) + return {{}, p->type}; } else if (auto pt = get(subjectType); pt && pt->metatable) { @@ -1957,38 +1902,70 @@ std::optional ConstraintSolver::lookupTableProp(TypeId subjectType, cons auto indexProp = metatable->props.find("__index"); if (indexProp == metatable->props.end()) - return std::nullopt; + return {{}, std::nullopt}; return lookupTableProp(indexProp->second.type, propName, seen); } + else if (auto ft = get(subjectType)) + { + Scope* scope = ft->scope; + + TableType* tt = &asMutable(subjectType)->ty.emplace(); + tt->state = TableState::Free; + tt->scope = scope; + TypeId propType = arena->freshType(scope); + tt->props[propName] = Property{propType}; + + return {{}, propType}; + } else if (auto utv = get(subjectType)) { - auto [blocked, parts] = collectParts(utv); + std::vector blocked; + std::vector options; - if (blocked) - resultType = *blocked; - else if (parts.size() == 1) - resultType = parts[0]; - else if (parts.size() > 1) - resultType = arena->addType(UnionType{std::move(parts)}); + for (TypeId ty : utv) + { + auto [innerBlocked, innerResult] = lookupTableProp(ty, propName, seen); + blocked.insert(blocked.end(), innerBlocked.begin(), innerBlocked.end()); + if (innerResult) + options.push_back(*innerResult); + } - // otherwise, nothing: no matching property + if (!blocked.empty()) + return {blocked, std::nullopt}; + + if (options.empty()) + return {{}, std::nullopt}; + else if (options.size() == 1) + return {{}, options[0]}; + else + return {{}, arena->addType(UnionType{std::move(options)})}; } else if (auto itv = get(subjectType)) { - auto [blocked, parts] = collectParts(itv); + std::vector blocked; + std::vector options; - if (blocked) - resultType = *blocked; - else if (parts.size() == 1) - resultType = parts[0]; - else if (parts.size() > 1) - resultType = arena->addType(IntersectionType{std::move(parts)}); + for (TypeId ty : itv) + { + auto [innerBlocked, innerResult] = lookupTableProp(ty, propName, seen); + blocked.insert(blocked.end(), innerBlocked.begin(), innerBlocked.end()); + if (innerResult) + options.push_back(*innerResult); + } - // otherwise, nothing: no matching property + if (!blocked.empty()) + return {blocked, std::nullopt}; + + if (options.empty()) + return {{}, std::nullopt}; + else if (options.size() == 1) + return {{}, options[0]}; + else + return {{}, arena->addType(IntersectionType{std::move(options)})}; } - return resultType; + return {{}, std::nullopt}; } void ConstraintSolver::block_(BlockedConstraintId target, NotNull constraint) @@ -2001,7 +1978,7 @@ void ConstraintSolver::block_(BlockedConstraintId target, NotNull target, NotNull constraint) { - if (FFlag::DebugLuauLogSolverToJson) + if (logger) logger->pushBlock(constraint, target); if (FFlag::DebugLuauLogSolver) @@ -2012,7 +1989,7 @@ void ConstraintSolver::block(NotNull target, NotNull constraint) { - if (FFlag::DebugLuauLogSolverToJson) + if (logger) logger->pushBlock(constraint, target); if (FFlag::DebugLuauLogSolver) @@ -2024,7 +2001,7 @@ bool ConstraintSolver::block(TypeId target, NotNull constraint bool ConstraintSolver::block(TypePackId target, NotNull constraint) { - if (FFlag::DebugLuauLogSolverToJson) + if (logger) logger->pushBlock(constraint, target); if (FFlag::DebugLuauLogSolver) @@ -2102,7 +2079,7 @@ void ConstraintSolver::unblock_(BlockedConstraintId progressed) void ConstraintSolver::unblock(NotNull progressed) { - if (FFlag::DebugLuauLogSolverToJson) + if (logger) logger->popBlock(progressed); return unblock_(progressed.get()); @@ -2110,7 +2087,7 @@ void ConstraintSolver::unblock(NotNull progressed) void ConstraintSolver::unblock(TypeId progressed) { - if (FFlag::DebugLuauLogSolverToJson) + if (logger) logger->popBlock(progressed); unblock_(progressed); @@ -2121,7 +2098,7 @@ void ConstraintSolver::unblock(TypeId progressed) void ConstraintSolver::unblock(TypePackId progressed) { - if (FFlag::DebugLuauLogSolverToJson) + if (logger) logger->popBlock(progressed); return unblock_(progressed); diff --git a/Analysis/src/DataFlowGraph.cpp b/Analysis/src/DataFlowGraph.cpp index 7e716603..e73c7e8c 100644 --- a/Analysis/src/DataFlowGraph.cpp +++ b/Analysis/src/DataFlowGraph.cpp @@ -1,7 +1,9 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/DataFlowGraph.h" +#include "Luau/Breadcrumb.h" #include "Luau/Error.h" +#include "Luau/Refinement.h" LUAU_FASTFLAG(DebugLuauFreezeArena) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) @@ -9,29 +11,74 @@ LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) namespace Luau { -std::optional DataFlowGraph::getDef(const AstExpr* expr) const +NullableBreadcrumbId DataFlowGraph::getBreadcrumb(const AstExpr* expr) const { // We need to skip through AstExprGroup because DFG doesn't try its best to transitively while (auto group = expr->as()) expr = group->expr; - if (auto def = astDefs.find(expr)) - return NotNull{*def}; - return std::nullopt; + if (auto bc = astBreadcrumbs.find(expr)) + return *bc; + return nullptr; } -std::optional DataFlowGraph::getDef(const AstLocal* local) const +BreadcrumbId DataFlowGraph::getBreadcrumb(const AstLocal* local) const { - if (auto def = localDefs.find(local)) - return NotNull{*def}; - return std::nullopt; + auto bc = localBreadcrumbs.find(local); + LUAU_ASSERT(bc); + return NotNull{*bc}; } -std::optional DataFlowGraph::getDef(const Symbol& symbol) const +BreadcrumbId DataFlowGraph::getBreadcrumb(const AstExprLocal* local) const { - if (symbol.local) - return getDef(symbol.local); - else - return std::nullopt; + auto bc = astBreadcrumbs.find(local); + LUAU_ASSERT(bc); + return NotNull{*bc}; +} + +BreadcrumbId DataFlowGraph::getBreadcrumb(const AstExprGlobal* global) const +{ + auto bc = astBreadcrumbs.find(global); + LUAU_ASSERT(bc); + return NotNull{*bc}; +} + +BreadcrumbId DataFlowGraph::getBreadcrumb(const AstStatDeclareGlobal* global) const +{ + auto bc = declaredBreadcrumbs.find(global); + LUAU_ASSERT(bc); + return NotNull{*bc}; +} + +BreadcrumbId DataFlowGraph::getBreadcrumb(const AstStatDeclareFunction* func) const +{ + auto bc = declaredBreadcrumbs.find(func); + LUAU_ASSERT(bc); + return NotNull{*bc}; +} + +NullableBreadcrumbId DfgScope::lookup(Symbol symbol) const +{ + for (const DfgScope* current = this; current; current = current->parent) + { + if (auto breadcrumb = current->bindings.find(symbol)) + return *breadcrumb; + } + + return nullptr; +} + +NullableBreadcrumbId DfgScope::lookup(DefId def, const std::string& key) const +{ + for (const DfgScope* current = this; current; current = current->parent) + { + if (auto map = props.find(def)) + { + if (auto it = map->find(key); it != map->end()) + return it->second; + } + } + + return nullptr; } DataFlowGraph DataFlowGraphBuilder::build(AstStatBlock* block, NotNull handle) @@ -40,9 +87,15 @@ DataFlowGraph DataFlowGraphBuilder::build(AstStatBlock* block, NotNullallocator.freeze(); + { + builder.defs->allocator.freeze(); + builder.breadcrumbs->allocator.freeze(); + } + return std::move(builder.graph); } @@ -51,29 +104,6 @@ DfgScope* DataFlowGraphBuilder::childScope(DfgScope* scope) return scopes.emplace_back(new DfgScope{scope}).get(); } -std::optional DataFlowGraphBuilder::use(DfgScope* scope, Symbol symbol, AstExpr* e) -{ - for (DfgScope* current = scope; current; current = current->parent) - { - if (auto def = current->bindings.find(symbol)) - { - graph.astDefs[e] = *def; - return NotNull{*def}; - } - } - - return std::nullopt; -} - -DefId DataFlowGraphBuilder::use(DefId def, AstExprIndexName* e) -{ - auto& propertyDef = props[def][e->index.value]; - if (!propertyDef) - propertyDef = arena->freshCell(def, e->index.value); - graph.astDefs[e] = propertyDef; - return NotNull{propertyDef}; -} - void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatBlock* b) { DfgScope* child = childScope(scope); @@ -119,27 +149,24 @@ void DataFlowGraphBuilder::visit(DfgScope* scope, AstStat* s) else if (auto l = s->as()) return visit(scope, l); else if (auto t = s->as()) - return; // ok - else if (auto d = s->as()) - return; // ok + return visit(scope, t); else if (auto d = s->as()) - return; // ok + return visit(scope, d); else if (auto d = s->as()) - return; // ok + return visit(scope, d); else if (auto d = s->as()) - return; // ok - else if (auto _ = s->as()) - return; // ok + return visit(scope, d); + else if (auto error = s->as()) + return visit(scope, error); else - handle->ice("Unknown AstStat in DataFlowGraphBuilder"); + handle->ice("Unknown AstStat in DataFlowGraphBuilder::visit"); } void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatIf* i) { - DfgScope* condScope = childScope(scope); - visitExpr(condScope, i->condition); - visit(condScope, i->thenbody); - + // TODO: type states and control flow analysis + visitExpr(scope, i->condition); + visit(scope, i->thenbody); if (i->elsebody) visit(scope, i->elsebody); } @@ -186,24 +213,41 @@ void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatExpr* e) void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatLocal* l) { - // TODO: alias tracking + // We're gonna need a `visitExprList` and `visitVariadicExpr` (function calls and `...`) + std::vector bcs; + bcs.reserve(l->values.size); for (AstExpr* e : l->values) - visitExpr(scope, e); + bcs.push_back(visitExpr(scope, e)); - for (AstLocal* local : l->vars) + for (size_t i = 0; i < l->vars.size; ++i) { - DefId def = arena->freshCell(); - graph.localDefs[local] = def; - scope->bindings[local] = def; + AstLocal* local = l->vars.data[i]; + if (local->annotation) + visitType(scope, local->annotation); + + // We need to create a new breadcrumb with new defs to intentionally avoid alias tracking. + BreadcrumbId bc = breadcrumbs->add(nullptr, defs->freshCell(), i < bcs.size() ? bcs[i]->metadata : std::nullopt); + graph.localBreadcrumbs[local] = bc; + scope->bindings[local] = bc; } } void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatFor* f) { DfgScope* forScope = childScope(scope); // TODO: loop scope. - DefId def = arena->freshCell(); - graph.localDefs[f->var] = def; - scope->bindings[f->var] = def; + + visitExpr(scope, f->from); + visitExpr(scope, f->to); + if (f->step) + visitExpr(scope, f->step); + + if (f->var->annotation) + visitType(forScope, f->var->annotation); + + // TODO: RangeMetadata. + BreadcrumbId bc = breadcrumbs->add(nullptr, defs->freshCell()); + graph.localBreadcrumbs[f->var] = bc; + scope->bindings[f->var] = bc; // TODO(controlflow): entry point has a back edge from exit point visit(forScope, f->body); @@ -215,12 +259,17 @@ void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatForIn* f) for (AstLocal* local : f->vars) { - DefId def = arena->freshCell(); - graph.localDefs[local] = def; - forScope->bindings[local] = def; + if (local->annotation) + visitType(forScope, local->annotation); + + // TODO: IterMetadata (different from RangeMetadata) + BreadcrumbId bc = breadcrumbs->add(nullptr, defs->freshCell()); + graph.localBreadcrumbs[local] = bc; + forScope->bindings[local] = bc; } // TODO(controlflow): entry point has a back edge from exit point + // We're gonna need a `visitExprList` and `visitVariadicExpr` (function calls and `...`) for (AstExpr* e : f->values) visitExpr(forScope, e); @@ -233,87 +282,117 @@ void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatAssign* a) visitExpr(scope, r); for (AstExpr* l : a->vars) - { - AstExpr* root = l; - - bool isUpdatable = true; - while (true) - { - if (root->is() || root->is()) - break; - - AstExprIndexName* indexName = root->as(); - if (!indexName) - { - isUpdatable = false; - break; - } - - root = indexName->expr; - } - - if (isUpdatable) - { - // TODO global? - if (auto exprLocal = root->as()) - { - DefId def = arena->freshCell(); - graph.astDefs[exprLocal] = def; - - // Update the def in the scope that introduced the local. Not - // the current scope. - AstLocal* local = exprLocal->local; - DfgScope* s = scope; - while (s && !s->bindings.find(local)) - s = s->parent; - LUAU_ASSERT(s && s->bindings.find(local)); - s->bindings[local] = def; - } - } - - visitExpr(scope, l); // TODO: they point to a new def!! - } + visitLValue(scope, l); } void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatCompoundAssign* c) { - // TODO(typestates): The lhs is being read and written to. This might or might not be annoying. + // TODO: This needs revisiting because this is incorrect. The `c->var` part is both being read and written to, + // but the `c->var` only has one pointer address, so we need to come up with a way to store both. + // For now, it's not important because we don't have type states, but it is going to be important, e.g. + // + // local a = 5 -- a[1] + // a += 5 -- a[2] = a[1] + 5 + // + // We can't just visit `c->var` as a rvalue and then separately traverse `c->var` as an lvalue, since that's O(n^2). + visitLValue(scope, c->var); visitExpr(scope, c->value); } void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatFunction* f) { - visitExpr(scope, f->name); + // In the old solver, we assumed that the name of the function is always a function in the body + // but this isn't true, e.g. the following example will print `5`, not a function address. + // + // local function f() print(f) end + // local g = f + // f = 5 + // g() --> 5 + // + // which is evidence that references to variables must be a phi node of all possible definitions, + // but for bug compatibility, we'll assume the same thing here. + visitLValue(scope, f->name); visitExpr(scope, f->func); } void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatLocalFunction* l) { - DefId def = arena->freshCell(); - graph.localDefs[l->name] = def; - scope->bindings[l->name] = def; + BreadcrumbId bc = breadcrumbs->add(nullptr, defs->freshCell()); + graph.localBreadcrumbs[l->name] = bc; + scope->bindings[l->name] = bc; visitExpr(scope, l->func); } -ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExpr* e) +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatTypeAlias* t) +{ + DfgScope* unreachable = childScope(scope); + visitGenerics(unreachable, t->generics); + visitGenericPacks(unreachable, t->genericPacks); + visitType(unreachable, t->type); +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatDeclareGlobal* d) +{ + // TODO: AmbientDeclarationMetadata. + BreadcrumbId bc = breadcrumbs->add(nullptr, defs->freshCell()); + graph.declaredBreadcrumbs[d] = bc; + scope->bindings[d->name] = bc; + + visitType(scope, d->type); +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatDeclareFunction* d) +{ + // TODO: AmbientDeclarationMetadata. + BreadcrumbId bc = breadcrumbs->add(nullptr, defs->freshCell()); + graph.declaredBreadcrumbs[d] = bc; + scope->bindings[d->name] = bc; + + DfgScope* unreachable = childScope(scope); + visitGenerics(unreachable, d->generics); + visitGenericPacks(unreachable, d->genericPacks); + visitTypeList(unreachable, d->params); + visitTypeList(unreachable, d->retTypes); +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatDeclareClass* d) +{ + // This declaration does not "introduce" any bindings in value namespace, + // so there's no symbolic value to begin with. We'll traverse the properties + // because their type annotations may depend on something in the value namespace. + DfgScope* unreachable = childScope(scope); + for (AstDeclaredClassProp prop : d->props) + visitType(unreachable, prop.ty); +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatError* error) +{ + DfgScope* unreachable = childScope(scope); + for (AstStat* s : error->statements) + visit(unreachable, s); + for (AstExpr* e : error->expressions) + visitExpr(unreachable, e); +} + +BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExpr* e) { if (auto g = e->as()) return visitExpr(scope, g->expr); else if (auto c = e->as()) - return {}; // ok + return breadcrumbs->add(nullptr, defs->freshCell()); // ok else if (auto c = e->as()) - return {}; // ok + return breadcrumbs->add(nullptr, defs->freshCell()); // ok else if (auto c = e->as()) - return {}; // ok + return breadcrumbs->add(nullptr, defs->freshCell()); // ok else if (auto c = e->as()) - return {}; // ok + return breadcrumbs->add(nullptr, defs->freshCell()); // ok else if (auto l = e->as()) return visitExpr(scope, l); else if (auto g = e->as()) return visitExpr(scope, g); else if (auto v = e->as()) - return {}; // ok + return breadcrumbs->add(nullptr, defs->freshCell()); // ok else if (auto c = e->as()) return visitExpr(scope, c); else if (auto i = e->as()) @@ -334,76 +413,123 @@ ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExpr* e) return visitExpr(scope, i); else if (auto i = e->as()) return visitExpr(scope, i); - else if (auto _ = e->as()) - return {}; // ok + else if (auto error = e->as()) + return visitExpr(scope, error); else - handle->ice("Unknown AstExpr in DataFlowGraphBuilder"); + handle->ice("Unknown AstExpr in DataFlowGraphBuilder::visitExpr"); } -ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprLocal* l) +BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprLocal* l) { - return {use(scope, l->local, l)}; + NullableBreadcrumbId breadcrumb = scope->lookup(l->local); + if (!breadcrumb) + handle->ice("AstExprLocal came before its declaration?"); + + graph.astBreadcrumbs[l] = breadcrumb; + return NotNull{breadcrumb}; } -ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprGlobal* g) +BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprGlobal* g) { - return {use(scope, g->name, g)}; + NullableBreadcrumbId bc = scope->lookup(g->name); + if (!bc) + { + bc = breadcrumbs->add(nullptr, defs->freshCell()); + moduleScope->bindings[g->name] = bc; + } + + graph.astBreadcrumbs[g] = bc; + return NotNull{bc}; } -ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprCall* c) +BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprCall* c) { visitExpr(scope, c->func); for (AstExpr* arg : c->args) visitExpr(scope, arg); - return {}; + return breadcrumbs->add(nullptr, defs->freshCell()); } -ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIndexName* i) +BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIndexName* i) { - std::optional def = visitExpr(scope, i->expr).def; - if (!def) - return {}; + BreadcrumbId parentBreadcrumb = visitExpr(scope, i->expr); - return {use(*def, i)}; + std::string key = i->index.value; + NullableBreadcrumbId& propBreadcrumb = moduleScope->props[parentBreadcrumb->def][key]; + if (!propBreadcrumb) + propBreadcrumb = breadcrumbs->emplace(parentBreadcrumb, defs->freshCell(), key); + + graph.astBreadcrumbs[i] = propBreadcrumb; + return NotNull{propBreadcrumb}; } -ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIndexExpr* i) +BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIndexExpr* i) { - visitExpr(scope, i->expr); - visitExpr(scope, i->index); + BreadcrumbId parentBreadcrumb = visitExpr(scope, i->expr); + BreadcrumbId key = visitExpr(scope, i->index); - if (i->index->as()) + if (auto string = i->index->as()) { - // TODO: properties for the def + std::string key{string->value.data, string->value.size}; + NullableBreadcrumbId& propBreadcrumb = moduleScope->props[parentBreadcrumb->def][key]; + if (!propBreadcrumb) + propBreadcrumb = breadcrumbs->emplace(parentBreadcrumb, defs->freshCell(), key); + + graph.astBreadcrumbs[i] = NotNull{propBreadcrumb}; + return NotNull{propBreadcrumb}; } - return {}; + return breadcrumbs->emplace(nullptr, defs->freshCell(), key); } -ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprFunction* f) +BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprFunction* f) { + DfgScope* signatureScope = childScope(scope); + if (AstLocal* self = f->self) { - DefId def = arena->freshCell(); - graph.localDefs[self] = def; - scope->bindings[self] = def; + // There's no syntax for `self` to have an annotation if using `function t:m()` + LUAU_ASSERT(!self->annotation); + + // TODO: ParameterMetadata. + BreadcrumbId bc = breadcrumbs->add(nullptr, defs->freshCell()); + graph.localBreadcrumbs[self] = bc; + signatureScope->bindings[self] = bc; } for (AstLocal* param : f->args) { - DefId def = arena->freshCell(); - graph.localDefs[param] = def; - scope->bindings[param] = def; + if (param->annotation) + visitType(signatureScope, param->annotation); + + // TODO: ParameterMetadata. + BreadcrumbId bc = breadcrumbs->add(nullptr, defs->freshCell()); + graph.localBreadcrumbs[param] = bc; + signatureScope->bindings[param] = bc; } - visit(scope, f->body); + if (f->varargAnnotation) + visitTypePack(scope, f->varargAnnotation); - return {}; + if (f->returnAnnotation) + visitTypeList(signatureScope, *f->returnAnnotation); + + // TODO: function body can be re-entrant, as in mutations that occurs at the end of the function can also be + // visible to the beginning of the function, so statically speaking, the body of the function has an exit point + // that points back to itself, e.g. + // + // local function f() print(f) f = 5 end + // local g = f + // g() --> function: address + // g() --> 5 + visit(signatureScope, f->body); + + return breadcrumbs->add(nullptr, defs->freshCell()); } -ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprTable* t) +BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprTable* t) { for (AstExprTable::Item item : t->items) { @@ -412,47 +538,259 @@ ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprTabl visitExpr(scope, item.value); } - return {}; + return breadcrumbs->add(nullptr, defs->freshCell()); } -ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprUnary* u) +BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprUnary* u) { visitExpr(scope, u->expr); - return {}; + return breadcrumbs->add(nullptr, defs->freshCell()); } -ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprBinary* b) +BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprBinary* b) { visitExpr(scope, b->left); visitExpr(scope, b->right); - return {}; + return breadcrumbs->add(nullptr, defs->freshCell()); } -ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprTypeAssertion* t) +BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprTypeAssertion* t) { - ExpressionFlowGraph result = visitExpr(scope, t->expr); - // TODO: visit type - return result; + // TODO: TypeAssertionMetadata? + BreadcrumbId bc = visitExpr(scope, t->expr); + visitType(scope, t->annotation); + + return bc; } -ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIfElse* i) +BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIfElse* i) { - DfgScope* condScope = childScope(scope); - visitExpr(condScope, i->condition); - visitExpr(condScope, i->trueExpr); - + visitExpr(scope, i->condition); + visitExpr(scope, i->trueExpr); visitExpr(scope, i->falseExpr); - return {}; + return breadcrumbs->add(nullptr, defs->freshCell()); } -ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprInterpString* i) +BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprInterpString* i) { for (AstExpr* e : i->expressions) visitExpr(scope, e); - return {}; + + return breadcrumbs->add(nullptr, defs->freshCell()); +} + +BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprError* error) +{ + DfgScope* unreachable = childScope(scope); + for (AstExpr* e : error->expressions) + visitExpr(unreachable, e); + + return breadcrumbs->add(nullptr, defs->freshCell()); +} + +void DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExpr* e) +{ + if (auto l = e->as()) + return visitLValue(scope, l); + else if (auto g = e->as()) + return visitLValue(scope, g); + else if (auto i = e->as()) + return visitLValue(scope, i); + else if (auto i = e->as()) + return visitLValue(scope, i); + else if (auto error = e->as()) + { + visitExpr(scope, error); // TODO: is this right? + return; + } + else + handle->ice("Unknown AstExpr in DataFlowGraphBuilder::visitLValue"); +} + +void DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprLocal* l) +{ + // Bug compatibility: we don't support type states yet, so we need to do this. + NullableBreadcrumbId bc = scope->lookup(l->local); + LUAU_ASSERT(bc); + + graph.astBreadcrumbs[l] = bc; + scope->bindings[l->local] = bc; +} + +void DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprGlobal* g) +{ + // Bug compatibility: we don't support type states yet, so we need to do this. + NullableBreadcrumbId bc = scope->lookup(g->name); + if (!bc) + bc = breadcrumbs->add(nullptr, defs->freshCell()); + + graph.astBreadcrumbs[g] = bc; + scope->bindings[g->name] = bc; +} + +void DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprIndexName* i) +{ + // Bug compatibility: we don't support type states yet, so we need to do this. + BreadcrumbId parentBreadcrumb = visitExpr(scope, i->expr); + + std::string key = i->index.value; + NullableBreadcrumbId propBreadcrumb = scope->lookup(parentBreadcrumb->def, key); + if (!propBreadcrumb) + { + propBreadcrumb = breadcrumbs->emplace(parentBreadcrumb, defs->freshCell(), key); + moduleScope->props[parentBreadcrumb->def][key] = propBreadcrumb; + } + + graph.astBreadcrumbs[i] = propBreadcrumb; +} + +void DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprIndexExpr* i) +{ + BreadcrumbId parentBreadcrumb = visitExpr(scope, i->expr); + visitExpr(scope, i->index); + + if (auto string = i->index->as()) + { + std::string key{string->value.data, string->value.size}; + NullableBreadcrumbId propBreadcrumb = scope->lookup(parentBreadcrumb->def, key); + if (!propBreadcrumb) + { + propBreadcrumb = breadcrumbs->add(parentBreadcrumb, parentBreadcrumb->def); + moduleScope->props[parentBreadcrumb->def][key] = propBreadcrumb; + } + + graph.astBreadcrumbs[i] = propBreadcrumb; + } +} + +void DataFlowGraphBuilder::visitType(DfgScope* scope, AstType* t) +{ + if (auto r = t->as()) + return visitType(scope, r); + else if (auto table = t->as()) + return visitType(scope, table); + else if (auto f = t->as()) + return visitType(scope, f); + else if (auto tyof = t->as()) + return visitType(scope, tyof); + else if (auto u = t->as()) + return visitType(scope, u); + else if (auto i = t->as()) + return visitType(scope, i); + else if (auto e = t->as()) + return visitType(scope, e); + else if (auto s = t->as()) + return; // ok + else if (auto s = t->as()) + return; // ok + else + handle->ice("Unknown AstType in DataFlowGraphBuilder::visitType"); +} + +void DataFlowGraphBuilder::visitType(DfgScope* scope, AstTypeReference* r) +{ + for (AstTypeOrPack param : r->parameters) + { + if (param.type) + visitType(scope, param.type); + else + visitTypePack(scope, param.typePack); + } +} + +void DataFlowGraphBuilder::visitType(DfgScope* scope, AstTypeTable* t) +{ + for (AstTableProp p : t->props) + visitType(scope, p.type); + + if (t->indexer) + { + visitType(scope, t->indexer->indexType); + visitType(scope, t->indexer->resultType); + } +} + +void DataFlowGraphBuilder::visitType(DfgScope* scope, AstTypeFunction* f) +{ + visitGenerics(scope, f->generics); + visitGenericPacks(scope, f->genericPacks); + visitTypeList(scope, f->argTypes); + visitTypeList(scope, f->returnTypes); +} + +void DataFlowGraphBuilder::visitType(DfgScope* scope, AstTypeTypeof* t) +{ + visitExpr(scope, t->expr); +} + +void DataFlowGraphBuilder::visitType(DfgScope* scope, AstTypeUnion* u) +{ + for (AstType* t : u->types) + visitType(scope, t); +} + +void DataFlowGraphBuilder::visitType(DfgScope* scope, AstTypeIntersection* i) +{ + for (AstType* t : i->types) + visitType(scope, t); +} + +void DataFlowGraphBuilder::visitType(DfgScope* scope, AstTypeError* error) +{ + for (AstType* t : error->types) + visitType(scope, t); +} + +void DataFlowGraphBuilder::visitTypePack(DfgScope* scope, AstTypePack* p) +{ + if (auto e = p->as()) + return visitTypePack(scope, e); + else if (auto v = p->as()) + return visitTypePack(scope, v); + else if (auto g = p->as()) + return; // ok + else + handle->ice("Unknown AstTypePack in DataFlowGraphBuilder::visitTypePack"); +} + +void DataFlowGraphBuilder::visitTypePack(DfgScope* scope, AstTypePackExplicit* e) +{ + visitTypeList(scope, e->typeList); +} + +void DataFlowGraphBuilder::visitTypePack(DfgScope* scope, AstTypePackVariadic* v) +{ + visitType(scope, v->variadicType); +} + +void DataFlowGraphBuilder::visitTypeList(DfgScope* scope, AstTypeList l) +{ + for (AstType* t : l.types) + visitType(scope, t); + + if (l.tailType) + visitTypePack(scope, l.tailType); +} + +void DataFlowGraphBuilder::visitGenerics(DfgScope* scope, AstArray g) +{ + for (AstGenericType generic : g) + { + if (generic.defaultValue) + visitType(scope, generic.defaultValue); + } +} + +void DataFlowGraphBuilder::visitGenericPacks(DfgScope* scope, AstArray g) +{ + for (AstGenericTypePack generic : g) + { + if (generic.defaultValue) + visitTypePack(scope, generic.defaultValue); + } } } // namespace Luau diff --git a/Analysis/src/Def.cpp b/Analysis/src/Def.cpp index 8ce1129c..7be075c2 100644 --- a/Analysis/src/Def.cpp +++ b/Analysis/src/Def.cpp @@ -6,12 +6,7 @@ namespace Luau DefId DefArena::freshCell() { - return NotNull{allocator.allocate(Def{Cell{std::nullopt}})}; -} - -DefId DefArena::freshCell(DefId parent, const std::string& prop) -{ - return NotNull{allocator.allocate(Def{Cell{FieldMetadata{parent, prop}}})}; + return NotNull{allocator.allocate(Def{Cell{}})}; } } // namespace Luau diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 91c72e44..b3e453db 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -25,12 +25,13 @@ #include LUAU_FASTINT(LuauTypeInferIterationLimit) +LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTINT(LuauTarjanChildLimit) LUAU_FASTFLAG(LuauInferInNoCheckMode) LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 100) LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false) -LUAU_FASTFLAG(DebugLuauLogSolverToJson); +LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false); namespace Luau { @@ -517,7 +518,7 @@ CheckResult Frontend::check(const ModuleName& name, std::optional& requireCycles, - NotNull builtinTypes, - NotNull iceHandler, - NotNull moduleResolver, - NotNull fileResolver, - const ScopePtr& globalScope, - NotNull unifierState, - FrontendOptions options -) { +ModulePtr check(const SourceModule& sourceModule, const std::vector& requireCycles, NotNull builtinTypes, + NotNull iceHandler, NotNull moduleResolver, NotNull fileResolver, + const ScopePtr& globalScope, FrontendOptions options) +{ + const bool recordJsonLog = FFlag::DebugLuauLogSolverToJson; + return check(sourceModule, requireCycles, builtinTypes, iceHandler, moduleResolver, fileResolver, globalScope, options, recordJsonLog); +} + +ModulePtr check(const SourceModule& sourceModule, const std::vector& requireCycles, NotNull builtinTypes, + NotNull iceHandler, NotNull moduleResolver, NotNull fileResolver, + const ScopePtr& globalScope, FrontendOptions options, bool recordJsonLog) +{ ModulePtr result = std::make_shared(); result->reduction = std::make_unique(NotNull{&result->internalTypes}, builtinTypes, iceHandler); std::unique_ptr logger; - if (FFlag::DebugLuauLogSolverToJson) + if (recordJsonLog) { logger = std::make_unique(); std::optional source = fileResolver->readSource(sourceModule.name); @@ -882,7 +886,11 @@ ModulePtr check( DataFlowGraph dfg = DataFlowGraphBuilder::build(sourceModule.root, iceHandler); - Normalizer normalizer{&result->internalTypes, builtinTypes, unifierState}; + UnifierSharedState unifierState{iceHandler}; + unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit; + unifierState.counters.iterationLimit = FInt::LuauTypeInferIterationLimit; + + Normalizer normalizer{&result->internalTypes, builtinTypes, NotNull{&unifierState}}; ConstraintGraphBuilder cgb{ sourceModule.name, @@ -925,7 +933,7 @@ ModulePtr check( freeze(result->internalTypes); freeze(result->interfaceTypes); - if (FFlag::DebugLuauLogSolverToJson) + if (recordJsonLog) { std::string output = logger->compileOutput(); printf("%s\n", output.c_str()); @@ -934,20 +942,11 @@ ModulePtr check( return result; } -ModulePtr Frontend::check( - const SourceModule& sourceModule, Mode mode, std::vector requireCycles, bool forAutocomplete) +ModulePtr Frontend::check(const SourceModule& sourceModule, Mode mode, std::vector requireCycles, bool forAutocomplete, bool recordJsonLog) { - return Luau::check( - sourceModule, - requireCycles, - builtinTypes, - NotNull{&iceHandler}, - NotNull{forAutocomplete ? &moduleResolverForAutocomplete : &moduleResolver}, - NotNull{fileResolver}, - forAutocomplete ? typeCheckerForAutocomplete.globalScope : typeChecker.globalScope, - NotNull{&typeChecker.unifierState}, - options - ); + return Luau::check(sourceModule, requireCycles, builtinTypes, NotNull{&iceHandler}, + NotNull{forAutocomplete ? &moduleResolverForAutocomplete : &moduleResolver}, NotNull{fileResolver}, + forAutocomplete ? typeCheckerForAutocomplete.globalScope : typeChecker.globalScope, options, recordJsonLog); } // Read AST into sourceModules if necessary. Trace require()s. Report parse errors. diff --git a/Analysis/src/Linter.cpp b/Analysis/src/Linter.cpp index 65ad8a82..f850bd3d 100644 --- a/Analysis/src/Linter.cpp +++ b/Analysis/src/Linter.cpp @@ -14,6 +14,8 @@ LUAU_FASTINTVARIABLE(LuauSuggestionDistance, 4) +LUAU_FASTFLAGVARIABLE(LuauImproveDeprecatedApiLint, false) + namespace Luau { @@ -2100,7 +2102,7 @@ class LintDeprecatedApi : AstVisitor public: LUAU_NOINLINE static void process(LintContext& context) { - if (!context.module) + if (!FFlag::LuauImproveDeprecatedApiLint && !context.module) return; LintDeprecatedApi pass{&context}; @@ -2117,26 +2119,51 @@ private: bool visit(AstExprIndexName* node) override { - std::optional ty = context->getType(node->expr); - if (!ty) - return true; + if (std::optional ty = context->getType(node->expr)) + check(node, follow(*ty)); + else if (AstExprGlobal* global = node->expr->as()) + if (FFlag::LuauImproveDeprecatedApiLint) + check(node->location, global->name, node->index); - if (const ClassType* cty = get(follow(*ty))) + return true; + } + + void check(AstExprIndexName* node, TypeId ty) + { + if (const ClassType* cty = get(ty)) { const Property* prop = lookupClassProp(cty, node->index.value); if (prop && prop->deprecated) report(node->location, *prop, cty->name.c_str(), node->index.value); } - else if (const TableType* tty = get(follow(*ty))) + else if (const TableType* tty = get(ty)) { auto prop = tty->props.find(node->index.value); if (prop != tty->props.end() && prop->second.deprecated) - report(node->location, prop->second, tty->name ? tty->name->c_str() : nullptr, node->index.value); + { + // strip synthetic typeof() for builtin tables + if (FFlag::LuauImproveDeprecatedApiLint && tty->name && tty->name->compare(0, 7, "typeof(") == 0 && tty->name->back() == ')') + report(node->location, prop->second, tty->name->substr(7, tty->name->length() - 8).c_str(), node->index.value); + else + report(node->location, prop->second, tty->name ? tty->name->c_str() : nullptr, node->index.value); + } } + } - return true; + void check(const Location& location, AstName global, AstName index) + { + if (const LintContext::Global* gv = context->builtinGlobals.find(global)) + { + if (const TableType* tty = get(gv->type)) + { + auto prop = tty->props.find(index.value); + + if (prop != tty->props.end() && prop->second.deprecated) + report(location, prop->second, global.value, index.value); + } + } } void report(const Location& location, const Property& prop, const char* container, const char* field) diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index 0b760810..0552bec0 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -2653,6 +2653,14 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) intersectNormals(here, *negated); } } + else if (get(t)) + { + // HACK: Refinements sometimes intersect with ~any under the + // assumption that it is the same as any. + return true; + } + else if (auto nt = get(t)) + return intersectNormalWithTy(here, nt->ty); else { // TODO negated unions, intersections, table, and function. diff --git a/Analysis/src/Refinement.cpp b/Analysis/src/Refinement.cpp index 459379ad..a81063c7 100644 --- a/Analysis/src/Refinement.cpp +++ b/Analysis/src/Refinement.cpp @@ -4,6 +4,11 @@ namespace Luau { +RefinementId RefinementArena::variadic(const std::vector& refis) +{ + return NotNull{allocator.allocate(Variadic{refis})}; +} + RefinementId RefinementArena::negation(RefinementId refinement) { return NotNull{allocator.allocate(Negation{refinement})}; @@ -24,14 +29,9 @@ RefinementId RefinementArena::equivalence(RefinementId lhs, RefinementId rhs) return NotNull{allocator.allocate(Equivalence{lhs, rhs})}; } -RefinementId RefinementArena::proposition(DefId def, TypeId discriminantTy) +RefinementId RefinementArena::proposition(BreadcrumbId breadcrumb, TypeId discriminantTy) { - return NotNull{allocator.allocate(Proposition{def, discriminantTy})}; -} - -RefinementId RefinementArena::variadic(const std::vector& refis) -{ - return NotNull{allocator.allocate(Variadic{refis})}; + return NotNull{allocator.allocate(Proposition{breadcrumb, discriminantTy})}; } } // namespace Luau diff --git a/Analysis/src/Type.cpp b/Analysis/src/Type.cpp index 2f69f698..f15f8c4c 100644 --- a/Analysis/src/Type.cpp +++ b/Analysis/src/Type.cpp @@ -745,6 +745,7 @@ BuiltinTypes::BuiltinTypes() , functionType(arena->addType(Type{PrimitiveType{PrimitiveType::Function}, /*persistent*/ true})) , classType(arena->addType(Type{ClassType{"class", {}, std::nullopt, std::nullopt, {}, {}, {}}, /*persistent*/ true})) , tableType(arena->addType(Type{PrimitiveType{PrimitiveType::Table}, /*persistent*/ true})) + , emptyTableType(arena->addType(Type{TableType{TableState::Sealed, TypeLevel{}, nullptr}, /*persistent*/ true})) , trueType(arena->addType(Type{SingletonType{BooleanSingleton{true}}, /*persistent*/ true})) , falseType(arena->addType(Type{SingletonType{BooleanSingleton{false}}, /*persistent*/ true})) , anyType(arena->addType(Type{AnyType{}, /*persistent*/ true})) diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index f23fad78..aacfd729 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -19,7 +19,6 @@ #include -LUAU_FASTFLAG(DebugLuauLogSolverToJson) LUAU_FASTFLAG(DebugLuauMagicTypes) LUAU_FASTFLAG(DebugLuauDontReduceTypes) @@ -105,8 +104,6 @@ struct TypeChecker2 , sourceModule(sourceModule) , module(module) { - if (FFlag::DebugLuauLogSolverToJson) - LUAU_ASSERT(logger); } std::optional pushStack(AstNode* node) @@ -918,13 +915,9 @@ struct TypeChecker2 reportError(ExtraInformation{"Other overloads are also not viable: " + s}, call->func->location); } - void visit(AstExprCall* call) + // Note: this is intentionally separated from `visit(AstExprCall*)` for stack allocation purposes. + void visitCall(AstExprCall* call) { - visit(call->func, RValue); - - for (AstExpr* arg : call->args) - visit(arg, RValue); - TypeArena* arena = &testArena; Instantiation instantiation{TxnLog::empty(), arena, TypeLevel{}, stack.back()}; @@ -1099,6 +1092,16 @@ struct TypeChecker2 reportOverloadResolutionErrors(call, overloads, expectedArgTypes, overloadsThatMatchArgCount, overloadsErrors); } + void visit(AstExprCall* call) + { + visit(call->func, RValue); + + for (AstExpr* arg : call->args) + visit(arg, RValue); + + visitCall(call); + } + void visitExprName(AstExpr* expr, Location location, const std::string& propName, ValueContext context) { visit(expr, RValue); @@ -1169,9 +1172,9 @@ struct TypeChecker2 TypeId inferredArgTy = *argIt; TypeId annotatedArgTy = lookupAnnotation(arg->annotation); - if (!isSubtype(annotatedArgTy, inferredArgTy, stack.back())) + if (!isSubtype(inferredArgTy, annotatedArgTy, stack.back())) { - reportError(TypeMismatch{annotatedArgTy, inferredArgTy}, arg->location); + reportError(TypeMismatch{inferredArgTy, annotatedArgTy}, arg->location); } } @@ -1726,7 +1729,7 @@ struct TypeChecker2 } } - for (size_t i = packsProvided; i < packsProvided; ++i) + for (size_t i = packsProvided; i < packsRequired; ++i) { if (alias->typePackParams[i].defaultValue) { @@ -1948,7 +1951,7 @@ struct TypeChecker2 { module->errors.emplace_back(location, sourceModule->name, std::move(data)); - if (FFlag::DebugLuauLogSolverToJson) + if (logger) logger->captureTypeCheckError(module->errors.back()); } @@ -2053,8 +2056,8 @@ struct TypeChecker2 if (findTablePropertyRespectingMeta(builtinTypes, module->errors, ty, prop, location)) return true; - else if (tt->indexer && isPrim(tt->indexer->indexResultType, PrimitiveType::String)) - return tt->indexer->indexResultType; + else if (tt->indexer && isPrim(tt->indexer->indexType, PrimitiveType::String)) + return true; else return false; diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index adca034c..6aa8e6ca 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -937,9 +937,9 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) TypeId right = nullptr; - Location loc = 0 == assign.values.size - ? assign.location - : i < assign.values.size ? assign.values.data[i]->location : assign.values.data[assign.values.size - 1]->location; + Location loc = 0 == assign.values.size ? assign.location + : i < assign.values.size ? assign.values.data[i]->location + : assign.values.data[assign.values.size - 1]->location; if (valueIter != valueEnd) { @@ -3170,7 +3170,8 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex property.location = expr.indexLocation; return theType; } - else if (FFlag::LuauDontExtendUnsealedRValueTables && ((ctx == ValueContext::LValue && lhsTable->state == TableState::Unsealed) || lhsTable->state == TableState::Free)) + else if (FFlag::LuauDontExtendUnsealedRValueTables && + ((ctx == ValueContext::LValue && lhsTable->state == TableState::Unsealed) || lhsTable->state == TableState::Free)) { TypeId theType = freshType(scope); Property& property = lhsTable->props[name]; @@ -3299,7 +3300,8 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex property.location = expr.index->location; return resultType; } - else if (FFlag::LuauDontExtendUnsealedRValueTables && ((ctx == ValueContext::LValue && exprTable->state == TableState::Unsealed) || exprTable->state == TableState::Free)) + else if (FFlag::LuauDontExtendUnsealedRValueTables && + ((ctx == ValueContext::LValue && exprTable->state == TableState::Unsealed) || exprTable->state == TableState::Free)) { TypeId resultType = freshType(scope); Property& property = exprTable->props[value->value.data]; @@ -3321,7 +3323,8 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex exprTable->indexer = TableIndexer{anyIfNonstrict(indexType), anyIfNonstrict(resultType)}; return resultType; } - else if (FFlag::LuauDontExtendUnsealedRValueTables && ((ctx == ValueContext::LValue && exprTable->state == TableState::Unsealed) || exprTable->state == TableState::Free)) + else if (FFlag::LuauDontExtendUnsealedRValueTables && + ((ctx == ValueContext::LValue && exprTable->state == TableState::Unsealed) || exprTable->state == TableState::Free)) { TypeId indexerType = freshType(exprTable->level); unify(indexType, indexerType, scope, expr.location); diff --git a/Analysis/src/TypeReduction.cpp b/Analysis/src/TypeReduction.cpp index 2393829d..abafa9fb 100644 --- a/Analysis/src/TypeReduction.cpp +++ b/Analysis/src/TypeReduction.cpp @@ -434,6 +434,14 @@ std::optional TypeReducer::intersectionType(TypeId left, TypeId right) return std::nullopt; // error & T ~ error & T else if (get(right)) return std::nullopt; // T & error ~ T & error + else if (get(left)) + return std::nullopt; // *blocked* & T ~ *blocked* & T + else if (get(right)) + return std::nullopt; // T & *blocked* ~ T & *blocked* + else if (get(left)) + return std::nullopt; // *pending* & T ~ *pending* & T + else if (get(right)) + return std::nullopt; // T & *pending* ~ T & *pending* else if (auto ut = get(left)) return reduce(distribute(begin(ut), end(ut), &TypeReducer::intersectionType, right)); // (A | B) & T ~ (A & T) | (B & T) else if (get(right)) diff --git a/Analysis/src/TypeUtils.cpp b/Analysis/src/TypeUtils.cpp index f8f51bcf..e5029e58 100644 --- a/Analysis/src/TypeUtils.cpp +++ b/Analysis/src/TypeUtils.cpp @@ -117,7 +117,8 @@ std::pair> getParameterExtents(const TxnLog* log, return {minCount, minCount + optionalCount}; } -TypePack extendTypePack(TypeArena& arena, NotNull builtinTypes, TypePackId pack, size_t length) +TypePack extendTypePack( + TypeArena& arena, NotNull builtinTypes, TypePackId pack, size_t length, std::vector> overrides) { TypePack result; @@ -179,11 +180,22 @@ TypePack extendTypePack(TypeArena& arena, NotNull builtinTypes, Ty TypePack newPack; newPack.tail = arena.freshTypePack(ftp->scope); - + size_t overridesIndex = 0; while (result.head.size() < length) { - newPack.head.push_back(arena.freshType(ftp->scope)); + TypeId t; + if (overridesIndex < overrides.size() && overrides[overridesIndex]) + { + t = *overrides[overridesIndex]; + } + else + { + t = arena.freshType(ftp->scope); + } + + newPack.head.push_back(t); result.head.push_back(newPack.head.back()); + overridesIndex++; } asMutable(pack)->ty.emplace(std::move(newPack)); diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 6364a5aa..aba64271 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -312,7 +312,7 @@ TypePackId Widen::operator()(TypePackId tp) return substitute(tp).value_or(tp); } -static std::optional hasUnificationTooComplex(const ErrorVec& errors) +std::optional hasUnificationTooComplex(const ErrorVec& errors) { auto isUnificationTooComplex = [](const TypeError& te) { return nullptr != get(te); @@ -375,7 +375,6 @@ Unifier::Unifier(NotNull normalizer, Mode mode, NotNull scope , variance(variance) , sharedState(*normalizer->sharedState) { - normalize = true; LUAU_ASSERT(sharedState.iceHandler); } @@ -561,6 +560,11 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool else if (log.getMutable(superTy) && log.getMutable(subTy)) tryUnifyFunctions(subTy, superTy, isFunctionCall); + else if (auto table = log.get(superTy); table && table->type == PrimitiveType::Table) + tryUnify(subTy, builtinTypes->emptyTableType, isFunctionCall, isIntersection); + else if (auto table = log.get(subTy); table && table->type == PrimitiveType::Table) + tryUnify(builtinTypes->emptyTableType, superTy, isFunctionCall, isIntersection); + else if (log.getMutable(superTy) && log.getMutable(subTy)) { tryUnifyTables(subTy, superTy, isIntersection); @@ -591,7 +595,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool else if (log.get(superTy) || log.get(subTy)) tryUnifyNegations(subTy, superTy); - else if (FFlag::LuauUninhabitedSubAnything2 && !normalizer->isInhabited(subTy)) + else if (FFlag::LuauUninhabitedSubAnything2 && checkInhabited && !normalizer->isInhabited(subTy)) { } @@ -1769,6 +1773,12 @@ struct Resetter void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) { + if (isPrim(log.follow(subTy), PrimitiveType::Table)) + subTy = builtinTypes->emptyTableType; + + if (isPrim(log.follow(superTy), PrimitiveType::Table)) + superTy = builtinTypes->emptyTableType; + TypeId activeSubTy = subTy; TableType* superTable = log.getMutable(superTy); TableType* subTable = log.getMutable(subTy); @@ -2092,7 +2102,7 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed) TypeId osubTy = subTy; TypeId osuperTy = superTy; - if (FFlag::LuauUninhabitedSubAnything2 && !normalizer->isInhabited(subTy)) + if (FFlag::LuauUninhabitedSubAnything2 && checkInhabited && !normalizer->isInhabited(subTy)) return; if (reversed) @@ -2682,6 +2692,7 @@ Unifier Unifier::makeChildUnifier() { Unifier u = Unifier{normalizer, mode, scope, location, variance, &log}; u.normalize = normalize; + u.checkInhabited = checkInhabited; u.useScopes = useScopes; return u; } diff --git a/Ast/src/Lexer.cpp b/Ast/src/Lexer.cpp index 118b0679..dac3b95b 100644 --- a/Ast/src/Lexer.cpp +++ b/Ast/src/Lexer.cpp @@ -6,6 +6,8 @@ #include +LUAU_FASTFLAGVARIABLE(LuauFixInterpStringMid, false) + namespace Luau { @@ -640,7 +642,8 @@ Lexeme Lexer::readInterpolatedStringSection(Position start, Lexeme::Type formatT } consume(); - Lexeme lexemeOutput(Location(start, position()), Lexeme::InterpStringBegin, &buffer[startOffset], offset - startOffset - 1); + Lexeme lexemeOutput(Location(start, position()), FFlag::LuauFixInterpStringMid ? formatType : Lexeme::InterpStringBegin, + &buffer[startOffset], offset - startOffset - 1); return lexemeOutput; } diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 4d61914f..4c347712 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -1248,7 +1248,11 @@ std::pair Parser::parseReturnTypeAnnotation() { AstType* returnType = parseTypeAnnotation(result, innerBegin); - return {Location{location, returnType->location}, AstTypeList{copy(&returnType, 1), varargAnnotation}}; + // If parseTypeAnnotation parses nothing, then returnType->location.end only points at the last non-type-pack + // type to successfully parse. We need the span of the whole annotation. + Position endPos = result.size() == 1 ? location.end : returnType->location.end; + + return {Location{location.begin, endPos}, AstTypeList{copy(&returnType, 1), varargAnnotation}}; } return {location, AstTypeList{copy(result), varargAnnotation}}; @@ -2623,8 +2627,6 @@ AstExpr* Parser::parseInterpString() endLocation = currentLexeme.location; - Location startOfBrace = Location(endLocation.end, 1); - scratchData.assign(currentLexeme.data, currentLexeme.length); if (!Lexer::fixupQuotedString(scratchData)) diff --git a/CMakeLists.txt b/CMakeLists.txt index 4255c7c2..6e15e5f8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -14,9 +14,11 @@ option(LUAU_STATIC_CRT "Link with the static CRT (/MT)" OFF) option(LUAU_EXTERN_C "Use extern C for all APIs" OFF) option(LUAU_NATIVE "Enable support for native code generation" OFF) +cmake_policy(SET CMP0054 NEW) +cmake_policy(SET CMP0091 NEW) + if(LUAU_STATIC_CRT) cmake_minimum_required(VERSION 3.15) - cmake_policy(SET CMP0091 NEW) set(CMAKE_MSVC_RUNTIME_LIBRARY "MultiThreaded$<$:Debug>") endif() @@ -88,9 +90,15 @@ set(LUAU_OPTIONS) if(MSVC) list(APPEND LUAU_OPTIONS /D_CRT_SECURE_NO_WARNINGS) # We need to use the portable CRT functions. - list(APPEND LUAU_OPTIONS /MP) # Distribute single project compilation across multiple cores + list(APPEND LUAU_OPTIONS "/we4018") # Signed/unsigned mismatch + list(APPEND LUAU_OPTIONS "/we4388") # Also signed/unsigned mismatch else() list(APPEND LUAU_OPTIONS -Wall) # All warnings + list(APPEND LUAU_OPTIONS -Wsign-compare) # This looks to be included in -Wall for GCC but not clang +endif() + +if (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") + list(APPEND LUAU_OPTIONS /MP) # Distribute single project compilation across multiple cores endif() if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU") @@ -115,7 +123,7 @@ endif() set(ISOCLINE_OPTIONS) -if (NOT MSVC) +if (NOT CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") list(APPEND ISOCLINE_OPTIONS -Wno-unused-function) endif() @@ -137,7 +145,7 @@ if(LUAU_NATIVE) target_compile_definitions(Luau.VM PUBLIC LUA_CUSTOM_EXECUTION=1) endif() -if (MSVC AND MSVC_VERSION GREATER_EQUAL 1924) +if (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC" AND MSVC_VERSION GREATER_EQUAL 1924) # disable partial redundancy elimination which regresses interpreter codegen substantially in VS2022: # https://developercommunity.visualstudio.com/t/performance-regression-on-a-complex-interpreter-lo/1631863 set_source_files_properties(VM/src/lvmexecute.cpp PROPERTIES COMPILE_FLAGS /d2ssa-pre-) diff --git a/CodeGen/include/Luau/AddressA64.h b/CodeGen/include/Luau/AddressA64.h index 53efd3c3..2c852046 100644 --- a/CodeGen/include/Luau/AddressA64.h +++ b/CodeGen/include/Luau/AddressA64.h @@ -7,6 +7,8 @@ namespace Luau { namespace CodeGen { +namespace A64 +{ enum class AddressKindA64 : uint8_t { @@ -49,5 +51,6 @@ struct AddressA64 using mem = AddressA64; +} // namespace A64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/include/Luau/AssemblyBuilderA64.h b/CodeGen/include/Luau/AssemblyBuilderA64.h index 9e12168a..94d8f811 100644 --- a/CodeGen/include/Luau/AssemblyBuilderA64.h +++ b/CodeGen/include/Luau/AssemblyBuilderA64.h @@ -13,6 +13,8 @@ namespace Luau { namespace CodeGen { +namespace A64 +{ class AssemblyBuilderA64 { @@ -157,5 +159,6 @@ private: uint32_t* codeEnd = nullptr; }; +} // namespace A64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/include/Luau/AssemblyBuilderX64.h b/CodeGen/include/Luau/AssemblyBuilderX64.h index 235f1a84..597f2b2c 100644 --- a/CodeGen/include/Luau/AssemblyBuilderX64.h +++ b/CodeGen/include/Luau/AssemblyBuilderX64.h @@ -14,6 +14,8 @@ namespace Luau { namespace CodeGen { +namespace X64 +{ enum class RoundingModeX64 { @@ -242,5 +244,6 @@ private: uint8_t* codeEnd = nullptr; }; +} // namespace X64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/include/Luau/ConditionA64.h b/CodeGen/include/Luau/ConditionA64.h index e208d8cb..0beadad5 100644 --- a/CodeGen/include/Luau/ConditionA64.h +++ b/CodeGen/include/Luau/ConditionA64.h @@ -5,6 +5,8 @@ namespace Luau { namespace CodeGen { +namespace A64 +{ enum class ConditionA64 { @@ -33,5 +35,6 @@ enum class ConditionA64 Count }; +} // namespace A64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/include/Luau/IrAnalysis.h b/CodeGen/include/Luau/IrAnalysis.h index 0941d475..d3e1a933 100644 --- a/CodeGen/include/Luau/IrAnalysis.h +++ b/CodeGen/include/Luau/IrAnalysis.h @@ -1,16 +1,26 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include + +#include + namespace Luau { namespace CodeGen { +struct IrBlock; struct IrFunction; void updateUseCounts(IrFunction& function); void updateLastUseLocations(IrFunction& function); +// Returns how many values are coming into the block (live in) and how many are coming out of the block (live out) +std::pair getLiveInOutValueCount(IrFunction& function, IrBlock& block); +uint32_t getLiveInValueCount(IrFunction& function, IrBlock& block); +uint32_t getLiveOutValueCount(IrFunction& function, IrBlock& block); + } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/include/Luau/IrBuilder.h b/CodeGen/include/Luau/IrBuilder.h index 29553421..916c6eed 100644 --- a/CodeGen/include/Luau/IrBuilder.h +++ b/CodeGen/include/Luau/IrBuilder.h @@ -27,6 +27,12 @@ struct IrBuilder bool isInternalBlock(IrOp block); void beginBlock(IrOp block); + void loadAndCheckTag(IrOp loc, uint8_t tag, IrOp fallback); + + // Clones all instructions into the current block + // Source block that is cloned cannot use values coming in from a predecessor + void clone(const IrBlock& source, bool removeCurrentTerminator); + IrOp constBool(bool value); IrOp constInt(int value); IrOp constUint(unsigned value); diff --git a/CodeGen/include/Luau/IrData.h b/CodeGen/include/Luau/IrData.h index 18d510cc..049d700a 100644 --- a/CodeGen/include/Luau/IrData.h +++ b/CodeGen/include/Luau/IrData.h @@ -108,6 +108,13 @@ enum class IrCmd : uint8_t MOD_NUM, POW_NUM, + // Get the minimum/maximum of two numbers + // If one of the values is NaN, 'B' is returned as the result + // A, B: double + // In final x64 lowering, B can also be Rn or Kn + MIN_NUM, + MAX_NUM, + // Negate a double number // A: double UNM_NUM, @@ -197,6 +204,29 @@ enum class IrCmd : uint8_t // This is used to recover after calling a variadic function ADJUST_STACK_TO_TOP, + // Execute fastcall builtin function in-place + // A: builtin + // B: Rn (result start) + // C: Rn (argument start) + // D: Rn or Kn or a boolean that's false (optional second argument) + // E: int (argument count or -1 to use all arguments up to stack top) + // F: int (result count or -1 to preserve all results and adjust stack top) + FASTCALL, + + // Call the fastcall builtin function + // A: builtin + // B: Rn (result start) + // C: Rn (argument start) + // D: Rn or Kn or a boolean that's false (optional second argument) + // E: int (argument count or -1 to use all arguments up to stack top) + // F: int (result count or -1 to preserve all results and adjust stack top) + INVOKE_FASTCALL, + + // Check that fastcall builtin function invocation was successful (negative result count jumps to fallback) + // A: int (result count) + // B: block (fallback) + CHECK_FASTCALL_RES, + // Fallback functions // Perform an arithmetic operation on TValues of any type @@ -351,39 +381,26 @@ enum class IrCmd : uint8_t // C: int (result count or -1 to return all values up to stack top) LOP_RETURN, - // Perform a fast call of a built-in function - // A: unsigned int (bytecode instruction index) - // B: Rn (argument start) - // C: int (argument count or -1 use all arguments up to stack top) - // D: block (fallback) - // Note: return values are placed starting from Rn specified in 'B' - LOP_FASTCALL, - - // Perform a fast call of a built-in function using 1 register argument - // A: unsigned int (bytecode instruction index) - // B: Rn (result start) - // C: Rn (arg1) - // D: block (fallback) - LOP_FASTCALL1, - - // Perform a fast call of a built-in function using 2 register arguments - // A: unsigned int (bytecode instruction index) - // B: Rn (result start) - // C: Rn (arg1) - // D: Rn (arg2) - // E: block (fallback) - LOP_FASTCALL2, - - // Perform a fast call of a built-in function using 1 register argument and 1 constant argument - // A: unsigned int (bytecode instruction index) - // B: Rn (result start) - // C: Rn (arg1) - // D: Kn (arg2) - // E: block (fallback) - LOP_FASTCALL2K, - + // Adjust loop variables for one iteration of a generic for loop, jump back to the loop header if loop needs to continue + // A: Rn (loop variable start, updates Rn+2 Rn+3 Rn+4) + // B: int (loop variable count, is more than 2, additional registers are set to nil) + // C: block (repeat) + // D: block (exit) LOP_FORGLOOP, + + // Handle LOP_FORGLOOP fallback when variable being iterated is not a table + // A: unsigned int (bytecode instruction index) + // B: Rn (loop state start, updates Rn+2 Rn+3 Rn+4 Rn+5) + // C: int (extra variable count or -1 for ipairs-style iteration) + // D: block (repeat) + // E: block (exit) LOP_FORGLOOP_FALLBACK, + + // Fallback for generic for loop preparation when iterating over builtin pairs/ipairs + // It raises an error if 'B' register is not a function + // A: unsigned int (bytecode instruction index) + // B: Rn + // C: block (forgloop location) LOP_FORGPREP_XNEXT_FALLBACK, // Perform `and` or `or` operation (selecting lhs or rhs based on whether the lhs is truthy) and put the result into target register @@ -462,7 +479,7 @@ enum class IrCmd : uint8_t // Prepare loop variables for a generic for loop, jump to the loop backedge unconditionally // A: unsigned int (bytecode instruction index) - // B: Rn (loop state, updates Rn Rn+1 Rn+2) + // B: Rn (loop state start, updates Rn Rn+1 Rn+2) // C: block FALLBACK_FORGPREP, @@ -577,8 +594,8 @@ struct IrInst uint16_t useCount = 0; // Location of the result (optional) - RegisterX64 regX64 = noreg; - RegisterA64 regA64{KindA64::none, 0}; + X64::RegisterX64 regX64 = X64::noreg; + A64::RegisterA64 regA64 = A64::noreg; bool reusedReg = false; }; @@ -587,6 +604,7 @@ enum class IrBlockKind : uint8_t Bytecode, Fallback, Internal, + Linearized, Dead, }; diff --git a/CodeGen/include/Luau/IrUtils.h b/CodeGen/include/Luau/IrUtils.h index 0a23b3f7..153cf7ad 100644 --- a/CodeGen/include/Luau/IrUtils.h +++ b/CodeGen/include/Luau/IrUtils.h @@ -133,6 +133,8 @@ inline bool hasResult(IrCmd cmd) case IrCmd::DIV_NUM: case IrCmd::MOD_NUM: case IrCmd::POW_NUM: + case IrCmd::MIN_NUM: + case IrCmd::MAX_NUM: case IrCmd::UNM_NUM: case IrCmd::NOT_ANY: case IrCmd::TABLE_LEN: @@ -141,6 +143,7 @@ inline bool hasResult(IrCmd cmd) case IrCmd::NUM_TO_INDEX: case IrCmd::INT_TO_NUM: case IrCmd::SUBSTITUTE: + case IrCmd::INVOKE_FASTCALL: return true; default: break; @@ -151,6 +154,9 @@ inline bool hasResult(IrCmd cmd) inline bool hasSideEffects(IrCmd cmd) { + if (cmd == IrCmd::INVOKE_FASTCALL) + return true; + // Instructions that don't produce a result most likely have other side-effects to make them useful // Right now, a full switch would mirror the 'hasResult' function, so we use this simple condition return !hasResult(cmd); @@ -164,6 +170,10 @@ inline bool isPseudo(IrCmd cmd) bool isGCO(uint8_t tag); +// Manually add or remove use of an operand +void addUse(IrFunction& function, IrOp op); +void removeUse(IrFunction& function, IrOp op); + // Remove a single instruction void kill(IrFunction& function, IrInst& inst); diff --git a/CodeGen/include/Luau/OperandX64.h b/CodeGen/include/Luau/OperandX64.h index 5ad38e90..b9aa8f54 100644 --- a/CodeGen/include/Luau/OperandX64.h +++ b/CodeGen/include/Luau/OperandX64.h @@ -10,6 +10,8 @@ namespace Luau { namespace CodeGen { +namespace X64 +{ enum class CategoryX64 : uint8_t { @@ -138,5 +140,6 @@ constexpr OperandX64 operator+(RegisterX64 base, OperandX64 op) return op; } +} // namespace X64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/include/Luau/RegisterA64.h b/CodeGen/include/Luau/RegisterA64.h index 2d56276f..519e83fc 100644 --- a/CodeGen/include/Luau/RegisterA64.h +++ b/CodeGen/include/Luau/RegisterA64.h @@ -9,6 +9,8 @@ namespace Luau { namespace CodeGen { +namespace A64 +{ enum class KindA64 : uint8_t { @@ -33,6 +35,8 @@ struct RegisterA64 } }; +constexpr RegisterA64 noreg{KindA64::none, 0}; + constexpr RegisterA64 w0{KindA64::w, 0}; constexpr RegisterA64 w1{KindA64::w, 1}; constexpr RegisterA64 w2{KindA64::w, 2}; @@ -101,5 +105,6 @@ constexpr RegisterA64 xzr{KindA64::x, 31}; constexpr RegisterA64 sp{KindA64::none, 31}; +} // namespace A64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/include/Luau/RegisterX64.h b/CodeGen/include/Luau/RegisterX64.h index adc2db0c..9d76b116 100644 --- a/CodeGen/include/Luau/RegisterX64.h +++ b/CodeGen/include/Luau/RegisterX64.h @@ -9,6 +9,8 @@ namespace Luau { namespace CodeGen { +namespace X64 +{ enum class SizeX64 : uint8_t { @@ -133,5 +135,6 @@ constexpr RegisterX64 qwordReg(RegisterX64 reg) return RegisterX64{SizeX64::qword, reg.index}; } +} // namespace X64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/include/Luau/UnwindBuilder.h b/CodeGen/include/Luau/UnwindBuilder.h index b7237318..98e60498 100644 --- a/CodeGen/include/Luau/UnwindBuilder.h +++ b/CodeGen/include/Luau/UnwindBuilder.h @@ -21,10 +21,10 @@ public: virtual void start() = 0; - virtual void spill(int espOffset, RegisterX64 reg) = 0; - virtual void save(RegisterX64 reg) = 0; + virtual void spill(int espOffset, X64::RegisterX64 reg) = 0; + virtual void save(X64::RegisterX64 reg) = 0; virtual void allocStack(int size) = 0; - virtual void setupFrameReg(RegisterX64 reg, int espOffset) = 0; + virtual void setupFrameReg(X64::RegisterX64 reg, int espOffset) = 0; virtual void finish() = 0; diff --git a/CodeGen/include/Luau/UnwindBuilderDwarf2.h b/CodeGen/include/Luau/UnwindBuilderDwarf2.h index dab6e957..972f7423 100644 --- a/CodeGen/include/Luau/UnwindBuilderDwarf2.h +++ b/CodeGen/include/Luau/UnwindBuilderDwarf2.h @@ -17,10 +17,10 @@ public: void start() override; - void spill(int espOffset, RegisterX64 reg) override; - void save(RegisterX64 reg) override; + void spill(int espOffset, X64::RegisterX64 reg) override; + void save(X64::RegisterX64 reg) override; void allocStack(int size) override; - void setupFrameReg(RegisterX64 reg, int espOffset) override; + void setupFrameReg(X64::RegisterX64 reg, int espOffset) override; void finish() override; diff --git a/CodeGen/include/Luau/UnwindBuilderWin.h b/CodeGen/include/Luau/UnwindBuilderWin.h index 00513771..1cd750a1 100644 --- a/CodeGen/include/Luau/UnwindBuilderWin.h +++ b/CodeGen/include/Luau/UnwindBuilderWin.h @@ -27,10 +27,10 @@ public: void start() override; - void spill(int espOffset, RegisterX64 reg) override; - void save(RegisterX64 reg) override; + void spill(int espOffset, X64::RegisterX64 reg) override; + void save(X64::RegisterX64 reg) override; void allocStack(int size) override; - void setupFrameReg(RegisterX64 reg, int espOffset) override; + void setupFrameReg(X64::RegisterX64 reg, int espOffset) override; void finish() override; @@ -45,7 +45,7 @@ private: std::vector unwindCodes; uint8_t prologSize = 0; - RegisterX64 frameReg = rax; // rax means that frame register is not used + X64::RegisterX64 frameReg = X64::rax; // rax means that frame register is not used uint8_t frameRegOffset = 0; uint32_t stackOffset = 0; diff --git a/CodeGen/src/AssemblyBuilderA64.cpp b/CodeGen/src/AssemblyBuilderA64.cpp index 286800d6..308747d2 100644 --- a/CodeGen/src/AssemblyBuilderA64.cpp +++ b/CodeGen/src/AssemblyBuilderA64.cpp @@ -9,6 +9,8 @@ namespace Luau { namespace CodeGen { +namespace A64 +{ static const uint8_t codeForCondition[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14}; static_assert(sizeof(codeForCondition) / sizeof(codeForCondition[0]) == size_t(ConditionA64::Count), "all conditions have to be covered"); @@ -719,5 +721,6 @@ void AssemblyBuilderA64::log(AddressA64 addr) text.append("]"); } +} // namespace A64 } // namespace CodeGen -} // namespace Luau \ No newline at end of file +} // namespace Luau diff --git a/CodeGen/src/AssemblyBuilderX64.cpp b/CodeGen/src/AssemblyBuilderX64.cpp index 71bfaec1..bf7889b8 100644 --- a/CodeGen/src/AssemblyBuilderX64.cpp +++ b/CodeGen/src/AssemblyBuilderX64.cpp @@ -11,6 +11,9 @@ namespace Luau { namespace CodeGen { +namespace X64 +{ + // TODO: more assertions on operand sizes static const uint8_t codeForCondition[] = { @@ -1475,5 +1478,6 @@ const char* AssemblyBuilderX64::getRegisterName(RegisterX64 reg) const return names[size_t(reg.size)][reg.index]; } +} // namespace X64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/CodeGen.cpp b/CodeGen/src/CodeGen.cpp index 5076cba2..51bf1746 100644 --- a/CodeGen/src/CodeGen.cpp +++ b/CodeGen/src/CodeGen.cpp @@ -41,7 +41,7 @@ namespace CodeGen constexpr uint32_t kFunctionAlignment = 32; -static void assembleHelpers(AssemblyBuilderX64& build, ModuleHelpers& helpers) +static void assembleHelpers(X64::AssemblyBuilderX64& build, ModuleHelpers& helpers) { if (build.logText) build.logAppend("; exitContinueVm\n"); @@ -59,7 +59,7 @@ static void assembleHelpers(AssemblyBuilderX64& build, ModuleHelpers& helpers) emitContinueCallInVm(build); } -static NativeProto* assembleFunction(AssemblyBuilderX64& build, NativeState& data, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options) +static NativeProto* assembleFunction(X64::AssemblyBuilderX64& build, NativeState& data, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options) { NativeProto* result = new NativeProto(); @@ -78,7 +78,7 @@ static NativeProto* assembleFunction(AssemblyBuilderX64& build, NativeState& dat build.logAppend("\n"); } - build.align(kFunctionAlignment, AlignmentDataX64::Ud2); + build.align(kFunctionAlignment, X64::AlignmentDataX64::Ud2); Label start = build.setLabel(); @@ -92,7 +92,7 @@ static NativeProto* assembleFunction(AssemblyBuilderX64& build, NativeState& dat optimizeMemoryOperandsX64(builder.function); - IrLoweringX64 lowering(build, helpers, data, proto, builder.function); + X64::IrLoweringX64 lowering(build, helpers, data, proto, builder.function); lowering.lower(options); @@ -213,7 +213,7 @@ void create(lua_State* L) initFallbackTable(data); initHelperFunctions(data); - if (!x64::initEntryFunction(data)) + if (!X64::initEntryFunction(data)) { destroyNativeState(L); return; @@ -251,7 +251,7 @@ void compile(lua_State* L, int idx) if (!getNativeState(L)) return; - AssemblyBuilderX64 build(/* logText= */ false); + X64::AssemblyBuilderX64 build(/* logText= */ false); NativeState* data = getNativeState(L); std::vector protos; @@ -302,7 +302,7 @@ std::string getAssembly(lua_State* L, int idx, AssemblyOptions options) LUAU_ASSERT(lua_isLfunction(L, idx)); const TValue* func = luaA_toobject(L, idx); - AssemblyBuilderX64 build(/* logText= */ options.includeAssembly); + X64::AssemblyBuilderX64 build(/* logText= */ options.includeAssembly); NativeState data; initFallbackTable(data); diff --git a/CodeGen/src/CodeGenX64.cpp b/CodeGen/src/CodeGenX64.cpp index b23d2b38..ac6c9416 100644 --- a/CodeGen/src/CodeGenX64.cpp +++ b/CodeGen/src/CodeGenX64.cpp @@ -38,7 +38,7 @@ namespace Luau { namespace CodeGen { -namespace x64 +namespace X64 { bool initEntryFunction(NativeState& data) @@ -143,6 +143,6 @@ bool initEntryFunction(NativeState& data) return true; } -} // namespace x64 +} // namespace X64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/CodeGenX64.h b/CodeGen/src/CodeGenX64.h index 6791f7f3..b82266af 100644 --- a/CodeGen/src/CodeGenX64.h +++ b/CodeGen/src/CodeGenX64.h @@ -8,11 +8,11 @@ namespace CodeGen struct NativeState; -namespace x64 +namespace X64 { bool initEntryFunction(NativeState& data); -} // namespace x64 +} // namespace X64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/EmitBuiltinsX64.cpp b/CodeGen/src/EmitBuiltinsX64.cpp index 0a3b3609..05b63551 100644 --- a/CodeGen/src/EmitBuiltinsX64.cpp +++ b/CodeGen/src/EmitBuiltinsX64.cpp @@ -5,7 +5,7 @@ #include "Luau/Bytecode.h" #include "EmitCommonX64.h" -#include "IrTranslateBuiltins.h" // Used temporarily for shared definition of BuiltinImplResult +#include "IrRegAllocX64.h" #include "NativeState.h" #include "lstate.h" @@ -16,343 +16,135 @@ namespace Luau { namespace CodeGen { - -BuiltinImplResult emitBuiltinMathFloor(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +namespace X64 { - if (nparams < 1 || nresults > 1) - return {BuiltinImplType::None, -1}; - if (build.logText) - build.logAppend("; inlined LBF_MATH_FLOOR\n"); - - jumpIfTagIsNot(build, arg, LUA_TNUMBER, fallback); - - build.vroundsd(xmm0, xmm0, luauRegValue(arg), RoundingModeX64::RoundToNegativeInfinity); - build.vmovsd(luauRegValue(ra), xmm0); - - if (ra != arg) - build.mov(luauRegTag(ra), LUA_TNUMBER); - - return {BuiltinImplType::UsesFallback, 1}; +void emitBuiltinMathFloor(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) +{ + ScopedRegX64 tmp{regs, SizeX64::xmmword}; + build.vroundsd(tmp.reg, tmp.reg, luauRegValue(arg), RoundingModeX64::RoundToNegativeInfinity); + build.vmovsd(luauRegValue(ra), tmp.reg); } -BuiltinImplResult emitBuiltinMathCeil(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +void emitBuiltinMathCeil(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - if (nparams < 1 || nresults > 1) - return {BuiltinImplType::None, -1}; - - if (build.logText) - build.logAppend("; inlined LBF_MATH_CEIL\n"); - - jumpIfTagIsNot(build, arg, LUA_TNUMBER, fallback); - - build.vroundsd(xmm0, xmm0, luauRegValue(arg), RoundingModeX64::RoundToPositiveInfinity); - build.vmovsd(luauRegValue(ra), xmm0); - - if (ra != arg) - build.mov(luauRegTag(ra), LUA_TNUMBER); - - return {BuiltinImplType::UsesFallback, 1}; + ScopedRegX64 tmp{regs, SizeX64::xmmword}; + build.vroundsd(tmp.reg, tmp.reg, luauRegValue(arg), RoundingModeX64::RoundToPositiveInfinity); + build.vmovsd(luauRegValue(ra), tmp.reg); } -BuiltinImplResult emitBuiltinMathSqrt(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +void emitBuiltinMathSqrt(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - if (nparams < 1 || nresults > 1) - return {BuiltinImplType::None, -1}; - - if (build.logText) - build.logAppend("; inlined LBF_MATH_SQRT\n"); - - jumpIfTagIsNot(build, arg, LUA_TNUMBER, fallback); - - build.vsqrtsd(xmm0, xmm0, luauRegValue(arg)); - build.vmovsd(luauRegValue(ra), xmm0); - - if (ra != arg) - build.mov(luauRegTag(ra), LUA_TNUMBER); - - return {BuiltinImplType::UsesFallback, 1}; + ScopedRegX64 tmp{regs, SizeX64::xmmword}; + build.vsqrtsd(tmp.reg, tmp.reg, luauRegValue(arg)); + build.vmovsd(luauRegValue(ra), tmp.reg); } -BuiltinImplResult emitBuiltinMathAbs(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +void emitBuiltinMathAbs(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - if (nparams < 1 || nresults > 1) - return {BuiltinImplType::None, -1}; - - if (build.logText) - build.logAppend("; inlined LBF_MATH_ABS\n"); - - jumpIfTagIsNot(build, arg, LUA_TNUMBER, fallback); - - build.vmovsd(xmm0, luauRegValue(arg)); - build.vandpd(xmm0, xmm0, build.i64(~(1LL << 63))); - build.vmovsd(luauRegValue(ra), xmm0); - - if (ra != arg) - build.mov(luauRegTag(ra), LUA_TNUMBER); - - return {BuiltinImplType::UsesFallback, 1}; + ScopedRegX64 tmp{regs, SizeX64::xmmword}; + build.vmovsd(tmp.reg, luauRegValue(arg)); + build.vandpd(tmp.reg, tmp.reg, build.i64(~(1LL << 63))); + build.vmovsd(luauRegValue(ra), tmp.reg); } -static BuiltinImplResult emitBuiltinMathSingleArgFunc( - AssemblyBuilderX64& build, int nparams, int ra, int arg, int nresults, Label& fallback, const char* name, int32_t offset) +static void emitBuiltinMathSingleArgFunc(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int arg, int32_t offset) { - if (nparams < 1 || nresults > 1) - return {BuiltinImplType::None, -1}; - - if (build.logText) - build.logAppend("; inlined %s\n", name); - - jumpIfTagIsNot(build, arg, LUA_TNUMBER, fallback); - + regs.assertAllFree(); build.vmovsd(xmm0, luauRegValue(arg)); build.call(qword[rNativeContext + offset]); build.vmovsd(luauRegValue(ra), xmm0); - - if (ra != arg) - build.mov(luauRegTag(ra), LUA_TNUMBER); - - return {BuiltinImplType::UsesFallback, 1}; } -BuiltinImplResult emitBuiltinMathExp(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +void emitBuiltinMathExp(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - return emitBuiltinMathSingleArgFunc(build, nparams, ra, arg, nresults, fallback, "LBF_MATH_EXP", offsetof(NativeContext, libm_exp)); + emitBuiltinMathSingleArgFunc(regs, build, ra, arg, offsetof(NativeContext, libm_exp)); } -BuiltinImplResult emitBuiltinMathDeg(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +void emitBuiltinMathFmod(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - if (nparams < 1 || nresults > 1) - return {BuiltinImplType::None, -1}; - - if (build.logText) - build.logAppend("; inlined LBF_MATH_DEG\n"); - - jumpIfTagIsNot(build, arg, LUA_TNUMBER, fallback); - - const double rpd = (3.14159265358979323846 / 180.0); - - build.vmovsd(xmm0, luauRegValue(arg)); - build.vdivsd(xmm0, xmm0, build.f64(rpd)); - build.vmovsd(luauRegValue(ra), xmm0); - - if (ra != arg) - build.mov(luauRegTag(ra), LUA_TNUMBER); - - return {BuiltinImplType::UsesFallback, 1}; -} - -BuiltinImplResult emitBuiltinMathRad(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) -{ - if (nparams < 1 || nresults > 1) - return {BuiltinImplType::None, -1}; - - if (build.logText) - build.logAppend("; inlined LBF_MATH_RAD\n"); - - jumpIfTagIsNot(build, arg, LUA_TNUMBER, fallback); - - const double rpd = (3.14159265358979323846 / 180.0); - - build.vmovsd(xmm0, luauRegValue(arg)); - build.vmulsd(xmm0, xmm0, build.f64(rpd)); - build.vmovsd(luauRegValue(ra), xmm0); - - if (ra != arg) - build.mov(luauRegTag(ra), LUA_TNUMBER); - - return {BuiltinImplType::UsesFallback, 1}; -} - -BuiltinImplResult emitBuiltinMathFmod(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) -{ - if (nparams < 2 || nresults > 1) - return {BuiltinImplType::None, -1}; - - if (build.logText) - build.logAppend("; inlined LBF_MATH_FMOD\n"); - - jumpIfTagIsNot(build, arg, LUA_TNUMBER, fallback); - - // TODO: jumpIfTagIsNot can be generalized to take OperandX64 and then we can use it here; let's wait until we see this more though - build.cmp(dword[args + offsetof(TValue, tt)], LUA_TNUMBER); - build.jcc(ConditionX64::NotEqual, fallback); - + regs.assertAllFree(); build.vmovsd(xmm0, luauRegValue(arg)); build.vmovsd(xmm1, qword[args + offsetof(TValue, value)]); build.call(qword[rNativeContext + offsetof(NativeContext, libm_fmod)]); build.vmovsd(luauRegValue(ra), xmm0); - - if (ra != arg) - build.mov(luauRegTag(ra), LUA_TNUMBER); - - return {BuiltinImplType::UsesFallback, 1}; } -BuiltinImplResult emitBuiltinMathPow(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +void emitBuiltinMathPow(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - if (nparams < 2 || nresults > 1) - return {BuiltinImplType::None, -1}; - - if (build.logText) - build.logAppend("; inlined LBF_MATH_POW\n"); - - jumpIfTagIsNot(build, arg, LUA_TNUMBER, fallback); - - // TODO: jumpIfTagIsNot can be generalized to take OperandX64 and then we can use it here; let's wait until we see this more though - build.cmp(dword[args + offsetof(TValue, tt)], LUA_TNUMBER); - build.jcc(ConditionX64::NotEqual, fallback); - + regs.assertAllFree(); build.vmovsd(xmm0, luauRegValue(arg)); build.vmovsd(xmm1, qword[args + offsetof(TValue, value)]); build.call(qword[rNativeContext + offsetof(NativeContext, libm_pow)]); build.vmovsd(luauRegValue(ra), xmm0); - - if (ra != arg) - build.mov(luauRegTag(ra), LUA_TNUMBER); - - return {BuiltinImplType::UsesFallback, 1}; } -BuiltinImplResult emitBuiltinMathMin(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +void emitBuiltinMathAsin(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - if (nparams != 2 || nresults > 1) - return {BuiltinImplType::None, -1}; - - if (build.logText) - build.logAppend("; inlined LBF_MATH_MIN\n"); - - jumpIfTagIsNot(build, arg, LUA_TNUMBER, fallback); - - // TODO: jumpIfTagIsNot can be generalized to take OperandX64 and then we can use it here; let's wait until we see this more though - build.cmp(dword[args + offsetof(TValue, tt)], LUA_TNUMBER); - build.jcc(ConditionX64::NotEqual, fallback); - - build.vmovsd(xmm0, qword[args + offsetof(TValue, value)]); - build.vminsd(xmm0, xmm0, luauRegValue(arg)); - - build.vmovsd(luauRegValue(ra), xmm0); - - if (ra != arg) - build.mov(luauRegTag(ra), LUA_TNUMBER); - - return {BuiltinImplType::UsesFallback, 1}; + emitBuiltinMathSingleArgFunc(regs, build, ra, arg, offsetof(NativeContext, libm_asin)); } -BuiltinImplResult emitBuiltinMathMax(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +void emitBuiltinMathSin(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - if (nparams != 2 || nresults > 1) - return {BuiltinImplType::None, -1}; - - if (build.logText) - build.logAppend("; inlined LBF_MATH_MAX\n"); - - jumpIfTagIsNot(build, arg, LUA_TNUMBER, fallback); - - // TODO: jumpIfTagIsNot can be generalized to take OperandX64 and then we can use it here; let's wait until we see this more though - build.cmp(dword[args + offsetof(TValue, tt)], LUA_TNUMBER); - build.jcc(ConditionX64::NotEqual, fallback); - - build.vmovsd(xmm0, qword[args + offsetof(TValue, value)]); - build.vmaxsd(xmm0, xmm0, luauRegValue(arg)); - - build.vmovsd(luauRegValue(ra), xmm0); - - if (ra != arg) - build.mov(luauRegTag(ra), LUA_TNUMBER); - - return {BuiltinImplType::UsesFallback, 1}; + emitBuiltinMathSingleArgFunc(regs, build, ra, arg, offsetof(NativeContext, libm_sin)); } -BuiltinImplResult emitBuiltinMathAsin(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +void emitBuiltinMathSinh(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - return emitBuiltinMathSingleArgFunc(build, nparams, ra, arg, nresults, fallback, "LBF_MATH_ASIN", offsetof(NativeContext, libm_asin)); + emitBuiltinMathSingleArgFunc(regs, build, ra, arg, offsetof(NativeContext, libm_sinh)); } -BuiltinImplResult emitBuiltinMathSin(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +void emitBuiltinMathAcos(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - return emitBuiltinMathSingleArgFunc(build, nparams, ra, arg, nresults, fallback, "LBF_MATH_SIN", offsetof(NativeContext, libm_sin)); + emitBuiltinMathSingleArgFunc(regs, build, ra, arg, offsetof(NativeContext, libm_acos)); } -BuiltinImplResult emitBuiltinMathSinh(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +void emitBuiltinMathCos(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - return emitBuiltinMathSingleArgFunc(build, nparams, ra, arg, nresults, fallback, "LBF_MATH_SINH", offsetof(NativeContext, libm_sinh)); + emitBuiltinMathSingleArgFunc(regs, build, ra, arg, offsetof(NativeContext, libm_cos)); } -BuiltinImplResult emitBuiltinMathAcos(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +void emitBuiltinMathCosh(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - return emitBuiltinMathSingleArgFunc(build, nparams, ra, arg, nresults, fallback, "LBF_MATH_ACOS", offsetof(NativeContext, libm_acos)); + emitBuiltinMathSingleArgFunc(regs, build, ra, arg, offsetof(NativeContext, libm_cosh)); } -BuiltinImplResult emitBuiltinMathCos(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +void emitBuiltinMathAtan(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - return emitBuiltinMathSingleArgFunc(build, nparams, ra, arg, nresults, fallback, "LBF_MATH_COS", offsetof(NativeContext, libm_cos)); + emitBuiltinMathSingleArgFunc(regs, build, ra, arg, offsetof(NativeContext, libm_atan)); } -BuiltinImplResult emitBuiltinMathCosh(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +void emitBuiltinMathTan(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - return emitBuiltinMathSingleArgFunc(build, nparams, ra, arg, nresults, fallback, "LBF_MATH_COSH", offsetof(NativeContext, libm_cosh)); + emitBuiltinMathSingleArgFunc(regs, build, ra, arg, offsetof(NativeContext, libm_tan)); } -BuiltinImplResult emitBuiltinMathAtan(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +void emitBuiltinMathTanh(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - return emitBuiltinMathSingleArgFunc(build, nparams, ra, arg, nresults, fallback, "LBF_MATH_ATAN", offsetof(NativeContext, libm_atan)); + emitBuiltinMathSingleArgFunc(regs, build, ra, arg, offsetof(NativeContext, libm_tanh)); } -BuiltinImplResult emitBuiltinMathTan(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +void emitBuiltinMathAtan2(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - return emitBuiltinMathSingleArgFunc(build, nparams, ra, arg, nresults, fallback, "LBF_MATH_TAN", offsetof(NativeContext, libm_tan)); -} - -BuiltinImplResult emitBuiltinMathTanh(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) -{ - return emitBuiltinMathSingleArgFunc(build, nparams, ra, arg, nresults, fallback, "LBF_MATH_TANH", offsetof(NativeContext, libm_tanh)); -} - -BuiltinImplResult emitBuiltinMathAtan2(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) -{ - if (nparams < 2 || nresults > 1) - return {BuiltinImplType::None, -1}; - - if (build.logText) - build.logAppend("; inlined LBF_MATH_ATAN2\n"); - - jumpIfTagIsNot(build, arg, LUA_TNUMBER, fallback); - - // TODO: jumpIfTagIsNot can be generalized to take OperandX64 and then we can use it here; let's wait until we see this more though - build.cmp(dword[args + offsetof(TValue, tt)], LUA_TNUMBER); - build.jcc(ConditionX64::NotEqual, fallback); - + regs.assertAllFree(); build.vmovsd(xmm0, luauRegValue(arg)); build.vmovsd(xmm1, qword[args + offsetof(TValue, value)]); build.call(qword[rNativeContext + offsetof(NativeContext, libm_atan2)]); build.vmovsd(luauRegValue(ra), xmm0); - - if (ra != arg) - build.mov(luauRegTag(ra), LUA_TNUMBER); - - return {BuiltinImplType::UsesFallback, 1}; } -BuiltinImplResult emitBuiltinMathLog10(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +void emitBuiltinMathLog10(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - return emitBuiltinMathSingleArgFunc(build, nparams, ra, arg, nresults, fallback, "LBF_MATH_LOG10", offsetof(NativeContext, libm_log10)); + emitBuiltinMathSingleArgFunc(regs, build, ra, arg, offsetof(NativeContext, libm_log10)); } -BuiltinImplResult emitBuiltinMathLog(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +void emitBuiltinMathLog(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - if (nparams < 1 || nresults > 1) - return {BuiltinImplType::None, -1}; - - if (build.logText) - build.logAppend("; inlined LBF_MATH_LOG\n"); - - jumpIfTagIsNot(build, arg, LUA_TNUMBER, fallback); - + regs.assertAllFree(); build.vmovsd(xmm0, luauRegValue(arg)); if (nparams == 1) @@ -367,19 +159,15 @@ BuiltinImplResult emitBuiltinMathLog(AssemblyBuilderX64& build, int nparams, int RegisterX64 tmp = rbx; OperandX64 arg2value = qword[args + offsetof(TValue, value)]; - // TODO: jumpIfTagIsNot can be generalized to take OperandX64 and then we can use it here; let's wait until we see this more though - build.cmp(dword[args + offsetof(TValue, tt)], LUA_TNUMBER); - build.jcc(ConditionX64::NotEqual, fallback); - build.vmovsd(xmm1, arg2value); - jumpOnNumberCmp(build, noreg, build.f64(2.0), xmm1, ConditionX64::NotEqual, log10check); + jumpOnNumberCmp(build, noreg, build.f64(2.0), xmm1, IrCondition::NotEqual, log10check); build.call(qword[rNativeContext + offsetof(NativeContext, libm_log2)]); build.jmp(exit); build.setLabel(log10check); - jumpOnNumberCmp(build, noreg, build.f64(10.0), xmm1, ConditionX64::NotEqual, logdivlog); + jumpOnNumberCmp(build, noreg, build.f64(10.0), xmm1, IrCondition::NotEqual, logdivlog); build.call(qword[rNativeContext + offsetof(NativeContext, libm_log10)]); build.jmp(exit); @@ -402,28 +190,11 @@ BuiltinImplResult emitBuiltinMathLog(AssemblyBuilderX64& build, int nparams, int } build.vmovsd(luauRegValue(ra), xmm0); - - if (ra != arg) - build.mov(luauRegTag(ra), LUA_TNUMBER); - - return {BuiltinImplType::UsesFallback, 1}; } - -BuiltinImplResult emitBuiltinMathLdexp(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +void emitBuiltinMathLdexp(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - if (nparams < 2 || nresults > 1) - return {BuiltinImplType::None, -1}; - - if (build.logText) - build.logAppend("; inlined LBF_MATH_LDEXP\n"); - - jumpIfTagIsNot(build, arg, LUA_TNUMBER, fallback); - - // TODO: jumpIfTagIsNot can be generalized to take OperandX64 and then we can use it here; let's wait until we see this more though - build.cmp(dword[args + offsetof(TValue, tt)], LUA_TNUMBER); - build.jcc(ConditionX64::NotEqual, fallback); - + regs.assertAllFree(); build.vmovsd(xmm0, luauRegValue(arg)); if (build.abi == ABIX64::Windows) @@ -434,48 +205,27 @@ BuiltinImplResult emitBuiltinMathLdexp(AssemblyBuilderX64& build, int nparams, i build.call(qword[rNativeContext + offsetof(NativeContext, libm_ldexp)]); build.vmovsd(luauRegValue(ra), xmm0); - - if (ra != arg) - build.mov(luauRegTag(ra), LUA_TNUMBER); - - return {BuiltinImplType::UsesFallback, 1}; } -BuiltinImplResult emitBuiltinMathRound(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +void emitBuiltinMathRound(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - if (nparams < 1 || nresults > 1) - return {BuiltinImplType::None, -1}; + ScopedRegX64 tmp0{regs, SizeX64::xmmword}; + ScopedRegX64 tmp1{regs, SizeX64::xmmword}; + ScopedRegX64 tmp2{regs, SizeX64::xmmword}; - if (build.logText) - build.logAppend("; inlined LBF_MATH_ROUND\n"); + build.vmovsd(tmp0.reg, luauRegValue(arg)); + build.vandpd(tmp1.reg, tmp0.reg, build.f64x2(-0.0, -0.0)); + build.vmovsd(tmp2.reg, build.i64(0x3fdfffffffffffff)); // 0.49999999999999994 + build.vorpd(tmp1.reg, tmp1.reg, tmp2.reg); + build.vaddsd(tmp0.reg, tmp0.reg, tmp1.reg); + build.vroundsd(tmp0.reg, tmp0.reg, tmp0.reg, RoundingModeX64::RoundToZero); - jumpIfTagIsNot(build, arg, LUA_TNUMBER, fallback); - - build.vmovsd(xmm0, luauRegValue(arg)); - build.vandpd(xmm1, xmm0, build.f64x2(-0.0, -0.0)); - build.vmovsd(xmm2, build.i64(0x3fdfffffffffffff)); // 0.49999999999999994 - build.vorpd(xmm1, xmm1, xmm2); - build.vaddsd(xmm0, xmm0, xmm1); - build.vroundsd(xmm0, xmm0, xmm0, RoundingModeX64::RoundToZero); - - build.vmovsd(luauRegValue(ra), xmm0); - - if (ra != arg) - build.mov(luauRegTag(ra), LUA_TNUMBER); - - return {BuiltinImplType::UsesFallback, 1}; + build.vmovsd(luauRegValue(ra), tmp0.reg); } -BuiltinImplResult emitBuiltinMathFrexp(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +void emitBuiltinMathFrexp(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - if (nparams < 1 || nresults > 2) - return {BuiltinImplType::None, -1}; - - if (build.logText) - build.logAppend("; inlined LBF_MATH_FREXP\n"); - - jumpIfTagIsNot(build, arg, LUA_TNUMBER, fallback); - + regs.assertAllFree(); build.vmovsd(xmm0, luauRegValue(arg)); if (build.abi == ABIX64::Windows) @@ -487,26 +237,13 @@ BuiltinImplResult emitBuiltinMathFrexp(AssemblyBuilderX64& build, int nparams, i build.vmovsd(luauRegValue(ra), xmm0); - if (ra != arg) - build.mov(luauRegTag(ra), LUA_TNUMBER); - build.vcvtsi2sd(xmm0, xmm0, dword[sTemporarySlot + 0]); build.vmovsd(luauRegValue(ra + 1), xmm0); - build.mov(luauRegTag(ra + 1), LUA_TNUMBER); - - return {BuiltinImplType::UsesFallback, 2}; } -BuiltinImplResult emitBuiltinMathModf(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +void emitBuiltinMathModf(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - if (nparams < 1 || nresults > 2) - return {BuiltinImplType::None, -1}; - - if (build.logText) - build.logAppend("; inlined LBF_MATH_MODF\n"); - - jumpIfTagIsNot(build, arg, LUA_TNUMBER, fallback); - + regs.assertAllFree(); build.vmovsd(xmm0, luauRegValue(arg)); if (build.abi == ABIX64::Windows) @@ -519,156 +256,109 @@ BuiltinImplResult emitBuiltinMathModf(AssemblyBuilderX64& build, int nparams, in build.vmovsd(xmm1, qword[sTemporarySlot + 0]); build.vmovsd(luauRegValue(ra), xmm1); - if (ra != arg) - build.mov(luauRegTag(ra), LUA_TNUMBER); - build.vmovsd(luauRegValue(ra + 1), xmm0); - build.mov(luauRegTag(ra + 1), LUA_TNUMBER); - - return {BuiltinImplType::UsesFallback, 2}; } -BuiltinImplResult emitBuiltinMathSign(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +void emitBuiltinMathSign(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) { - if (nparams < 1 || nresults > 1) - return {BuiltinImplType::None, -1}; + ScopedRegX64 tmp0{regs, SizeX64::xmmword}; + ScopedRegX64 tmp1{regs, SizeX64::xmmword}; + ScopedRegX64 tmp2{regs, SizeX64::xmmword}; + ScopedRegX64 tmp3{regs, SizeX64::xmmword}; - if (build.logText) - build.logAppend("; inlined LBF_MATH_SIGN\n"); + build.vmovsd(tmp0.reg, luauRegValue(arg)); + build.vxorpd(tmp1.reg, tmp1.reg, tmp1.reg); - jumpIfTagIsNot(build, arg, LUA_TNUMBER, fallback); - - build.vmovsd(xmm0, luauRegValue(arg)); - build.vxorpd(xmm1, xmm1, xmm1); - - // Set xmm2 to -1 if arg < 0, else 0 - build.vcmpltsd(xmm2, xmm0, xmm1); - build.vmovsd(xmm3, build.f64(-1)); - build.vandpd(xmm2, xmm2, xmm3); + // Set tmp2 to -1 if arg < 0, else 0 + build.vcmpltsd(tmp2.reg, tmp0.reg, tmp1.reg); + build.vmovsd(tmp3.reg, build.f64(-1)); + build.vandpd(tmp2.reg, tmp2.reg, tmp3.reg); // Set mask bit to 1 if 0 < arg, else 0 - build.vcmpltsd(xmm0, xmm1, xmm0); + build.vcmpltsd(tmp0.reg, tmp1.reg, tmp0.reg); - // Result = (mask-bit == 1) ? 1.0 : xmm2 - // If arg < 0 then xmm2 is -1 and mask-bit is 0, result is -1 - // If arg == 0 then xmm2 is 0 and mask-bit is 0, result is 0 - // If arg > 0 then xmm2 is 0 and mask-bit is 1, result is 1 - build.vblendvpd(xmm0, xmm2, build.f64x2(1, 1), xmm0); + // Result = (mask-bit == 1) ? 1.0 : tmp2 + // If arg < 0 then tmp2 is -1 and mask-bit is 0, result is -1 + // If arg == 0 then tmp2 is 0 and mask-bit is 0, result is 0 + // If arg > 0 then tmp2 is 0 and mask-bit is 1, result is 1 + build.vblendvpd(tmp0.reg, tmp2.reg, build.f64x2(1, 1), tmp0.reg); - build.vmovsd(luauRegValue(ra), xmm0); - - if (ra != arg) - build.mov(luauRegTag(ra), LUA_TNUMBER); - - return {BuiltinImplType::UsesFallback, 1}; + build.vmovsd(luauRegValue(ra), tmp0.reg); } -BuiltinImplResult emitBuiltinMathClamp(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) +void emitBuiltin(IrRegAllocX64& regs, AssemblyBuilderX64& build, int bfid, int ra, int arg, IrOp args, int nparams, int nresults) { - if (nparams < 3 || nresults > 1) - return {BuiltinImplType::None, -1}; + OperandX64 argsOp = 0; - if (build.logText) - build.logAppend("; inlined LBF_MATH_CLAMP\n"); + if (args.kind == IrOpKind::VmReg) + argsOp = luauRegAddress(args.index); + else if (args.kind == IrOpKind::VmConst) + argsOp = luauConstantAddress(args.index); - jumpIfTagIsNot(build, arg, LUA_TNUMBER, fallback); - - // TODO: jumpIfTagIsNot can be generalized to take OperandX64 and then we can use it here; let's wait until we see this more though - build.cmp(dword[args + offsetof(TValue, tt)], LUA_TNUMBER); - build.jcc(ConditionX64::NotEqual, fallback); - - // TODO: jumpIfTagIsNot can be generalized to take OperandX64 and then we can use it here; let's wait until we see this more though - build.cmp(dword[args + sizeof(TValue) + offsetof(TValue, tt)], LUA_TNUMBER); - build.jcc(ConditionX64::NotEqual, fallback); - - RegisterX64 min = xmm1; - RegisterX64 max = xmm2; - build.vmovsd(min, qword[args + offsetof(TValue, value)]); - build.vmovsd(max, qword[args + sizeof(TValue) + offsetof(TValue, value)]); - - jumpOnNumberCmp(build, noreg, min, max, ConditionX64::NotLessEqual, fallback); - - build.vmaxsd(xmm0, min, luauRegValue(arg)); - build.vminsd(xmm0, max, xmm0); - - build.vmovsd(luauRegValue(ra), xmm0); - - if (ra != arg) - build.mov(luauRegTag(ra), LUA_TNUMBER); - - return {BuiltinImplType::UsesFallback, 1}; -} - - -BuiltinImplResult emitBuiltin(AssemblyBuilderX64& build, int bfid, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback) -{ switch (bfid) { case LBF_ASSERT: - // This builtin fast-path was already translated to IR - return {BuiltinImplType::None, -1}; - case LBF_MATH_FLOOR: - return emitBuiltinMathFloor(build, nparams, ra, arg, args, nresults, fallback); - case LBF_MATH_CEIL: - return emitBuiltinMathCeil(build, nparams, ra, arg, args, nresults, fallback); - case LBF_MATH_SQRT: - return emitBuiltinMathSqrt(build, nparams, ra, arg, args, nresults, fallback); - case LBF_MATH_ABS: - return emitBuiltinMathAbs(build, nparams, ra, arg, args, nresults, fallback); - case LBF_MATH_EXP: - return emitBuiltinMathExp(build, nparams, ra, arg, args, nresults, fallback); case LBF_MATH_DEG: - return emitBuiltinMathDeg(build, nparams, ra, arg, args, nresults, fallback); case LBF_MATH_RAD: - return emitBuiltinMathRad(build, nparams, ra, arg, args, nresults, fallback); - case LBF_MATH_FMOD: - return emitBuiltinMathFmod(build, nparams, ra, arg, args, nresults, fallback); - case LBF_MATH_POW: - return emitBuiltinMathPow(build, nparams, ra, arg, args, nresults, fallback); case LBF_MATH_MIN: - return emitBuiltinMathMin(build, nparams, ra, arg, args, nresults, fallback); case LBF_MATH_MAX: - return emitBuiltinMathMax(build, nparams, ra, arg, args, nresults, fallback); - case LBF_MATH_ASIN: - return emitBuiltinMathAsin(build, nparams, ra, arg, args, nresults, fallback); - case LBF_MATH_SIN: - return emitBuiltinMathSin(build, nparams, ra, arg, args, nresults, fallback); - case LBF_MATH_SINH: - return emitBuiltinMathSinh(build, nparams, ra, arg, args, nresults, fallback); - case LBF_MATH_ACOS: - return emitBuiltinMathAcos(build, nparams, ra, arg, args, nresults, fallback); - case LBF_MATH_COS: - return emitBuiltinMathCos(build, nparams, ra, arg, args, nresults, fallback); - case LBF_MATH_COSH: - return emitBuiltinMathCosh(build, nparams, ra, arg, args, nresults, fallback); - case LBF_MATH_ATAN: - return emitBuiltinMathAtan(build, nparams, ra, arg, args, nresults, fallback); - case LBF_MATH_TAN: - return emitBuiltinMathTan(build, nparams, ra, arg, args, nresults, fallback); - case LBF_MATH_TANH: - return emitBuiltinMathTanh(build, nparams, ra, arg, args, nresults, fallback); - case LBF_MATH_ATAN2: - return emitBuiltinMathAtan2(build, nparams, ra, arg, args, nresults, fallback); - case LBF_MATH_LOG10: - return emitBuiltinMathLog10(build, nparams, ra, arg, args, nresults, fallback); - case LBF_MATH_LOG: - return emitBuiltinMathLog(build, nparams, ra, arg, args, nresults, fallback); - case LBF_MATH_LDEXP: - return emitBuiltinMathLdexp(build, nparams, ra, arg, args, nresults, fallback); - case LBF_MATH_ROUND: - return emitBuiltinMathRound(build, nparams, ra, arg, args, nresults, fallback); - case LBF_MATH_FREXP: - return emitBuiltinMathFrexp(build, nparams, ra, arg, args, nresults, fallback); - case LBF_MATH_MODF: - return emitBuiltinMathModf(build, nparams, ra, arg, args, nresults, fallback); - case LBF_MATH_SIGN: - return emitBuiltinMathSign(build, nparams, ra, arg, args, nresults, fallback); case LBF_MATH_CLAMP: - return emitBuiltinMathClamp(build, nparams, ra, arg, args, nresults, fallback); + // These instructions are fully translated to IR + break; + case LBF_MATH_FLOOR: + return emitBuiltinMathFloor(regs, build, nparams, ra, arg, argsOp, nresults); + case LBF_MATH_CEIL: + return emitBuiltinMathCeil(regs, build, nparams, ra, arg, argsOp, nresults); + case LBF_MATH_SQRT: + return emitBuiltinMathSqrt(regs, build, nparams, ra, arg, argsOp, nresults); + case LBF_MATH_ABS: + return emitBuiltinMathAbs(regs, build, nparams, ra, arg, argsOp, nresults); + case LBF_MATH_EXP: + return emitBuiltinMathExp(regs, build, nparams, ra, arg, argsOp, nresults); + case LBF_MATH_FMOD: + return emitBuiltinMathFmod(regs, build, nparams, ra, arg, argsOp, nresults); + case LBF_MATH_POW: + return emitBuiltinMathPow(regs, build, nparams, ra, arg, argsOp, nresults); + case LBF_MATH_ASIN: + return emitBuiltinMathAsin(regs, build, nparams, ra, arg, argsOp, nresults); + case LBF_MATH_SIN: + return emitBuiltinMathSin(regs, build, nparams, ra, arg, argsOp, nresults); + case LBF_MATH_SINH: + return emitBuiltinMathSinh(regs, build, nparams, ra, arg, argsOp, nresults); + case LBF_MATH_ACOS: + return emitBuiltinMathAcos(regs, build, nparams, ra, arg, argsOp, nresults); + case LBF_MATH_COS: + return emitBuiltinMathCos(regs, build, nparams, ra, arg, argsOp, nresults); + case LBF_MATH_COSH: + return emitBuiltinMathCosh(regs, build, nparams, ra, arg, argsOp, nresults); + case LBF_MATH_ATAN: + return emitBuiltinMathAtan(regs, build, nparams, ra, arg, argsOp, nresults); + case LBF_MATH_TAN: + return emitBuiltinMathTan(regs, build, nparams, ra, arg, argsOp, nresults); + case LBF_MATH_TANH: + return emitBuiltinMathTanh(regs, build, nparams, ra, arg, argsOp, nresults); + case LBF_MATH_ATAN2: + return emitBuiltinMathAtan2(regs, build, nparams, ra, arg, argsOp, nresults); + case LBF_MATH_LOG10: + return emitBuiltinMathLog10(regs, build, nparams, ra, arg, argsOp, nresults); + case LBF_MATH_LOG: + return emitBuiltinMathLog(regs, build, nparams, ra, arg, argsOp, nresults); + case LBF_MATH_LDEXP: + return emitBuiltinMathLdexp(regs, build, nparams, ra, arg, argsOp, nresults); + case LBF_MATH_ROUND: + return emitBuiltinMathRound(regs, build, nparams, ra, arg, argsOp, nresults); + case LBF_MATH_FREXP: + return emitBuiltinMathFrexp(regs, build, nparams, ra, arg, argsOp, nresults); + case LBF_MATH_MODF: + return emitBuiltinMathModf(regs, build, nparams, ra, arg, argsOp, nresults); + case LBF_MATH_SIGN: + return emitBuiltinMathSign(regs, build, nparams, ra, arg, argsOp, nresults); default: - return {BuiltinImplType::None, -1}; + LUAU_ASSERT(!"missing x64 lowering"); + break; } } +} // namespace X64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/EmitBuiltinsX64.h b/CodeGen/src/EmitBuiltinsX64.h index 1ee04e92..5925a2b3 100644 --- a/CodeGen/src/EmitBuiltinsX64.h +++ b/CodeGen/src/EmitBuiltinsX64.h @@ -6,12 +6,18 @@ namespace Luau namespace CodeGen { -class AssemblyBuilderX64; struct Label; +struct IrOp; + +namespace X64 +{ + +class AssemblyBuilderX64; struct OperandX64; -struct BuiltinImplResult; +struct IrRegAllocX64; -BuiltinImplResult emitBuiltin(AssemblyBuilderX64& build, int bfid, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback); +void emitBuiltin(IrRegAllocX64& regs, AssemblyBuilderX64& build, int bfid, int ra, int arg, IrOp args, int nparams, int nresults); +} // namespace X64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/EmitCommon.h b/CodeGen/src/EmitCommon.h new file mode 100644 index 00000000..3c41c271 --- /dev/null +++ b/CodeGen/src/EmitCommon.h @@ -0,0 +1,29 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Label.h" + +namespace Luau +{ +namespace CodeGen +{ + +constexpr unsigned kTValueSizeLog2 = 4; +constexpr unsigned kLuaNodeSizeLog2 = 5; +constexpr unsigned kLuaNodeTagMask = 0xf; +constexpr unsigned kNextBitOffset = 4; + +constexpr unsigned kOffsetOfLuaNodeTag = 12; // offsetof cannot be used on a bit field +constexpr unsigned kOffsetOfLuaNodeNext = 12; // offsetof cannot be used on a bit field +constexpr unsigned kOffsetOfInstructionC = 3; + +// Leaf functions that are placed in every module to perform common instruction sequences +struct ModuleHelpers +{ + Label exitContinueVm; + Label exitNoContinueVm; + Label continueCallInVm; +}; + +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/src/EmitCommonX64.cpp b/CodeGen/src/EmitCommonX64.cpp index 7d36e17d..e9cfdc48 100644 --- a/CodeGen/src/EmitCommonX64.cpp +++ b/CodeGen/src/EmitCommonX64.cpp @@ -2,6 +2,7 @@ #include "EmitCommonX64.h" #include "Luau/AssemblyBuilderX64.h" +#include "Luau/IrData.h" #include "CustomExecUtils.h" #include "NativeState.h" @@ -13,8 +14,10 @@ namespace Luau { namespace CodeGen { +namespace X64 +{ -void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs, OperandX64 rhs, ConditionX64 cond, Label& label) +void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs, OperandX64 rhs, IrCondition cond, Label& label) { // Refresher on comi/ucomi EFLAGS: // CF only: less @@ -35,23 +38,23 @@ void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs, // And because of NaN, integer check interchangeability like 'not less or equal' <-> 'greater' does not hold switch (cond) { - case ConditionX64::NotLessEqual: + case IrCondition::NotLessEqual: // (b < a) is the same as !(a <= b). jnae checks CF=1 which means < or NaN build.jcc(ConditionX64::NotAboveEqual, label); break; - case ConditionX64::LessEqual: + case IrCondition::LessEqual: // (b >= a) is the same as (a <= b). jae checks CF=0 which means >= and not NaN build.jcc(ConditionX64::AboveEqual, label); break; - case ConditionX64::NotLess: + case IrCondition::NotLess: // (b <= a) is the same as !(a < b). jna checks CF=1 or ZF=1 which means <= or NaN build.jcc(ConditionX64::NotAbove, label); break; - case ConditionX64::Less: + case IrCondition::Less: // (b > a) is the same as (a < b). ja checks CF=0 and ZF=0 which means > and not NaN build.jcc(ConditionX64::Above, label); break; - case ConditionX64::NotEqual: + case IrCondition::NotEqual: // ZF=0 or PF=1 means != or NaN build.jcc(ConditionX64::NotZero, label); build.jcc(ConditionX64::Parity, label); @@ -61,25 +64,25 @@ void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs, } } -void jumpOnAnyCmpFallback(AssemblyBuilderX64& build, int ra, int rb, ConditionX64 cond, Label& label) +void jumpOnAnyCmpFallback(AssemblyBuilderX64& build, int ra, int rb, IrCondition cond, Label& label) { build.mov(rArg1, rState); build.lea(rArg2, luauRegAddress(ra)); build.lea(rArg3, luauRegAddress(rb)); - if (cond == ConditionX64::NotLessEqual || cond == ConditionX64::LessEqual) + if (cond == IrCondition::NotLessEqual || cond == IrCondition::LessEqual) build.call(qword[rNativeContext + offsetof(NativeContext, luaV_lessequal)]); - else if (cond == ConditionX64::NotLess || cond == ConditionX64::Less) + else if (cond == IrCondition::NotLess || cond == IrCondition::Less) build.call(qword[rNativeContext + offsetof(NativeContext, luaV_lessthan)]); - else if (cond == ConditionX64::NotEqual || cond == ConditionX64::Equal) + else if (cond == IrCondition::NotEqual || cond == IrCondition::Equal) build.call(qword[rNativeContext + offsetof(NativeContext, luaV_equalval)]); else LUAU_ASSERT(!"Unsupported condition"); emitUpdateBase(build); build.test(eax, eax); - build.jcc(cond == ConditionX64::NotLessEqual || cond == ConditionX64::NotLess || cond == ConditionX64::NotEqual ? ConditionX64::Zero - : ConditionX64::NotZero, + build.jcc(cond == IrCondition::NotLessEqual || cond == IrCondition::NotLess || cond == IrCondition::NotEqual ? ConditionX64::Zero + : ConditionX64::NotZero, label); } @@ -377,5 +380,6 @@ void emitContinueCallInVm(AssemblyBuilderX64& build) emitExit(build, /* continueInVm */ true); } +} // namespace X64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/EmitCommonX64.h b/CodeGen/src/EmitCommonX64.h index 8d6e36d6..6b676255 100644 --- a/CodeGen/src/EmitCommonX64.h +++ b/CodeGen/src/EmitCommonX64.h @@ -3,6 +3,8 @@ #include "Luau/AssemblyBuilderX64.h" +#include "EmitCommon.h" + #include "lobject.h" #include "ltm.h" @@ -23,8 +25,12 @@ namespace Luau namespace CodeGen { +enum class IrCondition : uint8_t; struct NativeState; +namespace X64 +{ + // Data that is very common to access is placed in non-volatile registers constexpr RegisterX64 rState = r15; // lua_State* L constexpr RegisterX64 rBase = r14; // StkId base @@ -65,23 +71,6 @@ constexpr OperandX64 sArg6 = noreg; #endif -constexpr unsigned kTValueSizeLog2 = 4; -constexpr unsigned kLuaNodeSizeLog2 = 5; -constexpr unsigned kLuaNodeTagMask = 0xf; -constexpr unsigned kNextBitOffset = 4; - -constexpr unsigned kOffsetOfLuaNodeTag = 12; // offsetof cannot be used on a bit field -constexpr unsigned kOffsetOfLuaNodeNext = 12; // offsetof cannot be used on a bit field -constexpr unsigned kOffsetOfInstructionC = 3; - -// Leaf functions that are placed in every module to perform common instruction sequences -struct ModuleHelpers -{ - Label exitContinueVm; - Label exitNoContinueVm; - Label continueCallInVm; -}; - inline OperandX64 luauReg(int ri) { return xmmword[rBase + ri * sizeof(TValue)]; @@ -243,8 +232,8 @@ inline void jumpIfNodeKeyNotInExpectedSlot(AssemblyBuilderX64& build, RegisterX6 jumpIfNodeValueTagIs(build, node, LUA_TNIL, label); } -void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs, OperandX64 rhs, ConditionX64 cond, Label& label); -void jumpOnAnyCmpFallback(AssemblyBuilderX64& build, int ra, int rb, ConditionX64 cond, Label& label); +void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs, OperandX64 rhs, IrCondition cond, Label& label); +void jumpOnAnyCmpFallback(AssemblyBuilderX64& build, int ra, int rb, IrCondition cond, Label& label); void getTableNodeAtCachedSlot(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 node, RegisterX64 table, int pcpos); void convertNumberToIndexOrJump(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 numd, RegisterX64 numi, Label& label); @@ -268,5 +257,6 @@ void emitFallback(AssemblyBuilderX64& build, NativeState& data, int op, int pcpo void emitContinueCallInVm(AssemblyBuilderX64& build); +} // namespace X64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/EmitInstructionX64.cpp b/CodeGen/src/EmitInstructionX64.cpp index 97ff9f59..3b0aa258 100644 --- a/CodeGen/src/EmitInstructionX64.cpp +++ b/CodeGen/src/EmitInstructionX64.cpp @@ -7,7 +7,6 @@ #include "EmitBuiltinsX64.h" #include "EmitCommonX64.h" #include "NativeState.h" -#include "IrTranslateBuiltins.h" // Used temporarily until emitInstFastCallN is removed #include "lobject.h" #include "ltm.h" @@ -16,6 +15,8 @@ namespace Luau { namespace CodeGen { +namespace X64 +{ void emitInstNameCall(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, const TValue* k, Label& next, Label& fallback) { @@ -481,137 +482,12 @@ void emitInstSetList(AssemblyBuilderX64& build, const Instruction* pc, Label& ne callBarrierTableFast(build, table, next); } -static void emitInstFastCallN( - AssemblyBuilderX64& build, const Instruction* pc, bool customParams, int customParamCount, OperandX64 customArgs, int pcpos, Label& fallback) +void emitinstForGLoop(AssemblyBuilderX64& build, int ra, int aux, Label& loopRepeat, Label& loopExit) { - int bfid = LUAU_INSN_A(*pc); - int skip = LUAU_INSN_C(*pc); + // ipairs-style traversal is handled in IR + LUAU_ASSERT(aux >= 0); - Instruction call = pc[skip + 1]; - LUAU_ASSERT(LUAU_INSN_OP(call) == LOP_CALL); - int ra = LUAU_INSN_A(call); - - int nparams = customParams ? customParamCount : LUAU_INSN_B(call) - 1; - int nresults = LUAU_INSN_C(call) - 1; - int arg = customParams ? LUAU_INSN_B(*pc) : ra + 1; - OperandX64 args = customParams ? customArgs : luauRegAddress(ra + 2); - - BuiltinImplResult br = emitBuiltin(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback); - - if (br.type == BuiltinImplType::UsesFallback) - { - if (nresults == LUA_MULTRET) - { - // L->top = ra + n; - build.lea(rax, addr[rBase + (ra + br.actualResultCount) * sizeof(TValue)]); - build.mov(qword[rState + offsetof(lua_State, top)], rax); - } - else if (nparams == LUA_MULTRET) - { - // L->top = L->ci->top; - build.mov(rax, qword[rState + offsetof(lua_State, ci)]); - build.mov(rax, qword[rax + offsetof(CallInfo, top)]); - build.mov(qword[rState + offsetof(lua_State, top)], rax); - } - - return; - } - - // TODO: we can skip saving pc for some well-behaved builtins which we didn't inline - emitSetSavedPc(build, pcpos + 1); // uses rax/rdx - - build.mov(rax, qword[rNativeContext + offsetof(NativeContext, luauF_table) + bfid * sizeof(luau_FastFunction)]); - - // 5th parameter (args) is left unset for LOP_FASTCALL1 - if (args.cat == CategoryX64::mem) - { - if (build.abi == ABIX64::Windows) - { - build.lea(rcx, args); - build.mov(sArg5, rcx); - } - else - { - build.lea(rArg5, args); - } - } - - if (nparams == LUA_MULTRET) - { - // L->top - (ra + 1) - RegisterX64 reg = (build.abi == ABIX64::Windows) ? rcx : rArg6; - build.mov(reg, qword[rState + offsetof(lua_State, top)]); - build.lea(rdx, addr[rBase + (ra + 1) * sizeof(TValue)]); - build.sub(reg, rdx); - build.shr(reg, kTValueSizeLog2); - - if (build.abi == ABIX64::Windows) - build.mov(sArg6, reg); - } - else - { - if (build.abi == ABIX64::Windows) - build.mov(sArg6, nparams); - else - build.mov(rArg6, nparams); - } - - build.mov(rArg1, rState); - build.lea(rArg2, luauRegAddress(ra)); - build.lea(rArg3, luauRegAddress(arg)); - build.mov(dwordReg(rArg4), nresults); - - build.call(rax); - - build.test(eax, eax); // test here will set SF=1 for a negative number and it always sets OF to 0 - build.jcc(ConditionX64::Less, fallback); // jl jumps if SF != OF - - if (nresults == LUA_MULTRET) - { - // L->top = ra + n; - build.shl(rax, kTValueSizeLog2); - build.lea(rax, addr[rBase + rax + ra * sizeof(TValue)]); - build.mov(qword[rState + offsetof(lua_State, top)], rax); - } - else if (nparams == LUA_MULTRET) - { - // L->top = L->ci->top; - build.mov(rax, qword[rState + offsetof(lua_State, ci)]); - build.mov(rax, qword[rax + offsetof(CallInfo, top)]); - build.mov(qword[rState + offsetof(lua_State, top)], rax); - } -} - -void emitInstFastCall1(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback) -{ - return emitInstFastCallN(build, pc, /* customParams */ true, /* customParamCount */ 1, /* customArgs */ 0, pcpos, fallback); -} - -void emitInstFastCall2(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback) -{ - return emitInstFastCallN(build, pc, /* customParams */ true, /* customParamCount */ 2, /* customArgs */ luauRegAddress(pc[1]), pcpos, fallback); -} - -void emitInstFastCall2K(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback) -{ - return emitInstFastCallN( - build, pc, /* customParams */ true, /* customParamCount */ 2, /* customArgs */ luauConstantAddress(pc[1]), pcpos, fallback); -} - -void emitInstFastCall(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback) -{ - return emitInstFastCallN(build, pc, /* customParams */ false, /* customParamCount */ 0, /* customArgs */ 0, pcpos, fallback); -} - -void emitinstForGLoop(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& loopRepeat, Label& loopExit, Label& fallback) -{ - int ra = LUAU_INSN_A(*pc); - int aux = pc[1]; - - emitInterrupt(build, pcpos); - - // fast-path: builtin table iteration - jumpIfTagIsNot(build, ra, LUA_TNIL, fallback); + // This is a fast-path for builtin table iteration, tag check for 'ra' has to be performed before emitting this instruction // Registers are chosen in this way to simplify fallback code for the node part RegisterX64 table = rArg2; @@ -630,22 +506,19 @@ void emitinstForGLoop(AssemblyBuilderX64& build, const Instruction* pc, int pcpo for (int i = 2; i < aux; ++i) build.mov(luauRegTag(ra + 3 + i), LUA_TNIL); - // ipairs-style traversal is terminated early when array part ends of nil array element is encountered - bool isIpairsIter = aux < 0; - Label skipArray, skipArrayNil; // First we advance index through the array portion // while (unsigned(index) < unsigned(sizearray)) Label arrayLoop = build.setLabel(); build.cmp(dwordReg(index), dword[table + offsetof(Table, sizearray)]); - build.jcc(ConditionX64::NotBelow, isIpairsIter ? loopExit : skipArray); + build.jcc(ConditionX64::NotBelow, skipArray); // If element is nil, we increment the index; if it's not, we still need 'index + 1' inside build.inc(index); build.cmp(dword[elemPtr + offsetof(TValue, tt)], LUA_TNIL); - build.jcc(ConditionX64::Equal, isIpairsIter ? loopExit : skipArrayNil); + build.jcc(ConditionX64::Equal, skipArrayNil); // setpvalue(ra + 2, reinterpret_cast(uintptr_t(index + 1))); build.mov(luauRegValue(ra + 2), index); @@ -661,31 +534,25 @@ void emitinstForGLoop(AssemblyBuilderX64& build, const Instruction* pc, int pcpo build.jmp(loopRepeat); - if (!isIpairsIter) - { - build.setLabel(skipArrayNil); + build.setLabel(skipArrayNil); - // Index already incremented, advance to next array element - build.add(elemPtr, sizeof(TValue)); - build.jmp(arrayLoop); + // Index already incremented, advance to next array element + build.add(elemPtr, sizeof(TValue)); + build.jmp(arrayLoop); - build.setLabel(skipArray); + build.setLabel(skipArray); - // Call helper to assign next node value or to signal loop exit - build.mov(rArg1, rState); - // rArg2 and rArg3 are already set - build.lea(rArg4, luauRegAddress(ra)); - build.call(qword[rNativeContext + offsetof(NativeContext, forgLoopNodeIter)]); - build.test(al, al); - build.jcc(ConditionX64::NotZero, loopRepeat); - } + // Call helper to assign next node value or to signal loop exit + build.mov(rArg1, rState); + // rArg2 and rArg3 are already set + build.lea(rArg4, luauRegAddress(ra)); + build.call(qword[rNativeContext + offsetof(NativeContext, forgLoopNodeIter)]); + build.test(al, al); + build.jcc(ConditionX64::NotZero, loopRepeat); } -void emitinstForGLoopFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& loopRepeat) +void emitinstForGLoopFallback(AssemblyBuilderX64& build, int pcpos, int ra, int aux, Label& loopRepeat) { - int ra = LUAU_INSN_A(*pc); - int aux = pc[1]; - emitSetSavedPc(build, pcpos + 1); build.mov(rArg1, rState); @@ -697,10 +564,8 @@ void emitinstForGLoopFallback(AssemblyBuilderX64& build, const Instruction* pc, build.jcc(ConditionX64::NotZero, loopRepeat); } -void emitInstForGPrepXnextFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& target) +void emitInstForGPrepXnextFallback(AssemblyBuilderX64& build, int pcpos, int ra, Label& target) { - int ra = LUAU_INSN_A(*pc); - build.mov(rArg1, rState); build.lea(rArg2, luauRegAddress(ra)); build.mov(dwordReg(rArg3), pcpos + 1); @@ -836,5 +701,6 @@ void emitInstCoverage(AssemblyBuilderX64& build, int pcpos) build.mov(dword[rcx], eax); } +} // namespace X64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/EmitInstructionX64.h b/CodeGen/src/EmitInstructionX64.h index 5fbfb56d..dcca52ab 100644 --- a/CodeGen/src/EmitInstructionX64.h +++ b/CodeGen/src/EmitInstructionX64.h @@ -13,21 +13,21 @@ namespace Luau namespace CodeGen { -class AssemblyBuilderX64; struct Label; struct ModuleHelpers; +namespace X64 +{ + +class AssemblyBuilderX64; + void emitInstNameCall(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, const TValue* k, Label& next, Label& fallback); void emitInstCall(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Instruction* pc, int pcpos); void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Instruction* pc, int pcpos); void emitInstSetList(AssemblyBuilderX64& build, const Instruction* pc, Label& next); -void emitInstFastCall1(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback); -void emitInstFastCall2(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback); -void emitInstFastCall2K(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback); -void emitInstFastCall(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback); -void emitinstForGLoop(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& loopRepeat, Label& loopExit, Label& fallback); -void emitinstForGLoopFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& loopRepeat); -void emitInstForGPrepXnextFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& target); +void emitinstForGLoop(AssemblyBuilderX64& build, int ra, int aux, Label& loopRepeat, Label& loopExit); +void emitinstForGLoopFallback(AssemblyBuilderX64& build, int pcpos, int ra, int aux, Label& loopRepeat); +void emitInstForGPrepXnextFallback(AssemblyBuilderX64& build, int pcpos, int ra, Label& target); void emitInstAnd(AssemblyBuilderX64& build, const Instruction* pc); void emitInstAndK(AssemblyBuilderX64& build, const Instruction* pc); void emitInstOr(AssemblyBuilderX64& build, const Instruction* pc); @@ -35,5 +35,6 @@ void emitInstOrK(AssemblyBuilderX64& build, const Instruction* pc); void emitInstGetImportFallback(AssemblyBuilderX64& build, int ra, uint32_t aux); void emitInstCoverage(AssemblyBuilderX64& build, int pcpos); +} // namespace X64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/IrAnalysis.cpp b/CodeGen/src/IrAnalysis.cpp index b494f2af..aa3e19f7 100644 --- a/CodeGen/src/IrAnalysis.cpp +++ b/CodeGen/src/IrAnalysis.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/IrAnalysis.h" +#include "Luau/DenseHash.h" #include "Luau/IrData.h" #include "Luau/IrUtils.h" @@ -73,5 +74,47 @@ void updateLastUseLocations(IrFunction& function) } } +std::pair getLiveInOutValueCount(IrFunction& function, IrBlock& block) +{ + uint32_t liveIns = 0; + uint32_t liveOuts = 0; + + auto checkOp = [&](IrOp op) { + if (op.kind == IrOpKind::Inst) + { + if (op.index >= block.start && op.index <= block.finish) + liveOuts--; + else + liveIns++; + } + }; + + for (uint32_t instIdx = block.start; instIdx <= block.finish; instIdx++) + { + IrInst& inst = function.instructions[instIdx]; + + liveOuts += inst.useCount; + + checkOp(inst.a); + checkOp(inst.b); + checkOp(inst.c); + checkOp(inst.d); + checkOp(inst.e); + checkOp(inst.f); + } + + return std::make_pair(liveIns, liveOuts); +} + +uint32_t getLiveInValueCount(IrFunction& function, IrBlock& block) +{ + return getLiveInOutValueCount(function, block).first; +} + +uint32_t getLiveOutValueCount(IrFunction& function, IrBlock& block) +{ + return getLiveInOutValueCount(function, block).second; +} + } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/IrBuilder.cpp b/CodeGen/src/IrBuilder.cpp index 2e7c75d1..056ea600 100644 --- a/CodeGen/src/IrBuilder.cpp +++ b/CodeGen/src/IrBuilder.cpp @@ -2,6 +2,7 @@ #include "Luau/IrBuilder.h" #include "Luau/Common.h" +#include "Luau/DenseHash.h" #include "Luau/IrAnalysis.h" #include "Luau/IrUtils.h" @@ -271,7 +272,7 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) int skip = LUAU_INSN_C(*pc); IrOp next = blockAtInst(i + skip + 2); - translateFastCallN(*this, pc, i, false, 0, {}, next, IrCmd::LOP_FASTCALL); + translateFastCallN(*this, pc, i, false, 0, {}, next); activeFastcallFallback = true; fastcallFallbackReturn = next; @@ -282,7 +283,7 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) int skip = LUAU_INSN_C(*pc); IrOp next = blockAtInst(i + skip + 2); - translateFastCallN(*this, pc, i, true, 1, constBool(false), next, IrCmd::LOP_FASTCALL1); + translateFastCallN(*this, pc, i, true, 1, constBool(false), next); activeFastcallFallback = true; fastcallFallbackReturn = next; @@ -293,7 +294,7 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) int skip = LUAU_INSN_C(*pc); IrOp next = blockAtInst(i + skip + 2); - translateFastCallN(*this, pc, i, true, 2, vmReg(pc[1]), next, IrCmd::LOP_FASTCALL2); + translateFastCallN(*this, pc, i, true, 2, vmReg(pc[1]), next); activeFastcallFallback = true; fastcallFallbackReturn = next; @@ -304,7 +305,7 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) int skip = LUAU_INSN_C(*pc); IrOp next = blockAtInst(i + skip + 2); - translateFastCallN(*this, pc, i, true, 2, vmConst(pc[1]), next, IrCmd::LOP_FASTCALL2K); + translateFastCallN(*this, pc, i, true, 2, vmConst(pc[1]), next); activeFastcallFallback = true; fastcallFallbackReturn = next; @@ -318,21 +319,28 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) break; case LOP_FORGLOOP: { + int aux = int(pc[1]); + // We have a translation for ipairs-style traversal, general loop iteration is still too complex - if (int(pc[1]) < 0) + if (aux < 0) { translateInstForGLoopIpairs(*this, pc, i); } else { + int ra = LUAU_INSN_A(*pc); + IrOp loopRepeat = blockAtInst(i + 1 + LUAU_INSN_D(*pc)); IrOp loopExit = blockAtInst(i + getOpLength(LOP_FORGLOOP)); IrOp fallback = block(IrBlockKind::Fallback); - inst(IrCmd::LOP_FORGLOOP, constUint(i), loopRepeat, loopExit, fallback); + inst(IrCmd::INTERRUPT, constUint(i)); + loadAndCheckTag(vmReg(ra), LUA_TNIL, fallback); + + inst(IrCmd::LOP_FORGLOOP, vmReg(ra), constInt(aux), loopRepeat, loopExit); beginBlock(fallback); - inst(IrCmd::LOP_FORGLOOP_FALLBACK, constUint(i), loopRepeat, loopExit); + inst(IrCmd::LOP_FORGLOOP_FALLBACK, constUint(i), vmReg(ra), constInt(aux), loopRepeat, loopExit); beginBlock(loopExit); } @@ -426,6 +434,68 @@ void IrBuilder::beginBlock(IrOp block) inTerminatedBlock = false; } +void IrBuilder::loadAndCheckTag(IrOp loc, uint8_t tag, IrOp fallback) +{ + inst(IrCmd::CHECK_TAG, inst(IrCmd::LOAD_TAG, loc), constTag(tag), fallback); +} + +void IrBuilder::clone(const IrBlock& source, bool removeCurrentTerminator) +{ + DenseHashMap instRedir{~0u}; + + auto redirect = [&instRedir](IrOp& op) { + if (op.kind == IrOpKind::Inst) + { + if (const uint32_t* newIndex = instRedir.find(op.index)) + op.index = *newIndex; + else + LUAU_ASSERT(!"values can only be used if they are defined in the same block"); + } + }; + + if (removeCurrentTerminator && inTerminatedBlock) + { + IrBlock& active = function.blocks[activeBlockIdx]; + IrInst& term = function.instructions[active.finish]; + + kill(function, term); + inTerminatedBlock = false; + } + + for (uint32_t index = source.start; index <= source.finish; index++) + { + LUAU_ASSERT(index < function.instructions.size()); + IrInst clone = function.instructions[index]; + + // Skip pseudo instructions to make clone more compact, but validate that they have no users + if (isPseudo(clone.cmd)) + { + LUAU_ASSERT(clone.useCount == 0); + continue; + } + + redirect(clone.a); + redirect(clone.b); + redirect(clone.c); + redirect(clone.d); + redirect(clone.e); + redirect(clone.f); + + addUse(function, clone.a); + addUse(function, clone.b); + addUse(function, clone.c); + addUse(function, clone.d); + addUse(function, clone.e); + addUse(function, clone.f); + + // Instructions that referenced the original will have to be adjusted to use the clone + instRedir[index] = uint32_t(function.instructions.size()); + + // Reconstruct the fresh clone + inst(clone.cmd, clone.a, clone.b, clone.c, clone.d, clone.e, clone.f); + } +} + IrOp IrBuilder::constBool(bool value) { IrConst constant; diff --git a/CodeGen/src/IrDump.cpp b/CodeGen/src/IrDump.cpp index 681de286..cb203f7a 100644 --- a/CodeGen/src/IrDump.cpp +++ b/CodeGen/src/IrDump.cpp @@ -118,6 +118,10 @@ const char* getCmdName(IrCmd cmd) return "MOD_NUM"; case IrCmd::POW_NUM: return "POW_NUM"; + case IrCmd::MIN_NUM: + return "MIN_NUM"; + case IrCmd::MAX_NUM: + return "MAX_NUM"; case IrCmd::UNM_NUM: return "UNM_NUM"; case IrCmd::NOT_ANY: @@ -152,6 +156,12 @@ const char* getCmdName(IrCmd cmd) return "ADJUST_STACK_TO_REG"; case IrCmd::ADJUST_STACK_TO_TOP: return "ADJUST_STACK_TO_TOP"; + case IrCmd::FASTCALL: + return "FASTCALL"; + case IrCmd::INVOKE_FASTCALL: + return "INVOKE_FASTCALL"; + case IrCmd::CHECK_FASTCALL_RES: + return "CHECK_FASTCALL_RES"; case IrCmd::DO_ARITH: return "DO_ARITH"; case IrCmd::DO_LEN: @@ -206,14 +216,6 @@ const char* getCmdName(IrCmd cmd) return "LOP_CALL"; case IrCmd::LOP_RETURN: return "LOP_RETURN"; - case IrCmd::LOP_FASTCALL: - return "LOP_FASTCALL"; - case IrCmd::LOP_FASTCALL1: - return "LOP_FASTCALL1"; - case IrCmd::LOP_FASTCALL2: - return "LOP_FASTCALL2"; - case IrCmd::LOP_FASTCALL2K: - return "LOP_FASTCALL2K"; case IrCmd::LOP_FORGLOOP: return "LOP_FORGLOOP"; case IrCmd::LOP_FORGLOOP_FALLBACK: @@ -267,6 +269,8 @@ const char* getBlockKindName(IrBlockKind kind) return "bb_fallback"; case IrBlockKind::Internal: return "bb"; + case IrBlockKind::Linearized: + return "bb_linear"; case IrBlockKind::Dead: return "dead"; } diff --git a/CodeGen/src/IrLoweringX64.cpp b/CodeGen/src/IrLoweringX64.cpp index d81240ff..38337575 100644 --- a/CodeGen/src/IrLoweringX64.cpp +++ b/CodeGen/src/IrLoweringX64.cpp @@ -20,6 +20,8 @@ namespace Luau { namespace CodeGen { +namespace X64 +{ IrLoweringX64::IrLoweringX64(AssemblyBuilderX64& build, ModuleHelpers& helpers, NativeState& data, Proto* proto, IrFunction& function) : build(build) @@ -517,6 +519,36 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) break; } + case IrCmd::MIN_NUM: + inst.regX64 = regs.allocXmmRegOrReuse(index, {inst.a, inst.b}); + + if (inst.a.kind == IrOpKind::Constant) + { + ScopedRegX64 tmp{regs, SizeX64::xmmword}; + + build.vmovsd(tmp.reg, memRegDoubleOp(inst.a)); + build.vminsd(inst.regX64, tmp.reg, memRegDoubleOp(inst.b)); + } + else + { + build.vminsd(inst.regX64, regOp(inst.a), memRegDoubleOp(inst.b)); + } + break; + case IrCmd::MAX_NUM: + inst.regX64 = regs.allocXmmRegOrReuse(index, {inst.a, inst.b}); + + if (inst.a.kind == IrOpKind::Constant) + { + ScopedRegX64 tmp{regs, SizeX64::xmmword}; + + build.vmovsd(tmp.reg, memRegDoubleOp(inst.a)); + build.vmaxsd(inst.regX64, tmp.reg, memRegDoubleOp(inst.b)); + } + else + { + build.vmaxsd(inst.regX64, regOp(inst.a), memRegDoubleOp(inst.b)); + } + break; case IrCmd::UNM_NUM: { inst.regX64 = regs.allocXmmRegOrReuse(index, {inst.a}); @@ -624,7 +656,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) ScopedRegX64 tmp{regs, SizeX64::xmmword}; // TODO: jumpOnNumberCmp should work on IrCondition directly - jumpOnNumberCmp(build, tmp.reg, memRegDoubleOp(inst.a), memRegDoubleOp(inst.b), getX64Condition(cond), labelOp(inst.d)); + jumpOnNumberCmp(build, tmp.reg, memRegDoubleOp(inst.a), memRegDoubleOp(inst.b), cond, labelOp(inst.d)); jumpOrFallthrough(blockOp(inst.e), next); break; } @@ -636,7 +668,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) IrCondition cond = IrCondition(inst.c.index); - jumpOnAnyCmpFallback(build, inst.a.index, inst.b.index, getX64Condition(cond), labelOp(inst.d)); + jumpOnAnyCmpFallback(build, inst.a.index, inst.b.index, cond, labelOp(inst.d)); jumpOrFallthrough(blockOp(inst.e), next); break; } @@ -716,6 +748,89 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.mov(qword[rState + offsetof(lua_State, top)], tmp.reg); break; } + + case IrCmd::FASTCALL: + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); + emitBuiltin(regs, build, uintOp(inst.a), inst.b.index, inst.c.index, inst.d, intOp(inst.e), intOp(inst.f)); + break; + case IrCmd::INVOKE_FASTCALL: + { + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); + + unsigned bfid = uintOp(inst.a); + + OperandX64 args = 0; + + if (inst.d.kind == IrOpKind::VmReg) + args = luauRegAddress(inst.d.index); + else if (inst.d.kind == IrOpKind::VmConst) + args = luauConstantAddress(inst.d.index); + else + LUAU_ASSERT(boolOp(inst.d) == false); + + int ra = inst.b.index; + int arg = inst.c.index; + int nparams = intOp(inst.e); + int nresults = intOp(inst.f); + + regs.assertAllFree(); + + build.mov(rax, qword[rNativeContext + offsetof(NativeContext, luauF_table) + bfid * sizeof(luau_FastFunction)]); + + // 5th parameter (args) is left unset for LOP_FASTCALL1 + if (args.cat == CategoryX64::mem) + { + if (build.abi == ABIX64::Windows) + { + build.lea(rcx, args); + build.mov(sArg5, rcx); + } + else + { + build.lea(rArg5, args); + } + } + + if (nparams == LUA_MULTRET) + { + // L->top - (ra + 1) + RegisterX64 reg = (build.abi == ABIX64::Windows) ? rcx : rArg6; + build.mov(reg, qword[rState + offsetof(lua_State, top)]); + build.lea(rdx, addr[rBase + (ra + 1) * sizeof(TValue)]); + build.sub(reg, rdx); + build.shr(reg, kTValueSizeLog2); + + if (build.abi == ABIX64::Windows) + build.mov(sArg6, reg); + } + else + { + if (build.abi == ABIX64::Windows) + build.mov(sArg6, nparams); + else + build.mov(rArg6, nparams); + } + + build.mov(rArg1, rState); + build.lea(rArg2, luauRegAddress(ra)); + build.lea(rArg3, luauRegAddress(arg)); + build.mov(dwordReg(rArg4), nresults); + + build.call(rax); + + inst.regX64 = regs.takeGprReg(eax); // Result of a builtin call is returned in eax + break; + } + case IrCmd::CHECK_FASTCALL_RES: + { + RegisterX64 res = regOp(inst.a); + + build.test(res, res); // test here will set SF=1 for a negative number and it always sets OF to 0 + build.jcc(ConditionX64::Less, labelOp(inst.b)); // jl jumps if SF != OF + break; + } case IrCmd::DO_ARITH: LUAU_ASSERT(inst.a.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); @@ -1014,41 +1129,18 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) emitInstReturn(build, helpers, pc, uintOp(inst.a)); break; } - case IrCmd::LOP_FASTCALL: - LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); - LUAU_ASSERT(inst.c.kind == IrOpKind::Constant); - - emitInstFastCall(build, proto->code + uintOp(inst.a), uintOp(inst.a), labelOp(inst.d)); - break; - case IrCmd::LOP_FASTCALL1: - LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); - LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); - - emitInstFastCall1(build, proto->code + uintOp(inst.a), uintOp(inst.a), labelOp(inst.d)); - break; - case IrCmd::LOP_FASTCALL2: - LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); - LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); - LUAU_ASSERT(inst.d.kind == IrOpKind::VmReg); - - emitInstFastCall2(build, proto->code + uintOp(inst.a), uintOp(inst.a), labelOp(inst.e)); - break; - case IrCmd::LOP_FASTCALL2K: - LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); - LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); - LUAU_ASSERT(inst.d.kind == IrOpKind::VmConst); - - emitInstFastCall2K(build, proto->code + uintOp(inst.a), uintOp(inst.a), labelOp(inst.e)); - break; case IrCmd::LOP_FORGLOOP: - emitinstForGLoop(build, proto->code + uintOp(inst.a), uintOp(inst.a), labelOp(inst.b), labelOp(inst.c), labelOp(inst.d)); + LUAU_ASSERT(inst.a.kind == IrOpKind::VmReg); + emitinstForGLoop(build, inst.a.index, intOp(inst.b), labelOp(inst.c), labelOp(inst.d)); break; case IrCmd::LOP_FORGLOOP_FALLBACK: - emitinstForGLoopFallback(build, proto->code + uintOp(inst.a), uintOp(inst.a), labelOp(inst.b)); - build.jmp(labelOp(inst.c)); + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + emitinstForGLoopFallback(build, uintOp(inst.a), inst.b.index, intOp(inst.c), labelOp(inst.d)); + build.jmp(labelOp(inst.e)); break; case IrCmd::LOP_FORGPREP_XNEXT_FALLBACK: - emitInstForGPrepXnextFallback(build, proto->code + uintOp(inst.a), uintOp(inst.a), labelOp(inst.b)); + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + emitInstForGPrepXnextFallback(build, uintOp(inst.a), inst.b.index, labelOp(inst.c)); break; case IrCmd::LOP_AND: emitInstAnd(build, proto->code + uintOp(inst.a)); @@ -1224,38 +1316,6 @@ Label& IrLoweringX64::labelOp(IrOp op) const return blockOp(op).label; } -ConditionX64 IrLoweringX64::getX64Condition(IrCondition cond) const -{ - // TODO: this function will not be required when jumpOnNumberCmp starts accepting an IrCondition - switch (cond) - { - case IrCondition::Equal: - return ConditionX64::Equal; - case IrCondition::NotEqual: - return ConditionX64::NotEqual; - case IrCondition::Less: - return ConditionX64::Less; - case IrCondition::NotLess: - return ConditionX64::NotLess; - case IrCondition::LessEqual: - return ConditionX64::LessEqual; - case IrCondition::NotLessEqual: - return ConditionX64::NotLessEqual; - case IrCondition::Greater: - return ConditionX64::Greater; - case IrCondition::NotGreater: - return ConditionX64::NotGreater; - case IrCondition::GreaterEqual: - return ConditionX64::GreaterEqual; - case IrCondition::NotGreaterEqual: - return ConditionX64::NotGreaterEqual; - default: - LUAU_ASSERT(!"unsupported condition"); - break; - } - - return ConditionX64::Count; -} - +} // namespace X64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/IrLoweringX64.h b/CodeGen/src/IrLoweringX64.h index e47c3978..a0ad3eab 100644 --- a/CodeGen/src/IrLoweringX64.h +++ b/CodeGen/src/IrLoweringX64.h @@ -19,6 +19,9 @@ struct ModuleHelpers; struct NativeState; struct AssemblyOptions; +namespace X64 +{ + struct IrLoweringX64 { // Some of these arguments are only required while we re-use old direct bytecode to x64 lowering @@ -46,8 +49,6 @@ struct IrLoweringX64 IrBlock& blockOp(IrOp op) const; Label& labelOp(IrOp op) const; - ConditionX64 getX64Condition(IrCondition cond) const; - AssemblyBuilderX64& build; ModuleHelpers& helpers; NativeState& data; @@ -58,5 +59,6 @@ struct IrLoweringX64 IrRegAllocX64 regs; }; +} // namespace X64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/IrRegAllocX64.cpp b/CodeGen/src/IrRegAllocX64.cpp index de60159f..91867806 100644 --- a/CodeGen/src/IrRegAllocX64.cpp +++ b/CodeGen/src/IrRegAllocX64.cpp @@ -19,6 +19,8 @@ namespace Luau { namespace CodeGen { +namespace X64 +{ static const RegisterX64 kGprAllocOrder[] = {rax, rdx, rcx, rbx, rsi, rdi, r8, r9, r10, r11}; @@ -106,6 +108,16 @@ RegisterX64 IrRegAllocX64::allocXmmRegOrReuse(uint32_t index, std::initializer_l return allocXmmReg(); } +RegisterX64 IrRegAllocX64::takeGprReg(RegisterX64 reg) +{ + // In a more advanced register allocator, this would require a spill for the current register user + // But at the current stage we don't have register live ranges intersecting forced register uses + LUAU_ASSERT(freeGprMap[reg.index]); + + freeGprMap[reg.index] = false; + return reg; +} + void IrRegAllocX64::freeReg(RegisterX64 reg) { if (reg.size == SizeX64::xmmword) @@ -148,6 +160,15 @@ void IrRegAllocX64::freeLastUseRegs(const IrInst& inst, uint32_t index) checkOp(inst.f); } +void IrRegAllocX64::assertAllFree() const +{ + for (RegisterX64 reg : kGprAllocOrder) + LUAU_ASSERT(freeGprMap[reg.index]); + + for (bool free : freeXmmMap) + LUAU_ASSERT(free); +} + ScopedRegX64::ScopedRegX64(IrRegAllocX64& owner, SizeX64 size) : owner(owner) { @@ -157,7 +178,6 @@ ScopedRegX64::ScopedRegX64(IrRegAllocX64& owner, SizeX64 size) reg = owner.allocGprReg(size); } - ScopedRegX64::ScopedRegX64(IrRegAllocX64& owner, RegisterX64 reg) : owner(owner) , reg(reg) @@ -177,5 +197,6 @@ void ScopedRegX64::free() reg = noreg; } +} // namespace X64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/IrRegAllocX64.h b/CodeGen/src/IrRegAllocX64.h index a532c3b3..ac072a32 100644 --- a/CodeGen/src/IrRegAllocX64.h +++ b/CodeGen/src/IrRegAllocX64.h @@ -11,6 +11,8 @@ namespace Luau { namespace CodeGen { +namespace X64 +{ struct IrRegAllocX64 { @@ -22,10 +24,14 @@ struct IrRegAllocX64 RegisterX64 allocGprRegOrReuse(SizeX64 preferredSize, uint32_t index, std::initializer_list oprefs); RegisterX64 allocXmmRegOrReuse(uint32_t index, std::initializer_list oprefs); + RegisterX64 takeGprReg(RegisterX64 reg); + void freeReg(RegisterX64 reg); void freeLastUseReg(IrInst& target, uint32_t index); void freeLastUseRegs(const IrInst& inst, uint32_t index); + void assertAllFree() const; + IrFunction& function; std::array freeGprMap; @@ -47,5 +53,6 @@ struct ScopedRegX64 RegisterX64 reg; }; +} // namespace X64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/IrTranslateBuiltins.cpp b/CodeGen/src/IrTranslateBuiltins.cpp index ccd743ed..bc909105 100644 --- a/CodeGen/src/IrTranslateBuiltins.cpp +++ b/CodeGen/src/IrTranslateBuiltins.cpp @@ -6,11 +6,66 @@ #include "lstate.h" +// TODO: should be possible to handle fastcalls in contexts where nresults is -1 by adding the adjustment instruction +// TODO: when nresults is less than our actual result count, we can skip computing/writing unused results + namespace Luau { namespace CodeGen { +// Wrapper code for all builtins with a fixed signature and manual assembly lowering of the body + +// (number, ...) -> number +BuiltinImplResult translateBuiltinNumberToNumber( + IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +{ + if (nparams < 1 || nresults > 1) + return {BuiltinImplType::None, -1}; + + build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); + build.inst(IrCmd::FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), args, build.constInt(nparams), build.constInt(nresults)); + + // TODO: tag update might not be required, we place it here now because FASTCALL is not modeled in constant propagation yet + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); + + return {BuiltinImplType::UsesFallback, 1}; +} + +// (number, number, ...) -> number +BuiltinImplResult translateBuiltin2NumberToNumber( + IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +{ + if (nparams < 2 || nresults > 1) + return {BuiltinImplType::None, -1}; + + build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); + build.loadAndCheckTag(args, LUA_TNUMBER, fallback); + build.inst(IrCmd::FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), args, build.constInt(nparams), build.constInt(nresults)); + + // TODO:tag update might not be required, we place it here now because FASTCALL is not modeled in constant propagation yet + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); + + return {BuiltinImplType::UsesFallback, 1}; +} + +// (number, ...) -> (number, number) +BuiltinImplResult translateBuiltinNumberTo2Number( + IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +{ + if (nparams < 1 || nresults > 2) + return {BuiltinImplType::None, -1}; + + build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); + build.inst(IrCmd::FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), args, build.constInt(nparams), build.constInt(nresults)); + + // TODO: some tag updates might not be required, we place them here now because FASTCALL is not modeled in constant propagation yet + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); + build.inst(IrCmd::STORE_TAG, build.vmReg(ra + 1), build.constTag(LUA_TNUMBER)); + + return {BuiltinImplType::UsesFallback, 2}; +} + BuiltinImplResult translateBuiltinAssert(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) { if (nparams < 1 || nresults != 0) @@ -25,12 +80,180 @@ BuiltinImplResult translateBuiltinAssert(IrBuilder& build, int nparams, int ra, return {BuiltinImplType::UsesFallback, 0}; } +BuiltinImplResult translateBuiltinMathDeg(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +{ + if (nparams < 1 || nresults > 1) + return {BuiltinImplType::None, -1}; + + build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); + + const double rpd = (3.14159265358979323846 / 180.0); + + IrOp varg = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); + IrOp value = build.inst(IrCmd::DIV_NUM, varg, build.constDouble(rpd)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), value); + + if (ra != arg) + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); + + return {BuiltinImplType::UsesFallback, 1}; +} + +BuiltinImplResult translateBuiltinMathRad(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +{ + if (nparams < 1 || nresults > 1) + return {BuiltinImplType::None, -1}; + + build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); + + const double rpd = (3.14159265358979323846 / 180.0); + + IrOp varg = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); + IrOp value = build.inst(IrCmd::MUL_NUM, varg, build.constDouble(rpd)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), value); + + if (ra != arg) + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); + + return {BuiltinImplType::UsesFallback, 1}; +} + +BuiltinImplResult translateBuiltinMathLog( + IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +{ + if (nparams < 1 || nresults > 1) + return {BuiltinImplType::None, -1}; + + build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); + + if (nparams != 1) + build.loadAndCheckTag(args, LUA_TNUMBER, fallback); + + build.inst(IrCmd::FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), args, build.constInt(nparams), build.constInt(nresults)); + + // TODO: tag update might not be required, we place it here now because FASTCALL is not modeled in constant propagation yet + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); + + return {BuiltinImplType::UsesFallback, 1}; +} + +BuiltinImplResult translateBuiltinMathMin(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +{ + // TODO: this can be extended for other number of arguments + if (nparams != 2 || nresults > 1) + return {BuiltinImplType::None, -1}; + + build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); + build.loadAndCheckTag(args, LUA_TNUMBER, fallback); + + IrOp varg1 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); + IrOp varg2 = build.inst(IrCmd::LOAD_DOUBLE, args); + + IrOp res = build.inst(IrCmd::MIN_NUM, varg2, varg1); // Swapped arguments are required for consistency with VM builtins + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), res); + + if (ra != arg) + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); + + return {BuiltinImplType::UsesFallback, 1}; +} + +BuiltinImplResult translateBuiltinMathMax(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +{ + // TODO: this can be extended for other number of arguments + if (nparams != 2 || nresults > 1) + return {BuiltinImplType::None, -1}; + + build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); + build.loadAndCheckTag(args, LUA_TNUMBER, fallback); + + IrOp varg1 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); + IrOp varg2 = build.inst(IrCmd::LOAD_DOUBLE, args); + + IrOp res = build.inst(IrCmd::MAX_NUM, varg2, varg1); // Swapped arguments are required for consistency with VM builtins + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), res); + + if (ra != arg) + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); + + return {BuiltinImplType::UsesFallback, 1}; +} + +BuiltinImplResult translateBuiltinMathClamp(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +{ + if (nparams < 3 || nresults > 1) + return {BuiltinImplType::None, -1}; + + IrOp block = build.block(IrBlockKind::Internal); + + LUAU_ASSERT(args.kind == IrOpKind::VmReg); + + build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); + build.loadAndCheckTag(args, LUA_TNUMBER, fallback); + build.loadAndCheckTag(build.vmReg(args.index + 1), LUA_TNUMBER, fallback); + + IrOp min = build.inst(IrCmd::LOAD_DOUBLE, args); + IrOp max = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(args.index + 1)); + + build.inst(IrCmd::JUMP_CMP_NUM, min, max, build.cond(IrCondition::NotLessEqual), fallback, block); + build.beginBlock(block); + + IrOp v = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); + IrOp r = build.inst(IrCmd::MAX_NUM, min, v); + IrOp clamped = build.inst(IrCmd::MIN_NUM, max, r); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), clamped); + + if (ra != arg) + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); + + return {BuiltinImplType::UsesFallback, 1}; +} + BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg, IrOp args, int nparams, int nresults, IrOp fallback) { switch (bfid) { case LBF_ASSERT: return translateBuiltinAssert(build, nparams, ra, arg, args, nresults, fallback); + case LBF_MATH_DEG: + return translateBuiltinMathDeg(build, nparams, ra, arg, args, nresults, fallback); + case LBF_MATH_RAD: + return translateBuiltinMathRad(build, nparams, ra, arg, args, nresults, fallback); + case LBF_MATH_LOG: + return translateBuiltinMathLog(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback); + case LBF_MATH_MIN: + return translateBuiltinMathMin(build, nparams, ra, arg, args, nresults, fallback); + case LBF_MATH_MAX: + return translateBuiltinMathMax(build, nparams, ra, arg, args, nresults, fallback); + case LBF_MATH_CLAMP: + return translateBuiltinMathClamp(build, nparams, ra, arg, args, nresults, fallback); + case LBF_MATH_FLOOR: + case LBF_MATH_CEIL: + case LBF_MATH_SQRT: + case LBF_MATH_ABS: + case LBF_MATH_EXP: + case LBF_MATH_ASIN: + case LBF_MATH_SIN: + case LBF_MATH_SINH: + case LBF_MATH_ACOS: + case LBF_MATH_COS: + case LBF_MATH_COSH: + case LBF_MATH_ATAN: + case LBF_MATH_TAN: + case LBF_MATH_TANH: + case LBF_MATH_LOG10: + case LBF_MATH_ROUND: + case LBF_MATH_SIGN: + return translateBuiltinNumberToNumber(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback); + case LBF_MATH_FMOD: + case LBF_MATH_POW: + case LBF_MATH_ATAN2: + case LBF_MATH_LDEXP: + return translateBuiltin2NumberToNumber(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback); + case LBF_MATH_FREXP: + case LBF_MATH_MODF: + return translateBuiltinNumberTo2Number(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback); default: return {BuiltinImplType::None, -1}; } diff --git a/CodeGen/src/IrTranslation.cpp b/CodeGen/src/IrTranslation.cpp index 68c6c402..48ca3975 100644 --- a/CodeGen/src/IrTranslation.cpp +++ b/CodeGen/src/IrTranslation.cpp @@ -479,8 +479,7 @@ void translateInstCloseUpvals(IrBuilder& build, const Instruction* pc) build.inst(IrCmd::CLOSE_UPVALS, build.vmReg(ra)); } -void translateFastCallN( - IrBuilder& build, const Instruction* pc, int pcpos, bool customParams, int customParamCount, IrOp customArgs, IrOp next, IrCmd fallbackCmd) +void translateFastCallN(IrBuilder& build, const Instruction* pc, int pcpos, bool customParams, int customParamCount, IrOp customArgs, IrOp next) { int bfid = LUAU_INSN_A(*pc); int skip = LUAU_INSN_C(*pc); @@ -509,23 +508,17 @@ void translateFastCallN( } else { - switch (fallbackCmd) - { - case IrCmd::LOP_FASTCALL: - build.inst(IrCmd::LOP_FASTCALL, build.constUint(pcpos), build.vmReg(ra), build.constInt(nparams), fallback); - break; - case IrCmd::LOP_FASTCALL1: - build.inst(IrCmd::LOP_FASTCALL1, build.constUint(pcpos), build.vmReg(ra), build.vmReg(arg), fallback); - break; - case IrCmd::LOP_FASTCALL2: - build.inst(IrCmd::LOP_FASTCALL2, build.constUint(pcpos), build.vmReg(ra), build.vmReg(arg), build.vmReg(pc[1]), fallback); - break; - case IrCmd::LOP_FASTCALL2K: - build.inst(IrCmd::LOP_FASTCALL2K, build.constUint(pcpos), build.vmReg(ra), build.vmReg(arg), build.vmConst(pc[1]), fallback); - break; - default: - LUAU_ASSERT(!"unexpected command"); - } + // TODO: we can skip saving pc for some well-behaved builtins which we didn't inline + build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1)); + + IrOp res = build.inst(IrCmd::INVOKE_FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), args, build.constInt(nparams), + build.constInt(nresults)); + build.inst(IrCmd::CHECK_FASTCALL_RES, res, fallback); + + if (nresults == LUA_MULTRET) + build.inst(IrCmd::ADJUST_STACK_TO_REG, build.vmReg(ra), res); + else if (nparams == LUA_MULTRET) + build.inst(IrCmd::ADJUST_STACK_TO_TOP); } build.inst(IrCmd::JUMP, next); @@ -645,7 +638,7 @@ void translateInstForGPrepNext(IrBuilder& build, const Instruction* pc, int pcpo build.inst(IrCmd::JUMP, target); build.beginBlock(fallback); - build.inst(IrCmd::LOP_FORGPREP_XNEXT_FALLBACK, build.constUint(pcpos), target); + build.inst(IrCmd::LOP_FORGPREP_XNEXT_FALLBACK, build.constUint(pcpos), build.vmReg(ra), target); } void translateInstForGPrepInext(IrBuilder& build, const Instruction* pc, int pcpos) @@ -677,7 +670,7 @@ void translateInstForGPrepInext(IrBuilder& build, const Instruction* pc, int pcp build.inst(IrCmd::JUMP, target); build.beginBlock(fallback); - build.inst(IrCmd::LOP_FORGPREP_XNEXT_FALLBACK, build.constUint(pcpos), target); + build.inst(IrCmd::LOP_FORGPREP_XNEXT_FALLBACK, build.constUint(pcpos), build.vmReg(ra), target); } void translateInstForGLoopIpairs(IrBuilder& build, const Instruction* pc, int pcpos) @@ -728,7 +721,7 @@ void translateInstForGLoopIpairs(IrBuilder& build, const Instruction* pc, int pc build.inst(IrCmd::JUMP, loopRepeat); build.beginBlock(fallback); - build.inst(IrCmd::LOP_FORGLOOP_FALLBACK, build.constUint(pcpos), loopRepeat, loopExit); + build.inst(IrCmd::LOP_FORGLOOP_FALLBACK, build.constUint(pcpos), build.vmReg(ra), build.constInt(int(pc[1])), loopRepeat, loopExit); // Fallthrough in original bytecode is implicit, so we start next internal block here if (build.isInternalBlock(loopExit)) diff --git a/CodeGen/src/IrTranslation.h b/CodeGen/src/IrTranslation.h index 5b3f78f2..0d4a5096 100644 --- a/CodeGen/src/IrTranslation.h +++ b/CodeGen/src/IrTranslation.h @@ -43,8 +43,7 @@ void translateInstDupTable(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstGetUpval(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstSetUpval(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstCloseUpvals(IrBuilder& build, const Instruction* pc); -void translateFastCallN( - IrBuilder& build, const Instruction* pc, int pcpos, bool customParams, int customParamCount, IrOp customArgs, IrOp next, IrCmd fallbackCmd); +void translateFastCallN(IrBuilder& build, const Instruction* pc, int pcpos, bool customParams, int customParamCount, IrOp customArgs, IrOp next); void translateInstForNPrep(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstForNLoop(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstForGPrepNext(IrBuilder& build, const Instruction* pc, int pcpos); diff --git a/CodeGen/src/IrUtils.cpp b/CodeGen/src/IrUtils.cpp index 6ccbc8ce..0808ad07 100644 --- a/CodeGen/src/IrUtils.cpp +++ b/CodeGen/src/IrUtils.cpp @@ -286,6 +286,24 @@ void foldConstants(IrBuilder& build, IrFunction& function, IrBlock& block, uint3 if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) substitute(function, inst, build.constDouble(pow(function.doubleOp(inst.a), function.doubleOp(inst.b)))); break; + case IrCmd::MIN_NUM: + if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) + { + double a1 = function.doubleOp(inst.a); + double a2 = function.doubleOp(inst.b); + + substitute(function, inst, build.constDouble((a2 < a1) ? a2 : a1)); + } + break; + case IrCmd::MAX_NUM: + if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) + { + double a1 = function.doubleOp(inst.a); + double a2 = function.doubleOp(inst.b); + + substitute(function, inst, build.constDouble((a2 > a1) ? a2 : a1)); + } + break; case IrCmd::UNM_NUM: if (inst.a.kind == IrOpKind::Constant) substitute(function, inst, build.constDouble(-function.doubleOp(inst.a))); diff --git a/CodeGen/src/OptimizeConstProp.cpp b/CodeGen/src/OptimizeConstProp.cpp index c9c7f6c4..956c96d6 100644 --- a/CodeGen/src/OptimizeConstProp.cpp +++ b/CodeGen/src/OptimizeConstProp.cpp @@ -2,11 +2,16 @@ #include "Luau/OptimizeConstProp.h" #include "Luau/DenseHash.h" +#include "Luau/IrAnalysis.h" #include "Luau/IrBuilder.h" #include "Luau/IrUtils.h" #include "lua.h" +#include + +LUAU_FASTINTVARIABLE(LuauCodeGenMinLinearBlockPath, 3) + namespace Luau { namespace CodeGen @@ -181,6 +186,82 @@ struct ConstPropState DenseHashMap instLink{~0u}; }; +static void handleBuiltinEffects(ConstPropState& state, LuauBuiltinFunction bfid, uint32_t firstReturnReg, int nresults) +{ + // Switch over all values is used to force new items to be handled + switch (bfid) + { + case LBF_NONE: + case LBF_ASSERT: + case LBF_MATH_ABS: + case LBF_MATH_ACOS: + case LBF_MATH_ASIN: + case LBF_MATH_ATAN2: + case LBF_MATH_ATAN: + case LBF_MATH_CEIL: + case LBF_MATH_COSH: + case LBF_MATH_COS: + case LBF_MATH_DEG: + case LBF_MATH_EXP: + case LBF_MATH_FLOOR: + case LBF_MATH_FMOD: + case LBF_MATH_FREXP: + case LBF_MATH_LDEXP: + case LBF_MATH_LOG10: + case LBF_MATH_LOG: + case LBF_MATH_MAX: + case LBF_MATH_MIN: + case LBF_MATH_MODF: + case LBF_MATH_POW: + case LBF_MATH_RAD: + case LBF_MATH_SINH: + case LBF_MATH_SIN: + case LBF_MATH_SQRT: + case LBF_MATH_TANH: + case LBF_MATH_TAN: + case LBF_BIT32_ARSHIFT: + case LBF_BIT32_BAND: + case LBF_BIT32_BNOT: + case LBF_BIT32_BOR: + case LBF_BIT32_BXOR: + case LBF_BIT32_BTEST: + case LBF_BIT32_EXTRACT: + case LBF_BIT32_LROTATE: + case LBF_BIT32_LSHIFT: + case LBF_BIT32_REPLACE: + case LBF_BIT32_RROTATE: + case LBF_BIT32_RSHIFT: + case LBF_TYPE: + case LBF_STRING_BYTE: + case LBF_STRING_CHAR: + case LBF_STRING_LEN: + case LBF_TYPEOF: + case LBF_STRING_SUB: + case LBF_MATH_CLAMP: + case LBF_MATH_SIGN: + case LBF_MATH_ROUND: + case LBF_RAWSET: + case LBF_RAWGET: + case LBF_RAWEQUAL: + case LBF_TABLE_INSERT: + case LBF_TABLE_UNPACK: + case LBF_VECTOR: + case LBF_BIT32_COUNTLZ: + case LBF_BIT32_COUNTRZ: + case LBF_SELECT_VARARG: + case LBF_RAWLEN: + case LBF_BIT32_EXTRACTK: + case LBF_GETMETATABLE: + break; + case LBF_SETMETATABLE: + state.invalidateHeap(); // TODO: only knownNoMetatable is affected and we might know which one + break; + } + + // TODO: classify further using switch above, some fastcalls only modify the value, not the tag + state.invalidateRegistersFrom(firstReturnReg); +} + static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& function, IrBlock& block, IrInst& inst, uint32_t index) { switch (inst.cmd) @@ -406,20 +487,16 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& } } break; - case IrCmd::LOP_FASTCALL: - case IrCmd::LOP_FASTCALL1: - case IrCmd::LOP_FASTCALL2: - case IrCmd::LOP_FASTCALL2K: - // TODO: classify fast call behaviors to avoid heap invalidation - state.invalidateHeap(); // Even a builtin method can change table properties - state.invalidateRegistersFrom(inst.b.index); - break; case IrCmd::LOP_AND: case IrCmd::LOP_ANDK: case IrCmd::LOP_OR: case IrCmd::LOP_ORK: state.invalidate(inst.b); break; + case IrCmd::FASTCALL: + case IrCmd::INVOKE_FASTCALL: + handleBuiltinEffects(state, LuauBuiltinFunction(function.uintOp(inst.a)), inst.b.index, function.intOp(inst.f)); + break; // These instructions don't have an effect on register/memory state we are tracking case IrCmd::NOP: @@ -436,6 +513,8 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::DIV_NUM: case IrCmd::MOD_NUM: case IrCmd::POW_NUM: + case IrCmd::MIN_NUM: + case IrCmd::MAX_NUM: case IrCmd::UNM_NUM: case IrCmd::NOT_ANY: case IrCmd::JUMP: @@ -458,6 +537,7 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::SUBSTITUTE: case IrCmd::ADJUST_STACK_TO_REG: // Changes stack top, but not the values case IrCmd::ADJUST_STACK_TO_TOP: // Changes stack top, but not the values + case IrCmd::CHECK_FASTCALL_RES: // Changes stack top, but not the values break; // We don't model the following instructions, so we just clear all the knowledge we have built up @@ -534,8 +614,9 @@ static void constPropInBlockChain(IrBuilder& build, std::vector& visite if (termInst.cmd == IrCmd::JUMP) { IrBlock& target = function.blockOp(termInst.a); + uint32_t targetIdx = function.getBlockIndex(target); - if (target.useCount == 1 && !visited[function.getBlockIndex(target)] && target.kind != IrBlockKind::Fallback) + if (target.useCount == 1 && !visited[targetIdx] && target.kind != IrBlockKind::Fallback) nextBlock = ⌖ } @@ -543,12 +624,114 @@ static void constPropInBlockChain(IrBuilder& build, std::vector& visite } } +// Note that blocks in the collected path are marked as visited +static std::vector collectDirectBlockJumpPath(IrFunction& function, std::vector& visited, IrBlock* block) +{ + // Path shouldn't be built starting with a block that has 'live out' values. + // One theoretical way to get it is if we had 'block' jumping unconditionally into a successor that uses values from 'block' + // * if the successor has only one use, the starting conditions of 'tryCreateLinearBlock' would prevent this + // * if the successor has multiple uses, it can't have such 'live in' values without phi nodes that we don't have yet + // Another possibility is to have two paths from 'block' into the target through two intermediate blocks + // Usually that would mean that we would have a conditional jump at the end of 'block' + // But using check guards and fallback clocks it becomes a possible setup + // We avoid this by making sure fallbacks rejoin the other immediate successor of 'block' + LUAU_ASSERT(getLiveOutValueCount(function, *block) == 0); + + std::vector path; + + while (block) + { + IrInst& termInst = function.instructions[block->finish]; + IrBlock* nextBlock = nullptr; + + // A chain is made from internal blocks that were not a part of bytecode CFG + if (termInst.cmd == IrCmd::JUMP) + { + IrBlock& target = function.blockOp(termInst.a); + uint32_t targetIdx = function.getBlockIndex(target); + + if (!visited[targetIdx] && target.kind == IrBlockKind::Internal) + { + // Additional restriction is that to join a block, it cannot produce values that are used in other blocks + // And it also can't use values produced in other blocks + auto [liveIns, liveOuts] = getLiveInOutValueCount(function, target); + + if (liveIns == 0 && liveOuts == 0) + { + visited[targetIdx] = true; + path.push_back(targetIdx); + + nextBlock = ⌖ + } + } + } + + block = nextBlock; + } + + return path; +} + +static void tryCreateLinearBlock(IrBuilder& build, std::vector& visited, IrBlock& startingBlock) +{ + IrFunction& function = build.function; + + uint32_t blockIdx = function.getBlockIndex(startingBlock); + LUAU_ASSERT(!visited[blockIdx]); + visited[blockIdx] = true; + + IrInst& termInst = function.instructions[startingBlock.finish]; + + // Block has to end with an unconditional jump + if (termInst.cmd != IrCmd::JUMP) + return; + + // And it has to jump to a block with more than one user + // If there's only one use, it should already be optimized by constPropInBlockChain + if (function.blockOp(termInst.a).useCount == 1) + return; + + uint32_t targetBlockIdx = termInst.a.index; + + // Check if this path is worth it (and it will also mark path blocks as visited) + std::vector path = collectDirectBlockJumpPath(function, visited, &startingBlock); + + // If path is too small, we are not going to linearize it + if (int(path.size()) < FInt::LuauCodeGenMinLinearBlockPath) + return; + + // Initialize state with the knowledge of our current block + ConstPropState state; + constPropInBlock(build, startingBlock, state); + + // Veryfy that target hasn't changed + LUAU_ASSERT(function.instructions[startingBlock.finish].a.index == targetBlockIdx); + + // Create new linearized block into which we are going to redirect starting block jump + IrOp newBlock = build.block(IrBlockKind::Linearized); + visited.push_back(false); + + // TODO: placement of linear blocks in final lowering is sub-optimal, it should follow our predecessor + build.beginBlock(newBlock); + + replace(function, termInst.a, newBlock); + + // Clone the collected path int our fresh block + for (uint32_t pathBlockIdx : path) + build.clone(function.blocks[pathBlockIdx], /* removeCurrentTerminator */ true); + + // Optimize our linear block + IrBlock& linearBlock = function.blockOp(newBlock); + constPropInBlock(build, linearBlock, state); +} + void constPropInBlockChains(IrBuilder& build) { IrFunction& function = build.function; std::vector visited(function.blocks.size(), false); + // First pass: go over existing blocks once and propagate constants for (IrBlock& block : function.blocks) { if (block.kind == IrBlockKind::Fallback || block.kind == IrBlockKind::Dead) @@ -559,6 +742,26 @@ void constPropInBlockChains(IrBuilder& build) constPropInBlockChain(build, visited, &block); } + + // Second pass: go through internal block chains and outline them into a single new block + // Outlining will be able to linearize the execution, even if there was a jump to a block with multiple users, + // new 'block' will only be reachable from a single one and all gathered information can be preserved. + std::fill(visited.begin(), visited.end(), false); + + // This next loop can create new 'linear' blocks, so index-based loop has to be used (and it intentionally won't reach those new blocks) + size_t originalBlockCount = function.blocks.size(); + for (size_t i = 0; i < originalBlockCount; i++) + { + IrBlock& block = function.blocks[i]; + + if (block.kind == IrBlockKind::Fallback || block.kind == IrBlockKind::Dead) + continue; + + if (visited[function.getBlockIndex(block)]) + continue; + + tryCreateLinearBlock(build, visited, block); + } } } // namespace CodeGen diff --git a/CodeGen/src/OptimizeFinalX64.cpp b/CodeGen/src/OptimizeFinalX64.cpp index 2b7c9652..dd31fcc4 100644 --- a/CodeGen/src/OptimizeFinalX64.cpp +++ b/CodeGen/src/OptimizeFinalX64.cpp @@ -41,6 +41,8 @@ static void optimizeMemoryOperandsX64(IrFunction& function, IrBlock& block) case IrCmd::DIV_NUM: case IrCmd::MOD_NUM: case IrCmd::POW_NUM: + case IrCmd::MIN_NUM: + case IrCmd::MAX_NUM: { if (inst.b.kind == IrOpKind::Inst) { diff --git a/CodeGen/src/UnwindBuilderDwarf2.cpp b/CodeGen/src/UnwindBuilderDwarf2.cpp index 7dc86d3e..a95ed094 100644 --- a/CodeGen/src/UnwindBuilderDwarf2.cpp +++ b/CodeGen/src/UnwindBuilderDwarf2.cpp @@ -168,12 +168,12 @@ void UnwindBuilderDwarf2::start() // Function call frame instructions to follow } -void UnwindBuilderDwarf2::spill(int espOffset, RegisterX64 reg) +void UnwindBuilderDwarf2::spill(int espOffset, X64::RegisterX64 reg) { pos = advanceLocation(pos, 5); // REX.W mov [rsp + imm8], reg } -void UnwindBuilderDwarf2::save(RegisterX64 reg) +void UnwindBuilderDwarf2::save(X64::RegisterX64 reg) { stackOffset += 8; pos = advanceLocation(pos, 2); // REX.W push reg @@ -188,7 +188,7 @@ void UnwindBuilderDwarf2::allocStack(int size) pos = defineCfaExpressionOffset(pos, stackOffset); } -void UnwindBuilderDwarf2::setupFrameReg(RegisterX64 reg, int espOffset) +void UnwindBuilderDwarf2::setupFrameReg(X64::RegisterX64 reg, int espOffset) { if (espOffset != 0) pos = advanceLocation(pos, 5); // REX.W lea rbp, [rsp + imm8] diff --git a/CodeGen/src/UnwindBuilderWin.cpp b/CodeGen/src/UnwindBuilderWin.cpp index 13e92ab0..21733001 100644 --- a/CodeGen/src/UnwindBuilderWin.cpp +++ b/CodeGen/src/UnwindBuilderWin.cpp @@ -49,12 +49,12 @@ void UnwindBuilderWin::start() unwindCodes.reserve(16); } -void UnwindBuilderWin::spill(int espOffset, RegisterX64 reg) +void UnwindBuilderWin::spill(int espOffset, X64::RegisterX64 reg) { prologSize += 5; // REX.W mov [rsp + imm8], reg } -void UnwindBuilderWin::save(RegisterX64 reg) +void UnwindBuilderWin::save(X64::RegisterX64 reg) { prologSize += 2; // REX.W push reg stackOffset += 8; @@ -70,7 +70,7 @@ void UnwindBuilderWin::allocStack(int size) unwindCodes.push_back({prologSize, UWOP_ALLOC_SMALL, uint8_t((size - 8) / 8)}); } -void UnwindBuilderWin::setupFrameReg(RegisterX64 reg, int espOffset) +void UnwindBuilderWin::setupFrameReg(X64::RegisterX64 reg, int espOffset) { LUAU_ASSERT(espOffset < 256 && espOffset % 16 == 0); diff --git a/Compiler/src/Builtins.cpp b/Compiler/src/Builtins.cpp index ae963370..073bb1c7 100644 --- a/Compiler/src/Builtins.cpp +++ b/Compiler/src/Builtins.cpp @@ -244,5 +244,158 @@ void analyzeBuiltins(DenseHashMap& result, const DenseHashMap root->visit(&visitor); } +BuiltinInfo getBuiltinInfo(int bfid) +{ + switch (LuauBuiltinFunction(bfid)) + { + case LBF_NONE: + return {-1, -1}; + + case LBF_ASSERT: + return {-1, -1}; + ; // assert() returns all values when first value is truthy + + case LBF_MATH_ABS: + case LBF_MATH_ACOS: + case LBF_MATH_ASIN: + return {1, 1}; + + case LBF_MATH_ATAN2: + return {2, 1}; + + case LBF_MATH_ATAN: + case LBF_MATH_CEIL: + case LBF_MATH_COSH: + case LBF_MATH_COS: + case LBF_MATH_DEG: + case LBF_MATH_EXP: + case LBF_MATH_FLOOR: + return {1, 1}; + + case LBF_MATH_FMOD: + return {2, 1}; + + case LBF_MATH_FREXP: + return {1, 2}; + + case LBF_MATH_LDEXP: + return {2, 1}; + + case LBF_MATH_LOG10: + return {1, 1}; + + case LBF_MATH_LOG: + return {-1, 1}; // 1 or 2 parameters + + case LBF_MATH_MAX: + case LBF_MATH_MIN: + return {-1, 1}; // variadic + + case LBF_MATH_MODF: + return {1, 2}; + + case LBF_MATH_POW: + return {2, 1}; + + case LBF_MATH_RAD: + case LBF_MATH_SINH: + case LBF_MATH_SIN: + case LBF_MATH_SQRT: + case LBF_MATH_TANH: + case LBF_MATH_TAN: + return {1, 1}; + + case LBF_BIT32_ARSHIFT: + return {2, 1}; + + case LBF_BIT32_BAND: + return {-1, 1}; // variadic + + case LBF_BIT32_BNOT: + return {1, 1}; + + case LBF_BIT32_BOR: + case LBF_BIT32_BXOR: + case LBF_BIT32_BTEST: + return {-1, 1}; // variadic + + case LBF_BIT32_EXTRACT: + return {-1, 1}; // 2 or 3 parameters + + case LBF_BIT32_LROTATE: + case LBF_BIT32_LSHIFT: + return {2, 1}; + + case LBF_BIT32_REPLACE: + return {-1, 1}; // 3 or 4 parameters + + case LBF_BIT32_RROTATE: + case LBF_BIT32_RSHIFT: + return {2, 1}; + + case LBF_TYPE: + return {1, 1}; + + case LBF_STRING_BYTE: + return {-1, -1}; // 1, 2 or 3 parameters + + case LBF_STRING_CHAR: + return {-1, 1}; // variadic + + case LBF_STRING_LEN: + return {1, 1}; + + case LBF_TYPEOF: + return {1, 1}; + + case LBF_STRING_SUB: + return {-1, 1}; // 2 or 3 parameters + + case LBF_MATH_CLAMP: + return {3, 1}; + + case LBF_MATH_SIGN: + case LBF_MATH_ROUND: + return {1, 1}; + + case LBF_RAWSET: + return {3, 1}; + + case LBF_RAWGET: + case LBF_RAWEQUAL: + return {2, 1}; + + case LBF_TABLE_INSERT: + return {-1, 0}; // 2 or 3 parameters + + case LBF_TABLE_UNPACK: + return {-1, -1}; // 1, 2 or 3 parameters + + case LBF_VECTOR: + return {-1, 1}; // 3 or 4 parameters in some configurations + + case LBF_BIT32_COUNTLZ: + case LBF_BIT32_COUNTRZ: + return {1, 1}; + + case LBF_SELECT_VARARG: + return {-1, -1}; // variadic + + case LBF_RAWLEN: + return {1, 1}; + + case LBF_BIT32_EXTRACTK: + return {3, 1}; + + case LBF_GETMETATABLE: + return {1, 1}; + + case LBF_SETMETATABLE: + return {2, 1}; + }; + + LUAU_UNREACHABLE(); +} + } // namespace Compile } // namespace Luau diff --git a/Compiler/src/Builtins.h b/Compiler/src/Builtins.h index 4399c532..e179218a 100644 --- a/Compiler/src/Builtins.h +++ b/Compiler/src/Builtins.h @@ -39,5 +39,13 @@ Builtin getBuiltin(AstExpr* node, const DenseHashMap& globals, void analyzeBuiltins(DenseHashMap& result, const DenseHashMap& globals, const DenseHashMap& variables, const CompileOptions& options, AstNode* root); +struct BuiltinInfo +{ + int params; + int results; +}; + +BuiltinInfo getBuiltinInfo(int bfid); + } // namespace Compile } // namespace Luau diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp index 11bf2429..8e450f4f 100644 --- a/Compiler/src/BytecodeBuilder.cpp +++ b/Compiler/src/BytecodeBuilder.cpp @@ -2038,7 +2038,10 @@ void BytecodeBuilder::dumpInstruction(const uint32_t* code, std::string& result, case LOP_CAPTURE: formatAppend(result, "CAPTURE %s %c%d\n", - LUAU_INSN_A(insn) == LCT_UPVAL ? "UPVAL" : LUAU_INSN_A(insn) == LCT_REF ? "REF" : LUAU_INSN_A(insn) == LCT_VAL ? "VAL" : "", + LUAU_INSN_A(insn) == LCT_UPVAL ? "UPVAL" + : LUAU_INSN_A(insn) == LCT_REF ? "REF" + : LUAU_INSN_A(insn) == LCT_VAL ? "VAL" + : "", LUAU_INSN_A(insn) == LCT_UPVAL ? 'U' : 'R', LUAU_INSN_B(insn)); break; diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 8a017f48..78896d31 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -26,6 +26,7 @@ LUAU_FASTINTVARIABLE(LuauCompileInlineThresholdMaxBoost, 300) LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5) LUAU_FASTFLAGVARIABLE(LuauCompileTerminateBC, false) +LUAU_FASTFLAGVARIABLE(LuauCompileBuiltinArity, false) namespace Luau { @@ -293,6 +294,12 @@ struct Compiler if (isConstant(expr)) return false; + // handles builtin calls that can't be constant-folded but are known to return one value + // note: optimizationLevel check is technically redundant but it's important that we never optimize based on builtins in O1 + if (FFlag::LuauCompileBuiltinArity && options.optimizationLevel >= 2) + if (int* bfid = builtins.find(expr)) + return getBuiltinInfo(*bfid).results != 1; + // handles local function calls where we know only one argument is returned AstExprFunction* func = getFunctionExpr(expr->func); Function* fi = func ? functions.find(func) : nullptr; @@ -506,6 +513,7 @@ struct Compiler // we can't inline multret functions because the caller expects L->top to be adjusted: // - inlined return compiles to a JUMP, and we don't have an instruction that adjusts L->top arbitrarily // - even if we did, right now all L->top adjustments are immediately consumed by the next instruction, and for now we want to preserve that + // - additionally, we can't easily compile multret expressions into designated target as computed call arguments will get clobbered if (multRet) { bytecode.addDebugRemark("inlining failed: can't convert fixed returns to multret"); @@ -755,8 +763,13 @@ struct Compiler } // Optimization: for 1/2 argument fast calls use specialized opcodes - if (bfid >= 0 && expr->args.size >= 1 && expr->args.size <= 2 && !isExprMultRet(expr->args.data[expr->args.size - 1])) - return compileExprFastcallN(expr, target, targetCount, targetTop, multRet, regs, bfid); + if (bfid >= 0 && expr->args.size >= 1 && expr->args.size <= 2) + { + if (!isExprMultRet(expr->args.data[expr->args.size - 1])) + return compileExprFastcallN(expr, target, targetCount, targetTop, multRet, regs, bfid); + else if (FFlag::LuauCompileBuiltinArity && options.optimizationLevel >= 2 && int(expr->args.size) == getBuiltinInfo(bfid).params) + return compileExprFastcallN(expr, target, targetCount, targetTop, multRet, regs, bfid); + } if (expr->self) { diff --git a/Sources.cmake b/Sources.cmake index 22197e0e..88c6e9b6 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -108,6 +108,7 @@ target_sources(Luau.CodeGen PRIVATE CodeGen/src/CodeGenUtils.h CodeGen/src/CodeGenX64.h CodeGen/src/EmitBuiltinsX64.h + CodeGen/src/EmitCommon.h CodeGen/src/EmitCommonX64.h CodeGen/src/EmitInstructionX64.h CodeGen/src/Fallbacks.h @@ -126,6 +127,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/AstJsonEncoder.h Analysis/include/Luau/AstQuery.h Analysis/include/Luau/Autocomplete.h + Analysis/include/Luau/Breadcrumb.h Analysis/include/Luau/BuiltinDefinitions.h Analysis/include/Luau/Clone.h Analysis/include/Luau/Config.h diff --git a/VM/src/lstrlib.cpp b/VM/src/lstrlib.cpp index cf7381ae..875a479a 100644 --- a/VM/src/lstrlib.cpp +++ b/VM/src/lstrlib.cpp @@ -1445,7 +1445,7 @@ static int str_pack(lua_State* L) const char* s = luaL_checklstring(L, arg, &len); luaL_argcheck(L, len <= (size_t)size, arg, "string longer than given size"); luaL_addlstring(&b, s, len, -1); // add string - while (len++ < (size_t)size) // pad extra space + while (len++ < (size_t)size) // pad extra space luaL_addchar(&b, LUAL_PACKPADBYTE); break; } diff --git a/tests/AssemblyBuilderA64.test.cpp b/tests/AssemblyBuilderA64.test.cpp index d808ac49..e23b965b 100644 --- a/tests/AssemblyBuilderA64.test.cpp +++ b/tests/AssemblyBuilderA64.test.cpp @@ -7,6 +7,7 @@ #include using namespace Luau::CodeGen; +using namespace Luau::CodeGen::A64; static std::string bytecodeAsArray(const std::vector& bytecode) { diff --git a/tests/AssemblyBuilderX64.test.cpp b/tests/AssemblyBuilderX64.test.cpp index c4d2a1c7..6aa7aa56 100644 --- a/tests/AssemblyBuilderX64.test.cpp +++ b/tests/AssemblyBuilderX64.test.cpp @@ -7,6 +7,7 @@ #include using namespace Luau::CodeGen; +using namespace Luau::CodeGen::X64; static std::string bytecodeAsArray(const std::vector& bytecode) { diff --git a/tests/AstJsonEncoder.test.cpp b/tests/AstJsonEncoder.test.cpp index 44e9e5e4..a0127eef 100644 --- a/tests/AstJsonEncoder.test.cpp +++ b/tests/AstJsonEncoder.test.cpp @@ -445,7 +445,7 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_annotation") AstStat* statement = expectParseStatement("type T = ((number) -> (string | nil)) & ((string) -> ())"); std::string_view expected = - R"({"type":"AstStatTypeAlias","location":"0,0 - 0,55","name":"T","generics":[],"genericPacks":[],"type":{"type":"AstTypeIntersection","location":"0,9 - 0,55","types":[{"type":"AstTypeFunction","location":"0,10 - 0,35","generics":[],"genericPacks":[],"argTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,11 - 0,17","name":"number","parameters":[]}]},"returnTypes":{"type":"AstTypeList","types":[{"type":"AstTypeUnion","location":"0,23 - 0,35","types":[{"type":"AstTypeReference","location":"0,23 - 0,29","name":"string","parameters":[]},{"type":"AstTypeReference","location":"0,32 - 0,35","name":"nil","parameters":[]}]}]}},{"type":"AstTypeFunction","location":"0,41 - 0,55","generics":[],"genericPacks":[],"argTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,42 - 0,48","name":"string","parameters":[]}]},"returnTypes":{"type":"AstTypeList","types":[]}}]},"exported":false})"; + R"({"type":"AstStatTypeAlias","location":"0,0 - 0,55","name":"T","generics":[],"genericPacks":[],"type":{"type":"AstTypeIntersection","location":"0,9 - 0,55","types":[{"type":"AstTypeFunction","location":"0,10 - 0,36","generics":[],"genericPacks":[],"argTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,11 - 0,17","name":"number","parameters":[]}]},"returnTypes":{"type":"AstTypeList","types":[{"type":"AstTypeUnion","location":"0,23 - 0,35","types":[{"type":"AstTypeReference","location":"0,23 - 0,29","name":"string","parameters":[]},{"type":"AstTypeReference","location":"0,32 - 0,35","name":"nil","parameters":[]}]}]}},{"type":"AstTypeFunction","location":"0,41 - 0,55","generics":[],"genericPacks":[],"argTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,42 - 0,48","name":"string","parameters":[]}]},"returnTypes":{"type":"AstTypeList","types":[]}}]},"exported":false})"; CHECK(toJson(statement) == expected); } diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index d238e9ec..85bd5507 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -3198,6 +3198,20 @@ t.@1 } } +TEST_CASE_FIXTURE(ACFixture, "simple") +{ + check(R"( +local t = {} +function t:m() end +t:m() + )"); + + // auto ac = autocomplete('1'); + + // REQUIRE(ac.entryMap.count("m")); + // CHECK(!ac.entryMap["m"].wrongIndexType); +} + TEST_CASE_FIXTURE(ACFixture, "do_compatible_self_calls") { check(R"( @@ -3466,4 +3480,33 @@ TEST_CASE_FIXTURE(ACFixture, "string_contents_is_available_to_callback") CHECK(isCorrect); } +TEST_CASE_FIXTURE(ACFixture, "autocomplete_response_perf1" * doctest::timeout(0.5)) +{ + ScopedFastFlag luauAutocompleteSkipNormalization{"LuauAutocompleteSkipNormalization", true}; + + // Build a function type with a large overload set + const int parts = 100; + std::string source; + + for (int i = 0; i < parts; i++) + formatAppend(source, "type T%d = { f%d: number }\n", i, i); + + source += "type Instance = { new: (('s0', extra: Instance?) -> T0)"; + + for (int i = 1; i < parts; i++) + formatAppend(source, " & (('s%d', extra: Instance?) -> T%d)", i, i); + + source += " }\n"; + + source += "local Instance: Instance = {} :: any\n"; + source += "local function c(): boolean return t@1 end\n"; + + check(source); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("true")); + CHECK(ac.entryMap.count("Instance")); +} + TEST_SUITE_END(); diff --git a/tests/CodeAllocator.test.cpp b/tests/CodeAllocator.test.cpp index 65b485a7..a6ed96f0 100644 --- a/tests/CodeAllocator.test.cpp +++ b/tests/CodeAllocator.test.cpp @@ -131,6 +131,8 @@ TEST_CASE("CodeAllocationWithUnwindCallbacks") #if !defined(LUAU_BIG_ENDIAN) TEST_CASE("WindowsUnwindCodesX64") { + using namespace X64; + UnwindBuilderWin unwind; unwind.start(); @@ -162,6 +164,8 @@ TEST_CASE("WindowsUnwindCodesX64") TEST_CASE("Dwarf2UnwindCodesX64") { + using namespace X64; + UnwindBuilderDwarf2 unwind; unwind.start(); @@ -195,21 +199,23 @@ TEST_CASE("Dwarf2UnwindCodesX64") #if defined(_WIN32) // Windows x64 ABI -constexpr RegisterX64 rArg1 = rcx; -constexpr RegisterX64 rArg2 = rdx; -constexpr RegisterX64 rArg3 = r8; +constexpr X64::RegisterX64 rArg1 = X64::rcx; +constexpr X64::RegisterX64 rArg2 = X64::rdx; +constexpr X64::RegisterX64 rArg3 = X64::r8; #else // System V AMD64 ABI -constexpr RegisterX64 rArg1 = rdi; -constexpr RegisterX64 rArg2 = rsi; -constexpr RegisterX64 rArg3 = rdx; +constexpr X64::RegisterX64 rArg1 = X64::rdi; +constexpr X64::RegisterX64 rArg2 = X64::rsi; +constexpr X64::RegisterX64 rArg3 = X64::rdx; #endif -constexpr RegisterX64 rNonVol1 = r12; -constexpr RegisterX64 rNonVol2 = rbx; +constexpr X64::RegisterX64 rNonVol1 = X64::r12; +constexpr X64::RegisterX64 rNonVol2 = X64::rbx; TEST_CASE("GeneratedCodeExecutionX64") { + using namespace X64; + AssemblyBuilderX64 build(/* logText= */ false); build.mov(rax, rArg1); @@ -244,6 +250,8 @@ void throwing(int64_t arg) TEST_CASE("GeneratedCodeExecutionWithThrowX64") { + using namespace X64; + AssemblyBuilderX64 build(/* logText= */ false); #if defined(_WIN32) @@ -320,6 +328,8 @@ TEST_CASE("GeneratedCodeExecutionWithThrowX64") TEST_CASE("GeneratedCodeExecutionWithThrowOutsideTheGateX64") { + using namespace X64; + AssemblyBuilderX64 build(/* logText= */ false); #if defined(_WIN32) @@ -437,6 +447,8 @@ TEST_CASE("GeneratedCodeExecutionWithThrowOutsideTheGateX64") TEST_CASE("GeneratedCodeExecutionA64") { + using namespace A64; + AssemblyBuilderA64 build(/* logText= */ false); Label skip; diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 87a782ad..135a555a 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -4655,6 +4655,8 @@ RETURN R0 0 TEST_CASE("LoopUnrollCost") { + ScopedFastFlag sff("LuauCompileBuiltinArity", true); + ScopedFastInt sfis[] = { {"LuauCompileLoopUnrollThreshold", 25}, {"LuauCompileLoopUnrollThresholdMaxBoost", 300}, @@ -4796,10 +4798,10 @@ FORNPREP R1 L3 L0: FASTCALL1 24 R3 L1 MOVE R6 R3 GETIMPORT R5 2 [math.sin] -CALL R5 1 -1 -L1: FASTCALL 2 L2 +CALL R5 1 1 +L1: FASTCALL1 2 R5 L2 GETIMPORT R4 4 [math.abs] -CALL R4 -1 1 +CALL R4 1 1 L2: SETTABLE R4 R0 R3 FORNLOOP R1 L0 L3: RETURN R0 1 @@ -5924,6 +5926,8 @@ RETURN R2 1 TEST_CASE("InlineMultret") { + ScopedFastFlag sff("LuauCompileBuiltinArity", true); + // inlining a function in multret context is prohibited since we can't adjust L->top outside of CALL/GETVARARGS CHECK_EQ("\n" + compileFunction(R"( local function foo(a) @@ -5994,7 +5998,7 @@ CALL R1 1 -1 RETURN R1 -1 )"); - // and unfortunately we can't do this analysis for builtins or method calls due to getfenv + // we do this for builtins though as we assume getfenv is not used or is not changing arity CHECK_EQ("\n" + compileFunction(R"( local function foo(a) return math.abs(a) @@ -6005,10 +6009,8 @@ return foo(42) 1, 2), R"( DUPCLOSURE R0 K0 ['foo'] -MOVE R1 R0 -LOADN R2 42 -CALL R1 1 -1 -RETURN R1 -1 +LOADN R1 42 +RETURN R1 1 )"); } @@ -6263,6 +6265,8 @@ RETURN R0 52 TEST_CASE("BuiltinFoldingProhibited") { + ScopedFastFlag sff("LuauCompileBuiltinArity", true); + CHECK_EQ("\n" + compileFunction(R"( return math.abs(), @@ -6326,8 +6330,8 @@ L8: LOADN R10 1 FASTCALL2K 19 R10 K3 L9 [true] LOADK R11 K3 [true] GETIMPORT R9 26 [math.min] -CALL R9 2 -1 -L9: RETURN R0 -1 +CALL R9 2 1 +L9: RETURN R0 10 )"); } @@ -6865,4 +6869,111 @@ L3: RETURN R0 0 )"); } +TEST_CASE("BuiltinArity") +{ + ScopedFastFlag sff("LuauCompileBuiltinArity", true); + + // by default we can't assume that we know parameter/result count for builtins as they can be overridden at runtime + CHECK_EQ("\n" + compileFunction(R"( +return math.abs(unknown()) +)", + 0, 1), + R"( +GETIMPORT R1 1 [unknown] +CALL R1 0 -1 +FASTCALL 2 L0 +GETIMPORT R0 4 [math.abs] +CALL R0 -1 -1 +L0: RETURN R0 -1 +)"); + + // however, when using optimization level 2, we assume compile time knowledge about builtin behavior even if we can't deoptimize that with fenv + // in the test case below, this allows us to synthesize a more efficient FASTCALL1 (and use a fixed-return call to unknown) + CHECK_EQ("\n" + compileFunction(R"( +return math.abs(unknown()) +)", + 0, 2), + R"( +GETIMPORT R1 1 [unknown] +CALL R1 0 1 +FASTCALL1 2 R1 L0 +GETIMPORT R0 4 [math.abs] +CALL R0 1 1 +L0: RETURN R0 1 +)"); + + // some builtins are variadic, and as such they can't use fixed-length fastcall variants + CHECK_EQ("\n" + compileFunction(R"( +return math.max(0, unknown()) +)", + 0, 2), + R"( +LOADN R1 0 +GETIMPORT R2 1 [unknown] +CALL R2 0 -1 +FASTCALL 18 L0 +GETIMPORT R0 4 [math.max] +CALL R0 -1 1 +L0: RETURN R0 1 +)"); + + // some builtins are not variadic but don't have a fixed number of arguments; we currently don't optimize this although we might start to in the + // future + CHECK_EQ("\n" + compileFunction(R"( +return bit32.extract(0, 1, unknown()) +)", + 0, 2), + R"( +LOADN R1 0 +LOADN R2 1 +GETIMPORT R3 1 [unknown] +CALL R3 0 -1 +FASTCALL 34 L0 +GETIMPORT R0 4 [bit32.extract] +CALL R0 -1 1 +L0: RETURN R0 1 +)"); + + // importantly, this optimization also helps us get around the multret inlining restriction for builtin wrappers + CHECK_EQ("\n" + compileFunction(R"( +local function new() + return setmetatable({}, MT) +end + +return new() +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 ['new'] +NEWTABLE R2 0 0 +GETIMPORT R3 2 [MT] +FASTCALL2 61 R2 R3 L0 +GETIMPORT R1 4 [setmetatable] +CALL R1 2 1 +L0: RETURN R1 1 +)"); + + // note that the results of this optimization are benign in fixed-arg contexts which dampens the effect of fenv substitutions on correctness in + // practice + CHECK_EQ("\n" + compileFunction(R"( +local x = ... +local y, z = type(x) +return type(y, z) +)", + 0, 2), + R"( +GETVARARGS R0 1 +FASTCALL1 40 R0 L0 +MOVE R2 R0 +GETIMPORT R1 1 [type] +CALL R1 1 2 +L0: FASTCALL2 40 R1 R2 L1 +MOVE R4 R1 +MOVE R5 R2 +GETIMPORT R3 1 [type] +CALL R3 2 1 +L1: RETURN R3 1 +)"); +} + TEST_SUITE_END(); diff --git a/tests/DataFlowGraph.test.cpp b/tests/DataFlowGraph.test.cpp index d8230700..bd5fe562 100644 --- a/tests/DataFlowGraph.test.cpp +++ b/tests/DataFlowGraph.test.cpp @@ -10,7 +10,7 @@ using namespace Luau; -class DataFlowGraphFixture +struct DataFlowGraphFixture { // Only needed to fix the operator== reflexivity of an empty Symbol. ScopedFastFlag dcr{"DebugLuauDeferredConstraintResolution", true}; @@ -23,7 +23,6 @@ class DataFlowGraphFixture std::optional graph; -public: void dfg(const std::string& code) { ParseResult parseResult = Parser::parse(code.c_str(), code.size(), names, allocator); @@ -34,19 +33,19 @@ public: } template - std::optional getDef(const std::vector& nths = {nth(N)}) + NullableBreadcrumbId getBreadcrumb(const std::vector& nths = {nth(N)}) { T* node = query(module, nths); REQUIRE(node); - return graph->getDef(node); + return graph->getBreadcrumb(node); } template - DefId requireDef(const std::vector& nths = {nth(N)}) + BreadcrumbId requireBreadcrumb(const std::vector& nths = {nth(N)}) { - auto loc = getDef(nths); - REQUIRE(loc); - return NotNull{*loc}; + auto bc = getBreadcrumb(nths); + REQUIRE(bc); + return NotNull{bc}; } }; @@ -59,7 +58,7 @@ TEST_CASE_FIXTURE(DataFlowGraphFixture, "define_locals_in_local_stat") local y = x )"); - REQUIRE(getDef()); + REQUIRE(getBreadcrumb()); } TEST_CASE_FIXTURE(DataFlowGraphFixture, "define_parameters_in_functions") @@ -70,7 +69,7 @@ TEST_CASE_FIXTURE(DataFlowGraphFixture, "define_parameters_in_functions") end )"); - REQUIRE(getDef()); + REQUIRE(getBreadcrumb()); } TEST_CASE_FIXTURE(DataFlowGraphFixture, "find_aliases") @@ -81,9 +80,9 @@ TEST_CASE_FIXTURE(DataFlowGraphFixture, "find_aliases") local z = y )"); - DefId x = requireDef(); - DefId y = requireDef(); - REQUIRE(x != y); // TODO: they should be equal but it's not just locals that can alias, so we'll support this later. + BreadcrumbId x = requireBreadcrumb(); + BreadcrumbId y = requireBreadcrumb(); + REQUIRE(x != y); } TEST_CASE_FIXTURE(DataFlowGraphFixture, "independent_locals") @@ -96,8 +95,8 @@ TEST_CASE_FIXTURE(DataFlowGraphFixture, "independent_locals") local b = y )"); - DefId x = requireDef(); - DefId y = requireDef(); + BreadcrumbId x = requireBreadcrumb(); + BreadcrumbId y = requireBreadcrumb(); REQUIRE(x != y); } diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index f245ca93..cbceabbd 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -178,17 +178,8 @@ AstStatBlock* Fixture::parse(const std::string& source, const ParseOptions& pars if (FFlag::DebugLuauDeferredConstraintResolution) { - Luau::check( - *sourceModule, - {}, - frontend.builtinTypes, - NotNull{&ice}, - NotNull{&moduleResolver}, - NotNull{&fileResolver}, - typeChecker.globalScope, - NotNull{&typeChecker.unifierState}, - frontend.options - ); + Luau::check(*sourceModule, {}, frontend.builtinTypes, NotNull{&ice}, NotNull{&moduleResolver}, NotNull{&fileResolver}, + typeChecker.globalScope, frontend.options); } else typeChecker.check(*sourceModule, sourceModule->mode.value_or(Luau::Mode::Nonstrict)); diff --git a/tests/IrBuilder.test.cpp b/tests/IrBuilder.test.cpp index 9da34367..0896517f 100644 --- a/tests/IrBuilder.test.cpp +++ b/tests/IrBuilder.test.cpp @@ -311,6 +311,8 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "Numeric") build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::DIV_NUM, build.constDouble(2), build.constDouble(5))); build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::MOD_NUM, build.constDouble(5), build.constDouble(2))); build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::POW_NUM, build.constDouble(5), build.constDouble(2))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::MIN_NUM, build.constDouble(5), build.constDouble(2))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::MAX_NUM, build.constDouble(5), build.constDouble(2))); build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::UNM_NUM, build.constDouble(5))); @@ -338,6 +340,8 @@ bb_0: STORE_DOUBLE R0, 0.40000000000000002 STORE_DOUBLE R0, 1 STORE_DOUBLE R0, 25 + STORE_DOUBLE R0, 2 + STORE_DOUBLE R0, 5 STORE_DOUBLE R0, -5 STORE_INT R0, 1i STORE_INT R0, 0i @@ -809,7 +813,8 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "BuiltinFastcallsMayInvalidateMemory") build.inst(IrCmd::CHECK_NO_METATABLE, table, fallback); build.inst(IrCmd::CHECK_READONLY, table, fallback); - build.inst(IrCmd::LOP_FASTCALL1, build.constUint(0), build.vmReg(1), build.vmReg(2), fallback); + build.inst(IrCmd::INVOKE_FASTCALL, build.constUint(LBF_SETMETATABLE), build.vmReg(1), build.vmReg(2), build.vmReg(3), build.constInt(3), + build.constInt(1)); build.inst(IrCmd::CHECK_NO_METATABLE, table, fallback); build.inst(IrCmd::CHECK_READONLY, table, fallback); @@ -830,7 +835,7 @@ bb_0: %1 = LOAD_POINTER R0 CHECK_NO_METATABLE %1, bb_fallback_1 CHECK_READONLY %1, bb_fallback_1 - LOP_FASTCALL1 0u, R1, R2, bb_fallback_1 + %4 = INVOKE_FASTCALL 61u, R1, R2, R3, 3i, 1i CHECK_NO_METATABLE %1, bb_fallback_1 CHECK_READONLY %1, bb_fallback_1 STORE_DOUBLE R1, 0.5 @@ -1195,3 +1200,182 @@ bb_2: } TEST_SUITE_END(); + +TEST_SUITE_BEGIN("LinearExecutionFlowExtraction"); + +TEST_CASE_FIXTURE(IrBuilderFixture, "SimplePathExtraction") +{ + IrOp block1 = build.block(IrBlockKind::Internal); + IrOp fallback1 = build.block(IrBlockKind::Fallback); + IrOp block2 = build.block(IrBlockKind::Internal); + IrOp fallback2 = build.block(IrBlockKind::Fallback); + IrOp block3 = build.block(IrBlockKind::Internal); + IrOp block4 = build.block(IrBlockKind::Internal); + + build.beginBlock(block1); + + IrOp tag1 = build.inst(IrCmd::LOAD_TAG, build.vmReg(2)); + build.inst(IrCmd::CHECK_TAG, tag1, build.constTag(tnumber), fallback1); + build.inst(IrCmd::JUMP, block2); + + build.beginBlock(fallback1); + build.inst(IrCmd::DO_LEN, build.vmReg(1), build.vmReg(2)); + build.inst(IrCmd::JUMP, block2); + + build.beginBlock(block2); + IrOp tag2 = build.inst(IrCmd::LOAD_TAG, build.vmReg(2)); + build.inst(IrCmd::CHECK_TAG, tag2, build.constTag(tnumber), fallback2); + build.inst(IrCmd::JUMP, block3); + + build.beginBlock(fallback2); + build.inst(IrCmd::DO_LEN, build.vmReg(0), build.vmReg(2)); + build.inst(IrCmd::JUMP, block3); + + build.beginBlock(block3); + build.inst(IrCmd::JUMP, block4); + + build.beginBlock(block4); + build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(0), build.constInt(0)); + + updateUseCounts(build.function); + constPropInBlockChains(build); + + CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( +bb_0: + %0 = LOAD_TAG R2 + CHECK_TAG %0, tnumber, bb_fallback_1 + JUMP bb_linear_6 + +bb_fallback_1: + DO_LEN R1, R2 + JUMP bb_2 + +bb_2: + %5 = LOAD_TAG R2 + CHECK_TAG %5, tnumber, bb_fallback_3 + JUMP bb_4 + +bb_fallback_3: + DO_LEN R0, R2 + JUMP bb_4 + +bb_4: + JUMP bb_5 + +bb_5: + LOP_RETURN 0u, R0, 0i + +bb_linear_6: + LOP_RETURN 0u, R0, 0i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "NoPathExtractionForBlocksWithLiveOutValues") +{ + IrOp block1 = build.block(IrBlockKind::Internal); + IrOp fallback1 = build.block(IrBlockKind::Fallback); + IrOp block2 = build.block(IrBlockKind::Internal); + IrOp fallback2 = build.block(IrBlockKind::Fallback); + IrOp block3 = build.block(IrBlockKind::Internal); + IrOp block4a = build.block(IrBlockKind::Internal); + IrOp block4b = build.block(IrBlockKind::Internal); + + build.beginBlock(block1); + + IrOp tag1 = build.inst(IrCmd::LOAD_TAG, build.vmReg(2)); + build.inst(IrCmd::CHECK_TAG, tag1, build.constTag(tnumber), fallback1); + build.inst(IrCmd::JUMP, block2); + + build.beginBlock(fallback1); + build.inst(IrCmd::DO_LEN, build.vmReg(1), build.vmReg(2)); + build.inst(IrCmd::JUMP, block2); + + build.beginBlock(block2); + IrOp tag2 = build.inst(IrCmd::LOAD_TAG, build.vmReg(2)); + build.inst(IrCmd::CHECK_TAG, tag2, build.constTag(tnumber), fallback2); + build.inst(IrCmd::JUMP, block3); + + build.beginBlock(fallback2); + build.inst(IrCmd::DO_LEN, build.vmReg(0), build.vmReg(2)); + build.inst(IrCmd::JUMP, block3); + + build.beginBlock(block3); + IrOp tag3a = build.inst(IrCmd::LOAD_TAG, build.vmReg(3)); + build.inst(IrCmd::JUMP_EQ_TAG, tag3a, build.constTag(tnil), block4a, block4b); + + build.beginBlock(block4a); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), tag3a); + build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(0), build.constInt(0)); + + build.beginBlock(block4b); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), tag3a); + build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(0), build.constInt(0)); + + updateUseCounts(build.function); + constPropInBlockChains(build); + + CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( +bb_0: + %0 = LOAD_TAG R2 + CHECK_TAG %0, tnumber, bb_fallback_1 + JUMP bb_2 + +bb_fallback_1: + DO_LEN R1, R2 + JUMP bb_2 + +bb_2: + %5 = LOAD_TAG R2 + CHECK_TAG %5, tnumber, bb_fallback_3 + JUMP bb_4 + +bb_fallback_3: + DO_LEN R0, R2 + JUMP bb_4 + +bb_4: + %10 = LOAD_TAG R3 + JUMP_EQ_TAG %10, tnil, bb_5, bb_6 + +bb_5: + STORE_TAG R0, %10 + LOP_RETURN 0u, R0, 0i + +bb_6: + STORE_TAG R0, %10 + LOP_RETURN 0u, R0, 0i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "InfiniteLoopInPathAnalysis") +{ + IrOp block1 = build.block(IrBlockKind::Internal); + IrOp block2 = build.block(IrBlockKind::Internal); + + build.beginBlock(block1); + + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber)); + build.inst(IrCmd::JUMP, block2); + + build.beginBlock(block2); + build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.constTag(tboolean)); + build.inst(IrCmd::JUMP, block2); + + updateUseCounts(build.function); + constPropInBlockChains(build); + + CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( +bb_0: + STORE_TAG R0, tnumber + JUMP bb_1 + +bb_1: + STORE_TAG R1, tboolean + JUMP bb_1 + +)"); +} + +TEST_SUITE_END(); diff --git a/tests/Lexer.test.cpp b/tests/Lexer.test.cpp index 784fadad..7fcc1e54 100644 --- a/tests/Lexer.test.cpp +++ b/tests/Lexer.test.cpp @@ -155,6 +155,36 @@ TEST_CASE("string_interpolation_basic") CHECK_EQ(interpEnd.type, Lexeme::InterpStringEnd); } +TEST_CASE("string_interpolation_full") +{ + ScopedFastFlag sff("LuauFixInterpStringMid", true); + + const std::string testInput = R"(`foo {"bar"} {"baz"} end`)"; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + + Lexeme interpBegin = lexer.next(); + CHECK_EQ(interpBegin.type, Lexeme::InterpStringBegin); + CHECK_EQ(interpBegin.toString(), "`foo {"); + + Lexeme quote1 = lexer.next(); + CHECK_EQ(quote1.type, Lexeme::QuotedString); + CHECK_EQ(quote1.toString(), "\"bar\""); + + Lexeme interpMid = lexer.next(); + CHECK_EQ(interpMid.type, Lexeme::InterpStringMid); + CHECK_EQ(interpMid.toString(), "} {"); + + Lexeme quote2 = lexer.next(); + CHECK_EQ(quote2.type, Lexeme::QuotedString); + CHECK_EQ(quote2.toString(), "\"baz\""); + + Lexeme interpEnd = lexer.next(); + CHECK_EQ(interpEnd.type, Lexeme::InterpStringEnd); + CHECK_EQ(interpEnd.toString(), "} end`"); +} + TEST_CASE("string_interpolation_double_brace") { const std::string testInput = R"(`foo{{bad}}bar`)"; diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index afd5a4e4..c716982e 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -37,7 +37,7 @@ TEST_CASE_FIXTURE(Fixture, "DeprecatedGlobal") // Normally this would be defined externally, so hack it in for testing addGlobalBinding(frontend, "Wait", Binding{typeChecker.anyType, {}, true, "wait", "@test/global/Wait"}); - LintResult result = lintTyped("Wait(5)"); + LintResult result = lint("Wait(5)"); REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Global 'Wait' is deprecated, use 'wait' instead"); @@ -49,7 +49,7 @@ TEST_CASE_FIXTURE(Fixture, "DeprecatedGlobalNoReplacement") const char* deprecationReplacementString = ""; addGlobalBinding(frontend, "Version", Binding{typeChecker.anyType, {}, true, deprecationReplacementString}); - LintResult result = lintTyped("Version()"); + LintResult result = lint("Version()"); REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Global 'Version' is deprecated"); @@ -1440,8 +1440,10 @@ TEST_CASE_FIXTURE(Fixture, "LintHygieneUAF") REQUIRE(12 == result.warnings.size()); } -TEST_CASE_FIXTURE(Fixture, "DeprecatedApi") +TEST_CASE_FIXTURE(BuiltinsFixture, "DeprecatedApiTyped") { + ScopedFastFlag sff("LuauImproveDeprecatedApiLint", true); + unfreeze(typeChecker.globalTypes); TypeId instanceType = typeChecker.globalTypes.addType(ClassType{"Instance", {}, std::nullopt, std::nullopt, {}, {}, "Test"}); persist(instanceType); @@ -1459,6 +1461,13 @@ TEST_CASE_FIXTURE(Fixture, "DeprecatedApi") addGlobalBinding(frontend, "Color3", Binding{colorType, {}}); + if (TableType* ttv = getMutable(getGlobalBinding(typeChecker, "table"))) + { + ttv->props["foreach"].deprecated = true; + ttv->props["getn"].deprecated = true; + ttv->props["getn"].deprecatedSuggestion = "#"; + } + freeze(typeChecker.globalTypes); LintResult result = lintTyped(R"( @@ -1467,14 +1476,43 @@ return function (i: Instance) print(i.Name) print(Color3.toHSV()) print(Color3.doesntexist, i.doesntexist) -- type error, but this verifies we correctly handle non-existent members + print(table.getn({})) + table.foreach({}, function() end) + print(table.nogetn()) -- verify that we correctly handle non-existent members return i.DataCost end )"); - REQUIRE(3 == result.warnings.size()); + REQUIRE(5 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Member 'Instance.Wait' is deprecated"); CHECK_EQ(result.warnings[1].text, "Member 'toHSV' is deprecated, use 'Color3:ToHSV' instead"); - CHECK_EQ(result.warnings[2].text, "Member 'Instance.DataCost' is deprecated"); + CHECK_EQ(result.warnings[2].text, "Member 'table.getn' is deprecated, use '#' instead"); + CHECK_EQ(result.warnings[3].text, "Member 'table.foreach' is deprecated"); + CHECK_EQ(result.warnings[4].text, "Member 'Instance.DataCost' is deprecated"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "DeprecatedApiUntyped") +{ + ScopedFastFlag sff("LuauImproveDeprecatedApiLint", true); + + if (TableType* ttv = getMutable(getGlobalBinding(typeChecker, "table"))) + { + ttv->props["foreach"].deprecated = true; + ttv->props["getn"].deprecated = true; + ttv->props["getn"].deprecatedSuggestion = "#"; + } + + LintResult result = lint(R"( +return function () + print(table.getn({})) + table.foreach({}, function() end) + print(table.nogetn()) -- verify that we correctly handle non-existent members +end +)"); + + REQUIRE(2 == result.warnings.size()); + CHECK_EQ(result.warnings[0].text, "Member 'table.getn' is deprecated, use '#' instead"); + CHECK_EQ(result.warnings[1].text, "Member 'table.foreach' is deprecated"); } TEST_CASE_FIXTURE(BuiltinsFixture, "TableOperations") diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index ca06046a..c45932c6 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -488,6 +488,20 @@ TEST_CASE_FIXTURE(NormalizeFixture, "negate_boolean_2") )"))); } +TEST_CASE_FIXTURE(NormalizeFixture, "double_negation") +{ + CHECK("number" == toString(normal(R"( + number & Not> + )"))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "negate_any") +{ + CHECK("number" == toString(normal(R"( + number & Not + )"))); +} + TEST_CASE_FIXTURE(NormalizeFixture, "intersect_function_and_top_function") { CHECK("() -> ()" == toString(normal(R"( diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index c72cbcce..9ff16d16 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -458,6 +458,24 @@ TEST_CASE_FIXTURE(Fixture, "type_alias_should_work_when_name_is_also_local") REQUIRE(block->body.data[1]->is()); } +TEST_CASE_FIXTURE(Fixture, "type_alias_span_is_correct") +{ + AstStatBlock* block = parse(R"( + type Packed1 = (T...) -> (T...) + type Packed2 = (Packed1, T...) -> (Packed1, T...) + )"); + + REQUIRE(block != nullptr); + REQUIRE(2 == block->body.size); + AstStatTypeAlias* t1 = block->body.data[0]->as(); + REQUIRE(t1); + REQUIRE(Location{Position{1, 8}, Position{1, 45}} == t1->location); + + AstStatTypeAlias* t2 = block->body.data[1]->as(); + REQUIRE(t2); + REQUIRE(Location{Position{2, 8}, Position{2, 75}} == t2->location); +} + TEST_CASE_FIXTURE(Fixture, "parse_error_messages") { CHECK_EQ(getParseError(R"( @@ -1020,6 +1038,35 @@ TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_call_without_parens") } } +TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_without_expression") +{ + ScopedFastFlag sff("LuauFixInterpStringMid", true); + + try + { + parse(R"( + print(`{}`) + )"); + FAIL("Expected ParseErrors to be thrown"); + } + catch (const ParseErrors& e) + { + CHECK_EQ("Malformed interpolated string, expected expression inside '{}'", e.getErrors().front().getMessage()); + } + + try + { + parse(R"( + print(`{}{1}`) + )"); + FAIL("Expected ParseErrors to be thrown"); + } + catch (const ParseErrors& e) + { + CHECK_EQ("Malformed interpolated string, expected expression inside '{}'", e.getErrors().front().getMessage()); + } +} + TEST_CASE_FIXTURE(Fixture, "parse_nesting_based_end_detection") { try diff --git a/tests/TypeInfer.oop.test.cpp b/tests/TypeInfer.oop.test.cpp index 088b4d56..cf27518a 100644 --- a/tests/TypeInfer.oop.test.cpp +++ b/tests/TypeInfer.oop.test.cpp @@ -272,4 +272,22 @@ end LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "set_prop_of_intersection_containing_metatable") +{ + CheckResult result = check(R"( + export type Set = typeof(setmetatable( + {} :: { + add: (self: Set, T) -> Set, + }, + {} + )) + + local Set = {} :: Set & {} + + function Set:add(t) + return self + end + )"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index c389f325..0aacb8ae 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -814,18 +814,4 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_with_a_singleton_argument") } } -TEST_CASE_FIXTURE(BuiltinsFixture, "refine_unknown_to_table_then_clone_it") -{ - CheckResult result = check(R"( - local function f(x: unknown) - if typeof(x) == "table" then - local cloned: {} = table.clone(x) - end - end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - // LUAU_REQUIRE_NO_ERRORS(result); -} - TEST_SUITE_END(); diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index ba0f975e..570cf278 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -535,13 +535,20 @@ TEST_CASE_FIXTURE(Fixture, "string_not_equal_to_string_or_nil") CHECK_EQ(toString(requireTypeAtPosition({6, 29})), "string"); // a ~= b CHECK_EQ(toString(requireTypeAtPosition({6, 32})), "string?"); // a ~= b - CHECK_EQ(toString(requireTypeAtPosition({8, 29})), "string"); // a == b - CHECK_EQ(toString(requireTypeAtPosition({8, 32})), "string?"); // a == b + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ(toString(requireTypeAtPosition({8, 29})), "string?"); // a == b + CHECK_EQ(toString(requireTypeAtPosition({8, 32})), "string?"); // a == b + } + else + { + CHECK_EQ(toString(requireTypeAtPosition({8, 29})), "string"); // a == b + CHECK_EQ(toString(requireTypeAtPosition({8, 32})), "string?"); // a == b + } } TEST_CASE_FIXTURE(Fixture, "narrow_property_of_a_bounded_variable") { - CheckResult result = check(R"( local t local u: {x: number?} = {x = nil} @@ -804,7 +811,10 @@ TEST_CASE_FIXTURE(Fixture, "not_t_or_some_prop_of_t") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("{| x: boolean |}?", toString(requireTypeAtPosition({3, 28}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK_EQ("{| x: true |}?", toString(requireTypeAtPosition({3, 28}))); + else + CHECK_EQ("{| x: boolean |}?", toString(requireTypeAtPosition({3, 28}))); } TEST_CASE_FIXTURE(BuiltinsFixture, "assert_a_to_be_truthy_then_assert_a_to_be_number") @@ -1523,7 +1533,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "refine_unknown_to_table_then_take_the_length if (FFlag::DebugLuauDeferredConstraintResolution) { LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("a & table", toString(requireTypeAtPosition({3, 29}))); + CHECK_EQ("table", toString(requireTypeAtPosition({3, 29}))); } else { @@ -1532,6 +1542,26 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "refine_unknown_to_table_then_take_the_length } } +TEST_CASE_FIXTURE(BuiltinsFixture, "refine_unknown_to_table_then_clone_it") +{ + CheckResult result = check(R"( + local function f(x: unknown) + if typeof(x) == "table" then + local cloned: {} = table.clone(x) + end + end + )"); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + LUAU_REQUIRE_NO_ERRORS(result); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + } +} + TEST_CASE_FIXTURE(RefinementClassFixture, "refine_a_param_that_got_resolved_during_constraint_solving_stage") { CheckResult result = check(R"( @@ -1573,4 +1603,150 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "refine_a_param_that_got_resolved_duri CHECK_EQ("Instance", toString(requireTypeAtPosition({7, 28}))); } +TEST_CASE_FIXTURE(Fixture, "refine_a_property_of_some_global") +{ + ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + + CheckResult result = check(R"( + foo = { bar = 5 :: number? } + + if foo.bar then + local bar = foo.bar + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(3, result); + CHECK_EQ("*error-type*", toString(requireTypeAtPosition({4, 30}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "dataflow_analysis_can_tell_refinements_when_its_appropriate_to_refine_into_nil_or_never") +{ + CheckResult result = check(R"( + local function f(t: {string}, s: string) + local v1 = t[5] + local v2 = v1 + + if typeof(v1) == "nil" then + local foo = v1 + else + local foo = v1 + end + + if typeof(v2) == "nil" then + local foo = v2 + else + local foo = v2 + end + + if typeof(s) == "nil" then + local foo = s + else + local foo = s + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("nil", toString(requireTypeAtPosition({6, 28}))); + CHECK_EQ("string", toString(requireTypeAtPosition({8, 28}))); + + CHECK_EQ("nil", toString(requireTypeAtPosition({12, 28}))); + CHECK_EQ("string", toString(requireTypeAtPosition({14, 28}))); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("never", toString(requireTypeAtPosition({18, 28}))); + CHECK_EQ("string", toString(requireTypeAtPosition({20, 28}))); + } + else + { + CHECK_EQ("nil", toString(requireTypeAtPosition({18, 28}))); + CHECK_EQ("string", toString(requireTypeAtPosition({20, 28}))); + } +} + +TEST_CASE_FIXTURE(Fixture, "cat_or_dog_through_a_local") +{ + CheckResult result = check(R"( + type Cat = { tag: "cat", catfood: string } + type Dog = { tag: "dog", dogfood: string } + type Animal = Cat | Dog + + local function f(animal: Animal) + local tag = animal.tag + if tag == "dog" then + local dog = animal + elseif tag == "cat" then + local cat = animal + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("Cat | Dog", toString(requireTypeAtPosition({8, 28}))); + CHECK_EQ("Cat | Dog", toString(requireTypeAtPosition({10, 28}))); +} + +TEST_CASE_FIXTURE(Fixture, "prove_that_dataflow_analysis_isnt_doing_alias_tracking_yet") +{ + CheckResult result = check(R"( + local function f(tag: "cat" | "dog") + local tag2 = tag + + if tag2 == "cat" then + local foo = tag + else + local foo = tag + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(R"("cat" | "dog")", toString(requireTypeAtPosition({5, 28}))); + CHECK_EQ(R"("cat" | "dog")", toString(requireTypeAtPosition({7, 28}))); +} + +TEST_CASE_FIXTURE(Fixture, "fail_to_refine_a_property_of_subscript_expression") +{ + CheckResult result = check(R"( + type Foo = { foo: number? } + local function f(t: {Foo}) + if t[1].foo then + local foo = t[1].foo + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("number?", toString(requireTypeAtPosition({4, 34}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "type_annotations_arent_relevant_when_doing_dataflow_analysis") +{ + CheckResult result = check(R"( + local function s() return "hello" end + + local function f(t: {string}) + local s1: string = t[5] + local s2: string = s() + + if typeof(s1) == "nil" and typeof(s2) == "nil" then + local foo = s1 + local bar = s2 + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("nil", toString(requireTypeAtPosition({8, 28}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK_EQ("never", toString(requireTypeAtPosition({9, 28}))); + else + CHECK_EQ("nil", toString(requireTypeAtPosition({9, 28}))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 27b43aa9..2a87f0e3 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -519,10 +519,16 @@ TEST_CASE_FIXTURE(Fixture, "okay_to_add_property_to_unsealed_tables_by_function_ local x = get(t) )"); - // Currently this errors but it shouldn't, since set only needs write access - // TODO: file a JIRA for this - LUAU_REQUIRE_ERRORS(result); - // CHECK_EQ("number?", toString(requireType("x"))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("number", toString(requireType("x"))); + } + else + { + LUAU_REQUIRE_ERRORS(result); + // CHECK_EQ("number?", toString(requireType("x"))); + } } TEST_CASE_FIXTURE(Fixture, "width_subtyping") @@ -2646,7 +2652,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "dont_quantify_table_that_belongs_to_outer_sc const MetatableType* newRet = get(follow(*newRetType)); REQUIRE(newRet); - const TableType* newRetMeta = get(newRet->metatable); + const TableType* newRetMeta = get(follow(newRet->metatable)); REQUIRE(newRetMeta); CHECK(newRetMeta->props.count("incr")); @@ -3601,4 +3607,42 @@ TEST_CASE_FIXTURE(Fixture, "dont_extend_unsealed_tables_in_rvalue_position") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "extend_unsealed_table_with_metatable") +{ + CheckResult result = check(R"( + local T = setmetatable({}, { + __call = function(_, name: string?) + end, + }) + + T.for_ = "for_" + + return T + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "top_table_type_is_isomorphic_to_empty_sealed_table_type") +{ + CheckResult result = check(R"( + local None = newproxy(true) + local mt = getmetatable(None) + mt.__tostring = function() + return "Object.None" + end + + function assign(...) + for index = 1, select("#", ...) do + local rest = select(index, ...) + + if rest ~= nil and typeof(rest) == "table" then + for key, value in pairs(rest) do + end + end + end + end + )"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 16797ee4..3865e83a 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -439,7 +439,10 @@ end _ += not _ do end )"); +} +TEST_CASE_FIXTURE(Fixture, "cyclic_follow_2") +{ check(R"( --!nonstrict n13,_,table,_,l0,_,_ = ... diff --git a/tests/conformance/math.lua b/tests/conformance/math.lua index 972c399b..e2f68e65 100644 --- a/tests/conformance/math.lua +++ b/tests/conformance/math.lua @@ -261,6 +261,11 @@ assert(math.sign(inf) == 1) assert(math.sign(-inf) == -1) assert(math.sign(nan) == 0) +assert(math.min(nan, 2) ~= math.min(nan, 2)) +assert(math.min(1, nan) == 1) +assert(math.max(nan, 2) ~= math.max(nan, 2)) +assert(math.max(1, nan) == 1) + -- clamp assert(math.clamp(-1, 0, 1) == 0) assert(math.clamp(0.5, 0, 1) == 0.5) diff --git a/tools/faillist.txt b/tools/faillist.txt index 5c84d168..c6831298 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -1,13 +1,9 @@ -AnnotationTests.instantiate_type_fun_should_not_trip_rbxassert AnnotationTests.too_many_type_params -AnnotationTests.two_type_params AstQuery.last_argument_function_call_type AstQuery::getDocumentationSymbolAtPosition.overloaded_class_method AstQuery::getDocumentationSymbolAtPosition.overloaded_fn AstQuery::getDocumentationSymbolAtPosition.table_overloaded_function_prop -AutocompleteTest.autocomplete_oop_implicit_self -AutocompleteTest.type_correct_expected_return_type_suggestion -AutocompleteTest.type_correct_suggestion_for_overloads +AutocompleteTest.autocomplete_response_perf1 BuiltinTests.aliased_string_format BuiltinTests.assert_removes_falsy_types BuiltinTests.assert_removes_falsy_types2 @@ -43,7 +39,6 @@ GenericsTests.bound_tables_do_not_clone_original_fields GenericsTests.check_mutual_generic_functions GenericsTests.correctly_instantiate_polymorphic_member_functions GenericsTests.do_not_infer_generic_functions -GenericsTests.dont_unify_bound_types GenericsTests.generic_argument_count_too_few GenericsTests.generic_argument_count_too_many GenericsTests.generic_functions_should_be_memory_safe @@ -56,7 +51,6 @@ GenericsTests.infer_generic_lib_function_function_argument GenericsTests.instantiated_function_argument_names GenericsTests.no_stack_overflow_from_quantifying GenericsTests.self_recursive_instantiated_param -IntersectionTypes.overload_is_not_a_function IntersectionTypes.table_intersection_write_sealed IntersectionTypes.table_intersection_write_sealed_indirect IntersectionTypes.table_write_sealed_indirect @@ -72,7 +66,6 @@ NonstrictModeTests.local_tables_are_not_any NonstrictModeTests.locals_are_any_by_default NonstrictModeTests.offer_a_hint_if_you_use_a_dot_instead_of_a_colon NonstrictModeTests.parameters_having_type_any_are_optional -NonstrictModeTests.table_dot_insert_and_recursive_calls NonstrictModeTests.table_props_are_any ProvisionalTests.assign_table_with_refined_property_with_a_similar_type_is_illegal ProvisionalTests.bail_early_if_unification_is_too_complicated @@ -85,20 +78,11 @@ ProvisionalTests.setmetatable_constrains_free_type_into_free_table ProvisionalTests.specialization_binds_with_prototypes_too_early ProvisionalTests.table_insert_with_a_singleton_argument ProvisionalTests.typeguard_inference_incomplete -RefinementTest.apply_refinements_on_astexprindexexpr_whose_subscript_expr_is_constant_string -RefinementTest.discriminate_from_isa_of_x -RefinementTest.falsiness_of_TruthyPredicate_narrows_into_nil -RefinementTest.narrow_property_of_a_bounded_variable -RefinementTest.nonoptional_type_can_narrow_to_nil_if_sense_is_true -RefinementTest.refine_a_param_that_got_resolved_during_constraint_solving_stage -RefinementTest.refine_param_of_type_folder_or_part_without_using_typeof -RefinementTest.refine_unknowns RefinementTest.type_guard_can_filter_for_intersection_of_tables RefinementTest.type_narrow_to_vector RefinementTest.typeguard_cast_free_table_to_vector RefinementTest.typeguard_in_assert_position RefinementTest.x_as_any_if_x_is_instance_elseif_x_is_table -RefinementTest.x_is_not_instance_or_else_not_part RuntimeLimits.typescript_port_of_Result_type TableTests.a_free_shape_cannot_turn_into_a_scalar_if_it_is_not_compatible TableTests.accidentally_checked_prop_in_opposite_branch @@ -109,7 +93,6 @@ TableTests.checked_prop_too_early TableTests.disallow_indexing_into_an_unsealed_table_with_no_indexer_in_strict_mode TableTests.dont_crash_when_setmetatable_does_not_produce_a_metatabletypevar TableTests.dont_hang_when_trying_to_look_up_in_cyclic_metatable_index -TableTests.dont_quantify_table_that_belongs_to_outer_scope TableTests.dont_suggest_exact_match_keys TableTests.error_detailed_metatable_prop TableTests.expected_indexer_from_table_union @@ -134,7 +117,6 @@ TableTests.less_exponential_blowup_please TableTests.missing_metatable_for_sealed_tables_do_not_get_inferred TableTests.mixed_tables_with_implicit_numbered_keys TableTests.nil_assign_doesnt_hit_indexer -TableTests.nil_assign_doesnt_hit_no_indexer TableTests.ok_to_set_nil_even_on_non_lvalue_base_expr TableTests.only_ascribe_synthetic_names_at_module_scope TableTests.oop_polymorphic @@ -153,7 +135,6 @@ TableTests.table_subtyping_with_missing_props_dont_report_multiple_errors TableTests.table_unification_4 TableTests.used_colon_instead_of_dot TableTests.used_dot_instead_of_colon -TableTests.when_augmenting_an_unsealed_table_with_an_indexer_apply_the_correct_scope_to_the_indexer_type ToString.named_metatable_toStringNamedFunction ToString.toStringDetailed2 ToString.toStringErrorPack @@ -187,11 +168,9 @@ TypeInfer.no_stack_overflow_from_isoptional TypeInfer.no_stack_overflow_from_isoptional2 TypeInfer.tc_after_error_recovery_no_replacement_name_in_error TypeInfer.type_infer_recursion_limit_no_ice -TypeInfer.type_infer_recursion_limit_normalizer TypeInferAnyError.for_in_loop_iterator_is_any2 TypeInferClasses.class_type_mismatch_with_name_conflict TypeInferClasses.classes_without_overloaded_operators_cannot_be_added -TypeInferClasses.higher_order_function_arguments_are_contravariant TypeInferClasses.index_instance_property TypeInferClasses.optional_class_field_access_error TypeInferClasses.table_class_unification_reports_sane_errors_for_missing_properties @@ -239,7 +218,6 @@ TypeInferModules.module_type_conflict_instantiated TypeInferModules.type_error_of_unknown_qualified_type TypeInferOOP.inferring_hundreds_of_self_calls_should_not_suffocate_memory TypeInferOOP.methods_are_topologically_sorted -TypeInferOOP.object_constructor_can_refer_to_method_of_self TypeInferOperators.CallAndOrOfFunctions TypeInferOperators.CallOrOfFunctions TypeInferOperators.cannot_indirectly_compare_types_that_do_not_have_a_metatable @@ -265,25 +243,15 @@ TypeInferUnknownNever.math_operators_and_never TypePackTests.detect_cyclic_typepacks2 TypePackTests.pack_tail_unification_check TypePackTests.type_alias_backwards_compatible -TypePackTests.type_alias_default_mixed_self TypePackTests.type_alias_default_type_errors -TypePackTests.type_alias_default_type_pack_self_chained_tp -TypePackTests.type_alias_default_type_pack_self_tp -TypePackTests.type_alias_defaults_confusing_types -TypePackTests.type_alias_type_pack_multi -TypePackTests.type_alias_type_pack_variadic TypePackTests.type_alias_type_packs_errors -TypePackTests.type_alias_type_packs_nested TypePackTests.unify_variadic_tails_in_arguments TypePackTests.variadic_packs TypeSingletons.function_call_with_singletons TypeSingletons.function_call_with_singletons_mismatch TypeSingletons.indexing_on_union_of_string_singletons TypeSingletons.no_widening_from_callsites -TypeSingletons.overloaded_function_call_with_singletons -TypeSingletons.overloaded_function_call_with_singletons_mismatch TypeSingletons.return_type_of_f_is_not_widened -TypeSingletons.table_properties_singleton_strings_mismatch TypeSingletons.table_properties_type_error_escapes TypeSingletons.taking_the_length_of_union_of_string_singleton TypeSingletons.widen_the_supertype_if_it_is_free_and_subtype_has_singleton