Sync to upstream/release/534 (#569)

This commit is contained in:
Arseny Kapoulkine 2022-06-30 16:52:43 -07:00 committed by GitHub
parent fc763650d3
commit 2daa6497a1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
54 changed files with 1714 additions and 653 deletions

View file

@ -1,6 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Ast.h" // Used for some of the enumerations
#include "Luau/NotNull.h"
#include "Luau/Variant.h"
@ -47,6 +48,21 @@ struct InstantiationConstraint
TypeId superType;
};
struct UnaryConstraint
{
AstExprUnary::Op op;
TypeId operandType;
TypeId resultType;
};
struct BinaryConstraint
{
AstExprBinary::Op op;
TypeId leftType;
TypeId rightType;
TypeId resultType;
};
// name(namedType) = name
struct NameConstraint
{
@ -54,7 +70,8 @@ struct NameConstraint
std::string name;
};
using ConstraintV = Variant<SubtypeConstraint, PackSubtypeConstraint, GeneralizationConstraint, InstantiationConstraint, NameConstraint>;
using ConstraintV = Variant<SubtypeConstraint, PackSubtypeConstraint, GeneralizationConstraint, InstantiationConstraint, UnaryConstraint,
BinaryConstraint, NameConstraint>;
using ConstraintPtr = std::unique_ptr<struct Constraint>;
struct Constraint

View file

@ -25,9 +25,12 @@ struct ConstraintGraphBuilder
// scope pointers; the scopes themselves borrow pointers to other scopes to
// define the scope hierarchy.
std::vector<std::pair<Location, std::unique_ptr<Scope2>>> scopes;
ModuleName moduleName;
SingletonTypes& singletonTypes;
TypeArena* const arena;
const NotNull<TypeArena> arena;
// The root scope of the module we're generating constraints for.
// This is null when the CGB is initially constructed.
Scope2* rootScope;
// A mapping of AST node to TypeId.
DenseHashMap<const AstExpr*, TypeId> astTypes{nullptr};
@ -39,40 +42,50 @@ struct ConstraintGraphBuilder
// Type packs resolved from type annotations. Analogous to astTypePacks.
DenseHashMap<const AstTypePack*, TypePackId> astResolvedTypePacks{nullptr};
explicit ConstraintGraphBuilder(TypeArena* arena);
int recursionCount = 0;
// It is pretty uncommon for constraint generation to itself produce errors, but it can happen.
std::vector<TypeError> errors;
// Occasionally constraint generation needs to produce an ICE.
const NotNull<InternalErrorReporter> ice;
NotNull<Scope2> globalScope;
ConstraintGraphBuilder(const ModuleName& moduleName, TypeArena* arena, NotNull<InternalErrorReporter> ice, NotNull<Scope2> globalScope);
/**
* Fabricates a new free type belonging to a given scope.
* @param scope the scope the free type belongs to. Must not be null.
* @param scope the scope the free type belongs to.
*/
TypeId freshType(Scope2* scope);
TypeId freshType(NotNull<Scope2> scope);
/**
* Fabricates a new free type pack belonging to a given scope.
* @param scope the scope the free type pack belongs to. Must not be null.
* @param scope the scope the free type pack belongs to.
*/
TypePackId freshTypePack(Scope2* scope);
TypePackId freshTypePack(NotNull<Scope2> scope);
/**
* Fabricates a scope that is a child of another scope.
* @param location the lexical extent of the scope in the source code.
* @param parent the parent scope of the new scope. Must not be null.
*/
Scope2* childScope(Location location, Scope2* parent);
NotNull<Scope2> childScope(Location location, NotNull<Scope2> parent);
/**
* Adds a new constraint with no dependencies to a given scope.
* @param scope the scope to add the constraint to. Must not be null.
* @param scope the scope to add the constraint to.
* @param cv the constraint variant to add.
*/
void addConstraint(Scope2* scope, ConstraintV cv);
void addConstraint(NotNull<Scope2> scope, ConstraintV cv);
/**
* Adds a constraint to a given scope.
* @param scope the scope to add the constraint to. Must not be null.
* @param c the constraint to add.
*/
void addConstraint(Scope2* scope, std::unique_ptr<Constraint> c);
void addConstraint(NotNull<Scope2> scope, std::unique_ptr<Constraint> c);
/**
* The entry point to the ConstraintGraphBuilder. This will construct a set
@ -81,20 +94,22 @@ struct ConstraintGraphBuilder
*/
void visit(AstStatBlock* block);
void visit(Scope2* scope, AstStat* stat);
void visit(Scope2* scope, AstStatBlock* block);
void visit(Scope2* scope, AstStatLocal* local);
void visit(Scope2* scope, AstStatLocalFunction* function);
void visit(Scope2* scope, AstStatFunction* function);
void visit(Scope2* scope, AstStatReturn* ret);
void visit(Scope2* scope, AstStatAssign* assign);
void visit(Scope2* scope, AstStatIf* ifStatement);
void visit(Scope2* scope, AstStatTypeAlias* alias);
void visitBlockWithoutChildScope(NotNull<Scope2> scope, AstStatBlock* block);
TypePackId checkExprList(Scope2* scope, const AstArray<AstExpr*>& exprs);
void visit(NotNull<Scope2> scope, AstStat* stat);
void visit(NotNull<Scope2> scope, AstStatBlock* block);
void visit(NotNull<Scope2> scope, AstStatLocal* local);
void visit(NotNull<Scope2> scope, AstStatLocalFunction* function);
void visit(NotNull<Scope2> scope, AstStatFunction* function);
void visit(NotNull<Scope2> scope, AstStatReturn* ret);
void visit(NotNull<Scope2> scope, AstStatAssign* assign);
void visit(NotNull<Scope2> scope, AstStatIf* ifStatement);
void visit(NotNull<Scope2> scope, AstStatTypeAlias* alias);
TypePackId checkPack(Scope2* scope, AstArray<AstExpr*> exprs);
TypePackId checkPack(Scope2* scope, AstExpr* expr);
TypePackId checkExprList(NotNull<Scope2> scope, const AstArray<AstExpr*>& exprs);
TypePackId checkPack(NotNull<Scope2> scope, AstArray<AstExpr*> exprs);
TypePackId checkPack(NotNull<Scope2> scope, AstExpr* expr);
/**
* Checks an expression that is expected to evaluate to one type.
@ -102,19 +117,35 @@ struct ConstraintGraphBuilder
* @param expr the expression to check.
* @return the type of the expression.
*/
TypeId check(Scope2* scope, AstExpr* expr);
TypeId check(NotNull<Scope2> scope, AstExpr* expr);
TypeId checkExprTable(Scope2* scope, AstExprTable* expr);
TypeId check(Scope2* scope, AstExprIndexName* indexName);
TypeId checkExprTable(NotNull<Scope2> scope, AstExprTable* expr);
TypeId check(NotNull<Scope2> scope, AstExprIndexName* indexName);
TypeId check(NotNull<Scope2> scope, AstExprIndexExpr* indexExpr);
TypeId check(NotNull<Scope2> scope, AstExprUnary* unary);
TypeId check(NotNull<Scope2> scope, AstExprBinary* binary);
std::pair<TypeId, Scope2*> checkFunctionSignature(Scope2* parent, AstExprFunction* fn);
struct FunctionSignature
{
// The type of the function.
TypeId signature;
// The scope that encompasses the function's signature. May be nullptr
// if there was no need for a signature scope (the function has no
// generics).
Scope2* signatureScope;
// The scope that encompasses the function's body. Is a child scope of
// signatureScope, if present.
NotNull<Scope2> bodyScope;
};
FunctionSignature checkFunctionSignature(NotNull<Scope2> parent, AstExprFunction* fn);
/**
* Checks the body of a function expression.
* @param scope the interior scope of the body of the function.
* @param fn the function expression to check.
*/
void checkFunctionBody(Scope2* scope, AstExprFunction* fn);
void checkFunctionBody(NotNull<Scope2> scope, AstExprFunction* fn);
/**
* Resolves a type from its AST annotation.
@ -122,7 +153,7 @@ struct ConstraintGraphBuilder
* @param ty the AST annotation to resolve.
* @return the type of the AST annotation.
**/
TypeId resolveType(Scope2* scope, AstType* ty);
TypeId resolveType(NotNull<Scope2> scope, AstType* ty);
/**
* Resolves a type pack from its AST annotation.
@ -130,9 +161,25 @@ struct ConstraintGraphBuilder
* @param tp the AST annotation to resolve.
* @return the type pack of the AST annotation.
**/
TypePackId resolveTypePack(Scope2* scope, AstTypePack* tp);
TypePackId resolveTypePack(NotNull<Scope2> scope, AstTypePack* tp);
TypePackId resolveTypePack(Scope2* scope, const AstTypeList& list);
TypePackId resolveTypePack(NotNull<Scope2> scope, const AstTypeList& list);
std::vector<std::pair<Name, GenericTypeDefinition>> createGenerics(NotNull<Scope2> scope, AstArray<AstGenericType> generics);
std::vector<std::pair<Name, GenericTypePackDefinition>> createGenericPacks(NotNull<Scope2> scope, AstArray<AstGenericTypePack> packs);
TypeId flattenPack(NotNull<Scope2> scope, Location location, TypePackId tp);
void reportError(Location location, TypeErrorData err);
void reportCodeTooComplex(Location location);
/** Scan the program for global definitions.
*
* ConstraintGraphBuilder needs to differentiate between globals and accesses to undefined symbols. Doing this "for
* real" in a general way is going to be pretty hard, so we are choosing not to tackle that yet. For now, we do an
* initial scan of the AST and note what globals are defined.
*/
void prepopulateGlobalScope(NotNull<Scope2> globalScope, AstStatBlock* program);
};
/**
@ -145,6 +192,6 @@ struct ConstraintGraphBuilder
* @return a list of pointers to constraints contained within the scope graph.
* None of these pointers should be null.
*/
std::vector<NotNull<Constraint>> collectConstraints(Scope2* rootScope);
std::vector<NotNull<Constraint>> collectConstraints(NotNull<Scope2> rootScope);
} // namespace Luau

View file

@ -25,7 +25,7 @@ struct ConstraintSolver
// is important to not add elements to this vector, lest the underlying
// storage that we retain pointers to be mutated underneath us.
const std::vector<NotNull<Constraint>> constraints;
Scope2* rootScope;
NotNull<Scope2> rootScope;
// This includes every constraint that has not been fully solved.
// A constraint can be both blocked and unsolved, for instance.
@ -40,7 +40,7 @@ struct ConstraintSolver
ConstraintSolverLogger logger;
explicit ConstraintSolver(TypeArena* arena, Scope2* rootScope);
explicit ConstraintSolver(TypeArena* arena, NotNull<Scope2> rootScope);
/**
* Attempts to dispatch all pending constraints and reach a type solution
@ -50,11 +50,17 @@ struct ConstraintSolver
bool done();
/** Attempt to dispatch a constraint. Returns true if it was successful.
* If tryDispatch() returns false, the constraint remains in the unsolved set and will be retried later.
*/
bool tryDispatch(NotNull<const Constraint> c, bool force);
bool tryDispatch(const SubtypeConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const PackSubtypeConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const GeneralizationConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const InstantiationConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const UnaryConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const BinaryConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const NameConstraint& c, NotNull<const Constraint> constraint);
void block(NotNull<const Constraint> target, NotNull<const Constraint> constraint);
@ -115,6 +121,6 @@ private:
void unblock_(BlockedConstraintId progressed);
};
void dump(Scope2* rootScope, struct ToStringOptions& opts);
void dump(NotNull<Scope2> rootScope, struct ToStringOptions& opts);
} // namespace Luau

View file

@ -369,7 +369,8 @@ struct InternalErrorReporter
[[noreturn]] void ice(const std::string& message);
};
class InternalCompilerError : public std::exception {
class InternalCompilerError : public std::exception
{
public:
explicit InternalCompilerError(const std::string& message, const std::string& moduleName)
: message(message)

View file

@ -5,6 +5,7 @@
#include "Luau/Module.h"
#include "Luau/ModuleResolver.h"
#include "Luau/RequireTracer.h"
#include "Luau/Scope.h"
#include "Luau/TypeInfer.h"
#include "Luau/Variant.h"
@ -158,6 +159,8 @@ struct Frontend
void registerBuiltinDefinition(const std::string& name, std::function<void(TypeChecker&, ScopePtr)>);
void applyBuiltinDefinitionToEnvironment(const std::string& environmentName, const std::string& definitionName);
NotNull<Scope2> getGlobalScope2();
private:
ModulePtr check(const SourceModule& sourceModule, Mode mode, const ScopePtr& environmentScope);
@ -173,6 +176,8 @@ private:
std::unordered_map<std::string, ScopePtr> environments;
std::unordered_map<std::string, std::function<void(TypeChecker&, ScopePtr)>> builtinDefinitions;
std::unique_ptr<Scope2> globalScope2;
public:
FileResolver* fileResolver;
FrontendModuleResolver moduleResolver;

View file

@ -26,7 +26,7 @@ namespace Luau
* The explicit delete statement is permitted (but not recommended) on a
* NotNull<T> through this implicit conversion.
*/
template <typename T>
template<typename T>
struct NotNull
{
explicit NotNull(T* t)
@ -38,10 +38,11 @@ struct NotNull
explicit NotNull(std::nullptr_t) = delete;
void operator=(std::nullptr_t) = delete;
template <typename U>
template<typename U>
NotNull(NotNull<U> other)
: ptr(other.get())
{}
{
}
operator T*() const noexcept
{
@ -72,12 +73,13 @@ private:
T* ptr;
};
}
} // namespace Luau
namespace std
{
template <typename T> struct hash<Luau::NotNull<T>>
template<typename T>
struct hash<Luau::NotNull<T>>
{
size_t operator()(const Luau::NotNull<T>& p) const
{
@ -85,4 +87,4 @@ template <typename T> struct hash<Luau::NotNull<T>>
}
};
}
} // namespace std

View file

@ -3,6 +3,7 @@
#include "Luau/Constraint.h"
#include "Luau/Location.h"
#include "Luau/NotNull.h"
#include "Luau/TypeVar.h"
#include <unordered_map>
@ -71,15 +72,18 @@ struct Scope2
// is the module-level scope).
Scope2* parent = nullptr;
// All the children of this scope.
std::vector<Scope2*> children;
std::vector<NotNull<Scope2>> children;
std::unordered_map<Symbol, TypeId> bindings; // TODO: I think this can be a DenseHashMap
std::unordered_map<Name, TypeId> typeBindings;
std::unordered_map<Name, TypePackId> typePackBindings;
TypePackId returnType;
std::optional<TypePackId> varargPack;
// All constraints belonging to this scope.
std::vector<ConstraintPtr> constraints;
std::optional<TypeId> lookup(Symbol sym);
std::optional<TypeId> lookupTypeBinding(const Name& name);
std::optional<TypePackId> lookupTypePackBinding(const Name& name);
};
} // namespace Luau

View file

@ -34,6 +34,12 @@ struct TypeArena
TypePackId addTypePack(std::vector<TypeId> types);
TypePackId addTypePack(TypePack pack);
TypePackId addTypePack(TypePackVar pack);
template<typename T>
TypePackId addTypePack(T tp)
{
return addTypePack(TypePackVar(std::move(tp)));
}
};
void freeze(TypeArena& arena);

View file

@ -173,7 +173,7 @@ struct TypeChecker
TypeId checkFunctionName(const ScopePtr& scope, AstExpr& funName, TypeLevel level);
std::pair<TypeId, ScopePtr> checkFunctionSignature(const ScopePtr& scope, int subLevel, const AstExprFunction& expr,
std::optional<Location> originalNameLoc, std::optional<TypeId> expectedType);
std::optional<Location> originalNameLoc, std::optional<TypeId> selfType, std::optional<TypeId> expectedType);
void checkFunctionBody(const ScopePtr& scope, TypeId type, const AstExprFunction& function);
void checkArgumentList(
@ -424,6 +424,8 @@ private:
* (exported, name) to properly deal with the case where the two duplicates do not have the same export status.
*/
DenseHashSet<std::pair<bool, Name>, HashBoolNamePair> duplicateTypeAliases;
std::vector<std::pair<TypeId, ScopePtr>> deferredQuantification;
};
// Unit test hook

View file

@ -357,6 +357,9 @@ struct TableTypeVar
std::optional<TypeId> boundTo;
Tags tags;
// Methods of this table that have an untyped self will use the same shared self type.
std::optional<TypeId> selfTy;
};
// Represents a metatable attached to a table typevar. Somewhat analogous to a bound typevar.

View file

