Sync to upstream/release/502 (#134)

Changes:
- Support for time tracing for analysis/compiler (not currently exposed
  through CLI)
- Support for type pack arguments in type aliases (#83)
- Basic support for require(path) in luau-analyze
- Add a lint warning for table.move with 0 index as part of
  TableOperation lint
- Remove last STL dependency from Luau.VM
- Minor VS2022 performance tuning

Co-authored-by: Rodactor <rodactor@roblox.com>
This commit is contained in:
Arseny Kapoulkine 2021-11-04 19:34:35 -07:00 committed by GitHub
parent adacdcdf4e
commit 49b0c59eec
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
69 changed files with 4478 additions and 2962 deletions

View file

@ -1,7 +1,8 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "TypeInfer.h"
#include "Luau/Scope.h"
#include "Luau/TypeInfer.h"
namespace Luau
{

View file

@ -120,6 +120,7 @@ struct IncorrectGenericParameterCount
Name name;
TypeFun typeFun;
size_t actualParameters;
size_t actualPackParameters;
bool operator==(const IncorrectGenericParameterCount& rhs) const;
};

View file

@ -25,51 +25,39 @@ struct SourceCode
Type type;
};
struct ModuleInfo
{
ModuleName name;
bool optional = false;
};
struct FileResolver
{
virtual ~FileResolver() {}
/** Fetch the source code associated with the provided ModuleName.
*
* FIXME: This requires a string copy!
*
* @returns The actual Lua code on success.
* @returns std::nullopt if no such file exists. When this occurs, type inference will report an UnknownRequire error.
*/
virtual std::optional<SourceCode> readSource(const ModuleName& name) = 0;
/** Does the module exist?
*
* Saves a string copy over reading the source and throwing it away.
*/
virtual bool moduleExists(const ModuleName& name) const = 0;
virtual std::optional<ModuleInfo> resolveModule(const ModuleInfo* context, AstExpr* expr)
{
return std::nullopt;
}
virtual std::optional<ModuleName> fromAstFragment(AstExpr* expr) const = 0;
/** Given a valid module name and a string of arbitrary data, figure out the concatenation.
*/
virtual ModuleName concat(const ModuleName& lhs, std::string_view rhs) const = 0;
/** Goes "up" a level in the hierarchy that the ModuleName represents.
*
* For instances, this is analogous to someInstance.Parent; for paths, this is equivalent to removing the last
* element of the path. Other ModuleName representations may have other ways of doing this.
*
* @returns The parent ModuleName, if one exists.
* @returns std::nullopt if there is no parent for this module name.
*/
virtual std::optional<ModuleName> getParentModuleName(const ModuleName& name) const = 0;
virtual std::optional<std::string> getHumanReadableModuleName_(const ModuleName& name) const
virtual std::string getHumanReadableModuleName(const ModuleName& name) const
{
return name;
}
virtual std::optional<std::string> getEnvironmentForModule(const ModuleName& name) const = 0;
virtual std::optional<std::string> getEnvironmentForModule(const ModuleName& name) const
{
return std::nullopt;
}
/** LanguageService only:
* std::optional<ModuleName> fromInstance(Instance* inst)
*/
// DEPRECATED APIS
// These are going to be removed with LuauNewRequireTracer
virtual bool moduleExists(const ModuleName& name) const = 0;
virtual std::optional<ModuleName> fromAstFragment(AstExpr* expr) const = 0;
virtual ModuleName concat(const ModuleName& lhs, std::string_view rhs) const = 0;
virtual std::optional<ModuleName> getParentModuleName(const ModuleName& name) const = 0;
};
struct NullFileResolver : FileResolver
@ -94,10 +82,6 @@ struct NullFileResolver : FileResolver
{
return std::nullopt;
}
std::optional<std::string> getEnvironmentForModule(const ModuleName& name) const override
{
return std::nullopt;
}
};
} // namespace Luau

View file

@ -90,10 +90,12 @@ struct Module
TypeArena internalTypes;
std::vector<std::pair<Location, ScopePtr>> scopes; // never empty
std::unordered_map<const AstExpr*, TypeId> astTypes;
std::unordered_map<const AstExpr*, TypeId> astExpectedTypes;
std::unordered_map<const AstExpr*, TypeId> astOriginalCallTypes;
std::unordered_map<const AstExpr*, TypeId> astOverloadResolvedTypes;
DenseHashMap<const AstExpr*, TypeId> astTypes{nullptr};
DenseHashMap<const AstExpr*, TypeId> astExpectedTypes{nullptr};
DenseHashMap<const AstExpr*, TypeId> astOriginalCallTypes{nullptr};
DenseHashMap<const AstExpr*, TypeId> astOverloadResolvedTypes{nullptr};
std::unordered_map<Name, TypeId> declaredGlobals;
ErrorVec errors;
Mode mode;

View file

@ -15,12 +15,6 @@ struct Module;
using ModulePtr = std::shared_ptr<Module>;
struct ModuleInfo
{
ModuleName name;
bool optional = false;
};
struct ModuleResolver
{
virtual ~ModuleResolver() {}

View file

@ -17,12 +17,11 @@ struct AstLocal;
struct RequireTraceResult
{
DenseHashMap<const AstExpr*, ModuleName> exprs{0};
DenseHashMap<const AstExpr*, bool> optional{0};
DenseHashMap<const AstExpr*, ModuleInfo> exprs{nullptr};
std::vector<std::pair<ModuleName, Location>> requires;
};
RequireTraceResult traceRequires(FileResolver* fileResolver, AstStatBlock* root, ModuleName currentModuleName);
RequireTraceResult traceRequires(FileResolver* fileResolver, AstStatBlock* root, const ModuleName& currentModuleName);
} // namespace Luau

View file

@ -0,0 +1,67 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Location.h"
#include "Luau/TypeVar.h"
#include <unordered_map>
#include <optional>
#include <memory>
namespace Luau
{
struct Scope;
using ScopePtr = std::shared_ptr<Scope>;
struct Binding
{
TypeId typeId;
Location location;
bool deprecated = false;
std::string deprecatedSuggestion;
std::optional<std::string> documentationSymbol;
};
struct Scope
{
explicit Scope(TypePackId returnType); // root scope
explicit Scope(const ScopePtr& parent, int subLevel = 0); // child scope. Parent must not be nullptr.
const ScopePtr parent; // null for the root
std::unordered_map<Symbol, Binding> bindings;
TypePackId returnType;
bool breakOk = false;
std::optional<TypePackId> varargPack;
TypeLevel level;
std::unordered_map<Name, TypeFun> exportedTypeBindings;
std::unordered_map<Name, TypeFun> privateTypeBindings;
std::unordered_map<Name, Location> typeAliasLocations;
std::unordered_map<Name, std::unordered_map<Name, TypeFun>> importedTypeBindings;
std::optional<TypeId> lookup(const Symbol& name);
std::optional<TypeFun> lookupType(const Name& name);
std::optional<TypeFun> lookupImportedType(const Name& moduleAlias, const Name& name);
std::unordered_map<Name, TypePackId> privateTypePackBindings;
std::optional<TypePackId> lookupPack(const Name& name);
// WARNING: This function linearly scans for a string key of equal value! It is thus O(n**2)
std::optional<Binding> linearSearchForBinding(const std::string& name, bool traverseScopeChain = true);
RefinementMap refinements;
// For mutually recursive type aliases, it's important that
// they use the same types for the same names.
// For instance, in `type Tree<T> { data: T, children: Forest<T> } type Forest<T> = {Tree<T>}`
// we need that the generic type `T` in both cases is the same, so we use a cache.
std::unordered_map<Name, TypeId> typeAliasTypeParameters;
std::unordered_map<Name, TypePackId> typeAliasTypePackParameters;
};
} // namespace Luau

View file

@ -52,8 +52,6 @@
// `T`, and the type of `f` are in the same SCC, which is why `f` gets
// replaced.
LUAU_FASTFLAG(DebugLuauTrackOwningArena)
namespace Luau
{
@ -188,20 +186,12 @@ struct Substitution : FindDirty
template<typename T>
TypeId addType(const T& tv)
{
TypeId allocated = currentModule->internalTypes.typeVars.allocate(tv);
if (FFlag::DebugLuauTrackOwningArena)
asMutable(allocated)->owningArena = &currentModule->internalTypes;
return allocated;
return currentModule->internalTypes.addType(tv);
}
template<typename T>
TypePackId addTypePack(const T& tp)
{
TypePackId allocated = currentModule->internalTypes.typePacks.allocate(tp);
if (FFlag::DebugLuauTrackOwningArena)
asMutable(allocated)->owningArena = &currentModule->internalTypes;
return allocated;
return currentModule->internalTypes.addTypePack(TypePackVar{tp});
}
};

View file

@ -86,7 +86,10 @@ struct ApplyTypeFunction : Substitution
{
TypeLevel level;
bool encounteredForwardedType;
std::unordered_map<TypeId, TypeId> arguments;
std::unordered_map<TypeId, TypeId> typeArguments;
std::unordered_map<TypePackId, TypePackId> typePackArguments;
bool ignoreChildren(TypeId ty) override;
bool ignoreChildren(TypePackId tp) override;
bool isDirty(TypeId ty) override;
bool isDirty(TypePackId tp) override;
TypeId clean(TypeId ty) override;
@ -328,7 +331,8 @@ private:
TypeId resolveType(const ScopePtr& scope, const AstType& annotation, bool canBeGeneric = false);
TypePackId resolveTypePack(const ScopePtr& scope, const AstTypeList& types);
TypePackId resolveTypePack(const ScopePtr& scope, const AstTypePack& annotation);
TypeId instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector<TypeId>& typeParams, const Location& location);
TypeId instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& typePackParams, const Location& location);
// Note: `scope` must be a fresh scope.
std::pair<std::vector<TypeId>, std::vector<TypePackId>> createGenericTypes(
@ -398,54 +402,6 @@ private:
int recursionCount = 0;
};
struct Binding
{
TypeId typeId;
Location location;
bool deprecated = false;
std::string deprecatedSuggestion;
std::optional<std::string> documentationSymbol;
};
struct Scope
{
explicit Scope(TypePackId returnType); // root scope
explicit Scope(const ScopePtr& parent, int subLevel = 0); // child scope. Parent must not be nullptr.
const ScopePtr parent; // null for the root
std::unordered_map<Symbol, Binding> bindings;
TypePackId returnType;
bool breakOk = false;
std::optional<TypePackId> varargPack;
TypeLevel level;
std::unordered_map<Name, TypeFun> exportedTypeBindings;
std::unordered_map<Name, TypeFun> privateTypeBindings;
std::unordered_map<Name, Location> typeAliasLocations;
std::unordered_map<Name, std::unordered_map<Name, TypeFun>> importedTypeBindings;
std::optional<TypeId> lookup(const Symbol& name);
std::optional<TypeFun> lookupType(const Name& name);
std::optional<TypeFun> lookupImportedType(const Name& moduleAlias, const Name& name);
std::unordered_map<Name, TypePackId> privateTypePackBindings;
std::optional<TypePackId> lookupPack(const Name& name);
// WARNING: This function linearly scans for a string key of equal value! It is thus O(n**2)
std::optional<Binding> linearSearchForBinding(const std::string& name, bool traverseScopeChain = true);
RefinementMap refinements;
// For mutually recursive type aliases, it's important that
// they use the same types for the same names.
// For instance, in `type Tree<T> { data: T, children: Forest<T> } type Forest<T> = {Tree<T>}`
// we need that the generic type `T` in both cases is the same, so we use a cache.
std::unordered_map<Name, TypeId> typeAliasParameters;
};
// Unit test hook
void setPrintLine(void (*pl)(const std::string& s));
void resetPrintLine();

View file

@ -117,7 +117,8 @@ bool areEqual(SeenSet& seen, const TypePackVar& lhs, const TypePackVar& rhs);
TypePackId follow(TypePackId tp);
size_t size(const TypePackId tp);
size_t size(TypePackId tp);
bool finite(TypePackId tp);
size_t size(const TypePack& tp);
std::optional<TypeId> first(TypePackId tp);

View file

@ -228,6 +228,7 @@ struct TableTypeVar
std::map<Name, Location> methodDefinitionLocations;
std::vector<TypeId> instantiatedTypeParams;
std::vector<TypePackId> instantiatedTypePackParams;
ModuleName definitionModuleName;
std::optional<TypeId> boundTo;
@ -284,8 +285,9 @@ struct ClassTypeVar
struct TypeFun
{
/// These should all be generic
// These should all be generic
std::vector<TypeId> typeParams;
std::vector<TypePackId> typePackParams;
/** The underlying type.
*
@ -293,6 +295,20 @@ struct TypeFun
* You must first use TypeChecker::instantiateTypeFun to turn it into a real type.
*/
TypeId type;
TypeFun() = default;
TypeFun(std::vector<TypeId> typeParams, TypeId type)
: typeParams(std::move(typeParams))
, type(type)
{
}
TypeFun(std::vector<TypeId> typeParams, std::vector<TypePackId> typePackParams, TypeId type)
: typeParams(std::move(typeParams))
, typePackParams(std::move(typePackParams))
, type(type)
{
}
};
// Anything! All static checking is off.
@ -524,8 +540,4 @@ UnionTypeVarIterator end(const UnionTypeVar* utv);
using TypeIdPredicate = std::function<std::optional<TypeId>(TypeId)>;
std::vector<TypeId> filterMap(TypeId type, TypeIdPredicate predicate);
// TEMP: Clip this prototype with FFlag::LuauStringMetatable
std::optional<ExprResult<TypePackId>> magicFunctionFormat(
struct TypeChecker& typechecker, const std::shared_ptr<struct Scope>& scope, const AstExprCall& expr, ExprResult<TypePackId> exprResult);
} // namespace Luau

View file

@ -36,12 +36,17 @@ struct Unifier
Variance variance = Covariant;
CountMismatch::Context ctx = CountMismatch::Arg;
std::shared_ptr<UnifierCounters> counters;
UnifierCounters* counters;
UnifierCounters countersData;
std::shared_ptr<UnifierCounters> counters_DEPRECATED;
InternalErrorReporter* iceHandler;
Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, InternalErrorReporter* iceHandler);
Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::vector<std::pair<TypeId, TypeId>>& seen, const Location& location,
Variance variance, InternalErrorReporter* iceHandler, const std::shared_ptr<UnifierCounters>& counters = nullptr);
Variance variance, InternalErrorReporter* iceHandler, const std::shared_ptr<UnifierCounters>& counters_DEPRECATED = nullptr,
UnifierCounters* counters = nullptr);
// Test whether the two type vars unify. Never commits the result.
ErrorVec canUnify(TypeId superTy, TypeId subTy);
@ -58,11 +63,13 @@ private:
void tryUnifyPrimitives(TypeId superTy, TypeId subTy);
void tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCall = false);
void tryUnifyTables(TypeId left, TypeId right, bool isIntersection = false);
void DEPRECATED_tryUnifyTables(TypeId left, TypeId right, bool isIntersection = false);
void tryUnifyFreeTable(TypeId free, TypeId other);
void tryUnifySealedTables(TypeId left, TypeId right, bool isIntersection);
void tryUnifyWithMetatable(TypeId metatable, TypeId other, bool reversed);
void tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed);
void tryUnify(const TableIndexer& superIndexer, const TableIndexer& subIndexer);
TypeId deeplyOptional(TypeId ty, std::unordered_map<TypeId,TypeId> seen = {});
public:
void tryUnify(TypePackId superTy, TypePackId subTy, bool isFunctionCall = false);
@ -80,9 +87,9 @@ private:
public:
// Report an "infinite type error" if the type "needle" already occurs within "haystack"
void occursCheck(TypeId needle, TypeId haystack);
void occursCheck(std::unordered_set<TypeId>& seen, TypeId needle, TypeId haystack);
void occursCheck(std::unordered_set<TypeId>& seen_DEPRECATED, DenseHashSet<TypeId>& seen, TypeId needle, TypeId haystack);
void occursCheck(TypePackId needle, TypePackId haystack);
void occursCheck(std::unordered_set<TypePackId>& seen, TypePackId needle, TypePackId haystack);
void occursCheck(std::unordered_set<TypePackId>& seen_DEPRECATED, DenseHashSet<TypePackId>& seen, TypePackId needle, TypePackId haystack);
Unifier makeChildUnifier();
@ -93,6 +100,9 @@ private:
[[noreturn]] void ice(const std::string& message, const Location& location);
[[noreturn]] void ice(const std::string& message);
DenseHashSet<TypeId> tempSeenTy{nullptr};
DenseHashSet<TypePackId> tempSeenTp{nullptr};
};
} // namespace Luau

View file

