diff --git a/CodeGen/src/OptimizeConstProp.cpp b/CodeGen/src/OptimizeConstProp.cpp index 519630f0..b820826a 100644 --- a/CodeGen/src/OptimizeConstProp.cpp +++ b/CodeGen/src/OptimizeConstProp.cpp @@ -19,6 +19,7 @@ LUAU_FASTINTVARIABLE(LuauCodeGenReuseSlotLimit, 64) LUAU_FASTINTVARIABLE(LuauCodeGenReuseUdataTagLimit, 64) LUAU_FASTFLAGVARIABLE(DebugLuauAbortingChecks) LUAU_FASTFLAG(LuauVectorLibNativeDot); +LUAU_FASTFLAGVARIABLE(LuauCodeGenArithOpt); namespace Luau { @@ -1194,8 +1195,46 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::SUB_INT: case IrCmd::ADD_NUM: case IrCmd::SUB_NUM: + state.substituteOrRecord(inst, index); + break; case IrCmd::MUL_NUM: + if (FFlag::LuauCodeGenArithOpt) + { + if (std::optional k = function.asDoubleOp(inst.b.kind == IrOpKind::Constant ? inst.b : state.tryGetValue(inst.b))) + { + if (*k == 1.0) // a * 1.0 = a + substitute(function, inst, inst.a); + else if (*k == 2.0) // a * 2.0 = a + a + replace(function, block, index, {IrCmd::ADD_NUM, inst.a, inst.a}); + else if (*k == -1.0) // a * -1.0 = -a + replace(function, block, index, {IrCmd::UNM_NUM, inst.a}); + else + state.substituteOrRecord(inst, index); + } + else + state.substituteOrRecord(inst, index); + } + else + state.substituteOrRecord(inst, index); + break; case IrCmd::DIV_NUM: + if (FFlag::LuauCodeGenArithOpt) + { + if (std::optional k = function.asDoubleOp(inst.b.kind == IrOpKind::Constant ? inst.b : state.tryGetValue(inst.b))) + { + if (*k == 1.0) // a / 1.0 = a + substitute(function, inst, inst.a); + else if (int exp = 0; frexp(*k, &exp) == 0.5 && exp >= -1000 && exp <= 1000) // a / 2^k = a * 2^-k + replace(function, block, index, {IrCmd::MUL_NUM, inst.a, build.constDouble(1.0 / *k)}); + else + state.substituteOrRecord(inst, index); + } + else + state.substituteOrRecord(inst, index); + } + else + state.substituteOrRecord(inst, index); + break; case IrCmd::IDIV_NUM: case IrCmd::MOD_NUM: case IrCmd::MIN_NUM: diff --git a/tests/IrLowering.test.cpp b/tests/IrLowering.test.cpp index 396678468..27376777 100644 --- a/tests/IrLowering.test.cpp +++ b/tests/IrLowering.test.cpp @@ -540,7 +540,7 @@ TEST_CASE("VectorCustomAccess") CHECK_EQ( "\n" + getCodegenAssembly(R"( local function vec3magn(a: vector) - return a.Magnitude * 2 + return a.Magnitude * 3 end )"), R"( @@ -560,7 +560,7 @@ bb_bytecode_1: %12 = ADD_NUM %9, %10 %13 = ADD_NUM %12, %11 %14 = SQRT_NUM %13 - %20 = MUL_NUM %14, 2 + %20 = MUL_NUM %14, 3 STORE_DOUBLE R1, %20 STORE_TAG R1, tnumber INTERRUPT 3u @@ -1167,7 +1167,7 @@ local function inl(v: vector, s: number) end local function getsum(x) - return inl(x, 2) + inl(x, 5) + return inl(x, 3) + inl(x, 5) end )", /* includeIrTypes */ true @@ -1195,7 +1195,7 @@ bb_bytecode_1: bb_bytecode_0: CHECK_TAG R0, tvector, exit(0) %2 = LOAD_FLOAT R0, 4i - %8 = MUL_NUM %2, 2 + %8 = MUL_NUM %2, 3 %13 = LOAD_FLOAT R0, 4i %19 = MUL_NUM %13, 5 %28 = ADD_NUM %8, %19