Implement support for math.lerp (#1608)
Some checks are pending
benchmark / callgrind (map[branch:main name:luau-lang/benchmark-data], ubuntu-22.04) (push) Waiting to run
build / macos (push) Waiting to run
build / macos-arm (push) Waiting to run
build / ubuntu (push) Waiting to run
build / windows (Win32) (push) Waiting to run
build / windows (x64) (push) Waiting to run
build / coverage (push) Waiting to run
build / web (push) Waiting to run
release / macos (push) Waiting to run
release / ubuntu (push) Waiting to run
release / windows (push) Waiting to run
release / web (push) Waiting to run

This change implements math.lerp RFC with C function definition, builtin
function, builtin constant folding and tests.

The tests validate a few lerp properties by providing counter-examples
for popular lerp implementations; the testing is of course not
exhaustive, as exhaustive testing was done offline using fuzzing.

Type definitions will be updated separately.

Codegen support will be implemented separately: it requires new IR for
conditional
selects to represent the desired logic without using a branch.
This commit is contained in:
Arseny Kapoulkine 2025-01-09 09:42:07 -08:00 committed by GitHub
parent 9a102e2aff
commit 8a4ef26f89
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 96 additions and 8 deletions

View file

@ -235,7 +235,7 @@ static uint8_t getBytecodeConstantTag(Proto* proto, unsigned ki)
return LBC_TYPE_ANY; return LBC_TYPE_ANY;
} }
static void applyBuiltinCall(int bfid, BytecodeTypes& types) static void applyBuiltinCall(LuauBuiltinFunction bfid, BytecodeTypes& types)
{ {
switch (bfid) switch (bfid)
{ {
@ -549,6 +549,12 @@ static void applyBuiltinCall(int bfid, BytecodeTypes& types)
types.b = LBC_TYPE_VECTOR; types.b = LBC_TYPE_VECTOR;
types.c = LBC_TYPE_VECTOR; // We can mark optional arguments types.c = LBC_TYPE_VECTOR; // We can mark optional arguments
break; break;
case LBF_MATH_LERP:
types.result = LBC_TYPE_NUMBER;
types.a = LBC_TYPE_NUMBER;
types.b = LBC_TYPE_NUMBER;
types.c = LBC_TYPE_NUMBER;
break;
} }
} }
@ -1086,7 +1092,7 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks)
CODEGEN_ASSERT(LUAU_INSN_OP(call) == LOP_CALL); CODEGEN_ASSERT(LUAU_INSN_OP(call) == LOP_CALL);
int ra = LUAU_INSN_A(call); int ra = LUAU_INSN_A(call);
applyBuiltinCall(bfid, bcType); applyBuiltinCall(LuauBuiltinFunction(bfid), bcType);
regTags[ra + 1] = bcType.a; regTags[ra + 1] = bcType.a;
regTags[ra + 2] = bcType.b; regTags[ra + 2] = bcType.b;
regTags[ra + 3] = bcType.c; regTags[ra + 3] = bcType.c;
@ -1105,7 +1111,7 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks)
CODEGEN_ASSERT(LUAU_INSN_OP(call) == LOP_CALL); CODEGEN_ASSERT(LUAU_INSN_OP(call) == LOP_CALL);
int ra = LUAU_INSN_A(call); int ra = LUAU_INSN_A(call);
applyBuiltinCall(bfid, bcType); applyBuiltinCall(LuauBuiltinFunction(bfid), bcType);
regTags[LUAU_INSN_B(*pc)] = bcType.a; regTags[LUAU_INSN_B(*pc)] = bcType.a;
regTags[ra] = bcType.result; regTags[ra] = bcType.result;
@ -1122,7 +1128,7 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks)
CODEGEN_ASSERT(LUAU_INSN_OP(call) == LOP_CALL); CODEGEN_ASSERT(LUAU_INSN_OP(call) == LOP_CALL);
int ra = LUAU_INSN_A(call); int ra = LUAU_INSN_A(call);
applyBuiltinCall(bfid, bcType); applyBuiltinCall(LuauBuiltinFunction(bfid), bcType);
regTags[LUAU_INSN_B(*pc)] = bcType.a; regTags[LUAU_INSN_B(*pc)] = bcType.a;
regTags[int(pc[1])] = bcType.b; regTags[int(pc[1])] = bcType.b;
@ -1141,7 +1147,7 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks)
CODEGEN_ASSERT(LUAU_INSN_OP(call) == LOP_CALL); CODEGEN_ASSERT(LUAU_INSN_OP(call) == LOP_CALL);
int ra = LUAU_INSN_A(call); int ra = LUAU_INSN_A(call);
applyBuiltinCall(bfid, bcType); applyBuiltinCall(LuauBuiltinFunction(bfid), bcType);
regTags[LUAU_INSN_B(*pc)] = bcType.a; regTags[LUAU_INSN_B(*pc)] = bcType.a;
regTags[aux & 0xff] = bcType.b; regTags[aux & 0xff] = bcType.b;

