luau/Analysis/src/NonStrictTypeChecker.cpp
Andy Friesen 74c532053f
Sync to upstream/release/604 (#1106)
New Solver

* New algorithm for inferring the types of locals that have no
annotations. This
algorithm is very conservative by default, but is augmented with some
control
  flow awareness to handle most common scenarios.
* Fix bugs in type inference of tables
* Improve performance of by switching out standard C++ containers for
`DenseHashMap`
* Infrastructure to support clearer error messages in strict mode

Native Code Generation

* Fix a lowering issue with buffer.writeu8 and 0x80-0xff values: A
constant
  argument wasn't truncated to the target type range and that causes an
  assertion failure in `build.mov`.
* Store full lightuserdata value in loop iteration protocol lowering
* Add analysis to compute function bytecode distribution
* This includes a class to analyze the bytecode operator distribution
per
function and a CLI tool that produces a JSON report. See the new cmake
      target `Luau.Bytecode.CLI`

---------

Co-authored-by: Aaron Weiss <aaronweiss@roblox.com>
Co-authored-by: Alexander McCord <amccord@roblox.com>
Co-authored-by: Andy Friesen <afriesen@roblox.com>
Co-authored-by: Aviral Goel <agoel@roblox.com>
Co-authored-by: Lily Brown <lbrown@roblox.com>
Co-authored-by: Vyacheslav Egorov <vegorov@roblox.com>
2023-11-17 10:46:18 -08:00

619 lines
18 KiB
C++

// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/NonStrictTypeChecker.h"
#include "Luau/Ast.h"
#include "Luau/Common.h"
#include "Luau/Simplify.h"
#include "Luau/Type.h"
#include "Luau/Simplify.h"
#include "Luau/Subtyping.h"
#include "Luau/Normalize.h"
#include "Luau/Error.h"
#include "Luau/TypeArena.h"
#include "Luau/TypeFamily.h"
#include "Luau/Def.h"
#include <iostream>
namespace Luau
{
/* Push a scope onto the end of a stack for the lifetime of the StackPusher instance.
* NonStrictTypeChecker uses this to maintain knowledge about which scope encloses every
* given AstNode.
*/
struct StackPusher
{
std::vector<NotNull<Scope>>* stack;
NotNull<Scope> scope;
explicit StackPusher(std::vector<NotNull<Scope>>& stack, Scope* scope)
: stack(&stack)
, scope(scope)
{
stack.push_back(NotNull{scope});
}
~StackPusher()
{
if (stack)
{
LUAU_ASSERT(stack->back() == scope);
stack->pop_back();
}
}
StackPusher(const StackPusher&) = delete;
StackPusher&& operator=(const StackPusher&) = delete;
StackPusher(StackPusher&& other)
: stack(std::exchange(other.stack, nullptr))
, scope(other.scope)
{
}
};
struct NonStrictContext
{
std::unordered_map<const Def*, TypeId> context;
NonStrictContext() = default;
NonStrictContext(const NonStrictContext&) = delete;
NonStrictContext& operator=(const NonStrictContext&) = delete;
NonStrictContext(NonStrictContext&&) = default;
NonStrictContext& operator=(NonStrictContext&&) = default;
static NonStrictContext disjunction(
NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, const NonStrictContext& left, const NonStrictContext& right)
{
// disjunction implements union over the domain of keys
// if the default value for a defId not in the map is `never`
// then never | T is T
NonStrictContext disj{};
for (auto [def, leftTy] : left.context)
{
if (std::optional<TypeId> rightTy = right.find(def))
disj.context[def] = simplifyUnion(builtinTypes, arena, leftTy, *rightTy).result;
else
disj.context[def] = leftTy;
}
for (auto [def, rightTy] : right.context)
{
if (!left.find(def).has_value())
disj.context[def] = rightTy;
}
return disj;
}
static NonStrictContext conjunction(
NotNull<BuiltinTypes> builtins, NotNull<TypeArena> arena, const NonStrictContext& left, const NonStrictContext& right)
{
NonStrictContext conj{};
for (auto [def, leftTy] : left.context)
{
if (std::optional<TypeId> rightTy = right.find(def))
conj.context[def] = simplifyIntersection(builtins, arena, leftTy, *rightTy).result;
}
return conj;
}
void removeFromContext(const std::vector<DefId>& defs)
{
// TODO: unimplemented
}
std::optional<TypeId> find(const DefId& def) const
{
const Def* d = def.get();
return find(d);
}
private:
std::optional<TypeId> find(const Def* d) const
{
auto it = context.find(d);
if (it != context.end())
return {it->second};
return {};
}
};
struct NonStrictTypeChecker
{
NotNull<BuiltinTypes> builtinTypes;
const NotNull<InternalErrorReporter> ice;
TypeArena arena;
Module* module;
Normalizer normalizer;
Subtyping subtyping;
NotNull<const DataFlowGraph> dfg;
DenseHashSet<TypeId> noTypeFamilyErrors{nullptr};
std::vector<NotNull<Scope>> stack;
const NotNull<TypeCheckLimits> limits;
NonStrictTypeChecker(NotNull<BuiltinTypes> builtinTypes, const NotNull<InternalErrorReporter> ice, NotNull<UnifierSharedState> unifierState,
NotNull<const DataFlowGraph> dfg, NotNull<TypeCheckLimits> limits, Module* module)
: builtinTypes(builtinTypes)
, ice(ice)
, module(module)
, normalizer{&arena, builtinTypes, unifierState, /* cache inhabitance */ true}
, subtyping{builtinTypes, NotNull{&arena}, NotNull(&normalizer), ice, NotNull{module->getModuleScope().get()}}
, dfg(dfg)
, limits(limits)
{
}
std::optional<StackPusher> pushStack(AstNode* node)
{
if (Scope** scope = module->astScopes.find(node))
return StackPusher{stack, *scope};
else
return std::nullopt;
}
TypeId flattenPack(TypePackId pack)
{
pack = follow(pack);
if (auto fst = first(pack, /*ignoreHiddenVariadics*/ false))
return *fst;
else if (auto ftp = get<FreeTypePack>(pack))
{
TypeId result = arena.addType(FreeType{ftp->scope});
TypePackId freeTail = arena.addTypePack(FreeTypePack{ftp->scope});
TypePack& resultPack = asMutable(pack)->ty.emplace<TypePack>();
resultPack.head.assign(1, result);
resultPack.tail = freeTail;
return result;
}
else if (get<Unifiable::Error>(pack))
return builtinTypes->errorRecoveryType();
else if (finite(pack) && size(pack) == 0)
return builtinTypes->nilType; // `(f())` where `f()` returns no values is coerced into `nil`
else
ice->ice("flattenPack got a weird pack!");
}
TypeId checkForFamilyInhabitance(TypeId instance, Location location)
{
if (noTypeFamilyErrors.find(instance))
return instance;
ErrorVec errors = reduceFamilies(
instance, location, TypeFamilyContext{NotNull{&arena}, builtinTypes, stack.back(), NotNull{&normalizer}, ice, limits}, true)
.errors;
if (errors.empty())
noTypeFamilyErrors.insert(instance);
// TODO??
// if (!isErrorSuppressing(location, instance))
// reportErrors(std::move(errors));
return instance;
}
TypeId lookupType(AstExpr* expr)
{
TypeId* ty = module->astTypes.find(expr);
if (ty)
return checkForFamilyInhabitance(follow(*ty), expr->location);
TypePackId* tp = module->astTypePacks.find(expr);
if (tp)
return checkForFamilyInhabitance(flattenPack(*tp), expr->location);
return builtinTypes->anyType;
}
NonStrictContext visit(AstStat* stat)
{
auto pusher = pushStack(stat);
if (auto s = stat->as<AstStatBlock>())
return visit(s);
else if (auto s = stat->as<AstStatIf>())
return visit(s);
else if (auto s = stat->as<AstStatWhile>())
return visit(s);
else if (auto s = stat->as<AstStatRepeat>())
return visit(s);
else if (auto s = stat->as<AstStatBreak>())
return visit(s);
else if (auto s = stat->as<AstStatContinue>())
return visit(s);
else if (auto s = stat->as<AstStatReturn>())
return visit(s);
else if (auto s = stat->as<AstStatExpr>())
return visit(s);
else if (auto s = stat->as<AstStatLocal>())
return visit(s);
else if (auto s = stat->as<AstStatFor>())
return visit(s);
else if (auto s = stat->as<AstStatForIn>())
return visit(s);
else if (auto s = stat->as<AstStatAssign>())
return visit(s);
else if (auto s = stat->as<AstStatCompoundAssign>())
return visit(s);
else if (auto s = stat->as<AstStatFunction>())
return visit(s);
else if (auto s = stat->as<AstStatLocalFunction>())
return visit(s);
else if (auto s = stat->as<AstStatTypeAlias>())
return visit(s);
else if (auto s = stat->as<AstStatDeclareFunction>())
return visit(s);
else if (auto s = stat->as<AstStatDeclareGlobal>())
return visit(s);
else if (auto s = stat->as<AstStatDeclareClass>())
return visit(s);
else if (auto s = stat->as<AstStatError>())
return visit(s);
else
{
LUAU_ASSERT(!"NonStrictTypeChecker encountered an unknown statement type");
ice->ice("NonStrictTypeChecker encountered an unknown statement type");
}
}
NonStrictContext visit(AstStatBlock* block)
{
auto StackPusher = pushStack(block);
NonStrictContext ctx;
for (AstStat* statement : block->body)
ctx = NonStrictContext::disjunction(builtinTypes, NotNull{&arena}, ctx, visit(statement));
return ctx;
}
NonStrictContext visit(AstStatIf* ifStatement)
{
NonStrictContext condB = visit(ifStatement->condition);
NonStrictContext branchContext;
// If there is no else branch, don't bother generating warnings for the then branch - we can't prove there is an error
if (ifStatement->elsebody)
{
NonStrictContext thenBody = visit(ifStatement->thenbody);
NonStrictContext elseBody = visit(ifStatement->elsebody);
branchContext = NonStrictContext::conjunction(builtinTypes, NotNull{&arena}, thenBody, elseBody);
}
return NonStrictContext::disjunction(builtinTypes, NotNull{&arena}, condB, branchContext);
}
NonStrictContext visit(AstStatWhile* whileStatement)
{
return {};
}
NonStrictContext visit(AstStatRepeat* repeatStatement)
{
return {};
}
NonStrictContext visit(AstStatBreak* breakStatement)
{
return {};
}
NonStrictContext visit(AstStatContinue* continueStatement)
{
return {};
}
NonStrictContext visit(AstStatReturn* returnStatement)
{
return {};
}
NonStrictContext visit(AstStatExpr* expr)
{
return visit(expr->expr);
}
NonStrictContext visit(AstStatLocal* local)
{
for (AstExpr* rhs : local->values)
visit(rhs);
return {};
}
NonStrictContext visit(AstStatFor* forStatement)
{
return {};
}
NonStrictContext visit(AstStatForIn* forInStatement)
{
return {};
}
NonStrictContext visit(AstStatAssign* assign)
{
return {};
}
NonStrictContext visit(AstStatCompoundAssign* compoundAssign)
{
return {};
}
NonStrictContext visit(AstStatFunction* statFn)
{
return visit(statFn->func);
}
NonStrictContext visit(AstStatLocalFunction* localFn)
{
return visit(localFn->func);
}
NonStrictContext visit(AstStatTypeAlias* typeAlias)
{
return {};
}
NonStrictContext visit(AstStatDeclareFunction* declFn)
{
return {};
}
NonStrictContext visit(AstStatDeclareGlobal* declGlobal)
{
return {};
}
NonStrictContext visit(AstStatDeclareClass* declClass)
{
return {};
}
NonStrictContext visit(AstStatError* error)
{
return {};
}
NonStrictContext visit(AstExpr* expr)
{
auto pusher = pushStack(expr);
if (auto e = expr->as<AstExprGroup>())
return visit(e);
else if (auto e = expr->as<AstExprConstantNil>())
return visit(e);
else if (auto e = expr->as<AstExprConstantBool>())
return visit(e);
else if (auto e = expr->as<AstExprConstantNumber>())
return visit(e);
else if (auto e = expr->as<AstExprConstantString>())
return visit(e);
else if (auto e = expr->as<AstExprLocal>())
return visit(e);
else if (auto e = expr->as<AstExprGlobal>())
return visit(e);
else if (auto e = expr->as<AstExprVarargs>())
return visit(e);
else if (auto e = expr->as<AstExprCall>())
return visit(e);
else if (auto e = expr->as<AstExprIndexName>())
return visit(e);
else if (auto e = expr->as<AstExprIndexExpr>())
return visit(e);
else if (auto e = expr->as<AstExprFunction>())
return visit(e);
else if (auto e = expr->as<AstExprTable>())
return visit(e);
else if (auto e = expr->as<AstExprUnary>())
return visit(e);
else if (auto e = expr->as<AstExprBinary>())
return visit(e);
else if (auto e = expr->as<AstExprTypeAssertion>())
return visit(e);
else if (auto e = expr->as<AstExprIfElse>())
return visit(e);
else if (auto e = expr->as<AstExprInterpString>())
return visit(e);
else if (auto e = expr->as<AstExprError>())
return visit(e);
else
{
LUAU_ASSERT(!"NonStrictTypeChecker encountered an unknown expression type");
ice->ice("NonStrictTypeChecker encountered an unknown expression type");
}
}
NonStrictContext visit(AstExprGroup* group)
{
return {};
}
NonStrictContext visit(AstExprConstantNil* expr)
{
return {};
}
NonStrictContext visit(AstExprConstantBool* expr)
{
return {};
}
NonStrictContext visit(AstExprConstantNumber* expr)
{
return {};
}
NonStrictContext visit(AstExprConstantString* expr)
{
return {};
}
NonStrictContext visit(AstExprLocal* local)
{
return {};
}
NonStrictContext visit(AstExprGlobal* global)
{
return {};
}
NonStrictContext visit(AstExprVarargs* global)
{
return {};
}
NonStrictContext visit(AstExprCall* call)
{
NonStrictContext fresh{};
TypeId* originalCallTy = module->astOriginalCallTypes.find(call);
if (!originalCallTy)
return fresh;
TypeId fnTy = *originalCallTy;
if (auto fn = get<FunctionType>(follow(fnTy)))
{
if (fn->isCheckedFunction)
{
// We know fn is a checked function, which means it looks like:
// (S1, ... SN) -> T &
// (~S1, unknown^N-1) -> error &
// (unknown, ~S2, unknown^N-2) -> error
// ...
// ...
// (unknown^N-1, ~S_N) -> error
std::vector<TypeId> argTypes;
for (TypeId ty : fn->argTypes)
argTypes.push_back(ty);
// For a checked function, these gotta be the same size
LUAU_ASSERT(call->args.size == argTypes.size());
for (size_t i = 0; i < call->args.size; i++)
{
// For example, if the arg is "hi"
// The actual arg type is string
// The expected arg type is number
// The type of the argument in the overload is ~number
// We will compare arg and ~number
AstExpr* arg = call->args.data[i];
TypeId expectedArgType = argTypes[i];
DefId def = dfg->getDef(arg);
// TODO: Cache negations created here!!!
// See Jira Ticket: https://roblox.atlassian.net/browse/CLI-87539
TypeId runTimeErrorTy = arena.addType(NegationType{expectedArgType});
fresh.context[def.get()] = runTimeErrorTy;
}
// Populate the context and now iterate through each of the arguments to the call to find out if we satisfy the types
AstName name = getIdentifier(call->func);
for (size_t i = 0; i < call->args.size; i++)
{
AstExpr* arg = call->args.data[i];
if (auto runTimeFailureType = willRunTimeError(arg, fresh))
reportError(CheckedFunctionCallError{argTypes[i], *runTimeFailureType, name.value, i}, arg->location);
}
}
}
return fresh;
}
NonStrictContext visit(AstExprIndexName* indexName)
{
return {};
}
NonStrictContext visit(AstExprIndexExpr* indexExpr)
{
return {};
}
NonStrictContext visit(AstExprFunction* exprFn)
{
auto pusher = pushStack(exprFn);
return visit(exprFn->body);
}
NonStrictContext visit(AstExprTable* table)
{
return {};
}
NonStrictContext visit(AstExprUnary* unary)
{
return {};
}
NonStrictContext visit(AstExprBinary* binary)
{
return {};
}
NonStrictContext visit(AstExprTypeAssertion* typeAssertion)
{
return {};
}
NonStrictContext visit(AstExprIfElse* ifElse)
{
NonStrictContext condB = visit(ifElse->condition);
NonStrictContext thenB = visit(ifElse->trueExpr);
NonStrictContext elseB = visit(ifElse->falseExpr);
return NonStrictContext::disjunction(
builtinTypes, NotNull{&arena}, condB, NonStrictContext::conjunction(builtinTypes, NotNull{&arena}, thenB, elseB));
}
NonStrictContext visit(AstExprInterpString* interpString)
{
return {};
}
NonStrictContext visit(AstExprError* error)
{
return {};
}
void reportError(TypeErrorData data, const Location& location)
{
module->errors.emplace_back(location, module->name, std::move(data));
// TODO: weave in logger here?
}
// If this fragment of the ast will run time error, return the type that causes this
std::optional<TypeId> willRunTimeError(AstExpr* fragment, const NonStrictContext& context)
{
DefId def = dfg->getDef(fragment);
if (std::optional<TypeId> contextTy = context.find(def))
{
TypeId actualType = lookupType(fragment);
SubtypingResult r = subtyping.isSubtype(actualType, *contextTy);
if (r.normalizationTooComplex)
reportError(NormalizationTooComplex{}, fragment->location);
if (r.isSubtype)
return {actualType};
}
return {};
}
};
void checkNonStrict(NotNull<BuiltinTypes> builtinTypes, NotNull<InternalErrorReporter> ice, NotNull<UnifierSharedState> unifierState,
NotNull<const DataFlowGraph> dfg, NotNull<TypeCheckLimits> limits, const SourceModule& sourceModule, Module* module)
{
// TODO: unimplemented
NonStrictTypeChecker typeChecker{builtinTypes, ice, unifierState, dfg, limits, module};
typeChecker.visit(sourceModule.root);
unfreeze(module->interfaceTypes);
copyErrors(module->errors, module->interfaceTypes, builtinTypes);
freeze(module->interfaceTypes);
}
} // namespace Luau