mirror of
https://github.com/luau-lang/luau.git
synced 2025-01-19 09:18:07 +00:00
CodeGen: Implement support for math.lerp lowering (#1609)
Some checks are pending
benchmark / callgrind (map[branch:main name:luau-lang/benchmark-data], ubuntu-22.04) (push) Waiting to run
build / macos (push) Waiting to run
build / macos-arm (push) Waiting to run
build / ubuntu (push) Waiting to run
build / windows (Win32) (push) Waiting to run
build / windows (x64) (push) Waiting to run
build / coverage (push) Waiting to run
build / web (push) Waiting to run
release / macos (push) Waiting to run
release / ubuntu (push) Waiting to run
release / windows (push) Waiting to run
release / web (push) Waiting to run
Some checks are pending
benchmark / callgrind (map[branch:main name:luau-lang/benchmark-data], ubuntu-22.04) (push) Waiting to run
build / macos (push) Waiting to run
build / macos-arm (push) Waiting to run
build / ubuntu (push) Waiting to run
build / windows (Win32) (push) Waiting to run
build / windows (x64) (push) Waiting to run
build / coverage (push) Waiting to run
build / web (push) Waiting to run
release / macos (push) Waiting to run
release / ubuntu (push) Waiting to run
release / windows (push) Waiting to run
release / web (push) Waiting to run
To implement math.lerp without branches, we add SELECT_NUM which selects one of the two inputs based on the comparison condition. For simplicity, we only support C == D for now; this can be extended to a more generic version with a IrCondition operand E, but that requires more work on the SSE side (to flip the comparison for some conditions like Greater, and expose more generic vcmpsd). Note: On AArch64 this will effectively result in a change in floating point behavior between native code and non-native code: clang synthesizes fmadd (because floating point contraction is allowed by default, and the arch always has the instruction), whereas this change will use fmul+fadd. I am not sure if this is good or bad, and if this is a problem in C or not. Specifically, clang's behavior results in different results between X64 and AArch64 when *not* using codegen, and with this change the behavior when using codegen is... the same? :) Fixing this will require either using LERP_NUM instead and hand-coding lowering, or exposing some sort of "quasi" MADD_NUM (which would lower to fma on AArch64 and mul+add on X64). A small benefit to the current approach is `lerp(1, 5, t)` constant-folds the subtraction. With LERP_NUM this optimization will need to be implemented manually as a partial constant-folding for LERP_NUM. A similar problem exists today for vector.cross & vector.dot. So maybe this is not something we need to fix, unsure.
This commit is contained in:
parent
c759cd5581
commit
24cacc94ed
12 changed files with 108 additions and 0 deletions
|
@ -160,6 +160,7 @@ public:
|
||||||
void vmaxsd(OperandX64 dst, OperandX64 src1, OperandX64 src2);
|
void vmaxsd(OperandX64 dst, OperandX64 src1, OperandX64 src2);
|
||||||
void vminsd(OperandX64 dst, OperandX64 src1, OperandX64 src2);
|
void vminsd(OperandX64 dst, OperandX64 src1, OperandX64 src2);
|
||||||
|
|
||||||
|
void vcmpeqsd(OperandX64 dst, OperandX64 src1, OperandX64 src2);
|
||||||
void vcmpltsd(OperandX64 dst, OperandX64 src1, OperandX64 src2);
|
void vcmpltsd(OperandX64 dst, OperandX64 src1, OperandX64 src2);
|
||||||
|
|
||||||
void vblendvpd(RegisterX64 dst, RegisterX64 src1, OperandX64 mask, RegisterX64 src3);
|
void vblendvpd(RegisterX64 dst, RegisterX64 src1, OperandX64 mask, RegisterX64 src3);
|
||||||
|
|
|
@ -185,6 +185,11 @@ enum class IrCmd : uint8_t
|
||||||
// A: double
|
// A: double
|
||||||
SIGN_NUM,
|
SIGN_NUM,
|
||||||
|
|
||||||
|
// Select B if C == D, otherwise select A
|
||||||
|
// A, B: double (endpoints)
|
||||||
|
// C, D: double (condition arguments)
|
||||||
|
SELECT_NUM,
|
||||||
|
|
||||||
// Add/Sub/Mul/Div/Idiv two vectors
|
// Add/Sub/Mul/Div/Idiv two vectors
|
||||||
// A, B: TValue
|
// A, B: TValue
|
||||||
ADD_VEC,
|
ADD_VEC,
|
||||||
|
|
|
@ -174,6 +174,7 @@ inline bool hasResult(IrCmd cmd)
|
||||||
case IrCmd::SQRT_NUM:
|
case IrCmd::SQRT_NUM:
|
||||||
case IrCmd::ABS_NUM:
|
case IrCmd::ABS_NUM:
|
||||||
case IrCmd::SIGN_NUM:
|
case IrCmd::SIGN_NUM:
|
||||||
|
case IrCmd::SELECT_NUM:
|
||||||
case IrCmd::ADD_VEC:
|
case IrCmd::ADD_VEC:
|
||||||
case IrCmd::SUB_VEC:
|
case IrCmd::SUB_VEC:
|
||||||
case IrCmd::MUL_VEC:
|
case IrCmd::MUL_VEC:
|
||||||
|
|
|
@ -927,6 +927,11 @@ void AssemblyBuilderX64::vminsd(OperandX64 dst, OperandX64 src1, OperandX64 src2
|
||||||
placeAvx("vminsd", dst, src1, src2, 0x5d, false, AVX_0F, AVX_F2);
|
placeAvx("vminsd", dst, src1, src2, 0x5d, false, AVX_0F, AVX_F2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void AssemblyBuilderX64::vcmpeqsd(OperandX64 dst, OperandX64 src1, OperandX64 src2)
|
||||||
|
{
|
||||||
|
placeAvx("vcmpeqsd", dst, src1, src2, 0x00, 0xc2, false, AVX_0F, AVX_F2);
|
||||||
|
}
|
||||||
|
|
||||||
void AssemblyBuilderX64::vcmpltsd(OperandX64 dst, OperandX64 src1, OperandX64 src2)
|
void AssemblyBuilderX64::vcmpltsd(OperandX64 dst, OperandX64 src1, OperandX64 src2)
|
||||||
{
|
{
|
||||||
placeAvx("vcmpltsd", dst, src1, src2, 0x01, 0xc2, false, AVX_0F, AVX_F2);
|
placeAvx("vcmpltsd", dst, src1, src2, 0x01, 0xc2, false, AVX_0F, AVX_F2);
|
||||||
|
|
|
@ -169,6 +169,8 @@ const char* getCmdName(IrCmd cmd)
|
||||||
return "ABS_NUM";
|
return "ABS_NUM";
|
||||||
case IrCmd::SIGN_NUM:
|
case IrCmd::SIGN_NUM:
|
||||||
return "SIGN_NUM";
|
return "SIGN_NUM";
|
||||||
|
case IrCmd::SELECT_NUM:
|
||||||
|
return "SELECT_NUM";
|
||||||
case IrCmd::ADD_VEC:
|
case IrCmd::ADD_VEC:
|
||||||
return "ADD_VEC";
|
return "ADD_VEC";
|
||||||
case IrCmd::SUB_VEC:
|
case IrCmd::SUB_VEC:
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
|
|
||||||
LUAU_FASTFLAG(LuauVectorLibNativeDot)
|
LUAU_FASTFLAG(LuauVectorLibNativeDot)
|
||||||
LUAU_FASTFLAG(LuauCodeGenVectorDeadStoreElim)
|
LUAU_FASTFLAG(LuauCodeGenVectorDeadStoreElim)
|
||||||
|
LUAU_FASTFLAG(LuauCodeGenLerp)
|
||||||
|
|
||||||
namespace Luau
|
namespace Luau
|
||||||
{
|
{
|
||||||
|
@ -703,6 +704,20 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
|
||||||
build.fcsel(inst.regA64, temp1, inst.regA64, getConditionFP(IrCondition::Less));
|
build.fcsel(inst.regA64, temp1, inst.regA64, getConditionFP(IrCondition::Less));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case IrCmd::SELECT_NUM:
|
||||||
|
{
|
||||||
|
LUAU_ASSERT(FFlag::LuauCodeGenLerp);
|
||||||
|
inst.regA64 = regs.allocReuse(KindA64::d, index, {inst.a, inst.b, inst.c, inst.d});
|
||||||
|
|
||||||
|
RegisterA64 temp1 = tempDouble(inst.a);
|
||||||
|
RegisterA64 temp2 = tempDouble(inst.b);
|
||||||
|
RegisterA64 temp3 = tempDouble(inst.c);
|
||||||
|
RegisterA64 temp4 = tempDouble(inst.d);
|
||||||
|
|
||||||
|
build.fcmp(temp3, temp4);
|
||||||
|
build.fcsel(inst.regA64, temp2, temp1, getConditionFP(IrCondition::Equal));
|
||||||
|
break;
|
||||||
|
}
|
||||||
case IrCmd::ADD_VEC:
|
case IrCmd::ADD_VEC:
|
||||||
{
|
{
|
||||||
inst.regA64 = regs.allocReuse(KindA64::q, index, {inst.a, inst.b});
|
inst.regA64 = regs.allocReuse(KindA64::q, index, {inst.a, inst.b});
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
|
|
||||||
LUAU_FASTFLAG(LuauVectorLibNativeDot)
|
LUAU_FASTFLAG(LuauVectorLibNativeDot)
|
||||||
LUAU_FASTFLAG(LuauCodeGenVectorDeadStoreElim)
|
LUAU_FASTFLAG(LuauCodeGenVectorDeadStoreElim)
|
||||||
|
LUAU_FASTFLAG(LuauCodeGenLerp)
|
||||||
|
|
||||||
namespace Luau
|
namespace Luau
|
||||||
{
|
{
|
||||||
|
@ -622,6 +623,30 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
|
||||||
build.vblendvpd(inst.regX64, tmp1.reg, build.f64x2(1, 1), inst.regX64);
|
build.vblendvpd(inst.regX64, tmp1.reg, build.f64x2(1, 1), inst.regX64);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case IrCmd::SELECT_NUM:
|
||||||
|
{
|
||||||
|
LUAU_ASSERT(FFlag::LuauCodeGenLerp);
|
||||||
|
inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a, inst.c, inst.d}); // can't reuse b if a is a memory operand
|
||||||
|
|
||||||
|
ScopedRegX64 tmp{regs, SizeX64::xmmword};
|
||||||
|
|
||||||
|
if (inst.c.kind == IrOpKind::Inst)
|
||||||
|
build.vcmpeqsd(tmp.reg, regOp(inst.c), memRegDoubleOp(inst.d));
|
||||||
|
else
|
||||||
|
{
|
||||||
|
build.vmovsd(tmp.reg, memRegDoubleOp(inst.c));
|
||||||
|
build.vcmpeqsd(tmp.reg, tmp.reg, memRegDoubleOp(inst.d));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (inst.a.kind == IrOpKind::Inst)
|
||||||
|
build.vblendvpd(inst.regX64, regOp(inst.a), memRegDoubleOp(inst.b), tmp.reg);
|
||||||
|
else
|
||||||
|
{
|
||||||
|
build.vmovsd(inst.regX64, memRegDoubleOp(inst.a));
|
||||||
|
build.vblendvpd(inst.regX64, inst.regX64, memRegDoubleOp(inst.b), tmp.reg);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
case IrCmd::ADD_VEC:
|
case IrCmd::ADD_VEC:
|
||||||
{
|
{
|
||||||
inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a, inst.b});
|
inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a, inst.b});
|
||||||
|
|
|
@ -15,6 +15,7 @@ static const int kBit32BinaryOpUnrolledParams = 5;
|
||||||
|
|
||||||
LUAU_FASTFLAGVARIABLE(LuauVectorLibNativeCodegen);
|
LUAU_FASTFLAGVARIABLE(LuauVectorLibNativeCodegen);
|
||||||
LUAU_FASTFLAGVARIABLE(LuauVectorLibNativeDot);
|
LUAU_FASTFLAGVARIABLE(LuauVectorLibNativeDot);
|
||||||
|
LUAU_FASTFLAGVARIABLE(LuauCodeGenLerp);
|
||||||
|
|
||||||
namespace Luau
|
namespace Luau
|
||||||
{
|
{
|
||||||
|
@ -284,6 +285,42 @@ static BuiltinImplResult translateBuiltinMathClamp(
|
||||||
return {BuiltinImplType::UsesFallback, 1};
|
return {BuiltinImplType::UsesFallback, 1};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static BuiltinImplResult translateBuiltinMathLerp(
|
||||||
|
IrBuilder& build,
|
||||||
|
int nparams,
|
||||||
|
int ra,
|
||||||
|
int arg,
|
||||||
|
IrOp args,
|
||||||
|
IrOp arg3,
|
||||||
|
int nresults,
|
||||||
|
IrOp fallback,
|
||||||
|
int pcpos
|
||||||
|
)
|
||||||
|
{
|
||||||
|
LUAU_ASSERT(FFlag::LuauCodeGenLerp);
|
||||||
|
|
||||||
|
if (nparams < 3 || nresults > 1)
|
||||||
|
return {BuiltinImplType::None, -1};
|
||||||
|
|
||||||
|
builtinCheckDouble(build, build.vmReg(arg), pcpos);
|
||||||
|
builtinCheckDouble(build, args, pcpos);
|
||||||
|
builtinCheckDouble(build, arg3, pcpos);
|
||||||
|
|
||||||
|
IrOp a = builtinLoadDouble(build, build.vmReg(arg));
|
||||||
|
IrOp b = builtinLoadDouble(build, args);
|
||||||
|
IrOp t = builtinLoadDouble(build, arg3);
|
||||||
|
|
||||||
|
IrOp l = build.inst(IrCmd::ADD_NUM, a, build.inst(IrCmd::MUL_NUM, build.inst(IrCmd::SUB_NUM, b, a), t));
|
||||||
|
IrOp r = build.inst(IrCmd::SELECT_NUM, l, b, t, build.constDouble(1.0)); // select on t==1.0
|
||||||
|
|
||||||
|
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), r);
|
||||||
|
|
||||||
|
if (ra != arg)
|
||||||
|
build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER));
|
||||||
|
|
||||||
|
return {BuiltinImplType::Full, 1};
|
||||||
|
}
|
||||||
|
|
||||||
static BuiltinImplResult translateBuiltinMathUnary(IrBuilder& build, IrCmd cmd, int nparams, int ra, int arg, int nresults, int pcpos)
|
static BuiltinImplResult translateBuiltinMathUnary(IrBuilder& build, IrCmd cmd, int nparams, int ra, int arg, int nresults, int pcpos)
|
||||||
{
|
{
|
||||||
if (nparams < 1 || nresults > 1)
|
if (nparams < 1 || nresults > 1)
|
||||||
|
@ -1387,6 +1424,8 @@ BuiltinImplResult translateBuiltin(
|
||||||
case LBF_VECTOR_MAX:
|
case LBF_VECTOR_MAX:
|
||||||
return FFlag::LuauVectorLibNativeCodegen ? translateBuiltinVectorMap2(build, IrCmd::MAX_NUM, nparams, ra, arg, args, arg3, nresults, pcpos)
|
return FFlag::LuauVectorLibNativeCodegen ? translateBuiltinVectorMap2(build, IrCmd::MAX_NUM, nparams, ra, arg, args, arg3, nresults, pcpos)
|
||||||
: noneResult;
|
: noneResult;
|
||||||
|
case LBF_MATH_LERP:
|
||||||
|
return FFlag::LuauCodeGenLerp ? translateBuiltinMathLerp(build, nparams, ra, arg, args, arg3, nresults, fallback, pcpos) : noneResult;
|
||||||
default:
|
default:
|
||||||
return {BuiltinImplType::None, -1};
|
return {BuiltinImplType::None, -1};
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
#include <math.h>
|
#include <math.h>
|
||||||
|
|
||||||
LUAU_FASTFLAG(LuauVectorLibNativeDot);
|
LUAU_FASTFLAG(LuauVectorLibNativeDot);
|
||||||
|
LUAU_FASTFLAG(LuauCodeGenLerp);
|
||||||
|
|
||||||
namespace Luau
|
namespace Luau
|
||||||
{
|
{
|
||||||
|
@ -70,6 +71,7 @@ IrValueKind getCmdValueKind(IrCmd cmd)
|
||||||
case IrCmd::SQRT_NUM:
|
case IrCmd::SQRT_NUM:
|
||||||
case IrCmd::ABS_NUM:
|
case IrCmd::ABS_NUM:
|
||||||
case IrCmd::SIGN_NUM:
|
case IrCmd::SIGN_NUM:
|
||||||
|
case IrCmd::SELECT_NUM:
|
||||||
return IrValueKind::Double;
|
return IrValueKind::Double;
|
||||||
case IrCmd::ADD_VEC:
|
case IrCmd::ADD_VEC:
|
||||||
case IrCmd::SUB_VEC:
|
case IrCmd::SUB_VEC:
|
||||||
|
@ -656,6 +658,16 @@ void foldConstants(IrBuilder& build, IrFunction& function, IrBlock& block, uint3
|
||||||
substitute(function, inst, build.constDouble(v > 0.0 ? 1.0 : v < 0.0 ? -1.0 : 0.0));
|
substitute(function, inst, build.constDouble(v > 0.0 ? 1.0 : v < 0.0 ? -1.0 : 0.0));
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
case IrCmd::SELECT_NUM:
|
||||||
|
LUAU_ASSERT(FFlag::LuauCodeGenLerp);
|
||||||
|
if (inst.c.kind == IrOpKind::Constant && inst.d.kind == IrOpKind::Constant)
|
||||||
|
{
|
||||||
|
double c = function.doubleOp(inst.c);
|
||||||
|
double d = function.doubleOp(inst.d);
|
||||||
|
|
||||||
|
substitute(function, inst, c == d ? inst.b : inst.a);
|
||||||
|
}
|
||||||
|
break;
|
||||||
case IrCmd::NOT_ANY:
|
case IrCmd::NOT_ANY:
|
||||||
if (inst.a.kind == IrOpKind::Constant)
|
if (inst.a.kind == IrOpKind::Constant)
|
||||||
{
|
{
|
||||||
|
|
|
@ -1382,6 +1382,7 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction&
|
||||||
case IrCmd::SQRT_NUM:
|
case IrCmd::SQRT_NUM:
|
||||||
case IrCmd::ABS_NUM:
|
case IrCmd::ABS_NUM:
|
||||||
case IrCmd::SIGN_NUM:
|
case IrCmd::SIGN_NUM:
|
||||||
|
case IrCmd::SELECT_NUM:
|
||||||
case IrCmd::NOT_ANY:
|
case IrCmd::NOT_ANY:
|
||||||
state.substituteOrRecord(inst, index);
|
state.substituteOrRecord(inst, index);
|
||||||
break;
|
break;
|
||||||
|
|
|
@ -506,6 +506,7 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXBinaryInstructionForms")
|
||||||
SINGLE_COMPARE(vmaxsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x2b, 0x5f, 0xc6);
|
SINGLE_COMPARE(vmaxsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x2b, 0x5f, 0xc6);
|
||||||
SINGLE_COMPARE(vminsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x2b, 0x5d, 0xc6);
|
SINGLE_COMPARE(vminsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x2b, 0x5d, 0xc6);
|
||||||
|
|
||||||
|
SINGLE_COMPARE(vcmpeqsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x2b, 0xc2, 0xc6, 0x00);
|
||||||
SINGLE_COMPARE(vcmpltsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x2b, 0xc2, 0xc6, 0x01);
|
SINGLE_COMPARE(vcmpltsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x2b, 0xc2, 0xc6, 0x01);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -408,6 +408,7 @@ assert(math.lerp(1, 5, 1) == 5)
|
||||||
assert(math.lerp(1, 5, 0.5) == 3)
|
assert(math.lerp(1, 5, 0.5) == 3)
|
||||||
assert(math.lerp(1, 5, 1.5) == 7)
|
assert(math.lerp(1, 5, 1.5) == 7)
|
||||||
assert(math.lerp(1, 5, -0.5) == -1)
|
assert(math.lerp(1, 5, -0.5) == -1)
|
||||||
|
assert(math.lerp(1, 5, noinline(0.5)) == 3)
|
||||||
|
|
||||||
-- lerp properties
|
-- lerp properties
|
||||||
local sq2, sq3 = math.sqrt(2), math.sqrt(3)
|
local sq2, sq3 = math.sqrt(2), math.sqrt(3)
|
||||||
|
|
Loading…
Reference in a new issue