Compiler: Optimize k*n and k+n when types are known

When type information is specified, we can compile k*n and k+n
into MULK/ADDK forms that are faster to execute, as long as we think
n is a number. Since we generally restrict type aware optimizations
to O2, this does that as well.

This makes trig benchmark ~4% faster on Apple M2 in VM, and also a
tiny improvment on scimark (~0.1%) can be observed. The optimization
only affects interpreted execution, as NCG already can synthesize
optimal code here.

If the type information is not truthful (e.g. user annotates type as
a number and it's not), the worst case scenario is flipped arguments
to metamethods like __add/__mul for constant left hand side.
This commit is contained in:
Arseny Kapoulkine 2024-11-16 16:02:05 +09:00
parent d1025d0029
commit 859475d315
2 changed files with 58 additions and 8 deletions

View file

@ -27,6 +27,7 @@ LUAU_FASTINTVARIABLE(LuauCompileInlineThresholdMaxBoost, 300)
LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5) LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5)
LUAU_FASTFLAG(LuauNativeAttribute) LUAU_FASTFLAG(LuauNativeAttribute)
LUAU_FASTFLAGVARIABLE(LuauCompileOptimizeRevArith)
namespace Luau namespace Luau
{ {
@ -1623,6 +1624,24 @@ struct Compiler
return; return;
} }
} }
else if (FFlag::LuauCompileOptimizeRevArith && options.optimizationLevel >= 2 && (expr->op == AstExprBinary::Add || expr->op == AstExprBinary::Mul))
{
// Optimization: replace k*r with r*k when r is known to be a number (otherwise metamethods may be called)
if (LuauBytecodeType* ty = exprTypes.find(expr); ty && *ty == LBC_TYPE_NUMBER)
{
int32_t lc = getConstantNumber(expr->left);
if (lc >= 0 && lc <= 255)
{
uint8_t rr = compileExprAuto(expr->right, rs);
bytecode.emitABC(getBinaryOpArith(expr->op, /* k= */ true), target, uint8_t(rr), uint8_t(lc));
hintTemporaryExprRegType(expr->right, rr, LBC_TYPE_NUMBER, /* instLength */ 1);
return;
}
}
}
uint8_t rl = compileExprAuto(expr->left, rs); uint8_t rl = compileExprAuto(expr->left, rs);
uint8_t rr = compileExprAuto(expr->right, rs); uint8_t rr = compileExprAuto(expr->right, rs);

View file