@ -2,6 +2,7 @@
#include "Luau/AstQuery.h"
#include "Luau/Module.h"
#include "Luau/Scope.h"
#include "Luau/TypeInfer.h"
#include "Luau/TypeVar.h"
#include "Luau/ToString.h"
@ -143,8 +144,8 @@ std::optional<TypeId> findTypeAtPosition(const Module& module, const SourceModul
{
if (auto expr = findExprAtPosition(sourceModule, pos))
{
if (auto it = module.astTypes.find(expr); it != module.astTypes.end())
return it->second;
if (auto it = module.astTypes.find(expr))
return *it;
}
return std::nullopt;
@ -154,8 +155,8 @@ std::optional<TypeId> findExpectedTypeAtPosition(const Module& module, const Sou
{
if (auto expr = findExprAtPosition(sourceModule, pos))
{
if (auto it = module.astExpectedTypes.find(expr); it != module.astExpectedTypes.end())
return it->second;
if (auto it = module.astExpectedTypes.find(expr))
return *it;
}
return std::nullopt;
@ -322,9 +323,9 @@ std::optional<DocumentationSymbol> getDocumentationSymbolAtPosition(const Source
TypeId matchingOverload = nullptr;
if (parentExpr && parentExpr->is<AstExprCall>())
{
if (auto it = module.astOverloadResolvedTypes.find(parentExpr); it != module.astOverloadResolvedTypes.end())
if (auto it = module.astOverloadResolvedTypes.find(parentExpr))
{
matchingOverload = it->second;
matchingOverload = *it;
}
}
@ -345,9 +346,9 @@ std::optional<DocumentationSymbol> getDocumentationSymbolAtPosition(const Source
{
if (AstExprIndexName* indexName = targetExpr->as<AstExprIndexName>())
{
if (auto it = module.astTypes.find(indexName->expr); it != module.astTypes.end())
if (auto it = module.astTypes.find(indexName->expr))
{
TypeId parentTy = follow(it->second);
TypeId parentTy = follow(*it);
if (const TableTypeVar* ttv = get<TableTypeVar>(parentTy))
{
if (auto propIt = ttv->props.find(indexName->index.value); propIt != ttv->props.end())

View file

@ -210,10 +210,10 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ
return TypeCorrectKind::None;
auto it = module.astExpectedTypes.find(expr);
if (it == module.astExpectedTypes.end())
if (!it)
return TypeCorrectKind::None;
TypeId expectedType = follow(it->second);
TypeId expectedType = follow(*it);
if (canUnify(expectedType, ty))
return TypeCorrectKind::Correct;
@ -682,10 +682,10 @@ static std::optional<bool> functionIsExpectedAt(const Module& module, AstNode* n
return std::nullopt;
auto it = module.astExpectedTypes.find(expr);
if (it == module.astExpectedTypes.end())
if (!it)
return std::nullopt;
TypeId expectedType = follow(it->second);
TypeId expectedType = follow(*it);
if (const FunctionTypeVar* ftv = get<FunctionTypeVar>(expectedType))
return true;
@ -784,9 +784,9 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi
if (AstExprCall* exprCall = expr->as<AstExprCall>())
{
if (auto it = module.astTypes.find(exprCall->func); it != module.astTypes.end())
if (auto it = module.astTypes.find(exprCall->func))
{
if (const FunctionTypeVar* ftv = get<FunctionTypeVar>(follow(it->second)))
if (const FunctionTypeVar* ftv = get<FunctionTypeVar>(follow(*it)))
{
if (auto ty = tryGetTypePackTypeAt(ftv->retType, tailPos))
inferredType = *ty;
@ -798,8 +798,8 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi
if (tailPos != 0)
break;
if (auto it = module.astTypes.find(expr); it != module.astTypes.end())
inferredType = it->second;
if (auto it = module.astTypes.find(expr))
inferredType = *it;
}
if (inferredType)
@ -815,10 +815,10 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi
auto tryGetExpectedFunctionType = [](const Module& module, AstExpr* expr) -> const FunctionTypeVar* {
auto it = module.astExpectedTypes.find(expr);
if (it == module.astExpectedTypes.end())
if (!it)
return nullptr;
TypeId ty = follow(it->second);
TypeId ty = follow(*it);
if (const FunctionTypeVar* ftv = get<FunctionTypeVar>(ty))
return ftv;
@ -1129,9 +1129,8 @@ static void autocompleteExpression(const SourceModule& sourceModule, const Modul
if (node->is<AstExprIndexName>())
{
auto it = module.astTypes.find(node->asExpr());
if (it != module.astTypes.end())
autocompleteProps(module, typeArena, it->second, PropIndexType::Point, ancestry, result);
if (auto it = module.astTypes.find(node->asExpr()))
autocompleteProps(module, typeArena, *it, PropIndexType::Point, ancestry, result);
}
else if (FFlag::LuauIfElseExpressionAnalysisSupport && autocompleteIfElseExpression(node, ancestry, position, result))
return;
@ -1203,13 +1202,13 @@ static std::optional<const ClassTypeVar*> getMethodContainingClass(const ModuleP
return std::nullopt;
}
auto parentIter = module->astTypes.find(parentExpr);
if (parentIter == module->astTypes.end())
auto parentIt = module->astTypes.find(parentExpr);
if (!parentIt)
{
return std::nullopt;
}
Luau::TypeId parentType = Luau::follow(parentIter->second);
Luau::TypeId parentType = Luau::follow(*parentIt);
if (auto parentClass = Luau::get<ClassTypeVar>(parentType))
{
@ -1250,8 +1249,8 @@ static std::optional<AutocompleteEntryMap> autocompleteStringParams(const Source
return std::nullopt;
}
auto iter = module->astTypes.find(candidate->func);
if (iter == module->astTypes.end())
auto it = module->astTypes.find(candidate->func);
if (!it)
{
return std::nullopt;
}
@ -1267,7 +1266,7 @@ static std::optional<AutocompleteEntryMap> autocompleteStringParams(const Source
return std::nullopt;
};
auto followedId = Luau::follow(iter->second);
auto followedId = Luau::follow(*it);
if (auto functionType = Luau::get<FunctionTypeVar>(followedId))
{
return performCallback(functionType);
@ -1316,10 +1315,10 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M
if (auto indexName = node->as<AstExprIndexName>())
{
auto it = module->astTypes.find(indexName->expr);
if (it == module->astTypes.end())
if (!it)
return {};
TypeId ty = follow(it->second);
TypeId ty = follow(*it);
PropIndexType indexType = indexName->op == ':' ? PropIndexType::Colon : PropIndexType::Point;
if (isString(ty))
@ -1447,9 +1446,9 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M
// If item doesn't have a key, maybe the value is actually the key
if (key ? key == node : node->is<AstExprGlobal>() && value == node)
{
if (auto it = module->astExpectedTypes.find(exprTable); it != module->astExpectedTypes.end())
if (auto it = module->astExpectedTypes.find(exprTable))
{
auto result = autocompleteProps(*module, typeArena, it->second, PropIndexType::Key, finder.ancestry);
auto result = autocompleteProps(*module, typeArena, *it, PropIndexType::Key, finder.ancestry);
// Remove keys that are already completed
for (const auto& item : exprTable->items)
@ -1485,9 +1484,9 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M
{
if (auto idxExpr = finder.ancestry.at(finder.ancestry.size() - 2)->as<AstExprIndexExpr>())
{
if (auto it = module->astTypes.find(idxExpr->expr); it != module->astTypes.end())
if (auto it = module->astTypes.find(idxExpr->expr))
{
return {autocompleteProps(*module, typeArena, follow(it->second), PropIndexType::Point, finder.ancestry), finder.ancestry};
return {autocompleteProps(*module, typeArena, follow(*it), PropIndexType::Point, finder.ancestry), finder.ancestry};
}
}
}

View file

@ -11,7 +11,7 @@
LUAU_FASTFLAG(LuauParseGenericFunctions)
LUAU_FASTFLAG(LuauGenericFunctions)
LUAU_FASTFLAG(LuauRankNTypes)
LUAU_FASTFLAG(LuauStringMetatable)
LUAU_FASTFLAG(LuauNewRequireTrace)
/** FIXME: Many of these type definitions are not quite completely accurate.
*
@ -218,7 +218,6 @@ void registerBuiltinTypes(TypeChecker& typeChecker)
TypePackId anyTypePack = typeChecker.anyTypePack;
TypePackId numberVariadicList = arena.addTypePack(TypePackVar{VariadicTypePack{numberType}});
TypePackId stringVariadicList = arena.addTypePack(TypePackVar{VariadicTypePack{stringType}});
TypePackId listOfAtLeastOneNumber = arena.addTypePack(TypePack{{numberType}, numberVariadicList});
TypeId listOfAtLeastOneNumberToNumberType = arena.addType(FunctionTypeVar{
@ -255,8 +254,6 @@ void registerBuiltinTypes(TypeChecker& typeChecker)
TypeId genericV = arena.addType(GenericTypeVar{"V"});
TypeId mapOfKtoV = arena.addType(TableTypeVar{{}, TableIndexer(genericK, genericV), typeChecker.globalScope->level});
if (FFlag::LuauStringMetatable)
{
std::optional<TypeId> stringMetatableTy = getMetatable(singletonTypes.stringType);
LUAU_ASSERT(stringMetatableTy);
const TableTypeVar* stringMetatableTable = get<TableTypeVar>(follow(*stringMetatableTy));
@ -265,75 +262,10 @@ void registerBuiltinTypes(TypeChecker& typeChecker)
auto it = stringMetatableTable->props.find("__index");
LUAU_ASSERT(it != stringMetatableTable->props.end());
TypeId stringLib = it->second.type;
addGlobalBinding(typeChecker, "string", stringLib, "@luau");
}
addGlobalBinding(typeChecker, "string", it->second.type, "@luau");
if (FFlag::LuauParseGenericFunctions && FFlag::LuauGenericFunctions)
if (!FFlag::LuauParseGenericFunctions || !FFlag::LuauGenericFunctions)
{
if (!FFlag::LuauStringMetatable)
{
TypeId stringLibTy = getGlobalBinding(typeChecker, "string");
TableTypeVar* stringLib = getMutable<TableTypeVar>(stringLibTy);
TypeId replArgType = makeUnion(
arena, {stringType,
arena.addType(TableTypeVar({}, TableIndexer(stringType, stringType), typeChecker.globalScope->level, TableState::Generic)),
makeFunction(arena, std::nullopt, {stringType}, {stringType})});
TypeId gsubFunc = makeFunction(arena, stringType, {stringType, replArgType, optionalNumber}, {stringType, numberType});
stringLib->props["gsub"] = makeProperty(gsubFunc, "@luau/global/string.gsub");
}
}
else
{
if (!FFlag::LuauStringMetatable)
{
TypeId stringToStringType = makeFunction(arena, std::nullopt, {stringType}, {stringType});
TypeId gmatchFunc = makeFunction(arena, stringType, {stringType}, {arena.addType(FunctionTypeVar{emptyPack, stringVariadicList})});
TypeId replArgType = makeUnion(
arena, {stringType,
arena.addType(TableTypeVar({}, TableIndexer(stringType, stringType), typeChecker.globalScope->level, TableState::Generic)),
makeFunction(arena, std::nullopt, {stringType}, {stringType})});
TypeId gsubFunc = makeFunction(arena, stringType, {stringType, replArgType, optionalNumber}, {stringType, numberType});
TypeId formatFn = arena.addType(FunctionTypeVar{arena.addTypePack(TypePack{{stringType}, anyTypePack}), oneStringPack});
TableTypeVar::Props stringLib = {
// FIXME string.byte "can" return a pack of numbers, but only if 2nd or 3rd arguments were supplied
{"byte", {makeFunction(arena, stringType, {optionalNumber, optionalNumber}, {optionalNumber})}},
// FIXME char takes a variadic pack of numbers
{"char", {makeFunction(arena, std::nullopt, {numberType, optionalNumber, optionalNumber, optionalNumber}, {stringType})}},
{"find", {makeFunction(arena, stringType, {stringType, optionalNumber, optionalBoolean}, {optionalNumber, optionalNumber})}},
{"format", {formatFn}}, // FIXME
{"gmatch", {gmatchFunc}},
{"gsub", {gsubFunc}},
{"len", {makeFunction(arena, stringType, {}, {numberType})}},
{"lower", {stringToStringType}},
{"match", {makeFunction(arena, stringType, {stringType, optionalNumber}, {optionalString})}},
{"rep", {makeFunction(arena, stringType, {numberType}, {stringType})}},
{"reverse", {stringToStringType}},
{"sub", {makeFunction(arena, stringType, {numberType, optionalNumber}, {stringType})}},
{"upper", {stringToStringType}},
{"split", {makeFunction(arena, stringType, {stringType, optionalString},
{arena.addType(TableTypeVar{{}, TableIndexer{numberType, stringType}, typeChecker.globalScope->level})})}},
{"pack", {arena.addType(FunctionTypeVar{
arena.addTypePack(TypePack{{stringType}, anyTypePack}),
oneStringPack,
})}},
{"packsize", {makeFunction(arena, stringType, {}, {numberType})}},
{"unpack", {arena.addType(FunctionTypeVar{
arena.addTypePack(TypePack{{stringType, stringType, optionalNumber}}),
anyTypePack,
})}},
};
assignPropDocumentationSymbols(stringLib, "@luau/global/string");
addGlobalBinding(typeChecker, "string",
arena.addType(TableTypeVar{stringLib, std::nullopt, typeChecker.globalScope->level, TableState::Sealed}), "@luau");
}
TableTypeVar::Props debugLib{
{"info", {makeIntersection(arena,
{
@ -601,9 +533,6 @@ void registerBuiltinTypes(TypeChecker& typeChecker)
auto tableLib = getMutable<TableTypeVar>(getGlobalBinding(typeChecker, "table"));
attachMagicFunction(tableLib->props["pack"].type, magicFunctionPack);
auto stringLib = getMutable<TableTypeVar>(getGlobalBinding(typeChecker, "string"));
attachMagicFunction(stringLib->props["format"].type, magicFunctionFormat);
attachMagicFunction(getGlobalBinding(typeChecker, "require"), magicFunctionRequire);
}
@ -791,11 +720,11 @@ static std::optional<ExprResult<TypePackId>> magicFunctionRequire(
return std::nullopt;
}
AstExpr* require = expr.args.data[0];
if (!checkRequirePath(typechecker, require))
if (!checkRequirePath(typechecker, expr.args.data[0]))
return std::nullopt;
const AstExpr* require = FFlag::LuauNewRequireTrace ? &expr : expr.args.data[0];
if (auto moduleInfo = typechecker.resolver->resolveModuleInfo(typechecker.currentModuleName, *require))
return ExprResult<TypePackId>{arena.addTypePack({typechecker.checkRequire(scope, *moduleInfo, expr.location)})};

View file

@ -206,27 +206,6 @@ std::string getBuiltinDefinitionSource()
graphemes: (string, number?, number?) -> (() -> (number, number)),
}
declare string: {
byte: (string, number?, number?) -> ...number,
char: (number, ...number) -> string,
find: (string, string, number?, boolean?) -> (number?, number?),
-- `string.format` has a magic function attached that will provide more type information for literal format strings.
format: <A...>(string, A...) -> string,
gmatch: (string, string) -> () -> (...string),
-- gsub is defined in C++ because we don't have syntax for describing a generic table.
len: (string) -> number,
lower: (string) -> string,
match: (string, string, number?) -> string?,
rep: (string, number) -> string,
reverse: (string) -> string,
sub: (string, number, number?) -> string,
upper: (string) -> string,
split: (string, string, string?) -> {string},
pack: <A...>(string, A...) -> string,
packsize: (string) -> number,
unpack: <R...>(string, string, number?) -> R...,
}
-- Cannot use `typeof` here because it will produce a polytype when we expect a monotype.
declare function unpack<V>(tab: {V}, i: number?, j: number?): ...V
)";

View file

@ -7,9 +7,9 @@
#include <stdexcept>
LUAU_FASTFLAG(LuauFasterStringifier)
LUAU_FASTFLAG(LuauTypeAliasPacks)
static std::string wrongNumberOfArgsString(size_t expectedCount, size_t actualCount, bool isTypeArgs = false)
static std::string wrongNumberOfArgsString_DEPRECATED(size_t expectedCount, size_t actualCount, bool isTypeArgs = false)
{
std::string s = "expects " + std::to_string(expectedCount) + " ";
@ -41,6 +41,52 @@ static std::string wrongNumberOfArgsString(size_t expectedCount, size_t actualCo
return s;
}
static std::string wrongNumberOfArgsString(size_t expectedCount, size_t actualCount, const char* argPrefix = nullptr, bool isVariadic = false)
{
std::string s;
if (FFlag::LuauTypeAliasPacks)
{
s = "expects ";
if (isVariadic)
s += "at least ";
s += std::to_string(expectedCount) + " ";
}
else
{
s = "expects " + std::to_string(expectedCount) + " ";
}
if (argPrefix)
s += std::string(argPrefix) + " ";
s += "argument";
if (expectedCount != 1)
s += "s";
s += ", but ";
if (actualCount == 0)
{
s += "none";
}
else
{
if (actualCount < expectedCount)
s += "only ";
s += std::to_string(actualCount);
}
s += (actualCount == 1) ? " is" : " are";
s += " specified";
return s;
}
namespace Luau
{
@ -127,7 +173,10 @@ struct ErrorConverter
return "Function only returns " + std::to_string(e.expected) + " value" + expectedS + ". " +
std::to_string(e.actual) + " are required here";
case CountMismatch::Arg:
if (FFlag::LuauTypeAliasPacks)
return "Argument count mismatch. Function " + wrongNumberOfArgsString(e.expected, e.actual);
else
return "Argument count mismatch. Function " + wrongNumberOfArgsString_DEPRECATED(e.expected, e.actual);
}
LUAU_ASSERT(!"Unknown context");
@ -159,13 +208,16 @@ struct ErrorConverter
std::string operator()(const Luau::UnknownRequire& e) const
{
if (e.modulePath.empty())
return "Unknown require: unsupported path";
else
return "Unknown require: " + e.modulePath;
}
std::string operator()(const Luau::IncorrectGenericParameterCount& e) const
{
std::string name = e.name;
if (!e.typeFun.typeParams.empty())
if (!e.typeFun.typeParams.empty() || (FFlag::LuauTypeAliasPacks && !e.typeFun.typePackParams.empty()))
{
name += "<";
bool first = true;
@ -178,10 +230,37 @@ struct ErrorConverter
name += toString(t);
}
if (FFlag::LuauTypeAliasPacks)
{
for (TypePackId t : e.typeFun.typePackParams)
{
if (first)
first = false;
else
name += ", ";
name += toString(t);
}
}
name += ">";
}
return "Generic type '" + name + "' " + wrongNumberOfArgsString(e.typeFun.typeParams.size(), e.actualParameters, /*isTypeArgs*/ true);
if (FFlag::LuauTypeAliasPacks)
{
if (e.typeFun.typeParams.size() != e.actualParameters)
return "Generic type '" + name + "' " +
wrongNumberOfArgsString(e.typeFun.typeParams.size(), e.actualParameters, "type", !e.typeFun.typePackParams.empty());
return "Generic type '" + name + "' " +
wrongNumberOfArgsString(e.typeFun.typePackParams.size(), e.actualPackParameters, "type pack", /*isVariadic*/ false);
}
else
{
return "Generic type '" + name + "' " +
wrongNumberOfArgsString_DEPRECATED(e.typeFun.typeParams.size(), e.actualParameters, /*isTypeArgs*/ true);
}
}
std::string operator()(const Luau::SyntaxError& e) const
@ -470,9 +549,26 @@ bool IncorrectGenericParameterCount::operator==(const IncorrectGenericParameterC
if (typeFun.typeParams.size() != rhs.typeFun.typeParams.size())
return false;
if (FFlag::LuauTypeAliasPacks)
{
if (typeFun.typePackParams.size() != rhs.typeFun.typePackParams.size())
return false;
}
for (size_t i = 0; i < typeFun.typeParams.size(); ++i)
{
if (typeFun.typeParams[i] != rhs.typeFun.typeParams[i])
return false;
}
if (FFlag::LuauTypeAliasPacks)
{
for (size_t i = 0; i < typeFun.typePackParams.size(); ++i)
{
if (typeFun.typePackParams[i] != rhs.typeFun.typePackParams[i])
return false;
}
}
return true;
}

View file

@ -1,9 +1,12 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Frontend.h"
#include "Luau/Common.h"
#include "Luau/Config.h"
#include "Luau/FileResolver.h"
#include "Luau/Scope.h"
#include "Luau/StringUtils.h"
#include "Luau/TimeTrace.h"
#include "Luau/TypeInfer.h"
#include "Luau/Variant.h"
#include "Luau/Common.h"
@ -19,6 +22,7 @@ LUAU_FASTFLAGVARIABLE(LuauSecondTypecheckKnowsTheDataModel, false)
LUAU_FASTFLAGVARIABLE(LuauResolveModuleNameWithoutACurrentModule, false)
LUAU_FASTFLAG(LuauTraceRequireLookupChild)
LUAU_FASTFLAGVARIABLE(LuauPersistDefinitionFileTypes, false)
LUAU_FASTFLAG(LuauNewRequireTrace)
namespace Luau
{
@ -69,6 +73,8 @@ static void generateDocumentationSymbols(TypeId ty, const std::string& rootName)
LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr targetScope, std::string_view source, const std::string& packageName)
{
LUAU_TIMETRACE_SCOPE("loadDefinitionFile", "Frontend");
Luau::Allocator allocator;
Luau::AstNameTable names(allocator);
@ -350,6 +356,9 @@ FrontendModuleResolver::FrontendModuleResolver(Frontend* frontend)
CheckResult Frontend::check(const ModuleName& name)
{
LUAU_TIMETRACE_SCOPE("Frontend::check", "Frontend");
LUAU_TIMETRACE_ARGUMENT("name", name.c_str());
CheckResult checkResult;
auto it = sourceNodes.find(name);
@ -479,6 +488,9 @@ CheckResult Frontend::check(const ModuleName& name)
bool Frontend::parseGraph(std::vector<ModuleName>& buildQueue, CheckResult& checkResult, const ModuleName& root)
{
LUAU_TIMETRACE_SCOPE("Frontend::parseGraph", "Frontend");
LUAU_TIMETRACE_ARGUMENT("root", root.c_str());
// https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search
enum Mark
{
@ -597,6 +609,9 @@ ScopePtr Frontend::getModuleEnvironment(const SourceModule& module, const Config
LintResult Frontend::lint(const ModuleName& name, std::optional<Luau::LintOptions> enabledLintWarnings)
{
LUAU_TIMETRACE_SCOPE("Frontend::lint", "Frontend");
LUAU_TIMETRACE_ARGUMENT("name", name.c_str());
CheckResult checkResult;
auto [_sourceNode, sourceModule] = getSourceNode(checkResult, name);
@ -608,6 +623,8 @@ LintResult Frontend::lint(const ModuleName& name, std::optional<Luau::LintOption
std::pair<SourceModule, LintResult> Frontend::lintFragment(std::string_view source, std::optional<Luau::LintOptions> enabledLintWarnings)
{
LUAU_TIMETRACE_SCOPE("Frontend::lintFragment", "Frontend");
const Config& config = configResolver->getConfig("");
SourceModule sourceModule = parse(ModuleName{}, source, config.parseOptions);
@ -627,6 +644,9 @@ std::pair<SourceModule, LintResult> Frontend::lintFragment(std::string_view sour
CheckResult Frontend::check(const SourceModule& module)
{
LUAU_TIMETRACE_SCOPE("Frontend::check", "Frontend");
LUAU_TIMETRACE_ARGUMENT("module", module.name.c_str());
const Config& config = configResolver->getConfig(module.name);
Mode mode = module.mode.value_or(config.mode);
@ -648,6 +668,9 @@ CheckResult Frontend::check(const SourceModule& module)
LintResult Frontend::lint(const SourceModule& module, std::optional<Luau::LintOptions> enabledLintWarnings)
{
LUAU_TIMETRACE_SCOPE("Frontend::lint", "Frontend");
LUAU_TIMETRACE_ARGUMENT("module", module.name.c_str());
const Config& config = configResolver->getConfig(module.name);
LintOptions options = enabledLintWarnings.value_or(config.enabledLint);
@ -746,6 +769,9 @@ const SourceModule* Frontend::getSourceModule(const ModuleName& moduleName) cons
// Read AST into sourceModules if necessary. Trace require()s. Report parse errors.
std::pair<SourceNode*, SourceModule*> Frontend::getSourceNode(CheckResult& checkResult, const ModuleName& name)
{
LUAU_TIMETRACE_SCOPE("Frontend::getSourceNode", "Frontend");
LUAU_TIMETRACE_ARGUMENT("name", name.c_str());
auto it = sourceNodes.find(name);
if (it != sourceNodes.end() && !it->second.dirty)
{
@ -815,6 +841,9 @@ std::pair<SourceNode*, SourceModule*> Frontend::getSourceNode(CheckResult& check
*/
SourceModule Frontend::parse(const ModuleName& name, std::string_view src, const ParseOptions& parseOptions)
{
LUAU_TIMETRACE_SCOPE("Frontend::parse", "Frontend");
LUAU_TIMETRACE_ARGUMENT("name", name.c_str());
SourceModule sourceModule;
double timestamp = getTimestamp();
@ -864,20 +893,11 @@ std::optional<ModuleInfo> FrontendModuleResolver::resolveModuleInfo(const Module
const auto& exprs = it->second.exprs;
const ModuleName* relativeName = exprs.find(&pathExpr);
if (!relativeName || relativeName->empty())
const ModuleInfo* info = exprs.find(&pathExpr);
if (!info || (!FFlag::LuauNewRequireTrace && info->name.empty()))
return std::nullopt;
if (FFlag::LuauTraceRequireLookupChild)
{
const bool* optional = it->second.optional.find(&pathExpr);
return {{*relativeName, optional ? *optional : false}};
}
else
{
return {{*relativeName, false}};
}
return *info;
}
const ModulePtr FrontendModuleResolver::getModule(const ModuleName& moduleName) const
@ -891,12 +911,15 @@ const ModulePtr FrontendModuleResolver::getModule(const ModuleName& moduleName)
bool FrontendModuleResolver::moduleExists(const ModuleName& moduleName) const
{
if (FFlag::LuauNewRequireTrace)
return frontend->sourceNodes.count(moduleName) != 0;
else
return frontend->fileResolver->moduleExists(moduleName);
}
std::string FrontendModuleResolver::getHumanReadableModuleName(const ModuleName& moduleName) const
{
return frontend->fileResolver->getHumanReadableModuleName_(moduleName).value_or(moduleName);
return frontend->fileResolver->getHumanReadableModuleName(moduleName);
}
ScopePtr Frontend::addEnvironment(const std::string& environmentName)

View file

@ -2,6 +2,8 @@
#include "Luau/IostreamHelpers.h"
#include "Luau/ToString.h"
LUAU_FASTFLAG(LuauTypeAliasPacks)
namespace Luau
{
@ -92,7 +94,7 @@ std::ostream& operator<<(std::ostream& stream, const IncorrectGenericParameterCo
{
stream << "IncorrectGenericParameterCount { name = " << error.name;
if (!error.typeFun.typeParams.empty())
if (!error.typeFun.typeParams.empty() || (FFlag::LuauTypeAliasPacks && !error.typeFun.typePackParams.empty()))
{
stream << "<";
bool first = true;
@ -105,6 +107,20 @@ std::ostream& operator<<(std::ostream& stream, const IncorrectGenericParameterCo
stream << toString(t);
}
if (FFlag::LuauTypeAliasPacks)
{
for (TypePackId t : error.typeFun.typePackParams)
{
if (first)
first = false;
else
stream << ", ";
stream << toString(t);
}
}
stream << ">";
}

View file

@ -3,6 +3,9 @@
#include "Luau/Ast.h"
#include "Luau/StringUtils.h"
#include "Luau/Common.h"
LUAU_FASTFLAG(LuauTypeAliasPacks)
namespace Luau
{
@ -612,6 +615,12 @@ struct AstJsonEncoder : public AstVisitor
writeNode(node, "AstStatTypeAlias", [&]() {
PROP(name);
PROP(generics);
if (FFlag::LuauTypeAliasPacks)
{
PROP(genericPacks);
}
PROP(type);
PROP(exported);
});
@ -664,13 +673,21 @@ struct AstJsonEncoder : public AstVisitor
});
}
void write(struct AstTypeOrPack node)
{
if (node.type)
write(node.type);
else
write(node.typePack);
}
void write(class AstTypeReference* node)
{
writeNode(node, "AstTypeReference", [&]() {
if (node->hasPrefix)
PROP(prefix);
PROP(name);
PROP(generics);
PROP(parameters);
});
}
@ -734,6 +751,13 @@ struct AstJsonEncoder : public AstVisitor
});
}
void write(class AstTypePackExplicit* node)
{
writeNode(node, "AstTypePackExplicit", [&]() {
PROP(typeList);
});
}
void write(class AstTypePackVariadic* node)
{
writeNode(node, "AstTypePackVariadic", [&]() {
@ -1018,6 +1042,12 @@ struct AstJsonEncoder : public AstVisitor
return false;
}
bool visit(class AstTypePackExplicit* node) override
{
write(node);
return false;
}
bool visit(class AstTypePackVariadic* node) override
{
write(node);

View file

@ -3,6 +3,7 @@
#include "Luau/AstQuery.h"
#include "Luau/Module.h"
#include "Luau/Scope.h"
#include "Luau/TypeInfer.h"
#include "Luau/StringUtils.h"
#include "Luau/Common.h"
@ -12,6 +13,7 @@
#include <limits.h>
LUAU_FASTFLAGVARIABLE(LuauLinterUnknownTypeVectorAware, false)
LUAU_FASTFLAGVARIABLE(LuauLinterTableMoveZero, false)
namespace Luau
{
@ -85,10 +87,10 @@ struct LintContext
return std::nullopt;
auto it = module->astTypes.find(expr);
if (it == module->astTypes.end())
if (!it)
return std::nullopt;
return it->second;
return *it;
}
};
@ -2144,6 +2146,19 @@ private:
"wrap it in parentheses to silence");
}
if (FFlag::LuauLinterTableMoveZero && func->index == "move" && node->args.size >= 4)
{
// table.move(t, 0, _, _)
if (isConstant(args[1], 0.0))
emitWarning(*context, LintWarning::Code_TableOperations, args[1]->location,
"table.move uses index 0 but arrays are 1-based; did you mean 1 instead?");
// table.move(t, _, _, 0)
else if (isConstant(args[3], 0.0))
emitWarning(*context, LintWarning::Code_TableOperations, args[3]->location,
"table.move uses index 0 but arrays are 1-based; did you mean 1 instead?");
}
return true;
}

View file

@ -1,6 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Module.h"
#include "Luau/Scope.h"
#include "Luau/TypeInfer.h"
#include "Luau/TypePack.h"
#include "Luau/TypeVar.h"
@ -13,6 +14,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false)
LUAU_FASTFLAGVARIABLE(DebugLuauTrackOwningArena, false)
LUAU_FASTFLAG(LuauSecondTypecheckKnowsTheDataModel)
LUAU_FASTFLAG(LuauCaptureBrokenCommentSpans)
LUAU_FASTFLAG(LuauTypeAliasPacks)
namespace Luau
{
@ -188,7 +190,7 @@ struct TypePackCloner
template<typename T>
void defaultClone(const T& t)
{
TypePackId cloned = dest.typePacks.allocate(t);
TypePackId cloned = dest.addTypePack(TypePackVar{t});
seenTypePacks[typePackId] = cloned;
}
@ -197,7 +199,7 @@ struct TypePackCloner
if (encounteredFreeType)
*encounteredFreeType = true;
seenTypePacks[typePackId] = dest.typePacks.allocate(TypePackVar{Unifiable::Error{}});
seenTypePacks[typePackId] = dest.addTypePack(TypePackVar{Unifiable::Error{}});
}
void operator()(const Unifiable::Generic& t)
@ -219,13 +221,13 @@ struct TypePackCloner
void operator()(const VariadicTypePack& t)
{
TypePackId cloned = dest.typePacks.allocate(VariadicTypePack{clone(t.ty, dest, seenTypes, seenTypePacks, encounteredFreeType)});
TypePackId cloned = dest.addTypePack(TypePackVar{VariadicTypePack{clone(t.ty, dest, seenTypes, seenTypePacks, encounteredFreeType)}});
seenTypePacks[typePackId] = cloned;
}
void operator()(const TypePack& t)
{
TypePackId cloned = dest.typePacks.allocate(TypePack{});
TypePackId cloned = dest.addTypePack(TypePack{});
TypePack* destTp = getMutable<TypePack>(cloned);
LUAU_ASSERT(destTp != nullptr);
seenTypePacks[typePackId] = cloned;
@ -241,7 +243,7 @@ struct TypePackCloner
template<typename T>
void TypeCloner::defaultClone(const T& t)
{
TypeId cloned = dest.typeVars.allocate(t);
TypeId cloned = dest.addType(t);
seenTypes[typeId] = cloned;
}
@ -250,7 +252,7 @@ void TypeCloner::operator()(const Unifiable::Free& t)
if (encounteredFreeType)
*encounteredFreeType = true;
seenTypes[typeId] = dest.typeVars.allocate(ErrorTypeVar{});
seenTypes[typeId] = dest.addType(ErrorTypeVar{});
}
void TypeCloner::operator()(const Unifiable::Generic& t)
@ -275,7 +277,7 @@ void TypeCloner::operator()(const PrimitiveTypeVar& t)
void TypeCloner::operator()(const FunctionTypeVar& t)
{
TypeId result = dest.typeVars.allocate(FunctionTypeVar{TypeLevel{0, 0}, {}, {}, nullptr, nullptr, t.definition, t.hasSelf});
TypeId result = dest.addType(FunctionTypeVar{TypeLevel{0, 0}, {}, {}, nullptr, nullptr, t.definition, t.hasSelf});
FunctionTypeVar* ftv = getMutable<FunctionTypeVar>(result);
LUAU_ASSERT(ftv != nullptr);
@ -297,7 +299,7 @@ void TypeCloner::operator()(const FunctionTypeVar& t)
void TypeCloner::operator()(const TableTypeVar& t)
{
TypeId result = dest.typeVars.allocate(TableTypeVar{});
TypeId result = dest.addType(TableTypeVar{});
TableTypeVar* ttv = getMutable<TableTypeVar>(result);
LUAU_ASSERT(ttv != nullptr);
@ -323,7 +325,13 @@ void TypeCloner::operator()(const TableTypeVar& t)
ttv->boundTo = clone(*t.boundTo, dest, seenTypes, seenTypePacks, encounteredFreeType);
for (TypeId& arg : ttv->instantiatedTypeParams)
arg = (clone(arg, dest, seenTypes, seenTypePacks, encounteredFreeType));
arg = clone(arg, dest, seenTypes, seenTypePacks, encounteredFreeType);
if (FFlag::LuauTypeAliasPacks)
{
for (TypePackId& arg : ttv->instantiatedTypePackParams)
arg = clone(arg, dest, seenTypes, seenTypePacks, encounteredFreeType);
}
if (ttv->state == TableState::Free)
{
@ -343,7 +351,7 @@ void TypeCloner::operator()(const TableTypeVar& t)
void TypeCloner::operator()(const MetatableTypeVar& t)
{
TypeId result = dest.typeVars.allocate(MetatableTypeVar{});
TypeId result = dest.addType(MetatableTypeVar{});
MetatableTypeVar* mtv = getMutable<MetatableTypeVar>(result);
seenTypes[typeId] = result;
@ -353,7 +361,7 @@ void TypeCloner::operator()(const MetatableTypeVar& t)
void TypeCloner::operator()(const ClassTypeVar& t)
{
TypeId result = dest.typeVars.allocate(ClassTypeVar{t.name, {}, std::nullopt, std::nullopt, t.tags, t.userData});
TypeId result = dest.addType(ClassTypeVar{t.name, {}, std::nullopt, std::nullopt, t.tags, t.userData});
ClassTypeVar* ctv = getMutable<ClassTypeVar>(result);
seenTypes[typeId] = result;
@ -378,7 +386,7 @@ void TypeCloner::operator()(const AnyTypeVar& t)
void TypeCloner::operator()(const UnionTypeVar& t)
{
TypeId result = dest.typeVars.allocate(UnionTypeVar{});
TypeId result = dest.addType(UnionTypeVar{});
seenTypes[typeId] = result;
UnionTypeVar* option = getMutable<UnionTypeVar>(result);
@ -390,7 +398,7 @@ void TypeCloner::operator()(const UnionTypeVar& t)
void TypeCloner::operator()(const IntersectionTypeVar& t)
{
TypeId result = dest.typeVars.allocate(IntersectionTypeVar{});
TypeId result = dest.addType(IntersectionTypeVar{});
seenTypes[typeId] = result;
IntersectionTypeVar* option = getMutable<IntersectionTypeVar>(result);
@ -451,8 +459,14 @@ TypeId clone(TypeId typeId, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks
TypeFun clone(const TypeFun& typeFun, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, bool* encounteredFreeType)
{
TypeFun result;
for (TypeId param : typeFun.typeParams)
result.typeParams.push_back(clone(param, dest, seenTypes, seenTypePacks, encounteredFreeType));
for (TypeId ty : typeFun.typeParams)
result.typeParams.push_back(clone(ty, dest, seenTypes, seenTypePacks, encounteredFreeType));
if (FFlag::LuauTypeAliasPacks)
{
for (TypePackId tp : typeFun.typePackParams)
result.typePackParams.push_back(clone(tp, dest, seenTypes, seenTypePacks, encounteredFreeType));
}
result.type = clone(typeFun.type, dest, seenTypes, seenTypePacks, encounteredFreeType);

View file

@ -5,6 +5,7 @@
#include "Luau/Module.h"
LUAU_FASTFLAGVARIABLE(LuauTraceRequireLookupChild, false)
LUAU_FASTFLAGVARIABLE(LuauNewRequireTrace, false)
namespace Luau
{
@ -12,17 +13,18 @@ namespace Luau
namespace
{
struct RequireTracer : AstVisitor
struct RequireTracerOld : AstVisitor
{
explicit RequireTracer(FileResolver* fileResolver, ModuleName currentModuleName)
explicit RequireTracerOld(FileResolver* fileResolver, const ModuleName& currentModuleName)
: fileResolver(fileResolver)
, currentModuleName(std::move(currentModuleName))
, currentModuleName(currentModuleName)
{
LUAU_ASSERT(!FFlag::LuauNewRequireTrace);
}
FileResolver* const fileResolver;
ModuleName currentModuleName;
DenseHashMap<AstLocal*, ModuleName> locals{0};
DenseHashMap<AstLocal*, ModuleName> locals{nullptr};
RequireTraceResult result;
std::optional<ModuleName> fromAstFragment(AstExpr* expr)
@ -50,9 +52,9 @@ struct RequireTracer : AstVisitor
AstExpr* expr = stat->values.data[i];
expr->visit(this);
const ModuleName* name = result.exprs.find(expr);
if (name)
locals[local] = *name;
const ModuleInfo* info = result.exprs.find(expr);
if (info)
locals[local] = info->name;
}
}
@ -63,7 +65,7 @@ struct RequireTracer : AstVisitor
{
std::optional<ModuleName> name = fromAstFragment(global);
if (name)
result.exprs[global] = *name;
result.exprs[global] = {*name};
return false;
}
@ -72,7 +74,7 @@ struct RequireTracer : AstVisitor
{
const ModuleName* name = locals.find(local->local);
if (name)
result.exprs[local] = *name;
result.exprs[local] = {*name};
return false;
}
@ -81,16 +83,16 @@ struct RequireTracer : AstVisitor
{
indexName->expr->visit(this);
const ModuleName* name = result.exprs.find(indexName->expr);
if (name)
const ModuleInfo* info = result.exprs.find(indexName->expr);
if (info)
{
if (indexName->index == "parent" || indexName->index == "Parent")
{
if (auto parent = fileResolver->getParentModuleName(*name))
result.exprs[indexName] = *parent;
if (auto parent = fileResolver->getParentModuleName(info->name))
result.exprs[indexName] = {*parent};
}
else
result.exprs[indexName] = fileResolver->concat(*name, indexName->index.value);
result.exprs[indexName] = {fileResolver->concat(info->name, indexName->index.value)};
}
return false;
@ -100,11 +102,11 @@ struct RequireTracer : AstVisitor
{
indexExpr->expr->visit(this);
const ModuleName* name = result.exprs.find(indexExpr->expr);
const ModuleInfo* info = result.exprs.find(indexExpr->expr);
const AstExprConstantString* str = indexExpr->index->as<AstExprConstantString>();
if (name && str)
if (info && str)
{
result.exprs[indexExpr] = fileResolver->concat(*name, std::string_view(str->value.data, str->value.size));
result.exprs[indexExpr] = {fileResolver->concat(info->name, std::string_view(str->value.data, str->value.size))};
}
indexExpr->index->visit(this);
@ -129,8 +131,8 @@ struct RequireTracer : AstVisitor
AstExprGlobal* globalName = call->func->as<AstExprGlobal>();
if (globalName && globalName->name == "require" && call->args.size >= 1)
{
if (const ModuleName* moduleName = result.exprs.find(call->args.data[0]))
result.requires.push_back({*moduleName, call->location});
if (const ModuleInfo* moduleInfo = result.exprs.find(call->args.data[0]))
result.requires.push_back({moduleInfo->name, call->location});
return false;
}
@ -143,8 +145,8 @@ struct RequireTracer : AstVisitor
if (FFlag::LuauTraceRequireLookupChild && !rootName)
{
if (const ModuleName* moduleName = result.exprs.find(indexName->expr))
rootName = *moduleName;
if (const ModuleInfo* moduleInfo = result.exprs.find(indexName->expr))
rootName = moduleInfo->name;
}
if (!rootName)
@ -167,24 +169,183 @@ struct RequireTracer : AstVisitor
if (v.end() != std::find(v.begin(), v.end(), '/'))
return false;
result.exprs[call] = fileResolver->concat(*rootName, v);
result.exprs[call] = {fileResolver->concat(*rootName, v)};
// 'WaitForChild' can be used on modules that are not available at the typecheck time, but will be available at runtime
// If we fail to find such module, we will not report an UnknownRequire error
if (FFlag::LuauTraceRequireLookupChild && indexName->index == "WaitForChild")
result.optional[call] = true;
result.exprs[call].optional = true;
return false;
}
};
struct RequireTracer : AstVisitor
{
RequireTracer(RequireTraceResult& result, FileResolver * fileResolver, const ModuleName& currentModuleName)
: result(result)
, fileResolver(fileResolver)
, currentModuleName(currentModuleName)
, locals(nullptr)
{
LUAU_ASSERT(FFlag::LuauNewRequireTrace);
}
bool visit(AstExprTypeAssertion* expr) override
{
// suppress `require() :: any`
return false;
}
bool visit(AstExprCall* expr) override
{
AstExprGlobal* global = expr->func->as<AstExprGlobal>();
if (global && global->name == "require" && expr->args.size >= 1)
requires.push_back(expr);
return true;
}
bool visit(AstStatLocal* stat) override
{
for (size_t i = 0; i < stat->vars.size && i < stat->values.size; ++i)
{
AstLocal* local = stat->vars.data[i];
AstExpr* expr = stat->values.data[i];
// track initializing expression to be able to trace modules through locals
locals[local] = expr;
}
return true;
}
bool visit(AstStatAssign* stat) override
{
for (size_t i = 0; i < stat->vars.size; ++i)
{
// locals that are assigned don't have a known expression
if (AstExprLocal* expr = stat->vars.data[i]->as<AstExprLocal>())
locals[expr->local] = nullptr;
}
return true;
}
bool visit(AstType* node) override
{
// allow resolving require inside `typeof` annotations
return true;
}
AstExpr* getDependent(AstExpr* node)
{
if (AstExprLocal* expr = node->as<AstExprLocal>())
return locals[expr->local];
else if (AstExprIndexName* expr = node->as<AstExprIndexName>())
return expr->expr;
else if (AstExprIndexExpr* expr = node->as<AstExprIndexExpr>())
return expr->expr;
else if (AstExprCall* expr = node->as<AstExprCall>(); expr && expr->self)
return expr->func->as<AstExprIndexName>()->expr;
else
return nullptr;
}
void process()
{
ModuleInfo moduleContext{currentModuleName};
// seed worklist with require arguments
work.reserve(requires.size());
for (AstExprCall* require: requires)
work.push_back(require->args.data[0]);
// push all dependent expressions to the work stack; note that the vector is modified during traversal
for (size_t i = 0; i < work.size(); ++i)
if (AstExpr* dep = getDependent(work[i]))
work.push_back(dep);
// resolve all expressions to a module info
for (size_t i = work.size(); i > 0; --i)
{
AstExpr* expr = work[i - 1];
// when multiple expressions depend on the same one we push it to work queue multiple times
if (result.exprs.contains(expr))
continue;
std::optional<ModuleInfo> info;
if (AstExpr* dep = getDependent(expr))
{
const ModuleInfo* context = result.exprs.find(dep);
// locals just inherit their dependent context, no resolution required
if (expr->is<AstExprLocal>())
info = context ? std::optional<ModuleInfo>(*context) : std::nullopt;
else
info = fileResolver->resolveModule(context, expr);
}
else
{
info = fileResolver->resolveModule(&moduleContext, expr);
}
if (info)
result.exprs[expr] = std::move(*info);
}
// resolve all requires according to their argument
result.requires.reserve(requires.size());
for (AstExprCall* require : requires)
{
AstExpr* arg = require->args.data[0];
if (const ModuleInfo* info = result.exprs.find(arg))
{
result.requires.push_back({info->name, require->location});
ModuleInfo infoCopy = *info; // copy *info out since next line invalidates info!
result.exprs[require] = std::move(infoCopy);
}
else
{
result.exprs[require] = {}; // mark require as unresolved
}
}
}
RequireTraceResult& result;
FileResolver* fileResolver;
ModuleName currentModuleName;
DenseHashMap<AstLocal*, AstExpr*> locals;
std::vector<AstExpr*> work;
std::vector<AstExprCall*> requires;
};
} // anonymous namespace
RequireTraceResult traceRequires(FileResolver* fileResolver, AstStatBlock* root, ModuleName currentModuleName)
RequireTraceResult traceRequires(FileResolver* fileResolver, AstStatBlock* root, const ModuleName& currentModuleName)
{
RequireTracer tracer{fileResolver, std::move(currentModuleName)};
if (FFlag::LuauNewRequireTrace)
{
RequireTraceResult result;
RequireTracer tracer{result, fileResolver, currentModuleName};
root->visit(&tracer);
tracer.process();
return result;
}
else
{
RequireTracerOld tracer{fileResolver, currentModuleName};
root->visit(&tracer);
return tracer.result;
}
}
} // namespace Luau

123
Analysis/src/Scope.cpp Normal file
View file

@ -0,0 +1,123 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Scope.h"
namespace Luau
{
Scope::Scope(TypePackId returnType)
: parent(nullptr)
, returnType(returnType)
, level(TypeLevel())
{
}
Scope::Scope(const ScopePtr& parent, int subLevel)
: parent(parent)
, returnType(parent->returnType)
, level(parent->level.incr())
{
level.subLevel = subLevel;
}
std::optional<TypeId> Scope::lookup(const Symbol& name)
{
Scope* scope = this;
while (scope)
{
auto it = scope->bindings.find(name);
if (it != scope->bindings.end())
return it->second.typeId;
scope = scope->parent.get();
}
return std::nullopt;
}
std::optional<TypeFun> Scope::lookupType(const Name& name)
{
const Scope* scope = this;
while (true)
{
auto it = scope->exportedTypeBindings.find(name);
if (it != scope->exportedTypeBindings.end())
return it->second;
it = scope->privateTypeBindings.find(name);
if (it != scope->privateTypeBindings.end())
return it->second;
if (scope->parent)
scope = scope->parent.get();
else
return std::nullopt;
}
}
std::optional<TypeFun> Scope::lookupImportedType(const Name& moduleAlias, const Name& name)
{
const Scope* scope = this;
while (scope)
{
auto it = scope->importedTypeBindings.find(moduleAlias);
if (it == scope->importedTypeBindings.end())
{
scope = scope->parent.get();
continue;
}
auto it2 = it->second.find(name);
if (it2 == it->second.end())
{
scope = scope->parent.get();
continue;
}
return it2->second;
}
return std::nullopt;
}
std::optional<TypePackId> Scope::lookupPack(const Name& name)
{
const Scope* scope = this;
while (true)
{
auto it = scope->privateTypePackBindings.find(name);
if (it != scope->privateTypePackBindings.end())
return it->second;
if (scope->parent)
scope = scope->parent.get();
else
return std::nullopt;
}
}
std::optional<Binding> Scope::linearSearchForBinding(const std::string& name, bool traverseScopeChain)
{
Scope* scope = this;
while (scope)
{
for (const auto& [n, binding] : scope->bindings)
{
if (n.local && n.local->name == name.c_str())
return binding;
else if (n.global.value && n.global == name.c_str())
return binding;
}
scope = scope->parent.get();
if (!traverseScopeChain)
break;
}
return std::nullopt;
}
} // namespace Luau

View file

@ -6,9 +6,11 @@
#include <algorithm>
#include <stdexcept>
LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 0)
LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 1000)
LUAU_FASTFLAGVARIABLE(LuauSubstitutionDontReplaceIgnoredTypes, false)
LUAU_FASTFLAG(LuauSecondTypecheckKnowsTheDataModel)
LUAU_FASTFLAG(LuauRankNTypes)
LUAU_FASTFLAG(LuauTypeAliasPacks)
namespace Luau
{
@ -35,8 +37,15 @@ void Tarjan::visitChildren(TypeId ty, int index)
visitChild(ttv->indexer->indexType);
visitChild(ttv->indexer->indexResultType);
}
for (TypeId itp : ttv->instantiatedTypeParams)
visitChild(itp);
if (FFlag::LuauTypeAliasPacks)
{
for (TypePackId itp : ttv->instantiatedTypePackParams)
visitChild(itp);
}
}
else if (const MetatableTypeVar* mtv = get<MetatableTypeVar>(ty))
{
@ -332,8 +341,10 @@ std::optional<TypeId> Substitution::substitute(TypeId ty)
return std::nullopt;
for (auto [oldTy, newTy] : newTypes)
if (!FFlag::LuauSubstitutionDontReplaceIgnoredTypes || !ignoreChildren(oldTy))
replaceChildren(newTy);
for (auto [oldTp, newTp] : newPacks)
if (!FFlag::LuauSubstitutionDontReplaceIgnoredTypes || !ignoreChildren(oldTp))
replaceChildren(newTp);
TypeId newTy = replace(ty);
return newTy;
@ -350,8 +361,10 @@ std::optional<TypePackId> Substitution::substitute(TypePackId tp)
return std::nullopt;
for (auto [oldTy, newTy] : newTypes)
if (!FFlag::LuauSubstitutionDontReplaceIgnoredTypes || !ignoreChildren(oldTy))
replaceChildren(newTy);
for (auto [oldTp, newTp] : newPacks)
if (!FFlag::LuauSubstitutionDontReplaceIgnoredTypes || !ignoreChildren(oldTp))
replaceChildren(newTp);
TypePackId newTp = replace(tp);
return newTp;
@ -382,6 +395,10 @@ TypeId Substitution::clone(TypeId ty)
clone.name = ttv->name;
clone.syntheticName = ttv->syntheticName;
clone.instantiatedTypeParams = ttv->instantiatedTypeParams;
if (FFlag::LuauTypeAliasPacks)
clone.instantiatedTypePackParams = ttv->instantiatedTypePackParams;
if (FFlag::LuauSecondTypecheckKnowsTheDataModel)
clone.tags = ttv->tags;
result = addType(std::move(clone));
@ -487,8 +504,15 @@ void Substitution::replaceChildren(TypeId ty)
ttv->indexer->indexType = replace(ttv->indexer->indexType);
ttv->indexer->indexResultType = replace(ttv->indexer->indexResultType);
}
for (TypeId& itp : ttv->instantiatedTypeParams)
itp = replace(itp);
if (FFlag::LuauTypeAliasPacks)
{
for (TypePackId& itp : ttv->instantiatedTypePackParams)
itp = replace(itp);
}
}
else if (MetatableTypeVar* mtv = getMutable<MetatableTypeVar>(ty))
{

View file

@ -1,6 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/ToString.h"
#include "Luau/Scope.h"
#include "Luau/TypeInfer.h"
#include "Luau/TypePack.h"
#include "Luau/TypeVar.h"
@ -9,10 +10,10 @@
#include <algorithm>
#include <stdexcept>
LUAU_FASTFLAG(LuauToStringFollowsBoundTo)
LUAU_FASTFLAG(LuauExtraNilRecovery)
LUAU_FASTFLAG(LuauOccursCheckOkWithRecursiveFunctions)
LUAU_FASTFLAGVARIABLE(LuauInstantiatedTypeParamRecursion, false)
LUAU_FASTFLAG(LuauTypeAliasPacks)
namespace Luau
{
@ -59,6 +60,13 @@ struct FindCyclicTypes
{
for (TypeId itp : ttv.instantiatedTypeParams)
visitTypeVar(itp, *this, seen);
if (FFlag::LuauTypeAliasPacks)
{
for (TypePackId itp : ttv.instantiatedTypePackParams)
visitTypeVar(itp, *this, seen);
}
return exhaustive;
}
@ -258,14 +266,50 @@ struct TypeVarStringifier
void stringify(TypePackId tp);
void stringify(TypePackId tpid, const std::vector<std::optional<FunctionArgument>>& names);
void stringify(const std::vector<TypeId>& types)
void stringify(const std::vector<TypeId>& types, const std::vector<TypePackId>& typePacks)
{
if (types.size() == 0)
if (types.size() == 0 && (!FFlag::LuauTypeAliasPacks || typePacks.size() == 0))
return;
if (types.size())
if (types.size() || (FFlag::LuauTypeAliasPacks && typePacks.size()))
state.emit("<");
if (FFlag::LuauTypeAliasPacks)
{
bool first = true;
for (TypeId ty : types)
{
if (!first)
state.emit(", ");
first = false;
stringify(ty);
}
bool singleTp = typePacks.size() == 1;
for (TypePackId tp : typePacks)
{
if (isEmpty(tp) && singleTp)
continue;
if (!first)
state.emit(", ");
else
first = false;
if (!singleTp)
state.emit("(");
stringify(tp);
if (!singleTp)
state.emit(")");
}
}
else
{
for (size_t i = 0; i < types.size(); ++i)
{
if (i > 0)
@ -273,8 +317,9 @@ struct TypeVarStringifier
stringify(types[i]);
}
}
if (types.size())
if (types.size() || (FFlag::LuauTypeAliasPacks && typePacks.size()))
state.emit(">");
}
@ -388,7 +433,7 @@ struct TypeVarStringifier
void operator()(TypeId, const TableTypeVar& ttv)
{
if (FFlag::LuauToStringFollowsBoundTo && ttv.boundTo)
if (ttv.boundTo)
return stringify(*ttv.boundTo);
if (!state.exhaustive)
@ -411,14 +456,14 @@ struct TypeVarStringifier
}
state.emit(*ttv.name);
stringify(ttv.instantiatedTypeParams);
stringify(ttv.instantiatedTypeParams, ttv.instantiatedTypePackParams);
return;
}
if (ttv.syntheticName)
{
state.result.invalid = true;
state.emit(*ttv.syntheticName);
stringify(ttv.instantiatedTypeParams);
stringify(ttv.instantiatedTypeParams, ttv.instantiatedTypePackParams);
return;
}
}
@ -900,13 +945,26 @@ ToStringResult toStringDetailed(TypeId ty, const ToStringOptions& opts)
result.name += ttv->name ? *ttv->name : *ttv->syntheticName;
if (ttv->instantiatedTypeParams.empty())
if (ttv->instantiatedTypeParams.empty() && (!FFlag::LuauTypeAliasPacks || ttv->instantiatedTypePackParams.empty()))
return result;
std::vector<std::string> params;
for (TypeId tp : ttv->instantiatedTypeParams)
params.push_back(toString(tp));
if (FFlag::LuauTypeAliasPacks)
{
// Doesn't preserve grouping of multiple type packs
// But this is under a parent block of code that is being removed later
for (TypePackId tp : ttv->instantiatedTypePackParams)
{
std::string content = toString(tp);
if (!content.empty())
params.push_back(std::move(content));
}
}
result.name += "<" + join(params, ", ") + ">";
return result;
}
@ -950,7 +1008,13 @@ ToStringResult toStringDetailed(TypeId ty, const ToStringOptions& opts)
result.name += ttv->name ? *ttv->name : *ttv->syntheticName;
if (ttv->instantiatedTypeParams.empty())
if (FFlag::LuauTypeAliasPacks)
{
tvs.stringify(ttv->instantiatedTypeParams, ttv->instantiatedTypePackParams);
}
else
{
if (ttv->instantiatedTypeParams.empty() && (!FFlag::LuauTypeAliasPacks || ttv->instantiatedTypePackParams.empty()))
return result;
result.name += "<";
@ -975,6 +1039,7 @@ ToStringResult toStringDetailed(TypeId ty, const ToStringOptions& opts)
{
result.name += ">";
}
}
return result;
}

View file

@ -11,6 +11,7 @@
#include <math.h>
LUAU_FASTFLAG(LuauGenericFunctions)
LUAU_FASTFLAG(LuauTypeAliasPacks)
namespace
{
@ -280,10 +281,19 @@ struct Printer
void visualizeTypePackAnnotation(const AstTypePack& annotation)
{
if (const AstTypePackVariadic* variadic = annotation.as<AstTypePackVariadic>())
if (const AstTypePackVariadic* variadicTp = annotation.as<AstTypePackVariadic>())
{
writer.symbol("...");
visualizeTypeAnnotation(*variadic->variadicType);
visualizeTypeAnnotation(*variadicTp->variadicType);
}
else if (const AstTypePackGeneric* genericTp = annotation.as<AstTypePackGeneric>())
{
writer.symbol(genericTp->genericName.value);
writer.symbol("...");
}
else if (const AstTypePackExplicit* explicitTp = annotation.as<AstTypePackExplicit>())
{
visualizeTypeList(explicitTp->typeList, true);
}
else
{
@ -807,7 +817,7 @@ struct Printer
writer.keyword("type");
writer.identifier(a->name.value);
if (a->generics.size > 0)
if (a->generics.size > 0 || (FFlag::LuauTypeAliasPacks && a->genericPacks.size > 0))
{
writer.symbol("<");
CommaSeparatorInserter comma(writer);
@ -817,6 +827,17 @@ struct Printer
comma();
writer.identifier(o.value);
}
if (FFlag::LuauTypeAliasPacks)
{
for (auto o : a->genericPacks)
{
comma();
writer.identifier(o.value);
writer.symbol("...");
}
}
writer.symbol(">");
}
writer.maybeSpace(a->type->location.begin, 2);
@ -960,15 +981,20 @@ struct Printer
if (const auto& a = typeAnnotation.as<AstTypeReference>())
{
writer.write(a->name.value);
if (a->generics.size > 0)
if (a->parameters.size > 0)
{
CommaSeparatorInserter comma(writer);
writer.symbol("<");
for (auto o : a->generics)
for (auto o : a->parameters)
{
comma();
visualizeTypeAnnotation(*o);
if (o.type)
visualizeTypeAnnotation(*o.type);
else
visualizeTypePackAnnotation(*o.typePack);
}
writer.symbol(">");
}
}

View file

@ -5,6 +5,7 @@
#include "Luau/Module.h"
#include "Luau/Parser.h"
#include "Luau/RecursionCounter.h"
#include "Luau/Scope.h"
#include "Luau/TypeInfer.h"
#include "Luau/TypePack.h"
#include "Luau/TypeVar.h"
@ -12,6 +13,7 @@
#include <string>
LUAU_FASTFLAG(LuauGenericFunctions)
LUAU_FASTFLAG(LuauTypeAliasPacks)
static char* allocateString(Luau::Allocator& allocator, std::string_view contents)
{
@ -33,7 +35,6 @@ static char* allocateString(Luau::Allocator& allocator, const char* format, Data
namespace Luau
{
class TypeRehydrationVisitor
{
mutable std::map<void*, int> seen;
@ -57,6 +58,8 @@ public:
{
}
AstTypePack* rehydrate(TypePackId tp) const;
AstType* operator()(const PrimitiveTypeVar& ptv) const
{
switch (ptv.type)
@ -85,16 +88,24 @@ public:
if (ttv.name && options.bannedNames.find(*ttv.name) == options.bannedNames.end())
{
AstArray<AstType*> generics;
generics.size = ttv.instantiatedTypeParams.size();
generics.data = static_cast<AstType**>(allocator->allocate(sizeof(AstType*) * generics.size));
AstArray<AstTypeOrPack> parameters;
parameters.size = ttv.instantiatedTypeParams.size();
parameters.data = static_cast<AstTypeOrPack*>(allocator->allocate(sizeof(AstTypeOrPack) * parameters.size));
for (size_t i = 0; i < ttv.instantiatedTypeParams.size(); ++i)
{
generics.data[i] = Luau::visit(*this, ttv.instantiatedTypeParams[i]->ty);
parameters.data[i] = {Luau::visit(*this, ttv.instantiatedTypeParams[i]->ty), {}};
}
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName(ttv.name->c_str()), generics);
if (FFlag::LuauTypeAliasPacks)
{
for (size_t i = 0; i < ttv.instantiatedTypePackParams.size(); ++i)
{
parameters.data[i] = {{}, rehydrate(ttv.instantiatedTypePackParams[i])};
}
}
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName(ttv.name->c_str()), parameters.size != 0, parameters);
}
if (hasSeen(&ttv))
@ -221,6 +232,12 @@ public:
AstTypePack* argTailAnnotation = nullptr;
if (argTail)
{
if (FFlag::LuauTypeAliasPacks)
{
argTailAnnotation = rehydrate(*argTail);
}
else
{
TypePackId tail = *argTail;
if (const VariadicTypePack* vtp = get<VariadicTypePack>(tail))
@ -228,6 +245,7 @@ public:
argTailAnnotation = allocator->alloc<AstTypePackVariadic>(Location(), Luau::visit(*this, vtp->ty->ty));
}
}
}
AstArray<std::optional<AstArgumentName>> argNames;
argNames.size = ftv.argNames.size();
@ -254,6 +272,12 @@ public:
AstTypePack* retTailAnnotation = nullptr;
if (retTail)
{
if (FFlag::LuauTypeAliasPacks)
{
retTailAnnotation = rehydrate(*retTail);
}
else
{
TypePackId tail = *retTail;
if (const VariadicTypePack* vtp = get<VariadicTypePack>(tail))
@ -261,6 +285,7 @@ public:
retTailAnnotation = allocator->alloc<AstTypePackVariadic>(Location(), Luau::visit(*this, vtp->ty->ty));
}
}
}
return allocator->alloc<AstTypeFunction>(
Location(), generics, genericPacks, AstTypeList{argTypes, argTailAnnotation}, argNames, AstTypeList{returnTypes, retTailAnnotation});
@ -313,6 +338,68 @@ private:
const TypeRehydrationOptions& options;
};
class TypePackRehydrationVisitor
{
public:
TypePackRehydrationVisitor(Allocator* allocator, const TypeRehydrationVisitor& typeVisitor)
: allocator(allocator)
, typeVisitor(typeVisitor)
{
}
AstTypePack* operator()(const BoundTypePack& btp) const
{
return Luau::visit(*this, btp.boundTo->ty);
}
AstTypePack* operator()(const TypePack& tp) const
{
AstArray<AstType*> head;
head.size = tp.head.size();
head.data = static_cast<AstType**>(allocator->allocate(sizeof(AstType*) * tp.head.size()));
for (size_t i = 0; i < tp.head.size(); i++)
head.data[i] = Luau::visit(typeVisitor, tp.head[i]->ty);
AstTypePack* tail = nullptr;
if (tp.tail)
tail = Luau::visit(*this, (*tp.tail)->ty);
return allocator->alloc<AstTypePackExplicit>(Location(), AstTypeList{head, tail});
}
AstTypePack* operator()(const VariadicTypePack& vtp) const
{
return allocator->alloc<AstTypePackVariadic>(Location(), Luau::visit(typeVisitor, vtp.ty->ty));
}
AstTypePack* operator()(const GenericTypePack& gtp) const
{
return allocator->alloc<AstTypePackGeneric>(Location(), AstName(gtp.name.c_str()));
}
AstTypePack* operator()(const FreeTypePack& gtp) const
{
return allocator->alloc<AstTypePackGeneric>(Location(), AstName("free"));
}
AstTypePack* operator()(const Unifiable::Error&) const
{
return allocator->alloc<AstTypePackGeneric>(Location(), AstName("Unifiable<Error>"));
}
private:
Allocator* allocator;
const TypeRehydrationVisitor& typeVisitor;
};
AstTypePack* TypeRehydrationVisitor::rehydrate(TypePackId tp) const
{
TypePackRehydrationVisitor tprv(allocator, *this);
return Luau::visit(tprv, tp->ty);
}
class TypeAttacher : public AstVisitor
{
public:
@ -405,11 +492,18 @@ public:
const auto& [v, tail] = flatten(ret);
if (tail)
{
if (FFlag::LuauTypeAliasPacks)
{
variadicAnnotation = TypeRehydrationVisitor(allocator).rehydrate(*tail);
}
else
{
TypePackId tailPack = *tail;
if (const VariadicTypePack* vtp = get<VariadicTypePack>(tailPack))
variadicAnnotation = allocator->alloc<AstTypePackVariadic>(Location(), typeAst(vtp->ty));
}
}
fn->returnAnnotation = AstTypeList{typeAstPack(ret), variadicAnnotation};
}

View file

@ -5,21 +5,22 @@
#include "Luau/ModuleResolver.h"
#include "Luau/Parser.h"
#include "Luau/RecursionCounter.h"
#include "Luau/Scope.h"
#include "Luau/Substitution.h"
#include "Luau/TopoSortStatements.h"
#include "Luau/ToString.h"
#include "Luau/TypePack.h"
#include "Luau/TypeUtils.h"
#include "Luau/TypeVar.h"
#include "Luau/TimeTrace.h"
#include <deque>
#include <algorithm>
LUAU_FASTFLAGVARIABLE(DebugLuauMagicTypes, false)
LUAU_FASTINTVARIABLE(LuauTypeInferRecursionLimit, 0)
LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 0)
LUAU_FASTINTVARIABLE(LuauTypeInferRecursionLimit, 500)
LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 5000)
LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 500)
LUAU_FASTFLAGVARIABLE(LuauIndexTablesWithIndexers, false)
LUAU_FASTFLAGVARIABLE(LuauGenericFunctions, false)
LUAU_FASTFLAGVARIABLE(LuauGenericVariadicsUnification, false)
LUAU_FASTFLAG(LuauKnowsTheDataModel3)
@ -27,14 +28,11 @@ LUAU_FASTFLAG(LuauSecondTypecheckKnowsTheDataModel)
LUAU_FASTFLAGVARIABLE(LuauClassPropertyAccessAsString, false)
LUAU_FASTFLAGVARIABLE(LuauEqConstraint, false)
LUAU_FASTFLAGVARIABLE(LuauWeakEqConstraint, false) // Eventually removed as false.
LUAU_FASTFLAGVARIABLE(LuauImprovedTypeGuardPredicate2, false)
LUAU_FASTFLAG(LuauTraceRequireLookupChild)
LUAU_FASTFLAG(DebugLuauTrackOwningArena)
LUAU_FASTFLAGVARIABLE(LuauCloneCorrectlyBeforeMutatingTableType, false)
LUAU_FASTFLAGVARIABLE(LuauStoreMatchingOverloadFnType, false)
LUAU_FASTFLAGVARIABLE(LuauRankNTypes, false)
LUAU_FASTFLAGVARIABLE(LuauOrPredicate, false)
LUAU_FASTFLAGVARIABLE(LuauFixTableTypeAliasClone, false)
LUAU_FASTFLAGVARIABLE(LuauExtraNilRecovery, false)
LUAU_FASTFLAGVARIABLE(LuauMissingUnionPropertyError, false)
LUAU_FASTFLAGVARIABLE(LuauInferReturnAssertAssign, false)
@ -45,6 +43,10 @@ LUAU_FASTFLAGVARIABLE(LuauSlightlyMoreFlexibleBinaryPredicates, false)
LUAU_FASTFLAGVARIABLE(LuauInferFunctionArgsFix, false)
LUAU_FASTFLAGVARIABLE(LuauFollowInTypeFunApply, false)
LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionAnalysisSupport, false)
LUAU_FASTFLAGVARIABLE(LuauStrictRequire, false)
LUAU_FASTFLAG(LuauSubstitutionDontReplaceIgnoredTypes)
LUAU_FASTFLAG(LuauNewRequireTrace)
LUAU_FASTFLAG(LuauTypeAliasPacks)
namespace Luau
{
@ -216,9 +218,8 @@ TypeChecker::TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHan
, nilType(singletonTypes.nilType)
, numberType(singletonTypes.numberType)
, stringType(singletonTypes.stringType)
, booleanType(
FFlag::LuauImprovedTypeGuardPredicate2 ? singletonTypes.booleanType : globalTypes.addType(PrimitiveTypeVar(PrimitiveTypeVar::Boolean)))
, threadType(FFlag::LuauImprovedTypeGuardPredicate2 ? singletonTypes.threadType : globalTypes.addType(PrimitiveTypeVar(PrimitiveTypeVar::Thread)))
, booleanType(singletonTypes.booleanType)
, threadType(singletonTypes.threadType)
, anyType(singletonTypes.anyType)
, errorType(singletonTypes.errorType)
, optionalNumberType(globalTypes.addType(UnionTypeVar{{numberType, nilType}}))
@ -237,6 +238,9 @@ TypeChecker::TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHan
ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optional<ScopePtr> environmentScope)
{
LUAU_TIMETRACE_SCOPE("TypeChecker::check", "TypeChecker");
LUAU_TIMETRACE_ARGUMENT("module", module.name.c_str());
currentModule.reset(new Module());
currentModule->type = module.type;
@ -1177,12 +1181,28 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias
{
Location location = scope->typeAliasLocations[name];
reportError(TypeError{typealias.location, DuplicateTypeDefinition{name, location}});
if (FFlag::LuauTypeAliasPacks)
bindingsMap[name] = TypeFun{binding->typeParams, binding->typePackParams, errorType};
else
bindingsMap[name] = TypeFun{binding->typeParams, errorType};
}
else
{
ScopePtr aliasScope = childScope(scope, typealias.location);
if (FFlag::LuauTypeAliasPacks)
{
auto [generics, genericPacks] = createGenericTypes(aliasScope, typealias, typealias.generics, typealias.genericPacks);
TypeId ty = (FFlag::LuauRankNTypes ? freshType(aliasScope) : DEPRECATED_freshType(scope, true));
FreeTypeVar* ftv = getMutable<FreeTypeVar>(ty);
LUAU_ASSERT(ftv);
ftv->forwardedTypeAlias = true;
bindingsMap[name] = {std::move(generics), std::move(genericPacks), ty};
}
else
{
std::vector<TypeId> generics;
for (AstName generic : typealias.generics)
{
@ -1199,7 +1219,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias
TypeId g;
if (FFlag::LuauRecursiveTypeParameterRestriction)
{
TypeId& cached = scope->typeAliasParameters[n];
TypeId& cached = scope->typeAliasTypeParameters[n];
if (!cached)
cached = addType(GenericTypeVar{aliasScope->level, n});
g = cached;
@ -1217,6 +1237,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias
bindingsMap[name] = {std::move(generics), ty};
}
}
}
else
{
if (!binding)
@ -1231,6 +1252,16 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias
aliasScope->privateTypeBindings[generic->name] = TypeFun{{}, ty};
}
if (FFlag::LuauTypeAliasPacks)
{
for (TypePackId tp : binding->typePackParams)
{
auto generic = get<GenericTypePack>(tp);
LUAU_ASSERT(generic);
aliasScope->privateTypePackBindings[generic->name] = tp;
}
}
TypeId ty = (FFlag::LuauRankNTypes ? resolveType(aliasScope, *typealias.type) : resolveType(aliasScope, *typealias.type, true));
if (auto ttv = getMutable<TableTypeVar>(follow(ty)))
{
@ -1238,7 +1269,8 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias
if (ttv->name)
{
// Copy can be skipped if this is an identical alias
if (!FFlag::LuauFixTableTypeAliasClone || ttv->name != name || ttv->instantiatedTypeParams != binding->typeParams)
if (ttv->name != name || ttv->instantiatedTypeParams != binding->typeParams ||
(FFlag::LuauTypeAliasPacks && ttv->instantiatedTypePackParams != binding->typePackParams))
{
// This is a shallow clone, original recursive links to self are not updated
TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->state};
@ -1249,6 +1281,9 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias
clone.name = name;
clone.instantiatedTypeParams = binding->typeParams;
if (FFlag::LuauTypeAliasPacks)
clone.instantiatedTypePackParams = binding->typePackParams;
ty = addType(std::move(clone));
}
}
@ -1256,6 +1291,9 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias
{
ttv->name = name;
ttv->instantiatedTypeParams = binding->typeParams;
if (FFlag::LuauTypeAliasPacks)
ttv->instantiatedTypePackParams = binding->typePackParams;
}
}
else if (auto mtv = getMutable<MetatableTypeVar>(follow(ty)))
@ -1280,7 +1318,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar
}
// We don't have generic classes, so this assertion _should_ never be hit.
LUAU_ASSERT(lookupType->typeParams.size() == 0);
LUAU_ASSERT(lookupType->typeParams.size() == 0 && (!FFlag::LuauTypeAliasPacks || lookupType->typePackParams.size() == 0));
superTy = lookupType->type;
if (FFlag::LuauAddMissingFollow)
@ -1465,7 +1503,8 @@ ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr&
if (FFlag::LuauStoreMatchingOverloadFnType)
{
currentModule->astTypes.try_emplace(&expr, result.type);
if (!currentModule->astTypes.find(&expr))
currentModule->astTypes[&expr] = result.type;
}
else
{
@ -2193,7 +2232,7 @@ TypeId TypeChecker::checkRelationalOperation(
* have a better, more descriptive error teed up.
*/
Unifier state = mkUnifier(expr.location);
if (!FFlag::LuauEqConstraint || !isEquality)
if (!isEquality)
state.tryUnify(lhsType, rhsType);
bool needsMetamethod = !isEquality;
@ -2262,7 +2301,7 @@ TypeId TypeChecker::checkRelationalOperation(
}
}
if (get<FreeTypeVar>(FFlag::LuauAddMissingFollow ? follow(lhsType) : lhsType) && (!FFlag::LuauEqConstraint || !isEquality))
if (get<FreeTypeVar>(FFlag::LuauAddMissingFollow ? follow(lhsType) : lhsType) && !isEquality)
{
auto name = getIdentifierOfBaseVar(expr.left);
reportError(expr.location, CannotInferBinaryOperation{expr.op, name, CannotInferBinaryOperation::Comparison});
@ -2276,18 +2315,6 @@ TypeId TypeChecker::checkRelationalOperation(
return errorType;
}
if (!FFlag::LuauEqConstraint)
{
if (isEquality)
{
ErrorVec errVec = tryUnify(rhsType, lhsType, expr.location);
if (!state.errors.empty() && !errVec.empty())
reportError(expr.location, TypeMismatch{lhsType, rhsType});
}
else
reportErrors(state.errors);
}
return booleanType;
}
@ -2443,7 +2470,7 @@ ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBi
TypeId result = checkBinaryOperation(innerScope, expr, lhs.type, rhs.type, lhs.predicates);
return {result, {OrPredicate{std::move(lhs.predicates), std::move(rhs.predicates)}}};
}
else if (FFlag::LuauEqConstraint && (expr.op == AstExprBinary::CompareEq || expr.op == AstExprBinary::CompareNe))
else if (expr.op == AstExprBinary::CompareEq || expr.op == AstExprBinary::CompareNe)
{
if (auto predicate = tryGetTypeGuardPredicate(expr))
return {booleanType, {std::move(*predicate)}};
@ -2466,14 +2493,6 @@ ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBi
}
else
{
// Once we have EqPredicate, we should break this else branch into its' own branch.
// For now, fall through is intentional.
if (expr.op == AstExprBinary::CompareEq || expr.op == AstExprBinary::CompareNe)
{
if (auto predicate = tryGetTypeGuardPredicate(expr))
return {booleanType, {std::move(*predicate)}};
}
ExprResult<TypeId> lhs = checkExpr(scope, *expr.left);
ExprResult<TypeId> rhs = checkExpr(scope, *expr.right);
@ -2755,12 +2774,6 @@ std::pair<TypeId, TypeId*> TypeChecker::checkLValueBinding(const ScopePtr& scope
exprTable->indexer = TableIndexer{anyIfNonstrict(indexType), anyIfNonstrict(resultType)};
return std::pair(resultType, nullptr);
}
else if (FFlag::LuauIndexTablesWithIndexers)
{
// We allow t[x] where x:string for tables without an indexer
unify(indexType, stringType, expr.location);
return std::pair(anyType, nullptr);
}
else
{
TypeId resultType = freshType(scope);
@ -3076,6 +3089,13 @@ static Location getEndLocation(const AstExprFunction& function)
void TypeChecker::checkFunctionBody(const ScopePtr& scope, TypeId ty, const AstExprFunction& function)
{
LUAU_TIMETRACE_SCOPE("TypeChecker::checkFunctionBody", "TypeChecker");
if (function.debugname.value)
LUAU_TIMETRACE_ARGUMENT("name", function.debugname.value);
else
LUAU_TIMETRACE_ARGUMENT("line", std::to_string(function.location.begin.line).c_str());
if (FunctionTypeVar* funTy = getMutable<FunctionTypeVar>(ty))
{
check(scope, *function.body);
@ -3885,6 +3905,20 @@ std::optional<AstExpr*> TypeChecker::matchRequire(const AstExprCall& call)
TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& moduleInfo, const Location& location)
{
LUAU_TIMETRACE_SCOPE("TypeChecker::checkRequire", "TypeChecker");
LUAU_TIMETRACE_ARGUMENT("moduleInfo", moduleInfo.name.c_str());
if (FFlag::LuauNewRequireTrace && moduleInfo.name.empty())
{
if (FFlag::LuauStrictRequire && currentModule->mode == Mode::Strict)
{
reportError(TypeError{location, UnknownRequire{}});
return errorType;
}
return anyType;
}
ModulePtr module = resolver->getModule(moduleInfo.name);
if (!module)
{
@ -4472,7 +4506,7 @@ TypeId TypeChecker::freshType(const ScopePtr& scope)
TypeId TypeChecker::freshType(TypeLevel level)
{
return currentModule->internalTypes.typeVars.allocate(TypeVar(FreeTypeVar(level)));
return currentModule->internalTypes.addType(TypeVar(FreeTypeVar(level)));
}
TypeId TypeChecker::DEPRECATED_freshType(const ScopePtr& scope, bool canBeGeneric)
@ -4482,11 +4516,7 @@ TypeId TypeChecker::DEPRECATED_freshType(const ScopePtr& scope, bool canBeGeneri
TypeId TypeChecker::DEPRECATED_freshType(TypeLevel level, bool canBeGeneric)
{
TypeId allocated = currentModule->internalTypes.typeVars.allocate(TypeVar(FreeTypeVar(level, canBeGeneric)));
if (FFlag::DebugLuauTrackOwningArena)
asMutable(allocated)->owningArena = &currentModule->internalTypes;
return allocated;
return currentModule->internalTypes.addType(TypeVar(FreeTypeVar(level, canBeGeneric)));
}
std::optional<TypeId> TypeChecker::filterMap(TypeId type, TypeIdPredicate predicate)
@ -4506,20 +4536,12 @@ TypeId TypeChecker::addType(const UnionTypeVar& utv)
TypeId TypeChecker::addTV(TypeVar&& tv)
{
TypeId allocated = currentModule->internalTypes.typeVars.allocate(std::move(tv));
if (FFlag::DebugLuauTrackOwningArena)
asMutable(allocated)->owningArena = &currentModule->internalTypes;
return allocated;
return currentModule->internalTypes.addType(std::move(tv));
}
TypePackId TypeChecker::addTypePack(TypePackVar&& tv)
{
TypePackId allocated = currentModule->internalTypes.typePacks.allocate(std::move(tv));
if (FFlag::DebugLuauTrackOwningArena)
asMutable(allocated)->owningArena = &currentModule->internalTypes;
return allocated;
return currentModule->internalTypes.addTypePack(std::move(tv));
}
TypePackId TypeChecker::addTypePack(TypePack&& tp)
@ -4578,7 +4600,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation
else if (FFlag::DebugLuauMagicTypes && lit->name == "_luau_print")
{
if (lit->generics.size != 1)
if (lit->parameters.size != 1 || !lit->parameters.data[0].type)
{
reportError(TypeError{annotation.location, GenericError{"_luau_print requires one generic parameter"}});
return addType(ErrorTypeVar{});
@ -4588,7 +4610,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation
opts.exhaustive = true;
opts.maxTableLength = 0;
TypeId param = resolveType(scope, *lit->generics.data[0]);
TypeId param = resolveType(scope, *lit->parameters.data[0].type);
luauPrintLine(format("_luau_print\t%s\t|\t%s", toString(param, opts).c_str(), toString(lit->location).c_str()));
return param;
}
@ -4614,18 +4636,86 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation
return addType(ErrorTypeVar{});
}
if (lit->generics.size == 0 && tf->typeParams.empty())
return tf->type;
else if (lit->generics.size != tf->typeParams.size())
if (lit->parameters.size == 0 && tf->typeParams.empty() && (!FFlag::LuauTypeAliasPacks || tf->typePackParams.empty()))
{
reportError(TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, lit->generics.size}});
return tf->type;
}
else if (!FFlag::LuauTypeAliasPacks && lit->parameters.size != tf->typeParams.size())
{
reportError(TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, lit->parameters.size, 0}});
return addType(ErrorTypeVar{});
}
else if (FFlag::LuauTypeAliasPacks)
{
if (!lit->hasParameterList && !tf->typePackParams.empty())
{
reportError(TypeError{annotation.location, GenericError{"Type parameter list is required"}});
return addType(ErrorTypeVar{});
}
std::vector<TypeId> typeParams;
std::vector<TypeId> extraTypes;
std::vector<TypePackId> typePackParams;
for (size_t i = 0; i < lit->parameters.size; ++i)
{
if (AstType* type = lit->parameters.data[i].type)
{
TypeId ty = resolveType(scope, *type);
if (typeParams.size() < tf->typeParams.size() || tf->typePackParams.empty())
typeParams.push_back(ty);
else if (typePackParams.empty())
extraTypes.push_back(ty);
else
reportError(TypeError{annotation.location, GenericError{"Type parameters must come before type pack parameters"}});
}
else if (AstTypePack* typePack = lit->parameters.data[i].typePack)
{
TypePackId tp = resolveTypePack(scope, *typePack);
// If we have collected an implicit type pack, materialize it
if (typePackParams.empty() && !extraTypes.empty())
typePackParams.push_back(addTypePack(extraTypes));
// If we need more regular types, we can use single element type packs to fill those in
if (typeParams.size() < tf->typeParams.size() && size(tp) == 1 && finite(tp) && first(tp))
typeParams.push_back(*first(tp));
else
typePackParams.push_back(tp);
}
}
// If we still haven't meterialized an implicit type pack, do it now
if (typePackParams.empty() && !extraTypes.empty())
typePackParams.push_back(addTypePack(extraTypes));
// If we didn't combine regular types into a type pack and we're still one type pack short, provide an empty type pack
if (extraTypes.empty() && typePackParams.size() + 1 == tf->typePackParams.size())
typePackParams.push_back(addTypePack({}));
if (typeParams.size() != tf->typeParams.size() || typePackParams.size() != tf->typePackParams.size())
{
reportError(
TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, typeParams.size(), typePackParams.size()}});
return addType(ErrorTypeVar{});
}
if (FFlag::LuauRecursiveTypeParameterRestriction && typeParams == tf->typeParams && typePackParams == tf->typePackParams)
{
// If the generic parameters and the type arguments are the same, we are about to
// perform an identity substitution, which we can just short-circuit.
return tf->type;
}
return instantiateTypeFun(scope, *tf, typeParams, typePackParams, annotation.location);
}
else
{
std::vector<TypeId> typeParams;
for (AstType* paramAnnot : lit->generics)
typeParams.push_back(resolveType(scope, *paramAnnot));
for (const auto& param : lit->parameters)
typeParams.push_back(resolveType(scope, *param.type));
if (FFlag::LuauRecursiveTypeParameterRestriction && typeParams == tf->typeParams)
{
@ -4634,7 +4724,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation
return tf->type;
}
return instantiateTypeFun(scope, *tf, typeParams, annotation.location);
return instantiateTypeFun(scope, *tf, typeParams, {}, annotation.location);
}
}
else if (const auto& table = annotation.as<AstTypeTable>())
@ -4765,6 +4855,18 @@ TypePackId TypeChecker::resolveTypePack(const ScopePtr& scope, const AstTypePack
return *genericTy;
}
else if (const AstTypePackExplicit* explicitTp = annotation.as<AstTypePackExplicit>())
{
std::vector<TypeId> types;
for (auto type : explicitTp->typeList.types)
types.push_back(resolveType(scope, *type));
if (auto tailType = explicitTp->typeList.tailType)
return addTypePack(types, resolveTypePack(scope, *tailType));
return addTypePack(types);
}
else
{
ice("Unknown AstTypePack kind");
@ -4799,12 +4901,28 @@ bool ApplyTypeFunction::isDirty(TypePackId tp)
return false;
}
bool ApplyTypeFunction::ignoreChildren(TypeId ty)
{
if (FFlag::LuauSubstitutionDontReplaceIgnoredTypes && get<GenericTypeVar>(ty))
return true;
else
return false;
}
bool ApplyTypeFunction::ignoreChildren(TypePackId tp)
{
if (FFlag::LuauSubstitutionDontReplaceIgnoredTypes && get<GenericTypePack>(tp))
return true;
else
return false;
}
TypeId ApplyTypeFunction::clean(TypeId ty)
{
// Really this should just replace the arguments,
// but for bug-compatibility with existing code, we replace
// all generics by free type variables.
TypeId& arg = arguments[ty];
TypeId& arg = typeArguments[ty];
if (arg)
return arg;
else
@ -4816,17 +4934,37 @@ TypePackId ApplyTypeFunction::clean(TypePackId tp)
// Really this should just replace the arguments,
// but for bug-compatibility with existing code, we replace
// all generics by free type variables.
if (FFlag::LuauTypeAliasPacks)
{
TypePackId& arg = typePackArguments[tp];
if (arg)
return arg;
else
return addTypePack(FreeTypePack{level});
}
else
{
return addTypePack(FreeTypePack{level});
}
}
TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector<TypeId>& typeParams, const Location& location)
TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& typePackParams, const Location& location)
{
if (tf.typeParams.empty())
if (tf.typeParams.empty() && (!FFlag::LuauTypeAliasPacks || tf.typePackParams.empty()))
return tf.type;
applyTypeFunction.arguments.clear();
applyTypeFunction.typeArguments.clear();
for (size_t i = 0; i < tf.typeParams.size(); ++i)
applyTypeFunction.arguments[tf.typeParams[i]] = typeParams[i];
applyTypeFunction.typeArguments[tf.typeParams[i]] = typeParams[i];
if (FFlag::LuauTypeAliasPacks)
{
applyTypeFunction.typePackArguments.clear();
for (size_t i = 0; i < tf.typePackParams.size(); ++i)
applyTypeFunction.typePackArguments[tf.typePackParams[i]] = typePackParams[i];
}
applyTypeFunction.currentModule = currentModule;
applyTypeFunction.level = scope->level;
applyTypeFunction.encounteredForwardedType = false;
@ -4875,6 +5013,9 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf,
if (ttv)
{
ttv->instantiatedTypeParams = typeParams;
if (FFlag::LuauTypeAliasPacks)
ttv->instantiatedTypePackParams = typePackParams;
}
}
else
@ -4890,6 +5031,9 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf,
}
ttv->instantiatedTypeParams = typeParams;
if (FFlag::LuauTypeAliasPacks)
ttv->instantiatedTypePackParams = typePackParams;
}
}
@ -4899,6 +5043,8 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf,
std::pair<std::vector<TypeId>, std::vector<TypePackId>> TypeChecker::createGenericTypes(
const ScopePtr& scope, const AstNode& node, const AstArray<AstName>& genericNames, const AstArray<AstName>& genericPackNames)
{
LUAU_ASSERT(scope->parent);
std::vector<TypeId> generics;
for (const AstName& generic : genericNames)
{
@ -4912,7 +5058,19 @@ std::pair<std::vector<TypeId>, std::vector<TypePackId>> TypeChecker::createGener
reportError(TypeError{node.location, DuplicateGenericParameter{n}});
}
TypeId g = addType(Unifiable::Generic{scope->level, n});
TypeId g;
if (FFlag::LuauRecursiveTypeParameterRestriction && FFlag::LuauTypeAliasPacks)
{
TypeId& cached = scope->parent->typeAliasTypeParameters[n];
if (!cached)
cached = addType(GenericTypeVar{scope->level, n});
g = cached;
}
else
{
g = addType(Unifiable::Generic{scope->level, n});
}
generics.push_back(g);
scope->privateTypeBindings[n] = TypeFun{{}, g};
}
@ -4930,7 +5088,19 @@ std::pair<std::vector<TypeId>, std::vector<TypePackId>> TypeChecker::createGener
reportError(TypeError{node.location, DuplicateGenericParameter{n}});
}
TypePackId g = addTypePack(TypePackVar{Unifiable::Generic{scope->level, n}});
TypePackId g;
if (FFlag::LuauRecursiveTypeParameterRestriction && FFlag::LuauTypeAliasPacks)
{
TypePackId& cached = scope->parent->typeAliasTypePackParameters[n];
if (!cached)
cached = addTypePack(TypePackVar{Unifiable::Generic{scope->level, n}});
g = cached;
}
else
{
g = addTypePack(TypePackVar{Unifiable::Generic{scope->level, n}});
}
genericPacks.push_back(g);
scope->privateTypePackBindings[n] = g;
}
@ -5013,13 +5183,8 @@ void TypeChecker::resolve(const Predicate& predicate, ErrorVec& errVec, Refineme
else if (auto isaP = get<IsAPredicate>(predicate))
resolve(*isaP, errVec, refis, scope, sense);
else if (auto typeguardP = get<TypeGuardPredicate>(predicate))
{
if (FFlag::LuauImprovedTypeGuardPredicate2)
resolve(*typeguardP, errVec, refis, scope, sense);
else
DEPRECATED_resolve(*typeguardP, errVec, refis, scope, sense);
}
else if (auto eqP = get<EqPredicate>(predicate); eqP && FFlag::LuauEqConstraint)
else if (auto eqP = get<EqPredicate>(predicate))
resolve(*eqP, errVec, refis, scope, sense);
else
ice("Unhandled predicate kind");
@ -5145,7 +5310,7 @@ void TypeChecker::resolve(const IsAPredicate& isaP, ErrorVec& errVec, Refinement
return isaP.ty;
}
}
else if (FFlag::LuauImprovedTypeGuardPredicate2)
else
{
auto lctv = get<ClassTypeVar>(option);
auto rctv = get<ClassTypeVar>(isaP.ty);
@ -5159,19 +5324,6 @@ void TypeChecker::resolve(const IsAPredicate& isaP, ErrorVec& errVec, Refinement
if (canUnify(option, isaP.ty, isaP.location).empty() == sense)
return isaP.ty;
}
else
{
auto lctv = get<ClassTypeVar>(option);
auto rctv = get<ClassTypeVar>(isaP.ty);
if (lctv && rctv)
{
if (isSubclass(lctv, rctv) == sense)
return option;
else if (isSubclass(rctv, lctv) == sense)
return isaP.ty;
}
}
return std::nullopt;
};
@ -5266,7 +5418,7 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec
return fail(UnknownSymbol{typeguardP.kind, UnknownSymbol::Type});
auto typeFun = globalScope->lookupType(typeguardP.kind);
if (!typeFun || !typeFun->typeParams.empty())
if (!typeFun || !typeFun->typeParams.empty() || (FFlag::LuauTypeAliasPacks && !typeFun->typePackParams.empty()))
return fail(UnknownSymbol{typeguardP.kind, UnknownSymbol::Type});
TypeId type = follow(typeFun->type);
@ -5292,7 +5444,8 @@ void TypeChecker::DEPRECATED_resolve(const TypeGuardPredicate& typeguardP, Error
"userdata", // no op. Requires special handling.
};
if (auto typeFun = globalScope->lookupType(typeguardP.kind); typeFun && typeFun->typeParams.empty())
if (auto typeFun = globalScope->lookupType(typeguardP.kind);
typeFun && typeFun->typeParams.empty() && (!FFlag::LuauTypeAliasPacks || typeFun->typePackParams.empty()))
{
if (auto it = std::find(primitives.begin(), primitives.end(), typeguardP.kind); it != primitives.end())
addRefinement(refis, typeguardP.lvalue, typeFun->type);
@ -5319,6 +5472,8 @@ void TypeChecker::resolve(const EqPredicate& eqP, ErrorVec& errVec, RefinementMa
return;
}
if (FFlag::LuauEqConstraint)
{
std::optional<TypeId> ty = resolveLValue(refis, scope, eqP.lvalue);
if (!ty)
return;
@ -5351,6 +5506,7 @@ void TypeChecker::resolve(const EqPredicate& eqP, ErrorVec& errVec, RefinementMa
std::vector<TypeId> viable(set.begin(), set.end());
TypeId result = viable.size() == 1 ? viable[0] : addType(UnionTypeVar{std::move(viable)});
addRefinement(refis, eqP.lvalue, result);
}
}
bool TypeChecker::isNonstrictMode() const
@ -5379,119 +5535,4 @@ std::vector<std::pair<Location, ScopePtr>> TypeChecker::getScopes() const
return currentModule->scopes;
}
Scope::Scope(TypePackId returnType)
: parent(nullptr)
, returnType(returnType)
, level(TypeLevel())
{
}
Scope::Scope(const ScopePtr& parent, int subLevel)
: parent(parent)
, returnType(parent->returnType)
, level(parent->level.incr())
{
level.subLevel = subLevel;
}
std::optional<TypeId> Scope::lookup(const Symbol& name)
{
Scope* scope = this;
while (scope)
{
auto it = scope->bindings.find(name);
if (it != scope->bindings.end())
return it->second.typeId;
scope = scope->parent.get();
}
return std::nullopt;
}
std::optional<TypeFun> Scope::lookupType(const Name& name)
{
const Scope* scope = this;
while (true)
{
auto it = scope->exportedTypeBindings.find(name);
if (it != scope->exportedTypeBindings.end())
return it->second;
it = scope->privateTypeBindings.find(name);
if (it != scope->privateTypeBindings.end())
return it->second;
if (scope->parent)
scope = scope->parent.get();
else
return std::nullopt;
}
}
std::optional<TypeFun> Scope::lookupImportedType(const Name& moduleAlias, const Name& name)
{
const Scope* scope = this;
while (scope)
{
auto it = scope->importedTypeBindings.find(moduleAlias);
if (it == scope->importedTypeBindings.end())
{
scope = scope->parent.get();
continue;
}
auto it2 = it->second.find(name);
if (it2 == it->second.end())
{
scope = scope->parent.get();
continue;
}
return it2->second;
}
return std::nullopt;
}
std::optional<TypePackId> Scope::lookupPack(const Name& name)
{
const Scope* scope = this;
while (true)
{
auto it = scope->privateTypePackBindings.find(name);
if (it != scope->privateTypePackBindings.end())
return it->second;
if (scope->parent)
scope = scope->parent.get();
else
return std::nullopt;
}
}
std::optional<Binding> Scope::linearSearchForBinding(const std::string& name, bool traverseScopeChain)
{
Scope* scope = this;
while (scope)
{
for (const auto& [n, binding] : scope->bindings)
{
if (n.local && n.local->name == name.c_str())
return binding;
else if (n.global.value && n.global == name.c_str())
return binding;
}
scope = scope->parent.get();
if (!traverseScopeChain)
break;
}
return std::nullopt;
}
} // namespace Luau

View file

@ -209,6 +209,19 @@ size_t size(TypePackId tp)
return 0;
}
bool finite(TypePackId tp)
{
tp = follow(tp);
if (auto pack = get<TypePack>(tp))
return pack->tail ? finite(*pack->tail) : true;
if (auto pack = get<VariadicTypePack>(tp))
return false;
return true;
}
size_t size(const TypePack& tp)
{
size_t result = tp.head.size();

View file

@ -1,11 +1,10 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/TypeUtils.h"
#include "Luau/Scope.h"
#include "Luau/ToString.h"
#include "Luau/TypeInfer.h"
LUAU_FASTFLAG(LuauStringMetatable)
namespace Luau
{
@ -13,21 +12,6 @@ std::optional<TypeId> findMetatableEntry(ErrorVec& errors, const ScopePtr& globa
{
type = follow(type);
if (!FFlag::LuauStringMetatable)
{
if (const PrimitiveTypeVar* primType = get<PrimitiveTypeVar>(type))
{
if (primType->type != PrimitiveTypeVar::String || "__index" != entry)
return std::nullopt;
auto it = globalScope->bindings.find(AstName{"string"});
if (it != globalScope->bindings.end())
return it->second.typeId;
else
return std::nullopt;
}
}
std::optional<TypeId> metatable = getMetatable(type);
if (!metatable)
return std::nullopt;

View file

@ -19,11 +19,9 @@
LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500)
LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0)
LUAU_FASTFLAG(LuauImprovedTypeGuardPredicate2)
LUAU_FASTFLAGVARIABLE(LuauToStringFollowsBoundTo, false)
LUAU_FASTFLAG(LuauRankNTypes)
LUAU_FASTFLAGVARIABLE(LuauStringMetatable, false)
LUAU_FASTFLAG(LuauTypeGuardPeelsAwaySubclasses)
LUAU_FASTFLAG(LuauTypeAliasPacks)
namespace Luau
{
@ -193,27 +191,11 @@ bool isOptional(TypeId ty)
bool isTableIntersection(TypeId ty)
{
if (FFlag::LuauImprovedTypeGuardPredicate2)
{
if (!get<IntersectionTypeVar>(follow(ty)))
return false;
std::vector<TypeId> parts = flattenIntersection(ty);
return std::all_of(parts.begin(), parts.end(), getTableType);
}
else
{
if (const IntersectionTypeVar* itv = get<IntersectionTypeVar>(ty))
{
for (TypeId part : itv->parts)
{
if (getTableType(follow(part)))
return true;
}
}
return false;
}
}
bool isOverloadedFunction(TypeId ty)
@ -236,7 +218,7 @@ std::optional<TypeId> getMetatable(TypeId type)
else if (const ClassTypeVar* classType = get<ClassTypeVar>(type))
return classType->metatable;
else if (const PrimitiveTypeVar* primitiveType = get<PrimitiveTypeVar>(type);
FFlag::LuauStringMetatable && primitiveType && primitiveType->metatable)
primitiveType && primitiveType->metatable)
{
LUAU_ASSERT(primitiveType->type == PrimitiveTypeVar::String);
return primitiveType->metatable;
@ -871,6 +853,12 @@ void StateDot::visitChildren(TypeId ty, int index)
}
for (TypeId itp : ttv->instantiatedTypeParams)
visitChild(itp, index, "typeParam");
if (FFlag::LuauTypeAliasPacks)
{
for (TypePackId itp : ttv->instantiatedTypePackParams)
visitChild(itp, index, "typePackParam");
}
}
else if (const MetatableTypeVar* mtv = get<MetatableTypeVar>(ty))
{

View file

@ -3,23 +3,25 @@
#include "Luau/Common.h"
#include "Luau/RecursionCounter.h"
#include "Luau/Scope.h"
#include "Luau/TypePack.h"
#include "Luau/TypeUtils.h"
#include "Luau/TimeTrace.h"
#include <algorithm>
LUAU_FASTINT(LuauTypeInferRecursionLimit);
LUAU_FASTINT(LuauTypeInferTypePackLoopLimit);
LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 0);
LUAU_FASTFLAGVARIABLE(LuauLogTableTypeVarBoundTo, false)
LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 2000);
LUAU_FASTFLAG(LuauGenericFunctions)
LUAU_FASTFLAGVARIABLE(LuauTableSubtypingVariance, false);
LUAU_FASTFLAGVARIABLE(LuauDontMutatePersistentFunctions, false)
LUAU_FASTFLAG(LuauRankNTypes)
LUAU_FASTFLAG(LuauStringMetatable)
LUAU_FASTFLAGVARIABLE(LuauUnionHeuristic, false)
LUAU_FASTFLAGVARIABLE(LuauTableUnificationEarlyTest, false)
LUAU_FASTFLAGVARIABLE(LuauSealedTableUnifyOptionalFix, false)
LUAU_FASTFLAGVARIABLE(LuauOccursCheckOkWithRecursiveFunctions, false)
LUAU_FASTFLAGVARIABLE(LuauTypecheckOpts, false)
namespace Luau
{
@ -43,21 +45,23 @@ Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Locati
, globalScope(std::move(globalScope))
, location(location)
, variance(variance)
, counters(std::make_shared<UnifierCounters>())
, counters(&countersData)
, counters_DEPRECATED(std::make_shared<UnifierCounters>())
, iceHandler(iceHandler)
{
LUAU_ASSERT(iceHandler);
}
Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::vector<std::pair<TypeId, TypeId>>& seen, const Location& location,
Variance variance, InternalErrorReporter* iceHandler, const std::shared_ptr<UnifierCounters>& counters)
Variance variance, InternalErrorReporter* iceHandler, const std::shared_ptr<UnifierCounters>& counters_DEPRECATED, UnifierCounters* counters)
: types(types)
, mode(mode)
, globalScope(std::move(globalScope))
, log(seen)
, location(location)
, variance(variance)
, counters(counters ? counters : std::make_shared<UnifierCounters>())
, counters(counters ? counters : &countersData)
, counters_DEPRECATED(counters_DEPRECATED ? counters_DEPRECATED : std::make_shared<UnifierCounters>())
, iceHandler(iceHandler)
{
LUAU_ASSERT(iceHandler);
@ -65,16 +69,26 @@ Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::v
void Unifier::tryUnify(TypeId superTy, TypeId subTy, bool isFunctionCall, bool isIntersection)
{
if (FFlag::LuauTypecheckOpts)
counters->iterationCount = 0;
else
counters_DEPRECATED->iterationCount = 0;
return tryUnify_(superTy, subTy, isFunctionCall, isIntersection);
}
void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool isIntersection)
{
RecursionLimiter _ra(&counters->recursionCount, FInt::LuauTypeInferRecursionLimit);
RecursionLimiter _ra(
FFlag::LuauTypecheckOpts ? &counters->recursionCount : &counters_DEPRECATED->recursionCount, FInt::LuauTypeInferRecursionLimit);
if (FFlag::LuauTypecheckOpts)
++counters->iterationCount;
if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < counters->iterationCount)
else
++counters_DEPRECATED->iterationCount;
if (FInt::LuauTypeInferIterationLimit > 0 &&
FInt::LuauTypeInferIterationLimit < (FFlag::LuauTypecheckOpts ? counters->iterationCount : counters_DEPRECATED->iterationCount))
{
errors.push_back(TypeError{location, UnificationTooComplex{}});
return;
@ -440,7 +454,11 @@ ErrorVec Unifier::canUnify(TypePackId superTy, TypePackId subTy, bool isFunction
void Unifier::tryUnify(TypePackId superTp, TypePackId subTp, bool isFunctionCall)
{
if (FFlag::LuauTypecheckOpts)
counters->iterationCount = 0;
else
counters_DEPRECATED->iterationCount = 0;
return tryUnify_(superTp, subTp, isFunctionCall);
}
@ -450,10 +468,16 @@ void Unifier::tryUnify(TypePackId superTp, TypePackId subTp, bool isFunctionCall
*/
void Unifier::tryUnify_(TypePackId superTp, TypePackId subTp, bool isFunctionCall)
{
RecursionLimiter _ra(&counters->recursionCount, FInt::LuauTypeInferRecursionLimit);
RecursionLimiter _ra(
FFlag::LuauTypecheckOpts ? &counters->recursionCount : &counters_DEPRECATED->recursionCount, FInt::LuauTypeInferRecursionLimit);
if (FFlag::LuauTypecheckOpts)
++counters->iterationCount;
if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < counters->iterationCount)
else
++counters_DEPRECATED->iterationCount;
if (FInt::LuauTypeInferIterationLimit > 0 &&
FInt::LuauTypeInferIterationLimit < (FFlag::LuauTypecheckOpts ? counters->iterationCount : counters_DEPRECATED->iterationCount))
{
errors.push_back(TypeError{location, UnificationTooComplex{}});
return;
@ -762,9 +786,210 @@ struct Resetter
void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection)
{
std::unique_ptr<Resetter> resetter;
if (!FFlag::LuauTableSubtypingVariance)
return DEPRECATED_tryUnifyTables(left, right, isIntersection);
resetter.reset(new Resetter{&variance});
TableTypeVar* lt = getMutable<TableTypeVar>(left);
TableTypeVar* rt = getMutable<TableTypeVar>(right);
if (!lt || !rt)
ice("passed non-table types to unifyTables");
std::vector<std::string> missingProperties;
std::vector<std::string> extraProperties;
// Reminder: left is the supertype, right is the subtype.
// Width subtyping: any property in the supertype must be in the subtype,
// and the types must agree.
for (const auto& [name, prop] : lt->props)
{
const auto& r = rt->props.find(name);
if (r != rt->props.end())
{
// TODO: read-only properties don't need invariance
Resetter resetter{&variance};
variance = Invariant;
Unifier innerState = makeChildUnifier();
innerState.tryUnify_(prop.type, r->second.type);
checkChildUnifierTypeMismatch(innerState.errors, left, right);
if (innerState.errors.empty())
log.concat(std::move(innerState.log));
else
innerState.log.rollback();
}
else if (rt->indexer && isString(rt->indexer->indexType))
{
// TODO: read-only indexers don't need invariance
// TODO: really we should only allow this if prop.type is optional.
Resetter resetter{&variance};
variance = Invariant;
Unifier innerState = makeChildUnifier();
innerState.tryUnify_(prop.type, rt->indexer->indexResultType);
checkChildUnifierTypeMismatch(innerState.errors, left, right);
if (innerState.errors.empty())
log.concat(std::move(innerState.log));
else
innerState.log.rollback();
}
else if (isOptional(prop.type) || get<AnyTypeVar>(follow(prop.type)))
// TODO: this case is unsound, but without it our test suite fails. CLI-46031
// TODO: should isOptional(anyType) be true?
{}
else if (rt->state == TableState::Free)
{
log(rt);
rt->props[name] = prop;
}
else
missingProperties.push_back(name);
}
for (const auto& [name, prop] : rt->props)
{
if (lt->props.count(name))
{
// If both lt and rt contain the property, then
// we're done since we already unified them above
}
else if (lt->indexer && isString(lt->indexer->indexType))
{
// TODO: read-only indexers don't need invariance
// TODO: really we should only allow this if prop.type is optional.
Resetter resetter{&variance};
variance = Invariant;
Unifier innerState = makeChildUnifier();
innerState.tryUnify_(prop.type, lt->indexer->indexResultType);
checkChildUnifierTypeMismatch(innerState.errors, left, right);
if (innerState.errors.empty())
log.concat(std::move(innerState.log));
else
innerState.log.rollback();
}
else if (lt->state == TableState::Unsealed)
{
// TODO: this case is unsound when variance is Invariant, but without it lua-apps fails to typecheck.
// TODO: file a JIRA
// TODO: hopefully readonly/writeonly properties will fix this.
Property clone = prop;
clone.type = deeplyOptional(clone.type);
log(lt);
lt->props[name] = clone;
}
else if (variance == Covariant)
{}
else if (isOptional(prop.type) || get<AnyTypeVar>(follow(prop.type)))
// TODO: this case is unsound, but without it our test suite fails. CLI-46031
// TODO: should isOptional(anyType) be true?
{}
else if (lt->state == TableState::Free)
{
log(lt);
lt->props[name] = prop;
}
else
extraProperties.push_back(name);
}
// Unify indexers
if (lt->indexer && rt->indexer)
{
// TODO: read-only indexers don't need invariance
Resetter resetter{&variance};
variance = Invariant;
Unifier innerState = makeChildUnifier();
innerState.tryUnify(*lt->indexer, *rt->indexer);
checkChildUnifierTypeMismatch(innerState.errors, left, right);
if (innerState.errors.empty())
log.concat(std::move(innerState.log));
else
innerState.log.rollback();
}
else if (lt->indexer)
{
if (rt->state == TableState::Unsealed || rt->state == TableState::Free)
{
// passing/assigning a table without an indexer to something that has one
// e.g. table.insert(t, 1) where t is a non-sealed table and doesn't have an indexer.
// TODO: we only need to do this if the supertype's indexer is read/write
// since that can add indexed elements.
log(rt);
rt->indexer = lt->indexer;
}
}
else if (rt->indexer && variance == Invariant)
{
// Symmetric if we are invariant
if (lt->state == TableState::Unsealed || lt->state == TableState::Free)
{
log(lt);
lt->indexer = rt->indexer;
}
}
if (!missingProperties.empty())
{
errors.push_back(TypeError{location, MissingProperties{left, right, std::move(missingProperties)}});
return;
}
if (!extraProperties.empty())
{
errors.push_back(TypeError{location, MissingProperties{left, right, std::move(extraProperties), MissingProperties::Extra}});
return;
}
/*
* TypeVars are commonly cyclic, so it is entirely possible
* for unifying a property of a table to change the table itself!
* We need to check for this and start over if we notice this occurring.
*
* I believe this is guaranteed to terminate eventually because this will
* only happen when a free table is bound to another table.
*/
if (lt->boundTo || rt->boundTo)
return tryUnify_(left, right);
if (lt->state == TableState::Free)
{
log(lt);
lt->boundTo = right;
}
else if (rt->state == TableState::Free)
{
log(rt);
rt->boundTo = left;
}
}
TypeId Unifier::deeplyOptional(TypeId ty, std::unordered_map<TypeId, TypeId> seen)
{
ty = follow(ty);
if (get<AnyTypeVar>(ty))
return ty;
else if (isOptional(ty))
return ty;
else if (const TableTypeVar* ttv = get<TableTypeVar>(ty))
{
TypeId& result = seen[ty];
if (result)
return result;
result = types->addType(*ttv);
TableTypeVar* resultTtv = getMutable<TableTypeVar>(result);
for (auto& [name, prop] : resultTtv->props)
prop.type = deeplyOptional(prop.type, seen);
return types->addType(UnionTypeVar{{ singletonTypes.nilType, result }});;
}
else
return types->addType(UnionTypeVar{{ singletonTypes.nilType, ty }});
}
void Unifier::DEPRECATED_tryUnifyTables(TypeId left, TypeId right, bool isIntersection)
{
LUAU_ASSERT(!FFlag::LuauTableSubtypingVariance);
Resetter resetter{&variance};
variance = Invariant;
TableTypeVar* lt = getMutable<TableTypeVar>(left);
@ -894,10 +1119,7 @@ void Unifier::tryUnifyFreeTable(TypeId freeTypeId, TypeId otherTypeId)
if (!freeTable->boundTo && otherTable->state != TableState::Free)
{
if (FFlag::LuauLogTableTypeVarBoundTo)
log(freeTable);
else
log(freeTypeId);
freeTable->boundTo = otherTypeId;
}
}
@ -1196,9 +1418,11 @@ void Unifier::tryUnify(const TableIndexer& superIndexer, const TableIndexer& sub
tryUnify_(superIndexer.indexResultType, subIndexer.indexResultType);
}
static void queueTypePack(
static void queueTypePack_DEPRECATED(
std::vector<TypeId>& queue, std::unordered_set<TypePackId>& seenTypePacks, Unifier& state, TypePackId a, TypePackId anyTypePack)
{
LUAU_ASSERT(!FFlag::LuauTypecheckOpts);
while (true)
{
if (FFlag::LuauAddMissingFollow)
@ -1244,6 +1468,55 @@ static void queueTypePack(
}
}
static void queueTypePack(std::vector<TypeId>& queue, DenseHashSet<TypePackId>& seenTypePacks, Unifier& state, TypePackId a, TypePackId anyTypePack)
{
LUAU_ASSERT(FFlag::LuauTypecheckOpts);
while (true)
{
if (FFlag::LuauAddMissingFollow)
a = follow(a);
if (seenTypePacks.find(a))
break;
seenTypePacks.insert(a);
if (FFlag::LuauAddMissingFollow)
{
if (get<Unifiable::Free>(a))
{
state.log(a);
*asMutable(a) = Unifiable::Bound{anyTypePack};
}
else if (auto tp = get<TypePack>(a))
{
queue.insert(queue.end(), tp->head.begin(), tp->head.end());
if (tp->tail)
a = *tp->tail;
else
break;
}
}
else
{
if (get<Unifiable::Free>(a))
{
state.log(a);
*asMutable(a) = Unifiable::Bound{anyTypePack};
}
if (auto tp = get<TypePack>(a))
{
queue.insert(queue.end(), tp->head.begin(), tp->head.end());
if (tp->tail)
a = *tp->tail;
else
break;
}
}
}
}
void Unifier::tryUnifyVariadics(TypePackId superTp, TypePackId subTp, bool reversed, int subOffset)
{
const VariadicTypePack* lv = get<VariadicTypePack>(superTp);
@ -1297,9 +1570,11 @@ void Unifier::tryUnifyVariadics(TypePackId superTp, TypePackId subTp, bool rever
}
}
static void tryUnifyWithAny(
static void tryUnifyWithAny_DEPRECATED(
std::vector<TypeId>& queue, Unifier& state, std::unordered_set<TypePackId>& seenTypePacks, TypeId anyType, TypePackId anyTypePack)
{
LUAU_ASSERT(!FFlag::LuauTypecheckOpts);
std::unordered_set<TypeId> seen;
while (!queue.empty())
@ -1310,6 +1585,59 @@ static void tryUnifyWithAny(
continue;
seen.insert(ty);
if (get<FreeTypeVar>(ty))
{
state.log(ty);
*asMutable(ty) = BoundTypeVar{anyType};
}
else if (auto fun = get<FunctionTypeVar>(ty))
{
queueTypePack_DEPRECATED(queue, seenTypePacks, state, fun->argTypes, anyTypePack);
queueTypePack_DEPRECATED(queue, seenTypePacks, state, fun->retType, anyTypePack);
}
else if (auto table = get<TableTypeVar>(ty))
{
for (const auto& [_name, prop] : table->props)
queue.push_back(prop.type);
if (table->indexer)
{
queue.push_back(table->indexer->indexType);
queue.push_back(table->indexer->indexResultType);
}
}
else if (auto mt = get<MetatableTypeVar>(ty))
{
queue.push_back(mt->table);
queue.push_back(mt->metatable);
}
else if (get<ClassTypeVar>(ty))
{
// ClassTypeVars never contain free typevars.
}
else if (auto union_ = get<UnionTypeVar>(ty))
queue.insert(queue.end(), union_->options.begin(), union_->options.end());
else if (auto intersection = get<IntersectionTypeVar>(ty))
queue.insert(queue.end(), intersection->parts.begin(), intersection->parts.end());
else
{
} // Primitives, any, errors, and generics are left untouched.
}
}
static void tryUnifyWithAny(std::vector<TypeId>& queue, Unifier& state, DenseHashSet<TypeId>& seen, DenseHashSet<TypePackId>& seenTypePacks,
TypeId anyType, TypePackId anyTypePack)
{
LUAU_ASSERT(FFlag::LuauTypecheckOpts);
while (!queue.empty())
{
TypeId ty = follow(queue.back());
queue.pop_back();
if (seen.find(ty))
continue;
seen.insert(ty);
if (get<FreeTypeVar>(ty))
{
state.log(ty);
@ -1354,14 +1682,33 @@ void Unifier::tryUnifyWithAny(TypeId any, TypeId ty)
{
LUAU_ASSERT(get<AnyTypeVar>(any) || get<ErrorTypeVar>(any));
if (FFlag::LuauTypecheckOpts)
{
// These types are not visited in general loop below
if (get<PrimitiveTypeVar>(ty) || get<AnyTypeVar>(ty) || get<ClassTypeVar>(ty))
return;
}
const TypePackId anyTypePack = types->addTypePack(TypePackVar{VariadicTypePack{singletonTypes.anyType}});
const TypePackId anyTP = get<AnyTypeVar>(any) ? anyTypePack : types->addTypePack(TypePackVar{Unifiable::Error{}});
if (FFlag::LuauTypecheckOpts)
{
std::vector<TypeId> queue = {ty};
tempSeenTy.clear();
tempSeenTp.clear();
Luau::tryUnifyWithAny(queue, *this, tempSeenTy, tempSeenTp, singletonTypes.anyType, anyTP);
}
else
{
std::unordered_set<TypePackId> seenTypePacks;
std::vector<TypeId> queue = {ty};
Luau::tryUnifyWithAny(queue, *this, seenTypePacks, singletonTypes.anyType, anyTP);
Luau::tryUnifyWithAny_DEPRECATED(queue, *this, seenTypePacks, singletonTypes.anyType, anyTP);
}
}
void Unifier::tryUnifyWithAny(TypePackId any, TypePackId ty)
@ -1370,12 +1717,26 @@ void Unifier::tryUnifyWithAny(TypePackId any, TypePackId ty)
const TypeId anyTy = singletonTypes.errorType;
if (FFlag::LuauTypecheckOpts)
{
std::vector<TypeId> queue;
tempSeenTy.clear();
tempSeenTp.clear();
queueTypePack(queue, tempSeenTp, *this, ty, any);
Luau::tryUnifyWithAny(queue, *this, tempSeenTy, tempSeenTp, anyTy, any);
}
else
{
std::unordered_set<TypePackId> seenTypePacks;
std::vector<TypeId> queue;
queueTypePack(queue, seenTypePacks, *this, ty, any);
queueTypePack_DEPRECATED(queue, seenTypePacks, *this, ty, any);
Luau::tryUnifyWithAny(queue, *this, seenTypePacks, anyTy, any);
Luau::tryUnifyWithAny_DEPRECATED(queue, *this, seenTypePacks, anyTy, any);
}
}
std::optional<TypeId> Unifier::findTablePropertyRespectingMeta(TypeId lhsType, Name name)
@ -1387,21 +1748,6 @@ std::optional<TypeId> Unifier::findMetatableEntry(TypeId type, std::string entry
{
type = follow(type);
if (!FFlag::LuauStringMetatable)
{
if (const PrimitiveTypeVar* primType = get<PrimitiveTypeVar>(type))
{
if (primType->type != PrimitiveTypeVar::String || "__index" != entry)
return std::nullopt;
auto found = globalScope->bindings.find(AstName{"string"});
if (found == globalScope->bindings.end())
return std::nullopt;
else
return found->second.typeId;
}
}
std::optional<TypeId> metatable = getMetatable(type);
if (!metatable)
return std::nullopt;
@ -1427,21 +1773,36 @@ std::optional<TypeId> Unifier::findMetatableEntry(TypeId type, std::string entry
void Unifier::occursCheck(TypeId needle, TypeId haystack)
{
std::unordered_set<TypeId> seen;
return occursCheck(seen, needle, haystack);
std::unordered_set<TypeId> seen_DEPRECATED;
if (FFlag::LuauTypecheckOpts)
tempSeenTy.clear();
return occursCheck(seen_DEPRECATED, tempSeenTy, needle, haystack);
}
void Unifier::occursCheck(std::unordered_set<TypeId>& seen, TypeId needle, TypeId haystack)
void Unifier::occursCheck(std::unordered_set<TypeId>& seen_DEPRECATED, DenseHashSet<TypeId>& seen, TypeId needle, TypeId haystack)
{
RecursionLimiter _ra(&counters->recursionCount, FInt::LuauTypeInferRecursionLimit);
RecursionLimiter _ra(
FFlag::LuauTypecheckOpts ? &counters->recursionCount : &counters_DEPRECATED->recursionCount, FInt::LuauTypeInferRecursionLimit);
needle = follow(needle);
haystack = follow(haystack);
if (seen.end() != seen.find(haystack))
if (FFlag::LuauTypecheckOpts)
{
if (seen.find(haystack))
return;
seen.insert(haystack);
}
else
{
if (seen_DEPRECATED.end() != seen_DEPRECATED.find(haystack))
return;
seen_DEPRECATED.insert(haystack);
}
if (get<Unifiable::Error>(needle))
return;
@ -1458,7 +1819,7 @@ void Unifier::occursCheck(std::unordered_set<TypeId>& seen, TypeId needle, TypeI
}
auto check = [&](TypeId tv) {
occursCheck(seen, needle, tv);
occursCheck(seen_DEPRECATED, seen, needle, tv);
};
if (get<FreeTypeVar>(haystack))
@ -1488,19 +1849,33 @@ void Unifier::occursCheck(std::unordered_set<TypeId>& seen, TypeId needle, TypeI
void Unifier::occursCheck(TypePackId needle, TypePackId haystack)
{
std::unordered_set<TypePackId> seen;
return occursCheck(seen, needle, haystack);
std::unordered_set<TypePackId> seen_DEPRECATED;
if (FFlag::LuauTypecheckOpts)
tempSeenTp.clear();
return occursCheck(seen_DEPRECATED, tempSeenTp, needle, haystack);
}
void Unifier::occursCheck(std::unordered_set<TypePackId>& seen, TypePackId needle, TypePackId haystack)
void Unifier::occursCheck(std::unordered_set<TypePackId>& seen_DEPRECATED, DenseHashSet<TypePackId>& seen, TypePackId needle, TypePackId haystack)
{
needle = follow(needle);
haystack = follow(haystack);
if (seen.find(haystack) != seen.end())
if (FFlag::LuauTypecheckOpts)
{
if (seen.find(haystack))
return;
seen.insert(haystack);
}
else
{
if (seen_DEPRECATED.end() != seen_DEPRECATED.find(haystack))
return;
seen_DEPRECATED.insert(haystack);
}
if (get<Unifiable::Error>(needle))
return;
@ -1508,7 +1883,8 @@ void Unifier::occursCheck(std::unordered_set<TypePackId>& seen, TypePackId needl
if (!get<Unifiable::Free>(needle))
ice("Expected needle pack to be free");
RecursionLimiter _ra(&counters->recursionCount, FInt::LuauTypeInferRecursionLimit);
RecursionLimiter _ra(
FFlag::LuauTypecheckOpts ? &counters->recursionCount : &counters_DEPRECATED->recursionCount, FInt::LuauTypeInferRecursionLimit);
while (!get<ErrorTypeVar>(haystack))
{
@ -1528,8 +1904,8 @@ void Unifier::occursCheck(std::unordered_set<TypePackId>& seen, TypePackId needl
{
if (auto f = get<FunctionTypeVar>(FFlag::LuauAddMissingFollow ? follow(ty) : ty))
{
occursCheck(seen, needle, f->argTypes);
occursCheck(seen, needle, f->retType);
occursCheck(seen_DEPRECATED, seen, needle, f->argTypes);
occursCheck(seen_DEPRECATED, seen, needle, f->retType);
}
}
}
@ -1546,7 +1922,7 @@ void Unifier::occursCheck(std::unordered_set<TypePackId>& seen, TypePackId needl
Unifier Unifier::makeChildUnifier()
{
return Unifier{types, mode, globalScope, log.seen, location, variance, iceHandler, counters};
return Unifier{types, mode, globalScope, log.seen, location, variance, iceHandler, counters_DEPRECATED, counters};
}
bool Unifier::isNonstrictMode() const

View file

@ -264,6 +264,10 @@ public:
{
return false;
}
virtual bool visit(class AstTypePackExplicit* node)
{
return visit((class AstTypePack*)node);
}
virtual bool visit(class AstTypePackVariadic* node)
{
return visit((class AstTypePack*)node);
@ -930,12 +934,14 @@ class AstStatTypeAlias : public AstStat
public:
LUAU_RTTI(AstStatTypeAlias)
AstStatTypeAlias(const Location& location, const AstName& name, const AstArray<AstName>& generics, AstType* type, bool exported);
AstStatTypeAlias(const Location& location, const AstName& name, const AstArray<AstName>& generics, const AstArray<AstName>& genericPacks,
AstType* type, bool exported);
void visit(AstVisitor* visitor) override;
AstName name;
AstArray<AstName> generics;
AstArray<AstName> genericPacks;
AstType* type;
bool exported;
};
@ -1007,19 +1013,28 @@ public:
}
};
// Don't have Luau::Variant available, it's a bit of an overhead, but a plain struct is nice to use
struct AstTypeOrPack
{
AstType* type = nullptr;
AstTypePack* typePack = nullptr;
};
class AstTypeReference : public AstType
{
public:
LUAU_RTTI(AstTypeReference)
AstTypeReference(const Location& location, std::optional<AstName> prefix, AstName name, const AstArray<AstType*>& generics = {});
AstTypeReference(const Location& location, std::optional<AstName> prefix, AstName name, bool hasParameterList = false,
const AstArray<AstTypeOrPack>& parameters = {});
void visit(AstVisitor* visitor) override;
bool hasPrefix;
bool hasParameterList;
AstName prefix;
AstName name;
AstArray<AstType*> generics;
AstArray<AstTypeOrPack> parameters;
};
struct AstTableProp
@ -1152,6 +1167,18 @@ public:
}
};
class AstTypePackExplicit : public AstTypePack
{
public:
LUAU_RTTI(AstTypePackExplicit)
AstTypePackExplicit(const Location& location, AstTypeList typeList);
void visit(AstVisitor* visitor) override;
AstTypeList typeList;
};
class AstTypePackVariadic : public AstTypePack
{
public:

View file

@ -136,7 +136,10 @@ public:
const Key& key = ItemInterface::getKey(data[i]);
if (!eq(key, empty_key))
*newtable.insert_unsafe(key) = data[i];
{
Item* item = newtable.insert_unsafe(key);
*item = std::move(data[i]);
}
}
LUAU_ASSERT(count == newtable.count);

View file

@ -218,13 +218,14 @@ private:
AstTableIndexer* parseTableIndexerAnnotation();
AstType* parseFunctionTypeAnnotation();
AstTypeOrPack parseFunctionTypeAnnotation(bool allowPack);
AstType* parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray<AstName> generics, AstArray<AstName> genericPacks,
AstArray<AstType*>& params, AstArray<std::optional<AstArgumentName>>& paramNames, AstTypePack* varargAnnotation);
AstType* parseTableTypeAnnotation();
AstType* parseSimpleTypeAnnotation();
AstTypeOrPack parseSimpleTypeAnnotation(bool allowPack);
AstTypeOrPack parseTypeOrPackAnnotation();
AstType* parseTypeAnnotation(TempVector<AstType*>& parts, const Location& begin);
AstType* parseTypeAnnotation();
@ -284,7 +285,7 @@ private:
std::pair<AstArray<AstName>, AstArray<AstName>> parseGenericTypeListIfFFlagParseGenericFunctions();
// `<' typeAnnotation[, ...] `>'
AstArray<AstType*> parseTypeParams();
AstArray<AstTypeOrPack> parseTypeParams();
AstExpr* parseString();
@ -413,6 +414,7 @@ private:
std::vector<AstLocal*> scratchLocal;
std::vector<AstTableProp> scratchTableTypeProps;
std::vector<AstType*> scratchAnnotation;
std::vector<AstTypeOrPack> scratchTypeOrPackAnnotation;
std::vector<AstDeclaredClassProp> scratchDeclaredClassProps;
std::vector<AstExprTable::Item> scratchItem;
std::vector<AstArgumentName> scratchArgName;

View file

@ -0,0 +1,223 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Common.h"
#include <vector>
#include <stdint.h>
LUAU_FASTFLAG(DebugLuauTimeTracing)
#if defined(LUAU_ENABLE_TIME_TRACE)
namespace Luau
{
namespace TimeTrace
{
uint32_t getClockMicroseconds();
struct Token
{
const char* name;
const char* category;
};
enum class EventType : uint8_t
{
Enter,
Leave,
ArgName,
ArgValue,
};
struct Event
{
EventType type;
uint16_t token;
union
{
uint32_t microsec; // 1 hour trace limit
uint32_t dataPos;
} data;
};
struct GlobalContext;
struct ThreadContext;
GlobalContext& getGlobalContext();
uint16_t createToken(GlobalContext& context, const char* name, const char* category);
uint32_t createThread(GlobalContext& context, ThreadContext* threadContext);
void releaseThread(GlobalContext& context, ThreadContext* threadContext);
void flushEvents(GlobalContext& context, uint32_t threadId, const std::vector<Event>& events, const std::vector<char>& data);
struct ThreadContext
{
ThreadContext()
: globalContext(getGlobalContext())
{
threadId = createThread(globalContext, this);
}
~ThreadContext()
{
if (!events.empty())
flushEvents();
releaseThread(globalContext, this);
}
void flushEvents()
{
static uint16_t flushToken = createToken(globalContext, "flushEvents", "TimeTrace");
events.push_back({EventType::Enter, flushToken, {getClockMicroseconds()}});
TimeTrace::flushEvents(globalContext, threadId, events, data);
events.clear();
data.clear();
events.push_back({EventType::Leave, 0, {getClockMicroseconds()}});
}
void eventEnter(uint16_t token)
{
eventEnter(token, getClockMicroseconds());
}
void eventEnter(uint16_t token, uint32_t microsec)
{
events.push_back({EventType::Enter, token, {microsec}});
}
void eventLeave()
{
eventLeave(getClockMicroseconds());
}
void eventLeave(uint32_t microsec)
{
events.push_back({EventType::Leave, 0, {microsec}});
if (events.size() > kEventFlushLimit)
flushEvents();
}
void eventArgument(const char* name, const char* value)
{
uint32_t pos = uint32_t(data.size());
data.insert(data.end(), name, name + strlen(name) + 1);
events.push_back({EventType::ArgName, 0, {pos}});
pos = uint32_t(data.size());
data.insert(data.end(), value, value + strlen(value) + 1);
events.push_back({EventType::ArgValue, 0, {pos}});
}
GlobalContext& globalContext;
uint32_t threadId;
std::vector<Event> events;
std::vector<char> data;
static constexpr size_t kEventFlushLimit = 8192;
};
ThreadContext& getThreadContext();
struct Scope
{
explicit Scope(ThreadContext& context, uint16_t token)
: context(context)
{
if (!FFlag::DebugLuauTimeTracing)
return;
context.eventEnter(token);
}
~Scope()
{
if (!FFlag::DebugLuauTimeTracing)
return;
context.eventLeave();
}
ThreadContext& context;
};
struct OptionalTailScope
{
explicit OptionalTailScope(ThreadContext& context, uint16_t token, uint32_t threshold)
: context(context)
, token(token)
, threshold(threshold)
{
if (!FFlag::DebugLuauTimeTracing)
return;
pos = uint32_t(context.events.size());
microsec = getClockMicroseconds();
}
~OptionalTailScope()
{
if (!FFlag::DebugLuauTimeTracing)
return;
if (pos == context.events.size())
{
uint32_t curr = getClockMicroseconds();
if (curr - microsec > threshold)
{
context.eventEnter(token, microsec);
context.eventLeave(curr);
}
}
}
ThreadContext& context;
uint16_t token;
uint32_t threshold;
uint32_t microsec;
uint32_t pos;
};
LUAU_NOINLINE std::pair<uint16_t, Luau::TimeTrace::ThreadContext&> createScopeData(const char* name, const char* category);
} // namespace TimeTrace
} // namespace Luau
// Regular scope
#define LUAU_TIMETRACE_SCOPE(name, category) \
static auto lttScopeStatic = Luau::TimeTrace::createScopeData(name, category); \
Luau::TimeTrace::Scope lttScope(lttScopeStatic.second, lttScopeStatic.first)
// A scope without nested scopes that may be skipped if the time it took is less than the threshold
#define LUAU_TIMETRACE_OPTIONAL_TAIL_SCOPE(name, category, microsec) \
static auto lttScopeStaticOptTail = Luau::TimeTrace::createScopeData(name, category); \
Luau::TimeTrace::OptionalTailScope lttScope(lttScopeStaticOptTail.second, lttScopeStaticOptTail.first, microsec)
// Extra key/value data can be added to regular scopes
#define LUAU_TIMETRACE_ARGUMENT(name, value) \
do \
{ \
if (FFlag::DebugLuauTimeTracing) \
lttScopeStatic.second.eventArgument(name, value); \
} while (false)
#else
#define LUAU_TIMETRACE_SCOPE(name, category)
#define LUAU_TIMETRACE_OPTIONAL_TAIL_SCOPE(name, category, microsec)
#define LUAU_TIMETRACE_ARGUMENT(name, value) \
do \
{ \
} while (false)
#endif

View file

@ -641,10 +641,12 @@ void AstStatLocalFunction::visit(AstVisitor* visitor)
func->visit(visitor);
}
AstStatTypeAlias::AstStatTypeAlias(const Location& location, const AstName& name, const AstArray<AstName>& generics, AstType* type, bool exported)
AstStatTypeAlias::AstStatTypeAlias(const Location& location, const AstName& name, const AstArray<AstName>& generics,
const AstArray<AstName>& genericPacks, AstType* type, bool exported)
: AstStat(ClassIndex(), location)
, name(name)
, generics(generics)
, genericPacks(genericPacks)
, type(type)
, exported(exported)
{
@ -729,12 +731,14 @@ void AstStatError::visit(AstVisitor* visitor)
}
}
AstTypeReference::AstTypeReference(const Location& location, std::optional<AstName> prefix, AstName name, const AstArray<AstType*>& generics)
AstTypeReference::AstTypeReference(
const Location& location, std::optional<AstName> prefix, AstName name, bool hasParameterList, const AstArray<AstTypeOrPack>& parameters)
: AstType(ClassIndex(), location)
, hasPrefix(bool(prefix))
, hasParameterList(hasParameterList)
, prefix(prefix ? *prefix : AstName())
, name(name)
, generics(generics)
, parameters(parameters)
{
}
@ -742,8 +746,13 @@ void AstTypeReference::visit(AstVisitor* visitor)
{
if (visitor->visit(this))
{
for (AstType* generic : generics)
generic->visit(visitor);
for (const AstTypeOrPack& param : parameters)
{
if (param.type)
param.type->visit(visitor);
else
param.typePack->visit(visitor);
}
}
}
@ -849,6 +858,24 @@ void AstTypeError::visit(AstVisitor* visitor)
}
}
AstTypePackExplicit::AstTypePackExplicit(const Location& location, AstTypeList typeList)
: AstTypePack(ClassIndex(), location)
, typeList(typeList)
{
}
void AstTypePackExplicit::visit(AstVisitor* visitor)
{
if (visitor->visit(this))
{
for (AstType* type : typeList.types)
type->visit(visitor);
if (typeList.tailType)
typeList.tailType->visit(visitor);
}
}
AstTypePackVariadic::AstTypePackVariadic(const Location& location, AstType* variadicType)
: AstTypePack(ClassIndex(), location)
, variadicType(variadicType)

View file

@ -1,6 +1,8 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Parser.h"
#include "Luau/TimeTrace.h"
#include <algorithm>
// Warning: If you are introducing new syntax, ensure that it is behind a separate
@ -13,6 +15,8 @@ LUAU_FASTFLAGVARIABLE(LuauParseGenericFunctions, false)
LUAU_FASTFLAGVARIABLE(LuauCaptureBrokenCommentSpans, false)
LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionBaseSupport, false)
LUAU_FASTFLAGVARIABLE(LuauIfStatementRecursionGuard, false)
LUAU_FASTFLAGVARIABLE(LuauTypeAliasPacks, false)
LUAU_FASTFLAGVARIABLE(LuauParseTypePackTypeParameters, false)
namespace Luau
{
@ -148,6 +152,8 @@ static bool shouldParseTypePackAnnotation(Lexer& lexer)
ParseResult Parser::parse(const char* buffer, size_t bufferSize, AstNameTable& names, Allocator& allocator, ParseOptions options)
{
LUAU_TIMETRACE_SCOPE("Parser::parse", "Parser");
Parser p(buffer, bufferSize, names, allocator);
try
@ -769,14 +775,14 @@ AstStat* Parser::parseTypeAlias(const Location& start, bool exported)
if (!name)
name = Name(nameError, lexer.current().location);
// TODO: support generic type pack parameters in type aliases CLI-39907
auto [generics, genericPacks] = parseGenericTypeList();
expectAndConsume('=', "type alias");
AstType* type = parseTypeAnnotation();
return allocator.alloc<AstStatTypeAlias>(Location(start, type->location), name->name, generics, type, exported);
return allocator.alloc<AstStatTypeAlias>(
Location(start, type->location), name->name, generics, FFlag::LuauTypeAliasPacks ? genericPacks : AstArray<AstName>{}, type, exported);
}
AstDeclaredClassProp Parser::parseDeclaredClassMethod()
@ -1333,7 +1339,7 @@ AstType* Parser::parseTableTypeAnnotation()
// ReturnType ::= TypeAnnotation | `(' TypeList `)'
// FunctionTypeAnnotation ::= [`<' varlist `>'] `(' [TypeList] `)' `->` ReturnType
AstType* Parser::parseFunctionTypeAnnotation()
AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack)
{
incrementRecursionCounter("type annotation");
@ -1364,14 +1370,23 @@ AstType* Parser::parseFunctionTypeAnnotation()
matchRecoveryStopOnToken[Lexeme::SkinnyArrow]--;
// Not a function at all. Just a parenthesized type.
if (params.size() == 1 && !varargAnnotation && monomorphic && lexer.current().type != Lexeme::SkinnyArrow)
return params[0];
AstArray<AstType*> paramTypes = copy(params);
// Not a function at all. Just a parenthesized type. Or maybe a type pack with a single element
if (params.size() == 1 && !varargAnnotation && monomorphic && lexer.current().type != Lexeme::SkinnyArrow)
{
if (allowPack)
return {{}, allocator.alloc<AstTypePackExplicit>(begin.location, AstTypeList{paramTypes, nullptr})};
else
return {params[0], {}};
}
if (lexer.current().type != Lexeme::SkinnyArrow && monomorphic && allowPack)
return {{}, allocator.alloc<AstTypePackExplicit>(begin.location, AstTypeList{paramTypes, varargAnnotation})};
AstArray<std::optional<AstArgumentName>> paramNames = copy(names);
return parseFunctionTypeAnnotationTail(begin, generics, genericPacks, paramTypes, paramNames, varargAnnotation);
return {parseFunctionTypeAnnotationTail(begin, generics, genericPacks, paramTypes, paramNames, varargAnnotation), {}};
}
AstType* Parser::parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray<AstName> generics, AstArray<AstName> genericPacks,
@ -1421,7 +1436,7 @@ AstType* Parser::parseTypeAnnotation(TempVector<AstType*>& parts, const Location
if (c == '|')
{
nextLexeme();
parts.push_back(parseSimpleTypeAnnotation());
parts.push_back(parseSimpleTypeAnnotation(false).type);
isUnion = true;
}
else if (c == '?')
@ -1434,7 +1449,7 @@ AstType* Parser::parseTypeAnnotation(TempVector<AstType*>& parts, const Location
else if (c == '&')
{
nextLexeme();
parts.push_back(parseSimpleTypeAnnotation());
parts.push_back(parseSimpleTypeAnnotation(false).type);
isIntersection = true;
}
else
@ -1462,6 +1477,30 @@ AstType* Parser::parseTypeAnnotation(TempVector<AstType*>& parts, const Location
ParseError::raise(begin, "Composite type was not an intersection or union.");
}
AstTypeOrPack Parser::parseTypeOrPackAnnotation()
{
unsigned int oldRecursionCount = recursionCounter;
incrementRecursionCounter("type annotation");
Location begin = lexer.current().location;
TempVector<AstType*> parts(scratchAnnotation);
auto [type, typePack] = parseSimpleTypeAnnotation(true);
if (typePack)
{
LUAU_ASSERT(!type);
return {{}, typePack};
}
parts.push_back(type);
recursionCounter = oldRecursionCount;
return {parseTypeAnnotation(parts, begin), {}};
}
AstType* Parser::parseTypeAnnotation()
{
unsigned int oldRecursionCount = recursionCounter;
@ -1470,7 +1509,7 @@ AstType* Parser::parseTypeAnnotation()
Location begin = lexer.current().location;
TempVector<AstType*> parts(scratchAnnotation);
parts.push_back(parseSimpleTypeAnnotation());
parts.push_back(parseSimpleTypeAnnotation(false).type);
recursionCounter = oldRecursionCount;
@ -1479,7 +1518,7 @@ AstType* Parser::parseTypeAnnotation()
// typeannotation ::= nil | Name[`.' Name] [ `<' typeannotation [`,' ...] `>' ] | `typeof' `(' expr `)' | `{' [PropList] `}'
// | [`<' varlist `>'] `(' [TypeList] `)' `->` ReturnType
AstType* Parser::parseSimpleTypeAnnotation()
AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack)
{
incrementRecursionCounter("type annotation");
@ -1488,7 +1527,7 @@ AstType* Parser::parseSimpleTypeAnnotation()
if (lexer.current().type == Lexeme::ReservedNil)
{
nextLexeme();
return allocator.alloc<AstTypeReference>(begin, std::nullopt, nameNil);
return {allocator.alloc<AstTypeReference>(begin, std::nullopt, nameNil), {}};
}
else if (lexer.current().type == Lexeme::Name)
{
@ -1514,22 +1553,41 @@ AstType* Parser::parseSimpleTypeAnnotation()
expectMatchAndConsume(')', typeofBegin);
return allocator.alloc<AstTypeTypeof>(Location(begin, end), expr);
return {allocator.alloc<AstTypeTypeof>(Location(begin, end), expr), {}};
}
AstArray<AstType*> generics = parseTypeParams();
if (FFlag::LuauParseTypePackTypeParameters)
{
bool hasParameters = false;
AstArray<AstTypeOrPack> parameters{};
if (lexer.current().type == '<')
{
hasParameters = true;
parameters = parseTypeParams();
}
Location end = lexer.previousLocation();
return allocator.alloc<AstTypeReference>(Location(begin, end), prefix, name.name, generics);
return {allocator.alloc<AstTypeReference>(Location(begin, end), prefix, name.name, hasParameters, parameters), {}};
}
else
{
AstArray<AstTypeOrPack> generics = parseTypeParams();
Location end = lexer.previousLocation();
// false in 'hasParameterList' as it is not used without FFlagLuauTypeAliasPacks
return {allocator.alloc<AstTypeReference>(Location(begin, end), prefix, name.name, false, generics), {}};
}
}
else if (lexer.current().type == '{')
{
return parseTableTypeAnnotation();
return {parseTableTypeAnnotation(), {}};
}
else if (lexer.current().type == '(' || (FFlag::LuauParseGenericFunctions && lexer.current().type == '<'))
{
return parseFunctionTypeAnnotation();
return parseFunctionTypeAnnotation(allowPack);
}
else
{
@ -1538,7 +1596,7 @@ AstType* Parser::parseSimpleTypeAnnotation()
// For a missing type annotation, capture 'space' between last token and the next one
location = Location(lexer.previousLocation().end, lexer.current().location.begin);
return reportTypeAnnotationError(location, {}, /*isMissing*/ true, "Expected type, got %s", lexer.current().toString().c_str());
return {reportTypeAnnotationError(location, {}, /*isMissing*/ true, "Expected type, got %s", lexer.current().toString().c_str()), {}};
}
}
@ -2312,18 +2370,59 @@ std::pair<AstArray<AstName>, AstArray<AstName>> Parser::parseGenericTypeList()
return {generics, genericPacks};
}
AstArray<AstType*> Parser::parseTypeParams()
AstArray<AstTypeOrPack> Parser::parseTypeParams()
{
TempVector<AstType*> result{scratchAnnotation};
TempVector<AstTypeOrPack> parameters{scratchTypeOrPackAnnotation};
if (lexer.current().type == '<')
{
Lexeme begin = lexer.current();
nextLexeme();
bool seenPack = false;
while (true)
{
result.push_back(parseTypeAnnotation());
if (FFlag::LuauParseTypePackTypeParameters)
{
if (shouldParseTypePackAnnotation(lexer))
{
seenPack = true;
auto typePack = parseTypePackAnnotation();
if (FFlag::LuauTypeAliasPacks) // Type packs are recorded only is we can handle them
parameters.push_back({{}, typePack});
}
else if (lexer.current().type == '(')
{
auto [type, typePack] = parseTypeOrPackAnnotation();
if (typePack)
{
seenPack = true;
if (FFlag::LuauTypeAliasPacks) // Type packs are recorded only is we can handle them
parameters.push_back({{}, typePack});
}
else
{
parameters.push_back({type, {}});
}
}
else if (lexer.current().type == '>' && parameters.empty())
{
break;
}
else
{
parameters.push_back({parseTypeAnnotation(), {}});
}
}
else
{
parameters.push_back({parseTypeAnnotation(), {}});
}
if (lexer.current().type == ',')
nextLexeme();
else
@ -2333,7 +2432,7 @@ AstArray<AstType*> Parser::parseTypeParams()
expectMatchAndConsume('>', begin);
}
return copy(result);
return copy(parameters);
}
AstExpr* Parser::parseString()

248
Ast/src/TimeTrace.cpp Normal file
View file

@ -0,0 +1,248 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/TimeTrace.h"
#include "Luau/StringUtils.h"
#include <mutex>
#include <string>
#include <stdlib.h>
#ifdef _WIN32
#include <Windows.h>
#endif
#ifdef __APPLE__
#include <mach/mach.h>
#include <mach/mach_time.h>
#endif
#include <time.h>
LUAU_FASTFLAGVARIABLE(DebugLuauTimeTracing, false)
#if defined(LUAU_ENABLE_TIME_TRACE)
namespace Luau
{
namespace TimeTrace
{
static double getClockPeriod()
{
#if defined(_WIN32)
LARGE_INTEGER result = {};
QueryPerformanceFrequency(&result);
return 1.0 / double(result.QuadPart);
#elif defined(__APPLE__)
mach_timebase_info_data_t result = {};
mach_timebase_info(&result);
return double(result.numer) / double(result.denom) * 1e-9;
#elif defined(__linux__)
return 1e-9;
#else
return 1.0 / double(CLOCKS_PER_SEC);
#endif
}
static double getClockTimestamp()
{
#if defined(_WIN32)
LARGE_INTEGER result = {};
QueryPerformanceCounter(&result);
return double(result.QuadPart);
#elif defined(__APPLE__)
return double(mach_absolute_time());
#elif defined(__linux__)
timespec now;
clock_gettime(CLOCK_MONOTONIC, &now);
return now.tv_sec * 1e9 + now.tv_nsec;
#else
return double(clock());
#endif
}
uint32_t getClockMicroseconds()
{
static double period = getClockPeriod() * 1e6;
static double start = getClockTimestamp();
return uint32_t((getClockTimestamp() - start) * period);
}
struct GlobalContext
{
GlobalContext() = default;
~GlobalContext()
{
// Ideally we would want all ThreadContext destructors to run
// But in VS, not all thread_local object instances are destroyed
for (ThreadContext* context : threads)
context->flushEvents();
if (traceFile)
fclose(traceFile);
}
std::mutex mutex;
std::vector<ThreadContext*> threads;
uint32_t nextThreadId = 0;
std::vector<Token> tokens;
FILE* traceFile = nullptr;
};
GlobalContext& getGlobalContext()
{
static GlobalContext context;
return context;
}
uint16_t createToken(GlobalContext& context, const char* name, const char* category)
{
std::scoped_lock lock(context.mutex);
LUAU_ASSERT(context.tokens.size() < 64 * 1024);
context.tokens.push_back({name, category});
return uint16_t(context.tokens.size() - 1);
}
uint32_t createThread(GlobalContext& context, ThreadContext* threadContext)
{
std::scoped_lock lock(context.mutex);
context.threads.push_back(threadContext);
return ++context.nextThreadId;
}
void releaseThread(GlobalContext& context, ThreadContext* threadContext)
{
std::scoped_lock lock(context.mutex);
if (auto it = std::find(context.threads.begin(), context.threads.end(), threadContext); it != context.threads.end())
context.threads.erase(it);
}
void flushEvents(GlobalContext& context, uint32_t threadId, const std::vector<Event>& events, const std::vector<char>& data)
{
std::scoped_lock lock(context.mutex);
if (!context.traceFile)
{
context.traceFile = fopen("trace.json", "w");
if (!context.traceFile)
return;
fprintf(context.traceFile, "[\n");
}
std::string temp;
const unsigned tempReserve = 64 * 1024;
temp.reserve(tempReserve);
const char* rawData = data.data();
// Formatting state
bool unfinishedEnter = false;
bool unfinishedArgs = false;
for (const Event& ev : events)
{
switch (ev.type)
{
case EventType::Enter:
{
if (unfinishedArgs)
{
formatAppend(temp, "}");
unfinishedArgs = false;
}
if (unfinishedEnter)
{
formatAppend(temp, "},\n");
unfinishedEnter = false;
}
Token& token = context.tokens[ev.token];
formatAppend(temp, R"({"name": "%s", "cat": "%s", "ph": "B", "ts": %u, "pid": 0, "tid": %u)", token.name, token.category,
ev.data.microsec, threadId);
unfinishedEnter = true;
}
break;
case EventType::Leave:
if (unfinishedArgs)
{
formatAppend(temp, "}");
unfinishedArgs = false;
}
if (unfinishedEnter)
{
formatAppend(temp, "},\n");
unfinishedEnter = false;
}
formatAppend(temp,
R"({"ph": "E", "ts": %u, "pid": 0, "tid": %u},)"
"\n",
ev.data.microsec, threadId);
break;
case EventType::ArgName:
LUAU_ASSERT(unfinishedEnter);
if (!unfinishedArgs)
{
formatAppend(temp, R"(, "args": { "%s": )", rawData + ev.data.dataPos);
unfinishedArgs = true;
}
else
{
formatAppend(temp, R"(, "%s": )", rawData + ev.data.dataPos);
}
break;
case EventType::ArgValue:
LUAU_ASSERT(unfinishedArgs);
formatAppend(temp, R"("%s")", rawData + ev.data.dataPos);
break;
}
// Don't want to hit the string capacity and reallocate
if (temp.size() > tempReserve - 1024)
{
fwrite(temp.data(), 1, temp.size(), context.traceFile);
temp.clear();
}
}
if (unfinishedArgs)
{
formatAppend(temp, "}");
unfinishedArgs = false;
}
if (unfinishedEnter)
{
formatAppend(temp, "},\n");
unfinishedEnter = false;
}
fwrite(temp.data(), 1, temp.size(), context.traceFile);
fflush(context.traceFile);
}
ThreadContext& getThreadContext()
{
thread_local ThreadContext context;
return context;
}
std::pair<uint16_t, Luau::TimeTrace::ThreadContext&> createScopeData(const char* name, const char* category)
{
uint16_t token = createToken(Luau::TimeTrace::getGlobalContext(), name, category);
return {token, Luau::TimeTrace::getThreadContext()};
}
} // namespace TimeTrace
} // namespace Luau
#endif

View file

@ -111,11 +111,24 @@ struct CliFileResolver : Luau::FileResolver
return Luau::SourceCode{*source, Luau::SourceCode::Module};
}
std::optional<Luau::ModuleInfo> resolveModule(const Luau::ModuleInfo* context, Luau::AstExpr* node) override
{
if (Luau::AstExprConstantString* expr = node->as<Luau::AstExprConstantString>())
{
Luau::ModuleName name = std::string(expr->value.data, expr->value.size) + ".lua";
return {{name}};
}
return std::nullopt;
}
bool moduleExists(const Luau::ModuleName& name) const override
{
return !!readFile(name);
}
std::optional<Luau::ModuleName> fromAstFragment(Luau::AstExpr* expr) const override
{
return std::nullopt;
@ -130,11 +143,6 @@ struct CliFileResolver : Luau::FileResolver
{
return std::nullopt;
}
std::optional<std::string> getEnvironmentForModule(const Luau::ModuleName& name) const override
{
return std::nullopt;
}
};
struct CliConfigResolver : Luau::ConfigResolver

View file

@ -4,6 +4,7 @@
#include "Luau/Parser.h"
#include "Luau/BytecodeBuilder.h"
#include "Luau/Common.h"
#include "Luau/TimeTrace.h"
#include <algorithm>
#include <bitset>
@ -137,6 +138,11 @@ struct Compiler
uint32_t compileFunction(AstExprFunction* func)
{
LUAU_TIMETRACE_SCOPE("Compiler::compileFunction", "Compiler");
if (func->debugname.value)
LUAU_TIMETRACE_ARGUMENT("name", func->debugname.value);
LUAU_ASSERT(!functions.contains(func));
LUAU_ASSERT(regTop == 0 && stackSize == 0 && localStack.empty() && upvals.empty());
@ -3686,6 +3692,8 @@ struct Compiler
void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstNameTable& names, const CompileOptions& options)
{
LUAU_TIMETRACE_SCOPE("compileOrThrow", "Compiler");
Compiler compiler(bytecode, options);
// since access to some global objects may result in values that change over time, we block table imports
@ -3748,6 +3756,8 @@ void compileOrThrow(BytecodeBuilder& bytecode, const std::string& source, const
std::string compile(const std::string& source, const CompileOptions& options, const ParseOptions& parseOptions, BytecodeEncoder* encoder)
{
LUAU_TIMETRACE_SCOPE("compile", "Compiler");
Allocator allocator;
AstNameTable names(allocator);
ParseResult result = Parser::parse(source.c_str(), source.size(), names, allocator, parseOptions);

View file

@ -9,6 +9,7 @@ target_sources(Luau.Ast PRIVATE
Ast/include/Luau/ParseOptions.h
Ast/include/Luau/Parser.h
Ast/include/Luau/StringUtils.h
Ast/include/Luau/TimeTrace.h
Ast/src/Ast.cpp
Ast/src/Confusables.cpp
@ -16,6 +17,7 @@ target_sources(Luau.Ast PRIVATE
Ast/src/Location.cpp
Ast/src/Parser.cpp
Ast/src/StringUtils.cpp
Ast/src/TimeTrace.cpp
)
# Luau.Compiler Sources
@ -46,6 +48,7 @@ target_sources(Luau.Analysis PRIVATE
Analysis/include/Luau/Predicate.h
Analysis/include/Luau/RecursionCounter.h
Analysis/include/Luau/RequireTracer.h
Analysis/include/Luau/Scope.h
Analysis/include/Luau/Substitution.h
Analysis/include/Luau/Symbol.h
Analysis/include/Luau/TopoSortStatements.h
@ -75,6 +78,7 @@ target_sources(Luau.Analysis PRIVATE
Analysis/src/Module.cpp
Analysis/src/Predicate.cpp
Analysis/src/RequireTracer.cpp
Analysis/src/Scope.cpp
Analysis/src/Substitution.cpp
Analysis/src/Symbol.cpp
Analysis/src/TopoSortStatements.cpp
@ -188,6 +192,7 @@ if(TARGET Luau.UnitTest)
tests/TopoSort.test.cpp
tests/ToString.test.cpp
tests/Transpiler.test.cpp
tests/TypeInfer.aliases.test.cpp
tests/TypeInfer.annotations.test.cpp
tests/TypeInfer.builtins.test.cpp
tests/TypeInfer.classes.test.cpp

View file

@ -8,9 +8,13 @@
#include "lmem.h"
#include "lvm.h"
#include <stdexcept>
#if LUA_USE_LONGJMP
#include <setjmp.h>
#include <stdlib.h>
#else
#include <stdexcept>
#endif
#include <string.h>
LUAU_FASTFLAGVARIABLE(LuauExceptionMessageFix, false)
@ -51,8 +55,8 @@ l_noret luaD_throw(lua_State* L, int errcode)
longjmp(jb->buf, 1);
}
if (L->global->panic)
L->global->panic(L, errcode);
if (L->global->cb.panic)
L->global->cb.panic(L, errcode);
abort();
}

View file

@ -16,6 +16,8 @@ LUAU_FASTFLAGVARIABLE(LuauRescanGrayAgain, false)
LUAU_FASTFLAGVARIABLE(LuauRescanGrayAgainForwardBarrier, false)
LUAU_FASTFLAGVARIABLE(LuauGcFullSkipInactiveThreads, false)
LUAU_FASTFLAGVARIABLE(LuauShrinkWeakTables, false)
LUAU_FASTFLAGVARIABLE(LuauConsolidatedStep, false)
LUAU_FASTFLAG(LuauArrayBoundary)
#define GC_SWEEPMAX 40
@ -810,6 +812,133 @@ static size_t singlestep(lua_State* L)
return cost;
}
static size_t gcstep(lua_State* L, size_t limit)
{
size_t cost = 0;
global_State* g = L->global;
switch (g->gcstate)
{
case GCSpause:
{
markroot(L); /* start a new collection */
break;
}
case GCSpropagate:
{
if (FFlag::LuauRescanGrayAgain)
{
while (g->gray && cost < limit)
{
g->gcstats.currcycle.markitems++;
cost += propagatemark(g);
}
if (!g->gray)
{
// perform one iteration over 'gray again' list
g->gray = g->grayagain;
g->grayagain = NULL;
g->gcstate = GCSpropagateagain;
}
}
else
{
while (g->gray && cost < limit)
{
g->gcstats.currcycle.markitems++;
cost += propagatemark(g);
}
if (!g->gray) /* no more `gray' objects */
{
double starttimestamp = lua_clock();
g->gcstats.currcycle.atomicstarttimestamp = starttimestamp;
g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes;
atomic(L); /* finish mark phase */
g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp;
}
}
break;
}
case GCSpropagateagain:
{
while (g->gray && cost < limit)
{
g->gcstats.currcycle.markitems++;
cost += propagatemark(g);
}
if (!g->gray) /* no more `gray' objects */
{
double starttimestamp = lua_clock();
g->gcstats.currcycle.atomicstarttimestamp = starttimestamp;
g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes;
atomic(L); /* finish mark phase */
g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp;
}
break;
}
case GCSsweepstring:
{
while (g->sweepstrgc < g->strt.size && cost < limit)
{
size_t traversedcount = 0;
sweepwholelist(L, &g->strt.hash[g->sweepstrgc++], &traversedcount);
g->gcstats.currcycle.sweepitems += traversedcount;
cost += GC_SWEEPCOST;
}
// nothing more to sweep?
if (g->sweepstrgc >= g->strt.size)
{
// sweep string buffer list and preserve used string count
uint32_t nuse = L->global->strt.nuse;
size_t traversedcount = 0;
sweepwholelist(L, &g->strbufgc, &traversedcount);
L->global->strt.nuse = nuse;
g->gcstats.currcycle.sweepitems += traversedcount;
g->gcstate = GCSsweep; // end sweep-string phase
}
break;
}
case GCSsweep:
{
while (*g->sweepgc && cost < limit)
{
size_t traversedcount = 0;
g->sweepgc = sweeplist(L, g->sweepgc, GC_SWEEPMAX, &traversedcount);
g->gcstats.currcycle.sweepitems += traversedcount;
cost += GC_SWEEPMAX * GC_SWEEPCOST;
}
if (*g->sweepgc == NULL)
{ /* nothing more to sweep? */
shrinkbuffers(L);
g->gcstate = GCSpause; /* end collection */
}
break;
}
default:
LUAU_ASSERT(0);
}
return cost;
}
static int64_t getheaptriggererroroffset(GCHeapTriggerStats* triggerstats, GCCycleStats* cyclestats)
{
// adjust for error using Proportional-Integral controller
@ -878,14 +1007,20 @@ void luaC_step(lua_State* L, bool assist)
if (g->gcstate == GCSpause)
startGcCycleStats(g);
if (assist)
g->gcstats.currcycle.assistwork += lim;
else
g->gcstats.currcycle.explicitwork += lim;
int lastgcstate = g->gcstate;
double lasttimestamp = lua_clock();
if (FFlag::LuauConsolidatedStep)
{
size_t work = gcstep(L, lim);
if (assist)
g->gcstats.currcycle.assistwork += work;
else
g->gcstats.currcycle.explicitwork += work;
}
else
{
// always perform at least one single step
do
{
@ -905,6 +1040,7 @@ void luaC_step(lua_State* L, bool assist)
lastgcstate = g->gcstate;
}
} while (lim > 0 && g->gcstate != GCSpause);
}
recordGcStateTime(g, lastgcstate, lua_clock() - lasttimestamp, assist);
@ -931,7 +1067,14 @@ void luaC_step(lua_State* L, bool assist)
g->GCthreshold -= debt;
}
if (FFlag::LuauConsolidatedStep)
{
GC_INTERRUPT(lastgcstate);
}
else
{
GC_INTERRUPT(g->gcstate);
}
}
void luaC_fullgc(lua_State* L)
@ -957,6 +1100,9 @@ void luaC_fullgc(lua_State* L)
while (g->gcstate != GCSpause)
{
LUAU_ASSERT(g->gcstate == GCSsweepstring || g->gcstate == GCSsweep);
if (FFlag::LuauConsolidatedStep)
gcstep(L, SIZE_MAX);
else
singlestep(L);
}
@ -968,6 +1114,9 @@ void luaC_fullgc(lua_State* L)
markroot(L);
while (g->gcstate != GCSpause)
{
if (FFlag::LuauConsolidatedStep)
gcstep(L, SIZE_MAX);
else
singlestep(L);
}
/* reclaim as much buffer memory as possible (shrinkbuffers() called during sweep is incremental) */

View file

@ -9,14 +9,8 @@
#include "ldebug.h"
#include "lvm.h"
LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauTableMoveTelemetry, false)
LUAU_FASTFLAGVARIABLE(LuauTableFreeze, false)
bool lua_telemetry_table_move_oob_src_from = false;
bool lua_telemetry_table_move_oob_src_to = false;
bool lua_telemetry_table_move_oob_dst = false;
static int foreachi(lua_State* L)
{
luaL_checktype(L, 1, LUA_TTABLE);
@ -202,22 +196,6 @@ static int tmove(lua_State* L)
int tt = !lua_isnoneornil(L, 5) ? 5 : 1; /* destination table */
luaL_checktype(L, tt, LUA_TTABLE);
if (DFFlag::LuauTableMoveTelemetry)
{
int nf = lua_objlen(L, 1);
int nt = lua_objlen(L, tt);
// source index range must be in bounds in source table unless the table is empty (permits 1..#t moves)
if (!(f == 1 || (f >= 1 && f <= nf)))
lua_telemetry_table_move_oob_src_from = true;
if (!(e == nf || (e >= 1 && e <= nf)))
lua_telemetry_table_move_oob_src_to = true;
// destination index must be in bounds in dest table or be exactly at the first empty element (permits concats)
if (!(t == nt + 1 || (t >= 1 && t <= nt + 1)))
lua_telemetry_table_move_oob_dst = true;
}
if (e >= f)
{ /* otherwise, nothing to move */
luaL_argcheck(L, f > 0 || e < INT_MAX + f, 3, "too many elements to move");

View file

@ -16,8 +16,6 @@
#include <string.h>
LUAU_FASTFLAGVARIABLE(LuauLoopUseSafeenv, false)
// Disable c99-designator to avoid the warning in CGOTO dispatch table
#ifdef __clang__
#if __has_warning("-Wc99-designator")
@ -292,10 +290,6 @@ inline bool luau_skipstep(uint8_t op)
return op == LOP_PREPVARARGS || op == LOP_BREAK;
}
// declared in lbaselib.cpp, needed to support cases when pairs/ipairs have been replaced via setfenv
LUAI_FUNC int luaB_inext(lua_State* L);
LUAI_FUNC int luaB_next(lua_State* L);
template<bool SingleStep>
static void luau_execute(lua_State* L)
{
@ -2223,8 +2217,7 @@ static void luau_execute(lua_State* L)
StkId ra = VM_REG(LUAU_INSN_A(insn));
// fast-path: ipairs/inext
bool safeenv = FFlag::LuauLoopUseSafeenv ? cl->env->safeenv : ttisfunction(ra) && clvalue(ra)->isC && clvalue(ra)->c.f == luaB_inext;
if (safeenv && ttistable(ra + 1) && ttisnumber(ra + 2) && nvalue(ra + 2) == 0.0)
if (cl->env->safeenv && ttistable(ra + 1) && ttisnumber(ra + 2) && nvalue(ra + 2) == 0.0)
{
setpvalue(ra + 2, reinterpret_cast<void*>(uintptr_t(0)));
}
@ -2304,8 +2297,7 @@ static void luau_execute(lua_State* L)
StkId ra = VM_REG(LUAU_INSN_A(insn));
// fast-path: pairs/next
bool safeenv = FFlag::LuauLoopUseSafeenv ? cl->env->safeenv : ttisfunction(ra) && clvalue(ra)->isC && clvalue(ra)->c.f == luaB_next;
if (safeenv && ttistable(ra + 1) && ttisnil(ra + 2))
if (cl->env->safeenv && ttistable(ra + 1) && ttisnil(ra + 2))
{
setpvalue(ra + 2, reinterpret_cast<void*>(uintptr_t(0)));
}

View file

@ -12,7 +12,32 @@
#include <string.h>
#include <vector>
// TODO: RAII deallocation doesn't work for longjmp builds if a memory error happens
template <typename T>
struct TempBuffer
{
lua_State* L;
T* data;
size_t count;
TempBuffer(lua_State* L, size_t count)
: L(L)
, data(luaM_newarray(L, count, T, 0))
, count(count)
{
}
~TempBuffer()
{
luaM_freearray(L, data, count, T, 0);
}
T& operator[](size_t index)
{
LUAU_ASSERT(index < count);
return data[index];
}
};
void luaV_getimport(lua_State* L, Table* env, TValue* k, uint32_t id, bool propagatenil)
{
@ -67,7 +92,7 @@ static unsigned int readVarInt(const char* data, size_t size, size_t& offset)
return result;
}
static TString* readString(std::vector<TString*>& strings, const char* data, size_t size, size_t& offset)
static TString* readString(TempBuffer<TString*>& strings, const char* data, size_t size, size_t& offset)
{
unsigned int id = readVarInt(data, size, offset);
@ -133,6 +158,7 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size
}
// pause GC for the duration of deserialization - some objects we're creating aren't rooted
// TODO: if an allocation error happens mid-load, we do not unpause GC!
size_t GCthreshold = L->global->GCthreshold;
L->global->GCthreshold = SIZE_MAX;
@ -144,7 +170,7 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size
// string table
unsigned int stringCount = readVarInt(data, size, offset);
std::vector<TString*> strings(stringCount);
TempBuffer<TString*> strings(L, stringCount);
for (unsigned int i = 0; i < stringCount; ++i)
{
@ -156,7 +182,7 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size
// proto table
unsigned int protoCount = readVarInt(data, size, offset);
std::vector<Proto*> protos(protoCount);
TempBuffer<Proto*> protos(L, protoCount);
for (unsigned int i = 0; i < protoCount; ++i)
{

View file

@ -1,934 +0,0 @@
local bench = script and require(script.Parent.bench_support) or require("bench_support")
-- Copyright 2008 the V8 project authors. All rights reserved.
-- Copyright 1996 John Maloney and Mario Wolczko.
-- This program is free software; you can redistribute it and/or modify
-- it under the terms of the GNU General Public License as published by
-- the Free Software Foundation; either version 2 of the License, or
-- (at your option) any later version.
--
-- This program is distributed in the hope that it will be useful,
-- but WITHOUT ANY WARRANTY; without even the implied warranty of
-- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-- GNU General Public License for more details.
--
-- You should have received a copy of the GNU General Public License
-- along with this program; if not, write to the Free Software
-- Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
-- This implementation of the DeltaBlue benchmark is derived
-- from the Smalltalk implementation by John Maloney and Mario
-- Wolczko. Some parts have been translated directly, whereas
-- others have been modified more aggressively to make it feel
-- more like a JavaScript program.
--
-- A JavaScript implementation of the DeltaBlue constraint-solving
-- algorithm, as described in:
--
-- "The DeltaBlue Algorithm: An Incremental Constraint Hierarchy Solver"
-- Bjorn N. Freeman-Benson and John Maloney
-- January 1990 Communications of the ACM,
-- also available as University of Washington TR 89-08-06.
--
-- Beware: this benchmark is written in a grotesque style where
-- the constraint model is built by side-effects from constructors.
-- I've kept it this way to avoid deviating too much from the original
-- implementation.
--
function class(base)
local T = {}
T.__index = T
if base then
T.super = base
setmetatable(T, base)
end
function T.new(...)
local O = {}
setmetatable(O, T)
O:constructor(...)
return O
end
return T
end
local planner
--- O b j e c t M o d e l ---
local function alert (...) print(...) end
local OrderedCollection = class()
function OrderedCollection:constructor()
self.elms = {}
end
function OrderedCollection:add(elm)
self.elms[#self.elms + 1] = elm
end
function OrderedCollection:at (index)
return self.elms[index]
end
function OrderedCollection:size ()
return #self.elms
end
function OrderedCollection:removeFirst ()
local e = self.elms[#self.elms]
self.elms[#self.elms] = nil
return e
end
function OrderedCollection:remove (elm)
local index = 0
local skipped = 0
for i = 1, #self.elms do
local value = self.elms[i]
if value ~= elm then
self.elms[index] = value
index = index + 1
else
skipped = skipped + 1
end
end
local l = #self.elms
for i = 1, skipped do self.elms[l - i + 1] = nil end
end
--
-- S t r e n g t h
--
--
-- Strengths are used to measure the relative importance of constraints.
-- New strengths may be inserted in the strength hierarchy without
-- disrupting current constraints. Strengths cannot be created outside
-- this class, so pointer comparison can be used for value comparison.
--
local Strength = class()
function Strength:constructor(strengthValue, name)
self.strengthValue = strengthValue
self.name = name
end
function Strength.stronger (s1, s2)
return s1.strengthValue < s2.strengthValue
end
function Strength.weaker (s1, s2)
return s1.strengthValue > s2.strengthValue
end
function Strength.weakestOf (s1, s2)
return Strength.weaker(s1, s2) and s1 or s2
end
function Strength.strongest (s1, s2)
return Strength.stronger(s1, s2) and s1 or s2
end
function Strength:nextWeaker ()
local v = self.strengthValue
if v == 0 then return Strength.WEAKEST
elseif v == 1 then return Strength.WEAK_DEFAULT
elseif v == 2 then return Strength.NORMAL
elseif v == 3 then return Strength.STRONG_DEFAULT
elseif v == 4 then return Strength.PREFERRED
elseif v == 5 then return Strength.REQUIRED
end
end
-- Strength constants.
Strength.REQUIRED = Strength.new(0, "required");
Strength.STONG_PREFERRED = Strength.new(1, "strongPreferred");
Strength.PREFERRED = Strength.new(2, "preferred");
Strength.STRONG_DEFAULT = Strength.new(3, "strongDefault");
Strength.NORMAL = Strength.new(4, "normal");
Strength.WEAK_DEFAULT = Strength.new(5, "weakDefault");
Strength.WEAKEST = Strength.new(6, "weakest");
--
-- C o n s t r a i n t
--
--
-- An abstract class representing a system-maintainable relationship
-- (or "constraint") between a set of variables. A constraint supplies
-- a strength instance variable; concrete subclasses provide a means
-- of storing the constrained variables and other information required
-- to represent a constraint.
--
local Constraint = class ()
function Constraint:constructor(strength)
self.strength = strength
end
--
-- Activate this constraint and attempt to satisfy it.
--
function Constraint:addConstraint ()
self:addToGraph()
planner:incrementalAdd(self)
end
--
-- Attempt to find a way to enforce this constraint. If successful,
-- record the solution, perhaps modifying the current dataflow
-- graph. Answer the constraint that this constraint overrides, if
-- there is one, or nil, if there isn't.
-- Assume: I am not already satisfied.
--
function Constraint:satisfy (mark)
self:chooseMethod(mark)
if not self:isSatisfied() then
if self.strength == Strength.REQUIRED then
alert("Could not satisfy a required constraint!")
end
return nil
end
self:markInputs(mark)
local out = self:output()
local overridden = out.determinedBy
if overridden ~= nil then overridden:markUnsatisfied() end
out.determinedBy = self
if not planner:addPropagate(self, mark) then alert("Cycle encountered") end
out.mark = mark
return overridden
end
function Constraint:destroyConstraint ()
if self:isSatisfied()
then planner:incrementalRemove(self)
else self:removeFromGraph()
end
end
--
-- Normal constraints are not input constraints. An input constraint
-- is one that depends on external state, such as the mouse, the
-- keyboard, a clock, or some arbitrary piece of imperative code.
--
function Constraint:isInput ()
return false
end
--
-- U n a r y C o n s t r a i n t
--
--
-- Abstract superclass for constraints having a single possible output
-- variable.
--
local UnaryConstraint = class(Constraint)
function UnaryConstraint:constructor (v, strength)
UnaryConstraint.super.constructor(self, strength)
self.myOutput = v
self.satisfied = false
self:addConstraint()
end
--
-- Adds this constraint to the constraint graph
--
function UnaryConstraint:addToGraph ()
self.myOutput:addConstraint(self)
self.satisfied = false
end
--
-- Decides if this constraint can be satisfied and records that
-- decision.
--
function UnaryConstraint:chooseMethod (mark)
self.satisfied = (self.myOutput.mark ~= mark)
and Strength.stronger(self.strength, self.myOutput.walkStrength);
end
--
-- Returns true if this constraint is satisfied in the current solution.
--
function UnaryConstraint:isSatisfied ()
return self.satisfied;
end
function UnaryConstraint:markInputs (mark)
-- has no inputs
end
--
-- Returns the current output variable.
--
function UnaryConstraint:output ()
return self.myOutput
end
--
-- Calculate the walkabout strength, the stay flag, and, if it is
-- 'stay', the value for the current output of this constraint. Assume
-- this constraint is satisfied.
--
function UnaryConstraint:recalculate ()
self.myOutput.walkStrength = self.strength
self.myOutput.stay = not self:isInput()
if self.myOutput.stay then
self:execute() -- Stay optimization
end
end
--
-- Records that this constraint is unsatisfied
--
function UnaryConstraint:markUnsatisfied ()
self.satisfied = false
end
function UnaryConstraint:inputsKnown ()
return true
end
function UnaryConstraint:removeFromGraph ()
if self.myOutput ~= nil then
self.myOutput:removeConstraint(self)
end
self.satisfied = false
end
--
-- S t a y C o n s t r a i n t
--
--
-- Variables that should, with some level of preference, stay the same.
-- Planners may exploit the fact that instances, if satisfied, will not
-- change their output during plan execution. This is called "stay
-- optimization".
--
local StayConstraint = class(UnaryConstraint)
function StayConstraint:constructor(v, str)
StayConstraint.super.constructor(self, v, str)
end
function StayConstraint:execute ()
-- Stay constraints do nothing
end
--
-- E d i t C o n s t r a i n t
--
--
-- A unary input constraint used to mark a variable that the client
-- wishes to change.
--
local EditConstraint = class (UnaryConstraint)
function EditConstraint:constructor(v, str)
EditConstraint.super.constructor(self, v, str)
end
--
-- Edits indicate that a variable is to be changed by imperative code.
--
function EditConstraint:isInput ()
return true
end
function EditConstraint:execute ()
-- Edit constraints do nothing
end
--
-- B i n a r y C o n s t r a i n t
--
local Direction = {}
Direction.NONE = 0
Direction.FORWARD = 1
Direction.BACKWARD = -1
--
-- Abstract superclass for constraints having two possible output
-- variables.
--
local BinaryConstraint = class(Constraint)
function BinaryConstraint:constructor(var1, var2, strength)
BinaryConstraint.super.constructor(self, strength);
self.v1 = var1
self.v2 = var2
self.direction = Direction.NONE
self:addConstraint()
end
--
-- Decides if this constraint can be satisfied and which way it
-- should flow based on the relative strength of the variables related,
-- and record that decision.
--
function BinaryConstraint:chooseMethod (mark)
if self.v1.mark == mark then
self.direction = (self.v2.mark ~= mark and Strength.stronger(self.strength, self.v2.walkStrength)) and Direction.FORWARD or Direction.NONE
end
if self.v2.mark == mark then
self.direction = (self.v1.mark ~= mark and Strength.stronger(self.strength, self.v1.walkStrength)) and Direction.BACKWARD or Direction.NONE
end
if Strength.weaker(self.v1.walkStrength, self.v2.walkStrength) then
self.direction = Strength.stronger(self.strength, self.v1.walkStrength) and Direction.BACKWARD or Direction.NONE
else
self.direction = Strength.stronger(self.strength, self.v2.walkStrength) and Direction.FORWARD or Direction.BACKWARD
end
end
--
-- Add this constraint to the constraint graph
--
function BinaryConstraint:addToGraph ()
self.v1:addConstraint(self)
self.v2:addConstraint(self)
self.direction = Direction.NONE
end
--
-- Answer true if this constraint is satisfied in the current solution.
--
function BinaryConstraint:isSatisfied ()
return self.direction ~= Direction.NONE
end
--
-- Mark the input variable with the given mark.
--
function BinaryConstraint:markInputs (mark)
self:input().mark = mark
end
--
-- Returns the current input variable
--
function BinaryConstraint:input ()
return (self.direction == Direction.FORWARD) and self.v1 or self.v2
end
--
-- Returns the current output variable
--
function BinaryConstraint:output ()
return (self.direction == Direction.FORWARD) and self.v2 or self.v1
end
--
-- Calculate the walkabout strength, the stay flag, and, if it is
-- 'stay', the value for the current output of this
-- constraint. Assume this constraint is satisfied.
--
function BinaryConstraint:recalculate ()
local ihn = self:input()
local out = self:output()
out.walkStrength = Strength.weakestOf(self.strength, ihn.walkStrength);
out.stay = ihn.stay
if out.stay then self:execute() end
end
--
-- Record the fact that self constraint is unsatisfied.
--
function BinaryConstraint:markUnsatisfied ()
self.direction = Direction.NONE
end
function BinaryConstraint:inputsKnown (mark)
local i = self:input()
return i.mark == mark or i.stay or i.determinedBy == nil
end
function BinaryConstraint:removeFromGraph ()
if (self.v1 ~= nil) then self.v1:removeConstraint(self) end
if (self.v2 ~= nil) then self.v2:removeConstraint(self) end
self.direction = Direction.NONE
end
--
-- S c a l e C o n s t r a i n t
--
--
-- Relates two variables by the linear scaling relationship: "v2 =
-- (v1 * scale) + offset". Either v1 or v2 may be changed to maintain
-- this relationship but the scale factor and offset are considered
-- read-only.
--
local ScaleConstraint = class (BinaryConstraint)
function ScaleConstraint:constructor(src, scale, offset, dest, strength)
self.direction = Direction.NONE
self.scale = scale
self.offset = offset
ScaleConstraint.super.constructor(self, src, dest, strength)
end
--
-- Adds this constraint to the constraint graph.
--
function ScaleConstraint:addToGraph ()
ScaleConstraint.super.addToGraph(self)
self.scale:addConstraint(self)
self.offset:addConstraint(self)
end
function ScaleConstraint:removeFromGraph ()
ScaleConstraint.super.removeFromGraph(self)
if (self.scale ~= nil) then self.scale:removeConstraint(self) end
if (self.offset ~= nil) then self.offset:removeConstraint(self) end
end
function ScaleConstraint:markInputs (mark)
ScaleConstraint.super.markInputs(self, mark);
self.offset.mark = mark
self.scale.mark = mark
end
--
-- Enforce this constraint. Assume that it is satisfied.
--
function ScaleConstraint:execute ()
if self.direction == Direction.FORWARD then
self.v2.value = self.v1.value * self.scale.value + self.offset.value
else
self.v1.value = (self.v2.value - self.offset.value) / self.scale.value
end
end
--
-- Calculate the walkabout strength, the stay flag, and, if it is
-- 'stay', the value for the current output of this constraint. Assume
-- this constraint is satisfied.
--
function ScaleConstraint:recalculate ()
local ihn = self:input()
local out = self:output()
out.walkStrength = Strength.weakestOf(self.strength, ihn.walkStrength)
out.stay = ihn.stay and self.scale.stay and self.offset.stay
if out.stay then self:execute() end
end
--
-- E q u a l i t y C o n s t r a i n t
--
--
-- Constrains two variables to have the same value.
--
local EqualityConstraint = class (BinaryConstraint)
function EqualityConstraint:constructor(var1, var2, strength)
EqualityConstraint.super.constructor(self, var1, var2, strength)
end
--
-- Enforce this constraint. Assume that it is satisfied.
--
function EqualityConstraint:execute ()
self:output().value = self:input().value
end
--
-- V a r i a b l e
--
--
-- A constrained variable. In addition to its value, it maintain the
-- structure of the constraint graph, the current dataflow graph, and
-- various parameters of interest to the DeltaBlue incremental
-- constraint solver.
--
local Variable = class ()
function Variable:constructor(name, initialValue)
self.value = initialValue or 0
self.constraints = OrderedCollection.new()
self.determinedBy = nil
self.mark = 0
self.walkStrength = Strength.WEAKEST
self.stay = true
self.name = name
end
--
-- Add the given constraint to the set of all constraints that refer
-- this variable.
--
function Variable:addConstraint (c)
self.constraints:add(c)
end
--
-- Removes all traces of c from this variable.
--
function Variable:removeConstraint (c)
self.constraints:remove(c)
if self.determinedBy == c then
self.determinedBy = nil
end
end
--
-- P l a n n e r
--
--
-- The DeltaBlue planner
--
local Planner = class()
function Planner:constructor()
self.currentMark = 0
end
--
-- Attempt to satisfy the given constraint and, if successful,
-- incrementally update the dataflow graph. Details: If satisfying
-- the constraint is successful, it may override a weaker constraint
-- on its output. The algorithm attempts to resatisfy that
-- constraint using some other method. This process is repeated
-- until either a) it reaches a variable that was not previously
-- determined by any constraint or b) it reaches a constraint that
-- is too weak to be satisfied using any of its methods. The
-- variables of constraints that have been processed are marked with
-- a unique mark value so that we know where we've been. This allows
-- the algorithm to avoid getting into an infinite loop even if the
-- constraint graph has an inadvertent cycle.
--
function Planner:incrementalAdd (c)
local mark = self:newMark()
local overridden = c:satisfy(mark)
while overridden ~= nil do
overridden = overridden:satisfy(mark)
end
end
--
-- Entry point for retracting a constraint. Remove the given
-- constraint and incrementally update the dataflow graph.
-- Details: Retracting the given constraint may allow some currently
-- unsatisfiable downstream constraint to be satisfied. We therefore collect
-- a list of unsatisfied downstream constraints and attempt to
-- satisfy each one in turn. This list is traversed by constraint
-- strength, strongest first, as a heuristic for avoiding
-- unnecessarily adding and then overriding weak constraints.
-- Assume: c is satisfied.
--
function Planner:incrementalRemove (c)
local out = c:output()
c:markUnsatisfied()
c:removeFromGraph()
local unsatisfied = self:removePropagateFrom(out)
local strength = Strength.REQUIRED
repeat
for i = 1, unsatisfied:size() do
local u = unsatisfied:at(i)
if u.strength == strength then
self:incrementalAdd(u)
end
end
strength = strength:nextWeaker()
until strength == Strength.WEAKEST
end
--
-- Select a previously unused mark value.
--
function Planner:newMark ()
self.currentMark = self.currentMark + 1
return self.currentMark
end
--
-- Extract a plan for resatisfaction starting from the given source
-- constraints, usually a set of input constraints. This method
-- assumes that stay optimization is desired; the plan will contain
-- only constraints whose output variables are not stay. Constraints
-- that do no computation, such as stay and edit constraints, are
-- not included in the plan.
-- Details: The outputs of a constraint are marked when it is added
-- to the plan under construction. A constraint may be appended to
-- the plan when all its input variables are known. A variable is
-- known if either a) the variable is marked (indicating that has
-- been computed by a constraint appearing earlier in the plan), b)
-- the variable is 'stay' (i.e. it is a constant at plan execution
-- time), or c) the variable is not determined by any
-- constraint. The last provision is for past states of history
-- variables, which are not stay but which are also not computed by
-- any constraint.
-- Assume: sources are all satisfied.
--
local Plan -- FORWARD DECLARATION
function Planner:makePlan (sources)
local mark = self:newMark()
local plan = Plan.new()
local todo = sources
while todo:size() > 0 do
local c = todo:removeFirst()
if c:output().mark ~= mark and c:inputsKnown(mark) then
plan:addConstraint(c)
c:output().mark = mark
self:addConstraintsConsumingTo(c:output(), todo)
end
end
return plan
end
--
-- Extract a plan for resatisfying starting from the output of the
-- given constraints, usually a set of input constraints.
--
function Planner:extractPlanFromConstraints (constraints)
local sources = OrderedCollection.new()
for i = 1, constraints:size() do
local c = constraints:at(i)
if c:isInput() and c:isSatisfied() then
-- not in plan already and eligible for inclusion
sources:add(c)
end
end
return self:makePlan(sources)
end
--
-- Recompute the walkabout strengths and stay flags of all variables
-- downstream of the given constraint and recompute the actual
-- values of all variables whose stay flag is true. If a cycle is
-- detected, remove the given constraint and answer
-- false. Otherwise, answer true.
-- Details: Cycles are detected when a marked variable is
-- encountered downstream of the given constraint. The sender is
-- assumed to have marked the inputs of the given constraint with
-- the given mark. Thus, encountering a marked node downstream of
-- the output constraint means that there is a path from the
-- constraint's output to one of its inputs.
--
function Planner:addPropagate (c, mark)
local todo = OrderedCollection.new()
todo:add(c)
while todo:size() > 0 do
local d = todo:removeFirst()
if d:output().mark == mark then
self:incrementalRemove(c)
return false
end
d:recalculate()
self:addConstraintsConsumingTo(d:output(), todo)
end
return true
end
--
-- Update the walkabout strengths and stay flags of all variables
-- downstream of the given constraint. Answer a collection of
-- unsatisfied constraints sorted in order of decreasing strength.
--
function Planner:removePropagateFrom (out)
out.determinedBy = nil
out.walkStrength = Strength.WEAKEST
out.stay = true
local unsatisfied = OrderedCollection.new()
local todo = OrderedCollection.new()
todo:add(out)
while todo:size() > 0 do
local v = todo:removeFirst()
for i = 1, v.constraints:size() do
local c = v.constraints:at(i)
if not c:isSatisfied() then unsatisfied:add(c) end
end
local determining = v.determinedBy
for i = 1, v.constraints:size() do
local next = v.constraints:at(i);
if next ~= determining and next:isSatisfied() then
next:recalculate()
todo:add(next:output())
end
end
end
return unsatisfied
end
function Planner:addConstraintsConsumingTo (v, coll)
local determining = v.determinedBy
local cc = v.constraints
for i = 1, cc:size() do
local c = cc:at(i)
if c ~= determining and c:isSatisfied() then
coll:add(c)
end
end
end
--
-- P l a n
--
--
-- A Plan is an ordered list of constraints to be executed in sequence
-- to resatisfy all currently satisfiable constraints in the face of
-- one or more changing inputs.
--
Plan = class()
function Plan:constructor()
self.v = OrderedCollection.new()
end
function Plan:addConstraint (c)
self.v:add(c)
end
function Plan:size ()
return self.v:size()
end
function Plan:constraintAt (index)
return self.v:at(index)
end
function Plan:execute ()
for i = 1, self:size() do
local c = self:constraintAt(i)
c:execute()
end
end
--
-- M a i n
--
--
-- This is the standard DeltaBlue benchmark. A long chain of equality
-- constraints is constructed with a stay constraint on one end. An
-- edit constraint is then added to the opposite end and the time is
-- measured for adding and removing this constraint, and extracting
-- and executing a constraint satisfaction plan. There are two cases.
-- In case 1, the added constraint is stronger than the stay
-- constraint and values must propagate down the entire length of the
-- chain. In case 2, the added constraint is weaker than the stay
-- constraint so it cannot be accommodated. The cost in this case is,
-- of course, very low. Typical situations lie somewhere between these
-- two extremes.
--
local function chainTest(n)
planner = Planner.new()
local prev = nil
local first = nil
local last = nil
-- Build chain of n equality constraints
for i = 0, n do
local name = "v" .. i;
local v = Variable.new(name)
if prev ~= nil then EqualityConstraint.new(prev, v, Strength.REQUIRED) end
if i == 0 then first = v end
if i == n then last = v end
prev = v
end
StayConstraint.new(last, Strength.STRONG_DEFAULT)
local edit = EditConstraint.new(first, Strength.PREFERRED)
local edits = OrderedCollection.new()
edits:add(edit)
local plan = planner:extractPlanFromConstraints(edits)
for i = 0, 99 do
first.value = i
plan:execute()
if last.value ~= i then
alert("Chain test failed.")
end
end
end
local function change(v, newValue)
local edit = EditConstraint.new(v, Strength.PREFERRED)
local edits = OrderedCollection.new()
edits:add(edit)
local plan = planner:extractPlanFromConstraints(edits)
for i = 1, 10 do
v.value = newValue
plan:execute()
end
edit:destroyConstraint()
end
--
-- This test constructs a two sets of variables related to each
-- other by a simple linear transformation (scale and offset). The
-- time is measured to change a variable on either side of the
-- mapping and to change the scale and offset factors.
--
local function projectionTest(n)
planner = Planner.new();
local scale = Variable.new("scale", 10);
local offset = Variable.new("offset", 1000);
local src = nil
local dst = nil;
local dests = OrderedCollection.new();
for i = 0, n - 1 do
src = Variable.new("src" .. i, i);
dst = Variable.new("dst" .. i, i);
dests:add(dst);
StayConstraint.new(src, Strength.NORMAL);
ScaleConstraint.new(src, scale, offset, dst, Strength.REQUIRED);
end
change(src, 17)
if dst.value ~= 1170 then alert("Projection 1 failed") end
change(dst, 1050)
if src.value ~= 5 then alert("Projection 2 failed") end
change(scale, 5)
for i = 0, n - 2 do
if dests:at(i + 1).value ~= i * 5 + 1000 then
alert("Projection 3 failed")
end
end
change(offset, 2000)
for i = 0, n - 2 do
if dests:at(i + 1).value ~= i * 5 + 2000 then
alert("Projection 4 failed")
end
end
end
function test()
local t0 = os.clock()
chainTest(1000);
projectionTest(1000);
local t1 = os.clock()
return t1-t0
end
bench.runCode(test, "deltablue")

File diff suppressed because it is too large Load diff

View file

@ -32,6 +32,55 @@ std::optional<ModuleName> TestFileResolver::fromAstFragment(AstExpr* expr) const
return std::nullopt;
}
std::optional<ModuleInfo> TestFileResolver::resolveModule(const ModuleInfo* context, AstExpr* expr)
{
if (AstExprGlobal* g = expr->as<AstExprGlobal>())
{
if (g->name == "game")
return ModuleInfo{"game"};
if (g->name == "workspace")
return ModuleInfo{"workspace"};
if (g->name == "script")
return context ? std::optional<ModuleInfo>(*context) : std::nullopt;
}
else if (AstExprIndexName* i = expr->as<AstExprIndexName>(); i && context)
{
if (i->index == "Parent")
{
std::string_view view = context->name;
size_t lastSeparatorIndex = view.find_last_of('/');
if (lastSeparatorIndex == std::string_view::npos)
return std::nullopt;
return ModuleInfo{ModuleName(view.substr(0, lastSeparatorIndex)), context->optional};
}
else
{
return ModuleInfo{context->name + '/' + i->index.value, context->optional};
}
}
else if (AstExprIndexExpr* i = expr->as<AstExprIndexExpr>(); i && context)
{
if (AstExprConstantString* index = i->index->as<AstExprConstantString>())
{
return ModuleInfo{context->name + '/' + std::string(index->value.data, index->value.size), context->optional};
}
}
else if (AstExprCall* call = expr->as<AstExprCall>(); call && call->self && call->args.size >= 1 && context)
{
if (AstExprConstantString* index = call->args.data[0]->as<AstExprConstantString>())
{
AstName func = call->func->as<AstExprIndexName>()->index;
if (func == "GetService" && context->name == "game")
return ModuleInfo{"game/" + std::string(index->value.data, index->value.size)};
}
}
return std::nullopt;
}
ModuleName TestFileResolver::concat(const ModuleName& lhs, std::string_view rhs) const
{
return lhs + "/" + ModuleName(rhs);

View file

@ -65,6 +65,8 @@ struct TestFileResolver
}
std::optional<ModuleName> fromAstFragment(AstExpr* expr) const override;
std::optional<ModuleInfo> resolveModule(const ModuleInfo* context, AstExpr* expr) override;
ModuleName concat(const ModuleName& lhs, std::string_view rhs) const override;
std::optional<ModuleName> getParentModuleName(const ModuleName& name) const override;

View file

@ -58,6 +58,35 @@ struct NaiveFileResolver : NullFileResolver
return std::nullopt;
}
std::optional<ModuleInfo> resolveModule(const ModuleInfo* context, AstExpr* expr) override
{
if (AstExprGlobal* g = expr->as<AstExprGlobal>())
{
if (g->name == "Modules")
return ModuleInfo{"Modules"};
if (g->name == "game")
return ModuleInfo{"game"};
}
else if (AstExprIndexName* i = expr->as<AstExprIndexName>())
{
if (context)
return ModuleInfo{context->name + '/' + i->index.value, context->optional};
}
else if (AstExprCall* call = expr->as<AstExprCall>(); call && call->self && call->args.size >= 1 && context)
{
if (AstExprConstantString* index = call->args.data[0]->as<AstExprConstantString>())
{
AstName func = call->func->as<AstExprIndexName>()->index;
if (func == "GetService" && context->name == "game")
return ModuleInfo{"game/" + std::string(index->value.data, index->value.size)};
}
}
return std::nullopt;
}
ModuleName concat(const ModuleName& lhs, std::string_view rhs) const override
{
return lhs + "/" + ModuleName(rhs);
@ -528,7 +557,7 @@ TEST_CASE_FIXTURE(FrontendFixture, "ignore_require_to_nonexistent_file")
{
fileResolver.source["Modules/A"] = R"(
local Modules = script
local B = require(Modules.B :: any)
local B = require(Modules.B) :: any
)";
CheckResult result = frontend.check("Modules/A");

View file

@ -1400,6 +1400,8 @@ end
TEST_CASE_FIXTURE(Fixture, "TableOperations")
{
ScopedFastFlag sff("LuauLinterTableMoveZero", true);
LintResult result = lintTyped(R"(
local t = {}
local tt = {}
@ -1417,9 +1419,12 @@ table.remove(t, 0)
table.remove(t, #t-1)
table.insert(t, string.find("hello", "h"))
table.move(t, 0, #t, 1, tt)
table.move(t, 1, #t, 0, tt)
)");
REQUIRE_EQ(result.warnings.size(), 6);
REQUIRE_EQ(result.warnings.size(), 8);
CHECK_EQ(result.warnings[0].text, "table.insert will insert the value before the last element, which is likely a bug; consider removing the "
"second argument or wrap it in parentheses to silence");
CHECK_EQ(result.warnings[1].text, "table.insert will append the value to the table; consider removing the second argument for efficiency");
@ -1429,6 +1434,8 @@ table.insert(t, string.find("hello", "h"))
"second argument or wrap it in parentheses to silence");
CHECK_EQ(result.warnings[5].text,
"table.insert may change behavior if the call returns more than one result; consider adding parentheses around second argument");
CHECK_EQ(result.warnings[6].text, "table.move uses index 0 but arrays are 1-based; did you mean 1 instead?");
CHECK_EQ(result.warnings[7].text, "table.move uses index 0 but arrays are 1-based; did you mean 1 instead?");
}
TEST_CASE_FIXTURE(Fixture, "DuplicateConditions")

View file

@ -1,5 +1,6 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Module.h"
#include "Luau/Scope.h"
#include "Fixture.h"

View file

@ -1,5 +1,6 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Parser.h"
#include "Luau/Scope.h"
#include "Luau/TypeInfer.h"
#include "Luau/TypeVar.h"

View file

@ -2519,4 +2519,19 @@ TEST_CASE_FIXTURE(Fixture, "parse_if_else_expression")
}
}
TEST_CASE_FIXTURE(Fixture, "parse_type_pack_type_parameters")
{
ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true);
ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true);
AstStat* stat = parse(R"(
type Packed<T...> = () -> T...
type A<X...> = Packed<X...>
type B<X...> = Packed<...number>
type C<X...> = Packed<(number, X...)>
)");
REQUIRE(stat != nullptr);
}
TEST_SUITE_END();

View file

@ -57,6 +57,7 @@ TEST_CASE_FIXTURE(RequireTracerFixture, "trace_local")
{
AstStatBlock* block = parse(R"(
local m = workspace.Foo.Bar.Baz
require(m)
)");
RequireTraceResult result = traceRequires(&fileResolver, block, "ModuleName");
@ -70,22 +71,22 @@ TEST_CASE_FIXTURE(RequireTracerFixture, "trace_local")
AstExprIndexName* value = loc->values.data[0]->as<AstExprIndexName>();
REQUIRE(value);
REQUIRE(result.exprs.contains(value));
CHECK_EQ("workspace/Foo/Bar/Baz", result.exprs[value]);
CHECK_EQ("workspace/Foo/Bar/Baz", result.exprs[value].name);
value = value->expr->as<AstExprIndexName>();
REQUIRE(value);
REQUIRE(result.exprs.contains(value));
CHECK_EQ("workspace/Foo/Bar", result.exprs[value]);
CHECK_EQ("workspace/Foo/Bar", result.exprs[value].name);
value = value->expr->as<AstExprIndexName>();
REQUIRE(value);
REQUIRE(result.exprs.contains(value));
CHECK_EQ("workspace/Foo", result.exprs[value]);
CHECK_EQ("workspace/Foo", result.exprs[value].name);
AstExprGlobal* workspace = value->expr->as<AstExprGlobal>();
REQUIRE(workspace);
REQUIRE(result.exprs.contains(workspace));
CHECK_EQ("workspace", result.exprs[workspace]);
CHECK_EQ("workspace", result.exprs[workspace].name);
}
TEST_CASE_FIXTURE(RequireTracerFixture, "trace_transitive_local")
@ -93,9 +94,10 @@ TEST_CASE_FIXTURE(RequireTracerFixture, "trace_transitive_local")
AstStatBlock* block = parse(R"(
local m = workspace.Foo.Bar.Baz
local n = m.Quux
require(n)
)");
REQUIRE_EQ(2, block->body.size);
REQUIRE_EQ(3, block->body.size);
RequireTraceResult result = traceRequires(&fileResolver, block, "ModuleName");
@ -104,13 +106,13 @@ TEST_CASE_FIXTURE(RequireTracerFixture, "trace_transitive_local")
REQUIRE_EQ(1, local->vars.size);
REQUIRE(result.exprs.contains(local->values.data[0]));
CHECK_EQ("workspace/Foo/Bar/Baz/Quux", result.exprs[local->values.data[0]]);
CHECK_EQ("workspace/Foo/Bar/Baz/Quux", result.exprs[local->values.data[0]].name);
}
TEST_CASE_FIXTURE(RequireTracerFixture, "trace_function_arguments")
{
AstStatBlock* block = parse(R"(
local M = require(workspace.Game.Thing, workspace.Something.Else)
local M = require(workspace.Game.Thing)
)");
REQUIRE_EQ(1, block->body.size);
@ -124,52 +126,9 @@ TEST_CASE_FIXTURE(RequireTracerFixture, "trace_function_arguments")
AstExprCall* call = local->values.data[0]->as<AstExprCall>();
REQUIRE(call != nullptr);
REQUIRE_EQ(2, call->args.size);
CHECK_EQ("workspace/Game/Thing", result.exprs[call->args.data[0]]);
CHECK_EQ("workspace/Something/Else", result.exprs[call->args.data[1]]);
}
TEST_CASE_FIXTURE(RequireTracerFixture, "follow_GetService_calls")
{
AstStatBlock* block = parse(R"(
local R = game:GetService('ReplicatedStorage').Roact
local Roact = require(R)
)");
REQUIRE_EQ(2, block->body.size);
RequireTraceResult result = traceRequires(&fileResolver, block, "ModuleName");
AstStatLocal* local = block->body.data[0]->as<AstStatLocal>();
REQUIRE(local != nullptr);
CHECK_EQ("game/ReplicatedStorage/Roact", result.exprs[local->values.data[0]]);
AstStatLocal* local2 = block->body.data[1]->as<AstStatLocal>();
REQUIRE(local2 != nullptr);
REQUIRE_EQ(1, local2->values.size);
AstExprCall* call = local2->values.data[0]->as<AstExprCall>();
REQUIRE(call != nullptr);
REQUIRE_EQ(1, call->args.size);
CHECK_EQ("game/ReplicatedStorage/Roact", result.exprs[call->args.data[0]]);
}
TEST_CASE_FIXTURE(RequireTracerFixture, "follow_WaitForChild_calls")
{
ScopedFastFlag luauTraceRequireLookupChild("LuauTraceRequireLookupChild", true);
AstStatBlock* block = parse(R"(
local A = require(workspace:WaitForChild('ReplicatedStorage').Content)
local B = require(workspace:FindFirstChild('ReplicatedFirst').Data)
)");
RequireTraceResult result = traceRequires(&fileResolver, block, "ModuleName");
REQUIRE_EQ(2, result.requires.size());
CHECK_EQ("workspace/ReplicatedStorage/Content", result.requires[0].first);
CHECK_EQ("workspace/ReplicatedFirst/Data", result.requires[1].first);
CHECK_EQ("workspace/Game/Thing", result.exprs[call->args.data[0]].name);
}
TEST_CASE_FIXTURE(RequireTracerFixture, "follow_typeof")
@ -200,22 +159,23 @@ TEST_CASE_FIXTURE(RequireTracerFixture, "follow_typeof")
REQUIRE(call != nullptr);
REQUIRE_EQ(1, call->args.size);
CHECK_EQ("workspace/CoolThing", result.exprs[call->args.data[0]]);
CHECK_EQ("workspace/CoolThing", result.exprs[call->args.data[0]].name);
}
TEST_CASE_FIXTURE(RequireTracerFixture, "follow_string_indexexpr")
{
AstStatBlock* block = parse(R"(
local R = game["Test"]
require(R)
)");
REQUIRE_EQ(1, block->body.size);
REQUIRE_EQ(2, block->body.size);
RequireTraceResult result = traceRequires(&fileResolver, block, "ModuleName");
AstStatLocal* local = block->body.data[0]->as<AstStatLocal>();
REQUIRE(local != nullptr);
CHECK_EQ("game/Test", result.exprs[local->values.data[0]]);
CHECK_EQ("game/Test", result.exprs[local->values.data[0]].name);
}
TEST_SUITE_END();

View file

@ -1,5 +1,6 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Scope.h"
#include "Luau/ToString.h"
#include "Fixture.h"
@ -416,8 +417,6 @@ function foo(a, b) return a(b) end
TEST_CASE_FIXTURE(Fixture, "toString_the_boundTo_table_type_contained_within_a_TypePack")
{
ScopedFastFlag sff{"LuauToStringFollowsBoundTo", true};
TypeVar tv1{TableTypeVar{}};
TableTypeVar* ttv = getMutable<TableTypeVar>(&tv1);
ttv->state = TableState::Sealed;

View file

@ -0,0 +1,557 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Fixture.h"
#include "doctest.h"
#include "Luau/BuiltinDefinitions.h"
using namespace Luau;
TEST_SUITE_BEGIN("TypeAliases");
TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_type_alias")
{
ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true};
CheckResult result = check(R"(
type F = () -> F?
local function f()
return f
end
local g: F = f
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("t1 where t1 = () -> t1?", toString(requireType("g")));
}
TEST_CASE_FIXTURE(Fixture, "cyclic_types_of_named_table_fields_do_not_expand_when_stringified")
{
CheckResult result = check(R"(
--!strict
type Node = { Parent: Node?; }
local node: Node;
node.Parent = 1
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
TypeMismatch* tm = get<TypeMismatch>(result.errors[0]);
REQUIRE(tm);
CHECK_EQ("Node?", toString(tm->wantedType));
CHECK_EQ(typeChecker.numberType, tm->givenType);
}
TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types")
{
CheckResult result = check(R"(
--!strict
type T<a> = { f: a, g: U<a> }
type U<a> = { h: a, i: T<a>? }
local x: T<number> = { f = 37, g = { h = 5, i = nil } }
x.g.i = x
local y: T<string> = { f = "hi", g = { h = "lo", i = nil } }
y.g.i = y
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_errors")
{
CheckResult result = check(R"(
--!strict
type T<a> = { f: a, g: U<a> }
type U<b> = { h: b, i: T<b>? }
local x: T<number> = { f = 37, g = { h = 5, i = nil } }
x.g.i = x
local y: T<string> = { f = "hi", g = { h = 5, i = nil } }
y.g.i = y
)");
LUAU_REQUIRE_ERRORS(result);
// We had a UAF in this example caused by not cloning type function arguments
ModulePtr module = frontend.moduleResolver.getModule("MainModule");
unfreeze(module->interfaceTypes);
copyErrors(module->errors, module->interfaceTypes);
freeze(module->interfaceTypes);
module->internalTypes.clear();
module->astTypes.clear();
// Make sure the error strings don't include "VALUELESS"
for (auto error : module->errors)
CHECK_MESSAGE(toString(error).find("VALUELESS") == std::string::npos, toString(error));
}
TEST_CASE_FIXTURE(Fixture, "use_table_name_and_generic_params_in_errors")
{
CheckResult result = check(R"(
type Pair<T, U> = {first: T, second: U}
local a: Pair<string, number>
local b: Pair<string, string>
a = b
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
TypeMismatch* tm = get<TypeMismatch>(result.errors[0]);
REQUIRE(tm);
CHECK_EQ("Pair<string, number>", toString(tm->wantedType));
CHECK_EQ("Pair<string, string>", toString(tm->givenType));
}
TEST_CASE_FIXTURE(Fixture, "dont_stop_typechecking_after_reporting_duplicate_type_definition")
{
CheckResult result = check(R"(
type A = number
type A = string -- Redefinition of type 'A', previously defined at line 1
local foo: string = 1 -- No "Type 'number' could not be converted into 'string'"
)");
LUAU_REQUIRE_ERROR_COUNT(2, result);
}
TEST_CASE_FIXTURE(Fixture, "stringify_type_alias_of_recursive_template_table_type")
{
CheckResult result = check(R"(
type Table<T> = { a: T }
type Wrapped = Table<Wrapped>
local l: Wrapped = 2
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
TypeMismatch* tm = get<TypeMismatch>(result.errors[0]);
REQUIRE(tm);
CHECK_EQ("Wrapped", toString(tm->wantedType));
CHECK_EQ(typeChecker.numberType, tm->givenType);
}
TEST_CASE_FIXTURE(Fixture, "stringify_type_alias_of_recursive_template_table_type2")
{
CheckResult result = check(R"(
type Table<T> = { a: T }
type Wrapped = (Table<Wrapped>) -> string
local l: Wrapped = 2
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
TypeMismatch* tm = get<TypeMismatch>(result.errors[0]);
REQUIRE(tm);
CHECK_EQ("t1 where t1 = ({| a: t1 |}) -> string", toString(tm->wantedType));
CHECK_EQ(typeChecker.numberType, tm->givenType);
}
// Check that recursive intersection type doesn't generate an OOM
TEST_CASE_FIXTURE(Fixture, "cli_38393_recursive_intersection_oom")
{
CheckResult result = check(R"(
function _(l0:(t0)&((t0)&(((t0)&((t0)->()))->(typeof(_),typeof(# _)))),l39,...):any
end
type t0<t0> = ((typeof(_))&((t0)&(((typeof(_))&(t0))->typeof(_))),{n163:any,})->(any,typeof(_))
_(_)
)");
LUAU_REQUIRE_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "type_alias_fwd_declaration_is_precise")
{
CheckResult result = check(R"(
local foo: Id<number> = 1
type Id<T> = T
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "corecursive_types_generic")
{
const std::string code = R"(
type A<T> = {v:T, b:B<T>}
type B<T> = {v:T, a:A<T>}
local aa:A<number>
local bb = aa
)";
const std::string expected = R"(
type A<T> = {v:T, b:B<T>}
type B<T> = {v:T, a:A<T>}
local aa:A<number>
local bb:A<number>=aa
)";
CHECK_EQ(expected, decorateWithTypes(code));
CheckResult result = check(code);
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "corecursive_function_types")
{
ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true};
CheckResult result = check(R"(
type A = () -> (number, B)
type B = () -> (string, A)
local a: A
local b: B
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("t1 where t1 = () -> (number, () -> (string, t1))", toString(requireType("a")));
CHECK_EQ("t1 where t1 = () -> (string, () -> (number, t1))", toString(requireType("b")));
}
TEST_CASE_FIXTURE(Fixture, "generic_param_remap")
{
const std::string code = R"(
-- An example of a forwarded use of a type that has different type arguments than parameters
type A<T,U> = {t:T, u:U, next:A<U,T>?}
local aa:A<number,string> = { t = 5, u = 'hi', next = { t = 'lo', u = 8 } }
local bb = aa
)";
const std::string expected = R"(
type A<T,U> = {t:T, u:U, next:A<U,T>?}
local aa:A<number,string> = { t = 5, u = 'hi', next = { t = 'lo', u = 8 } }
local bb:A<number,string>=aa
)";
CHECK_EQ(expected, decorateWithTypes(code));
CheckResult result = check(code);
LUAU_REQUIRE_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "export_type_and_type_alias_are_duplicates")
{
CheckResult result = check(R"(
export type Foo = number
type Foo = number
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
auto dtd = get<DuplicateTypeDefinition>(result.errors[0]);
REQUIRE(dtd);
CHECK_EQ(dtd->name, "Foo");
}
TEST_CASE_FIXTURE(Fixture, "stringify_optional_parameterized_alias")
{
ScopedFastFlag sffs3{"LuauGenericFunctions", true};
ScopedFastFlag sffs4{"LuauParseGenericFunctions", true};
CheckResult result = check(R"(
type Node<T> = { value: T, child: Node<T>? }
local function visitor<T>(node: Node<T>?)
local a: Node<T>
if node then
a = node.child -- Observe the output of the error message.
end
end
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
auto e = get<TypeMismatch>(result.errors[0]);
CHECK_EQ("Node<T>?", toString(e->givenType));
CHECK_EQ("Node<T>", toString(e->wantedType));
}
TEST_CASE_FIXTURE(Fixture, "general_require_multi_assign")
{
fileResolver.source["workspace/A"] = R"(
export type myvec2 = {x: number, y: number}
return {}
)";
fileResolver.source["workspace/B"] = R"(
export type myvec3 = {x: number, y: number, z: number}
return {}
)";
fileResolver.source["workspace/C"] = R"(
local Foo, Bar = require(workspace.A), require(workspace.B)
local a: Foo.myvec2
local b: Bar.myvec3
)";
CheckResult result = frontend.check("workspace/C");
LUAU_REQUIRE_NO_ERRORS(result);
ModulePtr m = frontend.moduleResolver.modules["workspace/C"];
REQUIRE(m != nullptr);
std::optional<TypeId> aTypeId = lookupName(m->getModuleScope(), "a");
REQUIRE(aTypeId);
const Luau::TableTypeVar* aType = get<TableTypeVar>(follow(*aTypeId));
REQUIRE(aType);
REQUIRE(aType->props.size() == 2);
std::optional<TypeId> bTypeId = lookupName(m->getModuleScope(), "b");
REQUIRE(bTypeId);
const Luau::TableTypeVar* bType = get<TableTypeVar>(follow(*bTypeId));
REQUIRE(bType);
REQUIRE(bType->props.size() == 3);
}
TEST_CASE_FIXTURE(Fixture, "type_alias_import_mutation")
{
CheckResult result = check("type t10<x> = typeof(table)");
LUAU_REQUIRE_NO_ERRORS(result);
TypeId ty = getGlobalBinding(frontend.typeChecker, "table");
CHECK_EQ(toString(ty), "table");
const TableTypeVar* ttv = get<TableTypeVar>(ty);
REQUIRE(ttv);
CHECK(ttv->instantiatedTypeParams.empty());
}
TEST_CASE_FIXTURE(Fixture, "type_alias_local_mutation")
{
CheckResult result = check(R"(
type Cool = { a: number, b: string }
local c: Cool = { a = 1, b = "s" }
type NotCool<x> = Cool
)");
LUAU_REQUIRE_NO_ERRORS(result);
std::optional<TypeId> ty = requireType("c");
REQUIRE(ty);
CHECK_EQ(toString(*ty), "Cool");
const TableTypeVar* ttv = get<TableTypeVar>(*ty);
REQUIRE(ttv);
CHECK(ttv->instantiatedTypeParams.empty());
}
TEST_CASE_FIXTURE(Fixture, "type_alias_local_rename")
{
CheckResult result = check(R"(
type Cool = { a: number, b: string }
type NotCool = Cool
local c: Cool = { a = 1, b = "s" }
local d: NotCool = { a = 1, b = "s" }
)");
LUAU_REQUIRE_NO_ERRORS(result);
std::optional<TypeId> ty = requireType("c");
REQUIRE(ty);
CHECK_EQ(toString(*ty), "Cool");
ty = requireType("d");
REQUIRE(ty);
CHECK_EQ(toString(*ty), "NotCool");
}
TEST_CASE_FIXTURE(Fixture, "type_alias_local_synthetic_mutation")
{
CheckResult result = check(R"(
local c = { a = 1, b = "s" }
type Cool = typeof(c)
)");
LUAU_REQUIRE_NO_ERRORS(result);
std::optional<TypeId> ty = requireType("c");
REQUIRE(ty);
const TableTypeVar* ttv = get<TableTypeVar>(*ty);
REQUIRE(ttv);
CHECK_EQ(ttv->name, "Cool");
}
TEST_CASE_FIXTURE(Fixture, "type_alias_of_an_imported_recursive_type")
{
fileResolver.source["game/A"] = R"(
export type X = { a: number, b: X? }
return {}
)";
CheckResult aResult = frontend.check("game/A");
LUAU_REQUIRE_NO_ERRORS(aResult);
CheckResult bResult = check(R"(
local Import = require(game.A)
type X = Import.X
)");
LUAU_REQUIRE_NO_ERRORS(bResult);
std::optional<TypeId> ty1 = lookupImportedType("Import", "X");
REQUIRE(ty1);
std::optional<TypeId> ty2 = lookupType("X");
REQUIRE(ty2);
CHECK_EQ(follow(*ty1), follow(*ty2));
}
TEST_CASE_FIXTURE(Fixture, "type_alias_of_an_imported_recursive_generic_type")
{
fileResolver.source["game/A"] = R"(
export type X<T, U> = { a: T, b: U, C: X<T, U>? }
return {}
)";
CheckResult aResult = frontend.check("game/A");
LUAU_REQUIRE_NO_ERRORS(aResult);
CheckResult bResult = check(R"(
local Import = require(game.A)
type X<T, U> = Import.X<T, U>
)");
LUAU_REQUIRE_NO_ERRORS(bResult);
std::optional<TypeId> ty1 = lookupImportedType("Import", "X");
REQUIRE(ty1);
std::optional<TypeId> ty2 = lookupType("X");
REQUIRE(ty2);
CHECK_EQ(toString(*ty1, {true}), toString(*ty2, {true}));
bResult = check(R"(
local Import = require(game.A)
type X<T, U> = Import.X<U, T>
)");
LUAU_REQUIRE_NO_ERRORS(bResult);
ty1 = lookupImportedType("Import", "X");
REQUIRE(ty1);
ty2 = lookupType("X");
REQUIRE(ty2);
CHECK_EQ(toString(*ty1, {true}), "t1 where t1 = {| C: t1?, a: T, b: U |}");
CHECK_EQ(toString(*ty2, {true}), "{| C: t1, a: U, b: T |} where t1 = {| C: t1, a: U, b: T |}?");
}
TEST_CASE_FIXTURE(Fixture, "module_export_free_type_leak")
{
CheckResult result = check(R"(
function get()
return function(obj) return true end
end
export type f = typeof(get())
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "module_export_wrapped_free_type_leak")
{
CheckResult result = check(R"(
function get()
return {a = 1, b = function(obj) return true end}
end
export type f = typeof(get())
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_ok")
{
CheckResult result = check(R"(
type Tree<T> = { data: T, children: Forest<T> }
type Forest<T> = {Tree<T>}
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_not_ok_1")
{
ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true};
CheckResult result = check(R"(
-- OK because forwarded types are used with their parameters.
type Tree<T> = { data: T, children: Forest<T> }
type Forest<T> = {Tree<{T}>}
)");
LUAU_REQUIRE_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_not_ok_2")
{
ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true};
CheckResult result = check(R"(
-- Not OK because forwarded types are used with different types than their parameters.
type Forest<T> = {Tree<{T}>}
type Tree<T> = { data: T, children: Forest<T> }
)");
LUAU_REQUIRE_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_swapsies_ok")
{
CheckResult result = check(R"(
type Tree1<T,U> = { data: T, children: {Tree2<U,T>} }
type Tree2<U,T> = { data: U, children: {Tree1<T,U>} }
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_swapsies_not_ok")
{
ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true};
CheckResult result = check(R"(
type Tree1<T,U> = { data: T, children: {Tree2<U,T>} }
type Tree2<T,U> = { data: U, children: {Tree1<T,U>} }
)");
LUAU_REQUIRE_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "free_variables_from_typeof_in_aliases")
{
CheckResult result = check(R"(
function f(x) return x[1] end
-- x has type X? for a free type variable X
local x = f ({})
type ContainsFree<a> = { this: a, that: typeof(x) }
type ContainsContainsFree = { that: ContainsFree<number> }
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "non_recursive_aliases_that_reuse_a_generic_name")
{
ScopedFastFlag sff1{"LuauSubstitutionDontReplaceIgnoredTypes", true};
CheckResult result = check(R"(
type Array<T> = { [number]: T }
type Tuple<T, V> = Array<T | V>
local p: Tuple<number, string>
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("{number | string}", toString(requireType("p"), {true}));
}
TEST_SUITE_END();

View file

@ -30,6 +30,8 @@ TEST_SUITE_BEGIN("ProvisionalTests");
*/
TEST_CASE_FIXTURE(Fixture, "typeguard_inference_incomplete")
{
ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true);
const std::string code = R"(
function f(a)
if type(a) == "boolean" then
@ -41,11 +43,11 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_inference_incomplete")
)";
const std::string expected = R"(
function f(a:{fn:()->(free)}): ()
function f(a:{fn:()->(free,free...)}): ()
if type(a) == 'boolean'then
local a1:boolean=a
elseif a.fn()then
local a2:{fn:()->(free)}=a
local a2:{fn:()->(free,free...)}=a
end
end
)";
@ -231,16 +233,7 @@ TEST_CASE_FIXTURE(Fixture, "operator_eq_completely_incompatible")
local r2 = b == a
)");
if (FFlag::LuauEqConstraint)
{
LUAU_REQUIRE_NO_ERRORS(result);
}
else
{
LUAU_REQUIRE_ERROR_COUNT(2, result);
CHECK_EQ(toString(result.errors[0]), "Type '{| x: string |}?' could not be converted into 'number | string'");
CHECK_EQ(toString(result.errors[1]), "Type 'number | string' could not be converted into '{| x: string |}?'");
}
}
// Belongs in TypeInfer.refinements.test.cpp.
@ -542,6 +535,25 @@ TEST_CASE_FIXTURE(Fixture, "bail_early_on_typescript_port_of_Result_type" * doct
}
}
TEST_CASE_FIXTURE(Fixture, "table_subtyping_shouldn't_add_optional_properties_to_sealed_tables")
{
CheckResult result = check(R"(
--!strict
local function setNumber(t: { p: number? }, x:number) t.p = x end
local function getString(t: { p: string? }):string return t.p or "" end
-- This shouldn't type-check!
local function oh(x:number): string
local t: {} = {}
setNumber(t, x)
return getString(t)
end
local s: string = oh(37)
)");
// Really this should return an error, but it doesn't
LUAU_REQUIRE_NO_ERRORS(result);
}
// Should be in TypeInfer.tables.test.cpp
// It's unsound to instantiate tables containing generic methods,
// since mutating properties means table properties should be invariant.

View file

@ -1,4 +1,5 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Scope.h"
#include "Luau/TypeInfer.h"
#include "Fixture.h"
@ -6,7 +7,6 @@
#include "doctest.h"
LUAU_FASTFLAG(LuauWeakEqConstraint)
LUAU_FASTFLAG(LuauImprovedTypeGuardPredicate2)
LUAU_FASTFLAG(LuauOrPredicate)
using namespace Luau;
@ -199,16 +199,8 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_only_look_up_types_from_global_scope")
end
)");
if (FFlag::LuauImprovedTypeGuardPredicate2)
{
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ("Type 'number' has no overlap with 'string'", toString(result.errors[0]));
}
else
{
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ("Type 'string' could not be converted into 'boolean'", toString(result.errors[0]));
}
}
TEST_CASE_FIXTURE(Fixture, "call_a_more_specific_function_using_typeguard")
@ -526,8 +518,6 @@ TEST_CASE_FIXTURE(Fixture, "narrow_property_of_a_bounded_variable")
TEST_CASE_FIXTURE(Fixture, "type_narrow_to_vector")
{
ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true};
CheckResult result = check(R"(
local function f(x)
if type(x) == "vector" then
@ -544,8 +534,6 @@ TEST_CASE_FIXTURE(Fixture, "type_narrow_to_vector")
TEST_CASE_FIXTURE(Fixture, "nonoptional_type_can_narrow_to_nil_if_sense_is_true")
{
ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true};
CheckResult result = check(R"(
local t = {"hello"}
local v = t[2]
@ -573,8 +561,6 @@ TEST_CASE_FIXTURE(Fixture, "nonoptional_type_can_narrow_to_nil_if_sense_is_true"
TEST_CASE_FIXTURE(Fixture, "typeguard_not_to_be_string")
{
ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true};
CheckResult result = check(R"(
local function f(x: string | number | boolean)
if type(x) ~= "string" then
@ -593,8 +579,6 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_not_to_be_string")
TEST_CASE_FIXTURE(Fixture, "typeguard_narrows_for_table")
{
ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true};
CheckResult result = check(R"(
local function f(x: string | {x: number} | {y: boolean})
if type(x) == "table" then
@ -613,8 +597,6 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_narrows_for_table")
TEST_CASE_FIXTURE(Fixture, "typeguard_narrows_for_functions")
{
ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true};
CheckResult result = check(R"(
local function weird(x: string | ((number) -> string))
if type(x) == "function" then
@ -698,8 +680,6 @@ struct RefinementClassFixture : Fixture
TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_free_table_to_vector")
{
ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true};
CheckResult result = check(R"(
local function f(vec)
local X, Y, Z = vec.X, vec.Y, vec.Z
@ -726,8 +706,6 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_free_table_to_vector")
TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_instance_or_vector3_to_vector")
{
ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true};
CheckResult result = check(R"(
local function f(x: Instance | Vector3)
if typeof(x) == "Vector3" then
@ -746,8 +724,6 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_instance_or_vector3_to
TEST_CASE_FIXTURE(RefinementClassFixture, "type_narrow_for_all_the_userdata")
{
ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true};
CheckResult result = check(R"(
local function f(x: string | number | Instance | Vector3)
if type(x) == "userdata" then
@ -766,10 +742,7 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "type_narrow_for_all_the_userdata")
TEST_CASE_FIXTURE(RefinementClassFixture, "eliminate_subclasses_of_instance")
{
ScopedFastFlag sffs[] = {
{"LuauImprovedTypeGuardPredicate2", true},
{"LuauTypeGuardPeelsAwaySubclasses", true},
};
ScopedFastFlag sff{"LuauTypeGuardPeelsAwaySubclasses", true};
CheckResult result = check(R"(
local function f(x: Part | Folder | string)
@ -789,10 +762,7 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "eliminate_subclasses_of_instance")
TEST_CASE_FIXTURE(RefinementClassFixture, "narrow_this_large_union")
{
ScopedFastFlag sffs[] = {
{"LuauImprovedTypeGuardPredicate2", true},
{"LuauTypeGuardPeelsAwaySubclasses", true},
};
ScopedFastFlag sff{"LuauTypeGuardPeelsAwaySubclasses", true};
CheckResult result = check(R"(
local function f(x: Part | Folder | Instance | string | Vector3 | any)
@ -812,10 +782,7 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "narrow_this_large_union")
TEST_CASE_FIXTURE(RefinementClassFixture, "x_as_any_if_x_is_instance_elseif_x_is_table")
{
ScopedFastFlag sffs[] = {
{"LuauOrPredicate", true},
{"LuauImprovedTypeGuardPredicate2", true},
};
ScopedFastFlag sff{"LuauOrPredicate", true};
CheckResult result = check(R"(
--!nonstrict
@ -839,7 +806,6 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "x_is_not_instance_or_else_not_part")
{
ScopedFastFlag sffs[] = {
{"LuauOrPredicate", true},
{"LuauImprovedTypeGuardPredicate2", true},
{"LuauTypeGuardPeelsAwaySubclasses", true},
};
@ -861,8 +827,6 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "x_is_not_instance_or_else_not_part")
TEST_CASE_FIXTURE(Fixture, "type_guard_can_filter_for_intersection_of_tables")
{
ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true};
CheckResult result = check(R"(
type XYCoord = {x: number} & {y: number}
local function f(t: XYCoord?)
@ -882,8 +846,6 @@ TEST_CASE_FIXTURE(Fixture, "type_guard_can_filter_for_intersection_of_tables")
TEST_CASE_FIXTURE(Fixture, "type_guard_can_filter_for_overloaded_function")
{
ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true};
CheckResult result = check(R"(
type SomeOverloadedFunction = ((number) -> string) & ((string) -> number)
local function f(g: SomeOverloadedFunction?)
@ -903,8 +865,6 @@ TEST_CASE_FIXTURE(Fixture, "type_guard_can_filter_for_overloaded_function")
TEST_CASE_FIXTURE(Fixture, "type_guard_warns_on_no_overlapping_types_only_when_sense_is_true")
{
ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true};
CheckResult result = check(R"(
local function f(t: {x: number})
if type(t) ~= "table" then
@ -999,10 +959,7 @@ TEST_CASE_FIXTURE(Fixture, "not_a_and_not_b2")
TEST_CASE_FIXTURE(Fixture, "either_number_or_string")
{
ScopedFastFlag sffs[] = {
{"LuauOrPredicate", true},
{"LuauImprovedTypeGuardPredicate2", true},
};
ScopedFastFlag sff{"LuauOrPredicate", true};
CheckResult result = check(R"(
local function f(x: any)
@ -1036,10 +993,7 @@ TEST_CASE_FIXTURE(Fixture, "not_t_or_some_prop_of_t")
TEST_CASE_FIXTURE(Fixture, "assert_a_to_be_truthy_then_assert_a_to_be_number")
{
ScopedFastFlag sffs[] = {
{"LuauOrPredicate", true},
{"LuauImprovedTypeGuardPredicate2", true},
};
ScopedFastFlag sff{"LuauOrPredicate", true};
CheckResult result = check(R"(
local a: (number | string)?
@ -1057,10 +1011,7 @@ TEST_CASE_FIXTURE(Fixture, "assert_a_to_be_truthy_then_assert_a_to_be_number")
TEST_CASE_FIXTURE(Fixture, "merge_should_be_fully_agnostic_of_hashmap_ordering")
{
ScopedFastFlag sffs[] = {
{"LuauOrPredicate", true},
{"LuauImprovedTypeGuardPredicate2", true},
};
ScopedFastFlag sff{"LuauOrPredicate", true};
// This bug came up because there was a mistake in Luau::merge where zipping on two maps would produce the wrong merged result.
CheckResult result = check(R"(
@ -1081,10 +1032,7 @@ TEST_CASE_FIXTURE(Fixture, "merge_should_be_fully_agnostic_of_hashmap_ordering")
TEST_CASE_FIXTURE(Fixture, "refine_the_correct_types_opposite_of_when_a_is_not_number_or_string")
{
ScopedFastFlag sffs[] = {
{"LuauOrPredicate", true},
{"LuauImprovedTypeGuardPredicate2", true},
};
ScopedFastFlag sff{"LuauOrPredicate", true};
CheckResult result = check(R"(
local function f(a: string | number | boolean)

View file

@ -46,6 +46,21 @@ TEST_CASE_FIXTURE(Fixture, "augment_table")
CHECK(tType->props.find("foo") != tType->props.end());
}
TEST_CASE_FIXTURE(Fixture, "augment_nested_table")
{
CheckResult result = check("local t = { p = {} } t.p.foo = 'bar'");
LUAU_REQUIRE_NO_ERRORS(result);
TableTypeVar* tType = getMutable<TableTypeVar>(requireType("t"));
REQUIRE(tType != nullptr);
REQUIRE(tType->props.find("p") != tType->props.end());
const TableTypeVar* pType = get<TableTypeVar>(tType->props["p"].type);
REQUIRE(pType != nullptr);
CHECK(pType->props.find("foo") != pType->props.end());
}
TEST_CASE_FIXTURE(Fixture, "cannot_augment_sealed_table")
{
CheckResult result = check("local t = {prop=999} t.foo = 'bar'");
@ -260,6 +275,8 @@ TEST_CASE_FIXTURE(Fixture, "open_table_unification")
TEST_CASE_FIXTURE(Fixture, "open_table_unification_2")
{
ScopedFastFlag sff{"LuauTableSubtypingVariance", true};
CheckResult result = check(R"(
local a = {}
a.x = 99
@ -272,10 +289,11 @@ TEST_CASE_FIXTURE(Fixture, "open_table_unification_2")
LUAU_REQUIRE_ERROR_COUNT(1, result);
TypeError& err = result.errors[0];
UnknownProperty* error = get<UnknownProperty>(err);
MissingProperties* error = get<MissingProperties>(err);
REQUIRE(error != nullptr);
REQUIRE(error->properties.size() == 1);
CHECK_EQ(error->key, "y");
CHECK_EQ("y", error->properties[0]);
// TODO(rblanckaert): Revist when we can bind self at function creation time
// CHECK_EQ(err.location, Location(Position{5, 19}, Position{5, 25}));
@ -328,6 +346,8 @@ TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_1")
TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_2")
{
ScopedFastFlag sff{"LuauTableSubtypingVariance", true};
CheckResult result = check(R"(
--!strict
function foo(o)
@ -340,14 +360,17 @@ TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_2")
LUAU_REQUIRE_ERROR_COUNT(1, result);
UnknownProperty* error = get<UnknownProperty>(result.errors[0]);
MissingProperties* error = get<MissingProperties>(result.errors[0]);
REQUIRE(error != nullptr);
REQUIRE(error->properties.size() == 1);
CHECK_EQ("baz", error->key);
CHECK_EQ("baz", error->properties[0]);
}
TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_3")
{
ScopedFastFlag sff{"LuauTableSubtypingVariance", true};
CheckResult result = check(R"(
local T = {}
T.bar = 'hello'
@ -359,8 +382,11 @@ TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_3")
LUAU_REQUIRE_ERROR_COUNT(1, result);
TypeError& err = result.errors[0];
UnknownProperty* error = get<UnknownProperty>(err);
MissingProperties* error = get<MissingProperties>(err);
REQUIRE(error != nullptr);
REQUIRE(error->properties.size() == 1);
CHECK_EQ("baz", error->properties[0]);
// TODO(rblanckaert): Revist when we can bind self at function creation time
/*
@ -448,6 +474,73 @@ TEST_CASE_FIXTURE(Fixture, "ok_to_add_property_to_free_table")
dumpErrors(result);
}
TEST_CASE_FIXTURE(Fixture, "okay_to_add_property_to_unsealed_tables_by_assignment")
{
ScopedFastFlag sff{"LuauTableSubtypingVariance", true};
CheckResult result = check(R"(
--!strict
local t = { u = {} }
t = { u = { p = 37 } }
t = { u = { q = "hi" } }
local x = t.u.p
local y = t.u.q
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("number?", toString(requireType("x")));
CHECK_EQ("string?", toString(requireType("y")));
}
TEST_CASE_FIXTURE(Fixture, "okay_to_add_property_to_unsealed_tables_by_function_call")
{
CheckResult result = check(R"(
--!strict
function get(x) return x.opts["MYOPT"] end
function set(x,y) x.opts["MYOPT"] = y end
local t = { opts = {} }
set(t,37)
local x = get(t)
)");
// Currently this errors but it shouldn't, since set only needs write access
// TODO: file a JIRA for this
LUAU_REQUIRE_ERRORS(result);
// CHECK_EQ("number?", toString(requireType("x")));
}
TEST_CASE_FIXTURE(Fixture, "width_subtyping")
{
ScopedFastFlag sff{"LuauTableSubtypingVariance", true};
CheckResult result = check(R"(
--!strict
function f(x : { q : number })
x.q = 8
end
local t : { q : number, r : string } = { q = 8, r = "hi" }
f(t)
local x : string = t.r
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "width_subtyping_needs_covariance")
{
CheckResult result = check(R"(
--!strict
function f(x : { p : { q : number }})
x.p = { q = 8, r = 5 }
end
local t : { p : { q : number, r : string } } = { p = { q = 8, r = "hi" } }
f(t) -- Shouldn't typecheck
local x : string = t.p.r -- x is 5
)");
LUAU_REQUIRE_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "infer_array")
{
CheckResult result = check(R"(
@ -676,16 +769,27 @@ TEST_CASE_FIXTURE(Fixture, "infer_indexer_for_left_unsealed_table_from_right_han
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "sealed_table_value_must_not_infer_an_indexer")
TEST_CASE_FIXTURE(Fixture, "sealed_table_value_can_infer_an_indexer")
{
ScopedFastFlag sff{"LuauTableSubtypingVariance", true};
CheckResult result = check(R"(
local t: { a: string, [number]: string } = { a = "foo" }
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
LUAU_REQUIRE_NO_ERRORS(result);
}
TypeMismatch* tm = get<TypeMismatch>(result.errors[0]);
REQUIRE(tm != nullptr);
TEST_CASE_FIXTURE(Fixture, "array_factory_function")
{
ScopedFastFlag sff{"LuauTableSubtypingVariance", true};
CheckResult result = check(R"(
function empty() return {} end
local array: {string} = empty()
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "sealed_table_indexers_must_unify")
@ -756,37 +860,6 @@ TEST_CASE_FIXTURE(Fixture, "indexing_from_a_table_should_prefer_properties_when_
CHECK_MESSAGE(nullptr != get<TypeMismatch>(result.errors[0]), "Expected a TypeMismatch but got " << result.errors[0]);
}
TEST_CASE_FIXTURE(Fixture, "indexing_from_a_table_with_a_string")
{
ScopedFastFlag fflag("LuauIndexTablesWithIndexers", true);
CheckResult result = check(R"(
local t: { a: string }
function f(x: string) return t[x] end
local a = f("a")
local b = f("b")
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(*typeChecker.anyType, *requireType("a"));
CHECK_EQ(*typeChecker.anyType, *requireType("b"));
}
TEST_CASE_FIXTURE(Fixture, "indexing_from_a_table_with_a_number")
{
ScopedFastFlag fflag("LuauIndexTablesWithIndexers", true);
CheckResult result = check(R"(
local t = { a = true }
function f(x: number) return t[x] end
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_MESSAGE(nullptr != get<TypeMismatch>(result.errors[0]), "Expected a TypeMismatch but got " << result.errors[0]);
}
TEST_CASE_FIXTURE(Fixture, "assigning_to_an_unsealed_table_with_string_literal_should_infer_new_properties_over_indexer")
{
CheckResult result = check(R"(
@ -1392,6 +1465,8 @@ TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer2")
TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer3")
{
ScopedFastFlag sff{"LuauTableSubtypingVariance", true};
CheckResult result = check(R"(
local function foo(a: {[string]: number, a: string}) end
foo({ a = 1 })
@ -1402,8 +1477,21 @@ TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer3")
ToStringOptions o{/* exhaustive= */ true};
TypeMismatch* tm = get<TypeMismatch>(result.errors[0]);
REQUIRE(tm);
CHECK_EQ("string", toString(tm->wantedType, o));
CHECK_EQ("number", toString(tm->givenType, o));
CHECK_EQ("{| [string]: number, a: string |}", toString(tm->wantedType, o));
CHECK_EQ("{| a: number |}", toString(tm->givenType, o));
}
TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer4")
{
CheckResult result = check(R"(
local function foo(a: {[string]: number, a: string}, i: string)
return a[i]
end
local hi: number = foo({ a = "hi" }, "a") -- shouldn't typecheck since at runtime hi is "hi"
)");
// This typechecks but shouldn't
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_missing_props_dont_report_multiple_errors")
@ -1446,22 +1534,32 @@ TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_missing_props_dont_report_multi
TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_extra_props_dont_report_multiple_errors")
{
CheckResult result = check(R"(
local vec3 = {x = 1, y = 2, z = 3}
local vec1 = {x = 1}
local vec3 = {{x = 1, y = 2, z = 3}}
local vec1 = {{x = 1}}
vec1 = vec3
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
MissingProperties* mp = get<MissingProperties>(result.errors[0]);
REQUIRE(mp);
CHECK_EQ(mp->context, MissingProperties::Extra);
REQUIRE_EQ(2, mp->properties.size());
CHECK_EQ(mp->properties[0], "y");
CHECK_EQ(mp->properties[1], "z");
CHECK_EQ("vec1", toString(mp->superType));
CHECK_EQ("vec3", toString(mp->subType));
TypeMismatch* tm = get<TypeMismatch>(result.errors[0]);
REQUIRE(tm);
CHECK_EQ("vec1", toString(tm->wantedType));
CHECK_EQ("vec3", toString(tm->givenType));
}
TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_extra_props_is_ok")
{
ScopedFastFlag sff{"LuauTableSubtypingVariance", true};
CheckResult result = check(R"(
local vec3 = {x = 1, y = 2, z = 3}
local vec1 = {x = 1}
vec1 = vec3
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "type_mismatch_on_massive_table_is_cut_short")
@ -1824,4 +1922,32 @@ TEST_CASE_FIXTURE(Fixture, "invariant_table_properties_means_instantiating_table
LUAU_REQUIRE_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "table_insert_should_cope_with_optional_properties_in_nonstrict")
{
CheckResult result = check(R"(
--!nonstrict
local buttons = {}
table.insert(buttons, { a = 1 })
table.insert(buttons, { a = 2, b = true })
table.insert(buttons, { a = 3 })
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "table_insert_should_cope_with_optional_properties_in_strict")
{
ScopedFastFlag sff{"LuauTableSubtypingVariance", true};
CheckResult result = check(R"(
--!strict
local buttons = {}
table.insert(buttons, { a = 1 })
table.insert(buttons, { a = 2, b = true })
table.insert(buttons, { a = 3 })
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_SUITE_END();

View file

@ -3,6 +3,7 @@
#include "Luau/AstQuery.h"
#include "Luau/BuiltinDefinitions.h"
#include "Luau/Parser.h"
#include "Luau/Scope.h"
#include "Luau/TypeInfer.h"
#include "Luau/TypeVar.h"
#include "Luau/VisitTypeVar.h"
@ -978,23 +979,6 @@ TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_args")
CHECK_EQ("t1 where t1 = (t1) -> ()", toString(requireType("f")));
}
TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_type_alias")
{
ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true};
CheckResult result = check(R"(
type F = () -> F?
local function f()
return f
end
local g: F = f
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("t1 where t1 = () -> t1?", toString(requireType("g")));
}
// TODO: File a Jira about this
/*
TEST_CASE_FIXTURE(Fixture, "unifying_vararg_pack_with_fixed_length_pack_produces_fixed_length_pack")
@ -1257,23 +1241,6 @@ TEST_CASE_FIXTURE(Fixture, "instantiate_cyclic_generic_function")
REQUIRE_EQ(follow(*methodArg), follow(arg));
}
TEST_CASE_FIXTURE(Fixture, "cyclic_types_of_named_table_fields_do_not_expand_when_stringified")
{
CheckResult result = check(R"(
--!strict
type Node = { Parent: Node?; }
local node: Node;
node.Parent = 1
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
TypeMismatch* tm = get<TypeMismatch>(result.errors[0]);
REQUIRE(tm);
CHECK_EQ("Node?", toString(tm->wantedType));
CHECK_EQ(typeChecker.numberType, tm->givenType);
}
TEST_CASE_FIXTURE(Fixture, "varlist_declared_by_for_in_loop_should_be_free")
{
CheckResult result = check(R"(
@ -2591,48 +2558,6 @@ TEST_CASE_FIXTURE(Fixture, "toposort_doesnt_break_mutual_recursion")
dumpErrors(result);
}
TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types")
{
CheckResult result = check(R"(
--!strict
type T<a> = { f: a, g: U<a> }
type U<a> = { h: a, i: T<a>? }
local x: T<number> = { f = 37, g = { h = 5, i = nil } }
x.g.i = x
local y: T<string> = { f = "hi", g = { h = "lo", i = nil } }
y.g.i = y
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_errors")
{
CheckResult result = check(R"(
--!strict
type T<a> = { f: a, g: U<a> }
type U<b> = { h: b, i: T<b>? }
local x: T<number> = { f = 37, g = { h = 5, i = nil } }
x.g.i = x
local y: T<string> = { f = "hi", g = { h = 5, i = nil } }
y.g.i = y
)");
LUAU_REQUIRE_ERRORS(result);
// We had a UAF in this example caused by not cloning type function arguments
ModulePtr module = frontend.moduleResolver.getModule("MainModule");
unfreeze(module->interfaceTypes);
copyErrors(module->errors, module->interfaceTypes);
freeze(module->interfaceTypes);
module->internalTypes.clear();
module->astTypes.clear();
// Make sure the error strings don't include "VALUELESS"
for (auto error : module->errors)
CHECK_MESSAGE(toString(error).find("VALUELESS") == std::string::npos, toString(error));
}
TEST_CASE_FIXTURE(Fixture, "object_constructor_can_refer_to_method_of_self")
{
// CLI-30902
@ -3388,16 +3313,7 @@ TEST_CASE_FIXTURE(Fixture, "unknown_type_in_comparison")
end
)");
if (FFlag::LuauEqConstraint)
{
LUAU_REQUIRE_NO_ERRORS(result);
}
else
{
LUAU_REQUIRE_ERROR_COUNT(1, result);
REQUIRE(get<CannotInferBinaryOperation>(result.errors[0]));
}
}
TEST_CASE_FIXTURE(Fixture, "relation_op_on_any_lhs_where_rhs_maybe_has_metatable")
@ -3407,18 +3323,8 @@ TEST_CASE_FIXTURE(Fixture, "relation_op_on_any_lhs_where_rhs_maybe_has_metatable
print((x == true and (x .. "y")) .. 1)
)");
if (FFlag::LuauEqConstraint)
{
LUAU_REQUIRE_ERROR_COUNT(1, result);
REQUIRE(get<CannotInferBinaryOperation>(result.errors[0]));
}
else
{
LUAU_REQUIRE_ERROR_COUNT(2, result);
CHECK_EQ("Type 'boolean' could not be converted into 'number | string'", toString(result.errors[0]));
CHECK_EQ("Type 'boolean | string' could not be converted into 'number | string'", toString(result.errors[1]));
}
}
TEST_CASE_FIXTURE(Fixture, "concat_op_on_string_lhs_and_free_rhs")
@ -3530,25 +3436,6 @@ _(...)(...,setfenv,_):_G()
)");
}
TEST_CASE_FIXTURE(Fixture, "use_table_name_and_generic_params_in_errors")
{
CheckResult result = check(R"(
type Pair<T, U> = {first: T, second: U}
local a: Pair<string, number>
local b: Pair<string, string>
a = b
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
TypeMismatch* tm = get<TypeMismatch>(result.errors[0]);
REQUIRE(tm);
CHECK_EQ("Pair<string, number>", toString(tm->wantedType));
CHECK_EQ("Pair<string, string>", toString(tm->givenType));
}
TEST_CASE_FIXTURE(Fixture, "cyclic_type_packs")
{
// this has a risk of creating cyclic type packs, causing infinite loops / OOMs
@ -3658,17 +3545,6 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_where_iteratee_is_free")
)");
}
TEST_CASE_FIXTURE(Fixture, "dont_stop_typechecking_after_reporting_duplicate_type_definition")
{
CheckResult result = check(R"(
type A = number
type A = string -- Redefinition of type 'A', previously defined at line 1
local foo: string = 1 -- No "Type 'number' could not be converted into 'string'"
)");
LUAU_REQUIRE_ERROR_COUNT(2, result);
}
TEST_CASE_FIXTURE(Fixture, "tc_after_error_recovery")
{
CheckResult result = check(R"(
@ -3771,38 +3647,6 @@ TEST_CASE_FIXTURE(Fixture, "error_on_invalid_operand_types_to_relational_operato
CHECK_EQ("Type 'number | string' cannot be compared with relational operator <", ge->message);
}
TEST_CASE_FIXTURE(Fixture, "stringify_type_alias_of_recursive_template_table_type")
{
CheckResult result = check(R"(
type Table<T> = { a: T }
type Wrapped = Table<Wrapped>
local l: Wrapped = 2
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
TypeMismatch* tm = get<TypeMismatch>(result.errors[0]);
REQUIRE(tm);
CHECK_EQ("Wrapped", toString(tm->wantedType));
CHECK_EQ(typeChecker.numberType, tm->givenType);
}
TEST_CASE_FIXTURE(Fixture, "stringify_type_alias_of_recursive_template_table_type2")
{
CheckResult result = check(R"(
type Table<T> = { a: T }
type Wrapped = (Table<Wrapped>) -> string
local l: Wrapped = 2
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
TypeMismatch* tm = get<TypeMismatch>(result.errors[0]);
REQUIRE(tm);
CHECK_EQ("t1 where t1 = ({| a: t1 |}) -> string", toString(tm->wantedType));
CHECK_EQ(typeChecker.numberType, tm->givenType);
}
TEST_CASE_FIXTURE(Fixture, "index_expr_should_be_checked")
{
CheckResult result = check(R"(
@ -3928,19 +3772,6 @@ TEST_CASE_FIXTURE(Fixture, "stringify_nested_unions_with_optionals")
CHECK_EQ("(boolean | number | string)?", toString(tm->givenType));
}
// Check that recursive intersection type doesn't generate an OOM
TEST_CASE_FIXTURE(Fixture, "cli_38393_recursive_intersection_oom")
{
CheckResult result = check(R"(
function _(l0:(t0)&((t0)&(((t0)&((t0)->()))->(typeof(_),typeof(# _)))),l39,...):any
end
type t0<t0> = ((typeof(_))&((t0)&(((typeof(_))&(t0))->typeof(_))),{n163:any,})->(any,typeof(_))
_(_)
)");
LUAU_REQUIRE_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "UnknownGlobalCompoundAssign")
{
// In non-strict mode, global definition is still allowed
@ -3993,16 +3824,6 @@ TEST_CASE_FIXTURE(Fixture, "loop_typecheck_crash_on_empty_optional")
LUAU_REQUIRE_ERROR_COUNT(2, result);
}
TEST_CASE_FIXTURE(Fixture, "type_alias_fwd_declaration_is_precise")
{
CheckResult result = check(R"(
local foo: Id<number> = 1
type Id<T> = T
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "cli_39932_use_unifier_in_ensure_methods")
{
CheckResult result = check(R"(
@ -4033,81 +3854,6 @@ end
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "corecursive_types_generic")
{
const std::string code = R"(
type A<T> = {v:T, b:B<T>}
type B<T> = {v:T, a:A<T>}
local aa:A<number>
local bb = aa
)";
const std::string expected = R"(
type A<T> = {v:T, b:B<T>}
type B<T> = {v:T, a:A<T>}
local aa:A<number>
local bb:A<number>=aa
)";
CHECK_EQ(expected, decorateWithTypes(code));
CheckResult result = check(code);
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "corecursive_function_types")
{
ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true};
CheckResult result = check(R"(
type A = () -> (number, B)
type B = () -> (string, A)
local a: A
local b: B
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("t1 where t1 = () -> (number, () -> (string, t1))", toString(requireType("a")));
CHECK_EQ("t1 where t1 = () -> (string, () -> (number, t1))", toString(requireType("b")));
}
TEST_CASE_FIXTURE(Fixture, "generic_param_remap")
{
const std::string code = R"(
-- An example of a forwarded use of a type that has different type arguments than parameters
type A<T,U> = {t:T, u:U, next:A<U,T>?}
local aa:A<number,string> = { t = 5, u = 'hi', next = { t = 'lo', u = 8 } }
local bb = aa
)";
const std::string expected = R"(
type A<T,U> = {t:T, u:U, next:A<U,T>?}
local aa:A<number,string> = { t = 5, u = 'hi', next = { t = 'lo', u = 8 } }
local bb:A<number,string>=aa
)";
CHECK_EQ(expected, decorateWithTypes(code));
CheckResult result = check(code);
LUAU_REQUIRE_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "export_type_and_type_alias_are_duplicates")
{
CheckResult result = check(R"(
export type Foo = number
type Foo = number
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
auto dtd = get<DuplicateTypeDefinition>(result.errors[0]);
REQUIRE(dtd);
CHECK_EQ(dtd->name, "Foo");
}
TEST_CASE_FIXTURE(Fixture, "dont_report_type_errors_within_an_AstStatError")
{
CheckResult result = check(R"(
@ -4212,30 +3958,6 @@ TEST_CASE_FIXTURE(Fixture, "luau_resolves_symbols_the_same_way_lua_does")
REQUIRE_MESSAGE(get<UnknownSymbol>(e) != nullptr, "Expected UnknownSymbol, but got " << e);
}
TEST_CASE_FIXTURE(Fixture, "stringify_optional_parameterized_alias")
{
ScopedFastFlag sffs3{"LuauGenericFunctions", true};
ScopedFastFlag sffs4{"LuauParseGenericFunctions", true};
CheckResult result = check(R"(
type Node<T> = { value: T, child: Node<T>? }
local function visitor<T>(node: Node<T>?)
local a: Node<T>
if node then
a = node.child -- Observe the output of the error message.
end
end
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
auto e = get<TypeMismatch>(result.errors[0]);
CHECK_EQ("Node<T>?", toString(e->givenType));
CHECK_EQ("Node<T>", toString(e->wantedType));
}
TEST_CASE_FIXTURE(Fixture, "operator_eq_verifies_types_do_intersect")
{
CheckResult result = check(R"(
@ -4291,181 +4013,6 @@ local tbl: string = require(game.A)
CHECK_EQ("Type '{| def: number |}' could not be converted into 'string'", toString(result.errors[0]));
}
TEST_CASE_FIXTURE(Fixture, "general_require_multi_assign")
{
fileResolver.source["workspace/A"] = R"(
export type myvec2 = {x: number, y: number}
return {}
)";
fileResolver.source["workspace/B"] = R"(
export type myvec3 = {x: number, y: number, z: number}
return {}
)";
fileResolver.source["workspace/C"] = R"(
local Foo, Bar = require(workspace.A), require(workspace.B)
local a: Foo.myvec2
local b: Bar.myvec3
)";
CheckResult result = frontend.check("workspace/C");
LUAU_REQUIRE_NO_ERRORS(result);
ModulePtr m = frontend.moduleResolver.modules["workspace/C"];
REQUIRE(m != nullptr);
std::optional<TypeId> aTypeId = lookupName(m->getModuleScope(), "a");
REQUIRE(aTypeId);
const Luau::TableTypeVar* aType = get<TableTypeVar>(follow(*aTypeId));
REQUIRE(aType);
REQUIRE(aType->props.size() == 2);
std::optional<TypeId> bTypeId = lookupName(m->getModuleScope(), "b");
REQUIRE(bTypeId);
const Luau::TableTypeVar* bType = get<TableTypeVar>(follow(*bTypeId));
REQUIRE(bType);
REQUIRE(bType->props.size() == 3);
}
TEST_CASE_FIXTURE(Fixture, "type_alias_import_mutation")
{
CheckResult result = check("type t10<x> = typeof(table)");
LUAU_REQUIRE_NO_ERRORS(result);
TypeId ty = getGlobalBinding(frontend.typeChecker, "table");
CHECK_EQ(toString(ty), "table");
const TableTypeVar* ttv = get<TableTypeVar>(ty);
REQUIRE(ttv);
CHECK(ttv->instantiatedTypeParams.empty());
}
TEST_CASE_FIXTURE(Fixture, "type_alias_local_mutation")
{
CheckResult result = check(R"(
type Cool = { a: number, b: string }
local c: Cool = { a = 1, b = "s" }
type NotCool<x> = Cool
)");
LUAU_REQUIRE_NO_ERRORS(result);
std::optional<TypeId> ty = requireType("c");
REQUIRE(ty);
CHECK_EQ(toString(*ty), "Cool");
const TableTypeVar* ttv = get<TableTypeVar>(*ty);
REQUIRE(ttv);
CHECK(ttv->instantiatedTypeParams.empty());
}
TEST_CASE_FIXTURE(Fixture, "type_alias_local_rename")
{
CheckResult result = check(R"(
type Cool = { a: number, b: string }
type NotCool = Cool
local c: Cool = { a = 1, b = "s" }
local d: NotCool = { a = 1, b = "s" }
)");
LUAU_REQUIRE_NO_ERRORS(result);
std::optional<TypeId> ty = requireType("c");
REQUIRE(ty);
CHECK_EQ(toString(*ty), "Cool");
ty = requireType("d");
REQUIRE(ty);
CHECK_EQ(toString(*ty), "NotCool");
}
TEST_CASE_FIXTURE(Fixture, "type_alias_local_synthetic_mutation")
{
CheckResult result = check(R"(
local c = { a = 1, b = "s" }
type Cool = typeof(c)
)");
LUAU_REQUIRE_NO_ERRORS(result);
std::optional<TypeId> ty = requireType("c");
REQUIRE(ty);
const TableTypeVar* ttv = get<TableTypeVar>(*ty);
REQUIRE(ttv);
CHECK_EQ(ttv->name, "Cool");
}
TEST_CASE_FIXTURE(Fixture, "type_alias_of_an_imported_recursive_type")
{
ScopedFastFlag luauFixTableTypeAliasClone{"LuauFixTableTypeAliasClone", true};
fileResolver.source["game/A"] = R"(
export type X = { a: number, b: X? }
return {}
)";
CheckResult aResult = frontend.check("game/A");
LUAU_REQUIRE_NO_ERRORS(aResult);
CheckResult bResult = check(R"(
local Import = require(game.A)
type X = Import.X
)");
LUAU_REQUIRE_NO_ERRORS(bResult);
std::optional<TypeId> ty1 = lookupImportedType("Import", "X");
REQUIRE(ty1);
std::optional<TypeId> ty2 = lookupType("X");
REQUIRE(ty2);
CHECK_EQ(follow(*ty1), follow(*ty2));
}
TEST_CASE_FIXTURE(Fixture, "type_alias_of_an_imported_recursive_generic_type")
{
ScopedFastFlag luauFixTableTypeAliasClone{"LuauFixTableTypeAliasClone", true};
fileResolver.source["game/A"] = R"(
export type X<T, U> = { a: T, b: U, C: X<T, U>? }
return {}
)";
CheckResult aResult = frontend.check("game/A");
LUAU_REQUIRE_NO_ERRORS(aResult);
CheckResult bResult = check(R"(
local Import = require(game.A)
type X<T, U> = Import.X<T, U>
)");
LUAU_REQUIRE_NO_ERRORS(bResult);
std::optional<TypeId> ty1 = lookupImportedType("Import", "X");
REQUIRE(ty1);
std::optional<TypeId> ty2 = lookupType("X");
REQUIRE(ty2);
CHECK_EQ(toString(*ty1, {true}), toString(*ty2, {true}));
bResult = check(R"(
local Import = require(game.A)
type X<T, U> = Import.X<U, T>
)");
LUAU_REQUIRE_NO_ERRORS(bResult);
ty1 = lookupImportedType("Import", "X");
REQUIRE(ty1);
ty2 = lookupType("X");
REQUIRE(ty2);
CHECK_EQ(toString(*ty1, {true}), "t1 where t1 = {| C: t1?, a: T, b: U |}");
CHECK_EQ(toString(*ty2, {true}), "{| C: t1, a: U, b: T |} where t1 = {| C: t1, a: U, b: T |}?");
}
TEST_CASE_FIXTURE(Fixture, "nonstrict_self_mismatch_tail")
{
CheckResult result = check(R"(
@ -4579,32 +4126,6 @@ local c = a(2) -- too many arguments
CHECK_EQ("Argument count mismatch. Function expects 1 argument, but 2 are specified", toString(result.errors[0]));
}
TEST_CASE_FIXTURE(Fixture, "module_export_free_type_leak")
{
CheckResult result = check(R"(
function get()
return function(obj) return true end
end
export type f = typeof(get())
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "module_export_wrapped_free_type_leak")
{
CheckResult result = check(R"(
function get()
return {a = 1, b = function(obj) return true end}
end
export type f = typeof(get())
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "custom_require_global")
{
CheckResult result = check(R"(
@ -4787,8 +4308,6 @@ TEST_CASE_FIXTURE(Fixture, "no_heap_use_after_free_error")
TEST_CASE_FIXTURE(Fixture, "dont_invalidate_the_properties_iterator_of_free_table_when_rolled_back")
{
ScopedFastFlag sff{"LuauLogTableTypeVarBoundTo", true};
fileResolver.source["Module/Backend/Types"] = R"(
export type Fiber = {
return_: Fiber?
@ -4868,8 +4387,8 @@ TEST_CASE_FIXTURE(Fixture, "record_matching_overload")
ModulePtr module = getMainModule();
auto it = module->astOverloadResolvedTypes.find(parentExpr);
REQUIRE(it != module->astOverloadResolvedTypes.end());
CHECK_EQ(toString(it->second), "(number) -> number");
REQUIRE(it);
CHECK_EQ(toString(*it), "(number) -> number");
}
TEST_CASE_FIXTURE(Fixture, "infer_anonymous_function_arguments")
@ -5032,76 +4551,6 @@ g12({x=1}, {x=2}, function(x, y) return {x=x.x + y.x} end)
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_ok")
{
CheckResult result = check(R"(
type Tree<T> = { data: T, children: Forest<T> }
type Forest<T> = {Tree<T>}
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_not_ok_1")
{
ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true};
CheckResult result = check(R"(
-- OK because forwarded types are used with their parameters.
type Tree<T> = { data: T, children: Forest<T> }
type Forest<T> = {Tree<{T}>}
)");
LUAU_REQUIRE_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_not_ok_2")
{
ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true};
CheckResult result = check(R"(
-- Not OK because forwarded types are used with different types than their parameters.
type Forest<T> = {Tree<{T}>}
type Tree<T> = { data: T, children: Forest<T> }
)");
LUAU_REQUIRE_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_swapsies_ok")
{
CheckResult result = check(R"(
type Tree1<T,U> = { data: T, children: {Tree2<U,T>} }
type Tree2<U,T> = { data: U, children: {Tree1<T,U>} }
LUAU_REQUIRE_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_swapsies_not_ok")
{
ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true};
CheckResult result = check(R"(
type Tree1<T,U> = { data: T, children: {Tree2<U,T>} }
type Tree2<T,U> = { data: U, children: {Tree1<T,U>} }
)");
LUAU_REQUIRE_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "free_variables_from_typeof_in_aliases")
{
CheckResult result = check(R"(
function f(x) return x[1] end
-- x has type X? for a free type variable X
local x = f ({})
type ContainsFree<a> = { this: a, that: typeof(x) }
type ContainsContainsFree = { that: ContainsFree<number> }
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "infer_generic_lib_function_function_argument")
{
CheckResult result = check(R"(

View file

@ -1,5 +1,6 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Parser.h"
#include "Luau/Scope.h"
#include "Luau/TypeInfer.h"
#include "Luau/TypeVar.h"

View file

@ -294,4 +294,370 @@ end
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "type_alias_type_packs")
{
ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true);
ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true);
ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true);
CheckResult result = check(R"(
type Packed<T...> = (T...) -> T...
local a: Packed<>
local b: Packed<number>
local c: Packed<string, number>
)");
LUAU_REQUIRE_NO_ERRORS(result);
auto tf = lookupType("Packed");
REQUIRE(tf);
CHECK_EQ(toString(*tf), "(T...) -> (T...)");
CHECK_EQ(toString(requireType("a")), "() -> ()");
CHECK_EQ(toString(requireType("b")), "(number) -> number");
CHECK_EQ(toString(requireType("c")), "(string, number) -> (string, number)");
result = check(R"(
-- (U..., T) cannot be parsed right now
type Packed<T, U...> = { f: (a: T, U...) -> (T, U...) }
local a: Packed<number>
local b: Packed<string, number>
local c: Packed<string, number, boolean>
)");
LUAU_REQUIRE_NO_ERRORS(result);
tf = lookupType("Packed");
REQUIRE(tf);
CHECK_EQ(toString(*tf), "Packed<T, U...>");
CHECK_EQ(toString(*tf, {true}), "{| f: (T, U...) -> (T, U...) |}");
auto ttvA = get<TableTypeVar>(requireType("a"));
REQUIRE(ttvA);
CHECK_EQ(toString(requireType("a")), "Packed<number>");
CHECK_EQ(toString(requireType("a"), {true}), "{| f: (number) -> (number) |}");
REQUIRE(ttvA->instantiatedTypeParams.size() == 1);
REQUIRE(ttvA->instantiatedTypePackParams.size() == 1);
CHECK_EQ(toString(ttvA->instantiatedTypeParams[0], {true}), "number");
CHECK_EQ(toString(ttvA->instantiatedTypePackParams[0], {true}), "");
auto ttvB = get<TableTypeVar>(requireType("b"));
REQUIRE(ttvB);
CHECK_EQ(toString(requireType("b")), "Packed<string, number>");
CHECK_EQ(toString(requireType("b"), {true}), "{| f: (string, number) -> (string, number) |}");
REQUIRE(ttvB->instantiatedTypeParams.size() == 1);
REQUIRE(ttvB->instantiatedTypePackParams.size() == 1);
CHECK_EQ(toString(ttvB->instantiatedTypeParams[0], {true}), "string");
CHECK_EQ(toString(ttvB->instantiatedTypePackParams[0], {true}), "number");
auto ttvC = get<TableTypeVar>(requireType("c"));
REQUIRE(ttvC);
CHECK_EQ(toString(requireType("c")), "Packed<string, number, boolean>");
CHECK_EQ(toString(requireType("c"), {true}), "{| f: (string, number, boolean) -> (string, number, boolean) |}");
REQUIRE(ttvC->instantiatedTypeParams.size() == 1);
REQUIRE(ttvC->instantiatedTypePackParams.size() == 1);
CHECK_EQ(toString(ttvC->instantiatedTypeParams[0], {true}), "string");
CHECK_EQ(toString(ttvC->instantiatedTypePackParams[0], {true}), "number, boolean");
}
TEST_CASE_FIXTURE(Fixture, "type_alias_type_packs_import")
{
ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true);
ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true);
ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true);
fileResolver.source["game/A"] = R"(
export type Packed<T, U...> = { a: T, b: (U...) -> () }
return {}
)";
CheckResult aResult = frontend.check("game/A");
LUAU_REQUIRE_NO_ERRORS(aResult);
CheckResult bResult = check(R"(
local Import = require(game.A)
local a: Import.Packed<number>
local b: Import.Packed<string, number>
local c: Import.Packed<string, number, boolean>
local d: { a: typeof(c) }
)");
LUAU_REQUIRE_NO_ERRORS(bResult);
auto tf = lookupImportedType("Import", "Packed");
REQUIRE(tf);
CHECK_EQ(toString(*tf), "Packed<T, U...>");
CHECK_EQ(toString(*tf, {true}), "{| a: T, b: (U...) -> () |}");
CHECK_EQ(toString(requireType("a"), {true}), "{| a: number, b: () -> () |}");
CHECK_EQ(toString(requireType("b"), {true}), "{| a: string, b: (number) -> () |}");
CHECK_EQ(toString(requireType("c"), {true}), "{| a: string, b: (number, boolean) -> () |}");
CHECK_EQ(toString(requireType("d")), "{| a: Packed<string, number, boolean> |}");
}
TEST_CASE_FIXTURE(Fixture, "type_pack_type_parameters")
{
ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true);
ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true);
ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true);
fileResolver.source["game/A"] = R"(
export type Packed<T, U...> = { a: T, b: (U...) -> () }
return {}
)";
CheckResult cResult = check(R"(
local Import = require(game.A)
type Alias<S, T, R...> = Import.Packed<S, (T, R...)>
local a: Alias<string, number, boolean>
type B<X...> = Import.Packed<string, X...>
type C<X...> = Import.Packed<string, (number, X...)>
)");
LUAU_REQUIRE_NO_ERRORS(cResult);
auto tf = lookupType("Alias");
REQUIRE(tf);
CHECK_EQ(toString(*tf), "Alias<S, T, R...>");
CHECK_EQ(toString(*tf, {true}), "{| a: S, b: (T, R...) -> () |}");
CHECK_EQ(toString(requireType("a"), {true}), "{| a: string, b: (number, boolean) -> () |}");
tf = lookupType("B");
REQUIRE(tf);
CHECK_EQ(toString(*tf), "B<X...>");
CHECK_EQ(toString(*tf, {true}), "{| a: string, b: (X...) -> () |}");
tf = lookupType("C");
REQUIRE(tf);
CHECK_EQ(toString(*tf), "C<X...>");
CHECK_EQ(toString(*tf, {true}), "{| a: string, b: (number, X...) -> () |}");
}
TEST_CASE_FIXTURE(Fixture, "type_alias_type_packs_nested")
{
ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true);
ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true);
ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true);
CheckResult result = check(R"(
type Packed1<T...> = (T...) -> (T...)
type Packed2<T...> = (Packed1<T...>, T...) -> (Packed1<T...>, T...)
type Packed3<T...> = (Packed2<T...>, T...) -> (Packed2<T...>, T...)
type Packed4<T...> = (Packed3<T...>, T...) -> (Packed3<T...>, T...)
)");
LUAU_REQUIRE_NO_ERRORS(result);
auto tf = lookupType("Packed4");
REQUIRE(tf);
CHECK_EQ(toString(*tf),
"((((T...) -> (T...), T...) -> ((T...) -> (T...), T...), T...) -> (((T...) -> (T...), T...) -> ((T...) -> (T...), T...), T...), T...) -> "
"((((T...) -> (T...), T...) -> ((T...) -> (T...), T...), T...) -> (((T...) -> (T...), T...) -> ((T...) -> (T...), T...), T...), T...)");
}
TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_variadic")
{
ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true);
ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true);
ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true);
CheckResult result = check(R"(
type X<T...> = (T...) -> (string, T...)
type D = X<...number>
type E = X<(number, ...string)>
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(toString(*lookupType("D")), "(...number) -> (string, ...number)");
CHECK_EQ(toString(*lookupType("E")), "(number, ...string) -> (string, number, ...string)");
}
TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_multi")
{
ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true);
ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true);
ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true);
CheckResult result = check(R"(
type Y<T..., U...> = (T...) -> (U...)
type A<S...> = Y<S..., S...>
type B<S...> = Y<(number, ...string), S...>
type Z<T, U...> = (T) -> (U...)
type E<S...> = Z<number, S...>
type F<S...> = Z<number, (string, S...)>
type W<T, U..., V...> = (T, U...) -> (T, V...)
type H<S..., R...> = W<number, S..., R...>
type I<S..., R...> = W<number, (string, S...), R...>
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(toString(*lookupType("A")), "(S...) -> (S...)");
CHECK_EQ(toString(*lookupType("B")), "(number, ...string) -> (S...)");
CHECK_EQ(toString(*lookupType("E")), "(number) -> (S...)");
CHECK_EQ(toString(*lookupType("F")), "(number) -> (string, S...)");
CHECK_EQ(toString(*lookupType("H")), "(number, S...) -> (number, R...)");
CHECK_EQ(toString(*lookupType("I")), "(number, string, S...) -> (number, R...)");
}
TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_explicit")
{
ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true);
ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true);
ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true);
CheckResult result = check(R"(
type X<T...> = (T...) -> (T...)
type A<S...> = X<(S...)>
type B = X<()>
type C = X<(number)>
type D = X<(number, string)>
type E = X<(...number)>
type F = X<(string, ...number)>
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(toString(*lookupType("A")), "(S...) -> (S...)");
CHECK_EQ(toString(*lookupType("B")), "() -> ()");
CHECK_EQ(toString(*lookupType("C")), "(number) -> number");
CHECK_EQ(toString(*lookupType("D")), "(number, string) -> (number, string)");
CHECK_EQ(toString(*lookupType("E")), "(...number) -> (...number)");
CHECK_EQ(toString(*lookupType("F")), "(string, ...number) -> (string, ...number)");
}
TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_explicit_multi")
{
ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true);
ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true);
ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true);
CheckResult result = check(R"(
type Y<T..., U...> = (T...) -> (U...)
type A = Y<(number, string), (boolean)>
type B = Y<(), ()>
type C<S...> = Y<...string, (number, S...)>
type D<X...> = Y<X..., (number, string, X...)>
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(toString(*lookupType("A")), "(number, string) -> boolean");
CHECK_EQ(toString(*lookupType("B")), "() -> ()");
CHECK_EQ(toString(*lookupType("C")), "(...string) -> (number, S...)");
CHECK_EQ(toString(*lookupType("D")), "(X...) -> (number, string, X...)");
}
TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_explicit_multi_tostring")
{
ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true);
ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true);
ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true);
ScopedFastFlag luauInstantiatedTypeParamRecursion("LuauInstantiatedTypeParamRecursion", true); // For correct toString block
CheckResult result = check(R"(
type Y<T..., U...> = { f: (T...) -> (U...) }
local a: Y<(number, string), (boolean)>
local b: Y<(), ()>
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(toString(requireType("a")), "Y<(number, string), (boolean)>");
CHECK_EQ(toString(requireType("b")), "Y<(), ()>");
}
TEST_CASE_FIXTURE(Fixture, "type_alias_backwards_compatible")
{
ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true);
ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true);
ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true);
CheckResult result = check(R"(
type X<T> = () -> T
type Y<T, U> = (T) -> U
type A = X<(number)>
type B = Y<(number), (boolean)>
type C = Y<(number), boolean>
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(toString(*lookupType("A")), "() -> number");
CHECK_EQ(toString(*lookupType("B")), "(number) -> boolean");
CHECK_EQ(toString(*lookupType("C")), "(number) -> boolean");
}
TEST_CASE_FIXTURE(Fixture, "type_alias_type_packs_errors")
{
ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true);
ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true);
ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true);
CheckResult result = check(R"(
type Packed<T, U, V...> = (T, U) -> (V...)
local b: Packed<number>
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ(toString(result.errors[0]), "Generic type 'Packed<T, U, V...>' expects at least 2 type arguments, but only 1 is specified");
result = check(R"(
type Packed<T, U> = (T, U) -> ()
type B<X...> = Packed<number, string, X...>
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ(toString(result.errors[0]), "Generic type 'Packed<T, U>' expects 0 type pack arguments, but 1 is specified");
result = check(R"(
type Packed<T..., U...> = (T...) -> (U...)
type Other<S...> = Packed<S..., string>
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ(toString(result.errors[0]), "Type parameters must come before type pack parameters");
result = check(R"(
type Packed<T, U> = (T) -> U
type Other<S...> = Packed<number, S...>
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ(toString(result.errors[0]), "Generic type 'Packed<T, U>' expects 2 type arguments, but only 1 is specified");
result = check(R"(
type Packed<T...> = (T...) -> T...
local a: Packed
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ(toString(result.errors[0]), "Type parameter list is required");
result = check(R"(
type Packed<T..., U...> = (T...) -> (U...)
type Other = Packed<>
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ(toString(result.errors[0]), "Generic type 'Packed<T..., U...>' expects 2 type pack arguments, but none are specified");
result = check(R"(
type Packed<T..., U...> = (T...) -> (U...)
type Other = Packed<number, string>
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ(toString(result.errors[0]), "Generic type 'Packed<T..., U...>' expects 2 type pack arguments, but only 1 is specified");
}
TEST_SUITE_END();

View file

@ -237,21 +237,7 @@ TEST_CASE_FIXTURE(Fixture, "union_equality_comparisons")
local z = a == c
)");
if (FFlag::LuauEqConstraint)
{
LUAU_REQUIRE_NO_ERRORS(result);
}
else
{
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ(*typeChecker.booleanType, *requireType("x"));
CHECK_EQ(*typeChecker.booleanType, *requireType("y"));
TypeMismatch* tm = get<TypeMismatch>(result.errors[0]);
REQUIRE(tm);
CHECK_EQ("(number | string)?", toString(*tm->wantedType));
CHECK_EQ("boolean | number", toString(*tm->givenType));
}
}
TEST_CASE_FIXTURE(Fixture, "optional_union_members")

View file

@ -1,5 +1,6 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Parser.h"
#include "Luau/Scope.h"
#include "Luau/TypeInfer.h"
#include "Luau/TypeVar.h"

95
tools/tracegraph.py Normal file
View file

@ -0,0 +1,95 @@
#!/usr/bin/python
# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
# Given a trace event file, this tool generates a flame graph based on the event scopes present in the file
# The result of analysis is a .svg file which can be viewed in a browser
import sys
import svg
import json
class Node(svg.Node):
def __init__(self):
svg.Node.__init__(self)
self.caption = ""
self.description = ""
self.ticks = 0
def text(self):
return self.caption
def title(self):
return self.caption
def details(self, root):
return "{} ({:,} usec, {:.1%}); self: {:,} usec".format(self.description, self.width, self.width / root.width, self.ticks)
with open(sys.argv[1]) as f:
dump = f.read()
root = Node()
# Finish the file
if not dump.endswith("]"):
dump += "{}]"
data = json.loads(dump)
stacks = {}
for l in data:
if len(l) == 0:
continue
# Track stack of each thread, but aggregate values together
tid = l["tid"]
if not tid in stacks:
stacks[tid] = []
stack = stacks[tid]
if l["ph"] == 'B':
stack.append(l)
elif l["ph"] == 'E':
node = root
for e in stack:
caption = e["name"]
description = ''
if "args" in e:
for arg in e["args"]:
if len(description) != 0:
description += ", "
description += "{}: {}".format(arg, e["args"][arg])
child = node.child(caption + description)
child.caption = caption
child.description = description
node = child
begin = stack[-1]
ticks = l["ts"] - begin["ts"]
rawticks = ticks
# Flame graph requires ticks without children duration
if "childts" in begin:
ticks -= begin["childts"]
node.ticks += int(ticks)
stack.pop()
if len(stack):
parent = stack[-1]
if "childts" in parent:
parent["childts"] += rawticks
else:
parent["childts"] = rawticks
svg.layout(root, lambda n: n.ticks)
svg.display(root, "Flame Graph", "hot", flip = True)