This commit is contained in:
Roni Nevalainen 2021-11-05 21:22:46 +02:00
commit f13cc5486e
No known key found for this signature in database
GPG key ID: 222116D3E5A8F2A1
97 changed files with 8008 additions and 3491 deletions

View file

@ -1,7 +1,8 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "TypeInfer.h"
#include "Luau/Scope.h"
#include "Luau/TypeInfer.h"
namespace Luau
{
@ -33,7 +34,6 @@ TypeId makeFunction( // Polymorphic
std::initializer_list<TypeId> paramTypes, std::initializer_list<std::string> paramNames, std::initializer_list<TypeId> retTypes);
void attachMagicFunction(TypeId ty, MagicFunction fn);
void attachFunctionTag(TypeId ty, std::string constraint);
Property makeProperty(TypeId ty, std::optional<std::string> documentationSymbol = std::nullopt);
void assignPropDocumentationSymbols(TableTypeVar::Props& props, const std::string& baseName);

View file

@ -120,6 +120,7 @@ struct IncorrectGenericParameterCount
Name name;
TypeFun typeFun;
size_t actualParameters;
size_t actualPackParameters;
bool operator==(const IncorrectGenericParameterCount& rhs) const;
};

View file

@ -25,51 +25,39 @@ struct SourceCode
Type type;
};
struct ModuleInfo
{
ModuleName name;
bool optional = false;
};
struct FileResolver
{
virtual ~FileResolver() {}
/** Fetch the source code associated with the provided ModuleName.
*
* FIXME: This requires a string copy!
*
* @returns The actual Lua code on success.
* @returns std::nullopt if no such file exists. When this occurs, type inference will report an UnknownRequire error.
*/
virtual std::optional<SourceCode> readSource(const ModuleName& name) = 0;
/** Does the module exist?
*
* Saves a string copy over reading the source and throwing it away.
*/
virtual bool moduleExists(const ModuleName& name) const = 0;
virtual std::optional<ModuleInfo> resolveModule(const ModuleInfo* context, AstExpr* expr)
{
return std::nullopt;
}
virtual std::optional<ModuleName> fromAstFragment(AstExpr* expr) const = 0;
/** Given a valid module name and a string of arbitrary data, figure out the concatenation.
*/
virtual ModuleName concat(const ModuleName& lhs, std::string_view rhs) const = 0;
/** Goes "up" a level in the hierarchy that the ModuleName represents.
*
* For instances, this is analogous to someInstance.Parent; for paths, this is equivalent to removing the last
* element of the path. Other ModuleName representations may have other ways of doing this.
*
* @returns The parent ModuleName, if one exists.
* @returns std::nullopt if there is no parent for this module name.
*/
virtual std::optional<ModuleName> getParentModuleName(const ModuleName& name) const = 0;
virtual std::optional<std::string> getHumanReadableModuleName_(const ModuleName& name) const
virtual std::string getHumanReadableModuleName(const ModuleName& name) const
{
return name;
}
virtual std::optional<std::string> getEnvironmentForModule(const ModuleName& name) const = 0;
virtual std::optional<std::string> getEnvironmentForModule(const ModuleName& name) const
{
return std::nullopt;
}
/** LanguageService only:
* std::optional<ModuleName> fromInstance(Instance* inst)
*/
// DEPRECATED APIS
// These are going to be removed with LuauNewRequireTracer
virtual bool moduleExists(const ModuleName& name) const = 0;
virtual std::optional<ModuleName> fromAstFragment(AstExpr* expr) const = 0;
virtual ModuleName concat(const ModuleName& lhs, std::string_view rhs) const = 0;
virtual std::optional<ModuleName> getParentModuleName(const ModuleName& name) const = 0;
};
struct NullFileResolver : FileResolver
@ -94,10 +82,6 @@ struct NullFileResolver : FileResolver
{
return std::nullopt;
}
std::optional<std::string> getEnvironmentForModule(const ModuleName& name) const override
{
return std::nullopt;
}
};
} // namespace Luau

View file

@ -90,10 +90,12 @@ struct Module
TypeArena internalTypes;
std::vector<std::pair<Location, ScopePtr>> scopes; // never empty
std::unordered_map<const AstExpr*, TypeId> astTypes;
std::unordered_map<const AstExpr*, TypeId> astExpectedTypes;
std::unordered_map<const AstExpr*, TypeId> astOriginalCallTypes;
std::unordered_map<const AstExpr*, TypeId> astOverloadResolvedTypes;
DenseHashMap<const AstExpr*, TypeId> astTypes{nullptr};
DenseHashMap<const AstExpr*, TypeId> astExpectedTypes{nullptr};
DenseHashMap<const AstExpr*, TypeId> astOriginalCallTypes{nullptr};
DenseHashMap<const AstExpr*, TypeId> astOverloadResolvedTypes{nullptr};
std::unordered_map<Name, TypeId> declaredGlobals;
ErrorVec errors;
Mode mode;

View file

@ -15,12 +15,6 @@ struct Module;
using ModulePtr = std::shared_ptr<Module>;
struct ModuleInfo
{
ModuleName name;
bool optional = false;
};
struct ModuleResolver
{
virtual ~ModuleResolver() {}

View file

@ -0,0 +1,14 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/TypeVar.h"
namespace Luau
{
struct Module;
using ModulePtr = std::shared_ptr<Module>;
void quantify(ModulePtr module, TypeId ty, TypeLevel level);
} // namespace Luau

View file

@ -17,12 +17,11 @@ struct AstLocal;
struct RequireTraceResult
{
DenseHashMap<const AstExpr*, ModuleName> exprs{0};
DenseHashMap<const AstExpr*, bool> optional{0};
DenseHashMap<const AstExpr*, ModuleInfo> exprs{nullptr};
std::vector<std::pair<ModuleName, Location>> requires;
};
RequireTraceResult traceRequires(FileResolver* fileResolver, AstStatBlock* root, ModuleName currentModuleName);
RequireTraceResult traceRequires(FileResolver* fileResolver, AstStatBlock* root, const ModuleName& currentModuleName);
} // namespace Luau

View file

@ -0,0 +1,67 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Location.h"
#include "Luau/TypeVar.h"
#include <unordered_map>
#include <optional>
#include <memory>
namespace Luau
{
struct Scope;
using ScopePtr = std::shared_ptr<Scope>;
struct Binding
{
TypeId typeId;
Location location;
bool deprecated = false;
std::string deprecatedSuggestion;
std::optional<std::string> documentationSymbol;
};
struct Scope
{
explicit Scope(TypePackId returnType); // root scope
explicit Scope(const ScopePtr& parent, int subLevel = 0); // child scope. Parent must not be nullptr.
const ScopePtr parent; // null for the root
std::unordered_map<Symbol, Binding> bindings;
TypePackId returnType;
bool breakOk = false;
std::optional<TypePackId> varargPack;
TypeLevel level;
std::unordered_map<Name, TypeFun> exportedTypeBindings;
std::unordered_map<Name, TypeFun> privateTypeBindings;
std::unordered_map<Name, Location> typeAliasLocations;
std::unordered_map<Name, std::unordered_map<Name, TypeFun>> importedTypeBindings;
std::optional<TypeId> lookup(const Symbol& name);
std::optional<TypeFun> lookupType(const Name& name);
std::optional<TypeFun> lookupImportedType(const Name& moduleAlias, const Name& name);
std::unordered_map<Name, TypePackId> privateTypePackBindings;
std::optional<TypePackId> lookupPack(const Name& name);
// WARNING: This function linearly scans for a string key of equal value! It is thus O(n**2)
std::optional<Binding> linearSearchForBinding(const std::string& name, bool traverseScopeChain = true);
RefinementMap refinements;
// For mutually recursive type aliases, it's important that
// they use the same types for the same names.
// For instance, in `type Tree<T> { data: T, children: Forest<T> } type Forest<T> = {Tree<T>}`
// we need that the generic type `T` in both cases is the same, so we use a cache.
std::unordered_map<Name, TypeId> typeAliasTypeParameters;
std::unordered_map<Name, TypePackId> typeAliasTypePackParameters;
};
} // namespace Luau

View file

@ -52,8 +52,6 @@
// `T`, and the type of `f` are in the same SCC, which is why `f` gets
// replaced.
LUAU_FASTFLAG(DebugLuauTrackOwningArena)
namespace Luau
{
@ -188,20 +186,12 @@ struct Substitution : FindDirty
template<typename T>
TypeId addType(const T& tv)
{
TypeId allocated = currentModule->internalTypes.typeVars.allocate(tv);
if (FFlag::DebugLuauTrackOwningArena)
asMutable(allocated)->owningArena = &currentModule->internalTypes;
return allocated;
return currentModule->internalTypes.addType(tv);
}
template<typename T>
TypePackId addTypePack(const T& tp)
{
TypePackId allocated = currentModule->internalTypes.typePacks.allocate(tp);
if (FFlag::DebugLuauTrackOwningArena)
asMutable(allocated)->owningArena = &currentModule->internalTypes;
return allocated;
return currentModule->internalTypes.addTypePack(TypePackVar{tp});
}
};

View file

@ -69,4 +69,6 @@ std::string toString(const TypePackVar& tp, const ToStringOptions& opts = {});
void dump(TypeId ty);
void dump(TypePackId ty);
std::string generateName(size_t n);
} // namespace Luau

View file

@ -12,6 +12,7 @@ struct AstArray;
class AstStat;
bool containsFunctionCall(const AstStat& stat);
bool containsFunctionCallOrReturn(const AstStat& stat);
bool isFunction(const AstStat& stat);
void toposort(std::vector<AstStat*>& stats);

View file

@ -3,19 +3,37 @@
#include "Luau/TypeVar.h"
LUAU_FASTFLAG(LuauShareTxnSeen);
namespace Luau
{
// Log of where what TypeIds we are rebinding and what they used to be
struct TxnLog
{
TxnLog() = default;
explicit TxnLog(const std::vector<std::pair<TypeId, TypeId>>& seen)
: seen(seen)
TxnLog()
: originalSeenSize(0)
, ownedSeen()
, sharedSeen(&ownedSeen)
{
}
explicit TxnLog(std::vector<std::pair<TypeId, TypeId>>* sharedSeen)
: originalSeenSize(sharedSeen->size())
, ownedSeen()
, sharedSeen(sharedSeen)
{
}
explicit TxnLog(const std::vector<std::pair<TypeId, TypeId>>& ownedSeen)
: originalSeenSize(ownedSeen.size())
, ownedSeen(ownedSeen)
, sharedSeen(nullptr)
{
// This is deprecated!
LUAU_ASSERT(!FFlag::LuauShareTxnSeen);
}
TxnLog(const TxnLog&) = delete;
TxnLog& operator=(const TxnLog&) = delete;
@ -38,9 +56,11 @@ private:
std::vector<std::pair<TypeId, TypeVar>> typeVarChanges;
std::vector<std::pair<TypePackId, TypePackVar>> typePackChanges;
std::vector<std::pair<TableTypeVar*, std::optional<TypeId>>> tableChanges;
size_t originalSeenSize;
public:
std::vector<std::pair<TypeId, TypeId>> seen; // used to avoid infinite recursion when types are cyclic
std::vector<std::pair<TypeId, TypeId>> ownedSeen; // used to avoid infinite recursion when types are cyclic
std::vector<std::pair<TypeId, TypeId>>* sharedSeen; // shared with all the descendent logs
};
} // namespace Luau

View file

@ -11,6 +11,7 @@
#include "Luau/TypePack.h"
#include "Luau/TypeVar.h"
#include "Luau/Unifier.h"
#include "Luau/UnifierSharedState.h"
#include <memory>
#include <unordered_map>
@ -86,7 +87,10 @@ struct ApplyTypeFunction : Substitution
{
TypeLevel level;
bool encounteredForwardedType;
std::unordered_map<TypeId, TypeId> arguments;
std::unordered_map<TypeId, TypeId> typeArguments;
std::unordered_map<TypePackId, TypePackId> typePackArguments;
bool ignoreChildren(TypeId ty) override;
bool ignoreChildren(TypePackId tp) override;
bool isDirty(TypeId ty) override;
bool isDirty(TypePackId tp) override;
TypeId clean(TypeId ty) override;
@ -118,7 +122,7 @@ struct TypeChecker
void check(const ScopePtr& scope, const AstStatForIn& forin);
void check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatFunction& function);
void check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatLocalFunction& function);
void check(const ScopePtr& scope, const AstStatTypeAlias& typealias, bool forwardDeclare = false);
void check(const ScopePtr& scope, const AstStatTypeAlias& typealias, int subLevel = 0, bool forwardDeclare = false);
void check(const ScopePtr& scope, const AstStatDeclareClass& declaredClass);
void check(const ScopePtr& scope, const AstStatDeclareFunction& declaredFunction);
@ -328,11 +332,12 @@ private:
TypeId resolveType(const ScopePtr& scope, const AstType& annotation, bool canBeGeneric = false);
TypePackId resolveTypePack(const ScopePtr& scope, const AstTypeList& types);
TypePackId resolveTypePack(const ScopePtr& scope, const AstTypePack& annotation);
TypeId instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector<TypeId>& typeParams, const Location& location);
TypeId instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& typePackParams, const Location& location);
// Note: `scope` must be a fresh scope.
std::pair<std::vector<TypeId>, std::vector<TypePackId>> createGenericTypes(
const ScopePtr& scope, const AstNode& node, const AstArray<AstName>& genericNames, const AstArray<AstName>& genericPackNames);
const ScopePtr& scope, std::optional<TypeLevel> levelOpt, const AstNode& node, const AstArray<AstName>& genericNames, const AstArray<AstName>& genericPackNames);
public:
ErrorVec resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense);
@ -379,6 +384,8 @@ public:
std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope;
InternalErrorReporter* iceHandler;
UnifierSharedState unifierState;
public:
const TypeId nilType;
const TypeId numberType;
@ -398,54 +405,6 @@ private:
int recursionCount = 0;
};
struct Binding
{
TypeId typeId;
Location location;
bool deprecated = false;
std::string deprecatedSuggestion;
std::optional<std::string> documentationSymbol;
};
struct Scope
{
explicit Scope(TypePackId returnType); // root scope
explicit Scope(const ScopePtr& parent, int subLevel = 0); // child scope. Parent must not be nullptr.
const ScopePtr parent; // null for the root
std::unordered_map<Symbol, Binding> bindings;
TypePackId returnType;
bool breakOk = false;
std::optional<TypePackId> varargPack;
TypeLevel level;
std::unordered_map<Name, TypeFun> exportedTypeBindings;
std::unordered_map<Name, TypeFun> privateTypeBindings;
std::unordered_map<Name, Location> typeAliasLocations;
std::unordered_map<Name, std::unordered_map<Name, TypeFun>> importedTypeBindings;
std::optional<TypeId> lookup(const Symbol& name);
std::optional<TypeFun> lookupType(const Name& name);
std::optional<TypeFun> lookupImportedType(const Name& moduleAlias, const Name& name);
std::unordered_map<Name, TypePackId> privateTypePackBindings;
std::optional<TypePackId> lookupPack(const Name& name);
// WARNING: This function linearly scans for a string key of equal value! It is thus O(n**2)
std::optional<Binding> linearSearchForBinding(const std::string& name, bool traverseScopeChain = true);
RefinementMap refinements;
// For mutually recursive type aliases, it's important that
// they use the same types for the same names.
// For instance, in `type Tree<T> { data: T, children: Forest<T> } type Forest<T> = {Tree<T>}`
// we need that the generic type `T` in both cases is the same, so we use a cache.
std::unordered_map<Name, TypeId> typeAliasParameters;
};
// Unit test hook
void setPrintLine(void (*pl)(const std::string& s));
void resetPrintLine();

View file

@ -117,7 +117,8 @@ bool areEqual(SeenSet& seen, const TypePackVar& lhs, const TypePackVar& rhs);
TypePackId follow(TypePackId tp);
size_t size(const TypePackId tp);
size_t size(TypePackId tp);
bool finite(TypePackId tp);
size_t size(const TypePack& tp);
std::optional<TypeId> first(TypePackId tp);

View file

@ -228,6 +228,7 @@ struct TableTypeVar
std::map<Name, Location> methodDefinitionLocations;
std::vector<TypeId> instantiatedTypeParams;
std::vector<TypePackId> instantiatedTypePackParams;
ModuleName definitionModuleName;
std::optional<TypeId> boundTo;
@ -284,8 +285,9 @@ struct ClassTypeVar
struct TypeFun
{
/// These should all be generic
// These should all be generic
std::vector<TypeId> typeParams;
std::vector<TypePackId> typePackParams;
/** The underlying type.
*
@ -293,6 +295,20 @@ struct TypeFun
* You must first use TypeChecker::instantiateTypeFun to turn it into a real type.
*/
TypeId type;
TypeFun() = default;
TypeFun(std::vector<TypeId> typeParams, TypeId type)
: typeParams(std::move(typeParams))
, type(type)
{
}
TypeFun(std::vector<TypeId> typeParams, std::vector<TypePackId> typePackParams, TypeId type)
: typeParams(std::move(typeParams))
, typePackParams(std::move(typePackParams))
, type(type)
{
}
};
// Anything! All static checking is off.
@ -524,8 +540,11 @@ UnionTypeVarIterator end(const UnionTypeVar* utv);
using TypeIdPredicate = std::function<std::optional<TypeId>(TypeId)>;
std::vector<TypeId> filterMap(TypeId type, TypeIdPredicate predicate);
// TEMP: Clip this prototype with FFlag::LuauStringMetatable
std::optional<ExprResult<TypePackId>> magicFunctionFormat(
struct TypeChecker& typechecker, const std::shared_ptr<struct Scope>& scope, const AstExprCall& expr, ExprResult<TypePackId> exprResult);
void attachTag(TypeId ty, const std::string& tagName);
void attachTag(Property& prop, const std::string& tagName);
bool hasTag(TypeId ty, const std::string& tagName);
bool hasTag(const Property& prop, const std::string& tagName);
bool hasTag(const Tags& tags, const std::string& tagName); // Do not use in new work.
} // namespace Luau

View file

@ -6,6 +6,7 @@
#include "Luau/TxnLog.h"
#include "Luau/TypeInfer.h"
#include "Luau/Module.h" // FIXME: For TypeArena. It merits breaking out into its own header.
#include "Luau/UnifierSharedState.h"
#include <unordered_set>
@ -36,12 +37,20 @@ struct Unifier
Variance variance = Covariant;
CountMismatch::Context ctx = CountMismatch::Arg;
std::shared_ptr<UnifierCounters> counters;
InternalErrorReporter* iceHandler;
UnifierCounters* counters;
UnifierCounters countersData;
Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, InternalErrorReporter* iceHandler);
Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::vector<std::pair<TypeId, TypeId>>& seen, const Location& location,
Variance variance, InternalErrorReporter* iceHandler, const std::shared_ptr<UnifierCounters>& counters = nullptr);
std::shared_ptr<UnifierCounters> counters_DEPRECATED;
UnifierSharedState& sharedState;
Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, UnifierSharedState& sharedState);
Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::vector<std::pair<TypeId, TypeId>>& ownedSeen, const Location& location,
Variance variance, UnifierSharedState& sharedState, const std::shared_ptr<UnifierCounters>& counters_DEPRECATED = nullptr,
UnifierCounters* counters = nullptr);
Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, std::vector<std::pair<TypeId, TypeId>>* sharedSeen, const Location& location,
Variance variance, UnifierSharedState& sharedState, const std::shared_ptr<UnifierCounters>& counters_DEPRECATED = nullptr,
UnifierCounters* counters = nullptr);
// Test whether the two type vars unify. Never commits the result.
ErrorVec canUnify(TypeId superTy, TypeId subTy);
@ -58,11 +67,14 @@ private:
void tryUnifyPrimitives(TypeId superTy, TypeId subTy);
void tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCall = false);
void tryUnifyTables(TypeId left, TypeId right, bool isIntersection = false);
void DEPRECATED_tryUnifyTables(TypeId left, TypeId right, bool isIntersection = false);
void tryUnifyFreeTable(TypeId free, TypeId other);
void tryUnifySealedTables(TypeId left, TypeId right, bool isIntersection);
void tryUnifyWithMetatable(TypeId metatable, TypeId other, bool reversed);
void tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed);
void tryUnify(const TableIndexer& superIndexer, const TableIndexer& subIndexer);
TypeId deeplyOptional(TypeId ty, std::unordered_map<TypeId, TypeId> seen = {});
void cacheResult(TypeId superTy, TypeId subTy);
public:
void tryUnify(TypePackId superTy, TypePackId subTy, bool isFunctionCall = false);
@ -80,9 +92,9 @@ private:
public:
// Report an "infinite type error" if the type "needle" already occurs within "haystack"
void occursCheck(TypeId needle, TypeId haystack);
void occursCheck(std::unordered_set<TypeId>& seen, TypeId needle, TypeId haystack);
void occursCheck(std::unordered_set<TypeId>& seen_DEPRECATED, DenseHashSet<TypeId>& seen, TypeId needle, TypeId haystack);
void occursCheck(TypePackId needle, TypePackId haystack);
void occursCheck(std::unordered_set<TypePackId>& seen, TypePackId needle, TypePackId haystack);
void occursCheck(std::unordered_set<TypePackId>& seen_DEPRECATED, DenseHashSet<TypePackId>& seen, TypePackId needle, TypePackId haystack);
Unifier makeChildUnifier();
@ -93,6 +105,10 @@ private:
[[noreturn]] void ice(const std::string& message, const Location& location);
[[noreturn]] void ice(const std::string& message);
// Remove with FFlagLuauCacheUnifyTableResults
DenseHashSet<TypeId> tempSeenTy_DEPRECATED{nullptr};
DenseHashSet<TypePackId> tempSeenTp_DEPRECATED{nullptr};
};
} // namespace Luau

View file

@ -0,0 +1,44 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/DenseHash.h"
#include "Luau/TypeVar.h"
#include "Luau/TypePack.h"
#include <utility>
namespace Luau
{
struct InternalErrorReporter;
struct TypeIdPairHash
{
size_t hashOne(Luau::TypeId key) const
{
return (uintptr_t(key) >> 4) ^ (uintptr_t(key) >> 9);
}
size_t operator()(const std::pair<Luau::TypeId, Luau::TypeId>& x) const
{
return hashOne(x.first) ^ (hashOne(x.second) << 1);
}
};
struct UnifierSharedState
{
UnifierSharedState(InternalErrorReporter* iceHandler)
: iceHandler(iceHandler)
{
}
InternalErrorReporter* iceHandler;
DenseHashSet<void*> seenAny{nullptr};
DenseHashMap<TypeId, bool> skipCacheForType{nullptr};
DenseHashSet<std::pair<TypeId, TypeId>, TypeIdPairHash> cachedUnify{{nullptr, nullptr}};
DenseHashSet<TypeId> tempSeenTy{nullptr};
DenseHashSet<TypePackId> tempSeenTp{nullptr};
};
} // namespace Luau

View file

@ -1,9 +1,12 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/DenseHash.h"
#include "Luau/TypeVar.h"
#include "Luau/TypePack.h"
LUAU_FASTFLAG(LuauCacheUnifyTableResults)
namespace Luau
{
@ -32,17 +35,33 @@ inline bool hasSeen(std::unordered_set<void*>& seen, const void* tv)
return !seen.insert(ttv).second;
}
inline bool hasSeen(DenseHashSet<void*>& seen, const void* tv)
{
void* ttv = const_cast<void*>(tv);
if (seen.contains(ttv))
return true;
seen.insert(ttv);
return false;
}
inline void unsee(std::unordered_set<void*>& seen, const void* tv)
{
void* ttv = const_cast<void*>(tv);
seen.erase(ttv);
}
template<typename F>
void visit(TypePackId tp, F& f, std::unordered_set<void*>& seen);
inline void unsee(DenseHashSet<void*>& seen, const void* tv)
{
// When DenseHashSet is used for 'visitOnce', where don't forget visited elements
}
template<typename F>
void visit(TypeId ty, F& f, std::unordered_set<void*>& seen)
template<typename F, typename Set>
void visit(TypePackId tp, F& f, Set& seen);
template<typename F, typename Set>
void visit(TypeId ty, F& f, Set& seen)
{
if (visit_detail::hasSeen(seen, ty))
{
@ -79,15 +98,23 @@ void visit(TypeId ty, F& f, std::unordered_set<void*>& seen)
else if (auto ttv = get<TableTypeVar>(ty))
{
// Some visitors want to see bound tables, that's why we visit the original type
if (apply(ty, *ttv, seen, f))
{
for (auto& [_name, prop] : ttv->props)
visit(prop.type, f, seen);
if (ttv->indexer)
if (FFlag::LuauCacheUnifyTableResults && ttv->boundTo)
{
visit(ttv->indexer->indexType, f, seen);
visit(ttv->indexer->indexResultType, f, seen);
visit(*ttv->boundTo, f, seen);
}
else
{
for (auto& [_name, prop] : ttv->props)
visit(prop.type, f, seen);
if (ttv->indexer)
{
visit(ttv->indexer->indexType, f, seen);
visit(ttv->indexer->indexResultType, f, seen);
}
}
}
}
@ -140,8 +167,8 @@ void visit(TypeId ty, F& f, std::unordered_set<void*>& seen)
visit_detail::unsee(seen, ty);
}
template<typename F>
void visit(TypePackId tp, F& f, std::unordered_set<void*>& seen)
template<typename F, typename Set>
void visit(TypePackId tp, F& f, Set& seen)
{
if (visit_detail::hasSeen(seen, tp))
{
@ -182,6 +209,7 @@ void visit(TypePackId tp, F& f, std::unordered_set<void*>& seen)
visit_detail::unsee(seen, tp);
}
} // namespace visit_detail
template<typename TID, typename F>
@ -197,4 +225,11 @@ void visitTypeVar(TID ty, F& f)
visit_detail::visit(ty, f, seen);
}
template<typename TID, typename F>
void visitTypeVarOnce(TID ty, F& f, DenseHashSet<void*>& seen)
{
seen.clear();
visit_detail::visit(ty, f, seen);
}
} // namespace Luau

View file

@ -2,6 +2,7 @@
#include "Luau/AstQuery.h"
#include "Luau/Module.h"
#include "Luau/Scope.h"
#include "Luau/TypeInfer.h"
#include "Luau/TypeVar.h"
#include "Luau/ToString.h"
@ -143,8 +144,8 @@ std::optional<TypeId> findTypeAtPosition(const Module& module, const SourceModul
{
if (auto expr = findExprAtPosition(sourceModule, pos))
{
if (auto it = module.astTypes.find(expr); it != module.astTypes.end())
return it->second;
if (auto it = module.astTypes.find(expr))
return *it;
}
return std::nullopt;
@ -154,8 +155,8 @@ std::optional<TypeId> findExpectedTypeAtPosition(const Module& module, const Sou
{
if (auto expr = findExprAtPosition(sourceModule, pos))
{
if (auto it = module.astExpectedTypes.find(expr); it != module.astExpectedTypes.end())
return it->second;
if (auto it = module.astExpectedTypes.find(expr))
return *it;
}
return std::nullopt;
@ -322,9 +323,9 @@ std::optional<DocumentationSymbol> getDocumentationSymbolAtPosition(const Source
TypeId matchingOverload = nullptr;
if (parentExpr && parentExpr->is<AstExprCall>())
{
if (auto it = module.astOverloadResolvedTypes.find(parentExpr); it != module.astOverloadResolvedTypes.end())
if (auto it = module.astOverloadResolvedTypes.find(parentExpr))
{
matchingOverload = it->second;
matchingOverload = *it;
}
}
@ -345,9 +346,9 @@ std::optional<DocumentationSymbol> getDocumentationSymbolAtPosition(const Source
{
if (AstExprIndexName* indexName = targetExpr->as<AstExprIndexName>())
{
if (auto it = module.astTypes.find(indexName->expr); it != module.astTypes.end())
if (auto it = module.astTypes.find(indexName->expr))
{
TypeId parentTy = follow(it->second);
TypeId parentTy = follow(*it);
if (const TableTypeVar* ttv = get<TableTypeVar>(parentTy))
{
if (auto propIt = ttv->props.find(indexName->index.value); propIt != ttv->props.end())

View file

@ -196,7 +196,8 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ
auto canUnify = [&typeArena, &module](TypeId expectedType, TypeId actualType) {
InternalErrorReporter iceReporter;
Unifier unifier(typeArena, Mode::Strict, module.getModuleScope(), Location(), Variance::Covariant, &iceReporter);
UnifierSharedState unifierState(&iceReporter);
Unifier unifier(typeArena, Mode::Strict, module.getModuleScope(), Location(), Variance::Covariant, unifierState);
unifier.tryUnify(expectedType, actualType);
@ -210,10 +211,10 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ
return TypeCorrectKind::None;
auto it = module.astExpectedTypes.find(expr);
if (it == module.astExpectedTypes.end())
if (!it)
return TypeCorrectKind::None;
TypeId expectedType = follow(it->second);
TypeId expectedType = follow(*it);
if (canUnify(expectedType, ty))
return TypeCorrectKind::Correct;
@ -682,10 +683,10 @@ static std::optional<bool> functionIsExpectedAt(const Module& module, AstNode* n
return std::nullopt;
auto it = module.astExpectedTypes.find(expr);
if (it == module.astExpectedTypes.end())
if (!it)
return std::nullopt;
TypeId expectedType = follow(it->second);
TypeId expectedType = follow(*it);
if (const FunctionTypeVar* ftv = get<FunctionTypeVar>(expectedType))
return true;
@ -784,9 +785,9 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi
if (AstExprCall* exprCall = expr->as<AstExprCall>())
{
if (auto it = module.astTypes.find(exprCall->func); it != module.astTypes.end())
if (auto it = module.astTypes.find(exprCall->func))
{
if (const FunctionTypeVar* ftv = get<FunctionTypeVar>(follow(it->second)))
if (const FunctionTypeVar* ftv = get<FunctionTypeVar>(follow(*it)))
{
if (auto ty = tryGetTypePackTypeAt(ftv->retType, tailPos))
inferredType = *ty;
@ -798,8 +799,8 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi
if (tailPos != 0)
break;
if (auto it = module.astTypes.find(expr); it != module.astTypes.end())
inferredType = it->second;
if (auto it = module.astTypes.find(expr))
inferredType = *it;
}
if (inferredType)
@ -815,10 +816,10 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi
auto tryGetExpectedFunctionType = [](const Module& module, AstExpr* expr) -> const FunctionTypeVar* {
auto it = module.astExpectedTypes.find(expr);
if (it == module.astExpectedTypes.end())
if (!it)
return nullptr;
TypeId ty = follow(it->second);
TypeId ty = follow(*it);
if (const FunctionTypeVar* ftv = get<FunctionTypeVar>(ty))
return ftv;
@ -1129,9 +1130,8 @@ static void autocompleteExpression(const SourceModule& sourceModule, const Modul
if (node->is<AstExprIndexName>())
{
auto it = module.astTypes.find(node->asExpr());
if (it != module.astTypes.end())
autocompleteProps(module, typeArena, it->second, PropIndexType::Point, ancestry, result);
if (auto it = module.astTypes.find(node->asExpr()))
autocompleteProps(module, typeArena, *it, PropIndexType::Point, ancestry, result);
}
else if (FFlag::LuauIfElseExpressionAnalysisSupport && autocompleteIfElseExpression(node, ancestry, position, result))
return;
@ -1203,13 +1203,13 @@ static std::optional<const ClassTypeVar*> getMethodContainingClass(const ModuleP
return std::nullopt;
}
auto parentIter = module->astTypes.find(parentExpr);
if (parentIter == module->astTypes.end())
auto parentIt = module->astTypes.find(parentExpr);
if (!parentIt)
{
return std::nullopt;
}
Luau::TypeId parentType = Luau::follow(parentIter->second);
Luau::TypeId parentType = Luau::follow(*parentIt);
if (auto parentClass = Luau::get<ClassTypeVar>(parentType))
{
@ -1250,8 +1250,8 @@ static std::optional<AutocompleteEntryMap> autocompleteStringParams(const Source
return std::nullopt;
}
auto iter = module->astTypes.find(candidate->func);
if (iter == module->astTypes.end())
auto it = module->astTypes.find(candidate->func);
if (!it)
{
return std::nullopt;
}
@ -1267,7 +1267,7 @@ static std::optional<AutocompleteEntryMap> autocompleteStringParams(const Source
return std::nullopt;
};
auto followedId = Luau::follow(iter->second);
auto followedId = Luau::follow(*it);
if (auto functionType = Luau::get<FunctionTypeVar>(followedId))
{
return performCallback(functionType);
@ -1316,10 +1316,10 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M
if (auto indexName = node->as<AstExprIndexName>())
{
auto it = module->astTypes.find(indexName->expr);
if (it == module->astTypes.end())
if (!it)
return {};
TypeId ty = follow(it->second);
TypeId ty = follow(*it);
PropIndexType indexType = indexName->op == ':' ? PropIndexType::Colon : PropIndexType::Point;
if (isString(ty))
@ -1447,9 +1447,9 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M
// If item doesn't have a key, maybe the value is actually the key
if (key ? key == node : node->is<AstExprGlobal>() && value == node)
{
if (auto it = module->astExpectedTypes.find(exprTable); it != module->astExpectedTypes.end())
if (auto it = module->astExpectedTypes.find(exprTable))
{
auto result = autocompleteProps(*module, typeArena, it->second, PropIndexType::Key, finder.ancestry);
auto result = autocompleteProps(*module, typeArena, *it, PropIndexType::Key, finder.ancestry);
// Remove keys that are already completed
for (const auto& item : exprTable->items)
@ -1485,9 +1485,9 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M
{
if (auto idxExpr = finder.ancestry.at(finder.ancestry.size() - 2)->as<AstExprIndexExpr>())
{
if (auto it = module->astTypes.find(idxExpr->expr); it != module->astTypes.end())
if (auto it = module->astTypes.find(idxExpr->expr))
{
return {autocompleteProps(*module, typeArena, follow(it->second), PropIndexType::Point, finder.ancestry), finder.ancestry};
return {autocompleteProps(*module, typeArena, follow(*it), PropIndexType::Point, finder.ancestry), finder.ancestry};
}
}
}

View file

@ -11,7 +11,7 @@
LUAU_FASTFLAG(LuauParseGenericFunctions)
LUAU_FASTFLAG(LuauGenericFunctions)
LUAU_FASTFLAG(LuauRankNTypes)
LUAU_FASTFLAG(LuauStringMetatable)
LUAU_FASTFLAG(LuauNewRequireTrace)
/** FIXME: Many of these type definitions are not quite completely accurate.
*
@ -106,18 +106,6 @@ void attachMagicFunction(TypeId ty, MagicFunction fn)
LUAU_ASSERT(!"Got a non functional type");
}
void attachFunctionTag(TypeId ty, std::string tag)
{
if (auto ftv = getMutable<FunctionTypeVar>(ty))
{
ftv->tags.emplace_back(std::move(tag));
}
else
{
LUAU_ASSERT(!"Got a non functional type");
}
}
Property makeProperty(TypeId ty, std::optional<std::string> documentationSymbol)
{
return {
@ -218,7 +206,6 @@ void registerBuiltinTypes(TypeChecker& typeChecker)
TypePackId anyTypePack = typeChecker.anyTypePack;
TypePackId numberVariadicList = arena.addTypePack(TypePackVar{VariadicTypePack{numberType}});
TypePackId stringVariadicList = arena.addTypePack(TypePackVar{VariadicTypePack{stringType}});
TypePackId listOfAtLeastOneNumber = arena.addTypePack(TypePack{{numberType}, numberVariadicList});
TypeId listOfAtLeastOneNumberToNumberType = arena.addType(FunctionTypeVar{
@ -255,85 +242,18 @@ void registerBuiltinTypes(TypeChecker& typeChecker)
TypeId genericV = arena.addType(GenericTypeVar{"V"});
TypeId mapOfKtoV = arena.addType(TableTypeVar{{}, TableIndexer(genericK, genericV), typeChecker.globalScope->level});
if (FFlag::LuauStringMetatable)
std::optional<TypeId> stringMetatableTy = getMetatable(singletonTypes.stringType);
LUAU_ASSERT(stringMetatableTy);
const TableTypeVar* stringMetatableTable = get<TableTypeVar>(follow(*stringMetatableTy));
LUAU_ASSERT(stringMetatableTable);
auto it = stringMetatableTable->props.find("__index");
LUAU_ASSERT(it != stringMetatableTable->props.end());
addGlobalBinding(typeChecker, "string", it->second.type, "@luau");
if (!FFlag::LuauParseGenericFunctions || !FFlag::LuauGenericFunctions)
{
std::optional<TypeId> stringMetatableTy = getMetatable(singletonTypes.stringType);
LUAU_ASSERT(stringMetatableTy);
const TableTypeVar* stringMetatableTable = get<TableTypeVar>(follow(*stringMetatableTy));
LUAU_ASSERT(stringMetatableTable);
auto it = stringMetatableTable->props.find("__index");
LUAU_ASSERT(it != stringMetatableTable->props.end());
TypeId stringLib = it->second.type;
addGlobalBinding(typeChecker, "string", stringLib, "@luau");
}
if (FFlag::LuauParseGenericFunctions && FFlag::LuauGenericFunctions)
{
if (!FFlag::LuauStringMetatable)
{
TypeId stringLibTy = getGlobalBinding(typeChecker, "string");
TableTypeVar* stringLib = getMutable<TableTypeVar>(stringLibTy);
TypeId replArgType = makeUnion(
arena, {stringType,
arena.addType(TableTypeVar({}, TableIndexer(stringType, stringType), typeChecker.globalScope->level, TableState::Generic)),
makeFunction(arena, std::nullopt, {stringType}, {stringType})});
TypeId gsubFunc = makeFunction(arena, stringType, {stringType, replArgType, optionalNumber}, {stringType, numberType});
stringLib->props["gsub"] = makeProperty(gsubFunc, "@luau/global/string.gsub");
}
}
else
{
if (!FFlag::LuauStringMetatable)
{
TypeId stringToStringType = makeFunction(arena, std::nullopt, {stringType}, {stringType});
TypeId gmatchFunc = makeFunction(arena, stringType, {stringType}, {arena.addType(FunctionTypeVar{emptyPack, stringVariadicList})});
TypeId replArgType = makeUnion(
arena, {stringType,
arena.addType(TableTypeVar({}, TableIndexer(stringType, stringType), typeChecker.globalScope->level, TableState::Generic)),
makeFunction(arena, std::nullopt, {stringType}, {stringType})});
TypeId gsubFunc = makeFunction(arena, stringType, {stringType, replArgType, optionalNumber}, {stringType, numberType});
TypeId formatFn = arena.addType(FunctionTypeVar{arena.addTypePack(TypePack{{stringType}, anyTypePack}), oneStringPack});
TableTypeVar::Props stringLib = {
// FIXME string.byte "can" return a pack of numbers, but only if 2nd or 3rd arguments were supplied
{"byte", {makeFunction(arena, stringType, {optionalNumber, optionalNumber}, {optionalNumber})}},
// FIXME char takes a variadic pack of numbers
{"char", {makeFunction(arena, std::nullopt, {numberType, optionalNumber, optionalNumber, optionalNumber}, {stringType})}},
{"find", {makeFunction(arena, stringType, {stringType, optionalNumber, optionalBoolean}, {optionalNumber, optionalNumber})}},
{"format", {formatFn}}, // FIXME
{"gmatch", {gmatchFunc}},
{"gsub", {gsubFunc}},
{"len", {makeFunction(arena, stringType, {}, {numberType})}},
{"lower", {stringToStringType}},
{"match", {makeFunction(arena, stringType, {stringType, optionalNumber}, {optionalString})}},
{"rep", {makeFunction(arena, stringType, {numberType}, {stringType})}},
{"reverse", {stringToStringType}},
{"sub", {makeFunction(arena, stringType, {numberType, optionalNumber}, {stringType})}},
{"upper", {stringToStringType}},
{"split", {makeFunction(arena, stringType, {stringType, optionalString},
{arena.addType(TableTypeVar{{}, TableIndexer{numberType, stringType}, typeChecker.globalScope->level})})}},
{"pack", {arena.addType(FunctionTypeVar{
arena.addTypePack(TypePack{{stringType}, anyTypePack}),
oneStringPack,
})}},
{"packsize", {makeFunction(arena, stringType, {}, {numberType})}},
{"unpack", {arena.addType(FunctionTypeVar{
arena.addTypePack(TypePack{{stringType, stringType, optionalNumber}}),
anyTypePack,
})}},
};
assignPropDocumentationSymbols(stringLib, "@luau/global/string");
addGlobalBinding(typeChecker, "string",
arena.addType(TableTypeVar{stringLib, std::nullopt, typeChecker.globalScope->level, TableState::Sealed}), "@luau");
}
TableTypeVar::Props debugLib{
{"info", {makeIntersection(arena,
{
@ -601,9 +521,6 @@ void registerBuiltinTypes(TypeChecker& typeChecker)
auto tableLib = getMutable<TableTypeVar>(getGlobalBinding(typeChecker, "table"));
attachMagicFunction(tableLib->props["pack"].type, magicFunctionPack);
auto stringLib = getMutable<TableTypeVar>(getGlobalBinding(typeChecker, "string"));
attachMagicFunction(stringLib->props["format"].type, magicFunctionFormat);
attachMagicFunction(getGlobalBinding(typeChecker, "require"), magicFunctionRequire);
}
@ -791,11 +708,11 @@ static std::optional<ExprResult<TypePackId>> magicFunctionRequire(
return std::nullopt;
}
AstExpr* require = expr.args.data[0];
if (!checkRequirePath(typechecker, require))
if (!checkRequirePath(typechecker, expr.args.data[0]))
return std::nullopt;
const AstExpr* require = FFlag::LuauNewRequireTrace ? &expr : expr.args.data[0];
if (auto moduleInfo = typechecker.resolver->resolveModuleInfo(typechecker.currentModuleName, *require))
return ExprResult<TypePackId>{arena.addTypePack({typechecker.checkRequire(scope, *moduleInfo, expr.location)})};

View file

@ -206,27 +206,6 @@ std::string getBuiltinDefinitionSource()
graphemes: (string, number?, number?) -> (() -> (number, number)),
}
declare string: {
byte: (string, number?, number?) -> ...number,
char: (number, ...number) -> string,
find: (string, string, number?, boolean?) -> (number?, number?),
-- `string.format` has a magic function attached that will provide more type information for literal format strings.
format: <A...>(string, A...) -> string,
gmatch: (string, string) -> () -> (...string),
-- gsub is defined in C++ because we don't have syntax for describing a generic table.
len: (string) -> number,
lower: (string) -> string,
match: (string, string, number?) -> string?,
rep: (string, number) -> string,
reverse: (string) -> string,
sub: (string, number, number?) -> string,
upper: (string) -> string,
split: (string, string, string?) -> {string},
pack: <A...>(string, A...) -> string,
packsize: (string) -> number,
unpack: <R...>(string, string, number?) -> R...,
}
-- Cannot use `typeof` here because it will produce a polytype when we expect a monotype.
declare function unpack<V>(tab: {V}, i: number?, j: number?): ...V
)";

View file

@ -7,9 +7,9 @@
#include <stdexcept>
LUAU_FASTFLAG(LuauFasterStringifier)
LUAU_FASTFLAG(LuauTypeAliasPacks)
static std::string wrongNumberOfArgsString(size_t expectedCount, size_t actualCount, bool isTypeArgs = false)
static std::string wrongNumberOfArgsString_DEPRECATED(size_t expectedCount, size_t actualCount, bool isTypeArgs = false)
{
std::string s = "expects " + std::to_string(expectedCount) + " ";
@ -41,6 +41,52 @@ static std::string wrongNumberOfArgsString(size_t expectedCount, size_t actualCo
return s;
}
static std::string wrongNumberOfArgsString(size_t expectedCount, size_t actualCount, const char* argPrefix = nullptr, bool isVariadic = false)
{
std::string s;
if (FFlag::LuauTypeAliasPacks)
{
s = "expects ";
if (isVariadic)
s += "at least ";
s += std::to_string(expectedCount) + " ";
}
else
{
s = "expects " + std::to_string(expectedCount) + " ";
}
if (argPrefix)
s += std::string(argPrefix) + " ";
s += "argument";
if (expectedCount != 1)
s += "s";
s += ", but ";
if (actualCount == 0)
{
s += "none";
}
else
{
if (actualCount < expectedCount)
s += "only ";
s += std::to_string(actualCount);
}
s += (actualCount == 1) ? " is" : " are";
s += " specified";
return s;
}
namespace Luau
{
@ -127,7 +173,10 @@ struct ErrorConverter
return "Function only returns " + std::to_string(e.expected) + " value" + expectedS + ". " +
std::to_string(e.actual) + " are required here";
case CountMismatch::Arg:
return "Argument count mismatch. Function " + wrongNumberOfArgsString(e.expected, e.actual);
if (FFlag::LuauTypeAliasPacks)
return "Argument count mismatch. Function " + wrongNumberOfArgsString(e.expected, e.actual);
else
return "Argument count mismatch. Function " + wrongNumberOfArgsString_DEPRECATED(e.expected, e.actual);
}
LUAU_ASSERT(!"Unknown context");
@ -159,13 +208,16 @@ struct ErrorConverter
std::string operator()(const Luau::UnknownRequire& e) const
{
return "Unknown require: " + e.modulePath;
if (e.modulePath.empty())
return "Unknown require: unsupported path";
else
return "Unknown require: " + e.modulePath;
}
std::string operator()(const Luau::IncorrectGenericParameterCount& e) const
{
std::string name = e.name;
if (!e.typeFun.typeParams.empty())
if (!e.typeFun.typeParams.empty() || (FFlag::LuauTypeAliasPacks && !e.typeFun.typePackParams.empty()))
{
name += "<";
bool first = true;
@ -178,10 +230,37 @@ struct ErrorConverter
name += toString(t);
}
if (FFlag::LuauTypeAliasPacks)
{
for (TypePackId t : e.typeFun.typePackParams)
{
if (first)
first = false;
else
name += ", ";
name += toString(t);
}
}
name += ">";
}
return "Generic type '" + name + "' " + wrongNumberOfArgsString(e.typeFun.typeParams.size(), e.actualParameters, /*isTypeArgs*/ true);
if (FFlag::LuauTypeAliasPacks)
{
if (e.typeFun.typeParams.size() != e.actualParameters)
return "Generic type '" + name + "' " +
wrongNumberOfArgsString(e.typeFun.typeParams.size(), e.actualParameters, "type", !e.typeFun.typePackParams.empty());
return "Generic type '" + name + "' " +
wrongNumberOfArgsString(e.typeFun.typePackParams.size(), e.actualPackParameters, "type pack", /*isVariadic*/ false);
}
else
{
return "Generic type '" + name + "' " +
wrongNumberOfArgsString_DEPRECATED(e.typeFun.typeParams.size(), e.actualParameters, /*isTypeArgs*/ true);
}
}
std::string operator()(const Luau::SyntaxError& e) const
@ -470,9 +549,26 @@ bool IncorrectGenericParameterCount::operator==(const IncorrectGenericParameterC
if (typeFun.typeParams.size() != rhs.typeFun.typeParams.size())
return false;
if (FFlag::LuauTypeAliasPacks)
{
if (typeFun.typePackParams.size() != rhs.typeFun.typePackParams.size())
return false;
}
for (size_t i = 0; i < typeFun.typeParams.size(); ++i)
{
if (typeFun.typeParams[i] != rhs.typeFun.typeParams[i])
return false;
}
if (FFlag::LuauTypeAliasPacks)
{
for (size_t i = 0; i < typeFun.typePackParams.size(); ++i)
{
if (typeFun.typePackParams[i] != rhs.typeFun.typePackParams[i])
return false;
}
}
return true;
}

View file

@ -1,9 +1,12 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Frontend.h"
#include "Luau/Common.h"
#include "Luau/Config.h"
#include "Luau/FileResolver.h"
#include "Luau/Scope.h"
#include "Luau/StringUtils.h"
#include "Luau/TimeTrace.h"
#include "Luau/TypeInfer.h"
#include "Luau/Variant.h"
#include "Luau/Common.h"
@ -19,6 +22,8 @@ LUAU_FASTFLAGVARIABLE(LuauSecondTypecheckKnowsTheDataModel, false)
LUAU_FASTFLAGVARIABLE(LuauResolveModuleNameWithoutACurrentModule, false)
LUAU_FASTFLAG(LuauTraceRequireLookupChild)
LUAU_FASTFLAGVARIABLE(LuauPersistDefinitionFileTypes, false)
LUAU_FASTFLAG(LuauNewRequireTrace)
LUAU_FASTFLAGVARIABLE(LuauClearScopes, false)
namespace Luau
{
@ -69,6 +74,8 @@ static void generateDocumentationSymbols(TypeId ty, const std::string& rootName)
LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr targetScope, std::string_view source, const std::string& packageName)
{
LUAU_TIMETRACE_SCOPE("loadDefinitionFile", "Frontend");
Luau::Allocator allocator;
Luau::AstNameTable names(allocator);
@ -242,7 +249,7 @@ struct RequireCycle
// Note that this is O(V^2) for a fully connected graph and produces O(V) paths of length O(V)
// However, when the graph is acyclic, this is O(V), as well as when only the first cycle is needed (stopAtFirst=true)
std::vector<RequireCycle> getRequireCycles(
const std::unordered_map<ModuleName, SourceNode>& sourceNodes, const SourceNode* start, bool stopAtFirst = false)
const FileResolver* resolver, const std::unordered_map<ModuleName, SourceNode>& sourceNodes, const SourceNode* start, bool stopAtFirst = false)
{
std::vector<RequireCycle> result;
@ -276,9 +283,9 @@ std::vector<RequireCycle> getRequireCycles(
if (top == start)
{
for (const SourceNode* node : path)
cycle.push_back(node->name);
cycle.push_back(resolver->getHumanReadableModuleName(node->name));
cycle.push_back(top->name);
cycle.push_back(resolver->getHumanReadableModuleName(top->name));
break;
}
}
@ -350,6 +357,9 @@ FrontendModuleResolver::FrontendModuleResolver(Frontend* frontend)
CheckResult Frontend::check(const ModuleName& name)
{
LUAU_TIMETRACE_SCOPE("Frontend::check", "Frontend");
LUAU_TIMETRACE_ARGUMENT("name", name.c_str());
CheckResult checkResult;
auto it = sourceNodes.find(name);
@ -395,7 +405,7 @@ CheckResult Frontend::check(const ModuleName& name)
// however, for now getRequireCycles isn't expensive in practice on the cases we care about, and long term
// all correct programs must be acyclic so this code triggers rarely
if (cycleDetected)
requireCycles = getRequireCycles(sourceNodes, &sourceNode, mode == Mode::NoCheck);
requireCycles = getRequireCycles(fileResolver, sourceNodes, &sourceNode, mode == Mode::NoCheck);
// This is used by the type checker to replace the resulting type of cyclic modules with any
sourceModule.cyclic = !requireCycles.empty();
@ -449,6 +459,8 @@ CheckResult Frontend::check(const ModuleName& name)
module->astTypes.clear();
module->astExpectedTypes.clear();
module->astOriginalCallTypes.clear();
if (FFlag::LuauClearScopes)
module->scopes.resize(1);
}
if (mode != Mode::NoCheck)
@ -479,6 +491,9 @@ CheckResult Frontend::check(const ModuleName& name)
bool Frontend::parseGraph(std::vector<ModuleName>& buildQueue, CheckResult& checkResult, const ModuleName& root)
{
LUAU_TIMETRACE_SCOPE("Frontend::parseGraph", "Frontend");
LUAU_TIMETRACE_ARGUMENT("root", root.c_str());
// https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search
enum Mark
{
@ -597,6 +612,9 @@ ScopePtr Frontend::getModuleEnvironment(const SourceModule& module, const Config
LintResult Frontend::lint(const ModuleName& name, std::optional<Luau::LintOptions> enabledLintWarnings)
{
LUAU_TIMETRACE_SCOPE("Frontend::lint", "Frontend");
LUAU_TIMETRACE_ARGUMENT("name", name.c_str());
CheckResult checkResult;
auto [_sourceNode, sourceModule] = getSourceNode(checkResult, name);
@ -608,6 +626,8 @@ LintResult Frontend::lint(const ModuleName& name, std::optional<Luau::LintOption
std::pair<SourceModule, LintResult> Frontend::lintFragment(std::string_view source, std::optional<Luau::LintOptions> enabledLintWarnings)
{
LUAU_TIMETRACE_SCOPE("Frontend::lintFragment", "Frontend");
const Config& config = configResolver->getConfig("");
SourceModule sourceModule = parse(ModuleName{}, source, config.parseOptions);
@ -627,6 +647,9 @@ std::pair<SourceModule, LintResult> Frontend::lintFragment(std::string_view sour
CheckResult Frontend::check(const SourceModule& module)
{
LUAU_TIMETRACE_SCOPE("Frontend::check", "Frontend");
LUAU_TIMETRACE_ARGUMENT("module", module.name.c_str());
const Config& config = configResolver->getConfig(module.name);
Mode mode = module.mode.value_or(config.mode);
@ -648,6 +671,9 @@ CheckResult Frontend::check(const SourceModule& module)
LintResult Frontend::lint(const SourceModule& module, std::optional<Luau::LintOptions> enabledLintWarnings)
{
LUAU_TIMETRACE_SCOPE("Frontend::lint", "Frontend");
LUAU_TIMETRACE_ARGUMENT("module", module.name.c_str());
const Config& config = configResolver->getConfig(module.name);
LintOptions options = enabledLintWarnings.value_or(config.enabledLint);
@ -746,6 +772,9 @@ const SourceModule* Frontend::getSourceModule(const ModuleName& moduleName) cons
// Read AST into sourceModules if necessary. Trace require()s. Report parse errors.
std::pair<SourceNode*, SourceModule*> Frontend::getSourceNode(CheckResult& checkResult, const ModuleName& name)
{
LUAU_TIMETRACE_SCOPE("Frontend::getSourceNode", "Frontend");
LUAU_TIMETRACE_ARGUMENT("name", name.c_str());
auto it = sourceNodes.find(name);
if (it != sourceNodes.end() && !it->second.dirty)
{
@ -815,6 +844,9 @@ std::pair<SourceNode*, SourceModule*> Frontend::getSourceNode(CheckResult& check
*/
SourceModule Frontend::parse(const ModuleName& name, std::string_view src, const ParseOptions& parseOptions)
{
LUAU_TIMETRACE_SCOPE("Frontend::parse", "Frontend");
LUAU_TIMETRACE_ARGUMENT("name", name.c_str());
SourceModule sourceModule;
double timestamp = getTimestamp();
@ -864,20 +896,11 @@ std::optional<ModuleInfo> FrontendModuleResolver::resolveModuleInfo(const Module
const auto& exprs = it->second.exprs;
const ModuleName* relativeName = exprs.find(&pathExpr);
if (!relativeName || relativeName->empty())
const ModuleInfo* info = exprs.find(&pathExpr);
if (!info || (!FFlag::LuauNewRequireTrace && info->name.empty()))
return std::nullopt;
if (FFlag::LuauTraceRequireLookupChild)
{
const bool* optional = it->second.optional.find(&pathExpr);
return {{*relativeName, optional ? *optional : false}};
}
else
{
return {{*relativeName, false}};
}
return *info;
}
const ModulePtr FrontendModuleResolver::getModule(const ModuleName& moduleName) const
@ -891,12 +914,15 @@ const ModulePtr FrontendModuleResolver::getModule(const ModuleName& moduleName)
bool FrontendModuleResolver::moduleExists(const ModuleName& moduleName) const
{
return frontend->fileResolver->moduleExists(moduleName);
if (FFlag::LuauNewRequireTrace)
return frontend->sourceNodes.count(moduleName) != 0;
else
return frontend->fileResolver->moduleExists(moduleName);
}
std::string FrontendModuleResolver::getHumanReadableModuleName(const ModuleName& moduleName) const
{
return frontend->fileResolver->getHumanReadableModuleName_(moduleName).value_or(moduleName);
return frontend->fileResolver->getHumanReadableModuleName(moduleName);
}
ScopePtr Frontend::addEnvironment(const std::string& environmentName)

View file

@ -2,6 +2,8 @@
#include "Luau/IostreamHelpers.h"
#include "Luau/ToString.h"
LUAU_FASTFLAG(LuauTypeAliasPacks)
namespace Luau
{
@ -92,7 +94,7 @@ std::ostream& operator<<(std::ostream& stream, const IncorrectGenericParameterCo
{
stream << "IncorrectGenericParameterCount { name = " << error.name;
if (!error.typeFun.typeParams.empty())
if (!error.typeFun.typeParams.empty() || (FFlag::LuauTypeAliasPacks && !error.typeFun.typePackParams.empty()))
{
stream << "<";
bool first = true;
@ -105,6 +107,20 @@ std::ostream& operator<<(std::ostream& stream, const IncorrectGenericParameterCo
stream << toString(t);
}
if (FFlag::LuauTypeAliasPacks)
{
for (TypePackId t : error.typeFun.typePackParams)
{
if (first)
first = false;
else
stream << ", ";
stream << toString(t);
}
}
stream << ">";
}

View file

@ -3,6 +3,9 @@
#include "Luau/Ast.h"
#include "Luau/StringUtils.h"
#include "Luau/Common.h"
LUAU_FASTFLAG(LuauTypeAliasPacks)
namespace Luau
{
@ -612,6 +615,12 @@ struct AstJsonEncoder : public AstVisitor
writeNode(node, "AstStatTypeAlias", [&]() {
PROP(name);
PROP(generics);
if (FFlag::LuauTypeAliasPacks)
{
PROP(genericPacks);
}
PROP(type);
PROP(exported);
});
@ -664,13 +673,21 @@ struct AstJsonEncoder : public AstVisitor
});
}
void write(struct AstTypeOrPack node)
{
if (node.type)
write(node.type);
else
write(node.typePack);
}
void write(class AstTypeReference* node)
{
writeNode(node, "AstTypeReference", [&]() {
if (node->hasPrefix)
PROP(prefix);
PROP(name);
PROP(generics);
PROP(parameters);
});
}
@ -734,6 +751,13 @@ struct AstJsonEncoder : public AstVisitor
});
}
void write(class AstTypePackExplicit* node)
{
writeNode(node, "AstTypePackExplicit", [&]() {
PROP(typeList);
});
}
void write(class AstTypePackVariadic* node)
{
writeNode(node, "AstTypePackVariadic", [&]() {
@ -1018,6 +1042,12 @@ struct AstJsonEncoder : public AstVisitor
return false;
}
bool visit(class AstTypePackExplicit* node) override
{
write(node);
return false;
}
bool visit(class AstTypePackVariadic* node) override
{
write(node);

View file

@ -3,6 +3,7 @@
#include "Luau/AstQuery.h"
#include "Luau/Module.h"
#include "Luau/Scope.h"
#include "Luau/TypeInfer.h"
#include "Luau/StringUtils.h"
#include "Luau/Common.h"
@ -12,6 +13,7 @@
#include <limits.h>
LUAU_FASTFLAGVARIABLE(LuauLinterUnknownTypeVectorAware, false)
LUAU_FASTFLAGVARIABLE(LuauLinterTableMoveZero, false)
namespace Luau
{
@ -85,10 +87,10 @@ struct LintContext
return std::nullopt;
auto it = module->astTypes.find(expr);
if (it == module->astTypes.end())
if (!it)
return std::nullopt;
return it->second;
return *it;
}
};
@ -2144,6 +2146,19 @@ private:
"wrap it in parentheses to silence");
}
if (FFlag::LuauLinterTableMoveZero && func->index == "move" && node->args.size >= 4)
{
// table.move(t, 0, _, _)
if (isConstant(args[1], 0.0))
emitWarning(*context, LintWarning::Code_TableOperations, args[1]->location,
"table.move uses index 0 but arrays are 1-based; did you mean 1 instead?");
// table.move(t, _, _, 0)
else if (isConstant(args[3], 0.0))
emitWarning(*context, LintWarning::Code_TableOperations, args[3]->location,
"table.move uses index 0 but arrays are 1-based; did you mean 1 instead?");
}
return true;
}

View file

@ -1,6 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Module.h"
#include "Luau/Scope.h"
#include "Luau/TypeInfer.h"
#include "Luau/TypePack.h"
#include "Luau/TypeVar.h"
@ -13,6 +14,8 @@ LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false)
LUAU_FASTFLAGVARIABLE(DebugLuauTrackOwningArena, false)
LUAU_FASTFLAG(LuauSecondTypecheckKnowsTheDataModel)
LUAU_FASTFLAG(LuauCaptureBrokenCommentSpans)
LUAU_FASTFLAG(LuauTypeAliasPacks)
LUAU_FASTFLAGVARIABLE(LuauCloneBoundTables, false)
namespace Luau
{
@ -188,7 +191,7 @@ struct TypePackCloner
template<typename T>
void defaultClone(const T& t)
{
TypePackId cloned = dest.typePacks.allocate(t);
TypePackId cloned = dest.addTypePack(TypePackVar{t});
seenTypePacks[typePackId] = cloned;
}
@ -197,7 +200,7 @@ struct TypePackCloner
if (encounteredFreeType)
*encounteredFreeType = true;
seenTypePacks[typePackId] = dest.typePacks.allocate(TypePackVar{Unifiable::Error{}});
seenTypePacks[typePackId] = dest.addTypePack(TypePackVar{Unifiable::Error{}});
}
void operator()(const Unifiable::Generic& t)
@ -219,13 +222,13 @@ struct TypePackCloner
void operator()(const VariadicTypePack& t)
{
TypePackId cloned = dest.typePacks.allocate(VariadicTypePack{clone(t.ty, dest, seenTypes, seenTypePacks, encounteredFreeType)});
TypePackId cloned = dest.addTypePack(TypePackVar{VariadicTypePack{clone(t.ty, dest, seenTypes, seenTypePacks, encounteredFreeType)}});
seenTypePacks[typePackId] = cloned;
}
void operator()(const TypePack& t)
{
TypePackId cloned = dest.typePacks.allocate(TypePack{});
TypePackId cloned = dest.addTypePack(TypePack{});
TypePack* destTp = getMutable<TypePack>(cloned);
LUAU_ASSERT(destTp != nullptr);
seenTypePacks[typePackId] = cloned;
@ -241,7 +244,7 @@ struct TypePackCloner
template<typename T>
void TypeCloner::defaultClone(const T& t)
{
TypeId cloned = dest.typeVars.allocate(t);
TypeId cloned = dest.addType(t);
seenTypes[typeId] = cloned;
}
@ -250,7 +253,7 @@ void TypeCloner::operator()(const Unifiable::Free& t)
if (encounteredFreeType)
*encounteredFreeType = true;
seenTypes[typeId] = dest.typeVars.allocate(ErrorTypeVar{});
seenTypes[typeId] = dest.addType(ErrorTypeVar{});
}
void TypeCloner::operator()(const Unifiable::Generic& t)
@ -275,7 +278,7 @@ void TypeCloner::operator()(const PrimitiveTypeVar& t)
void TypeCloner::operator()(const FunctionTypeVar& t)
{
TypeId result = dest.typeVars.allocate(FunctionTypeVar{TypeLevel{0, 0}, {}, {}, nullptr, nullptr, t.definition, t.hasSelf});
TypeId result = dest.addType(FunctionTypeVar{TypeLevel{0, 0}, {}, {}, nullptr, nullptr, t.definition, t.hasSelf});
FunctionTypeVar* ftv = getMutable<FunctionTypeVar>(result);
LUAU_ASSERT(ftv != nullptr);
@ -297,7 +300,15 @@ void TypeCloner::operator()(const FunctionTypeVar& t)
void TypeCloner::operator()(const TableTypeVar& t)
{
TypeId result = dest.typeVars.allocate(TableTypeVar{});
// If table is now bound to another one, we ignore the content of the original
if (FFlag::LuauCloneBoundTables && t.boundTo)
{
TypeId boundTo = clone(*t.boundTo, dest, seenTypes, seenTypePacks, encounteredFreeType);
seenTypes[typeId] = boundTo;
return;
}
TypeId result = dest.addType(TableTypeVar{});
TableTypeVar* ttv = getMutable<TableTypeVar>(result);
LUAU_ASSERT(ttv != nullptr);
@ -319,15 +330,24 @@ void TypeCloner::operator()(const TableTypeVar& t)
ttv->indexer = TableIndexer{clone(t.indexer->indexType, dest, seenTypes, seenTypePacks, encounteredFreeType),
clone(t.indexer->indexResultType, dest, seenTypes, seenTypePacks, encounteredFreeType)};
if (t.boundTo)
ttv->boundTo = clone(*t.boundTo, dest, seenTypes, seenTypePacks, encounteredFreeType);
if (!FFlag::LuauCloneBoundTables)
{
if (t.boundTo)
ttv->boundTo = clone(*t.boundTo, dest, seenTypes, seenTypePacks, encounteredFreeType);
}
for (TypeId& arg : ttv->instantiatedTypeParams)
arg = (clone(arg, dest, seenTypes, seenTypePacks, encounteredFreeType));
arg = clone(arg, dest, seenTypes, seenTypePacks, encounteredFreeType);
if (FFlag::LuauTypeAliasPacks)
{
for (TypePackId& arg : ttv->instantiatedTypePackParams)
arg = clone(arg, dest, seenTypes, seenTypePacks, encounteredFreeType);
}
if (ttv->state == TableState::Free)
{
if (!t.boundTo)
if (FFlag::LuauCloneBoundTables || !t.boundTo)
{
if (encounteredFreeType)
*encounteredFreeType = true;
@ -343,7 +363,7 @@ void TypeCloner::operator()(const TableTypeVar& t)
void TypeCloner::operator()(const MetatableTypeVar& t)
{
TypeId result = dest.typeVars.allocate(MetatableTypeVar{});
TypeId result = dest.addType(MetatableTypeVar{});
MetatableTypeVar* mtv = getMutable<MetatableTypeVar>(result);
seenTypes[typeId] = result;
@ -353,7 +373,7 @@ void TypeCloner::operator()(const MetatableTypeVar& t)
void TypeCloner::operator()(const ClassTypeVar& t)
{
TypeId result = dest.typeVars.allocate(ClassTypeVar{t.name, {}, std::nullopt, std::nullopt, t.tags, t.userData});
TypeId result = dest.addType(ClassTypeVar{t.name, {}, std::nullopt, std::nullopt, t.tags, t.userData});
ClassTypeVar* ctv = getMutable<ClassTypeVar>(result);
seenTypes[typeId] = result;
@ -378,7 +398,7 @@ void TypeCloner::operator()(const AnyTypeVar& t)
void TypeCloner::operator()(const UnionTypeVar& t)
{
TypeId result = dest.typeVars.allocate(UnionTypeVar{});
TypeId result = dest.addType(UnionTypeVar{});
seenTypes[typeId] = result;
UnionTypeVar* option = getMutable<UnionTypeVar>(result);
@ -390,7 +410,7 @@ void TypeCloner::operator()(const UnionTypeVar& t)
void TypeCloner::operator()(const IntersectionTypeVar& t)
{
TypeId result = dest.typeVars.allocate(IntersectionTypeVar{});
TypeId result = dest.addType(IntersectionTypeVar{});
seenTypes[typeId] = result;
IntersectionTypeVar* option = getMutable<IntersectionTypeVar>(result);
@ -451,8 +471,14 @@ TypeId clone(TypeId typeId, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks
TypeFun clone(const TypeFun& typeFun, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, bool* encounteredFreeType)
{
TypeFun result;
for (TypeId param : typeFun.typeParams)
result.typeParams.push_back(clone(param, dest, seenTypes, seenTypePacks, encounteredFreeType));
for (TypeId ty : typeFun.typeParams)
result.typeParams.push_back(clone(ty, dest, seenTypes, seenTypePacks, encounteredFreeType));
if (FFlag::LuauTypeAliasPacks)
{
for (TypePackId tp : typeFun.typePackParams)
result.typePackParams.push_back(clone(tp, dest, seenTypes, seenTypePacks, encounteredFreeType));
}
result.type = clone(typeFun.type, dest, seenTypes, seenTypePacks, encounteredFreeType);

90
Analysis/src/Quantify.cpp Normal file
View file

@ -0,0 +1,90 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Quantify.h"
#include "Luau/VisitTypeVar.h"
namespace Luau
{
struct Quantifier
{
ModulePtr module;
TypeLevel level;
std::vector<TypeId> generics;
std::vector<TypePackId> genericPacks;
Quantifier(ModulePtr module, TypeLevel level)
: module(module)
, level(level)
{
}
void cycle(TypeId) {}
void cycle(TypePackId) {}
bool operator()(TypeId ty, const FreeTypeVar& ftv)
{
if (!level.subsumes(ftv.level))
return false;
*asMutable(ty) = GenericTypeVar{level};
generics.push_back(ty);
return false;
}
template<typename T>
bool operator()(TypeId ty, const T& t)
{
return true;
}
template<typename T>
bool operator()(TypePackId, const T&)
{
return true;
}
bool operator()(TypeId ty, const TableTypeVar&)
{
TableTypeVar& ttv = *getMutable<TableTypeVar>(ty);
if (ttv.state == TableState::Sealed || ttv.state == TableState::Generic)
return false;
if (!level.subsumes(ttv.level))
return false;
if (ttv.state == TableState::Free)
ttv.state = TableState::Generic;
else if (ttv.state == TableState::Unsealed)
ttv.state = TableState::Sealed;
ttv.level = level;
return true;
}
bool operator()(TypePackId tp, const FreeTypePack& ftp)
{
if (!level.subsumes(ftp.level))
return false;
*asMutable(tp) = GenericTypePack{level};
genericPacks.push_back(tp);
return true;
}
};
void quantify(ModulePtr module, TypeId ty, TypeLevel level)
{
Quantifier q{std::move(module), level};
visitTypeVar(ty, q);
FunctionTypeVar* ftv = getMutable<FunctionTypeVar>(ty);
LUAU_ASSERT(ftv);
ftv->generics = q.generics;
ftv->genericPacks = q.genericPacks;
}
} // namespace Luau

View file

@ -5,6 +5,7 @@
#include "Luau/Module.h"
LUAU_FASTFLAGVARIABLE(LuauTraceRequireLookupChild, false)
LUAU_FASTFLAGVARIABLE(LuauNewRequireTrace, false)
namespace Luau
{
@ -12,17 +13,18 @@ namespace Luau
namespace
{
struct RequireTracer : AstVisitor
struct RequireTracerOld : AstVisitor
{
explicit RequireTracer(FileResolver* fileResolver, ModuleName currentModuleName)
explicit RequireTracerOld(FileResolver* fileResolver, const ModuleName& currentModuleName)
: fileResolver(fileResolver)
, currentModuleName(std::move(currentModuleName))
, currentModuleName(currentModuleName)
{
LUAU_ASSERT(!FFlag::LuauNewRequireTrace);
}
FileResolver* const fileResolver;
ModuleName currentModuleName;
DenseHashMap<AstLocal*, ModuleName> locals{0};
DenseHashMap<AstLocal*, ModuleName> locals{nullptr};
RequireTraceResult result;
std::optional<ModuleName> fromAstFragment(AstExpr* expr)
@ -50,9 +52,9 @@ struct RequireTracer : AstVisitor
AstExpr* expr = stat->values.data[i];
expr->visit(this);
const ModuleName* name = result.exprs.find(expr);
if (name)
locals[local] = *name;
const ModuleInfo* info = result.exprs.find(expr);
if (info)
locals[local] = info->name;
}
}
@ -63,7 +65,7 @@ struct RequireTracer : AstVisitor
{
std::optional<ModuleName> name = fromAstFragment(global);
if (name)
result.exprs[global] = *name;
result.exprs[global] = {*name};
return false;
}
@ -72,7 +74,7 @@ struct RequireTracer : AstVisitor
{
const ModuleName* name = locals.find(local->local);
if (name)
result.exprs[local] = *name;
result.exprs[local] = {*name};
return false;
}
@ -81,16 +83,16 @@ struct RequireTracer : AstVisitor
{
indexName->expr->visit(this);
const ModuleName* name = result.exprs.find(indexName->expr);
if (name)
const ModuleInfo* info = result.exprs.find(indexName->expr);
if (info)
{
if (indexName->index == "parent" || indexName->index == "Parent")
{
if (auto parent = fileResolver->getParentModuleName(*name))
result.exprs[indexName] = *parent;
if (auto parent = fileResolver->getParentModuleName(info->name))
result.exprs[indexName] = {*parent};
}
else
result.exprs[indexName] = fileResolver->concat(*name, indexName->index.value);
result.exprs[indexName] = {fileResolver->concat(info->name, indexName->index.value)};
}
return false;
@ -100,11 +102,11 @@ struct RequireTracer : AstVisitor
{
indexExpr->expr->visit(this);
const ModuleName* name = result.exprs.find(indexExpr->expr);
const ModuleInfo* info = result.exprs.find(indexExpr->expr);
const AstExprConstantString* str = indexExpr->index->as<AstExprConstantString>();
if (name && str)
if (info && str)
{
result.exprs[indexExpr] = fileResolver->concat(*name, std::string_view(str->value.data, str->value.size));
result.exprs[indexExpr] = {fileResolver->concat(info->name, std::string_view(str->value.data, str->value.size))};
}
indexExpr->index->visit(this);
@ -129,8 +131,8 @@ struct RequireTracer : AstVisitor
AstExprGlobal* globalName = call->func->as<AstExprGlobal>();
if (globalName && globalName->name == "require" && call->args.size >= 1)
{
if (const ModuleName* moduleName = result.exprs.find(call->args.data[0]))
result.requires.push_back({*moduleName, call->location});
if (const ModuleInfo* moduleInfo = result.exprs.find(call->args.data[0]))
result.requires.push_back({moduleInfo->name, call->location});
return false;
}
@ -143,8 +145,8 @@ struct RequireTracer : AstVisitor
if (FFlag::LuauTraceRequireLookupChild && !rootName)
{
if (const ModuleName* moduleName = result.exprs.find(indexName->expr))
rootName = *moduleName;
if (const ModuleInfo* moduleInfo = result.exprs.find(indexName->expr))
rootName = moduleInfo->name;
}
if (!rootName)
@ -167,24 +169,183 @@ struct RequireTracer : AstVisitor
if (v.end() != std::find(v.begin(), v.end(), '/'))
return false;
result.exprs[call] = fileResolver->concat(*rootName, v);
result.exprs[call] = {fileResolver->concat(*rootName, v)};
// 'WaitForChild' can be used on modules that are not available at the typecheck time, but will be available at runtime
// If we fail to find such module, we will not report an UnknownRequire error
if (FFlag::LuauTraceRequireLookupChild && indexName->index == "WaitForChild")
result.optional[call] = true;
result.exprs[call].optional = true;
return false;
}
};
struct RequireTracer : AstVisitor
{
RequireTracer(RequireTraceResult& result, FileResolver* fileResolver, const ModuleName& currentModuleName)
: result(result)
, fileResolver(fileResolver)
, currentModuleName(currentModuleName)
, locals(nullptr)
{
LUAU_ASSERT(FFlag::LuauNewRequireTrace);
}
bool visit(AstExprTypeAssertion* expr) override
{
// suppress `require() :: any`
return false;
}
bool visit(AstExprCall* expr) override
{
AstExprGlobal* global = expr->func->as<AstExprGlobal>();
if (global && global->name == "require" && expr->args.size >= 1)
requires.push_back(expr);
return true;
}
bool visit(AstStatLocal* stat) override
{
for (size_t i = 0; i < stat->vars.size && i < stat->values.size; ++i)
{
AstLocal* local = stat->vars.data[i];
AstExpr* expr = stat->values.data[i];
// track initializing expression to be able to trace modules through locals
locals[local] = expr;
}
return true;
}
bool visit(AstStatAssign* stat) override
{
for (size_t i = 0; i < stat->vars.size; ++i)
{
// locals that are assigned don't have a known expression
if (AstExprLocal* expr = stat->vars.data[i]->as<AstExprLocal>())
locals[expr->local] = nullptr;
}
return true;
}
bool visit(AstType* node) override
{
// allow resolving require inside `typeof` annotations
return true;
}
AstExpr* getDependent(AstExpr* node)
{
if (AstExprLocal* expr = node->as<AstExprLocal>())
return locals[expr->local];
else if (AstExprIndexName* expr = node->as<AstExprIndexName>())
return expr->expr;
else if (AstExprIndexExpr* expr = node->as<AstExprIndexExpr>())
return expr->expr;
else if (AstExprCall* expr = node->as<AstExprCall>(); expr && expr->self)
return expr->func->as<AstExprIndexName>()->expr;
else
return nullptr;
}
void process()
{
ModuleInfo moduleContext{currentModuleName};
// seed worklist with require arguments
work.reserve(requires.size());
for (AstExprCall* require : requires)
work.push_back(require->args.data[0]);
// push all dependent expressions to the work stack; note that the vector is modified during traversal
for (size_t i = 0; i < work.size(); ++i)
if (AstExpr* dep = getDependent(work[i]))
work.push_back(dep);
// resolve all expressions to a module info
for (size_t i = work.size(); i > 0; --i)
{
AstExpr* expr = work[i - 1];
// when multiple expressions depend on the same one we push it to work queue multiple times
if (result.exprs.contains(expr))
continue;
std::optional<ModuleInfo> info;
if (AstExpr* dep = getDependent(expr))
{
const ModuleInfo* context = result.exprs.find(dep);
// locals just inherit their dependent context, no resolution required
if (expr->is<AstExprLocal>())
info = context ? std::optional<ModuleInfo>(*context) : std::nullopt;
else
info = fileResolver->resolveModule(context, expr);
}
else
{
info = fileResolver->resolveModule(&moduleContext, expr);
}
if (info)
result.exprs[expr] = std::move(*info);
}
// resolve all requires according to their argument
result.requires.reserve(requires.size());
for (AstExprCall* require : requires)
{
AstExpr* arg = require->args.data[0];
if (const ModuleInfo* info = result.exprs.find(arg))
{
result.requires.push_back({info->name, require->location});
ModuleInfo infoCopy = *info; // copy *info out since next line invalidates info!
result.exprs[require] = std::move(infoCopy);
}
else
{
result.exprs[require] = {}; // mark require as unresolved
}
}
}
RequireTraceResult& result;
FileResolver* fileResolver;
ModuleName currentModuleName;
DenseHashMap<AstLocal*, AstExpr*> locals;
std::vector<AstExpr*> work;
std::vector<AstExprCall*> requires;
};
} // anonymous namespace
RequireTraceResult traceRequires(FileResolver* fileResolver, AstStatBlock* root, ModuleName currentModuleName)
RequireTraceResult traceRequires(FileResolver* fileResolver, AstStatBlock* root, const ModuleName& currentModuleName)
{
RequireTracer tracer{fileResolver, std::move(currentModuleName)};
root->visit(&tracer);
return tracer.result;
if (FFlag::LuauNewRequireTrace)
{
RequireTraceResult result;
RequireTracer tracer{result, fileResolver, currentModuleName};
root->visit(&tracer);
tracer.process();
return result;
}
else
{
RequireTracerOld tracer{fileResolver, currentModuleName};
root->visit(&tracer);
return tracer.result;
}
}
} // namespace Luau

123
Analysis/src/Scope.cpp Normal file
View file

@ -0,0 +1,123 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Scope.h"
namespace Luau
{
Scope::Scope(TypePackId returnType)
: parent(nullptr)
, returnType(returnType)
, level(TypeLevel())
{
}
Scope::Scope(const ScopePtr& parent, int subLevel)
: parent(parent)
, returnType(parent->returnType)
, level(parent->level.incr())
{
level.subLevel = subLevel;
}
std::optional<TypeId> Scope::lookup(const Symbol& name)
{
Scope* scope = this;
while (scope)
{
auto it = scope->bindings.find(name);
if (it != scope->bindings.end())
return it->second.typeId;
scope = scope->parent.get();
}
return std::nullopt;
}
std::optional<TypeFun> Scope::lookupType(const Name& name)
{
const Scope* scope = this;
while (true)
{
auto it = scope->exportedTypeBindings.find(name);
if (it != scope->exportedTypeBindings.end())
return it->second;
it = scope->privateTypeBindings.find(name);
if (it != scope->privateTypeBindings.end())
return it->second;
if (scope->parent)
scope = scope->parent.get();
else
return std::nullopt;
}
}
std::optional<TypeFun> Scope::lookupImportedType(const Name& moduleAlias, const Name& name)
{
const Scope* scope = this;
while (scope)
{
auto it = scope->importedTypeBindings.find(moduleAlias);
if (it == scope->importedTypeBindings.end())
{
scope = scope->parent.get();
continue;
}
auto it2 = it->second.find(name);
if (it2 == it->second.end())
{
scope = scope->parent.get();
continue;
}
return it2->second;
}
return std::nullopt;
}
std::optional<TypePackId> Scope::lookupPack(const Name& name)
{
const Scope* scope = this;
while (true)
{
auto it = scope->privateTypePackBindings.find(name);
if (it != scope->privateTypePackBindings.end())
return it->second;
if (scope->parent)
scope = scope->parent.get();
else
return std::nullopt;
}
}
std::optional<Binding> Scope::linearSearchForBinding(const std::string& name, bool traverseScopeChain)
{
Scope* scope = this;
while (scope)
{
for (const auto& [n, binding] : scope->bindings)
{
if (n.local && n.local->name == name.c_str())
return binding;
else if (n.global.value && n.global == name.c_str())
return binding;
}
scope = scope->parent.get();
if (!traverseScopeChain)
break;
}
return std::nullopt;
}
} // namespace Luau

View file

@ -6,9 +6,11 @@
#include <algorithm>
#include <stdexcept>
LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 0)
LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 1000)
LUAU_FASTFLAGVARIABLE(LuauSubstitutionDontReplaceIgnoredTypes, false)
LUAU_FASTFLAG(LuauSecondTypecheckKnowsTheDataModel)
LUAU_FASTFLAG(LuauRankNTypes)
LUAU_FASTFLAG(LuauTypeAliasPacks)
namespace Luau
{
@ -35,8 +37,15 @@ void Tarjan::visitChildren(TypeId ty, int index)
visitChild(ttv->indexer->indexType);
visitChild(ttv->indexer->indexResultType);
}
for (TypeId itp : ttv->instantiatedTypeParams)
visitChild(itp);
if (FFlag::LuauTypeAliasPacks)
{
for (TypePackId itp : ttv->instantiatedTypePackParams)
visitChild(itp);
}
}
else if (const MetatableTypeVar* mtv = get<MetatableTypeVar>(ty))
{
@ -332,9 +341,11 @@ std::optional<TypeId> Substitution::substitute(TypeId ty)
return std::nullopt;
for (auto [oldTy, newTy] : newTypes)
replaceChildren(newTy);
if (!FFlag::LuauSubstitutionDontReplaceIgnoredTypes || !ignoreChildren(oldTy))
replaceChildren(newTy);
for (auto [oldTp, newTp] : newPacks)
replaceChildren(newTp);
if (!FFlag::LuauSubstitutionDontReplaceIgnoredTypes || !ignoreChildren(oldTp))
replaceChildren(newTp);
TypeId newTy = replace(ty);
return newTy;
}
@ -350,9 +361,11 @@ std::optional<TypePackId> Substitution::substitute(TypePackId tp)
return std::nullopt;
for (auto [oldTy, newTy] : newTypes)
replaceChildren(newTy);
if (!FFlag::LuauSubstitutionDontReplaceIgnoredTypes || !ignoreChildren(oldTy))
replaceChildren(newTy);
for (auto [oldTp, newTp] : newPacks)
replaceChildren(newTp);
if (!FFlag::LuauSubstitutionDontReplaceIgnoredTypes || !ignoreChildren(oldTp))
replaceChildren(newTp);
TypePackId newTp = replace(tp);
return newTp;
}
@ -382,6 +395,10 @@ TypeId Substitution::clone(TypeId ty)
clone.name = ttv->name;
clone.syntheticName = ttv->syntheticName;
clone.instantiatedTypeParams = ttv->instantiatedTypeParams;
if (FFlag::LuauTypeAliasPacks)
clone.instantiatedTypePackParams = ttv->instantiatedTypePackParams;
if (FFlag::LuauSecondTypecheckKnowsTheDataModel)
clone.tags = ttv->tags;
result = addType(std::move(clone));
@ -487,8 +504,15 @@ void Substitution::replaceChildren(TypeId ty)
ttv->indexer->indexType = replace(ttv->indexer->indexType);
ttv->indexer->indexResultType = replace(ttv->indexer->indexResultType);
}
for (TypeId& itp : ttv->instantiatedTypeParams)
itp = replace(itp);
if (FFlag::LuauTypeAliasPacks)
{
for (TypePackId& itp : ttv->instantiatedTypePackParams)
itp = replace(itp);
}
}
else if (MetatableTypeVar* mtv = getMutable<MetatableTypeVar>(ty))
{

View file

@ -1,6 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/ToString.h"
#include "Luau/Scope.h"
#include "Luau/TypeInfer.h"
#include "Luau/TypePack.h"
#include "Luau/TypeVar.h"
@ -9,10 +10,9 @@
#include <algorithm>
#include <stdexcept>
LUAU_FASTFLAG(LuauToStringFollowsBoundTo)
LUAU_FASTFLAG(LuauExtraNilRecovery)
LUAU_FASTFLAG(LuauOccursCheckOkWithRecursiveFunctions)
LUAU_FASTFLAGVARIABLE(LuauInstantiatedTypeParamRecursion, false)
LUAU_FASTFLAG(LuauTypeAliasPacks)
namespace Luau
{
@ -59,6 +59,13 @@ struct FindCyclicTypes
{
for (TypeId itp : ttv.instantiatedTypeParams)
visitTypeVar(itp, *this, seen);
if (FFlag::LuauTypeAliasPacks)
{
for (TypePackId itp : ttv.instantiatedTypePackParams)
visitTypeVar(itp, *this, seen);
}
return exhaustive;
}
@ -151,15 +158,6 @@ struct StringifierState
seen.erase(iter);
}
static std::string generateName(size_t i)
{
std::string n;
n = char('a' + i % 26);
if (i >= 26)
n += std::to_string(i / 26);
return n;
}
std::string getName(TypeId ty)
{
const size_t s = result.nameMap.typeVars.size();
@ -258,23 +256,60 @@ struct TypeVarStringifier
void stringify(TypePackId tp);
void stringify(TypePackId tpid, const std::vector<std::optional<FunctionArgument>>& names);
void stringify(const std::vector<TypeId>& types)
void stringify(const std::vector<TypeId>& types, const std::vector<TypePackId>& typePacks)
{
if (types.size() == 0)
if (types.size() == 0 && (!FFlag::LuauTypeAliasPacks || typePacks.size() == 0))
return;
if (types.size())
if (types.size() || (FFlag::LuauTypeAliasPacks && typePacks.size()))
state.emit("<");
for (size_t i = 0; i < types.size(); ++i)
if (FFlag::LuauTypeAliasPacks)
{
if (i > 0)
state.emit(", ");
bool first = true;
stringify(types[i]);
for (TypeId ty : types)
{
if (!first)
state.emit(", ");
first = false;
stringify(ty);
}
bool singleTp = typePacks.size() == 1;
for (TypePackId tp : typePacks)
{
if (isEmpty(tp) && singleTp)
continue;
if (!first)
state.emit(", ");
else
first = false;
if (!singleTp)
state.emit("(");
stringify(tp);
if (!singleTp)
state.emit(")");
}
}
else
{
for (size_t i = 0; i < types.size(); ++i)
{
if (i > 0)
state.emit(", ");
stringify(types[i]);
}
}
if (types.size())
if (types.size() || (FFlag::LuauTypeAliasPacks && typePacks.size()))
state.emit(">");
}
@ -388,7 +423,7 @@ struct TypeVarStringifier
void operator()(TypeId, const TableTypeVar& ttv)
{
if (FFlag::LuauToStringFollowsBoundTo && ttv.boundTo)
if (ttv.boundTo)
return stringify(*ttv.boundTo);
if (!state.exhaustive)
@ -411,14 +446,14 @@ struct TypeVarStringifier
}
state.emit(*ttv.name);
stringify(ttv.instantiatedTypeParams);
stringify(ttv.instantiatedTypeParams, ttv.instantiatedTypePackParams);
return;
}
if (ttv.syntheticName)
{
state.result.invalid = true;
state.emit(*ttv.syntheticName);
stringify(ttv.instantiatedTypeParams);
stringify(ttv.instantiatedTypeParams, ttv.instantiatedTypePackParams);
return;
}
}
@ -539,8 +574,7 @@ struct TypeVarStringifier
std::vector<std::string> results = {};
for (auto el : &uv)
{
if (FFlag::LuauExtraNilRecovery || FFlag::LuauAddMissingFollow)
el = follow(el);
el = follow(el);
if (isNil(el))
{
@ -604,8 +638,7 @@ struct TypeVarStringifier
std::vector<std::string> results = {};
for (auto el : uv.parts)
{
if (FFlag::LuauExtraNilRecovery || FFlag::LuauAddMissingFollow)
el = follow(el);
el = follow(el);
std::string saved = std::move(state.result.name);
@ -900,13 +933,26 @@ ToStringResult toStringDetailed(TypeId ty, const ToStringOptions& opts)
result.name += ttv->name ? *ttv->name : *ttv->syntheticName;
if (ttv->instantiatedTypeParams.empty())
if (ttv->instantiatedTypeParams.empty() && (!FFlag::LuauTypeAliasPacks || ttv->instantiatedTypePackParams.empty()))
return result;
std::vector<std::string> params;
for (TypeId tp : ttv->instantiatedTypeParams)
params.push_back(toString(tp));
if (FFlag::LuauTypeAliasPacks)
{
// Doesn't preserve grouping of multiple type packs
// But this is under a parent block of code that is being removed later
for (TypePackId tp : ttv->instantiatedTypePackParams)
{
std::string content = toString(tp);
if (!content.empty())
params.push_back(std::move(content));
}
}
result.name += "<" + join(params, ", ") + ">";
return result;
}
@ -950,30 +996,37 @@ ToStringResult toStringDetailed(TypeId ty, const ToStringOptions& opts)
result.name += ttv->name ? *ttv->name : *ttv->syntheticName;
if (ttv->instantiatedTypeParams.empty())
return result;
result.name += "<";
bool first = true;
for (TypeId ty : ttv->instantiatedTypeParams)
if (FFlag::LuauTypeAliasPacks)
{
if (!first)
result.name += ", ";
else
first = false;
tvs.stringify(ty);
}
if (opts.maxTypeLength > 0 && result.name.length() > opts.maxTypeLength)
{
result.truncated = true;
result.name += "... <TRUNCATED>";
tvs.stringify(ttv->instantiatedTypeParams, ttv->instantiatedTypePackParams);
}
else
{
result.name += ">";
if (ttv->instantiatedTypeParams.empty() && (!FFlag::LuauTypeAliasPacks || ttv->instantiatedTypePackParams.empty()))
return result;
result.name += "<";
bool first = true;
for (TypeId ty : ttv->instantiatedTypeParams)
{
if (!first)
result.name += ", ";
else
first = false;
tvs.stringify(ty);
}
if (opts.maxTypeLength > 0 && result.name.length() > opts.maxTypeLength)
{
result.truncated = true;
result.name += "... <TRUNCATED>";
}
else
{
result.name += ">";
}
}
return result;
@ -1139,4 +1192,13 @@ void dump(TypePackId ty)
printf("%s\n", toString(ty, opts).c_str());
}
std::string generateName(size_t i)
{
std::string n;
n = char('a' + i % 26);
if (i >= 26)
n += std::to_string(i / 26);
return n;
}
} // namespace Luau

View file

@ -298,8 +298,15 @@ struct ArcCollector : public AstVisitor
struct ContainsFunctionCall : public AstVisitor
{
bool alsoReturn = false;
bool result = false;
ContainsFunctionCall() = default;
explicit ContainsFunctionCall(bool alsoReturn)
: alsoReturn(alsoReturn)
{
}
bool visit(AstExpr*) override
{
return !result; // short circuit if result is true
@ -318,6 +325,17 @@ struct ContainsFunctionCall : public AstVisitor
return false;
}
bool visit(AstStatReturn* stat) override
{
if (alsoReturn)
{
result = true;
return false;
}
else
return AstVisitor::visit(stat);
}
bool visit(AstExprFunction*) override
{
return false;
@ -479,6 +497,13 @@ bool containsFunctionCall(const AstStat& stat)
return cfc.result;
}
bool containsFunctionCallOrReturn(const AstStat& stat)
{
detail::ContainsFunctionCall cfc{true};
const_cast<AstStat&>(stat).visit(&cfc);
return cfc.result;
}
bool isFunction(const AstStat& stat)
{
return stat.is<AstStatFunction>() || stat.is<AstStatLocalFunction>();

View file

@ -11,6 +11,7 @@
#include <math.h>
LUAU_FASTFLAG(LuauGenericFunctions)
LUAU_FASTFLAG(LuauTypeAliasPacks)
namespace
{
@ -280,10 +281,19 @@ struct Printer
void visualizeTypePackAnnotation(const AstTypePack& annotation)
{
if (const AstTypePackVariadic* variadic = annotation.as<AstTypePackVariadic>())
if (const AstTypePackVariadic* variadicTp = annotation.as<AstTypePackVariadic>())
{
writer.symbol("...");
visualizeTypeAnnotation(*variadic->variadicType);
visualizeTypeAnnotation(*variadicTp->variadicType);
}
else if (const AstTypePackGeneric* genericTp = annotation.as<AstTypePackGeneric>())
{
writer.symbol(genericTp->genericName.value);
writer.symbol("...");
}
else if (const AstTypePackExplicit* explicitTp = annotation.as<AstTypePackExplicit>())
{
visualizeTypeList(explicitTp->typeList, true);
}
else
{
@ -807,7 +817,7 @@ struct Printer
writer.keyword("type");
writer.identifier(a->name.value);
if (a->generics.size > 0)
if (a->generics.size > 0 || (FFlag::LuauTypeAliasPacks && a->genericPacks.size > 0))
{
writer.symbol("<");
CommaSeparatorInserter comma(writer);
@ -817,6 +827,17 @@ struct Printer
comma();
writer.identifier(o.value);
}
if (FFlag::LuauTypeAliasPacks)
{
for (auto o : a->genericPacks)
{
comma();
writer.identifier(o.value);
writer.symbol("...");
}
}
writer.symbol(">");
}
writer.maybeSpace(a->type->location.begin, 2);
@ -960,15 +981,20 @@ struct Printer
if (const auto& a = typeAnnotation.as<AstTypeReference>())
{
writer.write(a->name.value);
if (a->generics.size > 0)
if (a->parameters.size > 0)
{
CommaSeparatorInserter comma(writer);
writer.symbol("<");
for (auto o : a->generics)
for (auto o : a->parameters)
{
comma();
visualizeTypeAnnotation(*o);
if (o.type)
visualizeTypeAnnotation(*o.type);
else
visualizeTypePackAnnotation(*o.typePack);
}
writer.symbol(">");
}
}

View file

@ -5,6 +5,8 @@
#include <algorithm>
LUAU_FASTFLAGVARIABLE(LuauShareTxnSeen, false)
namespace Luau
{
@ -33,6 +35,12 @@ void TxnLog::rollback()
for (auto it = tableChanges.rbegin(); it != tableChanges.rend(); ++it)
std::swap(it->first->boundTo, it->second);
if (FFlag::LuauShareTxnSeen)
{
LUAU_ASSERT(originalSeenSize <= sharedSeen->size());
sharedSeen->resize(originalSeenSize);
}
}
void TxnLog::concat(TxnLog rhs)
@ -46,27 +54,44 @@ void TxnLog::concat(TxnLog rhs)
tableChanges.insert(tableChanges.end(), rhs.tableChanges.begin(), rhs.tableChanges.end());
rhs.tableChanges.clear();
seen.swap(rhs.seen);
rhs.seen.clear();
if (!FFlag::LuauShareTxnSeen)
{
ownedSeen.swap(rhs.ownedSeen);
rhs.ownedSeen.clear();
}
}
bool TxnLog::haveSeen(TypeId lhs, TypeId rhs)
{
const std::pair<TypeId, TypeId> sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs);
return (seen.end() != std::find(seen.begin(), seen.end(), sortedPair));
if (FFlag::LuauShareTxnSeen)
return (sharedSeen->end() != std::find(sharedSeen->begin(), sharedSeen->end(), sortedPair));
else
return (ownedSeen.end() != std::find(ownedSeen.begin(), ownedSeen.end(), sortedPair));
}
void TxnLog::pushSeen(TypeId lhs, TypeId rhs)
{
const std::pair<TypeId, TypeId> sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs);
seen.push_back(sortedPair);
if (FFlag::LuauShareTxnSeen)
sharedSeen->push_back(sortedPair);
else
ownedSeen.push_back(sortedPair);
}
void TxnLog::popSeen(TypeId lhs, TypeId rhs)
{
const std::pair<TypeId, TypeId> sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs);
LUAU_ASSERT(sortedPair == seen.back());
seen.pop_back();
if (FFlag::LuauShareTxnSeen)
{
LUAU_ASSERT(sortedPair == sharedSeen->back());
sharedSeen->pop_back();
}
else
{
LUAU_ASSERT(sortedPair == ownedSeen.back());
ownedSeen.pop_back();
}
}
} // namespace Luau

View file

@ -5,6 +5,8 @@
#include "Luau/Module.h"
#include "Luau/Parser.h"
#include "Luau/RecursionCounter.h"
#include "Luau/Scope.h"
#include "Luau/ToString.h"
#include "Luau/TypeInfer.h"
#include "Luau/TypePack.h"
#include "Luau/TypeVar.h"
@ -12,6 +14,7 @@
#include <string>
LUAU_FASTFLAG(LuauGenericFunctions)
LUAU_FASTFLAG(LuauTypeAliasPacks)
static char* allocateString(Luau::Allocator& allocator, std::string_view contents)
{
@ -31,15 +34,31 @@ static char* allocateString(Luau::Allocator& allocator, const char* format, Data
return result;
}
using SyntheticNames = std::unordered_map<const void*, char*>;
namespace Luau
{
static const char* getName(Allocator* allocator, SyntheticNames* syntheticNames, const Unifiable::Generic& gen)
{
size_t s = syntheticNames->size();
char*& n = (*syntheticNames)[&gen];
if (!n)
{
std::string str = gen.explicitName ? gen.name : generateName(s);
n = static_cast<char*>(allocator->allocate(str.size() + 1));
strcpy(n, str.c_str());
}
return n;
}
class TypeRehydrationVisitor
{
mutable std::map<void*, int> seen;
mutable int count = 0;
std::map<void*, int> seen;
int count = 0;
bool hasSeen(const void* tv) const
bool hasSeen(const void* tv)
{
void* ttv = const_cast<void*>(tv);
auto it = seen.find(ttv);
@ -51,13 +70,16 @@ class TypeRehydrationVisitor
}
public:
TypeRehydrationVisitor(Allocator* alloc, const TypeRehydrationOptions& options = TypeRehydrationOptions())
TypeRehydrationVisitor(Allocator* alloc, SyntheticNames* syntheticNames, const TypeRehydrationOptions& options = TypeRehydrationOptions())
: allocator(alloc)
, syntheticNames(syntheticNames)
, options(options)
{
}
AstType* operator()(const PrimitiveTypeVar& ptv) const
AstTypePack* rehydrate(TypePackId tp);
AstType* operator()(const PrimitiveTypeVar& ptv)
{
switch (ptv.type)
{
@ -75,26 +97,34 @@ public:
return nullptr;
}
}
AstType* operator()(const AnyTypeVar&) const
AstType* operator()(const AnyTypeVar&)
{
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("any"));
}
AstType* operator()(const TableTypeVar& ttv) const
AstType* operator()(const TableTypeVar& ttv)
{
RecursionCounter counter(&count);
if (ttv.name && options.bannedNames.find(*ttv.name) == options.bannedNames.end())
{
AstArray<AstType*> generics;
generics.size = ttv.instantiatedTypeParams.size();
generics.data = static_cast<AstType**>(allocator->allocate(sizeof(AstType*) * generics.size));
AstArray<AstTypeOrPack> parameters;
parameters.size = ttv.instantiatedTypeParams.size();
parameters.data = static_cast<AstTypeOrPack*>(allocator->allocate(sizeof(AstTypeOrPack) * parameters.size));
for (size_t i = 0; i < ttv.instantiatedTypeParams.size(); ++i)
{
generics.data[i] = Luau::visit(*this, ttv.instantiatedTypeParams[i]->ty);
parameters.data[i] = {Luau::visit(*this, ttv.instantiatedTypeParams[i]->ty), {}};
}
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName(ttv.name->c_str()), generics);
if (FFlag::LuauTypeAliasPacks)
{
for (size_t i = 0; i < ttv.instantiatedTypePackParams.size(); ++i)
{
parameters.data[i] = {{}, rehydrate(ttv.instantiatedTypePackParams[i])};
}
}
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName(ttv.name->c_str()), parameters.size != 0, parameters);
}
if (hasSeen(&ttv))
@ -133,12 +163,12 @@ public:
return allocator->alloc<AstTypeTable>(Location(), props, indexer);
}
AstType* operator()(const MetatableTypeVar& mtv) const
AstType* operator()(const MetatableTypeVar& mtv)
{
return Luau::visit(*this, mtv.table->ty);
}
AstType* operator()(const ClassTypeVar& ctv) const
AstType* operator()(const ClassTypeVar& ctv)
{
RecursionCounter counter(&count);
@ -165,7 +195,7 @@ public:
return allocator->alloc<AstTypeTable>(Location(), props);
}
AstType* operator()(const FunctionTypeVar& ftv) const
AstType* operator()(const FunctionTypeVar& ftv)
{
RecursionCounter counter(&count);
@ -222,10 +252,17 @@ public:
AstTypePack* argTailAnnotation = nullptr;
if (argTail)
{
TypePackId tail = *argTail;
if (const VariadicTypePack* vtp = get<VariadicTypePack>(tail))
if (FFlag::LuauTypeAliasPacks)
{
argTailAnnotation = allocator->alloc<AstTypePackVariadic>(Location(), Luau::visit(*this, vtp->ty->ty));
argTailAnnotation = rehydrate(*argTail);
}
else
{
TypePackId tail = *argTail;
if (const VariadicTypePack* vtp = get<VariadicTypePack>(tail))
{
argTailAnnotation = allocator->alloc<AstTypePackVariadic>(Location(), Luau::visit(*this, vtp->ty->ty));
}
}
}
@ -235,10 +272,12 @@ public:
size_t i = 0;
for (const auto& el : ftv.argNames)
{
std::optional<AstArgumentName>* arg = &argNames.data[i++];
if (el)
argNames.data[i++] = {AstName(el->name.c_str()), el->location};
new (arg) std::optional<AstArgumentName>(AstArgumentName(AstName(el->name.c_str()), el->location));
else
argNames.data[i++] = {};
new (arg) std::optional<AstArgumentName>();
}
AstArray<AstType*> returnTypes;
@ -255,33 +294,40 @@ public:
AstTypePack* retTailAnnotation = nullptr;
if (retTail)
{
TypePackId tail = *retTail;
if (const VariadicTypePack* vtp = get<VariadicTypePack>(tail))
if (FFlag::LuauTypeAliasPacks)
{
retTailAnnotation = allocator->alloc<AstTypePackVariadic>(Location(), Luau::visit(*this, vtp->ty->ty));
retTailAnnotation = rehydrate(*retTail);
}
else
{
TypePackId tail = *retTail;
if (const VariadicTypePack* vtp = get<VariadicTypePack>(tail))
{
retTailAnnotation = allocator->alloc<AstTypePackVariadic>(Location(), Luau::visit(*this, vtp->ty->ty));
}
}
}
return allocator->alloc<AstTypeFunction>(
Location(), generics, genericPacks, AstTypeList{argTypes, argTailAnnotation}, argNames, AstTypeList{returnTypes, retTailAnnotation});
}
AstType* operator()(const Unifiable::Error&) const
AstType* operator()(const Unifiable::Error&)
{
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("Unifiable<Error>"));
}
AstType* operator()(const GenericTypeVar& gtv) const
AstType* operator()(const GenericTypeVar& gtv)
{
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName(gtv.name.c_str()));
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName(getName(allocator, syntheticNames, gtv)));
}
AstType* operator()(const Unifiable::Bound<TypeId>& bound) const
AstType* operator()(const Unifiable::Bound<TypeId>& bound)
{
return Luau::visit(*this, bound.boundTo->ty);
}
AstType* operator()(Unifiable::Free ftv) const
AstType* operator()(const FreeTypeVar& ftv)
{
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("free"));
}
AstType* operator()(const UnionTypeVar& uv) const
AstType* operator()(const UnionTypeVar& uv)
{
AstArray<AstType*> unionTypes;
unionTypes.size = uv.options.size();
@ -292,7 +338,7 @@ public:
}
return allocator->alloc<AstTypeUnion>(Location(), unionTypes);
}
AstType* operator()(const IntersectionTypeVar& uv) const
AstType* operator()(const IntersectionTypeVar& uv)
{
AstArray<AstType*> intersectionTypes;
intersectionTypes.size = uv.parts.size();
@ -303,16 +349,84 @@ public:
}
return allocator->alloc<AstTypeIntersection>(Location(), intersectionTypes);
}
AstType* operator()(const LazyTypeVar& ltv) const
AstType* operator()(const LazyTypeVar& ltv)
{
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("<Lazy?>"));
}
private:
Allocator* allocator;
SyntheticNames* syntheticNames;
const TypeRehydrationOptions& options;
};
class TypePackRehydrationVisitor
{
public:
TypePackRehydrationVisitor(Allocator* allocator, SyntheticNames* syntheticNames, TypeRehydrationVisitor* typeVisitor)
: allocator(allocator)
, syntheticNames(syntheticNames)
, typeVisitor(typeVisitor)
{
LUAU_ASSERT(allocator);
LUAU_ASSERT(syntheticNames);
LUAU_ASSERT(typeVisitor);
}
AstTypePack* operator()(const BoundTypePack& btp) const
{
return Luau::visit(*this, btp.boundTo->ty);
}
AstTypePack* operator()(const TypePack& tp) const
{
AstArray<AstType*> head;
head.size = tp.head.size();
head.data = static_cast<AstType**>(allocator->allocate(sizeof(AstType*) * tp.head.size()));
for (size_t i = 0; i < tp.head.size(); i++)
head.data[i] = Luau::visit(*typeVisitor, tp.head[i]->ty);
AstTypePack* tail = nullptr;
if (tp.tail)
tail = Luau::visit(*this, (*tp.tail)->ty);
return allocator->alloc<AstTypePackExplicit>(Location(), AstTypeList{head, tail});
}
AstTypePack* operator()(const VariadicTypePack& vtp) const
{
return allocator->alloc<AstTypePackVariadic>(Location(), Luau::visit(*typeVisitor, vtp.ty->ty));
}
AstTypePack* operator()(const GenericTypePack& gtp) const
{
return allocator->alloc<AstTypePackGeneric>(Location(), AstName(getName(allocator, syntheticNames, gtp)));
}
AstTypePack* operator()(const FreeTypePack& gtp) const
{
return allocator->alloc<AstTypePackGeneric>(Location(), AstName("free"));
}
AstTypePack* operator()(const Unifiable::Error&) const
{
return allocator->alloc<AstTypePackGeneric>(Location(), AstName("Unifiable<Error>"));
}
private:
Allocator* allocator;
SyntheticNames* syntheticNames;
TypeRehydrationVisitor* typeVisitor;
};
AstTypePack* TypeRehydrationVisitor::rehydrate(TypePackId tp)
{
TypePackRehydrationVisitor tprv(allocator, syntheticNames, this);
return Luau::visit(tprv, tp->ty);
}
class TypeAttacher : public AstVisitor
{
public:
@ -344,7 +458,7 @@ public:
{
if (!type)
return nullptr;
return Luau::visit(TypeRehydrationVisitor(allocator), (*type)->ty);
return Luau::visit(TypeRehydrationVisitor(allocator, &syntheticNames), (*type)->ty);
}
AstArray<Luau::AstType*> typeAstPack(TypePackId type)
@ -356,7 +470,7 @@ public:
result.data = static_cast<AstType**>(allocator->allocate(sizeof(AstType*) * v.size()));
for (size_t i = 0; i < v.size(); ++i)
{
result.data[i] = Luau::visit(TypeRehydrationVisitor(allocator), v[i]->ty);
result.data[i] = Luau::visit(TypeRehydrationVisitor(allocator, &syntheticNames), v[i]->ty);
}
return result;
}
@ -406,9 +520,16 @@ public:
if (tail)
{
TypePackId tailPack = *tail;
if (const VariadicTypePack* vtp = get<VariadicTypePack>(tailPack))
variadicAnnotation = allocator->alloc<AstTypePackVariadic>(Location(), typeAst(vtp->ty));
if (FFlag::LuauTypeAliasPacks)
{
variadicAnnotation = TypeRehydrationVisitor(allocator, &syntheticNames).rehydrate(*tail);
}
else
{
TypePackId tailPack = *tail;
if (const VariadicTypePack* vtp = get<VariadicTypePack>(tailPack))
variadicAnnotation = allocator->alloc<AstTypePackVariadic>(Location(), typeAst(vtp->ty));
}
}
fn->returnAnnotation = AstTypeList{typeAstPack(ret), variadicAnnotation};
@ -421,6 +542,7 @@ public:
private:
Module& module;
Allocator* allocator;
SyntheticNames syntheticNames;
};
void attachTypeData(SourceModule& source, Module& result)
@ -431,7 +553,8 @@ void attachTypeData(SourceModule& source, Module& result)
AstType* rehydrateAnnotation(TypeId type, Allocator* allocator, const TypeRehydrationOptions& options)
{
return Luau::visit(TypeRehydrationVisitor(allocator, options), type->ty);
SyntheticNames syntheticNames;
return Luau::visit(TypeRehydrationVisitor(allocator, &syntheticNames, options), type->ty);
}
} // namespace Luau

File diff suppressed because it is too large Load diff

View file

@ -209,6 +209,19 @@ size_t size(TypePackId tp)
return 0;
}
bool finite(TypePackId tp)
{
tp = follow(tp);
if (auto pack = get<TypePack>(tp))
return pack->tail ? finite(*pack->tail) : true;
if (auto pack = get<VariadicTypePack>(tp))
return false;
return true;
}
size_t size(const TypePack& tp)
{
size_t result = tp.head.size();

View file

@ -1,11 +1,10 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/TypeUtils.h"
#include "Luau/Scope.h"
#include "Luau/ToString.h"
#include "Luau/TypeInfer.h"
LUAU_FASTFLAG(LuauStringMetatable)
namespace Luau
{
@ -13,21 +12,6 @@ std::optional<TypeId> findMetatableEntry(ErrorVec& errors, const ScopePtr& globa
{
type = follow(type);
if (!FFlag::LuauStringMetatable)
{
if (const PrimitiveTypeVar* primType = get<PrimitiveTypeVar>(type))
{
if (primType->type != PrimitiveTypeVar::String || "__index" != entry)
return std::nullopt;
auto it = globalScope->bindings.find(AstName{"string"});
if (it != globalScope->bindings.end())
return it->second.typeId;
else
return std::nullopt;
}
}
std::optional<TypeId> metatable = getMetatable(type);
if (!metatable)
return std::nullopt;

View file

@ -19,11 +19,10 @@
LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500)
LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0)
LUAU_FASTFLAG(LuauImprovedTypeGuardPredicate2)
LUAU_FASTFLAGVARIABLE(LuauToStringFollowsBoundTo, false)
LUAU_FASTFLAG(LuauRankNTypes)
LUAU_FASTFLAGVARIABLE(LuauStringMetatable, false)
LUAU_FASTFLAG(LuauTypeGuardPeelsAwaySubclasses)
LUAU_FASTFLAG(LuauTypeAliasPacks)
LUAU_FASTFLAGVARIABLE(LuauRefactorTagging, false)
namespace Luau
{
@ -193,27 +192,11 @@ bool isOptional(TypeId ty)
bool isTableIntersection(TypeId ty)
{
if (FFlag::LuauImprovedTypeGuardPredicate2)
{
if (!get<IntersectionTypeVar>(follow(ty)))
return false;
std::vector<TypeId> parts = flattenIntersection(ty);
return std::all_of(parts.begin(), parts.end(), getTableType);
}
else
{
if (const IntersectionTypeVar* itv = get<IntersectionTypeVar>(ty))
{
for (TypeId part : itv->parts)
{
if (getTableType(follow(part)))
return true;
}
}
if (!get<IntersectionTypeVar>(follow(ty)))
return false;
}
std::vector<TypeId> parts = flattenIntersection(ty);
return std::all_of(parts.begin(), parts.end(), getTableType);
}
bool isOverloadedFunction(TypeId ty)
@ -235,8 +218,7 @@ std::optional<TypeId> getMetatable(TypeId type)
return mtType->metatable;
else if (const ClassTypeVar* classType = get<ClassTypeVar>(type))
return classType->metatable;
else if (const PrimitiveTypeVar* primitiveType = get<PrimitiveTypeVar>(type);
FFlag::LuauStringMetatable && primitiveType && primitiveType->metatable)
else if (const PrimitiveTypeVar* primitiveType = get<PrimitiveTypeVar>(type); primitiveType && primitiveType->metatable)
{
LUAU_ASSERT(primitiveType->type == PrimitiveTypeVar::String);
return primitiveType->metatable;
@ -871,6 +853,12 @@ void StateDot::visitChildren(TypeId ty, int index)
}
for (TypeId itp : ttv->instantiatedTypeParams)
visitChild(itp, index, "typeParam");
if (FFlag::LuauTypeAliasPacks)
{
for (TypePackId itp : ttv->instantiatedTypePackParams)
visitChild(itp, index, "typePackParam");
}
}
else if (const MetatableTypeVar* mtv = get<MetatableTypeVar>(ty))
{
@ -1502,4 +1490,86 @@ std::vector<TypeId> filterMap(TypeId type, TypeIdPredicate predicate)
return {};
}
static Tags* getTags(TypeId ty)
{
ty = follow(ty);
if (auto ftv = getMutable<FunctionTypeVar>(ty))
return &ftv->tags;
else if (auto ttv = getMutable<TableTypeVar>(ty))
return &ttv->tags;
else if (auto ctv = getMutable<ClassTypeVar>(ty))
return &ctv->tags;
return nullptr;
}
void attachTag(TypeId ty, const std::string& tagName)
{
if (!FFlag::LuauRefactorTagging)
{
if (auto ftv = getMutable<FunctionTypeVar>(ty))
{
ftv->tags.emplace_back(tagName);
}
else
{
LUAU_ASSERT(!"Got a non functional type");
}
}
else
{
if (auto tags = getTags(ty))
tags->push_back(tagName);
else
LUAU_ASSERT(!"This TypeId does not support tags");
}
}
void attachTag(Property& prop, const std::string& tagName)
{
LUAU_ASSERT(FFlag::LuauRefactorTagging);
prop.tags.push_back(tagName);
}
// We would ideally not expose this because it could cause a footgun.
// If the Base class has a tag and you ask if Derived has that tag, it would return false.
// Unfortunately, there's already use cases that's hard to disentangle. For now, we expose it.
bool hasTag(const Tags& tags, const std::string& tagName)
{
LUAU_ASSERT(FFlag::LuauRefactorTagging);
return std::find(tags.begin(), tags.end(), tagName) != tags.end();
}
bool hasTag(TypeId ty, const std::string& tagName)
{
ty = follow(ty);
// We special case classes because getTags only returns a pointer to one vector of tags.
// But classes has multiple vector of tags, represented throughout the hierarchy.
if (auto ctv = get<ClassTypeVar>(ty))
{
while (ctv)
{
if (hasTag(ctv->tags, tagName))
return true;
else if (!ctv->parent)
return false;
ctv = get<ClassTypeVar>(*ctv->parent);
LUAU_ASSERT(ctv);
}
}
else if (auto tags = getTags(ty))
return hasTag(*tags, tagName);
return false;
}
bool hasTag(const Property& prop, const std::string& tagName)
{
return hasTag(prop.tags, tagName);
}
} // namespace Luau

File diff suppressed because it is too large Load diff

View file

@ -264,6 +264,10 @@ public:
{
return false;
}
virtual bool visit(class AstTypePackExplicit* node)
{
return visit((class AstTypePack*)node);
}
virtual bool visit(class AstTypePackVariadic* node)
{
return visit((class AstTypePack*)node);
@ -930,12 +934,14 @@ class AstStatTypeAlias : public AstStat
public:
LUAU_RTTI(AstStatTypeAlias)
AstStatTypeAlias(const Location& location, const AstName& name, const AstArray<AstName>& generics, AstType* type, bool exported);
AstStatTypeAlias(const Location& location, const AstName& name, const AstArray<AstName>& generics, const AstArray<AstName>& genericPacks,
AstType* type, bool exported);
void visit(AstVisitor* visitor) override;
AstName name;
AstArray<AstName> generics;
AstArray<AstName> genericPacks;
AstType* type;
bool exported;
};
@ -1007,19 +1013,28 @@ public:
}
};
// Don't have Luau::Variant available, it's a bit of an overhead, but a plain struct is nice to use
struct AstTypeOrPack
{
AstType* type = nullptr;
AstTypePack* typePack = nullptr;
};
class AstTypeReference : public AstType
{
public:
LUAU_RTTI(AstTypeReference)
AstTypeReference(const Location& location, std::optional<AstName> prefix, AstName name, const AstArray<AstType*>& generics = {});
AstTypeReference(const Location& location, std::optional<AstName> prefix, AstName name, bool hasParameterList = false,
const AstArray<AstTypeOrPack>& parameters = {});
void visit(AstVisitor* visitor) override;
bool hasPrefix;
bool hasParameterList;
AstName prefix;
AstName name;
AstArray<AstType*> generics;
AstArray<AstTypeOrPack> parameters;
};
struct AstTableProp
@ -1152,6 +1167,18 @@ public:
}
};
class AstTypePackExplicit : public AstTypePack
{
public:
LUAU_RTTI(AstTypePackExplicit)
AstTypePackExplicit(const Location& location, AstTypeList typeList);
void visit(AstVisitor* visitor) override;
AstTypeList typeList;
};
class AstTypePackVariadic : public AstTypePack
{
public:

View file

@ -136,7 +136,10 @@ public:
const Key& key = ItemInterface::getKey(data[i]);
if (!eq(key, empty_key))
*newtable.insert_unsafe(key) = data[i];
{
Item* item = newtable.insert_unsafe(key);
*item = std::move(data[i]);
}
}
LUAU_ASSERT(count == newtable.count);

View file

@ -218,13 +218,14 @@ private:
AstTableIndexer* parseTableIndexerAnnotation();
AstType* parseFunctionTypeAnnotation();
AstTypeOrPack parseFunctionTypeAnnotation(bool allowPack);
AstType* parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray<AstName> generics, AstArray<AstName> genericPacks,
AstArray<AstType*>& params, AstArray<std::optional<AstArgumentName>>& paramNames, AstTypePack* varargAnnotation);
AstType* parseTableTypeAnnotation();
AstType* parseSimpleTypeAnnotation();
AstTypeOrPack parseSimpleTypeAnnotation(bool allowPack);
AstTypeOrPack parseTypeOrPackAnnotation();
AstType* parseTypeAnnotation(TempVector<AstType*>& parts, const Location& begin);
AstType* parseTypeAnnotation();
@ -284,7 +285,7 @@ private:
std::pair<AstArray<AstName>, AstArray<AstName>> parseGenericTypeListIfFFlagParseGenericFunctions();
// `<' typeAnnotation[, ...] `>'
AstArray<AstType*> parseTypeParams();
AstArray<AstTypeOrPack> parseTypeParams();
AstExpr* parseString();
@ -413,6 +414,7 @@ private:
std::vector<AstLocal*> scratchLocal;
std::vector<AstTableProp> scratchTableTypeProps;
std::vector<AstType*> scratchAnnotation;
std::vector<AstTypeOrPack> scratchTypeOrPackAnnotation;
std::vector<AstDeclaredClassProp> scratchDeclaredClassProps;
std::vector<AstExprTable::Item> scratchItem;
std::vector<AstArgumentName> scratchArgName;

View file

@ -0,0 +1,223 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Common.h"
#include <vector>
#include <stdint.h>
LUAU_FASTFLAG(DebugLuauTimeTracing)
#if defined(LUAU_ENABLE_TIME_TRACE)
namespace Luau
{
namespace TimeTrace
{
uint32_t getClockMicroseconds();
struct Token
{
const char* name;
const char* category;
};
enum class EventType : uint8_t
{
Enter,
Leave,
ArgName,
ArgValue,
};
struct Event
{
EventType type;
uint16_t token;
union
{
uint32_t microsec; // 1 hour trace limit
uint32_t dataPos;
} data;
};
struct GlobalContext;
struct ThreadContext;
GlobalContext& getGlobalContext();
uint16_t createToken(GlobalContext& context, const char* name, const char* category);
uint32_t createThread(GlobalContext& context, ThreadContext* threadContext);
void releaseThread(GlobalContext& context, ThreadContext* threadContext);
void flushEvents(GlobalContext& context, uint32_t threadId, const std::vector<Event>& events, const std::vector<char>& data);
struct ThreadContext
{
ThreadContext()
: globalContext(getGlobalContext())
{
threadId = createThread(globalContext, this);
}
~ThreadContext()
{
if (!events.empty())
flushEvents();
releaseThread(globalContext, this);
}
void flushEvents()
{
static uint16_t flushToken = createToken(globalContext, "flushEvents", "TimeTrace");
events.push_back({EventType::Enter, flushToken, {getClockMicroseconds()}});
TimeTrace::flushEvents(globalContext, threadId, events, data);
events.clear();
data.clear();
events.push_back({EventType::Leave, 0, {getClockMicroseconds()}});
}
void eventEnter(uint16_t token)
{
eventEnter(token, getClockMicroseconds());
}
void eventEnter(uint16_t token, uint32_t microsec)
{
events.push_back({EventType::Enter, token, {microsec}});
}
void eventLeave()
{
eventLeave(getClockMicroseconds());
}
void eventLeave(uint32_t microsec)
{
events.push_back({EventType::Leave, 0, {microsec}});
if (events.size() > kEventFlushLimit)
flushEvents();
}
void eventArgument(const char* name, const char* value)
{
uint32_t pos = uint32_t(data.size());
data.insert(data.end(), name, name + strlen(name) + 1);
events.push_back({EventType::ArgName, 0, {pos}});
pos = uint32_t(data.size());
data.insert(data.end(), value, value + strlen(value) + 1);
events.push_back({EventType::ArgValue, 0, {pos}});
}
GlobalContext& globalContext;
uint32_t threadId;
std::vector<Event> events;
std::vector<char> data;
static constexpr size_t kEventFlushLimit = 8192;
};
ThreadContext& getThreadContext();
struct Scope
{
explicit Scope(ThreadContext& context, uint16_t token)
: context(context)
{
if (!FFlag::DebugLuauTimeTracing)
return;
context.eventEnter(token);
}
~Scope()
{
if (!FFlag::DebugLuauTimeTracing)
return;
context.eventLeave();
}
ThreadContext& context;
};
struct OptionalTailScope
{
explicit OptionalTailScope(ThreadContext& context, uint16_t token, uint32_t threshold)
: context(context)
, token(token)
, threshold(threshold)
{
if (!FFlag::DebugLuauTimeTracing)
return;
pos = uint32_t(context.events.size());
microsec = getClockMicroseconds();
}
~OptionalTailScope()
{
if (!FFlag::DebugLuauTimeTracing)
return;
if (pos == context.events.size())
{
uint32_t curr = getClockMicroseconds();
if (curr - microsec > threshold)
{
context.eventEnter(token, microsec);
context.eventLeave(curr);
}
}
}
ThreadContext& context;
uint16_t token;
uint32_t threshold;
uint32_t microsec;
uint32_t pos;
};
LUAU_NOINLINE std::pair<uint16_t, Luau::TimeTrace::ThreadContext&> createScopeData(const char* name, const char* category);
} // namespace TimeTrace
} // namespace Luau
// Regular scope
#define LUAU_TIMETRACE_SCOPE(name, category) \
static auto lttScopeStatic = Luau::TimeTrace::createScopeData(name, category); \
Luau::TimeTrace::Scope lttScope(lttScopeStatic.second, lttScopeStatic.first)
// A scope without nested scopes that may be skipped if the time it took is less than the threshold
#define LUAU_TIMETRACE_OPTIONAL_TAIL_SCOPE(name, category, microsec) \
static auto lttScopeStaticOptTail = Luau::TimeTrace::createScopeData(name, category); \
Luau::TimeTrace::OptionalTailScope lttScope(lttScopeStaticOptTail.second, lttScopeStaticOptTail.first, microsec)
// Extra key/value data can be added to regular scopes
#define LUAU_TIMETRACE_ARGUMENT(name, value) \
do \
{ \
if (FFlag::DebugLuauTimeTracing) \
lttScopeStatic.second.eventArgument(name, value); \
} while (false)
#else
#define LUAU_TIMETRACE_SCOPE(name, category)
#define LUAU_TIMETRACE_OPTIONAL_TAIL_SCOPE(name, category, microsec)
#define LUAU_TIMETRACE_ARGUMENT(name, value) \
do \
{ \
} while (false)
#endif

View file

@ -641,10 +641,12 @@ void AstStatLocalFunction::visit(AstVisitor* visitor)
func->visit(visitor);
}
AstStatTypeAlias::AstStatTypeAlias(const Location& location, const AstName& name, const AstArray<AstName>& generics, AstType* type, bool exported)
AstStatTypeAlias::AstStatTypeAlias(const Location& location, const AstName& name, const AstArray<AstName>& generics,
const AstArray<AstName>& genericPacks, AstType* type, bool exported)
: AstStat(ClassIndex(), location)
, name(name)
, generics(generics)
, genericPacks(genericPacks)
, type(type)
, exported(exported)
{
@ -729,12 +731,14 @@ void AstStatError::visit(AstVisitor* visitor)
}
}
AstTypeReference::AstTypeReference(const Location& location, std::optional<AstName> prefix, AstName name, const AstArray<AstType*>& generics)
AstTypeReference::AstTypeReference(
const Location& location, std::optional<AstName> prefix, AstName name, bool hasParameterList, const AstArray<AstTypeOrPack>& parameters)
: AstType(ClassIndex(), location)
, hasPrefix(bool(prefix))
, hasParameterList(hasParameterList)
, prefix(prefix ? *prefix : AstName())
, name(name)
, generics(generics)
, parameters(parameters)
{
}
@ -742,8 +746,13 @@ void AstTypeReference::visit(AstVisitor* visitor)
{
if (visitor->visit(this))
{
for (AstType* generic : generics)
generic->visit(visitor);
for (const AstTypeOrPack& param : parameters)
{
if (param.type)
param.type->visit(visitor);
else
param.typePack->visit(visitor);
}
}
}
@ -849,6 +858,24 @@ void AstTypeError::visit(AstVisitor* visitor)
}
}
AstTypePackExplicit::AstTypePackExplicit(const Location& location, AstTypeList typeList)
: AstTypePack(ClassIndex(), location)
, typeList(typeList)
{
}
void AstTypePackExplicit::visit(AstVisitor* visitor)
{
if (visitor->visit(this))
{
for (AstType* type : typeList.types)
type->visit(visitor);
if (typeList.tailType)
typeList.tailType->visit(visitor);
}
}
AstTypePackVariadic::AstTypePackVariadic(const Location& location, AstType* variadicType)
: AstTypePack(ClassIndex(), location)
, variadicType(variadicType)

View file

@ -1,6 +1,8 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Parser.h"
#include "Luau/TimeTrace.h"
#include <algorithm>
// Warning: If you are introducing new syntax, ensure that it is behind a separate
@ -13,6 +15,8 @@ LUAU_FASTFLAGVARIABLE(LuauParseGenericFunctions, false)
LUAU_FASTFLAGVARIABLE(LuauCaptureBrokenCommentSpans, false)
LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionBaseSupport, false)
LUAU_FASTFLAGVARIABLE(LuauIfStatementRecursionGuard, false)
LUAU_FASTFLAGVARIABLE(LuauTypeAliasPacks, false)
LUAU_FASTFLAGVARIABLE(LuauParseTypePackTypeParameters, false)
namespace Luau
{
@ -148,6 +152,8 @@ static bool shouldParseTypePackAnnotation(Lexer& lexer)
ParseResult Parser::parse(const char* buffer, size_t bufferSize, AstNameTable& names, Allocator& allocator, ParseOptions options)
{
LUAU_TIMETRACE_SCOPE("Parser::parse", "Parser");
Parser p(buffer, bufferSize, names, allocator);
try
@ -769,14 +775,14 @@ AstStat* Parser::parseTypeAlias(const Location& start, bool exported)
if (!name)
name = Name(nameError, lexer.current().location);
// TODO: support generic type pack parameters in type aliases CLI-39907
auto [generics, genericPacks] = parseGenericTypeList();
expectAndConsume('=', "type alias");
AstType* type = parseTypeAnnotation();
return allocator.alloc<AstStatTypeAlias>(Location(start, type->location), name->name, generics, type, exported);
return allocator.alloc<AstStatTypeAlias>(
Location(start, type->location), name->name, generics, FFlag::LuauTypeAliasPacks ? genericPacks : AstArray<AstName>{}, type, exported);
}
AstDeclaredClassProp Parser::parseDeclaredClassMethod()
@ -1333,7 +1339,7 @@ AstType* Parser::parseTableTypeAnnotation()
// ReturnType ::= TypeAnnotation | `(' TypeList `)'
// FunctionTypeAnnotation ::= [`<' varlist `>'] `(' [TypeList] `)' `->` ReturnType
AstType* Parser::parseFunctionTypeAnnotation()
AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack)
{
incrementRecursionCounter("type annotation");
@ -1364,14 +1370,23 @@ AstType* Parser::parseFunctionTypeAnnotation()
matchRecoveryStopOnToken[Lexeme::SkinnyArrow]--;
// Not a function at all. Just a parenthesized type.
if (params.size() == 1 && !varargAnnotation && monomorphic && lexer.current().type != Lexeme::SkinnyArrow)
return params[0];
AstArray<AstType*> paramTypes = copy(params);
// Not a function at all. Just a parenthesized type. Or maybe a type pack with a single element
if (params.size() == 1 && !varargAnnotation && monomorphic && lexer.current().type != Lexeme::SkinnyArrow)
{
if (allowPack)
return {{}, allocator.alloc<AstTypePackExplicit>(begin.location, AstTypeList{paramTypes, nullptr})};
else
return {params[0], {}};
}
if (lexer.current().type != Lexeme::SkinnyArrow && monomorphic && allowPack)
return {{}, allocator.alloc<AstTypePackExplicit>(begin.location, AstTypeList{paramTypes, varargAnnotation})};
AstArray<std::optional<AstArgumentName>> paramNames = copy(names);
return parseFunctionTypeAnnotationTail(begin, generics, genericPacks, paramTypes, paramNames, varargAnnotation);
return {parseFunctionTypeAnnotationTail(begin, generics, genericPacks, paramTypes, paramNames, varargAnnotation), {}};
}
AstType* Parser::parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray<AstName> generics, AstArray<AstName> genericPacks,
@ -1421,7 +1436,7 @@ AstType* Parser::parseTypeAnnotation(TempVector<AstType*>& parts, const Location
if (c == '|')
{
nextLexeme();
parts.push_back(parseSimpleTypeAnnotation());
parts.push_back(parseSimpleTypeAnnotation(false).type);
isUnion = true;
}
else if (c == '?')
@ -1434,7 +1449,7 @@ AstType* Parser::parseTypeAnnotation(TempVector<AstType*>& parts, const Location
else if (c == '&')
{
nextLexeme();
parts.push_back(parseSimpleTypeAnnotation());
parts.push_back(parseSimpleTypeAnnotation(false).type);
isIntersection = true;
}
else
@ -1462,6 +1477,30 @@ AstType* Parser::parseTypeAnnotation(TempVector<AstType*>& parts, const Location
ParseError::raise(begin, "Composite type was not an intersection or union.");
}
AstTypeOrPack Parser::parseTypeOrPackAnnotation()
{
unsigned int oldRecursionCount = recursionCounter;
incrementRecursionCounter("type annotation");
Location begin = lexer.current().location;
TempVector<AstType*> parts(scratchAnnotation);
auto [type, typePack] = parseSimpleTypeAnnotation(true);
if (typePack)
{
LUAU_ASSERT(!type);
return {{}, typePack};
}
parts.push_back(type);
recursionCounter = oldRecursionCount;
return {parseTypeAnnotation(parts, begin), {}};
}
AstType* Parser::parseTypeAnnotation()
{
unsigned int oldRecursionCount = recursionCounter;
@ -1470,7 +1509,7 @@ AstType* Parser::parseTypeAnnotation()
Location begin = lexer.current().location;
TempVector<AstType*> parts(scratchAnnotation);
parts.push_back(parseSimpleTypeAnnotation());
parts.push_back(parseSimpleTypeAnnotation(false).type);
recursionCounter = oldRecursionCount;
@ -1479,7 +1518,7 @@ AstType* Parser::parseTypeAnnotation()
// typeannotation ::= nil | Name[`.' Name] [ `<' typeannotation [`,' ...] `>' ] | `typeof' `(' expr `)' | `{' [PropList] `}'
// | [`<' varlist `>'] `(' [TypeList] `)' `->` ReturnType
AstType* Parser::parseSimpleTypeAnnotation()
AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack)
{
incrementRecursionCounter("type annotation");
@ -1488,7 +1527,7 @@ AstType* Parser::parseSimpleTypeAnnotation()
if (lexer.current().type == Lexeme::ReservedNil)
{
nextLexeme();
return allocator.alloc<AstTypeReference>(begin, std::nullopt, nameNil);
return {allocator.alloc<AstTypeReference>(begin, std::nullopt, nameNil), {}};
}
else if (lexer.current().type == Lexeme::Name)
{
@ -1514,22 +1553,41 @@ AstType* Parser::parseSimpleTypeAnnotation()
expectMatchAndConsume(')', typeofBegin);
return allocator.alloc<AstTypeTypeof>(Location(begin, end), expr);
return {allocator.alloc<AstTypeTypeof>(Location(begin, end), expr), {}};
}
AstArray<AstType*> generics = parseTypeParams();
if (FFlag::LuauParseTypePackTypeParameters)
{
bool hasParameters = false;
AstArray<AstTypeOrPack> parameters{};
Location end = lexer.previousLocation();
if (lexer.current().type == '<')
{
hasParameters = true;
parameters = parseTypeParams();
}
return allocator.alloc<AstTypeReference>(Location(begin, end), prefix, name.name, generics);
Location end = lexer.previousLocation();
return {allocator.alloc<AstTypeReference>(Location(begin, end), prefix, name.name, hasParameters, parameters), {}};
}
else
{
AstArray<AstTypeOrPack> generics = parseTypeParams();
Location end = lexer.previousLocation();
// false in 'hasParameterList' as it is not used without FFlagLuauTypeAliasPacks
return {allocator.alloc<AstTypeReference>(Location(begin, end), prefix, name.name, false, generics), {}};
}
}
else if (lexer.current().type == '{')
{
return parseTableTypeAnnotation();
return {parseTableTypeAnnotation(), {}};
}
else if (lexer.current().type == '(' || (FFlag::LuauParseGenericFunctions && lexer.current().type == '<'))
{
return parseFunctionTypeAnnotation();
return parseFunctionTypeAnnotation(allowPack);
}
else
{
@ -1538,7 +1596,7 @@ AstType* Parser::parseSimpleTypeAnnotation()
// For a missing type annotation, capture 'space' between last token and the next one
location = Location(lexer.previousLocation().end, lexer.current().location.begin);
return reportTypeAnnotationError(location, {}, /*isMissing*/ true, "Expected type, got %s", lexer.current().toString().c_str());
return {reportTypeAnnotationError(location, {}, /*isMissing*/ true, "Expected type, got %s", lexer.current().toString().c_str()), {}};
}
}
@ -2312,18 +2370,59 @@ std::pair<AstArray<AstName>, AstArray<AstName>> Parser::parseGenericTypeList()
return {generics, genericPacks};
}
AstArray<AstType*> Parser::parseTypeParams()
AstArray<AstTypeOrPack> Parser::parseTypeParams()
{
TempVector<AstType*> result{scratchAnnotation};
TempVector<AstTypeOrPack> parameters{scratchTypeOrPackAnnotation};
if (lexer.current().type == '<')
{
Lexeme begin = lexer.current();
nextLexeme();
bool seenPack = false;
while (true)
{
result.push_back(parseTypeAnnotation());
if (FFlag::LuauParseTypePackTypeParameters)
{
if (shouldParseTypePackAnnotation(lexer))
{
seenPack = true;
auto typePack = parseTypePackAnnotation();
if (FFlag::LuauTypeAliasPacks) // Type packs are recorded only is we can handle them
parameters.push_back({{}, typePack});
}
else if (lexer.current().type == '(')
{
auto [type, typePack] = parseTypeOrPackAnnotation();
if (typePack)
{
seenPack = true;
if (FFlag::LuauTypeAliasPacks) // Type packs are recorded only is we can handle them
parameters.push_back({{}, typePack});
}
else
{
parameters.push_back({type, {}});
}
}
else if (lexer.current().type == '>' && parameters.empty())
{
break;
}
else
{
parameters.push_back({parseTypeAnnotation(), {}});
}
}
else
{
parameters.push_back({parseTypeAnnotation(), {}});
}
if (lexer.current().type == ',')
nextLexeme();
else
@ -2333,7 +2432,7 @@ AstArray<AstType*> Parser::parseTypeParams()
expectMatchAndConsume('>', begin);
}
return copy(result);
return copy(parameters);
}
AstExpr* Parser::parseString()

251
Ast/src/TimeTrace.cpp Normal file
View file

@ -0,0 +1,251 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/TimeTrace.h"
#include "Luau/StringUtils.h"
#include <mutex>
#include <string>
#include <stdlib.h>
#ifdef _WIN32
#include <Windows.h>
#endif
#ifdef __APPLE__
#include <mach/mach.h>
#include <mach/mach_time.h>
#endif
#include <time.h>
LUAU_FASTFLAGVARIABLE(DebugLuauTimeTracing, false)
#if defined(LUAU_ENABLE_TIME_TRACE)
namespace Luau
{
namespace TimeTrace
{
static double getClockPeriod()
{
#if defined(_WIN32)
LARGE_INTEGER result = {};
QueryPerformanceFrequency(&result);
return 1.0 / double(result.QuadPart);
#elif defined(__APPLE__)
mach_timebase_info_data_t result = {};
mach_timebase_info(&result);
return double(result.numer) / double(result.denom) * 1e-9;
#elif defined(__linux__)
return 1e-9;
#else
return 1.0 / double(CLOCKS_PER_SEC);
#endif
}
static double getClockTimestamp()
{
#if defined(_WIN32)
LARGE_INTEGER result = {};
QueryPerformanceCounter(&result);
return double(result.QuadPart);
#elif defined(__APPLE__)
return double(mach_absolute_time());
#elif defined(__linux__)
timespec now;
clock_gettime(CLOCK_MONOTONIC, &now);
return now.tv_sec * 1e9 + now.tv_nsec;
#else
return double(clock());
#endif
}
uint32_t getClockMicroseconds()
{
static double period = getClockPeriod() * 1e6;
static double start = getClockTimestamp();
return uint32_t((getClockTimestamp() - start) * period);
}
struct GlobalContext
{
GlobalContext() = default;
~GlobalContext()
{
// Ideally we would want all ThreadContext destructors to run
// But in VS, not all thread_local object instances are destroyed
for (ThreadContext* context : threads)
{
if (!context->events.empty())
context->flushEvents();
}
if (traceFile)
fclose(traceFile);
}
std::mutex mutex;
std::vector<ThreadContext*> threads;
uint32_t nextThreadId = 0;
std::vector<Token> tokens;
FILE* traceFile = nullptr;
};
GlobalContext& getGlobalContext()
{
static GlobalContext context;
return context;
}
uint16_t createToken(GlobalContext& context, const char* name, const char* category)
{
std::scoped_lock lock(context.mutex);
LUAU_ASSERT(context.tokens.size() < 64 * 1024);
context.tokens.push_back({name, category});
return uint16_t(context.tokens.size() - 1);
}
uint32_t createThread(GlobalContext& context, ThreadContext* threadContext)
{
std::scoped_lock lock(context.mutex);
context.threads.push_back(threadContext);
return ++context.nextThreadId;
}
void releaseThread(GlobalContext& context, ThreadContext* threadContext)
{
std::scoped_lock lock(context.mutex);
if (auto it = std::find(context.threads.begin(), context.threads.end(), threadContext); it != context.threads.end())
context.threads.erase(it);
}
void flushEvents(GlobalContext& context, uint32_t threadId, const std::vector<Event>& events, const std::vector<char>& data)
{
std::scoped_lock lock(context.mutex);
if (!context.traceFile)
{
context.traceFile = fopen("trace.json", "w");
if (!context.traceFile)
return;
fprintf(context.traceFile, "[\n");
}
std::string temp;
const unsigned tempReserve = 64 * 1024;
temp.reserve(tempReserve);
const char* rawData = data.data();
// Formatting state
bool unfinishedEnter = false;
bool unfinishedArgs = false;
for (const Event& ev : events)
{
switch (ev.type)
{
case EventType::Enter:
{
if (unfinishedArgs)
{
formatAppend(temp, "}");
unfinishedArgs = false;
}
if (unfinishedEnter)
{
formatAppend(temp, "},\n");
unfinishedEnter = false;
}
Token& token = context.tokens[ev.token];
formatAppend(temp, R"({"name": "%s", "cat": "%s", "ph": "B", "ts": %u, "pid": 0, "tid": %u)", token.name, token.category,
ev.data.microsec, threadId);
unfinishedEnter = true;
}
break;
case EventType::Leave:
if (unfinishedArgs)
{
formatAppend(temp, "}");
unfinishedArgs = false;
}
if (unfinishedEnter)
{
formatAppend(temp, "},\n");
unfinishedEnter = false;
}
formatAppend(temp,
R"({"ph": "E", "ts": %u, "pid": 0, "tid": %u},)"
"\n",
ev.data.microsec, threadId);
break;
case EventType::ArgName:
LUAU_ASSERT(unfinishedEnter);
if (!unfinishedArgs)
{
formatAppend(temp, R"(, "args": { "%s": )", rawData + ev.data.dataPos);
unfinishedArgs = true;
}
else
{
formatAppend(temp, R"(, "%s": )", rawData + ev.data.dataPos);
}
break;
case EventType::ArgValue:
LUAU_ASSERT(unfinishedArgs);
formatAppend(temp, R"("%s")", rawData + ev.data.dataPos);
break;
}
// Don't want to hit the string capacity and reallocate
if (temp.size() > tempReserve - 1024)
{
fwrite(temp.data(), 1, temp.size(), context.traceFile);
temp.clear();
}
}
if (unfinishedArgs)
{
formatAppend(temp, "}");
unfinishedArgs = false;
}
if (unfinishedEnter)
{
formatAppend(temp, "},\n");
unfinishedEnter = false;
}
fwrite(temp.data(), 1, temp.size(), context.traceFile);
fflush(context.traceFile);
}
ThreadContext& getThreadContext()
{
thread_local ThreadContext context;
return context;
}
std::pair<uint16_t, Luau::TimeTrace::ThreadContext&> createScopeData(const char* name, const char* category)
{
uint16_t token = createToken(Luau::TimeTrace::getGlobalContext(), name, category);
return {token, Luau::TimeTrace::getThreadContext()};
}
} // namespace TimeTrace
} // namespace Luau
#endif

View file

@ -111,11 +111,24 @@ struct CliFileResolver : Luau::FileResolver
return Luau::SourceCode{*source, Luau::SourceCode::Module};
}
std::optional<Luau::ModuleInfo> resolveModule(const Luau::ModuleInfo* context, Luau::AstExpr* node) override
{
if (Luau::AstExprConstantString* expr = node->as<Luau::AstExprConstantString>())
{
Luau::ModuleName name = std::string(expr->value.data, expr->value.size) + ".lua";
return {{name}};
}
return std::nullopt;
}
bool moduleExists(const Luau::ModuleName& name) const override
{
return !!readFile(name);
}
std::optional<Luau::ModuleName> fromAstFragment(Luau::AstExpr* expr) const override
{
return std::nullopt;
@ -130,11 +143,6 @@ struct CliFileResolver : Luau::FileResolver
{
return std::nullopt;
}
std::optional<std::string> getEnvironmentForModule(const Luau::ModuleName& name) const override
{
return std::nullopt;
}
};
struct CliConfigResolver : Luau::ConfigResolver

View file

@ -179,7 +179,7 @@ static std::string runCode(lua_State* L, const std::string& source)
{
error = "thread yielded unexpectedly";
}
else if (const char* str = lua_tostring(L, -1))
else if (const char* str = lua_tostring(T, -1))
{
error = str;
}

View file

@ -4,6 +4,7 @@
#include "Luau/Parser.h"
#include "Luau/BytecodeBuilder.h"
#include "Luau/Common.h"
#include "Luau/TimeTrace.h"
#include <algorithm>
#include <bitset>
@ -137,6 +138,11 @@ struct Compiler
uint32_t compileFunction(AstExprFunction* func)
{
LUAU_TIMETRACE_SCOPE("Compiler::compileFunction", "Compiler");
if (func->debugname.value)
LUAU_TIMETRACE_ARGUMENT("name", func->debugname.value);
LUAU_ASSERT(!functions.contains(func));
LUAU_ASSERT(regTop == 0 && stackSize == 0 && localStack.empty() && upvals.empty());
@ -3686,6 +3692,8 @@ struct Compiler
void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstNameTable& names, const CompileOptions& options)
{
LUAU_TIMETRACE_SCOPE("compileOrThrow", "Compiler");
Compiler compiler(bytecode, options);
// since access to some global objects may result in values that change over time, we block table imports
@ -3748,6 +3756,8 @@ void compileOrThrow(BytecodeBuilder& bytecode, const std::string& source, const
std::string compile(const std::string& source, const CompileOptions& options, const ParseOptions& parseOptions, BytecodeEncoder* encoder)
{
LUAU_TIMETRACE_SCOPE("compile", "Compiler");
Allocator allocator;
AstNameTable names(allocator);
ParseResult result = Parser::parse(source.c_str(), source.size(), names, allocator, parseOptions);

View file

@ -1,4 +1,5 @@
# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
.SUFFIXES:
MAKEFLAGS+=-r -j8
COMMA=,

View file

@ -22,7 +22,7 @@ You can download the binaries from [a recent release](https://github.com/Roblox/
# Building
To build Luau tools or tests yourself, you can use CMake on all platforms, or alternatively make (on Linux/macOS). For example:
To build Luau tools or tests yourself, you can use CMake on all platforms:
```sh
mkdir cmake && cd cmake
@ -31,6 +31,12 @@ cmake --build . --target Luau.Repl.CLI --config RelWithDebInfo
cmake --build . --target Luau.Analyze.CLI --config RelWithDebInfo
```
Alternatively, on Linus/macOS you can use make:
```sh
make config=release luau luau-analyze
```
To integrate Luau into your CMake application projects, at the minimum you'll need to depend on `Luau.Compiler` and `Luau.VM` projects. From there you need to create a new Luau state (using Lua 5.x API such as `lua_newstate`), compile source to bytecode and load it into the VM like this:
```cpp

View file

@ -9,6 +9,7 @@ target_sources(Luau.Ast PRIVATE
Ast/include/Luau/ParseOptions.h
Ast/include/Luau/Parser.h
Ast/include/Luau/StringUtils.h
Ast/include/Luau/TimeTrace.h
Ast/src/Ast.cpp
Ast/src/Confusables.cpp
@ -16,6 +17,7 @@ target_sources(Luau.Ast PRIVATE
Ast/src/Location.cpp
Ast/src/Parser.cpp
Ast/src/StringUtils.cpp
Ast/src/TimeTrace.cpp
)
# Luau.Compiler Sources
@ -44,8 +46,10 @@ target_sources(Luau.Analysis PRIVATE
Analysis/include/Luau/Module.h
Analysis/include/Luau/ModuleResolver.h
Analysis/include/Luau/Predicate.h
Analysis/include/Luau/Quantify.h
Analysis/include/Luau/RecursionCounter.h
Analysis/include/Luau/RequireTracer.h
Analysis/include/Luau/Scope.h
Analysis/include/Luau/Substitution.h
Analysis/include/Luau/Symbol.h
Analysis/include/Luau/TopoSortStatements.h
@ -60,6 +64,7 @@ target_sources(Luau.Analysis PRIVATE
Analysis/include/Luau/TypeVar.h
Analysis/include/Luau/Unifiable.h
Analysis/include/Luau/Unifier.h
Analysis/include/Luau/UnifierSharedState.h
Analysis/include/Luau/Variant.h
Analysis/include/Luau/VisitTypeVar.h
@ -74,7 +79,9 @@ target_sources(Luau.Analysis PRIVATE
Analysis/src/Linter.cpp
Analysis/src/Module.cpp
Analysis/src/Predicate.cpp
Analysis/src/Quantify.cpp
Analysis/src/RequireTracer.cpp
Analysis/src/Scope.cpp
Analysis/src/Substitution.cpp
Analysis/src/Symbol.cpp
Analysis/src/TopoSortStatements.cpp
@ -188,6 +195,7 @@ if(TARGET Luau.UnitTest)
tests/TopoSort.test.cpp
tests/ToString.test.cpp
tests/Transpiler.test.cpp
tests/TypeInfer.aliases.test.cpp
tests/TypeInfer.annotations.test.cpp
tests/TypeInfer.builtins.test.cpp
tests/TypeInfer.classes.test.cpp

View file

@ -13,8 +13,6 @@
#include <string.h>
LUAU_FASTFLAG(LuauGcFullSkipInactiveThreads)
const char* lua_ident = "$Lua: Lua 5.1.4 Copyright (C) 1994-2008 Lua.org, PUC-Rio $\n"
"$Authors: R. Ierusalimschy, L. H. de Figueiredo & W. Celes $\n"
"$URL: www.lua.org $\n";
@ -1153,7 +1151,7 @@ void* lua_newuserdatadtor(lua_State* L, size_t sz, void (*dtor)(void*))
luaC_checkGC(L);
luaC_checkthreadsleep(L);
Udata* u = luaS_newudata(L, sz + sizeof(dtor), UTAG_IDTOR);
memcpy(u->data + sz, &dtor, sizeof(dtor));
memcpy(&u->data + sz, &dtor, sizeof(dtor));
setuvalue(L, L->top, u);
api_incr_top(L);
return u->data;

View file

@ -5,8 +5,6 @@
#include "lstate.h"
#include "lvm.h"
LUAU_FASTFLAGVARIABLE(LuauPreferXpush, false)
#define CO_RUN 0 /* running */
#define CO_SUS 1 /* suspended */
#define CO_NOR 2 /* 'normal' (it resumed another coroutine) */
@ -17,7 +15,7 @@ LUAU_FASTFLAGVARIABLE(LuauPreferXpush, false)
static const char* const statnames[] = {"running", "suspended", "normal", "dead"};
static int costatus(lua_State* L, lua_State* co)
static int auxstatus(lua_State* L, lua_State* co)
{
if (co == L)
return CO_RUN;
@ -34,11 +32,11 @@ static int costatus(lua_State* L, lua_State* co)
return CO_SUS; /* initial state */
}
static int luaB_costatus(lua_State* L)
static int costatus(lua_State* L)
{
lua_State* co = lua_tothread(L, 1);
luaL_argexpected(L, co, 1, "thread");
lua_pushstring(L, statnames[costatus(L, co)]);
lua_pushstring(L, statnames[auxstatus(L, co)]);
return 1;
}
@ -47,7 +45,7 @@ static int auxresume(lua_State* L, lua_State* co, int narg)
// error handling for edge cases
if (co->status != LUA_YIELD)
{
int status = costatus(L, co);
int status = auxstatus(L, co);
if (status != CO_SUS)
{
lua_pushfstring(L, "cannot resume %s coroutine", statnames[status]);
@ -115,7 +113,7 @@ static int auxresumecont(lua_State* L, lua_State* co)
}
}
static int luaB_coresumefinish(lua_State* L, int r)
static int coresumefinish(lua_State* L, int r)
{
if (r < 0)
{
@ -131,7 +129,7 @@ static int luaB_coresumefinish(lua_State* L, int r)
}
}
static int luaB_coresumey(lua_State* L)
static int coresumey(lua_State* L)
{
lua_State* co = lua_tothread(L, 1);
luaL_argexpected(L, co, 1, "thread");
@ -141,10 +139,10 @@ static int luaB_coresumey(lua_State* L)
if (r == CO_STATUS_BREAK)
return interruptThread(L, co);
return luaB_coresumefinish(L, r);
return coresumefinish(L, r);
}
static int luaB_coresumecont(lua_State* L, int status)
static int coresumecont(lua_State* L, int status)
{
lua_State* co = lua_tothread(L, 1);
luaL_argexpected(L, co, 1, "thread");
@ -155,10 +153,10 @@ static int luaB_coresumecont(lua_State* L, int status)
int r = auxresumecont(L, co);
return luaB_coresumefinish(L, r);
return coresumefinish(L, r);
}
static int luaB_auxwrapfinish(lua_State* L, int r)
static int auxwrapfinish(lua_State* L, int r)
{
if (r < 0)
{
@ -173,7 +171,7 @@ static int luaB_auxwrapfinish(lua_State* L, int r)
return r;
}
static int luaB_auxwrapy(lua_State* L)
static int auxwrapy(lua_State* L)
{
lua_State* co = lua_tothread(L, lua_upvalueindex(1));
int narg = cast_int(L->top - L->base);
@ -182,10 +180,10 @@ static int luaB_auxwrapy(lua_State* L)
if (r == CO_STATUS_BREAK)
return interruptThread(L, co);
return luaB_auxwrapfinish(L, r);
return auxwrapfinish(L, r);
}
static int luaB_auxwrapcont(lua_State* L, int status)
static int auxwrapcont(lua_State* L, int status)
{
lua_State* co = lua_tothread(L, lua_upvalueindex(1));
@ -195,62 +193,52 @@ static int luaB_auxwrapcont(lua_State* L, int status)
int r = auxresumecont(L, co);
return luaB_auxwrapfinish(L, r);
return auxwrapfinish(L, r);
}
static int luaB_cocreate(lua_State* L)
static int cocreate(lua_State* L)
{
luaL_checktype(L, 1, LUA_TFUNCTION);
lua_State* NL = lua_newthread(L);
if (FFlag::LuauPreferXpush)
{
lua_xpush(L, NL, 1); // push function on top of NL
}
else
{
lua_pushvalue(L, 1); /* move function to top */
lua_xmove(L, NL, 1); /* move function from L to NL */
}
lua_xpush(L, NL, 1); // push function on top of NL
return 1;
}
static int luaB_cowrap(lua_State* L)
static int cowrap(lua_State* L)
{
luaB_cocreate(L);
cocreate(L);
lua_pushcfunction(L, luaB_auxwrapy, NULL, 1, luaB_auxwrapcont);
lua_pushcfunction(L, auxwrapy, NULL, 1, auxwrapcont);
return 1;
}
static int luaB_yield(lua_State* L)
static int coyield(lua_State* L)
{
int nres = cast_int(L->top - L->base);
return lua_yield(L, nres);
}
static int luaB_corunning(lua_State* L)
static int corunning(lua_State* L)
{
if (lua_pushthread(L))
lua_pushnil(L); /* main thread is not a coroutine */
return 1;
}
static int luaB_yieldable(lua_State* L)
static int coyieldable(lua_State* L)
{
lua_pushboolean(L, lua_isyieldable(L));
return 1;
}
static const luaL_Reg co_funcs[] = {
{"create", luaB_cocreate},
{"running", luaB_corunning},
{"status", luaB_costatus},
{"wrap", luaB_cowrap},
{"yield", luaB_yield},
{"isyieldable", luaB_yieldable},
{"create", cocreate},
{"running", corunning},
{"status", costatus},
{"wrap", cowrap},
{"yield", coyield},
{"isyieldable", coyieldable},
{NULL, NULL},
};
@ -258,7 +246,7 @@ LUALIB_API int luaopen_coroutine(lua_State* L)
{
luaL_register(L, LUA_COLIBNAME, co_funcs);
lua_pushcfunction(L, luaB_coresumey, "resume", 0, luaB_coresumecont);
lua_pushcfunction(L, coresumey, "resume", 0, coresumecont);
lua_setfield(L, -2, "resume");
return 1;

View file

@ -8,12 +8,17 @@
#include "lmem.h"
#include "lvm.h"
#include <stdexcept>
#if LUA_USE_LONGJMP
#include <setjmp.h>
#include <stdlib.h>
#else
#include <stdexcept>
#endif
#include <string.h>
LUAU_FASTFLAGVARIABLE(LuauExceptionMessageFix, false)
LUAU_FASTFLAGVARIABLE(LuauCcallRestoreFix, false)
/*
** {======================================================
@ -51,8 +56,8 @@ l_noret luaD_throw(lua_State* L, int errcode)
longjmp(jb->buf, 1);
}
if (L->global->panic)
L->global->panic(L, errcode);
if (L->global->cb.panic)
L->global->cb.panic(L, errcode);
abort();
}
@ -532,6 +537,12 @@ int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t e
status = LUA_ERRERR;
}
if (FFlag::LuauCcallRestoreFix)
{
// Restore nCcalls before calling the debugprotectederror callback which may rely on the proper value to have been restored.
L->nCcalls = oldnCcalls;
}
// an error occurred, check if we have a protected error callback
if (L->global->cb.debugprotectederror)
{
@ -545,7 +556,10 @@ int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t e
StkId oldtop = restorestack(L, old_top);
luaF_close(L, oldtop); /* close eventual pending closures */
seterrorobj(L, status, oldtop);
L->nCcalls = oldnCcalls;
if (!FFlag::LuauCcallRestoreFix)
{
L->nCcalls = oldnCcalls;
}
L->ci = restoreci(L, old_ci);
L->base = L->ci->base;
restore_stack_limit(L);

View file

@ -12,10 +12,10 @@
#include <string.h>
#include <stdio.h>
LUAU_FASTFLAGVARIABLE(LuauRescanGrayAgain, false)
LUAU_FASTFLAGVARIABLE(LuauRescanGrayAgainForwardBarrier, false)
LUAU_FASTFLAGVARIABLE(LuauGcFullSkipInactiveThreads, false)
LUAU_FASTFLAGVARIABLE(LuauShrinkWeakTables, false)
LUAU_FASTFLAGVARIABLE(LuauConsolidatedStep, false)
LUAU_FASTFLAGVARIABLE(LuauSeparateAtomic, false)
LUAU_FASTFLAG(LuauArrayBoundary)
#define GC_SWEEPMAX 40
@ -64,13 +64,18 @@ static void recordGcStateTime(global_State* g, int startgcstate, double seconds,
g->gcstats.currcycle.marktime += seconds;
// atomic step had to be performed during the switch and it's tracked separately
if (g->gcstate == GCSsweepstring)
if (!FFlag::LuauSeparateAtomic && g->gcstate == GCSsweepstring)
g->gcstats.currcycle.marktime -= g->gcstats.currcycle.atomictime;
break;
case GCSatomic:
g->gcstats.currcycle.atomictime += seconds;
break;
case GCSsweepstring:
case GCSsweep:
g->gcstats.currcycle.sweeptime += seconds;
break;
default:
LUAU_ASSERT(!"Unexpected GC state");
}
if (assist)
@ -181,33 +186,15 @@ static int traversetable(global_State* g, Table* h)
if (h->metatable)
markobject(g, cast_to(Table*, h->metatable));
if (FFlag::LuauShrinkWeakTables)
/* is there a weak mode? */
if (const char* modev = gettablemode(g, h))
{
/* is there a weak mode? */
if (const char* modev = gettablemode(g, h))
{
weakkey = (strchr(modev, 'k') != NULL);
weakvalue = (strchr(modev, 'v') != NULL);
if (weakkey || weakvalue)
{ /* is really weak? */
h->gclist = g->weak; /* must be cleared after GC, ... */
g->weak = obj2gco(h); /* ... so put in the appropriate list */
}
}
}
else
{
const TValue* mode = gfasttm(g, h->metatable, TM_MODE);
if (mode && ttisstring(mode))
{ /* is there a weak mode? */
const char* modev = svalue(mode);
weakkey = (strchr(modev, 'k') != NULL);
weakvalue = (strchr(modev, 'v') != NULL);
if (weakkey || weakvalue)
{ /* is really weak? */
h->gclist = g->weak; /* must be cleared after GC, ... */
g->weak = obj2gco(h); /* ... so put in the appropriate list */
}
weakkey = (strchr(modev, 'k') != NULL);
weakvalue = (strchr(modev, 'v') != NULL);
if (weakkey || weakvalue)
{ /* is really weak? */
h->gclist = g->weak; /* must be cleared after GC, ... */
g->weak = obj2gco(h); /* ... so put in the appropriate list */
}
}
@ -295,7 +282,7 @@ static void traversestack(global_State* g, lua_State* l, bool clearstack)
for (StkId o = l->stack; o < l->top; o++)
markvalue(g, o);
/* final traversal? */
if (g->gcstate == GCSatomic || (FFlag::LuauGcFullSkipInactiveThreads && clearstack))
if (g->gcstate == GCSatomic || clearstack)
{
StkId stack_end = l->stack + l->stacksize;
for (StkId o = l->top; o < stack_end; o++) /* clear not-marked stack slice */
@ -334,28 +321,16 @@ static size_t propagatemark(global_State* g)
lua_State* th = gco2th(o);
g->gray = th->gclist;
if (FFlag::LuauGcFullSkipInactiveThreads)
LUAU_ASSERT(!luaC_threadsleeping(th));
// threads that are executing and the main thread are not deactivated
bool active = luaC_threadactive(th) || th == th->global->mainthread;
if (!active && g->gcstate == GCSpropagate)
{
LUAU_ASSERT(!luaC_threadsleeping(th));
traversestack(g, th, /* clearstack= */ true);
// threads that are executing and the main thread are not deactivated
bool active = luaC_threadactive(th) || th == th->global->mainthread;
if (!active && g->gcstate == GCSpropagate)
{
traversestack(g, th, /* clearstack= */ true);
l_setbit(th->stackstate, THREAD_SLEEPINGBIT);
}
else
{
th->gclist = g->grayagain;
g->grayagain = o;
black2gray(o);
traversestack(g, th, /* clearstack= */ false);
}
l_setbit(th->stackstate, THREAD_SLEEPINGBIT);
}
else
{
@ -383,12 +358,14 @@ static size_t propagatemark(global_State* g)
}
}
static void propagateall(global_State* g)
static size_t propagateall(global_State* g)
{
size_t work = 0;
while (g->gray)
{
propagatemark(g);
work += propagatemark(g);
}
return work;
}
/*
@ -413,11 +390,14 @@ static int isobjcleared(GCObject* o)
/*
** clear collected entries from weaktables
*/
static void cleartable(lua_State* L, GCObject* l)
static size_t cleartable(lua_State* L, GCObject* l)
{
size_t work = 0;
while (l)
{
Table* h = gco2h(l);
work += sizeof(Table) + sizeof(TValue) * h->sizearray + sizeof(LuaNode) * sizenode(h);
int i = h->sizearray;
while (i--)
{
@ -431,50 +411,36 @@ static void cleartable(lua_State* L, GCObject* l)
{
LuaNode* n = gnode(h, i);
if (FFlag::LuauShrinkWeakTables)
// non-empty entry?
if (!ttisnil(gval(n)))
{
// non-empty entry?
if (!ttisnil(gval(n)))
{
// can we clear key or value?
if (iscleared(gkey(n)) || iscleared(gval(n)))
{
setnilvalue(gval(n)); /* remove value ... */
removeentry(n); /* remove entry from table */
}
else
{
activevalues++;
}
}
}
else
{
if (!ttisnil(gval(n)) && /* non-empty entry? */
(iscleared(gkey(n)) || iscleared(gval(n))))
// can we clear key or value?
if (iscleared(gkey(n)) || iscleared(gval(n)))
{
setnilvalue(gval(n)); /* remove value ... */
removeentry(n); /* remove entry from table */
}
else
{
activevalues++;
}
}
}
if (FFlag::LuauShrinkWeakTables)
if (const char* modev = gettablemode(L->global, h))
{
if (const char* modev = gettablemode(L->global, h))
// are we allowed to shrink this weak table?
if (strchr(modev, 's'))
{
// are we allowed to shrink this weak table?
if (strchr(modev, 's'))
{
// shrink at 37.5% occupancy
if (activevalues < sizenode(h) * 3 / 8)
luaH_resizehash(L, h, activevalues);
}
// shrink at 37.5% occupancy
if (activevalues < sizenode(h) * 3 / 8)
luaH_resizehash(L, h, activevalues);
}
}
l = h->gclist;
}
return work;
}
static void shrinkstack(lua_State* L)
@ -653,37 +619,49 @@ static void markroot(lua_State* L)
g->gcstate = GCSpropagate;
}
static void remarkupvals(global_State* g)
static size_t remarkupvals(global_State* g)
{
UpVal* uv;
for (uv = g->uvhead.u.l.next; uv != &g->uvhead; uv = uv->u.l.next)
size_t work = 0;
for (UpVal* uv = g->uvhead.u.l.next; uv != &g->uvhead; uv = uv->u.l.next)
{
work += sizeof(UpVal);
LUAU_ASSERT(uv->u.l.next->u.l.prev == uv && uv->u.l.prev->u.l.next == uv);
if (isgray(obj2gco(uv)))
markvalue(g, uv->v);
}
return work;
}
static void atomic(lua_State* L)
static size_t atomic(lua_State* L)
{
global_State* g = L->global;
g->gcstate = GCSatomic;
size_t work = 0;
if (FFlag::LuauSeparateAtomic)
{
LUAU_ASSERT(g->gcstate == GCSatomic);
}
else
{
g->gcstate = GCSatomic;
}
/* remark occasional upvalues of (maybe) dead threads */
remarkupvals(g);
work += remarkupvals(g);
/* traverse objects caught by write barrier and by 'remarkupvals' */
propagateall(g);
work += propagateall(g);
/* remark weak tables */
g->gray = g->weak;
g->weak = NULL;
LUAU_ASSERT(!iswhite(obj2gco(g->mainthread)));
markobject(g, L); /* mark running thread */
markmt(g); /* mark basic metatables (again) */
propagateall(g);
work += propagateall(g);
/* remark gray again */
g->gray = g->grayagain;
g->grayagain = NULL;
propagateall(g);
cleartable(L, g->weak); /* remove collected objects from weak tables */
work += propagateall(g);
work += cleartable(L, g->weak); /* remove collected objects from weak tables */
g->weak = NULL;
/* flip current white */
g->currentwhite = cast_byte(otherwhite(g));
@ -691,7 +669,12 @@ static void atomic(lua_State* L)
g->sweepgc = &g->rootgc;
g->gcstate = GCSsweepstring;
GC_INTERRUPT(GCSatomic);
if (!FFlag::LuauSeparateAtomic)
{
GC_INTERRUPT(GCSatomic);
}
return work;
}
static size_t singlestep(lua_State* L)
@ -703,46 +686,24 @@ static size_t singlestep(lua_State* L)
case GCSpause:
{
markroot(L); /* start a new collection */
LUAU_ASSERT(g->gcstate == GCSpropagate);
break;
}
case GCSpropagate:
{
if (FFlag::LuauRescanGrayAgain)
if (g->gray)
{
if (g->gray)
{
g->gcstats.currcycle.markitems++;
g->gcstats.currcycle.markitems++;
cost = propagatemark(g);
}
else
{
// perform one iteration over 'gray again' list
g->gray = g->grayagain;
g->grayagain = NULL;
g->gcstate = GCSpropagateagain;
}
cost = propagatemark(g);
}
else
{
if (g->gray)
{
g->gcstats.currcycle.markitems++;
// perform one iteration over 'gray again' list
g->gray = g->grayagain;
g->grayagain = NULL;
cost = propagatemark(g);
}
else /* no more `gray' objects */
{
double starttimestamp = lua_clock();
g->gcstats.currcycle.atomicstarttimestamp = starttimestamp;
g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes;
atomic(L); /* finish mark phase */
g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp;
}
g->gcstate = GCSpropagateagain;
}
break;
}
@ -756,17 +717,34 @@ static size_t singlestep(lua_State* L)
}
else /* no more `gray' objects */
{
double starttimestamp = lua_clock();
if (FFlag::LuauSeparateAtomic)
{
g->gcstate = GCSatomic;
}
else
{
double starttimestamp = lua_clock();
g->gcstats.currcycle.atomicstarttimestamp = starttimestamp;
g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes;
g->gcstats.currcycle.atomicstarttimestamp = starttimestamp;
g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes;
atomic(L); /* finish mark phase */
atomic(L); /* finish mark phase */
LUAU_ASSERT(g->gcstate == GCSsweepstring);
g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp;
g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp;
}
}
break;
}
case GCSatomic:
{
g->gcstats.currcycle.atomicstarttimestamp = lua_clock();
g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes;
cost = atomic(L); /* finish mark phase */
LUAU_ASSERT(g->gcstate == GCSsweepstring);
break;
}
case GCSsweepstring:
{
size_t traversedcount = 0;
@ -804,12 +782,133 @@ static size_t singlestep(lua_State* L)
break;
}
default:
LUAU_ASSERT(0);
LUAU_ASSERT(!"Unexpected GC state");
}
return cost;
}
static size_t gcstep(lua_State* L, size_t limit)
{
size_t cost = 0;
global_State* g = L->global;
switch (g->gcstate)
{
case GCSpause:
{
markroot(L); /* start a new collection */
LUAU_ASSERT(g->gcstate == GCSpropagate);
break;
}
case GCSpropagate:
{
while (g->gray && cost < limit)
{
g->gcstats.currcycle.markitems++;
cost += propagatemark(g);
}
if (!g->gray)
{
// perform one iteration over 'gray again' list
g->gray = g->grayagain;
g->grayagain = NULL;
g->gcstate = GCSpropagateagain;
}
break;
}
case GCSpropagateagain:
{
while (g->gray && cost < limit)
{
g->gcstats.currcycle.markitems++;
cost += propagatemark(g);
}
if (!g->gray) /* no more `gray' objects */
{
if (FFlag::LuauSeparateAtomic)
{
g->gcstate = GCSatomic;
}
else
{
double starttimestamp = lua_clock();
g->gcstats.currcycle.atomicstarttimestamp = starttimestamp;
g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes;
atomic(L); /* finish mark phase */
LUAU_ASSERT(g->gcstate == GCSsweepstring);
g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp;
}
}
break;
}
case GCSatomic:
{
g->gcstats.currcycle.atomicstarttimestamp = lua_clock();
g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes;
cost = atomic(L); /* finish mark phase */
LUAU_ASSERT(g->gcstate == GCSsweepstring);
break;
}
case GCSsweepstring:
{
while (g->sweepstrgc < g->strt.size && cost < limit)
{
size_t traversedcount = 0;
sweepwholelist(L, &g->strt.hash[g->sweepstrgc++], &traversedcount);
g->gcstats.currcycle.sweepitems += traversedcount;
cost += GC_SWEEPCOST;
}
// nothing more to sweep?
if (g->sweepstrgc >= g->strt.size)
{
// sweep string buffer list and preserve used string count
uint32_t nuse = L->global->strt.nuse;
size_t traversedcount = 0;
sweepwholelist(L, &g->strbufgc, &traversedcount);
L->global->strt.nuse = nuse;
g->gcstats.currcycle.sweepitems += traversedcount;
g->gcstate = GCSsweep; // end sweep-string phase
}
break;
}
case GCSsweep:
{
while (*g->sweepgc && cost < limit)
{
size_t traversedcount = 0;
g->sweepgc = sweeplist(L, g->sweepgc, GC_SWEEPMAX, &traversedcount);
g->gcstats.currcycle.sweepitems += traversedcount;
cost += GC_SWEEPMAX * GC_SWEEPCOST;
}
if (*g->sweepgc == NULL)
{ /* nothing more to sweep? */
shrinkbuffers(L);
g->gcstate = GCSpause; /* end collection */
}
break;
}
default:
LUAU_ASSERT(!"Unexpected GC state");
}
return cost;
}
static int64_t getheaptriggererroroffset(GCHeapTriggerStats* triggerstats, GCCycleStats* cyclestats)
{
// adjust for error using Proportional-Integral controller
@ -878,33 +977,40 @@ void luaC_step(lua_State* L, bool assist)
if (g->gcstate == GCSpause)
startGcCycleStats(g);
if (assist)
g->gcstats.currcycle.assistwork += lim;
else
g->gcstats.currcycle.explicitwork += lim;
int lastgcstate = g->gcstate;
double lasttimestamp = lua_clock();
// always perform at least one single step
do
if (FFlag::LuauConsolidatedStep)
{
lim -= singlestep(L);
size_t work = gcstep(L, lim);
// if we have switched to a different state, capture the duration of last stage
// this way we reduce the number of timer calls we make
if (lastgcstate != g->gcstate)
if (assist)
g->gcstats.currcycle.assistwork += work;
else
g->gcstats.currcycle.explicitwork += work;
}
else
{
// always perform at least one single step
do
{
GC_INTERRUPT(lastgcstate);
lim -= singlestep(L);
double now = lua_clock();
// if we have switched to a different state, capture the duration of last stage
// this way we reduce the number of timer calls we make
if (lastgcstate != g->gcstate)
{
GC_INTERRUPT(lastgcstate);
recordGcStateTime(g, lastgcstate, now - lasttimestamp, assist);
double now = lua_clock();
lasttimestamp = now;
lastgcstate = g->gcstate;
}
} while (lim > 0 && g->gcstate != GCSpause);
recordGcStateTime(g, lastgcstate, now - lasttimestamp, assist);
lasttimestamp = now;
lastgcstate = g->gcstate;
}
} while (lim > 0 && g->gcstate != GCSpause);
}
recordGcStateTime(g, lastgcstate, lua_clock() - lasttimestamp, assist);
@ -931,7 +1037,14 @@ void luaC_step(lua_State* L, bool assist)
g->GCthreshold -= debt;
}
GC_INTERRUPT(g->gcstate);
if (FFlag::LuauConsolidatedStep)
{
GC_INTERRUPT(lastgcstate);
}
else
{
GC_INTERRUPT(g->gcstate);
}
}
void luaC_fullgc(lua_State* L)
@ -941,7 +1054,7 @@ void luaC_fullgc(lua_State* L)
if (g->gcstate == GCSpause)
startGcCycleStats(g);
if (g->gcstate <= GCSpropagateagain)
if (g->gcstate <= (FFlag::LuauSeparateAtomic ? GCSatomic : GCSpropagateagain))
{
/* reset sweep marks to sweep all elements (returning them to white) */
g->sweepstrgc = 0;
@ -952,12 +1065,15 @@ void luaC_fullgc(lua_State* L)
g->weak = NULL;
g->gcstate = GCSsweepstring;
}
LUAU_ASSERT(g->gcstate != GCSpause && g->gcstate != GCSpropagate && g->gcstate != GCSpropagateagain);
LUAU_ASSERT(g->gcstate == GCSsweepstring || g->gcstate == GCSsweep);
/* finish any pending sweep phase */
while (g->gcstate != GCSpause)
{
LUAU_ASSERT(g->gcstate == GCSsweepstring || g->gcstate == GCSsweep);
singlestep(L);
if (FFlag::LuauConsolidatedStep)
gcstep(L, SIZE_MAX);
else
singlestep(L);
}
finishGcCycleStats(g);
@ -968,7 +1084,10 @@ void luaC_fullgc(lua_State* L)
markroot(L);
while (g->gcstate != GCSpause)
{
singlestep(L);
if (FFlag::LuauConsolidatedStep)
gcstep(L, SIZE_MAX);
else
singlestep(L);
}
/* reclaim as much buffer memory as possible (shrinkbuffers() called during sweep is incremental) */
shrinkbuffersfull(L);
@ -994,14 +1113,11 @@ void luaC_fullgc(lua_State* L)
void luaC_barrierupval(lua_State* L, GCObject* v)
{
if (FFlag::LuauGcFullSkipInactiveThreads)
{
global_State* g = L->global;
LUAU_ASSERT(iswhite(v) && !isdead(g, v));
global_State* g = L->global;
LUAU_ASSERT(iswhite(v) && !isdead(g, v));
if (keepinvariant(g))
reallymarkobject(g, v);
}
if (keepinvariant(g))
reallymarkobject(g, v);
}
void luaC_barrierf(lua_State* L, GCObject* o, GCObject* v)
@ -1629,7 +1745,7 @@ int64_t luaC_allocationrate(lua_State* L)
global_State* g = L->global;
const double durationthreshold = 1e-3; // avoid measuring intervals smaller than 1ms
if (g->gcstate <= GCSpropagateagain)
if (g->gcstate <= (FFlag::LuauSeparateAtomic ? GCSatomic : GCSpropagateagain))
{
double duration = lua_clock() - g->gcstats.lastcycle.endtimestamp;

View file

@ -6,8 +6,6 @@
#include "lobject.h"
#include "lstate.h"
LUAU_FASTFLAG(LuauGcFullSkipInactiveThreads)
/*
** Possible states of the Garbage Collector
*/
@ -25,7 +23,7 @@ LUAU_FASTFLAG(LuauGcFullSkipInactiveThreads)
** still-black objects. The invariant is restored when sweep ends and
** all objects are white again.
*/
#define keepinvariant(g) ((g)->gcstate == GCSpropagate || (g)->gcstate == GCSpropagateagain)
#define keepinvariant(g) ((g)->gcstate == GCSpropagate || (g)->gcstate == GCSpropagateagain || (g)->gcstate == GCSatomic)
/*
** some useful bit tricks
@ -147,4 +145,4 @@ LUAI_FUNC void luaC_validate(lua_State* L);
LUAI_FUNC void luaC_dump(lua_State* L, void* file, const char* (*categoryName)(lua_State* L, uint8_t memcat));
LUAI_FUNC int64_t luaC_allocationrate(lua_State* L);
LUAI_FUNC void luaC_wakethread(lua_State* L);
LUAI_FUNC const char* luaC_statename(int state);
LUAI_FUNC const char* luaC_statename(int state);

View file

@ -199,7 +199,7 @@ static void* luaM_newblock(lua_State* L, int sizeClass)
if (page->freeNext >= 0)
{
block = page->data + page->freeNext;
block = &page->data + page->freeNext;
ASAN_UNPOISON_MEMORY_REGION(block, page->blockSize);
page->freeNext -= page->blockSize;

View file

@ -226,7 +226,7 @@ void luaS_freeudata(lua_State* L, Udata* u)
void (*dtor)(void*) = nullptr;
if (u->tag == UTAG_IDTOR)
memcpy(&dtor, u->data + u->len - sizeof(dtor), sizeof(dtor));
memcpy(&dtor, &u->data + u->len - sizeof(dtor), sizeof(dtor));
else if (u->tag)
dtor = L->global->udatagc[u->tag];

View file

@ -9,14 +9,8 @@
#include "ldebug.h"
#include "lvm.h"
LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauTableMoveTelemetry, false)
LUAU_FASTFLAGVARIABLE(LuauTableFreeze, false)
bool lua_telemetry_table_move_oob_src_from = false;
bool lua_telemetry_table_move_oob_src_to = false;
bool lua_telemetry_table_move_oob_dst = false;
static int foreachi(lua_State* L)
{
luaL_checktype(L, 1, LUA_TTABLE);
@ -202,22 +196,6 @@ static int tmove(lua_State* L)
int tt = !lua_isnoneornil(L, 5) ? 5 : 1; /* destination table */
luaL_checktype(L, tt, LUA_TTABLE);
if (DFFlag::LuauTableMoveTelemetry)
{
int nf = lua_objlen(L, 1);
int nt = lua_objlen(L, tt);
// source index range must be in bounds in source table unless the table is empty (permits 1..#t moves)
if (!(f == 1 || (f >= 1 && f <= nf)))
lua_telemetry_table_move_oob_src_from = true;
if (!(e == nf || (e >= 1 && e <= nf)))
lua_telemetry_table_move_oob_src_to = true;
// destination index must be in bounds in dest table or be exactly at the first empty element (permits concats)
if (!(t == nt + 1 || (t >= 1 && t <= nt + 1)))
lua_telemetry_table_move_oob_dst = true;
}
if (e >= f)
{ /* otherwise, nothing to move */
luaL_argcheck(L, f > 0 || e < INT_MAX + f, 3, "too many elements to move");

View file

@ -16,8 +16,6 @@
#include <string.h>
LUAU_FASTFLAGVARIABLE(LuauLoopUseSafeenv, false)
// Disable c99-designator to avoid the warning in CGOTO dispatch table
#ifdef __clang__
#if __has_warning("-Wc99-designator")
@ -292,10 +290,6 @@ inline bool luau_skipstep(uint8_t op)
return op == LOP_PREPVARARGS || op == LOP_BREAK;
}
// declared in lbaselib.cpp, needed to support cases when pairs/ipairs have been replaced via setfenv
LUAI_FUNC int luaB_inext(lua_State* L);
LUAI_FUNC int luaB_next(lua_State* L);
template<bool SingleStep>
static void luau_execute(lua_State* L)
{
@ -2223,8 +2217,7 @@ static void luau_execute(lua_State* L)
StkId ra = VM_REG(LUAU_INSN_A(insn));
// fast-path: ipairs/inext
bool safeenv = FFlag::LuauLoopUseSafeenv ? cl->env->safeenv : ttisfunction(ra) && clvalue(ra)->isC && clvalue(ra)->c.f == luaB_inext;
if (safeenv && ttistable(ra + 1) && ttisnumber(ra + 2) && nvalue(ra + 2) == 0.0)
if (cl->env->safeenv && ttistable(ra + 1) && ttisnumber(ra + 2) && nvalue(ra + 2) == 0.0)
{
setpvalue(ra + 2, reinterpret_cast<void*>(uintptr_t(0)));
}
@ -2304,8 +2297,7 @@ static void luau_execute(lua_State* L)
StkId ra = VM_REG(LUAU_INSN_A(insn));
// fast-path: pairs/next
bool safeenv = FFlag::LuauLoopUseSafeenv ? cl->env->safeenv : ttisfunction(ra) && clvalue(ra)->isC && clvalue(ra)->c.f == luaB_next;
if (safeenv && ttistable(ra + 1) && ttisnil(ra + 2))
if (cl->env->safeenv && ttistable(ra + 1) && ttisnil(ra + 2))
{
setpvalue(ra + 2, reinterpret_cast<void*>(uintptr_t(0)));
}

View file

@ -12,7 +12,32 @@
#include <string.h>
#include <vector>
// TODO: RAII deallocation doesn't work for longjmp builds if a memory error happens
template<typename T>
struct TempBuffer
{
lua_State* L;
T* data;
size_t count;
TempBuffer(lua_State* L, size_t count)
: L(L)
, data(luaM_newarray(L, count, T, 0))
, count(count)
{
}
~TempBuffer()
{
luaM_freearray(L, data, count, T, 0);
}
T& operator[](size_t index)
{
LUAU_ASSERT(index < count);
return data[index];
}
};
void luaV_getimport(lua_State* L, Table* env, TValue* k, uint32_t id, bool propagatenil)
{
@ -67,7 +92,7 @@ static unsigned int readVarInt(const char* data, size_t size, size_t& offset)
return result;
}
static TString* readString(std::vector<TString*>& strings, const char* data, size_t size, size_t& offset)
static TString* readString(TempBuffer<TString*>& strings, const char* data, size_t size, size_t& offset)
{
unsigned int id = readVarInt(data, size, offset);
@ -133,6 +158,7 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size
}
// pause GC for the duration of deserialization - some objects we're creating aren't rooted
// TODO: if an allocation error happens mid-load, we do not unpause GC!
size_t GCthreshold = L->global->GCthreshold;
L->global->GCthreshold = SIZE_MAX;
@ -144,7 +170,7 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size
// string table
unsigned int stringCount = readVarInt(data, size, offset);
std::vector<TString*> strings(stringCount);
TempBuffer<TString*> strings(L, stringCount);
for (unsigned int i = 0; i < stringCount; ++i)
{
@ -156,7 +182,7 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size
// proto table
unsigned int protoCount = readVarInt(data, size, offset);
std::vector<Proto*> protos(protoCount);
TempBuffer<Proto*> protos(L, protoCount);
for (unsigned int i = 0; i < protoCount; ++i)
{
@ -320,6 +346,8 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size
uint32_t mainid = readVarInt(data, size, offset);
Proto* main = protos[mainid];
luaC_checkthreadsleep(L);
Closure* cl = luaF_newLclosure(L, 0, envt, main);
setclvalue(L, L->top, cl);
incr_top(L);

849
bench/tests/chess.lua Normal file
View file

@ -0,0 +1,849 @@
local bench = script and require(script.Parent.bench_support) or require("bench_support")
local RANKS = "12345678"
local FILES = "abcdefgh"
local PieceSymbols = "PpRrNnBbQqKk"
local UnicodePieces = {"", "", "", "", "", "", "", "", "", "", "", ""}
local StartingFen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
--
-- Lua 5.2 Compat
--
if not table.create then
function table.create(n, v)
local result = {}
for i=1,n do result[i] = v end
return result
end
end
if not table.move then
function table.move(a, from, to, start, target)
local dx = start - from
for i=from,to do
target[i+dx] = a[i]
end
end
end
--
-- Utils
--
local function square(s)
return RANKS:find(s:sub(2,2)) * 8 + FILES:find(s:sub(1,1)) - 9
end
local function squareName(n)
local file = n % 8
local rank = (n-file)/8
return FILES:sub(file+1,file+1) .. RANKS:sub(rank+1,rank+1)
end
local function moveName(v )
local from = bit32.extract(v, 6, 6)
local to = bit32.extract(v, 0, 6)
local piece = bit32.extract(v, 20, 4)
local captured = bit32.extract(v, 25, 4)
local move = PieceSymbols:sub(piece,piece) .. ' ' .. squareName(from) .. (captured ~= 0 and 'x' or '-') .. squareName(to)
if bit32.extract(v,14) == 1 then
if to > from then
return "O-O"
else
return "O-O-O"
end
end
local promote = bit32.extract(v,15,4)
if promote ~= 0 then
move = move .. "=" .. PieceSymbols:sub(promote,promote)
end
return move
end
local function ucimove(m)
local mm = squareName(bit32.extract(m, 6, 6)) .. squareName(bit32.extract(m, 0, 6))
local promote = bit32.extract(m,15,4)
if promote > 0 then
mm = mm .. PieceSymbols:sub(promote,promote):lower()
end
return mm
end
local _utils = {squareName, moveName}
--
-- Bitboards
--
local Bitboard = {}
function Bitboard:toString()
local out = {}
local src = self.h
for x=7,0,-1 do
table.insert(out, RANKS:sub(x+1,x+1))
table.insert(out, " ")
local bit = bit32.lshift(1,(x%4) * 8)
for x=0,7 do
if bit32.band(src, bit) ~= 0 then
table.insert(out, "x ")
else
table.insert(out, "- ")
end
bit = bit32.lshift(bit, 1)
end
if x == 4 then
src = self.l
end
table.insert(out, "\n")
end
table.insert(out, ' ' .. FILES:gsub('.', '%1 ') .. '\n')
table.insert(out, '#: ' .. self:popcnt() .. "\tl:" .. self.l .. "\th:" .. self.h)
return table.concat(out)
end
function Bitboard.from(l ,h )
return setmetatable({l=l, h=h}, Bitboard)
end
Bitboard.zero = Bitboard.from(0,0)
Bitboard.full = Bitboard.from(0xFFFFFFFF, 0xFFFFFFFF)
local Rank1 = Bitboard.from(0x000000FF, 0)
local Rank3 = Bitboard.from(0x00FF0000, 0)
local Rank6 = Bitboard.from(0, 0x0000FF00)
local Rank8 = Bitboard.from(0, 0xFF000000)
local FileA = Bitboard.from(0x01010101, 0x01010101)
local FileB = Bitboard.from(0x02020202, 0x02020202)
local FileC = Bitboard.from(0x04040404, 0x04040404)
local FileD = Bitboard.from(0x08080808, 0x08080808)
local FileE = Bitboard.from(0x10101010, 0x10101010)
local FileF = Bitboard.from(0x20202020, 0x20202020)
local FileG = Bitboard.from(0x40404040, 0x40404040)
local FileH = Bitboard.from(0x80808080, 0x80808080)
local _Files = {FileA, FileB, FileC, FileD, FileE, FileF, FileG, FileH}
-- These masks are filled out below for all files
local RightMasks = {FileH}
local LeftMasks = {FileA}
local function popcnt32(i)
i = i - bit32.band(bit32.rshift(i,1), 0x55555555)
i = bit32.band(i, 0x33333333) + bit32.band(bit32.rshift(i,2), 0x33333333)
return bit32.rshift(bit32.band(i + bit32.rshift(i,4), 0x0F0F0F0F) * 0x01010101, 24)
end
function Bitboard:up()
return self:lshift(8)
end
function Bitboard:down()
return self:rshift(8)
end
function Bitboard:right()
return self:band(FileH:inverse()):lshift(1)
end
function Bitboard:left()
return self:band(FileA:inverse()):rshift(1)
end
function Bitboard:move(x,y)
local out = self
if x < 0 then out = out:bandnot(RightMasks[-x]):lshift(-x) end
if x > 0 then out = out:bandnot(LeftMasks[x]):rshift(x) end
if y < 0 then out = out:rshift(-8 * y) end
if y > 0 then out = out:lshift(8 * y) end
return out
end
function Bitboard:popcnt()
return popcnt32(self.l) + popcnt32(self.h)
end
function Bitboard:band(other )
return Bitboard.from(bit32.band(self.l,other.l), bit32.band(self.h, other.h))
end
function Bitboard:bandnot(other )
return Bitboard.from(bit32.band(self.l,bit32.bnot(other.l)), bit32.band(self.h, bit32.bnot(other.h)))
end
function Bitboard:bandempty(other )
return bit32.band(self.l,other.l) == 0 and bit32.band(self.h, other.h) == 0
end
function Bitboard:bor(other )
return Bitboard.from(bit32.bor(self.l,other.l), bit32.bor(self.h, other.h))
end
function Bitboard:bxor(other )
return Bitboard.from(bit32.bxor(self.l,other.l), bit32.bxor(self.h, other.h))
end
function Bitboard:inverse()
return Bitboard.from(bit32.bxor(self.l,0xFFFFFFFF), bit32.bxor(self.h, 0xFFFFFFFF))
end
function Bitboard:empty()
return self.h == 0 and self.l == 0
end
function Bitboard:ctz()
local target = self.l
local offset = 0
local result = 0
if target == 0 then
target = self.h
result = 32
end
if target == 0 then
return 64
end
while bit32.extract(target, offset) == 0 do
offset = offset + 1
end
return result + offset
end
function Bitboard:ctzafter(start)
start = start + 1
if start < 32 then
for i=start,31 do
if bit32.extract(self.l, i) == 1 then return i end
end
end
for i=math.max(32,start),63 do
if bit32.extract(self.h, i-32) == 1 then return i end
end
return 64
end
function Bitboard:lshift(amt)
assert(amt >= 0)
if amt == 0 then return self end
if amt > 31 then
return Bitboard.from(0, bit32.lshift(self.l, amt-31))
end
local l = bit32.lshift(self.l, amt)
local h = bit32.bor(
bit32.lshift(self.h, amt),
bit32.extract(self.l, 32-amt, amt)
)
return Bitboard.from(l, h)
end
function Bitboard:rshift(amt)
assert(amt >= 0)
if amt == 0 then return self end
local h = bit32.rshift(self.h, amt)
local l = bit32.bor(
bit32.rshift(self.l, amt),
bit32.lshift(bit32.extract(self.h, 0, amt), 32-amt)
)
return Bitboard.from(l, h)
end
function Bitboard:index(i)
if i > 31 then
return bit32.extract(self.h, i - 32)
else
return bit32.extract(self.l, i)
end
end
function Bitboard:set(i , v)
if i > 31 then
return Bitboard.from(self.l, bit32.replace(self.h, v, i - 32))
else
return Bitboard.from(bit32.replace(self.l, v, i), self.h)
end
end
function Bitboard:isolate(i)
return self:band(Bitboard.some(i))
end
function Bitboard.some(idx )
return Bitboard.zero:set(idx, 1)
end
Bitboard.__index = Bitboard
Bitboard.__tostring = Bitboard.toString
for i=2,8 do
RightMasks[i] = RightMasks[i-1]:rshift(1):bor(FileH)
LeftMasks[i] = LeftMasks[i-1]:lshift(1):bor(FileA)
end
--
-- Board
--
local Board = {}
function Board.new()
local boards = table.create(12, Bitboard.zero)
boards.ocupied = Bitboard.zero
boards.white = Bitboard.zero
boards.black = Bitboard.zero
boards.unocupied = Bitboard.full
boards.ep = Bitboard.zero
boards.castle = Bitboard.zero
boards.toMove = 1
boards.hm = 0
boards.moves = 0
boards.material = 0
return setmetatable(boards, Board)
end
function Board.fromFen(fen )
local b = Board.new()
local i = 0
local rank = 7
local file = 0
while true do
i = i + 1
local p = fen:sub(i,i)
if p == '/' then
rank = rank - 1
file = 0
elseif tonumber(p) ~= nil then
file = file + tonumber(p)
else
local pidx = PieceSymbols:find(p)
if pidx == nil then break end
b[pidx] = b[pidx]:set(rank*8+file, 1)
file = file + 1
end
end
local move, castle, ep, hm, m = string.match(fen, "^ ([bw]) ([KQkq-]*) ([a-h-][0-9]?) (%d*) (%d*)", i)
if move == nil then print(fen:sub(i)) end
b.toMove = move == 'w' and 1 or 2
if ep ~= "-" then
b.ep = Bitboard.some(square(ep))
end
if castle ~= "-" then
local oo = Bitboard.zero
if castle:find("K") then
oo = oo:set(7, 1)
end
if castle:find("Q") then
oo = oo:set(0, 1)
end
if castle:find("k") then
oo = oo:set(63, 1)
end
if castle:find("q") then
oo = oo:set(56, 1)
end
b.castle = oo
end
b.hm = hm
b.moves = m
b:updateCache()
return b
end
function Board:index(idx )
if self.white:index(idx) == 1 then
for p=1,12,2 do
if self[p]:index(idx) == 1 then
return p
end
end
else
for p=2,12,2 do
if self[p]:index(idx) == 1 then
return p
end
end
end
return 0
end
function Board:updateCache()
for i=1,11,2 do
self.white = self.white:bor(self[i])
self.black = self.black:bor(self[i+1])
end
self.ocupied = self.black:bor(self.white)
self.unocupied = self.ocupied:inverse()
self.material =
100*self[1]:popcnt() - 100*self[2]:popcnt() +
500*self[3]:popcnt() - 500*self[4]:popcnt() +
300*self[5]:popcnt() - 300*self[6]:popcnt() +
300*self[7]:popcnt() - 300*self[8]:popcnt() +
900*self[9]:popcnt() - 900*self[10]:popcnt()
end
function Board:fen()
local out = {}
local s = 0
local idx = 56
for i=0,63 do
if i % 8 == 0 and i > 0 then
idx = idx - 16
if s > 0 then
table.insert(out, '' .. s)
s = 0
end
table.insert(out, '/')
end
local p = self:index(idx)
if p == 0 then
s = s + 1
else
if s > 0 then
table.insert(out, '' .. s)
s = 0
end
table.insert(out, PieceSymbols:sub(p,p))
end
idx = idx + 1
end
if s > 0 then
table.insert(out, '' .. s)
end
table.insert(out, self.toMove == 1 and ' w ' or ' b ')
if self.castle:empty() then
table.insert(out, '-')
else
if self.castle:index(7) == 1 then table.insert(out, 'K') end
if self.castle:index(0) == 1 then table.insert(out, 'Q') end
if self.castle:index(63) == 1 then table.insert(out, 'k') end
if self.castle:index(56) == 1 then table.insert(out, 'q') end
end
table.insert(out, ' ')
if self.ep:empty() then
table.insert(out, '-')
else
table.insert(out, squareName(self.ep:ctz()))
end
table.insert(out, ' ' .. self.hm)
table.insert(out, ' ' .. self.moves)
return table.concat(out)
end
function Board:pmoves(idx)
return self:generate(idx)
end
function Board:pcaptures(idx)
return self:generate(idx):band(self.ocupied)
end
local ROOK_SLIDES = {{1,0}, {-1,0}, {0,1}, {0,-1}}
local BISHOP_SLIDES = {{1,1}, {-1,1}, {1,-1}, {-1,-1}}
local QUEEN_SLIDES = {{1,0}, {-1,0}, {0,1}, {0,-1}, {1,1}, {-1,1}, {1,-1}, {-1,-1}}
local KNIGHT_MOVES = {{2,1}, {2,-1}, {-2,1}, {-2,-1}, {1,2}, {1,-2}, {-1,2}, {-1,-2}}
function Board:generate(idx)
local piece = self:index(idx)
local r = Bitboard.some(idx)
local out = Bitboard.zero
local type = bit32.rshift(piece - 1, 1)
local cancapture = piece % 2 == 1 and self.black or self.white
if piece == 0 then return Bitboard.zero end
if type == 0 then
-- Pawn
local d = -(piece*2 - 3)
local movetwo = piece == 1 and Rank3 or Rank6
out = out:bor(r:move(0,d):band(self.unocupied))
out = out:bor(out:band(movetwo):move(0,d):band(self.unocupied))
local captures = r:move(0,d)
captures = captures:right():bor(captures:left())
if not captures:bandempty(self.ep) then
out = out:bor(self.ep)
end
captures = captures:band(cancapture)
out = out:bor(captures)
return out
elseif type == 5 then
-- King
for x=-1,1,1 do
for y = -1,1,1 do
local w = r:move(x,y)
if self.ocupied:bandempty(w) then
out = out:bor(w)
else
if not cancapture:bandempty(w) then
out = out:bor(w)
end
end
end
end
elseif type == 2 then
-- Knight
for _,j in ipairs(KNIGHT_MOVES) do
local w = r:move(j[1],j[2])
if self.ocupied:bandempty(w) then
out = out:bor(w)
else
if not cancapture:bandempty(w) then
out = out:bor(w)
end
end
end
else
-- Sliders (Rook, Bishop, Queen)
local slides
if type == 1 then
slides = ROOK_SLIDES
elseif type == 3 then
slides = BISHOP_SLIDES
else
slides = QUEEN_SLIDES
end
for _, op in ipairs(slides) do
local w = r
for i=1,7 do
w = w:move(op[1], op[2])
if w:empty() then break end
if self.ocupied:bandempty(w) then
out = out:bor(w)
else
if not cancapture:bandempty(w) then
out = out:bor(w)
end
break
end
end
end
end
return out
end
-- 0-5 - From Square
-- 6-11 - To Square
-- 12 - is Check
-- 13 - Is EnPassent
-- 14 - Is Castle
-- 15-19 - Promotion Piece
-- 20-24 - Moved Pice
-- 25-29 - Captured Piece
function Board:toString(mark )
local out = {}
for x=8,1,-1 do
table.insert(out, RANKS:sub(x,x) .. " ")
for y=1,8 do
local n = 8*x+y-9
local i = self:index(n)
if i == 0 then
table.insert(out, '-')
else
-- out = out .. PieceSymbols:sub(i,i)
table.insert(out, UnicodePieces[i])
end
if mark ~= nil and mark:index(n) ~= 0 then
table.insert(out, ')')
elseif mark ~= nil and n < 63 and y < 8 and mark:index(n+1) ~= 0 then
table.insert(out, '(')
else
table.insert(out, ' ')
end
end
table.insert(out, "\n")
end
table.insert(out, ' ' .. FILES:gsub('.', '%1 ') .. '\n')
table.insert(out, (self.toMove == 1 and "White" or "Black") .. ' e:' .. (self.material/100) .. "\n")
return table.concat(out)
end
function Board:moveList()
local tm = self.toMove == 1 and self.white or self.black
local castle_rank = self.toMove == 1 and Rank1 or Rank8
local out = {}
local function emit(id)
if not self:applyMove(id):illegalyChecked() then
table.insert(out, id)
end
end
local cr = tm:band(self.castle):band(castle_rank)
if not cr:empty() then
local p = self.toMove == 1 and 11 or 12
local tcolor = self.toMove == 1 and self.black or self.white
local kidx = self[p]:ctz()
local castle = bit32.replace(0, p, 20, 4)
castle = bit32.replace(castle, kidx, 6, 6)
castle = bit32.replace(castle, 1, 14)
local mustbeemptyl = LeftMasks[4]:bxor(FileA):band(castle_rank)
local cantbethreatened = FileD:bor(FileC):band(castle_rank):bor(self[p])
if
not cr:bandempty(FileA) and
mustbeemptyl:bandempty(self.ocupied) and
not self:isSquareThreatened(cantbethreatened, tcolor)
then
emit(bit32.replace(castle, kidx - 2, 0, 6))
end
local mustbeemptyr = RightMasks[3]:bxor(FileH):band(castle_rank)
if
not cr:bandempty(FileH) and
mustbeemptyr:bandempty(self.ocupied) and
not self:isSquareThreatened(mustbeemptyr:bor(self[p]), tcolor)
then
emit(bit32.replace(castle, kidx + 2, 0, 6))
end
end
local sq = tm:ctz()
repeat
local p = self:index(sq)
local moves = self:pmoves(sq)
while not moves:empty() do
local m = moves:ctz()
moves = moves:set(m, 0)
local id = bit32.replace(m, sq, 6, 6)
id = bit32.replace(id, p, 20, 4)
local mbb = Bitboard.some(m)
if not self.ocupied:bandempty(mbb) then
id = bit32.replace(id, self:index(m), 25, 4)
end
-- Check if pawn needs to be promoted
if p == 1 and m >= 8*7 then
for i=3,9,2 do
emit(bit32.replace(id, i, 15, 4))
end
elseif p == 2 and m < 8 then
for i=4,10,2 do
emit(bit32.replace(id, i, 15, 4))
end
else
emit(id)
end
end
sq = tm:ctzafter(sq)
until sq == 64
return out
end
function Board:illegalyChecked()
local target = self.toMove == 1 and self[PieceSymbols:find("k")] or self[PieceSymbols:find("K")]
return self:isSquareThreatened(target, self.toMove == 1 and self.white or self.black)
end
function Board:isSquareThreatened(target , color )
local tm = color
local sq = tm:ctz()
repeat
local moves = self:pmoves(sq)
if not moves:bandempty(target) then
return true
end
sq = color:ctzafter(sq)
until sq == 64
return false
end
function Board:perft(depth )
if depth == 0 then return 1 end
if depth == 1 then
return #self:moveList()
end
local result = 0
for k,m in ipairs(self:moveList()) do
local c = self:applyMove(m):perft(depth - 1)
if c == 0 then
-- Perft only counts leaf nodes at target depth
-- result = result + 1
else
result = result + c
end
end
return result
end
function Board:applyMove(move )
local out = Board.new()
table.move(self, 1, 12, 1, out)
local from = bit32.extract(move, 6, 6)
local to = bit32.extract(move, 0, 6)
local promote = bit32.extract(move, 15, 4)
local piece = self:index(from)
local captured = self:index(to)
local tom = Bitboard.some(to)
local isCastle = bit32.extract(move, 14)
if piece % 2 == 0 then
out.moves = self.moves + 1
end
if captured == 1 or piece < 3 then
out.hm = 0
else
out.hm = self.hm + 1
end
out.castle = self.castle
out.toMove = self.toMove == 1 and 2 or 1
if isCastle == 1 then
local rank = piece == 11 and Rank1 or Rank8
local colorOffset = piece - 11
out[3 + colorOffset] = out[3 + colorOffset]:bandnot(from < to and FileH or FileA)
out[3 + colorOffset] = out[3 + colorOffset]:bor((from < to and FileF or FileD):band(rank))
out[piece] = (from < to and FileG or FileC):band(rank)
out.castle = out.castle:bandnot(rank)
out:updateCache()
return out
end
if piece < 3 then
local dist = math.abs(to - from)
-- Pawn moved two squares, set ep square
if dist == 16 then
out.ep = Bitboard.some((from + to) / 2)
end
-- Remove enpasent capture
if not tom:bandempty(self.ep) then
if piece == 1 then
out[2] = out[2]:bandnot(self.ep:down())
end
if piece == 2 then
out[1] = out[1]:bandnot(self.ep:up())
end
end
end
if piece == 3 or piece == 4 then
out.castle = out.castle:set(from, 0)
end
if piece > 10 then
local rank = piece == 11 and Rank1 or Rank8
out.castle = out.castle:bandnot(rank)
end
out[piece] = out[piece]:set(from, 0)
if promote == 0 then
out[piece] = out[piece]:set(to, 1)
else
out[promote] = out[promote]:set(to, 1)
end
if captured ~= 0 then
out[captured] = out[captured]:set(to, 0)
end
out:updateCache()
return out
end
Board.__index = Board
Board.__tostring = Board.toString
--
-- Main
--
local failures = 0
local function test(fen, ply, target)
local b = Board.fromFen(fen)
if b:fen() ~= fen then
print("FEN MISMATCH", fen, b:fen())
failures = failures + 1
return
end
local found = b:perft(ply)
if found ~= target then
print(fen, "Found", found, "target", target)
failures = failures + 1
for k,v in pairs(b:moveList()) do
print(ucimove(v) .. ': ' .. (ply > 1 and b:applyMove(v):perft(ply-1) or '1'))
end
--error("Test Failure")
else
print("OK", found, fen)
end
end
-- From https://www.chessprogramming.org/Perft_Results
-- If interpreter, computers, or algorithm gets too fast
-- feel free to go deeper
local testCases = {}
local function addTest(...) table.insert(testCases, {...}) end
addTest(StartingFen, 3, 8902)
addTest("r3k2r/p1ppqpb1/bn2pnp1/3PN3/1p2P3/2N2Q1p/PPPBBPPP/R3K2R w KQkq - 0 0", 2, 2039)
addTest("8/2p5/3p4/KP5r/1R3p1k/8/4P1P1/8 w - - 0 0", 3, 2812)
addTest("r3k2r/Pppp1ppp/1b3nbN/nP6/BBP1P3/q4N2/Pp1P2PP/R2Q1RK1 w kq - 0 1", 3, 9467)
addTest("rnbq1k1r/pp1Pbppp/2p5/8/2B5/8/PPP1NnPP/RNBQK2R w KQ - 1 8", 2, 1486)
addTest("r4rk1/1pp1qppp/p1np1n2/2b1p1B1/2B1P1b1/P1NP1N2/1PP1QPPP/R4RK1 w - - 0 10", 2, 2079)
local function chess()
for k,v in ipairs(testCases) do
test(v[1],v[2],v[3])
end
end
bench.runCode(chess, "chess")

View file

@ -1,934 +0,0 @@
local bench = script and require(script.Parent.bench_support) or require("bench_support")
-- Copyright 2008 the V8 project authors. All rights reserved.
-- Copyright 1996 John Maloney and Mario Wolczko.
-- This program is free software; you can redistribute it and/or modify
-- it under the terms of the GNU General Public License as published by
-- the Free Software Foundation; either version 2 of the License, or
-- (at your option) any later version.
--
-- This program is distributed in the hope that it will be useful,
-- but WITHOUT ANY WARRANTY; without even the implied warranty of
-- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-- GNU General Public License for more details.
--
-- You should have received a copy of the GNU General Public License
-- along with this program; if not, write to the Free Software
-- Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
-- This implementation of the DeltaBlue benchmark is derived
-- from the Smalltalk implementation by John Maloney and Mario
-- Wolczko. Some parts have been translated directly, whereas
-- others have been modified more aggressively to make it feel
-- more like a JavaScript program.
--
-- A JavaScript implementation of the DeltaBlue constraint-solving
-- algorithm, as described in:
--
-- "The DeltaBlue Algorithm: An Incremental Constraint Hierarchy Solver"
-- Bjorn N. Freeman-Benson and John Maloney
-- January 1990 Communications of the ACM,
-- also available as University of Washington TR 89-08-06.
--
-- Beware: this benchmark is written in a grotesque style where
-- the constraint model is built by side-effects from constructors.
-- I've kept it this way to avoid deviating too much from the original
-- implementation.
--
function class(base)
local T = {}
T.__index = T
if base then
T.super = base
setmetatable(T, base)
end
function T.new(...)
local O = {}
setmetatable(O, T)
O:constructor(...)
return O
end
return T
end
local planner
--- O b j e c t M o d e l ---
local function alert (...) print(...) end
local OrderedCollection = class()
function OrderedCollection:constructor()
self.elms = {}
end
function OrderedCollection:add(elm)
self.elms[#self.elms + 1] = elm
end
function OrderedCollection:at (index)
return self.elms[index]
end
function OrderedCollection:size ()
return #self.elms
end
function OrderedCollection:removeFirst ()
local e = self.elms[#self.elms]
self.elms[#self.elms] = nil
return e
end
function OrderedCollection:remove (elm)
local index = 0
local skipped = 0
for i = 1, #self.elms do
local value = self.elms[i]
if value ~= elm then
self.elms[index] = value
index = index + 1
else
skipped = skipped + 1
end
end
local l = #self.elms
for i = 1, skipped do self.elms[l - i + 1] = nil end
end
--
-- S t r e n g t h
--
--
-- Strengths are used to measure the relative importance of constraints.
-- New strengths may be inserted in the strength hierarchy without
-- disrupting current constraints. Strengths cannot be created outside
-- this class, so pointer comparison can be used for value comparison.
--
local Strength = class()
function Strength:constructor(strengthValue, name)
self.strengthValue = strengthValue
self.name = name
end
function Strength.stronger (s1, s2)
return s1.strengthValue < s2.strengthValue
end
function Strength.weaker (s1, s2)
return s1.strengthValue > s2.strengthValue
end
function Strength.weakestOf (s1, s2)
return Strength.weaker(s1, s2) and s1 or s2
end
function Strength.strongest (s1, s2)
return Strength.stronger(s1, s2) and s1 or s2
end
function Strength:nextWeaker ()
local v = self.strengthValue
if v == 0 then return Strength.WEAKEST
elseif v == 1 then return Strength.WEAK_DEFAULT
elseif v == 2 then return Strength.NORMAL
elseif v == 3 then return Strength.STRONG_DEFAULT
elseif v == 4 then return Strength.PREFERRED
elseif v == 5 then return Strength.REQUIRED
end
end
-- Strength constants.
Strength.REQUIRED = Strength.new(0, "required");
Strength.STONG_PREFERRED = Strength.new(1, "strongPreferred");
Strength.PREFERRED = Strength.new(2, "preferred");
Strength.STRONG_DEFAULT = Strength.new(3, "strongDefault");
Strength.NORMAL = Strength.new(4, "normal");
Strength.WEAK_DEFAULT = Strength.new(5, "weakDefault");
Strength.WEAKEST = Strength.new(6, "weakest");
--
-- C o n s t r a i n t
--
--
-- An abstract class representing a system-maintainable relationship
-- (or "constraint") between a set of variables. A constraint supplies
-- a strength instance variable; concrete subclasses provide a means
-- of storing the constrained variables and other information required
-- to represent a constraint.
--
local Constraint = class ()
function Constraint:constructor(strength)
self.strength = strength
end
--
-- Activate this constraint and attempt to satisfy it.
--
function Constraint:addConstraint ()
self:addToGraph()
planner:incrementalAdd(self)
end
--
-- Attempt to find a way to enforce this constraint. If successful,
-- record the solution, perhaps modifying the current dataflow
-- graph. Answer the constraint that this constraint overrides, if
-- there is one, or nil, if there isn't.
-- Assume: I am not already satisfied.
--
function Constraint:satisfy (mark)
self:chooseMethod(mark)
if not self:isSatisfied() then
if self.strength == Strength.REQUIRED then
alert("Could not satisfy a required constraint!")
end
return nil
end
self:markInputs(mark)
local out = self:output()
local overridden = out.determinedBy
if overridden ~= nil then overridden:markUnsatisfied() end
out.determinedBy = self
if not planner:addPropagate(self, mark) then alert("Cycle encountered") end
out.mark = mark
return overridden
end
function Constraint:destroyConstraint ()
if self:isSatisfied()
then planner:incrementalRemove(self)
else self:removeFromGraph()
end
end
--
-- Normal constraints are not input constraints. An input constraint
-- is one that depends on external state, such as the mouse, the
-- keyboard, a clock, or some arbitrary piece of imperative code.
--
function Constraint:isInput ()
return false
end
--
-- U n a r y C o n s t r a i n t
--
--
-- Abstract superclass for constraints having a single possible output
-- variable.
--
local UnaryConstraint = class(Constraint)
function UnaryConstraint:constructor (v, strength)
UnaryConstraint.super.constructor(self, strength)
self.myOutput = v
self.satisfied = false
self:addConstraint()
end
--
-- Adds this constraint to the constraint graph
--
function UnaryConstraint:addToGraph ()
self.myOutput:addConstraint(self)
self.satisfied = false
end
--
-- Decides if this constraint can be satisfied and records that
-- decision.
--
function UnaryConstraint:chooseMethod (mark)
self.satisfied = (self.myOutput.mark ~= mark)
and Strength.stronger(self.strength, self.myOutput.walkStrength);
end
--
-- Returns true if this constraint is satisfied in the current solution.
--
function UnaryConstraint:isSatisfied ()
return self.satisfied;
end
function UnaryConstraint:markInputs (mark)
-- has no inputs
end
--
-- Returns the current output variable.
--
function UnaryConstraint:output ()
return self.myOutput
end
--
-- Calculate the walkabout strength, the stay flag, and, if it is
-- 'stay', the value for the current output of this constraint. Assume
-- this constraint is satisfied.
--
function UnaryConstraint:recalculate ()
self.myOutput.walkStrength = self.strength
self.myOutput.stay = not self:isInput()
if self.myOutput.stay then
self:execute() -- Stay optimization
end
end
--
-- Records that this constraint is unsatisfied
--
function UnaryConstraint:markUnsatisfied ()
self.satisfied = false
end
function UnaryConstraint:inputsKnown ()
return true
end
function UnaryConstraint:removeFromGraph ()
if self.myOutput ~= nil then
self.myOutput:removeConstraint(self)
end
self.satisfied = false
end
--
-- S t a y C o n s t r a i n t
--
--
-- Variables that should, with some level of preference, stay the same.
-- Planners may exploit the fact that instances, if satisfied, will not
-- change their output during plan execution. This is called "stay
-- optimization".
--
local StayConstraint = class(UnaryConstraint)
function StayConstraint:constructor(v, str)
StayConstraint.super.constructor(self, v, str)
end
function StayConstraint:execute ()
-- Stay constraints do nothing
end
--
-- E d i t C o n s t r a i n t
--
--
-- A unary input constraint used to mark a variable that the client
-- wishes to change.
--
local EditConstraint = class (UnaryConstraint)
function EditConstraint:constructor(v, str)
EditConstraint.super.constructor(self, v, str)
end
--
-- Edits indicate that a variable is to be changed by imperative code.
--
function EditConstraint:isInput ()
return true
end
function EditConstraint:execute ()
-- Edit constraints do nothing
end
--
-- B i n a r y C o n s t r a i n t
--
local Direction = {}
Direction.NONE = 0
Direction.FORWARD = 1
Direction.BACKWARD = -1
--
-- Abstract superclass for constraints having two possible output
-- variables.
--
local BinaryConstraint = class(Constraint)
function BinaryConstraint:constructor(var1, var2, strength)
BinaryConstraint.super.constructor(self, strength);
self.v1 = var1
self.v2 = var2
self.direction = Direction.NONE
self:addConstraint()
end
--
-- Decides if this constraint can be satisfied and which way it
-- should flow based on the relative strength of the variables related,
-- and record that decision.
--
function BinaryConstraint:chooseMethod (mark)
if self.v1.mark == mark then
self.direction = (self.v2.mark ~= mark and Strength.stronger(self.strength, self.v2.walkStrength)) and Direction.FORWARD or Direction.NONE
end
if self.v2.mark == mark then
self.direction = (self.v1.mark ~= mark and Strength.stronger(self.strength, self.v1.walkStrength)) and Direction.BACKWARD or Direction.NONE
end
if Strength.weaker(self.v1.walkStrength, self.v2.walkStrength) then
self.direction = Strength.stronger(self.strength, self.v1.walkStrength) and Direction.BACKWARD or Direction.NONE
else
self.direction = Strength.stronger(self.strength, self.v2.walkStrength) and Direction.FORWARD or Direction.BACKWARD
end
end
--
-- Add this constraint to the constraint graph
--
function BinaryConstraint:addToGraph ()
self.v1:addConstraint(self)
self.v2:addConstraint(self)
self.direction = Direction.NONE
end
--
-- Answer true if this constraint is satisfied in the current solution.
--
function BinaryConstraint:isSatisfied ()
return self.direction ~= Direction.NONE
end
--
-- Mark the input variable with the given mark.
--
function BinaryConstraint:markInputs (mark)
self:input().mark = mark
end
--
-- Returns the current input variable
--
function BinaryConstraint:input ()
return (self.direction == Direction.FORWARD) and self.v1 or self.v2
end
--
-- Returns the current output variable
--
function BinaryConstraint:output ()
return (self.direction == Direction.FORWARD) and self.v2 or self.v1
end
--
-- Calculate the walkabout strength, the stay flag, and, if it is
-- 'stay', the value for the current output of this
-- constraint. Assume this constraint is satisfied.
--
function BinaryConstraint:recalculate ()
local ihn = self:input()
local out = self:output()
out.walkStrength = Strength.weakestOf(self.strength, ihn.walkStrength);
out.stay = ihn.stay
if out.stay then self:execute() end
end
--
-- Record the fact that self constraint is unsatisfied.
--
function BinaryConstraint:markUnsatisfied ()
self.direction = Direction.NONE
end
function BinaryConstraint:inputsKnown (mark)
local i = self:input()
return i.mark == mark or i.stay or i.determinedBy == nil
end
function BinaryConstraint:removeFromGraph ()
if (self.v1 ~= nil) then self.v1:removeConstraint(self) end
if (self.v2 ~= nil) then self.v2:removeConstraint(self) end
self.direction = Direction.NONE
end
--
-- S c a l e C o n s t r a i n t
--
--
-- Relates two variables by the linear scaling relationship: "v2 =
-- (v1 * scale) + offset". Either v1 or v2 may be changed to maintain
-- this relationship but the scale factor and offset are considered
-- read-only.
--
local ScaleConstraint = class (BinaryConstraint)
function ScaleConstraint:constructor(src, scale, offset, dest, strength)
self.direction = Direction.NONE
self.scale = scale
self.offset = offset
ScaleConstraint.super.constructor(self, src, dest, strength)
end
--
-- Adds this constraint to the constraint graph.
--
function ScaleConstraint:addToGraph ()
ScaleConstraint.super.addToGraph(self)
self.scale:addConstraint(self)
self.offset:addConstraint(self)
end
function ScaleConstraint:removeFromGraph ()
ScaleConstraint.super.removeFromGraph(self)
if (self.scale ~= nil) then self.scale:removeConstraint(self) end
if (self.offset ~= nil) then self.offset:removeConstraint(self) end
end
function ScaleConstraint:markInputs (mark)
ScaleConstraint.super.markInputs(self, mark);
self.offset.mark = mark
self.scale.mark = mark
end
--
-- Enforce this constraint. Assume that it is satisfied.
--
function ScaleConstraint:execute ()
if self.direction == Direction.FORWARD then
self.v2.value = self.v1.value * self.scale.value + self.offset.value
else
self.v1.value = (self.v2.value - self.offset.value) / self.scale.value
end
end
--
-- Calculate the walkabout strength, the stay flag, and, if it is
-- 'stay', the value for the current output of this constraint. Assume
-- this constraint is satisfied.
--
function ScaleConstraint:recalculate ()
local ihn = self:input()
local out = self:output()
out.walkStrength = Strength.weakestOf(self.strength, ihn.walkStrength)
out.stay = ihn.stay and self.scale.stay and self.offset.stay
if out.stay then self:execute() end
end
--
-- E q u a l i t y C o n s t r a i n t
--
--
-- Constrains two variables to have the same value.
--
local EqualityConstraint = class (BinaryConstraint)
function EqualityConstraint:constructor(var1, var2, strength)
EqualityConstraint.super.constructor(self, var1, var2, strength)
end
--
-- Enforce this constraint. Assume that it is satisfied.
--
function EqualityConstraint:execute ()
self:output().value = self:input().value
end
--
-- V a r i a b l e
--
--
-- A constrained variable. In addition to its value, it maintain the
-- structure of the constraint graph, the current dataflow graph, and
-- various parameters of interest to the DeltaBlue incremental
-- constraint solver.
--
local Variable = class ()
function Variable:constructor(name, initialValue)
self.value = initialValue or 0
self.constraints = OrderedCollection.new()
self.determinedBy = nil
self.mark = 0
self.walkStrength = Strength.WEAKEST
self.stay = true
self.name = name
end
--
-- Add the given constraint to the set of all constraints that refer
-- this variable.
--
function Variable:addConstraint (c)
self.constraints:add(c)
end
--
-- Removes all traces of c from this variable.
--
function Variable:removeConstraint (c)
self.constraints:remove(c)
if self.determinedBy == c then
self.determinedBy = nil
end
end
--
-- P l a n n e r
--
--
-- The DeltaBlue planner
--
local Planner = class()
function Planner:constructor()
self.currentMark = 0
end
--
-- Attempt to satisfy the given constraint and, if successful,
-- incrementally update the dataflow graph. Details: If satisfying
-- the constraint is successful, it may override a weaker constraint
-- on its output. The algorithm attempts to resatisfy that
-- constraint using some other method. This process is repeated
-- until either a) it reaches a variable that was not previously
-- determined by any constraint or b) it reaches a constraint that
-- is too weak to be satisfied using any of its methods. The
-- variables of constraints that have been processed are marked with
-- a unique mark value so that we know where we've been. This allows
-- the algorithm to avoid getting into an infinite loop even if the
-- constraint graph has an inadvertent cycle.
--
function Planner:incrementalAdd (c)
local mark = self:newMark()
local overridden = c:satisfy(mark)
while overridden ~= nil do
overridden = overridden:satisfy(mark)
end
end
--
-- Entry point for retracting a constraint. Remove the given
-- constraint and incrementally update the dataflow graph.
-- Details: Retracting the given constraint may allow some currently
-- unsatisfiable downstream constraint to be satisfied. We therefore collect
-- a list of unsatisfied downstream constraints and attempt to
-- satisfy each one in turn. This list is traversed by constraint
-- strength, strongest first, as a heuristic for avoiding
-- unnecessarily adding and then overriding weak constraints.
-- Assume: c is satisfied.
--
function Planner:incrementalRemove (c)
local out = c:output()
c:markUnsatisfied()
c:removeFromGraph()
local unsatisfied = self:removePropagateFrom(out)
local strength = Strength.REQUIRED
repeat
for i = 1, unsatisfied:size() do
local u = unsatisfied:at(i)
if u.strength == strength then
self:incrementalAdd(u)
end
end
strength = strength:nextWeaker()
until strength == Strength.WEAKEST
end
--
-- Select a previously unused mark value.
--
function Planner:newMark ()
self.currentMark = self.currentMark + 1
return self.currentMark
end
--
-- Extract a plan for resatisfaction starting from the given source
-- constraints, usually a set of input constraints. This method
-- assumes that stay optimization is desired; the plan will contain
-- only constraints whose output variables are not stay. Constraints
-- that do no computation, such as stay and edit constraints, are
-- not included in the plan.
-- Details: The outputs of a constraint are marked when it is added
-- to the plan under construction. A constraint may be appended to
-- the plan when all its input variables are known. A variable is
-- known if either a) the variable is marked (indicating that has
-- been computed by a constraint appearing earlier in the plan), b)
-- the variable is 'stay' (i.e. it is a constant at plan execution
-- time), or c) the variable is not determined by any
-- constraint. The last provision is for past states of history
-- variables, which are not stay but which are also not computed by
-- any constraint.
-- Assume: sources are all satisfied.
--
local Plan -- FORWARD DECLARATION
function Planner:makePlan (sources)
local mark = self:newMark()
local plan = Plan.new()
local todo = sources
while todo:size() > 0 do
local c = todo:removeFirst()
if c:output().mark ~= mark and c:inputsKnown(mark) then
plan:addConstraint(c)
c:output().mark = mark
self:addConstraintsConsumingTo(c:output(), todo)
end
end
return plan
end
--
-- Extract a plan for resatisfying starting from the output of the
-- given constraints, usually a set of input constraints.
--
function Planner:extractPlanFromConstraints (constraints)
local sources = OrderedCollection.new()
for i = 1, constraints:size() do
local c = constraints:at(i)
if c:isInput() and c:isSatisfied() then
-- not in plan already and eligible for inclusion
sources:add(c)
end
end
return self:makePlan(sources)
end
--
-- Recompute the walkabout strengths and stay flags of all variables
-- downstream of the given constraint and recompute the actual
-- values of all variables whose stay flag is true. If a cycle is
-- detected, remove the given constraint and answer
-- false. Otherwise, answer true.
-- Details: Cycles are detected when a marked variable is
-- encountered downstream of the given constraint. The sender is
-- assumed to have marked the inputs of the given constraint with
-- the given mark. Thus, encountering a marked node downstream of
-- the output constraint means that there is a path from the
-- constraint's output to one of its inputs.
--
function Planner:addPropagate (c, mark)
local todo = OrderedCollection.new()
todo:add(c)
while todo:size() > 0 do
local d = todo:removeFirst()
if d:output().mark == mark then
self:incrementalRemove(c)
return false
end
d:recalculate()
self:addConstraintsConsumingTo(d:output(), todo)
end
return true
end
--
-- Update the walkabout strengths and stay flags of all variables
-- downstream of the given constraint. Answer a collection of
-- unsatisfied constraints sorted in order of decreasing strength.
--
function Planner:removePropagateFrom (out)
out.determinedBy = nil
out.walkStrength = Strength.WEAKEST
out.stay = true
local unsatisfied = OrderedCollection.new()
local todo = OrderedCollection.new()
todo:add(out)
while todo:size() > 0 do
local v = todo:removeFirst()
for i = 1, v.constraints:size() do
local c = v.constraints:at(i)
if not c:isSatisfied() then unsatisfied:add(c) end
end
local determining = v.determinedBy
for i = 1, v.constraints:size() do
local next = v.constraints:at(i);
if next ~= determining and next:isSatisfied() then
next:recalculate()
todo:add(next:output())
end
end
end
return unsatisfied
end
function Planner:addConstraintsConsumingTo (v, coll)
local determining = v.determinedBy
local cc = v.constraints
for i = 1, cc:size() do
local c = cc:at(i)
if c ~= determining and c:isSatisfied() then
coll:add(c)
end
end
end
--
-- P l a n
--
--
-- A Plan is an ordered list of constraints to be executed in sequence
-- to resatisfy all currently satisfiable constraints in the face of
-- one or more changing inputs.
--
Plan = class()
function Plan:constructor()
self.v = OrderedCollection.new()
end
function Plan:addConstraint (c)
self.v:add(c)
end
function Plan:size ()
return self.v:size()
end
function Plan:constraintAt (index)
return self.v:at(index)
end
function Plan:execute ()
for i = 1, self:size() do
local c = self:constraintAt(i)
c:execute()
end
end
--
-- M a i n
--
--
-- This is the standard DeltaBlue benchmark. A long chain of equality
-- constraints is constructed with a stay constraint on one end. An
-- edit constraint is then added to the opposite end and the time is
-- measured for adding and removing this constraint, and extracting
-- and executing a constraint satisfaction plan. There are two cases.
-- In case 1, the added constraint is stronger than the stay
-- constraint and values must propagate down the entire length of the
-- chain. In case 2, the added constraint is weaker than the stay
-- constraint so it cannot be accommodated. The cost in this case is,
-- of course, very low. Typical situations lie somewhere between these
-- two extremes.
--
local function chainTest(n)
planner = Planner.new()
local prev = nil
local first = nil
local last = nil
-- Build chain of n equality constraints
for i = 0, n do
local name = "v" .. i;
local v = Variable.new(name)
if prev ~= nil then EqualityConstraint.new(prev, v, Strength.REQUIRED) end
if i == 0 then first = v end
if i == n then last = v end
prev = v
end
StayConstraint.new(last, Strength.STRONG_DEFAULT)
local edit = EditConstraint.new(first, Strength.PREFERRED)
local edits = OrderedCollection.new()
edits:add(edit)
local plan = planner:extractPlanFromConstraints(edits)
for i = 0, 99 do
first.value = i
plan:execute()
if last.value ~= i then
alert("Chain test failed.")
end
end
end
local function change(v, newValue)
local edit = EditConstraint.new(v, Strength.PREFERRED)
local edits = OrderedCollection.new()
edits:add(edit)
local plan = planner:extractPlanFromConstraints(edits)
for i = 1, 10 do
v.value = newValue
plan:execute()
end
edit:destroyConstraint()
end
--
-- This test constructs a two sets of variables related to each
-- other by a simple linear transformation (scale and offset). The
-- time is measured to change a variable on either side of the
-- mapping and to change the scale and offset factors.
--
local function projectionTest(n)
planner = Planner.new();
local scale = Variable.new("scale", 10);
local offset = Variable.new("offset", 1000);
local src = nil
local dst = nil;
local dests = OrderedCollection.new();
for i = 0, n - 1 do
src = Variable.new("src" .. i, i);
dst = Variable.new("dst" .. i, i);
dests:add(dst);
StayConstraint.new(src, Strength.NORMAL);
ScaleConstraint.new(src, scale, offset, dst, Strength.REQUIRED);
end
change(src, 17)
if dst.value ~= 1170 then alert("Projection 1 failed") end
change(dst, 1050)
if src.value ~= 5 then alert("Projection 2 failed") end
change(scale, 5)
for i = 0, n - 2 do
if dests:at(i + 1).value ~= i * 5 + 1000 then
alert("Projection 3 failed")
end
end
change(offset, 2000)
for i = 0, n - 2 do
if dests:at(i + 1).value ~= i * 5 + 2000 then
alert("Projection 4 failed")
end
end
end
function test()
local t0 = os.clock()
chainTest(1000);
projectionTest(1000);
local t1 = os.clock()
return t1-t0
end
bench.runCode(test, "deltablue")

View file

@ -21,5 +21,7 @@ pages:
url: /compatibility
- title: Typechecking
url: /typecheck
- title: Profiling
url: /profile
- title: Library
url: /library

59
docs/_pages/profile.md Normal file
View file

@ -0,0 +1,59 @@
---
permalink: /profile
title: Profiling
toc: true
---
One of main goals of Luau is to enable high performance code. To help with that goal, we are relentlessly optimizing the compiler and runtime - but ultimately, performance of their
code is in developers' hands, and is a combination of good algorithm design and implementation that adheres to the strengths of the language. To help write efficient code, Luau
provides a built-in profiler that samples the execution of the program and outputs a profiler dump that can be converted to an interactive flamegraph.
To run the profiler, make sure you have an optimized build of the intepreter (otherwise profiling results are going to be very skewed) and run it with `--profile` argument:
```
$ luau --profile tests/chess.lua
OK 8902 rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1
OK 2039 r3k2r/p1ppqpb1/bn2pnp1/3PN3/1p2P3/2N2Q1p/PPPBBPPP/R3K2R w KQkq - 0 0
OK 2812 8/2p5/3p4/KP5r/1R3p1k/8/4P1P1/8 w - - 0 0
OK 9467 r3k2r/Pppp1ppp/1b3nbN/nP6/BBP1P3/q4N2/Pp1P2PP/R2Q1RK1 w kq - 0 1
OK 1486 rnbq1k1r/pp1Pbppp/2p5/8/2B5/8/PPP1NnPP/RNBQK2R w KQ - 1 8
OK 2079 r4rk1/1pp1qppp/p1np1n2/2b1p1B1/2B1P1b1/P1NP1N2/1PP1QPPP/R4RK1 w - - 0 10
Profiler dump written to profile.out (total runtime 2.034 seconds, 20344 samples, 374 stacks)
GC: 0.378 seconds (18.58%), mark 46.80%, remark 3.33%, atomic 1.93%, sweepstring 6.77%, sweep 41.16%
```
The resulting `profile.out` file can be converted to an SVG file by running `perfgraph.py` script that is part of Luau repository:
```
$ python tools/perfgraph.py profile.out >profile.svg
```
This produces an SVG file that can be opened in a browser (the image below is clickable):
[![profile.svg](/assets/images/chess-profile.svg)](/assets/images/chess-profile.svg)
In a flame graph visualization, the individual bars represent function calls, the width represents how much of the total program runtime they execute, and the nesting matches the call stack encountered during program execution. This is a fantastic visualization technique that allows you to hone in on the specific bottlenecks affecting
your program performance, optimize those exact bottlenecks, and then re-generate the profile data and visualizer, and look for the next set of true bottlenecks (if any).
Hovering your mouse cursor over individual sections will display detailed function information in the status bar and in a tooltip. If you want to Search for a specific named
function, use the Search field in the upper right, or press Ctrl+F.
Notice that some of the bars in the screenshot don't have any text. In some cases, there isn't enough room in the size of the bar to display the name.
You can hover your mouse over those bars to see the name and source location of the function in the tool tip, or double-click to zoom in on that part of the flame graph.
Some tooltips will have a source location for the function you're hovering over, but no name. Those are anonymous functions, or functions that were not declared in a way that
allows Luau compiler to track the name. To fill in more names, you may want to make these changes to your code:
`local myFunc = function() --[[ work ]] end` -> `local function myFunc() --[[ work ]] end`
Even without these changes, you can hover over a given bar with no visible name and see it's source location.
As any sampling profiler, this profiler relies on gathering enough information for the resulting output to be statistically meaningful. It may miss short functions if they
aren't called often enough. By default the profiler runs at 10 kHz, this can be customized by passing a different parameter to `--profile=`. Note that higher
frequencies result in higher profiling overhead and longer program execution, potentially skewing the results.
This profiler doesn't track leaf C functions and instead attributes the time spent there to calling Luau functions. As a result, when thinking about why a given function is
slow, consider not just the work it does immediately but also the library functions it calls.
This profiler tracks time consumed by Luau thread stacks; when a thread calls another thread via `coroutine.resume`, the time spent is not attributed to the parent thread that's
waiting for resume results. This limitation will be removed in the future.

File diff suppressed because it is too large Load diff

After

Width:  |  Height:  |  Size: 50 KiB

File diff suppressed because it is too large Load diff

View file

@ -768,11 +768,11 @@ TEST_CASE("CaptureSelf")
local MaterialsListClass = {}
function MaterialsListClass:_MakeToolTip(guiElement, text)
local function updateTooltipPosition()
self._tweakingTooltipFrame = 5
end
local function updateTooltipPosition()
self._tweakingTooltipFrame = 5
end
updateTooltipPosition()
updateTooltipPosition()
end
return MaterialsListClass
@ -2001,14 +2001,14 @@ TEST_CASE("UpvaluesLoopsBytecode")
{
CHECK_EQ("\n" + compileFunction(R"(
function test()
for i=1,10 do
for i=1,10 do
i = i
foo(function() return i end)
if bar then
break
end
end
return 0
foo(function() return i end)
if bar then
break
end
end
return 0
end
)",
1),
@ -2035,14 +2035,14 @@ RETURN R0 1
CHECK_EQ("\n" + compileFunction(R"(
function test()
for i in ipairs(data) do
for i in ipairs(data) do
i = i
foo(function() return i end)
if bar then
break
end
end
return 0
foo(function() return i end)
if bar then
break
end
end
return 0
end
)",
1),
@ -2068,17 +2068,17 @@ RETURN R0 1
CHECK_EQ("\n" + compileFunction(R"(
function test()
local i = 0
while i < 5 do
local j
local i = 0
while i < 5 do
local j
j = i
foo(function() return j end)
i = i + 1
if bar then
break
end
end
return 0
foo(function() return j end)
i = i + 1
if bar then
break
end
end
return 0
end
)",
1),
@ -2105,17 +2105,17 @@ RETURN R1 1
CHECK_EQ("\n" + compileFunction(R"(
function test()
local i = 0
repeat
local j
local i = 0
repeat
local j
j = i
foo(function() return j end)
i = i + 1
if bar then
break
end
until i < 5
return 0
foo(function() return j end)
i = i + 1
if bar then
break
end
until i < 5
return 0
end
)",
1),
@ -2304,10 +2304,10 @@ local Value1, Value2, Value3 = ...
local Table = {}
Table.SubTable["Key"] = {
Key1 = Value1,
Key2 = Value2,
Key3 = Value3,
Key4 = true,
Key1 = Value1,
Key2 = Value2,
Key3 = Value3,
Key4 = true,
}
)");

View file

@ -801,4 +801,17 @@ TEST_CASE("IfElseExpression")
runConformance("ifelseexpr.lua");
}
TEST_CASE("TagMethodError")
{
ScopedFastFlag sff{"LuauCcallRestoreFix", true};
runConformance("tmerror.lua", [](lua_State* L) {
auto* cb = lua_callbacks(L);
cb->debugprotectederror = [](lua_State* L) {
CHECK(lua_isyieldable(L));
};
});
}
TEST_SUITE_END();

View file

@ -32,6 +32,55 @@ std::optional<ModuleName> TestFileResolver::fromAstFragment(AstExpr* expr) const
return std::nullopt;
}
std::optional<ModuleInfo> TestFileResolver::resolveModule(const ModuleInfo* context, AstExpr* expr)
{
if (AstExprGlobal* g = expr->as<AstExprGlobal>())
{
if (g->name == "game")
return ModuleInfo{"game"};
if (g->name == "workspace")
return ModuleInfo{"workspace"};
if (g->name == "script")
return context ? std::optional<ModuleInfo>(*context) : std::nullopt;
}
else if (AstExprIndexName* i = expr->as<AstExprIndexName>(); i && context)
{
if (i->index == "Parent")
{
std::string_view view = context->name;
size_t lastSeparatorIndex = view.find_last_of('/');
if (lastSeparatorIndex == std::string_view::npos)
return std::nullopt;
return ModuleInfo{ModuleName(view.substr(0, lastSeparatorIndex)), context->optional};
}
else
{
return ModuleInfo{context->name + '/' + i->index.value, context->optional};
}
}
else if (AstExprIndexExpr* i = expr->as<AstExprIndexExpr>(); i && context)
{
if (AstExprConstantString* index = i->index->as<AstExprConstantString>())
{
return ModuleInfo{context->name + '/' + std::string(index->value.data, index->value.size), context->optional};
}
}
else if (AstExprCall* call = expr->as<AstExprCall>(); call && call->self && call->args.size >= 1 && context)
{
if (AstExprConstantString* index = call->args.data[0]->as<AstExprConstantString>())
{
AstName func = call->func->as<AstExprIndexName>()->index;
if (func == "GetService" && context->name == "game")
return ModuleInfo{"game/" + std::string(index->value.data, index->value.size)};
}
}
return std::nullopt;
}
ModuleName TestFileResolver::concat(const ModuleName& lhs, std::string_view rhs) const
{
return lhs + "/" + ModuleName(rhs);

View file

@ -65,6 +65,8 @@ struct TestFileResolver
}
std::optional<ModuleName> fromAstFragment(AstExpr* expr) const override;
std::optional<ModuleInfo> resolveModule(const ModuleInfo* context, AstExpr* expr) override;
ModuleName concat(const ModuleName& lhs, std::string_view rhs) const override;
std::optional<ModuleName> getParentModuleName(const ModuleName& name) const override;

View file

@ -58,6 +58,35 @@ struct NaiveFileResolver : NullFileResolver
return std::nullopt;
}
std::optional<ModuleInfo> resolveModule(const ModuleInfo* context, AstExpr* expr) override
{
if (AstExprGlobal* g = expr->as<AstExprGlobal>())
{
if (g->name == "Modules")
return ModuleInfo{"Modules"};
if (g->name == "game")
return ModuleInfo{"game"};
}
else if (AstExprIndexName* i = expr->as<AstExprIndexName>())
{
if (context)
return ModuleInfo{context->name + '/' + i->index.value, context->optional};
}
else if (AstExprCall* call = expr->as<AstExprCall>(); call && call->self && call->args.size >= 1 && context)
{
if (AstExprConstantString* index = call->args.data[0]->as<AstExprConstantString>())
{
AstName func = call->func->as<AstExprIndexName>()->index;
if (func == "GetService" && context->name == "game")
return ModuleInfo{"game/" + std::string(index->value.data, index->value.size)};
}
}
return std::nullopt;
}
ModuleName concat(const ModuleName& lhs, std::string_view rhs) const override
{
return lhs + "/" + ModuleName(rhs);
@ -528,7 +557,7 @@ TEST_CASE_FIXTURE(FrontendFixture, "ignore_require_to_nonexistent_file")
{
fileResolver.source["Modules/A"] = R"(
local Modules = script
local B = require(Modules.B :: any)
local B = require(Modules.B) :: any
)";
CheckResult result = frontend.check("Modules/A");

View file

@ -2,6 +2,9 @@
#pragma once
#include <ostream>
#include <optional>
namespace std {
inline std::ostream& operator<<(std::ostream& lhs, const std::nullopt_t&)
{
@ -9,10 +12,12 @@ inline std::ostream& operator<<(std::ostream& lhs, const std::nullopt_t&)
}
template<typename T>
std::ostream& operator<<(std::ostream& lhs, const std::optional<T>& t)
auto operator<<(std::ostream& lhs, const std::optional<T>& t) -> decltype(lhs << *t) // SFINAE to only instantiate << for supported types
{
if (t)
return lhs << *t;
else
return lhs << "none";
}
} // namespace std

View file

@ -791,13 +791,13 @@ TEST_CASE_FIXTURE(Fixture, "TypeAnnotationsShouldNotProduceWarnings")
{
LintResult result = lint(R"(--!strict
type InputData = {
id: number,
inputType: EnumItem,
inputState: EnumItem,
updated: number,
position: Vector3,
keyCode: EnumItem,
name: string
id: number,
inputType: EnumItem,
inputState: EnumItem,
updated: number,
position: Vector3,
keyCode: EnumItem,
name: string
}
)");
@ -1400,6 +1400,8 @@ end
TEST_CASE_FIXTURE(Fixture, "TableOperations")
{
ScopedFastFlag sff("LuauLinterTableMoveZero", true);
LintResult result = lintTyped(R"(
local t = {}
local tt = {}
@ -1417,9 +1419,12 @@ table.remove(t, 0)
table.remove(t, #t-1)
table.insert(t, string.find("hello", "h"))
table.move(t, 0, #t, 1, tt)
table.move(t, 1, #t, 0, tt)
)");
REQUIRE_EQ(result.warnings.size(), 6);
REQUIRE_EQ(result.warnings.size(), 8);
CHECK_EQ(result.warnings[0].text, "table.insert will insert the value before the last element, which is likely a bug; consider removing the "
"second argument or wrap it in parentheses to silence");
CHECK_EQ(result.warnings[1].text, "table.insert will append the value to the table; consider removing the second argument for efficiency");
@ -1429,6 +1434,8 @@ table.insert(t, string.find("hello", "h"))
"second argument or wrap it in parentheses to silence");
CHECK_EQ(result.warnings[5].text,
"table.insert may change behavior if the call returns more than one result; consider adding parentheses around second argument");
CHECK_EQ(result.warnings[6].text, "table.move uses index 0 but arrays are 1-based; did you mean 1 instead?");
CHECK_EQ(result.warnings[7].text, "table.move uses index 0 but arrays are 1-based; did you mean 1 instead?");
}
TEST_CASE_FIXTURE(Fixture, "DuplicateConditions")

View file

@ -1,5 +1,6 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Module.h"
#include "Luau/Scope.h"
#include "Fixture.h"

View file

@ -1,5 +1,6 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Parser.h"
#include "Luau/Scope.h"
#include "Luau/TypeInfer.h"
#include "Luau/TypeVar.h"

View file

@ -2519,4 +2519,19 @@ TEST_CASE_FIXTURE(Fixture, "parse_if_else_expression")
}
}
TEST_CASE_FIXTURE(Fixture, "parse_type_pack_type_parameters")
{
ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true);
ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true);
AstStat* stat = parse(R"(
type Packed<T...> = () -> T...
type A<X...> = Packed<X...>
type B<X...> = Packed<...number>
type C<X...> = Packed<(number, X...)>
)");
REQUIRE(stat != nullptr);
}
TEST_SUITE_END();

View file

@ -57,6 +57,7 @@ TEST_CASE_FIXTURE(RequireTracerFixture, "trace_local")
{
AstStatBlock* block = parse(R"(
local m = workspace.Foo.Bar.Baz
require(m)
)");
RequireTraceResult result = traceRequires(&fileResolver, block, "ModuleName");
@ -70,22 +71,22 @@ TEST_CASE_FIXTURE(RequireTracerFixture, "trace_local")
AstExprIndexName* value = loc->values.data[0]->as<AstExprIndexName>();
REQUIRE(value);
REQUIRE(result.exprs.contains(value));
CHECK_EQ("workspace/Foo/Bar/Baz", result.exprs[value]);
CHECK_EQ("workspace/Foo/Bar/Baz", result.exprs[value].name);
value = value->expr->as<AstExprIndexName>();
REQUIRE(value);
REQUIRE(result.exprs.contains(value));
CHECK_EQ("workspace/Foo/Bar", result.exprs[value]);
CHECK_EQ("workspace/Foo/Bar", result.exprs[value].name);
value = value->expr->as<AstExprIndexName>();
REQUIRE(value);
REQUIRE(result.exprs.contains(value));
CHECK_EQ("workspace/Foo", result.exprs[value]);
CHECK_EQ("workspace/Foo", result.exprs[value].name);
AstExprGlobal* workspace = value->expr->as<AstExprGlobal>();
REQUIRE(workspace);
REQUIRE(result.exprs.contains(workspace));
CHECK_EQ("workspace", result.exprs[workspace]);
CHECK_EQ("workspace", result.exprs[workspace].name);
}
TEST_CASE_FIXTURE(RequireTracerFixture, "trace_transitive_local")
@ -93,9 +94,10 @@ TEST_CASE_FIXTURE(RequireTracerFixture, "trace_transitive_local")
AstStatBlock* block = parse(R"(
local m = workspace.Foo.Bar.Baz
local n = m.Quux
require(n)
)");
REQUIRE_EQ(2, block->body.size);
REQUIRE_EQ(3, block->body.size);
RequireTraceResult result = traceRequires(&fileResolver, block, "ModuleName");
@ -104,13 +106,13 @@ TEST_CASE_FIXTURE(RequireTracerFixture, "trace_transitive_local")
REQUIRE_EQ(1, local->vars.size);
REQUIRE(result.exprs.contains(local->values.data[0]));
CHECK_EQ("workspace/Foo/Bar/Baz/Quux", result.exprs[local->values.data[0]]);
CHECK_EQ("workspace/Foo/Bar/Baz/Quux", result.exprs[local->values.data[0]].name);
}
TEST_CASE_FIXTURE(RequireTracerFixture, "trace_function_arguments")
{
AstStatBlock* block = parse(R"(
local M = require(workspace.Game.Thing, workspace.Something.Else)
local M = require(workspace.Game.Thing)
)");
REQUIRE_EQ(1, block->body.size);
@ -124,52 +126,9 @@ TEST_CASE_FIXTURE(RequireTracerFixture, "trace_function_arguments")
AstExprCall* call = local->values.data[0]->as<AstExprCall>();
REQUIRE(call != nullptr);
REQUIRE_EQ(2, call->args.size);
CHECK_EQ("workspace/Game/Thing", result.exprs[call->args.data[0]]);
CHECK_EQ("workspace/Something/Else", result.exprs[call->args.data[1]]);
}
TEST_CASE_FIXTURE(RequireTracerFixture, "follow_GetService_calls")
{
AstStatBlock* block = parse(R"(
local R = game:GetService('ReplicatedStorage').Roact
local Roact = require(R)
)");
REQUIRE_EQ(2, block->body.size);
RequireTraceResult result = traceRequires(&fileResolver, block, "ModuleName");
AstStatLocal* local = block->body.data[0]->as<AstStatLocal>();
REQUIRE(local != nullptr);
CHECK_EQ("game/ReplicatedStorage/Roact", result.exprs[local->values.data[0]]);
AstStatLocal* local2 = block->body.data[1]->as<AstStatLocal>();
REQUIRE(local2 != nullptr);
REQUIRE_EQ(1, local2->values.size);
AstExprCall* call = local2->values.data[0]->as<AstExprCall>();
REQUIRE(call != nullptr);
REQUIRE_EQ(1, call->args.size);
CHECK_EQ("game/ReplicatedStorage/Roact", result.exprs[call->args.data[0]]);
}
TEST_CASE_FIXTURE(RequireTracerFixture, "follow_WaitForChild_calls")
{
ScopedFastFlag luauTraceRequireLookupChild("LuauTraceRequireLookupChild", true);
AstStatBlock* block = parse(R"(
local A = require(workspace:WaitForChild('ReplicatedStorage').Content)
local B = require(workspace:FindFirstChild('ReplicatedFirst').Data)
)");
RequireTraceResult result = traceRequires(&fileResolver, block, "ModuleName");
REQUIRE_EQ(2, result.requires.size());
CHECK_EQ("workspace/ReplicatedStorage/Content", result.requires[0].first);
CHECK_EQ("workspace/ReplicatedFirst/Data", result.requires[1].first);
CHECK_EQ("workspace/Game/Thing", result.exprs[call->args.data[0]].name);
}
TEST_CASE_FIXTURE(RequireTracerFixture, "follow_typeof")
@ -200,22 +159,23 @@ TEST_CASE_FIXTURE(RequireTracerFixture, "follow_typeof")
REQUIRE(call != nullptr);
REQUIRE_EQ(1, call->args.size);
CHECK_EQ("workspace/CoolThing", result.exprs[call->args.data[0]]);
CHECK_EQ("workspace/CoolThing", result.exprs[call->args.data[0]].name);
}
TEST_CASE_FIXTURE(RequireTracerFixture, "follow_string_indexexpr")
{
AstStatBlock* block = parse(R"(
local R = game["Test"]
require(R)
)");
REQUIRE_EQ(1, block->body.size);
REQUIRE_EQ(2, block->body.size);
RequireTraceResult result = traceRequires(&fileResolver, block, "ModuleName");
AstStatLocal* local = block->body.data[0]->as<AstStatLocal>();
REQUIRE(local != nullptr);
CHECK_EQ("game/Test", result.exprs[local->values.data[0]]);
CHECK_EQ("game/Test", result.exprs[local->values.data[0]].name);
}
TEST_SUITE_END();

View file

@ -1,5 +1,6 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Scope.h"
#include "Luau/ToString.h"
#include "Fixture.h"
@ -416,8 +417,6 @@ function foo(a, b) return a(b) end
TEST_CASE_FIXTURE(Fixture, "toString_the_boundTo_table_type_contained_within_a_TypePack")
{
ScopedFastFlag sff{"LuauToStringFollowsBoundTo", true};
TypeVar tv1{TableTypeVar{}};
TableTypeVar* ttv = getMutable<TableTypeVar>(&tv1);
ttv->state = TableState::Sealed;

View file

@ -0,0 +1,607 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Fixture.h"
#include "doctest.h"
#include "Luau/BuiltinDefinitions.h"
using namespace Luau;
TEST_SUITE_BEGIN("TypeAliases");
TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_type_alias")
{
ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true};
CheckResult result = check(R"(
type F = () -> F?
local function f()
return f
end
local g: F = f
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("t1 where t1 = () -> t1?", toString(requireType("g")));
}
TEST_CASE_FIXTURE(Fixture, "cyclic_types_of_named_table_fields_do_not_expand_when_stringified")
{
CheckResult result = check(R"(
--!strict
type Node = { Parent: Node?; }
local node: Node;
node.Parent = 1
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
TypeMismatch* tm = get<TypeMismatch>(result.errors[0]);
REQUIRE(tm);
CHECK_EQ("Node?", toString(tm->wantedType));
CHECK_EQ(typeChecker.numberType, tm->givenType);
}
TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types")
{
CheckResult result = check(R"(
--!strict
type T<a> = { f: a, g: U<a> }
type U<a> = { h: a, i: T<a>? }
local x: T<number> = { f = 37, g = { h = 5, i = nil } }
x.g.i = x
local y: T<string> = { f = "hi", g = { h = "lo", i = nil } }
y.g.i = y
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_errors")
{
CheckResult result = check(R"(
--!strict
type T<a> = { f: a, g: U<a> }
type U<b> = { h: b, i: T<b>? }
local x: T<number> = { f = 37, g = { h = 5, i = nil } }
x.g.i = x
local y: T<string> = { f = "hi", g = { h = 5, i = nil } }
y.g.i = y
)");
LUAU_REQUIRE_ERRORS(result);
// We had a UAF in this example caused by not cloning type function arguments
ModulePtr module = frontend.moduleResolver.getModule("MainModule");
unfreeze(module->interfaceTypes);
copyErrors(module->errors, module->interfaceTypes);
freeze(module->interfaceTypes);
module->internalTypes.clear();
module->astTypes.clear();
// Make sure the error strings don't include "VALUELESS"
for (auto error : module->errors)
CHECK_MESSAGE(toString(error).find("VALUELESS") == std::string::npos, toString(error));
}
TEST_CASE_FIXTURE(Fixture, "use_table_name_and_generic_params_in_errors")
{
CheckResult result = check(R"(
type Pair<T, U> = {first: T, second: U}
local a: Pair<string, number>
local b: Pair<string, string>
a = b
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
TypeMismatch* tm = get<TypeMismatch>(result.errors[0]);
REQUIRE(tm);
CHECK_EQ("Pair<string, number>", toString(tm->wantedType));
CHECK_EQ("Pair<string, string>", toString(tm->givenType));
}
TEST_CASE_FIXTURE(Fixture, "dont_stop_typechecking_after_reporting_duplicate_type_definition")
{
CheckResult result = check(R"(
type A = number
type A = string -- Redefinition of type 'A', previously defined at line 1
local foo: string = 1 -- No "Type 'number' could not be converted into 'string'"
)");
LUAU_REQUIRE_ERROR_COUNT(2, result);
}
TEST_CASE_FIXTURE(Fixture, "stringify_type_alias_of_recursive_template_table_type")
{
CheckResult result = check(R"(
type Table<T> = { a: T }
type Wrapped = Table<Wrapped>
local l: Wrapped = 2
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
TypeMismatch* tm = get<TypeMismatch>(result.errors[0]);
REQUIRE(tm);
CHECK_EQ("Wrapped", toString(tm->wantedType));
CHECK_EQ(typeChecker.numberType, tm->givenType);
}
TEST_CASE_FIXTURE(Fixture, "stringify_type_alias_of_recursive_template_table_type2")
{
CheckResult result = check(R"(
type Table<T> = { a: T }
type Wrapped = (Table<Wrapped>) -> string
local l: Wrapped = 2
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
TypeMismatch* tm = get<TypeMismatch>(result.errors[0]);
REQUIRE(tm);
CHECK_EQ("t1 where t1 = ({| a: t1 |}) -> string", toString(tm->wantedType));
CHECK_EQ(typeChecker.numberType, tm->givenType);
}
// Check that recursive intersection type doesn't generate an OOM
TEST_CASE_FIXTURE(Fixture, "cli_38393_recursive_intersection_oom")
{
CheckResult result = check(R"(
function _(l0:(t0)&((t0)&(((t0)&((t0)->()))->(typeof(_),typeof(# _)))),l39,...):any
end
type t0<t0> = ((typeof(_))&((t0)&(((typeof(_))&(t0))->typeof(_))),{n163:any,})->(any,typeof(_))
_(_)
)");
LUAU_REQUIRE_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "type_alias_fwd_declaration_is_precise")
{
CheckResult result = check(R"(
local foo: Id<number> = 1
type Id<T> = T
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "corecursive_types_generic")
{
const std::string code = R"(
type A<T> = {v:T, b:B<T>}
type B<T> = {v:T, a:A<T>}
local aa:A<number>
local bb = aa
)";
const std::string expected = R"(
type A<T> = {v:T, b:B<T>}
type B<T> = {v:T, a:A<T>}
local aa:A<number>
local bb:A<number>=aa
)";
CHECK_EQ(expected, decorateWithTypes(code));
CheckResult result = check(code);
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "corecursive_function_types")
{
ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true};
CheckResult result = check(R"(
type A = () -> (number, B)
type B = () -> (string, A)
local a: A
local b: B
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("t1 where t1 = () -> (number, () -> (string, t1))", toString(requireType("a")));
CHECK_EQ("t1 where t1 = () -> (string, () -> (number, t1))", toString(requireType("b")));
}
TEST_CASE_FIXTURE(Fixture, "generic_param_remap")
{
const std::string code = R"(
-- An example of a forwarded use of a type that has different type arguments than parameters
type A<T,U> = {t:T, u:U, next:A<U,T>?}
local aa:A<number,string> = { t = 5, u = 'hi', next = { t = 'lo', u = 8 } }
local bb = aa
)";
const std::string expected = R"(
type A<T,U> = {t:T, u:U, next:A<U,T>?}
local aa:A<number,string> = { t = 5, u = 'hi', next = { t = 'lo', u = 8 } }
local bb:A<number,string>=aa
)";
CHECK_EQ(expected, decorateWithTypes(code));
CheckResult result = check(code);
LUAU_REQUIRE_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "export_type_and_type_alias_are_duplicates")
{
CheckResult result = check(R"(
export type Foo = number
type Foo = number
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
auto dtd = get<DuplicateTypeDefinition>(result.errors[0]);
REQUIRE(dtd);
CHECK_EQ(dtd->name, "Foo");
}
TEST_CASE_FIXTURE(Fixture, "stringify_optional_parameterized_alias")
{
ScopedFastFlag sffs3{"LuauGenericFunctions", true};
ScopedFastFlag sffs4{"LuauParseGenericFunctions", true};
CheckResult result = check(R"(
type Node<T> = { value: T, child: Node<T>? }
local function visitor<T>(node: Node<T>?)
local a: Node<T>
if node then
a = node.child -- Observe the output of the error message.
end
end
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
auto e = get<TypeMismatch>(result.errors[0]);
CHECK_EQ("Node<T>?", toString(e->givenType));
CHECK_EQ("Node<T>", toString(e->wantedType));
}
TEST_CASE_FIXTURE(Fixture, "general_require_multi_assign")
{
fileResolver.source["workspace/A"] = R"(
export type myvec2 = {x: number, y: number}
return {}
)";
fileResolver.source["workspace/B"] = R"(
export type myvec3 = {x: number, y: number, z: number}
return {}
)";
fileResolver.source["workspace/C"] = R"(
local Foo, Bar = require(workspace.A), require(workspace.B)
local a: Foo.myvec2
local b: Bar.myvec3
)";
CheckResult result = frontend.check("workspace/C");
LUAU_REQUIRE_NO_ERRORS(result);
ModulePtr m = frontend.moduleResolver.modules["workspace/C"];
REQUIRE(m != nullptr);
std::optional<TypeId> aTypeId = lookupName(m->getModuleScope(), "a");
REQUIRE(aTypeId);
const Luau::TableTypeVar* aType = get<TableTypeVar>(follow(*aTypeId));
REQUIRE(aType);
REQUIRE(aType->props.size() == 2);
std::optional<TypeId> bTypeId = lookupName(m->getModuleScope(), "b");
REQUIRE(bTypeId);
const Luau::TableTypeVar* bType = get<TableTypeVar>(follow(*bTypeId));
REQUIRE(bType);
REQUIRE(bType->props.size() == 3);
}
TEST_CASE_FIXTURE(Fixture, "type_alias_import_mutation")
{
CheckResult result = check("type t10<x> = typeof(table)");
LUAU_REQUIRE_NO_ERRORS(result);
TypeId ty = getGlobalBinding(frontend.typeChecker, "table");
CHECK_EQ(toString(ty), "table");
const TableTypeVar* ttv = get<TableTypeVar>(ty);
REQUIRE(ttv);
CHECK(ttv->instantiatedTypeParams.empty());
}
TEST_CASE_FIXTURE(Fixture, "type_alias_local_mutation")
{
CheckResult result = check(R"(
type Cool = { a: number, b: string }
local c: Cool = { a = 1, b = "s" }
type NotCool<x> = Cool
)");
LUAU_REQUIRE_NO_ERRORS(result);
std::optional<TypeId> ty = requireType("c");
REQUIRE(ty);
CHECK_EQ(toString(*ty), "Cool");
const TableTypeVar* ttv = get<TableTypeVar>(*ty);
REQUIRE(ttv);
CHECK(ttv->instantiatedTypeParams.empty());
}
TEST_CASE_FIXTURE(Fixture, "type_alias_local_rename")
{
CheckResult result = check(R"(
type Cool = { a: number, b: string }
type NotCool = Cool
local c: Cool = { a = 1, b = "s" }
local d: NotCool = { a = 1, b = "s" }
)");
LUAU_REQUIRE_NO_ERRORS(result);
std::optional<TypeId> ty = requireType("c");
REQUIRE(ty);
CHECK_EQ(toString(*ty), "Cool");
ty = requireType("d");
REQUIRE(ty);
CHECK_EQ(toString(*ty), "NotCool");
}
TEST_CASE_FIXTURE(Fixture, "type_alias_local_synthetic_mutation")
{
CheckResult result = check(R"(
local c = { a = 1, b = "s" }
type Cool = typeof(c)
)");
LUAU_REQUIRE_NO_ERRORS(result);
std::optional<TypeId> ty = requireType("c");
REQUIRE(ty);
const TableTypeVar* ttv = get<TableTypeVar>(*ty);
REQUIRE(ttv);
CHECK_EQ(ttv->name, "Cool");
}
TEST_CASE_FIXTURE(Fixture, "type_alias_of_an_imported_recursive_type")
{
fileResolver.source["game/A"] = R"(
export type X = { a: number, b: X? }
return {}
)";
CheckResult aResult = frontend.check("game/A");
LUAU_REQUIRE_NO_ERRORS(aResult);
CheckResult bResult = check(R"(
local Import = require(game.A)
type X = Import.X
)");
LUAU_REQUIRE_NO_ERRORS(bResult);
std::optional<TypeId> ty1 = lookupImportedType("Import", "X");
REQUIRE(ty1);
std::optional<TypeId> ty2 = lookupType("X");
REQUIRE(ty2);
CHECK_EQ(follow(*ty1), follow(*ty2));
}
TEST_CASE_FIXTURE(Fixture, "type_alias_of_an_imported_recursive_generic_type")
{
fileResolver.source["game/A"] = R"(
export type X<T, U> = { a: T, b: U, C: X<T, U>? }
return {}
)";
CheckResult aResult = frontend.check("game/A");
LUAU_REQUIRE_NO_ERRORS(aResult);
CheckResult bResult = check(R"(
local Import = require(game.A)
type X<T, U> = Import.X<T, U>
)");
LUAU_REQUIRE_NO_ERRORS(bResult);
std::optional<TypeId> ty1 = lookupImportedType("Import", "X");
REQUIRE(ty1);
std::optional<TypeId> ty2 = lookupType("X");
REQUIRE(ty2);
CHECK_EQ(toString(*ty1, {true}), toString(*ty2, {true}));
bResult = check(R"(
local Import = require(game.A)
type X<T, U> = Import.X<U, T>
)");
LUAU_REQUIRE_NO_ERRORS(bResult);
ty1 = lookupImportedType("Import", "X");
REQUIRE(ty1);
ty2 = lookupType("X");
REQUIRE(ty2);
CHECK_EQ(toString(*ty1, {true}), "t1 where t1 = {| C: t1?, a: T, b: U |}");
CHECK_EQ(toString(*ty2, {true}), "{| C: t1, a: U, b: T |} where t1 = {| C: t1, a: U, b: T |}?");
}
TEST_CASE_FIXTURE(Fixture, "module_export_free_type_leak")
{
CheckResult result = check(R"(
function get()
return function(obj) return true end
end
export type f = typeof(get())
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "module_export_wrapped_free_type_leak")
{
CheckResult result = check(R"(
function get()
return {a = 1, b = function(obj) return true end}
end
export type f = typeof(get())
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_ok")
{
CheckResult result = check(R"(
type Tree<T> = { data: T, children: Forest<T> }
type Forest<T> = {Tree<T>}
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_not_ok_1")
{
ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true};
CheckResult result = check(R"(
-- OK because forwarded types are used with their parameters.
type Tree<T> = { data: T, children: Forest<T> }
type Forest<T> = {Tree<{T}>}
)");
LUAU_REQUIRE_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_not_ok_2")
{
ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true};
CheckResult result = check(R"(
-- Not OK because forwarded types are used with different types than their parameters.
type Forest<T> = {Tree<{T}>}
type Tree<T> = { data: T, children: Forest<T> }
)");
LUAU_REQUIRE_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_swapsies_ok")
{
CheckResult result = check(R"(
type Tree1<T,U> = { data: T, children: {Tree2<U,T>} }
type Tree2<U,T> = { data: U, children: {Tree1<T,U>} }
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_swapsies_not_ok")
{
ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true};
CheckResult result = check(R"(
type Tree1<T,U> = { data: T, children: {Tree2<U,T>} }
type Tree2<T,U> = { data: U, children: {Tree1<T,U>} }
)");
LUAU_REQUIRE_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "free_variables_from_typeof_in_aliases")
{
CheckResult result = check(R"(
function f(x) return x[1] end
-- x has type X? for a free type variable X
local x = f ({})
type ContainsFree<a> = { this: a, that: typeof(x) }
type ContainsContainsFree = { that: ContainsFree<number> }
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "non_recursive_aliases_that_reuse_a_generic_name")
{
ScopedFastFlag sff1{"LuauSubstitutionDontReplaceIgnoredTypes", true};
CheckResult result = check(R"(
type Array<T> = { [number]: T }
type Tuple<T, V> = Array<T | V>
local p: Tuple<number, string>
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("{number | string}", toString(requireType("p"), {true}));
}
/*
* We had a problem where all type aliases would be prototyped into a child scope that happened
* to have the same level. This caused a problem where, if a sibling function referred to that
* type alias in its type signature, it would erroneously be quantified away, even though it doesn't
* actually belong to the function.
*
* We solved this by ascribing a unique subLevel to each prototyped alias.
*/
TEST_CASE_FIXTURE(Fixture, "do_not_quantify_unresolved_aliases")
{
CheckResult result = check(R"(
--!strict
local KeyPool = {}
local function newkey(pool: KeyPool, index)
return {}
end
function newKeyPool()
local pool = {
available = {} :: {Key},
}
return setmetatable(pool, KeyPool)
end
export type KeyPool = typeof(newKeyPool())
export type Key = typeof(newkey(newKeyPool(), 1))
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
/*
* We keep a cache of type alias onto TypeVar to prevent infinite types from
* being constructed via recursive or corecursive aliases. We have to adjust
* the TypeLevels of those generic TypeVars so that the unifier doesn't think
* they have improperly leaked out of their scope.
*/
TEST_CASE_FIXTURE(Fixture, "generic_typevars_are_not_considered_to_escape_their_scope_if_they_are_reused_in_multiple_aliases")
{
CheckResult result = check(R"(
type Array<T> = {T}
type Exclude<T, V> = T
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_SUITE_END();

View file

@ -437,8 +437,6 @@ TEST_CASE_FIXTURE(ClassFixture, "class_unification_type_mismatch_is_correct_orde
TEST_CASE_FIXTURE(ClassFixture, "optional_class_field_access_error")
{
ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true);
CheckResult result = check(R"(
local b: Vector2? = nil
local a = b.X + b.Z

View file

@ -695,4 +695,25 @@ TEST_CASE_FIXTURE(Fixture, "typefuns_sharing_types")
CHECK(requireType("y1") == requireType("y2"));
}
TEST_CASE_FIXTURE(Fixture, "bound_tables_do_not_clone_original_fields")
{
ScopedFastFlag luauRankNTypes{"LuauRankNTypes", true};
ScopedFastFlag luauCloneBoundTables{"LuauCloneBoundTables", true};
CheckResult result = check(R"(
local exports = {}
local nested = {}
nested.name = function(t, k)
local a = t.x.y
return rawget(t, k)
end
exports.nested = nested
return exports
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_SUITE_END();

View file

@ -9,6 +9,7 @@
#include <algorithm>
LUAU_FASTFLAG(LuauEqConstraint)
LUAU_FASTFLAG(LuauQuantifyInPlace2)
using namespace Luau;
@ -30,6 +31,8 @@ TEST_SUITE_BEGIN("ProvisionalTests");
*/
TEST_CASE_FIXTURE(Fixture, "typeguard_inference_incomplete")
{
ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true);
const std::string code = R"(
function f(a)
if type(a) == "boolean" then
@ -40,16 +43,30 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_inference_incomplete")
end
)";
const std::string expected = R"(
function f(a:{fn:()->(free)}): ()
const std::string old_expected = R"(
function f(a:{fn:()->(free,free...)}): ()
if type(a) == 'boolean'then
local a1:boolean=a
elseif a.fn()then
local a2:{fn:()->(free)}=a
local a2:{fn:()->(free,free...)}=a
end
end
)";
CHECK_EQ(expected, decorateWithTypes(code));
const std::string expected = R"(
function f(a:{fn:()->(a,b...)}): ()
if type(a) == 'boolean'then
local a1:boolean=a
elseif a.fn()then
local a2:{fn:()->(a,b...)}=a
end
end
)";
if (FFlag::LuauQuantifyInPlace2)
CHECK_EQ(expected, decorateWithTypes(code));
else
CHECK_EQ(old_expected, decorateWithTypes(code));
}
TEST_CASE_FIXTURE(Fixture, "xpcall_returns_what_f_returns")
@ -231,16 +248,7 @@ TEST_CASE_FIXTURE(Fixture, "operator_eq_completely_incompatible")
local r2 = b == a
)");
if (FFlag::LuauEqConstraint)
{
LUAU_REQUIRE_NO_ERRORS(result);
}
else
{
LUAU_REQUIRE_ERROR_COUNT(2, result);
CHECK_EQ(toString(result.errors[0]), "Type '{| x: string |}?' could not be converted into 'number | string'");
CHECK_EQ(toString(result.errors[1]), "Type 'number | string' could not be converted into '{| x: string |}?'");
}
LUAU_REQUIRE_NO_ERRORS(result);
}
// Belongs in TypeInfer.refinements.test.cpp.
@ -270,8 +278,8 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_equals_another_lvalue_with_no_overlap")
TEST_CASE_FIXTURE(Fixture, "bail_early_if_unification_is_too_complicated" * doctest::timeout(0.5))
{
ScopedFastInt sffi{"LuauTarjanChildLimit", 50};
ScopedFastInt sffi2{"LuauTypeInferIterationLimit", 50};
ScopedFastInt sffi{"LuauTarjanChildLimit", 1};
ScopedFastInt sffi2{"LuauTypeInferIterationLimit", 1};
CheckResult result = check(R"LUA(
local Result
@ -542,6 +550,25 @@ TEST_CASE_FIXTURE(Fixture, "bail_early_on_typescript_port_of_Result_type" * doct
}
}
TEST_CASE_FIXTURE(Fixture, "table_subtyping_shouldn't_add_optional_properties_to_sealed_tables")
{
CheckResult result = check(R"(
--!strict
local function setNumber(t: { p: number? }, x:number) t.p = x end
local function getString(t: { p: string? }):string return t.p or "" end
-- This shouldn't type-check!
local function oh(x:number): string
local t: {} = {}
setNumber(t, x)
return getString(t)
end
local s: string = oh(37)
)");
// Really this should return an error, but it doesn't
LUAU_REQUIRE_NO_ERRORS(result);
}
// Should be in TypeInfer.tables.test.cpp
// It's unsound to instantiate tables containing generic methods,
// since mutating properties means table properties should be invariant.

View file

@ -1,4 +1,5 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Scope.h"
#include "Luau/TypeInfer.h"
#include "Fixture.h"
@ -6,8 +7,8 @@
#include "doctest.h"
LUAU_FASTFLAG(LuauWeakEqConstraint)
LUAU_FASTFLAG(LuauImprovedTypeGuardPredicate2)
LUAU_FASTFLAG(LuauOrPredicate)
LUAU_FASTFLAG(LuauQuantifyInPlace2)
using namespace Luau;
@ -199,16 +200,8 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_only_look_up_types_from_global_scope")
end
)");
if (FFlag::LuauImprovedTypeGuardPredicate2)
{
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ("Type 'number' has no overlap with 'string'", toString(result.errors[0]));
}
else
{
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ("Type 'string' could not be converted into 'boolean'", toString(result.errors[0]));
}
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ("Type 'number' has no overlap with 'string'", toString(result.errors[0]));
}
TEST_CASE_FIXTURE(Fixture, "call_a_more_specific_function_using_typeguard")
@ -526,8 +519,6 @@ TEST_CASE_FIXTURE(Fixture, "narrow_property_of_a_bounded_variable")
TEST_CASE_FIXTURE(Fixture, "type_narrow_to_vector")
{
ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true};
CheckResult result = check(R"(
local function f(x)
if type(x) == "vector" then
@ -544,8 +535,6 @@ TEST_CASE_FIXTURE(Fixture, "type_narrow_to_vector")
TEST_CASE_FIXTURE(Fixture, "nonoptional_type_can_narrow_to_nil_if_sense_is_true")
{
ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true};
CheckResult result = check(R"(
local t = {"hello"}
local v = t[2]
@ -573,8 +562,6 @@ TEST_CASE_FIXTURE(Fixture, "nonoptional_type_can_narrow_to_nil_if_sense_is_true"
TEST_CASE_FIXTURE(Fixture, "typeguard_not_to_be_string")
{
ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true};
CheckResult result = check(R"(
local function f(x: string | number | boolean)
if type(x) ~= "string" then
@ -593,8 +580,6 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_not_to_be_string")
TEST_CASE_FIXTURE(Fixture, "typeguard_narrows_for_table")
{
ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true};
CheckResult result = check(R"(
local function f(x: string | {x: number} | {y: boolean})
if type(x) == "table" then
@ -613,8 +598,6 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_narrows_for_table")
TEST_CASE_FIXTURE(Fixture, "typeguard_narrows_for_functions")
{
ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true};
CheckResult result = check(R"(
local function weird(x: string | ((number) -> string))
if type(x) == "function" then
@ -698,8 +681,6 @@ struct RefinementClassFixture : Fixture
TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_free_table_to_vector")
{
ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true};
CheckResult result = check(R"(
local function f(vec)
local X, Y, Z = vec.X, vec.Y, vec.Z
@ -718,16 +699,20 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_free_table_to_vector")
CHECK_EQ("Vector3", toString(requireTypeAtPosition({5, 28}))); // type(vec) == "vector"
CHECK_EQ("Type '{- X: a, Y: b, Z: c -}' could not be converted into 'Instance'", toString(result.errors[0]));
if (FFlag::LuauQuantifyInPlace2)
CHECK_EQ("Type '{+ X: a, Y: b, Z: c +}' could not be converted into 'Instance'", toString(result.errors[0]));
else
CHECK_EQ("Type '{- X: a, Y: b, Z: c -}' could not be converted into 'Instance'", toString(result.errors[0]));
CHECK_EQ("*unknown*", toString(requireTypeAtPosition({7, 28}))); // typeof(vec) == "Instance"
CHECK_EQ("{- X: a, Y: b, Z: c -}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance"
if (FFlag::LuauQuantifyInPlace2)
CHECK_EQ("{+ X: a, Y: b, Z: c +}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance"
else
CHECK_EQ("{- X: a, Y: b, Z: c -}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance"
}
TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_instance_or_vector3_to_vector")
{
ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true};
CheckResult result = check(R"(
local function f(x: Instance | Vector3)
if typeof(x) == "Vector3" then
@ -746,8 +731,6 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_instance_or_vector3_to
TEST_CASE_FIXTURE(RefinementClassFixture, "type_narrow_for_all_the_userdata")
{
ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true};
CheckResult result = check(R"(
local function f(x: string | number | Instance | Vector3)
if type(x) == "userdata" then
@ -766,10 +749,7 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "type_narrow_for_all_the_userdata")
TEST_CASE_FIXTURE(RefinementClassFixture, "eliminate_subclasses_of_instance")
{
ScopedFastFlag sffs[] = {
{"LuauImprovedTypeGuardPredicate2", true},
{"LuauTypeGuardPeelsAwaySubclasses", true},
};
ScopedFastFlag sff{"LuauTypeGuardPeelsAwaySubclasses", true};
CheckResult result = check(R"(
local function f(x: Part | Folder | string)
@ -789,10 +769,7 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "eliminate_subclasses_of_instance")
TEST_CASE_FIXTURE(RefinementClassFixture, "narrow_this_large_union")
{
ScopedFastFlag sffs[] = {
{"LuauImprovedTypeGuardPredicate2", true},
{"LuauTypeGuardPeelsAwaySubclasses", true},
};
ScopedFastFlag sff{"LuauTypeGuardPeelsAwaySubclasses", true};
CheckResult result = check(R"(
local function f(x: Part | Folder | Instance | string | Vector3 | any)
@ -812,10 +789,7 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "narrow_this_large_union")
TEST_CASE_FIXTURE(RefinementClassFixture, "x_as_any_if_x_is_instance_elseif_x_is_table")
{
ScopedFastFlag sffs[] = {
{"LuauOrPredicate", true},
{"LuauImprovedTypeGuardPredicate2", true},
};
ScopedFastFlag sff{"LuauOrPredicate", true};
CheckResult result = check(R"(
--!nonstrict
@ -839,7 +813,6 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "x_is_not_instance_or_else_not_part")
{
ScopedFastFlag sffs[] = {
{"LuauOrPredicate", true},
{"LuauImprovedTypeGuardPredicate2", true},
{"LuauTypeGuardPeelsAwaySubclasses", true},
};
@ -861,8 +834,6 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "x_is_not_instance_or_else_not_part")
TEST_CASE_FIXTURE(Fixture, "type_guard_can_filter_for_intersection_of_tables")
{
ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true};
CheckResult result = check(R"(
type XYCoord = {x: number} & {y: number}
local function f(t: XYCoord?)
@ -882,8 +853,6 @@ TEST_CASE_FIXTURE(Fixture, "type_guard_can_filter_for_intersection_of_tables")
TEST_CASE_FIXTURE(Fixture, "type_guard_can_filter_for_overloaded_function")
{
ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true};
CheckResult result = check(R"(
type SomeOverloadedFunction = ((number) -> string) & ((string) -> number)
local function f(g: SomeOverloadedFunction?)
@ -903,8 +872,6 @@ TEST_CASE_FIXTURE(Fixture, "type_guard_can_filter_for_overloaded_function")
TEST_CASE_FIXTURE(Fixture, "type_guard_warns_on_no_overlapping_types_only_when_sense_is_true")
{
ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true};
CheckResult result = check(R"(
local function f(t: {x: number})
if type(t) ~= "table" then
@ -999,10 +966,7 @@ TEST_CASE_FIXTURE(Fixture, "not_a_and_not_b2")
TEST_CASE_FIXTURE(Fixture, "either_number_or_string")
{
ScopedFastFlag sffs[] = {
{"LuauOrPredicate", true},
{"LuauImprovedTypeGuardPredicate2", true},
};
ScopedFastFlag sff{"LuauOrPredicate", true};
CheckResult result = check(R"(
local function f(x: any)
@ -1036,10 +1000,7 @@ TEST_CASE_FIXTURE(Fixture, "not_t_or_some_prop_of_t")
TEST_CASE_FIXTURE(Fixture, "assert_a_to_be_truthy_then_assert_a_to_be_number")
{
ScopedFastFlag sffs[] = {
{"LuauOrPredicate", true},
{"LuauImprovedTypeGuardPredicate2", true},
};
ScopedFastFlag sff{"LuauOrPredicate", true};
CheckResult result = check(R"(
local a: (number | string)?
@ -1057,10 +1018,7 @@ TEST_CASE_FIXTURE(Fixture, "assert_a_to_be_truthy_then_assert_a_to_be_number")
TEST_CASE_FIXTURE(Fixture, "merge_should_be_fully_agnostic_of_hashmap_ordering")
{
ScopedFastFlag sffs[] = {
{"LuauOrPredicate", true},
{"LuauImprovedTypeGuardPredicate2", true},
};
ScopedFastFlag sff{"LuauOrPredicate", true};
// This bug came up because there was a mistake in Luau::merge where zipping on two maps would produce the wrong merged result.
CheckResult result = check(R"(
@ -1081,10 +1039,7 @@ TEST_CASE_FIXTURE(Fixture, "merge_should_be_fully_agnostic_of_hashmap_ordering")
TEST_CASE_FIXTURE(Fixture, "refine_the_correct_types_opposite_of_when_a_is_not_number_or_string")
{
ScopedFastFlag sffs[] = {
{"LuauOrPredicate", true},
{"LuauImprovedTypeGuardPredicate2", true},
};
ScopedFastFlag sff{"LuauOrPredicate", true};
CheckResult result = check(R"(
local function f(a: string | number | boolean)

View file

@ -46,6 +46,21 @@ TEST_CASE_FIXTURE(Fixture, "augment_table")
CHECK(tType->props.find("foo") != tType->props.end());
}
TEST_CASE_FIXTURE(Fixture, "augment_nested_table")
{
CheckResult result = check("local t = { p = {} } t.p.foo = 'bar'");
LUAU_REQUIRE_NO_ERRORS(result);
TableTypeVar* tType = getMutable<TableTypeVar>(requireType("t"));
REQUIRE(tType != nullptr);
REQUIRE(tType->props.find("p") != tType->props.end());
const TableTypeVar* pType = get<TableTypeVar>(tType->props["p"].type);
REQUIRE(pType != nullptr);
CHECK(pType->props.find("foo") != pType->props.end());
}
TEST_CASE_FIXTURE(Fixture, "cannot_augment_sealed_table")
{
CheckResult result = check("local t = {prop=999} t.foo = 'bar'");
@ -260,6 +275,8 @@ TEST_CASE_FIXTURE(Fixture, "open_table_unification")
TEST_CASE_FIXTURE(Fixture, "open_table_unification_2")
{
ScopedFastFlag sff{"LuauTableSubtypingVariance", true};
CheckResult result = check(R"(
local a = {}
a.x = 99
@ -272,10 +289,11 @@ TEST_CASE_FIXTURE(Fixture, "open_table_unification_2")
LUAU_REQUIRE_ERROR_COUNT(1, result);
TypeError& err = result.errors[0];
UnknownProperty* error = get<UnknownProperty>(err);
MissingProperties* error = get<MissingProperties>(err);
REQUIRE(error != nullptr);
REQUIRE(error->properties.size() == 1);
CHECK_EQ(error->key, "y");
CHECK_EQ("y", error->properties[0]);
// TODO(rblanckaert): Revist when we can bind self at function creation time
// CHECK_EQ(err.location, Location(Position{5, 19}, Position{5, 25}));
@ -328,6 +346,8 @@ TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_1")
TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_2")
{
ScopedFastFlag sff{"LuauTableSubtypingVariance", true};
CheckResult result = check(R"(
--!strict
function foo(o)
@ -340,14 +360,17 @@ TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_2")
LUAU_REQUIRE_ERROR_COUNT(1, result);
UnknownProperty* error = get<UnknownProperty>(result.errors[0]);
MissingProperties* error = get<MissingProperties>(result.errors[0]);
REQUIRE(error != nullptr);
REQUIRE(error->properties.size() == 1);
CHECK_EQ("baz", error->key);
CHECK_EQ("baz", error->properties[0]);
}
TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_3")
{
ScopedFastFlag sff{"LuauTableSubtypingVariance", true};
CheckResult result = check(R"(
local T = {}
T.bar = 'hello'
@ -359,8 +382,11 @@ TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_3")
LUAU_REQUIRE_ERROR_COUNT(1, result);
TypeError& err = result.errors[0];
UnknownProperty* error = get<UnknownProperty>(err);
MissingProperties* error = get<MissingProperties>(err);
REQUIRE(error != nullptr);
REQUIRE(error->properties.size() == 1);
CHECK_EQ("baz", error->properties[0]);
// TODO(rblanckaert): Revist when we can bind self at function creation time
/*
@ -448,6 +474,73 @@ TEST_CASE_FIXTURE(Fixture, "ok_to_add_property_to_free_table")
dumpErrors(result);
}
TEST_CASE_FIXTURE(Fixture, "okay_to_add_property_to_unsealed_tables_by_assignment")
{
ScopedFastFlag sff{"LuauTableSubtypingVariance", true};
CheckResult result = check(R"(
--!strict
local t = { u = {} }
t = { u = { p = 37 } }
t = { u = { q = "hi" } }
local x = t.u.p
local y = t.u.q
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("number?", toString(requireType("x")));
CHECK_EQ("string?", toString(requireType("y")));
}
TEST_CASE_FIXTURE(Fixture, "okay_to_add_property_to_unsealed_tables_by_function_call")
{
CheckResult result = check(R"(
--!strict
function get(x) return x.opts["MYOPT"] end
function set(x,y) x.opts["MYOPT"] = y end
local t = { opts = {} }
set(t,37)
local x = get(t)
)");
// Currently this errors but it shouldn't, since set only needs write access
// TODO: file a JIRA for this
LUAU_REQUIRE_ERRORS(result);
// CHECK_EQ("number?", toString(requireType("x")));
}
TEST_CASE_FIXTURE(Fixture, "width_subtyping")
{
ScopedFastFlag sff{"LuauTableSubtypingVariance", true};
CheckResult result = check(R"(
--!strict
function f(x : { q : number })
x.q = 8
end
local t : { q : number, r : string } = { q = 8, r = "hi" }
f(t)
local x : string = t.r
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "width_subtyping_needs_covariance")
{
CheckResult result = check(R"(
--!strict
function f(x : { p : { q : number }})
x.p = { q = 8, r = 5 }
end
local t : { p : { q : number, r : string } } = { p = { q = 8, r = "hi" } }
f(t) -- Shouldn't typecheck
local x : string = t.p.r -- x is 5
)");
LUAU_REQUIRE_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "infer_array")
{
CheckResult result = check(R"(
@ -524,7 +617,7 @@ TEST_CASE_FIXTURE(Fixture, "indexers_get_quantified_too")
REQUIRE_EQ(indexer.indexType, typeChecker.numberType);
REQUIRE(nullptr != get<GenericTypeVar>(indexer.indexResultType));
REQUIRE(nullptr != get<GenericTypeVar>(follow(indexer.indexResultType)));
}
TEST_CASE_FIXTURE(Fixture, "indexers_quantification_2")
@ -676,16 +769,27 @@ TEST_CASE_FIXTURE(Fixture, "infer_indexer_for_left_unsealed_table_from_right_han
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "sealed_table_value_must_not_infer_an_indexer")
TEST_CASE_FIXTURE(Fixture, "sealed_table_value_can_infer_an_indexer")
{
ScopedFastFlag sff{"LuauTableSubtypingVariance", true};
CheckResult result = check(R"(
local t: { a: string, [number]: string } = { a = "foo" }
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
LUAU_REQUIRE_NO_ERRORS(result);
}
TypeMismatch* tm = get<TypeMismatch>(result.errors[0]);
REQUIRE(tm != nullptr);
TEST_CASE_FIXTURE(Fixture, "array_factory_function")
{
ScopedFastFlag sff{"LuauTableSubtypingVariance", true};
CheckResult result = check(R"(
function empty() return {} end
local array: {string} = empty()
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "sealed_table_indexers_must_unify")
@ -756,37 +860,6 @@ TEST_CASE_FIXTURE(Fixture, "indexing_from_a_table_should_prefer_properties_when_
CHECK_MESSAGE(nullptr != get<TypeMismatch>(result.errors[0]), "Expected a TypeMismatch but got " << result.errors[0]);
}
TEST_CASE_FIXTURE(Fixture, "indexing_from_a_table_with_a_string")
{
ScopedFastFlag fflag("LuauIndexTablesWithIndexers", true);
CheckResult result = check(R"(
local t: { a: string }
function f(x: string) return t[x] end
local a = f("a")
local b = f("b")
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(*typeChecker.anyType, *requireType("a"));
CHECK_EQ(*typeChecker.anyType, *requireType("b"));
}
TEST_CASE_FIXTURE(Fixture, "indexing_from_a_table_with_a_number")
{
ScopedFastFlag fflag("LuauIndexTablesWithIndexers", true);
CheckResult result = check(R"(
local t = { a = true }
function f(x: number) return t[x] end
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_MESSAGE(nullptr != get<TypeMismatch>(result.errors[0]), "Expected a TypeMismatch but got " << result.errors[0]);
}
TEST_CASE_FIXTURE(Fixture, "assigning_to_an_unsealed_table_with_string_literal_should_infer_new_properties_over_indexer")
{
CheckResult result = check(R"(
@ -1392,6 +1465,8 @@ TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer2")
TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer3")
{
ScopedFastFlag sff{"LuauTableSubtypingVariance", true};
CheckResult result = check(R"(
local function foo(a: {[string]: number, a: string}) end
foo({ a = 1 })
@ -1402,8 +1477,21 @@ TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer3")
ToStringOptions o{/* exhaustive= */ true};
TypeMismatch* tm = get<TypeMismatch>(result.errors[0]);
REQUIRE(tm);
CHECK_EQ("string", toString(tm->wantedType, o));
CHECK_EQ("number", toString(tm->givenType, o));
CHECK_EQ("{| [string]: number, a: string |}", toString(tm->wantedType, o));
CHECK_EQ("{| a: number |}", toString(tm->givenType, o));
}
TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer4")
{
CheckResult result = check(R"(
local function foo(a: {[string]: number, a: string}, i: string)
return a[i]
end
local hi: number = foo({ a = "hi" }, "a") -- shouldn't typecheck since at runtime hi is "hi"
)");
// This typechecks but shouldn't
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_missing_props_dont_report_multiple_errors")
@ -1446,22 +1534,32 @@ TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_missing_props_dont_report_multi
TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_extra_props_dont_report_multiple_errors")
{
CheckResult result = check(R"(
local vec3 = {x = 1, y = 2, z = 3}
local vec1 = {x = 1}
local vec3 = {{x = 1, y = 2, z = 3}}
local vec1 = {{x = 1}}
vec1 = vec3
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
MissingProperties* mp = get<MissingProperties>(result.errors[0]);
REQUIRE(mp);
CHECK_EQ(mp->context, MissingProperties::Extra);
REQUIRE_EQ(2, mp->properties.size());
CHECK_EQ(mp->properties[0], "y");
CHECK_EQ(mp->properties[1], "z");
CHECK_EQ("vec1", toString(mp->superType));
CHECK_EQ("vec3", toString(mp->subType));
TypeMismatch* tm = get<TypeMismatch>(result.errors[0]);
REQUIRE(tm);
CHECK_EQ("vec1", toString(tm->wantedType));
CHECK_EQ("vec3", toString(tm->givenType));
}
TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_extra_props_is_ok")
{
ScopedFastFlag sff{"LuauTableSubtypingVariance", true};
CheckResult result = check(R"(
local vec3 = {x = 1, y = 2, z = 3}
local vec1 = {x = 1}
vec1 = vec3
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "type_mismatch_on_massive_table_is_cut_short")
@ -1824,4 +1922,32 @@ TEST_CASE_FIXTURE(Fixture, "invariant_table_properties_means_instantiating_table
LUAU_REQUIRE_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "table_insert_should_cope_with_optional_properties_in_nonstrict")
{
CheckResult result = check(R"(
--!nonstrict
local buttons = {}
table.insert(buttons, { a = 1 })
table.insert(buttons, { a = 2, b = true })
table.insert(buttons, { a = 3 })
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "table_insert_should_cope_with_optional_properties_in_strict")
{
ScopedFastFlag sff{"LuauTableSubtypingVariance", true};
CheckResult result = check(R"(
--!strict
local buttons = {}
table.insert(buttons, { a = 1 })
table.insert(buttons, { a = 2, b = true })
table.insert(buttons, { a = 3 })
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_SUITE_END();

View file

@ -3,6 +3,7 @@
#include "Luau/AstQuery.h"
#include "Luau/BuiltinDefinitions.h"
#include "Luau/Parser.h"
#include "Luau/Scope.h"
#include "Luau/TypeInfer.h"
#include "Luau/TypeVar.h"
#include "Luau/VisitTypeVar.h"
@ -179,7 +180,12 @@ TEST_CASE_FIXTURE(Fixture, "expr_statement")
TEST_CASE_FIXTURE(Fixture, "generic_function")
{
CheckResult result = check("function id(x) return x end local a = id(55) local b = id(nil)");
CheckResult result = check(R"(
function id(x) return x end
local a = id(55)
local b = id(nil)
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(*typeChecker.numberType, *requireType("a"));
@ -978,23 +984,6 @@ TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_args")
CHECK_EQ("t1 where t1 = (t1) -> ()", toString(requireType("f")));
}
TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_type_alias")
{
ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true};
CheckResult result = check(R"(
type F = () -> F?
local function f()
return f
end
local g: F = f
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("t1 where t1 = () -> t1?", toString(requireType("g")));
}
// TODO: File a Jira about this
/*
TEST_CASE_FIXTURE(Fixture, "unifying_vararg_pack_with_fixed_length_pack_produces_fixed_length_pack")
@ -1257,23 +1246,6 @@ TEST_CASE_FIXTURE(Fixture, "instantiate_cyclic_generic_function")
REQUIRE_EQ(follow(*methodArg), follow(arg));
}
TEST_CASE_FIXTURE(Fixture, "cyclic_types_of_named_table_fields_do_not_expand_when_stringified")
{
CheckResult result = check(R"(
--!strict
type Node = { Parent: Node?; }
local node: Node;
node.Parent = 1
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
TypeMismatch* tm = get<TypeMismatch>(result.errors[0]);
REQUIRE(tm);
CHECK_EQ("Node?", toString(tm->wantedType));
CHECK_EQ(typeChecker.numberType, tm->givenType);
}
TEST_CASE_FIXTURE(Fixture, "varlist_declared_by_for_in_loop_should_be_free")
{
CheckResult result = check(R"(
@ -1922,7 +1894,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_higher_order_function")
REQUIRE_EQ(2, argVec.size());
const FunctionTypeVar* fType = get<FunctionTypeVar>(argVec[0]);
const FunctionTypeVar* fType = get<FunctionTypeVar>(follow(argVec[0]));
REQUIRE(fType != nullptr);
std::vector<TypeId> fArgs = flatten(fType->argTypes).first;
@ -1959,7 +1931,7 @@ TEST_CASE_FIXTURE(Fixture, "higher_order_function_2")
REQUIRE_EQ(6, argVec.size());
const FunctionTypeVar* fType = get<FunctionTypeVar>(argVec[0]);
const FunctionTypeVar* fType = get<FunctionTypeVar>(follow(argVec[0]));
REQUIRE(fType != nullptr);
}
@ -2591,48 +2563,6 @@ TEST_CASE_FIXTURE(Fixture, "toposort_doesnt_break_mutual_recursion")
dumpErrors(result);
}
TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types")
{
CheckResult result = check(R"(
--!strict
type T<a> = { f: a, g: U<a> }
type U<a> = { h: a, i: T<a>? }
local x: T<number> = { f = 37, g = { h = 5, i = nil } }
x.g.i = x
local y: T<string> = { f = "hi", g = { h = "lo", i = nil } }
y.g.i = y
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_errors")
{
CheckResult result = check(R"(
--!strict
type T<a> = { f: a, g: U<a> }
type U<b> = { h: b, i: T<b>? }
local x: T<number> = { f = 37, g = { h = 5, i = nil } }
x.g.i = x
local y: T<string> = { f = "hi", g = { h = 5, i = nil } }
y.g.i = y
)");
LUAU_REQUIRE_ERRORS(result);
// We had a UAF in this example caused by not cloning type function arguments
ModulePtr module = frontend.moduleResolver.getModule("MainModule");
unfreeze(module->interfaceTypes);
copyErrors(module->errors, module->interfaceTypes);
freeze(module->interfaceTypes);
module->internalTypes.clear();
module->astTypes.clear();
// Make sure the error strings don't include "VALUELESS"
for (auto error : module->errors)
CHECK_MESSAGE(toString(error).find("VALUELESS") == std::string::npos, toString(error));
}
TEST_CASE_FIXTURE(Fixture, "object_constructor_can_refer_to_method_of_self")
{
// CLI-30902
@ -3388,16 +3318,7 @@ TEST_CASE_FIXTURE(Fixture, "unknown_type_in_comparison")
end
)");
if (FFlag::LuauEqConstraint)
{
LUAU_REQUIRE_NO_ERRORS(result);
}
else
{
LUAU_REQUIRE_ERROR_COUNT(1, result);
REQUIRE(get<CannotInferBinaryOperation>(result.errors[0]));
}
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "relation_op_on_any_lhs_where_rhs_maybe_has_metatable")
@ -3407,18 +3328,8 @@ TEST_CASE_FIXTURE(Fixture, "relation_op_on_any_lhs_where_rhs_maybe_has_metatable
print((x == true and (x .. "y")) .. 1)
)");
if (FFlag::LuauEqConstraint)
{
LUAU_REQUIRE_ERROR_COUNT(1, result);
REQUIRE(get<CannotInferBinaryOperation>(result.errors[0]));
}
else
{
LUAU_REQUIRE_ERROR_COUNT(2, result);
CHECK_EQ("Type 'boolean' could not be converted into 'number | string'", toString(result.errors[0]));
CHECK_EQ("Type 'boolean | string' could not be converted into 'number | string'", toString(result.errors[1]));
}
LUAU_REQUIRE_ERROR_COUNT(1, result);
REQUIRE(get<CannotInferBinaryOperation>(result.errors[0]));
}
TEST_CASE_FIXTURE(Fixture, "concat_op_on_string_lhs_and_free_rhs")
@ -3530,25 +3441,6 @@ _(...)(...,setfenv,_):_G()
)");
}
TEST_CASE_FIXTURE(Fixture, "use_table_name_and_generic_params_in_errors")
{
CheckResult result = check(R"(
type Pair<T, U> = {first: T, second: U}
local a: Pair<string, number>
local b: Pair<string, string>
a = b
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
TypeMismatch* tm = get<TypeMismatch>(result.errors[0]);
REQUIRE(tm);
CHECK_EQ("Pair<string, number>", toString(tm->wantedType));
CHECK_EQ("Pair<string, string>", toString(tm->givenType));
}
TEST_CASE_FIXTURE(Fixture, "cyclic_type_packs")
{
// this has a risk of creating cyclic type packs, causing infinite loops / OOMs
@ -3658,17 +3550,6 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_where_iteratee_is_free")
)");
}
TEST_CASE_FIXTURE(Fixture, "dont_stop_typechecking_after_reporting_duplicate_type_definition")
{
CheckResult result = check(R"(
type A = number
type A = string -- Redefinition of type 'A', previously defined at line 1
local foo: string = 1 -- No "Type 'number' could not be converted into 'string'"
)");
LUAU_REQUIRE_ERROR_COUNT(2, result);
}
TEST_CASE_FIXTURE(Fixture, "tc_after_error_recovery")
{
CheckResult result = check(R"(
@ -3771,38 +3652,6 @@ TEST_CASE_FIXTURE(Fixture, "error_on_invalid_operand_types_to_relational_operato
CHECK_EQ("Type 'number | string' cannot be compared with relational operator <", ge->message);
}
TEST_CASE_FIXTURE(Fixture, "stringify_type_alias_of_recursive_template_table_type")
{
CheckResult result = check(R"(
type Table<T> = { a: T }
type Wrapped = Table<Wrapped>
local l: Wrapped = 2
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
TypeMismatch* tm = get<TypeMismatch>(result.errors[0]);
REQUIRE(tm);
CHECK_EQ("Wrapped", toString(tm->wantedType));
CHECK_EQ(typeChecker.numberType, tm->givenType);
}
TEST_CASE_FIXTURE(Fixture, "stringify_type_alias_of_recursive_template_table_type2")
{
CheckResult result = check(R"(
type Table<T> = { a: T }
type Wrapped = (Table<Wrapped>) -> string
local l: Wrapped = 2
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
TypeMismatch* tm = get<TypeMismatch>(result.errors[0]);
REQUIRE(tm);
CHECK_EQ("t1 where t1 = ({| a: t1 |}) -> string", toString(tm->wantedType));
CHECK_EQ(typeChecker.numberType, tm->givenType);
}
TEST_CASE_FIXTURE(Fixture, "index_expr_should_be_checked")
{
CheckResult result = check(R"(
@ -3928,19 +3777,6 @@ TEST_CASE_FIXTURE(Fixture, "stringify_nested_unions_with_optionals")
CHECK_EQ("(boolean | number | string)?", toString(tm->givenType));
}
// Check that recursive intersection type doesn't generate an OOM
TEST_CASE_FIXTURE(Fixture, "cli_38393_recursive_intersection_oom")
{
CheckResult result = check(R"(
function _(l0:(t0)&((t0)&(((t0)&((t0)->()))->(typeof(_),typeof(# _)))),l39,...):any
end
type t0<t0> = ((typeof(_))&((t0)&(((typeof(_))&(t0))->typeof(_))),{n163:any,})->(any,typeof(_))
_(_)
)");
LUAU_REQUIRE_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "UnknownGlobalCompoundAssign")
{
// In non-strict mode, global definition is still allowed
@ -3993,16 +3829,6 @@ TEST_CASE_FIXTURE(Fixture, "loop_typecheck_crash_on_empty_optional")
LUAU_REQUIRE_ERROR_COUNT(2, result);
}
TEST_CASE_FIXTURE(Fixture, "type_alias_fwd_declaration_is_precise")
{
CheckResult result = check(R"(
local foo: Id<number> = 1
type Id<T> = T
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "cli_39932_use_unifier_in_ensure_methods")
{
CheckResult result = check(R"(
@ -4021,10 +3847,10 @@ local T: any
T = {}
T.__index = T
function T.new(...)
local self = {}
setmetatable(self, T)
self:construct(...)
return self
local self = {}
setmetatable(self, T)
self:construct(...)
return self
end
function T:construct(index)
end
@ -4033,81 +3859,6 @@ end
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "corecursive_types_generic")
{
const std::string code = R"(
type A<T> = {v:T, b:B<T>}
type B<T> = {v:T, a:A<T>}
local aa:A<number>
local bb = aa
)";
const std::string expected = R"(
type A<T> = {v:T, b:B<T>}
type B<T> = {v:T, a:A<T>}
local aa:A<number>
local bb:A<number>=aa
)";
CHECK_EQ(expected, decorateWithTypes(code));
CheckResult result = check(code);
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "corecursive_function_types")
{
ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true};
CheckResult result = check(R"(
type A = () -> (number, B)
type B = () -> (string, A)
local a: A
local b: B
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("t1 where t1 = () -> (number, () -> (string, t1))", toString(requireType("a")));
CHECK_EQ("t1 where t1 = () -> (string, () -> (number, t1))", toString(requireType("b")));
}
TEST_CASE_FIXTURE(Fixture, "generic_param_remap")
{
const std::string code = R"(
-- An example of a forwarded use of a type that has different type arguments than parameters
type A<T,U> = {t:T, u:U, next:A<U,T>?}
local aa:A<number,string> = { t = 5, u = 'hi', next = { t = 'lo', u = 8 } }
local bb = aa
)";
const std::string expected = R"(
type A<T,U> = {t:T, u:U, next:A<U,T>?}
local aa:A<number,string> = { t = 5, u = 'hi', next = { t = 'lo', u = 8 } }
local bb:A<number,string>=aa
)";
CHECK_EQ(expected, decorateWithTypes(code));
CheckResult result = check(code);
LUAU_REQUIRE_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "export_type_and_type_alias_are_duplicates")
{
CheckResult result = check(R"(
export type Foo = number
type Foo = number
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
auto dtd = get<DuplicateTypeDefinition>(result.errors[0]);
REQUIRE(dtd);
CHECK_EQ(dtd->name, "Foo");
}
TEST_CASE_FIXTURE(Fixture, "dont_report_type_errors_within_an_AstStatError")
{
CheckResult result = check(R"(
@ -4212,30 +3963,6 @@ TEST_CASE_FIXTURE(Fixture, "luau_resolves_symbols_the_same_way_lua_does")
REQUIRE_MESSAGE(get<UnknownSymbol>(e) != nullptr, "Expected UnknownSymbol, but got " << e);
}
TEST_CASE_FIXTURE(Fixture, "stringify_optional_parameterized_alias")
{
ScopedFastFlag sffs3{"LuauGenericFunctions", true};
ScopedFastFlag sffs4{"LuauParseGenericFunctions", true};
CheckResult result = check(R"(
type Node<T> = { value: T, child: Node<T>? }
local function visitor<T>(node: Node<T>?)
local a: Node<T>
if node then
a = node.child -- Observe the output of the error message.
end
end
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
auto e = get<TypeMismatch>(result.errors[0]);
CHECK_EQ("Node<T>?", toString(e->givenType));
CHECK_EQ("Node<T>", toString(e->wantedType));
}
TEST_CASE_FIXTURE(Fixture, "operator_eq_verifies_types_do_intersect")
{
CheckResult result = check(R"(
@ -4291,181 +4018,6 @@ local tbl: string = require(game.A)
CHECK_EQ("Type '{| def: number |}' could not be converted into 'string'", toString(result.errors[0]));
}
TEST_CASE_FIXTURE(Fixture, "general_require_multi_assign")
{
fileResolver.source["workspace/A"] = R"(
export type myvec2 = {x: number, y: number}
return {}
)";
fileResolver.source["workspace/B"] = R"(
export type myvec3 = {x: number, y: number, z: number}
return {}
)";
fileResolver.source["workspace/C"] = R"(
local Foo, Bar = require(workspace.A), require(workspace.B)
local a: Foo.myvec2
local b: Bar.myvec3
)";
CheckResult result = frontend.check("workspace/C");
LUAU_REQUIRE_NO_ERRORS(result);
ModulePtr m = frontend.moduleResolver.modules["workspace/C"];
REQUIRE(m != nullptr);
std::optional<TypeId> aTypeId = lookupName(m->getModuleScope(), "a");
REQUIRE(aTypeId);
const Luau::TableTypeVar* aType = get<TableTypeVar>(follow(*aTypeId));
REQUIRE(aType);
REQUIRE(aType->props.size() == 2);
std::optional<TypeId> bTypeId = lookupName(m->getModuleScope(), "b");
REQUIRE(bTypeId);
const Luau::TableTypeVar* bType = get<TableTypeVar>(follow(*bTypeId));
REQUIRE(bType);
REQUIRE(bType->props.size() == 3);
}
TEST_CASE_FIXTURE(Fixture, "type_alias_import_mutation")
{
CheckResult result = check("type t10<x> = typeof(table)");
LUAU_REQUIRE_NO_ERRORS(result);
TypeId ty = getGlobalBinding(frontend.typeChecker, "table");
CHECK_EQ(toString(ty), "table");
const TableTypeVar* ttv = get<TableTypeVar>(ty);
REQUIRE(ttv);
CHECK(ttv->instantiatedTypeParams.empty());
}
TEST_CASE_FIXTURE(Fixture, "type_alias_local_mutation")
{
CheckResult result = check(R"(
type Cool = { a: number, b: string }
local c: Cool = { a = 1, b = "s" }
type NotCool<x> = Cool
)");
LUAU_REQUIRE_NO_ERRORS(result);
std::optional<TypeId> ty = requireType("c");
REQUIRE(ty);
CHECK_EQ(toString(*ty), "Cool");
const TableTypeVar* ttv = get<TableTypeVar>(*ty);
REQUIRE(ttv);
CHECK(ttv->instantiatedTypeParams.empty());
}
TEST_CASE_FIXTURE(Fixture, "type_alias_local_rename")
{
CheckResult result = check(R"(
type Cool = { a: number, b: string }
type NotCool = Cool
local c: Cool = { a = 1, b = "s" }
local d: NotCool = { a = 1, b = "s" }
)");
LUAU_REQUIRE_NO_ERRORS(result);
std::optional<TypeId> ty = requireType("c");
REQUIRE(ty);
CHECK_EQ(toString(*ty), "Cool");
ty = requireType("d");
REQUIRE(ty);
CHECK_EQ(toString(*ty), "NotCool");
}
TEST_CASE_FIXTURE(Fixture, "type_alias_local_synthetic_mutation")
{
CheckResult result = check(R"(
local c = { a = 1, b = "s" }
type Cool = typeof(c)
)");
LUAU_REQUIRE_NO_ERRORS(result);
std::optional<TypeId> ty = requireType("c");
REQUIRE(ty);
const TableTypeVar* ttv = get<TableTypeVar>(*ty);
REQUIRE(ttv);
CHECK_EQ(ttv->name, "Cool");
}
TEST_CASE_FIXTURE(Fixture, "type_alias_of_an_imported_recursive_type")
{
ScopedFastFlag luauFixTableTypeAliasClone{"LuauFixTableTypeAliasClone", true};
fileResolver.source["game/A"] = R"(
export type X = { a: number, b: X? }
return {}
)";
CheckResult aResult = frontend.check("game/A");
LUAU_REQUIRE_NO_ERRORS(aResult);
CheckResult bResult = check(R"(
local Import = require(game.A)
type X = Import.X
)");
LUAU_REQUIRE_NO_ERRORS(bResult);
std::optional<TypeId> ty1 = lookupImportedType("Import", "X");
REQUIRE(ty1);
std::optional<TypeId> ty2 = lookupType("X");
REQUIRE(ty2);
CHECK_EQ(follow(*ty1), follow(*ty2));
}
TEST_CASE_FIXTURE(Fixture, "type_alias_of_an_imported_recursive_generic_type")
{
ScopedFastFlag luauFixTableTypeAliasClone{"LuauFixTableTypeAliasClone", true};
fileResolver.source["game/A"] = R"(
export type X<T, U> = { a: T, b: U, C: X<T, U>? }
return {}
)";
CheckResult aResult = frontend.check("game/A");
LUAU_REQUIRE_NO_ERRORS(aResult);
CheckResult bResult = check(R"(
local Import = require(game.A)
type X<T, U> = Import.X<T, U>
)");
LUAU_REQUIRE_NO_ERRORS(bResult);
std::optional<TypeId> ty1 = lookupImportedType("Import", "X");
REQUIRE(ty1);
std::optional<TypeId> ty2 = lookupType("X");
REQUIRE(ty2);
CHECK_EQ(toString(*ty1, {true}), toString(*ty2, {true}));
bResult = check(R"(
local Import = require(game.A)
type X<T, U> = Import.X<U, T>
)");
LUAU_REQUIRE_NO_ERRORS(bResult);
ty1 = lookupImportedType("Import", "X");
REQUIRE(ty1);
ty2 = lookupType("X");
REQUIRE(ty2);
CHECK_EQ(toString(*ty1, {true}), "t1 where t1 = {| C: t1?, a: T, b: U |}");
CHECK_EQ(toString(*ty2, {true}), "{| C: t1, a: U, b: T |} where t1 = {| C: t1, a: U, b: T |}?");
}
TEST_CASE_FIXTURE(Fixture, "nonstrict_self_mismatch_tail")
{
CheckResult result = check(R"(
@ -4521,11 +4073,11 @@ function n:Clone() end
local m = {}
function m.a(x)
x:Clone()
x:Clone()
end
function m.b()
m.a(n)
m.a(n)
end
return m
@ -4579,32 +4131,6 @@ local c = a(2) -- too many arguments
CHECK_EQ("Argument count mismatch. Function expects 1 argument, but 2 are specified", toString(result.errors[0]));
}
TEST_CASE_FIXTURE(Fixture, "module_export_free_type_leak")
{
CheckResult result = check(R"(
function get()
return function(obj) return true end
end
export type f = typeof(get())
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "module_export_wrapped_free_type_leak")
{
CheckResult result = check(R"(
function get()
return {a = 1, b = function(obj) return true end}
end
export type f = typeof(get())
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "custom_require_global")
{
CheckResult result = check(R"(
@ -4787,8 +4313,6 @@ TEST_CASE_FIXTURE(Fixture, "no_heap_use_after_free_error")
TEST_CASE_FIXTURE(Fixture, "dont_invalidate_the_properties_iterator_of_free_table_when_rolled_back")
{
ScopedFastFlag sff{"LuauLogTableTypeVarBoundTo", true};
fileResolver.source["Module/Backend/Types"] = R"(
export type Fiber = {
return_: Fiber?
@ -4868,14 +4392,12 @@ TEST_CASE_FIXTURE(Fixture, "record_matching_overload")
ModulePtr module = getMainModule();
auto it = module->astOverloadResolvedTypes.find(parentExpr);
REQUIRE(it != module->astOverloadResolvedTypes.end());
CHECK_EQ(toString(it->second), "(number) -> number");
REQUIRE(it);
CHECK_EQ(toString(*it), "(number) -> number");
}
TEST_CASE_FIXTURE(Fixture, "infer_anonymous_function_arguments")
{
ScopedFastFlag luauInferFunctionArgsFix("LuauInferFunctionArgsFix", true);
// Simple direct arg to arg propagation
CheckResult result = check(R"(
type Table = { x: number, y: number }
@ -5032,76 +4554,6 @@ g12({x=1}, {x=2}, function(x, y) return {x=x.x + y.x} end)
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_ok")
{
CheckResult result = check(R"(
type Tree<T> = { data: T, children: Forest<T> }
type Forest<T> = {Tree<T>}
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_not_ok_1")
{
ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true};
CheckResult result = check(R"(
-- OK because forwarded types are used with their parameters.
type Tree<T> = { data: T, children: Forest<T> }
type Forest<T> = {Tree<{T}>}
)");
LUAU_REQUIRE_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_not_ok_2")
{
ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true};
CheckResult result = check(R"(
-- Not OK because forwarded types are used with different types than their parameters.
type Forest<T> = {Tree<{T}>}
type Tree<T> = { data: T, children: Forest<T> }
)");
LUAU_REQUIRE_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_swapsies_ok")
{
CheckResult result = check(R"(
type Tree1<T,U> = { data: T, children: {Tree2<U,T>} }
type Tree2<U,T> = { data: U, children: {Tree1<T,U>} }
LUAU_REQUIRE_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_swapsies_not_ok")
{
ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true};
CheckResult result = check(R"(
type Tree1<T,U> = { data: T, children: {Tree2<U,T>} }
type Tree2<T,U> = { data: U, children: {Tree1<T,U>} }
)");
LUAU_REQUIRE_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "free_variables_from_typeof_in_aliases")
{
CheckResult result = check(R"(
function f(x) return x[1] end
-- x has type X? for a free type variable X
local x = f ({})
type ContainsFree<a> = { this: a, that: typeof(x) }
type ContainsContainsFree = { that: ContainsFree<number> }
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "infer_generic_lib_function_function_argument")
{
CheckResult result = check(R"(
@ -5232,7 +4684,6 @@ TEST_CASE_FIXTURE(Fixture, "checked_prop_too_early")
{
ScopedFastFlag sffs[] = {
{"LuauSlightlyMoreFlexibleBinaryPredicates", true},
{"LuauExtraNilRecovery", true},
};
CheckResult result = check(R"(
@ -5249,7 +4700,6 @@ TEST_CASE_FIXTURE(Fixture, "accidentally_checked_prop_in_opposite_branch")
{
ScopedFastFlag sffs[] = {
{"LuauSlightlyMoreFlexibleBinaryPredicates", true},
{"LuauExtraNilRecovery", true},
};
CheckResult result = check(R"(

View file

@ -1,5 +1,6 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Parser.h"
#include "Luau/Scope.h"
#include "Luau/TypeInfer.h"
#include "Luau/TypeVar.h"
@ -7,6 +8,8 @@
#include "doctest.h"
LUAU_FASTFLAG(LuauQuantifyInPlace2);
using namespace Luau;
struct TryUnifyFixture : Fixture
@ -14,7 +17,8 @@ struct TryUnifyFixture : Fixture
TypeArena arena;
ScopePtr globalScope{new Scope{arena.addTypePack({TypeId{}})}};
InternalErrorReporter iceHandler;
Unifier state{&arena, Mode::Strict, globalScope, Location{}, Variance::Covariant, &iceHandler};
UnifierSharedState unifierState{&iceHandler};
Unifier state{&arena, Mode::Strict, globalScope, Location{}, Variance::Covariant, unifierState};
};
TEST_SUITE_BEGIN("TryUnifyTests");
@ -138,7 +142,10 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "typepack_unification_should_trim_free_tails"
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ("(number) -> (boolean)", toString(requireType("f")));
if (FFlag::LuauQuantifyInPlace2)
CHECK_EQ("(number) -> boolean", toString(requireType("f")));
else
CHECK_EQ("(number) -> (boolean)", toString(requireType("f")));
}
TEST_CASE_FIXTURE(TryUnifyFixture, "variadic_type_pack_unification")

View file

@ -98,10 +98,10 @@ TEST_CASE_FIXTURE(Fixture, "higher_order_function")
std::vector<TypeId> applyArgs = flatten(applyType->argTypes).first;
REQUIRE_EQ(3, applyArgs.size());
const FunctionTypeVar* fType = get<FunctionTypeVar>(applyArgs[0]);
const FunctionTypeVar* fType = get<FunctionTypeVar>(follow(applyArgs[0]));
REQUIRE(fType != nullptr);
const FunctionTypeVar* gType = get<FunctionTypeVar>(applyArgs[1]);
const FunctionTypeVar* gType = get<FunctionTypeVar>(follow(applyArgs[1]));
REQUIRE(gType != nullptr);
std::vector<TypeId> gArgs = flatten(gType->argTypes).first;
@ -285,7 +285,7 @@ TEST_CASE_FIXTURE(Fixture, "variadic_argument_tail")
{
CheckResult result = check(R"(
local _ = function():((...any)->(...any),()->())
return function() end, function() end
return function() end, function() end
end
for y in _() do
end
@ -294,4 +294,370 @@ end
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "type_alias_type_packs")
{
ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true);
ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true);
ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true);
CheckResult result = check(R"(
type Packed<T...> = (T...) -> T...
local a: Packed<>
local b: Packed<number>
local c: Packed<string, number>
)");
LUAU_REQUIRE_NO_ERRORS(result);
auto tf = lookupType("Packed");
REQUIRE(tf);
CHECK_EQ(toString(*tf), "(T...) -> (T...)");
CHECK_EQ(toString(requireType("a")), "() -> ()");
CHECK_EQ(toString(requireType("b")), "(number) -> number");
CHECK_EQ(toString(requireType("c")), "(string, number) -> (string, number)");
result = check(R"(
-- (U..., T) cannot be parsed right now
type Packed<T, U...> = { f: (a: T, U...) -> (T, U...) }
local a: Packed<number>
local b: Packed<string, number>
local c: Packed<string, number, boolean>
)");
LUAU_REQUIRE_NO_ERRORS(result);
tf = lookupType("Packed");
REQUIRE(tf);
CHECK_EQ(toString(*tf), "Packed<T, U...>");
CHECK_EQ(toString(*tf, {true}), "{| f: (T, U...) -> (T, U...) |}");
auto ttvA = get<TableTypeVar>(requireType("a"));
REQUIRE(ttvA);
CHECK_EQ(toString(requireType("a")), "Packed<number>");
CHECK_EQ(toString(requireType("a"), {true}), "{| f: (number) -> (number) |}");
REQUIRE(ttvA->instantiatedTypeParams.size() == 1);
REQUIRE(ttvA->instantiatedTypePackParams.size() == 1);
CHECK_EQ(toString(ttvA->instantiatedTypeParams[0], {true}), "number");
CHECK_EQ(toString(ttvA->instantiatedTypePackParams[0], {true}), "");
auto ttvB = get<TableTypeVar>(requireType("b"));
REQUIRE(ttvB);
CHECK_EQ(toString(requireType("b")), "Packed<string, number>");
CHECK_EQ(toString(requireType("b"), {true}), "{| f: (string, number) -> (string, number) |}");
REQUIRE(ttvB->instantiatedTypeParams.size() == 1);
REQUIRE(ttvB->instantiatedTypePackParams.size() == 1);
CHECK_EQ(toString(ttvB->instantiatedTypeParams[0], {true}), "string");
CHECK_EQ(toString(ttvB->instantiatedTypePackParams[0], {true}), "number");
auto ttvC = get<TableTypeVar>(requireType("c"));
REQUIRE(ttvC);
CHECK_EQ(toString(requireType("c")), "Packed<string, number, boolean>");
CHECK_EQ(toString(requireType("c"), {true}), "{| f: (string, number, boolean) -> (string, number, boolean) |}");
REQUIRE(ttvC->instantiatedTypeParams.size() == 1);
REQUIRE(ttvC->instantiatedTypePackParams.size() == 1);
CHECK_EQ(toString(ttvC->instantiatedTypeParams[0], {true}), "string");
CHECK_EQ(toString(ttvC->instantiatedTypePackParams[0], {true}), "number, boolean");
}
TEST_CASE_FIXTURE(Fixture, "type_alias_type_packs_import")
{
ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true);
ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true);
ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true);
fileResolver.source["game/A"] = R"(
export type Packed<T, U...> = { a: T, b: (U...) -> () }
return {}
)";
CheckResult aResult = frontend.check("game/A");
LUAU_REQUIRE_NO_ERRORS(aResult);
CheckResult bResult = check(R"(
local Import = require(game.A)
local a: Import.Packed<number>
local b: Import.Packed<string, number>
local c: Import.Packed<string, number, boolean>
local d: { a: typeof(c) }
)");
LUAU_REQUIRE_NO_ERRORS(bResult);
auto tf = lookupImportedType("Import", "Packed");
REQUIRE(tf);
CHECK_EQ(toString(*tf), "Packed<T, U...>");
CHECK_EQ(toString(*tf, {true}), "{| a: T, b: (U...) -> () |}");
CHECK_EQ(toString(requireType("a"), {true}), "{| a: number, b: () -> () |}");
CHECK_EQ(toString(requireType("b"), {true}), "{| a: string, b: (number) -> () |}");
CHECK_EQ(toString(requireType("c"), {true}), "{| a: string, b: (number, boolean) -> () |}");
CHECK_EQ(toString(requireType("d")), "{| a: Packed<string, number, boolean> |}");
}
TEST_CASE_FIXTURE(Fixture, "type_pack_type_parameters")
{
ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true);
ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true);
ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true);
fileResolver.source["game/A"] = R"(
export type Packed<T, U...> = { a: T, b: (U...) -> () }
return {}
)";
CheckResult cResult = check(R"(
local Import = require(game.A)
type Alias<S, T, R...> = Import.Packed<S, (T, R...)>
local a: Alias<string, number, boolean>
type B<X...> = Import.Packed<string, X...>
type C<X...> = Import.Packed<string, (number, X...)>
)");
LUAU_REQUIRE_NO_ERRORS(cResult);
auto tf = lookupType("Alias");
REQUIRE(tf);
CHECK_EQ(toString(*tf), "Alias<S, T, R...>");
CHECK_EQ(toString(*tf, {true}), "{| a: S, b: (T, R...) -> () |}");
CHECK_EQ(toString(requireType("a"), {true}), "{| a: string, b: (number, boolean) -> () |}");
tf = lookupType("B");
REQUIRE(tf);
CHECK_EQ(toString(*tf), "B<X...>");
CHECK_EQ(toString(*tf, {true}), "{| a: string, b: (X...) -> () |}");
tf = lookupType("C");
REQUIRE(tf);
CHECK_EQ(toString(*tf), "C<X...>");
CHECK_EQ(toString(*tf, {true}), "{| a: string, b: (number, X...) -> () |}");
}
TEST_CASE_FIXTURE(Fixture, "type_alias_type_packs_nested")
{
ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true);
ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true);
ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true);
CheckResult result = check(R"(
type Packed1<T...> = (T...) -> (T...)
type Packed2<T...> = (Packed1<T...>, T...) -> (Packed1<T...>, T...)
type Packed3<T...> = (Packed2<T...>, T...) -> (Packed2<T...>, T...)
type Packed4<T...> = (Packed3<T...>, T...) -> (Packed3<T...>, T...)
)");
LUAU_REQUIRE_NO_ERRORS(result);
auto tf = lookupType("Packed4");
REQUIRE(tf);
CHECK_EQ(toString(*tf),
"((((T...) -> (T...), T...) -> ((T...) -> (T...), T...), T...) -> (((T...) -> (T...), T...) -> ((T...) -> (T...), T...), T...), T...) -> "
"((((T...) -> (T...), T...) -> ((T...) -> (T...), T...), T...) -> (((T...) -> (T...), T...) -> ((T...) -> (T...), T...), T...), T...)");
}
TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_variadic")
{
ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true);
ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true);
ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true);
CheckResult result = check(R"(
type X<T...> = (T...) -> (string, T...)
type D = X<...number>
type E = X<(number, ...string)>
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(toString(*lookupType("D")), "(...number) -> (string, ...number)");
CHECK_EQ(toString(*lookupType("E")), "(number, ...string) -> (string, number, ...string)");
}
TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_multi")
{
ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true);
ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true);
ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true);
CheckResult result = check(R"(
type Y<T..., U...> = (T...) -> (U...)
type A<S...> = Y<S..., S...>
type B<S...> = Y<(number, ...string), S...>
type Z<T, U...> = (T) -> (U...)
type E<S...> = Z<number, S...>
type F<S...> = Z<number, (string, S...)>
type W<T, U..., V...> = (T, U...) -> (T, V...)
type H<S..., R...> = W<number, S..., R...>
type I<S..., R...> = W<number, (string, S...), R...>
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(toString(*lookupType("A")), "(S...) -> (S...)");
CHECK_EQ(toString(*lookupType("B")), "(number, ...string) -> (S...)");
CHECK_EQ(toString(*lookupType("E")), "(number) -> (S...)");
CHECK_EQ(toString(*lookupType("F")), "(number) -> (string, S...)");
CHECK_EQ(toString(*lookupType("H")), "(number, S...) -> (number, R...)");
CHECK_EQ(toString(*lookupType("I")), "(number, string, S...) -> (number, R...)");
}
TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_explicit")
{
ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true);
ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true);
ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true);
CheckResult result = check(R"(
type X<T...> = (T...) -> (T...)
type A<S...> = X<(S...)>
type B = X<()>
type C = X<(number)>
type D = X<(number, string)>
type E = X<(...number)>
type F = X<(string, ...number)>
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(toString(*lookupType("A")), "(S...) -> (S...)");
CHECK_EQ(toString(*lookupType("B")), "() -> ()");
CHECK_EQ(toString(*lookupType("C")), "(number) -> number");
CHECK_EQ(toString(*lookupType("D")), "(number, string) -> (number, string)");
CHECK_EQ(toString(*lookupType("E")), "(...number) -> (...number)");
CHECK_EQ(toString(*lookupType("F")), "(string, ...number) -> (string, ...number)");
}
TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_explicit_multi")
{
ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true);
ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true);
ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true);
CheckResult result = check(R"(
type Y<T..., U...> = (T...) -> (U...)
type A = Y<(number, string), (boolean)>
type B = Y<(), ()>
type C<S...> = Y<...string, (number, S...)>
type D<X...> = Y<X..., (number, string, X...)>
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(toString(*lookupType("A")), "(number, string) -> boolean");
CHECK_EQ(toString(*lookupType("B")), "() -> ()");
CHECK_EQ(toString(*lookupType("C")), "(...string) -> (number, S...)");
CHECK_EQ(toString(*lookupType("D")), "(X...) -> (number, string, X...)");
}
TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_explicit_multi_tostring")
{
ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true);
ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true);
ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true);
ScopedFastFlag luauInstantiatedTypeParamRecursion("LuauInstantiatedTypeParamRecursion", true); // For correct toString block
CheckResult result = check(R"(
type Y<T..., U...> = { f: (T...) -> (U...) }
local a: Y<(number, string), (boolean)>
local b: Y<(), ()>
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(toString(requireType("a")), "Y<(number, string), (boolean)>");
CHECK_EQ(toString(requireType("b")), "Y<(), ()>");
}
TEST_CASE_FIXTURE(Fixture, "type_alias_backwards_compatible")
{
ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true);
ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true);
ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true);
CheckResult result = check(R"(
type X<T> = () -> T
type Y<T, U> = (T) -> U
type A = X<(number)>
type B = Y<(number), (boolean)>
type C = Y<(number), boolean>
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(toString(*lookupType("A")), "() -> number");
CHECK_EQ(toString(*lookupType("B")), "(number) -> boolean");
CHECK_EQ(toString(*lookupType("C")), "(number) -> boolean");
}
TEST_CASE_FIXTURE(Fixture, "type_alias_type_packs_errors")
{
ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true);
ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true);
ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true);
CheckResult result = check(R"(
type Packed<T, U, V...> = (T, U) -> (V...)
local b: Packed<number>
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ(toString(result.errors[0]), "Generic type 'Packed<T, U, V...>' expects at least 2 type arguments, but only 1 is specified");
result = check(R"(
type Packed<T, U> = (T, U) -> ()
type B<X...> = Packed<number, string, X...>
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ(toString(result.errors[0]), "Generic type 'Packed<T, U>' expects 0 type pack arguments, but 1 is specified");
result = check(R"(
type Packed<T..., U...> = (T...) -> (U...)
type Other<S...> = Packed<S..., string>
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ(toString(result.errors[0]), "Type parameters must come before type pack parameters");
result = check(R"(
type Packed<T, U> = (T) -> U
type Other<S...> = Packed<number, S...>
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ(toString(result.errors[0]), "Generic type 'Packed<T, U>' expects 2 type arguments, but only 1 is specified");
result = check(R"(
type Packed<T...> = (T...) -> T...
local a: Packed
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ(toString(result.errors[0]), "Type parameter list is required");
result = check(R"(
type Packed<T..., U...> = (T...) -> (U...)
type Other = Packed<>
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ(toString(result.errors[0]), "Generic type 'Packed<T..., U...>' expects 2 type pack arguments, but none are specified");
result = check(R"(
type Packed<T..., U...> = (T...) -> (U...)
type Other = Packed<number, string>
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ(toString(result.errors[0]), "Generic type 'Packed<T..., U...>' expects 2 type pack arguments, but only 1 is specified");
}
TEST_SUITE_END();

View file

@ -181,8 +181,6 @@ TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_one_optional_property")
TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_missing_property")
{
ScopedFastFlag luauMissingUnionPropertyError("LuauMissingUnionPropertyError", true);
CheckResult result = check(R"(
type A = {x: number}
type B = {}
@ -237,27 +235,11 @@ TEST_CASE_FIXTURE(Fixture, "union_equality_comparisons")
local z = a == c
)");
if (FFlag::LuauEqConstraint)
{
LUAU_REQUIRE_NO_ERRORS(result);
}
else
{
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ(*typeChecker.booleanType, *requireType("x"));
CHECK_EQ(*typeChecker.booleanType, *requireType("y"));
TypeMismatch* tm = get<TypeMismatch>(result.errors[0]);
REQUIRE(tm);
CHECK_EQ("(number | string)?", toString(*tm->wantedType));
CHECK_EQ("boolean | number", toString(*tm->givenType));
}
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "optional_union_members")
{
ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true);
CheckResult result = check(R"(
local a = { a = { x = 1, y = 2 }, b = 3 }
type A = typeof(a)
@ -273,8 +255,6 @@ local c = bf.a.y
TEST_CASE_FIXTURE(Fixture, "optional_union_functions")
{
ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true);
CheckResult result = check(R"(
local a = {}
function a.foo(x:number, y:number) return x + y end
@ -290,8 +270,6 @@ local c = b.foo(1, 2)
TEST_CASE_FIXTURE(Fixture, "optional_union_methods")
{
ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true);
CheckResult result = check(R"(
local a = {}
function a:foo(x:number, y:number) return x + y end
@ -324,8 +302,6 @@ return f()
TEST_CASE_FIXTURE(Fixture, "optional_field_access_error")
{
ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true);
CheckResult result = check(R"(
type A = { x: number }
local b: A? = { x = 2 }
@ -341,8 +317,6 @@ local d = b.y
TEST_CASE_FIXTURE(Fixture, "optional_index_error")
{
ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true);
CheckResult result = check(R"(
type A = {number}
local a: A? = {1, 2, 3}
@ -355,8 +329,6 @@ local b = a[1]
TEST_CASE_FIXTURE(Fixture, "optional_call_error")
{
ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true);
CheckResult result = check(R"(
type A = (number) -> number
local a: A? = function(a) return -a end
@ -369,8 +341,6 @@ local b = a(4)
TEST_CASE_FIXTURE(Fixture, "optional_assignment_errors")
{
ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true);
CheckResult result = check(R"(
type A = { x: number }
local a: A? = { x = 2 }
@ -392,8 +362,6 @@ a.x = 2
TEST_CASE_FIXTURE(Fixture, "optional_length_error")
{
ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true);
CheckResult result = check(R"(
type A = {number}
local a: A? = {1, 2, 3}
@ -406,9 +374,6 @@ local b = #a
TEST_CASE_FIXTURE(Fixture, "optional_missing_key_error_details")
{
ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true);
ScopedFastFlag luauMissingUnionPropertyError("LuauMissingUnionPropertyError", true);
CheckResult result = check(R"(
type A = { x: number, y: number }
type B = { x: number, y: number }

View file

@ -1,5 +1,6 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Parser.h"
#include "Luau/Scope.h"
#include "Luau/TypeInfer.h"
#include "Luau/TypeVar.h"
@ -264,4 +265,64 @@ TEST_CASE_FIXTURE(Fixture, "substitution_skip_failure")
CHECK_EQ("{ f: t1 } where t1 = () -> { f: () -> { f: ({ f: t1 }) -> (), signal: { f: (any) -> () } } }", toString(result));
}
TEST_CASE("tagging_tables")
{
ScopedFastFlag sff{"LuauRefactorTagging", true};
TypeVar ttv{TableTypeVar{}};
CHECK(!Luau::hasTag(&ttv, "foo"));
Luau::attachTag(&ttv, "foo");
CHECK(Luau::hasTag(&ttv, "foo"));
}
TEST_CASE("tagging_classes")
{
ScopedFastFlag sff{"LuauRefactorTagging", true};
TypeVar base{ClassTypeVar{"Base", {}, std::nullopt, std::nullopt, {}, nullptr}};
CHECK(!Luau::hasTag(&base, "foo"));
Luau::attachTag(&base, "foo");
CHECK(Luau::hasTag(&base, "foo"));
}
TEST_CASE("tagging_subclasses")
{
ScopedFastFlag sff{"LuauRefactorTagging", true};
TypeVar base{ClassTypeVar{"Base", {}, std::nullopt, std::nullopt, {}, nullptr}};
TypeVar derived{ClassTypeVar{"Derived", {}, &base, std::nullopt, {}, nullptr}};
CHECK(!Luau::hasTag(&base, "foo"));
CHECK(!Luau::hasTag(&derived, "foo"));
Luau::attachTag(&base, "foo");
CHECK(Luau::hasTag(&base, "foo"));
CHECK(Luau::hasTag(&derived, "foo"));
Luau::attachTag(&derived, "bar");
CHECK(!Luau::hasTag(&base, "bar"));
CHECK(Luau::hasTag(&derived, "bar"));
}
TEST_CASE("tagging_functions")
{
ScopedFastFlag sff{"LuauRefactorTagging", true};
TypePackVar empty{TypePack{}};
TypeVar ftv{FunctionTypeVar{&empty, &empty}};
CHECK(!Luau::hasTag(&ftv, "foo"));
Luau::attachTag(&ftv, "foo");
CHECK(Luau::hasTag(&ftv, "foo"));
}
TEST_CASE("tagging_props")
{
ScopedFastFlag sff{"LuauRefactorTagging", true};
Property prop{};
CHECK(!Luau::hasTag(prop, "foo"));
Luau::attachTag(prop, "foo");
CHECK(Luau::hasTag(prop, "foo"));
}
TEST_SUITE_END();

View file

@ -0,0 +1,15 @@
-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
-- This file is based on Lua 5.x tests -- https://github.com/lua/lua/tree/master/testes
-- Generate an error (i.e. throw an exception) inside a tag method which is indirectly
-- called via pcall.
-- This test is meant to detect a regression in handling errors inside a tag method
local testtable = {}
setmetatable(testtable, { __index = function() error("Error") end })
pcall(function()
testtable.missingmethod()
end)
return('OK')

View file

@ -11,9 +11,9 @@ class VariantPrinter:
return type.name + " [" + str(value) + "]"
def match_printer(val):
type = val.type.strip_typedefs()
if type.name and type.name.startswith('Luau::Variant<'):
return VariantPrinter(val)
return None
type = val.type.strip_typedefs()
if type.name and type.name.startswith('Luau::Variant<'):
return VariantPrinter(val)
return None
gdb.pretty_printers.append(match_printer)

95
tools/tracegraph.py Normal file
View file

@ -0,0 +1,95 @@
#!/usr/bin/python
# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
# Given a trace event file, this tool generates a flame graph based on the event scopes present in the file
# The result of analysis is a .svg file which can be viewed in a browser
import sys
import svg
import json
class Node(svg.Node):
def __init__(self):
svg.Node.__init__(self)
self.caption = ""
self.description = ""
self.ticks = 0
def text(self):
return self.caption
def title(self):
return self.caption
def details(self, root):
return "{} ({:,} usec, {:.1%}); self: {:,} usec".format(self.description, self.width, self.width / root.width, self.ticks)
with open(sys.argv[1]) as f:
dump = f.read()
root = Node()
# Finish the file
if not dump.endswith("]"):
dump += "{}]"
data = json.loads(dump)
stacks = {}
for l in data:
if len(l) == 0:
continue
# Track stack of each thread, but aggregate values together
tid = l["tid"]
if not tid in stacks:
stacks[tid] = []
stack = stacks[tid]
if l["ph"] == 'B':
stack.append(l)
elif l["ph"] == 'E':
node = root
for e in stack:
caption = e["name"]
description = ''
if "args" in e:
for arg in e["args"]:
if len(description) != 0:
description += ", "
description += "{}: {}".format(arg, e["args"][arg])
child = node.child(caption + description)
child.caption = caption
child.description = description
node = child
begin = stack[-1]
ticks = l["ts"] - begin["ts"]
rawticks = ticks
# Flame graph requires ticks without children duration
if "childts" in begin:
ticks -= begin["childts"]
node.ticks += int(ticks)
stack.pop()
if len(stack):
parent = stack[-1]
if "childts" in parent:
parent["childts"] += rawticks
else:
parent["childts"] = rawticks
svg.layout(root, lambda n: n.ticks)
svg.display(root, "Flame Graph", "hot", flip = True)