@ -23,15 +23,17 @@ LUAU_FASTINT(LuauCompileLoopUnrollThresholdMaxBoost)
LUAU_FASTINT(LuauRecursionLimit) LUAU_FASTINT(LuauRecursionLimit)
LUAU_FASTFLAG(LuauUserDefinedTypeFunctionsSyntax2) LUAU_FASTFLAG(LuauUserDefinedTypeFunctionsSyntax2)
LUAU_FASTFLAG(LuauCompileVectorTypeInfo) LUAU_FASTFLAG(LuauCompileVectorTypeInfo)
LUAU_FASTFLAG(LuauCompileOptimizeRevArith)
using namespace Luau; using namespace Luau;
static std::string compileFunction(const char* source, uint32_t id, int optimizationLevel = 1, bool enableVectors = false) static std::string compileFunction(const char* source, uint32_t id, int optimizationLevel = 1, int typeInfoLevel = 0, bool enableVectors = false)
{ {
Luau::BytecodeBuilder bcb; Luau::BytecodeBuilder bcb;
bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code); bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code);
Luau::CompileOptions options; Luau::CompileOptions options;
options.optimizationLevel = optimizationLevel; options.optimizationLevel = optimizationLevel;
options.typeInfoLevel = typeInfoLevel;
if (enableVectors) if (enableVectors)
{ {
options.vectorLib = "Vector3"; options.vectorLib = "Vector3";
@ -4931,32 +4933,32 @@ L0: RETURN R3 -1
TEST_CASE("VectorLiterals") TEST_CASE("VectorLiterals")
{ {
CHECK_EQ("\n" + compileFunction("return Vector3.new(1, 2, 3)", 0, 2, /*enableVectors*/ true), R"( CHECK_EQ("\n" + compileFunction("return Vector3.new(1, 2, 3)", 0, 2, 0, /*enableVectors*/ true), R"(
LOADK R0 K0 [1, 2, 3] LOADK R0 K0 [1, 2, 3]
RETURN R0 1 RETURN R0 1
)"); )");
CHECK_EQ("\n" + compileFunction("print(Vector3.new(1, 2, 3))", 0, 2, /*enableVectors*/ true), R"( CHECK_EQ("\n" + compileFunction("print(Vector3.new(1, 2, 3))", 0, 2, 0, /*enableVectors*/ true), R"(
GETIMPORT R0 1 [print] GETIMPORT R0 1 [print]
LOADK R1 K2 [1, 2, 3] LOADK R1 K2 [1, 2, 3]
CALL R0 1 0 CALL R0 1 0
RETURN R0 0 RETURN R0 0
)"); )");
CHECK_EQ("\n" + compileFunction("print(Vector3.new(1, 2, 3, 4))", 0, 2, /*enableVectors*/ true), R"( CHECK_EQ("\n" + compileFunction("print(Vector3.new(1, 2, 3, 4))", 0, 2, 0, /*enableVectors*/ true), R"(
GETIMPORT R0 1 [print] GETIMPORT R0 1 [print]
LOADK R1 K2 [1, 2, 3, 4] LOADK R1 K2 [1, 2, 3, 4]
CALL R0 1 0 CALL R0 1 0
RETURN R0 0 RETURN R0 0
)"); )");
CHECK_EQ("\n" + compileFunction("return Vector3.new(0, 0, 0), Vector3.new(-0, 0, 0)", 0, 2, /*enableVectors*/ true), R"( CHECK_EQ("\n" + compileFunction("return Vector3.new(0, 0, 0), Vector3.new(-0, 0, 0)", 0, 2, 0, /*enableVectors*/ true), R"(
LOADK R0 K0 [0, 0, 0] LOADK R0 K0 [0, 0, 0]
LOADK R1 K1 [-0, 0, 0] LOADK R1 K1 [-0, 0, 0]
RETURN R0 2 RETURN R0 2
)"); )");
CHECK_EQ("\n" + compileFunction("return type(Vector3.new(0, 0, 0))", 0, 2, /*enableVectors*/ true), R"( CHECK_EQ("\n" + compileFunction("return type(Vector3.new(0, 0, 0))", 0, 2, 0, /*enableVectors*/ true), R"(
LOADK R0 K0 ['vector'] LOADK R0 K0 ['vector']
RETURN R0 1 RETURN R0 1
)"); )");
@ -8845,8 +8847,9 @@ RETURN R0 1
TEST_CASE("ArithRevK") TEST_CASE("ArithRevK")
{ {
// - and / have special optimized form for reverse constants; in the future, + and * will likely get compiled to ADDK/MULK ScopedFastFlag sff(FFlag::LuauCompileOptimizeRevArith, true);
// other operators are not important enough to optimize reverse constant forms for
// - and / have special optimized form for reverse constants; in absence of type information, we can't optimize other ops
CHECK_EQ( CHECK_EQ(
"\n" + compileFunction0(R"( "\n" + compileFunction0(R"(
local x: number = unknown local x: number = unknown
@ -8867,6 +8870,34 @@ IDIV R6 R7 R0
LOADN R8 2 LOADN R8 2
POW R7 R8 R0 POW R7 R8 R0
RETURN R1 7 RETURN R1 7
)"
);
// the same code with type information can optimize commutative operators (+ and *) as well
// other operators are not important enough to optimize reverse constant forms for
CHECK_EQ(
"\n" + compileFunction(
R"(
local x: number = unknown
return 2 + x, 2 - x, 2 * x, 2 / x, 2 % x, 2 // x, 2 ^ x
)",
0,
2,
1
),
R"(
GETIMPORT R0 1 [unknown]
ADDK R1 R0 K2 [2]
SUBRK R2 K2 [2] R0
MULK R3 R0 K2 [2]
DIVRK R4 K2 [2] R0
LOADN R6 2
MOD R5 R6 R0
LOADN R7 2
IDIV R6 R7 R0
LOADN R8 2
POW R7 R8 R0
RETURN R1 7
)" )"
); );
} }