From d0222bb55465e05f03d4e4a84f7eb53d49c0c3f4 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Mon, 2 Dec 2024 16:16:33 -0800 Subject: [PATCH 1/2] Sync to upstream/release/654 (#1552) # What's Changed * Support dead store elimination for `STORE_VECTOR` instruction * Fix parser hang when a separator is used between Luau class declaration properties * Provide properties and metatable for built-in vector type definition to fix type errors * Fix Fragment Autocomplete to ensure correct parentheses insertion behavior. * Add support for 'thread' and 'buffer' primitive types in user-defined type functions --------- Co-authored-by: Andy Friesen Co-authored-by: Hunter Goldstein Co-authored-by: Vighnesh Vijay Co-authored-by: Vyacheslav Egorov --- Analysis/include/Luau/Frontend.h | 1 + Analysis/include/Luau/TypeFunctionRuntime.h | 2 + Analysis/src/AutocompleteCore.cpp | 2 +- Analysis/src/BuiltinDefinitions.cpp | 29 ++- Analysis/src/EmbeddedBuiltinDefinitions.cpp | 37 ++- Analysis/src/FragmentAutocomplete.cpp | 56 ++--- Analysis/src/Frontend.cpp | 26 ++ Analysis/src/TypeFunctionRuntime.cpp | 37 ++- Analysis/src/TypeFunctionRuntimeBuilder.cpp | 39 ++- Analysis/src/TypeInfer.cpp | 14 +- Ast/include/Luau/Lexer.h | 13 +- Ast/include/Luau/ParseOptions.h | 1 + Ast/src/Lexer.cpp | 18 +- Ast/src/Parser.cpp | 33 ++- CodeGen/include/Luau/IrData.h | 2 + CodeGen/src/IrLoweringA64.cpp | 10 +- CodeGen/src/IrLoweringX64.cpp | 6 +- CodeGen/src/OptimizeConstProp.cpp | 6 +- CodeGen/src/OptimizeDeadStore.cpp | 159 +++++++++++- tests/Fixture.h | 3 + tests/FragmentAutocomplete.test.cpp | 12 +- tests/Frontend.test.cpp | 75 +++++- tests/IrBuilder.test.cpp | 258 ++++++++++++++++++++ tests/Parser.test.cpp | 15 ++ tests/TypeFunction.user.test.cpp | 31 +++ tests/TypeInfer.modules.test.cpp | 3 - tests/TypeInfer.primitives.test.cpp | 32 +++ 27 files changed, 825 insertions(+), 95 deletions(-) diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index 49d7a36d..272ee52a 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -194,6 +194,7 @@ struct Frontend ); std::optional getCheckResult(const ModuleName& name, bool accumulateNested, bool forAutocomplete = false); + std::vector getRequiredScripts(const ModuleName& name); private: ModulePtr check( diff --git a/Analysis/include/Luau/TypeFunctionRuntime.h b/Analysis/include/Luau/TypeFunctionRuntime.h index be091351..356d34a5 100644 --- a/Analysis/include/Luau/TypeFunctionRuntime.h +++ b/Analysis/include/Luau/TypeFunctionRuntime.h @@ -31,6 +31,8 @@ struct TypeFunctionPrimitiveType Boolean, Number, String, + Thread, + Buffer, }; Type type; diff --git a/Analysis/src/AutocompleteCore.cpp b/Analysis/src/AutocompleteCore.cpp index 06080b8c..3e231acf 100644 --- a/Analysis/src/AutocompleteCore.cpp +++ b/Analysis/src/AutocompleteCore.cpp @@ -1701,7 +1701,7 @@ AutocompleteResult autocomplete_( NotNull builtinTypes, TypeArena* typeArena, std::vector& ancestry, - Scope* globalScope, + Scope* globalScope, // [TODO] This is unused argument, do we really need this? const ScopePtr& scopeAtPosition, Position position, FileResolver* fileResolver, diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index c0a50ec3..6306b5b1 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -31,8 +31,8 @@ LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAGVARIABLE(LuauTypestateBuiltins2) LUAU_FASTFLAGVARIABLE(LuauStringFormatArityFix) - -LUAU_FASTFLAG(AutocompleteRequirePathSuggestions2); +LUAU_FASTFLAG(AutocompleteRequirePathSuggestions2) +LUAU_FASTFLAG(LuauVectorDefinitionsExtra) namespace Luau { @@ -300,6 +300,31 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC addGlobalBinding(globals, "string", it->second.type(), "@luau"); + // Setup 'vector' metatable + if (FFlag::LuauVectorDefinitionsExtra) + { + if (auto it = globals.globalScope->exportedTypeBindings.find("vector"); it != globals.globalScope->exportedTypeBindings.end()) + { + TypeId vectorTy = it->second.type; + ClassType* vectorCls = getMutable(vectorTy); + + vectorCls->metatable = arena.addType(TableType{{}, std::nullopt, TypeLevel{}, TableState::Sealed}); + TableType* metatableTy = Luau::getMutable(vectorCls->metatable); + + metatableTy->props["__add"] = {makeFunction(arena, vectorTy, {vectorTy}, {vectorTy})}; + metatableTy->props["__sub"] = {makeFunction(arena, vectorTy, {vectorTy}, {vectorTy})}; + metatableTy->props["__unm"] = {makeFunction(arena, vectorTy, {}, {vectorTy})}; + + std::initializer_list mulOverloads{ + makeFunction(arena, vectorTy, {vectorTy}, {vectorTy}), + makeFunction(arena, vectorTy, {builtinTypes->numberType}, {vectorTy}), + }; + metatableTy->props["__mul"] = {makeIntersection(arena, mulOverloads)}; + metatableTy->props["__div"] = {makeIntersection(arena, mulOverloads)}; + metatableTy->props["__idiv"] = {makeIntersection(arena, mulOverloads)}; + } + } + // next(t: Table, i: K?) -> (K?, V) TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(builtinTypes, arena, genericK)}}); TypePackId nextRetsTypePack = arena.addTypePack(TypePack{{makeOption(builtinTypes, arena, genericK), genericV}}); diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index 209bcdbe..828fc7ed 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -4,6 +4,7 @@ LUAU_FASTFLAG(LuauMathMap) LUAU_FASTFLAGVARIABLE(LuauVectorDefinitions) +LUAU_FASTFLAGVARIABLE(LuauVectorDefinitionsExtra) namespace Luau { @@ -452,7 +453,7 @@ declare buffer: { )BUILTIN_SRC"; -static const std::string kBuiltinDefinitionVectorSrc = R"BUILTIN_SRC( +static const std::string kBuiltinDefinitionVectorSrc_DEPRECATED = R"BUILTIN_SRC( -- TODO: this will be replaced with a built-in primitive type declare class vector end @@ -478,12 +479,44 @@ declare vector: { )BUILTIN_SRC"; +static const std::string kBuiltinDefinitionVectorSrc = R"BUILTIN_SRC( + +-- While vector would have been better represented as a built-in primitive type, type solver class handling covers most of the properties +declare class vector + x: number + y: number + z: number +end + +declare vector: { + create: @checked (x: number, y: number, z: number) -> vector, + magnitude: @checked (vec: vector) -> number, + normalize: @checked (vec: vector) -> vector, + cross: @checked (vec1: vector, vec2: vector) -> vector, + dot: @checked (vec1: vector, vec2: vector) -> number, + angle: @checked (vec1: vector, vec2: vector, axis: vector?) -> number, + floor: @checked (vec: vector) -> vector, + ceil: @checked (vec: vector) -> vector, + abs: @checked (vec: vector) -> vector, + sign: @checked (vec: vector) -> vector, + clamp: @checked (vec: vector, min: vector, max: vector) -> vector, + max: @checked (vector, ...vector) -> vector, + min: @checked (vector, ...vector) -> vector, + + zero: vector, + one: vector, +} + +)BUILTIN_SRC"; + std::string getBuiltinDefinitionSource() { std::string result = FFlag::LuauMathMap ? kBuiltinDefinitionLuaSrcChecked : kBuiltinDefinitionLuaSrcChecked_DEPRECATED; - if (FFlag::LuauVectorDefinitions) + if (FFlag::LuauVectorDefinitionsExtra) result += kBuiltinDefinitionVectorSrc; + else if (FFlag::LuauVectorDefinitions) + result += kBuiltinDefinitionVectorSrc_DEPRECATED; return result; } diff --git a/Analysis/src/FragmentAutocomplete.cpp b/Analysis/src/FragmentAutocomplete.cpp index 33989a2b..5819d309 100644 --- a/Analysis/src/FragmentAutocomplete.cpp +++ b/Analysis/src/FragmentAutocomplete.cpp @@ -121,19 +121,13 @@ FragmentAutocompleteAncestryResult findAncestryForFragmentParse(AstStatBlock* ro /** * Get document offsets is a function that takes a source text document as well as a start position and end position(line, column) in that - * document and attempts to get the concrete text between those points. It returns a tuple of: + * document and attempts to get the concrete text between those points. It returns a pair of: * - start offset that represents an index in the source `char*` corresponding to startPos * - length, that represents how many more bytes to read to get to endPos. - * - cursorPos, that represents the position of the cursor relative to the start offset. - * Example - your document is "foo bar baz" and getDocumentOffsets is passed (0, 4), (0, 7), (0, 8). This function returns the tuple {3, 5, - * Position{0, 4}}, which corresponds to the string " bar " + * Example - your document is "foo bar baz" and getDocumentOffsets is passed (0, 4), (0, 8). This function returns the pair {3, 5} + * which corresponds to the string " bar " */ -std::tuple getDocumentOffsets( - const std::string_view& src, - const Position& startPos, - Position cursorPos, - const Position& endPos -) +std::pair getDocumentOffsets(const std::string_view& src, const Position& startPos, const Position& endPos) { size_t lineCount = 0; size_t colCount = 0; @@ -142,12 +136,8 @@ std::tuple getDocumentOffsets( size_t startOffset = 0; size_t endOffset = 0; bool foundStart = false; - bool foundCursor = false; bool foundEnd = false; - unsigned int colOffsetFromStart = 0; - unsigned int lineOffsetFromStart = 0; - for (char c : src) { if (foundStart && foundEnd) @@ -159,15 +149,11 @@ std::tuple getDocumentOffsets( startOffset = docOffset; } - if (cursorPos.line == lineCount && cursorPos.column == colCount) - { - foundCursor = true; - cursorPos = {lineOffsetFromStart, colOffsetFromStart}; - } - if (endPos.line == lineCount && endPos.column == colCount) { endOffset = docOffset; + while (endOffset < src.size() && src[endOffset] != '\n') + endOffset++; foundEnd = true; } @@ -180,18 +166,11 @@ std::tuple getDocumentOffsets( if (c == '\n') { - if (foundStart) - { - lineOffsetFromStart++; - colOffsetFromStart = 0; - } lineCount++; colCount = 0; } else { - if (foundStart) - colOffsetFromStart++; colCount++; } docOffset++; @@ -200,12 +179,9 @@ std::tuple getDocumentOffsets( if (foundStart && !foundEnd) endOffset = src.length(); - if (foundStart && !foundCursor) - cursorPos = {lineOffsetFromStart, colOffsetFromStart}; - size_t min = std::min(startOffset, endOffset); size_t len = std::max(startOffset, endOffset) - min; - return {min, len, cursorPos}; + return {min, len}; } ScopePtr findClosestScope(const ModulePtr& module, const AstStat* nearestStatement) @@ -232,10 +208,6 @@ FragmentParseResult parseFragment( ) { FragmentAutocompleteAncestryResult result = findAncestryForFragmentParse(srcModule.root, cursorPos); - ParseOptions opts; - opts.allowDeclarationSyntax = false; - opts.captureComments = true; - opts.parseFragment = FragmentParseResumeSettings{std::move(result.localMap), std::move(result.localStack)}; AstStat* nearestStatement = result.nearestStatement; const Location& rootSpan = srcModule.root->location; @@ -260,15 +232,18 @@ FragmentParseResult parseFragment( else startPos = nearestStatement->location.begin; - auto [offsetStart, parseLength, cursorInFragment] = getDocumentOffsets(src, startPos, cursorPos, endPos); - - + auto [offsetStart, parseLength] = getDocumentOffsets(src, startPos, endPos); const char* srcStart = src.data() + offsetStart; std::string_view dbg = src.substr(offsetStart, parseLength); const std::shared_ptr& nameTbl = srcModule.names; FragmentParseResult fragmentResult; fragmentResult.fragmentToParse = std::string(dbg.data(), parseLength); // For the duration of the incremental parse, we want to allow the name table to re-use duplicate names + + ParseOptions opts; + opts.allowDeclarationSyntax = false; + opts.captureComments = true; + opts.parseFragment = FragmentParseResumeSettings{std::move(result.localMap), std::move(result.localStack), startPos}; ParseResult p = Luau::Parser::parse(srcStart, parseLength, *nameTbl, *fragmentResult.alloc.get(), opts); std::vector fabricatedAncestry = std::move(result.ancestry); @@ -276,7 +251,7 @@ FragmentParseResult parseFragment( // Get the ancestry for the fragment at the offset cursor position. // Consumers have the option to request with fragment end position, so we cannot just use the end position of our parse result as the // cursor position. Instead, use the cursor position calculated as an offset from our start position. - std::vector fragmentAncestry = findAncestryAtPositionForAutocomplete(p.root, cursorInFragment); + std::vector fragmentAncestry = findAncestryAtPositionForAutocomplete(p.root, cursorPos); fabricatedAncestry.insert(fabricatedAncestry.end(), fragmentAncestry.begin(), fragmentAncestry.end()); if (nearestStatement == nullptr) nearestStatement = p.root; @@ -524,6 +499,7 @@ FragmentAutocompleteResult fragmentAutocomplete( } auto tcResult = typecheckFragment(frontend, moduleName, cursorPosition, opts, src, fragmentEndPosition); + auto globalScope = (opts && opts->forAutocomplete) ? frontend.globalsForAutocomplete.globalScope.get() : frontend.globals.globalScope.get(); TypeArena arenaForFragmentAutocomplete; auto result = Luau::autocomplete_( @@ -531,7 +507,7 @@ FragmentAutocompleteResult fragmentAutocomplete( frontend.builtinTypes, &arenaForFragmentAutocomplete, tcResult.ancestry, - frontend.globals.globalScope.get(), + globalScope, tcResult.freshScope, cursorPosition, frontend.fileResolver, diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 396678463..053e99c2 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -743,6 +743,32 @@ std::optional Frontend::getCheckResult(const ModuleName& name, bool return checkResult; } +std::vector Frontend::getRequiredScripts(const ModuleName& name) +{ + RequireTraceResult require = requireTrace[name]; + if (isDirty(name)) + { + std::optional source = fileResolver->readSource(name); + if (!source) + { + return {}; + } + const Config& config = configResolver->getConfig(name); + ParseOptions opts = config.parseOptions; + opts.captureComments = true; + SourceModule result = parse(name, source->source, opts); + result.type = source->type; + require = traceRequires(fileResolver, result.root, name); + } + std::vector requiredModuleNames; + requiredModuleNames.reserve(require.requireList.size()); + for (const auto& [moduleName, _] : require.requireList) + { + requiredModuleNames.push_back(moduleName); + } + return requiredModuleNames; +} + bool Frontend::parseGraph( std::vector& buildQueue, const ModuleName& root, diff --git a/Analysis/src/TypeFunctionRuntime.cpp b/Analysis/src/TypeFunctionRuntime.cpp index ad38a1f3..8a129462 100644 --- a/Analysis/src/TypeFunctionRuntime.cpp +++ b/Analysis/src/TypeFunctionRuntime.cpp @@ -16,6 +16,7 @@ LUAU_DYNAMIC_FASTINT(LuauTypeFunctionSerdeIterationLimit) LUAU_FASTFLAGVARIABLE(LuauUserTypeFunFixRegister) LUAU_FASTFLAGVARIABLE(LuauUserTypeFunFixNoReadWrite) +LUAU_FASTFLAGVARIABLE(LuauUserTypeFunThreadBuffer) namespace Luau { @@ -133,6 +134,12 @@ static std::string getTag(lua_State* L, TypeFunctionTypeId ty) return "number"; else if (auto s = get(ty); s && s->type == TypeFunctionPrimitiveType::Type::String) return "string"; + else if (auto s = get(ty); + FFlag::LuauUserTypeFunThreadBuffer && s && s->type == TypeFunctionPrimitiveType::Type::Thread) + return "thread"; + else if (auto s = get(ty); + FFlag::LuauUserTypeFunThreadBuffer && s && s->type == TypeFunctionPrimitiveType::Type::Buffer) + return "buffer"; else if (get(ty)) return "unknown"; else if (get(ty)) @@ -212,6 +219,22 @@ static int createString(lua_State* L) return 1; } +// Luau: `type.thread` +static int createThread(lua_State* L) +{ + allocTypeUserData(L, TypeFunctionPrimitiveType{TypeFunctionPrimitiveType::Thread}); + + return 1; +} + +// Luau: `type.buffer` +static int createBuffer(lua_State* L) +{ + allocTypeUserData(L, TypeFunctionPrimitiveType{TypeFunctionPrimitiveType::Buffer}); + + return 1; +} + // Luau: `type.singleton(value: string | boolean | nil) -> type` // Returns the type instance representing string or boolean singleton or nil static int createSingleton(lua_State* L) @@ -1394,6 +1417,8 @@ void registerTypesLibrary(lua_State* L) {"boolean", createBoolean}, {"number", createNumber}, {"string", createString}, + {FFlag::LuauUserTypeFunThreadBuffer ? "thread" : nullptr, FFlag::LuauUserTypeFunThreadBuffer ? createThread : nullptr}, + {FFlag::LuauUserTypeFunThreadBuffer ? "buffer" : nullptr, FFlag::LuauUserTypeFunThreadBuffer ? createBuffer : nullptr}, {nullptr, nullptr} }; @@ -2118,10 +2143,10 @@ private: { switch (p->type) { - case TypeFunctionPrimitiveType::Type::NilType: + case TypeFunctionPrimitiveType::NilType: target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::NilType)); break; - case TypeFunctionPrimitiveType::Type::Boolean: + case TypeFunctionPrimitiveType::Boolean: target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::Boolean)); break; case TypeFunctionPrimitiveType::Number: @@ -2130,6 +2155,14 @@ private: case TypeFunctionPrimitiveType::String: target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::String)); break; + case TypeFunctionPrimitiveType::Thread: + if (FFlag::LuauUserTypeFunThreadBuffer) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::Thread)); + break; + case TypeFunctionPrimitiveType::Buffer: + if (FFlag::LuauUserTypeFunThreadBuffer) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::Buffer)); + break; default: break; } diff --git a/Analysis/src/TypeFunctionRuntimeBuilder.cpp b/Analysis/src/TypeFunctionRuntimeBuilder.cpp index ca0c1b72..a102e5da 100644 --- a/Analysis/src/TypeFunctionRuntimeBuilder.cpp +++ b/Analysis/src/TypeFunctionRuntimeBuilder.cpp @@ -21,6 +21,7 @@ LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFunctionSerdeIterationLimit, 100'000); LUAU_FASTFLAGVARIABLE(LuauUserTypeFunFixMetatable) +LUAU_FASTFLAG(LuauUserTypeFunThreadBuffer) namespace Luau { @@ -147,10 +148,10 @@ private: { switch (p->type) { - case PrimitiveType::Type::NilType: + case PrimitiveType::NilType: target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::NilType)); break; - case PrimitiveType::Type::Boolean: + case PrimitiveType::Boolean: target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::Boolean)); break; case PrimitiveType::Number: @@ -160,9 +161,29 @@ private: target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::String)); break; case PrimitiveType::Thread: + if (FFlag::LuauUserTypeFunThreadBuffer) + { + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::Thread)); + } + else + { + std::string error = format("Argument of primitive type %s is not currently serializable by type functions", toString(ty).c_str()); + state->errors.push_back(error); + } + break; + case PrimitiveType::Buffer: + if (FFlag::LuauUserTypeFunThreadBuffer) + { + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::Buffer)); + } + else + { + std::string error = format("Argument of primitive type %s is not currently serializable by type functions", toString(ty).c_str()); + state->errors.push_back(error); + } + break; case PrimitiveType::Function: case PrimitiveType::Table: - case PrimitiveType::Buffer: default: { std::string error = format("Argument of primitive type %s is not currently serializable by type functions", toString(ty).c_str()); @@ -565,6 +586,18 @@ private: case TypeFunctionPrimitiveType::Type::String: target = state->ctx->builtins->stringType; break; + case TypeFunctionPrimitiveType::Type::Thread: + if (FFlag::LuauUserTypeFunThreadBuffer) + target = state->ctx->builtins->threadType; + else + state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized"); + break; + case TypeFunctionPrimitiveType::Type::Buffer: + if (FFlag::LuauUserTypeFunThreadBuffer) + target = state->ctx->builtins->bufferType; + else + state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized"); + break; default: state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized"); } diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 7ed3290b..911d4b5e 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -33,7 +33,6 @@ LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification) LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAGVARIABLE(LuauMetatableFollow) -LUAU_FASTFLAGVARIABLE(LuauRequireCyclesDontAlwaysReturnAny) namespace Luau { @@ -264,18 +263,7 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo ScopePtr parentScope = environmentScope.value_or(globalScope); ScopePtr moduleScope = std::make_shared(parentScope); - if (FFlag::LuauRequireCyclesDontAlwaysReturnAny) - { - moduleScope->returnType = freshTypePack(moduleScope); - } - else - { - if (module.cyclic) - moduleScope->returnType = addTypePack(TypePack{{anyType}, std::nullopt}); - else - moduleScope->returnType = freshTypePack(moduleScope); - } - + moduleScope->returnType = freshTypePack(moduleScope); moduleScope->varargPack = anyTypePack; currentModule->scopes.push_back(std::make_pair(module.root->location, moduleScope)); diff --git a/Ast/include/Luau/Lexer.h b/Ast/include/Luau/Lexer.h index 6c8f21c1..f91f6115 100644 --- a/Ast/include/Luau/Lexer.h +++ b/Ast/include/Luau/Lexer.h @@ -153,7 +153,7 @@ private: class Lexer { public: - Lexer(const char* buffer, std::size_t bufferSize, AstNameTable& names); + Lexer(const char* buffer, std::size_t bufferSize, AstNameTable& names, Position startPosition = {0, 0}); void setSkipComments(bool skip); void setReadNames(bool read); @@ -230,6 +230,17 @@ private: bool skipComments; bool readNames; + // This offset represents a column offset to be applied to any positions created by the lexer until the next new line. + // For example: + // local x = 4 + // local y = 5 + // If we start lexing from the position of `l` in `local x = 4`, the line number will be 1, and the column will be 4 + // However, because the lexer calculates line offsets by 'index in source buffer where there is a newline', the column + // count will start at 0. For this reason, for just the first line, we'll need to store the offset. + unsigned int lexResumeOffset; + + + enum class BraceType { InterpolatedString, diff --git a/Ast/include/Luau/ParseOptions.h b/Ast/include/Luau/ParseOptions.h index 804d16fc..ff727a0b 100644 --- a/Ast/include/Luau/ParseOptions.h +++ b/Ast/include/Luau/ParseOptions.h @@ -21,6 +21,7 @@ struct FragmentParseResumeSettings { DenseHashMap localMap{AstName()}; std::vector localStack; + Position resumePosition; }; struct ParseOptions diff --git a/Ast/src/Lexer.cpp b/Ast/src/Lexer.cpp index 4fb9c936..03532e06 100644 --- a/Ast/src/Lexer.cpp +++ b/Ast/src/Lexer.cpp @@ -8,6 +8,7 @@ #include +LUAU_FASTFLAGVARIABLE(LexerResumesFromPosition) namespace Luau { @@ -303,16 +304,20 @@ static char unescape(char ch) } } -Lexer::Lexer(const char* buffer, size_t bufferSize, AstNameTable& names) +Lexer::Lexer(const char* buffer, size_t bufferSize, AstNameTable& names, Position startPosition) : buffer(buffer) , bufferSize(bufferSize) , offset(0) - , line(0) + , line(FFlag::LexerResumesFromPosition ? startPosition.line : 0) , lineOffset(0) - , lexeme(Location(Position(0, 0), 0), Lexeme::Eof) + , lexeme( + (FFlag::LexerResumesFromPosition ? Location(Position(startPosition.line, startPosition.column), 0) : Location(Position(0, 0), 0)), + Lexeme::Eof + ) , names(names) , skipComments(false) , readNames(true) + , lexResumeOffset(FFlag::LexerResumesFromPosition ? startPosition.column : 0) { } @@ -367,6 +372,7 @@ Lexeme Lexer::lookahead() Location currentPrevLocation = prevLocation; size_t currentBraceStackSize = braceStack.size(); BraceType currentBraceType = braceStack.empty() ? BraceType::Normal : braceStack.back(); + unsigned int currentLexResumeOffset = lexResumeOffset; Lexeme result = next(); @@ -375,6 +381,7 @@ Lexeme Lexer::lookahead() lineOffset = currentLineOffset; lexeme = currentLexeme; prevLocation = currentPrevLocation; + lexResumeOffset = currentLexResumeOffset; if (braceStack.size() < currentBraceStackSize) braceStack.push_back(currentBraceType); @@ -407,7 +414,7 @@ char Lexer::peekch(unsigned int lookahead) const Position Lexer::position() const { - return Position(line, offset - lineOffset); + return Position(line, offset - lineOffset + (FFlag::LexerResumesFromPosition ? lexResumeOffset : 0)); } LUAU_FORCEINLINE @@ -426,6 +433,9 @@ void Lexer::consumeAny() { line++; lineOffset = offset + 1; + // every new line, we reset + if (FFlag::LexerResumesFromPosition) + lexResumeOffset = 0; } offset++; diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 8d665688..1a533fa5 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -24,6 +24,7 @@ LUAU_FASTFLAGVARIABLE(LuauAllowFragmentParsing) LUAU_FASTFLAGVARIABLE(LuauPortableStringZeroCheck) LUAU_FASTFLAGVARIABLE(LuauAllowComplexTypesInGenericParams) LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryForTableTypes) +LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryForClassNames) namespace Luau { @@ -179,7 +180,7 @@ ParseResult Parser::parse(const char* buffer, size_t bufferSize, AstNameTable& n Parser::Parser(const char* buffer, size_t bufferSize, AstNameTable& names, Allocator& allocator, const ParseOptions& options) : options(options) - , lexer(buffer, bufferSize, names) + , lexer(buffer, bufferSize, names, options.parseFragment ? options.parseFragment->resumePosition : Position(0, 0)) , allocator(allocator) , recursionCounter(0) , endMismatchSuspect(Lexeme(Location(), Lexeme::Eof)) @@ -1165,12 +1166,30 @@ AstStat* Parser::parseDeclaration(const Location& start, const AstArray propName = parseNameOpt("property name"); + + if (!propName) + break; + + expectAndConsume(':', "property type annotation"); + AstType* propType = parseType(); + props.push_back( + AstDeclaredClassProp{propName->name, propName->location, propType, false, Location(propStart, lexer.previousLocation())} + ); + } + else + { + Location propStart = lexer.current().location; + Name propName = parseName("property name"); + expectAndConsume(':', "property type annotation"); + AstType* propType = parseType(); + props.push_back( + AstDeclaredClassProp{propName.name, propName.location, propType, false, Location(propStart, lexer.previousLocation())} + ); + } } } diff --git a/CodeGen/include/Luau/IrData.h b/CodeGen/include/Luau/IrData.h index b603af9e..779fe012 100644 --- a/CodeGen/include/Luau/IrData.h +++ b/CodeGen/include/Luau/IrData.h @@ -114,10 +114,12 @@ enum class IrCmd : uint8_t STORE_INT, // Store a vector into TValue + // When optional 'E' tag is present, it is written out to the TValue as well // A: Rn // B: double (x) // C: double (y) // D: double (z) + // E: tag (optional) STORE_VECTOR, // Store a TValue into memory diff --git a/CodeGen/src/IrLoweringA64.cpp b/CodeGen/src/IrLoweringA64.cpp index 91fbf0bf..c7fcac27 100644 --- a/CodeGen/src/IrLoweringA64.cpp +++ b/CodeGen/src/IrLoweringA64.cpp @@ -11,7 +11,8 @@ #include "lstate.h" #include "lgc.h" -LUAU_FASTFLAG(LuauVectorLibNativeDot); +LUAU_FASTFLAG(LuauVectorLibNativeDot) +LUAU_FASTFLAG(LuauCodeGenVectorDeadStoreElim) namespace Luau { @@ -497,6 +498,13 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) build.str(temp4, AddressA64(addr.base, addr.data + 4)); build.fcvt(temp4, temp3); build.str(temp4, AddressA64(addr.base, addr.data + 8)); + + if (FFlag::LuauCodeGenVectorDeadStoreElim && inst.e.kind != IrOpKind::None) + { + RegisterA64 temp = regs.allocTemp(KindA64::w); + build.mov(temp, tagOp(inst.e)); + build.str(temp, tempAddr(inst.a, offsetof(TValue, tt))); + } break; } case IrCmd::STORE_TVALUE: diff --git a/CodeGen/src/IrLoweringX64.cpp b/CodeGen/src/IrLoweringX64.cpp index 796894ed..814c6d8c 100644 --- a/CodeGen/src/IrLoweringX64.cpp +++ b/CodeGen/src/IrLoweringX64.cpp @@ -15,7 +15,8 @@ #include "lstate.h" #include "lgc.h" -LUAU_FASTFLAG(LuauVectorLibNativeDot); +LUAU_FASTFLAG(LuauVectorLibNativeDot) +LUAU_FASTFLAG(LuauCodeGenVectorDeadStoreElim) namespace Luau { @@ -297,6 +298,9 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) storeDoubleAsFloat(luauRegValueVector(vmRegOp(inst.a), 0), inst.b); storeDoubleAsFloat(luauRegValueVector(vmRegOp(inst.a), 1), inst.c); storeDoubleAsFloat(luauRegValueVector(vmRegOp(inst.a), 2), inst.d); + + if (FFlag::LuauCodeGenVectorDeadStoreElim && inst.e.kind != IrOpKind::None) + build.mov(luauRegTag(vmRegOp(inst.a)), tagOp(inst.e)); break; case IrCmd::STORE_TVALUE: { diff --git a/CodeGen/src/OptimizeConstProp.cpp b/CodeGen/src/OptimizeConstProp.cpp index f93354a3..1e532280 100644 --- a/CodeGen/src/OptimizeConstProp.cpp +++ b/CodeGen/src/OptimizeConstProp.cpp @@ -364,7 +364,7 @@ struct ConstPropState return; // To avoid captured register invalidation tracking in lowering later, values from loads from captured registers are not propagated - // This prevents the case where load value location is linked to memory in case of a spill and is then cloberred in a user call + // This prevents the case where load value location is linked to memory in case of a spill and is then clobbered in a user call if (function.cfg.captured.regs.test(vmRegOp(loadInst.a))) return; @@ -378,7 +378,7 @@ struct ConstPropState if (!instLink.contains(*prevIdx)) createRegLink(*prevIdx, loadInst.a); - // Substitute load instructon with the previous value + // Substitute load instruction with the previous value substitute(function, loadInst, IrOp{IrOpKind::Inst, *prevIdx}); return; } @@ -401,7 +401,7 @@ struct ConstPropState return; // To avoid captured register invalidation tracking in lowering later, values from stores into captured registers are not propagated - // This prevents the case where store creates an alternative value location in case of a spill and is then cloberred in a user call + // This prevents the case where store creates an alternative value location in case of a spill and is then clobbered in a user call if (function.cfg.captured.regs.test(vmRegOp(storeInst.a))) return; diff --git a/CodeGen/src/OptimizeDeadStore.cpp b/CodeGen/src/OptimizeDeadStore.cpp index b4b4c7b5..8362cf2b 100644 --- a/CodeGen/src/OptimizeDeadStore.cpp +++ b/CodeGen/src/OptimizeDeadStore.cpp @@ -9,6 +9,8 @@ #include "lobject.h" +LUAU_FASTFLAGVARIABLE(LuauCodeGenVectorDeadStoreElim) + // TODO: optimization can be improved by knowing which registers are live in at each VM exit namespace Luau @@ -324,8 +326,29 @@ static bool tryReplaceTagWithFullStore( // And value store has to follow, as the pre-DSO code would not allow GC to observe an incomplete stack variable if (tag != LUA_TNIL && regInfo.valueInstIdx != ~0u) { - IrOp prevValueOp = function.instructions[regInfo.valueInstIdx].b; - replace(function, block, instIndex, IrInst{IrCmd::STORE_SPLIT_TVALUE, targetOp, tagOp, prevValueOp}); + if (FFlag::LuauCodeGenVectorDeadStoreElim) + { + IrInst& prevValueInst = function.instructions[regInfo.valueInstIdx]; + + if (prevValueInst.cmd == IrCmd::STORE_VECTOR) + { + CODEGEN_ASSERT(prevValueInst.e.kind == IrOpKind::None); + IrOp prevValueX = prevValueInst.b; + IrOp prevValueY = prevValueInst.c; + IrOp prevValueZ = prevValueInst.d; + replace(function, block, instIndex, IrInst{IrCmd::STORE_VECTOR, targetOp, prevValueX, prevValueY, prevValueZ, tagOp}); + } + else + { + IrOp prevValueOp = prevValueInst.b; + replace(function, block, instIndex, IrInst{IrCmd::STORE_SPLIT_TVALUE, targetOp, tagOp, prevValueOp}); + } + } + else + { + IrOp prevValueOp = function.instructions[regInfo.valueInstIdx].b; + replace(function, block, instIndex, IrInst{IrCmd::STORE_SPLIT_TVALUE, targetOp, tagOp, prevValueOp}); + } } state.killTagStore(regInfo); @@ -356,6 +379,25 @@ static bool tryReplaceTagWithFullStore( state.killTValueStore(regInfo); + regInfo.tvalueInstIdx = instIndex; + regInfo.maybeGco = isGCO(tag); + regInfo.knownTag = tag; + state.hasGcoToClear |= regInfo.maybeGco; + return true; + } + else if (FFlag::LuauCodeGenVectorDeadStoreElim && prev.cmd == IrCmd::STORE_VECTOR) + { + // If the 'nil' is stored, we keep 'STORE_TAG Rn, tnil' as it writes the 'full' TValue + if (tag != LUA_TNIL) + { + IrOp prevValueX = prev.b; + IrOp prevValueY = prev.c; + IrOp prevValueZ = prev.d; + replace(function, block, instIndex, IrInst{IrCmd::STORE_VECTOR, targetOp, prevValueX, prevValueY, prevValueZ, tagOp}); + } + + state.killTValueStore(regInfo); + regInfo.tvalueInstIdx = instIndex; regInfo.maybeGco = isGCO(tag); regInfo.knownTag = tag; @@ -410,6 +452,94 @@ static bool tryReplaceValueWithFullStore( state.killTValueStore(regInfo); + regInfo.tvalueInstIdx = instIndex; + return true; + } + else if (FFlag::LuauCodeGenVectorDeadStoreElim && prev.cmd == IrCmd::STORE_VECTOR) + { + IrOp prevTagOp = prev.e; + CODEGEN_ASSERT(prevTagOp.kind != IrOpKind::None); + uint8_t prevTag = function.tagOp(prevTagOp); + + CODEGEN_ASSERT(regInfo.knownTag == prevTag); + replace(function, block, instIndex, IrInst{IrCmd::STORE_SPLIT_TVALUE, targetOp, prevTagOp, valueOp}); + + state.killTValueStore(regInfo); + + regInfo.tvalueInstIdx = instIndex; + return true; + } + } + + return false; +} + +static bool tryReplaceVectorValueWithFullStore( + RemoveDeadStoreState& state, + IrBuilder& build, + IrFunction& function, + IrBlock& block, + uint32_t instIndex, + StoreRegInfo& regInfo +) +{ + CODEGEN_ASSERT(FFlag::LuauCodeGenVectorDeadStoreElim); + + // If the tag+value pair is established, we can mark both as dead and use a single split TValue store + if (regInfo.tagInstIdx != ~0u && regInfo.valueInstIdx != ~0u) + { + IrOp prevTagOp = function.instructions[regInfo.tagInstIdx].b; + uint8_t prevTag = function.tagOp(prevTagOp); + + CODEGEN_ASSERT(regInfo.knownTag == prevTag); + + IrInst& storeInst = function.instructions[instIndex]; + CODEGEN_ASSERT(storeInst.cmd == IrCmd::STORE_VECTOR); + replace(function, storeInst.e, prevTagOp); + + state.killTagStore(regInfo); + state.killValueStore(regInfo); + + regInfo.tvalueInstIdx = instIndex; + return true; + } + + // We can also replace a dead split TValue store with a new one, while keeping the value the same + if (regInfo.tvalueInstIdx != ~0u) + { + IrInst& prev = function.instructions[regInfo.tvalueInstIdx]; + + if (prev.cmd == IrCmd::STORE_SPLIT_TVALUE) + { + IrOp prevTagOp = prev.b; + uint8_t prevTag = function.tagOp(prevTagOp); + + CODEGEN_ASSERT(regInfo.knownTag == prevTag); + CODEGEN_ASSERT(prev.d.kind == IrOpKind::None); + + IrInst& storeInst = function.instructions[instIndex]; + CODEGEN_ASSERT(storeInst.cmd == IrCmd::STORE_VECTOR); + replace(function, storeInst.e, prevTagOp); + + state.killTValueStore(regInfo); + + regInfo.tvalueInstIdx = instIndex; + return true; + } + else if (prev.cmd == IrCmd::STORE_VECTOR) + { + IrOp prevTagOp = prev.e; + CODEGEN_ASSERT(prevTagOp.kind != IrOpKind::None); + uint8_t prevTag = function.tagOp(prevTagOp); + + CODEGEN_ASSERT(regInfo.knownTag == prevTag); + + IrInst& storeInst = function.instructions[instIndex]; + CODEGEN_ASSERT(storeInst.cmd == IrCmd::STORE_VECTOR); + replace(function, storeInst.e, prevTagOp); + + state.killTValueStore(regInfo); + regInfo.tvalueInstIdx = instIndex; return true; } @@ -499,10 +629,31 @@ static void markDeadStoresInInst(RemoveDeadStoreState& state, IrBuilder& build, } break; case IrCmd::STORE_VECTOR: - // Partial vector value store cannot be combined into a STORE_SPLIT_TVALUE, so we skip dead store optimization for it if (inst.a.kind == IrOpKind::VmReg) { - state.useReg(vmRegOp(inst.a)); + if (FFlag::LuauCodeGenVectorDeadStoreElim) + { + int reg = vmRegOp(inst.a); + + if (function.cfg.captured.regs.test(reg)) + return; + + StoreRegInfo& regInfo = state.info[reg]; + + if (tryReplaceVectorValueWithFullStore(state, build, function, block, index, regInfo)) + break; + + // Partial value store can be removed by a new one if the tag is known + if (regInfo.knownTag != kUnknownTag) + state.killValueStore(regInfo); + + regInfo.valueInstIdx = index; + regInfo.maybeGco = false; + } + else + { + state.useReg(vmRegOp(inst.a)); + } } break; case IrCmd::STORE_TVALUE: diff --git a/tests/Fixture.h b/tests/Fixture.h index e273a642..ba038403 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -27,6 +27,7 @@ LUAU_FASTFLAG(DebugLuauFreezeArena) LUAU_FASTFLAG(DebugLuauForceAllNewSolverTests) +LUAU_FASTFLAG(LuauVectorDefinitionsExtra) #define DOES_NOT_PASS_NEW_SOLVER_GUARD_IMPL(line) \ ScopedFastFlag sff_##line{FFlag::LuauSolverV2, FFlag::DebugLuauForceAllNewSolverTests}; @@ -113,6 +114,8 @@ struct Fixture // In that case, flag can be forced to 'true' using the example below: // ScopedFastFlag sff_LuauExampleFlagDefinition{FFlag::LuauExampleFlagDefinition, true}; + ScopedFastFlag sff_LuauVectorDefinitionsExtra{FFlag::LuauVectorDefinitionsExtra, true}; + // Arena freezing marks the `TypeArena`'s underlying memory as read-only, raising an access violation whenever you mutate it. // This is useful for tracking down violations of Luau's memory model. ScopedFastFlag sff_DebugLuauFreezeArena{FFlag::DebugLuauFreezeArena, true}; diff --git a/tests/FragmentAutocomplete.test.cpp b/tests/FragmentAutocomplete.test.cpp index f0e2ae91..42f2bf09 100644 --- a/tests/FragmentAutocomplete.test.cpp +++ b/tests/FragmentAutocomplete.test.cpp @@ -24,6 +24,7 @@ LUAU_FASTFLAG(LuauAllowFragmentParsing); LUAU_FASTFLAG(LuauAutocompleteRefactorsForIncrementalAutocomplete) LUAU_FASTFLAG(LuauSymbolEquality); LUAU_FASTFLAG(LuauStoreSolverTypeOnModule); +LUAU_FASTFLAG(LexerResumesFromPosition) static std::optional nullCallback(std::string tag, std::optional ptr, std::optional contents) { @@ -46,11 +47,12 @@ static FrontendOptions getOptions() template struct FragmentAutocompleteFixtureImpl : BaseType { - ScopedFastFlag sffs[4] = { + ScopedFastFlag sffs[5] = { {FFlag::LuauAllowFragmentParsing, true}, {FFlag::LuauAutocompleteRefactorsForIncrementalAutocomplete, true}, {FFlag::LuauStoreSolverTypeOnModule, true}, {FFlag::LuauSymbolEquality, true}, + {FFlag::LexerResumesFromPosition, true} }; FragmentAutocompleteFixtureImpl() @@ -288,6 +290,7 @@ TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "local_initializer") check("local a ="); auto fragment = parseFragment("local a =", Position(0, 10)); CHECK_EQ("local a =", fragment.fragmentToParse); + CHECK_EQ(Location{Position{0, 0}, 9}, fragment.root->location); } TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "statement_in_empty_fragment_is_non_null") @@ -334,6 +337,8 @@ local z = x + y Position{3, 15} ); + CHECK_EQ(Location{Position{2, 0}, Position{3, 15}}, fragment.root->location); + CHECK_EQ("local y = 5\nlocal z = x + y", fragment.fragmentToParse); CHECK_EQ(5, fragment.ancestry.size()); REQUIRE(fragment.root); @@ -380,6 +385,7 @@ local y = 5 CHECK_EQ("local z = x + y", fragment.fragmentToParse); CHECK_EQ(5, fragment.ancestry.size()); REQUIRE(fragment.root); + CHECK_EQ(Location{Position{2, 0}, Position{2, 15}}, fragment.root->location); CHECK_EQ(1, fragment.root->body.size); auto stat = fragment.root->body.data[0]->as(); REQUIRE(stat); @@ -421,7 +427,7 @@ TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "can_parse_in_correct_scope") Position{6, 0} ); - CHECK_EQ("\n", fragment.fragmentToParse); + CHECK_EQ("\n ", fragment.fragmentToParse); } TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "can_parse_single_line_fragment_override") @@ -465,7 +471,7 @@ abc("bar") Position{1, 9} ); - CHECK_EQ("function abc(foo: string) end\nabc(\"foo\"", stringFragment.fragmentToParse); + CHECK_EQ("function abc(foo: string) end\nabc(\"foo\")", stringFragment.fragmentToParse); CHECK(stringFragment.nearestStatement->is()); CHECK_GE(stringFragment.ancestry.size(), 1); diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index d7028fd7..bfa69fe4 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -13,7 +13,6 @@ using namespace Luau; LUAU_FASTFLAG(LuauSolverV2); -LUAU_FASTFLAG(LuauRequireCyclesDontAlwaysReturnAny); LUAU_FASTFLAG(DebugLuauFreezeArena); LUAU_FASTFLAG(DebugLuauMagicTypes); @@ -313,11 +312,9 @@ TEST_CASE_FIXTURE(FrontendFixture, "nocheck_cycle_used_by_checked") REQUIRE(bool(cExports)); if (FFlag::LuauSolverV2) - CHECK_EQ("{ a: { hello: any }, b: { hello: any } }", toString(*cExports)); - else if (FFlag::LuauRequireCyclesDontAlwaysReturnAny) - CHECK("{| a: any, b: any |}, {| a: {| hello: any |}, b: {| hello: any |} |}" == toString(*cExports)); + CHECK("{ a: { hello: any }, b: { hello: any } }" == toString(*cExports)); else - CHECK_EQ("{| a: any, b: any |}", toString(*cExports)); + CHECK("{| a: {| hello: any |}, b: {| hello: any |} |}" == toString(*cExports)); } TEST_CASE_FIXTURE(FrontendFixture, "cycle_detection_disabled_in_nocheck") @@ -1457,4 +1454,72 @@ TEST_CASE_FIXTURE(Fixture, "exported_tables_have_position_metadata") CHECK(Location{Position{1, 17}, Position{1, 20}} == prop.location); } +TEST_CASE_FIXTURE(FrontendFixture, "get_required_scripts") +{ + fileResolver.source["game/workspace/MyScript"] = R"( + local MyModuleScript = require(game.workspace.MyModuleScript) + local MyModuleScript2 = require(game.workspace.MyModuleScript2) + MyModuleScript.myPrint() + )"; + + fileResolver.source["game/workspace/MyModuleScript"] = R"( + local module = {} + function module.myPrint() + print("Hello World") + end + return module + )"; + + fileResolver.source["game/workspace/MyModuleScript2"] = R"( + local module = {} + return module + )"; + + // isDirty(name) is true, getRequiredScripts should not hit the cache. + frontend.markDirty("game/workspace/MyScript"); + std::vector requiredScripts = frontend.getRequiredScripts("game/workspace/MyScript"); + REQUIRE(requiredScripts.size() == 2); + CHECK(requiredScripts[0] == "game/workspace/MyModuleScript"); + CHECK(requiredScripts[1] == "game/workspace/MyModuleScript2"); + + // Call frontend.check first, then getRequiredScripts should hit the cache because isDirty(name) is false. + frontend.check("game/workspace/MyScript"); + requiredScripts = frontend.getRequiredScripts("game/workspace/MyScript"); + REQUIRE(requiredScripts.size() == 2); + CHECK(requiredScripts[0] == "game/workspace/MyModuleScript"); + CHECK(requiredScripts[1] == "game/workspace/MyModuleScript2"); +} + +TEST_CASE_FIXTURE(FrontendFixture, "get_required_scripts_dirty") +{ + fileResolver.source["game/workspace/MyScript"] = R"( + print("Hello World") + )"; + + fileResolver.source["game/workspace/MyModuleScript"] = R"( + local module = {} + function module.myPrint() + print("Hello World") + end + return module + )"; + + frontend.check("game/workspace/MyScript"); + std::vector requiredScripts = frontend.getRequiredScripts("game/workspace/MyScript"); + REQUIRE(requiredScripts.size() == 0); + + fileResolver.source["game/workspace/MyScript"] = R"( + local MyModuleScript = require(game.workspace.MyModuleScript) + MyModuleScript.myPrint() + )"; + + requiredScripts = frontend.getRequiredScripts("game/workspace/MyScript"); + REQUIRE(requiredScripts.size() == 0); + + frontend.markDirty("game/workspace/MyScript"); + requiredScripts = frontend.getRequiredScripts("game/workspace/MyScript"); + REQUIRE(requiredScripts.size() == 1); + CHECK(requiredScripts[0] == "game/workspace/MyModuleScript"); +} + TEST_SUITE_END(); diff --git a/tests/IrBuilder.test.cpp b/tests/IrBuilder.test.cpp index d02fd9f1..ba4e7f04 100644 --- a/tests/IrBuilder.test.cpp +++ b/tests/IrBuilder.test.cpp @@ -13,6 +13,8 @@ #include LUAU_FASTFLAG(DebugLuauAbortingChecks) +LUAU_FASTFLAG(LuauCodeGenVectorDeadStoreElim) +LUAU_FASTFLAG(LuauCodeGenArithOpt) using namespace Luau::CodeGen; @@ -119,6 +121,7 @@ public: static const int tnil = 0; static const int tboolean = 1; static const int tnumber = 3; + static const int tvector = 4; static const int tstring = 5; static const int ttable = 6; static const int tfunction = 7; @@ -1720,6 +1723,57 @@ bb_fallback_1: )"); } +TEST_CASE_FIXTURE(IrBuilderFixture, "NumericSimplifications") +{ + ScopedFastFlag luauCodeGenArithOpt{FFlag::LuauCodeGenArithOpt, true}; + + IrOp block = build.block(IrBlockKind::Internal); + + build.beginBlock(block); + IrOp value = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(0)); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.inst(IrCmd::SUB_NUM, value, build.constDouble(0.0))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(2), build.inst(IrCmd::ADD_NUM, value, build.constDouble(-0.0))); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(3), build.inst(IrCmd::MUL_NUM, value, build.constDouble(1.0))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(4), build.inst(IrCmd::MUL_NUM, value, build.constDouble(2.0))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(5), build.inst(IrCmd::MUL_NUM, value, build.constDouble(-1.0))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(6), build.inst(IrCmd::MUL_NUM, value, build.constDouble(3.0))); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(7), build.inst(IrCmd::DIV_NUM, value, build.constDouble(1.0))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(8), build.inst(IrCmd::DIV_NUM, value, build.constDouble(-1.0))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(9), build.inst(IrCmd::DIV_NUM, value, build.constDouble(32.0))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(10), build.inst(IrCmd::DIV_NUM, value, build.constDouble(6.0))); + + build.inst(IrCmd::RETURN, build.vmReg(1), build.constInt(9)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + %0 = LOAD_DOUBLE R0 + STORE_DOUBLE R1, %0 + STORE_DOUBLE R2, %0 + STORE_DOUBLE R3, %0 + %7 = ADD_NUM %0, %0 + STORE_DOUBLE R4, %7 + %9 = UNM_NUM %0 + STORE_DOUBLE R5, %9 + %11 = MUL_NUM %0, 3 + STORE_DOUBLE R6, %11 + STORE_DOUBLE R7, %0 + %15 = UNM_NUM %0 + STORE_DOUBLE R8, %15 + %17 = MUL_NUM %0, 0.03125 + STORE_DOUBLE R9, %17 + %19 = DIV_NUM %0, 6 + STORE_DOUBLE R10, %19 + RETURN R1, 9i + +)"); +} + TEST_SUITE_END(); TEST_SUITE_BEGIN("LinearExecutionFlowExtraction"); @@ -4416,6 +4470,210 @@ bb_0: )"); } +TEST_CASE_FIXTURE(IrBuilderFixture, "VectorOverNumber") +{ + ScopedFastFlag luauCodeGenVectorDeadStoreElim{FFlag::LuauCodeGenVectorDeadStoreElim, true}; + + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.constDouble(2.0)); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber)); + build.inst(IrCmd::STORE_VECTOR, build.vmReg(0), build.constDouble(1.0), build.constDouble(2.0), build.constDouble(4.0)); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tvector)); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + markDeadStoresInBlockChains(build); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + STORE_VECTOR R0, 1, 2, 4, tvector + RETURN R0, 1i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "VectorOverVector") +{ + ScopedFastFlag luauCodeGenVectorDeadStoreElim{FFlag::LuauCodeGenVectorDeadStoreElim, true}; + + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::STORE_VECTOR, build.vmReg(0), build.constDouble(4.0), build.constDouble(2.0), build.constDouble(1.0)); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tvector)); + build.inst(IrCmd::STORE_VECTOR, build.vmReg(0), build.constDouble(1.0), build.constDouble(2.0), build.constDouble(4.0)); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tvector)); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + markDeadStoresInBlockChains(build); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + STORE_VECTOR R0, 1, 2, 4, tvector + RETURN R0, 1i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "NumberOverVector") +{ + ScopedFastFlag luauCodeGenVectorDeadStoreElim{FFlag::LuauCodeGenVectorDeadStoreElim, true}; + + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::STORE_VECTOR, build.vmReg(0), build.constDouble(1.0), build.constDouble(2.0), build.constDouble(4.0)); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tvector)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.constDouble(2.0)); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber)); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + markDeadStoresInBlockChains(build); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + STORE_SPLIT_TVALUE R0, tnumber, 2 + RETURN R0, 1i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "NumberOverNil") +{ + ScopedFastFlag luauCodeGenVectorDeadStoreElim{FFlag::LuauCodeGenVectorDeadStoreElim, true}; + + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnil)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.constDouble(2.0)); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber)); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + markDeadStoresInBlockChains(build); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + STORE_SPLIT_TVALUE R0, tnumber, 2 + RETURN R0, 1i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "VectorOverNil") +{ + ScopedFastFlag luauCodeGenVectorDeadStoreElim{FFlag::LuauCodeGenVectorDeadStoreElim, true}; + + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnil)); + build.inst(IrCmd::STORE_VECTOR, build.vmReg(0), build.constDouble(1.0), build.constDouble(2.0), build.constDouble(4.0)); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tvector)); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + markDeadStoresInBlockChains(build); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + STORE_VECTOR R0, 1, 2, 4, tvector + RETURN R0, 1i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "NumberOverCombinedVector") +{ + ScopedFastFlag luauCodeGenVectorDeadStoreElim{FFlag::LuauCodeGenVectorDeadStoreElim, true}; + + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.constDouble(2.0)); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber)); + build.inst(IrCmd::STORE_VECTOR, build.vmReg(0), build.constDouble(1.0), build.constDouble(2.0), build.constDouble(4.0)); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tvector)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.constDouble(3.0)); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber)); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + markDeadStoresInBlockChains(build); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + STORE_SPLIT_TVALUE R0, tnumber, 3 + RETURN R0, 1i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "VectorOverCombinedVector") +{ + ScopedFastFlag luauCodeGenVectorDeadStoreElim{FFlag::LuauCodeGenVectorDeadStoreElim, true}; + + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.constDouble(2.0)); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber)); + build.inst(IrCmd::STORE_VECTOR, build.vmReg(0), build.constDouble(1.0), build.constDouble(2.0), build.constDouble(4.0)); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tvector)); + build.inst(IrCmd::STORE_VECTOR, build.vmReg(0), build.constDouble(8.0), build.constDouble(16.0), build.constDouble(32.0)); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tvector)); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + markDeadStoresInBlockChains(build); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + STORE_VECTOR R0, 8, 16, 32, tvector + RETURN R0, 1i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "VectorOverCombinedNumber") +{ + ScopedFastFlag luauCodeGenVectorDeadStoreElim{FFlag::LuauCodeGenVectorDeadStoreElim, true}; + + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.constDouble(2.0)); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.constDouble(4.0)); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber)); + build.inst(IrCmd::STORE_VECTOR, build.vmReg(0), build.constDouble(8.0), build.constDouble(16.0), build.constDouble(32.0)); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tvector)); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + markDeadStoresInBlockChains(build); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + STORE_VECTOR R0, 8, 16, 32, tvector + RETURN R0, 1i + +)"); +} + TEST_SUITE_END(); TEST_SUITE_BEGIN("Dump"); diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index b0e0caa0..b35466cb 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -20,6 +20,7 @@ LUAU_FASTFLAG(LuauUserDefinedTypeFunctionsSyntax2) LUAU_FASTFLAG(LuauUserDefinedTypeFunParseExport) LUAU_FASTFLAG(LuauAllowComplexTypesInGenericParams) LUAU_FASTFLAG(LuauErrorRecoveryForTableTypes) +LUAU_FASTFLAG(LuauErrorRecoveryForClassNames) namespace { @@ -2085,6 +2086,20 @@ TEST_CASE_FIXTURE(Fixture, "variadic_definition_parsing") matchParseError("declare class Foo function a(self, ...) end", "All declaration parameters aside from 'self' must be annotated"); } +TEST_CASE_FIXTURE(Fixture, "missing_declaration_prop") +{ + ScopedFastFlag luauErrorRecoveryForClassNames{FFlag::LuauErrorRecoveryForClassNames, true}; + + matchParseError( + R"( + declare class Foo + a: number, + end + )", + "Expected identifier when parsing property name, got ','" + ); +} + TEST_CASE_FIXTURE(Fixture, "generic_pack_parsing") { ParseResult result = parseEx(R"( diff --git a/tests/TypeFunction.user.test.cpp b/tests/TypeFunction.user.test.cpp index 909017cc..145772fd 100644 --- a/tests/TypeFunction.user.test.cpp +++ b/tests/TypeFunction.user.test.cpp @@ -16,6 +16,7 @@ LUAU_FASTFLAG(LuauUserDefinedTypeFunctionResetState) LUAU_FASTFLAG(LuauUserTypeFunNonstrict) LUAU_FASTFLAG(LuauUserTypeFunExportedAndLocal) LUAU_FASTFLAG(LuauUserDefinedTypeFunParseExport) +LUAU_FASTFLAG(LuauUserTypeFunThreadBuffer) TEST_SUITE_BEGIN("UserDefinedTypeFunctionTests"); @@ -235,6 +236,36 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_number_methods_work") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "thread_and_buffer_types") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag luauUserTypeFunFixRegister{FFlag::LuauUserTypeFunFixRegister, true}; + ScopedFastFlag luauUserTypeFunThreadBuffer{FFlag::LuauUserTypeFunThreadBuffer, true}; + + LUAU_REQUIRE_NO_ERRORS(check(R"( + type function work_with_thread(x) + if x:is("thread") then + return types.thread + end + return types.string + end + type X = thread + local function ok(idx: work_with_thread): thread return idx end + )")); + + LUAU_REQUIRE_NO_ERRORS(check(R"( + type function work_with_buffer(x) + if x:is("buffer") then + return types.buffer + end + return types.string + end + type X = buffer + local function ok(idx: work_with_buffer): buffer return idx end + )")); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_string_serialization_works") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; diff --git a/tests/TypeInfer.modules.test.cpp b/tests/TypeInfer.modules.test.cpp index 5d5df24a..a9109e1d 100644 --- a/tests/TypeInfer.modules.test.cpp +++ b/tests/TypeInfer.modules.test.cpp @@ -11,7 +11,6 @@ #include "doctest.h" LUAU_FASTFLAG(LuauInstantiateInSubtyping) -LUAU_FASTFLAG(LuauRequireCyclesDontAlwaysReturnAny) LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAG(LuauTypestateBuiltins2) LUAU_FASTFLAG(LuauNewSolverPopulateTableLocations) @@ -756,8 +755,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "spooky_blocked_type_laundered_by_bound_type" TEST_CASE_FIXTURE(BuiltinsFixture, "cycles_dont_make_everything_any") { - ScopedFastFlag sff{FFlag::LuauRequireCyclesDontAlwaysReturnAny, true}; - fileResolver.source["game/A"] = R"( --!strict local module = {} diff --git a/tests/TypeInfer.primitives.test.cpp b/tests/TypeInfer.primitives.test.cpp index c3cce9df..0c14a448 100644 --- a/tests/TypeInfer.primitives.test.cpp +++ b/tests/TypeInfer.primitives.test.cpp @@ -12,6 +12,8 @@ #include "doctest.h" +LUAU_FASTFLAG(LuauVectorDefinitions) + using namespace Luau; TEST_SUITE_BEGIN("TypeInferPrimitives"); @@ -120,4 +122,34 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "property_of_buffers") LUAU_REQUIRE_ERROR_COUNT(1, result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "properties_of_vectors") +{ + CheckResult result = check(R"( + local a = vector.create(1, 2, 3) + local b = vector.create(4, 5, 6) + + local t1 = { + a + b, + a - b, + a * 3, + a * b, + 3 * b, + a / 3, + a / b, + 3 / b, + a // 4, + a // b, + 4 // b, + -a, + } + local t2 = { + a.x, + a.y, + a.z, + } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); From 8b8118b027b10585af90273b649c1de286c05497 Mon Sep 17 00:00:00 2001 From: jkelaty-rbx <78873527+jkelaty-rbx@users.noreply.github.com> Date: Fri, 6 Dec 2024 10:04:57 -0800 Subject: [PATCH 2/2] Convert Luau heap dumps to Chrome heap snapshots (#1554) Adds a script for (approximately) converting Luau heap dumps to Chrome heap snapshots. Useful for visually inspecting a heap dump within Chrome's UI. --- tools/heapsnapshot.py | 221 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 221 insertions(+) create mode 100644 tools/heapsnapshot.py diff --git a/tools/heapsnapshot.py b/tools/heapsnapshot.py new file mode 100644 index 00000000..d3c0c92d --- /dev/null +++ b/tools/heapsnapshot.py @@ -0,0 +1,221 @@ +#!/usr/bin/python3 +# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +# Given a Luau heap dump, this tool generates a heap snapshot which can be imported by Chrome's DevTools Memory panel +# To generate a snapshot, use luaC_dump, ideally preceded by luaC_fullgc +# To import in Chrome, ensure the snapshot has the .heapsnapshot extension and go to: Inspect -> Memory -> Load Profile +# A reference for the heap snapshot schema can be found here: https://learn.microsoft.com/en-us/microsoft-edge/devtools-guide-chromium/memory-problems/heap-snapshot-schema + +# Usage: python3 heapsnapshot.py luauDump.json heapSnapshot.heapsnapshot + +import json +import sys + +# Header describing the snapshot format, copied from a real Chrome heap snapshot +snapshotMeta = { + "node_fields": ["type", "name", "id", "self_size", "edge_count", "trace_node_id", "detachedness"], + "node_types": [ + ["hidden", "array", "string", "object", "code", "closure", "regexp", "number", "native", "synthetic", "concatenated string", "sliced string", "symbol", "bigint", "object shape"], + "string", "number", "number", "number", "number", "number" + ], + "edge_fields": ["type", "name_or_index", "to_node"], + "edge_types": [ + ["context", "element", "property", "internal", "hidden", "shortcut", "weak"], + "string_or_number", "node" + ], + "trace_function_info_fields": ["function_id", "name", "script_name", "script_id", "line", "column"], + "trace_node_fields": ["id", "function_info_index", "count", "size", "children"], + "sample_fields": ["timestamp_us", "last_assigned_id"], + "location_fields": ["object_index", "script_id", "line", "column"], +} + +# These indices refer to the index in the snapshot's metadata header +nodeTypeToMetaIndex = {type: i for i, type in enumerate(snapshotMeta["node_types"][0])} +edgeTypeToMetaIndex = {type: i for i, type in enumerate(snapshotMeta["edge_types"][0])} + +nodeFieldCount = len(snapshotMeta["node_fields"]) +edgeFieldCount = len(snapshotMeta["edge_fields"]) + + +def readAddresses(data): + # Ordered list of addresses to ensure the registry is the first node, and also so we can process nodes in index order + addresses = [] + addressToNodeIndex = {} + + def addAddress(address): + assert address not in addressToNodeIndex, f"Address already exists in the snapshot: '{address}'" + addresses.append(address) + addressToNodeIndex[address] = len(addresses) - 1 + + # The registry is a special case that needs to be either the first or last node to ensure gc "distances" are calculated correctly + registryAddress = data["roots"]["registry"] + addAddress(registryAddress) + + for address, obj in data["objects"].items(): + if address == registryAddress: + continue + addAddress(address) + + return addresses, addressToNodeIndex + + +def convertToSnapshot(data): + addresses, addressToNodeIndex = readAddresses(data) + + # Some notable idiosyncrasies with the heap snapshot format: + # 1. The snapshot format contains a flat array of nodes and edges. Oddly, edges must reference the "absolute" index of a node's first element after flattening. + # 2. A node's outgoing edges are implicitly represented by a contiguous block of edges in the edges array which correspond to the node's position + # in the nodes array and its edge count. So if the first node has 3 edges, the first 3 edges in the edges array are its edges, and so on. + + nodes = [] + edges = [] + strings = [] + + stringToSnapshotIndex = {} + + def getUniqueId(address): + # TODO: we should hash this to an int32 instead of using the address directly + # Addresses are hexadecimal strings + return int(address, 16) + + def addNode(node): + assert len(node) == nodeFieldCount, f"Expected {nodeFieldCount} fields, got {len(node)}" + nodes.append(node) + + def addEdge(edge): + assert len(edge) == edgeFieldCount, f"Expected {edgeFieldCount} fields, got {len(edge)}" + edges.append(edge) + + def getStringSnapshotIndex(string): + assert isinstance(string, str), f"'{string}' is not of type string" + if string not in stringToSnapshotIndex: + strings.append(string) + stringToSnapshotIndex[string] = len(strings) - 1 + return stringToSnapshotIndex[string] + + def getNodeSnapshotIndex(address): + # This is the index of the first element of the node in the flattened nodes array + return addressToNodeIndex[address] * nodeFieldCount + + for address in addresses: + obj = data["objects"][address] + edgeCount = 0 + + if obj["type"] == "table": + # TODO: support weak references + name = f"Registry ({address})" if address == data["roots"]["registry"] else f"Luau table ({address})" + if "pairs" in obj: + for i in range(0, len(obj["pairs"]), 2): + key = obj["pairs"][i] + value = obj["pairs"][i + 1] + if key is None and value is None: + # Both the key and value are value types, nothing meaningful to add here + continue + elif key is None: + edgeCount += 1 + addEdge([edgeTypeToMetaIndex["property"], getStringSnapshotIndex("(Luau table key value type)"), getNodeSnapshotIndex(value)]) + elif value is None: + edgeCount += 1 + addEdge([edgeTypeToMetaIndex["internal"], getStringSnapshotIndex(f'Luau table key ref: {data["objects"][key]["type"]} ({key})'), getNodeSnapshotIndex(key)]) + elif data["objects"][key]["type"] == "string": + edgeCount += 2 + # This is a special case where the key is a string, so we can use it as the edge name + addEdge([edgeTypeToMetaIndex["property"], getStringSnapshotIndex(data["objects"][key]["data"]), getNodeSnapshotIndex(value)]) + addEdge([edgeTypeToMetaIndex["internal"], getStringSnapshotIndex(f'Luau table key ref: {data["objects"][key]["type"]} ({key})'), getNodeSnapshotIndex(key)]) + else: + edgeCount += 2 + addEdge([edgeTypeToMetaIndex["property"], getStringSnapshotIndex(f'{data["objects"][key]["type"]} ({key})'), getNodeSnapshotIndex(value)]) + addEdge([edgeTypeToMetaIndex["internal"], getStringSnapshotIndex(f'Luau table key ref: {data["objects"][key]["type"]} ({key})'), getNodeSnapshotIndex(key)]) + if "array" in obj: + for i, element in enumerate(obj["array"]): + edgeCount += 1 + addEdge([edgeTypeToMetaIndex["element"], i, getNodeSnapshotIndex(element)]) + if "metatable" in obj: + edgeCount += 1 + addEdge([edgeTypeToMetaIndex["internal"], getStringSnapshotIndex(f'metatable ({obj["metatable"]})'), getNodeSnapshotIndex(obj["metatable"])]) + # TODO: consider distinguishing "object" and "array" node types + addNode([nodeTypeToMetaIndex["object"], getStringSnapshotIndex(name), getUniqueId(address), obj["size"], edgeCount, 0, 0]) + elif obj["type"] == "thread": + name = f'Luau thread: {obj["source"]}:{obj["line"]} ({address})' if "source" in obj else f"Luau thread ({address})" + if address == data["roots"]["mainthread"]: + name += " (main thread)" + if "env" in obj: + edgeCount += 1 + addEdge([edgeTypeToMetaIndex["context"], getStringSnapshotIndex(f'env ({obj["env"]})'), getNodeSnapshotIndex(obj["env"])]) + if "stack" in obj: + for i, frame in enumerate(obj["stack"]): + edgeCount += 1 + addEdge([edgeTypeToMetaIndex["context"], getStringSnapshotIndex(f"callstack[{i}]"), getNodeSnapshotIndex(frame)]) + addNode([nodeTypeToMetaIndex["native"], getStringSnapshotIndex(name), getUniqueId(address), obj["size"], edgeCount, 0, 0]) + elif obj["type"] == "function": + name = f'Luau function: {obj["name"]} ({address})' if "name" in obj else f"Luau anonymous function ({address})" + if "env" in obj: + edgeCount += 1 + addEdge([edgeTypeToMetaIndex["context"], getStringSnapshotIndex(f'env ({obj["env"]})'), getNodeSnapshotIndex(obj["env"])]) + if "proto" in obj: + edgeCount += 1 + addEdge([edgeTypeToMetaIndex["context"], getStringSnapshotIndex(f'proto ({obj["proto"]})'), getNodeSnapshotIndex(obj["proto"])]) + if "upvalues" in obj: + for i, upvalue in enumerate(obj["upvalues"]): + edgeCount += 1 + addEdge([edgeTypeToMetaIndex["context"], getStringSnapshotIndex(f"up value ({upvalue})"), getNodeSnapshotIndex(upvalue)]) + addNode([nodeTypeToMetaIndex["closure"], getStringSnapshotIndex(name), getUniqueId(address), obj["size"], edgeCount, 0, 0]) + elif obj["type"] == "upvalue": + if "object" in obj: + edgeCount += 1 + addEdge([edgeTypeToMetaIndex["context"], getStringSnapshotIndex(f'upvalue object ({obj["object"]})'), getNodeSnapshotIndex(obj["object"])]) + addNode([nodeTypeToMetaIndex["native"], getStringSnapshotIndex(f"Luau upvalue ({address})"), getUniqueId(address), obj["size"], edgeCount, 0, 0]) + elif obj["type"] == "userdata": + if "metatable" in obj: + edgeCount += 1 + addEdge([edgeTypeToMetaIndex["internal"], getStringSnapshotIndex(f'metatable ({obj["metatable"]})'), getNodeSnapshotIndex(obj["metatable"])]) + addNode([nodeTypeToMetaIndex["native"], getStringSnapshotIndex(f"Luau userdata ({address})"), getUniqueId(address), obj["size"], edgeCount, 0, 0]) + elif obj["type"] == "proto": + name = f'Luau proto: {obj["source"]}:{obj["line"]} ({address})' if "source" in obj else f"Luau proto ({address})" + if "constants" in obj: + for constant in obj["constants"]: + edgeCount += 1 + addEdge([edgeTypeToMetaIndex["context"], getStringSnapshotIndex(constant), getNodeSnapshotIndex(constant)]) + if "protos" in obj: + for proto in obj["protos"]: + edgeCount += 1 + addEdge([edgeTypeToMetaIndex["context"], getStringSnapshotIndex(proto), getNodeSnapshotIndex(proto)]) + addNode([nodeTypeToMetaIndex["code"], getStringSnapshotIndex(name), getUniqueId(address), obj["size"], edgeCount, 0, 0]) + elif obj["type"] == "string": + addNode([nodeTypeToMetaIndex["string"], getStringSnapshotIndex(obj["data"]), getUniqueId(address), obj["size"], 0, 0, 0]) + elif obj["type"] == "buffer": + addNode([nodeTypeToMetaIndex["native"], getStringSnapshotIndex(f'buffer ({address})'), getUniqueId(address), obj["size"], 0, 0, 0]) + else: + raise Exception(f"Unknown object type: '{obj['type']}'") + + return { + "snapshot": { + "meta": snapshotMeta, + "node_count": len(nodes), + "edge_count": len(edges), + "trace_function_count": 0, + }, + # flatten the nodes and edges arrays + "nodes": [field for node in nodes for field in node], + "edges": [field for edge in edges for field in edge], + "trace_function_infos": [], + "trace_tree": [], + "samples": [], + "locations": [], + "strings": strings, + } + + +if __name__ == "__main__": + luauDump = sys.argv[1] + heapSnapshot = sys.argv[2] + + with open(luauDump, "r") as file: + dump = json.load(file) + + snapshot = convertToSnapshot(dump) + + with open(heapSnapshot, "w") as file: + json.dump(snapshot, file) + + print(f"Heap snapshot written to: '{heapSnapshot}'")