luau/Analysis/src/NonStrictTypeChecker.cpp

620 lines
18 KiB
C++
Raw Normal View History

2023-09-30 01:22:06 +01:00
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/NonStrictTypeChecker.h"
2023-10-13 20:38:31 +01:00
#include "Luau/Ast.h"
#include "Luau/Common.h"
2023-11-10 18:05:48 +00:00
#include "Luau/Simplify.h"
2023-09-30 01:22:06 +01:00
#include "Luau/Type.h"
2023-11-10 18:05:48 +00:00
#include "Luau/Simplify.h"
2023-09-30 01:22:06 +01:00
#include "Luau/Subtyping.h"
#include "Luau/Normalize.h"
#include "Luau/Error.h"
#include "Luau/TypeArena.h"
2023-10-13 20:38:31 +01:00
#include "Luau/TypeFamily.h"
2023-09-30 01:22:06 +01:00
#include "Luau/Def.h"
2023-10-13 20:38:31 +01:00
#include <iostream>
2023-09-30 01:22:06 +01:00
namespace Luau
{
2023-10-13 20:38:31 +01:00
/* Push a scope onto the end of a stack for the lifetime of the StackPusher instance.
* NonStrictTypeChecker uses this to maintain knowledge about which scope encloses every
* given AstNode.
*/
struct StackPusher
{
std::vector<NotNull<Scope>>* stack;
NotNull<Scope> scope;
explicit StackPusher(std::vector<NotNull<Scope>>& stack, Scope* scope)
: stack(&stack)
, scope(scope)
{
stack.push_back(NotNull{scope});
}
~StackPusher()
{
if (stack)
{
LUAU_ASSERT(stack->back() == scope);
stack->pop_back();
}
}
StackPusher(const StackPusher&) = delete;
StackPusher&& operator=(const StackPusher&) = delete;
StackPusher(StackPusher&& other)
: stack(std::exchange(other.stack, nullptr))
, scope(other.scope)
{
}
};
2023-09-30 01:22:06 +01:00
struct NonStrictContext
{
2023-10-13 20:38:31 +01:00
std::unordered_map<const Def*, TypeId> context;
2023-09-30 01:22:06 +01:00
NonStrictContext() = default;
NonStrictContext(const NonStrictContext&) = delete;
NonStrictContext& operator=(const NonStrictContext&) = delete;
NonStrictContext(NonStrictContext&&) = default;
NonStrictContext& operator=(NonStrictContext&&) = default;
2023-11-10 18:05:48 +00:00
static NonStrictContext disjunction(
NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, const NonStrictContext& left, const NonStrictContext& right)
2023-09-30 01:22:06 +01:00
{
2023-11-10 18:05:48 +00:00
// disjunction implements union over the domain of keys
// if the default value for a defId not in the map is `never`
// then never | T is T
NonStrictContext disj{};
for (auto [def, leftTy] : left.context)
{
if (std::optional<TypeId> rightTy = right.find(def))
disj.context[def] = simplifyUnion(builtinTypes, arena, leftTy, *rightTy).result;
else
disj.context[def] = leftTy;
}
for (auto [def, rightTy] : right.context)
{
2023-11-17 18:15:31 +00:00
if (!left.find(def).has_value())
2023-11-10 18:05:48 +00:00
disj.context[def] = rightTy;
}
return disj;
2023-09-30 01:22:06 +01:00
}
2023-11-10 18:05:48 +00:00
static NonStrictContext conjunction(
NotNull<BuiltinTypes> builtins, NotNull<TypeArena> arena, const NonStrictContext& left, const NonStrictContext& right)
2023-09-30 01:22:06 +01:00
{
2023-11-10 18:05:48 +00:00
NonStrictContext conj{};
for (auto [def, leftTy] : left.context)
{
if (std::optional<TypeId> rightTy = right.find(def))
conj.context[def] = simplifyIntersection(builtins, arena, leftTy, *rightTy).result;
}
return conj;
2023-09-30 01:22:06 +01:00
}
void removeFromContext(const std::vector<DefId>& defs)
{
// TODO: unimplemented
}
2023-10-13 20:38:31 +01:00
std::optional<TypeId> find(const DefId& def) const
2023-09-30 01:22:06 +01:00
{
2023-10-13 20:38:31 +01:00
const Def* d = def.get();
2023-11-10 18:05:48 +00:00
return find(d);
}
private:
std::optional<TypeId> find(const Def* d) const
{
2023-10-13 20:38:31 +01:00
auto it = context.find(d);
if (it != context.end())
return {it->second};
2023-09-30 01:22:06 +01:00
return {};
}
};
struct NonStrictTypeChecker
{
NotNull<BuiltinTypes> builtinTypes;
const NotNull<InternalErrorReporter> ice;
TypeArena arena;
Module* module;
Normalizer normalizer;
Subtyping subtyping;
2023-10-13 20:38:31 +01:00
NotNull<const DataFlowGraph> dfg;
DenseHashSet<TypeId> noTypeFamilyErrors{nullptr};
std::vector<NotNull<Scope>> stack;
2023-09-30 01:22:06 +01:00
2023-10-13 20:38:31 +01:00
const NotNull<TypeCheckLimits> limits;
2023-09-30 01:22:06 +01:00
2023-10-13 20:38:31 +01:00
NonStrictTypeChecker(NotNull<BuiltinTypes> builtinTypes, const NotNull<InternalErrorReporter> ice, NotNull<UnifierSharedState> unifierState,
NotNull<const DataFlowGraph> dfg, NotNull<TypeCheckLimits> limits, Module* module)
2023-09-30 01:22:06 +01:00
: builtinTypes(builtinTypes)
, ice(ice)
, module(module)
, normalizer{&arena, builtinTypes, unifierState, /* cache inhabitance */ true}
, subtyping{builtinTypes, NotNull{&arena}, NotNull(&normalizer), ice, NotNull{module->getModuleScope().get()}}
2023-10-13 20:38:31 +01:00
, dfg(dfg)
, limits(limits)
{
}
std::optional<StackPusher> pushStack(AstNode* node)
{
if (Scope** scope = module->astScopes.find(node))
return StackPusher{stack, *scope};
else
return std::nullopt;
}
TypeId flattenPack(TypePackId pack)
{
pack = follow(pack);
if (auto fst = first(pack, /*ignoreHiddenVariadics*/ false))
return *fst;
else if (auto ftp = get<FreeTypePack>(pack))
{
TypeId result = arena.addType(FreeType{ftp->scope});
TypePackId freeTail = arena.addTypePack(FreeTypePack{ftp->scope});
TypePack& resultPack = asMutable(pack)->ty.emplace<TypePack>();
resultPack.head.assign(1, result);
resultPack.tail = freeTail;
return result;
}
else if (get<Unifiable::Error>(pack))
return builtinTypes->errorRecoveryType();
else if (finite(pack) && size(pack) == 0)
return builtinTypes->nilType; // `(f())` where `f()` returns no values is coerced into `nil`
else
ice->ice("flattenPack got a weird pack!");
}
TypeId checkForFamilyInhabitance(TypeId instance, Location location)
{
if (noTypeFamilyErrors.find(instance))
return instance;
ErrorVec errors = reduceFamilies(
instance, location, TypeFamilyContext{NotNull{&arena}, builtinTypes, stack.back(), NotNull{&normalizer}, ice, limits}, true)
.errors;
if (errors.empty())
noTypeFamilyErrors.insert(instance);
// TODO??
// if (!isErrorSuppressing(location, instance))
// reportErrors(std::move(errors));
return instance;
}
TypeId lookupType(AstExpr* expr)
{
TypeId* ty = module->astTypes.find(expr);
if (ty)
return checkForFamilyInhabitance(follow(*ty), expr->location);
TypePackId* tp = module->astTypePacks.find(expr);
if (tp)
return checkForFamilyInhabitance(flattenPack(*tp), expr->location);
return builtinTypes->anyType;
}
2023-11-10 18:05:48 +00:00
NonStrictContext visit(AstStat* stat)
2023-10-13 20:38:31 +01:00
{
auto pusher = pushStack(stat);
if (auto s = stat->as<AstStatBlock>())
2023-11-10 18:05:48 +00:00
return visit(s);
2023-10-13 20:38:31 +01:00
else if (auto s = stat->as<AstStatIf>())
2023-11-10 18:05:48 +00:00
return visit(s);
2023-10-13 20:38:31 +01:00
else if (auto s = stat->as<AstStatWhile>())
2023-11-10 18:05:48 +00:00
return visit(s);
2023-10-13 20:38:31 +01:00
else if (auto s = stat->as<AstStatRepeat>())
2023-11-10 18:05:48 +00:00
return visit(s);
2023-10-13 20:38:31 +01:00
else if (auto s = stat->as<AstStatBreak>())
2023-11-10 18:05:48 +00:00
return visit(s);
2023-10-13 20:38:31 +01:00
else if (auto s = stat->as<AstStatContinue>())
2023-11-10 18:05:48 +00:00
return visit(s);
2023-10-13 20:38:31 +01:00
else if (auto s = stat->as<AstStatReturn>())
2023-11-10 18:05:48 +00:00
return visit(s);
2023-10-13 20:38:31 +01:00
else if (auto s = stat->as<AstStatExpr>())
2023-11-10 18:05:48 +00:00
return visit(s);
2023-10-13 20:38:31 +01:00
else if (auto s = stat->as<AstStatLocal>())
2023-11-10 18:05:48 +00:00
return visit(s);
2023-10-13 20:38:31 +01:00
else if (auto s = stat->as<AstStatFor>())
2023-11-10 18:05:48 +00:00
return visit(s);
2023-10-13 20:38:31 +01:00
else if (auto s = stat->as<AstStatForIn>())
2023-11-10 18:05:48 +00:00
return visit(s);
2023-10-13 20:38:31 +01:00
else if (auto s = stat->as<AstStatAssign>())
2023-11-10 18:05:48 +00:00
return visit(s);
2023-10-13 20:38:31 +01:00
else if (auto s = stat->as<AstStatCompoundAssign>())
2023-11-10 18:05:48 +00:00
return visit(s);
2023-10-13 20:38:31 +01:00
else if (auto s = stat->as<AstStatFunction>())
2023-11-10 18:05:48 +00:00
return visit(s);
2023-10-13 20:38:31 +01:00
else if (auto s = stat->as<AstStatLocalFunction>())
2023-11-10 18:05:48 +00:00
return visit(s);
2023-10-13 20:38:31 +01:00
else if (auto s = stat->as<AstStatTypeAlias>())
2023-11-10 18:05:48 +00:00
return visit(s);
2023-10-13 20:38:31 +01:00
else if (auto s = stat->as<AstStatDeclareFunction>())
2023-11-10 18:05:48 +00:00
return visit(s);
2023-10-13 20:38:31 +01:00
else if (auto s = stat->as<AstStatDeclareGlobal>())
2023-11-10 18:05:48 +00:00
return visit(s);
2023-10-13 20:38:31 +01:00
else if (auto s = stat->as<AstStatDeclareClass>())
2023-11-10 18:05:48 +00:00
return visit(s);
2023-10-13 20:38:31 +01:00
else if (auto s = stat->as<AstStatError>())
2023-11-10 18:05:48 +00:00
return visit(s);
2023-10-13 20:38:31 +01:00
else
2023-11-10 18:05:48 +00:00
{
LUAU_ASSERT(!"NonStrictTypeChecker encountered an unknown statement type");
ice->ice("NonStrictTypeChecker encountered an unknown statement type");
}
2023-10-13 20:38:31 +01:00
}
2023-11-10 18:05:48 +00:00
NonStrictContext visit(AstStatBlock* block)
2023-10-13 20:38:31 +01:00
{
auto StackPusher = pushStack(block);
2023-11-17 18:15:31 +00:00
NonStrictContext ctx;
2023-10-13 20:38:31 +01:00
for (AstStat* statement : block->body)
2023-11-17 18:15:31 +00:00
ctx = NonStrictContext::disjunction(builtinTypes, NotNull{&arena}, ctx, visit(statement));
return ctx;
2023-11-10 18:05:48 +00:00
}
NonStrictContext visit(AstStatIf* ifStatement)
{
NonStrictContext condB = visit(ifStatement->condition);
2023-11-17 18:15:31 +00:00
NonStrictContext branchContext;
// If there is no else branch, don't bother generating warnings for the then branch - we can't prove there is an error
if (ifStatement->elsebody)
{
NonStrictContext thenBody = visit(ifStatement->thenbody);
NonStrictContext elseBody = visit(ifStatement->elsebody);
branchContext = NonStrictContext::conjunction(builtinTypes, NotNull{&arena}, thenBody, elseBody);
}
return NonStrictContext::disjunction(builtinTypes, NotNull{&arena}, condB, branchContext);
2023-11-10 18:05:48 +00:00
}
NonStrictContext visit(AstStatWhile* whileStatement)
{
return {};
}
NonStrictContext visit(AstStatRepeat* repeatStatement)
{
return {};
}
NonStrictContext visit(AstStatBreak* breakStatement)
{
return {};
}
NonStrictContext visit(AstStatContinue* continueStatement)
{
return {};
}
NonStrictContext visit(AstStatReturn* returnStatement)
{
return {};
}
NonStrictContext visit(AstStatExpr* expr)
{
return visit(expr->expr);
}
NonStrictContext visit(AstStatLocal* local)
{
2023-11-17 18:15:31 +00:00
for (AstExpr* rhs : local->values)
visit(rhs);
2023-11-10 18:05:48 +00:00
return {};
}
NonStrictContext visit(AstStatFor* forStatement)
{
return {};
}
NonStrictContext visit(AstStatForIn* forInStatement)
{
return {};
}
NonStrictContext visit(AstStatAssign* assign)
{
return {};
}
NonStrictContext visit(AstStatCompoundAssign* compoundAssign)
{
return {};
}
NonStrictContext visit(AstStatFunction* statFn)
{
2023-11-17 18:15:31 +00:00
return visit(statFn->func);
2023-11-10 18:05:48 +00:00
}
NonStrictContext visit(AstStatLocalFunction* localFn)
{
2023-11-17 18:15:31 +00:00
return visit(localFn->func);
2023-11-10 18:05:48 +00:00
}
NonStrictContext visit(AstStatTypeAlias* typeAlias)
{
return {};
}
NonStrictContext visit(AstStatDeclareFunction* declFn)
{
return {};
}
NonStrictContext visit(AstStatDeclareGlobal* declGlobal)
{
return {};
}
NonStrictContext visit(AstStatDeclareClass* declClass)
{
return {};
}
NonStrictContext visit(AstStatError* error)
{
return {};
}
NonStrictContext visit(AstExpr* expr)
2023-10-13 20:38:31 +01:00
{
auto pusher = pushStack(expr);
if (auto e = expr->as<AstExprGroup>())
2023-11-10 18:05:48 +00:00
return visit(e);
2023-10-13 20:38:31 +01:00
else if (auto e = expr->as<AstExprConstantNil>())
2023-11-10 18:05:48 +00:00
return visit(e);
2023-10-13 20:38:31 +01:00
else if (auto e = expr->as<AstExprConstantBool>())
2023-11-10 18:05:48 +00:00
return visit(e);
2023-10-13 20:38:31 +01:00
else if (auto e = expr->as<AstExprConstantNumber>())
2023-11-10 18:05:48 +00:00
return visit(e);
2023-10-13 20:38:31 +01:00
else if (auto e = expr->as<AstExprConstantString>())
2023-11-10 18:05:48 +00:00
return visit(e);
2023-10-13 20:38:31 +01:00
else if (auto e = expr->as<AstExprLocal>())
2023-11-10 18:05:48 +00:00
return visit(e);
2023-10-13 20:38:31 +01:00
else if (auto e = expr->as<AstExprGlobal>())
2023-11-10 18:05:48 +00:00
return visit(e);
2023-10-13 20:38:31 +01:00
else if (auto e = expr->as<AstExprVarargs>())
2023-11-10 18:05:48 +00:00
return visit(e);
2023-10-13 20:38:31 +01:00
else if (auto e = expr->as<AstExprCall>())
2023-11-10 18:05:48 +00:00
return visit(e);
2023-10-13 20:38:31 +01:00
else if (auto e = expr->as<AstExprIndexName>())
2023-11-10 18:05:48 +00:00
return visit(e);
2023-10-13 20:38:31 +01:00
else if (auto e = expr->as<AstExprIndexExpr>())
2023-11-10 18:05:48 +00:00
return visit(e);
2023-10-13 20:38:31 +01:00
else if (auto e = expr->as<AstExprFunction>())
2023-11-10 18:05:48 +00:00
return visit(e);
2023-10-13 20:38:31 +01:00
else if (auto e = expr->as<AstExprTable>())
2023-11-10 18:05:48 +00:00
return visit(e);
2023-10-13 20:38:31 +01:00
else if (auto e = expr->as<AstExprUnary>())
2023-11-10 18:05:48 +00:00
return visit(e);
2023-10-13 20:38:31 +01:00
else if (auto e = expr->as<AstExprBinary>())
2023-11-10 18:05:48 +00:00
return visit(e);
2023-10-13 20:38:31 +01:00
else if (auto e = expr->as<AstExprTypeAssertion>())
2023-11-10 18:05:48 +00:00
return visit(e);
2023-10-13 20:38:31 +01:00
else if (auto e = expr->as<AstExprIfElse>())
2023-11-10 18:05:48 +00:00
return visit(e);
2023-10-13 20:38:31 +01:00
else if (auto e = expr->as<AstExprInterpString>())
2023-11-10 18:05:48 +00:00
return visit(e);
2023-10-13 20:38:31 +01:00
else if (auto e = expr->as<AstExprError>())
2023-11-10 18:05:48 +00:00
return visit(e);
2023-10-13 20:38:31 +01:00
else
2023-11-10 18:05:48 +00:00
{
2023-10-13 20:38:31 +01:00
LUAU_ASSERT(!"NonStrictTypeChecker encountered an unknown expression type");
2023-11-10 18:05:48 +00:00
ice->ice("NonStrictTypeChecker encountered an unknown expression type");
}
}
NonStrictContext visit(AstExprGroup* group)
{
return {};
}
NonStrictContext visit(AstExprConstantNil* expr)
{
return {};
}
NonStrictContext visit(AstExprConstantBool* expr)
{
return {};
}
NonStrictContext visit(AstExprConstantNumber* expr)
{
return {};
}
NonStrictContext visit(AstExprConstantString* expr)
{
return {};
}
NonStrictContext visit(AstExprLocal* local)
{
return {};
}
NonStrictContext visit(AstExprGlobal* global)
{
return {};
2023-10-13 20:38:31 +01:00
}
2023-11-10 18:05:48 +00:00
NonStrictContext visit(AstExprVarargs* global)
{
return {};
}
2023-10-13 20:38:31 +01:00
2023-11-10 18:05:48 +00:00
NonStrictContext visit(AstExprCall* call)
2023-10-13 20:38:31 +01:00
{
2023-11-10 18:05:48 +00:00
NonStrictContext fresh{};
2023-10-13 20:38:31 +01:00
TypeId* originalCallTy = module->astOriginalCallTypes.find(call);
if (!originalCallTy)
2023-11-10 18:05:48 +00:00
return fresh;
2023-10-13 20:38:31 +01:00
TypeId fnTy = *originalCallTy;
if (auto fn = get<FunctionType>(follow(fnTy)))
{
if (fn->isCheckedFunction)
{
// We know fn is a checked function, which means it looks like:
// (S1, ... SN) -> T &
// (~S1, unknown^N-1) -> error &
// (unknown, ~S2, unknown^N-2) -> error
// ...
// ...
// (unknown^N-1, ~S_N) -> error
std::vector<TypeId> argTypes;
for (TypeId ty : fn->argTypes)
argTypes.push_back(ty);
// For a checked function, these gotta be the same size
LUAU_ASSERT(call->args.size == argTypes.size());
for (size_t i = 0; i < call->args.size; i++)
{
// For example, if the arg is "hi"
// The actual arg type is string
// The expected arg type is number
// The type of the argument in the overload is ~number
// We will compare arg and ~number
AstExpr* arg = call->args.data[i];
TypeId expectedArgType = argTypes[i];
2023-10-20 21:36:26 +01:00
DefId def = dfg->getDef(arg);
2023-10-13 20:38:31 +01:00
// TODO: Cache negations created here!!!
// See Jira Ticket: https://roblox.atlassian.net/browse/CLI-87539
2023-10-20 21:36:26 +01:00
TypeId runTimeErrorTy = arena.addType(NegationType{expectedArgType});
fresh.context[def.get()] = runTimeErrorTy;
2023-10-13 20:38:31 +01:00
}
// Populate the context and now iterate through each of the arguments to the call to find out if we satisfy the types
2023-10-20 21:36:26 +01:00
AstName name = getIdentifier(call->func);
2023-10-13 20:38:31 +01:00
for (size_t i = 0; i < call->args.size; i++)
{
AstExpr* arg = call->args.data[i];
if (auto runTimeFailureType = willRunTimeError(arg, fresh))
2023-10-20 21:36:26 +01:00
reportError(CheckedFunctionCallError{argTypes[i], *runTimeFailureType, name.value, i}, arg->location);
2023-10-13 20:38:31 +01:00
}
}
}
2023-11-10 18:05:48 +00:00
return fresh;
}
NonStrictContext visit(AstExprIndexName* indexName)
{
return {};
2023-10-13 20:38:31 +01:00
}
2023-11-10 18:05:48 +00:00
NonStrictContext visit(AstExprIndexExpr* indexExpr)
{
return {};
}
NonStrictContext visit(AstExprFunction* exprFn)
2023-10-13 20:38:31 +01:00
{
auto pusher = pushStack(exprFn);
2023-11-17 18:15:31 +00:00
return visit(exprFn->body);
2023-11-10 18:05:48 +00:00
}
NonStrictContext visit(AstExprTable* table)
{
return {};
}
NonStrictContext visit(AstExprUnary* unary)
{
return {};
}
NonStrictContext visit(AstExprBinary* binary)
{
return {};
}
NonStrictContext visit(AstExprTypeAssertion* typeAssertion)
{
return {};
}
NonStrictContext visit(AstExprIfElse* ifElse)
{
NonStrictContext condB = visit(ifElse->condition);
NonStrictContext thenB = visit(ifElse->trueExpr);
NonStrictContext elseB = visit(ifElse->falseExpr);
return NonStrictContext::disjunction(
builtinTypes, NotNull{&arena}, condB, NonStrictContext::conjunction(builtinTypes, NotNull{&arena}, thenB, elseB));
}
NonStrictContext visit(AstExprInterpString* interpString)
{
return {};
}
NonStrictContext visit(AstExprError* error)
{
return {};
2023-10-13 20:38:31 +01:00
}
void reportError(TypeErrorData data, const Location& location)
2023-09-30 01:22:06 +01:00
{
2023-10-13 20:38:31 +01:00
module->errors.emplace_back(location, module->name, std::move(data));
// TODO: weave in logger here?
}
// If this fragment of the ast will run time error, return the type that causes this
std::optional<TypeId> willRunTimeError(AstExpr* fragment, const NonStrictContext& context)
{
2023-10-20 21:36:26 +01:00
DefId def = dfg->getDef(fragment);
if (std::optional<TypeId> contextTy = context.find(def))
2023-10-13 20:38:31 +01:00
{
2023-10-20 21:36:26 +01:00
TypeId actualType = lookupType(fragment);
SubtypingResult r = subtyping.isSubtype(actualType, *contextTy);
if (r.normalizationTooComplex)
reportError(NormalizationTooComplex{}, fragment->location);
if (r.isSubtype)
return {actualType};
2023-10-13 20:38:31 +01:00
}
2023-10-20 21:36:26 +01:00
2023-10-13 20:38:31 +01:00
return {};
2023-09-30 01:22:06 +01:00
}
};
2023-10-13 20:38:31 +01:00
void checkNonStrict(NotNull<BuiltinTypes> builtinTypes, NotNull<InternalErrorReporter> ice, NotNull<UnifierSharedState> unifierState,
NotNull<const DataFlowGraph> dfg, NotNull<TypeCheckLimits> limits, const SourceModule& sourceModule, Module* module)
2023-09-30 01:22:06 +01:00
{
// TODO: unimplemented
2023-10-13 20:38:31 +01:00
NonStrictTypeChecker typeChecker{builtinTypes, ice, unifierState, dfg, limits, module};
typeChecker.visit(sourceModule.root);
unfreeze(module->interfaceTypes);
copyErrors(module->errors, module->interfaceTypes, builtinTypes);
freeze(module->interfaceTypes);
2023-09-30 01:22:06 +01:00
}
2023-10-13 20:38:31 +01:00
2023-09-30 01:22:06 +01:00
} // namespace Luau