From 8a4ef26f89d959d94333fee41d09387b278d0e8d Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 9 Jan 2025 09:42:07 -0800 Subject: [PATCH] Implement support for math.lerp (#1608) This change implements math.lerp RFC with C function definition, builtin function, builtin constant folding and tests. The tests validate a few lerp properties by providing counter-examples for popular lerp implementations; the testing is of course not exhaustive, as exhaustive testing was done offline using fuzzing. Type definitions will be updated separately. Codegen support will be implemented separately: it requires new IR for conditional selects to represent the desired logic without using a branch. --- CodeGen/src/BytecodeAnalysis.cpp | 16 +++++++++++----- CodeGen/src/OptimizeConstProp.cpp | 1 + Common/include/Luau/Bytecode.h | 3 +++ Compiler/src/BuiltinFolding.cpp | 15 +++++++++++++++ Compiler/src/Builtins.cpp | 7 +++++++ Compiler/src/Types.cpp | 1 + VM/src/lbuiltins.cpp | 19 +++++++++++++++++++ VM/src/lmathlib.cpp | 18 ++++++++++++++++++ tests/Conformance.test.cpp | 7 ++++--- tests/conformance/math.lua | 17 +++++++++++++++++ 10 files changed, 96 insertions(+), 8 deletions(-) diff --git a/CodeGen/src/BytecodeAnalysis.cpp b/CodeGen/src/BytecodeAnalysis.cpp index 85317b60..f0f1ec8e 100644 --- a/CodeGen/src/BytecodeAnalysis.cpp +++ b/CodeGen/src/BytecodeAnalysis.cpp @@ -235,7 +235,7 @@ static uint8_t getBytecodeConstantTag(Proto* proto, unsigned ki) return LBC_TYPE_ANY; } -static void applyBuiltinCall(int bfid, BytecodeTypes& types) +static void applyBuiltinCall(LuauBuiltinFunction bfid, BytecodeTypes& types) { switch (bfid) { @@ -549,6 +549,12 @@ static void applyBuiltinCall(int bfid, BytecodeTypes& types) types.b = LBC_TYPE_VECTOR; types.c = LBC_TYPE_VECTOR; // We can mark optional arguments break; + case LBF_MATH_LERP: + types.result = LBC_TYPE_NUMBER; + types.a = LBC_TYPE_NUMBER; + types.b = LBC_TYPE_NUMBER; + types.c = LBC_TYPE_NUMBER; + break; } } @@ -1086,7 +1092,7 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) CODEGEN_ASSERT(LUAU_INSN_OP(call) == LOP_CALL); int ra = LUAU_INSN_A(call); - applyBuiltinCall(bfid, bcType); + applyBuiltinCall(LuauBuiltinFunction(bfid), bcType); regTags[ra + 1] = bcType.a; regTags[ra + 2] = bcType.b; regTags[ra + 3] = bcType.c; @@ -1105,7 +1111,7 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) CODEGEN_ASSERT(LUAU_INSN_OP(call) == LOP_CALL); int ra = LUAU_INSN_A(call); - applyBuiltinCall(bfid, bcType); + applyBuiltinCall(LuauBuiltinFunction(bfid), bcType); regTags[LUAU_INSN_B(*pc)] = bcType.a; regTags[ra] = bcType.result; @@ -1122,7 +1128,7 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) CODEGEN_ASSERT(LUAU_INSN_OP(call) == LOP_CALL); int ra = LUAU_INSN_A(call); - applyBuiltinCall(bfid, bcType); + applyBuiltinCall(LuauBuiltinFunction(bfid), bcType); regTags[LUAU_INSN_B(*pc)] = bcType.a; regTags[int(pc[1])] = bcType.b; @@ -1141,7 +1147,7 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) CODEGEN_ASSERT(LUAU_INSN_OP(call) == LOP_CALL); int ra = LUAU_INSN_A(call); - applyBuiltinCall(bfid, bcType); + applyBuiltinCall(LuauBuiltinFunction(bfid), bcType); regTags[LUAU_INSN_B(*pc)] = bcType.a; regTags[aux & 0xff] = bcType.b; diff --git a/CodeGen/src/OptimizeConstProp.cpp b/CodeGen/src/OptimizeConstProp.cpp index 1e532280..9c755563 100644 --- a/CodeGen/src/OptimizeConstProp.cpp +++ b/CodeGen/src/OptimizeConstProp.cpp @@ -550,6 +550,7 @@ static void handleBuiltinEffects(ConstPropState& state, LuauBuiltinFunction bfid case LBF_VECTOR_CLAMP: case LBF_VECTOR_MIN: case LBF_VECTOR_MAX: + case LBF_MATH_LERP: break; case LBF_TABLE_INSERT: state.invalidateHeap(); diff --git a/Common/include/Luau/Bytecode.h b/Common/include/Luau/Bytecode.h index 8d281393..a151056c 100644 --- a/Common/include/Luau/Bytecode.h +++ b/Common/include/Luau/Bytecode.h @@ -613,6 +613,9 @@ enum LuauBuiltinFunction LBF_VECTOR_CLAMP, LBF_VECTOR_MIN, LBF_VECTOR_MAX, + + // math.lerp + LBF_MATH_LERP, }; // Capture type, used in LOP_CAPTURE diff --git a/Compiler/src/BuiltinFolding.cpp b/Compiler/src/BuiltinFolding.cpp index 0886e94a..916021a6 100644 --- a/Compiler/src/BuiltinFolding.cpp +++ b/Compiler/src/BuiltinFolding.cpp @@ -5,6 +5,8 @@ #include +LUAU_FASTFLAG(LuauCompileMathLerp) + namespace Luau { namespace Compile @@ -479,6 +481,19 @@ Constant foldBuiltin(int bfid, const Constant* args, size_t count) return cvector(args[0].valueNumber, args[1].valueNumber, args[2].valueNumber, args[3].valueNumber); } break; + + case LBF_MATH_LERP: + if (FFlag::LuauCompileMathLerp && count == 3 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number && + args[2].type == Constant::Type_Number) + { + double a = args[0].valueNumber; + double b = args[1].valueNumber; + double t = args[2].valueNumber; + + double v = (t == 1.0) ? b : a + (b - a) * t; + return cnum(v); + } + break; } return cvar(); diff --git a/Compiler/src/Builtins.cpp b/Compiler/src/Builtins.cpp index e8b0cd98..902d74ea 100644 --- a/Compiler/src/Builtins.cpp +++ b/Compiler/src/Builtins.cpp @@ -9,6 +9,7 @@ LUAU_FASTFLAGVARIABLE(LuauVectorBuiltins) LUAU_FASTFLAGVARIABLE(LuauCompileDisabledBuiltins) +LUAU_FASTFLAGVARIABLE(LuauCompileMathLerp) namespace Luau { @@ -140,6 +141,8 @@ static int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& op return LBF_MATH_SIGN; if (builtin.method == "round") return LBF_MATH_ROUND; + if (FFlag::LuauCompileMathLerp && builtin.method == "lerp") + return LBF_MATH_LERP; } if (builtin.object == "bit32") @@ -556,6 +559,10 @@ BuiltinInfo getBuiltinInfo(int bfid) case LBF_VECTOR_MIN: case LBF_VECTOR_MAX: return {-1, 1}; // variadic + + case LBF_MATH_LERP: + LUAU_ASSERT(FFlag::LuauCompileMathLerp); + return {3, 1, BuiltinInfo::Flag_NoneSafe}; } LUAU_UNREACHABLE(); diff --git a/Compiler/src/Types.cpp b/Compiler/src/Types.cpp index 02aec11a..41363ce1 100644 --- a/Compiler/src/Types.cpp +++ b/Compiler/src/Types.cpp @@ -747,6 +747,7 @@ struct TypeMapVisitor : AstVisitor case LBF_BUFFER_READF64: case LBF_VECTOR_MAGNITUDE: case LBF_VECTOR_DOT: + case LBF_MATH_LERP: recordResolvedType(node, &builtinTypes.numberType); break; diff --git a/VM/src/lbuiltins.cpp b/VM/src/lbuiltins.cpp index 0bca4495..6d71836e 100644 --- a/VM/src/lbuiltins.cpp +++ b/VM/src/lbuiltins.cpp @@ -1694,6 +1694,23 @@ static int luauF_vectormax(lua_State* L, StkId res, TValue* arg0, int nresults, return -1; } +static int luauF_lerp(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 3 && nresults <= 1 && ttisnumber(arg0) && ttisnumber(args) && ttisnumber(args + 1)) + { + double a = nvalue(arg0); + double b = nvalue(args); + double t = nvalue(args + 1); + + double r = (t == 1.0) ? b : a + (b - a) * t; + + setnvalue(res, r); + return 1; + } + + return -1; +} + static int luauF_missing(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) { return -1; @@ -1889,6 +1906,8 @@ const luau_FastFunction luauF_table[256] = { luauF_vectormin, luauF_vectormax, + luauF_lerp, + // When adding builtins, add them above this line; what follows is 64 "dummy" entries with luauF_missing fallback. // This is important so that older versions of the runtime that don't support newer builtins automatically fall back via luauF_missing. // Given the builtin addition velocity this should always provide a larger compatibility window than bytecode versions suggest. diff --git a/VM/src/lmathlib.cpp b/VM/src/lmathlib.cpp index 3a93abcf..737583d4 100644 --- a/VM/src/lmathlib.cpp +++ b/VM/src/lmathlib.cpp @@ -8,6 +8,7 @@ #include LUAU_FASTFLAGVARIABLE(LuauMathMap) +LUAU_FASTFLAGVARIABLE(LuauMathLerp) #undef PI #define PI (3.14159265358979323846) @@ -418,6 +419,17 @@ static int math_map(lua_State* L) return 1; } +static int math_lerp(lua_State* L) +{ + double a = luaL_checknumber(L, 1); + double b = luaL_checknumber(L, 2); + double t = luaL_checknumber(L, 3); + + double r = (t == 1.0) ? b : a + (b - a) * t; + lua_pushnumber(L, r); + return 1; +} + static const luaL_Reg mathlib[] = { {"abs", math_abs}, {"acos", math_acos}, @@ -477,5 +489,11 @@ int luaopen_math(lua_State* L) lua_setfield(L, -2, "map"); } + if (FFlag::LuauMathLerp) + { + lua_pushcfunction(L, math_lerp, "lerp"); + lua_setfield(L, -2, "lerp"); + } + return 1; } diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index c0e81371..167a88ad 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -32,6 +32,7 @@ void luaC_fullgc(lua_State* L); void luaC_validate(lua_State* L); LUAU_FASTFLAG(LuauMathMap) +LUAU_FASTFLAG(LuauMathLerp) LUAU_FASTFLAG(DebugLuauAbortingChecks) LUAU_FASTINT(CodegenHeuristicsInstructionLimit) LUAU_DYNAMIC_FASTFLAG(LuauStackLimit) @@ -660,6 +661,7 @@ TEST_CASE("Buffers") TEST_CASE("Math") { ScopedFastFlag LuauMathMap{FFlag::LuauMathMap, true}; + ScopedFastFlag LuauMathLerp{FFlag::LuauMathLerp, true}; runConformance("math.lua"); } @@ -911,9 +913,7 @@ TEST_CASE("VectorLibrary") copts.optimizationLevel = 2; } - runConformance( - "vector_library.lua", [](lua_State* L) {}, nullptr, nullptr, &copts - ); + runConformance("vector_library.lua", [](lua_State* L) {}, nullptr, nullptr, &copts); } static void populateRTTI(lua_State* L, Luau::TypeId type) @@ -988,6 +988,7 @@ static void populateRTTI(lua_State* L, Luau::TypeId type) TEST_CASE("Types") { ScopedFastFlag luauVectorDefinitions{FFlag::LuauVectorDefinitions, true}; + ScopedFastFlag luauMathLerp{FFlag::LuauMathLerp, false}; // waiting for math.lerp to be added to embedded type definitions runConformance( "types.lua", diff --git a/tests/conformance/math.lua b/tests/conformance/math.lua index 97c44462..fbd8f9dd 100644 --- a/tests/conformance/math.lua +++ b/tests/conformance/math.lua @@ -402,6 +402,22 @@ assert(math.map(4, 4, 1, 2, 0) == 2) assert(math.map(-8, 0, 4, 0, 2) == -4) assert(math.map(16, 0, 4, 0, 2) == 8) +-- lerp basics +assert(math.lerp(1, 5, 0) == 1) +assert(math.lerp(1, 5, 1) == 5) +assert(math.lerp(1, 5, 0.5) == 3) +assert(math.lerp(1, 5, 1.5) == 7) +assert(math.lerp(1, 5, -0.5) == -1) + +-- lerp properties +local sq2, sq3 = math.sqrt(2), math.sqrt(3) +assert(math.lerp(sq2, sq3, 0) == sq2) -- exact at 0 +assert(math.lerp(sq2, sq3, 1) == sq3) -- exact at 1 +assert(math.lerp(-sq3, sq2, 1) == sq2) -- exact at 1 (fails for a + t*(b-a)) +assert(math.lerp(sq2, sq2, sq2 / 2) <= math.lerp(sq2, sq2, 1)) -- monotonic (fails for a*t + b*(1-t)) +assert(math.lerp(-sq3, sq2, 1) <= math.sqrt(2)) -- bounded (fails for a + t*(b-a)) +assert(math.lerp(sq2, sq2, sq2 / 2) == sq2) -- consistent (fails for a*t + b*(1-t)) + assert(tostring(math.pow(-2, 0.5)) == "nan") -- test that fastcalls return correct number of results @@ -464,5 +480,6 @@ assert(math.sign("2") == 1) assert(math.sign("-2") == -1) assert(math.sign("0") == 0) assert(math.round("1.8") == 2) +assert(math.lerp("1", "5", 0.5) == 3) return('OK')