From b5801d33772cad05d0cc2d708f0ca41fea168aa5 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Wed, 27 Nov 2024 21:44:39 +0900 Subject: [PATCH] CodeGen: Optimize arithmetics for basic identities (#1545) This change folds: a * 1 => a a / 1 => a a * -1 => -a a / -1 => -a a * 2 => a + a a / 2^k => a * 2^-k a - 0 => a a + (-0) => a Note that the following folds are all invalid: a + 0 => a (breaks for negative zero) a - (-0) => a (breaks for negative zero) a - a => 0 (breaks for Inf/NaN) 0 - a => -a (breaks for negative zero) Various cases of UNM_NUM could be optimized (eg (-a) * (-b) = a * b), but that doesn't happen in benchmarks. While it would be possible to also fold inverse multiplications (k * v), these do not happen in benchmarks and rarely happen in bytecode due to type based optimizations. Maybe this can be improved with some sort of IR canonicalization in the future if necessary. I've considered moving some of these, like division strength reduction, to IR translation (as this is where POW is lowered presently) but it didn't seem better one way or the other. This change improves performance on some benchmarks, e.g. trig and voxelgen, and should be a strict uplift as it never generates more instructions or longer latency chains. On Apple M2, without division->multiplication optimization, both benchmarks see 0.1-0.2% uplift. Division optimization makes trig 3% faster; I expect the gains on X64 will be more muted, but on Apple this seems to allow loop iterations to overlap better by removing the division bottleneck. --- CodeGen/src/OptimizeConstProp.cpp | 59 +++++++++++++++++++++++++++++++ tests/IrLowering.test.cpp | 8 ++--- tests/conformance/basic.lua | 14 ++++++-- 3 files changed, 75 insertions(+), 6 deletions(-) diff --git a/CodeGen/src/OptimizeConstProp.cpp b/CodeGen/src/OptimizeConstProp.cpp index 519630f0..f93354a3 100644 --- a/CodeGen/src/OptimizeConstProp.cpp +++ b/CodeGen/src/OptimizeConstProp.cpp @@ -9,6 +9,7 @@ #include "lua.h" #include +#include #include #include @@ -19,6 +20,7 @@ LUAU_FASTINTVARIABLE(LuauCodeGenReuseSlotLimit, 64) LUAU_FASTINTVARIABLE(LuauCodeGenReuseUdataTagLimit, 64) LUAU_FASTFLAGVARIABLE(DebugLuauAbortingChecks) LUAU_FASTFLAG(LuauVectorLibNativeDot); +LUAU_FASTFLAGVARIABLE(LuauCodeGenArithOpt); namespace Luau { @@ -1192,10 +1194,67 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& break; case IrCmd::ADD_INT: case IrCmd::SUB_INT: + state.substituteOrRecord(inst, index); + break; case IrCmd::ADD_NUM: case IrCmd::SUB_NUM: + if (FFlag::LuauCodeGenArithOpt) + { + if (std::optional k = function.asDoubleOp(inst.b.kind == IrOpKind::Constant ? inst.b : state.tryGetValue(inst.b))) + { + // a + 0.0 and a - (-0.0) can't be folded since the behavior is different for negative zero + // however, a - 0.0 and a + (-0.0) can be folded into a + if (*k == 0.0 && bool(signbit(*k)) == (inst.cmd == IrCmd::ADD_NUM)) + substitute(function, inst, inst.a); + else + state.substituteOrRecord(inst, index); + } + else + state.substituteOrRecord(inst, index); + } + else + state.substituteOrRecord(inst, index); + break; case IrCmd::MUL_NUM: + if (FFlag::LuauCodeGenArithOpt) + { + if (std::optional k = function.asDoubleOp(inst.b.kind == IrOpKind::Constant ? inst.b : state.tryGetValue(inst.b))) + { + if (*k == 1.0) // a * 1.0 = a + substitute(function, inst, inst.a); + else if (*k == 2.0) // a * 2.0 = a + a + replace(function, block, index, {IrCmd::ADD_NUM, inst.a, inst.a}); + else if (*k == -1.0) // a * -1.0 = -a + replace(function, block, index, {IrCmd::UNM_NUM, inst.a}); + else + state.substituteOrRecord(inst, index); + } + else + state.substituteOrRecord(inst, index); + } + else + state.substituteOrRecord(inst, index); + break; case IrCmd::DIV_NUM: + if (FFlag::LuauCodeGenArithOpt) + { + if (std::optional k = function.asDoubleOp(inst.b.kind == IrOpKind::Constant ? inst.b : state.tryGetValue(inst.b))) + { + if (*k == 1.0) // a / 1.0 = a + substitute(function, inst, inst.a); + else if (*k == -1.0) // a / -1.0 = -a + replace(function, block, index, {IrCmd::UNM_NUM, inst.a}); + else if (int exp = 0; frexp(*k, &exp) == 0.5 && exp >= -1000 && exp <= 1000) // a / 2^k = a * 2^-k + replace(function, block, index, {IrCmd::MUL_NUM, inst.a, build.constDouble(1.0 / *k)}); + else + state.substituteOrRecord(inst, index); + } + else + state.substituteOrRecord(inst, index); + } + else + state.substituteOrRecord(inst, index); + break; case IrCmd::IDIV_NUM: case IrCmd::MOD_NUM: case IrCmd::MIN_NUM: diff --git a/tests/IrLowering.test.cpp b/tests/IrLowering.test.cpp index 396678468..27376777 100644 --- a/tests/IrLowering.test.cpp +++ b/tests/IrLowering.test.cpp @@ -540,7 +540,7 @@ TEST_CASE("VectorCustomAccess") CHECK_EQ( "\n" + getCodegenAssembly(R"( local function vec3magn(a: vector) - return a.Magnitude * 2 + return a.Magnitude * 3 end )"), R"( @@ -560,7 +560,7 @@ bb_bytecode_1: %12 = ADD_NUM %9, %10 %13 = ADD_NUM %12, %11 %14 = SQRT_NUM %13 - %20 = MUL_NUM %14, 2 + %20 = MUL_NUM %14, 3 STORE_DOUBLE R1, %20 STORE_TAG R1, tnumber INTERRUPT 3u @@ -1167,7 +1167,7 @@ local function inl(v: vector, s: number) end local function getsum(x) - return inl(x, 2) + inl(x, 5) + return inl(x, 3) + inl(x, 5) end )", /* includeIrTypes */ true @@ -1195,7 +1195,7 @@ bb_bytecode_1: bb_bytecode_0: CHECK_TAG R0, tvector, exit(0) %2 = LOAD_FLOAT R0, 4i - %8 = MUL_NUM %2, 2 + %8 = MUL_NUM %2, 3 %13 = LOAD_FLOAT R0, 4i %19 = MUL_NUM %13, 5 %28 = ADD_NUM %8, %19 diff --git a/tests/conformance/basic.lua b/tests/conformance/basic.lua index 98f8000e..05d851ea 100644 --- a/tests/conformance/basic.lua +++ b/tests/conformance/basic.lua @@ -92,6 +92,16 @@ assert((function() local a = 1 a = a - 2 return a end)() == -1) assert((function() local a = 1 a = a * 2 return a end)() == 2) assert((function() local a = 1 a = a / 2 return a end)() == 0.5) +-- binary ops with fp specials, neg zero, large constants +-- argument is passed into anonymous function to prevent constant folding +assert((function(a) return tostring(a + 0) end)(-0) == "0") +assert((function(a) return tostring(a - 0) end)(-0) == "-0") +assert((function(a) return tostring(0 - a) end)(0) == "0") +assert((function(a) return tostring(a - a) end)(1 / 0) == "nan") +assert((function(a) return tostring(a * 0) end)(0 / 0) == "nan") +assert((function(a) return tostring(a / (2^1000)) end)(2^1000) == "1") +assert((function(a) return tostring(a / (2^-1000)) end)(2^-1000) == "1") + -- floor division should always round towards -Infinity assert((function() local a = 1 a = a // 2 return a end)() == 0) assert((function() local a = 3 a = a // 2 return a end)() == 1) @@ -290,7 +300,7 @@ assert((function() local t = {[1] = 1, [2] = 2} return t[1] + t[2] end)() == 3) assert((function() return table.concat({}, ',') end)() == "") assert((function() return table.concat({1}, ',') end)() == "1") assert((function() return table.concat({1,2}, ',') end)() == "1,2") -assert((function() return table.concat({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, ',') end)() == +assert((function() return table.concat({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, ',') end)() == "1,2,3,4,5,6,7,8,9,10,11,12,13,14,15") assert((function() return table.concat({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, ',') end)() == "1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16") assert((function() return table.concat({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}, ',') end)() == "1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17") @@ -770,7 +780,7 @@ assert(tostring(0) == "0") assert(tostring(-0) == "-0") -- test newline handling in long strings -assert((function() +assert((function() local s1 = [[ ]] local s2 = [[