View file

@ -550,6 +550,7 @@ static void handleBuiltinEffects(ConstPropState& state, LuauBuiltinFunction bfid
case LBF_VECTOR_CLAMP: case LBF_VECTOR_CLAMP:
case LBF_VECTOR_MIN: case LBF_VECTOR_MIN:
case LBF_VECTOR_MAX: case LBF_VECTOR_MAX:
case LBF_MATH_LERP:
break; break;
case LBF_TABLE_INSERT: case LBF_TABLE_INSERT:
state.invalidateHeap(); state.invalidateHeap();

View file

@ -613,6 +613,9 @@ enum LuauBuiltinFunction
LBF_VECTOR_CLAMP, LBF_VECTOR_CLAMP,
LBF_VECTOR_MIN, LBF_VECTOR_MIN,
LBF_VECTOR_MAX, LBF_VECTOR_MAX,
// math.lerp
LBF_MATH_LERP,
}; };
// Capture type, used in LOP_CAPTURE // Capture type, used in LOP_CAPTURE

View file

@ -5,6 +5,8 @@
#include <math.h> #include <math.h>
LUAU_FASTFLAG(LuauCompileMathLerp)
namespace Luau namespace Luau
{ {
namespace Compile namespace Compile
@ -479,6 +481,19 @@ Constant foldBuiltin(int bfid, const Constant* args, size_t count)
return cvector(args[0].valueNumber, args[1].valueNumber, args[2].valueNumber, args[3].valueNumber); return cvector(args[0].valueNumber, args[1].valueNumber, args[2].valueNumber, args[3].valueNumber);
} }
break; break;
case LBF_MATH_LERP:
if (FFlag::LuauCompileMathLerp && count == 3 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number &&
args[2].type == Constant::Type_Number)
{
double a = args[0].valueNumber;
double b = args[1].valueNumber;
double t = args[2].valueNumber;
double v = (t == 1.0) ? b : a + (b - a) * t;
return cnum(v);
}
break;
} }
return cvar(); return cvar();

View file

@ -9,6 +9,7 @@
LUAU_FASTFLAGVARIABLE(LuauVectorBuiltins) LUAU_FASTFLAGVARIABLE(LuauVectorBuiltins)
LUAU_FASTFLAGVARIABLE(LuauCompileDisabledBuiltins) LUAU_FASTFLAGVARIABLE(LuauCompileDisabledBuiltins)
LUAU_FASTFLAGVARIABLE(LuauCompileMathLerp)
namespace Luau namespace Luau
{ {
@ -140,6 +141,8 @@ static int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& op
return LBF_MATH_SIGN; return LBF_MATH_SIGN;
if (builtin.method == "round") if (builtin.method == "round")
return LBF_MATH_ROUND; return LBF_MATH_ROUND;
if (FFlag::LuauCompileMathLerp && builtin.method == "lerp")
return LBF_MATH_LERP;
} }
if (builtin.object == "bit32") if (builtin.object == "bit32")
@ -556,6 +559,10 @@ BuiltinInfo getBuiltinInfo(int bfid)
case LBF_VECTOR_MIN: case LBF_VECTOR_MIN:
case LBF_VECTOR_MAX: case LBF_VECTOR_MAX:
return {-1, 1}; // variadic return {-1, 1}; // variadic
case LBF_MATH_LERP:
LUAU_ASSERT(FFlag::LuauCompileMathLerp);
return {3, 1, BuiltinInfo::Flag_NoneSafe};
} }
LUAU_UNREACHABLE(); LUAU_UNREACHABLE();

