luau/Analysis/include/Luau/ConstraintGenerator.h
ariel 640ebbc0a5
Sync to upstream/release/663 (#1699)
Hey folks, another week means another Luau release! This one features a
number of bug fixes in the New Type Solver including improvements to
user-defined type functions and a bunch of work to untangle some of the
outstanding issues we've been seeing with constraint solving not
completing in real world use. We're also continuing to make progress on
crashes and other problems that affect the stability of fragment
autocomplete, as we work towards delivering consistent, low-latency
autocomplete for any editor environment.

## New Type Solver

- Fix a bug in user-defined type functions where `print` would
incorrectly insert `\1` a number of times.
- Fix a bug where attempting to refine an optional generic with a type
test will cause a false positive type error (fixes #1666)
- Fix a bug where the `refine` type family would not skip over
`*no-refine*` discriminants (partial resolution for #1424)
- Fix a constraint solving bug where recursive function calls would
consistently produce cyclic constraints leading to incomplete or
inaccurate type inference.
- Implement `readparent` and `writeparent` for class types in
user-defined type functions, replacing the incorrectly included `parent`
method.
- Add initial groundwork (under a debug flag) for eager free type
generalization, moving us towards further improvements to constraint
solving incomplete errors.

## Fragment Autocomplete

- Ease up some assertions to improve stability of mixed-mode use of the
two type solvers (i.e. using Fragment Autocomplete on a type graph
originally produced by the old type solver)
- Resolve a bug with type compatibility checks causing internal compiler
errors in autocomplete.

## Lexer and Parser

- Improve the accuracy of the roundtrippable AST parsing mode by
correctly placing closing parentheses on type groupings.
- Add a getter for `offset` in the Lexer by @aduermael in #1688
- Add a second entry point to the parser to parse an expression,
`parseExpr`

## Internal Contributors

Co-authored-by: Andy Friesen <afriesen@roblox.com>
Co-authored-by: Ariel Weiss <aaronweiss@roblox.com>
Co-authored-by: Aviral Goel <agoel@roblox.com>
Co-authored-by: Hunter Goldstein <hgoldstein@roblox.com>
Co-authored-by: James McNellis <jmcnellis@roblox.com>
Co-authored-by: Talha Pathan <tpathan@roblox.com>
Co-authored-by: Vighnesh Vijay <vvijay@roblox.com>
Co-authored-by: Vyacheslav Egorov <vegorov@roblox.com>

---------

Co-authored-by: Hunter Goldstein <hgoldstein@roblox.com>
Co-authored-by: Varun Saini <61795485+vrn-sn@users.noreply.github.com>
Co-authored-by: Alexander Youngblood <ayoungblood@roblox.com>
Co-authored-by: Menarul Alam <malam@roblox.com>
Co-authored-by: Aviral Goel <agoel@roblox.com>
Co-authored-by: Vighnesh <vvijay@roblox.com>
Co-authored-by: Vyacheslav Egorov <vegorov@roblox.com>
2025-02-28 14:42:30 -08:00

469 lines
19 KiB
C++

// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Ast.h"
#include "Luau/Constraint.h"
#include "Luau/ControlFlow.h"
#include "Luau/DataFlowGraph.h"
#include "Luau/EqSatSimplification.h"
#include "Luau/InsertionOrderedMap.h"
#include "Luau/Module.h"
#include "Luau/ModuleResolver.h"
#include "Luau/Normalize.h"
#include "Luau/NotNull.h"
#include "Luau/Refinement.h"
#include "Luau/Symbol.h"
#include "Luau/TypeFwd.h"
#include "Luau/TypeUtils.h"
#include "Luau/Variant.h"
#include <memory>
#include <vector>
#include <unordered_map>
namespace Luau
{
struct Scope;
using ScopePtr = std::shared_ptr<Scope>;
struct DcrLogger;
struct TypeFunctionRuntime;
struct Inference
{
TypeId ty = nullptr;
RefinementId refinement = nullptr;
Inference() = default;
explicit Inference(TypeId ty, RefinementId refinement = nullptr)
: ty(ty)
, refinement(refinement)
{
}
};
struct InferencePack
{
TypePackId tp = nullptr;
std::vector<RefinementId> refinements;
InferencePack() = default;
explicit InferencePack(TypePackId tp, const std::vector<RefinementId>& refinements = {})
: tp(tp)
, refinements(refinements)
{
}
};
struct ConstraintGenerator
{
// A list of all the scopes in the module. This vector holds ownership of the
// scope pointers; the scopes themselves borrow pointers to other scopes to
// define the scope hierarchy.
std::vector<std::pair<Location, ScopePtr>> scopes;
ModulePtr module;
NotNull<BuiltinTypes> builtinTypes;
const NotNull<TypeArena> arena;
// The root scope of the module we're generating constraints for.
// This is null when the CG is initially constructed.
Scope* rootScope;
TypeContext typeContext = TypeContext::Default;
struct InferredBinding
{
Scope* scope;
Location location;
TypeIds types;
};
// Some locals have multiple type states. We wish for Scope::bindings to
// map each local name onto the union of every type that the local can have
// over its lifetime, so we use this map to accumulate the set of types it
// might have.
//
// See the functions recordInferredBinding and fillInInferredBindings.
DenseHashMap<Symbol, InferredBinding> inferredBindings{{}};
// Constraints that go straight to the solver.
std::vector<ConstraintPtr> constraints;
// Constraints that do not go to the solver right away. Other constraints
// will enqueue them during solving.
std::vector<ConstraintPtr> unqueuedConstraints;
// Map a function's signature scope back to its signature type.
DenseHashMap<Scope*, TypeId> scopeToFunction{nullptr};
// The private scope of type aliases for which the type parameters belong to.
DenseHashMap<const AstStatTypeAlias*, ScopePtr> astTypeAliasDefiningScopes{nullptr};
NotNull<const DataFlowGraph> dfg;
RefinementArena refinementArena;
int recursionCount = 0;
// It is pretty uncommon for constraint generation to itself produce errors, but it can happen.
std::vector<TypeError> errors;
// Needed to be able to enable error-suppression preservation for immediate refinements.
NotNull<Normalizer> normalizer;
NotNull<Simplifier> simplifier;
// Needed to register all available type functions for execution at later stages.
NotNull<TypeFunctionRuntime> typeFunctionRuntime;
// Needed to resolve modules to make 'require' import types properly.
NotNull<ModuleResolver> moduleResolver;
// Occasionally constraint generation needs to produce an ICE.
const NotNull<InternalErrorReporter> ice;
ScopePtr globalScope;
std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope;
std::vector<RequireCycle> requireCycles;
DenseHashMap<TypeId, TypeIds> localTypes{nullptr};
DcrLogger* logger;
ConstraintGenerator(
ModulePtr module,
NotNull<Normalizer> normalizer,
NotNull<Simplifier> simplifier,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<ModuleResolver> moduleResolver,
NotNull<BuiltinTypes> builtinTypes,
NotNull<InternalErrorReporter> ice,
const ScopePtr& globalScope,
std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope,
DcrLogger* logger,
NotNull<DataFlowGraph> dfg,
std::vector<RequireCycle> requireCycles
);
/**
* The entry point to the ConstraintGenerator. This will construct a set
* of scopes, constraints, and free types that can be solved later.
* @param block the root block to generate constraints for.
*/
void visitModuleRoot(AstStatBlock* block);
void visitFragmentRoot(const ScopePtr& resumeScope, AstStatBlock* block);
private:
std::vector<std::vector<TypeId>> interiorTypes;
/**
* Fabricates a new free type belonging to a given scope.
* @param scope the scope the free type belongs to.
*/
TypeId freshType(const ScopePtr& scope);
/**
* Fabricates a new free type pack belonging to a given scope.
* @param scope the scope the free type pack belongs to.
*/
TypePackId freshTypePack(const ScopePtr& scope);
/**
* Allocate a new TypePack with the given head and tail.
*
* Avoids allocating 0-length type packs:
*
* If the head is non-empty, allocate and return a type pack with the given
* head and tail.
* If the head is empty and tail is non-empty, return *tail.
* If both the head and tail are empty, return an empty type pack.
*/
TypePackId addTypePack(std::vector<TypeId> head, std::optional<TypePackId> tail);
/**
* Fabricates a scope that is a child of another scope.
* @param node the lexical node that the scope belongs to.
* @param parent the parent scope of the new scope. Must not be null.
*/
ScopePtr childScope(AstNode* node, const ScopePtr& parent);
std::optional<TypeId> lookup(const ScopePtr& scope, Location location, DefId def, bool prototype = true);
/**
* Adds a new constraint with no dependencies to a given scope.
* @param scope the scope to add the constraint to.
* @param cv the constraint variant to add.
* @return the pointer to the inserted constraint
*/
NotNull<Constraint> addConstraint(const ScopePtr& scope, const Location& location, ConstraintV cv);
/**
* Adds a constraint to a given scope.
* @param scope the scope to add the constraint to. Must not be null.
* @param c the constraint to add.
* @return the pointer to the inserted constraint
*/
NotNull<Constraint> addConstraint(const ScopePtr& scope, std::unique_ptr<Constraint> c);
struct RefinementPartition
{
// Types that we want to intersect against the type of the expression.
std::vector<TypeId> discriminantTypes;
// Sometimes the type we're discriminating against is implicitly nil.
bool shouldAppendNilType = false;
};
using RefinementContext = InsertionOrderedMap<DefId, RefinementPartition>;
void unionRefinements(
const ScopePtr& scope,
Location location,
const RefinementContext& lhs,
const RefinementContext& rhs,
RefinementContext& dest,
std::vector<ConstraintV>* constraints
);
void computeRefinement(
const ScopePtr& scope,
Location location,
RefinementId refinement,
RefinementContext* refis,
bool sense,
bool eq,
std::vector<ConstraintV>* constraints
);
void applyRefinements(const ScopePtr& scope, Location location, RefinementId refinement);
LUAU_NOINLINE void checkAliases(const ScopePtr& scope, AstStatBlock* block);
ControlFlow visitBlockWithoutChildScope(const ScopePtr& scope, AstStatBlock* block);
ControlFlow visitBlockWithoutChildScope_DEPRECATED(const ScopePtr& scope, AstStatBlock* block);
ControlFlow visit(const ScopePtr& scope, AstStat* stat);
ControlFlow visit(const ScopePtr& scope, AstStatBlock* block);
ControlFlow visit(const ScopePtr& scope, AstStatLocal* local);
ControlFlow visit(const ScopePtr& scope, AstStatFor* for_);
ControlFlow visit(const ScopePtr& scope, AstStatForIn* forIn);
ControlFlow visit(const ScopePtr& scope, AstStatWhile* while_);
ControlFlow visit(const ScopePtr& scope, AstStatRepeat* repeat);
ControlFlow visit(const ScopePtr& scope, AstStatLocalFunction* function);
ControlFlow visit(const ScopePtr& scope, AstStatFunction* function);
ControlFlow visit(const ScopePtr& scope, AstStatReturn* ret);
ControlFlow visit(const ScopePtr& scope, AstStatAssign* assign);
ControlFlow visit(const ScopePtr& scope, AstStatCompoundAssign* assign);
ControlFlow visit(const ScopePtr& scope, AstStatIf* ifStatement);
ControlFlow visit(const ScopePtr& scope, AstStatTypeAlias* alias);
ControlFlow visit(const ScopePtr& scope, AstStatTypeFunction* function);
ControlFlow visit(const ScopePtr& scope, AstStatDeclareGlobal* declareGlobal);
ControlFlow visit(const ScopePtr& scope, AstStatDeclareClass* declareClass);
ControlFlow visit(const ScopePtr& scope, AstStatDeclareFunction* declareFunction);
ControlFlow visit(const ScopePtr& scope, AstStatError* error);
InferencePack checkPack(const ScopePtr& scope, AstArray<AstExpr*> exprs, const std::vector<std::optional<TypeId>>& expectedTypes = {});
InferencePack checkPack(
const ScopePtr& scope,
AstExpr* expr,
const std::vector<std::optional<TypeId>>& expectedTypes = {},
bool generalize = true
);
InferencePack checkPack(const ScopePtr& scope, AstExprCall* call);
/**
* Checks an expression that is expected to evaluate to one type.
* @param scope the scope the expression is contained within.
* @param expr the expression to check.
* @param expectedType the type of the expression that is expected from its
* surrounding context. Used to implement bidirectional type checking.
* @param generalize If true, generalize any lambdas that are encountered.
* @return the type of the expression.
*/
Inference check(
const ScopePtr& scope,
AstExpr* expr,
std::optional<TypeId> expectedType = {},
bool forceSingleton = false,
bool generalize = true
);
Inference check(const ScopePtr& scope, AstExprConstantString* string, std::optional<TypeId> expectedType, bool forceSingleton);
Inference check(const ScopePtr& scope, AstExprConstantBool* bool_, std::optional<TypeId> expectedType, bool forceSingleton);
Inference check(const ScopePtr& scope, AstExprLocal* local);
Inference check(const ScopePtr& scope, AstExprGlobal* global);
Inference checkIndexName(const ScopePtr& scope, const RefinementKey* key, AstExpr* indexee, const std::string& index, Location indexLocation);
Inference check(const ScopePtr& scope, AstExprIndexName* indexName);
Inference check(const ScopePtr& scope, AstExprIndexExpr* indexExpr);
Inference check(const ScopePtr& scope, AstExprFunction* func, std::optional<TypeId> expectedType, bool generalize);
Inference check(const ScopePtr& scope, AstExprUnary* unary);
Inference check(const ScopePtr& scope, AstExprBinary* binary, std::optional<TypeId> expectedType);
Inference checkAstExprBinary(
const ScopePtr& scope,
const Location& location,
AstExprBinary::Op op,
AstExpr* left,
AstExpr* right,
std::optional<TypeId> expectedType
);
Inference check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional<TypeId> expectedType);
Inference check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert);
Inference check(const ScopePtr& scope, AstExprInterpString* interpString);
Inference check(const ScopePtr& scope, AstExprTable* expr, std::optional<TypeId> expectedType);
std::tuple<TypeId, TypeId, RefinementId> checkBinary(
const ScopePtr& scope,
AstExprBinary::Op op,
AstExpr* left,
AstExpr* right,
std::optional<TypeId> expectedType
);
void visitLValue(const ScopePtr& scope, AstExpr* expr, TypeId rhsType);
void visitLValue(const ScopePtr& scope, AstExprLocal* local, TypeId rhsType);
void visitLValue(const ScopePtr& scope, AstExprGlobal* global, TypeId rhsType);
void visitLValue(const ScopePtr& scope, AstExprIndexName* indexName, TypeId rhsType);
void visitLValue(const ScopePtr& scope, AstExprIndexExpr* indexExpr, TypeId rhsType);
struct FunctionSignature
{
// The type of the function.
TypeId signature;
// The scope that encompasses the function's signature. May be nullptr
// if there was no need for a signature scope (the function has no
// generics).
ScopePtr signatureScope;
// The scope that encompasses the function's body. Is a child scope of
// signatureScope, if present.
ScopePtr bodyScope;
};
FunctionSignature checkFunctionSignature(
const ScopePtr& parent,
AstExprFunction* fn,
std::optional<TypeId> expectedType = {},
std::optional<Location> originalName = {}
);
/**
* Checks the body of a function expression.
* @param scope the interior scope of the body of the function.
* @param fn the function expression to check.
*/
void checkFunctionBody(const ScopePtr& scope, AstExprFunction* fn);
// Specializations of 'resolveType' below
TypeId resolveReferenceType(const ScopePtr& scope, AstType* ty, AstTypeReference* ref, bool inTypeArguments, bool replaceErrorWithFresh);
TypeId resolveTableType(const ScopePtr& scope, AstType* ty, AstTypeTable* tab, bool inTypeArguments, bool replaceErrorWithFresh);
TypeId resolveFunctionType(const ScopePtr& scope, AstType* ty, AstTypeFunction* fn, bool inTypeArguments, bool replaceErrorWithFresh);
/**
* Resolves a type from its AST annotation.
* @param scope the scope that the type annotation appears within.
* @param ty the AST annotation to resolve.
* @param inTypeArguments whether we are resolving a type that's contained within type arguments, `<...>`.
* @return the type of the AST annotation.
**/
TypeId resolveType(const ScopePtr& scope, AstType* ty, bool inTypeArguments, bool replaceErrorWithFresh = false);
/**
* Resolves a type pack from its AST annotation.
* @param scope the scope that the type annotation appears within.
* @param tp the AST annotation to resolve.
* @param inTypeArguments whether we are resolving a type that's contained within type arguments, `<...>`.
* @return the type pack of the AST annotation.
**/
TypePackId resolveTypePack(const ScopePtr& scope, AstTypePack* tp, bool inTypeArguments, bool replaceErrorWithFresh = false);
/**
* Resolves a type pack from its AST annotation.
* @param scope the scope that the type annotation appears within.
* @param list the AST annotation to resolve.
* @param inTypeArguments whether we are resolving a type that's contained within type arguments, `<...>`.
* @return the type pack of the AST annotation.
**/
TypePackId resolveTypePack(const ScopePtr& scope, const AstTypeList& list, bool inTypeArguments, bool replaceErrorWithFresh = false);
/**
* Creates generic types given a list of AST definitions, resolving default
* types as required.
* @param scope the scope that the generics should belong to.
* @param generics the AST generics to create types for.
* @param useCache whether to use the generic type cache for the given
* scope.
* @param addTypes whether to add the types to the scope's
* privateTypeBindings map.
**/
std::vector<std::pair<Name, GenericTypeDefinition>> createGenerics(
const ScopePtr& scope,
AstArray<AstGenericType*> generics,
bool useCache = false,
bool addTypes = true
);
/**
* Creates generic type packs given a list of AST definitions, resolving
* default type packs as required.
* @param scope the scope that the generic packs should belong to.
* @param generics the AST generics to create type packs for.
* @param useCache whether to use the generic type pack cache for the given
* scope.
* @param addTypes whether to add the types to the scope's
* privateTypePackBindings map.
**/
std::vector<std::pair<Name, GenericTypePackDefinition>> createGenericPacks(
const ScopePtr& scope,
AstArray<AstGenericTypePack*> packs,
bool useCache = false,
bool addTypes = true
);
Inference flattenPack(const ScopePtr& scope, Location location, InferencePack pack);
void reportError(Location location, TypeErrorData err);
void reportCodeTooComplex(Location location);
// make a union type function of these two types
TypeId makeUnion(const ScopePtr& scope, Location location, TypeId lhs, TypeId rhs);
// make an intersect type function of these two types
TypeId makeIntersect(const ScopePtr& scope, Location location, TypeId lhs, TypeId rhs);
void prepopulateGlobalScopeForFragmentTypecheck(const ScopePtr& globalScope, const ScopePtr& resumeScope, AstStatBlock* program);
/** Scan the program for global definitions.
*
* ConstraintGenerator needs to differentiate between globals and accesses to undefined symbols. Doing this "for
* real" in a general way is going to be pretty hard, so we are choosing not to tackle that yet. For now, we do an
* initial scan of the AST and note what globals are defined.
*/
void prepopulateGlobalScope(const ScopePtr& globalScope, AstStatBlock* program);
bool recordPropertyAssignment(TypeId ty);
// Record the fact that a particular local has a particular type in at least
// one of its states.
void recordInferredBinding(AstLocal* local, TypeId ty);
void fillInInferredBindings(const ScopePtr& globalScope, AstStatBlock* block);
/** Given a function type annotation, return a vector describing the expected types of the calls to the function
* For example, calling a function with annotation ((number) -> string & ((string) -> number))
* yields a vector of size 1, with value: [number | string]
*/
std::vector<std::optional<TypeId>> getExpectedCallTypesForFunctionOverloads(const TypeId fnType);
TypeId createTypeFunctionInstance(
const TypeFunction& function,
std::vector<TypeId> typeArguments,
std::vector<TypePackId> packArguments,
const ScopePtr& scope,
Location location
);
TypeId simplifyUnion(const ScopePtr& scope, Location location, TypeId left, TypeId right);
};
/** Borrow a vector of pointers from a vector of owning pointers to constraints.
*/
std::vector<NotNull<Constraint>> borrowConstraints(const std::vector<ConstraintPtr>& constraints);
} // namespace Luau