CodeGen: Implement support for math.lerp lowering

To implement math.lerp without branches, we add SELECT_NUM which
selects one of the two inputs based on the comparison condition.

For simplicity, we only support C == D for now; this can be extended to
a more generic version with a IrCondition operand E, but that requires
more work on the SSE side (to flip the comparison for some conditions
like Greater, and expose more generic vcmpsd).
This commit is contained in:
Arseny Kapoulkine 2025-01-09 10:08:52 -08:00
parent c759cd5581
commit 07578df79a
12 changed files with 108 additions and 0 deletions

View file

@ -160,6 +160,7 @@ public:
void vmaxsd(OperandX64 dst, OperandX64 src1, OperandX64 src2);
void vminsd(OperandX64 dst, OperandX64 src1, OperandX64 src2);
void vcmpeqsd(OperandX64 dst, OperandX64 src1, OperandX64 src2);
void vcmpltsd(OperandX64 dst, OperandX64 src1, OperandX64 src2);
void vblendvpd(RegisterX64 dst, RegisterX64 src1, OperandX64 mask, RegisterX64 src3);

View file

@ -185,6 +185,11 @@ enum class IrCmd : uint8_t
// A: double
SIGN_NUM,
// Select B if C == D, otherwise select A
// A, B: double (endpoints)
// C, D: double (condition arguments)
SELECT_NUM,
// Add/Sub/Mul/Div/Idiv two vectors
// A, B: TValue
ADD_VEC,

View file

@ -174,6 +174,7 @@ inline bool hasResult(IrCmd cmd)
case IrCmd::SQRT_NUM:
case IrCmd::ABS_NUM:
case IrCmd::SIGN_NUM:
case IrCmd::SELECT_NUM:
case IrCmd::ADD_VEC:
case IrCmd::SUB_VEC:
case IrCmd::MUL_VEC:

View file

@ -927,6 +927,11 @@ void AssemblyBuilderX64::vminsd(OperandX64 dst, OperandX64 src1, OperandX64 src2
placeAvx("vminsd", dst, src1, src2, 0x5d, false, AVX_0F, AVX_F2);
}
void AssemblyBuilderX64::vcmpeqsd(OperandX64 dst, OperandX64 src1, OperandX64 src2)
{
placeAvx("vcmpeqsd", dst, src1, src2, 0x00, 0xc2, false, AVX_0F, AVX_F2);
}
void AssemblyBuilderX64::vcmpltsd(OperandX64 dst, OperandX64 src1, OperandX64 src2)
{
placeAvx("vcmpltsd", dst, src1, src2, 0x01, 0xc2, false, AVX_0F, AVX_F2);

View file

@ -169,6 +169,8 @@ const char* getCmdName(IrCmd cmd)
return "ABS_NUM";
case IrCmd::SIGN_NUM:
return "SIGN_NUM";
case IrCmd::SELECT_NUM:
return "SELECT_NUM";
case IrCmd::ADD_VEC:
return "ADD_VEC";
case IrCmd::SUB_VEC:

View file

@ -13,6 +13,7 @@
LUAU_FASTFLAG(LuauVectorLibNativeDot)
LUAU_FASTFLAG(LuauCodeGenVectorDeadStoreElim)
LUAU_FASTFLAG(LuauCodeGenLerp)
namespace Luau
{
@ -703,6 +704,20 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
build.fcsel(inst.regA64, temp1, inst.regA64, getConditionFP(IrCondition::Less));
break;
}
case IrCmd::SELECT_NUM:
{
LUAU_ASSERT(FFlag::LuauCodeGenLerp);
inst.regA64 = regs.allocReuse(KindA64::d, index, {inst.a, inst.b});
RegisterA64 temp1 = tempDouble(inst.a);
RegisterA64 temp2 = tempDouble(inst.b);
RegisterA64 temp3 = tempDouble(inst.c);
RegisterA64 temp4 = tempDouble(inst.d);
build.fcmp(temp3, temp4);
build.fcsel(inst.regA64, temp2, temp1, getConditionFP(IrCondition::Equal));
break;
}
case IrCmd::ADD_VEC:
{
inst.regA64 = regs.allocReuse(KindA64::q, index, {inst.a, inst.b});

View file

@ -17,6 +17,7 @@
LUAU_FASTFLAG(LuauVectorLibNativeDot)
LUAU_FASTFLAG(LuauCodeGenVectorDeadStoreElim)
LUAU_FASTFLAG(LuauCodeGenLerp)
namespace Luau
{
@ -622,6 +623,30 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
build.vblendvpd(inst.regX64, tmp1.reg, build.f64x2(1, 1), inst.regX64);
break;
}
case IrCmd::SELECT_NUM:
{
LUAU_ASSERT(FFlag::LuauCodeGenLerp);
inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a}); // can't reuse b if a is a memory operand
ScopedRegX64 tmp{regs, SizeX64::xmmword};
if (inst.c.kind == IrOpKind::Inst)
build.vcmpeqsd(tmp.reg, regOp(inst.c), memRegDoubleOp(inst.d));
else
{
build.vmovsd(tmp.reg, memRegDoubleOp(inst.c));
build.vcmpeqsd(tmp.reg, tmp.reg, memRegDoubleOp(inst.d));
}
if (inst.a.kind == IrOpKind::Inst)
build.vblendvpd(inst.regX64, regOp(inst.a), memRegDoubleOp(inst.b), tmp.reg);
else
{
build.vmovsd(inst.regX64, memRegDoubleOp(inst.a));
build.vblendvpd(inst.regX64, inst.regX64, memRegDoubleOp(inst.b), tmp.reg);
}
break;
}
case IrCmd::ADD_VEC:
{
inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a, inst.b});