@ -1,6 +1,10 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/ConstraintGraphBuilder.h"
#include "Luau/RecursionCounter.h"
#include "Luau/ToString.h"
LUAU_FASTINT(LuauCheckRecursionLimit);
#include "Luau/Scope.h"
@ -9,32 +13,33 @@ namespace Luau
const AstStat* getFallthrough(const AstStat* node); // TypeInfer.cpp
ConstraintGraphBuilder::ConstraintGraphBuilder(TypeArena* arena)
: singletonTypes(getSingletonTypes())
ConstraintGraphBuilder::ConstraintGraphBuilder(
const ModuleName& moduleName, TypeArena* arena, NotNull<InternalErrorReporter> ice, NotNull<Scope2> globalScope)
: moduleName(moduleName)
, singletonTypes(getSingletonTypes())
, arena(arena)
, rootScope(nullptr)
, ice(ice)
, globalScope(globalScope)
{
LUAU_ASSERT(arena);
}
TypeId ConstraintGraphBuilder::freshType(Scope2* scope)
TypeId ConstraintGraphBuilder::freshType(NotNull<Scope2> scope)
{
LUAU_ASSERT(scope);
return arena->addType(FreeTypeVar{scope});
}
TypePackId ConstraintGraphBuilder::freshTypePack(Scope2* scope)
TypePackId ConstraintGraphBuilder::freshTypePack(NotNull<Scope2> scope)
{
LUAU_ASSERT(scope);
FreeTypePack f{scope};
return arena->addTypePack(TypePackVar{std::move(f)});
}
Scope2* ConstraintGraphBuilder::childScope(Location location, Scope2* parent)
NotNull<Scope2> ConstraintGraphBuilder::childScope(Location location, NotNull<Scope2> parent)
{
LUAU_ASSERT(parent);
auto scope = std::make_unique<Scope2>();
Scope2* borrow = scope.get();
NotNull<Scope2> borrow = NotNull(scope.get());
scopes.emplace_back(location, std::move(scope));
borrow->parent = parent;
@ -44,15 +49,13 @@ Scope2* ConstraintGraphBuilder::childScope(Location location, Scope2* parent)
return borrow;
}
void ConstraintGraphBuilder::addConstraint(Scope2* scope, ConstraintV cv)
void ConstraintGraphBuilder::addConstraint(NotNull<Scope2> scope, ConstraintV cv)
{
LUAU_ASSERT(scope);
scope->constraints.emplace_back(new Constraint{std::move(cv)});
}
void ConstraintGraphBuilder::addConstraint(Scope2* scope, std::unique_ptr<Constraint> c)
void ConstraintGraphBuilder::addConstraint(NotNull<Scope2> scope, std::unique_ptr<Constraint> c)
{
LUAU_ASSERT(scope);
scope->constraints.emplace_back(std::move(c));
}
@ -62,7 +65,11 @@ void ConstraintGraphBuilder::visit(AstStatBlock* block)
LUAU_ASSERT(rootScope == nullptr);
scopes.emplace_back(block->location, std::make_unique<Scope2>());
rootScope = scopes.back().second.get();
rootScope->returnType = freshTypePack(rootScope);
NotNull<Scope2> borrow = NotNull(rootScope);
rootScope->returnType = freshTypePack(borrow);
prepopulateGlobalScope(borrow, block);
// TODO: We should share the global scope.
rootScope->typeBindings["nil"] = singletonTypes.nilType;
@ -71,12 +78,26 @@ void ConstraintGraphBuilder::visit(AstStatBlock* block)
rootScope->typeBindings["boolean"] = singletonTypes.booleanType;
rootScope->typeBindings["thread"] = singletonTypes.threadType;
visit(rootScope, block);
visitBlockWithoutChildScope(borrow, block);
}
void ConstraintGraphBuilder::visit(Scope2* scope, AstStat* stat)
void ConstraintGraphBuilder::visitBlockWithoutChildScope(NotNull<Scope2> scope, AstStatBlock* block)
{
LUAU_ASSERT(scope);
RecursionCounter counter{&recursionCount};
if (recursionCount >= FInt::LuauCheckRecursionLimit)
{
reportCodeTooComplex(block->location);
return;
}
for (AstStat* stat : block->body)
visit(scope, stat);
}
void ConstraintGraphBuilder::visit(NotNull<Scope2> scope, AstStat* stat)
{
RecursionLimiter limiter{&recursionCount, FInt::LuauCheckRecursionLimit};
if (auto s = stat->as<AstStatBlock>())
visit(scope, s);
@ -100,10 +121,8 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStat* stat)
LUAU_ASSERT(0);
}
void ConstraintGraphBuilder::visit(Scope2* scope, AstStatLocal* local)
void ConstraintGraphBuilder::visit(NotNull<Scope2> scope, AstStatLocal* local)
{
LUAU_ASSERT(scope);
std::vector<TypeId> varTypes;
for (AstLocal* local : local->vars)
@ -148,23 +167,19 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatLocal* local)
}
}
void addConstraints(Constraint* constraint, Scope2* scope)
void addConstraints(Constraint* constraint, NotNull<Scope2> scope)
{
LUAU_ASSERT(scope);
scope->constraints.reserve(scope->constraints.size() + scope->constraints.size());
for (const auto& c : scope->constraints)
constraint->dependencies.push_back(NotNull{c.get()});
for (Scope2* childScope : scope->children)
for (NotNull<Scope2> childScope : scope->children)
addConstraints(constraint, childScope);
}
void ConstraintGraphBuilder::visit(Scope2* scope, AstStatLocalFunction* function)
void ConstraintGraphBuilder::visit(NotNull<Scope2> scope, AstStatLocalFunction* function)
{
LUAU_ASSERT(scope);
// Local
// Global
// Dotted path
@ -172,36 +187,31 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatLocalFunction* function
TypeId functionType = nullptr;
auto ty = scope->lookup(function->name);
if (ty.has_value())
{
// TODO: This is duplicate definition of a local function. Is this allowed?
functionType = *ty;
}
else
{
LUAU_ASSERT(!ty.has_value()); // The parser ensures that every local function has a distinct Symbol for its name.
functionType = arena->addType(BlockedTypeVar{});
scope->bindings[function->name] = functionType;
}
auto [actualFunctionType, innerScope] = checkFunctionSignature(scope, function->func);
innerScope->bindings[function->name] = actualFunctionType;
FunctionSignature sig = checkFunctionSignature(scope, function->func);
sig.bodyScope->bindings[function->name] = sig.signature;
checkFunctionBody(innerScope, function->func);
checkFunctionBody(sig.bodyScope, function->func);
std::unique_ptr<Constraint> c{new Constraint{GeneralizationConstraint{functionType, actualFunctionType, innerScope}}};
addConstraints(c.get(), innerScope);
std::unique_ptr<Constraint> c{
new Constraint{GeneralizationConstraint{functionType, sig.signature, sig.signatureScope ? sig.signatureScope : sig.bodyScope}}};
addConstraints(c.get(), sig.bodyScope);
addConstraint(scope, std::move(c));
}
void ConstraintGraphBuilder::visit(Scope2* scope, AstStatFunction* function)
void ConstraintGraphBuilder::visit(NotNull<Scope2> scope, AstStatFunction* function)
{
// Name could be AstStatLocal, AstStatGlobal, AstStatIndexName.
// With or without self
TypeId functionType = nullptr;
auto [actualFunctionType, innerScope] = checkFunctionSignature(scope, function->func);
FunctionSignature sig = checkFunctionSignature(scope, function->func);
if (AstExprLocal* localName = function->name->as<AstExprLocal>())
{
@ -216,7 +226,7 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatFunction* function)
functionType = arena->addType(BlockedTypeVar{});
scope->bindings[localName->local] = functionType;
}
innerScope->bindings[localName->local] = actualFunctionType;
sig.bodyScope->bindings[localName->local] = sig.signature;
}
else if (AstExprGlobal* globalName = function->name->as<AstExprGlobal>())
{
@ -231,32 +241,48 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatFunction* function)
functionType = arena->addType(BlockedTypeVar{});
rootScope->bindings[globalName->name] = functionType;
}
innerScope->bindings[globalName->name] = actualFunctionType;
sig.bodyScope->bindings[globalName->name] = sig.signature;
}
else if (AstExprIndexName* indexName = function->name->as<AstExprIndexName>())
{
LUAU_ASSERT(0); // not yet implemented
TypeId containingTableType = check(scope, indexName->expr);
functionType = arena->addType(BlockedTypeVar{});
TypeId prospectiveTableType =
arena->addType(TableTypeVar{}); // TODO look into stack utilization. This is probably ok because it scales with AST depth.
NotNull<TableTypeVar> prospectiveTable{getMutable<TableTypeVar>(prospectiveTableType)};
Property& prop = prospectiveTable->props[indexName->index.value];
prop.type = functionType;
prop.location = function->name->location;
addConstraint(scope, SubtypeConstraint{containingTableType, prospectiveTableType});
}
else if (AstExprError* err = function->name->as<AstExprError>())
{
functionType = singletonTypes.errorRecoveryType();
}
checkFunctionBody(innerScope, function->func);
LUAU_ASSERT(functionType != nullptr);
std::unique_ptr<Constraint> c{new Constraint{GeneralizationConstraint{functionType, actualFunctionType, innerScope}}};
addConstraints(c.get(), innerScope);
checkFunctionBody(sig.bodyScope, function->func);
std::unique_ptr<Constraint> c{
new Constraint{GeneralizationConstraint{functionType, sig.signature, sig.signatureScope ? sig.signatureScope : sig.bodyScope}}};
addConstraints(c.get(), sig.bodyScope);
addConstraint(scope, std::move(c));
}
void ConstraintGraphBuilder::visit(Scope2* scope, AstStatReturn* ret)
void ConstraintGraphBuilder::visit(NotNull<Scope2> scope, AstStatReturn* ret)
{
LUAU_ASSERT(scope);
TypePackId exprTypes = checkPack(scope, ret->list);
addConstraint(scope, PackSubtypeConstraint{exprTypes, scope->returnType});
}
void ConstraintGraphBuilder::visit(Scope2* scope, AstStatBlock* block)
void ConstraintGraphBuilder::visit(NotNull<Scope2> scope, AstStatBlock* block)
{
LUAU_ASSERT(scope);
NotNull<Scope2> innerScope = childScope(block->location, scope);
// In order to enable mutually-recursive type aliases, we need to
// populate the type bindings before we actually check any of the
@ -271,11 +297,10 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatBlock* block)
}
}
for (AstStat* stat : block->body)
visit(scope, stat);
visitBlockWithoutChildScope(innerScope, block);
}
void ConstraintGraphBuilder::visit(Scope2* scope, AstStatAssign* assign)
void ConstraintGraphBuilder::visit(NotNull<Scope2> scope, AstStatAssign* assign)
{
TypePackId varPackId = checkExprList(scope, assign->vars);
TypePackId valuePack = checkPack(scope, assign->values);
@ -283,21 +308,21 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatAssign* assign)
addConstraint(scope, PackSubtypeConstraint{valuePack, varPackId});
}
void ConstraintGraphBuilder::visit(Scope2* scope, AstStatIf* ifStatement)
void ConstraintGraphBuilder::visit(NotNull<Scope2> scope, AstStatIf* ifStatement)
{
check(scope, ifStatement->condition);
Scope2* thenScope = childScope(ifStatement->thenbody->location, scope);
NotNull<Scope2> thenScope = childScope(ifStatement->thenbody->location, scope);
visit(thenScope, ifStatement->thenbody);
if (ifStatement->elsebody)
{
Scope2* elseScope = childScope(ifStatement->elsebody->location, scope);
NotNull<Scope2> elseScope = childScope(ifStatement->elsebody->location, scope);
visit(elseScope, ifStatement->elsebody);
}
}
void ConstraintGraphBuilder::visit(Scope2* scope, AstStatTypeAlias* alias)
void ConstraintGraphBuilder::visit(NotNull<Scope2> scope, AstStatTypeAlias* alias)
{
// TODO: Exported type aliases
// TODO: Generic type aliases
@ -307,6 +332,10 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatTypeAlias* alias)
// AST to set up typeBindings. If it's not, we've somehow skipped
// this alias in that first pass.
LUAU_ASSERT(it != scope->typeBindings.end());
if (it == scope->typeBindings.end())
{
ice->ice("Type alias does not have a pre-populated binding", alias->location);
}
TypeId ty = resolveType(scope, alias->type);
@ -319,10 +348,8 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatTypeAlias* alias)
addConstraint(scope, NameConstraint{ty, alias->name.value});
}
TypePackId ConstraintGraphBuilder::checkPack(Scope2* scope, AstArray<AstExpr*> exprs)
TypePackId ConstraintGraphBuilder::checkPack(NotNull<Scope2> scope, AstArray<AstExpr*> exprs)
{
LUAU_ASSERT(scope);
if (exprs.size == 0)
return arena->addTypePack({});
@ -342,7 +369,7 @@ TypePackId ConstraintGraphBuilder::checkPack(Scope2* scope, AstArray<AstExpr*> e
return arena->addTypePack(TypePack{std::move(types), last});
}
TypePackId ConstraintGraphBuilder::checkExprList(Scope2* scope, const AstArray<AstExpr*>& exprs)
TypePackId ConstraintGraphBuilder::checkExprList(NotNull<Scope2> scope, const AstArray<AstExpr*>& exprs)
{
TypePackId result = arena->addTypePack({});
TypePack* resultPack = getMutable<TypePack>(result);
@ -363,9 +390,15 @@ TypePackId ConstraintGraphBuilder::checkExprList(Scope2* scope, const AstArray<A
return result;
}
TypePackId ConstraintGraphBuilder::checkPack(Scope2* scope, AstExpr* expr)
TypePackId ConstraintGraphBuilder::checkPack(NotNull<Scope2> scope, AstExpr* expr)
{
LUAU_ASSERT(scope);
RecursionCounter counter{&recursionCount};
if (recursionCount >= FInt::LuauCheckRecursionLimit)
{
reportCodeTooComplex(expr->location);
return singletonTypes.errorRecoveryTypePack();
}
TypePackId result = nullptr;
@ -384,7 +417,7 @@ TypePackId ConstraintGraphBuilder::checkPack(Scope2* scope, AstExpr* expr)
astOriginalCallTypes[call->func] = fnType;
TypeId instantiatedType = freshType(scope);
TypeId instantiatedType = arena->addType(BlockedTypeVar{});
addConstraint(scope, InstantiationConstraint{instantiatedType, fnType});
TypePackId rets = freshTypePack(scope);
@ -394,6 +427,13 @@ TypePackId ConstraintGraphBuilder::checkPack(Scope2* scope, AstExpr* expr)
addConstraint(scope, SubtypeConstraint{inferredFnType, instantiatedType});
result = rets;
}
else if (AstExprVarargs* varargs = expr->as<AstExprVarargs>())
{
if (scope->varargPack)
result = *scope->varargPack;
else
result = singletonTypes.errorRecoveryTypePack();
}
else
{
TypeId t = check(scope, expr);
@ -405,9 +445,15 @@ TypePackId ConstraintGraphBuilder::checkPack(Scope2* scope, AstExpr* expr)
return result;
}
TypeId ConstraintGraphBuilder::check(Scope2* scope, AstExpr* expr)
TypeId ConstraintGraphBuilder::check(NotNull<Scope2> scope, AstExpr* expr)
{
LUAU_ASSERT(scope);
RecursionCounter counter{&recursionCount};
if (recursionCount >= FInt::LuauCheckRecursionLimit)
{
reportCodeTooComplex(expr->location);
return singletonTypes.errorRecoveryType();
}
TypeId result = nullptr;
@ -435,37 +481,38 @@ TypeId ConstraintGraphBuilder::check(Scope2* scope, AstExpr* expr)
if (ty)
result = *ty;
else
{
/* prepopulateGlobalScope() has already added all global functions to the environment by this point, so any
* global that is not already in-scope is definitely an unknown symbol.
*/
reportError(g->location, UnknownSymbol{g->name.value});
result = singletonTypes.errorRecoveryType(); // FIXME? Record an error at this point?
}
else if (auto a = expr->as<AstExprCall>())
{
TypePackId packResult = checkPack(scope, expr);
if (auto f = first(packResult))
return *f;
else if (get<FreeTypePack>(packResult))
{
TypeId typeResult = freshType(scope);
TypePack onePack{{typeResult}, freshTypePack(scope)};
TypePackId oneTypePack = arena->addTypePack(std::move(onePack));
addConstraint(scope, PackSubtypeConstraint{packResult, oneTypePack});
return typeResult;
}
}
else if (expr->is<AstExprVarargs>())
result = flattenPack(scope, expr->location, checkPack(scope, expr));
else if (expr->is<AstExprCall>())
result = flattenPack(scope, expr->location, checkPack(scope, expr));
else if (auto a = expr->as<AstExprFunction>())
{
auto [fnType, functionScope] = checkFunctionSignature(scope, a);
checkFunctionBody(functionScope, a);
return fnType;
FunctionSignature sig = checkFunctionSignature(scope, a);
checkFunctionBody(sig.bodyScope, a);
return sig.signature;
}
else if (auto indexName = expr->as<AstExprIndexName>())
{
result = check(scope, indexName);
}
else if (auto indexExpr = expr->as<AstExprIndexExpr>())
result = check(scope, indexExpr);
else if (auto table = expr->as<AstExprTable>())
{
result = checkExprTable(scope, table);
else if (auto unary = expr->as<AstExprUnary>())
result = check(scope, unary);
else if (auto binary = expr->as<AstExprBinary>())
result = check(scope, binary);
else if (auto err = expr->as<AstExprError>())
{
// Open question: Should we traverse into this?
result = singletonTypes.errorRecoveryType();
}
else
{
@ -478,7 +525,7 @@ TypeId ConstraintGraphBuilder::check(Scope2* scope, AstExpr* expr)
return result;
}
TypeId ConstraintGraphBuilder::check(Scope2* scope, AstExprIndexName* indexName)
TypeId ConstraintGraphBuilder::check(NotNull<Scope2> scope, AstExprIndexName* indexName)
{
TypeId obj = check(scope, indexName->expr);
TypeId result = freshType(scope);
@ -494,7 +541,67 @@ TypeId ConstraintGraphBuilder::check(Scope2* scope, AstExprIndexName* indexName)
return result;
}
TypeId ConstraintGraphBuilder::checkExprTable(Scope2* scope, AstExprTable* expr)
TypeId ConstraintGraphBuilder::check(NotNull<Scope2> scope, AstExprIndexExpr* indexExpr)
{
TypeId obj = check(scope, indexExpr->expr);
TypeId indexType = check(scope, indexExpr->index);
TypeId result = freshType(scope);
TableIndexer indexer{indexType, result};
TypeId tableType = arena->addType(TableTypeVar{TableTypeVar::Props{}, TableIndexer{indexType, result}, TypeLevel{}, TableState::Free});
addConstraint(scope, SubtypeConstraint{obj, tableType});
return result;
}
TypeId ConstraintGraphBuilder::check(NotNull<Scope2> scope, AstExprUnary* unary)
{
TypeId operandType = check(scope, unary->expr);
switch (unary->op)
{
case AstExprUnary::Minus:
{
TypeId resultType = arena->addType(BlockedTypeVar{});
addConstraint(scope, UnaryConstraint{AstExprUnary::Minus, operandType, resultType});
return resultType;
}
default:
LUAU_ASSERT(0);
}
LUAU_UNREACHABLE();
return singletonTypes.errorRecoveryType();
}
TypeId ConstraintGraphBuilder::check(NotNull<Scope2> scope, AstExprBinary* binary)
{
TypeId leftType = check(scope, binary->left);
TypeId rightType = check(scope, binary->right);
switch (binary->op)
{
case AstExprBinary::Or:
{
addConstraint(scope, SubtypeConstraint{leftType, rightType});
return leftType;
}
case AstExprBinary::Sub:
{
TypeId resultType = arena->addType(BlockedTypeVar{});
addConstraint(scope, BinaryConstraint{AstExprBinary::Sub, leftType, rightType, resultType});
return resultType;
}
default:
LUAU_ASSERT(0);
}
LUAU_ASSERT(0);
return nullptr;
}
TypeId ConstraintGraphBuilder::checkExprTable(NotNull<Scope2> scope, AstExprTable* expr)
{
TypeId ty = arena->addType(TableTypeVar{});
TableTypeVar* ttv = getMutable<TableTypeVar>(ty);
@ -515,6 +622,8 @@ TypeId ConstraintGraphBuilder::checkExprTable(Scope2* scope, AstExprTable* expr)
for (const AstExprTable::Item& item : expr->items)
{
TypeId itemTy = check(scope, item.value);
if (get<ErrorTypeVar>(follow(itemTy)))
return ty;
if (item.key)
{
@ -542,47 +651,111 @@ TypeId ConstraintGraphBuilder::checkExprTable(Scope2* scope, AstExprTable* expr)
return ty;
}
std::pair<TypeId, Scope2*> ConstraintGraphBuilder::checkFunctionSignature(Scope2* parent, AstExprFunction* fn)
ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionSignature(NotNull<Scope2> parent, AstExprFunction* fn)
{
Scope2* innerScope = childScope(fn->body->location, parent);
TypePackId returnType = freshTypePack(innerScope);
innerScope->returnType = returnType;
Scope2* signatureScope = nullptr;
Scope2* bodyScope = nullptr;
TypePackId returnType = nullptr;
std::vector<TypeId> genericTypes;
std::vector<TypePackId> genericTypePacks;
bool hasGenerics = fn->generics.size > 0 || fn->genericPacks.size > 0;
// If we don't have any generics, we can save some memory and compute by not
// creating the signatureScope, which is only used to scope the declared
// generics properly.
if (hasGenerics)
{
NotNull signatureBorrow = childScope(fn->location, parent);
signatureScope = signatureBorrow.get();
// We need to assign returnType before creating bodyScope so that the
// return type gets propogated to bodyScope.
returnType = freshTypePack(signatureBorrow);
signatureScope->returnType = returnType;
bodyScope = childScope(fn->body->location, signatureBorrow).get();
std::vector<std::pair<Name, GenericTypeDefinition>> genericDefinitions = createGenerics(signatureBorrow, fn->generics);
std::vector<std::pair<Name, GenericTypePackDefinition>> genericPackDefinitions = createGenericPacks(signatureBorrow, fn->genericPacks);
// We do not support default values on function generics, so we only
// care about the types involved.
for (const auto& [name, g] : genericDefinitions)
{
genericTypes.push_back(g.ty);
signatureScope->typeBindings[name] = g.ty;
}
for (const auto& [name, g] : genericPackDefinitions)
{
genericTypePacks.push_back(g.tp);
signatureScope->typePackBindings[name] = g.tp;
}
}
else
{
NotNull bodyBorrow = childScope(fn->body->location, parent);
bodyScope = bodyBorrow.get();
returnType = freshTypePack(bodyBorrow);
bodyBorrow->returnType = returnType;
// To eliminate the need to branch on hasGenerics below, we say that the
// signature scope is the body scope when there is no real signature
// scope.
signatureScope = bodyScope;
}
NotNull bodyBorrow = NotNull(bodyScope);
NotNull signatureBorrow = NotNull(signatureScope);
if (fn->returnAnnotation)
{
TypePackId annotatedRetType = resolveTypePack(innerScope, *fn->returnAnnotation);
addConstraint(innerScope, PackSubtypeConstraint{returnType, annotatedRetType});
TypePackId annotatedRetType = resolveTypePack(signatureBorrow, *fn->returnAnnotation);
addConstraint(signatureBorrow, PackSubtypeConstraint{returnType, annotatedRetType});
}
std::vector<TypeId> argTypes;
for (AstLocal* local : fn->args)
{
TypeId t = freshType(innerScope);
TypeId t = freshType(signatureBorrow);
argTypes.push_back(t);
innerScope->bindings[local] = t;
signatureScope->bindings[local] = t;
if (local->annotation)
{
TypeId argAnnotation = resolveType(innerScope, local->annotation);
addConstraint(innerScope, SubtypeConstraint{t, argAnnotation});
TypeId argAnnotation = resolveType(signatureBorrow, local->annotation);
addConstraint(signatureBorrow, SubtypeConstraint{t, argAnnotation});
}
}
// TODO: Vararg annotation.
// TODO: Preserve argument names in the function's type.
FunctionTypeVar actualFunction{arena->addTypePack(argTypes), returnType};
actualFunction.hasNoGenerics = !hasGenerics;
actualFunction.generics = std::move(genericTypes);
actualFunction.genericPacks = std::move(genericTypePacks);
TypeId actualFunctionType = arena->addType(std::move(actualFunction));
LUAU_ASSERT(actualFunctionType);
astTypes[fn] = actualFunctionType;
return {actualFunctionType, innerScope};
return {
/* signature */ actualFunctionType,
// Undo the workaround we made above: if there's no signature scope,
// don't report it.
/* signatureScope */ hasGenerics ? signatureScope : nullptr,
/* bodyScope */ bodyBorrow,
};
}
void ConstraintGraphBuilder::checkFunctionBody(Scope2* scope, AstExprFunction* fn)
void ConstraintGraphBuilder::checkFunctionBody(NotNull<Scope2> scope, AstExprFunction* fn)
{
for (AstStat* stat : fn->body->body)
visit(scope, stat);
visitBlockWithoutChildScope(scope, fn->body);
// If it is possible for execution to reach the end of the function, the return type must be compatible with ()
@ -593,7 +766,7 @@ void ConstraintGraphBuilder::checkFunctionBody(Scope2* scope, AstExprFunction* f
}
}
TypeId ConstraintGraphBuilder::resolveType(Scope2* scope, AstType* ty)
TypeId ConstraintGraphBuilder::resolveType(NotNull<Scope2> scope, AstType* ty)
{
TypeId result = nullptr;
@ -636,29 +809,73 @@ TypeId ConstraintGraphBuilder::resolveType(Scope2* scope, AstType* ty)
}
else if (auto fn = ty->as<AstTypeFunction>())
{
// TODO: Generic functions.
// TODO: Scope (though it may not be needed).
// TODO: Recursion limit.
TypePackId argTypes = resolveTypePack(scope, fn->argTypes);
TypePackId returnTypes = resolveTypePack(scope, fn->returnTypes);
bool hasGenerics = fn->generics.size > 0 || fn->genericPacks.size > 0;
Scope2* signatureScope = nullptr;
// TODO: Is this the right constructor to use?
result = arena->addType(FunctionTypeVar{argTypes, returnTypes});
std::vector<TypeId> genericTypes;
std::vector<TypePackId> genericTypePacks;
FunctionTypeVar* ftv = getMutable<FunctionTypeVar>(result);
ftv->argNames.reserve(fn->argNames.size);
// If we don't have generics, we do not need to generate a child scope
// for the generic bindings to live on.
if (hasGenerics)
{
NotNull<Scope2> signatureBorrow = childScope(fn->location, scope);
signatureScope = signatureBorrow.get();
std::vector<std::pair<Name, GenericTypeDefinition>> genericDefinitions = createGenerics(signatureBorrow, fn->generics);
std::vector<std::pair<Name, GenericTypePackDefinition>> genericPackDefinitions = createGenericPacks(signatureBorrow, fn->genericPacks);
for (const auto& [name, g] : genericDefinitions)
{
genericTypes.push_back(g.ty);
signatureBorrow->typeBindings[name] = g.ty;
}
for (const auto& [name, g] : genericPackDefinitions)
{
genericTypePacks.push_back(g.tp);
signatureBorrow->typePackBindings[name] = g.tp;
}
}
else
{
// To eliminate the need to branch on hasGenerics below, we say that
// the signature scope is the parent scope if we don't have
// generics.
signatureScope = scope.get();
}
NotNull<Scope2> signatureBorrow(signatureScope);
TypePackId argTypes = resolveTypePack(signatureBorrow, fn->argTypes);
TypePackId returnTypes = resolveTypePack(signatureBorrow, fn->returnTypes);
// TODO: FunctionTypeVar needs a pointer to the scope so that we know
// how to quantify/instantiate it.
FunctionTypeVar ftv{argTypes, returnTypes};
// This replicates the behavior of the appropriate FunctionTypeVar
// constructors.
ftv.hasNoGenerics = !hasGenerics;
ftv.generics = std::move(genericTypes);
ftv.genericPacks = std::move(genericTypePacks);
ftv.argNames.reserve(fn->argNames.size);
for (const auto& el : fn->argNames)
{
if (el)
{
const auto& [name, location] = *el;
ftv->argNames.push_back(FunctionArgument{name.value, location});
ftv.argNames.push_back(FunctionArgument{name.value, location});
}
else
{
ftv->argNames.push_back(std::nullopt);
ftv.argNames.push_back(std::nullopt);
}
}
result = arena->addType(std::move(ftv));
}
else if (auto tof = ty->as<AstTypeTypeof>())
{
@ -710,7 +927,7 @@ TypeId ConstraintGraphBuilder::resolveType(Scope2* scope, AstType* ty)
return result;
}
TypePackId ConstraintGraphBuilder::resolveTypePack(Scope2* scope, AstTypePack* tp)
TypePackId ConstraintGraphBuilder::resolveTypePack(NotNull<Scope2> scope, AstTypePack* tp)
{
TypePackId result;
if (auto expl = tp->as<AstTypePackExplicit>())
@ -736,7 +953,7 @@ TypePackId ConstraintGraphBuilder::resolveTypePack(Scope2* scope, AstTypePack* t
return result;
}
TypePackId ConstraintGraphBuilder::resolveTypePack(Scope2* scope, const AstTypeList& list)
TypePackId ConstraintGraphBuilder::resolveTypePack(NotNull<Scope2> scope, const AstTypeList& list)
{
std::vector<TypeId> head;
@ -754,16 +971,108 @@ TypePackId ConstraintGraphBuilder::resolveTypePack(Scope2* scope, const AstTypeL
return arena->addTypePack(TypePack{head, tail});
}
void collectConstraints(std::vector<NotNull<Constraint>>& result, Scope2* scope)
std::vector<std::pair<Name, GenericTypeDefinition>> ConstraintGraphBuilder::createGenerics(NotNull<Scope2> scope, AstArray<AstGenericType> generics)
{
std::vector<std::pair<Name, GenericTypeDefinition>> result;
for (const auto& generic : generics)
{
TypeId genericTy = arena->addType(GenericTypeVar{scope, generic.name.value});
std::optional<TypeId> defaultTy = std::nullopt;
if (generic.defaultValue)
defaultTy = resolveType(scope, generic.defaultValue);
result.push_back({generic.name.value, GenericTypeDefinition{
genericTy,
defaultTy,
}});
}
return result;
}
std::vector<std::pair<Name, GenericTypePackDefinition>> ConstraintGraphBuilder::createGenericPacks(
NotNull<Scope2> scope, AstArray<AstGenericTypePack> generics)
{
std::vector<std::pair<Name, GenericTypePackDefinition>> result;
for (const auto& generic : generics)
{
TypePackId genericTy = arena->addTypePack(TypePackVar{GenericTypePack{scope, generic.name.value}});
std::optional<TypePackId> defaultTy = std::nullopt;
if (generic.defaultValue)
defaultTy = resolveTypePack(scope, generic.defaultValue);
result.push_back({generic.name.value, GenericTypePackDefinition{
genericTy,
defaultTy,
}});
}
return result;
}
TypeId ConstraintGraphBuilder::flattenPack(NotNull<Scope2> scope, Location location, TypePackId tp)
{
if (auto f = first(tp))
return *f;
TypeId typeResult = freshType(scope);
TypePack onePack{{typeResult}, freshTypePack(scope)};
TypePackId oneTypePack = arena->addTypePack(std::move(onePack));
addConstraint(scope, PackSubtypeConstraint{tp, oneTypePack});
return typeResult;
}
void ConstraintGraphBuilder::reportError(Location location, TypeErrorData err)
{
errors.push_back(TypeError{location, moduleName, std::move(err)});
}
void ConstraintGraphBuilder::reportCodeTooComplex(Location location)
{
errors.push_back(TypeError{location, moduleName, CodeTooComplex{}});
}
struct GlobalPrepopulator : AstVisitor
{
const NotNull<Scope2> globalScope;
const NotNull<TypeArena> arena;
GlobalPrepopulator(NotNull<Scope2> globalScope, NotNull<TypeArena> arena)
: globalScope(globalScope)
, arena(arena)
{
}
bool visit(AstStatFunction* function) override
{
if (AstExprGlobal* g = function->name->as<AstExprGlobal>())
globalScope->bindings[g->name] = arena->addType(BlockedTypeVar{});
return true;
}
};
void ConstraintGraphBuilder::prepopulateGlobalScope(NotNull<Scope2> globalScope, AstStatBlock* program)
{
GlobalPrepopulator gp{NotNull{globalScope}, arena};
program->visit(&gp);
}
void collectConstraints(std::vector<NotNull<Constraint>>& result, NotNull<Scope2> scope)
{
for (const auto& c : scope->constraints)
result.push_back(NotNull{c.get()});
for (Scope2* child : scope->children)
for (NotNull<Scope2> child : scope->children)
collectConstraints(result, child);
}
std::vector<NotNull<Constraint>> collectConstraints(Scope2* rootScope)
std::vector<NotNull<Constraint>> collectConstraints(NotNull<Scope2> rootScope)
{
std::vector<NotNull<Constraint>> result;
collectConstraints(result, rootScope);

View file

@ -13,7 +13,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false);
namespace Luau
{
[[maybe_unused]] static void dumpBindings(Scope2* scope, ToStringOptions& opts)
[[maybe_unused]] static void dumpBindings(NotNull<Scope2> scope, ToStringOptions& opts)
{
for (const auto& [k, v] : scope->bindings)
{
@ -22,22 +22,22 @@ namespace Luau
printf("\t%s : %s\n", k.c_str(), d.name.c_str());
}
for (Scope2* child : scope->children)
for (NotNull<Scope2> child : scope->children)
dumpBindings(child, opts);
}
static void dumpConstraints(Scope2* scope, ToStringOptions& opts)
static void dumpConstraints(NotNull<Scope2> scope, ToStringOptions& opts)
{
for (const ConstraintPtr& c : scope->constraints)
{
printf("\t%s\n", toString(*c, opts).c_str());
}
for (Scope2* child : scope->children)
for (NotNull<Scope2> child : scope->children)
dumpConstraints(child, opts);
}
void dump(Scope2* rootScope, ToStringOptions& opts)
void dump(NotNull<Scope2> rootScope, ToStringOptions& opts)
{
printf("constraints:\n");
dumpConstraints(rootScope, opts);
@ -55,7 +55,7 @@ void dump(ConstraintSolver* cs, ToStringOptions& opts)
}
}
ConstraintSolver::ConstraintSolver(TypeArena* arena, Scope2* rootScope)
ConstraintSolver::ConstraintSolver(TypeArena* arena, NotNull<Scope2> rootScope)
: arena(arena)
, constraints(collectConstraints(rootScope))
, rootScope(rootScope)
@ -180,6 +180,10 @@ bool ConstraintSolver::tryDispatch(NotNull<const Constraint> constraint, bool fo
success = tryDispatch(*gc, constraint, force);
else if (auto ic = get<InstantiationConstraint>(*constraint))
success = tryDispatch(*ic, constraint, force);
else if (auto uc = get<UnaryConstraint>(*constraint))
success = tryDispatch(*uc, constraint, force);
else if (auto bc = get<BinaryConstraint>(*constraint))
success = tryDispatch(*bc, constraint, force);
else if (auto nc = get<NameConstraint>(*constraint))
success = tryDispatch(*nc, constraint);
else
@ -246,12 +250,65 @@ bool ConstraintSolver::tryDispatch(const InstantiationConstraint& c, NotNull<con
std::optional<TypeId> instantiated = inst.substitute(c.superType);
LUAU_ASSERT(instantiated); // TODO FIXME HANDLE THIS
if (isBlocked(c.subType))
asMutable(c.subType)->ty.emplace<BoundTypeVar>(*instantiated);
else
unify(c.subType, *instantiated);
unblock(c.subType);
return true;
}
bool ConstraintSolver::tryDispatch(const UnaryConstraint& c, NotNull<const Constraint> constraint, bool force)
{
TypeId operandType = follow(c.operandType);
if (isBlocked(operandType))
return block(operandType, constraint);
if (get<FreeTypeVar>(operandType))
return block(operandType, constraint);
LUAU_ASSERT(get<BlockedTypeVar>(c.resultType));
if (isNumber(operandType) || get<AnyTypeVar>(operandType) || get<ErrorTypeVar>(operandType))
{
asMutable(c.resultType)->ty.emplace<BoundTypeVar>(c.operandType);
return true;
}
LUAU_ASSERT(0); // TODO metatable handling
return false;
}
bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNull<const Constraint> constraint, bool force)
{
TypeId leftType = follow(c.leftType);
TypeId rightType = follow(c.rightType);
if (isBlocked(leftType) || isBlocked(rightType))
{
block(leftType, constraint);
block(rightType, constraint);
return false;
}
if (isNumber(leftType))
{
unify(leftType, rightType);
asMutable(c.resultType)->ty.emplace<BoundTypeVar>(leftType);
return true;
}
if (get<FreeTypeVar>(leftType) && !force)
return block(leftType, constraint);
// TODO metatables, classes
return true;
}
bool ConstraintSolver::tryDispatch(const NameConstraint& c, NotNull<const Constraint> constraint)
{
if (isBlocked(c.namedType))

View file

@ -1,6 +1,8 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/BuiltinDefinitions.h"
LUAU_FASTFLAG(LuauCheckLenMT)
namespace Luau
{
@ -202,7 +204,13 @@ declare function unpack<V>(tab: {V}, i: number?, j: number?): ...V
std::string getBuiltinDefinitionSource()
{
return kBuiltinDefinitionLuaSrc;
std::string result = kBuiltinDefinitionLuaSrc;
// TODO: move this into kBuiltinDefinitionLuaSrc
if (FFlag::LuauCheckLenMT)
result += "declare function rawlen<K, V>(obj: {[K]: V} | string): number\n";
return result;
}
} // namespace Luau

View file

@ -787,14 +787,32 @@ const SourceModule* Frontend::getSourceModule(const ModuleName& moduleName) cons
return const_cast<Frontend*>(this)->getSourceModule(moduleName);
}
NotNull<Scope2> Frontend::getGlobalScope2()
{
if (!globalScope2)
{
const SingletonTypes& singletonTypes = getSingletonTypes();
globalScope2 = std::make_unique<Scope2>();
globalScope2->typeBindings["nil"] = singletonTypes.nilType;
globalScope2->typeBindings["number"] = singletonTypes.numberType;
globalScope2->typeBindings["string"] = singletonTypes.stringType;
globalScope2->typeBindings["boolean"] = singletonTypes.booleanType;
globalScope2->typeBindings["thread"] = singletonTypes.threadType;
}
return NotNull(globalScope2.get());
}
ModulePtr Frontend::check(const SourceModule& sourceModule, Mode mode, const ScopePtr& environmentScope)
{
ModulePtr result = std::make_shared<Module>();
ConstraintGraphBuilder cgb{&result->internalTypes};
ConstraintGraphBuilder cgb{sourceModule.name, &result->internalTypes, NotNull(&iceHandler), getGlobalScope2()};
cgb.visit(sourceModule.root);
result->errors = std::move(cgb.errors);
ConstraintSolver cs{&result->internalTypes, cgb.rootScope};
ConstraintSolver cs{&result->internalTypes, NotNull(cgb.rootScope)};
cs.run();
result->scope2s = std::move(cgb.scopes);

View file

@ -5,7 +5,6 @@
#include <algorithm>
#include "Luau/Clone.h"
#include "Luau/Substitution.h"
#include "Luau/Unifier.h"
#include "Luau/VisitTypeVar.h"
@ -16,7 +15,6 @@ LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200);
LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineTableFix, false);
LUAU_FASTFLAGVARIABLE(LuauNormalizeFlagIsConservative, false);
LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineEqFix, false);
LUAU_FASTFLAGVARIABLE(LuauReplaceReplacer, false);
LUAU_FASTFLAG(LuauQuantifyConstrained)
namespace Luau
@ -25,220 +23,24 @@ namespace Luau
namespace
{
struct Replacer : Substitution
struct Replacer
{
TypeArena* arena;
TypeId sourceType;
TypeId replacedType;
DenseHashMap<TypeId, TypeId> replacedTypes{nullptr};
DenseHashMap<TypePackId, TypePackId> replacedPacks{nullptr};
DenseHashMap<TypeId, TypeId> newTypes;
Replacer(TypeArena* arena, TypeId sourceType, TypeId replacedType)
: Substitution(TxnLog::empty(), arena)
: arena(arena)
, sourceType(sourceType)
, replacedType(replacedType)
, newTypes(nullptr)
{
}
bool isDirty(TypeId ty) override
{
if (!sourceType)
return false;
auto vecHasSourceType = [sourceType = sourceType](const auto& vec) {
return end(vec) != std::find(begin(vec), end(vec), sourceType);
};
// Walk every kind of TypeVar and find pointers to sourceType
if (auto t = get<FreeTypeVar>(ty))
return false;
else if (auto t = get<GenericTypeVar>(ty))
return false;
else if (auto t = get<ErrorTypeVar>(ty))
return false;
else if (auto t = get<PrimitiveTypeVar>(ty))
return false;
else if (auto t = get<ConstrainedTypeVar>(ty))
return vecHasSourceType(t->parts);
else if (auto t = get<SingletonTypeVar>(ty))
return false;
else if (auto t = get<FunctionTypeVar>(ty))
{
if (vecHasSourceType(t->generics))
return true;
return false;
}
else if (auto t = get<TableTypeVar>(ty))
{
if (t->boundTo)
return *t->boundTo == sourceType;
for (const auto& [_name, prop] : t->props)
{
if (prop.type == sourceType)
return true;
}
if (auto indexer = t->indexer)
{
if (indexer->indexType == sourceType || indexer->indexResultType == sourceType)
return true;
}
if (vecHasSourceType(t->instantiatedTypeParams))
return true;
return false;
}
else if (auto t = get<MetatableTypeVar>(ty))
return t->table == sourceType || t->metatable == sourceType;
else if (auto t = get<ClassTypeVar>(ty))
return false;
else if (auto t = get<AnyTypeVar>(ty))
return false;
else if (auto t = get<UnionTypeVar>(ty))
return vecHasSourceType(t->options);
else if (auto t = get<IntersectionTypeVar>(ty))
return vecHasSourceType(t->parts);
else if (auto t = get<LazyTypeVar>(ty))
return false;
LUAU_ASSERT(!"Luau::Replacer::isDirty internal error: Unknown TypeVar type");
LUAU_UNREACHABLE();
}
bool isDirty(TypePackId tp) override
{
if (auto it = replacedPacks.find(tp))
return false;
if (auto pack = get<TypePack>(tp))
{
for (TypeId ty : pack->head)
{
if (ty == sourceType)
return true;
}
return false;
}
else if (auto vtp = get<VariadicTypePack>(tp))
return vtp->ty == sourceType;
else
return false;
}
TypeId clean(TypeId ty) override
{
LUAU_ASSERT(sourceType && replacedType);
// Walk every kind of TypeVar and create a copy with sourceType replaced by replacedType
// Before returning, memoize the result for later use.
// Helpfully, Substitution::clone() only shallow-clones the kinds of types that we care to work with. This
// function returns the identity for things like primitives.
TypeId res = clone(ty);
if (auto t = get<FreeTypeVar>(res))
LUAU_ASSERT(!"Impossible");
else if (auto t = get<GenericTypeVar>(res))
LUAU_ASSERT(!"Impossible");
else if (auto t = get<ErrorTypeVar>(res))
LUAU_ASSERT(!"Impossible");
else if (auto t = get<PrimitiveTypeVar>(res))
LUAU_ASSERT(!"Impossible");
else if (auto t = getMutable<ConstrainedTypeVar>(res))
{
for (TypeId& part : t->parts)
{
if (part == sourceType)
part = replacedType;
}
}
else if (auto t = get<SingletonTypeVar>(res))
LUAU_ASSERT(!"Impossible");
else if (auto t = getMutable<FunctionTypeVar>(res))
{
// The constituent typepacks are cleaned separately. We just need to walk the generics array.
for (TypeId& g : t->generics)
{
if (g == sourceType)
g = replacedType;
}
}
else if (auto t = getMutable<TableTypeVar>(res))
{
for (auto& [_key, prop] : t->props)
{
if (prop.type == sourceType)
prop.type = replacedType;
}
}
else if (auto t = getMutable<MetatableTypeVar>(res))
{
if (t->table == sourceType)
t->table = replacedType;
if (t->metatable == sourceType)
t->table = replacedType;
}
else if (auto t = get<ClassTypeVar>(res))
LUAU_ASSERT(!"Impossible");
else if (auto t = get<AnyTypeVar>(res))
LUAU_ASSERT(!"Impossible");
else if (auto t = getMutable<UnionTypeVar>(res))
{
for (TypeId& option : t->options)
{
if (option == sourceType)
option = replacedType;
}
}
else if (auto t = getMutable<IntersectionTypeVar>(res))
{
for (TypeId& part : t->parts)
{
if (part == sourceType)
part = replacedType;
}
}
else if (auto t = get<LazyTypeVar>(res))
LUAU_ASSERT(!"Impossible");
else
LUAU_ASSERT(!"Luau::Replacer::clean internal error: Unknown TypeVar type");
replacedTypes[ty] = res;
return res;
}
TypePackId clean(TypePackId tp) override
{
TypePackId res = clone(tp);
if (auto pack = getMutable<TypePack>(res))
{
for (TypeId& type : pack->head)
{
if (type == sourceType)
type = replacedType;
}
}
else if (auto vtp = getMutable<VariadicTypePack>(res))
{
if (vtp->ty == sourceType)
vtp->ty = replacedType;
}
replacedPacks[tp] = res;
return res;
}
TypeId smartClone(TypeId t)
{
if (FFlag::LuauReplaceReplacer)
{
// The new smartClone is just a memoized clone()
// TODO: Remove the Substitution base class and all other methods from this struct.
// Add DenseHashMap<TypeId, TypeId> newTypes;
t = log->follow(t);
t = follow(t);
TypeId* res = newTypes.find(t);
if (res)
return *res;
@ -249,15 +51,6 @@ struct Replacer : Substitution
return result;
}
else
{
std::optional<TypeId> res = replace(t);
LUAU_ASSERT(res.has_value()); // TODO think about this
if (*res == t)
return clone(t);
return *res;
}
}
};
} // anonymous namespace

View file

@ -8,6 +8,7 @@
#include "Luau/VisitTypeVar.h"
LUAU_FASTFLAG(LuauAlwaysQuantify);
LUAU_FASTFLAG(DebugLuauSharedSelf)
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution);
LUAU_FASTFLAGVARIABLE(LuauQuantifyConstrained, false)
@ -158,6 +159,45 @@ struct Quantifier final : TypeVarOnceVisitor
void quantify(TypeId ty, TypeLevel level)
{
if (FFlag::DebugLuauSharedSelf)
{
ty = follow(ty);
if (auto ttv = getTableType(ty); ttv && ttv->selfTy)
{
Quantifier selfQ{level};
selfQ.traverse(*ttv->selfTy);
Quantifier q{level};
q.traverse(ty);
for (const auto& [_, prop] : ttv->props)
{
auto ftv = getMutable<FunctionTypeVar>(follow(prop.type));
if (!ftv || !ftv->hasSelf)
continue;
if (Luau::first(ftv->argTypes) == ttv->selfTy)
{
ftv->generics.insert(ftv->generics.end(), selfQ.generics.begin(), selfQ.generics.end());
ftv->genericPacks.insert(ftv->genericPacks.end(), selfQ.genericPacks.begin(), selfQ.genericPacks.end());
}
}
}
else if (auto ftv = getMutable<FunctionTypeVar>(ty))
{
Quantifier q{level};
q.traverse(ty);
ftv->generics.insert(ftv->generics.end(), q.generics.begin(), q.generics.end());
ftv->genericPacks.insert(ftv->genericPacks.end(), q.genericPacks.begin(), q.genericPacks.end());
if (ftv->generics.empty() && ftv->genericPacks.empty() && !q.seenMutableType && !q.seenGenericType)
ftv->hasNoGenerics = true;
}
}
else
{
Quantifier q{level};
q.traverse(ty);
@ -173,9 +213,7 @@ void quantify(TypeId ty, TypeLevel level)
ftv->generics = q.generics;
ftv->genericPacks = q.genericPacks;
}
if (ftv->generics.empty() && ftv->genericPacks.empty() && !q.seenMutableType && !q.seenGenericType)
ftv->hasNoGenerics = true;
}
}
void quantify(TypeId ty, Scope2* scope)
@ -206,8 +244,8 @@ struct PureQuantifier : Substitution
std::vector<TypeId> insertedGenerics;
std::vector<TypePackId> insertedGenericPacks;
PureQuantifier(const TxnLog* log, TypeArena* arena, Scope2* scope)
: Substitution(log, arena)
PureQuantifier(TypeArena* arena, Scope2* scope)
: Substitution(TxnLog::empty(), arena)
, scope(scope)
{
}
@ -286,7 +324,7 @@ struct PureQuantifier : Substitution
TypeId quantify(TypeArena* arena, TypeId ty, Scope2* scope)
{
PureQuantifier quantifier{TxnLog::empty(), arena, scope};
PureQuantifier quantifier{arena, scope};
std::optional<TypeId> result = quantifier.substitute(ty);
LUAU_ASSERT(result);
@ -294,8 +332,7 @@ TypeId quantify(TypeArena* arena, TypeId ty, Scope2* scope)
LUAU_ASSERT(ftv);
ftv->generics.insert(ftv->generics.end(), quantifier.insertedGenerics.begin(), quantifier.insertedGenerics.end());
ftv->genericPacks.insert(ftv->genericPacks.end(), quantifier.insertedGenericPacks.begin(), quantifier.insertedGenericPacks.end());
// TODO: Set hasNoGenerics.
ftv->hasNoGenerics = ftv->generics.empty() && ftv->genericPacks.empty();
return *result;
}

View file

@ -153,4 +153,19 @@ std::optional<TypeId> Scope2::lookupTypeBinding(const Name& name)
return std::nullopt;
}
std::optional<TypePackId> Scope2::lookupTypePackBinding(const Name& name)
{
Scope2* s = this;
while (s)
{
auto it = s->typePackBindings.find(name);
if (it != s->typePackBindings.end())
return it->second;
s = s->parent;
}
return std::nullopt;
}
} // namespace Luau

View file

@ -1377,51 +1377,74 @@ std::string generateName(size_t i)
return n;
}
std::string toString(const Constraint& c, ToStringOptions& opts)
std::string toString(const Constraint& constraint, ToStringOptions& opts)
{
if (const SubtypeConstraint* sc = Luau::get_if<SubtypeConstraint>(&c.c))
auto go = [&opts](auto&& c) {
using T = std::decay_t<decltype(c)>;
if constexpr (std::is_same_v<T, SubtypeConstraint>)
{
ToStringResult subStr = toStringDetailed(sc->subType, opts);
ToStringResult subStr = toStringDetailed(c.subType, opts);
opts.nameMap = std::move(subStr.nameMap);
ToStringResult superStr = toStringDetailed(sc->superType, opts);
ToStringResult superStr = toStringDetailed(c.superType, opts);
opts.nameMap = std::move(superStr.nameMap);
return subStr.name + " <: " + superStr.name;
}
else if (const PackSubtypeConstraint* psc = Luau::get_if<PackSubtypeConstraint>(&c.c))
else if constexpr (std::is_same_v<T, PackSubtypeConstraint>)
{
ToStringResult subStr = toStringDetailed(psc->subPack, opts);
ToStringResult subStr = toStringDetailed(c.subPack, opts);
opts.nameMap = std::move(subStr.nameMap);
ToStringResult superStr = toStringDetailed(psc->superPack, opts);
ToStringResult superStr = toStringDetailed(c.superPack, opts);
opts.nameMap = std::move(superStr.nameMap);
return subStr.name + " <: " + superStr.name;
}
else if (const GeneralizationConstraint* gc = Luau::get_if<GeneralizationConstraint>(&c.c))
else if constexpr (std::is_same_v<T, GeneralizationConstraint>)
{
ToStringResult subStr = toStringDetailed(gc->generalizedType, opts);
ToStringResult subStr = toStringDetailed(c.generalizedType, opts);
opts.nameMap = std::move(subStr.nameMap);
ToStringResult superStr = toStringDetailed(gc->sourceType, opts);
ToStringResult superStr = toStringDetailed(c.sourceType, opts);
opts.nameMap = std::move(superStr.nameMap);
return subStr.name + " ~ gen " + superStr.name;
}
else if (const InstantiationConstraint* ic = Luau::get_if<InstantiationConstraint>(&c.c))
else if constexpr (std::is_same_v<T, InstantiationConstraint>)
{
ToStringResult subStr = toStringDetailed(ic->subType, opts);
ToStringResult subStr = toStringDetailed(c.subType, opts);
opts.nameMap = std::move(subStr.nameMap);
ToStringResult superStr = toStringDetailed(ic->superType, opts);
ToStringResult superStr = toStringDetailed(c.superType, opts);
opts.nameMap = std::move(superStr.nameMap);
return subStr.name + " ~ inst " + superStr.name;
}
else if (const NameConstraint* nc = Luau::get<NameConstraint>(c))
else if constexpr (std::is_same_v<T, UnaryConstraint>)
{
ToStringResult namedStr = toStringDetailed(nc->namedType, opts);
ToStringResult resultStr = toStringDetailed(c.resultType, opts);
opts.nameMap = std::move(resultStr.nameMap);
ToStringResult operandStr = toStringDetailed(c.operandType, opts);
opts.nameMap = std::move(operandStr.nameMap);
return resultStr.name + " ~ Unary<" + toString(c.op) + ", " + operandStr.name + ">";
}
else if constexpr (std::is_same_v<T, BinaryConstraint>)
{
ToStringResult resultStr = toStringDetailed(c.resultType);
opts.nameMap = std::move(resultStr.nameMap);
ToStringResult leftStr = toStringDetailed(c.leftType);
opts.nameMap = std::move(leftStr.nameMap);
ToStringResult rightStr = toStringDetailed(c.rightType);
opts.nameMap = std::move(rightStr.nameMap);
return resultStr.name + " ~ Binary<" + toString(c.op) + ", " + leftStr.name + ", " + rightStr.name + ">";
}
else if constexpr (std::is_same_v<T, NameConstraint>)
{
ToStringResult namedStr = toStringDetailed(c.namedType, opts);
opts.nameMap = std::move(namedStr.nameMap);
return "@name(" + namedStr.name + ") = " + nc->name;
return "@name(" + namedStr.name + ") = " + c.name;
}
else
{
LUAU_ASSERT(false);
return "";
}
static_assert(always_false_v<T>, "Non-exhaustive constraint switch");
};
return visit(go, constraint.c);
}
std::string dump(const Constraint& c)

View file

@ -6,8 +6,10 @@
#include "Luau/Ast.h"
#include "Luau/AstQuery.h"
#include "Luau/Clone.h"
#include "Luau/Instantiation.h"
#include "Luau/Normalize.h"
#include "Luau/ConstraintGraphBuilder.h" // FIXME move Scope2 into its own header
#include "Luau/TxnLog.h"
#include "Luau/TypeUtils.h"
#include "Luau/Unifier.h"
#include "Luau/ToString.h"
@ -19,10 +21,12 @@ struct TypeChecker2 : public AstVisitor
const SourceModule* sourceModule;
Module* module;
InternalErrorReporter ice; // FIXME accept a pointer from Frontend
SingletonTypes& singletonTypes;
TypeChecker2(const SourceModule* sourceModule, Module* module)
: sourceModule(sourceModule)
, module(module)
, singletonTypes(getSingletonTypes())
{
}
@ -30,16 +34,30 @@ struct TypeChecker2 : public AstVisitor
TypePackId lookupPack(AstExpr* expr)
{
// If a type isn't in the type graph, it probably means that a recursion limit was exceeded.
// We'll just return anyType in these cases. Typechecking against any is very fast and this
// allows us not to think about this very much in the actual typechecking logic.
TypePackId* tp = module->astTypePacks.find(expr);
LUAU_ASSERT(tp);
if (tp)
return follow(*tp);
else
return singletonTypes.anyTypePack;
}
TypeId lookupType(AstExpr* expr)
{
// If a type isn't in the type graph, it probably means that a recursion limit was exceeded.
// We'll just return anyType in these cases. Typechecking against any is very fast and this
// allows us not to think about this very much in the actual typechecking logic.
TypeId* ty = module->astTypes.find(expr);
LUAU_ASSERT(ty);
if (ty)
return follow(*ty);
TypePackId* tp = module->astTypePacks.find(expr);
if (tp)
return flattenPack(*tp);
return singletonTypes.anyType;
}
TypeId lookupAnnotation(AstType* annotation)
@ -78,7 +96,7 @@ struct TypeChecker2 : public AstVisitor
bestLocation = scopeBounds;
}
}
else
else if (scopeBounds.begin > location.end)
{
// TODO: Is this sound? This relies on the fact that scopes are inserted
// into the scope list in the order that they appear in the AST.
@ -147,16 +165,14 @@ struct TypeChecker2 : public AstVisitor
for (size_t i = 0; i < count; ++i)
{
AstExpr* lhs = assign->vars.data[i];
TypeId* lhsType = module->astTypes.find(lhs);
LUAU_ASSERT(lhsType);
TypeId lhsType = lookupType(lhs);
AstExpr* rhs = assign->values.data[i];
TypeId* rhsType = module->astTypes.find(rhs);
LUAU_ASSERT(rhsType);
TypeId rhsType = lookupType(rhs);
if (!isSubtype(*rhsType, *lhsType, ice))
if (!isSubtype(rhsType, lhsType, ice))
{
reportError(TypeMismatch{*lhsType, *rhsType}, rhs->location);
reportError(TypeMismatch{lhsType, rhsType}, rhs->location);
}
}
@ -181,7 +197,7 @@ struct TypeChecker2 : public AstVisitor
if (!ok)
{
for (const TypeError& e : u.errors)
module->errors.push_back(e);
reportError(e);
}
return true;
@ -189,10 +205,14 @@ struct TypeChecker2 : public AstVisitor
bool visit(AstExprCall* call) override
{
TypeArena arena;
Instantiation instantiation{TxnLog::empty(), &arena, TypeLevel{}};
TypePackId expectedRetType = lookupPack(call);
TypeId functionType = lookupType(call->func);
TypeId instantiatedFunctionType = instantiation.substitute(functionType).value_or(nullptr);
LUAU_ASSERT(functionType);
TypeArena arena;
TypePack args;
for (const auto& arg : call->args)
{
@ -204,7 +224,7 @@ struct TypeChecker2 : public AstVisitor
TypePackId argsTp = arena.addTypePack(args);
FunctionTypeVar ftv{argsTp, expectedRetType};
TypeId expectedType = arena.addType(ftv);
if (!isSubtype(expectedType, functionType, ice))
if (!isSubtype(expectedType, instantiatedFunctionType, ice))
{
unfreeze(module->interfaceTypes);
CloneState cloneState;
@ -252,16 +272,12 @@ struct TypeChecker2 : public AstVisitor
// leftType must have a property called indexName->index
if (auto ttv = get<TableTypeVar>(leftType))
std::optional<TypeId> t = findTablePropertyRespectingMeta(module->errors, leftType, indexName->index.value, indexName->location);
if (t)
{
auto it = ttv->props.find(indexName->index.value);
if (it == ttv->props.end())
if (!isSubtype(resultType, *t, ice))
{
reportError(UnknownProperty{leftType, indexName->index.value}, indexName->location);
}
else if (!isSubtype(resultType, it->second.type, ice))
{
reportError(TypeMismatch{resultType, it->second.type}, indexName->location);
reportError(TypeMismatch{resultType, *t}, indexName->location);
}
}
else
@ -277,7 +293,7 @@ struct TypeChecker2 : public AstVisitor
TypeId actualType = lookupType(number);
TypeId numberType = getSingletonTypes().numberType;
if (!isSubtype(actualType, numberType, ice))
if (!isSubtype(numberType, actualType, ice))
{
reportError(TypeMismatch{actualType, numberType}, number->location);
}
@ -290,7 +306,7 @@ struct TypeChecker2 : public AstVisitor
TypeId actualType = lookupType(string);
TypeId stringType = getSingletonTypes().stringType;
if (!isSubtype(actualType, stringType, ice))
if (!isSubtype(stringType, actualType, ice))
{
reportError(TypeMismatch{actualType, stringType}, string->location);
}
@ -298,6 +314,41 @@ struct TypeChecker2 : public AstVisitor
return true;
}
/** Extract a TypeId for the first type of the provided pack.
*
* Note that this may require modifying some types. I hope this doesn't cause problems!
*/
TypeId flattenPack(TypePackId pack)
{
pack = follow(pack);
while (auto tp = get<TypePack>(pack))
{
if (tp->head.empty() && tp->tail)
pack = *tp->tail;
}
if (auto ty = first(pack))
return *ty;
else if (auto vtp = get<VariadicTypePack>(pack))
return vtp->ty;
else if (auto ftp = get<FreeTypePack>(pack))
{
TypeId result = module->internalTypes.addType(FreeTypeVar{ftp->scope});
TypePackId freeTail = module->internalTypes.addTypePack(FreeTypePack{ftp->scope});
TypePack& resultPack = asMutable(pack)->ty.emplace<TypePack>();
resultPack.head.assign(1, result);
resultPack.tail = freeTail;
return result;
}
else if (get<Unifiable::Error>(pack))
return singletonTypes.errorRecoveryType();
else
ice.ice("flattenPack got a weird pack!");
}
bool visit(AstType* ty) override
{
return true;
@ -321,6 +372,11 @@ struct TypeChecker2 : public AstVisitor
{
module->errors.emplace_back(location, sourceModule->name, std::move(data));
}
void reportError(TypeError e)
{
module->errors.emplace_back(std::move(e));
}
};
void check(const SourceModule& sourceModule, Module* module)

View file

@ -38,11 +38,13 @@ LUAU_FASTFLAGVARIABLE(LuauReduceUnionRecursion, false)
LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false.
LUAU_FASTFLAG(LuauNormalizeFlagIsConservative)
LUAU_FASTFLAGVARIABLE(LuauReturnTypeInferenceInNonstrict, false)
LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false);
LUAU_FASTFLAGVARIABLE(LuauAlwaysQuantify, false);
LUAU_FASTFLAGVARIABLE(LuauReportErrorsOnIndexerKeyMismatch, false)
LUAU_FASTFLAG(LuauQuantifyConstrained)
LUAU_FASTFLAGVARIABLE(LuauFalsyPredicateReturnsNilInstead, false)
LUAU_FASTFLAGVARIABLE(LuauNonCopyableTypeVarFields, false)
LUAU_FASTFLAGVARIABLE(LuauCheckLenMT, false)
namespace Luau
{
@ -238,7 +240,7 @@ static bool isMetamethod(const Name& name)
{
return name == "__index" || name == "__newindex" || name == "__call" || name == "__concat" || name == "__unm" || name == "__add" ||
name == "__sub" || name == "__mul" || name == "__div" || name == "__mod" || name == "__pow" || name == "__tostring" ||
name == "__metatable" || name == "__eq" || name == "__lt" || name == "__le" || name == "__mode";
name == "__metatable" || name == "__eq" || name == "__lt" || name == "__le" || name == "__mode" || name == "__iter" || name == "__len";
}
size_t HashBoolNamePair::operator()(const std::pair<bool, Name>& pair) const
@ -327,10 +329,19 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo
currentModule->timeout = true;
}
if (FFlag::DebugLuauSharedSelf)
{
for (auto& [ty, scope] : deferredQuantification)
Luau::quantify(ty, scope->level);
deferredQuantification.clear();
}
if (get<FreeTypePack>(follow(moduleScope->returnType)))
moduleScope->returnType = addTypePack(TypePack{{}, std::nullopt});
else
{
moduleScope->returnType = anyify(moduleScope, moduleScope->returnType, Location{});
}
for (auto& [_, typeFun] : moduleScope->exportedTypeBindings)
typeFun.type = anyify(moduleScope, typeFun.type, Location{});
@ -537,8 +548,32 @@ void TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const A
}
else if (auto fun = (*protoIter)->as<AstStatFunction>())
{
std::optional<TypeId> selfType;
std::optional<TypeId> expectedType;
if (FFlag::DebugLuauSharedSelf)
{
if (auto name = fun->name->as<AstExprIndexName>())
{
TypeId baseTy = checkExpr(scope, *name->expr).type;
tablify(baseTy);
if (!fun->func->self)
expectedType = getIndexTypeFromType(scope, baseTy, name->index.value, name->indexLocation, false);
else if (auto ttv = getMutableTableType(baseTy))
{
if (!baseTy->persistent && ttv->state != TableState::Sealed && !ttv->selfTy)
{
ttv->selfTy = anyIfNonstrict(freshType(ttv->level));
deferredQuantification.push_back({baseTy, scope});
}
selfType = ttv->selfTy;
}
}
}
else
{
if (!fun->func->self)
{
if (auto name = fun->name->as<AstExprIndexName>())
@ -547,8 +582,9 @@ void TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const A
expectedType = getIndexTypeFromType(scope, exprTy, name->index.value, name->indexLocation, false);
}
}
}
auto pair = checkFunctionSignature(scope, subLevel, *fun->func, fun->name->location, expectedType);
auto pair = checkFunctionSignature(scope, subLevel, *fun->func, fun->name->location, selfType, expectedType);
auto [funTy, funScope] = pair;
functionDecls[*protoIter] = pair;
@ -560,7 +596,7 @@ void TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const A
}
else if (auto fun = (*protoIter)->as<AstStatLocalFunction>())
{
auto pair = checkFunctionSignature(scope, subLevel, *fun->func, fun->name->location, std::nullopt);
auto pair = checkFunctionSignature(scope, subLevel, *fun->func, fun->name->location, std::nullopt, std::nullopt);
auto [funTy, funScope] = pair;
functionDecls[*protoIter] = pair;
@ -2076,7 +2112,7 @@ WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExp
WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprFunction& expr, std::optional<TypeId> expectedType)
{
auto [funTy, funScope] = checkFunctionSignature(scope, 0, expr, std::nullopt, expectedType);
auto [funTy, funScope] = checkFunctionSignature(scope, 0, expr, std::nullopt, std::nullopt, expectedType);
checkFunctionBody(funScope, funTy, expr);
@ -2296,6 +2332,8 @@ WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExp
state.tryUnify(actualFunctionType, expectedFunctionType, /*isFunctionCall*/ true);
state.log.commit();
reportErrors(state.errors);
TypeId retType = first(retTypePack).value_or(nilType);
if (!state.errors.empty())
retType = errorRecoveryType(retType);
@ -2322,6 +2360,23 @@ WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExp
DenseHashSet<TypeId> seen{nullptr};
if (FFlag::LuauCheckLenMT && typeCouldHaveMetatable(operandType))
{
if (auto fnt = findMetatableEntry(operandType, "__len", expr.location))
{
TypeId actualFunctionType = instantiate(scope, *fnt, expr.location);
TypePackId arguments = addTypePack({operandType});
TypePackId retTypePack = addTypePack({numberType});
TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retTypePack));
Unifier state = mkUnifier(expr.location);
state.tryUnify(actualFunctionType, expectedFunctionType, /*isFunctionCall*/ true);
state.log.commit();
reportErrors(state.errors);
}
}
if (!hasLength(operandType, seen, &recursionCount))
reportError(TypeError{expr.location, NotATable{operandType}});
@ -2530,7 +2585,6 @@ TypeId TypeChecker::checkRelationalOperation(
}
}
if (!matches)
{
reportError(
@ -2540,7 +2594,6 @@ TypeId TypeChecker::checkRelationalOperation(
}
}
if (leftMetatable)
{
std::optional<TypeId> metamethod = findMetatableEntry(lhsType, metamethodName, expr.location);
@ -3139,8 +3192,8 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName, T
// `(X) -> Y...`, but after typechecking the body, we cam unify `Y...` with `X`
// to get type `(X) -> X`, then we quantify the free types to get the final
// generic type `<a>(a) -> a`.
std::pair<TypeId, ScopePtr> TypeChecker::checkFunctionSignature(
const ScopePtr& scope, int subLevel, const AstExprFunction& expr, std::optional<Location> originalName, std::optional<TypeId> expectedType)
std::pair<TypeId, ScopePtr> TypeChecker::checkFunctionSignature(const ScopePtr& scope, int subLevel, const AstExprFunction& expr,
std::optional<Location> originalName, std::optional<TypeId> selfType, std::optional<TypeId> expectedType)
{
ScopePtr funScope = childFunctionScope(scope, expr.location, subLevel);
@ -3241,6 +3294,18 @@ std::pair<TypeId, ScopePtr> TypeChecker::checkFunctionSignature(
funScope->returnType = retPack;
if (FFlag::DebugLuauSharedSelf)
{
if (expr.self)
{
// TODO: generic self types: CLI-39906
TypeId selfTy = anyIfNonstrict(selfType ? *selfType : freshType(funScope));
funScope->bindings[expr.self] = {selfTy, expr.self->location};
argTypes.push_back(selfTy);
}
}
else
{
if (expr.self)
{
// TODO: generic self types: CLI-39906
@ -3248,6 +3313,7 @@ std::pair<TypeId, ScopePtr> TypeChecker::checkFunctionSignature(
funScope->bindings[expr.self] = {selfType, expr.self->location};
argTypes.push_back(selfType);
}
}
// Prepare expected argument type iterators if we have an expected function type
TypePackIterator expectedArgsCurr, expectedArgsEnd;
@ -4457,6 +4523,23 @@ TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location
{
ty = follow(ty);
if (FFlag::DebugLuauSharedSelf)
{
if (auto ftv = get<FunctionTypeVar>(ty))
Luau::quantify(ty, scope->level);
else if (auto ttv = getTableType(ty); ttv && ttv->selfTy)
Luau::quantify(ty, scope->level);
if (FFlag::LuauLowerBoundsCalculation)
{
auto [t, ok] = Luau::normalize(ty, currentModule, *iceHandler);
if (!ok)
reportError(location, NormalizationTooComplex{});
return t;
}
}
else
{
const FunctionTypeVar* ftv = get<FunctionTypeVar>(ty);
if (FFlag::LuauAlwaysQuantify)
@ -4477,6 +4560,7 @@ TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location
reportError(location, NormalizationTooComplex{});
return t;
}
}
return ty;
}

View file

@ -740,7 +740,7 @@ void Unifier::tryUnifyTypeWithIntersection(TypeId subTy, TypeId superTy, const I
std::optional<TypeError> unificationTooComplex;
std::optional<TypeError> firstFailedOption;
// T <: A & B if A <: T and B <: T
// T <: A & B if T <: A and T <: B
for (TypeId type : uv->parts)
{
Unifier innerState = makeChildUnifier();
@ -765,7 +765,7 @@ void Unifier::tryUnifyTypeWithIntersection(TypeId subTy, TypeId superTy, const I
void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionTypeVar* uv, TypeId superTy, bool cacheEnabled, bool isFunctionCall)
{
// A & B <: T if T <: A or T <: B
// A & B <: T if A <: T or B <: T
bool found = false;
std::optional<TypeError> unificationTooComplex;

View file

@ -5,6 +5,9 @@
#include <algorithm>
#include <errno.h>
#include <limits.h>
// Warning: If you are introducing new syntax, ensure that it is behind a separate
// flag so that we don't break production games by reverting syntax changes.
// See docs/SyntaxChanges.md for an explanation.
@ -14,6 +17,18 @@ LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100)
LUAU_FASTFLAGVARIABLE(LuauParserFunctionKeywordAsTypeHelp, false)
LUAU_FASTFLAGVARIABLE(LuauReturnTypeTokenConfusion, false)
LUAU_FASTFLAGVARIABLE(LuauFixNamedFunctionParse, false)
LUAU_DYNAMIC_FASTFLAGVARIABLE(LuaReportParseWrongNamedType, false)
bool lua_telemetry_parsed_named_non_function_type = false;
LUAU_FASTFLAGVARIABLE(LuauErrorParseIntegerIssues, false)
LUAU_DYNAMIC_FASTFLAGVARIABLE(LuaReportParseIntegerIssues, false)
bool lua_telemetry_parsed_out_of_range_bin_integer = false;
bool lua_telemetry_parsed_out_of_range_hex_integer = false;
bool lua_telemetry_parsed_double_prefix_hex_integer = false;
namespace Luau
{
@ -1330,7 +1345,7 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack)
{
incrementRecursionCounter("type annotation");
bool monomorphic = lexer.current().type != '<';
bool forceFunctionType = lexer.current().type == '<';
Lexeme begin = lexer.current();
@ -1355,21 +1370,33 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack)
AstArray<AstType*> paramTypes = copy(params);
if (FFlag::LuauFixNamedFunctionParse && !names.empty())
forceFunctionType = true;
bool returnTypeIntroducer =
FFlag::LuauReturnTypeTokenConfusion ? lexer.current().type == Lexeme::SkinnyArrow || lexer.current().type == ':' : false;
// Not a function at all. Just a parenthesized type. Or maybe a type pack with a single element
if (params.size() == 1 && !varargAnnotation && monomorphic &&
if (params.size() == 1 && !varargAnnotation && !forceFunctionType &&
(FFlag::LuauReturnTypeTokenConfusion ? !returnTypeIntroducer : lexer.current().type != Lexeme::SkinnyArrow))
{
if (DFFlag::LuaReportParseWrongNamedType && !names.empty())
lua_telemetry_parsed_named_non_function_type = true;
if (allowPack)
return {{}, allocator.alloc<AstTypePackExplicit>(begin.location, AstTypeList{paramTypes, nullptr})};
else
return {params[0], {}};
}
if ((FFlag::LuauReturnTypeTokenConfusion ? !returnTypeIntroducer : lexer.current().type != Lexeme::SkinnyArrow) && monomorphic && allowPack)
if ((FFlag::LuauReturnTypeTokenConfusion ? !returnTypeIntroducer : lexer.current().type != Lexeme::SkinnyArrow) && !forceFunctionType &&
allowPack)
{
if (DFFlag::LuaReportParseWrongNamedType && !names.empty())
lua_telemetry_parsed_named_non_function_type = true;
return {{}, allocator.alloc<AstTypePackExplicit>(begin.location, AstTypeList{paramTypes, varargAnnotation})};
}
AstArray<std::optional<AstArgumentName>> paramNames = copy(names);
@ -2010,7 +2037,63 @@ AstExpr* Parser::parseAssertionExpr()
return expr;
}
static bool parseNumber(double& result, const char* data)
static const char* parseInteger(double& result, const char* data, int base)
{
char* end = nullptr;
unsigned long long value = strtoull(data, &end, base);
if (value == ULLONG_MAX && errno == ERANGE)
{
// 'errno' might have been set before we called 'strtoull', but we don't want the overhead of resetting a TLS variable on each call
// so we only reset it when we get a result that might be an out-of-range error and parse again to make sure
errno = 0;
value = strtoull(data, &end, base);
if (errno == ERANGE)
{
if (DFFlag::LuaReportParseIntegerIssues)
{
if (base == 2)
lua_telemetry_parsed_out_of_range_bin_integer = true;
else
lua_telemetry_parsed_out_of_range_hex_integer = true;
}
if (FFlag::LuauErrorParseIntegerIssues)
return "Integer number value is out of range";
}
}
result = double(value);
return *end == 0 ? nullptr : "Malformed number";
}
static const char* parseNumber(double& result, const char* data)
{
// binary literal
if (data[0] == '0' && (data[1] == 'b' || data[1] == 'B') && data[2])
return parseInteger(result, data + 2, 2);
// hexadecimal literal
if (data[0] == '0' && (data[1] == 'x' || data[1] == 'X') && data[2])
{
if (DFFlag::LuaReportParseIntegerIssues && data[2] == '0' && (data[3] == 'x' || data[3] == 'X'))
lua_telemetry_parsed_double_prefix_hex_integer = true;
if (FFlag::LuauErrorParseIntegerIssues)
return parseInteger(result, data, 16); // keep prefix, it's handled by 'strtoull'
else
return parseInteger(result, data + 2, 16);
}
char* end = nullptr;
double value = strtod(data, &end);
result = value;
return *end == 0 ? nullptr : "Malformed number";
}
static bool parseNumber_DEPRECATED(double& result, const char* data)
{
// binary literal
if (data[0] == '0' && (data[1] == 'b' || data[1] == 'B') && data[2])
@ -2080,8 +2163,26 @@ AstExpr* Parser::parseSimpleExpr()
scratchData.erase(std::remove(scratchData.begin(), scratchData.end(), '_'), scratchData.end());
}
if (DFFlag::LuaReportParseIntegerIssues || FFlag::LuauErrorParseIntegerIssues)
{
double value = 0;
if (parseNumber(value, scratchData.c_str()))
if (const char* error = parseNumber(value, scratchData.c_str()))
{
nextLexeme();
return reportExprError(start, {}, "%s", error);
}
else
{
nextLexeme();
return allocator.alloc<AstExprConstantNumber>(start, value);
}
}
else
{
double value = 0;
if (parseNumber_DEPRECATED(value, scratchData.c_str()))
{
nextLexeme();
@ -2094,6 +2195,7 @@ AstExpr* Parser::parseSimpleExpr()
return reportExprError(start, {}, "Malformed number");
}
}
}
else if (lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::QuotedString)
{
return parseString();

View file

@ -276,7 +276,7 @@ enum LuauOpcode
// FORGLOOP: adjust loop variables for one iteration of a generic for loop, jump back to the loop header if loop needs to continue
// A: target register; generic for loops assume a register layout [generator, state, index, variables...]
// D: jump offset (-32768..32767)
// AUX: variable count (1..255)
// AUX: variable count (1..255) in the low 8 bits, high bit indicates whether to use ipairs-style traversal in the fast path
// loop variables are adjusted by calling generator(state, index) and expecting it to return a tuple that's copied to the user variables
// the first variable is then copied into index; generator/state are immutable, index isn't visible to user code
LOP_FORGLOOP,
@ -490,6 +490,9 @@ enum LuauBuiltinFunction
// select(_, ...)
LBF_SELECT_VARARG,
// rawlen
LBF_RAWLEN,
};
// Capture type, used in LOP_CAPTURE

View file

@ -4,6 +4,8 @@
#include "Luau/Bytecode.h"
#include "Luau/Compiler.h"
LUAU_FASTFLAGVARIABLE(LuauCompileRawlen, false)
namespace Luau
{
namespace Compile
@ -58,6 +60,8 @@ int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& options)
return LBF_RAWGET;
if (builtin.isGlobal("rawequal"))
return LBF_RAWEQUAL;
if (FFlag::LuauCompileRawlen && builtin.isGlobal("rawlen"))
return LBF_RAWLEN;
if (builtin.isGlobal("unpack"))
return LBF_TABLE_UNPACK;

View file

@ -1302,20 +1302,22 @@ void BytecodeBuilder::validate() const
case LOP_FORNPREP:
case LOP_FORNLOOP:
VREG(LUAU_INSN_A(insn) + 2); // for loop protocol: A, A+1, A+2 are used for iteration
// for loop protocol: A, A+1, A+2 are used for iteration
VREG(LUAU_INSN_A(insn) + 2);
VJUMP(LUAU_INSN_D(insn));
break;
case LOP_FORGPREP:
VREG(LUAU_INSN_A(insn) + 2 + 1); // forg loop protocol: A, A+1, A+2 are used for iteration protocol; A+3, ... are loop variables
// forg loop protocol: A, A+1, A+2 are used for iteration protocol; A+3, ... are loop variables
VREG(LUAU_INSN_A(insn) + 2 + 1);
VJUMP(LUAU_INSN_D(insn));
break;
case LOP_FORGLOOP:
VREG(
LUAU_INSN_A(insn) + 2 + insns[i + 1]); // forg loop protocol: A, A+1, A+2 are used for iteration protocol; A+3, ... are loop variables
// forg loop protocol: A, A+1, A+2 are used for iteration protocol; A+3, ... are loop variables
VREG(LUAU_INSN_A(insn) + 2 + uint8_t(insns[i + 1]));
VJUMP(LUAU_INSN_D(insn));
LUAU_ASSERT(insns[i + 1] >= 1);
LUAU_ASSERT(uint8_t(insns[i + 1]) >= 1);
break;
case LOP_FORGPREP_INEXT:
@ -1679,7 +1681,8 @@ void BytecodeBuilder::dumpInstruction(const uint32_t* code, std::string& result,
break;
case LOP_FORGLOOP:
formatAppend(result, "FORGLOOP R%d L%d %d\n", LUAU_INSN_A(insn), targetLabel, *code++);
formatAppend(result, "FORGLOOP R%d L%d %d%s\n", LUAU_INSN_A(insn), targetLabel, uint8_t(*code), int(*code) < 0 ? " [inext]" : "");
code++;
break;
case LOP_FORGPREP_INEXT:

View file

@ -23,6 +23,8 @@ LUAU_FASTINTVARIABLE(LuauCompileInlineThreshold, 25)
LUAU_FASTINTVARIABLE(LuauCompileInlineThresholdMaxBoost, 300)
LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5)
LUAU_FASTFLAGVARIABLE(LuauCompileNoIpairs, false)
namespace Luau
{
@ -2665,7 +2667,7 @@ struct Compiler
if (builtin.isGlobal("ipairs")) // for .. in ipairs(t)
{
skipOp = LOP_FORGPREP_INEXT;
loopOp = LOP_FORGLOOP_INEXT;
loopOp = FFlag::LuauCompileNoIpairs ? LOP_FORGLOOP : LOP_FORGLOOP_INEXT;
}
else if (builtin.isGlobal("pairs")) // for .. in pairs(t)
{
@ -2709,8 +2711,16 @@ struct Compiler
bytecode.emitAD(loopOp, regs, 0);
if (FFlag::LuauCompileNoIpairs)
{
// TODO: remove loopOp as it's a constant now
LUAU_ASSERT(loopOp == LOP_FORGLOOP);
// FORGLOOP uses aux to encode variable count and fast path flag for ipairs traversal in the high bit
bytecode.emitAux((skipOp == LOP_FORGPREP_INEXT ? 0x80000000 : 0) | uint32_t(stat->vars.size));
}
// note: FORGLOOP needs variable count encoded in AUX field, other loop instructions assume a fixed variable count
if (loopOp == LOP_FORGLOOP)
else if (loopOp == LOP_FORGLOOP)
bytecode.emitAux(uint32_t(stat->vars.size));
size_t endLabel = bytecode.emitLabel();
@ -3341,7 +3351,7 @@ struct Compiler
std::vector<AstLocal*> upvals;
};
struct ReturnVisitor: AstVisitor
struct ReturnVisitor : AstVisitor
{
Compiler* self;
bool returnsOne = true;

View file

@ -11,6 +11,8 @@
#include <stdio.h>
#include <stdlib.h>
LUAU_FASTFLAG(LuauLenTM)
static void writestring(const char* s, size_t l)
{
fwrite(s, 1, l, stdout);
@ -178,6 +180,18 @@ static int luaB_rawset(lua_State* L)
return 1;
}
static int luaB_rawlen(lua_State* L)
{
if (!FFlag::LuauLenTM)
luaL_error(L, "'rawlen' is not available");
int tt = lua_type(L, 1);
luaL_argcheck(L, tt == LUA_TTABLE || tt == LUA_TSTRING, 1, "table or string expected");
int len = lua_objlen(L, 1);
lua_pushinteger(L, len);
return 1;
}
static int luaB_gcinfo(lua_State* L)
{
lua_pushinteger(L, lua_gc(L, LUA_GCCOUNT, 0));
@ -428,6 +442,7 @@ static const luaL_Reg base_funcs[] = {
{"rawequal", luaB_rawequal},
{"rawget", luaB_rawget},
{"rawset", luaB_rawset},
{"rawlen", luaB_rawlen},
{"select", luaB_select},
{"setfenv", luaB_setfenv},
{"setmetatable", luaB_setmetatable},

View file

@ -1117,6 +1117,27 @@ static int luauF_select(lua_State* L, StkId res, TValue* arg0, int nresults, Stk
return -1;
}
static int luauF_rawlen(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams)
{
if (nparams >= 1 && nresults <= 1)
{
if (ttistable(arg0))
{
Table* h = hvalue(arg0);
setnvalue(res, double(luaH_getn(h)));
return 1;
}
else if (ttisstring(arg0))
{
TString* ts = tsvalue(arg0);
setnvalue(res, double(ts->len));
return 1;
}
}
return -1;
}
luau_FastFunction luauF_table[256] = {
NULL,
luauF_assert,
@ -1188,4 +1209,6 @@ luau_FastFunction luauF_table[256] = {
luauF_countrz,
luauF_select,
luauF_rawlen,
};

View file

@ -19,6 +19,7 @@ LUAI_FUNC l_noret luaG_concaterror(lua_State* L, StkId p1, StkId p2);
LUAI_FUNC l_noret luaG_aritherror(lua_State* L, const TValue* p1, const TValue* p2, TMS op);
LUAI_FUNC l_noret luaG_ordererror(lua_State* L, const TValue* p1, const TValue* p2, TMS op);
LUAI_FUNC l_noret luaG_indexerror(lua_State* L, const TValue* p1, const TValue* p2);
LUAI_FUNC LUA_PRINTF_ATTR(2, 3) l_noret luaG_runerrorL(lua_State* L, const char* fmt, ...);
LUAI_FUNC void luaG_pusherror(lua_State* L, const char* error);

View file

@ -39,6 +39,7 @@ const char* const luaT_eventname[] = {
"__namecall",
"__call",
"__iter",
"__len",
"__eq",
@ -52,7 +53,6 @@ const char* const luaT_eventname[] = {
"__unm",
"__len",
"__lt",
"__le",
"__concat",

View file

@ -18,6 +18,7 @@ typedef enum
TM_NAMECALL,
TM_CALL,
TM_ITER,
TM_LEN,
TM_EQ, /* last tag method with `fast' access */
@ -31,7 +32,6 @@ typedef enum
TM_UNM,
TM_LEN,
TM_LT,
TM_LE,
TM_CONCAT,

View file

@ -16,6 +16,8 @@
#include <string.h>
LUAU_FASTFLAGVARIABLE(LuauLenTM, false)
// Disable c99-designator to avoid the warning in CGOTO dispatch table
#ifdef __clang__
#if __has_warning("-Wc99-designator")
@ -2082,13 +2084,25 @@ static void luau_execute(lua_State* L)
// fast-path #1: tables
if (ttistable(rb))
{
setnvalue(ra, cast_num(luaH_getn(hvalue(rb))));
Table* h = hvalue(rb);
if (!FFlag::LuauLenTM || fastnotm(h->metatable, TM_LEN))
{
setnvalue(ra, cast_num(luaH_getn(h)));
VM_NEXT();
}
else
{
// slow-path, may invoke C/Lua via metamethods
VM_PROTECT(luaV_dolen(L, ra, rb));
VM_NEXT();
}
}
// fast-path #2: strings (not very important but easy to do)
else if (ttisstring(rb))
{
setnvalue(ra, cast_num(tsvalue(rb)->len));
TString* ts = tsvalue(rb);
setnvalue(ra, cast_num(ts->len));
VM_NEXT();
}
else
@ -2226,6 +2240,15 @@ static void luau_execute(lua_State* L)
VM_PROTECT(luaD_call(L, ra, 3));
L->top = L->ci->top;
/* recompute ra since stack might have been reallocated */
ra = VM_REG(LUAU_INSN_A(insn));
/* protect against __iter returning nil, since nil is used as a marker for builtin iteration in FORGLOOP */
if (ttisnil(ra))
{
VM_PROTECT(luaG_typeerror(L, ra, "call"));
}
}
else if (fasttm(L, mt, TM_CALL))
{
@ -2258,27 +2281,38 @@ static void luau_execute(lua_State* L)
uint32_t aux = *pc;
// fast-path: builtin table iteration
if (ttisnil(ra) && ttistable(ra + 1) && ttislightuserdata(ra + 2))
// note: ra=nil guarantees ra+1=table and ra+2=userdata because of the setup by FORGPREP* opcodes
// TODO: remove the table check per guarantee above
if (ttisnil(ra) && ttistable(ra + 1))
{
Table* h = hvalue(ra + 1);
int index = int(reinterpret_cast<uintptr_t>(pvalue(ra + 2)));
int sizearray = h->sizearray;
int sizenode = 1 << h->lsizenode;
// clear extra variables since we might have more than two
if (LUAU_UNLIKELY(aux > 2))
// note: while aux encodes ipairs bit, when set we always use 2 variables, so it's safe to check this via a signed comparison
if (LUAU_UNLIKELY(int(aux) > 2))
for (int i = 2; i < int(aux); ++i)
setnilvalue(ra + 3 + i);
// terminate ipairs-style traversal early when encountering nil
if (int(aux) < 0 && (unsigned(index) >= unsigned(sizearray) || ttisnil(&h->array[index])))
{
pc++;
VM_NEXT();
}
// first we advance index through the array portion
while (unsigned(index) < unsigned(sizearray))
{
if (!ttisnil(&h->array[index]))
TValue* e = &h->array[index];
if (!ttisnil(e))
{
setpvalue(ra + 2, reinterpret_cast<void*>(uintptr_t(index + 1)));
setnvalue(ra + 3, double(index + 1));
setobj2s(L, ra + 4, &h->array[index]);
setobj2s(L, ra + 4, e);
pc += LUAU_INSN_D(insn);
LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode));
@ -2288,6 +2322,8 @@ static void luau_execute(lua_State* L)
index++;
}
int sizenode = 1 << h->lsizenode;
// then we advance index through the hash portion
while (unsigned(index - sizearray) < unsigned(sizenode))
{
@ -2321,7 +2357,7 @@ static void luau_execute(lua_State* L)
L->top = ra + 3 + 3; /* func + 2 args (state and index) */
LUAU_ASSERT(L->top <= L->stack_last);
VM_PROTECT(luaD_call(L, ra + 3, aux));
VM_PROTECT(luaD_call(L, ra + 3, uint8_t(aux)));
L->top = L->ci->top;
// recompute ra since stack might have been reallocated

View file

@ -10,6 +10,9 @@
#include "lnumutils.h"
#include <string.h>
#include <stdio.h>
LUAU_FASTFLAG(LuauLenTM)
/* limit for table tag-method chains (to avoid loops) */
#define MAXTAGLOOP 100
@ -51,7 +54,7 @@ const float* luaV_tovector(const TValue* obj)
return nullptr;
}
static void callTMres(lua_State* L, StkId res, const TValue* f, const TValue* p1, const TValue* p2)
static StkId callTMres(lua_State* L, StkId res, const TValue* f, const TValue* p1, const TValue* p2)
{
ptrdiff_t result = savestack(L, res);
// using stack room beyond top is technically safe here, but for very complicated reasons:
@ -71,6 +74,7 @@ static void callTMres(lua_State* L, StkId res, const TValue* f, const TValue* p1
res = restorestack(L, result);
L->top--;
setobjs2s(L, res, L->top);
return res;
}
static void callTM(lua_State* L, const TValue* f, const TValue* p1, const TValue* p2, const TValue* p3)
@ -472,6 +476,8 @@ void luaV_doarith(lua_State* L, StkId ra, const TValue* rb, const TValue* rc, TM
void luaV_dolen(lua_State* L, StkId ra, const TValue* rb)
{
if (!FFlag::LuauLenTM)
{
switch (ttype(rb))
{
case LUA_TTABLE:
@ -490,4 +496,36 @@ void luaV_dolen(lua_State* L, StkId ra, const TValue* rb)
luaG_typeerror(L, rb, "get length of");
}
}
return;
}
const TValue* tm = NULL;
switch (ttype(rb))
{
case LUA_TTABLE:
{
Table* h = hvalue(rb);
if ((tm = fasttm(L, h->metatable, TM_LEN)) == NULL)
{
setnvalue(ra, cast_num(luaH_getn(h)));
return;
}
break;
}
case LUA_TSTRING:
{
TString* ts = tsvalue(rb);
setnvalue(ra, cast_num(ts->len));
return;
}
default:
tm = luaT_gettmbyobj(L, rb, TM_LEN);
}
if (ttisnil(tm))
luaG_typeerror(L, rb, "get length of");
StkId res = callTMres(L, ra, tm, rb, luaO_nilobject);
if (!ttisnumber(res))
luaG_runerror(L, "'__len' must return a number"); /* note, we can't access rb since stack may have been reallocated */
}

View file

@ -11,6 +11,7 @@ static const std::string kNames[] = {
"__div",
"__eq",
"__index",
"__iter",
"__le",
"__len",
"__lt",
@ -41,13 +42,18 @@ static const std::string kNames[] = {
"ceil",
"char",
"charpattern",
"clamp",
"clock",
"clone",
"close",
"codepoint",
"codes",
"concat",
"coroutine",
"cos",
"cosh",
"countlz",
"countrz",
"create",
"date",
"debug",
@ -63,6 +69,7 @@ static const std::string kNames[] = {
"foreachi",
"format",
"frexp",
"freeze",
"function",
"gcinfo",
"getfenv",
@ -72,8 +79,10 @@ static const std::string kNames[] = {
"gmatch",
"gsub",
"huge",
"info",
"insert",
"ipairs",
"isfrozen",
"isyieldable",
"ldexp",
"len",
@ -93,6 +102,7 @@ static const std::string kNames[] = {
"newproxy",
"next",
"nil",
"noise",
"number",
"offset",
"os",
@ -121,6 +131,7 @@ static const std::string kNames[] = {
"select",
"setfenv",
"setmetatable",
"sign",
"sin",
"sinh",
"sort",

View file

@ -261,6 +261,8 @@ L1: RETURN R0 0
TEST_CASE("ForBytecode")
{
ScopedFastFlag sff("LuauCompileNoIpairs", true);
// basic for loop: variable directly refers to internal iteration index (R2)
CHECK_EQ("\n" + compileFunction0("for i=1,5 do print(i) end"), R"(
LOADN R2 1
@ -313,7 +315,7 @@ L0: GETIMPORT R5 3
MOVE R6 R3
MOVE R7 R4
CALL R5 2 0
L1: FORGLOOP_INEXT R0 L0
L1: FORGLOOP R0 L0 2 [inext]
RETURN R0 0
)");
@ -347,13 +349,15 @@ RETURN R0 0
TEST_CASE("ForBytecodeBuiltin")
{
ScopedFastFlag sff("LuauCompileNoIpairs", true);
// we generally recognize builtins like pairs/ipairs and emit special opcodes
CHECK_EQ("\n" + compileFunction0("for k,v in ipairs({}) do end"), R"(
GETIMPORT R0 1
NEWTABLE R1 0 0
CALL R0 1 3
FORGPREP_INEXT R0 L0
L0: FORGLOOP_INEXT R0 L0
L0: FORGLOOP R0 L0 2 [inext]
RETURN R0 0
)");
@ -364,7 +368,7 @@ MOVE R1 R0
NEWTABLE R2 0 0
CALL R1 1 3
FORGPREP_INEXT R1 L0
L0: FORGLOOP_INEXT R1 L0
L0: FORGLOOP R1 L0 2 [inext]
RETURN R0 0
)");
@ -374,7 +378,7 @@ GETUPVAL R0 0
NEWTABLE R1 0 0
CALL R0 1 3
FORGPREP_INEXT R0 L0
L0: FORGLOOP_INEXT R0 L0
L0: FORGLOOP R0 L0 2 [inext]
RETURN R0 0
)");
@ -2107,6 +2111,8 @@ RETURN R3 -1
TEST_CASE("UpvaluesLoopsBytecode")
{
ScopedFastFlag sff("LuauCompileNoIpairs", true);
CHECK_EQ("\n" + compileFunction(R"(
function test()
for i=1,10 do
@ -2169,7 +2175,7 @@ JUMPIFNOT R5 L1
CLOSEUPVALS R3
JUMP L3
L1: CLOSEUPVALS R3
L2: FORGLOOP_INEXT R0 L0
L2: FORGLOOP R0 L0 1 [inext]
L3: LOADN R0 0
RETURN R0 1
)");

View file

@ -231,6 +231,8 @@ TEST_CASE("Assert")
TEST_CASE("Basic")
{
ScopedFastFlag sff("LuauLenTM", true);
runConformance("basic.lua");
}
@ -301,6 +303,8 @@ TEST_CASE("Errors")
TEST_CASE("Events")
{
ScopedFastFlag sff("LuauLenTM", true);
runConformance("events.lua");
}
@ -475,6 +479,8 @@ static void populateRTTI(lua_State* L, Luau::TypeId type)
TEST_CASE("Types")
{
ScopedFastFlag sff("LuauCheckLenMT", true);
runConformance("types.lua", [](lua_State* L) {
Luau::NullModuleResolver moduleResolver;
Luau::InternalErrorReporter iceHandler;

View file

@ -17,7 +17,7 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "hello_world")
)");
cgb.visit(block);
auto constraints = collectConstraints(cgb.rootScope);
auto constraints = collectConstraints(NotNull(cgb.rootScope));
REQUIRE(2 == constraints.size());
@ -36,7 +36,7 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "primitives")
)");
cgb.visit(block);
auto constraints = collectConstraints(cgb.rootScope);
auto constraints = collectConstraints(NotNull(cgb.rootScope));
REQUIRE(3 == constraints.size());
@ -54,15 +54,15 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "nil_primitive")
)");
cgb.visit(block);
auto constraints = collectConstraints(cgb.rootScope);
auto constraints = collectConstraints(NotNull(cgb.rootScope));
ToStringOptions opts;
REQUIRE(5 <= constraints.size());
CHECK("*blocked-1* ~ gen () -> (a...)" == toString(*constraints[0], opts));
CHECK("b ~ inst *blocked-1*" == toString(*constraints[1], opts));
CHECK("() -> (c...) <: b" == toString(*constraints[2], opts));
CHECK("c... <: d" == toString(*constraints[3], opts));
CHECK("*blocked-2* ~ inst *blocked-1*" == toString(*constraints[1], opts));
CHECK("() -> (b...) <: *blocked-2*" == toString(*constraints[2], opts));
CHECK("b... <: c" == toString(*constraints[3], opts));
CHECK("nil <: a..." == toString(*constraints[4], opts));
}
@ -74,15 +74,15 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "function_application")
)");
cgb.visit(block);
auto constraints = collectConstraints(cgb.rootScope);
auto constraints = collectConstraints(NotNull(cgb.rootScope));
REQUIRE(4 == constraints.size());
ToStringOptions opts;
CHECK("string <: a" == toString(*constraints[0], opts));
CHECK("b ~ inst a" == toString(*constraints[1], opts));
CHECK("(string) -> (c...) <: b" == toString(*constraints[2], opts));
CHECK("c... <: d" == toString(*constraints[3], opts));
CHECK("*blocked-1* ~ inst a" == toString(*constraints[1], opts));
CHECK("(string) -> (b...) <: *blocked-1*" == toString(*constraints[2], opts));
CHECK("b... <: c" == toString(*constraints[3], opts));
}
TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "local_function_definition")
@ -94,7 +94,7 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "local_function_definition")
)");
cgb.visit(block);
auto constraints = collectConstraints(cgb.rootScope);
auto constraints = collectConstraints(NotNull(cgb.rootScope));
REQUIRE(2 == constraints.size());
@ -112,15 +112,15 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "recursive_function")
)");
cgb.visit(block);
auto constraints = collectConstraints(cgb.rootScope);
auto constraints = collectConstraints(NotNull(cgb.rootScope));
REQUIRE(4 == constraints.size());
ToStringOptions opts;
CHECK("*blocked-1* ~ gen (a) -> (b...)" == toString(*constraints[0], opts));
CHECK("c ~ inst (a) -> (b...)" == toString(*constraints[1], opts));
CHECK("(a) -> (d...) <: c" == toString(*constraints[2], opts));
CHECK("d... <: b..." == toString(*constraints[3], opts));
CHECK("*blocked-2* ~ inst (a) -> (b...)" == toString(*constraints[1], opts));
CHECK("(a) -> (c...) <: *blocked-2*" == toString(*constraints[2], opts));
CHECK("c... <: b..." == toString(*constraints[3], opts));
}
TEST_SUITE_END();

View file

@ -9,7 +9,7 @@
using namespace Luau;
static TypeId requireBinding(Scope2* scope, const char* name)
static TypeId requireBinding(NotNull<Scope2> scope, const char* name)
{
auto b = linearSearchForBinding(scope, name);
LUAU_ASSERT(b.has_value());
@ -26,12 +26,13 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "hello")
)");
cgb.visit(block);
NotNull<Scope2> rootScope = NotNull(cgb.rootScope);
ConstraintSolver cs{&arena, cgb.rootScope};
ConstraintSolver cs{&arena, rootScope};
cs.run();
TypeId bType = requireBinding(cgb.rootScope, "b");
TypeId bType = requireBinding(rootScope, "b");
CHECK("number" == toString(bType));
}
@ -45,12 +46,13 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "generic_function")
)");
cgb.visit(block);
NotNull<Scope2> rootScope = NotNull(cgb.rootScope);
ConstraintSolver cs{&arena, cgb.rootScope};
ConstraintSolver cs{&arena, rootScope};
cs.run();
TypeId idType = requireBinding(cgb.rootScope, "id");
TypeId idType = requireBinding(rootScope, "id");
CHECK("<a>(a) -> a" == toString(idType));
}
@ -71,14 +73,15 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "proper_let_generalization")
)");
cgb.visit(block);
NotNull<Scope2> rootScope = NotNull(cgb.rootScope);
ToStringOptions opts;
ConstraintSolver cs{&arena, cgb.rootScope};
ConstraintSolver cs{&arena, rootScope};
cs.run();
TypeId idType = requireBinding(cgb.rootScope, "b");
TypeId idType = requireBinding(rootScope, "b");
CHECK("<a>(a) -> number" == toString(idType, opts));
}

