Optimize vector literals by storing them in the constant table (#1096)

With this optimization, built-in vector constructor calls with 3/4 arguments are detected by the compiler and turned into vector constants when the arguments are constant numbers.
Requires optimization level 2 because built-ins are not folded otherwise by the compiler.
Bytecode version is bumped because of the new constant type, but old bytecode versions can still be loaded.

The following synthetic benchmark shows ~6.6x improvement.
```
local v
for i = 1, 10000000 do
	v = vector(1, 2, 3)
	v = vector(1, 2, 3)
	v = vector(1, 2, 3)
	v = vector(1, 2, 3)
	v = vector(1, 2, 3)
	v = vector(1, 2, 3)
	v = vector(1, 2, 3)
	v = vector(1, 2, 3)
	v = vector(1, 2, 3)
	v = vector(1, 2, 3)
end
```

Also tried a more real world scenario and could see a few percent improvement.

Added a new fast flag LuauVectorLiterals for enabling the feature.

---------

Co-authored-by: Petri Häkkinen <petrih@rmd.remedy.fi>
Co-authored-by: vegorov-rbx <75688451+vegorov-rbx@users.noreply.github.com>
Co-authored-by: Arseny Kapoulkine <arseny.kapoulkine@gmail.com>
This commit is contained in:
Petri Häkkinen 2023-11-17 14:54:32 +02:00 committed by GitHub
parent 0492ecffdf
commit 298cd70154
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 196 additions and 20 deletions

View file

@ -45,6 +45,7 @@
// Version 2: Adds Proto::linedefined. Supported until 0.544. // Version 2: Adds Proto::linedefined. Supported until 0.544.
// Version 3: Adds FORGPREP/JUMPXEQK* and enhances AUX encoding for FORGLOOP. Removes FORGLOOP_NEXT/INEXT and JUMPIFEQK/JUMPIFNOTEQK. Currently supported. // Version 3: Adds FORGPREP/JUMPXEQK* and enhances AUX encoding for FORGLOOP. Removes FORGLOOP_NEXT/INEXT and JUMPIFEQK/JUMPIFNOTEQK. Currently supported.
// Version 4: Adds Proto::flags, typeinfo, and floor division opcodes IDIV/IDIVK. Currently supported. // Version 4: Adds Proto::flags, typeinfo, and floor division opcodes IDIV/IDIVK. Currently supported.
// Version 5: Adds vector constants. Currently supported.
// Bytecode opcode, part of the instruction header // Bytecode opcode, part of the instruction header
enum LuauOpcode enum LuauOpcode
@ -70,7 +71,7 @@ enum LuauOpcode
// D: value (-32768..32767) // D: value (-32768..32767)
LOP_LOADN, LOP_LOADN,
// LOADK: sets register to an entry from the constant table from the proto (number/string) // LOADK: sets register to an entry from the constant table from the proto (number/vector/string)
// A: target register // A: target register
// D: constant table index (0..32767) // D: constant table index (0..32767)
LOP_LOADK, LOP_LOADK,
@ -426,7 +427,7 @@ enum LuauBytecodeTag
{ {
// Bytecode version; runtime supports [MIN, MAX], compiler emits TARGET by default but may emit a higher version when flags are enabled // Bytecode version; runtime supports [MIN, MAX], compiler emits TARGET by default but may emit a higher version when flags are enabled
LBC_VERSION_MIN = 3, LBC_VERSION_MIN = 3,
LBC_VERSION_MAX = 4, LBC_VERSION_MAX = 5,
LBC_VERSION_TARGET = 4, LBC_VERSION_TARGET = 4,
// Type encoding version // Type encoding version
LBC_TYPE_VERSION = 1, LBC_TYPE_VERSION = 1,
@ -438,6 +439,7 @@ enum LuauBytecodeTag
LBC_CONSTANT_IMPORT, LBC_CONSTANT_IMPORT,
LBC_CONSTANT_TABLE, LBC_CONSTANT_TABLE,
LBC_CONSTANT_CLOSURE, LBC_CONSTANT_CLOSURE,
LBC_CONSTANT_VECTOR,
}; };
// Type table tags // Type table tags

View file

@ -54,6 +54,7 @@ public:
int32_t addConstantNil(); int32_t addConstantNil();
int32_t addConstantBoolean(bool value); int32_t addConstantBoolean(bool value);
int32_t addConstantNumber(double value); int32_t addConstantNumber(double value);
int32_t addConstantVector(float x, float y, float z, float w);
int32_t addConstantString(StringRef value); int32_t addConstantString(StringRef value);
int32_t addImport(uint32_t iid); int32_t addImport(uint32_t iid);
int32_t addConstantTable(const TableShape& shape); int32_t addConstantTable(const TableShape& shape);
@ -146,6 +147,7 @@ private:
Type_Nil, Type_Nil,
Type_Boolean, Type_Boolean,
Type_Number, Type_Number,
Type_Vector,
Type_String, Type_String,
Type_Import, Type_Import,
Type_Table, Type_Table,
@ -157,6 +159,7 @@ private:
{ {
bool valueBoolean; bool valueBoolean;
double valueNumber; double valueNumber;
float valueVector[4];
unsigned int valueString; // index into string table unsigned int valueString; // index into string table
uint32_t valueImport; // 10-10-10-2 encoded import id uint32_t valueImport; // 10-10-10-2 encoded import id
uint32_t valueTable; // index into tableShapes[] uint32_t valueTable; // index into tableShapes[]
@ -167,12 +170,14 @@ private:
struct ConstantKey struct ConstantKey
{ {
Constant::Type type; Constant::Type type;
// Note: this stores value* from Constant; when type is Number_Double, this stores the same bits as double does but in uint64_t. // Note: this stores value* from Constant; when type is Type_Number, this stores the same bits as double does but in uint64_t.
// For Type_Vector, x and y are stored in 'value' and z and w are stored in 'extra'.
uint64_t value; uint64_t value;
uint64_t extra = 0;
bool operator==(const ConstantKey& key) const bool operator==(const ConstantKey& key) const
{ {
return type == key.type && value == key.value; return type == key.type && value == key.value && extra == key.extra;
} }
}; };

View file

@ -5,6 +5,8 @@
#include <math.h> #include <math.h>
LUAU_FASTFLAGVARIABLE(LuauVectorLiterals, false)
namespace Luau namespace Luau
{ {
namespace Compile namespace Compile
@ -32,6 +34,16 @@ static Constant cnum(double v)
return res; return res;
} }
static Constant cvector(double x, double y, double z, double w)
{
Constant res = {Constant::Type_Vector};
res.valueVector[0] = (float)x;
res.valueVector[1] = (float)y;
res.valueVector[2] = (float)z;
res.valueVector[3] = (float)w;
return res;
}
static Constant cstring(const char* v) static Constant cstring(const char* v)
{ {
Constant res = {Constant::Type_String}; Constant res = {Constant::Type_String};
@ -55,6 +67,9 @@ static Constant ctype(const Constant& c)
case Constant::Type_Number: case Constant::Type_Number:
return cstring("number"); return cstring("number");
case Constant::Type_Vector:
return cstring("vector");
case Constant::Type_String: case Constant::Type_String:
return cstring("string"); return cstring("string");
@ -456,6 +471,19 @@ Constant foldBuiltin(int bfid, const Constant* args, size_t count)
if (count == 1 && args[0].type == Constant::Type_Number) if (count == 1 && args[0].type == Constant::Type_Number)
return cnum(round(args[0].valueNumber)); return cnum(round(args[0].valueNumber));
break; break;
case LBF_VECTOR:
if (FFlag::LuauVectorLiterals && count >= 3 &&
args[0].type == Constant::Type_Number &&
args[1].type == Constant::Type_Number &&
args[2].type == Constant::Type_Number)
{
if (count == 3)
return cvector(args[0].valueNumber, args[1].valueNumber, args[2].valueNumber, 0.0);
else if (count == 4 && args[3].type == Constant::Type_Number)
return cvector(args[0].valueNumber, args[1].valueNumber, args[2].valueNumber, args[3].valueNumber);
}
break;
} }
return cvar(); return cvar();

View file

@ -7,6 +7,7 @@
#include <algorithm> #include <algorithm>
#include <string.h> #include <string.h>
LUAU_FASTFLAG(LuauVectorLiterals)
namespace Luau namespace Luau
{ {
@ -41,6 +42,11 @@ static void writeInt(std::string& ss, int value)
ss.append(reinterpret_cast<const char*>(&value), sizeof(value)); ss.append(reinterpret_cast<const char*>(&value), sizeof(value));
} }
static void writeFloat(std::string& ss, float value)
{
ss.append(reinterpret_cast<const char*>(&value), sizeof(value));
}
static void writeDouble(std::string& ss, double value) static void writeDouble(std::string& ss, double value)
{ {
ss.append(reinterpret_cast<const char*>(&value), sizeof(value)); ss.append(reinterpret_cast<const char*>(&value), sizeof(value));
@ -146,23 +152,43 @@ size_t BytecodeBuilder::StringRefHash::operator()(const StringRef& v) const
size_t BytecodeBuilder::ConstantKeyHash::operator()(const ConstantKey& key) const size_t BytecodeBuilder::ConstantKeyHash::operator()(const ConstantKey& key) const
{ {
// finalizer from MurmurHash64B if (key.type == Constant::Type_Vector)
const uint32_t m = 0x5bd1e995; {
uint32_t i[4];
static_assert(sizeof(key.value) + sizeof(key.extra) == sizeof(i), "Expecting vector to have four 32-bit components");
memcpy(i, &key.value, sizeof(i));
uint32_t h1 = uint32_t(key.value); // scramble bits to make sure that integer coordinates have entropy in lower bits
uint32_t h2 = uint32_t(key.value >> 32) ^ (key.type * m); i[0] ^= i[0] >> 17;
i[1] ^= i[1] >> 17;
i[2] ^= i[2] >> 17;
i[3] ^= i[3] >> 17;
h1 ^= h2 >> 18; // Optimized Spatial Hashing for Collision Detection of Deformable Objects
h1 *= m; uint32_t h = (i[0] * 73856093) ^ (i[1] * 19349663) ^ (i[2] * 83492791) ^ (i[3] * 39916801);
h2 ^= h1 >> 22;
h2 *= m;
h1 ^= h2 >> 17;
h1 *= m;
h2 ^= h1 >> 19;
h2 *= m;
// ... truncated to 32-bit output (normally hash is equal to (uint64_t(h1) << 32) | h2, but we only really need the lower 32-bit half) return size_t(h);
return size_t(h2); }
else
{
// finalizer from MurmurHash64B
const uint32_t m = 0x5bd1e995;
uint32_t h1 = uint32_t(key.value);
uint32_t h2 = uint32_t(key.value >> 32) ^ (key.type * m);
h1 ^= h2 >> 18;
h1 *= m;
h2 ^= h1 >> 22;
h2 *= m;
h1 ^= h2 >> 17;
h1 *= m;
h2 ^= h1 >> 19;
h2 *= m;
// ... truncated to 32-bit output (normally hash is equal to (uint64_t(h1) << 32) | h2, but we only really need the lower 32-bit half)
return size_t(h2);
}
} }
size_t BytecodeBuilder::TableShapeHash::operator()(const TableShape& v) const size_t BytecodeBuilder::TableShapeHash::operator()(const TableShape& v) const
@ -330,6 +356,24 @@ int32_t BytecodeBuilder::addConstantNumber(double value)
return addConstant(k, c); return addConstant(k, c);
} }
int32_t BytecodeBuilder::addConstantVector(float x, float y, float z, float w)
{
Constant c = {Constant::Type_Vector};
c.valueVector[0] = x;
c.valueVector[1] = y;
c.valueVector[2] = z;
c.valueVector[3] = w;
ConstantKey k = {Constant::Type_Vector};
static_assert(sizeof(k.value) == sizeof(x) + sizeof(y) && sizeof(k.extra) == sizeof(z) + sizeof(w), "Expecting vector to have four 32-bit components");
memcpy(&k.value, &x, sizeof(x));
memcpy((char*)&k.value + sizeof(x), &y, sizeof(y));
memcpy(&k.extra, &z, sizeof(z));
memcpy((char*)&k.extra + sizeof(z), &w, sizeof(w));
return addConstant(k, c);
}
int32_t BytecodeBuilder::addConstantString(StringRef value) int32_t BytecodeBuilder::addConstantString(StringRef value)
{ {
unsigned int index = addStringTableEntry(value); unsigned int index = addStringTableEntry(value);
@ -647,6 +691,14 @@ void BytecodeBuilder::writeFunction(std::string& ss, uint32_t id, uint8_t flags)
writeDouble(ss, c.valueNumber); writeDouble(ss, c.valueNumber);
break; break;
case Constant::Type_Vector:
writeByte(ss, LBC_CONSTANT_VECTOR);
writeFloat(ss, c.valueVector[0]);
writeFloat(ss, c.valueVector[1]);
writeFloat(ss, c.valueVector[2]);
writeFloat(ss, c.valueVector[3]);
break;
case Constant::Type_String: case Constant::Type_String:
writeByte(ss, LBC_CONSTANT_STRING); writeByte(ss, LBC_CONSTANT_STRING);
writeVarInt(ss, c.valueString); writeVarInt(ss, c.valueString);
@ -1071,7 +1123,7 @@ std::string BytecodeBuilder::getError(const std::string& message)
uint8_t BytecodeBuilder::getVersion() uint8_t BytecodeBuilder::getVersion()
{ {
// This function usually returns LBC_VERSION_TARGET but may sometimes return a higher number (within LBC_VERSION_MIN/MAX) under fast flags // This function usually returns LBC_VERSION_TARGET but may sometimes return a higher number (within LBC_VERSION_MIN/MAX) under fast flags
return LBC_VERSION_TARGET; return (FFlag::LuauVectorLiterals ? 5 : LBC_VERSION_TARGET);
} }
uint8_t BytecodeBuilder::getTypeEncodingVersion() uint8_t BytecodeBuilder::getTypeEncodingVersion()
@ -1635,6 +1687,13 @@ void BytecodeBuilder::dumpConstant(std::string& result, int k) const
case Constant::Type_Number: case Constant::Type_Number:
formatAppend(result, "%.17g", data.valueNumber); formatAppend(result, "%.17g", data.valueNumber);
break; break;
case Constant::Type_Vector:
// 3-vectors is the most common configuration, so truncate to three components if possible
if (data.valueVector[3] == 0.0)
formatAppend(result, "%.9g, %.9g, %.9g", data.valueVector[0], data.valueVector[1], data.valueVector[2]);
else
formatAppend(result, "%.9g, %.9g, %.9g, %.9g", data.valueVector[0], data.valueVector[1], data.valueVector[2], data.valueVector[3]);
break;
case Constant::Type_String: case Constant::Type_String:
{ {
const StringRef& str = debugStrings[data.valueString - 1]; const StringRef& str = debugStrings[data.valueString - 1];

View file

@ -1094,6 +1094,13 @@ struct Compiler
return cv && cv->type != Constant::Type_Unknown && !cv->isTruthful(); return cv && cv->type != Constant::Type_Unknown && !cv->isTruthful();
} }
bool isConstantVector(AstExpr* node)
{
const Constant* cv = constants.find(node);
return cv && cv->type == Constant::Type_Vector;
}
Constant getConstant(AstExpr* node) Constant getConstant(AstExpr* node)
{ {
const Constant* cv = constants.find(node); const Constant* cv = constants.find(node);
@ -1117,6 +1124,10 @@ struct Compiler
std::swap(left, right); std::swap(left, right);
} }
// disable fast path for vectors because supporting it would require a new opcode
if (operandIsConstant && isConstantVector(right))
operandIsConstant = false;
uint8_t rl = compileExprAuto(left, rs); uint8_t rl = compileExprAuto(left, rs);
if (isEq && operandIsConstant) if (isEq && operandIsConstant)
@ -1226,6 +1237,10 @@ struct Compiler
cid = bytecode.addConstantNumber(c->valueNumber); cid = bytecode.addConstantNumber(c->valueNumber);
break; break;
case Constant::Type_Vector:
cid = bytecode.addConstantVector(c->valueVector[0], c->valueVector[1], c->valueVector[2], c->valueVector[3]);
break;
case Constant::Type_String: case Constant::Type_String:
cid = bytecode.addConstantString(sref(c->getString())); cid = bytecode.addConstantString(sref(c->getString()));
break; break;
@ -2052,6 +2067,13 @@ struct Compiler
} }
break; break;
case Constant::Type_Vector:
{
int32_t cid = bytecode.addConstantVector(cv->valueVector[0], cv->valueVector[1], cv->valueVector[2], cv->valueVector[3]);
emitLoadK(target, cid);
}
break;
case Constant::Type_String: case Constant::Type_String:
{ {
int32_t cid = bytecode.addConstantString(sref(cv->getString())); int32_t cid = bytecode.addConstantString(sref(cv->getString()));

View file

@ -26,6 +26,13 @@ static bool constantsEqual(const Constant& la, const Constant& ra)
case Constant::Type_Number: case Constant::Type_Number:
return ra.type == Constant::Type_Number && la.valueNumber == ra.valueNumber; return ra.type == Constant::Type_Number && la.valueNumber == ra.valueNumber;
case Constant::Type_Vector:
return ra.type == Constant::Type_Vector &&
la.valueVector[0] == ra.valueVector[0] &&
la.valueVector[1] == ra.valueVector[1] &&
la.valueVector[2] == ra.valueVector[2] &&
la.valueVector[3] == ra.valueVector[3];
case Constant::Type_String: case Constant::Type_String:
return ra.type == Constant::Type_String && la.stringLength == ra.stringLength && memcmp(la.valueString, ra.valueString, la.stringLength) == 0; return ra.type == Constant::Type_String && la.stringLength == ra.stringLength && memcmp(la.valueString, ra.valueString, la.stringLength) == 0;

View file

@ -16,6 +16,7 @@ struct Constant
Type_Nil, Type_Nil,
Type_Boolean, Type_Boolean,
Type_Number, Type_Number,
Type_Vector,
Type_String, Type_String,
}; };
@ -26,6 +27,7 @@ struct Constant
{ {
bool valueBoolean; bool valueBoolean;
double valueNumber; double valueNumber;
float valueVector[4];
const char* valueString = nullptr; // length stored in stringLength const char* valueString = nullptr; // length stored in stringLength
}; };

View file

@ -287,6 +287,17 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size
break; break;
} }
case LBC_CONSTANT_VECTOR:
{
float x = read<float>(data, size, offset);
float y = read<float>(data, size, offset);
float z = read<float>(data, size, offset);
float w = read<float>(data, size, offset);
(void)w;
setvvalue(&p->k[j], x, y, z, w);
break;
}
case LBC_CONSTANT_STRING: case LBC_CONSTANT_STRING:
{ {
TString* v = readString(strings, data, size, offset); TString* v = readString(strings, data, size, offset);

View file

@ -17,12 +17,17 @@ std::string rep(const std::string& s, size_t n);
using namespace Luau; using namespace Luau;
static std::string compileFunction(const char* source, uint32_t id, int optimizationLevel = 1) static std::string compileFunction(const char* source, uint32_t id, int optimizationLevel = 1, bool enableVectors = false)
{ {
Luau::BytecodeBuilder bcb; Luau::BytecodeBuilder bcb;
bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code); bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code);
Luau::CompileOptions options; Luau::CompileOptions options;
options.optimizationLevel = optimizationLevel; options.optimizationLevel = optimizationLevel;
if (enableVectors)
{
options.vectorLib = "Vector3";
options.vectorCtor = "new";
}
Luau::compileOrThrow(bcb, source, options); Luau::compileOrThrow(bcb, source, options);
return bcb.dumpFunction(id); return bcb.dumpFunction(id);
@ -4475,6 +4480,41 @@ L0: RETURN R0 -1
)"); )");
} }
TEST_CASE("VectorLiterals")
{
ScopedFastFlag sff("LuauVectorLiterals", true);
CHECK_EQ("\n" + compileFunction("return Vector3.new(1, 2, 3)", 0, 2, /*enableVectors*/ true), R"(
LOADK R0 K0 [1, 2, 3]
RETURN R0 1
)");
CHECK_EQ("\n" + compileFunction("print(Vector3.new(1, 2, 3))", 0, 2, /*enableVectors*/ true), R"(
GETIMPORT R0 1 [print]
LOADK R1 K2 [1, 2, 3]
CALL R0 1 0
RETURN R0 0
)");
CHECK_EQ("\n" + compileFunction("print(Vector3.new(1, 2, 3, 4))", 0, 2, /*enableVectors*/ true), R"(
GETIMPORT R0 1 [print]
LOADK R1 K2 [1, 2, 3, 4]
CALL R0 1 0
RETURN R0 0
)");
CHECK_EQ("\n" + compileFunction("return Vector3.new(0, 0, 0), Vector3.new(-0, 0, 0)", 0, 2, /*enableVectors*/ true), R"(
LOADK R0 K0 [0, 0, 0]
LOADK R1 K1 [-0, 0, 0]
RETURN R0 2
)");
CHECK_EQ("\n" + compileFunction("return type(Vector3.new(0, 0, 0))", 0, 2, /*enableVectors*/ true), R"(
LOADK R0 K0 ['vector']
RETURN R0 1
)");
}
TEST_CASE("TypeAssertion") TEST_CASE("TypeAssertion")
{ {
// validate that type assertions work with the compiler and that the code inside type assertion isn't evaluated // validate that type assertions work with the compiler and that the code inside type assertion isn't evaluated