View file

@ -15,6 +15,7 @@ static const int kBit32BinaryOpUnrolledParams = 5;
LUAU_FASTFLAGVARIABLE(LuauVectorLibNativeCodegen);
LUAU_FASTFLAGVARIABLE(LuauVectorLibNativeDot);
LUAU_FASTFLAGVARIABLE(LuauCodeGenLerp);
namespace Luau
{
@ -284,6 +285,42 @@ static BuiltinImplResult translateBuiltinMathClamp(
return {BuiltinImplType::UsesFallback, 1};
}
static BuiltinImplResult translateBuiltinMathLerp(
IrBuilder& build,
int nparams,
int ra,
int arg,
IrOp args,
IrOp arg3,
int nresults,
IrOp fallback,
int pcpos
)
{
LUAU_ASSERT(FFlag::LuauCodeGenLerp);
if (nparams < 3 || nresults > 1)
return {BuiltinImplType::None, -1};
builtinCheckDouble(build, build.vmReg(arg), pcpos);
builtinCheckDouble(build, args, pcpos);
builtinCheckDouble(build, arg3, pcpos);
IrOp a = builtinLoadDouble(build, build.vmReg(arg));
IrOp b = builtinLoadDouble(build, args);
IrOp t = builtinLoadDouble(build, arg3);
IrOp l = build.inst(IrCmd::ADD_NUM, a, build.inst(IrCmd::MUL_NUM, build.inst(IrCmd::SUB_NUM, b, a), t));
IrOp r = build.inst(IrCmd::SELECT_NUM, l, b, t, build.constDouble(1.0), build.cond(IrCondition::Equal));
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), r);
if (ra != arg)
build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER));
return {BuiltinImplType::Full, 1};
}
static BuiltinImplResult translateBuiltinMathUnary(IrBuilder& build, IrCmd cmd, int nparams, int ra, int arg, int nresults, int pcpos)
{
if (nparams < 1 || nresults > 1)
@ -1387,6 +1424,8 @@ BuiltinImplResult translateBuiltin(
case LBF_VECTOR_MAX:
return FFlag::LuauVectorLibNativeCodegen ? translateBuiltinVectorMap2(build, IrCmd::MAX_NUM, nparams, ra, arg, args, arg3, nresults, pcpos)
: noneResult;
case LBF_MATH_LERP:
return FFlag::LuauCodeGenLerp ? translateBuiltinMathLerp(build, nparams, ra, arg, args, arg3, nresults, fallback, pcpos) : noneResult;
default:
return {BuiltinImplType::None, -1};
}

View file

@ -13,6 +13,7 @@
#include <math.h>
LUAU_FASTFLAG(LuauVectorLibNativeDot);
LUAU_FASTFLAG(LuauCodeGenLerp);
namespace Luau
{
@ -70,6 +71,7 @@ IrValueKind getCmdValueKind(IrCmd cmd)
case IrCmd::SQRT_NUM:
case IrCmd::ABS_NUM:
case IrCmd::SIGN_NUM:
case IrCmd::SELECT_NUM:
return IrValueKind::Double;
case IrCmd::ADD_VEC:
case IrCmd::SUB_VEC:
@ -656,6 +658,16 @@ void foldConstants(IrBuilder& build, IrFunction& function, IrBlock& block, uint3
substitute(function, inst, build.constDouble(v > 0.0 ? 1.0 : v < 0.0 ? -1.0 : 0.0));
}
break;
case IrCmd::SELECT_NUM:
LUAU_ASSERT(FFlag::LuauCodeGenLerp);
if (inst.c.kind == IrOpKind::Constant && inst.d.kind == IrOpKind::Constant)
{
double c = function.doubleOp(inst.c);
double d = function.doubleOp(inst.d);
substitute(function, inst, c == d ? inst.b : inst.a);
}
break;
case IrCmd::NOT_ANY:
if (inst.a.kind == IrOpKind::Constant)
{

View file

@ -1382,6 +1382,7 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction&
case IrCmd::SQRT_NUM:
case IrCmd::ABS_NUM:
case IrCmd::SIGN_NUM:
case IrCmd::SELECT_NUM:
case IrCmd::NOT_ANY:
state.substituteOrRecord(inst, index);
break;

View file

@ -506,6 +506,7 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXBinaryInstructionForms")
SINGLE_COMPARE(vmaxsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x2b, 0x5f, 0xc6);
SINGLE_COMPARE(vminsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x2b, 0x5d, 0xc6);
SINGLE_COMPARE(vcmpeqsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x2b, 0xc2, 0xc6, 0x00);
SINGLE_COMPARE(vcmpltsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x2b, 0xc2, 0xc6, 0x01);
}

View file

@ -408,6 +408,7 @@ assert(math.lerp(1, 5, 1) == 5)
assert(math.lerp(1, 5, 0.5) == 3)
assert(math.lerp(1, 5, 1.5) == 7)
assert(math.lerp(1, 5, -0.5) == -1)
assert(math.lerp(1, 5, noinline(0.5)) == 3)
-- lerp properties
local sq2, sq3 = math.sqrt(2), math.sqrt(3)