From af15c3cf17ed211f7f4a257cbc5bccd298c2e9e0 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Fri, 26 Apr 2024 11:14:13 -0700 Subject: [PATCH 01/20] CodeGen: Fix a typo in X64 (dis)assembler (#1238) --- CodeGen/src/AssemblyBuilderX64.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CodeGen/src/AssemblyBuilderX64.cpp b/CodeGen/src/AssemblyBuilderX64.cpp index bed7e0e3..f999d753 100644 --- a/CodeGen/src/AssemblyBuilderX64.cpp +++ b/CodeGen/src/AssemblyBuilderX64.cpp @@ -826,7 +826,7 @@ void AssemblyBuilderX64::vcvtss2sd(OperandX64 dst, OperandX64 src1, OperandX64 s else CODEGEN_ASSERT(src2.memSize == SizeX64::dword); - placeAvx("vcvtsd2ss", dst, src1, src2, 0x5a, false, AVX_0F, AVX_F3); + placeAvx("vcvtss2sd", dst, src1, src2, 0x5a, false, AVX_0F, AVX_F3); } void AssemblyBuilderX64::vroundsd(OperandX64 dst, OperandX64 src1, OperandX64 src2, RoundingModeX64 roundingMode) From f5303b3dd722eca5dd7a7021b23b91cb0e64fc46 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Mon, 29 Apr 2024 05:19:01 -0700 Subject: [PATCH 02/20] Make table.concat faster (#1243) table.concat is idiomatic and should be the fastest way to concatenate all table array elements together, but apparently you can beat it by using `string.format`, `string.rep` and `table.unpack`: ```lua string.format(string.rep("%*", #t), table.unpack(t)) ``` ... this just won't do, so we should fix table.concat performance. The deficit comes from two places: - rawgeti overhead followed by other stack accesses, all to extract a string from what is almost always an in-bounds array lookup - addlstring overhead in case separator is empty (extra function calls) This change fixes this by using a fast path for in-bounds array lookup for a string. Note that `table.concat` also supports numbers (these need to be converted to strings which is a little cumbersome and has innate overhead), and out-of-bounds accesses*. In these cases we fall back to the old implementation. To trigger out-of-bounds accesses, you need to skip the past-array-end element (which is nil per array invariant), but this is achievable because table.concat supports offset+length arguments. This should almost never come up in practice but the per-element branches et al are fairly cheap compared to the eventual string copy/alloc anyway. This change makes table.concat ~2x faster when the separator is empty; the table.concat benchmark shows +40% gains but it uses a variety of string separators of different lengths so it doesn't get the full benefit from this change. --------- Co-authored-by: vegorov-rbx <75688451+vegorov-rbx@users.noreply.github.com> --- VM/src/ltablib.cpp | 36 ++++++++++++++++++++++++------------ tests/conformance/tables.lua | 28 ++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 12 deletions(-) diff --git a/VM/src/ltablib.cpp b/VM/src/ltablib.cpp index 27c08f11..a57d6cf7 100644 --- a/VM/src/ltablib.cpp +++ b/VM/src/ltablib.cpp @@ -12,6 +12,7 @@ #include "lvm.h" LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauFastCrossTableMove, false) +LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauFasterConcat, false) static int foreachi(lua_State* L) { @@ -282,31 +283,42 @@ static int tmove(lua_State* L) return 1; } -static void addfield(lua_State* L, luaL_Strbuf* b, int i) +static void addfield(lua_State* L, luaL_Strbuf* b, int i, Table* t) { - int tt = lua_rawgeti(L, 1, i); - if (tt != LUA_TSTRING && tt != LUA_TNUMBER) - luaL_error(L, "invalid value (%s) at index %d in table for 'concat'", luaL_typename(L, -1), i); - luaL_addvalue(b); + if (DFFlag::LuauFasterConcat && t && unsigned(i - 1) < unsigned(t->sizearray) && ttisstring(&t->array[i - 1])) + { + TString* ts = tsvalue(&t->array[i - 1]); + luaL_addlstring(b, getstr(ts), ts->len); + } + else + { + int tt = lua_rawgeti(L, 1, i); + if (tt != LUA_TSTRING && tt != LUA_TNUMBER) + luaL_error(L, "invalid value (%s) at index %d in table for 'concat'", luaL_typename(L, -1), i); + luaL_addvalue(b); + } } static int tconcat(lua_State* L) { - luaL_Strbuf b; size_t lsep; - int i, last; const char* sep = luaL_optlstring(L, 2, "", &lsep); luaL_checktype(L, 1, LUA_TTABLE); - i = luaL_optinteger(L, 3, 1); - last = luaL_opt(L, luaL_checkinteger, 4, lua_objlen(L, 1)); + int i = luaL_optinteger(L, 3, 1); + int last = luaL_opt(L, luaL_checkinteger, 4, lua_objlen(L, 1)); + + Table* t = DFFlag::LuauFasterConcat ? hvalue(L->base) : NULL; + + luaL_Strbuf b; luaL_buffinit(L, &b); for (; i < last; i++) { - addfield(L, &b, i); - luaL_addlstring(&b, sep, lsep); + addfield(L, &b, i, t); + if (!DFFlag::LuauFasterConcat || lsep != 0) + luaL_addlstring(&b, sep, lsep); } if (i == last) // add last value (if interval was not empty) - addfield(L, &b, i); + addfield(L, &b, i, t); luaL_pushresult(&b); return 1; } diff --git a/tests/conformance/tables.lua b/tests/conformance/tables.lua index 03b46396..75163fd1 100644 --- a/tests/conformance/tables.lua +++ b/tests/conformance/tables.lua @@ -412,6 +412,34 @@ do assert(table.find({[(1)] = true}, true) == 1) end +-- test table.concat +do + -- regular usage + assert(table.concat({}) == "") + assert(table.concat({}, ",") == "") + assert(table.concat({"a", "b", "c"}, ",") == "a,b,c") + assert(table.concat({"a", "b", "c"}, ",", 2) == "b,c") + assert(table.concat({"a", "b", "c"}, ",", 1, 2) == "a,b") + + -- hash elements + local t = {} + t[123] = "a" + t[124] = "b" + + assert(table.concat(t) == "") + assert(table.concat(t, ",", 123, 124) == "a,b") + assert(table.concat(t, ",", 123, 123) == "a") + + -- numeric values + assert(table.concat({1, 2, 3}, ",") == "1,2,3") + assert(table.concat({"a", 2, "c"}, ",") == "a,2,c") + + -- error cases + assert(pcall(table.concat, "") == false) + assert(pcall(table.concat, t, false) == false) + assert(pcall(table.concat, t, ",", 1, 100) == false) +end + -- test indexing with strings that have zeroes embedded in them do local t = {} From 7edd58afede8950bde2042f9cfeed210242105ca Mon Sep 17 00:00:00 2001 From: vegorov-rbx <75688451+vegorov-rbx@users.noreply.github.com> Date: Thu, 2 May 2024 08:33:47 -0700 Subject: [PATCH 03/20] Add benchmarks for native compilation with type info enabled (#1244) --- .github/workflows/benchmark.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 7a11fbe1..8e1bf983 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -77,10 +77,12 @@ jobs: valgrind --tool=callgrind ./luau-compile --null -O1 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O1 | tee -a compile-output.txt valgrind --tool=callgrind ./luau-compile --null -O2 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O2 | tee -a compile-output.txt valgrind --tool=callgrind ./luau-compile --codegennull -O2 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O2-codegen | tee -a compile-output.txt + valgrind --tool=callgrind ./luau-compile --codegennull -O2 -t1 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O2-t1-codegen | tee -a compile-output.txt valgrind --tool=callgrind ./luau-compile --null -O0 bench/other/regex.lua 2>&1 | filter regex-O0 | tee -a compile-output.txt valgrind --tool=callgrind ./luau-compile --null -O1 bench/other/regex.lua 2>&1 | filter regex-O1 | tee -a compile-output.txt valgrind --tool=callgrind ./luau-compile --null -O2 bench/other/regex.lua 2>&1 | filter regex-O2 | tee -a compile-output.txt valgrind --tool=callgrind ./luau-compile --codegennull -O2 bench/other/regex.lua 2>&1 | filter regex-O2-codegen | tee -a compile-output.txt + valgrind --tool=callgrind ./luau-compile --codegennull -O2 -t1 bench/other/regex.lua 2>&1 | filter regex-O2-t1-codegen | tee -a compile-output.txt - name: Checkout benchmark results uses: actions/checkout@v3 From 8a64cb8b73996bd69c2734c607acd4b7d092358a Mon Sep 17 00:00:00 2001 From: Andy Friesen Date: Fri, 3 May 2024 13:17:51 -0700 Subject: [PATCH 04/20] Sync to upstream/release/624 (#1245) # What's changed? * Optimize table.maxn. This function is now 5-14x faster * Reserve Luau stack space for error message. ## New Solver * Globals can be type-stated, but only if they are already in scope * Fix a stack overflow that could occur when normalizing certain kinds of recursive unions of intersections (of unions of intersections...) * Fix an assertion failure that would trigger when the __iter metamethod has a bad signature ## Native Codegen * Type propagation and temporary register type hints * Direct vector property access should only happen for names of right length * BytecodeAnalysis will only predict that some of the vector value fields are numbers --- ## Internal Contributors Co-authored-by: Alexander McCord Co-authored-by: Andy Friesen Co-authored-by: Aviral Goel Co-authored-by: Vyacheslav Egorov --- Analysis/include/Luau/Normalize.h | 1 + Analysis/src/ConstraintGenerator.cpp | 10 +- Analysis/src/ConstraintSolver.cpp | 14 +- Analysis/src/Normalize.cpp | 121 ++++-- Analysis/src/TypeChecker2.cpp | 4 +- CLI/Compile.cpp | 1 + CMakeLists.txt | 1 + CodeGen/include/Luau/CodeGen.h | 4 + CodeGen/src/BytecodeAnalysis.cpp | 32 +- CodeGen/src/CodeGen.cpp | 32 ++ CodeGen/src/CodeGenAssembly.cpp | 110 +++++- CodeGen/src/IrTranslation.cpp | 7 +- Compiler/src/Builtins.cpp | 2 +- Compiler/src/Compiler.cpp | 59 ++- Compiler/src/Types.cpp | 551 ++++++++++++++++++++++++++- Compiler/src/Types.h | 20 +- VM/src/ldebug.cpp | 8 + VM/src/lmem.cpp | 14 +- VM/src/lstate.cpp | 1 + VM/src/lstate.h | 3 +- VM/src/ltablib.cpp | 43 ++- tests/IrLowering.test.cpp | 484 ++++++++++++++++++++++- tests/Normalize.test.cpp | 30 +- tests/TypeInfer.classes.test.cpp | 33 ++ tests/TypeInfer.tables.test.cpp | 16 + tests/TypeInfer.test.cpp | 39 +- tests/TypeInfer.typestates.test.cpp | 29 ++ tests/conformance/tables.lua | 4 + tools/faillist.txt | 4 +- tools/stackdbg.py | 94 +++++ 30 files changed, 1656 insertions(+), 115 deletions(-) create mode 100644 tools/stackdbg.py diff --git a/Analysis/include/Luau/Normalize.h b/Analysis/include/Luau/Normalize.h index 35e0c7a1..6d75568e 100644 --- a/Analysis/include/Luau/Normalize.h +++ b/Analysis/include/Luau/Normalize.h @@ -395,6 +395,7 @@ public: TypeId negate(TypeId there); void subtractPrimitive(NormalizedType& here, TypeId ty); void subtractSingleton(NormalizedType& here, TypeId ty); + NormalizationResult intersectNormalWithNegationTy(TypeId toNegate, NormalizedType& intersect, bool useDeprecated = false); // ------- Normalizing intersections TypeId intersectionOfTops(TypeId here, TypeId there); diff --git a/Analysis/src/ConstraintGenerator.cpp b/Analysis/src/ConstraintGenerator.cpp index 16e2014a..c559a256 100644 --- a/Analysis/src/ConstraintGenerator.cpp +++ b/Analysis/src/ConstraintGenerator.cpp @@ -1900,10 +1900,7 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprGlobal* globa return Inference{*ty, refinementArena.proposition(key, builtinTypes->truthyType)}; } else - { - reportError(global->location, UnknownSymbol{global->name.value, UnknownSymbol::Binding}); return Inference{builtinTypes->errorRecoveryType()}; - } } Inference ConstraintGenerator::checkIndexName(const ScopePtr& scope, const RefinementKey* key, AstExpr* indexee, const std::string& index, Location indexLocation) @@ -2453,7 +2450,12 @@ ConstraintGenerator::LValueBounds ConstraintGenerator::checkLValue(const ScopePt { std::optional annotatedTy = scope->lookup(Symbol{global->name}); if (annotatedTy) - return {annotatedTy, arena->addType(BlockedType{})}; + { + DefId def = dfg->getDef(global); + TypeId assignedTy = arena->addType(BlockedType{}); + rootScope->lvalueTypes[def] = assignedTy; + return {annotatedTy, assignedTy}; + } else return {annotatedTy, std::nullopt}; } diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 6a9dd031..ff56a37d 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -1619,9 +1619,7 @@ std::pair> ConstraintSolver::tryDispatchSetIndexer( { if (tt->indexer) { - if (isBlocked(tt->indexer->indexType)) - return {block(tt->indexer->indexType, constraint), std::nullopt}; - else if (isBlocked(tt->indexer->indexResultType)) + if (isBlocked(tt->indexer->indexResultType)) return {block(tt->indexer->indexResultType, constraint), std::nullopt}; unify(constraint, indexType, tt->indexer->indexType); @@ -2014,10 +2012,14 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl if (std::optional instantiatedNextFn = instantiate(builtinTypes, arena, NotNull{&limits}, constraint->scope, nextFn)) { const FunctionType* nextFn = get(*instantiatedNextFn); - LUAU_ASSERT(nextFn); - const TypePackId nextRetPack = nextFn->retTypes; - pushConstraint(constraint->scope, constraint->location, UnpackConstraint{c.variables, nextRetPack, /* resultIsLValue=*/true}); + // If nextFn is nullptr, then the iterator function has an improper signature. + if (nextFn) + { + const TypePackId nextRetPack = nextFn->retTypes; + pushConstraint(constraint->scope, constraint->location, UnpackConstraint{c.variables, nextRetPack, /* resultIsLValue=*/true}); + } + return true; } else diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index 848c8684..a124be66 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -20,6 +20,7 @@ LUAU_FASTFLAGVARIABLE(LuauNormalizeAwayUninhabitableTables, false) LUAU_FASTFLAGVARIABLE(LuauFixNormalizeCaching, false); LUAU_FASTFLAGVARIABLE(LuauNormalizeNotUnknownIntersection, false); LUAU_FASTFLAGVARIABLE(LuauFixCyclicUnionsOfIntersections, false); +LUAU_FASTFLAGVARIABLE(LuauFixReduceStackPressure, false); // This could theoretically be 2000 on amd64, but x86 requires this. LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200); @@ -36,6 +37,11 @@ static bool fixCyclicUnionsOfIntersections() return FFlag::LuauFixCyclicUnionsOfIntersections || FFlag::DebugLuauDeferredConstraintResolution; } +static bool fixReduceStackPressure() +{ + return FFlag::LuauFixReduceStackPressure || FFlag::DebugLuauDeferredConstraintResolution; +} + namespace Luau { @@ -45,6 +51,14 @@ static bool normalizeAwayUninhabitableTables() return FFlag::LuauNormalizeAwayUninhabitableTables || FFlag::DebugLuauDeferredConstraintResolution; } +static bool shouldEarlyExit(NormalizationResult res) +{ + // if res is hit limits, return control flow + if (res == NormalizationResult::HitLimits || res == NormalizationResult::False) + return true; + return false; +} + TypeIds::TypeIds(std::initializer_list tys) { for (TypeId ty : tys) @@ -1729,6 +1743,27 @@ bool Normalizer::withinResourceLimits() return true; } +NormalizationResult Normalizer::intersectNormalWithNegationTy(TypeId toNegate, NormalizedType& intersect, bool useDeprecated) +{ + + std::optional negated; + if (useDeprecated) + { + const NormalizedType* normal = DEPRECATED_normalize(toNegate); + negated = negateNormal(*normal); + } + else + { + std::shared_ptr normal = normalize(toNegate); + negated = negateNormal(*normal); + } + + if (!negated) + return NormalizationResult::False; + intersectNormals(intersect, *negated); + return NormalizationResult::True; +} + // See above for an explaination of `ignoreSmallerTyvars`. NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, Set& seenSetTypes, int ignoreSmallerTyvars) { @@ -2541,8 +2576,8 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there state = tttv->state; TypeLevel level = max(httv->level, tttv->level); - TableType result{state, level}; + std::unique_ptr result = nullptr; bool hereSubThere = true; bool thereSubHere = true; @@ -2563,8 +2598,18 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there if (tprop.readTy.has_value()) { // if the intersection of the read types of a property is uninhabited, the whole table is `never`. - if (normalizeAwayUninhabitableTables() && NormalizationResult::False == isIntersectionInhabited(*hprop.readTy, *tprop.readTy)) - return {builtinTypes->neverType}; + if (fixReduceStackPressure()) + { + if (normalizeAwayUninhabitableTables() && + NormalizationResult::True != isIntersectionInhabited(*hprop.readTy, *tprop.readTy)) + return {builtinTypes->neverType}; + } + else + { + if (normalizeAwayUninhabitableTables() && + NormalizationResult::False == isIntersectionInhabited(*hprop.readTy, *tprop.readTy)) + return {builtinTypes->neverType}; + } TypeId ty = simplifyIntersection(builtinTypes, NotNull{arena}, *hprop.readTy, *tprop.readTy).result; prop.readTy = ty; @@ -2614,14 +2659,21 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there // TODO: string indexers if (prop.readTy || prop.writeTy) - result.props[name] = prop; + { + if (!result.get()) + result = std::make_unique(TableType{state, level}); + result->props[name] = prop; + } } for (const auto& [name, tprop] : tttv->props) { if (httv->props.count(name) == 0) { - result.props[name] = tprop; + if (!result.get()) + result = std::make_unique(TableType{state, level}); + + result->props[name] = tprop; hereSubThere = false; } } @@ -2631,18 +2683,24 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there // TODO: What should intersection of indexes be? TypeId index = unionType(httv->indexer->indexType, tttv->indexer->indexType); TypeId indexResult = intersectionType(httv->indexer->indexResultType, tttv->indexer->indexResultType); - result.indexer = {index, indexResult}; + if (!result.get()) + result = std::make_unique(TableType{state, level}); + result->indexer = {index, indexResult}; hereSubThere &= (httv->indexer->indexType == index) && (httv->indexer->indexResultType == indexResult); thereSubHere &= (tttv->indexer->indexType == index) && (tttv->indexer->indexResultType == indexResult); } else if (httv->indexer) { - result.indexer = httv->indexer; + if (!result.get()) + result = std::make_unique(TableType{state, level}); + result->indexer = httv->indexer; thereSubHere = false; } else if (tttv->indexer) { - result.indexer = tttv->indexer; + if (!result.get()) + result = std::make_unique(TableType{state, level}); + result->indexer = tttv->indexer; hereSubThere = false; } @@ -2652,7 +2710,12 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there else if (thereSubHere) table = ttable; else - table = arena->addType(std::move(result)); + { + if (result.get()) + table = arena->addType(std::move(*result)); + else + table = arena->addType(TableType{state, level}); + } if (tmtable && hmtable) { @@ -3150,19 +3213,15 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type { if (fixNormalizeCaching()) { - std::shared_ptr normal = normalize(t); - std::optional negated = negateNormal(*normal); - if (!negated) - return NormalizationResult::False; - intersectNormals(here, *negated); + NormalizationResult res = intersectNormalWithNegationTy(t, here); + if (shouldEarlyExit(res)) + return res; } else { - const NormalizedType* normal = DEPRECATED_normalize(t); - std::optional negated = negateNormal(*normal); - if (!negated) - return NormalizationResult::False; - intersectNormals(here, *negated); + NormalizationResult res = intersectNormalWithNegationTy(t, here, /* useDeprecated */ true); + if (shouldEarlyExit(res)) + return res; } } else if (const UnionType* itv = get(t)) @@ -3171,11 +3230,9 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type { for (TypeId part : itv->options) { - std::shared_ptr normalPart = normalize(part); - std::optional negated = negateNormal(*normalPart); - if (!negated) - return NormalizationResult::False; - intersectNormals(here, *negated); + NormalizationResult res = intersectNormalWithNegationTy(part, here); + if (shouldEarlyExit(res)) + return res; } } else @@ -3184,22 +3241,18 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type { for (TypeId part : itv->options) { - std::shared_ptr normalPart = normalize(part); - std::optional negated = negateNormal(*normalPart); - if (!negated) - return NormalizationResult::False; - intersectNormals(here, *negated); + NormalizationResult res = intersectNormalWithNegationTy(part, here); + if (shouldEarlyExit(res)) + return res; } } else { for (TypeId part : itv->options) { - const NormalizedType* normalPart = DEPRECATED_normalize(part); - std::optional negated = negateNormal(*normalPart); - if (!negated) - return NormalizationResult::False; - intersectNormals(here, *negated); + NormalizationResult res = intersectNormalWithNegationTy(part, here, /* useDeprecated */ true); + if (shouldEarlyExit(res)) + return res; } } } diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index a888564e..faa5ffdb 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -1280,7 +1280,9 @@ struct TypeChecker2 void visit(AstExprGlobal* expr) { - // TODO! + NotNull scope = stack.back(); + if (!scope->lookup(expr->name)) + reportError(UnknownSymbol{expr->name.value, UnknownSymbol::Binding}, expr->location); } void visit(AstExprVarargs* expr) diff --git a/CLI/Compile.cpp b/CLI/Compile.cpp index 44a6ef77..dd6b14ab 100644 --- a/CLI/Compile.cpp +++ b/CLI/Compile.cpp @@ -317,6 +317,7 @@ static bool compileFile(const char* name, CompileFormat format, Luau::CodeGen::A { options.includeAssembly = format != CompileFormat::CodegenIr; options.includeIr = format != CompileFormat::CodegenAsm; + options.includeIrTypes = format != CompileFormat::CodegenAsm; options.includeOutlinedCode = format == CompileFormat::CodegenVerbose; } diff --git a/CMakeLists.txt b/CMakeLists.txt index 985cda1c..5b7e551e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -229,6 +229,7 @@ if(LUAU_BUILD_TESTS) target_link_libraries(Luau.UnitTest PRIVATE Luau.Analysis Luau.Compiler Luau.CodeGen) target_compile_options(Luau.Conformance PRIVATE ${LUAU_OPTIONS}) + target_compile_definitions(Luau.Conformance PRIVATE DOCTEST_CONFIG_DOUBLE_STRINGIFY) target_include_directories(Luau.Conformance PRIVATE extern) target_link_libraries(Luau.Conformance PRIVATE Luau.Analysis Luau.Compiler Luau.CodeGen Luau.VM) if(CMAKE_SYSTEM_NAME MATCHES "Android|iOS") diff --git a/CodeGen/include/Luau/CodeGen.h b/CodeGen/include/Luau/CodeGen.h index 9765035b..9b56034f 100644 --- a/CodeGen/include/Luau/CodeGen.h +++ b/CodeGen/include/Luau/CodeGen.h @@ -40,8 +40,12 @@ enum class CodeGenCompilationResult CodeGenAssemblerFinalizationFailure = 7, // Failure during assembler finalization CodeGenLoweringFailure = 8, // Lowering failed AllocationFailed = 9, // Native codegen failed due to an allocation error + + Count = 10, }; +std::string toString(const CodeGenCompilationResult& result); + struct ProtoCompilationFailure { CodeGenCompilationResult result = CodeGenCompilationResult::Success; diff --git a/CodeGen/src/BytecodeAnalysis.cpp b/CodeGen/src/BytecodeAnalysis.cpp index 7c39f5fc..e3ce9166 100644 --- a/CodeGen/src/BytecodeAnalysis.cpp +++ b/CodeGen/src/BytecodeAnalysis.cpp @@ -6,6 +6,9 @@ #include "Luau/IrUtils.h" #include "lobject.h" +#include "lstate.h" + +#include #include @@ -13,6 +16,7 @@ LUAU_FASTFLAG(LuauCodegenDirectUserdataFlow) LUAU_FASTFLAG(LuauLoadTypeInfo) // Because new VM typeinfo load changes the format used by Codegen, same flag is used LUAU_FASTFLAGVARIABLE(LuauCodegenTypeInfo, false) // New analysis is flagged separately LUAU_FASTFLAG(LuauTypeInfoLookupImprovement) +LUAU_FASTFLAGVARIABLE(LuauCodegenVectorMispredictFix, false) namespace Luau { @@ -771,10 +775,30 @@ void analyzeBytecodeTypes(IrFunction& function) regTags[ra] = LBC_TYPE_ANY; - // Assuming that vector component is being indexed - // TODO: check what key is used - if (bcType.a == LBC_TYPE_VECTOR) - regTags[ra] = LBC_TYPE_NUMBER; + if (FFlag::LuauCodegenVectorMispredictFix) + { + if (bcType.a == LBC_TYPE_VECTOR) + { + TString* str = gco2ts(function.proto->k[kc].value.gc); + const char* field = getstr(str); + + if (str->len == 1) + { + // Same handling as LOP_GETTABLEKS block in lvmexecute.cpp - case-insensitive comparison with "X" / "Y" / "Z" + char ch = field[0] | ' '; + + if (ch == 'x' || ch == 'y' || ch == 'z') + regTags[ra] = LBC_TYPE_NUMBER; + } + } + } + else + { + // Assuming that vector component is being indexed + // TODO: check what key is used + if (bcType.a == LBC_TYPE_VECTOR) + regTags[ra] = LBC_TYPE_NUMBER; + } bcType.result = regTags[ra]; break; diff --git a/CodeGen/src/CodeGen.cpp b/CodeGen/src/CodeGen.cpp index 9ef9980a..a5f6721e 100644 --- a/CodeGen/src/CodeGen.cpp +++ b/CodeGen/src/CodeGen.cpp @@ -65,6 +65,38 @@ namespace Luau namespace CodeGen { +std::string toString(const CodeGenCompilationResult& result) +{ + switch (result) + { + case CodeGenCompilationResult::Success: + return "Success"; + case CodeGenCompilationResult::NothingToCompile: + return "NothingToCompile"; + case CodeGenCompilationResult::NotNativeModule: + return "NotNativeModule"; + case CodeGenCompilationResult::CodeGenNotInitialized: + return "CodeGenNotInitialized"; + case CodeGenCompilationResult::CodeGenOverflowInstructionLimit: + return "CodeGenOverflowInstructionLimit"; + case CodeGenCompilationResult::CodeGenOverflowBlockLimit: + return "CodeGenOverflowBlockLimit"; + case CodeGenCompilationResult::CodeGenOverflowBlockInstructionLimit: + return "CodeGenOverflowBlockInstructionLimit"; + case CodeGenCompilationResult::CodeGenAssemblerFinalizationFailure: + return "CodeGenAssemblerFinalizationFailure"; + case CodeGenCompilationResult::CodeGenLoweringFailure: + return "CodeGenLoweringFailure"; + case CodeGenCompilationResult::AllocationFailed: + return "AllocationFailed"; + case CodeGenCompilationResult::Count: + return "Count"; + } + + CODEGEN_ASSERT(false); + return ""; +} + static const Instruction kCodeEntryInsn = LOP_NATIVECALL; void* gPerfLogContext = nullptr; diff --git a/CodeGen/src/CodeGenAssembly.cpp b/CodeGen/src/CodeGenAssembly.cpp index 96c73ce2..8324b7cc 100644 --- a/CodeGen/src/CodeGenAssembly.cpp +++ b/CodeGen/src/CodeGenAssembly.cpp @@ -13,12 +13,55 @@ #include "lapi.h" LUAU_FASTFLAG(LuauCodegenTypeInfo) +LUAU_FASTFLAGVARIABLE(LuauCodegenIrTypeNames, false) namespace Luau { namespace CodeGen { +static const LocVar* tryFindLocal(const Proto* proto, int reg, int pcpos) +{ + CODEGEN_ASSERT(FFlag::LuauCodegenIrTypeNames); + + for (int i = 0; i < proto->sizelocvars; i++) + { + const LocVar& local = proto->locvars[i]; + + if (reg == local.reg && pcpos >= local.startpc && pcpos < local.endpc) + return &local; + } + + return nullptr; +} + +const char* tryFindLocalName(const Proto* proto, int reg, int pcpos) +{ + CODEGEN_ASSERT(FFlag::LuauCodegenIrTypeNames); + + const LocVar* var = tryFindLocal(proto, reg, pcpos); + + if (var && var->varname) + return getstr(var->varname); + + return nullptr; +} + +const char* tryFindUpvalueName(const Proto* proto, int upval) +{ + CODEGEN_ASSERT(FFlag::LuauCodegenIrTypeNames); + + if (proto->upvalues) + { + CODEGEN_ASSERT(upval < proto->sizeupvalues); + + if (proto->upvalues[upval]) + return getstr(proto->upvalues[upval]); + } + + return nullptr; +} + template static void logFunctionHeader(AssemblyBuilder& build, Proto* proto) { @@ -29,12 +72,22 @@ static void logFunctionHeader(AssemblyBuilder& build, Proto* proto) for (int i = 0; i < proto->numparams; i++) { - LocVar* var = proto->locvars ? &proto->locvars[proto->sizelocvars - proto->numparams + i] : nullptr; - - if (var && var->varname) - build.logAppend("%s%s", i == 0 ? "" : ", ", getstr(var->varname)); + if (FFlag::LuauCodegenIrTypeNames) + { + if (const char* name = tryFindLocalName(proto, i, 0)) + build.logAppend("%s%s", i == 0 ? "" : ", ", name); + else + build.logAppend("%s$arg%d", i == 0 ? "" : ", ", i); + } else - build.logAppend("%s$arg%d", i == 0 ? "" : ", ", i); + { + LocVar* var = proto->locvars ? &proto->locvars[proto->sizelocvars - proto->numparams + i] : nullptr; + + if (var && var->varname) + build.logAppend("%s%s", i == 0 ? "" : ", ", getstr(var->varname)); + else + build.logAppend("%s$arg%d", i == 0 ? "" : ", ", i); + } } if (proto->numparams != 0 && proto->is_vararg) @@ -59,21 +112,58 @@ static void logFunctionTypes(AssemblyBuilder& build, const IrFunction& function) { uint8_t ty = typeInfo.argumentTypes[i]; - if (ty != LBC_TYPE_ANY) - build.logAppend("; R%d: %s [argument]\n", int(i), getBytecodeTypeName(ty)); + if (FFlag::LuauCodegenIrTypeNames) + { + if (ty != LBC_TYPE_ANY) + { + if (const char* name = tryFindLocalName(function.proto, int(i), 0)) + build.logAppend("; R%d: %s [argument '%s']\n", int(i), getBytecodeTypeName(ty), name); + else + build.logAppend("; R%d: %s [argument]\n", int(i), getBytecodeTypeName(ty)); + } + } + else + { + if (ty != LBC_TYPE_ANY) + build.logAppend("; R%d: %s [argument]\n", int(i), getBytecodeTypeName(ty)); + } } for (size_t i = 0; i < typeInfo.upvalueTypes.size(); i++) { uint8_t ty = typeInfo.upvalueTypes[i]; - if (ty != LBC_TYPE_ANY) - build.logAppend("; U%d: %s\n", int(i), getBytecodeTypeName(ty)); + if (FFlag::LuauCodegenIrTypeNames) + { + if (ty != LBC_TYPE_ANY) + { + if (const char* name = tryFindUpvalueName(function.proto, int(i))) + build.logAppend("; U%d: %s ['%s']\n", int(i), getBytecodeTypeName(ty), name); + else + build.logAppend("; U%d: %s\n", int(i), getBytecodeTypeName(ty)); + } + } + else + { + if (ty != LBC_TYPE_ANY) + build.logAppend("; U%d: %s\n", int(i), getBytecodeTypeName(ty)); + } } for (const BytecodeRegTypeInfo& el : typeInfo.regTypes) { - build.logAppend("; R%d: %s from %d to %d\n", el.reg, getBytecodeTypeName(el.type), el.startpc, el.endpc); + if (FFlag::LuauCodegenIrTypeNames) + { + // Using last active position as the PC because 'startpc' for type info is before local is initialized + if (const char* name = tryFindLocalName(function.proto, el.reg, el.endpc - 1)) + build.logAppend("; R%d: %s from %d to %d [local '%s']\n", el.reg, getBytecodeTypeName(el.type), el.startpc, el.endpc, name); + else + build.logAppend("; R%d: %s from %d to %d\n", el.reg, getBytecodeTypeName(el.type), el.startpc, el.endpc); + } + else + { + build.logAppend("; R%d: %s from %d to %d\n", el.reg, getBytecodeTypeName(el.type), el.startpc, el.endpc); + } } } diff --git a/CodeGen/src/IrTranslation.cpp b/CodeGen/src/IrTranslation.cpp index 20150f9a..84e3b639 100644 --- a/CodeGen/src/IrTranslation.cpp +++ b/CodeGen/src/IrTranslation.cpp @@ -13,6 +13,7 @@ #include "ltm.h" LUAU_FASTFLAGVARIABLE(LuauCodegenDirectUserdataFlow, false) +LUAU_FASTFLAGVARIABLE(LuauCodegenFixVectorFields, false) namespace Luau { @@ -1197,19 +1198,19 @@ void translateInstGetTableKS(IrBuilder& build, const Instruction* pc, int pcpos) TString* str = gco2ts(build.function.proto->k[aux].value.gc); const char* field = getstr(str); - if (*field == 'X' || *field == 'x') + if ((!FFlag::LuauCodegenFixVectorFields || str->len == 1) && (*field == 'X' || *field == 'x')) { IrOp value = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(rb), build.constInt(0)); build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), value); build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); } - else if (*field == 'Y' || *field == 'y') + else if ((!FFlag::LuauCodegenFixVectorFields || str->len == 1) && (*field == 'Y' || *field == 'y')) { IrOp value = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(rb), build.constInt(4)); build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), value); build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); } - else if (*field == 'Z' || *field == 'z') + else if ((!FFlag::LuauCodegenFixVectorFields || str->len == 1) && (*field == 'Z' || *field == 'z')) { IrOp value = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(rb), build.constInt(8)); build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), value); diff --git a/Compiler/src/Builtins.cpp b/Compiler/src/Builtins.cpp index 2b09b7e0..c576e3a4 100644 --- a/Compiler/src/Builtins.cpp +++ b/Compiler/src/Builtins.cpp @@ -454,7 +454,7 @@ BuiltinInfo getBuiltinInfo(int bfid) case LBF_BUFFER_WRITEF32: case LBF_BUFFER_WRITEF64: return {3, 0, BuiltinInfo::Flag_NoneSafe}; - }; + } LUAU_UNREACHABLE(); } diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index df096d3a..d5cd78a5 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -29,6 +29,7 @@ LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5) LUAU_FASTFLAGVARIABLE(LuauCompileRepeatUntilSkippedLocals, false) LUAU_FASTFLAG(LuauCompileTypeInfo) LUAU_FASTFLAGVARIABLE(LuauTypeInfoLookupImprovement, false) +LUAU_FASTFLAGVARIABLE(LuauCompileTempTypeInfo, false) namespace Luau { @@ -108,6 +109,8 @@ struct Compiler , builtins(nullptr) , functionTypes(nullptr) , localTypes(nullptr) + , exprTypes(nullptr) + , builtinTypes(options.vectorType) { // preallocate some buffers that are very likely to grow anyway; this works around std::vector's inefficient growth policy for small arrays localStack.reserve(16); @@ -916,6 +919,9 @@ struct Compiler bytecode.emitABC(LOP_NAMECALL, regs, selfreg, uint8_t(BytecodeBuilder::getStringHash(iname))); bytecode.emitAux(cid); + + if (FFlag::LuauCompileTempTypeInfo) + hintTemporaryExprRegType(fi->expr, selfreg, LBC_TYPE_TABLE, /* instLength */ 2); } else if (bfid >= 0) { @@ -1570,6 +1576,9 @@ struct Compiler uint8_t rl = compileExprAuto(expr->left, rs); bytecode.emitABC(getBinaryOpArith(expr->op, /* k= */ true), target, rl, uint8_t(rc)); + + if (FFlag::LuauCompileTempTypeInfo) + hintTemporaryExprRegType(expr->left, rl, LBC_TYPE_NUMBER, /* instLength */ 1); } else { @@ -1583,6 +1592,9 @@ struct Compiler LuauOpcode op = (expr->op == AstExprBinary::Sub) ? LOP_SUBRK : LOP_DIVRK; bytecode.emitABC(op, target, uint8_t(lc), uint8_t(rr)); + + if (FFlag::LuauCompileTempTypeInfo) + hintTemporaryExprRegType(expr->right, rr, LBC_TYPE_NUMBER, /* instLength */ 1); return; } } @@ -1591,6 +1603,12 @@ struct Compiler uint8_t rr = compileExprAuto(expr->right, rs); bytecode.emitABC(getBinaryOpArith(expr->op), target, rl, rr); + + if (FFlag::LuauCompileTempTypeInfo) + { + hintTemporaryExprRegType(expr->left, rl, LBC_TYPE_NUMBER, /* instLength */ 1); + hintTemporaryExprRegType(expr->right, rr, LBC_TYPE_NUMBER, /* instLength */ 1); + } } } break; @@ -2030,6 +2048,9 @@ struct Compiler bytecode.emitABC(LOP_GETTABLEKS, target, reg, uint8_t(BytecodeBuilder::getStringHash(iname))); bytecode.emitAux(cid); + + if (FFlag::LuauCompileTempTypeInfo) + hintTemporaryExprRegType(expr->expr, reg, LBC_TYPE_TABLE, /* instLength */ 2); } void compileExprIndexExpr(AstExprIndexExpr* expr, uint8_t target) @@ -3410,6 +3431,14 @@ struct Compiler uint8_t rr = compileExprAuto(stat->value, rs); bytecode.emitABC(getBinaryOpArith(stat->op), target, target, rr); + + if (FFlag::LuauCompileTempTypeInfo) + { + if (var.kind != LValue::Kind_Local) + hintTemporaryRegType(stat->var, target, LBC_TYPE_NUMBER, /* instLength */ 1); + + hintTemporaryExprRegType(stat->value, rr, LBC_TYPE_NUMBER, /* instLength */ 1); + } } } break; @@ -3794,6 +3823,27 @@ struct Compiler return !node->is() && !node->is(); } + void hintTemporaryRegType(AstExpr* expr, int reg, LuauBytecodeType expectedType, int instLength) + { + LUAU_ASSERT(FFlag::LuauCompileTempTypeInfo); + + // If we know the type of a temporary and it's not the type that would be expected by codegen, provide a hint + if (LuauBytecodeType* ty = exprTypes.find(expr)) + { + if (*ty != expectedType) + bytecode.pushLocalTypeInfo(*ty, reg, bytecode.getDebugPC() - instLength, bytecode.getDebugPC()); + } + } + + void hintTemporaryExprRegType(AstExpr* expr, int reg, LuauBytecodeType expectedType, int instLength) + { + LUAU_ASSERT(FFlag::LuauCompileTempTypeInfo); + + // If we allocated a temporary register for the operation argument, try hinting its type + if (!getExprLocal(expr)) + hintTemporaryRegType(expr, reg, expectedType, instLength); + } + struct FenvVisitor : AstVisitor { bool& getfenvUsed; @@ -4046,6 +4096,9 @@ struct Compiler DenseHashMap builtins; DenseHashMap functionTypes; DenseHashMap localTypes; + DenseHashMap exprTypes; + + BuiltinTypes builtinTypes; const DenseHashMap* builtinsFold = nullptr; bool builtinsFoldMathK = false; @@ -4141,12 +4194,14 @@ void compileOrThrow(BytecodeBuilder& bytecode, const ParseResult& parseResult, c if (FFlag::LuauCompileTypeInfo) { if (options.typeInfoLevel >= 1) - buildTypeMap(compiler.functionTypes, compiler.localTypes, root, options.vectorType); + buildTypeMap(compiler.functionTypes, compiler.localTypes, compiler.exprTypes, root, options.vectorType, compiler.builtinTypes, + compiler.builtins, compiler.globals); } else { if (functionVisitor.hasTypes) - buildTypeMap(compiler.functionTypes, compiler.localTypes, root, options.vectorType); + buildTypeMap(compiler.functionTypes, compiler.localTypes, compiler.exprTypes, root, options.vectorType, compiler.builtinTypes, + compiler.builtins, compiler.globals); } for (AstExprFunction* expr : functions) diff --git a/Compiler/src/Types.cpp b/Compiler/src/Types.cpp index d05a7ba2..eaa2d8be 100644 --- a/Compiler/src/Types.cpp +++ b/Compiler/src/Types.cpp @@ -4,6 +4,7 @@ #include "Luau/BytecodeBuilder.h" LUAU_FASTFLAG(LuauCompileTypeInfo) +LUAU_FASTFLAG(LuauCompileTempTypeInfo) namespace Luau { @@ -37,10 +38,10 @@ static LuauBytecodeType getPrimitiveType(AstName name) return LBC_TYPE_INVALID; } -static LuauBytecodeType getType(AstType* ty, const AstArray& generics, const DenseHashMap& typeAliases, - bool resolveAliases, const char* vectorType) +static LuauBytecodeType getType(const AstType* ty, const AstArray& generics, + const DenseHashMap& typeAliases, bool resolveAliases, const char* vectorType) { - if (AstTypeReference* ref = ty->as()) + if (const AstTypeReference* ref = ty->as()) { if (ref->prefix) return LBC_TYPE_ANY; @@ -66,15 +67,15 @@ static LuauBytecodeType getType(AstType* ty, const AstArray& gen // not primitive or alias or generic => host-provided, we assume userdata for now return LBC_TYPE_USERDATA; } - else if (AstTypeTable* table = ty->as()) + else if (const AstTypeTable* table = ty->as()) { return LBC_TYPE_TABLE; } - else if (AstTypeFunction* func = ty->as()) + else if (const AstTypeFunction* func = ty->as()) { return LBC_TYPE_FUNCTION; } - else if (AstTypeUnion* un = ty->as()) + else if (const AstTypeUnion* un = ty->as()) { bool optional = false; LuauBytecodeType type = LBC_TYPE_INVALID; @@ -104,7 +105,7 @@ static LuauBytecodeType getType(AstType* ty, const AstArray& gen return LuauBytecodeType(type | (optional && (type != LBC_TYPE_ANY) ? LBC_TYPE_OPTIONAL_BIT : 0)); } - else if (AstTypeIntersection* inter = ty->as()) + else if (const AstTypeIntersection* inter = ty->as()) { return LBC_TYPE_ANY; } @@ -144,21 +145,44 @@ static std::string getFunctionType(const AstExprFunction* func, const DenseHashM return typeInfo; } +static bool isMatchingGlobal(const DenseHashMap& globals, AstExpr* node, const char* name) +{ + LUAU_ASSERT(FFlag::LuauCompileTempTypeInfo); + + if (AstExprGlobal* expr = node->as()) + return Compile::getGlobalState(globals, expr->name) == Compile::Global::Default && expr->name == name; + + return false; +} + struct TypeMapVisitor : AstVisitor { DenseHashMap& functionTypes; DenseHashMap& localTypes; + DenseHashMap& exprTypes; const char* vectorType; + const BuiltinTypes& builtinTypes; + const DenseHashMap& builtinCalls; + const DenseHashMap& globals; DenseHashMap typeAliases; std::vector> typeAliasStack; + DenseHashMap resolvedLocals; + DenseHashMap resolvedExprs; - TypeMapVisitor( - DenseHashMap& functionTypes, DenseHashMap& localTypes, const char* vectorType) + TypeMapVisitor(DenseHashMap& functionTypes, DenseHashMap& localTypes, + DenseHashMap& exprTypes, const char* vectorType, const BuiltinTypes& builtinTypes, + const DenseHashMap& builtinCalls, const DenseHashMap& globals) : functionTypes(functionTypes) , localTypes(localTypes) + , exprTypes(exprTypes) , vectorType(vectorType) + , builtinTypes(builtinTypes) + , builtinCalls(builtinCalls) + , globals(globals) , typeAliases(AstName()) + , resolvedLocals(nullptr) + , resolvedExprs(nullptr) { } @@ -189,6 +213,64 @@ struct TypeMapVisitor : AstVisitor } } + const AstType* resolveAliases(const AstType* ty) + { + LUAU_ASSERT(FFlag::LuauCompileTempTypeInfo); + + if (const AstTypeReference* ref = ty->as()) + { + if (ref->prefix) + return ty; + + if (AstStatTypeAlias* const* alias = typeAliases.find(ref->name); alias && *alias) + return (*alias)->type; + } + + return ty; + } + + const AstTableIndexer* tryGetTableIndexer(AstExpr* expr) + { + LUAU_ASSERT(FFlag::LuauCompileTempTypeInfo); + + if (const AstType** typePtr = resolvedExprs.find(expr)) + { + if (const AstTypeTable* tableTy = (*typePtr)->as()) + return tableTy->indexer; + } + + return nullptr; + } + + LuauBytecodeType recordResolvedType(AstExpr* expr, const AstType* ty) + { + LUAU_ASSERT(FFlag::LuauCompileTempTypeInfo); + + ty = resolveAliases(ty); + + resolvedExprs[expr] = ty; + + LuauBytecodeType bty = getType(ty, {}, typeAliases, /* resolveAliases= */ true, vectorType); + exprTypes[expr] = bty; + return bty; + } + + LuauBytecodeType recordResolvedType(AstLocal* local, const AstType* ty) + { + LUAU_ASSERT(FFlag::LuauCompileTempTypeInfo); + + ty = resolveAliases(ty); + + resolvedLocals[local] = ty; + + LuauBytecodeType bty = getType(ty, {}, typeAliases, /* resolveAliases= */ true, vectorType); + + if (bty != LBC_TYPE_ANY) + localTypes[local] = bty; + + return bty; + } + bool visit(AstStatBlock* node) override { size_t aliasStackTop = pushTypeAliases(node); @@ -216,6 +298,60 @@ struct TypeMapVisitor : AstVisitor return false; } + // for...in statement can contain type annotations on locals (we might even infer some for ipairs/pairs/generalized iteration) + bool visit(AstStatForIn* node) override + { + if (!FFlag::LuauCompileTempTypeInfo) + return true; + + for (AstExpr* expr : node->values) + expr->visit(this); + + // This is similar to how Compiler matches builtin iteration, but we also handle generalized iteration case + if (node->vars.size == 2 && node->values.size == 1) + { + if (AstExprCall* call = node->values.data[0]->as(); call && call->args.size == 1) + { + AstExpr* func = call->func; + AstExpr* arg = call->args.data[0]; + + if (isMatchingGlobal(globals, func, "ipairs")) + { + if (const AstTableIndexer* indexer = tryGetTableIndexer(arg)) + { + recordResolvedType(node->vars.data[0], &builtinTypes.numberType); + recordResolvedType(node->vars.data[1], indexer->resultType); + } + } + else if (isMatchingGlobal(globals, func, "pairs")) + { + if (const AstTableIndexer* indexer = tryGetTableIndexer(arg)) + { + recordResolvedType(node->vars.data[0], indexer->indexType); + recordResolvedType(node->vars.data[1], indexer->resultType); + } + } + } + else if (const AstTableIndexer* indexer = tryGetTableIndexer(node->values.data[0])) + { + recordResolvedType(node->vars.data[0], indexer->indexType); + recordResolvedType(node->vars.data[1], indexer->resultType); + } + } + + for (size_t i = 0; i < node->vars.size; i++) + { + AstLocal* var = node->vars.data[i]; + + if (AstType* annotation = var->annotation) + recordResolvedType(var, annotation); + } + + node->body->visit(this); + + return false; + } + bool visit(AstExprFunction* node) override { std::string type = getFunctionType(node, typeAliases, vectorType); @@ -223,32 +359,405 @@ struct TypeMapVisitor : AstVisitor if (!type.empty()) functionTypes[node] = std::move(type); - return true; + return true; // Let generic visitor step into all expressions } bool visit(AstExprLocal* node) override { - if (FFlag::LuauCompileTypeInfo) + if (FFlag::LuauCompileTempTypeInfo) { - AstLocal* local = node->local; - - if (AstType* annotation = local->annotation) + if (FFlag::LuauCompileTypeInfo) { - LuauBytecodeType ty = getType(annotation, {}, typeAliases, /* resolveAliases= */ true, vectorType); + AstLocal* local = node->local; - if (ty != LBC_TYPE_ANY) - localTypes[local] = ty; + if (AstType* annotation = local->annotation) + { + LuauBytecodeType ty = recordResolvedType(node, annotation); + + if (ty != LBC_TYPE_ANY) + localTypes[local] = ty; + } + else if (const AstType** typePtr = resolvedLocals.find(local)) + { + localTypes[local] = recordResolvedType(node, *typePtr); + } + } + + return false; + } + else + { + if (FFlag::LuauCompileTypeInfo) + { + AstLocal* local = node->local; + + if (AstType* annotation = local->annotation) + { + LuauBytecodeType ty = getType(annotation, {}, typeAliases, /* resolveAliases= */ true, vectorType); + + if (ty != LBC_TYPE_ANY) + localTypes[local] = ty; + } + } + + return true; + } + } + + bool visit(AstStatLocal* node) override + { + if (!FFlag::LuauCompileTempTypeInfo) + return true; + + for (AstExpr* expr : node->values) + expr->visit(this); + + for (size_t i = 0; i < node->vars.size; i++) + { + AstLocal* var = node->vars.data[i]; + + // Propagate from the value that's being assigned + // This simple propagation doesn't handle type packs in tail position + if (var->annotation == nullptr) + { + if (i < node->values.size) + { + if (const AstType** typePtr = resolvedExprs.find(node->values.data[i])) + resolvedLocals[var] = *typePtr; + } } } - return true; + return false; } + + bool visit(AstExprIndexExpr* node) override + { + if (!FFlag::LuauCompileTempTypeInfo) + return true; + + node->expr->visit(this); + node->index->visit(this); + + if (const AstTableIndexer* indexer = tryGetTableIndexer(node->expr)) + recordResolvedType(node, indexer->resultType); + + return false; + } + + bool visit(AstExprIndexName* node) override + { + if (!FFlag::LuauCompileTempTypeInfo) + return true; + + node->expr->visit(this); + + if (const AstType** typePtr = resolvedExprs.find(node->expr)) + { + if (const AstTypeTable* tableTy = (*typePtr)->as()) + { + for (const AstTableProp& prop : tableTy->props) + { + if (prop.name == node->index) + { + recordResolvedType(node, prop.type); + return false; + } + } + } + } + + if (LuauBytecodeType* typeBcPtr = exprTypes.find(node->expr)) + { + if (*typeBcPtr == LBC_TYPE_VECTOR) + { + if (node->index == "X" || node->index == "Y" || node->index == "Z") + recordResolvedType(node, &builtinTypes.numberType); + } + } + + return false; + } + + bool visit(AstExprUnary* node) override + { + if (!FFlag::LuauCompileTempTypeInfo) + return true; + + node->expr->visit(this); + + switch (node->op) + { + case AstExprUnary::Not: + recordResolvedType(node, &builtinTypes.booleanType); + break; + case AstExprUnary::Minus: + { + const AstType** typePtr = resolvedExprs.find(node->expr); + LuauBytecodeType* bcTypePtr = exprTypes.find(node->expr); + + if (!typePtr || !bcTypePtr) + return false; + + if (*bcTypePtr == LBC_TYPE_VECTOR) + recordResolvedType(node, *typePtr); + else if (*bcTypePtr == LBC_TYPE_NUMBER) + recordResolvedType(node, *typePtr); + + break; + } + case AstExprUnary::Len: + recordResolvedType(node, &builtinTypes.numberType); + break; + } + + return false; + } + + bool visit(AstExprBinary* node) override + { + if (!FFlag::LuauCompileTempTypeInfo) + return true; + + node->left->visit(this); + node->right->visit(this); + + // Comparisons result in a boolean + if (node->op == AstExprBinary::CompareNe || node->op == AstExprBinary::CompareEq || node->op == AstExprBinary::CompareLt || + node->op == AstExprBinary::CompareLe || node->op == AstExprBinary::CompareGt || node->op == AstExprBinary::CompareGe) + { + recordResolvedType(node, &builtinTypes.booleanType); + return false; + } + + if (node->op == AstExprBinary::Concat || node->op == AstExprBinary::And || node->op == AstExprBinary::Or) + return false; + + const AstType** leftTypePtr = resolvedExprs.find(node->left); + LuauBytecodeType* leftBcTypePtr = exprTypes.find(node->left); + + if (!leftTypePtr || !leftBcTypePtr) + return false; + + const AstType** rightTypePtr = resolvedExprs.find(node->right); + LuauBytecodeType* rightBcTypePtr = exprTypes.find(node->right); + + if (!rightTypePtr || !rightBcTypePtr) + return false; + + if (*leftBcTypePtr == LBC_TYPE_VECTOR) + recordResolvedType(node, *leftTypePtr); + else if (*rightBcTypePtr == LBC_TYPE_VECTOR) + recordResolvedType(node, *rightTypePtr); + else if (*leftBcTypePtr == LBC_TYPE_NUMBER && *rightBcTypePtr == LBC_TYPE_NUMBER) + recordResolvedType(node, *leftTypePtr); + + return false; + } + + bool visit(AstExprGroup* node) override + { + if (!FFlag::LuauCompileTempTypeInfo) + return true; + + node->expr->visit(this); + + if (const AstType** typePtr = resolvedExprs.find(node->expr)) + recordResolvedType(node, *typePtr); + + return false; + } + + bool visit(AstExprTypeAssertion* node) override + { + if (!FFlag::LuauCompileTempTypeInfo) + return true; + + node->expr->visit(this); + + recordResolvedType(node, node->annotation); + + return false; + } + + bool visit(AstExprConstantBool* node) override + { + if (!FFlag::LuauCompileTempTypeInfo) + return true; + + recordResolvedType(node, &builtinTypes.booleanType); + + return false; + } + + bool visit(AstExprConstantNumber* node) override + { + if (!FFlag::LuauCompileTempTypeInfo) + return true; + + recordResolvedType(node, &builtinTypes.numberType); + + return false; + } + + bool visit(AstExprConstantString* node) override + { + if (!FFlag::LuauCompileTempTypeInfo) + return true; + + recordResolvedType(node, &builtinTypes.stringType); + + return false; + } + + bool visit(AstExprInterpString* node) override + { + if (!FFlag::LuauCompileTempTypeInfo) + return true; + + recordResolvedType(node, &builtinTypes.stringType); + + return false; + } + + bool visit(AstExprIfElse* node) override + { + if (!FFlag::LuauCompileTempTypeInfo) + return true; + + node->condition->visit(this); + node->trueExpr->visit(this); + node->falseExpr->visit(this); + + const AstType** trueTypePtr = resolvedExprs.find(node->trueExpr); + LuauBytecodeType* trueBcTypePtr = exprTypes.find(node->trueExpr); + LuauBytecodeType* falseBcTypePtr = exprTypes.find(node->falseExpr); + + // Optimistic check that both expressions are of the same kind, as AstType* cannot be compared + if (trueTypePtr && trueBcTypePtr && falseBcTypePtr && *trueBcTypePtr == *falseBcTypePtr) + recordResolvedType(node, *trueTypePtr); + + return false; + } + + bool visit(AstExprCall* node) override + { + if (!FFlag::LuauCompileTempTypeInfo) + return true; + + if (const int* bfid = builtinCalls.find(node)) + { + switch (LuauBuiltinFunction(*bfid)) + { + case LBF_NONE: + case LBF_ASSERT: + case LBF_RAWSET: + case LBF_RAWGET: + case LBF_TABLE_INSERT: + case LBF_TABLE_UNPACK: + case LBF_SELECT_VARARG: + case LBF_GETMETATABLE: + case LBF_SETMETATABLE: + case LBF_BUFFER_WRITEU8: + case LBF_BUFFER_WRITEU16: + case LBF_BUFFER_WRITEU32: + case LBF_BUFFER_WRITEF32: + case LBF_BUFFER_WRITEF64: + break; + case LBF_MATH_ABS: + case LBF_MATH_ACOS: + case LBF_MATH_ASIN: + case LBF_MATH_ATAN2: + case LBF_MATH_ATAN: + case LBF_MATH_CEIL: + case LBF_MATH_COSH: + case LBF_MATH_COS: + case LBF_MATH_DEG: + case LBF_MATH_EXP: + case LBF_MATH_FLOOR: + case LBF_MATH_FMOD: + case LBF_MATH_FREXP: + case LBF_MATH_LDEXP: + case LBF_MATH_LOG10: + case LBF_MATH_LOG: + case LBF_MATH_MAX: + case LBF_MATH_MIN: + case LBF_MATH_MODF: + case LBF_MATH_POW: + case LBF_MATH_RAD: + case LBF_MATH_SINH: + case LBF_MATH_SIN: + case LBF_MATH_SQRT: + case LBF_MATH_TANH: + case LBF_MATH_TAN: + case LBF_BIT32_ARSHIFT: + case LBF_BIT32_BAND: + case LBF_BIT32_BNOT: + case LBF_BIT32_BOR: + case LBF_BIT32_BXOR: + case LBF_BIT32_BTEST: + case LBF_BIT32_EXTRACT: + case LBF_BIT32_LROTATE: + case LBF_BIT32_LSHIFT: + case LBF_BIT32_REPLACE: + case LBF_BIT32_RROTATE: + case LBF_BIT32_RSHIFT: + case LBF_STRING_BYTE: + case LBF_STRING_LEN: + case LBF_MATH_CLAMP: + case LBF_MATH_SIGN: + case LBF_MATH_ROUND: + case LBF_BIT32_COUNTLZ: + case LBF_BIT32_COUNTRZ: + case LBF_RAWLEN: + case LBF_BIT32_EXTRACTK: + case LBF_TONUMBER: + case LBF_BIT32_BYTESWAP: + case LBF_BUFFER_READI8: + case LBF_BUFFER_READU8: + case LBF_BUFFER_READI16: + case LBF_BUFFER_READU16: + case LBF_BUFFER_READI32: + case LBF_BUFFER_READU32: + case LBF_BUFFER_READF32: + case LBF_BUFFER_READF64: + recordResolvedType(node, &builtinTypes.numberType); + break; + + case LBF_TYPE: + case LBF_STRING_CHAR: + case LBF_TYPEOF: + case LBF_STRING_SUB: + case LBF_TOSTRING: + recordResolvedType(node, &builtinTypes.stringType); + break; + + case LBF_RAWEQUAL: + recordResolvedType(node, &builtinTypes.booleanType); + break; + + case LBF_VECTOR: + recordResolvedType(node, &builtinTypes.vectorType); + break; + } + } + + return true; // Let generic visitor step into all expressions + } + + // AstExpr classes that are not covered: + // * AstExprConstantNil is not resolved to 'nil' because that doesn't help codegen operations and often used as an initializer before real value + // * AstExprGlobal is not supported as we don't have info on globals + // * AstExprVarargs cannot be resolved to a testable type + // * AstExprTable cannot be reconstructed into a specific AstTypeTable and table annotations don't really help codegen + // * AstExprCall is very complex (especially if builtins and registered globals are included), will be extended in the future }; -void buildTypeMap(DenseHashMap& functionTypes, DenseHashMap& localTypes, AstNode* root, - const char* vectorType) +void buildTypeMap(DenseHashMap& functionTypes, DenseHashMap& localTypes, + DenseHashMap& exprTypes, AstNode* root, const char* vectorType, const BuiltinTypes& builtinTypes, + const DenseHashMap& builtinCalls, const DenseHashMap& globals) { - TypeMapVisitor visitor(functionTypes, localTypes, vectorType); + TypeMapVisitor visitor(functionTypes, localTypes, exprTypes, vectorType, builtinTypes, builtinCalls, globals); root->visit(&visitor); } diff --git a/Compiler/src/Types.h b/Compiler/src/Types.h index de11fde9..b1aff8a2 100644 --- a/Compiler/src/Types.h +++ b/Compiler/src/Types.h @@ -4,13 +4,29 @@ #include "Luau/Ast.h" #include "Luau/Bytecode.h" #include "Luau/DenseHash.h" +#include "ValueTracking.h" #include namespace Luau { -void buildTypeMap(DenseHashMap& functionTypes, DenseHashMap& localTypes, AstNode* root, - const char* vectorType); +struct BuiltinTypes +{ + BuiltinTypes(const char* vectorType) + : vectorType{{}, std::nullopt, AstName{vectorType}, std::nullopt, {}} + { + } + + // AstName use here will not match the AstNameTable, but the was we use them here always force a full string compare + AstTypeReference booleanType{{}, std::nullopt, AstName{"boolean"}, std::nullopt, {}}; + AstTypeReference numberType{{}, std::nullopt, AstName{"number"}, std::nullopt, {}}; + AstTypeReference stringType{{}, std::nullopt, AstName{"string"}, std::nullopt, {}}; + AstTypeReference vectorType; +}; + +void buildTypeMap(DenseHashMap& functionTypes, DenseHashMap& localTypes, + DenseHashMap& exprTypes, AstNode* root, const char* vectorType, const BuiltinTypes& builtinTypes, + const DenseHashMap& builtinCalls, const DenseHashMap& globals); } // namespace Luau diff --git a/VM/src/ldebug.cpp b/VM/src/ldebug.cpp index 7122b035..0e792366 100644 --- a/VM/src/ldebug.cpp +++ b/VM/src/ldebug.cpp @@ -12,6 +12,8 @@ #include #include +LUAU_FASTFLAGVARIABLE(LuauPushErrorStackCheck, false) + static const char* getfuncname(Closure* f); static int currentpc(lua_State* L, CallInfo* ci) @@ -330,12 +332,18 @@ l_noret luaG_runerrorL(lua_State* L, const char* fmt, ...) vsnprintf(result, sizeof(result), fmt, argp); va_end(argp); + if (FFlag::LuauPushErrorStackCheck) + lua_rawcheckstack(L, 1); + pusherror(L, result); luaD_throw(L, LUA_ERRRUN); } void luaG_pusherror(lua_State* L, const char* error) { + if (FFlag::LuauPushErrorStackCheck) + lua_rawcheckstack(L, 1); + pusherror(L, error); } diff --git a/VM/src/lmem.cpp b/VM/src/lmem.cpp index 3de18cf9..f6cc07c9 100644 --- a/VM/src/lmem.cpp +++ b/VM/src/lmem.cpp @@ -53,6 +53,10 @@ * for each block size there's a page free list that contains pages that have at least one free block * (global_State::freegcopages). This free list is used to make sure object allocation is O(1). * + * When LUAU_ASSERTENABLED is enabled, all non-GCO pages are also linked in a list (global_State::allpages). + * Because this list is not strictly required for runtime operations, it is only tracked for the purposes of + * debugging. While overhead of linking those pages together is very small, unnecessary operations are avoided. + * * Compared to GCOs, regular allocations have two important differences: they can be freed in isolation, * and they don't start with a GC header. Because of this, each allocation is prefixed with block metadata, * which contains the pointer to the page for allocated blocks, and the pointer to the next free block @@ -190,6 +194,12 @@ const SizeClassConfig kSizeClassConfig; #define metadata(block) (*(void**)(block)) #define freegcolink(block) (*(void**)((char*)block + kGCOLinkOffset)) +#if defined(LUAU_ASSERTENABLED) +#define debugpageset(x) (x) +#else +#define debugpageset(x) NULL +#endif + struct lua_Page { // list of pages with free blocks @@ -336,7 +346,7 @@ static void* newblock(lua_State* L, int sizeClass) // slow path: no page in the freelist, allocate a new one if (!page) - page = newclasspage(L, g->freepages, NULL, sizeClass, true); + page = newclasspage(L, g->freepages, debugpageset(&g->allpages), sizeClass, true); LUAU_ASSERT(!page->prev); LUAU_ASSERT(page->freeList || page->freeNext >= 0); @@ -457,7 +467,7 @@ static void freeblock(lua_State* L, int sizeClass, void* block) // if it's the last block in the page, we don't need the page if (page->busyBlocks == 0) - freeclasspage(L, g->freepages, NULL, page, sizeClass); + freeclasspage(L, g->freepages, debugpageset(&g->allpages), page, sizeClass); } static void freegcoblock(lua_State* L, int sizeClass, void* block, lua_Page* page) diff --git a/VM/src/lstate.cpp b/VM/src/lstate.cpp index 858f61a3..dbc1dd10 100644 --- a/VM/src/lstate.cpp +++ b/VM/src/lstate.cpp @@ -204,6 +204,7 @@ lua_State* lua_newstate(lua_Alloc f, void* ud) g->freepages[i] = NULL; g->freegcopages[i] = NULL; } + g->allpages = NULL; g->allgcopages = NULL; g->sweepgcopage = NULL; for (i = 0; i < LUA_T_COUNT; i++) diff --git a/VM/src/lstate.h b/VM/src/lstate.h index 97546511..21d7071c 100644 --- a/VM/src/lstate.h +++ b/VM/src/lstate.h @@ -188,7 +188,8 @@ typedef struct global_State struct lua_Page* freepages[LUA_SIZECLASSES]; // free page linked list for each size class for non-collectable objects struct lua_Page* freegcopages[LUA_SIZECLASSES]; // free page linked list for each size class for collectable objects - struct lua_Page* allgcopages; // page linked list with all pages for all classes + struct lua_Page* allpages; // page linked list with all pages for all non-collectable object classes (available with LUAU_ASSERTENABLED) + struct lua_Page* allgcopages; // page linked list with all pages for all collectable object classes struct lua_Page* sweepgcopage; // position of the sweep in `allgcopages' size_t memcatbytes[LUA_MEMORY_CATEGORIES]; // total amount of memory used by each memory category diff --git a/VM/src/ltablib.cpp b/VM/src/ltablib.cpp index a57d6cf7..545c1d2d 100644 --- a/VM/src/ltablib.cpp +++ b/VM/src/ltablib.cpp @@ -12,6 +12,7 @@ #include "lvm.h" LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauFastCrossTableMove, false) +LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauFastTableMaxn, false) LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauFasterConcat, false) static int foreachi(lua_State* L) @@ -55,17 +56,45 @@ static int maxn(lua_State* L) { double max = 0; luaL_checktype(L, 1, LUA_TTABLE); - lua_pushnil(L); // first key - while (lua_next(L, 1)) + + if (DFFlag::LuauFastTableMaxn) { - lua_pop(L, 1); // remove value - if (lua_type(L, -1) == LUA_TNUMBER) + Table* t = hvalue(L->base); + + for (int i = 0; i < t->sizearray; i++) { - double v = lua_tonumber(L, -1); - if (v > max) - max = v; + if (!ttisnil(&t->array[i])) + max = i + 1; + } + + for (int i = 0; i < sizenode(t); i++) + { + LuaNode* n = gnode(t, i); + + if (!ttisnil(gval(n)) && ttisnumber(gkey(n))) + { + double v = nvalue(gkey(n)); + + if (v > max) + max = v; + } } } + else + { + lua_pushnil(L); // first key + while (lua_next(L, 1)) + { + lua_pop(L, 1); // remove value + if (lua_type(L, -1) == LUA_TNUMBER) + { + double v = lua_tonumber(L, -1); + if (v > max) + max = v; + } + } + } + lua_pushnumber(L, max); return 1; } diff --git a/tests/IrLowering.test.cpp b/tests/IrLowering.test.cpp index 479329b4..0c0d6378 100644 --- a/tests/IrLowering.test.cpp +++ b/tests/IrLowering.test.cpp @@ -18,8 +18,12 @@ LUAU_FASTFLAG(LuauCompileTypeInfo) LUAU_FASTFLAG(LuauLoadTypeInfo) LUAU_FASTFLAG(LuauCodegenTypeInfo) LUAU_FASTFLAG(LuauTypeInfoLookupImprovement) +LUAU_FASTFLAG(LuauCodegenIrTypeNames) +LUAU_FASTFLAG(LuauCompileTempTypeInfo) +LUAU_FASTFLAG(LuauCodegenFixVectorFields) +LUAU_FASTFLAG(LuauCodegenVectorMispredictFix) -static std::string getCodegenAssembly(const char* source, bool includeIrTypes = false) +static std::string getCodegenAssembly(const char* source, bool includeIrTypes = false, int debugLevel = 1) { Luau::CodeGen::AssemblyOptions options; @@ -47,7 +51,7 @@ static std::string getCodegenAssembly(const char* source, bool includeIrTypes = Luau::CompileOptions copts = {}; copts.optimizationLevel = 2; - copts.debugLevel = 1; + copts.debugLevel = debugLevel; copts.typeInfoLevel = 1; copts.vectorCtor = "vector"; copts.vectorType = "vector"; @@ -66,6 +70,20 @@ static std::string getCodegenAssembly(const char* source, bool includeIrTypes = return ""; } +static std::string getCodegenHeader(const char* source) +{ + std::string assembly = getCodegenAssembly(source, /* includeIrTypes */ true, /* debugLevel */ 2); + + auto bytecodeStart = assembly.find("bb_bytecode_0:"); + + if (bytecodeStart == std::string::npos) + bytecodeStart = assembly.find("bb_0:"); + + REQUIRE(bytecodeStart != std::string::npos); + + return assembly.substr(0, bytecodeStart); +} + TEST_SUITE_BEGIN("IrLowering"); TEST_CASE("VectorReciprocal") @@ -451,6 +469,50 @@ bb_bytecode_1: )"); } +TEST_CASE("VectorRandomProp") +{ + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; + ScopedFastFlag luauCodegenFixVectorFields{FFlag::LuauCodegenFixVectorFields, true}; + ScopedFastFlag luauCodegenVectorMispredictFix{FFlag::LuauCodegenVectorMispredictFix, true}; + + CHECK_EQ("\n" + getCodegenAssembly(R"( +local function foo(a: vector) + return a.XX + a.YY + a.ZZ +end +)"), + R"( +; function foo($arg0) line 2 +bb_0: + CHECK_TAG R0, tvector, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + FALLBACK_GETTABLEKS 0u, R3, R0, K0 + FALLBACK_GETTABLEKS 2u, R4, R0, K1 + CHECK_TAG R3, tnumber, bb_fallback_3 + CHECK_TAG R4, tnumber, bb_fallback_3 + %14 = LOAD_DOUBLE R3 + %16 = ADD_NUM %14, R4 + STORE_DOUBLE R2, %16 + STORE_TAG R2, tnumber + JUMP bb_4 +bb_4: + CHECK_TAG R0, tvector, exit(5) + FALLBACK_GETTABLEKS 5u, R3, R0, K2 + CHECK_TAG R2, tnumber, bb_fallback_5 + CHECK_TAG R3, tnumber, bb_fallback_5 + %30 = LOAD_DOUBLE R2 + %32 = ADD_NUM %30, R3 + STORE_DOUBLE R1, %32 + STORE_TAG R1, tnumber + JUMP bb_6 +bb_6: + INTERRUPT 8u + RETURN R1, 1i +)"); +} + TEST_CASE("UserDataGetIndex") { ScopedFastFlag luauCodegenDirectUserdataFlow{FFlag::LuauCodegenDirectUserdataFlow, true}; @@ -860,4 +922,422 @@ bb_bytecode_0: )"); } +TEST_CASE("ResolveTablePathTypes") +{ + ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, + {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauTypeInfoLookupImprovement, true}, {FFlag::LuauCodegenIrTypeNames, true}, + {FFlag::LuauCompileTempTypeInfo, true}}; + + CHECK_EQ("\n" + getCodegenAssembly(R"( +type Vertex = {pos: vector, normal: vector} + +local function foo(arr: {Vertex}, i) + local v = arr[i] + + return v.pos.Y +end +)", + /* includeIrTypes */ true, /* debugLevel */ 2), + R"( +; function foo(arr, i) line 4 +; R0: table [argument 'arr'] +; R2: table from 0 to 6 [local 'v'] +; R4: vector from 3 to 5 +bb_0: + CHECK_TAG R0, ttable, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + CHECK_TAG R1, tnumber, bb_fallback_3 + %8 = LOAD_POINTER R0 + %9 = LOAD_DOUBLE R1 + %10 = TRY_NUM_TO_INDEX %9, bb_fallback_3 + %11 = SUB_INT %10, 1i + CHECK_ARRAY_SIZE %8, %11, bb_fallback_3 + CHECK_NO_METATABLE %8, bb_fallback_3 + %14 = GET_ARR_ADDR %8, %11 + %15 = LOAD_TVALUE %14 + STORE_TVALUE R2, %15 + JUMP bb_4 +bb_4: + CHECK_TAG R2, ttable, exit(1) + %23 = LOAD_POINTER R2 + %24 = GET_SLOT_NODE_ADDR %23, 1u, K0 + CHECK_SLOT_MATCH %24, K0, bb_fallback_5 + %26 = LOAD_TVALUE %24, 0i + STORE_TVALUE R4, %26 + JUMP bb_6 +bb_6: + CHECK_TAG R4, tvector, exit(3) + %33 = LOAD_FLOAT R4, 4i + STORE_DOUBLE R3, %33 + STORE_TAG R3, tnumber + INTERRUPT 5u + RETURN R3, 1i +)"); +} + +TEST_CASE("ResolvableSimpleMath") +{ + ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, + {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauTypeInfoLookupImprovement, true}, {FFlag::LuauCodegenIrTypeNames, true}, + {FFlag::LuauCompileTempTypeInfo, true}}; + + CHECK_EQ("\n" + getCodegenHeader(R"( +type Vertex = { p: vector, uv: vector, n: vector, t: vector, b: vector, h: number } +local mesh: { vertices: {Vertex}, indices: {number} } = ... + +local function compute() + for i = 1,#mesh.indices,3 do + local a = mesh.vertices[mesh.indices[i]] + local b = mesh.vertices[mesh.indices[i + 1]] + local c = mesh.vertices[mesh.indices[i + 2]] + + local vba = b.p - a.p + local vca = c.p - a.p + + local uvba = b.uv - a.uv + local uvca = c.uv - a.uv + + local r = 1.0 / (uvba.X * uvca.Y - uvca.X * uvba.Y); + + local sdir = (uvca.Y * vba - uvba.Y * vca) * r + + a.t += sdir + end +end +)"), + R"( +; function compute() line 5 +; U0: table ['mesh'] +; R2: number from 0 to 78 [local 'i'] +; R3: table from 7 to 78 [local 'a'] +; R4: table from 15 to 78 [local 'b'] +; R5: table from 24 to 78 [local 'c'] +; R6: vector from 33 to 78 [local 'vba'] +; R7: vector from 37 to 38 +; R7: vector from 38 to 78 [local 'vca'] +; R8: vector from 37 to 38 +; R8: vector from 42 to 43 +; R8: vector from 43 to 78 [local 'uvba'] +; R9: vector from 42 to 43 +; R9: vector from 47 to 48 +; R9: vector from 48 to 78 [local 'uvca'] +; R10: vector from 47 to 48 +; R10: vector from 52 to 53 +; R10: number from 53 to 78 [local 'r'] +; R11: vector from 52 to 53 +; R11: vector from 65 to 78 [local 'sdir'] +; R12: vector from 72 to 73 +; R12: vector from 75 to 76 +; R13: vector from 71 to 72 +; R14: vector from 71 to 72 +)"); +} + +TEST_CASE("ResolveVectorNamecalls") +{ + ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, + {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauTypeInfoLookupImprovement, true}, {FFlag::LuauCodegenIrTypeNames, true}, + {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}}; + + CHECK_EQ("\n" + getCodegenAssembly(R"( +type Vertex = {pos: vector, normal: vector} + +local function foo(arr: {Vertex}, i) + return arr[i].normal:Dot(vector(0.707, 0, 0.707)) +end +)", + /* includeIrTypes */ true), + R"( +; function foo($arg0, $arg1) line 4 +; R0: table [argument] +; R2: vector from 4 to 6 +bb_0: + CHECK_TAG R0, ttable, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + CHECK_TAG R1, tnumber, bb_fallback_3 + %8 = LOAD_POINTER R0 + %9 = LOAD_DOUBLE R1 + %10 = TRY_NUM_TO_INDEX %9, bb_fallback_3 + %11 = SUB_INT %10, 1i + CHECK_ARRAY_SIZE %8, %11, bb_fallback_3 + CHECK_NO_METATABLE %8, bb_fallback_3 + %14 = GET_ARR_ADDR %8, %11 + %15 = LOAD_TVALUE %14 + STORE_TVALUE R3, %15 + JUMP bb_4 +bb_4: + CHECK_TAG R3, ttable, bb_fallback_5 + %23 = LOAD_POINTER R3 + %24 = GET_SLOT_NODE_ADDR %23, 1u, K0 + CHECK_SLOT_MATCH %24, K0, bb_fallback_5 + %26 = LOAD_TVALUE %24, 0i + STORE_TVALUE R2, %26 + JUMP bb_6 +bb_6: + %31 = LOAD_TVALUE K1, 0i, tvector + STORE_TVALUE R4, %31 + CHECK_TAG R2, tvector, exit(4) + FALLBACK_NAMECALL 4u, R2, R2, K2 + INTERRUPT 6u + SET_SAVEDPC 7u + CALL R2, 2i, -1i + INTERRUPT 7u + RETURN R2, -1i +)"); +} + +TEST_CASE("ImmediateTypeAnnotationHelp") +{ + ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, + {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauTypeInfoLookupImprovement, true}, {FFlag::LuauCodegenIrTypeNames, true}, + {FFlag::LuauCompileTempTypeInfo, true}}; + + CHECK_EQ("\n" + getCodegenAssembly(R"( +local function foo(arr, i) + return (arr[i] :: vector) / 5 +end +)", + /* includeIrTypes */ true), + R"( +; function foo($arg0, $arg1) line 2 +; R3: vector from 1 to 2 +bb_bytecode_0: + CHECK_TAG R0, ttable, bb_fallback_1 + CHECK_TAG R1, tnumber, bb_fallback_1 + %4 = LOAD_POINTER R0 + %5 = LOAD_DOUBLE R1 + %6 = TRY_NUM_TO_INDEX %5, bb_fallback_1 + %7 = SUB_INT %6, 1i + CHECK_ARRAY_SIZE %4, %7, bb_fallback_1 + CHECK_NO_METATABLE %4, bb_fallback_1 + %10 = GET_ARR_ADDR %4, %7 + %11 = LOAD_TVALUE %10 + STORE_TVALUE R3, %11 + JUMP bb_2 +bb_2: + CHECK_TAG R3, tvector, exit(1) + %19 = LOAD_TVALUE R3 + %20 = NUM_TO_VEC 5 + %21 = DIV_VEC %19, %20 + %22 = TAG_VECTOR %21 + STORE_TVALUE R2, %22 + INTERRUPT 2u + RETURN R2, 1i +)"); +} + +TEST_CASE("UnaryTypeResolve") +{ + ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, + {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauTypeInfoLookupImprovement, true}, {FFlag::LuauCodegenIrTypeNames, true}, + {FFlag::LuauCompileTempTypeInfo, true}}; + + CHECK_EQ("\n" + getCodegenHeader(R"( +local function foo(a, b: vector, c) + local d = not a + local e = -b + local f = #c + return (if d then e else vector(f, 2, 3)).X +end +)"), + R"( +; function foo(a, b, c) line 2 +; R1: vector [argument 'b'] +; R3: boolean from 0 to 16 [local 'd'] +; R4: vector from 1 to 16 [local 'e'] +; R5: number from 2 to 16 [local 'f'] +; R7: vector from 13 to 15 +)"); +} + +TEST_CASE("ForInManualAnnotation") +{ + ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, + {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauTypeInfoLookupImprovement, true}, {FFlag::LuauCodegenIrTypeNames, true}, + {FFlag::LuauCompileTempTypeInfo, true}}; + + CHECK_EQ("\n" + getCodegenAssembly(R"( +type Vertex = {pos: vector, normal: vector} + +local function foo(a: {Vertex}) + local sum = 0 + for k, v: Vertex in ipairs(a) do + sum += v.pos.X + end + return sum +end +)", + /* includeIrTypes */ true, /* debugLevel */ 2), + R"( +; function foo(a) line 4 +; R0: table [argument 'a'] +; R1: number from 0 to 14 [local 'sum'] +; R5: number from 5 to 11 [local 'k'] +; R6: table from 5 to 11 [local 'v'] +; R8: vector from 8 to 10 +bb_0: + CHECK_TAG R0, ttable, exit(entry) + JUMP bb_4 +bb_4: + JUMP bb_bytecode_1 +bb_bytecode_1: + STORE_DOUBLE R1, 0 + STORE_TAG R1, tnumber + CHECK_SAFE_ENV exit(1) + JUMP_EQ_TAG K1, tnil, bb_fallback_6, bb_5 +bb_5: + %9 = LOAD_TVALUE K1 + STORE_TVALUE R2, %9 + JUMP bb_7 +bb_7: + %15 = LOAD_TVALUE R0 + STORE_TVALUE R3, %15 + INTERRUPT 4u + SET_SAVEDPC 5u + CALL R2, 1i, 3i + CHECK_SAFE_ENV exit(5) + CHECK_TAG R3, ttable, bb_fallback_8 + CHECK_TAG R4, tnumber, bb_fallback_8 + JUMP_CMP_NUM R4, 0, not_eq, bb_fallback_8, bb_9 +bb_9: + STORE_TAG R2, tnil + STORE_POINTER R4, 0i + STORE_EXTRA R4, 128i + STORE_TAG R4, tlightuserdata + JUMP bb_bytecode_3 +bb_bytecode_2: + CHECK_TAG R6, ttable, exit(6) + %35 = LOAD_POINTER R6 + %36 = GET_SLOT_NODE_ADDR %35, 6u, K2 + CHECK_SLOT_MATCH %36, K2, bb_fallback_10 + %38 = LOAD_TVALUE %36, 0i + STORE_TVALUE R8, %38 + JUMP bb_11 +bb_11: + CHECK_TAG R8, tvector, exit(8) + %45 = LOAD_FLOAT R8, 0i + STORE_DOUBLE R7, %45 + STORE_TAG R7, tnumber + CHECK_TAG R1, tnumber, exit(10) + %52 = LOAD_DOUBLE R1 + %54 = ADD_NUM %52, %45 + STORE_DOUBLE R1, %54 + JUMP bb_bytecode_3 +bb_bytecode_3: + INTERRUPT 11u + CHECK_TAG R2, tnil, bb_fallback_13 + %60 = LOAD_POINTER R3 + %61 = LOAD_INT R4 + %62 = GET_ARR_ADDR %60, %61 + CHECK_ARRAY_SIZE %60, %61, bb_12 + %64 = LOAD_TAG %62 + JUMP_EQ_TAG %64, tnil, bb_12, bb_14 +bb_14: + %66 = ADD_INT %61, 1i + STORE_INT R4, %66 + %68 = INT_TO_NUM %66 + STORE_DOUBLE R5, %68 + STORE_TAG R5, tnumber + %71 = LOAD_TVALUE %62 + STORE_TVALUE R6, %71 + JUMP bb_bytecode_2 +bb_12: + INTERRUPT 13u + RETURN R1, 1i +)"); +} + +TEST_CASE("ForInAutoAnnotationIpairs") +{ + ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, + {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauTypeInfoLookupImprovement, true}, {FFlag::LuauCodegenIrTypeNames, true}, + {FFlag::LuauCompileTempTypeInfo, true}}; + + CHECK_EQ("\n" + getCodegenHeader(R"( +type Vertex = {pos: vector, normal: vector} + +local function foo(a: {Vertex}) + local sum = 0 + for k, v in ipairs(a) do + local n = v.pos.X + sum += n + end + return sum +end +)"), + R"( +; function foo(a) line 4 +; R0: table [argument 'a'] +; R1: number from 0 to 14 [local 'sum'] +; R5: number from 5 to 11 [local 'k'] +; R6: table from 5 to 11 [local 'v'] +; R7: number from 6 to 11 [local 'n'] +; R8: vector from 8 to 10 +)"); +} + +TEST_CASE("ForInAutoAnnotationPairs") +{ + ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, + {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauTypeInfoLookupImprovement, true}, {FFlag::LuauCodegenIrTypeNames, true}, + {FFlag::LuauCompileTempTypeInfo, true}}; + + CHECK_EQ("\n" + getCodegenHeader(R"( +type Vertex = {pos: vector, normal: vector} + +local function foo(a: {[string]: Vertex}) + local sum = 0 + for k, v in pairs(a) do + local n = v.pos.X + sum += n + end + return sum +end +)"), + R"( +; function foo(a) line 4 +; R0: table [argument 'a'] +; R1: number from 0 to 14 [local 'sum'] +; R5: string from 5 to 11 [local 'k'] +; R6: table from 5 to 11 [local 'v'] +; R7: number from 6 to 11 [local 'n'] +; R8: vector from 8 to 10 +)"); +} + +TEST_CASE("ForInAutoAnnotationGeneric") +{ + ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, + {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauTypeInfoLookupImprovement, true}, {FFlag::LuauCodegenIrTypeNames, true}, + {FFlag::LuauCompileTempTypeInfo, true}}; + + CHECK_EQ("\n" + getCodegenHeader(R"( +type Vertex = {pos: vector, normal: vector} + +local function foo(a: {Vertex}) + local sum = 0 + for k, v in a do + local n = v.pos.X + sum += n + end + return sum +end +)"), + R"( +; function foo(a) line 4 +; R0: table [argument 'a'] +; R1: number from 0 to 13 [local 'sum'] +; R5: number from 4 to 10 [local 'k'] +; R6: table from 4 to 10 [local 'v'] +; R7: number from 5 to 10 [local 'n'] +; R8: vector from 7 to 9 +)"); +} + TEST_SUITE_END(); diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index 1a9ffd65..36398289 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -14,7 +14,7 @@ LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(LuauFixNormalizeCaching) LUAU_FASTFLAG(LuauNormalizeNotUnknownIntersection) LUAU_FASTFLAG(LuauFixCyclicUnionsOfIntersections); - +LUAU_FASTINT(LuauTypeInferRecursionLimit) using namespace Luau; namespace @@ -962,4 +962,32 @@ TEST_CASE_FIXTURE(NormalizeFixture, "intersect_with_not_unknown") CHECK("never" == toString(normalizer.typeFromNormal(*normalized.get()))); } +TEST_CASE_FIXTURE(NormalizeFixture, "cyclic_stack_overflow_1") +{ + ScopedFastInt sfi{FInt::LuauTypeInferRecursionLimit, 165}; + this->unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit; + TypeId t1 = arena.addType(TableType{}); + TypeId t2 = arena.addType(TableType{}); + TypeId t3 = arena.addType(IntersectionType{{t1, t2}}); + asMutable(t1)->ty.get_if()->props = {{"foo", Property::readonly(t2)}}; + asMutable(t2)->ty.get_if()->props = {{"foo", Property::readonly(t1)}}; + + std::shared_ptr normalized = normalizer.normalize(t3); + CHECK(normalized); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "cyclic_stack_overflow_2") +{ + ScopedFastInt sfi{FInt::LuauTypeInferRecursionLimit, 165}; + this->unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit; + TypeId t1 = arena.addType(TableType{}); + TypeId t2 = arena.addType(TableType{}); + TypeId t3 = arena.addType(IntersectionType{{t1, t2}}); + asMutable(t1)->ty.get_if()->props = {{"foo", Property::readonly(t3)}}; + asMutable(t2)->ty.get_if()->props = {{"foo", Property::readonly(t1)}}; + + std::shared_ptr normalized = normalizer.normalize(t3); + CHECK(normalized); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.classes.test.cpp b/tests/TypeInfer.classes.test.cpp index 6112bd02..8f8aef84 100644 --- a/tests/TypeInfer.classes.test.cpp +++ b/tests/TypeInfer.classes.test.cpp @@ -17,6 +17,39 @@ LUAU_FASTFLAG(LuauAlwaysCommitInferencesOfFunctionCalls); TEST_SUITE_BEGIN("TypeInferClasses"); +TEST_CASE_FIXTURE(ClassFixture, "Luau.Analyze.CLI_crashes_on_this_test") +{ + CheckResult result = check(R"( + local CircularQueue = {} +CircularQueue.__index = CircularQueue + +function CircularQueue:new() + local newCircularQueue = { + head = nil, + } + setmetatable(newCircularQueue, CircularQueue) + + return newCircularQueue +end + +function CircularQueue:push() + local newListNode + + if self.head then + newListNode = { + prevNode = self.head.prevNode, + nextNode = self.head, + } + newListNode.prevNode.nextNode = newListNode + newListNode.nextNode.prevNode = newListNode + end +end + +return CircularQueue + + )"); +} + TEST_CASE_FIXTURE(ClassFixture, "call_method_of_a_class") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 4cc07fba..4446bbc9 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -4423,4 +4423,20 @@ TEST_CASE_FIXTURE(Fixture, "setindexer_multiple_tables_intersection") CHECK("({ [string]: number } & { [thread]: boolean }, boolean | number) -> ()" == toString(requireType("f"))); } +TEST_CASE_FIXTURE(Fixture, "insert_a_and_f_of_a_into_table_res_in_a_loop") +{ + CheckResult result = check(R"( + local function f(t) + local res = {} + + for k, a in t do + res[k] = f(a) + res[k] = a + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 140f462a..9ea9539f 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -782,7 +782,10 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "no_heap_use_after_free_error") end )"); - LUAU_REQUIRE_ERRORS(result); + if (FFlag::DebugLuauDeferredConstraintResolution) + LUAU_REQUIRE_NO_ERRORS(result); + else + LUAU_REQUIRE_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "infer_type_assertion_value_type") @@ -1540,19 +1543,33 @@ TEST_CASE_FIXTURE(Fixture, "typeof_cannot_refine_builtin_alias") )"); } -/* - * We had an issue where we tripped the canMutate() check when binding one - * blocked type to another. - */ -TEST_CASE_FIXTURE(Fixture, "delay_setIndexer_constraint_if_the_indexers_type_is_blocked") +TEST_CASE_FIXTURE(BuiltinsFixture, "bad_iter_metamethod") { - (void) check(R"( - local SG = GetService(true) - local lines: { [string]: typeof(SG.ScreenGui) } = {} - lines[deadline] = nil -- This line + CheckResult result = check(R"( + function iter(): unknown + return nil + end + + local a = {__iter = iter} + setmetatable(a, a) + + for i in a do + end )"); - // As long as type inference doesn't trip an assert or crash, we're good! + if (FFlag::DebugLuauDeferredConstraintResolution) + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CannotCallNonFunction* ccnf = get(result.errors[0]); + REQUIRE(ccnf); + + CHECK("unknown" == toString(ccnf->ty)); + } + else + { + LUAU_REQUIRE_NO_ERRORS(result); + } } TEST_SUITE_END(); diff --git a/tests/TypeInfer.typestates.test.cpp b/tests/TypeInfer.typestates.test.cpp index dbb9815d..3116022b 100644 --- a/tests/TypeInfer.typestates.test.cpp +++ b/tests/TypeInfer.typestates.test.cpp @@ -490,5 +490,34 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "typestates_do_not_apply_to_the_initial_local CHECK("number" == toString(requireTypeAtPosition({5, 14}), {true})); } +TEST_CASE_FIXTURE(Fixture, "typestate_globals") +{ + ScopedFastFlag sff{FFlag::DebugLuauDeferredConstraintResolution, true}; + + loadDefinition(R"( + declare foo: string | number + declare function f(x: string): () + )"); + + CheckResult result = check(R"( + foo = "a" + f(foo) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "typestate_unknown_global") +{ + ScopedFastFlag sff{FFlag::DebugLuauDeferredConstraintResolution, true}; + + CheckResult result = check(R"( + x = 5 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK(get(result.errors[0])); +} TEST_SUITE_END(); diff --git a/tests/conformance/tables.lua b/tests/conformance/tables.lua index 75163fd1..3f1efd8e 100644 --- a/tests/conformance/tables.lua +++ b/tests/conformance/tables.lua @@ -306,10 +306,14 @@ end assert(table.maxn{} == 0) +assert(table.maxn{[-100] = 1} == 0) assert(table.maxn{["1000"] = true} == 0) assert(table.maxn{["1000"] = true, [24.5] = 3} == 24.5) assert(table.maxn{[1000] = true} == 1000) assert(table.maxn{[10] = true, [100*math.pi] = print} == 100*math.pi) +a = {[10] = 1, [20] = 2} +a[20] = nil +assert(table.maxn(a) == 10) -- int overflow diff --git a/tools/faillist.txt b/tools/faillist.txt index 469e3a84..db3eeba5 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -37,7 +37,6 @@ Differ.metatable_metamissing_left Differ.metatable_metamissing_right Differ.metatable_metanormal Differ.negation -FrontendTest.accumulate_cached_errors_in_consistent_order FrontendTest.environments FrontendTest.imported_table_modification_2 FrontendTest.it_should_be_safe_to_stringify_errors_when_full_type_graph_is_discarded @@ -182,6 +181,7 @@ TableTests.infer_array TableTests.infer_indexer_from_array_like_table TableTests.infer_indexer_from_its_variable_type_and_unifiable TableTests.inferred_return_type_of_free_table +TableTests.insert_a_and_f_of_a_into_table_res_in_a_loop TableTests.invariant_table_properties_means_instantiating_tables_in_assignment_is_unsound TableTests.invariant_table_properties_means_instantiating_tables_in_call_is_unsound TableTests.length_operator_union @@ -269,7 +269,6 @@ TypeInfer.dont_report_type_errors_within_an_AstExprError TypeInfer.dont_report_type_errors_within_an_AstStatError TypeInfer.globals TypeInfer.globals2 -TypeInfer.globals_are_banned_in_strict_mode TypeInfer.infer_through_group_expr TypeInfer.no_stack_overflow_from_isoptional TypeInfer.recursive_function_that_invokes_itself_with_a_refinement_of_its_parameter @@ -366,7 +365,6 @@ TypeInferLoops.properly_infer_iteratee_is_a_free_table TypeInferLoops.repeat_loop TypeInferLoops.varlist_declared_by_for_in_loop_should_be_free TypeInferLoops.while_loop -TypeInferModules.custom_require_global TypeInferModules.do_not_modify_imported_types_5 TypeInferModules.require TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_it_wont_help_2 diff --git a/tools/stackdbg.py b/tools/stackdbg.py new file mode 100644 index 00000000..de656c60 --- /dev/null +++ b/tools/stackdbg.py @@ -0,0 +1,94 @@ +#!usr/bin/python3 +""" +To use this command, simply run the command: +`command script import /path/to/your/game-engine/Client/Luau/tools/stackdbg.py` +in the `lldb` interpreter. You can also add it to your .lldbinit file to have it be +automatically imported. + +If using vscode, you can add the above command to your launch.json under `preRunCommands` for the appropriate target. For example: +{ + "name": "Luau.UnitTest", + "type": "lldb", + "request": "launch", + "program": "${workspaceFolder}/build/ninja/common-tests/noopt/Luau/Luau.UnitTest", + "preRunCommands": [ + "command script import ${workspaceFolder}/Client/Luau/tools/stackdbg.py" + ], +} + +Once this is loaded, +`(lldb) help stack` +or +`(lldb) stack -h +or +`(lldb) stack --help + +can get you started +""" + +import lldb +import functools +import argparse +import shlex + +# Dumps the collected frame data +def dump(collected): + for (frame_name, size_in_kb, live_size_kb, variables) in collected: + print(f'{frame_name}, locals: {size_in_kb}kb, fp-sp: {live_size_kb}kb') + for (var_name, var_size, variable_obj) in variables: + print(f' {var_name}, {var_size} bytes') + +def dbg_stack_pressure(frame, frames_to_show = 5, sort_frames = False, vars_to_show = 5, sort_vars = True): + totalKb = 0 + collect = [] + for f in frame.thread: + frame_name = f.GetFunctionName() + variables = [ (v.GetName(), v.GetByteSize(), v) for v in f.get_locals() ] + if sort_vars: + variables.sort(key = lambda x: x[1], reverse = True) + size_in_kb = functools.reduce(lambda x,y : x + y[1], variables, 0) / 1024 + + fp = f.GetFP() + sp = f.GetSP() + live_size_kb = round((fp - sp) / 1024, 2) + + size_in_kb = round(size_in_kb, 2) + totalKb += size_in_kb + collect.append((frame_name, size_in_kb, live_size_kb, variables[:vars_to_show])) + if sort_frames: + collect.sort(key = lambda x: x[1], reverse = True) + + print("******************** Report Stack Usage ********************") + totalMb = round(totalKb / 1024, 2) + print(f'{len(frame.thread)} stack frames used {totalMb}MB') + dump(collect[:frames_to_show]) + +def stack(debugger, command, result, internal_dict): + """ + usage: [-h] [-f FRAMES] [-fd] [-v VARS] [-vd] + + optional arguments: + -h, --help show this help message and exit + -f FRAMES, --frames FRAMES + How many stack frames to display + -fd, --sort_frames Sort frames + -v VARS, --vars VARS How many variables per frame to display + -vd, --sort_vars Sort frames + """ + + frame = debugger.GetSelectedTarget().GetProcess().GetSelectedThread().GetSelectedFrame() + args = shlex.split(command) + argparser = argparse.ArgumentParser(allow_abbrev = True) + argparser.add_argument("-f", "--frames", required=False, help="How many stack frames to display", default=5, type=int) + argparser.add_argument("-fd", "--sort_frames", required=False, help="Sort frames in descending order of stack usage", action="store_true", default=False) + argparser.add_argument("-v", "--vars", required=False, help="How many variables per frame to display", default=5, type=int) + argparser.add_argument("-vd", "--sort_vars", required=False, help="Sort locals in descending order of stack usage ", action="store_true", default=False) + + args = argparser.parse_args(args) + dbg_stack_pressure(frame, frames_to_show=args.frames, sort_frames=args.sort_frames, vars_to_show=args.vars, sort_vars=args.sort_vars) + +# Initialization code to add commands +def __lldb_init_module(debugger, internal_dict): + debugger.HandleCommand('command script add -f stackdbg.stack stack') + print("The 'stack' python command has been installed and is ready for use.") + From 905a37b9283290f101d54f50d5d06cf6d73fe9ab Mon Sep 17 00:00:00 2001 From: vegorov-rbx <75688451+vegorov-rbx@users.noreply.github.com> Date: Wed, 8 May 2024 13:35:12 -0700 Subject: [PATCH 05/20] Update native code generation note in the security guarantees (#1250) --- SECURITY.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/SECURITY.md b/SECURITY.md index 48a6ccc4..ca3f5923 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -1,6 +1,6 @@ # Security Guarantees -Luau provides a safe sandbox that scripts can not escape from, short of vulnerabilities in custom C functions exposed by the host. This includes the virtual machine and builtin libraries. Notably this currently does *not* include the work-in-progress native code generation facilities. +Luau provides a safe sandbox that scripts can not escape from, short of vulnerabilities in custom C functions exposed by the host. This includes the virtual machine, builtin libraries and native code generation facilities. Any source code can not result in memory safety errors or crashes during its compilation or execution. Violations of memory safety are considered vulnerabilities. From a775bbc6fc9f13ad1e353b2acabcc4965e74885a Mon Sep 17 00:00:00 2001 From: Bjorn Date: Fri, 10 May 2024 03:36:37 -0700 Subject: [PATCH 06/20] Fix confusing warning when CMake version is too low (#1251) Experienced brief moment of panic when this warning said I had clang 3 installed. --- fuzz/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fuzz/CMakeLists.txt b/fuzz/CMakeLists.txt index c18fbba5..be40b811 100644 --- a/fuzz/CMakeLists.txt +++ b/fuzz/CMakeLists.txt @@ -1,6 +1,6 @@ # This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details if(${CMAKE_VERSION} VERSION_LESS "3.26") - message(WARNING "Building the Luau fuzzer requires Clang version 3.26 of higher.") + message(WARNING "Building the Luau fuzzer requires CMake version 3.26 or higher.") return() endif() From 2a80f5e1d1311390cd5f69beca628c007fe35d7e Mon Sep 17 00:00:00 2001 From: Vighnesh-V Date: Fri, 10 May 2024 11:21:45 -0700 Subject: [PATCH 07/20] Sync to upstream/release/625 (#1252) # What's changed? * Fix warning issued when Cmake version is too low (contributed by OSS community) ## New Solver * Fix an issue with inhabitance testing of tables with cyclic properties * Preserve error suppression during type unification * Overhaul type reference counting in the constraint solver * Other miscellaneous constraint ordering fixes ## Native Codegen * Fix incorrect assertion check in loadBytecodeTypeInfo --- ## Internal Contributors Co-authored-by: Aaron Weiss Co-authored-by: Alexander McCord Co-authored-by: Andy Friesen Co-authored-by: Vighnesh Vijay Co-authored-by: Vyacheslav Egorov --------- Co-authored-by: Aaron Weiss Co-authored-by: Alexander McCord Co-authored-by: Andy Friesen Co-authored-by: Aviral Goel Co-authored-by: David Cope Co-authored-by: Lily Brown Co-authored-by: Vyacheslav Egorov --- Analysis/include/Luau/Constraint.h | 4 +- Analysis/include/Luau/ConstraintSolver.h | 18 ++ Analysis/include/Luau/Normalize.h | 13 +- Analysis/include/Luau/Set.h | 8 +- Analysis/include/Luau/Unifier2.h | 5 + Analysis/src/Constraint.cpp | 102 +++++++- Analysis/src/ConstraintSolver.cpp | 77 +++++- Analysis/src/Frontend.cpp | 48 ++-- Analysis/src/Normalize.cpp | 193 ++++++--------- Analysis/src/Set.cpp | 5 - Analysis/src/ToString.cpp | 1 - Analysis/src/TypeFamily.cpp | 154 +++++++++++- Analysis/src/TypeInfer.cpp | 23 +- Analysis/src/Unifier.cpp | 167 ++++--------- Analysis/src/Unifier2.cpp | 147 +++++++++-- Ast/src/Parser.cpp | 26 +- CLI/Repl.cpp | 10 +- CodeGen/include/Luau/BytecodeAnalysis.h | 3 +- CodeGen/include/Luau/CodeGen.h | 42 +++- CodeGen/include/Luau/IrBuilder.h | 10 +- CodeGen/src/BytecodeAnalysis.cpp | 49 +++- CodeGen/src/CodeGen.cpp | 42 +++- CodeGen/src/CodeGenAssembly.cpp | 6 +- CodeGen/src/CodeGenContext.cpp | 22 +- CodeGen/src/CodeGenContext.h | 4 +- CodeGen/src/IrBuilder.cpp | 26 +- CodeGen/src/IrTranslation.cpp | 30 ++- CodeGen/src/IrTranslation.h | 2 +- CodeGen/src/lcodegen.cpp | 3 +- Sources.cmake | 2 +- tests/Autocomplete.test.cpp | 6 +- tests/Conformance.test.cpp | 86 ++++++- tests/ConformanceIrHooks.h | 151 ++++++++++++ tests/IrBuilder.test.cpp | 6 + tests/IrLowering.test.cpp | 299 ++++++++++++++++++++++- tests/Normalize.test.cpp | 2 - tests/Parser.test.cpp | 3 - tests/Set.test.cpp | 4 - tests/SharedCodeAllocator.test.cpp | 7 +- tests/ToString.test.cpp | 74 ++++-- tests/TypeInfer.functions.test.cpp | 18 +- tests/TypeInfer.loops.test.cpp | 25 +- tests/TypeInfer.tables.test.cpp | 37 ++- tests/conformance/vector.lua | 29 +++ tools/faillist.txt | 14 +- 45 files changed, 1536 insertions(+), 467 deletions(-) delete mode 100644 Analysis/src/Set.cpp create mode 100644 tests/ConformanceIrHooks.h diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h index d52ae6e0..ec281ae3 100644 --- a/Analysis/include/Luau/Constraint.h +++ b/Analysis/include/Luau/Constraint.h @@ -284,11 +284,13 @@ struct Constraint std::vector> dependencies; - DenseHashSet getFreeTypes() const; + DenseHashSet getMaybeMutatedFreeTypes() const; }; using ConstraintPtr = std::unique_ptr; +bool isReferenceCountedType(const TypeId typ); + inline Constraint& asMutable(const Constraint& c) { return const_cast(c); diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index bb1fe2d8..031da67b 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -242,6 +242,24 @@ struct ConstraintSolver void reportError(TypeErrorData&& data, const Location& location); void reportError(TypeError e); + /** + * Shifts the count of references from `source` to `target`. This should be paired + * with any instance of binding a free type in order to maintain accurate refcounts. + * If `target` is not a free type, this is a noop. + * @param source the free type which is being bound + * @param target the type which the free type is being bound to + */ + void shiftReferences(TypeId source, TypeId target); + + /** + * Generalizes the given free type if the reference counting allows it. + * @param the scope to generalize in + * @param type the free type we want to generalize + * @returns a non-free type that generalizes the argument, or `std::nullopt` if one + * does not exist + */ + std::optional generalizeFreeType(NotNull scope, TypeId type); + /** * Checks the existing set of constraints to see if there exist any that contain * the provided free type, indicating that it is not yet ready to be replaced by diff --git a/Analysis/include/Luau/Normalize.h b/Analysis/include/Luau/Normalize.h index 6d75568e..b21e470c 100644 --- a/Analysis/include/Luau/Normalize.h +++ b/Analysis/include/Luau/Normalize.h @@ -307,6 +307,9 @@ struct NormalizedType /// Returns true if the type is a subtype of string(it could be a singleton). Behaves like Type::isString() bool isSubtypeOfString() const; + /// Returns true if the type is a subtype of boolean(it could be a singleton). Behaves like Type::isBoolean() + bool isSubtypeOfBooleans() const; + /// Returns true if this type should result in error suppressing behavior. bool shouldSuppressErrors() const; @@ -360,7 +363,6 @@ public: Normalizer& operator=(Normalizer&) = delete; // If this returns null, the typechecker should emit a "too complex" error - const NormalizedType* DEPRECATED_normalize(TypeId ty); std::shared_ptr normalize(TypeId ty); void clearNormal(NormalizedType& norm); @@ -395,7 +397,7 @@ public: TypeId negate(TypeId there); void subtractPrimitive(NormalizedType& here, TypeId ty); void subtractSingleton(NormalizedType& here, TypeId ty); - NormalizationResult intersectNormalWithNegationTy(TypeId toNegate, NormalizedType& intersect, bool useDeprecated = false); + NormalizationResult intersectNormalWithNegationTy(TypeId toNegate, NormalizedType& intersect); // ------- Normalizing intersections TypeId intersectionOfTops(TypeId here, TypeId there); @@ -404,8 +406,8 @@ public: void intersectClassesWithClass(NormalizedClassType& heres, TypeId there); void intersectStrings(NormalizedStringType& here, const NormalizedStringType& there); std::optional intersectionOfTypePacks(TypePackId here, TypePackId there); - std::optional intersectionOfTables(TypeId here, TypeId there); - void intersectTablesWithTable(TypeIds& heres, TypeId there); + std::optional intersectionOfTables(TypeId here, TypeId there, Set& seenSet); + void intersectTablesWithTable(TypeIds& heres, TypeId there, Set& seenSetTypes); void intersectTables(TypeIds& heres, const TypeIds& theres); std::optional intersectionOfFunctions(TypeId here, TypeId there); void intersectFunctionsWithFunction(NormalizedFunctionType& heress, TypeId there); @@ -413,7 +415,7 @@ public: NormalizationResult intersectTyvarsWithTy(NormalizedTyvars& here, TypeId there, Set& seenSetTypes); NormalizationResult intersectNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars = -1); NormalizationResult intersectNormalWithTy(NormalizedType& here, TypeId there, Set& seenSetTypes); - NormalizationResult normalizeIntersections(const std::vector& intersections, NormalizedType& outType); + NormalizationResult normalizeIntersections(const std::vector& intersections, NormalizedType& outType, Set& seenSet); // Check for inhabitance NormalizationResult isInhabited(TypeId ty); @@ -423,6 +425,7 @@ public: // Check for intersections being inhabited NormalizationResult isIntersectionInhabited(TypeId left, TypeId right); + NormalizationResult isIntersectionInhabited(TypeId left, TypeId right, Set& seenSet); // -------- Convert back from a normalized type to a type TypeId typeFromNormal(const NormalizedType& norm); diff --git a/Analysis/include/Luau/Set.h b/Analysis/include/Luau/Set.h index 2fea2e6a..274375cf 100644 --- a/Analysis/include/Luau/Set.h +++ b/Analysis/include/Luau/Set.h @@ -4,7 +4,6 @@ #include "Luau/Common.h" #include "Luau/DenseHash.h" -LUAU_FASTFLAG(LuauFixSetIter) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) namespace Luau @@ -143,11 +142,8 @@ public: : impl(impl_) , end(end_) { - if (FFlag::LuauFixSetIter || FFlag::DebugLuauDeferredConstraintResolution) - { - while (impl != end && impl->second == false) - ++impl; - } + while (impl != end && impl->second == false) + ++impl; } const T& operator*() const diff --git a/Analysis/include/Luau/Unifier2.h b/Analysis/include/Luau/Unifier2.h index a7d64312..130c0c3c 100644 --- a/Analysis/include/Luau/Unifier2.h +++ b/Analysis/include/Luau/Unifier2.h @@ -78,6 +78,11 @@ struct Unifier2 bool unify(TableType* subTable, const TableType* superTable); bool unify(const MetatableType* subMetatable, const MetatableType* superMetatable); + bool unify(const AnyType* subAny, const FunctionType* superFn); + bool unify(const FunctionType* subFn, const AnyType* superAny); + bool unify(const AnyType* subAny, const TableType* superTable); + bool unify(const TableType* subTable, const AnyType* superAny); + // TODO think about this one carefully. We don't do unions or intersections of type packs bool unify(TypePackId subTp, TypePackId superTp); diff --git a/Analysis/src/Constraint.cpp b/Analysis/src/Constraint.cpp index 4f35b58f..4d1c35e0 100644 --- a/Analysis/src/Constraint.cpp +++ b/Analysis/src/Constraint.cpp @@ -13,12 +13,12 @@ Constraint::Constraint(NotNull scope, const Location& location, Constrain { } -struct FreeTypeCollector : TypeOnceVisitor +struct ReferenceCountInitializer : TypeOnceVisitor { DenseHashSet* result; - FreeTypeCollector(DenseHashSet* result) + ReferenceCountInitializer(DenseHashSet* result) : result(result) { } @@ -29,6 +29,18 @@ struct FreeTypeCollector : TypeOnceVisitor return false; } + bool visit(TypeId ty, const BlockedType&) override + { + result->insert(ty); + return false; + } + + bool visit(TypeId ty, const PendingExpansionType&) override + { + result->insert(ty); + return false; + } + bool visit(TypeId ty, const ClassType&) override { // ClassTypes never contain free types. @@ -36,26 +48,92 @@ struct FreeTypeCollector : TypeOnceVisitor } }; -DenseHashSet Constraint::getFreeTypes() const +bool isReferenceCountedType(const TypeId typ) +{ + // n.b. this should match whatever `ReferenceCountInitializer` includes. + return get(typ) || get(typ) || get(typ); +} + +DenseHashSet Constraint::getMaybeMutatedFreeTypes() const { DenseHashSet types{{}}; - FreeTypeCollector ftc{&types}; + ReferenceCountInitializer rci{&types}; - if (auto sc = get(*this)) + if (auto ec = get(*this)) { - ftc.traverse(sc->subType); - ftc.traverse(sc->superType); + rci.traverse(ec->resultType); + // `EqualityConstraints` should not mutate `assignmentType`. + } + else if (auto sc = get(*this)) + { + rci.traverse(sc->subType); + rci.traverse(sc->superType); } else if (auto psc = get(*this)) { - ftc.traverse(psc->subPack); - ftc.traverse(psc->superPack); + rci.traverse(psc->subPack); + rci.traverse(psc->superPack); + } + else if (auto gc = get(*this)) + { + rci.traverse(gc->generalizedType); + // `GeneralizationConstraints` should not mutate `sourceType` or `interiorTypes`. + } + else if (auto itc = get(*this)) + { + rci.traverse(itc->variables); + // `IterableConstraints` should not mutate `iterator`. + } + else if (auto nc = get(*this)) + { + rci.traverse(nc->namedType); + } + else if (auto taec = get(*this)) + { + rci.traverse(taec->target); } else if (auto ptc = get(*this)) { - // we need to take into account primitive type constraints to prevent type families from reducing on - // primitive whose types we have not yet selected to be singleton or not. - ftc.traverse(ptc->freeType); + rci.traverse(ptc->freeType); + } + else if (auto hpc = get(*this)) + { + rci.traverse(hpc->resultType); + // `HasPropConstraints` should not mutate `subjectType`. + } + else if (auto spc = get(*this)) + { + rci.traverse(spc->resultType); + // `SetPropConstraints` should not mutate `subjectType` or `propType`. + // TODO: is this true? it "unifies" with `propType`, so maybe mutates that one too? + } + else if (auto hic = get(*this)) + { + rci.traverse(hic->resultType); + // `HasIndexerConstraint` should not mutate `subjectType` or `indexType`. + } + else if (auto sic = get(*this)) + { + rci.traverse(sic->propType); + // `SetIndexerConstraints` should not mutate `subjectType` or `indexType`. + } + else if (auto uc = get(*this)) + { + rci.traverse(uc->resultPack); + // `UnpackConstraint` should not mutate `sourcePack`. + } + else if (auto u1c = get(*this)) + { + rci.traverse(u1c->resultType); + // `Unpack1Constraint` should not mutate `sourceType`. + } + else if (auto rc = get(*this)) + { + rci.traverse(rc->ty); + } + else if (auto rpc = get(*this)) + { + rci.traverse(rpc->tp); } return types; diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index ff56a37d..cdb13b4a 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -27,6 +27,7 @@ #include LUAU_FASTFLAGVARIABLE(DebugLuauLogSolver, false); +LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverIncludeDependencies, false) LUAU_FASTFLAGVARIABLE(DebugLuauLogBindings, false); LUAU_FASTINTVARIABLE(LuauSolverRecursionLimit, 500); @@ -251,6 +252,15 @@ void dump(ConstraintSolver* cs, ToStringOptions& opts) auto it = cs->blockedConstraints.find(c); int blockCount = it == cs->blockedConstraints.end() ? 0 : int(it->second); printf("\t%d\t%s\n", blockCount, toString(*c, opts).c_str()); + + if (FFlag::DebugLuauLogSolverIncludeDependencies) + { + for (NotNull dep : c->dependencies) + { + if (std::find(cs->unsolvedConstraints.begin(), cs->unsolvedConstraints.end(), dep) != cs->unsolvedConstraints.end()) + printf("\t\t|\t%s\n", toString(*dep, opts).c_str()); + } + } } } @@ -305,7 +315,7 @@ ConstraintSolver::ConstraintSolver(NotNull normalizer, NotNullgetFreeTypes()) + for (auto ty : c->getMaybeMutatedFreeTypes()) { // increment the reference count for `ty` auto [refCount, _] = unresolvedConstraints.try_insert(ty, 0); @@ -394,7 +404,7 @@ void ConstraintSolver::run() unsolvedConstraints.erase(unsolvedConstraints.begin() + i); // decrement the referenced free types for this constraint if we dispatched successfully! - for (auto ty : c->getFreeTypes()) + for (auto ty : c->getMaybeMutatedFreeTypes()) { // this is a little weird, but because we're only counting free types in subtyping constraints, // some constraints (like unpack) might actually produce _more_ references to a free type. @@ -720,8 +730,6 @@ bool ConstraintSolver::tryDispatch(const NameConstraint& c, NotNull(c.target)); + shiftReferences(c.target, result); emplaceType(asMutable(c.target), result); unblock(c.target, constraint->location); }; @@ -1190,6 +1199,7 @@ bool ConstraintSolver::tryDispatch(const FunctionCheckConstraint& c, NotNullargs.data[j]->annotation && get(follow(lambdaArgTys[j]))) { + shiftReferences(lambdaArgTys[j], expectedLambdaArgTys[j]); emplaceType(asMutable(lambdaArgTys[j]), expectedLambdaArgTys[j]); } } @@ -1242,6 +1252,7 @@ bool ConstraintSolver::tryDispatch(const PrimitiveTypeConstraint& c, NotNulllowerBound; + shiftReferences(c.freeType, bindTo); emplaceType(asMutable(c.freeType), bindTo); return true; @@ -1551,7 +1562,11 @@ bool ConstraintSolver::tryDispatchHasIndexer( if (0 == results.size()) emplaceType(asMutable(resultType), builtinTypes->errorType); else if (1 == results.size()) - emplaceType(asMutable(resultType), *results.begin()); + { + TypeId firstResult = *results.begin(); + shiftReferences(resultType, firstResult); + emplaceType(asMutable(resultType), firstResult); + } else emplaceType(asMutable(resultType), std::vector(results.begin(), results.end())); @@ -1716,7 +1731,10 @@ bool ConstraintSolver::tryDispatchUnpack1(NotNull constraint, --lt->blockCount; if (0 == lt->blockCount) + { + shiftReferences(ty, lt->domain); emplaceType(asMutable(ty), lt->domain); + } }; if (auto ut = get(resultTy)) @@ -1732,6 +1750,7 @@ bool ConstraintSolver::tryDispatchUnpack1(NotNull constraint, // constitute any meaningful constraint, so we replace it // with a free type. TypeId f = freshType(arena, builtinTypes, constraint->scope); + shiftReferences(resultTy, f); emplaceType(asMutable(resultTy), f); } else @@ -1798,7 +1817,10 @@ bool ConstraintSolver::tryDispatch(const UnpackConstraint& c, NotNullblockCount; if (0 == lt->blockCount) + { + shiftReferences(resultTy, lt->domain); emplaceType(asMutable(resultTy), lt->domain); + } } else if (get(resultTy) || get(resultTy)) { @@ -1977,7 +1999,10 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl LUAU_ASSERT(0 <= lt->blockCount); if (0 == lt->blockCount) + { + shiftReferences(ty, lt->domain); emplaceType(asMutable(ty), lt->domain); + } } } } @@ -2395,10 +2420,15 @@ void ConstraintSolver::bindBlockedType(TypeId blockedTy, TypeId resultTy, TypeId LUAU_ASSERT(freeScope); - emplaceType(asMutable(blockedTy), arena->freshType(freeScope)); + TypeId freeType = arena->freshType(freeScope); + shiftReferences(blockedTy, freeType); + emplaceType(asMutable(blockedTy), freeType); } else + { + shiftReferences(blockedTy, resultTy); emplaceType(asMutable(blockedTy), resultTy); + } } bool ConstraintSolver::block_(BlockedConstraintId target, NotNull constraint) @@ -2700,10 +2730,43 @@ void ConstraintSolver::reportError(TypeError e) errors.back().moduleName = currentModuleName; } +void ConstraintSolver::shiftReferences(TypeId source, TypeId target) +{ + target = follow(target); + + // if the target isn't a reference counted type, there's nothing to do. + // this stops us from keeping unnecessary counts for e.g. primitive types. + if (!isReferenceCountedType(target)) + return; + + auto sourceRefs = unresolvedConstraints.find(source); + if (!sourceRefs) + return; + + // we read out the count before proceeding to avoid hash invalidation issues. + size_t count = *sourceRefs; + + auto [targetRefs, _] = unresolvedConstraints.try_insert(target, 0); + targetRefs += count; +} + +std::optional ConstraintSolver::generalizeFreeType(NotNull scope, TypeId type) +{ + if (get(type)) + { + auto refCount = unresolvedConstraints.find(type); + if (!refCount || *refCount > 1) + return {}; + } + + Unifier2 u2{NotNull{arena}, builtinTypes, scope, NotNull{&iceReporter}}; + return u2.generalize(type); +} + bool ConstraintSolver::hasUnresolvedConstraints(TypeId ty) { if (auto refCount = unresolvedConstraints.find(ty)) - return *refCount > 0; + return *refCount > 1; return false; } diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 55cff7f6..5261c211 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -1297,6 +1297,30 @@ ModulePtr check(const SourceModule& sourceModule, Mode mode, const std::vectortype = sourceModule.type; result->upperBoundContributors = std::move(cs.upperBoundContributors); + if (result->timeout || result->cancelled) + { + // If solver was interrupted, skip typechecking and replace all module results with error-supressing types to avoid leaking blocked/pending + // types + ScopePtr moduleScope = result->getModuleScope(); + moduleScope->returnType = builtinTypes->errorRecoveryTypePack(); + + for (auto& [name, ty] : result->declaredGlobals) + ty = builtinTypes->errorRecoveryType(); + + for (auto& [name, tf] : result->exportedTypeBindings) + tf.type = builtinTypes->errorRecoveryType(); + } + else + { + if (mode == Mode::Nonstrict) + Luau::checkNonStrict(builtinTypes, iceHandler, NotNull{&unifierState}, NotNull{&dfg}, NotNull{&limits}, sourceModule, result.get()); + else + Luau::check(builtinTypes, NotNull{&unifierState}, NotNull{&limits}, logger.get(), sourceModule, result.get()); + } + + unfreeze(result->interfaceTypes); + result->clonePublicInterface(builtinTypes, *iceHandler); + if (FFlag::DebugLuauForbidInternalTypes) { InternalTypeFinder finder; @@ -1325,30 +1349,6 @@ ModulePtr check(const SourceModule& sourceModule, Mode mode, const std::vectortimeout || result->cancelled) - { - // If solver was interrupted, skip typechecking and replace all module results with error-supressing types to avoid leaking blocked/pending - // types - ScopePtr moduleScope = result->getModuleScope(); - moduleScope->returnType = builtinTypes->errorRecoveryTypePack(); - - for (auto& [name, ty] : result->declaredGlobals) - ty = builtinTypes->errorRecoveryType(); - - for (auto& [name, tf] : result->exportedTypeBindings) - tf.type = builtinTypes->errorRecoveryType(); - } - else - { - if (mode == Mode::Nonstrict) - Luau::checkNonStrict(builtinTypes, iceHandler, NotNull{&unifierState}, NotNull{&dfg}, NotNull{&limits}, sourceModule, result.get()); - else - Luau::check(builtinTypes, NotNull{&unifierState}, NotNull{&limits}, logger.get(), sourceModule, result.get()); - } - - unfreeze(result->interfaceTypes); - result->clonePublicInterface(builtinTypes, *iceHandler); - // It would be nice if we could freeze the arenas before doing type // checking, but we'll have to do some work to get there. // diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index a124be66..3c63a7fd 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -17,21 +17,16 @@ LUAU_FASTFLAGVARIABLE(DebugLuauCheckNormalizeInvariant, false) LUAU_FASTFLAGVARIABLE(LuauNormalizeAwayUninhabitableTables, false) -LUAU_FASTFLAGVARIABLE(LuauFixNormalizeCaching, false); LUAU_FASTFLAGVARIABLE(LuauNormalizeNotUnknownIntersection, false); LUAU_FASTFLAGVARIABLE(LuauFixCyclicUnionsOfIntersections, false); LUAU_FASTFLAGVARIABLE(LuauFixReduceStackPressure, false); +LUAU_FASTFLAGVARIABLE(LuauFixCyclicTablesBlowingStack, false); // This could theoretically be 2000 on amd64, but x86 requires this. LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200); LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); -static bool fixNormalizeCaching() -{ - return FFlag::LuauFixNormalizeCaching || FFlag::DebugLuauDeferredConstraintResolution; -} - static bool fixCyclicUnionsOfIntersections() { return FFlag::LuauFixCyclicUnionsOfIntersections || FFlag::DebugLuauDeferredConstraintResolution; @@ -42,6 +37,11 @@ static bool fixReduceStackPressure() return FFlag::LuauFixReduceStackPressure || FFlag::DebugLuauDeferredConstraintResolution; } +static bool fixCyclicTablesBlowingStack() +{ + return FFlag::LuauFixCyclicTablesBlowingStack || FFlag::DebugLuauDeferredConstraintResolution; +} + namespace Luau { @@ -353,6 +353,12 @@ bool NormalizedType::isSubtypeOfString() const !hasBuffers() && !hasTables() && !hasFunctions() && !hasTyvars(); } +bool NormalizedType::isSubtypeOfBooleans() const +{ + return hasBooleans() && !hasTops() && !hasClasses() && !hasErrors() && !hasNils() && !hasNumbers() && !hasStrings() && !hasThreads() && + !hasBuffers() && !hasTables() && !hasFunctions() && !hasTyvars(); +} + bool NormalizedType::shouldSuppressErrors() const { return hasErrors() || get(tops); @@ -561,22 +567,21 @@ NormalizationResult Normalizer::isInhabited(TypeId ty, Set& seen) return isInhabited(mtv->metatable, seen); } - if (fixNormalizeCaching()) - { - std::shared_ptr norm = normalize(ty); - return isInhabited(norm.get(), seen); - } - else - { - const NormalizedType* norm = DEPRECATED_normalize(ty); - return isInhabited(norm, seen); - } + std::shared_ptr norm = normalize(ty); + return isInhabited(norm.get(), seen); } NormalizationResult Normalizer::isIntersectionInhabited(TypeId left, TypeId right) +{ + Set seen{nullptr}; + return isIntersectionInhabited(left, right, seen); +} + +NormalizationResult Normalizer::isIntersectionInhabited(TypeId left, TypeId right, Set& seenSet) { left = follow(left); right = follow(right); + // We're asking if intersection is inahbited between left and right but we've already seen them .... if (cacheInhabitance) { @@ -584,12 +589,8 @@ NormalizationResult Normalizer::isIntersectionInhabited(TypeId left, TypeId righ return *result ? NormalizationResult::True : NormalizationResult::False; } - Set seen{nullptr}; - seen.insert(left); - seen.insert(right); - NormalizedType norm{builtinTypes}; - NormalizationResult res = normalizeIntersections({left, right}, norm); + NormalizationResult res = normalizeIntersections({left, right}, norm, seenSet); if (res != NormalizationResult::True) { if (cacheInhabitance && res == NormalizationResult::False) @@ -598,7 +599,7 @@ NormalizationResult Normalizer::isIntersectionInhabited(TypeId left, TypeId righ return res; } - NormalizationResult result = isInhabited(&norm, seen); + NormalizationResult result = isInhabited(&norm, seenSet); if (cacheInhabitance && result == NormalizationResult::True) cachedIsInhabitedIntersection[{left, right}] = true; @@ -870,31 +871,6 @@ Normalizer::Normalizer(TypeArena* arena, NotNull builtinTypes, Not { } -const NormalizedType* Normalizer::DEPRECATED_normalize(TypeId ty) -{ - if (!arena) - sharedState->iceHandler->ice("Normalizing types outside a module"); - - auto found = cachedNormals.find(ty); - if (found != cachedNormals.end()) - return found->second.get(); - - NormalizedType norm{builtinTypes}; - Set seenSetTypes{nullptr}; - NormalizationResult res = unionNormalWithTy(norm, ty, seenSetTypes); - if (res != NormalizationResult::True) - return nullptr; - if (norm.isUnknown()) - { - clearNormal(norm); - norm.tops = builtinTypes->unknownType; - } - std::shared_ptr shared = std::make_shared(std::move(norm)); - const NormalizedType* result = shared.get(); - cachedNormals[ty] = std::move(shared); - return result; -} - static bool isCacheable(TypeId ty, Set& seen); static bool isCacheable(TypePackId tp, Set& seen) @@ -949,9 +925,6 @@ static bool isCacheable(TypeId ty, Set& seen) static bool isCacheable(TypeId ty) { - if (!fixNormalizeCaching()) - return true; - Set seen{nullptr}; return isCacheable(ty, seen); } @@ -985,7 +958,7 @@ std::shared_ptr Normalizer::normalize(TypeId ty) return shared; } -NormalizationResult Normalizer::normalizeIntersections(const std::vector& intersections, NormalizedType& outType) +NormalizationResult Normalizer::normalizeIntersections(const std::vector& intersections, NormalizedType& outType, Set& seenSet) { if (!arena) sharedState->iceHandler->ice("Normalizing types outside a module"); @@ -995,7 +968,7 @@ NormalizationResult Normalizer::normalizeIntersections(const std::vector Set seenSetTypes{nullptr}; for (auto ty : intersections) { - NormalizationResult res = intersectNormalWithTy(norm, ty, seenSetTypes); + NormalizationResult res = intersectNormalWithTy(norm, ty, seenSet); if (res != NormalizationResult::True) return res; } @@ -1743,20 +1716,13 @@ bool Normalizer::withinResourceLimits() return true; } -NormalizationResult Normalizer::intersectNormalWithNegationTy(TypeId toNegate, NormalizedType& intersect, bool useDeprecated) +NormalizationResult Normalizer::intersectNormalWithNegationTy(TypeId toNegate, NormalizedType& intersect) { std::optional negated; - if (useDeprecated) - { - const NormalizedType* normal = DEPRECATED_normalize(toNegate); - negated = negateNormal(*normal); - } - else - { - std::shared_ptr normal = normalize(toNegate); - negated = negateNormal(*normal); - } + + std::shared_ptr normal = normalize(toNegate); + negated = negateNormal(*normal); if (!negated) return NormalizationResult::False; @@ -1911,16 +1877,8 @@ NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId t { std::optional tn; - if (fixNormalizeCaching()) - { - std::shared_ptr thereNormal = normalize(ntv->ty); - tn = negateNormal(*thereNormal); - } - else - { - const NormalizedType* thereNormal = DEPRECATED_normalize(ntv->ty); - tn = negateNormal(*thereNormal); - } + std::shared_ptr thereNormal = normalize(ntv->ty); + tn = negateNormal(*thereNormal); if (!tn) return NormalizationResult::False; @@ -2519,7 +2477,7 @@ std::optional Normalizer::intersectionOfTypePacks(TypePackId here, T return arena->addTypePack({}); } -std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there) +std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there, Set& seenSet) { if (here == there) return here; @@ -2600,8 +2558,33 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there // if the intersection of the read types of a property is uninhabited, the whole table is `never`. if (fixReduceStackPressure()) { - if (normalizeAwayUninhabitableTables() && - NormalizationResult::True != isIntersectionInhabited(*hprop.readTy, *tprop.readTy)) + // We've seen these table prop elements before and we're about to ask if their intersection + // is inhabited + if (fixCyclicTablesBlowingStack()) + { + if (seenSet.contains(*hprop.readTy) && seenSet.contains(*tprop.readTy)) + { + seenSet.erase(*hprop.readTy); + seenSet.erase(*tprop.readTy); + return {builtinTypes->neverType}; + } + else + { + seenSet.insert(*hprop.readTy); + seenSet.insert(*tprop.readTy); + } + } + + NormalizationResult res = isIntersectionInhabited(*hprop.readTy, *tprop.readTy, seenSet); + + // Cleanup + if (fixCyclicTablesBlowingStack()) + { + seenSet.erase(*hprop.readTy); + seenSet.erase(*tprop.readTy); + } + + if (normalizeAwayUninhabitableTables() && NormalizationResult::True != res) return {builtinTypes->neverType}; } else @@ -2720,7 +2703,7 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there if (tmtable && hmtable) { // NOTE: this assumes metatables are ivariant - if (std::optional mtable = intersectionOfTables(hmtable, tmtable)) + if (std::optional mtable = intersectionOfTables(hmtable, tmtable, seenSet)) { if (table == htable && *mtable == hmtable) return here; @@ -2750,12 +2733,12 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there return table; } -void Normalizer::intersectTablesWithTable(TypeIds& heres, TypeId there) +void Normalizer::intersectTablesWithTable(TypeIds& heres, TypeId there, Set& seenSetTypes) { TypeIds tmp; for (TypeId here : heres) { - if (std::optional inter = intersectionOfTables(here, there)) + if (std::optional inter = intersectionOfTables(here, there, seenSetTypes)) tmp.insert(*inter); } heres.retain(tmp); @@ -2769,7 +2752,8 @@ void Normalizer::intersectTables(TypeIds& heres, const TypeIds& theres) { for (TypeId there : theres) { - if (std::optional inter = intersectionOfTables(here, there)) + Set seenSetTypes{nullptr}; + if (std::optional inter = intersectionOfTables(here, there, seenSetTypes)) tmp.insert(*inter); } } @@ -3137,7 +3121,7 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type { TypeIds tables = std::move(here.tables); clearNormal(here); - intersectTablesWithTable(tables, there); + intersectTablesWithTable(tables, there, seenSetTypes); here.tables = std::move(tables); } else if (get(there)) @@ -3211,50 +3195,17 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type subtractSingleton(here, follow(ntv->ty)); else if (get(t)) { - if (fixNormalizeCaching()) - { - NormalizationResult res = intersectNormalWithNegationTy(t, here); - if (shouldEarlyExit(res)) - return res; - } - else - { - NormalizationResult res = intersectNormalWithNegationTy(t, here, /* useDeprecated */ true); - if (shouldEarlyExit(res)) - return res; - } + NormalizationResult res = intersectNormalWithNegationTy(t, here); + if (shouldEarlyExit(res)) + return res; } else if (const UnionType* itv = get(t)) { - if (fixNormalizeCaching()) + for (TypeId part : itv->options) { - for (TypeId part : itv->options) - { - NormalizationResult res = intersectNormalWithNegationTy(part, here); - if (shouldEarlyExit(res)) - return res; - } - } - else - { - if (fixNormalizeCaching()) - { - for (TypeId part : itv->options) - { - NormalizationResult res = intersectNormalWithNegationTy(part, here); - if (shouldEarlyExit(res)) - return res; - } - } - else - { - for (TypeId part : itv->options) - { - NormalizationResult res = intersectNormalWithNegationTy(part, here, /* useDeprecated */ true); - if (shouldEarlyExit(res)) - return res; - } - } + NormalizationResult res = intersectNormalWithNegationTy(part, here); + if (shouldEarlyExit(res)) + return res; } } else if (get(t)) diff --git a/Analysis/src/Set.cpp b/Analysis/src/Set.cpp deleted file mode 100644 index 1819e28a..00000000 --- a/Analysis/src/Set.cpp +++ /dev/null @@ -1,5 +0,0 @@ -// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details - -#include "Luau/Common.h" - -LUAU_FASTFLAGVARIABLE(LuauFixSetIter, false) diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index cb6b2f4a..e3ee2252 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -20,7 +20,6 @@ #include LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) -LUAU_FASTFLAGVARIABLE(LuauToStringiteTypesSingleLine, false) /* * Enables increasing levels of verbosity for Luau type names when stringifying. diff --git a/Analysis/src/TypeFamily.cpp b/Analysis/src/TypeFamily.cpp index a685c216..7fac35c9 100644 --- a/Analysis/src/TypeFamily.cpp +++ b/Analysis/src/TypeFamily.cpp @@ -500,6 +500,15 @@ TypeFamilyReductionResult lenFamilyFn(TypeId instance, NotNullsolver) || get(operandTy)) return {std::nullopt, false, {operandTy}, {}}; + // if the type is free but has only one remaining reference, we can generalize it to its upper bound here. + if (ctx->solver) + { + std::optional maybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, operandTy); + if (!maybeGeneralized) + return {std::nullopt, false, {operandTy}, {}}; + operandTy = *maybeGeneralized; + } + std::shared_ptr normTy = ctx->normalizer->normalize(operandTy); NormalizationResult inhabited = ctx->normalizer->isInhabited(normTy.get()); @@ -576,6 +585,15 @@ TypeFamilyReductionResult unmFamilyFn(TypeId instance, NotNullsolver)) return {std::nullopt, false, {operandTy}, {}}; + // if the type is free but has only one remaining reference, we can generalize it to its upper bound here. + if (ctx->solver) + { + std::optional maybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, operandTy); + if (!maybeGeneralized) + return {std::nullopt, false, {operandTy}, {}}; + operandTy = *maybeGeneralized; + } + std::shared_ptr normTy = ctx->normalizer->normalize(operandTy); // if the operand failed to normalize, we can't reduce, but know nothing about inhabitance. @@ -674,6 +692,21 @@ TypeFamilyReductionResult numericBinopFamilyFn(TypeId instance, NotNull< else if (isPending(rhsTy, ctx->solver)) return {std::nullopt, false, {rhsTy}, {}}; + // if either type is free but has only one remaining reference, we can generalize it to its upper bound here. + if (ctx->solver) + { + std::optional lhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, lhsTy); + std::optional rhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, rhsTy); + + if (!lhsMaybeGeneralized) + return {std::nullopt, false, {lhsTy}, {}}; + else if (!rhsMaybeGeneralized) + return {std::nullopt, false, {rhsTy}, {}}; + + lhsTy = *lhsMaybeGeneralized; + rhsTy = *rhsMaybeGeneralized; + } + // TODO: Normalization needs to remove cyclic type families from a `NormalizedType`. std::shared_ptr normLhsTy = ctx->normalizer->normalize(lhsTy); std::shared_ptr normRhsTy = ctx->normalizer->normalize(rhsTy); @@ -895,6 +928,21 @@ TypeFamilyReductionResult concatFamilyFn(TypeId instance, NotNullsolver)) return {std::nullopt, false, {rhsTy}, {}}; + // if either type is free but has only one remaining reference, we can generalize it to its upper bound here. + if (ctx->solver) + { + std::optional lhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, lhsTy); + std::optional rhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, rhsTy); + + if (!lhsMaybeGeneralized) + return {std::nullopt, false, {lhsTy}, {}}; + else if (!rhsMaybeGeneralized) + return {std::nullopt, false, {rhsTy}, {}}; + + lhsTy = *lhsMaybeGeneralized; + rhsTy = *rhsMaybeGeneralized; + } + std::shared_ptr normLhsTy = ctx->normalizer->normalize(lhsTy); std::shared_ptr normRhsTy = ctx->normalizer->normalize(rhsTy); @@ -982,13 +1030,27 @@ TypeFamilyReductionResult andFamilyFn(TypeId instance, NotNullsolver)) return {std::nullopt, false, {lhsTy}, {}}; else if (isPending(rhsTy, ctx->solver)) return {std::nullopt, false, {rhsTy}, {}}; + // if either type is free but has only one remaining reference, we can generalize it to its upper bound here. + if (ctx->solver) + { + std::optional lhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, lhsTy); + std::optional rhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, rhsTy); + + if (!lhsMaybeGeneralized) + return {std::nullopt, false, {lhsTy}, {}}; + else if (!rhsMaybeGeneralized) + return {std::nullopt, false, {rhsTy}, {}}; + + lhsTy = *lhsMaybeGeneralized; + rhsTy = *rhsMaybeGeneralized; + } + // And evalutes to a boolean if the LHS is falsey, and the RHS type if LHS is truthy. SimplifyResult filteredLhs = simplifyIntersection(ctx->builtins, ctx->arena, lhsTy, ctx->builtins->falsyType); SimplifyResult overallResult = simplifyUnion(ctx->builtins, ctx->arena, rhsTy, filteredLhs.result); @@ -1025,6 +1087,21 @@ TypeFamilyReductionResult orFamilyFn(TypeId instance, NotNullsolver)) return {std::nullopt, false, {rhsTy}, {}}; + // if either type is free but has only one remaining reference, we can generalize it to its upper bound here. + if (ctx->solver) + { + std::optional lhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, lhsTy); + std::optional rhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, rhsTy); + + if (!lhsMaybeGeneralized) + return {std::nullopt, false, {lhsTy}, {}}; + else if (!rhsMaybeGeneralized) + return {std::nullopt, false, {rhsTy}, {}}; + + lhsTy = *lhsMaybeGeneralized; + rhsTy = *rhsMaybeGeneralized; + } + // Or evalutes to the LHS type if the LHS is truthy, and the RHS type if LHS is falsy. SimplifyResult filteredLhs = simplifyIntersection(ctx->builtins, ctx->arena, lhsTy, ctx->builtins->truthyType); SimplifyResult overallResult = simplifyUnion(ctx->builtins, ctx->arena, rhsTy, filteredLhs.result); @@ -1088,6 +1165,21 @@ static TypeFamilyReductionResult comparisonFamilyFn(TypeId instance, Not lhsTy = follow(lhsTy); rhsTy = follow(rhsTy); + // if either type is free but has only one remaining reference, we can generalize it to its upper bound here. + if (ctx->solver) + { + std::optional lhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, lhsTy); + std::optional rhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, rhsTy); + + if (!lhsMaybeGeneralized) + return {std::nullopt, false, {lhsTy}, {}}; + else if (!rhsMaybeGeneralized) + return {std::nullopt, false, {rhsTy}, {}}; + + lhsTy = *lhsMaybeGeneralized; + rhsTy = *rhsMaybeGeneralized; + } + // check to see if both operand types are resolved enough, and wait to reduce if not std::shared_ptr normLhsTy = ctx->normalizer->normalize(lhsTy); @@ -1196,6 +1288,21 @@ TypeFamilyReductionResult eqFamilyFn(TypeId instance, NotNullsolver)) return {std::nullopt, false, {rhsTy}, {}}; + // if either type is free but has only one remaining reference, we can generalize it to its upper bound here. + if (ctx->solver) + { + std::optional lhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, lhsTy); + std::optional rhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, rhsTy); + + if (!lhsMaybeGeneralized) + return {std::nullopt, false, {lhsTy}, {}}; + else if (!rhsMaybeGeneralized) + return {std::nullopt, false, {rhsTy}, {}}; + + lhsTy = *lhsMaybeGeneralized; + rhsTy = *rhsMaybeGeneralized; + } + std::shared_ptr normLhsTy = ctx->normalizer->normalize(lhsTy); std::shared_ptr normRhsTy = ctx->normalizer->normalize(rhsTy); NormalizationResult lhsInhabited = ctx->normalizer->isInhabited(normLhsTy.get()); @@ -1223,10 +1330,25 @@ TypeFamilyReductionResult eqFamilyFn(TypeId instance, NotNullnormalizer->isIntersectionInhabited(lhsTy, rhsTy); - if (!mmType && intersectInhabited == NormalizationResult::True) - return {ctx->builtins->booleanType, false, {}, {}}; // if it's inhabited, everything is okay! - else if (!mmType) + if (!mmType) + { + if (intersectInhabited == NormalizationResult::True) + return {ctx->builtins->booleanType, false, {}, {}}; // if it's inhabited, everything is okay! + + // we might be in a case where we still want to accept the comparison... + if (intersectInhabited == NormalizationResult::False) + { + // if they're both subtypes of `string` but have no common intersection, the comparison is allowed but always `false`. + if (normLhsTy->isSubtypeOfString() && normRhsTy->isSubtypeOfString()) + return {ctx->builtins->falseType, false, {}, {}}; + + // if they're both subtypes of `boolean` but have no common intersection, the comparison is allowed but always `false`. + if (normLhsTy->isSubtypeOfBooleans() && normRhsTy->isSubtypeOfBooleans()) + return {ctx->builtins->falseType, false, {}, {}}; + } + return {std::nullopt, true, {}, {}}; // if it's not, then this family is irreducible! + } mmType = follow(*mmType); if (isPending(*mmType, ctx->solver)) @@ -1303,6 +1425,21 @@ TypeFamilyReductionResult refineFamilyFn(TypeId instance, NotNullsolver)) return {std::nullopt, false, {discriminantTy}, {}}; + // if either type is free but has only one remaining reference, we can generalize it to its upper bound here. + if (ctx->solver) + { + std::optional targetMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, targetTy); + std::optional discriminantMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, discriminantTy); + + if (!targetMaybeGeneralized) + return {std::nullopt, false, {targetTy}, {}}; + else if (!discriminantMaybeGeneralized) + return {std::nullopt, false, {discriminantTy}, {}}; + + targetTy = *targetMaybeGeneralized; + discriminantTy = *discriminantMaybeGeneralized; + } + // we need a more complex check for blocking on the discriminant in particular FindRefinementBlockers frb; frb.traverse(discriminantTy); @@ -1358,6 +1495,15 @@ TypeFamilyReductionResult singletonFamilyFn(TypeId instance, NotNullsolver)) return {std::nullopt, false, {type}, {}}; + // if the type is free but has only one remaining reference, we can generalize it to its upper bound here. + if (ctx->solver) + { + std::optional maybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, type); + if (!maybeGeneralized) + return {std::nullopt, false, {type}, {}}; + type = *maybeGeneralized; + } + TypeId followed = type; // we want to follow through a negation here as well. if (auto negation = get(followed)) diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 1c21ecb2..eed3c715 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -39,7 +39,6 @@ LUAU_FASTFLAGVARIABLE(LuauAlwaysCommitInferencesOfFunctionCalls, false) LUAU_FASTFLAGVARIABLE(LuauRemoveBadRelationalOperatorWarning, false) LUAU_FASTFLAGVARIABLE(LuauForbidAliasNamedTypeof, false) LUAU_FASTFLAGVARIABLE(LuauOkWithIteratingOverTableProperties, false) -LUAU_FASTFLAG(LuauFixNormalizeCaching) namespace Luau { @@ -2649,24 +2648,12 @@ static std::optional areEqComparable(NotNull arena, NotNulladdType(IntersectionType{{a, b}}); - std::shared_ptr n = normalizer->normalize(c); - if (!n) - return std::nullopt; + TypeId c = arena->addType(IntersectionType{{a, b}}); + std::shared_ptr n = normalizer->normalize(c); + if (!n) + return std::nullopt; - nr = normalizer->isInhabited(n.get()); - } - else - { - TypeId c = arena->addType(IntersectionType{{a, b}}); - const NormalizedType* n = normalizer->DEPRECATED_normalize(c); - if (!n) - return std::nullopt; - - nr = normalizer->isInhabited(n); - } + nr = normalizer->isInhabited(n.get()); switch (nr) { diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 3dc274a9..484e45d0 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -23,7 +23,6 @@ LUAU_FASTFLAG(LuauAlwaysCommitInferencesOfFunctionCalls) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAGVARIABLE(LuauFixIndexerSubtypingOrdering, false) LUAU_FASTFLAGVARIABLE(LuauUnifierShouldNotCopyError, false) -LUAU_FASTFLAG(LuauFixNormalizeCaching) namespace Luau { @@ -580,28 +579,14 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool { if (normalize) { - if (FFlag::LuauFixNormalizeCaching) - { - // TODO: there are probably cheaper ways to check if any <: T. - std::shared_ptr superNorm = normalizer->normalize(superTy); + // TODO: there are probably cheaper ways to check if any <: T. + std::shared_ptr superNorm = normalizer->normalize(superTy); - if (!superNorm) - return reportError(location, NormalizationTooComplex{}); + if (!superNorm) + return reportError(location, NormalizationTooComplex{}); - if (!log.get(superNorm->tops)) - failure = true; - } - else - { - // TODO: there are probably cheaper ways to check if any <: T. - const NormalizedType* superNorm = normalizer->DEPRECATED_normalize(superTy); - - if (!superNorm) - return reportError(location, NormalizationTooComplex{}); - - if (!log.get(superNorm->tops)) - failure = true; - } + if (!log.get(superNorm->tops)) + failure = true; } else failure = true; @@ -962,30 +947,15 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp // We deal with this by type normalization. Unifier innerState = makeChildUnifier(); - if (FFlag::LuauFixNormalizeCaching) - { - std::shared_ptr subNorm = normalizer->normalize(subTy); - std::shared_ptr superNorm = normalizer->normalize(superTy); - if (!subNorm || !superNorm) - return reportError(location, NormalizationTooComplex{}); - else if ((failedOptionCount == 1 || foundHeuristic) && failedOption) - innerState.tryUnifyNormalizedTypes( - subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption); - else - innerState.tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the union options are compatible"); - } + std::shared_ptr subNorm = normalizer->normalize(subTy); + std::shared_ptr superNorm = normalizer->normalize(superTy); + if (!subNorm || !superNorm) + return reportError(location, NormalizationTooComplex{}); + else if ((failedOptionCount == 1 || foundHeuristic) && failedOption) + innerState.tryUnifyNormalizedTypes( + subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption); else - { - const NormalizedType* subNorm = normalizer->DEPRECATED_normalize(subTy); - const NormalizedType* superNorm = normalizer->DEPRECATED_normalize(superTy); - if (!subNorm || !superNorm) - return reportError(location, NormalizationTooComplex{}); - else if ((failedOptionCount == 1 || foundHeuristic) && failedOption) - innerState.tryUnifyNormalizedTypes( - subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption); - else - innerState.tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the union options are compatible"); - } + innerState.tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the union options are compatible"); if (!innerState.failure) log.concat(std::move(innerState.log)); @@ -999,30 +969,15 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp // It is possible that T <: A | B even though T subNorm = normalizer->normalize(subTy); - std::shared_ptr superNorm = normalizer->normalize(superTy); - if (!subNorm || !superNorm) - reportError(location, NormalizationTooComplex{}); - else if ((failedOptionCount == 1 || foundHeuristic) && failedOption) - tryUnifyNormalizedTypes( - subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption); - else - tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the union options are compatible"); - } + std::shared_ptr subNorm = normalizer->normalize(subTy); + std::shared_ptr superNorm = normalizer->normalize(superTy); + if (!subNorm || !superNorm) + reportError(location, NormalizationTooComplex{}); + else if ((failedOptionCount == 1 || foundHeuristic) && failedOption) + tryUnifyNormalizedTypes( + subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption); else - { - const NormalizedType* subNorm = normalizer->DEPRECATED_normalize(subTy); - const NormalizedType* superNorm = normalizer->DEPRECATED_normalize(superTy); - if (!subNorm || !superNorm) - reportError(location, NormalizationTooComplex{}); - else if ((failedOptionCount == 1 || foundHeuristic) && failedOption) - tryUnifyNormalizedTypes( - subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption); - else - tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the union options are compatible"); - } + tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the union options are compatible"); } else if (!found) { @@ -1125,24 +1080,12 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionType* // It is possible that A & B <: T even though A subNorm = normalizer->normalize(subTy); - std::shared_ptr superNorm = normalizer->normalize(superTy); - if (subNorm && superNorm) - tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the intersection parts are compatible"); - else - reportError(location, NormalizationTooComplex{}); - } + std::shared_ptr subNorm = normalizer->normalize(subTy); + std::shared_ptr superNorm = normalizer->normalize(superTy); + if (subNorm && superNorm) + tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the intersection parts are compatible"); else - { - const NormalizedType* subNorm = normalizer->DEPRECATED_normalize(subTy); - const NormalizedType* superNorm = normalizer->DEPRECATED_normalize(superTy); - if (subNorm && superNorm) - tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the intersection parts are compatible"); - else - reportError(location, NormalizationTooComplex{}); - } + reportError(location, NormalizationTooComplex{}); return; } @@ -1192,24 +1135,12 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionType* // for example string? & number? <: nil. // We deal with this by type normalization. - if (FFlag::LuauFixNormalizeCaching) - { - std::shared_ptr subNorm = normalizer->normalize(subTy); - std::shared_ptr superNorm = normalizer->normalize(superTy); - if (subNorm && superNorm) - tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the intersection parts are compatible"); - else - reportError(location, NormalizationTooComplex{}); - } + std::shared_ptr subNorm = normalizer->normalize(subTy); + std::shared_ptr superNorm = normalizer->normalize(superTy); + if (subNorm && superNorm) + tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the intersection parts are compatible"); else - { - const NormalizedType* subNorm = normalizer->DEPRECATED_normalize(subTy); - const NormalizedType* superNorm = normalizer->DEPRECATED_normalize(superTy); - if (subNorm && superNorm) - tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the intersection parts are compatible"); - else - reportError(location, NormalizationTooComplex{}); - } + reportError(location, NormalizationTooComplex{}); } else if (!found) { @@ -2712,32 +2643,16 @@ void Unifier::tryUnifyNegations(TypeId subTy, TypeId superTy) if (!log.get(subTy) && !log.get(superTy)) ice("tryUnifyNegations superTy or subTy must be a negation type"); - if (FFlag::LuauFixNormalizeCaching) - { - std::shared_ptr subNorm = normalizer->normalize(subTy); - std::shared_ptr superNorm = normalizer->normalize(superTy); - if (!subNorm || !superNorm) - return reportError(location, NormalizationTooComplex{}); + std::shared_ptr subNorm = normalizer->normalize(subTy); + std::shared_ptr superNorm = normalizer->normalize(superTy); + if (!subNorm || !superNorm) + return reportError(location, NormalizationTooComplex{}); - // T DEPRECATED_normalize(subTy); - const NormalizedType* superNorm = normalizer->DEPRECATED_normalize(superTy); - if (!subNorm || !superNorm) - return reportError(location, NormalizationTooComplex{}); - - // T & queue, DenseHashSet& seenTypePacks, Unifier& state, TypePackId a, TypePackId anyTypePack) diff --git a/Analysis/src/Unifier2.cpp b/Analysis/src/Unifier2.cpp index 34fc6ee9..1e90c0e8 100644 --- a/Analysis/src/Unifier2.cpp +++ b/Analysis/src/Unifier2.cpp @@ -204,25 +204,21 @@ bool Unifier2::unify(TypeId subTy, TypeId superTy) auto subAny = get(subTy); auto superAny = get(superTy); - if (subAny && superAny) - return true; - else if (subAny && superFn) - { - // If `any` is the subtype, then we can propagate that inward. - bool argResult = unify(superFn->argTypes, builtinTypes->anyTypePack); - bool retResult = unify(builtinTypes->anyTypePack, superFn->retTypes); - return argResult && retResult; - } - else if (subFn && superAny) - { - // If `any` is the supertype, then we can propagate that inward. - bool argResult = unify(builtinTypes->anyTypePack, subFn->argTypes); - bool retResult = unify(subFn->retTypes, builtinTypes->anyTypePack); - return argResult && retResult; - } auto subTable = getMutable(subTy); auto superTable = get(superTy); + + if (subAny && superAny) + return true; + else if (subAny && superFn) + return unify(subAny, superFn); + else if (subFn && superAny) + return unify(subFn, superAny); + else if (subAny && superTable) + return unify(subAny, superTable); + else if (subTable && superAny) + return unify(subTable, superAny); + if (subTable && superTable) { // `boundTo` works like a bound type, and therefore we'd replace it @@ -451,7 +447,16 @@ bool Unifier2::unify(TableType* subTable, const TableType* superTable) * an indexer, we therefore conclude that the unsealed table has the * same indexer. */ - subTable->indexer = *superTable->indexer; + + TypeId indexType = superTable->indexer->indexType; + if (TypeId* subst = genericSubstitutions.find(indexType)) + indexType = *subst; + + TypeId indexResultType = superTable->indexer->indexResultType; + if (TypeId* subst = genericSubstitutions.find(indexResultType)) + indexResultType = *subst; + + subTable->indexer = TableIndexer{indexType, indexResultType}; } return result; @@ -462,6 +467,62 @@ bool Unifier2::unify(const MetatableType* subMetatable, const MetatableType* sup return unify(subMetatable->metatable, superMetatable->metatable) && unify(subMetatable->table, superMetatable->table); } +bool Unifier2::unify(const AnyType* subAny, const FunctionType* superFn) +{ + // If `any` is the subtype, then we can propagate that inward. + bool argResult = unify(superFn->argTypes, builtinTypes->anyTypePack); + bool retResult = unify(builtinTypes->anyTypePack, superFn->retTypes); + return argResult && retResult; +} + +bool Unifier2::unify(const FunctionType* subFn, const AnyType* superAny) +{ + // If `any` is the supertype, then we can propagate that inward. + bool argResult = unify(builtinTypes->anyTypePack, subFn->argTypes); + bool retResult = unify(subFn->retTypes, builtinTypes->anyTypePack); + return argResult && retResult; +} + +bool Unifier2::unify(const AnyType* subAny, const TableType* superTable) +{ + for (const auto& [propName, prop]: superTable->props) + { + if (prop.readTy) + unify(builtinTypes->anyType, *prop.readTy); + + if (prop.writeTy) + unify(*prop.writeTy, builtinTypes->anyType); + } + + if (superTable->indexer) + { + unify(builtinTypes->anyType, superTable->indexer->indexType); + unify(builtinTypes->anyType, superTable->indexer->indexResultType); + } + + return true; +} + +bool Unifier2::unify(const TableType* subTable, const AnyType* superAny) +{ + for (const auto& [propName, prop]: subTable->props) + { + if (prop.readTy) + unify(*prop.readTy, builtinTypes->anyType); + + if (prop.writeTy) + unify(builtinTypes->anyType, *prop.writeTy); + } + + if (subTable->indexer) + { + unify(subTable->indexer->indexType, builtinTypes->anyType); + unify(subTable->indexer->indexResultType, builtinTypes->anyType); + } + + return true; +} + // FIXME? This should probably return an ErrorVec or an optional // rather than a boolean to signal an occurs check failure. bool Unifier2::unify(TypePackId subTp, TypePackId superTp) @@ -596,6 +657,43 @@ struct FreeTypeSearcher : TypeVisitor } } + DenseHashSet seenPositive{nullptr}; + DenseHashSet seenNegative{nullptr}; + + bool seenWithPolarity(const void* ty) + { + switch (polarity) + { + case Positive: + { + if (seenPositive.contains(ty)) + return true; + + seenPositive.insert(ty); + return false; + } + case Negative: + { + if (seenNegative.contains(ty)) + return true; + + seenNegative.insert(ty); + return false; + } + case Both: + { + if (seenPositive.contains(ty) && seenNegative.contains(ty)) + return true; + + seenPositive.insert(ty); + seenNegative.insert(ty); + return false; + } + } + + return false; + } + // The keys in these maps are either TypeIds or TypePackIds. It's safe to // mix them because we only use these pointers as unique keys. We never // indirect them. @@ -604,12 +702,18 @@ struct FreeTypeSearcher : TypeVisitor bool visit(TypeId ty) override { + if (seenWithPolarity(ty)) + return false; + LUAU_ASSERT(ty); return true; } bool visit(TypeId ty, const FreeType& ft) override { + if (seenWithPolarity(ty)) + return false; + if (!subsumes(scope, ft.scope)) return true; @@ -632,6 +736,9 @@ struct FreeTypeSearcher : TypeVisitor bool visit(TypeId ty, const TableType& tt) override { + if (seenWithPolarity(ty)) + return false; + if ((tt.state == TableState::Free || tt.state == TableState::Unsealed) && subsumes(scope, tt.scope)) { switch (polarity) @@ -675,6 +782,9 @@ struct FreeTypeSearcher : TypeVisitor bool visit(TypeId ty, const FunctionType& ft) override { + if (seenWithPolarity(ty)) + return false; + flip(); traverse(ft.argTypes); flip(); @@ -691,6 +801,9 @@ struct FreeTypeSearcher : TypeVisitor bool visit(TypePackId tp, const FreeTypePack& ftp) override { + if (seenWithPolarity(tp)) + return false; + if (!subsumes(scope, ftp.scope)) return true; diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index a7363552..8bbdf307 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -17,7 +17,6 @@ LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) // flag so that we don't break production games by reverting syntax changes. // See docs/SyntaxChanges.md for an explanation. LUAU_FASTFLAG(LuauCheckedFunctionSyntax) -LUAU_FASTFLAGVARIABLE(LuauReadWritePropertySyntax, false) LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false) namespace Luau @@ -1340,22 +1339,19 @@ AstType* Parser::parseTableType(bool inDeclarationContext) AstTableAccess access = AstTableAccess::ReadWrite; std::optional accessLocation; - if (FFlag::LuauReadWritePropertySyntax || FFlag::DebugLuauDeferredConstraintResolution) + if (lexer.current().type == Lexeme::Name && lexer.lookahead().type != ':') { - if (lexer.current().type == Lexeme::Name && lexer.lookahead().type != ':') + if (AstName(lexer.current().name) == "read") { - if (AstName(lexer.current().name) == "read") - { - accessLocation = lexer.current().location; - access = AstTableAccess::Read; - lexer.next(); - } - else if (AstName(lexer.current().name) == "write") - { - accessLocation = lexer.current().location; - access = AstTableAccess::Write; - lexer.next(); - } + accessLocation = lexer.current().location; + access = AstTableAccess::Read; + lexer.next(); + } + else if (AstName(lexer.current().name) == "write") + { + accessLocation = lexer.current().location; + access = AstTableAccess::Write; + lexer.next(); } } diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index d92dcd3e..501707e1 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -144,7 +144,10 @@ static int lua_require(lua_State* L) if (luau_load(ML, resolvedRequire.chunkName.c_str(), bytecode.data(), bytecode.size(), 0) == 0) { if (codegen) - Luau::CodeGen::compile(ML, -1); + { + Luau::CodeGen::CompilationOptions nativeOptions; + Luau::CodeGen::compile(ML, -1, nativeOptions); + } if (coverageActive()) coverageTrack(ML, -1); @@ -602,7 +605,10 @@ static bool runFile(const char* name, lua_State* GL, bool repl) if (luau_load(L, chunkname.c_str(), bytecode.data(), bytecode.size(), 0) == 0) { if (codegen) - Luau::CodeGen::compile(L, -1); + { + Luau::CodeGen::CompilationOptions nativeOptions; + Luau::CodeGen::compile(L, -1, nativeOptions); + } if (coverageActive()) coverageTrack(L, -1); diff --git a/CodeGen/include/Luau/BytecodeAnalysis.h b/CodeGen/include/Luau/BytecodeAnalysis.h index 11af90da..edfa6e3d 100644 --- a/CodeGen/include/Luau/BytecodeAnalysis.h +++ b/CodeGen/include/Luau/BytecodeAnalysis.h @@ -13,10 +13,11 @@ namespace CodeGen { struct IrFunction; +struct HostIrHooks; void loadBytecodeTypeInfo(IrFunction& function); void buildBytecodeBlocks(IrFunction& function, const std::vector& jumpTargets); -void analyzeBytecodeTypes(IrFunction& function); +void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks); } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/include/Luau/CodeGen.h b/CodeGen/include/Luau/CodeGen.h index 9b56034f..43993231 100644 --- a/CodeGen/include/Luau/CodeGen.h +++ b/CodeGen/include/Luau/CodeGen.h @@ -66,6 +66,39 @@ struct CompilationResult } }; +struct IrBuilder; + +using HostVectorOperationBytecodeType = uint8_t (*)(const char* member, size_t memberLength); +using HostVectorAccessHandler = bool (*)(IrBuilder& builder, const char* member, size_t memberLength, int resultReg, int sourceReg, int pcpos); +using HostVectorNamecallHandler = bool (*)( + IrBuilder& builder, const char* member, size_t memberLength, int argResReg, int sourceReg, int params, int results, int pcpos); + +struct HostIrHooks +{ + // Suggest result type of a vector field access + HostVectorOperationBytecodeType vectorAccessBytecodeType = nullptr; + + // Suggest result type of a vector function namecall + HostVectorOperationBytecodeType vectorNamecallBytecodeType = nullptr; + + // Handle vector value field access + // 'sourceReg' is guaranteed to be a vector + // Guards should take a VM exit to 'pcpos' + HostVectorAccessHandler vectorAccess = nullptr; + + // Handle namecalled performed on a vector value + // 'sourceReg' (self argument) is guaranteed to be a vector + // All other arguments can be of any type + // Guards should take a VM exit to 'pcpos' + HostVectorNamecallHandler vectorNamecall = nullptr; +}; + +struct CompilationOptions +{ + unsigned int flags = 0; + HostIrHooks hooks; +}; + struct CompilationStats { size_t bytecodeSizeBytes = 0; @@ -118,8 +151,11 @@ void setNativeExecutionEnabled(lua_State* L, bool enabled); using ModuleId = std::array; // Builds target function and all inner functions -CompilationResult compile(lua_State* L, int idx, unsigned int flags = 0, CompilationStats* stats = nullptr); -CompilationResult compile(const ModuleId& moduleId, lua_State* L, int idx, unsigned int flags = 0, CompilationStats* stats = nullptr); +CompilationResult compile(lua_State* L, int idx, unsigned int flags, CompilationStats* stats = nullptr); +CompilationResult compile(const ModuleId& moduleId, lua_State* L, int idx, unsigned int flags, CompilationStats* stats = nullptr); + +CompilationResult compile(lua_State* L, int idx, const CompilationOptions& options, CompilationStats* stats = nullptr); +CompilationResult compile(const ModuleId& moduleId, lua_State* L, int idx, const CompilationOptions& options, CompilationStats* stats = nullptr); using AnnotatorFn = void (*)(void* context, std::string& result, int fid, int instpos); @@ -164,7 +200,7 @@ struct AssemblyOptions Target target = Host; - unsigned int flags = 0; + CompilationOptions compilationOptions; bool outputBinary = false; diff --git a/CodeGen/include/Luau/IrBuilder.h b/CodeGen/include/Luau/IrBuilder.h index 6c975e85..8ad75fbe 100644 --- a/CodeGen/include/Luau/IrBuilder.h +++ b/CodeGen/include/Luau/IrBuilder.h @@ -16,11 +16,11 @@ namespace Luau namespace CodeGen { -struct AssemblyOptions; +struct HostIrHooks; struct IrBuilder { - IrBuilder(); + IrBuilder(const HostIrHooks& hostHooks); void buildFunctionIr(Proto* proto); @@ -64,13 +64,17 @@ struct IrBuilder IrOp vmExit(uint32_t pcpos); + const HostIrHooks& hostHooks; + bool inTerminatedBlock = false; bool interruptRequested = false; bool activeFastcallFallback = false; IrOp fastcallFallbackReturn; - int fastcallSkipTarget = -1; + + // Force builder to skip source commands + int cmdSkipTarget = -1; IrFunction function; diff --git a/CodeGen/src/BytecodeAnalysis.cpp b/CodeGen/src/BytecodeAnalysis.cpp index e3ce9166..a2f67ebb 100644 --- a/CodeGen/src/BytecodeAnalysis.cpp +++ b/CodeGen/src/BytecodeAnalysis.cpp @@ -2,6 +2,7 @@ #include "Luau/BytecodeAnalysis.h" #include "Luau/BytecodeUtils.h" +#include "Luau/CodeGen.h" #include "Luau/IrData.h" #include "Luau/IrUtils.h" @@ -17,6 +18,8 @@ LUAU_FASTFLAG(LuauLoadTypeInfo) // Because new VM typeinfo loa LUAU_FASTFLAGVARIABLE(LuauCodegenTypeInfo, false) // New analysis is flagged separately LUAU_FASTFLAG(LuauTypeInfoLookupImprovement) LUAU_FASTFLAGVARIABLE(LuauCodegenVectorMispredictFix, false) +LUAU_FASTFLAGVARIABLE(LuauCodegenAnalyzeHostVectorOps, false) +LUAU_FASTFLAGVARIABLE(LuauCodegenLoadTypeUpvalCheck, false) namespace Luau { @@ -95,7 +98,10 @@ void loadBytecodeTypeInfo(IrFunction& function) uint32_t upvalCount = readVarInt(data, offset); uint32_t localCount = readVarInt(data, offset); - CODEGEN_ASSERT(upvalCount == unsigned(proto->nups)); + if (!FFlag::LuauCodegenLoadTypeUpvalCheck) + { + CODEGEN_ASSERT(upvalCount == unsigned(proto->nups)); + } if (typeSize != 0) { @@ -114,6 +120,11 @@ void loadBytecodeTypeInfo(IrFunction& function) if (upvalCount != 0) { + if (FFlag::LuauCodegenLoadTypeUpvalCheck) + { + CODEGEN_ASSERT(upvalCount == unsigned(proto->nups)); + } + typeInfo.upvalueTypes.resize(upvalCount); uint8_t* types = (uint8_t*)data + offset; @@ -611,7 +622,7 @@ void buildBytecodeBlocks(IrFunction& function, const std::vector& jumpT } } -void analyzeBytecodeTypes(IrFunction& function) +void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) { Proto* proto = function.proto; CODEGEN_ASSERT(proto); @@ -662,6 +673,8 @@ void analyzeBytecodeTypes(IrFunction& function) for (int i = proto->numparams; i < proto->maxstacksize; ++i) regTags[i] = LBC_TYPE_ANY; + LuauBytecodeType knownNextCallResult = LBC_TYPE_ANY; + for (int i = block.startpc; i <= block.finishpc;) { const Instruction* pc = &proto->code[i]; @@ -790,6 +803,9 @@ void analyzeBytecodeTypes(IrFunction& function) if (ch == 'x' || ch == 'y' || ch == 'z') regTags[ra] = LBC_TYPE_NUMBER; } + + if (FFlag::LuauCodegenAnalyzeHostVectorOps && regTags[ra] == LBC_TYPE_ANY && hostHooks.vectorAccessBytecodeType) + regTags[ra] = hostHooks.vectorAccessBytecodeType(field, str->len); } } else @@ -1161,6 +1177,34 @@ void analyzeBytecodeTypes(IrFunction& function) regTags[ra + 1] = bcType.a; bcType.result = LBC_TYPE_FUNCTION; + + if (FFlag::LuauCodegenAnalyzeHostVectorOps && bcType.a == LBC_TYPE_VECTOR && hostHooks.vectorNamecallBytecodeType) + { + TString* str = gco2ts(function.proto->k[kc].value.gc); + const char* field = getstr(str); + + knownNextCallResult = LuauBytecodeType(hostHooks.vectorNamecallBytecodeType(field, str->len)); + } + } + break; + } + case LOP_CALL: + { + if (FFlag::LuauCodegenAnalyzeHostVectorOps) + { + int ra = LUAU_INSN_A(*pc); + + if (knownNextCallResult != LBC_TYPE_ANY) + { + bcType.result = knownNextCallResult; + + knownNextCallResult = LBC_TYPE_ANY; + + regTags[ra] = bcType.result; + } + + if (FFlag::LuauCodegenTypeInfo) + refineRegType(bcTypeInfo, ra, i, bcType.result); } break; } @@ -1199,7 +1243,6 @@ void analyzeBytecodeTypes(IrFunction& function) } case LOP_GETGLOBAL: case LOP_SETGLOBAL: - case LOP_CALL: case LOP_RETURN: case LOP_JUMP: case LOP_JUMPBACK: diff --git a/CodeGen/src/CodeGen.cpp b/CodeGen/src/CodeGen.cpp index a5f6721e..3938ab12 100644 --- a/CodeGen/src/CodeGen.cpp +++ b/CodeGen/src/CodeGen.cpp @@ -201,12 +201,12 @@ static void logPerfFunction(Proto* p, uintptr_t addr, unsigned size) } template -static std::optional createNativeFunction( - AssemblyBuilder& build, ModuleHelpers& helpers, Proto* proto, uint32_t& totalIrInstCount, CodeGenCompilationResult& result) +static std::optional createNativeFunction(AssemblyBuilder& build, ModuleHelpers& helpers, Proto* proto, uint32_t& totalIrInstCount, + const HostIrHooks& hooks, CodeGenCompilationResult& result) { CODEGEN_ASSERT(!FFlag::LuauCodegenContext); - IrBuilder ir; + IrBuilder ir(hooks); ir.buildFunctionIr(proto); unsigned instCount = unsigned(ir.function.instructions.size()); @@ -476,7 +476,7 @@ void setNativeExecutionEnabled(lua_State* L, bool enabled) } } -static CompilationResult compile_OLD(lua_State* L, int idx, unsigned int flags, CompilationStats* stats) +static CompilationResult compile_OLD(lua_State* L, int idx, const CompilationOptions& options, CompilationStats* stats) { CompilationResult compilationResult; @@ -485,7 +485,7 @@ static CompilationResult compile_OLD(lua_State* L, int idx, unsigned int flags, Proto* root = clvalue(func)->l.p; - if ((flags & CodeGen_OnlyNativeModules) != 0 && (root->flags & LPF_NATIVE_MODULE) == 0) + if ((options.flags & CodeGen_OnlyNativeModules) != 0 && (root->flags & LPF_NATIVE_MODULE) == 0) { compilationResult.result = CodeGenCompilationResult::NotNativeModule; return compilationResult; @@ -500,7 +500,7 @@ static CompilationResult compile_OLD(lua_State* L, int idx, unsigned int flags, } std::vector protos; - gatherFunctions(protos, root, flags); + gatherFunctions(protos, root, options.flags); // Skip protos that have been compiled during previous invocations of CodeGen::compile protos.erase(std::remove_if(protos.begin(), protos.end(), @@ -541,7 +541,7 @@ static CompilationResult compile_OLD(lua_State* L, int idx, unsigned int flags, { CodeGenCompilationResult protoResult = CodeGenCompilationResult::Success; - if (std::optional np = createNativeFunction(build, helpers, p, totalIrInstCount, protoResult)) + if (std::optional np = createNativeFunction(build, helpers, p, totalIrInstCount, options.hooks, protoResult)) results.push_back(*np); else compilationResult.protoFailures.push_back({protoResult, p->debugname ? getstr(p->debugname) : "", p->linedefined}); @@ -618,13 +618,15 @@ static CompilationResult compile_OLD(lua_State* L, int idx, unsigned int flags, CompilationResult compile(lua_State* L, int idx, unsigned int flags, CompilationStats* stats) { + Luau::CodeGen::CompilationOptions options{flags}; + if (FFlag::LuauCodegenContext) { - return compile_NEW(L, idx, flags, stats); + return compile_NEW(L, idx, options, stats); } else { - return compile_OLD(L, idx, flags, stats); + return compile_OLD(L, idx, options, stats); } } @@ -632,7 +634,27 @@ CompilationResult compile(const ModuleId& moduleId, lua_State* L, int idx, unsig { CODEGEN_ASSERT(FFlag::LuauCodegenContext); - return compile_NEW(moduleId, L, idx, flags, stats); + Luau::CodeGen::CompilationOptions options{flags}; + return compile_NEW(moduleId, L, idx, options, stats); +} + +CompilationResult compile(lua_State* L, int idx, const CompilationOptions& options, CompilationStats* stats) +{ + if (FFlag::LuauCodegenContext) + { + return compile_NEW(L, idx, options, stats); + } + else + { + return compile_OLD(L, idx, options, stats); + } +} + +CompilationResult compile(const ModuleId& moduleId, lua_State* L, int idx, const CompilationOptions& options, CompilationStats* stats) +{ + CODEGEN_ASSERT(FFlag::LuauCodegenContext); + + return compile_NEW(moduleId, L, idx, options, stats); } void setPerfLog(void* context, PerfLogFn logFn) diff --git a/CodeGen/src/CodeGenAssembly.cpp b/CodeGen/src/CodeGenAssembly.cpp index 8324b7cc..e9402426 100644 --- a/CodeGen/src/CodeGenAssembly.cpp +++ b/CodeGen/src/CodeGenAssembly.cpp @@ -183,11 +183,11 @@ static std::string getAssemblyImpl(AssemblyBuilder& build, const TValue* func, A { Proto* root = clvalue(func)->l.p; - if ((options.flags & CodeGen_OnlyNativeModules) != 0 && (root->flags & LPF_NATIVE_MODULE) == 0) + if ((options.compilationOptions.flags & CodeGen_OnlyNativeModules) != 0 && (root->flags & LPF_NATIVE_MODULE) == 0) return std::string(); std::vector protos; - gatherFunctions(protos, root, options.flags); + gatherFunctions(protos, root, options.compilationOptions.flags); protos.erase(std::remove_if(protos.begin(), protos.end(), [](Proto* p) { @@ -215,7 +215,7 @@ static std::string getAssemblyImpl(AssemblyBuilder& build, const TValue* func, A for (Proto* p : protos) { - IrBuilder ir; + IrBuilder ir(options.compilationOptions.hooks); ir.buildFunctionIr(p); unsigned asmSize = build.getCodeSize(); unsigned asmCount = build.getInstructionCount(); diff --git a/CodeGen/src/CodeGenContext.cpp b/CodeGen/src/CodeGenContext.cpp index d9e3c4b3..cb542036 100644 --- a/CodeGen/src/CodeGenContext.cpp +++ b/CodeGen/src/CodeGenContext.cpp @@ -478,12 +478,12 @@ void create_NEW(lua_State* L, SharedCodeGenContext* codeGenContext) } template -[[nodiscard]] static NativeProtoExecDataPtr createNativeFunction( - AssemblyBuilder& build, ModuleHelpers& helpers, Proto* proto, uint32_t& totalIrInstCount, CodeGenCompilationResult& result) +[[nodiscard]] static NativeProtoExecDataPtr createNativeFunction(AssemblyBuilder& build, ModuleHelpers& helpers, Proto* proto, + uint32_t& totalIrInstCount, const HostIrHooks& hooks, CodeGenCompilationResult& result) { CODEGEN_ASSERT(FFlag::LuauCodegenContext); - IrBuilder ir; + IrBuilder ir(hooks); ir.buildFunctionIr(proto); unsigned instCount = unsigned(ir.function.instructions.size()); @@ -505,7 +505,7 @@ template } [[nodiscard]] static CompilationResult compileInternal( - const std::optional& moduleId, lua_State* L, int idx, unsigned int flags, CompilationStats* stats) + const std::optional& moduleId, lua_State* L, int idx, const CompilationOptions& options, CompilationStats* stats) { CODEGEN_ASSERT(FFlag::LuauCodegenContext); CODEGEN_ASSERT(lua_isLfunction(L, idx)); @@ -513,7 +513,7 @@ template Proto* root = clvalue(func)->l.p; - if ((flags & CodeGen_OnlyNativeModules) != 0 && (root->flags & LPF_NATIVE_MODULE) == 0) + if ((options.flags & CodeGen_OnlyNativeModules) != 0 && (root->flags & LPF_NATIVE_MODULE) == 0) return CompilationResult{CodeGenCompilationResult::NotNativeModule}; BaseCodeGenContext* codeGenContext = getCodeGenContext(L); @@ -521,7 +521,7 @@ template return CompilationResult{CodeGenCompilationResult::CodeGenNotInitialized}; std::vector protos; - gatherFunctions(protos, root, flags); + gatherFunctions(protos, root, options.flags); // Skip protos that have been compiled during previous invocations of CodeGen::compile protos.erase(std::remove_if(protos.begin(), protos.end(), @@ -572,7 +572,7 @@ template { CodeGenCompilationResult protoResult = CodeGenCompilationResult::Success; - NativeProtoExecDataPtr nativeExecData = createNativeFunction(build, helpers, protos[i], totalIrInstCount, protoResult); + NativeProtoExecDataPtr nativeExecData = createNativeFunction(build, helpers, protos[i], totalIrInstCount, options.hooks, protoResult); if (nativeExecData != nullptr) { nativeProtos.push_back(std::move(nativeExecData)); @@ -639,18 +639,18 @@ template return compilationResult; } -CompilationResult compile_NEW(const ModuleId& moduleId, lua_State* L, int idx, unsigned int flags, CompilationStats* stats) +CompilationResult compile_NEW(const ModuleId& moduleId, lua_State* L, int idx, const CompilationOptions& options, CompilationStats* stats) { CODEGEN_ASSERT(FFlag::LuauCodegenContext); - return compileInternal(moduleId, L, idx, flags, stats); + return compileInternal(moduleId, L, idx, options, stats); } -CompilationResult compile_NEW(lua_State* L, int idx, unsigned int flags, CompilationStats* stats) +CompilationResult compile_NEW(lua_State* L, int idx, const CompilationOptions& options, CompilationStats* stats) { CODEGEN_ASSERT(FFlag::LuauCodegenContext); - return compileInternal({}, L, idx, flags, stats); + return compileInternal({}, L, idx, options, stats); } [[nodiscard]] bool isNativeExecutionEnabled_NEW(lua_State* L) diff --git a/CodeGen/src/CodeGenContext.h b/CodeGen/src/CodeGenContext.h index ca338da5..c47121bc 100644 --- a/CodeGen/src/CodeGenContext.h +++ b/CodeGen/src/CodeGenContext.h @@ -107,8 +107,8 @@ void create_NEW(lua_State* L, size_t blockSize, size_t maxTotalSize, AllocationC // destroyed via lua_close. void create_NEW(lua_State* L, SharedCodeGenContext* codeGenContext); -CompilationResult compile_NEW(lua_State* L, int idx, unsigned int flags, CompilationStats* stats); -CompilationResult compile_NEW(const ModuleId& moduleId, lua_State* L, int idx, unsigned int flags, CompilationStats* stats); +CompilationResult compile_NEW(lua_State* L, int idx, const CompilationOptions& options, CompilationStats* stats); +CompilationResult compile_NEW(const ModuleId& moduleId, lua_State* L, int idx, const CompilationOptions& options, CompilationStats* stats); // Returns true if native execution is currently enabled for this VM [[nodiscard]] bool isNativeExecutionEnabled_NEW(lua_State* L); diff --git a/CodeGen/src/IrBuilder.cpp b/CodeGen/src/IrBuilder.cpp index 7d285aaf..76d015e9 100644 --- a/CodeGen/src/IrBuilder.cpp +++ b/CodeGen/src/IrBuilder.cpp @@ -15,6 +15,7 @@ LUAU_FASTFLAG(LuauLoadTypeInfo) // Because new VM typeinfo load changes the format used by Codegen, same flag is used LUAU_FASTFLAG(LuauTypeInfoLookupImprovement) +LUAU_FASTFLAG(LuauCodegenAnalyzeHostVectorOps) namespace Luau { @@ -23,8 +24,9 @@ namespace CodeGen constexpr unsigned kNoAssociatedBlockIndex = ~0u; -IrBuilder::IrBuilder() - : constantMap({IrConstKind::Tag, ~0ull}) +IrBuilder::IrBuilder(const HostIrHooks& hostHooks) + : hostHooks(hostHooks) + , constantMap({IrConstKind::Tag, ~0ull}) { } static bool hasTypedParameters_DEPRECATED(Proto* proto) @@ -230,7 +232,7 @@ void IrBuilder::buildFunctionIr(Proto* proto) rebuildBytecodeBasicBlocks(proto); // Infer register tags in bytecode - analyzeBytecodeTypes(function); + analyzeBytecodeTypes(function, hostHooks); function.bcMapping.resize(proto->sizecode, {~0u, ~0u}); @@ -283,10 +285,10 @@ void IrBuilder::buildFunctionIr(Proto* proto) translateInst(op, pc, i); - if (fastcallSkipTarget != -1) + if (cmdSkipTarget != -1) { - nexti = fastcallSkipTarget; - fastcallSkipTarget = -1; + nexti = cmdSkipTarget; + cmdSkipTarget = -1; } } @@ -613,7 +615,15 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) translateInstCapture(*this, pc, i); break; case LOP_NAMECALL: - translateInstNamecall(*this, pc, i); + if (FFlag::LuauCodegenAnalyzeHostVectorOps) + { + if (translateInstNamecall(*this, pc, i)) + cmdSkipTarget = i + 3; + } + else + { + translateInstNamecall(*this, pc, i); + } break; case LOP_PREPVARARGS: inst(IrCmd::FALLBACK_PREPVARARGS, constUint(i), constInt(LUAU_INSN_A(*pc))); @@ -654,7 +664,7 @@ void IrBuilder::handleFastcallFallback(IrOp fallbackOrUndef, const Instruction* } else { - fastcallSkipTarget = i + skip + 2; + cmdSkipTarget = i + skip + 2; } } diff --git a/CodeGen/src/IrTranslation.cpp b/CodeGen/src/IrTranslation.cpp index 84e3b639..291f618b 100644 --- a/CodeGen/src/IrTranslation.cpp +++ b/CodeGen/src/IrTranslation.cpp @@ -3,6 +3,7 @@ #include "Luau/Bytecode.h" #include "Luau/BytecodeUtils.h" +#include "Luau/CodeGen.h" #include "Luau/IrBuilder.h" #include "Luau/IrUtils.h" @@ -14,6 +15,7 @@ LUAU_FASTFLAGVARIABLE(LuauCodegenDirectUserdataFlow, false) LUAU_FASTFLAGVARIABLE(LuauCodegenFixVectorFields, false) +LUAU_FASTFLAG(LuauCodegenAnalyzeHostVectorOps) namespace Luau { @@ -1218,6 +1220,10 @@ void translateInstGetTableKS(IrBuilder& build, const Instruction* pc, int pcpos) } else { + if (FFlag::LuauCodegenAnalyzeHostVectorOps && build.hostHooks.vectorAccess && + build.hostHooks.vectorAccess(build, field, str->len, ra, rb, pcpos)) + return; + build.inst(IrCmd::FALLBACK_GETTABLEKS, build.constUint(pcpos), build.vmReg(ra), build.vmReg(rb), build.vmConst(aux)); } @@ -1376,7 +1382,7 @@ void translateInstCapture(IrBuilder& build, const Instruction* pc, int pcpos) } } -void translateInstNamecall(IrBuilder& build, const Instruction* pc, int pcpos) +bool translateInstNamecall(IrBuilder& build, const Instruction* pc, int pcpos) { int ra = LUAU_INSN_A(*pc); int rb = LUAU_INSN_B(*pc); @@ -1388,8 +1394,24 @@ void translateInstNamecall(IrBuilder& build, const Instruction* pc, int pcpos) { build.loadAndCheckTag(build.vmReg(rb), LUA_TVECTOR, build.vmExit(pcpos)); + if (FFlag::LuauCodegenAnalyzeHostVectorOps && build.hostHooks.vectorNamecall) + { + Instruction call = pc[2]; + CODEGEN_ASSERT(LUAU_INSN_OP(call) == LOP_CALL); + + int callra = LUAU_INSN_A(call); + int nparams = LUAU_INSN_B(call) - 1; + int nresults = LUAU_INSN_C(call) - 1; + + TString* str = gco2ts(build.function.proto->k[aux].value.gc); + const char* field = getstr(str); + + if (build.hostHooks.vectorNamecall(build, field, str->len, callra, rb, nparams, nresults, pcpos)) + return true; + } + build.inst(IrCmd::FALLBACK_NAMECALL, build.constUint(pcpos), build.vmReg(ra), build.vmReg(rb), build.vmConst(aux)); - return; + return false; } if (FFlag::LuauCodegenDirectUserdataFlow && bcTypes.a == LBC_TYPE_USERDATA) @@ -1397,7 +1419,7 @@ void translateInstNamecall(IrBuilder& build, const Instruction* pc, int pcpos) build.loadAndCheckTag(build.vmReg(rb), LUA_TUSERDATA, build.vmExit(pcpos)); build.inst(IrCmd::FALLBACK_NAMECALL, build.constUint(pcpos), build.vmReg(ra), build.vmReg(rb), build.vmConst(aux)); - return; + return false; } IrOp next = build.blockAtInst(pcpos + getOpLength(LOP_NAMECALL)); @@ -1451,6 +1473,8 @@ void translateInstNamecall(IrBuilder& build, const Instruction* pc, int pcpos) build.inst(IrCmd::JUMP, next); build.beginBlock(next); + + return false; } void translateInstAndX(IrBuilder& build, const Instruction* pc, int pcpos, IrOp c) diff --git a/CodeGen/src/IrTranslation.h b/CodeGen/src/IrTranslation.h index b1f1e28b..5eb01450 100644 --- a/CodeGen/src/IrTranslation.h +++ b/CodeGen/src/IrTranslation.h @@ -61,7 +61,7 @@ void translateInstGetGlobal(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstSetGlobal(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstConcat(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstCapture(IrBuilder& build, const Instruction* pc, int pcpos); -void translateInstNamecall(IrBuilder& build, const Instruction* pc, int pcpos); +bool translateInstNamecall(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstAndX(IrBuilder& build, const Instruction* pc, int pcpos, IrOp c); void translateInstOrX(IrBuilder& build, const Instruction* pc, int pcpos, IrOp c); void translateInstNewClosure(IrBuilder& build, const Instruction* pc, int pcpos); diff --git a/CodeGen/src/lcodegen.cpp b/CodeGen/src/lcodegen.cpp index 0795cd48..1ad685a1 100644 --- a/CodeGen/src/lcodegen.cpp +++ b/CodeGen/src/lcodegen.cpp @@ -17,5 +17,6 @@ void luau_codegen_create(lua_State* L) void luau_codegen_compile(lua_State* L, int idx) { - Luau::CodeGen::compile(L, idx); + Luau::CodeGen::CompilationOptions options; + Luau::CodeGen::compile(L, idx, options); } diff --git a/Sources.cmake b/Sources.cmake index 6adbf283..79fad0e4 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -266,7 +266,6 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/Refinement.cpp Analysis/src/RequireTracer.cpp Analysis/src/Scope.cpp - Analysis/src/Set.cpp Analysis/src/Simplify.cpp Analysis/src/Substitution.cpp Analysis/src/Subtyping.cpp @@ -494,6 +493,7 @@ if(TARGET Luau.Conformance) target_sources(Luau.Conformance PRIVATE tests/RegisterCallbacks.h tests/RegisterCallbacks.cpp + tests/ConformanceIrHooks.h tests/Conformance.test.cpp tests/IrLowering.test.cpp tests/SharedCodeAllocator.test.cpp diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index d0d4e9be..c220f30b 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -3272,9 +3272,9 @@ TEST_CASE_FIXTURE(ACFixture, "string_singleton_in_if_statement") // https://github.com/Roblox/luau/issues/858 TEST_CASE_FIXTURE(ACFixture, "string_singleton_in_if_statement2") { - ScopedFastFlag sff[]{ - {FFlag::DebugLuauDeferredConstraintResolution, true}, - }; + // don't run this when the DCR flag isn't set + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; check(R"( --!strict diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 9333cb19..bd57a140 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -16,6 +16,7 @@ #include "doctest.h" #include "ScopedFlags.h" +#include "ConformanceIrHooks.h" #include #include @@ -48,6 +49,13 @@ static lua_CompileOptions defaultOptions() return copts; } +static Luau::CodeGen::CompilationOptions defaultCodegenOptions() +{ + Luau::CodeGen::CompilationOptions opts = {}; + opts.flags = Luau::CodeGen::CodeGen_ColdFunctions; + return opts; +} + static int lua_collectgarbage(lua_State* L) { static const char* const opts[] = {"stop", "restart", "collect", "count", "isrunning", "step", "setgoal", "setstepmul", "setstepsize", nullptr}; @@ -118,6 +126,15 @@ static int lua_vector_dot(lua_State* L) return 1; } +static int lua_vector_cross(lua_State* L) +{ + const float* a = luaL_checkvector(L, 1); + const float* b = luaL_checkvector(L, 2); + + lua_pushvector(L, a[1] * b[2] - a[2] * b[1], a[2] * b[0] - a[0] * b[2], a[0] * b[1] - a[1] * b[0]); + return 1; +} + static int lua_vector_index(lua_State* L) { const float* v = luaL_checkvector(L, 1); @@ -129,6 +146,14 @@ static int lua_vector_index(lua_State* L) return 1; } + if (strcmp(name, "Unit") == 0) + { + float invSqrt = 1.0f / sqrtf(v[0] * v[0] + v[1] * v[1] + v[2] * v[2]); + + lua_pushvector(L, v[0] * invSqrt, v[1] * invSqrt, v[2] * invSqrt); + return 1; + } + if (strcmp(name, "Dot") == 0) { lua_pushcfunction(L, lua_vector_dot, "Dot"); @@ -144,6 +169,9 @@ static int lua_vector_namecall(lua_State* L) { if (strcmp(str, "Dot") == 0) return lua_vector_dot(L); + + if (strcmp(str, "Cross") == 0) + return lua_vector_cross(L); } luaL_error(L, "%s is not a valid method of vector", luaL_checkstring(L, 1)); @@ -157,7 +185,8 @@ int lua_silence(lua_State* L) using StateRef = std::unique_ptr; static StateRef runConformance(const char* name, void (*setup)(lua_State* L) = nullptr, void (*yield)(lua_State* L) = nullptr, - lua_State* initialLuaState = nullptr, lua_CompileOptions* options = nullptr, bool skipCodegen = false) + lua_State* initialLuaState = nullptr, lua_CompileOptions* options = nullptr, bool skipCodegen = false, + Luau::CodeGen::CompilationOptions* codegenOptions = nullptr) { #ifdef LUAU_CONFORMANCE_SOURCE_DIR std::string path = LUAU_CONFORMANCE_SOURCE_DIR; @@ -238,7 +267,11 @@ static StateRef runConformance(const char* name, void (*setup)(lua_State* L) = n free(bytecode); if (result == 0 && codegen && !skipCodegen && luau_codegen_supported()) - Luau::CodeGen::compile(L, -1, Luau::CodeGen::CodeGen_ColdFunctions); + { + Luau::CodeGen::CompilationOptions nativeOpts = codegenOptions ? *codegenOptions : defaultCodegenOptions(); + + Luau::CodeGen::compile(L, -1, nativeOpts); + } int status = (result == 0) ? lua_resume(L, nullptr, 0) : LUA_ERRSYNTAX; @@ -533,12 +566,51 @@ TEST_CASE("Pack") TEST_CASE("Vector") { + lua_CompileOptions copts = defaultOptions(); + Luau::CodeGen::CompilationOptions nativeOpts = defaultCodegenOptions(); + + SUBCASE("NoIrHooks") + { + SUBCASE("O0") + { + copts.optimizationLevel = 0; + } + SUBCASE("O1") + { + copts.optimizationLevel = 1; + } + SUBCASE("O2") + { + copts.optimizationLevel = 2; + } + } + SUBCASE("IrHooks") + { + nativeOpts.hooks.vectorAccessBytecodeType = vectorAccessBytecodeType; + nativeOpts.hooks.vectorNamecallBytecodeType = vectorNamecallBytecodeType; + nativeOpts.hooks.vectorAccess = vectorAccess; + nativeOpts.hooks.vectorNamecall = vectorNamecall; + + SUBCASE("O0") + { + copts.optimizationLevel = 0; + } + SUBCASE("O1") + { + copts.optimizationLevel = 1; + } + SUBCASE("O2") + { + copts.optimizationLevel = 2; + } + } + runConformance( "vector.lua", [](lua_State* L) { setupVectorHelpers(L); }, - nullptr, nullptr, nullptr); + nullptr, nullptr, &copts, false, &nativeOpts); } static void populateRTTI(lua_State* L, Luau::TypeId type) @@ -2141,7 +2213,10 @@ TEST_CASE("HugeFunction") REQUIRE(result == 0); if (codegen && luau_codegen_supported()) - Luau::CodeGen::compile(L, -1, Luau::CodeGen::CodeGen_ColdFunctions); + { + Luau::CodeGen::CompilationOptions nativeOptions{Luau::CodeGen::CodeGen_ColdFunctions}; + Luau::CodeGen::compile(L, -1, nativeOptions); + } int status = lua_resume(L, nullptr, 0); REQUIRE(status == 0); @@ -2263,8 +2338,9 @@ TEST_CASE("IrInstructionLimit") REQUIRE(result == 0); + Luau::CodeGen::CompilationOptions nativeOptions{Luau::CodeGen::CodeGen_ColdFunctions}; Luau::CodeGen::CompilationStats nativeStats = {}; - Luau::CodeGen::CompilationResult nativeResult = Luau::CodeGen::compile(L, -1, Luau::CodeGen::CodeGen_ColdFunctions, &nativeStats); + Luau::CodeGen::CompilationResult nativeResult = Luau::CodeGen::compile(L, -1, nativeOptions, &nativeStats); // Limit is not hit immediately, so with some functions compiled it should be a success CHECK(nativeResult.result == Luau::CodeGen::CodeGenCompilationResult::Success); diff --git a/tests/ConformanceIrHooks.h b/tests/ConformanceIrHooks.h new file mode 100644 index 00000000..135fe9da --- /dev/null +++ b/tests/ConformanceIrHooks.h @@ -0,0 +1,151 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/IrBuilder.h" + +inline uint8_t vectorAccessBytecodeType(const char* member, size_t memberLength) +{ + using namespace Luau::CodeGen; + + if (memberLength == strlen("Magnitude") && strcmp(member, "Magnitude") == 0) + return LBC_TYPE_NUMBER; + + if (memberLength == strlen("Unit") && strcmp(member, "Unit") == 0) + return LBC_TYPE_VECTOR; + + return LBC_TYPE_ANY; +} + +inline bool vectorAccess(Luau::CodeGen::IrBuilder& build, const char* member, size_t memberLength, int resultReg, int sourceReg, int pcpos) +{ + using namespace Luau::CodeGen; + + if (memberLength == strlen("Magnitude") && strcmp(member, "Magnitude") == 0) + { + IrOp x = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(0)); + IrOp y = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(4)); + IrOp z = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(8)); + + IrOp x2 = build.inst(IrCmd::MUL_NUM, x, x); + IrOp y2 = build.inst(IrCmd::MUL_NUM, y, y); + IrOp z2 = build.inst(IrCmd::MUL_NUM, z, z); + + IrOp sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, x2, y2), z2); + + IrOp mag = build.inst(IrCmd::SQRT_NUM, sum); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(resultReg), mag); + build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TNUMBER)); + + return true; + } + + if (memberLength == strlen("Unit") && strcmp(member, "Unit") == 0) + { + IrOp x = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(0)); + IrOp y = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(4)); + IrOp z = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(8)); + + IrOp x2 = build.inst(IrCmd::MUL_NUM, x, x); + IrOp y2 = build.inst(IrCmd::MUL_NUM, y, y); + IrOp z2 = build.inst(IrCmd::MUL_NUM, z, z); + + IrOp sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, x2, y2), z2); + + IrOp mag = build.inst(IrCmd::SQRT_NUM, sum); + IrOp inv = build.inst(IrCmd::DIV_NUM, build.constDouble(1.0), mag); + + IrOp xr = build.inst(IrCmd::MUL_NUM, x, inv); + IrOp yr = build.inst(IrCmd::MUL_NUM, y, inv); + IrOp zr = build.inst(IrCmd::MUL_NUM, z, inv); + + build.inst(IrCmd::STORE_VECTOR, build.vmReg(resultReg), xr, yr, zr); + build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TVECTOR)); + + return true; + } + + return false; +} + +inline uint8_t vectorNamecallBytecodeType(const char* member, size_t memberLength) +{ + if (memberLength == strlen("Dot") && strcmp(member, "Dot") == 0) + return LBC_TYPE_NUMBER; + + if (memberLength == strlen("Cross") && strcmp(member, "Cross") == 0) + return LBC_TYPE_VECTOR; + + return LBC_TYPE_ANY; +} + +inline bool vectorNamecall( + Luau::CodeGen::IrBuilder& build, const char* member, size_t memberLength, int argResReg, int sourceReg, int params, int results, int pcpos) +{ + using namespace Luau::CodeGen; + + if (memberLength == strlen("Dot") && strcmp(member, "Dot") == 0 && params == 2 && results <= 1) + { + build.loadAndCheckTag(build.vmReg(argResReg + 2), LUA_TVECTOR, build.vmExit(pcpos)); + + IrOp x1 = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(0)); + IrOp x2 = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(argResReg + 2), build.constInt(0)); + IrOp xx = build.inst(IrCmd::MUL_NUM, x1, x2); + + IrOp y1 = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(4)); + IrOp y2 = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(argResReg + 2), build.constInt(4)); + IrOp yy = build.inst(IrCmd::MUL_NUM, y1, y2); + + IrOp z1 = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(8)); + IrOp z2 = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(argResReg + 2), build.constInt(8)); + IrOp zz = build.inst(IrCmd::MUL_NUM, z1, z2); + + IrOp sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, xx, yy), zz); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(argResReg), sum); + build.inst(IrCmd::STORE_TAG, build.vmReg(argResReg), build.constTag(LUA_TNUMBER)); + + // If the function is called in multi-return context, stack has to be adjusted + if (results == LUA_MULTRET) + build.inst(IrCmd::ADJUST_STACK_TO_REG, build.vmReg(argResReg), build.constInt(1)); + + return true; + } + + if (memberLength == strlen("Cross") && strcmp(member, "Cross") == 0 && params == 2 && results <= 1) + { + build.loadAndCheckTag(build.vmReg(argResReg + 2), LUA_TVECTOR, build.vmExit(pcpos)); + + IrOp x1 = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(0)); + IrOp x2 = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(argResReg + 2), build.constInt(0)); + + IrOp y1 = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(4)); + IrOp y2 = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(argResReg + 2), build.constInt(4)); + + IrOp z1 = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(8)); + IrOp z2 = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(argResReg + 2), build.constInt(8)); + + IrOp y1z2 = build.inst(IrCmd::MUL_NUM, y1, z2); + IrOp z1y2 = build.inst(IrCmd::MUL_NUM, z1, y2); + IrOp xr = build.inst(IrCmd::SUB_NUM, y1z2, z1y2); + + IrOp z1x2 = build.inst(IrCmd::MUL_NUM, z1, x2); + IrOp x1z2 = build.inst(IrCmd::MUL_NUM, x1, z2); + IrOp yr = build.inst(IrCmd::SUB_NUM, z1x2, x1z2); + + IrOp x1y2 = build.inst(IrCmd::MUL_NUM, x1, y2); + IrOp y1x2 = build.inst(IrCmd::MUL_NUM, y1, x2); + IrOp zr = build.inst(IrCmd::SUB_NUM, x1y2, y1x2); + + build.inst(IrCmd::STORE_VECTOR, build.vmReg(argResReg), xr, yr, zr); + build.inst(IrCmd::STORE_TAG, build.vmReg(argResReg), build.constTag(LUA_TVECTOR)); + + // If the function is called in multi-return context, stack has to be adjusted + if (results == LUA_MULTRET) + build.inst(IrCmd::ADJUST_STACK_TO_REG, build.vmReg(argResReg), build.constInt(1)); + + return true; + } + + return false; +} diff --git a/tests/IrBuilder.test.cpp b/tests/IrBuilder.test.cpp index 2f198e65..4f7725e6 100644 --- a/tests/IrBuilder.test.cpp +++ b/tests/IrBuilder.test.cpp @@ -21,6 +21,11 @@ using namespace Luau::CodeGen; class IrBuilderFixture { public: + IrBuilderFixture() + : build(hooks) + { + } + void constantFold() { for (IrBlock& block : build.function.blocks) @@ -109,6 +114,7 @@ public: computeCfgDominanceTreeChildren(build.function); } + HostIrHooks hooks; IrBuilder build; // Luau.VM headers are not accessible diff --git a/tests/IrLowering.test.cpp b/tests/IrLowering.test.cpp index 0c0d6378..131ec4d1 100644 --- a/tests/IrLowering.test.cpp +++ b/tests/IrLowering.test.cpp @@ -6,9 +6,11 @@ #include "Luau/CodeGen.h" #include "Luau/Compiler.h" #include "Luau/Parser.h" +#include "Luau/IrBuilder.h" #include "doctest.h" #include "ScopedFlags.h" +#include "ConformanceIrHooks.h" #include @@ -22,11 +24,17 @@ LUAU_FASTFLAG(LuauCodegenIrTypeNames) LUAU_FASTFLAG(LuauCompileTempTypeInfo) LUAU_FASTFLAG(LuauCodegenFixVectorFields) LUAU_FASTFLAG(LuauCodegenVectorMispredictFix) +LUAU_FASTFLAG(LuauCodegenAnalyzeHostVectorOps) static std::string getCodegenAssembly(const char* source, bool includeIrTypes = false, int debugLevel = 1) { Luau::CodeGen::AssemblyOptions options; + options.compilationOptions.hooks.vectorAccessBytecodeType = vectorAccessBytecodeType; + options.compilationOptions.hooks.vectorNamecallBytecodeType = vectorNamecallBytecodeType; + options.compilationOptions.hooks.vectorAccess = vectorAccess; + options.compilationOptions.hooks.vectorNamecall = vectorNamecall; + // For IR, we don't care about assembly, but we want a stable target options.target = Luau::CodeGen::AssemblyOptions::Target::X64_SystemV; @@ -513,6 +521,277 @@ bb_6: )"); } +TEST_CASE("VectorCustomAccess") +{ + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; + ScopedFastFlag luauCodegenVectorMispredictFix{FFlag::LuauCodegenVectorMispredictFix, true}; + ScopedFastFlag luauCodegenAnalyzeHostVectorOps{FFlag::LuauCodegenAnalyzeHostVectorOps, true}; + + CHECK_EQ("\n" + getCodegenAssembly(R"( +local function vec3magn(a: vector) + return a.Magnitude * 2 +end +)"), + R"( +; function vec3magn($arg0) line 2 +bb_0: + CHECK_TAG R0, tvector, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + %6 = LOAD_FLOAT R0, 0i + %7 = LOAD_FLOAT R0, 4i + %8 = LOAD_FLOAT R0, 8i + %9 = MUL_NUM %6, %6 + %10 = MUL_NUM %7, %7 + %11 = MUL_NUM %8, %8 + %12 = ADD_NUM %9, %10 + %13 = ADD_NUM %12, %11 + %14 = SQRT_NUM %13 + %20 = MUL_NUM %14, 2 + STORE_DOUBLE R1, %20 + STORE_TAG R1, tnumber + INTERRUPT 3u + RETURN R1, 1i +)"); +} + +TEST_CASE("VectorCustomNamecall") +{ + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; + ScopedFastFlag LuauCodegenDirectUserdataFlow{FFlag::LuauCodegenDirectUserdataFlow, true}; + ScopedFastFlag luauCodegenAnalyzeHostVectorOps{FFlag::LuauCodegenAnalyzeHostVectorOps, true}; + + CHECK_EQ("\n" + getCodegenAssembly(R"( +local function vec3dot(a: vector, b: vector) + return (a:Dot(b)) +end +)"), + R"( +; function vec3dot($arg0, $arg1) line 2 +bb_0: + CHECK_TAG R0, tvector, exit(entry) + CHECK_TAG R1, tvector, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + %6 = LOAD_TVALUE R1 + STORE_TVALUE R4, %6 + %12 = LOAD_FLOAT R0, 0i + %13 = LOAD_FLOAT R4, 0i + %14 = MUL_NUM %12, %13 + %15 = LOAD_FLOAT R0, 4i + %16 = LOAD_FLOAT R4, 4i + %17 = MUL_NUM %15, %16 + %18 = LOAD_FLOAT R0, 8i + %19 = LOAD_FLOAT R4, 8i + %20 = MUL_NUM %18, %19 + %21 = ADD_NUM %14, %17 + %22 = ADD_NUM %21, %20 + STORE_DOUBLE R2, %22 + STORE_TAG R2, tnumber + INTERRUPT 4u + RETURN R2, 1i +)"); +} + +TEST_CASE("VectorCustomAccessChain") +{ + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; + ScopedFastFlag luauCodegenVectorMispredictFix{FFlag::LuauCodegenVectorMispredictFix, true}; + ScopedFastFlag LuauCodegenDirectUserdataFlow{FFlag::LuauCodegenDirectUserdataFlow, true}; + ScopedFastFlag luauCodegenAnalyzeHostVectorOps{FFlag::LuauCodegenAnalyzeHostVectorOps, true}; + + CHECK_EQ("\n" + getCodegenAssembly(R"( +local function foo(a: vector, b: vector) + return a.Unit * b.Magnitude +end +)"), + R"( +; function foo($arg0, $arg1) line 2 +bb_0: + CHECK_TAG R0, tvector, exit(entry) + CHECK_TAG R1, tvector, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + %8 = LOAD_FLOAT R0, 0i + %9 = LOAD_FLOAT R0, 4i + %10 = LOAD_FLOAT R0, 8i + %11 = MUL_NUM %8, %8 + %12 = MUL_NUM %9, %9 + %13 = MUL_NUM %10, %10 + %14 = ADD_NUM %11, %12 + %15 = ADD_NUM %14, %13 + %16 = SQRT_NUM %15 + %17 = DIV_NUM 1, %16 + %18 = MUL_NUM %8, %17 + %19 = MUL_NUM %9, %17 + %20 = MUL_NUM %10, %17 + STORE_VECTOR R3, %18, %19, %20 + STORE_TAG R3, tvector + %25 = LOAD_FLOAT R1, 0i + %26 = LOAD_FLOAT R1, 4i + %27 = LOAD_FLOAT R1, 8i + %28 = MUL_NUM %25, %25 + %29 = MUL_NUM %26, %26 + %30 = MUL_NUM %27, %27 + %31 = ADD_NUM %28, %29 + %32 = ADD_NUM %31, %30 + %33 = SQRT_NUM %32 + %40 = LOAD_TVALUE R3 + %42 = NUM_TO_VEC %33 + %43 = MUL_VEC %40, %42 + %44 = TAG_VECTOR %43 + STORE_TVALUE R2, %44 + INTERRUPT 5u + RETURN R2, 1i +)"); +} + +TEST_CASE("VectorCustomNamecallChain") +{ + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; + ScopedFastFlag luauCodegenVectorMispredictFix{FFlag::LuauCodegenVectorMispredictFix, true}; + ScopedFastFlag LuauCodegenDirectUserdataFlow{FFlag::LuauCodegenDirectUserdataFlow, true}; + ScopedFastFlag luauCodegenAnalyzeHostVectorOps{FFlag::LuauCodegenAnalyzeHostVectorOps, true}; + + CHECK_EQ("\n" + getCodegenAssembly(R"( +local function foo(n: vector, b: vector, t: vector) + return n:Cross(t):Dot(b) + 1 +end +)"), + R"( +; function foo($arg0, $arg1, $arg2) line 2 +bb_0: + CHECK_TAG R0, tvector, exit(entry) + CHECK_TAG R1, tvector, exit(entry) + CHECK_TAG R2, tvector, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + %8 = LOAD_TVALUE R2 + STORE_TVALUE R6, %8 + %14 = LOAD_FLOAT R0, 0i + %15 = LOAD_FLOAT R6, 0i + %16 = LOAD_FLOAT R0, 4i + %17 = LOAD_FLOAT R6, 4i + %18 = LOAD_FLOAT R0, 8i + %19 = LOAD_FLOAT R6, 8i + %20 = MUL_NUM %16, %19 + %21 = MUL_NUM %18, %17 + %22 = SUB_NUM %20, %21 + %23 = MUL_NUM %18, %15 + %24 = MUL_NUM %14, %19 + %25 = SUB_NUM %23, %24 + %26 = MUL_NUM %14, %17 + %27 = MUL_NUM %16, %15 + %28 = SUB_NUM %26, %27 + STORE_VECTOR R4, %22, %25, %28 + STORE_TAG R4, tvector + %31 = LOAD_TVALUE R1 + STORE_TVALUE R6, %31 + %37 = LOAD_FLOAT R4, 0i + %38 = LOAD_FLOAT R6, 0i + %39 = MUL_NUM %37, %38 + %40 = LOAD_FLOAT R4, 4i + %41 = LOAD_FLOAT R6, 4i + %42 = MUL_NUM %40, %41 + %43 = LOAD_FLOAT R4, 8i + %44 = LOAD_FLOAT R6, 8i + %45 = MUL_NUM %43, %44 + %46 = ADD_NUM %39, %42 + %47 = ADD_NUM %46, %45 + %53 = ADD_NUM %47, 1 + STORE_DOUBLE R3, %53 + STORE_TAG R3, tnumber + INTERRUPT 9u + RETURN R3, 1i +)"); +} + +TEST_CASE("VectorCustomNamecallChain2") +{ + ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, + {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauTypeInfoLookupImprovement, true}, {FFlag::LuauCompileTempTypeInfo, true}, + {FFlag::LuauCodegenVectorMispredictFix, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, {FFlag::LuauCodegenAnalyzeHostVectorOps, true}}; + + CHECK_EQ("\n" + getCodegenAssembly(R"( +type Vertex = {n: vector, b: vector} + +local function foo(v: Vertex, t: vector) + return v.n:Cross(t):Dot(v.b) + 1 +end +)"), + R"( +; function foo($arg0, $arg1) line 4 +bb_0: + CHECK_TAG R0, ttable, exit(entry) + CHECK_TAG R1, tvector, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + %8 = LOAD_POINTER R0 + %9 = GET_SLOT_NODE_ADDR %8, 0u, K1 + CHECK_SLOT_MATCH %9, K1, bb_fallback_3 + %11 = LOAD_TVALUE %9, 0i + STORE_TVALUE R3, %11 + JUMP bb_4 +bb_4: + %16 = LOAD_TVALUE R1 + STORE_TVALUE R5, %16 + CHECK_TAG R3, tvector, exit(3) + CHECK_TAG R5, tvector, exit(3) + %22 = LOAD_FLOAT R3, 0i + %23 = LOAD_FLOAT R5, 0i + %24 = LOAD_FLOAT R3, 4i + %25 = LOAD_FLOAT R5, 4i + %26 = LOAD_FLOAT R3, 8i + %27 = LOAD_FLOAT R5, 8i + %28 = MUL_NUM %24, %27 + %29 = MUL_NUM %26, %25 + %30 = SUB_NUM %28, %29 + %31 = MUL_NUM %26, %23 + %32 = MUL_NUM %22, %27 + %33 = SUB_NUM %31, %32 + %34 = MUL_NUM %22, %25 + %35 = MUL_NUM %24, %23 + %36 = SUB_NUM %34, %35 + STORE_VECTOR R3, %30, %33, %36 + CHECK_TAG R0, ttable, exit(6) + %41 = LOAD_POINTER R0 + %42 = GET_SLOT_NODE_ADDR %41, 6u, K3 + CHECK_SLOT_MATCH %42, K3, bb_fallback_5 + %44 = LOAD_TVALUE %42, 0i + STORE_TVALUE R5, %44 + JUMP bb_6 +bb_6: + CHECK_TAG R3, tvector, exit(8) + CHECK_TAG R5, tvector, exit(8) + %53 = LOAD_FLOAT R3, 0i + %54 = LOAD_FLOAT R5, 0i + %55 = MUL_NUM %53, %54 + %56 = LOAD_FLOAT R3, 4i + %57 = LOAD_FLOAT R5, 4i + %58 = MUL_NUM %56, %57 + %59 = LOAD_FLOAT R3, 8i + %60 = LOAD_FLOAT R5, 8i + %61 = MUL_NUM %59, %60 + %62 = ADD_NUM %55, %58 + %63 = ADD_NUM %62, %61 + %69 = ADD_NUM %63, 1 + STORE_DOUBLE R2, %69 + STORE_TAG R2, tnumber + INTERRUPT 12u + RETURN R2, 1i +)"); +} + TEST_CASE("UserDataGetIndex") { ScopedFastFlag luauCodegenDirectUserdataFlow{FFlag::LuauCodegenDirectUserdataFlow, true}; @@ -1040,7 +1319,7 @@ TEST_CASE("ResolveVectorNamecalls") { ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauTypeInfoLookupImprovement, true}, {FFlag::LuauCodegenIrTypeNames, true}, - {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}}; + {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, {FFlag::LuauCodegenAnalyzeHostVectorOps, true}}; CHECK_EQ("\n" + getCodegenAssembly(R"( type Vertex = {pos: vector, normal: vector} @@ -1083,10 +1362,20 @@ bb_6: %31 = LOAD_TVALUE K1, 0i, tvector STORE_TVALUE R4, %31 CHECK_TAG R2, tvector, exit(4) - FALLBACK_NAMECALL 4u, R2, R2, K2 - INTERRUPT 6u - SET_SAVEDPC 7u - CALL R2, 2i, -1i + %37 = LOAD_FLOAT R2, 0i + %38 = LOAD_FLOAT R4, 0i + %39 = MUL_NUM %37, %38 + %40 = LOAD_FLOAT R2, 4i + %41 = LOAD_FLOAT R4, 4i + %42 = MUL_NUM %40, %41 + %43 = LOAD_FLOAT R2, 8i + %44 = LOAD_FLOAT R4, 8i + %45 = MUL_NUM %43, %44 + %46 = ADD_NUM %39, %42 + %47 = ADD_NUM %46, %45 + STORE_DOUBLE R2, %47 + STORE_TAG R2, tnumber + ADJUST_STACK_TO_REG R2, 1i INTERRUPT 7u RETURN R2, -1i )"); diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index 36398289..2eb8ca91 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -11,7 +11,6 @@ #include "Luau/BuiltinDefinitions.h" LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) -LUAU_FASTFLAG(LuauFixNormalizeCaching) LUAU_FASTFLAG(LuauNormalizeNotUnknownIntersection) LUAU_FASTFLAG(LuauFixCyclicUnionsOfIntersections); LUAU_FASTINT(LuauTypeInferRecursionLimit) @@ -428,7 +427,6 @@ struct NormalizeFixture : Fixture UnifierSharedState unifierState{&iceHandler}; Normalizer normalizer{&arena, builtinTypes, NotNull{&unifierState}}; Scope globalScope{builtinTypes->anyTypePack}; - ScopedFastFlag fixNormalizeCaching{FFlag::LuauFixNormalizeCaching, true}; NormalizeFixture() { diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 10331408..e1163a1b 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -17,7 +17,6 @@ LUAU_FASTINT(LuauRecursionLimit); LUAU_FASTINT(LuauTypeLengthLimit); LUAU_FASTINT(LuauParseErrorLimit); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); -LUAU_FASTFLAG(LuauReadWritePropertySyntax); namespace { @@ -3156,8 +3155,6 @@ TEST_CASE_FIXTURE(Fixture, "cannot_use_@_as_variable_name") TEST_CASE_FIXTURE(Fixture, "read_write_table_properties") { - ScopedFastFlag sff{FFlag::LuauReadWritePropertySyntax, true}; - auto pr = tryParse(R"( type A = {read x: number} type B = {write x: number} diff --git a/tests/Set.test.cpp b/tests/Set.test.cpp index 94de4f01..b3824bf1 100644 --- a/tests/Set.test.cpp +++ b/tests/Set.test.cpp @@ -7,8 +7,6 @@ #include #include -LUAU_FASTFLAG(LuauFixSetIter); - TEST_SUITE_BEGIN("SetTests"); TEST_CASE("empty_set_size_0") @@ -107,8 +105,6 @@ TEST_CASE("iterate_over_set_skips_erased_elements") TEST_CASE("iterate_over_set_skips_first_element_if_it_is_erased") { - ScopedFastFlag sff{FFlag::LuauFixSetIter, true}; - /* * As of this writing, in the following set, the key "y" happens to occur * before "x" in the underlying DenseHashSet. This is important because it diff --git a/tests/SharedCodeAllocator.test.cpp b/tests/SharedCodeAllocator.test.cpp index 0b142930..30bf1de2 100644 --- a/tests/SharedCodeAllocator.test.cpp +++ b/tests/SharedCodeAllocator.test.cpp @@ -438,10 +438,13 @@ TEST_CASE("SharedAllocation") const ModuleId moduleId = {0x01}; + CompilationOptions options; + options.flags = CodeGen_ColdFunctions; + CompilationStats nativeStats1 = {}; CompilationStats nativeStats2 = {}; - const CompilationResult codeGenResult1 = Luau::CodeGen::compile(moduleId, L1.get(), -1, CodeGen_ColdFunctions, &nativeStats1); - const CompilationResult codeGenResult2 = Luau::CodeGen::compile(moduleId, L2.get(), -1, CodeGen_ColdFunctions, &nativeStats2); + const CompilationResult codeGenResult1 = Luau::CodeGen::compile(moduleId, L1.get(), -1, options, &nativeStats1); + const CompilationResult codeGenResult2 = Luau::CodeGen::compile(moduleId, L2.get(), -1, options, &nativeStats2); REQUIRE(codeGenResult1.result == CodeGenCompilationResult::Success); REQUIRE(codeGenResult2.result == CodeGenCompilationResult::Success); diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index 4789a810..7308d7da 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -354,21 +354,41 @@ TEST_CASE_FIXTURE(Fixture, "quit_stringifying_type_when_length_is_exceeded") function f2(f) return f or f1 end function f3(f) return f or f2 end )"); - LUAU_REQUIRE_NO_ERRORS(result); - - ToStringOptions o; - o.exhaustive = false; - if (FFlag::DebugLuauDeferredConstraintResolution) { - o.maxTypeLength = 30; + LUAU_REQUIRE_ERROR_COUNT(3, result); + auto err = get(result.errors[0]); + LUAU_ASSERT(err); + CHECK("(...any) -> ()" == toString(err->recommendedReturn)); + REQUIRE(1 == err->recommendedArgs.size()); + CHECK("unknown" == toString(err->recommendedArgs[0].second)); + err = get(result.errors[1]); + LUAU_ASSERT(err); + // FIXME: this recommendation could be better + CHECK("(a) -> or ()>" == toString(err->recommendedReturn)); + REQUIRE(1 == err->recommendedArgs.size()); + CHECK("unknown" == toString(err->recommendedArgs[0].second)); + err = get(result.errors[2]); + LUAU_ASSERT(err); + // FIXME: this recommendation could be better + CHECK("(a) -> or(b) -> or ()>>" == toString(err->recommendedReturn)); + REQUIRE(1 == err->recommendedArgs.size()); + CHECK("unknown" == toString(err->recommendedArgs[0].second)); + + ToStringOptions o; + o.exhaustive = false; + o.maxTypeLength = 20; CHECK_EQ(toString(requireType("f0"), o), "() -> ()"); - CHECK_EQ(toString(requireType("f1"), o), "(a) -> (() -> ()) | (a & ~(false?))... *TRUNCATED*"); - CHECK_EQ(toString(requireType("f2"), o), "(b) -> ((a) -> (() -> ()) | (a & ~(false?))... *TRUNCATED*"); - CHECK_EQ(toString(requireType("f3"), o), "(c) -> ((b) -> ((a) -> (() -> ()) | (a & ~(false?))... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f1"), o), "(a) -> or ... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f2"), o), "(b) -> or(a... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f3"), o), "(c) -> or(b... *TRUNCATED*"); } else { + LUAU_REQUIRE_NO_ERRORS(result); + + ToStringOptions o; + o.exhaustive = false; o.maxTypeLength = 40; CHECK_EQ(toString(requireType("f0"), o), "() -> ()"); CHECK_EQ(toString(requireType("f1"), o), "(() -> ()) -> () -> ()"); @@ -385,20 +405,42 @@ TEST_CASE_FIXTURE(Fixture, "stringifying_type_is_still_capped_when_exhaustive") function f2(f) return f or f1 end function f3(f) return f or f2 end )"); - LUAU_REQUIRE_NO_ERRORS(result); - ToStringOptions o; - o.exhaustive = true; if (FFlag::DebugLuauDeferredConstraintResolution) { - o.maxTypeLength = 30; + LUAU_REQUIRE_ERROR_COUNT(3, result); + auto err = get(result.errors[0]); + LUAU_ASSERT(err); + CHECK("(...any) -> ()" == toString(err->recommendedReturn)); + REQUIRE(1 == err->recommendedArgs.size()); + CHECK("unknown" == toString(err->recommendedArgs[0].second)); + err = get(result.errors[1]); + LUAU_ASSERT(err); + // FIXME: this recommendation could be better + CHECK("(a) -> or ()>" == toString(err->recommendedReturn)); + REQUIRE(1 == err->recommendedArgs.size()); + CHECK("unknown" == toString(err->recommendedArgs[0].second)); + err = get(result.errors[2]); + LUAU_ASSERT(err); + // FIXME: this recommendation could be better + CHECK("(a) -> or(b) -> or ()>>" == toString(err->recommendedReturn)); + REQUIRE(1 == err->recommendedArgs.size()); + CHECK("unknown" == toString(err->recommendedArgs[0].second)); + + ToStringOptions o; + o.exhaustive = true; + o.maxTypeLength = 20; CHECK_EQ(toString(requireType("f0"), o), "() -> ()"); - CHECK_EQ(toString(requireType("f1"), o), "(a) -> (() -> ()) | (a & ~(false?))... *TRUNCATED*"); - CHECK_EQ(toString(requireType("f2"), o), "(b) -> ((a) -> (() -> ()) | (a & ~(false?))... *TRUNCATED*"); - CHECK_EQ(toString(requireType("f3"), o), "(c) -> ((b) -> ((a) -> (() -> ()) | (a & ~(false?))... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f1"), o), "(a) -> or ... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f2"), o), "(b) -> or(a... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f3"), o), "(c) -> or(b... *TRUNCATED*"); } else { + LUAU_REQUIRE_NO_ERRORS(result); + + ToStringOptions o; + o.exhaustive = true; o.maxTypeLength = 40; CHECK_EQ(toString(requireType("f0"), o), "() -> ()"); CHECK_EQ(toString(requireType("f1"), o), "(() -> ()) -> () -> ()"); diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 4fb3d58b..bfb17c78 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -2351,8 +2351,9 @@ end LUAU_REQUIRE_ERRORS(result); auto err = get(result.errors.back()); LUAU_ASSERT(err); - CHECK("false | number" == toString(err->recommendedReturn)); - CHECK(err->recommendedArgs.size() == 0); + CHECK("number" == toString(err->recommendedReturn)); + REQUIRE(1 == err->recommendedArgs.size()); + CHECK("number" == toString(err->recommendedArgs[0].second)); } TEST_CASE_FIXTURE(BuiltinsFixture, "tf_suggest_arg_type") @@ -2673,4 +2674,17 @@ TEST_CASE_FIXTURE(Fixture, "captured_local_is_assigned_a_function") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "error_suppression_propagates_through_function_calls") +{ + CheckResult result = check(R"( + function first(x: any) + return pairs(x)(x) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK("(any) -> (any?, any)" == toString(requireType("first"))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.loops.test.cpp b/tests/TypeInfer.loops.test.cpp index d1716f5d..8bbd3f92 100644 --- a/tests/TypeInfer.loops.test.cpp +++ b/tests/TypeInfer.loops.test.cpp @@ -1010,7 +1010,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "iterate_over_properties_nonstrict") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(BuiltinsFixture, "pairs_should_not_add_an_indexer") +TEST_CASE_FIXTURE(BuiltinsFixture, "pairs_should_not_retroactively_add_an_indexer") { CheckResult result = check(R"( --!strict @@ -1025,7 +1025,12 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "pairs_should_not_add_an_indexer") )"); if (FFlag::DebugLuauDeferredConstraintResolution) - LUAU_REQUIRE_ERROR_COUNT(2, result); + { + // We regress a little here: The old solver would typecheck the first + // access to prices.wwwww on a table that had no indexer, and the second + // on a table that does. + LUAU_REQUIRE_ERROR_COUNT(0, result); + } else LUAU_REQUIRE_ERROR_COUNT(1, result); } @@ -1114,4 +1119,20 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "forin_metatable_iter_mm") CHECK_EQ("number", toString(requireTypeAtPosition({6, 21}))); } +TEST_CASE_FIXTURE(BuiltinsFixture, "iteration_preserves_error_suppression") +{ + CheckResult result = check(R"( + function first(x: any) + for k, v in pairs(x) do + print(k, v) + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK("any" == toString(requireTypeAtPosition({3, 22}))); + CHECK("any" == toString(requireTypeAtPosition({3, 25}))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 4446bbc9..307084d5 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -21,7 +21,6 @@ LUAU_FASTFLAG(LuauInstantiateInSubtyping); LUAU_FASTFLAG(LuauAlwaysCommitInferencesOfFunctionCalls); LUAU_FASTFLAG(LuauFixIndexerSubtypingOrdering); LUAU_FASTFLAG(DebugLuauSharedSelf); -LUAU_FASTFLAG(LuauReadWritePropertySyntax); LUAU_FASTFLAG(LuauMetatableInstantiationCloneCheck); LUAU_DYNAMIC_FASTFLAG(LuauImproveNonFunctionCallError) @@ -2729,7 +2728,9 @@ TEST_CASE_FIXTURE(Fixture, "tables_get_names_from_their_locals") TEST_CASE_FIXTURE(Fixture, "should_not_unblock_table_type_twice") { - ScopedFastFlag sff(FFlag::DebugLuauDeferredConstraintResolution, true); + // don't run this when the DCR flag isn't set + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; check(R"( local timer = peek(timerQueue) @@ -4014,7 +4015,6 @@ TEST_CASE_FIXTURE(Fixture, "identify_all_problematic_table_fields") TEST_CASE_FIXTURE(Fixture, "read_and_write_only_table_properties_are_unsupported") { ScopedFastFlag sff[] = { - {FFlag::LuauReadWritePropertySyntax, true}, {FFlag::DebugLuauDeferredConstraintResolution, false}, }; @@ -4040,8 +4040,6 @@ TEST_CASE_FIXTURE(Fixture, "read_and_write_only_table_properties_are_unsupported TEST_CASE_FIXTURE(Fixture, "read_ond_write_only_indexers_are_unsupported") { - ScopedFastFlag sff{FFlag::LuauReadWritePropertySyntax, true}; - CheckResult result = check(R"( type T = {read [string]: number} type U = {write [string]: boolean} @@ -4155,7 +4153,9 @@ TEST_CASE_FIXTURE(Fixture, "write_annotations_are_unsupported_even_with_the_new_ TEST_CASE_FIXTURE(Fixture, "read_and_write_only_table_properties_are_unsupported") { - ScopedFastFlag sff[] = {{FFlag::LuauReadWritePropertySyntax, true}, {FFlag::DebugLuauDeferredConstraintResolution, false}}; + ScopedFastFlag sff[] = { + {FFlag::DebugLuauDeferredConstraintResolution, false} + }; CheckResult result = check(R"( type W = {read x: number} @@ -4179,7 +4179,9 @@ TEST_CASE_FIXTURE(Fixture, "read_and_write_only_table_properties_are_unsupported TEST_CASE_FIXTURE(Fixture, "read_ond_write_only_indexers_are_unsupported") { - ScopedFastFlag sff[] = {{FFlag::LuauReadWritePropertySyntax, true}, {FFlag::DebugLuauDeferredConstraintResolution, false}}; + ScopedFastFlag sff[] = { + {FFlag::DebugLuauDeferredConstraintResolution, false} + }; CheckResult result = check(R"( type T = {read [string]: number} @@ -4199,7 +4201,9 @@ TEST_CASE_FIXTURE(Fixture, "table_writes_introduce_write_properties") if (!FFlag::DebugLuauDeferredConstraintResolution) return; - ScopedFastFlag sff[] = {{FFlag::LuauReadWritePropertySyntax, true}, {FFlag::DebugLuauDeferredConstraintResolution, true}}; + ScopedFastFlag sff[] = { + {FFlag::DebugLuauDeferredConstraintResolution, true} + }; CheckResult result = check(R"( function oc(player, speaker) @@ -4439,4 +4443,21 @@ TEST_CASE_FIXTURE(Fixture, "insert_a_and_f_of_a_into_table_res_in_a_loop") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "ipairs_adds_an_unbounded_indexer") +{ + CheckResult result = check(R"( + --!strict + + local a = {} + ipairs(a) + )"); + + // The old solver erroneously leaves a free type dangling here. The new + // solver does better. + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK("{unknown}" == toString(requireType("a"), {true})); + else + CHECK("{a}" == toString(requireType("a"), {true})); +} + TEST_SUITE_END(); diff --git a/tests/conformance/vector.lua b/tests/conformance/vector.lua index 9be88f69..7e4a9a3e 100644 --- a/tests/conformance/vector.lua +++ b/tests/conformance/vector.lua @@ -4,6 +4,12 @@ print('testing vectors') -- detect vector size local vector_size = if pcall(function() return vector(0, 0, 0).w end) then 4 else 3 +function ecall(fn, ...) + local ok, err = pcall(fn, ...) + assert(not ok) + return err:sub((err:find(": ") or -1) + 2, #err) +end + -- equality assert(vector(1, 2, 3) == vector(1, 2, 3)) assert(vector(0, 1, 2) == vector(-0, 1, 2)) @@ -92,9 +98,29 @@ assert(nanv ~= nanv); -- __index assert(vector(1, 2, 2).Magnitude == 3) assert(vector(0, 0, 0)['Dot'](vector(1, 2, 4), vector(5, 6, 7)) == 45) +assert(vector(2, 0, 0).Unit == vector(1, 0, 0)) -- __namecall assert(vector(1, 2, 4):Dot(vector(5, 6, 7)) == 45) +assert(ecall(function() vector(1, 2, 4):Dot() end) == "missing argument #2 (vector expected)") +assert(ecall(function() vector(1, 2, 4):Dot("a") end) == "invalid argument #2 (vector expected, got string)") + +local function doDot1(a: vector, b) + return a:Dot(b) +end + +local function doDot2(a: vector, b) + return (a:Dot(b)) +end + +local v124 = vector(1, 2, 4) + +assert(doDot1(v124, vector(5, 6, 7)) == 45) +assert(doDot2(v124, vector(5, 6, 7)) == 45) +assert(ecall(function() doDot1(v124, "a") end) == "invalid argument #2 (vector expected, got string)") +assert(ecall(function() doDot2(v124, "a") end) == "invalid argument #2 (vector expected, got string)") +assert(select("#", doDot1(v124, vector(5, 6, 7))) == 1) +assert(select("#", doDot2(v124, vector(5, 6, 7))) == 1) -- can't use vector with NaN components as table key assert(pcall(function() local t = {} t[vector(0/0, 2, 3)] = 1 end) == false) @@ -102,6 +128,9 @@ assert(pcall(function() local t = {} t[vector(1, 0/0, 3)] = 1 end) == false) assert(pcall(function() local t = {} t[vector(1, 2, 0/0)] = 1 end) == false) assert(pcall(function() local t = {} rawset(t, vector(0/0, 2, 3), 1) end) == false) +assert(vector(1, 0, 0):Cross(vector(0, 1, 0)) == vector(0, 0, 1)) +assert(vector(0, 1, 0):Cross(vector(1, 0, 0)) == vector(0, 0, -1)) + -- make sure we cover both builtin and C impl assert(vector(1, 2, 4) == vector("1", "2", "4")) diff --git a/tools/faillist.txt b/tools/faillist.txt index db3eeba5..c0c12bc3 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -4,6 +4,7 @@ AutocompleteTest.anonymous_autofilled_generic_type_pack_vararg AutocompleteTest.autocomplete_string_singletons AutocompleteTest.do_wrong_compatible_nonself_calls AutocompleteTest.string_singleton_as_table_key +AutocompleteTest.string_singleton_in_if_statement2 AutocompleteTest.suggest_table_keys AutocompleteTest.type_correct_suggestion_for_overloads AutocompleteTest.type_correct_suggestion_in_table @@ -133,9 +134,11 @@ RefinementTest.call_an_incompatible_function_after_using_typeguard RefinementTest.dataflow_analysis_can_tell_refinements_when_its_appropriate_to_refine_into_nil_or_never RefinementTest.discriminate_from_isa_of_x RefinementTest.discriminate_from_truthiness_of_x -RefinementTest.function_call_with_colon_after_refining_not_to_be_nil +RefinementTest.free_type_is_equal_to_an_lvalue RefinementTest.globals_can_be_narrowed_too RefinementTest.isa_type_refinement_must_be_known_ahead_of_time +RefinementTest.luau_polyfill_isindexkey_refine_conjunction +RefinementTest.luau_polyfill_isindexkey_refine_conjunction_variant RefinementTest.not_t_or_some_prop_of_t RefinementTest.refine_a_param_that_got_resolved_during_constraint_solving_stage RefinementTest.refine_a_property_of_some_global @@ -154,6 +157,7 @@ TableTests.a_free_shape_can_turn_into_a_scalar_if_it_is_compatible TableTests.a_free_shape_cannot_turn_into_a_scalar_if_it_is_not_compatible TableTests.any_when_indexing_into_an_unsealed_table_with_no_indexer_in_nonstrict_mode TableTests.array_factory_function +TableTests.cannot_augment_sealed_table TableTests.casting_tables_with_props_into_table_with_indexer2 TableTests.casting_tables_with_props_into_table_with_indexer3 TableTests.casting_unsealed_tables_with_props_into_table_with_indexer @@ -177,6 +181,7 @@ TableTests.generalize_table_argument TableTests.generic_table_instantiation_potential_regression TableTests.indexer_on_sealed_table_must_unify_with_free_table TableTests.indexers_get_quantified_too +TableTests.inequality_operators_imply_exactly_matching_types TableTests.infer_array TableTests.infer_indexer_from_array_like_table TableTests.infer_indexer_from_its_variable_type_and_unifiable @@ -206,6 +211,7 @@ TableTests.quantify_even_that_table_was_never_exported_at_all TableTests.quantify_metatables_of_metatables_of_table TableTests.reasonable_error_when_adding_a_nonexistent_property_to_an_array_like_table TableTests.recursive_metatable_type_call +TableTests.refined_thing_can_be_an_array TableTests.right_table_missing_key2 TableTests.scalar_is_a_subtype_of_a_compatible_polymorphic_shape_type TableTests.scalar_is_not_a_subtype_of_a_compatible_polymorphic_shape_type @@ -214,6 +220,7 @@ TableTests.setmetatable_has_a_side_effect TableTests.shared_selfs TableTests.shared_selfs_from_free_param TableTests.shared_selfs_through_metatables +TableTests.should_not_unblock_table_type_twice TableTests.table_call_metamethod_basic TableTests.table_call_metamethod_must_be_callable TableTests.table_param_width_subtyping_2 @@ -236,6 +243,8 @@ ToString.named_metatable_toStringNamedFunction ToString.no_parentheses_around_cyclic_function_type_in_intersection ToString.pick_distinct_names_for_mixed_explicit_and_implicit_generics ToString.primitive +ToString.quit_stringifying_type_when_length_is_exceeded +ToString.stringifying_type_is_still_capped_when_exhaustive ToString.toStringDetailed2 ToString.toStringErrorPack TryUnifyTests.members_of_failed_typepack_unification_are_unified_with_errorType @@ -332,8 +341,10 @@ TypeInferFunctions.occurs_check_failure_in_function_return_type TypeInferFunctions.other_things_are_not_related_to_function TypeInferFunctions.param_1_and_2_both_takes_the_same_generic_but_their_arguments_are_incompatible TypeInferFunctions.param_1_and_2_both_takes_the_same_generic_but_their_arguments_are_incompatible_2 +TypeInferFunctions.regex_benchmark_string_format_minimization TypeInferFunctions.report_exiting_without_return_nonstrict TypeInferFunctions.return_type_by_overload +TypeInferFunctions.tf_suggest_return_type TypeInferFunctions.too_few_arguments_variadic TypeInferFunctions.too_few_arguments_variadic_generic TypeInferFunctions.too_few_arguments_variadic_generic2 @@ -406,6 +417,7 @@ TypeSingletons.error_detailed_tagged_union_mismatch_bool TypeSingletons.error_detailed_tagged_union_mismatch_string TypeSingletons.overloaded_function_call_with_singletons_mismatch TypeSingletons.return_type_of_f_is_not_widened +TypeSingletons.singletons_stick_around_under_assignment TypeSingletons.table_properties_type_error_escapes TypeSingletons.widen_the_supertype_if_it_is_free_and_subtype_has_singleton TypeStatesTest.typestates_preserve_error_suppression_properties From fe0a8194722e3c508c30b8f266a215a99d52dafe Mon Sep 17 00:00:00 2001 From: vegorov-rbx <75688451+vegorov-rbx@users.noreply.github.com> Date: Thu, 16 May 2024 16:02:03 -0700 Subject: [PATCH 08/20] Sync to upstream/release/626 (#1258) ### New Type Solver * Fixed crash in numeric binary operation type families * Results of an indexing operation are now comparable to `nil` without a false positive error * Fixed a crash when a type that failed normalization was accessed * Iterating on a free value now implies that it is iterable --- ### Internal Contributors Co-authored-by: Aaron Weiss Co-authored-by: Alexander McCord Co-authored-by: James McNellis Co-authored-by: Vighnesh Vijay --- Analysis/include/Luau/ConstraintGenerator.h | 3 +- Analysis/include/Luau/Generalization.h | 13 + Analysis/src/ConstraintGenerator.cpp | 201 ++------ Analysis/src/ConstraintSolver.cpp | 53 +- Analysis/src/DataFlowGraph.cpp | 3 +- Analysis/src/EmbeddedBuiltinDefinitions.cpp | 3 +- Analysis/src/Generalization.cpp | 526 ++++++++++++++++++++ Analysis/src/TypeChecker2.cpp | 22 +- Analysis/src/TypeFamily.cpp | 2 +- Analysis/src/Unifier2.cpp | 309 ------------ Ast/src/Lexer.cpp | 16 +- Ast/src/Parser.cpp | 6 +- CodeGen/include/Luau/CodeGen.h | 6 + CodeGen/src/BytecodeAnalysis.cpp | 2 - CodeGen/src/CodeBlockUnwind.cpp | 16 +- CodeGen/src/CodeGen.cpp | 435 +--------------- CodeGen/src/CodeGenAssembly.cpp | 4 +- CodeGen/src/CodeGenContext.cpp | 80 +-- Sources.cmake | 3 + fuzz/proto.cpp | 2 +- tests/CodeAllocator.test.cpp | 4 +- tests/Generalization.test.cpp | 119 +++++ tests/NonStrictTypeChecker.test.cpp | 4 - tests/Parser.test.cpp | 7 - tests/SharedCodeAllocator.test.cpp | 12 - tests/ToString.test.cpp | 2 - tests/TypeFamily.test.cpp | 15 + tests/TypeInfer.functions.test.cpp | 36 ++ tests/TypeInfer.tables.test.cpp | 19 + tests/Unifier2.test.cpp | 63 --- tests/main.cpp | 4 +- tools/faillist.txt | 2 - 32 files changed, 867 insertions(+), 1125 deletions(-) create mode 100644 Analysis/include/Luau/Generalization.h create mode 100644 Analysis/src/Generalization.cpp create mode 100644 tests/Generalization.test.cpp diff --git a/Analysis/include/Luau/ConstraintGenerator.h b/Analysis/include/Luau/ConstraintGenerator.h index 6cb4b6d6..ed5e17e2 100644 --- a/Analysis/include/Luau/ConstraintGenerator.h +++ b/Analysis/include/Luau/ConstraintGenerator.h @@ -373,7 +373,8 @@ private: */ std::vector> getExpectedCallTypesForFunctionOverloads(const TypeId fnType); - TypeId createFamilyInstance(TypeFamilyInstanceType instance, const ScopePtr& scope, Location location); + TypeId createTypeFamilyInstance( + const TypeFamily& family, std::vector typeArguments, std::vector packArguments, const ScopePtr& scope, Location location); }; /** Borrow a vector of pointers from a vector of owning pointers to constraints. diff --git a/Analysis/include/Luau/Generalization.h b/Analysis/include/Luau/Generalization.h new file mode 100644 index 00000000..bf196f3e --- /dev/null +++ b/Analysis/include/Luau/Generalization.h @@ -0,0 +1,13 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Scope.h" +#include "Luau/NotNull.h" +#include "Luau/TypeFwd.h" + +namespace Luau +{ + +std::optional generalize(NotNull arena, NotNull builtinTypes, NotNull scope, TypeId ty); + +} diff --git a/Analysis/src/ConstraintGenerator.cpp b/Analysis/src/ConstraintGenerator.cpp index c559a256..cbd027bb 100644 --- a/Analysis/src/ConstraintGenerator.cpp +++ b/Analysis/src/ConstraintGenerator.cpp @@ -414,7 +414,7 @@ void ConstraintGenerator::computeRefinement(const ScopePtr& scope, Location loca discriminantTy = arena->addType(NegationType{discriminantTy}); if (eq) - discriminantTy = arena->addTypeFamily(kBuiltinTypeFamilies.singletonFamily, {discriminantTy}); + discriminantTy = createTypeFamilyInstance(kBuiltinTypeFamilies.singletonFamily, {discriminantTy}, {}, scope, location); for (const RefinementKey* key = proposition->key; key; key = key->parent) { @@ -526,13 +526,7 @@ void ConstraintGenerator::applyRefinements(const ScopePtr& scope, Location locat { if (mustDeferIntersection(ty) || mustDeferIntersection(dt)) { - TypeId resultType = createFamilyInstance( - TypeFamilyInstanceType{ - NotNull{&kBuiltinTypeFamilies.refineFamily}, - {ty, dt}, - {}, - }, - scope, location); + TypeId resultType = createTypeFamilyInstance(kBuiltinTypeFamilies.refineFamily, {ty, dt}, {}, scope, location); ty = resultType; } @@ -2009,35 +2003,17 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprUnary* unary) { case AstExprUnary::Op::Not: { - TypeId resultType = createFamilyInstance( - TypeFamilyInstanceType{ - NotNull{&kBuiltinTypeFamilies.notFamily}, - {operandType}, - {}, - }, - scope, unary->location); + TypeId resultType = createTypeFamilyInstance(kBuiltinTypeFamilies.notFamily, {operandType}, {}, scope, unary->location); return Inference{resultType, refinementArena.negation(refinement)}; } case AstExprUnary::Op::Len: { - TypeId resultType = createFamilyInstance( - TypeFamilyInstanceType{ - NotNull{&kBuiltinTypeFamilies.lenFamily}, - {operandType}, - {}, - }, - scope, unary->location); + TypeId resultType = createTypeFamilyInstance(kBuiltinTypeFamilies.lenFamily, {operandType}, {}, scope, unary->location); return Inference{resultType, refinementArena.negation(refinement)}; } case AstExprUnary::Op::Minus: { - TypeId resultType = createFamilyInstance( - TypeFamilyInstanceType{ - NotNull{&kBuiltinTypeFamilies.unmFamily}, - {operandType}, - {}, - }, - scope, unary->location); + TypeId resultType = createTypeFamilyInstance(kBuiltinTypeFamilies.unmFamily, {operandType}, {}, scope, unary->location); return Inference{resultType, refinementArena.negation(refinement)}; } default: // msvc can't prove that this is exhaustive. @@ -2053,168 +2029,96 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprBinary* binar { case AstExprBinary::Op::Add: { - TypeId resultType = createFamilyInstance( - TypeFamilyInstanceType{ - NotNull{&kBuiltinTypeFamilies.addFamily}, - {leftType, rightType}, - {}, - }, - scope, binary->location); + TypeId resultType = createTypeFamilyInstance(kBuiltinTypeFamilies.addFamily, {leftType, rightType}, {}, scope, binary->location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::Sub: { - TypeId resultType = createFamilyInstance( - TypeFamilyInstanceType{ - NotNull{&kBuiltinTypeFamilies.subFamily}, - {leftType, rightType}, - {}, - }, - scope, binary->location); + TypeId resultType = createTypeFamilyInstance(kBuiltinTypeFamilies.subFamily, {leftType, rightType}, {}, scope, binary->location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::Mul: { - TypeId resultType = createFamilyInstance( - TypeFamilyInstanceType{ - NotNull{&kBuiltinTypeFamilies.mulFamily}, - {leftType, rightType}, - {}, - }, - scope, binary->location); + TypeId resultType = createTypeFamilyInstance(kBuiltinTypeFamilies.mulFamily, {leftType, rightType}, {}, scope, binary->location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::Div: { - TypeId resultType = createFamilyInstance( - TypeFamilyInstanceType{ - NotNull{&kBuiltinTypeFamilies.divFamily}, - {leftType, rightType}, - {}, - }, - scope, binary->location); + TypeId resultType = createTypeFamilyInstance(kBuiltinTypeFamilies.divFamily, {leftType, rightType}, {}, scope, binary->location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::FloorDiv: { - TypeId resultType = createFamilyInstance( - TypeFamilyInstanceType{ - NotNull{&kBuiltinTypeFamilies.idivFamily}, - {leftType, rightType}, - {}, - }, - scope, binary->location); + TypeId resultType = createTypeFamilyInstance(kBuiltinTypeFamilies.idivFamily, {leftType, rightType}, {}, scope, binary->location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::Pow: { - TypeId resultType = createFamilyInstance( - TypeFamilyInstanceType{ - NotNull{&kBuiltinTypeFamilies.powFamily}, - {leftType, rightType}, - {}, - }, - scope, binary->location); + TypeId resultType = createTypeFamilyInstance(kBuiltinTypeFamilies.powFamily, {leftType, rightType}, {}, scope, binary->location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::Mod: { - TypeId resultType = createFamilyInstance( - TypeFamilyInstanceType{ - NotNull{&kBuiltinTypeFamilies.modFamily}, - {leftType, rightType}, - {}, - }, - scope, binary->location); + TypeId resultType = createTypeFamilyInstance(kBuiltinTypeFamilies.modFamily, {leftType, rightType}, {}, scope, binary->location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::Concat: { - TypeId resultType = createFamilyInstance( - TypeFamilyInstanceType{ - NotNull{&kBuiltinTypeFamilies.concatFamily}, - {leftType, rightType}, - {}, - }, - scope, binary->location); + TypeId resultType = createTypeFamilyInstance(kBuiltinTypeFamilies.concatFamily, {leftType, rightType}, {}, scope, binary->location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::And: { - TypeId resultType = createFamilyInstance( - TypeFamilyInstanceType{ - NotNull{&kBuiltinTypeFamilies.andFamily}, - {leftType, rightType}, - {}, - }, - scope, binary->location); + TypeId resultType = createTypeFamilyInstance(kBuiltinTypeFamilies.andFamily, {leftType, rightType}, {}, scope, binary->location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::Or: { - TypeId resultType = createFamilyInstance( - TypeFamilyInstanceType{ - NotNull{&kBuiltinTypeFamilies.orFamily}, - {leftType, rightType}, - {}, - }, - scope, binary->location); + TypeId resultType = createTypeFamilyInstance(kBuiltinTypeFamilies.orFamily, {leftType, rightType}, {}, scope, binary->location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::CompareLt: { - TypeId resultType = createFamilyInstance( - TypeFamilyInstanceType{ - NotNull{&kBuiltinTypeFamilies.ltFamily}, - {leftType, rightType}, - {}, - }, - scope, binary->location); + TypeId resultType = createTypeFamilyInstance(kBuiltinTypeFamilies.ltFamily, {leftType, rightType}, {}, scope, binary->location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::CompareGe: { - TypeId resultType = createFamilyInstance( - TypeFamilyInstanceType{ - NotNull{&kBuiltinTypeFamilies.ltFamily}, - {rightType, leftType}, // lua decided that `__ge(a, b)` is instead just `__lt(b, a)` - {}, - }, - scope, binary->location); + TypeId resultType = createTypeFamilyInstance(kBuiltinTypeFamilies.ltFamily, + {rightType, leftType}, // lua decided that `__ge(a, b)` is instead just `__lt(b, a)` + {}, scope, binary->location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::CompareLe: { - TypeId resultType = createFamilyInstance( - TypeFamilyInstanceType{ - NotNull{&kBuiltinTypeFamilies.leFamily}, - {leftType, rightType}, - {}, - }, - scope, binary->location); + TypeId resultType = createTypeFamilyInstance(kBuiltinTypeFamilies.leFamily, {leftType, rightType}, {}, scope, binary->location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::CompareGt: { - TypeId resultType = createFamilyInstance( - TypeFamilyInstanceType{ - NotNull{&kBuiltinTypeFamilies.leFamily}, - {rightType, leftType}, // lua decided that `__gt(a, b)` is instead just `__le(b, a)` - {}, - }, - scope, binary->location); + TypeId resultType = createTypeFamilyInstance(kBuiltinTypeFamilies.leFamily, + {rightType, leftType}, // lua decided that `__gt(a, b)` is instead just `__le(b, a)` + {}, scope, binary->location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::CompareEq: case AstExprBinary::Op::CompareNe: { - TypeId resultType = createFamilyInstance( - TypeFamilyInstanceType{ - NotNull{&kBuiltinTypeFamilies.eqFamily}, - {leftType, rightType}, - {}, - }, - scope, binary->location); + DefId leftDef = dfg->getDef(binary->left); + DefId rightDef = dfg->getDef(binary->right); + bool leftSubscripted = containsSubscriptedDefinition(leftDef); + bool rightSubscripted = containsSubscriptedDefinition(rightDef); + + if (leftSubscripted && rightSubscripted) + { + // we cannot add nil in this case because then we will blindly accept comparisons that we should not. + } + else if (leftSubscripted) + leftType = makeUnion(scope, binary->location, leftType, builtinTypes->nilType); + else if (rightSubscripted) + rightType = makeUnion(scope, binary->location, rightType, builtinTypes->nilType); + + TypeId resultType = createTypeFamilyInstance(kBuiltinTypeFamilies.eqFamily, {leftType, rightType}, {}, scope, binary->location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::Op__Count: @@ -3290,26 +3194,14 @@ void ConstraintGenerator::reportCodeTooComplex(Location location) TypeId ConstraintGenerator::makeUnion(const ScopePtr& scope, Location location, TypeId lhs, TypeId rhs) { - TypeId resultType = createFamilyInstance( - TypeFamilyInstanceType{ - NotNull{&kBuiltinTypeFamilies.unionFamily}, - {lhs, rhs}, - {}, - }, - scope, location); + TypeId resultType = createTypeFamilyInstance(kBuiltinTypeFamilies.unionFamily, {lhs, rhs}, {}, scope, location); return resultType; } TypeId ConstraintGenerator::makeIntersect(const ScopePtr& scope, Location location, TypeId lhs, TypeId rhs) { - TypeId resultType = createFamilyInstance( - TypeFamilyInstanceType{ - NotNull{&kBuiltinTypeFamilies.intersectFamily}, - {lhs, rhs}, - {}, - }, - scope, location); + TypeId resultType = createTypeFamilyInstance(kBuiltinTypeFamilies.intersectFamily, {lhs, rhs}, {}, scope, location); return resultType; } @@ -3387,13 +3279,7 @@ void ConstraintGenerator::fillInInferredBindings(const ScopePtr& globalScope, As scope->bindings[symbol] = Binding{tys.front(), location}; else { - TypeId ty = createFamilyInstance( - TypeFamilyInstanceType{ - NotNull{&kBuiltinTypeFamilies.unionFamily}, - std::move(tys), - {}, - }, - globalScope, location); + TypeId ty = createTypeFamilyInstance(kBuiltinTypeFamilies.unionFamily, std::move(tys), {}, globalScope, location); scope->bindings[symbol] = Binding{ty, location}; } @@ -3463,9 +3349,10 @@ std::vector> ConstraintGenerator::getExpectedCallTypesForF return expectedTypes; } -TypeId ConstraintGenerator::createFamilyInstance(TypeFamilyInstanceType instance, const ScopePtr& scope, Location location) +TypeId ConstraintGenerator::createTypeFamilyInstance( + const TypeFamily& family, std::vector typeArguments, std::vector packArguments, const ScopePtr& scope, Location location) { - TypeId result = arena->addType(std::move(instance)); + TypeId result = arena->addTypeFamily(family, typeArguments, packArguments); addConstraint(scope, location, ReduceConstraint{result}); return result; } diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index cdb13b4a..e35ddf0e 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -5,6 +5,7 @@ #include "Luau/Common.h" #include "Luau/ConstraintSolver.h" #include "Luau/DcrLogger.h" +#include "Luau/Generalization.h" #include "Luau/Instantiation.h" #include "Luau/Instantiation2.h" #include "Luau/Location.h" @@ -577,9 +578,7 @@ bool ConstraintSolver::tryDispatch(const GeneralizationConstraint& c, NotNull generalized; - Unifier2 u2{NotNull{arena}, builtinTypes, constraint->scope, NotNull{&iceReporter}}; - - std::optional generalizedTy = u2.generalize(c.sourceType); + std::optional generalizedTy = generalize(NotNull{arena}, builtinTypes, constraint->scope, c.sourceType); if (generalizedTy) generalized = QuantifierResult{*generalizedTy}; // FIXME insertedGenerics and insertedGenericPacks else @@ -609,7 +608,7 @@ bool ConstraintSolver::tryDispatch(const GeneralizationConstraint& c, NotNullscope, ty); unblock(ty, constraint->location); } @@ -682,7 +681,16 @@ bool ConstraintSolver::tryDispatch(const IterableConstraint& c, NotNull(nextTy)) - return block_(nextTy); + { + TypeId keyTy = freshType(arena, builtinTypes, constraint->scope); + TypeId valueTy = freshType(arena, builtinTypes, constraint->scope); + TypeId tableTy = arena->addType(TableType{TableState::Sealed, {}, constraint->scope}); + getMutable(tableTy)->indexer = TableIndexer{keyTy, valueTy}; + + pushConstraint(constraint->scope, constraint->location, SubtypeConstraint{nextTy, tableTy}); + pushConstraint(constraint->scope, constraint->location, UnpackConstraint{c.variables, arena->addTypePack({keyTy, valueTy}), /*resultIsLValue=*/true}); + return true; + } if (get(nextTy)) { @@ -1924,24 +1932,19 @@ bool ConstraintSolver::tryDispatch(const EqualityConstraint& c, NotNull constraint, bool force) { - auto block_ = [&](auto&& t) { - if (force) - { - // TODO: I believe it is the case that, if we are asked to force - // this constraint, then we can do nothing but fail. I'd like to - // find a code sample that gets here. - LUAU_ASSERT(false); - } - else - block(t, constraint); - return false; - }; - - // We may have to block here if we don't know what the iteratee type is, - // if it's a free table, if we don't know it has a metatable, and so on. iteratorTy = follow(iteratorTy); + if (get(iteratorTy)) - return block_(iteratorTy); + { + TypeId keyTy = freshType(arena, builtinTypes, constraint->scope); + TypeId valueTy = freshType(arena, builtinTypes, constraint->scope); + TypeId tableTy = arena->addType(TableType{TableState::Sealed, {}, constraint->scope}); + getMutable(tableTy)->indexer = TableIndexer{keyTy, valueTy}; + + pushConstraint(constraint->scope, constraint->location, SubtypeConstraint{iteratorTy, tableTy}); + pushConstraint(constraint->scope, constraint->location, UnpackConstraint{c.variables, arena->addTypePack({keyTy, valueTy}), /*resultIsLValue=*/true}); + return true; + } auto unpack = [&](TypeId ty) { TypePackId variadic = arena->addTypePack(VariadicTypePack{ty}); @@ -2752,15 +2755,15 @@ void ConstraintSolver::shiftReferences(TypeId source, TypeId target) std::optional ConstraintSolver::generalizeFreeType(NotNull scope, TypeId type) { - if (get(type)) + TypeId t = follow(type); + if (get(t)) { - auto refCount = unresolvedConstraints.find(type); + auto refCount = unresolvedConstraints.find(t); if (!refCount || *refCount > 1) return {}; } - Unifier2 u2{NotNull{arena}, builtinTypes, scope, NotNull{&iceReporter}}; - return u2.generalize(type); + return generalize(NotNull{arena}, builtinTypes, scope, type); } bool ConstraintSolver::hasUnresolvedConstraints(TypeId ty) diff --git a/Analysis/src/DataFlowGraph.cpp b/Analysis/src/DataFlowGraph.cpp index 33b41698..0a0a64d3 100644 --- a/Analysis/src/DataFlowGraph.cpp +++ b/Analysis/src/DataFlowGraph.cpp @@ -763,7 +763,8 @@ DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprCall* c) for (AstExpr* arg : c->args) visitExpr(scope, arg); - return {defArena->freshCell(), nullptr}; + // calls should be treated as subscripted. + return {defArena->freshCell(/* subscripted */ true), nullptr}; } DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIndexName* i) diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index 4fe7c4b7..78b76a78 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -2,7 +2,6 @@ #include "Luau/BuiltinDefinitions.h" LUAU_FASTFLAGVARIABLE(LuauCheckedEmbeddedDefinitions2, false); -LUAU_FASTFLAG(LuauCheckedFunctionSyntax); namespace Luau { @@ -452,7 +451,7 @@ std::string getBuiltinDefinitionSource() std::string result = kBuiltinDefinitionLuaSrc; // Annotates each non generic function as checked - if (FFlag::LuauCheckedEmbeddedDefinitions2 && FFlag::LuauCheckedFunctionSyntax) + if (FFlag::LuauCheckedEmbeddedDefinitions2) result = kBuiltinDefinitionLuaSrcChecked; return result; diff --git a/Analysis/src/Generalization.cpp b/Analysis/src/Generalization.cpp new file mode 100644 index 00000000..081ea153 --- /dev/null +++ b/Analysis/src/Generalization.cpp @@ -0,0 +1,526 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Generalization.h" + +#include "Luau/Scope.h" +#include "Luau/Type.h" +#include "Luau/TypeArena.h" +#include "Luau/TypePack.h" +#include "Luau/VisitType.h" + +namespace Luau +{ + +struct MutatingGeneralizer : TypeOnceVisitor +{ + NotNull builtinTypes; + + NotNull scope; + DenseHashMap positiveTypes; + DenseHashMap negativeTypes; + std::vector generics; + std::vector genericPacks; + + bool isWithinFunction = false; + + MutatingGeneralizer(NotNull builtinTypes, NotNull scope, DenseHashMap positiveTypes, + DenseHashMap negativeTypes) + : TypeOnceVisitor(/* skipBoundTypes */ true) + , builtinTypes(builtinTypes) + , scope(scope) + , positiveTypes(std::move(positiveTypes)) + , negativeTypes(std::move(negativeTypes)) + { + } + + static void replace(DenseHashSet& seen, TypeId haystack, TypeId needle, TypeId replacement) + { + haystack = follow(haystack); + + if (seen.find(haystack)) + return; + seen.insert(haystack); + + if (UnionType* ut = getMutable(haystack)) + { + for (auto iter = ut->options.begin(); iter != ut->options.end();) + { + // FIXME: I bet this function has reentrancy problems + TypeId option = follow(*iter); + + if (option == needle && get(replacement)) + { + iter = ut->options.erase(iter); + continue; + } + + if (option == needle) + { + *iter = replacement; + iter++; + continue; + } + + // advance the iterator, nothing after this can use it. + iter++; + + if (seen.find(option)) + continue; + seen.insert(option); + + if (get(option)) + replace(seen, option, needle, haystack); + else if (get(option)) + replace(seen, option, needle, haystack); + } + + if (ut->options.size() == 1) + { + TypeId onlyType = ut->options[0]; + LUAU_ASSERT(onlyType != haystack); + emplaceType(asMutable(haystack), onlyType); + } + + return; + } + + if (IntersectionType* it = getMutable(needle)) + { + for (auto iter = it->parts.begin(); iter != it->parts.end();) + { + // FIXME: I bet this function has reentrancy problems + TypeId part = follow(*iter); + + if (part == needle && get(replacement)) + { + iter = it->parts.erase(iter); + continue; + } + + if (part == needle) + { + *iter = replacement; + iter++; + continue; + } + + // advance the iterator, nothing after this can use it. + iter++; + + if (seen.find(part)) + continue; + seen.insert(part); + + if (get(part)) + replace(seen, part, needle, haystack); + else if (get(part)) + replace(seen, part, needle, haystack); + } + + if (it->parts.size() == 1) + { + TypeId onlyType = it->parts[0]; + LUAU_ASSERT(onlyType != needle); + emplaceType(asMutable(needle), onlyType); + } + + return; + } + } + + bool visit(TypeId ty, const FunctionType& ft) override + { + const bool oldValue = isWithinFunction; + + isWithinFunction = true; + + traverse(ft.argTypes); + traverse(ft.retTypes); + + isWithinFunction = oldValue; + + return false; + } + + bool visit(TypeId ty, const FreeType&) override + { + const FreeType* ft = get(ty); + LUAU_ASSERT(ft); + + traverse(ft->lowerBound); + traverse(ft->upperBound); + + // It is possible for the above traverse() calls to cause ty to be + // transmuted. We must reacquire ft if this happens. + ty = follow(ty); + ft = get(ty); + if (!ft) + return false; + + const size_t positiveCount = getCount(positiveTypes, ty); + const size_t negativeCount = getCount(negativeTypes, ty); + + if (!positiveCount && !negativeCount) + return false; + + const bool hasLowerBound = !get(follow(ft->lowerBound)); + const bool hasUpperBound = !get(follow(ft->upperBound)); + + DenseHashSet seen{nullptr}; + seen.insert(ty); + + if (!hasLowerBound && !hasUpperBound) + { + if (!isWithinFunction || (positiveCount + negativeCount == 1)) + emplaceType(asMutable(ty), builtinTypes->unknownType); + else + { + emplaceType(asMutable(ty), scope); + generics.push_back(ty); + } + } + + // It is possible that this free type has other free types in its upper + // or lower bounds. If this is the case, we must replace those + // references with never (for the lower bound) or unknown (for the upper + // bound). + // + // If we do not do this, we get tautological bounds like a <: a <: unknown. + else if (positiveCount && !hasUpperBound) + { + TypeId lb = follow(ft->lowerBound); + if (FreeType* lowerFree = getMutable(lb); lowerFree && lowerFree->upperBound == ty) + lowerFree->upperBound = builtinTypes->unknownType; + else + { + DenseHashSet replaceSeen{nullptr}; + replace(replaceSeen, lb, ty, builtinTypes->unknownType); + } + + if (lb != ty) + emplaceType(asMutable(ty), lb); + else if (!isWithinFunction || (positiveCount + negativeCount == 1)) + emplaceType(asMutable(ty), builtinTypes->unknownType); + else + { + // if the lower bound is the type in question, we don't actually have a lower bound. + emplaceType(asMutable(ty), scope); + generics.push_back(ty); + } + } + else + { + TypeId ub = follow(ft->upperBound); + if (FreeType* upperFree = getMutable(ub); upperFree && upperFree->lowerBound == ty) + upperFree->lowerBound = builtinTypes->neverType; + else + { + DenseHashSet replaceSeen{nullptr}; + replace(replaceSeen, ub, ty, builtinTypes->neverType); + } + + if (ub != ty) + emplaceType(asMutable(ty), ub); + else if (!isWithinFunction || (positiveCount + negativeCount == 1)) + emplaceType(asMutable(ty), builtinTypes->unknownType); + else + { + // if the upper bound is the type in question, we don't actually have an upper bound. + emplaceType(asMutable(ty), scope); + generics.push_back(ty); + } + } + + return false; + } + + size_t getCount(const DenseHashMap& map, const void* ty) + { + if (const size_t* count = map.find(ty)) + return *count; + else + return 0; + } + + bool visit(TypeId ty, const TableType&) override + { + const size_t positiveCount = getCount(positiveTypes, ty); + const size_t negativeCount = getCount(negativeTypes, ty); + + // FIXME: Free tables should probably just be replaced by upper bounds on free types. + // + // eg never <: 'a <: {x: number} & {z: boolean} + + if (!positiveCount && !negativeCount) + return true; + + TableType* tt = getMutable(ty); + LUAU_ASSERT(tt); + + tt->state = TableState::Sealed; + + return true; + } + + bool visit(TypePackId tp, const FreeTypePack& ftp) override + { + if (!subsumes(scope, ftp.scope)) + return true; + + tp = follow(tp); + + const size_t positiveCount = getCount(positiveTypes, tp); + const size_t negativeCount = getCount(negativeTypes, tp); + + if (1 == positiveCount + negativeCount) + emplaceTypePack(asMutable(tp), builtinTypes->unknownTypePack); + else + { + emplaceTypePack(asMutable(tp), scope); + genericPacks.push_back(tp); + } + + return true; + } +}; + +struct FreeTypeSearcher : TypeVisitor +{ + NotNull scope; + + explicit FreeTypeSearcher(NotNull scope) + : TypeVisitor(/*skipBoundTypes*/ true) + , scope(scope) + { + } + + enum Polarity + { + Positive, + Negative, + Both, + }; + + Polarity polarity = Positive; + + void flip() + { + switch (polarity) + { + case Positive: + polarity = Negative; + break; + case Negative: + polarity = Positive; + break; + case Both: + break; + } + } + + DenseHashSet seenPositive{nullptr}; + DenseHashSet seenNegative{nullptr}; + + bool seenWithPolarity(const void* ty) + { + switch (polarity) + { + case Positive: + { + if (seenPositive.contains(ty)) + return true; + + seenPositive.insert(ty); + return false; + } + case Negative: + { + if (seenNegative.contains(ty)) + return true; + + seenNegative.insert(ty); + return false; + } + case Both: + { + if (seenPositive.contains(ty) && seenNegative.contains(ty)) + return true; + + seenPositive.insert(ty); + seenNegative.insert(ty); + return false; + } + } + + return false; + } + + // The keys in these maps are either TypeIds or TypePackIds. It's safe to + // mix them because we only use these pointers as unique keys. We never + // indirect them. + DenseHashMap negativeTypes{0}; + DenseHashMap positiveTypes{0}; + + bool visit(TypeId ty) override + { + if (seenWithPolarity(ty)) + return false; + + LUAU_ASSERT(ty); + return true; + } + + bool visit(TypeId ty, const FreeType& ft) override + { + if (seenWithPolarity(ty)) + return false; + + if (!subsumes(scope, ft.scope)) + return true; + + switch (polarity) + { + case Positive: + positiveTypes[ty]++; + break; + case Negative: + negativeTypes[ty]++; + break; + case Both: + positiveTypes[ty]++; + negativeTypes[ty]++; + break; + } + + return true; + } + + bool visit(TypeId ty, const TableType& tt) override + { + if (seenWithPolarity(ty)) + return false; + + if ((tt.state == TableState::Free || tt.state == TableState::Unsealed) && subsumes(scope, tt.scope)) + { + switch (polarity) + { + case Positive: + positiveTypes[ty]++; + break; + case Negative: + negativeTypes[ty]++; + break; + case Both: + positiveTypes[ty]++; + negativeTypes[ty]++; + break; + } + } + + for (const auto& [_name, prop] : tt.props) + { + if (prop.isReadOnly()) + traverse(*prop.readTy); + else + { + LUAU_ASSERT(prop.isShared()); + + Polarity p = polarity; + polarity = Both; + traverse(prop.type()); + polarity = p; + } + } + + if (tt.indexer) + { + traverse(tt.indexer->indexType); + traverse(tt.indexer->indexResultType); + } + + return false; + } + + bool visit(TypeId ty, const FunctionType& ft) override + { + if (seenWithPolarity(ty)) + return false; + + flip(); + traverse(ft.argTypes); + flip(); + + traverse(ft.retTypes); + + return false; + } + + bool visit(TypeId, const ClassType&) override + { + return false; + } + + bool visit(TypePackId tp, const FreeTypePack& ftp) override + { + if (seenWithPolarity(tp)) + return false; + + if (!subsumes(scope, ftp.scope)) + return true; + + switch (polarity) + { + case Positive: + positiveTypes[tp]++; + break; + case Negative: + negativeTypes[tp]++; + break; + case Both: + positiveTypes[tp]++; + negativeTypes[tp]++; + break; + } + + return true; + } +}; + + +std::optional generalize(NotNull arena, NotNull builtinTypes, NotNull scope, TypeId ty) +{ + ty = follow(ty); + + if (ty->owningArena != arena || ty->persistent) + return ty; + + if (const FunctionType* ft = get(ty); ft && (!ft->generics.empty() || !ft->genericPacks.empty())) + return ty; + + FreeTypeSearcher fts{scope}; + fts.traverse(ty); + + MutatingGeneralizer gen{builtinTypes, scope, std::move(fts.positiveTypes), std::move(fts.negativeTypes)}; + + gen.traverse(ty); + + /* MutatingGeneralizer mutates types in place, so it is possible that ty has + * been transmuted to a BoundType. We must follow it again and verify that + * we are allowed to mutate it before we attach generics to it. + */ + ty = follow(ty); + + if (ty->owningArena != arena || ty->persistent) + return ty; + + FunctionType* ftv = getMutable(ty); + if (ftv) + { + ftv->generics = std::move(gen.generics); + ftv->genericPacks = std::move(gen.genericPacks); + } + + return ty; +} + +} // namespace Luau diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index faa5ffdb..d0d37127 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -1591,7 +1591,6 @@ struct TypeChecker2 functionDeclStack.push_back(inferredFnTy); std::shared_ptr normalizedFnTy = normalizer.normalize(inferredFnTy); - const FunctionType* inferredFtv = get(normalizedFnTy->functions.parts.front()); if (!normalizedFnTy) { reportError(CodeTooComplex{}, fn->location); @@ -1686,16 +1685,23 @@ struct TypeChecker2 if (fn->returnAnnotation) visit(*fn->returnAnnotation); + // If the function type has a family annotation, we need to see if we can suggest an annotation - TypeFamilyReductionGuesser guesser{NotNull{&module->internalTypes}, builtinTypes, NotNull{&normalizer}}; - for (TypeId retTy : inferredFtv->retTypes) + if (normalizedFnTy) { - if (get(follow(retTy))) + const FunctionType* inferredFtv = get(normalizedFnTy->functions.parts.front()); + LUAU_ASSERT(inferredFtv); + + TypeFamilyReductionGuesser guesser{NotNull{&module->internalTypes}, builtinTypes, NotNull{&normalizer}}; + for (TypeId retTy : inferredFtv->retTypes) { - TypeFamilyReductionGuessResult result = guesser.guessTypeFamilyReductionForFunction(*fn, inferredFtv, retTy); - if (result.shouldRecommendAnnotation) - reportError( - ExplicitFunctionAnnotationRecommended{std::move(result.guessedFunctionAnnotations), result.guessedReturnType}, fn->location); + if (get(follow(retTy))) + { + TypeFamilyReductionGuessResult result = guesser.guessTypeFamilyReductionForFunction(*fn, inferredFtv, retTy); + if (result.shouldRecommendAnnotation) + reportError(ExplicitFunctionAnnotationRecommended{std::move(result.guessedFunctionAnnotations), result.guessedReturnType}, + fn->location); + } } } diff --git a/Analysis/src/TypeFamily.cpp b/Analysis/src/TypeFamily.cpp index 7fac35c9..e336a5cd 100644 --- a/Analysis/src/TypeFamily.cpp +++ b/Analysis/src/TypeFamily.cpp @@ -1507,7 +1507,7 @@ TypeFamilyReductionResult singletonFamilyFn(TypeId instance, NotNull(followed)) - followed = follow(negation->ty); + followed = follow(negation->ty); // if we have a singleton type or `nil`, which is its own singleton type... if (get(followed) || isNil(followed)) diff --git a/Analysis/src/Unifier2.cpp b/Analysis/src/Unifier2.cpp index 1e90c0e8..c8db5335 100644 --- a/Analysis/src/Unifier2.cpp +++ b/Analysis/src/Unifier2.cpp @@ -825,315 +825,6 @@ struct FreeTypeSearcher : TypeVisitor } }; -struct MutatingGeneralizer : TypeOnceVisitor -{ - NotNull builtinTypes; - - NotNull scope; - DenseHashMap positiveTypes; - DenseHashMap negativeTypes; - std::vector generics; - std::vector genericPacks; - - bool isWithinFunction = false; - - MutatingGeneralizer(NotNull builtinTypes, NotNull scope, DenseHashMap positiveTypes, - DenseHashMap negativeTypes) - : TypeOnceVisitor(/* skipBoundTypes */ true) - , builtinTypes(builtinTypes) - , scope(scope) - , positiveTypes(std::move(positiveTypes)) - , negativeTypes(std::move(negativeTypes)) - { - } - - static void replace(DenseHashSet& seen, TypeId haystack, TypeId needle, TypeId replacement) - { - haystack = follow(haystack); - - if (seen.find(haystack)) - return; - seen.insert(haystack); - - if (UnionType* ut = getMutable(haystack)) - { - for (auto iter = ut->options.begin(); iter != ut->options.end();) - { - // FIXME: I bet this function has reentrancy problems - TypeId option = follow(*iter); - - if (option == needle && get(replacement)) - { - iter = ut->options.erase(iter); - continue; - } - - if (option == needle) - { - *iter = replacement; - iter++; - continue; - } - - // advance the iterator, nothing after this can use it. - iter++; - - if (seen.find(option)) - continue; - seen.insert(option); - - if (get(option)) - replace(seen, option, needle, haystack); - else if (get(option)) - replace(seen, option, needle, haystack); - } - - if (ut->options.size() == 1) - { - TypeId onlyType = ut->options[0]; - LUAU_ASSERT(onlyType != haystack); - emplaceType(asMutable(haystack), onlyType); - } - - return; - } - - if (IntersectionType* it = getMutable(needle)) - { - for (auto iter = it->parts.begin(); iter != it->parts.end();) - { - // FIXME: I bet this function has reentrancy problems - TypeId part = follow(*iter); - - if (part == needle && get(replacement)) - { - iter = it->parts.erase(iter); - continue; - } - - if (part == needle) - { - *iter = replacement; - iter++; - continue; - } - - // advance the iterator, nothing after this can use it. - iter++; - - if (seen.find(part)) - continue; - seen.insert(part); - - if (get(part)) - replace(seen, part, needle, haystack); - else if (get(part)) - replace(seen, part, needle, haystack); - } - - if (it->parts.size() == 1) - { - TypeId onlyType = it->parts[0]; - LUAU_ASSERT(onlyType != needle); - emplaceType(asMutable(needle), onlyType); - } - - return; - } - } - - bool visit(TypeId ty, const FunctionType& ft) override - { - const bool oldValue = isWithinFunction; - - isWithinFunction = true; - - traverse(ft.argTypes); - traverse(ft.retTypes); - - isWithinFunction = oldValue; - - return false; - } - - bool visit(TypeId ty, const FreeType&) override - { - const FreeType* ft = get(ty); - LUAU_ASSERT(ft); - - traverse(ft->lowerBound); - traverse(ft->upperBound); - - // It is possible for the above traverse() calls to cause ty to be - // transmuted. We must reacquire ft if this happens. - ty = follow(ty); - ft = get(ty); - if (!ft) - return false; - - const size_t positiveCount = getCount(positiveTypes, ty); - const size_t negativeCount = getCount(negativeTypes, ty); - - if (!positiveCount && !negativeCount) - return false; - - const bool hasLowerBound = !get(follow(ft->lowerBound)); - const bool hasUpperBound = !get(follow(ft->upperBound)); - - DenseHashSet seen{nullptr}; - seen.insert(ty); - - if (!hasLowerBound && !hasUpperBound) - { - if (!isWithinFunction || (positiveCount + negativeCount == 1)) - emplaceType(asMutable(ty), builtinTypes->unknownType); - else - { - emplaceType(asMutable(ty), scope); - generics.push_back(ty); - } - } - - // It is possible that this free type has other free types in its upper - // or lower bounds. If this is the case, we must replace those - // references with never (for the lower bound) or unknown (for the upper - // bound). - // - // If we do not do this, we get tautological bounds like a <: a <: unknown. - else if (positiveCount && !hasUpperBound) - { - TypeId lb = follow(ft->lowerBound); - if (FreeType* lowerFree = getMutable(lb); lowerFree && lowerFree->upperBound == ty) - lowerFree->upperBound = builtinTypes->unknownType; - else - { - DenseHashSet replaceSeen{nullptr}; - replace(replaceSeen, lb, ty, builtinTypes->unknownType); - } - - if (lb != ty) - emplaceType(asMutable(ty), lb); - else if (!isWithinFunction || (positiveCount + negativeCount == 1)) - emplaceType(asMutable(ty), builtinTypes->unknownType); - else - { - // if the lower bound is the type in question, we don't actually have a lower bound. - emplaceType(asMutable(ty), scope); - generics.push_back(ty); - } - } - else - { - TypeId ub = follow(ft->upperBound); - if (FreeType* upperFree = getMutable(ub); upperFree && upperFree->lowerBound == ty) - upperFree->lowerBound = builtinTypes->neverType; - else - { - DenseHashSet replaceSeen{nullptr}; - replace(replaceSeen, ub, ty, builtinTypes->neverType); - } - - if (ub != ty) - emplaceType(asMutable(ty), ub); - else if (!isWithinFunction || (positiveCount + negativeCount == 1)) - emplaceType(asMutable(ty), builtinTypes->unknownType); - else - { - // if the upper bound is the type in question, we don't actually have an upper bound. - emplaceType(asMutable(ty), scope); - generics.push_back(ty); - } - } - - return false; - } - - size_t getCount(const DenseHashMap& map, const void* ty) - { - if (const size_t* count = map.find(ty)) - return *count; - else - return 0; - } - - bool visit(TypeId ty, const TableType&) override - { - const size_t positiveCount = getCount(positiveTypes, ty); - const size_t negativeCount = getCount(negativeTypes, ty); - - // FIXME: Free tables should probably just be replaced by upper bounds on free types. - // - // eg never <: 'a <: {x: number} & {z: boolean} - - if (!positiveCount && !negativeCount) - return true; - - TableType* tt = getMutable(ty); - LUAU_ASSERT(tt); - - tt->state = TableState::Sealed; - - return true; - } - - bool visit(TypePackId tp, const FreeTypePack& ftp) override - { - if (!subsumes(scope, ftp.scope)) - return true; - - tp = follow(tp); - - const size_t positiveCount = getCount(positiveTypes, tp); - const size_t negativeCount = getCount(negativeTypes, tp); - - if (1 == positiveCount + negativeCount) - emplaceTypePack(asMutable(tp), builtinTypes->unknownTypePack); - else - { - emplaceTypePack(asMutable(tp), scope); - genericPacks.push_back(tp); - } - - return true; - } -}; - -std::optional Unifier2::generalize(TypeId ty) -{ - ty = follow(ty); - - if (ty->owningArena != arena || ty->persistent) - return ty; - - if (const FunctionType* ft = get(ty); ft && (!ft->generics.empty() || !ft->genericPacks.empty())) - return ty; - - FreeTypeSearcher fts{scope}; - fts.traverse(ty); - - MutatingGeneralizer gen{builtinTypes, scope, std::move(fts.positiveTypes), std::move(fts.negativeTypes)}; - - gen.traverse(ty); - - /* MutatingGeneralizer mutates types in place, so it is possible that ty has - * been transmuted to a BoundType. We must follow it again and verify that - * we are allowed to mutate it before we attach generics to it. - */ - ty = follow(ty); - - if (ty->owningArena != arena || ty->persistent) - return ty; - - FunctionType* ftv = getMutable(ty); - if (ftv) - { - ftv->generics = std::move(gen.generics); - ftv->genericPacks = std::move(gen.genericPacks); - } - - return ty; -} - TypeId Unifier2::mkUnion(TypeId left, TypeId right) { left = follow(left); diff --git a/Ast/src/Lexer.cpp b/Ast/src/Lexer.cpp index 96653a56..71577459 100644 --- a/Ast/src/Lexer.cpp +++ b/Ast/src/Lexer.cpp @@ -8,7 +8,6 @@ #include LUAU_FASTFLAGVARIABLE(LuauLexerLookaheadRemembersBraceType, false) -LUAU_FASTFLAGVARIABLE(LuauCheckedFunctionSyntax, false) namespace Luau { @@ -995,17 +994,14 @@ Lexeme Lexer::readNext() } case '@': { - if (FFlag::LuauCheckedFunctionSyntax) - { - // We're trying to lex the token @checked - LUAU_ASSERT(peekch() == '@'); + // We're trying to lex the token @checked + LUAU_ASSERT(peekch() == '@'); - std::pair maybeChecked = readName(); - if (maybeChecked.second != Lexeme::ReservedChecked) - return Lexeme(Location(start, position()), Lexeme::Error); + std::pair maybeChecked = readName(); + if (maybeChecked.second != Lexeme::ReservedChecked) + return Lexeme(Location(start, position()), Lexeme::Error); - return Lexeme(Location(start, position()), maybeChecked.second, maybeChecked.first.value); - } + return Lexeme(Location(start, position()), maybeChecked.second, maybeChecked.first.value); } default: if (isDigit(peekch())) diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 8bbdf307..e26df1fa 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -16,7 +16,6 @@ LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) // Warning: If you are introducing new syntax, ensure that it is behind a separate // flag so that we don't break production games by reverting syntax changes. // See docs/SyntaxChanges.md for an explanation. -LUAU_FASTFLAG(LuauCheckedFunctionSyntax) LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false) namespace Luau @@ -838,7 +837,7 @@ AstStat* Parser::parseDeclaration(const Location& start) { nextLexeme(); bool checkedFunction = false; - if (FFlag::LuauCheckedFunctionSyntax && lexer.current().type == Lexeme::ReservedChecked) + if (lexer.current().type == Lexeme::ReservedChecked) { checkedFunction = true; nextLexeme(); @@ -1731,9 +1730,8 @@ AstTypeOrPack Parser::parseSimpleType(bool allowPack, bool inDeclarationContext) { return {parseTableType(/* inDeclarationContext */ inDeclarationContext), {}}; } - else if (FFlag::LuauCheckedFunctionSyntax && inDeclarationContext && lexer.current().type == Lexeme::ReservedChecked) + else if (inDeclarationContext && lexer.current().type == Lexeme::ReservedChecked) { - LUAU_ASSERT(FFlag::LuauCheckedFunctionSyntax); nextLexeme(); return parseFunctionType(allowPack, /* isCheckedFunction */ true); } diff --git a/CodeGen/include/Luau/CodeGen.h b/CodeGen/include/Luau/CodeGen.h index 43993231..19a9b3c9 100644 --- a/CodeGen/include/Luau/CodeGen.h +++ b/CodeGen/include/Luau/CodeGen.h @@ -12,6 +12,12 @@ struct lua_State; +#if defined(__x86_64__) || defined(_M_X64) +#define CODEGEN_TARGET_X64 +#elif defined(__aarch64__) || defined(_M_ARM64) +#define CODEGEN_TARGET_A64 +#endif + namespace Luau { namespace CodeGen diff --git a/CodeGen/src/BytecodeAnalysis.cpp b/CodeGen/src/BytecodeAnalysis.cpp index a2f67ebb..900093d1 100644 --- a/CodeGen/src/BytecodeAnalysis.cpp +++ b/CodeGen/src/BytecodeAnalysis.cpp @@ -11,8 +11,6 @@ #include -#include - LUAU_FASTFLAG(LuauCodegenDirectUserdataFlow) LUAU_FASTFLAG(LuauLoadTypeInfo) // Because new VM typeinfo load changes the format used by Codegen, same flag is used LUAU_FASTFLAGVARIABLE(LuauCodegenTypeInfo, false) // New analysis is flagged separately diff --git a/CodeGen/src/CodeBlockUnwind.cpp b/CodeGen/src/CodeBlockUnwind.cpp index ca1a489e..b8876054 100644 --- a/CodeGen/src/CodeBlockUnwind.cpp +++ b/CodeGen/src/CodeBlockUnwind.cpp @@ -7,7 +7,7 @@ #include #include -#if defined(_WIN32) && defined(_M_X64) +#if defined(_WIN32) && defined(CODEGEN_TARGET_X64) #ifndef WIN32_LEAN_AND_MEAN #define WIN32_LEAN_AND_MEAN @@ -26,7 +26,7 @@ extern "C" void __deregister_frame(const void*) __attribute__((weak)); extern "C" void __unw_add_dynamic_fde() __attribute__((weak)); #endif -#if defined(__APPLE__) && defined(__aarch64__) +#if defined(__APPLE__) && defined(CODEGEN_TARGET_A64) #include #include #include @@ -48,7 +48,7 @@ namespace Luau namespace CodeGen { -#if defined(__APPLE__) && defined(__aarch64__) +#if defined(__APPLE__) && defined(CODEGEN_TARGET_A64) static int findDynamicUnwindSections(uintptr_t addr, unw_dynamic_unwind_sections_t* info) { // Define a minimal mach header for JIT'd code. @@ -109,7 +109,7 @@ void* createBlockUnwindInfo(void* context, uint8_t* block, size_t blockSize, siz char* unwindData = (char*)block; unwind->finalize(unwindData, unwindSize, block, blockSize); -#if defined(_WIN32) && defined(_M_X64) +#if defined(_WIN32) && defined(CODEGEN_TARGET_X64) #if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_APP | WINAPI_PARTITION_SYSTEM) if (!RtlAddFunctionTable((RUNTIME_FUNCTION*)block, uint32_t(unwind->getFunctionCount()), uintptr_t(block))) @@ -126,7 +126,7 @@ void* createBlockUnwindInfo(void* context, uint8_t* block, size_t blockSize, siz visitFdeEntries(unwindData, __register_frame); #endif -#if defined(__APPLE__) && defined(__aarch64__) +#if defined(__APPLE__) && defined(CODEGEN_TARGET_A64) // Starting from macOS 14, we need to register unwind section callback to state that our ABI doesn't require pointer authentication // This might conflict with other JITs that do the same; unfortunately this is the best we can do for now. static unw_add_find_dynamic_unwind_sections_t unw_add_find_dynamic_unwind_sections = @@ -141,7 +141,7 @@ void* createBlockUnwindInfo(void* context, uint8_t* block, size_t blockSize, siz void destroyBlockUnwindInfo(void* context, void* unwindData) { -#if defined(_WIN32) && defined(_M_X64) +#if defined(_WIN32) && defined(CODEGEN_TARGET_X64) #if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_APP | WINAPI_PARTITION_SYSTEM) if (!RtlDeleteFunctionTable((RUNTIME_FUNCTION*)unwindData)) @@ -161,12 +161,12 @@ void destroyBlockUnwindInfo(void* context, void* unwindData) bool isUnwindSupported() { -#if defined(_WIN32) && defined(_M_X64) +#if defined(_WIN32) && defined(CODEGEN_TARGET_X64) return true; #elif defined(__ANDROID__) // Current unwind information is not compatible with Android return false; -#elif defined(__APPLE__) && defined(__aarch64__) +#elif defined(__APPLE__) && defined(CODEGEN_TARGET_A64) char ver[256]; size_t verLength = sizeof(ver); // libunwind on macOS 12 and earlier (which maps to osrelease 21) assumes JIT frames use pointer authentication without a way to override that diff --git a/CodeGen/src/CodeGen.cpp b/CodeGen/src/CodeGen.cpp index 3938ab12..5d6f1fb5 100644 --- a/CodeGen/src/CodeGen.cpp +++ b/CodeGen/src/CodeGen.cpp @@ -27,7 +27,7 @@ #include #include -#if defined(__x86_64__) || defined(_M_X64) +#if defined(CODEGEN_TARGET_X64) #ifdef _MSC_VER #include // __cpuid #else @@ -35,7 +35,7 @@ #endif #endif -#if defined(__aarch64__) +#if defined(CODEGEN_TARGET_A64) #ifdef __APPLE__ #include #endif @@ -58,8 +58,6 @@ LUAU_FASTINTVARIABLE(CodegenHeuristicsBlockLimit, 32'768) // 32 K // Current value is based on some member variables being limited to 16 bits LUAU_FASTINTVARIABLE(CodegenHeuristicsBlockInstructionLimit, 65'536) // 64 K -LUAU_FASTFLAG(LuauCodegenContext) - namespace Luau { namespace CodeGen @@ -97,180 +95,9 @@ std::string toString(const CodeGenCompilationResult& result) return ""; } -static const Instruction kCodeEntryInsn = LOP_NATIVECALL; - void* gPerfLogContext = nullptr; PerfLogFn gPerfLogFn = nullptr; -struct OldNativeProto -{ - Proto* p; - void* execdata; - uintptr_t exectarget; -}; - -// Additional data attached to Proto::execdata -// Guaranteed to be aligned to 16 bytes -struct ExtraExecData -{ - size_t execDataSize; - size_t codeSize; -}; - -static int alignTo(int value, int align) -{ - CODEGEN_ASSERT(!FFlag::LuauCodegenContext); - CODEGEN_ASSERT(align > 0 && (align & (align - 1)) == 0); - return (value + (align - 1)) & ~(align - 1); -} - -// Returns the size of execdata required to store all code offsets and ExtraExecData structure at proper alignment -// Always a multiple of 4 bytes -static int calculateExecDataSize(Proto* proto) -{ - CODEGEN_ASSERT(!FFlag::LuauCodegenContext); - int size = proto->sizecode * sizeof(uint32_t); - - size = alignTo(size, 16); - size += sizeof(ExtraExecData); - - return size; -} - -// Returns pointer to the ExtraExecData inside the Proto::execdata -// Even though 'execdata' is a field in Proto, we require it to support cases where it's not attached to Proto during construction -ExtraExecData* getExtraExecData(Proto* proto, void* execdata) -{ - CODEGEN_ASSERT(!FFlag::LuauCodegenContext); - int size = proto->sizecode * sizeof(uint32_t); - - size = alignTo(size, 16); - - return reinterpret_cast(reinterpret_cast(execdata) + size); -} - -static OldNativeProto createOldNativeProto(Proto* proto, const IrBuilder& ir) -{ - CODEGEN_ASSERT(!FFlag::LuauCodegenContext); - - int execDataSize = calculateExecDataSize(proto); - CODEGEN_ASSERT(execDataSize % 4 == 0); - - uint32_t* execData = new uint32_t[execDataSize / 4]; - uint32_t instTarget = ir.function.entryLocation; - - for (int i = 0; i < proto->sizecode; i++) - { - CODEGEN_ASSERT(ir.function.bcMapping[i].asmLocation >= instTarget); - - execData[i] = ir.function.bcMapping[i].asmLocation - instTarget; - } - - // Set first instruction offset to 0 so that entering this function still executes any generated entry code. - execData[0] = 0; - - ExtraExecData* extra = getExtraExecData(proto, execData); - memset(extra, 0, sizeof(ExtraExecData)); - - extra->execDataSize = execDataSize; - - // entry target will be relocated when assembly is finalized - return {proto, execData, instTarget}; -} - -static void destroyExecData(void* execdata) -{ - CODEGEN_ASSERT(!FFlag::LuauCodegenContext); - - delete[] static_cast(execdata); -} - -static void logPerfFunction(Proto* p, uintptr_t addr, unsigned size) -{ - CODEGEN_ASSERT(!FFlag::LuauCodegenContext); - CODEGEN_ASSERT(p->source); - - const char* source = getstr(p->source); - source = (source[0] == '=' || source[0] == '@') ? source + 1 : "[string]"; - - char name[256]; - snprintf(name, sizeof(name), " %s:%d %s", source, p->linedefined, p->debugname ? getstr(p->debugname) : ""); - - if (gPerfLogFn) - gPerfLogFn(gPerfLogContext, addr, size, name); -} - -template -static std::optional createNativeFunction(AssemblyBuilder& build, ModuleHelpers& helpers, Proto* proto, uint32_t& totalIrInstCount, - const HostIrHooks& hooks, CodeGenCompilationResult& result) -{ - CODEGEN_ASSERT(!FFlag::LuauCodegenContext); - - IrBuilder ir(hooks); - ir.buildFunctionIr(proto); - - unsigned instCount = unsigned(ir.function.instructions.size()); - - if (totalIrInstCount + instCount >= unsigned(FInt::CodegenHeuristicsInstructionLimit.value)) - { - result = CodeGenCompilationResult::CodeGenOverflowInstructionLimit; - return std::nullopt; - } - totalIrInstCount += instCount; - - if (!lowerFunction(ir, build, helpers, proto, {}, /* stats */ nullptr, result)) - return std::nullopt; - - return createOldNativeProto(proto, ir); -} - -static NativeState* getNativeState(lua_State* L) -{ - CODEGEN_ASSERT(!FFlag::LuauCodegenContext); - - return static_cast(L->global->ecb.context); -} - -static void onCloseState(lua_State* L) -{ - CODEGEN_ASSERT(!FFlag::LuauCodegenContext); - - delete getNativeState(L); - L->global->ecb = lua_ExecutionCallbacks(); -} - -static void onDestroyFunction(lua_State* L, Proto* proto) -{ - CODEGEN_ASSERT(!FFlag::LuauCodegenContext); - - destroyExecData(proto->execdata); - proto->execdata = nullptr; - proto->exectarget = 0; - proto->codeentry = proto->code; -} - -static int onEnter(lua_State* L, Proto* proto) -{ - CODEGEN_ASSERT(!FFlag::LuauCodegenContext); - - NativeState* data = getNativeState(L); - - CODEGEN_ASSERT(proto->execdata); - CODEGEN_ASSERT(L->ci->savedpc >= proto->code && L->ci->savedpc < proto->code + proto->sizecode); - - uintptr_t target = proto->exectarget + static_cast(proto->execdata)[L->ci->savedpc - proto->code]; - - // Returns 1 to finish the function in the VM - return GateFn(data->context.gateEntry)(L, proto, target, &data->context); -} - -// used to disable native execution, unconditionally -static int onEnterDisabled(lua_State* L, Proto* proto) -{ - CODEGEN_ASSERT(!FFlag::LuauCodegenContext); - - return 1; -} void onDisable(lua_State* L, Proto* proto) { @@ -311,18 +138,7 @@ void onDisable(lua_State* L, Proto* proto) }); } -static size_t getMemorySize(lua_State* L, Proto* proto) -{ - CODEGEN_ASSERT(!FFlag::LuauCodegenContext); - ExtraExecData* extra = getExtraExecData(proto, proto->execdata); - - // While execDataSize is exactly the size of the allocation we made and hold for 'execdata' field, the code size is approximate - // This is because code+data page is shared and owned by all Proto from a single module and each one can keep the whole region alive - // So individual Proto being freed by GC will not reflect memory use by native code correctly - return extra->execDataSize + extra->codeSize; -} - -#if defined(__aarch64__) +#if defined(CODEGEN_TARGET_A64) unsigned int getCpuFeaturesA64() { unsigned int result = 0; @@ -358,7 +174,7 @@ bool isSupported() return false; #endif -#if defined(__x86_64__) || defined(_M_X64) +#if defined(CODEGEN_TARGET_X64) int cpuinfo[4] = {}; #ifdef _MSC_VER __cpuid(cpuinfo, 1); @@ -373,287 +189,58 @@ bool isSupported() return false; return true; -#elif defined(__aarch64__) +#elif defined(CODEGEN_TARGET_A64) return true; #else return false; #endif } -static void create_OLD(lua_State* L, AllocationCallback* allocationCallback, void* allocationCallbackContext) -{ - CODEGEN_ASSERT(!FFlag::LuauCodegenContext); - CODEGEN_ASSERT(isSupported()); - - std::unique_ptr data = std::make_unique(allocationCallback, allocationCallbackContext); - -#if defined(_WIN32) - data->unwindBuilder = std::make_unique(); -#else - data->unwindBuilder = std::make_unique(); -#endif - - data->codeAllocator.context = data->unwindBuilder.get(); - data->codeAllocator.createBlockUnwindInfo = createBlockUnwindInfo; - data->codeAllocator.destroyBlockUnwindInfo = destroyBlockUnwindInfo; - - initFunctions(*data); - -#if defined(__x86_64__) || defined(_M_X64) - if (!X64::initHeaderFunctions(*data)) - return; -#elif defined(__aarch64__) - if (!A64::initHeaderFunctions(*data)) - return; -#endif - - if (gPerfLogFn) - gPerfLogFn(gPerfLogContext, uintptr_t(data->context.gateEntry), 4096, ""); - - lua_ExecutionCallbacks* ecb = &L->global->ecb; - - ecb->context = data.release(); - ecb->close = onCloseState; - ecb->destroy = onDestroyFunction; - ecb->enter = onEnter; - ecb->disable = onDisable; - ecb->getmemorysize = getMemorySize; -} - void create(lua_State* L, AllocationCallback* allocationCallback, void* allocationCallbackContext) { - if (FFlag::LuauCodegenContext) - { - create_NEW(L, allocationCallback, allocationCallbackContext); - } - else - { - create_OLD(L, allocationCallback, allocationCallbackContext); - } + create_NEW(L, allocationCallback, allocationCallbackContext); } void create(lua_State* L) { - if (FFlag::LuauCodegenContext) - { - create_NEW(L); - } - else - { - create(L, nullptr, nullptr); - } + create_NEW(L); } void create(lua_State* L, SharedCodeGenContext* codeGenContext) { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - create_NEW(L, codeGenContext); } [[nodiscard]] bool isNativeExecutionEnabled(lua_State* L) { - if (FFlag::LuauCodegenContext) - { - return isNativeExecutionEnabled_NEW(L); - } - else - { - return getNativeState(L) ? (L->global->ecb.enter == onEnter) : false; - } + return isNativeExecutionEnabled_NEW(L); } void setNativeExecutionEnabled(lua_State* L, bool enabled) { - if (FFlag::LuauCodegenContext) - { - setNativeExecutionEnabled_NEW(L, enabled); - } - else - { - if (getNativeState(L)) - L->global->ecb.enter = enabled ? onEnter : onEnterDisabled; - } -} - -static CompilationResult compile_OLD(lua_State* L, int idx, const CompilationOptions& options, CompilationStats* stats) -{ - CompilationResult compilationResult; - - CODEGEN_ASSERT(lua_isLfunction(L, idx)); - const TValue* func = luaA_toobject(L, idx); - - Proto* root = clvalue(func)->l.p; - - if ((options.flags & CodeGen_OnlyNativeModules) != 0 && (root->flags & LPF_NATIVE_MODULE) == 0) - { - compilationResult.result = CodeGenCompilationResult::NotNativeModule; - return compilationResult; - } - - // If initialization has failed, do not compile any functions - NativeState* data = getNativeState(L); - if (!data) - { - compilationResult.result = CodeGenCompilationResult::CodeGenNotInitialized; - return compilationResult; - } - - std::vector protos; - gatherFunctions(protos, root, options.flags); - - // Skip protos that have been compiled during previous invocations of CodeGen::compile - protos.erase(std::remove_if(protos.begin(), protos.end(), - [](Proto* p) { - return p == nullptr || p->execdata != nullptr; - }), - protos.end()); - - if (protos.empty()) - { - compilationResult.result = CodeGenCompilationResult::NothingToCompile; - return compilationResult; - } - - if (stats != nullptr) - stats->functionsTotal = uint32_t(protos.size()); - -#if defined(__aarch64__) - static unsigned int cpuFeatures = getCpuFeaturesA64(); - A64::AssemblyBuilderA64 build(/* logText= */ false, cpuFeatures); -#else - X64::AssemblyBuilderX64 build(/* logText= */ false); -#endif - - ModuleHelpers helpers; -#if defined(__aarch64__) - A64::assembleHelpers(build, helpers); -#else - X64::assembleHelpers(build, helpers); -#endif - - std::vector results; - results.reserve(protos.size()); - - uint32_t totalIrInstCount = 0; - - for (Proto* p : protos) - { - CodeGenCompilationResult protoResult = CodeGenCompilationResult::Success; - - if (std::optional np = createNativeFunction(build, helpers, p, totalIrInstCount, options.hooks, protoResult)) - results.push_back(*np); - else - compilationResult.protoFailures.push_back({protoResult, p->debugname ? getstr(p->debugname) : "", p->linedefined}); - } - - // Very large modules might result in overflowing a jump offset; in this case we currently abandon the entire module - if (!build.finalize()) - { - for (OldNativeProto result : results) - destroyExecData(result.execdata); - - compilationResult.result = CodeGenCompilationResult::CodeGenAssemblerFinalizationFailure; - return compilationResult; - } - - // If no functions were assembled, we don't need to allocate/copy executable pages for helpers - if (results.empty()) - return compilationResult; - - uint8_t* nativeData = nullptr; - size_t sizeNativeData = 0; - uint8_t* codeStart = nullptr; - if (!data->codeAllocator.allocate(build.data.data(), int(build.data.size()), reinterpret_cast(build.code.data()), - int(build.code.size() * sizeof(build.code[0])), nativeData, sizeNativeData, codeStart)) - { - for (OldNativeProto result : results) - destroyExecData(result.execdata); - - compilationResult.result = CodeGenCompilationResult::AllocationFailed; - return compilationResult; - } - - if (gPerfLogFn && results.size() > 0) - gPerfLogFn(gPerfLogContext, uintptr_t(codeStart), uint32_t(results[0].exectarget), ""); - - for (size_t i = 0; i < results.size(); ++i) - { - uint32_t begin = uint32_t(results[i].exectarget); - uint32_t end = i + 1 < results.size() ? uint32_t(results[i + 1].exectarget) : uint32_t(build.code.size() * sizeof(build.code[0])); - CODEGEN_ASSERT(begin < end); - - if (gPerfLogFn) - logPerfFunction(results[i].p, uintptr_t(codeStart) + begin, end - begin); - - ExtraExecData* extra = getExtraExecData(results[i].p, results[i].execdata); - extra->codeSize = end - begin; - } - - for (const OldNativeProto& result : results) - { - // the memory is now managed by VM and will be freed via onDestroyFunction - result.p->execdata = result.execdata; - result.p->exectarget = uintptr_t(codeStart) + result.exectarget; - result.p->codeentry = &kCodeEntryInsn; - } - - if (stats != nullptr) - { - for (const OldNativeProto& result : results) - { - stats->bytecodeSizeBytes += result.p->sizecode * sizeof(Instruction); - - // Account for the native -> bytecode instruction offsets mapping: - stats->nativeMetadataSizeBytes += result.p->sizecode * sizeof(uint32_t); - } - - stats->functionsCompiled += uint32_t(results.size()); - stats->nativeCodeSizeBytes += build.code.size(); - stats->nativeDataSizeBytes += build.data.size(); - } - - return compilationResult; + setNativeExecutionEnabled_NEW(L, enabled); } CompilationResult compile(lua_State* L, int idx, unsigned int flags, CompilationStats* stats) { Luau::CodeGen::CompilationOptions options{flags}; - if (FFlag::LuauCodegenContext) - { - return compile_NEW(L, idx, options, stats); - } - else - { - return compile_OLD(L, idx, options, stats); - } + return compile_NEW(L, idx, options, stats); } CompilationResult compile(const ModuleId& moduleId, lua_State* L, int idx, unsigned int flags, CompilationStats* stats) { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - Luau::CodeGen::CompilationOptions options{flags}; return compile_NEW(moduleId, L, idx, options, stats); } CompilationResult compile(lua_State* L, int idx, const CompilationOptions& options, CompilationStats* stats) { - if (FFlag::LuauCodegenContext) - { - return compile_NEW(L, idx, options, stats); - } - else - { - return compile_OLD(L, idx, options, stats); - } + return compile_NEW(L, idx, options, stats); } CompilationResult compile(const ModuleId& moduleId, lua_State* L, int idx, const CompilationOptions& options, CompilationStats* stats) { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - return compile_NEW(moduleId, L, idx, options, stats); } diff --git a/CodeGen/src/CodeGenAssembly.cpp b/CodeGen/src/CodeGenAssembly.cpp index e9402426..ce3a57bd 100644 --- a/CodeGen/src/CodeGenAssembly.cpp +++ b/CodeGen/src/CodeGenAssembly.cpp @@ -279,7 +279,7 @@ static std::string getAssemblyImpl(AssemblyBuilder& build, const TValue* func, A return build.text; } -#if defined(__aarch64__) +#if defined(CODEGEN_TARGET_A64) unsigned int getCpuFeaturesA64(); #endif @@ -292,7 +292,7 @@ std::string getAssembly(lua_State* L, int idx, AssemblyOptions options, Lowering { case AssemblyOptions::Host: { -#if defined(__aarch64__) +#if defined(CODEGEN_TARGET_A64) static unsigned int cpuFeatures = getCpuFeaturesA64(); A64::AssemblyBuilderA64 build(/* logText= */ options.includeAssembly, cpuFeatures); #else diff --git a/CodeGen/src/CodeGenContext.cpp b/CodeGen/src/CodeGenContext.cpp index cb542036..cdffb123 100644 --- a/CodeGen/src/CodeGenContext.cpp +++ b/CodeGen/src/CodeGenContext.cpp @@ -12,8 +12,6 @@ #include "lapi.h" - -LUAU_FASTFLAGVARIABLE(LuauCodegenContext, false) LUAU_FASTFLAGVARIABLE(LuauCodegenCheckNullContext, false) LUAU_FASTINT(LuauCodeGenBlockSize) @@ -34,7 +32,6 @@ unsigned int getCpuFeaturesA64(); static void logPerfFunction(Proto* p, uintptr_t addr, unsigned size) { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); CODEGEN_ASSERT(p->source); const char* source = getstr(p->source); @@ -50,8 +47,6 @@ static void logPerfFunction(Proto* p, uintptr_t addr, unsigned size) static void logPerfFunctions( const std::vector& moduleProtos, const uint8_t* nativeModuleBaseAddress, const std::vector& nativeProtos) { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - if (gPerfLogFn == nullptr) return; @@ -83,8 +78,6 @@ static void logPerfFunctions( template [[nodiscard]] static uint32_t bindNativeProtos(const std::vector& moduleProtos, NativeProtosVector& nativeProtos) { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - uint32_t protosBound = 0; auto protoIt = moduleProtos.begin(); @@ -125,7 +118,6 @@ template BaseCodeGenContext::BaseCodeGenContext(size_t blockSize, size_t maxTotalSize, AllocationCallback* allocationCallback, void* allocationCallbackContext) : codeAllocator{blockSize, maxTotalSize, allocationCallback, allocationCallbackContext} { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); CODEGEN_ASSERT(isSupported()); #if defined(_WIN32) @@ -143,12 +135,10 @@ BaseCodeGenContext::BaseCodeGenContext(size_t blockSize, size_t maxTotalSize, Al [[nodiscard]] bool BaseCodeGenContext::initHeaderFunctions() { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - -#if defined(__x86_64__) || defined(_M_X64) +#if defined(CODEGEN_TARGET_X64) if (!X64::initHeaderFunctions(*this)) return false; -#elif defined(__aarch64__) +#elif defined(CODEGEN_TARGET_A64) if (!A64::initHeaderFunctions(*this)) return false; #endif @@ -164,13 +154,10 @@ StandaloneCodeGenContext::StandaloneCodeGenContext( size_t blockSize, size_t maxTotalSize, AllocationCallback* allocationCallback, void* allocationCallbackContext) : BaseCodeGenContext{blockSize, maxTotalSize, allocationCallback, allocationCallbackContext} { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); } [[nodiscard]] std::optional StandaloneCodeGenContext::tryBindExistingModule(const ModuleId&, const std::vector&) { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - // The StandaloneCodeGenContext does not support sharing of native code return {}; } @@ -178,8 +165,6 @@ StandaloneCodeGenContext::StandaloneCodeGenContext( [[nodiscard]] ModuleBindResult StandaloneCodeGenContext::bindModule(const std::optional&, const std::vector& moduleProtos, std::vector nativeProtos, const uint8_t* data, size_t dataSize, const uint8_t* code, size_t codeSize) { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - uint8_t* nativeData = nullptr; size_t sizeNativeData = 0; uint8_t* codeStart = nullptr; @@ -205,8 +190,6 @@ StandaloneCodeGenContext::StandaloneCodeGenContext( void StandaloneCodeGenContext::onCloseState() noexcept { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - // The StandaloneCodeGenContext is owned by the one VM that owns it, so when // that VM is destroyed, we destroy *this as well: delete this; @@ -214,8 +197,6 @@ void StandaloneCodeGenContext::onCloseState() noexcept void StandaloneCodeGenContext::onDestroyFunction(void* execdata) noexcept { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - destroyNativeProtoExecData(static_cast(execdata)); } @@ -225,14 +206,11 @@ SharedCodeGenContext::SharedCodeGenContext( : BaseCodeGenContext{blockSize, maxTotalSize, allocationCallback, allocationCallbackContext} , sharedAllocator{&codeAllocator} { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); } [[nodiscard]] std::optional SharedCodeGenContext::tryBindExistingModule( const ModuleId& moduleId, const std::vector& moduleProtos) { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - NativeModuleRef nativeModule = sharedAllocator.tryGetNativeModule(moduleId); if (nativeModule.empty()) { @@ -249,8 +227,6 @@ SharedCodeGenContext::SharedCodeGenContext( [[nodiscard]] ModuleBindResult SharedCodeGenContext::bindModule(const std::optional& moduleId, const std::vector& moduleProtos, std::vector nativeProtos, const uint8_t* data, size_t dataSize, const uint8_t* code, size_t codeSize) { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - const std::pair insertionResult = [&]() -> std::pair { if (moduleId.has_value()) { @@ -279,8 +255,6 @@ SharedCodeGenContext::SharedCodeGenContext( void SharedCodeGenContext::onCloseState() noexcept { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - // The lifetime of the SharedCodeGenContext is managed separately from the // VMs that use it. When a VM is destroyed, we don't need to do anything // here. @@ -288,23 +262,17 @@ void SharedCodeGenContext::onCloseState() noexcept void SharedCodeGenContext::onDestroyFunction(void* execdata) noexcept { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - getNativeProtoExecDataHeader(static_cast(execdata)).nativeModule->release(); } [[nodiscard]] UniqueSharedCodeGenContext createSharedCodeGenContext() { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - return createSharedCodeGenContext(size_t(FInt::LuauCodeGenBlockSize), size_t(FInt::LuauCodeGenMaxTotalSize), nullptr, nullptr); } [[nodiscard]] UniqueSharedCodeGenContext createSharedCodeGenContext(AllocationCallback* allocationCallback, void* allocationCallbackContext) { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - return createSharedCodeGenContext( size_t(FInt::LuauCodeGenBlockSize), size_t(FInt::LuauCodeGenMaxTotalSize), allocationCallback, allocationCallbackContext); } @@ -312,8 +280,6 @@ void SharedCodeGenContext::onDestroyFunction(void* execdata) noexcept [[nodiscard]] UniqueSharedCodeGenContext createSharedCodeGenContext( size_t blockSize, size_t maxTotalSize, AllocationCallback* allocationCallback, void* allocationCallbackContext) { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - UniqueSharedCodeGenContext codeGenContext{new SharedCodeGenContext{blockSize, maxTotalSize, nullptr, nullptr}}; if (!codeGenContext->initHeaderFunctions()) @@ -324,38 +290,28 @@ void SharedCodeGenContext::onDestroyFunction(void* execdata) noexcept void destroySharedCodeGenContext(const SharedCodeGenContext* codeGenContext) noexcept { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - delete codeGenContext; } void SharedCodeGenContextDeleter::operator()(const SharedCodeGenContext* codeGenContext) const noexcept { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - destroySharedCodeGenContext(codeGenContext); } [[nodiscard]] static BaseCodeGenContext* getCodeGenContext(lua_State* L) noexcept { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - return static_cast(L->global->ecb.context); } static void onCloseState(lua_State* L) noexcept { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - getCodeGenContext(L)->onCloseState(); L->global->ecb = lua_ExecutionCallbacks{}; } static void onDestroyFunction(lua_State* L, Proto* proto) noexcept { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - getCodeGenContext(L)->onDestroyFunction(proto->execdata); proto->execdata = nullptr; proto->exectarget = 0; @@ -364,8 +320,6 @@ static void onDestroyFunction(lua_State* L, Proto* proto) noexcept static int onEnter(lua_State* L, Proto* proto) { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - BaseCodeGenContext* codeGenContext = getCodeGenContext(L); CODEGEN_ASSERT(proto->execdata); @@ -379,8 +333,6 @@ static int onEnter(lua_State* L, Proto* proto) static int onEnterDisabled(lua_State* L, Proto* proto) { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - return 1; } @@ -389,8 +341,6 @@ void onDisable(lua_State* L, Proto* proto); static size_t getMemorySize(lua_State* L, Proto* proto) { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - const NativeProtoExecDataHeader& execDataHeader = getNativeProtoExecDataHeader(static_cast(proto->execdata)); const size_t execDataSize = sizeof(NativeProtoExecDataHeader) + execDataHeader.bytecodeInstructionCount * sizeof(Instruction); @@ -403,7 +353,6 @@ static size_t getMemorySize(lua_State* L, Proto* proto) static void initializeExecutionCallbacks(lua_State* L, BaseCodeGenContext* codeGenContext) noexcept { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); CODEGEN_ASSERT(!FFlag::LuauCodegenCheckNullContext || codeGenContext != nullptr); lua_ExecutionCallbacks* ecb = &L->global->ecb; @@ -418,22 +367,16 @@ static void initializeExecutionCallbacks(lua_State* L, BaseCodeGenContext* codeG void create_NEW(lua_State* L) { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - return create_NEW(L, size_t(FInt::LuauCodeGenBlockSize), size_t(FInt::LuauCodeGenMaxTotalSize), nullptr, nullptr); } void create_NEW(lua_State* L, AllocationCallback* allocationCallback, void* allocationCallbackContext) { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - return create_NEW(L, size_t(FInt::LuauCodeGenBlockSize), size_t(FInt::LuauCodeGenMaxTotalSize), allocationCallback, allocationCallbackContext); } void create_NEW(lua_State* L, size_t blockSize, size_t maxTotalSize, AllocationCallback* allocationCallback, void* allocationCallbackContext) { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - std::unique_ptr codeGenContext = std::make_unique(blockSize, maxTotalSize, allocationCallback, allocationCallbackContext); @@ -445,15 +388,11 @@ void create_NEW(lua_State* L, size_t blockSize, size_t maxTotalSize, AllocationC void create_NEW(lua_State* L, SharedCodeGenContext* codeGenContext) { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - initializeExecutionCallbacks(L, codeGenContext); } [[nodiscard]] static NativeProtoExecDataPtr createNativeProtoExecData(Proto* proto, const IrBuilder& ir) { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - NativeProtoExecDataPtr nativeExecData = createNativeProtoExecData(proto->sizecode); uint32_t instTarget = ir.function.entryLocation; @@ -481,8 +420,6 @@ template [[nodiscard]] static NativeProtoExecDataPtr createNativeFunction(AssemblyBuilder& build, ModuleHelpers& helpers, Proto* proto, uint32_t& totalIrInstCount, const HostIrHooks& hooks, CodeGenCompilationResult& result) { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - IrBuilder ir(hooks); ir.buildFunctionIr(proto); @@ -507,7 +444,6 @@ template [[nodiscard]] static CompilationResult compileInternal( const std::optional& moduleId, lua_State* L, int idx, const CompilationOptions& options, CompilationStats* stats) { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); CODEGEN_ASSERT(lua_isLfunction(L, idx)); const TValue* func = luaA_toobject(L, idx); @@ -547,7 +483,7 @@ template } } -#if defined(__aarch64__) +#if defined(CODEGEN_TARGET_A64) static unsigned int cpuFeatures = getCpuFeaturesA64(); A64::AssemblyBuilderA64 build(/* logText= */ false, cpuFeatures); #else @@ -555,7 +491,7 @@ template #endif ModuleHelpers helpers; -#if defined(__aarch64__) +#if defined(CODEGEN_TARGET_A64) A64::assembleHelpers(build, helpers); #else X64::assembleHelpers(build, helpers); @@ -641,29 +577,21 @@ template CompilationResult compile_NEW(const ModuleId& moduleId, lua_State* L, int idx, const CompilationOptions& options, CompilationStats* stats) { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - return compileInternal(moduleId, L, idx, options, stats); } CompilationResult compile_NEW(lua_State* L, int idx, const CompilationOptions& options, CompilationStats* stats) { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - return compileInternal({}, L, idx, options, stats); } [[nodiscard]] bool isNativeExecutionEnabled_NEW(lua_State* L) { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - return getCodeGenContext(L) != nullptr && L->global->ecb.enter == onEnter; } void setNativeExecutionEnabled_NEW(lua_State* L, bool enabled) { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - if (getCodeGenContext(L) != nullptr) L->global->ecb.enter = enabled ? onEnter : onEnterDisabled; } diff --git a/Sources.cmake b/Sources.cmake index 79fad0e4..4c5504b6 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -181,6 +181,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/Error.h Analysis/include/Luau/FileResolver.h Analysis/include/Luau/Frontend.h + Analysis/include/Luau/Generalization.h Analysis/include/Luau/GlobalTypes.h Analysis/include/Luau/InsertionOrderedMap.h Analysis/include/Luau/Instantiation.h @@ -251,6 +252,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/EmbeddedBuiltinDefinitions.cpp Analysis/src/Error.cpp Analysis/src/Frontend.cpp + Analysis/src/Generalization.cpp Analysis/src/GlobalTypes.cpp Analysis/src/Instantiation.cpp Analysis/src/Instantiation2.cpp @@ -420,6 +422,7 @@ if(TARGET Luau.UnitTest) tests/Fixture.cpp tests/Fixture.h tests/Frontend.test.cpp + tests/Generalization.test.cpp tests/InsertionOrderedMap.test.cpp tests/Instantiation2.test.cpp tests/IostreamOptional.h diff --git a/fuzz/proto.cpp b/fuzz/proto.cpp index ba6fb4c8..d06189a4 100644 --- a/fuzz/proto.cpp +++ b/fuzz/proto.cpp @@ -379,7 +379,7 @@ DEFINE_PROTO_FUZZER(const luau::ModuleSet& message) if (luau_load(globalState, "=fuzz", bytecode.data(), bytecode.size(), 0) == 0) { Luau::CodeGen::AssemblyOptions options; - options.flags = Luau::CodeGen::CodeGen_ColdFunctions; + options.compilationOptions.flags = Luau::CodeGen::CodeGen_ColdFunctions; options.outputBinary = true; options.target = kFuzzCodegenTarget; Luau::CodeGen::getAssembly(globalState, -1, options); diff --git a/tests/CodeAllocator.test.cpp b/tests/CodeAllocator.test.cpp index 9d65a5a7..2ac0e5fd 100644 --- a/tests/CodeAllocator.test.cpp +++ b/tests/CodeAllocator.test.cpp @@ -253,7 +253,7 @@ TEST_CASE("Dwarf2UnwindCodesA64") CHECK(memcmp(data.data(), expected.data(), expected.size()) == 0); } -#if defined(__x86_64__) || defined(_M_X64) +#if defined(CODEGEN_TARGET_X64) #if defined(_WIN32) // Windows x64 ABI @@ -774,7 +774,7 @@ TEST_CASE("GeneratedCodeExecutionWithThrowOutsideTheGateX64") #endif -#if defined(__aarch64__) +#if defined(CODEGEN_TARGET_A64) TEST_CASE("GeneratedCodeExecutionA64") { diff --git a/tests/Generalization.test.cpp b/tests/Generalization.test.cpp new file mode 100644 index 00000000..8268dde6 --- /dev/null +++ b/tests/Generalization.test.cpp @@ -0,0 +1,119 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Generalization.h" +#include "Luau/Scope.h" +#include "Luau/ToString.h" +#include "Luau/Type.h" +#include "Luau/TypeArena.h" +#include "Luau/Error.h" + +#include "ScopedFlags.h" + +#include "doctest.h" + +using namespace Luau; + +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) + +TEST_SUITE_BEGIN("Generalization"); + +struct GeneralizationFixture +{ + TypeArena arena; + BuiltinTypes builtinTypes; + Scope scope{builtinTypes.anyTypePack}; + ToStringOptions opts; + + ScopedFastFlag sff{FFlag::DebugLuauDeferredConstraintResolution, true}; + + std::pair freshType() + { + FreeType ft{&scope, builtinTypes.neverType, builtinTypes.unknownType}; + + TypeId ty = arena.addType(ft); + FreeType* ftv = getMutable(ty); + REQUIRE(ftv != nullptr); + + return {ty, ftv}; + } + + std::string toString(TypeId ty) + { + return ::Luau::toString(ty, opts); + } + + std::string toString(TypePackId ty) + { + return ::Luau::toString(ty, opts); + } + + std::optional generalize(TypeId ty) + { + return ::Luau::generalize(NotNull{&arena}, NotNull{&builtinTypes}, NotNull{&scope}, ty); + } +}; + +TEST_CASE_FIXTURE(GeneralizationFixture, "generalize_a_type_that_is_bounded_by_another_generalizable_type") +{ + auto [t1, ft1] = freshType(); + auto [t2, ft2] = freshType(); + + // t2 <: t1 <: unknown + // unknown <: t2 <: t1 + + ft1->lowerBound = t2; + ft2->upperBound = t1; + ft2->lowerBound = builtinTypes.unknownType; + + auto t2generalized = generalize(t2); + REQUIRE(t2generalized); + + CHECK(follow(t1) == follow(t2)); + + auto t1generalized = generalize(t1); + REQUIRE(t1generalized); + + CHECK(builtinTypes.unknownType == follow(t1)); + CHECK(builtinTypes.unknownType == follow(t2)); +} + +// Same as generalize_a_type_that_is_bounded_by_another_generalizable_type +// except that we generalize the types in the opposite order +TEST_CASE_FIXTURE(GeneralizationFixture, "generalize_a_type_that_is_bounded_by_another_generalizable_type_in_reverse_order") +{ + auto [t1, ft1] = freshType(); + auto [t2, ft2] = freshType(); + + // t2 <: t1 <: unknown + // unknown <: t2 <: t1 + + ft1->lowerBound = t2; + ft2->upperBound = t1; + ft2->lowerBound = builtinTypes.unknownType; + + auto t1generalized = generalize(t1); + REQUIRE(t1generalized); + + CHECK(follow(t1) == follow(t2)); + + auto t2generalized = generalize(t2); + REQUIRE(t2generalized); + + CHECK(builtinTypes.unknownType == follow(t1)); + CHECK(builtinTypes.unknownType == follow(t2)); +} + +TEST_CASE_FIXTURE(GeneralizationFixture, "dont_traverse_into_class_types_when_generalizing") +{ + auto [propTy, _] = freshType(); + + TypeId cursedClass = arena.addType(ClassType{"Cursed", {{"oh_no", Property::readonly(propTy)}}, std::nullopt, std::nullopt, {}, {}, ""}); + + auto genClass = generalize(cursedClass); + REQUIRE(genClass); + + auto genPropTy = get(*genClass)->props.at("oh_no").readTy; + CHECK(is(*genPropTy)); +} + +TEST_SUITE_END(); diff --git a/tests/NonStrictTypeChecker.test.cpp b/tests/NonStrictTypeChecker.test.cpp index d85e46ee..806dac62 100644 --- a/tests/NonStrictTypeChecker.test.cpp +++ b/tests/NonStrictTypeChecker.test.cpp @@ -15,8 +15,6 @@ using namespace Luau; -LUAU_FASTFLAG(LuauCheckedFunctionSyntax); - #define NONSTRICT_REQUIRE_ERR_AT_POS(pos, result, idx) \ do \ { \ @@ -69,7 +67,6 @@ struct NonStrictTypeCheckerFixture : Fixture CheckResult checkNonStrict(const std::string& code) { ScopedFastFlag flags[] = { - {FFlag::LuauCheckedFunctionSyntax, true}, {FFlag::DebugLuauDeferredConstraintResolution, true}, }; LoadDefinitionFileResult res = loadDefinition(definitions); @@ -80,7 +77,6 @@ struct NonStrictTypeCheckerFixture : Fixture CheckResult checkNonStrictModule(const std::string& moduleName) { ScopedFastFlag flags[] = { - {FFlag::LuauCheckedFunctionSyntax, true}, {FFlag::DebugLuauDeferredConstraintResolution, true}, }; LoadDefinitionFileResult res = loadDefinition(definitions); diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index e1163a1b..b178f539 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -11,7 +11,6 @@ using namespace Luau; -LUAU_FASTFLAG(LuauCheckedFunctionSyntax); LUAU_FASTFLAG(LuauLexerLookaheadRemembersBraceType); LUAU_FASTINT(LuauRecursionLimit); LUAU_FASTINT(LuauTypeLengthLimit); @@ -3051,7 +3050,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_top_level_checked_fn") { ParseOptions opts; opts.allowDeclarationSyntax = true; - ScopedFastFlag sff{FFlag::LuauCheckedFunctionSyntax, true}; std::string src = R"BUILTIN_SRC( declare function @checked abs(n: number): number @@ -3071,7 +3069,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_declared_table_checked_member") { ParseOptions opts; opts.allowDeclarationSyntax = true; - ScopedFastFlag sff{FFlag::LuauCheckedFunctionSyntax, true}; const std::string src = R"BUILTIN_SRC( declare math : { @@ -3099,7 +3096,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_checked_outside_decl_fails") { ParseOptions opts; opts.allowDeclarationSyntax = true; - ScopedFastFlag sff{FFlag::LuauCheckedFunctionSyntax, true}; ParseResult pr = tryParse(R"( local @checked = 3 @@ -3113,7 +3109,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_checked_in_and_out_of_decl_fails") { ParseOptions opts; opts.allowDeclarationSyntax = true; - ScopedFastFlag sff{FFlag::LuauCheckedFunctionSyntax, true}; auto pr = tryParse(R"( local @checked = 3 @@ -3129,7 +3124,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_checked_as_function_name_fails") { ParseOptions opts; opts.allowDeclarationSyntax = true; - ScopedFastFlag sff{FFlag::LuauCheckedFunctionSyntax, true}; auto pr = tryParse(R"( function @checked(x: number) : number @@ -3143,7 +3137,6 @@ TEST_CASE_FIXTURE(Fixture, "cannot_use_@_as_variable_name") { ParseOptions opts; opts.allowDeclarationSyntax = true; - ScopedFastFlag sff{FFlag::LuauCheckedFunctionSyntax, true}; auto pr = tryParse(R"( local @blah = 3 diff --git a/tests/SharedCodeAllocator.test.cpp b/tests/SharedCodeAllocator.test.cpp index 30bf1de2..bba8daad 100644 --- a/tests/SharedCodeAllocator.test.cpp +++ b/tests/SharedCodeAllocator.test.cpp @@ -15,8 +15,6 @@ #pragma GCC diagnostic ignored "-Wself-assign-overloaded" #endif -LUAU_FASTFLAG(LuauCodegenContext) - using namespace Luau::CodeGen; @@ -32,8 +30,6 @@ TEST_CASE("NativeModuleRefRefcounting") if (!luau_codegen_supported()) return; - ScopedFastFlag luauCodegenContext{FFlag::LuauCodegenContext, true}; - CodeAllocator codeAllocator{kBlockSize, kMaxTotalSize}; SharedCodeAllocator allocator{&codeAllocator}; @@ -250,8 +246,6 @@ TEST_CASE("NativeProtoRefcounting") if (!luau_codegen_supported()) return; - ScopedFastFlag luauCodegenContext{FFlag::LuauCodegenContext, true}; - CodeAllocator codeAllocator{kBlockSize, kMaxTotalSize}; SharedCodeAllocator allocator{&codeAllocator}; @@ -303,8 +297,6 @@ TEST_CASE("NativeProtoState") if (!luau_codegen_supported()) return; - ScopedFastFlag luauCodegenContext{FFlag::LuauCodegenContext, true}; - CodeAllocator codeAllocator{kBlockSize, kMaxTotalSize}; SharedCodeAllocator allocator{&codeAllocator}; @@ -364,8 +356,6 @@ TEST_CASE("AnonymousModuleLifetime") if (!luau_codegen_supported()) return; - ScopedFastFlag luauCodegenContext{FFlag::LuauCodegenContext, true}; - CodeAllocator codeAllocator{kBlockSize, kMaxTotalSize}; SharedCodeAllocator allocator{&codeAllocator}; @@ -413,8 +403,6 @@ TEST_CASE("SharedAllocation") if (!luau_codegen_supported()) return; - ScopedFastFlag luauCodegenContext{FFlag::LuauCodegenContext, true}; - UniqueSharedCodeGenContext sharedCodeGenContext = createSharedCodeGenContext(); std::unique_ptr L1{luaL_newstate(), lua_close}; diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index 7308d7da..b2c5f623 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -12,7 +12,6 @@ using namespace Luau; LUAU_FASTFLAG(LuauRecursiveTypeParameterRestriction); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); -LUAU_FASTFLAG(LuauCheckedFunctionSyntax); LUAU_FASTFLAG(DebugLuauSharedSelf); TEST_SUITE_BEGIN("ToString"); @@ -1007,7 +1006,6 @@ Type 'string' could not be converted into 'number' in an invariant context)"; TEST_CASE_FIXTURE(Fixture, "checked_fn_toString") { ScopedFastFlag flags[] = { - {FFlag::LuauCheckedFunctionSyntax, true}, {FFlag::DebugLuauDeferredConstraintResolution, true}, }; diff --git a/tests/TypeFamily.test.cpp b/tests/TypeFamily.test.cpp index c5b3e053..14385054 100644 --- a/tests/TypeFamily.test.cpp +++ b/tests/TypeFamily.test.cpp @@ -701,4 +701,19 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "keyof_oss_crash_gh1161") CHECK(get(result.errors[0])); } +TEST_CASE_FIXTURE(FamilyFixture, "fuzzer_numeric_binop_doesnt_assert_on_generalizeFreeType") +{ + CheckResult result = check(R"( +Module 'l0': +local _ = (67108864)(_ >= _).insert +do end +do end +_(...,_(_,_(_()),_())) +(67108864)()() +_(_ ~= _ // _,l0)(_(_({n0,})),_(_),_) +_(setmetatable(_,{[...]=_,})) + +)"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index bfb17c78..e424ddca 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -2687,4 +2687,40 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "error_suppression_propagates_through_functio CHECK("(any) -> (any?, any)" == toString(requireType("first"))); } +TEST_CASE_FIXTURE(BuiltinsFixture, "fuzzer_normalizer_out_of_resources") +{ + // This luau code should finish typechecking, not segfault upon dereferencing + // the normalized type + CheckResult result = check(R"( + Module 'l0': +local _ = true,...,_ +if ... then +while _:_(_._G) do +do end +_ = _ and _ +_ = 0 and {# _,} +local _ = "CCCCCCCCCCCCCCCCCCCCCCCCCCC" +local l0 = require(module0) +end +local function l0() +end +elseif _ then +l0 = _ +end +do end +while _ do +_ = if _ then _ elseif _ then _,if _ then _ else _ +_ = _() +do end +do end +if _ then +end +end +_ = _,{} + + )"); + +} + + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 307084d5..d828ff65 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -4460,4 +4460,23 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "ipairs_adds_an_unbounded_indexer") CHECK("{a}" == toString(requireType("a"), {true})); } +TEST_CASE_FIXTURE(BuiltinsFixture, "index_results_compare_to_nil") +{ + CheckResult result = check(R"( + --!strict + + function foo(tbl: {number}) + if tbl[2] == nil then + print("foo") + end + + if tbl[3] ~= nil then + print("bar") + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/Unifier2.test.cpp b/tests/Unifier2.test.cpp index dcec34d1..8efb2870 100644 --- a/tests/Unifier2.test.cpp +++ b/tests/Unifier2.test.cpp @@ -132,67 +132,4 @@ TEST_CASE_FIXTURE(Unifier2Fixture, "unify_binds_free_supertype_tail_pack") CHECK("(number <: 'a)" == toString(freeAndFree)); } -TEST_CASE_FIXTURE(Unifier2Fixture, "generalize_a_type_that_is_bounded_by_another_generalizable_type") -{ - auto [t1, ft1] = freshType(); - auto [t2, ft2] = freshType(); - - // t2 <: t1 <: unknown - // unknown <: t2 <: t1 - - ft1->lowerBound = t2; - ft2->upperBound = t1; - ft2->lowerBound = builtinTypes.unknownType; - - auto t2generalized = u2.generalize(t2); - REQUIRE(t2generalized); - - CHECK(follow(t1) == follow(t2)); - - auto t1generalized = u2.generalize(t1); - REQUIRE(t1generalized); - - CHECK(builtinTypes.unknownType == follow(t1)); - CHECK(builtinTypes.unknownType == follow(t2)); -} - -// Same as generalize_a_type_that_is_bounded_by_another_generalizable_type -// except that we generalize the types in the opposite order -TEST_CASE_FIXTURE(Unifier2Fixture, "generalize_a_type_that_is_bounded_by_another_generalizable_type_in_reverse_order") -{ - auto [t1, ft1] = freshType(); - auto [t2, ft2] = freshType(); - - // t2 <: t1 <: unknown - // unknown <: t2 <: t1 - - ft1->lowerBound = t2; - ft2->upperBound = t1; - ft2->lowerBound = builtinTypes.unknownType; - - auto t1generalized = u2.generalize(t1); - REQUIRE(t1generalized); - - CHECK(follow(t1) == follow(t2)); - - auto t2generalized = u2.generalize(t2); - REQUIRE(t2generalized); - - CHECK(builtinTypes.unknownType == follow(t1)); - CHECK(builtinTypes.unknownType == follow(t2)); -} - -TEST_CASE_FIXTURE(Unifier2Fixture, "dont_traverse_into_class_types_when_generalizing") -{ - auto [propTy, _] = freshType(); - - TypeId cursedClass = arena.addType(ClassType{"Cursed", {{"oh_no", Property::readonly(propTy)}}, std::nullopt, std::nullopt, {}, {}, ""}); - - auto genClass = u2.generalize(cursedClass); - REQUIRE(genClass); - - auto genPropTy = get(*genClass)->props.at("oh_no").readTy; - CHECK(is(*genPropTy)); -} - TEST_SUITE_END(); diff --git a/tests/main.cpp b/tests/main.cpp index 5d1ee6a6..4de391b6 100644 --- a/tests/main.cpp +++ b/tests/main.cpp @@ -18,7 +18,7 @@ #include // IsDebuggerPresent #endif -#if defined(__x86_64__) || defined(_M_X64) +#if defined(CODEGEN_TARGET_X64) #include #endif @@ -330,7 +330,7 @@ static void setFastFlags(const std::vector& flags) // This function performs system/architecture specific initialization prior to running tests. static void initSystem() { -#if defined(__x86_64__) || defined(_M_X64) +#if defined(CODEGEN_TARGET_X64) // Some unit tests make use of denormalized numbers. So flags to flush to zero or treat denormals as zero // must be disabled for expected behavior. _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_OFF); diff --git a/tools/faillist.txt b/tools/faillist.txt index c0c12bc3..2450eeb1 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -186,7 +186,6 @@ TableTests.infer_array TableTests.infer_indexer_from_array_like_table TableTests.infer_indexer_from_its_variable_type_and_unifiable TableTests.inferred_return_type_of_free_table -TableTests.insert_a_and_f_of_a_into_table_res_in_a_loop TableTests.invariant_table_properties_means_instantiating_tables_in_assignment_is_unsound TableTests.invariant_table_properties_means_instantiating_tables_in_call_is_unsound TableTests.length_operator_union @@ -363,7 +362,6 @@ TypeInferLoops.for_in_loop_on_non_function TypeInferLoops.for_in_loop_with_next TypeInferLoops.for_loop TypeInferLoops.ipairs_produces_integral_indices -TypeInferLoops.iterate_over_free_table TypeInferLoops.iterate_over_properties TypeInferLoops.iteration_regression_issue_69967_alt TypeInferLoops.loop_iter_metamethod_nil From c73ecd8e08c488acd22db9f04c8935471d170e37 Mon Sep 17 00:00:00 2001 From: birds3345 <31601136+birds3345@users.noreply.github.com> Date: Tue, 21 May 2024 16:58:33 -0400 Subject: [PATCH 09/20] Fix typo in a comment (#1255) ens -> ends --- Compiler/src/BytecodeBuilder.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp index 5386a528..6c76b671 100644 --- a/Compiler/src/BytecodeBuilder.cpp +++ b/Compiler/src/BytecodeBuilder.cpp @@ -1666,7 +1666,7 @@ void BytecodeBuilder::validateVariadic() const if (LUAU_INSN_B(insn) == 0) { - // consumer instruction ens a variadic sequence + // consumer instruction ends a variadic sequence LUAU_ASSERT(variadicSeq); variadicSeq = false; } From 0dbe1a5022a877e77f965c57a7c3a9728d744b19 Mon Sep 17 00:00:00 2001 From: birds3345 <31601136+birds3345@users.noreply.github.com> Date: Wed, 22 May 2024 16:07:15 -0400 Subject: [PATCH 10/20] add cmake folder to .gitignore (#1246) In the readme file under the building section, it specifies that you should run the command `mkdir cmake && cd cmake`; however, the folder is not currently being ignored. --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index 528ab204..8de6d91d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,7 @@ /build/ /build[.-]*/ +/cmake/ +/cmake[.-]*/ /coverage/ /.vs/ /.vscode/ From c8fe77c268cb7887d4bddcbc6e04183a289ec9af Mon Sep 17 00:00:00 2001 From: aaron Date: Sun, 26 May 2024 13:09:09 -0400 Subject: [PATCH 11/20] Sync to upstream/release/627 (#1266) ### What's new? * Removed new `table.move` optimization because of correctness problems. ### New Type Solver * Improved error messages for type families to describe what's wrong in more detail, and ideally without using the term `type family` at all. * Change `boolean` and `string` singletons in type checking to report errors to the user when they've gotten an impossible type (indicating a type error from their context). * Split debugging flags for type family reduction (`DebugLuauLogTypeFamilies`) from general solver logging (`DebugLuauLogSolver`). * Improve type simplification to support patterns like `(number | string) | (string | number)` becoming `number | string`. ### Native Code Generation * Use templated `luaV_doarith` to speedup vector operation fallbacks. * Various small changes to better support arm64 on Windows. ### Internal Contributors Co-authored-by: Aaron Weiss Co-authored-by: Andy Friesen Co-authored-by: James McNellis Co-authored-by: Vighnesh Vijay Co-authored-by: Vyacheslav Egorov --------- Co-authored-by: Alexander McCord Co-authored-by: Andy Friesen Co-authored-by: Vighnesh Co-authored-by: Aviral Goel Co-authored-by: David Cope Co-authored-by: Lily Brown Co-authored-by: Vyacheslav Egorov --- Analysis/include/Luau/Scope.h | 8 + Analysis/src/BuiltinDefinitions.cpp | 203 +++++++-------------- Analysis/src/Error.cpp | 109 +++++++++++ Analysis/src/Normalize.cpp | 13 +- Analysis/src/Simplify.cpp | 4 + Analysis/src/TypeChecker2.cpp | 14 +- Analysis/src/TypeFamily.cpp | 28 +-- CodeGen/include/Luau/CodeGen.h | 11 +- CodeGen/include/Luau/UnwindBuilder.h | 7 +- CodeGen/include/Luau/UnwindBuilderDwarf2.h | 5 +- CodeGen/include/Luau/UnwindBuilderWin.h | 5 +- CodeGen/src/CodeBlockUnwind.cpp | 6 +- CodeGen/src/CodeGen.cpp | 58 ------ CodeGen/src/CodeGenA64.cpp | 2 +- CodeGen/src/CodeGenContext.cpp | 40 ++-- CodeGen/src/CodeGenContext.h | 28 --- CodeGen/src/CodeGenX64.cpp | 2 +- CodeGen/src/EmitCommonX64.cpp | 43 ++++- CodeGen/src/IrLoweringA64.cpp | 45 ++++- CodeGen/src/NativeState.cpp | 20 ++ CodeGen/src/NativeState.h | 8 + CodeGen/src/UnwindBuilderDwarf2.cpp | 19 +- CodeGen/src/UnwindBuilderWin.cpp | 15 +- VM/src/lvm.h | 4 + VM/src/lvmexecute.cpp | 155 ++++++++++++++-- VM/src/lvmutils.cpp | 146 +++++++++++++++ tests/CodeAllocator.test.cpp | 6 +- tests/Error.test.cpp | 42 +++++ tests/Simplify.test.cpp | 8 + tests/TypeFamily.test.cpp | 12 +- tests/TypeInfer.functions.test.cpp | 8 +- tests/TypeInfer.tables.test.cpp | 27 +++ tools/faillist.txt | 2 - 33 files changed, 771 insertions(+), 332 deletions(-) diff --git a/Analysis/include/Luau/Scope.h b/Analysis/include/Luau/Scope.h index 5f1630d5..0e6eff56 100644 --- a/Analysis/include/Luau/Scope.h +++ b/Analysis/include/Luau/Scope.h @@ -102,4 +102,12 @@ bool subsumesStrict(Scope* left, Scope* right); // outermost-possible scope. bool subsumes(Scope* left, Scope* right); +inline Scope* max(Scope* left, Scope* right) +{ + if (subsumes(left, right)) + return right; + else + return left; +} + } // namespace Luau diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index f9ce87e0..a9c519fe 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -24,7 +24,6 @@ */ LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); -LUAU_FASTFLAGVARIABLE(LuauMakeStringMethodsChecked, false); namespace Luau { @@ -773,153 +772,87 @@ TypeId makeStringMetatable(NotNull builtinTypes) const TypePackId numberVariadicList = arena->addTypePack(TypePackVar{VariadicTypePack{numberType}}); - if (FFlag::LuauMakeStringMethodsChecked) - { - FunctionType formatFTV{arena->addTypePack(TypePack{{stringType}, variadicTailPack}), oneStringPack}; - formatFTV.magicFunction = &magicFunctionFormat; - formatFTV.isCheckedFunction = true; - const TypeId formatFn = arena->addType(formatFTV); - attachDcrMagicFunction(formatFn, dcrMagicFunctionFormat); + FunctionType formatFTV{arena->addTypePack(TypePack{{stringType}, variadicTailPack}), oneStringPack}; + formatFTV.magicFunction = &magicFunctionFormat; + formatFTV.isCheckedFunction = true; + const TypeId formatFn = arena->addType(formatFTV); + attachDcrMagicFunction(formatFn, dcrMagicFunctionFormat); - const TypeId stringToStringType = makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType}, /* checked */ true); + const TypeId stringToStringType = makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType}, /* checked */ true); - const TypeId replArgType = arena->addType( - UnionType{{stringType, arena->addType(TableType({}, TableIndexer(stringType, stringType), TypeLevel{}, TableState::Generic)), - makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType}, /* checked */ false)}}); - const TypeId gsubFunc = - makeFunction(*arena, stringType, {}, {}, {stringType, replArgType, optionalNumber}, {}, {stringType, numberType}, /* checked */ false); - const TypeId gmatchFunc = makeFunction( - *arena, stringType, {}, {}, {stringType}, {}, {arena->addType(FunctionType{emptyPack, stringVariadicList})}, /* checked */ true); - attachMagicFunction(gmatchFunc, magicFunctionGmatch); - attachDcrMagicFunction(gmatchFunc, dcrMagicFunctionGmatch); + const TypeId replArgType = + arena->addType(UnionType{{stringType, arena->addType(TableType({}, TableIndexer(stringType, stringType), TypeLevel{}, TableState::Generic)), + makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType}, /* checked */ false)}}); + const TypeId gsubFunc = + makeFunction(*arena, stringType, {}, {}, {stringType, replArgType, optionalNumber}, {}, {stringType, numberType}, /* checked */ false); + const TypeId gmatchFunc = + makeFunction(*arena, stringType, {}, {}, {stringType}, {}, {arena->addType(FunctionType{emptyPack, stringVariadicList})}, /* checked */ true); + attachMagicFunction(gmatchFunc, magicFunctionGmatch); + attachDcrMagicFunction(gmatchFunc, dcrMagicFunctionGmatch); - FunctionType matchFuncTy{ - arena->addTypePack({stringType, stringType, optionalNumber}), arena->addTypePack(TypePackVar{VariadicTypePack{stringType}})}; - matchFuncTy.isCheckedFunction = true; - const TypeId matchFunc = arena->addType(matchFuncTy); - attachMagicFunction(matchFunc, magicFunctionMatch); - attachDcrMagicFunction(matchFunc, dcrMagicFunctionMatch); + FunctionType matchFuncTy{ + arena->addTypePack({stringType, stringType, optionalNumber}), arena->addTypePack(TypePackVar{VariadicTypePack{stringType}})}; + matchFuncTy.isCheckedFunction = true; + const TypeId matchFunc = arena->addType(matchFuncTy); + attachMagicFunction(matchFunc, magicFunctionMatch); + attachDcrMagicFunction(matchFunc, dcrMagicFunctionMatch); - FunctionType findFuncTy{arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}), - arena->addTypePack(TypePack{{optionalNumber, optionalNumber}, stringVariadicList})}; - findFuncTy.isCheckedFunction = true; - const TypeId findFunc = arena->addType(findFuncTy); - attachMagicFunction(findFunc, magicFunctionFind); - attachDcrMagicFunction(findFunc, dcrMagicFunctionFind); + FunctionType findFuncTy{arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}), + arena->addTypePack(TypePack{{optionalNumber, optionalNumber}, stringVariadicList})}; + findFuncTy.isCheckedFunction = true; + const TypeId findFunc = arena->addType(findFuncTy); + attachMagicFunction(findFunc, magicFunctionFind); + attachDcrMagicFunction(findFunc, dcrMagicFunctionFind); - // string.byte : string -> number? -> number? -> ...number - FunctionType stringDotByte{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList}; - stringDotByte.isCheckedFunction = true; + // string.byte : string -> number? -> number? -> ...number + FunctionType stringDotByte{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList}; + stringDotByte.isCheckedFunction = true; - // string.char : .... number -> string - FunctionType stringDotChar{numberVariadicList, arena->addTypePack({stringType})}; - stringDotChar.isCheckedFunction = true; + // string.char : .... number -> string + FunctionType stringDotChar{numberVariadicList, arena->addTypePack({stringType})}; + stringDotChar.isCheckedFunction = true; - // string.unpack : string -> string -> number? -> ...any - FunctionType stringDotUnpack{ - arena->addTypePack(TypePack{{stringType, stringType, optionalNumber}}), - variadicTailPack, - }; - stringDotUnpack.isCheckedFunction = true; + // string.unpack : string -> string -> number? -> ...any + FunctionType stringDotUnpack{ + arena->addTypePack(TypePack{{stringType, stringType, optionalNumber}}), + variadicTailPack, + }; + stringDotUnpack.isCheckedFunction = true; - TableType::Props stringLib = { - {"byte", {arena->addType(stringDotByte)}}, - {"char", {arena->addType(stringDotChar)}}, - {"find", {findFunc}}, - {"format", {formatFn}}, // FIXME - {"gmatch", {gmatchFunc}}, - {"gsub", {gsubFunc}}, - {"len", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType}, /* checked */ true)}}, - {"lower", {stringToStringType}}, - {"match", {matchFunc}}, - {"rep", {makeFunction(*arena, stringType, {}, {}, {numberType}, {}, {stringType}, /* checked */ true)}}, - {"reverse", {stringToStringType}}, - {"sub", {makeFunction(*arena, stringType, {}, {}, {numberType, optionalNumber}, {}, {stringType}, /* checked */ true)}}, - {"upper", {stringToStringType}}, - {"split", {makeFunction(*arena, stringType, {}, {}, {optionalString}, {}, - {arena->addType(TableType{{}, TableIndexer{numberType, stringType}, TypeLevel{}, TableState::Sealed})}, - /* checked */ true)}}, - {"pack", {arena->addType(FunctionType{ - arena->addTypePack(TypePack{{stringType}, variadicTailPack}), - oneStringPack, - })}}, - {"packsize", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType}, /* checked */ true)}}, - {"unpack", {arena->addType(stringDotUnpack)}}, - }; - assignPropDocumentationSymbols(stringLib, "@luau/global/string"); + TableType::Props stringLib = { + {"byte", {arena->addType(stringDotByte)}}, + {"char", {arena->addType(stringDotChar)}}, + {"find", {findFunc}}, + {"format", {formatFn}}, // FIXME + {"gmatch", {gmatchFunc}}, + {"gsub", {gsubFunc}}, + {"len", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType}, /* checked */ true)}}, + {"lower", {stringToStringType}}, + {"match", {matchFunc}}, + {"rep", {makeFunction(*arena, stringType, {}, {}, {numberType}, {}, {stringType}, /* checked */ true)}}, + {"reverse", {stringToStringType}}, + {"sub", {makeFunction(*arena, stringType, {}, {}, {numberType, optionalNumber}, {}, {stringType}, /* checked */ true)}}, + {"upper", {stringToStringType}}, + {"split", {makeFunction(*arena, stringType, {}, {}, {optionalString}, {}, + {arena->addType(TableType{{}, TableIndexer{numberType, stringType}, TypeLevel{}, TableState::Sealed})}, + /* checked */ true)}}, + {"pack", {arena->addType(FunctionType{ + arena->addTypePack(TypePack{{stringType}, variadicTailPack}), + oneStringPack, + })}}, + {"packsize", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType}, /* checked */ true)}}, + {"unpack", {arena->addType(stringDotUnpack)}}, + }; - TypeId tableType = arena->addType(TableType{std::move(stringLib), std::nullopt, TypeLevel{}, TableState::Sealed}); + assignPropDocumentationSymbols(stringLib, "@luau/global/string"); - if (TableType* ttv = getMutable(tableType)) - ttv->name = "typeof(string)"; + TypeId tableType = arena->addType(TableType{std::move(stringLib), std::nullopt, TypeLevel{}, TableState::Sealed}); - return arena->addType(TableType{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed}); - } - else - { - FunctionType formatFTV{arena->addTypePack(TypePack{{stringType}, variadicTailPack}), oneStringPack}; - formatFTV.magicFunction = &magicFunctionFormat; - const TypeId formatFn = arena->addType(formatFTV); - attachDcrMagicFunction(formatFn, dcrMagicFunctionFormat); + if (TableType* ttv = getMutable(tableType)) + ttv->name = "typeof(string)"; - const TypeId stringToStringType = makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType}); - - const TypeId replArgType = arena->addType( - UnionType{{stringType, arena->addType(TableType({}, TableIndexer(stringType, stringType), TypeLevel{}, TableState::Generic)), - makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType})}}); - const TypeId gsubFunc = makeFunction(*arena, stringType, {}, {}, {stringType, replArgType, optionalNumber}, {}, {stringType, numberType}); - const TypeId gmatchFunc = - makeFunction(*arena, stringType, {}, {}, {stringType}, {}, {arena->addType(FunctionType{emptyPack, stringVariadicList})}); - attachMagicFunction(gmatchFunc, magicFunctionGmatch); - attachDcrMagicFunction(gmatchFunc, dcrMagicFunctionGmatch); - - const TypeId matchFunc = arena->addType(FunctionType{ - arena->addTypePack({stringType, stringType, optionalNumber}), arena->addTypePack(TypePackVar{VariadicTypePack{stringType}})}); - attachMagicFunction(matchFunc, magicFunctionMatch); - attachDcrMagicFunction(matchFunc, dcrMagicFunctionMatch); - - const TypeId findFunc = arena->addType(FunctionType{arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}), - arena->addTypePack(TypePack{{optionalNumber, optionalNumber}, stringVariadicList})}); - attachMagicFunction(findFunc, magicFunctionFind); - attachDcrMagicFunction(findFunc, dcrMagicFunctionFind); - - TableType::Props stringLib = { - {"byte", {arena->addType(FunctionType{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList})}}, - {"char", {arena->addType(FunctionType{numberVariadicList, arena->addTypePack({stringType})})}}, - {"find", {findFunc}}, - {"format", {formatFn}}, // FIXME - {"gmatch", {gmatchFunc}}, - {"gsub", {gsubFunc}}, - {"len", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType})}}, - {"lower", {stringToStringType}}, - {"match", {matchFunc}}, - {"rep", {makeFunction(*arena, stringType, {}, {}, {numberType}, {}, {stringType})}}, - {"reverse", {stringToStringType}}, - {"sub", {makeFunction(*arena, stringType, {}, {}, {numberType, optionalNumber}, {}, {stringType})}}, - {"upper", {stringToStringType}}, - {"split", {makeFunction(*arena, stringType, {}, {}, {optionalString}, {}, - {arena->addType(TableType{{}, TableIndexer{numberType, stringType}, TypeLevel{}, TableState::Sealed})})}}, - {"pack", {arena->addType(FunctionType{ - arena->addTypePack(TypePack{{stringType}, variadicTailPack}), - oneStringPack, - })}}, - {"packsize", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType})}}, - {"unpack", {arena->addType(FunctionType{ - arena->addTypePack(TypePack{{stringType, stringType, optionalNumber}}), - variadicTailPack, - })}}, - }; - - assignPropDocumentationSymbols(stringLib, "@luau/global/string"); - - TypeId tableType = arena->addType(TableType{std::move(stringLib), std::nullopt, TypeLevel{}, TableState::Sealed}); - - if (TableType* ttv = getMutable(tableType)) - ttv->name = "typeof(string)"; - - return arena->addType(TableType{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed}); - } + return arena->addType(TableType{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed}); } static std::optional> magicFunctionSelect( diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index 98b15b77..2087e3d3 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -7,11 +7,13 @@ #include "Luau/NotNull.h" #include "Luau/StringUtils.h" #include "Luau/ToString.h" +#include "Luau/TypeFamily.h" #include #include #include #include +#include LUAU_FASTINTVARIABLE(LuauIndentTypeMismatchMaxTypeLength, 10) @@ -61,6 +63,23 @@ static std::string wrongNumberOfArgsString( namespace Luau { +// this list of binary operator type families is used for better stringification of type families errors +static const std::unordered_map kBinaryOps{ + {"add", "+"}, {"sub", "-"}, {"mul", "*"}, {"div", "/"}, {"idiv", "//"}, {"pow", "^"}, {"mod", "%"}, {"concat", ".."}, {"and", "and"}, + {"or", "or"}, {"lt", "< or >="}, {"le", "<= or >"}, {"eq", "== or ~="} +}; + +// this list of unary operator type families is used for better stringification of type families errors +static const std::unordered_map kUnaryOps{ + {"unm", "-"}, {"len", "#"}, {"not", "not"} +}; + +// this list of type families will receive a special error indicating that the user should file a bug on the GitHub repository +// putting a type family in this list indicates that it is expected to _always_ reduce +static const std::unordered_set kUnreachableTypeFamilies{ + "refine", "singleton", "union", "intersect" +}; + struct ErrorConverter { FileResolver* fileResolver = nullptr; @@ -565,6 +584,96 @@ struct ErrorConverter std::string operator()(const UninhabitedTypeFamily& e) const { + auto tfit = get(e.ty); + LUAU_ASSERT(tfit); // Luau analysis has actually done something wrong if this type is not a type family. + if (!tfit) + return "Unexpected type " + Luau::toString(e.ty) + " flagged as an uninhabited type family."; + + // unary operators + if (auto unaryString = kUnaryOps.find(tfit->family->name); unaryString != kUnaryOps.end()) + { + std::string result = "Operator '" + std::string(unaryString->second) + "' could not be applied to "; + + if (tfit->typeArguments.size() == 1 && tfit->packArguments.empty()) + { + result += "operand of type " + Luau::toString(tfit->typeArguments[0]); + + if (tfit->family->name != "not") + result += "; there is no corresponding overload for __" + tfit->family->name; + } + else + { + // if it's not the expected case, we ought to add a specialization later, but this is a sane default. + result += "operands of types "; + + bool isFirst = true; + for (auto arg : tfit->typeArguments) + { + if (!isFirst) + result += ", "; + + result += Luau::toString(arg); + isFirst = false; + } + + for (auto packArg : tfit->packArguments) + result += ", " + Luau::toString(packArg); + } + + return result; + } + + // binary operators + if (auto binaryString = kBinaryOps.find(tfit->family->name); binaryString != kBinaryOps.end()) + { + std::string result = "Operator '" + std::string(binaryString->second) + "' could not be applied to operands of types "; + + if (tfit->typeArguments.size() == 2 && tfit->packArguments.empty()) + { + // this is the expected case. + result += Luau::toString(tfit->typeArguments[0]) + " and " + Luau::toString(tfit->typeArguments[1]); + } + else + { + // if it's not the expected case, we ought to add a specialization later, but this is a sane default. + + bool isFirst = true; + for (auto arg : tfit->typeArguments) + { + if (!isFirst) + result += ", "; + + result += Luau::toString(arg); + isFirst = false; + } + + for (auto packArg : tfit->packArguments) + result += ", " + Luau::toString(packArg); + } + + result += "; there is no corresponding overload for __" + tfit->family->name; + + return result; + } + + // miscellaneous + + if ("keyof" == tfit->family->name || "rawkeyof" == tfit->family->name) + { + if (tfit->typeArguments.size() == 1 && tfit->packArguments.empty()) + return "Type '" + toString(tfit->typeArguments[0]) + "' does not have keys, so '" + Luau::toString(e.ty) + "' is invalid"; + else + return "Type family instance " + Luau::toString(e.ty) + " is ill-formed, and thus invalid"; + } + + if (kUnreachableTypeFamilies.count(tfit->family->name)) + { + return "Type family instance " + Luau::toString(e.ty) + " is uninhabited\n" + + "This is likely to be a bug, please report it at https://github.com/luau-lang/luau/issues"; + } + + // Everything should be specialized above to report a more descriptive error that hopefully does not mention "type families" explicitly. + // If we produce this message, it's an indication that we've missed a specialization and it should be fixed! return "Type family instance " + Luau::toString(e.ty) + " is uninhabited"; } diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index 3c63a7fd..5b14fd5f 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -2534,6 +2534,7 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there state = tttv->state; TypeLevel level = max(httv->level, tttv->level); + Scope* scope = max(httv->scope, tttv->scope); std::unique_ptr result = nullptr; bool hereSubThere = true; @@ -2644,7 +2645,7 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there if (prop.readTy || prop.writeTy) { if (!result.get()) - result = std::make_unique(TableType{state, level}); + result = std::make_unique(TableType{state, level, scope}); result->props[name] = prop; } } @@ -2654,7 +2655,7 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there if (httv->props.count(name) == 0) { if (!result.get()) - result = std::make_unique(TableType{state, level}); + result = std::make_unique(TableType{state, level, scope}); result->props[name] = tprop; hereSubThere = false; @@ -2667,7 +2668,7 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there TypeId index = unionType(httv->indexer->indexType, tttv->indexer->indexType); TypeId indexResult = intersectionType(httv->indexer->indexResultType, tttv->indexer->indexResultType); if (!result.get()) - result = std::make_unique(TableType{state, level}); + result = std::make_unique(TableType{state, level, scope}); result->indexer = {index, indexResult}; hereSubThere &= (httv->indexer->indexType == index) && (httv->indexer->indexResultType == indexResult); thereSubHere &= (tttv->indexer->indexType == index) && (tttv->indexer->indexResultType == indexResult); @@ -2675,14 +2676,14 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there else if (httv->indexer) { if (!result.get()) - result = std::make_unique(TableType{state, level}); + result = std::make_unique(TableType{state, level, scope}); result->indexer = httv->indexer; thereSubHere = false; } else if (tttv->indexer) { if (!result.get()) - result = std::make_unique(TableType{state, level}); + result = std::make_unique(TableType{state, level, scope}); result->indexer = tttv->indexer; hereSubThere = false; } @@ -2697,7 +2698,7 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there if (result.get()) table = arena->addType(std::move(*result)); else - table = arena->addType(TableType{state, level}); + table = arena->addType(TableType{state, level, scope}); } if (tmtable && hmtable) diff --git a/Analysis/src/Simplify.cpp b/Analysis/src/Simplify.cpp index d29546a2..ca78d54d 100644 --- a/Analysis/src/Simplify.cpp +++ b/Analysis/src/Simplify.cpp @@ -1255,6 +1255,10 @@ TypeId TypeSimplifier::union_(TypeId left, TypeId right) case Relation::Coincident: case Relation::Superset: return left; + case Relation::Subset: + newParts.insert(right); + changed = true; + break; default: newParts.insert(part); newParts.insert(right); diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index d0d37127..37e0f039 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -1242,13 +1242,14 @@ struct TypeChecker2 void visit(AstExprConstantBool* expr) { -#if defined(LUAU_ENABLE_ASSERT) + // booleans use specialized inference logic for singleton types, which can lead to real type errors here. + const TypeId bestType = expr->value ? builtinTypes->trueType : builtinTypes->falseType; const TypeId inferredType = lookupType(expr); const SubtypingResult r = subtyping->isSubtype(bestType, inferredType); - LUAU_ASSERT(r.isSubtype || isErrorSuppressing(expr->location, inferredType)); -#endif + if (!r.isSubtype && !isErrorSuppressing(expr->location, inferredType)) + reportError(TypeMismatch{inferredType, bestType}, expr->location); } void visit(AstExprConstantNumber* expr) @@ -1264,13 +1265,14 @@ struct TypeChecker2 void visit(AstExprConstantString* expr) { -#if defined(LUAU_ENABLE_ASSERT) + // strings use specialized inference logic for singleton types, which can lead to real type errors here. + const TypeId bestType = module->internalTypes.addType(SingletonType{StringSingleton{std::string{expr->value.data, expr->value.size}}}); const TypeId inferredType = lookupType(expr); const SubtypingResult r = subtyping->isSubtype(bestType, inferredType); - LUAU_ASSERT(r.isSubtype || isErrorSuppressing(expr->location, inferredType)); -#endif + if (!r.isSubtype && !isErrorSuppressing(expr->location, inferredType)) + reportError(TypeMismatch{inferredType, bestType}, expr->location); } void visit(AstExprLocal* expr) diff --git a/Analysis/src/TypeFamily.cpp b/Analysis/src/TypeFamily.cpp index e336a5cd..a8d7d2f7 100644 --- a/Analysis/src/TypeFamily.cpp +++ b/Analysis/src/TypeFamily.cpp @@ -37,7 +37,7 @@ LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFamilyApplicationCartesianProductLimit, 5'0 // when this value is set to a negative value, guessing will be totally disabled. LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFamilyUseGuesserDepth, -1); -LUAU_FASTFLAG(DebugLuauLogSolver); +LUAU_FASTFLAGVARIABLE(DebugLuauLogTypeFamilies, false); namespace Luau { @@ -184,7 +184,7 @@ struct FamilyReducer if (subject->owningArena != ctx.arena.get()) ctx.ice->ice("Attempting to modify a type family instance from another arena", location); - if (FFlag::DebugLuauLogSolver) + if (FFlag::DebugLuauLogTypeFamilies) printf("%s -> %s\n", toString(subject, {true}).c_str(), toString(replacement, {true}).c_str()); asMutable(subject)->ty.template emplace>(replacement); @@ -206,7 +206,7 @@ struct FamilyReducer if (reduction.uninhabited || force) { - if (FFlag::DebugLuauLogSolver) + if (FFlag::DebugLuauLogTypeFamilies) printf("%s is uninhabited\n", toString(subject, {true}).c_str()); if constexpr (std::is_same_v) @@ -216,7 +216,7 @@ struct FamilyReducer } else if (!reduction.uninhabited && !force) { - if (FFlag::DebugLuauLogSolver) + if (FFlag::DebugLuauLogTypeFamilies) printf("%s is irreducible; blocked on %zu types, %zu packs\n", toString(subject, {true}).c_str(), reduction.blockedTypes.size(), reduction.blockedPacks.size()); @@ -243,7 +243,7 @@ struct FamilyReducer if (skip == SkipTestResult::Irreducible) { - if (FFlag::DebugLuauLogSolver) + if (FFlag::DebugLuauLogTypeFamilies) printf("%s is irreducible due to a dependency on %s\n", toString(subject, {true}).c_str(), toString(p, {true}).c_str()); irreducible.insert(subject); @@ -251,7 +251,7 @@ struct FamilyReducer } else if (skip == SkipTestResult::Defer) { - if (FFlag::DebugLuauLogSolver) + if (FFlag::DebugLuauLogTypeFamilies) printf("Deferring %s until %s is solved\n", toString(subject, {true}).c_str(), toString(p, {true}).c_str()); if constexpr (std::is_same_v) @@ -269,7 +269,7 @@ struct FamilyReducer if (skip == SkipTestResult::Irreducible) { - if (FFlag::DebugLuauLogSolver) + if (FFlag::DebugLuauLogTypeFamilies) printf("%s is irreducible due to a dependency on %s\n", toString(subject, {true}).c_str(), toString(p, {true}).c_str()); irreducible.insert(subject); @@ -277,7 +277,7 @@ struct FamilyReducer } else if (skip == SkipTestResult::Defer) { - if (FFlag::DebugLuauLogSolver) + if (FFlag::DebugLuauLogTypeFamilies) printf("Deferring %s until %s is solved\n", toString(subject, {true}).c_str(), toString(p, {true}).c_str()); if constexpr (std::is_same_v) @@ -297,7 +297,7 @@ struct FamilyReducer { if (shouldGuess.contains(subject)) { - if (FFlag::DebugLuauLogSolver) + if (FFlag::DebugLuauLogTypeFamilies) printf("Flagged %s for reduction with guesser.\n", toString(subject, {true}).c_str()); TypeFamilyReductionGuesser guesser{ctx.arena, ctx.builtins, ctx.normalizer}; @@ -305,14 +305,14 @@ struct FamilyReducer if (guessed) { - if (FFlag::DebugLuauLogSolver) + if (FFlag::DebugLuauLogTypeFamilies) printf("Selected %s as the guessed result type.\n", toString(*guessed, {true}).c_str()); replace(subject, *guessed); return true; } - if (FFlag::DebugLuauLogSolver) + if (FFlag::DebugLuauLogTypeFamilies) printf("Failed to produce a guess for the result of %s.\n", toString(subject, {true}).c_str()); } @@ -328,7 +328,7 @@ struct FamilyReducer if (irreducible.contains(subject)) return; - if (FFlag::DebugLuauLogSolver) + if (FFlag::DebugLuauLogTypeFamilies) printf("Trying to reduce %s\n", toString(subject, {true}).c_str()); if (const TypeFamilyInstanceType* tfit = get(subject)) @@ -337,7 +337,7 @@ struct FamilyReducer if (!testParameters(subject, tfit) && testCyclic != SkipTestResult::CyclicTypeFamily) { - if (FFlag::DebugLuauLogSolver) + if (FFlag::DebugLuauLogTypeFamilies) printf("Irreducible due to irreducible/pending and a non-cyclic family\n"); return; @@ -361,7 +361,7 @@ struct FamilyReducer if (irreducible.contains(subject)) return; - if (FFlag::DebugLuauLogSolver) + if (FFlag::DebugLuauLogTypeFamilies) printf("Trying to reduce %s\n", toString(subject, {true}).c_str()); if (const TypeFamilyInstanceTypePack* tfit = get(subject)) diff --git a/CodeGen/include/Luau/CodeGen.h b/CodeGen/include/Luau/CodeGen.h index 19a9b3c9..ac444b7b 100644 --- a/CodeGen/include/Luau/CodeGen.h +++ b/CodeGen/include/Luau/CodeGen.h @@ -144,8 +144,17 @@ using UniqueSharedCodeGenContext = std::unique_ptr gpr, const std::vector& simd) = 0; - virtual size_t getSize() const = 0; - virtual size_t getFunctionCount() const = 0; + virtual size_t getUnwindInfoSize(size_t blockSize) const = 0; // This will place the unwinding data at the target address and might update values of some fields - virtual void finalize(char* target, size_t offset, void* funcAddress, size_t funcSize) const = 0; + virtual size_t finalize(char* target, size_t offset, void* funcAddress, size_t blockSize) const = 0; }; } // namespace CodeGen diff --git a/CodeGen/include/Luau/UnwindBuilderDwarf2.h b/CodeGen/include/Luau/UnwindBuilderDwarf2.h index 741aaed2..1b634dec 100644 --- a/CodeGen/include/Luau/UnwindBuilderDwarf2.h +++ b/CodeGen/include/Luau/UnwindBuilderDwarf2.h @@ -33,10 +33,9 @@ public: void prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list gpr, const std::vector& simd) override; - size_t getSize() const override; - size_t getFunctionCount() const override; + size_t getUnwindInfoSize(size_t blockSize = 0) const override; - void finalize(char* target, size_t offset, void* funcAddress, size_t funcSize) const override; + size_t finalize(char* target, size_t offset, void* funcAddress, size_t blockSize) const override; private: size_t beginOffset = 0; diff --git a/CodeGen/include/Luau/UnwindBuilderWin.h b/CodeGen/include/Luau/UnwindBuilderWin.h index 3a7e1b5a..bc43b94a 100644 --- a/CodeGen/include/Luau/UnwindBuilderWin.h +++ b/CodeGen/include/Luau/UnwindBuilderWin.h @@ -53,10 +53,9 @@ public: void prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list gpr, const std::vector& simd) override; - size_t getSize() const override; - size_t getFunctionCount() const override; + size_t getUnwindInfoSize(size_t blockSize = 0) const override; - void finalize(char* target, size_t offset, void* funcAddress, size_t funcSize) const override; + size_t finalize(char* target, size_t offset, void* funcAddress, size_t blockSize) const override; private: size_t beginOffset = 0; diff --git a/CodeGen/src/CodeBlockUnwind.cpp b/CodeGen/src/CodeBlockUnwind.cpp index b8876054..cb2d693a 100644 --- a/CodeGen/src/CodeBlockUnwind.cpp +++ b/CodeGen/src/CodeBlockUnwind.cpp @@ -102,17 +102,17 @@ void* createBlockUnwindInfo(void* context, uint8_t* block, size_t blockSize, siz UnwindBuilder* unwind = (UnwindBuilder*)context; // All unwinding related data is placed together at the start of the block - size_t unwindSize = unwind->getSize(); + size_t unwindSize = unwind->getUnwindInfoSize(blockSize); unwindSize = (unwindSize + (kCodeAlignment - 1)) & ~(kCodeAlignment - 1); // Match code allocator alignment CODEGEN_ASSERT(blockSize >= unwindSize); char* unwindData = (char*)block; - unwind->finalize(unwindData, unwindSize, block, blockSize); + [[maybe_unused]] size_t functionCount = unwind->finalize(unwindData, unwindSize, block, blockSize); #if defined(_WIN32) && defined(CODEGEN_TARGET_X64) #if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_APP | WINAPI_PARTITION_SYSTEM) - if (!RtlAddFunctionTable((RUNTIME_FUNCTION*)block, uint32_t(unwind->getFunctionCount()), uintptr_t(block))) + if (!RtlAddFunctionTable((RUNTIME_FUNCTION*)block, uint32_t(functionCount), uintptr_t(block))) { CODEGEN_ASSERT(!"Failed to allocate function table"); return nullptr; diff --git a/CodeGen/src/CodeGen.cpp b/CodeGen/src/CodeGen.cpp index 5d6f1fb5..694a9f7e 100644 --- a/CodeGen/src/CodeGen.cpp +++ b/CodeGen/src/CodeGen.cpp @@ -95,10 +95,6 @@ std::string toString(const CodeGenCompilationResult& result) return ""; } -void* gPerfLogContext = nullptr; -PerfLogFn gPerfLogFn = nullptr; - - void onDisable(lua_State* L, Proto* proto) { // do nothing if proto already uses bytecode @@ -196,59 +192,5 @@ bool isSupported() #endif } -void create(lua_State* L, AllocationCallback* allocationCallback, void* allocationCallbackContext) -{ - create_NEW(L, allocationCallback, allocationCallbackContext); -} - -void create(lua_State* L) -{ - create_NEW(L); -} - -void create(lua_State* L, SharedCodeGenContext* codeGenContext) -{ - create_NEW(L, codeGenContext); -} - -[[nodiscard]] bool isNativeExecutionEnabled(lua_State* L) -{ - return isNativeExecutionEnabled_NEW(L); -} - -void setNativeExecutionEnabled(lua_State* L, bool enabled) -{ - setNativeExecutionEnabled_NEW(L, enabled); -} - -CompilationResult compile(lua_State* L, int idx, unsigned int flags, CompilationStats* stats) -{ - Luau::CodeGen::CompilationOptions options{flags}; - - return compile_NEW(L, idx, options, stats); -} - -CompilationResult compile(const ModuleId& moduleId, lua_State* L, int idx, unsigned int flags, CompilationStats* stats) -{ - Luau::CodeGen::CompilationOptions options{flags}; - return compile_NEW(moduleId, L, idx, options, stats); -} - -CompilationResult compile(lua_State* L, int idx, const CompilationOptions& options, CompilationStats* stats) -{ - return compile_NEW(L, idx, options, stats); -} - -CompilationResult compile(const ModuleId& moduleId, lua_State* L, int idx, const CompilationOptions& options, CompilationStats* stats) -{ - return compile_NEW(moduleId, L, idx, options, stats); -} - -void setPerfLog(void* context, PerfLogFn logFn) -{ - gPerfLogContext = context; - gPerfLogFn = logFn; -} - } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/CodeGenA64.cpp b/CodeGen/src/CodeGenA64.cpp index a18278c9..05ac9013 100644 --- a/CodeGen/src/CodeGenA64.cpp +++ b/CodeGen/src/CodeGenA64.cpp @@ -253,7 +253,7 @@ static EntryLocations buildEntryFunction(AssemblyBuilderA64& build, UnwindBuilde // Our entry function is special, it spans the whole remaining code area unwind.startFunction(); unwind.prologueA64(prologueSize, kStackSize, {x29, x30, x19, x20, x21, x22, x23, x24, x25}); - unwind.finishFunction(build.getLabelOffset(locations.start), kFullBlockFuncton); + unwind.finishFunction(build.getLabelOffset(locations.start), kFullBlockFunction); return locations; } diff --git a/CodeGen/src/CodeGenContext.cpp b/CodeGen/src/CodeGenContext.cpp index cdffb123..a94388f6 100644 --- a/CodeGen/src/CodeGenContext.cpp +++ b/CodeGen/src/CodeGenContext.cpp @@ -25,11 +25,17 @@ namespace CodeGen static const Instruction kCodeEntryInsn = LOP_NATIVECALL; // From CodeGen.cpp -extern void* gPerfLogContext; -extern PerfLogFn gPerfLogFn; +static void* gPerfLogContext = nullptr; +static PerfLogFn gPerfLogFn = nullptr; unsigned int getCpuFeaturesA64(); +void setPerfLog(void* context, PerfLogFn logFn) +{ + gPerfLogContext = context; + gPerfLogFn = logFn; +} + static void logPerfFunction(Proto* p, uintptr_t addr, unsigned size) { CODEGEN_ASSERT(p->source); @@ -365,17 +371,17 @@ static void initializeExecutionCallbacks(lua_State* L, BaseCodeGenContext* codeG ecb->getmemorysize = getMemorySize; } -void create_NEW(lua_State* L) +void create(lua_State* L) { - return create_NEW(L, size_t(FInt::LuauCodeGenBlockSize), size_t(FInt::LuauCodeGenMaxTotalSize), nullptr, nullptr); + return create(L, size_t(FInt::LuauCodeGenBlockSize), size_t(FInt::LuauCodeGenMaxTotalSize), nullptr, nullptr); } -void create_NEW(lua_State* L, AllocationCallback* allocationCallback, void* allocationCallbackContext) +void create(lua_State* L, AllocationCallback* allocationCallback, void* allocationCallbackContext) { - return create_NEW(L, size_t(FInt::LuauCodeGenBlockSize), size_t(FInt::LuauCodeGenMaxTotalSize), allocationCallback, allocationCallbackContext); + return create(L, size_t(FInt::LuauCodeGenBlockSize), size_t(FInt::LuauCodeGenMaxTotalSize), allocationCallback, allocationCallbackContext); } -void create_NEW(lua_State* L, size_t blockSize, size_t maxTotalSize, AllocationCallback* allocationCallback, void* allocationCallbackContext) +void create(lua_State* L, size_t blockSize, size_t maxTotalSize, AllocationCallback* allocationCallback, void* allocationCallbackContext) { std::unique_ptr codeGenContext = std::make_unique(blockSize, maxTotalSize, allocationCallback, allocationCallbackContext); @@ -386,7 +392,7 @@ void create_NEW(lua_State* L, size_t blockSize, size_t maxTotalSize, AllocationC initializeExecutionCallbacks(L, codeGenContext.release()); } -void create_NEW(lua_State* L, SharedCodeGenContext* codeGenContext) +void create(lua_State* L, SharedCodeGenContext* codeGenContext) { initializeExecutionCallbacks(L, codeGenContext); } @@ -575,22 +581,32 @@ template return compilationResult; } -CompilationResult compile_NEW(const ModuleId& moduleId, lua_State* L, int idx, const CompilationOptions& options, CompilationStats* stats) +CompilationResult compile(const ModuleId& moduleId, lua_State* L, int idx, const CompilationOptions& options, CompilationStats* stats) { return compileInternal(moduleId, L, idx, options, stats); } -CompilationResult compile_NEW(lua_State* L, int idx, const CompilationOptions& options, CompilationStats* stats) +CompilationResult compile(lua_State* L, int idx, const CompilationOptions& options, CompilationStats* stats) { return compileInternal({}, L, idx, options, stats); } -[[nodiscard]] bool isNativeExecutionEnabled_NEW(lua_State* L) +CompilationResult compile(lua_State* L, int idx, unsigned int flags, CompilationStats* stats) +{ + return compileInternal({}, L, idx, CompilationOptions{flags}, stats); +} + +CompilationResult compile(const ModuleId& moduleId, lua_State* L, int idx, unsigned int flags, CompilationStats* stats) +{ + return compileInternal(moduleId, L, idx, CompilationOptions{flags}, stats); +} + +[[nodiscard]] bool isNativeExecutionEnabled(lua_State* L) { return getCodeGenContext(L) != nullptr && L->global->ecb.enter == onEnter; } -void setNativeExecutionEnabled_NEW(lua_State* L, bool enabled) +void setNativeExecutionEnabled(lua_State* L, bool enabled) { if (getCodeGenContext(L) != nullptr) L->global->ecb.enter = enabled ? onEnter : onEnterDisabled; diff --git a/CodeGen/src/CodeGenContext.h b/CodeGen/src/CodeGenContext.h index c47121bc..516a7064 100644 --- a/CodeGen/src/CodeGenContext.h +++ b/CodeGen/src/CodeGenContext.h @@ -88,33 +88,5 @@ private: SharedCodeAllocator sharedAllocator; }; - -// The following will become the public interface, and can be moved into -// CodeGen.h after the shared allocator work is complete. When the old -// implementation is removed, the _NEW suffix can be dropped from these -// functions. - -// Initializes native code-gen on the provided Luau VM, using a VM-specific -// code-gen context and either the default allocator parameters or custom -// allocator parameters. -void create_NEW(lua_State* L); -void create_NEW(lua_State* L, AllocationCallback* allocationCallback, void* allocationCallbackContext); -void create_NEW(lua_State* L, size_t blockSize, size_t maxTotalSize, AllocationCallback* allocationCallback, void* allocationCallbackContext); - -// Initializes native code-gen on the provided Luau VM, using the provided -// SharedCodeGenContext. Note that after this function is called, the -// SharedCodeGenContext must not be destroyed until after the Luau VM L is -// destroyed via lua_close. -void create_NEW(lua_State* L, SharedCodeGenContext* codeGenContext); - -CompilationResult compile_NEW(lua_State* L, int idx, const CompilationOptions& options, CompilationStats* stats); -CompilationResult compile_NEW(const ModuleId& moduleId, lua_State* L, int idx, const CompilationOptions& options, CompilationStats* stats); - -// Returns true if native execution is currently enabled for this VM -[[nodiscard]] bool isNativeExecutionEnabled_NEW(lua_State* L); - -// Enables or disables native excution for this VM -void setNativeExecutionEnabled_NEW(lua_State* L, bool enabled); - } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/CodeGenX64.cpp b/CodeGen/src/CodeGenX64.cpp index 5e450c9a..7f4a9e0c 100644 --- a/CodeGen/src/CodeGenX64.cpp +++ b/CodeGen/src/CodeGenX64.cpp @@ -181,7 +181,7 @@ static EntryLocations buildEntryFunction(AssemblyBuilderX64& build, UnwindBuilde build.ret(); // Our entry function is special, it spans the whole remaining code area - unwind.finishFunction(build.getLabelOffset(locations.start), kFullBlockFuncton); + unwind.finishFunction(build.getLabelOffset(locations.start), kFullBlockFunction); return locations; } diff --git a/CodeGen/src/EmitCommonX64.cpp b/CodeGen/src/EmitCommonX64.cpp index c8d1e75a..50f2208b 100644 --- a/CodeGen/src/EmitCommonX64.cpp +++ b/CodeGen/src/EmitCommonX64.cpp @@ -14,6 +14,8 @@ #include +LUAU_FASTFLAGVARIABLE(LuauCodegenSplitDoarith, false) + namespace Luau { namespace CodeGen @@ -155,8 +157,45 @@ void callArithHelper(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, Ope callWrap.addArgument(SizeX64::qword, luauRegAddress(ra)); callWrap.addArgument(SizeX64::qword, b); callWrap.addArgument(SizeX64::qword, c); - callWrap.addArgument(SizeX64::dword, tm); - callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarith)]); + + if (FFlag::LuauCodegenSplitDoarith) + { + switch (tm) + { + case TM_ADD: + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarithadd)]); + break; + case TM_SUB: + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarithsub)]); + break; + case TM_MUL: + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarithmul)]); + break; + case TM_DIV: + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarithdiv)]); + break; + case TM_IDIV: + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarithidiv)]); + break; + case TM_MOD: + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarithmod)]); + break; + case TM_POW: + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarithpow)]); + break; + case TM_UNM: + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarithunm)]); + break; + default: + CODEGEN_ASSERT(!"Invalid doarith helper operation tag"); + break; + } + } + else + { + callWrap.addArgument(SizeX64::dword, tm); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarith)]); + } emitUpdateBase(build); } diff --git a/CodeGen/src/IrLoweringA64.cpp b/CodeGen/src/IrLoweringA64.cpp index f35a15fa..c8cc07f4 100644 --- a/CodeGen/src/IrLoweringA64.cpp +++ b/CodeGen/src/IrLoweringA64.cpp @@ -12,6 +12,7 @@ #include "lgc.h" LUAU_FASTFLAG(LuauCodegenRemoveDeadStores5) +LUAU_FASTFLAG(LuauCodegenSplitDoarith) namespace Luau { @@ -1242,9 +1243,47 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) else build.add(x3, rBase, uint16_t(vmRegOp(inst.c) * sizeof(TValue))); - build.mov(w4, TMS(intOp(inst.d))); - build.ldr(x5, mem(rNativeContext, offsetof(NativeContext, luaV_doarith))); - build.blr(x5); + if (FFlag::LuauCodegenSplitDoarith) + { + switch (TMS(intOp(inst.d))) + { + case TM_ADD: + build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_doarithadd))); + break; + case TM_SUB: + build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_doarithsub))); + break; + case TM_MUL: + build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_doarithmul))); + break; + case TM_DIV: + build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_doarithdiv))); + break; + case TM_IDIV: + build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_doarithidiv))); + break; + case TM_MOD: + build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_doarithmod))); + break; + case TM_POW: + build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_doarithpow))); + break; + case TM_UNM: + build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_doarithunm))); + break; + default: + CODEGEN_ASSERT(!"Invalid doarith helper operation tag"); + break; + } + + build.blr(x4); + } + else + { + build.mov(w4, TMS(intOp(inst.d))); + build.ldr(x5, mem(rNativeContext, offsetof(NativeContext, luaV_doarith))); + build.blr(x5); + } emitUpdateBase(build); break; diff --git a/CodeGen/src/NativeState.cpp b/CodeGen/src/NativeState.cpp index 5f6df4b6..b3d07491 100644 --- a/CodeGen/src/NativeState.cpp +++ b/CodeGen/src/NativeState.cpp @@ -43,6 +43,16 @@ void initFunctions(NativeState& data) data.context.luaV_lessequal = luaV_lessequal; data.context.luaV_equalval = luaV_equalval; data.context.luaV_doarith = luaV_doarith; + + data.context.luaV_doarithadd = luaV_doarithimpl; + data.context.luaV_doarithsub = luaV_doarithimpl; + data.context.luaV_doarithmul = luaV_doarithimpl; + data.context.luaV_doarithdiv = luaV_doarithimpl; + data.context.luaV_doarithidiv = luaV_doarithimpl; + data.context.luaV_doarithmod = luaV_doarithimpl; + data.context.luaV_doarithpow = luaV_doarithimpl; + data.context.luaV_doarithunm = luaV_doarithimpl; + data.context.luaV_dolen = luaV_dolen; data.context.luaV_gettable = luaV_gettable; data.context.luaV_settable = luaV_settable; @@ -121,6 +131,16 @@ void initFunctions(NativeContext& context) context.luaV_lessequal = luaV_lessequal; context.luaV_equalval = luaV_equalval; context.luaV_doarith = luaV_doarith; + + context.luaV_doarithadd = luaV_doarithimpl; + context.luaV_doarithsub = luaV_doarithimpl; + context.luaV_doarithmul = luaV_doarithimpl; + context.luaV_doarithdiv = luaV_doarithimpl; + context.luaV_doarithidiv = luaV_doarithimpl; + context.luaV_doarithmod = luaV_doarithimpl; + context.luaV_doarithpow = luaV_doarithimpl; + context.luaV_doarithunm = luaV_doarithimpl; + context.luaV_dolen = luaV_dolen; context.luaV_gettable = luaV_gettable; context.luaV_settable = luaV_settable; diff --git a/CodeGen/src/NativeState.h b/CodeGen/src/NativeState.h index 3e7c85e9..2edfc270 100644 --- a/CodeGen/src/NativeState.h +++ b/CodeGen/src/NativeState.h @@ -34,6 +34,14 @@ struct NativeContext int (*luaV_lessequal)(lua_State* L, const TValue* l, const TValue* r) = nullptr; int (*luaV_equalval)(lua_State* L, const TValue* t1, const TValue* t2) = nullptr; void (*luaV_doarith)(lua_State* L, StkId ra, const TValue* rb, const TValue* rc, TMS op) = nullptr; + void (*luaV_doarithadd)(lua_State* L, StkId ra, const TValue* rb, const TValue* rc) = nullptr; + void (*luaV_doarithsub)(lua_State* L, StkId ra, const TValue* rb, const TValue* rc) = nullptr; + void (*luaV_doarithmul)(lua_State* L, StkId ra, const TValue* rb, const TValue* rc) = nullptr; + void (*luaV_doarithdiv)(lua_State* L, StkId ra, const TValue* rb, const TValue* rc) = nullptr; + void (*luaV_doarithidiv)(lua_State* L, StkId ra, const TValue* rb, const TValue* rc) = nullptr; + void (*luaV_doarithmod)(lua_State* L, StkId ra, const TValue* rb, const TValue* rc) = nullptr; + void (*luaV_doarithpow)(lua_State* L, StkId ra, const TValue* rb, const TValue* rc) = nullptr; + void (*luaV_doarithunm)(lua_State* L, StkId ra, const TValue* rb, const TValue* rc) = nullptr; void (*luaV_dolen)(lua_State* L, StkId ra, const TValue* rb) = nullptr; void (*luaV_gettable)(lua_State* L, const TValue* t, TValue* key, StkId val) = nullptr; void (*luaV_settable)(lua_State* L, const TValue* t, TValue* key, StkId val) = nullptr; diff --git a/CodeGen/src/UnwindBuilderDwarf2.cpp b/CodeGen/src/UnwindBuilderDwarf2.cpp index b1522e7b..2f090b52 100644 --- a/CodeGen/src/UnwindBuilderDwarf2.cpp +++ b/CodeGen/src/UnwindBuilderDwarf2.cpp @@ -202,7 +202,7 @@ void UnwindBuilderDwarf2::finishInfo() // Terminate section pos = writeu32(pos, 0); - CODEGEN_ASSERT(getSize() <= kRawDataLimit); + CODEGEN_ASSERT(getUnwindInfoSize() <= kRawDataLimit); } void UnwindBuilderDwarf2::prologueA64(uint32_t prologueSize, uint32_t stackSize, std::initializer_list regs) @@ -271,19 +271,14 @@ void UnwindBuilderDwarf2::prologueX64(uint32_t prologueSize, uint32_t stackSize, CODEGEN_ASSERT(prologueOffset == prologueSize); } -size_t UnwindBuilderDwarf2::getSize() const +size_t UnwindBuilderDwarf2::getUnwindInfoSize(size_t blockSize) const { return size_t(pos - rawData); } -size_t UnwindBuilderDwarf2::getFunctionCount() const +size_t UnwindBuilderDwarf2::finalize(char* target, size_t offset, void* funcAddress, size_t blockSize) const { - return unwindFunctions.size(); -} - -void UnwindBuilderDwarf2::finalize(char* target, size_t offset, void* funcAddress, size_t funcSize) const -{ - memcpy(target, rawData, getSize()); + memcpy(target, rawData, getUnwindInfoSize()); for (const UnwindFunctionDwarf2& func : unwindFunctions) { @@ -291,11 +286,13 @@ void UnwindBuilderDwarf2::finalize(char* target, size_t offset, void* funcAddres writeu64(fdeEntry + kFdeInitialLocationOffset, uintptr_t(funcAddress) + offset + func.beginOffset); - if (func.endOffset == kFullBlockFuncton) - writeu64(fdeEntry + kFdeAddressRangeOffset, funcSize - offset); + if (func.endOffset == kFullBlockFunction) + writeu64(fdeEntry + kFdeAddressRangeOffset, blockSize - offset); else writeu64(fdeEntry + kFdeAddressRangeOffset, func.endOffset - func.beginOffset); } + + return unwindFunctions.size(); } } // namespace CodeGen diff --git a/CodeGen/src/UnwindBuilderWin.cpp b/CodeGen/src/UnwindBuilderWin.cpp index 498470bd..2bcc0321 100644 --- a/CodeGen/src/UnwindBuilderWin.cpp +++ b/CodeGen/src/UnwindBuilderWin.cpp @@ -194,17 +194,12 @@ void UnwindBuilderWin::prologueX64(uint32_t prologueSize, uint32_t stackSize, bo this->prologSize = prologueSize; } -size_t UnwindBuilderWin::getSize() const +size_t UnwindBuilderWin::getUnwindInfoSize(size_t blockSize) const { return sizeof(UnwindFunctionWin) * unwindFunctions.size() + size_t(rawDataPos - rawData); } -size_t UnwindBuilderWin::getFunctionCount() const -{ - return unwindFunctions.size(); -} - -void UnwindBuilderWin::finalize(char* target, size_t offset, void* funcAddress, size_t funcSize) const +size_t UnwindBuilderWin::finalize(char* target, size_t offset, void* funcAddress, size_t blockSize) const { // Copy adjusted function information for (UnwindFunctionWin func : unwindFunctions) @@ -213,8 +208,8 @@ void UnwindBuilderWin::finalize(char* target, size_t offset, void* funcAddress, func.beginOffset += uint32_t(offset); // Whole block is a part of a 'single function' - if (func.endOffset == kFullBlockFuncton) - func.endOffset = uint32_t(funcSize); + if (func.endOffset == kFullBlockFunction) + func.endOffset = uint32_t(blockSize); else func.endOffset += uint32_t(offset); @@ -226,6 +221,8 @@ void UnwindBuilderWin::finalize(char* target, size_t offset, void* funcAddress, // Copy unwind codes memcpy(target, rawData, size_t(rawDataPos - rawData)); + + return unwindFunctions.size(); } } // namespace CodeGen diff --git a/VM/src/lvm.h b/VM/src/lvm.h index 5ec7bc16..96bc37f3 100644 --- a/VM/src/lvm.h +++ b/VM/src/lvm.h @@ -16,6 +16,10 @@ LUAI_FUNC int luaV_lessthan(lua_State* L, const TValue* l, const TValue* r); LUAI_FUNC int luaV_lessequal(lua_State* L, const TValue* l, const TValue* r); LUAI_FUNC int luaV_equalval(lua_State* L, const TValue* t1, const TValue* t2); LUAI_FUNC void luaV_doarith(lua_State* L, StkId ra, const TValue* rb, const TValue* rc, TMS op); + +template +void luaV_doarithimpl(lua_State* L, StkId ra, const TValue* rb, const TValue* rc); + LUAI_FUNC void luaV_dolen(lua_State* L, StkId ra, const TValue* rb); LUAI_FUNC const TValue* luaV_tonumber(const TValue* obj, TValue* n); LUAI_FUNC const float* luaV_tovector(const TValue* obj); diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index 74e30c94..4ac21db3 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -16,6 +16,8 @@ #include +LUAU_FASTFLAGVARIABLE(LuauVmSplitDoarith, false) + // Disable c99-designator to avoid the warning in CGOTO dispatch table #ifdef __clang__ #if __has_warning("-Wc99-designator") @@ -1487,7 +1489,14 @@ reentry: else { // slow-path, may invoke C/Lua via metamethods - VM_PROTECT(luaV_doarith(L, ra, rb, rc, TM_ADD)); + if (FFlag::LuauVmSplitDoarith) + { + VM_PROTECT(luaV_doarithimpl(L, ra, rb, rc)); + } + else + { + VM_PROTECT(luaV_doarith(L, ra, rb, rc, TM_ADD)); + } VM_NEXT(); } } @@ -1533,7 +1542,14 @@ reentry: else { // slow-path, may invoke C/Lua via metamethods - VM_PROTECT(luaV_doarith(L, ra, rb, rc, TM_SUB)); + if (FFlag::LuauVmSplitDoarith) + { + VM_PROTECT(luaV_doarithimpl(L, ra, rb, rc)); + } + else + { + VM_PROTECT(luaV_doarith(L, ra, rb, rc, TM_SUB)); + } VM_NEXT(); } } @@ -1594,7 +1610,14 @@ reentry: else { // slow-path, may invoke C/Lua via metamethods - VM_PROTECT(luaV_doarith(L, ra, rb, rc, TM_MUL)); + if (FFlag::LuauVmSplitDoarith) + { + VM_PROTECT(luaV_doarithimpl(L, ra, rb, rc)); + } + else + { + VM_PROTECT(luaV_doarith(L, ra, rb, rc, TM_MUL)); + } VM_NEXT(); } } @@ -1655,7 +1678,14 @@ reentry: else { // slow-path, may invoke C/Lua via metamethods - VM_PROTECT(luaV_doarith(L, ra, rb, rc, TM_DIV)); + if (FFlag::LuauVmSplitDoarith) + { + VM_PROTECT(luaV_doarithimpl(L, ra, rb, rc)); + } + else + { + VM_PROTECT(luaV_doarith(L, ra, rb, rc, TM_DIV)); + } VM_NEXT(); } } @@ -1703,7 +1733,14 @@ reentry: else { // slow-path, may invoke C/Lua via metamethods - VM_PROTECT(luaV_doarith(L, ra, rb, rc, TM_IDIV)); + if (FFlag::LuauVmSplitDoarith) + { + VM_PROTECT(luaV_doarithimpl(L, ra, rb, rc)); + } + else + { + VM_PROTECT(luaV_doarith(L, ra, rb, rc, TM_IDIV)); + } VM_NEXT(); } } @@ -1727,7 +1764,14 @@ reentry: else { // slow-path, may invoke C/Lua via metamethods - VM_PROTECT(luaV_doarith(L, ra, rb, rc, TM_MOD)); + if (FFlag::LuauVmSplitDoarith) + { + VM_PROTECT(luaV_doarithimpl(L, ra, rb, rc)); + } + else + { + VM_PROTECT(luaV_doarith(L, ra, rb, rc, TM_MOD)); + } VM_NEXT(); } } @@ -1748,7 +1792,14 @@ reentry: else { // slow-path, may invoke C/Lua via metamethods - VM_PROTECT(luaV_doarith(L, ra, rb, rc, TM_POW)); + if (FFlag::LuauVmSplitDoarith) + { + VM_PROTECT(luaV_doarithimpl(L, ra, rb, rc)); + } + else + { + VM_PROTECT(luaV_doarith(L, ra, rb, rc, TM_POW)); + } VM_NEXT(); } } @@ -1769,7 +1820,14 @@ reentry: else { // slow-path, may invoke C/Lua via metamethods - VM_PROTECT(luaV_doarith(L, ra, rb, kv, TM_ADD)); + if (FFlag::LuauVmSplitDoarith) + { + VM_PROTECT(luaV_doarithimpl(L, ra, rb, kv)); + } + else + { + VM_PROTECT(luaV_doarith(L, ra, rb, kv, TM_ADD)); + } VM_NEXT(); } } @@ -1790,7 +1848,14 @@ reentry: else { // slow-path, may invoke C/Lua via metamethods - VM_PROTECT(luaV_doarith(L, ra, rb, kv, TM_SUB)); + if (FFlag::LuauVmSplitDoarith) + { + VM_PROTECT(luaV_doarithimpl(L, ra, rb, kv)); + } + else + { + VM_PROTECT(luaV_doarith(L, ra, rb, kv, TM_SUB)); + } VM_NEXT(); } } @@ -1835,7 +1900,14 @@ reentry: else { // slow-path, may invoke C/Lua via metamethods - VM_PROTECT(luaV_doarith(L, ra, rb, kv, TM_MUL)); + if (FFlag::LuauVmSplitDoarith) + { + VM_PROTECT(luaV_doarithimpl(L, ra, rb, kv)); + } + else + { + VM_PROTECT(luaV_doarith(L, ra, rb, kv, TM_MUL)); + } VM_NEXT(); } } @@ -1881,7 +1953,14 @@ reentry: else { // slow-path, may invoke C/Lua via metamethods - VM_PROTECT(luaV_doarith(L, ra, rb, kv, TM_DIV)); + if (FFlag::LuauVmSplitDoarith) + { + VM_PROTECT(luaV_doarithimpl(L, ra, rb, kv)); + } + else + { + VM_PROTECT(luaV_doarith(L, ra, rb, kv, TM_DIV)); + } VM_NEXT(); } } @@ -1928,7 +2007,14 @@ reentry: else { // slow-path, may invoke C/Lua via metamethods - VM_PROTECT(luaV_doarith(L, ra, rb, kv, TM_IDIV)); + if (FFlag::LuauVmSplitDoarith) + { + VM_PROTECT(luaV_doarithimpl(L, ra, rb, kv)); + } + else + { + VM_PROTECT(luaV_doarith(L, ra, rb, kv, TM_IDIV)); + } VM_NEXT(); } } @@ -1952,7 +2038,14 @@ reentry: else { // slow-path, may invoke C/Lua via metamethods - VM_PROTECT(luaV_doarith(L, ra, rb, kv, TM_MOD)); + if (FFlag::LuauVmSplitDoarith) + { + VM_PROTECT(luaV_doarithimpl(L, ra, rb, kv)); + } + else + { + VM_PROTECT(luaV_doarith(L, ra, rb, kv, TM_MOD)); + } VM_NEXT(); } } @@ -1979,7 +2072,14 @@ reentry: else { // slow-path, may invoke C/Lua via metamethods - VM_PROTECT(luaV_doarith(L, ra, rb, kv, TM_POW)); + if (FFlag::LuauVmSplitDoarith) + { + VM_PROTECT(luaV_doarithimpl(L, ra, rb, kv)); + } + else + { + VM_PROTECT(luaV_doarith(L, ra, rb, kv, TM_POW)); + } VM_NEXT(); } } @@ -2092,7 +2192,14 @@ reentry: else { // slow-path, may invoke C/Lua via metamethods - VM_PROTECT(luaV_doarith(L, ra, rb, rb, TM_UNM)); + if (FFlag::LuauVmSplitDoarith) + { + VM_PROTECT(luaV_doarithimpl(L, ra, rb, rb)); + } + else + { + VM_PROTECT(luaV_doarith(L, ra, rb, rb, TM_UNM)); + } VM_NEXT(); } } @@ -2711,7 +2818,14 @@ reentry: else { // slow-path, may invoke C/Lua via metamethods - VM_PROTECT(luaV_doarith(L, ra, kv, rc, TM_SUB)); + if (FFlag::LuauVmSplitDoarith) + { + VM_PROTECT(luaV_doarithimpl(L, ra, kv, rc)); + } + else + { + VM_PROTECT(luaV_doarith(L, ra, kv, rc, TM_SUB)); + } VM_NEXT(); } } @@ -2739,7 +2853,14 @@ reentry: else { // slow-path, may invoke C/Lua via metamethods - VM_PROTECT(luaV_doarith(L, ra, kv, rc, TM_DIV)); + if (FFlag::LuauVmSplitDoarith) + { + VM_PROTECT(luaV_doarithimpl(L, ra, kv, rc)); + } + else + { + VM_PROTECT(luaV_doarith(L, ra, kv, rc, TM_DIV)); + } VM_NEXT(); } } diff --git a/VM/src/lvmutils.cpp b/VM/src/lvmutils.cpp index 4db8bba7..6ee542b0 100644 --- a/VM/src/lvmutils.cpp +++ b/VM/src/lvmutils.cpp @@ -373,6 +373,152 @@ void luaV_concat(lua_State* L, int total, int last) } while (total > 1); // repeat until only 1 result left } +template +void luaV_doarithimpl(lua_State* L, StkId ra, const TValue* rb, const TValue* rc) +{ + TValue tempb, tempc; + const TValue *b, *c; + + // vector operations that we support: + // v+v v-v -v (add/sub/neg) + // v*v s*v v*s (mul) + // v/v s/v v/s (div) + // v//v s//v v//s (floor div) + const float* vb = ttisvector(rb) ? vvalue(rb) : nullptr; + const float* vc = ttisvector(rc) ? vvalue(rc) : nullptr; + + if (vb && vc) + { + switch (op) + { + case TM_ADD: + setvvalue(ra, vb[0] + vc[0], vb[1] + vc[1], vb[2] + vc[2], vb[3] + vc[3]); + return; + case TM_SUB: + setvvalue(ra, vb[0] - vc[0], vb[1] - vc[1], vb[2] - vc[2], vb[3] - vc[3]); + return; + case TM_MUL: + setvvalue(ra, vb[0] * vc[0], vb[1] * vc[1], vb[2] * vc[2], vb[3] * vc[3]); + return; + case TM_DIV: + setvvalue(ra, vb[0] / vc[0], vb[1] / vc[1], vb[2] / vc[2], vb[3] / vc[3]); + return; + case TM_IDIV: + setvvalue(ra, float(luai_numidiv(vb[0], vc[0])), float(luai_numidiv(vb[1], vc[1])), float(luai_numidiv(vb[2], vc[2])), + float(luai_numidiv(vb[3], vc[3]))); + return; + case TM_UNM: + setvvalue(ra, -vb[0], -vb[1], -vb[2], -vb[3]); + return; + default: + break; + } + } + else if (vb) + { + c = ttisnumber(rc) ? rc : luaV_tonumber(rc, &tempc); + + if (c) + { + float nc = cast_to(float, nvalue(c)); + + switch (op) + { + case TM_MUL: + setvvalue(ra, vb[0] * nc, vb[1] * nc, vb[2] * nc, vb[3] * nc); + return; + case TM_DIV: + setvvalue(ra, vb[0] / nc, vb[1] / nc, vb[2] / nc, vb[3] / nc); + return; + case TM_IDIV: + setvvalue(ra, float(luai_numidiv(vb[0], nc)), float(luai_numidiv(vb[1], nc)), float(luai_numidiv(vb[2], nc)), + float(luai_numidiv(vb[3], nc))); + return; + default: + break; + } + } + } + else if (vc) + { + b = ttisnumber(rb) ? rb : luaV_tonumber(rb, &tempb); + + if (b) + { + float nb = cast_to(float, nvalue(b)); + + switch (op) + { + case TM_MUL: + setvvalue(ra, nb * vc[0], nb * vc[1], nb * vc[2], nb * vc[3]); + return; + case TM_DIV: + setvvalue(ra, nb / vc[0], nb / vc[1], nb / vc[2], nb / vc[3]); + return; + case TM_IDIV: + setvvalue(ra, float(luai_numidiv(nb, vc[0])), float(luai_numidiv(nb, vc[1])), float(luai_numidiv(nb, vc[2])), + float(luai_numidiv(nb, vc[3]))); + return; + default: + break; + } + } + } + + if ((b = luaV_tonumber(rb, &tempb)) != NULL && (c = luaV_tonumber(rc, &tempc)) != NULL) + { + double nb = nvalue(b), nc = nvalue(c); + + switch (op) + { + case TM_ADD: + setnvalue(ra, luai_numadd(nb, nc)); + break; + case TM_SUB: + setnvalue(ra, luai_numsub(nb, nc)); + break; + case TM_MUL: + setnvalue(ra, luai_nummul(nb, nc)); + break; + case TM_DIV: + setnvalue(ra, luai_numdiv(nb, nc)); + break; + case TM_IDIV: + setnvalue(ra, luai_numidiv(nb, nc)); + break; + case TM_MOD: + setnvalue(ra, luai_nummod(nb, nc)); + break; + case TM_POW: + setnvalue(ra, luai_numpow(nb, nc)); + break; + case TM_UNM: + setnvalue(ra, luai_numunm(nb)); + break; + default: + LUAU_ASSERT(0); + break; + } + } + else + { + if (!call_binTM(L, rb, rc, ra, op)) + { + luaG_aritherror(L, rb, rc, op); + } + } +} + +// instantiate private template implementation for external callers +template void luaV_doarithimpl(lua_State* L, StkId ra, const TValue* rb, const TValue* rc); +template void luaV_doarithimpl(lua_State* L, StkId ra, const TValue* rb, const TValue* rc); +template void luaV_doarithimpl(lua_State* L, StkId ra, const TValue* rb, const TValue* rc); +template void luaV_doarithimpl(lua_State* L, StkId ra, const TValue* rb, const TValue* rc); +template void luaV_doarithimpl(lua_State* L, StkId ra, const TValue* rb, const TValue* rc); +template void luaV_doarithimpl(lua_State* L, StkId ra, const TValue* rb, const TValue* rc); +template void luaV_doarithimpl(lua_State* L, StkId ra, const TValue* rb, const TValue* rc); +template void luaV_doarithimpl(lua_State* L, StkId ra, const TValue* rb, const TValue* rc); + void luaV_doarith(lua_State* L, StkId ra, const TValue* rb, const TValue* rc, TMS op) { TValue tempb, tempc; diff --git a/tests/CodeAllocator.test.cpp b/tests/CodeAllocator.test.cpp index 2ac0e5fd..21228d6b 100644 --- a/tests/CodeAllocator.test.cpp +++ b/tests/CodeAllocator.test.cpp @@ -191,7 +191,7 @@ TEST_CASE("WindowsUnwindCodesX64") unwind.finishInfo(); std::vector data; - data.resize(unwind.getSize()); + data.resize(unwind.getUnwindInfoSize()); unwind.finalize(data.data(), 0, nullptr, 0); std::vector expected{0x44, 0x33, 0x22, 0x11, 0x22, 0x33, 0x44, 0x55, 0x0c, 0x00, 0x00, 0x00, 0x01, 0x17, 0x0a, 0x05, 0x17, 0x82, 0x13, @@ -215,7 +215,7 @@ TEST_CASE("Dwarf2UnwindCodesX64") unwind.finishInfo(); std::vector data; - data.resize(unwind.getSize()); + data.resize(unwind.getUnwindInfoSize()); unwind.finalize(data.data(), 0, nullptr, 0); std::vector expected{0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x01, 0x78, 0x10, 0x0c, 0x07, 0x08, 0x90, 0x01, 0x00, @@ -241,7 +241,7 @@ TEST_CASE("Dwarf2UnwindCodesA64") unwind.finishInfo(); std::vector data; - data.resize(unwind.getSize()); + data.resize(unwind.getUnwindInfoSize()); unwind.finalize(data.data(), 0, nullptr, 0); std::vector expected{0x0c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x01, 0x78, 0x1e, 0x0c, 0x1f, 0x00, 0x2c, 0x00, 0x00, diff --git a/tests/Error.test.cpp b/tests/Error.test.cpp index 677e3217..8dfcbde0 100644 --- a/tests/Error.test.cpp +++ b/tests/Error.test.cpp @@ -6,6 +6,8 @@ using namespace Luau; +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) + TEST_SUITE_BEGIN("ErrorTests"); TEST_CASE("TypeError_code_should_return_nonzero_code") @@ -34,4 +36,44 @@ local x: Account = 5 CHECK_EQ("Type 'number' could not be converted into 'Account'", toString(result.errors[0])); } +TEST_CASE_FIXTURE(BuiltinsFixture, "binary_op_type_family_errors") +{ + frontend.options.retainFullTypeGraphs = false; + + CheckResult result = check(R"( + --!strict + local x = 1 + "foo" + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK_EQ("Operator '+' could not be applied to operands of types number and string; there is no corresponding overload for __add", toString(result.errors[0])); + else + CHECK_EQ("Type 'string' could not be converted into 'number'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "unary_op_type_family_errors") +{ + frontend.options.retainFullTypeGraphs = false; + + CheckResult result = check(R"( + --!strict + local x = -"foo" + )"); + + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ("Operator '-' could not be applied to operand of type string; there is no corresponding overload for __unm", toString(result.errors[0])); + CHECK_EQ("Type 'string' could not be converted into 'number'", toString(result.errors[1])); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 'string' could not be converted into 'number'", toString(result.errors[0])); + } +} + TEST_SUITE_END(); diff --git a/tests/Simplify.test.cpp b/tests/Simplify.test.cpp index ddddbe67..b938b5f8 100644 --- a/tests/Simplify.test.cpp +++ b/tests/Simplify.test.cpp @@ -214,6 +214,14 @@ TEST_CASE_FIXTURE(SimplifyFixture, "any_and_indeterminate_types") CHECK(errorTy == anyLhsPending->options[1]); } +TEST_CASE_FIXTURE(SimplifyFixture, "union_where_lhs_elements_are_a_subset_of_the_rhs") +{ + TypeId lhs = union_(numberTy, stringTy); + TypeId rhs = union_(stringTy, numberTy); + + CHECK("number | string" == toString(union_(lhs, rhs))); +} + TEST_CASE_FIXTURE(SimplifyFixture, "unknown_and_indeterminate_types") { CHECK(freeTy == intersect(unknownTy, freeTy)); diff --git a/tests/TypeFamily.test.cpp b/tests/TypeFamily.test.cpp index 14385054..88dfbf47 100644 --- a/tests/TypeFamily.test.cpp +++ b/tests/TypeFamily.test.cpp @@ -391,8 +391,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "keyof_type_family_errors_if_it_has_nontable_ // FIXME(CLI-95289): we should actually only report the type family being uninhabited error at its first use, I think? LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK(toString(result.errors[0]) == "Type family instance keyof is uninhabited"); - CHECK(toString(result.errors[1]) == "Type family instance keyof is uninhabited"); + CHECK(toString(result.errors[0]) == "Type 'MyObject | boolean' does not have keys, so 'keyof' is invalid"); + CHECK(toString(result.errors[1]) == "Type 'MyObject | boolean' does not have keys, so 'keyof' is invalid"); } TEST_CASE_FIXTURE(BuiltinsFixture, "keyof_type_family_string_indexer") @@ -517,8 +517,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "rawkeyof_type_family_errors_if_it_has_nontab // FIXME(CLI-95289): we should actually only report the type family being uninhabited error at its first use, I think? LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK(toString(result.errors[0]) == "Type family instance rawkeyof is uninhabited"); - CHECK(toString(result.errors[1]) == "Type family instance rawkeyof is uninhabited"); + CHECK(toString(result.errors[0]) == "Type 'MyObject | boolean' does not have keys, so 'rawkeyof' is invalid"); + CHECK(toString(result.errors[1]) == "Type 'MyObject | boolean' does not have keys, so 'rawkeyof' is invalid"); } TEST_CASE_FIXTURE(BuiltinsFixture, "rawkeyof_type_family_common_subset_if_union_of_differing_tables") @@ -590,8 +590,8 @@ TEST_CASE_FIXTURE(ClassFixture, "keyof_type_family_errors_if_it_has_nonclass_par // FIXME(CLI-95289): we should actually only report the type family being uninhabited error at its first use, I think? LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK(toString(result.errors[0]) == "Type family instance keyof is uninhabited"); - CHECK(toString(result.errors[1]) == "Type family instance keyof is uninhabited"); + CHECK(toString(result.errors[0]) == "Type 'BaseClass | boolean' does not have keys, so 'keyof' is invalid"); + CHECK(toString(result.errors[1]) == "Type 'BaseClass | boolean' does not have keys, so 'keyof' is invalid"); } TEST_CASE_FIXTURE(ClassFixture, "keyof_type_family_common_subset_if_union_of_differing_classes") diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index e424ddca..34178fd9 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -2298,10 +2298,10 @@ end if (FFlag::DebugLuauDeferredConstraintResolution) { LUAU_REQUIRE_ERROR_COUNT(4, result); - CHECK(toString(result.errors[0]) == "Type family instance sub is uninhabited"); - CHECK(toString(result.errors[1]) == "Type family instance sub is uninhabited"); - CHECK(toString(result.errors[2]) == "Type family instance sub is uninhabited"); - CHECK(toString(result.errors[3]) == "Type family instance sub is uninhabited"); + CHECK(toString(result.errors[0]) == "Operator '-' could not be applied to operands of types unknown and number; there is no corresponding overload for __sub"); + CHECK(toString(result.errors[1]) == "Operator '-' could not be applied to operands of types unknown and number; there is no corresponding overload for __sub"); + CHECK(toString(result.errors[2]) == "Operator '-' could not be applied to operands of types unknown and number; there is no corresponding overload for __sub"); + CHECK(toString(result.errors[3]) == "Operator '-' could not be applied to operands of types unknown and number; there is no corresponding overload for __sub"); } else { diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index d828ff65..bd0a4144 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -4479,4 +4479,31 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "index_results_compare_to_nil") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "fuzzer_normalization_preserves_tbl_scopes") +{ + CheckResult result = check(R"( +Module 'l0': +do end + +Module 'l1': +local _ = {n0=nil,} +if if nil then _ then +if nil and (_)._ ~= (_)._ then +do end +while _ do +_ = _ +do end +end +end +do end +end +local l0 +while _ do +_ = nil +(_[_])._ %= `{# _}{bit32.extract(# _,1)}` +end + +)"); +} + TEST_SUITE_END(); diff --git a/tools/faillist.txt b/tools/faillist.txt index 2450eeb1..6939df54 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -219,7 +219,6 @@ TableTests.setmetatable_has_a_side_effect TableTests.shared_selfs TableTests.shared_selfs_from_free_param TableTests.shared_selfs_through_metatables -TableTests.should_not_unblock_table_type_twice TableTests.table_call_metamethod_basic TableTests.table_call_metamethod_must_be_callable TableTests.table_param_width_subtyping_2 @@ -288,7 +287,6 @@ TypeInfer.unify_nearly_identical_recursive_types TypeInferAnyError.can_subscript_any TypeInferAnyError.for_in_loop_iterator_is_error TypeInferAnyError.for_in_loop_iterator_is_error2 -TypeInferAnyError.metatable_of_any_can_be_a_table TypeInferAnyError.replace_every_free_type_when_unifying_a_complex_function_with_any TypeInferClasses.callable_classes TypeInferClasses.cannot_unify_class_instance_with_primitive From daf79328fc85bae9781239271ff3184153484363 Mon Sep 17 00:00:00 2001 From: Alexander McCord <11488393+alexmccord@users.noreply.github.com> Date: Fri, 31 May 2024 12:18:18 -0700 Subject: [PATCH 12/20] Sync to upstream/release/628 (#1278) ### What's new? * Remove a case of unsound `table.move` optimization * Add Luau stack slot reservations that were missing in REPL (fixes #1273) ### New Type Solver * Assignments have been completely reworked to fix a case of cyclic constraint dependency * When indexing, if the fresh type's upper bound already contains a compatible indexer, do not add another upper bound * Distribute type arguments over all type families sans `eq`, `keyof`, `rawkeyof`, and other internal type families * Fix a case where `buffers` component weren't read in two places (fixes #1267) * Fix a case where things that constitutes a strong ref were slightly incorrect * Fix a case where constraint dependencies weren't setup wrt `for ... in` statement ### Native Codegen * Fix an optimization that splits TValue store only when its value and its tag are compatible * Implement a system to plug additional type information for custom host userdata types --- ### Internal Contributors Co-authored-by: Aaron Weiss Co-authored-by: Alexander McCord Co-authored-by: Andy Friesen Co-authored-by: Vighnesh Vijay Co-authored-by: Vyacheslav Egorov --------- Co-authored-by: Aaron Weiss Co-authored-by: Andy Friesen Co-authored-by: Vighnesh Co-authored-by: Aviral Goel Co-authored-by: David Cope Co-authored-by: Lily Brown Co-authored-by: Vyacheslav Egorov --- Analysis/include/Luau/Constraint.h | 77 +-- Analysis/include/Luau/ConstraintGenerator.h | 17 +- Analysis/include/Luau/ConstraintSolver.h | 20 +- Analysis/include/Luau/Simplify.h | 3 + Analysis/include/Luau/TypeFamily.h | 21 +- Analysis/include/Luau/TypeUtils.h | 3 + Analysis/src/Constraint.cpp | 42 +- Analysis/src/ConstraintGenerator.cpp | 415 ++++--------- Analysis/src/ConstraintSolver.cpp | 610 +++++++++++++------- Analysis/src/Simplify.cpp | 11 + Analysis/src/Subtyping.cpp | 1 + Analysis/src/ToString.cpp | 17 +- Analysis/src/TypeChecker2.cpp | 27 + Analysis/src/TypeFamily.cpp | 257 +++++---- Analysis/src/TypeUtils.cpp | 53 ++ CLI/Repl.cpp | 6 + CodeGen/include/Luau/CodeGen.h | 11 +- CodeGen/include/Luau/IrDump.h | 6 +- CodeGen/include/Luau/IrUtils.h | 4 + CodeGen/src/BytecodeAnalysis.cpp | 100 +--- CodeGen/src/CodeGenAssembly.cpp | 136 +++-- CodeGen/src/CodeGenContext.cpp | 24 + CodeGen/src/CodeGenContext.h | 3 + CodeGen/src/CodeGenLower.h | 7 +- CodeGen/src/IrBuilder.cpp | 32 +- CodeGen/src/IrDump.cpp | 79 ++- CodeGen/src/IrTranslation.cpp | 7 +- CodeGen/src/IrUtils.cpp | 10 + CodeGen/src/OptimizeConstProp.cpp | 50 +- Common/include/Luau/Bytecode.h | 8 +- Compiler/include/Luau/BytecodeBuilder.h | 14 + Compiler/include/Luau/Compiler.h | 3 + Compiler/include/luacode.h | 3 + Compiler/src/BytecodeBuilder.cpp | 146 ++++- Compiler/src/Compiler.cpp | 40 +- Compiler/src/Types.cpp | 47 +- Compiler/src/Types.h | 6 +- VM/src/lstate.h | 1 + VM/src/ltablib.cpp | 63 -- VM/src/lvmload.cpp | 104 +++- tests/Compiler.test.cpp | 91 ++- tests/Conformance.test.cpp | 9 +- tests/ConformanceIrHooks.h | 2 + tests/IrBuilder.test.cpp | 58 +- tests/IrLowering.test.cpp | 125 +++- tests/NonStrictTypeChecker.test.cpp | 18 + tests/Repl.test.cpp | 18 + tests/Subtyping.test.cpp | 1 + tests/ToString.test.cpp | 5 +- tests/TypeFamily.test.cpp | 117 +++- tests/TypeInfer.anyerror.test.cpp | 52 +- tests/TypeInfer.functions.test.cpp | 2 +- tests/TypeInfer.loops.test.cpp | 44 ++ tests/TypeInfer.operators.test.cpp | 6 +- tests/TypeInfer.primitives.test.cpp | 10 + tests/TypeInfer.singletons.test.cpp | 14 + tests/TypeInfer.tables.test.cpp | 39 +- tests/TypeInfer.typestates.test.cpp | 1 + tests/TypeInfer.unionTypes.test.cpp | 8 +- tests/conformance/move.lua | 37 +- tests/conformance/native.lua | 29 + tools/faillist.txt | 29 +- 62 files changed, 1986 insertions(+), 1213 deletions(-) diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h index ec281ae3..77810516 100644 --- a/Analysis/include/Luau/Constraint.h +++ b/Analysis/include/Luau/Constraint.h @@ -179,23 +179,6 @@ struct HasPropConstraint bool suppressSimplification = false; }; -// result ~ setProp subjectType ["prop", "prop2", ...] propType -// -// If the subject is a table or table-like thing that already has the named -// property chain, we unify propType with that existing property type. -// -// If the subject is a free table, we augment it in place. -// -// If the subject is an unsealed table, result is an augmented table that -// includes that new prop. -struct SetPropConstraint -{ - TypeId resultType; - TypeId subjectType; - std::vector path; - TypeId propType; -}; - // resultType ~ hasIndexer subjectType indexType // // If the subject type is a table or table-like thing that supports indexing, @@ -209,16 +192,37 @@ struct HasIndexerConstraint TypeId indexType; }; -// result ~ setIndexer subjectType indexType propType -// -// If the subject is a table or table-like thing that already has an indexer, -// unify its indexType and propType with those from this constraint. -// -// If the table is a free or unsealed table, we augment it with a new indexer. -struct SetIndexerConstraint +struct AssignConstraint { - TypeId subjectType; + TypeId lhsType; + TypeId rhsType; +}; + +// assign lhsType propName rhsType +// +// Assign a value of type rhsType into the named property of lhsType. + +struct AssignPropConstraint +{ + TypeId lhsType; + std::string propName; + TypeId rhsType; + + /// The canonical write type of the property. It is _solely_ used to + /// populate astTypes during constraint resolution. Nothing should ever + /// block on it. + TypeId propType; +}; + +struct AssignIndexConstraint +{ + TypeId lhsType; TypeId indexType; + TypeId rhsType; + + /// The canonical write type of the property. It is _solely_ used to + /// populate astTypes during constraint resolution. Nothing should ever + /// block on it. TypeId propType; }; @@ -230,25 +234,6 @@ struct UnpackConstraint { TypePackId resultPack; TypePackId sourcePack; - - // UnpackConstraint is sometimes used to resolve the types of assignments. - // When this is the case, any LocalTypes in resultPack can have their - // domains extended by the corresponding type from sourcePack. - bool resultIsLValue = false; -}; - -// resultType ~ unpack sourceType -// -// The same as UnpackConstraint, but specialized for a pair of types as opposed to packs. -struct Unpack1Constraint -{ - TypeId resultType; - TypeId sourceType; - - // UnpackConstraint is sometimes used to resolve the types of assignments. - // When this is the case, any LocalTypes in resultPack can have their - // domains extended by the corresponding type from sourcePack. - bool resultIsLValue = false; }; // ty ~ reduce ty @@ -268,8 +253,8 @@ struct ReducePackConstraint }; using ConstraintV = Variant; + TypeAliasExpansionConstraint, FunctionCallConstraint, FunctionCheckConstraint, PrimitiveTypeConstraint, HasPropConstraint, HasIndexerConstraint, + AssignConstraint, AssignPropConstraint, AssignIndexConstraint, UnpackConstraint, ReduceConstraint, ReducePackConstraint, EqualityConstraint>; struct Constraint { diff --git a/Analysis/include/Luau/ConstraintGenerator.h b/Analysis/include/Luau/ConstraintGenerator.h index ed5e17e2..3e1861ea 100644 --- a/Analysis/include/Luau/ConstraintGenerator.h +++ b/Analysis/include/Luau/ConstraintGenerator.h @@ -254,18 +254,11 @@ private: Inference check(const ScopePtr& scope, AstExprTable* expr, std::optional expectedType); std::tuple checkBinary(const ScopePtr& scope, AstExprBinary* binary, std::optional expectedType); - struct LValueBounds - { - std::optional annotationTy; - std::optional assignedTy; - }; - - LValueBounds checkLValue(const ScopePtr& scope, AstExpr* expr); - LValueBounds checkLValue(const ScopePtr& scope, AstExprLocal* local); - LValueBounds checkLValue(const ScopePtr& scope, AstExprGlobal* global); - LValueBounds checkLValue(const ScopePtr& scope, AstExprIndexName* indexName); - LValueBounds checkLValue(const ScopePtr& scope, AstExprIndexExpr* indexExpr); - LValueBounds updateProperty(const ScopePtr& scope, AstExpr* expr); + void visitLValue(const ScopePtr& scope, AstExpr* expr, TypeId rhsType); + void visitLValue(const ScopePtr& scope, AstExprLocal* local, TypeId rhsType); + void visitLValue(const ScopePtr& scope, AstExprGlobal* global, TypeId rhsType); + void visitLValue(const ScopePtr& scope, AstExprIndexName* indexName, TypeId rhsType); + void visitLValue(const ScopePtr& scope, AstExprIndexExpr* indexExpr, TypeId rhsType); struct FunctionSignature { diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index 031da67b..58361dde 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -134,7 +134,6 @@ struct ConstraintSolver bool tryDispatch(const FunctionCheckConstraint& c, NotNull constraint); bool tryDispatch(const PrimitiveTypeConstraint& c, NotNull constraint); bool tryDispatch(const HasPropConstraint& c, NotNull constraint); - bool tryDispatch(const SetPropConstraint& c, NotNull constraint); bool tryDispatchHasIndexer( int& recursionDepth, NotNull constraint, TypeId subjectType, TypeId indexType, TypeId resultType, Set& seen); @@ -142,11 +141,13 @@ struct ConstraintSolver std::pair> tryDispatchSetIndexer( NotNull constraint, TypeId subjectType, TypeId indexType, TypeId propType, bool expandFreeTypeBounds); - bool tryDispatch(const SetIndexerConstraint& c, NotNull constraint, bool force); - bool tryDispatchUnpack1(NotNull constraint, TypeId resultType, TypeId sourceType, bool resultIsLValue); + bool tryDispatch(const AssignConstraint& c, NotNull constraint); + bool tryDispatch(const AssignPropConstraint& c, NotNull constraint); + bool tryDispatch(const AssignIndexConstraint& c, NotNull constraint); + + bool tryDispatchUnpack1(NotNull constraint, TypeId resultType, TypeId sourceType); bool tryDispatch(const UnpackConstraint& c, NotNull constraint); - bool tryDispatch(const Unpack1Constraint& c, NotNull constraint); bool tryDispatch(const ReduceConstraint& c, NotNull constraint, bool force); bool tryDispatch(const ReducePackConstraint& c, NotNull constraint, bool force); @@ -165,6 +166,17 @@ struct ConstraintSolver std::pair, std::optional> lookupTableProp(NotNull constraint, TypeId subjectType, const std::string& propName, ValueContext context, bool inConditional, bool suppressSimplification, DenseHashSet& seen); + /** + * Generate constraints to unpack the types of srcTypes and assign each + * value to the corresponding LocalType in destTypes. + * + * @param destTypes A finite TypePack comprised of LocalTypes. + * @param srcTypes A TypePack that represents rvalues to be assigned. + * @returns The underlying UnpackConstraint. There's a bit of code in + * iteration that needs to pass blocks on to this constraint. + */ + NotNull unpackAndAssign(TypePackId destTypes, TypePackId srcTypes, NotNull constraint); + void block(NotNull target, NotNull constraint); /** * Block a constraint on the resolution of a Type. diff --git a/Analysis/include/Luau/Simplify.h b/Analysis/include/Luau/Simplify.h index 10f27d4e..5b363e96 100644 --- a/Analysis/include/Luau/Simplify.h +++ b/Analysis/include/Luau/Simplify.h @@ -5,6 +5,7 @@ #include "Luau/DenseHash.h" #include "Luau/NotNull.h" #include "Luau/TypeFwd.h" +#include namespace Luau { @@ -19,6 +20,8 @@ struct SimplifyResult }; SimplifyResult simplifyIntersection(NotNull builtinTypes, NotNull arena, TypeId ty, TypeId discriminant); +SimplifyResult simplifyIntersection(NotNull builtinTypes, NotNull arena, std::set parts); + SimplifyResult simplifyUnion(NotNull builtinTypes, NotNull arena, TypeId ty, TypeId discriminant); enum class Relation diff --git a/Analysis/include/Luau/TypeFamily.h b/Analysis/include/Luau/TypeFamily.h index fa418e17..5b72a370 100644 --- a/Analysis/include/Luau/TypeFamily.h +++ b/Analysis/include/Luau/TypeFamily.h @@ -6,7 +6,6 @@ #include "Luau/NotNull.h" #include "Luau/TypeCheckLimits.h" #include "Luau/TypeFwd.h" -#include "Luau/Variant.h" #include #include @@ -19,22 +18,6 @@ struct TypeArena; struct TxnLog; class Normalizer; -struct TypeFamilyQueue -{ - NotNull> queuedTys; - NotNull> queuedTps; - - void add(TypeId instanceTy); - void add(TypePackId instanceTp); - - template - void add(const std::vector& ts) - { - for (const T& t : ts) - enqueue(t); - } -}; - struct TypeFamilyContext { NotNull arena; @@ -99,8 +82,8 @@ struct TypeFamilyReductionResult }; template -using ReducerFunction = std::function( - T, NotNull, const std::vector&, const std::vector&, NotNull)>; +using ReducerFunction = std::function(T, const std::vector&, const std::vector&, + NotNull)>; /// Represents a type function that may be applied to map a series of types and /// type packs to a single output type. diff --git a/Analysis/include/Luau/TypeUtils.h b/Analysis/include/Luau/TypeUtils.h index 81c8a5ca..c8ee99e9 100644 --- a/Analysis/include/Luau/TypeUtils.h +++ b/Analysis/include/Luau/TypeUtils.h @@ -55,6 +55,9 @@ struct InConditionalContext using ScopePtr = std::shared_ptr; +std::optional findTableProperty( + NotNull builtinTypes, ErrorVec& errors, TypeId ty, const std::string& name, Location location); + std::optional findMetatableEntry( NotNull builtinTypes, ErrorVec& errors, TypeId type, const std::string& entry, Location location); std::optional findTablePropertyRespectingMeta( diff --git a/Analysis/src/Constraint.cpp b/Analysis/src/Constraint.cpp index 4d1c35e0..bd31beff 100644 --- a/Analysis/src/Constraint.cpp +++ b/Analysis/src/Constraint.cpp @@ -56,6 +56,11 @@ bool isReferenceCountedType(const TypeId typ) DenseHashSet Constraint::getMaybeMutatedFreeTypes() const { + // For the purpose of this function and reference counting in general, we are only considering + // mutations that affect the _bounds_ of the free type, and not something that may bind the free + // type itself to a new type. As such, `ReduceConstraint` and `GeneralizationConstraint` have no + // contribution to the output set here. + DenseHashSet types{{}}; ReferenceCountInitializer rci{&types}; @@ -74,11 +79,6 @@ DenseHashSet Constraint::getMaybeMutatedFreeTypes() const rci.traverse(psc->subPack); rci.traverse(psc->superPack); } - else if (auto gc = get(*this)) - { - rci.traverse(gc->generalizedType); - // `GeneralizationConstraints` should not mutate `sourceType` or `interiorTypes`. - } else if (auto itc = get(*this)) { rci.traverse(itc->variables); @@ -101,36 +101,32 @@ DenseHashSet Constraint::getMaybeMutatedFreeTypes() const rci.traverse(hpc->resultType); // `HasPropConstraints` should not mutate `subjectType`. } - else if (auto spc = get(*this)) - { - rci.traverse(spc->resultType); - // `SetPropConstraints` should not mutate `subjectType` or `propType`. - // TODO: is this true? it "unifies" with `propType`, so maybe mutates that one too? - } else if (auto hic = get(*this)) { rci.traverse(hic->resultType); // `HasIndexerConstraint` should not mutate `subjectType` or `indexType`. } - else if (auto sic = get(*this)) + else if (auto ac = get(*this)) { - rci.traverse(sic->propType); - // `SetIndexerConstraints` should not mutate `subjectType` or `indexType`. + rci.traverse(ac->lhsType); + rci.traverse(ac->rhsType); + } + else if (auto apc = get(*this)) + { + rci.traverse(apc->lhsType); + rci.traverse(apc->rhsType); + } + else if (auto aic = get(*this)) + { + rci.traverse(aic->lhsType); + rci.traverse(aic->indexType); + rci.traverse(aic->rhsType); } else if (auto uc = get(*this)) { rci.traverse(uc->resultPack); // `UnpackConstraint` should not mutate `sourcePack`. } - else if (auto u1c = get(*this)) - { - rci.traverse(u1c->resultType); - // `Unpack1Constraint` should not mutate `sourceType`. - } - else if (auto rc = get(*this)) - { - rci.traverse(rc->ty); - } else if (auto rpc = get(*this)) { rci.traverse(rpc->tp); diff --git a/Analysis/src/ConstraintGenerator.cpp b/Analysis/src/ConstraintGenerator.cpp index cbd027bb..12648eb0 100644 --- a/Analysis/src/ConstraintGenerator.cpp +++ b/Analysis/src/ConstraintGenerator.cpp @@ -310,7 +310,7 @@ std::optional ConstraintGenerator::lookup(const ScopePtr& scope, Locatio std::optional ty = lookup(scope, location, operand, /*prototype*/ false); if (!ty) { - ty = arena->addType(BlockedType{}); + ty = arena->addType(LocalType{builtinTypes->neverType}); rootScope->lvalueTypes[operand] = *ty; } @@ -739,12 +739,28 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatLocal* stat if (hasAnnotation) { + for (size_t i = 0; i < statLocal->vars.size; ++i) + addConstraint(scope, statLocal->location, AssignConstraint{assignees[i], annotatedTypes[i]}); + TypePackId annotatedPack = arena->addTypePack(std::move(annotatedTypes)); - addConstraint(scope, statLocal->location, UnpackConstraint{arena->addTypePack(std::move(assignees)), annotatedPack, /*resultIsLValue*/ true}); addConstraint(scope, statLocal->location, PackSubtypeConstraint{rvaluePack, annotatedPack}); } else - addConstraint(scope, statLocal->location, UnpackConstraint{arena->addTypePack(std::move(assignees)), rvaluePack, /*resultIsLValue*/ true}); + { + std::vector valueTypes; + valueTypes.reserve(statLocal->vars.size); + + for (size_t i = 0; i < statLocal->vars.size; ++i) + valueTypes.push_back(arena->addType(BlockedType{})); + + auto uc = addConstraint(scope, statLocal->location, UnpackConstraint{arena->addTypePack(valueTypes), rvaluePack}); + + for (size_t i = 0; i < statLocal->vars.size; ++i) + { + getMutable(valueTypes[i])->setOwner(uc); + addConstraint(scope, statLocal->location, AssignConstraint{assignees[i], valueTypes[i]}); + } + } if (statLocal->vars.size == 1 && statLocal->values.size == 1 && firstValueType && scope.get() == rootScope && !hasAnnotation) { @@ -837,7 +853,6 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatFor* for_) ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatForIn* forIn) { ScopePtr loopScope = childScope(forIn, scope); - TypePackId iterator = checkPack(scope, forIn->values).tp; std::vector variableTypes; @@ -862,10 +877,17 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatForIn* forI } TypePackId variablePack = arena->addTypePack(std::move(variableTypes)); - addConstraint( + auto iterable = addConstraint( loopScope, getLocation(forIn->values), IterableConstraint{iterator, variablePack, forIn->values.data[0], &module->astForInNextTypes}); + Checkpoint start = checkpoint(this); visit(loopScope, forIn->body); + Checkpoint end = checkpoint(this); + + // This iter constraint must dispatch first. + forEachConstraint(start, end, this, [&iterable](const ConstraintPtr& runLater) { + runLater->dependencies.push_back(iterable); + }); return ControlFlow::None; } @@ -957,67 +979,63 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatFunction* f // Name could be AstStatLocal, AstStatGlobal, AstStatIndexName. // With or without self - TypeId generalizedType = arena->addType(BlockedType{}); Checkpoint start = checkpoint(this); FunctionSignature sig = checkFunctionSignature(scope, function->func, /* expectedType */ std::nullopt, function->name->location); bool sigFullyDefined = !hasFreeType(sig.signature); + checkFunctionBody(sig.bodyScope, function->func); + Checkpoint end = checkpoint(this); + + TypeId generalizedType = arena->addType(BlockedType{}); if (sigFullyDefined) emplaceType(asMutable(generalizedType), sig.signature); + else + { + const ScopePtr& constraintScope = sig.signatureScope ? sig.signatureScope : sig.bodyScope; - DenseHashSet excludeList{nullptr}; + NotNull c = addConstraint(constraintScope, function->name->location, GeneralizationConstraint{generalizedType, sig.signature}); + getMutable(generalizedType)->setOwner(c); + + Constraint* previous = nullptr; + forEachConstraint(start, end, this, [&c, &previous](const ConstraintPtr& constraint) { + c->dependencies.push_back(NotNull{constraint.get()}); + + if (auto psc = get(*constraint); psc && psc->returns) + { + if (previous) + constraint->dependencies.push_back(NotNull{previous}); + + previous = constraint.get(); + } + }); + } DefId def = dfg->getDef(function->name); std::optional existingFunctionTy = follow(lookup(scope, function->name->location, def)); - if (get(existingFunctionTy) && sigFullyDefined) - emplaceType(asMutable(*existingFunctionTy), sig.signature); - if (AstExprLocal* localName = function->name->as()) { - if (existingFunctionTy) - { - addConstraint(scope, function->name->location, SubtypeConstraint{generalizedType, *existingFunctionTy}); - - Symbol sym{localName->local}; - scope->bindings[sym].typeId = generalizedType; - } - else - scope->bindings[localName->local] = Binding{generalizedType, localName->location}; + visitLValue(scope, localName, generalizedType); scope->bindings[localName->local] = Binding{sig.signature, localName->location}; scope->lvalueTypes[def] = sig.signature; - scope->rvalueRefinements[def] = sig.signature; } else if (AstExprGlobal* globalName = function->name->as()) { if (!existingFunctionTy) ice->ice("prepopulateGlobalScope did not populate a global name", globalName->location); - if (!sigFullyDefined) - generalizedType = *existingFunctionTy; + // Sketchy: We're specifically looking for BlockedTypes that were + // initially created by ConstraintGenerator::prepopulateGlobalScope. + if (auto bt = get(*existingFunctionTy); bt && nullptr == bt->getOwner()) + emplaceType(asMutable(*existingFunctionTy), generalizedType); scope->bindings[globalName->name] = Binding{sig.signature, globalName->location}; scope->lvalueTypes[def] = sig.signature; - scope->rvalueRefinements[def] = sig.signature; } else if (AstExprIndexName* indexName = function->name->as()) { - Checkpoint check1 = checkpoint(this); - auto [_, lvalueType] = checkLValue(scope, indexName); - Checkpoint check2 = checkpoint(this); - - forEachConstraint(check1, check2, this, [&excludeList](const ConstraintPtr& c) { - excludeList.insert(c.get()); - }); - - // TODO figure out how to populate the location field of the table Property. - - if (lvalueType && *lvalueType != generalizedType) - { - LUAU_ASSERT(get(lvalueType)); - emplaceType(asMutable(*lvalueType), generalizedType); - } + visitLValue(scope, indexName, generalizedType); } else if (AstExprError* err = function->name->as()) { @@ -1029,48 +1047,6 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatFunction* f scope->rvalueRefinements[def] = generalizedType; - checkFunctionBody(sig.bodyScope, function->func); - Checkpoint end = checkpoint(this); - - if (!sigFullyDefined) - { - NotNull constraintScope{sig.signatureScope ? sig.signatureScope.get() : sig.bodyScope.get()}; - std::unique_ptr c = - std::make_unique(constraintScope, function->name->location, GeneralizationConstraint{generalizedType, sig.signature}); - - Constraint* previous = nullptr; - forEachConstraint(start, end, this, [&c, &excludeList, &previous](const ConstraintPtr& constraint) { - if (!excludeList.contains(constraint.get())) - c->dependencies.push_back(NotNull{constraint.get()}); - - if (auto psc = get(*constraint); psc && psc->returns) - { - if (previous) - constraint->dependencies.push_back(NotNull{previous}); - - previous = constraint.get(); - } - }); - - - // We need to check if the blocked type has no owner here because - // if a function is defined twice anywhere in the program like: - // `function f() end` and then later like `function f() end` - // Then there will be exactly one definition in the scope for it because it's a global - // (this is the same as writing f = function() end) - // Therefore, when we visit() the multiple different expression of this global variable - // They will all be aliased to the same blocked type, which means we can create multiple constraints - // for the same blocked type. - if (auto blocked = getMutable(generalizedType); blocked && !blocked->getOwner()) - blocked->setOwner(addConstraint(scope, std::move(c))); - } - - if (BlockedType* bt = getMutable(follow(existingFunctionTy)); bt && !bt->getOwner()) - { - auto uc = addConstraint(scope, function->name->location, Unpack1Constraint{*existingFunctionTy, generalizedType}); - bt->setOwner(uc); - } - return ControlFlow::None; } @@ -1124,38 +1100,20 @@ static void bindFreeType(TypeId a, TypeId b) ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatAssign* assign) { - std::vector upperBounds; - upperBounds.reserve(assign->vars.size); - - std::vector typeStates; - typeStates.reserve(assign->vars.size); - - Checkpoint lvalueBeginCheckpoint = checkpoint(this); - - for (AstExpr* lvalue : assign->vars) - { - auto [upperBound, typeState] = checkLValue(scope, lvalue); - upperBounds.push_back(upperBound.value_or(builtinTypes->unknownType)); - typeStates.push_back(typeState.value_or(builtinTypes->unknownType)); - } - - Checkpoint lvalueEndCheckpoint = checkpoint(this); - TypePackId resultPack = checkPack(scope, assign->values).tp; - auto uc = addConstraint(scope, assign->location, UnpackConstraint{arena->addTypePack(typeStates), resultPack, /*resultIsLValue*/ true}); - forEachConstraint(lvalueBeginCheckpoint, lvalueEndCheckpoint, this, [uc](const ConstraintPtr& constraint) { - uc->dependencies.push_back(NotNull{constraint.get()}); - }); - auto psc = addConstraint(scope, assign->location, PackSubtypeConstraint{resultPack, arena->addTypePack(std::move(upperBounds))}); - psc->dependencies.push_back(uc); + std::vector valueTypes; + valueTypes.reserve(assign->vars.size); - for (TypeId assignee : typeStates) + for (size_t i = 0; i < assign->vars.size; ++i) + valueTypes.push_back(arena->addType(BlockedType{})); + + auto uc = addConstraint(scope, assign->location, UnpackConstraint{arena->addTypePack(valueTypes), resultPack}); + + for (size_t i = 0; i < assign->vars.size; ++i) { - auto blocked = getMutable(assignee); - - if (blocked && !blocked->getOwner()) - blocked->setOwner(uc); + getMutable(valueTypes[i])->setOwner(uc); + visitLValue(scope, assign->vars.data[i], valueTypes[i]); } return ControlFlow::None; @@ -1166,24 +1124,7 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatCompoundAss AstExprBinary binop = AstExprBinary{assign->location, assign->op, assign->var, assign->value}; TypeId resultTy = check(scope, &binop).ty; - auto [upperBound, typeState] = checkLValue(scope, assign->var); - - Constraint* sc = nullptr; - if (upperBound) - sc = addConstraint(scope, assign->location, SubtypeConstraint{resultTy, *upperBound}); - - if (typeState) - { - NotNull uc = addConstraint(scope, assign->location, Unpack1Constraint{*typeState, resultTy, /*resultIsLValue=*/true}); - if (auto blocked = getMutable(*typeState); blocked && !blocked->getOwner()) - blocked->setOwner(uc); - - if (sc) - uc->dependencies.push_back(NotNull{sc}); - } - - DefId def = dfg->getDef(assign->var); - scope->lvalueTypes[def] = resultTy; + visitLValue(scope, assign->var, resultTy); return ControlFlow::None; } @@ -1897,7 +1838,8 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprGlobal* globa return Inference{builtinTypes->errorRecoveryType()}; } -Inference ConstraintGenerator::checkIndexName(const ScopePtr& scope, const RefinementKey* key, AstExpr* indexee, const std::string& index, Location indexLocation) +Inference ConstraintGenerator::checkIndexName( + const ScopePtr& scope, const RefinementKey* key, AstExpr* indexee, const std::string& index, Location indexLocation) { TypeId obj = check(scope, indexee).ty; TypeId result = arena->addType(BlockedType{}); @@ -2272,26 +2214,25 @@ std::tuple ConstraintGenerator::checkBinary( } } -ConstraintGenerator::LValueBounds ConstraintGenerator::checkLValue(const ScopePtr& scope, AstExpr* expr) +void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExpr* expr, TypeId rhsType) { - if (auto local = expr->as()) - return checkLValue(scope, local); - else if (auto global = expr->as()) - return checkLValue(scope, global); - else if (auto indexName = expr->as()) - return checkLValue(scope, indexName); - else if (auto indexExpr = expr->as()) - return checkLValue(scope, indexExpr); - else if (auto error = expr->as()) + if (auto e = expr->as()) + visitLValue(scope, e, rhsType); + else if (auto e = expr->as()) + visitLValue(scope, e, rhsType); + else if (auto e = expr->as()) + visitLValue(scope, e, rhsType); + else if (auto e = expr->as()) + visitLValue(scope, e, rhsType); + else if (auto e = expr->as()) { - check(scope, error); - return {builtinTypes->errorRecoveryType(), builtinTypes->errorRecoveryType()}; + // Nothing? } else - ice->ice("checkLValue is inexhaustive"); + ice->ice("Unexpected lvalue expression", expr->location); } -ConstraintGenerator::LValueBounds ConstraintGenerator::checkLValue(const ScopePtr& scope, AstExprLocal* local) +void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprLocal* local, TypeId rhsType) { std::optional annotatedTy = scope->lookup(local->local); LUAU_ASSERT(annotatedTy); @@ -2332,186 +2273,53 @@ ConstraintGenerator::LValueBounds ConstraintGenerator::checkLValue(const ScopePt scope->lvalueTypes[defId] = *ty; } - // TODO: Need to clip this, but this requires more code to be reworked first before we can clip this. - std::optional assignedTy = arena->addType(BlockedType{}); - - auto unpackC = addConstraint(scope, local->location, Unpack1Constraint{*ty, *assignedTy, /*resultIsLValue*/ true}); - - if (auto blocked = get(*ty)) - { - if (blocked->getOwner()) - unpackC->dependencies.push_back(NotNull{blocked->getOwner()}); - else if (auto blocked = getMutable(*ty)) - blocked->setOwner(unpackC); - } - recordInferredBinding(local->local, *ty); - return {annotatedTy, assignedTy}; + if (annotatedTy) + addConstraint(scope, local->location, SubtypeConstraint{rhsType, *annotatedTy}); + addConstraint(scope, local->location, AssignConstraint{*ty, rhsType}); } -ConstraintGenerator::LValueBounds ConstraintGenerator::checkLValue(const ScopePtr& scope, AstExprGlobal* global) +void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprGlobal* global, TypeId rhsType) { std::optional annotatedTy = scope->lookup(Symbol{global->name}); if (annotatedTy) { DefId def = dfg->getDef(global); - TypeId assignedTy = arena->addType(BlockedType{}); - rootScope->lvalueTypes[def] = assignedTy; - return {annotatedTy, assignedTy}; + rootScope->lvalueTypes[def] = rhsType; + + addConstraint(scope, global->location, SubtypeConstraint{rhsType, *annotatedTy}); + addConstraint(scope, global->location, AssignConstraint{*annotatedTy, rhsType}); } - else - return {annotatedTy, std::nullopt}; } -ConstraintGenerator::LValueBounds ConstraintGenerator::checkLValue(const ScopePtr& scope, AstExprIndexName* indexName) +void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprIndexName* expr, TypeId rhsType) { - return updateProperty(scope, indexName); + TypeId lhsTy = check(scope, expr->expr).ty; + TypeId propTy = arena->addType(BlockedType{}); + module->astTypes[expr] = propTy; + addConstraint(scope, expr->location, AssignPropConstraint{lhsTy, expr->index.value, rhsType, propTy}); } -ConstraintGenerator::LValueBounds ConstraintGenerator::checkLValue(const ScopePtr& scope, AstExprIndexExpr* indexExpr) +void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprIndexExpr* expr, TypeId rhsType) { - return updateProperty(scope, indexExpr); -} - -/** - * This function is mostly about identifying properties that are being inserted into unsealed tables. - * - * If expr has the form name.a.b.c - */ -ConstraintGenerator::LValueBounds ConstraintGenerator::updateProperty(const ScopePtr& scope, AstExpr* expr) -{ - // There are a bunch of cases where we realize that this is not the kind of - // assignment that potentially changes the shape of a table. When we - // encounter them, we call this to fall back and do the "usual thing." - auto fallback = [&]() -> LValueBounds { - TypeId resTy = check(scope, expr).ty; - return {resTy, std::nullopt}; - }; - - LUAU_ASSERT(expr->is() || expr->is()); - - if (auto indexExpr = expr->as(); indexExpr && !indexExpr->index->is()) + if (auto constantString = expr->index->as()) { - // An indexer is only interesting in an lvalue-ey way if it is at the - // tail of an expression. - // - // If the indexer is not at the tail, then we are not interested in - // augmenting the lhs data structure with a new indexer. Constraint - // generation can treat it as an ordinary lvalue. - // - // eg - // - // a.b.c[1] = 44 -- lvalue - // a.b[4].c = 2 -- rvalue + TypeId lhsTy = check(scope, expr->expr).ty; + TypeId propTy = arena->addType(BlockedType{}); + module->astTypes[expr] = propTy; + module->astTypes[expr->index] = builtinTypes->stringType; // FIXME? Singleton strings exist. + std::string propName{constantString->value.data, constantString->value.size}; + addConstraint(scope, expr->location, AssignPropConstraint{lhsTy, std::move(propName), rhsType, propTy}); - TypeId subjectType = check(scope, indexExpr->expr).ty; - TypeId indexType = check(scope, indexExpr->index).ty; - TypeId assignedTy = arena->addType(BlockedType{}); - auto sic = addConstraint(scope, expr->location, SetIndexerConstraint{subjectType, indexType, assignedTy}); - getMutable(assignedTy)->setOwner(sic); - - module->astTypes[expr] = assignedTy; - - return {assignedTy, assignedTy}; + return; } - Symbol sym; - const Def* def = nullptr; - std::vector segments; - std::vector exprs; - - AstExpr* e = expr; - while (e) - { - if (auto global = e->as()) - { - sym = global->name; - def = dfg->getDef(global); - break; - } - else if (auto local = e->as()) - { - sym = local->local; - def = dfg->getDef(local); - break; - } - else if (auto indexName = e->as()) - { - segments.push_back(indexName->index.value); - exprs.push_back(e); - e = indexName->expr; - } - else if (auto indexExpr = e->as()) - { - if (auto strIndex = indexExpr->index->as()) - { - // We need to populate astTypes for the index value. - check(scope, indexExpr->index); - - segments.push_back(std::string(strIndex->value.data, strIndex->value.size)); - exprs.push_back(e); - e = indexExpr->expr; - } - else - { - return fallback(); - } - } - else - { - return fallback(); - } - } - - LUAU_ASSERT(!segments.empty()); - - std::reverse(begin(segments), end(segments)); - std::reverse(begin(exprs), end(exprs)); - - LUAU_ASSERT(def); - std::optional> lookupResult = scope->lookupEx(NotNull{def}); - if (!lookupResult) - return fallback(); - - const auto [subjectType, subjectScope] = *lookupResult; - - std::vector segmentStrings(begin(segments), end(segments)); - - TypeId updatedType = arena->addType(BlockedType{}); - TypeId assignedTy = arena->addType(BlockedType{}); - auto setC = addConstraint(scope, expr->location, SetPropConstraint{updatedType, subjectType, std::move(segmentStrings), assignedTy}); - getMutable(updatedType)->setOwner(setC); - - TypeId prevSegmentTy = updatedType; - for (size_t i = 0; i < segments.size(); ++i) - { - TypeId segmentTy = arena->addType(BlockedType{}); - module->astTypes[exprs[i]] = segmentTy; - ValueContext ctx = i == segments.size() - 1 ? ValueContext::LValue : ValueContext::RValue; - auto hasC = addConstraint(scope, expr->location, HasPropConstraint{segmentTy, prevSegmentTy, segments[i], ctx, inConditional(typeContext)}); - getMutable(segmentTy)->setOwner(hasC); - setC->dependencies.push_back(hasC); - prevSegmentTy = segmentTy; - } - - module->astTypes[expr] = prevSegmentTy; - module->astTypes[e] = updatedType; - - if (!subjectType->persistent) - { - subjectScope->bindings[sym].typeId = updatedType; - - // This can fail if the user is erroneously trying to augment a builtin - // table like os or string. - if (auto key = dfg->getRefinementKey(e)) - { - subjectScope->lvalueTypes[key->def] = updatedType; - subjectScope->rvalueRefinements[key->def] = updatedType; - } - } - - return {assignedTy, assignedTy}; + TypeId lhsTy = check(scope, expr->expr).ty; + TypeId indexTy = check(scope, expr->index).ty; + TypeId propTy = arena->addType(BlockedType{}); + module->astTypes[expr] = propTy; + addConstraint(scope, expr->location, AssignIndexConstraint{lhsTy, indexTy, rhsType, propTy}); } Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprTable* expr, std::optional expectedType) @@ -3194,6 +3002,11 @@ void ConstraintGenerator::reportCodeTooComplex(Location location) TypeId ConstraintGenerator::makeUnion(const ScopePtr& scope, Location location, TypeId lhs, TypeId rhs) { + if (get(follow(lhs))) + return rhs; + if (get(follow(rhs))) + return lhs; + TypeId resultType = createTypeFamilyInstance(kBuiltinTypeFamilies.unionFamily, {lhs, rhs}, {}, scope, location); return resultType; diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index e35ddf0e..b0f27911 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -206,6 +206,12 @@ static std::pair, std::vector> saturateArguments saturatedPackArguments.push_back(builtinTypes->errorRecoveryTypePack()); } + for (TypeId& arg : saturatedTypeArguments) + arg = follow(arg); + + for (TypePackId& pack : saturatedPackArguments) + pack = follow(pack); + // At this point, these two conditions should be true. If they aren't we // will run into access violations. LUAU_ASSERT(saturatedTypeArguments.size() == fn.typeParams.size()); @@ -407,11 +413,17 @@ void ConstraintSolver::run() // decrement the referenced free types for this constraint if we dispatched successfully! for (auto ty : c->getMaybeMutatedFreeTypes()) { - // this is a little weird, but because we're only counting free types in subtyping constraints, - // some constraints (like unpack) might actually produce _more_ references to a free type. size_t& refCount = unresolvedConstraints[ty]; if (refCount > 0) refCount -= 1; + + // We have two constraints that are designed to wait for the + // refCount on a free type to be equal to 1: the + // PrimitiveTypeConstraint and ReduceConstraint. We + // therefore wake any constraint waiting for a free type's + // refcount to be 1 or 0. + if (refCount <= 1) + unblock(ty, Location{}); } if (logger) @@ -518,15 +530,15 @@ bool ConstraintSolver::tryDispatch(NotNull constraint, bool fo success = tryDispatch(*fcc, constraint); else if (auto hpc = get(*constraint)) success = tryDispatch(*hpc, constraint); - else if (auto spc = get(*constraint)) - success = tryDispatch(*spc, constraint); else if (auto spc = get(*constraint)) success = tryDispatch(*spc, constraint); - else if (auto spc = get(*constraint)) - success = tryDispatch(*spc, constraint, force); - else if (auto uc = get(*constraint)) + else if (auto uc = get(*constraint)) success = tryDispatch(*uc, constraint); - else if (auto uc = get(*constraint)) + else if (auto uc = get(*constraint)) + success = tryDispatch(*uc, constraint); + else if (auto uc = get(*constraint)) + success = tryDispatch(*uc, constraint); + else if (auto uc = get(*constraint)) success = tryDispatch(*uc, constraint); else if (auto rc = get(*constraint)) success = tryDispatch(*rc, constraint, force); @@ -688,7 +700,18 @@ bool ConstraintSolver::tryDispatch(const IterableConstraint& c, NotNull(tableTy)->indexer = TableIndexer{keyTy, valueTy}; pushConstraint(constraint->scope, constraint->location, SubtypeConstraint{nextTy, tableTy}); - pushConstraint(constraint->scope, constraint->location, UnpackConstraint{c.variables, arena->addTypePack({keyTy, valueTy}), /*resultIsLValue=*/true}); + + auto it = begin(c.variables); + auto endIt = end(c.variables); + + if (it != endIt) + { + pushConstraint(constraint->scope, constraint->location, AssignConstraint{*it, keyTy}); + ++it; + } + if (it != endIt) + pushConstraint(constraint->scope, constraint->location, AssignConstraint{*it, valueTy}); + return true; } @@ -915,7 +938,17 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul // Type function application will happily give us the exact same type if // there are e.g. generic saturatedTypeArguments that go unused. const TableType* tfTable = getTableType(tf->type); - bool needsClone = follow(tf->type) == target || (tfTable != nullptr && tfTable == getTableType(target)); + + //clang-format off + bool needsClone = + follow(tf->type) == target || + (tfTable != nullptr && tfTable == getTableType(target)) || + std::any_of(typeArguments.begin(), typeArguments.end(), [&](const auto& other) { + return other == target; + } + ); + //clang-format on + // Only tables have the properties we're trying to set. TableType* ttv = getMutableTableType(target); @@ -1291,158 +1324,6 @@ bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNull(ty); - return ttv && ttv->state == TableState::Unsealed; -} - -/** - * Given a path into a set of nested unsealed tables `ty`, insert a new property `replaceTy` as the leaf-most property. - * - * Fails and does nothing if every table along the way is not unsealed. - * - * Mutates the innermost table type in-place. - */ -static void updateTheTableType( - NotNull builtinTypes, NotNull arena, TypeId ty, const std::vector& path, TypeId replaceTy) -{ - if (path.empty()) - return; - - // First walk the path and ensure that it's unsealed tables all the way - // to the end. - { - TypeId t = ty; - for (size_t i = 0; i < path.size() - 1; ++i) - { - if (!isUnsealedTable(t)) - return; - - const TableType* tbl = get(t); - auto it = tbl->props.find(path[i]); - if (it == tbl->props.end()) - return; - - t = follow(it->second.type()); - } - - // The last path segment should not be a property of the table at all. - // We are not changing property types. We are only admitting this one - // new property to be appended. - if (!isUnsealedTable(t)) - return; - const TableType* tbl = get(t); - if (0 != tbl->props.count(path.back())) - return; - } - - TypeId t = ty; - ErrorVec dummy; - - for (size_t i = 0; i < path.size() - 1; ++i) - { - t = follow(t); - auto propTy = findTablePropertyRespectingMeta(builtinTypes, dummy, t, path[i], ValueContext::LValue, Location{}); - dummy.clear(); - - if (!propTy) - return; - - t = *propTy; - } - - const std::string& lastSegment = path.back(); - - t = follow(t); - TableType* tt = getMutable(t); - if (auto mt = get(t)) - tt = getMutable(mt->table); - - if (!tt) - return; - - tt->props[lastSegment].setType(replaceTy); -} - -bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNull constraint) -{ - TypeId subjectType = follow(c.subjectType); - const TypeId propType = follow(c.propType); - - if (isBlocked(subjectType)) - return block(subjectType, constraint); - - std::optional existingPropType = subjectType; - - LUAU_ASSERT(!c.path.empty()); - if (c.path.empty()) - return false; - - for (size_t i = 0; i < c.path.size(); ++i) - { - const std::string& segment = c.path[i]; - if (!existingPropType) - break; - - ValueContext ctx = i == c.path.size() - 1 ? ValueContext::LValue : ValueContext::RValue; - - auto [blocked, result] = lookupTableProp(constraint, *existingPropType, segment, ctx); - if (!blocked.empty()) - { - for (TypeId blocked : blocked) - block(blocked, constraint); - return false; - } - - existingPropType = result; - } - - auto bind = [&](TypeId a, TypeId b) { - bindBlockedType(a, b, subjectType, constraint); - }; - - if (existingPropType) - { - unify(constraint, propType, *existingPropType); - unify(constraint, *existingPropType, propType); - bind(c.resultType, c.subjectType); - unblock(c.resultType, constraint->location); - return true; - } - - const TypeId originalSubjectType = subjectType; - - if (auto mt = get(subjectType)) - subjectType = follow(mt->table); - - if (get(subjectType)) - return false; - else if (auto ttv = getMutable(subjectType)) - { - if (ttv->state == TableState::Free) - { - LUAU_ASSERT(!subjectType->persistent); - - ttv->props[c.path[0]] = Property{propType}; - bind(c.resultType, subjectType); - unblock(c.resultType, constraint->location); - return true; - } - else if (ttv->state == TableState::Unsealed) - { - LUAU_ASSERT(!subjectType->persistent); - - updateTheTableType(builtinTypes, NotNull{arena}, subjectType, c.path, propType); - } - } - - bind(c.resultType, originalSubjectType); - unblock(c.resultType, constraint->location); - return true; -} - bool ConstraintSolver::tryDispatchHasIndexer( int& recursionDepth, NotNull constraint, TypeId subjectType, TypeId indexType, TypeId resultType, Set& seen) { @@ -1460,6 +1341,13 @@ bool ConstraintSolver::tryDispatchHasIndexer( if (auto ft = get(subjectType)) { + if (auto tbl = get(follow(ft->upperBound)); tbl && tbl->indexer) + { + unify(constraint, indexType, tbl->indexer->indexType); + bindBlockedType(resultType, tbl->indexer->indexResultType, subjectType, constraint); + return true; + } + FreeType freeResult{ft->scope, builtinTypes->neverType, builtinTypes->unknownType}; emplaceType(asMutable(resultType), freeResult); @@ -1708,33 +1596,20 @@ std::pair> ConstraintSolver::tryDispatchSetIndexer( return {true, std::nullopt}; } -bool ConstraintSolver::tryDispatch(const SetIndexerConstraint& c, NotNull constraint, bool force) +bool ConstraintSolver::tryDispatch(const AssignConstraint& c, NotNull constraint) { - TypeId subjectType = follow(c.subjectType); - if (isBlocked(subjectType)) - return block(subjectType, constraint); + const TypeId lhsTy = follow(c.lhsType); + const TypeId rhsTy = follow(c.rhsType); - auto [dispatched, resultTy] = tryDispatchSetIndexer(constraint, subjectType, c.indexType, c.propType, /*expandFreeTypeBounds=*/true); - if (dispatched) - { - bindBlockedType(c.propType, resultTy.value_or(builtinTypes->errorRecoveryType()), subjectType, constraint); - unblock(c.propType, constraint->location); - } - - return dispatched; -} - -bool ConstraintSolver::tryDispatchUnpack1(NotNull constraint, TypeId resultTy, TypeId srcTy, bool resultIsLValue) -{ - resultTy = follow(resultTy); - LUAU_ASSERT(canMutate(resultTy, constraint)); + if (!get(lhsTy) && isBlocked(lhsTy)) + return block(lhsTy, constraint); auto tryExpand = [&](TypeId ty) { LocalType* lt = getMutable(ty); - if (!lt || !resultIsLValue) + if (!lt) return; - lt->domain = simplifyUnion(builtinTypes, arena, lt->domain, srcTy).result; + lt->domain = simplifyUnion(builtinTypes, arena, lt->domain, rhsTy).result; LUAU_ASSERT(lt->blockCount > 0); --lt->blockCount; @@ -1745,11 +1620,289 @@ bool ConstraintSolver::tryDispatchUnpack1(NotNull constraint, } }; - if (auto ut = get(resultTy)) - std::for_each(begin(ut), end(ut), tryExpand); - else if (get(resultTy)) - tryExpand(resultTy); - else if (get(resultTy)) + if (auto ut = get(lhsTy)) + { + // FIXME: I suspect there's a bug here where lhsTy is a union that contains no LocalTypes. + for (TypeId t : ut) + tryExpand(t); + } + else if (get(lhsTy)) + tryExpand(lhsTy); + else + unify(constraint, rhsTy, lhsTy); + + unblock(lhsTy, constraint->location); + + return true; +} + +bool ConstraintSolver::tryDispatch(const AssignPropConstraint& c, NotNull constraint) +{ + TypeId lhsType = follow(c.lhsType); + const std::string& propName = c.propName; + const TypeId rhsType = follow(c.rhsType); + + if (isBlocked(lhsType)) + return block(lhsType, constraint); + + // 1. lhsType is a class that already has the prop + // 2. lhsType is a table that already has the prop (or a union or + // intersection that has the prop in aggregate) + // 3. lhsType has a metatable that already has the prop + // 4. lhsType is an unsealed table that does not have the prop, but has a + // string indexer + // 5. lhsType is an unsealed table that does not have the prop or a string + // indexer + + // Important: In every codepath through this function, the type `c.propType` + // must be bound to something, even if it's just the errorType. + + if (auto lhsClass = get(lhsType)) + { + const Property* prop = lookupClassProp(lhsClass, propName); + if (!prop || !prop->writeTy.has_value()) + return true; + + emplaceType(asMutable(c.propType), *prop->writeTy); + unify(constraint, rhsType, *prop->writeTy); + return true; + } + + if (auto lhsFree = getMutable(lhsType)) + { + if (get(lhsFree->upperBound) || get(lhsFree->upperBound)) + lhsType = lhsFree->upperBound; + else + { + TypeId newUpperBound = arena->addType(TableType{TableState::Free, TypeLevel{}, constraint->scope}); + TableType* upperTable = getMutable(newUpperBound); + LUAU_ASSERT(upperTable); + + upperTable->props[c.propName] = rhsType; + + // Food for thought: Could we block if simplification encounters a blocked type? + lhsFree->upperBound = simplifyIntersection(builtinTypes, arena, lhsFree->upperBound, newUpperBound).result; + + emplaceType(asMutable(c.propType), rhsType); + return true; + } + } + + // Handle the case that lhsType is a table that already has the property or + // a matching indexer. This also handles unions and intersections. + const auto [blocked, maybeTy] = lookupTableProp(constraint, lhsType, propName, ValueContext::LValue); + if (!blocked.empty()) + { + for (TypeId t : blocked) + block(t, constraint); + return false; + } + + if (maybeTy) + { + const TypeId propTy = *maybeTy; + emplaceType(asMutable(c.propType), propTy); + unify(constraint, rhsType, propTy); + return true; + } + + if (auto lhsMeta = get(lhsType)) + lhsType = follow(lhsMeta->table); + + // Handle the case where the lhs type is a table that does not have the + // named property. It could be a table with a string indexer, or an unsealed + // or free table that can grow. + if (auto lhsTable = getMutable(lhsType)) + { + if (auto it = lhsTable->props.find(propName); it != lhsTable->props.end()) + { + Property& prop = it->second; + + if (prop.writeTy.has_value()) + { + emplaceType(asMutable(c.propType), *prop.writeTy); + unify(constraint, rhsType, *prop.writeTy); + return true; + } + else + { + LUAU_ASSERT(prop.isReadOnly()); + if (lhsTable->state == TableState::Unsealed || lhsTable->state == TableState::Free) + { + prop.writeTy = prop.readTy; + emplaceType(asMutable(c.propType), *prop.writeTy); + unify(constraint, rhsType, *prop.writeTy); + return true; + } + else + { + emplaceType(asMutable(c.propType), builtinTypes->errorType); + return true; + } + } + } + + if (lhsTable->indexer && maybeString(lhsTable->indexer->indexType)) + { + emplaceType(asMutable(c.propType), rhsType); + unify(constraint, rhsType, lhsTable->indexer->indexResultType); + return true; + } + + if (lhsTable->state == TableState::Unsealed || lhsTable->state == TableState::Free) + { + emplaceType(asMutable(c.propType), rhsType); + lhsTable->props[propName] = Property::rw(rhsType); + return true; + } + } + + emplaceType(asMutable(c.propType), builtinTypes->errorType); + + return true; +} + +bool ConstraintSolver::tryDispatch(const AssignIndexConstraint& c, NotNull constraint) +{ + const TypeId lhsType = follow(c.lhsType); + const TypeId indexType = follow(c.indexType); + const TypeId rhsType = follow(c.rhsType); + + if (isBlocked(lhsType)) + return block(lhsType, constraint); + + // 0. lhsType could be an intersection or union. + // 1. lhsType is a class with an indexer + // 2. lhsType is a table with an indexer, or it has a metatable that has an indexer + // 3. lhsType is a free or unsealed table and can grow an indexer + + // Important: In every codepath through this function, the type `c.propType` + // must be bound to something, even if it's just the errorType. + + auto tableStuff = [&](TableType* lhsTable) -> std::optional { + if (lhsTable->indexer) + { + unify(constraint, indexType, lhsTable->indexer->indexType); + unify(constraint, rhsType, lhsTable->indexer->indexResultType); + emplaceType(asMutable(c.propType), lhsTable->indexer->indexResultType); + return true; + } + + if (lhsTable->state == TableState::Unsealed || lhsTable->state == TableState::Free) + { + lhsTable->indexer = TableIndexer{indexType, rhsType}; + emplaceType(asMutable(c.propType), rhsType); + return true; + } + + return {}; + }; + + if (auto lhsFree = getMutable(lhsType)) + { + if (auto lhsTable = getMutable(lhsFree->upperBound)) + { + if (auto res = tableStuff(lhsTable)) + return *res; + } + + TypeId newUpperBound = + arena->addType(TableType{/*props*/ {}, TableIndexer{indexType, rhsType}, TypeLevel{}, constraint->scope, TableState::Free}); + const TableType* newTable = get(newUpperBound); + LUAU_ASSERT(newTable); + + unify(constraint, lhsType, newUpperBound); + + LUAU_ASSERT(newTable->indexer); + emplaceType(asMutable(c.propType), newTable->indexer->indexResultType); + return true; + } + + if (auto lhsTable = getMutable(lhsType)) + { + std::optional res = tableStuff(lhsTable); + if (res.has_value()) + return *res; + } + + if (auto lhsClass = get(lhsType)) + { + while (true) + { + if (lhsClass->indexer) + { + unify(constraint, indexType, lhsClass->indexer->indexType); + unify(constraint, rhsType, lhsClass->indexer->indexResultType); + emplaceType(asMutable(c.propType), lhsClass->indexer->indexResultType); + return true; + } + + if (lhsClass->parent) + lhsClass = get(lhsClass->parent); + else + break; + } + return true; + } + + if (auto lhsIntersection = getMutable(lhsType)) + { + std::set parts; + + for (TypeId t : lhsIntersection) + { + if (auto tbl = getMutable(follow(t))) + { + if (tbl->indexer) + { + unify(constraint, indexType, tbl->indexer->indexType); + parts.insert(tbl->indexer->indexResultType); + } + + if (tbl->state == TableState::Unsealed || tbl->state == TableState::Free) + { + tbl->indexer = TableIndexer{indexType, rhsType}; + parts.insert(rhsType); + } + } + else if (auto cls = get(follow(t))) + { + while (true) + { + if (cls->indexer) + { + unify(constraint, indexType, cls->indexer->indexType); + parts.insert(cls->indexer->indexResultType); + break; + } + + if (cls->parent) + cls = get(cls->parent); + else + break; + } + } + } + + TypeId res = simplifyIntersection(builtinTypes, arena, std::move(parts)).result; + + unify(constraint, rhsType, res); + } + + // Other types do not support index assignment. + emplaceType(asMutable(c.propType), builtinTypes->errorType); + + return true; +} + +bool ConstraintSolver::tryDispatchUnpack1(NotNull constraint, TypeId resultTy, TypeId srcTy) +{ + resultTy = follow(resultTy); + LUAU_ASSERT(canMutate(resultTy, constraint)); + + LUAU_ASSERT(get(resultTy)); + + if (get(resultTy)) { if (follow(srcTy) == resultTy) { @@ -1765,10 +1918,7 @@ bool ConstraintSolver::tryDispatchUnpack1(NotNull constraint, bindBlockedType(resultTy, srcTy, srcTy, constraint); } else - { - LUAU_ASSERT(resultIsLValue); unify(constraint, srcTy, resultTy); - } unblock(resultTy, constraint->location); return true; @@ -1804,7 +1954,7 @@ bool ConstraintSolver::tryDispatch(const UnpackConstraint& c, NotNull(resultTy); c.resultIsLValue && lt) - { - lt->domain = simplifyUnion(builtinTypes, arena, lt->domain, builtinTypes->nilType).result; - LUAU_ASSERT(0 <= lt->blockCount); - --lt->blockCount; - - if (0 == lt->blockCount) - { - shiftReferences(resultTy, lt->domain); - emplaceType(asMutable(resultTy), lt->domain); - } - } - else if (get(resultTy) || get(resultTy)) + if (get(resultTy) || get(resultTy)) { emplaceType(asMutable(resultTy), builtinTypes->nilType); unblock(resultTy, constraint->location); @@ -1842,11 +1980,6 @@ bool ConstraintSolver::tryDispatch(const UnpackConstraint& c, NotNull constraint) -{ - return tryDispatchUnpack1(constraint, c.resultType, c.sourceType, c.resultIsLValue); -} - bool ConstraintSolver::tryDispatch(const ReduceConstraint& c, NotNull constraint, bool force) { TypeId ty = follow(c.ty); @@ -1942,13 +2075,23 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl getMutable(tableTy)->indexer = TableIndexer{keyTy, valueTy}; pushConstraint(constraint->scope, constraint->location, SubtypeConstraint{iteratorTy, tableTy}); - pushConstraint(constraint->scope, constraint->location, UnpackConstraint{c.variables, arena->addTypePack({keyTy, valueTy}), /*resultIsLValue=*/true}); + + auto it = begin(c.variables); + auto endIt = end(c.variables); + if (it != endIt) + { + pushConstraint(constraint->scope, constraint->location, AssignConstraint{*it, keyTy}); + ++it; + } + if (it != endIt) + pushConstraint(constraint->scope, constraint->location, AssignConstraint{*it, valueTy}); + return true; } auto unpack = [&](TypeId ty) { - TypePackId variadic = arena->addTypePack(VariadicTypePack{ty}); - pushConstraint(constraint->scope, constraint->location, UnpackConstraint{c.variables, variadic, /* resultIsLValue */ true}); + for (TypeId varTy : c.variables) + pushConstraint(constraint->scope, constraint->location, AssignConstraint{varTy, ty}); }; if (get(iteratorTy)) @@ -2043,10 +2186,7 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl // If nextFn is nullptr, then the iterator function has an improper signature. if (nextFn) - { - const TypePackId nextRetPack = nextFn->retTypes; - pushConstraint(constraint->scope, constraint->location, UnpackConstraint{c.variables, nextRetPack, /* resultIsLValue=*/true}); - } + unpackAndAssign(c.variables, nextFn->retTypes, constraint); return true; } @@ -2119,12 +2259,37 @@ bool ConstraintSolver::tryDispatchIterableFunction( modifiedNextRetHead.push_back(*it); TypePackId modifiedNextRetPack = arena->addTypePack(std::move(modifiedNextRetHead), it.tail()); - auto psc = pushConstraint(constraint->scope, constraint->location, UnpackConstraint{c.variables, modifiedNextRetPack, /* resultIsLValue */ true}); - inheritBlocks(constraint, psc); + + auto unpackConstraint = unpackAndAssign(c.variables, modifiedNextRetPack, constraint); + + inheritBlocks(constraint, unpackConstraint); return true; } +NotNull ConstraintSolver::unpackAndAssign(TypePackId destTypes, TypePackId srcTypes, NotNull constraint) +{ + std::vector unpackedTys; + for (TypeId _ty : destTypes) + { + (void) _ty; + unpackedTys.push_back(arena->addType(BlockedType{})); + } + + TypePackId unpackedTp = arena->addTypePack(TypePack{unpackedTys}); + auto unpackConstraint = pushConstraint(constraint->scope, constraint->location, UnpackConstraint{unpackedTp, srcTypes}); + + size_t i = 0; + for (TypeId varTy : destTypes) + { + pushConstraint(constraint->scope, constraint->location, AssignConstraint{varTy, unpackedTys[i]}); + getMutable(unpackedTys[i])->setOwner(unpackConstraint); + ++i; + } + + return unpackConstraint; +} + std::pair, std::optional> ConstraintSolver::lookupTableProp(NotNull constraint, TypeId subjectType, const std::string& propName, ValueContext context, bool inConditional, bool suppressSimplification) { @@ -2759,8 +2924,13 @@ std::optional ConstraintSolver::generalizeFreeType(NotNull scope, if (get(t)) { auto refCount = unresolvedConstraints.find(t); - if (!refCount || *refCount > 1) + if (refCount && *refCount > 0) return {}; + + // if no reference count is present, then that means the only constraints referring to + // this free type need only for it to be generalized. in principle, this means we could + // have actually never generated the free type in the first place, but we couldn't know + // that until all constraint generation is complete. } return generalize(NotNull{arena}, builtinTypes, scope, type); @@ -2769,7 +2939,7 @@ std::optional ConstraintSolver::generalizeFreeType(NotNull scope, bool ConstraintSolver::hasUnresolvedConstraints(TypeId ty) { if (auto refCount = unresolvedConstraints.find(ty)) - return *refCount > 1; + return *refCount > 0; return false; } diff --git a/Analysis/src/Simplify.cpp b/Analysis/src/Simplify.cpp index ca78d54d..dae7b2d2 100644 --- a/Analysis/src/Simplify.cpp +++ b/Analysis/src/Simplify.cpp @@ -1368,6 +1368,17 @@ SimplifyResult simplifyIntersection(NotNull builtinTypes, NotNull< return SimplifyResult{res, std::move(s.blockedTypes)}; } +SimplifyResult simplifyIntersection(NotNull builtinTypes, NotNull arena, std::set parts) +{ + LUAU_ASSERT(FFlag::DebugLuauDeferredConstraintResolution); + + TypeSimplifier s{builtinTypes, arena}; + + TypeId res = s.intersectFromParts(std::move(parts)); + + return SimplifyResult{res, std::move(s.blockedTypes)}; +} + SimplifyResult simplifyUnion(NotNull builtinTypes, NotNull arena, TypeId left, TypeId right) { LUAU_ASSERT(FFlag::DebugLuauDeferredConstraintResolution); diff --git a/Analysis/src/Subtyping.cpp b/Analysis/src/Subtyping.cpp index f2d51b31..040c3fc6 100644 --- a/Analysis/src/Subtyping.cpp +++ b/Analysis/src/Subtyping.cpp @@ -1438,6 +1438,7 @@ SubtypingResult Subtyping::isCovariantWith( result.andAlso(isCovariantWith(env, subNorm->strings, superNorm->strings)); result.andAlso(isCovariantWith(env, subNorm->strings, superNorm->tables)); result.andAlso(isCovariantWith(env, subNorm->threads, superNorm->threads)); + result.andAlso(isCovariantWith(env, subNorm->buffers, superNorm->buffers)); result.andAlso(isCovariantWith(env, subNorm->tables, superNorm->tables)); result.andAlso(isCovariantWith(env, subNorm->functions, superNorm->functions)); // isCovariantWith(subNorm->tyvars, superNorm->tyvars); diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index e3ee2252..4e81a870 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -1787,23 +1787,18 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts) { return tos(c.resultType) + " ~ hasProp " + tos(c.subjectType) + ", \"" + c.prop + "\" ctx=" + std::to_string(int(c.context)); } - else if constexpr (std::is_same_v) - { - const std::string pathStr = c.path.size() == 1 ? "\"" + c.path[0] + "\"" : "[\"" + join(c.path, "\", \"") + "\"]"; - return tos(c.resultType) + " ~ setProp " + tos(c.subjectType) + ", " + pathStr + " " + tos(c.propType); - } else if constexpr (std::is_same_v) { return tos(c.resultType) + " ~ hasIndexer " + tos(c.subjectType) + " " + tos(c.indexType); } - else if constexpr (std::is_same_v) - { - return "setIndexer " + tos(c.subjectType) + " [ " + tos(c.indexType) + " ] " + tos(c.propType); - } + else if constexpr (std::is_same_v) + return "assign " + tos(c.lhsType) + " " + tos(c.rhsType); + else if constexpr (std::is_same_v) + return "assignProp " + tos(c.lhsType) + " " + c.propName + " " + tos(c.rhsType); + else if constexpr (std::is_same_v) + return "assignIndex " + tos(c.lhsType) + " " + tos(c.indexType) + " " + tos(c.rhsType); else if constexpr (std::is_same_v) return tos(c.resultPack) + " ~ ...unpack " + tos(c.sourcePack); - else if constexpr (std::is_same_v) - return tos(c.resultType) + " ~ unpack " + tos(c.sourceType); else if constexpr (std::is_same_v) return "reduce " + tos(c.ty); else if constexpr (std::is_same_v) diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index 37e0f039..5ffeb951 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -1561,6 +1561,18 @@ struct TypeChecker2 else reportError(CannotExtendTable{exprType, CannotExtendTable::Indexer, "indexer??"}, indexExpr->location); } + else if (auto mt = get(exprType)) + { + const TableType* tt = get(follow(mt->table)); + LUAU_ASSERT(tt); + if (tt->indexer) + testIsSubtype(indexType, tt->indexer->indexType, indexExpr->index->location); + else + { + // TODO: Maybe the metatable has a suitable indexer? + reportError(CannotExtendTable{exprType, CannotExtendTable::Indexer, "indexer??"}, indexExpr->location); + } + } else if (auto cls = get(exprType)) { if (cls->indexer) @@ -1581,6 +1593,19 @@ struct TypeChecker2 reportError(OptionalValueAccess{exprType}, indexExpr->location); } } + else if (auto exprIntersection = get(exprType)) + { + for (TypeId part : exprIntersection) + { + (void)part; + } + } + else if (get(exprType) || isErrorSuppressing(indexExpr->location, exprType)) + { + // Nothing + } + else + reportError(NotATable{exprType}, indexExpr->location); } void visit(AstExprFunction* fn) @@ -2720,6 +2745,8 @@ struct TypeChecker2 fetch(builtinTypes->stringType); if (normValid) fetch(norm->threads); + if (normValid) + fetch(norm->buffers); if (normValid) { diff --git a/Analysis/src/TypeFamily.cpp b/Analysis/src/TypeFamily.cpp index a8d7d2f7..3a0483a6 100644 --- a/Analysis/src/TypeFamily.cpp +++ b/Analysis/src/TypeFamily.cpp @@ -11,12 +11,10 @@ #include "Luau/OverloadResolution.h" #include "Luau/Set.h" #include "Luau/Simplify.h" -#include "Luau/Substitution.h" #include "Luau/Subtyping.h" #include "Luau/ToString.h" #include "Luau/TxnLog.h" #include "Luau/Type.h" -#include "Luau/TypeCheckLimits.h" #include "Luau/TypeFamilyReductionGuesser.h" #include "Luau/TypeFwd.h" #include "Luau/TypeUtils.h" @@ -346,9 +344,8 @@ struct FamilyReducer if (tryGuessing(subject)) return; - TypeFamilyQueue queue{NotNull{&queuedTys}, NotNull{&queuedTps}}; TypeFamilyReductionResult result = - tfit->family->reducer(subject, NotNull{&queue}, tfit->typeArguments, tfit->packArguments, NotNull{&ctx}); + tfit->family->reducer(subject, tfit->typeArguments, tfit->packArguments, NotNull{&ctx}); handleFamilyReduction(subject, result); } } @@ -372,9 +369,8 @@ struct FamilyReducer if (tryGuessing(subject)) return; - TypeFamilyQueue queue{NotNull{&queuedTys}, NotNull{&queuedTps}}; TypeFamilyReductionResult result = - tfit->family->reducer(subject, NotNull{&queue}, tfit->typeArguments, tfit->packArguments, NotNull{&ctx}); + tfit->family->reducer(subject, tfit->typeArguments, tfit->packArguments, NotNull{&ctx}); handleFamilyReduction(subject, result); } } @@ -449,24 +445,89 @@ FamilyGraphReductionResult reduceFamilies(TypePackId entrypoint, Location locati std::move(collector.cyclicInstance), location, ctx, force); } -void TypeFamilyQueue::add(TypeId instanceTy) -{ - LUAU_ASSERT(get(instanceTy)); - queuedTys->push_back(instanceTy); -} - -void TypeFamilyQueue::add(TypePackId instanceTp) -{ - LUAU_ASSERT(get(instanceTp)); - queuedTps->push_back(instanceTp); -} - bool isPending(TypeId ty, ConstraintSolver* solver) { return is(ty) || (solver && solver->hasUnresolvedConstraints(ty)); } -TypeFamilyReductionResult notFamilyFn(TypeId instance, NotNull queue, const std::vector& typeParams, +template +static std::optional> tryDistributeTypeFamilyApp(F f, TypeId instance, + const std::vector& typeParams, const std::vector& packParams, NotNull ctx, Args&& ...args) +{ + // op (a | b) (c | d) ~ (op a (c | d)) | (op b (c | d)) ~ (op a c) | (op a d) | (op b c) | (op b d) + bool uninhabited = false; + std::vector blockedTypes; + std::vector results; + size_t cartesianProductSize = 1; + + const UnionType* firstUnion = nullptr; + size_t unionIndex; + + std::vector arguments = typeParams; + for (size_t i = 0; i < arguments.size(); ++i) + { + const UnionType* ut = get(follow(arguments[i])); + if (!ut) + continue; + + // We want to find the first union type in the set of arguments to distribute that one and only that one union. + // The function `f` we have is recursive, so `arguments[unionIndex]` will be updated in-place for each option in + // the union we've found in this context, so that index will no longer be a union type. Any other arguments at + // index + 1 or after will instead be distributed, if those are a union, which will be subjected to the same rules. + if (!firstUnion && ut) + { + firstUnion = ut; + unionIndex = i; + } + + cartesianProductSize *= std::distance(begin(ut), end(ut)); + + // TODO: We'd like to report that the type family application is too complex here. + if (size_t(DFInt::LuauTypeFamilyApplicationCartesianProductLimit) <= cartesianProductSize) + return {{std::nullopt, true, {}, {}}}; + } + + if (!firstUnion) + { + // If we couldn't find any union type argument, we're not distributing. + return std::nullopt; + } + + for (TypeId option : firstUnion) + { + arguments[unionIndex] = option; + + TypeFamilyReductionResult result = f(instance, arguments, packParams, ctx, args...); + blockedTypes.insert(blockedTypes.end(), result.blockedTypes.begin(), result.blockedTypes.end()); + uninhabited |= result.uninhabited; + + if (result.uninhabited || !result.result) + break; + else + results.push_back(*result.result); + } + + if (uninhabited || !blockedTypes.empty()) + return {{std::nullopt, uninhabited, blockedTypes, {}}}; + + if (!results.empty()) + { + if (results.size() == 1) + return {{results[0], false, {}, {}}}; + + TypeId resultTy = ctx->arena->addType(TypeFamilyInstanceType{ + NotNull{&kBuiltinTypeFamilies.unionFamily}, + std::move(results), + {}, + }); + + return {{resultTy, false, {}, {}}}; + } + + return std::nullopt; +} + +TypeFamilyReductionResult notFamilyFn(TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 1 || !packParams.empty()) @@ -477,14 +538,20 @@ TypeFamilyReductionResult notFamilyFn(TypeId instance, NotNullbuiltins->neverType, false, {}, {}}; + if (isPending(ty, ctx->solver)) return {std::nullopt, false, {ty}, {}}; + if (auto result = tryDistributeTypeFamilyApp(notFamilyFn, instance, typeParams, packParams, ctx)) + return *result; + // `not` operates on anything and returns a `boolean` always. return {ctx->builtins->booleanType, false, {}, {}}; } -TypeFamilyReductionResult lenFamilyFn(TypeId instance, NotNull queue, const std::vector& typeParams, +TypeFamilyReductionResult lenFamilyFn(TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 1 || !packParams.empty()) @@ -495,6 +562,9 @@ TypeFamilyReductionResult lenFamilyFn(TypeId instance, NotNullbuiltins->neverType, false, {}, {}}; + // check to see if the operand type is resolved enough, and wait to reduce if not // the use of `typeFromNormal` later necessitates blocking on local types. if (isPending(operandTy, ctx->solver) || get(operandTy)) @@ -533,6 +603,9 @@ TypeFamilyReductionResult lenFamilyFn(TypeId instance, NotNullhasTopTable() || get(normalizedOperand)) return {ctx->builtins->numberType, false, {}, {}}; + if (auto result = tryDistributeTypeFamilyApp(notFamilyFn, instance, typeParams, packParams, ctx)) + return *result; + // findMetatableEntry demands the ability to emit errors, so we must give it // the necessary state to do that, even if we intend to just eat the errors. ErrorVec dummy; @@ -570,7 +643,7 @@ TypeFamilyReductionResult lenFamilyFn(TypeId instance, NotNullbuiltins->numberType, false, {}, {}}; } -TypeFamilyReductionResult unmFamilyFn(TypeId instance, NotNull queue, const std::vector& typeParams, +TypeFamilyReductionResult unmFamilyFn(TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 1 || !packParams.empty()) @@ -581,6 +654,9 @@ TypeFamilyReductionResult unmFamilyFn(TypeId instance, NotNullbuiltins->neverType, false, {}, {}}; + // check to see if the operand type is resolved enough, and wait to reduce if not if (isPending(operandTy, ctx->solver)) return {std::nullopt, false, {operandTy}, {}}; @@ -612,6 +688,9 @@ TypeFamilyReductionResult unmFamilyFn(TypeId instance, NotNullisExactlyNumber()) return {ctx->builtins->numberType, false, {}, {}}; + if (auto result = tryDistributeTypeFamilyApp(notFamilyFn, instance, typeParams, packParams, ctx)) + return *result; + // findMetatableEntry demands the ability to emit errors, so we must give it // the necessary state to do that, even if we intend to just eat the errors. ErrorVec dummy; @@ -664,7 +743,7 @@ NotNull TypeFamilyContext::pushConstraint(ConstraintV&& c) return newConstraint; } -TypeFamilyReductionResult numericBinopFamilyFn(TypeId instance, NotNull queue, const std::vector& typeParams, +TypeFamilyReductionResult numericBinopFamilyFn(TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx, const std::string metamethod) { if (typeParams.size() != 2 || !packParams.empty()) @@ -723,67 +802,8 @@ TypeFamilyReductionResult numericBinopFamilyFn(TypeId instance, NotNull< if (normLhsTy->isExactlyNumber() && normRhsTy->isExactlyNumber()) return {ctx->builtins->numberType, false, {}, {}}; - // op (a | b) (c | d) ~ (op a (c | d)) | (op b (c | d)) ~ (op a c) | (op a d) | (op b c) | (op b d) - std::vector results; - bool uninhabited = false; - std::vector blockedTypes; - std::vector arguments = typeParams; - auto distributeFamilyApp = [&](const UnionType* ut, size_t argumentIndex) { - // Returning true here means we completed the loop without any problems. - for (TypeId option : ut) - { - arguments[argumentIndex] = option; - - TypeFamilyReductionResult result = numericBinopFamilyFn(instance, queue, arguments, packParams, ctx, metamethod); - blockedTypes.insert(blockedTypes.end(), result.blockedTypes.begin(), result.blockedTypes.end()); - uninhabited |= result.uninhabited; - - if (result.uninhabited) - return false; - else if (!result.result) - return false; - else - results.push_back(*result.result); - } - - return true; - }; - - const UnionType* lhsUnion = get(lhsTy); - const UnionType* rhsUnion = get(rhsTy); - if (lhsUnion || rhsUnion) - { - // TODO: We'd like to report that the type family application is too complex here. - size_t lhsUnionSize = lhsUnion ? std::distance(begin(lhsUnion), end(lhsUnion)) : 1; - size_t rhsUnionSize = rhsUnion ? std::distance(begin(rhsUnion), end(rhsUnion)) : 1; - if (size_t(DFInt::LuauTypeFamilyApplicationCartesianProductLimit) <= lhsUnionSize * rhsUnionSize) - return {std::nullopt, true, {}, {}}; - - if (lhsUnion && !distributeFamilyApp(lhsUnion, 0)) - return {std::nullopt, uninhabited, std::move(blockedTypes), {}}; - - if (rhsUnion && !distributeFamilyApp(rhsUnion, 1)) - return {std::nullopt, uninhabited, std::move(blockedTypes), {}}; - - if (results.empty()) - { - // If this happens, it means `distributeFamilyApp` has improperly returned `true` even - // though there exists no arm of the union that is inhabited or have a reduced type. - ctx->ice->ice("`distributeFamilyApp` failed to add any types to the results vector?"); - } - - if (results.size() == 1) - return {results[0], false, {}, {}}; - - TypeId resultTy = ctx->arena->addType(TypeFamilyInstanceType{ - NotNull{&kBuiltinTypeFamilies.unionFamily}, - std::move(results), - {}, - }); - - queue->add(resultTy); - return {resultTy, false, {}, {}}; - } + if (auto result = tryDistributeTypeFamilyApp(numericBinopFamilyFn, instance, typeParams, packParams, ctx, metamethod)) + return *result; // findMetatableEntry demands the ability to emit errors, so we must give it // the necessary state to do that, even if we intend to just eat the errors. @@ -826,7 +846,7 @@ TypeFamilyReductionResult numericBinopFamilyFn(TypeId instance, NotNull< return {extracted.head.front(), false, {}, {}}; } -TypeFamilyReductionResult addFamilyFn(TypeId instance, NotNull queue, const std::vector& typeParams, +TypeFamilyReductionResult addFamilyFn(TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) @@ -835,10 +855,10 @@ TypeFamilyReductionResult addFamilyFn(TypeId instance, NotNull subFamilyFn(TypeId instance, NotNull queue, const std::vector& typeParams, +TypeFamilyReductionResult subFamilyFn(TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) @@ -847,10 +867,10 @@ TypeFamilyReductionResult subFamilyFn(TypeId instance, NotNull mulFamilyFn(TypeId instance, NotNull queue, const std::vector& typeParams, +TypeFamilyReductionResult mulFamilyFn(TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) @@ -859,10 +879,10 @@ TypeFamilyReductionResult mulFamilyFn(TypeId instance, NotNull divFamilyFn(TypeId instance, NotNull queue, const std::vector& typeParams, +TypeFamilyReductionResult divFamilyFn(TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) @@ -871,10 +891,10 @@ TypeFamilyReductionResult divFamilyFn(TypeId instance, NotNull idivFamilyFn(TypeId instance, NotNull queue, const std::vector& typeParams, +TypeFamilyReductionResult idivFamilyFn(TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) @@ -883,10 +903,10 @@ TypeFamilyReductionResult idivFamilyFn(TypeId instance, NotNull powFamilyFn(TypeId instance, NotNull queue, const std::vector& typeParams, +TypeFamilyReductionResult powFamilyFn(TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) @@ -895,10 +915,10 @@ TypeFamilyReductionResult powFamilyFn(TypeId instance, NotNull modFamilyFn(TypeId instance, NotNull queue, const std::vector& typeParams, +TypeFamilyReductionResult modFamilyFn(TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) @@ -907,10 +927,10 @@ TypeFamilyReductionResult modFamilyFn(TypeId instance, NotNull concatFamilyFn(TypeId instance, NotNull queue, const std::vector& typeParams, +TypeFamilyReductionResult concatFamilyFn(TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) @@ -922,6 +942,10 @@ TypeFamilyReductionResult concatFamilyFn(TypeId instance, NotNullbuiltins->neverType, false, {}, {}}; + // check to see if both operand types are resolved enough, and wait to reduce if not if (isPending(lhsTy, ctx->solver)) return {std::nullopt, false, {lhsTy}, {}}; @@ -962,6 +986,9 @@ TypeFamilyReductionResult concatFamilyFn(TypeId instance, NotNullisSubtypeOfString() || normLhsTy->isExactlyNumber()) && (normRhsTy->isSubtypeOfString() || normRhsTy->isExactlyNumber())) return {ctx->builtins->stringType, false, {}, {}}; + if (auto result = tryDistributeTypeFamilyApp(concatFamilyFn, instance, typeParams, packParams, ctx)) + return *result; + // findMetatableEntry demands the ability to emit errors, so we must give it // the necessary state to do that, even if we intend to just eat the errors. ErrorVec dummy; @@ -1011,7 +1038,7 @@ TypeFamilyReductionResult concatFamilyFn(TypeId instance, NotNullbuiltins->stringType, false, {}, {}}; } -TypeFamilyReductionResult andFamilyFn(TypeId instance, NotNull queue, const std::vector& typeParams, +TypeFamilyReductionResult andFamilyFn(TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) @@ -1062,7 +1089,7 @@ TypeFamilyReductionResult andFamilyFn(TypeId instance, NotNull orFamilyFn(TypeId instance, NotNull queue, const std::vector& typeParams, +TypeFamilyReductionResult orFamilyFn(TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) @@ -1113,7 +1140,7 @@ TypeFamilyReductionResult orFamilyFn(TypeId instance, NotNull comparisonFamilyFn(TypeId instance, NotNull queue, const std::vector& typeParams, +static TypeFamilyReductionResult comparisonFamilyFn(TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx, const std::string metamethod) { @@ -1126,6 +1153,9 @@ static TypeFamilyReductionResult comparisonFamilyFn(TypeId instance, Not TypeId lhsTy = follow(typeParams.at(0)); TypeId rhsTy = follow(typeParams.at(1)); + if (lhsTy == instance || rhsTy == instance) + return {ctx->builtins->neverType, false, {}, {}}; + if (isPending(lhsTy, ctx->solver)) return {std::nullopt, false, {lhsTy}, {}}; else if (isPending(rhsTy, ctx->solver)) @@ -1207,6 +1237,9 @@ static TypeFamilyReductionResult comparisonFamilyFn(TypeId instance, Not if (normLhsTy->isExactlyNumber() && normRhsTy->isExactlyNumber()) return {ctx->builtins->booleanType, false, {}, {}}; + if (auto result = tryDistributeTypeFamilyApp(comparisonFamilyFn, instance, typeParams, packParams, ctx, metamethod)) + return *result; + // findMetatableEntry demands the ability to emit errors, so we must give it // the necessary state to do that, even if we intend to just eat the errors. ErrorVec dummy; @@ -1246,7 +1279,7 @@ static TypeFamilyReductionResult comparisonFamilyFn(TypeId instance, Not return {ctx->builtins->booleanType, false, {}, {}}; } -TypeFamilyReductionResult ltFamilyFn(TypeId instance, NotNull queue, const std::vector& typeParams, +TypeFamilyReductionResult ltFamilyFn(TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) @@ -1255,10 +1288,10 @@ TypeFamilyReductionResult ltFamilyFn(TypeId instance, NotNull leFamilyFn(TypeId instance, NotNull queue, const std::vector& typeParams, +TypeFamilyReductionResult leFamilyFn(TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) @@ -1267,10 +1300,10 @@ TypeFamilyReductionResult leFamilyFn(TypeId instance, NotNull eqFamilyFn(TypeId instance, NotNull queue, const std::vector& typeParams, +TypeFamilyReductionResult eqFamilyFn(TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) @@ -1407,7 +1440,7 @@ struct FindRefinementBlockers : TypeOnceVisitor }; -TypeFamilyReductionResult refineFamilyFn(TypeId instance, NotNull queue, const std::vector& typeParams, +TypeFamilyReductionResult refineFamilyFn(TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) @@ -1480,7 +1513,7 @@ TypeFamilyReductionResult refineFamilyFn(TypeId instance, NotNull singletonFamilyFn(TypeId instance, NotNull queue, const std::vector& typeParams, +TypeFamilyReductionResult singletonFamilyFn(TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 1 || !packParams.empty()) @@ -1517,7 +1550,7 @@ TypeFamilyReductionResult singletonFamilyFn(TypeId instance, NotNullbuiltins->unknownType, false, {}, {}}; } -TypeFamilyReductionResult unionFamilyFn(TypeId instance, NotNull queue, const std::vector& typeParams, +TypeFamilyReductionResult unionFamilyFn(TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (!packParams.empty()) @@ -1578,7 +1611,7 @@ TypeFamilyReductionResult unionFamilyFn(TypeId instance, NotNull intersectFamilyFn(TypeId instance, NotNull queue, const std::vector& typeParams, +TypeFamilyReductionResult intersectFamilyFn(TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (!packParams.empty()) @@ -1802,7 +1835,7 @@ TypeFamilyReductionResult keyofFamilyImpl( return {ctx->arena->addType(UnionType{singletons}), false, {}, {}}; } -TypeFamilyReductionResult keyofFamilyFn(TypeId instance, NotNull queue, const std::vector& typeParams, +TypeFamilyReductionResult keyofFamilyFn(TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 1 || !packParams.empty()) @@ -1814,7 +1847,7 @@ TypeFamilyReductionResult keyofFamilyFn(TypeId instance, NotNull rawkeyofFamilyFn(TypeId instance, NotNull queue, const std::vector& typeParams, +TypeFamilyReductionResult rawkeyofFamilyFn(TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 1 || !packParams.empty()) diff --git a/Analysis/src/TypeUtils.cpp b/Analysis/src/TypeUtils.cpp index 588b1da1..c2512ddc 100644 --- a/Analysis/src/TypeUtils.cpp +++ b/Analysis/src/TypeUtils.cpp @@ -38,6 +38,59 @@ bool occursCheck(TypeId needle, TypeId haystack) return false; } +// FIXME: Property is quite large. +// +// Returning it on the stack like this isn't great. We'd like to just return a +// const Property*, but we mint a property of type any if the subject type is +// any. +std::optional findTableProperty(NotNull builtinTypes, ErrorVec& errors, TypeId ty, const std::string& name, Location location) +{ + if (get(ty)) + return Property::rw(ty); + + if (const TableType* tableType = getTableType(ty)) + { + const auto& it = tableType->props.find(name); + if (it != tableType->props.end()) + return it->second; + } + + std::optional mtIndex = findMetatableEntry(builtinTypes, errors, ty, "__index", location); + int count = 0; + while (mtIndex) + { + TypeId index = follow(*mtIndex); + + if (count >= 100) + return std::nullopt; + + ++count; + + if (const auto& itt = getTableType(index)) + { + const auto& fit = itt->props.find(name); + if (fit != itt->props.end()) + return fit->second.type(); + } + else if (const auto& itf = get(index)) + { + std::optional r = first(follow(itf->retTypes)); + if (!r) + return builtinTypes->nilType; + else + return *r; + } + else if (get(index)) + return builtinTypes->anyType; + else + errors.push_back(TypeError{location, GenericError{"__index should either be a function or table. Got " + toString(index)}}); + + mtIndex = findMetatableEntry(builtinTypes, errors, *mtIndex, "__index", location); + } + + return std::nullopt; +} + std::optional findMetatableEntry( NotNull builtinTypes, ErrorVec& errors, TypeId type, const std::string& entry, Location location) { diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index 501707e1..814d7c8c 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -256,12 +256,16 @@ void setupState(lua_State* L) void setupArguments(lua_State* L, int argc, char** argv) { + lua_checkstack(L, argc); + for (int i = 0; i < argc; ++i) lua_pushstring(L, argv[i]); } std::string runCode(lua_State* L, const std::string& source) { + lua_checkstack(L, LUA_MINSTACK); + std::string bytecode = Luau::compile(source, copts()); if (luau_load(L, "=stdin", bytecode.data(), bytecode.size(), 0) != 0) @@ -432,6 +436,8 @@ static void completeIndexer(lua_State* L, const std::string& editBuffer, const A std::string_view lookup = editBuffer; bool completeOnlyFunctions = false; + lua_checkstack(L, LUA_MINSTACK); + // Push the global variable table to begin the search lua_pushvalue(L, LUA_GLOBALSINDEX); diff --git a/CodeGen/include/Luau/CodeGen.h b/CodeGen/include/Luau/CodeGen.h index ac444b7b..171e9197 100644 --- a/CodeGen/include/Luau/CodeGen.h +++ b/CodeGen/include/Luau/CodeGen.h @@ -92,7 +92,7 @@ struct HostIrHooks // Guards should take a VM exit to 'pcpos' HostVectorAccessHandler vectorAccess = nullptr; - // Handle namecalled performed on a vector value + // Handle namecall performed on a vector value // 'sourceReg' (self argument) is guaranteed to be a vector // All other arguments can be of any type // Guards should take a VM exit to 'pcpos' @@ -103,6 +103,9 @@ struct CompilationOptions { unsigned int flags = 0; HostIrHooks hooks; + + // null-terminated array of userdata types names that might have custom lowering + const char* const* userdataTypes = nullptr; }; struct CompilationStats @@ -163,6 +166,12 @@ void create(lua_State* L, SharedCodeGenContext* codeGenContext); // Enable or disable native execution according to `enabled` argument void setNativeExecutionEnabled(lua_State* L, bool enabled); +// Given a name, this function must return the index of the type which matches the type array used all CompilationOptions and AssemblyOptions +// If the type is unknown, 0xff has to be returned +using UserdataRemapperCallback = uint8_t(void* context, const char* name, size_t nameLength); + +void setUserdataRemapper(lua_State* L, void* context, UserdataRemapperCallback cb); + using ModuleId = std::array; // Builds target function and all inner functions diff --git a/CodeGen/include/Luau/IrDump.h b/CodeGen/include/Luau/IrDump.h index dcca3c7b..d989a6c7 100644 --- a/CodeGen/include/Luau/IrDump.h +++ b/CodeGen/include/Luau/IrDump.h @@ -31,9 +31,11 @@ void toString(IrToStringContext& ctx, IrOp op); void toString(std::string& result, IrConst constant); -const char* getBytecodeTypeName(uint8_t type); +const char* getBytecodeTypeName_DEPRECATED(uint8_t type); +const char* getBytecodeTypeName(uint8_t type, const char* const* userdataTypes); -void toString(std::string& result, const BytecodeTypes& bcTypes); +void toString_DEPRECATED(std::string& result, const BytecodeTypes& bcTypes); +void toString(std::string& result, const BytecodeTypes& bcTypes, const char* const* userdataTypes); void toStringDetailed( IrToStringContext& ctx, const IrBlock& block, uint32_t blockIdx, const IrInst& inst, uint32_t instIdx, IncludeUseInfo includeUseInfo); diff --git a/CodeGen/include/Luau/IrUtils.h b/CodeGen/include/Luau/IrUtils.h index 0c8495e8..55b86822 100644 --- a/CodeGen/include/Luau/IrUtils.h +++ b/CodeGen/include/Luau/IrUtils.h @@ -241,6 +241,10 @@ IrValueKind getCmdValueKind(IrCmd cmd); bool isGCO(uint8_t tag); +// Optional bit has to be cleared at call site, otherwise, this will return 'false' for 'userdata?' +bool isUserdataBytecodeType(uint8_t ty); +bool isCustomUserdataBytecodeType(uint8_t ty); + // Manually add or remove use of an operand void addUse(IrFunction& function, IrOp op); void removeUse(IrFunction& function, IrOp op); diff --git a/CodeGen/src/BytecodeAnalysis.cpp b/CodeGen/src/BytecodeAnalysis.cpp index 900093d1..aed8c763 100644 --- a/CodeGen/src/BytecodeAnalysis.cpp +++ b/CodeGen/src/BytecodeAnalysis.cpp @@ -14,8 +14,6 @@ LUAU_FASTFLAG(LuauCodegenDirectUserdataFlow) LUAU_FASTFLAG(LuauLoadTypeInfo) // Because new VM typeinfo load changes the format used by Codegen, same flag is used LUAU_FASTFLAGVARIABLE(LuauCodegenTypeInfo, false) // New analysis is flagged separately -LUAU_FASTFLAG(LuauTypeInfoLookupImprovement) -LUAU_FASTFLAGVARIABLE(LuauCodegenVectorMispredictFix, false) LUAU_FASTFLAGVARIABLE(LuauCodegenAnalyzeHostVectorOps, false) LUAU_FASTFLAGVARIABLE(LuauCodegenLoadTypeUpvalCheck, false) @@ -68,21 +66,13 @@ void loadBytecodeTypeInfo(IrFunction& function) Proto* proto = function.proto; - if (FFlag::LuauTypeInfoLookupImprovement) - { - if (!proto) - return; - } - else - { - if (!proto || !proto->typeinfo) - return; - } + if (!proto) + return; BytecodeTypeInfo& typeInfo = function.bcTypeInfo; // If there is no typeinfo, we generate default values for arguments and upvalues - if (FFlag::LuauTypeInfoLookupImprovement && !proto->typeinfo) + if (!proto->typeinfo) { typeInfo.argumentTypes.resize(proto->numparams, LBC_TYPE_ANY); typeInfo.upvalueTypes.resize(proto->nups, LBC_TYPE_ANY); @@ -150,8 +140,6 @@ void loadBytecodeTypeInfo(IrFunction& function) static void prepareRegTypeInfoLookups(BytecodeTypeInfo& typeInfo) { - CODEGEN_ASSERT(FFlag::LuauTypeInfoLookupImprovement); - // Sort by register first, then by end PC std::sort(typeInfo.regTypes.begin(), typeInfo.regTypes.end(), [](const BytecodeRegTypeInfo& a, const BytecodeRegTypeInfo& b) { if (a.reg != b.reg) @@ -186,39 +174,26 @@ static BytecodeRegTypeInfo* findRegType(BytecodeTypeInfo& info, uint8_t reg, int { CODEGEN_ASSERT(FFlag::LuauCodegenTypeInfo); - if (FFlag::LuauTypeInfoLookupImprovement) - { - auto b = info.regTypes.begin() + info.regTypeOffsets[reg]; - auto e = info.regTypes.begin() + info.regTypeOffsets[reg + 1]; - - // Doen't have info - if (b == e) - return nullptr; - - // No info after the last live range - if (pc >= (e - 1)->endpc) - return nullptr; - - for (auto it = b; it != e; ++it) - { - CODEGEN_ASSERT(it->reg == reg); - - if (pc >= it->startpc && pc < it->endpc) - return &*it; - } + auto b = info.regTypes.begin() + info.regTypeOffsets[reg]; + auto e = info.regTypes.begin() + info.regTypeOffsets[reg + 1]; + // Doen't have info + if (b == e) return nullptr; - } - else - { - for (BytecodeRegTypeInfo& el : info.regTypes) - { - if (reg == el.reg && pc >= el.startpc && pc < el.endpc) - return ⪙ - } + // No info after the last live range + if (pc >= (e - 1)->endpc) return nullptr; + + for (auto it = b; it != e; ++it) + { + CODEGEN_ASSERT(it->reg == reg); + + if (pc >= it->startpc && pc < it->endpc) + return &*it; } + + return nullptr; } static void refineRegType(BytecodeTypeInfo& info, uint8_t reg, int pc, uint8_t ty) @@ -233,7 +208,7 @@ static void refineRegType(BytecodeTypeInfo& info, uint8_t reg, int pc, uint8_t t if (regType->type == LBC_TYPE_ANY) regType->type = ty; } - else if (FFlag::LuauTypeInfoLookupImprovement && reg < info.argumentTypes.size()) + else if (reg < info.argumentTypes.size()) { if (info.argumentTypes[reg] == LBC_TYPE_ANY) info.argumentTypes[reg] = ty; @@ -627,8 +602,7 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) BytecodeTypeInfo& bcTypeInfo = function.bcTypeInfo; - if (FFlag::LuauTypeInfoLookupImprovement) - prepareRegTypeInfoLookups(bcTypeInfo); + prepareRegTypeInfoLookups(bcTypeInfo); // Setup our current knowledge of type tags based on arguments uint8_t regTags[256]; @@ -786,32 +760,22 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) regTags[ra] = LBC_TYPE_ANY; - if (FFlag::LuauCodegenVectorMispredictFix) + if (bcType.a == LBC_TYPE_VECTOR) { - if (bcType.a == LBC_TYPE_VECTOR) + TString* str = gco2ts(function.proto->k[kc].value.gc); + const char* field = getstr(str); + + if (str->len == 1) { - TString* str = gco2ts(function.proto->k[kc].value.gc); - const char* field = getstr(str); + // Same handling as LOP_GETTABLEKS block in lvmexecute.cpp - case-insensitive comparison with "X" / "Y" / "Z" + char ch = field[0] | ' '; - if (str->len == 1) - { - // Same handling as LOP_GETTABLEKS block in lvmexecute.cpp - case-insensitive comparison with "X" / "Y" / "Z" - char ch = field[0] | ' '; - - if (ch == 'x' || ch == 'y' || ch == 'z') - regTags[ra] = LBC_TYPE_NUMBER; - } - - if (FFlag::LuauCodegenAnalyzeHostVectorOps && regTags[ra] == LBC_TYPE_ANY && hostHooks.vectorAccessBytecodeType) - regTags[ra] = hostHooks.vectorAccessBytecodeType(field, str->len); + if (ch == 'x' || ch == 'y' || ch == 'z') + regTags[ra] = LBC_TYPE_NUMBER; } - } - else - { - // Assuming that vector component is being indexed - // TODO: check what key is used - if (bcType.a == LBC_TYPE_VECTOR) - regTags[ra] = LBC_TYPE_NUMBER; + + if (FFlag::LuauCodegenAnalyzeHostVectorOps && regTags[ra] == LBC_TYPE_ANY && hostHooks.vectorAccessBytecodeType) + regTags[ra] = hostHooks.vectorAccessBytecodeType(field, str->len); } bcType.result = regTags[ra]; diff --git a/CodeGen/src/CodeGenAssembly.cpp b/CodeGen/src/CodeGenAssembly.cpp index ce3a57bd..269bf8dc 100644 --- a/CodeGen/src/CodeGenAssembly.cpp +++ b/CodeGen/src/CodeGenAssembly.cpp @@ -13,7 +13,7 @@ #include "lapi.h" LUAU_FASTFLAG(LuauCodegenTypeInfo) -LUAU_FASTFLAGVARIABLE(LuauCodegenIrTypeNames, false) +LUAU_FASTFLAG(LuauLoadUserdataInfo) namespace Luau { @@ -22,8 +22,6 @@ namespace CodeGen static const LocVar* tryFindLocal(const Proto* proto, int reg, int pcpos) { - CODEGEN_ASSERT(FFlag::LuauCodegenIrTypeNames); - for (int i = 0; i < proto->sizelocvars; i++) { const LocVar& local = proto->locvars[i]; @@ -37,8 +35,6 @@ static const LocVar* tryFindLocal(const Proto* proto, int reg, int pcpos) const char* tryFindLocalName(const Proto* proto, int reg, int pcpos) { - CODEGEN_ASSERT(FFlag::LuauCodegenIrTypeNames); - const LocVar* var = tryFindLocal(proto, reg, pcpos); if (var && var->varname) @@ -49,8 +45,6 @@ const char* tryFindLocalName(const Proto* proto, int reg, int pcpos) const char* tryFindUpvalueName(const Proto* proto, int upval) { - CODEGEN_ASSERT(FFlag::LuauCodegenIrTypeNames); - if (proto->upvalues) { CODEGEN_ASSERT(upval < proto->sizeupvalues); @@ -72,22 +66,10 @@ static void logFunctionHeader(AssemblyBuilder& build, Proto* proto) for (int i = 0; i < proto->numparams; i++) { - if (FFlag::LuauCodegenIrTypeNames) - { - if (const char* name = tryFindLocalName(proto, i, 0)) - build.logAppend("%s%s", i == 0 ? "" : ", ", name); - else - build.logAppend("%s$arg%d", i == 0 ? "" : ", ", i); - } + if (const char* name = tryFindLocalName(proto, i, 0)) + build.logAppend("%s%s", i == 0 ? "" : ", ", name); else - { - LocVar* var = proto->locvars ? &proto->locvars[proto->sizelocvars - proto->numparams + i] : nullptr; - - if (var && var->varname) - build.logAppend("%s%s", i == 0 ? "" : ", ", getstr(var->varname)); - else - build.logAppend("%s$arg%d", i == 0 ? "" : ", ", i); - } + build.logAppend("%s$arg%d", i == 0 ? "" : ", ", i); } if (proto->numparams != 0 && proto->is_vararg) @@ -102,9 +84,10 @@ static void logFunctionHeader(AssemblyBuilder& build, Proto* proto) } template -static void logFunctionTypes(AssemblyBuilder& build, const IrFunction& function) +static void logFunctionTypes_DEPRECATED(AssemblyBuilder& build, const IrFunction& function) { CODEGEN_ASSERT(FFlag::LuauCodegenTypeInfo); + CODEGEN_ASSERT(!FFlag::LuauLoadUserdataInfo); const BytecodeTypeInfo& typeInfo = function.bcTypeInfo; @@ -112,20 +95,12 @@ static void logFunctionTypes(AssemblyBuilder& build, const IrFunction& function) { uint8_t ty = typeInfo.argumentTypes[i]; - if (FFlag::LuauCodegenIrTypeNames) + if (ty != LBC_TYPE_ANY) { - if (ty != LBC_TYPE_ANY) - { - if (const char* name = tryFindLocalName(function.proto, int(i), 0)) - build.logAppend("; R%d: %s [argument '%s']\n", int(i), getBytecodeTypeName(ty), name); - else - build.logAppend("; R%d: %s [argument]\n", int(i), getBytecodeTypeName(ty)); - } - } - else - { - if (ty != LBC_TYPE_ANY) - build.logAppend("; R%d: %s [argument]\n", int(i), getBytecodeTypeName(ty)); + if (const char* name = tryFindLocalName(function.proto, int(i), 0)) + build.logAppend("; R%d: %s [argument '%s']\n", int(i), getBytecodeTypeName_DEPRECATED(ty), name); + else + build.logAppend("; R%d: %s [argument]\n", int(i), getBytecodeTypeName_DEPRECATED(ty)); } } @@ -133,38 +108,76 @@ static void logFunctionTypes(AssemblyBuilder& build, const IrFunction& function) { uint8_t ty = typeInfo.upvalueTypes[i]; - if (FFlag::LuauCodegenIrTypeNames) + if (ty != LBC_TYPE_ANY) { - if (ty != LBC_TYPE_ANY) - { - if (const char* name = tryFindUpvalueName(function.proto, int(i))) - build.logAppend("; U%d: %s ['%s']\n", int(i), getBytecodeTypeName(ty), name); - else - build.logAppend("; U%d: %s\n", int(i), getBytecodeTypeName(ty)); - } - } - else - { - if (ty != LBC_TYPE_ANY) - build.logAppend("; U%d: %s\n", int(i), getBytecodeTypeName(ty)); + if (const char* name = tryFindUpvalueName(function.proto, int(i))) + build.logAppend("; U%d: %s ['%s']\n", int(i), getBytecodeTypeName_DEPRECATED(ty), name); + else + build.logAppend("; U%d: %s\n", int(i), getBytecodeTypeName_DEPRECATED(ty)); } } for (const BytecodeRegTypeInfo& el : typeInfo.regTypes) { - if (FFlag::LuauCodegenIrTypeNames) - { - // Using last active position as the PC because 'startpc' for type info is before local is initialized - if (const char* name = tryFindLocalName(function.proto, el.reg, el.endpc - 1)) - build.logAppend("; R%d: %s from %d to %d [local '%s']\n", el.reg, getBytecodeTypeName(el.type), el.startpc, el.endpc, name); - else - build.logAppend("; R%d: %s from %d to %d\n", el.reg, getBytecodeTypeName(el.type), el.startpc, el.endpc); - } + // Using last active position as the PC because 'startpc' for type info is before local is initialized + if (const char* name = tryFindLocalName(function.proto, el.reg, el.endpc - 1)) + build.logAppend("; R%d: %s from %d to %d [local '%s']\n", el.reg, getBytecodeTypeName_DEPRECATED(el.type), el.startpc, el.endpc, name); else + build.logAppend("; R%d: %s from %d to %d\n", el.reg, getBytecodeTypeName_DEPRECATED(el.type), el.startpc, el.endpc); + } +} + +template +static void logFunctionTypes(AssemblyBuilder& build, const IrFunction& function, const char* const* userdataTypes) +{ + CODEGEN_ASSERT(FFlag::LuauCodegenTypeInfo); + CODEGEN_ASSERT(FFlag::LuauLoadUserdataInfo); + + const BytecodeTypeInfo& typeInfo = function.bcTypeInfo; + + for (size_t i = 0; i < typeInfo.argumentTypes.size(); i++) + { + uint8_t ty = typeInfo.argumentTypes[i]; + + const char* type = getBytecodeTypeName(ty, userdataTypes); + const char* optional = (ty & LBC_TYPE_OPTIONAL_BIT) != 0 ? "?" : ""; + + if (ty != LBC_TYPE_ANY) { - build.logAppend("; R%d: %s from %d to %d\n", el.reg, getBytecodeTypeName(el.type), el.startpc, el.endpc); + if (const char* name = tryFindLocalName(function.proto, int(i), 0)) + build.logAppend("; R%d: %s%s [argument '%s']\n", int(i), type, optional, name); + else + build.logAppend("; R%d: %s%s [argument]\n", int(i), type, optional); } } + + for (size_t i = 0; i < typeInfo.upvalueTypes.size(); i++) + { + uint8_t ty = typeInfo.upvalueTypes[i]; + + const char* type = getBytecodeTypeName(ty, userdataTypes); + const char* optional = (ty & LBC_TYPE_OPTIONAL_BIT) != 0 ? "?" : ""; + + if (ty != LBC_TYPE_ANY) + { + if (const char* name = tryFindUpvalueName(function.proto, int(i))) + build.logAppend("; U%d: %s%s ['%s']\n", int(i), type, optional, name); + else + build.logAppend("; U%d: %s%s\n", int(i), type, optional); + } + } + + for (const BytecodeRegTypeInfo& el : typeInfo.regTypes) + { + const char* type = getBytecodeTypeName(el.type, userdataTypes); + const char* optional = (el.type & LBC_TYPE_OPTIONAL_BIT) != 0 ? "?" : ""; + + // Using last active position as the PC because 'startpc' for type info is before local is initialized + if (const char* name = tryFindLocalName(function.proto, el.reg, el.endpc - 1)) + build.logAppend("; R%d: %s%s from %d to %d [local '%s']\n", el.reg, type, optional, el.startpc, el.endpc, name); + else + build.logAppend("; R%d: %s%s from %d to %d\n", el.reg, type, optional, el.startpc, el.endpc); + } } unsigned getInstructionCount(const Instruction* insns, const unsigned size) @@ -224,7 +237,12 @@ static std::string getAssemblyImpl(AssemblyBuilder& build, const TValue* func, A logFunctionHeader(build, p); if (FFlag::LuauCodegenTypeInfo && options.includeIrTypes) - logFunctionTypes(build, ir.function); + { + if (FFlag::LuauLoadUserdataInfo) + logFunctionTypes(build, ir.function, options.compilationOptions.userdataTypes); + else + logFunctionTypes_DEPRECATED(build, ir.function); + } CodeGenCompilationResult result = CodeGenCompilationResult::Success; diff --git a/CodeGen/src/CodeGenContext.cpp b/CodeGen/src/CodeGenContext.cpp index a94388f6..7788d099 100644 --- a/CodeGen/src/CodeGenContext.cpp +++ b/CodeGen/src/CodeGenContext.cpp @@ -612,5 +612,29 @@ void setNativeExecutionEnabled(lua_State* L, bool enabled) L->global->ecb.enter = enabled ? onEnter : onEnterDisabled; } +static uint8_t userdataRemapperWrap(lua_State* L, const char* str, size_t len) +{ + if (BaseCodeGenContext* codegenCtx = getCodeGenContext(L)) + { + uint8_t index = codegenCtx->userdataRemapper(codegenCtx->userdataRemappingContext, str, len); + + if (index < (LBC_TYPE_TAGGED_USERDATA_END - LBC_TYPE_TAGGED_USERDATA_BASE)) + return LBC_TYPE_TAGGED_USERDATA_BASE + index; + } + + return LBC_TYPE_USERDATA; +} + +void setUserdataRemapper(lua_State* L, void* context, UserdataRemapperCallback cb) +{ + if (BaseCodeGenContext* codegenCtx = getCodeGenContext(L)) + { + codegenCtx->userdataRemappingContext = context; + codegenCtx->userdataRemapper = cb; + + L->global->ecb.gettypemapping = cb ? userdataRemapperWrap : nullptr; + } +} + } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/CodeGenContext.h b/CodeGen/src/CodeGenContext.h index 516a7064..43099a9b 100644 --- a/CodeGen/src/CodeGenContext.h +++ b/CodeGen/src/CodeGenContext.h @@ -50,6 +50,9 @@ public: uint8_t* gateData = nullptr; size_t gateDataSize = 0; + void* userdataRemappingContext = nullptr; + UserdataRemapperCallback* userdataRemapper = nullptr; + NativeContext context; }; diff --git a/CodeGen/src/CodeGenLower.h b/CodeGen/src/CodeGenLower.h index efd1034d..6015ef10 100644 --- a/CodeGen/src/CodeGenLower.h +++ b/CodeGen/src/CodeGenLower.h @@ -28,6 +28,7 @@ LUAU_FASTINT(CodegenHeuristicsInstructionLimit) LUAU_FASTINT(CodegenHeuristicsBlockLimit) LUAU_FASTINT(CodegenHeuristicsBlockInstructionLimit) LUAU_FASTFLAG(LuauCodegenRemoveDeadStores5) +LUAU_FASTFLAG(LuauLoadUserdataInfo) namespace Luau { @@ -149,7 +150,11 @@ inline bool lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction& if (bcTypes.result != LBC_TYPE_ANY || bcTypes.a != LBC_TYPE_ANY || bcTypes.b != LBC_TYPE_ANY || bcTypes.c != LBC_TYPE_ANY) { - toString(ctx.result, bcTypes); + if (FFlag::LuauLoadUserdataInfo) + toString(ctx.result, bcTypes, options.compilationOptions.userdataTypes); + else + toString_DEPRECATED(ctx.result, bcTypes); + build.logAppend("\n"); } } diff --git a/CodeGen/src/IrBuilder.cpp b/CodeGen/src/IrBuilder.cpp index 76d015e9..723d35c4 100644 --- a/CodeGen/src/IrBuilder.cpp +++ b/CodeGen/src/IrBuilder.cpp @@ -14,8 +14,8 @@ #include LUAU_FASTFLAG(LuauLoadTypeInfo) // Because new VM typeinfo load changes the format used by Codegen, same flag is used -LUAU_FASTFLAG(LuauTypeInfoLookupImprovement) LUAU_FASTFLAG(LuauCodegenAnalyzeHostVectorOps) +LUAU_FASTFLAG(LuauLoadUserdataInfo) namespace Luau { @@ -119,20 +119,13 @@ static bool hasTypedParameters(const BytecodeTypeInfo& typeInfo) { CODEGEN_ASSERT(FFlag::LuauLoadTypeInfo); - if (FFlag::LuauTypeInfoLookupImprovement) + for (auto el : typeInfo.argumentTypes) { - for (auto el : typeInfo.argumentTypes) - { - if (el != LBC_TYPE_ANY) - return true; - } + if (el != LBC_TYPE_ANY) + return true; + } - return false; - } - else - { - return !typeInfo.argumentTypes.empty(); - } + return false; } static void buildArgumentTypeChecks(IrBuilder& build) @@ -197,6 +190,19 @@ static void buildArgumentTypeChecks(IrBuilder& build) case LBC_TYPE_BUFFER: build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TBUFFER), build.vmExit(kVmExitEntryGuardPc)); break; + default: + if (FFlag::LuauLoadUserdataInfo) + { + if (tag >= LBC_TYPE_TAGGED_USERDATA_BASE && tag < LBC_TYPE_TAGGED_USERDATA_END) + { + build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TUSERDATA), build.vmExit(kVmExitEntryGuardPc)); + } + else + { + CODEGEN_ASSERT(!"unknown argument type tag"); + } + } + break; } if (optional) diff --git a/CodeGen/src/IrDump.cpp b/CodeGen/src/IrDump.cpp index 48a50ecb..c47a0b8f 100644 --- a/CodeGen/src/IrDump.cpp +++ b/CodeGen/src/IrDump.cpp @@ -7,6 +7,8 @@ #include +LUAU_FASTFLAG(LuauLoadUserdataInfo) + namespace Luau { namespace CodeGen @@ -480,8 +482,10 @@ void toString(std::string& result, IrConst constant) } } -const char* getBytecodeTypeName(uint8_t type) +const char* getBytecodeTypeName_DEPRECATED(uint8_t type) { + CODEGEN_ASSERT(!FFlag::LuauLoadUserdataInfo); + switch (type & ~LBC_TYPE_OPTIONAL_BIT) { case LBC_TYPE_NIL: @@ -512,13 +516,78 @@ const char* getBytecodeTypeName(uint8_t type) return nullptr; } -void toString(std::string& result, const BytecodeTypes& bcTypes) +const char* getBytecodeTypeName(uint8_t type, const char* const* userdataTypes) { + CODEGEN_ASSERT(FFlag::LuauLoadUserdataInfo); + + // Optional bit should be handled externally + type = type & ~LBC_TYPE_OPTIONAL_BIT; + + if (type >= LBC_TYPE_TAGGED_USERDATA_BASE && type < LBC_TYPE_TAGGED_USERDATA_END) + { + if (userdataTypes) + return userdataTypes[type - LBC_TYPE_TAGGED_USERDATA_BASE]; + + return "userdata"; + } + + switch (type) + { + case LBC_TYPE_NIL: + return "nil"; + case LBC_TYPE_BOOLEAN: + return "boolean"; + case LBC_TYPE_NUMBER: + return "number"; + case LBC_TYPE_STRING: + return "string"; + case LBC_TYPE_TABLE: + return "table"; + case LBC_TYPE_FUNCTION: + return "function"; + case LBC_TYPE_THREAD: + return "thread"; + case LBC_TYPE_USERDATA: + return "userdata"; + case LBC_TYPE_VECTOR: + return "vector"; + case LBC_TYPE_BUFFER: + return "buffer"; + case LBC_TYPE_ANY: + return "any"; + } + + CODEGEN_ASSERT(!"Unhandled type in getBytecodeTypeName"); + return nullptr; +} + +void toString_DEPRECATED(std::string& result, const BytecodeTypes& bcTypes) +{ + CODEGEN_ASSERT(!FFlag::LuauLoadUserdataInfo); + if (bcTypes.c != LBC_TYPE_ANY) - append(result, "%s <- %s, %s, %s", getBytecodeTypeName(bcTypes.result), getBytecodeTypeName(bcTypes.a), getBytecodeTypeName(bcTypes.b), - getBytecodeTypeName(bcTypes.c)); + append(result, "%s <- %s, %s, %s", getBytecodeTypeName_DEPRECATED(bcTypes.result), getBytecodeTypeName_DEPRECATED(bcTypes.a), + getBytecodeTypeName_DEPRECATED(bcTypes.b), getBytecodeTypeName_DEPRECATED(bcTypes.c)); else - append(result, "%s <- %s, %s", getBytecodeTypeName(bcTypes.result), getBytecodeTypeName(bcTypes.a), getBytecodeTypeName(bcTypes.b)); + append(result, "%s <- %s, %s", getBytecodeTypeName_DEPRECATED(bcTypes.result), getBytecodeTypeName_DEPRECATED(bcTypes.a), + getBytecodeTypeName_DEPRECATED(bcTypes.b)); +} + +void toString(std::string& result, const BytecodeTypes& bcTypes, const char* const* userdataTypes) +{ + CODEGEN_ASSERT(FFlag::LuauLoadUserdataInfo); + + append(result, "%s%s", getBytecodeTypeName(bcTypes.result, userdataTypes), (bcTypes.result & LBC_TYPE_OPTIONAL_BIT) != 0 ? "?" : ""); + append(result, " <- "); + append(result, "%s%s", getBytecodeTypeName(bcTypes.a, userdataTypes), (bcTypes.a & LBC_TYPE_OPTIONAL_BIT) != 0 ? "?" : ""); + append(result, ", "); + append(result, "%s%s", getBytecodeTypeName(bcTypes.b, userdataTypes), (bcTypes.b & LBC_TYPE_OPTIONAL_BIT) != 0 ? "?" : ""); + + if (bcTypes.c != LBC_TYPE_ANY) + { + append(result, ", "); + append(result, "%s%s", getBytecodeTypeName(bcTypes.c, userdataTypes), (bcTypes.c & LBC_TYPE_OPTIONAL_BIT) != 0 ? "?" : ""); + } } static void appendBlockSet(IrToStringContext& ctx, BlockIteratorWrapper blocks) diff --git a/CodeGen/src/IrTranslation.cpp b/CodeGen/src/IrTranslation.cpp index 291f618b..93073a92 100644 --- a/CodeGen/src/IrTranslation.cpp +++ b/CodeGen/src/IrTranslation.cpp @@ -14,7 +14,6 @@ #include "ltm.h" LUAU_FASTFLAGVARIABLE(LuauCodegenDirectUserdataFlow, false) -LUAU_FASTFLAGVARIABLE(LuauCodegenFixVectorFields, false) LUAU_FASTFLAG(LuauCodegenAnalyzeHostVectorOps) namespace Luau @@ -1200,19 +1199,19 @@ void translateInstGetTableKS(IrBuilder& build, const Instruction* pc, int pcpos) TString* str = gco2ts(build.function.proto->k[aux].value.gc); const char* field = getstr(str); - if ((!FFlag::LuauCodegenFixVectorFields || str->len == 1) && (*field == 'X' || *field == 'x')) + if (str->len == 1 && (*field == 'X' || *field == 'x')) { IrOp value = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(rb), build.constInt(0)); build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), value); build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); } - else if ((!FFlag::LuauCodegenFixVectorFields || str->len == 1) && (*field == 'Y' || *field == 'y')) + else if (str->len == 1 && (*field == 'Y' || *field == 'y')) { IrOp value = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(rb), build.constInt(4)); build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), value); build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); } - else if ((!FFlag::LuauCodegenFixVectorFields || str->len == 1) && (*field == 'Z' || *field == 'z')) + else if (str->len == 1 && (*field == 'Z' || *field == 'z')) { IrOp value = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(rb), build.constInt(8)); build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), value); diff --git a/CodeGen/src/IrUtils.cpp b/CodeGen/src/IrUtils.cpp index caa6b178..afc6ba5a 100644 --- a/CodeGen/src/IrUtils.cpp +++ b/CodeGen/src/IrUtils.cpp @@ -252,6 +252,16 @@ bool isGCO(uint8_t tag) return tag >= LUA_TSTRING; } +bool isUserdataBytecodeType(uint8_t ty) +{ + return ty == LBC_TYPE_USERDATA || isCustomUserdataBytecodeType(ty); +} + +bool isCustomUserdataBytecodeType(uint8_t ty) +{ + return ty >= LBC_TYPE_TAGGED_USERDATA_BASE && ty < LBC_TYPE_TAGGED_USERDATA_END; +} + void kill(IrFunction& function, IrInst& inst) { CODEGEN_ASSERT(inst.useCount == 0); diff --git a/CodeGen/src/OptimizeConstProp.cpp b/CodeGen/src/OptimizeConstProp.cpp index eae0baa3..9135a9ed 100644 --- a/CodeGen/src/OptimizeConstProp.cpp +++ b/CodeGen/src/OptimizeConstProp.cpp @@ -18,7 +18,7 @@ LUAU_FASTINTVARIABLE(LuauCodeGenMinLinearBlockPath, 3) LUAU_FASTINTVARIABLE(LuauCodeGenReuseSlotLimit, 64) LUAU_FASTFLAGVARIABLE(DebugLuauAbortingChecks, false) LUAU_FASTFLAG(LuauCodegenRemoveDeadStores5) -LUAU_FASTFLAGVARIABLE(LuauCodegenLoadPropCheckRegLinkInTv, false) +LUAU_FASTFLAGVARIABLE(LuauCodegenFixSplitStoreConstMismatch, false) namespace Luau { @@ -739,7 +739,7 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& // If we know the tag, we can try extracting the value from a register used by LOAD_TVALUE // To do that, we have to ensure that the register link of the source value is still valid - if (tag != 0xff && (!FFlag::LuauCodegenLoadPropCheckRegLinkInTv || state.tryGetRegLink(inst.b) != nullptr)) + if (tag != 0xff && state.tryGetRegLink(inst.b) != nullptr) { if (IrInst* arg = function.asInstOp(inst.b); arg && arg->cmd == IrCmd::LOAD_TVALUE && arg->a.kind == IrOpKind::VmReg) { @@ -750,18 +750,48 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& } } - // If we have constant tag and value, replace TValue store with tag/value pair store - if (tag != 0xff && value.kind != IrOpKind::None && (tag == LUA_TBOOLEAN || tag == LUA_TNUMBER || isGCO(tag))) + if (FFlag::LuauCodegenFixSplitStoreConstMismatch) { - replace(function, block, index, {IrCmd::STORE_SPLIT_TVALUE, inst.a, build.constTag(tag), value, inst.c}); + // If we have constant tag and value, replace TValue store with tag/value pair store + bool canSplitTvalueStore = false; - // Value can be propagated to future loads of the same register - if (inst.a.kind == IrOpKind::VmReg && activeLoadValue != kInvalidInstIdx) - state.valueMap[state.versionedVmRegLoad(activeLoadCmd, inst.a)] = activeLoadValue; + if (tag == LUA_TBOOLEAN && + (value.kind == IrOpKind::Inst || (value.kind == IrOpKind::Constant && function.constOp(value).kind == IrConstKind::Int))) + canSplitTvalueStore = true; + else if (tag == LUA_TNUMBER && + (value.kind == IrOpKind::Inst || (value.kind == IrOpKind::Constant && function.constOp(value).kind == IrConstKind::Double))) + canSplitTvalueStore = true; + else if (tag != 0xff && isGCO(tag) && value.kind == IrOpKind::Inst) + canSplitTvalueStore = true; + + if (canSplitTvalueStore) + { + replace(function, block, index, {IrCmd::STORE_SPLIT_TVALUE, inst.a, build.constTag(tag), value, inst.c}); + + // Value can be propagated to future loads of the same register + if (inst.a.kind == IrOpKind::VmReg && activeLoadValue != kInvalidInstIdx) + state.valueMap[state.versionedVmRegLoad(activeLoadCmd, inst.a)] = activeLoadValue; + } + else if (inst.a.kind == IrOpKind::VmReg) + { + state.forwardVmRegStoreToLoad(inst, IrCmd::LOAD_TVALUE); + } } - else if (inst.a.kind == IrOpKind::VmReg) + else { - state.forwardVmRegStoreToLoad(inst, IrCmd::LOAD_TVALUE); + // If we have constant tag and value, replace TValue store with tag/value pair store + if (tag != 0xff && value.kind != IrOpKind::None && (tag == LUA_TBOOLEAN || tag == LUA_TNUMBER || isGCO(tag))) + { + replace(function, block, index, {IrCmd::STORE_SPLIT_TVALUE, inst.a, build.constTag(tag), value, inst.c}); + + // Value can be propagated to future loads of the same register + if (inst.a.kind == IrOpKind::VmReg && activeLoadValue != kInvalidInstIdx) + state.valueMap[state.versionedVmRegLoad(activeLoadCmd, inst.a)] = activeLoadValue; + } + else if (inst.a.kind == IrOpKind::VmReg) + { + state.forwardVmRegStoreToLoad(inst, IrCmd::LOAD_TVALUE); + } } } break; diff --git a/Common/include/Luau/Bytecode.h b/Common/include/Luau/Bytecode.h index 7012d820..2ae54c67 100644 --- a/Common/include/Luau/Bytecode.h +++ b/Common/include/Luau/Bytecode.h @@ -438,7 +438,9 @@ enum LuauBytecodeTag LBC_VERSION_TARGET = 5, // Type encoding version LBC_TYPE_VERSION_DEPRECATED = 1, - LBC_TYPE_VERSION = 2, + LBC_TYPE_VERSION_MIN = 1, + LBC_TYPE_VERSION_MAX = 3, + LBC_TYPE_VERSION_TARGET = 3, // Types of constant table entries LBC_CONSTANT_NIL = 0, LBC_CONSTANT_BOOLEAN, @@ -465,6 +467,10 @@ enum LuauBytecodeType LBC_TYPE_BUFFER, LBC_TYPE_ANY = 15, + + LBC_TYPE_TAGGED_USERDATA_BASE = 64, + LBC_TYPE_TAGGED_USERDATA_END = 64 + 32, + LBC_TYPE_OPTIONAL_BIT = 1 << 7, LBC_TYPE_INVALID = 256, diff --git a/Compiler/include/Luau/BytecodeBuilder.h b/Compiler/include/Luau/BytecodeBuilder.h index 7f0115bb..59d30d62 100644 --- a/Compiler/include/Luau/BytecodeBuilder.h +++ b/Compiler/include/Luau/BytecodeBuilder.h @@ -79,6 +79,9 @@ public: void pushLocalTypeInfo(LuauBytecodeType type, uint8_t reg, uint32_t startpc, uint32_t endpc); void pushUpvalTypeInfo(LuauBytecodeType type); + uint32_t addUserdataType(const char* name); + void useUserdataType(uint32_t index); + void setDebugFunctionName(StringRef name); void setDebugFunctionLineDefined(int line); void setDebugLine(int line); @@ -229,6 +232,13 @@ private: LuauBytecodeType type; }; + struct UserdataType + { + std::string name; + uint32_t nameRef = 0; + bool used = false; + }; + struct Jump { uint32_t source; @@ -277,6 +287,8 @@ private: std::vector typedLocals; std::vector typedUpvals; + std::vector userdataTypes; + DenseHashMap stringTable; std::vector debugStrings; @@ -308,6 +320,8 @@ private: int32_t addConstant(const ConstantKey& key, const Constant& value); unsigned int addStringTableEntry(StringRef value); + + const char* tryGetUserdataTypeName(LuauBytecodeType type) const; }; } // namespace Luau diff --git a/Compiler/include/Luau/Compiler.h b/Compiler/include/Luau/Compiler.h index 698a50c4..119e0aa2 100644 --- a/Compiler/include/Luau/Compiler.h +++ b/Compiler/include/Luau/Compiler.h @@ -46,6 +46,9 @@ struct CompileOptions // null-terminated array of globals that are mutable; disables the import optimization for fields accessed through these const char* const* mutableGlobals = nullptr; + + // null-terminated array of userdata types that will be included in the type information + const char* const* userdataTypes = nullptr; }; class CompileError : public std::exception diff --git a/Compiler/include/luacode.h b/Compiler/include/luacode.h index a470319d..1d200817 100644 --- a/Compiler/include/luacode.h +++ b/Compiler/include/luacode.h @@ -42,6 +42,9 @@ struct lua_CompileOptions // null-terminated array of globals that are mutable; disables the import optimization for fields accessed through these const char* const* mutableGlobals; + + // null-terminated array of userdata types that will be included in the type information + const char* const* userdataTypes = nullptr; }; // compile source to bytecode; when source compilation fails, the resulting bytecode contains the encoded error. use free() to destroy diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp index 6c76b671..59aee1e7 100644 --- a/Compiler/src/BytecodeBuilder.cpp +++ b/Compiler/src/BytecodeBuilder.cpp @@ -7,9 +7,8 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauCompileNoJumpLineRetarget, false) -LUAU_FASTFLAG(LuauCompileRepeatUntilSkippedLocals) LUAU_FASTFLAGVARIABLE(LuauCompileTypeInfo, false) +LUAU_FASTFLAG(LuauCompileUserdataInfo) namespace Luau { @@ -335,6 +334,18 @@ unsigned int BytecodeBuilder::addStringTableEntry(StringRef value) return index; } +const char* BytecodeBuilder::tryGetUserdataTypeName(LuauBytecodeType type) const +{ + LUAU_ASSERT(FFlag::LuauCompileUserdataInfo); + + unsigned index = unsigned((type & ~LBC_TYPE_OPTIONAL_BIT) - LBC_TYPE_TAGGED_USERDATA_BASE); + + if (index < userdataTypes.size()) + return userdataTypes[index].name.c_str(); + + return nullptr; +} + int32_t BytecodeBuilder::addConstantNil() { Constant c = {Constant::Type_Nil}; @@ -567,6 +578,25 @@ void BytecodeBuilder::pushUpvalTypeInfo(LuauBytecodeType type) typedUpvals.push_back(upval); } +uint32_t BytecodeBuilder::addUserdataType(const char* name) +{ + LUAU_ASSERT(FFlag::LuauCompileUserdataInfo); + + UserdataType ty; + + ty.name = name; + + userdataTypes.push_back(std::move(ty)); + return uint32_t(userdataTypes.size() - 1); +} + +void BytecodeBuilder::useUserdataType(uint32_t index) +{ + LUAU_ASSERT(FFlag::LuauCompileUserdataInfo); + + userdataTypes[index].used = true; +} + void BytecodeBuilder::setDebugFunctionName(StringRef name) { unsigned int index = addStringTableEntry(name); @@ -648,6 +678,15 @@ void BytecodeBuilder::finalize() { LUAU_ASSERT(bytecode.empty()); + if (FFlag::LuauCompileUserdataInfo) + { + for (auto& ty : userdataTypes) + { + if (ty.used) + ty.nameRef = addStringTableEntry(StringRef({ty.name.c_str(), ty.name.length()})); + } + } + // preallocate space for bytecode blob size_t capacity = 16; @@ -666,10 +705,24 @@ void BytecodeBuilder::finalize() bytecode = char(version); uint8_t typesversion = getTypeEncodingVersion(); + LUAU_ASSERT(typesversion >= LBC_TYPE_VERSION_MIN && typesversion <= LBC_TYPE_VERSION_MAX); writeByte(bytecode, typesversion); writeStringTable(bytecode); + if (FFlag::LuauCompileTypeInfo && FFlag::LuauCompileUserdataInfo) + { + // Write the mapping between used type name indices and their name + for (uint32_t i = 0; i < uint32_t(userdataTypes.size()); i++) + { + writeByte(bytecode, i + 1); + writeVarInt(bytecode, userdataTypes[i].nameRef); + } + + // 0 marks the end of the mapping + writeByte(bytecode, 0); + } + writeVarInt(bytecode, uint32_t(functions.size())); for (const Function& func : functions) @@ -1036,11 +1089,6 @@ void BytecodeBuilder::foldJumps() if (LUAU_INSN_OP(jumpInsn) == LOP_JUMP && LUAU_INSN_OP(targetInsn) == LOP_RETURN) { insns[jumpLabel] = targetInsn; - - if (!FFlag::LuauCompileNoJumpLineRetarget) - { - lines[jumpLabel] = lines[targetLabel]; - } } else if (int16_t(offset) == offset) { @@ -1198,7 +1246,10 @@ uint8_t BytecodeBuilder::getVersion() uint8_t BytecodeBuilder::getTypeEncodingVersion() { - return FFlag::LuauCompileTypeInfo ? LBC_TYPE_VERSION : LBC_TYPE_VERSION_DEPRECATED; + if (FFlag::LuauCompileTypeInfo && FFlag::LuauCompileUserdataInfo) + return LBC_TYPE_VERSION_TARGET; + + return FFlag::LuauCompileTypeInfo ? 2 : LBC_TYPE_VERSION_DEPRECATED; } #ifdef LUAU_ASSERTENABLED @@ -2275,7 +2326,7 @@ std::string BytecodeBuilder::dumpCurrentFunction(std::vector& dumpinstoffs) { const DebugLocal& l = debugLocals[i]; - if (FFlag::LuauCompileRepeatUntilSkippedLocals && l.startpc == l.endpc) + if (l.startpc == l.endpc) { LUAU_ASSERT(l.startpc < lines.size()); @@ -2301,35 +2352,74 @@ std::string BytecodeBuilder::dumpCurrentFunction(std::vector& dumpinstoffs) { const std::string& typeinfo = functions.back().typeinfo; - // Arguments start from third byte in function typeinfo string - for (uint8_t i = 2; i < typeinfo.size(); ++i) + if (FFlag::LuauCompileUserdataInfo) { - uint8_t et = typeinfo[i]; + // Arguments start from third byte in function typeinfo string + for (uint8_t i = 2; i < typeinfo.size(); ++i) + { + uint8_t et = typeinfo[i]; - const char* base = getBaseTypeString(et); - const char* optional = (et & LBC_TYPE_OPTIONAL_BIT) ? "?" : ""; + const char* userdata = tryGetUserdataTypeName(LuauBytecodeType(et)); + const char* name = userdata ? userdata : getBaseTypeString(et); + const char* optional = (et & LBC_TYPE_OPTIONAL_BIT) ? "?" : ""; - formatAppend(result, "R%d: %s%s [argument]\n", i - 2, base, optional); + formatAppend(result, "R%d: %s%s [argument]\n", i - 2, name, optional); + } + + for (size_t i = 0; i < typedUpvals.size(); ++i) + { + const TypedUpval& l = typedUpvals[i]; + + const char* userdata = tryGetUserdataTypeName(l.type); + const char* name = userdata ? userdata : getBaseTypeString(l.type); + const char* optional = (l.type & LBC_TYPE_OPTIONAL_BIT) ? "?" : ""; + + formatAppend(result, "U%d: %s%s\n", int(i), name, optional); + } + + for (size_t i = 0; i < typedLocals.size(); ++i) + { + const TypedLocal& l = typedLocals[i]; + + const char* userdata = tryGetUserdataTypeName(l.type); + const char* name = userdata ? userdata : getBaseTypeString(l.type); + const char* optional = (l.type & LBC_TYPE_OPTIONAL_BIT) ? "?" : ""; + + formatAppend(result, "R%d: %s%s from %d to %d\n", l.reg, name, optional, l.startpc, l.endpc); + } } - - for (size_t i = 0; i < typedUpvals.size(); ++i) + else { - const TypedUpval& l = typedUpvals[i]; + // Arguments start from third byte in function typeinfo string + for (uint8_t i = 2; i < typeinfo.size(); ++i) + { + uint8_t et = typeinfo[i]; - const char* base = getBaseTypeString(l.type); - const char* optional = (l.type & LBC_TYPE_OPTIONAL_BIT) ? "?" : ""; + const char* base = getBaseTypeString(et); + const char* optional = (et & LBC_TYPE_OPTIONAL_BIT) ? "?" : ""; - formatAppend(result, "U%d: %s%s\n", int(i), base, optional); - } + formatAppend(result, "R%d: %s%s [argument]\n", i - 2, base, optional); + } - for (size_t i = 0; i < typedLocals.size(); ++i) - { - const TypedLocal& l = typedLocals[i]; + for (size_t i = 0; i < typedUpvals.size(); ++i) + { + const TypedUpval& l = typedUpvals[i]; - const char* base = getBaseTypeString(l.type); - const char* optional = (l.type & LBC_TYPE_OPTIONAL_BIT) ? "?" : ""; + const char* base = getBaseTypeString(l.type); + const char* optional = (l.type & LBC_TYPE_OPTIONAL_BIT) ? "?" : ""; - formatAppend(result, "R%d: %s%s from %d to %d\n", l.reg, base, optional, l.startpc, l.endpc); + formatAppend(result, "U%d: %s%s\n", int(i), base, optional); + } + + for (size_t i = 0; i < typedLocals.size(); ++i) + { + const TypedLocal& l = typedLocals[i]; + + const char* base = getBaseTypeString(l.type); + const char* optional = (l.type & LBC_TYPE_OPTIONAL_BIT) ? "?" : ""; + + formatAppend(result, "R%d: %s%s from %d to %d\n", l.reg, base, optional, l.startpc, l.endpc); + } } } } diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index d5cd78a5..19526fa9 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -26,10 +26,9 @@ LUAU_FASTINTVARIABLE(LuauCompileInlineThreshold, 25) LUAU_FASTINTVARIABLE(LuauCompileInlineThresholdMaxBoost, 300) LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5) -LUAU_FASTFLAGVARIABLE(LuauCompileRepeatUntilSkippedLocals, false) LUAU_FASTFLAG(LuauCompileTypeInfo) -LUAU_FASTFLAGVARIABLE(LuauTypeInfoLookupImprovement, false) LUAU_FASTFLAGVARIABLE(LuauCompileTempTypeInfo, false) +LUAU_FASTFLAGVARIABLE(LuauCompileUserdataInfo, false) namespace Luau { @@ -107,6 +106,7 @@ struct Compiler , locstants(nullptr) , tableShapes(nullptr) , builtins(nullptr) + , userdataTypes(AstName()) , functionTypes(nullptr) , localTypes(nullptr) , exprTypes(nullptr) @@ -677,10 +677,7 @@ struct Compiler // if the argument is a local that isn't mutated, we will simply reuse the existing register if (int reg = le ? getExprLocalReg(le) : -1; reg >= 0 && (!lv || !lv->written)) { - if (FFlag::LuauTypeInfoLookupImprovement) - args.push_back({var, uint8_t(reg), {Constant::Type_Unknown}, kDefaultAllocPc}); - else - args.push_back({var, uint8_t(reg)}); + args.push_back({var, uint8_t(reg), {Constant::Type_Unknown}, kDefaultAllocPc}); } else { @@ -2771,16 +2768,14 @@ struct Compiler { validateContinueUntil(loops.back().continueUsed, stat->condition, body, i + 1); continueValidated = true; - - if (FFlag::LuauCompileRepeatUntilSkippedLocals) - conditionLocals = localStack.size(); + conditionLocals = localStack.size(); } } // if continue was used, some locals might not have had their initialization completed // the lifetime of these locals has to end before the condition is executed // because referencing skipped locals is not possible from the condition, this earlier closure doesn't affect upvalues - if (FFlag::LuauCompileRepeatUntilSkippedLocals && continueValidated) + if (continueValidated) { // if continueValidated is set, it means we have visited at least one body node and size > 0 setDebugLineEnd(body->body.data[body->body.size - 1]); @@ -4094,6 +4089,7 @@ struct Compiler DenseHashMap locstants; DenseHashMap tableShapes; DenseHashMap builtins; + DenseHashMap userdataTypes; DenseHashMap functionTypes; DenseHashMap localTypes; DenseHashMap exprTypes; @@ -4190,18 +4186,34 @@ void compileOrThrow(BytecodeBuilder& bytecode, const ParseResult& parseResult, c Compiler::FunctionVisitor functionVisitor(&compiler, functions); root->visit(&functionVisitor); + if (FFlag::LuauCompileUserdataInfo) + { + if (const char* const* ptr = options.userdataTypes) + { + for (; *ptr; ++ptr) + { + // Type will only resolve to an AstName if it is actually mentioned in the source + if (AstName name = names.get(*ptr); name.value) + compiler.userdataTypes[name] = bytecode.addUserdataType(name.value); + } + + if (uintptr_t(ptr - options.userdataTypes) > (LBC_TYPE_TAGGED_USERDATA_END - LBC_TYPE_TAGGED_USERDATA_BASE)) + CompileError::raise(root->location, "Exceeded userdata type limit in the compilation options"); + } + } + // computes type information for all functions based on type annotations if (FFlag::LuauCompileTypeInfo) { if (options.typeInfoLevel >= 1) - buildTypeMap(compiler.functionTypes, compiler.localTypes, compiler.exprTypes, root, options.vectorType, compiler.builtinTypes, - compiler.builtins, compiler.globals); + buildTypeMap(compiler.functionTypes, compiler.localTypes, compiler.exprTypes, root, options.vectorType, compiler.userdataTypes, + compiler.builtinTypes, compiler.builtins, compiler.globals, bytecode); } else { if (functionVisitor.hasTypes) - buildTypeMap(compiler.functionTypes, compiler.localTypes, compiler.exprTypes, root, options.vectorType, compiler.builtinTypes, - compiler.builtins, compiler.globals); + buildTypeMap(compiler.functionTypes, compiler.localTypes, compiler.exprTypes, root, options.vectorType, compiler.userdataTypes, + compiler.builtinTypes, compiler.builtins, compiler.globals, bytecode); } for (AstExprFunction* expr : functions) diff --git a/Compiler/src/Types.cpp b/Compiler/src/Types.cpp index eaa2d8be..4454114c 100644 --- a/Compiler/src/Types.cpp +++ b/Compiler/src/Types.cpp @@ -5,6 +5,7 @@ LUAU_FASTFLAG(LuauCompileTypeInfo) LUAU_FASTFLAG(LuauCompileTempTypeInfo) +LUAU_FASTFLAG(LuauCompileUserdataInfo) namespace Luau { @@ -39,7 +40,8 @@ static LuauBytecodeType getPrimitiveType(AstName name) } static LuauBytecodeType getType(const AstType* ty, const AstArray& generics, - const DenseHashMap& typeAliases, bool resolveAliases, const char* vectorType) + const DenseHashMap& typeAliases, bool resolveAliases, const char* vectorType, + const DenseHashMap& userdataTypes, BytecodeBuilder& bytecode) { if (const AstTypeReference* ref = ty->as()) { @@ -50,7 +52,7 @@ static LuauBytecodeType getType(const AstType* ty, const AstArraytype, (*alias)->generics, typeAliases, /* resolveAliases= */ false, vectorType); + return getType((*alias)->type, (*alias)->generics, typeAliases, /* resolveAliases= */ false, vectorType, userdataTypes, bytecode); else return LBC_TYPE_ANY; } @@ -64,6 +66,15 @@ static LuauBytecodeType getType(const AstType* ty, const AstArrayname); prim != LBC_TYPE_INVALID) return prim; + if (FFlag::LuauCompileUserdataInfo) + { + if (const uint8_t* userdataIndex = userdataTypes.find(ref->name)) + { + bytecode.useUserdataType(*userdataIndex); + return LuauBytecodeType(LBC_TYPE_TAGGED_USERDATA_BASE + *userdataIndex); + } + } + // not primitive or alias or generic => host-provided, we assume userdata for now return LBC_TYPE_USERDATA; } @@ -82,7 +93,7 @@ static LuauBytecodeType getType(const AstType* ty, const AstArraytypes) { - LuauBytecodeType et = getType(ty, generics, typeAliases, resolveAliases, vectorType); + LuauBytecodeType et = getType(ty, generics, typeAliases, resolveAliases, vectorType, userdataTypes, bytecode); if (et == LBC_TYPE_NIL) { @@ -113,7 +124,8 @@ static LuauBytecodeType getType(const AstType* ty, const AstArray& typeAliases, const char* vectorType) +static std::string getFunctionType(const AstExprFunction* func, const DenseHashMap& typeAliases, const char* vectorType, + const DenseHashMap& userdataTypes, BytecodeBuilder& bytecode) { bool self = func->self != 0; @@ -130,7 +142,8 @@ static std::string getFunctionType(const AstExprFunction* func, const DenseHashM for (AstLocal* arg : func->args) { LuauBytecodeType ty = - arg->annotation ? getType(arg->annotation, func->generics, typeAliases, /* resolveAliases= */ true, vectorType) : LBC_TYPE_ANY; + arg->annotation ? getType(arg->annotation, func->generics, typeAliases, /* resolveAliases= */ true, vectorType, userdataTypes, bytecode) + : LBC_TYPE_ANY; if (ty != LBC_TYPE_ANY) haveNonAnyParam = true; @@ -161,9 +174,11 @@ struct TypeMapVisitor : AstVisitor DenseHashMap& localTypes; DenseHashMap& exprTypes; const char* vectorType; + const DenseHashMap& userdataTypes; const BuiltinTypes& builtinTypes; const DenseHashMap& builtinCalls; const DenseHashMap& globals; + BytecodeBuilder& bytecode; DenseHashMap typeAliases; std::vector> typeAliasStack; @@ -171,15 +186,18 @@ struct TypeMapVisitor : AstVisitor DenseHashMap resolvedExprs; TypeMapVisitor(DenseHashMap& functionTypes, DenseHashMap& localTypes, - DenseHashMap& exprTypes, const char* vectorType, const BuiltinTypes& builtinTypes, - const DenseHashMap& builtinCalls, const DenseHashMap& globals) + DenseHashMap& exprTypes, const char* vectorType, const DenseHashMap& userdataTypes, + const BuiltinTypes& builtinTypes, const DenseHashMap& builtinCalls, const DenseHashMap& globals, + BytecodeBuilder& bytecode) : functionTypes(functionTypes) , localTypes(localTypes) , exprTypes(exprTypes) , vectorType(vectorType) + , userdataTypes(userdataTypes) , builtinTypes(builtinTypes) , builtinCalls(builtinCalls) , globals(globals) + , bytecode(bytecode) , typeAliases(AstName()) , resolvedLocals(nullptr) , resolvedExprs(nullptr) @@ -250,7 +268,7 @@ struct TypeMapVisitor : AstVisitor resolvedExprs[expr] = ty; - LuauBytecodeType bty = getType(ty, {}, typeAliases, /* resolveAliases= */ true, vectorType); + LuauBytecodeType bty = getType(ty, {}, typeAliases, /* resolveAliases= */ true, vectorType, userdataTypes, bytecode); exprTypes[expr] = bty; return bty; } @@ -263,7 +281,7 @@ struct TypeMapVisitor : AstVisitor resolvedLocals[local] = ty; - LuauBytecodeType bty = getType(ty, {}, typeAliases, /* resolveAliases= */ true, vectorType); + LuauBytecodeType bty = getType(ty, {}, typeAliases, /* resolveAliases= */ true, vectorType, userdataTypes, bytecode); if (bty != LBC_TYPE_ANY) localTypes[local] = bty; @@ -354,7 +372,7 @@ struct TypeMapVisitor : AstVisitor bool visit(AstExprFunction* node) override { - std::string type = getFunctionType(node, typeAliases, vectorType); + std::string type = getFunctionType(node, typeAliases, vectorType, userdataTypes, bytecode); if (!type.empty()) functionTypes[node] = std::move(type); @@ -393,7 +411,7 @@ struct TypeMapVisitor : AstVisitor if (AstType* annotation = local->annotation) { - LuauBytecodeType ty = getType(annotation, {}, typeAliases, /* resolveAliases= */ true, vectorType); + LuauBytecodeType ty = getType(annotation, {}, typeAliases, /* resolveAliases= */ true, vectorType, userdataTypes, bytecode); if (ty != LBC_TYPE_ANY) localTypes[local] = ty; @@ -754,10 +772,11 @@ struct TypeMapVisitor : AstVisitor }; void buildTypeMap(DenseHashMap& functionTypes, DenseHashMap& localTypes, - DenseHashMap& exprTypes, AstNode* root, const char* vectorType, const BuiltinTypes& builtinTypes, - const DenseHashMap& builtinCalls, const DenseHashMap& globals) + DenseHashMap& exprTypes, AstNode* root, const char* vectorType, const DenseHashMap& userdataTypes, + const BuiltinTypes& builtinTypes, const DenseHashMap& builtinCalls, const DenseHashMap& globals, + BytecodeBuilder& bytecode) { - TypeMapVisitor visitor(functionTypes, localTypes, exprTypes, vectorType, builtinTypes, builtinCalls, globals); + TypeMapVisitor visitor(functionTypes, localTypes, exprTypes, vectorType, userdataTypes, builtinTypes, builtinCalls, globals, bytecode); root->visit(&visitor); } diff --git a/Compiler/src/Types.h b/Compiler/src/Types.h index b1aff8a2..bd12ea77 100644 --- a/Compiler/src/Types.h +++ b/Compiler/src/Types.h @@ -10,6 +10,7 @@ namespace Luau { +class BytecodeBuilder; struct BuiltinTypes { @@ -26,7 +27,8 @@ struct BuiltinTypes }; void buildTypeMap(DenseHashMap& functionTypes, DenseHashMap& localTypes, - DenseHashMap& exprTypes, AstNode* root, const char* vectorType, const BuiltinTypes& builtinTypes, - const DenseHashMap& builtinCalls, const DenseHashMap& globals); + DenseHashMap& exprTypes, AstNode* root, const char* vectorType, const DenseHashMap& userdataTypes, + const BuiltinTypes& builtinTypes, const DenseHashMap& builtinCalls, const DenseHashMap& globals, + BytecodeBuilder& bytecode); } // namespace Luau diff --git a/VM/src/lstate.h b/VM/src/lstate.h index 21d7071c..35e66471 100644 --- a/VM/src/lstate.h +++ b/VM/src/lstate.h @@ -156,6 +156,7 @@ struct lua_ExecutionCallbacks int (*enter)(lua_State* L, Proto* proto); // called when function is about to start/resume (when execdata is present), return 0 to exit VM void (*disable)(lua_State* L, Proto* proto); // called when function has to be switched from native to bytecode in the debugger size_t (*getmemorysize)(lua_State* L, Proto* proto); // called to request the size of memory associated with native part of the Proto + uint8_t (*gettypemapping)(lua_State* L, const char* str, size_t len); // called to get the userdata type index }; /* diff --git a/VM/src/ltablib.cpp b/VM/src/ltablib.cpp index 545c1d2d..3e14d4ad 100644 --- a/VM/src/ltablib.cpp +++ b/VM/src/ltablib.cpp @@ -11,7 +11,6 @@ #include "ldebug.h" #include "lvm.h" -LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauFastCrossTableMove, false) LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauFastTableMaxn, false) LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauFasterConcat, false) @@ -145,68 +144,6 @@ static void moveelements(lua_State* L, int srct, int dstt, int f, int e, int t) luaC_barrierfast(L, dst); } - else if (DFFlag::LuauFastCrossTableMove && dst != src) - { - // compute the array slice we have to copy over - int slicestart = f < 1 ? 0 : (f > src->sizearray ? src->sizearray : f - 1); - int sliceend = e < 1 ? 0 : (e > src->sizearray ? src->sizearray : e); - LUAU_ASSERT(slicestart <= sliceend); - - int slicecount = sliceend - slicestart; - - if (slicecount > 0) - { - // array slice starting from INT_MIN is impossible, so we don't have to worry about int overflow - int dstslicestart = f < 1 ? -f + 1 : 0; - - // copy over the slice - for (int i = 0; i < slicecount; ++i) - { - lua_rawgeti(L, srct, slicestart + i + 1); - lua_rawseti(L, dstt, dstslicestart + t + i); - } - } - - // copy the remaining elements that could be in the hash part - int hashpartsize = sizenode(src); - - // select the strategy with the least amount of steps - if (n <= hashpartsize) - { - for (int i = 0; i < n; ++i) - { - // skip array slice elements that were already copied over - if (cast_to(unsigned int, f + i - 1) < cast_to(unsigned int, src->sizearray)) - continue; - - lua_rawgeti(L, srct, f + i); - lua_rawseti(L, dstt, t + i); - } - } - else - { - // source and destination tables are different, so we can iterate over source hash part directly - int i = hashpartsize; - - while (i--) - { - LuaNode* node = gnode(src, i); - if (ttisnumber(gkey(node))) - { - double n = nvalue(gkey(node)); - - int k; - luai_num2int(k, n); - - if (luai_numeq(cast_num(k), n) && k >= f && k <= e) - { - lua_rawgeti(L, srct, k); - lua_rawseti(L, dstt, t - f + k); - } - } - } - } - } else { if (t > e || t <= f || dst != src) diff --git a/VM/src/lvmload.cpp b/VM/src/lvmload.cpp index f13c0f21..ed564bba 100644 --- a/VM/src/lvmload.cpp +++ b/VM/src/lvmload.cpp @@ -14,6 +14,7 @@ #include LUAU_FASTFLAG(LuauLoadTypeInfo) +LUAU_FASTFLAGVARIABLE(LuauLoadUserdataInfo, false) // TODO: RAII deallocation doesn't work for longjmp builds if a memory error happens template @@ -187,6 +188,65 @@ static void resolveImportSafe(lua_State* L, Table* env, TValue* k, uint32_t id) } } +static void remapUserdataTypes(char* data, size_t size, uint8_t* userdataRemapping, uint32_t count) +{ + LUAU_ASSERT(FFlag::LuauLoadUserdataInfo); + + size_t offset = 0; + + uint32_t typeSize = readVarInt(data, size, offset); + uint32_t upvalCount = readVarInt(data, size, offset); + uint32_t localCount = readVarInt(data, size, offset); + + if (typeSize != 0) + { + uint8_t* types = (uint8_t*)data + offset; + + // Skip two bytes of function type introduction + for (uint32_t i = 2; i < typeSize; i++) + { + uint32_t index = uint32_t(types[i] - LBC_TYPE_TAGGED_USERDATA_BASE); + + if (index < count) + types[i] = userdataRemapping[index]; + } + + offset += typeSize; + } + + if (upvalCount != 0) + { + uint8_t* types = (uint8_t*)data + offset; + + for (uint32_t i = 0; i < upvalCount; i++) + { + uint32_t index = uint32_t(types[i] - LBC_TYPE_TAGGED_USERDATA_BASE); + + if (index < count) + types[i] = userdataRemapping[index]; + } + + offset += upvalCount; + } + + if (localCount != 0) + { + for (uint32_t i = 0; i < localCount; i++) + { + uint32_t index = uint32_t(data[offset] - LBC_TYPE_TAGGED_USERDATA_BASE); + + if (index < count) + data[offset] = userdataRemapping[index]; + + offset += 2; + readVarInt(data, size, offset); + readVarInt(data, size, offset); + } + } + + LUAU_ASSERT(offset == size); +} + int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size, int env) { size_t offset = 0; @@ -227,6 +287,18 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size if (version >= 4) { typesversion = read(data, size, offset); + + if (FFlag::LuauLoadUserdataInfo) + { + if (typesversion < LBC_TYPE_VERSION_MIN || typesversion > LBC_TYPE_VERSION_MAX) + { + char chunkbuf[LUA_IDSIZE]; + const char* chunkid = luaO_chunkid(chunkbuf, sizeof(chunkbuf), chunkname, strlen(chunkname)); + lua_pushfstring(L, "%s: bytecode type version mismatch (expected [%d..%d], got %d)", chunkid, LBC_TYPE_VERSION_MIN, + LBC_TYPE_VERSION_MAX, typesversion); + return 1; + } + } } // string table @@ -241,6 +313,31 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size offset += length; } + // userdata type remapping table + // for unknown userdata types, the entry will remap to common 'userdata' type + const uint32_t userdataTypeLimit = LBC_TYPE_TAGGED_USERDATA_END - LBC_TYPE_TAGGED_USERDATA_BASE; + uint8_t userdataRemapping[userdataTypeLimit]; + + if (FFlag::LuauLoadUserdataInfo && typesversion == 3) + { + memset(userdataRemapping, LBC_TYPE_USERDATA, userdataTypeLimit); + + uint8_t index = read(data, size, offset); + + while (index != 0) + { + TString* name = readString(strings, data, size, offset); + + if (uint32_t(index - 1) < userdataTypeLimit) + { + if (auto cb = L->global->ecb.gettypemapping) + userdataRemapping[index - 1] = cb(L, getstr(name), name->len); + } + + index = read(data, size, offset); + } + } + // proto table unsigned int protoCount = readVarInt(data, size, offset); TempBuffer protos(L, protoCount); @@ -299,7 +396,7 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size offset += typesize; } - else if (typesversion == 2) + else if (typesversion == 2 || (FFlag::LuauLoadUserdataInfo && typesversion == 3)) { uint32_t typesize = readVarInt(data, size, offset); @@ -311,6 +408,11 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size p->sizetypeinfo = typesize; memcpy(p->typeinfo, types, typesize); offset += typesize; + + if (FFlag::LuauLoadUserdataInfo && typesversion == 3) + { + remapUserdataTypes((char*)(uint8_t*)p->typeinfo, p->sizetypeinfo, userdataRemapping, userdataTypeLimit); + } } } } diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index e8927837..6255d73f 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -22,8 +22,9 @@ LUAU_FASTINT(LuauCompileLoopUnrollThreshold) LUAU_FASTINT(LuauCompileLoopUnrollThresholdMaxBoost) LUAU_FASTINT(LuauRecursionLimit) -LUAU_FASTFLAG(LuauCompileNoJumpLineRetarget) -LUAU_FASTFLAG(LuauCompileRepeatUntilSkippedLocals) +LUAU_FASTFLAG(LuauCompileTypeInfo) +LUAU_FASTFLAG(LuauCompileTempTypeInfo) +LUAU_FASTFLAG(LuauCompileUserdataInfo) using namespace Luau; @@ -2106,8 +2107,6 @@ RETURN R0 0 TEST_CASE("LoopContinueEarlyCleanup") { - ScopedFastFlag luauCompileRepeatUntilSkippedLocals{FFlag::LuauCompileRepeatUntilSkippedLocals, true}; - // locals after a potential 'continue' are not accessible inside the condition and can be closed at the end of a block CHECK_EQ("\n" + compileFunction(R"( local y @@ -2788,8 +2787,6 @@ end TEST_CASE("DebugLineInfoWhile") { - ScopedFastFlag luauCompileNoJumpLineRetarget{FFlag::LuauCompileNoJumpLineRetarget, true}; - Luau::BytecodeBuilder bcb; bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Lines); Luau::compileOrThrow(bcb, R"( @@ -3136,8 +3133,6 @@ local 8: reg 3, start pc 35 line 21, end pc 35 line 21 TEST_CASE("DebugLocals2") { - ScopedFastFlag luauCompileRepeatUntilSkippedLocals{FFlag::LuauCompileRepeatUntilSkippedLocals, true}; - const char* source = R"( function foo(x) repeat @@ -3167,9 +3162,6 @@ local 2: reg 0, start pc 0 line 4, end pc 2 line 6 TEST_CASE("DebugLocals3") { - ScopedFastFlag luauCompileRepeatUntilSkippedLocals{FFlag::LuauCompileRepeatUntilSkippedLocals, true}; - ScopedFastFlag luauCompileNoJumpLineRetarget{FFlag::LuauCompileNoJumpLineRetarget, true}; - const char* source = R"( function foo(x) repeat @@ -3203,6 +3195,7 @@ local 4: reg 0, start pc 0 line 4, end pc 5 line 8 8: RETURN R0 0 )"); } + TEST_CASE("DebugRemarks") { Luau::BytecodeBuilder bcb; @@ -3230,6 +3223,80 @@ RETURN R0 0 )"); } +TEST_CASE("DebugTypes") +{ + ScopedFastFlag luauCompileTypeInfo{FFlag::LuauCompileTypeInfo, true}; + ScopedFastFlag luauCompileTempTypeInfo{FFlag::LuauCompileTempTypeInfo, true}; + ScopedFastFlag luauCompileUserdataInfo{FFlag::LuauCompileUserdataInfo, true}; + + const char* source = R"( +local up: number = 2 + +function foo(e: vector, f: mat3, g: sequence) + local h = e * e + + for i=1,3 do + print(i) + end + + print(e * f) + print(g) + print(h) + + up += a + return a +end +)"; + + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Types); + bcb.setDumpSource(source); + + Luau::CompileOptions options; + options.vectorCtor = "vector"; + options.vectorType = "vector"; + + options.typeInfoLevel = 1; + + static const char* kUserdataCompileTypes[] = {"vec2", "color", "mat3", nullptr}; + options.userdataTypes = kUserdataCompileTypes; + + Luau::compileOrThrow(bcb, source, options); + + CHECK_EQ("\n" + bcb.dumpFunction(0), R"( +R0: vector [argument] +R1: mat3 [argument] +R2: userdata [argument] +U0: number +R6: any from 1 to 9 +R3: vector from 0 to 30 +MUL R3 R0 R0 +LOADN R6 1 +LOADN R4 3 +LOADN R5 1 +FORNPREP R4 L1 +L0: GETIMPORT R7 1 [print] +MOVE R8 R6 +CALL R7 1 0 +FORNLOOP R4 L0 +L1: GETIMPORT R4 1 [print] +MUL R5 R0 R1 +CALL R4 1 0 +GETIMPORT R4 1 [print] +MOVE R5 R2 +CALL R4 1 0 +GETIMPORT R4 1 [print] +MOVE R5 R3 +CALL R4 1 0 +GETUPVAL R4 0 +GETIMPORT R5 3 [a] +ADD R4 R4 R5 +SETUPVAL R4 0 +GETIMPORT R4 3 [a] +RETURN R4 1 +)"); +} + TEST_CASE("SourceRemarks") { const char* source = R"( @@ -4158,8 +4225,6 @@ RETURN R0 0 TEST_CASE("Coverage") { - ScopedFastFlag luauCompileNoJumpLineRetarget{FFlag::LuauCompileNoJumpLineRetarget, true}; - // basic statement coverage CHECK_EQ("\n" + compileFunction0Coverage(R"( print(1) diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index bd57a140..7ced52cf 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -33,8 +33,7 @@ void luaC_validate(lua_State* L); LUAU_FASTFLAG(DebugLuauAbortingChecks) LUAU_FASTINT(CodegenHeuristicsInstructionLimit) -LUAU_FASTFLAG(LuauCompileRepeatUntilSkippedLocals) -LUAU_DYNAMIC_FASTFLAG(LuauFastCrossTableMove) +LUAU_FASTFLAG(LuauCodegenFixSplitStoreConstMismatch) static lua_CompileOptions defaultOptions() { @@ -443,8 +442,6 @@ TEST_CASE("Sort") TEST_CASE("Move") { - ScopedFastFlag luauFastCrossTableMove{DFFlag::LuauFastCrossTableMove, true}; - runConformance("move.lua"); } @@ -717,8 +714,6 @@ TEST_CASE("Debugger") static bool singlestep = false; static int stephits = 0; - ScopedFastFlag luauCompileRepeatUntilSkippedLocals{FFlag::LuauCompileRepeatUntilSkippedLocals, true}; - SUBCASE("") { singlestep = false; @@ -2140,6 +2135,8 @@ TEST_CASE("Native") if (!codegen || !luau_codegen_supported()) return; + ScopedFastFlag luauCodegenFixSplitStoreConstMismatch{FFlag::LuauCodegenFixSplitStoreConstMismatch, true}; + SUBCASE("Checked") { FFlag::DebugLuauAbortingChecks.value = true; diff --git a/tests/ConformanceIrHooks.h b/tests/ConformanceIrHooks.h index 135fe9da..d4050863 100644 --- a/tests/ConformanceIrHooks.h +++ b/tests/ConformanceIrHooks.h @@ -3,6 +3,8 @@ #include "Luau/IrBuilder.h" +static const char* kUserdataRunTypes[] = {"extra", "color", "vec2", "mat3", nullptr}; + inline uint8_t vectorAccessBytecodeType(const char* member, size_t memberLength) { using namespace Luau::CodeGen; diff --git a/tests/IrBuilder.test.cpp b/tests/IrBuilder.test.cpp index 4f7725e6..da7cd9b1 100644 --- a/tests/IrBuilder.test.cpp +++ b/tests/IrBuilder.test.cpp @@ -14,7 +14,7 @@ LUAU_FASTFLAG(LuauCodegenRemoveDeadStores5) LUAU_FASTFLAG(DebugLuauAbortingChecks) -LUAU_FASTFLAG(LuauCodegenLoadPropCheckRegLinkInTv) +LUAU_FASTFLAG(LuauCodegenFixSplitStoreConstMismatch) using namespace Luau::CodeGen; @@ -2658,6 +2658,60 @@ bb_0: )"); } +TEST_CASE_FIXTURE(IrBuilderFixture, "DoNotProduceInvalidSplitStore1") +{ + ScopedFastFlag luauCodegenFixSplitStoreConstMismatch{FFlag::LuauCodegenFixSplitStoreConstMismatch, true}; + + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::STORE_INT, build.vmReg(0), build.constInt(1)); + build.inst(IrCmd::CHECK_TAG, build.vmReg(0), build.constTag(ttable), build.vmExit(1)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(1), build.inst(IrCmd::LOAD_TVALUE, build.vmReg(0))); + build.inst(IrCmd::RETURN, build.vmReg(1), build.constInt(1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + STORE_INT R0, 1i + CHECK_TAG R0, ttable, exit(1) + %2 = LOAD_TVALUE R0 + STORE_TVALUE R1, %2 + RETURN R1, 1i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "DoNotProduceInvalidSplitStore2") +{ + ScopedFastFlag luauCodegenFixSplitStoreConstMismatch{FFlag::LuauCodegenFixSplitStoreConstMismatch, true}; + + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::STORE_INT, build.vmReg(0), build.constInt(1)); + build.inst(IrCmd::CHECK_TAG, build.vmReg(0), build.constTag(tnumber), build.vmExit(1)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(1), build.inst(IrCmd::LOAD_TVALUE, build.vmReg(0))); + build.inst(IrCmd::RETURN, build.vmReg(1), build.constInt(1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + STORE_INT R0, 1i + CHECK_TAG R0, tnumber, exit(1) + %2 = LOAD_TVALUE R0 + STORE_TVALUE R1, %2 + RETURN R1, 1i + +)"); +} + TEST_SUITE_END(); TEST_SUITE_BEGIN("Analysis"); @@ -3475,8 +3529,6 @@ bb_1: TEST_CASE_FIXTURE(IrBuilderFixture, "TaggedValuePropagationIntoTvalueChecksRegisterVersion") { - ScopedFastFlag luauCodegenLoadPropCheckRegLinkInTv{FFlag::LuauCodegenLoadPropCheckRegLinkInTv, true}; - IrOp entry = build.block(IrBlockKind::Internal); build.beginBlock(entry); diff --git a/tests/IrLowering.test.cpp b/tests/IrLowering.test.cpp index 131ec4d1..5d7fedd8 100644 --- a/tests/IrLowering.test.cpp +++ b/tests/IrLowering.test.cpp @@ -13,18 +13,17 @@ #include "ConformanceIrHooks.h" #include +#include LUAU_FASTFLAG(LuauCodegenRemoveDeadStores5) LUAU_FASTFLAG(LuauCodegenDirectUserdataFlow) LUAU_FASTFLAG(LuauCompileTypeInfo) LUAU_FASTFLAG(LuauLoadTypeInfo) LUAU_FASTFLAG(LuauCodegenTypeInfo) -LUAU_FASTFLAG(LuauTypeInfoLookupImprovement) -LUAU_FASTFLAG(LuauCodegenIrTypeNames) LUAU_FASTFLAG(LuauCompileTempTypeInfo) -LUAU_FASTFLAG(LuauCodegenFixVectorFields) -LUAU_FASTFLAG(LuauCodegenVectorMispredictFix) LUAU_FASTFLAG(LuauCodegenAnalyzeHostVectorOps) +LUAU_FASTFLAG(LuauCompileUserdataInfo) +LUAU_FASTFLAG(LuauLoadUserdataInfo) static std::string getCodegenAssembly(const char* source, bool includeIrTypes = false, int debugLevel = 1) { @@ -64,6 +63,9 @@ static std::string getCodegenAssembly(const char* source, bool includeIrTypes = copts.vectorCtor = "vector"; copts.vectorType = "vector"; + static const char* kUserdataCompileTypes[] = {"vec2", "color", "mat3", nullptr}; + copts.userdataTypes = kUserdataCompileTypes; + Luau::BytecodeBuilder bcb; Luau::compileOrThrow(bcb, result, names, copts); @@ -71,6 +73,33 @@ static std::string getCodegenAssembly(const char* source, bool includeIrTypes = std::unique_ptr globalState(luaL_newstate(), lua_close); lua_State* L = globalState.get(); + // Runtime mapping is specifically created to NOT match the compilation mapping + options.compilationOptions.userdataTypes = kUserdataRunTypes; + + if (Luau::CodeGen::isSupported()) + { + // Type remapper requires the codegen runtime + Luau::CodeGen::create(L); + + Luau::CodeGen::setUserdataRemapper(L, kUserdataRunTypes, [](void* context, const char* str, size_t len) -> uint8_t { + const char** types = (const char**)context; + + uint8_t index = 0; + + std::string_view sv{str, len}; + + for (; *types; ++types) + { + if (sv == *types) + return index; + + index++; + } + + return 0xff; + }); + } + if (luau_load(L, "name", bytecode.data(), bytecode.size(), 0) == 0) return Luau::CodeGen::getAssembly(L, -1, options, nullptr); @@ -480,8 +509,6 @@ bb_bytecode_1: TEST_CASE("VectorRandomProp") { ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - ScopedFastFlag luauCodegenFixVectorFields{FFlag::LuauCodegenFixVectorFields, true}; - ScopedFastFlag luauCodegenVectorMispredictFix{FFlag::LuauCodegenVectorMispredictFix, true}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function foo(a: vector) @@ -524,7 +551,6 @@ bb_6: TEST_CASE("VectorCustomAccess") { ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - ScopedFastFlag luauCodegenVectorMispredictFix{FFlag::LuauCodegenVectorMispredictFix, true}; ScopedFastFlag luauCodegenAnalyzeHostVectorOps{FFlag::LuauCodegenAnalyzeHostVectorOps, true}; CHECK_EQ("\n" + getCodegenAssembly(R"( @@ -600,7 +626,6 @@ bb_bytecode_1: TEST_CASE("VectorCustomAccessChain") { ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - ScopedFastFlag luauCodegenVectorMispredictFix{FFlag::LuauCodegenVectorMispredictFix, true}; ScopedFastFlag LuauCodegenDirectUserdataFlow{FFlag::LuauCodegenDirectUserdataFlow, true}; ScopedFastFlag luauCodegenAnalyzeHostVectorOps{FFlag::LuauCodegenAnalyzeHostVectorOps, true}; @@ -655,7 +680,6 @@ bb_bytecode_1: TEST_CASE("VectorCustomNamecallChain") { ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - ScopedFastFlag luauCodegenVectorMispredictFix{FFlag::LuauCodegenVectorMispredictFix, true}; ScopedFastFlag LuauCodegenDirectUserdataFlow{FFlag::LuauCodegenDirectUserdataFlow, true}; ScopedFastFlag luauCodegenAnalyzeHostVectorOps{FFlag::LuauCodegenAnalyzeHostVectorOps, true}; @@ -717,8 +741,8 @@ bb_bytecode_1: TEST_CASE("VectorCustomNamecallChain2") { ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, - {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauTypeInfoLookupImprovement, true}, {FFlag::LuauCompileTempTypeInfo, true}, - {FFlag::LuauCodegenVectorMispredictFix, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, {FFlag::LuauCodegenAnalyzeHostVectorOps, true}}; + {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, + {FFlag::LuauCodegenAnalyzeHostVectorOps, true}}; CHECK_EQ("\n" + getCodegenAssembly(R"( type Vertex = {n: vector, b: vector} @@ -1048,7 +1072,7 @@ bb_bytecode_1: TEST_CASE("LoadAndMoveTypePropagation") { ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, - {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauTypeInfoLookupImprovement, true}}; + {FFlag::LuauCodegenRemoveDeadStores5, true}}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function getsum(n) @@ -1116,7 +1140,7 @@ bb_bytecode_4: TEST_CASE("ArgumentTypeRefinement") { ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, - {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauTypeInfoLookupImprovement, true}}; + {FFlag::LuauCodegenRemoveDeadStores5, true}}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function getsum(x, y) @@ -1155,7 +1179,7 @@ bb_bytecode_0: TEST_CASE("InlineFunctionType") { ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, - {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauTypeInfoLookupImprovement, true}}; + {FFlag::LuauCodegenRemoveDeadStores5, true}}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function inl(v: vector, s: number) @@ -1204,8 +1228,7 @@ bb_bytecode_0: TEST_CASE("ResolveTablePathTypes") { ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, - {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauTypeInfoLookupImprovement, true}, {FFlag::LuauCodegenIrTypeNames, true}, - {FFlag::LuauCompileTempTypeInfo, true}}; + {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}}; CHECK_EQ("\n" + getCodegenAssembly(R"( type Vertex = {pos: vector, normal: vector} @@ -1260,8 +1283,7 @@ bb_6: TEST_CASE("ResolvableSimpleMath") { ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, - {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauTypeInfoLookupImprovement, true}, {FFlag::LuauCodegenIrTypeNames, true}, - {FFlag::LuauCompileTempTypeInfo, true}}; + {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}}; CHECK_EQ("\n" + getCodegenHeader(R"( type Vertex = { p: vector, uv: vector, n: vector, t: vector, b: vector, h: number } @@ -1318,8 +1340,8 @@ end TEST_CASE("ResolveVectorNamecalls") { ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, - {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauTypeInfoLookupImprovement, true}, {FFlag::LuauCodegenIrTypeNames, true}, - {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, {FFlag::LuauCodegenAnalyzeHostVectorOps, true}}; + {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, + {FFlag::LuauCodegenAnalyzeHostVectorOps, true}}; CHECK_EQ("\n" + getCodegenAssembly(R"( type Vertex = {pos: vector, normal: vector} @@ -1384,8 +1406,7 @@ bb_6: TEST_CASE("ImmediateTypeAnnotationHelp") { ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, - {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauTypeInfoLookupImprovement, true}, {FFlag::LuauCodegenIrTypeNames, true}, - {FFlag::LuauCompileTempTypeInfo, true}}; + {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function foo(arr, i) @@ -1424,8 +1445,7 @@ bb_2: TEST_CASE("UnaryTypeResolve") { ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, - {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauTypeInfoLookupImprovement, true}, {FFlag::LuauCodegenIrTypeNames, true}, - {FFlag::LuauCompileTempTypeInfo, true}}; + {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}}; CHECK_EQ("\n" + getCodegenHeader(R"( local function foo(a, b: vector, c) @@ -1448,8 +1468,7 @@ end TEST_CASE("ForInManualAnnotation") { ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, - {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauTypeInfoLookupImprovement, true}, {FFlag::LuauCodegenIrTypeNames, true}, - {FFlag::LuauCompileTempTypeInfo, true}}; + {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}}; CHECK_EQ("\n" + getCodegenAssembly(R"( type Vertex = {pos: vector, normal: vector} @@ -1545,8 +1564,7 @@ bb_12: TEST_CASE("ForInAutoAnnotationIpairs") { ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, - {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauTypeInfoLookupImprovement, true}, {FFlag::LuauCodegenIrTypeNames, true}, - {FFlag::LuauCompileTempTypeInfo, true}}; + {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}}; CHECK_EQ("\n" + getCodegenHeader(R"( type Vertex = {pos: vector, normal: vector} @@ -1574,8 +1592,7 @@ end TEST_CASE("ForInAutoAnnotationPairs") { ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, - {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauTypeInfoLookupImprovement, true}, {FFlag::LuauCodegenIrTypeNames, true}, - {FFlag::LuauCompileTempTypeInfo, true}}; + {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}}; CHECK_EQ("\n" + getCodegenHeader(R"( type Vertex = {pos: vector, normal: vector} @@ -1603,8 +1620,7 @@ end TEST_CASE("ForInAutoAnnotationGeneric") { ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, - {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauTypeInfoLookupImprovement, true}, {FFlag::LuauCodegenIrTypeNames, true}, - {FFlag::LuauCompileTempTypeInfo, true}}; + {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}}; CHECK_EQ("\n" + getCodegenHeader(R"( type Vertex = {pos: vector, normal: vector} @@ -1629,4 +1645,49 @@ end )"); } +// Temporary test, when we don't compile new typeinfo, but support loading it +TEST_CASE("CustomUserdataTypesTemp") +{ + // This test requires runtime component to be present + if (!Luau::CodeGen::isSupported()) + return; + + ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, + {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, false}, + {FFlag::LuauLoadUserdataInfo, true}}; + + CHECK_EQ("\n" + getCodegenHeader(R"( +local function foo(v: vec2, x: mat3) + return v.X * x +end +)"), + R"( +; function foo(v, x) line 2 +; R0: userdata [argument 'v'] +; R1: userdata [argument 'x'] +)"); +} + +TEST_CASE("CustomUserdataTypes") +{ + // This test requires runtime component to be present + if (!Luau::CodeGen::isSupported()) + return; + + ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, + {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true}, + {FFlag::LuauLoadUserdataInfo, true}}; + + CHECK_EQ("\n" + getCodegenHeader(R"( +local function foo(v: vec2, x: mat3) + return v.X * x +end +)"), + R"( +; function foo(v, x) line 2 +; R0: vec2 [argument 'v'] +; R1: mat3 [argument 'x'] +)"); +} + TEST_SUITE_END(); diff --git a/tests/NonStrictTypeChecker.test.cpp b/tests/NonStrictTypeChecker.test.cpp index 806dac62..e51fb0df 100644 --- a/tests/NonStrictTypeChecker.test.cpp +++ b/tests/NonStrictTypeChecker.test.cpp @@ -556,4 +556,22 @@ local E = require(script.Parent.A) LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "nonstrict_shouldnt_warn_on_valid_buffer_use") +{ + loadDefinition(R"( +declare buffer: { + create: @checked (size: number) -> buffer, + readi8: @checked (b: buffer, offset: number) -> number, + writef64: @checked (b: buffer, offset: number, value: number) -> (), +} +)"); + + CheckResult result = checkNonStrict(R"( +local b = buffer.create(100) +buffer.writef64(b, 0, 5) +buffer.readi8(b, 0) +)"); + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/Repl.test.cpp b/tests/Repl.test.cpp index c22d464e..3eceea17 100644 --- a/tests/Repl.test.cpp +++ b/tests/Repl.test.cpp @@ -420,4 +420,22 @@ print(NewProxyOne.HelloICauseACrash) )"); } +TEST_CASE_FIXTURE(ReplFixture, "InteractiveStackReserve1") +{ + // Reset stack reservation + lua_resume(L, nullptr, 0); + + runCode(L, R"( +local t = {} +)"); +} + +TEST_CASE_FIXTURE(ReplFixture, "InteractiveStackReserve2") +{ + // Reset stack reservation + lua_resume(L, nullptr, 0); + + getCompletionSet("a"); +} + TEST_SUITE_END(); diff --git a/tests/Subtyping.test.cpp b/tests/Subtyping.test.cpp index d8f115ae..afc22d0f 100644 --- a/tests/Subtyping.test.cpp +++ b/tests/Subtyping.test.cpp @@ -915,6 +915,7 @@ TEST_IS_SUBTYPE(numberToNumberType, negate(builtinTypes->classType)); TEST_IS_NOT_SUBTYPE(numberToNumberType, negate(builtinTypes->functionType)); // Negated supertypes: Primitives and singletons +TEST_IS_NOT_SUBTYPE(builtinTypes->stringType, negate(builtinTypes->stringType)); TEST_IS_SUBTYPE(builtinTypes->stringType, negate(builtinTypes->numberType)); TEST_IS_SUBTYPE(str("foo"), meet(builtinTypes->stringType, negate(str("bar")))); TEST_IS_NOT_SUBTYPE(builtinTypes->trueType, negate(builtinTypes->booleanType)); diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index b2c5f623..e2b3f9b7 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -782,7 +782,10 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_map") TypeId ty = requireType("map"); const FunctionType* ftv = get(follow(ty)); - CHECK_EQ("map(arr: {a}, fn: (a) -> b): {b}", toStringNamedFunction("map", *ftv)); + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK_EQ("map(arr: {a}, fn: (a) -> (b, ...unknown)): {b}", toStringNamedFunction("map", *ftv)); + else + CHECK_EQ("map(arr: {a}, fn: (a) -> b): {b}", toStringNamedFunction("map", *ftv)); } TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_generic_pack") diff --git a/tests/TypeFamily.test.cpp b/tests/TypeFamily.test.cpp index 88dfbf47..c66f0227 100644 --- a/tests/TypeFamily.test.cpp +++ b/tests/TypeFamily.test.cpp @@ -3,7 +3,6 @@ #include "Luau/ConstraintSolver.h" #include "Luau/NotNull.h" -#include "Luau/TxnLog.h" #include "Luau/Type.h" #include "ClassFixture.h" @@ -14,6 +13,7 @@ using namespace Luau; LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) +LUAU_DYNAMIC_FASTINT(LuauTypeFamilyApplicationCartesianProductLimit) struct FamilyFixture : Fixture { @@ -24,7 +24,7 @@ struct FamilyFixture : Fixture { swapFamily = TypeFamily{/* name */ "Swap", /* reducer */ - [](TypeId instance, NotNull queue, const std::vector& tys, const std::vector& tps, + [](TypeId instance, const std::vector& tys, const std::vector& tps, NotNull ctx) -> TypeFamilyReductionResult { LUAU_ASSERT(tys.size() == 1); TypeId param = follow(tys.at(0)); @@ -716,4 +716,117 @@ _(setmetatable(_,{[...]=_,})) )"); } +TEST_CASE_FIXTURE(BuiltinsFixture, "cyclic_concat_family_at_work") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + type T = concat + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK(toString(requireTypeAlias("T")) == "string"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "exceeded_distributivity_limits") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + ScopedFastInt sfi{DFInt::LuauTypeFamilyApplicationCartesianProductLimit, 10}; + + loadDefinition(R"( + declare class A + function __mul(self, rhs: unknown): A + end + + declare class B + function __mul(self, rhs: unknown): B + end + + declare class C + function __mul(self, rhs: unknown): C + end + + declare class D + function __mul(self, rhs: unknown): D + end + )"); + + CheckResult result = check(R"( + type T = mul + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(get(result.errors[0])); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "didnt_quite_exceed_distributivity_limits") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + // We duplicate the test here because we want to make sure the test failed + // due to exceeding the limits specifically, rather than any possible reasons. + ScopedFastInt sfi{DFInt::LuauTypeFamilyApplicationCartesianProductLimit, 20}; + + loadDefinition(R"( + declare class A + function __mul(self, rhs: unknown): A + end + + declare class B + function __mul(self, rhs: unknown): B + end + + declare class C + function __mul(self, rhs: unknown): C + end + + declare class D + function __mul(self, rhs: unknown): D + end + )"); + + CheckResult result = check(R"( + type T = mul + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "ensure_equivalence_with_distributivity") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + loadDefinition(R"( + declare class A + function __mul(self, rhs: unknown): A + end + + declare class B + function __mul(self, rhs: unknown): B + end + + declare class C + function __mul(self, rhs: unknown): C + end + + declare class D + function __mul(self, rhs: unknown): D + end + )"); + + CheckResult result = check(R"( + type T = mul + type U = mul | mul | mul | mul + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK(toString(requireTypeAlias("T")) == "A | B"); + CHECK(toString(requireTypeAlias("U")) == "A | A | B | B"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.anyerror.test.cpp b/tests/TypeInfer.anyerror.test.cpp index 8d14f56b..b305d97d 100644 --- a/tests/TypeInfer.anyerror.test.cpp +++ b/tests/TypeInfer.anyerror.test.cpp @@ -32,15 +32,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_returns_any") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - // Bug: We do not simplify at the right time - CHECK_EQ("any?", toString(requireType("a"))); - } - else - { - CHECK_EQ(builtinTypes->anyType, requireType("a")); - } + CHECK(builtinTypes->anyType == requireType("a")); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_returns_any2") @@ -58,15 +50,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_returns_any2") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - // Bug: We do not simplify at the right time - CHECK_EQ("any?", toString(requireType("a"))); - } - else - { - CHECK_EQ("any", toString(requireType("a"))); - } + CHECK("any" == toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any") @@ -82,15 +66,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - // Bug: We do not simplify at the right time - CHECK_EQ("any?", toString(requireType("a"))); - } - else - { - CHECK_EQ("any", toString(requireType("a"))); - } + CHECK("any" == toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any2") @@ -104,17 +80,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any2") end )"); - LUAU_REQUIRE_NO_ERRORS(result); - - if (FFlag::DebugLuauDeferredConstraintResolution) - { - // Bug: We do not simplify at the right time - CHECK_EQ("any?", toString(requireType("a"))); - } - else - { - CHECK_EQ("any", toString(requireType("a"))); - } + CHECK("any" == toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any_pack") @@ -130,15 +96,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any_pack") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - // Bug: We do not simplify at the right time - CHECK_EQ("any?", toString(requireType("a"))); - } - else - { - CHECK_EQ("any", toString(requireType("a"))); - } + CHECK("any" == toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error") diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 34178fd9..48d130dd 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -1582,7 +1582,7 @@ TEST_CASE_FIXTURE(Fixture, "inferred_higher_order_functions_are_quantified_at_th if (!result.errors.empty()) { for (const auto& e : result.errors) - printf("%s %s: %s\n", e.moduleName.c_str(), toString(e.location).c_str(), toString(e).c_str()); + MESSAGE(e.moduleName << " " << toString(e.location) << ": " << toString(e)); } } diff --git a/tests/TypeInfer.loops.test.cpp b/tests/TypeInfer.loops.test.cpp index 8bbd3f92..1a7ef973 100644 --- a/tests/TypeInfer.loops.test.cpp +++ b/tests/TypeInfer.loops.test.cpp @@ -433,9 +433,53 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "varlist_declared_by_for_in_loop_should_be_fr end )"); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + auto err = get(result.errors[0]); + CHECK(err != nullptr); + } + else + { + LUAU_REQUIRE_NO_ERRORS(result); + } +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "iter_constraint_before_loop_body") +{ + CheckResult result = check(R"( + local T = { + fields = {}, + } + + function f() + for u, v in pairs(T.fields) do + T.fields[u] = nil + end + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "rbxl_place_file_crash_for_wrong_constraints") +{ + CheckResult result = check(R"( +local VehicleParameters = { + -- These are default values in the case the package structure is broken + StrutSpringStiffnessFront = 28000, +} + +local function updateFromConfiguration() + for property, value in pairs(VehicleParameters) do + VehicleParameters[property] = value + end +end +)"); + LUAU_REQUIRE_NO_ERRORS(result); +} + + TEST_CASE_FIXTURE(BuiltinsFixture, "properly_infer_iteratee_is_a_free_table") { // In this case, we cannot know the element type of the table {}. It could be anything. diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index 56548608..fac86150 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -732,7 +732,11 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "and_binexps_dont_unify") end )"); - LUAU_REQUIRE_NO_ERRORS(result); + // This infers a type for `t` of `{unknown}`, and so it makes sense that `t[1].test` would error. + if (FFlag::DebugLuauDeferredConstraintResolution) + LUAU_REQUIRE_ERROR_COUNT(1, result); + else + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "error_on_invalid_operand_types_to_relational_operators") diff --git a/tests/TypeInfer.primitives.test.cpp b/tests/TypeInfer.primitives.test.cpp index 640e693b..37f891cb 100644 --- a/tests/TypeInfer.primitives.test.cpp +++ b/tests/TypeInfer.primitives.test.cpp @@ -101,4 +101,14 @@ TEST_CASE("singleton_types") CHECK(result.errors.empty()); } +TEST_CASE_FIXTURE(BuiltinsFixture, "property_of_buffers") +{ + CheckResult result = check(R"( + local b = buffer.create(100) + print(b.foo) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index 485a18c6..ebf1fde4 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -44,6 +44,20 @@ TEST_CASE_FIXTURE(Fixture, "string_singletons") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "string_singleton_function_call") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + local x = "a" + function f(x: "a") end + f(x) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_CASE_FIXTURE(Fixture, "bool_singletons_mismatch") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index bd0a4144..2c6136a4 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -2462,10 +2462,7 @@ local x: {number} | number | string local y = #x )"); - if (FFlag::DebugLuauDeferredConstraintResolution) - LUAU_REQUIRE_ERROR_COUNT(2, result); - else - LUAU_REQUIRE_ERROR_COUNT(1, result); + LUAU_REQUIRE_ERROR_COUNT(1, result); } TEST_CASE_FIXTURE(BuiltinsFixture, "dont_hang_when_trying_to_look_up_in_cyclic_metatable_index") @@ -2973,7 +2970,7 @@ c = b const TableType* ttv = get(*ty); REQUIRE(ttv); - CHECK(ttv->instantiatedTypeParams.empty()); + CHECK(0 == ttv->instantiatedTypeParams.size()); } TEST_CASE_FIXTURE(Fixture, "table_indexing_error_location") @@ -4355,19 +4352,6 @@ TEST_CASE_FIXTURE(Fixture, "mymovie_read_write_tables_bug_2") LUAU_REQUIRE_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "setindexer_always_transmute") -{ - ScopedFastFlag sff{FFlag::DebugLuauDeferredConstraintResolution, true}; - - CheckResult result = check(R"( - function f(x) - (5)[5] = x - end - )"); - - CHECK_EQ("(*error-type*) -> ()", toString(requireType("f"))); -} - TEST_CASE_FIXTURE(BuiltinsFixture, "instantiated_metatable_frozen_table_clone_mutation") { ScopedFastFlag luauMetatableInstantiationCloneCheck{FFlag::LuauMetatableInstantiationCloneCheck, true}; @@ -4412,6 +4396,21 @@ TEST_CASE_FIXTURE(Fixture, "setprop_on_a_mutating_local_in_both_loops_and_functi LUAU_REQUIRE_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "cant_index_this") +{ + CheckResult result = check(R"( + local a: number = 9 + a[18] = "tomfoolery" + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + NotATable* notATable = get(result.errors[0]); + REQUIRE(notATable); + + CHECK("number" == toString(notATable->ty)); +} + TEST_CASE_FIXTURE(Fixture, "setindexer_multiple_tables_intersection") { ScopedFastFlag sff{FFlag::DebugLuauDeferredConstraintResolution, true}; @@ -4423,8 +4422,8 @@ TEST_CASE_FIXTURE(Fixture, "setindexer_multiple_tables_intersection") end )"); - LUAU_REQUIRE_NO_ERRORS(result); - CHECK("({ [string]: number } & { [thread]: boolean }, boolean | number) -> ()" == toString(requireType("f"))); + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK("({ [string]: number } & { [thread]: boolean }, never) -> ()" == toString(requireType("f"))); } TEST_CASE_FIXTURE(Fixture, "insert_a_and_f_of_a_into_table_res_in_a_loop") diff --git a/tests/TypeInfer.typestates.test.cpp b/tests/TypeInfer.typestates.test.cpp index 3116022b..19117447 100644 --- a/tests/TypeInfer.typestates.test.cpp +++ b/tests/TypeInfer.typestates.test.cpp @@ -406,6 +406,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "prototyped_recursive_functions_but_has_futur )"); LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK("((() -> ()) | number)?" == toString(requireType("f"))); } diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 5f4d2a0e..539b8592 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -606,13 +606,7 @@ TEST_CASE_FIXTURE(Fixture, "indexing_into_a_cyclic_union_doesnt_crash") end )"); - // The old solver has a bug: It doesn't consider this goofy thing to be a - // table. It's not really important. What's important is that we don't - // crash, hang, or ICE. - if (FFlag::DebugLuauDeferredConstraintResolution) - LUAU_REQUIRE_NO_ERRORS(result); - else - LUAU_REQUIRE_ERROR_COUNT(1, result); + LUAU_REQUIRE_ERROR_COUNT(1, result); } TEST_CASE_FIXTURE(BuiltinsFixture, "table_union_write_indirect") diff --git a/tests/conformance/move.lua b/tests/conformance/move.lua index 9518219f..bb613157 100644 --- a/tests/conformance/move.lua +++ b/tests/conformance/move.lua @@ -65,30 +65,6 @@ do a = table.move({[minI] = 100}, minI, minI, maxI) eqT(a, {[minI] = 100, [maxI] = 100}) - -- moving small amount of elements (array/hash) using a wide range - a = {} - table.move({1, 2, 3, 4, 5}, -100000000, 100000000, -100000000, a) - eqT(a, {1, 2, 3, 4, 5}) - - a = {} - table.move({1, 2}, -100000000, 100000000, 0, a) - eqT(a, {[100000001] = 1, [100000002] = 2}) - - -- hash part copy - a = {} - table.move({[-1000000] = 1, [-100] = 2, [100] = 3, [100000] = 4}, -100000000, 100000000, 0, a) - eqT(a, {[99000000] = 1, [99999900] = 2, [100000100] = 3, [100100000] = 4}) - - -- precise hash part bounds - a = {} - table.move({[-100000000 - 1] = -1, [-100000000] = 1, [-100] = 2, [100] = 3, [100000000] = 4, [100000000 + 1] = -1}, -100000000, 100000000, 0, a) - eqT(a, {[0] = 1, [99999900] = 2, [100000100] = 3, [200000000] = 4}) - - -- no integer undeflow in corner hash part case - a = {} - table.move({[minI] = 100, [-100] = 2}, minI, minI + 100000000, minI, a) - eqT(a, {[minI] = 100}) - -- hash part skips array slice a = {} table.move({[-1] = 1, [0] = 2, [1] = 3, [2] = 4}, -1, 3, 1, a) @@ -97,6 +73,19 @@ do a = {} table.move({[-1] = 1, [0] = 2, [1] = 3, [2] = 4, [10] = 5, [100] = 6, [1000] = 7}, -1, 3, 1, a) eqT(a, {[1] = 1, [2] = 2, [3] = 3, [4] = 4}) + + -- moving ranges containing nil values into tables with values + a = {1, 2, 3, 4, 5} + table.move({10}, 1, 3, 2, a) + eqT(a, {1, 10, nil, nil, 5}) + + a = {1, 2, 3, 4, 5} + table.move({10}, -1, 1, 2, a) + eqT(a, {1, nil, nil, 10, 5}) + + a = {[-1000] = 1, [1000] = 2, [1] = 3} + table.move({10}, -1000, 1000, -1000, a) + eqT(a, {10}) end checkerror("too many", table.move, {}, 0, maxI, 1) diff --git a/tests/conformance/native.lua b/tests/conformance/native.lua index 094e6b83..03845013 100644 --- a/tests/conformance/native.lua +++ b/tests/conformance/native.lua @@ -208,6 +208,35 @@ end assert(pcall(fuzzfail21) == false) +local function fuzzfail22(...) + local _ = {false,},true,...,l0 + while _ do + _ = true,{unpack(0,_),},l0 + _.n126 = nil + _ = {not _,_=not _,n0=_,_,n0=not _,},_ < _ + return _ > _ + end + return `""` +end + +assert(pcall(fuzzfail22) == false) + +local function fuzzfail23(...) + local _ = {false,},_,...,l0 + while _ do + _ = true,{unpack(_),},l0 + _ = {{[_]=nil,_=not _,_,true,_=nil,},not _,not _,_,bxor=- _,} + do end + break + end + do end + local _ = _,true + do end + local _ = _,true +end + +assert(pcall(fuzzfail23) == false) + local function arraySizeInv1() local t = {1, 2, nil, nil, nil, nil, nil, nil, nil, true} diff --git a/tools/faillist.txt b/tools/faillist.txt index 6939df54..7a214a32 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -71,7 +71,6 @@ GenericsTests.no_stack_overflow_from_quantifying GenericsTests.properties_can_be_instantiated_polytypes GenericsTests.quantify_functions_even_if_they_have_an_explicit_generic GenericsTests.self_recursive_instantiated_param -IntersectionTypes.CLI-44817 IntersectionTypes.error_detailed_intersection_all IntersectionTypes.error_detailed_intersection_part IntersectionTypes.intersect_bool_and_false @@ -134,11 +133,8 @@ RefinementTest.call_an_incompatible_function_after_using_typeguard RefinementTest.dataflow_analysis_can_tell_refinements_when_its_appropriate_to_refine_into_nil_or_never RefinementTest.discriminate_from_isa_of_x RefinementTest.discriminate_from_truthiness_of_x -RefinementTest.free_type_is_equal_to_an_lvalue RefinementTest.globals_can_be_narrowed_too RefinementTest.isa_type_refinement_must_be_known_ahead_of_time -RefinementTest.luau_polyfill_isindexkey_refine_conjunction -RefinementTest.luau_polyfill_isindexkey_refine_conjunction_variant RefinementTest.not_t_or_some_prop_of_t RefinementTest.refine_a_param_that_got_resolved_during_constraint_solving_stage RefinementTest.refine_a_property_of_some_global @@ -157,7 +153,6 @@ TableTests.a_free_shape_can_turn_into_a_scalar_if_it_is_compatible TableTests.a_free_shape_cannot_turn_into_a_scalar_if_it_is_not_compatible TableTests.any_when_indexing_into_an_unsealed_table_with_no_indexer_in_nonstrict_mode TableTests.array_factory_function -TableTests.cannot_augment_sealed_table TableTests.casting_tables_with_props_into_table_with_indexer2 TableTests.casting_tables_with_props_into_table_with_indexer3 TableTests.casting_unsealed_tables_with_props_into_table_with_indexer @@ -181,20 +176,18 @@ TableTests.generalize_table_argument TableTests.generic_table_instantiation_potential_regression TableTests.indexer_on_sealed_table_must_unify_with_free_table TableTests.indexers_get_quantified_too -TableTests.inequality_operators_imply_exactly_matching_types -TableTests.infer_array TableTests.infer_indexer_from_array_like_table TableTests.infer_indexer_from_its_variable_type_and_unifiable TableTests.inferred_return_type_of_free_table TableTests.invariant_table_properties_means_instantiating_tables_in_assignment_is_unsound TableTests.invariant_table_properties_means_instantiating_tables_in_call_is_unsound -TableTests.length_operator_union TableTests.less_exponential_blowup_please TableTests.meta_add TableTests.meta_add_inferred TableTests.metatable_mismatch_should_fail TableTests.missing_metatable_for_sealed_tables_do_not_get_inferred TableTests.mixed_tables_with_implicit_numbered_keys +TableTests.nil_assign_doesnt_hit_indexer TableTests.ok_to_provide_a_subtype_during_construction TableTests.ok_to_set_nil_even_on_non_lvalue_base_expr TableTests.okay_to_add_property_to_unsealed_tables_by_assignment @@ -202,7 +195,6 @@ TableTests.okay_to_add_property_to_unsealed_tables_by_function_call TableTests.only_ascribe_synthetic_names_at_module_scope TableTests.open_table_unification_2 TableTests.parameter_was_set_an_indexer_and_bounded_by_another_parameter -TableTests.parameter_was_set_an_indexer_and_bounded_by_string TableTests.pass_a_union_of_tables_to_a_function_that_requires_a_table TableTests.pass_a_union_of_tables_to_a_function_that_requires_a_table_2 TableTests.persistent_sealed_table_is_immutable @@ -210,7 +202,6 @@ TableTests.quantify_even_that_table_was_never_exported_at_all TableTests.quantify_metatables_of_metatables_of_table TableTests.reasonable_error_when_adding_a_nonexistent_property_to_an_array_like_table TableTests.recursive_metatable_type_call -TableTests.refined_thing_can_be_an_array TableTests.right_table_missing_key2 TableTests.scalar_is_a_subtype_of_a_compatible_polymorphic_shape_type TableTests.scalar_is_not_a_subtype_of_a_compatible_polymorphic_shape_type @@ -228,12 +219,10 @@ TableTests.table_subtyping_with_extra_props_dont_report_multiple_errors TableTests.table_subtyping_with_missing_props_dont_report_multiple_errors2 TableTests.table_unification_4 TableTests.table_unifies_into_map -TableTests.table_writes_introduce_write_properties TableTests.type_mismatch_on_massive_table_is_cut_short TableTests.used_colon_instead_of_dot TableTests.used_dot_instead_of_colon TableTests.when_augmenting_an_unsealed_table_with_an_indexer_apply_the_correct_scope_to_the_indexer_type -TableTests.wrong_assign_does_hit_indexer ToDot.function ToString.exhaustive_toString_of_cyclic_table ToString.free_types @@ -274,6 +263,7 @@ TypeInfer.cli_50041_committing_txnlog_in_apollo_client_error TypeInfer.dont_ice_when_failing_the_occurs_check TypeInfer.dont_report_type_errors_within_an_AstExprError TypeInfer.dont_report_type_errors_within_an_AstStatError +TypeInfer.follow_on_new_types_in_substitution TypeInfer.globals TypeInfer.globals2 TypeInfer.infer_through_group_expr @@ -285,9 +275,10 @@ TypeInfer.type_infer_recursion_limit_no_ice TypeInfer.type_infer_recursion_limit_normalizer TypeInfer.unify_nearly_identical_recursive_types TypeInferAnyError.can_subscript_any -TypeInferAnyError.for_in_loop_iterator_is_error -TypeInferAnyError.for_in_loop_iterator_is_error2 -TypeInferAnyError.replace_every_free_type_when_unifying_a_complex_function_with_any +TypeInferAnyError.for_in_loop_iterator_is_any +TypeInferAnyError.for_in_loop_iterator_is_any2 +TypeInferAnyError.for_in_loop_iterator_is_any_pack +TypeInferAnyError.for_in_loop_iterator_returns_any2 TypeInferClasses.callable_classes TypeInferClasses.cannot_unify_class_instance_with_primitive TypeInferClasses.class_type_mismatch_with_name_conflict @@ -317,10 +308,8 @@ TypeInferFunctions.function_does_not_return_enough_values TypeInferFunctions.function_exprs_are_generalized_at_signature_scope_not_enclosing TypeInferFunctions.function_is_supertype_of_concrete_functions TypeInferFunctions.function_statement_sealed_table_assignment_through_indexer -TypeInferFunctions.fuzzer_missing_follow_in_ast_stat_fun TypeInferFunctions.generic_packs_are_not_variadic TypeInferFunctions.higher_order_function_2 -TypeInferFunctions.higher_order_function_3 TypeInferFunctions.higher_order_function_4 TypeInferFunctions.improved_function_arg_mismatch_error_nonstrict TypeInferFunctions.improved_function_arg_mismatch_errors @@ -338,7 +327,6 @@ TypeInferFunctions.occurs_check_failure_in_function_return_type TypeInferFunctions.other_things_are_not_related_to_function TypeInferFunctions.param_1_and_2_both_takes_the_same_generic_but_their_arguments_are_incompatible TypeInferFunctions.param_1_and_2_both_takes_the_same_generic_but_their_arguments_are_incompatible_2 -TypeInferFunctions.regex_benchmark_string_format_minimization TypeInferFunctions.report_exiting_without_return_nonstrict TypeInferFunctions.return_type_by_overload TypeInferFunctions.tf_suggest_return_type @@ -370,9 +358,7 @@ TypeInferLoops.loop_iter_trailing_nil TypeInferLoops.loop_typecheck_crash_on_empty_optional TypeInferLoops.properly_infer_iteratee_is_a_free_table TypeInferLoops.repeat_loop -TypeInferLoops.varlist_declared_by_for_in_loop_should_be_free TypeInferLoops.while_loop -TypeInferModules.do_not_modify_imported_types_5 TypeInferModules.require TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_it_wont_help_2 TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_not_defined_with_colon @@ -396,7 +382,6 @@ TypeInferOperators.typecheck_unary_len_error TypeInferOperators.typecheck_unary_minus_error TypeInferOperators.UnknownGlobalCompoundAssign TypeInferPrimitives.CheckMethodsOfNumber -TypeInferPrimitives.string_index TypeInferUnknownNever.assign_to_local_which_is_never TypeInferUnknownNever.index_on_union_of_tables_for_properties_that_is_never TypeInferUnknownNever.index_on_union_of_tables_for_properties_that_is_sorta_never @@ -414,6 +399,7 @@ TypeSingletons.error_detailed_tagged_union_mismatch_string TypeSingletons.overloaded_function_call_with_singletons_mismatch TypeSingletons.return_type_of_f_is_not_widened TypeSingletons.singletons_stick_around_under_assignment +TypeSingletons.string_singleton_function_call TypeSingletons.table_properties_type_error_escapes TypeSingletons.widen_the_supertype_if_it_is_free_and_subtype_has_singleton TypeStatesTest.typestates_preserve_error_suppression_properties @@ -423,7 +409,6 @@ UnionTypes.generic_function_with_optional_arg UnionTypes.index_on_a_union_type_with_missing_property UnionTypes.less_greedy_unification_with_union_types UnionTypes.optional_arguments_table -UnionTypes.optional_length_error UnionTypes.optional_union_functions UnionTypes.optional_union_members UnionTypes.optional_union_methods From 041b8ee4e71de42e686cd787d7fe28e7036209eb Mon Sep 17 00:00:00 2001 From: JohnnyMorganz Date: Tue, 4 Jun 2024 22:53:01 +0200 Subject: [PATCH 13/20] Fix edge case in 'findBindingAtPosition' when looking up global binding at start of file (#1254) The 'findBindingAtPosition' AstQuery function can be used to lookup a local or global binding. Inside of this function is a check to "Ignore this binding if we're inside its definition. e.g. local abc = abc -- Will take the definition of abc from outer scope". However, this check is incorrect when we are looking up a global binding at the start of a file. Consider a complete file with the contents: ```lua local x = stri|ng.char(1) ``` and we pass the location of the marker `|` as the position to the find binding position. We will pick up the global binding of the definition `string` coming from a builtin source (either defined via C++ code or a definitions file and loaded into the global scope). The global binding `string` will have a zero position: `0,0,0,0`. However, the `findBindingLocalStatement` check works by looking up the AstAncestry at the binding's defined begin position *in the current source module*. This will then incorrectly return the local statement for `local x`, as that is at the start of the source code. Then in turn, we assume we are in the `local abc = abc` case, and end up skipping over the correct binding. We fix this by checking if the binding is at the global position. If so, we early exit because it is impossible for a global binding to be defined in a local statement. --- Analysis/src/AstQuery.cpp | 6 ++++++ tests/AstQuery.test.cpp | 14 ++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/Analysis/src/AstQuery.cpp b/Analysis/src/AstQuery.cpp index cebb226a..928e5dfb 100644 --- a/Analysis/src/AstQuery.cpp +++ b/Analysis/src/AstQuery.cpp @@ -12,6 +12,7 @@ #include LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); +LUAU_FASTFLAGVARIABLE(LuauFixBindingForGlobalPos, false); namespace Luau { @@ -332,6 +333,11 @@ std::optional findExpectedTypeAtPosition(const Module& module, const Sou static std::optional findBindingLocalStatement(const SourceModule& source, const Binding& binding) { + // Bindings coming from global sources (e.g., definition files) have a zero position. + // They cannot be defined from a local statement + if (FFlag::LuauFixBindingForGlobalPos && binding.location == Location{{0, 0}, {0, 0}}) + return std::nullopt; + std::vector nodes = findAstAncestryOfPosition(source, binding.location.begin); auto iter = std::find_if(nodes.rbegin(), nodes.rend(), [](AstNode* node) { return node->is(); diff --git a/tests/AstQuery.test.cpp b/tests/AstQuery.test.cpp index 769637a5..c53fe731 100644 --- a/tests/AstQuery.test.cpp +++ b/tests/AstQuery.test.cpp @@ -6,6 +6,8 @@ #include "doctest.h" #include "Fixture.h" +LUAU_FASTFLAG(LuauFixBindingForGlobalPos); + using namespace Luau; struct DocumentationSymbolFixture : BuiltinsFixture @@ -331,4 +333,16 @@ TEST_CASE_FIXTURE(Fixture, "find_expr_ancestry") CHECK(ancestry.back()->is()); } +TEST_CASE_FIXTURE(BuiltinsFixture, "find_binding_at_position_global_start_of_file") +{ + ScopedFastFlag sff{FFlag::LuauFixBindingForGlobalPos, true}; + check("local x = string.char(1)"); + const Position pos(0, 12); + + std::optional binding = findBindingAtPosition(*getMainModule(), *getMainSourceModule(), pos); + + REQUIRE(binding); + CHECK_EQ(binding->location, Location{Position{0, 0}, Position{0, 0}}); +} + TEST_SUITE_END(); From 43bf7c4e051b0d49dbb2bd3cbb2471d235da55db Mon Sep 17 00:00:00 2001 From: Jack <85714123+jackdotink@users.noreply.github.com> Date: Wed, 5 Jun 2024 09:52:30 -0500 Subject: [PATCH 14/20] implement leading bar and ampersand in types (#1286) Implements the [Leading `|` and `&` in types](https://rfcs.luau-lang.org/syntax-leading-bar-and-ampersand.html) RFC. The changes to the parser are exactly as described in the RFC. --------- Co-authored-by: Alexander McCord <11488393+alexmccord@users.noreply.github.com> --- Ast/src/Parser.cpp | 34 +++++++++++++++++++++++++++++----- tests/Parser.test.cpp | 23 +++++++++++++++++++++++ 2 files changed, 52 insertions(+), 5 deletions(-) diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index e26df1fa..5ca480e8 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -17,6 +17,7 @@ LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) // flag so that we don't break production games by reverting syntax changes. // See docs/SyntaxChanges.md for an explanation. LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false) +LUAU_FASTFLAGVARIABLE(LuauLeadingBarAndAmpersand, false) namespace Luau { @@ -1523,7 +1524,11 @@ AstType* Parser::parseFunctionTypeTail(const Lexeme& begin, AstArray parts(scratchType); - parts.push_back(type); + + if (!FFlag::LuauLeadingBarAndAmpersand || type != nullptr) + { + parts.push_back(type); + } incrementRecursionCounter("type annotation"); @@ -1623,15 +1628,34 @@ AstTypeOrPack Parser::parseTypeOrPack() AstType* Parser::parseType(bool inDeclarationContext) { unsigned int oldRecursionCount = recursionCounter; - // recursion counter is incremented in parseSimpleType + // recursion counter is incremented in parseSimpleType and/or parseTypeSuffix Location begin = lexer.current().location; - AstType* type = parseSimpleType(/* allowPack= */ false, /* in declaration context */ inDeclarationContext).type; + if (FFlag::LuauLeadingBarAndAmpersand) + { + AstType* type = nullptr; - recursionCounter = oldRecursionCount; + Lexeme::Type c = lexer.current().type; + if (c != '|' && c != '&') + { + type = parseSimpleType(/* allowPack= */ false, /* in declaration context */ inDeclarationContext).type; + recursionCounter = oldRecursionCount; + } - return parseTypeSuffix(type, begin); + AstType* typeWithSuffix = parseTypeSuffix(type, begin); + recursionCounter = oldRecursionCount; + + return typeWithSuffix; + } + else + { + AstType* type = parseSimpleType(/* allowPack= */ false, /* in declaration context */ inDeclarationContext).type; + + recursionCounter = oldRecursionCount; + + return parseTypeSuffix(type, begin); + } } // Type ::= nil | Name[`.' Name] [ `<' Type [`,' ...] `>' ] | `typeof' `(' expr `)' | `{' [PropList] `}' diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index b178f539..6b4bcf22 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -16,6 +16,7 @@ LUAU_FASTINT(LuauRecursionLimit); LUAU_FASTINT(LuauTypeLengthLimit); LUAU_FASTINT(LuauParseErrorLimit); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); +LUAU_FASTFLAG(LuauLeadingBarAndAmpersand); namespace { @@ -3167,4 +3168,26 @@ TEST_CASE_FIXTURE(Fixture, "read_write_table_properties") LUAU_ASSERT(pr.errors.size() == 0); } +TEST_CASE_FIXTURE(Fixture, "can_parse_leading_bar_unions_successfully") +{ + ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand, true}; + + parse(R"(type A = | "Hello" | "World")"); +} + +TEST_CASE_FIXTURE(Fixture, "can_parse_leading_ampersand_intersections_successfully") +{ + ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand, true}; + + parse(R"(type A = & { string } & { number })"); +} + +TEST_CASE_FIXTURE(Fixture, "mixed_leading_intersection_and_union_not_allowed") +{ + ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand, true}; + + matchParseError("type A = & number | string | boolean", "Mixing union and intersection types is not allowed; consider wrapping in parentheses."); + matchParseError("type A = | number & string & boolean", "Mixing union and intersection types is not allowed; consider wrapping in parentheses."); +} + TEST_SUITE_END(); From 23b872620394c04592a0eae81491b4d2dcf5dfab Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Fri, 7 Jun 2024 11:56:00 -0500 Subject: [PATCH 15/20] Fix incorrect comment in lgc.h (#1288) The comment gave an incorrect (reversed) version of the invariant, which could be confusing for people who haven't read the full description in lgc.cpp. Unfortunately this change is difficult to flag. Fixes #1282. --- VM/src/lgc.h | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/VM/src/lgc.h b/VM/src/lgc.h index ba433c67..010d7e86 100644 --- a/VM/src/lgc.h +++ b/VM/src/lgc.h @@ -23,11 +23,10 @@ #define GCSsweep 4 /* -** macro to tell when main invariant (white objects cannot point to black -** ones) must be kept. During a collection, the sweep -** phase may break the invariant, as objects turned white may point to -** still-black objects. The invariant is restored when sweep ends and -** all objects are white again. +** The main invariant of the garbage collector, while marking objects, +** is that a black object can never point to a white one. This invariant +** is not being enforced during a sweep phase, and is restored when sweep +** ends. */ #define keepinvariant(g) ((g)->gcstate == GCSpropagate || (g)->gcstate == GCSpropagateagain || (g)->gcstate == GCSatomic) From 81b2cc7dbe928f78c30b2a2c5629a84e014a750e Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Fri, 7 Jun 2024 12:05:50 -0500 Subject: [PATCH 16/20] tests: Adjust conformance tests to account for array invariant (#1289) These were written before compiler optimizations and array invariant. It is now impossible for t[1] to be stored in the hash part, as this would violate the array invariant that says that elements 1..#t are stored in the array. For ipairs, it doesn't traverse the hash part anymore now, so we adjust the code to make sure no elements outside of the 1..#t slice are covered. For table.find, we can use find-with-offset to still access the hash part. Fixes #1283. --- tests/conformance/basic.lua | 7 +++---- tests/conformance/tables.lua | 4 ++-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/conformance/basic.lua b/tests/conformance/basic.lua index 17f4497a..98f8000e 100644 --- a/tests/conformance/basic.lua +++ b/tests/conformance/basic.lua @@ -726,16 +726,15 @@ assert((function() return sum end)() == 15) --- the reason why this test is interesting is that the table created here has arraysize=0 and a single hash element with key = 1.0 --- ipairs must iterate through that +-- ipairs will not iterate through hash part assert((function() - local arr = { [1] = 42 } + local arr = { [1] = 1, [42] = 42, x = 10 } local sum = 0 for i,v in ipairs(arr) do sum = sum + v end return sum -end)() == 42) +end)() == 1) -- the reason why this test is interesting is it ensures we do correct mutability analysis for locals local function chainTest(n) diff --git a/tests/conformance/tables.lua b/tests/conformance/tables.lua index 3f1efd8e..c739f555 100644 --- a/tests/conformance/tables.lua +++ b/tests/conformance/tables.lua @@ -412,8 +412,8 @@ do assert(table.find({false, true}, true) == 2) - -- make sure table.find checks the hash portion as well by constructing a table literal that forces the value into the hash part - assert(table.find({[(1)] = true}, true) == 1) + -- make sure table.find checks the hash portion as well + assert(table.find({[(2)] = true}, true, 2) == 2) end -- test table.concat From 0fa6a51c914b94c47019cdf5b7c6cfc9358855fb Mon Sep 17 00:00:00 2001 From: Andy Friesen Date: Fri, 7 Jun 2024 10:51:12 -0700 Subject: [PATCH 17/20] Sync to upstream/release/629 (#1290) ### What's new * Implemented parsing logic for attributes * Added `lua_setuserdatametatable` and `lua_getuserdatametatable` C API methods for a faster userdata metatable fetch compared to `luaL_getmetatable`. Note that metatable reference has to still be pinned in memory! ### New Solver * Further improvement to the assignment inference logic * Fix many bugs surrounding constraint dispatch order ### Native Codegen * Add IR lowering hooks for custom host userdata types * Add IR to create new tagged userdata objects * Remove outdated NativeState --- ### Internal Contributors Co-authored-by: Aaron Weiss Co-authored-by: Alexander McCord Co-authored-by: Andy Friesen Co-authored-by: Aviral Goel Co-authored-by: Vighnesh Vijay Co-authored-by: Vyacheslav Egorov --- Analysis/include/Luau/Constraint.h | 22 +- Analysis/include/Luau/ConstraintGenerator.h | 4 + Analysis/include/Luau/ConstraintSolver.h | 14 +- Analysis/include/Luau/Type.h | 30 +- Analysis/include/Luau/Unifier2.h | 1 - Analysis/include/Luau/VisitType.h | 9 - Analysis/src/Clone.cpp | 5 - Analysis/src/Constraint.cpp | 11 +- Analysis/src/ConstraintGenerator.cpp | 211 ++++++++-- Analysis/src/ConstraintSolver.cpp | 193 ++++------ Analysis/src/EmbeddedBuiltinDefinitions.cpp | 9 +- Analysis/src/Frontend.cpp | 6 - Analysis/src/Normalize.cpp | 12 +- Analysis/src/Substitution.cpp | 2 - Analysis/src/ToDot.cpp | 8 - Analysis/src/ToString.cpp | 43 +-- Analysis/src/Transpiler.cpp | 6 +- Analysis/src/Type.cpp | 5 + Analysis/src/TypeAttach.cpp | 4 - Analysis/src/TypeFamily.cpp | 10 +- Analysis/src/Unifier2.cpp | 6 - Ast/include/Luau/Ast.h | 60 ++- Ast/include/Luau/Lexer.h | 14 +- Ast/include/Luau/Parser.h | 35 +- Ast/src/Ast.cpp | 66 +++- Ast/src/Lexer.cpp | 32 +- Ast/src/Parser.cpp | 212 ++++++++-- CodeGen/include/Luau/CodeGen.h | 55 +++ CodeGen/include/Luau/IrData.h | 12 + CodeGen/include/Luau/IrUtils.h | 5 + CodeGen/src/BytecodeAnalysis.cpp | 137 ++++++- CodeGen/src/CodeGenA64.cpp | 33 -- CodeGen/src/CodeGenA64.h | 2 - CodeGen/src/CodeGenContext.cpp | 4 +- CodeGen/src/CodeGenUtils.cpp | 15 + CodeGen/src/CodeGenUtils.h | 2 + CodeGen/src/CodeGenX64.cpp | 33 -- CodeGen/src/CodeGenX64.h | 2 - CodeGen/src/EmitCommonA64.h | 2 - CodeGen/src/EmitCommonX64.h | 1 - CodeGen/src/IrDump.cpp | 4 + CodeGen/src/IrLoweringA64.cpp | 132 +++++-- CodeGen/src/IrLoweringA64.h | 2 +- CodeGen/src/IrLoweringX64.cpp | 72 +++- CodeGen/src/IrLoweringX64.h | 2 +- CodeGen/src/IrTranslation.cpp | 85 +++- CodeGen/src/IrUtils.cpp | 40 ++ CodeGen/src/NativeState.cpp | 106 +---- CodeGen/src/NativeState.h | 17 +- CodeGen/src/OptimizeConstProp.cpp | 54 +++ CodeGen/src/OptimizeDeadStore.cpp | 6 + Compiler/src/Compiler.cpp | 3 +- Config/src/Config.cpp | 6 +- VM/include/lua.h | 4 + VM/src/lapi.cpp | 27 ++ VM/src/lstate.cpp | 3 + VM/src/lstate.h | 1 + tests/Conformance.test.cpp | 313 ++++++++++++++- tests/ConformanceIrHooks.h | 405 +++++++++++++++++++- tests/IrLowering.test.cpp | 357 +++++++++++++++++ tests/Lexer.test.cpp | 4 +- tests/NonStrictTypeChecker.test.cpp | 26 +- tests/Parser.test.cpp | 313 ++++++++++++++- tests/ToString.test.cpp | 4 +- tests/TypeInfer.loops.test.cpp | 12 +- tests/TypeInfer.tables.test.cpp | 10 +- tests/conformance/native_userdata.lua | 42 ++ tools/faillist.txt | 16 +- 68 files changed, 2712 insertions(+), 687 deletions(-) create mode 100644 tests/conformance/native_userdata.lua diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h index 77810516..3f3ad641 100644 --- a/Analysis/include/Luau/Constraint.h +++ b/Analysis/include/Luau/Constraint.h @@ -57,7 +57,7 @@ struct GeneralizationConstraint struct IterableConstraint { TypePackId iterator; - TypePackId variables; + std::vector variables; const AstNode* nextAstFragment; DenseHashMap* astForInNextTypes; @@ -192,13 +192,7 @@ struct HasIndexerConstraint TypeId indexType; }; -struct AssignConstraint -{ - TypeId lhsType; - TypeId rhsType; -}; - -// assign lhsType propName rhsType +// assignProp lhsType propName rhsType // // Assign a value of type rhsType into the named property of lhsType. @@ -212,6 +206,12 @@ struct AssignPropConstraint /// populate astTypes during constraint resolution. Nothing should ever /// block on it. TypeId propType; + + // When we generate constraints, we increment the remaining prop count on + // the table if we are able. This flag informs the solver as to whether or + // not it should in turn decrement the prop count when this constraint is + // dispatched. + bool decrementPropCount = false; }; struct AssignIndexConstraint @@ -226,13 +226,13 @@ struct AssignIndexConstraint TypeId propType; }; -// resultType ~ unpack sourceTypePack +// resultTypes ~ unpack sourceTypePack // // Similar to PackSubtypeConstraint, but with one important difference: If the // sourcePack is blocked, this constraint blocks. struct UnpackConstraint { - TypePackId resultPack; + std::vector resultPack; TypePackId sourcePack; }; @@ -254,7 +254,7 @@ struct ReducePackConstraint using ConstraintV = Variant; + AssignPropConstraint, AssignIndexConstraint, UnpackConstraint, ReduceConstraint, ReducePackConstraint, EqualityConstraint>; struct Constraint { diff --git a/Analysis/include/Luau/ConstraintGenerator.h b/Analysis/include/Luau/ConstraintGenerator.h index 3e1861ea..b540b82f 100644 --- a/Analysis/include/Luau/ConstraintGenerator.h +++ b/Analysis/include/Luau/ConstraintGenerator.h @@ -118,6 +118,8 @@ struct ConstraintGenerator std::function prepareModuleScope; std::vector requireCycles; + DenseHashMap> localTypes{nullptr}; + DcrLogger* logger; ConstraintGenerator(ModulePtr module, NotNull normalizer, NotNull moduleResolver, NotNull builtinTypes, @@ -354,6 +356,8 @@ private: */ void prepopulateGlobalScope(const ScopePtr& globalScope, AstStatBlock* program); + bool recordPropertyAssignment(TypeId ty); + // Record the fact that a particular local has a particular type in at least // one of its states. void recordInferredBinding(AstLocal* local, TypeId ty); diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index 58361dde..902dd15d 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -142,7 +142,6 @@ struct ConstraintSolver std::pair> tryDispatchSetIndexer( NotNull constraint, TypeId subjectType, TypeId indexType, TypeId propType, bool expandFreeTypeBounds); - bool tryDispatch(const AssignConstraint& c, NotNull constraint); bool tryDispatch(const AssignPropConstraint& c, NotNull constraint); bool tryDispatch(const AssignIndexConstraint& c, NotNull constraint); @@ -158,8 +157,7 @@ struct ConstraintSolver bool tryDispatchIterableTable(TypeId iteratorTy, const IterableConstraint& c, NotNull constraint, bool force); // for a, ... in next_function, t, ... do - bool tryDispatchIterableFunction( - TypeId nextTy, TypeId tableTy, TypeId firstIndexTy, const IterableConstraint& c, NotNull constraint, bool force); + bool tryDispatchIterableFunction(TypeId nextTy, TypeId tableTy, const IterableConstraint& c, NotNull constraint, bool force); std::pair, std::optional> lookupTableProp(NotNull constraint, TypeId subjectType, const std::string& propName, ValueContext context, bool inConditional = false, bool suppressSimplification = false); @@ -168,14 +166,18 @@ struct ConstraintSolver /** * Generate constraints to unpack the types of srcTypes and assign each - * value to the corresponding LocalType in destTypes. + * value to the corresponding BlockedType in destTypes. * - * @param destTypes A finite TypePack comprised of LocalTypes. + * This function also overwrites the owners of each BlockedType. This is + * okay because this function is only used to decompose IterableConstraint + * into an UnpackConstraint. + * + * @param destTypes A vector of types comprised of BlockedTypes. * @param srcTypes A TypePack that represents rvalues to be assigned. * @returns The underlying UnpackConstraint. There's a bit of code in * iteration that needs to pass blocks on to this constraint. */ - NotNull unpackAndAssign(TypePackId destTypes, TypePackId srcTypes, NotNull constraint); + NotNull unpackAndAssign(const std::vector destTypes, TypePackId srcTypes, NotNull constraint); void block(NotNull target, NotNull constraint); /** diff --git a/Analysis/include/Luau/Type.h b/Analysis/include/Luau/Type.h index 47161886..6105ede3 100644 --- a/Analysis/include/Luau/Type.h +++ b/Analysis/include/Luau/Type.h @@ -86,24 +86,6 @@ struct FreeType TypeId upperBound = nullptr; }; -/** A type that tracks the domain of a local variable. - * - * We consider each local's domain to be the union of all types assigned to it. - * We accomplish this with LocalType. Each time we dispatch an assignment to a - * local, we accumulate this union and decrement blockCount. - * - * When blockCount reaches 0, we can consider the LocalType to be "fully baked" - * and replace it with the union we've built. - */ -struct LocalType -{ - TypeId domain; - int blockCount = 0; - - // Used for debugging - std::string name; -}; - struct GenericType { // By default, generics are global, with a synthetic name @@ -148,6 +130,7 @@ struct BlockedType Constraint* getOwner() const; void setOwner(Constraint* newOwner); + void replaceOwner(Constraint* newOwner); private: // The constraint that is intended to unblock this type. Other constraints @@ -471,6 +454,11 @@ struct TableType // Methods of this table that have an untyped self will use the same shared self type. std::optional selfTy; + + // We track the number of as-yet-unadded properties to unsealed tables. + // Some constraints will use this information to decide whether or not they + // are able to dispatch. + size_t remainingProps = 0; }; // Represents a metatable attached to a table type. Somewhat analogous to a bound type. @@ -669,9 +657,9 @@ struct NegationType using ErrorType = Unifiable::Error; -using TypeVariant = Unifiable::Variant; +using TypeVariant = + Unifiable::Variant; struct Type final { diff --git a/Analysis/include/Luau/Unifier2.h b/Analysis/include/Luau/Unifier2.h index 130c0c3c..bbf3a63a 100644 --- a/Analysis/include/Luau/Unifier2.h +++ b/Analysis/include/Luau/Unifier2.h @@ -69,7 +69,6 @@ struct Unifier2 */ bool unify(TypeId subTy, TypeId superTy); bool unifyFreeWithType(TypeId subTy, TypeId superTy); - bool unify(const LocalType* subTy, TypeId superFn); bool unify(TypeId subTy, const FunctionType* superFn); bool unify(const UnionType* subUnion, TypeId superTy); bool unify(TypeId subTy, const UnionType* superUnion); diff --git a/Analysis/include/Luau/VisitType.h b/Analysis/include/Luau/VisitType.h index 40dccbd2..ff0656d6 100644 --- a/Analysis/include/Luau/VisitType.h +++ b/Analysis/include/Luau/VisitType.h @@ -100,10 +100,6 @@ struct GenericTypeVisitor { return visit(ty); } - virtual bool visit(TypeId ty, const LocalType& ftv) - { - return visit(ty); - } virtual bool visit(TypeId ty, const GenericType& gtv) { return visit(ty); @@ -248,11 +244,6 @@ struct GenericTypeVisitor else visit(ty, *ftv); } - else if (auto lt = get(ty)) - { - if (visit(ty, *lt)) - traverse(lt->domain); - } else if (auto gtv = get(ty)) visit(ty, *gtv); else if (auto etv = get(ty)) diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index a96e5866..371ace2e 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -271,11 +271,6 @@ private: t->upperBound = shallowClone(t->upperBound); } - void cloneChildren(LocalType* t) - { - t->domain = shallowClone(t->domain); - } - void cloneChildren(GenericType* t) { // TOOD: clone upper bounds. diff --git a/Analysis/src/Constraint.cpp b/Analysis/src/Constraint.cpp index bd31beff..7b3377cb 100644 --- a/Analysis/src/Constraint.cpp +++ b/Analysis/src/Constraint.cpp @@ -81,7 +81,8 @@ DenseHashSet Constraint::getMaybeMutatedFreeTypes() const } else if (auto itc = get(*this)) { - rci.traverse(itc->variables); + for (TypeId ty : itc->variables) + rci.traverse(ty); // `IterableConstraints` should not mutate `iterator`. } else if (auto nc = get(*this)) @@ -106,11 +107,6 @@ DenseHashSet Constraint::getMaybeMutatedFreeTypes() const rci.traverse(hic->resultType); // `HasIndexerConstraint` should not mutate `subjectType` or `indexType`. } - else if (auto ac = get(*this)) - { - rci.traverse(ac->lhsType); - rci.traverse(ac->rhsType); - } else if (auto apc = get(*this)) { rci.traverse(apc->lhsType); @@ -124,7 +120,8 @@ DenseHashSet Constraint::getMaybeMutatedFreeTypes() const } else if (auto uc = get(*this)) { - rci.traverse(uc->resultPack); + for (TypeId ty : uc->resultPack) + rci.traverse(ty); // `UnpackConstraint` should not mutate `sourcePack`. } else if (auto rpc = get(*this)) diff --git a/Analysis/src/ConstraintGenerator.cpp b/Analysis/src/ConstraintGenerator.cpp index 12648eb0..9d825408 100644 --- a/Analysis/src/ConstraintGenerator.cpp +++ b/Analysis/src/ConstraintGenerator.cpp @@ -28,6 +28,7 @@ LUAU_FASTINT(LuauCheckRecursionLimit); LUAU_FASTFLAG(DebugLuauLogSolverToJson); LUAU_FASTFLAG(DebugLuauMagicTypes); +LUAU_FASTFLAG(LuauAttributeSyntax); namespace Luau { @@ -246,6 +247,17 @@ void ConstraintGenerator::visitModuleRoot(AstStatBlock* block) if (logger) logger->captureGenerationModule(module); + + for (const auto& [ty, domain] : localTypes) + { + // FIXME: This isn't the most efficient thing. + TypeId domainTy = builtinTypes->neverType; + for (TypeId d : domain) + domainTy = simplifyUnion(builtinTypes, arena, domainTy, d).result; + + LUAU_ASSERT(get(ty)); + asMutable(ty)->ty.emplace(domainTy); + } } TypeId ConstraintGenerator::freshType(const ScopePtr& scope) @@ -310,7 +322,8 @@ std::optional ConstraintGenerator::lookup(const ScopePtr& scope, Locatio std::optional ty = lookup(scope, location, operand, /*prototype*/ false); if (!ty) { - ty = arena->addType(LocalType{builtinTypes->neverType}); + ty = arena->addType(BlockedType{}); + localTypes[*ty] = {}; rootScope->lvalueTypes[operand] = *ty; } @@ -703,7 +716,8 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatLocal* stat { const Location location = local->location; - TypeId assignee = arena->addType(LocalType{builtinTypes->neverType, /* blockCount */ 1, local->name.value}); + TypeId assignee = arena->addType(BlockedType{}); + localTypes[assignee] = {}; assignees.push_back(assignee); @@ -740,7 +754,12 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatLocal* stat if (hasAnnotation) { for (size_t i = 0; i < statLocal->vars.size; ++i) - addConstraint(scope, statLocal->location, AssignConstraint{assignees[i], annotatedTypes[i]}); + { + LUAU_ASSERT(get(assignees[i])); + std::vector* localDomain = localTypes.find(assignees[i]); + LUAU_ASSERT(localDomain); + localDomain->push_back(annotatedTypes[i]); + } TypePackId annotatedPack = arena->addTypePack(std::move(annotatedTypes)); addConstraint(scope, statLocal->location, PackSubtypeConstraint{rvaluePack, annotatedPack}); @@ -750,15 +769,30 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatLocal* stat std::vector valueTypes; valueTypes.reserve(statLocal->vars.size); - for (size_t i = 0; i < statLocal->vars.size; ++i) - valueTypes.push_back(arena->addType(BlockedType{})); + auto [head, tail] = flatten(rvaluePack); - auto uc = addConstraint(scope, statLocal->location, UnpackConstraint{arena->addTypePack(valueTypes), rvaluePack}); + if (head.size() >= statLocal->vars.size) + { + for (size_t i = 0; i < statLocal->vars.size; ++i) + valueTypes.push_back(head[i]); + } + else + { + for (size_t i = 0; i < statLocal->vars.size; ++i) + valueTypes.push_back(arena->addType(BlockedType{})); + + auto uc = addConstraint(scope, statLocal->location, UnpackConstraint{valueTypes, rvaluePack}); + + for (TypeId t: valueTypes) + getMutable(t)->setOwner(uc); + } for (size_t i = 0; i < statLocal->vars.size; ++i) { - getMutable(valueTypes[i])->setOwner(uc); - addConstraint(scope, statLocal->location, AssignConstraint{assignees[i], valueTypes[i]}); + LUAU_ASSERT(get(assignees[i])); + std::vector* localDomain = localTypes.find(assignees[i]); + LUAU_ASSERT(localDomain); + localDomain->push_back(valueTypes[i]); } } @@ -860,25 +894,34 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatForIn* forI for (AstLocal* var : forIn->vars) { - TypeId assignee = arena->addType(LocalType{builtinTypes->neverType, /* blockCount */ 1, var->name.value}); + TypeId assignee = arena->addType(BlockedType{}); variableTypes.push_back(assignee); + TypeId loopVar = arena->addType(BlockedType{}); + localTypes[loopVar].push_back(assignee); + if (var->annotation) { TypeId annotationTy = resolveType(loopScope, var->annotation, /*inTypeArguments*/ false); loopScope->bindings[var] = Binding{annotationTy, var->location}; - addConstraint(scope, var->location, SubtypeConstraint{assignee, annotationTy}); + addConstraint(scope, var->location, SubtypeConstraint{loopVar, annotationTy}); } else - loopScope->bindings[var] = Binding{assignee, var->location}; + loopScope->bindings[var] = Binding{loopVar, var->location}; DefId def = dfg->getDef(var); - loopScope->lvalueTypes[def] = assignee; + loopScope->lvalueTypes[def] = loopVar; } - TypePackId variablePack = arena->addTypePack(std::move(variableTypes)); auto iterable = addConstraint( - loopScope, getLocation(forIn->values), IterableConstraint{iterator, variablePack, forIn->values.data[0], &module->astForInNextTypes}); + loopScope, getLocation(forIn->values), IterableConstraint{iterator, variableTypes, forIn->values.data[0], &module->astForInNextTypes}); + + for (TypeId var: variableTypes) + { + auto bt = getMutable(var); + LUAU_ASSERT(bt); + bt->setOwner(iterable); + } Checkpoint start = checkpoint(this); visit(loopScope, forIn->body); @@ -1105,14 +1148,31 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatAssign* ass std::vector valueTypes; valueTypes.reserve(assign->vars.size); - for (size_t i = 0; i < assign->vars.size; ++i) - valueTypes.push_back(arena->addType(BlockedType{})); + auto [head, tail] = flatten(resultPack); + if (head.size() >= assign->vars.size) + { + // If the resultPack is definitely long enough for each variable, we can + // skip the UnpackConstraint and use the result types directly. - auto uc = addConstraint(scope, assign->location, UnpackConstraint{arena->addTypePack(valueTypes), resultPack}); + for (size_t i = 0; i < assign->vars.size; ++i) + valueTypes.push_back(head[i]); + } + else + { + // We're not sure how many types are produced by the right-side + // expressions. We'll use an UnpackConstraint to defer this until + // later. + for (size_t i = 0; i < assign->vars.size; ++i) + valueTypes.push_back(arena->addType(BlockedType{})); + + auto uc = addConstraint(scope, assign->location, UnpackConstraint{valueTypes, resultPack}); + + for (TypeId t: valueTypes) + getMutable(t)->setOwner(uc); + } for (size_t i = 0; i < assign->vars.size; ++i) { - getMutable(valueTypes[i])->setOwner(uc); visitLValue(scope, assign->vars.data[i], valueTypes[i]); } @@ -1393,7 +1453,7 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatDeclareFunc TypePackId retPack = resolveTypePack(funScope, global->retTypes, /* inTypeArguments */ false); TypeId fnType = arena->addType(FunctionType{TypeLevel{}, funScope.get(), std::move(genericTys), std::move(genericTps), paramPack, retPack}); FunctionType* ftv = getMutable(fnType); - ftv->isCheckedFunction = global->checkedFunction; + ftv->isCheckedFunction = FFlag::LuauAttributeSyntax ? global->isCheckedFunction() : false; ftv->argNames.reserve(global->paramNames.size); for (const auto& el : global->paramNames) @@ -1599,9 +1659,8 @@ InferencePack ConstraintGenerator::checkPack(const ScopePtr& scope, AstExprCall* mt = arena->addType(BlockedType{}); unpackedTypes.emplace_back(mt); - TypePackId mtPack = arena->addTypePack(std::move(unpackedTypes)); - auto c = addConstraint(scope, call->location, UnpackConstraint{mtPack, *argTail}); + auto c = addConstraint(scope, call->location, UnpackConstraint{unpackedTypes, *argTail}); getMutable(mt)->setOwner(c); if (auto b = getMutable(target); b && b->getOwner() == nullptr) b->setOwner(c); @@ -1842,7 +1901,37 @@ Inference ConstraintGenerator::checkIndexName( const ScopePtr& scope, const RefinementKey* key, AstExpr* indexee, const std::string& index, Location indexLocation) { TypeId obj = check(scope, indexee).ty; - TypeId result = arena->addType(BlockedType{}); + TypeId result = nullptr; + + // We optimize away the HasProp constraint in simple cases so that we can + // reason about updates to unsealed tables more accurately. + + const TableType* tt = getTableType(obj); + + // This is a little bit iffy but I *believe* it is okay because, if the + // local's domain is going to be extended at all, it will be someplace after + // the current lexical position within the script. + if (!tt) + { + if (auto localDomain = localTypes.find(obj); localDomain && 1 == localDomain->size()) + tt = getTableType(localDomain->front()); + } + + if (tt) + { + auto it = tt->props.find(index); + if (it != tt->props.end() && it->second.readTy.has_value()) + result = *it->second.readTy; + } + + if (!result) + { + result = arena->addType(BlockedType{}); + + auto c = addConstraint( + scope, indexee->location, HasPropConstraint{result, obj, std::move(index), ValueContext::RValue, inConditional(typeContext)}); + getMutable(result)->setOwner(c); + } if (key) { @@ -1852,10 +1941,6 @@ Inference ConstraintGenerator::checkIndexName( scope->rvalueRefinements[key->def] = result; } - auto c = - addConstraint(scope, indexee->location, HasPropConstraint{result, obj, std::move(index), ValueContext::RValue, inConditional(typeContext)}); - getMutable(result)->setOwner(c); - if (key) return Inference{result, refinementArena.proposition(key, builtinTypes->truthyType)}; else @@ -2242,18 +2327,14 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprLocal* local if (ty) { - if (auto lt = getMutable(*ty)) - ++lt->blockCount; - else if (auto ut = getMutable(*ty)) - { - for (TypeId optTy : ut->options) - if (auto lt = getMutable(optTy)) - ++lt->blockCount; - } + std::vector* localDomain = localTypes.find(*ty); + if (localDomain) + localDomain->push_back(rhsType); } else { - ty = arena->addType(LocalType{builtinTypes->neverType, /* blockCount */ 1, local->local->name.value}); + ty = arena->addType(BlockedType{}); + localTypes[*ty].push_back(rhsType); if (annotatedTy) { @@ -2277,7 +2358,9 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprLocal* local if (annotatedTy) addConstraint(scope, local->location, SubtypeConstraint{rhsType, *annotatedTy}); - addConstraint(scope, local->location, AssignConstraint{*ty, rhsType}); + + if (auto localDomain = localTypes.find(*ty)) + localDomain->push_back(rhsType); } void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprGlobal* global, TypeId rhsType) @@ -2289,7 +2372,6 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprGlobal* glob rootScope->lvalueTypes[def] = rhsType; addConstraint(scope, global->location, SubtypeConstraint{rhsType, *annotatedTy}); - addConstraint(scope, global->location, AssignConstraint{*annotatedTy, rhsType}); } } @@ -2298,7 +2380,10 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprIndexName* e TypeId lhsTy = check(scope, expr->expr).ty; TypeId propTy = arena->addType(BlockedType{}); module->astTypes[expr] = propTy; - addConstraint(scope, expr->location, AssignPropConstraint{lhsTy, expr->index.value, rhsType, propTy}); + + bool incremented = recordPropertyAssignment(lhsTy); + + addConstraint(scope, expr->location, AssignPropConstraint{lhsTy, expr->index.value, rhsType, propTy, incremented}); } void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprIndexExpr* expr, TypeId rhsType) @@ -2310,7 +2395,10 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprIndexExpr* e module->astTypes[expr] = propTy; module->astTypes[expr->index] = builtinTypes->stringType; // FIXME? Singleton strings exist. std::string propName{constantString->value.data, constantString->value.size}; - addConstraint(scope, expr->location, AssignPropConstraint{lhsTy, std::move(propName), rhsType, propTy}); + + bool incremented = recordPropertyAssignment(lhsTy); + + addConstraint(scope, expr->location, AssignPropConstraint{lhsTy, std::move(propName), rhsType, propTy, incremented}); return; } @@ -2775,7 +2863,7 @@ TypeId ConstraintGenerator::resolveType(const ScopePtr& scope, AstType* ty, bool // TODO: FunctionType needs a pointer to the scope so that we know // how to quantify/instantiate it. FunctionType ftv{TypeLevel{}, scope.get(), {}, {}, argTypes, returnTypes}; - ftv.isCheckedFunction = fn->checkedFunction; + ftv.isCheckedFunction = FFlag::LuauAttributeSyntax ? fn->isCheckedFunction() : false; // This replicates the behavior of the appropriate FunctionType // constructors. @@ -2977,8 +3065,7 @@ Inference ConstraintGenerator::flattenPack(const ScopePtr& scope, Location locat return Inference{*f, refinement}; TypeId typeResult = arena->addType(BlockedType{}); - TypePackId resultPack = arena->addTypePack({typeResult}, arena->freshTypePack(scope.get())); - auto c = addConstraint(scope, location, UnpackConstraint{resultPack, tp}); + auto c = addConstraint(scope, location, UnpackConstraint{{typeResult}, tp}); getMutable(typeResult)->setOwner(c); return Inference{typeResult, refinement}; @@ -3075,6 +3162,46 @@ void ConstraintGenerator::prepopulateGlobalScope(const ScopePtr& globalScope, As program->visit(&gp); } +bool ConstraintGenerator::recordPropertyAssignment(TypeId ty) +{ + DenseHashSet seen{nullptr}; + VecDeque queue; + + queue.push_back(ty); + + bool incremented = false; + + while (!queue.empty()) + { + const TypeId t = follow(queue.front()); + queue.pop_front(); + + if (seen.find(t)) + continue; + seen.insert(t); + + if (auto tt = getMutable(t); tt && tt->state == TableState::Unsealed) + { + tt->remainingProps += 1; + incremented = true; + } + else if (auto mt = get(t)) + queue.push_back(mt->table); + else if (auto localDomain = localTypes.find(t)) + { + for (TypeId domainTy : *localDomain) + queue.push_back(domainTy); + } + else if (auto ut = get(t)) + { + for (TypeId part : ut) + queue.push_back(part); + } + } + + return incremented; +} + void ConstraintGenerator::recordInferredBinding(AstLocal* local, TypeId ty) { if (InferredBinding* ib = inferredBindings.find(local)) diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index b0f27911..07fc26fb 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -532,8 +532,6 @@ bool ConstraintSolver::tryDispatch(NotNull constraint, bool fo success = tryDispatch(*hpc, constraint); else if (auto spc = get(*constraint)) success = tryDispatch(*spc, constraint); - else if (auto uc = get(*constraint)) - success = tryDispatch(*uc, constraint); else if (auto uc = get(*constraint)) success = tryDispatch(*uc, constraint); else if (auto uc = get(*constraint)) @@ -686,7 +684,8 @@ bool ConstraintSolver::tryDispatch(const IterableConstraint& c, NotNullanyTypePack, c.variables); + for (TypeId ty : c.variables) + unify(constraint, builtinTypes->errorRecoveryType(), ty); return true; } @@ -696,21 +695,35 @@ bool ConstraintSolver::tryDispatch(const IterableConstraint& c, NotNullscope); TypeId valueTy = freshType(arena, builtinTypes, constraint->scope); - TypeId tableTy = arena->addType(TableType{TableState::Sealed, {}, constraint->scope}); - getMutable(tableTy)->indexer = TableIndexer{keyTy, valueTy}; + TypeId tableTy = arena->addType(TableType{ + TableType::Props{}, + TableIndexer{keyTy, valueTy}, + TypeLevel{}, + constraint->scope, + TableState::Free + }); - pushConstraint(constraint->scope, constraint->location, SubtypeConstraint{nextTy, tableTy}); + unify(constraint, nextTy, tableTy); auto it = begin(c.variables); auto endIt = end(c.variables); if (it != endIt) { - pushConstraint(constraint->scope, constraint->location, AssignConstraint{*it, keyTy}); + bindBlockedType(*it, keyTy, keyTy, constraint); ++it; } if (it != endIt) - pushConstraint(constraint->scope, constraint->location, AssignConstraint{*it, valueTy}); + { + bindBlockedType(*it, valueTy, valueTy, constraint); + ++it; + } + + while (it != endIt) + { + bindBlockedType(*it, builtinTypes->nilType, builtinTypes->nilType, constraint); + ++it; + } return true; } @@ -721,11 +734,7 @@ bool ConstraintSolver::tryDispatch(const IterableConstraint& c, NotNull= 2) tableTy = iterator.head[1]; - TypeId firstIndexTy = builtinTypes->nilType; - if (iterator.head.size() >= 3) - firstIndexTy = iterator.head[2]; - - return tryDispatchIterableFunction(nextTy, tableTy, firstIndexTy, c, constraint, force); + return tryDispatchIterableFunction(nextTy, tableTy, c, constraint, force); } else @@ -1310,6 +1319,14 @@ bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNull(subjectType) || get(subjectType)) return block(subjectType, constraint); + if (const TableType* subjectTable = getTableType(subjectType)) + { + if (subjectTable->state == TableState::Unsealed && subjectTable->remainingProps > 0 && subjectTable->props.count(c.prop) == 0) + { + return block(subjectType, constraint); + } + } + auto [blocked, result] = lookupTableProp(constraint, subjectType, c.prop, c.context, c.inConditional, c.suppressSimplification); if (!blocked.empty()) { @@ -1517,7 +1534,10 @@ bool ConstraintSolver::tryDispatch(const HasIndexerConstraint& c, NotNull seen{nullptr}; - return tryDispatchHasIndexer(recursionDepth, constraint, subjectType, indexType, c.resultType, seen); + bool ok = tryDispatchHasIndexer(recursionDepth, constraint, subjectType, indexType, c.resultType, seen); + if (ok) + unblock(c.resultType, constraint->location); + return ok; } std::pair> ConstraintSolver::tryDispatchSetIndexer( @@ -1596,46 +1616,6 @@ std::pair> ConstraintSolver::tryDispatchSetIndexer( return {true, std::nullopt}; } -bool ConstraintSolver::tryDispatch(const AssignConstraint& c, NotNull constraint) -{ - const TypeId lhsTy = follow(c.lhsType); - const TypeId rhsTy = follow(c.rhsType); - - if (!get(lhsTy) && isBlocked(lhsTy)) - return block(lhsTy, constraint); - - auto tryExpand = [&](TypeId ty) { - LocalType* lt = getMutable(ty); - if (!lt) - return; - - lt->domain = simplifyUnion(builtinTypes, arena, lt->domain, rhsTy).result; - LUAU_ASSERT(lt->blockCount > 0); - --lt->blockCount; - - if (0 == lt->blockCount) - { - shiftReferences(ty, lt->domain); - emplaceType(asMutable(ty), lt->domain); - } - }; - - if (auto ut = get(lhsTy)) - { - // FIXME: I suspect there's a bug here where lhsTy is a union that contains no LocalTypes. - for (TypeId t : ut) - tryExpand(t); - } - else if (get(lhsTy)) - tryExpand(lhsTy); - else - unify(constraint, rhsTy, lhsTy); - - unblock(lhsTy, constraint->location); - - return true; -} - bool ConstraintSolver::tryDispatch(const AssignPropConstraint& c, NotNull constraint) { TypeId lhsType = follow(c.lhsType); @@ -1753,6 +1733,14 @@ bool ConstraintSolver::tryDispatch(const AssignPropConstraint& c, NotNull(asMutable(c.propType), rhsType); lhsTable->props[propName] = Property::rw(rhsType); + + if (lhsTable->state == TableState::Unsealed && c.decrementPropCount) + { + LUAU_ASSERT(lhsTable->remainingProps > 0); + lhsTable->remainingProps -= 1; + unblock(lhsType, constraint->location); + } + return true; } } @@ -1927,24 +1915,14 @@ bool ConstraintSolver::tryDispatchUnpack1(NotNull constraint, bool ConstraintSolver::tryDispatch(const UnpackConstraint& c, NotNull constraint) { TypePackId sourcePack = follow(c.sourcePack); - TypePackId resultPack = follow(c.resultPack); if (isBlocked(sourcePack)) return block(sourcePack, constraint); - if (isBlocked(resultPack)) - { - LUAU_ASSERT(canMutate(resultPack, constraint)); - LUAU_ASSERT(resultPack != sourcePack); - emplaceTypePack(asMutable(resultPack), sourcePack); - unblock(resultPack, constraint->location); - return true; - } + TypePack srcPack = extendTypePack(*arena, builtinTypes, sourcePack, c.resultPack.size()); - TypePack srcPack = extendTypePack(*arena, builtinTypes, sourcePack, size(resultPack)); - - auto resultIter = begin(resultPack); - auto resultEnd = end(resultPack); + auto resultIter = begin(c.resultPack); + auto resultEnd = end(c.resultPack); size_t i = 0; while (resultIter != resultEnd) @@ -2080,18 +2058,22 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl auto endIt = end(c.variables); if (it != endIt) { - pushConstraint(constraint->scope, constraint->location, AssignConstraint{*it, keyTy}); + bindBlockedType(*it, keyTy, keyTy, constraint); ++it; } if (it != endIt) - pushConstraint(constraint->scope, constraint->location, AssignConstraint{*it, valueTy}); + bindBlockedType(*it, valueTy, valueTy, constraint); return true; } auto unpack = [&](TypeId ty) { - for (TypeId varTy : c.variables) - pushConstraint(constraint->scope, constraint->location, AssignConstraint{varTy, ty}); + for (TypeId varTy: c.variables) + { + LUAU_ASSERT(get(varTy)); + LUAU_ASSERT(varTy != ty); + bindBlockedType(varTy, ty, ty, constraint); + } }; if (get(iteratorTy)) @@ -2129,27 +2111,18 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl if (iteratorTable->indexer) { - TypePackId expectedVariablePack = arena->addTypePack({iteratorTable->indexer->indexType, iteratorTable->indexer->indexResultType}); - unify(constraint, c.variables, expectedVariablePack); + std::vector expectedVariables{iteratorTable->indexer->indexType, iteratorTable->indexer->indexResultType}; + while (c.variables.size() >= expectedVariables.size()) + expectedVariables.push_back(builtinTypes->errorRecoveryType()); - auto [variableTys, variablesTail] = flatten(c.variables); - - // the local types for the indexer _should_ be all set after unification - for (TypeId ty : variableTys) + for (size_t i = 0; i < c.variables.size(); ++i) { - if (auto lt = getMutable(ty)) - { - LUAU_ASSERT(lt->blockCount > 0); - --lt->blockCount; + LUAU_ASSERT(c.variables[i] != expectedVariables[i]); - LUAU_ASSERT(0 <= lt->blockCount); + unify(constraint, c.variables[i], expectedVariables[i]); - if (0 == lt->blockCount) - { - shiftReferences(ty, lt->domain); - emplaceType(asMutable(ty), lt->domain); - } - } + bindBlockedType(c.variables[i], expectedVariables[i], expectedVariables[i], constraint); + unblock(c.variables[i], constraint->location); } } else @@ -2213,26 +2186,16 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl else if (auto primitiveTy = get(iteratorTy); primitiveTy && primitiveTy->type == PrimitiveType::Type::Table) unpack(builtinTypes->unknownType); else + { unpack(builtinTypes->errorType); + } return true; } bool ConstraintSolver::tryDispatchIterableFunction( - TypeId nextTy, TypeId tableTy, TypeId firstIndexTy, const IterableConstraint& c, NotNull constraint, bool force) + TypeId nextTy, TypeId tableTy, const IterableConstraint& c, NotNull constraint, bool force) { - // We need to know whether or not this type is nil or not. - // If we don't know, block and reschedule ourselves. - firstIndexTy = follow(firstIndexTy); - if (get(firstIndexTy)) - { - if (force) - LUAU_ASSERT(false); - else - block(firstIndexTy, constraint); - return false; - } - const FunctionType* nextFn = get(nextTy); // If this does not hold, we should've never called `tryDispatchIterableFunction` in the first place. LUAU_ASSERT(nextFn); @@ -2267,27 +2230,18 @@ bool ConstraintSolver::tryDispatchIterableFunction( return true; } -NotNull ConstraintSolver::unpackAndAssign(TypePackId destTypes, TypePackId srcTypes, NotNull constraint) +NotNull ConstraintSolver::unpackAndAssign(const std::vector destTypes, TypePackId srcTypes, NotNull constraint) { - std::vector unpackedTys; - for (TypeId _ty : destTypes) + auto c = pushConstraint(constraint->scope, constraint->location, UnpackConstraint{destTypes, srcTypes}); + + for (TypeId t: destTypes) { - (void) _ty; - unpackedTys.push_back(arena->addType(BlockedType{})); + BlockedType* bt = getMutable(t); + LUAU_ASSERT(bt); + bt->replaceOwner(c); } - TypePackId unpackedTp = arena->addTypePack(TypePack{unpackedTys}); - auto unpackConstraint = pushConstraint(constraint->scope, constraint->location, UnpackConstraint{unpackedTp, srcTypes}); - - size_t i = 0; - for (TypeId varTy : destTypes) - { - pushConstraint(constraint->scope, constraint->location, AssignConstraint{varTy, unpackedTys[i]}); - getMutable(unpackedTys[i])->setOwner(unpackConstraint); - ++i; - } - - return unpackConstraint; + return c; } std::pair, std::optional> ConstraintSolver::lookupTableProp(NotNull constraint, TypeId subjectType, @@ -2808,9 +2762,6 @@ bool ConstraintSolver::isBlocked(TypeId ty) { ty = follow(ty); - if (auto lt = get(ty)) - return lt->blockCount > 0; - if (auto tfit = get(ty)) return uninhabitedTypeFamilies.contains(ty) == false; diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index 78b76a78..91d8006a 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -2,6 +2,7 @@ #include "Luau/BuiltinDefinitions.h" LUAU_FASTFLAGVARIABLE(LuauCheckedEmbeddedDefinitions2, false); +LUAU_FASTFLAG(LuauAttributeSyntax); namespace Luau { @@ -319,9 +320,9 @@ declare os: { clock: () -> number, } -declare function @checked require(target: any): any +@checked declare function require(target: any): any -declare function @checked getfenv(target: any): { [string]: any } +@checked declare function getfenv(target: any): { [string]: any } declare _G: any declare _VERSION: string @@ -363,7 +364,7 @@ declare function select(i: string | number, ...: A...): ...any -- (nil, string). declare function loadstring(src: string, chunkname: string?): (((A...) -> any)?, string?) -declare function @checked newproxy(mt: boolean?): any +@checked declare function newproxy(mt: boolean?): any declare coroutine: { create: (f: (A...) -> R...) -> thread, @@ -451,7 +452,7 @@ std::string getBuiltinDefinitionSource() std::string result = kBuiltinDefinitionLuaSrc; // Annotates each non generic function as checked - if (FFlag::LuauCheckedEmbeddedDefinitions2) + if (FFlag::LuauCheckedEmbeddedDefinitions2 && FFlag::LuauAttributeSyntax) result = kBuiltinDefinitionLuaSrcChecked; return result; diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 5261c211..7823f3d4 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -1196,12 +1196,6 @@ struct InternalTypeFinder : TypeOnceVisitor return false; } - bool visit(TypeId, const LocalType&) override - { - LUAU_ASSERT(false); - return false; - } - bool visit(TypePackId, const BlockedTypePack&) override { LUAU_ASSERT(false); diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index 5b14fd5f..7ce50284 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -1815,12 +1815,6 @@ NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId t if (!isCacheable(there)) here.isCacheable = false; } - else if (auto lt = get(there)) - { - // FIXME? This is somewhat questionable. - // Maybe we should assert because this should never happen? - unionNormalWithTy(here, lt->domain, seenSetTypes, ignoreSmallerTyvars); - } else if (get(there)) unionFunctionsWithFunction(here.functions, there); else if (get(there) || get(there)) @@ -3095,7 +3089,7 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type return NormalizationResult::True; } else if (get(there) || get(there) || get(there) || get(there) || - get(there) || get(there)) + get(there)) { NormalizedType thereNorm{builtinTypes}; NormalizedType topNorm{builtinTypes}; @@ -3104,10 +3098,6 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type here.isCacheable = false; return intersectNormals(here, thereNorm); } - else if (auto lt = get(there)) - { - return intersectNormalWithTy(here, lt->domain, seenSetTypes); - } NormalizedTyvars tyvars = std::move(here.tyvars); diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index bc899798..ea9c3178 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -24,8 +24,6 @@ static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool a // We decline to copy them. if constexpr (std::is_same_v) return ty; - else if constexpr (std::is_same_v) - return ty; else if constexpr (std::is_same_v) { // This should never happen, but visit() cannot see it. diff --git a/Analysis/src/ToDot.cpp b/Analysis/src/ToDot.cpp index 9093b38a..17b595b1 100644 --- a/Analysis/src/ToDot.cpp +++ b/Analysis/src/ToDot.cpp @@ -262,14 +262,6 @@ void StateDot::visitChildren(TypeId ty, int index) visitChild(t.upperBound, index, "[upperBound]"); } } - else if constexpr (std::is_same_v) - { - formatAppend(result, "LocalType"); - finishNodeLabel(ty); - finishNode(); - - visitChild(t.domain, 1, "[domain]"); - } else if constexpr (std::is_same_v) { formatAppend(result, "AnyType %d", index); diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 4e81a870..dca041a2 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -100,16 +100,6 @@ struct FindCyclicTypes final : TypeVisitor return false; } - bool visit(TypeId ty, const LocalType& lt) override - { - if (!visited.insert(ty)) - return false; - - traverse(lt.domain); - - return false; - } - bool visit(TypeId ty, const TableType& ttv) override { if (!visited.insert(ty)) @@ -525,21 +515,6 @@ struct TypeStringifier } } - void operator()(TypeId ty, const LocalType& lt) - { - state.emit("l-"); - state.emit(lt.name); - if (FInt::DebugLuauVerboseTypeNames >= 1) - { - state.emit("["); - state.emit(lt.blockCount); - state.emit("]"); - } - state.emit("=["); - stringify(lt.domain); - state.emit("]"); - } - void operator()(TypeId, const BoundType& btv) { stringify(btv.boundTo); @@ -1724,6 +1699,18 @@ std::string generateName(size_t i) return n; } +std::string toStringVector(const std::vector& types, ToStringOptions& opts) +{ + std::string s; + for (TypeId ty : types) + { + if (!s.empty()) + s += ", "; + s += toString(ty, opts); + } + return s; +} + std::string toString(const Constraint& constraint, ToStringOptions& opts) { auto go = [&opts](auto&& c) -> std::string { @@ -1754,7 +1741,7 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts) else if constexpr (std::is_same_v) { std::string iteratorStr = tos(c.iterator); - std::string variableStr = tos(c.variables); + std::string variableStr = toStringVector(c.variables, opts); return variableStr + " ~ iterate " + iteratorStr; } @@ -1791,14 +1778,12 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts) { return tos(c.resultType) + " ~ hasIndexer " + tos(c.subjectType) + " " + tos(c.indexType); } - else if constexpr (std::is_same_v) - return "assign " + tos(c.lhsType) + " " + tos(c.rhsType); else if constexpr (std::is_same_v) return "assignProp " + tos(c.lhsType) + " " + c.propName + " " + tos(c.rhsType); else if constexpr (std::is_same_v) return "assignIndex " + tos(c.lhsType) + " " + tos(c.indexType) + " " + tos(c.rhsType); else if constexpr (std::is_same_v) - return tos(c.resultPack) + " ~ ...unpack " + tos(c.sourcePack); + return toStringVector(c.resultPack, opts) + " ~ ...unpack " + tos(c.sourcePack); else if constexpr (std::is_same_v) return "reduce " + tos(c.ty); else if constexpr (std::is_same_v) diff --git a/Analysis/src/Transpiler.cpp b/Analysis/src/Transpiler.cpp index 85b8849f..d78bf157 100644 --- a/Analysis/src/Transpiler.cpp +++ b/Analysis/src/Transpiler.cpp @@ -1182,11 +1182,11 @@ std::string toString(AstNode* node) Printer printer(writer); printer.writeTypes = true; - if (auto statNode = dynamic_cast(node)) + if (auto statNode = node->asStat()) printer.visualize(*statNode); - else if (auto exprNode = dynamic_cast(node)) + else if (auto exprNode = node->asExpr()) printer.visualize(*exprNode); - else if (auto typeNode = dynamic_cast(node)) + else if (auto typeNode = node->asType()) printer.visualizeTypeAnnotation(*typeNode); return writer.str(); diff --git a/Analysis/src/Type.cpp b/Analysis/src/Type.cpp index b7a54e3d..71cac6fd 100644 --- a/Analysis/src/Type.cpp +++ b/Analysis/src/Type.cpp @@ -561,6 +561,11 @@ void BlockedType::setOwner(Constraint* newOwner) owner = newOwner; } +void BlockedType::replaceOwner(Constraint* newOwner) +{ + owner = newOwner; +} + PendingExpansionType::PendingExpansionType( std::optional prefix, AstName name, std::vector typeArguments, std::vector packArguments) : prefix(prefix) diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index f1fe83ee..c0294fc9 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -338,10 +338,6 @@ public: { return allocator->alloc(Location(), std::nullopt, AstName("free"), std::nullopt, Location()); } - AstType* operator()(const LocalType& lt) - { - return Luau::visit(*this, lt.domain->ty); - } AstType* operator()(const UnionType& uv) { AstArray unionTypes; diff --git a/Analysis/src/TypeFamily.cpp b/Analysis/src/TypeFamily.cpp index 3a0483a6..89de1912 100644 --- a/Analysis/src/TypeFamily.cpp +++ b/Analysis/src/TypeFamily.cpp @@ -447,7 +447,7 @@ FamilyGraphReductionResult reduceFamilies(TypePackId entrypoint, Location locati bool isPending(TypeId ty, ConstraintSolver* solver) { - return is(ty) || (solver && solver->hasUnresolvedConstraints(ty)); + return is(ty) || (solver && solver->hasUnresolvedConstraints(ty)); } template @@ -567,7 +567,7 @@ TypeFamilyReductionResult lenFamilyFn(TypeId instance, const std::vector // check to see if the operand type is resolved enough, and wait to reduce if not // the use of `typeFromNormal` later necessitates blocking on local types. - if (isPending(operandTy, ctx->solver) || get(operandTy)) + if (isPending(operandTy, ctx->solver)) return {std::nullopt, false, {operandTy}, {}}; // if the type is free but has only one remaining reference, we can generalize it to its upper bound here. @@ -1427,12 +1427,6 @@ struct FindRefinementBlockers : TypeOnceVisitor return false; } - bool visit(TypeId ty, const LocalType&) override - { - found.insert(ty); - return false; - } - bool visit(TypeId ty, const ClassType&) override { return false; diff --git a/Analysis/src/Unifier2.cpp b/Analysis/src/Unifier2.cpp index c8db5335..f46c3372 100644 --- a/Analysis/src/Unifier2.cpp +++ b/Analysis/src/Unifier2.cpp @@ -158,12 +158,6 @@ bool Unifier2::unify(TypeId subTy, TypeId superTy) if (subFree || superFree) return true; - if (auto subLocal = getMutable(subTy)) - { - subLocal->domain = mkUnion(subLocal->domain, superTy); - expandedFreeTypes[subTy].push_back(superTy); - } - auto subFn = get(subTy); auto superFn = get(superTy); if (subFn && superFn) diff --git a/Ast/include/Luau/Ast.h b/Ast/include/Luau/Ast.h index 993116d6..e8479e09 100644 --- a/Ast/include/Luau/Ast.h +++ b/Ast/include/Luau/Ast.h @@ -60,6 +60,8 @@ class AstStat; class AstStatBlock; class AstExpr; class AstTypePack; +class AstAttr; +class AstExprTable; struct AstLocal { @@ -172,6 +174,10 @@ public: { return nullptr; } + virtual AstAttr* asAttr() + { + return nullptr; + } template bool is() const @@ -193,6 +199,28 @@ public: Location location; }; +class AstAttr : public AstNode +{ +public: + LUAU_RTTI(AstAttr) + + enum Type + { + Checked, + }; + + AstAttr(const Location& location, Type type); + + AstAttr* asAttr() override + { + return this; + } + + void visit(AstVisitor* visitor) override; + + Type type; +}; + class AstExpr : public AstNode { public: @@ -384,13 +412,15 @@ class AstExprFunction : public AstExpr public: LUAU_RTTI(AstExprFunction) - AstExprFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, - AstLocal* self, const AstArray& args, bool vararg, const Location& varargLocation, AstStatBlock* body, size_t functionDepth, - const AstName& debugname, const std::optional& returnAnnotation = {}, AstTypePack* varargAnnotation = nullptr, + AstExprFunction(const Location& location, const AstArray& attributes, const AstArray& generics, + const AstArray& genericPacks, AstLocal* self, const AstArray& args, bool vararg, + const Location& varargLocation, AstStatBlock* body, size_t functionDepth, const AstName& debugname, + const std::optional& returnAnnotation = {}, AstTypePack* varargAnnotation = nullptr, const std::optional& argLocation = std::nullopt); void visit(AstVisitor* visitor) override; + AstArray attributes; AstArray generics; AstArray genericPacks; AstLocal* self; @@ -810,20 +840,22 @@ public: const AstArray& genericPacks, const AstTypeList& params, const AstArray& paramNames, const AstTypeList& retTypes); - AstStatDeclareFunction(const Location& location, const AstName& name, const AstArray& generics, - const AstArray& genericPacks, const AstTypeList& params, const AstArray& paramNames, - const AstTypeList& retTypes, bool checkedFunction); + AstStatDeclareFunction(const Location& location, const AstArray& attributes, const AstName& name, + const AstArray& generics, const AstArray& genericPacks, const AstTypeList& params, + const AstArray& paramNames, const AstTypeList& retTypes); void visit(AstVisitor* visitor) override; + bool isCheckedFunction() const; + + AstArray attributes; AstName name; AstArray generics; AstArray genericPacks; AstTypeList params; AstArray paramNames; AstTypeList retTypes; - bool checkedFunction; }; struct AstDeclaredClassProp @@ -936,17 +968,20 @@ public: AstTypeFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, const AstTypeList& argTypes, const AstArray>& argNames, const AstTypeList& returnTypes); - AstTypeFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, - const AstTypeList& argTypes, const AstArray>& argNames, const AstTypeList& returnTypes, bool checkedFunction); + AstTypeFunction(const Location& location, const AstArray& attributes, const AstArray& generics, + const AstArray& genericPacks, const AstTypeList& argTypes, const AstArray>& argNames, + const AstTypeList& returnTypes); void visit(AstVisitor* visitor) override; + bool isCheckedFunction() const; + + AstArray attributes; AstArray generics; AstArray genericPacks; AstTypeList argTypes; AstArray> argNames; AstTypeList returnTypes; - bool checkedFunction; }; class AstTypeTypeof : public AstType @@ -1105,6 +1140,11 @@ public: return true; } + virtual bool visit(class AstAttr* node) + { + return visit(static_cast(node)); + } + virtual bool visit(class AstExpr* node) { return visit(static_cast(node)); diff --git a/Ast/include/Luau/Lexer.h b/Ast/include/Luau/Lexer.h index e111030d..f6ac28ad 100644 --- a/Ast/include/Luau/Lexer.h +++ b/Ast/include/Luau/Lexer.h @@ -87,6 +87,8 @@ struct Lexeme Comment, BlockComment, + Attribute, + BrokenString, BrokenComment, BrokenUnicode, @@ -115,14 +117,20 @@ struct Lexeme ReservedTrue, ReservedUntil, ReservedWhile, - ReservedChecked, Reserved_END }; Type type; Location location; + + // Field declared here, before the union, to ensure that Lexeme size is 32 bytes. +private: + // length is used to extract a slice from the input buffer. + // This field is only valid for certain lexeme types which don't duplicate portions of input + // but instead store a pointer to a location in the input buffer and the length of lexeme. unsigned int length; +public: union { const char* data; // String, Number, Comment @@ -135,9 +143,13 @@ struct Lexeme Lexeme(const Location& location, Type type, const char* data, size_t size); Lexeme(const Location& location, Type type, const char* name); + unsigned int getLength() const; + std::string toString() const; }; +static_assert(sizeof(Lexeme) <= 32, "Size of `Lexeme` struct should be up to 32 bytes."); + class AstNameTable { public: diff --git a/Ast/include/Luau/Parser.h b/Ast/include/Luau/Parser.h index e97df66b..c1fd43ea 100644 --- a/Ast/include/Luau/Parser.h +++ b/Ast/include/Luau/Parser.h @@ -82,8 +82,8 @@ private: // if exp then block {elseif exp then block} [else block] end | // for Name `=' exp `,' exp [`,' exp] do block end | // for namelist in explist do block end | - // function funcname funcbody | - // local function Name funcbody | + // [attributes] function funcname funcbody | + // [attributes] local function Name funcbody | // local namelist [`=' explist] // laststat ::= return [explist] | break AstStat* parseStat(); @@ -114,11 +114,25 @@ private: AstExpr* parseFunctionName(Location start, bool& hasself, AstName& debugname); // function funcname funcbody - AstStat* parseFunctionStat(); + LUAU_FORCEINLINE AstStat* parseFunctionStat(const AstArray& attributes = {nullptr, 0}); + + std::pair validateAttribute(const char* attributeName, const TempVector& attributes); + + // attribute ::= '@' NAME + void parseAttribute(TempVector& attribute); + + // attributes ::= {attribute} + AstArray parseAttributes(); + + // attributes local function Name funcbody + // attributes function funcname funcbody + // attributes `declare function' Name`(' [parlist] `)' [`:` Type] + // declare Name '{' Name ':' attributes `(' [parlist] `)' [`:` Type] '}' + AstStat* parseAttributeStat(); // local function Name funcbody | // local namelist [`=' explist] - AstStat* parseLocal(); + AstStat* parseLocal(const AstArray& attributes); // return [explist] AstStat* parseReturn(); @@ -130,7 +144,7 @@ private: // `declare global' Name: Type | // `declare function' Name`(' [parlist] `)' [`:` Type] - AstStat* parseDeclaration(const Location& start); + AstStat* parseDeclaration(const Location& start, const AstArray& attributes); // varlist `=' explist AstStat* parseAssignment(AstExpr* initial); @@ -143,7 +157,7 @@ private: // funcbodyhead ::= `(' [namelist [`,' `...'] | `...'] `)' [`:` Type] // funcbody ::= funcbodyhead block end std::pair parseFunctionBody( - bool hasself, const Lexeme& matchFunction, const AstName& debugname, const Name* localName); + bool hasself, const Lexeme& matchFunction, const AstName& debugname, const Name* localName, const AstArray& attributes); // explist ::= {exp `,'} exp void parseExprList(TempVector& result); @@ -176,10 +190,10 @@ private: AstTableIndexer* parseTableIndexer(AstTableAccess access, std::optional accessLocation); - AstTypeOrPack parseFunctionType(bool allowPack, bool isCheckedFunction = false); - AstType* parseFunctionTypeTail(const Lexeme& begin, AstArray generics, AstArray genericPacks, - AstArray params, AstArray> paramNames, AstTypePack* varargAnnotation, - bool isCheckedFunction = false); + AstTypeOrPack parseFunctionType(bool allowPack, const AstArray& attributes); + AstType* parseFunctionTypeTail(const Lexeme& begin, const AstArray& attributes, AstArray generics, + AstArray genericPacks, AstArray params, AstArray> paramNames, + AstTypePack* varargAnnotation); AstType* parseTableType(bool inDeclarationContext = false); AstTypeOrPack parseSimpleType(bool allowPack, bool inDeclarationContext = false); @@ -393,6 +407,7 @@ private: std::vector matchRecoveryStopOnToken; + std::vector scratchAttr; std::vector scratchStat; std::vector> scratchString; std::vector scratchExpr; diff --git a/Ast/src/Ast.cpp b/Ast/src/Ast.cpp index bb82e0be..4c956307 100644 --- a/Ast/src/Ast.cpp +++ b/Ast/src/Ast.cpp @@ -3,6 +3,7 @@ #include "Luau/Common.h" +LUAU_FASTFLAG(LuauAttributeSyntax); namespace Luau { @@ -16,6 +17,17 @@ static void visitTypeList(AstVisitor* visitor, const AstTypeList& list) list.tailType->visit(visitor); } +AstAttr::AstAttr(const Location& location, Type type) + : AstNode(ClassIndex(), location) + , type(type) +{ +} + +void AstAttr::visit(AstVisitor* visitor) +{ + visitor->visit(this); +} + int gAstRttiIndex = 0; AstExprGroup::AstExprGroup(const Location& location, AstExpr* expr) @@ -161,11 +173,12 @@ void AstExprIndexExpr::visit(AstVisitor* visitor) } } -AstExprFunction::AstExprFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, - AstLocal* self, const AstArray& args, bool vararg, const Location& varargLocation, AstStatBlock* body, size_t functionDepth, - const AstName& debugname, const std::optional& returnAnnotation, AstTypePack* varargAnnotation, - const std::optional& argLocation) +AstExprFunction::AstExprFunction(const Location& location, const AstArray& attributes, const AstArray& generics, + const AstArray& genericPacks, AstLocal* self, const AstArray& args, bool vararg, const Location& varargLocation, + AstStatBlock* body, size_t functionDepth, const AstName& debugname, const std::optional& returnAnnotation, + AstTypePack* varargAnnotation, const std::optional& argLocation) : AstExpr(ClassIndex(), location) + , attributes(attributes) , generics(generics) , genericPacks(genericPacks) , self(self) @@ -696,27 +709,27 @@ AstStatDeclareFunction::AstStatDeclareFunction(const Location& location, const A const AstArray& genericPacks, const AstTypeList& params, const AstArray& paramNames, const AstTypeList& retTypes) : AstStat(ClassIndex(), location) + , attributes() , name(name) , generics(generics) , genericPacks(genericPacks) , params(params) , paramNames(paramNames) , retTypes(retTypes) - , checkedFunction(false) { } -AstStatDeclareFunction::AstStatDeclareFunction(const Location& location, const AstName& name, const AstArray& generics, - const AstArray& genericPacks, const AstTypeList& params, const AstArray& paramNames, - const AstTypeList& retTypes, bool checkedFunction) +AstStatDeclareFunction::AstStatDeclareFunction(const Location& location, const AstArray& attributes, const AstName& name, + const AstArray& generics, const AstArray& genericPacks, const AstTypeList& params, + const AstArray& paramNames, const AstTypeList& retTypes) : AstStat(ClassIndex(), location) + , attributes(attributes) , name(name) , generics(generics) , genericPacks(genericPacks) , params(params) , paramNames(paramNames) , retTypes(retTypes) - , checkedFunction(checkedFunction) { } @@ -729,6 +742,19 @@ void AstStatDeclareFunction::visit(AstVisitor* visitor) } } +bool AstStatDeclareFunction::isCheckedFunction() const +{ + LUAU_ASSERT(FFlag::LuauAttributeSyntax); + + for (const AstAttr* attr : attributes) + { + if (attr->type == AstAttr::Type::Checked) + return true; + } + + return false; +} + AstStatDeclareClass::AstStatDeclareClass(const Location& location, const AstName& name, std::optional superName, const AstArray& props, AstTableIndexer* indexer) : AstStat(ClassIndex(), location) @@ -820,25 +846,26 @@ void AstTypeTable::visit(AstVisitor* visitor) AstTypeFunction::AstTypeFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, const AstTypeList& argTypes, const AstArray>& argNames, const AstTypeList& returnTypes) : AstType(ClassIndex(), location) + , attributes() , generics(generics) , genericPacks(genericPacks) , argTypes(argTypes) , argNames(argNames) , returnTypes(returnTypes) - , checkedFunction(false) { LUAU_ASSERT(argNames.size == 0 || argNames.size == argTypes.types.size); } -AstTypeFunction::AstTypeFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, - const AstTypeList& argTypes, const AstArray>& argNames, const AstTypeList& returnTypes, bool checkedFunction) +AstTypeFunction::AstTypeFunction(const Location& location, const AstArray& attributes, const AstArray& generics, + const AstArray& genericPacks, const AstTypeList& argTypes, const AstArray>& argNames, + const AstTypeList& returnTypes) : AstType(ClassIndex(), location) + , attributes(attributes) , generics(generics) , genericPacks(genericPacks) , argTypes(argTypes) , argNames(argNames) , returnTypes(returnTypes) - , checkedFunction(checkedFunction) { LUAU_ASSERT(argNames.size == 0 || argNames.size == argTypes.types.size); } @@ -852,6 +879,19 @@ void AstTypeFunction::visit(AstVisitor* visitor) } } +bool AstTypeFunction::isCheckedFunction() const +{ + LUAU_ASSERT(FFlag::LuauAttributeSyntax); + + for (const AstAttr* attr : attributes) + { + if (attr->type == AstAttr::Type::Checked) + return true; + } + + return false; +} + AstTypeTypeof::AstTypeTypeof(const Location& location, AstExpr* expr) : AstType(ClassIndex(), location) , expr(expr) diff --git a/Ast/src/Lexer.cpp b/Ast/src/Lexer.cpp index 71577459..8e9b3be9 100644 --- a/Ast/src/Lexer.cpp +++ b/Ast/src/Lexer.cpp @@ -8,6 +8,7 @@ #include LUAU_FASTFLAGVARIABLE(LuauLexerLookaheadRemembersBraceType, false) +LUAU_FASTFLAGVARIABLE(LuauAttributeSyntax, false) namespace Luau { @@ -102,11 +103,19 @@ Lexeme::Lexeme(const Location& location, Type type, const char* name) , length(0) , name(name) { - LUAU_ASSERT(type == Name || (type >= Reserved_BEGIN && type < Lexeme::Reserved_END)); + LUAU_ASSERT(type == Name || type == Attribute || (type >= Reserved_BEGIN && type < Lexeme::Reserved_END)); +} + +unsigned int Lexeme::getLength() const +{ + LUAU_ASSERT(type == RawString || type == QuotedString || type == InterpStringBegin || type == InterpStringMid || type == InterpStringEnd || + type == InterpStringSimple || type == BrokenInterpDoubleBrace || type == Number || type == Comment || type == BlockComment); + + return length; } static const char* kReserved[] = {"and", "break", "do", "else", "elseif", "end", "false", "for", "function", "if", "in", "local", "nil", "not", "or", - "repeat", "return", "then", "true", "until", "while", "@checked"}; + "repeat", "return", "then", "true", "until", "while"}; std::string Lexeme::toString() const { @@ -191,6 +200,10 @@ std::string Lexeme::toString() const case Comment: return "comment"; + case Attribute: + LUAU_ASSERT(FFlag::LuauAttributeSyntax); + return name ? format("'%s'", name) : "attribute"; + case BrokenString: return "malformed string"; @@ -278,7 +291,7 @@ std::pair AstNameTable::getOrAddWithType(const char* name nameData[length] = 0; const_cast(entry).value = AstName(nameData); - const_cast(entry).type = Lexeme::Name; + const_cast(entry).type = (name[0] == '@' ? Lexeme::Attribute : Lexeme::Name); return std::make_pair(entry.value, entry.type); } @@ -994,14 +1007,11 @@ Lexeme Lexer::readNext() } case '@': { - // We're trying to lex the token @checked - LUAU_ASSERT(peekch() == '@'); - - std::pair maybeChecked = readName(); - if (maybeChecked.second != Lexeme::ReservedChecked) - return Lexeme(Location(start, position()), Lexeme::Error); - - return Lexeme(Location(start, position()), maybeChecked.second, maybeChecked.first.value); + if (FFlag::LuauAttributeSyntax) + { + std::pair attribute = readName(); + return Lexeme(Location(start, position()), Lexeme::Attribute, attribute.first.value); + } } default: if (isDigit(peekch())) diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 5ca480e8..d80878d5 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -17,11 +17,20 @@ LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) // flag so that we don't break production games by reverting syntax changes. // See docs/SyntaxChanges.md for an explanation. LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false) +LUAU_FASTFLAG(LuauAttributeSyntax) LUAU_FASTFLAGVARIABLE(LuauLeadingBarAndAmpersand, false) namespace Luau { +struct AttributeEntry +{ + const char* name; + AstAttr::Type type; +}; + +AttributeEntry kAttributeEntries[] = {{"@checked", AstAttr::Type::Checked}, {nullptr, AstAttr::Type::Checked}}; + ParseError::ParseError(const Location& location, const std::string& message) : location(location) , message(message) @@ -280,7 +289,9 @@ AstStatBlock* Parser::parseBlockNoScope() // for binding `=' exp `,' exp [`,' exp] do block end | // for namelist in explist do block end | // function funcname funcbody | +// attributes function funcname funcbody | // local function Name funcbody | +// local attributes function Name funcbody | // local namelist [`=' explist] // laststat ::= return [explist] | break AstStat* Parser::parseStat() @@ -299,13 +310,16 @@ AstStat* Parser::parseStat() case Lexeme::ReservedRepeat: return parseRepeat(); case Lexeme::ReservedFunction: - return parseFunctionStat(); + return parseFunctionStat(AstArray({nullptr, 0})); case Lexeme::ReservedLocal: - return parseLocal(); + return parseLocal(AstArray({nullptr, 0})); case Lexeme::ReservedReturn: return parseReturn(); case Lexeme::ReservedBreak: return parseBreak(); + case Lexeme::Attribute: + if (FFlag::LuauAttributeSyntax) + return parseAttributeStat(); default:; } @@ -343,7 +357,7 @@ AstStat* Parser::parseStat() if (options.allowDeclarationSyntax) { if (ident == "declare") - return parseDeclaration(expr->location); + return parseDeclaration(expr->location, AstArray({nullptr, 0})); } // skip unexpected symbol if lexer couldn't advance at all (statements are parsed in a loop) @@ -652,7 +666,7 @@ AstExpr* Parser::parseFunctionName(Location start, bool& hasself, AstName& debug } // function funcname funcbody -AstStat* Parser::parseFunctionStat() +AstStat* Parser::parseFunctionStat(const AstArray& attributes) { Location start = lexer.current().location; @@ -665,16 +679,125 @@ AstStat* Parser::parseFunctionStat() matchRecoveryStopOnToken[Lexeme::ReservedEnd]++; - AstExprFunction* body = parseFunctionBody(hasself, matchFunction, debugname, nullptr).first; + AstExprFunction* body = parseFunctionBody(hasself, matchFunction, debugname, nullptr, attributes).first; matchRecoveryStopOnToken[Lexeme::ReservedEnd]--; return allocator.alloc(Location(start, body->location), expr, body); } + +std::pair Parser::validateAttribute(const char* attributeName, const TempVector& attributes) +{ + LUAU_ASSERT(FFlag::LuauAttributeSyntax); + + AstAttr::Type type; + + // check if the attribute name is valid + + bool found = false; + + for (int i = 0; kAttributeEntries[i].name; ++i) + { + found = !strcmp(attributeName, kAttributeEntries[i].name); + if (found) + { + type = kAttributeEntries[i].type; + break; + } + } + + if (!found) + { + if (strlen(attributeName) == 1) + report(lexer.current().location, "Attribute name is missing"); + else + report(lexer.current().location, "Invalid attribute '%s'", attributeName); + } + else + { + // check that attribute is not duplicated + for (const AstAttr* attr : attributes) + { + if (attr->type == type) + { + report(lexer.current().location, "Cannot duplicate attribute '%s'", attributeName); + } + } + } + + return {found, type}; +} + +// attribute ::= '@' NAME +void Parser::parseAttribute(TempVector& attributes) +{ + LUAU_ASSERT(FFlag::LuauAttributeSyntax); + + LUAU_ASSERT(lexer.current().type == Lexeme::Type::Attribute); + + Location loc = lexer.current().location; + + const char* name = lexer.current().name; + const auto [found, type] = validateAttribute(name, attributes); + + nextLexeme(); + + if (found) + attributes.push_back(allocator.alloc(loc, type)); +} + +// attributes ::= {attribute} +AstArray Parser::parseAttributes() +{ + LUAU_ASSERT(FFlag::LuauAttributeSyntax); + + Lexeme::Type type = lexer.current().type; + + LUAU_ASSERT(type == Lexeme::Attribute); + + TempVector attributes(scratchAttr); + + while (lexer.current().type == Lexeme::Attribute) + parseAttribute(attributes); + + return copy(attributes); +} + +// attributes local function Name funcbody +// attributes function funcname funcbody +// attributes `declare function' Name`(' [parlist] `)' [`:` Type] +// declare Name '{' Name ':' attributes `(' [parlist] `)' [`:` Type] '}' +AstStat* Parser::parseAttributeStat() +{ + LUAU_ASSERT(FFlag::LuauAttributeSyntax); + + AstArray attributes = Parser::parseAttributes(); + + Lexeme::Type type = lexer.current().type; + + switch (type) + { + case Lexeme::Type::ReservedFunction: + return parseFunctionStat(attributes); + case Lexeme::Type::ReservedLocal: + return parseLocal(attributes); + case Lexeme::Type::Name: + if (options.allowDeclarationSyntax && !strcmp("declare", lexer.current().data)) + { + AstExpr* expr = parsePrimaryExpr(/* asStatement= */ true); + return parseDeclaration(expr->location, attributes); + } + default: + return reportStatError(lexer.current().location, {}, {}, + "Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got %s intead", + lexer.current().toString().c_str()); + } +} + // local function Name funcbody | // local bindinglist [`=' explist] -AstStat* Parser::parseLocal() +AstStat* Parser::parseLocal(const AstArray& attributes) { Location start = lexer.current().location; @@ -694,7 +817,7 @@ AstStat* Parser::parseLocal() matchRecoveryStopOnToken[Lexeme::ReservedEnd]++; - auto [body, var] = parseFunctionBody(false, matchFunction, name.name, &name); + auto [body, var] = parseFunctionBody(false, matchFunction, name.name, &name, attributes); matchRecoveryStopOnToken[Lexeme::ReservedEnd]--; @@ -704,6 +827,12 @@ AstStat* Parser::parseLocal() } else { + if (FFlag::LuauAttributeSyntax && attributes.size != 0) + { + return reportStatError(lexer.current().location, {}, {}, "Expected 'function' after local declaration with attribute, but got %s intead", + lexer.current().toString().c_str()); + } + matchRecoveryStopOnToken['=']++; TempVector names(scratchBinding); @@ -831,18 +960,17 @@ AstDeclaredClassProp Parser::parseDeclaredClassMethod() return AstDeclaredClassProp{fnName.name, fnType, true}; } -AstStat* Parser::parseDeclaration(const Location& start) +AstStat* Parser::parseDeclaration(const Location& start, const AstArray& attributes) { // `declare` token is already parsed at this point + + if (FFlag::LuauAttributeSyntax && (attributes.size != 0) && (lexer.current().type != Lexeme::ReservedFunction)) + return reportStatError(lexer.current().location, {}, {}, "Expected a function type declaration after attribute, but got %s intead", + lexer.current().toString().c_str()); + if (lexer.current().type == Lexeme::ReservedFunction) { nextLexeme(); - bool checkedFunction = false; - if (lexer.current().type == Lexeme::ReservedChecked) - { - checkedFunction = true; - nextLexeme(); - } Name globalName = parseName("global function name"); auto [generics, genericPacks] = parseGenericTypeList(/* withDefaultValues= */ false); @@ -880,8 +1008,8 @@ AstStat* Parser::parseDeclaration(const Location& start) if (vararg && !varargAnnotation) return reportStatError(Location(start, end), {}, {}, "All declaration parameters must be annotated"); - return allocator.alloc(Location(start, end), globalName.name, generics, genericPacks, - AstTypeList{copy(vars), varargAnnotation}, copy(varNames), retTypes, checkedFunction); + return allocator.alloc(Location(start, end), attributes, globalName.name, generics, genericPacks, + AstTypeList{copy(vars), varargAnnotation}, copy(varNames), retTypes); } else if (AstName(lexer.current().name) == "class") { @@ -1035,7 +1163,7 @@ std::pair> Parser::prepareFunctionArguments(const // funcbody ::= `(' [parlist] `)' [`:' ReturnType] block end // parlist ::= bindinglist [`,' `...'] | `...' std::pair Parser::parseFunctionBody( - bool hasself, const Lexeme& matchFunction, const AstName& debugname, const Name* localName) + bool hasself, const Lexeme& matchFunction, const AstName& debugname, const Name* localName, const AstArray& attributes) { Location start = matchFunction.location; @@ -1087,7 +1215,7 @@ std::pair Parser::parseFunctionBody( bool hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchFunction); body->hasEnd = hasEnd; - return {allocator.alloc(Location(start, end), generics, genericPacks, self, vars, vararg, varargLocation, body, + return {allocator.alloc(Location(start, end), attributes, generics, genericPacks, self, vars, vararg, varargLocation, body, functionStack.size(), debugname, typelist, varargAnnotation, argLocation), funLocal}; } @@ -1296,7 +1424,7 @@ std::pair Parser::parseReturnType() return {location, AstTypeList{copy(result), varargAnnotation}}; } - AstType* tail = parseFunctionTypeTail(begin, {}, {}, copy(result), copy(resultNames), varargAnnotation); + AstType* tail = parseFunctionTypeTail(begin, {nullptr, 0}, {}, {}, copy(result), copy(resultNames), varargAnnotation); return {Location{location, tail->location}, AstTypeList{copy(&tail, 1), varargAnnotation}}; } @@ -1435,7 +1563,7 @@ AstType* Parser::parseTableType(bool inDeclarationContext) // ReturnType ::= Type | `(' TypeList `)' // FunctionType ::= [`<' varlist `>'] `(' [TypeList] `)' `->` ReturnType -AstTypeOrPack Parser::parseFunctionType(bool allowPack, bool isCheckedFunction) +AstTypeOrPack Parser::parseFunctionType(bool allowPack, const AstArray& attributes) { incrementRecursionCounter("type annotation"); @@ -1483,11 +1611,12 @@ AstTypeOrPack Parser::parseFunctionType(bool allowPack, bool isCheckedFunction) AstArray> paramNames = copy(names); - return {parseFunctionTypeTail(begin, generics, genericPacks, paramTypes, paramNames, varargAnnotation, isCheckedFunction), {}}; + return {parseFunctionTypeTail(begin, attributes, generics, genericPacks, paramTypes, paramNames, varargAnnotation), {}}; } -AstType* Parser::parseFunctionTypeTail(const Lexeme& begin, AstArray generics, AstArray genericPacks, - AstArray params, AstArray> paramNames, AstTypePack* varargAnnotation, bool isCheckedFunction) +AstType* Parser::parseFunctionTypeTail(const Lexeme& begin, const AstArray& attributes, AstArray generics, + AstArray genericPacks, AstArray params, AstArray> paramNames, + AstTypePack* varargAnnotation) { incrementRecursionCounter("type annotation"); @@ -1512,7 +1641,7 @@ AstType* Parser::parseFunctionTypeTail(const Lexeme& begin, AstArray( - Location(begin.location, endLocation), generics, genericPacks, paramTypes, paramNames, returnTypeList, isCheckedFunction); + Location(begin.location, endLocation), attributes, generics, genericPacks, paramTypes, paramNames, returnTypeList); } // Type ::= @@ -1666,7 +1795,21 @@ AstTypeOrPack Parser::parseSimpleType(bool allowPack, bool inDeclarationContext) Location start = lexer.current().location; - if (lexer.current().type == Lexeme::ReservedNil) + AstArray attributes{nullptr, 0}; + + if (lexer.current().type == Lexeme::Attribute) + { + if (!inDeclarationContext || !FFlag::LuauAttributeSyntax) + { + return {reportTypeError(start, {}, "attributes are not allowed in declaration context")}; + } + else + { + attributes = Parser::parseAttributes(); + return parseFunctionType(allowPack, attributes); + } + } + else if (lexer.current().type == Lexeme::ReservedNil) { nextLexeme(); return {allocator.alloc(start, std::nullopt, nameNil, std::nullopt, start), {}}; @@ -1754,14 +1897,9 @@ AstTypeOrPack Parser::parseSimpleType(bool allowPack, bool inDeclarationContext) { return {parseTableType(/* inDeclarationContext */ inDeclarationContext), {}}; } - else if (inDeclarationContext && lexer.current().type == Lexeme::ReservedChecked) - { - nextLexeme(); - return parseFunctionType(allowPack, /* isCheckedFunction */ true); - } else if (lexer.current().type == '(' || lexer.current().type == '<') { - return parseFunctionType(allowPack); + return parseFunctionType(allowPack, AstArray({nullptr, 0})); } else if (lexer.current().type == Lexeme::ReservedFunction) { @@ -2259,7 +2397,7 @@ AstExpr* Parser::parseSimpleExpr() Lexeme matchFunction = lexer.current(); nextLexeme(); - return parseFunctionBody(false, matchFunction, AstName(), nullptr).first; + return parseFunctionBody(false, matchFunction, AstName(), nullptr, AstArray({nullptr, 0})).first; } else if (lexer.current().type == Lexeme::Number) { @@ -2689,7 +2827,7 @@ std::optional> Parser::parseCharArray() LUAU_ASSERT(lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::InterpStringSimple); - scratchData.assign(lexer.current().data, lexer.current().length); + scratchData.assign(lexer.current().data, lexer.current().getLength()); if (lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::InterpStringSimple) { @@ -2734,7 +2872,7 @@ AstExpr* Parser::parseInterpString() endLocation = currentLexeme.location; - scratchData.assign(currentLexeme.data, currentLexeme.length); + scratchData.assign(currentLexeme.data, currentLexeme.getLength()); if (!Lexer::fixupQuotedString(scratchData)) { @@ -2807,7 +2945,7 @@ AstExpr* Parser::parseNumber() { Location start = lexer.current().location; - scratchData.assign(lexer.current().data, lexer.current().length); + scratchData.assign(lexer.current().data, lexer.current().getLength()); // Remove all internal _ - they don't hold any meaning and this allows parsing code to just pass the string pointer to strtod et al if (scratchData.find('_') != std::string::npos) @@ -3162,11 +3300,11 @@ void Parser::nextLexeme() return; // Comments starting with ! are called "hot comments" and contain directives for type checking / linting / compiling - if (lexeme.type == Lexeme::Comment && lexeme.length && lexeme.data[0] == '!') + if (lexeme.type == Lexeme::Comment && lexeme.getLength() && lexeme.data[0] == '!') { const char* text = lexeme.data; - unsigned int end = lexeme.length; + unsigned int end = lexeme.getLength(); while (end > 0 && isSpace(text[end - 1])) --end; diff --git a/CodeGen/include/Luau/CodeGen.h b/CodeGen/include/Luau/CodeGen.h index 171e9197..7dd05660 100644 --- a/CodeGen/include/Luau/CodeGen.h +++ b/CodeGen/include/Luau/CodeGen.h @@ -73,12 +73,39 @@ struct CompilationResult }; struct IrBuilder; +struct IrOp; using HostVectorOperationBytecodeType = uint8_t (*)(const char* member, size_t memberLength); using HostVectorAccessHandler = bool (*)(IrBuilder& builder, const char* member, size_t memberLength, int resultReg, int sourceReg, int pcpos); using HostVectorNamecallHandler = bool (*)( IrBuilder& builder, const char* member, size_t memberLength, int argResReg, int sourceReg, int params, int results, int pcpos); +enum class HostMetamethod +{ + Add, + Sub, + Mul, + Div, + Idiv, + Mod, + Pow, + Minus, + Equal, + LessThan, + LessEqual, + Length, + Concat, +}; + +using HostUserdataOperationBytecodeType = uint8_t (*)(uint8_t type, const char* member, size_t memberLength); +using HostUserdataMetamethodBytecodeType = uint8_t (*)(uint8_t lhsTy, uint8_t rhsTy, HostMetamethod method); +using HostUserdataAccessHandler = bool (*)( + IrBuilder& builder, uint8_t type, const char* member, size_t memberLength, int resultReg, int sourceReg, int pcpos); +using HostUserdataMetamethodHandler = bool (*)( + IrBuilder& builder, uint8_t lhsTy, uint8_t rhsTy, int resultReg, IrOp lhs, IrOp rhs, HostMetamethod method, int pcpos); +using HostUserdataNamecallHandler = bool (*)( + IrBuilder& builder, uint8_t type, const char* member, size_t memberLength, int argResReg, int sourceReg, int params, int results, int pcpos); + struct HostIrHooks { // Suggest result type of a vector field access @@ -97,6 +124,34 @@ struct HostIrHooks // All other arguments can be of any type // Guards should take a VM exit to 'pcpos' HostVectorNamecallHandler vectorNamecall = nullptr; + + // Suggest result type of a userdata field access + HostUserdataOperationBytecodeType userdataAccessBytecodeType = nullptr; + + // Suggest result type of a metamethod call + HostUserdataMetamethodBytecodeType userdataMetamethodBytecodeType = nullptr; + + // Suggest result type of a userdata namecall + HostUserdataOperationBytecodeType userdataNamecallBytecodeType = nullptr; + + // Handle userdata value field access + // 'sourceReg' is guaranteed to be a userdata, but tag has to be checked + // Write to 'resultReg' might invalidate 'sourceReg' + // Guards should take a VM exit to 'pcpos' + HostUserdataAccessHandler userdataAccess = nullptr; + + // Handle metamethod operation on a userdata value + // 'lhs' and 'rhs' operands can be VM registers of constants + // Operand types have to be checked and userdata operand tags have to be checked + // Write to 'resultReg' might invalidate source operands + // Guards should take a VM exit to 'pcpos' + HostUserdataMetamethodHandler userdataMetamethod = nullptr; + + // Handle namecall performed on a userdata value + // 'sourceReg' (self argument) is guaranteed to be a userdata, but tag has to be checked + // All other arguments can be of any type + // Guards should take a VM exit to 'pcpos' + HostUserdataNamecallHandler userdataNamecall = nullptr; }; struct CompilationOptions diff --git a/CodeGen/include/Luau/IrData.h b/CodeGen/include/Luau/IrData.h index b00fffab..60af706f 100644 --- a/CodeGen/include/Luau/IrData.h +++ b/CodeGen/include/Luau/IrData.h @@ -290,6 +290,11 @@ enum class IrCmd : uint8_t // C: block TRY_CALL_FASTGETTM, + // Create new tagged userdata + // A: int (size) + // B: int (tag) + NEW_USERDATA, + // Convert integer into a double number // A: int INT_TO_NUM, @@ -460,6 +465,13 @@ enum class IrCmd : uint8_t // When undef is specified instead of a block, execution is aborted on check failure CHECK_BUFFER_LEN, + // Guard against userdata tag mismatch + // A: pointer (userdata) + // B: int (tag) + // C: block/vmexit/undef + // When undef is specified instead of a block, execution is aborted on check failure + CHECK_USERDATA_TAG, + // Special operations // Check interrupt handler diff --git a/CodeGen/include/Luau/IrUtils.h b/CodeGen/include/Luau/IrUtils.h index 55b86822..8486921e 100644 --- a/CodeGen/include/Luau/IrUtils.h +++ b/CodeGen/include/Luau/IrUtils.h @@ -11,6 +11,7 @@ namespace CodeGen { struct IrBuilder; +enum class HostMetamethod; inline bool isJumpD(LuauOpcode op) { @@ -129,6 +130,7 @@ inline bool isNonTerminatingJump(IrCmd cmd) case IrCmd::CHECK_NODE_NO_NEXT: case IrCmd::CHECK_NODE_VALUE: case IrCmd::CHECK_BUFFER_LEN: + case IrCmd::CHECK_USERDATA_TAG: return true; default: break; @@ -182,6 +184,7 @@ inline bool hasResult(IrCmd cmd) case IrCmd::DUP_TABLE: case IrCmd::TRY_NUM_TO_INDEX: case IrCmd::TRY_CALL_FASTGETTM: + case IrCmd::NEW_USERDATA: case IrCmd::INT_TO_NUM: case IrCmd::UINT_TO_NUM: case IrCmd::NUM_TO_INT: @@ -245,6 +248,8 @@ bool isGCO(uint8_t tag); bool isUserdataBytecodeType(uint8_t ty); bool isCustomUserdataBytecodeType(uint8_t ty); +HostMetamethod tmToHostMetamethod(int tm); + // Manually add or remove use of an operand void addUse(IrFunction& function, IrOp op); void removeUse(IrFunction& function, IrOp op); diff --git a/CodeGen/src/BytecodeAnalysis.cpp b/CodeGen/src/BytecodeAnalysis.cpp index aed8c763..fc8eb900 100644 --- a/CodeGen/src/BytecodeAnalysis.cpp +++ b/CodeGen/src/BytecodeAnalysis.cpp @@ -16,6 +16,7 @@ LUAU_FASTFLAG(LuauLoadTypeInfo) // Because new VM typeinfo loa LUAU_FASTFLAGVARIABLE(LuauCodegenTypeInfo, false) // New analysis is flagged separately LUAU_FASTFLAGVARIABLE(LuauCodegenAnalyzeHostVectorOps, false) LUAU_FASTFLAGVARIABLE(LuauCodegenLoadTypeUpvalCheck, false) +LUAU_FASTFLAGVARIABLE(LuauCodegenUserdataOps, false) namespace Luau { @@ -546,6 +547,49 @@ static void applyBuiltinCall(int bfid, BytecodeTypes& types) } } +static HostMetamethod opcodeToHostMetamethod(LuauOpcode op) +{ + switch (op) + { + case LOP_ADD: + return HostMetamethod::Add; + case LOP_SUB: + return HostMetamethod::Sub; + case LOP_MUL: + return HostMetamethod::Mul; + case LOP_DIV: + return HostMetamethod::Div; + case LOP_IDIV: + return HostMetamethod::Idiv; + case LOP_MOD: + return HostMetamethod::Mod; + case LOP_POW: + return HostMetamethod::Pow; + case LOP_ADDK: + return HostMetamethod::Add; + case LOP_SUBK: + return HostMetamethod::Sub; + case LOP_MULK: + return HostMetamethod::Mul; + case LOP_DIVK: + return HostMetamethod::Div; + case LOP_IDIVK: + return HostMetamethod::Idiv; + case LOP_MODK: + return HostMetamethod::Mod; + case LOP_POWK: + return HostMetamethod::Pow; + case LOP_SUBRK: + return HostMetamethod::Sub; + case LOP_DIVRK: + return HostMetamethod::Div; + default: + CODEGEN_ASSERT(!"opcode is not assigned to a host metamethod"); + } + + return HostMetamethod::Add; +} + void buildBytecodeBlocks(IrFunction& function, const std::vector& jumpTargets) { Proto* proto = function.proto; @@ -760,22 +804,50 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) regTags[ra] = LBC_TYPE_ANY; - if (bcType.a == LBC_TYPE_VECTOR) + if (FFlag::LuauCodegenUserdataOps) { TString* str = gco2ts(function.proto->k[kc].value.gc); const char* field = getstr(str); - if (str->len == 1) + if (bcType.a == LBC_TYPE_VECTOR) { - // Same handling as LOP_GETTABLEKS block in lvmexecute.cpp - case-insensitive comparison with "X" / "Y" / "Z" - char ch = field[0] | ' '; + if (str->len == 1) + { + // Same handling as LOP_GETTABLEKS block in lvmexecute.cpp - case-insensitive comparison with "X" / "Y" / "Z" + char ch = field[0] | ' '; - if (ch == 'x' || ch == 'y' || ch == 'z') - regTags[ra] = LBC_TYPE_NUMBER; + if (ch == 'x' || ch == 'y' || ch == 'z') + regTags[ra] = LBC_TYPE_NUMBER; + } + + if (FFlag::LuauCodegenAnalyzeHostVectorOps && regTags[ra] == LBC_TYPE_ANY && hostHooks.vectorAccessBytecodeType) + regTags[ra] = hostHooks.vectorAccessBytecodeType(field, str->len); } + else if (isCustomUserdataBytecodeType(bcType.a)) + { + if (regTags[ra] == LBC_TYPE_ANY && hostHooks.userdataAccessBytecodeType) + regTags[ra] = hostHooks.userdataAccessBytecodeType(bcType.a, field, str->len); + } + } + else + { + if (bcType.a == LBC_TYPE_VECTOR) + { + TString* str = gco2ts(function.proto->k[kc].value.gc); + const char* field = getstr(str); - if (FFlag::LuauCodegenAnalyzeHostVectorOps && regTags[ra] == LBC_TYPE_ANY && hostHooks.vectorAccessBytecodeType) - regTags[ra] = hostHooks.vectorAccessBytecodeType(field, str->len); + if (str->len == 1) + { + // Same handling as LOP_GETTABLEKS block in lvmexecute.cpp - case-insensitive comparison with "X" / "Y" / "Z" + char ch = field[0] | ' '; + + if (ch == 'x' || ch == 'y' || ch == 'z') + regTags[ra] = LBC_TYPE_NUMBER; + } + + if (FFlag::LuauCodegenAnalyzeHostVectorOps && regTags[ra] == LBC_TYPE_ANY && hostHooks.vectorAccessBytecodeType) + regTags[ra] = hostHooks.vectorAccessBytecodeType(field, str->len); + } } bcType.result = regTags[ra]; @@ -812,6 +884,9 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) regTags[ra] = LBC_TYPE_NUMBER; else if (bcType.a == LBC_TYPE_VECTOR && bcType.b == LBC_TYPE_VECTOR) regTags[ra] = LBC_TYPE_VECTOR; + else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType && + (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) + regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); bcType.result = regTags[ra]; break; @@ -841,6 +916,11 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) if (bcType.b == LBC_TYPE_NUMBER || bcType.b == LBC_TYPE_VECTOR) regTags[ra] = LBC_TYPE_VECTOR; } + else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType && + (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) + { + regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); + } bcType.result = regTags[ra]; break; @@ -859,6 +939,9 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) if (bcType.a == LBC_TYPE_NUMBER && bcType.b == LBC_TYPE_NUMBER) regTags[ra] = LBC_TYPE_NUMBER; + else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType && + (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) + regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); bcType.result = regTags[ra]; break; @@ -879,6 +962,9 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) regTags[ra] = LBC_TYPE_NUMBER; else if (bcType.a == LBC_TYPE_VECTOR && bcType.b == LBC_TYPE_VECTOR) regTags[ra] = LBC_TYPE_VECTOR; + else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType && + (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) + regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); bcType.result = regTags[ra]; break; @@ -908,6 +994,11 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) if (bcType.b == LBC_TYPE_NUMBER || bcType.b == LBC_TYPE_VECTOR) regTags[ra] = LBC_TYPE_VECTOR; } + else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType && + (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) + { + regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); + } bcType.result = regTags[ra]; break; @@ -926,6 +1017,9 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) if (bcType.a == LBC_TYPE_NUMBER && bcType.b == LBC_TYPE_NUMBER) regTags[ra] = LBC_TYPE_NUMBER; + else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType && + (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) + regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); bcType.result = regTags[ra]; break; @@ -945,6 +1039,9 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) regTags[ra] = LBC_TYPE_NUMBER; else if (bcType.a == LBC_TYPE_VECTOR && bcType.b == LBC_TYPE_VECTOR) regTags[ra] = LBC_TYPE_VECTOR; + else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType && + (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) + regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); bcType.result = regTags[ra]; break; @@ -972,6 +1069,11 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) if (bcType.b == LBC_TYPE_NUMBER || bcType.b == LBC_TYPE_VECTOR) regTags[ra] = LBC_TYPE_VECTOR; } + else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType && + (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) + { + regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); + } bcType.result = regTags[ra]; break; @@ -1000,6 +1102,8 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) regTags[ra] = LBC_TYPE_NUMBER; else if (bcType.a == LBC_TYPE_VECTOR) regTags[ra] = LBC_TYPE_VECTOR; + else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType && isCustomUserdataBytecodeType(bcType.a)) + regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, LBC_TYPE_ANY, HostMetamethod::Minus); bcType.result = regTags[ra]; break; @@ -1140,12 +1244,25 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) bcType.result = LBC_TYPE_FUNCTION; - if (FFlag::LuauCodegenAnalyzeHostVectorOps && bcType.a == LBC_TYPE_VECTOR && hostHooks.vectorNamecallBytecodeType) + if (FFlag::LuauCodegenUserdataOps) { TString* str = gco2ts(function.proto->k[kc].value.gc); const char* field = getstr(str); - knownNextCallResult = LuauBytecodeType(hostHooks.vectorNamecallBytecodeType(field, str->len)); + if (FFlag::LuauCodegenAnalyzeHostVectorOps && bcType.a == LBC_TYPE_VECTOR && hostHooks.vectorNamecallBytecodeType) + knownNextCallResult = LuauBytecodeType(hostHooks.vectorNamecallBytecodeType(field, str->len)); + else if (isCustomUserdataBytecodeType(bcType.a) && hostHooks.userdataNamecallBytecodeType) + knownNextCallResult = LuauBytecodeType(hostHooks.userdataNamecallBytecodeType(bcType.a, field, str->len)); + } + else + { + if (FFlag::LuauCodegenAnalyzeHostVectorOps && bcType.a == LBC_TYPE_VECTOR && hostHooks.vectorNamecallBytecodeType) + { + TString* str = gco2ts(function.proto->k[kc].value.gc); + const char* field = getstr(str); + + knownNextCallResult = LuauBytecodeType(hostHooks.vectorNamecallBytecodeType(field, str->len)); + } } } break; diff --git a/CodeGen/src/CodeGenA64.cpp b/CodeGen/src/CodeGenA64.cpp index 05ac9013..06f64955 100644 --- a/CodeGen/src/CodeGenA64.cpp +++ b/CodeGen/src/CodeGenA64.cpp @@ -258,39 +258,6 @@ static EntryLocations buildEntryFunction(AssemblyBuilderA64& build, UnwindBuilde return locations; } -bool initHeaderFunctions(NativeState& data) -{ - AssemblyBuilderA64 build(/* logText= */ false); - UnwindBuilder& unwind = *data.unwindBuilder.get(); - - unwind.startInfo(UnwindBuilder::A64); - - EntryLocations entryLocations = buildEntryFunction(build, unwind); - - build.finalize(); - - unwind.finishInfo(); - - CODEGEN_ASSERT(build.data.empty()); - - uint8_t* codeStart = nullptr; - if (!data.codeAllocator.allocate(build.data.data(), int(build.data.size()), reinterpret_cast(build.code.data()), - int(build.code.size() * sizeof(build.code[0])), data.gateData, data.gateDataSize, codeStart)) - { - CODEGEN_ASSERT(!"Failed to create entry function"); - return false; - } - - // Set the offset at the begining so that functions in new blocks will not overlay the locations - // specified by the unwind information of the entry function - unwind.setBeginOffset(build.getLabelOffset(entryLocations.prologueEnd)); - - data.context.gateEntry = codeStart + build.getLabelOffset(entryLocations.start); - data.context.gateExit = codeStart + build.getLabelOffset(entryLocations.epilogueStart); - - return true; -} - bool initHeaderFunctions(BaseCodeGenContext& codeGenContext) { AssemblyBuilderA64 build(/* logText= */ false); diff --git a/CodeGen/src/CodeGenA64.h b/CodeGen/src/CodeGenA64.h index 24fedd9a..2633f5ba 100644 --- a/CodeGen/src/CodeGenA64.h +++ b/CodeGen/src/CodeGenA64.h @@ -7,7 +7,6 @@ namespace CodeGen { class BaseCodeGenContext; -struct NativeState; struct ModuleHelpers; namespace A64 @@ -15,7 +14,6 @@ namespace A64 class AssemblyBuilderA64; -bool initHeaderFunctions(NativeState& data); bool initHeaderFunctions(BaseCodeGenContext& codeGenContext); void assembleHelpers(AssemblyBuilderA64& build, ModuleHelpers& helpers); diff --git a/CodeGen/src/CodeGenContext.cpp b/CodeGen/src/CodeGenContext.cpp index 7788d099..ae9e41f1 100644 --- a/CodeGen/src/CodeGenContext.cpp +++ b/CodeGen/src/CodeGenContext.cpp @@ -14,8 +14,8 @@ LUAU_FASTFLAGVARIABLE(LuauCodegenCheckNullContext, false) -LUAU_FASTINT(LuauCodeGenBlockSize) -LUAU_FASTINT(LuauCodeGenMaxTotalSize) +LUAU_FASTINTVARIABLE(LuauCodeGenBlockSize, 4 * 1024 * 1024) +LUAU_FASTINTVARIABLE(LuauCodeGenMaxTotalSize, 256 * 1024 * 1024) namespace Luau { diff --git a/CodeGen/src/CodeGenUtils.cpp b/CodeGen/src/CodeGenUtils.cpp index 973829ca..ad231e76 100644 --- a/CodeGen/src/CodeGenUtils.cpp +++ b/CodeGen/src/CodeGenUtils.cpp @@ -14,6 +14,7 @@ #include "lstate.h" #include "lstring.h" #include "ltable.h" +#include "ludata.h" #include @@ -219,6 +220,20 @@ void callEpilogC(lua_State* L, int nresults, int n) L->top = (nresults == LUA_MULTRET) ? res : cip->top; } +Udata* newUserdata(lua_State* L, size_t s, int tag) +{ + Udata* u = luaU_newudata(L, s, tag); + + if (Table* h = L->global->udatamt[tag]) + { + u->metatable = h; + + luaC_objbarrier(L, u, h); + } + + return u; +} + // Extracted as-is from lvmexecute.cpp with the exception of control flow (reentry) and removed interrupts/savedpc Closure* callFallback(lua_State* L, StkId ra, StkId argtop, int nresults) { diff --git a/CodeGen/src/CodeGenUtils.h b/CodeGen/src/CodeGenUtils.h index 515a81f0..15d4c95d 100644 --- a/CodeGen/src/CodeGenUtils.h +++ b/CodeGen/src/CodeGenUtils.h @@ -17,6 +17,8 @@ void forgPrepXnextFallback(lua_State* L, TValue* ra, int pc); Closure* callProlog(lua_State* L, TValue* ra, StkId argtop, int nresults); void callEpilogC(lua_State* L, int nresults, int n); +Udata* newUserdata(lua_State* L, size_t s, int tag); + #define CALL_FALLBACK_YIELD 1 Closure* callFallback(lua_State* L, StkId ra, StkId argtop, int nresults); diff --git a/CodeGen/src/CodeGenX64.cpp b/CodeGen/src/CodeGenX64.cpp index 7f4a9e0c..b8df3774 100644 --- a/CodeGen/src/CodeGenX64.cpp +++ b/CodeGen/src/CodeGenX64.cpp @@ -186,39 +186,6 @@ static EntryLocations buildEntryFunction(AssemblyBuilderX64& build, UnwindBuilde return locations; } -bool initHeaderFunctions(NativeState& data) -{ - AssemblyBuilderX64 build(/* logText= */ false); - UnwindBuilder& unwind = *data.unwindBuilder.get(); - - unwind.startInfo(UnwindBuilder::X64); - - EntryLocations entryLocations = buildEntryFunction(build, unwind); - - build.finalize(); - - unwind.finishInfo(); - - CODEGEN_ASSERT(build.data.empty()); - - uint8_t* codeStart = nullptr; - if (!data.codeAllocator.allocate( - build.data.data(), int(build.data.size()), build.code.data(), int(build.code.size()), data.gateData, data.gateDataSize, codeStart)) - { - CODEGEN_ASSERT(!"Failed to create entry function"); - return false; - } - - // Set the offset at the begining so that functions in new blocks will not overlay the locations - // specified by the unwind information of the entry function - unwind.setBeginOffset(build.getLabelOffset(entryLocations.prologueEnd)); - - data.context.gateEntry = codeStart + build.getLabelOffset(entryLocations.start); - data.context.gateExit = codeStart + build.getLabelOffset(entryLocations.epilogueStart); - - return true; -} - bool initHeaderFunctions(BaseCodeGenContext& codeGenContext) { AssemblyBuilderX64 build(/* logText= */ false); diff --git a/CodeGen/src/CodeGenX64.h b/CodeGen/src/CodeGenX64.h index eb6ab81c..ce360b23 100644 --- a/CodeGen/src/CodeGenX64.h +++ b/CodeGen/src/CodeGenX64.h @@ -7,7 +7,6 @@ namespace CodeGen { class BaseCodeGenContext; -struct NativeState; struct ModuleHelpers; namespace X64 @@ -15,7 +14,6 @@ namespace X64 class AssemblyBuilderX64; -bool initHeaderFunctions(NativeState& data); bool initHeaderFunctions(BaseCodeGenContext& codeGenContext); void assembleHelpers(AssemblyBuilderX64& build, ModuleHelpers& helpers); diff --git a/CodeGen/src/EmitCommonA64.h b/CodeGen/src/EmitCommonA64.h index 894570d9..d61fd2a7 100644 --- a/CodeGen/src/EmitCommonA64.h +++ b/CodeGen/src/EmitCommonA64.h @@ -22,8 +22,6 @@ namespace Luau namespace CodeGen { -struct NativeState; - namespace A64 { diff --git a/CodeGen/src/EmitCommonX64.h b/CodeGen/src/EmitCommonX64.h index c29479e1..f88944e5 100644 --- a/CodeGen/src/EmitCommonX64.h +++ b/CodeGen/src/EmitCommonX64.h @@ -26,7 +26,6 @@ namespace CodeGen { enum class IrCondition : uint8_t; -struct NativeState; struct IrOp; namespace X64 diff --git a/CodeGen/src/IrDump.cpp b/CodeGen/src/IrDump.cpp index c47a0b8f..a82ee894 100644 --- a/CodeGen/src/IrDump.cpp +++ b/CodeGen/src/IrDump.cpp @@ -199,6 +199,8 @@ const char* getCmdName(IrCmd cmd) return "TRY_NUM_TO_INDEX"; case IrCmd::TRY_CALL_FASTGETTM: return "TRY_CALL_FASTGETTM"; + case IrCmd::NEW_USERDATA: + return "NEW_USERDATA"; case IrCmd::INT_TO_NUM: return "INT_TO_NUM"; case IrCmd::UINT_TO_NUM: @@ -257,6 +259,8 @@ const char* getCmdName(IrCmd cmd) return "CHECK_NODE_VALUE"; case IrCmd::CHECK_BUFFER_LEN: return "CHECK_BUFFER_LEN"; + case IrCmd::CHECK_USERDATA_TAG: + return "CHECK_USERDATA_TAG"; case IrCmd::INTERRUPT: return "INTERRUPT"; case IrCmd::CHECK_GC: diff --git a/CodeGen/src/IrLoweringA64.cpp b/CodeGen/src/IrLoweringA64.cpp index c8cc07f4..ea83bb99 100644 --- a/CodeGen/src/IrLoweringA64.cpp +++ b/CodeGen/src/IrLoweringA64.cpp @@ -13,6 +13,9 @@ LUAU_FASTFLAG(LuauCodegenRemoveDeadStores5) LUAU_FASTFLAG(LuauCodegenSplitDoarith) +LUAU_FASTFLAG(LuauCodegenUserdataOps) +LUAU_FASTFLAGVARIABLE(LuauCodegenUserdataAlloc, false) +LUAU_FASTFLAGVARIABLE(LuauCodegenUserdataOpsFixA64, false) namespace Luau { @@ -1083,6 +1086,19 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) inst.regA64 = regs.takeReg(x0, index); break; } + case IrCmd::NEW_USERDATA: + { + CODEGEN_ASSERT(FFlag::LuauCodegenUserdataAlloc); + + regs.spill(build, index); + build.mov(x0, rState); + build.mov(x1, intOp(inst.a)); + build.mov(x2, intOp(inst.b)); + build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, newUserdata))); + build.blr(x3); + inst.regA64 = regs.takeReg(x0, index); + break; + } case IrCmd::INT_TO_NUM: { inst.regA64 = regs.allocReg(KindA64::d, index); @@ -1677,6 +1693,24 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) finalizeTargetLabel(inst.d, fresh); break; } + case IrCmd::CHECK_USERDATA_TAG: + { + CODEGEN_ASSERT(FFlag::LuauCodegenUserdataOps); + + Label fresh; // used when guard aborts execution or jumps to a VM exit + Label& fail = getTargetLabel(inst.c, fresh); + RegisterA64 temp = regs.allocTemp(KindA64::w); + build.ldrb(temp, mem(regOp(inst.a), offsetof(Udata, tag))); + + if (FFlag::LuauCodegenUserdataOpsFixA64) + build.cmp(temp, intOp(inst.b)); + else + build.cmp(temp, tagOp(inst.b)); + + build.b(ConditionA64::NotEqual, fail); + finalizeTargetLabel(inst.c, fresh); + break; + } case IrCmd::INTERRUPT: { regs.spill(build, index); @@ -2308,7 +2342,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) case IrCmd::BUFFER_READI8: { inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.b}); - AddressA64 addr = tempAddrBuffer(inst.a, inst.b); + AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c)); build.ldrsb(inst.regA64, addr); break; @@ -2317,7 +2351,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) case IrCmd::BUFFER_READU8: { inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.b}); - AddressA64 addr = tempAddrBuffer(inst.a, inst.b); + AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c)); build.ldrb(inst.regA64, addr); break; @@ -2326,7 +2360,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) case IrCmd::BUFFER_WRITEI8: { RegisterA64 temp = tempInt(inst.c); - AddressA64 addr = tempAddrBuffer(inst.a, inst.b); + AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d)); build.strb(temp, addr); break; @@ -2335,7 +2369,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) case IrCmd::BUFFER_READI16: { inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.b}); - AddressA64 addr = tempAddrBuffer(inst.a, inst.b); + AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c)); build.ldrsh(inst.regA64, addr); break; @@ -2344,7 +2378,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) case IrCmd::BUFFER_READU16: { inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.b}); - AddressA64 addr = tempAddrBuffer(inst.a, inst.b); + AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c)); build.ldrh(inst.regA64, addr); break; @@ -2353,7 +2387,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) case IrCmd::BUFFER_WRITEI16: { RegisterA64 temp = tempInt(inst.c); - AddressA64 addr = tempAddrBuffer(inst.a, inst.b); + AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d)); build.strh(temp, addr); break; @@ -2362,7 +2396,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) case IrCmd::BUFFER_READI32: { inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.b}); - AddressA64 addr = tempAddrBuffer(inst.a, inst.b); + AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c)); build.ldr(inst.regA64, addr); break; @@ -2371,7 +2405,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) case IrCmd::BUFFER_WRITEI32: { RegisterA64 temp = tempInt(inst.c); - AddressA64 addr = tempAddrBuffer(inst.a, inst.b); + AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d)); build.str(temp, addr); break; @@ -2381,7 +2415,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) { inst.regA64 = regs.allocReg(KindA64::d, index); RegisterA64 temp = castReg(KindA64::s, inst.regA64); // safe to alias a fresh register - AddressA64 addr = tempAddrBuffer(inst.a, inst.b); + AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c)); build.ldr(temp, addr); build.fcvt(inst.regA64, temp); @@ -2392,7 +2426,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) { RegisterA64 temp1 = tempDouble(inst.c); RegisterA64 temp2 = regs.allocTemp(KindA64::s); - AddressA64 addr = tempAddrBuffer(inst.a, inst.b); + AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d)); build.fcvt(temp2, temp1); build.str(temp2, addr); @@ -2402,7 +2436,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) case IrCmd::BUFFER_READF64: { inst.regA64 = regs.allocReg(KindA64::d, index); - AddressA64 addr = tempAddrBuffer(inst.a, inst.b); + AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c)); build.ldr(inst.regA64, addr); break; @@ -2411,7 +2445,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) case IrCmd::BUFFER_WRITEF64: { RegisterA64 temp = tempDouble(inst.c); - AddressA64 addr = tempAddrBuffer(inst.a, inst.b); + AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d)); build.str(temp, addr); break; @@ -2639,32 +2673,68 @@ AddressA64 IrLoweringA64::tempAddr(IrOp op, int offset) } } -AddressA64 IrLoweringA64::tempAddrBuffer(IrOp bufferOp, IrOp indexOp) +AddressA64 IrLoweringA64::tempAddrBuffer(IrOp bufferOp, IrOp indexOp, uint8_t tag) { - if (indexOp.kind == IrOpKind::Inst) + if (FFlag::LuauCodegenUserdataOps) { - RegisterA64 temp = regs.allocTemp(KindA64::x); - build.add(temp, regOp(bufferOp), regOp(indexOp)); // implicit uxtw - return mem(temp, offsetof(Buffer, data)); - } - else if (indexOp.kind == IrOpKind::Constant) - { - // Since the resulting address may be used to load any size, including 1 byte, from an unaligned offset, we are limited by unscaled encoding - if (unsigned(intOp(indexOp)) + offsetof(Buffer, data) <= 255) - return mem(regOp(bufferOp), int(intOp(indexOp) + offsetof(Buffer, data))); + CODEGEN_ASSERT(tag == LUA_TUSERDATA || tag == LUA_TBUFFER); + int dataOffset = tag == LUA_TBUFFER ? offsetof(Buffer, data) : offsetof(Udata, data); - // indexOp can only be negative in dead code (since offsets are checked); this avoids assertion in emitAddOffset - if (intOp(indexOp) < 0) - return mem(regOp(bufferOp), offsetof(Buffer, data)); + if (indexOp.kind == IrOpKind::Inst) + { + RegisterA64 temp = regs.allocTemp(KindA64::x); + build.add(temp, regOp(bufferOp), regOp(indexOp)); // implicit uxtw + return mem(temp, dataOffset); + } + else if (indexOp.kind == IrOpKind::Constant) + { + // Since the resulting address may be used to load any size, including 1 byte, from an unaligned offset, we are limited by unscaled + // encoding + if (unsigned(intOp(indexOp)) + dataOffset <= 255) + return mem(regOp(bufferOp), int(intOp(indexOp) + dataOffset)); - RegisterA64 temp = regs.allocTemp(KindA64::x); - emitAddOffset(build, temp, regOp(bufferOp), size_t(intOp(indexOp))); - return mem(temp, offsetof(Buffer, data)); + // indexOp can only be negative in dead code (since offsets are checked); this avoids assertion in emitAddOffset + if (intOp(indexOp) < 0) + return mem(regOp(bufferOp), dataOffset); + + RegisterA64 temp = regs.allocTemp(KindA64::x); + emitAddOffset(build, temp, regOp(bufferOp), size_t(intOp(indexOp))); + return mem(temp, dataOffset); + } + else + { + CODEGEN_ASSERT(!"Unsupported instruction form"); + return noreg; + } } else { - CODEGEN_ASSERT(!"Unsupported instruction form"); - return noreg; + if (indexOp.kind == IrOpKind::Inst) + { + RegisterA64 temp = regs.allocTemp(KindA64::x); + build.add(temp, regOp(bufferOp), regOp(indexOp)); // implicit uxtw + return mem(temp, offsetof(Buffer, data)); + } + else if (indexOp.kind == IrOpKind::Constant) + { + // Since the resulting address may be used to load any size, including 1 byte, from an unaligned offset, we are limited by unscaled + // encoding + if (unsigned(intOp(indexOp)) + offsetof(Buffer, data) <= 255) + return mem(regOp(bufferOp), int(intOp(indexOp) + offsetof(Buffer, data))); + + // indexOp can only be negative in dead code (since offsets are checked); this avoids assertion in emitAddOffset + if (intOp(indexOp) < 0) + return mem(regOp(bufferOp), offsetof(Buffer, data)); + + RegisterA64 temp = regs.allocTemp(KindA64::x); + emitAddOffset(build, temp, regOp(bufferOp), size_t(intOp(indexOp))); + return mem(temp, offsetof(Buffer, data)); + } + else + { + CODEGEN_ASSERT(!"Unsupported instruction form"); + return noreg; + } } } diff --git a/CodeGen/src/IrLoweringA64.h b/CodeGen/src/IrLoweringA64.h index 5fb7f2b8..5f13f58e 100644 --- a/CodeGen/src/IrLoweringA64.h +++ b/CodeGen/src/IrLoweringA64.h @@ -44,7 +44,7 @@ struct IrLoweringA64 RegisterA64 tempInt(IrOp op); RegisterA64 tempUint(IrOp op); AddressA64 tempAddr(IrOp op, int offset); - AddressA64 tempAddrBuffer(IrOp bufferOp, IrOp indexOp); + AddressA64 tempAddrBuffer(IrOp bufferOp, IrOp indexOp, uint8_t tag); // May emit restore instructions RegisterA64 regOp(IrOp op); diff --git a/CodeGen/src/IrLoweringX64.cpp b/CodeGen/src/IrLoweringX64.cpp index 66609cb7..00768c70 100644 --- a/CodeGen/src/IrLoweringX64.cpp +++ b/CodeGen/src/IrLoweringX64.cpp @@ -15,6 +15,9 @@ #include "lstate.h" #include "lgc.h" +LUAU_FASTFLAG(LuauCodegenUserdataOps) +LUAU_FASTFLAG(LuauCodegenUserdataAlloc) + namespace Luau { namespace CodeGen @@ -905,6 +908,18 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) inst.regX64 = regs.takeReg(rax, index); break; } + case IrCmd::NEW_USERDATA: + { + CODEGEN_ASSERT(FFlag::LuauCodegenUserdataAlloc); + + IrCallWrapperX64 callWrap(regs, build, index); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::qword, intOp(inst.a)); + callWrap.addArgument(SizeX64::dword, intOp(inst.b)); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, newUserdata)]); + inst.regX64 = regs.takeReg(rax, index); + break; + } case IrCmd::INT_TO_NUM: inst.regX64 = regs.allocReg(SizeX64::xmmword, index); @@ -1350,6 +1365,14 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) } break; } + case IrCmd::CHECK_USERDATA_TAG: + { + CODEGEN_ASSERT(FFlag::LuauCodegenUserdataOps); + + build.cmp(byte[regOp(inst.a) + offsetof(Udata, tag)], intOp(inst.b)); + jumpOrAbortOnUndef(ConditionX64::NotEqual, inst.c, next); + break; + } case IrCmd::INTERRUPT: { unsigned pcpos = uintOp(inst.a); @@ -1895,71 +1918,71 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) case IrCmd::BUFFER_READI8: inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b}); - build.movsx(inst.regX64, byte[bufferAddrOp(inst.a, inst.b)]); + build.movsx(inst.regX64, byte[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]); break; case IrCmd::BUFFER_READU8: inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b}); - build.movzx(inst.regX64, byte[bufferAddrOp(inst.a, inst.b)]); + build.movzx(inst.regX64, byte[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]); break; case IrCmd::BUFFER_WRITEI8: { OperandX64 value = inst.c.kind == IrOpKind::Inst ? byteReg(regOp(inst.c)) : OperandX64(int8_t(intOp(inst.c))); - build.mov(byte[bufferAddrOp(inst.a, inst.b)], value); + build.mov(byte[bufferAddrOp(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d))], value); break; } case IrCmd::BUFFER_READI16: inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b}); - build.movsx(inst.regX64, word[bufferAddrOp(inst.a, inst.b)]); + build.movsx(inst.regX64, word[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]); break; case IrCmd::BUFFER_READU16: inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b}); - build.movzx(inst.regX64, word[bufferAddrOp(inst.a, inst.b)]); + build.movzx(inst.regX64, word[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]); break; case IrCmd::BUFFER_WRITEI16: { OperandX64 value = inst.c.kind == IrOpKind::Inst ? wordReg(regOp(inst.c)) : OperandX64(int16_t(intOp(inst.c))); - build.mov(word[bufferAddrOp(inst.a, inst.b)], value); + build.mov(word[bufferAddrOp(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d))], value); break; } case IrCmd::BUFFER_READI32: inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b}); - build.mov(inst.regX64, dword[bufferAddrOp(inst.a, inst.b)]); + build.mov(inst.regX64, dword[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]); break; case IrCmd::BUFFER_WRITEI32: { OperandX64 value = inst.c.kind == IrOpKind::Inst ? regOp(inst.c) : OperandX64(intOp(inst.c)); - build.mov(dword[bufferAddrOp(inst.a, inst.b)], value); + build.mov(dword[bufferAddrOp(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d))], value); break; } case IrCmd::BUFFER_READF32: inst.regX64 = regs.allocReg(SizeX64::xmmword, index); - build.vcvtss2sd(inst.regX64, inst.regX64, dword[bufferAddrOp(inst.a, inst.b)]); + build.vcvtss2sd(inst.regX64, inst.regX64, dword[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]); break; case IrCmd::BUFFER_WRITEF32: - storeDoubleAsFloat(dword[bufferAddrOp(inst.a, inst.b)], inst.c); + storeDoubleAsFloat(dword[bufferAddrOp(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d))], inst.c); break; case IrCmd::BUFFER_READF64: inst.regX64 = regs.allocReg(SizeX64::xmmword, index); - build.vmovsd(inst.regX64, qword[bufferAddrOp(inst.a, inst.b)]); + build.vmovsd(inst.regX64, qword[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]); break; case IrCmd::BUFFER_WRITEF64: @@ -1967,11 +1990,11 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) { ScopedRegX64 tmp{regs, SizeX64::xmmword}; build.vmovsd(tmp.reg, build.f64(doubleOp(inst.c))); - build.vmovsd(qword[bufferAddrOp(inst.a, inst.b)], tmp.reg); + build.vmovsd(qword[bufferAddrOp(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d))], tmp.reg); } else if (inst.c.kind == IrOpKind::Inst) { - build.vmovsd(qword[bufferAddrOp(inst.a, inst.b)], regOp(inst.c)); + build.vmovsd(qword[bufferAddrOp(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d))], regOp(inst.c)); } else { @@ -2190,12 +2213,25 @@ RegisterX64 IrLoweringX64::regOp(IrOp op) return inst.regX64; } -OperandX64 IrLoweringX64::bufferAddrOp(IrOp bufferOp, IrOp indexOp) +OperandX64 IrLoweringX64::bufferAddrOp(IrOp bufferOp, IrOp indexOp, uint8_t tag) { - if (indexOp.kind == IrOpKind::Inst) - return regOp(bufferOp) + qwordReg(regOp(indexOp)) + offsetof(Buffer, data); - else if (indexOp.kind == IrOpKind::Constant) - return regOp(bufferOp) + intOp(indexOp) + offsetof(Buffer, data); + if (FFlag::LuauCodegenUserdataOps) + { + CODEGEN_ASSERT(tag == LUA_TUSERDATA || tag == LUA_TBUFFER); + int dataOffset = tag == LUA_TBUFFER ? offsetof(Buffer, data) : offsetof(Udata, data); + + if (indexOp.kind == IrOpKind::Inst) + return regOp(bufferOp) + qwordReg(regOp(indexOp)) + dataOffset; + else if (indexOp.kind == IrOpKind::Constant) + return regOp(bufferOp) + intOp(indexOp) + dataOffset; + } + else + { + if (indexOp.kind == IrOpKind::Inst) + return regOp(bufferOp) + qwordReg(regOp(indexOp)) + offsetof(Buffer, data); + else if (indexOp.kind == IrOpKind::Constant) + return regOp(bufferOp) + intOp(indexOp) + offsetof(Buffer, data); + } CODEGEN_ASSERT(!"Unsupported instruction form"); return noreg; diff --git a/CodeGen/src/IrLoweringX64.h b/CodeGen/src/IrLoweringX64.h index 5fb7b0fa..8fb311ea 100644 --- a/CodeGen/src/IrLoweringX64.h +++ b/CodeGen/src/IrLoweringX64.h @@ -50,7 +50,7 @@ struct IrLoweringX64 OperandX64 memRegUintOp(IrOp op); OperandX64 memRegTagOp(IrOp op); RegisterX64 regOp(IrOp op); - OperandX64 bufferAddrOp(IrOp bufferOp, IrOp indexOp); + OperandX64 bufferAddrOp(IrOp bufferOp, IrOp indexOp, uint8_t tag); RegisterX64 vecOp(IrOp op, ScopedRegX64& tmp); IrConst constOp(IrOp op) const; diff --git a/CodeGen/src/IrTranslation.cpp b/CodeGen/src/IrTranslation.cpp index 93073a92..5798f3e9 100644 --- a/CodeGen/src/IrTranslation.cpp +++ b/CodeGen/src/IrTranslation.cpp @@ -15,6 +15,7 @@ LUAU_FASTFLAGVARIABLE(LuauCodegenDirectUserdataFlow, false) LUAU_FASTFLAG(LuauCodegenAnalyzeHostVectorOps) +LUAU_FASTFLAG(LuauCodegenUserdataOps) namespace Luau { @@ -444,6 +445,17 @@ static void translateInstBinaryNumeric(IrBuilder& build, int ra, int rb, int rc, return; } + if (FFlag::LuauCodegenUserdataOps && (isUserdataBytecodeType(bcTypes.a) || isUserdataBytecodeType(bcTypes.b))) + { + if (build.hostHooks.userdataMetamethod && + build.hostHooks.userdataMetamethod(build, bcTypes.a, bcTypes.b, ra, opb, opc, tmToHostMetamethod(tm), pcpos)) + return; + + build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1)); + build.inst(IrCmd::DO_ARITH, build.vmReg(ra), opb, opc, build.constInt(tm)); + return; + } + IrOp fallback; // fast-path: number @@ -585,6 +597,17 @@ void translateInstMinus(IrBuilder& build, const Instruction* pc, int pcpos) return; } + if (FFlag::LuauCodegenUserdataOps && isUserdataBytecodeType(bcTypes.a)) + { + if (build.hostHooks.userdataMetamethod && + build.hostHooks.userdataMetamethod(build, bcTypes.a, bcTypes.b, ra, build.vmReg(rb), {}, tmToHostMetamethod(TM_UNM), pcpos)) + return; + + build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1)); + build.inst(IrCmd::DO_ARITH, build.vmReg(ra), build.vmReg(rb), build.vmReg(rb), build.constInt(TM_UNM)); + return; + } + IrOp fallback; IrOp tb = build.inst(IrCmd::LOAD_TAG, build.vmReg(rb)); @@ -606,8 +629,17 @@ void translateInstMinus(IrBuilder& build, const Instruction* pc, int pcpos) FallbackStreamScope scope(build, fallback, next); build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1)); - build.inst( - IrCmd::DO_ARITH, build.vmReg(LUAU_INSN_A(*pc)), build.vmReg(LUAU_INSN_B(*pc)), build.vmReg(LUAU_INSN_B(*pc)), build.constInt(TM_UNM)); + + if (FFlag::LuauCodegenUserdataOps) + { + build.inst(IrCmd::DO_ARITH, build.vmReg(ra), build.vmReg(rb), build.vmReg(rb), build.constInt(TM_UNM)); + } + else + { + build.inst( + IrCmd::DO_ARITH, build.vmReg(LUAU_INSN_A(*pc)), build.vmReg(LUAU_INSN_B(*pc)), build.vmReg(LUAU_INSN_B(*pc)), build.constInt(TM_UNM)); + } + build.inst(IrCmd::JUMP, next); } } @@ -619,6 +651,17 @@ void translateInstLength(IrBuilder& build, const Instruction* pc, int pcpos) int ra = LUAU_INSN_A(*pc); int rb = LUAU_INSN_B(*pc); + if (FFlag::LuauCodegenUserdataOps && isUserdataBytecodeType(bcTypes.a)) + { + if (build.hostHooks.userdataMetamethod && + build.hostHooks.userdataMetamethod(build, bcTypes.a, bcTypes.b, ra, build.vmReg(rb), {}, tmToHostMetamethod(TM_LEN), pcpos)) + return; + + build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1)); + build.inst(IrCmd::DO_LEN, build.vmReg(ra), build.vmReg(rb)); + return; + } + IrOp fallback = build.block(IrBlockKind::Fallback); IrOp tb = build.inst(IrCmd::LOAD_TAG, build.vmReg(rb)); @@ -638,7 +681,12 @@ void translateInstLength(IrBuilder& build, const Instruction* pc, int pcpos) FallbackStreamScope scope(build, fallback, next); build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1)); - build.inst(IrCmd::DO_LEN, build.vmReg(LUAU_INSN_A(*pc)), build.vmReg(LUAU_INSN_B(*pc))); + + if (FFlag::LuauCodegenUserdataOps) + build.inst(IrCmd::DO_LEN, build.vmReg(ra), build.vmReg(rb)); + else + build.inst(IrCmd::DO_LEN, build.vmReg(LUAU_INSN_A(*pc)), build.vmReg(LUAU_INSN_B(*pc))); + build.inst(IrCmd::JUMP, next); } @@ -1229,10 +1277,19 @@ void translateInstGetTableKS(IrBuilder& build, const Instruction* pc, int pcpos) return; } - if (FFlag::LuauCodegenDirectUserdataFlow && bcTypes.a == LBC_TYPE_USERDATA) + if (FFlag::LuauCodegenDirectUserdataFlow && (FFlag::LuauCodegenUserdataOps ? isUserdataBytecodeType(bcTypes.a) : bcTypes.a == LBC_TYPE_USERDATA)) { build.inst(IrCmd::CHECK_TAG, tb, build.constTag(LUA_TUSERDATA), build.vmExit(pcpos)); + if (FFlag::LuauCodegenUserdataOps && build.hostHooks.userdataAccess) + { + TString* str = gco2ts(build.function.proto->k[aux].value.gc); + const char* field = getstr(str); + + if (build.hostHooks.userdataAccess(build, bcTypes.a, field, str->len, ra, rb, pcpos)) + return; + } + build.inst(IrCmd::FALLBACK_GETTABLEKS, build.constUint(pcpos), build.vmReg(ra), build.vmReg(rb), build.vmConst(aux)); return; } @@ -1267,7 +1324,7 @@ void translateInstSetTableKS(IrBuilder& build, const Instruction* pc, int pcpos) IrOp tb = build.inst(IrCmd::LOAD_TAG, build.vmReg(rb)); - if (FFlag::LuauCodegenDirectUserdataFlow && bcTypes.a == LBC_TYPE_USERDATA) + if (FFlag::LuauCodegenDirectUserdataFlow && (FFlag::LuauCodegenUserdataOps ? isUserdataBytecodeType(bcTypes.a) : bcTypes.a == LBC_TYPE_USERDATA)) { build.inst(IrCmd::CHECK_TAG, tb, build.constTag(LUA_TUSERDATA), build.vmExit(pcpos)); @@ -1413,10 +1470,26 @@ bool translateInstNamecall(IrBuilder& build, const Instruction* pc, int pcpos) return false; } - if (FFlag::LuauCodegenDirectUserdataFlow && bcTypes.a == LBC_TYPE_USERDATA) + if (FFlag::LuauCodegenDirectUserdataFlow && (FFlag::LuauCodegenUserdataOps ? isUserdataBytecodeType(bcTypes.a) : bcTypes.a == LBC_TYPE_USERDATA)) { build.loadAndCheckTag(build.vmReg(rb), LUA_TUSERDATA, build.vmExit(pcpos)); + if (FFlag::LuauCodegenUserdataOps && build.hostHooks.userdataNamecall) + { + Instruction call = pc[2]; + CODEGEN_ASSERT(LUAU_INSN_OP(call) == LOP_CALL); + + int callra = LUAU_INSN_A(call); + int nparams = LUAU_INSN_B(call) - 1; + int nresults = LUAU_INSN_C(call) - 1; + + TString* str = gco2ts(build.function.proto->k[aux].value.gc); + const char* field = getstr(str); + + if (build.hostHooks.userdataNamecall(build, bcTypes.a, field, str->len, callra, rb, nparams, nresults, pcpos)) + return true; + } + build.inst(IrCmd::FALLBACK_NAMECALL, build.constUint(pcpos), build.vmReg(ra), build.vmReg(rb), build.vmConst(aux)); return false; } diff --git a/CodeGen/src/IrUtils.cpp b/CodeGen/src/IrUtils.cpp index afc6ba5a..d1bfca45 100644 --- a/CodeGen/src/IrUtils.cpp +++ b/CodeGen/src/IrUtils.cpp @@ -99,6 +99,7 @@ IrValueKind getCmdValueKind(IrCmd cmd) case IrCmd::TRY_NUM_TO_INDEX: return IrValueKind::Int; case IrCmd::TRY_CALL_FASTGETTM: + case IrCmd::NEW_USERDATA: return IrValueKind::Pointer; case IrCmd::INT_TO_NUM: case IrCmd::UINT_TO_NUM: @@ -135,6 +136,7 @@ IrValueKind getCmdValueKind(IrCmd cmd) case IrCmd::CHECK_NODE_NO_NEXT: case IrCmd::CHECK_NODE_VALUE: case IrCmd::CHECK_BUFFER_LEN: + case IrCmd::CHECK_USERDATA_TAG: case IrCmd::INTERRUPT: case IrCmd::CHECK_GC: case IrCmd::BARRIER_OBJ: @@ -262,6 +264,44 @@ bool isCustomUserdataBytecodeType(uint8_t ty) return ty >= LBC_TYPE_TAGGED_USERDATA_BASE && ty < LBC_TYPE_TAGGED_USERDATA_END; } +HostMetamethod tmToHostMetamethod(int tm) +{ + switch (TMS(tm)) + { + case TM_ADD: + return HostMetamethod::Add; + case TM_SUB: + return HostMetamethod::Sub; + case TM_MUL: + return HostMetamethod::Mul; + case TM_DIV: + return HostMetamethod::Div; + case TM_IDIV: + return HostMetamethod::Idiv; + case TM_MOD: + return HostMetamethod::Mod; + case TM_POW: + return HostMetamethod::Pow; + case TM_UNM: + return HostMetamethod::Minus; + case TM_EQ: + return HostMetamethod::Equal; + case TM_LT: + return HostMetamethod::LessThan; + case TM_LE: + return HostMetamethod::LessEqual; + case TM_LEN: + return HostMetamethod::Length; + case TM_CONCAT: + return HostMetamethod::Concat; + default: + CODEGEN_ASSERT(!"invalid tag method for host"); + break; + } + + return HostMetamethod::Add; +} + void kill(IrFunction& function, IrInst& inst) { CODEGEN_ASSERT(inst.useCount == 0); diff --git a/CodeGen/src/NativeState.cpp b/CodeGen/src/NativeState.cpp index b3d07491..248f0cd3 100644 --- a/CodeGen/src/NativeState.cpp +++ b/CodeGen/src/NativeState.cpp @@ -14,114 +14,13 @@ #include #include -LUAU_FASTINTVARIABLE(LuauCodeGenBlockSize, 4 * 1024 * 1024) -LUAU_FASTINTVARIABLE(LuauCodeGenMaxTotalSize, 256 * 1024 * 1024) +LUAU_FASTFLAG(LuauCodegenUserdataAlloc) namespace Luau { namespace CodeGen { -NativeState::NativeState() - : NativeState(nullptr, nullptr) -{ -} - -NativeState::NativeState(AllocationCallback* allocationCallback, void* allocationCallbackContext) - : codeAllocator{size_t(FInt::LuauCodeGenBlockSize), size_t(FInt::LuauCodeGenMaxTotalSize), allocationCallback, allocationCallbackContext} -{ -} - -NativeState::~NativeState() = default; - -void initFunctions(NativeState& data) -{ - static_assert(sizeof(data.context.luauF_table) == sizeof(luauF_table), "fastcall tables are not of the same length"); - memcpy(data.context.luauF_table, luauF_table, sizeof(luauF_table)); - - data.context.luaV_lessthan = luaV_lessthan; - data.context.luaV_lessequal = luaV_lessequal; - data.context.luaV_equalval = luaV_equalval; - data.context.luaV_doarith = luaV_doarith; - - data.context.luaV_doarithadd = luaV_doarithimpl; - data.context.luaV_doarithsub = luaV_doarithimpl; - data.context.luaV_doarithmul = luaV_doarithimpl; - data.context.luaV_doarithdiv = luaV_doarithimpl; - data.context.luaV_doarithidiv = luaV_doarithimpl; - data.context.luaV_doarithmod = luaV_doarithimpl; - data.context.luaV_doarithpow = luaV_doarithimpl; - data.context.luaV_doarithunm = luaV_doarithimpl; - - data.context.luaV_dolen = luaV_dolen; - data.context.luaV_gettable = luaV_gettable; - data.context.luaV_settable = luaV_settable; - data.context.luaV_getimport = luaV_getimport; - data.context.luaV_concat = luaV_concat; - - data.context.luaH_getn = luaH_getn; - data.context.luaH_new = luaH_new; - data.context.luaH_clone = luaH_clone; - data.context.luaH_resizearray = luaH_resizearray; - data.context.luaH_setnum = luaH_setnum; - - data.context.luaC_barriertable = luaC_barriertable; - data.context.luaC_barrierf = luaC_barrierf; - data.context.luaC_barrierback = luaC_barrierback; - data.context.luaC_step = luaC_step; - - data.context.luaF_close = luaF_close; - data.context.luaF_findupval = luaF_findupval; - data.context.luaF_newLclosure = luaF_newLclosure; - - data.context.luaT_gettm = luaT_gettm; - data.context.luaT_objtypenamestr = luaT_objtypenamestr; - - data.context.libm_exp = exp; - data.context.libm_pow = pow; - data.context.libm_fmod = fmod; - data.context.libm_log = log; - data.context.libm_log2 = log2; - data.context.libm_log10 = log10; - data.context.libm_ldexp = ldexp; - data.context.libm_round = round; - data.context.libm_frexp = frexp; - data.context.libm_modf = modf; - - data.context.libm_asin = asin; - data.context.libm_sin = sin; - data.context.libm_sinh = sinh; - data.context.libm_acos = acos; - data.context.libm_cos = cos; - data.context.libm_cosh = cosh; - data.context.libm_atan = atan; - data.context.libm_atan2 = atan2; - data.context.libm_tan = tan; - data.context.libm_tanh = tanh; - - data.context.forgLoopTableIter = forgLoopTableIter; - data.context.forgLoopNodeIter = forgLoopNodeIter; - data.context.forgLoopNonTableFallback = forgLoopNonTableFallback; - data.context.forgPrepXnextFallback = forgPrepXnextFallback; - data.context.callProlog = callProlog; - data.context.callEpilogC = callEpilogC; - - data.context.callFallback = callFallback; - - data.context.executeGETGLOBAL = executeGETGLOBAL; - data.context.executeSETGLOBAL = executeSETGLOBAL; - data.context.executeGETTABLEKS = executeGETTABLEKS; - data.context.executeSETTABLEKS = executeSETTABLEKS; - - data.context.executeNAMECALL = executeNAMECALL; - data.context.executeFORGPREP = executeFORGPREP; - data.context.executeGETVARARGSMultRet = executeGETVARARGSMultRet; - data.context.executeGETVARARGSConst = executeGETVARARGSConst; - data.context.executeDUPCLOSURE = executeDUPCLOSURE; - data.context.executePREPVARARGS = executePREPVARARGS; - data.context.executeSETLIST = executeSETLIST; -} - void initFunctions(NativeContext& context) { static_assert(sizeof(context.luauF_table) == sizeof(luauF_table), "fastcall tables are not of the same length"); @@ -194,6 +93,9 @@ void initFunctions(NativeContext& context) context.callProlog = callProlog; context.callEpilogC = callEpilogC; + if (FFlag::LuauCodegenUserdataAlloc) + context.newUserdata = newUserdata; + context.callFallback = callFallback; context.executeGETGLOBAL = executeGETGLOBAL; diff --git a/CodeGen/src/NativeState.h b/CodeGen/src/NativeState.h index 2edfc270..be73815d 100644 --- a/CodeGen/src/NativeState.h +++ b/CodeGen/src/NativeState.h @@ -94,6 +94,7 @@ struct NativeContext void (*forgPrepXnextFallback)(lua_State* L, TValue* ra, int pc) = nullptr; Closure* (*callProlog)(lua_State* L, TValue* ra, StkId argtop, int nresults) = nullptr; void (*callEpilogC)(lua_State* L, int nresults, int n) = nullptr; + Udata* (*newUserdata)(lua_State* L, size_t s, int tag) = nullptr; Closure* (*callFallback)(lua_State* L, StkId ra, StkId argtop, int nresults) = nullptr; @@ -116,22 +117,6 @@ struct NativeContext using GateFn = int (*)(lua_State*, Proto*, uintptr_t, NativeContext*); -struct NativeState -{ - NativeState(); - NativeState(AllocationCallback* allocationCallback, void* allocationCallbackContext); - ~NativeState(); - - CodeAllocator codeAllocator; - std::unique_ptr unwindBuilder; - - uint8_t* gateData = nullptr; - size_t gateDataSize = 0; - - NativeContext context; -}; - -void initFunctions(NativeState& data); void initFunctions(NativeContext& context); } // namespace CodeGen diff --git a/CodeGen/src/OptimizeConstProp.cpp b/CodeGen/src/OptimizeConstProp.cpp index 9135a9ed..4ff49570 100644 --- a/CodeGen/src/OptimizeConstProp.cpp +++ b/CodeGen/src/OptimizeConstProp.cpp @@ -16,9 +16,12 @@ LUAU_FASTINTVARIABLE(LuauCodeGenMinLinearBlockPath, 3) LUAU_FASTINTVARIABLE(LuauCodeGenReuseSlotLimit, 64) +LUAU_FASTINTVARIABLE(LuauCodeGenReuseUdataTagLimit, 64) LUAU_FASTFLAGVARIABLE(DebugLuauAbortingChecks, false) LUAU_FASTFLAG(LuauCodegenRemoveDeadStores5) LUAU_FASTFLAGVARIABLE(LuauCodegenFixSplitStoreConstMismatch, false) +LUAU_FASTFLAG(LuauCodegenUserdataOps) +LUAU_FASTFLAG(LuauCodegenUserdataAlloc) namespace Luau { @@ -200,6 +203,11 @@ struct ConstPropState checkBufferLenCache.clear(); } + void invalidateUserdataData() + { + useradataTagCache.clear(); + } + void invalidateHeap() { for (int i = 0; i <= maxReg; ++i) @@ -417,6 +425,9 @@ struct ConstPropState invalidateValuePropagation(); invalidateHeapTableData(); invalidateHeapBufferData(); + + if (FFlag::LuauCodegenUserdataOps) + invalidateUserdataData(); } IrFunction& function; @@ -446,6 +457,9 @@ struct ConstPropState std::vector checkArraySizeCache; // Additionally, fallback block argument might be different std::vector checkBufferLenCache; // Additionally, fallback block argument might be different + + // Userdata tag cache can point to both NEW_USERDATA and CHECK_USERDATA_TAG instructions + std::vector useradataTagCache; // Additionally, fallback block argument might be different }; static void handleBuiltinEffects(ConstPropState& state, LuauBuiltinFunction bfid, uint32_t firstReturnReg, int nresults) @@ -1061,6 +1075,37 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& state.checkBufferLenCache.push_back(index); break; } + case IrCmd::CHECK_USERDATA_TAG: + { + CODEGEN_ASSERT(FFlag::LuauCodegenUserdataOps); + + for (uint32_t prevIdx : state.useradataTagCache) + { + IrInst& prev = function.instructions[prevIdx]; + + if (prev.cmd == IrCmd::CHECK_USERDATA_TAG) + { + if (prev.a != inst.a || prev.b != inst.b) + continue; + } + else if (FFlag::LuauCodegenUserdataAlloc && prev.cmd == IrCmd::NEW_USERDATA) + { + if (inst.a.kind != IrOpKind::Inst || prevIdx != inst.a.index || prev.b != inst.b) + continue; + } + + if (FFlag::DebugLuauAbortingChecks) + replace(function, inst.c, build.undef()); + else + kill(function, inst); + + return; // Break out from both the loop and the switch + } + + if (int(state.useradataTagCache.size()) < FInt::LuauCodeGenReuseUdataTagLimit) + state.useradataTagCache.push_back(index); + break; + } case IrCmd::BUFFER_READI8: case IrCmd::BUFFER_READU8: case IrCmd::BUFFER_WRITEI8: @@ -1228,6 +1273,12 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& break; case IrCmd::TRY_CALL_FASTGETTM: break; + case IrCmd::NEW_USERDATA: + CODEGEN_ASSERT(FFlag::LuauCodegenUserdataAlloc); + + if (int(state.useradataTagCache.size()) < FInt::LuauCodeGenReuseUdataTagLimit) + state.useradataTagCache.push_back(index); + break; case IrCmd::INT_TO_NUM: case IrCmd::UINT_TO_NUM: state.substituteOrRecord(inst, index); @@ -1512,6 +1563,9 @@ static void constPropInBlockChain(IrBuilder& build, std::vector& visite state.invalidateHeapTableData(); state.invalidateHeapBufferData(); + if (FFlag::LuauCodegenUserdataOps) + state.invalidateUserdataData(); + // Blocks in a chain are guaranteed to follow each other // We force that by giving all blocks the same sorting key, but consecutive chain keys block->sortkey = startSortkey; diff --git a/CodeGen/src/OptimizeDeadStore.cpp b/CodeGen/src/OptimizeDeadStore.cpp index 6c1d6aff..d18b75c5 100644 --- a/CodeGen/src/OptimizeDeadStore.cpp +++ b/CodeGen/src/OptimizeDeadStore.cpp @@ -10,6 +10,7 @@ #include "lobject.h" LUAU_FASTFLAGVARIABLE(LuauCodegenRemoveDeadStores5, false) +LUAU_FASTFLAG(LuauCodegenUserdataOps) // TODO: optimization can be improved by knowing which registers are live in at each VM exit @@ -595,6 +596,11 @@ static void markDeadStoresInInst(RemoveDeadStoreState& state, IrBuilder& build, case IrCmd::CHECK_BUFFER_LEN: state.checkLiveIns(inst.d); break; + case IrCmd::CHECK_USERDATA_TAG: + CODEGEN_ASSERT(FFlag::LuauCodegenUserdataOps); + + state.checkLiveIns(inst.c); + break; case IrCmd::JUMP: // Ideally, we would be able to remove stores to registers that are not live out from a block diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 19526fa9..4842b9a1 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -4219,7 +4219,8 @@ void compileOrThrow(BytecodeBuilder& bytecode, const ParseResult& parseResult, c for (AstExprFunction* expr : functions) compiler.compileFunction(expr, 0); - AstExprFunction main(root->location, /*generics= */ AstArray(), /*genericPacks= */ AstArray(), + AstExprFunction main(root->location, /*attributes=*/AstArray({nullptr, 0}), /*generics= */ AstArray(), + /*genericPacks= */ AstArray(), /* self= */ nullptr, AstArray(), /* vararg= */ true, /* varargLocation= */ Luau::Location(), root, /* functionDepth= */ 0, /* debugname= */ AstName()); uint32_t mainid = compiler.compileFunction(&main, mainFlags); diff --git a/Config/src/Config.cpp b/Config/src/Config.cpp index 693e0f87..5fba9fa3 100644 --- a/Config/src/Config.cpp +++ b/Config/src/Config.cpp @@ -195,7 +195,7 @@ static Error parseJson(const std::string& contents, Action action) } else if (lexer.current().type == Lexeme::QuotedString) { - std::string value(lexer.current().data, lexer.current().length); + std::string value(lexer.current().data, lexer.current().getLength()); next(lexer); if (Error err = action(keys, value)) @@ -232,7 +232,7 @@ static Error parseJson(const std::string& contents, Action action) } else if (lexer.current().type == Lexeme::QuotedString) { - std::string key(lexer.current().data, lexer.current().length); + std::string key(lexer.current().data, lexer.current().getLength()); next(lexer); keys.push_back(key); @@ -250,7 +250,7 @@ static Error parseJson(const std::string& contents, Action action) lexer.current().type == Lexeme::ReservedFalse) { std::string value = lexer.current().type == Lexeme::QuotedString - ? std::string(lexer.current().data, lexer.current().length) + ? std::string(lexer.current().data, lexer.current().getLength()) : (lexer.current().type == Lexeme::ReservedTrue ? "true" : "false"); next(lexer); diff --git a/VM/include/lua.h b/VM/include/lua.h index 4876b933..4ee9306e 100644 --- a/VM/include/lua.h +++ b/VM/include/lua.h @@ -324,6 +324,10 @@ typedef void (*lua_Destructor)(lua_State* L, void* userdata); LUA_API void lua_setuserdatadtor(lua_State* L, int tag, lua_Destructor dtor); LUA_API lua_Destructor lua_getuserdatadtor(lua_State* L, int tag); +// alternative access for metatables already registered with luaL_newmetatable +LUA_API void lua_setuserdatametatable(lua_State* L, int tag, int idx); +LUA_API void lua_getuserdatametatable(lua_State* L, int tag); + LUA_API void lua_setlightuserdataname(lua_State* L, int tag, const char* name); LUA_API const char* lua_getlightuserdataname(lua_State* L, int tag); diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index 58c767f1..87f85af8 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -1427,6 +1427,33 @@ lua_Destructor lua_getuserdatadtor(lua_State* L, int tag) return L->global->udatagc[tag]; } +void lua_setuserdatametatable(lua_State* L, int tag, int idx) +{ + api_check(L, unsigned(tag) < LUA_UTAG_LIMIT); + api_check(L, !L->global->udatamt[tag]); // reassignment not supported + StkId o = index2addr(L, idx); + api_check(L, ttistable(o)); + L->global->udatamt[tag] = hvalue(o); + L->top--; +} + +void lua_getuserdatametatable(lua_State* L, int tag) +{ + api_check(L, unsigned(tag) < LUA_UTAG_LIMIT); + luaC_threadbarrier(L); + + if (Table* h = L->global->udatamt[tag]) + { + sethvalue(L, L->top, h); + } + else + { + setnilvalue(L->top); + } + + api_incr_top(L); +} + void lua_setlightuserdataname(lua_State* L, int tag, const char* name) { api_check(L, unsigned(tag) < LUA_LUTAG_LIMIT); diff --git a/VM/src/lstate.cpp b/VM/src/lstate.cpp index dbc1dd10..6b7a9aa0 100644 --- a/VM/src/lstate.cpp +++ b/VM/src/lstate.cpp @@ -210,7 +210,10 @@ lua_State* lua_newstate(lua_Alloc f, void* ud) for (i = 0; i < LUA_T_COUNT; i++) g->mt[i] = NULL; for (i = 0; i < LUA_UTAG_LIMIT; i++) + { g->udatagc[i] = NULL; + g->udatamt[i] = NULL; + } for (i = 0; i < LUA_LUTAG_LIMIT; i++) g->lightuserdataname[i] = NULL; for (i = 0; i < LUA_MEMORY_CATEGORIES; i++) diff --git a/VM/src/lstate.h b/VM/src/lstate.h index 35e66471..f8caa69b 100644 --- a/VM/src/lstate.h +++ b/VM/src/lstate.h @@ -217,6 +217,7 @@ typedef struct global_State lua_ExecutionCallbacks ecb; void (*udatagc[LUA_UTAG_LIMIT])(lua_State*, void*); // for each userdata tag, a gc callback to be called immediately before freeing memory + Table* udatamt[LUA_LUTAG_LIMIT]; // metatables for tagged userdata TString* lightuserdataname[LUA_LUTAG_LIMIT]; // names for tagged lightuserdata diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 7ced52cf..9c1fca9e 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -342,6 +342,209 @@ void setupVectorHelpers(lua_State* L) lua_pop(L, 1); } +Vec2* lua_vec2_push(lua_State* L) +{ + Vec2* data = (Vec2*)lua_newuserdatatagged(L, sizeof(Vec2), kTagVec2); + + lua_getuserdatametatable(L, kTagVec2); + lua_setmetatable(L, -2); + + return data; +} + +Vec2* lua_vec2_get(lua_State* L, int idx) +{ + Vec2* a = (Vec2*)lua_touserdatatagged(L, idx, kTagVec2); + + if (a) + return a; + + luaL_typeerror(L, idx, "vec2"); +} + +static int lua_vec2(lua_State* L) +{ + double x = luaL_checknumber(L, 1); + double y = luaL_checknumber(L, 2); + + Vec2* data = lua_vec2_push(L); + + data->x = float(x); + data->y = float(y); + + return 1; +} + +static int lua_vec2_dot(lua_State* L) +{ + Vec2* a = lua_vec2_get(L, 1); + Vec2* b = lua_vec2_get(L, 2); + + lua_pushnumber(L, a->x * b->x + a->y * b->y); + return 1; +} + +static int lua_vec2_min(lua_State* L) +{ + Vec2* a = lua_vec2_get(L, 1); + Vec2* b = lua_vec2_get(L, 2); + + Vec2* data = lua_vec2_push(L); + + data->x = a->x < b->x ? a->x : b->x; + data->y = a->y < b->y ? a->y : b->y; + + return 1; +} + +static int lua_vec2_index(lua_State* L) +{ + Vec2* v = lua_vec2_get(L, 1); + const char* name = luaL_checkstring(L, 2); + + if (strcmp(name, "X") == 0) + { + lua_pushnumber(L, v->x); + return 1; + } + + if (strcmp(name, "Y") == 0) + { + lua_pushnumber(L, v->y); + return 1; + } + + if (strcmp(name, "Magnitude") == 0) + { + lua_pushnumber(L, sqrtf(v->x * v->x + v->y * v->y)); + return 1; + } + + if (strcmp(name, "Unit") == 0) + { + float invSqrt = 1.0f / sqrtf(v->x * v->x + v->y * v->y); + + Vec2* data = lua_vec2_push(L); + + data->x = v->x * invSqrt; + data->y = v->y * invSqrt; + return 1; + } + + luaL_error(L, "%s is not a valid member of vector", name); +} + +static int lua_vec2_namecall(lua_State* L) +{ + if (const char* str = lua_namecallatom(L, nullptr)) + { + if (strcmp(str, "Dot") == 0) + return lua_vec2_dot(L); + + if (strcmp(str, "Min") == 0) + return lua_vec2_min(L); + } + + luaL_error(L, "%s is not a valid method of vector", luaL_checkstring(L, 1)); +} + +void setupUserdataHelpers(lua_State* L) +{ + // create metatable with all the metamethods + luaL_newmetatable(L, "vec2"); + luaL_getmetatable(L, "vec2"); + lua_pushvalue(L, -1); + lua_setuserdatametatable(L, kTagVec2, -1); + + lua_pushcfunction(L, lua_vec2_index, nullptr); + lua_setfield(L, -2, "__index"); + + lua_pushcfunction(L, lua_vec2_namecall, nullptr); + lua_setfield(L, -2, "__namecall"); + + lua_pushcclosurek( + L, + [](lua_State* L) { + Vec2* a = lua_vec2_get(L, 1); + Vec2* b = lua_vec2_get(L, 2); + Vec2* data = lua_vec2_push(L); + + data->x = a->x + b->x; + data->y = a->y + b->y; + + return 1; + }, + nullptr, 0, nullptr); + lua_setfield(L, -2, "__add"); + + lua_pushcclosurek( + L, + [](lua_State* L) { + Vec2* a = lua_vec2_get(L, 1); + Vec2* b = lua_vec2_get(L, 2); + Vec2* data = lua_vec2_push(L); + + data->x = a->x - b->x; + data->y = a->y - b->y; + + return 1; + }, + nullptr, 0, nullptr); + lua_setfield(L, -2, "__sub"); + + lua_pushcclosurek( + L, + [](lua_State* L) { + Vec2* a = lua_vec2_get(L, 1); + Vec2* b = lua_vec2_get(L, 2); + Vec2* data = lua_vec2_push(L); + + data->x = a->x * b->x; + data->y = a->y * b->y; + + return 1; + }, + nullptr, 0, nullptr); + lua_setfield(L, -2, "__mul"); + + lua_pushcclosurek( + L, + [](lua_State* L) { + Vec2* a = lua_vec2_get(L, 1); + Vec2* b = lua_vec2_get(L, 2); + Vec2* data = lua_vec2_push(L); + + data->x = a->x / b->x; + data->y = a->y / b->y; + + return 1; + }, + nullptr, 0, nullptr); + lua_setfield(L, -2, "__div"); + + lua_pushcclosurek( + L, + [](lua_State* L) { + Vec2* a = lua_vec2_get(L, 1); + Vec2* data = lua_vec2_push(L); + + data->x = -a->x; + data->y = -a->y; + + return 1; + }, + nullptr, 0, nullptr); + lua_setfield(L, -2, "__unm"); + + lua_setreadonly(L, -1, true); + + // ctor + lua_pushcfunction(L, lua_vec2, "vec2"); + lua_setglobal(L, "vec2"); + + lua_pop(L, 1); +} + static void setupNativeHelpers(lua_State* L) { lua_pushcclosurek( @@ -1828,16 +2031,36 @@ TEST_CASE("UserdataApi") luaL_newmetatable(L, "udata2"); void* ud5 = lua_newuserdata(L, 0); - lua_getfield(L, LUA_REGISTRYINDEX, "udata1"); + luaL_getmetatable(L, "udata1"); lua_setmetatable(L, -2); void* ud6 = lua_newuserdata(L, 0); - lua_getfield(L, LUA_REGISTRYINDEX, "udata2"); + luaL_getmetatable(L, "udata2"); lua_setmetatable(L, -2); CHECK(luaL_checkudata(L, -2, "udata1") == ud5); CHECK(luaL_checkudata(L, -1, "udata2") == ud6); + // tagged user data with fast metatable access + luaL_newmetatable(L, "udata3"); + luaL_getmetatable(L, "udata3"); + lua_setuserdatametatable(L, 50, -1); + + luaL_newmetatable(L, "udata4"); + luaL_getmetatable(L, "udata4"); + lua_setuserdatametatable(L, 51, -1); + + void* ud7 = lua_newuserdatatagged(L, 16, 50); + lua_getuserdatametatable(L, 50); + lua_setmetatable(L, -2); + + void* ud8 = lua_newuserdatatagged(L, 16, 51); + lua_getuserdatametatable(L, 51); + lua_setmetatable(L, -2); + + CHECK(luaL_checkudata(L, -2, "udata3") == ud7); + CHECK(luaL_checkudata(L, -1, "udata4") == ud8); + globalState.reset(); CHECK(dtorhits == 42); @@ -1911,7 +2134,6 @@ TEST_CASE("Iter") } const int kInt64Tag = 1; -static int gInt64MT = -1; static int64_t getInt64(lua_State* L, int idx) { @@ -1928,7 +2150,7 @@ static void pushInt64(lua_State* L, int64_t value) { void* p = lua_newuserdatatagged(L, sizeof(int64_t), kInt64Tag); - lua_getref(L, gInt64MT); + luaL_getmetatable(L, "int64"); lua_setmetatable(L, -2); *static_cast(p) = value; @@ -1938,8 +2160,7 @@ TEST_CASE("Userdata") { runConformance("userdata.lua", [](lua_State* L) { // create metatable with all the metamethods - lua_newtable(L); - gInt64MT = lua_ref(L, -1); + luaL_newmetatable(L, "int64"); // __index lua_pushcfunction( @@ -2164,6 +2385,86 @@ TEST_CASE("NativeTypeAnnotations") }); } +TEST_CASE("NativeUserdata") +{ + lua_CompileOptions copts = defaultOptions(); + Luau::CodeGen::CompilationOptions nativeOpts = defaultCodegenOptions(); + + static const char* kUserdataCompileTypes[] = {"vec2", "color", "mat3", nullptr}; + copts.userdataTypes = kUserdataCompileTypes; + + SUBCASE("NoIrHooks") + { + SUBCASE("O0") + { + copts.optimizationLevel = 0; + } + SUBCASE("O1") + { + copts.optimizationLevel = 1; + } + SUBCASE("O2") + { + copts.optimizationLevel = 2; + } + } + SUBCASE("IrHooks") + { + nativeOpts.hooks.vectorAccessBytecodeType = vectorAccessBytecodeType; + nativeOpts.hooks.vectorNamecallBytecodeType = vectorNamecallBytecodeType; + nativeOpts.hooks.vectorAccess = vectorAccess; + nativeOpts.hooks.vectorNamecall = vectorNamecall; + + nativeOpts.hooks.userdataAccessBytecodeType = userdataAccessBytecodeType; + nativeOpts.hooks.userdataMetamethodBytecodeType = userdataMetamethodBytecodeType; + nativeOpts.hooks.userdataNamecallBytecodeType = userdataNamecallBytecodeType; + nativeOpts.hooks.userdataAccess = userdataAccess; + nativeOpts.hooks.userdataMetamethod = userdataMetamethod; + nativeOpts.hooks.userdataNamecall = userdataNamecall; + + nativeOpts.userdataTypes = kUserdataRunTypes; + + SUBCASE("O0") + { + copts.optimizationLevel = 0; + } + SUBCASE("O1") + { + copts.optimizationLevel = 1; + } + SUBCASE("O2") + { + copts.optimizationLevel = 2; + } + } + + runConformance( + "native_userdata.lua", + [](lua_State* L) { + Luau::CodeGen::setUserdataRemapper(L, kUserdataRunTypes, [](void* context, const char* str, size_t len) -> uint8_t { + const char** types = (const char**)context; + + uint8_t index = 0; + + std::string_view sv{str, len}; + + for (; *types; ++types) + { + if (sv == *types) + return index; + + index++; + } + + return 0xff; + }); + + setupVectorHelpers(L); + setupUserdataHelpers(L); + }, + nullptr, nullptr, &copts, false, &nativeOpts); +} + [[nodiscard]] static std::string makeHugeFunctionSource() { std::string source; diff --git a/tests/ConformanceIrHooks.h b/tests/ConformanceIrHooks.h index d4050863..ab5b86d4 100644 --- a/tests/ConformanceIrHooks.h +++ b/tests/ConformanceIrHooks.h @@ -5,14 +5,44 @@ static const char* kUserdataRunTypes[] = {"extra", "color", "vec2", "mat3", nullptr}; +constexpr uint8_t kUserdataExtra = 0; +constexpr uint8_t kUserdataColor = 1; +constexpr uint8_t kUserdataVec2 = 2; +constexpr uint8_t kUserdataMat3 = 3; + +// Userdata tags can be different from userdata bytecode type indices +constexpr uint8_t kTagVec2 = 12; + +struct Vec2 +{ + float x; + float y; +}; + +inline bool compareMemberName(const char* member, size_t memberLength, const char* str) +{ + return memberLength == strlen(str) && strcmp(member, str) == 0; +} + +inline uint8_t typeToUserdataIndex(uint8_t type) +{ + // Underflow will push the type into a value that is not comparable to any kUserdata* constants + return type - LBC_TYPE_TAGGED_USERDATA_BASE; +} + +inline uint8_t userdataIndexToType(uint8_t userdataIndex) +{ + return LBC_TYPE_TAGGED_USERDATA_BASE + userdataIndex; +} + inline uint8_t vectorAccessBytecodeType(const char* member, size_t memberLength) { using namespace Luau::CodeGen; - if (memberLength == strlen("Magnitude") && strcmp(member, "Magnitude") == 0) + if (compareMemberName(member, memberLength, "Magnitude")) return LBC_TYPE_NUMBER; - if (memberLength == strlen("Unit") && strcmp(member, "Unit") == 0) + if (compareMemberName(member, memberLength, "Unit")) return LBC_TYPE_VECTOR; return LBC_TYPE_ANY; @@ -22,7 +52,7 @@ inline bool vectorAccess(Luau::CodeGen::IrBuilder& build, const char* member, si { using namespace Luau::CodeGen; - if (memberLength == strlen("Magnitude") && strcmp(member, "Magnitude") == 0) + if (compareMemberName(member, memberLength, "Magnitude")) { IrOp x = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(0)); IrOp y = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(4)); @@ -42,7 +72,7 @@ inline bool vectorAccess(Luau::CodeGen::IrBuilder& build, const char* member, si return true; } - if (memberLength == strlen("Unit") && strcmp(member, "Unit") == 0) + if (compareMemberName(member, memberLength, "Unit")) { IrOp x = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(0)); IrOp y = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(4)); @@ -72,10 +102,10 @@ inline bool vectorAccess(Luau::CodeGen::IrBuilder& build, const char* member, si inline uint8_t vectorNamecallBytecodeType(const char* member, size_t memberLength) { - if (memberLength == strlen("Dot") && strcmp(member, "Dot") == 0) + if (compareMemberName(member, memberLength, "Dot")) return LBC_TYPE_NUMBER; - if (memberLength == strlen("Cross") && strcmp(member, "Cross") == 0) + if (compareMemberName(member, memberLength, "Cross")) return LBC_TYPE_VECTOR; return LBC_TYPE_ANY; @@ -86,7 +116,7 @@ inline bool vectorNamecall( { using namespace Luau::CodeGen; - if (memberLength == strlen("Dot") && strcmp(member, "Dot") == 0 && params == 2 && results <= 1) + if (compareMemberName(member, memberLength, "Dot") && params == 2 && results <= 1) { build.loadAndCheckTag(build.vmReg(argResReg + 2), LUA_TVECTOR, build.vmExit(pcpos)); @@ -114,7 +144,7 @@ inline bool vectorNamecall( return true; } - if (memberLength == strlen("Cross") && strcmp(member, "Cross") == 0 && params == 2 && results <= 1) + if (compareMemberName(member, memberLength, "Cross") && params == 2 && results <= 1) { build.loadAndCheckTag(build.vmReg(argResReg + 2), LUA_TVECTOR, build.vmExit(pcpos)); @@ -151,3 +181,362 @@ inline bool vectorNamecall( return false; } + +inline uint8_t userdataAccessBytecodeType(uint8_t type, const char* member, size_t memberLength) +{ + switch (typeToUserdataIndex(type)) + { + case kUserdataColor: + if (compareMemberName(member, memberLength, "R")) + return LBC_TYPE_NUMBER; + + if (compareMemberName(member, memberLength, "G")) + return LBC_TYPE_NUMBER; + + if (compareMemberName(member, memberLength, "B")) + return LBC_TYPE_NUMBER; + break; + case kUserdataVec2: + if (compareMemberName(member, memberLength, "X")) + return LBC_TYPE_NUMBER; + + if (compareMemberName(member, memberLength, "Y")) + return LBC_TYPE_NUMBER; + + if (compareMemberName(member, memberLength, "Magnitude")) + return LBC_TYPE_NUMBER; + + if (compareMemberName(member, memberLength, "Unit")) + return userdataIndexToType(kUserdataVec2); + break; + case kUserdataMat3: + if (compareMemberName(member, memberLength, "Row1")) + return LBC_TYPE_VECTOR; + + if (compareMemberName(member, memberLength, "Row2")) + return LBC_TYPE_VECTOR; + + if (compareMemberName(member, memberLength, "Row3")) + return LBC_TYPE_VECTOR; + break; + } + + return LBC_TYPE_ANY; +} + +inline bool userdataAccess( + Luau::CodeGen::IrBuilder& build, uint8_t type, const char* member, size_t memberLength, int resultReg, int sourceReg, int pcpos) +{ + using namespace Luau::CodeGen; + + switch (typeToUserdataIndex(type)) + { + case kUserdataColor: + break; + case kUserdataVec2: + if (compareMemberName(member, memberLength, "X")) + { + IrOp udata = build.inst(IrCmd::LOAD_POINTER, build.vmReg(sourceReg)); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata, build.constInt(kTagVec2), build.vmExit(pcpos)); + + IrOp value = build.inst(IrCmd::BUFFER_READF32, udata, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(resultReg), value); + build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TNUMBER)); + return true; + } + + if (compareMemberName(member, memberLength, "Y")) + { + IrOp udata = build.inst(IrCmd::LOAD_POINTER, build.vmReg(sourceReg)); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata, build.constInt(kTagVec2), build.vmExit(pcpos)); + + IrOp value = build.inst(IrCmd::BUFFER_READF32, udata, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(resultReg), value); + build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TNUMBER)); + return true; + } + + if (compareMemberName(member, memberLength, "Magnitude")) + { + IrOp udata = build.inst(IrCmd::LOAD_POINTER, build.vmReg(sourceReg)); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata, build.constInt(kTagVec2), build.vmExit(pcpos)); + + IrOp x = build.inst(IrCmd::BUFFER_READF32, udata, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + IrOp y = build.inst(IrCmd::BUFFER_READF32, udata, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + + IrOp x2 = build.inst(IrCmd::MUL_NUM, x, x); + IrOp y2 = build.inst(IrCmd::MUL_NUM, y, y); + + IrOp sum = build.inst(IrCmd::ADD_NUM, x2, y2); + + IrOp mag = build.inst(IrCmd::SQRT_NUM, sum); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(resultReg), mag); + build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TNUMBER)); + return true; + } + + if (compareMemberName(member, memberLength, "Unit")) + { + IrOp udata = build.inst(IrCmd::LOAD_POINTER, build.vmReg(sourceReg)); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata, build.constInt(kTagVec2), build.vmExit(pcpos)); + + IrOp x = build.inst(IrCmd::BUFFER_READF32, udata, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + IrOp y = build.inst(IrCmd::BUFFER_READF32, udata, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + + IrOp x2 = build.inst(IrCmd::MUL_NUM, x, x); + IrOp y2 = build.inst(IrCmd::MUL_NUM, y, y); + + IrOp sum = build.inst(IrCmd::ADD_NUM, x2, y2); + + IrOp mag = build.inst(IrCmd::SQRT_NUM, sum); + IrOp inv = build.inst(IrCmd::DIV_NUM, build.constDouble(1.0), mag); + + IrOp xr = build.inst(IrCmd::MUL_NUM, x, inv); + IrOp yr = build.inst(IrCmd::MUL_NUM, y, inv); + + build.inst(IrCmd::CHECK_GC); + IrOp udatar = build.inst(IrCmd::NEW_USERDATA, build.constInt(sizeof(Vec2)), build.constInt(kTagVec2)); + + build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, x)), xr, build.constTag(LUA_TUSERDATA)); + build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, y)), yr, build.constTag(LUA_TUSERDATA)); + + build.inst(IrCmd::STORE_POINTER, build.vmReg(resultReg), udatar); + build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TUSERDATA)); + return true; + } + break; + case kUserdataMat3: + break; + } + + return false; +} + +inline uint8_t userdataMetamethodBytecodeType(uint8_t lhsTy, uint8_t rhsTy, Luau::CodeGen::HostMetamethod method) +{ + switch (method) + { + case Luau::CodeGen::HostMetamethod::Add: + case Luau::CodeGen::HostMetamethod::Sub: + case Luau::CodeGen::HostMetamethod::Mul: + case Luau::CodeGen::HostMetamethod::Div: + if (typeToUserdataIndex(lhsTy) == kUserdataVec2 || typeToUserdataIndex(rhsTy) == kUserdataVec2) + return userdataIndexToType(kUserdataVec2); + break; + case Luau::CodeGen::HostMetamethod::Minus: + if (typeToUserdataIndex(lhsTy) == kUserdataVec2) + return userdataIndexToType(kUserdataVec2); + break; + default: + break; + } + + return LBC_TYPE_ANY; +} + +inline bool userdataMetamethod(Luau::CodeGen::IrBuilder& build, uint8_t lhsTy, uint8_t rhsTy, int resultReg, Luau::CodeGen::IrOp lhs, + Luau::CodeGen::IrOp rhs, Luau::CodeGen::HostMetamethod method, int pcpos) +{ + using namespace Luau::CodeGen; + + switch (method) + { + case Luau::CodeGen::HostMetamethod::Add: + if (typeToUserdataIndex(lhsTy) == kUserdataVec2 && typeToUserdataIndex(rhsTy) == kUserdataVec2) + { + build.loadAndCheckTag(lhs, LUA_TUSERDATA, build.vmExit(pcpos)); + build.loadAndCheckTag(rhs, LUA_TUSERDATA, build.vmExit(pcpos)); + + IrOp udata1 = build.inst(IrCmd::LOAD_POINTER, lhs); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata1, build.constInt(kTagVec2), build.vmExit(pcpos)); + + IrOp udata2 = build.inst(IrCmd::LOAD_POINTER, rhs); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata2, build.constInt(kTagVec2), build.vmExit(pcpos)); + + IrOp x1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + IrOp x2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + IrOp mx = build.inst(IrCmd::ADD_NUM, x1, x2); + + IrOp y1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + IrOp y2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + IrOp my = build.inst(IrCmd::ADD_NUM, y1, y2); + + build.inst(IrCmd::CHECK_GC); + IrOp udatar = build.inst(IrCmd::NEW_USERDATA, build.constInt(sizeof(Vec2)), build.constInt(kTagVec2)); + + build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, x)), mx, build.constTag(LUA_TUSERDATA)); + build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, y)), my, build.constTag(LUA_TUSERDATA)); + + build.inst(IrCmd::STORE_POINTER, build.vmReg(resultReg), udatar); + build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TUSERDATA)); + + return true; + } + break; + case Luau::CodeGen::HostMetamethod::Mul: + if (typeToUserdataIndex(lhsTy) == kUserdataVec2 && typeToUserdataIndex(rhsTy) == kUserdataVec2) + { + build.loadAndCheckTag(lhs, LUA_TUSERDATA, build.vmExit(pcpos)); + build.loadAndCheckTag(rhs, LUA_TUSERDATA, build.vmExit(pcpos)); + + IrOp udata1 = build.inst(IrCmd::LOAD_POINTER, lhs); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata1, build.constInt(kTagVec2), build.vmExit(pcpos)); + + IrOp udata2 = build.inst(IrCmd::LOAD_POINTER, rhs); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata2, build.constInt(kTagVec2), build.vmExit(pcpos)); + + IrOp x1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + IrOp x2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + IrOp mx = build.inst(IrCmd::MUL_NUM, x1, x2); + + IrOp y1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + IrOp y2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + IrOp my = build.inst(IrCmd::MUL_NUM, y1, y2); + + build.inst(IrCmd::CHECK_GC); + IrOp udatar = build.inst(IrCmd::NEW_USERDATA, build.constInt(sizeof(Vec2)), build.constInt(kTagVec2)); + + build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, x)), mx, build.constTag(LUA_TUSERDATA)); + build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, y)), my, build.constTag(LUA_TUSERDATA)); + + build.inst(IrCmd::STORE_POINTER, build.vmReg(resultReg), udatar); + build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TUSERDATA)); + + return true; + } + break; + case Luau::CodeGen::HostMetamethod::Minus: + if (typeToUserdataIndex(lhsTy) == kUserdataVec2) + { + build.loadAndCheckTag(lhs, LUA_TUSERDATA, build.vmExit(pcpos)); + + IrOp udata1 = build.inst(IrCmd::LOAD_POINTER, lhs); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata1, build.constInt(kTagVec2), build.vmExit(pcpos)); + + IrOp x = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + IrOp y = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + IrOp mx = build.inst(IrCmd::UNM_NUM, x); + IrOp my = build.inst(IrCmd::UNM_NUM, y); + + build.inst(IrCmd::CHECK_GC); + IrOp udatar = build.inst(IrCmd::NEW_USERDATA, build.constInt(sizeof(Vec2)), build.constInt(kTagVec2)); + + build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, x)), mx, build.constTag(LUA_TUSERDATA)); + build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, y)), my, build.constTag(LUA_TUSERDATA)); + + build.inst(IrCmd::STORE_POINTER, build.vmReg(resultReg), udatar); + build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TUSERDATA)); + + return true; + } + break; + default: + break; + } + + return false; +} + +inline uint8_t userdataNamecallBytecodeType(uint8_t type, const char* member, size_t memberLength) +{ + switch (typeToUserdataIndex(type)) + { + case kUserdataColor: + break; + case kUserdataVec2: + if (compareMemberName(member, memberLength, "Dot")) + return LBC_TYPE_NUMBER; + + if (compareMemberName(member, memberLength, "Min")) + return userdataIndexToType(kUserdataVec2); + break; + case kUserdataMat3: + break; + } + + return LBC_TYPE_ANY; +} + +inline bool userdataNamecall(Luau::CodeGen::IrBuilder& build, uint8_t type, const char* member, size_t memberLength, int argResReg, int sourceReg, + int params, int results, int pcpos) +{ + using namespace Luau::CodeGen; + + switch (typeToUserdataIndex(type)) + { + case kUserdataColor: + break; + case kUserdataVec2: + if (compareMemberName(member, memberLength, "Dot")) + { + IrOp udata1 = build.inst(IrCmd::LOAD_POINTER, build.vmReg(sourceReg)); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata1, build.constInt(kTagVec2), build.vmExit(pcpos)); + + build.loadAndCheckTag(build.vmReg(argResReg + 2), LUA_TUSERDATA, build.vmExit(pcpos)); + + IrOp udata2 = build.inst(IrCmd::LOAD_POINTER, build.vmReg(argResReg + 2)); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata2, build.constInt(kTagVec2), build.vmExit(pcpos)); + + IrOp x1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + IrOp x2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + IrOp xx = build.inst(IrCmd::MUL_NUM, x1, x2); + + IrOp y1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + IrOp y2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + IrOp yy = build.inst(IrCmd::MUL_NUM, y1, y2); + + IrOp sum = build.inst(IrCmd::ADD_NUM, xx, yy); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(argResReg), sum); + build.inst(IrCmd::STORE_TAG, build.vmReg(argResReg), build.constTag(LUA_TNUMBER)); + + // If the function is called in multi-return context, stack has to be adjusted + if (results == LUA_MULTRET) + build.inst(IrCmd::ADJUST_STACK_TO_REG, build.vmReg(argResReg), build.constInt(1)); + + return true; + } + + if (compareMemberName(member, memberLength, "Min")) + { + IrOp udata1 = build.inst(IrCmd::LOAD_POINTER, build.vmReg(sourceReg)); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata1, build.constInt(kTagVec2), build.vmExit(pcpos)); + + build.loadAndCheckTag(build.vmReg(argResReg + 2), LUA_TUSERDATA, build.vmExit(pcpos)); + + IrOp udata2 = build.inst(IrCmd::LOAD_POINTER, build.vmReg(argResReg + 2)); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata2, build.constInt(kTagVec2), build.vmExit(pcpos)); + + IrOp x1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + IrOp x2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + IrOp mx = build.inst(IrCmd::MIN_NUM, x1, x2); + + IrOp y1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + IrOp y2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + IrOp my = build.inst(IrCmd::MIN_NUM, y1, y2); + + build.inst(IrCmd::CHECK_GC); + IrOp udatar = build.inst(IrCmd::NEW_USERDATA, build.constInt(sizeof(Vec2)), build.constInt(kTagVec2)); + + build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, x)), mx, build.constTag(LUA_TUSERDATA)); + build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, y)), my, build.constTag(LUA_TUSERDATA)); + + build.inst(IrCmd::STORE_POINTER, build.vmReg(argResReg), udatar); + build.inst(IrCmd::STORE_TAG, build.vmReg(argResReg), build.constTag(LUA_TUSERDATA)); + + // If the function is called in multi-return context, stack has to be adjusted + if (results == LUA_MULTRET) + build.inst(IrCmd::ADJUST_STACK_TO_REG, build.vmReg(argResReg), build.constInt(1)); + + return true; + } + break; + case kUserdataMat3: + break; + } + + return false; +} diff --git a/tests/IrLowering.test.cpp b/tests/IrLowering.test.cpp index 5d7fedd8..ecdb522c 100644 --- a/tests/IrLowering.test.cpp +++ b/tests/IrLowering.test.cpp @@ -24,6 +24,8 @@ LUAU_FASTFLAG(LuauCompileTempTypeInfo) LUAU_FASTFLAG(LuauCodegenAnalyzeHostVectorOps) LUAU_FASTFLAG(LuauCompileUserdataInfo) LUAU_FASTFLAG(LuauLoadUserdataInfo) +LUAU_FASTFLAG(LuauCodegenUserdataOps) +LUAU_FASTFLAG(LuauCodegenUserdataAlloc) static std::string getCodegenAssembly(const char* source, bool includeIrTypes = false, int debugLevel = 1) { @@ -34,6 +36,13 @@ static std::string getCodegenAssembly(const char* source, bool includeIrTypes = options.compilationOptions.hooks.vectorAccess = vectorAccess; options.compilationOptions.hooks.vectorNamecall = vectorNamecall; + options.compilationOptions.hooks.userdataAccessBytecodeType = userdataAccessBytecodeType; + options.compilationOptions.hooks.userdataMetamethodBytecodeType = userdataMetamethodBytecodeType; + options.compilationOptions.hooks.userdataNamecallBytecodeType = userdataNamecallBytecodeType; + options.compilationOptions.hooks.userdataAccess = userdataAccess; + options.compilationOptions.hooks.userdataMetamethod = userdataMetamethod; + options.compilationOptions.hooks.userdataNamecall = userdataNamecall; + // For IR, we don't care about assembly, but we want a stable target options.target = Luau::CodeGen::AssemblyOptions::Target::X64_SystemV; @@ -1690,4 +1699,352 @@ end )"); } +TEST_CASE("CustomUserdataPropertyAccess") +{ + // This test requires runtime component to be present + if (!Luau::CodeGen::isSupported()) + return; + + ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, + {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true}, + {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, {FFlag::LuauCodegenUserdataOps, true}}; + + CHECK_EQ("\n" + getCodegenAssembly(R"( +local function foo(v: vec2) + return v.X + v.Y +end +)", + /* includeIrTypes */ true), + R"( +; function foo($arg0) line 2 +; R0: vec2 [argument] +bb_0: + CHECK_TAG R0, tuserdata, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + %6 = LOAD_POINTER R0 + CHECK_USERDATA_TAG %6, 12i, exit(0) + %8 = BUFFER_READF32 %6, 0i, tuserdata + %15 = BUFFER_READF32 %6, 4i, tuserdata + %24 = ADD_NUM %8, %15 + STORE_DOUBLE R1, %24 + STORE_TAG R1, tnumber + INTERRUPT 5u + RETURN R1, 1i +)"); +} + +TEST_CASE("CustomUserdataPropertyAccess2") +{ + // This test requires runtime component to be present + if (!Luau::CodeGen::isSupported()) + return; + + ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, + {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true}, + {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, {FFlag::LuauCodegenUserdataOps, true}}; + + CHECK_EQ("\n" + getCodegenAssembly(R"( +local function foo(a: mat3) + return a.Row1 * a.Row2 +end +)", + /* includeIrTypes */ true), + R"( +; function foo($arg0) line 2 +; R0: mat3 [argument] +bb_0: + CHECK_TAG R0, tuserdata, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + FALLBACK_GETTABLEKS 0u, R2, R0, K0 + FALLBACK_GETTABLEKS 2u, R3, R0, K1 + CHECK_TAG R2, tvector, exit(4) + CHECK_TAG R3, tvector, exit(4) + %14 = LOAD_TVALUE R2 + %15 = LOAD_TVALUE R3 + %16 = MUL_VEC %14, %15 + %17 = TAG_VECTOR %16 + STORE_TVALUE R1, %17 + INTERRUPT 5u + RETURN R1, 1i +)"); +} + +TEST_CASE("CustomUserdataNamecall1") +{ + // This test requires runtime component to be present + if (!Luau::CodeGen::isSupported()) + return; + + ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, + {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true}, + {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, {FFlag::LuauCodegenAnalyzeHostVectorOps, true}, + {FFlag::LuauCodegenUserdataOps, true}}; + + CHECK_EQ("\n" + getCodegenAssembly(R"( +local function foo(a: vec2, b: vec2) + return a:Dot(b) +end +)", + /* includeIrTypes */ true), + R"( +; function foo($arg0, $arg1) line 2 +; R0: vec2 [argument] +; R1: vec2 [argument] +bb_0: + CHECK_TAG R0, tuserdata, exit(entry) + CHECK_TAG R1, tuserdata, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + %6 = LOAD_TVALUE R1 + STORE_TVALUE R4, %6 + %10 = LOAD_POINTER R0 + CHECK_USERDATA_TAG %10, 12i, exit(1) + %14 = LOAD_POINTER R4 + CHECK_USERDATA_TAG %14, 12i, exit(1) + %16 = BUFFER_READF32 %10, 0i, tuserdata + %17 = BUFFER_READF32 %14, 0i, tuserdata + %18 = MUL_NUM %16, %17 + %19 = BUFFER_READF32 %10, 4i, tuserdata + %20 = BUFFER_READF32 %14, 4i, tuserdata + %21 = MUL_NUM %19, %20 + %22 = ADD_NUM %18, %21 + STORE_DOUBLE R2, %22 + STORE_TAG R2, tnumber + ADJUST_STACK_TO_REG R2, 1i + INTERRUPT 4u + RETURN R2, -1i +)"); +} + +TEST_CASE("CustomUserdataNamecall2") +{ + // This test requires runtime component to be present + if (!Luau::CodeGen::isSupported()) + return; + + ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, + {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true}, + {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, {FFlag::LuauCodegenAnalyzeHostVectorOps, true}, + {FFlag::LuauCodegenUserdataOps, true}, {FFlag::LuauCodegenUserdataAlloc, true}}; + + CHECK_EQ("\n" + getCodegenAssembly(R"( +local function foo(a: vec2, b: vec2) + return a:Min(b) +end +)", + /* includeIrTypes */ true), + R"( +; function foo($arg0, $arg1) line 2 +; R0: vec2 [argument] +; R1: vec2 [argument] +bb_0: + CHECK_TAG R0, tuserdata, exit(entry) + CHECK_TAG R1, tuserdata, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + %6 = LOAD_TVALUE R1 + STORE_TVALUE R4, %6 + %10 = LOAD_POINTER R0 + CHECK_USERDATA_TAG %10, 12i, exit(1) + %14 = LOAD_POINTER R4 + CHECK_USERDATA_TAG %14, 12i, exit(1) + %16 = BUFFER_READF32 %10, 0i, tuserdata + %17 = BUFFER_READF32 %14, 0i, tuserdata + %18 = MIN_NUM %16, %17 + %19 = BUFFER_READF32 %10, 4i, tuserdata + %20 = BUFFER_READF32 %14, 4i, tuserdata + %21 = MIN_NUM %19, %20 + CHECK_GC + %23 = NEW_USERDATA 8i, 12i + BUFFER_WRITEF32 %23, 0i, %18, tuserdata + BUFFER_WRITEF32 %23, 4i, %21, tuserdata + STORE_POINTER R2, %23 + STORE_TAG R2, tuserdata + ADJUST_STACK_TO_REG R2, 1i + INTERRUPT 4u + RETURN R2, -1i +)"); +} + +TEST_CASE("CustomUserdataMetamethodDirectFlow") +{ + // This test requires runtime component to be present + if (!Luau::CodeGen::isSupported()) + return; + + ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, + {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true}, + {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, {FFlag::LuauCodegenUserdataOps, true}}; + + CHECK_EQ("\n" + getCodegenAssembly(R"( +local function foo(a: mat3, b: mat3) + return a * b +end +)", + /* includeIrTypes */ true), + R"( +; function foo($arg0, $arg1) line 2 +; R0: mat3 [argument] +; R1: mat3 [argument] +bb_0: + CHECK_TAG R0, tuserdata, exit(entry) + CHECK_TAG R1, tuserdata, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + SET_SAVEDPC 1u + DO_ARITH R2, R0, R1, 10i + INTERRUPT 1u + RETURN R2, 1i +)"); +} + +TEST_CASE("CustomUserdataMetamethodDirectFlow2") +{ + // This test requires runtime component to be present + if (!Luau::CodeGen::isSupported()) + return; + + ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, + {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true}, + {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, {FFlag::LuauCodegenUserdataOps, true}}; + + CHECK_EQ("\n" + getCodegenAssembly(R"( +local function foo(a: mat3) + return -a +end +)", + /* includeIrTypes */ true), + R"( +; function foo($arg0) line 2 +; R0: mat3 [argument] +bb_0: + CHECK_TAG R0, tuserdata, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + SET_SAVEDPC 1u + DO_ARITH R1, R0, R0, 15i + INTERRUPT 1u + RETURN R1, 1i +)"); +} + +TEST_CASE("CustomUserdataMetamethodDirectFlow3") +{ + // This test requires runtime component to be present + if (!Luau::CodeGen::isSupported()) + return; + + ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, + {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true}, + {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, {FFlag::LuauCodegenUserdataOps, true}}; + + CHECK_EQ("\n" + getCodegenAssembly(R"( +local function foo(a: sequence) + return #a +end +)", + /* includeIrTypes */ true), + R"( +; function foo($arg0) line 2 +; R0: userdata [argument] +bb_0: + CHECK_TAG R0, tuserdata, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + SET_SAVEDPC 1u + DO_LEN R1, R0 + INTERRUPT 1u + RETURN R1, 1i +)"); +} + +TEST_CASE("CustomUserdataMetamethod") +{ + // This test requires runtime component to be present + if (!Luau::CodeGen::isSupported()) + return; + + ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, + {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true}, + {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, {FFlag::LuauCodegenUserdataOps, true}, + {FFlag::LuauCodegenUserdataAlloc, true}}; + + CHECK_EQ("\n" + getCodegenAssembly(R"( +local function foo(a: vec2, b: vec2, c: vec2) + return -c + a * b +end +)", + /* includeIrTypes */ true), + R"( +; function foo($arg0, $arg1, $arg2) line 2 +; R0: vec2 [argument] +; R1: vec2 [argument] +; R2: vec2 [argument] +bb_0: + CHECK_TAG R0, tuserdata, exit(entry) + CHECK_TAG R1, tuserdata, exit(entry) + CHECK_TAG R2, tuserdata, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + %10 = LOAD_POINTER R2 + CHECK_USERDATA_TAG %10, 12i, exit(0) + %12 = BUFFER_READF32 %10, 0i, tuserdata + %13 = BUFFER_READF32 %10, 4i, tuserdata + %14 = UNM_NUM %12 + %15 = UNM_NUM %13 + CHECK_GC + %17 = NEW_USERDATA 8i, 12i + BUFFER_WRITEF32 %17, 0i, %14, tuserdata + BUFFER_WRITEF32 %17, 4i, %15, tuserdata + STORE_POINTER R4, %17 + STORE_TAG R4, tuserdata + %26 = LOAD_POINTER R0 + CHECK_USERDATA_TAG %26, 12i, exit(1) + %28 = LOAD_POINTER R1 + CHECK_USERDATA_TAG %28, 12i, exit(1) + %30 = BUFFER_READF32 %26, 0i, tuserdata + %31 = BUFFER_READF32 %28, 0i, tuserdata + %32 = MUL_NUM %30, %31 + %33 = BUFFER_READF32 %26, 4i, tuserdata + %34 = BUFFER_READF32 %28, 4i, tuserdata + %35 = MUL_NUM %33, %34 + %37 = NEW_USERDATA 8i, 12i + BUFFER_WRITEF32 %37, 0i, %32, tuserdata + BUFFER_WRITEF32 %37, 4i, %35, tuserdata + STORE_POINTER R5, %37 + STORE_TAG R5, tuserdata + %50 = BUFFER_READF32 %17, 0i, tuserdata + %51 = BUFFER_READF32 %37, 0i, tuserdata + %52 = ADD_NUM %50, %51 + %53 = BUFFER_READF32 %17, 4i, tuserdata + %54 = BUFFER_READF32 %37, 4i, tuserdata + %55 = ADD_NUM %53, %54 + %57 = NEW_USERDATA 8i, 12i + BUFFER_WRITEF32 %57, 0i, %52, tuserdata + BUFFER_WRITEF32 %57, 4i, %55, tuserdata + STORE_POINTER R3, %57 + STORE_TAG R3, tuserdata + INTERRUPT 3u + RETURN R3, 1i +)"); +} + TEST_SUITE_END(); diff --git a/tests/Lexer.test.cpp b/tests/Lexer.test.cpp index 78d1389a..e0716e4c 100644 --- a/tests/Lexer.test.cpp +++ b/tests/Lexer.test.cpp @@ -192,13 +192,13 @@ TEST_CASE("string_interpolation_double_brace") auto brokenInterpBegin = lexer.next(); CHECK_EQ(brokenInterpBegin.type, Lexeme::BrokenInterpDoubleBrace); - CHECK_EQ(std::string(brokenInterpBegin.data, brokenInterpBegin.length), std::string("foo")); + CHECK_EQ(std::string(brokenInterpBegin.data, brokenInterpBegin.getLength()), std::string("foo")); CHECK_EQ(lexer.next().type, Lexeme::Name); auto interpEnd = lexer.next(); CHECK_EQ(interpEnd.type, Lexeme::InterpStringEnd); - CHECK_EQ(std::string(interpEnd.data, interpEnd.length), std::string("}bar")); + CHECK_EQ(std::string(interpEnd.data, interpEnd.getLength()), std::string("}bar")); } TEST_CASE("string_interpolation_double_but_unmatched_brace") diff --git a/tests/NonStrictTypeChecker.test.cpp b/tests/NonStrictTypeChecker.test.cpp index e51fb0df..81a84722 100644 --- a/tests/NonStrictTypeChecker.test.cpp +++ b/tests/NonStrictTypeChecker.test.cpp @@ -15,6 +15,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauAttributeSyntax); + #define NONSTRICT_REQUIRE_ERR_AT_POS(pos, result, idx) \ do \ { \ @@ -68,6 +70,7 @@ struct NonStrictTypeCheckerFixture : Fixture { ScopedFastFlag flags[] = { {FFlag::DebugLuauDeferredConstraintResolution, true}, + {FFlag::LuauAttributeSyntax, true}, }; LoadDefinitionFileResult res = loadDefinition(definitions); LUAU_ASSERT(res.success); @@ -78,6 +81,7 @@ struct NonStrictTypeCheckerFixture : Fixture { ScopedFastFlag flags[] = { {FFlag::DebugLuauDeferredConstraintResolution, true}, + {FFlag::LuauAttributeSyntax, true}, }; LoadDefinitionFileResult res = loadDefinition(definitions); LUAU_ASSERT(res.success); @@ -85,21 +89,21 @@ struct NonStrictTypeCheckerFixture : Fixture } std::string definitions = R"BUILTIN_SRC( -declare function @checked abs(n: number): number -declare function @checked lower(s: string): string +@checked declare function abs(n: number): number +@checked declare function lower(s: string): string declare function cond() : boolean -declare function @checked contrived(n : Not) : number +@checked declare function contrived(n : Not) : number -- interesting types of things that we would like to mark as checked -declare function @checked onlyNums(...: number) : number -declare function @checked mixedArgs(x: string, ...: number) : number -declare function @checked optionalArg(x: string?) : number +@checked declare function onlyNums(...: number) : number +@checked declare function mixedArgs(x: string, ...: number) : number +@checked declare function optionalArg(x: string?) : number declare foo: { bar: @checked (number) -> number, } -declare function @checked optionalArgsAtTheEnd1(x: string, y: number?, z: number?) : number -declare function @checked optionalArgsAtTheEnd2(x: string, y: number?, z: string) : number +@checked declare function optionalArgsAtTheEnd1(x: string, y: number?, z: number?) : number +@checked declare function optionalArgsAtTheEnd2(x: string, y: number?, z: string) : number type DateTypeArg = { year: number, @@ -115,7 +119,7 @@ declare os : { time: @checked (time: DateTypeArg?) -> number } -declare function @checked require(target : any) : any +@checked declare function require(target : any) : any )BUILTIN_SRC"; }; @@ -558,6 +562,10 @@ local E = require(script.Parent.A) TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "nonstrict_shouldnt_warn_on_valid_buffer_use") { + ScopedFastFlag flags[] = { + {FFlag::LuauAttributeSyntax, true}, + }; + loadDefinition(R"( declare buffer: { create: @checked (size: number) -> buffer, diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 6b4bcf22..8b2cc6ba 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -16,6 +16,7 @@ LUAU_FASTINT(LuauRecursionLimit); LUAU_FASTINT(LuauTypeLengthLimit); LUAU_FASTINT(LuauParseErrorLimit); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); +LUAU_FASTFLAG(LuauAttributeSyntax); LUAU_FASTFLAG(LuauLeadingBarAndAmpersand); namespace @@ -3051,9 +3052,10 @@ TEST_CASE_FIXTURE(Fixture, "parse_top_level_checked_fn") { ParseOptions opts; opts.allowDeclarationSyntax = true; + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; std::string src = R"BUILTIN_SRC( -declare function @checked abs(n: number): number +@checked declare function abs(n: number): number )BUILTIN_SRC"; ParseResult pr = tryParse(src, opts); @@ -3063,13 +3065,14 @@ declare function @checked abs(n: number): number AstStat* root = *(pr.root->body.data); auto func = root->as(); LUAU_ASSERT(func); - LUAU_ASSERT(func->checkedFunction); + LUAU_ASSERT(func->isCheckedFunction()); } TEST_CASE_FIXTURE(Fixture, "parse_declared_table_checked_member") { ParseOptions opts; opts.allowDeclarationSyntax = true; + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; const std::string src = R"BUILTIN_SRC( declare math : { @@ -3090,13 +3093,14 @@ TEST_CASE_FIXTURE(Fixture, "parse_declared_table_checked_member") auto prop = *tbl->props.data; auto func = prop.type->as(); LUAU_ASSERT(func); - LUAU_ASSERT(func->checkedFunction); + LUAU_ASSERT(func->isCheckedFunction()); } TEST_CASE_FIXTURE(Fixture, "parse_checked_outside_decl_fails") { ParseOptions opts; opts.allowDeclarationSyntax = true; + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; ParseResult pr = tryParse(R"( local @checked = 3 @@ -3110,10 +3114,11 @@ TEST_CASE_FIXTURE(Fixture, "parse_checked_in_and_out_of_decl_fails") { ParseOptions opts; opts.allowDeclarationSyntax = true; + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; auto pr = tryParse(R"( local @checked = 3 - declare function @checked abs(n: number): number + @checked declare function abs(n: number): number )", opts); LUAU_ASSERT(pr.errors.size() == 2); @@ -3125,9 +3130,10 @@ TEST_CASE_FIXTURE(Fixture, "parse_checked_as_function_name_fails") { ParseOptions opts; opts.allowDeclarationSyntax = true; + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; auto pr = tryParse(R"( - function @checked(x: number) : number + @checked function(x: number) : number end )", opts); @@ -3138,6 +3144,7 @@ TEST_CASE_FIXTURE(Fixture, "cannot_use_@_as_variable_name") { ParseOptions opts; opts.allowDeclarationSyntax = true; + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; auto pr = tryParse(R"( local @blah = 3 @@ -3190,4 +3197,300 @@ TEST_CASE_FIXTURE(Fixture, "mixed_leading_intersection_and_union_not_allowed") matchParseError("type A = | number & string & boolean", "Mixing union and intersection types is not allowed; consider wrapping in parentheses."); } +void checkAttribute(const AstAttr* attr, const AstAttr::Type type, const Location& location) +{ + CHECK_EQ(attr->type, type); + CHECK_EQ(attr->location, location); +} + +void checkFirstErrorForAttributes(const std::vector& errors, const size_t minSize, const Location& location, const std::string& message) +{ + LUAU_ASSERT(minSize >= 1); + + CHECK_GE(errors.size(), minSize); + CHECK_EQ(errors[0].getLocation(), location); + CHECK_EQ(errors[0].getMessage(), message); +} + +TEST_CASE_FIXTURE(Fixture, "parse_attribute_on_function_stat") +{ + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; + + AstStatBlock* stat = parse(R"( +@checked +function hello(x, y) + return x + y +end)"); + + LUAU_ASSERT(stat != nullptr); + + AstStatFunction* statFun = stat->body.data[0]->as(); + LUAU_ASSERT(statFun != nullptr); + + AstArray attributes = statFun->func->attributes; + + CHECK_EQ(attributes.size, 1); + + checkAttribute(attributes.data[0], AstAttr::Type::Checked, Location(Position(1, 0), Position(1, 8))); +} + +TEST_CASE_FIXTURE(Fixture, "parse_attribute_on_local_function_stat") +{ + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; + + AstStatBlock* stat = parse(R"( + @checked +local function hello(x, y) + return x + y +end)"); + + LUAU_ASSERT(stat != nullptr); + + AstStatLocalFunction* statFun = stat->body.data[0]->as(); + LUAU_ASSERT(statFun != nullptr); + + AstArray attributes = statFun->func->attributes; + + CHECK_EQ(attributes.size, 1); + + checkAttribute(attributes.data[0], AstAttr::Type::Checked, Location(Position(1, 4), Position(1, 12))); +} + +TEST_CASE_FIXTURE(Fixture, "empty_attribute_name_is_not_allowed") +{ + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; + + ParseResult result = tryParse(R"( +@ +function hello(x, y) + return x + y +end)"); + + checkFirstErrorForAttributes(result.errors, 1, Location(Position(1, 0), Position(1, 1)), "Attribute name is missing"); +} + +TEST_CASE_FIXTURE(Fixture, "dont_parse_attributes_on_non_function_stat") +{ + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; + + ParseResult pr1 = tryParse(R"( +@checked +if a<0 then a = 0 end)"); + checkFirstErrorForAttributes(pr1.errors, 1, Location(Position(2, 0), Position(2, 2)), + "Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got 'if' intead"); + + ParseResult pr2 = tryParse(R"( +local i = 1 +@checked +while a[i] do + print(a[i]) + i = i + 1 +end)"); + checkFirstErrorForAttributes(pr2.errors, 1, Location(Position(3, 0), Position(3, 5)), + "Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got 'while' intead"); + + ParseResult pr3 = tryParse(R"( +@checked +do + local a2 = 2*a + local d = sqrt(b^2 - 4*a*c) + x1 = (-b + d)/a2 + x2 = (-b - d)/a2 +end)"); + checkFirstErrorForAttributes(pr3.errors, 1, Location(Position(2, 0), Position(2, 2)), + "Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got 'do' intead"); + + ParseResult pr4 = tryParse(R"( +@checked +for i=1,10 do print(i) end +)"); + checkFirstErrorForAttributes(pr4.errors, 1, Location(Position(2, 0), Position(2, 3)), + "Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got 'for' intead"); + + ParseResult pr5 = tryParse(R"( +@checked +repeat + line = io.read() +until line ~= "" +)"); + checkFirstErrorForAttributes(pr5.errors, 1, Location(Position(2, 0), Position(2, 6)), + "Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got 'repeat' intead"); + + + ParseResult pr6 = tryParse(R"( +@checked +local x = 10 +)"); + checkFirstErrorForAttributes( + pr6.errors, 1, Location(Position(2, 6), Position(2, 7)), "Expected 'function' after local declaration with attribute, but got 'x' intead"); + + ParseResult pr7 = tryParse(R"( +local i = 1 +while a[i] do + if a[i] == v then @checked break end + i = i + 1 +end +)"); + checkFirstErrorForAttributes(pr7.errors, 1, Location(Position(3, 31), Position(3, 36)), + "Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got 'break' intead"); + + + ParseResult pr8 = tryParse(R"( +function foo1 () @checked return 'a' end +)"); + checkFirstErrorForAttributes(pr8.errors, 1, Location(Position(1, 26), Position(1, 32)), + "Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got 'return' intead"); +} + +TEST_CASE_FIXTURE(Fixture, "parse_attribute_on_function_type_declaration") +{ + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; + + ParseOptions opts; + opts.allowDeclarationSyntax = true; + + std::string src = R"( +@checked declare function abs(n: number): number +)"; + + ParseResult pr = tryParse(src, opts); + CHECK_EQ(pr.errors.size(), 0); + + LUAU_ASSERT(pr.root->body.size == 1); + + AstStat* root = *(pr.root->body.data); + + auto func = root->as(); + LUAU_ASSERT(func != nullptr); + + CHECK(func->isCheckedFunction()); + + AstArray attributes = func->attributes; + + checkAttribute(attributes.data[0], AstAttr::Type::Checked, Location(Position(1, 0), Position(1, 8))); +} + +TEST_CASE_FIXTURE(Fixture, "parse_attributes_on_function_type_declaration_in_table") +{ + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; + + ParseOptions opts; + opts.allowDeclarationSyntax = true; + + std::string src = R"( +declare bit32: { + band: @checked (...number) -> number +})"; + + ParseResult pr = tryParse(src, opts); + CHECK_EQ(pr.errors.size(), 0); + + LUAU_ASSERT(pr.root->body.size == 1); + + AstStat* root = *(pr.root->body.data); + + AstStatDeclareGlobal* glob = root->as(); + LUAU_ASSERT(glob); + + auto tbl = glob->type->as(); + LUAU_ASSERT(tbl); + + LUAU_ASSERT(tbl->props.size == 1); + AstTableProp prop = tbl->props.data[0]; + + AstTypeFunction* func = prop.type->as(); + LUAU_ASSERT(func); + + AstArray attributes = func->attributes; + + CHECK_EQ(attributes.size, 1); + checkAttribute(attributes.data[0], AstAttr::Type::Checked, Location(Position(2, 10), Position(2, 18))); +} + +TEST_CASE_FIXTURE(Fixture, "dont_parse_attributes_on_non_function_type_declarations") +{ + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; + + ParseOptions opts; + opts.allowDeclarationSyntax = true; + + ParseResult pr1 = tryParse(R"( +@checked declare foo: number + )", + opts); + + checkFirstErrorForAttributes( + pr1.errors, 1, Location(Position(1, 17), Position(1, 20)), "Expected a function type declaration after attribute, but got 'foo' intead"); + + ParseResult pr2 = tryParse(R"( +@checked declare class Foo + prop: number + function method(self, foo: number): string +end)", + opts); + + checkFirstErrorForAttributes( + pr2.errors, 1, Location(Position(1, 17), Position(1, 22)), "Expected a function type declaration after attribute, but got 'class' intead"); + + ParseResult pr3 = tryParse(R"( +declare bit32: { + band: @checked number +})", + opts); + + checkFirstErrorForAttributes( + pr3.errors, 1, Location(Position(2, 19), Position(2, 25)), "Expected '(' when parsing function parameters, got 'number'"); +} + +TEST_CASE_FIXTURE(Fixture, "attributes_cannot_be_duplicated") +{ + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; + + ParseResult result = tryParse(R"( +@checked + @checked +function hello(x, y) + return x + y +end)"); + + checkFirstErrorForAttributes(result.errors, 1, Location(Position(2, 4), Position(2, 12)), "Cannot duplicate attribute '@checked'"); +} + +TEST_CASE_FIXTURE(Fixture, "unsupported_attributes_are_not_allowed") +{ + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; + + ParseResult result = tryParse(R"( +@checked + @cool_attribute +function hello(x, y) + return x + y +end)"); + + checkFirstErrorForAttributes(result.errors, 1, Location(Position(2, 4), Position(2, 19)), "Invalid attribute '@cool_attribute'"); +} + +TEST_CASE_FIXTURE(Fixture, "can_parse_leading_bar_unions_successfully") +{ + ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand, true}; + + parse(R"(type A = | "Hello" | "World")"); +} + +TEST_CASE_FIXTURE(Fixture, "can_parse_leading_ampersand_intersections_successfully") +{ + ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand, true}; + + parse(R"(type A = & { string } & { number })"); +} + +TEST_CASE_FIXTURE(Fixture, "mixed_leading_intersection_and_union_not_allowed") +{ + ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand, true}; + + matchParseError("type A = & number | string | boolean", "Mixing union and intersection types is not allowed; consider wrapping in parentheses."); + matchParseError("type A = | number & string & boolean", "Mixing union and intersection types is not allowed; consider wrapping in parentheses."); +} + + TEST_SUITE_END(); diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index e2b3f9b7..d7cb225a 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -13,6 +13,7 @@ using namespace Luau; LUAU_FASTFLAG(LuauRecursiveTypeParameterRestriction); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAG(DebugLuauSharedSelf); +LUAU_FASTFLAG(LuauAttributeSyntax); TEST_SUITE_BEGIN("ToString"); @@ -1010,10 +1011,11 @@ TEST_CASE_FIXTURE(Fixture, "checked_fn_toString") { ScopedFastFlag flags[] = { {FFlag::DebugLuauDeferredConstraintResolution, true}, + {FFlag::LuauAttributeSyntax, true}, }; auto _result = loadDefinition(R"( -declare function @checked abs(n: number) : number +@checked declare function abs(n: number) : number )"); auto result = check(Mode::Nonstrict, R"( diff --git a/tests/TypeInfer.loops.test.cpp b/tests/TypeInfer.loops.test.cpp index 1a7ef973..ce6988aa 100644 --- a/tests/TypeInfer.loops.test.cpp +++ b/tests/TypeInfer.loops.test.cpp @@ -701,7 +701,7 @@ TEST_CASE_FIXTURE(Fixture, "loop_iter_basic") REQUIRE(ut); REQUIRE(ut->options.size() == 2); - CHECK_EQ(builtinTypes->nilType, ut->options[0]); + CHECK_EQ(builtinTypes->nilType, follow(ut->options[0])); CHECK_EQ(*builtinTypes->numberType, *ut->options[1]); } else @@ -1179,4 +1179,14 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "iteration_preserves_error_suppression") CHECK("any" == toString(requireTypeAtPosition({3, 25}))); } +TEST_CASE_FIXTURE(BuiltinsFixture, "tryDispatchIterableFunction_under_constrained_loop_should_not_assert") +{ + CheckResult result = check(R"( +local function foo(Instance) + for _, Child in next, Instance:GetChildren() do + end +end + )"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 2c6136a4..516a761b 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -3153,7 +3153,7 @@ TEST_CASE_FIXTURE(Fixture, "accidentally_checked_prop_in_opposite_branch") LUAU_REQUIRE_ERROR_COUNT(1, result); if (FFlag::DebugLuauDeferredConstraintResolution) - CHECK_EQ("Value of type '{ x: number? }?' could be nil", toString(result.errors[0])); + CHECK_EQ("Type 'nil' does not have key 'x'", toString(result.errors[0])); else CHECK_EQ("Value of type '{| x: number? |}?' could be nil", toString(result.errors[0])); CHECK_EQ("boolean", toString(requireType("u"))); @@ -4439,7 +4439,13 @@ TEST_CASE_FIXTURE(Fixture, "insert_a_and_f_of_a_into_table_res_in_a_loop") end )"); - LUAU_REQUIRE_NO_ERRORS(result); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(get(result.errors[0])); + } + else + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(BuiltinsFixture, "ipairs_adds_an_unbounded_indexer") diff --git a/tests/conformance/native_userdata.lua b/tests/conformance/native_userdata.lua new file mode 100644 index 00000000..b1b2a103 --- /dev/null +++ b/tests/conformance/native_userdata.lua @@ -0,0 +1,42 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +print('testing userdata') + +function ecall(fn, ...) + local ok, err = pcall(fn, ...) + assert(not ok) + return err:sub((err:find(": ") or -1) + 2, #err) +end + +local function realmad(a: vec2, b: vec2, c: vec2): vec2 + return -c + a * b; +end + +local function dm(s: vec2, t: vec2, u: vec2) + local x = s:Dot(t) + assert(x == 13) + + local t = u:Min(s) + assert(t.X == 5) + assert(t.Y == 4) +end + +local s: vec2 = vec2(5, 4) +local t: vec2 = vec2(1, 2) +local u: vec2 = vec2(10, 20) + +local x: vec2 = realmad(s, t, u) + +assert(x.X == -5) +assert(x.Y == -12) + +dm(s, t, u) + +local function mu(v: vec2) + assert(v.Magnitude == 2) + assert(v.Unit.X == 0) + assert(v.Unit.Y == 1) +end + +mu(vec2(0, 2)) + +return 'OK' diff --git a/tools/faillist.txt b/tools/faillist.txt index 7a214a32..b2677bf4 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -4,7 +4,6 @@ AutocompleteTest.anonymous_autofilled_generic_type_pack_vararg AutocompleteTest.autocomplete_string_singletons AutocompleteTest.do_wrong_compatible_nonself_calls AutocompleteTest.string_singleton_as_table_key -AutocompleteTest.string_singleton_in_if_statement2 AutocompleteTest.suggest_table_keys AutocompleteTest.type_correct_suggestion_for_overloads AutocompleteTest.type_correct_suggestion_in_table @@ -33,6 +32,15 @@ BuiltinTests.string_format_report_all_type_errors_at_correct_positions BuiltinTests.string_format_use_correct_argument2 BuiltinTests.table_freeze_is_generic BuiltinTests.tonumber_returns_optional_number_type +ControlFlowAnalysis.for_record_do_if_not_x_break +ControlFlowAnalysis.for_record_do_if_not_x_continue +ControlFlowAnalysis.if_not_x_break_elif_not_y_break +ControlFlowAnalysis.if_not_x_break_elif_not_y_continue +ControlFlowAnalysis.if_not_x_break_elif_rand_break_elif_not_y_break +ControlFlowAnalysis.if_not_x_continue_elif_not_y_continue +ControlFlowAnalysis.if_not_x_continue_elif_not_y_throw_elif_not_z_fallthrough +ControlFlowAnalysis.if_not_x_continue_elif_rand_continue_elif_not_y_continue +ControlFlowAnalysis.if_not_x_return_elif_not_y_break DefinitionTests.class_definition_overload_metamethods Differ.metatable_metamissing_left Differ.metatable_metamissing_right @@ -46,7 +54,6 @@ FrontendTest.trace_requires_in_nonstrict_mode GenericsTests.apply_type_function_nested_generics1 GenericsTests.better_mismatch_error_messages GenericsTests.bound_tables_do_not_clone_original_fields -GenericsTests.correctly_instantiate_polymorphic_member_functions GenericsTests.do_not_always_instantiate_generic_intersection_types GenericsTests.do_not_infer_generic_functions GenericsTests.dont_substitute_bound_types @@ -135,6 +142,7 @@ RefinementTest.discriminate_from_isa_of_x RefinementTest.discriminate_from_truthiness_of_x RefinementTest.globals_can_be_narrowed_too RefinementTest.isa_type_refinement_must_be_known_ahead_of_time +RefinementTest.nonoptional_type_can_narrow_to_nil_if_sense_is_true RefinementTest.not_t_or_some_prop_of_t RefinementTest.refine_a_param_that_got_resolved_during_constraint_solving_stage RefinementTest.refine_a_property_of_some_global @@ -278,7 +286,9 @@ TypeInferAnyError.can_subscript_any TypeInferAnyError.for_in_loop_iterator_is_any TypeInferAnyError.for_in_loop_iterator_is_any2 TypeInferAnyError.for_in_loop_iterator_is_any_pack +TypeInferAnyError.for_in_loop_iterator_returns_any TypeInferAnyError.for_in_loop_iterator_returns_any2 +TypeInferAnyError.replace_every_free_type_when_unifying_a_complex_function_with_any TypeInferClasses.callable_classes TypeInferClasses.cannot_unify_class_instance_with_primitive TypeInferClasses.class_type_mismatch_with_name_conflict @@ -337,6 +347,7 @@ TypeInferFunctions.too_many_arguments TypeInferFunctions.too_many_arguments_error_location TypeInferFunctions.too_many_return_values_in_parentheses TypeInferFunctions.too_many_return_values_no_function +TypeInferFunctions.unifier_should_not_bind_free_types TypeInferLoops.cli_68448_iterators_need_not_accept_nil TypeInferLoops.dcr_iteration_on_never_gives_never TypeInferLoops.dcr_xpath_candidates @@ -363,7 +374,6 @@ TypeInferModules.require TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_it_wont_help_2 TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_not_defined_with_colon TypeInferOOP.inferring_hundreds_of_self_calls_should_not_suffocate_memory -TypeInferOOP.methods_are_topologically_sorted TypeInferOOP.promise_type_error_too_complex TypeInferOperators.add_type_family_works TypeInferOperators.cli_38355_recursive_union From 7d4033071abebe09971b410d362c00ffb3084afb Mon Sep 17 00:00:00 2001 From: Vighnesh-V Date: Fri, 14 Jun 2024 13:21:20 -0700 Subject: [PATCH 18/20] Sync to upstream/release/630 (#1295) ### What's new * A bug in exception handling in GCC(11/12/13) on MacOS prevents our test suite from running. * Parser now supports leading `|` or `&` when declaring `Union` and `Intersection` types (#1286) * We now support parsing of attributes on functions as described in the [rfc](https://github.com/luau-lang/rfcs/pull/30) * With this change, expressions such as `local x = @native function(x) return x+1 end` and `f(@native function(x) return x+1 end)` are now valid. * Added support for `@native` attribute - we can now force native compilation of individual functions if the `@native` attribute is specified before the `function` keyword (works for lambdas too). ### New Solver * Many fixes in the new solver for crashes and instability * Refinements now use simplification and not normalization in a specific case of two tables * Assume that compound assignments do not change the type of the left-side operand * Fix error that prevented Class Methods from being overloaded ### VM * Updated description of Garbage Collector invariant --- ### Internal Contributors Co-authored-by: Aaron Weiss Co-authored-by: Alexander McCord Co-authored-by: Andy Friesen Co-authored-by: Aviral Goel Co-authored-by: Vighnesh Vijay Co-authored-by: Vyacheslav Egorov --------- Co-authored-by: Aaron Weiss Co-authored-by: Alexander McCord Co-authored-by: Andy Friesen Co-authored-by: Aviral Goel Co-authored-by: David Cope Co-authored-by: Lily Brown Co-authored-by: Vyacheslav Egorov --- Analysis/include/Luau/ConstraintGenerator.h | 2 +- Analysis/include/Luau/ConstraintSolver.h | 56 +-- Analysis/include/Luau/Generalization.h | 2 +- Analysis/include/Luau/Module.h | 6 + Analysis/include/Luau/Type.h | 2 +- Analysis/include/Luau/TypeFamily.h | 2 + Analysis/src/Constraint.cpp | 4 + Analysis/src/ConstraintGenerator.cpp | 74 ++-- Analysis/src/ConstraintSolver.cpp | 401 +++++++++----------- Analysis/src/Error.cpp | 13 + Analysis/src/Frontend.cpp | 1 + Analysis/src/Generalization.cpp | 397 ++++++++++++++++++- Analysis/src/Normalize.cpp | 21 +- Analysis/src/TableLiteralInference.cpp | 4 +- Analysis/src/TypeChecker2.cpp | 10 +- Analysis/src/TypeFamily.cpp | 217 +++++++++++ Analysis/src/TypeInfer.cpp | 7 +- Ast/include/Luau/Ast.h | 3 + Ast/include/Luau/Parser.h | 2 +- Ast/src/Ast.cpp | 13 + Ast/src/Parser.cpp | 37 +- CodeGen/include/Luau/IrBuilder.h | 1 + CodeGen/include/Luau/IrData.h | 6 +- CodeGen/include/Luau/IrVisitUseDef.h | 1 + CodeGen/src/BytecodeSummary.cpp | 7 +- CodeGen/src/CodeGenAssembly.cpp | 6 +- CodeGen/src/CodeGenContext.cpp | 8 +- CodeGen/src/CodeGenLower.h | 34 +- CodeGen/src/IrAnalysis.cpp | 20 + CodeGen/src/IrBuilder.cpp | 39 +- CodeGen/src/IrDump.cpp | 9 + CodeGen/src/IrRegAllocA64.cpp | 4 + CodeGen/src/IrRegAllocX64.cpp | 5 + CodeGen/src/IrUtils.cpp | 23 ++ CodeGen/src/IrValueLocationTracking.cpp | 1 + Common/include/Luau/Bytecode.h | 2 + Compiler/src/Compiler.cpp | 55 ++- tests/Conformance.test.cpp | 55 +++ tests/Generalization.test.cpp | 77 +++- tests/Normalize.test.cpp | 5 - tests/Parser.test.cpp | 70 +++- tests/ToString.test.cpp | 50 +-- tests/TypeFamily.test.cpp | 228 ++++++++++- tests/TypeInfer.aliases.test.cpp | 3 - tests/TypeInfer.loops.test.cpp | 8 +- tests/TypeInfer.operators.test.cpp | 41 +- tests/TypeInfer.singletons.test.cpp | 8 +- tests/TypeInfer.tables.test.cpp | 26 ++ tests/TypeInfer.test.cpp | 59 +++ tools/faillist.txt | 19 +- 50 files changed, 1699 insertions(+), 445 deletions(-) diff --git a/Analysis/include/Luau/ConstraintGenerator.h b/Analysis/include/Luau/ConstraintGenerator.h index b540b82f..28cfb5aa 100644 --- a/Analysis/include/Luau/ConstraintGenerator.h +++ b/Analysis/include/Luau/ConstraintGenerator.h @@ -118,7 +118,7 @@ struct ConstraintGenerator std::function prepareModuleScope; std::vector requireCycles; - DenseHashMap> localTypes{nullptr}; + DenseHashMap localTypes{nullptr}; DcrLogger* logger; diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index 902dd15d..6e62a2e3 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -94,6 +94,10 @@ struct ConstraintSolver // Irreducible/uninhabited type families or type pack families. DenseHashSet uninhabitedTypeFamilies{{}}; + // The set of types that will definitely be unchanged by generalization. + DenseHashSet generalizedTypes_{nullptr}; + const NotNull> generalizedTypes{&generalizedTypes_}; + // Recorded errors that take place within the solver. ErrorVec errors; @@ -103,6 +107,8 @@ struct ConstraintSolver DcrLogger* logger; TypeCheckLimits limits; + DenseHashMap typeFamiliesToFinalize{nullptr}; + explicit ConstraintSolver(NotNull normalizer, NotNull rootScope, std::vector> constraints, ModuleName moduleName, NotNull moduleResolver, std::vector requireCycles, DcrLogger* logger, TypeCheckLimits limits); @@ -116,8 +122,35 @@ struct ConstraintSolver **/ void run(); + + /** + * Attempts to perform one final reduction on type families after every constraint has been completed + * + **/ + void finalizeTypeFamilies(); + bool isDone(); +private: + /** + * Bind a type variable to another type. + * + * A constraint is required and will validate that blockedTy is owned by this + * constraint. This prevents one constraint from interfering with another's + * blocked types. + * + * Bind will also unblock the type variable for you. + */ + void bind(NotNull constraint, TypeId ty, TypeId boundTo); + void bind(NotNull constraint, TypePackId tp, TypePackId boundTo); + + template + void emplace(NotNull constraint, TypeId ty, Args&&... args); + + template + void emplace(NotNull constraint, TypePackId tp, Args&&... args); + +public: /** Attempt to dispatch a constraint. Returns true if it was successful. If * tryDispatch() returns false, the constraint remains in the unsolved set * and will be retried later. @@ -135,19 +168,14 @@ struct ConstraintSolver bool tryDispatch(const PrimitiveTypeConstraint& c, NotNull constraint); bool tryDispatch(const HasPropConstraint& c, NotNull constraint); + bool tryDispatchHasIndexer( int& recursionDepth, NotNull constraint, TypeId subjectType, TypeId indexType, TypeId resultType, Set& seen); bool tryDispatch(const HasIndexerConstraint& c, NotNull constraint); - std::pair> tryDispatchSetIndexer( - NotNull constraint, TypeId subjectType, TypeId indexType, TypeId propType, bool expandFreeTypeBounds); - bool tryDispatch(const AssignPropConstraint& c, NotNull constraint); bool tryDispatch(const AssignIndexConstraint& c, NotNull constraint); - - bool tryDispatchUnpack1(NotNull constraint, TypeId resultType, TypeId sourceType); bool tryDispatch(const UnpackConstraint& c, NotNull constraint); - bool tryDispatch(const ReduceConstraint& c, NotNull constraint, bool force); bool tryDispatch(const ReducePackConstraint& c, NotNull constraint, bool force); bool tryDispatch(const EqualityConstraint& c, NotNull constraint, bool force); @@ -298,22 +326,6 @@ struct ConstraintSolver template bool unify(NotNull constraint, TID subTy, TID superTy); -private: - /** - * Bind a BlockedType to another type while taking care not to bind it to - * itself in the case that resultTy == blockedTy. This can happen if we - * have a tautological constraint. When it does, we must instead bind - * blockedTy to a fresh type belonging to an appropriate scope. - * - * To determine which scope is appropriate, we also accept rootTy, which is - * to be the type that contains blockedTy. - * - * A constraint is required and will validate that blockedTy is owned by this - * constraint. This prevents one constraint from interfering with another's - * blocked types. - */ - void bindBlockedType(TypeId blockedTy, TypeId resultTy, TypeId rootTy, NotNull constraint); - /** * Marks a constraint as being blocked on a type or type pack. The constraint * solver will not attempt to dispatch blocked constraints until their diff --git a/Analysis/include/Luau/Generalization.h b/Analysis/include/Luau/Generalization.h index bf196f3e..44d0db67 100644 --- a/Analysis/include/Luau/Generalization.h +++ b/Analysis/include/Luau/Generalization.h @@ -8,6 +8,6 @@ namespace Luau { -std::optional generalize(NotNull arena, NotNull builtinTypes, NotNull scope, TypeId ty); +std::optional generalize(NotNull arena, NotNull builtinTypes, NotNull scope, NotNull> bakedTypes, TypeId ty); } diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index 197c7f9c..152d8c65 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -102,6 +102,12 @@ struct Module DenseHashMap astResolvedTypes{nullptr}; DenseHashMap astResolvedTypePacks{nullptr}; + // The computed result type of a compound assignment. (eg foo += 1) + // + // Type checking uses this to check that the result of such an operation is + // actually compatible with the left-side operand. + DenseHashMap astCompoundAssignResultTypes{nullptr}; + DenseHashMap>> upperBoundContributors{nullptr}; // Map AST nodes to the scope they create. Cannot be NotNull because diff --git a/Analysis/include/Luau/Type.h b/Analysis/include/Luau/Type.h index 6105ede3..881dc646 100644 --- a/Analysis/include/Luau/Type.h +++ b/Analysis/include/Luau/Type.h @@ -658,7 +658,7 @@ struct NegationType using ErrorType = Unifiable::Error; using TypeVariant = - Unifiable::Variant; struct Type final diff --git a/Analysis/include/Luau/TypeFamily.h b/Analysis/include/Luau/TypeFamily.h index 5b72a370..fa23a6ba 100644 --- a/Analysis/include/Luau/TypeFamily.h +++ b/Analysis/include/Luau/TypeFamily.h @@ -179,6 +179,8 @@ struct BuiltinTypeFamilies TypeFamily keyofFamily; TypeFamily rawkeyofFamily; + TypeFamily indexFamily; + void addToScope(NotNull arena, NotNull scope) const; }; diff --git a/Analysis/src/Constraint.cpp b/Analysis/src/Constraint.cpp index 7b3377cb..a62879fa 100644 --- a/Analysis/src/Constraint.cpp +++ b/Analysis/src/Constraint.cpp @@ -93,6 +93,10 @@ DenseHashSet Constraint::getMaybeMutatedFreeTypes() const { rci.traverse(taec->target); } + else if (auto fchc = get(*this)) + { + rci.traverse(fchc->argsPack); + } else if (auto ptc = get(*this)) { rci.traverse(ptc->freeType); diff --git a/Analysis/src/ConstraintGenerator.cpp b/Analysis/src/ConstraintGenerator.cpp index 9d825408..b784f4aa 100644 --- a/Analysis/src/ConstraintGenerator.cpp +++ b/Analysis/src/ConstraintGenerator.cpp @@ -253,7 +253,11 @@ void ConstraintGenerator::visitModuleRoot(AstStatBlock* block) // FIXME: This isn't the most efficient thing. TypeId domainTy = builtinTypes->neverType; for (TypeId d : domain) + { + if (d == ty) + continue; domainTy = simplifyUnion(builtinTypes, arena, domainTy, d).result; + } LUAU_ASSERT(get(ty)); asMutable(ty)->ty.emplace(domainTy); @@ -323,7 +327,7 @@ std::optional ConstraintGenerator::lookup(const ScopePtr& scope, Locatio if (!ty) { ty = arena->addType(BlockedType{}); - localTypes[*ty] = {}; + localTypes.try_insert(*ty, {}); rootScope->lvalueTypes[operand] = *ty; } @@ -717,7 +721,7 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatLocal* stat const Location location = local->location; TypeId assignee = arena->addType(BlockedType{}); - localTypes[assignee] = {}; + localTypes.try_insert(assignee, {}); assignees.push_back(assignee); @@ -756,9 +760,9 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatLocal* stat for (size_t i = 0; i < statLocal->vars.size; ++i) { LUAU_ASSERT(get(assignees[i])); - std::vector* localDomain = localTypes.find(assignees[i]); + TypeIds* localDomain = localTypes.find(assignees[i]); LUAU_ASSERT(localDomain); - localDomain->push_back(annotatedTypes[i]); + localDomain->insert(annotatedTypes[i]); } TypePackId annotatedPack = arena->addTypePack(std::move(annotatedTypes)); @@ -790,9 +794,9 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatLocal* stat for (size_t i = 0; i < statLocal->vars.size; ++i) { LUAU_ASSERT(get(assignees[i])); - std::vector* localDomain = localTypes.find(assignees[i]); + TypeIds* localDomain = localTypes.find(assignees[i]); LUAU_ASSERT(localDomain); - localDomain->push_back(valueTypes[i]); + localDomain->insert(valueTypes[i]); } } @@ -898,7 +902,7 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatForIn* forI variableTypes.push_back(assignee); TypeId loopVar = arena->addType(BlockedType{}); - localTypes[loopVar].push_back(assignee); + localTypes[loopVar].insert(assignee); if (var->annotation) { @@ -1183,8 +1187,13 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatCompoundAss { AstExprBinary binop = AstExprBinary{assign->location, assign->op, assign->var, assign->value}; TypeId resultTy = check(scope, &binop).ty; + module->astCompoundAssignResultTypes[assign] = resultTy; - visitLValue(scope, assign->var, resultTy); + TypeId lhsType = check(scope, assign->var).ty; + visitLValue(scope, assign->var, lhsType); + + follow(lhsType); + follow(resultTy); return ControlFlow::None; } @@ -1383,16 +1392,15 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatDeclareClas } } - if (ctv->props.count(propName) == 0) + TableType::Props& props = assignToMetatable ? metatable->props : ctv->props; + + if (props.count(propName) == 0) { - if (assignToMetatable) - metatable->props[propName] = {propTy}; - else - ctv->props[propName] = {propTy}; + props[propName] = {propTy}; } else { - TypeId currentTy = assignToMetatable ? metatable->props[propName].type() : ctv->props[propName].type(); + TypeId currentTy = props[propName].type(); // We special-case this logic to keep the intersection flat; otherwise we // would create a ton of nested intersection types. @@ -1402,19 +1410,13 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatDeclareClas options.push_back(propTy); TypeId newItv = arena->addType(IntersectionType{std::move(options)}); - if (assignToMetatable) - metatable->props[propName] = {newItv}; - else - ctv->props[propName] = {newItv}; + props[propName] = {newItv}; } else if (get(currentTy)) { TypeId intersection = arena->addType(IntersectionType{{currentTy, propTy}}); - if (assignToMetatable) - metatable->props[propName] = {intersection}; - else - ctv->props[propName] = {intersection}; + props[propName] = {intersection}; } else { @@ -1913,8 +1915,8 @@ Inference ConstraintGenerator::checkIndexName( // the current lexical position within the script. if (!tt) { - if (auto localDomain = localTypes.find(obj); localDomain && 1 == localDomain->size()) - tt = getTableType(localDomain->front()); + if (TypeIds* localDomain = localTypes.find(obj); localDomain && 1 == localDomain->size()) + tt = getTableType(*localDomain->begin()); } if (tt) @@ -2327,14 +2329,14 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprLocal* local if (ty) { - std::vector* localDomain = localTypes.find(*ty); + TypeIds* localDomain = localTypes.find(*ty); if (localDomain) - localDomain->push_back(rhsType); + localDomain->insert(rhsType); } else { ty = arena->addType(BlockedType{}); - localTypes[*ty].push_back(rhsType); + localTypes[*ty].insert(rhsType); if (annotatedTy) { @@ -2359,8 +2361,8 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprLocal* local if (annotatedTy) addConstraint(scope, local->location, SubtypeConstraint{rhsType, *annotatedTy}); - if (auto localDomain = localTypes.find(*ty)) - localDomain->push_back(rhsType); + if (TypeIds* localDomain = localTypes.find(*ty)) + localDomain->insert(rhsType); } void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprGlobal* global, TypeId rhsType) @@ -2383,7 +2385,8 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprIndexName* e bool incremented = recordPropertyAssignment(lhsTy); - addConstraint(scope, expr->location, AssignPropConstraint{lhsTy, expr->index.value, rhsType, propTy, incremented}); + auto apc = addConstraint(scope, expr->location, AssignPropConstraint{lhsTy, expr->index.value, rhsType, propTy, incremented}); + getMutable(propTy)->setOwner(apc); } void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprIndexExpr* expr, TypeId rhsType) @@ -2398,7 +2401,8 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprIndexExpr* e bool incremented = recordPropertyAssignment(lhsTy); - addConstraint(scope, expr->location, AssignPropConstraint{lhsTy, std::move(propName), rhsType, propTy, incremented}); + auto apc = addConstraint(scope, expr->location, AssignPropConstraint{lhsTy, std::move(propName), rhsType, propTy, incremented}); + getMutable(propTy)->setOwner(apc); return; } @@ -2407,7 +2411,8 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprIndexExpr* e TypeId indexTy = check(scope, expr->index).ty; TypeId propTy = arena->addType(BlockedType{}); module->astTypes[expr] = propTy; - addConstraint(scope, expr->location, AssignIndexConstraint{lhsTy, indexTy, rhsType, propTy}); + auto aic = addConstraint(scope, expr->location, AssignIndexConstraint{lhsTy, indexTy, rhsType, propTy}); + getMutable(propTy)->setOwner(aic); } Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprTable* expr, std::optional expectedType) @@ -2447,7 +2452,8 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprTable* expr, if (AstExprConstantString* key = item.key->as()) { - ttv->props[key->value.begin()] = {itemTy}; + std::string propName{key->value.data, key->value.size}; + ttv->props[propName] = {itemTy}; } else { @@ -3187,7 +3193,7 @@ bool ConstraintGenerator::recordPropertyAssignment(TypeId ty) } else if (auto mt = get(t)) queue.push_back(mt->table); - else if (auto localDomain = localTypes.find(t)) + else if (TypeIds* localDomain = localTypes.find(t)) { for (TypeId domainTy : *localDomain) queue.push_back(domainTy); diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 07fc26fb..e59bc8a7 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -1,9 +1,9 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/ConstraintSolver.h" #include "Luau/Anyification.h" #include "Luau/ApplyTypeFunction.h" #include "Luau/Common.h" -#include "Luau/ConstraintSolver.h" #include "Luau/DcrLogger.h" #include "Luau/Generalization.h" #include "Luau/Instantiation.h" @@ -22,8 +22,8 @@ #include "Luau/TypeFwd.h" #include "Luau/TypeUtils.h" #include "Luau/Unifier2.h" -#include "Luau/VecDeque.h" #include "Luau/VisitType.h" + #include #include @@ -67,7 +67,11 @@ size_t HashBlockedConstraintId::operator()(const BlockedConstraintId& bci) const [[maybe_unused]] static bool canMutate(TypeId ty, NotNull constraint) { if (auto blocked = get(ty)) - return blocked->getOwner() == constraint; + { + Constraint* owner = blocked->getOwner(); + LUAU_ASSERT(owner); + return owner == constraint; + } return true; } @@ -76,7 +80,11 @@ size_t HashBlockedConstraintId::operator()(const BlockedConstraintId& bci) const [[maybe_unused]] static bool canMutate(TypePackId tp, NotNull constraint) { if (auto blocked = get(tp)) - return blocked->owner == nullptr || blocked->owner == constraint; + { + Constraint* owner = blocked->owner; + LUAU_ASSERT(owner); + return owner == constraint; + } return true; } @@ -478,6 +486,12 @@ void ConstraintSolver::run() progress |= runSolverPass(true); } while (progress); + // After we have run all the constraints, type families should be generalized + // At this point, we can try to perform one final simplification to suss out + // whether type families are truly uninhabited or if they can reduce + + finalizeTypeFamilies(); + if (FFlag::DebugLuauLogSolver || FFlag::DebugLuauLogBindings) dumpBindings(rootScope, opts); @@ -487,6 +501,25 @@ void ConstraintSolver::run() } } +void ConstraintSolver::finalizeTypeFamilies() +{ + // At this point, we've generalized. Let's try to finish reducing as much as we can, we'll leave warning to the typechecker + for (auto [t, constraint] : typeFamiliesToFinalize) + { + TypeId ty = follow(t); + if (get(ty)) + { + FamilyGraphReductionResult result = + reduceFamilies(t, constraint->location, TypeFamilyContext{NotNull{this}, constraint->scope, NotNull{constraint}}, true); + + for (TypeId r : result.reducedTypes) + unblock(r, constraint->location); + for (TypePackId r : result.reducedPacks) + unblock(r, constraint->location); + } + } +} + bool ConstraintSolver::isDone() { return unsolvedConstraints.empty(); @@ -503,6 +536,56 @@ struct TypeAndLocation } // namespace +void ConstraintSolver::bind(NotNull constraint, TypeId ty, TypeId boundTo) +{ + LUAU_ASSERT(get(ty) || get(ty) || get(ty)); + LUAU_ASSERT(canMutate(ty, constraint)); + + boundTo = follow(boundTo); + if (get(ty) && ty == boundTo) + return emplace(constraint, ty, constraint->scope, builtinTypes->neverType, builtinTypes->unknownType); + + shiftReferences(ty, boundTo); + emplaceType(asMutable(ty), boundTo); + unblock(ty, constraint->location); +} + +void ConstraintSolver::bind(NotNull constraint, TypePackId tp, TypePackId boundTo) +{ + LUAU_ASSERT(get(tp) || get(tp)); + LUAU_ASSERT(canMutate(tp, constraint)); + + boundTo = follow(boundTo); + LUAU_ASSERT(tp != boundTo); + + emplaceTypePack(asMutable(tp), boundTo); + unblock(tp, constraint->location); +} + +template +void ConstraintSolver::emplace(NotNull constraint, TypeId ty, Args&&... args) +{ + static_assert(!std::is_same_v, "cannot use `emplace`! use `bind`"); + + LUAU_ASSERT(get(ty) || get(ty) || get(ty)); + LUAU_ASSERT(canMutate(ty, constraint)); + + emplaceType(asMutable(ty), std::forward(args)...); + unblock(ty, constraint->location); +} + +template +void ConstraintSolver::emplace(NotNull constraint, TypePackId tp, Args&&... args) +{ + static_assert(!std::is_same_v, "cannot use `emplace`! use `bind`"); + + LUAU_ASSERT(get(tp) || get(tp)); + LUAU_ASSERT(canMutate(tp, constraint)); + + emplaceTypePack(asMutable(tp), std::forward(args)...); + unblock(tp, constraint->location); +} + bool ConstraintSolver::tryDispatch(NotNull constraint, bool force) { if (!force && isBlocked(constraint)) @@ -547,9 +630,6 @@ bool ConstraintSolver::tryDispatch(NotNull constraint, bool fo else LUAU_ASSERT(false); - if (success) - unblock(constraint); - return success; } @@ -588,7 +668,7 @@ bool ConstraintSolver::tryDispatch(const GeneralizationConstraint& c, NotNull generalized; - std::optional generalizedTy = generalize(NotNull{arena}, builtinTypes, constraint->scope, c.sourceType); + std::optional generalizedTy = generalize(NotNull{arena}, builtinTypes, constraint->scope, generalizedTypes, c.sourceType); if (generalizedTy) generalized = QuantifierResult{*generalizedTy}; // FIXME insertedGenerics and insertedGenericPacks else @@ -597,7 +677,7 @@ bool ConstraintSolver::tryDispatch(const GeneralizationConstraint& c, NotNull(generalizedType)) - bindBlockedType(generalizedType, generalized->result, c.sourceType, constraint); + bind(constraint, generalizedType, generalized->result); else unify(constraint, generalizedType, generalized->result); @@ -610,17 +690,11 @@ bool ConstraintSolver::tryDispatch(const GeneralizationConstraint& c, NotNulllocation); - emplaceType(asMutable(c.generalizedType), builtinTypes->errorType); + bind(constraint, c.generalizedType, builtinTypes->errorRecoveryType()); } - unblock(c.generalizedType, constraint->location); - unblock(c.sourceType, constraint->location); - for (TypeId ty : c.interiorTypes) - { - generalize(NotNull{arena}, builtinTypes, constraint->scope, ty); - unblock(ty, constraint->location); - } + generalize(NotNull{arena}, builtinTypes, constraint->scope, generalizedTypes, ty); return true; } @@ -710,18 +784,18 @@ bool ConstraintSolver::tryDispatch(const IterableConstraint& c, NotNullnilType, builtinTypes->nilType, constraint); + bind(constraint, *it, builtinTypes->nilType); ++it; } @@ -813,15 +887,14 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul const PendingExpansionType* petv = get(follow(c.target)); if (!petv) { - unblock(c.target, constraint->location); + unblock(c.target, constraint->location); // TODO: do we need this? any re-entrancy? return true; } auto bindResult = [this, &c, constraint](TypeId result) { LUAU_ASSERT(get(c.target)); shiftReferences(c.target, result); - emplaceType(asMutable(c.target), result); - unblock(c.target, constraint->location); + bind(constraint, c.target, result); }; std::optional tf = (petv->prefix) ? constraint->scope->lookupImportedType(petv->prefix->value, petv->name.value) @@ -1009,19 +1082,23 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull(fn)) + { + emplaceTypePack(asMutable(c.result), builtinTypes->anyTypePack); + unblock(c.result, constraint->location); + return true; + } + // if we're calling an error type, the result is an error type, and that's that. if (get(fn)) { - emplaceTypePack(asMutable(c.result), builtinTypes->errorTypePack); - unblock(c.result, constraint->location); - + bind(constraint, c.result, builtinTypes->errorRecoveryTypePack()); return true; } if (get(fn)) { - emplaceTypePack(asMutable(c.result), builtinTypes->neverTypePack); - unblock(c.result, constraint->location); + bind(constraint, c.result, builtinTypes->neverTypePack); return true; } @@ -1078,7 +1155,7 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNulladdTypePack(TypePack{std::move(argsHead), argsTail}); fn = follow(*callMm); - emplaceTypePack(asMutable(c.result), constraint->scope); + emplace(constraint, c.result, constraint->scope); } else { @@ -1095,14 +1172,21 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull(asMutable(c.result), constraint->scope); + emplace(constraint, c.result, constraint->scope); } for (std::optional ty : c.discriminantTypes) { - if (!ty || !isBlocked(*ty)) + if (!ty) continue; + // If the discriminant type has been transmuted, we need to unblock them. + if (!isBlocked(*ty)) + { + unblock(*ty, constraint->location); + continue; + } + // We use `any` here because the discriminant type may be pointed at by both branches, // where the discriminant type is not negated, and the other where it is negated, i.e. // `unknown ~ unknown` and `~unknown ~ never`, so `T & unknown ~ T` and `T & ~unknown ~ never` @@ -1110,7 +1194,7 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNullanyType}; + emplaceType(asMutable(follow(*ty)), builtinTypes->anyType); } OverloadResolver resolver{ @@ -1120,7 +1204,6 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNulladdType(FunctionType{TypeLevel{}, constraint->scope.get(), argsPack, c.result}); Unifier2 u2{NotNull{arena}, builtinTypes, constraint->scope, NotNull{&iceReporter}}; @@ -1150,12 +1233,12 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNulllocation); - InstantiationQueuer queuer{constraint->scope, constraint->location, this}; queuer.traverse(overloadToUse); queuer.traverse(inferredTy); + unblock(c.result, constraint->location); + return true; } @@ -1250,7 +1333,7 @@ bool ConstraintSolver::tryDispatch(const FunctionCheckConstraint& c, NotNullargs.data[j]->annotation && get(follow(lambdaArgTys[j]))) { shiftReferences(lambdaArgTys[j], expectedLambdaArgTys[j]); - emplaceType(asMutable(lambdaArgTys[j]), expectedLambdaArgTys[j]); + bind(constraint, lambdaArgTys[j], expectedLambdaArgTys[j]); } } } @@ -1303,7 +1386,7 @@ bool ConstraintSolver::tryDispatch(const PrimitiveTypeConstraint& c, NotNulllowerBound; shiftReferences(c.freeType, bindTo); - emplaceType(asMutable(c.freeType), bindTo); + bind(constraint, c.freeType, bindTo); return true; } @@ -1336,8 +1419,7 @@ bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNullanyType), c.subjectType, constraint); - unblock(resultType, constraint->location); + bind(constraint, resultType, result.value_or(builtinTypes->anyType)); return true; } @@ -1361,12 +1443,12 @@ bool ConstraintSolver::tryDispatchHasIndexer( if (auto tbl = get(follow(ft->upperBound)); tbl && tbl->indexer) { unify(constraint, indexType, tbl->indexer->indexType); - bindBlockedType(resultType, tbl->indexer->indexResultType, subjectType, constraint); + bind(constraint, resultType, tbl->indexer->indexResultType); return true; } FreeType freeResult{ft->scope, builtinTypes->neverType, builtinTypes->unknownType}; - emplaceType(asMutable(resultType), freeResult); + emplace(constraint, resultType, freeResult); TypeId upperBound = arena->addType(TableType{/* props */ {}, TableIndexer{indexType, resultType}, TypeLevel{}, TableState::Unsealed}); @@ -1380,7 +1462,7 @@ bool ConstraintSolver::tryDispatchHasIndexer( { unify(constraint, indexType, indexer->indexType); - bindBlockedType(resultType, indexer->indexResultType, subjectType, constraint); + bind(constraint, resultType, indexer->indexResultType); return true; } else if (tt->state == TableState::Unsealed) @@ -1388,7 +1470,7 @@ bool ConstraintSolver::tryDispatchHasIndexer( // FIXME this is greedy. FreeType freeResult{tt->scope, builtinTypes->neverType, builtinTypes->unknownType}; - emplaceType(asMutable(resultType), freeResult); + emplace(constraint, resultType, freeResult); tt->indexer = TableIndexer{indexType, resultType}; return true; @@ -1401,12 +1483,12 @@ bool ConstraintSolver::tryDispatchHasIndexer( if (auto indexer = ct->indexer) { unify(constraint, indexType, indexer->indexType); - bindBlockedType(resultType, indexer->indexResultType, subjectType, constraint); + bind(constraint, resultType, indexer->indexResultType); return true; } else if (isString(indexType)) { - bindBlockedType(resultType, builtinTypes->unknownType, subjectType, constraint); + bind(constraint, resultType, builtinTypes->unknownType); return true; } } @@ -1441,11 +1523,11 @@ bool ConstraintSolver::tryDispatchHasIndexer( } if (0 == results.size()) - bindBlockedType(resultType, builtinTypes->errorType, subjectType, constraint); + bind(constraint, resultType, builtinTypes->errorType); else if (1 == results.size()) - bindBlockedType(resultType, *results.begin(), subjectType, constraint); + bind(constraint, resultType, *results.begin()); else - emplaceType(asMutable(resultType), std::vector(results.begin(), results.end())); + emplace(constraint, resultType, std::vector(results.begin(), results.end())); return true; } @@ -1473,20 +1555,20 @@ bool ConstraintSolver::tryDispatchHasIndexer( } if (0 == results.size()) - emplaceType(asMutable(resultType), builtinTypes->errorType); + bind(constraint, resultType, builtinTypes->errorType); else if (1 == results.size()) { TypeId firstResult = *results.begin(); shiftReferences(resultType, firstResult); - emplaceType(asMutable(resultType), firstResult); + bind(constraint, resultType, firstResult); } else - emplaceType(asMutable(resultType), std::vector(results.begin(), results.end())); + emplace(constraint, resultType, std::vector(results.begin(), results.end())); return true; } - bindBlockedType(resultType, builtinTypes->errorType, subjectType, constraint); + bind(constraint, resultType, builtinTypes->errorType); return true; } @@ -1534,86 +1616,7 @@ bool ConstraintSolver::tryDispatch(const HasIndexerConstraint& c, NotNull seen{nullptr}; - bool ok = tryDispatchHasIndexer(recursionDepth, constraint, subjectType, indexType, c.resultType, seen); - if (ok) - unblock(c.resultType, constraint->location); - return ok; -} - -std::pair> ConstraintSolver::tryDispatchSetIndexer( - NotNull constraint, TypeId subjectType, TypeId indexType, TypeId propType, bool expandFreeTypeBounds) -{ - if (isBlocked(subjectType)) - return {block(subjectType, constraint), std::nullopt}; - - if (auto tt = getMutable(subjectType)) - { - if (tt->indexer) - { - if (isBlocked(tt->indexer->indexResultType)) - return {block(tt->indexer->indexResultType, constraint), std::nullopt}; - - unify(constraint, indexType, tt->indexer->indexType); - return {true, tt->indexer->indexResultType}; - } - else if (tt->state == TableState::Free || tt->state == TableState::Unsealed) - { - TypeId resultTy = freshType(arena, builtinTypes, constraint->scope.get()); - tt->indexer = TableIndexer{indexType, resultTy}; - return {true, resultTy}; - } - } - else if (auto ft = getMutable(subjectType); ft && expandFreeTypeBounds) - { - // Setting an indexer on some fresh type means we use that fresh type in a negative position. - // Therefore, we only care about the upper bound. - // - // We'll extend the upper bound if we could dispatch, but could not find a table type to update the indexer. - auto [dispatched, resultTy] = tryDispatchSetIndexer(constraint, ft->upperBound, indexType, propType, /*expandFreeTypeBounds=*/false); - if (dispatched && !resultTy) - { - // Despite that we haven't found a table type, adding a table type causes us to have one that we can /now/ find. - resultTy = freshType(arena, builtinTypes, constraint->scope.get()); - - TypeId tableTy = arena->addType(TableType{TableState::Sealed, TypeLevel{}, constraint->scope.get()}); - TableType* tt2 = getMutable(tableTy); - tt2->indexer = TableIndexer{indexType, *resultTy}; - - ft->upperBound = - simplifyIntersection(builtinTypes, arena, ft->upperBound, tableTy).result; // TODO: intersect type family or a constraint. - } - - return {dispatched, resultTy}; - } - else if (auto it = get(subjectType)) - { - bool dispatched = true; - std::vector results; - - for (TypeId part : it) - { - auto [dispatched2, found] = tryDispatchSetIndexer(constraint, part, indexType, propType, expandFreeTypeBounds); - dispatched &= dispatched2; - results.push_back(found.value_or(builtinTypes->errorRecoveryType())); - - if (!dispatched) - return {dispatched, std::nullopt}; - } - - TypeId resultTy = arena->addType(TypeFamilyInstanceType{ - NotNull{&kBuiltinTypeFamilies.unionFamily}, - std::move(results), - {}, - }); - - pushConstraint(constraint->scope, constraint->location, ReduceConstraint{resultTy}); - - return {dispatched, resultTy}; - } - else if (is(subjectType)) - return {true, subjectType}; - - return {true, std::nullopt}; + return tryDispatchHasIndexer(recursionDepth, constraint, subjectType, indexType, c.resultType, seen); } bool ConstraintSolver::tryDispatch(const AssignPropConstraint& c, NotNull constraint) @@ -1643,7 +1646,7 @@ bool ConstraintSolver::tryDispatch(const AssignPropConstraint& c, NotNullwriteTy.has_value()) return true; - emplaceType(asMutable(c.propType), *prop->writeTy); + bind(constraint, c.propType, *prop->writeTy); unify(constraint, rhsType, *prop->writeTy); return true; } @@ -1663,7 +1666,7 @@ bool ConstraintSolver::tryDispatch(const AssignPropConstraint& c, NotNullupperBound = simplifyIntersection(builtinTypes, arena, lhsFree->upperBound, newUpperBound).result; - emplaceType(asMutable(c.propType), rhsType); + bind(constraint, c.propType, rhsType); return true; } } @@ -1681,7 +1684,7 @@ bool ConstraintSolver::tryDispatch(const AssignPropConstraint& c, NotNull(asMutable(c.propType), propTy); + bind(constraint, c.propType, propTy); unify(constraint, rhsType, propTy); return true; } @@ -1700,7 +1703,7 @@ bool ConstraintSolver::tryDispatch(const AssignPropConstraint& c, NotNull(asMutable(c.propType), *prop.writeTy); + bind(constraint, c.propType, *prop.writeTy); unify(constraint, rhsType, *prop.writeTy); return true; } @@ -1710,13 +1713,13 @@ bool ConstraintSolver::tryDispatch(const AssignPropConstraint& c, NotNullstate == TableState::Unsealed || lhsTable->state == TableState::Free) { prop.writeTy = prop.readTy; - emplaceType(asMutable(c.propType), *prop.writeTy); + bind(constraint, c.propType, *prop.writeTy); unify(constraint, rhsType, *prop.writeTy); return true; } else { - emplaceType(asMutable(c.propType), builtinTypes->errorType); + bind(constraint, c.propType, builtinTypes->errorType); return true; } } @@ -1724,28 +1727,27 @@ bool ConstraintSolver::tryDispatch(const AssignPropConstraint& c, NotNullindexer && maybeString(lhsTable->indexer->indexType)) { - emplaceType(asMutable(c.propType), rhsType); + bind(constraint, c.propType, rhsType); unify(constraint, rhsType, lhsTable->indexer->indexResultType); return true; } if (lhsTable->state == TableState::Unsealed || lhsTable->state == TableState::Free) { - emplaceType(asMutable(c.propType), rhsType); + bind(constraint, c.propType, rhsType); lhsTable->props[propName] = Property::rw(rhsType); if (lhsTable->state == TableState::Unsealed && c.decrementPropCount) { LUAU_ASSERT(lhsTable->remainingProps > 0); lhsTable->remainingProps -= 1; - unblock(lhsType, constraint->location); } return true; } } - emplaceType(asMutable(c.propType), builtinTypes->errorType); + bind(constraint, c.propType, builtinTypes->errorType); return true; } @@ -1772,14 +1774,14 @@ bool ConstraintSolver::tryDispatch(const AssignIndexConstraint& c, NotNullindexer->indexType); unify(constraint, rhsType, lhsTable->indexer->indexResultType); - emplaceType(asMutable(c.propType), lhsTable->indexer->indexResultType); + bind(constraint, c.propType, lhsTable->indexer->indexResultType); return true; } if (lhsTable->state == TableState::Unsealed || lhsTable->state == TableState::Free) { lhsTable->indexer = TableIndexer{indexType, rhsType}; - emplaceType(asMutable(c.propType), rhsType); + bind(constraint, c.propType, rhsType); return true; } @@ -1802,7 +1804,7 @@ bool ConstraintSolver::tryDispatch(const AssignIndexConstraint& c, NotNullindexer); - emplaceType(asMutable(c.propType), newTable->indexer->indexResultType); + bind(constraint, c.propType, newTable->indexer->indexResultType); return true; } @@ -1821,7 +1823,7 @@ bool ConstraintSolver::tryDispatch(const AssignIndexConstraint& c, NotNullindexer->indexType); unify(constraint, rhsType, lhsClass->indexer->indexResultType); - emplaceType(asMutable(c.propType), lhsClass->indexer->indexResultType); + bind(constraint, c.propType, lhsClass->indexer->indexResultType); return true; } @@ -1878,40 +1880,11 @@ bool ConstraintSolver::tryDispatch(const AssignIndexConstraint& c, NotNull(asMutable(c.propType), builtinTypes->errorType); + bind(constraint, c.propType, builtinTypes->errorType); return true; } -bool ConstraintSolver::tryDispatchUnpack1(NotNull constraint, TypeId resultTy, TypeId srcTy) -{ - resultTy = follow(resultTy); - LUAU_ASSERT(canMutate(resultTy, constraint)); - - LUAU_ASSERT(get(resultTy)); - - if (get(resultTy)) - { - if (follow(srcTy) == resultTy) - { - // It is sometimes the case that we find that a blocked type - // is only blocked on itself. This doesn't actually - // constitute any meaningful constraint, so we replace it - // with a free type. - TypeId f = freshType(arena, builtinTypes, constraint->scope); - shiftReferences(resultTy, f); - emplaceType(asMutable(resultTy), f); - } - else - bindBlockedType(resultTy, srcTy, srcTy, constraint); - } - else - unify(constraint, srcTy, resultTy); - - unblock(resultTy, constraint->location); - return true; -} - bool ConstraintSolver::tryDispatch(const UnpackConstraint& c, NotNull constraint) { TypePackId sourcePack = follow(c.sourcePack); @@ -1932,7 +1905,29 @@ bool ConstraintSolver::tryDispatch(const UnpackConstraint& c, NotNull(resultTy)); + LUAU_ASSERT(canMutate(resultTy, constraint)); + + if (get(resultTy)) + { + if (follow(srcTy) == resultTy) + { + // It is sometimes the case that we find that a blocked type + // is only blocked on itself. This doesn't actually + // constitute any meaningful constraint, so we replace it + // with a free type. + TypeId f = freshType(arena, builtinTypes, constraint->scope); + shiftReferences(resultTy, f); + emplaceType(asMutable(resultTy), f); + } + else + bind(constraint, resultTy, srcTy); + } + else + unify(constraint, srcTy, resultTy); + + unblock(resultTy, constraint->location); ++resultIter; ++i; @@ -1948,8 +1943,7 @@ bool ConstraintSolver::tryDispatch(const UnpackConstraint& c, NotNull(resultTy) || get(resultTy)) { - emplaceType(asMutable(resultTy), builtinTypes->nilType); - unblock(resultTy, constraint->location); + bind(constraint, resultTy, builtinTypes->nilType); } ++resultIter; @@ -1972,6 +1966,11 @@ bool ConstraintSolver::tryDispatch(const ReduceConstraint& c, NotNull(ty)) + typeFamiliesToFinalize[ty] = constraint; + if (force || reductionFinished) { // if we're completely dispatching this constraint, we want to record any uninhabited type families to unblock. @@ -2058,11 +2057,11 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl auto endIt = end(c.variables); if (it != endIt) { - bindBlockedType(*it, keyTy, keyTy, constraint); + bind(constraint, *it, keyTy); ++it; } if (it != endIt) - bindBlockedType(*it, valueTy, valueTy, constraint); + bind(constraint, *it, valueTy); return true; } @@ -2072,7 +2071,7 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl { LUAU_ASSERT(get(varTy)); LUAU_ASSERT(varTy != ty); - bindBlockedType(varTy, ty, ty, constraint); + bind(constraint, varTy, ty); } }; @@ -2121,8 +2120,7 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl unify(constraint, c.variables[i], expectedVariables[i]); - bindBlockedType(c.variables[i], expectedVariables[i], expectedVariables[i], constraint); - unblock(c.variables[i], constraint->location); + bind(constraint, c.variables[i], expectedVariables[i]); } } else @@ -2517,42 +2515,9 @@ bool ConstraintSolver::unify(NotNull constraint, TID subTy, TI return false; } - unblock(subTy, constraint->location); - unblock(superTy, constraint->location); - return true; } -void ConstraintSolver::bindBlockedType(TypeId blockedTy, TypeId resultTy, TypeId rootTy, NotNull constraint) -{ - resultTy = follow(resultTy); - - LUAU_ASSERT(get(blockedTy) && canMutate(blockedTy, constraint)); - - if (blockedTy == resultTy) - { - rootTy = follow(rootTy); - Scope* freeScope = nullptr; - if (auto ft = get(rootTy)) - freeScope = ft->scope; - else if (auto tt = get(rootTy); tt && tt->state == TableState::Free) - freeScope = tt->scope; - else - iceReporter.ice("bindBlockedType couldn't find an appropriate scope for a fresh type!", constraint->location); - - LUAU_ASSERT(freeScope); - - TypeId freeType = arena->freshType(freeScope); - shiftReferences(blockedTy, freeType); - emplaceType(asMutable(blockedTy), freeType); - } - else - { - shiftReferences(blockedTy, resultTy); - emplaceType(asMutable(blockedTy), resultTy); - } -} - bool ConstraintSolver::block_(BlockedConstraintId target, NotNull constraint) { // If a set is not present for the target, construct a new DenseHashSet for it, @@ -2884,7 +2849,7 @@ std::optional ConstraintSolver::generalizeFreeType(NotNull scope, // that until all constraint generation is complete. } - return generalize(NotNull{arena}, builtinTypes, scope, type); + return generalize(NotNull{arena}, builtinTypes, scope, generalizedTypes, type); } bool ConstraintSolver::hasUnresolvedConstraints(TypeId ty) diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index 2087e3d3..d356b1cc 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -7,6 +7,7 @@ #include "Luau/NotNull.h" #include "Luau/StringUtils.h" #include "Luau/ToString.h" +#include "Luau/Type.h" #include "Luau/TypeFamily.h" #include @@ -666,6 +667,18 @@ struct ErrorConverter return "Type family instance " + Luau::toString(e.ty) + " is ill-formed, and thus invalid"; } + if ("index" == tfit->family->name) + { + if (tfit->typeArguments.size() != 2) + return "Type family instance " + Luau::toString(e.ty) + " is ill-formed, and thus invalid"; + + if (auto errType = get(tfit->typeArguments[1])) // Second argument to index<_,_> is not a type + return "Second argument to index<" + Luau::toString(tfit->typeArguments[0]) + ", _> is not a valid index type"; + else // Second argument to index<_,_> is not a property of the first argument + return "Property '" + Luau::toString(tfit->typeArguments[1]) + "' does not exist on type '" + Luau::toString(tfit->typeArguments[0]) + + "'"; + } + if (kUnreachableTypeFamilies.count(tfit->family->name)) { return "Type family instance " + Luau::toString(e.ty) + " is uninhabited\n" + diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 7823f3d4..618a9a9c 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -1003,6 +1003,7 @@ void Frontend::checkBuildQueueItem(BuildQueueItem& item) module->astForInNextTypes.clear(); module->astResolvedTypes.clear(); module->astResolvedTypePacks.clear(); + module->astCompoundAssignResultTypes.clear(); module->astScopes.clear(); module->upperBoundContributors.clear(); diff --git a/Analysis/src/Generalization.cpp b/Analysis/src/Generalization.cpp index 081ea153..c2c44d96 100644 --- a/Analysis/src/Generalization.cpp +++ b/Analysis/src/Generalization.cpp @@ -4,6 +4,7 @@ #include "Luau/Scope.h" #include "Luau/Type.h" +#include "Luau/ToString.h" #include "Luau/TypeArena.h" #include "Luau/TypePack.h" #include "Luau/VisitType.h" @@ -16,6 +17,7 @@ struct MutatingGeneralizer : TypeOnceVisitor NotNull builtinTypes; NotNull scope; + NotNull> cachedTypes; DenseHashMap positiveTypes; DenseHashMap negativeTypes; std::vector generics; @@ -23,11 +25,12 @@ struct MutatingGeneralizer : TypeOnceVisitor bool isWithinFunction = false; - MutatingGeneralizer(NotNull builtinTypes, NotNull scope, DenseHashMap positiveTypes, + MutatingGeneralizer(NotNull builtinTypes, NotNull scope, NotNull> cachedTypes, DenseHashMap positiveTypes, DenseHashMap negativeTypes) : TypeOnceVisitor(/* skipBoundTypes */ true) , builtinTypes(builtinTypes) , scope(scope) + , cachedTypes(cachedTypes) , positiveTypes(std::move(positiveTypes)) , negativeTypes(std::move(negativeTypes)) { @@ -130,6 +133,9 @@ struct MutatingGeneralizer : TypeOnceVisitor bool visit(TypeId ty, const FunctionType& ft) override { + if (cachedTypes->contains(ty)) + return false; + const bool oldValue = isWithinFunction; isWithinFunction = true; @@ -144,6 +150,8 @@ struct MutatingGeneralizer : TypeOnceVisitor bool visit(TypeId ty, const FreeType&) override { + LUAU_ASSERT(!cachedTypes->contains(ty)); + const FreeType* ft = get(ty); LUAU_ASSERT(ft); @@ -244,6 +252,9 @@ struct MutatingGeneralizer : TypeOnceVisitor bool visit(TypeId ty, const TableType&) override { + if (cachedTypes->contains(ty)) + return false; + const size_t positiveCount = getCount(positiveTypes, ty); const size_t negativeCount = getCount(negativeTypes, ty); @@ -287,10 +298,12 @@ struct MutatingGeneralizer : TypeOnceVisitor struct FreeTypeSearcher : TypeVisitor { NotNull scope; + NotNull> cachedTypes; - explicit FreeTypeSearcher(NotNull scope) + explicit FreeTypeSearcher(NotNull scope, NotNull> cachedTypes) : TypeVisitor(/*skipBoundTypes*/ true) , scope(scope) + , cachedTypes(cachedTypes) { } @@ -363,7 +376,7 @@ struct FreeTypeSearcher : TypeVisitor bool visit(TypeId ty) override { - if (seenWithPolarity(ty)) + if (cachedTypes->contains(ty) || seenWithPolarity(ty)) return false; LUAU_ASSERT(ty); @@ -372,7 +385,7 @@ struct FreeTypeSearcher : TypeVisitor bool visit(TypeId ty, const FreeType& ft) override { - if (seenWithPolarity(ty)) + if (cachedTypes->contains(ty) || seenWithPolarity(ty)) return false; if (!subsumes(scope, ft.scope)) @@ -397,7 +410,7 @@ struct FreeTypeSearcher : TypeVisitor bool visit(TypeId ty, const TableType& tt) override { - if (seenWithPolarity(ty)) + if (cachedTypes->contains(ty) || seenWithPolarity(ty)) return false; if ((tt.state == TableState::Free || tt.state == TableState::Unsealed) && subsumes(scope, tt.scope)) @@ -443,7 +456,7 @@ struct FreeTypeSearcher : TypeVisitor bool visit(TypeId ty, const FunctionType& ft) override { - if (seenWithPolarity(ty)) + if (cachedTypes->contains(ty) || seenWithPolarity(ty)) return false; flip(); @@ -486,8 +499,371 @@ struct FreeTypeSearcher : TypeVisitor } }; +// We keep a running set of types that will not change under generalization and +// only have outgoing references to types that are the same. We use this to +// short circuit generalization. It improves performance quite a lot. +// +// We do this by tracing through the type and searching for types that are +// uncacheable. If a type has a reference to an uncacheable type, it is itself +// uncacheable. +// +// If a type has no outbound references to uncacheable types, we add it to the +// cache. +struct TypeCacher : TypeOnceVisitor +{ + NotNull> cachedTypes; -std::optional generalize(NotNull arena, NotNull builtinTypes, NotNull scope, TypeId ty) + DenseHashSet uncacheable{nullptr}; + DenseHashSet uncacheablePacks{nullptr}; + + explicit TypeCacher(NotNull> cachedTypes) + : TypeOnceVisitor(/* skipBoundTypes */ true) + , cachedTypes(cachedTypes) + {} + + void cache(TypeId ty) + { + cachedTypes->insert(ty); + } + + bool isCached(TypeId ty) const + { + return cachedTypes->contains(ty); + } + + void markUncacheable(TypeId ty) + { + uncacheable.insert(ty); + } + + void markUncacheable(TypePackId tp) + { + uncacheablePacks.insert(tp); + } + + bool isUncacheable(TypeId ty) const + { + return uncacheable.contains(ty); + } + + bool isUncacheable(TypePackId tp) const + { + return uncacheablePacks.contains(tp); + } + + bool visit(TypeId ty) override + { + if (isUncacheable(ty) || isCached(ty)) + return false; + return true; + } + + bool visit(TypeId ty, const FreeType& ft) override + { + // Free types are never cacheable. + LUAU_ASSERT(!isCached(ty)); + + if (!isUncacheable(ty)) + { + traverse(ft.lowerBound); + traverse(ft.upperBound); + + markUncacheable(ty); + } + + return false; + } + + bool visit(TypeId ty, const GenericType&) override + { + cache(ty); + return false; + } + + bool visit(TypeId ty, const PrimitiveType&) override + { + cache(ty); + return false; + } + + bool visit(TypeId ty, const SingletonType&) override + { + cache(ty); + return false; + } + + bool visit(TypeId ty, const BlockedType&) override + { + markUncacheable(ty); + return false; + } + + bool visit(TypeId ty, const PendingExpansionType&) override + { + markUncacheable(ty); + return false; + } + + bool visit(TypeId ty, const FunctionType& ft) override + { + if (isCached(ty) || isUncacheable(ty)) + return false; + + traverse(ft.argTypes); + traverse(ft.retTypes); + for (TypeId gen: ft.generics) + traverse(gen); + + bool uncacheable = false; + + if (isUncacheable(ft.argTypes)) + uncacheable = true; + + else if (isUncacheable(ft.retTypes)) + uncacheable = true; + + for (TypeId argTy: ft.argTypes) + { + if (isUncacheable(argTy)) + { + uncacheable = true; + break; + } + } + + for (TypeId retTy: ft.retTypes) + { + if (isUncacheable(retTy)) + { + uncacheable = true; + break; + } + } + + for (TypeId g: ft.generics) + { + if (isUncacheable(g)) + { + uncacheable = true; + break; + } + } + + if (uncacheable) + markUncacheable(ty); + else + cache(ty); + + return false; + } + + bool visit(TypeId ty, const TableType& tt) override + { + if (isCached(ty) || isUncacheable(ty)) + return false; + + if (tt.boundTo) + { + traverse(*tt.boundTo); + if (isUncacheable(*tt.boundTo)) + { + markUncacheable(ty); + return false; + } + } + + bool uncacheable = false; + + // This logic runs immediately after generalization, so any remaining + // unsealed tables are assuredly not cacheable. They may yet have + // properties added to them. + if (tt.state == TableState::Free || tt.state == TableState::Unsealed) + uncacheable = true; + + for (const auto& [_name, prop] : tt.props) + { + if (prop.readTy) + { + traverse(*prop.readTy); + + if (isUncacheable(*prop.readTy)) + uncacheable = true; + } + if (prop.writeTy && prop.writeTy != prop.readTy) + { + traverse(*prop.writeTy); + + if (isUncacheable(*prop.writeTy)) + uncacheable = true; + } + } + + if (tt.indexer) + { + traverse(tt.indexer->indexType); + if (isUncacheable(tt.indexer->indexType)) + uncacheable = true; + + traverse(tt.indexer->indexResultType); + if (isUncacheable(tt.indexer->indexResultType)) + uncacheable = true; + } + + if (uncacheable) + markUncacheable(ty); + else + cache(ty); + + return false; + } + + bool visit(TypeId ty, const ClassType&) override + { + cache(ty); + return false; + } + + bool visit(TypeId ty, const AnyType&) override + { + cache(ty); + return false; + } + + bool visit(TypeId ty, const UnionType& ut) override + { + if (isUncacheable(ty) || isCached(ty)) + return false; + + bool uncacheable = false; + + for (TypeId partTy : ut.options) + { + traverse(partTy); + + uncacheable |= isUncacheable(partTy); + } + + if (uncacheable) + markUncacheable(ty); + else + cache(ty); + + return false; + } + + bool visit(TypeId ty, const IntersectionType& it) override + { + if (isUncacheable(ty) || isCached(ty)) + return false; + + bool uncacheable = false; + + for (TypeId partTy : it.parts) + { + traverse(partTy); + + uncacheable |= isUncacheable(partTy); + } + + if (uncacheable) + markUncacheable(ty); + else + cache(ty); + + return false; + } + + bool visit(TypeId ty, const UnknownType&) override + { + cache(ty); + return false; + } + + bool visit(TypeId ty, const NeverType&) override + { + cache(ty); + return false; + } + + bool visit(TypeId ty, const NegationType& nt) override + { + if (!isCached(ty) && !isUncacheable(ty)) + { + traverse(nt.ty); + + if (isUncacheable(nt.ty)) + markUncacheable(ty); + else + cache(ty); + } + + return false; + } + + bool visit(TypeId ty, const TypeFamilyInstanceType& tfit) override + { + if (isCached(ty) || isUncacheable(ty)) + return false; + + bool uncacheable = false; + + for (TypeId argTy : tfit.typeArguments) + { + traverse(argTy); + + if (isUncacheable(argTy)) + uncacheable = true; + } + + for (TypePackId argPack : tfit.packArguments) + { + traverse(argPack); + + if (isUncacheable(argPack)) + uncacheable = true; + } + + if (uncacheable) + markUncacheable(ty); + else + cache(ty); + + return false; + } + + bool visit(TypePackId tp, const FreeTypePack&) override + { + markUncacheable(tp); + return false; + } + + bool visit(TypePackId tp, const VariadicTypePack& vtp) override + { + if (isUncacheable(tp)) + return false; + + traverse(vtp.ty); + + if (isUncacheable(vtp.ty)) + markUncacheable(tp); + + return false; + } + + bool visit(TypePackId tp, const BlockedTypePack&) override + { + markUncacheable(tp); + return false; + } + + bool visit(TypePackId tp, const TypeFamilyInstanceTypePack&) override + { + markUncacheable(tp); + return false; + } +}; + +std::optional generalize(NotNull arena, NotNull builtinTypes, NotNull scope, NotNull> cachedTypes, TypeId ty) { ty = follow(ty); @@ -497,10 +873,10 @@ std::optional generalize(NotNull arena, NotNull if (const FunctionType* ft = get(ty); ft && (!ft->generics.empty() || !ft->genericPacks.empty())) return ty; - FreeTypeSearcher fts{scope}; + FreeTypeSearcher fts{scope, cachedTypes}; fts.traverse(ty); - MutatingGeneralizer gen{builtinTypes, scope, std::move(fts.positiveTypes), std::move(fts.negativeTypes)}; + MutatingGeneralizer gen{builtinTypes, scope, cachedTypes, std::move(fts.positiveTypes), std::move(fts.negativeTypes)}; gen.traverse(ty); @@ -513,6 +889,9 @@ std::optional generalize(NotNull arena, NotNull if (ty->owningArena != arena || ty->persistent) return ty; + TypeCacher cacher{cachedTypes}; + cacher.traverse(ty); + FunctionType* ftv = getMutable(ty); if (ftv) { diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index 7ce50284..16fe9546 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -18,7 +18,6 @@ LUAU_FASTFLAGVARIABLE(DebugLuauCheckNormalizeInvariant, false) LUAU_FASTFLAGVARIABLE(LuauNormalizeAwayUninhabitableTables, false) LUAU_FASTFLAGVARIABLE(LuauNormalizeNotUnknownIntersection, false); -LUAU_FASTFLAGVARIABLE(LuauFixCyclicUnionsOfIntersections, false); LUAU_FASTFLAGVARIABLE(LuauFixReduceStackPressure, false); LUAU_FASTFLAGVARIABLE(LuauFixCyclicTablesBlowingStack, false); @@ -27,11 +26,6 @@ LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200); LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); -static bool fixCyclicUnionsOfIntersections() -{ - return FFlag::LuauFixCyclicUnionsOfIntersections || FFlag::DebugLuauDeferredConstraintResolution; -} - static bool fixReduceStackPressure() { return FFlag::LuauFixReduceStackPressure || FFlag::DebugLuauDeferredConstraintResolution; @@ -1776,12 +1770,9 @@ NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId t } else if (const IntersectionType* itv = get(there)) { - if (fixCyclicUnionsOfIntersections()) - { - if (seenSetTypes.count(there)) - return NormalizationResult::True; - seenSetTypes.insert(there); - } + if (seenSetTypes.count(there)) + return NormalizationResult::True; + seenSetTypes.insert(there); NormalizedType norm{builtinTypes}; norm.tops = builtinTypes->anyType; @@ -1790,14 +1781,12 @@ NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId t NormalizationResult res = intersectNormalWithTy(norm, *it, seenSetTypes); if (res != NormalizationResult::True) { - if (fixCyclicUnionsOfIntersections()) - seenSetTypes.erase(there); + seenSetTypes.erase(there); return res; } } - if (fixCyclicUnionsOfIntersections()) - seenSetTypes.erase(there); + seenSetTypes.erase(there); return unionNormals(here, norm); } diff --git a/Analysis/src/TableLiteralInference.cpp b/Analysis/src/TableLiteralInference.cpp index 414544b6..3514ff65 100644 --- a/Analysis/src/TableLiteralInference.cpp +++ b/Analysis/src/TableLiteralInference.cpp @@ -337,7 +337,9 @@ TypeId matchLiteralType(NotNull> astTypes, TypeId matchedType = matchLiteralType(astTypes, astExpectedTypes, builtinTypes, arena, unifier, expectedTableTy->indexer->indexResultType, *propTy, item.value, toBlock); - tableTy->indexer->indexResultType = matchedType; + // if the index result type is the prop type, we can replace it with the matched type here. + if (tableTy->indexer->indexResultType == *propTy) + tableTy->indexer->indexResultType = matchedType; } } else if (item.kind == AstExprTable::Item::General) diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index 5ffeb951..cc02bea6 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -446,7 +446,6 @@ struct TypeChecker2 .errors; if (!isErrorSuppressing(location, instance)) reportErrors(std::move(errors)); - return instance; } @@ -1108,10 +1107,13 @@ struct TypeChecker2 void visit(AstStatCompoundAssign* stat) { AstExprBinary fake{stat->location, stat->op, stat->var, stat->value}; - TypeId resultTy = visit(&fake, stat); + visit(&fake, stat); + + TypeId* resultTy = module->astCompoundAssignResultTypes.find(stat); + LUAU_ASSERT(resultTy); TypeId varTy = lookupType(stat->var); - testIsSubtype(resultTy, varTy, stat->location); + testIsSubtype(*resultTy, varTy, stat->location); } void visit(AstStatFunction* stat) @@ -1857,7 +1859,7 @@ struct TypeChecker2 bool isStringOperation = (normLeft ? normLeft->isSubtypeOfString() : isString(leftType)) && (normRight ? normRight->isSubtypeOfString() : isString(rightType)); - + leftType = follow(leftType); if (get(leftType) || get(leftType) || get(leftType)) return leftType; else if (get(rightType) || get(rightType) || get(rightType)) diff --git a/Analysis/src/TypeFamily.cpp b/Analysis/src/TypeFamily.cpp index 89de1912..c65fde00 100644 --- a/Analysis/src/TypeFamily.cpp +++ b/Analysis/src/TypeFamily.cpp @@ -1490,6 +1490,18 @@ TypeFamilyReductionResult refineFamilyFn(TypeId instance, const std::vec if (get(follow(nt->ty))) return {targetTy, false, {}, {}}; + // If the target type is a table, then simplification already implements the logic to deal with refinements properly since the + // type of the discriminant is guaranteed to only ever be an (arbitrarily-nested) table of a single property type. + if (get(targetTy)) + { + SimplifyResult result = simplifyIntersection(ctx->builtins, ctx->arena, targetTy, discriminantTy); + if (!result.blockedTypes.empty()) + return {std::nullopt, false, {result.blockedTypes.begin(), result.blockedTypes.end()}, {}}; + + return {result.result, false, {}, {}}; + } + + // In the general case, we'll still use normalization though. TypeId intersection = ctx->arena->addType(IntersectionType{{targetTy, discriminantTy}}); std::shared_ptr normIntersection = ctx->normalizer->normalize(intersection); std::shared_ptr normType = ctx->normalizer->normalize(targetTy); @@ -1853,6 +1865,208 @@ TypeFamilyReductionResult rawkeyofFamilyFn(TypeId instance, const std::v return keyofFamilyImpl(typeParams, packParams, ctx, /* isRaw */ true); } +/* Searches through table's or class's props/indexer to find the property of `ty` + If found, appends that property to `result` and returns true + Else, returns false */ +bool searchPropsAndIndexer( + TypeId ty, TableType::Props tblProps, std::optional tblIndexer, DenseHashSet& result, NotNull ctx) +{ + ty = follow(ty); + + // index into tbl's properties + if (auto stringSingleton = get(get(ty))) + { + if (tblProps.find(stringSingleton->value) != tblProps.end()) + { + TypeId propTy = follow(tblProps.at(stringSingleton->value).type()); + + // property is a union type -> we need to extend our reduction type + if (auto propUnionTy = get(propTy)) + { + for (TypeId option : propUnionTy->options) + result.insert(option); + } + else // property is a singular type or intersection type -> we can simply append + result.insert(propTy); + + return true; + } + } + + // index into tbl's indexer + if (tblIndexer) + { + if (isSubtype(ty, tblIndexer->indexType, ctx->scope, ctx->builtins, *ctx->ice)) + { + TypeId idxResultTy = follow(tblIndexer->indexResultType); + + // indexResultType is a union type -> we need to extend our reduction type + if (auto idxResUnionTy = get(idxResultTy)) + { + for (TypeId option : idxResUnionTy->options) + result.insert(option); + } + else // indexResultType is a singular type or intersection type -> we can simply append + result.insert(idxResultTy); + + return true; + } + } + + return false; +} + +/* Handles recursion / metamethods of tables/classes + `isRaw` parameter indicates whether or not we should follow __index metamethods + returns false if property of `ty` could not be found */ +bool tblIndexInto(TypeId indexer, TypeId indexee, DenseHashSet& result, NotNull ctx, bool isRaw) +{ + indexer = follow(indexer); + indexee = follow(indexee); + + // we have a table type to try indexing + if (auto tableTy = get(indexee)) + { + return searchPropsAndIndexer(indexer, tableTy->props, tableTy->indexer, result, ctx); + } + + // we have a metatable type to try indexing + if (auto metatableTy = get(indexee)) + { + if (auto tableTy = get(metatableTy->table)) + { + + // try finding all properties within the current scope of the table + if (searchPropsAndIndexer(indexer, tableTy->props, tableTy->indexer, result, ctx)) + return true; + } + + // if the code reached here, it means we weren't able to find all properties -> look into __index metamethod + if (!isRaw) + { + // findMetatableEntry demands the ability to emit errors, so we must give it + // the necessary state to do that, even if we intend to just eat the errors. + ErrorVec dummy; + std::optional mmType = findMetatableEntry(ctx->builtins, dummy, indexee, "__index", Location{}); + if (mmType) + return tblIndexInto(indexer, *mmType, result, ctx, isRaw); + } + } + + return false; +} + +/* Vocabulary note: indexee refers to the type that contains the properties, + indexer refers to the type that is used to access indexee + Example: index => `Person` is the indexee and `"name"` is the indexer */ +TypeFamilyReductionResult indexFamilyFn( + TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) +{ + if (typeParams.size() != 2 || !packParams.empty()) + { + ctx->ice->ice("index type family: encountered a type family instance without the required argument structure"); + LUAU_ASSERT(false); + } + + TypeId indexeeTy = follow(typeParams.at(0)); + std::shared_ptr indexeeNormTy = ctx->normalizer->normalize(indexeeTy); + + // if the indexee failed to normalize, we can't reduce, but know nothing about inhabitance. + if (!indexeeNormTy) + return {std::nullopt, false, {}, {}}; + + // if we don't have either just tables or just classes, we've got nothing to index into + if (indexeeNormTy->hasTables() == indexeeNormTy->hasClasses()) + return {std::nullopt, true, {}, {}}; + + // we're trying to reject any type that has not normalized to a table/class or a union of tables/classes. + if (indexeeNormTy->hasTops() || indexeeNormTy->hasBooleans() || indexeeNormTy->hasErrors() || indexeeNormTy->hasNils() || + indexeeNormTy->hasNumbers() || indexeeNormTy->hasStrings() || indexeeNormTy->hasThreads() || indexeeNormTy->hasBuffers() || + indexeeNormTy->hasFunctions() || indexeeNormTy->hasTyvars()) + return {std::nullopt, true, {}, {}}; + + TypeId indexerTy = follow(typeParams.at(1)); + std::shared_ptr indexerNormTy = ctx->normalizer->normalize(indexerTy); + + // if the indexer failed to normalize, we can't reduce, but know nothing about inhabitance. + if (!indexerNormTy) + return {std::nullopt, false, {}, {}}; + + // we're trying to reject any type that is not a string singleton or primitive (string, number, boolean, thread, nil, function, table, or buffer) + if (indexerNormTy->hasTops() || indexerNormTy->hasErrors()) + return {std::nullopt, true, {}, {}}; + + // indexer can be a union —> break them down into a vector + const std::vector* typesToFind; + const std::vector singleType{indexerTy}; + if (auto unionTy = get(indexerTy)) + typesToFind = &unionTy->options; + else + typesToFind = &singleType; + + DenseHashSet properties{{}}; // vector of types that will be returned + bool isRaw = false; + + if (indexeeNormTy->hasClasses()) + { + LUAU_ASSERT(!indexeeNormTy->hasTables()); + + // at least one class is guaranteed to be in the iterator by .hasClasses() + for (auto classesIter = indexeeNormTy->classes.ordering.begin(); classesIter != indexeeNormTy->classes.ordering.end(); ++classesIter) + { + auto classTy = get(*classesIter); + if (!classTy) + { + LUAU_ASSERT(false); // this should not be possible according to normalization's spec + return {std::nullopt, true, {}, {}}; + } + + for (TypeId ty : *typesToFind) + { + // Search for all instances of indexer in class->props and class->indexer using `indexInto` + if (searchPropsAndIndexer(ty, classTy->props, classTy->indexer, properties, ctx)) + continue; // Indexer was found in this class, so we can move on to the next + + // If code reaches here,that means the property not found -> check in the metatable's __index + + // findMetatableEntry demands the ability to emit errors, so we must give it + // the necessary state to do that, even if we intend to just eat the errors. + ErrorVec dummy; + std::optional mmType = findMetatableEntry(ctx->builtins, dummy, *classesIter, "__index", Location{}); + if (!mmType) // if a metatable does not exist, there is no where else to look + return {std::nullopt, true, {}, {}}; + + if (!tblIndexInto(ty, *mmType, properties, ctx, isRaw)) // if indexer is not in the metatable, we fail to reduce + return {std::nullopt, true, {}, {}}; + } + } + } + + if (indexeeNormTy->hasTables()) + { + LUAU_ASSERT(!indexeeNormTy->hasClasses()); + + // at least one table is guaranteed to be in the iterator by .hasTables() + for (auto tablesIter = indexeeNormTy->tables.begin(); tablesIter != indexeeNormTy->tables.end(); ++tablesIter) + { + for (TypeId ty : *typesToFind) + if (!tblIndexInto(ty, *tablesIter, properties, ctx, isRaw)) + return {std::nullopt, true, {}, {}}; + } + } + + // Call `follow()` on each element to resolve all Bound types before returning + std::transform(properties.begin(), properties.end(), properties.begin(), [](TypeId ty) { + return follow(ty); + }); + + // If the type being reduced to is a single type, no need to union + if (properties.size() == 1) + return {*properties.begin(), false, {}, {}}; + + return {ctx->arena->addType(UnionType{std::vector(properties.begin(), properties.end())}), false, {}, {}}; +} + BuiltinTypeFamilies::BuiltinTypeFamilies() : notFamily{"not", notFamilyFn} , lenFamily{"len", lenFamilyFn} @@ -1876,6 +2090,7 @@ BuiltinTypeFamilies::BuiltinTypeFamilies() , intersectFamily{"intersect", intersectFamilyFn} , keyofFamily{"keyof", keyofFamilyFn} , rawkeyofFamily{"rawkeyof", rawkeyofFamilyFn} + , indexFamily{"index", indexFamilyFn} { } @@ -1917,6 +2132,8 @@ void BuiltinTypeFamilies::addToScope(NotNull arena, NotNull sc scope->exportedTypeBindings[keyofFamily.name] = mkUnaryTypeFamily(&keyofFamily); scope->exportedTypeBindings[rawkeyofFamily.name] = mkUnaryTypeFamily(&rawkeyofFamily); + + scope->exportedTypeBindings[indexFamily.name] = mkBinaryTypeFamily(&indexFamily); } } // namespace Luau diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index eed3c715..3050f09e 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -37,7 +37,6 @@ LUAU_FASTFLAGVARIABLE(LuauMetatableInstantiationCloneCheck, false) LUAU_FASTFLAGVARIABLE(LuauTinyControlFlowAnalysis, false) LUAU_FASTFLAGVARIABLE(LuauAlwaysCommitInferencesOfFunctionCalls, false) LUAU_FASTFLAGVARIABLE(LuauRemoveBadRelationalOperatorWarning, false) -LUAU_FASTFLAGVARIABLE(LuauForbidAliasNamedTypeof, false) LUAU_FASTFLAGVARIABLE(LuauOkWithIteratingOverTableProperties, false) namespace Luau @@ -667,7 +666,7 @@ LUAU_NOINLINE void TypeChecker::checkBlockTypeAliases(const ScopePtr& scope, std { if (const auto& typealias = stat->as()) { - if (typealias->name == kParseNameError || (FFlag::LuauForbidAliasNamedTypeof && typealias->name == "typeof")) + if (typealias->name == kParseNameError || typealias->name == "typeof") continue; auto& bindings = typealias->exported ? scope->exportedTypeBindings : scope->privateTypeBindings; @@ -1535,7 +1534,7 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& ty if (name == kParseNameError) return ControlFlow::None; - if (FFlag::LuauForbidAliasNamedTypeof && name == "typeof") + if (name == "typeof") { reportError(typealias.location, GenericError{"Type aliases cannot be named typeof"}); return ControlFlow::None; @@ -1656,7 +1655,7 @@ void TypeChecker::prototype(const ScopePtr& scope, const AstStatTypeAlias& typea // If the alias is missing a name, we can't do anything with it. Ignore it. // Also, typeof is not a valid type alias name. We will report an error for // this in check() - if (name == kParseNameError || (FFlag::LuauForbidAliasNamedTypeof && name == "typeof")) + if (name == kParseNameError || name == "typeof") return; std::optional binding; diff --git a/Ast/include/Luau/Ast.h b/Ast/include/Luau/Ast.h index e8479e09..ab0d40e2 100644 --- a/Ast/include/Luau/Ast.h +++ b/Ast/include/Luau/Ast.h @@ -207,6 +207,7 @@ public: enum Type { Checked, + Native, }; AstAttr(const Location& location, Type type); @@ -420,6 +421,8 @@ public: void visit(AstVisitor* visitor) override; + bool hasNativeAttribute() const; + AstArray attributes; AstArray generics; AstArray genericPacks; diff --git a/Ast/include/Luau/Parser.h b/Ast/include/Luau/Parser.h index c1fd43ea..5a945e26 100644 --- a/Ast/include/Luau/Parser.h +++ b/Ast/include/Luau/Parser.h @@ -234,7 +234,7 @@ private: // asexp -> simpleexp [`::' Type] AstExpr* parseAssertionExpr(); - // simpleexp -> NUMBER | STRING | NIL | true | false | ... | constructor | FUNCTION body | primaryexp + // simpleexp -> NUMBER | STRING | NIL | true | false | ... | constructor | [attributes] FUNCTION body | primaryexp AstExpr* parseSimpleExpr(); // args ::= `(' [explist] `)' | tableconstructor | String diff --git a/Ast/src/Ast.cpp b/Ast/src/Ast.cpp index 4c956307..14b79767 100644 --- a/Ast/src/Ast.cpp +++ b/Ast/src/Ast.cpp @@ -4,6 +4,7 @@ #include "Luau/Common.h" LUAU_FASTFLAG(LuauAttributeSyntax); +LUAU_FASTFLAG(LuauNativeAttribute); namespace Luau { @@ -214,6 +215,18 @@ void AstExprFunction::visit(AstVisitor* visitor) } } +bool AstExprFunction::hasNativeAttribute() const +{ + LUAU_ASSERT(FFlag::LuauNativeAttribute); + + for (const auto attribute : attributes) + { + if (attribute->type == AstAttr::Type::Native) + return true; + } + return false; +} + AstExprTable::AstExprTable(const Location& location, const AstArray& items) : AstExpr(ClassIndex(), location) , items(items) diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index d80878d5..3a6625a5 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -18,7 +18,9 @@ LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) // See docs/SyntaxChanges.md for an explanation. LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false) LUAU_FASTFLAG(LuauAttributeSyntax) -LUAU_FASTFLAGVARIABLE(LuauLeadingBarAndAmpersand, false) +LUAU_FASTFLAGVARIABLE(LuauLeadingBarAndAmpersand2, false) +LUAU_FASTFLAGVARIABLE(LuauNativeAttribute, false) +LUAU_FASTFLAGVARIABLE(LuauAttributeSyntaxFunExpr, false) namespace Luau { @@ -29,7 +31,7 @@ struct AttributeEntry AstAttr::Type type; }; -AttributeEntry kAttributeEntries[] = {{"@checked", AstAttr::Type::Checked}, {nullptr, AstAttr::Type::Checked}}; +AttributeEntry kAttributeEntries[] = {{"@checked", AstAttr::Type::Checked}, {"@native", AstAttr::Type::Native}, {nullptr, AstAttr::Type::Checked}}; ParseError::ParseError(const Location& location, const std::string& message) : location(location) @@ -703,6 +705,10 @@ std::pair Parser::validateAttribute(const char* attributeNa if (found) { type = kAttributeEntries[i].type; + + if (!FFlag::LuauNativeAttribute && type == AstAttr::Type::Native) + found = false; + break; } } @@ -772,7 +778,7 @@ AstStat* Parser::parseAttributeStat() { LUAU_ASSERT(FFlag::LuauAttributeSyntax); - AstArray attributes = Parser::parseAttributes(); + AstArray attributes = parseAttributes(); Lexeme::Type type = lexer.current().type; @@ -1654,7 +1660,7 @@ AstType* Parser::parseTypeSuffix(AstType* type, const Location& begin) { TempVector parts(scratchType); - if (!FFlag::LuauLeadingBarAndAmpersand || type != nullptr) + if (!FFlag::LuauLeadingBarAndAmpersand2 || type != nullptr) { parts.push_back(type); } @@ -1682,6 +1688,8 @@ AstType* Parser::parseTypeSuffix(AstType* type, const Location& begin) } else if (c == '?') { + LUAU_ASSERT(parts.size() >= 1); + Location loc = lexer.current().location; nextLexeme(); @@ -1714,7 +1722,7 @@ AstType* Parser::parseTypeSuffix(AstType* type, const Location& begin) } if (parts.size() == 1) - return type; + return FFlag::LuauLeadingBarAndAmpersand2 ? parts[0] : type; if (isUnion && isIntersection) { @@ -1761,7 +1769,7 @@ AstType* Parser::parseType(bool inDeclarationContext) Location begin = lexer.current().location; - if (FFlag::LuauLeadingBarAndAmpersand) + if (FFlag::LuauLeadingBarAndAmpersand2) { AstType* type = nullptr; @@ -2369,11 +2377,24 @@ static ConstantNumberParseResult parseDouble(double& result, const char* data) return ConstantNumberParseResult::Ok; } -// simpleexp -> NUMBER | STRING | NIL | true | false | ... | constructor | FUNCTION body | primaryexp +// simpleexp -> NUMBER | STRING | NIL | true | false | ... | constructor | [attributes] FUNCTION body | primaryexp AstExpr* Parser::parseSimpleExpr() { Location start = lexer.current().location; + AstArray attributes{nullptr, 0}; + + if (FFlag::LuauAttributeSyntax && FFlag::LuauAttributeSyntaxFunExpr && lexer.current().type == Lexeme::Attribute) + { + attributes = parseAttributes(); + + if (lexer.current().type != Lexeme::ReservedFunction) + { + return reportExprError( + start, {}, "Expected 'function' declaration after attribute, but got %s intead", lexer.current().toString().c_str()); + } + } + if (lexer.current().type == Lexeme::ReservedNil) { nextLexeme(); @@ -2397,7 +2418,7 @@ AstExpr* Parser::parseSimpleExpr() Lexeme matchFunction = lexer.current(); nextLexeme(); - return parseFunctionBody(false, matchFunction, AstName(), nullptr, AstArray({nullptr, 0})).first; + return parseFunctionBody(false, matchFunction, AstName(), nullptr, attributes).first; } else if (lexer.current().type == Lexeme::Number) { diff --git a/CodeGen/include/Luau/IrBuilder.h b/CodeGen/include/Luau/IrBuilder.h index 8ad75fbe..2077cce0 100644 --- a/CodeGen/include/Luau/IrBuilder.h +++ b/CodeGen/include/Luau/IrBuilder.h @@ -54,6 +54,7 @@ struct IrBuilder IrOp inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d); IrOp inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d, IrOp e); IrOp inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d, IrOp e, IrOp f); + IrOp inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d, IrOp e, IrOp f, IrOp g); IrOp block(IrBlockKind kind); // Requested kind can be ignored if we are in an outlined sequence IrOp blockAtInst(uint32_t index); diff --git a/CodeGen/include/Luau/IrData.h b/CodeGen/include/Luau/IrData.h index 60af706f..d0e40ca3 100644 --- a/CodeGen/include/Luau/IrData.h +++ b/CodeGen/include/Luau/IrData.h @@ -31,7 +31,7 @@ enum // * Rn - VM stack register slot, n in 0..254 // * Kn - VM proto constant slot, n in 0..2^23-1 // * UPn - VM function upvalue slot, n in 0..199 -// * A, B, C, D, E are instruction arguments +// * A, B, C, D, E, F, G are instruction arguments enum class IrCmd : uint8_t { NOP, @@ -869,6 +869,7 @@ struct IrInst IrOp d; IrOp e; IrOp f; + IrOp g; uint32_t lastUse = 0; uint16_t useCount = 0; @@ -923,6 +924,7 @@ struct IrInstHash h = mix(h, key.d); h = mix(h, key.e); h = mix(h, key.f); + h = mix(h, key.g); // MurmurHash2 tail h ^= h >> 13; @@ -937,7 +939,7 @@ struct IrInstEq { bool operator()(const IrInst& a, const IrInst& b) const { - return a.cmd == b.cmd && a.a == b.a && a.b == b.b && a.c == b.c && a.d == b.d && a.e == b.e && a.f == b.f; + return a.cmd == b.cmd && a.a == b.a && a.b == b.b && a.c == b.c && a.d == b.d && a.e == b.e && a.f == b.f && a.g == b.g; } }; diff --git a/CodeGen/include/Luau/IrVisitUseDef.h b/CodeGen/include/Luau/IrVisitUseDef.h index 58c88661..32dd6c2a 100644 --- a/CodeGen/include/Luau/IrVisitUseDef.h +++ b/CodeGen/include/Luau/IrVisitUseDef.h @@ -228,6 +228,7 @@ static void visitVmRegDefsUses(T& visitor, IrFunction& function, const IrInst& i CODEGEN_ASSERT(inst.d.kind != IrOpKind::VmReg); CODEGEN_ASSERT(inst.e.kind != IrOpKind::VmReg); CODEGEN_ASSERT(inst.f.kind != IrOpKind::VmReg); + CODEGEN_ASSERT(inst.g.kind != IrOpKind::VmReg); break; } } diff --git a/CodeGen/src/BytecodeSummary.cpp b/CodeGen/src/BytecodeSummary.cpp index 0089f592..d0d71504 100644 --- a/CodeGen/src/BytecodeSummary.cpp +++ b/CodeGen/src/BytecodeSummary.cpp @@ -8,6 +8,8 @@ #include "lobject.h" #include "lstate.h" +LUAU_FASTFLAG(LuauNativeAttribute) + namespace Luau { namespace CodeGen @@ -56,7 +58,10 @@ std::vector summarizeBytecode(lua_State* L, int idx, un Proto* root = clvalue(func)->l.p; std::vector protos; - gatherFunctions(protos, root, CodeGen_ColdFunctions); + if (FFlag::LuauNativeAttribute) + gatherFunctions(protos, root, CodeGen_ColdFunctions, root->flags & LPF_NATIVE_FUNCTION); + else + gatherFunctions_DEPRECATED(protos, root, CodeGen_ColdFunctions); std::vector summaries; summaries.reserve(protos.size()); diff --git a/CodeGen/src/CodeGenAssembly.cpp b/CodeGen/src/CodeGenAssembly.cpp index 269bf8dc..121535be 100644 --- a/CodeGen/src/CodeGenAssembly.cpp +++ b/CodeGen/src/CodeGenAssembly.cpp @@ -14,6 +14,7 @@ LUAU_FASTFLAG(LuauCodegenTypeInfo) LUAU_FASTFLAG(LuauLoadUserdataInfo) +LUAU_FASTFLAG(LuauNativeAttribute) namespace Luau { @@ -200,7 +201,10 @@ static std::string getAssemblyImpl(AssemblyBuilder& build, const TValue* func, A return std::string(); std::vector protos; - gatherFunctions(protos, root, options.compilationOptions.flags); + if (FFlag::LuauNativeAttribute) + gatherFunctions(protos, root, options.compilationOptions.flags, root->flags & LPF_NATIVE_FUNCTION); + else + gatherFunctions_DEPRECATED(protos, root, options.compilationOptions.flags); protos.erase(std::remove_if(protos.begin(), protos.end(), [](Proto* p) { diff --git a/CodeGen/src/CodeGenContext.cpp b/CodeGen/src/CodeGenContext.cpp index ae9e41f1..67a2676e 100644 --- a/CodeGen/src/CodeGenContext.cpp +++ b/CodeGen/src/CodeGenContext.cpp @@ -16,6 +16,7 @@ LUAU_FASTFLAGVARIABLE(LuauCodegenCheckNullContext, false) LUAU_FASTINTVARIABLE(LuauCodeGenBlockSize, 4 * 1024 * 1024) LUAU_FASTINTVARIABLE(LuauCodeGenMaxTotalSize, 256 * 1024 * 1024) +LUAU_FASTFLAG(LuauNativeAttribute) namespace Luau { @@ -455,7 +456,7 @@ template Proto* root = clvalue(func)->l.p; - if ((options.flags & CodeGen_OnlyNativeModules) != 0 && (root->flags & LPF_NATIVE_MODULE) == 0) + if ((options.flags & CodeGen_OnlyNativeModules) != 0 && (root->flags & LPF_NATIVE_MODULE) == 0 && (root->flags & LPF_NATIVE_FUNCTION) == 0) return CompilationResult{CodeGenCompilationResult::NotNativeModule}; BaseCodeGenContext* codeGenContext = getCodeGenContext(L); @@ -463,7 +464,10 @@ template return CompilationResult{CodeGenCompilationResult::CodeGenNotInitialized}; std::vector protos; - gatherFunctions(protos, root, options.flags); + if (FFlag::LuauNativeAttribute) + gatherFunctions(protos, root, options.flags, root->flags & LPF_NATIVE_FUNCTION); + else + gatherFunctions_DEPRECATED(protos, root, options.flags); // Skip protos that have been compiled during previous invocations of CodeGen::compile protos.erase(std::remove_if(protos.begin(), protos.end(), diff --git a/CodeGen/src/CodeGenLower.h b/CodeGen/src/CodeGenLower.h index 6015ef10..4523d62b 100644 --- a/CodeGen/src/CodeGenLower.h +++ b/CodeGen/src/CodeGenLower.h @@ -29,13 +29,14 @@ LUAU_FASTINT(CodegenHeuristicsBlockLimit) LUAU_FASTINT(CodegenHeuristicsBlockInstructionLimit) LUAU_FASTFLAG(LuauCodegenRemoveDeadStores5) LUAU_FASTFLAG(LuauLoadUserdataInfo) +LUAU_FASTFLAG(LuauNativeAttribute) namespace Luau { namespace CodeGen { -inline void gatherFunctions(std::vector& results, Proto* proto, unsigned int flags) +inline void gatherFunctions_DEPRECATED(std::vector& results, Proto* proto, unsigned int flags) { if (results.size() <= size_t(proto->bytecodeid)) results.resize(proto->bytecodeid + 1); @@ -50,7 +51,36 @@ inline void gatherFunctions(std::vector& results, Proto* proto, unsigned // Recursively traverse child protos even if we aren't compiling this one for (int i = 0; i < proto->sizep; i++) - gatherFunctions(results, proto->p[i], flags); + gatherFunctions_DEPRECATED(results, proto->p[i], flags); +} + +inline void gatherFunctionsHelper( + std::vector& results, Proto* proto, const unsigned int flags, const bool hasNativeFunctions, const bool root) +{ + if (results.size() <= size_t(proto->bytecodeid)) + results.resize(proto->bytecodeid + 1); + + // Skip protos that we've already compiled in this run: this happens because at -O2, inlined functions get their protos reused + if (results[proto->bytecodeid]) + return; + + // if native module, compile cold functions if requested + // if not native module, compile function if it has native attribute and is not root + bool shouldGather = hasNativeFunctions ? (!root && (proto->flags & LPF_NATIVE_FUNCTION) != 0) + : ((proto->flags & LPF_NATIVE_COLD) == 0 || (flags & CodeGen_ColdFunctions) != 0); + + if (shouldGather) + results[proto->bytecodeid] = proto; + + // Recursively traverse child protos even if we aren't compiling this one + for (int i = 0; i < proto->sizep; i++) + gatherFunctionsHelper(results, proto->p[i], flags, hasNativeFunctions, false); +} + +inline void gatherFunctions(std::vector& results, Proto* root, const unsigned int flags, const bool hasNativeFunctions = false) +{ + LUAU_ASSERT(FFlag::LuauNativeAttribute); + gatherFunctionsHelper(results, root, flags, hasNativeFunctions, true); } inline unsigned getInstructionCount(const std::vector& instructions, IrCmd cmd) diff --git a/CodeGen/src/IrAnalysis.cpp b/CodeGen/src/IrAnalysis.cpp index 30ed42a0..f78823df 100644 --- a/CodeGen/src/IrAnalysis.cpp +++ b/CodeGen/src/IrAnalysis.cpp @@ -13,6 +13,8 @@ #include +LUAU_FASTFLAGVARIABLE(LuauCodegenInstG, false) + namespace Luau { namespace CodeGen @@ -52,6 +54,9 @@ void updateUseCounts(IrFunction& function) checkOp(inst.d); checkOp(inst.e); checkOp(inst.f); + + if (FFlag::LuauCodegenInstG) + checkOp(inst.g); } } @@ -95,6 +100,9 @@ void updateLastUseLocations(IrFunction& function, const std::vector& s checkOp(inst.d); checkOp(inst.e); checkOp(inst.f); + + if (FFlag::LuauCodegenInstG) + checkOp(inst.g); } } } @@ -128,6 +136,12 @@ uint32_t getNextInstUse(IrFunction& function, uint32_t targetInstIdx, uint32_t s if (inst.f.kind == IrOpKind::Inst && inst.f.index == targetInstIdx) return i; + + if (FFlag::LuauCodegenInstG) + { + if (inst.g.kind == IrOpKind::Inst && inst.g.index == targetInstIdx) + return i; + } } // There must be a next use since there is the last use location @@ -165,6 +179,9 @@ std::pair getLiveInOutValueCount(IrFunction& function, IrBlo checkOp(inst.d); checkOp(inst.e); checkOp(inst.f); + + if (FFlag::LuauCodegenInstG) + checkOp(inst.g); } return std::make_pair(liveIns, liveOuts); @@ -488,6 +505,9 @@ static void computeCfgBlockEdges(IrFunction& function) checkOp(inst.d); checkOp(inst.e); checkOp(inst.f); + + if (FFlag::LuauCodegenInstG) + checkOp(inst.g); } } diff --git a/CodeGen/src/IrBuilder.cpp b/CodeGen/src/IrBuilder.cpp index 723d35c4..e62885eb 100644 --- a/CodeGen/src/IrBuilder.cpp +++ b/CodeGen/src/IrBuilder.cpp @@ -16,6 +16,7 @@ LUAU_FASTFLAG(LuauLoadTypeInfo) // Because new VM typeinfo load changes the format used by Codegen, same flag is used LUAU_FASTFLAG(LuauCodegenAnalyzeHostVectorOps) LUAU_FASTFLAG(LuauLoadUserdataInfo) +LUAU_FASTFLAG(LuauCodegenInstG) namespace Luau { @@ -741,6 +742,9 @@ void IrBuilder::clone(const IrBlock& source, bool removeCurrentTerminator) redirect(clone.e); redirect(clone.f); + if (FFlag::LuauCodegenInstG) + redirect(clone.g); + addUse(function, clone.a); addUse(function, clone.b); addUse(function, clone.c); @@ -748,11 +752,17 @@ void IrBuilder::clone(const IrBlock& source, bool removeCurrentTerminator) addUse(function, clone.e); addUse(function, clone.f); + if (FFlag::LuauCodegenInstG) + addUse(function, clone.g); + // Instructions that referenced the original will have to be adjusted to use the clone instRedir[index] = uint32_t(function.instructions.size()); // Reconstruct the fresh clone - inst(clone.cmd, clone.a, clone.b, clone.c, clone.d, clone.e, clone.f); + if (FFlag::LuauCodegenInstG) + inst(clone.cmd, clone.a, clone.b, clone.c, clone.d, clone.e, clone.f, clone.g); + else + inst(clone.cmd, clone.a, clone.b, clone.c, clone.d, clone.e, clone.f); } } @@ -850,8 +860,33 @@ IrOp IrBuilder::inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d, IrOp e) IrOp IrBuilder::inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d, IrOp e, IrOp f) { + if (FFlag::LuauCodegenInstG) + { + return inst(cmd, a, b, c, d, e, f, {}); + } + else + { + uint32_t index = uint32_t(function.instructions.size()); + function.instructions.push_back({cmd, a, b, c, d, e, f}); + + CODEGEN_ASSERT(!inTerminatedBlock); + + if (isBlockTerminator(cmd)) + { + function.blocks[activeBlockIdx].finish = index; + inTerminatedBlock = true; + } + + return {IrOpKind::Inst, index}; + } +} + +IrOp IrBuilder::inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d, IrOp e, IrOp f, IrOp g) +{ + CODEGEN_ASSERT(FFlag::LuauCodegenInstG); + uint32_t index = uint32_t(function.instructions.size()); - function.instructions.push_back({cmd, a, b, c, d, e, f}); + function.instructions.push_back({cmd, a, b, c, d, e, f, g}); CODEGEN_ASSERT(!inTerminatedBlock); diff --git a/CodeGen/src/IrDump.cpp b/CodeGen/src/IrDump.cpp index a82ee894..5465d0a0 100644 --- a/CodeGen/src/IrDump.cpp +++ b/CodeGen/src/IrDump.cpp @@ -8,6 +8,7 @@ #include LUAU_FASTFLAG(LuauLoadUserdataInfo) +LUAU_FASTFLAG(LuauCodegenInstG) namespace Luau { @@ -417,6 +418,9 @@ void toString(IrToStringContext& ctx, const IrInst& inst, uint32_t index) checkOp(inst.d, ", "); checkOp(inst.e, ", "); checkOp(inst.f, ", "); + + if (FFlag::LuauCodegenInstG) + checkOp(inst.g, ", "); } void toString(IrToStringContext& ctx, const IrBlock& block, uint32_t index) @@ -656,6 +660,8 @@ static RegisterSet getJumpTargetExtraLiveIn(IrToStringContext& ctx, const IrBloc op = inst.e; else if (inst.f.kind == IrOpKind::Block) op = inst.f; + else if (FFlag::LuauCodegenInstG && inst.g.kind == IrOpKind::Block) + op = inst.g; if (op.kind == IrOpKind::Block && op.index < ctx.cfg.in.size()) { @@ -940,6 +946,9 @@ std::string toDot(const IrFunction& function, bool includeInst) checkOp(inst.d); checkOp(inst.e); checkOp(inst.f); + + if (FFlag::LuauCodegenInstG) + checkOp(inst.g); } } diff --git a/CodeGen/src/IrRegAllocA64.cpp b/CodeGen/src/IrRegAllocA64.cpp index 24b0b285..af63a2fc 100644 --- a/CodeGen/src/IrRegAllocA64.cpp +++ b/CodeGen/src/IrRegAllocA64.cpp @@ -11,6 +11,7 @@ #include LUAU_FASTFLAGVARIABLE(DebugCodegenChaosA64, false) +LUAU_FASTFLAG(LuauCodegenInstG) namespace Luau { @@ -256,6 +257,9 @@ void IrRegAllocA64::freeLastUseRegs(const IrInst& inst, uint32_t index) checkOp(inst.d); checkOp(inst.e); checkOp(inst.f); + + if (FFlag::LuauCodegenInstG) + checkOp(inst.g); } void IrRegAllocA64::freeTempRegs() diff --git a/CodeGen/src/IrRegAllocX64.cpp b/CodeGen/src/IrRegAllocX64.cpp index 2b5da623..60326074 100644 --- a/CodeGen/src/IrRegAllocX64.cpp +++ b/CodeGen/src/IrRegAllocX64.cpp @@ -6,6 +6,8 @@ #include "EmitCommonX64.h" +LUAU_FASTFLAG(LuauCodegenInstG) + namespace Luau { namespace CodeGen @@ -181,6 +183,9 @@ void IrRegAllocX64::freeLastUseRegs(const IrInst& inst, uint32_t instIdx) checkOp(inst.d); checkOp(inst.e); checkOp(inst.f); + + if (FFlag::LuauCodegenInstG) + checkOp(inst.g); } bool IrRegAllocX64::isLastUseReg(const IrInst& target, uint32_t instIdx) const diff --git a/CodeGen/src/IrUtils.cpp b/CodeGen/src/IrUtils.cpp index d1bfca45..2244c4d3 100644 --- a/CodeGen/src/IrUtils.cpp +++ b/CodeGen/src/IrUtils.cpp @@ -12,6 +12,8 @@ #include #include +LUAU_FASTFLAG(LuauCodegenInstG) + namespace Luau { namespace CodeGen @@ -315,12 +317,18 @@ void kill(IrFunction& function, IrInst& inst) removeUse(function, inst.e); removeUse(function, inst.f); + if (FFlag::LuauCodegenInstG) + removeUse(function, inst.g); + inst.a = {}; inst.b = {}; inst.c = {}; inst.d = {}; inst.e = {}; inst.f = {}; + + if (FFlag::LuauCodegenInstG) + inst.g = {}; } void kill(IrFunction& function, uint32_t start, uint32_t end) @@ -370,6 +378,9 @@ void replace(IrFunction& function, IrBlock& block, uint32_t instIdx, IrInst repl addUse(function, replacement.e); addUse(function, replacement.f); + if (FFlag::LuauCodegenInstG) + addUse(function, replacement.g); + // An extra reference is added so block will not remove itself block.useCount++; @@ -392,6 +403,9 @@ void replace(IrFunction& function, IrBlock& block, uint32_t instIdx, IrInst repl removeUse(function, inst.e); removeUse(function, inst.f); + if (FFlag::LuauCodegenInstG) + removeUse(function, inst.g); + // Inherit existing use count (last use is skipped as it will be defined later) replacement.useCount = inst.useCount; @@ -417,12 +431,18 @@ void substitute(IrFunction& function, IrInst& inst, IrOp replacement) removeUse(function, inst.e); removeUse(function, inst.f); + if (FFlag::LuauCodegenInstG) + removeUse(function, inst.g); + inst.a = replacement; inst.b = {}; inst.c = {}; inst.d = {}; inst.e = {}; inst.f = {}; + + if (FFlag::LuauCodegenInstG) + inst.g = {}; } void applySubstitutions(IrFunction& function, IrOp& op) @@ -466,6 +486,9 @@ void applySubstitutions(IrFunction& function, IrInst& inst) applySubstitutions(function, inst.d); applySubstitutions(function, inst.e); applySubstitutions(function, inst.f); + + if (FFlag::LuauCodegenInstG) + applySubstitutions(function, inst.g); } bool compare(double a, double b, IrCondition cond) diff --git a/CodeGen/src/IrValueLocationTracking.cpp b/CodeGen/src/IrValueLocationTracking.cpp index 3dc72610..c6b2d044 100644 --- a/CodeGen/src/IrValueLocationTracking.cpp +++ b/CodeGen/src/IrValueLocationTracking.cpp @@ -146,6 +146,7 @@ void IrValueLocationTracking::beforeInstLowering(IrInst& inst) CODEGEN_ASSERT(inst.d.kind != IrOpKind::VmReg); CODEGEN_ASSERT(inst.e.kind != IrOpKind::VmReg); CODEGEN_ASSERT(inst.f.kind != IrOpKind::VmReg); + CODEGEN_ASSERT(inst.g.kind != IrOpKind::VmReg); break; } } diff --git a/Common/include/Luau/Bytecode.h b/Common/include/Luau/Bytecode.h index 2ae54c67..85fef5aa 100644 --- a/Common/include/Luau/Bytecode.h +++ b/Common/include/Luau/Bytecode.h @@ -612,4 +612,6 @@ enum LuauProtoFlag LPF_NATIVE_MODULE = 1 << 0, // used to tag individual protos as not profitable to compile natively LPF_NATIVE_COLD = 1 << 1, + // used to tag main proto for modules that have at least one function with native attribute + LPF_NATIVE_FUNCTION = 1 << 2, }; diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 4842b9a1..db86fbc6 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -30,6 +30,8 @@ LUAU_FASTFLAG(LuauCompileTypeInfo) LUAU_FASTFLAGVARIABLE(LuauCompileTempTypeInfo, false) LUAU_FASTFLAGVARIABLE(LuauCompileUserdataInfo, false) +LUAU_FASTFLAG(LuauNativeAttribute) + namespace Luau { @@ -195,7 +197,7 @@ struct Compiler return node->as(); } - uint32_t compileFunction(AstExprFunction* func, uint8_t protoflags) + uint32_t compileFunction(AstExprFunction* func, uint8_t& protoflags) { LUAU_TIMETRACE_SCOPE("Compiler::compileFunction", "Compiler"); @@ -297,6 +299,9 @@ struct Compiler if (func->functionDepth == 0 && !hasLoops) protoflags |= LPF_NATIVE_COLD; + if (FFlag::LuauNativeAttribute && func->hasNativeAttribute()) + protoflags |= LPF_NATIVE_FUNCTION; + bytecode.endFunction(uint8_t(stackSize), uint8_t(upvals.size()), protoflags); Function& f = functions[func]; @@ -3863,13 +3868,12 @@ struct Compiler struct FunctionVisitor : AstVisitor { - Compiler* self; std::vector& functions; bool hasTypes = false; + bool hasNativeFunction = false; - FunctionVisitor(Compiler* self, std::vector& functions) - : self(self) - , functions(functions) + FunctionVisitor(std::vector& functions) + : functions(functions) { // preallocate the result; this works around std::vector's inefficient growth policy for small arrays functions.reserve(16); @@ -3885,6 +3889,9 @@ struct Compiler // this makes sure all functions that are used when compiling this one have been already added to the vector functions.push_back(node); + if (FFlag::LuauNativeAttribute && !hasNativeFunction && node->hasNativeAttribute()) + hasNativeFunction = true; + return false; } }; @@ -4117,6 +4124,14 @@ struct Compiler std::vector> interpStrings; }; +static void setCompileOptionsForNativeCompilation(CompileOptions& options) +{ + options.optimizationLevel = 2; // note: this might be removed in the future in favor of --!optimize + + if (FFlag::LuauCompileTypeInfo) + options.typeInfoLevel = 1; +} + void compileOrThrow(BytecodeBuilder& bytecode, const ParseResult& parseResult, const AstNameTable& names, const CompileOptions& inputOptions) { LUAU_TIMETRACE_SCOPE("compileOrThrow", "Compiler"); @@ -4135,15 +4150,21 @@ void compileOrThrow(BytecodeBuilder& bytecode, const ParseResult& parseResult, c if (hc.header && hc.content == "native") { mainFlags |= LPF_NATIVE_MODULE; - options.optimizationLevel = 2; // note: this might be removed in the future in favor of --!optimize - - if (FFlag::LuauCompileTypeInfo) - options.typeInfoLevel = 1; + setCompileOptionsForNativeCompilation(options); } } AstStatBlock* root = parseResult.root; + // gathers all functions with the invariant that all function references are to functions earlier in the list + // for example, function foo() return function() end end will result in two vector entries, [0] = anonymous and [1] = foo + std::vector functions; + Compiler::FunctionVisitor functionVisitor(functions); + root->visit(&functionVisitor); + + if (functionVisitor.hasNativeFunction) + setCompileOptionsForNativeCompilation(options); + Compiler compiler(bytecode, options); // since access to some global objects may result in values that change over time, we block imports from non-readonly tables @@ -4180,12 +4201,6 @@ void compileOrThrow(BytecodeBuilder& bytecode, const ParseResult& parseResult, c predictTableShapes(compiler.tableShapes, root); } - // gathers all functions with the invariant that all function references are to functions earlier in the list - // for example, function foo() return function() end end will result in two vector entries, [0] = anonymous and [1] = foo - std::vector functions; - Compiler::FunctionVisitor functionVisitor(&compiler, functions); - root->visit(&functionVisitor); - if (FFlag::LuauCompileUserdataInfo) { if (const char* const* ptr = options.userdataTypes) @@ -4217,7 +4232,15 @@ void compileOrThrow(BytecodeBuilder& bytecode, const ParseResult& parseResult, c } for (AstExprFunction* expr : functions) - compiler.compileFunction(expr, 0); + { + uint8_t protoflags = 0; + compiler.compileFunction(expr, protoflags); + + // If a function has native attribute and the whole module is not native, we set LPF_NATIVE_FUNCTION flag + // This ensures that LPF_NATIVE_MODULE and LPF_NATIVE_FUNCTION are exclusive. + if (FFlag::LuauNativeAttribute && (protoflags & LPF_NATIVE_FUNCTION) && !(mainFlags & LPF_NATIVE_MODULE)) + mainFlags |= LPF_NATIVE_FUNCTION; + } AstExprFunction main(root->location, /*attributes=*/AstArray({nullptr, 0}), /*generics= */ AstArray(), /*genericPacks= */ AstArray(), diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 9c1fca9e..516e02f4 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -34,6 +34,8 @@ void luaC_validate(lua_State* L); LUAU_FASTFLAG(DebugLuauAbortingChecks) LUAU_FASTINT(CodegenHeuristicsInstructionLimit) LUAU_FASTFLAG(LuauCodegenFixSplitStoreConstMismatch) +LUAU_FASTFLAG(LuauAttributeSyntax) +LUAU_FASTFLAG(LuauNativeAttribute) static lua_CompileOptions defaultOptions() { @@ -2707,4 +2709,57 @@ end 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})); } +TEST_CASE("NativeAttribute") +{ + if (!codegen || !luau_codegen_supported()) + return; + + ScopedFastFlag sffs[] = {{FFlag::LuauAttributeSyntax, true}, {FFlag::LuauNativeAttribute, true}}; + + std::string source = R"R( + @native + local function sum(x, y) + local function sumHelper(z) + return (x+y+z) + end + return sumHelper + end + + local function sub(x, y) + @native + local function subHelper(z) + return (x+y-z) + end + return subHelper + end)R"; + + StateRef globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + luau_codegen_create(L); + + luaL_openlibs(L); + luaL_sandbox(L); + luaL_sandboxthread(L); + + size_t bytecodeSize = 0; + char* bytecode = luau_compile(source.data(), source.size(), nullptr, &bytecodeSize); + int result = luau_load(L, "=Code", bytecode, bytecodeSize, 0); + free(bytecode); + + REQUIRE(result == 0); + + Luau::CodeGen::CompilationOptions nativeOptions{Luau::CodeGen::CodeGen_ColdFunctions}; + Luau::CodeGen::CompilationStats nativeStats = {}; + Luau::CodeGen::CompilationResult nativeResult = Luau::CodeGen::compile(L, -1, nativeOptions, &nativeStats); + + CHECK(nativeResult.result == Luau::CodeGen::CodeGenCompilationResult::Success); + + CHECK(!nativeResult.hasErrors()); + REQUIRE(nativeResult.protoFailures.empty()); + + // We should be able to compile at least one of our functions + CHECK_EQ(nativeStats.functionsCompiled, 2); +} + TEST_SUITE_END(); diff --git a/tests/Generalization.test.cpp b/tests/Generalization.test.cpp index 8268dde6..43bd7325 100644 --- a/tests/Generalization.test.cpp +++ b/tests/Generalization.test.cpp @@ -21,14 +21,18 @@ struct GeneralizationFixture { TypeArena arena; BuiltinTypes builtinTypes; - Scope scope{builtinTypes.anyTypePack}; + ScopePtr globalScope = std::make_shared(builtinTypes.anyTypePack); + ScopePtr scope = std::make_shared(globalScope); ToStringOptions opts; + DenseHashSet generalizedTypes_{nullptr}; + NotNull> generalizedTypes{&generalizedTypes_}; + ScopedFastFlag sff{FFlag::DebugLuauDeferredConstraintResolution, true}; std::pair freshType() { - FreeType ft{&scope, builtinTypes.neverType, builtinTypes.unknownType}; + FreeType ft{scope.get(), builtinTypes.neverType, builtinTypes.unknownType}; TypeId ty = arena.addType(ft); FreeType* ftv = getMutable(ty); @@ -49,7 +53,7 @@ struct GeneralizationFixture std::optional generalize(TypeId ty) { - return ::Luau::generalize(NotNull{&arena}, NotNull{&builtinTypes}, NotNull{&scope}, ty); + return ::Luau::generalize(NotNull{&arena}, NotNull{&builtinTypes}, NotNull{scope.get()}, generalizedTypes, ty); } }; @@ -116,4 +120,71 @@ TEST_CASE_FIXTURE(GeneralizationFixture, "dont_traverse_into_class_types_when_ge CHECK(is(*genPropTy)); } +TEST_CASE_FIXTURE(GeneralizationFixture, "cache_fully_generalized_types") +{ + CHECK(generalizedTypes->empty()); + + TypeId tinyTable = arena.addType(TableType{ + TableType::Props{{"one", builtinTypes.numberType}, {"two", builtinTypes.stringType}}, + std::nullopt, + TypeLevel{}, + TableState::Sealed + }); + + generalize(tinyTable); + + CHECK(generalizedTypes->contains(tinyTable)); + CHECK(generalizedTypes->contains(builtinTypes.numberType)); + CHECK(generalizedTypes->contains(builtinTypes.stringType)); +} + +TEST_CASE_FIXTURE(GeneralizationFixture, "dont_cache_types_that_arent_done_yet") +{ + TypeId freeTy = arena.addType(FreeType{NotNull{globalScope.get()}, builtinTypes.neverType, builtinTypes.stringType}); + + TypeId fnTy = arena.addType(FunctionType{ + builtinTypes.emptyTypePack, + arena.addTypePack(TypePack{{builtinTypes.numberType}}) + }); + + TypeId tableTy = arena.addType(TableType{ + TableType::Props{{"one", builtinTypes.numberType}, {"two", freeTy}, {"three", fnTy}}, + std::nullopt, + TypeLevel{}, + TableState::Sealed + }); + + generalize(tableTy); + + CHECK(generalizedTypes->contains(fnTy)); + CHECK(generalizedTypes->contains(builtinTypes.numberType)); + CHECK(generalizedTypes->contains(builtinTypes.neverType)); + CHECK(generalizedTypes->contains(builtinTypes.stringType)); + CHECK(!generalizedTypes->contains(freeTy)); + CHECK(!generalizedTypes->contains(tableTy)); +} + +TEST_CASE_FIXTURE(GeneralizationFixture, "functions_containing_cyclic_tables_can_be_cached") +{ + TypeId selfTy = arena.addType(BlockedType{}); + + TypeId methodTy = arena.addType(FunctionType{ + arena.addTypePack({selfTy}), + arena.addTypePack({builtinTypes.numberType}), + }); + + asMutable(selfTy)->ty.emplace( + TableType::Props{{"count", builtinTypes.numberType}, {"method", methodTy}}, + std::nullopt, + TypeLevel{}, + TableState::Sealed + ); + + generalize(methodTy); + + CHECK(generalizedTypes->contains(methodTy)); + CHECK(generalizedTypes->contains(selfTy)); + CHECK(generalizedTypes->contains(builtinTypes.numberType)); +} + TEST_SUITE_END(); diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index 2eb8ca91..e8a10e92 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -12,7 +12,6 @@ LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(LuauNormalizeNotUnknownIntersection) -LUAU_FASTFLAG(LuauFixCyclicUnionsOfIntersections); LUAU_FASTINT(LuauTypeInferRecursionLimit) using namespace Luau; @@ -799,8 +798,6 @@ TEST_CASE_FIXTURE(NormalizeFixture, "cyclic_union") TEST_CASE_FIXTURE(NormalizeFixture, "cyclic_union_of_intersection") { - ScopedFastFlag sff{FFlag::LuauFixCyclicUnionsOfIntersections, true}; - // t1 where t1 = (string & t1) | string TypeId boundTy = arena.addType(BlockedType{}); TypeId intersectTy = arena.addType(IntersectionType{{builtinTypes->stringType, boundTy}}); @@ -814,8 +811,6 @@ TEST_CASE_FIXTURE(NormalizeFixture, "cyclic_union_of_intersection") TEST_CASE_FIXTURE(NormalizeFixture, "cyclic_intersection_of_unions") { - ScopedFastFlag sff{FFlag::LuauFixCyclicUnionsOfIntersections, true}; - // t1 where t1 = (string & t1) | string TypeId boundTy = arena.addType(BlockedType{}); TypeId unionTy = arena.addType(UnionType{{builtinTypes->stringType, boundTy}}); diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 8b2cc6ba..fabb897f 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -17,7 +17,8 @@ LUAU_FASTINT(LuauTypeLengthLimit); LUAU_FASTINT(LuauParseErrorLimit); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAG(LuauAttributeSyntax); -LUAU_FASTFLAG(LuauLeadingBarAndAmpersand); +LUAU_FASTFLAG(LuauLeadingBarAndAmpersand2); +LUAU_FASTFLAG(LuauAttributeSyntaxFunExpr); namespace { @@ -3177,21 +3178,21 @@ TEST_CASE_FIXTURE(Fixture, "read_write_table_properties") TEST_CASE_FIXTURE(Fixture, "can_parse_leading_bar_unions_successfully") { - ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand, true}; + ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand2, true}; parse(R"(type A = | "Hello" | "World")"); } TEST_CASE_FIXTURE(Fixture, "can_parse_leading_ampersand_intersections_successfully") { - ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand, true}; + ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand2, true}; parse(R"(type A = & { string } & { number })"); } TEST_CASE_FIXTURE(Fixture, "mixed_leading_intersection_and_union_not_allowed") { - ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand, true}; + ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand2, true}; matchParseError("type A = & number | string | boolean", "Mixing union and intersection types is not allowed; consider wrapping in parentheses."); matchParseError("type A = | number & string & boolean", "Mixing union and intersection types is not allowed; consider wrapping in parentheses."); @@ -3234,6 +3235,45 @@ end)"); checkAttribute(attributes.data[0], AstAttr::Type::Checked, Location(Position(1, 0), Position(1, 8))); } +TEST_CASE_FIXTURE(Fixture, "parse_attribute_for_function_expression") +{ + ScopedFastFlag sff[] = {{FFlag::LuauAttributeSyntax, true}, {FFlag::LuauAttributeSyntaxFunExpr, true}}; + + AstStatBlock* stat1 = parse(R"( +local function invoker(f) + return f(1) +end + +invoker(@checked function(x) return (x + 2) end) +)"); + + LUAU_ASSERT(stat1 != nullptr); + + AstExprFunction* func1 = stat1->body.data[1]->as()->expr->as()->args.data[0]->as(); + LUAU_ASSERT(func1 != nullptr); + + AstArray attributes1 = func1->attributes; + + CHECK_EQ(attributes1.size, 1); + + checkAttribute(attributes1.data[0], AstAttr::Type::Checked, Location(Position(5, 8), Position(5, 16))); + + AstStatBlock* stat2 = parse(R"( +local f = @checked function(x) return (x + 2) end +)"); + + LUAU_ASSERT(stat2 != nullptr); + + AstExprFunction* func2 = stat2->body.data[0]->as()->values.data[0]->as(); + LUAU_ASSERT(func2 != nullptr); + + AstArray attributes2 = func2->attributes; + + CHECK_EQ(attributes2.size, 1); + + checkAttribute(attributes2.data[0], AstAttr::Type::Checked, Location(Position(1, 10), Position(1, 18))); +} + TEST_CASE_FIXTURE(Fixture, "parse_attribute_on_local_function_stat") { ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; @@ -3342,6 +3382,22 @@ function foo1 () @checked return 'a' end "Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got 'return' intead"); } +TEST_CASE_FIXTURE(Fixture, "dont_parse_attribute_on_argument_non_function") +{ + ScopedFastFlag sff[] = {{FFlag::LuauAttributeSyntax, true}, {FFlag::LuauAttributeSyntaxFunExpr, true}}; + + ParseResult pr = tryParse(R"( +local function invoker(f, y) + return f(y) +end + +invoker(function(x) return (x + 2) end, @checked 1) +)"); + + checkFirstErrorForAttributes( + pr.errors, 1, Location(Position(5, 40), Position(5, 48)), "Expected 'function' declaration after attribute, but got '1' intead"); +} + TEST_CASE_FIXTURE(Fixture, "parse_attribute_on_function_type_declaration") { ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; @@ -3472,21 +3528,21 @@ end)"); TEST_CASE_FIXTURE(Fixture, "can_parse_leading_bar_unions_successfully") { - ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand, true}; + ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand2, true}; parse(R"(type A = | "Hello" | "World")"); } TEST_CASE_FIXTURE(Fixture, "can_parse_leading_ampersand_intersections_successfully") { - ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand, true}; + ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand2, true}; parse(R"(type A = & { string } & { number })"); } TEST_CASE_FIXTURE(Fixture, "mixed_leading_intersection_and_union_not_allowed") { - ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand, true}; + ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand2, true}; matchParseError("type A = & number | string | boolean", "Mixing union and intersection types is not allowed; consider wrapping in parentheses."); matchParseError("type A = | number & string & boolean", "Mixing union and intersection types is not allowed; consider wrapping in parentheses."); diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index d7cb225a..17faa2e7 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -356,32 +356,15 @@ TEST_CASE_FIXTURE(Fixture, "quit_stringifying_type_when_length_is_exceeded") )"); if (FFlag::DebugLuauDeferredConstraintResolution) { - LUAU_REQUIRE_ERROR_COUNT(3, result); - auto err = get(result.errors[0]); - LUAU_ASSERT(err); - CHECK("(...any) -> ()" == toString(err->recommendedReturn)); - REQUIRE(1 == err->recommendedArgs.size()); - CHECK("unknown" == toString(err->recommendedArgs[0].second)); - err = get(result.errors[1]); - LUAU_ASSERT(err); - // FIXME: this recommendation could be better - CHECK("(a) -> or ()>" == toString(err->recommendedReturn)); - REQUIRE(1 == err->recommendedArgs.size()); - CHECK("unknown" == toString(err->recommendedArgs[0].second)); - err = get(result.errors[2]); - LUAU_ASSERT(err); - // FIXME: this recommendation could be better - CHECK("(a) -> or(b) -> or ()>>" == toString(err->recommendedReturn)); - REQUIRE(1 == err->recommendedArgs.size()); - CHECK("unknown" == toString(err->recommendedArgs[0].second)); + LUAU_REQUIRE_NO_ERRORS(result); ToStringOptions o; o.exhaustive = false; o.maxTypeLength = 20; CHECK_EQ(toString(requireType("f0"), o), "() -> ()"); - CHECK_EQ(toString(requireType("f1"), o), "(a) -> or ... *TRUNCATED*"); - CHECK_EQ(toString(requireType("f2"), o), "(b) -> or(a... *TRUNCATED*"); - CHECK_EQ(toString(requireType("f3"), o), "(c) -> or(b... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f1"), o), "(a) -> (() -> ()) ... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f2"), o), "(b) -> ((a) -> (() -> ())... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f3"), o), "(c) -> ((b) -> ((a) -> (() -> ())... *TRUNCATED*"); } else { @@ -408,32 +391,15 @@ TEST_CASE_FIXTURE(Fixture, "stringifying_type_is_still_capped_when_exhaustive") if (FFlag::DebugLuauDeferredConstraintResolution) { - LUAU_REQUIRE_ERROR_COUNT(3, result); - auto err = get(result.errors[0]); - LUAU_ASSERT(err); - CHECK("(...any) -> ()" == toString(err->recommendedReturn)); - REQUIRE(1 == err->recommendedArgs.size()); - CHECK("unknown" == toString(err->recommendedArgs[0].second)); - err = get(result.errors[1]); - LUAU_ASSERT(err); - // FIXME: this recommendation could be better - CHECK("(a) -> or ()>" == toString(err->recommendedReturn)); - REQUIRE(1 == err->recommendedArgs.size()); - CHECK("unknown" == toString(err->recommendedArgs[0].second)); - err = get(result.errors[2]); - LUAU_ASSERT(err); - // FIXME: this recommendation could be better - CHECK("(a) -> or(b) -> or ()>>" == toString(err->recommendedReturn)); - REQUIRE(1 == err->recommendedArgs.size()); - CHECK("unknown" == toString(err->recommendedArgs[0].second)); + LUAU_REQUIRE_NO_ERRORS(result); ToStringOptions o; o.exhaustive = true; o.maxTypeLength = 20; CHECK_EQ(toString(requireType("f0"), o), "() -> ()"); - CHECK_EQ(toString(requireType("f1"), o), "(a) -> or ... *TRUNCATED*"); - CHECK_EQ(toString(requireType("f2"), o), "(b) -> or(a... *TRUNCATED*"); - CHECK_EQ(toString(requireType("f3"), o), "(c) -> or(b... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f1"), o), "(a) -> (() -> ()) ... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f2"), o), "(b) -> ((a) -> (() -> ())... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f3"), o), "(c) -> ((b) -> ((a) -> (() -> ())... *TRUNCATED*"); } else { diff --git a/tests/TypeFamily.test.cpp b/tests/TypeFamily.test.cpp index c66f0227..063ed39c 100644 --- a/tests/TypeFamily.test.cpp +++ b/tests/TypeFamily.test.cpp @@ -167,15 +167,13 @@ TEST_CASE_FIXTURE(FamilyFixture, "table_internal_families") LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK(toString(requireType("a")) == "{string}"); CHECK(toString(requireType("b")) == "{number}"); - CHECK(toString(requireType("c")) == "{Swap}"); - CHECK(toString(result.errors[0]) == "Type family instance Swap is uninhabited"); + // FIXME: table types are constructing a trivial union here. + CHECK(toString(requireType("c")) == "{Swap}"); + CHECK(toString(result.errors[0]) == "Type family instance Swap is uninhabited"); } TEST_CASE_FIXTURE(FamilyFixture, "function_internal_families") { - // This test is broken right now, but it's not because of type families. See - // CLI-71143. - if (!FFlag::DebugLuauDeferredConstraintResolution) return; @@ -829,4 +827,222 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "ensure_equivalence_with_distributivity") CHECK(toString(requireTypeAlias("U")) == "A | A | B | B"); } -TEST_SUITE_END(); +TEST_CASE_FIXTURE(BuiltinsFixture, "we_shouldnt_warn_that_a_reducible_type_family_is_uninhabited") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + +local Debounce = false +local Active = false + +local function Use(Mode) + + if Mode ~= nil then + + if Mode == false and Active == false then + return + else + Active = not Mode + end + + Debounce = false + end + Active = not Active + +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "index_type_family_works") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + type MyObject = {a: string, b: number, c: boolean} + type IdxAType = index + type IdxBType = index> + + local function ok(idx: IdxAType): string return idx end + local function ok2(idx: IdxBType): string | number | boolean return idx end + local function err(idx: IdxAType): boolean return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK_EQ("boolean", toString(tpm->wantedTp)); + CHECK_EQ("string", toString(tpm->givenTp)); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "index_type_family_works_w_array") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + local MyObject = {"hello", 1, true} + type IdxAType = index + + local function ok(idx: IdxAType): string | number | boolean return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "index_type_family_works_w_generic_types") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + local function access(tbl: T & {}, key: K): index + return tbl[key] + end + + local subjects = { + english = "boring", + math = "fun" + } + + local key: "english" = "english" + local a: string = access(subjects, key) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "index_type_family_errors_w_bad_indexer") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + type MyObject = {a: string, b: number, c: boolean} + type errType1 = index + type errType2 = index + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK(toString(result.errors[0]) == "Property '\"d\"' does not exist on type 'MyObject'"); + CHECK(toString(result.errors[1]) == "Property 'boolean' does not exist on type 'MyObject'"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "index_type_family_errors_w_var_indexer") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + type MyObject = {a: string, b: number, c: boolean} + local key = "a" + + type errType1 = index + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK(toString(result.errors[0]) == "Second argument to index is not a valid index type"); + CHECK(toString(result.errors[1]) == "Unknown type 'key'"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "index_type_family_works_w_union_type_indexer") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + type MyObject = {a: string, b: number, c: boolean} + + type idxType = index + local function ok(idx: idxType): string | number return idx end + + type errType = index + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(toString(result.errors[0]) == "Property '\"a\" | \"d\"' does not exist on type 'MyObject'"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "index_type_family_works_w_union_type_indexee") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + type MyObject = {a: string, b: number, c: boolean} + type MyObject2 = {a: number} + + type idxTypeA = index + local function ok(idx: idxTypeA): string | number return idx end + + type errType = index + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(toString(result.errors[0]) == "Property '\"b\"' does not exist on type 'MyObject | MyObject2'"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "index_type_family_rfc_alternative_section") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + type MyObject = {a: string} + type MyObject2 = {a: string, b: number} + + local function edgeCase(param: MyObject) + type unknownType = index + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(toString(result.errors[0]) == "Property '\"b\"' does not exist on type 'MyObject'"); +} + +TEST_CASE_FIXTURE(ClassFixture, "index_type_family_works_on_classes") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + type KeysOfMyObject = index + + local function ok(idx: KeysOfMyObject): number return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "index_type_family_works_w_index_metatables") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + local exampleClass = { Foo = "text", Bar = true } + + local exampleClass2 = setmetatable({ Foo = 8 }, { __index = exampleClass }) + type exampleTy2 = index + local function ok(idx: exampleTy2): number return idx end + + local exampleClass3 = setmetatable({ Bar = 5 }, { __index = exampleClass }) + type exampleTy3 = index + local function ok2(idx: exampleTy3): string return idx end + + type exampleTy4 = index + local function ok3(idx: exampleTy4): string | number return idx end + + type errTy = index + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(toString(result.errors[0]) == "Property '\"Car\"' does not exist on type 'exampleClass2'"); +} + +TEST_SUITE_END(); \ No newline at end of file diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index 06e698a8..54cf1cef 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -9,7 +9,6 @@ using namespace Luau; LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(DebugLuauSharedSelf); -LUAU_FASTFLAG(LuauForbidAliasNamedTypeof); TEST_SUITE_BEGIN("TypeAliases"); @@ -1065,8 +1064,6 @@ TEST_CASE_FIXTURE(Fixture, "table_types_record_the_property_locations") TEST_CASE_FIXTURE(Fixture, "typeof_is_not_a_valid_alias_name") { - ScopedFastFlag sff{FFlag::LuauForbidAliasNamedTypeof, true}; - CheckResult result = check(R"( type typeof = number )"); diff --git a/tests/TypeInfer.loops.test.cpp b/tests/TypeInfer.loops.test.cpp index ce6988aa..a58fb638 100644 --- a/tests/TypeInfer.loops.test.cpp +++ b/tests/TypeInfer.loops.test.cpp @@ -696,13 +696,7 @@ TEST_CASE_FIXTURE(Fixture, "loop_iter_basic") if (FFlag::DebugLuauDeferredConstraintResolution) { TypeId keyTy = requireType("key"); - - const UnionType* ut = get(keyTy); - REQUIRE(ut); - - REQUIRE(ut->options.size() == 2); - CHECK_EQ(builtinTypes->nilType, follow(ut->options[0])); - CHECK_EQ(*builtinTypes->numberType, *ut->options[1]); + CHECK("number?" == toString(keyTy)); } else CHECK_EQ(*builtinTypes->numberType, *requireType("key")); diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index fac86150..b8bb9795 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -396,9 +396,17 @@ TEST_CASE_FIXTURE(Fixture, "compound_assign_mismatch_result") s += 10 )"); - LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK_EQ(result.errors[0], (TypeError{Location{{2, 8}, {2, 9}}, TypeMismatch{builtinTypes->numberType, builtinTypes->stringType}})); - CHECK_EQ(result.errors[1], (TypeError{Location{{2, 8}, {2, 15}}, TypeMismatch{builtinTypes->stringType, builtinTypes->numberType}})); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(result.errors[0], (TypeError{Location{{2, 8}, {2, 9}}, TypeMismatch{builtinTypes->numberType, builtinTypes->stringType}})); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ(result.errors[0], (TypeError{Location{{2, 8}, {2, 9}}, TypeMismatch{builtinTypes->numberType, builtinTypes->stringType}})); + CHECK_EQ(result.errors[1], (TypeError{Location{{2, 8}, {2, 15}}, TypeMismatch{builtinTypes->stringType, builtinTypes->numberType}})); + } } TEST_CASE_FIXTURE(BuiltinsFixture, "compound_assign_metatable") @@ -423,6 +431,33 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "compound_assign_metatable") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "compound_assign_metatable_with_changing_return_type") +{ + ScopedFastFlag sff{FFlag::DebugLuauDeferredConstraintResolution, true}; + + CheckResult result = check(R"( + --!strict + type T = { x: number } + local MT = {} + + function MT:__add(other): number + return 112 + end + + local t = setmetatable({x = 2}, MT) + local u = t + 3 + t += 3 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + + CHECK("t" == toString(tm->wantedType)); + CHECK("number" == toString(tm->givenType)); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "compound_assign_result_must_be_compatible_with_var") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index ebf1fde4..e089c7be 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -576,15 +576,13 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "singletons_stick_around_under_assignment") local foo = (nil :: any) :: Foo - print(foo.kind == "Bar") -- TypeError: Type "Foo" cannot be compared with "Bar" + print(foo.kind == "Bar") -- type of equality refines to `false` local kind = foo.kind - print(kind == "Bar") -- SHOULD BE: TypeError: Type "Foo" cannot be compared with "Bar" + print(kind == "Bar") -- type of equality refines to `false` )"); - // FIXME: Under the new solver, we get both the errors we expect, but they're - // duplicated because of how we are currently running type family reduction. if (FFlag::DebugLuauDeferredConstraintResolution) - LUAU_REQUIRE_ERROR_COUNT(4, result); + LUAU_REQUIRE_NO_ERRORS(result); else LUAU_REQUIRE_ERROR_COUNT(1, result); } diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 516a761b..2c9614d0 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -4511,4 +4511,30 @@ end )"); } +TEST_CASE_FIXTURE(BuiltinsFixture, "table_literal_inference_assert") +{ + CheckResult result = check(R"( + local buttons = { + buttons = {}; + } + + buttons.Button = { + call = nil; + lightParts = nil; + litPropertyOverrides = nil; + model = nil; + pivot = nil; + unlitPropertyOverrides = nil; + } + buttons.Button.__index = buttons.Button + + local lightFuncs: { (self: types.Button, lit: boolean) -> nil } = { + ['\x00'] = function(self: types.Button, lit: boolean) + end; + } + )"); + + +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 9ea9539f..60903733 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -19,6 +19,7 @@ LUAU_FASTFLAG(LuauFixLocationSpanTableIndexExpr); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAG(LuauInstantiateInSubtyping); +LUAU_FASTFLAG(LuauLeadingBarAndAmpersand2) LUAU_FASTINT(LuauCheckRecursionLimit); LUAU_FASTINT(LuauNormalizeCacheLimit); LUAU_FASTINT(LuauRecursionLimit); @@ -1572,4 +1573,62 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "bad_iter_metamethod") } } +TEST_CASE_FIXTURE(Fixture, "leading_bar") +{ + ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand2, true}; + CheckResult result = check(R"( + type Bar = | number + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK("number" == toString(requireTypeAlias("Bar"))); +} + +TEST_CASE_FIXTURE(Fixture, "leading_bar_question_mark") +{ + ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand2, true}; + CheckResult result = check(R"( + type Bar = |? + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK("Expected type, got '?'" == toString(result.errors[0])); + CHECK("*error-type*?" == toString(requireTypeAlias("Bar"))); +} + +TEST_CASE_FIXTURE(Fixture, "leading_ampersand") +{ + ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand2, true}; + CheckResult result = check(R"( + type Amp = & string + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK("string" == toString(requireTypeAlias("Amp"))); +} + +TEST_CASE_FIXTURE(Fixture, "leading_bar_no_type") +{ + ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand2, true}; + CheckResult result = check(R"( + type Bar = | + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK("Expected type, got " == toString(result.errors[0])); + CHECK("*error-type*" == toString(requireTypeAlias("Bar"))); +} + +TEST_CASE_FIXTURE(Fixture, "leading_ampersand_no_type") +{ + ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand2, true}; + CheckResult result = check(R"( + type Amp = & + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK("Expected type, got " == toString(result.errors[0])); + CHECK("*error-type*" == toString(requireTypeAlias("Amp"))); +} + TEST_SUITE_END(); diff --git a/tools/faillist.txt b/tools/faillist.txt index b2677bf4..834b24c0 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -32,16 +32,6 @@ BuiltinTests.string_format_report_all_type_errors_at_correct_positions BuiltinTests.string_format_use_correct_argument2 BuiltinTests.table_freeze_is_generic BuiltinTests.tonumber_returns_optional_number_type -ControlFlowAnalysis.for_record_do_if_not_x_break -ControlFlowAnalysis.for_record_do_if_not_x_continue -ControlFlowAnalysis.if_not_x_break_elif_not_y_break -ControlFlowAnalysis.if_not_x_break_elif_not_y_continue -ControlFlowAnalysis.if_not_x_break_elif_rand_break_elif_not_y_break -ControlFlowAnalysis.if_not_x_continue_elif_not_y_continue -ControlFlowAnalysis.if_not_x_continue_elif_not_y_throw_elif_not_z_fallthrough -ControlFlowAnalysis.if_not_x_continue_elif_rand_continue_elif_not_y_continue -ControlFlowAnalysis.if_not_x_return_elif_not_y_break -DefinitionTests.class_definition_overload_metamethods Differ.metatable_metamissing_left Differ.metatable_metamissing_right Differ.metatable_metanormal @@ -238,8 +228,6 @@ ToString.named_metatable_toStringNamedFunction ToString.no_parentheses_around_cyclic_function_type_in_intersection ToString.pick_distinct_names_for_mixed_explicit_and_implicit_generics ToString.primitive -ToString.quit_stringifying_type_when_length_is_exceeded -ToString.stringifying_type_is_still_capped_when_exhaustive ToString.toStringDetailed2 ToString.toStringErrorPack TryUnifyTests.members_of_failed_typepack_unification_are_unified_with_errorType @@ -262,7 +250,6 @@ TypeAliases.type_alias_of_an_imported_recursive_generic_type TypeFamilyTests.add_family_at_work TypeFamilyTests.family_as_fn_arg TypeFamilyTests.internal_families_raise_errors -TypeFamilyTests.mul_family_with_union_of_multiplicatives_2 TypeFamilyTests.unsolvable_family TypeInfer.be_sure_to_use_active_txnlog_when_evaluating_a_variadic_overload TypeInfer.check_type_infer_recursion_count @@ -319,7 +306,6 @@ TypeInferFunctions.function_exprs_are_generalized_at_signature_scope_not_enclosi TypeInferFunctions.function_is_supertype_of_concrete_functions TypeInferFunctions.function_statement_sealed_table_assignment_through_indexer TypeInferFunctions.generic_packs_are_not_variadic -TypeInferFunctions.higher_order_function_2 TypeInferFunctions.higher_order_function_4 TypeInferFunctions.improved_function_arg_mismatch_error_nonstrict TypeInferFunctions.improved_function_arg_mismatch_errors @@ -339,7 +325,7 @@ TypeInferFunctions.param_1_and_2_both_takes_the_same_generic_but_their_arguments TypeInferFunctions.param_1_and_2_both_takes_the_same_generic_but_their_arguments_are_incompatible_2 TypeInferFunctions.report_exiting_without_return_nonstrict TypeInferFunctions.return_type_by_overload -TypeInferFunctions.tf_suggest_return_type +TypeInferFunctions.simple_unannotated_mutual_recursion TypeInferFunctions.too_few_arguments_variadic TypeInferFunctions.too_few_arguments_variadic_generic TypeInferFunctions.too_few_arguments_variadic_generic2 @@ -377,7 +363,6 @@ TypeInferOOP.inferring_hundreds_of_self_calls_should_not_suffocate_memory TypeInferOOP.promise_type_error_too_complex TypeInferOperators.add_type_family_works TypeInferOperators.cli_38355_recursive_union -TypeInferOperators.compound_assign_mismatch_metatable TypeInferOperators.compound_assign_result_must_be_compatible_with_var TypeInferOperators.concat_op_on_free_lhs_and_string_rhs TypeInferOperators.concat_op_on_string_lhs_and_free_rhs @@ -408,8 +393,6 @@ TypeSingletons.error_detailed_tagged_union_mismatch_bool TypeSingletons.error_detailed_tagged_union_mismatch_string TypeSingletons.overloaded_function_call_with_singletons_mismatch TypeSingletons.return_type_of_f_is_not_widened -TypeSingletons.singletons_stick_around_under_assignment -TypeSingletons.string_singleton_function_call TypeSingletons.table_properties_type_error_escapes TypeSingletons.widen_the_supertype_if_it_is_free_and_subtype_has_singleton TypeStatesTest.typestates_preserve_error_suppression_properties From caee04d82d014ed104dd63edec1710fb6ab5794c Mon Sep 17 00:00:00 2001 From: vegorov-rbx <75688451+vegorov-rbx@users.noreply.github.com> Date: Thu, 20 Jun 2024 16:37:55 -0700 Subject: [PATCH 19/20] Sync to upstream/release/631 (#1299) ### What's new * Added lint warning for using redundant `@native` attributes on functions inside a `--!native` module * Improved typechecking speed in old solver for modules with large types ### New Solver * Fixed the length type function sealing the table prematurely * Fixed crashes caused by general table indexing expressions ### VM * Added support for a specialized 3-argument fast-call instruction to improve performance of `vector` constructor, `buffer` writes and a few `bit32` methods --- ### Internal Contributors Co-authored-by: Aaron Weiss Co-authored-by: Andy Friesen Co-authored-by: Aviral Goel Co-authored-by: Vighnesh Vijay Co-authored-by: Vyacheslav Egorov --------- Co-authored-by: Aaron Weiss Co-authored-by: Alexander McCord Co-authored-by: Andy Friesen Co-authored-by: Vighnesh Co-authored-by: Aviral Goel Co-authored-by: David Cope Co-authored-by: Lily Brown --- Analysis/include/Luau/ConstraintSolver.h | 2 +- Analysis/include/Luau/Generalization.h | 4 +- Analysis/include/Luau/Instantiation.h | 10 + Analysis/include/Luau/Substitution.h | 5 +- Analysis/include/Luau/TypeFamily.h | 4 +- Analysis/include/Luau/TypeInfer.h | 3 + Analysis/src/BuiltinDefinitions.cpp | 49 +++-- Analysis/src/ConstraintGenerator.cpp | 6 +- Analysis/src/ConstraintSolver.cpp | 39 ++-- Analysis/src/Error.cpp | 16 +- Analysis/src/Generalization.cpp | 75 +++---- Analysis/src/Instantiation.cpp | 54 ++++- Analysis/src/Linter.cpp | 70 +++++++ Analysis/src/Substitution.cpp | 48 ++++- Analysis/src/TypeChecker2.cpp | 28 ++- Analysis/src/TypeFamily.cpp | 100 +++++----- Analysis/src/TypeInfer.cpp | 30 ++- Analysis/src/Unifier.cpp | 3 +- Analysis/src/Unifier2.cpp | 48 ++--- CodeGen/include/Luau/IrData.h | 14 +- CodeGen/include/Luau/IrUtils.h | 1 + CodeGen/include/Luau/IrVisitUseDef.h | 68 +++++-- CodeGen/src/BytecodeAnalysis.cpp | 213 +++++++++----------- CodeGen/src/CodeGenAssembly.cpp | 5 +- CodeGen/src/CodeGenLower.h | 4 +- CodeGen/src/EmitBuiltinsX64.cpp | 30 +-- CodeGen/src/EmitBuiltinsX64.h | 2 +- CodeGen/src/IrBuilder.cpp | 114 ++--------- CodeGen/src/IrLoweringA64.cpp | 213 ++++++++++---------- CodeGen/src/IrLoweringX64.cpp | 48 +++-- CodeGen/src/IrTranslateBuiltins.cpp | 157 ++++++++------- CodeGen/src/IrTranslateBuiltins.h | 3 +- CodeGen/src/IrTranslation.cpp | 37 +++- CodeGen/src/IrTranslation.h | 3 +- CodeGen/src/IrValueLocationTracking.cpp | 8 +- CodeGen/src/OptimizeConstProp.cpp | 59 +++--- CodeGen/src/OptimizeDeadStore.cpp | 1 - Common/include/Luau/Bytecode.h | 12 +- Common/include/Luau/BytecodeUtils.h | 1 + Common/include/Luau/DenseHash.h | 8 +- Compiler/src/BytecodeBuilder.cpp | 22 +++ Compiler/src/Compiler.cpp | 58 +++++- Config/include/Luau/LinterConfig.h | 2 + VM/src/ldebug.cpp | 8 +- VM/src/lfunc.cpp | 18 +- VM/src/lgc.cpp | 14 +- VM/src/lmem.cpp | 48 ++--- VM/src/ltablib.cpp | 54 ++--- VM/src/lvmexecute.cpp | 62 +++++- VM/src/lvmload.cpp | 106 ++++------ tests/Compiler.test.cpp | 56 ++++++ tests/Error.test.cpp | 6 +- tests/Fixture.h | 18 ++ tests/Generalization.test.cpp | 23 +-- tests/IrBuilder.test.cpp | 104 ++++------ tests/IrLowering.test.cpp | 238 ++++++++--------------- tests/Linter.test.cpp | 31 +++ tests/TypeInfer.functions.test.cpp | 13 +- tests/TypeInfer.provisional.test.cpp | 41 ++++ tests/TypeInfer.tables.test.cpp | 50 +++-- tests/TypeInfer.test.cpp | 2 +- tests/conformance/bitwise.lua | 1 + tests/conformance/math.lua | 23 +++ 63 files changed, 1433 insertions(+), 1160 deletions(-) diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index 6e62a2e3..925be04e 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -300,7 +300,7 @@ public: * @returns a non-free type that generalizes the argument, or `std::nullopt` if one * does not exist */ - std::optional generalizeFreeType(NotNull scope, TypeId type); + std::optional generalizeFreeType(NotNull scope, TypeId type, bool avoidSealingTables = false); /** * Checks the existing set of constraints to see if there exist any that contain diff --git a/Analysis/include/Luau/Generalization.h b/Analysis/include/Luau/Generalization.h index 44d0db67..04ac2df1 100644 --- a/Analysis/include/Luau/Generalization.h +++ b/Analysis/include/Luau/Generalization.h @@ -8,6 +8,6 @@ namespace Luau { -std::optional generalize(NotNull arena, NotNull builtinTypes, NotNull scope, NotNull> bakedTypes, TypeId ty); - +std::optional generalize(NotNull arena, NotNull builtinTypes, NotNull scope, + NotNull> bakedTypes, TypeId ty, /* avoid sealing tables*/ bool avoidSealingTables = false); } diff --git a/Analysis/include/Luau/Instantiation.h b/Analysis/include/Luau/Instantiation.h index 2122f0fa..58ba88ab 100644 --- a/Analysis/include/Luau/Instantiation.h +++ b/Analysis/include/Luau/Instantiation.h @@ -27,12 +27,16 @@ struct ReplaceGenerics : Substitution { } + void resetState(const TxnLog* log, TypeArena* arena, NotNull builtinTypes, TypeLevel level, Scope* scope, + const std::vector& generics, const std::vector& genericPacks); + NotNull builtinTypes; TypeLevel level; Scope* scope; std::vector generics; std::vector genericPacks; + bool ignoreChildren(TypeId ty) override; bool isDirty(TypeId ty) override; bool isDirty(TypePackId tp) override; @@ -48,13 +52,19 @@ struct Instantiation : Substitution , builtinTypes(builtinTypes) , level(level) , scope(scope) + , reusableReplaceGenerics(log, arena, builtinTypes, level, scope, {}, {}) { } + void resetState(const TxnLog* log, TypeArena* arena, NotNull builtinTypes, TypeLevel level, Scope* scope); + NotNull builtinTypes; TypeLevel level; Scope* scope; + + ReplaceGenerics reusableReplaceGenerics; + bool ignoreChildren(TypeId ty) override; bool isDirty(TypeId ty) override; bool isDirty(TypePackId tp) override; diff --git a/Analysis/include/Luau/Substitution.h b/Analysis/include/Luau/Substitution.h index 16e36e09..28ebc93d 100644 --- a/Analysis/include/Luau/Substitution.h +++ b/Analysis/include/Luau/Substitution.h @@ -134,7 +134,8 @@ struct Tarjan TarjanResult visitRoot(TypeId ty); TarjanResult visitRoot(TypePackId ty); - void clearTarjan(); + // Used to reuse the object for a new operation + void clearTarjan(const TxnLog* log); // Get/set the dirty bit for an index (grows the vector if needed) bool getDirty(int index); @@ -212,6 +213,8 @@ public: std::optional substitute(TypeId ty); std::optional substitute(TypePackId tp); + void resetState(const TxnLog* log, TypeArena* arena); + TypeId replace(TypeId ty); TypePackId replace(TypePackId tp); diff --git a/Analysis/include/Luau/TypeFamily.h b/Analysis/include/Luau/TypeFamily.h index fa23a6ba..7c68d815 100644 --- a/Analysis/include/Luau/TypeFamily.h +++ b/Analysis/include/Luau/TypeFamily.h @@ -82,8 +82,8 @@ struct TypeFamilyReductionResult }; template -using ReducerFunction = std::function(T, const std::vector&, const std::vector&, - NotNull)>; +using ReducerFunction = + std::function(T, const std::vector&, const std::vector&, NotNull)>; /// Represents a type function that may be applied to map a series of types and /// type packs to a single output type. diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 26a67c7a..340c1e72 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -4,6 +4,7 @@ #include "Luau/Anyification.h" #include "Luau/ControlFlow.h" #include "Luau/Error.h" +#include "Luau/Instantiation.h" #include "Luau/Module.h" #include "Luau/Predicate.h" #include "Luau/Substitution.h" @@ -362,6 +363,8 @@ public: UnifierSharedState unifierState; Normalizer normalizer; + Instantiation reusableInstantiation; + std::vector requireCycles; // Type inference limits diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index a9c519fe..2393bd2a 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -256,21 +256,44 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC TypeId tableMetaMT = arena.addType(MetatableType{tabTy, genericMT}); + // getmetatable : ({ @metatable MT, {+ +} }) -> MT addGlobalBinding(globals, "getmetatable", makeFunction(arena, std::nullopt, {genericMT}, {}, {tableMetaMT}, {genericMT}), "@luau"); - // clang-format off - // setmetatable(T, MT) -> { @metatable MT, T } - addGlobalBinding(globals, "setmetatable", - arena.addType( - FunctionType{ - {genericMT}, - {}, - arena.addTypePack(TypePack{{tabTy, genericMT}}), - arena.addTypePack(TypePack{{tableMetaMT}}) - } - ), "@luau" - ); - // clang-format on + if (FFlag::DebugLuauDeferredConstraintResolution) + { + TypeId genericT = arena.addType(GenericType{"T"}); + TypeId tMetaMT = arena.addType(MetatableType{genericT, genericMT}); + + // clang-format off + // setmetatable(T, MT) -> { @metatable MT, T } + addGlobalBinding(globals, "setmetatable", + arena.addType( + FunctionType{ + {genericT, genericMT}, + {}, + arena.addTypePack(TypePack{{genericT, genericMT}}), + arena.addTypePack(TypePack{{tMetaMT}}) + } + ), "@luau" + ); + // clang-format on + } + else + { + // clang-format off + // setmetatable(T, MT) -> { @metatable MT, T } + addGlobalBinding(globals, "setmetatable", + arena.addType( + FunctionType{ + {genericMT}, + {}, + arena.addTypePack(TypePack{{tabTy, genericMT}}), + arena.addTypePack(TypePack{{tableMetaMT}}) + } + ), "@luau" + ); + // clang-format on + } for (const auto& pair : globals.globalScope->bindings) { diff --git a/Analysis/src/ConstraintGenerator.cpp b/Analysis/src/ConstraintGenerator.cpp index b784f4aa..7d92d9ff 100644 --- a/Analysis/src/ConstraintGenerator.cpp +++ b/Analysis/src/ConstraintGenerator.cpp @@ -787,7 +787,7 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatLocal* stat auto uc = addConstraint(scope, statLocal->location, UnpackConstraint{valueTypes, rvaluePack}); - for (TypeId t: valueTypes) + for (TypeId t : valueTypes) getMutable(t)->setOwner(uc); } @@ -920,7 +920,7 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatForIn* forI auto iterable = addConstraint( loopScope, getLocation(forIn->values), IterableConstraint{iterator, variableTypes, forIn->values.data[0], &module->astForInNextTypes}); - for (TypeId var: variableTypes) + for (TypeId var : variableTypes) { auto bt = getMutable(var); LUAU_ASSERT(bt); @@ -1171,7 +1171,7 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatAssign* ass auto uc = addConstraint(scope, assign->location, UnpackConstraint{valueTypes, resultPack}); - for (TypeId t: valueTypes) + for (TypeId t : valueTypes) getMutable(t)->setOwner(uc); } diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index e59bc8a7..8756ec44 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -694,7 +694,7 @@ bool ConstraintSolver::tryDispatch(const GeneralizationConstraint& c, NotNullscope, generalizedTypes, ty); + generalize(NotNull{arena}, builtinTypes, constraint->scope, generalizedTypes, ty, /* avoidSealingTables */ false); return true; } @@ -769,13 +769,8 @@ bool ConstraintSolver::tryDispatch(const IterableConstraint& c, NotNullscope); TypeId valueTy = freshType(arena, builtinTypes, constraint->scope); - TypeId tableTy = arena->addType(TableType{ - TableType::Props{}, - TableIndexer{keyTy, valueTy}, - TypeLevel{}, - constraint->scope, - TableState::Free - }); + TypeId tableTy = + arena->addType(TableType{TableType::Props{}, TableIndexer{keyTy, valueTy}, TypeLevel{}, constraint->scope, TableState::Free}); unify(constraint, nextTy, tableTy); @@ -1022,13 +1017,10 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul const TableType* tfTable = getTableType(tf->type); //clang-format off - bool needsClone = - follow(tf->type) == target || - (tfTable != nullptr && tfTable == getTableType(target)) || - std::any_of(typeArguments.begin(), typeArguments.end(), [&](const auto& other) { - return other == target; - } - ); + bool needsClone = follow(tf->type) == target || (tfTable != nullptr && tfTable == getTableType(target)) || + std::any_of(typeArguments.begin(), typeArguments.end(), [&](const auto& other) { + return other == target; + }); //clang-format on // Only tables have the properties we're trying to set. @@ -1446,6 +1438,8 @@ bool ConstraintSolver::tryDispatchHasIndexer( bind(constraint, resultType, tbl->indexer->indexResultType); return true; } + else if (auto mt = get(follow(ft->upperBound))) + return tryDispatchHasIndexer(recursionDepth, constraint, mt->table, indexType, resultType, seen); FreeType freeResult{ft->scope, builtinTypes->neverType, builtinTypes->unknownType}; emplace(constraint, resultType, freeResult); @@ -1461,11 +1455,11 @@ bool ConstraintSolver::tryDispatchHasIndexer( if (auto indexer = tt->indexer) { unify(constraint, indexType, indexer->indexType); - bind(constraint, resultType, indexer->indexResultType); return true; } - else if (tt->state == TableState::Unsealed) + + if (tt->state == TableState::Unsealed) { // FIXME this is greedy. @@ -2067,7 +2061,7 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl } auto unpack = [&](TypeId ty) { - for (TypeId varTy: c.variables) + for (TypeId varTy : c.variables) { LUAU_ASSERT(get(varTy)); LUAU_ASSERT(varTy != ty); @@ -2228,11 +2222,12 @@ bool ConstraintSolver::tryDispatchIterableFunction( return true; } -NotNull ConstraintSolver::unpackAndAssign(const std::vector destTypes, TypePackId srcTypes, NotNull constraint) +NotNull ConstraintSolver::unpackAndAssign( + const std::vector destTypes, TypePackId srcTypes, NotNull constraint) { auto c = pushConstraint(constraint->scope, constraint->location, UnpackConstraint{destTypes, srcTypes}); - for (TypeId t: destTypes) + for (TypeId t : destTypes) { BlockedType* bt = getMutable(t); LUAU_ASSERT(bt); @@ -2834,7 +2829,7 @@ void ConstraintSolver::shiftReferences(TypeId source, TypeId target) targetRefs += count; } -std::optional ConstraintSolver::generalizeFreeType(NotNull scope, TypeId type) +std::optional ConstraintSolver::generalizeFreeType(NotNull scope, TypeId type, bool avoidSealingTables) { TypeId t = follow(type); if (get(t)) @@ -2849,7 +2844,7 @@ std::optional ConstraintSolver::generalizeFreeType(NotNull scope, // that until all constraint generation is complete. } - return generalize(NotNull{arena}, builtinTypes, scope, generalizedTypes, type); + return generalize(NotNull{arena}, builtinTypes, scope, generalizedTypes, type, avoidSealingTables); } bool ConstraintSolver::hasUnresolvedConstraints(TypeId ty) diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index d356b1cc..cb8ef20d 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -65,21 +65,15 @@ namespace Luau { // this list of binary operator type families is used for better stringification of type families errors -static const std::unordered_map kBinaryOps{ - {"add", "+"}, {"sub", "-"}, {"mul", "*"}, {"div", "/"}, {"idiv", "//"}, {"pow", "^"}, {"mod", "%"}, {"concat", ".."}, {"and", "and"}, - {"or", "or"}, {"lt", "< or >="}, {"le", "<= or >"}, {"eq", "== or ~="} -}; +static const std::unordered_map kBinaryOps{{"add", "+"}, {"sub", "-"}, {"mul", "*"}, {"div", "/"}, {"idiv", "//"}, + {"pow", "^"}, {"mod", "%"}, {"concat", ".."}, {"and", "and"}, {"or", "or"}, {"lt", "< or >="}, {"le", "<= or >"}, {"eq", "== or ~="}}; // this list of unary operator type families is used for better stringification of type families errors -static const std::unordered_map kUnaryOps{ - {"unm", "-"}, {"len", "#"}, {"not", "not"} -}; +static const std::unordered_map kUnaryOps{{"unm", "-"}, {"len", "#"}, {"not", "not"}}; // this list of type families will receive a special error indicating that the user should file a bug on the GitHub repository // putting a type family in this list indicates that it is expected to _always_ reduce -static const std::unordered_set kUnreachableTypeFamilies{ - "refine", "singleton", "union", "intersect" -}; +static const std::unordered_set kUnreachableTypeFamilies{"refine", "singleton", "union", "intersect"}; struct ErrorConverter { @@ -682,7 +676,7 @@ struct ErrorConverter if (kUnreachableTypeFamilies.count(tfit->family->name)) { return "Type family instance " + Luau::toString(e.ty) + " is uninhabited\n" + - "This is likely to be a bug, please report it at https://github.com/luau-lang/luau/issues"; + "This is likely to be a bug, please report it at https://github.com/luau-lang/luau/issues"; } // Everything should be specialized above to report a more descriptive error that hopefully does not mention "type families" explicitly. diff --git a/Analysis/src/Generalization.cpp b/Analysis/src/Generalization.cpp index c2c44d96..5020ea58 100644 --- a/Analysis/src/Generalization.cpp +++ b/Analysis/src/Generalization.cpp @@ -24,15 +24,17 @@ struct MutatingGeneralizer : TypeOnceVisitor std::vector genericPacks; bool isWithinFunction = false; + bool avoidSealingTables = false; - MutatingGeneralizer(NotNull builtinTypes, NotNull scope, NotNull> cachedTypes, DenseHashMap positiveTypes, - DenseHashMap negativeTypes) + MutatingGeneralizer(NotNull builtinTypes, NotNull scope, NotNull> cachedTypes, + DenseHashMap positiveTypes, DenseHashMap negativeTypes, bool avoidSealingTables) : TypeOnceVisitor(/* skipBoundTypes */ true) , builtinTypes(builtinTypes) , scope(scope) , cachedTypes(cachedTypes) , positiveTypes(std::move(positiveTypes)) , negativeTypes(std::move(negativeTypes)) + , avoidSealingTables(avoidSealingTables) { } @@ -268,7 +270,8 @@ struct MutatingGeneralizer : TypeOnceVisitor TableType* tt = getMutable(ty); LUAU_ASSERT(tt); - tt->state = TableState::Sealed; + if (!avoidSealingTables) + tt->state = TableState::Sealed; return true; } @@ -338,31 +341,31 @@ struct FreeTypeSearcher : TypeVisitor { switch (polarity) { - case Positive: - { - if (seenPositive.contains(ty)) - return true; + case Positive: + { + if (seenPositive.contains(ty)) + return true; - seenPositive.insert(ty); - return false; - } - case Negative: - { - if (seenNegative.contains(ty)) - return true; + seenPositive.insert(ty); + return false; + } + case Negative: + { + if (seenNegative.contains(ty)) + return true; - seenNegative.insert(ty); - return false; - } - case Both: - { - if (seenPositive.contains(ty) && seenNegative.contains(ty)) - return true; + seenNegative.insert(ty); + return false; + } + case Both: + { + if (seenPositive.contains(ty) && seenNegative.contains(ty)) + return true; - seenPositive.insert(ty); - seenNegative.insert(ty); - return false; - } + seenPositive.insert(ty); + seenNegative.insert(ty); + return false; + } } return false; @@ -519,7 +522,8 @@ struct TypeCacher : TypeOnceVisitor explicit TypeCacher(NotNull> cachedTypes) : TypeOnceVisitor(/* skipBoundTypes */ true) , cachedTypes(cachedTypes) - {} + { + } void cache(TypeId ty) { @@ -611,7 +615,7 @@ struct TypeCacher : TypeOnceVisitor traverse(ft.argTypes); traverse(ft.retTypes); - for (TypeId gen: ft.generics) + for (TypeId gen : ft.generics) traverse(gen); bool uncacheable = false; @@ -622,7 +626,7 @@ struct TypeCacher : TypeOnceVisitor else if (isUncacheable(ft.retTypes)) uncacheable = true; - for (TypeId argTy: ft.argTypes) + for (TypeId argTy : ft.argTypes) { if (isUncacheable(argTy)) { @@ -631,7 +635,7 @@ struct TypeCacher : TypeOnceVisitor } } - for (TypeId retTy: ft.retTypes) + for (TypeId retTy : ft.retTypes) { if (isUncacheable(retTy)) { @@ -640,7 +644,7 @@ struct TypeCacher : TypeOnceVisitor } } - for (TypeId g: ft.generics) + for (TypeId g : ft.generics) { if (isUncacheable(g)) { @@ -863,7 +867,8 @@ struct TypeCacher : TypeOnceVisitor } }; -std::optional generalize(NotNull arena, NotNull builtinTypes, NotNull scope, NotNull> cachedTypes, TypeId ty) +std::optional generalize(NotNull arena, NotNull builtinTypes, NotNull scope, + NotNull> cachedTypes, TypeId ty, bool avoidSealingTables) { ty = follow(ty); @@ -876,14 +881,14 @@ std::optional generalize(NotNull arena, NotNull FreeTypeSearcher fts{scope, cachedTypes}; fts.traverse(ty); - MutatingGeneralizer gen{builtinTypes, scope, cachedTypes, std::move(fts.positiveTypes), std::move(fts.negativeTypes)}; + MutatingGeneralizer gen{builtinTypes, scope, cachedTypes, std::move(fts.positiveTypes), std::move(fts.negativeTypes), avoidSealingTables}; gen.traverse(ty); /* MutatingGeneralizer mutates types in place, so it is possible that ty has - * been transmuted to a BoundType. We must follow it again and verify that - * we are allowed to mutate it before we attach generics to it. - */ + * been transmuted to a BoundType. We must follow it again and verify that + * we are allowed to mutate it before we attach generics to it. + */ ty = follow(ty); if (ty->owningArena != arena || ty->persistent) diff --git a/Analysis/src/Instantiation.cpp b/Analysis/src/Instantiation.cpp index 525319c6..811aa048 100644 --- a/Analysis/src/Instantiation.cpp +++ b/Analysis/src/Instantiation.cpp @@ -11,10 +11,23 @@ #include LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) +LUAU_FASTFLAG(LuauReusableSubstitutions) namespace Luau { +void Instantiation::resetState(const TxnLog* log, TypeArena* arena, NotNull builtinTypes, TypeLevel level, Scope* scope) +{ + LUAU_ASSERT(FFlag::LuauReusableSubstitutions); + + Substitution::resetState(log, arena); + + this->builtinTypes = builtinTypes; + + this->level = level; + this->scope = scope; +} + bool Instantiation::isDirty(TypeId ty) { if (const FunctionType* ftv = log->getMutable(ty)) @@ -58,13 +71,26 @@ TypeId Instantiation::clean(TypeId ty) clone.argNames = ftv->argNames; TypeId result = addType(std::move(clone)); - // Annoyingly, we have to do this even if there are no generics, - // to replace any generic tables. - ReplaceGenerics replaceGenerics{log, arena, builtinTypes, level, scope, ftv->generics, ftv->genericPacks}; + if (FFlag::LuauReusableSubstitutions) + { + // Annoyingly, we have to do this even if there are no generics, + // to replace any generic tables. + reusableReplaceGenerics.resetState(log, arena, builtinTypes, level, scope, ftv->generics, ftv->genericPacks); - // TODO: What to do if this returns nullopt? - // We don't have access to the error-reporting machinery - result = replaceGenerics.substitute(result).value_or(result); + // TODO: What to do if this returns nullopt? + // We don't have access to the error-reporting machinery + result = reusableReplaceGenerics.substitute(result).value_or(result); + } + else + { + // Annoyingly, we have to do this even if there are no generics, + // to replace any generic tables. + ReplaceGenerics replaceGenerics{log, arena, builtinTypes, level, scope, ftv->generics, ftv->genericPacks}; + + // TODO: What to do if this returns nullopt? + // We don't have access to the error-reporting machinery + result = replaceGenerics.substitute(result).value_or(result); + } asMutable(result)->documentationSymbol = ty->documentationSymbol; return result; @@ -76,6 +102,22 @@ TypePackId Instantiation::clean(TypePackId tp) return tp; } +void ReplaceGenerics::resetState(const TxnLog* log, TypeArena* arena, NotNull builtinTypes, TypeLevel level, Scope* scope, + const std::vector& generics, const std::vector& genericPacks) +{ + LUAU_ASSERT(FFlag::LuauReusableSubstitutions); + + Substitution::resetState(log, arena); + + this->builtinTypes = builtinTypes; + + this->level = level; + this->scope = scope; + + this->generics = generics; + this->genericPacks = genericPacks; +} + bool ReplaceGenerics::ignoreChildren(TypeId ty) { if (const FunctionType* ftv = log->getMutable(ty)) diff --git a/Analysis/src/Linter.cpp b/Analysis/src/Linter.cpp index d79361c0..e9d4ca53 100644 --- a/Analysis/src/Linter.cpp +++ b/Analysis/src/Linter.cpp @@ -16,6 +16,11 @@ LUAU_FASTINTVARIABLE(LuauSuggestionDistance, 4) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) +LUAU_FASTFLAG(LuauAttributeSyntax) +LUAU_FASTFLAG(LuauAttribute) +LUAU_FASTFLAG(LuauNativeAttribute) +LUAU_FASTFLAGVARIABLE(LintRedundantNativeAttribute, false) + namespace Luau { @@ -2922,6 +2927,64 @@ static void lintComments(LintContext& context, const std::vector& ho } } +static bool hasNativeCommentDirective(const std::vector& hotcomments) +{ + LUAU_ASSERT(FFlag::LuauAttributeSyntax); + LUAU_ASSERT(FFlag::LuauNativeAttribute); + LUAU_ASSERT(FFlag::LintRedundantNativeAttribute); + + for (const HotComment& hc : hotcomments) + { + if (hc.content.empty() || hc.content[0] == ' ' || hc.content[0] == '\t') + continue; + + if (hc.header) + { + size_t space = hc.content.find_first_of(" \t"); + std::string_view first = std::string_view(hc.content).substr(0, space); + + if (first == "native") + return true; + } + } + + return false; +} + +struct LintRedundantNativeAttribute : AstVisitor +{ +public: + LUAU_NOINLINE static void process(LintContext& context) + { + LUAU_ASSERT(FFlag::LuauAttributeSyntax); + LUAU_ASSERT(FFlag::LuauNativeAttribute); + LUAU_ASSERT(FFlag::LintRedundantNativeAttribute); + + LintRedundantNativeAttribute pass; + pass.context = &context; + context.root->visit(&pass); + } + +private: + LintContext* context; + + bool visit(AstExprFunction* node) override + { + node->body->visit(this); + + for (const auto attribute : node->attributes) + { + if (attribute->type == AstAttr::Type::Native) + { + emitWarning(*context, LintWarning::Code_RedundantNativeAttribute, attribute->location, + "native attribute on a function is redundant in a native module; consider removing it"); + } + } + + return false; + } +}; + std::vector lint(AstStat* root, const AstNameTable& names, const ScopePtr& env, const Module* module, const std::vector& hotcomments, const LintOptions& options) { @@ -3008,6 +3071,13 @@ std::vector lint(AstStat* root, const AstNameTable& names, const Sc if (context.warningEnabled(LintWarning::Code_ComparisonPrecedence)) LintComparisonPrecedence::process(context); + if (FFlag::LuauAttributeSyntax && FFlag::LuauNativeAttribute && FFlag::LintRedundantNativeAttribute && + context.warningEnabled(LintWarning::Code_RedundantNativeAttribute)) + { + if (hasNativeCommentDirective(hotcomments)) + LintRedundantNativeAttribute::process(context); + } + std::sort(context.result.begin(), context.result.end(), WarningComparator()); return context.result; diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index ea9c3178..4e5dae07 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -11,6 +11,7 @@ LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 10000) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTINTVARIABLE(LuauTarjanPreallocationSize, 256); +LUAU_FASTFLAG(LuauReusableSubstitutions) namespace Luau { @@ -146,6 +147,8 @@ static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool a } Tarjan::Tarjan() + : typeToIndex(nullptr, FFlag::LuauReusableSubstitutions ? FInt::LuauTarjanPreallocationSize : 0) + , packToIndex(nullptr, FFlag::LuauReusableSubstitutions ? FInt::LuauTarjanPreallocationSize : 0) { nodes.reserve(FInt::LuauTarjanPreallocationSize); stack.reserve(FInt::LuauTarjanPreallocationSize); @@ -446,14 +449,31 @@ TarjanResult Tarjan::visitRoot(TypePackId tp) return loop(); } -void Tarjan::clearTarjan() +void Tarjan::clearTarjan(const TxnLog* log) { - typeToIndex.clear(); - packToIndex.clear(); + if (FFlag::LuauReusableSubstitutions) + { + typeToIndex.clear(~0u); + packToIndex.clear(~0u); + } + else + { + typeToIndex.clear(); + packToIndex.clear(); + } + nodes.clear(); stack.clear(); + if (FFlag::LuauReusableSubstitutions) + { + childCount = 0; + // childLimit setting stays the same + + this->log = log; + } + edgesTy.clear(); edgesTp.clear(); worklist.clear(); @@ -528,7 +548,6 @@ Substitution::Substitution(const TxnLog* log_, TypeArena* arena) { log = log_; LUAU_ASSERT(log); - LUAU_ASSERT(arena); } void Substitution::dontTraverseInto(TypeId ty) @@ -546,7 +565,7 @@ std::optional Substitution::substitute(TypeId ty) ty = log->follow(ty); // clear algorithm state for reentrancy - clearTarjan(); + clearTarjan(log); auto result = findDirty(ty); if (result != TarjanResult::Ok) @@ -579,7 +598,7 @@ std::optional Substitution::substitute(TypePackId tp) tp = log->follow(tp); // clear algorithm state for reentrancy - clearTarjan(); + clearTarjan(log); auto result = findDirty(tp); if (result != TarjanResult::Ok) @@ -607,6 +626,23 @@ std::optional Substitution::substitute(TypePackId tp) return newTp; } +void Substitution::resetState(const TxnLog* log, TypeArena* arena) +{ + LUAU_ASSERT(FFlag::LuauReusableSubstitutions); + + clearTarjan(log); + + this->arena = arena; + + newTypes.clear(); + newPacks.clear(); + replacedTypes.clear(); + replacedTypePacks.clear(); + + noTraverseTypes.clear(); + noTraverseTypePacks.clear(); +} + TypeId Substitution::clone(TypeId ty) { return shallowClone(ty, *arena, log, /* alwaysClone */ true); diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index cc02bea6..fe0bf2dd 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -1540,6 +1540,24 @@ struct TypeChecker2 visitExprName(indexName->expr, indexName->location, indexName->index.value, context, builtinTypes->stringType); } + void indexExprMetatableHelper(AstExprIndexExpr* indexExpr, const MetatableType* metaTable, TypeId exprType, TypeId indexType) + { + if (auto tt = get(follow(metaTable->table)); tt && tt->indexer) + testIsSubtype(indexType, tt->indexer->indexType, indexExpr->index->location); + else if (auto mt = get(follow(metaTable->table))) + indexExprMetatableHelper(indexExpr, mt, exprType, indexType); + else if (auto tmt = get(follow(metaTable->metatable)); tmt && tmt->indexer) + testIsSubtype(indexType, tmt->indexer->indexType, indexExpr->index->location); + else if (auto mtmt = get(follow(metaTable->metatable))) + indexExprMetatableHelper(indexExpr, mtmt, exprType, indexType); + else + { + LUAU_ASSERT(tt || get(follow(metaTable->table))); + + reportError(CannotExtendTable{exprType, CannotExtendTable::Indexer, "indexer??"}, indexExpr->location); + } + } + void visit(AstExprIndexExpr* indexExpr, ValueContext context) { if (auto str = indexExpr->index->as()) @@ -1565,15 +1583,7 @@ struct TypeChecker2 } else if (auto mt = get(exprType)) { - const TableType* tt = get(follow(mt->table)); - LUAU_ASSERT(tt); - if (tt->indexer) - testIsSubtype(indexType, tt->indexer->indexType, indexExpr->index->location); - else - { - // TODO: Maybe the metatable has a suitable indexer? - reportError(CannotExtendTable{exprType, CannotExtendTable::Indexer, "indexer??"}, indexExpr->location); - } + return indexExprMetatableHelper(indexExpr, mt, exprType, indexType); } else if (auto cls = get(exprType)) { diff --git a/Analysis/src/TypeFamily.cpp b/Analysis/src/TypeFamily.cpp index c65fde00..54d89a15 100644 --- a/Analysis/src/TypeFamily.cpp +++ b/Analysis/src/TypeFamily.cpp @@ -344,8 +344,7 @@ struct FamilyReducer if (tryGuessing(subject)) return; - TypeFamilyReductionResult result = - tfit->family->reducer(subject, tfit->typeArguments, tfit->packArguments, NotNull{&ctx}); + TypeFamilyReductionResult result = tfit->family->reducer(subject, tfit->typeArguments, tfit->packArguments, NotNull{&ctx}); handleFamilyReduction(subject, result); } } @@ -369,8 +368,7 @@ struct FamilyReducer if (tryGuessing(subject)) return; - TypeFamilyReductionResult result = - tfit->family->reducer(subject, tfit->typeArguments, tfit->packArguments, NotNull{&ctx}); + TypeFamilyReductionResult result = tfit->family->reducer(subject, tfit->typeArguments, tfit->packArguments, NotNull{&ctx}); handleFamilyReduction(subject, result); } } @@ -451,8 +449,8 @@ bool isPending(TypeId ty, ConstraintSolver* solver) } template -static std::optional> tryDistributeTypeFamilyApp(F f, TypeId instance, - const std::vector& typeParams, const std::vector& packParams, NotNull ctx, Args&& ...args) +static std::optional> tryDistributeTypeFamilyApp(F f, TypeId instance, const std::vector& typeParams, + const std::vector& packParams, NotNull ctx, Args&&... args) { // op (a | b) (c | d) ~ (op a (c | d)) | (op b (c | d)) ~ (op a c) | (op a d) | (op b c) | (op b d) bool uninhabited = false; @@ -527,8 +525,8 @@ static std::optional> tryDistributeTypeFamilyA return std::nullopt; } -TypeFamilyReductionResult notFamilyFn(TypeId instance, const std::vector& typeParams, - const std::vector& packParams, NotNull ctx) +TypeFamilyReductionResult notFamilyFn( + TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 1 || !packParams.empty()) { @@ -551,8 +549,8 @@ TypeFamilyReductionResult notFamilyFn(TypeId instance, const std::vector return {ctx->builtins->booleanType, false, {}, {}}; } -TypeFamilyReductionResult lenFamilyFn(TypeId instance, const std::vector& typeParams, - const std::vector& packParams, NotNull ctx) +TypeFamilyReductionResult lenFamilyFn( + TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 1 || !packParams.empty()) { @@ -573,7 +571,7 @@ TypeFamilyReductionResult lenFamilyFn(TypeId instance, const std::vector // if the type is free but has only one remaining reference, we can generalize it to its upper bound here. if (ctx->solver) { - std::optional maybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, operandTy); + std::optional maybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, operandTy, /* avoidSealingTables */ true); if (!maybeGeneralized) return {std::nullopt, false, {operandTy}, {}}; operandTy = *maybeGeneralized; @@ -643,8 +641,8 @@ TypeFamilyReductionResult lenFamilyFn(TypeId instance, const std::vector return {ctx->builtins->numberType, false, {}, {}}; } -TypeFamilyReductionResult unmFamilyFn(TypeId instance, const std::vector& typeParams, - const std::vector& packParams, NotNull ctx) +TypeFamilyReductionResult unmFamilyFn( + TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 1 || !packParams.empty()) { @@ -846,8 +844,8 @@ TypeFamilyReductionResult numericBinopFamilyFn(TypeId instance, const st return {extracted.head.front(), false, {}, {}}; } -TypeFamilyReductionResult addFamilyFn(TypeId instance, const std::vector& typeParams, - const std::vector& packParams, NotNull ctx) +TypeFamilyReductionResult addFamilyFn( + TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) { @@ -858,8 +856,8 @@ TypeFamilyReductionResult addFamilyFn(TypeId instance, const std::vector return numericBinopFamilyFn(instance, typeParams, packParams, ctx, "__add"); } -TypeFamilyReductionResult subFamilyFn(TypeId instance, const std::vector& typeParams, - const std::vector& packParams, NotNull ctx) +TypeFamilyReductionResult subFamilyFn( + TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) { @@ -870,8 +868,8 @@ TypeFamilyReductionResult subFamilyFn(TypeId instance, const std::vector return numericBinopFamilyFn(instance, typeParams, packParams, ctx, "__sub"); } -TypeFamilyReductionResult mulFamilyFn(TypeId instance, const std::vector& typeParams, - const std::vector& packParams, NotNull ctx) +TypeFamilyReductionResult mulFamilyFn( + TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) { @@ -882,8 +880,8 @@ TypeFamilyReductionResult mulFamilyFn(TypeId instance, const std::vector return numericBinopFamilyFn(instance, typeParams, packParams, ctx, "__mul"); } -TypeFamilyReductionResult divFamilyFn(TypeId instance, const std::vector& typeParams, - const std::vector& packParams, NotNull ctx) +TypeFamilyReductionResult divFamilyFn( + TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) { @@ -894,8 +892,8 @@ TypeFamilyReductionResult divFamilyFn(TypeId instance, const std::vector return numericBinopFamilyFn(instance, typeParams, packParams, ctx, "__div"); } -TypeFamilyReductionResult idivFamilyFn(TypeId instance, const std::vector& typeParams, - const std::vector& packParams, NotNull ctx) +TypeFamilyReductionResult idivFamilyFn( + TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) { @@ -906,8 +904,8 @@ TypeFamilyReductionResult idivFamilyFn(TypeId instance, const std::vecto return numericBinopFamilyFn(instance, typeParams, packParams, ctx, "__idiv"); } -TypeFamilyReductionResult powFamilyFn(TypeId instance, const std::vector& typeParams, - const std::vector& packParams, NotNull ctx) +TypeFamilyReductionResult powFamilyFn( + TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) { @@ -918,8 +916,8 @@ TypeFamilyReductionResult powFamilyFn(TypeId instance, const std::vector return numericBinopFamilyFn(instance, typeParams, packParams, ctx, "__pow"); } -TypeFamilyReductionResult modFamilyFn(TypeId instance, const std::vector& typeParams, - const std::vector& packParams, NotNull ctx) +TypeFamilyReductionResult modFamilyFn( + TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) { @@ -930,8 +928,8 @@ TypeFamilyReductionResult modFamilyFn(TypeId instance, const std::vector return numericBinopFamilyFn(instance, typeParams, packParams, ctx, "__mod"); } -TypeFamilyReductionResult concatFamilyFn(TypeId instance, const std::vector& typeParams, - const std::vector& packParams, NotNull ctx) +TypeFamilyReductionResult concatFamilyFn( + TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) { @@ -1038,8 +1036,8 @@ TypeFamilyReductionResult concatFamilyFn(TypeId instance, const std::vec return {ctx->builtins->stringType, false, {}, {}}; } -TypeFamilyReductionResult andFamilyFn(TypeId instance, const std::vector& typeParams, - const std::vector& packParams, NotNull ctx) +TypeFamilyReductionResult andFamilyFn( + TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) { @@ -1089,8 +1087,8 @@ TypeFamilyReductionResult andFamilyFn(TypeId instance, const std::vector return {overallResult.result, false, std::move(blockedTypes), {}}; } -TypeFamilyReductionResult orFamilyFn(TypeId instance, const std::vector& typeParams, - const std::vector& packParams, NotNull ctx) +TypeFamilyReductionResult orFamilyFn( + TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) { @@ -1279,8 +1277,8 @@ static TypeFamilyReductionResult comparisonFamilyFn(TypeId instance, con return {ctx->builtins->booleanType, false, {}, {}}; } -TypeFamilyReductionResult ltFamilyFn(TypeId instance, const std::vector& typeParams, - const std::vector& packParams, NotNull ctx) +TypeFamilyReductionResult ltFamilyFn( + TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) { @@ -1291,8 +1289,8 @@ TypeFamilyReductionResult ltFamilyFn(TypeId instance, const std::vector< return comparisonFamilyFn(instance, typeParams, packParams, ctx, "__lt"); } -TypeFamilyReductionResult leFamilyFn(TypeId instance, const std::vector& typeParams, - const std::vector& packParams, NotNull ctx) +TypeFamilyReductionResult leFamilyFn( + TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) { @@ -1303,8 +1301,8 @@ TypeFamilyReductionResult leFamilyFn(TypeId instance, const std::vector< return comparisonFamilyFn(instance, typeParams, packParams, ctx, "__le"); } -TypeFamilyReductionResult eqFamilyFn(TypeId instance, const std::vector& typeParams, - const std::vector& packParams, NotNull ctx) +TypeFamilyReductionResult eqFamilyFn( + TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) { @@ -1434,8 +1432,8 @@ struct FindRefinementBlockers : TypeOnceVisitor }; -TypeFamilyReductionResult refineFamilyFn(TypeId instance, const std::vector& typeParams, - const std::vector& packParams, NotNull ctx) +TypeFamilyReductionResult refineFamilyFn( + TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) { @@ -1519,8 +1517,8 @@ TypeFamilyReductionResult refineFamilyFn(TypeId instance, const std::vec return {resultTy, false, {}, {}}; } -TypeFamilyReductionResult singletonFamilyFn(TypeId instance, const std::vector& typeParams, - const std::vector& packParams, NotNull ctx) +TypeFamilyReductionResult singletonFamilyFn( + TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 1 || !packParams.empty()) { @@ -1556,8 +1554,8 @@ TypeFamilyReductionResult singletonFamilyFn(TypeId instance, const std:: return {ctx->builtins->unknownType, false, {}, {}}; } -TypeFamilyReductionResult unionFamilyFn(TypeId instance, const std::vector& typeParams, - const std::vector& packParams, NotNull ctx) +TypeFamilyReductionResult unionFamilyFn( + TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (!packParams.empty()) { @@ -1617,8 +1615,8 @@ TypeFamilyReductionResult unionFamilyFn(TypeId instance, const std::vect } -TypeFamilyReductionResult intersectFamilyFn(TypeId instance, const std::vector& typeParams, - const std::vector& packParams, NotNull ctx) +TypeFamilyReductionResult intersectFamilyFn( + TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (!packParams.empty()) { @@ -1841,8 +1839,8 @@ TypeFamilyReductionResult keyofFamilyImpl( return {ctx->arena->addType(UnionType{singletons}), false, {}, {}}; } -TypeFamilyReductionResult keyofFamilyFn(TypeId instance, const std::vector& typeParams, - const std::vector& packParams, NotNull ctx) +TypeFamilyReductionResult keyofFamilyFn( + TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 1 || !packParams.empty()) { @@ -1853,8 +1851,8 @@ TypeFamilyReductionResult keyofFamilyFn(TypeId instance, const std::vect return keyofFamilyImpl(typeParams, packParams, ctx, /* isRaw */ false); } -TypeFamilyReductionResult rawkeyofFamilyFn(TypeId instance, const std::vector& typeParams, - const std::vector& packParams, NotNull ctx) +TypeFamilyReductionResult rawkeyofFamilyFn( + TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 1 || !packParams.empty()) { diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 3050f09e..9ce1a58a 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -33,11 +33,11 @@ LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false) LUAU_FASTFLAG(LuauInstantiateInSubtyping) -LUAU_FASTFLAGVARIABLE(LuauMetatableInstantiationCloneCheck, false) LUAU_FASTFLAGVARIABLE(LuauTinyControlFlowAnalysis, false) LUAU_FASTFLAGVARIABLE(LuauAlwaysCommitInferencesOfFunctionCalls, false) LUAU_FASTFLAGVARIABLE(LuauRemoveBadRelationalOperatorWarning, false) LUAU_FASTFLAGVARIABLE(LuauOkWithIteratingOverTableProperties, false) +LUAU_FASTFLAGVARIABLE(LuauReusableSubstitutions, false) namespace Luau { @@ -214,6 +214,7 @@ TypeChecker::TypeChecker(const ScopePtr& globalScope, ModuleResolver* resolver, , iceHandler(iceHandler) , unifierState(iceHandler) , normalizer(nullptr, builtinTypes, NotNull{&unifierState}) + , reusableInstantiation(TxnLog::empty(), nullptr, builtinTypes, {}, nullptr) , nilType(builtinTypes->nilType) , numberType(builtinTypes->numberType) , stringType(builtinTypes->stringType) @@ -4865,12 +4866,27 @@ TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location locat if (ftv && ftv->hasNoFreeOrGenericTypes) return ty; - Instantiation instantiation{log, ¤tModule->internalTypes, builtinTypes, scope->level, /*scope*/ nullptr}; + std::optional instantiated; - if (instantiationChildLimit) - instantiation.childLimit = *instantiationChildLimit; + if (FFlag::LuauReusableSubstitutions) + { + reusableInstantiation.resetState(log, ¤tModule->internalTypes, builtinTypes, scope->level, /*scope*/ nullptr); + + if (instantiationChildLimit) + reusableInstantiation.childLimit = *instantiationChildLimit; + + instantiated = reusableInstantiation.substitute(ty); + } + else + { + Instantiation instantiation{log, ¤tModule->internalTypes, builtinTypes, scope->level, /*scope*/ nullptr}; + + if (instantiationChildLimit) + instantiation.childLimit = *instantiationChildLimit; + + instantiated = instantiation.substitute(ty); + } - std::optional instantiated = instantiation.substitute(ty); if (instantiated.has_value()) return *instantiated; else @@ -5619,8 +5635,8 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, TypeId instantiated = *maybeInstantiated; TypeId target = follow(instantiated); - const TableType* tfTable = FFlag::LuauMetatableInstantiationCloneCheck ? getTableType(tf.type) : nullptr; - bool needsClone = follow(tf.type) == target || (FFlag::LuauMetatableInstantiationCloneCheck && tfTable != nullptr && tfTable == getTableType(target)); + const TableType* tfTable = getTableType(tf.type); + bool needsClone = follow(tf.type) == target || (tfTable != nullptr && tfTable == getTableType(target)); bool shouldMutate = getTableType(tf.type); TableType* ttv = getMutableTableType(target); diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 484e45d0..a0c802dd 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -974,8 +974,7 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp if (!subNorm || !superNorm) reportError(location, NormalizationTooComplex{}); else if ((failedOptionCount == 1 || foundHeuristic) && failedOption) - tryUnifyNormalizedTypes( - subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption); + tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption); else tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the union options are compatible"); } diff --git a/Analysis/src/Unifier2.cpp b/Analysis/src/Unifier2.cpp index f46c3372..6dcd7197 100644 --- a/Analysis/src/Unifier2.cpp +++ b/Analysis/src/Unifier2.cpp @@ -479,7 +479,7 @@ bool Unifier2::unify(const FunctionType* subFn, const AnyType* superAny) bool Unifier2::unify(const AnyType* subAny, const TableType* superTable) { - for (const auto& [propName, prop]: superTable->props) + for (const auto& [propName, prop] : superTable->props) { if (prop.readTy) unify(builtinTypes->anyType, *prop.readTy); @@ -499,7 +499,7 @@ bool Unifier2::unify(const AnyType* subAny, const TableType* superTable) bool Unifier2::unify(const TableType* subTable, const AnyType* superAny) { - for (const auto& [propName, prop]: subTable->props) + for (const auto& [propName, prop] : subTable->props) { if (prop.readTy) unify(*prop.readTy, builtinTypes->anyType); @@ -658,31 +658,31 @@ struct FreeTypeSearcher : TypeVisitor { switch (polarity) { - case Positive: - { - if (seenPositive.contains(ty)) - return true; + case Positive: + { + if (seenPositive.contains(ty)) + return true; - seenPositive.insert(ty); - return false; - } - case Negative: - { - if (seenNegative.contains(ty)) - return true; + seenPositive.insert(ty); + return false; + } + case Negative: + { + if (seenNegative.contains(ty)) + return true; - seenNegative.insert(ty); - return false; - } - case Both: - { - if (seenPositive.contains(ty) && seenNegative.contains(ty)) - return true; + seenNegative.insert(ty); + return false; + } + case Both: + { + if (seenPositive.contains(ty) && seenNegative.contains(ty)) + return true; - seenPositive.insert(ty); - seenNegative.insert(ty); - return false; - } + seenPositive.insert(ty); + seenNegative.insert(ty); + return false; + } } return false; diff --git a/CodeGen/include/Luau/IrData.h b/CodeGen/include/Luau/IrData.h index d0e40ca3..c136c721 100644 --- a/CodeGen/include/Luau/IrData.h +++ b/CodeGen/include/Luau/IrData.h @@ -326,13 +326,12 @@ enum class IrCmd : uint8_t // This is used to recover after calling a variadic function ADJUST_STACK_TO_TOP, - // Execute fastcall builtin function in-place + // Execute fastcall builtin function with 1 argument in-place + // This is used for a few builtins that can have more than 1 result and cannot be represented as a regular instruction // A: unsigned int (builtin id) // B: Rn (result start) - // C: Rn (argument start) - // D: Rn or Kn or undef (optional second argument) - // E: int (argument count) - // F: int (result count) + // C: Rn (first argument) + // D: int (result count) FASTCALL, // Call the fastcall builtin function @@ -340,8 +339,9 @@ enum class IrCmd : uint8_t // B: Rn (result start) // C: Rn (argument start) // D: Rn or Kn or undef (optional second argument) - // E: int (argument count or -1 to use all arguments up to stack top) - // F: int (result count or -1 to preserve all results and adjust stack top) + // E: Rn or Kn or undef (optional third argument) + // F: int (argument count or -1 to use all arguments up to stack top) + // G: int (result count or -1 to preserve all results and adjust stack top) INVOKE_FASTCALL, // Check that fastcall builtin function invocation was successful (negative result count jumps to fallback) diff --git a/CodeGen/include/Luau/IrUtils.h b/CodeGen/include/Luau/IrUtils.h index 8486921e..bc81fc68 100644 --- a/CodeGen/include/Luau/IrUtils.h +++ b/CodeGen/include/Luau/IrUtils.h @@ -64,6 +64,7 @@ inline bool isFastCall(LuauOpcode op) case LOP_FASTCALL1: case LOP_FASTCALL2: case LOP_FASTCALL2K: + case LOP_FASTCALL3: return true; default: diff --git a/CodeGen/include/Luau/IrVisitUseDef.h b/CodeGen/include/Luau/IrVisitUseDef.h index 32dd6c2a..6744bd65 100644 --- a/CodeGen/include/Luau/IrVisitUseDef.h +++ b/CodeGen/include/Luau/IrVisitUseDef.h @@ -4,7 +4,7 @@ #include "Luau/Common.h" #include "Luau/IrData.h" -LUAU_FASTFLAG(LuauCodegenRemoveDeadStores5) +LUAU_FASTFLAG(LuauCodegenFastcall3) namespace Luau { @@ -112,12 +112,48 @@ static void visitVmRegDefsUses(T& visitor, IrFunction& function, const IrInst& i visitor.useRange(vmRegOp(inst.a), function.intOp(inst.b)); break; - // TODO: FASTCALL is more restrictive than INVOKE_FASTCALL; we should either determine the exact semantics, or rework it case IrCmd::FASTCALL: - case IrCmd::INVOKE_FASTCALL: - if (int count = function.intOp(inst.e); count != -1) + if (FFlag::LuauCodegenFastcall3) { - if (count >= 3) + visitor.use(inst.c); + + if (int nresults = function.intOp(inst.d); nresults != -1) + visitor.defRange(vmRegOp(inst.b), nresults); + } + else + { + if (int count = function.intOp(inst.e); count != -1) + { + if (count >= 3) + { + CODEGEN_ASSERT(inst.d.kind == IrOpKind::VmReg && vmRegOp(inst.d) == vmRegOp(inst.c) + 1); + + visitor.useRange(vmRegOp(inst.c), count); + } + else + { + if (count >= 1) + visitor.use(inst.c); + + if (count >= 2) + visitor.maybeUse(inst.d); // Argument can also be a VmConst + } + } + else + { + visitor.useVarargs(vmRegOp(inst.c)); + } + + // Multiple return sequences (count == -1) are defined by ADJUST_STACK_TO_REG + if (int count = function.intOp(inst.f); count != -1) + visitor.defRange(vmRegOp(inst.b), count); + } + break; + case IrCmd::INVOKE_FASTCALL: + if (int count = function.intOp(FFlag::LuauCodegenFastcall3 ? inst.f : inst.e); count != -1) + { + // Only LOP_FASTCALL3 lowering is allowed to have third optional argument + if (count >= 3 && (!FFlag::LuauCodegenFastcall3 || inst.e.kind == IrOpKind::Undef)) { CODEGEN_ASSERT(inst.d.kind == IrOpKind::VmReg && vmRegOp(inst.d) == vmRegOp(inst.c) + 1); @@ -130,6 +166,9 @@ static void visitVmRegDefsUses(T& visitor, IrFunction& function, const IrInst& i if (count >= 2) visitor.maybeUse(inst.d); // Argument can also be a VmConst + + if (FFlag::LuauCodegenFastcall3 && count >= 3) + visitor.maybeUse(inst.e); // Argument can also be a VmConst } } else @@ -138,7 +177,7 @@ static void visitVmRegDefsUses(T& visitor, IrFunction& function, const IrInst& i } // Multiple return sequences (count == -1) are defined by ADJUST_STACK_TO_REG - if (int count = function.intOp(inst.f); count != -1) + if (int count = function.intOp(FFlag::LuauCodegenFastcall3 ? inst.g : inst.f); count != -1) visitor.defRange(vmRegOp(inst.b), count); break; case IrCmd::FORGLOOP: @@ -188,15 +227,8 @@ static void visitVmRegDefsUses(T& visitor, IrFunction& function, const IrInst& i visitor.def(inst.b); break; case IrCmd::FALLBACK_FORGPREP: - if (FFlag::LuauCodegenRemoveDeadStores5) - { - // This instruction doesn't always redefine Rn, Rn+1, Rn+2, so we have to mark it as implicit use - visitor.useRange(vmRegOp(inst.b), 3); - } - else - { - visitor.use(inst.b); - } + // This instruction doesn't always redefine Rn, Rn+1, Rn+2, so we have to mark it as implicit use + visitor.useRange(vmRegOp(inst.b), 3); visitor.defRange(vmRegOp(inst.b), 3); break; @@ -214,12 +246,6 @@ static void visitVmRegDefsUses(T& visitor, IrFunction& function, const IrInst& i visitor.use(inst.a); break; - // After optimizations with DebugLuauAbortingChecks enabled, CHECK_TAG Rn, tag, block instructions are generated - case IrCmd::CHECK_TAG: - if (!FFlag::LuauCodegenRemoveDeadStores5) - visitor.maybeUse(inst.a); - break; - default: // All instructions which reference registers have to be handled explicitly CODEGEN_ASSERT(inst.a.kind != IrOpKind::VmReg); diff --git a/CodeGen/src/BytecodeAnalysis.cpp b/CodeGen/src/BytecodeAnalysis.cpp index fc8eb900..c429188d 100644 --- a/CodeGen/src/BytecodeAnalysis.cpp +++ b/CodeGen/src/BytecodeAnalysis.cpp @@ -11,30 +11,19 @@ #include -LUAU_FASTFLAG(LuauCodegenDirectUserdataFlow) -LUAU_FASTFLAG(LuauLoadTypeInfo) // Because new VM typeinfo load changes the format used by Codegen, same flag is used -LUAU_FASTFLAGVARIABLE(LuauCodegenTypeInfo, false) // New analysis is flagged separately LUAU_FASTFLAGVARIABLE(LuauCodegenAnalyzeHostVectorOps, false) LUAU_FASTFLAGVARIABLE(LuauCodegenLoadTypeUpvalCheck, false) LUAU_FASTFLAGVARIABLE(LuauCodegenUserdataOps, false) +LUAU_FASTFLAGVARIABLE(LuauCodegenFastcall3, false) namespace Luau { namespace CodeGen { -static bool hasTypedParameters(Proto* proto) -{ - CODEGEN_ASSERT(!FFlag::LuauLoadTypeInfo); - - return proto->typeinfo && proto->numparams != 0; -} - template static T read(uint8_t* data, size_t& offset) { - CODEGEN_ASSERT(FFlag::LuauLoadTypeInfo); - T result; memcpy(&result, data + offset, sizeof(T)); offset += sizeof(T); @@ -44,8 +33,6 @@ static T read(uint8_t* data, size_t& offset) static uint32_t readVarInt(uint8_t* data, size_t& offset) { - CODEGEN_ASSERT(FFlag::LuauLoadTypeInfo); - uint32_t result = 0; uint32_t shift = 0; @@ -63,8 +50,6 @@ static uint32_t readVarInt(uint8_t* data, size_t& offset) void loadBytecodeTypeInfo(IrFunction& function) { - CODEGEN_ASSERT(FFlag::LuauLoadTypeInfo); - Proto* proto = function.proto; if (!proto) @@ -173,8 +158,6 @@ static void prepareRegTypeInfoLookups(BytecodeTypeInfo& typeInfo) static BytecodeRegTypeInfo* findRegType(BytecodeTypeInfo& info, uint8_t reg, int pc) { - CODEGEN_ASSERT(FFlag::LuauCodegenTypeInfo); - auto b = info.regTypes.begin() + info.regTypeOffsets[reg]; auto e = info.regTypes.begin() + info.regTypeOffsets[reg + 1]; @@ -199,8 +182,6 @@ static BytecodeRegTypeInfo* findRegType(BytecodeTypeInfo& info, uint8_t reg, int static void refineRegType(BytecodeTypeInfo& info, uint8_t reg, int pc, uint8_t ty) { - CODEGEN_ASSERT(FFlag::LuauCodegenTypeInfo); - if (ty != LBC_TYPE_ANY) { if (BytecodeRegTypeInfo* regType = findRegType(info, reg, pc)) @@ -219,8 +200,6 @@ static void refineRegType(BytecodeTypeInfo& info, uint8_t reg, int pc, uint8_t t static void refineUpvalueType(BytecodeTypeInfo& info, int up, uint8_t ty) { - CODEGEN_ASSERT(FFlag::LuauCodegenTypeInfo); - if (ty != LBC_TYPE_ANY) { if (size_t(up) < info.upvalueTypes.size()) @@ -662,28 +641,12 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) // At the block start, reset or knowledge to the starting state // In the future we might be able to propagate some info between the blocks as well - if (FFlag::LuauLoadTypeInfo) + for (size_t i = 0; i < bcTypeInfo.argumentTypes.size(); i++) { - for (size_t i = 0; i < bcTypeInfo.argumentTypes.size(); i++) - { - uint8_t et = bcTypeInfo.argumentTypes[i]; + uint8_t et = bcTypeInfo.argumentTypes[i]; - // TODO: if argument is optional, this might force a VM exit unnecessarily - regTags[i] = et & ~LBC_TYPE_OPTIONAL_BIT; - } - } - else - { - if (hasTypedParameters(proto)) - { - for (int i = 0; i < proto->numparams; ++i) - { - uint8_t et = proto->typeinfo[2 + i]; - - // TODO: if argument is optional, this might force a VM exit unnecessarily - regTags[i] = et & ~LBC_TYPE_OPTIONAL_BIT; - } - } + // TODO: if argument is optional, this might force a VM exit unnecessarily + regTags[i] = et & ~LBC_TYPE_OPTIONAL_BIT; } for (int i = proto->numparams; i < proto->maxstacksize; ++i) @@ -696,16 +659,13 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) const Instruction* pc = &proto->code[i]; LuauOpcode op = LuauOpcode(LUAU_INSN_OP(*pc)); - if (FFlag::LuauCodegenTypeInfo) + // Assign known register types from local type information + // TODO: this is an expensive walk for each instruction + // TODO: it's best to lookup when register is actually used in the instruction + for (BytecodeRegTypeInfo& el : bcTypeInfo.regTypes) { - // Assign known register types from local type information - // TODO: this is an expensive walk for each instruction - // TODO: it's best to lookup when register is actually used in the instruction - for (BytecodeRegTypeInfo& el : bcTypeInfo.regTypes) - { - if (el.type != LBC_TYPE_ANY && i >= el.startpc && i < el.endpc) - regTags[el.reg] = el.type; - } + if (el.type != LBC_TYPE_ANY && i >= el.startpc && i < el.endpc) + regTags[el.reg] = el.type; } BytecodeTypes& bcType = function.bcTypes[i]; @@ -727,8 +687,7 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) regTags[ra] = LBC_TYPE_BOOLEAN; bcType.result = regTags[ra]; - if (FFlag::LuauCodegenTypeInfo) - refineRegType(bcTypeInfo, ra, i, bcType.result); + refineRegType(bcTypeInfo, ra, i, bcType.result); break; } case LOP_LOADN: @@ -737,8 +696,7 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) regTags[ra] = LBC_TYPE_NUMBER; bcType.result = regTags[ra]; - if (FFlag::LuauCodegenTypeInfo) - refineRegType(bcTypeInfo, ra, i, bcType.result); + refineRegType(bcTypeInfo, ra, i, bcType.result); break; } case LOP_LOADK: @@ -749,8 +707,7 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) regTags[ra] = bcType.a; bcType.result = regTags[ra]; - if (FFlag::LuauCodegenTypeInfo) - refineRegType(bcTypeInfo, ra, i, bcType.result); + refineRegType(bcTypeInfo, ra, i, bcType.result); break; } case LOP_LOADKX: @@ -761,8 +718,7 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) regTags[ra] = bcType.a; bcType.result = regTags[ra]; - if (FFlag::LuauCodegenTypeInfo) - refineRegType(bcTypeInfo, ra, i, bcType.result); + refineRegType(bcTypeInfo, ra, i, bcType.result); break; } case LOP_MOVE: @@ -773,8 +729,7 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) regTags[ra] = regTags[rb]; bcType.result = regTags[ra]; - if (FFlag::LuauCodegenTypeInfo) - refineRegType(bcTypeInfo, ra, i, bcType.result); + refineRegType(bcTypeInfo, ra, i, bcType.result); break; } case LOP_GETTABLE: @@ -1142,8 +1097,7 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) regTags[ra + 3] = bcType.c; regTags[ra] = bcType.result; - if (FFlag::LuauCodegenTypeInfo) - refineRegType(bcTypeInfo, ra, i, bcType.result); + refineRegType(bcTypeInfo, ra, i, bcType.result); break; } case LOP_FASTCALL1: @@ -1161,8 +1115,7 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) regTags[LUAU_INSN_B(*pc)] = bcType.a; regTags[ra] = bcType.result; - if (FFlag::LuauCodegenTypeInfo) - refineRegType(bcTypeInfo, ra, i, bcType.result); + refineRegType(bcTypeInfo, ra, i, bcType.result); break; } case LOP_FASTCALL2: @@ -1180,8 +1133,29 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) regTags[int(pc[1])] = bcType.b; regTags[ra] = bcType.result; - if (FFlag::LuauCodegenTypeInfo) - refineRegType(bcTypeInfo, ra, i, bcType.result); + refineRegType(bcTypeInfo, ra, i, bcType.result); + break; + } + case LOP_FASTCALL3: + { + CODEGEN_ASSERT(FFlag::LuauCodegenFastcall3); + + int bfid = LUAU_INSN_A(*pc); + int skip = LUAU_INSN_C(*pc); + int aux = pc[1]; + + Instruction call = pc[skip + 1]; + CODEGEN_ASSERT(LUAU_INSN_OP(call) == LOP_CALL); + int ra = LUAU_INSN_A(call); + + applyBuiltinCall(bfid, bcType); + + regTags[LUAU_INSN_B(*pc)] = bcType.a; + regTags[aux & 0xff] = bcType.b; + regTags[(aux >> 8) & 0xff] = bcType.c; + regTags[ra] = bcType.result; + + refineRegType(bcTypeInfo, ra, i, bcType.result); break; } case LOP_FORNPREP: @@ -1192,12 +1166,9 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) regTags[ra + 1] = LBC_TYPE_NUMBER; regTags[ra + 2] = LBC_TYPE_NUMBER; - if (FFlag::LuauCodegenTypeInfo) - { - refineRegType(bcTypeInfo, ra, i, regTags[ra]); - refineRegType(bcTypeInfo, ra + 1, i, regTags[ra + 1]); - refineRegType(bcTypeInfo, ra + 2, i, regTags[ra + 2]); - } + refineRegType(bcTypeInfo, ra, i, regTags[ra]); + refineRegType(bcTypeInfo, ra + 1, i, regTags[ra + 1]); + refineRegType(bcTypeInfo, ra + 2, i, regTags[ra + 2]); break; } case LOP_FORNLOOP: @@ -1227,42 +1198,39 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) } case LOP_NAMECALL: { - if (FFlag::LuauCodegenDirectUserdataFlow) + int ra = LUAU_INSN_A(*pc); + int rb = LUAU_INSN_B(*pc); + uint32_t kc = pc[1]; + + bcType.a = regTags[rb]; + bcType.b = getBytecodeConstantTag(proto, kc); + + // While namecall might result in a callable table, we assume the function fast path + regTags[ra] = LBC_TYPE_FUNCTION; + + // Namecall places source register into target + 1 + regTags[ra + 1] = bcType.a; + + bcType.result = LBC_TYPE_FUNCTION; + + if (FFlag::LuauCodegenUserdataOps) { - int ra = LUAU_INSN_A(*pc); - int rb = LUAU_INSN_B(*pc); - uint32_t kc = pc[1]; + TString* str = gco2ts(function.proto->k[kc].value.gc); + const char* field = getstr(str); - bcType.a = regTags[rb]; - bcType.b = getBytecodeConstantTag(proto, kc); - - // While namecall might result in a callable table, we assume the function fast path - regTags[ra] = LBC_TYPE_FUNCTION; - - // Namecall places source register into target + 1 - regTags[ra + 1] = bcType.a; - - bcType.result = LBC_TYPE_FUNCTION; - - if (FFlag::LuauCodegenUserdataOps) + if (FFlag::LuauCodegenAnalyzeHostVectorOps && bcType.a == LBC_TYPE_VECTOR && hostHooks.vectorNamecallBytecodeType) + knownNextCallResult = LuauBytecodeType(hostHooks.vectorNamecallBytecodeType(field, str->len)); + else if (isCustomUserdataBytecodeType(bcType.a) && hostHooks.userdataNamecallBytecodeType) + knownNextCallResult = LuauBytecodeType(hostHooks.userdataNamecallBytecodeType(bcType.a, field, str->len)); + } + else + { + if (FFlag::LuauCodegenAnalyzeHostVectorOps && bcType.a == LBC_TYPE_VECTOR && hostHooks.vectorNamecallBytecodeType) { TString* str = gco2ts(function.proto->k[kc].value.gc); const char* field = getstr(str); - if (FFlag::LuauCodegenAnalyzeHostVectorOps && bcType.a == LBC_TYPE_VECTOR && hostHooks.vectorNamecallBytecodeType) - knownNextCallResult = LuauBytecodeType(hostHooks.vectorNamecallBytecodeType(field, str->len)); - else if (isCustomUserdataBytecodeType(bcType.a) && hostHooks.userdataNamecallBytecodeType) - knownNextCallResult = LuauBytecodeType(hostHooks.userdataNamecallBytecodeType(bcType.a, field, str->len)); - } - else - { - if (FFlag::LuauCodegenAnalyzeHostVectorOps && bcType.a == LBC_TYPE_VECTOR && hostHooks.vectorNamecallBytecodeType) - { - TString* str = gco2ts(function.proto->k[kc].value.gc); - const char* field = getstr(str); - - knownNextCallResult = LuauBytecodeType(hostHooks.vectorNamecallBytecodeType(field, str->len)); - } + knownNextCallResult = LuauBytecodeType(hostHooks.vectorNamecallBytecodeType(field, str->len)); } } break; @@ -1282,42 +1250,35 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) regTags[ra] = bcType.result; } - if (FFlag::LuauCodegenTypeInfo) - refineRegType(bcTypeInfo, ra, i, bcType.result); + refineRegType(bcTypeInfo, ra, i, bcType.result); } break; } case LOP_GETUPVAL: { - if (FFlag::LuauCodegenTypeInfo) + int ra = LUAU_INSN_A(*pc); + int up = LUAU_INSN_B(*pc); + + bcType.a = LBC_TYPE_ANY; + + if (size_t(up) < bcTypeInfo.upvalueTypes.size()) { - int ra = LUAU_INSN_A(*pc); - int up = LUAU_INSN_B(*pc); + uint8_t et = bcTypeInfo.upvalueTypes[up]; - bcType.a = LBC_TYPE_ANY; - - if (size_t(up) < bcTypeInfo.upvalueTypes.size()) - { - uint8_t et = bcTypeInfo.upvalueTypes[up]; - - // TODO: if argument is optional, this might force a VM exit unnecessarily - bcType.a = et & ~LBC_TYPE_OPTIONAL_BIT; - } - - regTags[ra] = bcType.a; - bcType.result = regTags[ra]; + // TODO: if argument is optional, this might force a VM exit unnecessarily + bcType.a = et & ~LBC_TYPE_OPTIONAL_BIT; } + + regTags[ra] = bcType.a; + bcType.result = regTags[ra]; break; } case LOP_SETUPVAL: { - if (FFlag::LuauCodegenTypeInfo) - { - int ra = LUAU_INSN_A(*pc); - int up = LUAU_INSN_B(*pc); + int ra = LUAU_INSN_A(*pc); + int up = LUAU_INSN_B(*pc); - refineUpvalueType(bcTypeInfo, up, regTags[ra]); - } + refineUpvalueType(bcTypeInfo, up, regTags[ra]); break; } case LOP_GETGLOBAL: diff --git a/CodeGen/src/CodeGenAssembly.cpp b/CodeGen/src/CodeGenAssembly.cpp index 121535be..de8dcecf 100644 --- a/CodeGen/src/CodeGenAssembly.cpp +++ b/CodeGen/src/CodeGenAssembly.cpp @@ -12,7 +12,6 @@ #include "lapi.h" -LUAU_FASTFLAG(LuauCodegenTypeInfo) LUAU_FASTFLAG(LuauLoadUserdataInfo) LUAU_FASTFLAG(LuauNativeAttribute) @@ -87,7 +86,6 @@ static void logFunctionHeader(AssemblyBuilder& build, Proto* proto) template static void logFunctionTypes_DEPRECATED(AssemblyBuilder& build, const IrFunction& function) { - CODEGEN_ASSERT(FFlag::LuauCodegenTypeInfo); CODEGEN_ASSERT(!FFlag::LuauLoadUserdataInfo); const BytecodeTypeInfo& typeInfo = function.bcTypeInfo; @@ -131,7 +129,6 @@ static void logFunctionTypes_DEPRECATED(AssemblyBuilder& build, const IrFunction template static void logFunctionTypes(AssemblyBuilder& build, const IrFunction& function, const char* const* userdataTypes) { - CODEGEN_ASSERT(FFlag::LuauCodegenTypeInfo); CODEGEN_ASSERT(FFlag::LuauLoadUserdataInfo); const BytecodeTypeInfo& typeInfo = function.bcTypeInfo; @@ -240,7 +237,7 @@ static std::string getAssemblyImpl(AssemblyBuilder& build, const TValue* func, A if (options.includeAssembly || options.includeIr) logFunctionHeader(build, p); - if (FFlag::LuauCodegenTypeInfo && options.includeIrTypes) + if (options.includeIrTypes) { if (FFlag::LuauLoadUserdataInfo) logFunctionTypes(build, ir.function, options.compilationOptions.userdataTypes); diff --git a/CodeGen/src/CodeGenLower.h b/CodeGen/src/CodeGenLower.h index 4523d62b..e7701361 100644 --- a/CodeGen/src/CodeGenLower.h +++ b/CodeGen/src/CodeGenLower.h @@ -27,7 +27,6 @@ LUAU_FASTFLAG(DebugCodegenSkipNumbering) LUAU_FASTINT(CodegenHeuristicsInstructionLimit) LUAU_FASTINT(CodegenHeuristicsBlockLimit) LUAU_FASTINT(CodegenHeuristicsBlockInstructionLimit) -LUAU_FASTFLAG(LuauCodegenRemoveDeadStores5) LUAU_FASTFLAG(LuauLoadUserdataInfo) LUAU_FASTFLAG(LuauNativeAttribute) @@ -347,8 +346,7 @@ inline bool lowerFunction(IrBuilder& ir, AssemblyBuilder& build, ModuleHelpers& } } - if (FFlag::LuauCodegenRemoveDeadStores5) - markDeadStoresInBlockChains(ir); + markDeadStoresInBlockChains(ir); } std::vector sortedBlocks = getSortedBlockOrder(ir.function); diff --git a/CodeGen/src/EmitBuiltinsX64.cpp b/CodeGen/src/EmitBuiltinsX64.cpp index 96d22e13..09f69d69 100644 --- a/CodeGen/src/EmitBuiltinsX64.cpp +++ b/CodeGen/src/EmitBuiltinsX64.cpp @@ -12,8 +12,6 @@ #include "lstate.h" -LUAU_FASTFLAG(LuauCodegenRemoveDeadStores5) - namespace Luau { namespace CodeGen @@ -29,17 +27,13 @@ static void emitBuiltinMathFrexp(IrRegAllocX64& regs, AssemblyBuilderX64& build, callWrap.call(qword[rNativeContext + offsetof(NativeContext, libm_frexp)]); build.vmovsd(luauRegValue(ra), xmm0); - - if (FFlag::LuauCodegenRemoveDeadStores5) - build.mov(luauRegTag(ra), LUA_TNUMBER); + build.mov(luauRegTag(ra), LUA_TNUMBER); if (nresults > 1) { build.vcvtsi2sd(xmm0, xmm0, dword[sTemporarySlot + 0]); build.vmovsd(luauRegValue(ra + 1), xmm0); - - if (FFlag::LuauCodegenRemoveDeadStores5) - build.mov(luauRegTag(ra + 1), LUA_TNUMBER); + build.mov(luauRegTag(ra + 1), LUA_TNUMBER); } } @@ -52,16 +46,12 @@ static void emitBuiltinMathModf(IrRegAllocX64& regs, AssemblyBuilderX64& build, build.vmovsd(xmm1, qword[sTemporarySlot + 0]); build.vmovsd(luauRegValue(ra), xmm1); - - if (FFlag::LuauCodegenRemoveDeadStores5) - build.mov(luauRegTag(ra), LUA_TNUMBER); + build.mov(luauRegTag(ra), LUA_TNUMBER); if (nresults > 1) { build.vmovsd(luauRegValue(ra + 1), xmm0); - - if (FFlag::LuauCodegenRemoveDeadStores5) - build.mov(luauRegTag(ra + 1), LUA_TNUMBER); + build.mov(luauRegTag(ra + 1), LUA_TNUMBER); } } @@ -90,23 +80,21 @@ static void emitBuiltinMathSign(IrRegAllocX64& regs, AssemblyBuilderX64& build, build.vblendvpd(tmp0.reg, tmp2.reg, build.f64x2(1, 1), tmp0.reg); build.vmovsd(luauRegValue(ra), tmp0.reg); - - if (FFlag::LuauCodegenRemoveDeadStores5) - build.mov(luauRegTag(ra), LUA_TNUMBER); + build.mov(luauRegTag(ra), LUA_TNUMBER); } -void emitBuiltin(IrRegAllocX64& regs, AssemblyBuilderX64& build, int bfid, int ra, int arg, OperandX64 arg2, int nparams, int nresults) +void emitBuiltin(IrRegAllocX64& regs, AssemblyBuilderX64& build, int bfid, int ra, int arg, int nresults) { switch (bfid) { case LBF_MATH_FREXP: - CODEGEN_ASSERT(nparams == 1 && (nresults == 1 || nresults == 2)); + CODEGEN_ASSERT(nresults == 1 || nresults == 2); return emitBuiltinMathFrexp(regs, build, ra, arg, nresults); case LBF_MATH_MODF: - CODEGEN_ASSERT(nparams == 1 && (nresults == 1 || nresults == 2)); + CODEGEN_ASSERT(nresults == 1 || nresults == 2); return emitBuiltinMathModf(regs, build, ra, arg, nresults); case LBF_MATH_SIGN: - CODEGEN_ASSERT(nparams == 1 && nresults == 1); + CODEGEN_ASSERT(nresults == 1); return emitBuiltinMathSign(regs, build, ra, arg); default: CODEGEN_ASSERT(!"Missing x64 lowering"); diff --git a/CodeGen/src/EmitBuiltinsX64.h b/CodeGen/src/EmitBuiltinsX64.h index cd8b5251..72a1ad15 100644 --- a/CodeGen/src/EmitBuiltinsX64.h +++ b/CodeGen/src/EmitBuiltinsX64.h @@ -16,7 +16,7 @@ class AssemblyBuilderX64; struct OperandX64; struct IrRegAllocX64; -void emitBuiltin(IrRegAllocX64& regs, AssemblyBuilderX64& build, int bfid, int ra, int arg, OperandX64 arg2, int nparams, int nresults); +void emitBuiltin(IrRegAllocX64& regs, AssemblyBuilderX64& build, int bfid, int ra, int arg, int nresults); } // namespace X64 } // namespace CodeGen diff --git a/CodeGen/src/IrBuilder.cpp b/CodeGen/src/IrBuilder.cpp index e62885eb..672c27ad 100644 --- a/CodeGen/src/IrBuilder.cpp +++ b/CodeGen/src/IrBuilder.cpp @@ -13,10 +13,10 @@ #include -LUAU_FASTFLAG(LuauLoadTypeInfo) // Because new VM typeinfo load changes the format used by Codegen, same flag is used LUAU_FASTFLAG(LuauCodegenAnalyzeHostVectorOps) LUAU_FASTFLAG(LuauLoadUserdataInfo) LUAU_FASTFLAG(LuauCodegenInstG) +LUAU_FASTFLAG(LuauCodegenFastcall3) namespace Luau { @@ -30,96 +30,9 @@ IrBuilder::IrBuilder(const HostIrHooks& hostHooks) , constantMap({IrConstKind::Tag, ~0ull}) { } -static bool hasTypedParameters_DEPRECATED(Proto* proto) -{ - CODEGEN_ASSERT(!FFlag::LuauLoadTypeInfo); - - return proto->typeinfo && proto->numparams != 0; -} - -static void buildArgumentTypeChecks_DEPRECATED(IrBuilder& build, Proto* proto) -{ - CODEGEN_ASSERT(!FFlag::LuauLoadTypeInfo); - CODEGEN_ASSERT(hasTypedParameters_DEPRECATED(proto)); - - for (int i = 0; i < proto->numparams; ++i) - { - uint8_t et = proto->typeinfo[2 + i]; - - uint8_t tag = et & ~LBC_TYPE_OPTIONAL_BIT; - uint8_t optional = et & LBC_TYPE_OPTIONAL_BIT; - - if (tag == LBC_TYPE_ANY) - continue; - - IrOp load = build.inst(IrCmd::LOAD_TAG, build.vmReg(i)); - - IrOp nextCheck; - if (optional) - { - nextCheck = build.block(IrBlockKind::Internal); - IrOp fallbackCheck = build.block(IrBlockKind::Internal); - - build.inst(IrCmd::JUMP_EQ_TAG, load, build.constTag(LUA_TNIL), nextCheck, fallbackCheck); - - build.beginBlock(fallbackCheck); - } - - switch (tag) - { - case LBC_TYPE_NIL: - build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TNIL), build.vmExit(kVmExitEntryGuardPc)); - break; - case LBC_TYPE_BOOLEAN: - build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TBOOLEAN), build.vmExit(kVmExitEntryGuardPc)); - break; - case LBC_TYPE_NUMBER: - build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TNUMBER), build.vmExit(kVmExitEntryGuardPc)); - break; - case LBC_TYPE_STRING: - build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TSTRING), build.vmExit(kVmExitEntryGuardPc)); - break; - case LBC_TYPE_TABLE: - build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TTABLE), build.vmExit(kVmExitEntryGuardPc)); - break; - case LBC_TYPE_FUNCTION: - build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TFUNCTION), build.vmExit(kVmExitEntryGuardPc)); - break; - case LBC_TYPE_THREAD: - build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TTHREAD), build.vmExit(kVmExitEntryGuardPc)); - break; - case LBC_TYPE_USERDATA: - build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TUSERDATA), build.vmExit(kVmExitEntryGuardPc)); - break; - case LBC_TYPE_VECTOR: - build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TVECTOR), build.vmExit(kVmExitEntryGuardPc)); - break; - case LBC_TYPE_BUFFER: - build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TBUFFER), build.vmExit(kVmExitEntryGuardPc)); - break; - } - - if (optional) - { - build.inst(IrCmd::JUMP, nextCheck); - build.beginBlock(nextCheck); - } - } - - // If the last argument is optional, we can skip creating a new internal block since one will already have been created. - if (!(proto->typeinfo[2 + proto->numparams - 1] & LBC_TYPE_OPTIONAL_BIT)) - { - IrOp next = build.block(IrBlockKind::Internal); - build.inst(IrCmd::JUMP, next); - - build.beginBlock(next); - } -} static bool hasTypedParameters(const BytecodeTypeInfo& typeInfo) { - CODEGEN_ASSERT(FFlag::LuauLoadTypeInfo); - for (auto el : typeInfo.argumentTypes) { if (el != LBC_TYPE_ANY) @@ -131,8 +44,6 @@ static bool hasTypedParameters(const BytecodeTypeInfo& typeInfo) static void buildArgumentTypeChecks(IrBuilder& build) { - CODEGEN_ASSERT(FFlag::LuauLoadTypeInfo); - const BytecodeTypeInfo& typeInfo = build.function.bcTypeInfo; CODEGEN_ASSERT(hasTypedParameters(typeInfo)); @@ -228,11 +139,10 @@ void IrBuilder::buildFunctionIr(Proto* proto) function.proto = proto; function.variadic = proto->is_vararg != 0; - if (FFlag::LuauLoadTypeInfo) - loadBytecodeTypeInfo(function); + loadBytecodeTypeInfo(function); // Reserve entry block - bool generateTypeChecks = FFlag::LuauLoadTypeInfo ? hasTypedParameters(function.bcTypeInfo) : hasTypedParameters_DEPRECATED(proto); + bool generateTypeChecks = hasTypedParameters(function.bcTypeInfo); IrOp entry = generateTypeChecks ? block(IrBlockKind::Internal) : IrOp{}; // Rebuild original control flow blocks @@ -247,10 +157,7 @@ void IrBuilder::buildFunctionIr(Proto* proto) { beginBlock(entry); - if (FFlag::LuauLoadTypeInfo) - buildArgumentTypeChecks(*this); - else - buildArgumentTypeChecks_DEPRECATED(*this, proto); + buildArgumentTypeChecks(*this); inst(IrCmd::JUMP, blockAtInst(0)); } @@ -544,16 +451,21 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) translateInstCloseUpvals(*this, pc); break; case LOP_FASTCALL: - handleFastcallFallback(translateFastCallN(*this, pc, i, false, 0, {}), pc, i); + handleFastcallFallback(translateFastCallN(*this, pc, i, false, 0, {}, {}), pc, i); break; case LOP_FASTCALL1: - handleFastcallFallback(translateFastCallN(*this, pc, i, true, 1, undef()), pc, i); + handleFastcallFallback(translateFastCallN(*this, pc, i, true, 1, undef(), undef()), pc, i); break; case LOP_FASTCALL2: - handleFastcallFallback(translateFastCallN(*this, pc, i, true, 2, vmReg(pc[1])), pc, i); + handleFastcallFallback(translateFastCallN(*this, pc, i, true, 2, vmReg(pc[1]), undef()), pc, i); break; case LOP_FASTCALL2K: - handleFastcallFallback(translateFastCallN(*this, pc, i, true, 2, vmConst(pc[1])), pc, i); + handleFastcallFallback(translateFastCallN(*this, pc, i, true, 2, vmConst(pc[1]), undef()), pc, i); + break; + case LOP_FASTCALL3: + CODEGEN_ASSERT(FFlag::LuauCodegenFastcall3); + + handleFastcallFallback(translateFastCallN(*this, pc, i, true, 3, vmReg(pc[1] & 0xff), vmReg((pc[1] >> 8) & 0xff)), pc, i); break; case LOP_FORNPREP: translateInstForNPrep(*this, pc, i); diff --git a/CodeGen/src/IrLoweringA64.cpp b/CodeGen/src/IrLoweringA64.cpp index ea83bb99..5b333374 100644 --- a/CodeGen/src/IrLoweringA64.cpp +++ b/CodeGen/src/IrLoweringA64.cpp @@ -11,11 +11,11 @@ #include "lstate.h" #include "lgc.h" -LUAU_FASTFLAG(LuauCodegenRemoveDeadStores5) LUAU_FASTFLAG(LuauCodegenSplitDoarith) LUAU_FASTFLAG(LuauCodegenUserdataOps) LUAU_FASTFLAGVARIABLE(LuauCodegenUserdataAlloc, false) LUAU_FASTFLAGVARIABLE(LuauCodegenUserdataOpsFixA64, false) +LUAU_FASTFLAG(LuauCodegenFastcall3) namespace Luau { @@ -197,78 +197,50 @@ static void emitInvokeLibm1P(AssemblyBuilderA64& build, size_t func, int arg) build.blr(x1); } -static bool emitBuiltin( - AssemblyBuilderA64& build, IrFunction& function, IrRegAllocA64& regs, int bfid, int res, int arg, IrOp args, int nparams, int nresults) +static bool emitBuiltin(AssemblyBuilderA64& build, IrFunction& function, IrRegAllocA64& regs, int bfid, int res, int arg, int nresults) { switch (bfid) { case LBF_MATH_FREXP: { - if (FFlag::LuauCodegenRemoveDeadStores5) - { - CODEGEN_ASSERT(nparams == 1 && (nresults == 1 || nresults == 2)); - emitInvokeLibm1P(build, offsetof(NativeContext, libm_frexp), arg); - build.str(d0, mem(rBase, res * sizeof(TValue) + offsetof(TValue, value.n))); + CODEGEN_ASSERT(nresults == 1 || nresults == 2); + emitInvokeLibm1P(build, offsetof(NativeContext, libm_frexp), arg); + build.str(d0, mem(rBase, res * sizeof(TValue) + offsetof(TValue, value.n))); - RegisterA64 temp = regs.allocTemp(KindA64::w); - build.mov(temp, LUA_TNUMBER); - build.str(temp, mem(rBase, res * sizeof(TValue) + offsetof(TValue, tt))); + RegisterA64 temp = regs.allocTemp(KindA64::w); + build.mov(temp, LUA_TNUMBER); + build.str(temp, mem(rBase, res * sizeof(TValue) + offsetof(TValue, tt))); - if (nresults == 2) - { - build.ldr(w0, sTemporary); - build.scvtf(d1, w0); - build.str(d1, mem(rBase, (res + 1) * sizeof(TValue) + offsetof(TValue, value.n))); - build.str(temp, mem(rBase, (res + 1) * sizeof(TValue) + offsetof(TValue, tt))); - } - } - else + if (nresults == 2) { - CODEGEN_ASSERT(nparams == 1 && (nresults == 1 || nresults == 2)); - emitInvokeLibm1P(build, offsetof(NativeContext, libm_frexp), arg); - build.str(d0, mem(rBase, res * sizeof(TValue) + offsetof(TValue, value.n))); - if (nresults == 2) - { - build.ldr(w0, sTemporary); - build.scvtf(d1, w0); - build.str(d1, mem(rBase, (res + 1) * sizeof(TValue) + offsetof(TValue, value.n))); - } + build.ldr(w0, sTemporary); + build.scvtf(d1, w0); + build.str(d1, mem(rBase, (res + 1) * sizeof(TValue) + offsetof(TValue, value.n))); + build.str(temp, mem(rBase, (res + 1) * sizeof(TValue) + offsetof(TValue, tt))); } return true; } case LBF_MATH_MODF: { - if (FFlag::LuauCodegenRemoveDeadStores5) - { - CODEGEN_ASSERT(nparams == 1 && (nresults == 1 || nresults == 2)); - emitInvokeLibm1P(build, offsetof(NativeContext, libm_modf), arg); - build.ldr(d1, sTemporary); - build.str(d1, mem(rBase, res * sizeof(TValue) + offsetof(TValue, value.n))); + CODEGEN_ASSERT(nresults == 1 || nresults == 2); + emitInvokeLibm1P(build, offsetof(NativeContext, libm_modf), arg); + build.ldr(d1, sTemporary); + build.str(d1, mem(rBase, res * sizeof(TValue) + offsetof(TValue, value.n))); - RegisterA64 temp = regs.allocTemp(KindA64::w); - build.mov(temp, LUA_TNUMBER); - build.str(temp, mem(rBase, res * sizeof(TValue) + offsetof(TValue, tt))); + RegisterA64 temp = regs.allocTemp(KindA64::w); + build.mov(temp, LUA_TNUMBER); + build.str(temp, mem(rBase, res * sizeof(TValue) + offsetof(TValue, tt))); - if (nresults == 2) - { - build.str(d0, mem(rBase, (res + 1) * sizeof(TValue) + offsetof(TValue, value.n))); - build.str(temp, mem(rBase, (res + 1) * sizeof(TValue) + offsetof(TValue, tt))); - } - } - else + if (nresults == 2) { - CODEGEN_ASSERT(nparams == 1 && (nresults == 1 || nresults == 2)); - emitInvokeLibm1P(build, offsetof(NativeContext, libm_modf), arg); - build.ldr(d1, sTemporary); - build.str(d1, mem(rBase, res * sizeof(TValue) + offsetof(TValue, value.n))); - if (nresults == 2) - build.str(d0, mem(rBase, (res + 1) * sizeof(TValue) + offsetof(TValue, value.n))); + build.str(d0, mem(rBase, (res + 1) * sizeof(TValue) + offsetof(TValue, value.n))); + build.str(temp, mem(rBase, (res + 1) * sizeof(TValue) + offsetof(TValue, tt))); } return true; } case LBF_MATH_SIGN: { - CODEGEN_ASSERT(nparams == 1 && nresults == 1); + CODEGEN_ASSERT(nresults == 1); build.ldr(d0, mem(rBase, arg * sizeof(TValue) + offsetof(TValue, value.n))); build.fcmpz(d0); build.fmov(d0, 0.0); @@ -278,12 +250,10 @@ static bool emitBuiltin( build.fcsel(d0, d1, d0, getConditionFP(IrCondition::Less)); build.str(d0, mem(rBase, res * sizeof(TValue) + offsetof(TValue, value.n))); - if (FFlag::LuauCodegenRemoveDeadStores5) - { - RegisterA64 temp = regs.allocTemp(KindA64::w); - build.mov(temp, LUA_TNUMBER); - build.str(temp, mem(rBase, res * sizeof(TValue) + offsetof(TValue, tt))); - } + RegisterA64 temp = regs.allocTemp(KindA64::w); + build.mov(temp, LUA_TNUMBER); + build.str(temp, mem(rBase, res * sizeof(TValue) + offsetof(TValue, tt))); + return true; } @@ -1205,34 +1175,88 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) } case IrCmd::FASTCALL: regs.spill(build, index); - error |= !emitBuiltin(build, function, regs, uintOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c), inst.d, intOp(inst.e), intOp(inst.f)); + + if (FFlag::LuauCodegenFastcall3) + error |= !emitBuiltin(build, function, regs, uintOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c), intOp(inst.d)); + else + error |= !emitBuiltin(build, function, regs, uintOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c), intOp(inst.f)); + break; case IrCmd::INVOKE_FASTCALL: { - regs.spill(build, index); - build.mov(x0, rState); - build.add(x1, rBase, uint16_t(vmRegOp(inst.b) * sizeof(TValue))); - build.add(x2, rBase, uint16_t(vmRegOp(inst.c) * sizeof(TValue))); - build.mov(w3, intOp(inst.f)); // nresults - - if (inst.d.kind == IrOpKind::VmReg) - build.add(x4, rBase, uint16_t(vmRegOp(inst.d) * sizeof(TValue))); - else if (inst.d.kind == IrOpKind::VmConst) - emitAddOffset(build, x4, rConstants, vmConstOp(inst.d) * sizeof(TValue)); - else - CODEGEN_ASSERT(inst.d.kind == IrOpKind::Undef); - - // nparams - if (intOp(inst.e) == LUA_MULTRET) + if (FFlag::LuauCodegenFastcall3) { - // L->top - (ra + 1) - build.ldr(x5, mem(rState, offsetof(lua_State, top))); - build.sub(x5, x5, rBase); - build.sub(x5, x5, uint16_t((vmRegOp(inst.b) + 1) * sizeof(TValue))); - build.lsr(x5, x5, kTValueSizeLog2); + // We might need a temporary and we have to preserve it over the spill + RegisterA64 temp = regs.allocTemp(KindA64::q); + regs.spill(build, index, {temp}); + + build.mov(x0, rState); + build.add(x1, rBase, uint16_t(vmRegOp(inst.b) * sizeof(TValue))); + build.add(x2, rBase, uint16_t(vmRegOp(inst.c) * sizeof(TValue))); + build.mov(w3, intOp(inst.g)); // nresults + + // 'E' argument can only be produced by LOP_FASTCALL3 lowering + if (inst.e.kind != IrOpKind::Undef) + { + CODEGEN_ASSERT(intOp(inst.f) == 3); + + build.ldr(x4, mem(rState, offsetof(lua_State, top))); + + build.ldr(temp, mem(rBase, vmRegOp(inst.d) * sizeof(TValue))); + build.str(temp, mem(x4, 0)); + + build.ldr(temp, mem(rBase, vmRegOp(inst.e) * sizeof(TValue))); + build.str(temp, mem(x4, sizeof(TValue))); + } + else + { + if (inst.d.kind == IrOpKind::VmReg) + build.add(x4, rBase, uint16_t(vmRegOp(inst.d) * sizeof(TValue))); + else if (inst.d.kind == IrOpKind::VmConst) + emitAddOffset(build, x4, rConstants, vmConstOp(inst.d) * sizeof(TValue)); + else + CODEGEN_ASSERT(inst.d.kind == IrOpKind::Undef); + } + + // nparams + if (intOp(inst.f) == LUA_MULTRET) + { + // L->top - (ra + 1) + build.ldr(x5, mem(rState, offsetof(lua_State, top))); + build.sub(x5, x5, rBase); + build.sub(x5, x5, uint16_t((vmRegOp(inst.b) + 1) * sizeof(TValue))); + build.lsr(x5, x5, kTValueSizeLog2); + } + else + build.mov(w5, intOp(inst.f)); } else - build.mov(w5, intOp(inst.e)); + { + regs.spill(build, index); + build.mov(x0, rState); + build.add(x1, rBase, uint16_t(vmRegOp(inst.b) * sizeof(TValue))); + build.add(x2, rBase, uint16_t(vmRegOp(inst.c) * sizeof(TValue))); + build.mov(w3, intOp(inst.f)); // nresults + + if (inst.d.kind == IrOpKind::VmReg) + build.add(x4, rBase, uint16_t(vmRegOp(inst.d) * sizeof(TValue))); + else if (inst.d.kind == IrOpKind::VmConst) + emitAddOffset(build, x4, rConstants, vmConstOp(inst.d) * sizeof(TValue)); + else + CODEGEN_ASSERT(inst.d.kind == IrOpKind::Undef); + + // nparams + if (intOp(inst.e) == LUA_MULTRET) + { + // L->top - (ra + 1) + build.ldr(x5, mem(rState, offsetof(lua_State, top))); + build.sub(x5, x5, rBase); + build.sub(x5, x5, uint16_t((vmRegOp(inst.b) + 1) * sizeof(TValue))); + build.lsr(x5, x5, kTValueSizeLog2); + } + else + build.mov(w5, intOp(inst.e)); + } build.ldr(x6, mem(rNativeContext, offsetof(NativeContext, luauF_table) + uintOp(inst.a) * sizeof(luau_FastFunction))); build.blr(x6); @@ -1443,35 +1467,14 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) Label fresh; // used when guard aborts execution or jumps to a VM exit Label& fail = getTargetLabel(inst.c, fresh); - if (FFlag::LuauCodegenRemoveDeadStores5) + if (tagOp(inst.b) == 0) { - if (tagOp(inst.b) == 0) - { - build.cbnz(regOp(inst.a), fail); - } - else - { - build.cmp(regOp(inst.a), tagOp(inst.b)); - build.b(ConditionA64::NotEqual, fail); - } + build.cbnz(regOp(inst.a), fail); } else { - // To support DebugLuauAbortingChecks, CHECK_TAG with VmReg has to be handled - RegisterA64 tag = inst.a.kind == IrOpKind::VmReg ? regs.allocTemp(KindA64::w) : regOp(inst.a); - - if (inst.a.kind == IrOpKind::VmReg) - build.ldr(tag, mem(rBase, vmRegOp(inst.a) * sizeof(TValue) + offsetof(TValue, tt))); - - if (tagOp(inst.b) == 0) - { - build.cbnz(tag, fail); - } - else - { - build.cmp(tag, tagOp(inst.b)); - build.b(ConditionA64::NotEqual, fail); - } + build.cmp(regOp(inst.a), tagOp(inst.b)); + build.b(ConditionA64::NotEqual, fail); } finalizeTargetLabel(inst.c, fresh); diff --git a/CodeGen/src/IrLoweringX64.cpp b/CodeGen/src/IrLoweringX64.cpp index 00768c70..5128dce5 100644 --- a/CodeGen/src/IrLoweringX64.cpp +++ b/CodeGen/src/IrLoweringX64.cpp @@ -17,6 +17,7 @@ LUAU_FASTFLAG(LuauCodegenUserdataOps) LUAU_FASTFLAG(LuauCodegenUserdataAlloc) +LUAU_FASTFLAG(LuauCodegenFastcall3) namespace Luau { @@ -1008,9 +1009,10 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) case IrCmd::FASTCALL: { - OperandX64 arg2 = inst.d.kind != IrOpKind::Undef ? memRegDoubleOp(inst.d) : OperandX64{0}; - - emitBuiltin(regs, build, uintOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c), arg2, intOp(inst.e), intOp(inst.f)); + if (FFlag::LuauCodegenFastcall3) + emitBuiltin(regs, build, uintOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c), intOp(inst.d)); + else + emitBuiltin(regs, build, uintOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c), intOp(inst.f)); break; } case IrCmd::INVOKE_FASTCALL: @@ -1018,25 +1020,49 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) unsigned bfid = uintOp(inst.a); OperandX64 args = 0; + ScopedRegX64 argsAlt{regs}; - if (inst.d.kind == IrOpKind::VmReg) - args = luauRegAddress(vmRegOp(inst.d)); - else if (inst.d.kind == IrOpKind::VmConst) - args = luauConstantAddress(vmConstOp(inst.d)); + // 'E' argument can only be produced by LOP_FASTCALL3 + if (FFlag::LuauCodegenFastcall3 && inst.e.kind != IrOpKind::Undef) + { + CODEGEN_ASSERT(intOp(inst.f) == 3); + + ScopedRegX64 tmp{regs, SizeX64::xmmword}; + argsAlt.alloc(SizeX64::qword); + + build.mov(argsAlt.reg, qword[rState + offsetof(lua_State, top)]); + + build.vmovups(tmp.reg, luauReg(vmRegOp(inst.d))); + build.vmovups(xmmword[argsAlt.reg], tmp.reg); + + build.vmovups(tmp.reg, luauReg(vmRegOp(inst.e))); + build.vmovups(xmmword[argsAlt.reg + sizeof(TValue)], tmp.reg); + } else - CODEGEN_ASSERT(inst.d.kind == IrOpKind::Undef); + { + if (inst.d.kind == IrOpKind::VmReg) + args = luauRegAddress(vmRegOp(inst.d)); + else if (inst.d.kind == IrOpKind::VmConst) + args = luauConstantAddress(vmConstOp(inst.d)); + else + CODEGEN_ASSERT(inst.d.kind == IrOpKind::Undef); + } int ra = vmRegOp(inst.b); int arg = vmRegOp(inst.c); - int nparams = intOp(inst.e); - int nresults = intOp(inst.f); + int nparams = intOp(FFlag::LuauCodegenFastcall3 ? inst.f : inst.e); + int nresults = intOp(FFlag::LuauCodegenFastcall3 ? inst.g : inst.f); IrCallWrapperX64 callWrap(regs, build, index); callWrap.addArgument(SizeX64::qword, rState); callWrap.addArgument(SizeX64::qword, luauRegAddress(ra)); callWrap.addArgument(SizeX64::qword, luauRegAddress(arg)); callWrap.addArgument(SizeX64::dword, nresults); - callWrap.addArgument(SizeX64::qword, args); + + if (FFlag::LuauCodegenFastcall3 && inst.e.kind != IrOpKind::Undef) + callWrap.addArgument(SizeX64::qword, argsAlt); + else + callWrap.addArgument(SizeX64::qword, args); if (nparams == LUA_MULTRET) { diff --git a/CodeGen/src/IrTranslateBuiltins.cpp b/CodeGen/src/IrTranslateBuiltins.cpp index bec5deea..668bdfe0 100644 --- a/CodeGen/src/IrTranslateBuiltins.cpp +++ b/CodeGen/src/IrTranslateBuiltins.cpp @@ -8,7 +8,7 @@ #include -LUAU_FASTFLAG(LuauCodegenRemoveDeadStores5) +LUAU_FASTFLAG(LuauCodegenFastcall3) // TODO: when nresults is less than our actual result count, we can skip computing/writing unused results @@ -46,19 +46,17 @@ static BuiltinImplResult translateBuiltinNumberToNumber( return {BuiltinImplType::None, -1}; builtinCheckDouble(build, build.vmReg(arg), pcpos); - build.inst(IrCmd::FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), args, build.constInt(1), build.constInt(1)); - if (!FFlag::LuauCodegenRemoveDeadStores5) - { - if (ra != arg) - build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); - } + if (FFlag::LuauCodegenFastcall3) + build.inst(IrCmd::FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), build.constInt(1)); + else + build.inst(IrCmd::FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), args, build.constInt(1), build.constInt(1)); return {BuiltinImplType::Full, 1}; } static BuiltinImplResult translateBuiltinNumberToNumberLibm( - IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, int pcpos) + IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, int nresults, int pcpos) { if (nparams < 1 || nresults > 1) return {BuiltinImplType::None, -1}; @@ -109,17 +107,12 @@ static BuiltinImplResult translateBuiltinNumberTo2Number( return {BuiltinImplType::None, -1}; builtinCheckDouble(build, build.vmReg(arg), pcpos); - build.inst( - IrCmd::FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), args, build.constInt(1), build.constInt(nresults == 1 ? 1 : 2)); - if (!FFlag::LuauCodegenRemoveDeadStores5) - { - if (ra != arg) - build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); - - if (nresults != 1) - build.inst(IrCmd::STORE_TAG, build.vmReg(ra + 1), build.constTag(LUA_TNUMBER)); - } + if (FFlag::LuauCodegenFastcall3) + build.inst(IrCmd::FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), build.constInt(nresults == 1 ? 1 : 2)); + else + build.inst(IrCmd::FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), build.undef(), build.constInt(1), + build.constInt(nresults == 1 ? 1 : 2)); return {BuiltinImplType::Full, 2}; } @@ -198,7 +191,8 @@ static BuiltinImplResult translateBuiltinMathLog(IrBuilder& build, int nparams, return {BuiltinImplType::Full, 1}; } -static BuiltinImplResult translateBuiltinMathMinMax(IrBuilder& build, IrCmd cmd, int nparams, int ra, int arg, IrOp args, int nresults, int pcpos) +static BuiltinImplResult translateBuiltinMathMinMax( + IrBuilder& build, IrCmd cmd, int nparams, int ra, int arg, IrOp args, IrOp arg3, int nresults, int pcpos) { if (nparams < 2 || nparams > kMinMaxUnrolledParams || nresults > 1) return {BuiltinImplType::None, -1}; @@ -206,7 +200,10 @@ static BuiltinImplResult translateBuiltinMathMinMax(IrBuilder& build, IrCmd cmd, builtinCheckDouble(build, build.vmReg(arg), pcpos); builtinCheckDouble(build, args, pcpos); - for (int i = 3; i <= nparams; ++i) + if (FFlag::LuauCodegenFastcall3 && nparams >= 3) + builtinCheckDouble(build, arg3, pcpos); + + for (int i = (FFlag::LuauCodegenFastcall3 ? 4 : 3); i <= nparams; ++i) builtinCheckDouble(build, build.vmReg(vmRegOp(args) + (i - 2)), pcpos); IrOp varg1 = builtinLoadDouble(build, build.vmReg(arg)); @@ -214,7 +211,13 @@ static BuiltinImplResult translateBuiltinMathMinMax(IrBuilder& build, IrCmd cmd, IrOp res = build.inst(cmd, varg2, varg1); // Swapped arguments are required for consistency with VM builtins - for (int i = 3; i <= nparams; ++i) + if (FFlag::LuauCodegenFastcall3 && nparams >= 3) + { + IrOp arg = builtinLoadDouble(build, arg3); + res = build.inst(cmd, arg, res); + } + + for (int i = (FFlag::LuauCodegenFastcall3 ? 4 : 3); i <= nparams; ++i) { IrOp arg = builtinLoadDouble(build, build.vmReg(vmRegOp(args) + (i - 2))); res = build.inst(cmd, arg, res); @@ -228,7 +231,8 @@ static BuiltinImplResult translateBuiltinMathMinMax(IrBuilder& build, IrCmd cmd, return {BuiltinImplType::Full, 1}; } -static BuiltinImplResult translateBuiltinMathClamp(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback, int pcpos) +static BuiltinImplResult translateBuiltinMathClamp( + IrBuilder& build, int nparams, int ra, int arg, IrOp args, IrOp arg3, int nresults, IrOp fallback, int pcpos) { if (nparams < 3 || nresults > 1) return {BuiltinImplType::None, -1}; @@ -239,10 +243,10 @@ static BuiltinImplResult translateBuiltinMathClamp(IrBuilder& build, int nparams builtinCheckDouble(build, build.vmReg(arg), pcpos); builtinCheckDouble(build, args, pcpos); - builtinCheckDouble(build, build.vmReg(vmRegOp(args) + 1), pcpos); + builtinCheckDouble(build, FFlag::LuauCodegenFastcall3 ? arg3 : build.vmReg(vmRegOp(args) + 1), pcpos); IrOp min = builtinLoadDouble(build, args); - IrOp max = builtinLoadDouble(build, build.vmReg(vmRegOp(args) + 1)); + IrOp max = builtinLoadDouble(build, FFlag::LuauCodegenFastcall3 ? arg3 : build.vmReg(vmRegOp(args) + 1)); build.inst(IrCmd::JUMP_CMP_NUM, min, max, build.cond(IrCondition::NotLessEqual), fallback, block); build.beginBlock(block); @@ -305,7 +309,7 @@ static BuiltinImplResult translateBuiltinTypeof(IrBuilder& build, int nparams, i } static BuiltinImplResult translateBuiltinBit32BinaryOp( - IrBuilder& build, IrCmd cmd, bool btest, int nparams, int ra, int arg, IrOp args, int nresults, int pcpos) + IrBuilder& build, IrCmd cmd, bool btest, int nparams, int ra, int arg, IrOp args, IrOp arg3, int nresults, int pcpos) { if (nparams < 2 || nparams > kBit32BinaryOpUnrolledParams || nresults > 1) return {BuiltinImplType::None, -1}; @@ -313,7 +317,10 @@ static BuiltinImplResult translateBuiltinBit32BinaryOp( builtinCheckDouble(build, build.vmReg(arg), pcpos); builtinCheckDouble(build, args, pcpos); - for (int i = 3; i <= nparams; ++i) + if (FFlag::LuauCodegenFastcall3 && nparams >= 3) + builtinCheckDouble(build, arg3, pcpos); + + for (int i = (FFlag::LuauCodegenFastcall3 ? 4 : 3); i <= nparams; ++i) builtinCheckDouble(build, build.vmReg(vmRegOp(args) + (i - 2)), pcpos); IrOp va = builtinLoadDouble(build, build.vmReg(arg)); @@ -324,7 +331,15 @@ static BuiltinImplResult translateBuiltinBit32BinaryOp( IrOp res = build.inst(cmd, vaui, vbui); - for (int i = 3; i <= nparams; ++i) + if (FFlag::LuauCodegenFastcall3 && nparams >= 3) + { + IrOp vc = builtinLoadDouble(build, arg3); + IrOp arg = build.inst(IrCmd::NUM_TO_UINT, vc); + + res = build.inst(cmd, res, arg); + } + + for (int i = (FFlag::LuauCodegenFastcall3 ? 4 : 3); i <= nparams; ++i) { IrOp vc = builtinLoadDouble(build, build.vmReg(vmRegOp(args) + (i - 2))); IrOp arg = build.inst(IrCmd::NUM_TO_UINT, vc); @@ -449,7 +464,7 @@ static BuiltinImplResult translateBuiltinBit32Rotate(IrBuilder& build, IrCmd cmd } static BuiltinImplResult translateBuiltinBit32Extract( - IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback, int pcpos) + IrBuilder& build, int nparams, int ra, int arg, IrOp args, IrOp arg3, int nresults, IrOp fallback, int pcpos) { if (nparams < 2 || nresults > 1) return {BuiltinImplType::None, -1}; @@ -497,8 +512,8 @@ static BuiltinImplResult translateBuiltinBit32Extract( { IrOp f = build.inst(IrCmd::NUM_TO_INT, vb); - builtinCheckDouble(build, build.vmReg(args.index + 1), pcpos); - IrOp vc = builtinLoadDouble(build, build.vmReg(args.index + 1)); + builtinCheckDouble(build, FFlag::LuauCodegenFastcall3 ? arg3 : build.vmReg(args.index + 1), pcpos); + IrOp vc = builtinLoadDouble(build, FFlag::LuauCodegenFastcall3 ? arg3 : build.vmReg(args.index + 1)); IrOp w = build.inst(IrCmd::NUM_TO_INT, vc); IrOp block1 = build.block(IrBlockKind::Internal); @@ -587,18 +602,18 @@ static BuiltinImplResult translateBuiltinBit32Unary(IrBuilder& build, IrCmd cmd, } static BuiltinImplResult translateBuiltinBit32Replace( - IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback, int pcpos) + IrBuilder& build, int nparams, int ra, int arg, IrOp args, IrOp arg3, int nresults, IrOp fallback, int pcpos) { if (nparams < 3 || nresults > 1) return {BuiltinImplType::None, -1}; builtinCheckDouble(build, build.vmReg(arg), pcpos); builtinCheckDouble(build, args, pcpos); - builtinCheckDouble(build, build.vmReg(args.index + 1), pcpos); + builtinCheckDouble(build, FFlag::LuauCodegenFastcall3 ? arg3 : build.vmReg(args.index + 1), pcpos); IrOp va = builtinLoadDouble(build, build.vmReg(arg)); IrOp vb = builtinLoadDouble(build, args); - IrOp vc = builtinLoadDouble(build, build.vmReg(args.index + 1)); + IrOp vc = builtinLoadDouble(build, FFlag::LuauCodegenFastcall3 ? arg3 : build.vmReg(args.index + 1)); IrOp n = build.inst(IrCmd::NUM_TO_UINT, va); IrOp v = build.inst(IrCmd::NUM_TO_UINT, vb); @@ -623,8 +638,8 @@ static BuiltinImplResult translateBuiltinBit32Replace( } else { - builtinCheckDouble(build, build.vmReg(args.index + 2), pcpos); - IrOp vd = builtinLoadDouble(build, build.vmReg(args.index + 2)); + builtinCheckDouble(build, FFlag::LuauCodegenFastcall3 ? build.vmReg(vmRegOp(args) + 2) : build.vmReg(args.index + 2), pcpos); + IrOp vd = builtinLoadDouble(build, FFlag::LuauCodegenFastcall3 ? build.vmReg(vmRegOp(args) + 2) : build.vmReg(args.index + 2)); IrOp w = build.inst(IrCmd::NUM_TO_INT, vd); IrOp block1 = build.block(IrBlockKind::Internal); @@ -661,7 +676,7 @@ static BuiltinImplResult translateBuiltinBit32Replace( return {BuiltinImplType::UsesFallback, 1}; } -static BuiltinImplResult translateBuiltinVector(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, int pcpos) +static BuiltinImplResult translateBuiltinVector(IrBuilder& build, int nparams, int ra, int arg, IrOp args, IrOp arg3, int nresults, int pcpos) { if (nparams < 3 || nresults > 1) return {BuiltinImplType::None, -1}; @@ -670,11 +685,11 @@ static BuiltinImplResult translateBuiltinVector(IrBuilder& build, int nparams, i builtinCheckDouble(build, build.vmReg(arg), pcpos); builtinCheckDouble(build, args, pcpos); - builtinCheckDouble(build, build.vmReg(vmRegOp(args) + 1), pcpos); + builtinCheckDouble(build, FFlag::LuauCodegenFastcall3 ? arg3 : build.vmReg(vmRegOp(args) + 1), pcpos); IrOp x = builtinLoadDouble(build, build.vmReg(arg)); IrOp y = builtinLoadDouble(build, args); - IrOp z = builtinLoadDouble(build, build.vmReg(vmRegOp(args) + 1)); + IrOp z = builtinLoadDouble(build, FFlag::LuauCodegenFastcall3 ? arg3 : build.vmReg(vmRegOp(args) + 1)); build.inst(IrCmd::STORE_VECTOR, build.vmReg(ra), x, y, z); build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TVECTOR)); @@ -736,13 +751,14 @@ static BuiltinImplResult translateBuiltinStringLen(IrBuilder& build, int nparams return {BuiltinImplType::Full, 1}; } -static void translateBufferArgsAndCheckBounds(IrBuilder& build, int nparams, int arg, IrOp args, int size, int pcpos, IrOp& buf, IrOp& intIndex) +static void translateBufferArgsAndCheckBounds( + IrBuilder& build, int nparams, int arg, IrOp args, IrOp arg3, int size, int pcpos, IrOp& buf, IrOp& intIndex) { build.loadAndCheckTag(build.vmReg(arg), LUA_TBUFFER, build.vmExit(pcpos)); builtinCheckDouble(build, args, pcpos); if (nparams == 3) - builtinCheckDouble(build, build.vmReg(vmRegOp(args) + 1), pcpos); + builtinCheckDouble(build, FFlag::LuauCodegenFastcall3 ? arg3 : build.vmReg(vmRegOp(args) + 1), pcpos); buf = build.inst(IrCmd::LOAD_POINTER, build.vmReg(arg)); @@ -753,13 +769,13 @@ static void translateBufferArgsAndCheckBounds(IrBuilder& build, int nparams, int } static BuiltinImplResult translateBuiltinBufferRead( - IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, int pcpos, IrCmd readCmd, int size, IrCmd convCmd) + IrBuilder& build, int nparams, int ra, int arg, IrOp args, IrOp arg3, int nresults, int pcpos, IrCmd readCmd, int size, IrCmd convCmd) { if (nparams < 2 || nresults > 1) return {BuiltinImplType::None, -1}; IrOp buf, intIndex; - translateBufferArgsAndCheckBounds(build, nparams, arg, args, size, pcpos, buf, intIndex); + translateBufferArgsAndCheckBounds(build, nparams, arg, args, arg3, size, pcpos, buf, intIndex); IrOp result = build.inst(readCmd, buf, intIndex); build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), convCmd == IrCmd::NOP ? result : build.inst(convCmd, result)); @@ -769,21 +785,22 @@ static BuiltinImplResult translateBuiltinBufferRead( } static BuiltinImplResult translateBuiltinBufferWrite( - IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, int pcpos, IrCmd writeCmd, int size, IrCmd convCmd) + IrBuilder& build, int nparams, int ra, int arg, IrOp args, IrOp arg3, int nresults, int pcpos, IrCmd writeCmd, int size, IrCmd convCmd) { if (nparams < 3 || nresults > 0) return {BuiltinImplType::None, -1}; IrOp buf, intIndex; - translateBufferArgsAndCheckBounds(build, nparams, arg, args, size, pcpos, buf, intIndex); + translateBufferArgsAndCheckBounds(build, nparams, arg, args, arg3, size, pcpos, buf, intIndex); - IrOp numValue = builtinLoadDouble(build, build.vmReg(vmRegOp(args) + 1)); + IrOp numValue = builtinLoadDouble(build, FFlag::LuauCodegenFastcall3 ? arg3 : build.vmReg(vmRegOp(args) + 1)); build.inst(writeCmd, buf, intIndex, convCmd == IrCmd::NOP ? numValue : build.inst(convCmd, numValue)); return {BuiltinImplType::Full, 0}; } -BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg, IrOp args, int nparams, int nresults, IrOp fallback, int pcpos) +BuiltinImplResult translateBuiltin( + IrBuilder& build, int bfid, int ra, int arg, IrOp args, IrOp arg3, int nparams, int nresults, IrOp fallback, int pcpos) { // Builtins are not allowed to handle variadic arguments if (nparams == LUA_MULTRET) @@ -800,11 +817,11 @@ BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg, case LBF_MATH_LOG: return translateBuiltinMathLog(build, nparams, ra, arg, args, nresults, pcpos); case LBF_MATH_MIN: - return translateBuiltinMathMinMax(build, IrCmd::MIN_NUM, nparams, ra, arg, args, nresults, pcpos); + return translateBuiltinMathMinMax(build, IrCmd::MIN_NUM, nparams, ra, arg, args, arg3, nresults, pcpos); case LBF_MATH_MAX: - return translateBuiltinMathMinMax(build, IrCmd::MAX_NUM, nparams, ra, arg, args, nresults, pcpos); + return translateBuiltinMathMinMax(build, IrCmd::MAX_NUM, nparams, ra, arg, args, arg3, nresults, pcpos); case LBF_MATH_CLAMP: - return translateBuiltinMathClamp(build, nparams, ra, arg, args, nresults, fallback, pcpos); + return translateBuiltinMathClamp(build, nparams, ra, arg, args, arg3, nresults, fallback, pcpos); case LBF_MATH_FLOOR: return translateBuiltinMathUnary(build, IrCmd::FLOOR_NUM, nparams, ra, arg, nresults, pcpos); case LBF_MATH_CEIL: @@ -826,7 +843,7 @@ BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg, case LBF_MATH_TAN: case LBF_MATH_TANH: case LBF_MATH_LOG10: - return translateBuiltinNumberToNumberLibm(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, pcpos); + return translateBuiltinNumberToNumberLibm(build, LuauBuiltinFunction(bfid), nparams, ra, arg, nresults, pcpos); case LBF_MATH_SIGN: return translateBuiltinNumberToNumber(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, pcpos); case LBF_MATH_POW: @@ -838,13 +855,13 @@ BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg, case LBF_MATH_MODF: return translateBuiltinNumberTo2Number(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, pcpos); case LBF_BIT32_BAND: - return translateBuiltinBit32BinaryOp(build, IrCmd::BITAND_UINT, /* btest= */ false, nparams, ra, arg, args, nresults, pcpos); + return translateBuiltinBit32BinaryOp(build, IrCmd::BITAND_UINT, /* btest= */ false, nparams, ra, arg, args, arg3, nresults, pcpos); case LBF_BIT32_BOR: - return translateBuiltinBit32BinaryOp(build, IrCmd::BITOR_UINT, /* btest= */ false, nparams, ra, arg, args, nresults, pcpos); + return translateBuiltinBit32BinaryOp(build, IrCmd::BITOR_UINT, /* btest= */ false, nparams, ra, arg, args, arg3, nresults, pcpos); case LBF_BIT32_BXOR: - return translateBuiltinBit32BinaryOp(build, IrCmd::BITXOR_UINT, /* btest= */ false, nparams, ra, arg, args, nresults, pcpos); + return translateBuiltinBit32BinaryOp(build, IrCmd::BITXOR_UINT, /* btest= */ false, nparams, ra, arg, args, arg3, nresults, pcpos); case LBF_BIT32_BTEST: - return translateBuiltinBit32BinaryOp(build, IrCmd::BITAND_UINT, /* btest= */ true, nparams, ra, arg, args, nresults, pcpos); + return translateBuiltinBit32BinaryOp(build, IrCmd::BITAND_UINT, /* btest= */ true, nparams, ra, arg, args, arg3, nresults, pcpos); case LBF_BIT32_BNOT: return translateBuiltinBit32Bnot(build, nparams, ra, arg, args, nresults, pcpos); case LBF_BIT32_LSHIFT: @@ -858,7 +875,7 @@ BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg, case LBF_BIT32_RROTATE: return translateBuiltinBit32Rotate(build, IrCmd::BITRROTATE_UINT, nparams, ra, arg, args, nresults, pcpos); case LBF_BIT32_EXTRACT: - return translateBuiltinBit32Extract(build, nparams, ra, arg, args, nresults, fallback, pcpos); + return translateBuiltinBit32Extract(build, nparams, ra, arg, args, arg3, nresults, fallback, pcpos); case LBF_BIT32_EXTRACTK: return translateBuiltinBit32ExtractK(build, nparams, ra, arg, args, nresults, pcpos); case LBF_BIT32_COUNTLZ: @@ -866,13 +883,13 @@ BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg, case LBF_BIT32_COUNTRZ: return translateBuiltinBit32Unary(build, IrCmd::BITCOUNTRZ_UINT, nparams, ra, arg, args, nresults, pcpos); case LBF_BIT32_REPLACE: - return translateBuiltinBit32Replace(build, nparams, ra, arg, args, nresults, fallback, pcpos); + return translateBuiltinBit32Replace(build, nparams, ra, arg, args, arg3, nresults, fallback, pcpos); case LBF_TYPE: return translateBuiltinType(build, nparams, ra, arg, args, nresults); case LBF_TYPEOF: return translateBuiltinTypeof(build, nparams, ra, arg, args, nresults); case LBF_VECTOR: - return translateBuiltinVector(build, nparams, ra, arg, args, nresults, pcpos); + return translateBuiltinVector(build, nparams, ra, arg, args, arg3, nresults, pcpos); case LBF_TABLE_INSERT: return translateBuiltinTableInsert(build, nparams, ra, arg, args, nresults, pcpos); case LBF_STRING_LEN: @@ -880,31 +897,31 @@ BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg, case LBF_BIT32_BYTESWAP: return translateBuiltinBit32Unary(build, IrCmd::BYTESWAP_UINT, nparams, ra, arg, args, nresults, pcpos); case LBF_BUFFER_READI8: - return translateBuiltinBufferRead(build, nparams, ra, arg, args, nresults, pcpos, IrCmd::BUFFER_READI8, 1, IrCmd::INT_TO_NUM); + return translateBuiltinBufferRead(build, nparams, ra, arg, args, arg3, nresults, pcpos, IrCmd::BUFFER_READI8, 1, IrCmd::INT_TO_NUM); case LBF_BUFFER_READU8: - return translateBuiltinBufferRead(build, nparams, ra, arg, args, nresults, pcpos, IrCmd::BUFFER_READU8, 1, IrCmd::INT_TO_NUM); + return translateBuiltinBufferRead(build, nparams, ra, arg, args, arg3, nresults, pcpos, IrCmd::BUFFER_READU8, 1, IrCmd::INT_TO_NUM); case LBF_BUFFER_WRITEU8: - return translateBuiltinBufferWrite(build, nparams, ra, arg, args, nresults, pcpos, IrCmd::BUFFER_WRITEI8, 1, IrCmd::NUM_TO_UINT); + return translateBuiltinBufferWrite(build, nparams, ra, arg, args, arg3, nresults, pcpos, IrCmd::BUFFER_WRITEI8, 1, IrCmd::NUM_TO_UINT); case LBF_BUFFER_READI16: - return translateBuiltinBufferRead(build, nparams, ra, arg, args, nresults, pcpos, IrCmd::BUFFER_READI16, 2, IrCmd::INT_TO_NUM); + return translateBuiltinBufferRead(build, nparams, ra, arg, args, arg3, nresults, pcpos, IrCmd::BUFFER_READI16, 2, IrCmd::INT_TO_NUM); case LBF_BUFFER_READU16: - return translateBuiltinBufferRead(build, nparams, ra, arg, args, nresults, pcpos, IrCmd::BUFFER_READU16, 2, IrCmd::INT_TO_NUM); + return translateBuiltinBufferRead(build, nparams, ra, arg, args, arg3, nresults, pcpos, IrCmd::BUFFER_READU16, 2, IrCmd::INT_TO_NUM); case LBF_BUFFER_WRITEU16: - return translateBuiltinBufferWrite(build, nparams, ra, arg, args, nresults, pcpos, IrCmd::BUFFER_WRITEI16, 2, IrCmd::NUM_TO_UINT); + return translateBuiltinBufferWrite(build, nparams, ra, arg, args, arg3, nresults, pcpos, IrCmd::BUFFER_WRITEI16, 2, IrCmd::NUM_TO_UINT); case LBF_BUFFER_READI32: - return translateBuiltinBufferRead(build, nparams, ra, arg, args, nresults, pcpos, IrCmd::BUFFER_READI32, 4, IrCmd::INT_TO_NUM); + return translateBuiltinBufferRead(build, nparams, ra, arg, args, arg3, nresults, pcpos, IrCmd::BUFFER_READI32, 4, IrCmd::INT_TO_NUM); case LBF_BUFFER_READU32: - return translateBuiltinBufferRead(build, nparams, ra, arg, args, nresults, pcpos, IrCmd::BUFFER_READI32, 4, IrCmd::UINT_TO_NUM); + return translateBuiltinBufferRead(build, nparams, ra, arg, args, arg3, nresults, pcpos, IrCmd::BUFFER_READI32, 4, IrCmd::UINT_TO_NUM); case LBF_BUFFER_WRITEU32: - return translateBuiltinBufferWrite(build, nparams, ra, arg, args, nresults, pcpos, IrCmd::BUFFER_WRITEI32, 4, IrCmd::NUM_TO_UINT); + return translateBuiltinBufferWrite(build, nparams, ra, arg, args, arg3, nresults, pcpos, IrCmd::BUFFER_WRITEI32, 4, IrCmd::NUM_TO_UINT); case LBF_BUFFER_READF32: - return translateBuiltinBufferRead(build, nparams, ra, arg, args, nresults, pcpos, IrCmd::BUFFER_READF32, 4, IrCmd::NOP); + return translateBuiltinBufferRead(build, nparams, ra, arg, args, arg3, nresults, pcpos, IrCmd::BUFFER_READF32, 4, IrCmd::NOP); case LBF_BUFFER_WRITEF32: - return translateBuiltinBufferWrite(build, nparams, ra, arg, args, nresults, pcpos, IrCmd::BUFFER_WRITEF32, 4, IrCmd::NOP); + return translateBuiltinBufferWrite(build, nparams, ra, arg, args, arg3, nresults, pcpos, IrCmd::BUFFER_WRITEF32, 4, IrCmd::NOP); case LBF_BUFFER_READF64: - return translateBuiltinBufferRead(build, nparams, ra, arg, args, nresults, pcpos, IrCmd::BUFFER_READF64, 8, IrCmd::NOP); + return translateBuiltinBufferRead(build, nparams, ra, arg, args, arg3, nresults, pcpos, IrCmd::BUFFER_READF64, 8, IrCmd::NOP); case LBF_BUFFER_WRITEF64: - return translateBuiltinBufferWrite(build, nparams, ra, arg, args, nresults, pcpos, IrCmd::BUFFER_WRITEF64, 8, IrCmd::NOP); + return translateBuiltinBufferWrite(build, nparams, ra, arg, args, arg3, nresults, pcpos, IrCmd::BUFFER_WRITEF64, 8, IrCmd::NOP); default: return {BuiltinImplType::None, -1}; } diff --git a/CodeGen/src/IrTranslateBuiltins.h b/CodeGen/src/IrTranslateBuiltins.h index 8ae64b94..54a05aba 100644 --- a/CodeGen/src/IrTranslateBuiltins.h +++ b/CodeGen/src/IrTranslateBuiltins.h @@ -22,7 +22,8 @@ struct BuiltinImplResult int actualResultCount; }; -BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg, IrOp args, int nparams, int nresults, IrOp fallback, int pcpos); +BuiltinImplResult translateBuiltin( + IrBuilder& build, int bfid, int ra, int arg, IrOp args, IrOp arg3, int nparams, int nresults, IrOp fallback, int pcpos); } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/IrTranslation.cpp b/CodeGen/src/IrTranslation.cpp index 5798f3e9..e06f14f8 100644 --- a/CodeGen/src/IrTranslation.cpp +++ b/CodeGen/src/IrTranslation.cpp @@ -13,9 +13,9 @@ #include "lstate.h" #include "ltm.h" -LUAU_FASTFLAGVARIABLE(LuauCodegenDirectUserdataFlow, false) LUAU_FASTFLAG(LuauCodegenAnalyzeHostVectorOps) LUAU_FASTFLAG(LuauCodegenUserdataOps) +LUAU_FASTFLAG(LuauCodegenFastcall3) namespace Luau { @@ -743,7 +743,7 @@ void translateInstCloseUpvals(IrBuilder& build, const Instruction* pc) build.inst(IrCmd::CLOSE_UPVALS, build.vmReg(ra)); } -IrOp translateFastCallN(IrBuilder& build, const Instruction* pc, int pcpos, bool customParams, int customParamCount, IrOp customArgs) +IrOp translateFastCallN(IrBuilder& build, const Instruction* pc, int pcpos, bool customParams, int customParamCount, IrOp customArgs, IrOp customArg3) { LuauOpcode opcode = LuauOpcode(LUAU_INSN_OP(*pc)); int bfid = LUAU_INSN_A(*pc); @@ -769,13 +769,15 @@ IrOp translateFastCallN(IrBuilder& build, const Instruction* pc, int pcpos, bool builtinArgs = build.constDouble(protok.value.n); } + IrOp builtinArg3 = FFlag::LuauCodegenFastcall3 ? (customParams ? customArg3 : build.vmReg(ra + 3)) : IrOp{}; + IrOp fallback = build.block(IrBlockKind::Fallback); // In unsafe environment, instead of retrying fastcall at 'pcpos' we side-exit directly to fallback sequence build.inst(IrCmd::CHECK_SAFE_ENV, build.vmExit(pcpos + getOpLength(opcode))); - BuiltinImplResult br = - translateBuiltin(build, LuauBuiltinFunction(bfid), ra, arg, builtinArgs, nparams, nresults, fallback, pcpos + getOpLength(opcode)); + BuiltinImplResult br = translateBuiltin( + build, LuauBuiltinFunction(bfid), ra, arg, builtinArgs, builtinArg3, nparams, nresults, fallback, pcpos + getOpLength(opcode)); if (br.type != BuiltinImplType::None) { @@ -792,6 +794,22 @@ IrOp translateFastCallN(IrBuilder& build, const Instruction* pc, int pcpos, bool return build.undef(); } } + else if (FFlag::LuauCodegenFastcall3) + { + IrOp arg3 = customParams ? customArg3 : build.undef(); + + // TODO: we can skip saving pc for some well-behaved builtins which we didn't inline + build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + getOpLength(opcode))); + + IrOp res = build.inst(IrCmd::INVOKE_FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), args, arg3, build.constInt(nparams), + build.constInt(nresults)); + build.inst(IrCmd::CHECK_FASTCALL_RES, res, fallback); + + if (nresults == LUA_MULTRET) + build.inst(IrCmd::ADJUST_STACK_TO_REG, build.vmReg(ra), res); + else if (nparams == LUA_MULTRET) + build.inst(IrCmd::ADJUST_STACK_TO_TOP); + } else { // TODO: we can skip saving pc for some well-behaved builtins which we didn't inline @@ -1277,7 +1295,7 @@ void translateInstGetTableKS(IrBuilder& build, const Instruction* pc, int pcpos) return; } - if (FFlag::LuauCodegenDirectUserdataFlow && (FFlag::LuauCodegenUserdataOps ? isUserdataBytecodeType(bcTypes.a) : bcTypes.a == LBC_TYPE_USERDATA)) + if (FFlag::LuauCodegenUserdataOps ? isUserdataBytecodeType(bcTypes.a) : bcTypes.a == LBC_TYPE_USERDATA) { build.inst(IrCmd::CHECK_TAG, tb, build.constTag(LUA_TUSERDATA), build.vmExit(pcpos)); @@ -1324,7 +1342,7 @@ void translateInstSetTableKS(IrBuilder& build, const Instruction* pc, int pcpos) IrOp tb = build.inst(IrCmd::LOAD_TAG, build.vmReg(rb)); - if (FFlag::LuauCodegenDirectUserdataFlow && (FFlag::LuauCodegenUserdataOps ? isUserdataBytecodeType(bcTypes.a) : bcTypes.a == LBC_TYPE_USERDATA)) + if (FFlag::LuauCodegenUserdataOps ? isUserdataBytecodeType(bcTypes.a) : bcTypes.a == LBC_TYPE_USERDATA) { build.inst(IrCmd::CHECK_TAG, tb, build.constTag(LUA_TUSERDATA), build.vmExit(pcpos)); @@ -1446,7 +1464,7 @@ bool translateInstNamecall(IrBuilder& build, const Instruction* pc, int pcpos) BytecodeTypes bcTypes = build.function.getBytecodeTypesAt(pcpos); - if (FFlag::LuauCodegenDirectUserdataFlow && bcTypes.a == LBC_TYPE_VECTOR) + if (bcTypes.a == LBC_TYPE_VECTOR) { build.loadAndCheckTag(build.vmReg(rb), LUA_TVECTOR, build.vmExit(pcpos)); @@ -1470,7 +1488,7 @@ bool translateInstNamecall(IrBuilder& build, const Instruction* pc, int pcpos) return false; } - if (FFlag::LuauCodegenDirectUserdataFlow && (FFlag::LuauCodegenUserdataOps ? isUserdataBytecodeType(bcTypes.a) : bcTypes.a == LBC_TYPE_USERDATA)) + if (FFlag::LuauCodegenUserdataOps ? isUserdataBytecodeType(bcTypes.a) : bcTypes.a == LBC_TYPE_USERDATA) { build.loadAndCheckTag(build.vmReg(rb), LUA_TUSERDATA, build.vmExit(pcpos)); @@ -1499,8 +1517,7 @@ bool translateInstNamecall(IrBuilder& build, const Instruction* pc, int pcpos) IrOp firstFastPathSuccess = build.block(IrBlockKind::Internal); IrOp secondFastPath = build.block(IrBlockKind::Internal); - build.loadAndCheckTag( - build.vmReg(rb), LUA_TTABLE, FFlag::LuauCodegenDirectUserdataFlow && bcTypes.a == LBC_TYPE_TABLE ? build.vmExit(pcpos) : fallback); + build.loadAndCheckTag(build.vmReg(rb), LUA_TTABLE, bcTypes.a == LBC_TYPE_TABLE ? build.vmExit(pcpos) : fallback); IrOp table = build.inst(IrCmd::LOAD_POINTER, build.vmReg(rb)); CODEGEN_ASSERT(build.function.proto); diff --git a/CodeGen/src/IrTranslation.h b/CodeGen/src/IrTranslation.h index 5eb01450..8b514cc1 100644 --- a/CodeGen/src/IrTranslation.h +++ b/CodeGen/src/IrTranslation.h @@ -44,7 +44,8 @@ void translateInstDupTable(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstGetUpval(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstSetUpval(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstCloseUpvals(IrBuilder& build, const Instruction* pc); -IrOp translateFastCallN(IrBuilder& build, const Instruction* pc, int pcpos, bool customParams, int customParamCount, IrOp customArgs); +IrOp translateFastCallN( + IrBuilder& build, const Instruction* pc, int pcpos, bool customParams, int customParamCount, IrOp customArgs, IrOp customArg3); void translateInstForNPrep(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstForNLoop(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstForGPrepNext(IrBuilder& build, const Instruction* pc, int pcpos); diff --git a/CodeGen/src/IrValueLocationTracking.cpp b/CodeGen/src/IrValueLocationTracking.cpp index c6b2d044..0224b49b 100644 --- a/CodeGen/src/IrValueLocationTracking.cpp +++ b/CodeGen/src/IrValueLocationTracking.cpp @@ -3,6 +3,8 @@ #include "Luau/IrUtils.h" +LUAU_FASTFLAG(LuauCodegenFastcall3) + namespace Luau { namespace CodeGen @@ -44,11 +46,11 @@ void IrValueLocationTracking::beforeInstLowering(IrInst& inst) invalidateRestoreVmRegs(vmRegOp(inst.a), -1); break; case IrCmd::FASTCALL: - invalidateRestoreVmRegs(vmRegOp(inst.b), function.intOp(inst.f)); + invalidateRestoreVmRegs(vmRegOp(inst.b), function.intOp(FFlag::LuauCodegenFastcall3 ? inst.d : inst.f)); break; case IrCmd::INVOKE_FASTCALL: // Multiple return sequences (count == -1) are defined by ADJUST_STACK_TO_REG - if (int count = function.intOp(inst.f); count != -1) + if (int count = function.intOp(FFlag::LuauCodegenFastcall3 ? inst.g : inst.f); count != -1) invalidateRestoreVmRegs(vmRegOp(inst.b), count); break; case IrCmd::DO_ARITH: @@ -119,7 +121,7 @@ void IrValueLocationTracking::beforeInstLowering(IrInst& inst) break; // These instructions read VmReg only after optimizeMemoryOperandsX64 - case IrCmd::CHECK_TAG: // TODO: remove with FFlagLuauCodegenRemoveDeadStores5 + case IrCmd::CHECK_TAG: case IrCmd::CHECK_TRUTHY: case IrCmd::ADD_NUM: case IrCmd::SUB_NUM: diff --git a/CodeGen/src/OptimizeConstProp.cpp b/CodeGen/src/OptimizeConstProp.cpp index 4ff49570..0cd2aa51 100644 --- a/CodeGen/src/OptimizeConstProp.cpp +++ b/CodeGen/src/OptimizeConstProp.cpp @@ -18,10 +18,10 @@ LUAU_FASTINTVARIABLE(LuauCodeGenMinLinearBlockPath, 3) LUAU_FASTINTVARIABLE(LuauCodeGenReuseSlotLimit, 64) LUAU_FASTINTVARIABLE(LuauCodeGenReuseUdataTagLimit, 64) LUAU_FASTFLAGVARIABLE(DebugLuauAbortingChecks, false) -LUAU_FASTFLAG(LuauCodegenRemoveDeadStores5) LUAU_FASTFLAGVARIABLE(LuauCodegenFixSplitStoreConstMismatch, false) LUAU_FASTFLAG(LuauCodegenUserdataOps) LUAU_FASTFLAG(LuauCodegenUserdataAlloc) +LUAU_FASTFLAG(LuauCodegenFastcall3) namespace Luau { @@ -621,16 +621,9 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& std::tie(activeLoadCmd, activeLoadValue) = state.getPreviousVersionedLoadForTag(value, source); if (state.tryGetTag(source) == value) - { - if (FFlag::DebugLuauAbortingChecks && !FFlag::LuauCodegenRemoveDeadStores5) - replace(function, block, index, {IrCmd::CHECK_TAG, inst.a, inst.b, build.undef()}); - else - kill(function, inst); - } + kill(function, inst); else - { state.saveTag(source, value); - } } else { @@ -1150,39 +1143,33 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::FASTCALL: { - if (FFlag::LuauCodegenRemoveDeadStores5) + LuauBuiltinFunction bfid = LuauBuiltinFunction(function.uintOp(inst.a)); + int firstReturnReg = vmRegOp(inst.b); + int nresults = function.intOp(FFlag::LuauCodegenFastcall3 ? inst.d : inst.f); + + // TODO: FASTCALL is more restrictive than INVOKE_FASTCALL; we should either determine the exact semantics, or rework it + handleBuiltinEffects(state, bfid, firstReturnReg, nresults); + + switch (bfid) { - LuauBuiltinFunction bfid = LuauBuiltinFunction(function.uintOp(inst.a)); - int firstReturnReg = vmRegOp(inst.b); - int nresults = function.intOp(inst.f); + case LBF_MATH_MODF: + case LBF_MATH_FREXP: + state.updateTag(IrOp{IrOpKind::VmReg, uint8_t(firstReturnReg)}, LUA_TNUMBER); - // TODO: FASTCALL is more restrictive than INVOKE_FASTCALL; we should either determine the exact semantics, or rework it - handleBuiltinEffects(state, bfid, firstReturnReg, nresults); - - switch (bfid) - { - case LBF_MATH_MODF: - case LBF_MATH_FREXP: - state.updateTag(IrOp{IrOpKind::VmReg, uint8_t(firstReturnReg)}, LUA_TNUMBER); - - if (nresults > 1) - state.updateTag(IrOp{IrOpKind::VmReg, uint8_t(firstReturnReg + 1)}, LUA_TNUMBER); - break; - case LBF_MATH_SIGN: - state.updateTag(IrOp{IrOpKind::VmReg, uint8_t(firstReturnReg)}, LUA_TNUMBER); - break; - default: - break; - } - } - else - { - handleBuiltinEffects(state, LuauBuiltinFunction(function.uintOp(inst.a)), vmRegOp(inst.b), function.intOp(inst.f)); + if (nresults > 1) + state.updateTag(IrOp{IrOpKind::VmReg, uint8_t(firstReturnReg + 1)}, LUA_TNUMBER); + break; + case LBF_MATH_SIGN: + state.updateTag(IrOp{IrOpKind::VmReg, uint8_t(firstReturnReg)}, LUA_TNUMBER); + break; + default: + break; } break; } case IrCmd::INVOKE_FASTCALL: - handleBuiltinEffects(state, LuauBuiltinFunction(function.uintOp(inst.a)), vmRegOp(inst.b), function.intOp(inst.f)); + handleBuiltinEffects( + state, LuauBuiltinFunction(function.uintOp(inst.a)), vmRegOp(inst.b), function.intOp(FFlag::LuauCodegenFastcall3 ? inst.g : inst.f)); break; // These instructions don't have an effect on register/memory state we are tracking diff --git a/CodeGen/src/OptimizeDeadStore.cpp b/CodeGen/src/OptimizeDeadStore.cpp index d18b75c5..9fa6f062 100644 --- a/CodeGen/src/OptimizeDeadStore.cpp +++ b/CodeGen/src/OptimizeDeadStore.cpp @@ -9,7 +9,6 @@ #include "lobject.h" -LUAU_FASTFLAGVARIABLE(LuauCodegenRemoveDeadStores5, false) LUAU_FASTFLAG(LuauCodegenUserdataOps) // TODO: optimization can be improved by knowing which registers are live in at each VM exit diff --git a/Common/include/Luau/Bytecode.h b/Common/include/Luau/Bytecode.h index 85fef5aa..f971391b 100644 --- a/Common/include/Luau/Bytecode.h +++ b/Common/include/Luau/Bytecode.h @@ -46,6 +46,7 @@ // Version 3: Adds FORGPREP/JUMPXEQK* and enhances AUX encoding for FORGLOOP. Removes FORGLOOP_NEXT/INEXT and JUMPIFEQK/JUMPIFNOTEQK. Currently supported. // Version 4: Adds Proto::flags, typeinfo, and floor division opcodes IDIV/IDIVK. Currently supported. // Version 5: Adds SUBRK/DIVRK and vector constants. Currently supported. +// Version 6: Adds FASTCALL3. Currently supported. // # Bytecode type information history // Version 1: (from bytecode version 4) Type information for function signature. Currently supported. @@ -299,8 +300,13 @@ enum LuauOpcode // A: target register (see FORGLOOP for register layout) LOP_FORGPREP_INEXT, - // removed in v3 - LOP_DEP_FORGLOOP_INEXT, + // FASTCALL3: perform a fast call of a built-in function using 3 register arguments + // A: builtin function id (see LuauBuiltinFunction) + // B: source argument register + // C: jump offset to get to following CALL + // AUX: source register 2 in least-significant byte + // AUX: source register 3 in second least-significant byte + LOP_FASTCALL3, // FORGPREP_NEXT: prepare FORGLOOP with 2 output variables (no AUX encoding), assuming generator is luaB_next, and jump to FORGLOOP // A: target register (see FORGLOOP for register layout) @@ -434,7 +440,7 @@ enum LuauBytecodeTag { // Bytecode version; runtime supports [MIN, MAX], compiler emits TARGET by default but may emit a higher version when flags are enabled LBC_VERSION_MIN = 3, - LBC_VERSION_MAX = 5, + LBC_VERSION_MAX = 6, LBC_VERSION_TARGET = 5, // Type encoding version LBC_TYPE_VERSION_DEPRECATED = 1, diff --git a/Common/include/Luau/BytecodeUtils.h b/Common/include/Luau/BytecodeUtils.h index 957c804c..6f110311 100644 --- a/Common/include/Luau/BytecodeUtils.h +++ b/Common/include/Luau/BytecodeUtils.h @@ -28,6 +28,7 @@ inline int getOpLength(LuauOpcode op) case LOP_LOADKX: case LOP_FASTCALL2: case LOP_FASTCALL2K: + case LOP_FASTCALL3: case LOP_JUMPXEQKNIL: case LOP_JUMPXEQKB: case LOP_JUMPXEQKN: diff --git a/Common/include/Luau/DenseHash.h b/Common/include/Luau/DenseHash.h index 507a9c48..39e50f92 100644 --- a/Common/include/Luau/DenseHash.h +++ b/Common/include/Luau/DenseHash.h @@ -120,12 +120,12 @@ public: return *this; } - void clear() + void clear(size_t thresholdToDestroy = 32) { if (count == 0) return; - if (capacity > 32) + if (capacity > thresholdToDestroy) { destroy(); } @@ -583,9 +583,9 @@ public: { } - void clear() + void clear(size_t thresholdToDestroy = 32) { - impl.clear(); + impl.clear(thresholdToDestroy); } // Note: this reference is invalidated by any insert operation (i.e. operator[]) diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp index 59aee1e7..fac740c2 100644 --- a/Compiler/src/BytecodeBuilder.cpp +++ b/Compiler/src/BytecodeBuilder.cpp @@ -9,6 +9,7 @@ LUAU_FASTFLAGVARIABLE(LuauCompileTypeInfo, false) LUAU_FASTFLAG(LuauCompileUserdataInfo) +LUAU_FASTFLAG(LuauCompileFastcall3) namespace Luau { @@ -113,6 +114,7 @@ inline bool isFastCall(LuauOpcode op) case LOP_FASTCALL1: case LOP_FASTCALL2: case LOP_FASTCALL2K: + case LOP_FASTCALL3: return true; default: @@ -1241,6 +1243,9 @@ std::string BytecodeBuilder::getError(const std::string& message) uint8_t BytecodeBuilder::getVersion() { // This function usually returns LBC_VERSION_TARGET but may sometimes return a higher number (within LBC_VERSION_MIN/MAX) under fast flags + if (FFlag::LuauCompileFastcall3) + return 6; + return LBC_VERSION_TARGET; } @@ -1621,6 +1626,16 @@ void BytecodeBuilder::validateInstructions() const VCONSTANY(insns[i + 1]); break; + case LOP_FASTCALL3: + LUAU_ASSERT(FFlag::LuauCompileFastcall3); + + VREG(LUAU_INSN_B(insn)); + VJUMP(LUAU_INSN_C(insn)); + LUAU_ASSERT(LUAU_INSN_OP(insns[i + 1 + LUAU_INSN_C(insn)]) == LOP_CALL); + VREG(insns[i + 1] & 0xff); + VREG((insns[i + 1] >> 8) & 0xff); + break; + case LOP_COVERAGE: break; @@ -2235,6 +2250,13 @@ void BytecodeBuilder::dumpInstruction(const uint32_t* code, std::string& result, code++; break; + case LOP_FASTCALL3: + LUAU_ASSERT(FFlag::LuauCompileFastcall3); + + formatAppend(result, "FASTCALL3 %d R%d R%d R%d L%d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), *code & 0xff, (*code >> 8) & 0xff, targetLabel); + code++; + break; + case LOP_COVERAGE: formatAppend(result, "COVERAGE\n"); break; diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index db86fbc6..26d3100c 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -29,6 +29,7 @@ LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5) LUAU_FASTFLAG(LuauCompileTypeInfo) LUAU_FASTFLAGVARIABLE(LuauCompileTempTypeInfo, false) LUAU_FASTFLAGVARIABLE(LuauCompileUserdataInfo, false) +LUAU_FASTFLAGVARIABLE(LuauCompileFastcall3, false) LUAU_FASTFLAG(LuauNativeAttribute) @@ -473,10 +474,32 @@ struct Compiler { LUAU_ASSERT(!expr->self); LUAU_ASSERT(expr->args.size >= 1); - LUAU_ASSERT(expr->args.size <= 2 || (bfid == LBF_BIT32_EXTRACTK && expr->args.size == 3)); + + if (FFlag::LuauCompileFastcall3) + LUAU_ASSERT(expr->args.size <= 3); + else + LUAU_ASSERT(expr->args.size <= 2 || (bfid == LBF_BIT32_EXTRACTK && expr->args.size == 3)); + LUAU_ASSERT(bfid == LBF_BIT32_EXTRACTK ? bfK >= 0 : bfK < 0); - LuauOpcode opc = expr->args.size == 1 ? LOP_FASTCALL1 : (bfK >= 0 || isConstant(expr->args.data[1])) ? LOP_FASTCALL2K : LOP_FASTCALL2; + LuauOpcode opc = LOP_NOP; + + if (FFlag::LuauCompileFastcall3) + { + if (expr->args.size == 1) + opc = LOP_FASTCALL1; + else if (bfK >= 0 || (expr->args.size == 2 && isConstant(expr->args.data[1]))) + opc = LOP_FASTCALL2K; + else if (expr->args.size == 2) + opc = LOP_FASTCALL2; + else + opc = LOP_FASTCALL3; + } + else + { + opc = expr->args.size == 1 ? LOP_FASTCALL1 + : (bfK >= 0 || (expr->args.size == 2 && isConstant(expr->args.data[1]))) ? LOP_FASTCALL2K : LOP_FASTCALL2; + } uint32_t args[3] = {}; @@ -504,8 +527,16 @@ struct Compiler size_t fastcallLabel = bytecode.emitLabel(); bytecode.emitABC(opc, uint8_t(bfid), uint8_t(args[0]), 0); - if (opc != LOP_FASTCALL1) + + if (FFlag::LuauCompileFastcall3 && opc == LOP_FASTCALL3) + { + LUAU_ASSERT(bfK < 0); + bytecode.emitAux(args[1] | (args[2] << 8)); + } + else if (opc != LOP_FASTCALL1) + { bytecode.emitAux(bfK >= 0 ? bfK : args[1]); + } // Set up a traditional Lua stack for the subsequent LOP_CALL. // Note, as with other instructions that immediately follow FASTCALL, these are normally not executed and are used as a fallback for @@ -857,11 +888,28 @@ struct Compiler } } - // Optimization: for 1/2 argument fast calls use specialized opcodes - if (bfid >= 0 && expr->args.size >= 1 && expr->args.size <= 2) + unsigned maxFastcallArgs = 2; + + // Fastcall with 3 arguments is only used if it can help save one or more move instructions + if (FFlag::LuauCompileFastcall3 && bfid >= 0 && expr->args.size == 3) + { + for (size_t i = 0; i < expr->args.size; ++i) + { + if (int reg = getExprLocalReg(expr->args.data[i]); reg >= 0) + { + maxFastcallArgs = 3; + break; + } + } + } + + // Optimization: for 1/2/3 argument fast calls use specialized opcodes + if (bfid >= 0 && expr->args.size >= 1 && expr->args.size <= (FFlag::LuauCompileFastcall3 ? maxFastcallArgs : 2u)) { if (!isExprMultRet(expr->args.data[expr->args.size - 1])) + { return compileExprFastcallN(expr, target, targetCount, targetTop, multRet, regs, bfid); + } else if (options.optimizationLevel >= 2) { // when a builtin is none-safe with matching arity, even if the last expression returns 0 or >1 arguments, diff --git a/Config/include/Luau/LinterConfig.h b/Config/include/Luau/LinterConfig.h index a598a3df..3a68c0d7 100644 --- a/Config/include/Luau/LinterConfig.h +++ b/Config/include/Luau/LinterConfig.h @@ -49,6 +49,7 @@ struct LintWarning Code_CommentDirective = 26, Code_IntegerParsing = 27, Code_ComparisonPrecedence = 28, + Code_RedundantNativeAttribute = 29, Code__Count }; @@ -115,6 +116,7 @@ static const char* kWarningNames[] = { "CommentDirective", "IntegerParsing", "ComparisonPrecedence", + "RedundantNativeAttribute", }; // clang-format on diff --git a/VM/src/ldebug.cpp b/VM/src/ldebug.cpp index 0e792366..07cc117e 100644 --- a/VM/src/ldebug.cpp +++ b/VM/src/ldebug.cpp @@ -12,8 +12,6 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauPushErrorStackCheck, false) - static const char* getfuncname(Closure* f); static int currentpc(lua_State* L, CallInfo* ci) @@ -332,8 +330,7 @@ l_noret luaG_runerrorL(lua_State* L, const char* fmt, ...) vsnprintf(result, sizeof(result), fmt, argp); va_end(argp); - if (FFlag::LuauPushErrorStackCheck) - lua_rawcheckstack(L, 1); + lua_rawcheckstack(L, 1); pusherror(L, result); luaD_throw(L, LUA_ERRRUN); @@ -341,8 +338,7 @@ l_noret luaG_runerrorL(lua_State* L, const char* fmt, ...) void luaG_pusherror(lua_State* L, const char* error) { - if (FFlag::LuauPushErrorStackCheck) - lua_rawcheckstack(L, 1); + lua_rawcheckstack(L, 1); pusherror(L, error); } diff --git a/VM/src/lfunc.cpp b/VM/src/lfunc.cpp index b33fe9dd..2a1e45c4 100644 --- a/VM/src/lfunc.cpp +++ b/VM/src/lfunc.cpp @@ -6,8 +6,6 @@ #include "lmem.h" #include "lgc.h" -LUAU_FASTFLAGVARIABLE(LuauLoadTypeInfo, false) - Proto* luaF_newproto(lua_State* L) { Proto* f = luaM_newgco(L, Proto, sizeof(Proto), L->activememcat); @@ -52,9 +50,7 @@ Proto* luaF_newproto(lua_State* L) f->linegaplog2 = 0; f->linedefined = 0; f->bytecodeid = 0; - - if (FFlag::LuauLoadTypeInfo) - f->sizetypeinfo = 0; + f->sizetypeinfo = 0; return f; } @@ -178,16 +174,8 @@ void luaF_freeproto(lua_State* L, Proto* f, lua_Page* page) if (f->execdata) L->global->ecb.destroy(L, f); - if (FFlag::LuauLoadTypeInfo) - { - if (f->typeinfo) - luaM_freearray(L, f->typeinfo, f->sizetypeinfo, uint8_t, f->memcat); - } - else - { - if (f->typeinfo) - luaM_freearray(L, f->typeinfo, f->numparams + 2, uint8_t, f->memcat); - } + if (f->typeinfo) + luaM_freearray(L, f->typeinfo, f->sizetypeinfo, uint8_t, f->memcat); luaM_freegco(L, f, sizeof(Proto), f->memcat, page); } diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp index f8389422..4473f04f 100644 --- a/VM/src/lgc.cpp +++ b/VM/src/lgc.cpp @@ -14,8 +14,6 @@ #include -LUAU_FASTFLAG(LuauLoadTypeInfo) - /* * Luau uses an incremental non-generational non-moving mark&sweep garbage collector. * @@ -507,16 +505,8 @@ static size_t propagatemark(global_State* g) g->gray = p->gclist; traverseproto(g, p); - if (FFlag::LuauLoadTypeInfo) - { - return sizeof(Proto) + sizeof(Instruction) * p->sizecode + sizeof(Proto*) * p->sizep + sizeof(TValue) * p->sizek + p->sizelineinfo + - sizeof(LocVar) * p->sizelocvars + sizeof(TString*) * p->sizeupvalues + p->sizetypeinfo; - } - else - { - return sizeof(Proto) + sizeof(Instruction) * p->sizecode + sizeof(Proto*) * p->sizep + sizeof(TValue) * p->sizek + p->sizelineinfo + - sizeof(LocVar) * p->sizelocvars + sizeof(TString*) * p->sizeupvalues; - } + return sizeof(Proto) + sizeof(Instruction) * p->sizecode + sizeof(Proto*) * p->sizep + sizeof(TValue) * p->sizek + p->sizelineinfo + + sizeof(LocVar) * p->sizelocvars + sizeof(TString*) * p->sizeupvalues + p->sizetypeinfo; } default: LUAU_ASSERT(0); diff --git a/VM/src/lmem.cpp b/VM/src/lmem.cpp index f6cc07c9..5ff5de72 100644 --- a/VM/src/lmem.cpp +++ b/VM/src/lmem.cpp @@ -124,11 +124,16 @@ static_assert(offsetof(Udata, data) == ABISWITCH(16, 16, 12), "size mismatch for static_assert(sizeof(Table) == ABISWITCH(48, 32, 32), "size mismatch for table header"); static_assert(offsetof(Buffer, data) == ABISWITCH(8, 8, 8), "size mismatch for buffer header"); -LUAU_FASTFLAGVARIABLE(LuauExtendedSizeClasses, false) - const size_t kSizeClasses = LUA_SIZECLASSES; -const size_t kMaxSmallSize_DEPRECATED = 512; // TODO: remove with FFlagLuauExtendedSizeClasses + +// Controls the number of entries in SizeClassConfig and define the maximum possible paged allocation size +// Modifications require updates the SizeClassConfig initialization const size_t kMaxSmallSize = 1024; + +// Effective limit on object size to use paged allocation +// Can be modified without additional changes to code, provided it is smaller or equal to kMaxSmallSize +const size_t kMaxSmallSizeUsed = 1024; + const size_t kLargePageThreshold = 512; // larger pages are used for objects larger than this size to fit more of them into a page // constant factor to reduce our page sizes by, to increase the chances that pages we allocate will @@ -187,8 +192,7 @@ struct SizeClassConfig const SizeClassConfig kSizeClassConfig; // size class for a block of size sz; returns -1 for size=0 because empty allocations take no space -#define sizeclass(sz) \ - (size_t((sz)-1) < (FFlag::LuauExtendedSizeClasses ? kMaxSmallSize : kMaxSmallSize_DEPRECATED) ? kSizeClassConfig.classForSize[sz] : -1) +#define sizeclass(sz) (size_t((sz)-1) < kMaxSmallSizeUsed ? kSizeClassConfig.classForSize[sz] : -1) // metadata for a block is stored in the first pointer of the block #define metadata(block) (*(void**)(block)) @@ -275,34 +279,18 @@ static lua_Page* newpage(lua_State* L, lua_Page** pageset, int pageSize, int blo // if it is inlined, then the compiler may determine those functions are "too big" to be profitably inlined, which results in reduced performance LUAU_NOINLINE static lua_Page* newclasspage(lua_State* L, lua_Page** freepageset, lua_Page** pageset, uint8_t sizeClass, bool storeMetadata) { - if (FFlag::LuauExtendedSizeClasses) - { - int sizeOfClass = kSizeClassConfig.sizeOfClass[sizeClass]; - int pageSize = sizeOfClass > int(kLargePageThreshold) ? kLargePageSize : kSmallPageSize; - int blockSize = sizeOfClass + (storeMetadata ? kBlockHeader : 0); - int blockCount = (pageSize - offsetof(lua_Page, data)) / blockSize; + int sizeOfClass = kSizeClassConfig.sizeOfClass[sizeClass]; + int pageSize = sizeOfClass > int(kLargePageThreshold) ? kLargePageSize : kSmallPageSize; + int blockSize = sizeOfClass + (storeMetadata ? kBlockHeader : 0); + int blockCount = (pageSize - offsetof(lua_Page, data)) / blockSize; - lua_Page* page = newpage(L, pageset, pageSize, blockSize, blockCount); + lua_Page* page = newpage(L, pageset, pageSize, blockSize, blockCount); - // prepend a page to page freelist (which is empty because we only ever allocate a new page when it is!) - LUAU_ASSERT(!freepageset[sizeClass]); - freepageset[sizeClass] = page; + // prepend a page to page freelist (which is empty because we only ever allocate a new page when it is!) + LUAU_ASSERT(!freepageset[sizeClass]); + freepageset[sizeClass] = page; - return page; - } - else - { - int blockSize = kSizeClassConfig.sizeOfClass[sizeClass] + (storeMetadata ? kBlockHeader : 0); - int blockCount = (kSmallPageSize - offsetof(lua_Page, data)) / blockSize; - - lua_Page* page = newpage(L, pageset, kSmallPageSize, blockSize, blockCount); - - // prepend a page to page freelist (which is empty because we only ever allocate a new page when it is!) - LUAU_ASSERT(!freepageset[sizeClass]); - freepageset[sizeClass] = page; - - return page; - } + return page; } static void freepage(lua_State* L, lua_Page** pageset, lua_Page* page) diff --git a/VM/src/ltablib.cpp b/VM/src/ltablib.cpp index 3e14d4ad..75d9f400 100644 --- a/VM/src/ltablib.cpp +++ b/VM/src/ltablib.cpp @@ -11,9 +11,6 @@ #include "ldebug.h" #include "lvm.h" -LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauFastTableMaxn, false) -LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauFasterConcat, false) - static int foreachi(lua_State* L) { luaL_checktype(L, 1, LUA_TTABLE); @@ -56,41 +53,24 @@ static int maxn(lua_State* L) double max = 0; luaL_checktype(L, 1, LUA_TTABLE); - if (DFFlag::LuauFastTableMaxn) + Table* t = hvalue(L->base); + + for (int i = 0; i < t->sizearray; i++) { - Table* t = hvalue(L->base); - - for (int i = 0; i < t->sizearray; i++) - { - if (!ttisnil(&t->array[i])) - max = i + 1; - } - - for (int i = 0; i < sizenode(t); i++) - { - LuaNode* n = gnode(t, i); - - if (!ttisnil(gval(n)) && ttisnumber(gkey(n))) - { - double v = nvalue(gkey(n)); - - if (v > max) - max = v; - } - } + if (!ttisnil(&t->array[i])) + max = i + 1; } - else + + for (int i = 0; i < sizenode(t); i++) { - lua_pushnil(L); // first key - while (lua_next(L, 1)) + LuaNode* n = gnode(t, i); + + if (!ttisnil(gval(n)) && ttisnumber(gkey(n))) { - lua_pop(L, 1); // remove value - if (lua_type(L, -1) == LUA_TNUMBER) - { - double v = lua_tonumber(L, -1); - if (v > max) - max = v; - } + double v = nvalue(gkey(n)); + + if (v > max) + max = v; } } @@ -251,7 +231,7 @@ static int tmove(lua_State* L) static void addfield(lua_State* L, luaL_Strbuf* b, int i, Table* t) { - if (DFFlag::LuauFasterConcat && t && unsigned(i - 1) < unsigned(t->sizearray) && ttisstring(&t->array[i - 1])) + if (t && unsigned(i - 1) < unsigned(t->sizearray) && ttisstring(&t->array[i - 1])) { TString* ts = tsvalue(&t->array[i - 1]); luaL_addlstring(b, getstr(ts), ts->len); @@ -273,14 +253,14 @@ static int tconcat(lua_State* L) int i = luaL_optinteger(L, 3, 1); int last = luaL_opt(L, luaL_checkinteger, 4, lua_objlen(L, 1)); - Table* t = DFFlag::LuauFasterConcat ? hvalue(L->base) : NULL; + Table* t = hvalue(L->base); luaL_Strbuf b; luaL_buffinit(L, &b); for (; i < last; i++) { addfield(L, &b, i, t); - if (!DFFlag::LuauFasterConcat || lsep != 0) + if (lsep != 0) luaL_addlstring(&b, sep, lsep); } if (i == last) // add last value (if interval was not empty) diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index 4ac21db3..bc89458e 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -98,7 +98,7 @@ LUAU_FASTFLAGVARIABLE(LuauVmSplitDoarith, false) VM_DISPATCH_OP(LOP_POWK), VM_DISPATCH_OP(LOP_AND), VM_DISPATCH_OP(LOP_OR), VM_DISPATCH_OP(LOP_ANDK), VM_DISPATCH_OP(LOP_ORK), \ VM_DISPATCH_OP(LOP_CONCAT), VM_DISPATCH_OP(LOP_NOT), VM_DISPATCH_OP(LOP_MINUS), VM_DISPATCH_OP(LOP_LENGTH), VM_DISPATCH_OP(LOP_NEWTABLE), \ VM_DISPATCH_OP(LOP_DUPTABLE), VM_DISPATCH_OP(LOP_SETLIST), VM_DISPATCH_OP(LOP_FORNPREP), VM_DISPATCH_OP(LOP_FORNLOOP), \ - VM_DISPATCH_OP(LOP_FORGLOOP), VM_DISPATCH_OP(LOP_FORGPREP_INEXT), VM_DISPATCH_OP(LOP_DEP_FORGLOOP_INEXT), VM_DISPATCH_OP(LOP_FORGPREP_NEXT), \ + VM_DISPATCH_OP(LOP_FORGLOOP), VM_DISPATCH_OP(LOP_FORGPREP_INEXT), VM_DISPATCH_OP(LOP_FASTCALL3), VM_DISPATCH_OP(LOP_FORGPREP_NEXT), \ VM_DISPATCH_OP(LOP_NATIVECALL), VM_DISPATCH_OP(LOP_GETVARARGS), VM_DISPATCH_OP(LOP_DUPCLOSURE), VM_DISPATCH_OP(LOP_PREPVARARGS), \ VM_DISPATCH_OP(LOP_LOADKX), VM_DISPATCH_OP(LOP_JUMPX), VM_DISPATCH_OP(LOP_FASTCALL), VM_DISPATCH_OP(LOP_COVERAGE), \ VM_DISPATCH_OP(LOP_CAPTURE), VM_DISPATCH_OP(LOP_SUBRK), VM_DISPATCH_OP(LOP_DIVRK), VM_DISPATCH_OP(LOP_FASTCALL1), \ @@ -2539,12 +2539,6 @@ reentry: VM_NEXT(); } - VM_CASE(LOP_DEP_FORGLOOP_INEXT) - { - LUAU_ASSERT(!"Unsupported deprecated opcode"); - LUAU_UNREACHABLE(); - } - VM_CASE(LOP_FORGPREP_NEXT) { Instruction insn = *pc++; @@ -3013,6 +3007,60 @@ reentry: } } + VM_CASE(LOP_FASTCALL3) + { + Instruction insn = *pc++; + int bfid = LUAU_INSN_A(insn); + int skip = LUAU_INSN_C(insn) - 1; + uint32_t aux = *pc++; + TValue* arg1 = VM_REG(LUAU_INSN_B(insn)); + TValue* arg2 = VM_REG(aux & 0xff); + TValue* arg3 = VM_REG((aux >> 8) & 0xff); + + LUAU_ASSERT(unsigned(pc - cl->l.p->code + skip) < unsigned(cl->l.p->sizecode)); + + Instruction call = pc[skip]; + LUAU_ASSERT(LUAU_INSN_OP(call) == LOP_CALL); + + StkId ra = VM_REG(LUAU_INSN_A(call)); + + int nparams = 3; + int nresults = LUAU_INSN_C(call) - 1; + + luau_FastFunction f = luauF_table[bfid]; + LUAU_ASSERT(f); + + if (cl->env->safeenv) + { + VM_PROTECT_PC(); // f may fail due to OOM + + setobj2s(L, L->top, arg2); + setobj2s(L, L->top + 1, arg3); + + int n = f(L, ra, arg1, nresults, L->top, nparams); + + if (n >= 0) + { + if (nresults == LUA_MULTRET) + L->top = ra + n; + + pc += skip + 1; // skip instructions that compute function as well as CALL + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + else + { + // continue execution through the fallback code + VM_NEXT(); + } + } + else + { + // continue execution through the fallback code + VM_NEXT(); + } + } + VM_CASE(LOP_BREAK) { LUAU_ASSERT(cl->l.p->debuginsn); diff --git a/VM/src/lvmload.cpp b/VM/src/lvmload.cpp index ed564bba..112a7197 100644 --- a/VM/src/lvmload.cpp +++ b/VM/src/lvmload.cpp @@ -13,7 +13,6 @@ #include -LUAU_FASTFLAG(LuauLoadTypeInfo) LUAU_FASTFLAGVARIABLE(LuauLoadUserdataInfo, false) // TODO: RAII deallocation doesn't work for longjmp builds if a memory error happens @@ -357,70 +356,11 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size { p->flags = read(data, size, offset); - if (FFlag::LuauLoadTypeInfo) - { - if (typesversion == 1) - { - uint32_t typesize = readVarInt(data, size, offset); - - if (typesize) - { - uint8_t* types = (uint8_t*)data + offset; - - LUAU_ASSERT(typesize == unsigned(2 + p->numparams)); - LUAU_ASSERT(types[0] == LBC_TYPE_FUNCTION); - LUAU_ASSERT(types[1] == p->numparams); - - // transform v1 into v2 format - int headersize = typesize > 127 ? 4 : 3; - - p->typeinfo = luaM_newarray(L, headersize + typesize, uint8_t, p->memcat); - p->sizetypeinfo = headersize + typesize; - - if (headersize == 4) - { - p->typeinfo[0] = (typesize & 127) | (1 << 7); - p->typeinfo[1] = typesize >> 7; - p->typeinfo[2] = 0; - p->typeinfo[3] = 0; - } - else - { - p->typeinfo[0] = uint8_t(typesize); - p->typeinfo[1] = 0; - p->typeinfo[2] = 0; - } - - memcpy(p->typeinfo + headersize, types, typesize); - } - - offset += typesize; - } - else if (typesversion == 2 || (FFlag::LuauLoadUserdataInfo && typesversion == 3)) - { - uint32_t typesize = readVarInt(data, size, offset); - - if (typesize) - { - uint8_t* types = (uint8_t*)data + offset; - - p->typeinfo = luaM_newarray(L, typesize, uint8_t, p->memcat); - p->sizetypeinfo = typesize; - memcpy(p->typeinfo, types, typesize); - offset += typesize; - - if (FFlag::LuauLoadUserdataInfo && typesversion == 3) - { - remapUserdataTypes((char*)(uint8_t*)p->typeinfo, p->sizetypeinfo, userdataRemapping, userdataTypeLimit); - } - } - } - } - else + if (typesversion == 1) { uint32_t typesize = readVarInt(data, size, offset); - if (typesize && typesversion == LBC_TYPE_VERSION_DEPRECATED) + if (typesize) { uint8_t* types = (uint8_t*)data + offset; @@ -428,12 +368,50 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size LUAU_ASSERT(types[0] == LBC_TYPE_FUNCTION); LUAU_ASSERT(types[1] == p->numparams); - p->typeinfo = luaM_newarray(L, typesize, uint8_t, p->memcat); - memcpy(p->typeinfo, types, typesize); + // transform v1 into v2 format + int headersize = typesize > 127 ? 4 : 3; + + p->typeinfo = luaM_newarray(L, headersize + typesize, uint8_t, p->memcat); + p->sizetypeinfo = headersize + typesize; + + if (headersize == 4) + { + p->typeinfo[0] = (typesize & 127) | (1 << 7); + p->typeinfo[1] = typesize >> 7; + p->typeinfo[2] = 0; + p->typeinfo[3] = 0; + } + else + { + p->typeinfo[0] = uint8_t(typesize); + p->typeinfo[1] = 0; + p->typeinfo[2] = 0; + } + + memcpy(p->typeinfo + headersize, types, typesize); } offset += typesize; } + else if (typesversion == 2 || (FFlag::LuauLoadUserdataInfo && typesversion == 3)) + { + uint32_t typesize = readVarInt(data, size, offset); + + if (typesize) + { + uint8_t* types = (uint8_t*)data + offset; + + p->typeinfo = luaM_newarray(L, typesize, uint8_t, p->memcat); + p->sizetypeinfo = typesize; + memcpy(p->typeinfo, types, typesize); + offset += typesize; + + if (FFlag::LuauLoadUserdataInfo && typesversion == 3) + { + remapUserdataTypes((char*)(uint8_t*)p->typeinfo, p->sizetypeinfo, userdataRemapping, userdataTypeLimit); + } + } + } } const int sizecode = readVarInt(data, size, offset); diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 6255d73f..eeca416c 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -25,6 +25,7 @@ LUAU_FASTINT(LuauRecursionLimit) LUAU_FASTFLAG(LuauCompileTypeInfo) LUAU_FASTFLAG(LuauCompileTempTypeInfo) LUAU_FASTFLAG(LuauCompileUserdataInfo) +LUAU_FASTFLAG(LuauCompileFastcall3) using namespace Luau; @@ -3486,6 +3487,33 @@ RETURN R1 -1 )"); } +TEST_CASE("Fastcall3") +{ + ScopedFastFlag luauCompileFastcall3{FFlag::LuauCompileFastcall3, true}; + + CHECK_EQ("\n" + compileFunction0(R"( +local a, b, c = ... +return math.min(a, b, c) + math.clamp(a, b, c) +)"), + R"( +GETVARARGS R0 3 +FASTCALL3 19 R0 R1 R2 L0 +MOVE R5 R0 +MOVE R6 R1 +MOVE R7 R2 +GETIMPORT R4 2 [math.min] +CALL R4 3 1 +L0: FASTCALL3 46 R0 R1 R2 L1 +MOVE R6 R0 +MOVE R7 R1 +MOVE R8 R2 +GETIMPORT R5 4 [math.clamp] +CALL R5 3 1 +L1: ADD R3 R4 R5 +RETURN R3 1 +)"); +} + TEST_CASE("FastcallSelect") { // select(_, ...) compiles to a builtin call @@ -4668,6 +4696,34 @@ L0: RETURN R0 -1 )"); } +TEST_CASE("VectorFastCall3") +{ + ScopedFastFlag luauCompileFastcall3{FFlag::LuauCompileFastcall3, true}; + + const char* source = R"( +local a, b, c = ... +return Vector3.new(a, b, c) +)"; + + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code); + Luau::CompileOptions options; + options.vectorLib = "Vector3"; + options.vectorCtor = "new"; + Luau::compileOrThrow(bcb, source, options); + + CHECK_EQ("\n" + bcb.dumpFunction(0), R"( +GETVARARGS R0 3 +FASTCALL3 54 R0 R1 R2 L0 +MOVE R4 R0 +MOVE R5 R1 +MOVE R6 R2 +GETIMPORT R3 2 [Vector3.new] +CALL R3 3 -1 +L0: RETURN R3 -1 +)"); +} + TEST_CASE("VectorLiterals") { CHECK_EQ("\n" + compileFunction("return Vector3.new(1, 2, 3)", 0, 2, /*enableVectors*/ true), R"( diff --git a/tests/Error.test.cpp b/tests/Error.test.cpp index 8dfcbde0..00a5a2e7 100644 --- a/tests/Error.test.cpp +++ b/tests/Error.test.cpp @@ -48,7 +48,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "binary_op_type_family_errors") LUAU_REQUIRE_ERROR_COUNT(1, result); if (FFlag::DebugLuauDeferredConstraintResolution) - CHECK_EQ("Operator '+' could not be applied to operands of types number and string; there is no corresponding overload for __add", toString(result.errors[0])); + CHECK_EQ("Operator '+' could not be applied to operands of types number and string; there is no corresponding overload for __add", + toString(result.errors[0])); else CHECK_EQ("Type 'string' could not be converted into 'number'", toString(result.errors[0])); } @@ -66,7 +67,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "unary_op_type_family_errors") if (FFlag::DebugLuauDeferredConstraintResolution) { LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK_EQ("Operator '-' could not be applied to operand of type string; there is no corresponding overload for __unm", toString(result.errors[0])); + CHECK_EQ( + "Operator '-' could not be applied to operand of type string; there is no corresponding overload for __unm", toString(result.errors[0])); CHECK_EQ("Type 'string' could not be converted into 'number'", toString(result.errors[1])); } else diff --git a/tests/Fixture.h b/tests/Fixture.h index 481f79d3..e0c04e8b 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -241,3 +241,21 @@ using DifferFixtureWithBuiltins = DifferFixtureGeneric; } while (false) #define LUAU_REQUIRE_NO_ERRORS(result) LUAU_REQUIRE_ERROR_COUNT(0, result) + +#define LUAU_CHECK_ERRORS(result) \ + do \ + { \ + auto&& r = (result); \ + validateErrors(r.errors); \ + CHECK(!r.errors.empty()); \ + } while (false) + +#define LUAU_CHECK_ERROR_COUNT(count, result) \ + do \ + { \ + auto&& r = (result); \ + validateErrors(r.errors); \ + CHECK_MESSAGE(count == r.errors.size(), getErrors(r)); \ + } while (false) + +#define LUAU_CHECK_NO_ERRORS(result) LUAU_CHECK_ERROR_COUNT(0, result) diff --git a/tests/Generalization.test.cpp b/tests/Generalization.test.cpp index 43bd7325..e9344911 100644 --- a/tests/Generalization.test.cpp +++ b/tests/Generalization.test.cpp @@ -125,11 +125,7 @@ TEST_CASE_FIXTURE(GeneralizationFixture, "cache_fully_generalized_types") CHECK(generalizedTypes->empty()); TypeId tinyTable = arena.addType(TableType{ - TableType::Props{{"one", builtinTypes.numberType}, {"two", builtinTypes.stringType}}, - std::nullopt, - TypeLevel{}, - TableState::Sealed - }); + TableType::Props{{"one", builtinTypes.numberType}, {"two", builtinTypes.stringType}}, std::nullopt, TypeLevel{}, TableState::Sealed}); generalize(tinyTable); @@ -142,17 +138,10 @@ TEST_CASE_FIXTURE(GeneralizationFixture, "dont_cache_types_that_arent_done_yet") { TypeId freeTy = arena.addType(FreeType{NotNull{globalScope.get()}, builtinTypes.neverType, builtinTypes.stringType}); - TypeId fnTy = arena.addType(FunctionType{ - builtinTypes.emptyTypePack, - arena.addTypePack(TypePack{{builtinTypes.numberType}}) - }); + TypeId fnTy = arena.addType(FunctionType{builtinTypes.emptyTypePack, arena.addTypePack(TypePack{{builtinTypes.numberType}})}); TypeId tableTy = arena.addType(TableType{ - TableType::Props{{"one", builtinTypes.numberType}, {"two", freeTy}, {"three", fnTy}}, - std::nullopt, - TypeLevel{}, - TableState::Sealed - }); + TableType::Props{{"one", builtinTypes.numberType}, {"two", freeTy}, {"three", fnTy}}, std::nullopt, TypeLevel{}, TableState::Sealed}); generalize(tableTy); @@ -174,11 +163,7 @@ TEST_CASE_FIXTURE(GeneralizationFixture, "functions_containing_cyclic_tables_can }); asMutable(selfTy)->ty.emplace( - TableType::Props{{"count", builtinTypes.numberType}, {"method", methodTy}}, - std::nullopt, - TypeLevel{}, - TableState::Sealed - ); + TableType::Props{{"count", builtinTypes.numberType}, {"method", methodTy}}, std::nullopt, TypeLevel{}, TableState::Sealed); generalize(methodTy); diff --git a/tests/IrBuilder.test.cpp b/tests/IrBuilder.test.cpp index da7cd9b1..bd7a02b2 100644 --- a/tests/IrBuilder.test.cpp +++ b/tests/IrBuilder.test.cpp @@ -12,9 +12,10 @@ #include -LUAU_FASTFLAG(LuauCodegenRemoveDeadStores5) LUAU_FASTFLAG(DebugLuauAbortingChecks) LUAU_FASTFLAG(LuauCodegenFixSplitStoreConstMismatch) +LUAU_FASTFLAG(LuauCodegenInstG) +LUAU_FASTFLAG(LuauCodegenFastcall3) using namespace Luau::CodeGen; @@ -1117,6 +1118,8 @@ bb_0: TEST_CASE_FIXTURE(IrBuilderFixture, "BuiltinFastcallsMayInvalidateMemory") { + ScopedFastFlag luauCodegenInstG{FFlag::LuauCodegenInstG, true}; + IrOp block = build.block(IrBlockKind::Internal); IrOp fallback = build.block(IrBlockKind::Fallback); @@ -1129,8 +1132,8 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "BuiltinFastcallsMayInvalidateMemory") build.inst(IrCmd::CHECK_NO_METATABLE, table, fallback); build.inst(IrCmd::CHECK_READONLY, table, fallback); - build.inst(IrCmd::INVOKE_FASTCALL, build.constUint(LBF_SETMETATABLE), build.vmReg(1), build.vmReg(2), build.vmReg(3), build.constInt(3), - build.constInt(1)); + build.inst(IrCmd::INVOKE_FASTCALL, build.constUint(LBF_SETMETATABLE), build.vmReg(1), build.vmReg(2), build.vmReg(3), build.undef(), + build.constInt(3), build.constInt(1)); build.inst(IrCmd::CHECK_NO_METATABLE, table, fallback); build.inst(IrCmd::CHECK_READONLY, table, fallback); @@ -1151,7 +1154,7 @@ bb_0: %1 = LOAD_POINTER R0 CHECK_NO_METATABLE %1, bb_fallback_1 CHECK_READONLY %1, bb_fallback_1 - %4 = INVOKE_FASTCALL 61u, R1, R2, R3, 3i, 1i + %4 = INVOKE_FASTCALL 61u, R1, R2, R3, undef, 3i, 1i CHECK_NO_METATABLE %1, bb_fallback_1 CHECK_READONLY %1, bb_fallback_1 STORE_DOUBLE R1, 0.5 @@ -2546,8 +2549,6 @@ bb_0: ; useCount: 0 TEST_CASE_FIXTURE(IrBuilderFixture, "ForgprepInvalidation") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - IrOp block = build.block(IrBlockKind::Internal); IrOp followup = build.block(IrBlockKind::Internal); @@ -2587,14 +2588,14 @@ bb_1: TEST_CASE_FIXTURE(IrBuilderFixture, "FastCallEffects1") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; + ScopedFastFlag luauCodegenFastcall3{FFlag::LuauCodegenFastcall3, true}; IrOp entry = build.block(IrBlockKind::Internal); build.beginBlock(entry); - build.inst(IrCmd::FASTCALL, build.constUint(LBF_MATH_FREXP), build.vmReg(1), build.vmReg(2), build.undef(), build.constInt(1), build.constInt(2)); - build.inst(IrCmd::CHECK_TAG, build.vmReg(1), build.constTag(tnumber), build.vmExit(1)); - build.inst(IrCmd::CHECK_TAG, build.vmReg(2), build.constTag(tnumber), build.vmExit(1)); + build.inst(IrCmd::FASTCALL, build.constUint(LBF_MATH_FREXP), build.vmReg(1), build.vmReg(2), build.constInt(2)); + build.inst(IrCmd::CHECK_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(1)), build.constTag(tnumber), build.vmExit(1)); + build.inst(IrCmd::CHECK_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(2)), build.constTag(tnumber), build.vmExit(1)); build.inst(IrCmd::RETURN, build.vmReg(1), build.constInt(2)); updateUseCounts(build.function); @@ -2604,7 +2605,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FastCallEffects1") CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( bb_0: ; in regs: R2 - FASTCALL 14u, R1, R2, undef, 1i, 2i + FASTCALL 14u, R1, R2, 2i RETURN R1, 2i )"); @@ -2612,14 +2613,14 @@ bb_0: TEST_CASE_FIXTURE(IrBuilderFixture, "FastCallEffects2") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; + ScopedFastFlag luauCodegenFastcall3{FFlag::LuauCodegenFastcall3, true}; IrOp entry = build.block(IrBlockKind::Internal); build.beginBlock(entry); - build.inst(IrCmd::FASTCALL, build.constUint(LBF_MATH_MODF), build.vmReg(1), build.vmReg(2), build.undef(), build.constInt(1), build.constInt(1)); - build.inst(IrCmd::CHECK_TAG, build.vmReg(1), build.constTag(tnumber), build.vmExit(1)); - build.inst(IrCmd::CHECK_TAG, build.vmReg(2), build.constTag(tnumber), build.vmExit(1)); + build.inst(IrCmd::FASTCALL, build.constUint(LBF_MATH_MODF), build.vmReg(1), build.vmReg(2), build.constInt(1)); + build.inst(IrCmd::CHECK_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(1)), build.constTag(tnumber), build.vmExit(1)); + build.inst(IrCmd::CHECK_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(2)), build.constTag(tnumber), build.vmExit(1)); build.inst(IrCmd::RETURN, build.vmReg(1), build.constInt(2)); updateUseCounts(build.function); @@ -2629,8 +2630,9 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FastCallEffects2") CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( bb_0: ; in regs: R2 - FASTCALL 20u, R1, R2, undef, 1i, 1i - CHECK_TAG R2, tnumber, exit(1) + FASTCALL 20u, R1, R2, 1i + %3 = LOAD_TAG R2 + CHECK_TAG %3, tnumber, exit(1) RETURN R1, 2i )"); @@ -2642,7 +2644,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "InferNumberTagFromLimitedContext") build.beginBlock(entry); build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.constDouble(2.0)); - build.inst(IrCmd::CHECK_TAG, build.vmReg(0), build.constTag(ttable), build.vmExit(1)); + build.inst(IrCmd::CHECK_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(0)), build.constTag(ttable), build.vmExit(1)); build.inst(IrCmd::STORE_TVALUE, build.vmReg(1), build.inst(IrCmd::LOAD_TVALUE, build.vmReg(0))); build.inst(IrCmd::RETURN, build.vmReg(1), build.constInt(1)); @@ -2666,7 +2668,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "DoNotProduceInvalidSplitStore1") build.beginBlock(entry); build.inst(IrCmd::STORE_INT, build.vmReg(0), build.constInt(1)); - build.inst(IrCmd::CHECK_TAG, build.vmReg(0), build.constTag(ttable), build.vmExit(1)); + build.inst(IrCmd::CHECK_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(0)), build.constTag(ttable), build.vmExit(1)); build.inst(IrCmd::STORE_TVALUE, build.vmReg(1), build.inst(IrCmd::LOAD_TVALUE, build.vmReg(0))); build.inst(IrCmd::RETURN, build.vmReg(1), build.constInt(1)); @@ -2677,9 +2679,10 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "DoNotProduceInvalidSplitStore1") CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( bb_0: STORE_INT R0, 1i - CHECK_TAG R0, ttable, exit(1) - %2 = LOAD_TVALUE R0 - STORE_TVALUE R1, %2 + %1 = LOAD_TAG R0 + CHECK_TAG %1, ttable, exit(1) + %3 = LOAD_TVALUE R0 + STORE_TVALUE R1, %3 RETURN R1, 1i )"); @@ -2693,7 +2696,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "DoNotProduceInvalidSplitStore2") build.beginBlock(entry); build.inst(IrCmd::STORE_INT, build.vmReg(0), build.constInt(1)); - build.inst(IrCmd::CHECK_TAG, build.vmReg(0), build.constTag(tnumber), build.vmExit(1)); + build.inst(IrCmd::CHECK_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(0)), build.constTag(tnumber), build.vmExit(1)); build.inst(IrCmd::STORE_TVALUE, build.vmReg(1), build.inst(IrCmd::LOAD_TVALUE, build.vmReg(0))); build.inst(IrCmd::RETURN, build.vmReg(1), build.constInt(1)); @@ -2704,9 +2707,10 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "DoNotProduceInvalidSplitStore2") CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( bb_0: STORE_INT R0, 1i - CHECK_TAG R0, tnumber, exit(1) - %2 = LOAD_TVALUE R0 - STORE_TVALUE R1, %2 + %1 = LOAD_TAG R0 + CHECK_TAG %1, tnumber, exit(1) + %3 = LOAD_TVALUE R0 + STORE_TVALUE R1, %3 RETURN R1, 1i )"); @@ -2809,13 +2813,16 @@ bb_1: TEST_CASE_FIXTURE(IrBuilderFixture, "ExplicitUseOfRegisterInVarargSequence") { + ScopedFastFlag luauCodegenInstG{FFlag::LuauCodegenInstG, true}; + ScopedFastFlag luauCodegenFastcall3{FFlag::LuauCodegenFastcall3, true}; + IrOp entry = build.block(IrBlockKind::Internal); IrOp exit = build.block(IrBlockKind::Internal); build.beginBlock(entry); build.inst(IrCmd::FALLBACK_GETVARARGS, build.constUint(0), build.vmReg(1), build.constInt(-1)); - IrOp results = build.inst( - IrCmd::INVOKE_FASTCALL, build.constUint(0), build.vmReg(0), build.vmReg(1), build.vmReg(2), build.constInt(-1), build.constInt(-1)); + IrOp results = build.inst(IrCmd::INVOKE_FASTCALL, build.constUint(0), build.vmReg(0), build.vmReg(1), build.vmReg(2), build.undef(), + build.constInt(-1), build.constInt(-1)); build.inst(IrCmd::ADJUST_STACK_TO_REG, build.vmReg(0), results); build.inst(IrCmd::JUMP, exit); @@ -2830,7 +2837,7 @@ bb_0: ; successors: bb_1 ; out regs: R0... FALLBACK_GETVARARGS 0u, R1, -1i - %1 = INVOKE_FASTCALL 0u, R0, R1, R2, -1i, -1i + %1 = INVOKE_FASTCALL 0u, R0, R1, R2, undef, -1i, -1i ADJUST_STACK_TO_REG R0, %1 JUMP bb_1 @@ -3023,8 +3030,6 @@ bb_1: TEST_CASE_FIXTURE(IrBuilderFixture, "ForgprepImplicitUse") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - IrOp entry = build.block(IrBlockKind::Internal); IrOp direct = build.block(IrBlockKind::Internal); IrOp fallback = build.block(IrBlockKind::Internal); @@ -3585,8 +3590,6 @@ TEST_SUITE_BEGIN("DeadStoreRemoval"); TEST_CASE_FIXTURE(IrBuilderFixture, "SimpleDoubleStore") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - IrOp entry = build.block(IrBlockKind::Internal); build.beginBlock(entry); @@ -3636,8 +3639,6 @@ bb_0: TEST_CASE_FIXTURE(IrBuilderFixture, "UnusedAtReturn") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - IrOp entry = build.block(IrBlockKind::Internal); build.beginBlock(entry); @@ -3669,8 +3670,6 @@ bb_0: TEST_CASE_FIXTURE(IrBuilderFixture, "UnusedAtReturnPartial") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - IrOp entry = build.block(IrBlockKind::Internal); build.beginBlock(entry); @@ -3699,8 +3698,6 @@ bb_0: TEST_CASE_FIXTURE(IrBuilderFixture, "HiddenPointerUse1") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - IrOp entry = build.block(IrBlockKind::Internal); build.beginBlock(entry); @@ -3729,8 +3726,6 @@ bb_0: TEST_CASE_FIXTURE(IrBuilderFixture, "HiddenPointerUse2") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - IrOp entry = build.block(IrBlockKind::Internal); build.beginBlock(entry); @@ -3763,8 +3758,6 @@ bb_0: TEST_CASE_FIXTURE(IrBuilderFixture, "HiddenPointerUse3") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - IrOp entry = build.block(IrBlockKind::Internal); build.beginBlock(entry); @@ -3793,8 +3786,6 @@ bb_0: TEST_CASE_FIXTURE(IrBuilderFixture, "HiddenPointerUse4") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - IrOp entry = build.block(IrBlockKind::Internal); build.beginBlock(entry); @@ -3827,8 +3818,6 @@ bb_0: TEST_CASE_FIXTURE(IrBuilderFixture, "PartialVsFullStoresWithRecombination") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - IrOp entry = build.block(IrBlockKind::Internal); build.beginBlock(entry); @@ -3852,8 +3841,6 @@ bb_0: TEST_CASE_FIXTURE(IrBuilderFixture, "IgnoreFastcallAdjustment") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - IrOp entry = build.block(IrBlockKind::Internal); build.beginBlock(entry); @@ -3880,8 +3867,6 @@ bb_0: TEST_CASE_FIXTURE(IrBuilderFixture, "JumpImplicitLiveOut") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - IrOp entry = build.block(IrBlockKind::Internal); IrOp next = build.block(IrBlockKind::Internal); @@ -3917,8 +3902,6 @@ bb_1: TEST_CASE_FIXTURE(IrBuilderFixture, "KeepCapturedRegisterStores") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - IrOp entry = build.block(IrBlockKind::Internal); build.beginBlock(entry); @@ -3956,7 +3939,6 @@ bb_0: TEST_CASE_FIXTURE(IrBuilderFixture, "StoreCannotBeReplacedWithCheck") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; ScopedFastFlag debugLuauAbortingChecks{FFlag::DebugLuauAbortingChecks, true}; IrOp block = build.block(IrBlockKind::Internal); @@ -4025,8 +4007,6 @@ bb_2: TEST_CASE_FIXTURE(IrBuilderFixture, "FullStoreHasToBeObservableFromFallbacks") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - IrOp entry = build.block(IrBlockKind::Internal); IrOp fallback = build.block(IrBlockKind::Fallback); IrOp last = build.block(IrBlockKind::Internal); @@ -4083,8 +4063,6 @@ bb_2: TEST_CASE_FIXTURE(IrBuilderFixture, "FullStoreHasToBeObservableFromFallbacks2") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - IrOp entry = build.block(IrBlockKind::Internal); IrOp fallback = build.block(IrBlockKind::Fallback); IrOp last = build.block(IrBlockKind::Internal); @@ -4139,8 +4117,6 @@ bb_2: TEST_CASE_FIXTURE(IrBuilderFixture, "FullStoreHasToBeObservableFromFallbacks3") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - IrOp entry = build.block(IrBlockKind::Internal); IrOp fallback = build.block(IrBlockKind::Fallback); IrOp last = build.block(IrBlockKind::Internal); @@ -4198,8 +4174,6 @@ bb_2: TEST_CASE_FIXTURE(IrBuilderFixture, "SafePartialValueStoresWithPreservedTag") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - IrOp entry = build.block(IrBlockKind::Internal); IrOp fallback = build.block(IrBlockKind::Fallback); IrOp last = build.block(IrBlockKind::Internal); @@ -4253,8 +4227,6 @@ bb_2: TEST_CASE_FIXTURE(IrBuilderFixture, "SafePartialValueStoresWithPreservedTag2") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - IrOp entry = build.block(IrBlockKind::Internal); IrOp fallback = build.block(IrBlockKind::Fallback); IrOp last = build.block(IrBlockKind::Internal); @@ -4307,8 +4279,6 @@ bb_2: TEST_CASE_FIXTURE(IrBuilderFixture, "DoNotReturnWithPartialStores") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - IrOp entry = build.block(IrBlockKind::Internal); IrOp success = build.block(IrBlockKind::Internal); IrOp fail = build.block(IrBlockKind::Internal); @@ -4379,8 +4349,6 @@ bb_3: TEST_CASE_FIXTURE(IrBuilderFixture, "PartialOverFullValue") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - IrOp entry = build.block(IrBlockKind::Internal); build.beginBlock(entry); diff --git a/tests/IrLowering.test.cpp b/tests/IrLowering.test.cpp index ecdb522c..33d5602b 100644 --- a/tests/IrLowering.test.cpp +++ b/tests/IrLowering.test.cpp @@ -15,17 +15,15 @@ #include #include -LUAU_FASTFLAG(LuauCodegenRemoveDeadStores5) -LUAU_FASTFLAG(LuauCodegenDirectUserdataFlow) LUAU_FASTFLAG(LuauCompileTypeInfo) -LUAU_FASTFLAG(LuauLoadTypeInfo) -LUAU_FASTFLAG(LuauCodegenTypeInfo) LUAU_FASTFLAG(LuauCompileTempTypeInfo) LUAU_FASTFLAG(LuauCodegenAnalyzeHostVectorOps) LUAU_FASTFLAG(LuauCompileUserdataInfo) LUAU_FASTFLAG(LuauLoadUserdataInfo) LUAU_FASTFLAG(LuauCodegenUserdataOps) LUAU_FASTFLAG(LuauCodegenUserdataAlloc) +LUAU_FASTFLAG(LuauCompileFastcall3) +LUAU_FASTFLAG(LuauCodegenFastcall3) static std::string getCodegenAssembly(const char* source, bool includeIrTypes = false, int debugLevel = 1) { @@ -159,8 +157,6 @@ bb_bytecode_1: TEST_CASE("VectorComponentRead") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - CHECK_EQ("\n" + getCodegenAssembly(R"( local function compsum(a: vector) return a.X + a.Y + a.Z @@ -238,8 +234,6 @@ bb_bytecode_1: TEST_CASE("VectorSubMulDiv") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - CHECK_EQ("\n" + getCodegenAssembly(R"( local function vec3combo(a: vector, b: vector, c: vector, d: vector) return a * b - c / d @@ -272,8 +266,6 @@ bb_bytecode_1: TEST_CASE("VectorSubMulDiv2") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - CHECK_EQ("\n" + getCodegenAssembly(R"( local function vec3combo(a: vector) local tmp = a * a @@ -302,8 +294,6 @@ bb_bytecode_1: TEST_CASE("VectorMulDivMixed") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - CHECK_EQ("\n" + getCodegenAssembly(R"( local function vec3combo(a: vector, b: vector, c: vector, d: vector) return a * 2 + b / 4 + 0.5 * c + 40 / d @@ -344,8 +334,6 @@ bb_bytecode_1: TEST_CASE("ExtraMathMemoryOperands") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - CHECK_EQ("\n" + getCodegenAssembly(R"( local function foo(a: number, b: number, c: number, d: number, e: number) return math.floor(a) + math.ceil(b) + math.round(c) + math.sqrt(d) + math.abs(e) @@ -382,8 +370,6 @@ bb_bytecode_1: TEST_CASE("DseInitialStackState") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - CHECK_EQ("\n" + getCodegenAssembly(R"( local function foo() while {} do @@ -422,7 +408,7 @@ bb_5: TEST_CASE("DseInitialStackState2") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; + ScopedFastFlag luauCodegenFastcall3{FFlag::LuauCodegenFastcall3, true}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function foo(a) @@ -435,7 +421,7 @@ end bb_bytecode_0: CHECK_SAFE_ENV exit(1) CHECK_TAG R0, tnumber, exit(1) - FASTCALL 14u, R1, R0, undef, 1i, 2i + FASTCALL 14u, R1, R0, 2i INTERRUPT 5u RETURN R0, 1i )"); @@ -443,7 +429,7 @@ bb_bytecode_0: TEST_CASE("DseInitialStackState3") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; + ScopedFastFlag luauCodegenFastcall3{FFlag::LuauCodegenFastcall3, true}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function foo(a) @@ -456,7 +442,7 @@ end bb_bytecode_0: CHECK_SAFE_ENV exit(1) CHECK_TAG R0, tnumber, exit(1) - FASTCALL 47u, R1, R0, undef, 1i, 1i + FASTCALL 47u, R1, R0, 1i INTERRUPT 5u RETURN R0, 1i )"); @@ -464,8 +450,6 @@ bb_bytecode_0: TEST_CASE("VectorConstantTag") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - CHECK_EQ("\n" + getCodegenAssembly(R"( local function vecrcp(a: vector) return vector(1, 2, 3) + a @@ -491,8 +475,6 @@ bb_bytecode_1: TEST_CASE("VectorNamecall") { - ScopedFastFlag luauCodegenDirectUserdataFlow{FFlag::LuauCodegenDirectUserdataFlow, true}; - CHECK_EQ("\n" + getCodegenAssembly(R"( local function abs(a: vector) return a:Abs() @@ -517,8 +499,6 @@ bb_bytecode_1: TEST_CASE("VectorRandomProp") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - CHECK_EQ("\n" + getCodegenAssembly(R"( local function foo(a: vector) return a.XX + a.YY + a.ZZ @@ -559,7 +539,6 @@ bb_6: TEST_CASE("VectorCustomAccess") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; ScopedFastFlag luauCodegenAnalyzeHostVectorOps{FFlag::LuauCodegenAnalyzeHostVectorOps, true}; CHECK_EQ("\n" + getCodegenAssembly(R"( @@ -594,8 +573,6 @@ bb_bytecode_1: TEST_CASE("VectorCustomNamecall") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - ScopedFastFlag LuauCodegenDirectUserdataFlow{FFlag::LuauCodegenDirectUserdataFlow, true}; ScopedFastFlag luauCodegenAnalyzeHostVectorOps{FFlag::LuauCodegenAnalyzeHostVectorOps, true}; CHECK_EQ("\n" + getCodegenAssembly(R"( @@ -634,8 +611,6 @@ bb_bytecode_1: TEST_CASE("VectorCustomAccessChain") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - ScopedFastFlag LuauCodegenDirectUserdataFlow{FFlag::LuauCodegenDirectUserdataFlow, true}; ScopedFastFlag luauCodegenAnalyzeHostVectorOps{FFlag::LuauCodegenAnalyzeHostVectorOps, true}; CHECK_EQ("\n" + getCodegenAssembly(R"( @@ -688,8 +663,6 @@ bb_bytecode_1: TEST_CASE("VectorCustomNamecallChain") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - ScopedFastFlag LuauCodegenDirectUserdataFlow{FFlag::LuauCodegenDirectUserdataFlow, true}; ScopedFastFlag luauCodegenAnalyzeHostVectorOps{FFlag::LuauCodegenAnalyzeHostVectorOps, true}; CHECK_EQ("\n" + getCodegenAssembly(R"( @@ -749,9 +722,7 @@ bb_bytecode_1: TEST_CASE("VectorCustomNamecallChain2") { - ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, - {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, - {FFlag::LuauCodegenAnalyzeHostVectorOps, true}}; + ScopedFastFlag sffs[]{{FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCodegenAnalyzeHostVectorOps, true}}; CHECK_EQ("\n" + getCodegenAssembly(R"( type Vertex = {n: vector, b: vector} @@ -827,8 +798,6 @@ bb_6: TEST_CASE("UserDataGetIndex") { - ScopedFastFlag luauCodegenDirectUserdataFlow{FFlag::LuauCodegenDirectUserdataFlow, true}; - CHECK_EQ("\n" + getCodegenAssembly(R"( local function getxy(a: Point) return a.x + a.y @@ -859,8 +828,6 @@ bb_4: TEST_CASE("UserDataSetIndex") { - ScopedFastFlag luauCodegenDirectUserdataFlow{FFlag::LuauCodegenDirectUserdataFlow, true}; - CHECK_EQ("\n" + getCodegenAssembly(R"( local function setxy(a: Point) a.x = 3 @@ -887,8 +854,6 @@ bb_bytecode_1: TEST_CASE("UserDataNamecall") { - ScopedFastFlag luauCodegenDirectUserdataFlow{FFlag::LuauCodegenDirectUserdataFlow, true}; - CHECK_EQ("\n" + getCodegenAssembly(R"( local function getxy(a: Point) return a:GetX() + a:GetY() @@ -925,8 +890,7 @@ bb_4: TEST_CASE("ExplicitUpvalueAndLocalTypes") { - ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, - {FFlag::LuauCodegenRemoveDeadStores5, true}}; + ScopedFastFlag sffs[]{{FFlag::LuauCompileTypeInfo, true}}; CHECK_EQ("\n" + getCodegenAssembly(R"( local y: vector = ... @@ -969,8 +933,7 @@ bb_bytecode_0: TEST_CASE("FastcallTypeInferThroughLocal") { - ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, - {FFlag::LuauCodegenRemoveDeadStores5, true}}; + ScopedFastFlag sffs[]{{FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCompileFastcall3, true}, {FFlag::LuauCodegenFastcall3, true}}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function getsum(x, c) @@ -985,43 +948,40 @@ end /* includeIrTypes */ true), R"( ; function getsum($arg0, $arg1) line 2 -; R2: vector from 0 to 17 +; R2: vector from 0 to 18 bb_bytecode_0: - %0 = LOAD_TVALUE R0 - STORE_TVALUE R3, %0 STORE_DOUBLE R4, 2 STORE_TAG R4, tnumber STORE_DOUBLE R5, 3 STORE_TAG R5, tnumber CHECK_SAFE_ENV exit(4) - CHECK_TAG R3, tnumber, exit(4) - %13 = LOAD_DOUBLE R3 - STORE_VECTOR R2, %13, 2, 3 + CHECK_TAG R0, tnumber, exit(4) + %11 = LOAD_DOUBLE R0 + STORE_VECTOR R2, %11, 2, 3 STORE_TAG R2, tvector JUMP_IF_FALSY R1, bb_bytecode_1, bb_3 bb_3: - CHECK_TAG R2, tvector, exit(8) - %21 = LOAD_FLOAT R2, 0i - %26 = LOAD_FLOAT R2, 4i - %35 = ADD_NUM %21, %26 - STORE_DOUBLE R3, %35 + CHECK_TAG R2, tvector, exit(9) + %19 = LOAD_FLOAT R2, 0i + %24 = LOAD_FLOAT R2, 4i + %33 = ADD_NUM %19, %24 + STORE_DOUBLE R3, %33 STORE_TAG R3, tnumber - INTERRUPT 13u + INTERRUPT 14u RETURN R3, 1i bb_bytecode_1: - CHECK_TAG R2, tvector, exit(14) - %42 = LOAD_FLOAT R2, 8i - STORE_DOUBLE R3, %42 + CHECK_TAG R2, tvector, exit(15) + %40 = LOAD_FLOAT R2, 8i + STORE_DOUBLE R3, %40 STORE_TAG R3, tnumber - INTERRUPT 16u + INTERRUPT 17u RETURN R3, 1i )"); } TEST_CASE("FastcallTypeInferThroughUpvalue") { - ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, - {FFlag::LuauCodegenRemoveDeadStores5, true}}; + ScopedFastFlag sffs[]{{FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCompileFastcall3, true}, {FFlag::LuauCodegenFastcall3, true}}; CHECK_EQ("\n" + getCodegenAssembly(R"( local v = ... @@ -1040,48 +1000,45 @@ end ; function getsum($arg0, $arg1) line 4 ; U0: vector bb_bytecode_0: - %0 = LOAD_TVALUE R0 - STORE_TVALUE R3, %0 STORE_DOUBLE R4, 2 STORE_TAG R4, tnumber STORE_DOUBLE R5, 3 STORE_TAG R5, tnumber CHECK_SAFE_ENV exit(4) - CHECK_TAG R3, tnumber, exit(4) - %13 = LOAD_DOUBLE R3 - STORE_VECTOR R2, %13, 2, 3 + CHECK_TAG R0, tnumber, exit(4) + %11 = LOAD_DOUBLE R0 + STORE_VECTOR R2, %11, 2, 3 STORE_TAG R2, tvector SET_UPVALUE U0, R2, tvector JUMP_IF_FALSY R1, bb_bytecode_1, bb_3 bb_3: GET_UPVALUE R4, U0 - CHECK_TAG R4, tvector, exit(10) - %23 = LOAD_FLOAT R4, 0i - STORE_DOUBLE R3, %23 + CHECK_TAG R4, tvector, exit(11) + %21 = LOAD_FLOAT R4, 0i + STORE_DOUBLE R3, %21 STORE_TAG R3, tnumber GET_UPVALUE R5, U0 - CHECK_TAG R5, tvector, exit(13) - %29 = LOAD_FLOAT R5, 4i - %38 = ADD_NUM %23, %29 - STORE_DOUBLE R2, %38 + CHECK_TAG R5, tvector, exit(14) + %27 = LOAD_FLOAT R5, 4i + %36 = ADD_NUM %21, %27 + STORE_DOUBLE R2, %36 STORE_TAG R2, tnumber - INTERRUPT 16u + INTERRUPT 17u RETURN R2, 1i bb_bytecode_1: GET_UPVALUE R3, U0 - CHECK_TAG R3, tvector, exit(18) - %46 = LOAD_FLOAT R3, 8i - STORE_DOUBLE R2, %46 + CHECK_TAG R3, tvector, exit(19) + %44 = LOAD_FLOAT R3, 8i + STORE_DOUBLE R2, %44 STORE_TAG R2, tnumber - INTERRUPT 20u + INTERRUPT 21u RETURN R2, 1i )"); } TEST_CASE("LoadAndMoveTypePropagation") { - ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, - {FFlag::LuauCodegenRemoveDeadStores5, true}}; + ScopedFastFlag sffs[]{{FFlag::LuauCompileTypeInfo, true}}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function getsum(n) @@ -1148,8 +1105,7 @@ bb_bytecode_4: TEST_CASE("ArgumentTypeRefinement") { - ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, - {FFlag::LuauCodegenRemoveDeadStores5, true}}; + ScopedFastFlag sffs[]{{FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCompileFastcall3, true}, {FFlag::LuauCodegenFastcall3, true}}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function getsum(x, y) @@ -1164,31 +1120,28 @@ end bb_bytecode_0: STORE_DOUBLE R3, 1 STORE_TAG R3, tnumber - %2 = LOAD_TVALUE R1 - STORE_TVALUE R4, %2 STORE_DOUBLE R5, 3 STORE_TAG R5, tnumber CHECK_SAFE_ENV exit(4) - CHECK_TAG R4, tnumber, exit(4) - %14 = LOAD_DOUBLE R4 - STORE_VECTOR R2, 1, %14, 3 + CHECK_TAG R1, tnumber, exit(4) + %12 = LOAD_DOUBLE R1 + STORE_VECTOR R2, 1, %12, 3 STORE_TAG R2, tvector - %18 = LOAD_TVALUE R2 - STORE_TVALUE R0, %18 - %22 = LOAD_FLOAT R0, 4i - %27 = LOAD_FLOAT R0, 8i - %36 = ADD_NUM %22, %27 - STORE_DOUBLE R2, %36 + %16 = LOAD_TVALUE R2 + STORE_TVALUE R0, %16 + %20 = LOAD_FLOAT R0, 4i + %25 = LOAD_FLOAT R0, 8i + %34 = ADD_NUM %20, %25 + STORE_DOUBLE R2, %34 STORE_TAG R2, tnumber - INTERRUPT 13u + INTERRUPT 14u RETURN R2, 1i )"); } TEST_CASE("InlineFunctionType") { - ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, - {FFlag::LuauCodegenRemoveDeadStores5, true}}; + ScopedFastFlag sffs[]{{FFlag::LuauCompileTypeInfo, true}}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function inl(v: vector, s: number) @@ -1236,8 +1189,7 @@ bb_bytecode_0: TEST_CASE("ResolveTablePathTypes") { - ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, - {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}}; + ScopedFastFlag sffs[]{{FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCompileTempTypeInfo, true}}; CHECK_EQ("\n" + getCodegenAssembly(R"( type Vertex = {pos: vector, normal: vector} @@ -1291,8 +1243,7 @@ bb_6: TEST_CASE("ResolvableSimpleMath") { - ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, - {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}}; + ScopedFastFlag sffs[]{{FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCompileTempTypeInfo, true}}; CHECK_EQ("\n" + getCodegenHeader(R"( type Vertex = { p: vector, uv: vector, n: vector, t: vector, b: vector, h: number } @@ -1348,9 +1299,7 @@ end TEST_CASE("ResolveVectorNamecalls") { - ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, - {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, - {FFlag::LuauCodegenAnalyzeHostVectorOps, true}}; + ScopedFastFlag sffs[]{{FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCodegenAnalyzeHostVectorOps, true}}; CHECK_EQ("\n" + getCodegenAssembly(R"( type Vertex = {pos: vector, normal: vector} @@ -1414,8 +1363,7 @@ bb_6: TEST_CASE("ImmediateTypeAnnotationHelp") { - ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, - {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}}; + ScopedFastFlag sffs[]{{FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCompileTempTypeInfo, true}}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function foo(arr, i) @@ -1453,8 +1401,8 @@ bb_2: TEST_CASE("UnaryTypeResolve") { - ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, - {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}}; + ScopedFastFlag sffs[]{{FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileFastcall3, true}, + {FFlag::LuauCodegenFastcall3, true}}; CHECK_EQ("\n" + getCodegenHeader(R"( local function foo(a, b: vector, c) @@ -1467,17 +1415,16 @@ end R"( ; function foo(a, b, c) line 2 ; R1: vector [argument 'b'] -; R3: boolean from 0 to 16 [local 'd'] -; R4: vector from 1 to 16 [local 'e'] -; R5: number from 2 to 16 [local 'f'] -; R7: vector from 13 to 15 +; R3: boolean from 0 to 17 [local 'd'] +; R4: vector from 1 to 17 [local 'e'] +; R5: number from 2 to 17 [local 'f'] +; R7: vector from 14 to 16 )"); } TEST_CASE("ForInManualAnnotation") { - ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, - {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}}; + ScopedFastFlag sffs[]{{FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCompileTempTypeInfo, true}}; CHECK_EQ("\n" + getCodegenAssembly(R"( type Vertex = {pos: vector, normal: vector} @@ -1572,8 +1519,7 @@ bb_12: TEST_CASE("ForInAutoAnnotationIpairs") { - ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, - {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}}; + ScopedFastFlag sffs[]{{FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCompileTempTypeInfo, true}}; CHECK_EQ("\n" + getCodegenHeader(R"( type Vertex = {pos: vector, normal: vector} @@ -1600,8 +1546,7 @@ end TEST_CASE("ForInAutoAnnotationPairs") { - ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, - {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}}; + ScopedFastFlag sffs[]{{FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCompileTempTypeInfo, true}}; CHECK_EQ("\n" + getCodegenHeader(R"( type Vertex = {pos: vector, normal: vector} @@ -1628,8 +1573,7 @@ end TEST_CASE("ForInAutoAnnotationGeneric") { - ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, - {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}}; + ScopedFastFlag sffs[]{{FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCompileTempTypeInfo, true}}; CHECK_EQ("\n" + getCodegenHeader(R"( type Vertex = {pos: vector, normal: vector} @@ -1661,8 +1605,7 @@ TEST_CASE("CustomUserdataTypesTemp") if (!Luau::CodeGen::isSupported()) return; - ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, - {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, false}, + ScopedFastFlag sffs[]{{FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, false}, {FFlag::LuauLoadUserdataInfo, true}}; CHECK_EQ("\n" + getCodegenHeader(R"( @@ -1683,8 +1626,7 @@ TEST_CASE("CustomUserdataTypes") if (!Luau::CodeGen::isSupported()) return; - ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, - {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true}, + ScopedFastFlag sffs[]{{FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true}, {FFlag::LuauLoadUserdataInfo, true}}; CHECK_EQ("\n" + getCodegenHeader(R"( @@ -1705,9 +1647,8 @@ TEST_CASE("CustomUserdataPropertyAccess") if (!Luau::CodeGen::isSupported()) return; - ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, - {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true}, - {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, {FFlag::LuauCodegenUserdataOps, true}}; + ScopedFastFlag sffs[]{{FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true}, + {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenUserdataOps, true}}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function foo(v: vec2) @@ -1742,9 +1683,8 @@ TEST_CASE("CustomUserdataPropertyAccess2") if (!Luau::CodeGen::isSupported()) return; - ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, - {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true}, - {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, {FFlag::LuauCodegenUserdataOps, true}}; + ScopedFastFlag sffs[]{{FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true}, + {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenUserdataOps, true}}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function foo(a: mat3) @@ -1781,10 +1721,8 @@ TEST_CASE("CustomUserdataNamecall1") if (!Luau::CodeGen::isSupported()) return; - ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, - {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true}, - {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, {FFlag::LuauCodegenAnalyzeHostVectorOps, true}, - {FFlag::LuauCodegenUserdataOps, true}}; + ScopedFastFlag sffs[]{{FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true}, + {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenAnalyzeHostVectorOps, true}, {FFlag::LuauCodegenUserdataOps, true}}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function foo(a: vec2, b: vec2) @@ -1830,10 +1768,9 @@ TEST_CASE("CustomUserdataNamecall2") if (!Luau::CodeGen::isSupported()) return; - ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, - {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true}, - {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, {FFlag::LuauCodegenAnalyzeHostVectorOps, true}, - {FFlag::LuauCodegenUserdataOps, true}, {FFlag::LuauCodegenUserdataAlloc, true}}; + ScopedFastFlag sffs[]{{FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true}, + {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenAnalyzeHostVectorOps, true}, {FFlag::LuauCodegenUserdataOps, true}, + {FFlag::LuauCodegenUserdataAlloc, true}}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function foo(a: vec2, b: vec2) @@ -1882,9 +1819,8 @@ TEST_CASE("CustomUserdataMetamethodDirectFlow") if (!Luau::CodeGen::isSupported()) return; - ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, - {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true}, - {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, {FFlag::LuauCodegenUserdataOps, true}}; + ScopedFastFlag sffs[]{{FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true}, + {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenUserdataOps, true}}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function foo(a: mat3, b: mat3) @@ -1916,9 +1852,8 @@ TEST_CASE("CustomUserdataMetamethodDirectFlow2") if (!Luau::CodeGen::isSupported()) return; - ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, - {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true}, - {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, {FFlag::LuauCodegenUserdataOps, true}}; + ScopedFastFlag sffs[]{{FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true}, + {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenUserdataOps, true}}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function foo(a: mat3) @@ -1948,9 +1883,8 @@ TEST_CASE("CustomUserdataMetamethodDirectFlow3") if (!Luau::CodeGen::isSupported()) return; - ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, - {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true}, - {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, {FFlag::LuauCodegenUserdataOps, true}}; + ScopedFastFlag sffs[]{{FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true}, + {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenUserdataOps, true}}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function foo(a: sequence) @@ -1980,10 +1914,8 @@ TEST_CASE("CustomUserdataMetamethod") if (!Luau::CodeGen::isSupported()) return; - ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, - {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true}, - {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, {FFlag::LuauCodegenUserdataOps, true}, - {FFlag::LuauCodegenUserdataAlloc, true}}; + ScopedFastFlag sffs[]{{FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true}, + {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenUserdataOps, true}, {FFlag::LuauCodegenUserdataAlloc, true}}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function foo(a: vec2, b: vec2, c: vec2) diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index 9f7aa77a..807b5e73 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -8,6 +8,9 @@ #include "doctest.h" LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); +LUAU_FASTFLAG(LuauAttributeSyntax); +LUAU_FASTFLAG(LuauNativeAttribute); +LUAU_FASTFLAG(LintRedundantNativeAttribute); using namespace Luau; @@ -1955,4 +1958,32 @@ local _ = a <= (b == 0) CHECK_EQ(result.warnings[4].text, "X <= Y <= Z is equivalent to (X <= Y) <= Z; did you mean X <= Y and Y <= Z?"); } +TEST_CASE_FIXTURE(Fixture, "RedundantNativeAttribute") +{ + ScopedFastFlag sff[] = {{FFlag::LuauAttributeSyntax, true}, {FFlag::LuauNativeAttribute, true}, {FFlag::LintRedundantNativeAttribute, true}}; + + LintResult result = lint(R"( +--!native + +@native +local function f(a) + @native + local function g(b) + return (a + b) + end + return g +end + +f(3)(4) +)"); + + REQUIRE(2 == result.warnings.size()); + + CHECK_EQ(result.warnings[0].text, "native attribute on a function is redundant in a native module; consider removing it"); + CHECK_EQ(result.warnings[0].location, Location(Position(3, 0), Position(3, 7))); + + CHECK_EQ(result.warnings[1].text, "native attribute on a function is redundant in a native module; consider removing it"); + CHECK_EQ(result.warnings[1].location, Location(Position(5, 4), Position(5, 11))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 48d130dd..4f8ed3eb 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -2298,10 +2298,14 @@ end if (FFlag::DebugLuauDeferredConstraintResolution) { LUAU_REQUIRE_ERROR_COUNT(4, result); - CHECK(toString(result.errors[0]) == "Operator '-' could not be applied to operands of types unknown and number; there is no corresponding overload for __sub"); - CHECK(toString(result.errors[1]) == "Operator '-' could not be applied to operands of types unknown and number; there is no corresponding overload for __sub"); - CHECK(toString(result.errors[2]) == "Operator '-' could not be applied to operands of types unknown and number; there is no corresponding overload for __sub"); - CHECK(toString(result.errors[3]) == "Operator '-' could not be applied to operands of types unknown and number; there is no corresponding overload for __sub"); + CHECK(toString(result.errors[0]) == + "Operator '-' could not be applied to operands of types unknown and number; there is no corresponding overload for __sub"); + CHECK(toString(result.errors[1]) == + "Operator '-' could not be applied to operands of types unknown and number; there is no corresponding overload for __sub"); + CHECK(toString(result.errors[2]) == + "Operator '-' could not be applied to operands of types unknown and number; there is no corresponding overload for __sub"); + CHECK(toString(result.errors[3]) == + "Operator '-' could not be applied to operands of types unknown and number; there is no corresponding overload for __sub"); } else { @@ -2719,7 +2723,6 @@ end _ = _,{} )"); - } diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 8e81b0cc..a34af12d 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -1230,4 +1230,45 @@ TEST_CASE_FIXTURE(Fixture, "table_containing_non_final_type_is_erroneously_cache CHECK(n1 == n2); } +// This is doable with the new solver, but there are some problems we have to work out first. +// CLI-111113 +TEST_CASE_FIXTURE(Fixture, "we_cannot_infer_functions_that_return_inconsistently") +{ + CheckResult result = check(R"( + function find_first(tbl: {T}, el) + for i, e in tbl do + if e == el then + return i + end + end + return nil + end + )"); + +#if 0 + // This #if block describes what should happen. + LUAU_CHECK_NO_ERRORS(result); + + // The second argument has type unknown because the == operator does not + // constrain the type of el. + CHECK("({T}, unknown) -> number?" == toString(requireType("find_first"))); +#else + // This is what actually happens right now. + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + LUAU_CHECK_ERROR_COUNT(2, result); + + // The second argument should be unknown. CLI-111111 + CHECK("({T}, 'b) -> number" == toString(requireType("find_first"))); + } + else + { + LUAU_CHECK_ERROR_COUNT(1, result); + + CHECK("({T}, b) -> number" == toString(requireType("find_first"))); + } +#endif +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 2c9614d0..4dbedd51 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -21,12 +21,30 @@ LUAU_FASTFLAG(LuauInstantiateInSubtyping); LUAU_FASTFLAG(LuauAlwaysCommitInferencesOfFunctionCalls); LUAU_FASTFLAG(LuauFixIndexerSubtypingOrdering); LUAU_FASTFLAG(DebugLuauSharedSelf); -LUAU_FASTFLAG(LuauMetatableInstantiationCloneCheck); LUAU_DYNAMIC_FASTFLAG(LuauImproveNonFunctionCallError) TEST_SUITE_BEGIN("TableTests"); +TEST_CASE_FIXTURE(BuiltinsFixture, "generalization_shouldnt_seal_table_in_len_family_fn") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + CheckResult result = check(R"( +local t = {} +for i = #t, 2, -1 do + t[i] = t[i + 1] +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + const TableType* tType = get(requireType("t")); + REQUIRE(tType != nullptr); + REQUIRE(tType->indexer); + CHECK_EQ(tType->indexer->indexType, builtinTypes->numberType); + CHECK_EQ(follow(tType->indexer->indexResultType), builtinTypes->unknownType); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "LUAU_ASSERT_arg_exprs_doesnt_trigger_assert") { CheckResult result = check(R"( @@ -4150,9 +4168,7 @@ TEST_CASE_FIXTURE(Fixture, "write_annotations_are_unsupported_even_with_the_new_ TEST_CASE_FIXTURE(Fixture, "read_and_write_only_table_properties_are_unsupported") { - ScopedFastFlag sff[] = { - {FFlag::DebugLuauDeferredConstraintResolution, false} - }; + ScopedFastFlag sff[] = {{FFlag::DebugLuauDeferredConstraintResolution, false}}; CheckResult result = check(R"( type W = {read x: number} @@ -4176,9 +4192,7 @@ TEST_CASE_FIXTURE(Fixture, "read_and_write_only_table_properties_are_unsupported TEST_CASE_FIXTURE(Fixture, "read_ond_write_only_indexers_are_unsupported") { - ScopedFastFlag sff[] = { - {FFlag::DebugLuauDeferredConstraintResolution, false} - }; + ScopedFastFlag sff[] = {{FFlag::DebugLuauDeferredConstraintResolution, false}}; CheckResult result = check(R"( type T = {read [string]: number} @@ -4198,9 +4212,7 @@ TEST_CASE_FIXTURE(Fixture, "table_writes_introduce_write_properties") if (!FFlag::DebugLuauDeferredConstraintResolution) return; - ScopedFastFlag sff[] = { - {FFlag::DebugLuauDeferredConstraintResolution, true} - }; + ScopedFastFlag sff[] = {{FFlag::DebugLuauDeferredConstraintResolution, true}}; CheckResult result = check(R"( function oc(player, speaker) @@ -4354,8 +4366,6 @@ TEST_CASE_FIXTURE(Fixture, "mymovie_read_write_tables_bug_2") TEST_CASE_FIXTURE(BuiltinsFixture, "instantiated_metatable_frozen_table_clone_mutation") { - ScopedFastFlag luauMetatableInstantiationCloneCheck{FFlag::LuauMetatableInstantiationCloneCheck, true}; - fileResolver.source["game/worker"] = R"( type WorkerImpl = { destroy: (self: Worker) -> boolean, @@ -4533,8 +4543,24 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_literal_inference_assert") end; } )"); +} +TEST_CASE_FIXTURE(BuiltinsFixture, "metatable_table_assertion_crash") +{ + CheckResult result = check(R"( + local NexusInstance = {} + function NexusInstance:__InitMetaMethods(): () + local Metatable = {} + local OriginalIndexTable = getmetatable(self).__index + setmetatable(self, Metatable) + Metatable.__newindex = function(_, Index: string, Value: any): () + --Return if the new and old values are the same. + if self[Index] == Value then + end + end + end + )"); } TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 60903733..1d1dd999 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -1536,7 +1536,7 @@ TEST_CASE_FIXTURE(Fixture, "typeof_cannot_refine_builtin_alias") freeze(arena); - (void) check(R"( + (void)check(R"( function foo(x) if typeof(x) == 'GlobalTable' then end diff --git a/tests/conformance/bitwise.lua b/tests/conformance/bitwise.lua index f394dc5b..c2536508 100644 --- a/tests/conformance/bitwise.lua +++ b/tests/conformance/bitwise.lua @@ -72,6 +72,7 @@ for _, b in pairs(c) do assert(bit32.bxor(b, b) == 0) assert(bit32.bxor(b, 0) == b) assert(bit32.bxor(b, b, b) == b) + assert(bit32.bxor(b, b, b, b) == 0) assert(bit32.bnot(b) ~= b) assert(bit32.bnot(bit32.bnot(b)) == b) assert(bit32.bnot(b) == 2^32 - 1 - b) diff --git a/tests/conformance/math.lua b/tests/conformance/math.lua index 9262f4ea..b8fc882a 100644 --- a/tests/conformance/math.lua +++ b/tests/conformance/math.lua @@ -268,10 +268,33 @@ assert(math.min(1) == 1) assert(math.min(1, 2) == 1) assert(math.min(1, 2, -1) == -1) assert(math.min(1, -1, 2) == -1) +assert(math.min(1, -1, 2, -2) == -2) assert(math.max(1) == 1) assert(math.max(1, 2) == 2) assert(math.max(1, 2, -1) == 2) assert(math.max(1, -1, 2) == 2) +assert(math.max(1, -1, 2, -2) == 2) + +local ma, mb, mc, md + +assert(pcall(function() + ma = 1 + mb = -1 + mc = 2 + md = -2 +end) == true) + +-- min/max without contant-folding +assert(math.min(ma) == 1) +assert(math.min(ma, mc) == 1) +assert(math.min(ma, mc, mb) == -1) +assert(math.min(ma, mb, mc) == -1) +assert(math.min(ma, mb, mc, md) == -2) +assert(math.max(ma) == 1) +assert(math.max(ma, mc) == 2) +assert(math.max(ma, mc, mb) == 2) +assert(math.max(ma, mb, mc) == 2) +assert(math.max(ma, mb, mc, md) == 2) -- noise assert(math.noise(0.5) == 0) From 0d2688844ab285af1ef52f15878b57911c3cf056 Mon Sep 17 00:00:00 2001 From: aaron Date: Fri, 28 Jun 2024 20:34:49 -0400 Subject: [PATCH 20/20] Sync to upstream/release/632 (#1307) # What's Changed? - Fix #1137 by appropriately retaining additional metadata from definition files throughout the type system. - Improve Frontend for LSPs by appropriately allowing the cancellation of typechecking while running its destructor. ## New Solver - Added support for the `rawget` type function. - Reduced overall static memory usage of builtin type functions. - Fixed a crash where visitors could mutate a union or intersection type and fail to invalidate iteration over them in doing so. - Revised autocomplete functionality to not rely on a separate run of the type solver when using the new solver. - Implemented a more relaxed semantic rule for casting. - Fixed some smaller crashes in the new solver. ## Native Code Generation - Add additional codegen specialization for `math.sign` - Cleaned up a large number of outstanding fflags in the code. ### Internal Contributors Co-authored-by: Aaron Weiss Co-authored-by: Alexander McCord Co-authored-by: Andy Friesen Co-authored-by: James McNellis Co-authored-by: Jeremy Yoo Co-authored-by: Vighnesh Vijay Co-authored-by: Vyacheslav Egorov --------- Co-authored-by: Alexander McCord Co-authored-by: Andy Friesen Co-authored-by: Vighnesh Co-authored-by: Aviral Goel Co-authored-by: David Cope Co-authored-by: Lily Brown Co-authored-by: Vyacheslav Egorov --- Analysis/include/Luau/Frontend.h | 2 +- Analysis/include/Luau/TypeFamily.h | 5 +- Analysis/include/Luau/VisitType.h | 22 +++ Analysis/src/AstJsonEncoder.cpp | 27 ++++ Analysis/src/Autocomplete.cpp | 13 +- Analysis/src/BuiltinDefinitions.cpp | 4 +- Analysis/src/ConstraintGenerator.cpp | 106 +++++++++++---- Analysis/src/Error.cpp | 10 +- Analysis/src/Frontend.cpp | 26 +++- Analysis/src/TypeChecker2.cpp | 31 +++-- Analysis/src/TypeFamily.cpp | 56 ++++++-- Analysis/src/TypeInfer.cpp | 59 +++++++- Analysis/src/Unifier.cpp | 27 +++- Ast/include/Luau/Ast.h | 18 ++- Ast/include/Luau/TimeTrace.h | 8 ++ Ast/src/Ast.cpp | 19 ++- Ast/src/Parser.cpp | 83 +++++++++--- Ast/src/TimeTrace.cpp | 4 + CodeGen/include/Luau/IrData.h | 4 + CodeGen/include/Luau/IrUtils.h | 1 + CodeGen/src/BytecodeAnalysis.cpp | 39 ++---- CodeGen/src/CodeGenContext.cpp | 4 +- CodeGen/src/EmitBuiltinsX64.cpp | 5 + CodeGen/src/EmitCommonX64.cpp | 66 ++++----- CodeGen/src/IrBuilder.cpp | 12 +- CodeGen/src/IrDump.cpp | 2 + CodeGen/src/IrLoweringA64.cpp | 90 +++++++------ CodeGen/src/IrLoweringX64.cpp | 28 ++++ CodeGen/src/IrTranslateBuiltins.cpp | 8 +- CodeGen/src/IrTranslation.cpp | 6 +- CodeGen/src/IrUtils.cpp | 9 ++ CodeGen/src/NativeState.cpp | 1 - CodeGen/src/NativeState.h | 1 - CodeGen/src/OptimizeConstProp.cpp | 61 +++------ Common/include/Luau/Bytecode.h | 1 - Compiler/include/luacode.h | 2 +- Compiler/src/BytecodeBuilder.cpp | 185 ++++++++++++-------------- Compiler/src/Compiler.cpp | 121 +++++------------ Compiler/src/Types.cpp | 97 ++------------ VM/src/lvm.h | 1 - VM/src/lvmexecute.cpp | 155 +++------------------ VM/src/lvmutils.cpp | 134 ------------------- tests/AstJsonEncoder.test.cpp | 24 +++- tests/Autocomplete.test.cpp | 21 ++- tests/Compiler.test.cpp | 4 - tests/Conformance.test.cpp | 3 - tests/Frontend.test.cpp | 54 ++++++++ tests/Generalization.test.cpp | 75 +++++++++++ tests/IrBuilder.test.cpp | 11 +- tests/IrLowering.test.cpp | 96 +++---------- tests/Parser.test.cpp | 17 +++ tests/TypeFamily.test.cpp | 118 ++++++++++++++++ tests/TypeInfer.anyerror.test.cpp | 9 ++ tests/TypeInfer.classes.test.cpp | 26 ++++ tests/TypeInfer.definitions.test.cpp | 20 +++ tests/TypeInfer.functions.test.cpp | 22 +++ tests/TypeInfer.provisional.test.cpp | 1 - tests/TypeInfer.tables.test.cpp | 2 - tests/TypeInfer.tryUnify.test.cpp | 31 +++++ tests/TypeInfer.unknownnever.test.cpp | 11 ++ tests/conformance/math.lua | 87 ++++++++---- 61 files changed, 1243 insertions(+), 942 deletions(-) diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index fc9bc54f..27a67f40 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -191,7 +191,7 @@ struct Frontend void queueModuleCheck(const std::vector& names); void queueModuleCheck(const ModuleName& name); std::vector checkQueuedModules(std::optional optionOverride = {}, - std::function task)> executeTask = {}, std::function progress = {}); + std::function task)> executeTask = {}, std::function progress = {}); std::optional getCheckResult(const ModuleName& name, bool accumulateNested, bool forAutocomplete = false); diff --git a/Analysis/include/Luau/TypeFamily.h b/Analysis/include/Luau/TypeFamily.h index 7c68d815..9d2182df 100644 --- a/Analysis/include/Luau/TypeFamily.h +++ b/Analysis/include/Luau/TypeFamily.h @@ -180,12 +180,11 @@ struct BuiltinTypeFamilies TypeFamily rawkeyofFamily; TypeFamily indexFamily; + TypeFamily rawgetFamily; void addToScope(NotNull arena, NotNull scope) const; }; - - -const BuiltinTypeFamilies kBuiltinTypeFamilies{}; +const BuiltinTypeFamilies& builtinTypeFunctions(); } // namespace Luau diff --git a/Analysis/include/Luau/VisitType.h b/Analysis/include/Luau/VisitType.h index ff0656d6..8c0f5ed9 100644 --- a/Analysis/include/Luau/VisitType.h +++ b/Analysis/include/Luau/VisitType.h @@ -348,16 +348,38 @@ struct GenericTypeVisitor { if (visit(ty, *utv)) { + bool unionChanged = false; for (TypeId optTy : utv->options) + { traverse(optTy); + if (!get(follow(ty))) + { + unionChanged = true; + break; + } + } + + if (unionChanged) + traverse(ty); } } else if (auto itv = get(ty)) { if (visit(ty, *itv)) { + bool intersectionChanged = false; for (TypeId partTy : itv->parts) + { traverse(partTy); + if (!get(follow(ty))) + { + intersectionChanged = true; + break; + } + } + + if (intersectionChanged) + traverse(ty); } } else if (auto ltv = get(ty)) diff --git a/Analysis/src/AstJsonEncoder.cpp b/Analysis/src/AstJsonEncoder.cpp index 470d69b3..3507a68f 100644 --- a/Analysis/src/AstJsonEncoder.cpp +++ b/Analysis/src/AstJsonEncoder.cpp @@ -8,6 +8,8 @@ #include +LUAU_FASTFLAG(LuauDeclarationExtraPropData) + namespace Luau { @@ -735,8 +737,21 @@ struct AstJsonEncoder : public AstVisitor void write(class AstStatDeclareFunction* node) { writeNode(node, "AstStatDeclareFunction", [&]() { + // TODO: attributes PROP(name); + + if (FFlag::LuauDeclarationExtraPropData) + PROP(nameLocation); + PROP(params); + + if (FFlag::LuauDeclarationExtraPropData) + { + PROP(paramNames); + PROP(vararg); + PROP(varargLocation); + } + PROP(retTypes); PROP(generics); PROP(genericPacks); @@ -747,6 +762,10 @@ struct AstJsonEncoder : public AstVisitor { writeNode(node, "AstStatDeclareGlobal", [&]() { PROP(name); + + if (FFlag::LuauDeclarationExtraPropData) + PROP(nameLocation); + PROP(type); }); } @@ -756,8 +775,16 @@ struct AstJsonEncoder : public AstVisitor writeRaw("{"); bool c = pushComma(); write("name", prop.name); + + if (FFlag::LuauDeclarationExtraPropData) + write("nameLocation", prop.nameLocation); + writeType("AstDeclaredClassProp"); write("luauType", prop.ty); + + if (FFlag::LuauDeclarationExtraPropData) + write("location", prop.location); + popComma(c); writeRaw("}"); } diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index d6f0ab83..0dab640f 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -1830,12 +1830,21 @@ AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName if (!sourceModule) return {}; - ModulePtr module = frontend.moduleResolverForAutocomplete.getModule(moduleName); + ModulePtr module; + if (FFlag::DebugLuauDeferredConstraintResolution) + module = frontend.moduleResolver.getModule(moduleName); + else + module = frontend.moduleResolverForAutocomplete.getModule(moduleName); + if (!module) return {}; NotNull builtinTypes = frontend.builtinTypes; - Scope* globalScope = frontend.globalsForAutocomplete.globalScope.get(); + Scope* globalScope; + if (FFlag::DebugLuauDeferredConstraintResolution) + globalScope = frontend.globals.globalScope.get(); + else + globalScope = frontend.globalsForAutocomplete.globalScope.get(); TypeArena typeArena; return autocomplete(*sourceModule, module, builtinTypes, &typeArena, globalScope, position, callback); diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index 2393bd2a..582d5a7d 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -216,7 +216,7 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC NotNull builtinTypes = globals.builtinTypes; if (FFlag::DebugLuauDeferredConstraintResolution) - kBuiltinTypeFamilies.addToScope(NotNull{&arena}, NotNull{globals.globalScope.get()}); + builtinTypeFunctions().addToScope(NotNull{&arena}, NotNull{globals.globalScope.get()}); LoadDefinitionFileResult loadResult = frontend.loadDefinitionFile( globals, globals.globalScope, getBuiltinDefinitionSource(), "@luau", /* captureComments */ false, typeCheckForAutocomplete); @@ -313,7 +313,7 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC // declare function assert(value: T, errorMessage: string?): intersect TypeId genericT = arena.addType(GenericType{"T"}); TypeId refinedTy = arena.addType(TypeFamilyInstanceType{ - NotNull{&kBuiltinTypeFamilies.intersectFamily}, {genericT, arena.addType(NegationType{builtinTypes->falsyType})}, {}}); + NotNull{&builtinTypeFunctions().intersectFamily}, {genericT, arena.addType(NegationType{builtinTypes->falsyType})}, {}}); TypeId assertTy = arena.addType(FunctionType{ {genericT}, {}, arena.addTypePack(TypePack{{genericT, builtinTypes->optionalStringType}}), arena.addTypePack(TypePack{{refinedTy}})}); diff --git a/Analysis/src/ConstraintGenerator.cpp b/Analysis/src/ConstraintGenerator.cpp index 7d92d9ff..f1ae5eae 100644 --- a/Analysis/src/ConstraintGenerator.cpp +++ b/Analysis/src/ConstraintGenerator.cpp @@ -29,6 +29,7 @@ LUAU_FASTINT(LuauCheckRecursionLimit); LUAU_FASTFLAG(DebugLuauLogSolverToJson); LUAU_FASTFLAG(DebugLuauMagicTypes); LUAU_FASTFLAG(LuauAttributeSyntax); +LUAU_FASTFLAG(LuauDeclarationExtraPropData); namespace Luau { @@ -431,7 +432,7 @@ void ConstraintGenerator::computeRefinement(const ScopePtr& scope, Location loca discriminantTy = arena->addType(NegationType{discriminantTy}); if (eq) - discriminantTy = createTypeFamilyInstance(kBuiltinTypeFamilies.singletonFamily, {discriminantTy}, {}, scope, location); + discriminantTy = createTypeFamilyInstance(builtinTypeFunctions().singletonFamily, {discriminantTy}, {}, scope, location); for (const RefinementKey* key = proposition->key; key; key = key->parent) { @@ -543,7 +544,7 @@ void ConstraintGenerator::applyRefinements(const ScopePtr& scope, Location locat { if (mustDeferIntersection(ty) || mustDeferIntersection(dt)) { - TypeId resultType = createTypeFamilyInstance(kBuiltinTypeFamilies.refineFamily, {ty, dt}, {}, scope, location); + TypeId resultType = createTypeFamilyInstance(builtinTypeFunctions().refineFamily, {ty, dt}, {}, scope, location); ty = resultType; } @@ -1389,6 +1390,18 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatDeclareClas ftv->argTypes = addTypePack({classTy}, ftv->argTypes); ftv->hasSelf = true; + + if (FFlag::LuauDeclarationExtraPropData) + { + FunctionDefinition defn; + + defn.definitionModuleName = module->name; + defn.definitionLocation = prop.location; + // No data is preserved for varargLocation + defn.originalNameLocation = prop.nameLocation; + + ftv->definition = defn; + } } } @@ -1396,7 +1409,38 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatDeclareClas if (props.count(propName) == 0) { - props[propName] = {propTy}; + if (FFlag::LuauDeclarationExtraPropData) + props[propName] = {propTy, /*deprecated*/ false, /*deprecatedSuggestion*/ "", prop.location}; + else + props[propName] = {propTy}; + } + else if (FFlag::LuauDeclarationExtraPropData) + { + Luau::Property& prop = props[propName]; + TypeId currentTy = prop.type(); + + // We special-case this logic to keep the intersection flat; otherwise we + // would create a ton of nested intersection types. + if (const IntersectionType* itv = get(currentTy)) + { + std::vector options = itv->parts; + options.push_back(propTy); + TypeId newItv = arena->addType(IntersectionType{std::move(options)}); + + prop.readTy = newItv; + prop.writeTy = newItv; + } + else if (get(currentTy)) + { + TypeId intersection = arena->addType(IntersectionType{{currentTy, propTy}}); + + prop.readTy = intersection; + prop.writeTy = intersection; + } + else + { + reportError(declaredClass->location, GenericError{format("Cannot overload non-function class member '%s'", propName.c_str())}); + } } else { @@ -1453,7 +1497,18 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatDeclareFunc TypePackId paramPack = resolveTypePack(funScope, global->params, /* inTypeArguments */ false); TypePackId retPack = resolveTypePack(funScope, global->retTypes, /* inTypeArguments */ false); - TypeId fnType = arena->addType(FunctionType{TypeLevel{}, funScope.get(), std::move(genericTys), std::move(genericTps), paramPack, retPack}); + + FunctionDefinition defn; + + if (FFlag::LuauDeclarationExtraPropData) + { + defn.definitionModuleName = module->name; + defn.definitionLocation = global->location; + defn.varargLocation = global->vararg ? std::make_optional(global->varargLocation) : std::nullopt; + defn.originalNameLocation = global->nameLocation; + } + + TypeId fnType = arena->addType(FunctionType{TypeLevel{}, funScope.get(), std::move(genericTys), std::move(genericTps), paramPack, retPack, defn}); FunctionType* ftv = getMutable(fnType); ftv->isCheckedFunction = FFlag::LuauAttributeSyntax ? global->isCheckedFunction() : false; @@ -2032,17 +2087,17 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprUnary* unary) { case AstExprUnary::Op::Not: { - TypeId resultType = createTypeFamilyInstance(kBuiltinTypeFamilies.notFamily, {operandType}, {}, scope, unary->location); + TypeId resultType = createTypeFamilyInstance(builtinTypeFunctions().notFamily, {operandType}, {}, scope, unary->location); return Inference{resultType, refinementArena.negation(refinement)}; } case AstExprUnary::Op::Len: { - TypeId resultType = createTypeFamilyInstance(kBuiltinTypeFamilies.lenFamily, {operandType}, {}, scope, unary->location); + TypeId resultType = createTypeFamilyInstance(builtinTypeFunctions().lenFamily, {operandType}, {}, scope, unary->location); return Inference{resultType, refinementArena.negation(refinement)}; } case AstExprUnary::Op::Minus: { - TypeId resultType = createTypeFamilyInstance(kBuiltinTypeFamilies.unmFamily, {operandType}, {}, scope, unary->location); + TypeId resultType = createTypeFamilyInstance(builtinTypeFunctions().unmFamily, {operandType}, {}, scope, unary->location); return Inference{resultType, refinementArena.negation(refinement)}; } default: // msvc can't prove that this is exhaustive. @@ -2058,74 +2113,75 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprBinary* binar { case AstExprBinary::Op::Add: { - TypeId resultType = createTypeFamilyInstance(kBuiltinTypeFamilies.addFamily, {leftType, rightType}, {}, scope, binary->location); + TypeId resultType = createTypeFamilyInstance(builtinTypeFunctions().addFamily, {leftType, rightType}, {}, scope, binary->location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::Sub: { - TypeId resultType = createTypeFamilyInstance(kBuiltinTypeFamilies.subFamily, {leftType, rightType}, {}, scope, binary->location); + TypeId resultType = createTypeFamilyInstance(builtinTypeFunctions().subFamily, {leftType, rightType}, {}, scope, binary->location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::Mul: { - TypeId resultType = createTypeFamilyInstance(kBuiltinTypeFamilies.mulFamily, {leftType, rightType}, {}, scope, binary->location); + TypeId resultType = createTypeFamilyInstance(builtinTypeFunctions().mulFamily, {leftType, rightType}, {}, scope, binary->location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::Div: { - TypeId resultType = createTypeFamilyInstance(kBuiltinTypeFamilies.divFamily, {leftType, rightType}, {}, scope, binary->location); + TypeId resultType = createTypeFamilyInstance(builtinTypeFunctions().divFamily, {leftType, rightType}, {}, scope, binary->location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::FloorDiv: { - TypeId resultType = createTypeFamilyInstance(kBuiltinTypeFamilies.idivFamily, {leftType, rightType}, {}, scope, binary->location); + TypeId resultType = createTypeFamilyInstance(builtinTypeFunctions().idivFamily, {leftType, rightType}, {}, scope, binary->location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::Pow: { - TypeId resultType = createTypeFamilyInstance(kBuiltinTypeFamilies.powFamily, {leftType, rightType}, {}, scope, binary->location); + TypeId resultType = createTypeFamilyInstance(builtinTypeFunctions().powFamily, {leftType, rightType}, {}, scope, binary->location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::Mod: { - TypeId resultType = createTypeFamilyInstance(kBuiltinTypeFamilies.modFamily, {leftType, rightType}, {}, scope, binary->location); + TypeId resultType = createTypeFamilyInstance(builtinTypeFunctions().modFamily, {leftType, rightType}, {}, scope, binary->location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::Concat: { - TypeId resultType = createTypeFamilyInstance(kBuiltinTypeFamilies.concatFamily, {leftType, rightType}, {}, scope, binary->location); + TypeId resultType = createTypeFamilyInstance(builtinTypeFunctions().concatFamily, {leftType, rightType}, {}, scope, binary->location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::And: { - TypeId resultType = createTypeFamilyInstance(kBuiltinTypeFamilies.andFamily, {leftType, rightType}, {}, scope, binary->location); + TypeId resultType = createTypeFamilyInstance(builtinTypeFunctions().andFamily, {leftType, rightType}, {}, scope, binary->location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::Or: { - TypeId resultType = createTypeFamilyInstance(kBuiltinTypeFamilies.orFamily, {leftType, rightType}, {}, scope, binary->location); + TypeId resultType = createTypeFamilyInstance(builtinTypeFunctions().orFamily, {leftType, rightType}, {}, scope, binary->location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::CompareLt: { - TypeId resultType = createTypeFamilyInstance(kBuiltinTypeFamilies.ltFamily, {leftType, rightType}, {}, scope, binary->location); + TypeId resultType = createTypeFamilyInstance(builtinTypeFunctions().ltFamily, {leftType, rightType}, {}, scope, binary->location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::CompareGe: { - TypeId resultType = createTypeFamilyInstance(kBuiltinTypeFamilies.ltFamily, + TypeId resultType = createTypeFamilyInstance(builtinTypeFunctions().ltFamily, {rightType, leftType}, // lua decided that `__ge(a, b)` is instead just `__lt(b, a)` {}, scope, binary->location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::CompareLe: { - TypeId resultType = createTypeFamilyInstance(kBuiltinTypeFamilies.leFamily, {leftType, rightType}, {}, scope, binary->location); + TypeId resultType = createTypeFamilyInstance(builtinTypeFunctions().leFamily, {leftType, rightType}, {}, scope, binary->location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::CompareGt: { - TypeId resultType = createTypeFamilyInstance(kBuiltinTypeFamilies.leFamily, + TypeId resultType = createTypeFamilyInstance( +builtinTypeFunctions().leFamily, {rightType, leftType}, // lua decided that `__gt(a, b)` is instead just `__le(b, a)` {}, scope, binary->location); return Inference{resultType, std::move(refinement)}; @@ -2147,7 +2203,7 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprBinary* binar else if (rightSubscripted) rightType = makeUnion(scope, binary->location, rightType, builtinTypes->nilType); - TypeId resultType = createTypeFamilyInstance(kBuiltinTypeFamilies.eqFamily, {leftType, rightType}, {}, scope, binary->location); + TypeId resultType = createTypeFamilyInstance(builtinTypeFunctions().eqFamily, {leftType, rightType}, {}, scope, binary->location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::Op__Count: @@ -3100,14 +3156,14 @@ TypeId ConstraintGenerator::makeUnion(const ScopePtr& scope, Location location, if (get(follow(rhs))) return lhs; - TypeId resultType = createTypeFamilyInstance(kBuiltinTypeFamilies.unionFamily, {lhs, rhs}, {}, scope, location); + TypeId resultType = createTypeFamilyInstance(builtinTypeFunctions().unionFamily, {lhs, rhs}, {}, scope, location); return resultType; } TypeId ConstraintGenerator::makeIntersect(const ScopePtr& scope, Location location, TypeId lhs, TypeId rhs) { - TypeId resultType = createTypeFamilyInstance(kBuiltinTypeFamilies.intersectFamily, {lhs, rhs}, {}, scope, location); + TypeId resultType = createTypeFamilyInstance(builtinTypeFunctions().intersectFamily, {lhs, rhs}, {}, scope, location); return resultType; } @@ -3225,7 +3281,7 @@ void ConstraintGenerator::fillInInferredBindings(const ScopePtr& globalScope, As scope->bindings[symbol] = Binding{tys.front(), location}; else { - TypeId ty = createTypeFamilyInstance(kBuiltinTypeFamilies.unionFamily, std::move(tys), {}, globalScope, location); + TypeId ty = createTypeFamilyInstance(builtinTypeFunctions().unionFamily, std::move(tys), {}, globalScope, location); scope->bindings[symbol] = Binding{ty, location}; } diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index cb8ef20d..5a9e42a7 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -661,14 +661,14 @@ struct ErrorConverter return "Type family instance " + Luau::toString(e.ty) + " is ill-formed, and thus invalid"; } - if ("index" == tfit->family->name) + if ("index" == tfit->family->name || "rawget" == tfit->family->name) { if (tfit->typeArguments.size() != 2) return "Type family instance " + Luau::toString(e.ty) + " is ill-formed, and thus invalid"; - if (auto errType = get(tfit->typeArguments[1])) // Second argument to index<_,_> is not a type - return "Second argument to index<" + Luau::toString(tfit->typeArguments[0]) + ", _> is not a valid index type"; - else // Second argument to index<_,_> is not a property of the first argument + if (auto errType = get(tfit->typeArguments[1])) // Second argument to (index | rawget)<_,_> is not a type + return "Second argument to " + tfit->family->name + "<" + Luau::toString(tfit->typeArguments[0]) + ", _> is not a valid index type"; + else // Property `indexer` does not exist on type `indexee` return "Property '" + Luau::toString(tfit->typeArguments[1]) + "' does not exist on type '" + Luau::toString(tfit->typeArguments[0]) + "'"; } @@ -1321,7 +1321,7 @@ void copyError(T& e, TypeArena& destArena, CloneState& cloneState) else if constexpr (std::is_same_v) { e.recommendedReturn = clone(e.recommendedReturn); - for (auto [_, t] : e.recommendedArgs) + for (auto& [_, t] : e.recommendedArgs) t = clone(t); } else if constexpr (std::is_same_v) diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 618a9a9c..4339960d 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -34,6 +34,7 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTINT(LuauTarjanChildLimit) LUAU_FASTFLAG(LuauInferInNoCheckMode) LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) +LUAU_FASTFLAGVARIABLE(LuauCancelFromProgress, false) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false) LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJsonFile, false) @@ -440,6 +441,8 @@ CheckResult Frontend::check(const ModuleName& name, std::optional result = getCheckResult(name, true, frontendOptions.forAutocomplete)) return std::move(*result); @@ -492,9 +495,11 @@ void Frontend::queueModuleCheck(const ModuleName& name) } std::vector Frontend::checkQueuedModules(std::optional optionOverride, - std::function task)> executeTask, std::function progress) + std::function task)> executeTask, std::function progress) { FrontendOptions frontendOptions = optionOverride.value_or(options); + if (FFlag::DebugLuauDeferredConstraintResolution) + frontendOptions.forAutocomplete = false; // By taking data into locals, we make sure queue is cleared at the end, even if an ICE or a different exception is thrown std::vector currModuleQueue; @@ -673,7 +678,17 @@ std::vector Frontend::checkQueuedModules(std::optional Frontend::checkQueuedModules(std::optional Frontend::getCheckResult(const ModuleName& name, bool accumulateNested, bool forAutocomplete) { + if (FFlag::DebugLuauDeferredConstraintResolution) + forAutocomplete = false; + auto it = sourceNodes.find(name); if (it == sourceNodes.end() || it->second->hasDirtyModule(forAutocomplete)) @@ -1006,9 +1024,7 @@ void Frontend::checkBuildQueueItem(BuildQueueItem& item) module->astCompoundAssignResultTypes.clear(); module->astScopes.clear(); module->upperBoundContributors.clear(); - - if (!FFlag::DebugLuauDeferredConstraintResolution) - module->scopes.clear(); + module->scopes.clear(); } if (mode != Mode::NoCheck) diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index fe0bf2dd..c53a5d30 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -2138,24 +2138,39 @@ struct TypeChecker2 TypeId annotationType = lookupAnnotation(expr->annotation); TypeId computedType = lookupType(expr->expr); - // Note: As an optimization, we try 'number <: number | string' first, as that is the more likely case. - if (subtyping->isSubtype(annotationType, computedType).isSubtype) - return; - - if (subtyping->isSubtype(computedType, annotationType).isSubtype) - return; - switch (shouldSuppressErrors(NotNull{&normalizer}, computedType).orElse(shouldSuppressErrors(NotNull{&normalizer}, annotationType))) { case ErrorSuppression::Suppress: return; case ErrorSuppression::NormalizationFailed: reportError(NormalizationTooComplex{}, expr->location); + return; case ErrorSuppression::DoNotSuppress: break; } - reportError(TypesAreUnrelated{computedType, annotationType}, expr->location); + switch (normalizer.isInhabited(computedType)) + { + case NormalizationResult::True: + break; + case NormalizationResult::False: + return; + case NormalizationResult::HitLimits: + reportError(NormalizationTooComplex{}, expr->location); + return; + } + + switch (normalizer.isIntersectionInhabited(computedType, annotationType)) + { + case NormalizationResult::True: + return; + case NormalizationResult::False: + reportError(TypesAreUnrelated{computedType, annotationType}, expr->location); + break; + case NormalizationResult::HitLimits: + reportError(NormalizationTooComplex{}, expr->location); + break; + } } void visit(AstExprIfElse* expr) diff --git a/Analysis/src/TypeFamily.cpp b/Analysis/src/TypeFamily.cpp index 54d89a15..816cf005 100644 --- a/Analysis/src/TypeFamily.cpp +++ b/Analysis/src/TypeFamily.cpp @@ -180,7 +180,10 @@ struct FamilyReducer void replace(T subject, T replacement) { if (subject->owningArena != ctx.arena.get()) - ctx.ice->ice("Attempting to modify a type family instance from another arena", location); + { + result.errors.emplace_back(location, InternalError{"Attempting to modify a type family instance from another arena"}); + return; + } if (FFlag::DebugLuauLogTypeFamilies) printf("%s -> %s\n", toString(subject, {true}).c_str(), toString(replacement, {true}).c_str()); @@ -514,7 +517,7 @@ static std::optional> tryDistributeTypeFamilyA return {{results[0], false, {}, {}}}; TypeId resultTy = ctx->arena->addType(TypeFamilyInstanceType{ - NotNull{&kBuiltinTypeFamilies.unionFamily}, + NotNull{&builtinTypeFunctions().unionFamily}, std::move(results), {}, }); @@ -1957,15 +1960,9 @@ bool tblIndexInto(TypeId indexer, TypeId indexee, DenseHashSet& result, /* Vocabulary note: indexee refers to the type that contains the properties, indexer refers to the type that is used to access indexee Example: index => `Person` is the indexee and `"name"` is the indexer */ -TypeFamilyReductionResult indexFamilyFn( - TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) +TypeFamilyReductionResult indexFamilyImpl( + const std::vector& typeParams, const std::vector& packParams, NotNull ctx, bool isRaw) { - if (typeParams.size() != 2 || !packParams.empty()) - { - ctx->ice->ice("index type family: encountered a type family instance without the required argument structure"); - LUAU_ASSERT(false); - } - TypeId indexeeTy = follow(typeParams.at(0)); std::shared_ptr indexeeNormTy = ctx->normalizer->normalize(indexeeTy); @@ -2003,12 +2000,14 @@ TypeFamilyReductionResult indexFamilyFn( typesToFind = &singleType; DenseHashSet properties{{}}; // vector of types that will be returned - bool isRaw = false; if (indexeeNormTy->hasClasses()) { LUAU_ASSERT(!indexeeNormTy->hasTables()); + if (isRaw) // rawget should never reduce for classes (to match the behavior of the rawget global function) + return {std::nullopt, true, {}, {}}; + // at least one class is guaranteed to be in the iterator by .hasClasses() for (auto classesIter = indexeeNormTy->classes.ordering.begin(); classesIter != indexeeNormTy->classes.ordering.end(); ++classesIter) { @@ -2021,7 +2020,7 @@ TypeFamilyReductionResult indexFamilyFn( for (TypeId ty : *typesToFind) { - // Search for all instances of indexer in class->props and class->indexer using `indexInto` + // Search for all instances of indexer in class->props and class->indexer if (searchPropsAndIndexer(ty, classTy->props, classTy->indexer, properties, ctx)) continue; // Indexer was found in this class, so we can move on to the next @@ -2065,6 +2064,30 @@ TypeFamilyReductionResult indexFamilyFn( return {ctx->arena->addType(UnionType{std::vector(properties.begin(), properties.end())}), false, {}, {}}; } +TypeFamilyReductionResult indexFamilyFn( + TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) +{ + if (typeParams.size() != 2 || !packParams.empty()) + { + ctx->ice->ice("index type family: encountered a type family instance without the required argument structure"); + LUAU_ASSERT(false); + } + + return indexFamilyImpl(typeParams, packParams, ctx, /* isRaw */ false); +} + +TypeFamilyReductionResult rawgetFamilyFn( + TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) +{ + if (typeParams.size() != 2 || !packParams.empty()) + { + ctx->ice->ice("rawget type family: encountered a type family instance without the required argument structure"); + LUAU_ASSERT(false); + } + + return indexFamilyImpl(typeParams, packParams, ctx, /* isRaw */ true); +} + BuiltinTypeFamilies::BuiltinTypeFamilies() : notFamily{"not", notFamilyFn} , lenFamily{"len", lenFamilyFn} @@ -2089,6 +2112,7 @@ BuiltinTypeFamilies::BuiltinTypeFamilies() , keyofFamily{"keyof", keyofFamilyFn} , rawkeyofFamily{"rawkeyof", rawkeyofFamilyFn} , indexFamily{"index", indexFamilyFn} + , rawgetFamily{"rawget", rawgetFamilyFn} { } @@ -2132,6 +2156,14 @@ void BuiltinTypeFamilies::addToScope(NotNull arena, NotNull sc scope->exportedTypeBindings[rawkeyofFamily.name] = mkUnaryTypeFamily(&rawkeyofFamily); scope->exportedTypeBindings[indexFamily.name] = mkBinaryTypeFamily(&indexFamily); + scope->exportedTypeBindings[rawgetFamily.name] = mkBinaryTypeFamily(&rawgetFamily); +} + +const BuiltinTypeFamilies& builtinTypeFunctions() +{ + static std::unique_ptr result = std::make_unique(); + + return *result; } } // namespace Luau diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 9ce1a58a..d4c25c34 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -38,6 +38,7 @@ LUAU_FASTFLAGVARIABLE(LuauAlwaysCommitInferencesOfFunctionCalls, false) LUAU_FASTFLAGVARIABLE(LuauRemoveBadRelationalOperatorWarning, false) LUAU_FASTFLAGVARIABLE(LuauOkWithIteratingOverTableProperties, false) LUAU_FASTFLAGVARIABLE(LuauReusableSubstitutions, false) +LUAU_FASTFLAG(LuauDeclarationExtraPropData) namespace Luau { @@ -1783,12 +1784,55 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& ftv->argNames.insert(ftv->argNames.begin(), FunctionArgument{"self", {}}); ftv->argTypes = addTypePack(TypePack{{classTy}, ftv->argTypes}); ftv->hasSelf = true; + + if (FFlag::LuauDeclarationExtraPropData) + { + FunctionDefinition defn; + + defn.definitionModuleName = currentModule->name; + defn.definitionLocation = prop.location; + // No data is preserved for varargLocation + defn.originalNameLocation = prop.nameLocation; + + ftv->definition = defn; + } } } if (assignTo.count(propName) == 0) { - assignTo[propName] = {propTy}; + if (FFlag::LuauDeclarationExtraPropData) + assignTo[propName] = {propTy, /*deprecated*/ false, /*deprecatedSuggestion*/ "", prop.location}; + else + assignTo[propName] = {propTy}; + } + else if (FFlag::LuauDeclarationExtraPropData) + { + Luau::Property& prop = assignTo[propName]; + TypeId currentTy = prop.type(); + + // We special-case this logic to keep the intersection flat; otherwise we + // would create a ton of nested intersection types. + if (const IntersectionType* itv = get(currentTy)) + { + std::vector options = itv->parts; + options.push_back(propTy); + TypeId newItv = addType(IntersectionType{std::move(options)}); + + prop.readTy = newItv; + prop.writeTy = newItv; + } + else if (get(currentTy)) + { + TypeId intersection = addType(IntersectionType{{currentTy, propTy}}); + + prop.readTy = intersection; + prop.writeTy = intersection; + } + else + { + reportError(declaredClass.location, GenericError{format("Cannot overload non-function class member '%s'", propName.c_str())}); + } } else { @@ -1840,7 +1884,18 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatDeclareFuncti TypePackId argPack = resolveTypePack(funScope, global.params); TypePackId retPack = resolveTypePack(funScope, global.retTypes); - TypeId fnType = addType(FunctionType{funScope->level, std::move(genericTys), std::move(genericTps), argPack, retPack}); + + FunctionDefinition defn; + + if (FFlag::LuauDeclarationExtraPropData) + { + defn.definitionModuleName = currentModule->name; + defn.definitionLocation = global.location; + defn.varargLocation = global.vararg ? std::make_optional(global.varargLocation) : std::nullopt; + defn.originalNameLocation = global.nameLocation; + } + + TypeId fnType = addType(FunctionType{funScope->level, std::move(genericTys), std::move(genericTps), argPack, retPack, defn}); FunctionType* ftv = getMutable(fnType); ftv->argNames.reserve(global.paramNames.size); diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index a0c802dd..1802345d 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -23,6 +23,7 @@ LUAU_FASTFLAG(LuauAlwaysCommitInferencesOfFunctionCalls) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAGVARIABLE(LuauFixIndexerSubtypingOrdering, false) LUAU_FASTFLAGVARIABLE(LuauUnifierShouldNotCopyError, false) +LUAU_FASTFLAGVARIABLE(LuauUnifierRecursionOnRestart, false) namespace Luau { @@ -2179,7 +2180,18 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection, // If one of the types stopped being a table altogether, we need to restart from the top if ((superTy != superTyNew || activeSubTy != subTyNew) && errors.empty()) - return tryUnify(subTy, superTy, false, isIntersection); + { + if (FFlag::LuauUnifierRecursionOnRestart) + { + RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit); + tryUnify(subTy, superTy, false, isIntersection); + return; + } + else + { + return tryUnify(subTy, superTy, false, isIntersection); + } + } // Otherwise, restart only the table unification TableType* newSuperTable = log.getMutable(superTyNew); @@ -2258,7 +2270,18 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection, // If one of the types stopped being a table altogether, we need to restart from the top if ((superTy != superTyNew || activeSubTy != subTyNew) && errors.empty()) - return tryUnify(subTy, superTy, false, isIntersection); + { + if (FFlag::LuauUnifierRecursionOnRestart) + { + RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit); + tryUnify(subTy, superTy, false, isIntersection); + return; + } + else + { + return tryUnify(subTy, superTy, false, isIntersection); + } + } // Recursive unification can change the txn log, and invalidate the old // table. If we detect that this has happened, we start over, with the updated diff --git a/Ast/include/Luau/Ast.h b/Ast/include/Luau/Ast.h index ab0d40e2..e2ac8b7d 100644 --- a/Ast/include/Luau/Ast.h +++ b/Ast/include/Luau/Ast.h @@ -826,11 +826,12 @@ class AstStatDeclareGlobal : public AstStat public: LUAU_RTTI(AstStatDeclareGlobal) - AstStatDeclareGlobal(const Location& location, const AstName& name, AstType* type); + AstStatDeclareGlobal(const Location& location, const AstName& name, const Location& nameLocation, AstType* type); void visit(AstVisitor* visitor) override; AstName name; + Location nameLocation; AstType* type; }; @@ -839,13 +840,13 @@ class AstStatDeclareFunction : public AstStat public: LUAU_RTTI(AstStatDeclareFunction) - AstStatDeclareFunction(const Location& location, const AstName& name, const AstArray& generics, - const AstArray& genericPacks, const AstTypeList& params, const AstArray& paramNames, - const AstTypeList& retTypes); + AstStatDeclareFunction(const Location& location, const AstName& name, const Location& nameLocation, const AstArray& generics, + const AstArray& genericPacks, const AstTypeList& params, const AstArray& paramNames, bool vararg, + const Location& varargLocation, const AstTypeList& retTypes); - AstStatDeclareFunction(const Location& location, const AstArray& attributes, const AstName& name, + AstStatDeclareFunction(const Location& location, const AstArray& attributes, const AstName& name, const Location& nameLocation, const AstArray& generics, const AstArray& genericPacks, const AstTypeList& params, - const AstArray& paramNames, const AstTypeList& retTypes); + const AstArray& paramNames, bool vararg, const Location& varargLocation, const AstTypeList& retTypes); void visit(AstVisitor* visitor) override; @@ -854,18 +855,23 @@ public: AstArray attributes; AstName name; + Location nameLocation; AstArray generics; AstArray genericPacks; AstTypeList params; AstArray paramNames; + bool vararg = false; + Location varargLocation; AstTypeList retTypes; }; struct AstDeclaredClassProp { AstName name; + Location nameLocation; AstType* ty = nullptr; bool isMethod = false; + Location location; }; enum class AstTableAccess diff --git a/Ast/include/Luau/TimeTrace.h b/Ast/include/Luau/TimeTrace.h index 2f7daf2c..bd2ca86b 100644 --- a/Ast/include/Luau/TimeTrace.h +++ b/Ast/include/Luau/TimeTrace.h @@ -134,6 +134,14 @@ struct ThreadContext static constexpr size_t kEventFlushLimit = 8192; }; +using ThreadContextProvider = ThreadContext& (*)(); + +inline ThreadContextProvider& threadContextProvider() +{ + static ThreadContextProvider handler = nullptr; + return handler; +} + ThreadContext& getThreadContext(); struct Scope diff --git a/Ast/src/Ast.cpp b/Ast/src/Ast.cpp index 14b79767..a3e53af5 100644 --- a/Ast/src/Ast.cpp +++ b/Ast/src/Ast.cpp @@ -705,9 +705,10 @@ void AstStatTypeAlias::visit(AstVisitor* visitor) } } -AstStatDeclareGlobal::AstStatDeclareGlobal(const Location& location, const AstName& name, AstType* type) +AstStatDeclareGlobal::AstStatDeclareGlobal(const Location& location, const AstName& name, const Location& nameLocation, AstType* type) : AstStat(ClassIndex(), location) , name(name) + , nameLocation(nameLocation) , type(type) { } @@ -718,30 +719,36 @@ void AstStatDeclareGlobal::visit(AstVisitor* visitor) type->visit(visitor); } -AstStatDeclareFunction::AstStatDeclareFunction(const Location& location, const AstName& name, const AstArray& generics, - const AstArray& genericPacks, const AstTypeList& params, const AstArray& paramNames, - const AstTypeList& retTypes) +AstStatDeclareFunction::AstStatDeclareFunction(const Location& location, const AstName& name, const Location& nameLocation, + const AstArray& generics, const AstArray& genericPacks, const AstTypeList& params, + const AstArray& paramNames, bool vararg, const Location& varargLocation, const AstTypeList& retTypes) : AstStat(ClassIndex(), location) , attributes() , name(name) + , nameLocation(nameLocation) , generics(generics) , genericPacks(genericPacks) , params(params) , paramNames(paramNames) + , vararg(vararg) + , varargLocation(varargLocation) , retTypes(retTypes) { } AstStatDeclareFunction::AstStatDeclareFunction(const Location& location, const AstArray& attributes, const AstName& name, - const AstArray& generics, const AstArray& genericPacks, const AstTypeList& params, - const AstArray& paramNames, const AstTypeList& retTypes) + const Location& nameLocation, const AstArray& generics, const AstArray& genericPacks, + const AstTypeList& params, const AstArray& paramNames, bool vararg, const Location& varargLocation, const AstTypeList& retTypes) : AstStat(ClassIndex(), location) , attributes(attributes) , name(name) + , nameLocation(nameLocation) , generics(generics) , genericPacks(genericPacks) , params(params) , paramNames(paramNames) + , vararg(vararg) + , varargLocation(varargLocation) , retTypes(retTypes) { } diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 3a6625a5..87af53cb 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -21,6 +21,7 @@ LUAU_FASTFLAG(LuauAttributeSyntax) LUAU_FASTFLAGVARIABLE(LuauLeadingBarAndAmpersand2, false) LUAU_FASTFLAGVARIABLE(LuauNativeAttribute, false) LUAU_FASTFLAGVARIABLE(LuauAttributeSyntaxFunExpr, false) +LUAU_FASTFLAGVARIABLE(LuauDeclarationExtraPropData, false) namespace Luau { @@ -909,8 +910,16 @@ AstStat* Parser::parseTypeAlias(const Location& start, bool exported) AstDeclaredClassProp Parser::parseDeclaredClassMethod() { + Location start; + + if (FFlag::LuauDeclarationExtraPropData) + start = lexer.current().location; + nextLexeme(); - Location start = lexer.current().location; + + if (!FFlag::LuauDeclarationExtraPropData) + start = lexer.current().location; + Name fnName = parseName("function name"); // TODO: generic method declarations CLI-39909 @@ -935,15 +944,15 @@ AstDeclaredClassProp Parser::parseDeclaredClassMethod() expectMatchAndConsume(')', matchParen); AstTypeList retTypes = parseOptionalReturnType().value_or(AstTypeList{copy(nullptr, 0), nullptr}); - Location end = lexer.current().location; + Location end = FFlag::LuauDeclarationExtraPropData ? lexer.previousLocation() : lexer.current().location; TempVector vars(scratchType); TempVector> varNames(scratchOptArgName); if (args.size() == 0 || args[0].name.name != "self" || args[0].annotation != nullptr) { - return AstDeclaredClassProp{ - fnName.name, reportTypeError(Location(start, end), {}, "'self' must be present as the unannotated first parameter"), true}; + return AstDeclaredClassProp{fnName.name, FFlag::LuauDeclarationExtraPropData ? fnName.location : Location{}, + reportTypeError(Location(start, end), {}, "'self' must be present as the unannotated first parameter"), true}; } // Skip the first index. @@ -963,7 +972,8 @@ AstDeclaredClassProp Parser::parseDeclaredClassMethod() AstType* fnType = allocator.alloc( Location(start, end), generics, genericPacks, AstTypeList{copy(vars), varargAnnotation}, copy(varNames), retTypes); - return AstDeclaredClassProp{fnName.name, fnType, true}; + return AstDeclaredClassProp{fnName.name, FFlag::LuauDeclarationExtraPropData ? fnName.location : Location{}, fnType, true, + FFlag::LuauDeclarationExtraPropData ? Location(start, end) : Location{}}; } AstStat* Parser::parseDeclaration(const Location& start, const AstArray& attributes) @@ -1014,8 +1024,12 @@ AstStat* Parser::parseDeclaration(const Location& start, const AstArray(Location(start, end), attributes, globalName.name, generics, genericPacks, - AstTypeList{copy(vars), varargAnnotation}, copy(varNames), retTypes); + if (FFlag::LuauDeclarationExtraPropData) + return allocator.alloc(Location(start, end), attributes, globalName.name, globalName.location, generics, + genericPacks, AstTypeList{copy(vars), varargAnnotation}, copy(varNames), vararg, varargLocation, retTypes); + else + return allocator.alloc(Location(start, end), attributes, globalName.name, Location{}, generics, genericPacks, + AstTypeList{copy(vars), varargAnnotation}, copy(varNames), false, Location{}, retTypes); } else if (AstName(lexer.current().name) == "class") { @@ -1045,19 +1059,42 @@ AstStat* Parser::parseDeclaration(const Location& start, const AstArray> chars = parseCharArray(); + if (FFlag::LuauDeclarationExtraPropData) + { + const Location nameBegin = lexer.current().location; + std::optional> chars = parseCharArray(); - expectMatchAndConsume(']', begin); - expectAndConsume(':', "property type annotation"); - AstType* type = parseType(); + const Location nameEnd = lexer.previousLocation(); - // since AstName contains a char*, it can't contain null - bool containsNull = chars && (strnlen(chars->data, chars->size) < chars->size); + expectMatchAndConsume(']', begin); + expectAndConsume(':', "property type annotation"); + AstType* type = parseType(); - if (chars && !containsNull) - props.push_back(AstDeclaredClassProp{AstName(chars->data), type, false}); + // since AstName contains a char*, it can't contain null + bool containsNull = chars && (strnlen(chars->data, chars->size) < chars->size); + + if (chars && !containsNull) + props.push_back(AstDeclaredClassProp{ + AstName(chars->data), Location(nameBegin, nameEnd), type, false, Location(begin.location, lexer.previousLocation())}); + else + report(begin.location, "String literal contains malformed escape sequence or \\0"); + } else - report(begin.location, "String literal contains malformed escape sequence or \\0"); + { + std::optional> chars = parseCharArray(); + + expectMatchAndConsume(']', begin); + expectAndConsume(':', "property type annotation"); + AstType* type = parseType(); + + // since AstName contains a char*, it can't contain null + bool containsNull = chars && (strnlen(chars->data, chars->size) < chars->size); + + if (chars && !containsNull) + props.push_back(AstDeclaredClassProp{AstName(chars->data), Location{}, type, false}); + else + report(begin.location, "String literal contains malformed escape sequence or \\0"); + } } else if (lexer.current().type == '[') { @@ -1075,12 +1112,21 @@ AstStat* Parser::parseDeclaration(const Location& start, const AstArray(Location(start, type->location), globalName->name, type); + return allocator.alloc( + Location(start, type->location), globalName->name, FFlag::LuauDeclarationExtraPropData ? globalName->location : Location{}, type); } else { diff --git a/Ast/src/TimeTrace.cpp b/Ast/src/TimeTrace.cpp index bc3f3538..cfcf9ce2 100644 --- a/Ast/src/TimeTrace.cpp +++ b/Ast/src/TimeTrace.cpp @@ -250,6 +250,10 @@ void flushEvents(GlobalContext& context, uint32_t threadId, const std::vector -LUAU_FASTFLAGVARIABLE(LuauCodegenAnalyzeHostVectorOps, false) -LUAU_FASTFLAGVARIABLE(LuauCodegenLoadTypeUpvalCheck, false) LUAU_FASTFLAGVARIABLE(LuauCodegenUserdataOps, false) LUAU_FASTFLAGVARIABLE(LuauCodegenFastcall3, false) @@ -72,11 +70,6 @@ void loadBytecodeTypeInfo(IrFunction& function) uint32_t upvalCount = readVarInt(data, offset); uint32_t localCount = readVarInt(data, offset); - if (!FFlag::LuauCodegenLoadTypeUpvalCheck) - { - CODEGEN_ASSERT(upvalCount == unsigned(proto->nups)); - } - if (typeSize != 0) { uint8_t* types = (uint8_t*)data + offset; @@ -94,10 +87,7 @@ void loadBytecodeTypeInfo(IrFunction& function) if (upvalCount != 0) { - if (FFlag::LuauCodegenLoadTypeUpvalCheck) - { - CODEGEN_ASSERT(upvalCount == unsigned(proto->nups)); - } + CODEGEN_ASSERT(upvalCount == unsigned(proto->nups)); typeInfo.upvalueTypes.resize(upvalCount); @@ -775,7 +765,7 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) regTags[ra] = LBC_TYPE_NUMBER; } - if (FFlag::LuauCodegenAnalyzeHostVectorOps && regTags[ra] == LBC_TYPE_ANY && hostHooks.vectorAccessBytecodeType) + if (regTags[ra] == LBC_TYPE_ANY && hostHooks.vectorAccessBytecodeType) regTags[ra] = hostHooks.vectorAccessBytecodeType(field, str->len); } else if (isCustomUserdataBytecodeType(bcType.a)) @@ -800,7 +790,7 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) regTags[ra] = LBC_TYPE_NUMBER; } - if (FFlag::LuauCodegenAnalyzeHostVectorOps && regTags[ra] == LBC_TYPE_ANY && hostHooks.vectorAccessBytecodeType) + if (regTags[ra] == LBC_TYPE_ANY && hostHooks.vectorAccessBytecodeType) regTags[ra] = hostHooks.vectorAccessBytecodeType(field, str->len); } } @@ -1218,14 +1208,14 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) TString* str = gco2ts(function.proto->k[kc].value.gc); const char* field = getstr(str); - if (FFlag::LuauCodegenAnalyzeHostVectorOps && bcType.a == LBC_TYPE_VECTOR && hostHooks.vectorNamecallBytecodeType) + if (bcType.a == LBC_TYPE_VECTOR && hostHooks.vectorNamecallBytecodeType) knownNextCallResult = LuauBytecodeType(hostHooks.vectorNamecallBytecodeType(field, str->len)); else if (isCustomUserdataBytecodeType(bcType.a) && hostHooks.userdataNamecallBytecodeType) knownNextCallResult = LuauBytecodeType(hostHooks.userdataNamecallBytecodeType(bcType.a, field, str->len)); } else { - if (FFlag::LuauCodegenAnalyzeHostVectorOps && bcType.a == LBC_TYPE_VECTOR && hostHooks.vectorNamecallBytecodeType) + if (bcType.a == LBC_TYPE_VECTOR && hostHooks.vectorNamecallBytecodeType) { TString* str = gco2ts(function.proto->k[kc].value.gc); const char* field = getstr(str); @@ -1237,21 +1227,18 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) } case LOP_CALL: { - if (FFlag::LuauCodegenAnalyzeHostVectorOps) + int ra = LUAU_INSN_A(*pc); + + if (knownNextCallResult != LBC_TYPE_ANY) { - int ra = LUAU_INSN_A(*pc); + bcType.result = knownNextCallResult; - if (knownNextCallResult != LBC_TYPE_ANY) - { - bcType.result = knownNextCallResult; + knownNextCallResult = LBC_TYPE_ANY; - knownNextCallResult = LBC_TYPE_ANY; - - regTags[ra] = bcType.result; - } - - refineRegType(bcTypeInfo, ra, i, bcType.result); + regTags[ra] = bcType.result; } + + refineRegType(bcTypeInfo, ra, i, bcType.result); break; } case LOP_GETUPVAL: diff --git a/CodeGen/src/CodeGenContext.cpp b/CodeGen/src/CodeGenContext.cpp index 67a2676e..a31a08ba 100644 --- a/CodeGen/src/CodeGenContext.cpp +++ b/CodeGen/src/CodeGenContext.cpp @@ -12,8 +12,6 @@ #include "lapi.h" -LUAU_FASTFLAGVARIABLE(LuauCodegenCheckNullContext, false) - LUAU_FASTINTVARIABLE(LuauCodeGenBlockSize, 4 * 1024 * 1024) LUAU_FASTINTVARIABLE(LuauCodeGenMaxTotalSize, 256 * 1024 * 1024) LUAU_FASTFLAG(LuauNativeAttribute) @@ -360,7 +358,7 @@ static size_t getMemorySize(lua_State* L, Proto* proto) static void initializeExecutionCallbacks(lua_State* L, BaseCodeGenContext* codeGenContext) noexcept { - CODEGEN_ASSERT(!FFlag::LuauCodegenCheckNullContext || codeGenContext != nullptr); + CODEGEN_ASSERT(codeGenContext != nullptr); lua_ExecutionCallbacks* ecb = &L->global->ecb; diff --git a/CodeGen/src/EmitBuiltinsX64.cpp b/CodeGen/src/EmitBuiltinsX64.cpp index 09f69d69..15aab4b6 100644 --- a/CodeGen/src/EmitBuiltinsX64.cpp +++ b/CodeGen/src/EmitBuiltinsX64.cpp @@ -12,6 +12,8 @@ #include "lstate.h" +LUAU_FASTFLAG(LuauCodegenMathSign) + namespace Luau { namespace CodeGen @@ -57,6 +59,8 @@ static void emitBuiltinMathModf(IrRegAllocX64& regs, AssemblyBuilderX64& build, static void emitBuiltinMathSign(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int arg) { + CODEGEN_ASSERT(!FFlag::LuauCodegenMathSign); + ScopedRegX64 tmp0{regs, SizeX64::xmmword}; ScopedRegX64 tmp1{regs, SizeX64::xmmword}; ScopedRegX64 tmp2{regs, SizeX64::xmmword}; @@ -94,6 +98,7 @@ void emitBuiltin(IrRegAllocX64& regs, AssemblyBuilderX64& build, int bfid, int r CODEGEN_ASSERT(nresults == 1 || nresults == 2); return emitBuiltinMathModf(regs, build, ra, arg, nresults); case LBF_MATH_SIGN: + CODEGEN_ASSERT(!FFlag::LuauCodegenMathSign); CODEGEN_ASSERT(nresults == 1); return emitBuiltinMathSign(regs, build, ra, arg); default: diff --git a/CodeGen/src/EmitCommonX64.cpp b/CodeGen/src/EmitCommonX64.cpp index 50f2208b..79562b88 100644 --- a/CodeGen/src/EmitCommonX64.cpp +++ b/CodeGen/src/EmitCommonX64.cpp @@ -14,8 +14,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauCodegenSplitDoarith, false) - namespace Luau { namespace CodeGen @@ -158,43 +156,35 @@ void callArithHelper(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, Ope callWrap.addArgument(SizeX64::qword, b); callWrap.addArgument(SizeX64::qword, c); - if (FFlag::LuauCodegenSplitDoarith) + switch (tm) { - switch (tm) - { - case TM_ADD: - callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarithadd)]); - break; - case TM_SUB: - callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarithsub)]); - break; - case TM_MUL: - callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarithmul)]); - break; - case TM_DIV: - callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarithdiv)]); - break; - case TM_IDIV: - callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarithidiv)]); - break; - case TM_MOD: - callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarithmod)]); - break; - case TM_POW: - callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarithpow)]); - break; - case TM_UNM: - callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarithunm)]); - break; - default: - CODEGEN_ASSERT(!"Invalid doarith helper operation tag"); - break; - } - } - else - { - callWrap.addArgument(SizeX64::dword, tm); - callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarith)]); + case TM_ADD: + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarithadd)]); + break; + case TM_SUB: + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarithsub)]); + break; + case TM_MUL: + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarithmul)]); + break; + case TM_DIV: + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarithdiv)]); + break; + case TM_IDIV: + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarithidiv)]); + break; + case TM_MOD: + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarithmod)]); + break; + case TM_POW: + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarithpow)]); + break; + case TM_UNM: + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarithunm)]); + break; + default: + CODEGEN_ASSERT(!"Invalid doarith helper operation tag"); + break; } emitUpdateBase(build); diff --git a/CodeGen/src/IrBuilder.cpp b/CodeGen/src/IrBuilder.cpp index 672c27ad..1f4342f6 100644 --- a/CodeGen/src/IrBuilder.cpp +++ b/CodeGen/src/IrBuilder.cpp @@ -13,7 +13,6 @@ #include -LUAU_FASTFLAG(LuauCodegenAnalyzeHostVectorOps) LUAU_FASTFLAG(LuauLoadUserdataInfo) LUAU_FASTFLAG(LuauCodegenInstG) LUAU_FASTFLAG(LuauCodegenFastcall3) @@ -534,15 +533,8 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) translateInstCapture(*this, pc, i); break; case LOP_NAMECALL: - if (FFlag::LuauCodegenAnalyzeHostVectorOps) - { - if (translateInstNamecall(*this, pc, i)) - cmdSkipTarget = i + 3; - } - else - { - translateInstNamecall(*this, pc, i); - } + if (translateInstNamecall(*this, pc, i)) + cmdSkipTarget = i + 3; break; case LOP_PREPVARARGS: inst(IrCmd::FALLBACK_PREPVARARGS, constUint(i), constInt(LUAU_INSN_A(*pc))); diff --git a/CodeGen/src/IrDump.cpp b/CodeGen/src/IrDump.cpp index 5465d0a0..c4114d89 100644 --- a/CodeGen/src/IrDump.cpp +++ b/CodeGen/src/IrDump.cpp @@ -154,6 +154,8 @@ const char* getCmdName(IrCmd cmd) return "SQRT_NUM"; case IrCmd::ABS_NUM: return "ABS_NUM"; + case IrCmd::SIGN_NUM: + return "SIGN_NUM"; case IrCmd::ADD_VEC: return "ADD_VEC"; case IrCmd::SUB_VEC: diff --git a/CodeGen/src/IrLoweringA64.cpp b/CodeGen/src/IrLoweringA64.cpp index 5b333374..ef51a4b1 100644 --- a/CodeGen/src/IrLoweringA64.cpp +++ b/CodeGen/src/IrLoweringA64.cpp @@ -11,11 +11,11 @@ #include "lstate.h" #include "lgc.h" -LUAU_FASTFLAG(LuauCodegenSplitDoarith) LUAU_FASTFLAG(LuauCodegenUserdataOps) LUAU_FASTFLAGVARIABLE(LuauCodegenUserdataAlloc, false) LUAU_FASTFLAGVARIABLE(LuauCodegenUserdataOpsFixA64, false) LUAU_FASTFLAG(LuauCodegenFastcall3) +LUAU_FASTFLAG(LuauCodegenMathSign) namespace Luau { @@ -240,6 +240,7 @@ static bool emitBuiltin(AssemblyBuilderA64& build, IrFunction& function, IrRegAl } case LBF_MATH_SIGN: { + CODEGEN_ASSERT(!FFlag::LuauCodegenMathSign); CODEGEN_ASSERT(nresults == 1); build.ldr(d0, mem(rBase, arg * sizeof(TValue) + offsetof(TValue, value.n))); build.fcmpz(d0); @@ -697,6 +698,24 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) build.fabs(inst.regA64, temp); break; } + case IrCmd::SIGN_NUM: + { + CODEGEN_ASSERT(FFlag::LuauCodegenMathSign); + + inst.regA64 = regs.allocReuse(KindA64::d, index, {inst.a}); + + RegisterA64 temp = tempDouble(inst.a); + RegisterA64 temp0 = regs.allocTemp(KindA64::d); + RegisterA64 temp1 = regs.allocTemp(KindA64::d); + + build.fcmpz(temp); + build.fmov(temp0, 0.0); + build.fmov(temp1, 1.0); + build.fcsel(inst.regA64, temp1, temp0, getConditionFP(IrCondition::Greater)); + build.fmov(temp1, -1.0); + build.fcsel(inst.regA64, temp1, inst.regA64, getConditionFP(IrCondition::Less)); + break; + } case IrCmd::ADD_VEC: { inst.regA64 = regs.allocReuse(KindA64::q, index, {inst.a, inst.b}); @@ -1283,47 +1302,38 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) else build.add(x3, rBase, uint16_t(vmRegOp(inst.c) * sizeof(TValue))); - if (FFlag::LuauCodegenSplitDoarith) + switch (TMS(intOp(inst.d))) { - switch (TMS(intOp(inst.d))) - { - case TM_ADD: - build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_doarithadd))); - break; - case TM_SUB: - build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_doarithsub))); - break; - case TM_MUL: - build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_doarithmul))); - break; - case TM_DIV: - build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_doarithdiv))); - break; - case TM_IDIV: - build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_doarithidiv))); - break; - case TM_MOD: - build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_doarithmod))); - break; - case TM_POW: - build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_doarithpow))); - break; - case TM_UNM: - build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_doarithunm))); - break; - default: - CODEGEN_ASSERT(!"Invalid doarith helper operation tag"); - break; - } + case TM_ADD: + build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_doarithadd))); + break; + case TM_SUB: + build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_doarithsub))); + break; + case TM_MUL: + build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_doarithmul))); + break; + case TM_DIV: + build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_doarithdiv))); + break; + case TM_IDIV: + build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_doarithidiv))); + break; + case TM_MOD: + build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_doarithmod))); + break; + case TM_POW: + build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_doarithpow))); + break; + case TM_UNM: + build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_doarithunm))); + break; + default: + CODEGEN_ASSERT(!"Invalid doarith helper operation tag"); + break; + } - build.blr(x4); - } - else - { - build.mov(w4, TMS(intOp(inst.d))); - build.ldr(x5, mem(rNativeContext, offsetof(NativeContext, luaV_doarith))); - build.blr(x5); - } + build.blr(x4); emitUpdateBase(build); break; diff --git a/CodeGen/src/IrLoweringX64.cpp b/CodeGen/src/IrLoweringX64.cpp index 5128dce5..f372a7ec 100644 --- a/CodeGen/src/IrLoweringX64.cpp +++ b/CodeGen/src/IrLoweringX64.cpp @@ -18,6 +18,7 @@ LUAU_FASTFLAG(LuauCodegenUserdataOps) LUAU_FASTFLAG(LuauCodegenUserdataAlloc) LUAU_FASTFLAG(LuauCodegenFastcall3) +LUAU_FASTFLAG(LuauCodegenMathSign) namespace Luau { @@ -590,6 +591,33 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) build.vandpd(inst.regX64, inst.regX64, build.i64(~(1LL << 63))); break; + case IrCmd::SIGN_NUM: + { + CODEGEN_ASSERT(FFlag::LuauCodegenMathSign); + + inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a}); + + ScopedRegX64 tmp0{regs, SizeX64::xmmword}; + ScopedRegX64 tmp1{regs, SizeX64::xmmword}; + ScopedRegX64 tmp2{regs, SizeX64::xmmword}; + + build.vxorpd(tmp0.reg, tmp0.reg, tmp0.reg); + + // Set tmp1 to -1 if arg < 0, else 0 + build.vcmpltsd(tmp1.reg, regOp(inst.a), tmp0.reg); + build.vmovsd(tmp2.reg, build.f64(-1)); + build.vandpd(tmp1.reg, tmp1.reg, tmp2.reg); + + // Set mask bit to 1 if 0 < arg, else 0 + build.vcmpltsd(inst.regX64, tmp0.reg, regOp(inst.a)); + + // Result = (mask-bit == 1) ? 1.0 : tmp1 + // If arg < 0 then tmp1 is -1 and mask-bit is 0, result is -1 + // If arg == 0 then tmp1 is 0 and mask-bit is 0, result is 0 + // If arg > 0 then tmp1 is 0 and mask-bit is 1, result is 1 + build.vblendvpd(inst.regX64, tmp1.reg, build.f64x2(1, 1), inst.regX64); + break; + } case IrCmd::ADD_VEC: { inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a, inst.b}); diff --git a/CodeGen/src/IrTranslateBuiltins.cpp b/CodeGen/src/IrTranslateBuiltins.cpp index 668bdfe0..f6a77f21 100644 --- a/CodeGen/src/IrTranslateBuiltins.cpp +++ b/CodeGen/src/IrTranslateBuiltins.cpp @@ -9,6 +9,7 @@ #include LUAU_FASTFLAG(LuauCodegenFastcall3) +LUAU_FASTFLAGVARIABLE(LuauCodegenMathSign, false) // TODO: when nresults is less than our actual result count, we can skip computing/writing unused results @@ -42,6 +43,8 @@ static IrOp builtinLoadDouble(IrBuilder& build, IrOp arg) static BuiltinImplResult translateBuiltinNumberToNumber( IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, int pcpos) { + CODEGEN_ASSERT(!FFlag::LuauCodegenMathSign); + if (nparams < 1 || nresults > 1) return {BuiltinImplType::None, -1}; @@ -845,7 +848,10 @@ BuiltinImplResult translateBuiltin( case LBF_MATH_LOG10: return translateBuiltinNumberToNumberLibm(build, LuauBuiltinFunction(bfid), nparams, ra, arg, nresults, pcpos); case LBF_MATH_SIGN: - return translateBuiltinNumberToNumber(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, pcpos); + if (FFlag::LuauCodegenMathSign) + return translateBuiltinMathUnary(build, IrCmd::SIGN_NUM, nparams, ra, arg, nresults, pcpos); + else + return translateBuiltinNumberToNumber(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, pcpos); case LBF_MATH_POW: case LBF_MATH_FMOD: case LBF_MATH_ATAN2: diff --git a/CodeGen/src/IrTranslation.cpp b/CodeGen/src/IrTranslation.cpp index e06f14f8..db867fc9 100644 --- a/CodeGen/src/IrTranslation.cpp +++ b/CodeGen/src/IrTranslation.cpp @@ -13,7 +13,6 @@ #include "lstate.h" #include "ltm.h" -LUAU_FASTFLAG(LuauCodegenAnalyzeHostVectorOps) LUAU_FASTFLAG(LuauCodegenUserdataOps) LUAU_FASTFLAG(LuauCodegenFastcall3) @@ -1285,8 +1284,7 @@ void translateInstGetTableKS(IrBuilder& build, const Instruction* pc, int pcpos) } else { - if (FFlag::LuauCodegenAnalyzeHostVectorOps && build.hostHooks.vectorAccess && - build.hostHooks.vectorAccess(build, field, str->len, ra, rb, pcpos)) + if (build.hostHooks.vectorAccess && build.hostHooks.vectorAccess(build, field, str->len, ra, rb, pcpos)) return; build.inst(IrCmd::FALLBACK_GETTABLEKS, build.constUint(pcpos), build.vmReg(ra), build.vmReg(rb), build.vmConst(aux)); @@ -1468,7 +1466,7 @@ bool translateInstNamecall(IrBuilder& build, const Instruction* pc, int pcpos) { build.loadAndCheckTag(build.vmReg(rb), LUA_TVECTOR, build.vmExit(pcpos)); - if (FFlag::LuauCodegenAnalyzeHostVectorOps && build.hostHooks.vectorNamecall) + if (build.hostHooks.vectorNamecall) { Instruction call = pc[2]; CODEGEN_ASSERT(LUAU_INSN_OP(call) == LOP_CALL); diff --git a/CodeGen/src/IrUtils.cpp b/CodeGen/src/IrUtils.cpp index 2244c4d3..129945df 100644 --- a/CodeGen/src/IrUtils.cpp +++ b/CodeGen/src/IrUtils.cpp @@ -69,6 +69,7 @@ IrValueKind getCmdValueKind(IrCmd cmd) case IrCmd::ROUND_NUM: case IrCmd::SQRT_NUM: case IrCmd::ABS_NUM: + case IrCmd::SIGN_NUM: return IrValueKind::Double; case IrCmd::ADD_VEC: case IrCmd::SUB_VEC: @@ -658,6 +659,14 @@ void foldConstants(IrBuilder& build, IrFunction& function, IrBlock& block, uint3 if (inst.a.kind == IrOpKind::Constant) substitute(function, inst, build.constDouble(fabs(function.doubleOp(inst.a)))); break; + case IrCmd::SIGN_NUM: + if (inst.a.kind == IrOpKind::Constant) + { + double v = function.doubleOp(inst.a); + + substitute(function, inst, build.constDouble(v > 0.0 ? 1.0 : v < 0.0 ? -1.0 : 0.0)); + } + break; case IrCmd::NOT_ANY: if (inst.a.kind == IrOpKind::Constant) { diff --git a/CodeGen/src/NativeState.cpp b/CodeGen/src/NativeState.cpp index 248f0cd3..7aa35f23 100644 --- a/CodeGen/src/NativeState.cpp +++ b/CodeGen/src/NativeState.cpp @@ -29,7 +29,6 @@ void initFunctions(NativeContext& context) context.luaV_lessthan = luaV_lessthan; context.luaV_lessequal = luaV_lessequal; context.luaV_equalval = luaV_equalval; - context.luaV_doarith = luaV_doarith; context.luaV_doarithadd = luaV_doarithimpl; context.luaV_doarithsub = luaV_doarithimpl; diff --git a/CodeGen/src/NativeState.h b/CodeGen/src/NativeState.h index be73815d..941db252 100644 --- a/CodeGen/src/NativeState.h +++ b/CodeGen/src/NativeState.h @@ -33,7 +33,6 @@ struct NativeContext int (*luaV_lessthan)(lua_State* L, const TValue* l, const TValue* r) = nullptr; int (*luaV_lessequal)(lua_State* L, const TValue* l, const TValue* r) = nullptr; int (*luaV_equalval)(lua_State* L, const TValue* t1, const TValue* t2) = nullptr; - void (*luaV_doarith)(lua_State* L, StkId ra, const TValue* rb, const TValue* rc, TMS op) = nullptr; void (*luaV_doarithadd)(lua_State* L, StkId ra, const TValue* rb, const TValue* rc) = nullptr; void (*luaV_doarithsub)(lua_State* L, StkId ra, const TValue* rb, const TValue* rc) = nullptr; void (*luaV_doarithmul)(lua_State* L, StkId ra, const TValue* rb, const TValue* rc) = nullptr; diff --git a/CodeGen/src/OptimizeConstProp.cpp b/CodeGen/src/OptimizeConstProp.cpp index 0cd2aa51..ac90f8e5 100644 --- a/CodeGen/src/OptimizeConstProp.cpp +++ b/CodeGen/src/OptimizeConstProp.cpp @@ -18,10 +18,10 @@ LUAU_FASTINTVARIABLE(LuauCodeGenMinLinearBlockPath, 3) LUAU_FASTINTVARIABLE(LuauCodeGenReuseSlotLimit, 64) LUAU_FASTINTVARIABLE(LuauCodeGenReuseUdataTagLimit, 64) LUAU_FASTFLAGVARIABLE(DebugLuauAbortingChecks, false) -LUAU_FASTFLAGVARIABLE(LuauCodegenFixSplitStoreConstMismatch, false) LUAU_FASTFLAG(LuauCodegenUserdataOps) LUAU_FASTFLAG(LuauCodegenUserdataAlloc) LUAU_FASTFLAG(LuauCodegenFastcall3) +LUAU_FASTFLAG(LuauCodegenMathSign) namespace Luau { @@ -757,48 +757,29 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& } } - if (FFlag::LuauCodegenFixSplitStoreConstMismatch) + // If we have constant tag and value, replace TValue store with tag/value pair store + bool canSplitTvalueStore = false; + + if (tag == LUA_TBOOLEAN && + (value.kind == IrOpKind::Inst || (value.kind == IrOpKind::Constant && function.constOp(value).kind == IrConstKind::Int))) + canSplitTvalueStore = true; + else if (tag == LUA_TNUMBER && + (value.kind == IrOpKind::Inst || (value.kind == IrOpKind::Constant && function.constOp(value).kind == IrConstKind::Double))) + canSplitTvalueStore = true; + else if (tag != 0xff && isGCO(tag) && value.kind == IrOpKind::Inst) + canSplitTvalueStore = true; + + if (canSplitTvalueStore) { - // If we have constant tag and value, replace TValue store with tag/value pair store - bool canSplitTvalueStore = false; + replace(function, block, index, {IrCmd::STORE_SPLIT_TVALUE, inst.a, build.constTag(tag), value, inst.c}); - if (tag == LUA_TBOOLEAN && - (value.kind == IrOpKind::Inst || (value.kind == IrOpKind::Constant && function.constOp(value).kind == IrConstKind::Int))) - canSplitTvalueStore = true; - else if (tag == LUA_TNUMBER && - (value.kind == IrOpKind::Inst || (value.kind == IrOpKind::Constant && function.constOp(value).kind == IrConstKind::Double))) - canSplitTvalueStore = true; - else if (tag != 0xff && isGCO(tag) && value.kind == IrOpKind::Inst) - canSplitTvalueStore = true; - - if (canSplitTvalueStore) - { - replace(function, block, index, {IrCmd::STORE_SPLIT_TVALUE, inst.a, build.constTag(tag), value, inst.c}); - - // Value can be propagated to future loads of the same register - if (inst.a.kind == IrOpKind::VmReg && activeLoadValue != kInvalidInstIdx) - state.valueMap[state.versionedVmRegLoad(activeLoadCmd, inst.a)] = activeLoadValue; - } - else if (inst.a.kind == IrOpKind::VmReg) - { - state.forwardVmRegStoreToLoad(inst, IrCmd::LOAD_TVALUE); - } + // Value can be propagated to future loads of the same register + if (inst.a.kind == IrOpKind::VmReg && activeLoadValue != kInvalidInstIdx) + state.valueMap[state.versionedVmRegLoad(activeLoadCmd, inst.a)] = activeLoadValue; } - else + else if (inst.a.kind == IrOpKind::VmReg) { - // If we have constant tag and value, replace TValue store with tag/value pair store - if (tag != 0xff && value.kind != IrOpKind::None && (tag == LUA_TBOOLEAN || tag == LUA_TNUMBER || isGCO(tag))) - { - replace(function, block, index, {IrCmd::STORE_SPLIT_TVALUE, inst.a, build.constTag(tag), value, inst.c}); - - // Value can be propagated to future loads of the same register - if (inst.a.kind == IrOpKind::VmReg && activeLoadValue != kInvalidInstIdx) - state.valueMap[state.versionedVmRegLoad(activeLoadCmd, inst.a)] = activeLoadValue; - } - else if (inst.a.kind == IrOpKind::VmReg) - { - state.forwardVmRegStoreToLoad(inst, IrCmd::LOAD_TVALUE); - } + state.forwardVmRegStoreToLoad(inst, IrCmd::LOAD_TVALUE); } } break; @@ -1160,6 +1141,7 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& state.updateTag(IrOp{IrOpKind::VmReg, uint8_t(firstReturnReg + 1)}, LUA_TNUMBER); break; case LBF_MATH_SIGN: + CODEGEN_ASSERT(!FFlag::LuauCodegenMathSign); state.updateTag(IrOp{IrOpKind::VmReg, uint8_t(firstReturnReg)}, LUA_TNUMBER); break; default: @@ -1225,6 +1207,7 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::ROUND_NUM: case IrCmd::SQRT_NUM: case IrCmd::ABS_NUM: + case IrCmd::SIGN_NUM: case IrCmd::NOT_ANY: state.substituteOrRecord(inst, index); break; diff --git a/Common/include/Luau/Bytecode.h b/Common/include/Luau/Bytecode.h index f971391b..604b8b86 100644 --- a/Common/include/Luau/Bytecode.h +++ b/Common/include/Luau/Bytecode.h @@ -443,7 +443,6 @@ enum LuauBytecodeTag LBC_VERSION_MAX = 6, LBC_VERSION_TARGET = 5, // Type encoding version - LBC_TYPE_VERSION_DEPRECATED = 1, LBC_TYPE_VERSION_MIN = 1, LBC_TYPE_VERSION_MAX = 3, LBC_TYPE_VERSION_TARGET = 3, diff --git a/Compiler/include/luacode.h b/Compiler/include/luacode.h index 1d200817..1440a699 100644 --- a/Compiler/include/luacode.h +++ b/Compiler/include/luacode.h @@ -44,7 +44,7 @@ struct lua_CompileOptions const char* const* mutableGlobals; // null-terminated array of userdata types that will be included in the type information - const char* const* userdataTypes = nullptr; + const char* const* userdataTypes; }; // compile source to bytecode; when source compilation fails, the resulting bytecode contains the encoded error. use free() to destroy diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp index fac740c2..f68884c5 100644 --- a/Compiler/src/BytecodeBuilder.cpp +++ b/Compiler/src/BytecodeBuilder.cpp @@ -7,7 +7,6 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauCompileTypeInfo, false) LUAU_FASTFLAG(LuauCompileUserdataInfo) LUAU_FASTFLAG(LuauCompileFastcall3) @@ -283,11 +282,8 @@ void BytecodeBuilder::endFunction(uint8_t maxstacksize, uint8_t numupvalues, uin debugLocals.clear(); debugUpvals.clear(); - if (FFlag::LuauCompileTypeInfo) - { - typedLocals.clear(); - typedUpvals.clear(); - } + typedLocals.clear(); + typedUpvals.clear(); constantMap.clear(); tableShapeMap.clear(); @@ -559,8 +555,6 @@ void BytecodeBuilder::setFunctionTypeInfo(std::string value) void BytecodeBuilder::pushLocalTypeInfo(LuauBytecodeType type, uint8_t reg, uint32_t startpc, uint32_t endpc) { - LUAU_ASSERT(FFlag::LuauCompileTypeInfo); - TypedLocal local; local.type = type; local.reg = reg; @@ -572,8 +566,6 @@ void BytecodeBuilder::pushLocalTypeInfo(LuauBytecodeType type, uint8_t reg, uint void BytecodeBuilder::pushUpvalTypeInfo(LuauBytecodeType type) { - LUAU_ASSERT(FFlag::LuauCompileTypeInfo); - TypedUpval upval; upval.type = type; @@ -712,7 +704,7 @@ void BytecodeBuilder::finalize() writeStringTable(bytecode); - if (FFlag::LuauCompileTypeInfo && FFlag::LuauCompileUserdataInfo) + if (FFlag::LuauCompileUserdataInfo) { // Write the mapping between used type name indices and their name for (uint32_t i = 0; i < uint32_t(userdataTypes.size()); i++) @@ -747,42 +739,34 @@ void BytecodeBuilder::writeFunction(std::string& ss, uint32_t id, uint8_t flags) writeByte(ss, flags); - if (FFlag::LuauCompileTypeInfo) + if (!func.typeinfo.empty() || !typedUpvals.empty() || !typedLocals.empty()) { - if (!func.typeinfo.empty() || !typedUpvals.empty() || !typedLocals.empty()) + // collect type info into a temporary string to know the overall size of type data + tempTypeInfo.clear(); + writeVarInt(tempTypeInfo, uint32_t(func.typeinfo.size())); + writeVarInt(tempTypeInfo, uint32_t(typedUpvals.size())); + writeVarInt(tempTypeInfo, uint32_t(typedLocals.size())); + + tempTypeInfo.append(func.typeinfo); + + for (const TypedUpval& l : typedUpvals) + writeByte(tempTypeInfo, l.type); + + for (const TypedLocal& l : typedLocals) { - // collect type info into a temporary string to know the overall size of type data - tempTypeInfo.clear(); - writeVarInt(tempTypeInfo, uint32_t(func.typeinfo.size())); - writeVarInt(tempTypeInfo, uint32_t(typedUpvals.size())); - writeVarInt(tempTypeInfo, uint32_t(typedLocals.size())); - - tempTypeInfo.append(func.typeinfo); - - for (const TypedUpval& l : typedUpvals) - writeByte(tempTypeInfo, l.type); - - for (const TypedLocal& l : typedLocals) - { - writeByte(tempTypeInfo, l.type); - writeByte(tempTypeInfo, l.reg); - writeVarInt(tempTypeInfo, l.startpc); - LUAU_ASSERT(l.endpc >= l.startpc); - writeVarInt(tempTypeInfo, l.endpc - l.startpc); - } - - writeVarInt(ss, uint32_t(tempTypeInfo.size())); - ss.append(tempTypeInfo); - } - else - { - writeVarInt(ss, 0); + writeByte(tempTypeInfo, l.type); + writeByte(tempTypeInfo, l.reg); + writeVarInt(tempTypeInfo, l.startpc); + LUAU_ASSERT(l.endpc >= l.startpc); + writeVarInt(tempTypeInfo, l.endpc - l.startpc); } + + writeVarInt(ss, uint32_t(tempTypeInfo.size())); + ss.append(tempTypeInfo); } else { - writeVarInt(ss, uint32_t(func.typeinfo.size())); - ss.append(func.typeinfo); + writeVarInt(ss, 0); } // instructions @@ -1251,10 +1235,10 @@ uint8_t BytecodeBuilder::getVersion() uint8_t BytecodeBuilder::getTypeEncodingVersion() { - if (FFlag::LuauCompileTypeInfo && FFlag::LuauCompileUserdataInfo) + if (FFlag::LuauCompileUserdataInfo) return LBC_TYPE_VERSION_TARGET; - return FFlag::LuauCompileTypeInfo ? 2 : LBC_TYPE_VERSION_DEPRECATED; + return 2; } #ifdef LUAU_ASSERTENABLED @@ -2368,80 +2352,77 @@ std::string BytecodeBuilder::dumpCurrentFunction(std::vector& dumpinstoffs) } } - if (FFlag::LuauCompileTypeInfo) + if (dumpFlags & Dump_Types) { - if (dumpFlags & Dump_Types) + const std::string& typeinfo = functions.back().typeinfo; + + if (FFlag::LuauCompileUserdataInfo) { - const std::string& typeinfo = functions.back().typeinfo; - - if (FFlag::LuauCompileUserdataInfo) + // Arguments start from third byte in function typeinfo string + for (uint8_t i = 2; i < typeinfo.size(); ++i) { - // Arguments start from third byte in function typeinfo string - for (uint8_t i = 2; i < typeinfo.size(); ++i) - { - uint8_t et = typeinfo[i]; + uint8_t et = typeinfo[i]; - const char* userdata = tryGetUserdataTypeName(LuauBytecodeType(et)); - const char* name = userdata ? userdata : getBaseTypeString(et); - const char* optional = (et & LBC_TYPE_OPTIONAL_BIT) ? "?" : ""; + const char* userdata = tryGetUserdataTypeName(LuauBytecodeType(et)); + const char* name = userdata ? userdata : getBaseTypeString(et); + const char* optional = (et & LBC_TYPE_OPTIONAL_BIT) ? "?" : ""; - formatAppend(result, "R%d: %s%s [argument]\n", i - 2, name, optional); - } - - for (size_t i = 0; i < typedUpvals.size(); ++i) - { - const TypedUpval& l = typedUpvals[i]; - - const char* userdata = tryGetUserdataTypeName(l.type); - const char* name = userdata ? userdata : getBaseTypeString(l.type); - const char* optional = (l.type & LBC_TYPE_OPTIONAL_BIT) ? "?" : ""; - - formatAppend(result, "U%d: %s%s\n", int(i), name, optional); - } - - for (size_t i = 0; i < typedLocals.size(); ++i) - { - const TypedLocal& l = typedLocals[i]; - - const char* userdata = tryGetUserdataTypeName(l.type); - const char* name = userdata ? userdata : getBaseTypeString(l.type); - const char* optional = (l.type & LBC_TYPE_OPTIONAL_BIT) ? "?" : ""; - - formatAppend(result, "R%d: %s%s from %d to %d\n", l.reg, name, optional, l.startpc, l.endpc); - } + formatAppend(result, "R%d: %s%s [argument]\n", i - 2, name, optional); } - else + + for (size_t i = 0; i < typedUpvals.size(); ++i) { - // Arguments start from third byte in function typeinfo string - for (uint8_t i = 2; i < typeinfo.size(); ++i) - { - uint8_t et = typeinfo[i]; + const TypedUpval& l = typedUpvals[i]; - const char* base = getBaseTypeString(et); - const char* optional = (et & LBC_TYPE_OPTIONAL_BIT) ? "?" : ""; + const char* userdata = tryGetUserdataTypeName(l.type); + const char* name = userdata ? userdata : getBaseTypeString(l.type); + const char* optional = (l.type & LBC_TYPE_OPTIONAL_BIT) ? "?" : ""; - formatAppend(result, "R%d: %s%s [argument]\n", i - 2, base, optional); - } + formatAppend(result, "U%d: %s%s\n", int(i), name, optional); + } - for (size_t i = 0; i < typedUpvals.size(); ++i) - { - const TypedUpval& l = typedUpvals[i]; + for (size_t i = 0; i < typedLocals.size(); ++i) + { + const TypedLocal& l = typedLocals[i]; - const char* base = getBaseTypeString(l.type); - const char* optional = (l.type & LBC_TYPE_OPTIONAL_BIT) ? "?" : ""; + const char* userdata = tryGetUserdataTypeName(l.type); + const char* name = userdata ? userdata : getBaseTypeString(l.type); + const char* optional = (l.type & LBC_TYPE_OPTIONAL_BIT) ? "?" : ""; - formatAppend(result, "U%d: %s%s\n", int(i), base, optional); - } + formatAppend(result, "R%d: %s%s from %d to %d\n", l.reg, name, optional, l.startpc, l.endpc); + } + } + else + { + // Arguments start from third byte in function typeinfo string + for (uint8_t i = 2; i < typeinfo.size(); ++i) + { + uint8_t et = typeinfo[i]; - for (size_t i = 0; i < typedLocals.size(); ++i) - { - const TypedLocal& l = typedLocals[i]; + const char* base = getBaseTypeString(et); + const char* optional = (et & LBC_TYPE_OPTIONAL_BIT) ? "?" : ""; - const char* base = getBaseTypeString(l.type); - const char* optional = (l.type & LBC_TYPE_OPTIONAL_BIT) ? "?" : ""; + formatAppend(result, "R%d: %s%s [argument]\n", i - 2, base, optional); + } - formatAppend(result, "R%d: %s%s from %d to %d\n", l.reg, base, optional, l.startpc, l.endpc); - } + for (size_t i = 0; i < typedUpvals.size(); ++i) + { + const TypedUpval& l = typedUpvals[i]; + + const char* base = getBaseTypeString(l.type); + const char* optional = (l.type & LBC_TYPE_OPTIONAL_BIT) ? "?" : ""; + + formatAppend(result, "U%d: %s%s\n", int(i), base, optional); + } + + for (size_t i = 0; i < typedLocals.size(); ++i) + { + const TypedLocal& l = typedLocals[i]; + + const char* base = getBaseTypeString(l.type); + const char* optional = (l.type & LBC_TYPE_OPTIONAL_BIT) ? "?" : ""; + + formatAppend(result, "R%d: %s%s from %d to %d\n", l.reg, base, optional, l.startpc, l.endpc); } } } diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 26d3100c..98520a7f 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -26,8 +26,6 @@ LUAU_FASTINTVARIABLE(LuauCompileInlineThreshold, 25) LUAU_FASTINTVARIABLE(LuauCompileInlineThresholdMaxBoost, 300) LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5) -LUAU_FASTFLAG(LuauCompileTypeInfo) -LUAU_FASTFLAGVARIABLE(LuauCompileTempTypeInfo, false) LUAU_FASTFLAGVARIABLE(LuauCompileUserdataInfo, false) LUAU_FASTFLAGVARIABLE(LuauCompileFastcall3, false) @@ -215,13 +213,6 @@ struct Compiler setDebugLine(func); - if (!FFlag::LuauCompileTypeInfo) - { - // note: we move types out of typeMap which is safe because compileFunction is only called once per function - if (std::string* funcType = functionTypes.find(func)) - bytecode.setFunctionTypeInfo(std::move(*funcType)); - } - if (func->vararg) bytecode.emitABC(LOP_PREPVARARGS, uint8_t(self + func->args.size), 0, 0); @@ -233,8 +224,7 @@ struct Compiler for (size_t i = 0; i < func->args.size; ++i) pushLocal(func->args.data[i], uint8_t(args + self + i), kDefaultAllocPc); - if (FFlag::LuauCompileTypeInfo) - argCount = localStack.size(); + argCount = localStack.size(); AstStatBlock* stat = func->body; @@ -266,7 +256,7 @@ struct Compiler bytecode.pushDebugUpval(sref(l->name)); } - if (FFlag::LuauCompileTypeInfo && options.typeInfoLevel >= 1) + if (options.typeInfoLevel >= 1) { for (AstLocal* l : upvals) { @@ -289,12 +279,9 @@ struct Compiler if (bytecode.getInstructionCount() > kMaxInstructionCount) CompileError::raise(func->location, "Exceeded function instruction limit; split the function into parts to compile"); - if (FFlag::LuauCompileTypeInfo) - { - // note: we move types out of typeMap which is safe because compileFunction is only called once per function - if (std::string* funcType = functionTypes.find(func)) - bytecode.setFunctionTypeInfo(std::move(*funcType)); - } + // note: we move types out of typeMap which is safe because compileFunction is only called once per function + if (std::string* funcType = functionTypes.find(func)) + bytecode.setFunctionTypeInfo(std::move(*funcType)); // top-level code only executes once so it can be marked as cold if it has no loops; code with loops might be profitable to compile natively if (func->functionDepth == 0 && !hasLoops) @@ -328,8 +315,7 @@ struct Compiler upvals.clear(); // note: instead of std::move above, we copy & clear to preserve capacity for future pushes stackSize = 0; - if (FFlag::LuauCompileTypeInfo) - argCount = 0; + argCount = 0; hasLoops = false; @@ -659,7 +645,7 @@ struct Compiler // if the last argument can return multiple values, we need to compute all of them into the remaining arguments unsigned int tail = unsigned(func->args.size - expr->args.size) + 1; uint8_t reg = allocReg(arg, tail); - uint32_t allocpc = FFlag::LuauCompileTypeInfo ? bytecode.getDebugPC() : kDefaultAllocPc; + uint32_t allocpc = bytecode.getDebugPC(); if (AstExprCall* expr = arg->as()) compileExprCall(expr, reg, tail, /* targetTop= */ true); @@ -669,12 +655,7 @@ struct Compiler LUAU_ASSERT(!"Unexpected expression type"); for (size_t j = i; j < func->args.size; ++j) - { - if (FFlag::LuauCompileTypeInfo) - args.push_back({func->args.data[j], uint8_t(reg + (j - i)), {Constant::Type_Unknown}, allocpc}); - else - args.push_back({func->args.data[j], uint8_t(reg + (j - i))}); - } + args.push_back({func->args.data[j], uint8_t(reg + (j - i)), {Constant::Type_Unknown}, allocpc}); // all remaining function arguments have been allocated and assigned to break; @@ -683,17 +664,14 @@ struct Compiler { // if the argument is mutated, we need to allocate a fresh register even if it's a constant uint8_t reg = allocReg(arg, 1); - uint32_t allocpc = FFlag::LuauCompileTypeInfo ? bytecode.getDebugPC() : kDefaultAllocPc; + uint32_t allocpc = bytecode.getDebugPC(); if (arg) compileExprTemp(arg, reg); else bytecode.emitABC(LOP_LOADNIL, reg, 0, 0); - if (FFlag::LuauCompileTypeInfo) - args.push_back({var, reg, {Constant::Type_Unknown}, allocpc}); - else - args.push_back({var, reg}); + args.push_back({var, reg, {Constant::Type_Unknown}, allocpc}); } else if (arg == nullptr) { @@ -718,14 +696,11 @@ struct Compiler else { uint8_t temp = allocReg(arg, 1); - uint32_t allocpc = FFlag::LuauCompileTypeInfo ? bytecode.getDebugPC() : kDefaultAllocPc; + uint32_t allocpc = bytecode.getDebugPC(); compileExprTemp(arg, temp); - if (FFlag::LuauCompileTypeInfo) - args.push_back({var, temp, {Constant::Type_Unknown}, allocpc}); - else - args.push_back({var, temp}); + args.push_back({var, temp, {Constant::Type_Unknown}, allocpc}); } } } @@ -739,16 +714,9 @@ struct Compiler for (InlineArg& arg : args) { if (arg.value.type == Constant::Type_Unknown) - { - if (FFlag::LuauCompileTypeInfo) - pushLocal(arg.local, arg.reg, arg.allocpc); - else - pushLocal(arg.local, arg.reg, kDefaultAllocPc); - } + pushLocal(arg.local, arg.reg, arg.allocpc); else - { locstants[arg.local] = arg.value; - } } // the inline frame will be used to compile return statements as well as to reject recursive inlining attempts @@ -970,8 +938,7 @@ struct Compiler bytecode.emitABC(LOP_NAMECALL, regs, selfreg, uint8_t(BytecodeBuilder::getStringHash(iname))); bytecode.emitAux(cid); - if (FFlag::LuauCompileTempTypeInfo) - hintTemporaryExprRegType(fi->expr, selfreg, LBC_TYPE_TABLE, /* instLength */ 2); + hintTemporaryExprRegType(fi->expr, selfreg, LBC_TYPE_TABLE, /* instLength */ 2); } else if (bfid >= 0) { @@ -1627,8 +1594,7 @@ struct Compiler bytecode.emitABC(getBinaryOpArith(expr->op, /* k= */ true), target, rl, uint8_t(rc)); - if (FFlag::LuauCompileTempTypeInfo) - hintTemporaryExprRegType(expr->left, rl, LBC_TYPE_NUMBER, /* instLength */ 1); + hintTemporaryExprRegType(expr->left, rl, LBC_TYPE_NUMBER, /* instLength */ 1); } else { @@ -1643,8 +1609,7 @@ struct Compiler bytecode.emitABC(op, target, uint8_t(lc), uint8_t(rr)); - if (FFlag::LuauCompileTempTypeInfo) - hintTemporaryExprRegType(expr->right, rr, LBC_TYPE_NUMBER, /* instLength */ 1); + hintTemporaryExprRegType(expr->right, rr, LBC_TYPE_NUMBER, /* instLength */ 1); return; } } @@ -1654,11 +1619,8 @@ struct Compiler bytecode.emitABC(getBinaryOpArith(expr->op), target, rl, rr); - if (FFlag::LuauCompileTempTypeInfo) - { - hintTemporaryExprRegType(expr->left, rl, LBC_TYPE_NUMBER, /* instLength */ 1); - hintTemporaryExprRegType(expr->right, rr, LBC_TYPE_NUMBER, /* instLength */ 1); - } + hintTemporaryExprRegType(expr->left, rl, LBC_TYPE_NUMBER, /* instLength */ 1); + hintTemporaryExprRegType(expr->right, rr, LBC_TYPE_NUMBER, /* instLength */ 1); } } break; @@ -2099,8 +2061,7 @@ struct Compiler bytecode.emitABC(LOP_GETTABLEKS, target, reg, uint8_t(BytecodeBuilder::getStringHash(iname))); bytecode.emitAux(cid); - if (FFlag::LuauCompileTempTypeInfo) - hintTemporaryExprRegType(expr->expr, reg, LBC_TYPE_TABLE, /* instLength */ 2); + hintTemporaryExprRegType(expr->expr, reg, LBC_TYPE_TABLE, /* instLength */ 2); } void compileExprIndexExpr(AstExprIndexExpr* expr, uint8_t target) @@ -2984,7 +2945,7 @@ struct Compiler // note: allocReg in this case allocates into parent block register - note that we don't have RegScope here uint8_t vars = allocReg(stat, unsigned(stat->vars.size)); - uint32_t allocpc = FFlag::LuauCompileTypeInfo ? bytecode.getDebugPC() : kDefaultAllocPc; + uint32_t allocpc = bytecode.getDebugPC(); compileExprListTemp(stat->values, vars, uint8_t(stat->vars.size), /* targetTop= */ true); @@ -3116,7 +3077,7 @@ struct Compiler // this makes sure the code inside the loop can't interfere with the iteration process (other than modifying the table we're iterating // through) uint8_t varreg = regs + 2; - uint32_t varregallocpc = FFlag::LuauCompileTypeInfo ? bytecode.getDebugPC() : kDefaultAllocPc; + uint32_t varregallocpc = bytecode.getDebugPC(); if (Variable* il = variables.find(stat->var); il && il->written) varreg = allocReg(stat, 1); @@ -3183,7 +3144,7 @@ struct Compiler // note that we reserve at least 2 variables; this allows our fast path to assume that we need 2 variables instead of 1 or 2 uint8_t vars = allocReg(stat, std::max(unsigned(stat->vars.size), 2u)); LUAU_ASSERT(vars == regs + 3); - uint32_t varsallocpc = FFlag::LuauCompileTypeInfo ? bytecode.getDebugPC() : kDefaultAllocPc; + uint32_t varsallocpc = bytecode.getDebugPC(); LuauOpcode skipOp = LOP_FORGPREP; @@ -3480,13 +3441,10 @@ struct Compiler bytecode.emitABC(getBinaryOpArith(stat->op), target, target, rr); - if (FFlag::LuauCompileTempTypeInfo) - { - if (var.kind != LValue::Kind_Local) - hintTemporaryRegType(stat->var, target, LBC_TYPE_NUMBER, /* instLength */ 1); + if (var.kind != LValue::Kind_Local) + hintTemporaryRegType(stat->var, target, LBC_TYPE_NUMBER, /* instLength */ 1); - hintTemporaryExprRegType(stat->value, rr, LBC_TYPE_NUMBER, /* instLength */ 1); - } + hintTemporaryExprRegType(stat->value, rr, LBC_TYPE_NUMBER, /* instLength */ 1); } } break; @@ -3720,9 +3678,7 @@ struct Compiler l.reg = reg; l.allocated = true; l.debugpc = bytecode.getDebugPC(); - - if (FFlag::LuauCompileTypeInfo) - l.allocpc = allocpc == kDefaultAllocPc ? l.debugpc : allocpc; + l.allocpc = allocpc == kDefaultAllocPc ? l.debugpc : allocpc; } bool areLocalsCaptured(size_t start) @@ -3785,7 +3741,7 @@ struct Compiler bytecode.pushDebugLocal(sref(localStack[i]->name), l->reg, l->debugpc, debugpc); } - if (FFlag::LuauCompileTypeInfo && options.typeInfoLevel >= 1 && i >= argCount) + if (options.typeInfoLevel >= 1 && i >= argCount) { uint32_t debugpc = bytecode.getDebugPC(); LuauBytecodeType ty = LBC_TYPE_ANY; @@ -3873,8 +3829,6 @@ struct Compiler void hintTemporaryRegType(AstExpr* expr, int reg, LuauBytecodeType expectedType, int instLength) { - LUAU_ASSERT(FFlag::LuauCompileTempTypeInfo); - // If we know the type of a temporary and it's not the type that would be expected by codegen, provide a hint if (LuauBytecodeType* ty = exprTypes.find(expr)) { @@ -3885,8 +3839,6 @@ struct Compiler void hintTemporaryExprRegType(AstExpr* expr, int reg, LuauBytecodeType expectedType, int instLength) { - LUAU_ASSERT(FFlag::LuauCompileTempTypeInfo); - // If we allocated a temporary register for the operation argument, try hinting its type if (!getExprLocal(expr)) hintTemporaryRegType(expr, reg, expectedType, instLength); @@ -4175,9 +4127,7 @@ struct Compiler static void setCompileOptionsForNativeCompilation(CompileOptions& options) { options.optimizationLevel = 2; // note: this might be removed in the future in favor of --!optimize - - if (FFlag::LuauCompileTypeInfo) - options.typeInfoLevel = 1; + options.typeInfoLevel = 1; } void compileOrThrow(BytecodeBuilder& bytecode, const ParseResult& parseResult, const AstNameTable& names, const CompileOptions& inputOptions) @@ -4266,18 +4216,9 @@ void compileOrThrow(BytecodeBuilder& bytecode, const ParseResult& parseResult, c } // computes type information for all functions based on type annotations - if (FFlag::LuauCompileTypeInfo) - { - if (options.typeInfoLevel >= 1) - buildTypeMap(compiler.functionTypes, compiler.localTypes, compiler.exprTypes, root, options.vectorType, compiler.userdataTypes, - compiler.builtinTypes, compiler.builtins, compiler.globals, bytecode); - } - else - { - if (functionVisitor.hasTypes) - buildTypeMap(compiler.functionTypes, compiler.localTypes, compiler.exprTypes, root, options.vectorType, compiler.userdataTypes, - compiler.builtinTypes, compiler.builtins, compiler.globals, bytecode); - } + if (options.typeInfoLevel >= 1) + buildTypeMap(compiler.functionTypes, compiler.localTypes, compiler.exprTypes, root, options.vectorType, compiler.userdataTypes, + compiler.builtinTypes, compiler.builtins, compiler.globals, bytecode); for (AstExprFunction* expr : functions) { diff --git a/Compiler/src/Types.cpp b/Compiler/src/Types.cpp index 4454114c..447b51d3 100644 --- a/Compiler/src/Types.cpp +++ b/Compiler/src/Types.cpp @@ -3,8 +3,6 @@ #include "Luau/BytecodeBuilder.h" -LUAU_FASTFLAG(LuauCompileTypeInfo) -LUAU_FASTFLAG(LuauCompileTempTypeInfo) LUAU_FASTFLAG(LuauCompileUserdataInfo) namespace Luau @@ -160,8 +158,6 @@ static std::string getFunctionType(const AstExprFunction* func, const DenseHashM static bool isMatchingGlobal(const DenseHashMap& globals, AstExpr* node, const char* name) { - LUAU_ASSERT(FFlag::LuauCompileTempTypeInfo); - if (AstExprGlobal* expr = node->as()) return Compile::getGlobalState(globals, expr->name) == Compile::Global::Default && expr->name == name; @@ -233,8 +229,6 @@ struct TypeMapVisitor : AstVisitor const AstType* resolveAliases(const AstType* ty) { - LUAU_ASSERT(FFlag::LuauCompileTempTypeInfo); - if (const AstTypeReference* ref = ty->as()) { if (ref->prefix) @@ -249,8 +243,6 @@ struct TypeMapVisitor : AstVisitor const AstTableIndexer* tryGetTableIndexer(AstExpr* expr) { - LUAU_ASSERT(FFlag::LuauCompileTempTypeInfo); - if (const AstType** typePtr = resolvedExprs.find(expr)) { if (const AstTypeTable* tableTy = (*typePtr)->as()) @@ -262,8 +254,6 @@ struct TypeMapVisitor : AstVisitor LuauBytecodeType recordResolvedType(AstExpr* expr, const AstType* ty) { - LUAU_ASSERT(FFlag::LuauCompileTempTypeInfo); - ty = resolveAliases(ty); resolvedExprs[expr] = ty; @@ -275,8 +265,6 @@ struct TypeMapVisitor : AstVisitor LuauBytecodeType recordResolvedType(AstLocal* local, const AstType* ty) { - LUAU_ASSERT(FFlag::LuauCompileTempTypeInfo); - ty = resolveAliases(ty); resolvedLocals[local] = ty; @@ -319,9 +307,6 @@ struct TypeMapVisitor : AstVisitor // for...in statement can contain type annotations on locals (we might even infer some for ipairs/pairs/generalized iteration) bool visit(AstStatForIn* node) override { - if (!FFlag::LuauCompileTempTypeInfo) - return true; - for (AstExpr* expr : node->values) expr->visit(this); @@ -382,51 +367,25 @@ struct TypeMapVisitor : AstVisitor bool visit(AstExprLocal* node) override { - if (FFlag::LuauCompileTempTypeInfo) + AstLocal* local = node->local; + + if (AstType* annotation = local->annotation) { - if (FFlag::LuauCompileTypeInfo) - { - AstLocal* local = node->local; + LuauBytecodeType ty = recordResolvedType(node, annotation); - if (AstType* annotation = local->annotation) - { - LuauBytecodeType ty = recordResolvedType(node, annotation); - - if (ty != LBC_TYPE_ANY) - localTypes[local] = ty; - } - else if (const AstType** typePtr = resolvedLocals.find(local)) - { - localTypes[local] = recordResolvedType(node, *typePtr); - } - } - - return false; + if (ty != LBC_TYPE_ANY) + localTypes[local] = ty; } - else + else if (const AstType** typePtr = resolvedLocals.find(local)) { - if (FFlag::LuauCompileTypeInfo) - { - AstLocal* local = node->local; - - if (AstType* annotation = local->annotation) - { - LuauBytecodeType ty = getType(annotation, {}, typeAliases, /* resolveAliases= */ true, vectorType, userdataTypes, bytecode); - - if (ty != LBC_TYPE_ANY) - localTypes[local] = ty; - } - } - - return true; + localTypes[local] = recordResolvedType(node, *typePtr); } + + return false; } bool visit(AstStatLocal* node) override { - if (!FFlag::LuauCompileTempTypeInfo) - return true; - for (AstExpr* expr : node->values) expr->visit(this); @@ -451,9 +410,6 @@ struct TypeMapVisitor : AstVisitor bool visit(AstExprIndexExpr* node) override { - if (!FFlag::LuauCompileTempTypeInfo) - return true; - node->expr->visit(this); node->index->visit(this); @@ -465,9 +421,6 @@ struct TypeMapVisitor : AstVisitor bool visit(AstExprIndexName* node) override { - if (!FFlag::LuauCompileTempTypeInfo) - return true; - node->expr->visit(this); if (const AstType** typePtr = resolvedExprs.find(node->expr)) @@ -499,9 +452,6 @@ struct TypeMapVisitor : AstVisitor bool visit(AstExprUnary* node) override { - if (!FFlag::LuauCompileTempTypeInfo) - return true; - node->expr->visit(this); switch (node->op) @@ -534,9 +484,6 @@ struct TypeMapVisitor : AstVisitor bool visit(AstExprBinary* node) override { - if (!FFlag::LuauCompileTempTypeInfo) - return true; - node->left->visit(this); node->right->visit(this); @@ -575,9 +522,6 @@ struct TypeMapVisitor : AstVisitor bool visit(AstExprGroup* node) override { - if (!FFlag::LuauCompileTempTypeInfo) - return true; - node->expr->visit(this); if (const AstType** typePtr = resolvedExprs.find(node->expr)) @@ -588,9 +532,6 @@ struct TypeMapVisitor : AstVisitor bool visit(AstExprTypeAssertion* node) override { - if (!FFlag::LuauCompileTempTypeInfo) - return true; - node->expr->visit(this); recordResolvedType(node, node->annotation); @@ -600,9 +541,6 @@ struct TypeMapVisitor : AstVisitor bool visit(AstExprConstantBool* node) override { - if (!FFlag::LuauCompileTempTypeInfo) - return true; - recordResolvedType(node, &builtinTypes.booleanType); return false; @@ -610,9 +548,6 @@ struct TypeMapVisitor : AstVisitor bool visit(AstExprConstantNumber* node) override { - if (!FFlag::LuauCompileTempTypeInfo) - return true; - recordResolvedType(node, &builtinTypes.numberType); return false; @@ -620,9 +555,6 @@ struct TypeMapVisitor : AstVisitor bool visit(AstExprConstantString* node) override { - if (!FFlag::LuauCompileTempTypeInfo) - return true; - recordResolvedType(node, &builtinTypes.stringType); return false; @@ -630,9 +562,6 @@ struct TypeMapVisitor : AstVisitor bool visit(AstExprInterpString* node) override { - if (!FFlag::LuauCompileTempTypeInfo) - return true; - recordResolvedType(node, &builtinTypes.stringType); return false; @@ -640,9 +569,6 @@ struct TypeMapVisitor : AstVisitor bool visit(AstExprIfElse* node) override { - if (!FFlag::LuauCompileTempTypeInfo) - return true; - node->condition->visit(this); node->trueExpr->visit(this); node->falseExpr->visit(this); @@ -660,9 +586,6 @@ struct TypeMapVisitor : AstVisitor bool visit(AstExprCall* node) override { - if (!FFlag::LuauCompileTempTypeInfo) - return true; - if (const int* bfid = builtinCalls.find(node)) { switch (LuauBuiltinFunction(*bfid)) diff --git a/VM/src/lvm.h b/VM/src/lvm.h index 96bc37f3..0b8690be 100644 --- a/VM/src/lvm.h +++ b/VM/src/lvm.h @@ -15,7 +15,6 @@ LUAI_FUNC int luaV_strcmp(const TString* ls, const TString* rs); LUAI_FUNC int luaV_lessthan(lua_State* L, const TValue* l, const TValue* r); LUAI_FUNC int luaV_lessequal(lua_State* L, const TValue* l, const TValue* r); LUAI_FUNC int luaV_equalval(lua_State* L, const TValue* t1, const TValue* t2); -LUAI_FUNC void luaV_doarith(lua_State* L, StkId ra, const TValue* rb, const TValue* rc, TMS op); template void luaV_doarithimpl(lua_State* L, StkId ra, const TValue* rb, const TValue* rc); diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index bc89458e..fb253c6a 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -16,8 +16,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauVmSplitDoarith, false) - // Disable c99-designator to avoid the warning in CGOTO dispatch table #ifdef __clang__ #if __has_warning("-Wc99-designator") @@ -1489,14 +1487,7 @@ reentry: else { // slow-path, may invoke C/Lua via metamethods - if (FFlag::LuauVmSplitDoarith) - { - VM_PROTECT(luaV_doarithimpl(L, ra, rb, rc)); - } - else - { - VM_PROTECT(luaV_doarith(L, ra, rb, rc, TM_ADD)); - } + VM_PROTECT(luaV_doarithimpl(L, ra, rb, rc)); VM_NEXT(); } } @@ -1542,14 +1533,7 @@ reentry: else { // slow-path, may invoke C/Lua via metamethods - if (FFlag::LuauVmSplitDoarith) - { - VM_PROTECT(luaV_doarithimpl(L, ra, rb, rc)); - } - else - { - VM_PROTECT(luaV_doarith(L, ra, rb, rc, TM_SUB)); - } + VM_PROTECT(luaV_doarithimpl(L, ra, rb, rc)); VM_NEXT(); } } @@ -1610,14 +1594,7 @@ reentry: else { // slow-path, may invoke C/Lua via metamethods - if (FFlag::LuauVmSplitDoarith) - { - VM_PROTECT(luaV_doarithimpl(L, ra, rb, rc)); - } - else - { - VM_PROTECT(luaV_doarith(L, ra, rb, rc, TM_MUL)); - } + VM_PROTECT(luaV_doarithimpl(L, ra, rb, rc)); VM_NEXT(); } } @@ -1678,14 +1655,7 @@ reentry: else { // slow-path, may invoke C/Lua via metamethods - if (FFlag::LuauVmSplitDoarith) - { - VM_PROTECT(luaV_doarithimpl(L, ra, rb, rc)); - } - else - { - VM_PROTECT(luaV_doarith(L, ra, rb, rc, TM_DIV)); - } + VM_PROTECT(luaV_doarithimpl(L, ra, rb, rc)); VM_NEXT(); } } @@ -1733,14 +1703,7 @@ reentry: else { // slow-path, may invoke C/Lua via metamethods - if (FFlag::LuauVmSplitDoarith) - { - VM_PROTECT(luaV_doarithimpl(L, ra, rb, rc)); - } - else - { - VM_PROTECT(luaV_doarith(L, ra, rb, rc, TM_IDIV)); - } + VM_PROTECT(luaV_doarithimpl(L, ra, rb, rc)); VM_NEXT(); } } @@ -1764,14 +1727,7 @@ reentry: else { // slow-path, may invoke C/Lua via metamethods - if (FFlag::LuauVmSplitDoarith) - { - VM_PROTECT(luaV_doarithimpl(L, ra, rb, rc)); - } - else - { - VM_PROTECT(luaV_doarith(L, ra, rb, rc, TM_MOD)); - } + VM_PROTECT(luaV_doarithimpl(L, ra, rb, rc)); VM_NEXT(); } } @@ -1792,14 +1748,7 @@ reentry: else { // slow-path, may invoke C/Lua via metamethods - if (FFlag::LuauVmSplitDoarith) - { - VM_PROTECT(luaV_doarithimpl(L, ra, rb, rc)); - } - else - { - VM_PROTECT(luaV_doarith(L, ra, rb, rc, TM_POW)); - } + VM_PROTECT(luaV_doarithimpl(L, ra, rb, rc)); VM_NEXT(); } } @@ -1820,14 +1769,7 @@ reentry: else { // slow-path, may invoke C/Lua via metamethods - if (FFlag::LuauVmSplitDoarith) - { - VM_PROTECT(luaV_doarithimpl(L, ra, rb, kv)); - } - else - { - VM_PROTECT(luaV_doarith(L, ra, rb, kv, TM_ADD)); - } + VM_PROTECT(luaV_doarithimpl(L, ra, rb, kv)); VM_NEXT(); } } @@ -1848,14 +1790,7 @@ reentry: else { // slow-path, may invoke C/Lua via metamethods - if (FFlag::LuauVmSplitDoarith) - { - VM_PROTECT(luaV_doarithimpl(L, ra, rb, kv)); - } - else - { - VM_PROTECT(luaV_doarith(L, ra, rb, kv, TM_SUB)); - } + VM_PROTECT(luaV_doarithimpl(L, ra, rb, kv)); VM_NEXT(); } } @@ -1900,14 +1835,7 @@ reentry: else { // slow-path, may invoke C/Lua via metamethods - if (FFlag::LuauVmSplitDoarith) - { - VM_PROTECT(luaV_doarithimpl(L, ra, rb, kv)); - } - else - { - VM_PROTECT(luaV_doarith(L, ra, rb, kv, TM_MUL)); - } + VM_PROTECT(luaV_doarithimpl(L, ra, rb, kv)); VM_NEXT(); } } @@ -1953,14 +1881,7 @@ reentry: else { // slow-path, may invoke C/Lua via metamethods - if (FFlag::LuauVmSplitDoarith) - { - VM_PROTECT(luaV_doarithimpl(L, ra, rb, kv)); - } - else - { - VM_PROTECT(luaV_doarith(L, ra, rb, kv, TM_DIV)); - } + VM_PROTECT(luaV_doarithimpl(L, ra, rb, kv)); VM_NEXT(); } } @@ -2007,14 +1928,7 @@ reentry: else { // slow-path, may invoke C/Lua via metamethods - if (FFlag::LuauVmSplitDoarith) - { - VM_PROTECT(luaV_doarithimpl(L, ra, rb, kv)); - } - else - { - VM_PROTECT(luaV_doarith(L, ra, rb, kv, TM_IDIV)); - } + VM_PROTECT(luaV_doarithimpl(L, ra, rb, kv)); VM_NEXT(); } } @@ -2038,14 +1952,7 @@ reentry: else { // slow-path, may invoke C/Lua via metamethods - if (FFlag::LuauVmSplitDoarith) - { - VM_PROTECT(luaV_doarithimpl(L, ra, rb, kv)); - } - else - { - VM_PROTECT(luaV_doarith(L, ra, rb, kv, TM_MOD)); - } + VM_PROTECT(luaV_doarithimpl(L, ra, rb, kv)); VM_NEXT(); } } @@ -2072,14 +1979,7 @@ reentry: else { // slow-path, may invoke C/Lua via metamethods - if (FFlag::LuauVmSplitDoarith) - { - VM_PROTECT(luaV_doarithimpl(L, ra, rb, kv)); - } - else - { - VM_PROTECT(luaV_doarith(L, ra, rb, kv, TM_POW)); - } + VM_PROTECT(luaV_doarithimpl(L, ra, rb, kv)); VM_NEXT(); } } @@ -2192,14 +2092,7 @@ reentry: else { // slow-path, may invoke C/Lua via metamethods - if (FFlag::LuauVmSplitDoarith) - { - VM_PROTECT(luaV_doarithimpl(L, ra, rb, rb)); - } - else - { - VM_PROTECT(luaV_doarith(L, ra, rb, rb, TM_UNM)); - } + VM_PROTECT(luaV_doarithimpl(L, ra, rb, rb)); VM_NEXT(); } } @@ -2812,14 +2705,7 @@ reentry: else { // slow-path, may invoke C/Lua via metamethods - if (FFlag::LuauVmSplitDoarith) - { - VM_PROTECT(luaV_doarithimpl(L, ra, kv, rc)); - } - else - { - VM_PROTECT(luaV_doarith(L, ra, kv, rc, TM_SUB)); - } + VM_PROTECT(luaV_doarithimpl(L, ra, kv, rc)); VM_NEXT(); } } @@ -2847,14 +2733,7 @@ reentry: else { // slow-path, may invoke C/Lua via metamethods - if (FFlag::LuauVmSplitDoarith) - { - VM_PROTECT(luaV_doarithimpl(L, ra, kv, rc)); - } - else - { - VM_PROTECT(luaV_doarith(L, ra, kv, rc, TM_DIV)); - } + VM_PROTECT(luaV_doarithimpl(L, ra, kv, rc)); VM_NEXT(); } } diff --git a/VM/src/lvmutils.cpp b/VM/src/lvmutils.cpp index 6ee542b0..41990742 100644 --- a/VM/src/lvmutils.cpp +++ b/VM/src/lvmutils.cpp @@ -519,140 +519,6 @@ template void luaV_doarithimpl(lua_State* L, StkId ra, const TValue* rb, template void luaV_doarithimpl(lua_State* L, StkId ra, const TValue* rb, const TValue* rc); template void luaV_doarithimpl(lua_State* L, StkId ra, const TValue* rb, const TValue* rc); -void luaV_doarith(lua_State* L, StkId ra, const TValue* rb, const TValue* rc, TMS op) -{ - TValue tempb, tempc; - const TValue *b, *c; - if ((b = luaV_tonumber(rb, &tempb)) != NULL && (c = luaV_tonumber(rc, &tempc)) != NULL) - { - double nb = nvalue(b), nc = nvalue(c); - switch (op) - { - case TM_ADD: - setnvalue(ra, luai_numadd(nb, nc)); - break; - case TM_SUB: - setnvalue(ra, luai_numsub(nb, nc)); - break; - case TM_MUL: - setnvalue(ra, luai_nummul(nb, nc)); - break; - case TM_DIV: - setnvalue(ra, luai_numdiv(nb, nc)); - break; - case TM_IDIV: - setnvalue(ra, luai_numidiv(nb, nc)); - break; - case TM_MOD: - setnvalue(ra, luai_nummod(nb, nc)); - break; - case TM_POW: - setnvalue(ra, luai_numpow(nb, nc)); - break; - case TM_UNM: - setnvalue(ra, luai_numunm(nb)); - break; - default: - LUAU_ASSERT(0); - break; - } - } - else - { - // vector operations that we support: - // v+v v-v -v (add/sub/neg) - // v*v s*v v*s (mul) - // v/v s/v v/s (div) - // v//v s//v v//s (floor div) - - const float* vb = luaV_tovector(rb); - const float* vc = luaV_tovector(rc); - - if (vb && vc) - { - switch (op) - { - case TM_ADD: - setvvalue(ra, vb[0] + vc[0], vb[1] + vc[1], vb[2] + vc[2], vb[3] + vc[3]); - return; - case TM_SUB: - setvvalue(ra, vb[0] - vc[0], vb[1] - vc[1], vb[2] - vc[2], vb[3] - vc[3]); - return; - case TM_MUL: - setvvalue(ra, vb[0] * vc[0], vb[1] * vc[1], vb[2] * vc[2], vb[3] * vc[3]); - return; - case TM_DIV: - setvvalue(ra, vb[0] / vc[0], vb[1] / vc[1], vb[2] / vc[2], vb[3] / vc[3]); - return; - case TM_IDIV: - setvvalue(ra, float(luai_numidiv(vb[0], vc[0])), float(luai_numidiv(vb[1], vc[1])), float(luai_numidiv(vb[2], vc[2])), - float(luai_numidiv(vb[3], vc[3]))); - return; - case TM_UNM: - setvvalue(ra, -vb[0], -vb[1], -vb[2], -vb[3]); - return; - default: - break; - } - } - else if (vb) - { - c = luaV_tonumber(rc, &tempc); - - if (c) - { - float nc = cast_to(float, nvalue(c)); - - switch (op) - { - case TM_MUL: - setvvalue(ra, vb[0] * nc, vb[1] * nc, vb[2] * nc, vb[3] * nc); - return; - case TM_DIV: - setvvalue(ra, vb[0] / nc, vb[1] / nc, vb[2] / nc, vb[3] / nc); - return; - case TM_IDIV: - setvvalue(ra, float(luai_numidiv(vb[0], nc)), float(luai_numidiv(vb[1], nc)), float(luai_numidiv(vb[2], nc)), - float(luai_numidiv(vb[3], nc))); - return; - default: - break; - } - } - } - else if (vc) - { - b = luaV_tonumber(rb, &tempb); - - if (b) - { - float nb = cast_to(float, nvalue(b)); - - switch (op) - { - case TM_MUL: - setvvalue(ra, nb * vc[0], nb * vc[1], nb * vc[2], nb * vc[3]); - return; - case TM_DIV: - setvvalue(ra, nb / vc[0], nb / vc[1], nb / vc[2], nb / vc[3]); - return; - case TM_IDIV: - setvvalue(ra, float(luai_numidiv(nb, vc[0])), float(luai_numidiv(nb, vc[1])), float(luai_numidiv(nb, vc[2])), - float(luai_numidiv(nb, vc[3]))); - return; - default: - break; - } - } - } - - if (!call_binTM(L, rb, rc, ra, op)) - { - luaG_aritherror(L, rb, rc, op); - } - } -} - void luaV_dolen(lua_State* L, StkId ra, const TValue* rb) { const TValue* tm = NULL; diff --git a/tests/AstJsonEncoder.test.cpp b/tests/AstJsonEncoder.test.cpp index 82e8f139..1c8b2127 100644 --- a/tests/AstJsonEncoder.test.cpp +++ b/tests/AstJsonEncoder.test.cpp @@ -9,6 +9,8 @@ #include #include +LUAU_FASTFLAG(LuauDeclarationExtraPropData) + using namespace Luau; struct JsonEncoderFixture @@ -408,16 +410,32 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatTypeAlias") TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatDeclareFunction") { + ScopedFastFlag luauDeclarationExtraPropData{FFlag::LuauDeclarationExtraPropData, true}; + AstStat* statement = expectParseStatement("declare function foo(x: number): string"); std::string_view expected = - R"({"type":"AstStatDeclareFunction","location":"0,0 - 0,39","name":"foo","params":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,24 - 0,30","name":"number","nameLocation":"0,24 - 0,30","parameters":[]}]},"retTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,33 - 0,39","name":"string","nameLocation":"0,33 - 0,39","parameters":[]}]},"generics":[],"genericPacks":[]})"; + R"({"type":"AstStatDeclareFunction","location":"0,0 - 0,39","name":"foo","nameLocation":"0,17 - 0,20","params":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,24 - 0,30","name":"number","nameLocation":"0,24 - 0,30","parameters":[]}]},"paramNames":[{"type":"AstArgumentName","name":"x","location":"0,21 - 0,22"}],"vararg":false,"varargLocation":"0,0 - 0,0","retTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,33 - 0,39","name":"string","nameLocation":"0,33 - 0,39","parameters":[]}]},"generics":[],"genericPacks":[]})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatDeclareFunction2") +{ + ScopedFastFlag luauDeclarationExtraPropData{FFlag::LuauDeclarationExtraPropData, true}; + + AstStat* statement = expectParseStatement("declare function foo(x: number, ...: string): string"); + + std::string_view expected = + R"({"type":"AstStatDeclareFunction","location":"0,0 - 0,52","name":"foo","nameLocation":"0,17 - 0,20","params":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,24 - 0,30","name":"number","nameLocation":"0,24 - 0,30","parameters":[]}],"tailType":{"type":"AstTypePackVariadic","location":"0,37 - 0,43","variadicType":{"type":"AstTypeReference","location":"0,37 - 0,43","name":"string","nameLocation":"0,37 - 0,43","parameters":[]}}},"paramNames":[{"type":"AstArgumentName","name":"x","location":"0,21 - 0,22"}],"vararg":true,"varargLocation":"0,32 - 0,35","retTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,46 - 0,52","name":"string","nameLocation":"0,46 - 0,52","parameters":[]}]},"generics":[],"genericPacks":[]})"; CHECK(toJson(statement) == expected); } TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatDeclareClass") { + ScopedFastFlag luauDeclarationExtraPropData{FFlag::LuauDeclarationExtraPropData, true}; + AstStatBlock* root = expectParse(R"( declare class Foo prop: number @@ -432,11 +450,11 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatDeclareClass") REQUIRE(2 == root->body.size); std::string_view expected1 = - R"({"type":"AstStatDeclareClass","location":"1,22 - 4,11","name":"Foo","props":[{"name":"prop","type":"AstDeclaredClassProp","luauType":{"type":"AstTypeReference","location":"2,18 - 2,24","name":"number","nameLocation":"2,18 - 2,24","parameters":[]}},{"name":"method","type":"AstDeclaredClassProp","luauType":{"type":"AstTypeFunction","location":"3,21 - 4,11","generics":[],"genericPacks":[],"argTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"3,39 - 3,45","name":"number","nameLocation":"3,39 - 3,45","parameters":[]}]},"argNames":[{"type":"AstArgumentName","name":"foo","location":"3,34 - 3,37"}],"returnTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"3,48 - 3,54","name":"string","nameLocation":"3,48 - 3,54","parameters":[]}]}}}],"indexer":null})"; + R"({"type":"AstStatDeclareClass","location":"1,22 - 4,11","name":"Foo","props":[{"name":"prop","nameLocation":"2,12 - 2,16","type":"AstDeclaredClassProp","luauType":{"type":"AstTypeReference","location":"2,18 - 2,24","name":"number","nameLocation":"2,18 - 2,24","parameters":[]},"location":"2,12 - 2,24"},{"name":"method","nameLocation":"3,21 - 3,27","type":"AstDeclaredClassProp","luauType":{"type":"AstTypeFunction","location":"3,12 - 3,54","generics":[],"genericPacks":[],"argTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"3,39 - 3,45","name":"number","nameLocation":"3,39 - 3,45","parameters":[]}]},"argNames":[{"type":"AstArgumentName","name":"foo","location":"3,34 - 3,37"}],"returnTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"3,48 - 3,54","name":"string","nameLocation":"3,48 - 3,54","parameters":[]}]}},"location":"3,12 - 3,54"}],"indexer":null})"; CHECK(toJson(root->body.data[0]) == expected1); std::string_view expected2 = - R"({"type":"AstStatDeclareClass","location":"6,22 - 8,11","name":"Bar","superName":"Foo","props":[{"name":"prop2","type":"AstDeclaredClassProp","luauType":{"type":"AstTypeReference","location":"7,19 - 7,25","name":"string","nameLocation":"7,19 - 7,25","parameters":[]}}],"indexer":null})"; + R"({"type":"AstStatDeclareClass","location":"6,22 - 8,11","name":"Bar","superName":"Foo","props":[{"name":"prop2","nameLocation":"7,12 - 7,17","type":"AstDeclaredClassProp","luauType":{"type":"AstTypeReference","location":"7,19 - 7,25","name":"string","nameLocation":"7,19 - 7,25","parameters":[]},"location":"7,12 - 7,25"}],"indexer":null})"; CHECK(toJson(root->body.data[1]) == expected2); } diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index c220f30b..4e8a0442 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -35,6 +35,7 @@ struct ACFixtureImpl : BaseType { FrontendOptions opts; opts.forAutocomplete = true; + opts.retainFullTypeGraphs = true; this->frontend.check("MainModule", opts); return Luau::autocomplete(this->frontend, "MainModule", Position{row, column}, nullCallback); @@ -44,6 +45,7 @@ struct ACFixtureImpl : BaseType { FrontendOptions opts; opts.forAutocomplete = true; + opts.retainFullTypeGraphs = true; this->frontend.check("MainModule", opts); return Luau::autocomplete(this->frontend, "MainModule", getPosition(marker), callback); @@ -53,6 +55,7 @@ struct ACFixtureImpl : BaseType { FrontendOptions opts; opts.forAutocomplete = true; + opts.retainFullTypeGraphs = true; this->frontend.check(name, opts); return Luau::autocomplete(this->frontend, name, pos, callback); @@ -3681,6 +3684,8 @@ a.@1 auto ac = autocomplete('1'); + CHECK(2 == ac.entryMap.size()); + CHECK(ac.entryMap.count("x")); CHECK(ac.entryMap.count("y")); @@ -3733,11 +3738,13 @@ TEST_CASE_FIXTURE(ACFixture, "string_contents_is_available_to_callback") declare function require(path: string): any )"); - std::optional require = frontend.globalsForAutocomplete.globalScope->linearSearchForBinding("require"); + GlobalTypes& globals = FFlag::DebugLuauDeferredConstraintResolution ? frontend.globals : frontend.globalsForAutocomplete; + + std::optional require = globals.globalScope->linearSearchForBinding("require"); REQUIRE(require); - Luau::unfreeze(frontend.globalsForAutocomplete.globalTypes); + Luau::unfreeze(globals.globalTypes); attachTag(require->typeId, "RequireCall"); - Luau::freeze(frontend.globalsForAutocomplete.globalTypes); + Luau::freeze(globals.globalTypes); check(R"( local x = require("testing/@1") @@ -3837,11 +3844,13 @@ TEST_CASE_FIXTURE(ACFixture, "string_completion_outside_quotes") declare function require(path: string): any )"); - std::optional require = frontend.globalsForAutocomplete.globalScope->linearSearchForBinding("require"); + GlobalTypes& globals = FFlag::DebugLuauDeferredConstraintResolution ? frontend.globals : frontend.globalsForAutocomplete; + + std::optional require = globals.globalScope->linearSearchForBinding("require"); REQUIRE(require); - Luau::unfreeze(frontend.globalsForAutocomplete.globalTypes); + Luau::unfreeze(globals.globalTypes); attachTag(require->typeId, "RequireCall"); - Luau::freeze(frontend.globalsForAutocomplete.globalTypes); + Luau::freeze(globals.globalTypes); check(R"( local x = require(@1"@2"@3) diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index eeca416c..250de6e4 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -22,8 +22,6 @@ LUAU_FASTINT(LuauCompileLoopUnrollThreshold) LUAU_FASTINT(LuauCompileLoopUnrollThresholdMaxBoost) LUAU_FASTINT(LuauRecursionLimit) -LUAU_FASTFLAG(LuauCompileTypeInfo) -LUAU_FASTFLAG(LuauCompileTempTypeInfo) LUAU_FASTFLAG(LuauCompileUserdataInfo) LUAU_FASTFLAG(LuauCompileFastcall3) @@ -3226,8 +3224,6 @@ RETURN R0 0 TEST_CASE("DebugTypes") { - ScopedFastFlag luauCompileTypeInfo{FFlag::LuauCompileTypeInfo, true}; - ScopedFastFlag luauCompileTempTypeInfo{FFlag::LuauCompileTempTypeInfo, true}; ScopedFastFlag luauCompileUserdataInfo{FFlag::LuauCompileUserdataInfo, true}; const char* source = R"( diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 516e02f4..65af4e4d 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -33,7 +33,6 @@ void luaC_validate(lua_State* L); LUAU_FASTFLAG(DebugLuauAbortingChecks) LUAU_FASTINT(CodegenHeuristicsInstructionLimit) -LUAU_FASTFLAG(LuauCodegenFixSplitStoreConstMismatch) LUAU_FASTFLAG(LuauAttributeSyntax) LUAU_FASTFLAG(LuauNativeAttribute) @@ -2358,8 +2357,6 @@ TEST_CASE("Native") if (!codegen || !luau_codegen_supported()) return; - ScopedFastFlag luauCodegenFixSplitStoreConstMismatch{FFlag::LuauCodegenFixSplitStoreConstMismatch, true}; - SUBCASE("Checked") { FFlag::DebugLuauAbortingChecks.value = true; diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index 411d4914..967dea43 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -1333,4 +1333,58 @@ TEST_CASE_FIXTURE(FrontendFixture, "checked_modules_have_the_correct_mode") CHECK(moduleC->mode == Mode::Strict); } +TEST_CASE_FIXTURE(FrontendFixture, "separate_caches_for_autocomplete") +{ + ScopedFastFlag sff{FFlag::DebugLuauDeferredConstraintResolution, false}; + + fileResolver.source["game/A"] = R"( + --!nonstrict + local exports = {} + function exports.hello() end + return exports + )"; + + FrontendOptions opts; + opts.forAutocomplete = true; + + frontend.check("game/A", opts); + + CHECK(nullptr == frontend.moduleResolver.getModule("game/A")); + + ModulePtr acModule = frontend.moduleResolverForAutocomplete.getModule("game/A"); + REQUIRE(acModule != nullptr); + CHECK(acModule->mode == Mode::Strict); + + frontend.check("game/A"); + + ModulePtr module = frontend.moduleResolver.getModule("game/A"); + + REQUIRE(module != nullptr); + CHECK(module->mode == Mode::Nonstrict); +} + +TEST_CASE_FIXTURE(FrontendFixture, "no_separate_caches_with_the_new_solver") +{ + ScopedFastFlag sff{FFlag::DebugLuauDeferredConstraintResolution, true}; + + fileResolver.source["game/A"] = R"( + --!nonstrict + local exports = {} + function exports.hello() end + return exports + )"; + + FrontendOptions opts; + opts.forAutocomplete = true; + + frontend.check("game/A", opts); + + CHECK(nullptr == frontend.moduleResolverForAutocomplete.getModule("game/A")); + + ModulePtr module = frontend.moduleResolver.getModule("game/A"); + + REQUIRE(module != nullptr); + CHECK(module->mode == Mode::Nonstrict); +} + TEST_SUITE_END(); diff --git a/tests/Generalization.test.cpp b/tests/Generalization.test.cpp index e9344911..901461ae 100644 --- a/tests/Generalization.test.cpp +++ b/tests/Generalization.test.cpp @@ -7,6 +7,7 @@ #include "Luau/TypeArena.h" #include "Luau/Error.h" +#include "Fixture.h" #include "ScopedFlags.h" #include "doctest.h" @@ -172,4 +173,78 @@ TEST_CASE_FIXTURE(GeneralizationFixture, "functions_containing_cyclic_tables_can CHECK(generalizedTypes->contains(builtinTypes.numberType)); } +TEST_CASE_FIXTURE(GeneralizationFixture, "union_type_traversal_doesnt_crash") +{ + // t1 where t1 = ('h <: (t1 <: 'i)) | ('j <: (t1 <: 'i)) + TypeId i = arena.addType(FreeType{NotNull{globalScope.get()}}); + TypeId h = arena.addType(FreeType{NotNull{globalScope.get()}}); + TypeId j = arena.addType(FreeType{NotNull{globalScope.get()}}); + TypeId unionType = arena.addType(UnionType{{h, j}}); + getMutable(h)->upperBound = i; + getMutable(h)->lowerBound = builtinTypes.neverType; + getMutable(i)->upperBound = builtinTypes.unknownType; + getMutable(i)->lowerBound = unionType; + getMutable(j)->upperBound = i; + getMutable(j)->lowerBound = builtinTypes.neverType; + + generalize(unionType); +} + +TEST_CASE_FIXTURE(GeneralizationFixture, "intersection_type_traversal_doesnt_crash") +{ + // t1 where t1 = ('h <: (t1 <: 'i)) & ('j <: (t1 <: 'i)) + TypeId i = arena.addType(FreeType{NotNull{globalScope.get()}}); + TypeId h = arena.addType(FreeType{NotNull{globalScope.get()}}); + TypeId j = arena.addType(FreeType{NotNull{globalScope.get()}}); + TypeId intersectionType = arena.addType(IntersectionType{{h, j}}); + + getMutable(h)->upperBound = i; + getMutable(h)->lowerBound = builtinTypes.neverType; + getMutable(i)->upperBound = builtinTypes.unknownType; + getMutable(i)->lowerBound = intersectionType; + getMutable(j)->upperBound = i; + getMutable(j)->lowerBound = builtinTypes.neverType; + + generalize(intersectionType); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "generalization_traversal_should_re_traverse_unions_if_they_change_type") +{ + // This test case should just not assert + CheckResult result = check(R"( +function byId(p) + return p.id +end + +function foo() + + local productButtonPairs = {} + local func = byId + local dir = -1 + + local function updateSearch() + for product, button in pairs(productButtonPairs) do + button.LayoutOrder = func(product) * dir + end + end + + function(mode) + if mode == 'Name'then + else + if mode == 'New'then + func = function(p) + return p.id + end + elseif mode == 'Price'then + func = function(p) + return p.price + end + end + + end + end +end +)"); +} + TEST_SUITE_END(); diff --git a/tests/IrBuilder.test.cpp b/tests/IrBuilder.test.cpp index bd7a02b2..611eb7b5 100644 --- a/tests/IrBuilder.test.cpp +++ b/tests/IrBuilder.test.cpp @@ -13,9 +13,9 @@ #include LUAU_FASTFLAG(DebugLuauAbortingChecks) -LUAU_FASTFLAG(LuauCodegenFixSplitStoreConstMismatch) LUAU_FASTFLAG(LuauCodegenInstG) LUAU_FASTFLAG(LuauCodegenFastcall3) +LUAU_FASTFLAG(LuauCodegenMathSign) using namespace Luau::CodeGen; @@ -335,6 +335,8 @@ TEST_SUITE_BEGIN("ConstantFolding"); TEST_CASE_FIXTURE(IrBuilderFixture, "Numeric") { + ScopedFastFlag luauCodegenMathSign{FFlag::LuauCodegenMathSign, true}; + IrOp block = build.block(IrBlockKind::Internal); build.beginBlock(block); @@ -365,6 +367,8 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "Numeric") build.inst(IrCmd::STORE_INT, build.vmReg(20), build.inst(IrCmd::NOT_ANY, build.constTag(tboolean), build.constInt(0))); build.inst(IrCmd::STORE_INT, build.vmReg(21), build.inst(IrCmd::NOT_ANY, build.constTag(tboolean), build.constInt(1))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(22), build.inst(IrCmd::SIGN_NUM, build.constDouble(-4))); + build.inst(IrCmd::RETURN, build.constUint(0)); updateUseCounts(build.function); @@ -393,6 +397,7 @@ bb_0: STORE_INT R19, 0i STORE_INT R20, 1i STORE_INT R21, 0i + STORE_DOUBLE R22, -1 RETURN 0u )"); @@ -2662,8 +2667,6 @@ bb_0: TEST_CASE_FIXTURE(IrBuilderFixture, "DoNotProduceInvalidSplitStore1") { - ScopedFastFlag luauCodegenFixSplitStoreConstMismatch{FFlag::LuauCodegenFixSplitStoreConstMismatch, true}; - IrOp entry = build.block(IrBlockKind::Internal); build.beginBlock(entry); @@ -2690,8 +2693,6 @@ bb_0: TEST_CASE_FIXTURE(IrBuilderFixture, "DoNotProduceInvalidSplitStore2") { - ScopedFastFlag luauCodegenFixSplitStoreConstMismatch{FFlag::LuauCodegenFixSplitStoreConstMismatch, true}; - IrOp entry = build.block(IrBlockKind::Internal); build.beginBlock(entry); diff --git a/tests/IrLowering.test.cpp b/tests/IrLowering.test.cpp index 33d5602b..0ff8a12c 100644 --- a/tests/IrLowering.test.cpp +++ b/tests/IrLowering.test.cpp @@ -15,9 +15,6 @@ #include #include -LUAU_FASTFLAG(LuauCompileTypeInfo) -LUAU_FASTFLAG(LuauCompileTempTypeInfo) -LUAU_FASTFLAG(LuauCodegenAnalyzeHostVectorOps) LUAU_FASTFLAG(LuauCompileUserdataInfo) LUAU_FASTFLAG(LuauLoadUserdataInfo) LUAU_FASTFLAG(LuauCodegenUserdataOps) @@ -427,27 +424,6 @@ bb_bytecode_0: )"); } -TEST_CASE("DseInitialStackState3") -{ - ScopedFastFlag luauCodegenFastcall3{FFlag::LuauCodegenFastcall3, true}; - - CHECK_EQ("\n" + getCodegenAssembly(R"( -local function foo(a) - math.sign(a) - return a -end -)"), - R"( -; function foo($arg0) line 2 -bb_bytecode_0: - CHECK_SAFE_ENV exit(1) - CHECK_TAG R0, tnumber, exit(1) - FASTCALL 47u, R1, R0, 1i - INTERRUPT 5u - RETURN R0, 1i -)"); -} - TEST_CASE("VectorConstantTag") { CHECK_EQ("\n" + getCodegenAssembly(R"( @@ -539,8 +515,6 @@ bb_6: TEST_CASE("VectorCustomAccess") { - ScopedFastFlag luauCodegenAnalyzeHostVectorOps{FFlag::LuauCodegenAnalyzeHostVectorOps, true}; - CHECK_EQ("\n" + getCodegenAssembly(R"( local function vec3magn(a: vector) return a.Magnitude * 2 @@ -573,8 +547,6 @@ bb_bytecode_1: TEST_CASE("VectorCustomNamecall") { - ScopedFastFlag luauCodegenAnalyzeHostVectorOps{FFlag::LuauCodegenAnalyzeHostVectorOps, true}; - CHECK_EQ("\n" + getCodegenAssembly(R"( local function vec3dot(a: vector, b: vector) return (a:Dot(b)) @@ -611,8 +583,6 @@ bb_bytecode_1: TEST_CASE("VectorCustomAccessChain") { - ScopedFastFlag luauCodegenAnalyzeHostVectorOps{FFlag::LuauCodegenAnalyzeHostVectorOps, true}; - CHECK_EQ("\n" + getCodegenAssembly(R"( local function foo(a: vector, b: vector) return a.Unit * b.Magnitude @@ -663,8 +633,6 @@ bb_bytecode_1: TEST_CASE("VectorCustomNamecallChain") { - ScopedFastFlag luauCodegenAnalyzeHostVectorOps{FFlag::LuauCodegenAnalyzeHostVectorOps, true}; - CHECK_EQ("\n" + getCodegenAssembly(R"( local function foo(n: vector, b: vector, t: vector) return n:Cross(t):Dot(b) + 1 @@ -722,8 +690,6 @@ bb_bytecode_1: TEST_CASE("VectorCustomNamecallChain2") { - ScopedFastFlag sffs[]{{FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCodegenAnalyzeHostVectorOps, true}}; - CHECK_EQ("\n" + getCodegenAssembly(R"( type Vertex = {n: vector, b: vector} @@ -890,8 +856,6 @@ bb_4: TEST_CASE("ExplicitUpvalueAndLocalTypes") { - ScopedFastFlag sffs[]{{FFlag::LuauCompileTypeInfo, true}}; - CHECK_EQ("\n" + getCodegenAssembly(R"( local y: vector = ... @@ -933,7 +897,7 @@ bb_bytecode_0: TEST_CASE("FastcallTypeInferThroughLocal") { - ScopedFastFlag sffs[]{{FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCompileFastcall3, true}, {FFlag::LuauCodegenFastcall3, true}}; + ScopedFastFlag sffs[]{{FFlag::LuauCompileFastcall3, true}, {FFlag::LuauCodegenFastcall3, true}}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function getsum(x, c) @@ -981,7 +945,7 @@ bb_bytecode_1: TEST_CASE("FastcallTypeInferThroughUpvalue") { - ScopedFastFlag sffs[]{{FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCompileFastcall3, true}, {FFlag::LuauCodegenFastcall3, true}}; + ScopedFastFlag sffs[]{{FFlag::LuauCompileFastcall3, true}, {FFlag::LuauCodegenFastcall3, true}}; CHECK_EQ("\n" + getCodegenAssembly(R"( local v = ... @@ -1038,8 +1002,6 @@ bb_bytecode_1: TEST_CASE("LoadAndMoveTypePropagation") { - ScopedFastFlag sffs[]{{FFlag::LuauCompileTypeInfo, true}}; - CHECK_EQ("\n" + getCodegenAssembly(R"( local function getsum(n) local seqsum = 0 @@ -1105,7 +1067,7 @@ bb_bytecode_4: TEST_CASE("ArgumentTypeRefinement") { - ScopedFastFlag sffs[]{{FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCompileFastcall3, true}, {FFlag::LuauCodegenFastcall3, true}}; + ScopedFastFlag sffs[]{{FFlag::LuauCompileFastcall3, true}, {FFlag::LuauCodegenFastcall3, true}}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function getsum(x, y) @@ -1141,8 +1103,6 @@ bb_bytecode_0: TEST_CASE("InlineFunctionType") { - ScopedFastFlag sffs[]{{FFlag::LuauCompileTypeInfo, true}}; - CHECK_EQ("\n" + getCodegenAssembly(R"( local function inl(v: vector, s: number) return v.Y * s @@ -1189,8 +1149,6 @@ bb_bytecode_0: TEST_CASE("ResolveTablePathTypes") { - ScopedFastFlag sffs[]{{FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCompileTempTypeInfo, true}}; - CHECK_EQ("\n" + getCodegenAssembly(R"( type Vertex = {pos: vector, normal: vector} @@ -1243,8 +1201,6 @@ bb_6: TEST_CASE("ResolvableSimpleMath") { - ScopedFastFlag sffs[]{{FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCompileTempTypeInfo, true}}; - CHECK_EQ("\n" + getCodegenHeader(R"( type Vertex = { p: vector, uv: vector, n: vector, t: vector, b: vector, h: number } local mesh: { vertices: {Vertex}, indices: {number} } = ... @@ -1299,8 +1255,6 @@ end TEST_CASE("ResolveVectorNamecalls") { - ScopedFastFlag sffs[]{{FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCodegenAnalyzeHostVectorOps, true}}; - CHECK_EQ("\n" + getCodegenAssembly(R"( type Vertex = {pos: vector, normal: vector} @@ -1363,8 +1317,6 @@ bb_6: TEST_CASE("ImmediateTypeAnnotationHelp") { - ScopedFastFlag sffs[]{{FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCompileTempTypeInfo, true}}; - CHECK_EQ("\n" + getCodegenAssembly(R"( local function foo(arr, i) return (arr[i] :: vector) / 5 @@ -1401,8 +1353,7 @@ bb_2: TEST_CASE("UnaryTypeResolve") { - ScopedFastFlag sffs[]{{FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileFastcall3, true}, - {FFlag::LuauCodegenFastcall3, true}}; + ScopedFastFlag sffs[]{{FFlag::LuauCompileFastcall3, true}, {FFlag::LuauCodegenFastcall3, true}}; CHECK_EQ("\n" + getCodegenHeader(R"( local function foo(a, b: vector, c) @@ -1424,8 +1375,6 @@ end TEST_CASE("ForInManualAnnotation") { - ScopedFastFlag sffs[]{{FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCompileTempTypeInfo, true}}; - CHECK_EQ("\n" + getCodegenAssembly(R"( type Vertex = {pos: vector, normal: vector} @@ -1519,8 +1468,6 @@ bb_12: TEST_CASE("ForInAutoAnnotationIpairs") { - ScopedFastFlag sffs[]{{FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCompileTempTypeInfo, true}}; - CHECK_EQ("\n" + getCodegenHeader(R"( type Vertex = {pos: vector, normal: vector} @@ -1546,8 +1493,6 @@ end TEST_CASE("ForInAutoAnnotationPairs") { - ScopedFastFlag sffs[]{{FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCompileTempTypeInfo, true}}; - CHECK_EQ("\n" + getCodegenHeader(R"( type Vertex = {pos: vector, normal: vector} @@ -1573,8 +1518,6 @@ end TEST_CASE("ForInAutoAnnotationGeneric") { - ScopedFastFlag sffs[]{{FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCompileTempTypeInfo, true}}; - CHECK_EQ("\n" + getCodegenHeader(R"( type Vertex = {pos: vector, normal: vector} @@ -1605,8 +1548,7 @@ TEST_CASE("CustomUserdataTypesTemp") if (!Luau::CodeGen::isSupported()) return; - ScopedFastFlag sffs[]{{FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, false}, - {FFlag::LuauLoadUserdataInfo, true}}; + ScopedFastFlag sffs[]{{FFlag::LuauCompileUserdataInfo, false}, {FFlag::LuauLoadUserdataInfo, true}}; CHECK_EQ("\n" + getCodegenHeader(R"( local function foo(v: vec2, x: mat3) @@ -1626,8 +1568,7 @@ TEST_CASE("CustomUserdataTypes") if (!Luau::CodeGen::isSupported()) return; - ScopedFastFlag sffs[]{{FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true}, - {FFlag::LuauLoadUserdataInfo, true}}; + ScopedFastFlag sffs[]{{FFlag::LuauCompileUserdataInfo, true}, {FFlag::LuauLoadUserdataInfo, true}}; CHECK_EQ("\n" + getCodegenHeader(R"( local function foo(v: vec2, x: mat3) @@ -1647,8 +1588,7 @@ TEST_CASE("CustomUserdataPropertyAccess") if (!Luau::CodeGen::isSupported()) return; - ScopedFastFlag sffs[]{{FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true}, - {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenUserdataOps, true}}; + ScopedFastFlag sffs[]{{FFlag::LuauCompileUserdataInfo, true}, {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenUserdataOps, true}}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function foo(v: vec2) @@ -1683,8 +1623,7 @@ TEST_CASE("CustomUserdataPropertyAccess2") if (!Luau::CodeGen::isSupported()) return; - ScopedFastFlag sffs[]{{FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true}, - {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenUserdataOps, true}}; + ScopedFastFlag sffs[]{{FFlag::LuauCompileUserdataInfo, true}, {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenUserdataOps, true}}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function foo(a: mat3) @@ -1721,8 +1660,7 @@ TEST_CASE("CustomUserdataNamecall1") if (!Luau::CodeGen::isSupported()) return; - ScopedFastFlag sffs[]{{FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true}, - {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenAnalyzeHostVectorOps, true}, {FFlag::LuauCodegenUserdataOps, true}}; + ScopedFastFlag sffs[]{{FFlag::LuauCompileUserdataInfo, true}, {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenUserdataOps, true}}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function foo(a: vec2, b: vec2) @@ -1768,8 +1706,7 @@ TEST_CASE("CustomUserdataNamecall2") if (!Luau::CodeGen::isSupported()) return; - ScopedFastFlag sffs[]{{FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true}, - {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenAnalyzeHostVectorOps, true}, {FFlag::LuauCodegenUserdataOps, true}, + ScopedFastFlag sffs[]{{FFlag::LuauCompileUserdataInfo, true}, {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenUserdataOps, true}, {FFlag::LuauCodegenUserdataAlloc, true}}; CHECK_EQ("\n" + getCodegenAssembly(R"( @@ -1819,8 +1756,7 @@ TEST_CASE("CustomUserdataMetamethodDirectFlow") if (!Luau::CodeGen::isSupported()) return; - ScopedFastFlag sffs[]{{FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true}, - {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenUserdataOps, true}}; + ScopedFastFlag sffs[]{{FFlag::LuauCompileUserdataInfo, true}, {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenUserdataOps, true}}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function foo(a: mat3, b: mat3) @@ -1852,8 +1788,7 @@ TEST_CASE("CustomUserdataMetamethodDirectFlow2") if (!Luau::CodeGen::isSupported()) return; - ScopedFastFlag sffs[]{{FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true}, - {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenUserdataOps, true}}; + ScopedFastFlag sffs[]{{FFlag::LuauCompileUserdataInfo, true}, {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenUserdataOps, true}}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function foo(a: mat3) @@ -1883,8 +1818,7 @@ TEST_CASE("CustomUserdataMetamethodDirectFlow3") if (!Luau::CodeGen::isSupported()) return; - ScopedFastFlag sffs[]{{FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true}, - {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenUserdataOps, true}}; + ScopedFastFlag sffs[]{{FFlag::LuauCompileUserdataInfo, true}, {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenUserdataOps, true}}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function foo(a: sequence) @@ -1914,8 +1848,8 @@ TEST_CASE("CustomUserdataMetamethod") if (!Luau::CodeGen::isSupported()) return; - ScopedFastFlag sffs[]{{FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true}, - {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenUserdataOps, true}, {FFlag::LuauCodegenUserdataAlloc, true}}; + ScopedFastFlag sffs[]{{FFlag::LuauCompileUserdataInfo, true}, {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenUserdataOps, true}, + {FFlag::LuauCodegenUserdataAlloc, true}}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function foo(a: vec2, b: vec2, c: vec2) diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index fabb897f..972d0edd 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -19,6 +19,7 @@ LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAG(LuauAttributeSyntax); LUAU_FASTFLAG(LuauLeadingBarAndAmpersand2); LUAU_FASTFLAG(LuauAttributeSyntaxFunExpr); +LUAU_FASTFLAG(LuauDeclarationExtraPropData); namespace { @@ -1858,6 +1859,8 @@ function func():end TEST_CASE_FIXTURE(Fixture, "parse_declarations") { + ScopedFastFlag luauDeclarationExtraPropData{FFlag::LuauDeclarationExtraPropData, true}; + AstStatBlock* stat = parseEx(R"( declare foo: number declare function bar(x: number): string @@ -1871,18 +1874,23 @@ TEST_CASE_FIXTURE(Fixture, "parse_declarations") AstStatDeclareGlobal* global = stat->body.data[0]->as(); REQUIRE(global); CHECK(global->name == "foo"); + CHECK(global->nameLocation == Location({1, 16}, {1, 19})); CHECK(global->type); AstStatDeclareFunction* func = stat->body.data[1]->as(); REQUIRE(func); CHECK(func->name == "bar"); + CHECK(func->nameLocation == Location({2, 25}, {2, 28})); REQUIRE_EQ(func->params.types.size, 1); REQUIRE_EQ(func->retTypes.types.size, 1); AstStatDeclareFunction* varFunc = stat->body.data[2]->as(); REQUIRE(varFunc); CHECK(varFunc->name == "var"); + CHECK(varFunc->nameLocation == Location({3, 25}, {3, 28})); CHECK(varFunc->params.tailType); + CHECK(varFunc->vararg); + CHECK(varFunc->varargLocation == Location({3, 29}, {3, 32})); matchParseError("declare function foo(x)", "All declaration parameters must be annotated"); matchParseError("declare foo", "Expected ':' when parsing global variable declaration, got "); @@ -1890,6 +1898,8 @@ TEST_CASE_FIXTURE(Fixture, "parse_declarations") TEST_CASE_FIXTURE(Fixture, "parse_class_declarations") { + ScopedFastFlag luauDeclarationExtraPropData{FFlag::LuauDeclarationExtraPropData, true}; + AstStatBlock* stat = parseEx(R"( declare class Foo prop: number @@ -1913,11 +1923,16 @@ TEST_CASE_FIXTURE(Fixture, "parse_class_declarations") AstDeclaredClassProp& prop = declaredClass->props.data[0]; CHECK(prop.name == "prop"); + CHECK(prop.nameLocation == Location({2, 12}, {2, 16})); CHECK(prop.ty->is()); + CHECK(prop.location == Location({2, 12}, {2, 24})); AstDeclaredClassProp& method = declaredClass->props.data[1]; CHECK(method.name == "method"); + CHECK(method.nameLocation == Location({3, 21}, {3, 27})); CHECK(method.ty->is()); + CHECK(method.location == Location({3, 12}, {3, 54})); + CHECK(method.isMethod); AstStatDeclareClass* subclass = stat->body.data[1]->as(); REQUIRE(subclass); @@ -1928,7 +1943,9 @@ TEST_CASE_FIXTURE(Fixture, "parse_class_declarations") REQUIRE_EQ(subclass->props.size, 1); AstDeclaredClassProp& prop2 = subclass->props.data[0]; CHECK(prop2.name == "prop2"); + CHECK(prop2.nameLocation == Location({7, 12}, {7, 17})); CHECK(prop2.ty->is()); + CHECK(prop2.location == Location({7, 12}, {7, 25})); } TEST_CASE_FIXTURE(Fixture, "class_method_properties") diff --git a/tests/TypeFamily.test.cpp b/tests/TypeFamily.test.cpp index 063ed39c..068e8684 100644 --- a/tests/TypeFamily.test.cpp +++ b/tests/TypeFamily.test.cpp @@ -1045,4 +1045,122 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "index_type_family_works_w_index_metatables") CHECK(toString(result.errors[0]) == "Property '\"Car\"' does not exist on type 'exampleClass2'"); } +TEST_CASE_FIXTURE(BuiltinsFixture, "rawget_type_family_works") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + type MyObject = {a: string, b: number, c: boolean} + type RawAType = rawget + type RawBType = rawget> + local function ok(idx: RawAType): string return idx end + local function ok2(idx: RawBType): string | number | boolean return idx end + local function err(idx: RawAType): boolean return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK_EQ("boolean", toString(tpm->wantedTp)); + CHECK_EQ("string", toString(tpm->givenTp)); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "rawget_type_family_works_w_array") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + local MyObject = {"hello", 1, true} + type RawAType = rawget + local function ok(idx: RawAType): string | number | boolean return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "rawget_type_family_errors_w_var_indexer") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + type MyObject = {a: string, b: number, c: boolean} + local key = "a" + type errType1 = rawget + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK(toString(result.errors[0]) == "Second argument to rawget is not a valid index type"); + CHECK(toString(result.errors[1]) == "Unknown type 'key'"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "rawget_type_family_works_w_union_type_indexer") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + type MyObject = {a: string, b: number, c: boolean} + type rawType = rawget + local function ok(idx: rawType): string | number return idx end + type errType = rawget + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(toString(result.errors[0]) == "Property '\"a\" | \"d\"' does not exist on type 'MyObject'"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "rawget_type_family_works_w_union_type_indexee") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + type MyObject = {a: string, b: number, c: boolean} + type MyObject2 = {a: number} + type rawTypeA = rawget + local function ok(idx: rawTypeA): string | number return idx end + type errType = rawget + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(toString(result.errors[0]) == "Property '\"b\"' does not exist on type 'MyObject | MyObject2'"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "rawget_type_family_works_w_index_metatables") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + local exampleClass = { Foo = "text", Bar = true } + local exampleClass2 = setmetatable({ Foo = 8 }, { __index = exampleClass }) + type exampleTy2 = rawget + local function ok(idx: exampleTy2): number return idx end + local exampleClass3 = setmetatable({ Bar = 5 }, { __index = exampleClass }) + type errType = rawget + type errType2 = rawget + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK(toString(result.errors[0]) == "Property '\"Foo\"' does not exist on type 'exampleClass3'"); + CHECK(toString(result.errors[1]) == "Property '\"Bar\" | \"Foo\"' does not exist on type 'exampleClass3'"); +} + +TEST_CASE_FIXTURE(ClassFixture, "rawget_type_family_errors_w_classes") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + type PropsOfMyObject = rawget + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(toString(result.errors[0]) == "Property '\"BaseField\"' does not exist on type 'BaseClass'"); +} + TEST_SUITE_END(); \ No newline at end of file diff --git a/tests/TypeInfer.anyerror.test.cpp b/tests/TypeInfer.anyerror.test.cpp index b305d97d..c532c069 100644 --- a/tests/TypeInfer.anyerror.test.cpp +++ b/tests/TypeInfer.anyerror.test.cpp @@ -401,4 +401,13 @@ end CHECK("(any, any) -> any" == toString(requireType("foo"))); } +TEST_CASE_FIXTURE(Fixture, "cast_to_table_of_any") +{ + CheckResult result = check(R"( + local v = {true} :: {any} + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.classes.test.cpp b/tests/TypeInfer.classes.test.cpp index 8f8aef84..f90324d7 100644 --- a/tests/TypeInfer.classes.test.cpp +++ b/tests/TypeInfer.classes.test.cpp @@ -7,6 +7,7 @@ #include "Fixture.h" #include "ClassFixture.h" +#include "ScopedFlags.h" #include "doctest.h" using namespace Luau; @@ -507,6 +508,31 @@ Type 'ChildClass' could not be converted into 'BaseClass' in an invariant contex } } +TEST_CASE_FIXTURE(ClassFixture, "optional_class_casts_work_in_new_solver") +{ + ScopedFastFlag sff{FFlag::DebugLuauDeferredConstraintResolution, true}; + + CheckResult result = check(R"( + type A = { x: ChildClass } + type B = { x: BaseClass } + + local a = { x = ChildClass.New() } :: A + local opt_a = a :: A? + local b = { x = BaseClass.New() } :: B + local opt_b = b :: B? + local b_from_a = a :: B + local b_from_opt_a = opt_a :: B + local opt_b_from_a = a :: B? + local opt_b_from_opt_a = opt_a :: B? + local a_from_b = b :: A + local a_from_opt_b = opt_b :: A + local opt_a_from_b = b :: A? + local opt_a_from_opt_b = opt_b :: A? + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_CASE_FIXTURE(ClassFixture, "callable_classes") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.definitions.test.cpp b/tests/TypeInfer.definitions.test.cpp index c57eab79..688f27b7 100644 --- a/tests/TypeInfer.definitions.test.cpp +++ b/tests/TypeInfer.definitions.test.cpp @@ -7,6 +7,8 @@ #include "doctest.h" +LUAU_FASTFLAG(LuauDeclarationExtraPropData) + using namespace Luau; TEST_SUITE_BEGIN("DefinitionTests"); @@ -319,6 +321,8 @@ TEST_CASE_FIXTURE(Fixture, "definitions_documentation_symbols") TEST_CASE_FIXTURE(Fixture, "definitions_symbols_are_generated_for_recursively_referenced_types") { + ScopedFastFlag luauDeclarationExtraPropData{FFlag::LuauDeclarationExtraPropData, true}; + loadDefinition(R"( declare class MyClass function myMethod(self) @@ -330,6 +334,22 @@ TEST_CASE_FIXTURE(Fixture, "definitions_symbols_are_generated_for_recursively_re std::optional myClassTy = frontend.globals.globalScope->lookupType("MyClass"); REQUIRE(bool(myClassTy)); CHECK_EQ(myClassTy->type->documentationSymbol, "@test/globaltype/MyClass"); + + ClassType* cls = getMutable(myClassTy->type); + REQUIRE(bool(cls)); + REQUIRE_EQ(cls->props.count("myMethod"), 1); + + const auto& method = cls->props["myMethod"]; + CHECK_EQ(method.documentationSymbol, "@test/globaltype/MyClass.myMethod"); + + FunctionType* function = getMutable(method.type()); + REQUIRE(function); + + REQUIRE(function->definition.has_value()); + CHECK(function->definition->definitionModuleName == "@test"); + CHECK(function->definition->definitionLocation == Location({2, 12}, {2, 35})); + CHECK(!function->definition->varargLocation.has_value()); + CHECK(function->definition->originalNameLocation == Location({2, 21}, {2, 29})); } TEST_CASE_FIXTURE(Fixture, "documentation_symbols_dont_attach_to_persistent_types") diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 4f8ed3eb..410a9859 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -2379,6 +2379,28 @@ end CHECK("number" == toString(err->recommendedArgs[1].second)); } +TEST_CASE_FIXTURE(BuiltinsFixture, "tf_suggest_arg_type_2") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + // Make sure the error types are cloned to module interface + frontend.options.retainFullTypeGraphs = false; + + CheckResult result = check(R"( +local function escape_fslash(pre) + return (#pre % 2 == 0 and '\\' or '') .. pre .. '.' +end +)"); + + LUAU_REQUIRE_ERRORS(result); + auto err = get(result.errors.back()); + LUAU_ASSERT(err); + CHECK("unknown" == toString(err->recommendedReturn)); + REQUIRE(err->recommendedArgs.size() == 1); + CHECK("a" == toString(err->recommendedArgs[0].second)); +} + TEST_CASE_FIXTURE(Fixture, "local_function_fwd_decl_doesnt_crash") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index a34af12d..3072169c 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -312,7 +312,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "bail_early_if_unification_is_too_complicated } } -// FIXME: Move this test to another source file when removing FFlag::LuauLowerBoundsCalculation TEST_CASE_FIXTURE(Fixture, "do_not_ice_when_trying_to_pick_first_of_generic_type_pack") { // In-place quantification causes these types to have the wrong types but only because of nasty interaction with prototyping. diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 4dbedd51..6f8b4f50 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -15,7 +15,6 @@ using namespace Luau; -LUAU_FASTFLAG(LuauLowerBoundsCalculation); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAG(LuauInstantiateInSubtyping); LUAU_FASTFLAG(LuauAlwaysCommitInferencesOfFunctionCalls); @@ -3257,7 +3256,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "dont_leak_free_table_props") TEST_CASE_FIXTURE(Fixture, "inferred_return_type_of_free_table") { ScopedFastFlag sff[] = { - // {FFlag::LuauLowerBoundsCalculation, true}, {FFlag::DebugLuauSharedSelf, true}, }; diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 58ccea89..92f07c43 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -13,6 +13,7 @@ using namespace Luau; LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAG(LuauAlwaysCommitInferencesOfFunctionCalls); +LUAU_FASTFLAG(LuauUnifierRecursionOnRestart); struct TryUnifyFixture : Fixture { @@ -480,4 +481,34 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "unifying_two_unions_under_dcr_does_not_creat } } +TEST_CASE_FIXTURE(BuiltinsFixture, "table_unification_full_restart_recursion") +{ + ScopedFastFlag luauUnifierRecursionOnRestart{FFlag::LuauUnifierRecursionOnRestart, true}; + + CheckResult result = check(R"( +local A, B, C, D + +E = function(a, b) + local mt = getmetatable(b) + if mt.tm:bar(A) == nil and mt.tm:bar(B) == nil then end + if mt.foo == true then D(b, 3) end + mt.foo:call(false, b) +end + +A = function(a, b) + local mt = getmetatable(b) + if mt.foo == true then D(b, 3) end + C(mt, 3) +end + +B = function(a, b) + local mt = getmetatable(b) + if mt.foo == true then D(b, 3) end + C(mt, 3) +end + )"); + + LUAU_REQUIRE_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.unknownnever.test.cpp b/tests/TypeInfer.unknownnever.test.cpp index 8ec70d11..f1924b1c 100644 --- a/tests/TypeInfer.unknownnever.test.cpp +++ b/tests/TypeInfer.unknownnever.test.cpp @@ -396,4 +396,15 @@ TEST_CASE_FIXTURE(Fixture, "lti_permit_explicit_never_annotation") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "cast_from_never_does_not_error") +{ + CheckResult result = check(R"( + local function f(x: never): number + return x :: number + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/conformance/math.lua b/tests/conformance/math.lua index b8fc882a..98d5b317 100644 --- a/tests/conformance/math.lua +++ b/tests/conformance/math.lua @@ -296,14 +296,26 @@ assert(math.max(ma, mc, mb) == 2) assert(math.max(ma, mb, mc) == 2) assert(math.max(ma, mb, mc, md) == 2) +local inf = math.huge * 2 +local nan = 0 / 0 + +assert(math.min(nan, 2) ~= math.min(nan, 2)) +assert(math.min(1, nan) == 1) +assert(math.max(nan, 2) ~= math.max(nan, 2)) +assert(math.max(1, nan) == 1) + +local function noinline(x, ...) local s, r = pcall(function(y) return y end, x) return r end + -- noise assert(math.noise(0.5) == 0) assert(math.noise(0.5, 0.5) == -0.25) assert(math.noise(0.5, 0.5, -0.5) == 0.125) assert(math.noise(455.7204209769105, 340.80410508750134, 121.80087666537628) == 0.5010709762573242) -local inf = math.huge * 2 -local nan = 0 / 0 +assert(math.noise(noinline(0.5)) == 0) +assert(math.noise(noinline(0.5), 0.5) == -0.25) +assert(math.noise(noinline(0.5), 0.5, -0.5) == 0.125) +assert(math.noise(noinline(455.7204209769105), 340.80410508750134, 121.80087666537628) == 0.5010709762573242) -- sign assert(math.sign(0) == 0) @@ -313,10 +325,12 @@ assert(math.sign(inf) == 1) assert(math.sign(-inf) == -1) assert(math.sign(nan) == 0) -assert(math.min(nan, 2) ~= math.min(nan, 2)) -assert(math.min(1, nan) == 1) -assert(math.max(nan, 2) ~= math.max(nan, 2)) -assert(math.max(1, nan) == 1) +assert(math.sign(noinline(0)) == 0) +assert(math.sign(noinline(42)) == 1) +assert(math.sign(noinline(-42)) == -1) +assert(math.sign(noinline(inf)) == 1) +assert(math.sign(noinline(-inf)) == -1) +assert(math.sign(noinline(nan)) == 0) -- clamp assert(math.clamp(-1, 0, 1) == 0) @@ -324,6 +338,11 @@ assert(math.clamp(0.5, 0, 1) == 0.5) assert(math.clamp(2, 0, 1) == 1) assert(math.clamp(4, 0, 0) == 0) +assert(math.clamp(noinline(-1), 0, 1) == 0) +assert(math.clamp(noinline(0.5), 0, 1) == 0.5) +assert(math.clamp(noinline(2), 0, 1) == 1) +assert(math.clamp(noinline(4), 0, 0) == 0) + -- round assert(math.round(0) == 0) assert(math.round(0.4) == 0) @@ -336,19 +355,58 @@ assert(math.round(math.huge) == math.huge) assert(math.round(0.49999999999999994) == 0) assert(math.round(-0.49999999999999994) == 0) +assert(math.round(noinline(0)) == 0) +assert(math.round(noinline(0.4)) == 0) +assert(math.round(noinline(0.5)) == 1) +assert(math.round(noinline(3.5)) == 4) +assert(math.round(noinline(-0.4)) == 0) +assert(math.round(noinline(-0.5)) == -1) +assert(math.round(noinline(-3.5)) == -4) +assert(math.round(noinline(math.huge)) == math.huge) +assert(math.round(noinline(0.49999999999999994)) == 0) +assert(math.round(noinline(-0.49999999999999994)) == 0) + -- fmod assert(math.fmod(3, 2) == 1) assert(math.fmod(-3, 2) == -1) assert(math.fmod(3, -2) == 1) assert(math.fmod(-3, -2) == -1) +assert(math.fmod(noinline(3), 2) == 1) +assert(math.fmod(noinline(-3), 2) == -1) +assert(math.fmod(noinline(3), -2) == 1) +assert(math.fmod(noinline(-3), -2) == -1) + -- pow assert(math.pow(2, 0) == 1) assert(math.pow(2, 2) == 4) assert(math.pow(4, 0.5) == 2) assert(math.pow(-2, 2) == 4) + +assert(math.pow(noinline(2), 0) == 1) +assert(math.pow(noinline(2), 2) == 4) +assert(math.pow(noinline(4), 0.5) == 2) +assert(math.pow(noinline(-2), 2) == 4) + assert(tostring(math.pow(-2, 0.5)) == "nan") +-- test that fastcalls return correct number of results +assert(select('#', math.floor(1.4)) == 1) +assert(select('#', math.ceil(1.6)) == 1) +assert(select('#', math.sqrt(9)) == 1) +assert(select('#', math.deg(9)) == 1) +assert(select('#', math.rad(9)) == 1) +assert(select('#', math.sin(1.5)) == 1) +assert(select('#', math.atan2(1.5, 0.5)) == 1) +assert(select('#', math.modf(1.5)) == 2) +assert(select('#', math.frexp(1.5)) == 2) + +-- test that fastcalls that return variadic results return them correctly in variadic position +assert(select(1, math.modf(1.5)) == 1) +assert(select(2, math.modf(1.5)) == 0.5) +assert(select(1, math.frexp(1.5)) == 0.75) +assert(select(2, math.frexp(1.5)) == 1) + -- most of the tests above go through fastcall path -- to make sure the basic implementations are also correct we test these functions with string->number coercions assert(math.abs("-4") == 4) @@ -393,21 +451,4 @@ assert(math.sign("-2") == -1) assert(math.sign("0") == 0) assert(math.round("1.8") == 2) --- test that fastcalls return correct number of results -assert(select('#', math.floor(1.4)) == 1) -assert(select('#', math.ceil(1.6)) == 1) -assert(select('#', math.sqrt(9)) == 1) -assert(select('#', math.deg(9)) == 1) -assert(select('#', math.rad(9)) == 1) -assert(select('#', math.sin(1.5)) == 1) -assert(select('#', math.atan2(1.5, 0.5)) == 1) -assert(select('#', math.modf(1.5)) == 2) -assert(select('#', math.frexp(1.5)) == 2) - --- test that fastcalls that return variadic results return them correctly in variadic position -assert(select(1, math.modf(1.5)) == 1) -assert(select(2, math.modf(1.5)) == 0.5) -assert(select(1, math.frexp(1.5)) == 0.75) -assert(select(2, math.frexp(1.5)) == 1) - return('OK')