mirror of
https://github.com/luau-lang/luau.git
synced 2025-04-03 02:10:53 +01:00
CodeGen: Implement support for math.lerp lowering
To implement math.lerp without branches, we add SELECT_NUM which selects one of the two inputs based on the comparison condition. For simplicity, we only support C == D for now; this can be extended to a more generic version with a IrCondition operand E, but that requires more work on the SSE side (to flip the comparison for some conditions like Greater, and expose more generic vcmpsd).
This commit is contained in:
parent
c759cd5581
commit
07578df79a
12 changed files with 108 additions and 0 deletions
|
@ -160,6 +160,7 @@ public:
|
|||
void vmaxsd(OperandX64 dst, OperandX64 src1, OperandX64 src2);
|
||||
void vminsd(OperandX64 dst, OperandX64 src1, OperandX64 src2);
|
||||
|
||||
void vcmpeqsd(OperandX64 dst, OperandX64 src1, OperandX64 src2);
|
||||
void vcmpltsd(OperandX64 dst, OperandX64 src1, OperandX64 src2);
|
||||
|
||||
void vblendvpd(RegisterX64 dst, RegisterX64 src1, OperandX64 mask, RegisterX64 src3);
|
||||
|
|
|
@ -185,6 +185,11 @@ enum class IrCmd : uint8_t
|
|||
// A: double
|
||||
SIGN_NUM,
|
||||
|
||||
// Select B if C == D, otherwise select A
|
||||
// A, B: double (endpoints)
|
||||
// C, D: double (condition arguments)
|
||||
SELECT_NUM,
|
||||
|
||||
// Add/Sub/Mul/Div/Idiv two vectors
|
||||
// A, B: TValue
|
||||
ADD_VEC,
|
||||
|
|
|
@ -174,6 +174,7 @@ inline bool hasResult(IrCmd cmd)
|
|||
case IrCmd::SQRT_NUM:
|
||||
case IrCmd::ABS_NUM:
|
||||
case IrCmd::SIGN_NUM:
|
||||
case IrCmd::SELECT_NUM:
|
||||
case IrCmd::ADD_VEC:
|
||||
case IrCmd::SUB_VEC:
|
||||
case IrCmd::MUL_VEC:
|
||||
|
|
|
@ -927,6 +927,11 @@ void AssemblyBuilderX64::vminsd(OperandX64 dst, OperandX64 src1, OperandX64 src2
|
|||
placeAvx("vminsd", dst, src1, src2, 0x5d, false, AVX_0F, AVX_F2);
|
||||
}
|
||||
|
||||
void AssemblyBuilderX64::vcmpeqsd(OperandX64 dst, OperandX64 src1, OperandX64 src2)
|
||||
{
|
||||
placeAvx("vcmpeqsd", dst, src1, src2, 0x00, 0xc2, false, AVX_0F, AVX_F2);
|
||||
}
|
||||
|
||||
void AssemblyBuilderX64::vcmpltsd(OperandX64 dst, OperandX64 src1, OperandX64 src2)
|
||||
{
|
||||
placeAvx("vcmpltsd", dst, src1, src2, 0x01, 0xc2, false, AVX_0F, AVX_F2);
|
||||
|
|
|
@ -169,6 +169,8 @@ const char* getCmdName(IrCmd cmd)
|
|||
return "ABS_NUM";
|
||||
case IrCmd::SIGN_NUM:
|
||||
return "SIGN_NUM";
|
||||
case IrCmd::SELECT_NUM:
|
||||
return "SELECT_NUM";
|
||||
case IrCmd::ADD_VEC:
|
||||
return "ADD_VEC";
|
||||
case IrCmd::SUB_VEC:
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
|
||||
LUAU_FASTFLAG(LuauVectorLibNativeDot)
|
||||
LUAU_FASTFLAG(LuauCodeGenVectorDeadStoreElim)
|
||||
LUAU_FASTFLAG(LuauCodeGenLerp)
|
||||
|
||||
namespace Luau
|
||||
{
|
||||
|
@ -703,6 +704,20 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
|
|||
build.fcsel(inst.regA64, temp1, inst.regA64, getConditionFP(IrCondition::Less));
|
||||
break;
|
||||
}
|
||||
case IrCmd::SELECT_NUM:
|
||||
{
|
||||
LUAU_ASSERT(FFlag::LuauCodeGenLerp);
|
||||
inst.regA64 = regs.allocReuse(KindA64::d, index, {inst.a, inst.b});
|
||||
|
||||
RegisterA64 temp1 = tempDouble(inst.a);
|
||||
RegisterA64 temp2 = tempDouble(inst.b);
|
||||
RegisterA64 temp3 = tempDouble(inst.c);
|
||||
RegisterA64 temp4 = tempDouble(inst.d);
|
||||
|
||||
build.fcmp(temp3, temp4);
|
||||
build.fcsel(inst.regA64, temp2, temp1, getConditionFP(IrCondition::Equal));
|
||||
break;
|
||||
}
|
||||
case IrCmd::ADD_VEC:
|
||||
{
|
||||
inst.regA64 = regs.allocReuse(KindA64::q, index, {inst.a, inst.b});
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
LUAU_FASTFLAG(LuauVectorLibNativeDot)
|
||||
LUAU_FASTFLAG(LuauCodeGenVectorDeadStoreElim)
|
||||
LUAU_FASTFLAG(LuauCodeGenLerp)
|
||||
|
||||
namespace Luau
|
||||
{
|
||||
|
@ -622,6 +623,30 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
|
|||
build.vblendvpd(inst.regX64, tmp1.reg, build.f64x2(1, 1), inst.regX64);
|
||||
break;
|
||||
}
|
||||
case IrCmd::SELECT_NUM:
|
||||
{
|
||||
LUAU_ASSERT(FFlag::LuauCodeGenLerp);
|
||||
inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a}); // can't reuse b if a is a memory operand
|
||||
|
||||
ScopedRegX64 tmp{regs, SizeX64::xmmword};
|
||||
|
||||
if (inst.c.kind == IrOpKind::Inst)
|
||||
build.vcmpeqsd(tmp.reg, regOp(inst.c), memRegDoubleOp(inst.d));
|
||||
else
|
||||
{
|
||||
build.vmovsd(tmp.reg, memRegDoubleOp(inst.c));
|
||||
build.vcmpeqsd(tmp.reg, tmp.reg, memRegDoubleOp(inst.d));
|
||||
}
|
||||
|
||||
if (inst.a.kind == IrOpKind::Inst)
|
||||
build.vblendvpd(inst.regX64, regOp(inst.a), memRegDoubleOp(inst.b), tmp.reg);
|
||||
else
|
||||
{
|
||||
build.vmovsd(inst.regX64, memRegDoubleOp(inst.a));
|
||||
build.vblendvpd(inst.regX64, inst.regX64, memRegDoubleOp(inst.b), tmp.reg);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case IrCmd::ADD_VEC:
|
||||
{
|
||||
inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a, inst.b});
|
||||
|
|
|
@ -15,6 +15,7 @@ static const int kBit32BinaryOpUnrolledParams = 5;
|
|||
|
||||
LUAU_FASTFLAGVARIABLE(LuauVectorLibNativeCodegen);
|
||||
LUAU_FASTFLAGVARIABLE(LuauVectorLibNativeDot);
|
||||
LUAU_FASTFLAGVARIABLE(LuauCodeGenLerp);
|
||||
|
||||
namespace Luau
|
||||
{
|
||||
|
@ -284,6 +285,42 @@ static BuiltinImplResult translateBuiltinMathClamp(
|
|||
return {BuiltinImplType::UsesFallback, 1};
|
||||
}
|
||||
|
||||
static BuiltinImplResult translateBuiltinMathLerp(
|
||||
IrBuilder& build,
|
||||
int nparams,
|
||||
int ra,
|
||||
int arg,
|
||||
IrOp args,
|
||||
IrOp arg3,
|
||||
int nresults,
|
||||
IrOp fallback,
|
||||
int pcpos
|
||||
)
|
||||
{
|
||||
LUAU_ASSERT(FFlag::LuauCodeGenLerp);
|
||||
|
||||
if (nparams < 3 || nresults > 1)
|
||||
return {BuiltinImplType::None, -1};
|
||||
|
||||
builtinCheckDouble(build, build.vmReg(arg), pcpos);
|
||||
builtinCheckDouble(build, args, pcpos);
|
||||
builtinCheckDouble(build, arg3, pcpos);
|
||||
|
||||
IrOp a = builtinLoadDouble(build, build.vmReg(arg));
|
||||
IrOp b = builtinLoadDouble(build, args);
|
||||
IrOp t = builtinLoadDouble(build, arg3);
|
||||
|
||||
IrOp l = build.inst(IrCmd::ADD_NUM, a, build.inst(IrCmd::MUL_NUM, build.inst(IrCmd::SUB_NUM, b, a), t));
|
||||
IrOp r = build.inst(IrCmd::SELECT_NUM, l, b, t, build.constDouble(1.0), build.cond(IrCondition::Equal));
|
||||
|
||||
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), r);
|
||||
|
||||
if (ra != arg)
|
||||
build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER));
|
||||
|
||||
return {BuiltinImplType::Full, 1};
|
||||
}
|
||||
|
||||
static BuiltinImplResult translateBuiltinMathUnary(IrBuilder& build, IrCmd cmd, int nparams, int ra, int arg, int nresults, int pcpos)
|
||||
{
|
||||
if (nparams < 1 || nresults > 1)
|
||||
|
@ -1387,6 +1424,8 @@ BuiltinImplResult translateBuiltin(
|
|||
case LBF_VECTOR_MAX:
|
||||
return FFlag::LuauVectorLibNativeCodegen ? translateBuiltinVectorMap2(build, IrCmd::MAX_NUM, nparams, ra, arg, args, arg3, nresults, pcpos)
|
||||
: noneResult;
|
||||
case LBF_MATH_LERP:
|
||||
return FFlag::LuauCodeGenLerp ? translateBuiltinMathLerp(build, nparams, ra, arg, args, arg3, nresults, fallback, pcpos) : noneResult;
|
||||
default:
|
||||
return {BuiltinImplType::None, -1};
|
||||
}
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
#include <math.h>
|
||||
|
||||
LUAU_FASTFLAG(LuauVectorLibNativeDot);
|
||||
LUAU_FASTFLAG(LuauCodeGenLerp);
|
||||
|
||||
namespace Luau
|
||||
{
|
||||
|
@ -70,6 +71,7 @@ IrValueKind getCmdValueKind(IrCmd cmd)
|
|||
case IrCmd::SQRT_NUM:
|
||||
case IrCmd::ABS_NUM:
|
||||
case IrCmd::SIGN_NUM:
|
||||
case IrCmd::SELECT_NUM:
|
||||
return IrValueKind::Double;
|
||||
case IrCmd::ADD_VEC:
|
||||
case IrCmd::SUB_VEC:
|
||||
|
@ -656,6 +658,16 @@ void foldConstants(IrBuilder& build, IrFunction& function, IrBlock& block, uint3
|
|||
substitute(function, inst, build.constDouble(v > 0.0 ? 1.0 : v < 0.0 ? -1.0 : 0.0));
|
||||
}
|
||||
break;
|
||||
case IrCmd::SELECT_NUM:
|
||||
LUAU_ASSERT(FFlag::LuauCodeGenLerp);
|
||||
if (inst.c.kind == IrOpKind::Constant && inst.d.kind == IrOpKind::Constant)
|
||||
{
|
||||
double c = function.doubleOp(inst.c);
|
||||
double d = function.doubleOp(inst.d);
|
||||
|
||||
substitute(function, inst, c == d ? inst.b : inst.a);
|
||||
}
|
||||
break;
|
||||
case IrCmd::NOT_ANY:
|
||||
if (inst.a.kind == IrOpKind::Constant)
|
||||
{
|
||||
|
|
|
@ -1382,6 +1382,7 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction&
|
|||
case IrCmd::SQRT_NUM:
|
||||
case IrCmd::ABS_NUM:
|
||||
case IrCmd::SIGN_NUM:
|
||||
case IrCmd::SELECT_NUM:
|
||||
case IrCmd::NOT_ANY:
|
||||
state.substituteOrRecord(inst, index);
|
||||
break;
|
||||
|
|
|
@ -506,6 +506,7 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXBinaryInstructionForms")
|
|||
SINGLE_COMPARE(vmaxsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x2b, 0x5f, 0xc6);
|
||||
SINGLE_COMPARE(vminsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x2b, 0x5d, 0xc6);
|
||||
|
||||
SINGLE_COMPARE(vcmpeqsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x2b, 0xc2, 0xc6, 0x00);
|
||||
SINGLE_COMPARE(vcmpltsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x2b, 0xc2, 0xc6, 0x01);
|
||||
}
|
||||
|
||||
|
|
|
@ -408,6 +408,7 @@ assert(math.lerp(1, 5, 1) == 5)
|
|||
assert(math.lerp(1, 5, 0.5) == 3)
|
||||
assert(math.lerp(1, 5, 1.5) == 7)
|
||||
assert(math.lerp(1, 5, -0.5) == -1)
|
||||
assert(math.lerp(1, 5, noinline(0.5)) == 3)
|
||||
|
||||
-- lerp properties
|
||||
local sq2, sq3 = math.sqrt(2), math.sqrt(3)
|
||||
|
|
Loading…
Add table
Reference in a new issue