CodeGen: Optimize arithmetics for basic multiplicative identities

This change folds:

	a * 1 => a
	a / 1 => a
	a * -1 => -a
	a / -1 => -a
	a * 2 => a + a
	a / 2^k => a * 2^-k

Note that the following folds are all invalid:

	a + 0 => a (breaks for negative zero)
	a - (-0) => a (breaks for negative zero)
	a - a => 0 (breaks for NaN)

a - 0 could be folded into a but that doesn't happen in benchmarks.

Various cases of UNM_NUM could be optimized (eg (-a) * (-b) = a * b),
but that doesn't happen in benchmarks either.
This commit is contained in:
Arseny Kapoulkine 2024-11-26 21:03:32 +09:00
parent d19a5f0699
commit 434c8f2c42
2 changed files with 43 additions and 4 deletions

View file

@ -19,6 +19,7 @@ LUAU_FASTINTVARIABLE(LuauCodeGenReuseSlotLimit, 64)
LUAU_FASTINTVARIABLE(LuauCodeGenReuseUdataTagLimit, 64)
LUAU_FASTFLAGVARIABLE(DebugLuauAbortingChecks)
LUAU_FASTFLAG(LuauVectorLibNativeDot);
LUAU_FASTFLAGVARIABLE(LuauCodeGenArithOpt);
namespace Luau
{
@ -1194,8 +1195,46 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction&
case IrCmd::SUB_INT:
case IrCmd::ADD_NUM:
case IrCmd::SUB_NUM:
state.substituteOrRecord(inst, index);
break;
case IrCmd::MUL_NUM:
if (FFlag::LuauCodeGenArithOpt)
{
if (std::optional<double> k = function.asDoubleOp(inst.b.kind == IrOpKind::Constant ? inst.b : state.tryGetValue(inst.b)))
{
if (*k == 1.0) // a * 1.0 = a
substitute(function, inst, inst.a);
else if (*k == 2.0) // a * 2.0 = a + a
replace(function, block, index, {IrCmd::ADD_NUM, inst.a, inst.a});
else if (*k == -1.0) // a * -1.0 = -a
replace(function, block, index, {IrCmd::UNM_NUM, inst.a});
else
state.substituteOrRecord(inst, index);
}
else
state.substituteOrRecord(inst, index);
}
else
state.substituteOrRecord(inst, index);
break;
case IrCmd::DIV_NUM:
if (FFlag::LuauCodeGenArithOpt)
{
if (std::optional<double> k = function.asDoubleOp(inst.b.kind == IrOpKind::Constant ? inst.b : state.tryGetValue(inst.b)))
{
if (*k == 1.0) // a / 1.0 = a
substitute(function, inst, inst.a);
else if (int exp = 0; frexp(*k, &exp) == 0.5 && exp >= -1000 && exp <= 1000) // a / 2^k = a * 2^-k
replace(function, block, index, {IrCmd::MUL_NUM, inst.a, build.constDouble(1.0 / *k)});
else
state.substituteOrRecord(inst, index);
}
else
state.substituteOrRecord(inst, index);
}
else
state.substituteOrRecord(inst, index);
break;
case IrCmd::IDIV_NUM:
case IrCmd::MOD_NUM:
case IrCmd::MIN_NUM:

View file

@ -540,7 +540,7 @@ TEST_CASE("VectorCustomAccess")
CHECK_EQ(
"\n" + getCodegenAssembly(R"(
local function vec3magn(a: vector)
return a.Magnitude * 2
return a.Magnitude * 3
end
)"),
R"(
@ -560,7 +560,7 @@ bb_bytecode_1:
%12 = ADD_NUM %9, %10
%13 = ADD_NUM %12, %11
%14 = SQRT_NUM %13
%20 = MUL_NUM %14, 2
%20 = MUL_NUM %14, 3
STORE_DOUBLE R1, %20
STORE_TAG R1, tnumber
INTERRUPT 3u
@ -1167,7 +1167,7 @@ local function inl(v: vector, s: number)
end
local function getsum(x)
return inl(x, 2) + inl(x, 5)
return inl(x, 3) + inl(x, 5)
end
)",
/* includeIrTypes */ true
@ -1195,7 +1195,7 @@ bb_bytecode_1:
bb_bytecode_0:
CHECK_TAG R0, tvector, exit(0)
%2 = LOAD_FLOAT R0, 4i
%8 = MUL_NUM %2, 2
%8 = MUL_NUM %2, 3
%13 = LOAD_FLOAT R0, 4i
%19 = MUL_NUM %13, 5
%28 = ADD_NUM %8, %19