luau/Analysis/include/Luau/ConstraintGenerator.h
vegorov-rbx f5dabc2998
Sync to upstream/release/644 (#1432)
In this update we improve overall stability of the new type solver and
address some type inference issues with it.

If you use the new solver and want to use all new fixes included in this
release, you have to reference an additional Luau flag:
```c++
LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease)
```
And set its value to `644`:
```c++
DFInt::LuauTypeSolverRelease.value = 644; // Or a higher value for future updates
```

## New Solver
* Fixed a debug assertion failure in autocomplete (Fixes #1391)
* Fixed type function distribution issue which transformed `len<>` and
`unm<>` into `not<>` (Fixes #1416)
* Placed a limit on the possible normalized table intersection size as a
temporary measure to avoid hangs and out-of-memory issues for complex
type refinements
* Internal recursion limits are now respected in the subtyping
operations and in autocomplete, to avoid stack overflow crashes
* Fixed false positive errors on assignments to tables whose indexers
are unions of strings
* Fixed memory corruption crashes in subtyping of generic types
containing other generic types in their bounds

---

Internal Contributors:

Co-authored-by: Aaron Weiss <aaronweiss@roblox.com>
Co-authored-by: Andy Friesen <afriesen@roblox.com>
Co-authored-by: Vyacheslav Egorov <vegorov@roblox.com>
2024-09-20 09:53:26 -07:00

436 lines
18 KiB
C++

// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Ast.h"
#include "Luau/Constraint.h"
#include "Luau/ControlFlow.h"
#include "Luau/DataFlowGraph.h"
#include "Luau/InsertionOrderedMap.h"
#include "Luau/Module.h"
#include "Luau/ModuleResolver.h"
#include "Luau/Normalize.h"
#include "Luau/NotNull.h"
#include "Luau/Refinement.h"
#include "Luau/Symbol.h"
#include "Luau/TypeFwd.h"
#include "Luau/TypeUtils.h"
#include "Luau/Variant.h"
#include "Luau/Normalize.h"
#include <memory>
#include <vector>
#include <unordered_map>
namespace Luau
{
struct Scope;
using ScopePtr = std::shared_ptr<Scope>;
struct DcrLogger;
struct Inference
{
TypeId ty = nullptr;
RefinementId refinement = nullptr;
Inference() = default;
explicit Inference(TypeId ty, RefinementId refinement = nullptr)
: ty(ty)
, refinement(refinement)
{
}
};
struct InferencePack
{
TypePackId tp = nullptr;
std::vector<RefinementId> refinements;
InferencePack() = default;
explicit InferencePack(TypePackId tp, const std::vector<RefinementId>& refinements = {})
: tp(tp)
, refinements(refinements)
{
}
};
struct ConstraintGenerator
{
// A list of all the scopes in the module. This vector holds ownership of the
// scope pointers; the scopes themselves borrow pointers to other scopes to
// define the scope hierarchy.
std::vector<std::pair<Location, ScopePtr>> scopes;
ModulePtr module;
NotNull<BuiltinTypes> builtinTypes;
const NotNull<TypeArena> arena;
// The root scope of the module we're generating constraints for.
// This is null when the CG is initially constructed.
Scope* rootScope;
TypeContext typeContext = TypeContext::Default;
struct InferredBinding
{
Scope* scope;
Location location;
TypeIds types;
};
// Some locals have multiple type states. We wish for Scope::bindings to
// map each local name onto the union of every type that the local can have
// over its lifetime, so we use this map to accumulate the set of types it
// might have.
//
// See the functions recordInferredBinding and fillInInferredBindings.
DenseHashMap<Symbol, InferredBinding> inferredBindings{{}};
// Constraints that go straight to the solver.
std::vector<ConstraintPtr> constraints;
// Constraints that do not go to the solver right away. Other constraints
// will enqueue them during solving.
std::vector<ConstraintPtr> unqueuedConstraints;
// The private scope of type aliases for which the type parameters belong to.
DenseHashMap<const AstStatTypeAlias*, ScopePtr> astTypeAliasDefiningScopes{nullptr};
NotNull<const DataFlowGraph> dfg;
RefinementArena refinementArena;
int recursionCount = 0;
// It is pretty uncommon for constraint generation to itself produce errors, but it can happen.
std::vector<TypeError> errors;
// Needed to be able to enable error-suppression preservation for immediate refinements.
NotNull<Normalizer> normalizer;
// Needed to resolve modules to make 'require' import types properly.
NotNull<ModuleResolver> moduleResolver;
// Occasionally constraint generation needs to produce an ICE.
const NotNull<InternalErrorReporter> ice;
ScopePtr globalScope;
std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope;
std::vector<RequireCycle> requireCycles;
DenseHashMap<TypeId, TypeIds> localTypes{nullptr};
DcrLogger* logger;
ConstraintGenerator(
ModulePtr module,
NotNull<Normalizer> normalizer,
NotNull<ModuleResolver> moduleResolver,
NotNull<BuiltinTypes> builtinTypes,
NotNull<InternalErrorReporter> ice,
const ScopePtr& globalScope,
std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope,
DcrLogger* logger,
NotNull<DataFlowGraph> dfg,
std::vector<RequireCycle> requireCycles
);
/**
* The entry point to the ConstraintGenerator. This will construct a set
* of scopes, constraints, and free types that can be solved later.
* @param block the root block to generate constraints for.
*/
void visitModuleRoot(AstStatBlock* block);
private:
std::vector<std::vector<TypeId>> interiorTypes;
/**
* Fabricates a new free type belonging to a given scope.
* @param scope the scope the free type belongs to.
*/
TypeId freshType(const ScopePtr& scope);
/**
* Fabricates a new free type pack belonging to a given scope.
* @param scope the scope the free type pack belongs to.
*/
TypePackId freshTypePack(const ScopePtr& scope);
/**
* Allocate a new TypePack with the given head and tail.
*
* Avoids allocating 0-length type packs:
*
* If the head is non-empty, allocate and return a type pack with the given
* head and tail.
* If the head is empty and tail is non-empty, return *tail.
* If both the head and tail are empty, return an empty type pack.
*/
TypePackId addTypePack(std::vector<TypeId> head, std::optional<TypePackId> tail);
/**
* Fabricates a scope that is a child of another scope.
* @param node the lexical node that the scope belongs to.
* @param parent the parent scope of the new scope. Must not be null.
*/
ScopePtr childScope(AstNode* node, const ScopePtr& parent);
std::optional<TypeId> lookup(const ScopePtr& scope, Location location, DefId def, bool prototype = true);
/**
* Adds a new constraint with no dependencies to a given scope.
* @param scope the scope to add the constraint to.
* @param cv the constraint variant to add.
* @return the pointer to the inserted constraint
*/
NotNull<Constraint> addConstraint(const ScopePtr& scope, const Location& location, ConstraintV cv);
/**
* Adds a constraint to a given scope.
* @param scope the scope to add the constraint to. Must not be null.
* @param c the constraint to add.
* @return the pointer to the inserted constraint
*/
NotNull<Constraint> addConstraint(const ScopePtr& scope, std::unique_ptr<Constraint> c);
struct RefinementPartition
{
// Types that we want to intersect against the type of the expression.
std::vector<TypeId> discriminantTypes;
// Sometimes the type we're discriminating against is implicitly nil.
bool shouldAppendNilType = false;
};
using RefinementContext = InsertionOrderedMap<DefId, RefinementPartition>;
void unionRefinements(
const ScopePtr& scope,
Location location,
const RefinementContext& lhs,
const RefinementContext& rhs,
RefinementContext& dest,
std::vector<ConstraintV>* constraints
);
void computeRefinement(
const ScopePtr& scope,
Location location,
RefinementId refinement,
RefinementContext* refis,
bool sense,
bool eq,
std::vector<ConstraintV>* constraints
);
void applyRefinements(const ScopePtr& scope, Location location, RefinementId refinement);
ControlFlow visitBlockWithoutChildScope(const ScopePtr& scope, AstStatBlock* block);
ControlFlow visit(const ScopePtr& scope, AstStat* stat);
ControlFlow visit(const ScopePtr& scope, AstStatBlock* block);
ControlFlow visit(const ScopePtr& scope, AstStatLocal* local);
ControlFlow visit(const ScopePtr& scope, AstStatFor* for_);
ControlFlow visit(const ScopePtr& scope, AstStatForIn* forIn);
ControlFlow visit(const ScopePtr& scope, AstStatWhile* while_);
ControlFlow visit(const ScopePtr& scope, AstStatRepeat* repeat);
ControlFlow visit(const ScopePtr& scope, AstStatLocalFunction* function);
ControlFlow visit(const ScopePtr& scope, AstStatFunction* function);
ControlFlow visit(const ScopePtr& scope, AstStatReturn* ret);
ControlFlow visit(const ScopePtr& scope, AstStatAssign* assign);
ControlFlow visit(const ScopePtr& scope, AstStatCompoundAssign* assign);
ControlFlow visit(const ScopePtr& scope, AstStatIf* ifStatement);
ControlFlow visit(const ScopePtr& scope, AstStatTypeAlias* alias);
ControlFlow visit(const ScopePtr& scope, AstStatTypeFunction* function);
ControlFlow visit(const ScopePtr& scope, AstStatDeclareGlobal* declareGlobal);
ControlFlow visit(const ScopePtr& scope, AstStatDeclareClass* declareClass);
ControlFlow visit(const ScopePtr& scope, AstStatDeclareFunction* declareFunction);
ControlFlow visit(const ScopePtr& scope, AstStatError* error);
InferencePack checkPack(const ScopePtr& scope, AstArray<AstExpr*> exprs, const std::vector<std::optional<TypeId>>& expectedTypes = {});
InferencePack checkPack(
const ScopePtr& scope,
AstExpr* expr,
const std::vector<std::optional<TypeId>>& expectedTypes = {},
bool generalize = true
);
InferencePack checkPack(const ScopePtr& scope, AstExprCall* call);
/**
* Checks an expression that is expected to evaluate to one type.
* @param scope the scope the expression is contained within.
* @param expr the expression to check.
* @param expectedType the type of the expression that is expected from its
* surrounding context. Used to implement bidirectional type checking.
* @param generalize If true, generalize any lambdas that are encountered.
* @return the type of the expression.
*/
Inference check(
const ScopePtr& scope,
AstExpr* expr,
std::optional<TypeId> expectedType = {},
bool forceSingleton = false,
bool generalize = true
);
Inference check(const ScopePtr& scope, AstExprConstantString* string, std::optional<TypeId> expectedType, bool forceSingleton);
Inference check(const ScopePtr& scope, AstExprConstantBool* bool_, std::optional<TypeId> expectedType, bool forceSingleton);
Inference check(const ScopePtr& scope, AstExprLocal* local);
Inference check(const ScopePtr& scope, AstExprGlobal* global);
Inference checkIndexName(const ScopePtr& scope, const RefinementKey* key, AstExpr* indexee, const std::string& index, Location indexLocation);
Inference check(const ScopePtr& scope, AstExprIndexName* indexName);
Inference check(const ScopePtr& scope, AstExprIndexExpr* indexExpr);
Inference check(const ScopePtr& scope, AstExprFunction* func, std::optional<TypeId> expectedType, bool generalize);
Inference check(const ScopePtr& scope, AstExprUnary* unary);
Inference check(const ScopePtr& scope, AstExprBinary* binary, std::optional<TypeId> expectedType);
Inference check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional<TypeId> expectedType);
Inference check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert);
Inference check(const ScopePtr& scope, AstExprInterpString* interpString);
Inference check(const ScopePtr& scope, AstExprTable* expr, std::optional<TypeId> expectedType);
std::tuple<TypeId, TypeId, RefinementId> checkBinary(const ScopePtr& scope, AstExprBinary* binary, std::optional<TypeId> expectedType);
void visitLValue(const ScopePtr& scope, AstExpr* expr, TypeId rhsType);
void visitLValue(const ScopePtr& scope, AstExprLocal* local, TypeId rhsType);
void visitLValue(const ScopePtr& scope, AstExprGlobal* global, TypeId rhsType);
void visitLValue(const ScopePtr& scope, AstExprIndexName* indexName, TypeId rhsType);
void visitLValue(const ScopePtr& scope, AstExprIndexExpr* indexExpr, TypeId rhsType);
struct FunctionSignature
{
// The type of the function.
TypeId signature;
// The scope that encompasses the function's signature. May be nullptr
// if there was no need for a signature scope (the function has no
// generics).
ScopePtr signatureScope;
// The scope that encompasses the function's body. Is a child scope of
// signatureScope, if present.
ScopePtr bodyScope;
};
FunctionSignature checkFunctionSignature(
const ScopePtr& parent,
AstExprFunction* fn,
std::optional<TypeId> expectedType = {},
std::optional<Location> originalName = {}
);
/**
* Checks the body of a function expression.
* @param scope the interior scope of the body of the function.
* @param fn the function expression to check.
*/
void checkFunctionBody(const ScopePtr& scope, AstExprFunction* fn);
// Specializations of 'resolveType' below
TypeId resolveReferenceType(const ScopePtr& scope, AstType* ty, AstTypeReference* ref, bool inTypeArguments, bool replaceErrorWithFresh);
TypeId resolveTableType(const ScopePtr& scope, AstType* ty, AstTypeTable* tab, bool inTypeArguments, bool replaceErrorWithFresh);
TypeId resolveFunctionType(const ScopePtr& scope, AstType* ty, AstTypeFunction* fn, bool inTypeArguments, bool replaceErrorWithFresh);
/**
* Resolves a type from its AST annotation.
* @param scope the scope that the type annotation appears within.
* @param ty the AST annotation to resolve.
* @param inTypeArguments whether we are resolving a type that's contained within type arguments, `<...>`.
* @return the type of the AST annotation.
**/
TypeId resolveType(const ScopePtr& scope, AstType* ty, bool inTypeArguments, bool replaceErrorWithFresh = false);
/**
* Resolves a type pack from its AST annotation.
* @param scope the scope that the type annotation appears within.
* @param tp the AST annotation to resolve.
* @param inTypeArguments whether we are resolving a type that's contained within type arguments, `<...>`.
* @return the type pack of the AST annotation.
**/
TypePackId resolveTypePack(const ScopePtr& scope, AstTypePack* tp, bool inTypeArguments, bool replaceErrorWithFresh = false);
/**
* Resolves a type pack from its AST annotation.
* @param scope the scope that the type annotation appears within.
* @param list the AST annotation to resolve.
* @param inTypeArguments whether we are resolving a type that's contained within type arguments, `<...>`.
* @return the type pack of the AST annotation.
**/
TypePackId resolveTypePack(const ScopePtr& scope, const AstTypeList& list, bool inTypeArguments, bool replaceErrorWithFresh = false);
/**
* Creates generic types given a list of AST definitions, resolving default
* types as required.
* @param scope the scope that the generics should belong to.
* @param generics the AST generics to create types for.
* @param useCache whether to use the generic type cache for the given
* scope.
* @param addTypes whether to add the types to the scope's
* privateTypeBindings map.
**/
std::vector<std::pair<Name, GenericTypeDefinition>> createGenerics(
const ScopePtr& scope,
AstArray<AstGenericType> generics,
bool useCache = false,
bool addTypes = true
);
/**
* Creates generic type packs given a list of AST definitions, resolving
* default type packs as required.
* @param scope the scope that the generic packs should belong to.
* @param generics the AST generics to create type packs for.
* @param useCache whether to use the generic type pack cache for the given
* scope.
* @param addTypes whether to add the types to the scope's
* privateTypePackBindings map.
**/
std::vector<std::pair<Name, GenericTypePackDefinition>> createGenericPacks(
const ScopePtr& scope,
AstArray<AstGenericTypePack> packs,
bool useCache = false,
bool addTypes = true
);
Inference flattenPack(const ScopePtr& scope, Location location, InferencePack pack);
void reportError(Location location, TypeErrorData err);
void reportCodeTooComplex(Location location);
// make a union type function of these two types
TypeId makeUnion(const ScopePtr& scope, Location location, TypeId lhs, TypeId rhs);
// make an intersect type function of these two types
TypeId makeIntersect(const ScopePtr& scope, Location location, TypeId lhs, TypeId rhs);
/** Scan the program for global definitions.
*
* ConstraintGenerator needs to differentiate between globals and accesses to undefined symbols. Doing this "for
* real" in a general way is going to be pretty hard, so we are choosing not to tackle that yet. For now, we do an
* initial scan of the AST and note what globals are defined.
*/
void prepopulateGlobalScope(const ScopePtr& globalScope, AstStatBlock* program);
bool recordPropertyAssignment(TypeId ty);
// Record the fact that a particular local has a particular type in at least
// one of its states.
void recordInferredBinding(AstLocal* local, TypeId ty);
void fillInInferredBindings(const ScopePtr& globalScope, AstStatBlock* block);
/** Given a function type annotation, return a vector describing the expected types of the calls to the function
* For example, calling a function with annotation ((number) -> string & ((string) -> number))
* yields a vector of size 1, with value: [number | string]
*/
std::vector<std::optional<TypeId>> getExpectedCallTypesForFunctionOverloads(const TypeId fnType);
TypeId createTypeFunctionInstance(
const TypeFunction& function,
std::vector<TypeId> typeArguments,
std::vector<TypePackId> packArguments,
const ScopePtr& scope,
Location location
);
};
/** Borrow a vector of pointers from a vector of owning pointers to constraints.
*/
std::vector<NotNull<Constraint>> borrowConstraints(const std::vector<ConstraintPtr>& constraints);
} // namespace Luau