luau/tests/TypeInfer.operators.test.cpp
2022-03-24 15:04:14 -07:00

785 lines
20 KiB
C++

// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/AstQuery.h"
#include "Luau/BuiltinDefinitions.h"
#include "Luau/Scope.h"
#include "Luau/TypeInfer.h"
#include "Luau/TypeVar.h"
#include "Luau/VisitTypeVar.h"
#include "Fixture.h"
#include "doctest.h"
using namespace Luau;
TEST_SUITE_BEGIN("TypeInferOperators");
TEST_CASE_FIXTURE(Fixture, "or_joins_types")
{
CheckResult result = check(R"(
local s = "a" or 10
local x:string|number = s
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(toString(*requireType("s")), "number | string");
CHECK_EQ(toString(*requireType("x")), "number | string");
}
TEST_CASE_FIXTURE(Fixture, "or_joins_types_with_no_extras")
{
CheckResult result = check(R"(
local s = "a" or 10
local x:number|string = s
local y = x or "s"
)");
CHECK_EQ(0, result.errors.size());
CHECK_EQ(toString(*requireType("s")), "number | string");
CHECK_EQ(toString(*requireType("y")), "number | string");
}
TEST_CASE_FIXTURE(Fixture, "or_joins_types_with_no_superfluous_union")
{
CheckResult result = check(R"(
local s = "a" or "b"
local x:string = s
)");
CHECK_EQ(0, result.errors.size());
CHECK_EQ(*requireType("s"), *typeChecker.stringType);
}
TEST_CASE_FIXTURE(Fixture, "and_adds_boolean")
{
CheckResult result = check(R"(
local s = "a" and 10
local x:boolean|number = s
)");
CHECK_EQ(0, result.errors.size());
CHECK_EQ(toString(*requireType("s")), "boolean | number");
}
TEST_CASE_FIXTURE(Fixture, "and_adds_boolean_no_superfluous_union")
{
CheckResult result = check(R"(
local s = "a" and true
local x:boolean = s
)");
CHECK_EQ(0, result.errors.size());
CHECK_EQ(*requireType("x"), *typeChecker.booleanType);
}
TEST_CASE_FIXTURE(Fixture, "and_or_ternary")
{
CheckResult result = check(R"(
local s = (1/2) > 0.5 and "a" or 10
)");
CHECK_EQ(0, result.errors.size());
CHECK_EQ(toString(*requireType("s")), "number | string");
}
TEST_CASE_FIXTURE(Fixture, "primitive_arith_no_metatable")
{
CheckResult result = check(R"(
function add(a: number, b: string)
return a + (tonumber(b) :: number), a .. b
end
local n, s = add(2,"3")
)");
LUAU_REQUIRE_NO_ERRORS(result);
const FunctionTypeVar* functionType = get<FunctionTypeVar>(requireType("add"));
std::optional<TypeId> retType = first(functionType->retType);
CHECK_EQ(std::optional<TypeId>(typeChecker.numberType), retType);
CHECK_EQ(requireType("n"), typeChecker.numberType);
CHECK_EQ(requireType("s"), typeChecker.stringType);
}
TEST_CASE_FIXTURE(Fixture, "primitive_arith_no_metatable_with_follows")
{
CheckResult result = check(R"(
local PI=3.1415926535897931
local SOLAR_MASS=4*PI * PI
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(requireType("SOLAR_MASS"), typeChecker.numberType);
}
TEST_CASE_FIXTURE(Fixture, "primitive_arith_possible_metatable")
{
CheckResult result = check(R"(
function add(a: number, b: any)
return a + b
end
local t = add(1,2)
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("any", toString(requireType("t")));
}
TEST_CASE_FIXTURE(Fixture, "some_primitive_binary_ops")
{
CheckResult result = check(R"(
local a = 4 + 8
local b = a + 9
local s = 'hotdogs'
local t = s .. s
local c = b - a
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("number", toString(requireType("a")));
CHECK_EQ("number", toString(requireType("b")));
CHECK_EQ("string", toString(requireType("s")));
CHECK_EQ("string", toString(requireType("t")));
CHECK_EQ("number", toString(requireType("c")));
}
TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersection")
{
ScopedFastFlag sff{"LuauErrorRecoveryType", true};
CheckResult result = check(R"(
--!strict
local Vec3 = {}
Vec3.__index = Vec3
function Vec3.new()
return setmetatable({x=0, y=0, z=0}, Vec3)
end
export type Vec3 = typeof(Vec3.new())
local thefun: any = function(self, o) return self end
local multiply: ((Vec3, Vec3) -> Vec3) & ((Vec3, number) -> Vec3) = thefun
Vec3.__mul = multiply
local a = Vec3.new()
local b = Vec3.new()
local c = a * b
local d = a * 2
local e = a * 'cabbage'
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ("Vec3", toString(requireType("a")));
CHECK_EQ("Vec3", toString(requireType("b")));
CHECK_EQ("Vec3", toString(requireType("c")));
CHECK_EQ("Vec3", toString(requireType("d")));
CHECK_EQ("Vec3", toString(requireType("e")));
}
TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersection_on_rhs")
{
ScopedFastFlag sff{"LuauErrorRecoveryType", true};
CheckResult result = check(R"(
--!strict
local Vec3 = {}
Vec3.__index = Vec3
function Vec3.new()
return setmetatable({x=0, y=0, z=0}, Vec3)
end
export type Vec3 = typeof(Vec3.new())
local thefun: any = function(self, o) return self end
local multiply: ((Vec3, Vec3) -> Vec3) & ((Vec3, number) -> Vec3) = thefun
Vec3.__mul = multiply
local a = Vec3.new()
local b = Vec3.new()
local c = b * a
local d = 2 * a
local e = 'cabbage' * a
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ("Vec3", toString(requireType("a")));
CHECK_EQ("Vec3", toString(requireType("b")));
CHECK_EQ("Vec3", toString(requireType("c")));
CHECK_EQ("Vec3", toString(requireType("d")));
CHECK_EQ("Vec3", toString(requireType("e")));
}
TEST_CASE_FIXTURE(Fixture, "compare_numbers")
{
CheckResult result = check(R"(
local a = 441
local b = 0
local c = a < b
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "compare_strings")
{
CheckResult result = check(R"(
local a = '441'
local b = '0'
local c = a < b
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "cannot_indirectly_compare_types_that_do_not_have_a_metatable")
{
CheckResult result = check(R"(
local a = {}
local b = {}
local c = a < b
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
GenericError* gen = get<GenericError>(result.errors[0]);
REQUIRE_EQ(gen->message, "Type a cannot be compared with < because it has no metatable");
}
TEST_CASE_FIXTURE(Fixture, "cannot_indirectly_compare_types_that_do_not_offer_overloaded_ordering_operators")
{
CheckResult result = check(R"(
local M = {}
function M.new()
return setmetatable({}, M)
end
type M = typeof(M.new())
local a = M.new()
local b = M.new()
local c = a < b
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
GenericError* gen = get<GenericError>(result.errors[0]);
REQUIRE(gen != nullptr);
REQUIRE_EQ(gen->message, "Table M does not offer metamethod __lt");
}
TEST_CASE_FIXTURE(Fixture, "cannot_compare_tables_that_do_not_have_the_same_metatable")
{
CheckResult result = check(R"(
--!strict
local M = {}
function M.new()
return setmetatable({}, M)
end
function M.__lt(left, right) return true end
local a = M.new()
local b = {}
local c = a < b -- line 10
local d = b < a -- line 11
)");
LUAU_REQUIRE_ERROR_COUNT(2, result);
REQUIRE_EQ((Location{{10, 18}, {10, 23}}), result.errors[0].location);
REQUIRE_EQ((Location{{11, 18}, {11, 23}}), result.errors[1].location);
}
TEST_CASE_FIXTURE(Fixture, "produce_the_correct_error_message_when_comparing_a_table_with_a_metatable_with_one_that_does_not")
{
CheckResult result = check(R"(
--!strict
local M = {}
function M.new()
return setmetatable({}, M)
end
function M.__lt(left, right) return true end
type M = typeof(M.new())
local a = M.new()
local b = {}
local c = a < b -- line 10
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
auto err = get<GenericError>(result.errors[0]);
REQUIRE(err != nullptr);
// Frail. :|
REQUIRE_EQ("Types M and b cannot be compared with < because they do not have the same metatable", err->message);
}
TEST_CASE_FIXTURE(Fixture, "in_nonstrict_mode_strip_nil_from_intersections_when_considering_relational_operators")
{
CheckResult result = check(R"(
--!nonstrict
function maybe_a_number(): number?
return 50
end
local a = maybe_a_number() < maybe_a_number()
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "compound_assign_basic")
{
CheckResult result = check(R"(
local s = 10
s += 20
)");
CHECK_EQ(0, result.errors.size());
CHECK_EQ(toString(*requireType("s")), "number");
}
TEST_CASE_FIXTURE(Fixture, "compound_assign_mismatch_op")
{
CheckResult result = check(R"(
local s = 10
s += true
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ(result.errors[0], (TypeError{Location{{2, 13}, {2, 17}}, TypeMismatch{typeChecker.numberType, typeChecker.booleanType}}));
}
TEST_CASE_FIXTURE(Fixture, "compound_assign_mismatch_result")
{
CheckResult result = check(R"(
local s = 'hello'
s += 10
)");
LUAU_REQUIRE_ERROR_COUNT(2, result);
CHECK_EQ(result.errors[0], (TypeError{Location{{2, 8}, {2, 9}}, TypeMismatch{typeChecker.numberType, typeChecker.stringType}}));
CHECK_EQ(result.errors[1], (TypeError{Location{{2, 8}, {2, 15}}, TypeMismatch{typeChecker.stringType, typeChecker.numberType}}));
}
TEST_CASE_FIXTURE(Fixture, "compound_assign_metatable")
{
CheckResult result = check(R"(
--!strict
type V2B = { x: number, y: number }
local v2b: V2B = { x = 0, y = 0 }
local VMT = {}
type V2 = typeof(setmetatable(v2b, VMT))
function VMT.__add(a: V2, b: V2): V2
return setmetatable({ x = a.x + b.x, y = a.y + b.y }, VMT)
end
local v1: V2 = setmetatable({ x = 1, y = 2 }, VMT)
local v2: V2 = setmetatable({ x = 3, y = 4 }, VMT)
v1 += v2
)");
CHECK_EQ(0, result.errors.size());
}
TEST_CASE_FIXTURE(Fixture, "compound_assign_mismatch_metatable")
{
CheckResult result = check(R"(
--!strict
type V2B = { x: number, y: number }
local v2b: V2B = { x = 0, y = 0 }
local VMT = {}
type V2 = typeof(setmetatable(v2b, VMT))
function VMT.__mod(a: V2, b: V2): number
return a.x * b.x + a.y * b.y
end
local v1: V2 = setmetatable({ x = 1, y = 2 }, VMT)
local v2: V2 = setmetatable({ x = 3, y = 4 }, VMT)
v1 %= v2
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
TypeMismatch* tm = get<TypeMismatch>(result.errors[0]);
CHECK_EQ(*tm->wantedType, *requireType("v2"));
CHECK_EQ(*tm->givenType, *typeChecker.numberType);
}
TEST_CASE_FIXTURE(Fixture, "CallOrOfFunctions")
{
CheckResult result = check(R"(
function f() return 1; end
function g() return 2; end
(f or g)()
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "CallAndOrOfFunctions")
{
CheckResult result = check(R"(
function f() return 1; end
function g() return 2; end
local x = false
(x and f or g)()
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "typecheck_unary_minus")
{
CheckResult result = check(R"(
--!strict
local foo = {
value = 10
}
local mt = {}
setmetatable(foo, mt)
mt.__unm = function(val: typeof(foo)): string
return val.value .. "test"
end
local a = -foo
local b = 1+-1
local bar = {
value = 10
}
local c = -bar -- disallowed
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ("string", toString(requireType("a")));
CHECK_EQ("number", toString(requireType("b")));
GenericError* gen = get<GenericError>(result.errors[0]);
REQUIRE_EQ(gen->message, "Unary operator '-' not supported by type 'bar'");
}
TEST_CASE_FIXTURE(Fixture, "unary_not_is_boolean")
{
CheckResult result = check(R"(
local b = not "string"
local c = not (math.random() > 0.5 and "string" or 7)
)");
LUAU_REQUIRE_NO_ERRORS(result);
REQUIRE_EQ("boolean", toString(requireType("b")));
REQUIRE_EQ("boolean", toString(requireType("c")));
}
TEST_CASE_FIXTURE(Fixture, "disallow_string_and_types_without_metatables_from_arithmetic_binary_ops")
{
CheckResult result = check(R"(
--!strict
local a = "1.24" + 123 -- not allowed
local foo = {
value = 10
}
local b = foo + 1 -- not allowed
local bar = {
value = 1
}
local mt = {}
setmetatable(bar, mt)
mt.__add = function(a: typeof(bar), b: number): number
return a.value + b
end
local c = bar + 1 -- allowed
local d = bar + foo -- not allowed
)");
LUAU_REQUIRE_ERROR_COUNT(3, result);
TypeMismatch* tm = get<TypeMismatch>(result.errors[0]);
REQUIRE_EQ(*tm->wantedType, *typeChecker.numberType);
REQUIRE_EQ(*tm->givenType, *typeChecker.stringType);
TypeMismatch* tm2 = get<TypeMismatch>(result.errors[2]);
CHECK_EQ(*tm2->wantedType, *typeChecker.numberType);
CHECK_EQ(*tm2->givenType, *requireType("foo"));
GenericError* gen2 = get<GenericError>(result.errors[1]);
REQUIRE_EQ(gen2->message, "Binary operator '+' not supported by types 'foo' and 'number'");
}
// CLI-29033
TEST_CASE_FIXTURE(Fixture, "unknown_type_in_comparison")
{
CheckResult result = check(R"(
function merge(lower, greater)
if lower.y == greater.y then
end
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "concat_op_on_free_lhs_and_string_rhs")
{
CheckResult result = check(R"(
local function f(x)
return x .. "y"
end
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
REQUIRE(get<CannotInferBinaryOperation>(result.errors[0]));
}
TEST_CASE_FIXTURE(Fixture, "concat_op_on_string_lhs_and_free_rhs")
{
CheckResult result = check(R"(
local function f(x)
return "foo" .. x
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("(string) -> string", toString(requireType("f")));
}
TEST_CASE_FIXTURE(Fixture, "strict_binary_op_where_lhs_unknown")
{
std::vector<std::string> ops = {"+", "-", "*", "/", "%", "^", ".."};
std::string src = R"(
function foo(a, b)
)";
for (const auto& op : ops)
src += "local _ = a " + op + "b\n";
src += "end";
CheckResult result = check(src);
LUAU_REQUIRE_ERROR_COUNT(ops.size(), result);
CHECK_EQ("Unknown type used in + operation; consider adding a type annotation to 'a'", toString(result.errors[0]));
}
TEST_CASE_FIXTURE(Fixture, "and_binexps_dont_unify")
{
CheckResult result = check(R"(
--!strict
local t = {}
while true and t[1] do
print(t[1].test)
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "error_on_invalid_operand_types_to_relational_operators")
{
CheckResult result = check(R"(
local a: boolean = true
local b: boolean = false
local foo = a < b
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
GenericError* ge = get<GenericError>(result.errors[0]);
REQUIRE(ge);
CHECK_EQ("Type 'boolean' cannot be compared with relational operator <", ge->message);
}
TEST_CASE_FIXTURE(Fixture, "error_on_invalid_operand_types_to_relational_operators2")
{
CheckResult result = check(R"(
local a: number | string = ""
local b: number | string = 1
local foo = a < b
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
GenericError* ge = get<GenericError>(result.errors[0]);
REQUIRE(ge);
CHECK_EQ("Type 'number | string' cannot be compared with relational operator <", ge->message);
}
TEST_CASE_FIXTURE(Fixture, "cli_38355_recursive_union")
{
CheckResult result = check(R"(
--!strict
local _
_ += _ and _ or _ and _ or _ and _
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ("Type contains a self-recursive construct that cannot be resolved", toString(result.errors[0]));
}
TEST_CASE_FIXTURE(Fixture, "UnknownGlobalCompoundAssign")
{
// In non-strict mode, global definition is still allowed
{
CheckResult result = check(R"(
--!nonstrict
a = a + 1
print(a)
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ(toString(result.errors[0]), "Unknown global 'a'");
}
// In strict mode we no longer generate two errors from lhs
{
CheckResult result = check(R"(
--!strict
a += 1
print(a)
)");
LUAU_REQUIRE_ERRORS(result);
CHECK_EQ(toString(result.errors[0]), "Unknown global 'a'");
}
// In non-strict mode, compound assignment is not a definition, it's a modification
{
CheckResult result = check(R"(
--!nonstrict
a += 1
print(a)
)");
LUAU_REQUIRE_ERROR_COUNT(2, result);
CHECK_EQ(toString(result.errors[0]), "Unknown global 'a'");
}
}
TEST_CASE_FIXTURE(Fixture, "strip_nil_from_lhs_or_operator")
{
CheckResult result = check(R"(
--!strict
local a: number? = nil
local b: number = a or 1
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "strip_nil_from_lhs_or_operator2")
{
CheckResult result = check(R"(
--!nonstrict
local a: number? = nil
local b: number = a or 1
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "dont_strip_nil_from_rhs_or_operator")
{
CheckResult result = check(R"(
--!strict
local a: number? = nil
local b: number = 1 or a
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
TypeMismatch* tm = get<TypeMismatch>(result.errors[0]);
REQUIRE(tm);
CHECK_EQ(typeChecker.numberType, tm->wantedType);
CHECK_EQ("number?", toString(tm->givenType));
}
TEST_CASE_FIXTURE(Fixture, "operator_eq_verifies_types_do_intersect")
{
CheckResult result = check(R"(
type Array<T> = { [number]: T }
type Fiber = { id: number }
type null = {}
local fiberStack: Array<Fiber | null> = {}
local index = 0
local function f(fiber: Fiber)
local a = fiber ~= fiberStack[index]
local b = fiberStack[index] ~= fiber
end
return f
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "operator_eq_operands_are_not_subtypes_of_each_other_but_has_overlap")
{
ScopedFastFlag sff1{"LuauEqConstraint", true};
CheckResult result = check(R"(
local function f(a: string | number, b: boolean | number)
return a == b
end
)");
// This doesn't produce any errors but for the wrong reasons.
// This unit test serves as a reminder to not try and unify the operands on `==`/`~=`.
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "refine_and_or")
{
CheckResult result = check(R"(
local t: {x: number?}? = {x = nil}
local u = t and t.x or 5
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("number", toString(requireType("u")));
}
TEST_CASE_FIXTURE(Fixture, "infer_any_in_all_modes_when_lhs_is_unknown")
{
ScopedFastFlag sff{"LuauDecoupleOperatorInferenceFromUnifiedTypeInference", true};
CheckResult result = check(Mode::Strict, R"(
local function f(x, y)
return x + y
end
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ(toString(result.errors[0]), "Unknown type used in + operation; consider adding a type annotation to 'x'");
result = check(Mode::Nonstrict, R"(
local function f(x, y)
return x + y
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
// When type inference is unified, we could add an assertion that
// the strict and nonstrict types are equivalent. This isn't actually
// the case right now, though.
}
TEST_SUITE_END();