CodeGen: Optimize arithmetics for basic identities (#1545)

This change folds:

	a * 1 => a
	a / 1 => a
	a * -1 => -a
	a / -1 => -a
	a * 2 => a + a
	a / 2^k => a * 2^-k
	a - 0 => a
	a + (-0) => a

Note that the following folds are all invalid:

	a + 0 => a (breaks for negative zero)
	a - (-0) => a (breaks for negative zero)
	a - a => 0 (breaks for Inf/NaN)
	0 - a => -a (breaks for negative zero)

Various cases of UNM_NUM could be optimized (eg (-a) * (-b) = a * b),
but that doesn't happen in benchmarks.

While it would be possible to also fold inverse multiplications (k * v),
these do not happen in benchmarks and rarely happen in bytecode due
to type based optimizations. Maybe this can be improved with some sort
of
IR canonicalization in the future if necessary.

I've considered moving some of these, like division strength reduction,
to IR translation (as this is where POW is lowered presently) but it
didn't
seem better one way or the other.

This change improves performance on some benchmarks, e.g. trig and
voxelgen,
and should be a strict uplift as it never generates more instructions or
longer
latency chains. On Apple M2, without division->multiplication
optimization, both
benchmarks see 0.1-0.2% uplift. Division optimization makes trig 3%
faster; I expect
the gains on X64 will be more muted, but on Apple this seems to allow
loop iterations
to overlap better by removing the division bottleneck.
This commit is contained in:
Arseny Kapoulkine 2024-11-27 21:44:39 +09:00 committed by GitHub
parent d19a5f0699
commit b5801d3377
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 75 additions and 6 deletions

View file

@ -9,6 +9,7 @@
#include "lua.h" #include "lua.h"
#include <limits.h> #include <limits.h>
#include <math.h>
#include <array> #include <array>
#include <utility> #include <utility>
@ -19,6 +20,7 @@ LUAU_FASTINTVARIABLE(LuauCodeGenReuseSlotLimit, 64)
LUAU_FASTINTVARIABLE(LuauCodeGenReuseUdataTagLimit, 64) LUAU_FASTINTVARIABLE(LuauCodeGenReuseUdataTagLimit, 64)
LUAU_FASTFLAGVARIABLE(DebugLuauAbortingChecks) LUAU_FASTFLAGVARIABLE(DebugLuauAbortingChecks)
LUAU_FASTFLAG(LuauVectorLibNativeDot); LUAU_FASTFLAG(LuauVectorLibNativeDot);
LUAU_FASTFLAGVARIABLE(LuauCodeGenArithOpt);
namespace Luau namespace Luau
{ {
@ -1192,10 +1194,67 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction&
break; break;
case IrCmd::ADD_INT: case IrCmd::ADD_INT:
case IrCmd::SUB_INT: case IrCmd::SUB_INT:
state.substituteOrRecord(inst, index);
break;
case IrCmd::ADD_NUM: case IrCmd::ADD_NUM:
case IrCmd::SUB_NUM: case IrCmd::SUB_NUM:
if (FFlag::LuauCodeGenArithOpt)
{
if (std::optional<double> k = function.asDoubleOp(inst.b.kind == IrOpKind::Constant ? inst.b : state.tryGetValue(inst.b)))
{
// a + 0.0 and a - (-0.0) can't be folded since the behavior is different for negative zero
// however, a - 0.0 and a + (-0.0) can be folded into a
if (*k == 0.0 && bool(signbit(*k)) == (inst.cmd == IrCmd::ADD_NUM))
substitute(function, inst, inst.a);
else
state.substituteOrRecord(inst, index);
}
else
state.substituteOrRecord(inst, index);
}
else
state.substituteOrRecord(inst, index);
break;
case IrCmd::MUL_NUM: case IrCmd::MUL_NUM:
if (FFlag::LuauCodeGenArithOpt)
{
if (std::optional<double> k = function.asDoubleOp(inst.b.kind == IrOpKind::Constant ? inst.b : state.tryGetValue(inst.b)))
{
if (*k == 1.0) // a * 1.0 = a
substitute(function, inst, inst.a);
else if (*k == 2.0) // a * 2.0 = a + a
replace(function, block, index, {IrCmd::ADD_NUM, inst.a, inst.a});
else if (*k == -1.0) // a * -1.0 = -a
replace(function, block, index, {IrCmd::UNM_NUM, inst.a});
else
state.substituteOrRecord(inst, index);
}
else
state.substituteOrRecord(inst, index);
}
else
state.substituteOrRecord(inst, index);
break;
case IrCmd::DIV_NUM: case IrCmd::DIV_NUM:
if (FFlag::LuauCodeGenArithOpt)
{
if (std::optional<double> k = function.asDoubleOp(inst.b.kind == IrOpKind::Constant ? inst.b : state.tryGetValue(inst.b)))
{
if (*k == 1.0) // a / 1.0 = a
substitute(function, inst, inst.a);
else if (*k == -1.0) // a / -1.0 = -a
replace(function, block, index, {IrCmd::UNM_NUM, inst.a});
else if (int exp = 0; frexp(*k, &exp) == 0.5 && exp >= -1000 && exp <= 1000) // a / 2^k = a * 2^-k
replace(function, block, index, {IrCmd::MUL_NUM, inst.a, build.constDouble(1.0 / *k)});
else
state.substituteOrRecord(inst, index);
}
else
state.substituteOrRecord(inst, index);
}
else
state.substituteOrRecord(inst, index);
break;
case IrCmd::IDIV_NUM: case IrCmd::IDIV_NUM:
case IrCmd::MOD_NUM: case IrCmd::MOD_NUM:
case IrCmd::MIN_NUM: case IrCmd::MIN_NUM:

View file

@ -540,7 +540,7 @@ TEST_CASE("VectorCustomAccess")
CHECK_EQ( CHECK_EQ(
"\n" + getCodegenAssembly(R"( "\n" + getCodegenAssembly(R"(
local function vec3magn(a: vector) local function vec3magn(a: vector)
return a.Magnitude * 2 return a.Magnitude * 3
end end
)"), )"),
R"( R"(
@ -560,7 +560,7 @@ bb_bytecode_1:
%12 = ADD_NUM %9, %10 %12 = ADD_NUM %9, %10
%13 = ADD_NUM %12, %11 %13 = ADD_NUM %12, %11
%14 = SQRT_NUM %13 %14 = SQRT_NUM %13
%20 = MUL_NUM %14, 2 %20 = MUL_NUM %14, 3
STORE_DOUBLE R1, %20 STORE_DOUBLE R1, %20
STORE_TAG R1, tnumber STORE_TAG R1, tnumber
INTERRUPT 3u INTERRUPT 3u
@ -1167,7 +1167,7 @@ local function inl(v: vector, s: number)
end end
local function getsum(x) local function getsum(x)
return inl(x, 2) + inl(x, 5) return inl(x, 3) + inl(x, 5)
end end
)", )",
/* includeIrTypes */ true /* includeIrTypes */ true
@ -1195,7 +1195,7 @@ bb_bytecode_1:
bb_bytecode_0: bb_bytecode_0:
CHECK_TAG R0, tvector, exit(0) CHECK_TAG R0, tvector, exit(0)
%2 = LOAD_FLOAT R0, 4i %2 = LOAD_FLOAT R0, 4i
%8 = MUL_NUM %2, 2 %8 = MUL_NUM %2, 3
%13 = LOAD_FLOAT R0, 4i %13 = LOAD_FLOAT R0, 4i
%19 = MUL_NUM %13, 5 %19 = MUL_NUM %13, 5
%28 = ADD_NUM %8, %19 %28 = ADD_NUM %8, %19

View file

@ -92,6 +92,16 @@ assert((function() local a = 1 a = a - 2 return a end)() == -1)
assert((function() local a = 1 a = a * 2 return a end)() == 2) assert((function() local a = 1 a = a * 2 return a end)() == 2)
assert((function() local a = 1 a = a / 2 return a end)() == 0.5) assert((function() local a = 1 a = a / 2 return a end)() == 0.5)
-- binary ops with fp specials, neg zero, large constants
-- argument is passed into anonymous function to prevent constant folding
assert((function(a) return tostring(a + 0) end)(-0) == "0")
assert((function(a) return tostring(a - 0) end)(-0) == "-0")
assert((function(a) return tostring(0 - a) end)(0) == "0")
assert((function(a) return tostring(a - a) end)(1 / 0) == "nan")
assert((function(a) return tostring(a * 0) end)(0 / 0) == "nan")
assert((function(a) return tostring(a / (2^1000)) end)(2^1000) == "1")
assert((function(a) return tostring(a / (2^-1000)) end)(2^-1000) == "1")
-- floor division should always round towards -Infinity -- floor division should always round towards -Infinity
assert((function() local a = 1 a = a // 2 return a end)() == 0) assert((function() local a = 1 a = a // 2 return a end)() == 0)
assert((function() local a = 3 a = a // 2 return a end)() == 1) assert((function() local a = 3 a = a // 2 return a end)() == 1)
@ -290,7 +300,7 @@ assert((function() local t = {[1] = 1, [2] = 2} return t[1] + t[2] end)() == 3)
assert((function() return table.concat({}, ',') end)() == "") assert((function() return table.concat({}, ',') end)() == "")
assert((function() return table.concat({1}, ',') end)() == "1") assert((function() return table.concat({1}, ',') end)() == "1")
assert((function() return table.concat({1,2}, ',') end)() == "1,2") assert((function() return table.concat({1,2}, ',') end)() == "1,2")
assert((function() return table.concat({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, ',') end)() == assert((function() return table.concat({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, ',') end)() ==
"1,2,3,4,5,6,7,8,9,10,11,12,13,14,15") "1,2,3,4,5,6,7,8,9,10,11,12,13,14,15")
assert((function() return table.concat({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, ',') end)() == "1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16") assert((function() return table.concat({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, ',') end)() == "1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16")
assert((function() return table.concat({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}, ',') end)() == "1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17") assert((function() return table.concat({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}, ',') end)() == "1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17")
@ -770,7 +780,7 @@ assert(tostring(0) == "0")
assert(tostring(-0) == "-0") assert(tostring(-0) == "-0")
-- test newline handling in long strings -- test newline handling in long strings
assert((function() assert((function()
local s1 = [[ local s1 = [[
]] ]]
local s2 = [[ local s2 = [[