View file

@ -747,6 +747,7 @@ struct TypeMapVisitor : AstVisitor
case LBF_BUFFER_READF64: case LBF_BUFFER_READF64:
case LBF_VECTOR_MAGNITUDE: case LBF_VECTOR_MAGNITUDE:
case LBF_VECTOR_DOT: case LBF_VECTOR_DOT:
case LBF_MATH_LERP:
recordResolvedType(node, &builtinTypes.numberType); recordResolvedType(node, &builtinTypes.numberType);
break; break;

View file

@ -1694,6 +1694,23 @@ static int luauF_vectormax(lua_State* L, StkId res, TValue* arg0, int nresults,
return -1; return -1;
} }
static int luauF_lerp(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams)
{
if (nparams >= 3 && nresults <= 1 && ttisnumber(arg0) && ttisnumber(args) && ttisnumber(args + 1))
{
double a = nvalue(arg0);
double b = nvalue(args);
double t = nvalue(args + 1);
double r = (t == 1.0) ? b : a + (b - a) * t;
setnvalue(res, r);
return 1;
}
return -1;
}
static int luauF_missing(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) static int luauF_missing(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams)
{ {
return -1; return -1;
@ -1889,6 +1906,8 @@ const luau_FastFunction luauF_table[256] = {
luauF_vectormin, luauF_vectormin,
luauF_vectormax, luauF_vectormax,
luauF_lerp,
// When adding builtins, add them above this line; what follows is 64 "dummy" entries with luauF_missing fallback. // When adding builtins, add them above this line; what follows is 64 "dummy" entries with luauF_missing fallback.
// This is important so that older versions of the runtime that don't support newer builtins automatically fall back via luauF_missing. // This is important so that older versions of the runtime that don't support newer builtins automatically fall back via luauF_missing.
// Given the builtin addition velocity this should always provide a larger compatibility window than bytecode versions suggest. // Given the builtin addition velocity this should always provide a larger compatibility window than bytecode versions suggest.

View file

@ -8,6 +8,7 @@
#include <time.h> #include <time.h>
LUAU_FASTFLAGVARIABLE(LuauMathMap) LUAU_FASTFLAGVARIABLE(LuauMathMap)
LUAU_FASTFLAGVARIABLE(LuauMathLerp)
#undef PI #undef PI
#define PI (3.14159265358979323846) #define PI (3.14159265358979323846)
@ -418,6 +419,17 @@ static int math_map(lua_State* L)
return 1; return 1;
} }
static int math_lerp(lua_State* L)
{
double a = luaL_checknumber(L, 1);
double b = luaL_checknumber(L, 2);
double t = luaL_checknumber(L, 3);
double r = (t == 1.0) ? b : a + (b - a) * t;
lua_pushnumber(L, r);
return 1;
}
static const luaL_Reg mathlib[] = { static const luaL_Reg mathlib[] = {
{"abs", math_abs}, {"abs", math_abs},
{"acos", math_acos}, {"acos", math_acos},
@ -477,5 +489,11 @@ int luaopen_math(lua_State* L)
lua_setfield(L, -2, "map"); lua_setfield(L, -2, "map");
} }
if (FFlag::LuauMathLerp)
{
lua_pushcfunction(L, math_lerp, "lerp");
lua_setfield(L, -2, "lerp");
}
return 1; return 1;
} }

View file