View file

@ -195,12 +195,15 @@ ParseResult Fixture::matchParseError(const std::string& source, const std::strin
sourceModule.reset(new SourceModule);
ParseResult result = Parser::parse(source.c_str(), source.length(), *sourceModule->names, *sourceModule->allocator, options);
REQUIRE_MESSAGE(!result.errors.empty(), "Expected a parse error in '" << source << "'");
CHECK_MESSAGE(!result.errors.empty(), "Expected a parse error in '" << source << "'");
if (!result.errors.empty())
{
CHECK_EQ(result.errors.front().getMessage(), message);
if (location)
CHECK_EQ(result.errors.front().getLocation(), *location);
}
return result;
}
@ -213,11 +216,14 @@ ParseResult Fixture::matchParseErrorPrefix(const std::string& source, const std:
sourceModule.reset(new SourceModule);
ParseResult result = Parser::parse(source.c_str(), source.length(), *sourceModule->names, *sourceModule->allocator, options);
REQUIRE_MESSAGE(!result.errors.empty(), "Expected a parse error in '" << source << "'");
CHECK_MESSAGE(!result.errors.empty(), "Expected a parse error in '" << source << "'");
if (!result.errors.empty())
{
const std::string& message = result.errors.front().getMessage();
CHECK_GE(message.length(), prefix.length());
CHECK_EQ(prefix, message.substr(0, prefix.size()));
}
return result;
}
@ -428,6 +434,7 @@ BuiltinsFixture::BuiltinsFixture(bool freeze, bool prepareAutocomplete)
ConstraintGraphBuilderFixture::ConstraintGraphBuilderFixture()
: Fixture()
, cgb(mainModuleName, &arena, NotNull(&ice), frontend.getGlobalScope2())
, forceTheFlag{"DebugLuauDeferredConstraintResolution", true}
{
BlockedTypeVar::nextIndex = 0;

View file

@ -133,6 +133,7 @@ struct Fixture
TestConfigResolver configResolver;
std::unique_ptr<SourceModule> sourceModule;
Frontend frontend;
InternalErrorReporter ice;
TypeChecker& typeChecker;
std::string decorateWithTypes(const std::string& code);
@ -160,7 +161,7 @@ struct BuiltinsFixture : Fixture
struct ConstraintGraphBuilderFixture : Fixture
{
TypeArena arena;
ConstraintGraphBuilder cgb{&arena};
ConstraintGraphBuilder cgb;
ScopedFastFlag forceTheFlag;

View file

@ -30,15 +30,14 @@ struct Test
int Test::count = 0;
}
} // namespace
int foo(NotNull<int> p)
{
return *p;
}
void bar(int* q)
{}
void bar(int* q) {}
TEST_SUITE_BEGIN("NotNull");
@ -46,7 +45,8 @@ TEST_CASE("basic_stuff")
{
NotNull<int> a = NotNull{new int(55)}; // Does runtime test
NotNull<int> b{new int(55)}; // As above
// NotNull<int> c = new int(55); // Nope. Mildly regrettable, but implicit conversion from T* to NotNull<T> in the general case is not good.
// NotNull<int> c = new int(55); // Nope. Mildly regrettable, but implicit conversion from T* to NotNull<T> in the general case is not
// good.
// a = nullptr; // nope

View file

@ -6,6 +6,8 @@
#include "doctest.h"
#include <limits.h>
using namespace Luau;
namespace
@ -786,33 +788,46 @@ TEST_CASE_FIXTURE(Fixture, "parse_numbers_decimal")
TEST_CASE_FIXTURE(Fixture, "parse_numbers_hexadecimal")
{
AstStat* stat = parse("return 0xab, 0XAB05, 0xff_ff");
AstStat* stat = parse("return 0xab, 0XAB05, 0xff_ff, 0xffffffffffffffff");
REQUIRE(stat != nullptr);
AstStatReturn* str = stat->as<AstStatBlock>()->body.data[0]->as<AstStatReturn>();
CHECK(str->list.size == 3);
CHECK(str->list.size == 4);
CHECK_EQ(str->list.data[0]->as<AstExprConstantNumber>()->value, 0xab);
CHECK_EQ(str->list.data[1]->as<AstExprConstantNumber>()->value, 0xAB05);
CHECK_EQ(str->list.data[2]->as<AstExprConstantNumber>()->value, 0xFFFF);
CHECK_EQ(str->list.data[3]->as<AstExprConstantNumber>()->value, double(ULLONG_MAX));
}
TEST_CASE_FIXTURE(Fixture, "parse_numbers_binary")
{
AstStat* stat = parse("return 0b1, 0b0, 0b101010");
AstStat* stat = parse("return 0b1, 0b0, 0b101010, 0b1111111111111111111111111111111111111111111111111111111111111111");
REQUIRE(stat != nullptr);
AstStatReturn* str = stat->as<AstStatBlock>()->body.data[0]->as<AstStatReturn>();
CHECK(str->list.size == 3);
CHECK(str->list.size == 4);
CHECK_EQ(str->list.data[0]->as<AstExprConstantNumber>()->value, 1);
CHECK_EQ(str->list.data[1]->as<AstExprConstantNumber>()->value, 0);
CHECK_EQ(str->list.data[2]->as<AstExprConstantNumber>()->value, 42);
CHECK_EQ(str->list.data[3]->as<AstExprConstantNumber>()->value, double(ULLONG_MAX));
}
TEST_CASE_FIXTURE(Fixture, "parse_numbers_error")
{
ScopedFastFlag luauErrorParseIntegerIssues{"LuauErrorParseIntegerIssues", true};
CHECK_EQ(getParseError("return 0b123"), "Malformed number");
CHECK_EQ(getParseError("return 123x"), "Malformed number");
CHECK_EQ(getParseError("return 0xg"), "Malformed number");
CHECK_EQ(getParseError("return 0x0x123"), "Malformed number");
}
TEST_CASE_FIXTURE(Fixture, "parse_numbers_range_error")
{
ScopedFastFlag luauErrorParseIntegerIssues{"LuauErrorParseIntegerIssues", true};
CHECK_EQ(getParseError("return 0x10000000000000000"), "Integer number value is out of range");
CHECK_EQ(getParseError("return 0b10000000000000000000000000000000000000000000000000000000000000000"), "Integer number value is out of range");
}
TEST_CASE_FIXTURE(Fixture, "break_return_not_last_error")
@ -2111,6 +2126,15 @@ type C<X...> = Packed<(number, X...)>
REQUIRE(stat != nullptr);
}
TEST_CASE_FIXTURE(Fixture, "invalid_type_forms")
{
ScopedFastFlag luauFixNamedFunctionParse{"LuauFixNamedFunctionParse", true};
matchParseError("type A = (b: number)", "Expected '->' when parsing function type, got <eof>");
matchParseError("type P<T...> = () -> T... type B = P<(x: number, y: string)>", "Expected '->' when parsing function type, got '>'");
matchParseError("type F<T... = (a: string)> = (T...) -> ()", "Expected '->' when parsing function type, got '>'");
}
TEST_SUITE_END();
TEST_SUITE_BEGIN("ParseErrorRecovery");

