CodeGen: Rewrite dot product lowering using a dedicated IR instruction (#1512)

Instead of doing the dot product related math in scalar IR, we lift the
computation into a dedicated IR instruction.

On x64, we can use VDPPS which was more or less tailor made for this
purpose. This is better than manual scalar lowering that requires
reloading components from memory; it's not always a strict improvement
over the shuffle+add version (which we never had), but this can now be
adjusted in the IR lowering in an optimal fashion (maybe even based on
CPU vendor, although that'd create issues for offline compilation).

On A64, we can either use naive adds or paired adds, as there is no
dedicated vector-wide horizontal instruction until SVE. Both run at
about the same performance on M2, but paired adds require fewer
instructions and temporaries.

I've measured this using mesh-normal-vector benchmark, changing the
benchmark to just report the time of the second loop inside
`calculate_normals`, testing master vs #1504 vs this PR, also increasing
the grid size to 400 for more stable timings.

On Zen 4 (7950X), this PR is comfortably ~8% faster vs master, while I
see neutral to negative results in #1504.
On M2 (base), this PR is ~28% faster vs master, while #1504 is only
about ~10% faster.

If I measure the second loop in `calculate_tangent_space` instead, I
get:

On Zen 4 (7950X), this PR is ~12% faster vs master, while #1504 is ~3%
faster
On M2 (base), this PR is ~24% faster vs master, while #1504 is only
about ~13% faster.

Note that the loops in question are not quite optimal, as they store and
reload various vectors to dictionary values due to inappropriate use of
locals. The underlying gains in individual functions are thus larger
than the numbers above; for example, changing the `calculate_normals`
loop to use a local variable to store the normalized vector (but still
saving the result to dictionary value), I get a ~24% performance
increase from this PR on Zen4 vs master instead of just 8% (#1504 is
~15% slower in this setup).
This commit is contained in:
Arseny Kapoulkine 2024-11-08 16:23:09 -08:00 committed by GitHub
parent a36a3c41cc
commit e6bf71871a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 135 additions and 32 deletions

View file

@ -138,6 +138,7 @@ public:
void fneg(RegisterA64 dst, RegisterA64 src); void fneg(RegisterA64 dst, RegisterA64 src);
void fsqrt(RegisterA64 dst, RegisterA64 src); void fsqrt(RegisterA64 dst, RegisterA64 src);
void fsub(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); void fsub(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2);
void faddp(RegisterA64 dst, RegisterA64 src);
// Vector component manipulation // Vector component manipulation
void ins_4s(RegisterA64 dst, RegisterA64 src, uint8_t index); void ins_4s(RegisterA64 dst, RegisterA64 src, uint8_t index);

View file

@ -167,6 +167,8 @@ public:
void vpshufps(RegisterX64 dst, RegisterX64 src1, OperandX64 src2, uint8_t shuffle); void vpshufps(RegisterX64 dst, RegisterX64 src1, OperandX64 src2, uint8_t shuffle);
void vpinsrd(RegisterX64 dst, RegisterX64 src1, OperandX64 src2, uint8_t offset); void vpinsrd(RegisterX64 dst, RegisterX64 src1, OperandX64 src2, uint8_t offset);
void vdpps(OperandX64 dst, OperandX64 src1, OperandX64 src2, uint8_t mask);
// Run final checks // Run final checks
bool finalize(); bool finalize();

View file

@ -194,6 +194,10 @@ enum class IrCmd : uint8_t
// A: TValue // A: TValue
UNM_VEC, UNM_VEC,
// Compute dot product between two vectors
// A, B: TValue
DOT_VEC,
// Compute Luau 'not' operation on destructured TValue // Compute Luau 'not' operation on destructured TValue
// A: tag // A: tag
// B: int (value) // B: int (value)

View file

@ -176,6 +176,7 @@ inline bool hasResult(IrCmd cmd)
case IrCmd::SUB_VEC: case IrCmd::SUB_VEC:
case IrCmd::MUL_VEC: case IrCmd::MUL_VEC:
case IrCmd::DIV_VEC: case IrCmd::DIV_VEC:
case IrCmd::DOT_VEC:
case IrCmd::UNM_VEC: case IrCmd::UNM_VEC:
case IrCmd::NOT_ANY: case IrCmd::NOT_ANY:
case IrCmd::CMP_ANY: case IrCmd::CMP_ANY:

View file

@ -586,6 +586,14 @@ void AssemblyBuilderA64::fabs(RegisterA64 dst, RegisterA64 src)
placeR1("fabs", dst, src, 0b000'11110'01'1'0000'01'10000); placeR1("fabs", dst, src, 0b000'11110'01'1'0000'01'10000);
} }
void AssemblyBuilderA64::faddp(RegisterA64 dst, RegisterA64 src)
{
CODEGEN_ASSERT(dst.kind == KindA64::d || dst.kind == KindA64::s);
CODEGEN_ASSERT(dst.kind == src.kind);
placeR1("faddp", dst, src, 0b011'11110'0'0'11000'01101'10 | ((dst.kind == KindA64::d) << 12));
}
void AssemblyBuilderA64::fadd(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2) void AssemblyBuilderA64::fadd(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2)
{ {
if (dst.kind == KindA64::d) if (dst.kind == KindA64::d)

View file

@ -946,6 +946,11 @@ void AssemblyBuilderX64::vpinsrd(RegisterX64 dst, RegisterX64 src1, OperandX64 s
placeAvx("vpinsrd", dst, src1, src2, offset, 0x22, false, AVX_0F3A, AVX_66); placeAvx("vpinsrd", dst, src1, src2, offset, 0x22, false, AVX_0F3A, AVX_66);
} }
void AssemblyBuilderX64::vdpps(OperandX64 dst, OperandX64 src1, OperandX64 src2, uint8_t mask)
{
placeAvx("vdpps", dst, src1, src2, mask, 0x40, false, AVX_0F3A, AVX_66);
}
bool AssemblyBuilderX64::finalize() bool AssemblyBuilderX64::finalize()
{ {
code.resize(codePos - code.data()); code.resize(codePos - code.data());

View file

@ -163,6 +163,8 @@ const char* getCmdName(IrCmd cmd)
return "DIV_VEC"; return "DIV_VEC";
case IrCmd::UNM_VEC: case IrCmd::UNM_VEC:
return "UNM_VEC"; return "UNM_VEC";
case IrCmd::DOT_VEC:
return "DOT_VEC";
case IrCmd::NOT_ANY: case IrCmd::NOT_ANY:
return "NOT_ANY"; return "NOT_ANY";
case IrCmd::CMP_ANY: case IrCmd::CMP_ANY:

View file

@ -728,6 +728,21 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
build.fneg(inst.regA64, regOp(inst.a)); build.fneg(inst.regA64, regOp(inst.a));
break; break;
} }
case IrCmd::DOT_VEC:
{
inst.regA64 = regs.allocReg(KindA64::d, index);
RegisterA64 temp = regs.allocTemp(KindA64::q);
RegisterA64 temps = castReg(KindA64::s, temp);
RegisterA64 regs = castReg(KindA64::s, inst.regA64);
build.fmul(temp, regOp(inst.a), regOp(inst.b));
build.faddp(regs, temps); // x+y
build.dup_4s(temp, temp, 2);
build.fadd(regs, regs, temps); // +z
build.fcvt(inst.regA64, regs);
break;
}
case IrCmd::NOT_ANY: case IrCmd::NOT_ANY:
{ {
inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.a, inst.b}); inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.a, inst.b});

