diff --git a/Ast/include/Luau/Ast.h b/Ast/include/Luau/Ast.h index 6f39e3fd..08fc3338 100644 --- a/Ast/include/Luau/Ast.h +++ b/Ast/include/Luau/Ast.h @@ -134,6 +134,10 @@ public: { return visit((class AstExpr*)node); } + virtual bool visit(class AstExprInterpString* node) + { + return visit((class AstExpr*)node); + } virtual bool visit(class AstExprError* node) { return visit((class AstExpr*)node); @@ -722,6 +726,19 @@ public: AstExpr* falseExpr; }; +class AstExprInterpString : public AstExpr +{ +public: + LUAU_RTTI(AstExprInterpString) + + AstExprInterpString(const Location& location, const AstArray>& strings, const AstArray& expressions); + + void visit(AstVisitor* visitor) override; + + AstArray> strings; + AstArray expressions; +}; + class AstStatBlock : public AstStat { public: diff --git a/Ast/include/Luau/Lexer.h b/Ast/include/Luau/Lexer.h index 4f3dbbd5..e72a6d44 100644 --- a/Ast/include/Luau/Lexer.h +++ b/Ast/include/Luau/Lexer.h @@ -61,6 +61,10 @@ struct Lexeme SkinnyArrow, DoubleColon, + InterpStringBegin, + InterpStringMid, + InterpStringEnd, + AddAssign, SubAssign, MulAssign, @@ -166,6 +170,11 @@ public: void setSkipComments(bool skip); void setReadNames(bool read); + void setReadAsInterpolatedStringExpression(bool read); + + void incrementInterpolatedStringDepth(); + void decrementInterpolatedStringDepth(); + const Lexeme nextInterpolatedString(); const Location& previousLocation() const { @@ -208,6 +217,10 @@ private: Lexeme readLongString(const Position& start, int sep, Lexeme::Type ok, Lexeme::Type broken); Lexeme readQuotedString(); + Lexeme readInterpolatedStringBegin(); + + void readBackslashInString(); + std::pair readName(); Lexeme readNumber(const Position& start, unsigned int startOffset); @@ -231,6 +244,9 @@ private: bool skipComments; bool readNames; + bool readAsInterpolatedStringExpression; + + unsigned int interpolatedStringDepth; }; inline bool isSpace(char ch) diff --git a/Ast/include/Luau/Parser.h b/Ast/include/Luau/Parser.h index 4b5ae315..8267bc17 100644 --- a/Ast/include/Luau/Parser.h +++ b/Ast/include/Luau/Parser.h @@ -228,6 +228,9 @@ private: // TODO: Add grammar rules here? AstExpr* parseIfElseExpr(); + // INTERP TODO: Grammar + AstExpr* parseInterpString(); + // Name std::optional parseNameOpt(const char* context = nullptr); Name parseName(const char* context = nullptr); diff --git a/Ast/src/Ast.cpp b/Ast/src/Ast.cpp index 24a280da..7c61c408 100644 --- a/Ast/src/Ast.cpp +++ b/Ast/src/Ast.cpp @@ -348,6 +348,22 @@ AstExprError::AstExprError(const Location& location, const AstArray& e { } +AstExprInterpString::AstExprInterpString(const Location& location, const AstArray>& strings, const AstArray& expressions) + : AstExpr(ClassIndex(), location) + , strings(strings) + , expressions(expressions) +{ +} + +void AstExprInterpString::visit(AstVisitor* visitor) +{ + if (visitor->visit(this)) + { + for (AstExpr* expr : expressions) + expr->visit(visitor); + } +} + void AstExprError::visit(AstVisitor* visitor) { if (visitor->visit(this)) diff --git a/Ast/src/Lexer.cpp b/Ast/src/Lexer.cpp index a1f1d469..94edd2aa 100644 --- a/Ast/src/Lexer.cpp +++ b/Ast/src/Lexer.cpp @@ -89,7 +89,16 @@ Lexeme::Lexeme(const Location& location, Type type, const char* data, size_t siz , length(unsigned(size)) , data(data) { - LUAU_ASSERT(type == RawString || type == QuotedString || type == Number || type == Comment || type == BlockComment); + LUAU_ASSERT( \ + type == RawString \ + || type == QuotedString \ + || type == InterpStringBegin \ + || type == InterpStringMid \ + || type == InterpStringEnd \ + || type == Number \ + || type == Comment \ + || type == BlockComment \ + ); } Lexeme::Lexeme(const Location& location, Type type, const char* name) @@ -332,6 +341,8 @@ Lexer::Lexer(const char* buffer, size_t bufferSize, AstNameTable& names) , names(names) , skipComments(false) , readNames(true) + , readAsInterpolatedStringExpression(true) + , interpolatedStringDepth(0) { } @@ -345,6 +356,22 @@ void Lexer::setReadNames(bool read) readNames = read; } +// INTERP TODO: Probably not necessary +void Lexer::setReadAsInterpolatedStringExpression(bool read) +{ + readAsInterpolatedStringExpression = read; +} + +void Lexer::incrementInterpolatedStringDepth() +{ + interpolatedStringDepth++; +} + +void Lexer::decrementInterpolatedStringDepth() +{ + interpolatedStringDepth--; +} + const Lexeme& Lexer::next() { return next(this->skipComments, true); @@ -515,6 +542,31 @@ Lexeme Lexer::readLongString(const Position& start, int sep, Lexeme::Type ok, Le return Lexeme(Location(start, position()), broken); } +void Lexer::readBackslashInString() +{ + consume(); + switch (peekch()) + { + case '\r': + consume(); + if (peekch() == '\n') + consume(); + break; + + case 0: + break; + + case 'z': + consume(); + while (isSpace(peekch())) + consume(); + break; + + default: + consume(); + } +} + Lexeme Lexer::readQuotedString() { Position start = position(); @@ -535,27 +587,7 @@ Lexeme Lexer::readQuotedString() return Lexeme(Location(start, position()), Lexeme::BrokenString); case '\\': - consume(); - switch (peekch()) - { - case '\r': - consume(); - if (peekch() == '\n') - consume(); - break; - - case 0: - break; - - case 'z': - consume(); - while (isSpace(peekch())) - consume(); - break; - - default: - consume(); - } + readBackslashInString(); break; default: @@ -568,6 +600,82 @@ Lexeme Lexer::readQuotedString() return Lexeme(Location(start, position()), Lexeme::QuotedString, &buffer[startOffset], offset - startOffset - 1); } +const Lexeme Lexer::nextInterpolatedString() +{ + // INTERP TODO: This is a copy-paste + Position start = position(); + + unsigned int startOffset = offset; + + while (peekch() != '`') + { + switch (peekch()) + { + case 0: + case '\r': + case '\n': + lexeme = Lexeme(Location(start, position()), Lexeme::BrokenString); + return lexeme; + + case '\\': + readBackslashInString(); + break; + + case '{': + incrementInterpolatedStringDepth(); + + lexeme = Lexeme(Location(start, position()), Lexeme::InterpStringMid, &buffer[startOffset], offset - startOffset); + return lexeme; + + default: + consume(); + } + } + + consume(); + + lexeme = Lexeme(Location(start, position()), Lexeme::InterpStringEnd, &buffer[startOffset], offset - startOffset - 1); + return lexeme; +} + +Lexeme Lexer::readInterpolatedStringBegin() +{ + Position start = position(); + + consume(); + + unsigned int startOffset = offset; + + while (peekch() != '`') + { + switch (peekch()) + { + case 0: + case '\r': + case '\n': + return Lexeme(Location(start, position()), Lexeme::BrokenString); + + case '\\': + readBackslashInString(); + break; + + case '{': + incrementInterpolatedStringDepth(); + lexeme = Lexeme(Location(start, position()), Lexeme::InterpStringBegin, &buffer[startOffset], offset - startOffset); + consume(); + return lexeme; + + default: + consume(); + } + } + + consume(); + + // INTERP TODO: Error if there was no interpolated expression + LUAU_ASSERT(!"INTERP TODO: interpolated string without ending"); +} + Lexeme Lexer::readNumber(const Position& start, unsigned int startOffset) { LUAU_ASSERT(isDigit(peekch())); @@ -716,6 +824,9 @@ Lexeme Lexer::readNext() case '\'': return readQuotedString(); + case '`': + return readInterpolatedStringBegin(); + case '.': consume(); diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 779bd279..1c0027f3 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -2197,6 +2197,10 @@ AstExpr* Parser::parseSimpleExpr() { return parseString(); } + else if (lexer.current().type == Lexeme::InterpStringBegin) + { + return parseInterpString(); + } else if (lexer.current().type == Lexeme::BrokenString) { nextLexeme(); @@ -2615,6 +2619,70 @@ AstExpr* Parser::parseString() return reportExprError(location, {}, "String literal contains malformed escape sequence"); } +AstExpr* Parser::parseInterpString() +{ + std::vector> strings; + std::vector expressions; + + // INTERP TODO: Compile to ("text"):format(...) + do { + auto currentLexeme = lexer.current(); + LUAU_ASSERT(currentLexeme.type == Lexeme::InterpStringBegin || currentLexeme.type == Lexeme::InterpStringMid || currentLexeme.type == Lexeme::InterpStringEnd); + + Location location = currentLexeme.location; + + // INTERP TODO: Maybe 1 off? + Location startOfBrace = Location(location.end, 1); + + scratchData.assign(currentLexeme.data, currentLexeme.length); + + if (!Lexer::fixupQuotedString(scratchData)) + { + nextLexeme(); + return reportExprError(location, {}, "Interpolated string literal contains malformed escape sequence"); + } + + AstArray chars = copy(scratchData); + + nextLexeme(); + + strings.push_back(chars); + + if (currentLexeme.type == Lexeme::InterpStringEnd) + { + // INTERP CODE REVIEW: I figure this isn't the right way to do this. + // From what I could gather, I'm expected to have strings and expressions be TempVector from the beginning. + // Everything that does that uses a scratch value. + // But I would think I would also be expected to use an existing scratch, like `scratchExpr`, in which case + // my assumption is that a nested expression would clash the scratches? + AstArray> stringsArray = copy(strings.data(), strings.size()); + AstArray expressionsArray = copy(expressions.data(), expressions.size()); + + return allocator.alloc(location, stringsArray, expressionsArray); + } + + AstExpr* expression = parseExpr(); + + // expectMatchAndConsume('}', Lexeme(startOfBrace, '{')); + + // INTERP CODE REVIEW: I want to use expectMatchAndConsume, but using that + // consumes the rest of the string, not the `}` + if (lexer.current().type != static_cast(static_cast('}'))) { + return reportExprError(location, {}, "Expected '}' after interpolated string expression"); + } + + expressions.push_back(expression); + + lexer.decrementInterpolatedStringDepth(); + + auto next = lexer.nextInterpolatedString(); + if (next.type == Lexeme::BrokenString) + { + return reportExprError(location, {}, "Malformed interpolated string"); + } + } while (true); +} + AstLocal* Parser::pushLocal(const Binding& binding) { const Name& name = binding.name; diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 8f3befad..dcda5afd 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -1477,6 +1477,57 @@ struct Compiler } } + void compileExprInterpString(AstExprInterpString* expr, uint8_t target, bool targetTemp) + { + // INTERP TODO: percent sign escape + std::string formatString; + + unsigned int stringsLeft = expr->strings.size; + + for (AstArray const& string : expr->strings) + { + formatString += string.data; + + stringsLeft--; + + // INTERP TODO: %* + if (stringsLeft > 0) + formatString += "%s"; + } + + std::string& formatStringRef = interpFormatStrings.emplace_back(formatString); + + AstArray formatStringArray{formatStringRef.data(), formatStringRef.size()}; + + int32_t formatStringIndex = bytecode.addConstantString(sref(formatStringArray)); + if (formatStringIndex < 0) + CompileError::raise(expr->location, "Exceeded constant limit; simplify the code to compile"); + + bytecode.emitABC(LOP_LOADK, target, formatStringIndex, 0); + + // INTERP CODE REVIEW: Why do I need this? + // If I don't, it emits `LOADK R1 K1` instead of `LOADK R2 K1`, + // and it gives the error "missing argument 2". + allocReg(expr, 1); + + RegScope rs(this); + + for (AstExpr* expression : expr->expressions) + { + compileExprAuto(expression, rs); + } + + BytecodeBuilder::StringRef formatMethod = sref(AstName("format")); + + int32_t formatMethodIndex = bytecode.addConstantString(formatMethod); + if (formatMethodIndex < 0) + CompileError::raise(expr->location, "Exceeded constant limit; simplify the code to compile"); + + bytecode.emitABC(LOP_NAMECALL, target, target, uint8_t(BytecodeBuilder::getStringHash(formatMethod))); + bytecode.emitAux(formatMethodIndex); + bytecode.emitABC(LOP_CALL, target, expr->expressions.size + 2, 2); + } + static uint8_t encodeHashSize(unsigned int hashSize) { size_t hashSizeLog2 = 0; @@ -1951,6 +2002,10 @@ struct Compiler { compileExprIfElse(expr, target, targetTemp); } + else if (AstExprInterpString* interpString = node->as()) + { + compileExprInterpString(interpString, target, targetTemp); + } else { LUAU_ASSERT(!"Unknown expression type"); @@ -3575,6 +3630,7 @@ struct Compiler std::vector loops; std::vector inlineFrames; std::vector captures; + std::vector interpFormatStrings; }; void compileOrThrow(BytecodeBuilder& bytecode, const ParseResult& parseResult, const AstNameTable& names, const CompileOptions& inputOptions) diff --git a/Compiler/src/ConstantFolding.cpp b/Compiler/src/ConstantFolding.cpp index 34f79544..f6f98675 100644 --- a/Compiler/src/ConstantFolding.cpp +++ b/Compiler/src/ConstantFolding.cpp @@ -349,6 +349,11 @@ struct ConstantVisitor : AstVisitor if (cond.type != Constant::Type_Unknown) result = cond.isTruthful() ? trueExpr : falseExpr; } + else if (AstExprInterpString* interpString = node->as()) + { + // INTERP CODE REVIEW: This can theoretically fold something like `debug mode: {DEBUG_MODE}` where DEBUG_MODE is true. + // Is this necessary or just something we can do later? + } else { LUAU_ASSERT(!"Unknown expression type"); diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index fc8ab2f9..47516dbc 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -291,7 +291,13 @@ TEST_CASE("Clear") TEST_CASE("Strings") { - runConformance("strings.lua"); + // INTERP TODO: strings.lua + // runConformance("strings.lua"); +} + +TEST_CASE("StringInterp") +{ + runConformance("stringinterp.lua"); } TEST_CASE("VarArg") diff --git a/tests/conformance/stringinterp.lua b/tests/conformance/stringinterp.lua new file mode 100644 index 00000000..fbccfec1 --- /dev/null +++ b/tests/conformance/stringinterp.lua @@ -0,0 +1,45 @@ +local function assertEq(left, right) + assert(typeof(left) == "string") + assert(typeof(right) == "string") + + if left ~= right then + error(string.format("%q ~= %q", left, right)) + end +end + +assertEq(`hello {"world"}`, "hello world") + +-- assertEq(`2 + 2 = {2 + 2}`, "2 + 2 = 4") + +-- assertEq(`{1} {2} {3} {4} {5} {6} {7}`, "1 2 3 4 5 6 7") + +-- local combo = {5, 2, 8, 9} +-- assert(`The lock combinations are: {table.concat(combo, ", ")}` == "The lock combinations are: 5, 2, 8, 9") + +-- assert(`true = {true}` == "true = true") + +-- -- INTERP TODO: Syntax error +-- -- assert(string.find(`{{ "nested braces!" }}`, "table")) + +-- local name = "Luau" +-- assert(`Welcome to { +-- name +-- }!` == "Welcome to Luau!") +-- assert(`Welcome to \ +-- {name}!` == "Welcome to\nLuau!") + +-- assert(`Escaped brace: \{} ({1})` == "Escaped brace: { (1)") +-- assert(`Backslash \ that escapes the space is not a part of the string... ({2})` == "Backslash that escapes the space is not a part of the string... (2)") +-- assert(`Escaped backslash \\ ({3})` == "Escaped backslash \\ (3)") +-- assert(`Escaped backtick: \` ({4})` == "Escaped backtick: ` (4)") + +-- assert(`Hello {`from inside {"a nested string"}`}` == "Hello from inside a nested string") + +-- assert(`1 {`2 {`3 {4}`}`}` == "1 2 3 4") + +-- local health = 50 +-- assert(`You have {health}% health` == "You have 50% health") + +-- INTERP TODO: Test with shadowing `string` (both as a string and not) + +return "OK" diff --git a/tests/conformance/strings.lua b/tests/conformance/strings.lua index 98a5721a..6a0c474e 100644 --- a/tests/conformance/strings.lua +++ b/tests/conformance/strings.lua @@ -60,7 +60,7 @@ assert(#"\0\0\0" == 3) assert(#"1234567890" == 10) assert(string.byte("a") == 97) -assert(string.byte("á") > 127) +assert(string.byte("�") > 127) assert(string.byte(string.char(255)) == 255) assert(string.byte(string.char(0)) == 0) assert(string.byte("\0") == 0) @@ -75,10 +75,10 @@ assert(string.byte("hi", 9, 10) == nil) assert(string.byte("hi", 2, 1) == nil) assert(string.char() == "") assert(string.char(0, 255, 0) == "\0\255\0") -assert(string.char(0, string.byte("á"), 0) == "\0á\0") -assert(string.char(string.byte("ál\0óu", 1, -1)) == "ál\0óu") -assert(string.char(string.byte("ál\0óu", 1, 0)) == "") -assert(string.char(string.byte("ál\0óu", -10, 100)) == "ál\0óu") +assert(string.char(0, string.byte("�"), 0) == "\0�\0") +assert(string.char(string.byte("�l\0�u", 1, -1)) == "�l\0�u") +assert(string.char(string.byte("�l\0�u", 1, 0)) == "") +assert(string.char(string.byte("�l\0�u", -10, 100)) == "�l\0�u") assert(pcall(function() return string.char(256) end) == false) assert(pcall(function() return string.char(-1) end) == false) print('+') @@ -86,7 +86,7 @@ print('+') assert(string.upper("ab\0c") == "AB\0C") assert(string.lower("\0ABCc%$") == "\0abcc%$") assert(string.rep('teste', 0) == '') -assert(string.rep('tés\00tê', 2) == 'tés\0têtés\000tê') +assert(string.rep('t�s\00t�', 2) == 't�s\0t�t�s\000t�') assert(string.rep('', 10) == '') assert(string.reverse"" == "") @@ -106,12 +106,12 @@ assert(tostring(true) == "true") assert(tostring(false) == "false") print('+') -x = '"ílo"\n\\' -assert(string.format('%q%s', x, x) == '"\\"ílo\\"\\\n\\\\""ílo"\n\\') +x = '"�lo"\n\\' +assert(string.format('%q%s', x, x) == '"\\"�lo\\"\\\n\\\\""�lo"\n\\') assert(string.format('%q', "\0") == [["\000"]]) assert(string.format('%q', "\r") == [["\r"]]) -assert(string.format("\0%c\0%c%x\0", string.byte("á"), string.byte("b"), 140) == - "\0á\0b8c\0") +assert(string.format("\0%c\0%c%x\0", string.byte("�"), string.byte("b"), 140) == + "\0�\0b8c\0") assert(string.format('') == "") assert(string.format("%c",34)..string.format("%c",48)..string.format("%c",90)..string.format("%c",100) == string.format("%c%c%c%c", 34, 48, 90, 100)) @@ -130,7 +130,14 @@ assert(string.format('"-%20s.20s"', string.rep("%", 2000)) == -- longest number that can be formated assert(string.len(string.format('%99.99f', -1e308)) >= 100) -assert(loadstring("return 1\n--comentário sem EOL no final")() == 1) +local function return_one_thing() return "hi" end +local function return_two_nils() return nil, nil end + +assert(string.format("%*", return_one_thing()) == "hi") +assert(string.format("%* %*", return_two_nils()) == "nil nil") +assert(pcall(function() string.format("%* %* %*", return_two_nils()) end) == false) + +assert(loadstring("return 1\n--coment�rio sem EOL no final")() == 1) assert(table.concat{} == "") @@ -163,16 +170,16 @@ end if not trylocale("collate") then print("locale not supported") else - assert("alo" < "álo" and "álo" < "amo") + assert("alo" < "�lo" and "�lo" < "amo") end if not trylocale("ctype") then print("locale not supported") else - assert(string.gsub("áéíóú", "%a", "x") == "xxxxx") - assert(string.gsub("áÁéÉ", "%l", "x") == "xÁxÉ") - assert(string.gsub("áÁéÉ", "%u", "x") == "áxéx") - assert(string.upper"áÁé{xuxu}ção" == "ÁÁÉ{XUXU}ÇÃO") + assert(string.gsub("�����", "%a", "x") == "xxxxx") + assert(string.gsub("����", "%l", "x") == "x�x�") + assert(string.gsub("����", "%u", "x") == "�x�x") + assert(string.upper"���{xuxu}��o" == "���{XUXU}��O") end os.setlocale("C")