@ -32,6 +32,7 @@ void luaC_fullgc(lua_State* L);
void luaC_validate(lua_State* L); void luaC_validate(lua_State* L);
LUAU_FASTFLAG(LuauMathMap) LUAU_FASTFLAG(LuauMathMap)
LUAU_FASTFLAG(LuauMathLerp)
LUAU_FASTFLAG(DebugLuauAbortingChecks) LUAU_FASTFLAG(DebugLuauAbortingChecks)
LUAU_FASTINT(CodegenHeuristicsInstructionLimit) LUAU_FASTINT(CodegenHeuristicsInstructionLimit)
LUAU_DYNAMIC_FASTFLAG(LuauStackLimit) LUAU_DYNAMIC_FASTFLAG(LuauStackLimit)
@ -660,6 +661,7 @@ TEST_CASE("Buffers")
TEST_CASE("Math") TEST_CASE("Math")
{ {
ScopedFastFlag LuauMathMap{FFlag::LuauMathMap, true}; ScopedFastFlag LuauMathMap{FFlag::LuauMathMap, true};
ScopedFastFlag LuauMathLerp{FFlag::LuauMathLerp, true};
runConformance("math.lua"); runConformance("math.lua");
} }
@ -911,9 +913,7 @@ TEST_CASE("VectorLibrary")
copts.optimizationLevel = 2; copts.optimizationLevel = 2;
} }
runConformance( runConformance("vector_library.lua", [](lua_State* L) {}, nullptr, nullptr, &copts);
"vector_library.lua", [](lua_State* L) {}, nullptr, nullptr, &copts
);
} }
static void populateRTTI(lua_State* L, Luau::TypeId type) static void populateRTTI(lua_State* L, Luau::TypeId type)
@ -988,6 +988,7 @@ static void populateRTTI(lua_State* L, Luau::TypeId type)
TEST_CASE("Types") TEST_CASE("Types")
{ {
ScopedFastFlag luauVectorDefinitions{FFlag::LuauVectorDefinitions, true}; ScopedFastFlag luauVectorDefinitions{FFlag::LuauVectorDefinitions, true};
ScopedFastFlag luauMathLerp{FFlag::LuauMathLerp, false}; // waiting for math.lerp to be added to embedded type definitions
runConformance( runConformance(
"types.lua", "types.lua",

View file

@ -402,6 +402,22 @@ assert(math.map(4, 4, 1, 2, 0) == 2)
assert(math.map(-8, 0, 4, 0, 2) == -4) assert(math.map(-8, 0, 4, 0, 2) == -4)
assert(math.map(16, 0, 4, 0, 2) == 8) assert(math.map(16, 0, 4, 0, 2) == 8)
-- lerp basics
assert(math.lerp(1, 5, 0) == 1)
assert(math.lerp(1, 5, 1) == 5)
assert(math.lerp(1, 5, 0.5) == 3)
assert(math.lerp(1, 5, 1.5) == 7)
assert(math.lerp(1, 5, -0.5) == -1)
-- lerp properties
local sq2, sq3 = math.sqrt(2), math.sqrt(3)
assert(math.lerp(sq2, sq3, 0) == sq2) -- exact at 0
assert(math.lerp(sq2, sq3, 1) == sq3) -- exact at 1
assert(math.lerp(-sq3, sq2, 1) == sq2) -- exact at 1 (fails for a + t*(b-a))
assert(math.lerp(sq2, sq2, sq2 / 2) <= math.lerp(sq2, sq2, 1)) -- monotonic (fails for a*t + b*(1-t))
assert(math.lerp(-sq3, sq2, 1) <= math.sqrt(2)) -- bounded (fails for a + t*(b-a))
assert(math.lerp(sq2, sq2, sq2 / 2) == sq2) -- consistent (fails for a*t + b*(1-t))
assert(tostring(math.pow(-2, 0.5)) == "nan") assert(tostring(math.pow(-2, 0.5)) == "nan")
-- test that fastcalls return correct number of results -- test that fastcalls return correct number of results
@ -464,5 +480,6 @@ assert(math.sign("2") == 1)
assert(math.sign("-2") == -1) assert(math.sign("-2") == -1)
assert(math.sign("0") == 0) assert(math.sign("0") == 0)
assert(math.round("1.8") == 2) assert(math.round("1.8") == 2)
assert(math.lerp("1", "5", 0.5) == 3)
return('OK') return('OK')