diff --git a/Analysis/include/Luau/Clone.h b/Analysis/include/Luau/Clone.h index 9fcbce04..548a58f5 100644 --- a/Analysis/include/Luau/Clone.h +++ b/Analysis/include/Luau/Clone.h @@ -25,4 +25,6 @@ TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState); TypeId clone(TypeId tp, TypeArena& dest, CloneState& cloneState); TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState); +TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log); + } // namespace Luau diff --git a/Analysis/include/Luau/ConstraintGraphBuilder.h b/Analysis/include/Luau/ConstraintGraphBuilder.h new file mode 100644 index 00000000..4234f2f6 --- /dev/null +++ b/Analysis/include/Luau/ConstraintGraphBuilder.h @@ -0,0 +1,162 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#pragma once + +#include +#include + +#include "Luau/Ast.h" +#include "Luau/Module.h" +#include "Luau/Symbol.h" +#include "Luau/TypeVar.h" +#include "Luau/Variant.h" + +namespace Luau +{ + +struct Scope2; + +// subType <: superType +struct SubtypeConstraint +{ + TypeId subType; + TypeId superType; +}; + +// subPack <: superPack +struct PackSubtypeConstraint +{ + TypePackId subPack; + TypePackId superPack; +}; + +// subType ~ gen superType +struct GeneralizationConstraint +{ + TypeId subType; + TypeId superType; + Scope2* scope; +}; + +// subType ~ inst superType +struct InstantiationConstraint +{ + TypeId subType; + TypeId superType; +}; + +using ConstraintV = Variant; +using ConstraintPtr = std::unique_ptr; + +struct Constraint +{ + Constraint(ConstraintV&& c); + Constraint(ConstraintV&& c, std::vector dependencies); + + Constraint(const Constraint&) = delete; + Constraint& operator=(const Constraint&) = delete; + + ConstraintV c; + std::vector dependencies; +}; + +inline Constraint& asMutable(const Constraint& c) +{ + return const_cast(c); +} + +template +T* getMutable(Constraint& c) +{ + return ::Luau::get_if(&c.c); +} + +template +const T* get(const Constraint& c) +{ + return getMutable(asMutable(c)); +} + +struct Scope2 +{ + // The parent scope of this scope. Null if there is no parent (i.e. this + // is the module-level scope). + Scope2* parent = nullptr; + // All the children of this scope. + std::vector children; + std::unordered_map bindings; // TODO: I think this can be a DenseHashMap + TypePackId returnType; + // All constraints belonging to this scope. + std::vector constraints; + + std::optional lookup(Symbol sym); +}; + +struct ConstraintGraphBuilder +{ + // A list of all the scopes in the module. This vector holds ownership of the + // scope pointers; the scopes themselves borrow pointers to other scopes to + // define the scope hierarchy. + std::vector>> scopes; + SingletonTypes& singletonTypes; + TypeArena* const arena; + // The root scope of the module we're generating constraints for. + Scope2* rootScope; + + explicit ConstraintGraphBuilder(TypeArena* arena); + + /** + * Fabricates a new free type belonging to a given scope. + * @param scope the scope the free type belongs to. Must not be null. + */ + TypeId freshType(Scope2* scope); + + /** + * Fabricates a new free type pack belonging to a given scope. + * @param scope the scope the free type pack belongs to. Must not be null. + */ + TypePackId freshTypePack(Scope2* scope); + + /** + * Fabricates a scope that is a child of another scope. + * @param location the lexical extent of the scope in the source code. + * @param parent the parent scope of the new scope. Must not be null. + */ + Scope2* childScope(Location location, Scope2* parent); + + /** + * Adds a new constraint with no dependencies to a given scope. + * @param scope the scope to add the constraint to. Must not be null. + * @param cv the constraint variant to add. + */ + void addConstraint(Scope2* scope, ConstraintV cv); + + /** + * Adds a constraint to a given scope. + * @param scope the scope to add the constraint to. Must not be null. + * @param c the constraint to add. + */ + void addConstraint(Scope2* scope, std::unique_ptr c); + + /** + * The entry point to the ConstraintGraphBuilder. This will construct a set + * of scopes, constraints, and free types that can be solved later. + * @param block the root block to generate constraints for. + */ + void visit(AstStatBlock* block); + + void visit(Scope2* scope, AstStat* stat); + void visit(Scope2* scope, AstStatBlock* block); + void visit(Scope2* scope, AstStatLocal* local); + void visit(Scope2* scope, AstStatLocalFunction* local); + void visit(Scope2* scope, AstStatReturn* local); + + TypePackId checkPack(Scope2* scope, AstArray exprs); + TypePackId checkPack(Scope2* scope, AstExpr* expr); + + TypeId check(Scope2* scope, AstExpr* expr); +}; + +std::vector collectConstraints(Scope2* rootScope); + +} // namespace Luau diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h new file mode 100644 index 00000000..85006e68 --- /dev/null +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -0,0 +1,106 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#pragma once + +#include "Luau/Error.h" +#include "Luau/Variant.h" +#include "Luau/ConstraintGraphBuilder.h" +#include "Luau/TypeVar.h" + +#include + +namespace Luau +{ + +// TypeId, TypePackId, or Constraint*. It is impossible to know which, but we +// never dereference this pointer. +using BlockedConstraintId = const void*; + +struct ConstraintSolver +{ + TypeArena* arena; + InternalErrorReporter iceReporter; + // The entire set of constraints that the solver is trying to resolve. + std::vector constraints; + Scope2* rootScope; + std::vector errors; + + // This includes every constraint that has not been fully solved. + // A constraint can be both blocked and unsolved, for instance. + std::unordered_set unsolvedConstraints; + + // A mapping of constraint pointer to how many things the constraint is + // blocked on. Can be empty or 0 for constraints that are not blocked on + // anything. + std::unordered_map blockedConstraints; + // A mapping of type/pack pointers to the constraints they block. + std::unordered_map> blocked; + + explicit ConstraintSolver(TypeArena* arena, Scope2* rootScope); + + /** + * Attempts to dispatch all pending constraints and reach a type solution + * that satisfies all of the constraints, recording any errors that are + * encountered. + **/ + void run(); + + bool done(); + + bool tryDispatch(const Constraint* c); + bool tryDispatch(const SubtypeConstraint& c); + bool tryDispatch(const PackSubtypeConstraint& c); + bool tryDispatch(const GeneralizationConstraint& c); + bool tryDispatch(const InstantiationConstraint& c, const Constraint* constraint); + + /** + * Marks a constraint as being blocked on a type or type pack. The constraint + * solver will not attempt to dispatch blocked constraints until their + * dependencies have made progress. + * @param target the type or type pack pointer that the constraint is blocked on. + * @param constraint the constraint to block. + **/ + void block_(BlockedConstraintId target, const Constraint* constraint); + void block(const Constraint* target, const Constraint* constraint); + void block(TypeId target, const Constraint* constraint); + void block(TypePackId target, const Constraint* constraint); + + /** + * Informs the solver that progress has been made on a type or type pack. The + * solver will wake up all constraints that are blocked on the type or type pack, + * and will resume attempting to dispatch them. + * @param progressed the type or type pack pointer that has progressed. + **/ + void unblock_(BlockedConstraintId progressed); + void unblock(const Constraint* progressed); + void unblock(TypeId progressed); + void unblock(TypePackId progressed); + + /** + * Returns whether the constraint is blocked on anything. + * @param constraint the constraint to check. + */ + bool isBlocked(const Constraint* constraint); + + void reportErrors(const std::vector& errors); + + /** + * Creates a new Unifier and performs a single unification operation. Commits + * the result and reports errors if necessary. + * @param subType the sub-type to unify. + * @param superType the super-type to unify. + */ + void unify(TypeId subType, TypeId superType); + + /** + * Creates a new Unifier and performs a single unification operation. Commits + * the result and reports errors if necessary. + * @param subPack the sub-type pack to unify. + * @param superPack the super-type pack to unify. + */ + void unify(TypePackId subPack, TypePackId superPack); +}; + +void dump(Scope2* rootScope, struct ToStringOptions& opts); + +} // namespace Luau diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index d7c9ca40..58be0ffe 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -159,6 +159,8 @@ struct Frontend void applyBuiltinDefinitionToEnvironment(const std::string& environmentName, const std::string& definitionName); private: + ModulePtr check(const SourceModule& sourceModule, Mode mode, const ScopePtr& environmentScope); + std::pair getSourceNode(CheckResult& checkResult, const ModuleName& name); SourceModule parse(const ModuleName& name, std::string_view src, const ParseOptions& parseOptions); diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index 00e1e635..f6e077dc 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -19,6 +19,7 @@ struct Module; using ScopePtr = std::shared_ptr; using ModulePtr = std::shared_ptr; +struct Scope2; /// Root of the AST of a parsed source file struct SourceModule @@ -65,6 +66,7 @@ struct Module std::shared_ptr names; std::vector> scopes; // never empty + std::vector>> scope2s; // never empty DenseHashMap astTypes{nullptr}; DenseHashMap astExpectedTypes{nullptr}; @@ -78,6 +80,7 @@ struct Module bool timeout = false; ScopePtr getModuleScope() const; + Scope2* getModuleScope2() const; // Once a module has been typechecked, we clone its public interface into a separate arena. // This helps us to force TypeVar ownership into a DAG rather than a DCG. diff --git a/Analysis/include/Luau/NotNull.h b/Analysis/include/Luau/NotNull.h new file mode 100644 index 00000000..3d05fdea --- /dev/null +++ b/Analysis/include/Luau/NotNull.h @@ -0,0 +1,75 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Common.h" + +#include + +namespace Luau +{ + +/** A non-owning, non-null pointer to a T. + * + * A NotNull is notionally identical to a T* with the added restriction that it + * can never store nullptr. + * + * The sole conversion rule from T* to NotNull is the single-argument constructor, which + * is intentionally marked explicit. This constructor performs a runtime test to verify + * that the passed pointer is never nullptr. + * + * Pointer arithmetic, increment, decrement, and array indexing are all forbidden. + * + * An implicit coersion from NotNull to T* is afforded, as are the pointer indirection and member + * access operators. (*p and p->prop) + * + * The explicit delete statement is permitted on a NotNull through this implicit conversion. + */ +template +struct NotNull +{ + explicit NotNull(T* t) + : ptr(t) + { + LUAU_ASSERT(t); + } + + explicit NotNull(std::nullptr_t) = delete; + void operator=(std::nullptr_t) = delete; + + operator T*() const noexcept + { + return ptr; + } + + T& operator*() const noexcept + { + return *ptr; + } + + T* operator->() const noexcept + { + return ptr; + } + + T& operator[](int) = delete; + + T& operator+(int) = delete; + T& operator-(int) = delete; + + T* ptr; +}; + +} + +namespace std +{ + +template struct hash> +{ + size_t operator()(const Luau::NotNull& p) const + { + return std::hash()(p.ptr); + } +}; + +} diff --git a/Analysis/include/Luau/Quantify.h b/Analysis/include/Luau/Quantify.h index e48cad40..b32d684e 100644 --- a/Analysis/include/Luau/Quantify.h +++ b/Analysis/include/Luau/Quantify.h @@ -6,6 +6,9 @@ namespace Luau { +struct Scope2; + void quantify(TypeId ty, TypeLevel level); +void quantify(TypeId ty, Scope2* scope); } // namespace Luau diff --git a/Analysis/include/Luau/Symbol.h b/Analysis/include/Luau/Symbol.h index b5dd9c89..1fe037e5 100644 --- a/Analysis/include/Luau/Symbol.h +++ b/Analysis/include/Luau/Symbol.h @@ -30,6 +30,9 @@ struct Symbol { } + template + Symbol(const T&) = delete; + AstLocal* local; AstName global; diff --git a/Analysis/include/Luau/ToString.h b/Analysis/include/Luau/ToString.h index 3b380a60..a50fef78 100644 --- a/Analysis/include/Luau/ToString.h +++ b/Analysis/include/Luau/ToString.h @@ -3,6 +3,7 @@ #include "Luau/Common.h" #include "Luau/TypeVar.h" +#include "Luau/ConstraintGraphBuilder.h" #include #include @@ -53,6 +54,7 @@ ToStringResult toStringDetailed(TypePackId ty, const ToStringOptions& opts = {}) std::string toString(TypeId ty, const ToStringOptions& opts); std::string toString(TypePackId ty, const ToStringOptions& opts); +std::string toString(const Constraint& c, ToStringOptions& opts); // These are offered as overloads rather than a default parameter so that they can be easily invoked from within the MSVC debugger. // You can use them in watch expressions! @@ -64,6 +66,11 @@ inline std::string toString(TypePackId ty) { return toString(ty, ToStringOptions{}); } +inline std::string toString(const Constraint& c) +{ + ToStringOptions opts; + return toString(c, opts); +} std::string toString(const TypeVar& tv, const ToStringOptions& opts = {}); std::string toString(const TypePackVar& tp, const ToStringOptions& opts = {}); @@ -74,6 +81,7 @@ std::string toStringNamedFunction(const std::string& funcName, const FunctionTyp // These functions will dump the type to stdout and can be evaluated in Watch/Immediate windows or as gdb/lldb expression std::string dump(TypeId ty); std::string dump(TypePackId ty); +std::string dump(const Constraint& c); std::string dump(const std::shared_ptr& scope, const char* name); diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index 9cacbc6d..b3c455cf 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -24,6 +24,7 @@ namespace Luau { struct TypeArena; +struct Scope2; /** * There are three kinds of type variables: @@ -124,6 +125,7 @@ struct ConstrainedTypeVar std::vector parts; TypeLevel level; + Scope2* scope = nullptr; }; // Singleton types https://github.com/Roblox/luau/blob/master/rfcs/syntax-singleton-types.md @@ -255,6 +257,7 @@ struct FunctionTypeVar std::optional defn = {}, bool hasSelf = false); TypeLevel level; + Scope2* scope = nullptr; /// These should all be generic std::vector generics; std::vector genericPacks; @@ -266,6 +269,7 @@ struct FunctionTypeVar bool hasSelf; Tags tags; bool hasNoGenerics = false; + bool generalized = false; }; enum class TableState @@ -323,6 +327,7 @@ struct TableTypeVar TableState state = TableState::Unsealed; TypeLevel level; + Scope2* scope = nullptr; std::optional name; // Sometimes we throw a type on a name to make for nicer error messages, but without creating any entry in the type namespace diff --git a/Analysis/include/Luau/TypedAllocator.h b/Analysis/include/Luau/TypedAllocator.h index c1c04d10..f67e3d8e 100644 --- a/Analysis/include/Luau/TypedAllocator.h +++ b/Analysis/include/Luau/TypedAllocator.h @@ -23,6 +23,12 @@ public: currentBlockSize = kBlockSize; } + TypedAllocator(const TypedAllocator&) = delete; + TypedAllocator& operator=(const TypedAllocator&) = delete; + + TypedAllocator(TypedAllocator&&) = default; + TypedAllocator& operator=(TypedAllocator&&) = default; + ~TypedAllocator() { if (frozen) diff --git a/Analysis/include/Luau/Unifiable.h b/Analysis/include/Luau/Unifiable.h index 64fa131d..fdc39481 100644 --- a/Analysis/include/Luau/Unifiable.h +++ b/Analysis/include/Luau/Unifiable.h @@ -8,6 +8,8 @@ namespace Luau { +struct Scope2; + /** * The 'level' of a TypeVar is an indirect way to talk about the scope that it 'belongs' too. * To start, read http://okmij.org/ftp/ML/generalization.html @@ -82,9 +84,11 @@ using Name = std::string; struct Free { explicit Free(TypeLevel level); + explicit Free(Scope2* scope); int index; TypeLevel level; + Scope2* scope = nullptr; // True if this free type variable is part of a mutually // recursive type alias whose definitions haven't been // resolved yet. @@ -111,12 +115,14 @@ struct Generic Generic(); explicit Generic(TypeLevel level); explicit Generic(const Name& name); + explicit Generic(Scope2* scope); Generic(TypeLevel level, const Name& name); int index; TypeLevel level; + Scope2* scope = nullptr; Name name; - bool explicitName; + bool explicitName = false; private: static int nextIndex; diff --git a/Analysis/include/Luau/Variant.h b/Analysis/include/Luau/Variant.h index 5efe89ed..c9c97c92 100644 --- a/Analysis/include/Luau/Variant.h +++ b/Analysis/include/Luau/Variant.h @@ -95,6 +95,20 @@ public: return *this; } + template + T& emplace(Args&&... args) + { + using TT = std::decay_t; + constexpr int tid = getTypeId(); + static_assert(tid >= 0, "unsupported T"); + + tableDtor[typeId](&storage); + typeId = tid; + new (&storage) TT(std::forward(args)...); + + return *reinterpret_cast(&storage); + } + template const T* get_if() const { diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index a3611f53..19e3383e 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -2,6 +2,7 @@ #include "Luau/Clone.h" #include "Luau/RecursionCounter.h" +#include "Luau/TxnLog.h" #include "Luau/TypePack.h" #include "Luau/Unifiable.h" @@ -382,4 +383,67 @@ TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState) return result; } +TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log) +{ + ty = log->follow(ty); + + TypeId result = ty; + + if (auto pty = log->pending(ty)) + ty = &pty->pending; + + if (const FunctionTypeVar* ftv = get(ty)) + { + FunctionTypeVar clone = FunctionTypeVar{ftv->level, ftv->argTypes, ftv->retType, ftv->definition, ftv->hasSelf}; + clone.generics = ftv->generics; + clone.genericPacks = ftv->genericPacks; + clone.magicFunction = ftv->magicFunction; + clone.tags = ftv->tags; + clone.argNames = ftv->argNames; + result = dest.addType(std::move(clone)); + } + else if (const TableTypeVar* ttv = get(ty)) + { + LUAU_ASSERT(!ttv->boundTo); + TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->state}; + if (!FFlag::LuauNoMethodLocations) + clone.methodDefinitionLocations = ttv->methodDefinitionLocations; + clone.definitionModuleName = ttv->definitionModuleName; + clone.name = ttv->name; + clone.syntheticName = ttv->syntheticName; + clone.instantiatedTypeParams = ttv->instantiatedTypeParams; + clone.instantiatedTypePackParams = ttv->instantiatedTypePackParams; + clone.tags = ttv->tags; + result = dest.addType(std::move(clone)); + } + else if (const MetatableTypeVar* mtv = get(ty)) + { + MetatableTypeVar clone = MetatableTypeVar{mtv->table, mtv->metatable}; + clone.syntheticName = mtv->syntheticName; + result = dest.addType(std::move(clone)); + } + else if (const UnionTypeVar* utv = get(ty)) + { + UnionTypeVar clone; + clone.options = utv->options; + result = dest.addType(std::move(clone)); + } + else if (const IntersectionTypeVar* itv = get(ty)) + { + IntersectionTypeVar clone; + clone.parts = itv->parts; + result = dest.addType(std::move(clone)); + } + else if (const ConstrainedTypeVar* ctv = get(ty)) + { + ConstrainedTypeVar clone{ctv->level, ctv->parts}; + result = dest.addType(std::move(clone)); + } + else + return result; + + asMutable(result)->documentationSymbol = ty->documentationSymbol; + return result; +} + } // namespace Luau diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp new file mode 100644 index 00000000..c8f77ddf --- /dev/null +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -0,0 +1,300 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/ConstraintGraphBuilder.h" + +namespace Luau +{ + +Constraint::Constraint(ConstraintV&& c) + : c(std::move(c)) +{ +} + +Constraint::Constraint(ConstraintV&& c, std::vector dependencies) + : c(std::move(c)) + , dependencies(dependencies) +{ +} + +std::optional Scope2::lookup(Symbol sym) +{ + Scope2* s = this; + + while (true) + { + auto it = s->bindings.find(sym); + if (it != s->bindings.end()) + return it->second; + + if (s->parent) + s = s->parent; + else + return std::nullopt; + } +} + +ConstraintGraphBuilder::ConstraintGraphBuilder(TypeArena* arena) + : singletonTypes(getSingletonTypes()) + , arena(arena) + , rootScope(nullptr) +{ + LUAU_ASSERT(arena); +} + +TypeId ConstraintGraphBuilder::freshType(Scope2* scope) +{ + LUAU_ASSERT(scope); + return arena->addType(FreeTypeVar{scope}); +} + +TypePackId ConstraintGraphBuilder::freshTypePack(Scope2* scope) +{ + LUAU_ASSERT(scope); + FreeTypePack f{scope}; + return arena->addTypePack(TypePackVar{std::move(f)}); +} + +Scope2* ConstraintGraphBuilder::childScope(Location location, Scope2* parent) +{ + LUAU_ASSERT(parent); + auto scope = std::make_unique(); + Scope2* borrow = scope.get(); + scopes.emplace_back(location, std::move(scope)); + + borrow->parent = parent; + borrow->returnType = parent->returnType; + parent->children.push_back(borrow); + + return borrow; +} + +void ConstraintGraphBuilder::addConstraint(Scope2* scope, ConstraintV cv) +{ + LUAU_ASSERT(scope); + scope->constraints.emplace_back(new Constraint{std::move(cv)}); +} + +void ConstraintGraphBuilder::addConstraint(Scope2* scope, std::unique_ptr c) +{ + LUAU_ASSERT(scope); + scope->constraints.emplace_back(std::move(c)); +} + +void ConstraintGraphBuilder::visit(AstStatBlock* block) +{ + LUAU_ASSERT(scopes.empty()); + LUAU_ASSERT(rootScope == nullptr); + scopes.emplace_back(block->location, std::make_unique()); + rootScope = scopes.back().second.get(); + rootScope->returnType = freshTypePack(rootScope); + + visit(rootScope, block); +} + +void ConstraintGraphBuilder::visit(Scope2* scope, AstStat* stat) +{ + LUAU_ASSERT(scope); + + if (auto s = stat->as()) + visit(scope, s); + else if (auto s = stat->as()) + visit(scope, s); + else if (auto f = stat->as()) + visit(scope, f); + else if (auto r = stat->as()) + visit(scope, r); + else + LUAU_ASSERT(0); +} + +void ConstraintGraphBuilder::visit(Scope2* scope, AstStatLocal* local) +{ + LUAU_ASSERT(scope); + + std::vector varTypes; + + for (AstLocal* local : local->vars) + { + // TODO annotations + TypeId ty = freshType(scope); + varTypes.push_back(ty); + scope->bindings[local] = ty; + } + + for (size_t i = 0; i < local->vars.size; ++i) + { + if (i < local->values.size) + { + TypeId exprType = check(scope, local->values.data[i]); + addConstraint(scope, SubtypeConstraint{varTypes[i], exprType}); + } + } +} + +void addConstraints(Constraint* constraint, Scope2* scope) +{ + LUAU_ASSERT(scope); + + scope->constraints.reserve(scope->constraints.size() + scope->constraints.size()); + + for (const auto& c : scope->constraints) + constraint->dependencies.push_back(c.get()); + + for (Scope2* childScope : scope->children) + addConstraints(constraint, childScope); +} + +void ConstraintGraphBuilder::visit(Scope2* scope, AstStatLocalFunction* function) +{ + LUAU_ASSERT(scope); + + // Local + // Global + // Dotted path + // Self? + + TypeId functionType = nullptr; + auto ty = scope->lookup(function->name); + LUAU_ASSERT(!ty.has_value()); // The parser ensures that every local function has a distinct Symbol for its name. + + functionType = freshType(scope); + scope->bindings[function->name] = functionType; + + Scope2* innerScope = childScope(function->func->body->location, scope); + TypePackId returnType = freshTypePack(scope); + innerScope->returnType = returnType; + + std::vector argTypes; + + for (AstLocal* local : function->func->args) + { + TypeId t = freshType(innerScope); + argTypes.push_back(t); + innerScope->bindings[local] = t; // TODO annotations + } + + for (AstStat* stat : function->func->body->body) + visit(innerScope, stat); + + FunctionTypeVar actualFunction{arena->addTypePack(argTypes), returnType}; + TypeId actualFunctionType = arena->addType(std::move(actualFunction)); + + std::unique_ptr c{new Constraint{GeneralizationConstraint{functionType, actualFunctionType, innerScope}}}; + addConstraints(c.get(), innerScope); + + addConstraint(scope, std::move(c)); +} + +void ConstraintGraphBuilder::visit(Scope2* scope, AstStatReturn* ret) +{ + LUAU_ASSERT(scope); + + TypePackId exprTypes = checkPack(scope, ret->list); + addConstraint(scope, PackSubtypeConstraint{exprTypes, scope->returnType}); +} + +void ConstraintGraphBuilder::visit(Scope2* scope, AstStatBlock* block) +{ + LUAU_ASSERT(scope); + + for (AstStat* stat : block->body) + visit(scope, stat); +} + +TypePackId ConstraintGraphBuilder::checkPack(Scope2* scope, AstArray exprs) +{ + LUAU_ASSERT(scope); + + if (exprs.size == 0) + return arena->addTypePack({}); + + std::vector types; + TypePackId last = nullptr; + + for (size_t i = 0; i < exprs.size; ++i) + { + if (i < exprs.size - 1) + types.push_back(check(scope, exprs.data[i])); + else + last = checkPack(scope, exprs.data[i]); + } + + LUAU_ASSERT(last != nullptr); + + return arena->addTypePack(TypePack{std::move(types), last}); +} + +TypePackId ConstraintGraphBuilder::checkPack(Scope2* scope, AstExpr* expr) +{ + LUAU_ASSERT(scope); + + // TEMP TEMP TEMP HACK HACK HACK FIXME FIXME + TypeId t = check(scope, expr); + return arena->addTypePack({t}); +} + +TypeId ConstraintGraphBuilder::check(Scope2* scope, AstExpr* expr) +{ + LUAU_ASSERT(scope); + + if (auto a = expr->as()) + return singletonTypes.stringType; + else if (auto a = expr->as()) + return singletonTypes.numberType; + else if (auto a = expr->as()) + return singletonTypes.booleanType; + else if (auto a = expr->as()) + return singletonTypes.nilType; + else if (auto a = expr->as()) + { + std::optional ty = scope->lookup(a->local); + if (ty) + return *ty; + else + return singletonTypes.errorRecoveryType(singletonTypes.anyType); // FIXME? Record an error at this point? + } + else if (auto a = expr->as()) + { + std::vector args; + + for (AstExpr* arg : a->args) + { + args.push_back(check(scope, arg)); + } + + TypeId fnType = check(scope, a->func); + TypeId instantiatedType = freshType(scope); + addConstraint(scope, InstantiationConstraint{instantiatedType, fnType}); + + TypeId firstRet = freshType(scope); + TypePackId rets = arena->addTypePack(TypePack{{firstRet}, arena->addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}})}); + FunctionTypeVar ftv(arena->addTypePack(TypePack{args, {}}), rets); + TypeId inferredFnType = arena->addType(ftv); + + addConstraint(scope, SubtypeConstraint{inferredFnType, instantiatedType}); + return firstRet; + } + else + { + LUAU_ASSERT(0); + return freshType(scope); + } +} + +static void collectConstraints(std::vector& result, Scope2* scope) +{ + for (const auto& c : scope->constraints) + result.push_back(c.get()); + + for (Scope2* child : scope->children) + collectConstraints(result, child); +} + +std::vector collectConstraints(Scope2* rootScope) +{ + std::vector result; + collectConstraints(result, rootScope); + return result; +} + +} // namespace Luau diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp new file mode 100644 index 00000000..f40cd4b3 --- /dev/null +++ b/Analysis/src/ConstraintSolver.cpp @@ -0,0 +1,306 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/ConstraintSolver.h" +#include "Luau/Instantiation.h" +#include "Luau/Quantify.h" +#include "Luau/ToString.h" +#include "Luau/Unifier.h" + +LUAU_FASTFLAGVARIABLE(DebugLuauLogSolver, false); + +namespace Luau +{ + +[[maybe_unused]] static void dumpBindings(Scope2* scope, ToStringOptions& opts) +{ + for (const auto& [k, v] : scope->bindings) + { + auto d = toStringDetailed(v, opts); + opts.nameMap = d.nameMap; + printf("\t%s : %s\n", k.c_str(), d.name.c_str()); + } + + for (Scope2* child : scope->children) + dumpBindings(child, opts); +} + +static void dumpConstraints(Scope2* scope, ToStringOptions& opts) +{ + for (const ConstraintPtr& c : scope->constraints) + { + printf("\t%s\n", toString(*c, opts).c_str()); + } + + for (Scope2* child : scope->children) + dumpConstraints(child, opts); +} + +void dump(Scope2* rootScope, ToStringOptions& opts) +{ + printf("constraints:\n"); + dumpConstraints(rootScope, opts); +} + +void dump(ConstraintSolver* cs, ToStringOptions& opts) +{ + printf("constraints:\n"); + for (const Constraint* c : cs->unsolvedConstraints) + { + printf("\t%s\n", toString(*c, opts).c_str()); + + for (const Constraint* dep : c->dependencies) + printf("\t\t%s\n", toString(*dep, opts).c_str()); + } +} + +ConstraintSolver::ConstraintSolver(TypeArena* arena, Scope2* rootScope) + : arena(arena) + , constraints(collectConstraints(rootScope)) + , rootScope(rootScope) +{ + for (const Constraint* c : constraints) + { + unsolvedConstraints.insert(c); + + for (const Constraint* dep : c->dependencies) + { + block(dep, c); + } + } +} + +void ConstraintSolver::run() +{ + if (done()) + return; + + bool progress = false; + + ToStringOptions opts; + + if (FFlag::DebugLuauLogSolver) + { + printf("Starting solver\n"); + dump(this, opts); + } + + do + { + progress = false; + + auto it = begin(unsolvedConstraints); + auto endIt = end(unsolvedConstraints); + + while (it != endIt) + { + if (isBlocked(*it)) + { + ++it; + continue; + } + + std::string saveMe = FFlag::DebugLuauLogSolver ? toString(**it, opts) : std::string{}; + + bool success = tryDispatch(*it); + progress = progress || success; + + auto saveIt = it; + ++it; + if (success) + { + unsolvedConstraints.erase(saveIt); + if (FFlag::DebugLuauLogSolver) + { + printf("Dispatched\n\t%s\n", saveMe.c_str()); + dump(this, opts); + } + } + } + } while (progress); + + if (FFlag::DebugLuauLogSolver) + dumpBindings(rootScope, opts); + + LUAU_ASSERT(done()); +} + +bool ConstraintSolver::done() +{ + return unsolvedConstraints.empty(); +} + +bool ConstraintSolver::tryDispatch(const Constraint* constraint) +{ + if (isBlocked(constraint)) + return false; + + bool success = false; + + if (auto sc = get(*constraint)) + success = tryDispatch(*sc); + else if (auto psc = get(*constraint)) + success = tryDispatch(*psc); + else if (auto gc = get(*constraint)) + success = tryDispatch(*gc); + else if (auto ic = get(*constraint)) + success = tryDispatch(*ic, constraint); + else + LUAU_ASSERT(0); + + if (success) + { + unblock(constraint); + } + + return success; +} + +bool ConstraintSolver::tryDispatch(const SubtypeConstraint& c) +{ + unify(c.subType, c.superType); + unblock(c.subType); + unblock(c.superType); + + return true; +} + +bool ConstraintSolver::tryDispatch(const PackSubtypeConstraint& c) +{ + unify(c.subPack, c.superPack); + unblock(c.subPack); + unblock(c.superPack); + + return true; +} + +bool ConstraintSolver::tryDispatch(const GeneralizationConstraint& constraint) +{ + unify(constraint.subType, constraint.superType); + + quantify(constraint.superType, constraint.scope); + unblock(constraint.subType); + unblock(constraint.superType); + + return true; +} + +bool ConstraintSolver::tryDispatch(const InstantiationConstraint& c, const Constraint* constraint) +{ + TypeId superType = follow(c.superType); + if (const FunctionTypeVar* ftv = get(superType)) + { + if (!ftv->generalized) + { + block(superType, constraint); + return false; + } + } + else if (get(superType)) + { + block(superType, constraint); + return false; + } + // TODO: Error if it's a primitive or something + + Instantiation inst(TxnLog::empty(), arena, TypeLevel{}); + + std::optional instantiated = inst.substitute(c.superType); + LUAU_ASSERT(instantiated); // TODO FIXME HANDLE THIS + + unify(c.subType, *instantiated); + unblock(c.subType); + + return true; +} + +void ConstraintSolver::block_(BlockedConstraintId target, const Constraint* constraint) +{ + blocked[target].push_back(constraint); + + auto& count = blockedConstraints[constraint]; + count += 1; +} + +void ConstraintSolver::block(const Constraint* target, const Constraint* constraint) +{ + block_(target, constraint); +} + +void ConstraintSolver::block(TypeId target, const Constraint* constraint) +{ + block_(target, constraint); +} + +void ConstraintSolver::block(TypePackId target, const Constraint* constraint) +{ + block_(target, constraint); +} + +void ConstraintSolver::unblock_(BlockedConstraintId progressed) +{ + auto it = blocked.find(progressed); + if (it == blocked.end()) + return; + + // unblocked should contain a value always, because of the above check + for (const Constraint* unblockedConstraint : it->second) + { + auto& count = blockedConstraints[unblockedConstraint]; + // This assertion being hit indicates that `blocked` and + // `blockedConstraints` desynchronized at some point. This is problematic + // because we rely on this count being correct to skip over blocked + // constraints. + LUAU_ASSERT(count > 0); + count -= 1; + } + + blocked.erase(it); +} + +void ConstraintSolver::unblock(const Constraint* progressed) +{ + return unblock_(progressed); +} + +void ConstraintSolver::unblock(TypeId progressed) +{ + return unblock_(progressed); +} + +void ConstraintSolver::unblock(TypePackId progressed) +{ + return unblock_(progressed); +} + +bool ConstraintSolver::isBlocked(const Constraint* constraint) +{ + auto blockedIt = blockedConstraints.find(constraint); + return blockedIt != blockedConstraints.end() && blockedIt->second > 0; +} + +void ConstraintSolver::reportErrors(const std::vector& errors) +{ + this->errors.insert(end(this->errors), begin(errors), end(errors)); +} + +void ConstraintSolver::unify(TypeId subType, TypeId superType) +{ + UnifierSharedState sharedState{&iceReporter}; + Unifier u{arena, Mode::Strict, Location{}, Covariant, sharedState}; + + u.tryUnify(subType, superType); + u.log.commit(); + reportErrors(u.errors); +} + +void ConstraintSolver::unify(TypePackId subPack, TypePackId superPack) +{ + UnifierSharedState sharedState{&iceReporter}; + Unifier u{arena, Mode::Strict, Location{}, Covariant, sharedState}; + + u.tryUnify(subPack, superPack); + u.log.commit(); + reportErrors(u.errors); +} + +} // namespace Luau diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index 9a2259f1..f184b74e 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -179,7 +179,7 @@ declare debug: { } declare utf8: { - char: (number, ...number) -> string, + char: (...number) -> string, charpattern: string, codes: (string) -> ((string, number) -> (number, number), string, number), -- FIXME diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 1d33f131..741a35cf 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -5,6 +5,8 @@ #include "Luau/Clone.h" #include "Luau/Config.h" #include "Luau/FileResolver.h" +#include "Luau/ConstraintGraphBuilder.h" +#include "Luau/ConstraintSolver.h" #include "Luau/Parser.h" #include "Luau/Scope.h" #include "Luau/StringUtils.h" @@ -22,6 +24,7 @@ LUAU_FASTFLAG(LuauInferInNoCheckMode) LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) LUAU_FASTFLAGVARIABLE(LuauAutocompleteDynamicLimits, false) LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 100) +LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false) namespace Luau { @@ -470,7 +473,8 @@ CheckResult Frontend::check(const ModuleName& name, std::optional(this)->getSourceModule(moduleName); } +ModulePtr Frontend::check(const SourceModule& sourceModule, Mode mode, const ScopePtr& environmentScope) +{ + ModulePtr result = std::make_shared(); + + ConstraintGraphBuilder cgb{&result->internalTypes}; + cgb.visit(sourceModule.root); + + ConstraintSolver cs{&result->internalTypes, cgb.rootScope}; + cs.run(); + + result->scope2s = std::move(cgb.scopes); + + result->clonePublicInterface(iceHandler); + + return result; +} + // Read AST into sourceModules if necessary. Trace require()s. Report parse errors. std::pair Frontend::getSourceNode(CheckResult& checkResult, const ModuleName& name) { diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index 6591d60a..4d157e6f 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -3,6 +3,7 @@ #include "Luau/Clone.h" #include "Luau/Common.h" +#include "Luau/ConstraintGraphBuilder.h" #include "Luau/Normalize.h" #include "Luau/RecursionCounter.h" #include "Luau/Scope.h" @@ -10,11 +11,13 @@ #include "Luau/TypePack.h" #include "Luau/TypeVar.h" #include "Luau/VisitTypeVar.h" +#include "Luau/ConstraintGraphBuilder.h" // FIXME: For Scope2 TODO pull out into its own header #include LUAU_FASTFLAG(LuauLowerBoundsCalculation); LUAU_FASTFLAG(LuauNormalizeFlagIsConservative); +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); namespace Luau { @@ -97,38 +100,60 @@ void Module::clonePublicInterface(InternalErrorReporter& ice) CloneState cloneState; - ScopePtr moduleScope = getModuleScope(); + ScopePtr moduleScope = FFlag::DebugLuauDeferredConstraintResolution ? nullptr : getModuleScope(); + Scope2* moduleScope2 = FFlag::DebugLuauDeferredConstraintResolution ? getModuleScope2() : nullptr; - moduleScope->returnType = clone(moduleScope->returnType, interfaceTypes, cloneState); - if (moduleScope->varargPack) - moduleScope->varargPack = clone(*moduleScope->varargPack, interfaceTypes, cloneState); + TypePackId returnType = FFlag::DebugLuauDeferredConstraintResolution ? moduleScope2->returnType : moduleScope->returnType; + std::optional varargPack = FFlag::DebugLuauDeferredConstraintResolution ? std::nullopt : moduleScope->varargPack; + std::unordered_map* exportedTypeBindings = + FFlag::DebugLuauDeferredConstraintResolution ? nullptr : &moduleScope->exportedTypeBindings; + + returnType = clone(returnType, interfaceTypes, cloneState); + + if (moduleScope) + { + moduleScope->returnType = returnType; + if (varargPack) + { + varargPack = clone(*varargPack, interfaceTypes, cloneState); + moduleScope->varargPack = varargPack; + } + } + else + { + LUAU_ASSERT(moduleScope2); + moduleScope2->returnType = returnType; // TODO varargPack + } if (FFlag::LuauLowerBoundsCalculation) { - normalize(moduleScope->returnType, interfaceTypes, ice); - if (moduleScope->varargPack) - normalize(*moduleScope->varargPack, interfaceTypes, ice); + normalize(returnType, interfaceTypes, ice); + if (varargPack) + normalize(*varargPack, interfaceTypes, ice); } ForceNormal forceNormal{&interfaceTypes}; - for (auto& [name, tf] : moduleScope->exportedTypeBindings) + if (exportedTypeBindings) { - tf = clone(tf, interfaceTypes, cloneState); - if (FFlag::LuauLowerBoundsCalculation) + for (auto& [name, tf] : *exportedTypeBindings) { - normalize(tf.type, interfaceTypes, ice); - - if (FFlag::LuauNormalizeFlagIsConservative) + tf = clone(tf, interfaceTypes, cloneState); + if (FFlag::LuauLowerBoundsCalculation) { - // We're about to freeze the memory. We know that the flag is conservative by design. Cyclic tables - // won't be marked normal. If the types aren't normal by now, they never will be. - forceNormal.traverse(tf.type); + normalize(tf.type, interfaceTypes, ice); + + if (FFlag::LuauNormalizeFlagIsConservative) + { + // We're about to freeze the memory. We know that the flag is conservative by design. Cyclic tables + // won't be marked normal. If the types aren't normal by now, they never will be. + forceNormal.traverse(tf.type); + } } } } - for (TypeId ty : moduleScope->returnType) + for (TypeId ty : returnType) { if (get(follow(ty))) { @@ -155,4 +180,10 @@ ScopePtr Module::getModuleScope() const return scopes.front().second; } +Scope2* Module::getModuleScope2() const +{ + LUAU_ASSERT(!scope2s.empty()); + return scope2s.front().second.get(); +} + } // namespace Luau diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index fb31df1e..11403be5 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -16,6 +16,7 @@ LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200); LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineTableFix, false); LUAU_FASTFLAGVARIABLE(LuauNormalizeFlagIsConservative, false); LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineEqFix, false); +LUAU_FASTFLAGVARIABLE(LuauReplaceReplacer, false); namespace Luau { @@ -231,11 +232,30 @@ struct Replacer : Substitution TypeId smartClone(TypeId t) { - std::optional res = replace(t); - LUAU_ASSERT(res.has_value()); // TODO think about this - if (*res == t) - return clone(t); - return *res; + if (FFlag::LuauReplaceReplacer) + { + // The new smartClone is just a memoized clone() + // TODO: Remove the Substitution base class and all other methods from this struct. + // Add DenseHashMap newTypes; + t = log->follow(t); + TypeId* res = newTypes.find(t); + if (res) + return *res; + + TypeId result = shallowClone(t, *arena, TxnLog::empty()); + newTypes[t] = result; + newTypes[result] = result; + + return result; + } + else + { + std::optional res = replace(t); + LUAU_ASSERT(res.has_value()); // TODO think about this + if (*res == t) + return clone(t); + return *res; + } } }; diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp index c0f677d7..8f2cc8e3 100644 --- a/Analysis/src/Quantify.cpp +++ b/Analysis/src/Quantify.cpp @@ -3,8 +3,10 @@ #include "Luau/Quantify.h" #include "Luau/VisitTypeVar.h" +#include "Luau/ConstraintGraphBuilder.h" // TODO for Scope2; move to separate header -LUAU_FASTFLAG(LuauAlwaysQuantify) +LUAU_FASTFLAG(LuauAlwaysQuantify); +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); namespace Luau { @@ -14,12 +16,20 @@ struct Quantifier final : TypeVarOnceVisitor TypeLevel level; std::vector generics; std::vector genericPacks; + Scope2* scope = nullptr; bool seenGenericType = false; bool seenMutableType = false; explicit Quantifier(TypeLevel level) : level(level) { + LUAU_ASSERT(!FFlag::DebugLuauDeferredConstraintResolution); + } + + explicit Quantifier(Scope2* scope) + : scope(scope) + { + LUAU_ASSERT(FFlag::DebugLuauDeferredConstraintResolution); } void cycle(TypeId) override {} @@ -57,14 +67,31 @@ struct Quantifier final : TypeVarOnceVisitor return visit(tp, ftp); } + /// @return true if outer encloses inner + bool subsumes(Scope2* outer, Scope2* inner) + { + while (inner) + { + if (inner == outer) + return true; + inner = inner->parent; + } + + return false; + } + bool visit(TypeId ty, const FreeTypeVar& ftv) override { seenMutableType = true; - if (!level.subsumes(ftv.level)) + if (FFlag::DebugLuauDeferredConstraintResolution ? !subsumes(scope, ftv.scope) : !level.subsumes(ftv.level)) return false; - *asMutable(ty) = GenericTypeVar{level}; + if (FFlag::DebugLuauDeferredConstraintResolution) + *asMutable(ty) = GenericTypeVar{scope}; + else + *asMutable(ty) = GenericTypeVar{level}; + generics.push_back(ty); return false; @@ -83,7 +110,7 @@ struct Quantifier final : TypeVarOnceVisitor if (ttv.state == TableState::Sealed || ttv.state == TableState::Generic) return false; - if (!level.subsumes(ttv.level)) + if (FFlag::DebugLuauDeferredConstraintResolution ? !subsumes(scope, ttv.scope) : !level.subsumes(ttv.level)) { if (ttv.state == TableState::Unsealed) seenMutableType = true; @@ -107,7 +134,7 @@ struct Quantifier final : TypeVarOnceVisitor { seenMutableType = true; - if (!level.subsumes(ftp.level)) + if (FFlag::DebugLuauDeferredConstraintResolution ? !subsumes(scope, ftp.scope) : !level.subsumes(ftp.level)) return false; *asMutable(tp) = GenericTypePack{level}; @@ -136,6 +163,32 @@ void quantify(TypeId ty, TypeLevel level) if (ftv->generics.empty() && ftv->genericPacks.empty() && !q.seenMutableType && !q.seenGenericType) ftv->hasNoGenerics = true; + + ftv->generalized = true; +} + +void quantify(TypeId ty, Scope2* scope) +{ + Quantifier q{scope}; + q.traverse(ty); + + FunctionTypeVar* ftv = getMutable(ty); + LUAU_ASSERT(ftv); + if (FFlag::LuauAlwaysQuantify) + { + ftv->generics.insert(ftv->generics.end(), q.generics.begin(), q.generics.end()); + ftv->genericPacks.insert(ftv->genericPacks.end(), q.genericPacks.begin(), q.genericPacks.end()); + } + else + { + ftv->generics = q.generics; + ftv->genericPacks = q.genericPacks; + } + + if (ftv->generics.empty() && ftv->genericPacks.empty() && !q.seenMutableType && !q.seenGenericType) + ftv->hasNoGenerics = true; + + ftv->generalized = true; } } // namespace Luau diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index e40bedb0..50c516db 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -2,6 +2,7 @@ #include "Luau/Substitution.h" #include "Luau/Common.h" +#include "Luau/Clone.h" #include "Luau/TxnLog.h" #include @@ -362,63 +363,7 @@ std::optional Substitution::substitute(TypePackId tp) TypeId Substitution::clone(TypeId ty) { - ty = log->follow(ty); - - TypeId result = ty; - - if (auto pty = log->pending(ty)) - ty = &pty->pending; - - if (const FunctionTypeVar* ftv = get(ty)) - { - FunctionTypeVar clone = FunctionTypeVar{ftv->level, ftv->argTypes, ftv->retType, ftv->definition, ftv->hasSelf}; - clone.generics = ftv->generics; - clone.genericPacks = ftv->genericPacks; - clone.magicFunction = ftv->magicFunction; - clone.tags = ftv->tags; - clone.argNames = ftv->argNames; - result = addType(std::move(clone)); - } - else if (const TableTypeVar* ttv = get(ty)) - { - LUAU_ASSERT(!ttv->boundTo); - TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->state}; - if (!FFlag::LuauNoMethodLocations) - clone.methodDefinitionLocations = ttv->methodDefinitionLocations; - clone.definitionModuleName = ttv->definitionModuleName; - clone.name = ttv->name; - clone.syntheticName = ttv->syntheticName; - clone.instantiatedTypeParams = ttv->instantiatedTypeParams; - clone.instantiatedTypePackParams = ttv->instantiatedTypePackParams; - clone.tags = ttv->tags; - result = addType(std::move(clone)); - } - else if (const MetatableTypeVar* mtv = get(ty)) - { - MetatableTypeVar clone = MetatableTypeVar{mtv->table, mtv->metatable}; - clone.syntheticName = mtv->syntheticName; - result = addType(std::move(clone)); - } - else if (const UnionTypeVar* utv = get(ty)) - { - UnionTypeVar clone; - clone.options = utv->options; - result = addType(std::move(clone)); - } - else if (const IntersectionTypeVar* itv = get(ty)) - { - IntersectionTypeVar clone; - clone.parts = itv->parts; - result = addType(std::move(clone)); - } - else if (const ConstrainedTypeVar* ctv = get(ty)) - { - ConstrainedTypeVar clone{ctv->level, ctv->parts}; - result = addType(std::move(clone)); - } - - asMutable(result)->documentationSymbol = ty->documentationSymbol; - return result; + return shallowClone(ty, *arena, log); } TypePackId Substitution::clone(TypePackId tp) diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index a4a3ec49..8490350d 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -18,7 +18,6 @@ LUAU_FASTFLAG(LuauLowerBoundsCalculation) * Fair warning: Setting this will break a lot of Luau unit tests. */ LUAU_FASTFLAGVARIABLE(DebugLuauVerboseTypeNames, false) -LUAU_FASTFLAGVARIABLE(LuauDocFuncParameters, false) namespace Luau { @@ -1196,65 +1195,38 @@ std::string toStringNamedFunction(const std::string& funcName, const FunctionTyp auto argPackIter = begin(ftv.argTypes); bool first = true; - if (FFlag::LuauDocFuncParameters) + size_t idx = 0; + while (argPackIter != end(ftv.argTypes)) { - size_t idx = 0; - while (argPackIter != end(ftv.argTypes)) + // ftv takes a self parameter as the first argument, skip it if specified in option + if (idx == 0 && ftv.hasSelf && opts.hideFunctionSelfArgument) { - // ftv takes a self parameter as the first argument, skip it if specified in option - if (idx == 0 && ftv.hasSelf && opts.hideFunctionSelfArgument) - { - ++argPackIter; - ++idx; - continue; - } - - if (!first) - state.emit(", "); - first = false; - - // We don't respect opts.functionTypeArguments - if (idx < opts.namedFunctionOverrideArgNames.size()) - { - state.emit(opts.namedFunctionOverrideArgNames[idx] + ": "); - } - else if (idx < ftv.argNames.size() && ftv.argNames[idx]) - { - state.emit(ftv.argNames[idx]->name + ": "); - } - else - { - state.emit("_: "); - } - tvs.stringify(*argPackIter); - ++argPackIter; ++idx; + continue; } - } - else - { - auto argNameIter = ftv.argNames.begin(); - while (argPackIter != end(ftv.argTypes)) + + if (!first) + state.emit(", "); + first = false; + + // We don't respect opts.functionTypeArguments + if (idx < opts.namedFunctionOverrideArgNames.size()) { - if (!first) - state.emit(", "); - first = false; - - // We don't currently respect opts.functionTypeArguments. I don't think this function should. - if (argNameIter != ftv.argNames.end()) - { - state.emit((*argNameIter ? (*argNameIter)->name : "_") + ": "); - ++argNameIter; - } - else - { - state.emit("_: "); - } - - tvs.stringify(*argPackIter); - ++argPackIter; + state.emit(opts.namedFunctionOverrideArgNames[idx] + ": "); } + else if (idx < ftv.argNames.size() && ftv.argNames[idx]) + { + state.emit(ftv.argNames[idx]->name + ": "); + } + else + { + state.emit("_: "); + } + tvs.stringify(*argPackIter); + + ++argPackIter; + ++idx; } if (argPackIter.tail()) @@ -1337,4 +1309,55 @@ std::string generateName(size_t i) return n; } +std::string toString(const Constraint& c, ToStringOptions& opts) +{ + if (const SubtypeConstraint* sc = Luau::get_if(&c.c)) + { + ToStringResult subStr = toStringDetailed(sc->subType, opts); + opts.nameMap = std::move(subStr.nameMap); + ToStringResult superStr = toStringDetailed(sc->superType, opts); + opts.nameMap = std::move(superStr.nameMap); + return subStr.name + " <: " + superStr.name; + } + else if (const PackSubtypeConstraint* psc = Luau::get_if(&c.c)) + { + ToStringResult subStr = toStringDetailed(psc->subPack, opts); + opts.nameMap = std::move(subStr.nameMap); + ToStringResult superStr = toStringDetailed(psc->superPack, opts); + opts.nameMap = std::move(superStr.nameMap); + return subStr.name + " <: " + superStr.name; + } + else if (const GeneralizationConstraint* gc = Luau::get_if(&c.c)) + { + ToStringResult subStr = toStringDetailed(gc->subType, opts); + opts.nameMap = std::move(subStr.nameMap); + ToStringResult superStr = toStringDetailed(gc->superType, opts); + opts.nameMap = std::move(superStr.nameMap); + return subStr.name + " ~ gen " + superStr.name; + } + else if (const InstantiationConstraint* ic = Luau::get_if(&c.c)) + { + ToStringResult subStr = toStringDetailed(ic->subType, opts); + opts.nameMap = std::move(subStr.nameMap); + ToStringResult superStr = toStringDetailed(ic->superType, opts); + opts.nameMap = std::move(superStr.nameMap); + return subStr.name + " ~ inst " + superStr.name; + } + else + { + LUAU_ASSERT(false); + return ""; + } +} + +std::string dump(const Constraint& c) +{ + ToStringOptions opts; + opts.exhaustive = true; + opts.functionTypeArguments = true; + std::string s = toString(c, opts); + printf("%s\n", s.c_str()); + return s; +} + } // namespace Luau diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 11813c76..4931bc59 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -30,7 +30,6 @@ LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 300) LUAU_FASTINTVARIABLE(LuauVisitRecursionLimit, 500) LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAG(LuauAutocompleteDynamicLimits) -LUAU_FASTFLAGVARIABLE(LuauDoNotRelyOnNextBinding, false) LUAU_FASTFLAGVARIABLE(LuauExpectedPropTypeFromIndexer, false) LUAU_FASTFLAGVARIABLE(LuauLowerBoundsCalculation, false) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) @@ -1182,12 +1181,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) unify(varTy, var, forin.location); if (!get(iterTy) && !get(iterTy) && !get(iterTy)) - { - if (FFlag::LuauDoNotRelyOnNextBinding) - reportError(firstValue->location, CannotCallNonFunction{iterTy}); - else - reportError(TypeError{firstValue->location, TypeMismatch{globalScope->bindings[AstName{"next"}].typeId, iterTy}}); - } + reportError(firstValue->location, CannotCallNonFunction{iterTy}); return check(loopScope, *forin.body); } @@ -3714,7 +3708,7 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A { retPack = freshTypePack(free->level); TypePackId freshArgPack = freshTypePack(free->level); - *asMutable(actualFunctionType) = FunctionTypeVar(free->level, freshArgPack, retPack); + asMutable(actualFunctionType)->ty.emplace(free->level, freshArgPack, retPack); } else retPack = freshTypePack(scope->level); diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 2355dab2..12cbed91 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -24,7 +24,6 @@ LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTFLAG(LuauSubtypingAddOptPropsToUnsealedTables) -LUAU_FASTFLAGVARIABLE(LuauClassDefinitionModuleInError, false) namespace Luau { @@ -302,7 +301,7 @@ std::optional getDefinitionModuleName(TypeId type) if (ftv->definition) return ftv->definition->definitionModuleName; } - else if (auto ctv = get(type); ctv && FFlag::LuauClassDefinitionModuleInError) + else if (auto ctv = get(type)) { if (!ctv->definitionModuleName.empty()) return ctv->definitionModuleName; @@ -724,7 +723,7 @@ TypeId SingletonTypes::makeStringMetatable() TableTypeVar::Props stringLib = { {"byte", {arena->addType(FunctionTypeVar{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList})}}, - {"char", {arena->addType(FunctionTypeVar{arena->addTypePack(TypePack{{numberType}, numberVariadicList}), arena->addTypePack({stringType})})}}, + {"char", {arena->addType(FunctionTypeVar{numberVariadicList, arena->addTypePack({stringType})})}}, {"find", {makeFunction(*arena, stringType, {}, {}, {stringType, optionalNumber, optionalBoolean}, {}, {optionalNumber, optionalNumber})}}, {"format", {formatFn}}, // FIXME {"gmatch", {gmatchFunc}}, diff --git a/Analysis/src/Unifiable.cpp b/Analysis/src/Unifiable.cpp index dc554664..fe878358 100644 --- a/Analysis/src/Unifiable.cpp +++ b/Analysis/src/Unifiable.cpp @@ -12,12 +12,16 @@ Free::Free(TypeLevel level) { } +Free::Free(Scope2* scope) + : scope(scope) +{ +} + int Free::nextIndex = 0; Generic::Generic() : index(++nextIndex) , name("g" + std::to_string(index)) - , explicitName(false) { } @@ -25,7 +29,6 @@ Generic::Generic(TypeLevel level) : index(++nextIndex) , level(level) , name("g" + std::to_string(index)) - , explicitName(false) { } @@ -36,6 +39,12 @@ Generic::Generic(const Name& name) { } +Generic::Generic(Scope2* scope) + : index(++nextIndex) + , scope(scope) +{ +} + Generic::Generic(TypeLevel level, const Name& name) : index(++nextIndex) , level(level) diff --git a/Sources.cmake b/Sources.cmake index 82993349..99007e89 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -65,9 +65,12 @@ target_sources(Luau.CodeGen PRIVATE target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/AstQuery.h Analysis/include/Luau/Autocomplete.h + Analysis/include/Luau/NotNull.h Analysis/include/Luau/BuiltinDefinitions.h - Analysis/include/Luau/Config.h Analysis/include/Luau/Clone.h + Analysis/include/Luau/Config.h + Analysis/include/Luau/ConstraintGraphBuilder.h + Analysis/include/Luau/ConstraintSolver.h Analysis/include/Luau/Documentation.h Analysis/include/Luau/Error.h Analysis/include/Luau/FileResolver.h @@ -108,8 +111,10 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/AstQuery.cpp Analysis/src/Autocomplete.cpp Analysis/src/BuiltinDefinitions.cpp - Analysis/src/Config.cpp Analysis/src/Clone.cpp + Analysis/src/Config.cpp + Analysis/src/ConstraintGraphBuilder.cpp + Analysis/src/ConstraintSolver.cpp Analysis/src/Error.cpp Analysis/src/Frontend.cpp Analysis/src/Instantiation.cpp @@ -240,6 +245,7 @@ if(TARGET Luau.UnitTest) tests/AstQuery.test.cpp tests/AstVisitor.test.cpp tests/Autocomplete.test.cpp + tests/NotNull.test.cpp tests/BuiltinDefinitions.test.cpp tests/Compiler.test.cpp tests/Config.test.cpp @@ -252,6 +258,8 @@ if(TARGET Luau.UnitTest) tests/Module.test.cpp tests/NonstrictMode.test.cpp tests/Normalize.test.cpp + tests/ConstraintGraphBuilder.test.cpp + tests/ConstraintSolver.test.cpp tests/Parser.test.cpp tests/RequireTracer.test.cpp tests/RuntimeLimits.test.cpp diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index f3be64b8..f86371da 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -14,8 +14,6 @@ #include -LUAU_FASTFLAG(LuauGcWorkTrackFix) - const char* lua_ident = "$Lua: Lua 5.1.4 Copyright (C) 1994-2008 Lua.org, PUC-Rio $\n" "$Authors: R. Ierusalimschy, L. H. de Figueiredo & W. Celes $\n" "$URL: www.lua.org $\n"; @@ -488,11 +486,11 @@ void* lua_touserdata(lua_State* L, int idx) { StkId o = index2addr(L, idx); if (ttisuserdata(o)) - return uvalue(o)->data; + return uvalue(o)->data; else if (ttislightuserdata(o)) - return pvalue(o); + return pvalue(o); else - return NULL; + return NULL; } void* lua_touserdatatagged(lua_State* L, int idx, int tag) @@ -1054,7 +1052,6 @@ int lua_gc(lua_State* L, int what, int data) } case LUA_GCSTEP: { - size_t prevthreshold = g->GCthreshold; size_t amount = (cast_to(size_t, data) << 10); ptrdiff_t oldcredit = g->gcstate == GCSpause ? 0 : g->GCthreshold - g->totalbytes; @@ -1064,8 +1061,6 @@ int lua_gc(lua_State* L, int what, int data) else g->GCthreshold = 0; - bool waspaused = g->gcstate == GCSpause; - #ifdef LUAI_GCMETRICS double startmarktime = g->gcmetrics.currcycle.marktime; double startsweeptime = g->gcmetrics.currcycle.sweeptime; @@ -1078,7 +1073,7 @@ int lua_gc(lua_State* L, int what, int data) { size_t stepsize = luaC_step(L, false); - actualwork += FFlag::LuauGcWorkTrackFix ? stepsize : g->gcstepsize; + actualwork += stepsize; if (g->gcstate == GCSpause) { /* end of cycle? */ @@ -1114,20 +1109,9 @@ int lua_gc(lua_State* L, int what, int data) // if cycle hasn't finished, advance threshold forward for the amount of extra work performed if (g->gcstate != GCSpause) { - if (FFlag::LuauGcWorkTrackFix) - { - // if a new cycle was triggered by explicit step, old 'credit' of GC work is 0 - ptrdiff_t newthreshold = g->totalbytes + actualwork + oldcredit; - g->GCthreshold = newthreshold < 0 ? 0 : newthreshold; - } - else - { - // if a new cycle was triggered by explicit step, we ignore old threshold as that shows an incorrect 'credit' of GC work - if (waspaused) - g->GCthreshold = g->totalbytes + actualwork; - else - g->GCthreshold = prevthreshold + actualwork; - } + // if a new cycle was triggered by explicit step, old 'credit' of GC work is 0 + ptrdiff_t newthreshold = g->totalbytes + actualwork + oldcredit; + g->GCthreshold = newthreshold < 0 ? 0 : newthreshold; } break; } diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index 4cab746a..a71fce52 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -213,7 +213,7 @@ CallInfo* luaD_growCI(lua_State* L) return ++L->ci; } -void luaD_checkCstack(lua_State *L) +void luaD_checkCstack(lua_State* L) { if (L->nCcalls == LUAI_MAXCCALLS) luaG_runerror(L, "C stack overflow"); diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp index e7b73fe7..70b4dbf9 100644 --- a/VM/src/lgc.cpp +++ b/VM/src/lgc.cpp @@ -13,10 +13,7 @@ #include -LUAU_FASTFLAGVARIABLE(LuauGcWorkTrackFix, false) -LUAU_FASTFLAGVARIABLE(LuauGcSweepCostFix, false) - -#define GC_SWEEPPAGESTEPCOST (FFlag::LuauGcSweepCostFix ? 16 : 4) +#define GC_SWEEPPAGESTEPCOST 16 #define GC_INTERRUPT(state) \ { \ @@ -881,7 +878,7 @@ size_t luaC_step(lua_State* L, bool assist) { global_State* g = L->global; - int lim = FFlag::LuauGcWorkTrackFix ? g->gcstepsize * g->gcstepmul / 100 : (g->gcstepsize / 100) * g->gcstepmul; /* how much to work */ + int lim = g->gcstepsize * g->gcstepmul / 100; /* how much to work */ LUAU_ASSERT(g->totalbytes >= g->GCthreshold); size_t debt = g->totalbytes - g->GCthreshold; @@ -927,10 +924,10 @@ size_t luaC_step(lua_State* L, bool assist) } else { - g->GCthreshold = g->totalbytes + (FFlag::LuauGcWorkTrackFix ? actualstepsize : g->gcstepsize); + g->GCthreshold = g->totalbytes + actualstepsize; // compensate if GC is "behind schedule" (has some debt to pay) - if (FFlag::LuauGcWorkTrackFix ? g->GCthreshold >= debt : g->GCthreshold > debt) + if (g->GCthreshold >= debt) g->GCthreshold -= debt; } diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index f723e0d1..20139650 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -4471,7 +4471,6 @@ end RETURN R0 0 RETURN R0 0 )"); - } TEST_CASE("LoopUnrollNestedClosure") diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 7b4e83ba..f7f2b4ac 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -242,11 +242,14 @@ TEST_CASE("Math") TEST_CASE("Table") { runConformance("nextvar.lua", [](lua_State* L) { - lua_pushcfunction(L, [](lua_State* L) { - unsigned v = luaL_checkunsigned(L, 1); - lua_pushlightuserdata(L, reinterpret_cast(uintptr_t(v))); - return 1; - }, "makelud"); + lua_pushcfunction( + L, + [](lua_State* L) { + unsigned v = luaL_checkunsigned(L, 1); + lua_pushlightuserdata(L, reinterpret_cast(uintptr_t(v))); + return 1; + }, + "makelud"); lua_setglobal(L, "makelud"); }); } @@ -1150,129 +1153,171 @@ TEST_CASE("Userdata") gInt64MT = lua_ref(L, -1); // __index - lua_pushcfunction(L, [](lua_State* L) { - void* p = lua_touserdatatagged(L, 1, kInt64Tag); - if (!p) - luaL_typeerror(L, 1, "int64"); + lua_pushcfunction( + L, + [](lua_State* L) { + void* p = lua_touserdatatagged(L, 1, kInt64Tag); + if (!p) + luaL_typeerror(L, 1, "int64"); - const char* name = luaL_checkstring(L, 2); + const char* name = luaL_checkstring(L, 2); - if (strcmp(name, "value") == 0) - { - lua_pushnumber(L, double(*static_cast(p))); - return 1; - } + if (strcmp(name, "value") == 0) + { + lua_pushnumber(L, double(*static_cast(p))); + return 1; + } - luaL_error(L, "unknown field %s", name); - }, nullptr); + luaL_error(L, "unknown field %s", name); + }, + nullptr); lua_setfield(L, -2, "__index"); // __newindex - lua_pushcfunction(L, [](lua_State* L) { - void* p = lua_touserdatatagged(L, 1, kInt64Tag); - if (!p) - luaL_typeerror(L, 1, "int64"); + lua_pushcfunction( + L, + [](lua_State* L) { + void* p = lua_touserdatatagged(L, 1, kInt64Tag); + if (!p) + luaL_typeerror(L, 1, "int64"); - const char* name = luaL_checkstring(L, 2); + const char* name = luaL_checkstring(L, 2); - if (strcmp(name, "value") == 0) - { - double value = luaL_checknumber(L, 3); - *static_cast(p) = int64_t(value); - return 0; - } + if (strcmp(name, "value") == 0) + { + double value = luaL_checknumber(L, 3); + *static_cast(p) = int64_t(value); + return 0; + } - luaL_error(L, "unknown field %s", name); - }, nullptr); + luaL_error(L, "unknown field %s", name); + }, + nullptr); lua_setfield(L, -2, "__newindex"); // __eq - lua_pushcfunction(L, [](lua_State* L) { - lua_pushboolean(L, getInt64(L, 1) == getInt64(L, 2)); - return 1; - }, nullptr); + lua_pushcfunction( + L, + [](lua_State* L) { + lua_pushboolean(L, getInt64(L, 1) == getInt64(L, 2)); + return 1; + }, + nullptr); lua_setfield(L, -2, "__eq"); // __lt - lua_pushcfunction(L, [](lua_State* L) { - lua_pushboolean(L, getInt64(L, 1) < getInt64(L, 2)); - return 1; - }, nullptr); + lua_pushcfunction( + L, + [](lua_State* L) { + lua_pushboolean(L, getInt64(L, 1) < getInt64(L, 2)); + return 1; + }, + nullptr); lua_setfield(L, -2, "__lt"); // __le - lua_pushcfunction(L, [](lua_State* L) { - lua_pushboolean(L, getInt64(L, 1) <= getInt64(L, 2)); - return 1; - }, nullptr); + lua_pushcfunction( + L, + [](lua_State* L) { + lua_pushboolean(L, getInt64(L, 1) <= getInt64(L, 2)); + return 1; + }, + nullptr); lua_setfield(L, -2, "__le"); // __add - lua_pushcfunction(L, [](lua_State* L) { - pushInt64(L, getInt64(L, 1) + getInt64(L, 2)); - return 1; - }, nullptr); + lua_pushcfunction( + L, + [](lua_State* L) { + pushInt64(L, getInt64(L, 1) + getInt64(L, 2)); + return 1; + }, + nullptr); lua_setfield(L, -2, "__add"); // __sub - lua_pushcfunction(L, [](lua_State* L) { - pushInt64(L, getInt64(L, 1) - getInt64(L, 2)); - return 1; - }, nullptr); + lua_pushcfunction( + L, + [](lua_State* L) { + pushInt64(L, getInt64(L, 1) - getInt64(L, 2)); + return 1; + }, + nullptr); lua_setfield(L, -2, "__sub"); // __mul - lua_pushcfunction(L, [](lua_State* L) { - pushInt64(L, getInt64(L, 1) * getInt64(L, 2)); - return 1; - }, nullptr); + lua_pushcfunction( + L, + [](lua_State* L) { + pushInt64(L, getInt64(L, 1) * getInt64(L, 2)); + return 1; + }, + nullptr); lua_setfield(L, -2, "__mul"); // __div - lua_pushcfunction(L, [](lua_State* L) { - // ideally we'd guard against 0 but it's a test so eh - pushInt64(L, getInt64(L, 1) / getInt64(L, 2)); - return 1; - }, nullptr); + lua_pushcfunction( + L, + [](lua_State* L) { + // ideally we'd guard against 0 but it's a test so eh + pushInt64(L, getInt64(L, 1) / getInt64(L, 2)); + return 1; + }, + nullptr); lua_setfield(L, -2, "__div"); // __mod - lua_pushcfunction(L, [](lua_State* L) { - // ideally we'd guard against 0 and INT64_MIN but it's a test so eh - pushInt64(L, getInt64(L, 1) % getInt64(L, 2)); - return 1; - }, nullptr); + lua_pushcfunction( + L, + [](lua_State* L) { + // ideally we'd guard against 0 and INT64_MIN but it's a test so eh + pushInt64(L, getInt64(L, 1) % getInt64(L, 2)); + return 1; + }, + nullptr); lua_setfield(L, -2, "__mod"); // __pow - lua_pushcfunction(L, [](lua_State* L) { - pushInt64(L, int64_t(pow(double(getInt64(L, 1)), double(getInt64(L, 2))))); - return 1; - }, nullptr); + lua_pushcfunction( + L, + [](lua_State* L) { + pushInt64(L, int64_t(pow(double(getInt64(L, 1)), double(getInt64(L, 2))))); + return 1; + }, + nullptr); lua_setfield(L, -2, "__pow"); // __unm - lua_pushcfunction(L, [](lua_State* L) { - pushInt64(L, -getInt64(L, 1)); - return 1; - }, nullptr); + lua_pushcfunction( + L, + [](lua_State* L) { + pushInt64(L, -getInt64(L, 1)); + return 1; + }, + nullptr); lua_setfield(L, -2, "__unm"); // __tostring - lua_pushcfunction(L, [](lua_State* L) { - int64_t value = getInt64(L, 1); - std::string str = std::to_string(value); - lua_pushlstring(L, str.c_str(), str.length()); - return 1; - }, nullptr); + lua_pushcfunction( + L, + [](lua_State* L) { + int64_t value = getInt64(L, 1); + std::string str = std::to_string(value); + lua_pushlstring(L, str.c_str(), str.length()); + return 1; + }, + nullptr); lua_setfield(L, -2, "__tostring"); // ctor - lua_pushcfunction(L, [](lua_State* L) { - double v = luaL_checknumber(L, 1); - pushInt64(L, int64_t(v)); - return 1; - }, "int64"); + lua_pushcfunction( + L, + [](lua_State* L) { + double v = luaL_checknumber(L, 1); + pushInt64(L, int64_t(v)); + return 1; + }, + "int64"); lua_setglobal(L, "int64"); }); } diff --git a/tests/ConstraintGraphBuilder.test.cpp b/tests/ConstraintGraphBuilder.test.cpp new file mode 100644 index 00000000..ab5af4f6 --- /dev/null +++ b/tests/ConstraintGraphBuilder.test.cpp @@ -0,0 +1,107 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Fixture.h" +#include "Luau/ConstraintGraphBuilder.h" + +#include "doctest.h" + +using namespace Luau; + +TEST_SUITE_BEGIN("ConstraintGraphBuilder"); + +TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "hello_world") +{ + AstStatBlock* block = parse(R"( + local a = "hello" + local b = a + )"); + + cgb.visit(block); + std::vector constraints = collectConstraints(cgb.rootScope); + + REQUIRE(2 == constraints.size()); + + ToStringOptions opts; + CHECK("a <: string" == toString(*constraints[0], opts)); + CHECK("b <: a" == toString(*constraints[1], opts)); +} + +TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "primitives") +{ + AstStatBlock* block = parse(R"( + local s = "hello" + local n = 555 + local b = true + local n2 = nil + )"); + + cgb.visit(block); + std::vector constraints = collectConstraints(cgb.rootScope); + + REQUIRE(4 == constraints.size()); + + ToStringOptions opts; + CHECK("a <: string" == toString(*constraints[0], opts)); + CHECK("b <: number" == toString(*constraints[1], opts)); + CHECK("c <: boolean" == toString(*constraints[2], opts)); + CHECK("d <: nil" == toString(*constraints[3], opts)); +} + +TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "function_application") +{ + AstStatBlock* block = parse(R"( + local a = "hello" + local b = a("world") + )"); + + cgb.visit(block); + std::vector constraints = collectConstraints(cgb.rootScope); + + REQUIRE(4 == constraints.size()); + + ToStringOptions opts; + CHECK("a <: string" == toString(*constraints[0], opts)); + CHECK("b ~ inst a" == toString(*constraints[1], opts)); + CHECK("(string) -> (c, d...) <: b" == toString(*constraints[2], opts)); + CHECK("e <: c" == toString(*constraints[3], opts)); +} + +TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "local_function_definition") +{ + AstStatBlock* block = parse(R"( + local function f(a) + return a + end + )"); + + cgb.visit(block); + std::vector constraints = collectConstraints(cgb.rootScope); + + REQUIRE(2 == constraints.size()); + + ToStringOptions opts; + CHECK("a ~ gen (b) -> (c...)" == toString(*constraints[0], opts)); + CHECK("b <: c..." == toString(*constraints[1], opts)); +} + +TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "recursive_function") +{ + AstStatBlock* block = parse(R"( + local function f(a) + return f(a) + end + )"); + + cgb.visit(block); + std::vector constraints = collectConstraints(cgb.rootScope); + + REQUIRE(4 == constraints.size()); + + ToStringOptions opts; + CHECK("a ~ gen (b) -> (c...)" == toString(*constraints[0], opts)); + CHECK("d ~ inst a" == toString(*constraints[1], opts)); + CHECK("(b) -> (e, f...) <: d" == toString(*constraints[2], opts)); + CHECK("e <: c..." == toString(*constraints[3], opts)); +} + +TEST_SUITE_END(); diff --git a/tests/ConstraintSolver.test.cpp b/tests/ConstraintSolver.test.cpp new file mode 100644 index 00000000..5959f55c --- /dev/null +++ b/tests/ConstraintSolver.test.cpp @@ -0,0 +1,87 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Fixture.h" + +#include "doctest.h" + +#include "Luau/ConstraintGraphBuilder.h" +#include "Luau/ConstraintSolver.h" + +using namespace Luau; + +static TypeId requireBinding(Scope2* scope, const char* name) +{ + auto b = linearSearchForBinding(scope, name); + LUAU_ASSERT(b.has_value()); + return *b; +} + +TEST_SUITE_BEGIN("ConstraintSolver"); + +TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "hello") +{ + AstStatBlock* block = parse(R"( + local a = 55 + local b = a + )"); + + cgb.visit(block); + + ConstraintSolver cs{&arena, cgb.rootScope}; + + cs.run(); + + TypeId bType = requireBinding(cgb.rootScope, "b"); + + CHECK("number" == toString(bType)); +} + +TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "generic_function") +{ + AstStatBlock* block = parse(R"( + local function id(a) + return a + end + )"); + + cgb.visit(block); + + ConstraintSolver cs{&arena, cgb.rootScope}; + + cs.run(); + + TypeId idType = requireBinding(cgb.rootScope, "id"); + + CHECK("(a) -> a" == toString(idType)); +} + +#if 1 +TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "proper_let_generalization") +{ + AstStatBlock* block = parse(R"( + local function a(c) + local function d(e) + return c + end + + return d + end + + local b = a(5) + )"); + + cgb.visit(block); + + ToStringOptions opts; + + ConstraintSolver cs{&arena, cgb.rootScope}; + + cs.run(); + + TypeId idType = requireBinding(cgb.rootScope, "b"); + + CHECK("(a) -> number" == toString(idType, opts)); +} +#endif + +TEST_SUITE_END(); diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index 03f3e15c..232ec2de 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -17,6 +17,8 @@ static const char* mainModuleName = "MainModule"; +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); + namespace Luau { @@ -249,7 +251,10 @@ std::optional Fixture::getType(const std::string& name) ModulePtr module = getMainModule(); REQUIRE(module); - return lookupName(module->getModuleScope(), name); + if (FFlag::DebugLuauDeferredConstraintResolution) + return linearSearchForBinding(module->getModuleScope2(), name.c_str()); + else + return lookupName(module->getModuleScope(), name); } TypeId Fixture::requireType(const std::string& name) @@ -421,6 +426,12 @@ BuiltinsFixture::BuiltinsFixture(bool freeze, bool prepareAutocomplete) Luau::freeze(frontend.typeCheckerForAutocomplete.globalTypes); } +ConstraintGraphBuilderFixture::ConstraintGraphBuilderFixture() + : Fixture() + , forceTheFlag{"DebugLuauDeferredConstraintResolution", true} +{ +} + ModuleName fromString(std::string_view name) { return ModuleName(name); @@ -460,4 +471,27 @@ std::optional lookupName(ScopePtr scope, const std::string& name) return std::nullopt; } +std::optional linearSearchForBinding(Scope2* scope, const char* name) +{ + while (scope) + { + for (const auto& [n, ty] : scope->bindings) + { + if (n.astName() == name) + return ty; + } + + scope = scope->parent; + } + + return std::nullopt; +} + +void dump(const std::vector& constraints) +{ + ToStringOptions opts; + for (const auto& c : constraints) + printf("%s\n", toString(c, opts).c_str()); +} + } // namespace Luau diff --git a/tests/Fixture.h b/tests/Fixture.h index 901f7d42..ffcd4b9e 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -2,6 +2,7 @@ #pragma once #include "Luau/Config.h" +#include "Luau/ConstraintGraphBuilder.h" #include "Luau/FileResolver.h" #include "Luau/Frontend.h" #include "Luau/IostreamHelpers.h" @@ -156,6 +157,16 @@ struct BuiltinsFixture : Fixture BuiltinsFixture(bool freeze = true, bool prepareAutocomplete = false); }; +struct ConstraintGraphBuilderFixture : Fixture +{ + TypeArena arena; + ConstraintGraphBuilder cgb{&arena}; + + ScopedFastFlag forceTheFlag; + + ConstraintGraphBuilderFixture(); +}; + ModuleName fromString(std::string_view name); template @@ -175,9 +186,12 @@ bool isInArena(TypeId t, const TypeArena& arena); void dumpErrors(const ModulePtr& module); void dumpErrors(const Module& module); void dump(const std::string& name, TypeId ty); +void dump(const std::vector& constraints); std::optional lookupName(ScopePtr scope, const std::string& name); // Warning: This function runs in O(n**2) +std::optional linearSearchForBinding(Scope2* scope, const char* name); + } // namespace Luau #define LUAU_REQUIRE_ERRORS(result) \ diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index 33b81be8..c0554669 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -1031,8 +1031,6 @@ return false; TEST_CASE("check_without_builtin_next") { - ScopedFastFlag luauDoNotRelyOnNextBinding{"LuauDoNotRelyOnNextBinding", true}; - TestFileResolver fileResolver; TestConfigResolver configResolver; Frontend frontend(&fileResolver, &configResolver); diff --git a/tests/NotNull.test.cpp b/tests/NotNull.test.cpp new file mode 100644 index 00000000..1a323c85 --- /dev/null +++ b/tests/NotNull.test.cpp @@ -0,0 +1,116 @@ +#include "Luau/NotNull.h" + +#include "doctest.h" + +#include +#include +#include + +using Luau::NotNull; + +namespace +{ + +struct Test +{ + int x; + float y; + + static int count; + Test() + { + ++count; + } + + ~Test() + { + --count; + } +}; + +int Test::count = 0; + +} + +int foo(NotNull p) +{ + return *p; +} + +void bar(int* q) +{} + +TEST_SUITE_BEGIN("NotNull"); + +TEST_CASE("basic_stuff") +{ + NotNull a = NotNull{new int(55)}; // Does runtime test + NotNull b{new int(55)}; // As above + // NotNull c = new int(55); // Nope. Mildly regrettable, but implicit conversion from T* to NotNull in the general case is not good. + + // a = nullptr; // nope + + NotNull d = a; // No runtime test. a is known not to be null. + + int e = *d; + *d = 1; + CHECK(e == 55); + + const NotNull f = d; + *f = 5; // valid: there is a difference between const NotNull and NotNull + // f = a; // nope + + CHECK_EQ(a, d); + CHECK(a != b); + + NotNull g(a); + CHECK(g == a); + + // *g = 123; // nope + + (void)f; + + NotNull t{new Test}; + t->x = 5; + t->y = 3.14f; + + const NotNull u = t; + // u->x = 44; // nope + int v = u->x; + CHECK(v == 5); + + bar(a); + + // a++; // nope + // a[41]; // nope + // a + 41; // nope + // a - 41; // nope + + delete a; + delete b; + delete t; + + CHECK_EQ(0, Test::count); +} + +TEST_CASE("hashable") +{ + std::unordered_map, const char*> map; + NotNull a{new int(8)}; + NotNull b{new int(10)}; + + std::string hello = "hello"; + std::string world = "world"; + + map[a] = hello.c_str(); + map[b] = world.c_str(); + + CHECK_EQ(2, map.size()); + CHECK_EQ(hello.c_str(), map[a]); + CHECK_EQ(world.c_str(), map[b]); + + delete a; + delete b; +} + +TEST_SUITE_END(); diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index b854bc51..4d9fad14 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -505,7 +505,6 @@ TEST_CASE_FIXTURE(Fixture, "self_recursive_instantiated_param") TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_id") { - ScopedFastFlag flag{"LuauDocFuncParameters", true}; CheckResult result = check(R"( local function id(x) return x end )"); @@ -518,7 +517,6 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_id") TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_map") { - ScopedFastFlag flag{"LuauDocFuncParameters", true}; CheckResult result = check(R"( local function map(arr, fn) local t = {} @@ -537,7 +535,6 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_map") TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_generic_pack") { - ScopedFastFlag flag{"LuauDocFuncParameters", true}; CheckResult result = check(R"( local function f(a: number, b: string) end local function test(...: T...): U... @@ -554,7 +551,6 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_generic_pack") TEST_CASE("toStringNamedFunction_unit_f") { - ScopedFastFlag flag{"LuauDocFuncParameters", true}; TypePackVar empty{TypePack{}}; FunctionTypeVar ftv{&empty, &empty, {}, false}; CHECK_EQ("f(): ()", toStringNamedFunction("f", ftv)); @@ -562,7 +558,6 @@ TEST_CASE("toStringNamedFunction_unit_f") TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics") { - ScopedFastFlag flag{"LuauDocFuncParameters", true}; CheckResult result = check(R"( local function f(x: a, ...): (a, a, b...) return x, x, ... @@ -577,7 +572,6 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics") TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics2") { - ScopedFastFlag flag{"LuauDocFuncParameters", true}; CheckResult result = check(R"( local function f(): ...number return 1, 2, 3 @@ -592,7 +586,6 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics2") TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics3") { - ScopedFastFlag flag{"LuauDocFuncParameters", true}; CheckResult result = check(R"( local function f(): (string, ...number) return 'a', 1, 2, 3 @@ -607,7 +600,6 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics3") TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_type_annotation_has_partial_argnames") { - ScopedFastFlag flag{"LuauDocFuncParameters", true}; CheckResult result = check(R"( local f: (number, y: number) -> number )"); @@ -620,7 +612,6 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_type_annotation_has_partial_ar TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_hide_type_params") { - ScopedFastFlag flag{"LuauDocFuncParameters", true}; CheckResult result = check(R"( local function f(x: T, g: (T) -> U)): () end @@ -636,8 +627,6 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_hide_type_params") TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_overrides_param_names") { - ScopedFastFlag flag{"LuauDocFuncParameters", true}; - CheckResult result = check(R"( local function test(a, b : string, ... : number) return a end )"); @@ -665,7 +654,6 @@ TEST_CASE_FIXTURE(Fixture, "pick_distinct_names_for_mixed_explicit_and_implicit_ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_include_self_param") { - ScopedFastFlag flag{"LuauDocFuncParameters", true}; CheckResult result = check(R"( local foo = {} function foo:method(arg: string): () @@ -682,7 +670,6 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_include_self_param") TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_hide_self_param") { - ScopedFastFlag flag{"LuauDocFuncParameters", true}; CheckResult result = check(R"( local foo = {} function foo:method(arg: string): () diff --git a/tests/TypeInfer.classes.test.cpp b/tests/TypeInfer.classes.test.cpp index d90129d7..6f4191e3 100644 --- a/tests/TypeInfer.classes.test.cpp +++ b/tests/TypeInfer.classes.test.cpp @@ -470,8 +470,6 @@ caused by: TEST_CASE_FIXTURE(ClassFixture, "class_type_mismatch_with_name_conflict") { - ScopedFastFlag luauClassDefinitionModuleInError{"LuauClassDefinitionModuleInError", true}; - CheckResult result = check(R"( local i = ChildClass.New() type ChildClass = { x: number } diff --git a/tests/TypeInfer.loops.test.cpp b/tests/TypeInfer.loops.test.cpp index a3cae3de..4444cd66 100644 --- a/tests/TypeInfer.loops.test.cpp +++ b/tests/TypeInfer.loops.test.cpp @@ -78,8 +78,6 @@ TEST_CASE_FIXTURE(Fixture, "for_in_with_an_iterator_of_type_any") TEST_CASE_FIXTURE(Fixture, "for_in_loop_should_fail_with_non_function_iterator") { - ScopedFastFlag luauDoNotRelyOnNextBinding{"LuauDoNotRelyOnNextBinding", true}; - CheckResult result = check(R"( local foo = "bar" for i, v in foo do diff --git a/tests/Variant.test.cpp b/tests/Variant.test.cpp index fcf37875..aa0731ca 100644 --- a/tests/Variant.test.cpp +++ b/tests/Variant.test.cpp @@ -13,6 +13,25 @@ struct Foo int x = 42; }; +struct Bar +{ + explicit Bar(int x) + : prop(x * 2) + { + ++count; + } + + ~Bar() + { + --count; + } + + int prop; + static int count; +}; + +int Bar::count = 0; + TEST_SUITE_BEGIN("Variant"); TEST_CASE("DefaultCtor") @@ -46,6 +65,29 @@ TEST_CASE("Create") CHECK(get_if(&v3)->x == 3); } +TEST_CASE("Emplace") +{ + { + Variant v1; + + CHECK(0 == Bar::count); + int& i = v1.emplace(5); + CHECK(5 == i); + + CHECK(0 == Bar::count); + + CHECK(get_if(&v1) == &i); + + Bar& bar = v1.emplace(11); + CHECK(22 == bar.prop); + CHECK(1 == Bar::count); + + CHECK(get_if(&v1) == &bar); + } + + CHECK(0 == Bar::count); +} + TEST_CASE("NonPOD") { // initialize (copy) diff --git a/tools/natvis/CodeGen.natvis b/tools/natvis/CodeGen.natvis index 47ff0db1..5ff6e143 100644 --- a/tools/natvis/CodeGen.natvis +++ b/tools/natvis/CodeGen.natvis @@ -2,7 +2,7 @@ - noreg + noreg rip al @@ -36,14 +36,20 @@ - {reg} - {mem.size,en} ptr[{mem.base} + {mem.index}*{(int)mem.scale,d} + {disp}] + {base} + {memSize,en} ptr[{base} + {index}*{(int)scale,d} + {imm}] + {memSize,en} ptr[{index}*{(int)scale,d} + {imm}] + {memSize,en} ptr[{base} + {imm}] + {memSize,en} ptr[{imm}] {imm} - reg - mem + base imm - disp + memSize + base + index + scale + imm