CodeGen: Use DOT_VEC under a fast flag for lowering vector lib

This is useful for vector.dot, vector.magnitude and vector.normalize.
This commit is contained in:
Arseny Kapoulkine 2024-11-08 10:09:23 -08:00
parent 74a91289e9
commit cd73807c09

View file

@ -14,6 +14,7 @@ static const int kMinMaxUnrolledParams = 5;
static const int kBit32BinaryOpUnrolledParams = 5; static const int kBit32BinaryOpUnrolledParams = 5;
LUAU_FASTFLAGVARIABLE(LuauVectorLibNativeCodegen); LUAU_FASTFLAGVARIABLE(LuauVectorLibNativeCodegen);
LUAU_FASTFLAGVARIABLE(LuauVectorLibNativeDot);
namespace Luau namespace Luau
{ {
@ -907,6 +908,16 @@ static BuiltinImplResult translateBuiltinVectorMagnitude(
build.loadAndCheckTag(arg1, LUA_TVECTOR, build.vmExit(pcpos)); build.loadAndCheckTag(arg1, LUA_TVECTOR, build.vmExit(pcpos));
IrOp sum;
if (FFlag::LuauVectorLibNativeDot)
{
IrOp a = build.inst(IrCmd::LOAD_TVALUE, arg1, build.constInt(0));
sum = build.inst(IrCmd::DOT_VEC, a, a);
}
else
{
IrOp x = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(0)); IrOp x = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(0));
IrOp y = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(4)); IrOp y = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(4));
IrOp z = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(8)); IrOp z = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(8));
@ -915,7 +926,8 @@ static BuiltinImplResult translateBuiltinVectorMagnitude(
IrOp y2 = build.inst(IrCmd::MUL_NUM, y, y); IrOp y2 = build.inst(IrCmd::MUL_NUM, y, y);
IrOp z2 = build.inst(IrCmd::MUL_NUM, z, z); IrOp z2 = build.inst(IrCmd::MUL_NUM, z, z);
IrOp sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, x2, y2), z2); sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, x2, y2), z2);
}
IrOp mag = build.inst(IrCmd::SQRT_NUM, sum); IrOp mag = build.inst(IrCmd::SQRT_NUM, sum);
@ -945,6 +957,23 @@ static BuiltinImplResult translateBuiltinVectorNormalize(
build.loadAndCheckTag(arg1, LUA_TVECTOR, build.vmExit(pcpos)); build.loadAndCheckTag(arg1, LUA_TVECTOR, build.vmExit(pcpos));
if (FFlag::LuauVectorLibNativeDot)
{
IrOp a = build.inst(IrCmd::LOAD_TVALUE, arg1, build.constInt(0));
IrOp sum = build.inst(IrCmd::DOT_VEC, a, a);
IrOp mag = build.inst(IrCmd::SQRT_NUM, sum);
IrOp inv = build.inst(IrCmd::DIV_NUM, build.constDouble(1.0), mag);
IrOp invvec = build.inst(IrCmd::NUM_TO_VEC, inv);
IrOp result = build.inst(IrCmd::MUL_VEC, a, invvec);
result = build.inst(IrCmd::TAG_VECTOR, result);
build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), result);
}
else
{
IrOp x = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(0)); IrOp x = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(0));
IrOp y = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(4)); IrOp y = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(4));
IrOp z = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(8)); IrOp z = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(8));
@ -964,6 +993,7 @@ static BuiltinImplResult translateBuiltinVectorNormalize(
build.inst(IrCmd::STORE_VECTOR, build.vmReg(ra), xr, yr, zr); build.inst(IrCmd::STORE_VECTOR, build.vmReg(ra), xr, yr, zr);
build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TVECTOR)); build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TVECTOR));
}
return {BuiltinImplType::Full, 1}; return {BuiltinImplType::Full, 1};
} }
@ -1019,6 +1049,17 @@ static BuiltinImplResult translateBuiltinVectorDot(IrBuilder& build, int nparams
build.loadAndCheckTag(arg1, LUA_TVECTOR, build.vmExit(pcpos)); build.loadAndCheckTag(arg1, LUA_TVECTOR, build.vmExit(pcpos));
build.loadAndCheckTag(args, LUA_TVECTOR, build.vmExit(pcpos)); build.loadAndCheckTag(args, LUA_TVECTOR, build.vmExit(pcpos));
IrOp sum;
if (FFlag::LuauVectorLibNativeDot)
{
IrOp a = build.inst(IrCmd::LOAD_TVALUE, arg1, build.constInt(0));
IrOp b = build.inst(IrCmd::LOAD_TVALUE, args, build.constInt(0));
sum = build.inst(IrCmd::DOT_VEC, a, b);
}
else
{
IrOp x1 = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(0)); IrOp x1 = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(0));
IrOp x2 = build.inst(IrCmd::LOAD_FLOAT, args, build.constInt(0)); IrOp x2 = build.inst(IrCmd::LOAD_FLOAT, args, build.constInt(0));
IrOp xx = build.inst(IrCmd::MUL_NUM, x1, x2); IrOp xx = build.inst(IrCmd::MUL_NUM, x1, x2);
@ -1031,7 +1072,8 @@ static BuiltinImplResult translateBuiltinVectorDot(IrBuilder& build, int nparams
IrOp z2 = build.inst(IrCmd::LOAD_FLOAT, args, build.constInt(8)); IrOp z2 = build.inst(IrCmd::LOAD_FLOAT, args, build.constInt(8));
IrOp zz = build.inst(IrCmd::MUL_NUM, z1, z2); IrOp zz = build.inst(IrCmd::MUL_NUM, z1, z2);
IrOp sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, xx, yy), zz); sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, xx, yy), zz);
}
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), sum); build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), sum);
build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER));