mirror of
https://github.com/luau-lang/luau.git
synced 2024-12-12 13:00:38 +00:00
CodeGen: Optimize arithmetics for basic identities (#1545)
This change folds: a * 1 => a a / 1 => a a * -1 => -a a / -1 => -a a * 2 => a + a a / 2^k => a * 2^-k a - 0 => a a + (-0) => a Note that the following folds are all invalid: a + 0 => a (breaks for negative zero) a - (-0) => a (breaks for negative zero) a - a => 0 (breaks for Inf/NaN) 0 - a => -a (breaks for negative zero) Various cases of UNM_NUM could be optimized (eg (-a) * (-b) = a * b), but that doesn't happen in benchmarks. While it would be possible to also fold inverse multiplications (k * v), these do not happen in benchmarks and rarely happen in bytecode due to type based optimizations. Maybe this can be improved with some sort of IR canonicalization in the future if necessary. I've considered moving some of these, like division strength reduction, to IR translation (as this is where POW is lowered presently) but it didn't seem better one way or the other. This change improves performance on some benchmarks, e.g. trig and voxelgen, and should be a strict uplift as it never generates more instructions or longer latency chains. On Apple M2, without division->multiplication optimization, both benchmarks see 0.1-0.2% uplift. Division optimization makes trig 3% faster; I expect the gains on X64 will be more muted, but on Apple this seems to allow loop iterations to overlap better by removing the division bottleneck.
This commit is contained in:
parent
d19a5f0699
commit
b5801d3377
3 changed files with 75 additions and 6 deletions
|
@ -9,6 +9,7 @@
|
|||
#include "lua.h"
|
||||
|
||||
#include <limits.h>
|
||||
#include <math.h>
|
||||
|
||||
#include <array>
|
||||
#include <utility>
|
||||
|
@ -19,6 +20,7 @@ LUAU_FASTINTVARIABLE(LuauCodeGenReuseSlotLimit, 64)
|
|||
LUAU_FASTINTVARIABLE(LuauCodeGenReuseUdataTagLimit, 64)
|
||||
LUAU_FASTFLAGVARIABLE(DebugLuauAbortingChecks)
|
||||
LUAU_FASTFLAG(LuauVectorLibNativeDot);
|
||||
LUAU_FASTFLAGVARIABLE(LuauCodeGenArithOpt);
|
||||
|
||||
namespace Luau
|
||||
{
|
||||
|
@ -1192,10 +1194,67 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction&
|
|||
break;
|
||||
case IrCmd::ADD_INT:
|
||||
case IrCmd::SUB_INT:
|
||||
state.substituteOrRecord(inst, index);
|
||||
break;
|
||||
case IrCmd::ADD_NUM:
|
||||
case IrCmd::SUB_NUM:
|
||||
if (FFlag::LuauCodeGenArithOpt)
|
||||
{
|
||||
if (std::optional<double> k = function.asDoubleOp(inst.b.kind == IrOpKind::Constant ? inst.b : state.tryGetValue(inst.b)))
|
||||
{
|
||||
// a + 0.0 and a - (-0.0) can't be folded since the behavior is different for negative zero
|
||||
// however, a - 0.0 and a + (-0.0) can be folded into a
|
||||
if (*k == 0.0 && bool(signbit(*k)) == (inst.cmd == IrCmd::ADD_NUM))
|
||||
substitute(function, inst, inst.a);
|
||||
else
|
||||
state.substituteOrRecord(inst, index);
|
||||
}
|
||||
else
|
||||
state.substituteOrRecord(inst, index);
|
||||
}
|
||||
else
|
||||
state.substituteOrRecord(inst, index);
|
||||
break;
|
||||
case IrCmd::MUL_NUM:
|
||||
if (FFlag::LuauCodeGenArithOpt)
|
||||
{
|
||||
if (std::optional<double> k = function.asDoubleOp(inst.b.kind == IrOpKind::Constant ? inst.b : state.tryGetValue(inst.b)))
|
||||
{
|
||||
if (*k == 1.0) // a * 1.0 = a
|
||||
substitute(function, inst, inst.a);
|
||||
else if (*k == 2.0) // a * 2.0 = a + a
|
||||
replace(function, block, index, {IrCmd::ADD_NUM, inst.a, inst.a});
|
||||
else if (*k == -1.0) // a * -1.0 = -a
|
||||
replace(function, block, index, {IrCmd::UNM_NUM, inst.a});
|
||||
else
|
||||
state.substituteOrRecord(inst, index);
|
||||
}
|
||||
else
|
||||
state.substituteOrRecord(inst, index);
|
||||
}
|
||||
else
|
||||
state.substituteOrRecord(inst, index);
|
||||
break;
|
||||
case IrCmd::DIV_NUM:
|
||||
if (FFlag::LuauCodeGenArithOpt)
|
||||
{
|
||||
if (std::optional<double> k = function.asDoubleOp(inst.b.kind == IrOpKind::Constant ? inst.b : state.tryGetValue(inst.b)))
|
||||
{
|
||||
if (*k == 1.0) // a / 1.0 = a
|
||||
substitute(function, inst, inst.a);
|
||||
else if (*k == -1.0) // a / -1.0 = -a
|
||||
replace(function, block, index, {IrCmd::UNM_NUM, inst.a});
|
||||
else if (int exp = 0; frexp(*k, &exp) == 0.5 && exp >= -1000 && exp <= 1000) // a / 2^k = a * 2^-k
|
||||
replace(function, block, index, {IrCmd::MUL_NUM, inst.a, build.constDouble(1.0 / *k)});
|
||||
else
|
||||
state.substituteOrRecord(inst, index);
|
||||
}
|
||||
else
|
||||
state.substituteOrRecord(inst, index);
|
||||
}
|
||||
else
|
||||
state.substituteOrRecord(inst, index);
|
||||
break;
|
||||
case IrCmd::IDIV_NUM:
|
||||
case IrCmd::MOD_NUM:
|
||||
case IrCmd::MIN_NUM:
|
||||
|
|
|
@ -540,7 +540,7 @@ TEST_CASE("VectorCustomAccess")
|
|||
CHECK_EQ(
|
||||
"\n" + getCodegenAssembly(R"(
|
||||
local function vec3magn(a: vector)
|
||||
return a.Magnitude * 2
|
||||
return a.Magnitude * 3
|
||||
end
|
||||
)"),
|
||||
R"(
|
||||
|
@ -560,7 +560,7 @@ bb_bytecode_1:
|
|||
%12 = ADD_NUM %9, %10
|
||||
%13 = ADD_NUM %12, %11
|
||||
%14 = SQRT_NUM %13
|
||||
%20 = MUL_NUM %14, 2
|
||||
%20 = MUL_NUM %14, 3
|
||||
STORE_DOUBLE R1, %20
|
||||
STORE_TAG R1, tnumber
|
||||
INTERRUPT 3u
|
||||
|
@ -1167,7 +1167,7 @@ local function inl(v: vector, s: number)
|
|||
end
|
||||
|
||||
local function getsum(x)
|
||||
return inl(x, 2) + inl(x, 5)
|
||||
return inl(x, 3) + inl(x, 5)
|
||||
end
|
||||
)",
|
||||
/* includeIrTypes */ true
|
||||
|
@ -1195,7 +1195,7 @@ bb_bytecode_1:
|
|||
bb_bytecode_0:
|
||||
CHECK_TAG R0, tvector, exit(0)
|
||||
%2 = LOAD_FLOAT R0, 4i
|
||||
%8 = MUL_NUM %2, 2
|
||||
%8 = MUL_NUM %2, 3
|
||||
%13 = LOAD_FLOAT R0, 4i
|
||||
%19 = MUL_NUM %13, 5
|
||||
%28 = ADD_NUM %8, %19
|
||||
|
|
|
@ -92,6 +92,16 @@ assert((function() local a = 1 a = a - 2 return a end)() == -1)
|
|||
assert((function() local a = 1 a = a * 2 return a end)() == 2)
|
||||
assert((function() local a = 1 a = a / 2 return a end)() == 0.5)
|
||||
|
||||
-- binary ops with fp specials, neg zero, large constants
|
||||
-- argument is passed into anonymous function to prevent constant folding
|
||||
assert((function(a) return tostring(a + 0) end)(-0) == "0")
|
||||
assert((function(a) return tostring(a - 0) end)(-0) == "-0")
|
||||
assert((function(a) return tostring(0 - a) end)(0) == "0")
|
||||
assert((function(a) return tostring(a - a) end)(1 / 0) == "nan")
|
||||
assert((function(a) return tostring(a * 0) end)(0 / 0) == "nan")
|
||||
assert((function(a) return tostring(a / (2^1000)) end)(2^1000) == "1")
|
||||
assert((function(a) return tostring(a / (2^-1000)) end)(2^-1000) == "1")
|
||||
|
||||
-- floor division should always round towards -Infinity
|
||||
assert((function() local a = 1 a = a // 2 return a end)() == 0)
|
||||
assert((function() local a = 3 a = a // 2 return a end)() == 1)
|
||||
|
|
Loading…
Reference in a new issue