Sync to upstream/release/508 (#301)

This version isn't for release because we've skipped some internal
numbers due to year-end schedule changes, but it's better to merge
separately.
This commit is contained in:
Arseny Kapoulkine 2022-01-06 15:26:14 -08:00 committed by GitHub
parent 82587bef29
commit d323237b6c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
24 changed files with 570 additions and 284 deletions

View file

@ -0,0 +1,63 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Variant.h"
#include "Luau/Symbol.h"
#include <map> // TODO: Kill with LuauLValueAsKey.
#include <memory>
#include <unordered_map>
namespace Luau
{
struct TypeVar;
using TypeId = const TypeVar*;
struct Field;
using LValue = Variant<Symbol, Field>;
struct Field
{
std::shared_ptr<LValue> parent;
std::string key;
bool operator==(const Field& rhs) const;
bool operator!=(const Field& rhs) const;
};
struct LValueHasher
{
size_t operator()(const LValue& lvalue) const;
};
const LValue* baseof(const LValue& lvalue);
std::optional<LValue> tryGetLValue(const class AstExpr& expr);
// Utility function: breaks down an LValue to get at the Symbol, and reverses the vector of keys.
std::pair<Symbol, std::vector<std::string>> getFullName(const LValue& lvalue);
// Kill with LuauLValueAsKey.
std::string toString(const LValue& lvalue);
template<typename T>
const T* get(const LValue& lvalue)
{
return get_if<T>(&lvalue);
}
using NEW_RefinementMap = std::unordered_map<LValue, TypeId, LValueHasher>;
using DEPRECATED_RefinementMap = std::map<std::string, TypeId>;
// Transient. Kill with LuauLValueAsKey.
struct RefinementMap
{
NEW_RefinementMap NEW_refinements;
DEPRECATED_RefinementMap DEPRECATED_refinements;
};
void merge(RefinementMap& l, const RefinementMap& r, std::function<TypeId(TypeId, TypeId)> f);
void addRefinement(RefinementMap& refis, const LValue& lvalue, TypeId ty);
} // namespace Luau

View file

@ -1,12 +1,10 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Variant.h"
#include "Luau/Location.h"
#include "Luau/Symbol.h"
#include "Luau/LValue.h"
#include "Luau/Variant.h"
#include <map>
#include <memory>
#include <vector>
namespace Luau
@ -15,34 +13,6 @@ namespace Luau
struct TypeVar;
using TypeId = const TypeVar*;
struct Field;
using LValue = Variant<Symbol, Field>;
struct Field
{
std::shared_ptr<LValue> parent; // TODO: Eventually use unique_ptr to enforce non-copyable trait.
std::string key;
};
std::optional<LValue> tryGetLValue(const class AstExpr& expr);
// Utility function: breaks down an LValue to get at the Symbol, and reverses the vector of keys.
std::pair<Symbol, std::vector<std::string>> getFullName(const LValue& lvalue);
std::string toString(const LValue& lvalue);
template<typename T>
const T* get(const LValue& lvalue)
{
return get_if<T>(&lvalue);
}
// Key is a stringified encoding of an LValue.
using RefinementMap = std::map<std::string, TypeId>;
void merge(RefinementMap& l, const RefinementMap& r, std::function<TypeId(TypeId, TypeId)> f);
void addRefinement(RefinementMap& refis, const LValue& lvalue, TypeId ty);
struct TruthyPredicate;
struct IsAPredicate;
struct TypeGuardPredicate;

View file

@ -350,6 +350,7 @@ public:
private:
std::optional<TypeId> resolveLValue(const ScopePtr& scope, const LValue& lvalue);
std::optional<TypeId> DEPRECATED_resolveLValue(const ScopePtr& scope, const LValue& lvalue);
std::optional<TypeId> resolveLValue(const RefinementMap& refis, const ScopePtr& scope, const LValue& lvalue);
void resolve(const PredicateVec& predicates, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr = false);

View file

@ -1,11 +1,59 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Predicate.h"
#include "Luau/LValue.h"
#include "Luau/Ast.h"
#include <vector>
LUAU_FASTFLAG(LuauLValueAsKey)
namespace Luau
{
bool Field::operator==(const Field& rhs) const
{
LUAU_ASSERT(parent && rhs.parent);
return key == rhs.key && (parent == rhs.parent || *parent == *rhs.parent);
}
bool Field::operator!=(const Field& rhs) const
{
return !(*this == rhs);
}
size_t LValueHasher::operator()(const LValue& lvalue) const
{
// Most likely doesn't produce high quality hashes, but we're probably ok enough with it.
// When an evidence is shown that operator==(LValue) is used more often than it should, we can have a look at improving the hash quality.
size_t acc = 0;
size_t offset = 0;
const LValue* current = &lvalue;
while (current)
{
if (auto field = get<Field>(*current))
acc ^= (std::hash<std::string>{}(field->key) << 1) >> ++offset;
else if (auto symbol = get<Symbol>(*current))
acc ^= std::hash<Symbol>{}(*symbol) << 1;
else
LUAU_ASSERT(!"Hash not accumulated for this new LValue alternative.");
current = baseof(*current);
}
return acc;
}
const LValue* baseof(const LValue& lvalue)
{
if (auto field = get<Field>(lvalue))
return field->parent.get();
auto symbol = get<Symbol>(lvalue);
LUAU_ASSERT(symbol);
return nullptr; // Base of root is null.
}
std::optional<LValue> tryGetLValue(const AstExpr& node)
{
const AstExpr* expr = &node;
@ -38,15 +86,15 @@ std::pair<Symbol, std::vector<std::string>> getFullName(const LValue& lvalue)
while (auto field = get<Field>(*current))
{
keys.push_back(field->key);
current = field->parent.get();
if (!current)
LUAU_ASSERT(!"LValue root is a Field?");
current = baseof(*current);
}
const Symbol* symbol = get<Symbol>(*current);
LUAU_ASSERT(symbol);
return {*symbol, std::vector<std::string>(keys.rbegin(), keys.rend())};
}
// Kill with LuauLValueAsKey.
std::string toString(const LValue& lvalue)
{
auto [symbol, keys] = getFullName(lvalue);
@ -56,7 +104,18 @@ std::string toString(const LValue& lvalue)
return s;
}
void merge(RefinementMap& l, const RefinementMap& r, std::function<TypeId(TypeId, TypeId)> f)
static void merge(NEW_RefinementMap& l, const NEW_RefinementMap& r, std::function<TypeId(TypeId, TypeId)> f)
{
for (const auto& [k, a] : r)
{
if (auto it = l.find(k); it != l.end())
l[k] = f(it->second, a);
else
l[k] = a;
}
}
static void merge(DEPRECATED_RefinementMap& l, const DEPRECATED_RefinementMap& r, std::function<TypeId(TypeId, TypeId)> f)
{
auto itL = l.begin();
auto itR = r.begin();
@ -69,21 +128,32 @@ void merge(RefinementMap& l, const RefinementMap& r, std::function<TypeId(TypeId
++itL;
++itR;
}
else if (itL->first > k)
else if (itL->first < k)
++itL;
else
{
l[k] = a;
++itR;
}
else
++itL;
}
l.insert(itR, r.end());
}
void merge(RefinementMap& l, const RefinementMap& r, std::function<TypeId(TypeId, TypeId)> f)
{
if (FFlag::LuauLValueAsKey)
return merge(l.NEW_refinements, r.NEW_refinements, f);
else
return merge(l.DEPRECATED_refinements, r.DEPRECATED_refinements, f);
}
void addRefinement(RefinementMap& refis, const LValue& lvalue, TypeId ty)
{
refis[toString(lvalue)] = ty;
if (FFlag::LuauLValueAsKey)
refis.NEW_refinements[lvalue] = ty;
else
refis.DEPRECATED_refinements[toString(lvalue)] = ty;
}
} // namespace Luau

View file

@ -36,6 +36,7 @@ LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false)
LUAU_FASTFLAGVARIABLE(LuauPropertiesGetExpectedType, false)
LUAU_FASTFLAGVARIABLE(LuauTailArgumentTypeInfo, false)
LUAU_FASTFLAGVARIABLE(LuauModuleRequireErrorPack, false)
LUAU_FASTFLAGVARIABLE(LuauLValueAsKey, false)
LUAU_FASTFLAGVARIABLE(LuauRefiLookupFromIndexExpr, false)
LUAU_FASTFLAGVARIABLE(LuauProperTypeLevels, false)
LUAU_FASTFLAGVARIABLE(LuauAscribeCorrectLevelToInferredProperitesOfFreeTables, false)
@ -1626,6 +1627,10 @@ std::optional<TypeId> TypeChecker::getIndexTypeFromType(
{
RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit);
// Not needed when we normalize types.
if (FFlag::LuauLValueAsKey && get<AnyTypeVar>(follow(t)))
return t;
if (std::optional<TypeId> ty = getIndexTypeFromType(scope, t, name, location, false))
goodOptions.push_back(*ty);
else
@ -4967,13 +4972,83 @@ std::pair<std::vector<TypeId>, std::vector<TypePackId>> TypeChecker::createGener
std::optional<TypeId> TypeChecker::resolveLValue(const ScopePtr& scope, const LValue& lvalue)
{
std::string path = toString(lvalue);
if (!FFlag::LuauLValueAsKey)
return DEPRECATED_resolveLValue(scope, lvalue);
// We want to be walking the Scope parents.
// We'll also want to walk up the LValue path. As we do this, we need to save each LValue because we must walk back.
// For example:
// There exists an entry t.x.
// We are asked to look for t.x.y.
// We need to search in the provided Scope. Find t.x.y first.
// We fail to find t.x.y. Try t.x. We found it. Now we must return the type of the property y from the mapped-to type of t.x.
// If we completely fail to find the Symbol t but the Scope has that entry, then we should walk that all the way through and terminate.
const auto& [symbol, keys] = getFullName(lvalue);
ScopePtr currentScope = scope;
while (currentScope)
{
std::optional<TypeId> found;
std::vector<LValue> childKeys;
const LValue* currentLValue = &lvalue;
while (currentLValue)
{
if (auto it = currentScope->refinements.NEW_refinements.find(*currentLValue); it != currentScope->refinements.NEW_refinements.end())
{
found = it->second;
break;
}
childKeys.push_back(*currentLValue);
currentLValue = baseof(*currentLValue);
}
if (!found)
{
// Should not be using scope->lookup. This is already recursive.
if (auto it = currentScope->bindings.find(symbol); it != currentScope->bindings.end())
found = it->second.typeId;
else
{
// Nothing exists in this Scope. Just skip and try the parent one.
currentScope = currentScope->parent;
continue;
}
}
for (auto it = childKeys.rbegin(); it != childKeys.rend(); ++it)
{
const LValue& key = *it;
// Symbol can happen. Skip.
if (get<Symbol>(key))
continue;
else if (auto field = get<Field>(key))
{
found = getIndexTypeFromType(scope, *found, field->key, Location(), false);
if (!found)
return std::nullopt; // Turns out this type doesn't have the property at all. We're done.
}
else
LUAU_ASSERT(!"New LValue alternative not handled here.");
}
return found;
}
// No entry for it at all. Can happen when LValue root is a global.
return std::nullopt;
}
std::optional<TypeId> TypeChecker::DEPRECATED_resolveLValue(const ScopePtr& scope, const LValue& lvalue)
{
auto [symbol, keys] = getFullName(lvalue);
ScopePtr currentScope = scope;
while (currentScope)
{
if (auto it = currentScope->refinements.find(path); it != currentScope->refinements.end())
if (auto it = currentScope->refinements.DEPRECATED_refinements.find(toString(lvalue)); it != currentScope->refinements.DEPRECATED_refinements.end())
return it->second;
// Should not be using scope->lookup. This is already recursive.
@ -5000,7 +5075,9 @@ std::optional<TypeId> TypeChecker::resolveLValue(const ScopePtr& scope, const LV
std::optional<TypeId> TypeChecker::resolveLValue(const RefinementMap& refis, const ScopePtr& scope, const LValue& lvalue)
{
if (auto it = refis.find(toString(lvalue)); it != refis.end())
if (auto it = refis.DEPRECATED_refinements.find(toString(lvalue)); it != refis.DEPRECATED_refinements.end())
return it->second;
else if (auto it = refis.NEW_refinements.find(lvalue); it != refis.NEW_refinements.end())
return it->second;
else
return resolveLValue(scope, lvalue);

View file

@ -996,18 +996,19 @@ std::optional<ExprResult<TypePackId>> magicFunctionFormat(
std::vector<TypeId> expected = parseFormatString(typechecker, fmt->value.data, fmt->value.size);
const auto& [params, tail] = flatten(paramPack);
const size_t dataOffset = 1;
size_t paramOffset = 1;
size_t dataOffset = expr.self ? 0 : 1;
// unify the prefix one argument at a time
for (size_t i = 0; i < expected.size() && i + dataOffset < params.size(); ++i)
for (size_t i = 0; i < expected.size() && i + paramOffset < params.size(); ++i)
{
Location location = expr.args.data[std::min(i, expr.args.size - 1)]->location;
Location location = expr.args.data[std::min(i + dataOffset, expr.args.size - 1)]->location;
typechecker.unify(expected[i], params[i + dataOffset], location);
typechecker.unify(expected[i], params[i + paramOffset], location);
}
// if we know the argument count or if we have too many arguments for sure, we can issue an error
const size_t actualParamSize = params.size() - dataOffset;
size_t actualParamSize = params.size() - paramOffset;
if (expected.size() != actualParamSize && (!tail || expected.size() < actualParamSize))
typechecker.reportError(TypeError{expr.location, CountMismatch{expected.size(), actualParamSize}});

View file

@ -18,9 +18,7 @@ LUAU_FASTFLAGVARIABLE(LuauTableSubtypingVariance2, false);
LUAU_FASTFLAGVARIABLE(LuauUnionHeuristic, false)
LUAU_FASTFLAGVARIABLE(LuauTableUnificationEarlyTest, false)
LUAU_FASTFLAGVARIABLE(LuauOccursCheckOkWithRecursiveFunctions, false)
LUAU_FASTFLAGVARIABLE(LuauExtendedTypeMismatchError, false)
LUAU_FASTFLAG(LuauSingletonTypes)
LUAU_FASTFLAGVARIABLE(LuauExtendedClassMismatchError, false)
LUAU_FASTFLAG(LuauErrorRecoveryType);
LUAU_FASTFLAG(LuauProperTypeLevels);
LUAU_FASTFLAGVARIABLE(LuauExtendedUnionMismatchError, false)
@ -416,7 +414,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool
else if (!innerState.errors.empty())
{
// 'nil' option is skipped from extended report because we present the type in a special way - 'T?'
if (FFlag::LuauExtendedTypeMismatchError && !firstFailedOption && !isNil(type))
if (!firstFailedOption && !isNil(type))
firstFailedOption = {innerState.errors.front()};
failed = true;
@ -434,7 +432,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool
errors.push_back(*unificationTooComplex);
else if (failed)
{
if (FFlag::LuauExtendedTypeMismatchError && firstFailedOption)
if (firstFailedOption)
errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "Not all union options are compatible.", *firstFailedOption}});
else
errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}});
@ -536,15 +534,11 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool
if (FFlag::LuauExtendedUnionMismatchError && (failedOptionCount == 1 || foundHeuristic) && failedOption)
errors.push_back(
TypeError{location, TypeMismatch{superTy, subTy, "None of the union options are compatible. For example:", *failedOption}});
else if (FFlag::LuauExtendedTypeMismatchError)
errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "none of the union options are compatible"}});
else
errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}});
errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "none of the union options are compatible"}});
}
}
else if (const IntersectionTypeVar* uv = get<IntersectionTypeVar>(superTy))
{
if (FFlag::LuauExtendedTypeMismatchError)
{
std::optional<TypeError> unificationTooComplex;
std::optional<TypeError> firstFailedOption;
@ -571,15 +565,6 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool
else if (firstFailedOption)
errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "Not all intersection parts are compatible.", *firstFailedOption}});
}
else
{
// T <: A & B if A <: T and B <: T
for (TypeId type : uv->parts)
{
tryUnify_(type, subTy, /*isFunctionCall*/ false, /*isIntersection*/ true);
}
}
}
else if (const IntersectionTypeVar* uv = get<IntersectionTypeVar>(subTy))
{
// A & B <: T if T <: A or T <: B
@ -626,10 +611,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool
errors.push_back(*unificationTooComplex);
else if (!found)
{
if (FFlag::LuauExtendedTypeMismatchError)
errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "none of the intersection parts are compatible"}});
else
errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}});
}
}
else if (get<PrimitiveTypeVar>(superTy) && get<PrimitiveTypeVar>(subTy))
@ -1241,10 +1223,7 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection)
Unifier innerState = makeChildUnifier();
innerState.tryUnify_(prop.type, r->second.type);
if (FFlag::LuauExtendedTypeMismatchError)
checkChildUnifierTypeMismatch(innerState.errors, name, left, right);
else
checkChildUnifierTypeMismatch(innerState.errors, left, right);
if (innerState.errors.empty())
log.concat(std::move(innerState.log));
@ -1261,10 +1240,7 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection)
Unifier innerState = makeChildUnifier();
innerState.tryUnify_(prop.type, rt->indexer->indexResultType);
if (FFlag::LuauExtendedTypeMismatchError)
checkChildUnifierTypeMismatch(innerState.errors, name, left, right);
else
checkChildUnifierTypeMismatch(innerState.errors, left, right);
if (innerState.errors.empty())
log.concat(std::move(innerState.log));
@ -1302,10 +1278,7 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection)
Unifier innerState = makeChildUnifier();
innerState.tryUnify_(prop.type, lt->indexer->indexResultType);
if (FFlag::LuauExtendedTypeMismatchError)
checkChildUnifierTypeMismatch(innerState.errors, name, left, right);
else
checkChildUnifierTypeMismatch(innerState.errors, left, right);
if (innerState.errors.empty())
log.concat(std::move(innerState.log));
@ -1723,18 +1696,11 @@ void Unifier::tryUnifyWithMetatable(TypeId metatable, TypeId other, bool reverse
innerState.tryUnify_(lhs->table, rhs->table);
innerState.tryUnify_(lhs->metatable, rhs->metatable);
if (FFlag::LuauExtendedTypeMismatchError)
{
if (auto e = hasUnificationTooComplex(innerState.errors))
errors.push_back(*e);
else if (!innerState.errors.empty())
errors.push_back(
TypeError{location, TypeMismatch{reversed ? other : metatable, reversed ? metatable : other, "", innerState.errors.front()}});
}
else
{
checkChildUnifierTypeMismatch(innerState.errors, reversed ? other : metatable, reversed ? metatable : other);
}
log.concat(std::move(innerState.log));
}
@ -1821,12 +1787,8 @@ void Unifier::tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed)
{
ok = false;
errors.push_back(TypeError{location, UnknownProperty{superTy, propName}});
if (!FFlag::LuauExtendedClassMismatchError)
tryUnify_(prop.type, getSingletonTypes().errorRecoveryType());
}
else
{
if (FFlag::LuauExtendedClassMismatchError)
{
Unifier innerState = makeChildUnifier();
innerState.tryUnify_(prop.type, classProp->type);
@ -1843,11 +1805,6 @@ void Unifier::tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed)
innerState.log.rollback();
}
}
else
{
tryUnify_(prop.type, classProp->type);
}
}
}
if (table->indexer)
@ -2185,8 +2142,6 @@ void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, TypeId
void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, const std::string& prop, TypeId wantedType, TypeId givenType)
{
LUAU_ASSERT(FFlag::LuauExtendedTypeMismatchError || FFlag::LuauExtendedClassMismatchError);
if (auto e = hasUnificationTooComplex(innerErrors))
errors.push_back(*e);
else if (!innerErrors.empty())

View file

@ -14,7 +14,6 @@ LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionBaseSupport, false)
LUAU_FASTFLAGVARIABLE(LuauIfStatementRecursionGuard, false)
LUAU_FASTFLAGVARIABLE(LuauFixAmbiguousErrorRecoveryInAssign, false)
LUAU_FASTFLAGVARIABLE(LuauParseSingletonTypes, false)
LUAU_FASTFLAGVARIABLE(LuauParseGenericFunctionTypeBegin, false)
namespace Luau
{
@ -1368,9 +1367,6 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack)
Lexeme parameterStart = lexer.current();
if (!FFlag::LuauParseGenericFunctionTypeBegin)
begin = parameterStart;
expectAndConsume('(', "function parameters");
matchRecoveryStopOnToken[Lexeme::SkinnyArrow]++;

