From 23b872620394c04592a0eae81491b4d2dcf5dfab Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Fri, 7 Jun 2024 11:56:00 -0500 Subject: [PATCH 1/3] Fix incorrect comment in lgc.h (#1288) The comment gave an incorrect (reversed) version of the invariant, which could be confusing for people who haven't read the full description in lgc.cpp. Unfortunately this change is difficult to flag. Fixes #1282. --- VM/src/lgc.h | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/VM/src/lgc.h b/VM/src/lgc.h index ba433c67..010d7e86 100644 --- a/VM/src/lgc.h +++ b/VM/src/lgc.h @@ -23,11 +23,10 @@ #define GCSsweep 4 /* -** macro to tell when main invariant (white objects cannot point to black -** ones) must be kept. During a collection, the sweep -** phase may break the invariant, as objects turned white may point to -** still-black objects. The invariant is restored when sweep ends and -** all objects are white again. +** The main invariant of the garbage collector, while marking objects, +** is that a black object can never point to a white one. This invariant +** is not being enforced during a sweep phase, and is restored when sweep +** ends. */ #define keepinvariant(g) ((g)->gcstate == GCSpropagate || (g)->gcstate == GCSpropagateagain || (g)->gcstate == GCSatomic) From 81b2cc7dbe928f78c30b2a2c5629a84e014a750e Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Fri, 7 Jun 2024 12:05:50 -0500 Subject: [PATCH 2/3] tests: Adjust conformance tests to account for array invariant (#1289) These were written before compiler optimizations and array invariant. It is now impossible for t[1] to be stored in the hash part, as this would violate the array invariant that says that elements 1..#t are stored in the array. For ipairs, it doesn't traverse the hash part anymore now, so we adjust the code to make sure no elements outside of the 1..#t slice are covered. For table.find, we can use find-with-offset to still access the hash part. Fixes #1283. --- tests/conformance/basic.lua | 7 +++---- tests/conformance/tables.lua | 4 ++-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/conformance/basic.lua b/tests/conformance/basic.lua index 17f4497a..98f8000e 100644 --- a/tests/conformance/basic.lua +++ b/tests/conformance/basic.lua @@ -726,16 +726,15 @@ assert((function() return sum end)() == 15) --- the reason why this test is interesting is that the table created here has arraysize=0 and a single hash element with key = 1.0 --- ipairs must iterate through that +-- ipairs will not iterate through hash part assert((function() - local arr = { [1] = 42 } + local arr = { [1] = 1, [42] = 42, x = 10 } local sum = 0 for i,v in ipairs(arr) do sum = sum + v end return sum -end)() == 42) +end)() == 1) -- the reason why this test is interesting is it ensures we do correct mutability analysis for locals local function chainTest(n) diff --git a/tests/conformance/tables.lua b/tests/conformance/tables.lua index 3f1efd8e..c739f555 100644 --- a/tests/conformance/tables.lua +++ b/tests/conformance/tables.lua @@ -412,8 +412,8 @@ do assert(table.find({false, true}, true) == 2) - -- make sure table.find checks the hash portion as well by constructing a table literal that forces the value into the hash part - assert(table.find({[(1)] = true}, true) == 1) + -- make sure table.find checks the hash portion as well + assert(table.find({[(2)] = true}, true, 2) == 2) end -- test table.concat From 0fa6a51c914b94c47019cdf5b7c6cfc9358855fb Mon Sep 17 00:00:00 2001 From: Andy Friesen Date: Fri, 7 Jun 2024 10:51:12 -0700 Subject: [PATCH 3/3] Sync to upstream/release/629 (#1290) ### What's new * Implemented parsing logic for attributes * Added `lua_setuserdatametatable` and `lua_getuserdatametatable` C API methods for a faster userdata metatable fetch compared to `luaL_getmetatable`. Note that metatable reference has to still be pinned in memory! ### New Solver * Further improvement to the assignment inference logic * Fix many bugs surrounding constraint dispatch order ### Native Codegen * Add IR lowering hooks for custom host userdata types * Add IR to create new tagged userdata objects * Remove outdated NativeState --- ### Internal Contributors Co-authored-by: Aaron Weiss Co-authored-by: Alexander McCord Co-authored-by: Andy Friesen Co-authored-by: Aviral Goel Co-authored-by: Vighnesh Vijay Co-authored-by: Vyacheslav Egorov --- Analysis/include/Luau/Constraint.h | 22 +- Analysis/include/Luau/ConstraintGenerator.h | 4 + Analysis/include/Luau/ConstraintSolver.h | 14 +- Analysis/include/Luau/Type.h | 30 +- Analysis/include/Luau/Unifier2.h | 1 - Analysis/include/Luau/VisitType.h | 9 - Analysis/src/Clone.cpp | 5 - Analysis/src/Constraint.cpp | 11 +- Analysis/src/ConstraintGenerator.cpp | 211 ++++++++-- Analysis/src/ConstraintSolver.cpp | 193 ++++------ Analysis/src/EmbeddedBuiltinDefinitions.cpp | 9 +- Analysis/src/Frontend.cpp | 6 - Analysis/src/Normalize.cpp | 12 +- Analysis/src/Substitution.cpp | 2 - Analysis/src/ToDot.cpp | 8 - Analysis/src/ToString.cpp | 43 +-- Analysis/src/Transpiler.cpp | 6 +- Analysis/src/Type.cpp | 5 + Analysis/src/TypeAttach.cpp | 4 - Analysis/src/TypeFamily.cpp | 10 +- Analysis/src/Unifier2.cpp | 6 - Ast/include/Luau/Ast.h | 60 ++- Ast/include/Luau/Lexer.h | 14 +- Ast/include/Luau/Parser.h | 35 +- Ast/src/Ast.cpp | 66 +++- Ast/src/Lexer.cpp | 32 +- Ast/src/Parser.cpp | 212 ++++++++-- CodeGen/include/Luau/CodeGen.h | 55 +++ CodeGen/include/Luau/IrData.h | 12 + CodeGen/include/Luau/IrUtils.h | 5 + CodeGen/src/BytecodeAnalysis.cpp | 137 ++++++- CodeGen/src/CodeGenA64.cpp | 33 -- CodeGen/src/CodeGenA64.h | 2 - CodeGen/src/CodeGenContext.cpp | 4 +- CodeGen/src/CodeGenUtils.cpp | 15 + CodeGen/src/CodeGenUtils.h | 2 + CodeGen/src/CodeGenX64.cpp | 33 -- CodeGen/src/CodeGenX64.h | 2 - CodeGen/src/EmitCommonA64.h | 2 - CodeGen/src/EmitCommonX64.h | 1 - CodeGen/src/IrDump.cpp | 4 + CodeGen/src/IrLoweringA64.cpp | 132 +++++-- CodeGen/src/IrLoweringA64.h | 2 +- CodeGen/src/IrLoweringX64.cpp | 72 +++- CodeGen/src/IrLoweringX64.h | 2 +- CodeGen/src/IrTranslation.cpp | 85 +++- CodeGen/src/IrUtils.cpp | 40 ++ CodeGen/src/NativeState.cpp | 106 +---- CodeGen/src/NativeState.h | 17 +- CodeGen/src/OptimizeConstProp.cpp | 54 +++ CodeGen/src/OptimizeDeadStore.cpp | 6 + Compiler/src/Compiler.cpp | 3 +- Config/src/Config.cpp | 6 +- VM/include/lua.h | 4 + VM/src/lapi.cpp | 27 ++ VM/src/lstate.cpp | 3 + VM/src/lstate.h | 1 + tests/Conformance.test.cpp | 313 ++++++++++++++- tests/ConformanceIrHooks.h | 405 +++++++++++++++++++- tests/IrLowering.test.cpp | 357 +++++++++++++++++ tests/Lexer.test.cpp | 4 +- tests/NonStrictTypeChecker.test.cpp | 26 +- tests/Parser.test.cpp | 313 ++++++++++++++- tests/ToString.test.cpp | 4 +- tests/TypeInfer.loops.test.cpp | 12 +- tests/TypeInfer.tables.test.cpp | 10 +- tests/conformance/native_userdata.lua | 42 ++ tools/faillist.txt | 16 +- 68 files changed, 2712 insertions(+), 687 deletions(-) create mode 100644 tests/conformance/native_userdata.lua diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h index 77810516..3f3ad641 100644 --- a/Analysis/include/Luau/Constraint.h +++ b/Analysis/include/Luau/Constraint.h @@ -57,7 +57,7 @@ struct GeneralizationConstraint struct IterableConstraint { TypePackId iterator; - TypePackId variables; + std::vector variables; const AstNode* nextAstFragment; DenseHashMap* astForInNextTypes; @@ -192,13 +192,7 @@ struct HasIndexerConstraint TypeId indexType; }; -struct AssignConstraint -{ - TypeId lhsType; - TypeId rhsType; -}; - -// assign lhsType propName rhsType +// assignProp lhsType propName rhsType // // Assign a value of type rhsType into the named property of lhsType. @@ -212,6 +206,12 @@ struct AssignPropConstraint /// populate astTypes during constraint resolution. Nothing should ever /// block on it. TypeId propType; + + // When we generate constraints, we increment the remaining prop count on + // the table if we are able. This flag informs the solver as to whether or + // not it should in turn decrement the prop count when this constraint is + // dispatched. + bool decrementPropCount = false; }; struct AssignIndexConstraint @@ -226,13 +226,13 @@ struct AssignIndexConstraint TypeId propType; }; -// resultType ~ unpack sourceTypePack +// resultTypes ~ unpack sourceTypePack // // Similar to PackSubtypeConstraint, but with one important difference: If the // sourcePack is blocked, this constraint blocks. struct UnpackConstraint { - TypePackId resultPack; + std::vector resultPack; TypePackId sourcePack; }; @@ -254,7 +254,7 @@ struct ReducePackConstraint using ConstraintV = Variant; + AssignPropConstraint, AssignIndexConstraint, UnpackConstraint, ReduceConstraint, ReducePackConstraint, EqualityConstraint>; struct Constraint { diff --git a/Analysis/include/Luau/ConstraintGenerator.h b/Analysis/include/Luau/ConstraintGenerator.h index 3e1861ea..b540b82f 100644 --- a/Analysis/include/Luau/ConstraintGenerator.h +++ b/Analysis/include/Luau/ConstraintGenerator.h @@ -118,6 +118,8 @@ struct ConstraintGenerator std::function prepareModuleScope; std::vector requireCycles; + DenseHashMap> localTypes{nullptr}; + DcrLogger* logger; ConstraintGenerator(ModulePtr module, NotNull normalizer, NotNull moduleResolver, NotNull builtinTypes, @@ -354,6 +356,8 @@ private: */ void prepopulateGlobalScope(const ScopePtr& globalScope, AstStatBlock* program); + bool recordPropertyAssignment(TypeId ty); + // Record the fact that a particular local has a particular type in at least // one of its states. void recordInferredBinding(AstLocal* local, TypeId ty); diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index 58361dde..902dd15d 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -142,7 +142,6 @@ struct ConstraintSolver std::pair> tryDispatchSetIndexer( NotNull constraint, TypeId subjectType, TypeId indexType, TypeId propType, bool expandFreeTypeBounds); - bool tryDispatch(const AssignConstraint& c, NotNull constraint); bool tryDispatch(const AssignPropConstraint& c, NotNull constraint); bool tryDispatch(const AssignIndexConstraint& c, NotNull constraint); @@ -158,8 +157,7 @@ struct ConstraintSolver bool tryDispatchIterableTable(TypeId iteratorTy, const IterableConstraint& c, NotNull constraint, bool force); // for a, ... in next_function, t, ... do - bool tryDispatchIterableFunction( - TypeId nextTy, TypeId tableTy, TypeId firstIndexTy, const IterableConstraint& c, NotNull constraint, bool force); + bool tryDispatchIterableFunction(TypeId nextTy, TypeId tableTy, const IterableConstraint& c, NotNull constraint, bool force); std::pair, std::optional> lookupTableProp(NotNull constraint, TypeId subjectType, const std::string& propName, ValueContext context, bool inConditional = false, bool suppressSimplification = false); @@ -168,14 +166,18 @@ struct ConstraintSolver /** * Generate constraints to unpack the types of srcTypes and assign each - * value to the corresponding LocalType in destTypes. + * value to the corresponding BlockedType in destTypes. * - * @param destTypes A finite TypePack comprised of LocalTypes. + * This function also overwrites the owners of each BlockedType. This is + * okay because this function is only used to decompose IterableConstraint + * into an UnpackConstraint. + * + * @param destTypes A vector of types comprised of BlockedTypes. * @param srcTypes A TypePack that represents rvalues to be assigned. * @returns The underlying UnpackConstraint. There's a bit of code in * iteration that needs to pass blocks on to this constraint. */ - NotNull unpackAndAssign(TypePackId destTypes, TypePackId srcTypes, NotNull constraint); + NotNull unpackAndAssign(const std::vector destTypes, TypePackId srcTypes, NotNull constraint); void block(NotNull target, NotNull constraint); /** diff --git a/Analysis/include/Luau/Type.h b/Analysis/include/Luau/Type.h index 47161886..6105ede3 100644 --- a/Analysis/include/Luau/Type.h +++ b/Analysis/include/Luau/Type.h @@ -86,24 +86,6 @@ struct FreeType TypeId upperBound = nullptr; }; -/** A type that tracks the domain of a local variable. - * - * We consider each local's domain to be the union of all types assigned to it. - * We accomplish this with LocalType. Each time we dispatch an assignment to a - * local, we accumulate this union and decrement blockCount. - * - * When blockCount reaches 0, we can consider the LocalType to be "fully baked" - * and replace it with the union we've built. - */ -struct LocalType -{ - TypeId domain; - int blockCount = 0; - - // Used for debugging - std::string name; -}; - struct GenericType { // By default, generics are global, with a synthetic name @@ -148,6 +130,7 @@ struct BlockedType Constraint* getOwner() const; void setOwner(Constraint* newOwner); + void replaceOwner(Constraint* newOwner); private: // The constraint that is intended to unblock this type. Other constraints @@ -471,6 +454,11 @@ struct TableType // Methods of this table that have an untyped self will use the same shared self type. std::optional selfTy; + + // We track the number of as-yet-unadded properties to unsealed tables. + // Some constraints will use this information to decide whether or not they + // are able to dispatch. + size_t remainingProps = 0; }; // Represents a metatable attached to a table type. Somewhat analogous to a bound type. @@ -669,9 +657,9 @@ struct NegationType using ErrorType = Unifiable::Error; -using TypeVariant = Unifiable::Variant; +using TypeVariant = + Unifiable::Variant; struct Type final { diff --git a/Analysis/include/Luau/Unifier2.h b/Analysis/include/Luau/Unifier2.h index 130c0c3c..bbf3a63a 100644 --- a/Analysis/include/Luau/Unifier2.h +++ b/Analysis/include/Luau/Unifier2.h @@ -69,7 +69,6 @@ struct Unifier2 */ bool unify(TypeId subTy, TypeId superTy); bool unifyFreeWithType(TypeId subTy, TypeId superTy); - bool unify(const LocalType* subTy, TypeId superFn); bool unify(TypeId subTy, const FunctionType* superFn); bool unify(const UnionType* subUnion, TypeId superTy); bool unify(TypeId subTy, const UnionType* superUnion); diff --git a/Analysis/include/Luau/VisitType.h b/Analysis/include/Luau/VisitType.h index 40dccbd2..ff0656d6 100644 --- a/Analysis/include/Luau/VisitType.h +++ b/Analysis/include/Luau/VisitType.h @@ -100,10 +100,6 @@ struct GenericTypeVisitor { return visit(ty); } - virtual bool visit(TypeId ty, const LocalType& ftv) - { - return visit(ty); - } virtual bool visit(TypeId ty, const GenericType& gtv) { return visit(ty); @@ -248,11 +244,6 @@ struct GenericTypeVisitor else visit(ty, *ftv); } - else if (auto lt = get(ty)) - { - if (visit(ty, *lt)) - traverse(lt->domain); - } else if (auto gtv = get(ty)) visit(ty, *gtv); else if (auto etv = get(ty)) diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index a96e5866..371ace2e 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -271,11 +271,6 @@ private: t->upperBound = shallowClone(t->upperBound); } - void cloneChildren(LocalType* t) - { - t->domain = shallowClone(t->domain); - } - void cloneChildren(GenericType* t) { // TOOD: clone upper bounds. diff --git a/Analysis/src/Constraint.cpp b/Analysis/src/Constraint.cpp index bd31beff..7b3377cb 100644 --- a/Analysis/src/Constraint.cpp +++ b/Analysis/src/Constraint.cpp @@ -81,7 +81,8 @@ DenseHashSet Constraint::getMaybeMutatedFreeTypes() const } else if (auto itc = get(*this)) { - rci.traverse(itc->variables); + for (TypeId ty : itc->variables) + rci.traverse(ty); // `IterableConstraints` should not mutate `iterator`. } else if (auto nc = get(*this)) @@ -106,11 +107,6 @@ DenseHashSet Constraint::getMaybeMutatedFreeTypes() const rci.traverse(hic->resultType); // `HasIndexerConstraint` should not mutate `subjectType` or `indexType`. } - else if (auto ac = get(*this)) - { - rci.traverse(ac->lhsType); - rci.traverse(ac->rhsType); - } else if (auto apc = get(*this)) { rci.traverse(apc->lhsType); @@ -124,7 +120,8 @@ DenseHashSet Constraint::getMaybeMutatedFreeTypes() const } else if (auto uc = get(*this)) { - rci.traverse(uc->resultPack); + for (TypeId ty : uc->resultPack) + rci.traverse(ty); // `UnpackConstraint` should not mutate `sourcePack`. } else if (auto rpc = get(*this)) diff --git a/Analysis/src/ConstraintGenerator.cpp b/Analysis/src/ConstraintGenerator.cpp index 12648eb0..9d825408 100644 --- a/Analysis/src/ConstraintGenerator.cpp +++ b/Analysis/src/ConstraintGenerator.cpp @@ -28,6 +28,7 @@ LUAU_FASTINT(LuauCheckRecursionLimit); LUAU_FASTFLAG(DebugLuauLogSolverToJson); LUAU_FASTFLAG(DebugLuauMagicTypes); +LUAU_FASTFLAG(LuauAttributeSyntax); namespace Luau { @@ -246,6 +247,17 @@ void ConstraintGenerator::visitModuleRoot(AstStatBlock* block) if (logger) logger->captureGenerationModule(module); + + for (const auto& [ty, domain] : localTypes) + { + // FIXME: This isn't the most efficient thing. + TypeId domainTy = builtinTypes->neverType; + for (TypeId d : domain) + domainTy = simplifyUnion(builtinTypes, arena, domainTy, d).result; + + LUAU_ASSERT(get(ty)); + asMutable(ty)->ty.emplace(domainTy); + } } TypeId ConstraintGenerator::freshType(const ScopePtr& scope) @@ -310,7 +322,8 @@ std::optional ConstraintGenerator::lookup(const ScopePtr& scope, Locatio std::optional ty = lookup(scope, location, operand, /*prototype*/ false); if (!ty) { - ty = arena->addType(LocalType{builtinTypes->neverType}); + ty = arena->addType(BlockedType{}); + localTypes[*ty] = {}; rootScope->lvalueTypes[operand] = *ty; } @@ -703,7 +716,8 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatLocal* stat { const Location location = local->location; - TypeId assignee = arena->addType(LocalType{builtinTypes->neverType, /* blockCount */ 1, local->name.value}); + TypeId assignee = arena->addType(BlockedType{}); + localTypes[assignee] = {}; assignees.push_back(assignee); @@ -740,7 +754,12 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatLocal* stat if (hasAnnotation) { for (size_t i = 0; i < statLocal->vars.size; ++i) - addConstraint(scope, statLocal->location, AssignConstraint{assignees[i], annotatedTypes[i]}); + { + LUAU_ASSERT(get(assignees[i])); + std::vector* localDomain = localTypes.find(assignees[i]); + LUAU_ASSERT(localDomain); + localDomain->push_back(annotatedTypes[i]); + } TypePackId annotatedPack = arena->addTypePack(std::move(annotatedTypes)); addConstraint(scope, statLocal->location, PackSubtypeConstraint{rvaluePack, annotatedPack}); @@ -750,15 +769,30 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatLocal* stat std::vector valueTypes; valueTypes.reserve(statLocal->vars.size); - for (size_t i = 0; i < statLocal->vars.size; ++i) - valueTypes.push_back(arena->addType(BlockedType{})); + auto [head, tail] = flatten(rvaluePack); - auto uc = addConstraint(scope, statLocal->location, UnpackConstraint{arena->addTypePack(valueTypes), rvaluePack}); + if (head.size() >= statLocal->vars.size) + { + for (size_t i = 0; i < statLocal->vars.size; ++i) + valueTypes.push_back(head[i]); + } + else + { + for (size_t i = 0; i < statLocal->vars.size; ++i) + valueTypes.push_back(arena->addType(BlockedType{})); + + auto uc = addConstraint(scope, statLocal->location, UnpackConstraint{valueTypes, rvaluePack}); + + for (TypeId t: valueTypes) + getMutable(t)->setOwner(uc); + } for (size_t i = 0; i < statLocal->vars.size; ++i) { - getMutable(valueTypes[i])->setOwner(uc); - addConstraint(scope, statLocal->location, AssignConstraint{assignees[i], valueTypes[i]}); + LUAU_ASSERT(get(assignees[i])); + std::vector* localDomain = localTypes.find(assignees[i]); + LUAU_ASSERT(localDomain); + localDomain->push_back(valueTypes[i]); } } @@ -860,25 +894,34 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatForIn* forI for (AstLocal* var : forIn->vars) { - TypeId assignee = arena->addType(LocalType{builtinTypes->neverType, /* blockCount */ 1, var->name.value}); + TypeId assignee = arena->addType(BlockedType{}); variableTypes.push_back(assignee); + TypeId loopVar = arena->addType(BlockedType{}); + localTypes[loopVar].push_back(assignee); + if (var->annotation) { TypeId annotationTy = resolveType(loopScope, var->annotation, /*inTypeArguments*/ false); loopScope->bindings[var] = Binding{annotationTy, var->location}; - addConstraint(scope, var->location, SubtypeConstraint{assignee, annotationTy}); + addConstraint(scope, var->location, SubtypeConstraint{loopVar, annotationTy}); } else - loopScope->bindings[var] = Binding{assignee, var->location}; + loopScope->bindings[var] = Binding{loopVar, var->location}; DefId def = dfg->getDef(var); - loopScope->lvalueTypes[def] = assignee; + loopScope->lvalueTypes[def] = loopVar; } - TypePackId variablePack = arena->addTypePack(std::move(variableTypes)); auto iterable = addConstraint( - loopScope, getLocation(forIn->values), IterableConstraint{iterator, variablePack, forIn->values.data[0], &module->astForInNextTypes}); + loopScope, getLocation(forIn->values), IterableConstraint{iterator, variableTypes, forIn->values.data[0], &module->astForInNextTypes}); + + for (TypeId var: variableTypes) + { + auto bt = getMutable(var); + LUAU_ASSERT(bt); + bt->setOwner(iterable); + } Checkpoint start = checkpoint(this); visit(loopScope, forIn->body); @@ -1105,14 +1148,31 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatAssign* ass std::vector valueTypes; valueTypes.reserve(assign->vars.size); - for (size_t i = 0; i < assign->vars.size; ++i) - valueTypes.push_back(arena->addType(BlockedType{})); + auto [head, tail] = flatten(resultPack); + if (head.size() >= assign->vars.size) + { + // If the resultPack is definitely long enough for each variable, we can + // skip the UnpackConstraint and use the result types directly. - auto uc = addConstraint(scope, assign->location, UnpackConstraint{arena->addTypePack(valueTypes), resultPack}); + for (size_t i = 0; i < assign->vars.size; ++i) + valueTypes.push_back(head[i]); + } + else + { + // We're not sure how many types are produced by the right-side + // expressions. We'll use an UnpackConstraint to defer this until + // later. + for (size_t i = 0; i < assign->vars.size; ++i) + valueTypes.push_back(arena->addType(BlockedType{})); + + auto uc = addConstraint(scope, assign->location, UnpackConstraint{valueTypes, resultPack}); + + for (TypeId t: valueTypes) + getMutable(t)->setOwner(uc); + } for (size_t i = 0; i < assign->vars.size; ++i) { - getMutable(valueTypes[i])->setOwner(uc); visitLValue(scope, assign->vars.data[i], valueTypes[i]); } @@ -1393,7 +1453,7 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatDeclareFunc TypePackId retPack = resolveTypePack(funScope, global->retTypes, /* inTypeArguments */ false); TypeId fnType = arena->addType(FunctionType{TypeLevel{}, funScope.get(), std::move(genericTys), std::move(genericTps), paramPack, retPack}); FunctionType* ftv = getMutable(fnType); - ftv->isCheckedFunction = global->checkedFunction; + ftv->isCheckedFunction = FFlag::LuauAttributeSyntax ? global->isCheckedFunction() : false; ftv->argNames.reserve(global->paramNames.size); for (const auto& el : global->paramNames) @@ -1599,9 +1659,8 @@ InferencePack ConstraintGenerator::checkPack(const ScopePtr& scope, AstExprCall* mt = arena->addType(BlockedType{}); unpackedTypes.emplace_back(mt); - TypePackId mtPack = arena->addTypePack(std::move(unpackedTypes)); - auto c = addConstraint(scope, call->location, UnpackConstraint{mtPack, *argTail}); + auto c = addConstraint(scope, call->location, UnpackConstraint{unpackedTypes, *argTail}); getMutable(mt)->setOwner(c); if (auto b = getMutable(target); b && b->getOwner() == nullptr) b->setOwner(c); @@ -1842,7 +1901,37 @@ Inference ConstraintGenerator::checkIndexName( const ScopePtr& scope, const RefinementKey* key, AstExpr* indexee, const std::string& index, Location indexLocation) { TypeId obj = check(scope, indexee).ty; - TypeId result = arena->addType(BlockedType{}); + TypeId result = nullptr; + + // We optimize away the HasProp constraint in simple cases so that we can + // reason about updates to unsealed tables more accurately. + + const TableType* tt = getTableType(obj); + + // This is a little bit iffy but I *believe* it is okay because, if the + // local's domain is going to be extended at all, it will be someplace after + // the current lexical position within the script. + if (!tt) + { + if (auto localDomain = localTypes.find(obj); localDomain && 1 == localDomain->size()) + tt = getTableType(localDomain->front()); + } + + if (tt) + { + auto it = tt->props.find(index); + if (it != tt->props.end() && it->second.readTy.has_value()) + result = *it->second.readTy; + } + + if (!result) + { + result = arena->addType(BlockedType{}); + + auto c = addConstraint( + scope, indexee->location, HasPropConstraint{result, obj, std::move(index), ValueContext::RValue, inConditional(typeContext)}); + getMutable(result)->setOwner(c); + } if (key) { @@ -1852,10 +1941,6 @@ Inference ConstraintGenerator::checkIndexName( scope->rvalueRefinements[key->def] = result; } - auto c = - addConstraint(scope, indexee->location, HasPropConstraint{result, obj, std::move(index), ValueContext::RValue, inConditional(typeContext)}); - getMutable(result)->setOwner(c); - if (key) return Inference{result, refinementArena.proposition(key, builtinTypes->truthyType)}; else @@ -2242,18 +2327,14 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprLocal* local if (ty) { - if (auto lt = getMutable(*ty)) - ++lt->blockCount; - else if (auto ut = getMutable(*ty)) - { - for (TypeId optTy : ut->options) - if (auto lt = getMutable(optTy)) - ++lt->blockCount; - } + std::vector* localDomain = localTypes.find(*ty); + if (localDomain) + localDomain->push_back(rhsType); } else { - ty = arena->addType(LocalType{builtinTypes->neverType, /* blockCount */ 1, local->local->name.value}); + ty = arena->addType(BlockedType{}); + localTypes[*ty].push_back(rhsType); if (annotatedTy) { @@ -2277,7 +2358,9 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprLocal* local if (annotatedTy) addConstraint(scope, local->location, SubtypeConstraint{rhsType, *annotatedTy}); - addConstraint(scope, local->location, AssignConstraint{*ty, rhsType}); + + if (auto localDomain = localTypes.find(*ty)) + localDomain->push_back(rhsType); } void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprGlobal* global, TypeId rhsType) @@ -2289,7 +2372,6 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprGlobal* glob rootScope->lvalueTypes[def] = rhsType; addConstraint(scope, global->location, SubtypeConstraint{rhsType, *annotatedTy}); - addConstraint(scope, global->location, AssignConstraint{*annotatedTy, rhsType}); } } @@ -2298,7 +2380,10 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprIndexName* e TypeId lhsTy = check(scope, expr->expr).ty; TypeId propTy = arena->addType(BlockedType{}); module->astTypes[expr] = propTy; - addConstraint(scope, expr->location, AssignPropConstraint{lhsTy, expr->index.value, rhsType, propTy}); + + bool incremented = recordPropertyAssignment(lhsTy); + + addConstraint(scope, expr->location, AssignPropConstraint{lhsTy, expr->index.value, rhsType, propTy, incremented}); } void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprIndexExpr* expr, TypeId rhsType) @@ -2310,7 +2395,10 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprIndexExpr* e module->astTypes[expr] = propTy; module->astTypes[expr->index] = builtinTypes->stringType; // FIXME? Singleton strings exist. std::string propName{constantString->value.data, constantString->value.size}; - addConstraint(scope, expr->location, AssignPropConstraint{lhsTy, std::move(propName), rhsType, propTy}); + + bool incremented = recordPropertyAssignment(lhsTy); + + addConstraint(scope, expr->location, AssignPropConstraint{lhsTy, std::move(propName), rhsType, propTy, incremented}); return; } @@ -2775,7 +2863,7 @@ TypeId ConstraintGenerator::resolveType(const ScopePtr& scope, AstType* ty, bool // TODO: FunctionType needs a pointer to the scope so that we know // how to quantify/instantiate it. FunctionType ftv{TypeLevel{}, scope.get(), {}, {}, argTypes, returnTypes}; - ftv.isCheckedFunction = fn->checkedFunction; + ftv.isCheckedFunction = FFlag::LuauAttributeSyntax ? fn->isCheckedFunction() : false; // This replicates the behavior of the appropriate FunctionType // constructors. @@ -2977,8 +3065,7 @@ Inference ConstraintGenerator::flattenPack(const ScopePtr& scope, Location locat return Inference{*f, refinement}; TypeId typeResult = arena->addType(BlockedType{}); - TypePackId resultPack = arena->addTypePack({typeResult}, arena->freshTypePack(scope.get())); - auto c = addConstraint(scope, location, UnpackConstraint{resultPack, tp}); + auto c = addConstraint(scope, location, UnpackConstraint{{typeResult}, tp}); getMutable(typeResult)->setOwner(c); return Inference{typeResult, refinement}; @@ -3075,6 +3162,46 @@ void ConstraintGenerator::prepopulateGlobalScope(const ScopePtr& globalScope, As program->visit(&gp); } +bool ConstraintGenerator::recordPropertyAssignment(TypeId ty) +{ + DenseHashSet seen{nullptr}; + VecDeque queue; + + queue.push_back(ty); + + bool incremented = false; + + while (!queue.empty()) + { + const TypeId t = follow(queue.front()); + queue.pop_front(); + + if (seen.find(t)) + continue; + seen.insert(t); + + if (auto tt = getMutable(t); tt && tt->state == TableState::Unsealed) + { + tt->remainingProps += 1; + incremented = true; + } + else if (auto mt = get(t)) + queue.push_back(mt->table); + else if (auto localDomain = localTypes.find(t)) + { + for (TypeId domainTy : *localDomain) + queue.push_back(domainTy); + } + else if (auto ut = get(t)) + { + for (TypeId part : ut) + queue.push_back(part); + } + } + + return incremented; +} + void ConstraintGenerator::recordInferredBinding(AstLocal* local, TypeId ty) { if (InferredBinding* ib = inferredBindings.find(local)) diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index b0f27911..07fc26fb 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -532,8 +532,6 @@ bool ConstraintSolver::tryDispatch(NotNull constraint, bool fo success = tryDispatch(*hpc, constraint); else if (auto spc = get(*constraint)) success = tryDispatch(*spc, constraint); - else if (auto uc = get(*constraint)) - success = tryDispatch(*uc, constraint); else if (auto uc = get(*constraint)) success = tryDispatch(*uc, constraint); else if (auto uc = get(*constraint)) @@ -686,7 +684,8 @@ bool ConstraintSolver::tryDispatch(const IterableConstraint& c, NotNullanyTypePack, c.variables); + for (TypeId ty : c.variables) + unify(constraint, builtinTypes->errorRecoveryType(), ty); return true; } @@ -696,21 +695,35 @@ bool ConstraintSolver::tryDispatch(const IterableConstraint& c, NotNullscope); TypeId valueTy = freshType(arena, builtinTypes, constraint->scope); - TypeId tableTy = arena->addType(TableType{TableState::Sealed, {}, constraint->scope}); - getMutable(tableTy)->indexer = TableIndexer{keyTy, valueTy}; + TypeId tableTy = arena->addType(TableType{ + TableType::Props{}, + TableIndexer{keyTy, valueTy}, + TypeLevel{}, + constraint->scope, + TableState::Free + }); - pushConstraint(constraint->scope, constraint->location, SubtypeConstraint{nextTy, tableTy}); + unify(constraint, nextTy, tableTy); auto it = begin(c.variables); auto endIt = end(c.variables); if (it != endIt) { - pushConstraint(constraint->scope, constraint->location, AssignConstraint{*it, keyTy}); + bindBlockedType(*it, keyTy, keyTy, constraint); ++it; } if (it != endIt) - pushConstraint(constraint->scope, constraint->location, AssignConstraint{*it, valueTy}); + { + bindBlockedType(*it, valueTy, valueTy, constraint); + ++it; + } + + while (it != endIt) + { + bindBlockedType(*it, builtinTypes->nilType, builtinTypes->nilType, constraint); + ++it; + } return true; } @@ -721,11 +734,7 @@ bool ConstraintSolver::tryDispatch(const IterableConstraint& c, NotNull= 2) tableTy = iterator.head[1]; - TypeId firstIndexTy = builtinTypes->nilType; - if (iterator.head.size() >= 3) - firstIndexTy = iterator.head[2]; - - return tryDispatchIterableFunction(nextTy, tableTy, firstIndexTy, c, constraint, force); + return tryDispatchIterableFunction(nextTy, tableTy, c, constraint, force); } else @@ -1310,6 +1319,14 @@ bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNull(subjectType) || get(subjectType)) return block(subjectType, constraint); + if (const TableType* subjectTable = getTableType(subjectType)) + { + if (subjectTable->state == TableState::Unsealed && subjectTable->remainingProps > 0 && subjectTable->props.count(c.prop) == 0) + { + return block(subjectType, constraint); + } + } + auto [blocked, result] = lookupTableProp(constraint, subjectType, c.prop, c.context, c.inConditional, c.suppressSimplification); if (!blocked.empty()) { @@ -1517,7 +1534,10 @@ bool ConstraintSolver::tryDispatch(const HasIndexerConstraint& c, NotNull seen{nullptr}; - return tryDispatchHasIndexer(recursionDepth, constraint, subjectType, indexType, c.resultType, seen); + bool ok = tryDispatchHasIndexer(recursionDepth, constraint, subjectType, indexType, c.resultType, seen); + if (ok) + unblock(c.resultType, constraint->location); + return ok; } std::pair> ConstraintSolver::tryDispatchSetIndexer( @@ -1596,46 +1616,6 @@ std::pair> ConstraintSolver::tryDispatchSetIndexer( return {true, std::nullopt}; } -bool ConstraintSolver::tryDispatch(const AssignConstraint& c, NotNull constraint) -{ - const TypeId lhsTy = follow(c.lhsType); - const TypeId rhsTy = follow(c.rhsType); - - if (!get(lhsTy) && isBlocked(lhsTy)) - return block(lhsTy, constraint); - - auto tryExpand = [&](TypeId ty) { - LocalType* lt = getMutable(ty); - if (!lt) - return; - - lt->domain = simplifyUnion(builtinTypes, arena, lt->domain, rhsTy).result; - LUAU_ASSERT(lt->blockCount > 0); - --lt->blockCount; - - if (0 == lt->blockCount) - { - shiftReferences(ty, lt->domain); - emplaceType(asMutable(ty), lt->domain); - } - }; - - if (auto ut = get(lhsTy)) - { - // FIXME: I suspect there's a bug here where lhsTy is a union that contains no LocalTypes. - for (TypeId t : ut) - tryExpand(t); - } - else if (get(lhsTy)) - tryExpand(lhsTy); - else - unify(constraint, rhsTy, lhsTy); - - unblock(lhsTy, constraint->location); - - return true; -} - bool ConstraintSolver::tryDispatch(const AssignPropConstraint& c, NotNull constraint) { TypeId lhsType = follow(c.lhsType); @@ -1753,6 +1733,14 @@ bool ConstraintSolver::tryDispatch(const AssignPropConstraint& c, NotNull(asMutable(c.propType), rhsType); lhsTable->props[propName] = Property::rw(rhsType); + + if (lhsTable->state == TableState::Unsealed && c.decrementPropCount) + { + LUAU_ASSERT(lhsTable->remainingProps > 0); + lhsTable->remainingProps -= 1; + unblock(lhsType, constraint->location); + } + return true; } } @@ -1927,24 +1915,14 @@ bool ConstraintSolver::tryDispatchUnpack1(NotNull constraint, bool ConstraintSolver::tryDispatch(const UnpackConstraint& c, NotNull constraint) { TypePackId sourcePack = follow(c.sourcePack); - TypePackId resultPack = follow(c.resultPack); if (isBlocked(sourcePack)) return block(sourcePack, constraint); - if (isBlocked(resultPack)) - { - LUAU_ASSERT(canMutate(resultPack, constraint)); - LUAU_ASSERT(resultPack != sourcePack); - emplaceTypePack(asMutable(resultPack), sourcePack); - unblock(resultPack, constraint->location); - return true; - } + TypePack srcPack = extendTypePack(*arena, builtinTypes, sourcePack, c.resultPack.size()); - TypePack srcPack = extendTypePack(*arena, builtinTypes, sourcePack, size(resultPack)); - - auto resultIter = begin(resultPack); - auto resultEnd = end(resultPack); + auto resultIter = begin(c.resultPack); + auto resultEnd = end(c.resultPack); size_t i = 0; while (resultIter != resultEnd) @@ -2080,18 +2058,22 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl auto endIt = end(c.variables); if (it != endIt) { - pushConstraint(constraint->scope, constraint->location, AssignConstraint{*it, keyTy}); + bindBlockedType(*it, keyTy, keyTy, constraint); ++it; } if (it != endIt) - pushConstraint(constraint->scope, constraint->location, AssignConstraint{*it, valueTy}); + bindBlockedType(*it, valueTy, valueTy, constraint); return true; } auto unpack = [&](TypeId ty) { - for (TypeId varTy : c.variables) - pushConstraint(constraint->scope, constraint->location, AssignConstraint{varTy, ty}); + for (TypeId varTy: c.variables) + { + LUAU_ASSERT(get(varTy)); + LUAU_ASSERT(varTy != ty); + bindBlockedType(varTy, ty, ty, constraint); + } }; if (get(iteratorTy)) @@ -2129,27 +2111,18 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl if (iteratorTable->indexer) { - TypePackId expectedVariablePack = arena->addTypePack({iteratorTable->indexer->indexType, iteratorTable->indexer->indexResultType}); - unify(constraint, c.variables, expectedVariablePack); + std::vector expectedVariables{iteratorTable->indexer->indexType, iteratorTable->indexer->indexResultType}; + while (c.variables.size() >= expectedVariables.size()) + expectedVariables.push_back(builtinTypes->errorRecoveryType()); - auto [variableTys, variablesTail] = flatten(c.variables); - - // the local types for the indexer _should_ be all set after unification - for (TypeId ty : variableTys) + for (size_t i = 0; i < c.variables.size(); ++i) { - if (auto lt = getMutable(ty)) - { - LUAU_ASSERT(lt->blockCount > 0); - --lt->blockCount; + LUAU_ASSERT(c.variables[i] != expectedVariables[i]); - LUAU_ASSERT(0 <= lt->blockCount); + unify(constraint, c.variables[i], expectedVariables[i]); - if (0 == lt->blockCount) - { - shiftReferences(ty, lt->domain); - emplaceType(asMutable(ty), lt->domain); - } - } + bindBlockedType(c.variables[i], expectedVariables[i], expectedVariables[i], constraint); + unblock(c.variables[i], constraint->location); } } else @@ -2213,26 +2186,16 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl else if (auto primitiveTy = get(iteratorTy); primitiveTy && primitiveTy->type == PrimitiveType::Type::Table) unpack(builtinTypes->unknownType); else + { unpack(builtinTypes->errorType); + } return true; } bool ConstraintSolver::tryDispatchIterableFunction( - TypeId nextTy, TypeId tableTy, TypeId firstIndexTy, const IterableConstraint& c, NotNull constraint, bool force) + TypeId nextTy, TypeId tableTy, const IterableConstraint& c, NotNull constraint, bool force) { - // We need to know whether or not this type is nil or not. - // If we don't know, block and reschedule ourselves. - firstIndexTy = follow(firstIndexTy); - if (get(firstIndexTy)) - { - if (force) - LUAU_ASSERT(false); - else - block(firstIndexTy, constraint); - return false; - } - const FunctionType* nextFn = get(nextTy); // If this does not hold, we should've never called `tryDispatchIterableFunction` in the first place. LUAU_ASSERT(nextFn); @@ -2267,27 +2230,18 @@ bool ConstraintSolver::tryDispatchIterableFunction( return true; } -NotNull ConstraintSolver::unpackAndAssign(TypePackId destTypes, TypePackId srcTypes, NotNull constraint) +NotNull ConstraintSolver::unpackAndAssign(const std::vector destTypes, TypePackId srcTypes, NotNull constraint) { - std::vector unpackedTys; - for (TypeId _ty : destTypes) + auto c = pushConstraint(constraint->scope, constraint->location, UnpackConstraint{destTypes, srcTypes}); + + for (TypeId t: destTypes) { - (void) _ty; - unpackedTys.push_back(arena->addType(BlockedType{})); + BlockedType* bt = getMutable(t); + LUAU_ASSERT(bt); + bt->replaceOwner(c); } - TypePackId unpackedTp = arena->addTypePack(TypePack{unpackedTys}); - auto unpackConstraint = pushConstraint(constraint->scope, constraint->location, UnpackConstraint{unpackedTp, srcTypes}); - - size_t i = 0; - for (TypeId varTy : destTypes) - { - pushConstraint(constraint->scope, constraint->location, AssignConstraint{varTy, unpackedTys[i]}); - getMutable(unpackedTys[i])->setOwner(unpackConstraint); - ++i; - } - - return unpackConstraint; + return c; } std::pair, std::optional> ConstraintSolver::lookupTableProp(NotNull constraint, TypeId subjectType, @@ -2808,9 +2762,6 @@ bool ConstraintSolver::isBlocked(TypeId ty) { ty = follow(ty); - if (auto lt = get(ty)) - return lt->blockCount > 0; - if (auto tfit = get(ty)) return uninhabitedTypeFamilies.contains(ty) == false; diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index 78b76a78..91d8006a 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -2,6 +2,7 @@ #include "Luau/BuiltinDefinitions.h" LUAU_FASTFLAGVARIABLE(LuauCheckedEmbeddedDefinitions2, false); +LUAU_FASTFLAG(LuauAttributeSyntax); namespace Luau { @@ -319,9 +320,9 @@ declare os: { clock: () -> number, } -declare function @checked require(target: any): any +@checked declare function require(target: any): any -declare function @checked getfenv(target: any): { [string]: any } +@checked declare function getfenv(target: any): { [string]: any } declare _G: any declare _VERSION: string @@ -363,7 +364,7 @@ declare function select(i: string | number, ...: A...): ...any -- (nil, string). declare function loadstring(src: string, chunkname: string?): (((A...) -> any)?, string?) -declare function @checked newproxy(mt: boolean?): any +@checked declare function newproxy(mt: boolean?): any declare coroutine: { create: (f: (A...) -> R...) -> thread, @@ -451,7 +452,7 @@ std::string getBuiltinDefinitionSource() std::string result = kBuiltinDefinitionLuaSrc; // Annotates each non generic function as checked - if (FFlag::LuauCheckedEmbeddedDefinitions2) + if (FFlag::LuauCheckedEmbeddedDefinitions2 && FFlag::LuauAttributeSyntax) result = kBuiltinDefinitionLuaSrcChecked; return result; diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 5261c211..7823f3d4 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -1196,12 +1196,6 @@ struct InternalTypeFinder : TypeOnceVisitor return false; } - bool visit(TypeId, const LocalType&) override - { - LUAU_ASSERT(false); - return false; - } - bool visit(TypePackId, const BlockedTypePack&) override { LUAU_ASSERT(false); diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index 5b14fd5f..7ce50284 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -1815,12 +1815,6 @@ NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId t if (!isCacheable(there)) here.isCacheable = false; } - else if (auto lt = get(there)) - { - // FIXME? This is somewhat questionable. - // Maybe we should assert because this should never happen? - unionNormalWithTy(here, lt->domain, seenSetTypes, ignoreSmallerTyvars); - } else if (get(there)) unionFunctionsWithFunction(here.functions, there); else if (get(there) || get(there)) @@ -3095,7 +3089,7 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type return NormalizationResult::True; } else if (get(there) || get(there) || get(there) || get(there) || - get(there) || get(there)) + get(there)) { NormalizedType thereNorm{builtinTypes}; NormalizedType topNorm{builtinTypes}; @@ -3104,10 +3098,6 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type here.isCacheable = false; return intersectNormals(here, thereNorm); } - else if (auto lt = get(there)) - { - return intersectNormalWithTy(here, lt->domain, seenSetTypes); - } NormalizedTyvars tyvars = std::move(here.tyvars); diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index bc899798..ea9c3178 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -24,8 +24,6 @@ static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool a // We decline to copy them. if constexpr (std::is_same_v) return ty; - else if constexpr (std::is_same_v) - return ty; else if constexpr (std::is_same_v) { // This should never happen, but visit() cannot see it. diff --git a/Analysis/src/ToDot.cpp b/Analysis/src/ToDot.cpp index 9093b38a..17b595b1 100644 --- a/Analysis/src/ToDot.cpp +++ b/Analysis/src/ToDot.cpp @@ -262,14 +262,6 @@ void StateDot::visitChildren(TypeId ty, int index) visitChild(t.upperBound, index, "[upperBound]"); } } - else if constexpr (std::is_same_v) - { - formatAppend(result, "LocalType"); - finishNodeLabel(ty); - finishNode(); - - visitChild(t.domain, 1, "[domain]"); - } else if constexpr (std::is_same_v) { formatAppend(result, "AnyType %d", index); diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 4e81a870..dca041a2 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -100,16 +100,6 @@ struct FindCyclicTypes final : TypeVisitor return false; } - bool visit(TypeId ty, const LocalType& lt) override - { - if (!visited.insert(ty)) - return false; - - traverse(lt.domain); - - return false; - } - bool visit(TypeId ty, const TableType& ttv) override { if (!visited.insert(ty)) @@ -525,21 +515,6 @@ struct TypeStringifier } } - void operator()(TypeId ty, const LocalType& lt) - { - state.emit("l-"); - state.emit(lt.name); - if (FInt::DebugLuauVerboseTypeNames >= 1) - { - state.emit("["); - state.emit(lt.blockCount); - state.emit("]"); - } - state.emit("=["); - stringify(lt.domain); - state.emit("]"); - } - void operator()(TypeId, const BoundType& btv) { stringify(btv.boundTo); @@ -1724,6 +1699,18 @@ std::string generateName(size_t i) return n; } +std::string toStringVector(const std::vector& types, ToStringOptions& opts) +{ + std::string s; + for (TypeId ty : types) + { + if (!s.empty()) + s += ", "; + s += toString(ty, opts); + } + return s; +} + std::string toString(const Constraint& constraint, ToStringOptions& opts) { auto go = [&opts](auto&& c) -> std::string { @@ -1754,7 +1741,7 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts) else if constexpr (std::is_same_v) { std::string iteratorStr = tos(c.iterator); - std::string variableStr = tos(c.variables); + std::string variableStr = toStringVector(c.variables, opts); return variableStr + " ~ iterate " + iteratorStr; } @@ -1791,14 +1778,12 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts) { return tos(c.resultType) + " ~ hasIndexer " + tos(c.subjectType) + " " + tos(c.indexType); } - else if constexpr (std::is_same_v) - return "assign " + tos(c.lhsType) + " " + tos(c.rhsType); else if constexpr (std::is_same_v) return "assignProp " + tos(c.lhsType) + " " + c.propName + " " + tos(c.rhsType); else if constexpr (std::is_same_v) return "assignIndex " + tos(c.lhsType) + " " + tos(c.indexType) + " " + tos(c.rhsType); else if constexpr (std::is_same_v) - return tos(c.resultPack) + " ~ ...unpack " + tos(c.sourcePack); + return toStringVector(c.resultPack, opts) + " ~ ...unpack " + tos(c.sourcePack); else if constexpr (std::is_same_v) return "reduce " + tos(c.ty); else if constexpr (std::is_same_v) diff --git a/Analysis/src/Transpiler.cpp b/Analysis/src/Transpiler.cpp index 85b8849f..d78bf157 100644 --- a/Analysis/src/Transpiler.cpp +++ b/Analysis/src/Transpiler.cpp @@ -1182,11 +1182,11 @@ std::string toString(AstNode* node) Printer printer(writer); printer.writeTypes = true; - if (auto statNode = dynamic_cast(node)) + if (auto statNode = node->asStat()) printer.visualize(*statNode); - else if (auto exprNode = dynamic_cast(node)) + else if (auto exprNode = node->asExpr()) printer.visualize(*exprNode); - else if (auto typeNode = dynamic_cast(node)) + else if (auto typeNode = node->asType()) printer.visualizeTypeAnnotation(*typeNode); return writer.str(); diff --git a/Analysis/src/Type.cpp b/Analysis/src/Type.cpp index b7a54e3d..71cac6fd 100644 --- a/Analysis/src/Type.cpp +++ b/Analysis/src/Type.cpp @@ -561,6 +561,11 @@ void BlockedType::setOwner(Constraint* newOwner) owner = newOwner; } +void BlockedType::replaceOwner(Constraint* newOwner) +{ + owner = newOwner; +} + PendingExpansionType::PendingExpansionType( std::optional prefix, AstName name, std::vector typeArguments, std::vector packArguments) : prefix(prefix) diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index f1fe83ee..c0294fc9 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -338,10 +338,6 @@ public: { return allocator->alloc(Location(), std::nullopt, AstName("free"), std::nullopt, Location()); } - AstType* operator()(const LocalType& lt) - { - return Luau::visit(*this, lt.domain->ty); - } AstType* operator()(const UnionType& uv) { AstArray unionTypes; diff --git a/Analysis/src/TypeFamily.cpp b/Analysis/src/TypeFamily.cpp index 3a0483a6..89de1912 100644 --- a/Analysis/src/TypeFamily.cpp +++ b/Analysis/src/TypeFamily.cpp @@ -447,7 +447,7 @@ FamilyGraphReductionResult reduceFamilies(TypePackId entrypoint, Location locati bool isPending(TypeId ty, ConstraintSolver* solver) { - return is(ty) || (solver && solver->hasUnresolvedConstraints(ty)); + return is(ty) || (solver && solver->hasUnresolvedConstraints(ty)); } template @@ -567,7 +567,7 @@ TypeFamilyReductionResult lenFamilyFn(TypeId instance, const std::vector // check to see if the operand type is resolved enough, and wait to reduce if not // the use of `typeFromNormal` later necessitates blocking on local types. - if (isPending(operandTy, ctx->solver) || get(operandTy)) + if (isPending(operandTy, ctx->solver)) return {std::nullopt, false, {operandTy}, {}}; // if the type is free but has only one remaining reference, we can generalize it to its upper bound here. @@ -1427,12 +1427,6 @@ struct FindRefinementBlockers : TypeOnceVisitor return false; } - bool visit(TypeId ty, const LocalType&) override - { - found.insert(ty); - return false; - } - bool visit(TypeId ty, const ClassType&) override { return false; diff --git a/Analysis/src/Unifier2.cpp b/Analysis/src/Unifier2.cpp index c8db5335..f46c3372 100644 --- a/Analysis/src/Unifier2.cpp +++ b/Analysis/src/Unifier2.cpp @@ -158,12 +158,6 @@ bool Unifier2::unify(TypeId subTy, TypeId superTy) if (subFree || superFree) return true; - if (auto subLocal = getMutable(subTy)) - { - subLocal->domain = mkUnion(subLocal->domain, superTy); - expandedFreeTypes[subTy].push_back(superTy); - } - auto subFn = get(subTy); auto superFn = get(superTy); if (subFn && superFn) diff --git a/Ast/include/Luau/Ast.h b/Ast/include/Luau/Ast.h index 993116d6..e8479e09 100644 --- a/Ast/include/Luau/Ast.h +++ b/Ast/include/Luau/Ast.h @@ -60,6 +60,8 @@ class AstStat; class AstStatBlock; class AstExpr; class AstTypePack; +class AstAttr; +class AstExprTable; struct AstLocal { @@ -172,6 +174,10 @@ public: { return nullptr; } + virtual AstAttr* asAttr() + { + return nullptr; + } template bool is() const @@ -193,6 +199,28 @@ public: Location location; }; +class AstAttr : public AstNode +{ +public: + LUAU_RTTI(AstAttr) + + enum Type + { + Checked, + }; + + AstAttr(const Location& location, Type type); + + AstAttr* asAttr() override + { + return this; + } + + void visit(AstVisitor* visitor) override; + + Type type; +}; + class AstExpr : public AstNode { public: @@ -384,13 +412,15 @@ class AstExprFunction : public AstExpr public: LUAU_RTTI(AstExprFunction) - AstExprFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, - AstLocal* self, const AstArray& args, bool vararg, const Location& varargLocation, AstStatBlock* body, size_t functionDepth, - const AstName& debugname, const std::optional& returnAnnotation = {}, AstTypePack* varargAnnotation = nullptr, + AstExprFunction(const Location& location, const AstArray& attributes, const AstArray& generics, + const AstArray& genericPacks, AstLocal* self, const AstArray& args, bool vararg, + const Location& varargLocation, AstStatBlock* body, size_t functionDepth, const AstName& debugname, + const std::optional& returnAnnotation = {}, AstTypePack* varargAnnotation = nullptr, const std::optional& argLocation = std::nullopt); void visit(AstVisitor* visitor) override; + AstArray attributes; AstArray generics; AstArray genericPacks; AstLocal* self; @@ -810,20 +840,22 @@ public: const AstArray& genericPacks, const AstTypeList& params, const AstArray& paramNames, const AstTypeList& retTypes); - AstStatDeclareFunction(const Location& location, const AstName& name, const AstArray& generics, - const AstArray& genericPacks, const AstTypeList& params, const AstArray& paramNames, - const AstTypeList& retTypes, bool checkedFunction); + AstStatDeclareFunction(const Location& location, const AstArray& attributes, const AstName& name, + const AstArray& generics, const AstArray& genericPacks, const AstTypeList& params, + const AstArray& paramNames, const AstTypeList& retTypes); void visit(AstVisitor* visitor) override; + bool isCheckedFunction() const; + + AstArray attributes; AstName name; AstArray generics; AstArray genericPacks; AstTypeList params; AstArray paramNames; AstTypeList retTypes; - bool checkedFunction; }; struct AstDeclaredClassProp @@ -936,17 +968,20 @@ public: AstTypeFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, const AstTypeList& argTypes, const AstArray>& argNames, const AstTypeList& returnTypes); - AstTypeFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, - const AstTypeList& argTypes, const AstArray>& argNames, const AstTypeList& returnTypes, bool checkedFunction); + AstTypeFunction(const Location& location, const AstArray& attributes, const AstArray& generics, + const AstArray& genericPacks, const AstTypeList& argTypes, const AstArray>& argNames, + const AstTypeList& returnTypes); void visit(AstVisitor* visitor) override; + bool isCheckedFunction() const; + + AstArray attributes; AstArray generics; AstArray genericPacks; AstTypeList argTypes; AstArray> argNames; AstTypeList returnTypes; - bool checkedFunction; }; class AstTypeTypeof : public AstType @@ -1105,6 +1140,11 @@ public: return true; } + virtual bool visit(class AstAttr* node) + { + return visit(static_cast(node)); + } + virtual bool visit(class AstExpr* node) { return visit(static_cast(node)); diff --git a/Ast/include/Luau/Lexer.h b/Ast/include/Luau/Lexer.h index e111030d..f6ac28ad 100644 --- a/Ast/include/Luau/Lexer.h +++ b/Ast/include/Luau/Lexer.h @@ -87,6 +87,8 @@ struct Lexeme Comment, BlockComment, + Attribute, + BrokenString, BrokenComment, BrokenUnicode, @@ -115,14 +117,20 @@ struct Lexeme ReservedTrue, ReservedUntil, ReservedWhile, - ReservedChecked, Reserved_END }; Type type; Location location; + + // Field declared here, before the union, to ensure that Lexeme size is 32 bytes. +private: + // length is used to extract a slice from the input buffer. + // This field is only valid for certain lexeme types which don't duplicate portions of input + // but instead store a pointer to a location in the input buffer and the length of lexeme. unsigned int length; +public: union { const char* data; // String, Number, Comment @@ -135,9 +143,13 @@ struct Lexeme Lexeme(const Location& location, Type type, const char* data, size_t size); Lexeme(const Location& location, Type type, const char* name); + unsigned int getLength() const; + std::string toString() const; }; +static_assert(sizeof(Lexeme) <= 32, "Size of `Lexeme` struct should be up to 32 bytes."); + class AstNameTable { public: diff --git a/Ast/include/Luau/Parser.h b/Ast/include/Luau/Parser.h index e97df66b..c1fd43ea 100644 --- a/Ast/include/Luau/Parser.h +++ b/Ast/include/Luau/Parser.h @@ -82,8 +82,8 @@ private: // if exp then block {elseif exp then block} [else block] end | // for Name `=' exp `,' exp [`,' exp] do block end | // for namelist in explist do block end | - // function funcname funcbody | - // local function Name funcbody | + // [attributes] function funcname funcbody | + // [attributes] local function Name funcbody | // local namelist [`=' explist] // laststat ::= return [explist] | break AstStat* parseStat(); @@ -114,11 +114,25 @@ private: AstExpr* parseFunctionName(Location start, bool& hasself, AstName& debugname); // function funcname funcbody - AstStat* parseFunctionStat(); + LUAU_FORCEINLINE AstStat* parseFunctionStat(const AstArray& attributes = {nullptr, 0}); + + std::pair validateAttribute(const char* attributeName, const TempVector& attributes); + + // attribute ::= '@' NAME + void parseAttribute(TempVector& attribute); + + // attributes ::= {attribute} + AstArray parseAttributes(); + + // attributes local function Name funcbody + // attributes function funcname funcbody + // attributes `declare function' Name`(' [parlist] `)' [`:` Type] + // declare Name '{' Name ':' attributes `(' [parlist] `)' [`:` Type] '}' + AstStat* parseAttributeStat(); // local function Name funcbody | // local namelist [`=' explist] - AstStat* parseLocal(); + AstStat* parseLocal(const AstArray& attributes); // return [explist] AstStat* parseReturn(); @@ -130,7 +144,7 @@ private: // `declare global' Name: Type | // `declare function' Name`(' [parlist] `)' [`:` Type] - AstStat* parseDeclaration(const Location& start); + AstStat* parseDeclaration(const Location& start, const AstArray& attributes); // varlist `=' explist AstStat* parseAssignment(AstExpr* initial); @@ -143,7 +157,7 @@ private: // funcbodyhead ::= `(' [namelist [`,' `...'] | `...'] `)' [`:` Type] // funcbody ::= funcbodyhead block end std::pair parseFunctionBody( - bool hasself, const Lexeme& matchFunction, const AstName& debugname, const Name* localName); + bool hasself, const Lexeme& matchFunction, const AstName& debugname, const Name* localName, const AstArray& attributes); // explist ::= {exp `,'} exp void parseExprList(TempVector& result); @@ -176,10 +190,10 @@ private: AstTableIndexer* parseTableIndexer(AstTableAccess access, std::optional accessLocation); - AstTypeOrPack parseFunctionType(bool allowPack, bool isCheckedFunction = false); - AstType* parseFunctionTypeTail(const Lexeme& begin, AstArray generics, AstArray genericPacks, - AstArray params, AstArray> paramNames, AstTypePack* varargAnnotation, - bool isCheckedFunction = false); + AstTypeOrPack parseFunctionType(bool allowPack, const AstArray& attributes); + AstType* parseFunctionTypeTail(const Lexeme& begin, const AstArray& attributes, AstArray generics, + AstArray genericPacks, AstArray params, AstArray> paramNames, + AstTypePack* varargAnnotation); AstType* parseTableType(bool inDeclarationContext = false); AstTypeOrPack parseSimpleType(bool allowPack, bool inDeclarationContext = false); @@ -393,6 +407,7 @@ private: std::vector matchRecoveryStopOnToken; + std::vector scratchAttr; std::vector scratchStat; std::vector> scratchString; std::vector scratchExpr; diff --git a/Ast/src/Ast.cpp b/Ast/src/Ast.cpp index bb82e0be..4c956307 100644 --- a/Ast/src/Ast.cpp +++ b/Ast/src/Ast.cpp @@ -3,6 +3,7 @@ #include "Luau/Common.h" +LUAU_FASTFLAG(LuauAttributeSyntax); namespace Luau { @@ -16,6 +17,17 @@ static void visitTypeList(AstVisitor* visitor, const AstTypeList& list) list.tailType->visit(visitor); } +AstAttr::AstAttr(const Location& location, Type type) + : AstNode(ClassIndex(), location) + , type(type) +{ +} + +void AstAttr::visit(AstVisitor* visitor) +{ + visitor->visit(this); +} + int gAstRttiIndex = 0; AstExprGroup::AstExprGroup(const Location& location, AstExpr* expr) @@ -161,11 +173,12 @@ void AstExprIndexExpr::visit(AstVisitor* visitor) } } -AstExprFunction::AstExprFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, - AstLocal* self, const AstArray& args, bool vararg, const Location& varargLocation, AstStatBlock* body, size_t functionDepth, - const AstName& debugname, const std::optional& returnAnnotation, AstTypePack* varargAnnotation, - const std::optional& argLocation) +AstExprFunction::AstExprFunction(const Location& location, const AstArray& attributes, const AstArray& generics, + const AstArray& genericPacks, AstLocal* self, const AstArray& args, bool vararg, const Location& varargLocation, + AstStatBlock* body, size_t functionDepth, const AstName& debugname, const std::optional& returnAnnotation, + AstTypePack* varargAnnotation, const std::optional& argLocation) : AstExpr(ClassIndex(), location) + , attributes(attributes) , generics(generics) , genericPacks(genericPacks) , self(self) @@ -696,27 +709,27 @@ AstStatDeclareFunction::AstStatDeclareFunction(const Location& location, const A const AstArray& genericPacks, const AstTypeList& params, const AstArray& paramNames, const AstTypeList& retTypes) : AstStat(ClassIndex(), location) + , attributes() , name(name) , generics(generics) , genericPacks(genericPacks) , params(params) , paramNames(paramNames) , retTypes(retTypes) - , checkedFunction(false) { } -AstStatDeclareFunction::AstStatDeclareFunction(const Location& location, const AstName& name, const AstArray& generics, - const AstArray& genericPacks, const AstTypeList& params, const AstArray& paramNames, - const AstTypeList& retTypes, bool checkedFunction) +AstStatDeclareFunction::AstStatDeclareFunction(const Location& location, const AstArray& attributes, const AstName& name, + const AstArray& generics, const AstArray& genericPacks, const AstTypeList& params, + const AstArray& paramNames, const AstTypeList& retTypes) : AstStat(ClassIndex(), location) + , attributes(attributes) , name(name) , generics(generics) , genericPacks(genericPacks) , params(params) , paramNames(paramNames) , retTypes(retTypes) - , checkedFunction(checkedFunction) { } @@ -729,6 +742,19 @@ void AstStatDeclareFunction::visit(AstVisitor* visitor) } } +bool AstStatDeclareFunction::isCheckedFunction() const +{ + LUAU_ASSERT(FFlag::LuauAttributeSyntax); + + for (const AstAttr* attr : attributes) + { + if (attr->type == AstAttr::Type::Checked) + return true; + } + + return false; +} + AstStatDeclareClass::AstStatDeclareClass(const Location& location, const AstName& name, std::optional superName, const AstArray& props, AstTableIndexer* indexer) : AstStat(ClassIndex(), location) @@ -820,25 +846,26 @@ void AstTypeTable::visit(AstVisitor* visitor) AstTypeFunction::AstTypeFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, const AstTypeList& argTypes, const AstArray>& argNames, const AstTypeList& returnTypes) : AstType(ClassIndex(), location) + , attributes() , generics(generics) , genericPacks(genericPacks) , argTypes(argTypes) , argNames(argNames) , returnTypes(returnTypes) - , checkedFunction(false) { LUAU_ASSERT(argNames.size == 0 || argNames.size == argTypes.types.size); } -AstTypeFunction::AstTypeFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, - const AstTypeList& argTypes, const AstArray>& argNames, const AstTypeList& returnTypes, bool checkedFunction) +AstTypeFunction::AstTypeFunction(const Location& location, const AstArray& attributes, const AstArray& generics, + const AstArray& genericPacks, const AstTypeList& argTypes, const AstArray>& argNames, + const AstTypeList& returnTypes) : AstType(ClassIndex(), location) + , attributes(attributes) , generics(generics) , genericPacks(genericPacks) , argTypes(argTypes) , argNames(argNames) , returnTypes(returnTypes) - , checkedFunction(checkedFunction) { LUAU_ASSERT(argNames.size == 0 || argNames.size == argTypes.types.size); } @@ -852,6 +879,19 @@ void AstTypeFunction::visit(AstVisitor* visitor) } } +bool AstTypeFunction::isCheckedFunction() const +{ + LUAU_ASSERT(FFlag::LuauAttributeSyntax); + + for (const AstAttr* attr : attributes) + { + if (attr->type == AstAttr::Type::Checked) + return true; + } + + return false; +} + AstTypeTypeof::AstTypeTypeof(const Location& location, AstExpr* expr) : AstType(ClassIndex(), location) , expr(expr) diff --git a/Ast/src/Lexer.cpp b/Ast/src/Lexer.cpp index 71577459..8e9b3be9 100644 --- a/Ast/src/Lexer.cpp +++ b/Ast/src/Lexer.cpp @@ -8,6 +8,7 @@ #include LUAU_FASTFLAGVARIABLE(LuauLexerLookaheadRemembersBraceType, false) +LUAU_FASTFLAGVARIABLE(LuauAttributeSyntax, false) namespace Luau { @@ -102,11 +103,19 @@ Lexeme::Lexeme(const Location& location, Type type, const char* name) , length(0) , name(name) { - LUAU_ASSERT(type == Name || (type >= Reserved_BEGIN && type < Lexeme::Reserved_END)); + LUAU_ASSERT(type == Name || type == Attribute || (type >= Reserved_BEGIN && type < Lexeme::Reserved_END)); +} + +unsigned int Lexeme::getLength() const +{ + LUAU_ASSERT(type == RawString || type == QuotedString || type == InterpStringBegin || type == InterpStringMid || type == InterpStringEnd || + type == InterpStringSimple || type == BrokenInterpDoubleBrace || type == Number || type == Comment || type == BlockComment); + + return length; } static const char* kReserved[] = {"and", "break", "do", "else", "elseif", "end", "false", "for", "function", "if", "in", "local", "nil", "not", "or", - "repeat", "return", "then", "true", "until", "while", "@checked"}; + "repeat", "return", "then", "true", "until", "while"}; std::string Lexeme::toString() const { @@ -191,6 +200,10 @@ std::string Lexeme::toString() const case Comment: return "comment"; + case Attribute: + LUAU_ASSERT(FFlag::LuauAttributeSyntax); + return name ? format("'%s'", name) : "attribute"; + case BrokenString: return "malformed string"; @@ -278,7 +291,7 @@ std::pair AstNameTable::getOrAddWithType(const char* name nameData[length] = 0; const_cast(entry).value = AstName(nameData); - const_cast(entry).type = Lexeme::Name; + const_cast(entry).type = (name[0] == '@' ? Lexeme::Attribute : Lexeme::Name); return std::make_pair(entry.value, entry.type); } @@ -994,14 +1007,11 @@ Lexeme Lexer::readNext() } case '@': { - // We're trying to lex the token @checked - LUAU_ASSERT(peekch() == '@'); - - std::pair maybeChecked = readName(); - if (maybeChecked.second != Lexeme::ReservedChecked) - return Lexeme(Location(start, position()), Lexeme::Error); - - return Lexeme(Location(start, position()), maybeChecked.second, maybeChecked.first.value); + if (FFlag::LuauAttributeSyntax) + { + std::pair attribute = readName(); + return Lexeme(Location(start, position()), Lexeme::Attribute, attribute.first.value); + } } default: if (isDigit(peekch())) diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 5ca480e8..d80878d5 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -17,11 +17,20 @@ LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) // flag so that we don't break production games by reverting syntax changes. // See docs/SyntaxChanges.md for an explanation. LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false) +LUAU_FASTFLAG(LuauAttributeSyntax) LUAU_FASTFLAGVARIABLE(LuauLeadingBarAndAmpersand, false) namespace Luau { +struct AttributeEntry +{ + const char* name; + AstAttr::Type type; +}; + +AttributeEntry kAttributeEntries[] = {{"@checked", AstAttr::Type::Checked}, {nullptr, AstAttr::Type::Checked}}; + ParseError::ParseError(const Location& location, const std::string& message) : location(location) , message(message) @@ -280,7 +289,9 @@ AstStatBlock* Parser::parseBlockNoScope() // for binding `=' exp `,' exp [`,' exp] do block end | // for namelist in explist do block end | // function funcname funcbody | +// attributes function funcname funcbody | // local function Name funcbody | +// local attributes function Name funcbody | // local namelist [`=' explist] // laststat ::= return [explist] | break AstStat* Parser::parseStat() @@ -299,13 +310,16 @@ AstStat* Parser::parseStat() case Lexeme::ReservedRepeat: return parseRepeat(); case Lexeme::ReservedFunction: - return parseFunctionStat(); + return parseFunctionStat(AstArray({nullptr, 0})); case Lexeme::ReservedLocal: - return parseLocal(); + return parseLocal(AstArray({nullptr, 0})); case Lexeme::ReservedReturn: return parseReturn(); case Lexeme::ReservedBreak: return parseBreak(); + case Lexeme::Attribute: + if (FFlag::LuauAttributeSyntax) + return parseAttributeStat(); default:; } @@ -343,7 +357,7 @@ AstStat* Parser::parseStat() if (options.allowDeclarationSyntax) { if (ident == "declare") - return parseDeclaration(expr->location); + return parseDeclaration(expr->location, AstArray({nullptr, 0})); } // skip unexpected symbol if lexer couldn't advance at all (statements are parsed in a loop) @@ -652,7 +666,7 @@ AstExpr* Parser::parseFunctionName(Location start, bool& hasself, AstName& debug } // function funcname funcbody -AstStat* Parser::parseFunctionStat() +AstStat* Parser::parseFunctionStat(const AstArray& attributes) { Location start = lexer.current().location; @@ -665,16 +679,125 @@ AstStat* Parser::parseFunctionStat() matchRecoveryStopOnToken[Lexeme::ReservedEnd]++; - AstExprFunction* body = parseFunctionBody(hasself, matchFunction, debugname, nullptr).first; + AstExprFunction* body = parseFunctionBody(hasself, matchFunction, debugname, nullptr, attributes).first; matchRecoveryStopOnToken[Lexeme::ReservedEnd]--; return allocator.alloc(Location(start, body->location), expr, body); } + +std::pair Parser::validateAttribute(const char* attributeName, const TempVector& attributes) +{ + LUAU_ASSERT(FFlag::LuauAttributeSyntax); + + AstAttr::Type type; + + // check if the attribute name is valid + + bool found = false; + + for (int i = 0; kAttributeEntries[i].name; ++i) + { + found = !strcmp(attributeName, kAttributeEntries[i].name); + if (found) + { + type = kAttributeEntries[i].type; + break; + } + } + + if (!found) + { + if (strlen(attributeName) == 1) + report(lexer.current().location, "Attribute name is missing"); + else + report(lexer.current().location, "Invalid attribute '%s'", attributeName); + } + else + { + // check that attribute is not duplicated + for (const AstAttr* attr : attributes) + { + if (attr->type == type) + { + report(lexer.current().location, "Cannot duplicate attribute '%s'", attributeName); + } + } + } + + return {found, type}; +} + +// attribute ::= '@' NAME +void Parser::parseAttribute(TempVector& attributes) +{ + LUAU_ASSERT(FFlag::LuauAttributeSyntax); + + LUAU_ASSERT(lexer.current().type == Lexeme::Type::Attribute); + + Location loc = lexer.current().location; + + const char* name = lexer.current().name; + const auto [found, type] = validateAttribute(name, attributes); + + nextLexeme(); + + if (found) + attributes.push_back(allocator.alloc(loc, type)); +} + +// attributes ::= {attribute} +AstArray Parser::parseAttributes() +{ + LUAU_ASSERT(FFlag::LuauAttributeSyntax); + + Lexeme::Type type = lexer.current().type; + + LUAU_ASSERT(type == Lexeme::Attribute); + + TempVector attributes(scratchAttr); + + while (lexer.current().type == Lexeme::Attribute) + parseAttribute(attributes); + + return copy(attributes); +} + +// attributes local function Name funcbody +// attributes function funcname funcbody +// attributes `declare function' Name`(' [parlist] `)' [`:` Type] +// declare Name '{' Name ':' attributes `(' [parlist] `)' [`:` Type] '}' +AstStat* Parser::parseAttributeStat() +{ + LUAU_ASSERT(FFlag::LuauAttributeSyntax); + + AstArray attributes = Parser::parseAttributes(); + + Lexeme::Type type = lexer.current().type; + + switch (type) + { + case Lexeme::Type::ReservedFunction: + return parseFunctionStat(attributes); + case Lexeme::Type::ReservedLocal: + return parseLocal(attributes); + case Lexeme::Type::Name: + if (options.allowDeclarationSyntax && !strcmp("declare", lexer.current().data)) + { + AstExpr* expr = parsePrimaryExpr(/* asStatement= */ true); + return parseDeclaration(expr->location, attributes); + } + default: + return reportStatError(lexer.current().location, {}, {}, + "Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got %s intead", + lexer.current().toString().c_str()); + } +} + // local function Name funcbody | // local bindinglist [`=' explist] -AstStat* Parser::parseLocal() +AstStat* Parser::parseLocal(const AstArray& attributes) { Location start = lexer.current().location; @@ -694,7 +817,7 @@ AstStat* Parser::parseLocal() matchRecoveryStopOnToken[Lexeme::ReservedEnd]++; - auto [body, var] = parseFunctionBody(false, matchFunction, name.name, &name); + auto [body, var] = parseFunctionBody(false, matchFunction, name.name, &name, attributes); matchRecoveryStopOnToken[Lexeme::ReservedEnd]--; @@ -704,6 +827,12 @@ AstStat* Parser::parseLocal() } else { + if (FFlag::LuauAttributeSyntax && attributes.size != 0) + { + return reportStatError(lexer.current().location, {}, {}, "Expected 'function' after local declaration with attribute, but got %s intead", + lexer.current().toString().c_str()); + } + matchRecoveryStopOnToken['=']++; TempVector names(scratchBinding); @@ -831,18 +960,17 @@ AstDeclaredClassProp Parser::parseDeclaredClassMethod() return AstDeclaredClassProp{fnName.name, fnType, true}; } -AstStat* Parser::parseDeclaration(const Location& start) +AstStat* Parser::parseDeclaration(const Location& start, const AstArray& attributes) { // `declare` token is already parsed at this point + + if (FFlag::LuauAttributeSyntax && (attributes.size != 0) && (lexer.current().type != Lexeme::ReservedFunction)) + return reportStatError(lexer.current().location, {}, {}, "Expected a function type declaration after attribute, but got %s intead", + lexer.current().toString().c_str()); + if (lexer.current().type == Lexeme::ReservedFunction) { nextLexeme(); - bool checkedFunction = false; - if (lexer.current().type == Lexeme::ReservedChecked) - { - checkedFunction = true; - nextLexeme(); - } Name globalName = parseName("global function name"); auto [generics, genericPacks] = parseGenericTypeList(/* withDefaultValues= */ false); @@ -880,8 +1008,8 @@ AstStat* Parser::parseDeclaration(const Location& start) if (vararg && !varargAnnotation) return reportStatError(Location(start, end), {}, {}, "All declaration parameters must be annotated"); - return allocator.alloc(Location(start, end), globalName.name, generics, genericPacks, - AstTypeList{copy(vars), varargAnnotation}, copy(varNames), retTypes, checkedFunction); + return allocator.alloc(Location(start, end), attributes, globalName.name, generics, genericPacks, + AstTypeList{copy(vars), varargAnnotation}, copy(varNames), retTypes); } else if (AstName(lexer.current().name) == "class") { @@ -1035,7 +1163,7 @@ std::pair> Parser::prepareFunctionArguments(const // funcbody ::= `(' [parlist] `)' [`:' ReturnType] block end // parlist ::= bindinglist [`,' `...'] | `...' std::pair Parser::parseFunctionBody( - bool hasself, const Lexeme& matchFunction, const AstName& debugname, const Name* localName) + bool hasself, const Lexeme& matchFunction, const AstName& debugname, const Name* localName, const AstArray& attributes) { Location start = matchFunction.location; @@ -1087,7 +1215,7 @@ std::pair Parser::parseFunctionBody( bool hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchFunction); body->hasEnd = hasEnd; - return {allocator.alloc(Location(start, end), generics, genericPacks, self, vars, vararg, varargLocation, body, + return {allocator.alloc(Location(start, end), attributes, generics, genericPacks, self, vars, vararg, varargLocation, body, functionStack.size(), debugname, typelist, varargAnnotation, argLocation), funLocal}; } @@ -1296,7 +1424,7 @@ std::pair Parser::parseReturnType() return {location, AstTypeList{copy(result), varargAnnotation}}; } - AstType* tail = parseFunctionTypeTail(begin, {}, {}, copy(result), copy(resultNames), varargAnnotation); + AstType* tail = parseFunctionTypeTail(begin, {nullptr, 0}, {}, {}, copy(result), copy(resultNames), varargAnnotation); return {Location{location, tail->location}, AstTypeList{copy(&tail, 1), varargAnnotation}}; } @@ -1435,7 +1563,7 @@ AstType* Parser::parseTableType(bool inDeclarationContext) // ReturnType ::= Type | `(' TypeList `)' // FunctionType ::= [`<' varlist `>'] `(' [TypeList] `)' `->` ReturnType -AstTypeOrPack Parser::parseFunctionType(bool allowPack, bool isCheckedFunction) +AstTypeOrPack Parser::parseFunctionType(bool allowPack, const AstArray& attributes) { incrementRecursionCounter("type annotation"); @@ -1483,11 +1611,12 @@ AstTypeOrPack Parser::parseFunctionType(bool allowPack, bool isCheckedFunction) AstArray> paramNames = copy(names); - return {parseFunctionTypeTail(begin, generics, genericPacks, paramTypes, paramNames, varargAnnotation, isCheckedFunction), {}}; + return {parseFunctionTypeTail(begin, attributes, generics, genericPacks, paramTypes, paramNames, varargAnnotation), {}}; } -AstType* Parser::parseFunctionTypeTail(const Lexeme& begin, AstArray generics, AstArray genericPacks, - AstArray params, AstArray> paramNames, AstTypePack* varargAnnotation, bool isCheckedFunction) +AstType* Parser::parseFunctionTypeTail(const Lexeme& begin, const AstArray& attributes, AstArray generics, + AstArray genericPacks, AstArray params, AstArray> paramNames, + AstTypePack* varargAnnotation) { incrementRecursionCounter("type annotation"); @@ -1512,7 +1641,7 @@ AstType* Parser::parseFunctionTypeTail(const Lexeme& begin, AstArray( - Location(begin.location, endLocation), generics, genericPacks, paramTypes, paramNames, returnTypeList, isCheckedFunction); + Location(begin.location, endLocation), attributes, generics, genericPacks, paramTypes, paramNames, returnTypeList); } // Type ::= @@ -1666,7 +1795,21 @@ AstTypeOrPack Parser::parseSimpleType(bool allowPack, bool inDeclarationContext) Location start = lexer.current().location; - if (lexer.current().type == Lexeme::ReservedNil) + AstArray attributes{nullptr, 0}; + + if (lexer.current().type == Lexeme::Attribute) + { + if (!inDeclarationContext || !FFlag::LuauAttributeSyntax) + { + return {reportTypeError(start, {}, "attributes are not allowed in declaration context")}; + } + else + { + attributes = Parser::parseAttributes(); + return parseFunctionType(allowPack, attributes); + } + } + else if (lexer.current().type == Lexeme::ReservedNil) { nextLexeme(); return {allocator.alloc(start, std::nullopt, nameNil, std::nullopt, start), {}}; @@ -1754,14 +1897,9 @@ AstTypeOrPack Parser::parseSimpleType(bool allowPack, bool inDeclarationContext) { return {parseTableType(/* inDeclarationContext */ inDeclarationContext), {}}; } - else if (inDeclarationContext && lexer.current().type == Lexeme::ReservedChecked) - { - nextLexeme(); - return parseFunctionType(allowPack, /* isCheckedFunction */ true); - } else if (lexer.current().type == '(' || lexer.current().type == '<') { - return parseFunctionType(allowPack); + return parseFunctionType(allowPack, AstArray({nullptr, 0})); } else if (lexer.current().type == Lexeme::ReservedFunction) { @@ -2259,7 +2397,7 @@ AstExpr* Parser::parseSimpleExpr() Lexeme matchFunction = lexer.current(); nextLexeme(); - return parseFunctionBody(false, matchFunction, AstName(), nullptr).first; + return parseFunctionBody(false, matchFunction, AstName(), nullptr, AstArray({nullptr, 0})).first; } else if (lexer.current().type == Lexeme::Number) { @@ -2689,7 +2827,7 @@ std::optional> Parser::parseCharArray() LUAU_ASSERT(lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::InterpStringSimple); - scratchData.assign(lexer.current().data, lexer.current().length); + scratchData.assign(lexer.current().data, lexer.current().getLength()); if (lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::InterpStringSimple) { @@ -2734,7 +2872,7 @@ AstExpr* Parser::parseInterpString() endLocation = currentLexeme.location; - scratchData.assign(currentLexeme.data, currentLexeme.length); + scratchData.assign(currentLexeme.data, currentLexeme.getLength()); if (!Lexer::fixupQuotedString(scratchData)) { @@ -2807,7 +2945,7 @@ AstExpr* Parser::parseNumber() { Location start = lexer.current().location; - scratchData.assign(lexer.current().data, lexer.current().length); + scratchData.assign(lexer.current().data, lexer.current().getLength()); // Remove all internal _ - they don't hold any meaning and this allows parsing code to just pass the string pointer to strtod et al if (scratchData.find('_') != std::string::npos) @@ -3162,11 +3300,11 @@ void Parser::nextLexeme() return; // Comments starting with ! are called "hot comments" and contain directives for type checking / linting / compiling - if (lexeme.type == Lexeme::Comment && lexeme.length && lexeme.data[0] == '!') + if (lexeme.type == Lexeme::Comment && lexeme.getLength() && lexeme.data[0] == '!') { const char* text = lexeme.data; - unsigned int end = lexeme.length; + unsigned int end = lexeme.getLength(); while (end > 0 && isSpace(text[end - 1])) --end; diff --git a/CodeGen/include/Luau/CodeGen.h b/CodeGen/include/Luau/CodeGen.h index 171e9197..7dd05660 100644 --- a/CodeGen/include/Luau/CodeGen.h +++ b/CodeGen/include/Luau/CodeGen.h @@ -73,12 +73,39 @@ struct CompilationResult }; struct IrBuilder; +struct IrOp; using HostVectorOperationBytecodeType = uint8_t (*)(const char* member, size_t memberLength); using HostVectorAccessHandler = bool (*)(IrBuilder& builder, const char* member, size_t memberLength, int resultReg, int sourceReg, int pcpos); using HostVectorNamecallHandler = bool (*)( IrBuilder& builder, const char* member, size_t memberLength, int argResReg, int sourceReg, int params, int results, int pcpos); +enum class HostMetamethod +{ + Add, + Sub, + Mul, + Div, + Idiv, + Mod, + Pow, + Minus, + Equal, + LessThan, + LessEqual, + Length, + Concat, +}; + +using HostUserdataOperationBytecodeType = uint8_t (*)(uint8_t type, const char* member, size_t memberLength); +using HostUserdataMetamethodBytecodeType = uint8_t (*)(uint8_t lhsTy, uint8_t rhsTy, HostMetamethod method); +using HostUserdataAccessHandler = bool (*)( + IrBuilder& builder, uint8_t type, const char* member, size_t memberLength, int resultReg, int sourceReg, int pcpos); +using HostUserdataMetamethodHandler = bool (*)( + IrBuilder& builder, uint8_t lhsTy, uint8_t rhsTy, int resultReg, IrOp lhs, IrOp rhs, HostMetamethod method, int pcpos); +using HostUserdataNamecallHandler = bool (*)( + IrBuilder& builder, uint8_t type, const char* member, size_t memberLength, int argResReg, int sourceReg, int params, int results, int pcpos); + struct HostIrHooks { // Suggest result type of a vector field access @@ -97,6 +124,34 @@ struct HostIrHooks // All other arguments can be of any type // Guards should take a VM exit to 'pcpos' HostVectorNamecallHandler vectorNamecall = nullptr; + + // Suggest result type of a userdata field access + HostUserdataOperationBytecodeType userdataAccessBytecodeType = nullptr; + + // Suggest result type of a metamethod call + HostUserdataMetamethodBytecodeType userdataMetamethodBytecodeType = nullptr; + + // Suggest result type of a userdata namecall + HostUserdataOperationBytecodeType userdataNamecallBytecodeType = nullptr; + + // Handle userdata value field access + // 'sourceReg' is guaranteed to be a userdata, but tag has to be checked + // Write to 'resultReg' might invalidate 'sourceReg' + // Guards should take a VM exit to 'pcpos' + HostUserdataAccessHandler userdataAccess = nullptr; + + // Handle metamethod operation on a userdata value + // 'lhs' and 'rhs' operands can be VM registers of constants + // Operand types have to be checked and userdata operand tags have to be checked + // Write to 'resultReg' might invalidate source operands + // Guards should take a VM exit to 'pcpos' + HostUserdataMetamethodHandler userdataMetamethod = nullptr; + + // Handle namecall performed on a userdata value + // 'sourceReg' (self argument) is guaranteed to be a userdata, but tag has to be checked + // All other arguments can be of any type + // Guards should take a VM exit to 'pcpos' + HostUserdataNamecallHandler userdataNamecall = nullptr; }; struct CompilationOptions diff --git a/CodeGen/include/Luau/IrData.h b/CodeGen/include/Luau/IrData.h index b00fffab..60af706f 100644 --- a/CodeGen/include/Luau/IrData.h +++ b/CodeGen/include/Luau/IrData.h @@ -290,6 +290,11 @@ enum class IrCmd : uint8_t // C: block TRY_CALL_FASTGETTM, + // Create new tagged userdata + // A: int (size) + // B: int (tag) + NEW_USERDATA, + // Convert integer into a double number // A: int INT_TO_NUM, @@ -460,6 +465,13 @@ enum class IrCmd : uint8_t // When undef is specified instead of a block, execution is aborted on check failure CHECK_BUFFER_LEN, + // Guard against userdata tag mismatch + // A: pointer (userdata) + // B: int (tag) + // C: block/vmexit/undef + // When undef is specified instead of a block, execution is aborted on check failure + CHECK_USERDATA_TAG, + // Special operations // Check interrupt handler diff --git a/CodeGen/include/Luau/IrUtils.h b/CodeGen/include/Luau/IrUtils.h index 55b86822..8486921e 100644 --- a/CodeGen/include/Luau/IrUtils.h +++ b/CodeGen/include/Luau/IrUtils.h @@ -11,6 +11,7 @@ namespace CodeGen { struct IrBuilder; +enum class HostMetamethod; inline bool isJumpD(LuauOpcode op) { @@ -129,6 +130,7 @@ inline bool isNonTerminatingJump(IrCmd cmd) case IrCmd::CHECK_NODE_NO_NEXT: case IrCmd::CHECK_NODE_VALUE: case IrCmd::CHECK_BUFFER_LEN: + case IrCmd::CHECK_USERDATA_TAG: return true; default: break; @@ -182,6 +184,7 @@ inline bool hasResult(IrCmd cmd) case IrCmd::DUP_TABLE: case IrCmd::TRY_NUM_TO_INDEX: case IrCmd::TRY_CALL_FASTGETTM: + case IrCmd::NEW_USERDATA: case IrCmd::INT_TO_NUM: case IrCmd::UINT_TO_NUM: case IrCmd::NUM_TO_INT: @@ -245,6 +248,8 @@ bool isGCO(uint8_t tag); bool isUserdataBytecodeType(uint8_t ty); bool isCustomUserdataBytecodeType(uint8_t ty); +HostMetamethod tmToHostMetamethod(int tm); + // Manually add or remove use of an operand void addUse(IrFunction& function, IrOp op); void removeUse(IrFunction& function, IrOp op); diff --git a/CodeGen/src/BytecodeAnalysis.cpp b/CodeGen/src/BytecodeAnalysis.cpp index aed8c763..fc8eb900 100644 --- a/CodeGen/src/BytecodeAnalysis.cpp +++ b/CodeGen/src/BytecodeAnalysis.cpp @@ -16,6 +16,7 @@ LUAU_FASTFLAG(LuauLoadTypeInfo) // Because new VM typeinfo loa LUAU_FASTFLAGVARIABLE(LuauCodegenTypeInfo, false) // New analysis is flagged separately LUAU_FASTFLAGVARIABLE(LuauCodegenAnalyzeHostVectorOps, false) LUAU_FASTFLAGVARIABLE(LuauCodegenLoadTypeUpvalCheck, false) +LUAU_FASTFLAGVARIABLE(LuauCodegenUserdataOps, false) namespace Luau { @@ -546,6 +547,49 @@ static void applyBuiltinCall(int bfid, BytecodeTypes& types) } } +static HostMetamethod opcodeToHostMetamethod(LuauOpcode op) +{ + switch (op) + { + case LOP_ADD: + return HostMetamethod::Add; + case LOP_SUB: + return HostMetamethod::Sub; + case LOP_MUL: + return HostMetamethod::Mul; + case LOP_DIV: + return HostMetamethod::Div; + case LOP_IDIV: + return HostMetamethod::Idiv; + case LOP_MOD: + return HostMetamethod::Mod; + case LOP_POW: + return HostMetamethod::Pow; + case LOP_ADDK: + return HostMetamethod::Add; + case LOP_SUBK: + return HostMetamethod::Sub; + case LOP_MULK: + return HostMetamethod::Mul; + case LOP_DIVK: + return HostMetamethod::Div; + case LOP_IDIVK: + return HostMetamethod::Idiv; + case LOP_MODK: + return HostMetamethod::Mod; + case LOP_POWK: + return HostMetamethod::Pow; + case LOP_SUBRK: + return HostMetamethod::Sub; + case LOP_DIVRK: + return HostMetamethod::Div; + default: + CODEGEN_ASSERT(!"opcode is not assigned to a host metamethod"); + } + + return HostMetamethod::Add; +} + void buildBytecodeBlocks(IrFunction& function, const std::vector& jumpTargets) { Proto* proto = function.proto; @@ -760,22 +804,50 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) regTags[ra] = LBC_TYPE_ANY; - if (bcType.a == LBC_TYPE_VECTOR) + if (FFlag::LuauCodegenUserdataOps) { TString* str = gco2ts(function.proto->k[kc].value.gc); const char* field = getstr(str); - if (str->len == 1) + if (bcType.a == LBC_TYPE_VECTOR) { - // Same handling as LOP_GETTABLEKS block in lvmexecute.cpp - case-insensitive comparison with "X" / "Y" / "Z" - char ch = field[0] | ' '; + if (str->len == 1) + { + // Same handling as LOP_GETTABLEKS block in lvmexecute.cpp - case-insensitive comparison with "X" / "Y" / "Z" + char ch = field[0] | ' '; - if (ch == 'x' || ch == 'y' || ch == 'z') - regTags[ra] = LBC_TYPE_NUMBER; + if (ch == 'x' || ch == 'y' || ch == 'z') + regTags[ra] = LBC_TYPE_NUMBER; + } + + if (FFlag::LuauCodegenAnalyzeHostVectorOps && regTags[ra] == LBC_TYPE_ANY && hostHooks.vectorAccessBytecodeType) + regTags[ra] = hostHooks.vectorAccessBytecodeType(field, str->len); } + else if (isCustomUserdataBytecodeType(bcType.a)) + { + if (regTags[ra] == LBC_TYPE_ANY && hostHooks.userdataAccessBytecodeType) + regTags[ra] = hostHooks.userdataAccessBytecodeType(bcType.a, field, str->len); + } + } + else + { + if (bcType.a == LBC_TYPE_VECTOR) + { + TString* str = gco2ts(function.proto->k[kc].value.gc); + const char* field = getstr(str); - if (FFlag::LuauCodegenAnalyzeHostVectorOps && regTags[ra] == LBC_TYPE_ANY && hostHooks.vectorAccessBytecodeType) - regTags[ra] = hostHooks.vectorAccessBytecodeType(field, str->len); + if (str->len == 1) + { + // Same handling as LOP_GETTABLEKS block in lvmexecute.cpp - case-insensitive comparison with "X" / "Y" / "Z" + char ch = field[0] | ' '; + + if (ch == 'x' || ch == 'y' || ch == 'z') + regTags[ra] = LBC_TYPE_NUMBER; + } + + if (FFlag::LuauCodegenAnalyzeHostVectorOps && regTags[ra] == LBC_TYPE_ANY && hostHooks.vectorAccessBytecodeType) + regTags[ra] = hostHooks.vectorAccessBytecodeType(field, str->len); + } } bcType.result = regTags[ra]; @@ -812,6 +884,9 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) regTags[ra] = LBC_TYPE_NUMBER; else if (bcType.a == LBC_TYPE_VECTOR && bcType.b == LBC_TYPE_VECTOR) regTags[ra] = LBC_TYPE_VECTOR; + else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType && + (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) + regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); bcType.result = regTags[ra]; break; @@ -841,6 +916,11 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) if (bcType.b == LBC_TYPE_NUMBER || bcType.b == LBC_TYPE_VECTOR) regTags[ra] = LBC_TYPE_VECTOR; } + else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType && + (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) + { + regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); + } bcType.result = regTags[ra]; break; @@ -859,6 +939,9 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) if (bcType.a == LBC_TYPE_NUMBER && bcType.b == LBC_TYPE_NUMBER) regTags[ra] = LBC_TYPE_NUMBER; + else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType && + (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) + regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); bcType.result = regTags[ra]; break; @@ -879,6 +962,9 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) regTags[ra] = LBC_TYPE_NUMBER; else if (bcType.a == LBC_TYPE_VECTOR && bcType.b == LBC_TYPE_VECTOR) regTags[ra] = LBC_TYPE_VECTOR; + else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType && + (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) + regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); bcType.result = regTags[ra]; break; @@ -908,6 +994,11 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) if (bcType.b == LBC_TYPE_NUMBER || bcType.b == LBC_TYPE_VECTOR) regTags[ra] = LBC_TYPE_VECTOR; } + else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType && + (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) + { + regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); + } bcType.result = regTags[ra]; break; @@ -926,6 +1017,9 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) if (bcType.a == LBC_TYPE_NUMBER && bcType.b == LBC_TYPE_NUMBER) regTags[ra] = LBC_TYPE_NUMBER; + else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType && + (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) + regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); bcType.result = regTags[ra]; break; @@ -945,6 +1039,9 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) regTags[ra] = LBC_TYPE_NUMBER; else if (bcType.a == LBC_TYPE_VECTOR && bcType.b == LBC_TYPE_VECTOR) regTags[ra] = LBC_TYPE_VECTOR; + else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType && + (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) + regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); bcType.result = regTags[ra]; break; @@ -972,6 +1069,11 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) if (bcType.b == LBC_TYPE_NUMBER || bcType.b == LBC_TYPE_VECTOR) regTags[ra] = LBC_TYPE_VECTOR; } + else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType && + (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) + { + regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); + } bcType.result = regTags[ra]; break; @@ -1000,6 +1102,8 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) regTags[ra] = LBC_TYPE_NUMBER; else if (bcType.a == LBC_TYPE_VECTOR) regTags[ra] = LBC_TYPE_VECTOR; + else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType && isCustomUserdataBytecodeType(bcType.a)) + regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, LBC_TYPE_ANY, HostMetamethod::Minus); bcType.result = regTags[ra]; break; @@ -1140,12 +1244,25 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) bcType.result = LBC_TYPE_FUNCTION; - if (FFlag::LuauCodegenAnalyzeHostVectorOps && bcType.a == LBC_TYPE_VECTOR && hostHooks.vectorNamecallBytecodeType) + if (FFlag::LuauCodegenUserdataOps) { TString* str = gco2ts(function.proto->k[kc].value.gc); const char* field = getstr(str); - knownNextCallResult = LuauBytecodeType(hostHooks.vectorNamecallBytecodeType(field, str->len)); + if (FFlag::LuauCodegenAnalyzeHostVectorOps && bcType.a == LBC_TYPE_VECTOR && hostHooks.vectorNamecallBytecodeType) + knownNextCallResult = LuauBytecodeType(hostHooks.vectorNamecallBytecodeType(field, str->len)); + else if (isCustomUserdataBytecodeType(bcType.a) && hostHooks.userdataNamecallBytecodeType) + knownNextCallResult = LuauBytecodeType(hostHooks.userdataNamecallBytecodeType(bcType.a, field, str->len)); + } + else + { + if (FFlag::LuauCodegenAnalyzeHostVectorOps && bcType.a == LBC_TYPE_VECTOR && hostHooks.vectorNamecallBytecodeType) + { + TString* str = gco2ts(function.proto->k[kc].value.gc); + const char* field = getstr(str); + + knownNextCallResult = LuauBytecodeType(hostHooks.vectorNamecallBytecodeType(field, str->len)); + } } } break; diff --git a/CodeGen/src/CodeGenA64.cpp b/CodeGen/src/CodeGenA64.cpp index 05ac9013..06f64955 100644 --- a/CodeGen/src/CodeGenA64.cpp +++ b/CodeGen/src/CodeGenA64.cpp @@ -258,39 +258,6 @@ static EntryLocations buildEntryFunction(AssemblyBuilderA64& build, UnwindBuilde return locations; } -bool initHeaderFunctions(NativeState& data) -{ - AssemblyBuilderA64 build(/* logText= */ false); - UnwindBuilder& unwind = *data.unwindBuilder.get(); - - unwind.startInfo(UnwindBuilder::A64); - - EntryLocations entryLocations = buildEntryFunction(build, unwind); - - build.finalize(); - - unwind.finishInfo(); - - CODEGEN_ASSERT(build.data.empty()); - - uint8_t* codeStart = nullptr; - if (!data.codeAllocator.allocate(build.data.data(), int(build.data.size()), reinterpret_cast(build.code.data()), - int(build.code.size() * sizeof(build.code[0])), data.gateData, data.gateDataSize, codeStart)) - { - CODEGEN_ASSERT(!"Failed to create entry function"); - return false; - } - - // Set the offset at the begining so that functions in new blocks will not overlay the locations - // specified by the unwind information of the entry function - unwind.setBeginOffset(build.getLabelOffset(entryLocations.prologueEnd)); - - data.context.gateEntry = codeStart + build.getLabelOffset(entryLocations.start); - data.context.gateExit = codeStart + build.getLabelOffset(entryLocations.epilogueStart); - - return true; -} - bool initHeaderFunctions(BaseCodeGenContext& codeGenContext) { AssemblyBuilderA64 build(/* logText= */ false); diff --git a/CodeGen/src/CodeGenA64.h b/CodeGen/src/CodeGenA64.h index 24fedd9a..2633f5ba 100644 --- a/CodeGen/src/CodeGenA64.h +++ b/CodeGen/src/CodeGenA64.h @@ -7,7 +7,6 @@ namespace CodeGen { class BaseCodeGenContext; -struct NativeState; struct ModuleHelpers; namespace A64 @@ -15,7 +14,6 @@ namespace A64 class AssemblyBuilderA64; -bool initHeaderFunctions(NativeState& data); bool initHeaderFunctions(BaseCodeGenContext& codeGenContext); void assembleHelpers(AssemblyBuilderA64& build, ModuleHelpers& helpers); diff --git a/CodeGen/src/CodeGenContext.cpp b/CodeGen/src/CodeGenContext.cpp index 7788d099..ae9e41f1 100644 --- a/CodeGen/src/CodeGenContext.cpp +++ b/CodeGen/src/CodeGenContext.cpp @@ -14,8 +14,8 @@ LUAU_FASTFLAGVARIABLE(LuauCodegenCheckNullContext, false) -LUAU_FASTINT(LuauCodeGenBlockSize) -LUAU_FASTINT(LuauCodeGenMaxTotalSize) +LUAU_FASTINTVARIABLE(LuauCodeGenBlockSize, 4 * 1024 * 1024) +LUAU_FASTINTVARIABLE(LuauCodeGenMaxTotalSize, 256 * 1024 * 1024) namespace Luau { diff --git a/CodeGen/src/CodeGenUtils.cpp b/CodeGen/src/CodeGenUtils.cpp index 973829ca..ad231e76 100644 --- a/CodeGen/src/CodeGenUtils.cpp +++ b/CodeGen/src/CodeGenUtils.cpp @@ -14,6 +14,7 @@ #include "lstate.h" #include "lstring.h" #include "ltable.h" +#include "ludata.h" #include @@ -219,6 +220,20 @@ void callEpilogC(lua_State* L, int nresults, int n) L->top = (nresults == LUA_MULTRET) ? res : cip->top; } +Udata* newUserdata(lua_State* L, size_t s, int tag) +{ + Udata* u = luaU_newudata(L, s, tag); + + if (Table* h = L->global->udatamt[tag]) + { + u->metatable = h; + + luaC_objbarrier(L, u, h); + } + + return u; +} + // Extracted as-is from lvmexecute.cpp with the exception of control flow (reentry) and removed interrupts/savedpc Closure* callFallback(lua_State* L, StkId ra, StkId argtop, int nresults) { diff --git a/CodeGen/src/CodeGenUtils.h b/CodeGen/src/CodeGenUtils.h index 515a81f0..15d4c95d 100644 --- a/CodeGen/src/CodeGenUtils.h +++ b/CodeGen/src/CodeGenUtils.h @@ -17,6 +17,8 @@ void forgPrepXnextFallback(lua_State* L, TValue* ra, int pc); Closure* callProlog(lua_State* L, TValue* ra, StkId argtop, int nresults); void callEpilogC(lua_State* L, int nresults, int n); +Udata* newUserdata(lua_State* L, size_t s, int tag); + #define CALL_FALLBACK_YIELD 1 Closure* callFallback(lua_State* L, StkId ra, StkId argtop, int nresults); diff --git a/CodeGen/src/CodeGenX64.cpp b/CodeGen/src/CodeGenX64.cpp index 7f4a9e0c..b8df3774 100644 --- a/CodeGen/src/CodeGenX64.cpp +++ b/CodeGen/src/CodeGenX64.cpp @@ -186,39 +186,6 @@ static EntryLocations buildEntryFunction(AssemblyBuilderX64& build, UnwindBuilde return locations; } -bool initHeaderFunctions(NativeState& data) -{ - AssemblyBuilderX64 build(/* logText= */ false); - UnwindBuilder& unwind = *data.unwindBuilder.get(); - - unwind.startInfo(UnwindBuilder::X64); - - EntryLocations entryLocations = buildEntryFunction(build, unwind); - - build.finalize(); - - unwind.finishInfo(); - - CODEGEN_ASSERT(build.data.empty()); - - uint8_t* codeStart = nullptr; - if (!data.codeAllocator.allocate( - build.data.data(), int(build.data.size()), build.code.data(), int(build.code.size()), data.gateData, data.gateDataSize, codeStart)) - { - CODEGEN_ASSERT(!"Failed to create entry function"); - return false; - } - - // Set the offset at the begining so that functions in new blocks will not overlay the locations - // specified by the unwind information of the entry function - unwind.setBeginOffset(build.getLabelOffset(entryLocations.prologueEnd)); - - data.context.gateEntry = codeStart + build.getLabelOffset(entryLocations.start); - data.context.gateExit = codeStart + build.getLabelOffset(entryLocations.epilogueStart); - - return true; -} - bool initHeaderFunctions(BaseCodeGenContext& codeGenContext) { AssemblyBuilderX64 build(/* logText= */ false); diff --git a/CodeGen/src/CodeGenX64.h b/CodeGen/src/CodeGenX64.h index eb6ab81c..ce360b23 100644 --- a/CodeGen/src/CodeGenX64.h +++ b/CodeGen/src/CodeGenX64.h @@ -7,7 +7,6 @@ namespace CodeGen { class BaseCodeGenContext; -struct NativeState; struct ModuleHelpers; namespace X64 @@ -15,7 +14,6 @@ namespace X64 class AssemblyBuilderX64; -bool initHeaderFunctions(NativeState& data); bool initHeaderFunctions(BaseCodeGenContext& codeGenContext); void assembleHelpers(AssemblyBuilderX64& build, ModuleHelpers& helpers); diff --git a/CodeGen/src/EmitCommonA64.h b/CodeGen/src/EmitCommonA64.h index 894570d9..d61fd2a7 100644 --- a/CodeGen/src/EmitCommonA64.h +++ b/CodeGen/src/EmitCommonA64.h @@ -22,8 +22,6 @@ namespace Luau namespace CodeGen { -struct NativeState; - namespace A64 { diff --git a/CodeGen/src/EmitCommonX64.h b/CodeGen/src/EmitCommonX64.h index c29479e1..f88944e5 100644 --- a/CodeGen/src/EmitCommonX64.h +++ b/CodeGen/src/EmitCommonX64.h @@ -26,7 +26,6 @@ namespace CodeGen { enum class IrCondition : uint8_t; -struct NativeState; struct IrOp; namespace X64 diff --git a/CodeGen/src/IrDump.cpp b/CodeGen/src/IrDump.cpp index c47a0b8f..a82ee894 100644 --- a/CodeGen/src/IrDump.cpp +++ b/CodeGen/src/IrDump.cpp @@ -199,6 +199,8 @@ const char* getCmdName(IrCmd cmd) return "TRY_NUM_TO_INDEX"; case IrCmd::TRY_CALL_FASTGETTM: return "TRY_CALL_FASTGETTM"; + case IrCmd::NEW_USERDATA: + return "NEW_USERDATA"; case IrCmd::INT_TO_NUM: return "INT_TO_NUM"; case IrCmd::UINT_TO_NUM: @@ -257,6 +259,8 @@ const char* getCmdName(IrCmd cmd) return "CHECK_NODE_VALUE"; case IrCmd::CHECK_BUFFER_LEN: return "CHECK_BUFFER_LEN"; + case IrCmd::CHECK_USERDATA_TAG: + return "CHECK_USERDATA_TAG"; case IrCmd::INTERRUPT: return "INTERRUPT"; case IrCmd::CHECK_GC: diff --git a/CodeGen/src/IrLoweringA64.cpp b/CodeGen/src/IrLoweringA64.cpp index c8cc07f4..ea83bb99 100644 --- a/CodeGen/src/IrLoweringA64.cpp +++ b/CodeGen/src/IrLoweringA64.cpp @@ -13,6 +13,9 @@ LUAU_FASTFLAG(LuauCodegenRemoveDeadStores5) LUAU_FASTFLAG(LuauCodegenSplitDoarith) +LUAU_FASTFLAG(LuauCodegenUserdataOps) +LUAU_FASTFLAGVARIABLE(LuauCodegenUserdataAlloc, false) +LUAU_FASTFLAGVARIABLE(LuauCodegenUserdataOpsFixA64, false) namespace Luau { @@ -1083,6 +1086,19 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) inst.regA64 = regs.takeReg(x0, index); break; } + case IrCmd::NEW_USERDATA: + { + CODEGEN_ASSERT(FFlag::LuauCodegenUserdataAlloc); + + regs.spill(build, index); + build.mov(x0, rState); + build.mov(x1, intOp(inst.a)); + build.mov(x2, intOp(inst.b)); + build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, newUserdata))); + build.blr(x3); + inst.regA64 = regs.takeReg(x0, index); + break; + } case IrCmd::INT_TO_NUM: { inst.regA64 = regs.allocReg(KindA64::d, index); @@ -1677,6 +1693,24 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) finalizeTargetLabel(inst.d, fresh); break; } + case IrCmd::CHECK_USERDATA_TAG: + { + CODEGEN_ASSERT(FFlag::LuauCodegenUserdataOps); + + Label fresh; // used when guard aborts execution or jumps to a VM exit + Label& fail = getTargetLabel(inst.c, fresh); + RegisterA64 temp = regs.allocTemp(KindA64::w); + build.ldrb(temp, mem(regOp(inst.a), offsetof(Udata, tag))); + + if (FFlag::LuauCodegenUserdataOpsFixA64) + build.cmp(temp, intOp(inst.b)); + else + build.cmp(temp, tagOp(inst.b)); + + build.b(ConditionA64::NotEqual, fail); + finalizeTargetLabel(inst.c, fresh); + break; + } case IrCmd::INTERRUPT: { regs.spill(build, index); @@ -2308,7 +2342,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) case IrCmd::BUFFER_READI8: { inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.b}); - AddressA64 addr = tempAddrBuffer(inst.a, inst.b); + AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c)); build.ldrsb(inst.regA64, addr); break; @@ -2317,7 +2351,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) case IrCmd::BUFFER_READU8: { inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.b}); - AddressA64 addr = tempAddrBuffer(inst.a, inst.b); + AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c)); build.ldrb(inst.regA64, addr); break; @@ -2326,7 +2360,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) case IrCmd::BUFFER_WRITEI8: { RegisterA64 temp = tempInt(inst.c); - AddressA64 addr = tempAddrBuffer(inst.a, inst.b); + AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d)); build.strb(temp, addr); break; @@ -2335,7 +2369,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) case IrCmd::BUFFER_READI16: { inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.b}); - AddressA64 addr = tempAddrBuffer(inst.a, inst.b); + AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c)); build.ldrsh(inst.regA64, addr); break; @@ -2344,7 +2378,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) case IrCmd::BUFFER_READU16: { inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.b}); - AddressA64 addr = tempAddrBuffer(inst.a, inst.b); + AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c)); build.ldrh(inst.regA64, addr); break; @@ -2353,7 +2387,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) case IrCmd::BUFFER_WRITEI16: { RegisterA64 temp = tempInt(inst.c); - AddressA64 addr = tempAddrBuffer(inst.a, inst.b); + AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d)); build.strh(temp, addr); break; @@ -2362,7 +2396,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) case IrCmd::BUFFER_READI32: { inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.b}); - AddressA64 addr = tempAddrBuffer(inst.a, inst.b); + AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c)); build.ldr(inst.regA64, addr); break; @@ -2371,7 +2405,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) case IrCmd::BUFFER_WRITEI32: { RegisterA64 temp = tempInt(inst.c); - AddressA64 addr = tempAddrBuffer(inst.a, inst.b); + AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d)); build.str(temp, addr); break; @@ -2381,7 +2415,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) { inst.regA64 = regs.allocReg(KindA64::d, index); RegisterA64 temp = castReg(KindA64::s, inst.regA64); // safe to alias a fresh register - AddressA64 addr = tempAddrBuffer(inst.a, inst.b); + AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c)); build.ldr(temp, addr); build.fcvt(inst.regA64, temp); @@ -2392,7 +2426,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) { RegisterA64 temp1 = tempDouble(inst.c); RegisterA64 temp2 = regs.allocTemp(KindA64::s); - AddressA64 addr = tempAddrBuffer(inst.a, inst.b); + AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d)); build.fcvt(temp2, temp1); build.str(temp2, addr); @@ -2402,7 +2436,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) case IrCmd::BUFFER_READF64: { inst.regA64 = regs.allocReg(KindA64::d, index); - AddressA64 addr = tempAddrBuffer(inst.a, inst.b); + AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c)); build.ldr(inst.regA64, addr); break; @@ -2411,7 +2445,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) case IrCmd::BUFFER_WRITEF64: { RegisterA64 temp = tempDouble(inst.c); - AddressA64 addr = tempAddrBuffer(inst.a, inst.b); + AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d)); build.str(temp, addr); break; @@ -2639,32 +2673,68 @@ AddressA64 IrLoweringA64::tempAddr(IrOp op, int offset) } } -AddressA64 IrLoweringA64::tempAddrBuffer(IrOp bufferOp, IrOp indexOp) +AddressA64 IrLoweringA64::tempAddrBuffer(IrOp bufferOp, IrOp indexOp, uint8_t tag) { - if (indexOp.kind == IrOpKind::Inst) + if (FFlag::LuauCodegenUserdataOps) { - RegisterA64 temp = regs.allocTemp(KindA64::x); - build.add(temp, regOp(bufferOp), regOp(indexOp)); // implicit uxtw - return mem(temp, offsetof(Buffer, data)); - } - else if (indexOp.kind == IrOpKind::Constant) - { - // Since the resulting address may be used to load any size, including 1 byte, from an unaligned offset, we are limited by unscaled encoding - if (unsigned(intOp(indexOp)) + offsetof(Buffer, data) <= 255) - return mem(regOp(bufferOp), int(intOp(indexOp) + offsetof(Buffer, data))); + CODEGEN_ASSERT(tag == LUA_TUSERDATA || tag == LUA_TBUFFER); + int dataOffset = tag == LUA_TBUFFER ? offsetof(Buffer, data) : offsetof(Udata, data); - // indexOp can only be negative in dead code (since offsets are checked); this avoids assertion in emitAddOffset - if (intOp(indexOp) < 0) - return mem(regOp(bufferOp), offsetof(Buffer, data)); + if (indexOp.kind == IrOpKind::Inst) + { + RegisterA64 temp = regs.allocTemp(KindA64::x); + build.add(temp, regOp(bufferOp), regOp(indexOp)); // implicit uxtw + return mem(temp, dataOffset); + } + else if (indexOp.kind == IrOpKind::Constant) + { + // Since the resulting address may be used to load any size, including 1 byte, from an unaligned offset, we are limited by unscaled + // encoding + if (unsigned(intOp(indexOp)) + dataOffset <= 255) + return mem(regOp(bufferOp), int(intOp(indexOp) + dataOffset)); - RegisterA64 temp = regs.allocTemp(KindA64::x); - emitAddOffset(build, temp, regOp(bufferOp), size_t(intOp(indexOp))); - return mem(temp, offsetof(Buffer, data)); + // indexOp can only be negative in dead code (since offsets are checked); this avoids assertion in emitAddOffset + if (intOp(indexOp) < 0) + return mem(regOp(bufferOp), dataOffset); + + RegisterA64 temp = regs.allocTemp(KindA64::x); + emitAddOffset(build, temp, regOp(bufferOp), size_t(intOp(indexOp))); + return mem(temp, dataOffset); + } + else + { + CODEGEN_ASSERT(!"Unsupported instruction form"); + return noreg; + } } else { - CODEGEN_ASSERT(!"Unsupported instruction form"); - return noreg; + if (indexOp.kind == IrOpKind::Inst) + { + RegisterA64 temp = regs.allocTemp(KindA64::x); + build.add(temp, regOp(bufferOp), regOp(indexOp)); // implicit uxtw + return mem(temp, offsetof(Buffer, data)); + } + else if (indexOp.kind == IrOpKind::Constant) + { + // Since the resulting address may be used to load any size, including 1 byte, from an unaligned offset, we are limited by unscaled + // encoding + if (unsigned(intOp(indexOp)) + offsetof(Buffer, data) <= 255) + return mem(regOp(bufferOp), int(intOp(indexOp) + offsetof(Buffer, data))); + + // indexOp can only be negative in dead code (since offsets are checked); this avoids assertion in emitAddOffset + if (intOp(indexOp) < 0) + return mem(regOp(bufferOp), offsetof(Buffer, data)); + + RegisterA64 temp = regs.allocTemp(KindA64::x); + emitAddOffset(build, temp, regOp(bufferOp), size_t(intOp(indexOp))); + return mem(temp, offsetof(Buffer, data)); + } + else + { + CODEGEN_ASSERT(!"Unsupported instruction form"); + return noreg; + } } } diff --git a/CodeGen/src/IrLoweringA64.h b/CodeGen/src/IrLoweringA64.h index 5fb7f2b8..5f13f58e 100644 --- a/CodeGen/src/IrLoweringA64.h +++ b/CodeGen/src/IrLoweringA64.h @@ -44,7 +44,7 @@ struct IrLoweringA64 RegisterA64 tempInt(IrOp op); RegisterA64 tempUint(IrOp op); AddressA64 tempAddr(IrOp op, int offset); - AddressA64 tempAddrBuffer(IrOp bufferOp, IrOp indexOp); + AddressA64 tempAddrBuffer(IrOp bufferOp, IrOp indexOp, uint8_t tag); // May emit restore instructions RegisterA64 regOp(IrOp op); diff --git a/CodeGen/src/IrLoweringX64.cpp b/CodeGen/src/IrLoweringX64.cpp index 66609cb7..00768c70 100644 --- a/CodeGen/src/IrLoweringX64.cpp +++ b/CodeGen/src/IrLoweringX64.cpp @@ -15,6 +15,9 @@ #include "lstate.h" #include "lgc.h" +LUAU_FASTFLAG(LuauCodegenUserdataOps) +LUAU_FASTFLAG(LuauCodegenUserdataAlloc) + namespace Luau { namespace CodeGen @@ -905,6 +908,18 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) inst.regX64 = regs.takeReg(rax, index); break; } + case IrCmd::NEW_USERDATA: + { + CODEGEN_ASSERT(FFlag::LuauCodegenUserdataAlloc); + + IrCallWrapperX64 callWrap(regs, build, index); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::qword, intOp(inst.a)); + callWrap.addArgument(SizeX64::dword, intOp(inst.b)); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, newUserdata)]); + inst.regX64 = regs.takeReg(rax, index); + break; + } case IrCmd::INT_TO_NUM: inst.regX64 = regs.allocReg(SizeX64::xmmword, index); @@ -1350,6 +1365,14 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) } break; } + case IrCmd::CHECK_USERDATA_TAG: + { + CODEGEN_ASSERT(FFlag::LuauCodegenUserdataOps); + + build.cmp(byte[regOp(inst.a) + offsetof(Udata, tag)], intOp(inst.b)); + jumpOrAbortOnUndef(ConditionX64::NotEqual, inst.c, next); + break; + } case IrCmd::INTERRUPT: { unsigned pcpos = uintOp(inst.a); @@ -1895,71 +1918,71 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) case IrCmd::BUFFER_READI8: inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b}); - build.movsx(inst.regX64, byte[bufferAddrOp(inst.a, inst.b)]); + build.movsx(inst.regX64, byte[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]); break; case IrCmd::BUFFER_READU8: inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b}); - build.movzx(inst.regX64, byte[bufferAddrOp(inst.a, inst.b)]); + build.movzx(inst.regX64, byte[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]); break; case IrCmd::BUFFER_WRITEI8: { OperandX64 value = inst.c.kind == IrOpKind::Inst ? byteReg(regOp(inst.c)) : OperandX64(int8_t(intOp(inst.c))); - build.mov(byte[bufferAddrOp(inst.a, inst.b)], value); + build.mov(byte[bufferAddrOp(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d))], value); break; } case IrCmd::BUFFER_READI16: inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b}); - build.movsx(inst.regX64, word[bufferAddrOp(inst.a, inst.b)]); + build.movsx(inst.regX64, word[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]); break; case IrCmd::BUFFER_READU16: inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b}); - build.movzx(inst.regX64, word[bufferAddrOp(inst.a, inst.b)]); + build.movzx(inst.regX64, word[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]); break; case IrCmd::BUFFER_WRITEI16: { OperandX64 value = inst.c.kind == IrOpKind::Inst ? wordReg(regOp(inst.c)) : OperandX64(int16_t(intOp(inst.c))); - build.mov(word[bufferAddrOp(inst.a, inst.b)], value); + build.mov(word[bufferAddrOp(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d))], value); break; } case IrCmd::BUFFER_READI32: inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b}); - build.mov(inst.regX64, dword[bufferAddrOp(inst.a, inst.b)]); + build.mov(inst.regX64, dword[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]); break; case IrCmd::BUFFER_WRITEI32: { OperandX64 value = inst.c.kind == IrOpKind::Inst ? regOp(inst.c) : OperandX64(intOp(inst.c)); - build.mov(dword[bufferAddrOp(inst.a, inst.b)], value); + build.mov(dword[bufferAddrOp(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d))], value); break; } case IrCmd::BUFFER_READF32: inst.regX64 = regs.allocReg(SizeX64::xmmword, index); - build.vcvtss2sd(inst.regX64, inst.regX64, dword[bufferAddrOp(inst.a, inst.b)]); + build.vcvtss2sd(inst.regX64, inst.regX64, dword[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]); break; case IrCmd::BUFFER_WRITEF32: - storeDoubleAsFloat(dword[bufferAddrOp(inst.a, inst.b)], inst.c); + storeDoubleAsFloat(dword[bufferAddrOp(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d))], inst.c); break; case IrCmd::BUFFER_READF64: inst.regX64 = regs.allocReg(SizeX64::xmmword, index); - build.vmovsd(inst.regX64, qword[bufferAddrOp(inst.a, inst.b)]); + build.vmovsd(inst.regX64, qword[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]); break; case IrCmd::BUFFER_WRITEF64: @@ -1967,11 +1990,11 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) { ScopedRegX64 tmp{regs, SizeX64::xmmword}; build.vmovsd(tmp.reg, build.f64(doubleOp(inst.c))); - build.vmovsd(qword[bufferAddrOp(inst.a, inst.b)], tmp.reg); + build.vmovsd(qword[bufferAddrOp(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d))], tmp.reg); } else if (inst.c.kind == IrOpKind::Inst) { - build.vmovsd(qword[bufferAddrOp(inst.a, inst.b)], regOp(inst.c)); + build.vmovsd(qword[bufferAddrOp(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d))], regOp(inst.c)); } else { @@ -2190,12 +2213,25 @@ RegisterX64 IrLoweringX64::regOp(IrOp op) return inst.regX64; } -OperandX64 IrLoweringX64::bufferAddrOp(IrOp bufferOp, IrOp indexOp) +OperandX64 IrLoweringX64::bufferAddrOp(IrOp bufferOp, IrOp indexOp, uint8_t tag) { - if (indexOp.kind == IrOpKind::Inst) - return regOp(bufferOp) + qwordReg(regOp(indexOp)) + offsetof(Buffer, data); - else if (indexOp.kind == IrOpKind::Constant) - return regOp(bufferOp) + intOp(indexOp) + offsetof(Buffer, data); + if (FFlag::LuauCodegenUserdataOps) + { + CODEGEN_ASSERT(tag == LUA_TUSERDATA || tag == LUA_TBUFFER); + int dataOffset = tag == LUA_TBUFFER ? offsetof(Buffer, data) : offsetof(Udata, data); + + if (indexOp.kind == IrOpKind::Inst) + return regOp(bufferOp) + qwordReg(regOp(indexOp)) + dataOffset; + else if (indexOp.kind == IrOpKind::Constant) + return regOp(bufferOp) + intOp(indexOp) + dataOffset; + } + else + { + if (indexOp.kind == IrOpKind::Inst) + return regOp(bufferOp) + qwordReg(regOp(indexOp)) + offsetof(Buffer, data); + else if (indexOp.kind == IrOpKind::Constant) + return regOp(bufferOp) + intOp(indexOp) + offsetof(Buffer, data); + } CODEGEN_ASSERT(!"Unsupported instruction form"); return noreg; diff --git a/CodeGen/src/IrLoweringX64.h b/CodeGen/src/IrLoweringX64.h index 5fb7b0fa..8fb311ea 100644 --- a/CodeGen/src/IrLoweringX64.h +++ b/CodeGen/src/IrLoweringX64.h @@ -50,7 +50,7 @@ struct IrLoweringX64 OperandX64 memRegUintOp(IrOp op); OperandX64 memRegTagOp(IrOp op); RegisterX64 regOp(IrOp op); - OperandX64 bufferAddrOp(IrOp bufferOp, IrOp indexOp); + OperandX64 bufferAddrOp(IrOp bufferOp, IrOp indexOp, uint8_t tag); RegisterX64 vecOp(IrOp op, ScopedRegX64& tmp); IrConst constOp(IrOp op) const; diff --git a/CodeGen/src/IrTranslation.cpp b/CodeGen/src/IrTranslation.cpp index 93073a92..5798f3e9 100644 --- a/CodeGen/src/IrTranslation.cpp +++ b/CodeGen/src/IrTranslation.cpp @@ -15,6 +15,7 @@ LUAU_FASTFLAGVARIABLE(LuauCodegenDirectUserdataFlow, false) LUAU_FASTFLAG(LuauCodegenAnalyzeHostVectorOps) +LUAU_FASTFLAG(LuauCodegenUserdataOps) namespace Luau { @@ -444,6 +445,17 @@ static void translateInstBinaryNumeric(IrBuilder& build, int ra, int rb, int rc, return; } + if (FFlag::LuauCodegenUserdataOps && (isUserdataBytecodeType(bcTypes.a) || isUserdataBytecodeType(bcTypes.b))) + { + if (build.hostHooks.userdataMetamethod && + build.hostHooks.userdataMetamethod(build, bcTypes.a, bcTypes.b, ra, opb, opc, tmToHostMetamethod(tm), pcpos)) + return; + + build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1)); + build.inst(IrCmd::DO_ARITH, build.vmReg(ra), opb, opc, build.constInt(tm)); + return; + } + IrOp fallback; // fast-path: number @@ -585,6 +597,17 @@ void translateInstMinus(IrBuilder& build, const Instruction* pc, int pcpos) return; } + if (FFlag::LuauCodegenUserdataOps && isUserdataBytecodeType(bcTypes.a)) + { + if (build.hostHooks.userdataMetamethod && + build.hostHooks.userdataMetamethod(build, bcTypes.a, bcTypes.b, ra, build.vmReg(rb), {}, tmToHostMetamethod(TM_UNM), pcpos)) + return; + + build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1)); + build.inst(IrCmd::DO_ARITH, build.vmReg(ra), build.vmReg(rb), build.vmReg(rb), build.constInt(TM_UNM)); + return; + } + IrOp fallback; IrOp tb = build.inst(IrCmd::LOAD_TAG, build.vmReg(rb)); @@ -606,8 +629,17 @@ void translateInstMinus(IrBuilder& build, const Instruction* pc, int pcpos) FallbackStreamScope scope(build, fallback, next); build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1)); - build.inst( - IrCmd::DO_ARITH, build.vmReg(LUAU_INSN_A(*pc)), build.vmReg(LUAU_INSN_B(*pc)), build.vmReg(LUAU_INSN_B(*pc)), build.constInt(TM_UNM)); + + if (FFlag::LuauCodegenUserdataOps) + { + build.inst(IrCmd::DO_ARITH, build.vmReg(ra), build.vmReg(rb), build.vmReg(rb), build.constInt(TM_UNM)); + } + else + { + build.inst( + IrCmd::DO_ARITH, build.vmReg(LUAU_INSN_A(*pc)), build.vmReg(LUAU_INSN_B(*pc)), build.vmReg(LUAU_INSN_B(*pc)), build.constInt(TM_UNM)); + } + build.inst(IrCmd::JUMP, next); } } @@ -619,6 +651,17 @@ void translateInstLength(IrBuilder& build, const Instruction* pc, int pcpos) int ra = LUAU_INSN_A(*pc); int rb = LUAU_INSN_B(*pc); + if (FFlag::LuauCodegenUserdataOps && isUserdataBytecodeType(bcTypes.a)) + { + if (build.hostHooks.userdataMetamethod && + build.hostHooks.userdataMetamethod(build, bcTypes.a, bcTypes.b, ra, build.vmReg(rb), {}, tmToHostMetamethod(TM_LEN), pcpos)) + return; + + build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1)); + build.inst(IrCmd::DO_LEN, build.vmReg(ra), build.vmReg(rb)); + return; + } + IrOp fallback = build.block(IrBlockKind::Fallback); IrOp tb = build.inst(IrCmd::LOAD_TAG, build.vmReg(rb)); @@ -638,7 +681,12 @@ void translateInstLength(IrBuilder& build, const Instruction* pc, int pcpos) FallbackStreamScope scope(build, fallback, next); build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1)); - build.inst(IrCmd::DO_LEN, build.vmReg(LUAU_INSN_A(*pc)), build.vmReg(LUAU_INSN_B(*pc))); + + if (FFlag::LuauCodegenUserdataOps) + build.inst(IrCmd::DO_LEN, build.vmReg(ra), build.vmReg(rb)); + else + build.inst(IrCmd::DO_LEN, build.vmReg(LUAU_INSN_A(*pc)), build.vmReg(LUAU_INSN_B(*pc))); + build.inst(IrCmd::JUMP, next); } @@ -1229,10 +1277,19 @@ void translateInstGetTableKS(IrBuilder& build, const Instruction* pc, int pcpos) return; } - if (FFlag::LuauCodegenDirectUserdataFlow && bcTypes.a == LBC_TYPE_USERDATA) + if (FFlag::LuauCodegenDirectUserdataFlow && (FFlag::LuauCodegenUserdataOps ? isUserdataBytecodeType(bcTypes.a) : bcTypes.a == LBC_TYPE_USERDATA)) { build.inst(IrCmd::CHECK_TAG, tb, build.constTag(LUA_TUSERDATA), build.vmExit(pcpos)); + if (FFlag::LuauCodegenUserdataOps && build.hostHooks.userdataAccess) + { + TString* str = gco2ts(build.function.proto->k[aux].value.gc); + const char* field = getstr(str); + + if (build.hostHooks.userdataAccess(build, bcTypes.a, field, str->len, ra, rb, pcpos)) + return; + } + build.inst(IrCmd::FALLBACK_GETTABLEKS, build.constUint(pcpos), build.vmReg(ra), build.vmReg(rb), build.vmConst(aux)); return; } @@ -1267,7 +1324,7 @@ void translateInstSetTableKS(IrBuilder& build, const Instruction* pc, int pcpos) IrOp tb = build.inst(IrCmd::LOAD_TAG, build.vmReg(rb)); - if (FFlag::LuauCodegenDirectUserdataFlow && bcTypes.a == LBC_TYPE_USERDATA) + if (FFlag::LuauCodegenDirectUserdataFlow && (FFlag::LuauCodegenUserdataOps ? isUserdataBytecodeType(bcTypes.a) : bcTypes.a == LBC_TYPE_USERDATA)) { build.inst(IrCmd::CHECK_TAG, tb, build.constTag(LUA_TUSERDATA), build.vmExit(pcpos)); @@ -1413,10 +1470,26 @@ bool translateInstNamecall(IrBuilder& build, const Instruction* pc, int pcpos) return false; } - if (FFlag::LuauCodegenDirectUserdataFlow && bcTypes.a == LBC_TYPE_USERDATA) + if (FFlag::LuauCodegenDirectUserdataFlow && (FFlag::LuauCodegenUserdataOps ? isUserdataBytecodeType(bcTypes.a) : bcTypes.a == LBC_TYPE_USERDATA)) { build.loadAndCheckTag(build.vmReg(rb), LUA_TUSERDATA, build.vmExit(pcpos)); + if (FFlag::LuauCodegenUserdataOps && build.hostHooks.userdataNamecall) + { + Instruction call = pc[2]; + CODEGEN_ASSERT(LUAU_INSN_OP(call) == LOP_CALL); + + int callra = LUAU_INSN_A(call); + int nparams = LUAU_INSN_B(call) - 1; + int nresults = LUAU_INSN_C(call) - 1; + + TString* str = gco2ts(build.function.proto->k[aux].value.gc); + const char* field = getstr(str); + + if (build.hostHooks.userdataNamecall(build, bcTypes.a, field, str->len, callra, rb, nparams, nresults, pcpos)) + return true; + } + build.inst(IrCmd::FALLBACK_NAMECALL, build.constUint(pcpos), build.vmReg(ra), build.vmReg(rb), build.vmConst(aux)); return false; } diff --git a/CodeGen/src/IrUtils.cpp b/CodeGen/src/IrUtils.cpp index afc6ba5a..d1bfca45 100644 --- a/CodeGen/src/IrUtils.cpp +++ b/CodeGen/src/IrUtils.cpp @@ -99,6 +99,7 @@ IrValueKind getCmdValueKind(IrCmd cmd) case IrCmd::TRY_NUM_TO_INDEX: return IrValueKind::Int; case IrCmd::TRY_CALL_FASTGETTM: + case IrCmd::NEW_USERDATA: return IrValueKind::Pointer; case IrCmd::INT_TO_NUM: case IrCmd::UINT_TO_NUM: @@ -135,6 +136,7 @@ IrValueKind getCmdValueKind(IrCmd cmd) case IrCmd::CHECK_NODE_NO_NEXT: case IrCmd::CHECK_NODE_VALUE: case IrCmd::CHECK_BUFFER_LEN: + case IrCmd::CHECK_USERDATA_TAG: case IrCmd::INTERRUPT: case IrCmd::CHECK_GC: case IrCmd::BARRIER_OBJ: @@ -262,6 +264,44 @@ bool isCustomUserdataBytecodeType(uint8_t ty) return ty >= LBC_TYPE_TAGGED_USERDATA_BASE && ty < LBC_TYPE_TAGGED_USERDATA_END; } +HostMetamethod tmToHostMetamethod(int tm) +{ + switch (TMS(tm)) + { + case TM_ADD: + return HostMetamethod::Add; + case TM_SUB: + return HostMetamethod::Sub; + case TM_MUL: + return HostMetamethod::Mul; + case TM_DIV: + return HostMetamethod::Div; + case TM_IDIV: + return HostMetamethod::Idiv; + case TM_MOD: + return HostMetamethod::Mod; + case TM_POW: + return HostMetamethod::Pow; + case TM_UNM: + return HostMetamethod::Minus; + case TM_EQ: + return HostMetamethod::Equal; + case TM_LT: + return HostMetamethod::LessThan; + case TM_LE: + return HostMetamethod::LessEqual; + case TM_LEN: + return HostMetamethod::Length; + case TM_CONCAT: + return HostMetamethod::Concat; + default: + CODEGEN_ASSERT(!"invalid tag method for host"); + break; + } + + return HostMetamethod::Add; +} + void kill(IrFunction& function, IrInst& inst) { CODEGEN_ASSERT(inst.useCount == 0); diff --git a/CodeGen/src/NativeState.cpp b/CodeGen/src/NativeState.cpp index b3d07491..248f0cd3 100644 --- a/CodeGen/src/NativeState.cpp +++ b/CodeGen/src/NativeState.cpp @@ -14,114 +14,13 @@ #include #include -LUAU_FASTINTVARIABLE(LuauCodeGenBlockSize, 4 * 1024 * 1024) -LUAU_FASTINTVARIABLE(LuauCodeGenMaxTotalSize, 256 * 1024 * 1024) +LUAU_FASTFLAG(LuauCodegenUserdataAlloc) namespace Luau { namespace CodeGen { -NativeState::NativeState() - : NativeState(nullptr, nullptr) -{ -} - -NativeState::NativeState(AllocationCallback* allocationCallback, void* allocationCallbackContext) - : codeAllocator{size_t(FInt::LuauCodeGenBlockSize), size_t(FInt::LuauCodeGenMaxTotalSize), allocationCallback, allocationCallbackContext} -{ -} - -NativeState::~NativeState() = default; - -void initFunctions(NativeState& data) -{ - static_assert(sizeof(data.context.luauF_table) == sizeof(luauF_table), "fastcall tables are not of the same length"); - memcpy(data.context.luauF_table, luauF_table, sizeof(luauF_table)); - - data.context.luaV_lessthan = luaV_lessthan; - data.context.luaV_lessequal = luaV_lessequal; - data.context.luaV_equalval = luaV_equalval; - data.context.luaV_doarith = luaV_doarith; - - data.context.luaV_doarithadd = luaV_doarithimpl; - data.context.luaV_doarithsub = luaV_doarithimpl; - data.context.luaV_doarithmul = luaV_doarithimpl; - data.context.luaV_doarithdiv = luaV_doarithimpl; - data.context.luaV_doarithidiv = luaV_doarithimpl; - data.context.luaV_doarithmod = luaV_doarithimpl; - data.context.luaV_doarithpow = luaV_doarithimpl; - data.context.luaV_doarithunm = luaV_doarithimpl; - - data.context.luaV_dolen = luaV_dolen; - data.context.luaV_gettable = luaV_gettable; - data.context.luaV_settable = luaV_settable; - data.context.luaV_getimport = luaV_getimport; - data.context.luaV_concat = luaV_concat; - - data.context.luaH_getn = luaH_getn; - data.context.luaH_new = luaH_new; - data.context.luaH_clone = luaH_clone; - data.context.luaH_resizearray = luaH_resizearray; - data.context.luaH_setnum = luaH_setnum; - - data.context.luaC_barriertable = luaC_barriertable; - data.context.luaC_barrierf = luaC_barrierf; - data.context.luaC_barrierback = luaC_barrierback; - data.context.luaC_step = luaC_step; - - data.context.luaF_close = luaF_close; - data.context.luaF_findupval = luaF_findupval; - data.context.luaF_newLclosure = luaF_newLclosure; - - data.context.luaT_gettm = luaT_gettm; - data.context.luaT_objtypenamestr = luaT_objtypenamestr; - - data.context.libm_exp = exp; - data.context.libm_pow = pow; - data.context.libm_fmod = fmod; - data.context.libm_log = log; - data.context.libm_log2 = log2; - data.context.libm_log10 = log10; - data.context.libm_ldexp = ldexp; - data.context.libm_round = round; - data.context.libm_frexp = frexp; - data.context.libm_modf = modf; - - data.context.libm_asin = asin; - data.context.libm_sin = sin; - data.context.libm_sinh = sinh; - data.context.libm_acos = acos; - data.context.libm_cos = cos; - data.context.libm_cosh = cosh; - data.context.libm_atan = atan; - data.context.libm_atan2 = atan2; - data.context.libm_tan = tan; - data.context.libm_tanh = tanh; - - data.context.forgLoopTableIter = forgLoopTableIter; - data.context.forgLoopNodeIter = forgLoopNodeIter; - data.context.forgLoopNonTableFallback = forgLoopNonTableFallback; - data.context.forgPrepXnextFallback = forgPrepXnextFallback; - data.context.callProlog = callProlog; - data.context.callEpilogC = callEpilogC; - - data.context.callFallback = callFallback; - - data.context.executeGETGLOBAL = executeGETGLOBAL; - data.context.executeSETGLOBAL = executeSETGLOBAL; - data.context.executeGETTABLEKS = executeGETTABLEKS; - data.context.executeSETTABLEKS = executeSETTABLEKS; - - data.context.executeNAMECALL = executeNAMECALL; - data.context.executeFORGPREP = executeFORGPREP; - data.context.executeGETVARARGSMultRet = executeGETVARARGSMultRet; - data.context.executeGETVARARGSConst = executeGETVARARGSConst; - data.context.executeDUPCLOSURE = executeDUPCLOSURE; - data.context.executePREPVARARGS = executePREPVARARGS; - data.context.executeSETLIST = executeSETLIST; -} - void initFunctions(NativeContext& context) { static_assert(sizeof(context.luauF_table) == sizeof(luauF_table), "fastcall tables are not of the same length"); @@ -194,6 +93,9 @@ void initFunctions(NativeContext& context) context.callProlog = callProlog; context.callEpilogC = callEpilogC; + if (FFlag::LuauCodegenUserdataAlloc) + context.newUserdata = newUserdata; + context.callFallback = callFallback; context.executeGETGLOBAL = executeGETGLOBAL; diff --git a/CodeGen/src/NativeState.h b/CodeGen/src/NativeState.h index 2edfc270..be73815d 100644 --- a/CodeGen/src/NativeState.h +++ b/CodeGen/src/NativeState.h @@ -94,6 +94,7 @@ struct NativeContext void (*forgPrepXnextFallback)(lua_State* L, TValue* ra, int pc) = nullptr; Closure* (*callProlog)(lua_State* L, TValue* ra, StkId argtop, int nresults) = nullptr; void (*callEpilogC)(lua_State* L, int nresults, int n) = nullptr; + Udata* (*newUserdata)(lua_State* L, size_t s, int tag) = nullptr; Closure* (*callFallback)(lua_State* L, StkId ra, StkId argtop, int nresults) = nullptr; @@ -116,22 +117,6 @@ struct NativeContext using GateFn = int (*)(lua_State*, Proto*, uintptr_t, NativeContext*); -struct NativeState -{ - NativeState(); - NativeState(AllocationCallback* allocationCallback, void* allocationCallbackContext); - ~NativeState(); - - CodeAllocator codeAllocator; - std::unique_ptr unwindBuilder; - - uint8_t* gateData = nullptr; - size_t gateDataSize = 0; - - NativeContext context; -}; - -void initFunctions(NativeState& data); void initFunctions(NativeContext& context); } // namespace CodeGen diff --git a/CodeGen/src/OptimizeConstProp.cpp b/CodeGen/src/OptimizeConstProp.cpp index 9135a9ed..4ff49570 100644 --- a/CodeGen/src/OptimizeConstProp.cpp +++ b/CodeGen/src/OptimizeConstProp.cpp @@ -16,9 +16,12 @@ LUAU_FASTINTVARIABLE(LuauCodeGenMinLinearBlockPath, 3) LUAU_FASTINTVARIABLE(LuauCodeGenReuseSlotLimit, 64) +LUAU_FASTINTVARIABLE(LuauCodeGenReuseUdataTagLimit, 64) LUAU_FASTFLAGVARIABLE(DebugLuauAbortingChecks, false) LUAU_FASTFLAG(LuauCodegenRemoveDeadStores5) LUAU_FASTFLAGVARIABLE(LuauCodegenFixSplitStoreConstMismatch, false) +LUAU_FASTFLAG(LuauCodegenUserdataOps) +LUAU_FASTFLAG(LuauCodegenUserdataAlloc) namespace Luau { @@ -200,6 +203,11 @@ struct ConstPropState checkBufferLenCache.clear(); } + void invalidateUserdataData() + { + useradataTagCache.clear(); + } + void invalidateHeap() { for (int i = 0; i <= maxReg; ++i) @@ -417,6 +425,9 @@ struct ConstPropState invalidateValuePropagation(); invalidateHeapTableData(); invalidateHeapBufferData(); + + if (FFlag::LuauCodegenUserdataOps) + invalidateUserdataData(); } IrFunction& function; @@ -446,6 +457,9 @@ struct ConstPropState std::vector checkArraySizeCache; // Additionally, fallback block argument might be different std::vector checkBufferLenCache; // Additionally, fallback block argument might be different + + // Userdata tag cache can point to both NEW_USERDATA and CHECK_USERDATA_TAG instructions + std::vector useradataTagCache; // Additionally, fallback block argument might be different }; static void handleBuiltinEffects(ConstPropState& state, LuauBuiltinFunction bfid, uint32_t firstReturnReg, int nresults) @@ -1061,6 +1075,37 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& state.checkBufferLenCache.push_back(index); break; } + case IrCmd::CHECK_USERDATA_TAG: + { + CODEGEN_ASSERT(FFlag::LuauCodegenUserdataOps); + + for (uint32_t prevIdx : state.useradataTagCache) + { + IrInst& prev = function.instructions[prevIdx]; + + if (prev.cmd == IrCmd::CHECK_USERDATA_TAG) + { + if (prev.a != inst.a || prev.b != inst.b) + continue; + } + else if (FFlag::LuauCodegenUserdataAlloc && prev.cmd == IrCmd::NEW_USERDATA) + { + if (inst.a.kind != IrOpKind::Inst || prevIdx != inst.a.index || prev.b != inst.b) + continue; + } + + if (FFlag::DebugLuauAbortingChecks) + replace(function, inst.c, build.undef()); + else + kill(function, inst); + + return; // Break out from both the loop and the switch + } + + if (int(state.useradataTagCache.size()) < FInt::LuauCodeGenReuseUdataTagLimit) + state.useradataTagCache.push_back(index); + break; + } case IrCmd::BUFFER_READI8: case IrCmd::BUFFER_READU8: case IrCmd::BUFFER_WRITEI8: @@ -1228,6 +1273,12 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& break; case IrCmd::TRY_CALL_FASTGETTM: break; + case IrCmd::NEW_USERDATA: + CODEGEN_ASSERT(FFlag::LuauCodegenUserdataAlloc); + + if (int(state.useradataTagCache.size()) < FInt::LuauCodeGenReuseUdataTagLimit) + state.useradataTagCache.push_back(index); + break; case IrCmd::INT_TO_NUM: case IrCmd::UINT_TO_NUM: state.substituteOrRecord(inst, index); @@ -1512,6 +1563,9 @@ static void constPropInBlockChain(IrBuilder& build, std::vector& visite state.invalidateHeapTableData(); state.invalidateHeapBufferData(); + if (FFlag::LuauCodegenUserdataOps) + state.invalidateUserdataData(); + // Blocks in a chain are guaranteed to follow each other // We force that by giving all blocks the same sorting key, but consecutive chain keys block->sortkey = startSortkey; diff --git a/CodeGen/src/OptimizeDeadStore.cpp b/CodeGen/src/OptimizeDeadStore.cpp index 6c1d6aff..d18b75c5 100644 --- a/CodeGen/src/OptimizeDeadStore.cpp +++ b/CodeGen/src/OptimizeDeadStore.cpp @@ -10,6 +10,7 @@ #include "lobject.h" LUAU_FASTFLAGVARIABLE(LuauCodegenRemoveDeadStores5, false) +LUAU_FASTFLAG(LuauCodegenUserdataOps) // TODO: optimization can be improved by knowing which registers are live in at each VM exit @@ -595,6 +596,11 @@ static void markDeadStoresInInst(RemoveDeadStoreState& state, IrBuilder& build, case IrCmd::CHECK_BUFFER_LEN: state.checkLiveIns(inst.d); break; + case IrCmd::CHECK_USERDATA_TAG: + CODEGEN_ASSERT(FFlag::LuauCodegenUserdataOps); + + state.checkLiveIns(inst.c); + break; case IrCmd::JUMP: // Ideally, we would be able to remove stores to registers that are not live out from a block diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 19526fa9..4842b9a1 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -4219,7 +4219,8 @@ void compileOrThrow(BytecodeBuilder& bytecode, const ParseResult& parseResult, c for (AstExprFunction* expr : functions) compiler.compileFunction(expr, 0); - AstExprFunction main(root->location, /*generics= */ AstArray(), /*genericPacks= */ AstArray(), + AstExprFunction main(root->location, /*attributes=*/AstArray({nullptr, 0}), /*generics= */ AstArray(), + /*genericPacks= */ AstArray(), /* self= */ nullptr, AstArray(), /* vararg= */ true, /* varargLocation= */ Luau::Location(), root, /* functionDepth= */ 0, /* debugname= */ AstName()); uint32_t mainid = compiler.compileFunction(&main, mainFlags); diff --git a/Config/src/Config.cpp b/Config/src/Config.cpp index 693e0f87..5fba9fa3 100644 --- a/Config/src/Config.cpp +++ b/Config/src/Config.cpp @@ -195,7 +195,7 @@ static Error parseJson(const std::string& contents, Action action) } else if (lexer.current().type == Lexeme::QuotedString) { - std::string value(lexer.current().data, lexer.current().length); + std::string value(lexer.current().data, lexer.current().getLength()); next(lexer); if (Error err = action(keys, value)) @@ -232,7 +232,7 @@ static Error parseJson(const std::string& contents, Action action) } else if (lexer.current().type == Lexeme::QuotedString) { - std::string key(lexer.current().data, lexer.current().length); + std::string key(lexer.current().data, lexer.current().getLength()); next(lexer); keys.push_back(key); @@ -250,7 +250,7 @@ static Error parseJson(const std::string& contents, Action action) lexer.current().type == Lexeme::ReservedFalse) { std::string value = lexer.current().type == Lexeme::QuotedString - ? std::string(lexer.current().data, lexer.current().length) + ? std::string(lexer.current().data, lexer.current().getLength()) : (lexer.current().type == Lexeme::ReservedTrue ? "true" : "false"); next(lexer); diff --git a/VM/include/lua.h b/VM/include/lua.h index 4876b933..4ee9306e 100644 --- a/VM/include/lua.h +++ b/VM/include/lua.h @@ -324,6 +324,10 @@ typedef void (*lua_Destructor)(lua_State* L, void* userdata); LUA_API void lua_setuserdatadtor(lua_State* L, int tag, lua_Destructor dtor); LUA_API lua_Destructor lua_getuserdatadtor(lua_State* L, int tag); +// alternative access for metatables already registered with luaL_newmetatable +LUA_API void lua_setuserdatametatable(lua_State* L, int tag, int idx); +LUA_API void lua_getuserdatametatable(lua_State* L, int tag); + LUA_API void lua_setlightuserdataname(lua_State* L, int tag, const char* name); LUA_API const char* lua_getlightuserdataname(lua_State* L, int tag); diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index 58c767f1..87f85af8 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -1427,6 +1427,33 @@ lua_Destructor lua_getuserdatadtor(lua_State* L, int tag) return L->global->udatagc[tag]; } +void lua_setuserdatametatable(lua_State* L, int tag, int idx) +{ + api_check(L, unsigned(tag) < LUA_UTAG_LIMIT); + api_check(L, !L->global->udatamt[tag]); // reassignment not supported + StkId o = index2addr(L, idx); + api_check(L, ttistable(o)); + L->global->udatamt[tag] = hvalue(o); + L->top--; +} + +void lua_getuserdatametatable(lua_State* L, int tag) +{ + api_check(L, unsigned(tag) < LUA_UTAG_LIMIT); + luaC_threadbarrier(L); + + if (Table* h = L->global->udatamt[tag]) + { + sethvalue(L, L->top, h); + } + else + { + setnilvalue(L->top); + } + + api_incr_top(L); +} + void lua_setlightuserdataname(lua_State* L, int tag, const char* name) { api_check(L, unsigned(tag) < LUA_LUTAG_LIMIT); diff --git a/VM/src/lstate.cpp b/VM/src/lstate.cpp index dbc1dd10..6b7a9aa0 100644 --- a/VM/src/lstate.cpp +++ b/VM/src/lstate.cpp @@ -210,7 +210,10 @@ lua_State* lua_newstate(lua_Alloc f, void* ud) for (i = 0; i < LUA_T_COUNT; i++) g->mt[i] = NULL; for (i = 0; i < LUA_UTAG_LIMIT; i++) + { g->udatagc[i] = NULL; + g->udatamt[i] = NULL; + } for (i = 0; i < LUA_LUTAG_LIMIT; i++) g->lightuserdataname[i] = NULL; for (i = 0; i < LUA_MEMORY_CATEGORIES; i++) diff --git a/VM/src/lstate.h b/VM/src/lstate.h index 35e66471..f8caa69b 100644 --- a/VM/src/lstate.h +++ b/VM/src/lstate.h @@ -217,6 +217,7 @@ typedef struct global_State lua_ExecutionCallbacks ecb; void (*udatagc[LUA_UTAG_LIMIT])(lua_State*, void*); // for each userdata tag, a gc callback to be called immediately before freeing memory + Table* udatamt[LUA_LUTAG_LIMIT]; // metatables for tagged userdata TString* lightuserdataname[LUA_LUTAG_LIMIT]; // names for tagged lightuserdata diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 7ced52cf..9c1fca9e 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -342,6 +342,209 @@ void setupVectorHelpers(lua_State* L) lua_pop(L, 1); } +Vec2* lua_vec2_push(lua_State* L) +{ + Vec2* data = (Vec2*)lua_newuserdatatagged(L, sizeof(Vec2), kTagVec2); + + lua_getuserdatametatable(L, kTagVec2); + lua_setmetatable(L, -2); + + return data; +} + +Vec2* lua_vec2_get(lua_State* L, int idx) +{ + Vec2* a = (Vec2*)lua_touserdatatagged(L, idx, kTagVec2); + + if (a) + return a; + + luaL_typeerror(L, idx, "vec2"); +} + +static int lua_vec2(lua_State* L) +{ + double x = luaL_checknumber(L, 1); + double y = luaL_checknumber(L, 2); + + Vec2* data = lua_vec2_push(L); + + data->x = float(x); + data->y = float(y); + + return 1; +} + +static int lua_vec2_dot(lua_State* L) +{ + Vec2* a = lua_vec2_get(L, 1); + Vec2* b = lua_vec2_get(L, 2); + + lua_pushnumber(L, a->x * b->x + a->y * b->y); + return 1; +} + +static int lua_vec2_min(lua_State* L) +{ + Vec2* a = lua_vec2_get(L, 1); + Vec2* b = lua_vec2_get(L, 2); + + Vec2* data = lua_vec2_push(L); + + data->x = a->x < b->x ? a->x : b->x; + data->y = a->y < b->y ? a->y : b->y; + + return 1; +} + +static int lua_vec2_index(lua_State* L) +{ + Vec2* v = lua_vec2_get(L, 1); + const char* name = luaL_checkstring(L, 2); + + if (strcmp(name, "X") == 0) + { + lua_pushnumber(L, v->x); + return 1; + } + + if (strcmp(name, "Y") == 0) + { + lua_pushnumber(L, v->y); + return 1; + } + + if (strcmp(name, "Magnitude") == 0) + { + lua_pushnumber(L, sqrtf(v->x * v->x + v->y * v->y)); + return 1; + } + + if (strcmp(name, "Unit") == 0) + { + float invSqrt = 1.0f / sqrtf(v->x * v->x + v->y * v->y); + + Vec2* data = lua_vec2_push(L); + + data->x = v->x * invSqrt; + data->y = v->y * invSqrt; + return 1; + } + + luaL_error(L, "%s is not a valid member of vector", name); +} + +static int lua_vec2_namecall(lua_State* L) +{ + if (const char* str = lua_namecallatom(L, nullptr)) + { + if (strcmp(str, "Dot") == 0) + return lua_vec2_dot(L); + + if (strcmp(str, "Min") == 0) + return lua_vec2_min(L); + } + + luaL_error(L, "%s is not a valid method of vector", luaL_checkstring(L, 1)); +} + +void setupUserdataHelpers(lua_State* L) +{ + // create metatable with all the metamethods + luaL_newmetatable(L, "vec2"); + luaL_getmetatable(L, "vec2"); + lua_pushvalue(L, -1); + lua_setuserdatametatable(L, kTagVec2, -1); + + lua_pushcfunction(L, lua_vec2_index, nullptr); + lua_setfield(L, -2, "__index"); + + lua_pushcfunction(L, lua_vec2_namecall, nullptr); + lua_setfield(L, -2, "__namecall"); + + lua_pushcclosurek( + L, + [](lua_State* L) { + Vec2* a = lua_vec2_get(L, 1); + Vec2* b = lua_vec2_get(L, 2); + Vec2* data = lua_vec2_push(L); + + data->x = a->x + b->x; + data->y = a->y + b->y; + + return 1; + }, + nullptr, 0, nullptr); + lua_setfield(L, -2, "__add"); + + lua_pushcclosurek( + L, + [](lua_State* L) { + Vec2* a = lua_vec2_get(L, 1); + Vec2* b = lua_vec2_get(L, 2); + Vec2* data = lua_vec2_push(L); + + data->x = a->x - b->x; + data->y = a->y - b->y; + + return 1; + }, + nullptr, 0, nullptr); + lua_setfield(L, -2, "__sub"); + + lua_pushcclosurek( + L, + [](lua_State* L) { + Vec2* a = lua_vec2_get(L, 1); + Vec2* b = lua_vec2_get(L, 2); + Vec2* data = lua_vec2_push(L); + + data->x = a->x * b->x; + data->y = a->y * b->y; + + return 1; + }, + nullptr, 0, nullptr); + lua_setfield(L, -2, "__mul"); + + lua_pushcclosurek( + L, + [](lua_State* L) { + Vec2* a = lua_vec2_get(L, 1); + Vec2* b = lua_vec2_get(L, 2); + Vec2* data = lua_vec2_push(L); + + data->x = a->x / b->x; + data->y = a->y / b->y; + + return 1; + }, + nullptr, 0, nullptr); + lua_setfield(L, -2, "__div"); + + lua_pushcclosurek( + L, + [](lua_State* L) { + Vec2* a = lua_vec2_get(L, 1); + Vec2* data = lua_vec2_push(L); + + data->x = -a->x; + data->y = -a->y; + + return 1; + }, + nullptr, 0, nullptr); + lua_setfield(L, -2, "__unm"); + + lua_setreadonly(L, -1, true); + + // ctor + lua_pushcfunction(L, lua_vec2, "vec2"); + lua_setglobal(L, "vec2"); + + lua_pop(L, 1); +} + static void setupNativeHelpers(lua_State* L) { lua_pushcclosurek( @@ -1828,16 +2031,36 @@ TEST_CASE("UserdataApi") luaL_newmetatable(L, "udata2"); void* ud5 = lua_newuserdata(L, 0); - lua_getfield(L, LUA_REGISTRYINDEX, "udata1"); + luaL_getmetatable(L, "udata1"); lua_setmetatable(L, -2); void* ud6 = lua_newuserdata(L, 0); - lua_getfield(L, LUA_REGISTRYINDEX, "udata2"); + luaL_getmetatable(L, "udata2"); lua_setmetatable(L, -2); CHECK(luaL_checkudata(L, -2, "udata1") == ud5); CHECK(luaL_checkudata(L, -1, "udata2") == ud6); + // tagged user data with fast metatable access + luaL_newmetatable(L, "udata3"); + luaL_getmetatable(L, "udata3"); + lua_setuserdatametatable(L, 50, -1); + + luaL_newmetatable(L, "udata4"); + luaL_getmetatable(L, "udata4"); + lua_setuserdatametatable(L, 51, -1); + + void* ud7 = lua_newuserdatatagged(L, 16, 50); + lua_getuserdatametatable(L, 50); + lua_setmetatable(L, -2); + + void* ud8 = lua_newuserdatatagged(L, 16, 51); + lua_getuserdatametatable(L, 51); + lua_setmetatable(L, -2); + + CHECK(luaL_checkudata(L, -2, "udata3") == ud7); + CHECK(luaL_checkudata(L, -1, "udata4") == ud8); + globalState.reset(); CHECK(dtorhits == 42); @@ -1911,7 +2134,6 @@ TEST_CASE("Iter") } const int kInt64Tag = 1; -static int gInt64MT = -1; static int64_t getInt64(lua_State* L, int idx) { @@ -1928,7 +2150,7 @@ static void pushInt64(lua_State* L, int64_t value) { void* p = lua_newuserdatatagged(L, sizeof(int64_t), kInt64Tag); - lua_getref(L, gInt64MT); + luaL_getmetatable(L, "int64"); lua_setmetatable(L, -2); *static_cast(p) = value; @@ -1938,8 +2160,7 @@ TEST_CASE("Userdata") { runConformance("userdata.lua", [](lua_State* L) { // create metatable with all the metamethods - lua_newtable(L); - gInt64MT = lua_ref(L, -1); + luaL_newmetatable(L, "int64"); // __index lua_pushcfunction( @@ -2164,6 +2385,86 @@ TEST_CASE("NativeTypeAnnotations") }); } +TEST_CASE("NativeUserdata") +{ + lua_CompileOptions copts = defaultOptions(); + Luau::CodeGen::CompilationOptions nativeOpts = defaultCodegenOptions(); + + static const char* kUserdataCompileTypes[] = {"vec2", "color", "mat3", nullptr}; + copts.userdataTypes = kUserdataCompileTypes; + + SUBCASE("NoIrHooks") + { + SUBCASE("O0") + { + copts.optimizationLevel = 0; + } + SUBCASE("O1") + { + copts.optimizationLevel = 1; + } + SUBCASE("O2") + { + copts.optimizationLevel = 2; + } + } + SUBCASE("IrHooks") + { + nativeOpts.hooks.vectorAccessBytecodeType = vectorAccessBytecodeType; + nativeOpts.hooks.vectorNamecallBytecodeType = vectorNamecallBytecodeType; + nativeOpts.hooks.vectorAccess = vectorAccess; + nativeOpts.hooks.vectorNamecall = vectorNamecall; + + nativeOpts.hooks.userdataAccessBytecodeType = userdataAccessBytecodeType; + nativeOpts.hooks.userdataMetamethodBytecodeType = userdataMetamethodBytecodeType; + nativeOpts.hooks.userdataNamecallBytecodeType = userdataNamecallBytecodeType; + nativeOpts.hooks.userdataAccess = userdataAccess; + nativeOpts.hooks.userdataMetamethod = userdataMetamethod; + nativeOpts.hooks.userdataNamecall = userdataNamecall; + + nativeOpts.userdataTypes = kUserdataRunTypes; + + SUBCASE("O0") + { + copts.optimizationLevel = 0; + } + SUBCASE("O1") + { + copts.optimizationLevel = 1; + } + SUBCASE("O2") + { + copts.optimizationLevel = 2; + } + } + + runConformance( + "native_userdata.lua", + [](lua_State* L) { + Luau::CodeGen::setUserdataRemapper(L, kUserdataRunTypes, [](void* context, const char* str, size_t len) -> uint8_t { + const char** types = (const char**)context; + + uint8_t index = 0; + + std::string_view sv{str, len}; + + for (; *types; ++types) + { + if (sv == *types) + return index; + + index++; + } + + return 0xff; + }); + + setupVectorHelpers(L); + setupUserdataHelpers(L); + }, + nullptr, nullptr, &copts, false, &nativeOpts); +} + [[nodiscard]] static std::string makeHugeFunctionSource() { std::string source; diff --git a/tests/ConformanceIrHooks.h b/tests/ConformanceIrHooks.h index d4050863..ab5b86d4 100644 --- a/tests/ConformanceIrHooks.h +++ b/tests/ConformanceIrHooks.h @@ -5,14 +5,44 @@ static const char* kUserdataRunTypes[] = {"extra", "color", "vec2", "mat3", nullptr}; +constexpr uint8_t kUserdataExtra = 0; +constexpr uint8_t kUserdataColor = 1; +constexpr uint8_t kUserdataVec2 = 2; +constexpr uint8_t kUserdataMat3 = 3; + +// Userdata tags can be different from userdata bytecode type indices +constexpr uint8_t kTagVec2 = 12; + +struct Vec2 +{ + float x; + float y; +}; + +inline bool compareMemberName(const char* member, size_t memberLength, const char* str) +{ + return memberLength == strlen(str) && strcmp(member, str) == 0; +} + +inline uint8_t typeToUserdataIndex(uint8_t type) +{ + // Underflow will push the type into a value that is not comparable to any kUserdata* constants + return type - LBC_TYPE_TAGGED_USERDATA_BASE; +} + +inline uint8_t userdataIndexToType(uint8_t userdataIndex) +{ + return LBC_TYPE_TAGGED_USERDATA_BASE + userdataIndex; +} + inline uint8_t vectorAccessBytecodeType(const char* member, size_t memberLength) { using namespace Luau::CodeGen; - if (memberLength == strlen("Magnitude") && strcmp(member, "Magnitude") == 0) + if (compareMemberName(member, memberLength, "Magnitude")) return LBC_TYPE_NUMBER; - if (memberLength == strlen("Unit") && strcmp(member, "Unit") == 0) + if (compareMemberName(member, memberLength, "Unit")) return LBC_TYPE_VECTOR; return LBC_TYPE_ANY; @@ -22,7 +52,7 @@ inline bool vectorAccess(Luau::CodeGen::IrBuilder& build, const char* member, si { using namespace Luau::CodeGen; - if (memberLength == strlen("Magnitude") && strcmp(member, "Magnitude") == 0) + if (compareMemberName(member, memberLength, "Magnitude")) { IrOp x = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(0)); IrOp y = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(4)); @@ -42,7 +72,7 @@ inline bool vectorAccess(Luau::CodeGen::IrBuilder& build, const char* member, si return true; } - if (memberLength == strlen("Unit") && strcmp(member, "Unit") == 0) + if (compareMemberName(member, memberLength, "Unit")) { IrOp x = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(0)); IrOp y = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(4)); @@ -72,10 +102,10 @@ inline bool vectorAccess(Luau::CodeGen::IrBuilder& build, const char* member, si inline uint8_t vectorNamecallBytecodeType(const char* member, size_t memberLength) { - if (memberLength == strlen("Dot") && strcmp(member, "Dot") == 0) + if (compareMemberName(member, memberLength, "Dot")) return LBC_TYPE_NUMBER; - if (memberLength == strlen("Cross") && strcmp(member, "Cross") == 0) + if (compareMemberName(member, memberLength, "Cross")) return LBC_TYPE_VECTOR; return LBC_TYPE_ANY; @@ -86,7 +116,7 @@ inline bool vectorNamecall( { using namespace Luau::CodeGen; - if (memberLength == strlen("Dot") && strcmp(member, "Dot") == 0 && params == 2 && results <= 1) + if (compareMemberName(member, memberLength, "Dot") && params == 2 && results <= 1) { build.loadAndCheckTag(build.vmReg(argResReg + 2), LUA_TVECTOR, build.vmExit(pcpos)); @@ -114,7 +144,7 @@ inline bool vectorNamecall( return true; } - if (memberLength == strlen("Cross") && strcmp(member, "Cross") == 0 && params == 2 && results <= 1) + if (compareMemberName(member, memberLength, "Cross") && params == 2 && results <= 1) { build.loadAndCheckTag(build.vmReg(argResReg + 2), LUA_TVECTOR, build.vmExit(pcpos)); @@ -151,3 +181,362 @@ inline bool vectorNamecall( return false; } + +inline uint8_t userdataAccessBytecodeType(uint8_t type, const char* member, size_t memberLength) +{ + switch (typeToUserdataIndex(type)) + { + case kUserdataColor: + if (compareMemberName(member, memberLength, "R")) + return LBC_TYPE_NUMBER; + + if (compareMemberName(member, memberLength, "G")) + return LBC_TYPE_NUMBER; + + if (compareMemberName(member, memberLength, "B")) + return LBC_TYPE_NUMBER; + break; + case kUserdataVec2: + if (compareMemberName(member, memberLength, "X")) + return LBC_TYPE_NUMBER; + + if (compareMemberName(member, memberLength, "Y")) + return LBC_TYPE_NUMBER; + + if (compareMemberName(member, memberLength, "Magnitude")) + return LBC_TYPE_NUMBER; + + if (compareMemberName(member, memberLength, "Unit")) + return userdataIndexToType(kUserdataVec2); + break; + case kUserdataMat3: + if (compareMemberName(member, memberLength, "Row1")) + return LBC_TYPE_VECTOR; + + if (compareMemberName(member, memberLength, "Row2")) + return LBC_TYPE_VECTOR; + + if (compareMemberName(member, memberLength, "Row3")) + return LBC_TYPE_VECTOR; + break; + } + + return LBC_TYPE_ANY; +} + +inline bool userdataAccess( + Luau::CodeGen::IrBuilder& build, uint8_t type, const char* member, size_t memberLength, int resultReg, int sourceReg, int pcpos) +{ + using namespace Luau::CodeGen; + + switch (typeToUserdataIndex(type)) + { + case kUserdataColor: + break; + case kUserdataVec2: + if (compareMemberName(member, memberLength, "X")) + { + IrOp udata = build.inst(IrCmd::LOAD_POINTER, build.vmReg(sourceReg)); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata, build.constInt(kTagVec2), build.vmExit(pcpos)); + + IrOp value = build.inst(IrCmd::BUFFER_READF32, udata, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(resultReg), value); + build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TNUMBER)); + return true; + } + + if (compareMemberName(member, memberLength, "Y")) + { + IrOp udata = build.inst(IrCmd::LOAD_POINTER, build.vmReg(sourceReg)); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata, build.constInt(kTagVec2), build.vmExit(pcpos)); + + IrOp value = build.inst(IrCmd::BUFFER_READF32, udata, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(resultReg), value); + build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TNUMBER)); + return true; + } + + if (compareMemberName(member, memberLength, "Magnitude")) + { + IrOp udata = build.inst(IrCmd::LOAD_POINTER, build.vmReg(sourceReg)); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata, build.constInt(kTagVec2), build.vmExit(pcpos)); + + IrOp x = build.inst(IrCmd::BUFFER_READF32, udata, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + IrOp y = build.inst(IrCmd::BUFFER_READF32, udata, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + + IrOp x2 = build.inst(IrCmd::MUL_NUM, x, x); + IrOp y2 = build.inst(IrCmd::MUL_NUM, y, y); + + IrOp sum = build.inst(IrCmd::ADD_NUM, x2, y2); + + IrOp mag = build.inst(IrCmd::SQRT_NUM, sum); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(resultReg), mag); + build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TNUMBER)); + return true; + } + + if (compareMemberName(member, memberLength, "Unit")) + { + IrOp udata = build.inst(IrCmd::LOAD_POINTER, build.vmReg(sourceReg)); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata, build.constInt(kTagVec2), build.vmExit(pcpos)); + + IrOp x = build.inst(IrCmd::BUFFER_READF32, udata, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + IrOp y = build.inst(IrCmd::BUFFER_READF32, udata, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + + IrOp x2 = build.inst(IrCmd::MUL_NUM, x, x); + IrOp y2 = build.inst(IrCmd::MUL_NUM, y, y); + + IrOp sum = build.inst(IrCmd::ADD_NUM, x2, y2); + + IrOp mag = build.inst(IrCmd::SQRT_NUM, sum); + IrOp inv = build.inst(IrCmd::DIV_NUM, build.constDouble(1.0), mag); + + IrOp xr = build.inst(IrCmd::MUL_NUM, x, inv); + IrOp yr = build.inst(IrCmd::MUL_NUM, y, inv); + + build.inst(IrCmd::CHECK_GC); + IrOp udatar = build.inst(IrCmd::NEW_USERDATA, build.constInt(sizeof(Vec2)), build.constInt(kTagVec2)); + + build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, x)), xr, build.constTag(LUA_TUSERDATA)); + build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, y)), yr, build.constTag(LUA_TUSERDATA)); + + build.inst(IrCmd::STORE_POINTER, build.vmReg(resultReg), udatar); + build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TUSERDATA)); + return true; + } + break; + case kUserdataMat3: + break; + } + + return false; +} + +inline uint8_t userdataMetamethodBytecodeType(uint8_t lhsTy, uint8_t rhsTy, Luau::CodeGen::HostMetamethod method) +{ + switch (method) + { + case Luau::CodeGen::HostMetamethod::Add: + case Luau::CodeGen::HostMetamethod::Sub: + case Luau::CodeGen::HostMetamethod::Mul: + case Luau::CodeGen::HostMetamethod::Div: + if (typeToUserdataIndex(lhsTy) == kUserdataVec2 || typeToUserdataIndex(rhsTy) == kUserdataVec2) + return userdataIndexToType(kUserdataVec2); + break; + case Luau::CodeGen::HostMetamethod::Minus: + if (typeToUserdataIndex(lhsTy) == kUserdataVec2) + return userdataIndexToType(kUserdataVec2); + break; + default: + break; + } + + return LBC_TYPE_ANY; +} + +inline bool userdataMetamethod(Luau::CodeGen::IrBuilder& build, uint8_t lhsTy, uint8_t rhsTy, int resultReg, Luau::CodeGen::IrOp lhs, + Luau::CodeGen::IrOp rhs, Luau::CodeGen::HostMetamethod method, int pcpos) +{ + using namespace Luau::CodeGen; + + switch (method) + { + case Luau::CodeGen::HostMetamethod::Add: + if (typeToUserdataIndex(lhsTy) == kUserdataVec2 && typeToUserdataIndex(rhsTy) == kUserdataVec2) + { + build.loadAndCheckTag(lhs, LUA_TUSERDATA, build.vmExit(pcpos)); + build.loadAndCheckTag(rhs, LUA_TUSERDATA, build.vmExit(pcpos)); + + IrOp udata1 = build.inst(IrCmd::LOAD_POINTER, lhs); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata1, build.constInt(kTagVec2), build.vmExit(pcpos)); + + IrOp udata2 = build.inst(IrCmd::LOAD_POINTER, rhs); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata2, build.constInt(kTagVec2), build.vmExit(pcpos)); + + IrOp x1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + IrOp x2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + IrOp mx = build.inst(IrCmd::ADD_NUM, x1, x2); + + IrOp y1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + IrOp y2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + IrOp my = build.inst(IrCmd::ADD_NUM, y1, y2); + + build.inst(IrCmd::CHECK_GC); + IrOp udatar = build.inst(IrCmd::NEW_USERDATA, build.constInt(sizeof(Vec2)), build.constInt(kTagVec2)); + + build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, x)), mx, build.constTag(LUA_TUSERDATA)); + build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, y)), my, build.constTag(LUA_TUSERDATA)); + + build.inst(IrCmd::STORE_POINTER, build.vmReg(resultReg), udatar); + build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TUSERDATA)); + + return true; + } + break; + case Luau::CodeGen::HostMetamethod::Mul: + if (typeToUserdataIndex(lhsTy) == kUserdataVec2 && typeToUserdataIndex(rhsTy) == kUserdataVec2) + { + build.loadAndCheckTag(lhs, LUA_TUSERDATA, build.vmExit(pcpos)); + build.loadAndCheckTag(rhs, LUA_TUSERDATA, build.vmExit(pcpos)); + + IrOp udata1 = build.inst(IrCmd::LOAD_POINTER, lhs); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata1, build.constInt(kTagVec2), build.vmExit(pcpos)); + + IrOp udata2 = build.inst(IrCmd::LOAD_POINTER, rhs); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata2, build.constInt(kTagVec2), build.vmExit(pcpos)); + + IrOp x1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + IrOp x2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + IrOp mx = build.inst(IrCmd::MUL_NUM, x1, x2); + + IrOp y1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + IrOp y2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + IrOp my = build.inst(IrCmd::MUL_NUM, y1, y2); + + build.inst(IrCmd::CHECK_GC); + IrOp udatar = build.inst(IrCmd::NEW_USERDATA, build.constInt(sizeof(Vec2)), build.constInt(kTagVec2)); + + build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, x)), mx, build.constTag(LUA_TUSERDATA)); + build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, y)), my, build.constTag(LUA_TUSERDATA)); + + build.inst(IrCmd::STORE_POINTER, build.vmReg(resultReg), udatar); + build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TUSERDATA)); + + return true; + } + break; + case Luau::CodeGen::HostMetamethod::Minus: + if (typeToUserdataIndex(lhsTy) == kUserdataVec2) + { + build.loadAndCheckTag(lhs, LUA_TUSERDATA, build.vmExit(pcpos)); + + IrOp udata1 = build.inst(IrCmd::LOAD_POINTER, lhs); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata1, build.constInt(kTagVec2), build.vmExit(pcpos)); + + IrOp x = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + IrOp y = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + IrOp mx = build.inst(IrCmd::UNM_NUM, x); + IrOp my = build.inst(IrCmd::UNM_NUM, y); + + build.inst(IrCmd::CHECK_GC); + IrOp udatar = build.inst(IrCmd::NEW_USERDATA, build.constInt(sizeof(Vec2)), build.constInt(kTagVec2)); + + build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, x)), mx, build.constTag(LUA_TUSERDATA)); + build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, y)), my, build.constTag(LUA_TUSERDATA)); + + build.inst(IrCmd::STORE_POINTER, build.vmReg(resultReg), udatar); + build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TUSERDATA)); + + return true; + } + break; + default: + break; + } + + return false; +} + +inline uint8_t userdataNamecallBytecodeType(uint8_t type, const char* member, size_t memberLength) +{ + switch (typeToUserdataIndex(type)) + { + case kUserdataColor: + break; + case kUserdataVec2: + if (compareMemberName(member, memberLength, "Dot")) + return LBC_TYPE_NUMBER; + + if (compareMemberName(member, memberLength, "Min")) + return userdataIndexToType(kUserdataVec2); + break; + case kUserdataMat3: + break; + } + + return LBC_TYPE_ANY; +} + +inline bool userdataNamecall(Luau::CodeGen::IrBuilder& build, uint8_t type, const char* member, size_t memberLength, int argResReg, int sourceReg, + int params, int results, int pcpos) +{ + using namespace Luau::CodeGen; + + switch (typeToUserdataIndex(type)) + { + case kUserdataColor: + break; + case kUserdataVec2: + if (compareMemberName(member, memberLength, "Dot")) + { + IrOp udata1 = build.inst(IrCmd::LOAD_POINTER, build.vmReg(sourceReg)); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata1, build.constInt(kTagVec2), build.vmExit(pcpos)); + + build.loadAndCheckTag(build.vmReg(argResReg + 2), LUA_TUSERDATA, build.vmExit(pcpos)); + + IrOp udata2 = build.inst(IrCmd::LOAD_POINTER, build.vmReg(argResReg + 2)); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata2, build.constInt(kTagVec2), build.vmExit(pcpos)); + + IrOp x1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + IrOp x2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + IrOp xx = build.inst(IrCmd::MUL_NUM, x1, x2); + + IrOp y1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + IrOp y2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + IrOp yy = build.inst(IrCmd::MUL_NUM, y1, y2); + + IrOp sum = build.inst(IrCmd::ADD_NUM, xx, yy); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(argResReg), sum); + build.inst(IrCmd::STORE_TAG, build.vmReg(argResReg), build.constTag(LUA_TNUMBER)); + + // If the function is called in multi-return context, stack has to be adjusted + if (results == LUA_MULTRET) + build.inst(IrCmd::ADJUST_STACK_TO_REG, build.vmReg(argResReg), build.constInt(1)); + + return true; + } + + if (compareMemberName(member, memberLength, "Min")) + { + IrOp udata1 = build.inst(IrCmd::LOAD_POINTER, build.vmReg(sourceReg)); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata1, build.constInt(kTagVec2), build.vmExit(pcpos)); + + build.loadAndCheckTag(build.vmReg(argResReg + 2), LUA_TUSERDATA, build.vmExit(pcpos)); + + IrOp udata2 = build.inst(IrCmd::LOAD_POINTER, build.vmReg(argResReg + 2)); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata2, build.constInt(kTagVec2), build.vmExit(pcpos)); + + IrOp x1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + IrOp x2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + IrOp mx = build.inst(IrCmd::MIN_NUM, x1, x2); + + IrOp y1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + IrOp y2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + IrOp my = build.inst(IrCmd::MIN_NUM, y1, y2); + + build.inst(IrCmd::CHECK_GC); + IrOp udatar = build.inst(IrCmd::NEW_USERDATA, build.constInt(sizeof(Vec2)), build.constInt(kTagVec2)); + + build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, x)), mx, build.constTag(LUA_TUSERDATA)); + build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, y)), my, build.constTag(LUA_TUSERDATA)); + + build.inst(IrCmd::STORE_POINTER, build.vmReg(argResReg), udatar); + build.inst(IrCmd::STORE_TAG, build.vmReg(argResReg), build.constTag(LUA_TUSERDATA)); + + // If the function is called in multi-return context, stack has to be adjusted + if (results == LUA_MULTRET) + build.inst(IrCmd::ADJUST_STACK_TO_REG, build.vmReg(argResReg), build.constInt(1)); + + return true; + } + break; + case kUserdataMat3: + break; + } + + return false; +} diff --git a/tests/IrLowering.test.cpp b/tests/IrLowering.test.cpp index 5d7fedd8..ecdb522c 100644 --- a/tests/IrLowering.test.cpp +++ b/tests/IrLowering.test.cpp @@ -24,6 +24,8 @@ LUAU_FASTFLAG(LuauCompileTempTypeInfo) LUAU_FASTFLAG(LuauCodegenAnalyzeHostVectorOps) LUAU_FASTFLAG(LuauCompileUserdataInfo) LUAU_FASTFLAG(LuauLoadUserdataInfo) +LUAU_FASTFLAG(LuauCodegenUserdataOps) +LUAU_FASTFLAG(LuauCodegenUserdataAlloc) static std::string getCodegenAssembly(const char* source, bool includeIrTypes = false, int debugLevel = 1) { @@ -34,6 +36,13 @@ static std::string getCodegenAssembly(const char* source, bool includeIrTypes = options.compilationOptions.hooks.vectorAccess = vectorAccess; options.compilationOptions.hooks.vectorNamecall = vectorNamecall; + options.compilationOptions.hooks.userdataAccessBytecodeType = userdataAccessBytecodeType; + options.compilationOptions.hooks.userdataMetamethodBytecodeType = userdataMetamethodBytecodeType; + options.compilationOptions.hooks.userdataNamecallBytecodeType = userdataNamecallBytecodeType; + options.compilationOptions.hooks.userdataAccess = userdataAccess; + options.compilationOptions.hooks.userdataMetamethod = userdataMetamethod; + options.compilationOptions.hooks.userdataNamecall = userdataNamecall; + // For IR, we don't care about assembly, but we want a stable target options.target = Luau::CodeGen::AssemblyOptions::Target::X64_SystemV; @@ -1690,4 +1699,352 @@ end )"); } +TEST_CASE("CustomUserdataPropertyAccess") +{ + // This test requires runtime component to be present + if (!Luau::CodeGen::isSupported()) + return; + + ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, + {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true}, + {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, {FFlag::LuauCodegenUserdataOps, true}}; + + CHECK_EQ("\n" + getCodegenAssembly(R"( +local function foo(v: vec2) + return v.X + v.Y +end +)", + /* includeIrTypes */ true), + R"( +; function foo($arg0) line 2 +; R0: vec2 [argument] +bb_0: + CHECK_TAG R0, tuserdata, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + %6 = LOAD_POINTER R0 + CHECK_USERDATA_TAG %6, 12i, exit(0) + %8 = BUFFER_READF32 %6, 0i, tuserdata + %15 = BUFFER_READF32 %6, 4i, tuserdata + %24 = ADD_NUM %8, %15 + STORE_DOUBLE R1, %24 + STORE_TAG R1, tnumber + INTERRUPT 5u + RETURN R1, 1i +)"); +} + +TEST_CASE("CustomUserdataPropertyAccess2") +{ + // This test requires runtime component to be present + if (!Luau::CodeGen::isSupported()) + return; + + ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, + {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true}, + {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, {FFlag::LuauCodegenUserdataOps, true}}; + + CHECK_EQ("\n" + getCodegenAssembly(R"( +local function foo(a: mat3) + return a.Row1 * a.Row2 +end +)", + /* includeIrTypes */ true), + R"( +; function foo($arg0) line 2 +; R0: mat3 [argument] +bb_0: + CHECK_TAG R0, tuserdata, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + FALLBACK_GETTABLEKS 0u, R2, R0, K0 + FALLBACK_GETTABLEKS 2u, R3, R0, K1 + CHECK_TAG R2, tvector, exit(4) + CHECK_TAG R3, tvector, exit(4) + %14 = LOAD_TVALUE R2 + %15 = LOAD_TVALUE R3 + %16 = MUL_VEC %14, %15 + %17 = TAG_VECTOR %16 + STORE_TVALUE R1, %17 + INTERRUPT 5u + RETURN R1, 1i +)"); +} + +TEST_CASE("CustomUserdataNamecall1") +{ + // This test requires runtime component to be present + if (!Luau::CodeGen::isSupported()) + return; + + ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, + {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true}, + {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, {FFlag::LuauCodegenAnalyzeHostVectorOps, true}, + {FFlag::LuauCodegenUserdataOps, true}}; + + CHECK_EQ("\n" + getCodegenAssembly(R"( +local function foo(a: vec2, b: vec2) + return a:Dot(b) +end +)", + /* includeIrTypes */ true), + R"( +; function foo($arg0, $arg1) line 2 +; R0: vec2 [argument] +; R1: vec2 [argument] +bb_0: + CHECK_TAG R0, tuserdata, exit(entry) + CHECK_TAG R1, tuserdata, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + %6 = LOAD_TVALUE R1 + STORE_TVALUE R4, %6 + %10 = LOAD_POINTER R0 + CHECK_USERDATA_TAG %10, 12i, exit(1) + %14 = LOAD_POINTER R4 + CHECK_USERDATA_TAG %14, 12i, exit(1) + %16 = BUFFER_READF32 %10, 0i, tuserdata + %17 = BUFFER_READF32 %14, 0i, tuserdata + %18 = MUL_NUM %16, %17 + %19 = BUFFER_READF32 %10, 4i, tuserdata + %20 = BUFFER_READF32 %14, 4i, tuserdata + %21 = MUL_NUM %19, %20 + %22 = ADD_NUM %18, %21 + STORE_DOUBLE R2, %22 + STORE_TAG R2, tnumber + ADJUST_STACK_TO_REG R2, 1i + INTERRUPT 4u + RETURN R2, -1i +)"); +} + +TEST_CASE("CustomUserdataNamecall2") +{ + // This test requires runtime component to be present + if (!Luau::CodeGen::isSupported()) + return; + + ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, + {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true}, + {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, {FFlag::LuauCodegenAnalyzeHostVectorOps, true}, + {FFlag::LuauCodegenUserdataOps, true}, {FFlag::LuauCodegenUserdataAlloc, true}}; + + CHECK_EQ("\n" + getCodegenAssembly(R"( +local function foo(a: vec2, b: vec2) + return a:Min(b) +end +)", + /* includeIrTypes */ true), + R"( +; function foo($arg0, $arg1) line 2 +; R0: vec2 [argument] +; R1: vec2 [argument] +bb_0: + CHECK_TAG R0, tuserdata, exit(entry) + CHECK_TAG R1, tuserdata, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + %6 = LOAD_TVALUE R1 + STORE_TVALUE R4, %6 + %10 = LOAD_POINTER R0 + CHECK_USERDATA_TAG %10, 12i, exit(1) + %14 = LOAD_POINTER R4 + CHECK_USERDATA_TAG %14, 12i, exit(1) + %16 = BUFFER_READF32 %10, 0i, tuserdata + %17 = BUFFER_READF32 %14, 0i, tuserdata + %18 = MIN_NUM %16, %17 + %19 = BUFFER_READF32 %10, 4i, tuserdata + %20 = BUFFER_READF32 %14, 4i, tuserdata + %21 = MIN_NUM %19, %20 + CHECK_GC + %23 = NEW_USERDATA 8i, 12i + BUFFER_WRITEF32 %23, 0i, %18, tuserdata + BUFFER_WRITEF32 %23, 4i, %21, tuserdata + STORE_POINTER R2, %23 + STORE_TAG R2, tuserdata + ADJUST_STACK_TO_REG R2, 1i + INTERRUPT 4u + RETURN R2, -1i +)"); +} + +TEST_CASE("CustomUserdataMetamethodDirectFlow") +{ + // This test requires runtime component to be present + if (!Luau::CodeGen::isSupported()) + return; + + ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, + {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true}, + {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, {FFlag::LuauCodegenUserdataOps, true}}; + + CHECK_EQ("\n" + getCodegenAssembly(R"( +local function foo(a: mat3, b: mat3) + return a * b +end +)", + /* includeIrTypes */ true), + R"( +; function foo($arg0, $arg1) line 2 +; R0: mat3 [argument] +; R1: mat3 [argument] +bb_0: + CHECK_TAG R0, tuserdata, exit(entry) + CHECK_TAG R1, tuserdata, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + SET_SAVEDPC 1u + DO_ARITH R2, R0, R1, 10i + INTERRUPT 1u + RETURN R2, 1i +)"); +} + +TEST_CASE("CustomUserdataMetamethodDirectFlow2") +{ + // This test requires runtime component to be present + if (!Luau::CodeGen::isSupported()) + return; + + ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, + {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true}, + {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, {FFlag::LuauCodegenUserdataOps, true}}; + + CHECK_EQ("\n" + getCodegenAssembly(R"( +local function foo(a: mat3) + return -a +end +)", + /* includeIrTypes */ true), + R"( +; function foo($arg0) line 2 +; R0: mat3 [argument] +bb_0: + CHECK_TAG R0, tuserdata, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + SET_SAVEDPC 1u + DO_ARITH R1, R0, R0, 15i + INTERRUPT 1u + RETURN R1, 1i +)"); +} + +TEST_CASE("CustomUserdataMetamethodDirectFlow3") +{ + // This test requires runtime component to be present + if (!Luau::CodeGen::isSupported()) + return; + + ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, + {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true}, + {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, {FFlag::LuauCodegenUserdataOps, true}}; + + CHECK_EQ("\n" + getCodegenAssembly(R"( +local function foo(a: sequence) + return #a +end +)", + /* includeIrTypes */ true), + R"( +; function foo($arg0) line 2 +; R0: userdata [argument] +bb_0: + CHECK_TAG R0, tuserdata, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + SET_SAVEDPC 1u + DO_LEN R1, R0 + INTERRUPT 1u + RETURN R1, 1i +)"); +} + +TEST_CASE("CustomUserdataMetamethod") +{ + // This test requires runtime component to be present + if (!Luau::CodeGen::isSupported()) + return; + + ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, + {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true}, + {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, {FFlag::LuauCodegenUserdataOps, true}, + {FFlag::LuauCodegenUserdataAlloc, true}}; + + CHECK_EQ("\n" + getCodegenAssembly(R"( +local function foo(a: vec2, b: vec2, c: vec2) + return -c + a * b +end +)", + /* includeIrTypes */ true), + R"( +; function foo($arg0, $arg1, $arg2) line 2 +; R0: vec2 [argument] +; R1: vec2 [argument] +; R2: vec2 [argument] +bb_0: + CHECK_TAG R0, tuserdata, exit(entry) + CHECK_TAG R1, tuserdata, exit(entry) + CHECK_TAG R2, tuserdata, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + %10 = LOAD_POINTER R2 + CHECK_USERDATA_TAG %10, 12i, exit(0) + %12 = BUFFER_READF32 %10, 0i, tuserdata + %13 = BUFFER_READF32 %10, 4i, tuserdata + %14 = UNM_NUM %12 + %15 = UNM_NUM %13 + CHECK_GC + %17 = NEW_USERDATA 8i, 12i + BUFFER_WRITEF32 %17, 0i, %14, tuserdata + BUFFER_WRITEF32 %17, 4i, %15, tuserdata + STORE_POINTER R4, %17 + STORE_TAG R4, tuserdata + %26 = LOAD_POINTER R0 + CHECK_USERDATA_TAG %26, 12i, exit(1) + %28 = LOAD_POINTER R1 + CHECK_USERDATA_TAG %28, 12i, exit(1) + %30 = BUFFER_READF32 %26, 0i, tuserdata + %31 = BUFFER_READF32 %28, 0i, tuserdata + %32 = MUL_NUM %30, %31 + %33 = BUFFER_READF32 %26, 4i, tuserdata + %34 = BUFFER_READF32 %28, 4i, tuserdata + %35 = MUL_NUM %33, %34 + %37 = NEW_USERDATA 8i, 12i + BUFFER_WRITEF32 %37, 0i, %32, tuserdata + BUFFER_WRITEF32 %37, 4i, %35, tuserdata + STORE_POINTER R5, %37 + STORE_TAG R5, tuserdata + %50 = BUFFER_READF32 %17, 0i, tuserdata + %51 = BUFFER_READF32 %37, 0i, tuserdata + %52 = ADD_NUM %50, %51 + %53 = BUFFER_READF32 %17, 4i, tuserdata + %54 = BUFFER_READF32 %37, 4i, tuserdata + %55 = ADD_NUM %53, %54 + %57 = NEW_USERDATA 8i, 12i + BUFFER_WRITEF32 %57, 0i, %52, tuserdata + BUFFER_WRITEF32 %57, 4i, %55, tuserdata + STORE_POINTER R3, %57 + STORE_TAG R3, tuserdata + INTERRUPT 3u + RETURN R3, 1i +)"); +} + TEST_SUITE_END(); diff --git a/tests/Lexer.test.cpp b/tests/Lexer.test.cpp index 78d1389a..e0716e4c 100644 --- a/tests/Lexer.test.cpp +++ b/tests/Lexer.test.cpp @@ -192,13 +192,13 @@ TEST_CASE("string_interpolation_double_brace") auto brokenInterpBegin = lexer.next(); CHECK_EQ(brokenInterpBegin.type, Lexeme::BrokenInterpDoubleBrace); - CHECK_EQ(std::string(brokenInterpBegin.data, brokenInterpBegin.length), std::string("foo")); + CHECK_EQ(std::string(brokenInterpBegin.data, brokenInterpBegin.getLength()), std::string("foo")); CHECK_EQ(lexer.next().type, Lexeme::Name); auto interpEnd = lexer.next(); CHECK_EQ(interpEnd.type, Lexeme::InterpStringEnd); - CHECK_EQ(std::string(interpEnd.data, interpEnd.length), std::string("}bar")); + CHECK_EQ(std::string(interpEnd.data, interpEnd.getLength()), std::string("}bar")); } TEST_CASE("string_interpolation_double_but_unmatched_brace") diff --git a/tests/NonStrictTypeChecker.test.cpp b/tests/NonStrictTypeChecker.test.cpp index e51fb0df..81a84722 100644 --- a/tests/NonStrictTypeChecker.test.cpp +++ b/tests/NonStrictTypeChecker.test.cpp @@ -15,6 +15,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauAttributeSyntax); + #define NONSTRICT_REQUIRE_ERR_AT_POS(pos, result, idx) \ do \ { \ @@ -68,6 +70,7 @@ struct NonStrictTypeCheckerFixture : Fixture { ScopedFastFlag flags[] = { {FFlag::DebugLuauDeferredConstraintResolution, true}, + {FFlag::LuauAttributeSyntax, true}, }; LoadDefinitionFileResult res = loadDefinition(definitions); LUAU_ASSERT(res.success); @@ -78,6 +81,7 @@ struct NonStrictTypeCheckerFixture : Fixture { ScopedFastFlag flags[] = { {FFlag::DebugLuauDeferredConstraintResolution, true}, + {FFlag::LuauAttributeSyntax, true}, }; LoadDefinitionFileResult res = loadDefinition(definitions); LUAU_ASSERT(res.success); @@ -85,21 +89,21 @@ struct NonStrictTypeCheckerFixture : Fixture } std::string definitions = R"BUILTIN_SRC( -declare function @checked abs(n: number): number -declare function @checked lower(s: string): string +@checked declare function abs(n: number): number +@checked declare function lower(s: string): string declare function cond() : boolean -declare function @checked contrived(n : Not) : number +@checked declare function contrived(n : Not) : number -- interesting types of things that we would like to mark as checked -declare function @checked onlyNums(...: number) : number -declare function @checked mixedArgs(x: string, ...: number) : number -declare function @checked optionalArg(x: string?) : number +@checked declare function onlyNums(...: number) : number +@checked declare function mixedArgs(x: string, ...: number) : number +@checked declare function optionalArg(x: string?) : number declare foo: { bar: @checked (number) -> number, } -declare function @checked optionalArgsAtTheEnd1(x: string, y: number?, z: number?) : number -declare function @checked optionalArgsAtTheEnd2(x: string, y: number?, z: string) : number +@checked declare function optionalArgsAtTheEnd1(x: string, y: number?, z: number?) : number +@checked declare function optionalArgsAtTheEnd2(x: string, y: number?, z: string) : number type DateTypeArg = { year: number, @@ -115,7 +119,7 @@ declare os : { time: @checked (time: DateTypeArg?) -> number } -declare function @checked require(target : any) : any +@checked declare function require(target : any) : any )BUILTIN_SRC"; }; @@ -558,6 +562,10 @@ local E = require(script.Parent.A) TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "nonstrict_shouldnt_warn_on_valid_buffer_use") { + ScopedFastFlag flags[] = { + {FFlag::LuauAttributeSyntax, true}, + }; + loadDefinition(R"( declare buffer: { create: @checked (size: number) -> buffer, diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 6b4bcf22..8b2cc6ba 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -16,6 +16,7 @@ LUAU_FASTINT(LuauRecursionLimit); LUAU_FASTINT(LuauTypeLengthLimit); LUAU_FASTINT(LuauParseErrorLimit); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); +LUAU_FASTFLAG(LuauAttributeSyntax); LUAU_FASTFLAG(LuauLeadingBarAndAmpersand); namespace @@ -3051,9 +3052,10 @@ TEST_CASE_FIXTURE(Fixture, "parse_top_level_checked_fn") { ParseOptions opts; opts.allowDeclarationSyntax = true; + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; std::string src = R"BUILTIN_SRC( -declare function @checked abs(n: number): number +@checked declare function abs(n: number): number )BUILTIN_SRC"; ParseResult pr = tryParse(src, opts); @@ -3063,13 +3065,14 @@ declare function @checked abs(n: number): number AstStat* root = *(pr.root->body.data); auto func = root->as(); LUAU_ASSERT(func); - LUAU_ASSERT(func->checkedFunction); + LUAU_ASSERT(func->isCheckedFunction()); } TEST_CASE_FIXTURE(Fixture, "parse_declared_table_checked_member") { ParseOptions opts; opts.allowDeclarationSyntax = true; + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; const std::string src = R"BUILTIN_SRC( declare math : { @@ -3090,13 +3093,14 @@ TEST_CASE_FIXTURE(Fixture, "parse_declared_table_checked_member") auto prop = *tbl->props.data; auto func = prop.type->as(); LUAU_ASSERT(func); - LUAU_ASSERT(func->checkedFunction); + LUAU_ASSERT(func->isCheckedFunction()); } TEST_CASE_FIXTURE(Fixture, "parse_checked_outside_decl_fails") { ParseOptions opts; opts.allowDeclarationSyntax = true; + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; ParseResult pr = tryParse(R"( local @checked = 3 @@ -3110,10 +3114,11 @@ TEST_CASE_FIXTURE(Fixture, "parse_checked_in_and_out_of_decl_fails") { ParseOptions opts; opts.allowDeclarationSyntax = true; + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; auto pr = tryParse(R"( local @checked = 3 - declare function @checked abs(n: number): number + @checked declare function abs(n: number): number )", opts); LUAU_ASSERT(pr.errors.size() == 2); @@ -3125,9 +3130,10 @@ TEST_CASE_FIXTURE(Fixture, "parse_checked_as_function_name_fails") { ParseOptions opts; opts.allowDeclarationSyntax = true; + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; auto pr = tryParse(R"( - function @checked(x: number) : number + @checked function(x: number) : number end )", opts); @@ -3138,6 +3144,7 @@ TEST_CASE_FIXTURE(Fixture, "cannot_use_@_as_variable_name") { ParseOptions opts; opts.allowDeclarationSyntax = true; + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; auto pr = tryParse(R"( local @blah = 3 @@ -3190,4 +3197,300 @@ TEST_CASE_FIXTURE(Fixture, "mixed_leading_intersection_and_union_not_allowed") matchParseError("type A = | number & string & boolean", "Mixing union and intersection types is not allowed; consider wrapping in parentheses."); } +void checkAttribute(const AstAttr* attr, const AstAttr::Type type, const Location& location) +{ + CHECK_EQ(attr->type, type); + CHECK_EQ(attr->location, location); +} + +void checkFirstErrorForAttributes(const std::vector& errors, const size_t minSize, const Location& location, const std::string& message) +{ + LUAU_ASSERT(minSize >= 1); + + CHECK_GE(errors.size(), minSize); + CHECK_EQ(errors[0].getLocation(), location); + CHECK_EQ(errors[0].getMessage(), message); +} + +TEST_CASE_FIXTURE(Fixture, "parse_attribute_on_function_stat") +{ + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; + + AstStatBlock* stat = parse(R"( +@checked +function hello(x, y) + return x + y +end)"); + + LUAU_ASSERT(stat != nullptr); + + AstStatFunction* statFun = stat->body.data[0]->as(); + LUAU_ASSERT(statFun != nullptr); + + AstArray attributes = statFun->func->attributes; + + CHECK_EQ(attributes.size, 1); + + checkAttribute(attributes.data[0], AstAttr::Type::Checked, Location(Position(1, 0), Position(1, 8))); +} + +TEST_CASE_FIXTURE(Fixture, "parse_attribute_on_local_function_stat") +{ + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; + + AstStatBlock* stat = parse(R"( + @checked +local function hello(x, y) + return x + y +end)"); + + LUAU_ASSERT(stat != nullptr); + + AstStatLocalFunction* statFun = stat->body.data[0]->as(); + LUAU_ASSERT(statFun != nullptr); + + AstArray attributes = statFun->func->attributes; + + CHECK_EQ(attributes.size, 1); + + checkAttribute(attributes.data[0], AstAttr::Type::Checked, Location(Position(1, 4), Position(1, 12))); +} + +TEST_CASE_FIXTURE(Fixture, "empty_attribute_name_is_not_allowed") +{ + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; + + ParseResult result = tryParse(R"( +@ +function hello(x, y) + return x + y +end)"); + + checkFirstErrorForAttributes(result.errors, 1, Location(Position(1, 0), Position(1, 1)), "Attribute name is missing"); +} + +TEST_CASE_FIXTURE(Fixture, "dont_parse_attributes_on_non_function_stat") +{ + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; + + ParseResult pr1 = tryParse(R"( +@checked +if a<0 then a = 0 end)"); + checkFirstErrorForAttributes(pr1.errors, 1, Location(Position(2, 0), Position(2, 2)), + "Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got 'if' intead"); + + ParseResult pr2 = tryParse(R"( +local i = 1 +@checked +while a[i] do + print(a[i]) + i = i + 1 +end)"); + checkFirstErrorForAttributes(pr2.errors, 1, Location(Position(3, 0), Position(3, 5)), + "Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got 'while' intead"); + + ParseResult pr3 = tryParse(R"( +@checked +do + local a2 = 2*a + local d = sqrt(b^2 - 4*a*c) + x1 = (-b + d)/a2 + x2 = (-b - d)/a2 +end)"); + checkFirstErrorForAttributes(pr3.errors, 1, Location(Position(2, 0), Position(2, 2)), + "Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got 'do' intead"); + + ParseResult pr4 = tryParse(R"( +@checked +for i=1,10 do print(i) end +)"); + checkFirstErrorForAttributes(pr4.errors, 1, Location(Position(2, 0), Position(2, 3)), + "Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got 'for' intead"); + + ParseResult pr5 = tryParse(R"( +@checked +repeat + line = io.read() +until line ~= "" +)"); + checkFirstErrorForAttributes(pr5.errors, 1, Location(Position(2, 0), Position(2, 6)), + "Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got 'repeat' intead"); + + + ParseResult pr6 = tryParse(R"( +@checked +local x = 10 +)"); + checkFirstErrorForAttributes( + pr6.errors, 1, Location(Position(2, 6), Position(2, 7)), "Expected 'function' after local declaration with attribute, but got 'x' intead"); + + ParseResult pr7 = tryParse(R"( +local i = 1 +while a[i] do + if a[i] == v then @checked break end + i = i + 1 +end +)"); + checkFirstErrorForAttributes(pr7.errors, 1, Location(Position(3, 31), Position(3, 36)), + "Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got 'break' intead"); + + + ParseResult pr8 = tryParse(R"( +function foo1 () @checked return 'a' end +)"); + checkFirstErrorForAttributes(pr8.errors, 1, Location(Position(1, 26), Position(1, 32)), + "Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got 'return' intead"); +} + +TEST_CASE_FIXTURE(Fixture, "parse_attribute_on_function_type_declaration") +{ + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; + + ParseOptions opts; + opts.allowDeclarationSyntax = true; + + std::string src = R"( +@checked declare function abs(n: number): number +)"; + + ParseResult pr = tryParse(src, opts); + CHECK_EQ(pr.errors.size(), 0); + + LUAU_ASSERT(pr.root->body.size == 1); + + AstStat* root = *(pr.root->body.data); + + auto func = root->as(); + LUAU_ASSERT(func != nullptr); + + CHECK(func->isCheckedFunction()); + + AstArray attributes = func->attributes; + + checkAttribute(attributes.data[0], AstAttr::Type::Checked, Location(Position(1, 0), Position(1, 8))); +} + +TEST_CASE_FIXTURE(Fixture, "parse_attributes_on_function_type_declaration_in_table") +{ + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; + + ParseOptions opts; + opts.allowDeclarationSyntax = true; + + std::string src = R"( +declare bit32: { + band: @checked (...number) -> number +})"; + + ParseResult pr = tryParse(src, opts); + CHECK_EQ(pr.errors.size(), 0); + + LUAU_ASSERT(pr.root->body.size == 1); + + AstStat* root = *(pr.root->body.data); + + AstStatDeclareGlobal* glob = root->as(); + LUAU_ASSERT(glob); + + auto tbl = glob->type->as(); + LUAU_ASSERT(tbl); + + LUAU_ASSERT(tbl->props.size == 1); + AstTableProp prop = tbl->props.data[0]; + + AstTypeFunction* func = prop.type->as(); + LUAU_ASSERT(func); + + AstArray attributes = func->attributes; + + CHECK_EQ(attributes.size, 1); + checkAttribute(attributes.data[0], AstAttr::Type::Checked, Location(Position(2, 10), Position(2, 18))); +} + +TEST_CASE_FIXTURE(Fixture, "dont_parse_attributes_on_non_function_type_declarations") +{ + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; + + ParseOptions opts; + opts.allowDeclarationSyntax = true; + + ParseResult pr1 = tryParse(R"( +@checked declare foo: number + )", + opts); + + checkFirstErrorForAttributes( + pr1.errors, 1, Location(Position(1, 17), Position(1, 20)), "Expected a function type declaration after attribute, but got 'foo' intead"); + + ParseResult pr2 = tryParse(R"( +@checked declare class Foo + prop: number + function method(self, foo: number): string +end)", + opts); + + checkFirstErrorForAttributes( + pr2.errors, 1, Location(Position(1, 17), Position(1, 22)), "Expected a function type declaration after attribute, but got 'class' intead"); + + ParseResult pr3 = tryParse(R"( +declare bit32: { + band: @checked number +})", + opts); + + checkFirstErrorForAttributes( + pr3.errors, 1, Location(Position(2, 19), Position(2, 25)), "Expected '(' when parsing function parameters, got 'number'"); +} + +TEST_CASE_FIXTURE(Fixture, "attributes_cannot_be_duplicated") +{ + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; + + ParseResult result = tryParse(R"( +@checked + @checked +function hello(x, y) + return x + y +end)"); + + checkFirstErrorForAttributes(result.errors, 1, Location(Position(2, 4), Position(2, 12)), "Cannot duplicate attribute '@checked'"); +} + +TEST_CASE_FIXTURE(Fixture, "unsupported_attributes_are_not_allowed") +{ + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; + + ParseResult result = tryParse(R"( +@checked + @cool_attribute +function hello(x, y) + return x + y +end)"); + + checkFirstErrorForAttributes(result.errors, 1, Location(Position(2, 4), Position(2, 19)), "Invalid attribute '@cool_attribute'"); +} + +TEST_CASE_FIXTURE(Fixture, "can_parse_leading_bar_unions_successfully") +{ + ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand, true}; + + parse(R"(type A = | "Hello" | "World")"); +} + +TEST_CASE_FIXTURE(Fixture, "can_parse_leading_ampersand_intersections_successfully") +{ + ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand, true}; + + parse(R"(type A = & { string } & { number })"); +} + +TEST_CASE_FIXTURE(Fixture, "mixed_leading_intersection_and_union_not_allowed") +{ + ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand, true}; + + matchParseError("type A = & number | string | boolean", "Mixing union and intersection types is not allowed; consider wrapping in parentheses."); + matchParseError("type A = | number & string & boolean", "Mixing union and intersection types is not allowed; consider wrapping in parentheses."); +} + + TEST_SUITE_END(); diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index e2b3f9b7..d7cb225a 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -13,6 +13,7 @@ using namespace Luau; LUAU_FASTFLAG(LuauRecursiveTypeParameterRestriction); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAG(DebugLuauSharedSelf); +LUAU_FASTFLAG(LuauAttributeSyntax); TEST_SUITE_BEGIN("ToString"); @@ -1010,10 +1011,11 @@ TEST_CASE_FIXTURE(Fixture, "checked_fn_toString") { ScopedFastFlag flags[] = { {FFlag::DebugLuauDeferredConstraintResolution, true}, + {FFlag::LuauAttributeSyntax, true}, }; auto _result = loadDefinition(R"( -declare function @checked abs(n: number) : number +@checked declare function abs(n: number) : number )"); auto result = check(Mode::Nonstrict, R"( diff --git a/tests/TypeInfer.loops.test.cpp b/tests/TypeInfer.loops.test.cpp index 1a7ef973..ce6988aa 100644 --- a/tests/TypeInfer.loops.test.cpp +++ b/tests/TypeInfer.loops.test.cpp @@ -701,7 +701,7 @@ TEST_CASE_FIXTURE(Fixture, "loop_iter_basic") REQUIRE(ut); REQUIRE(ut->options.size() == 2); - CHECK_EQ(builtinTypes->nilType, ut->options[0]); + CHECK_EQ(builtinTypes->nilType, follow(ut->options[0])); CHECK_EQ(*builtinTypes->numberType, *ut->options[1]); } else @@ -1179,4 +1179,14 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "iteration_preserves_error_suppression") CHECK("any" == toString(requireTypeAtPosition({3, 25}))); } +TEST_CASE_FIXTURE(BuiltinsFixture, "tryDispatchIterableFunction_under_constrained_loop_should_not_assert") +{ + CheckResult result = check(R"( +local function foo(Instance) + for _, Child in next, Instance:GetChildren() do + end +end + )"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 2c6136a4..516a761b 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -3153,7 +3153,7 @@ TEST_CASE_FIXTURE(Fixture, "accidentally_checked_prop_in_opposite_branch") LUAU_REQUIRE_ERROR_COUNT(1, result); if (FFlag::DebugLuauDeferredConstraintResolution) - CHECK_EQ("Value of type '{ x: number? }?' could be nil", toString(result.errors[0])); + CHECK_EQ("Type 'nil' does not have key 'x'", toString(result.errors[0])); else CHECK_EQ("Value of type '{| x: number? |}?' could be nil", toString(result.errors[0])); CHECK_EQ("boolean", toString(requireType("u"))); @@ -4439,7 +4439,13 @@ TEST_CASE_FIXTURE(Fixture, "insert_a_and_f_of_a_into_table_res_in_a_loop") end )"); - LUAU_REQUIRE_NO_ERRORS(result); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(get(result.errors[0])); + } + else + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(BuiltinsFixture, "ipairs_adds_an_unbounded_indexer") diff --git a/tests/conformance/native_userdata.lua b/tests/conformance/native_userdata.lua new file mode 100644 index 00000000..b1b2a103 --- /dev/null +++ b/tests/conformance/native_userdata.lua @@ -0,0 +1,42 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +print('testing userdata') + +function ecall(fn, ...) + local ok, err = pcall(fn, ...) + assert(not ok) + return err:sub((err:find(": ") or -1) + 2, #err) +end + +local function realmad(a: vec2, b: vec2, c: vec2): vec2 + return -c + a * b; +end + +local function dm(s: vec2, t: vec2, u: vec2) + local x = s:Dot(t) + assert(x == 13) + + local t = u:Min(s) + assert(t.X == 5) + assert(t.Y == 4) +end + +local s: vec2 = vec2(5, 4) +local t: vec2 = vec2(1, 2) +local u: vec2 = vec2(10, 20) + +local x: vec2 = realmad(s, t, u) + +assert(x.X == -5) +assert(x.Y == -12) + +dm(s, t, u) + +local function mu(v: vec2) + assert(v.Magnitude == 2) + assert(v.Unit.X == 0) + assert(v.Unit.Y == 1) +end + +mu(vec2(0, 2)) + +return 'OK' diff --git a/tools/faillist.txt b/tools/faillist.txt index 7a214a32..b2677bf4 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -4,7 +4,6 @@ AutocompleteTest.anonymous_autofilled_generic_type_pack_vararg AutocompleteTest.autocomplete_string_singletons AutocompleteTest.do_wrong_compatible_nonself_calls AutocompleteTest.string_singleton_as_table_key -AutocompleteTest.string_singleton_in_if_statement2 AutocompleteTest.suggest_table_keys AutocompleteTest.type_correct_suggestion_for_overloads AutocompleteTest.type_correct_suggestion_in_table @@ -33,6 +32,15 @@ BuiltinTests.string_format_report_all_type_errors_at_correct_positions BuiltinTests.string_format_use_correct_argument2 BuiltinTests.table_freeze_is_generic BuiltinTests.tonumber_returns_optional_number_type +ControlFlowAnalysis.for_record_do_if_not_x_break +ControlFlowAnalysis.for_record_do_if_not_x_continue +ControlFlowAnalysis.if_not_x_break_elif_not_y_break +ControlFlowAnalysis.if_not_x_break_elif_not_y_continue +ControlFlowAnalysis.if_not_x_break_elif_rand_break_elif_not_y_break +ControlFlowAnalysis.if_not_x_continue_elif_not_y_continue +ControlFlowAnalysis.if_not_x_continue_elif_not_y_throw_elif_not_z_fallthrough +ControlFlowAnalysis.if_not_x_continue_elif_rand_continue_elif_not_y_continue +ControlFlowAnalysis.if_not_x_return_elif_not_y_break DefinitionTests.class_definition_overload_metamethods Differ.metatable_metamissing_left Differ.metatable_metamissing_right @@ -46,7 +54,6 @@ FrontendTest.trace_requires_in_nonstrict_mode GenericsTests.apply_type_function_nested_generics1 GenericsTests.better_mismatch_error_messages GenericsTests.bound_tables_do_not_clone_original_fields -GenericsTests.correctly_instantiate_polymorphic_member_functions GenericsTests.do_not_always_instantiate_generic_intersection_types GenericsTests.do_not_infer_generic_functions GenericsTests.dont_substitute_bound_types @@ -135,6 +142,7 @@ RefinementTest.discriminate_from_isa_of_x RefinementTest.discriminate_from_truthiness_of_x RefinementTest.globals_can_be_narrowed_too RefinementTest.isa_type_refinement_must_be_known_ahead_of_time +RefinementTest.nonoptional_type_can_narrow_to_nil_if_sense_is_true RefinementTest.not_t_or_some_prop_of_t RefinementTest.refine_a_param_that_got_resolved_during_constraint_solving_stage RefinementTest.refine_a_property_of_some_global @@ -278,7 +286,9 @@ TypeInferAnyError.can_subscript_any TypeInferAnyError.for_in_loop_iterator_is_any TypeInferAnyError.for_in_loop_iterator_is_any2 TypeInferAnyError.for_in_loop_iterator_is_any_pack +TypeInferAnyError.for_in_loop_iterator_returns_any TypeInferAnyError.for_in_loop_iterator_returns_any2 +TypeInferAnyError.replace_every_free_type_when_unifying_a_complex_function_with_any TypeInferClasses.callable_classes TypeInferClasses.cannot_unify_class_instance_with_primitive TypeInferClasses.class_type_mismatch_with_name_conflict @@ -337,6 +347,7 @@ TypeInferFunctions.too_many_arguments TypeInferFunctions.too_many_arguments_error_location TypeInferFunctions.too_many_return_values_in_parentheses TypeInferFunctions.too_many_return_values_no_function +TypeInferFunctions.unifier_should_not_bind_free_types TypeInferLoops.cli_68448_iterators_need_not_accept_nil TypeInferLoops.dcr_iteration_on_never_gives_never TypeInferLoops.dcr_xpath_candidates @@ -363,7 +374,6 @@ TypeInferModules.require TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_it_wont_help_2 TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_not_defined_with_colon TypeInferOOP.inferring_hundreds_of_self_calls_should_not_suffocate_memory -TypeInferOOP.methods_are_topologically_sorted TypeInferOOP.promise_type_error_too_complex TypeInferOperators.add_type_family_works TypeInferOperators.cli_38355_recursive_union