View file

@ -675,6 +675,20 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
build.vxorpd(inst.regX64, regOp(inst.a), build.f32x4(-0.0, -0.0, -0.0, -0.0)); build.vxorpd(inst.regX64, regOp(inst.a), build.f32x4(-0.0, -0.0, -0.0, -0.0));
break; break;
} }
case IrCmd::DOT_VEC:
{
inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a, inst.b});
ScopedRegX64 tmp1{regs};
ScopedRegX64 tmp2{regs};
RegisterX64 tmpa = vecOp(inst.a, tmp1);
RegisterX64 tmpb = (inst.a == inst.b) ? tmpa : vecOp(inst.b, tmp2);
build.vdpps(inst.regX64, tmpa, tmpb, 0x71); // 7 = 0b0111, sum first 3 products into first float
build.vcvtss2sd(inst.regX64, inst.regX64, inst.regX64);
break;
}
case IrCmd::NOT_ANY: case IrCmd::NOT_ANY:
{ {
// TODO: if we have a single user which is a STORE_INT, we are missing the opportunity to write directly to target // TODO: if we have a single user which is a STORE_INT, we are missing the opportunity to write directly to target

View file

@ -14,6 +14,7 @@ static const int kMinMaxUnrolledParams = 5;
static const int kBit32BinaryOpUnrolledParams = 5; static const int kBit32BinaryOpUnrolledParams = 5;
LUAU_FASTFLAGVARIABLE(LuauVectorLibNativeCodegen); LUAU_FASTFLAGVARIABLE(LuauVectorLibNativeCodegen);
LUAU_FASTFLAGVARIABLE(LuauVectorLibNativeDot);
namespace Luau namespace Luau
{ {
@ -907,6 +908,16 @@ static BuiltinImplResult translateBuiltinVectorMagnitude(
build.loadAndCheckTag(arg1, LUA_TVECTOR, build.vmExit(pcpos)); build.loadAndCheckTag(arg1, LUA_TVECTOR, build.vmExit(pcpos));
IrOp sum;
if (FFlag::LuauVectorLibNativeDot)
{
IrOp a = build.inst(IrCmd::LOAD_TVALUE, arg1, build.constInt(0));
sum = build.inst(IrCmd::DOT_VEC, a, a);
}
else
{
IrOp x = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(0)); IrOp x = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(0));
IrOp y = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(4)); IrOp y = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(4));
IrOp z = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(8)); IrOp z = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(8));
@ -915,7 +926,8 @@ static BuiltinImplResult translateBuiltinVectorMagnitude(
IrOp y2 = build.inst(IrCmd::MUL_NUM, y, y); IrOp y2 = build.inst(IrCmd::MUL_NUM, y, y);
IrOp z2 = build.inst(IrCmd::MUL_NUM, z, z); IrOp z2 = build.inst(IrCmd::MUL_NUM, z, z);
IrOp sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, x2, y2), z2); sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, x2, y2), z2);
}
IrOp mag = build.inst(IrCmd::SQRT_NUM, sum); IrOp mag = build.inst(IrCmd::SQRT_NUM, sum);
@ -945,6 +957,23 @@ static BuiltinImplResult translateBuiltinVectorNormalize(
build.loadAndCheckTag(arg1, LUA_TVECTOR, build.vmExit(pcpos)); build.loadAndCheckTag(arg1, LUA_TVECTOR, build.vmExit(pcpos));
if (FFlag::LuauVectorLibNativeDot)
{
IrOp a = build.inst(IrCmd::LOAD_TVALUE, arg1, build.constInt(0));
IrOp sum = build.inst(IrCmd::DOT_VEC, a, a);
IrOp mag = build.inst(IrCmd::SQRT_NUM, sum);
IrOp inv = build.inst(IrCmd::DIV_NUM, build.constDouble(1.0), mag);
IrOp invvec = build.inst(IrCmd::NUM_TO_VEC, inv);
IrOp result = build.inst(IrCmd::MUL_VEC, a, invvec);
result = build.inst(IrCmd::TAG_VECTOR, result);
build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), result);
}
else
{
IrOp x = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(0)); IrOp x = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(0));
IrOp y = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(4)); IrOp y = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(4));
IrOp z = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(8)); IrOp z = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(8));
@ -964,6 +993,7 @@ static BuiltinImplResult translateBuiltinVectorNormalize(
build.inst(IrCmd::STORE_VECTOR, build.vmReg(ra), xr, yr, zr); build.inst(IrCmd::STORE_VECTOR, build.vmReg(ra), xr, yr, zr);
build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TVECTOR)); build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TVECTOR));
}
return {BuiltinImplType::Full, 1}; return {BuiltinImplType::Full, 1};
} }
@ -1019,6 +1049,17 @@ static BuiltinImplResult translateBuiltinVectorDot(IrBuilder& build, int nparams
build.loadAndCheckTag(arg1, LUA_TVECTOR, build.vmExit(pcpos)); build.loadAndCheckTag(arg1, LUA_TVECTOR, build.vmExit(pcpos));
build.loadAndCheckTag(args, LUA_TVECTOR, build.vmExit(pcpos)); build.loadAndCheckTag(args, LUA_TVECTOR, build.vmExit(pcpos));
IrOp sum;
if (FFlag::LuauVectorLibNativeDot)
{
IrOp a = build.inst(IrCmd::LOAD_TVALUE, arg1, build.constInt(0));
IrOp b = build.inst(IrCmd::LOAD_TVALUE, args, build.constInt(0));
sum = build.inst(IrCmd::DOT_VEC, a, b);
}
else
{
IrOp x1 = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(0)); IrOp x1 = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(0));
IrOp x2 = build.inst(IrCmd::LOAD_FLOAT, args, build.constInt(0)); IrOp x2 = build.inst(IrCmd::LOAD_FLOAT, args, build.constInt(0));
IrOp xx = build.inst(IrCmd::MUL_NUM, x1, x2); IrOp xx = build.inst(IrCmd::MUL_NUM, x1, x2);
@ -1031,7 +1072,8 @@ static BuiltinImplResult translateBuiltinVectorDot(IrBuilder& build, int nparams
IrOp z2 = build.inst(IrCmd::LOAD_FLOAT, args, build.constInt(8)); IrOp z2 = build.inst(IrCmd::LOAD_FLOAT, args, build.constInt(8));
IrOp zz = build.inst(IrCmd::MUL_NUM, z1, z2); IrOp zz = build.inst(IrCmd::MUL_NUM, z1, z2);
IrOp sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, xx, yy), zz); sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, xx, yy), zz);
}
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), sum); build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), sum);
build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER));

