Optimize vector literals by storing them in the constant table (#1096)

With this optimization, built-in vector constructor calls with 3/4 arguments are detected by the compiler and turned into vector constants when the arguments are constant numbers.
Requires optimization level 2 because built-ins are not folded otherwise by the compiler.
Bytecode version is bumped because of the new constant type, but old bytecode versions can still be loaded.

The following synthetic benchmark shows ~6.6x improvement.
```
local v
for i = 1, 10000000 do
	v = vector(1, 2, 3)
	v = vector(1, 2, 3)
	v = vector(1, 2, 3)
	v = vector(1, 2, 3)
	v = vector(1, 2, 3)
	v = vector(1, 2, 3)
	v = vector(1, 2, 3)
	v = vector(1, 2, 3)
	v = vector(1, 2, 3)
	v = vector(1, 2, 3)
end
```

Also tried a more real world scenario and could see a few percent improvement.

Added a new fast flag LuauVectorLiterals for enabling the feature.

---------

Co-authored-by: Petri Häkkinen <petrih@rmd.remedy.fi>
Co-authored-by: vegorov-rbx <75688451+vegorov-rbx@users.noreply.github.com>
Co-authored-by: Arseny Kapoulkine <arseny.kapoulkine@gmail.com>
This commit is contained in:
Petri Häkkinen 2023-11-17 14:54:32 +02:00 committed by GitHub
parent 0492ecffdf
commit 298cd70154
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 196 additions and 20 deletions

View file

@ -45,6 +45,7 @@
// Version 2: Adds Proto::linedefined. Supported until 0.544.
// Version 3: Adds FORGPREP/JUMPXEQK* and enhances AUX encoding for FORGLOOP. Removes FORGLOOP_NEXT/INEXT and JUMPIFEQK/JUMPIFNOTEQK. Currently supported.
// Version 4: Adds Proto::flags, typeinfo, and floor division opcodes IDIV/IDIVK. Currently supported.
// Version 5: Adds vector constants. Currently supported.
// Bytecode opcode, part of the instruction header
enum LuauOpcode
@ -70,7 +71,7 @@ enum LuauOpcode
// D: value (-32768..32767)
LOP_LOADN,
// LOADK: sets register to an entry from the constant table from the proto (number/string)
// LOADK: sets register to an entry from the constant table from the proto (number/vector/string)
// A: target register
// D: constant table index (0..32767)
LOP_LOADK,
@ -426,7 +427,7 @@ enum LuauBytecodeTag
{
// Bytecode version; runtime supports [MIN, MAX], compiler emits TARGET by default but may emit a higher version when flags are enabled
LBC_VERSION_MIN = 3,
LBC_VERSION_MAX = 4,
LBC_VERSION_MAX = 5,
LBC_VERSION_TARGET = 4,
// Type encoding version
LBC_TYPE_VERSION = 1,
@ -438,6 +439,7 @@ enum LuauBytecodeTag
LBC_CONSTANT_IMPORT,
LBC_CONSTANT_TABLE,
LBC_CONSTANT_CLOSURE,
LBC_CONSTANT_VECTOR,
};
// Type table tags

View file

@ -54,6 +54,7 @@ public:
int32_t addConstantNil();
int32_t addConstantBoolean(bool value);
int32_t addConstantNumber(double value);
int32_t addConstantVector(float x, float y, float z, float w);
int32_t addConstantString(StringRef value);
int32_t addImport(uint32_t iid);
int32_t addConstantTable(const TableShape& shape);
@ -146,6 +147,7 @@ private:
Type_Nil,
Type_Boolean,
Type_Number,
Type_Vector,
Type_String,
Type_Import,
Type_Table,
@ -157,6 +159,7 @@ private:
{
bool valueBoolean;
double valueNumber;
float valueVector[4];
unsigned int valueString; // index into string table
uint32_t valueImport; // 10-10-10-2 encoded import id
uint32_t valueTable; // index into tableShapes[]
@ -167,12 +170,14 @@ private:
struct ConstantKey
{
Constant::Type type;
// Note: this stores value* from Constant; when type is Number_Double, this stores the same bits as double does but in uint64_t.
// Note: this stores value* from Constant; when type is Type_Number, this stores the same bits as double does but in uint64_t.
// For Type_Vector, x and y are stored in 'value' and z and w are stored in 'extra'.
uint64_t value;
uint64_t extra = 0;
bool operator==(const ConstantKey& key) const
{
return type == key.type && value == key.value;
return type == key.type && value == key.value && extra == key.extra;
}
};

View file

@ -5,6 +5,8 @@
#include <math.h>
LUAU_FASTFLAGVARIABLE(LuauVectorLiterals, false)
namespace Luau
{
namespace Compile
@ -32,6 +34,16 @@ static Constant cnum(double v)
return res;
}
static Constant cvector(double x, double y, double z, double w)
{
Constant res = {Constant::Type_Vector};
res.valueVector[0] = (float)x;
res.valueVector[1] = (float)y;
res.valueVector[2] = (float)z;
res.valueVector[3] = (float)w;
return res;
}
static Constant cstring(const char* v)
{
Constant res = {Constant::Type_String};
@ -55,6 +67,9 @@ static Constant ctype(const Constant& c)
case Constant::Type_Number:
return cstring("number");
case Constant::Type_Vector:
return cstring("vector");
case Constant::Type_String:
return cstring("string");
@ -456,6 +471,19 @@ Constant foldBuiltin(int bfid, const Constant* args, size_t count)
if (count == 1 && args[0].type == Constant::Type_Number)
return cnum(round(args[0].valueNumber));
break;
case LBF_VECTOR:
if (FFlag::LuauVectorLiterals && count >= 3 &&
args[0].type == Constant::Type_Number &&
args[1].type == Constant::Type_Number &&
args[2].type == Constant::Type_Number)
{
if (count == 3)
return cvector(args[0].valueNumber, args[1].valueNumber, args[2].valueNumber, 0.0);
else if (count == 4 && args[3].type == Constant::Type_Number)
return cvector(args[0].valueNumber, args[1].valueNumber, args[2].valueNumber, args[3].valueNumber);
}
break;
}
return cvar();

View file

@ -7,6 +7,7 @@
#include <algorithm>
#include <string.h>
LUAU_FASTFLAG(LuauVectorLiterals)
namespace Luau
{
@ -41,6 +42,11 @@ static void writeInt(std::string& ss, int value)
ss.append(reinterpret_cast<const char*>(&value), sizeof(value));
}
static void writeFloat(std::string& ss, float value)
{
ss.append(reinterpret_cast<const char*>(&value), sizeof(value));
}
static void writeDouble(std::string& ss, double value)
{
ss.append(reinterpret_cast<const char*>(&value), sizeof(value));
@ -146,23 +152,43 @@ size_t BytecodeBuilder::StringRefHash::operator()(const StringRef& v) const
size_t BytecodeBuilder::ConstantKeyHash::operator()(const ConstantKey& key) const
{
// finalizer from MurmurHash64B
const uint32_t m = 0x5bd1e995;
if (key.type == Constant::Type_Vector)
{
uint32_t i[4];
static_assert(sizeof(key.value) + sizeof(key.extra) == sizeof(i), "Expecting vector to have four 32-bit components");
memcpy(i, &key.value, sizeof(i));
uint32_t h1 = uint32_t(key.value);
uint32_t h2 = uint32_t(key.value >> 32) ^ (key.type * m);
// scramble bits to make sure that integer coordinates have entropy in lower bits
i[0] ^= i[0] >> 17;
i[1] ^= i[1] >> 17;
i[2] ^= i[2] >> 17;
i[3] ^= i[3] >> 17;
h1 ^= h2 >> 18;
h1 *= m;
h2 ^= h1 >> 22;
h2 *= m;
h1 ^= h2 >> 17;
h1 *= m;
h2 ^= h1 >> 19;
h2 *= m;
// Optimized Spatial Hashing for Collision Detection of Deformable Objects
uint32_t h = (i[0] * 73856093) ^ (i[1] * 19349663) ^ (i[2] * 83492791) ^ (i[3] * 39916801);
// ... truncated to 32-bit output (normally hash is equal to (uint64_t(h1) << 32) | h2, but we only really need the lower 32-bit half)
return size_t(h2);
return size_t(h);
}
else
{
// finalizer from MurmurHash64B
const uint32_t m = 0x5bd1e995;
uint32_t h1 = uint32_t(key.value);
uint32_t h2 = uint32_t(key.value >> 32) ^ (key.type * m);
h1 ^= h2 >> 18;
h1 *= m;
h2 ^= h1 >> 22;
h2 *= m;
h1 ^= h2 >> 17;
h1 *= m;
h2 ^= h1 >> 19;
h2 *= m;
// ... truncated to 32-bit output (normally hash is equal to (uint64_t(h1) << 32) | h2, but we only really need the lower 32-bit half)
return size_t(h2);
}
}
size_t BytecodeBuilder::TableShapeHash::operator()(const TableShape& v) const
@ -330,6 +356,24 @@ int32_t BytecodeBuilder::addConstantNumber(double value)
return addConstant(k, c);
}
int32_t BytecodeBuilder::addConstantVector(float x, float y, float z, float w)
{
Constant c = {Constant::Type_Vector};
c.valueVector[0] = x;
c.valueVector[1] = y;
c.valueVector[2] = z;
c.valueVector[3] = w;
ConstantKey k = {Constant::Type_Vector};
static_assert(sizeof(k.value) == sizeof(x) + sizeof(y) && sizeof(k.extra) == sizeof(z) + sizeof(w), "Expecting vector to have four 32-bit components");
memcpy(&k.value, &x, sizeof(x));
memcpy((char*)&k.value + sizeof(x), &y, sizeof(y));
memcpy(&k.extra, &z, sizeof(z));
memcpy((char*)&k.extra + sizeof(z), &w, sizeof(w));
return addConstant(k, c);
}
int32_t BytecodeBuilder::addConstantString(StringRef value)
{
unsigned int index = addStringTableEntry(value);
@ -647,6 +691,14 @@ void BytecodeBuilder::writeFunction(std::string& ss, uint32_t id, uint8_t flags)
writeDouble(ss, c.valueNumber);
break;
case Constant::Type_Vector:
writeByte(ss, LBC_CONSTANT_VECTOR);
writeFloat(ss, c.valueVector[0]);
writeFloat(ss, c.valueVector[1]);
writeFloat(ss, c.valueVector[2]);
writeFloat(ss, c.valueVector[3]);
break;
case Constant::Type_String:
writeByte(ss, LBC_CONSTANT_STRING);
writeVarInt(ss, c.valueString);
@ -1071,7 +1123,7 @@ std::string BytecodeBuilder::getError(const std::string& message)
uint8_t BytecodeBuilder::getVersion()
{
// This function usually returns LBC_VERSION_TARGET but may sometimes return a higher number (within LBC_VERSION_MIN/MAX) under fast flags
return LBC_VERSION_TARGET;
return (FFlag::LuauVectorLiterals ? 5 : LBC_VERSION_TARGET);
}
uint8_t BytecodeBuilder::getTypeEncodingVersion()
@ -1635,6 +1687,13 @@ void BytecodeBuilder::dumpConstant(std::string& result, int k) const
case Constant::Type_Number:
formatAppend(result, "%.17g", data.valueNumber);
break;
case Constant::Type_Vector:
// 3-vectors is the most common configuration, so truncate to three components if possible
if (data.valueVector[3] == 0.0)
formatAppend(result, "%.9g, %.9g, %.9g", data.valueVector[0], data.valueVector[1], data.valueVector[2]);
else
formatAppend(result, "%.9g, %.9g, %.9g, %.9g", data.valueVector[0], data.valueVector[1], data.valueVector[2], data.valueVector[3]);
break;
case Constant::Type_String:
{
const StringRef& str = debugStrings[data.valueString - 1];

View file

@ -1094,6 +1094,13 @@ struct Compiler
return cv && cv->type != Constant::Type_Unknown && !cv->isTruthful();
}
bool isConstantVector(AstExpr* node)
{
const Constant* cv = constants.find(node);
return cv && cv->type == Constant::Type_Vector;
}
Constant getConstant(AstExpr* node)
{
const Constant* cv = constants.find(node);
@ -1117,6 +1124,10 @@ struct Compiler
std::swap(left, right);
}
// disable fast path for vectors because supporting it would require a new opcode
if (operandIsConstant && isConstantVector(right))
operandIsConstant = false;
uint8_t rl = compileExprAuto(left, rs);
if (isEq && operandIsConstant)
@ -1226,6 +1237,10 @@ struct Compiler
cid = bytecode.addConstantNumber(c->valueNumber);
break;
case Constant::Type_Vector:
cid = bytecode.addConstantVector(c->valueVector[0], c->valueVector[1], c->valueVector[2], c->valueVector[3]);
break;
case Constant::Type_String:
cid = bytecode.addConstantString(sref(c->getString()));
break;
@ -2052,6 +2067,13 @@ struct Compiler
}
break;
case Constant::Type_Vector:
{
int32_t cid = bytecode.addConstantVector(cv->valueVector[0], cv->valueVector[1], cv->valueVector[2], cv->valueVector[3]);
emitLoadK(target, cid);
}
break;
case Constant::Type_String:
{
int32_t cid = bytecode.addConstantString(sref(cv->getString()));

View file

@ -26,6 +26,13 @@ static bool constantsEqual(const Constant& la, const Constant& ra)
case Constant::Type_Number:
return ra.type == Constant::Type_Number && la.valueNumber == ra.valueNumber;
case Constant::Type_Vector:
return ra.type == Constant::Type_Vector &&
la.valueVector[0] == ra.valueVector[0] &&
la.valueVector[1] == ra.valueVector[1] &&
la.valueVector[2] == ra.valueVector[2] &&
la.valueVector[3] == ra.valueVector[3];
case Constant::Type_String:
return ra.type == Constant::Type_String && la.stringLength == ra.stringLength && memcmp(la.valueString, ra.valueString, la.stringLength) == 0;

View file

@ -16,6 +16,7 @@ struct Constant
Type_Nil,
Type_Boolean,
Type_Number,
Type_Vector,
Type_String,
};
@ -26,6 +27,7 @@ struct Constant
{
bool valueBoolean;
double valueNumber;
float valueVector[4];
const char* valueString = nullptr; // length stored in stringLength
};

View file

@ -287,6 +287,17 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size
break;
}
case LBC_CONSTANT_VECTOR:
{
float x = read<float>(data, size, offset);
float y = read<float>(data, size, offset);
float z = read<float>(data, size, offset);
float w = read<float>(data, size, offset);
(void)w;
setvvalue(&p->k[j], x, y, z, w);
break;
}
case LBC_CONSTANT_STRING:
{
TString* v = readString(strings, data, size, offset);

View file

@ -17,12 +17,17 @@ std::string rep(const std::string& s, size_t n);
using namespace Luau;
static std::string compileFunction(const char* source, uint32_t id, int optimizationLevel = 1)
static std::string compileFunction(const char* source, uint32_t id, int optimizationLevel = 1, bool enableVectors = false)
{
Luau::BytecodeBuilder bcb;
bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code);
Luau::CompileOptions options;
options.optimizationLevel = optimizationLevel;
if (enableVectors)
{
options.vectorLib = "Vector3";
options.vectorCtor = "new";
}
Luau::compileOrThrow(bcb, source, options);
return bcb.dumpFunction(id);
@ -4475,6 +4480,41 @@ L0: RETURN R0 -1
)");
}
TEST_CASE("VectorLiterals")
{
ScopedFastFlag sff("LuauVectorLiterals", true);
CHECK_EQ("\n" + compileFunction("return Vector3.new(1, 2, 3)", 0, 2, /*enableVectors*/ true), R"(
LOADK R0 K0 [1, 2, 3]
RETURN R0 1
)");
CHECK_EQ("\n" + compileFunction("print(Vector3.new(1, 2, 3))", 0, 2, /*enableVectors*/ true), R"(
GETIMPORT R0 1 [print]
LOADK R1 K2 [1, 2, 3]
CALL R0 1 0
RETURN R0 0
)");
CHECK_EQ("\n" + compileFunction("print(Vector3.new(1, 2, 3, 4))", 0, 2, /*enableVectors*/ true), R"(
GETIMPORT R0 1 [print]
LOADK R1 K2 [1, 2, 3, 4]
CALL R0 1 0
RETURN R0 0
)");
CHECK_EQ("\n" + compileFunction("return Vector3.new(0, 0, 0), Vector3.new(-0, 0, 0)", 0, 2, /*enableVectors*/ true), R"(
LOADK R0 K0 [0, 0, 0]
LOADK R1 K1 [-0, 0, 0]
RETURN R0 2
)");
CHECK_EQ("\n" + compileFunction("return type(Vector3.new(0, 0, 0))", 0, 2, /*enableVectors*/ true), R"(
LOADK R0 K0 ['vector']
RETURN R0 1
)");
}
TEST_CASE("TypeAssertion")
{
// validate that type assertions work with the compiler and that the code inside type assertion isn't evaluated