diff --git a/CodeGen/src/OptimizeConstProp.cpp b/CodeGen/src/OptimizeConstProp.cpp index 519630f0..f93354a3 100644 --- a/CodeGen/src/OptimizeConstProp.cpp +++ b/CodeGen/src/OptimizeConstProp.cpp @@ -9,6 +9,7 @@ #include "lua.h" #include +#include #include #include @@ -19,6 +20,7 @@ LUAU_FASTINTVARIABLE(LuauCodeGenReuseSlotLimit, 64) LUAU_FASTINTVARIABLE(LuauCodeGenReuseUdataTagLimit, 64) LUAU_FASTFLAGVARIABLE(DebugLuauAbortingChecks) LUAU_FASTFLAG(LuauVectorLibNativeDot); +LUAU_FASTFLAGVARIABLE(LuauCodeGenArithOpt); namespace Luau { @@ -1192,10 +1194,67 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& break; case IrCmd::ADD_INT: case IrCmd::SUB_INT: + state.substituteOrRecord(inst, index); + break; case IrCmd::ADD_NUM: case IrCmd::SUB_NUM: + if (FFlag::LuauCodeGenArithOpt) + { + if (std::optional k = function.asDoubleOp(inst.b.kind == IrOpKind::Constant ? inst.b : state.tryGetValue(inst.b))) + { + // a + 0.0 and a - (-0.0) can't be folded since the behavior is different for negative zero + // however, a - 0.0 and a + (-0.0) can be folded into a + if (*k == 0.0 && bool(signbit(*k)) == (inst.cmd == IrCmd::ADD_NUM)) + substitute(function, inst, inst.a); + else + state.substituteOrRecord(inst, index); + } + else + state.substituteOrRecord(inst, index); + } + else + state.substituteOrRecord(inst, index); + break; case IrCmd::MUL_NUM: + if (FFlag::LuauCodeGenArithOpt) + { + if (std::optional k = function.asDoubleOp(inst.b.kind == IrOpKind::Constant ? inst.b : state.tryGetValue(inst.b))) + { + if (*k == 1.0) // a * 1.0 = a + substitute(function, inst, inst.a); + else if (*k == 2.0) // a * 2.0 = a + a + replace(function, block, index, {IrCmd::ADD_NUM, inst.a, inst.a}); + else if (*k == -1.0) // a * -1.0 = -a + replace(function, block, index, {IrCmd::UNM_NUM, inst.a}); + else + state.substituteOrRecord(inst, index); + } + else + state.substituteOrRecord(inst, index); + } + else + state.substituteOrRecord(inst, index); + break; case IrCmd::DIV_NUM: + if (FFlag::LuauCodeGenArithOpt) + { + if (std::optional k = function.asDoubleOp(inst.b.kind == IrOpKind::Constant ? inst.b : state.tryGetValue(inst.b))) + { + if (*k == 1.0) // a / 1.0 = a + substitute(function, inst, inst.a); + else if (*k == -1.0) // a / -1.0 = -a + replace(function, block, index, {IrCmd::UNM_NUM, inst.a}); + else if (int exp = 0; frexp(*k, &exp) == 0.5 && exp >= -1000 && exp <= 1000) // a / 2^k = a * 2^-k + replace(function, block, index, {IrCmd::MUL_NUM, inst.a, build.constDouble(1.0 / *k)}); + else + state.substituteOrRecord(inst, index); + } + else + state.substituteOrRecord(inst, index); + } + else + state.substituteOrRecord(inst, index); + break; case IrCmd::IDIV_NUM: case IrCmd::MOD_NUM: case IrCmd::MIN_NUM: diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index 0b26f079..d73f6496 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -2923,10 +2923,13 @@ reentry: { VM_PROTECT_PC(); // f may fail due to OOM - setobj2s(L, L->top, arg2); - setobj2s(L, L->top + 1, arg3); + // note: it's safe to push arguments past top for complicated reasons (see top of the file) + LUAU_ASSERT(L->top + 2 < L->stack + L->stacksize); + StkId top = L->top; + setobj2s(L, top, arg2); + setobj2s(L, top + 1, arg3); - int n = f(L, ra, arg1, nresults, L->top, nparams); + int n = f(L, ra, arg1, nresults, top, nparams); if (n >= 0) { diff --git a/tests/IrLowering.test.cpp b/tests/IrLowering.test.cpp index 396678468..27376777 100644 --- a/tests/IrLowering.test.cpp +++ b/tests/IrLowering.test.cpp @@ -540,7 +540,7 @@ TEST_CASE("VectorCustomAccess") CHECK_EQ( "\n" + getCodegenAssembly(R"( local function vec3magn(a: vector) - return a.Magnitude * 2 + return a.Magnitude * 3 end )"), R"( @@ -560,7 +560,7 @@ bb_bytecode_1: %12 = ADD_NUM %9, %10 %13 = ADD_NUM %12, %11 %14 = SQRT_NUM %13 - %20 = MUL_NUM %14, 2 + %20 = MUL_NUM %14, 3 STORE_DOUBLE R1, %20 STORE_TAG R1, tnumber INTERRUPT 3u @@ -1167,7 +1167,7 @@ local function inl(v: vector, s: number) end local function getsum(x) - return inl(x, 2) + inl(x, 5) + return inl(x, 3) + inl(x, 5) end )", /* includeIrTypes */ true @@ -1195,7 +1195,7 @@ bb_bytecode_1: bb_bytecode_0: CHECK_TAG R0, tvector, exit(0) %2 = LOAD_FLOAT R0, 4i - %8 = MUL_NUM %2, 2 + %8 = MUL_NUM %2, 3 %13 = LOAD_FLOAT R0, 4i %19 = MUL_NUM %13, 5 %28 = ADD_NUM %8, %19 diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 5aacceb4..b0e0caa0 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -39,24 +39,6 @@ struct Counter int Counter::instanceCount = 0; -// TODO: delete this and replace all other use of this function with matchParseError -std::string getParseError(const std::string& code) -{ - Fixture f; - - try - { - f.parse(code); - } - catch (const Luau::ParseErrors& e) - { - // in general, tests check only the first error - return e.getErrors().front().getMessage(); - } - - throw std::runtime_error("Expected a parse error in '" + code + "'"); -} - } // namespace TEST_SUITE_BEGIN("AllocatorTests"); @@ -465,62 +447,38 @@ TEST_CASE_FIXTURE(Fixture, "type_alias_span_is_correct") TEST_CASE_FIXTURE(Fixture, "parse_error_messages") { - CHECK_EQ( - getParseError(R"( - local a: (number, number) -> (string - )"), - "Expected ')' (to close '(' at line 2), got " - ); + matchParseError(R"( + local a: (number, number) -> (string + )", "Expected ')' (to close '(' at line 2), got "); - CHECK_EQ( - getParseError(R"( - local a: (number, number) -> ( - string - )"), - "Expected ')' (to close '(' at line 2), got " - ); + matchParseError(R"( + local a: (number, number) -> ( + string + )", "Expected ')' (to close '(' at line 2), got "); - CHECK_EQ( - getParseError(R"( - local a: (number, number) - )"), - "Expected '->' when parsing function type, got " - ); + matchParseError(R"( + local a: (number, number) + )", "Expected '->' when parsing function type, got "); - CHECK_EQ( - getParseError(R"( - local a: (number, number - )"), - "Expected ')' (to close '(' at line 2), got " - ); + matchParseError(R"( + local a: (number, number + )", "Expected ')' (to close '(' at line 2), got "); - CHECK_EQ( - getParseError(R"( - local a: {foo: string, - )"), - "Expected identifier when parsing table field, got " - ); + matchParseError(R"( + local a: {foo: string, + )", "Expected identifier when parsing table field, got "); - CHECK_EQ( - getParseError(R"( - local a: {foo: string - )"), - "Expected '}' (to close '{' at line 2), got " - ); + matchParseError(R"( + local a: {foo: string + )", "Expected '}' (to close '{' at line 2), got "); - CHECK_EQ( - getParseError(R"( - local a: { [string]: number, [number]: string } - )"), - "Cannot have more than one table indexer" - ); + matchParseError(R"( + local a: { [string]: number, [number]: string } + )", "Cannot have more than one table indexer"); - CHECK_EQ( - getParseError(R"( - type T = foo - )"), - "Expected '(' when parsing function parameters, got 'foo'" - ); + matchParseError(R"( + type T = foo + )", "Expected '(' when parsing function parameters, got 'foo'"); } TEST_CASE_FIXTURE(Fixture, "mixed_intersection_and_union_not_allowed") @@ -548,10 +506,10 @@ TEST_CASE_FIXTURE(Fixture, "cannot_write_multiple_values_in_type_groups") TEST_CASE_FIXTURE(Fixture, "type_alias_error_messages") { - CHECK_EQ(getParseError("type 5 = number"), "Expected identifier when parsing type name, got '5'"); - CHECK_EQ(getParseError("type A"), "Expected '=' when parsing type alias, got "); - CHECK_EQ(getParseError("type A<"), "Expected identifier, got "); - CHECK_EQ(getParseError("type A' (to close '<' at column 7), got "); + matchParseError("type 5 = number", "Expected identifier when parsing type name, got '5'"); + matchParseError("type A", "Expected '=' when parsing type alias, got "); + matchParseError("type A<", "Expected identifier, got "); + matchParseError("type A' (to close '<' at column 7), got "); } TEST_CASE_FIXTURE(Fixture, "type_assertion_expression") @@ -655,12 +613,9 @@ TEST_CASE_FIXTURE(Fixture, "vertical_space") TEST_CASE_FIXTURE(Fixture, "parse_error_type_name") { - CHECK_EQ( - getParseError(R"( - local a: Foo.= - )"), - "Expected identifier when parsing field name, got '='" - ); + matchParseError(R"( + local a: Foo.= + )", "Expected identifier when parsing field name, got '='"); } TEST_CASE_FIXTURE(Fixture, "parse_numbers_decimal") @@ -706,28 +661,25 @@ TEST_CASE_FIXTURE(Fixture, "parse_numbers_binary") TEST_CASE_FIXTURE(Fixture, "parse_numbers_error") { - CHECK_EQ(getParseError("return 0b123"), "Malformed number"); - CHECK_EQ(getParseError("return 123x"), "Malformed number"); - CHECK_EQ(getParseError("return 0xg"), "Malformed number"); - CHECK_EQ(getParseError("return 0x0x123"), "Malformed number"); - CHECK_EQ(getParseError("return 0xffffffffffffffffffffllllllg"), "Malformed number"); - CHECK_EQ(getParseError("return 0x0xffffffffffffffffffffffffffff"), "Malformed number"); + matchParseError("return 0b123", "Malformed number"); + matchParseError("return 123x", "Malformed number"); + matchParseError("return 0xg", "Malformed number"); + matchParseError("return 0x0x123", "Malformed number"); + matchParseError("return 0xffffffffffffffffffffllllllg", "Malformed number"); + matchParseError("return 0x0xffffffffffffffffffffffffffff", "Malformed number"); } TEST_CASE_FIXTURE(Fixture, "break_return_not_last_error") { - CHECK_EQ(getParseError("return 0 print(5)"), "Expected , got 'print'"); - CHECK_EQ(getParseError("while true do break print(5) end"), "Expected 'end' (to close 'do' at column 12), got 'print'"); + matchParseError("return 0 print(5)", "Expected , got 'print'"); + matchParseError("while true do break print(5) end", "Expected 'end' (to close 'do' at column 12), got 'print'"); } TEST_CASE_FIXTURE(Fixture, "error_on_unicode") { - CHECK_EQ( - getParseError(R"( + matchParseError(R"( local ☃ = 10 - )"), - "Expected identifier when parsing variable name, got Unicode character U+2603" - ); + )", "Expected identifier when parsing variable name, got Unicode character U+2603"); } TEST_CASE_FIXTURE(Fixture, "allow_unicode_in_string") @@ -738,20 +690,17 @@ TEST_CASE_FIXTURE(Fixture, "allow_unicode_in_string") TEST_CASE_FIXTURE(Fixture, "error_on_confusable") { - CHECK_EQ( - getParseError(R"( - local pi = 3․13 - )"), - "Expected identifier when parsing expression, got Unicode character U+2024 (did you mean '.'?)" - ); + matchParseError(R"( + local pi = 3․13 + )", "Expected identifier when parsing expression, got Unicode character U+2024 (did you mean '.'?)"); } TEST_CASE_FIXTURE(Fixture, "error_on_non_utf8_sequence") { const char* expected = "Expected identifier when parsing expression, got invalid UTF-8 sequence"; - CHECK_EQ(getParseError("local pi = \xFF!"), expected); - CHECK_EQ(getParseError("local pi = \xE2!"), expected); + matchParseError("local pi = \xFF!", expected); + matchParseError("local pi = \xE2!", expected); } TEST_CASE_FIXTURE(Fixture, "lex_broken_unicode") @@ -819,7 +768,7 @@ TEST_CASE_FIXTURE(Fixture, "parse_continue") TEST_CASE_FIXTURE(Fixture, "continue_not_last_error") { - CHECK_EQ(getParseError("while true do continue print(5) end"), "Expected 'end' (to close 'do' at column 12), got 'print'"); + matchParseError("while true do continue print(5) end", "Expected 'end' (to close 'do' at column 12), got 'print'"); } TEST_CASE_FIXTURE(Fixture, "parse_export_type") @@ -862,7 +811,7 @@ TEST_CASE_FIXTURE(Fixture, "export_is_an_identifier_only_when_followed_by_type") TEST_CASE_FIXTURE(Fixture, "incomplete_statement_error") { - CHECK_EQ(getParseError("fiddlesticks"), "Incomplete statement: expected assignment or a function call"); + matchParseError("fiddlesticks", "Incomplete statement: expected assignment or a function call"); } TEST_CASE_FIXTURE(Fixture, "parse_compound_assignment") diff --git a/tests/conformance/basic.lua b/tests/conformance/basic.lua index 98f8000e..05d851ea 100644 --- a/tests/conformance/basic.lua +++ b/tests/conformance/basic.lua @@ -92,6 +92,16 @@ assert((function() local a = 1 a = a - 2 return a end)() == -1) assert((function() local a = 1 a = a * 2 return a end)() == 2) assert((function() local a = 1 a = a / 2 return a end)() == 0.5) +-- binary ops with fp specials, neg zero, large constants +-- argument is passed into anonymous function to prevent constant folding +assert((function(a) return tostring(a + 0) end)(-0) == "0") +assert((function(a) return tostring(a - 0) end)(-0) == "-0") +assert((function(a) return tostring(0 - a) end)(0) == "0") +assert((function(a) return tostring(a - a) end)(1 / 0) == "nan") +assert((function(a) return tostring(a * 0) end)(0 / 0) == "nan") +assert((function(a) return tostring(a / (2^1000)) end)(2^1000) == "1") +assert((function(a) return tostring(a / (2^-1000)) end)(2^-1000) == "1") + -- floor division should always round towards -Infinity assert((function() local a = 1 a = a // 2 return a end)() == 0) assert((function() local a = 3 a = a // 2 return a end)() == 1) @@ -290,7 +300,7 @@ assert((function() local t = {[1] = 1, [2] = 2} return t[1] + t[2] end)() == 3) assert((function() return table.concat({}, ',') end)() == "") assert((function() return table.concat({1}, ',') end)() == "1") assert((function() return table.concat({1,2}, ',') end)() == "1,2") -assert((function() return table.concat({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, ',') end)() == +assert((function() return table.concat({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, ',') end)() == "1,2,3,4,5,6,7,8,9,10,11,12,13,14,15") assert((function() return table.concat({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, ',') end)() == "1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16") assert((function() return table.concat({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}, ',') end)() == "1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17") @@ -770,7 +780,7 @@ assert(tostring(0) == "0") assert(tostring(-0) == "-0") -- test newline handling in long strings -assert((function() +assert((function() local s1 = [[ ]] local s2 = [[