View file

@ -75,6 +75,8 @@ IrValueKind getCmdValueKind(IrCmd cmd)
case IrCmd::DIV_VEC: case IrCmd::DIV_VEC:
case IrCmd::UNM_VEC: case IrCmd::UNM_VEC:
return IrValueKind::Tvalue; return IrValueKind::Tvalue;
case IrCmd::DOT_VEC:
return IrValueKind::Double;
case IrCmd::NOT_ANY: case IrCmd::NOT_ANY:
case IrCmd::CMP_ANY: case IrCmd::CMP_ANY:
return IrValueKind::Int; return IrValueKind::Int;

View file

@ -768,7 +768,8 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction&
if (tag == LUA_TBOOLEAN && if (tag == LUA_TBOOLEAN &&
(value.kind == IrOpKind::Inst || (value.kind == IrOpKind::Constant && function.constOp(value).kind == IrConstKind::Int))) (value.kind == IrOpKind::Inst || (value.kind == IrOpKind::Constant && function.constOp(value).kind == IrConstKind::Int)))
canSplitTvalueStore = true; canSplitTvalueStore = true;
else if (tag == LUA_TNUMBER && (value.kind == IrOpKind::Inst || (value.kind == IrOpKind::Constant && function.constOp(value).kind == IrConstKind::Double))) else if (tag == LUA_TNUMBER &&
(value.kind == IrOpKind::Inst || (value.kind == IrOpKind::Constant && function.constOp(value).kind == IrConstKind::Double)))
canSplitTvalueStore = true; canSplitTvalueStore = true;
else if (tag != 0xff && isGCO(tag) && value.kind == IrOpKind::Inst) else if (tag != 0xff && isGCO(tag) && value.kind == IrOpKind::Inst)
canSplitTvalueStore = true; canSplitTvalueStore = true;
@ -1342,6 +1343,7 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction&
case IrCmd::SUB_VEC: case IrCmd::SUB_VEC:
case IrCmd::MUL_VEC: case IrCmd::MUL_VEC:
case IrCmd::DIV_VEC: case IrCmd::DIV_VEC:
case IrCmd::DOT_VEC:
if (IrInst* a = function.asInstOp(inst.a); a && a->cmd == IrCmd::TAG_VECTOR) if (IrInst* a = function.asInstOp(inst.a); a && a->cmd == IrCmd::TAG_VECTOR)
replace(function, inst.a, a->a); replace(function, inst.a, a->a);

