Basic string interpolation proof of concept

This commit is contained in:
Kampfkarren 2022-07-26 01:48:22 -07:00
parent 2a6d1c03ac
commit 24f61dfeb5
11 changed files with 389 additions and 39 deletions

View file

@ -134,6 +134,10 @@ public:
{
return visit((class AstExpr*)node);
}
virtual bool visit(class AstExprInterpString* node)
{
return visit((class AstExpr*)node);
}
virtual bool visit(class AstExprError* node)
{
return visit((class AstExpr*)node);
@ -722,6 +726,19 @@ public:
AstExpr* falseExpr;
};
class AstExprInterpString : public AstExpr
{
public:
LUAU_RTTI(AstExprInterpString)
AstExprInterpString(const Location& location, const AstArray<AstArray<char>>& strings, const AstArray<AstExpr*>& expressions);
void visit(AstVisitor* visitor) override;
AstArray<AstArray<char>> strings;
AstArray<AstExpr*> expressions;
};
class AstStatBlock : public AstStat
{
public:

View file

@ -61,6 +61,10 @@ struct Lexeme
SkinnyArrow,
DoubleColon,
InterpStringBegin,
InterpStringMid,
InterpStringEnd,
AddAssign,
SubAssign,
MulAssign,
@ -166,6 +170,11 @@ public:
void setSkipComments(bool skip);
void setReadNames(bool read);
void setReadAsInterpolatedStringExpression(bool read);
void incrementInterpolatedStringDepth();
void decrementInterpolatedStringDepth();
const Lexeme nextInterpolatedString();
const Location& previousLocation() const
{
@ -208,6 +217,10 @@ private:
Lexeme readLongString(const Position& start, int sep, Lexeme::Type ok, Lexeme::Type broken);
Lexeme readQuotedString();
Lexeme readInterpolatedStringBegin();
void readBackslashInString();
std::pair<AstName, Lexeme::Type> readName();
Lexeme readNumber(const Position& start, unsigned int startOffset);
@ -231,6 +244,9 @@ private:
bool skipComments;
bool readNames;
bool readAsInterpolatedStringExpression;
unsigned int interpolatedStringDepth;
};
inline bool isSpace(char ch)

View file

@ -228,6 +228,9 @@ private:
// TODO: Add grammar rules here?
AstExpr* parseIfElseExpr();
// INTERP TODO: Grammar
AstExpr* parseInterpString();
// Name
std::optional<Name> parseNameOpt(const char* context = nullptr);
Name parseName(const char* context = nullptr);

View file

@ -348,6 +348,22 @@ AstExprError::AstExprError(const Location& location, const AstArray<AstExpr*>& e
{
}
AstExprInterpString::AstExprInterpString(const Location& location, const AstArray<AstArray<char>>& strings, const AstArray<AstExpr*>& expressions)
: AstExpr(ClassIndex(), location)
, strings(strings)
, expressions(expressions)
{
}
void AstExprInterpString::visit(AstVisitor* visitor)
{
if (visitor->visit(this))
{
for (AstExpr* expr : expressions)
expr->visit(visitor);
}
}
void AstExprError::visit(AstVisitor* visitor)
{
if (visitor->visit(this))

View file

@ -89,7 +89,16 @@ Lexeme::Lexeme(const Location& location, Type type, const char* data, size_t siz
, length(unsigned(size))
, data(data)
{
LUAU_ASSERT(type == RawString || type == QuotedString || type == Number || type == Comment || type == BlockComment);
LUAU_ASSERT( \
type == RawString \
|| type == QuotedString \
|| type == InterpStringBegin \
|| type == InterpStringMid \
|| type == InterpStringEnd \
|| type == Number \
|| type == Comment \
|| type == BlockComment \
);
}
Lexeme::Lexeme(const Location& location, Type type, const char* name)
@ -332,6 +341,8 @@ Lexer::Lexer(const char* buffer, size_t bufferSize, AstNameTable& names)
, names(names)
, skipComments(false)
, readNames(true)
, readAsInterpolatedStringExpression(true)
, interpolatedStringDepth(0)
{
}
@ -345,6 +356,22 @@ void Lexer::setReadNames(bool read)
readNames = read;
}
// INTERP TODO: Probably not necessary
void Lexer::setReadAsInterpolatedStringExpression(bool read)
{
readAsInterpolatedStringExpression = read;
}
void Lexer::incrementInterpolatedStringDepth()
{
interpolatedStringDepth++;
}
void Lexer::decrementInterpolatedStringDepth()
{
interpolatedStringDepth--;
}
const Lexeme& Lexer::next()
{
return next(this->skipComments, true);
@ -515,6 +542,31 @@ Lexeme Lexer::readLongString(const Position& start, int sep, Lexeme::Type ok, Le
return Lexeme(Location(start, position()), broken);
}
void Lexer::readBackslashInString()
{
consume();
switch (peekch())
{
case '\r':
consume();
if (peekch() == '\n')
consume();
break;
case 0:
break;
case 'z':
consume();
while (isSpace(peekch()))
consume();
break;
default:
consume();
}
}
Lexeme Lexer::readQuotedString()
{
Position start = position();
@ -535,27 +587,7 @@ Lexeme Lexer::readQuotedString()
return Lexeme(Location(start, position()), Lexeme::BrokenString);
case '\\':
consume();
switch (peekch())
{
case '\r':
consume();
if (peekch() == '\n')
consume();
break;
case 0:
break;
case 'z':
consume();
while (isSpace(peekch()))
consume();
break;
default:
consume();
}
readBackslashInString();
break;
default:
@ -568,6 +600,82 @@ Lexeme Lexer::readQuotedString()
return Lexeme(Location(start, position()), Lexeme::QuotedString, &buffer[startOffset], offset - startOffset - 1);
}
const Lexeme Lexer::nextInterpolatedString()
{
// INTERP TODO: This is a copy-paste
Position start = position();
unsigned int startOffset = offset;
while (peekch() != '`')
{
switch (peekch())
{
case 0:
case '\r':
case '\n':
lexeme = Lexeme(Location(start, position()), Lexeme::BrokenString);
return lexeme;
case '\\':
readBackslashInString();
break;
case '{':
incrementInterpolatedStringDepth();
lexeme = Lexeme(Location(start, position()), Lexeme::InterpStringMid, &buffer[startOffset], offset - startOffset);
return lexeme;
default:
consume();
}
}
consume();
lexeme = Lexeme(Location(start, position()), Lexeme::InterpStringEnd, &buffer[startOffset], offset - startOffset - 1);
return lexeme;
}
Lexeme Lexer::readInterpolatedStringBegin()
{
Position start = position();
consume();
unsigned int startOffset = offset;
while (peekch() != '`')
{
switch (peekch())
{
case 0:
case '\r':
case '\n':
return Lexeme(Location(start, position()), Lexeme::BrokenString);
case '\\':
readBackslashInString();
break;
case '{':
incrementInterpolatedStringDepth();
lexeme = Lexeme(Location(start, position()), Lexeme::InterpStringBegin, &buffer[startOffset], offset - startOffset);
consume();
return lexeme;
default:
consume();
}
}
consume();
// INTERP TODO: Error if there was no interpolated expression
LUAU_ASSERT(!"INTERP TODO: interpolated string without ending");
}
Lexeme Lexer::readNumber(const Position& start, unsigned int startOffset)
{
LUAU_ASSERT(isDigit(peekch()));
@ -716,6 +824,9 @@ Lexeme Lexer::readNext()
case '\'':
return readQuotedString();
case '`':
return readInterpolatedStringBegin();
case '.':
consume();

View file

@ -2197,6 +2197,10 @@ AstExpr* Parser::parseSimpleExpr()
{
return parseString();
}
else if (lexer.current().type == Lexeme::InterpStringBegin)
{
return parseInterpString();
}
else if (lexer.current().type == Lexeme::BrokenString)
{
nextLexeme();
@ -2615,6 +2619,70 @@ AstExpr* Parser::parseString()
return reportExprError(location, {}, "String literal contains malformed escape sequence");
}
AstExpr* Parser::parseInterpString()
{
std::vector<AstArray<char>> strings;
std::vector<AstExpr*> expressions;
// INTERP TODO: Compile to ("text"):format(...)
do {
auto currentLexeme = lexer.current();
LUAU_ASSERT(currentLexeme.type == Lexeme::InterpStringBegin || currentLexeme.type == Lexeme::InterpStringMid || currentLexeme.type == Lexeme::InterpStringEnd);
Location location = currentLexeme.location;
// INTERP TODO: Maybe 1 off?
Location startOfBrace = Location(location.end, 1);
scratchData.assign(currentLexeme.data, currentLexeme.length);
if (!Lexer::fixupQuotedString(scratchData))
{
nextLexeme();
return reportExprError(location, {}, "Interpolated string literal contains malformed escape sequence");
}
AstArray<char> chars = copy(scratchData);
nextLexeme();
strings.push_back(chars);
if (currentLexeme.type == Lexeme::InterpStringEnd)
{
// INTERP CODE REVIEW: I figure this isn't the right way to do this.
// From what I could gather, I'm expected to have strings and expressions be TempVector from the beginning.
// Everything that does that uses a scratch value.
// But I would think I would also be expected to use an existing scratch, like `scratchExpr`, in which case
// my assumption is that a nested expression would clash the scratches?
AstArray<AstArray<char>> stringsArray = copy(strings.data(), strings.size());
AstArray<AstExpr*> expressionsArray = copy(expressions.data(), expressions.size());
return allocator.alloc<AstExprInterpString>(location, stringsArray, expressionsArray);
}
AstExpr* expression = parseExpr();
// expectMatchAndConsume('}', Lexeme(startOfBrace, '{'));
// INTERP CODE REVIEW: I want to use expectMatchAndConsume, but using that
// consumes the rest of the string, not the `}`
if (lexer.current().type != static_cast<Lexeme::Type>(static_cast<unsigned char>('}'))) {
return reportExprError(location, {}, "Expected '}' after interpolated string expression");
}
expressions.push_back(expression);
lexer.decrementInterpolatedStringDepth();
auto next = lexer.nextInterpolatedString();
if (next.type == Lexeme::BrokenString)
{
return reportExprError(location, {}, "Malformed interpolated string");
}
} while (true);
}
AstLocal* Parser::pushLocal(const Binding& binding)
{
const Name& name = binding.name;

View file

@ -1477,6 +1477,57 @@ struct Compiler
}
}
void compileExprInterpString(AstExprInterpString* expr, uint8_t target, bool targetTemp)
{
// INTERP TODO: percent sign escape
std::string formatString;
unsigned int stringsLeft = expr->strings.size;
for (AstArray<char> const& string : expr->strings)
{
formatString += string.data;
stringsLeft--;
// INTERP TODO: %*
if (stringsLeft > 0)
formatString += "%s";
}
std::string& formatStringRef = interpFormatStrings.emplace_back(formatString);
AstArray<char> formatStringArray{formatStringRef.data(), formatStringRef.size()};
int32_t formatStringIndex = bytecode.addConstantString(sref(formatStringArray));
if (formatStringIndex < 0)
CompileError::raise(expr->location, "Exceeded constant limit; simplify the code to compile");
bytecode.emitABC(LOP_LOADK, target, formatStringIndex, 0);
// INTERP CODE REVIEW: Why do I need this?
// If I don't, it emits `LOADK R1 K1` instead of `LOADK R2 K1`,
// and it gives the error "missing argument 2".
allocReg(expr, 1);
RegScope rs(this);
for (AstExpr* expression : expr->expressions)
{
compileExprAuto(expression, rs);
}
BytecodeBuilder::StringRef formatMethod = sref(AstName("format"));
int32_t formatMethodIndex = bytecode.addConstantString(formatMethod);
if (formatMethodIndex < 0)
CompileError::raise(expr->location, "Exceeded constant limit; simplify the code to compile");
bytecode.emitABC(LOP_NAMECALL, target, target, uint8_t(BytecodeBuilder::getStringHash(formatMethod)));
bytecode.emitAux(formatMethodIndex);
bytecode.emitABC(LOP_CALL, target, expr->expressions.size + 2, 2);
}
static uint8_t encodeHashSize(unsigned int hashSize)
{
size_t hashSizeLog2 = 0;
@ -1951,6 +2002,10 @@ struct Compiler
{
compileExprIfElse(expr, target, targetTemp);
}
else if (AstExprInterpString* interpString = node->as<AstExprInterpString>())
{
compileExprInterpString(interpString, target, targetTemp);
}
else
{
LUAU_ASSERT(!"Unknown expression type");
@ -3575,6 +3630,7 @@ struct Compiler
std::vector<Loop> loops;
std::vector<InlineFrame> inlineFrames;
std::vector<Capture> captures;
std::vector<std::string> interpFormatStrings;
};
void compileOrThrow(BytecodeBuilder& bytecode, const ParseResult& parseResult, const AstNameTable& names, const CompileOptions& inputOptions)

View file

@ -349,6 +349,11 @@ struct ConstantVisitor : AstVisitor
if (cond.type != Constant::Type_Unknown)
result = cond.isTruthful() ? trueExpr : falseExpr;
}
else if (AstExprInterpString* interpString = node->as<AstExprInterpString>())
{
// INTERP CODE REVIEW: This can theoretically fold something like `debug mode: {DEBUG_MODE}` where DEBUG_MODE is true.
// Is this necessary or just something we can do later?
}
else
{
LUAU_ASSERT(!"Unknown expression type");

View file

@ -291,7 +291,13 @@ TEST_CASE("Clear")
TEST_CASE("Strings")
{
runConformance("strings.lua");
// INTERP TODO: strings.lua
// runConformance("strings.lua");
}
TEST_CASE("StringInterp")
{
runConformance("stringinterp.lua");
}
TEST_CASE("VarArg")

View file

@ -0,0 +1,45 @@
local function assertEq(left, right)
assert(typeof(left) == "string")
assert(typeof(right) == "string")
if left ~= right then
error(string.format("%q ~= %q", left, right))
end
end
assertEq(`hello {"world"}`, "hello world")
-- assertEq(`2 + 2 = {2 + 2}`, "2 + 2 = 4")
-- assertEq(`{1} {2} {3} {4} {5} {6} {7}`, "1 2 3 4 5 6 7")
-- local combo = {5, 2, 8, 9}
-- assert(`The lock combinations are: {table.concat(combo, ", ")}` == "The lock combinations are: 5, 2, 8, 9")
-- assert(`true = {true}` == "true = true")
-- -- INTERP TODO: Syntax error
-- -- assert(string.find(`{{ "nested braces!" }}`, "table"))
-- local name = "Luau"
-- assert(`Welcome to {
-- name
-- }!` == "Welcome to Luau!")
-- assert(`Welcome to \
-- {name}!` == "Welcome to\nLuau!")
-- assert(`Escaped brace: \{} ({1})` == "Escaped brace: { (1)")
-- assert(`Backslash \ that escapes the space is not a part of the string... ({2})` == "Backslash that escapes the space is not a part of the string... (2)")
-- assert(`Escaped backslash \\ ({3})` == "Escaped backslash \\ (3)")
-- assert(`Escaped backtick: \` ({4})` == "Escaped backtick: ` (4)")
-- assert(`Hello {`from inside {"a nested string"}`}` == "Hello from inside a nested string")
-- assert(`1 {`2 {`3 {4}`}`}` == "1 2 3 4")
-- local health = 50
-- assert(`You have {health}% health` == "You have 50% health")
-- INTERP TODO: Test with shadowing `string` (both as a string and not)
return "OK"

View file

@ -60,7 +60,7 @@ assert(#"\0\0\0" == 3)
assert(#"1234567890" == 10)
assert(string.byte("a") == 97)
assert(string.byte("á") > 127)
assert(string.byte("<EFBFBD>") > 127)
assert(string.byte(string.char(255)) == 255)
assert(string.byte(string.char(0)) == 0)
assert(string.byte("\0") == 0)
@ -75,10 +75,10 @@ assert(string.byte("hi", 9, 10) == nil)
assert(string.byte("hi", 2, 1) == nil)
assert(string.char() == "")
assert(string.char(0, 255, 0) == "\0\255\0")
assert(string.char(0, string.byte("á"), 0) == "\0á\0")
assert(string.char(string.byte("ál\0óu", 1, -1)) == "ál\0óu")
assert(string.char(string.byte("ál\0óu", 1, 0)) == "")
assert(string.char(string.byte("ál\0óu", -10, 100)) == "ál\0óu")
assert(string.char(0, string.byte("<EFBFBD>"), 0) == "\0<EFBFBD>\0")
assert(string.char(string.byte("<EFBFBD>l\0<EFBFBD>u", 1, -1)) == "<EFBFBD>l\0<EFBFBD>u")
assert(string.char(string.byte("<EFBFBD>l\0<EFBFBD>u", 1, 0)) == "")
assert(string.char(string.byte("<EFBFBD>l\0<EFBFBD>u", -10, 100)) == "<EFBFBD>l\0<EFBFBD>u")
assert(pcall(function() return string.char(256) end) == false)
assert(pcall(function() return string.char(-1) end) == false)
print('+')
@ -86,7 +86,7 @@ print('+')
assert(string.upper("ab\0c") == "AB\0C")
assert(string.lower("\0ABCc%$") == "\0abcc%$")
assert(string.rep('teste', 0) == '')
assert(string.rep('tés\00', 2) == 'tés\0têtés\000')
assert(string.rep('t<EFBFBD>s\00t<EFBFBD>', 2) == 't<EFBFBD>s\0t<EFBFBD>t<EFBFBD>s\000t<EFBFBD>')
assert(string.rep('', 10) == '')
assert(string.reverse"" == "")
@ -106,12 +106,12 @@ assert(tostring(true) == "true")
assert(tostring(false) == "false")
print('+')
x = '"ílo"\n\\'
assert(string.format('%q%s', x, x) == '"\\"ílo\\"\\\n\\\\""ílo"\n\\')
x = '"<EFBFBD>lo"\n\\'
assert(string.format('%q%s', x, x) == '"\\"<EFBFBD>lo\\"\\\n\\\\""<22>lo"\n\\')
assert(string.format('%q', "\0") == [["\000"]])
assert(string.format('%q', "\r") == [["\r"]])
assert(string.format("\0%c\0%c%x\0", string.byte("á"), string.byte("b"), 140) ==
"\0á\0b8c\0")
assert(string.format("\0%c\0%c%x\0", string.byte("<EFBFBD>"), string.byte("b"), 140) ==
"\0<EFBFBD>\0b8c\0")
assert(string.format('') == "")
assert(string.format("%c",34)..string.format("%c",48)..string.format("%c",90)..string.format("%c",100) ==
string.format("%c%c%c%c", 34, 48, 90, 100))
@ -130,7 +130,14 @@ assert(string.format('"-%20s.20s"', string.rep("%", 2000)) ==
-- longest number that can be formated
assert(string.len(string.format('%99.99f', -1e308)) >= 100)
assert(loadstring("return 1\n--comentário sem EOL no final")() == 1)
local function return_one_thing() return "hi" end
local function return_two_nils() return nil, nil end
assert(string.format("%*", return_one_thing()) == "hi")
assert(string.format("%* %*", return_two_nils()) == "nil nil")
assert(pcall(function() string.format("%* %* %*", return_two_nils()) end) == false)
assert(loadstring("return 1\n--coment<6E>rio sem EOL no final")() == 1)
assert(table.concat{} == "")
@ -163,16 +170,16 @@ end
if not trylocale("collate") then
print("locale not supported")
else
assert("alo" < "álo" and "álo" < "amo")
assert("alo" < "<EFBFBD>lo" and "<EFBFBD>lo" < "amo")
end
if not trylocale("ctype") then
print("locale not supported")
else
assert(string.gsub("áéíóú", "%a", "x") == "xxxxx")
assert(string.gsub("áÁéÉ", "%l", "x") == "xÁxÉ")
assert(string.gsub("áÁéÉ", "%u", "x") == "áxéx")
assert(string.upper"áÁé{xuxu}ção" == "ÁÁÉ{XUXU}ÇÃO")
assert(string.gsub("<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>", "%a", "x") == "xxxxx")
assert(string.gsub("<EFBFBD><EFBFBD><EFBFBD><EFBFBD>", "%l", "x") == "x<EFBFBD>x<EFBFBD>")
assert(string.gsub("<EFBFBD><EFBFBD><EFBFBD><EFBFBD>", "%u", "x") == "<EFBFBD>x<EFBFBD>x")
assert(string.upper"<EFBFBD><EFBFBD><EFBFBD>{xuxu}<7D><>o" == "<EFBFBD><EFBFBD><EFBFBD>{XUXU}<7D><>O")
end
os.setlocale("C")