View file

@ -45,6 +45,7 @@ target_sources(Luau.Analysis PRIVATE
Analysis/include/Luau/IostreamHelpers.h
Analysis/include/Luau/JsonEncoder.h
Analysis/include/Luau/Linter.h
Analysis/include/Luau/LValue.h
Analysis/include/Luau/Module.h
Analysis/include/Luau/ModuleResolver.h
Analysis/include/Luau/Predicate.h
@ -80,8 +81,8 @@ target_sources(Luau.Analysis PRIVATE
Analysis/src/IostreamHelpers.cpp
Analysis/src/JsonEncoder.cpp
Analysis/src/Linter.cpp
Analysis/src/LValue.cpp
Analysis/src/Module.cpp
Analysis/src/Predicate.cpp
Analysis/src/Quantify.cpp
Analysis/src/RequireTracer.cpp
Analysis/src/Scope.cpp
@ -194,10 +195,10 @@ if(TARGET Luau.UnitTest)
tests/Frontend.test.cpp
tests/JsonEncoder.test.cpp
tests/Linter.test.cpp
tests/LValue.test.cpp
tests/Module.test.cpp
tests/NonstrictMode.test.cpp
tests/Parser.test.cpp
tests/Predicate.test.cpp
tests/RequireTracer.test.cpp
tests/StringUtils.test.cpp
tests/Symbol.test.cpp

View file

@ -36,12 +36,14 @@ static int luaB_tonumber(lua_State* L)
int base = luaL_optinteger(L, 2, 10);
if (base == 10)
{ /* standard conversion */
luaL_checkany(L, 1);
if (lua_isnumber(L, 1))
int isnum = 0;
double n = lua_tonumberx(L, 1, &isnum);
if (isnum)
{
lua_pushnumber(L, lua_tonumber(L, 1));
lua_pushnumber(L, n);
return 1;
}
luaL_checkany(L, 1); /* error if we don't have any argument */
}
else
{

View file

@ -1394,7 +1394,7 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_function_return_types")
check(R"(
local function target(a: number, b: string) return a + #b end
local function bar1(a: number) return -a end
local function bar2(a: string) reutrn a .. 'x' end
local function bar2(a: string) return a .. 'x' end
return target(b@1
)");
@ -1422,7 +1422,7 @@ return target(bar1, b@1
check(R"(
local function target(a: number, b: string) return a + #b end
local function bar1(a: number): (...number) return -a, a end
local function bar2(a: string) reutrn a .. 'x' end
local function bar2(a: string) return a .. 'x' end
return target(b@1
)");
@ -1918,7 +1918,7 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_function_no_parenthesis")
check(R"(
local function target(a: (number) -> number) return a(4) end
local function bar1(a: number) return -a end
local function bar2(a: string) reutrn a .. 'x' end
local function bar2(a: string) return a .. 'x' end
return target(b@1
)");

198
tests/LValue.test.cpp Normal file
View file

@ -0,0 +1,198 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/TypeInfer.h"
#include "Fixture.h"
#include "ScopedFlags.h"
#include "doctest.h"
using namespace Luau;
static void merge(TypeArena& arena, RefinementMap& l, const RefinementMap& r)
{
Luau::merge(l, r, [&arena](TypeId a, TypeId b) -> TypeId {
// TODO: normalize here also.
std::unordered_set<TypeId> s;
if (auto utv = get<UnionTypeVar>(follow(a)))
s.insert(begin(utv), end(utv));
else
s.insert(a);
if (auto utv = get<UnionTypeVar>(follow(b)))
s.insert(begin(utv), end(utv));
else
s.insert(b);
std::vector<TypeId> options(s.begin(), s.end());
return options.size() == 1 ? options[0] : arena.addType(UnionTypeVar{std::move(options)});
});
}
static LValue mkSymbol(const std::string& s)
{
return Symbol{AstName{s.data()}};
}
TEST_SUITE_BEGIN("LValue");
TEST_CASE("Luau_merge_hashmap_order")
{
ScopedFastFlag sff{"LuauLValueAsKey", true};
std::string a = "a";
std::string b = "b";
std::string c = "c";
RefinementMap m{{
{mkSymbol(b), getSingletonTypes().stringType},
{mkSymbol(c), getSingletonTypes().numberType},
}};
RefinementMap other{{
{mkSymbol(a), getSingletonTypes().stringType},
{mkSymbol(b), getSingletonTypes().stringType},
{mkSymbol(c), getSingletonTypes().booleanType},
}};
TypeArena arena;
merge(arena, m, other);
REQUIRE_EQ(3, m.NEW_refinements.size());
REQUIRE(m.NEW_refinements.count(mkSymbol(a)));
REQUIRE(m.NEW_refinements.count(mkSymbol(b)));
REQUIRE(m.NEW_refinements.count(mkSymbol(c)));
CHECK_EQ("string", toString(m.NEW_refinements[mkSymbol(a)]));
CHECK_EQ("string", toString(m.NEW_refinements[mkSymbol(b)]));
CHECK_EQ("boolean | number", toString(m.NEW_refinements[mkSymbol(c)]));
}
TEST_CASE("Luau_merge_hashmap_order2")
{
ScopedFastFlag sff{"LuauLValueAsKey", true};
std::string a = "a";
std::string b = "b";
std::string c = "c";
RefinementMap m{{
{mkSymbol(a), getSingletonTypes().stringType},
{mkSymbol(b), getSingletonTypes().stringType},
{mkSymbol(c), getSingletonTypes().numberType},
}};
RefinementMap other{{
{mkSymbol(b), getSingletonTypes().stringType},
{mkSymbol(c), getSingletonTypes().booleanType},
}};
TypeArena arena;
merge(arena, m, other);
REQUIRE_EQ(3, m.NEW_refinements.size());
REQUIRE(m.NEW_refinements.count(mkSymbol(a)));
REQUIRE(m.NEW_refinements.count(mkSymbol(b)));
REQUIRE(m.NEW_refinements.count(mkSymbol(c)));
CHECK_EQ("string", toString(m.NEW_refinements[mkSymbol(a)]));
CHECK_EQ("string", toString(m.NEW_refinements[mkSymbol(b)]));
CHECK_EQ("boolean | number", toString(m.NEW_refinements[mkSymbol(c)]));
}
TEST_CASE("one_map_has_overlap_at_end_whereas_other_has_it_in_start")
{
ScopedFastFlag sff{"LuauLValueAsKey", true};
std::string a = "a";
std::string b = "b";
std::string c = "c";
std::string d = "d";
std::string e = "e";
RefinementMap m{{
{mkSymbol(a), getSingletonTypes().stringType},
{mkSymbol(b), getSingletonTypes().numberType},
{mkSymbol(c), getSingletonTypes().booleanType},
}};
RefinementMap other{{
{mkSymbol(c), getSingletonTypes().stringType},
{mkSymbol(d), getSingletonTypes().numberType},
{mkSymbol(e), getSingletonTypes().booleanType},
}};
TypeArena arena;
merge(arena, m, other);
REQUIRE_EQ(5, m.NEW_refinements.size());
REQUIRE(m.NEW_refinements.count(mkSymbol(a)));
REQUIRE(m.NEW_refinements.count(mkSymbol(b)));
REQUIRE(m.NEW_refinements.count(mkSymbol(c)));
REQUIRE(m.NEW_refinements.count(mkSymbol(d)));
REQUIRE(m.NEW_refinements.count(mkSymbol(e)));
CHECK_EQ("string", toString(m.NEW_refinements[mkSymbol(a)]));
CHECK_EQ("number", toString(m.NEW_refinements[mkSymbol(b)]));
CHECK_EQ("boolean | string", toString(m.NEW_refinements[mkSymbol(c)]));
CHECK_EQ("number", toString(m.NEW_refinements[mkSymbol(d)]));
CHECK_EQ("boolean", toString(m.NEW_refinements[mkSymbol(e)]));
}
TEST_CASE("hashing_lvalue_global_prop_access")
{
std::string t1 = "t";
std::string x1 = "x";
LValue t_x1{Field{std::make_shared<LValue>(Symbol{AstName{t1.data()}}), x1}};
std::string t2 = "t";
std::string x2 = "x";
LValue t_x2{Field{std::make_shared<LValue>(Symbol{AstName{t2.data()}}), x2}};
CHECK_EQ(t_x1, t_x1);
CHECK_EQ(t_x1, t_x2);
CHECK_EQ(t_x2, t_x2);
CHECK_EQ(LValueHasher{}(t_x1), LValueHasher{}(t_x1));
CHECK_EQ(LValueHasher{}(t_x1), LValueHasher{}(t_x2));
CHECK_EQ(LValueHasher{}(t_x2), LValueHasher{}(t_x2));
NEW_RefinementMap m;
m[t_x1] = getSingletonTypes().stringType;
m[t_x2] = getSingletonTypes().numberType;
CHECK_EQ(1, m.size());
}
TEST_CASE("hashing_lvalue_local_prop_access")
{
std::string t1 = "t";
std::string x1 = "x";
AstLocal localt1{AstName{t1.data()}, Location(), nullptr, 0, 0, nullptr};
LValue t_x1{Field{std::make_shared<LValue>(Symbol{&localt1}), x1}};
std::string t2 = "t";
std::string x2 = "x";
AstLocal localt2{AstName{t2.data()}, Location(), &localt1, 0, 0, nullptr};
LValue t_x2{Field{std::make_shared<LValue>(Symbol{&localt2}), x2}};
CHECK_EQ(t_x1, t_x1);
CHECK_NE(t_x1, t_x2);
CHECK_EQ(t_x2, t_x2);
CHECK_EQ(LValueHasher{}(t_x1), LValueHasher{}(t_x1));
CHECK_NE(LValueHasher{}(t_x1), LValueHasher{}(t_x2));
CHECK_EQ(LValueHasher{}(t_x2), LValueHasher{}(t_x2));
NEW_RefinementMap m;
m[t_x1] = getSingletonTypes().stringType;
m[t_x2] = getSingletonTypes().numberType;
CHECK_EQ(2, m.size());
}
TEST_SUITE_END();

View file

@ -1,117 +0,0 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/TypeInfer.h"
#include "Fixture.h"
#include "ScopedFlags.h"
#include "doctest.h"
using namespace Luau;
static void merge(TypeArena& arena, RefinementMap& l, const RefinementMap& r)
{
Luau::merge(l, r, [&arena](TypeId a, TypeId b) -> TypeId {
// TODO: normalize here also.
std::unordered_set<TypeId> s;
if (auto utv = get<UnionTypeVar>(follow(a)))
s.insert(begin(utv), end(utv));
else
s.insert(a);
if (auto utv = get<UnionTypeVar>(follow(b)))
s.insert(begin(utv), end(utv));
else
s.insert(b);
std::vector<TypeId> options(s.begin(), s.end());
return options.size() == 1 ? options[0] : arena.addType(UnionTypeVar{std::move(options)});
});
}
TEST_SUITE_BEGIN("Predicate");
TEST_CASE_FIXTURE(Fixture, "Luau_merge_hashmap_order")
{
RefinementMap m{
{"b", typeChecker.stringType},
{"c", typeChecker.numberType},
};
RefinementMap other{
{"a", typeChecker.stringType},
{"b", typeChecker.stringType},
{"c", typeChecker.booleanType},
};
TypeArena arena;
merge(arena, m, other);
REQUIRE_EQ(3, m.size());
REQUIRE(m.count("a"));
REQUIRE(m.count("b"));
REQUIRE(m.count("c"));
CHECK_EQ("string", toString(m["a"]));
CHECK_EQ("string", toString(m["b"]));
CHECK_EQ("boolean | number", toString(m["c"]));
}
TEST_CASE_FIXTURE(Fixture, "Luau_merge_hashmap_order2")
{
RefinementMap m{
{"a", typeChecker.stringType},
{"b", typeChecker.stringType},
{"c", typeChecker.numberType},
};
RefinementMap other{
{"b", typeChecker.stringType},
{"c", typeChecker.booleanType},
};
TypeArena arena;
merge(arena, m, other);
REQUIRE_EQ(3, m.size());
REQUIRE(m.count("a"));
REQUIRE(m.count("b"));
REQUIRE(m.count("c"));
CHECK_EQ("string", toString(m["a"]));
CHECK_EQ("string", toString(m["b"]));
CHECK_EQ("boolean | number", toString(m["c"]));
}
TEST_CASE_FIXTURE(Fixture, "one_map_has_overlap_at_end_whereas_other_has_it_in_start")
{
RefinementMap m{
{"a", typeChecker.stringType},
{"b", typeChecker.numberType},
{"c", typeChecker.booleanType},
};
RefinementMap other{
{"c", typeChecker.stringType},
{"d", typeChecker.numberType},
{"e", typeChecker.booleanType},
};
TypeArena arena;
merge(arena, m, other);
REQUIRE_EQ(5, m.size());
REQUIRE(m.count("a"));
REQUIRE(m.count("b"));
REQUIRE(m.count("c"));
REQUIRE(m.count("d"));
REQUIRE(m.count("e"));
CHECK_EQ("string", toString(m["a"]));
CHECK_EQ("number", toString(m["b"]));
CHECK_EQ("boolean | string", toString(m["c"]));
CHECK_EQ("number", toString(m["d"]));
CHECK_EQ("boolean", toString(m["e"]));
}
TEST_SUITE_END();

View file

@ -10,7 +10,7 @@ using namespace Luau;
TEST_SUITE_BEGIN("SymbolTests");
TEST_CASE("hashing")
TEST_CASE("hashing_globals")
{
std::string s1 = "name";
std::string s2 = "name";
@ -31,10 +31,37 @@ TEST_CASE("hashing")
CHECK_EQ(std::hash<Symbol>()(two), std::hash<Symbol>()(two));
std::unordered_map<Symbol, int> theMap;
theMap[AstName{s1.data()}] = 5;
theMap[AstName{s2.data()}] = 1;
theMap[n1] = 5;
theMap[n2] = 1;
REQUIRE_EQ(1, theMap.size());
}
TEST_CASE("hashing_locals")
{
std::string s1 = "name";
std::string s2 = "name";
// These two names point to distinct memory areas.
AstLocal one{AstName{s1.data()}, Location(), nullptr, 0, 0, nullptr};
AstLocal two{AstName{s2.data()}, Location(), &one, 0, 0, nullptr};
Symbol n1{&one};
Symbol n2{&two};
CHECK(n1 == n1);
CHECK(n1 != n2);
CHECK(n2 == n2);
CHECK_EQ(std::hash<Symbol>()(&one), std::hash<Symbol>()(&one));
CHECK_NE(std::hash<Symbol>()(&one), std::hash<Symbol>()(&two));
CHECK_EQ(std::hash<Symbol>()(&two), std::hash<Symbol>()(&two));
std::unordered_map<Symbol, int> theMap;
theMap[n1] = 5;
theMap[n2] = 1;
REQUIRE_EQ(2, theMap.size());
}
TEST_SUITE_END();

View file

@ -555,8 +555,6 @@ TEST_CASE_FIXTURE(Fixture, "transpile_assign_multiple")
TEST_CASE_FIXTURE(Fixture, "transpile_generic_function")
{
ScopedFastFlag luauParseGenericFunctionTypeBegin("LuauParseGenericFunctionTypeBegin", true);
std::string code = R"(
local function foo<T,S...>(a: T, ...: S...) return 1 end
local f: <T,S...>(T, S...)->(number) = foo

View file

@ -798,13 +798,14 @@ TEST_CASE_FIXTURE(Fixture, "string_format_report_all_type_errors_at_correct_posi
{
CheckResult result = check(R"(
("%s%d%s"):format(1, "hello", true)
string.format("%s%d%s", 1, "hello", true)
)");
TypeId stringType = typeChecker.stringType;
TypeId numberType = typeChecker.numberType;
TypeId booleanType = typeChecker.booleanType;
LUAU_REQUIRE_ERROR_COUNT(3, result);
LUAU_REQUIRE_ERROR_COUNT(6, result);
CHECK_EQ(Location(Position{1, 26}, Position{1, 27}), result.errors[0].location);
CHECK_EQ(TypeErrorData(TypeMismatch{stringType, numberType}), result.errors[0].data);
@ -814,6 +815,15 @@ TEST_CASE_FIXTURE(Fixture, "string_format_report_all_type_errors_at_correct_posi
CHECK_EQ(Location(Position{1, 38}, Position{1, 42}), result.errors[2].location);
CHECK_EQ(TypeErrorData(TypeMismatch{stringType, booleanType}), result.errors[2].data);
CHECK_EQ(Location(Position{2, 32}, Position{2, 33}), result.errors[3].location);
CHECK_EQ(TypeErrorData(TypeMismatch{stringType, numberType}), result.errors[3].data);
CHECK_EQ(Location(Position{2, 35}, Position{2, 42}), result.errors[4].location);
CHECK_EQ(TypeErrorData(TypeMismatch{numberType, stringType}), result.errors[4].data);
CHECK_EQ(Location(Position{2, 44}, Position{2, 48}), result.errors[5].location);
CHECK_EQ(TypeErrorData(TypeMismatch{stringType, booleanType}), result.errors[5].data);
}
TEST_CASE_FIXTURE(Fixture, "tonumber_returns_optional_number_type")

View file

@ -449,8 +449,6 @@ b.X = 2 -- real Vector2.X is also read-only
TEST_CASE_FIXTURE(ClassFixture, "detailed_class_unification_error")
{
ScopedFastFlag luauExtendedClassMismatchError{"LuauExtendedClassMismatchError", true};
CheckResult result = check(R"(
local function foo(v)
return v.X :: number + string.len(v.Y)

View file

@ -343,8 +343,6 @@ TEST_CASE_FIXTURE(Fixture, "table_intersection_setmetatable")
TEST_CASE_FIXTURE(Fixture, "error_detailed_intersection_part")
{
ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true};
CheckResult result = check(R"(
type X = { x: number }
type Y = { y: number }
@ -363,8 +361,6 @@ caused by:
TEST_CASE_FIXTURE(Fixture, "error_detailed_intersection_all")
{
ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true};
CheckResult result = check(R"(
type X = { x: number }
type Y = { y: number }

View file

@ -280,7 +280,6 @@ TEST_CASE_FIXTURE(Fixture, "assert_non_binary_expressions_actually_resolve_const
TEST_CASE_FIXTURE(Fixture, "assign_table_with_refined_property_with_a_similar_type_is_illegal")
{
ScopedFastFlag LuauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true};
ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true};
CheckResult result = check(R"(
local t: {x: number?} = {x = nil}
@ -1085,6 +1084,41 @@ TEST_CASE_FIXTURE(Fixture, "type_comparison_ifelse_expression")
CHECK_EQ("any", toString(requireTypeAtPosition({6, 66})));
}
TEST_CASE_FIXTURE(Fixture, "correctly_lookup_a_shadowed_local_that_which_was_previously_refined")
{
ScopedFastFlag sff{"LuauLValueAsKey", true};
CheckResult result = check(R"(
local foo: string? = "hi"
assert(foo)
local foo: number = 5
print(foo:sub(1, 1))
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ("Type 'number' does not have key 'sub'", toString(result.errors[0]));
}
TEST_CASE_FIXTURE(Fixture, "correctly_lookup_property_whose_base_was_previously_refined")
{
ScopedFastFlag sff{"LuauLValueAsKey", true};
CheckResult result = check(R"(
type T = {x: string | number}
local t: T? = {x = "hi"}
if t then
if type(t.x) == "string" then
local foo = t.x
end
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("string", toString(requireTypeAtPosition({5, 30})));
}
TEST_CASE_FIXTURE(Fixture, "apply_refinements_on_astexprindexexpr_whose_subscript_expr_is_constant_string")
{
ScopedFastFlag sff{"LuauRefiLookupFromIndexExpr", true};
@ -1092,6 +1126,7 @@ TEST_CASE_FIXTURE(Fixture, "apply_refinements_on_astexprindexexpr_whose_subscrip
CheckResult result = check(R"(
type T = { [string]: { prop: number }? }
local t: T = {}
if t["hello"] then
local foo = t["hello"].prop
end

View file

@ -202,7 +202,6 @@ TEST_CASE_FIXTURE(Fixture, "enums_using_singletons_mismatch")
ScopedFastFlag sffs[] = {
{"LuauSingletonTypes", true},
{"LuauParseSingletonTypes", true},
{"LuauExtendedTypeMismatchError", true},
};
CheckResult result = check(R"(

View file

@ -1955,7 +1955,6 @@ TEST_CASE_FIXTURE(Fixture, "table_insert_should_cope_with_optional_properties_in
TEST_CASE_FIXTURE(Fixture, "error_detailed_prop")
{
ScopedFastFlag LuauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; // Only for new path
ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true};
CheckResult result = check(R"(
type A = { x: number, y: number }
@ -1974,7 +1973,6 @@ caused by:
TEST_CASE_FIXTURE(Fixture, "error_detailed_prop_nested")
{
ScopedFastFlag LuauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; // Only for new path
ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true};
CheckResult result = check(R"(
type AS = { x: number, y: number }
@ -1998,7 +1996,6 @@ caused by:
TEST_CASE_FIXTURE(Fixture, "error_detailed_metatable_prop")
{
ScopedFastFlag LuauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; // Only for new path
ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true};
CheckResult result = check(R"(
local a1 = setmetatable({ x = 2, y = 3 }, { __call = function(s) end });
@ -2062,7 +2059,6 @@ TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table_error")
{"LuauPropertiesGetExpectedType", true},
{"LuauExpectedTypesOfProperties", true},
{"LuauTableSubtypingVariance2", true},
{"LuauExtendedTypeMismatchError", true},
};
CheckResult result = check(R"(

View file

@ -5013,6 +5013,19 @@ caused by:
Return #2 type is not compatible. Type 'string' could not be converted into 'boolean')");
}
TEST_CASE_FIXTURE(Fixture, "prop_access_on_any_with_other_options")
{
ScopedFastFlag sff{"LuauLValueAsKey", true};
CheckResult result = check(R"(
local function f(thing: any | string)
local foo = thing.SomeRandomKey
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "table_function_check_use_after_free")
{
ScopedFastFlag luauUnifyFunctionCheckResult{"LuauUpdateFunctionNameBinding", true};

View file

@ -425,8 +425,6 @@ y = x
TEST_CASE_FIXTURE(Fixture, "error_detailed_union_part")
{
ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true};
CheckResult result = check(R"(
type X = { x: number }
type Y = { y: number }
@ -446,8 +444,6 @@ caused by:
TEST_CASE_FIXTURE(Fixture, "error_detailed_union_all")
{
ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true};
CheckResult result = check(R"(
type X = { x: number }
type Y = { y: number }

View file

@ -26,6 +26,7 @@ function f(...)
end
end
assert(pcall(tonumber) == false)
assert(tonumber{} == nil)
assert(tonumber'+0.01' == 1/100 and tonumber'+.01' == 0.01 and
tonumber'.01' == 0.01 and tonumber'-1.' == -1 and