View file

@ -400,6 +400,9 @@ TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "FPMath")
SINGLE_COMPARE(fsub(d1, d2, d3), 0x1E633841); SINGLE_COMPARE(fsub(d1, d2, d3), 0x1E633841);
SINGLE_COMPARE(fsub(s29, s29, s28), 0x1E3C3BBD); SINGLE_COMPARE(fsub(s29, s29, s28), 0x1E3C3BBD);
SINGLE_COMPARE(faddp(s29, s28), 0x7E30DB9D);
SINGLE_COMPARE(faddp(d29, d28), 0x7E70DB9D);
SINGLE_COMPARE(frinta(d1, d2), 0x1E664041); SINGLE_COMPARE(frinta(d1, d2), 0x1E664041);
SINGLE_COMPARE(frintm(d1, d2), 0x1E654041); SINGLE_COMPARE(frintm(d1, d2), 0x1E654041);
SINGLE_COMPARE(frintp(d1, d2), 0x1E64C041); SINGLE_COMPARE(frintp(d1, d2), 0x1E64C041);

View file

@ -577,6 +577,8 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXTernaryInstructionForms")
SINGLE_COMPARE(vpshufps(xmm7, xmm12, xmmword[rcx + r10], 0b11010100), 0xc4, 0xa1, 0x18, 0xc6, 0x3c, 0x11, 0xd4); SINGLE_COMPARE(vpshufps(xmm7, xmm12, xmmword[rcx + r10], 0b11010100), 0xc4, 0xa1, 0x18, 0xc6, 0x3c, 0x11, 0xd4);
SINGLE_COMPARE(vpinsrd(xmm7, xmm12, xmmword[rcx + r10], 2), 0xc4, 0xa3, 0x19, 0x22, 0x3c, 0x11, 0x02); SINGLE_COMPARE(vpinsrd(xmm7, xmm12, xmmword[rcx + r10], 2), 0xc4, 0xa3, 0x19, 0x22, 0x3c, 0x11, 0x02);
SINGLE_COMPARE(vdpps(xmm7, xmm12, xmmword[rcx + r10], 2), 0xc4, 0xa3, 0x19, 0x40, 0x3c, 0x11, 0x02);
} }
TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "MiscInstructions") TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "MiscInstructions")