View file

@ -409,6 +409,8 @@ TEST_CASE_FIXTURE(Fixture, "toStringDetailed")
TEST_CASE_FIXTURE(BuiltinsFixture, "toStringDetailed2")
{
ScopedFastFlag sff2{"DebugLuauSharedSelf", true};
CheckResult result = check(R"(
local base = {}
function base:one() return 1 end
@ -424,7 +426,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "toStringDetailed2")
TypeId tType = requireType("inst");
ToStringResult r = toStringDetailed(tType);
CHECK_EQ("{ @metatable { __index: { @metatable { __index: base }, child } }, inst }", r.name);
CHECK_EQ("{ @metatable { __index: { @metatable {| __index: base |}, child } }, inst }", r.name);
CHECK_EQ(0, r.nameMap.typeVars.size());
ToStringOptions opts;
@ -455,11 +457,10 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "toStringDetailed2")
std::string twoResult = toString(tMeta6->props["two"].type, opts);
REQUIRE_EQ("<a>(a) -> number", oneResult.name);
REQUIRE_EQ("<b>(b) -> number", twoResult);
CHECK_EQ("<a>(a) -> number", oneResult.name);
CHECK_EQ("<b>(b) -> number", twoResult);
}
TEST_CASE_FIXTURE(Fixture, "toStringErrorPack")
{
CheckResult result = check(R"(
@ -688,6 +689,10 @@ TEST_CASE_FIXTURE(Fixture, "pick_distinct_names_for_mixed_explicit_and_implicit_
TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_include_self_param")
{
ScopedFastFlag sff[]{
{"DebugLuauSharedSelf", true},
};
CheckResult result = check(R"(
local foo = {}
function foo:method(arg: string): ()
@ -701,9 +706,12 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_include_self_param")
CHECK_EQ("foo:method<a>(self: a, arg: string): ()", toStringNamedFunction("foo:method", *ftv));
}
TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_hide_self_param")
{
ScopedFastFlag sff[]{
{"DebugLuauSharedSelf", true},
};
CheckResult result = check(R"(
local foo = {}
function foo:method(arg: string): ()

View file

@ -716,6 +716,10 @@ TEST_CASE_FIXTURE(Fixture, "forward_declared_alias_is_not_clobbered_by_prior_uni
TEST_CASE_FIXTURE(Fixture, "forward_declared_alias_is_not_clobbered_by_prior_unification_with_any_2")
{
ScopedFastFlag sff[] = {
{"DebugLuauSharedSelf", true},
};
CheckResult result = check(R"(
local B = {}
B.bar = 4
@ -737,7 +741,8 @@ TEST_CASE_FIXTURE(Fixture, "forward_declared_alias_is_not_clobbered_by_prior_uni
type FutureIntersection = A & B
)");
LUAU_REQUIRE_NO_ERRORS(result);
// TODO: shared self causes this test to break in bizarre ways.
LUAU_REQUIRE_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "recursive_types_restriction_ok")

View file

@ -37,6 +37,27 @@ TEST_CASE_FIXTURE(Fixture, "check_function_bodies")
}}));
}
TEST_CASE_FIXTURE(Fixture, "cannot_hoist_interior_defns_into_signature")
{
// This test verifies that the signature does not have access to types
// declared within the body. Under DCR, if the function's inner scope
// encompasses the entire function expression, it would be possible for this
// to type check (but the solver output is somewhat undefined). This test
// ensures that this isn't the case.
CheckResult result = check(R"(
local function f(x: T)
type T = number
end
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK(result.errors[0] == TypeError{Location{{1, 28}, {1, 29}}, getMainSourceModule()->name,
UnknownSymbol{
"T",
UnknownSymbol::Context::Type,
}});
}
TEST_CASE_FIXTURE(Fixture, "infer_return_type")
{
CheckResult result = check("function take_five() return 5 end");

View file

@ -271,13 +271,16 @@ TEST_CASE_FIXTURE(Fixture, "infer_nested_generic_function")
TEST_CASE_FIXTURE(Fixture, "infer_generic_methods")
{
ScopedFastFlag sff{"DebugLuauSharedSelf", true};
CheckResult result = check(R"(
local x = {}
function x:id(x) return x end
function x:f(): string return self:id("hello") end
function x:g(): number return self:id(37) end
)");
LUAU_REQUIRE_NO_ERRORS(result);
// TODO: Quantification should be doing the conversion, not normalization.
LUAU_REQUIRE_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "calling_self_generic_methods")

View file

@ -461,6 +461,61 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_unary_minus")
REQUIRE_EQ(gen->message, "Unary operator '-' not supported by type 'bar'");
}
TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_unary_minus_error")
{
CheckResult result = check(R"(
--!strict
local foo = {
value = 10
}
local mt = {}
setmetatable(foo, mt)
mt.__unm = function(val: boolean): string
return "test"
end
local a = -foo
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ("string", toString(requireType("a")));
TypeMismatch* tm = get<TypeMismatch>(result.errors[0]);
REQUIRE_EQ(*tm->wantedType, *typeChecker.booleanType);
// given type is the typeof(foo) which is complex to compare against
}
TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_unary_len_error")
{
ScopedFastFlag sff("LuauCheckLenMT", true);
CheckResult result = check(R"(
--!strict
local foo = {
value = 10
}
local mt = {}
setmetatable(foo, mt)
mt.__len = function(val: any): string
return "test"
end
local a = #foo
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ("number", toString(requireType("a")));
TypeMismatch* tm = get<TypeMismatch>(result.errors[0]);
REQUIRE_EQ(*tm->wantedType, *typeChecker.numberType);
REQUIRE_EQ(*tm->givenType, *typeChecker.stringType);
}
TEST_CASE_FIXTURE(BuiltinsFixture, "unary_not_is_boolean")
{
CheckResult result = check(R"(

View file

@ -499,4 +499,26 @@ TEST_CASE_FIXTURE(Fixture, "constrained_is_level_dependent")
CHECK_EQ("<a...>(t1) -> {| [t1]: boolean |} where t1 = t2 ; t2 = {+ m1: (t1) -> (a...), m2: (t2) -> (b...) +}", toString(requireType("f")));
}
TEST_CASE_FIXTURE(BuiltinsFixture, "greedy_inference_with_shared_self_triggers_function_with_no_returns")
{
ScopedFastFlag sff{"DebugLuauSharedSelf", true};
CheckResult result = check(R"(
local T = {}
T.__index = T
function T.new()
local self = setmetatable({}, T)
return self:ctor() or self
end
function T:ctor()
-- oops, no return!
end
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ("Not all codepaths in this function return '{ @metatable T, {| |} }, a...'.", toString(result.errors[0]));
}
TEST_SUITE_END();

View file

@ -1863,6 +1863,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "quantifying_a_bound_var_works")
TEST_CASE_FIXTURE(BuiltinsFixture, "less_exponential_blowup_please")
{
ScopedFastFlag sff{"DebugLuauSharedSelf", true};
CheckResult result = check(R"(
--!strict
@ -1890,7 +1892,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "less_exponential_blowup_please")
newData:First()
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
LUAU_REQUIRE_ERROR_COUNT(2, result);
}
TEST_CASE_FIXTURE(Fixture, "common_table_element_union_in_call")
@ -2868,6 +2870,7 @@ TEST_CASE_FIXTURE(Fixture, "inferred_return_type_of_free_table")
{
ScopedFastFlag sff[] = {
{"LuauLowerBoundsCalculation", true},
{"DebugLuauSharedSelf", true},
};
check(R"(
@ -2887,7 +2890,7 @@ TEST_CASE_FIXTURE(Fixture, "inferred_return_type_of_free_table")
end
)");
CHECK_EQ("<a...>(t1) -> {| Byte: <b>(b) -> (a...), PeekByte: <c>(c) -> (a...) |} where t1 = {+ byte: (t1, number) -> (a...) +}",
CHECK_EQ("<a, b...>(t1) -> {| Byte: (a) -> (b...), PeekByte: (a) -> (b...) |} where t1 = {+ byte: (t1, number) -> (b...) +}",
toString(requireType("Base64FileReader")));
}
@ -2904,6 +2907,66 @@ TEST_CASE_FIXTURE(Fixture, "mixed_tables_with_implicit_numbered_keys")
CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[2]));
}
TEST_CASE_FIXTURE(Fixture, "shared_selfs")
{
ScopedFastFlag sff{"DebugLuauSharedSelf", true};
CheckResult result = check(R"(
local t = {}
t.x = 5
function t:m1() return self.x end
function t:m2() return self.y end
return t
)");
LUAU_REQUIRE_NO_ERRORS(result);
ToStringOptions opts;
opts.exhaustive = true;
CHECK_EQ("{| m1: <a, b>({+ x: a, y: b +}) -> a, m2: <a, b>({+ x: a, y: b +}) -> b, x: number |}", toString(requireType("t"), opts));
}
TEST_CASE_FIXTURE(Fixture, "shared_selfs_from_free_param")
{
ScopedFastFlag sff{"DebugLuauSharedSelf", true};
CheckResult result = check(R"(
local function f(t)
function t:m1() return self.x end
function t:m2() return self.y end
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("<a, b>({+ m1: ({+ x: a, y: b +}) -> a, m2: ({+ x: a, y: b +}) -> b +}) -> ()", toString(requireType("f")));
}
TEST_CASE_FIXTURE(BuiltinsFixture, "shared_selfs_through_metatables")
{
ScopedFastFlag sff{"DebugLuauSharedSelf", true};
CheckResult result = check(R"(
local t = {}
t.__index = t
setmetatable({}, t)
function t:m1() return self.x end
function t:m2() return self.y end
return t
)");
LUAU_REQUIRE_NO_ERRORS(result);
ToStringOptions opts;
opts.exhaustive = true;
CHECK_EQ(
toString(requireType("t"), opts), "t1 where t1 = {| __index: t1, m1: <a, b>({+ x: a, y: b +}) -> a, m2: <a, b>({+ x: a, y: b +}) -> b |}");
}
TEST_CASE_FIXTURE(Fixture, "expected_indexer_value_type_extra")
{
CheckResult result = check(R"(
@ -2953,4 +3016,58 @@ TEST_CASE_FIXTURE(Fixture, "prop_access_on_unions_of_indexers_where_key_whose_ty
CHECK_EQ("Type '{number} | {| [boolean]: number |}' does not have key 'x'", toString(result.errors[0]));
}
TEST_CASE_FIXTURE(BuiltinsFixture, "quantify_metatables_of_metatables_of_table")
{
ScopedFastFlag sff[]{
{"DebugLuauSharedSelf", true},
};
CheckResult result = check(R"(
local T = {}
function T:m()
return self.x, self.y
end
function T:n()
end
local U = setmetatable({}, {__index = T})
local V = setmetatable({}, {__index = U})
return V
)");
LUAU_REQUIRE_NO_ERRORS(result);
ToStringOptions opts;
opts.exhaustive = true;
CHECK_EQ(toString(requireType("V"), opts), "{ @metatable { __index: { @metatable { __index: {| m: <a, b>({+ x: a, y: b +}) -> (a, b), n: <a, "
"b>({+ x: a, y: b +}) -> () |} }, { } } }, { } }");
}
TEST_CASE_FIXTURE(Fixture, "quantify_even_that_table_was_never_exported_at_all")
{
ScopedFastFlag sff{"DebugLuauSharedSelf", true};
CheckResult result = check(R"(
local T = {}
function T:m()
return self.x
end
function T:n()
return self.y
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
ToStringOptions opts;
opts.exhaustive = true;
CHECK_EQ("{| m: <a, b>({+ x: a, y: b +}) -> a, n: <a, b>({+ x: a, y: b +}) -> b |}", toString(requireType("T"), opts));
}
TEST_SUITE_END();

View file

@ -369,14 +369,14 @@ TEST_CASE_FIXTURE(Fixture, "globals_are_banned_in_strict_mode")
CHECK_EQ("foo", us->name);
}
TEST_CASE_FIXTURE(BuiltinsFixture, "correctly_scope_locals_do")
TEST_CASE_FIXTURE(Fixture, "correctly_scope_locals_do")
{
CheckResult result = check(R"(
do
local a = 1
end
print(a) -- oops!
local b = a -- oops!
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);

View file

@ -118,10 +118,12 @@ assert((function() return #_G end)() == 0)
assert((function() return #{1,2} end)() == 2)
assert((function() return #'g' end)() == 1)
assert((function() local ud = newproxy(true) getmetatable(ud).__len = function() return 42 end return #ud end)() == 42)
assert((function() local a = 1 a = -a return a end)() == -1)
-- __len metamethod
assert((function() local ud = newproxy(true) getmetatable(ud).__len = function() return 42 end return #ud end)() == 42)
assert((function() local t = {} setmetatable(t, { __len = function() return 42 end }) return #t end)() == 42)
-- while/repeat
assert((function() local a = 10 local b = 1 while a > 1 do b = b * 2 a = a - 1 end return b end)() == 512)
assert((function() local a = 10 local b = 1 repeat b = b * 2 a = a - 1 until a == 1 return b end)() == 512)
@ -889,6 +891,10 @@ assert((function()
return table.concat(res, ',')
end)() == "6,8,10")
-- typeof and type require an argument
assert(pcall(typeof) == false)
assert(pcall(type) == false)
-- typeof == type in absence of custom userdata
assert(concat(typeof(5), typeof(nil), typeof({}), typeof(newproxy())) == "number,nil,table,userdata")

View file

@ -386,4 +386,42 @@ do
assert(t.X) -- fails if table flags are set incorrectly
end
do
-- verify __len behavior & error handling
local t = {1}
setmetatable(t, {})
assert(#t == 1)
setmetatable(t, { __len = rawlen })
assert(#t == 1)
setmetatable(t, { __len = function() return 42 end })
assert(#t == 42)
setmetatable(t, { __len = 42 })
local ok, err = pcall(function() return #t end)
assert(not ok and err:match("attempt to call a number value"))
setmetatable(t, { __len = function() end })
local ok, err = pcall(function() return #t end)
assert(not ok and err:match("'__len' must return a number"))
setmetatable(t, { __len = error })
local ok, err = pcall(function() return #t end)
assert(not ok and err == t)
end
-- verify rawlen behavior
do
local t = {1}
setmetatable(t, { __len = 42 })
assert(rawlen(t) == 1)
assert(rawlen("foo") == 3)
local ok, err = pcall(function() return rawlen(42) end)
assert(not ok and err:match("table or string expected"))
end
return 'OK'