mirror of
https://github.com/luau-lang/luau.git
synced 2024-12-12 13:00:38 +00:00
Sync to upstream/release/503 (#135)
- A series of major optimizations to type checking performance on complex programs/types (up to two orders of magnitude speedup for programs involving huge tagged unions) - Fix a few issues encountered by UBSAN (and maybe fix s390x builds) - Fix gcc-11 test builds - Fix a rare corner case where luau_load wouldn't wake inactive threads which could result in a use-after-free due to GC - Fix CLI crash when error object that's not a string escapes to top level - Fix Makefile suffixes on macOS Co-authored-by: Rodactor <rodactor@roblox.com>
This commit is contained in:
parent
c0b95b8961
commit
279855df91
54 changed files with 2201 additions and 623 deletions
|
@ -34,7 +34,6 @@ TypeId makeFunction( // Polymorphic
|
|||
std::initializer_list<TypeId> paramTypes, std::initializer_list<std::string> paramNames, std::initializer_list<TypeId> retTypes);
|
||||
|
||||
void attachMagicFunction(TypeId ty, MagicFunction fn);
|
||||
void attachFunctionTag(TypeId ty, std::string constraint);
|
||||
|
||||
Property makeProperty(TypeId ty, std::optional<std::string> documentationSymbol = std::nullopt);
|
||||
void assignPropDocumentationSymbols(TableTypeVar::Props& props, const std::string& baseName);
|
||||
|
|
14
Analysis/include/Luau/Quantify.h
Normal file
14
Analysis/include/Luau/Quantify.h
Normal file
|
@ -0,0 +1,14 @@
|
|||
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
|
||||
#pragma once
|
||||
|
||||
#include "Luau/TypeVar.h"
|
||||
|
||||
namespace Luau
|
||||
{
|
||||
|
||||
struct Module;
|
||||
using ModulePtr = std::shared_ptr<Module>;
|
||||
|
||||
void quantify(ModulePtr module, TypeId ty, TypeLevel level);
|
||||
|
||||
} // namespace Luau
|
|
@ -69,4 +69,6 @@ std::string toString(const TypePackVar& tp, const ToStringOptions& opts = {});
|
|||
void dump(TypeId ty);
|
||||
void dump(TypePackId ty);
|
||||
|
||||
std::string generateName(size_t n);
|
||||
|
||||
} // namespace Luau
|
||||
|
|
|
@ -12,6 +12,7 @@ struct AstArray;
|
|||
class AstStat;
|
||||
|
||||
bool containsFunctionCall(const AstStat& stat);
|
||||
bool containsFunctionCallOrReturn(const AstStat& stat);
|
||||
bool isFunction(const AstStat& stat);
|
||||
void toposort(std::vector<AstStat*>& stats);
|
||||
|
||||
|
|
|
@ -3,19 +3,37 @@
|
|||
|
||||
#include "Luau/TypeVar.h"
|
||||
|
||||
LUAU_FASTFLAG(LuauShareTxnSeen);
|
||||
|
||||
namespace Luau
|
||||
{
|
||||
|
||||
// Log of where what TypeIds we are rebinding and what they used to be
|
||||
struct TxnLog
|
||||
{
|
||||
TxnLog() = default;
|
||||
|
||||
explicit TxnLog(const std::vector<std::pair<TypeId, TypeId>>& seen)
|
||||
: seen(seen)
|
||||
TxnLog()
|
||||
: originalSeenSize(0)
|
||||
, ownedSeen()
|
||||
, sharedSeen(&ownedSeen)
|
||||
{
|
||||
}
|
||||
|
||||
explicit TxnLog(std::vector<std::pair<TypeId, TypeId>>* sharedSeen)
|
||||
: originalSeenSize(sharedSeen->size())
|
||||
, ownedSeen()
|
||||
, sharedSeen(sharedSeen)
|
||||
{
|
||||
}
|
||||
|
||||
explicit TxnLog(const std::vector<std::pair<TypeId, TypeId>>& ownedSeen)
|
||||
: originalSeenSize(ownedSeen.size())
|
||||
, ownedSeen(ownedSeen)
|
||||
, sharedSeen(nullptr)
|
||||
{
|
||||
// This is deprecated!
|
||||
LUAU_ASSERT(!FFlag::LuauShareTxnSeen);
|
||||
}
|
||||
|
||||
TxnLog(const TxnLog&) = delete;
|
||||
TxnLog& operator=(const TxnLog&) = delete;
|
||||
|
||||
|
@ -38,9 +56,11 @@ private:
|
|||
std::vector<std::pair<TypeId, TypeVar>> typeVarChanges;
|
||||
std::vector<std::pair<TypePackId, TypePackVar>> typePackChanges;
|
||||
std::vector<std::pair<TableTypeVar*, std::optional<TypeId>>> tableChanges;
|
||||
size_t originalSeenSize;
|
||||
|
||||
public:
|
||||
std::vector<std::pair<TypeId, TypeId>> seen; // used to avoid infinite recursion when types are cyclic
|
||||
std::vector<std::pair<TypeId, TypeId>> ownedSeen; // used to avoid infinite recursion when types are cyclic
|
||||
std::vector<std::pair<TypeId, TypeId>>* sharedSeen; // shared with all the descendent logs
|
||||
};
|
||||
|
||||
} // namespace Luau
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
#include "Luau/TypePack.h"
|
||||
#include "Luau/TypeVar.h"
|
||||
#include "Luau/Unifier.h"
|
||||
#include "Luau/UnifierSharedState.h"
|
||||
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
|
@ -121,7 +122,7 @@ struct TypeChecker
|
|||
void check(const ScopePtr& scope, const AstStatForIn& forin);
|
||||
void check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatFunction& function);
|
||||
void check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatLocalFunction& function);
|
||||
void check(const ScopePtr& scope, const AstStatTypeAlias& typealias, bool forwardDeclare = false);
|
||||
void check(const ScopePtr& scope, const AstStatTypeAlias& typealias, int subLevel = 0, bool forwardDeclare = false);
|
||||
void check(const ScopePtr& scope, const AstStatDeclareClass& declaredClass);
|
||||
void check(const ScopePtr& scope, const AstStatDeclareFunction& declaredFunction);
|
||||
|
||||
|
@ -336,7 +337,7 @@ private:
|
|||
|
||||
// Note: `scope` must be a fresh scope.
|
||||
std::pair<std::vector<TypeId>, std::vector<TypePackId>> createGenericTypes(
|
||||
const ScopePtr& scope, const AstNode& node, const AstArray<AstName>& genericNames, const AstArray<AstName>& genericPackNames);
|
||||
const ScopePtr& scope, std::optional<TypeLevel> levelOpt, const AstNode& node, const AstArray<AstName>& genericNames, const AstArray<AstName>& genericPackNames);
|
||||
|
||||
public:
|
||||
ErrorVec resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense);
|
||||
|
@ -383,6 +384,8 @@ public:
|
|||
std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope;
|
||||
InternalErrorReporter* iceHandler;
|
||||
|
||||
UnifierSharedState unifierState;
|
||||
|
||||
public:
|
||||
const TypeId nilType;
|
||||
const TypeId numberType;
|
||||
|
|
|
@ -540,4 +540,11 @@ UnionTypeVarIterator end(const UnionTypeVar* utv);
|
|||
using TypeIdPredicate = std::function<std::optional<TypeId>(TypeId)>;
|
||||
std::vector<TypeId> filterMap(TypeId type, TypeIdPredicate predicate);
|
||||
|
||||
void attachTag(TypeId ty, const std::string& tagName);
|
||||
void attachTag(Property& prop, const std::string& tagName);
|
||||
|
||||
bool hasTag(TypeId ty, const std::string& tagName);
|
||||
bool hasTag(const Property& prop, const std::string& tagName);
|
||||
bool hasTag(const Tags& tags, const std::string& tagName); // Do not use in new work.
|
||||
|
||||
} // namespace Luau
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
#include "Luau/TxnLog.h"
|
||||
#include "Luau/TypeInfer.h"
|
||||
#include "Luau/Module.h" // FIXME: For TypeArena. It merits breaking out into its own header.
|
||||
#include "Luau/UnifierSharedState.h"
|
||||
|
||||
#include <unordered_set>
|
||||
|
||||
|
@ -41,11 +42,14 @@ struct Unifier
|
|||
|
||||
std::shared_ptr<UnifierCounters> counters_DEPRECATED;
|
||||
|
||||
InternalErrorReporter* iceHandler;
|
||||
UnifierSharedState& sharedState;
|
||||
|
||||
Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, InternalErrorReporter* iceHandler);
|
||||
Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::vector<std::pair<TypeId, TypeId>>& seen, const Location& location,
|
||||
Variance variance, InternalErrorReporter* iceHandler, const std::shared_ptr<UnifierCounters>& counters_DEPRECATED = nullptr,
|
||||
Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, UnifierSharedState& sharedState);
|
||||
Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::vector<std::pair<TypeId, TypeId>>& ownedSeen, const Location& location,
|
||||
Variance variance, UnifierSharedState& sharedState, const std::shared_ptr<UnifierCounters>& counters_DEPRECATED = nullptr,
|
||||
UnifierCounters* counters = nullptr);
|
||||
Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, std::vector<std::pair<TypeId, TypeId>>* sharedSeen, const Location& location,
|
||||
Variance variance, UnifierSharedState& sharedState, const std::shared_ptr<UnifierCounters>& counters_DEPRECATED = nullptr,
|
||||
UnifierCounters* counters = nullptr);
|
||||
|
||||
// Test whether the two type vars unify. Never commits the result.
|
||||
|
@ -69,7 +73,8 @@ private:
|
|||
void tryUnifyWithMetatable(TypeId metatable, TypeId other, bool reversed);
|
||||
void tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed);
|
||||
void tryUnify(const TableIndexer& superIndexer, const TableIndexer& subIndexer);
|
||||
TypeId deeplyOptional(TypeId ty, std::unordered_map<TypeId,TypeId> seen = {});
|
||||
TypeId deeplyOptional(TypeId ty, std::unordered_map<TypeId, TypeId> seen = {});
|
||||
void cacheResult(TypeId superTy, TypeId subTy);
|
||||
|
||||
public:
|
||||
void tryUnify(TypePackId superTy, TypePackId subTy, bool isFunctionCall = false);
|
||||
|
@ -101,8 +106,9 @@ private:
|
|||
[[noreturn]] void ice(const std::string& message, const Location& location);
|
||||
[[noreturn]] void ice(const std::string& message);
|
||||
|
||||
DenseHashSet<TypeId> tempSeenTy{nullptr};
|
||||
DenseHashSet<TypePackId> tempSeenTp{nullptr};
|
||||
// Remove with FFlagLuauCacheUnifyTableResults
|
||||
DenseHashSet<TypeId> tempSeenTy_DEPRECATED{nullptr};
|
||||
DenseHashSet<TypePackId> tempSeenTp_DEPRECATED{nullptr};
|
||||
};
|
||||
|
||||
} // namespace Luau
|
||||
|
|
44
Analysis/include/Luau/UnifierSharedState.h
Normal file
44
Analysis/include/Luau/UnifierSharedState.h
Normal file
|
@ -0,0 +1,44 @@
|
|||
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
|
||||
#pragma once
|
||||
|
||||
#include "Luau/DenseHash.h"
|
||||
#include "Luau/TypeVar.h"
|
||||
#include "Luau/TypePack.h"
|
||||
|
||||
#include <utility>
|
||||
|
||||
namespace Luau
|
||||
{
|
||||
struct InternalErrorReporter;
|
||||
|
||||
struct TypeIdPairHash
|
||||
{
|
||||
size_t hashOne(Luau::TypeId key) const
|
||||
{
|
||||
return (uintptr_t(key) >> 4) ^ (uintptr_t(key) >> 9);
|
||||
}
|
||||
|
||||
size_t operator()(const std::pair<Luau::TypeId, Luau::TypeId>& x) const
|
||||
{
|
||||
return hashOne(x.first) ^ (hashOne(x.second) << 1);
|
||||
}
|
||||
};
|
||||
|
||||
struct UnifierSharedState
|
||||
{
|
||||
UnifierSharedState(InternalErrorReporter* iceHandler)
|
||||
: iceHandler(iceHandler)
|
||||
{
|
||||
}
|
||||
|
||||
InternalErrorReporter* iceHandler;
|
||||
|
||||
DenseHashSet<void*> seenAny{nullptr};
|
||||
DenseHashMap<TypeId, bool> skipCacheForType{nullptr};
|
||||
DenseHashSet<std::pair<TypeId, TypeId>, TypeIdPairHash> cachedUnify{{nullptr, nullptr}};
|
||||
|
||||
DenseHashSet<TypeId> tempSeenTy{nullptr};
|
||||
DenseHashSet<TypePackId> tempSeenTp{nullptr};
|
||||
};
|
||||
|
||||
} // namespace Luau
|
|
@ -1,9 +1,12 @@
|
|||
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
|
||||
#pragma once
|
||||
|
||||
#include "Luau/DenseHash.h"
|
||||
#include "Luau/TypeVar.h"
|
||||
#include "Luau/TypePack.h"
|
||||
|
||||
LUAU_FASTFLAG(LuauCacheUnifyTableResults)
|
||||
|
||||
namespace Luau
|
||||
{
|
||||
|
||||
|
@ -32,17 +35,33 @@ inline bool hasSeen(std::unordered_set<void*>& seen, const void* tv)
|
|||
return !seen.insert(ttv).second;
|
||||
}
|
||||
|
||||
inline bool hasSeen(DenseHashSet<void*>& seen, const void* tv)
|
||||
{
|
||||
void* ttv = const_cast<void*>(tv);
|
||||
|
||||
if (seen.contains(ttv))
|
||||
return true;
|
||||
|
||||
seen.insert(ttv);
|
||||
return false;
|
||||
}
|
||||
|
||||
inline void unsee(std::unordered_set<void*>& seen, const void* tv)
|
||||
{
|
||||
void* ttv = const_cast<void*>(tv);
|
||||
seen.erase(ttv);
|
||||
}
|
||||
|
||||
template<typename F>
|
||||
void visit(TypePackId tp, F& f, std::unordered_set<void*>& seen);
|
||||
inline void unsee(DenseHashSet<void*>& seen, const void* tv)
|
||||
{
|
||||
// When DenseHashSet is used for 'visitOnce', where don't forget visited elements
|
||||
}
|
||||
|
||||
template<typename F>
|
||||
void visit(TypeId ty, F& f, std::unordered_set<void*>& seen)
|
||||
template<typename F, typename Set>
|
||||
void visit(TypePackId tp, F& f, Set& seen);
|
||||
|
||||
template<typename F, typename Set>
|
||||
void visit(TypeId ty, F& f, Set& seen)
|
||||
{
|
||||
if (visit_detail::hasSeen(seen, ty))
|
||||
{
|
||||
|
@ -79,15 +98,23 @@ void visit(TypeId ty, F& f, std::unordered_set<void*>& seen)
|
|||
|
||||
else if (auto ttv = get<TableTypeVar>(ty))
|
||||
{
|
||||
// Some visitors want to see bound tables, that's why we visit the original type
|
||||
if (apply(ty, *ttv, seen, f))
|
||||
{
|
||||
for (auto& [_name, prop] : ttv->props)
|
||||
visit(prop.type, f, seen);
|
||||
|
||||
if (ttv->indexer)
|
||||
if (FFlag::LuauCacheUnifyTableResults && ttv->boundTo)
|
||||
{
|
||||
visit(ttv->indexer->indexType, f, seen);
|
||||
visit(ttv->indexer->indexResultType, f, seen);
|
||||
visit(*ttv->boundTo, f, seen);
|
||||
}
|
||||
else
|
||||
{
|
||||
for (auto& [_name, prop] : ttv->props)
|
||||
visit(prop.type, f, seen);
|
||||
|
||||
if (ttv->indexer)
|
||||
{
|
||||
visit(ttv->indexer->indexType, f, seen);
|
||||
visit(ttv->indexer->indexResultType, f, seen);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -140,8 +167,8 @@ void visit(TypeId ty, F& f, std::unordered_set<void*>& seen)
|
|||
visit_detail::unsee(seen, ty);
|
||||
}
|
||||
|
||||
template<typename F>
|
||||
void visit(TypePackId tp, F& f, std::unordered_set<void*>& seen)
|
||||
template<typename F, typename Set>
|
||||
void visit(TypePackId tp, F& f, Set& seen)
|
||||
{
|
||||
if (visit_detail::hasSeen(seen, tp))
|
||||
{
|
||||
|
@ -182,6 +209,7 @@ void visit(TypePackId tp, F& f, std::unordered_set<void*>& seen)
|
|||
|
||||
visit_detail::unsee(seen, tp);
|
||||
}
|
||||
|
||||
} // namespace visit_detail
|
||||
|
||||
template<typename TID, typename F>
|
||||
|
@ -197,4 +225,11 @@ void visitTypeVar(TID ty, F& f)
|
|||
visit_detail::visit(ty, f, seen);
|
||||
}
|
||||
|
||||
template<typename TID, typename F>
|
||||
void visitTypeVarOnce(TID ty, F& f, DenseHashSet<void*>& seen)
|
||||
{
|
||||
seen.clear();
|
||||
visit_detail::visit(ty, f, seen);
|
||||
}
|
||||
|
||||
} // namespace Luau
|
||||
|
|
|
@ -196,7 +196,8 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ
|
|||
|
||||
auto canUnify = [&typeArena, &module](TypeId expectedType, TypeId actualType) {
|
||||
InternalErrorReporter iceReporter;
|
||||
Unifier unifier(typeArena, Mode::Strict, module.getModuleScope(), Location(), Variance::Covariant, &iceReporter);
|
||||
UnifierSharedState unifierState(&iceReporter);
|
||||
Unifier unifier(typeArena, Mode::Strict, module.getModuleScope(), Location(), Variance::Covariant, unifierState);
|
||||
|
||||
unifier.tryUnify(expectedType, actualType);
|
||||
|
||||
|
|
|
@ -106,18 +106,6 @@ void attachMagicFunction(TypeId ty, MagicFunction fn)
|
|||
LUAU_ASSERT(!"Got a non functional type");
|
||||
}
|
||||
|
||||
void attachFunctionTag(TypeId ty, std::string tag)
|
||||
{
|
||||
if (auto ftv = getMutable<FunctionTypeVar>(ty))
|
||||
{
|
||||
ftv->tags.emplace_back(std::move(tag));
|
||||
}
|
||||
else
|
||||
{
|
||||
LUAU_ASSERT(!"Got a non functional type");
|
||||
}
|
||||
}
|
||||
|
||||
Property makeProperty(TypeId ty, std::optional<std::string> documentationSymbol)
|
||||
{
|
||||
return {
|
||||
|
|
|
@ -23,6 +23,7 @@ LUAU_FASTFLAGVARIABLE(LuauResolveModuleNameWithoutACurrentModule, false)
|
|||
LUAU_FASTFLAG(LuauTraceRequireLookupChild)
|
||||
LUAU_FASTFLAGVARIABLE(LuauPersistDefinitionFileTypes, false)
|
||||
LUAU_FASTFLAG(LuauNewRequireTrace)
|
||||
LUAU_FASTFLAGVARIABLE(LuauClearScopes, false)
|
||||
|
||||
namespace Luau
|
||||
{
|
||||
|
@ -248,7 +249,7 @@ struct RequireCycle
|
|||
// Note that this is O(V^2) for a fully connected graph and produces O(V) paths of length O(V)
|
||||
// However, when the graph is acyclic, this is O(V), as well as when only the first cycle is needed (stopAtFirst=true)
|
||||
std::vector<RequireCycle> getRequireCycles(
|
||||
const std::unordered_map<ModuleName, SourceNode>& sourceNodes, const SourceNode* start, bool stopAtFirst = false)
|
||||
const FileResolver* resolver, const std::unordered_map<ModuleName, SourceNode>& sourceNodes, const SourceNode* start, bool stopAtFirst = false)
|
||||
{
|
||||
std::vector<RequireCycle> result;
|
||||
|
||||
|
@ -282,9 +283,9 @@ std::vector<RequireCycle> getRequireCycles(
|
|||
if (top == start)
|
||||
{
|
||||
for (const SourceNode* node : path)
|
||||
cycle.push_back(node->name);
|
||||
cycle.push_back(resolver->getHumanReadableModuleName(node->name));
|
||||
|
||||
cycle.push_back(top->name);
|
||||
cycle.push_back(resolver->getHumanReadableModuleName(top->name));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -404,7 +405,7 @@ CheckResult Frontend::check(const ModuleName& name)
|
|||
// however, for now getRequireCycles isn't expensive in practice on the cases we care about, and long term
|
||||
// all correct programs must be acyclic so this code triggers rarely
|
||||
if (cycleDetected)
|
||||
requireCycles = getRequireCycles(sourceNodes, &sourceNode, mode == Mode::NoCheck);
|
||||
requireCycles = getRequireCycles(fileResolver, sourceNodes, &sourceNode, mode == Mode::NoCheck);
|
||||
|
||||
// This is used by the type checker to replace the resulting type of cyclic modules with any
|
||||
sourceModule.cyclic = !requireCycles.empty();
|
||||
|
@ -458,6 +459,8 @@ CheckResult Frontend::check(const ModuleName& name)
|
|||
module->astTypes.clear();
|
||||
module->astExpectedTypes.clear();
|
||||
module->astOriginalCallTypes.clear();
|
||||
if (FFlag::LuauClearScopes)
|
||||
module->scopes.resize(1);
|
||||
}
|
||||
|
||||
if (mode != Mode::NoCheck)
|
||||
|
|
|
@ -15,6 +15,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauTrackOwningArena, false)
|
|||
LUAU_FASTFLAG(LuauSecondTypecheckKnowsTheDataModel)
|
||||
LUAU_FASTFLAG(LuauCaptureBrokenCommentSpans)
|
||||
LUAU_FASTFLAG(LuauTypeAliasPacks)
|
||||
LUAU_FASTFLAGVARIABLE(LuauCloneBoundTables, false)
|
||||
|
||||
namespace Luau
|
||||
{
|
||||
|
@ -299,6 +300,14 @@ void TypeCloner::operator()(const FunctionTypeVar& t)
|
|||
|
||||
void TypeCloner::operator()(const TableTypeVar& t)
|
||||
{
|
||||
// If table is now bound to another one, we ignore the content of the original
|
||||
if (FFlag::LuauCloneBoundTables && t.boundTo)
|
||||
{
|
||||
TypeId boundTo = clone(*t.boundTo, dest, seenTypes, seenTypePacks, encounteredFreeType);
|
||||
seenTypes[typeId] = boundTo;
|
||||
return;
|
||||
}
|
||||
|
||||
TypeId result = dest.addType(TableTypeVar{});
|
||||
TableTypeVar* ttv = getMutable<TableTypeVar>(result);
|
||||
LUAU_ASSERT(ttv != nullptr);
|
||||
|
@ -321,8 +330,11 @@ void TypeCloner::operator()(const TableTypeVar& t)
|
|||
ttv->indexer = TableIndexer{clone(t.indexer->indexType, dest, seenTypes, seenTypePacks, encounteredFreeType),
|
||||
clone(t.indexer->indexResultType, dest, seenTypes, seenTypePacks, encounteredFreeType)};
|
||||
|
||||
if (t.boundTo)
|
||||
ttv->boundTo = clone(*t.boundTo, dest, seenTypes, seenTypePacks, encounteredFreeType);
|
||||
if (!FFlag::LuauCloneBoundTables)
|
||||
{
|
||||
if (t.boundTo)
|
||||
ttv->boundTo = clone(*t.boundTo, dest, seenTypes, seenTypePacks, encounteredFreeType);
|
||||
}
|
||||
|
||||
for (TypeId& arg : ttv->instantiatedTypeParams)
|
||||
arg = clone(arg, dest, seenTypes, seenTypePacks, encounteredFreeType);
|
||||
|
@ -335,7 +347,7 @@ void TypeCloner::operator()(const TableTypeVar& t)
|
|||
|
||||
if (ttv->state == TableState::Free)
|
||||
{
|
||||
if (!t.boundTo)
|
||||
if (FFlag::LuauCloneBoundTables || !t.boundTo)
|
||||
{
|
||||
if (encounteredFreeType)
|
||||
*encounteredFreeType = true;
|
||||
|
|
90
Analysis/src/Quantify.cpp
Normal file
90
Analysis/src/Quantify.cpp
Normal file
|
@ -0,0 +1,90 @@
|
|||
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
|
||||
|
||||
#include "Luau/Quantify.h"
|
||||
|
||||
#include "Luau/VisitTypeVar.h"
|
||||
|
||||
namespace Luau
|
||||
{
|
||||
|
||||
struct Quantifier
|
||||
{
|
||||
ModulePtr module;
|
||||
TypeLevel level;
|
||||
std::vector<TypeId> generics;
|
||||
std::vector<TypePackId> genericPacks;
|
||||
|
||||
Quantifier(ModulePtr module, TypeLevel level)
|
||||
: module(module)
|
||||
, level(level)
|
||||
{
|
||||
}
|
||||
|
||||
void cycle(TypeId) {}
|
||||
void cycle(TypePackId) {}
|
||||
|
||||
bool operator()(TypeId ty, const FreeTypeVar& ftv)
|
||||
{
|
||||
if (!level.subsumes(ftv.level))
|
||||
return false;
|
||||
|
||||
*asMutable(ty) = GenericTypeVar{level};
|
||||
generics.push_back(ty);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
bool operator()(TypeId ty, const T& t)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
bool operator()(TypePackId, const T&)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
bool operator()(TypeId ty, const TableTypeVar&)
|
||||
{
|
||||
TableTypeVar& ttv = *getMutable<TableTypeVar>(ty);
|
||||
|
||||
if (ttv.state == TableState::Sealed || ttv.state == TableState::Generic)
|
||||
return false;
|
||||
if (!level.subsumes(ttv.level))
|
||||
return false;
|
||||
|
||||
if (ttv.state == TableState::Free)
|
||||
ttv.state = TableState::Generic;
|
||||
else if (ttv.state == TableState::Unsealed)
|
||||
ttv.state = TableState::Sealed;
|
||||
|
||||
ttv.level = level;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool operator()(TypePackId tp, const FreeTypePack& ftp)
|
||||
{
|
||||
if (!level.subsumes(ftp.level))
|
||||
return false;
|
||||
|
||||
*asMutable(tp) = GenericTypePack{level};
|
||||
genericPacks.push_back(tp);
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
void quantify(ModulePtr module, TypeId ty, TypeLevel level)
|
||||
{
|
||||
Quantifier q{std::move(module), level};
|
||||
visitTypeVar(ty, q);
|
||||
|
||||
FunctionTypeVar* ftv = getMutable<FunctionTypeVar>(ty);
|
||||
LUAU_ASSERT(ftv);
|
||||
ftv->generics = q.generics;
|
||||
ftv->genericPacks = q.genericPacks;
|
||||
}
|
||||
|
||||
} // namespace Luau
|
|
@ -182,7 +182,7 @@ struct RequireTracerOld : AstVisitor
|
|||
|
||||
struct RequireTracer : AstVisitor
|
||||
{
|
||||
RequireTracer(RequireTraceResult& result, FileResolver * fileResolver, const ModuleName& currentModuleName)
|
||||
RequireTracer(RequireTraceResult& result, FileResolver* fileResolver, const ModuleName& currentModuleName)
|
||||
: result(result)
|
||||
, fileResolver(fileResolver)
|
||||
, currentModuleName(currentModuleName)
|
||||
|
@ -260,7 +260,7 @@ struct RequireTracer : AstVisitor
|
|||
// seed worklist with require arguments
|
||||
work.reserve(requires.size());
|
||||
|
||||
for (AstExprCall* require: requires)
|
||||
for (AstExprCall* require : requires)
|
||||
work.push_back(require->args.data[0]);
|
||||
|
||||
// push all dependent expressions to the work stack; note that the vector is modified during traversal
|
||||
|
|
|
@ -10,7 +10,6 @@
|
|||
#include <algorithm>
|
||||
#include <stdexcept>
|
||||
|
||||
LUAU_FASTFLAG(LuauExtraNilRecovery)
|
||||
LUAU_FASTFLAG(LuauOccursCheckOkWithRecursiveFunctions)
|
||||
LUAU_FASTFLAGVARIABLE(LuauInstantiatedTypeParamRecursion, false)
|
||||
LUAU_FASTFLAG(LuauTypeAliasPacks)
|
||||
|
@ -159,15 +158,6 @@ struct StringifierState
|
|||
seen.erase(iter);
|
||||
}
|
||||
|
||||
static std::string generateName(size_t i)
|
||||
{
|
||||
std::string n;
|
||||
n = char('a' + i % 26);
|
||||
if (i >= 26)
|
||||
n += std::to_string(i / 26);
|
||||
return n;
|
||||
}
|
||||
|
||||
std::string getName(TypeId ty)
|
||||
{
|
||||
const size_t s = result.nameMap.typeVars.size();
|
||||
|
@ -584,8 +574,7 @@ struct TypeVarStringifier
|
|||
std::vector<std::string> results = {};
|
||||
for (auto el : &uv)
|
||||
{
|
||||
if (FFlag::LuauExtraNilRecovery || FFlag::LuauAddMissingFollow)
|
||||
el = follow(el);
|
||||
el = follow(el);
|
||||
|
||||
if (isNil(el))
|
||||
{
|
||||
|
@ -649,8 +638,7 @@ struct TypeVarStringifier
|
|||
std::vector<std::string> results = {};
|
||||
for (auto el : uv.parts)
|
||||
{
|
||||
if (FFlag::LuauExtraNilRecovery || FFlag::LuauAddMissingFollow)
|
||||
el = follow(el);
|
||||
el = follow(el);
|
||||
|
||||
std::string saved = std::move(state.result.name);
|
||||
|
||||
|
@ -1204,4 +1192,13 @@ void dump(TypePackId ty)
|
|||
printf("%s\n", toString(ty, opts).c_str());
|
||||
}
|
||||
|
||||
std::string generateName(size_t i)
|
||||
{
|
||||
std::string n;
|
||||
n = char('a' + i % 26);
|
||||
if (i >= 26)
|
||||
n += std::to_string(i / 26);
|
||||
return n;
|
||||
}
|
||||
|
||||
} // namespace Luau
|
||||
|
|
|
@ -298,8 +298,15 @@ struct ArcCollector : public AstVisitor
|
|||
|
||||
struct ContainsFunctionCall : public AstVisitor
|
||||
{
|
||||
bool alsoReturn = false;
|
||||
bool result = false;
|
||||
|
||||
ContainsFunctionCall() = default;
|
||||
explicit ContainsFunctionCall(bool alsoReturn)
|
||||
: alsoReturn(alsoReturn)
|
||||
{
|
||||
}
|
||||
|
||||
bool visit(AstExpr*) override
|
||||
{
|
||||
return !result; // short circuit if result is true
|
||||
|
@ -318,6 +325,17 @@ struct ContainsFunctionCall : public AstVisitor
|
|||
return false;
|
||||
}
|
||||
|
||||
bool visit(AstStatReturn* stat) override
|
||||
{
|
||||
if (alsoReturn)
|
||||
{
|
||||
result = true;
|
||||
return false;
|
||||
}
|
||||
else
|
||||
return AstVisitor::visit(stat);
|
||||
}
|
||||
|
||||
bool visit(AstExprFunction*) override
|
||||
{
|
||||
return false;
|
||||
|
@ -479,6 +497,13 @@ bool containsFunctionCall(const AstStat& stat)
|
|||
return cfc.result;
|
||||
}
|
||||
|
||||
bool containsFunctionCallOrReturn(const AstStat& stat)
|
||||
{
|
||||
detail::ContainsFunctionCall cfc{true};
|
||||
const_cast<AstStat&>(stat).visit(&cfc);
|
||||
return cfc.result;
|
||||
}
|
||||
|
||||
bool isFunction(const AstStat& stat)
|
||||
{
|
||||
return stat.is<AstStatFunction>() || stat.is<AstStatLocalFunction>();
|
||||
|
|
|
@ -5,6 +5,8 @@
|
|||
|
||||
#include <algorithm>
|
||||
|
||||
LUAU_FASTFLAGVARIABLE(LuauShareTxnSeen, false)
|
||||
|
||||
namespace Luau
|
||||
{
|
||||
|
||||
|
@ -33,6 +35,12 @@ void TxnLog::rollback()
|
|||
|
||||
for (auto it = tableChanges.rbegin(); it != tableChanges.rend(); ++it)
|
||||
std::swap(it->first->boundTo, it->second);
|
||||
|
||||
if (FFlag::LuauShareTxnSeen)
|
||||
{
|
||||
LUAU_ASSERT(originalSeenSize <= sharedSeen->size());
|
||||
sharedSeen->resize(originalSeenSize);
|
||||
}
|
||||
}
|
||||
|
||||
void TxnLog::concat(TxnLog rhs)
|
||||
|
@ -46,27 +54,44 @@ void TxnLog::concat(TxnLog rhs)
|
|||
tableChanges.insert(tableChanges.end(), rhs.tableChanges.begin(), rhs.tableChanges.end());
|
||||
rhs.tableChanges.clear();
|
||||
|
||||
seen.swap(rhs.seen);
|
||||
rhs.seen.clear();
|
||||
if (!FFlag::LuauShareTxnSeen)
|
||||
{
|
||||
ownedSeen.swap(rhs.ownedSeen);
|
||||
rhs.ownedSeen.clear();
|
||||
}
|
||||
}
|
||||
|
||||
bool TxnLog::haveSeen(TypeId lhs, TypeId rhs)
|
||||
{
|
||||
const std::pair<TypeId, TypeId> sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs);
|
||||
return (seen.end() != std::find(seen.begin(), seen.end(), sortedPair));
|
||||
if (FFlag::LuauShareTxnSeen)
|
||||
return (sharedSeen->end() != std::find(sharedSeen->begin(), sharedSeen->end(), sortedPair));
|
||||
else
|
||||
return (ownedSeen.end() != std::find(ownedSeen.begin(), ownedSeen.end(), sortedPair));
|
||||
}
|
||||
|
||||
void TxnLog::pushSeen(TypeId lhs, TypeId rhs)
|
||||
{
|
||||
const std::pair<TypeId, TypeId> sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs);
|
||||
seen.push_back(sortedPair);
|
||||
if (FFlag::LuauShareTxnSeen)
|
||||
sharedSeen->push_back(sortedPair);
|
||||
else
|
||||
ownedSeen.push_back(sortedPair);
|
||||
}
|
||||
|
||||
void TxnLog::popSeen(TypeId lhs, TypeId rhs)
|
||||
{
|
||||
const std::pair<TypeId, TypeId> sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs);
|
||||
LUAU_ASSERT(sortedPair == seen.back());
|
||||
seen.pop_back();
|
||||
if (FFlag::LuauShareTxnSeen)
|
||||
{
|
||||
LUAU_ASSERT(sortedPair == sharedSeen->back());
|
||||
sharedSeen->pop_back();
|
||||
}
|
||||
else
|
||||
{
|
||||
LUAU_ASSERT(sortedPair == ownedSeen.back());
|
||||
ownedSeen.pop_back();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace Luau
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
#include "Luau/Parser.h"
|
||||
#include "Luau/RecursionCounter.h"
|
||||
#include "Luau/Scope.h"
|
||||
#include "Luau/ToString.h"
|
||||
#include "Luau/TypeInfer.h"
|
||||
#include "Luau/TypePack.h"
|
||||
#include "Luau/TypeVar.h"
|
||||
|
@ -33,14 +34,31 @@ static char* allocateString(Luau::Allocator& allocator, const char* format, Data
|
|||
return result;
|
||||
}
|
||||
|
||||
using SyntheticNames = std::unordered_map<const void*, char*>;
|
||||
|
||||
namespace Luau
|
||||
{
|
||||
|
||||
static const char* getName(Allocator* allocator, SyntheticNames* syntheticNames, const Unifiable::Generic& gen)
|
||||
{
|
||||
size_t s = syntheticNames->size();
|
||||
char*& n = (*syntheticNames)[&gen];
|
||||
if (!n)
|
||||
{
|
||||
std::string str = gen.explicitName ? gen.name : generateName(s);
|
||||
n = static_cast<char*>(allocator->allocate(str.size() + 1));
|
||||
strcpy(n, str.c_str());
|
||||
}
|
||||
|
||||
return n;
|
||||
}
|
||||
|
||||
class TypeRehydrationVisitor
|
||||
{
|
||||
mutable std::map<void*, int> seen;
|
||||
mutable int count = 0;
|
||||
std::map<void*, int> seen;
|
||||
int count = 0;
|
||||
|
||||
bool hasSeen(const void* tv) const
|
||||
bool hasSeen(const void* tv)
|
||||
{
|
||||
void* ttv = const_cast<void*>(tv);
|
||||
auto it = seen.find(ttv);
|
||||
|
@ -52,15 +70,16 @@ class TypeRehydrationVisitor
|
|||
}
|
||||
|
||||
public:
|
||||
TypeRehydrationVisitor(Allocator* alloc, const TypeRehydrationOptions& options = TypeRehydrationOptions())
|
||||
TypeRehydrationVisitor(Allocator* alloc, SyntheticNames* syntheticNames, const TypeRehydrationOptions& options = TypeRehydrationOptions())
|
||||
: allocator(alloc)
|
||||
, syntheticNames(syntheticNames)
|
||||
, options(options)
|
||||
{
|
||||
}
|
||||
|
||||
AstTypePack* rehydrate(TypePackId tp) const;
|
||||
AstTypePack* rehydrate(TypePackId tp);
|
||||
|
||||
AstType* operator()(const PrimitiveTypeVar& ptv) const
|
||||
AstType* operator()(const PrimitiveTypeVar& ptv)
|
||||
{
|
||||
switch (ptv.type)
|
||||
{
|
||||
|
@ -78,11 +97,11 @@ public:
|
|||
return nullptr;
|
||||
}
|
||||
}
|
||||
AstType* operator()(const AnyTypeVar&) const
|
||||
AstType* operator()(const AnyTypeVar&)
|
||||
{
|
||||
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("any"));
|
||||
}
|
||||
AstType* operator()(const TableTypeVar& ttv) const
|
||||
AstType* operator()(const TableTypeVar& ttv)
|
||||
{
|
||||
RecursionCounter counter(&count);
|
||||
|
||||
|
@ -144,12 +163,12 @@ public:
|
|||
return allocator->alloc<AstTypeTable>(Location(), props, indexer);
|
||||
}
|
||||
|
||||
AstType* operator()(const MetatableTypeVar& mtv) const
|
||||
AstType* operator()(const MetatableTypeVar& mtv)
|
||||
{
|
||||
return Luau::visit(*this, mtv.table->ty);
|
||||
}
|
||||
|
||||
AstType* operator()(const ClassTypeVar& ctv) const
|
||||
AstType* operator()(const ClassTypeVar& ctv)
|
||||
{
|
||||
RecursionCounter counter(&count);
|
||||
|
||||
|
@ -176,7 +195,7 @@ public:
|
|||
return allocator->alloc<AstTypeTable>(Location(), props);
|
||||
}
|
||||
|
||||
AstType* operator()(const FunctionTypeVar& ftv) const
|
||||
AstType* operator()(const FunctionTypeVar& ftv)
|
||||
{
|
||||
RecursionCounter counter(&count);
|
||||
|
||||
|
@ -253,10 +272,12 @@ public:
|
|||
size_t i = 0;
|
||||
for (const auto& el : ftv.argNames)
|
||||
{
|
||||
std::optional<AstArgumentName>* arg = &argNames.data[i++];
|
||||
|
||||
if (el)
|
||||
argNames.data[i++] = {AstName(el->name.c_str()), el->location};
|
||||
new (arg) std::optional<AstArgumentName>(AstArgumentName(AstName(el->name.c_str()), el->location));
|
||||
else
|
||||
argNames.data[i++] = {};
|
||||
new (arg) std::optional<AstArgumentName>();
|
||||
}
|
||||
|
||||
AstArray<AstType*> returnTypes;
|
||||
|
@ -290,23 +311,23 @@ public:
|
|||
return allocator->alloc<AstTypeFunction>(
|
||||
Location(), generics, genericPacks, AstTypeList{argTypes, argTailAnnotation}, argNames, AstTypeList{returnTypes, retTailAnnotation});
|
||||
}
|
||||
AstType* operator()(const Unifiable::Error&) const
|
||||
AstType* operator()(const Unifiable::Error&)
|
||||
{
|
||||
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("Unifiable<Error>"));
|
||||
}
|
||||
AstType* operator()(const GenericTypeVar& gtv) const
|
||||
AstType* operator()(const GenericTypeVar& gtv)
|
||||
{
|
||||
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName(gtv.name.c_str()));
|
||||
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName(getName(allocator, syntheticNames, gtv)));
|
||||
}
|
||||
AstType* operator()(const Unifiable::Bound<TypeId>& bound) const
|
||||
AstType* operator()(const Unifiable::Bound<TypeId>& bound)
|
||||
{
|
||||
return Luau::visit(*this, bound.boundTo->ty);
|
||||
}
|
||||
AstType* operator()(Unifiable::Free ftv) const
|
||||
AstType* operator()(const FreeTypeVar& ftv)
|
||||
{
|
||||
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("free"));
|
||||
}
|
||||
AstType* operator()(const UnionTypeVar& uv) const
|
||||
AstType* operator()(const UnionTypeVar& uv)
|
||||
{
|
||||
AstArray<AstType*> unionTypes;
|
||||
unionTypes.size = uv.options.size();
|
||||
|
@ -317,7 +338,7 @@ public:
|
|||
}
|
||||
return allocator->alloc<AstTypeUnion>(Location(), unionTypes);
|
||||
}
|
||||
AstType* operator()(const IntersectionTypeVar& uv) const
|
||||
AstType* operator()(const IntersectionTypeVar& uv)
|
||||
{
|
||||
AstArray<AstType*> intersectionTypes;
|
||||
intersectionTypes.size = uv.parts.size();
|
||||
|
@ -328,23 +349,28 @@ public:
|
|||
}
|
||||
return allocator->alloc<AstTypeIntersection>(Location(), intersectionTypes);
|
||||
}
|
||||
AstType* operator()(const LazyTypeVar& ltv) const
|
||||
AstType* operator()(const LazyTypeVar& ltv)
|
||||
{
|
||||
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("<Lazy?>"));
|
||||
}
|
||||
|
||||
private:
|
||||
Allocator* allocator;
|
||||
SyntheticNames* syntheticNames;
|
||||
const TypeRehydrationOptions& options;
|
||||
};
|
||||
|
||||
class TypePackRehydrationVisitor
|
||||
{
|
||||
public:
|
||||
TypePackRehydrationVisitor(Allocator* allocator, const TypeRehydrationVisitor& typeVisitor)
|
||||
TypePackRehydrationVisitor(Allocator* allocator, SyntheticNames* syntheticNames, TypeRehydrationVisitor* typeVisitor)
|
||||
: allocator(allocator)
|
||||
, syntheticNames(syntheticNames)
|
||||
, typeVisitor(typeVisitor)
|
||||
{
|
||||
LUAU_ASSERT(allocator);
|
||||
LUAU_ASSERT(syntheticNames);
|
||||
LUAU_ASSERT(typeVisitor);
|
||||
}
|
||||
|
||||
AstTypePack* operator()(const BoundTypePack& btp) const
|
||||
|
@ -359,7 +385,7 @@ public:
|
|||
head.data = static_cast<AstType**>(allocator->allocate(sizeof(AstType*) * tp.head.size()));
|
||||
|
||||
for (size_t i = 0; i < tp.head.size(); i++)
|
||||
head.data[i] = Luau::visit(typeVisitor, tp.head[i]->ty);
|
||||
head.data[i] = Luau::visit(*typeVisitor, tp.head[i]->ty);
|
||||
|
||||
AstTypePack* tail = nullptr;
|
||||
|
||||
|
@ -371,12 +397,12 @@ public:
|
|||
|
||||
AstTypePack* operator()(const VariadicTypePack& vtp) const
|
||||
{
|
||||
return allocator->alloc<AstTypePackVariadic>(Location(), Luau::visit(typeVisitor, vtp.ty->ty));
|
||||
return allocator->alloc<AstTypePackVariadic>(Location(), Luau::visit(*typeVisitor, vtp.ty->ty));
|
||||
}
|
||||
|
||||
AstTypePack* operator()(const GenericTypePack& gtp) const
|
||||
{
|
||||
return allocator->alloc<AstTypePackGeneric>(Location(), AstName(gtp.name.c_str()));
|
||||
return allocator->alloc<AstTypePackGeneric>(Location(), AstName(getName(allocator, syntheticNames, gtp)));
|
||||
}
|
||||
|
||||
AstTypePack* operator()(const FreeTypePack& gtp) const
|
||||
|
@ -391,12 +417,13 @@ public:
|
|||
|
||||
private:
|
||||
Allocator* allocator;
|
||||
const TypeRehydrationVisitor& typeVisitor;
|
||||
SyntheticNames* syntheticNames;
|
||||
TypeRehydrationVisitor* typeVisitor;
|
||||
};
|
||||
|
||||
AstTypePack* TypeRehydrationVisitor::rehydrate(TypePackId tp) const
|
||||
AstTypePack* TypeRehydrationVisitor::rehydrate(TypePackId tp)
|
||||
{
|
||||
TypePackRehydrationVisitor tprv(allocator, *this);
|
||||
TypePackRehydrationVisitor tprv(allocator, syntheticNames, this);
|
||||
return Luau::visit(tprv, tp->ty);
|
||||
}
|
||||
|
||||
|
@ -431,7 +458,7 @@ public:
|
|||
{
|
||||
if (!type)
|
||||
return nullptr;
|
||||
return Luau::visit(TypeRehydrationVisitor(allocator), (*type)->ty);
|
||||
return Luau::visit(TypeRehydrationVisitor(allocator, &syntheticNames), (*type)->ty);
|
||||
}
|
||||
|
||||
AstArray<Luau::AstType*> typeAstPack(TypePackId type)
|
||||
|
@ -443,7 +470,7 @@ public:
|
|||
result.data = static_cast<AstType**>(allocator->allocate(sizeof(AstType*) * v.size()));
|
||||
for (size_t i = 0; i < v.size(); ++i)
|
||||
{
|
||||
result.data[i] = Luau::visit(TypeRehydrationVisitor(allocator), v[i]->ty);
|
||||
result.data[i] = Luau::visit(TypeRehydrationVisitor(allocator, &syntheticNames), v[i]->ty);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
@ -495,7 +522,7 @@ public:
|
|||
{
|
||||
if (FFlag::LuauTypeAliasPacks)
|
||||
{
|
||||
variadicAnnotation = TypeRehydrationVisitor(allocator).rehydrate(*tail);
|
||||
variadicAnnotation = TypeRehydrationVisitor(allocator, &syntheticNames).rehydrate(*tail);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
@ -515,6 +542,7 @@ public:
|
|||
private:
|
||||
Module& module;
|
||||
Allocator* allocator;
|
||||
SyntheticNames syntheticNames;
|
||||
};
|
||||
|
||||
void attachTypeData(SourceModule& source, Module& result)
|
||||
|
@ -525,7 +553,8 @@ void attachTypeData(SourceModule& source, Module& result)
|
|||
|
||||
AstType* rehydrateAnnotation(TypeId type, Allocator* allocator, const TypeRehydrationOptions& options)
|
||||
{
|
||||
return Luau::visit(TypeRehydrationVisitor(allocator, options), type->ty);
|
||||
SyntheticNames syntheticNames;
|
||||
return Luau::visit(TypeRehydrationVisitor(allocator, &syntheticNames, options), type->ty);
|
||||
}
|
||||
|
||||
} // namespace Luau
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
#include "Luau/Common.h"
|
||||
#include "Luau/ModuleResolver.h"
|
||||
#include "Luau/Parser.h"
|
||||
#include "Luau/Quantify.h"
|
||||
#include "Luau/RecursionCounter.h"
|
||||
#include "Luau/Scope.h"
|
||||
#include "Luau/Substitution.h"
|
||||
|
@ -33,18 +34,16 @@ LUAU_FASTFLAGVARIABLE(LuauCloneCorrectlyBeforeMutatingTableType, false)
|
|||
LUAU_FASTFLAGVARIABLE(LuauStoreMatchingOverloadFnType, false)
|
||||
LUAU_FASTFLAGVARIABLE(LuauRankNTypes, false)
|
||||
LUAU_FASTFLAGVARIABLE(LuauOrPredicate, false)
|
||||
LUAU_FASTFLAGVARIABLE(LuauExtraNilRecovery, false)
|
||||
LUAU_FASTFLAGVARIABLE(LuauMissingUnionPropertyError, false)
|
||||
LUAU_FASTFLAGVARIABLE(LuauInferReturnAssertAssign, false)
|
||||
LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false)
|
||||
LUAU_FASTFLAGVARIABLE(LuauAddMissingFollow, false)
|
||||
LUAU_FASTFLAGVARIABLE(LuauTypeGuardPeelsAwaySubclasses, false)
|
||||
LUAU_FASTFLAGVARIABLE(LuauSlightlyMoreFlexibleBinaryPredicates, false)
|
||||
LUAU_FASTFLAGVARIABLE(LuauInferFunctionArgsFix, false)
|
||||
LUAU_FASTFLAGVARIABLE(LuauFollowInTypeFunApply, false)
|
||||
LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionAnalysisSupport, false)
|
||||
LUAU_FASTFLAGVARIABLE(LuauStrictRequire, false)
|
||||
LUAU_FASTFLAG(LuauSubstitutionDontReplaceIgnoredTypes)
|
||||
LUAU_FASTFLAGVARIABLE(LuauQuantifyInPlace2, false)
|
||||
LUAU_FASTFLAG(LuauNewRequireTrace)
|
||||
LUAU_FASTFLAG(LuauTypeAliasPacks)
|
||||
|
||||
|
@ -215,6 +214,7 @@ static bool isMetamethod(const Name& name)
|
|||
TypeChecker::TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHandler)
|
||||
: resolver(resolver)
|
||||
, iceHandler(iceHandler)
|
||||
, unifierState(iceHandler)
|
||||
, nilType(singletonTypes.nilType)
|
||||
, numberType(singletonTypes.numberType)
|
||||
, stringType(singletonTypes.stringType)
|
||||
|
@ -370,13 +370,18 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block)
|
|||
return;
|
||||
}
|
||||
|
||||
int subLevel = 0;
|
||||
|
||||
std::vector<AstStat*> sorted(block.body.data, block.body.data + block.body.size);
|
||||
toposort(sorted);
|
||||
|
||||
for (const auto& stat : sorted)
|
||||
{
|
||||
if (const auto& typealias = stat->as<AstStatTypeAlias>())
|
||||
check(scope, *typealias, true);
|
||||
{
|
||||
check(scope, *typealias, subLevel, true);
|
||||
++subLevel;
|
||||
}
|
||||
}
|
||||
|
||||
auto protoIter = sorted.begin();
|
||||
|
@ -399,8 +404,6 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block)
|
|||
}
|
||||
};
|
||||
|
||||
int subLevel = 0;
|
||||
|
||||
while (protoIter != sorted.end())
|
||||
{
|
||||
// protoIter walks forward
|
||||
|
@ -433,7 +436,7 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block)
|
|||
// function f<a>(x:a):a local x: number = g(37) return x end
|
||||
// function g(x:number):number return f(x) end
|
||||
// ```
|
||||
if (containsFunctionCall(**protoIter))
|
||||
if (FFlag::LuauQuantifyInPlace2 ? containsFunctionCallOrReturn(**protoIter) : containsFunctionCall(**protoIter))
|
||||
{
|
||||
while (checkIter != protoIter)
|
||||
{
|
||||
|
@ -1161,7 +1164,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco
|
|||
scope->bindings[function.name] = {quantify(scope, ty, function.name->location), function.name->location};
|
||||
}
|
||||
|
||||
void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias, bool forwardDeclare)
|
||||
void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias, int subLevel, bool forwardDeclare)
|
||||
{
|
||||
// This function should be called at most twice for each type alias.
|
||||
// Once with forwardDeclare, and once without.
|
||||
|
@ -1189,11 +1192,12 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias
|
|||
}
|
||||
else
|
||||
{
|
||||
ScopePtr aliasScope = childScope(scope, typealias.location);
|
||||
ScopePtr aliasScope =
|
||||
FFlag::LuauQuantifyInPlace2 ? childScope(scope, typealias.location, subLevel) : childScope(scope, typealias.location);
|
||||
|
||||
if (FFlag::LuauTypeAliasPacks)
|
||||
{
|
||||
auto [generics, genericPacks] = createGenericTypes(aliasScope, typealias, typealias.generics, typealias.genericPacks);
|
||||
auto [generics, genericPacks] = createGenericTypes(aliasScope, scope->level, typealias, typealias.generics, typealias.genericPacks);
|
||||
|
||||
TypeId ty = (FFlag::LuauRankNTypes ? freshType(aliasScope) : DEPRECATED_freshType(scope, true));
|
||||
FreeTypeVar* ftv = getMutable<FreeTypeVar>(ty);
|
||||
|
@ -1418,7 +1422,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareFunction& glo
|
|||
{
|
||||
ScopePtr funScope = childFunctionScope(scope, global.location);
|
||||
|
||||
auto [generics, genericPacks] = createGenericTypes(funScope, global, global.generics, global.genericPacks);
|
||||
auto [generics, genericPacks] = createGenericTypes(funScope, std::nullopt, global, global.generics, global.genericPacks);
|
||||
|
||||
TypePackId argPack = resolveTypePack(funScope, global.params);
|
||||
TypePackId retPack = resolveTypePack(funScope, global.retTypes);
|
||||
|
@ -1610,25 +1614,11 @@ ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIn
|
|||
if (std::optional<TypeId> ty = resolveLValue(scope, *lvalue))
|
||||
return {*ty, {TruthyPredicate{std::move(*lvalue), expr.location}}};
|
||||
|
||||
if (FFlag::LuauExtraNilRecovery)
|
||||
lhsType = stripFromNilAndReport(lhsType, expr.expr->location);
|
||||
lhsType = stripFromNilAndReport(lhsType, expr.expr->location);
|
||||
|
||||
if (std::optional<TypeId> ty = getIndexTypeFromType(scope, lhsType, name, expr.location, true))
|
||||
return {*ty};
|
||||
|
||||
if (!FFlag::LuauMissingUnionPropertyError)
|
||||
reportError(expr.indexLocation, UnknownProperty{lhsType, expr.index.value});
|
||||
|
||||
if (!FFlag::LuauExtraNilRecovery)
|
||||
{
|
||||
// Try to recover using a union without 'nil' options
|
||||
if (std::optional<TypeId> strippedUnion = tryStripUnionFromNil(lhsType))
|
||||
{
|
||||
if (std::optional<TypeId> ty = getIndexTypeFromType(scope, *strippedUnion, name, expr.location, false))
|
||||
return {*ty};
|
||||
}
|
||||
}
|
||||
|
||||
return {errorType};
|
||||
}
|
||||
|
||||
|
@ -1694,61 +1684,37 @@ std::optional<TypeId> TypeChecker::getIndexTypeFromType(
|
|||
}
|
||||
else if (const UnionTypeVar* utv = get<UnionTypeVar>(type))
|
||||
{
|
||||
if (FFlag::LuauMissingUnionPropertyError)
|
||||
std::vector<TypeId> goodOptions;
|
||||
std::vector<TypeId> badOptions;
|
||||
|
||||
for (TypeId t : utv)
|
||||
{
|
||||
std::vector<TypeId> goodOptions;
|
||||
std::vector<TypeId> badOptions;
|
||||
RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit);
|
||||
|
||||
for (TypeId t : utv)
|
||||
{
|
||||
RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit);
|
||||
|
||||
if (std::optional<TypeId> ty = getIndexTypeFromType(scope, t, name, location, false))
|
||||
goodOptions.push_back(*ty);
|
||||
else
|
||||
badOptions.push_back(t);
|
||||
}
|
||||
|
||||
if (!badOptions.empty())
|
||||
{
|
||||
if (addErrors)
|
||||
{
|
||||
if (goodOptions.empty())
|
||||
reportError(location, UnknownProperty{type, name});
|
||||
else
|
||||
reportError(location, MissingUnionProperty{type, badOptions, name});
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
std::vector<TypeId> result = reduceUnion(goodOptions);
|
||||
|
||||
if (result.size() == 1)
|
||||
return result[0];
|
||||
|
||||
return addType(UnionTypeVar{std::move(result)});
|
||||
if (std::optional<TypeId> ty = getIndexTypeFromType(scope, t, name, location, false))
|
||||
goodOptions.push_back(*ty);
|
||||
else
|
||||
badOptions.push_back(t);
|
||||
}
|
||||
else
|
||||
|
||||
if (!badOptions.empty())
|
||||
{
|
||||
std::vector<TypeId> options;
|
||||
|
||||
for (TypeId t : utv->options)
|
||||
if (addErrors)
|
||||
{
|
||||
RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit);
|
||||
|
||||
if (std::optional<TypeId> ty = getIndexTypeFromType(scope, t, name, location, false))
|
||||
options.push_back(*ty);
|
||||
if (goodOptions.empty())
|
||||
reportError(location, UnknownProperty{type, name});
|
||||
else
|
||||
return std::nullopt;
|
||||
reportError(location, MissingUnionProperty{type, badOptions, name});
|
||||
}
|
||||
|
||||
std::vector<TypeId> result = reduceUnion(options);
|
||||
|
||||
if (result.size() == 1)
|
||||
return result[0];
|
||||
|
||||
return addType(UnionTypeVar{std::move(result)});
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
std::vector<TypeId> result = reduceUnion(goodOptions);
|
||||
|
||||
if (result.size() == 1)
|
||||
return result[0];
|
||||
|
||||
return addType(UnionTypeVar{std::move(result)});
|
||||
}
|
||||
else if (const IntersectionTypeVar* itv = get<IntersectionTypeVar>(type))
|
||||
{
|
||||
|
@ -1765,7 +1731,7 @@ std::optional<TypeId> TypeChecker::getIndexTypeFromType(
|
|||
// If no parts of the intersection had the property we looked up for, it never existed at all.
|
||||
if (parts.empty())
|
||||
{
|
||||
if (FFlag::LuauMissingUnionPropertyError && addErrors)
|
||||
if (addErrors)
|
||||
reportError(location, UnknownProperty{type, name});
|
||||
return std::nullopt;
|
||||
}
|
||||
|
@ -1779,7 +1745,7 @@ std::optional<TypeId> TypeChecker::getIndexTypeFromType(
|
|||
return addType(IntersectionTypeVar{result});
|
||||
}
|
||||
|
||||
if (FFlag::LuauMissingUnionPropertyError && addErrors)
|
||||
if (addErrors)
|
||||
reportError(location, UnknownProperty{type, name});
|
||||
|
||||
return std::nullopt;
|
||||
|
@ -2062,8 +2028,7 @@ ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprUn
|
|||
case AstExprUnary::Len:
|
||||
tablify(operandType);
|
||||
|
||||
if (FFlag::LuauExtraNilRecovery)
|
||||
operandType = stripFromNilAndReport(operandType, expr.location);
|
||||
operandType = stripFromNilAndReport(operandType, expr.location);
|
||||
|
||||
if (get<ErrorTypeVar>(operandType))
|
||||
return {errorType};
|
||||
|
@ -2635,8 +2600,7 @@ std::pair<TypeId, TypeId*> TypeChecker::checkLValueBinding(const ScopePtr& scope
|
|||
|
||||
Name name = expr.index.value;
|
||||
|
||||
if (FFlag::LuauExtraNilRecovery)
|
||||
lhs = stripFromNilAndReport(lhs, expr.expr->location);
|
||||
lhs = stripFromNilAndReport(lhs, expr.expr->location);
|
||||
|
||||
if (TableTypeVar* lhsTable = getMutableTableType(lhs))
|
||||
{
|
||||
|
@ -2710,8 +2674,7 @@ std::pair<TypeId, TypeId*> TypeChecker::checkLValueBinding(const ScopePtr& scope
|
|||
TypeId exprType = checkExpr(scope, *expr.expr).type;
|
||||
tablify(exprType);
|
||||
|
||||
if (FFlag::LuauExtraNilRecovery)
|
||||
exprType = stripFromNilAndReport(exprType, expr.expr->location);
|
||||
exprType = stripFromNilAndReport(exprType, expr.expr->location);
|
||||
|
||||
TypeId indexType = checkExpr(scope, *expr.index).type;
|
||||
|
||||
|
@ -2738,10 +2701,7 @@ std::pair<TypeId, TypeId*> TypeChecker::checkLValueBinding(const ScopePtr& scope
|
|||
|
||||
if (!exprTable)
|
||||
{
|
||||
if (FFlag::LuauExtraNilRecovery)
|
||||
reportError(TypeError{expr.expr->location, NotATable{exprType}});
|
||||
else
|
||||
reportError(TypeError{expr.location, NotATable{exprType}});
|
||||
reportError(TypeError{expr.expr->location, NotATable{exprType}});
|
||||
return std::pair(errorType, nullptr);
|
||||
}
|
||||
|
||||
|
@ -2910,7 +2870,7 @@ std::pair<TypeId, ScopePtr> TypeChecker::checkFunctionSignature(
|
|||
|
||||
if (FFlag::LuauGenericFunctions)
|
||||
{
|
||||
std::tie(generics, genericPacks) = createGenericTypes(funScope, expr, expr.generics, expr.genericPacks);
|
||||
std::tie(generics, genericPacks) = createGenericTypes(funScope, std::nullopt, expr, expr.generics, expr.genericPacks);
|
||||
}
|
||||
|
||||
TypePackId retPack;
|
||||
|
@ -3016,9 +2976,6 @@ std::pair<TypeId, ScopePtr> TypeChecker::checkFunctionSignature(
|
|||
if (expectedArgsCurr != expectedArgsEnd)
|
||||
{
|
||||
argType = *expectedArgsCurr;
|
||||
|
||||
if (!FFlag::LuauInferFunctionArgsFix)
|
||||
++expectedArgsCurr;
|
||||
}
|
||||
else if (auto expectedArgsTail = expectedArgsCurr.tail())
|
||||
{
|
||||
|
@ -3034,7 +2991,7 @@ std::pair<TypeId, ScopePtr> TypeChecker::checkFunctionSignature(
|
|||
funScope->bindings[local] = {argType, local->location};
|
||||
argTypes.push_back(argType);
|
||||
|
||||
if (FFlag::LuauInferFunctionArgsFix && expectedArgsCurr != expectedArgsEnd)
|
||||
if (expectedArgsCurr != expectedArgsEnd)
|
||||
++expectedArgsCurr;
|
||||
}
|
||||
|
||||
|
@ -3402,8 +3359,7 @@ ExprResult<TypePackId> TypeChecker::checkExprPack(const ScopePtr& scope, const A
|
|||
if (!FFlag::LuauRankNTypes)
|
||||
instantiate(scope, selfType, expr.func->location);
|
||||
|
||||
if (FFlag::LuauExtraNilRecovery)
|
||||
selfType = stripFromNilAndReport(selfType, expr.func->location);
|
||||
selfType = stripFromNilAndReport(selfType, expr.func->location);
|
||||
|
||||
if (std::optional<TypeId> propTy = getIndexTypeFromType(scope, selfType, indexExpr->index.value, expr.location, true))
|
||||
{
|
||||
|
@ -3412,34 +3368,8 @@ ExprResult<TypePackId> TypeChecker::checkExprPack(const ScopePtr& scope, const A
|
|||
}
|
||||
else
|
||||
{
|
||||
if (!FFlag::LuauMissingUnionPropertyError)
|
||||
reportError(indexExpr->indexLocation, UnknownProperty{selfType, indexExpr->index.value});
|
||||
|
||||
if (!FFlag::LuauExtraNilRecovery)
|
||||
{
|
||||
// Try to recover using a union without 'nil' options
|
||||
if (std::optional<TypeId> strippedUnion = tryStripUnionFromNil(selfType))
|
||||
{
|
||||
if (std::optional<TypeId> propTy = getIndexTypeFromType(scope, *strippedUnion, indexExpr->index.value, expr.location, false))
|
||||
{
|
||||
selfType = *strippedUnion;
|
||||
|
||||
functionType = *propTy;
|
||||
actualFunctionType = instantiate(scope, functionType, expr.func->location);
|
||||
}
|
||||
}
|
||||
|
||||
if (!actualFunctionType)
|
||||
{
|
||||
functionType = errorType;
|
||||
actualFunctionType = errorType;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
functionType = errorType;
|
||||
actualFunctionType = errorType;
|
||||
}
|
||||
functionType = errorType;
|
||||
actualFunctionType = errorType;
|
||||
}
|
||||
}
|
||||
else
|
||||
|
@ -3555,8 +3485,7 @@ std::optional<ExprResult<TypePackId>> TypeChecker::checkCallOverload(const Scope
|
|||
TypePackId argPack, TypePack* args, const std::vector<Location>& argLocations, const ExprResult<TypePackId>& argListResult,
|
||||
std::vector<TypeId>& overloadsThatMatchArgCount, std::vector<OverloadErrorEntry>& errors)
|
||||
{
|
||||
if (FFlag::LuauExtraNilRecovery)
|
||||
fn = stripFromNilAndReport(fn, expr.func->location);
|
||||
fn = stripFromNilAndReport(fn, expr.func->location);
|
||||
|
||||
if (get<AnyTypeVar>(fn))
|
||||
{
|
||||
|
@ -4283,6 +4212,12 @@ TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location
|
|||
if (!ftv || !ftv->generics.empty() || !ftv->genericPacks.empty())
|
||||
return ty;
|
||||
|
||||
if (FFlag::LuauQuantifyInPlace2)
|
||||
{
|
||||
Luau::quantify(currentModule, ty, scope->level);
|
||||
return ty;
|
||||
}
|
||||
|
||||
quantification.level = scope->level;
|
||||
quantification.generics.clear();
|
||||
quantification.genericPacks.clear();
|
||||
|
@ -4491,12 +4426,12 @@ void TypeChecker::merge(RefinementMap& l, const RefinementMap& r)
|
|||
|
||||
Unifier TypeChecker::mkUnifier(const Location& location)
|
||||
{
|
||||
return Unifier{¤tModule->internalTypes, currentModule->mode, globalScope, location, Variance::Covariant, iceHandler};
|
||||
return Unifier{¤tModule->internalTypes, currentModule->mode, globalScope, location, Variance::Covariant, unifierState};
|
||||
}
|
||||
|
||||
Unifier TypeChecker::mkUnifier(const std::vector<std::pair<TypeId, TypeId>>& seen, const Location& location)
|
||||
{
|
||||
return Unifier{¤tModule->internalTypes, currentModule->mode, globalScope, seen, location, Variance::Covariant, iceHandler};
|
||||
return Unifier{¤tModule->internalTypes, currentModule->mode, globalScope, seen, location, Variance::Covariant, unifierState};
|
||||
}
|
||||
|
||||
TypeId TypeChecker::freshType(const ScopePtr& scope)
|
||||
|
@ -4753,7 +4688,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation
|
|||
|
||||
if (FFlag::LuauGenericFunctions)
|
||||
{
|
||||
std::tie(generics, genericPacks) = createGenericTypes(funcScope, annotation, func->generics, func->genericPacks);
|
||||
std::tie(generics, genericPacks) = createGenericTypes(funcScope, std::nullopt, annotation, func->generics, func->genericPacks);
|
||||
}
|
||||
|
||||
// TODO: better error message CLI-39912
|
||||
|
@ -5041,10 +4976,12 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf,
|
|||
}
|
||||
|
||||
std::pair<std::vector<TypeId>, std::vector<TypePackId>> TypeChecker::createGenericTypes(
|
||||
const ScopePtr& scope, const AstNode& node, const AstArray<AstName>& genericNames, const AstArray<AstName>& genericPackNames)
|
||||
const ScopePtr& scope, std::optional<TypeLevel> levelOpt, const AstNode& node, const AstArray<AstName>& genericNames, const AstArray<AstName>& genericPackNames)
|
||||
{
|
||||
LUAU_ASSERT(scope->parent);
|
||||
|
||||
const TypeLevel level = (FFlag::LuauQuantifyInPlace2 && levelOpt) ? *levelOpt : scope->level;
|
||||
|
||||
std::vector<TypeId> generics;
|
||||
for (const AstName& generic : genericNames)
|
||||
{
|
||||
|
@ -5063,12 +5000,12 @@ std::pair<std::vector<TypeId>, std::vector<TypePackId>> TypeChecker::createGener
|
|||
{
|
||||
TypeId& cached = scope->parent->typeAliasTypeParameters[n];
|
||||
if (!cached)
|
||||
cached = addType(GenericTypeVar{scope->level, n});
|
||||
cached = addType(GenericTypeVar{level, n});
|
||||
g = cached;
|
||||
}
|
||||
else
|
||||
{
|
||||
g = addType(Unifiable::Generic{scope->level, n});
|
||||
g = addType(Unifiable::Generic{level, n});
|
||||
}
|
||||
|
||||
generics.push_back(g);
|
||||
|
@ -5093,12 +5030,12 @@ std::pair<std::vector<TypeId>, std::vector<TypePackId>> TypeChecker::createGener
|
|||
{
|
||||
TypePackId& cached = scope->parent->typeAliasTypePackParameters[n];
|
||||
if (!cached)
|
||||
cached = addTypePack(TypePackVar{Unifiable::Generic{scope->level, n}});
|
||||
cached = addTypePack(TypePackVar{Unifiable::Generic{level, n}});
|
||||
g = cached;
|
||||
}
|
||||
else
|
||||
{
|
||||
g = addTypePack(TypePackVar{Unifiable::Generic{scope->level, n}});
|
||||
g = addTypePack(TypePackVar{Unifiable::Generic{level, n}});
|
||||
}
|
||||
|
||||
genericPacks.push_back(g);
|
||||
|
|
|
@ -22,6 +22,7 @@ LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0)
|
|||
LUAU_FASTFLAG(LuauRankNTypes)
|
||||
LUAU_FASTFLAG(LuauTypeGuardPeelsAwaySubclasses)
|
||||
LUAU_FASTFLAG(LuauTypeAliasPacks)
|
||||
LUAU_FASTFLAGVARIABLE(LuauRefactorTagging, false)
|
||||
|
||||
namespace Luau
|
||||
{
|
||||
|
@ -217,8 +218,7 @@ std::optional<TypeId> getMetatable(TypeId type)
|
|||
return mtType->metatable;
|
||||
else if (const ClassTypeVar* classType = get<ClassTypeVar>(type))
|
||||
return classType->metatable;
|
||||
else if (const PrimitiveTypeVar* primitiveType = get<PrimitiveTypeVar>(type);
|
||||
primitiveType && primitiveType->metatable)
|
||||
else if (const PrimitiveTypeVar* primitiveType = get<PrimitiveTypeVar>(type); primitiveType && primitiveType->metatable)
|
||||
{
|
||||
LUAU_ASSERT(primitiveType->type == PrimitiveTypeVar::String);
|
||||
return primitiveType->metatable;
|
||||
|
@ -1490,4 +1490,86 @@ std::vector<TypeId> filterMap(TypeId type, TypeIdPredicate predicate)
|
|||
return {};
|
||||
}
|
||||
|
||||
static Tags* getTags(TypeId ty)
|
||||
{
|
||||
ty = follow(ty);
|
||||
|
||||
if (auto ftv = getMutable<FunctionTypeVar>(ty))
|
||||
return &ftv->tags;
|
||||
else if (auto ttv = getMutable<TableTypeVar>(ty))
|
||||
return &ttv->tags;
|
||||
else if (auto ctv = getMutable<ClassTypeVar>(ty))
|
||||
return &ctv->tags;
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void attachTag(TypeId ty, const std::string& tagName)
|
||||
{
|
||||
if (!FFlag::LuauRefactorTagging)
|
||||
{
|
||||
if (auto ftv = getMutable<FunctionTypeVar>(ty))
|
||||
{
|
||||
ftv->tags.emplace_back(tagName);
|
||||
}
|
||||
else
|
||||
{
|
||||
LUAU_ASSERT(!"Got a non functional type");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if (auto tags = getTags(ty))
|
||||
tags->push_back(tagName);
|
||||
else
|
||||
LUAU_ASSERT(!"This TypeId does not support tags");
|
||||
}
|
||||
}
|
||||
|
||||
void attachTag(Property& prop, const std::string& tagName)
|
||||
{
|
||||
LUAU_ASSERT(FFlag::LuauRefactorTagging);
|
||||
|
||||
prop.tags.push_back(tagName);
|
||||
}
|
||||
|
||||
// We would ideally not expose this because it could cause a footgun.
|
||||
// If the Base class has a tag and you ask if Derived has that tag, it would return false.
|
||||
// Unfortunately, there's already use cases that's hard to disentangle. For now, we expose it.
|
||||
bool hasTag(const Tags& tags, const std::string& tagName)
|
||||
{
|
||||
LUAU_ASSERT(FFlag::LuauRefactorTagging);
|
||||
return std::find(tags.begin(), tags.end(), tagName) != tags.end();
|
||||
}
|
||||
|
||||
bool hasTag(TypeId ty, const std::string& tagName)
|
||||
{
|
||||
ty = follow(ty);
|
||||
|
||||
// We special case classes because getTags only returns a pointer to one vector of tags.
|
||||
// But classes has multiple vector of tags, represented throughout the hierarchy.
|
||||
if (auto ctv = get<ClassTypeVar>(ty))
|
||||
{
|
||||
while (ctv)
|
||||
{
|
||||
if (hasTag(ctv->tags, tagName))
|
||||
return true;
|
||||
else if (!ctv->parent)
|
||||
return false;
|
||||
|
||||
ctv = get<ClassTypeVar>(*ctv->parent);
|
||||
LUAU_ASSERT(ctv);
|
||||
}
|
||||
}
|
||||
else if (auto tags = getTags(ty))
|
||||
return hasTag(*tags, tagName);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool hasTag(const Property& prop, const std::string& tagName)
|
||||
{
|
||||
return hasTag(prop.tags, tagName);
|
||||
}
|
||||
|
||||
} // namespace Luau
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
#include "Luau/TypePack.h"
|
||||
#include "Luau/TypeUtils.h"
|
||||
#include "Luau/TimeTrace.h"
|
||||
#include "Luau/VisitTypeVar.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
|
@ -22,9 +23,99 @@ LUAU_FASTFLAGVARIABLE(LuauTableUnificationEarlyTest, false)
|
|||
LUAU_FASTFLAGVARIABLE(LuauSealedTableUnifyOptionalFix, false)
|
||||
LUAU_FASTFLAGVARIABLE(LuauOccursCheckOkWithRecursiveFunctions, false)
|
||||
LUAU_FASTFLAGVARIABLE(LuauTypecheckOpts, false)
|
||||
LUAU_FASTFLAG(LuauShareTxnSeen);
|
||||
LUAU_FASTFLAGVARIABLE(LuauCacheUnifyTableResults, false)
|
||||
|
||||
namespace Luau
|
||||
{
|
||||
struct SkipCacheForType
|
||||
{
|
||||
SkipCacheForType(const DenseHashMap<TypeId, bool>& skipCacheForType)
|
||||
: skipCacheForType(skipCacheForType)
|
||||
{
|
||||
}
|
||||
|
||||
void cycle(TypeId) {}
|
||||
void cycle(TypePackId) {}
|
||||
|
||||
bool operator()(TypeId ty, const FreeTypeVar& ftv)
|
||||
{
|
||||
result = true;
|
||||
return false;
|
||||
}
|
||||
|
||||
bool operator()(TypeId ty, const BoundTypeVar& btv)
|
||||
{
|
||||
result = true;
|
||||
return false;
|
||||
}
|
||||
|
||||
bool operator()(TypeId ty, const GenericTypeVar& btv)
|
||||
{
|
||||
result = true;
|
||||
return false;
|
||||
}
|
||||
|
||||
bool operator()(TypeId ty, const TableTypeVar&)
|
||||
{
|
||||
TableTypeVar& ttv = *getMutable<TableTypeVar>(ty);
|
||||
|
||||
if (ttv.boundTo)
|
||||
{
|
||||
result = true;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (ttv.state != TableState::Sealed)
|
||||
{
|
||||
result = true;
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
bool operator()(TypeId ty, const T& t)
|
||||
{
|
||||
const bool* prev = skipCacheForType.find(ty);
|
||||
|
||||
if (prev && *prev)
|
||||
{
|
||||
result = true;
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
bool operator()(TypePackId, const T&)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
bool operator()(TypePackId tp, const FreeTypePack& ftp)
|
||||
{
|
||||
result = true;
|
||||
return false;
|
||||
}
|
||||
|
||||
bool operator()(TypePackId tp, const BoundTypePack& ftp)
|
||||
{
|
||||
result = true;
|
||||
return false;
|
||||
}
|
||||
|
||||
bool operator()(TypePackId tp, const GenericTypePack& ftp)
|
||||
{
|
||||
result = true;
|
||||
return false;
|
||||
}
|
||||
|
||||
const DenseHashMap<TypeId, bool>& skipCacheForType;
|
||||
bool result = false;
|
||||
};
|
||||
|
||||
static std::optional<TypeError> hasUnificationTooComplex(const ErrorVec& errors)
|
||||
{
|
||||
|
@ -39,7 +130,7 @@ static std::optional<TypeError> hasUnificationTooComplex(const ErrorVec& errors)
|
|||
return *it;
|
||||
}
|
||||
|
||||
Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, InternalErrorReporter* iceHandler)
|
||||
Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, UnifierSharedState& sharedState)
|
||||
: types(types)
|
||||
, mode(mode)
|
||||
, globalScope(std::move(globalScope))
|
||||
|
@ -47,24 +138,39 @@ Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Locati
|
|||
, variance(variance)
|
||||
, counters(&countersData)
|
||||
, counters_DEPRECATED(std::make_shared<UnifierCounters>())
|
||||
, iceHandler(iceHandler)
|
||||
, sharedState(sharedState)
|
||||
{
|
||||
LUAU_ASSERT(iceHandler);
|
||||
LUAU_ASSERT(sharedState.iceHandler);
|
||||
}
|
||||
|
||||
Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::vector<std::pair<TypeId, TypeId>>& seen, const Location& location,
|
||||
Variance variance, InternalErrorReporter* iceHandler, const std::shared_ptr<UnifierCounters>& counters_DEPRECATED, UnifierCounters* counters)
|
||||
Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::vector<std::pair<TypeId, TypeId>>& ownedSeen, const Location& location,
|
||||
Variance variance, UnifierSharedState& sharedState, const std::shared_ptr<UnifierCounters>& counters_DEPRECATED, UnifierCounters* counters)
|
||||
: types(types)
|
||||
, mode(mode)
|
||||
, globalScope(std::move(globalScope))
|
||||
, log(seen)
|
||||
, log(ownedSeen)
|
||||
, location(location)
|
||||
, variance(variance)
|
||||
, counters(counters ? counters : &countersData)
|
||||
, counters_DEPRECATED(counters_DEPRECATED ? counters_DEPRECATED : std::make_shared<UnifierCounters>())
|
||||
, iceHandler(iceHandler)
|
||||
, sharedState(sharedState)
|
||||
{
|
||||
LUAU_ASSERT(iceHandler);
|
||||
LUAU_ASSERT(sharedState.iceHandler);
|
||||
}
|
||||
|
||||
Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, std::vector<std::pair<TypeId, TypeId>>* sharedSeen, const Location& location,
|
||||
Variance variance, UnifierSharedState& sharedState, const std::shared_ptr<UnifierCounters>& counters_DEPRECATED, UnifierCounters* counters)
|
||||
: types(types)
|
||||
, mode(mode)
|
||||
, globalScope(std::move(globalScope))
|
||||
, log(sharedSeen)
|
||||
, location(location)
|
||||
, variance(variance)
|
||||
, counters(counters ? counters : &countersData)
|
||||
, counters_DEPRECATED(counters_DEPRECATED ? counters_DEPRECATED : std::make_shared<UnifierCounters>())
|
||||
, sharedState(sharedState)
|
||||
{
|
||||
LUAU_ASSERT(sharedState.iceHandler);
|
||||
}
|
||||
|
||||
void Unifier::tryUnify(TypeId superTy, TypeId subTy, bool isFunctionCall, bool isIntersection)
|
||||
|
@ -74,7 +180,7 @@ void Unifier::tryUnify(TypeId superTy, TypeId subTy, bool isFunctionCall, bool i
|
|||
else
|
||||
counters_DEPRECATED->iterationCount = 0;
|
||||
|
||||
return tryUnify_(superTy, subTy, isFunctionCall, isIntersection);
|
||||
tryUnify_(superTy, subTy, isFunctionCall, isIntersection);
|
||||
}
|
||||
|
||||
void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool isIntersection)
|
||||
|
@ -206,6 +312,13 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool
|
|||
if (get<ErrorTypeVar>(subTy) || get<AnyTypeVar>(subTy))
|
||||
return tryUnifyWithAny(subTy, superTy);
|
||||
|
||||
bool cacheEnabled = FFlag::LuauCacheUnifyTableResults && !isFunctionCall && !isIntersection;
|
||||
auto& cache = sharedState.cachedUnify;
|
||||
|
||||
// What if the types are immutable and we proved their relation before
|
||||
if (cacheEnabled && cache.contains({superTy, subTy}) && (variance == Covariant || cache.contains({subTy, superTy})))
|
||||
return;
|
||||
|
||||
// If we have seen this pair of types before, we are currently recursing into cyclic types.
|
||||
// Here, we assume that the types unify. If they do not, we will find out as we roll back
|
||||
// the stack.
|
||||
|
@ -257,6 +370,8 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool
|
|||
|
||||
if (FFlag::LuauUnionHeuristic)
|
||||
{
|
||||
bool found = false;
|
||||
|
||||
const std::string* subName = getName(subTy);
|
||||
if (subName)
|
||||
{
|
||||
|
@ -264,6 +379,21 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool
|
|||
{
|
||||
const std::string* optionName = getName(uv->options[i]);
|
||||
if (optionName && *optionName == *subName)
|
||||
{
|
||||
found = true;
|
||||
startIndex = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!found && cacheEnabled)
|
||||
{
|
||||
for (size_t i = 0; i < uv->options.size(); ++i)
|
||||
{
|
||||
TypeId type = uv->options[i];
|
||||
|
||||
if (cache.contains({type, subTy}) && (variance == Covariant || cache.contains({subTy, type})))
|
||||
{
|
||||
startIndex = i;
|
||||
break;
|
||||
|
@ -311,8 +441,25 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool
|
|||
bool found = false;
|
||||
std::optional<TypeError> unificationTooComplex;
|
||||
|
||||
for (TypeId type : uv->parts)
|
||||
size_t startIndex = 0;
|
||||
|
||||
if (cacheEnabled)
|
||||
{
|
||||
for (size_t i = 0; i < uv->parts.size(); ++i)
|
||||
{
|
||||
TypeId type = uv->parts[i];
|
||||
|
||||
if (cache.contains({superTy, type}) && (variance == Covariant || cache.contains({type, superTy})))
|
||||
{
|
||||
startIndex = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < uv->parts.size(); ++i)
|
||||
{
|
||||
TypeId type = uv->parts[(i + startIndex) % uv->parts.size()];
|
||||
Unifier innerState = makeChildUnifier();
|
||||
innerState.tryUnify_(superTy, type, isFunctionCall);
|
||||
|
||||
|
@ -342,8 +489,13 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool
|
|||
tryUnifyFunctions(superTy, subTy, isFunctionCall);
|
||||
|
||||
else if (get<TableTypeVar>(superTy) && get<TableTypeVar>(subTy))
|
||||
{
|
||||
tryUnifyTables(superTy, subTy, isIntersection);
|
||||
|
||||
if (cacheEnabled && errors.empty())
|
||||
cacheResult(superTy, subTy);
|
||||
}
|
||||
|
||||
// tryUnifyWithMetatable assumes its first argument is a MetatableTypeVar. The check is otherwise symmetrical.
|
||||
else if (get<MetatableTypeVar>(superTy))
|
||||
tryUnifyWithMetatable(superTy, subTy, /*reversed*/ false);
|
||||
|
@ -364,6 +516,41 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool
|
|||
log.popSeen(superTy, subTy);
|
||||
}
|
||||
|
||||
void Unifier::cacheResult(TypeId superTy, TypeId subTy)
|
||||
{
|
||||
LUAU_ASSERT(FFlag::LuauCacheUnifyTableResults);
|
||||
|
||||
bool* superTyInfo = sharedState.skipCacheForType.find(superTy);
|
||||
|
||||
if (superTyInfo && *superTyInfo)
|
||||
return;
|
||||
|
||||
bool* subTyInfo = sharedState.skipCacheForType.find(subTy);
|
||||
|
||||
if (subTyInfo && *subTyInfo)
|
||||
return;
|
||||
|
||||
auto skipCacheFor = [this](TypeId ty) {
|
||||
SkipCacheForType visitor{sharedState.skipCacheForType};
|
||||
visitTypeVarOnce(ty, visitor, sharedState.seenAny);
|
||||
|
||||
sharedState.skipCacheForType[ty] = visitor.result;
|
||||
|
||||
return visitor.result;
|
||||
};
|
||||
|
||||
if (!superTyInfo && skipCacheFor(superTy))
|
||||
return;
|
||||
|
||||
if (!subTyInfo && skipCacheFor(subTy))
|
||||
return;
|
||||
|
||||
sharedState.cachedUnify.insert({superTy, subTy});
|
||||
|
||||
if (variance == Invariant)
|
||||
sharedState.cachedUnify.insert({subTy, superTy});
|
||||
}
|
||||
|
||||
struct WeirdIter
|
||||
{
|
||||
TypePackId packId;
|
||||
|
@ -459,7 +646,7 @@ void Unifier::tryUnify(TypePackId superTp, TypePackId subTp, bool isFunctionCall
|
|||
else
|
||||
counters_DEPRECATED->iterationCount = 0;
|
||||
|
||||
return tryUnify_(superTp, subTp, isFunctionCall);
|
||||
tryUnify_(superTp, subTp, isFunctionCall);
|
||||
}
|
||||
|
||||
/*
|
||||
|
@ -797,6 +984,40 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection)
|
|||
std::vector<std::string> missingProperties;
|
||||
std::vector<std::string> extraProperties;
|
||||
|
||||
// Optimization: First test that the property sets are compatible without doing any recursive unification
|
||||
if (FFlag::LuauTableUnificationEarlyTest && !rt->indexer && rt->state != TableState::Free)
|
||||
{
|
||||
for (const auto& [propName, superProp] : lt->props)
|
||||
{
|
||||
auto subIter = rt->props.find(propName);
|
||||
if (subIter == rt->props.end() && !isOptional(superProp.type) && !get<AnyTypeVar>(follow(superProp.type)))
|
||||
missingProperties.push_back(propName);
|
||||
}
|
||||
|
||||
if (!missingProperties.empty())
|
||||
{
|
||||
errors.push_back(TypeError{location, MissingProperties{left, right, std::move(missingProperties)}});
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// And vice versa if we're invariant
|
||||
if (FFlag::LuauTableUnificationEarlyTest && variance == Invariant && !lt->indexer && lt->state != TableState::Unsealed && lt->state != TableState::Free)
|
||||
{
|
||||
for (const auto& [propName, subProp] : rt->props)
|
||||
{
|
||||
auto superIter = lt->props.find(propName);
|
||||
if (superIter == lt->props.end() && !isOptional(subProp.type) && !get<AnyTypeVar>(follow(subProp.type)))
|
||||
extraProperties.push_back(propName);
|
||||
}
|
||||
|
||||
if (!extraProperties.empty())
|
||||
{
|
||||
errors.push_back(TypeError{location, MissingProperties{left, right, std::move(extraProperties), MissingProperties::Extra}});
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Reminder: left is the supertype, right is the subtype.
|
||||
// Width subtyping: any property in the supertype must be in the subtype,
|
||||
// and the types must agree.
|
||||
|
@ -833,9 +1054,10 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection)
|
|||
innerState.log.rollback();
|
||||
}
|
||||
else if (isOptional(prop.type) || get<AnyTypeVar>(follow(prop.type)))
|
||||
// TODO: this case is unsound, but without it our test suite fails. CLI-46031
|
||||
// TODO: should isOptional(anyType) be true?
|
||||
{}
|
||||
// TODO: this case is unsound, but without it our test suite fails. CLI-46031
|
||||
// TODO: should isOptional(anyType) be true?
|
||||
{
|
||||
}
|
||||
else if (rt->state == TableState::Free)
|
||||
{
|
||||
log(rt);
|
||||
|
@ -878,11 +1100,13 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection)
|
|||
lt->props[name] = clone;
|
||||
}
|
||||
else if (variance == Covariant)
|
||||
{}
|
||||
{
|
||||
}
|
||||
else if (isOptional(prop.type) || get<AnyTypeVar>(follow(prop.type)))
|
||||
// TODO: this case is unsound, but without it our test suite fails. CLI-46031
|
||||
// TODO: should isOptional(anyType) be true?
|
||||
{}
|
||||
// TODO: this case is unsound, but without it our test suite fails. CLI-46031
|
||||
// TODO: should isOptional(anyType) be true?
|
||||
{
|
||||
}
|
||||
else if (lt->state == TableState::Free)
|
||||
{
|
||||
log(lt);
|
||||
|
@ -980,10 +1204,10 @@ TypeId Unifier::deeplyOptional(TypeId ty, std::unordered_map<TypeId, TypeId> see
|
|||
TableTypeVar* resultTtv = getMutable<TableTypeVar>(result);
|
||||
for (auto& [name, prop] : resultTtv->props)
|
||||
prop.type = deeplyOptional(prop.type, seen);
|
||||
return types->addType(UnionTypeVar{{ singletonTypes.nilType, result }});;
|
||||
return types->addType(UnionTypeVar{{singletonTypes.nilType, result}});
|
||||
}
|
||||
else
|
||||
return types->addType(UnionTypeVar{{ singletonTypes.nilType, ty }});
|
||||
return types->addType(UnionTypeVar{{singletonTypes.nilType, ty}});
|
||||
}
|
||||
|
||||
void Unifier::DEPRECATED_tryUnifyTables(TypeId left, TypeId right, bool isIntersection)
|
||||
|
@ -1697,10 +1921,20 @@ void Unifier::tryUnifyWithAny(TypeId any, TypeId ty)
|
|||
{
|
||||
std::vector<TypeId> queue = {ty};
|
||||
|
||||
tempSeenTy.clear();
|
||||
tempSeenTp.clear();
|
||||
if (FFlag::LuauCacheUnifyTableResults)
|
||||
{
|
||||
sharedState.tempSeenTy.clear();
|
||||
sharedState.tempSeenTp.clear();
|
||||
|
||||
Luau::tryUnifyWithAny(queue, *this, tempSeenTy, tempSeenTp, singletonTypes.anyType, anyTP);
|
||||
Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, singletonTypes.anyType, anyTP);
|
||||
}
|
||||
else
|
||||
{
|
||||
tempSeenTy_DEPRECATED.clear();
|
||||
tempSeenTp_DEPRECATED.clear();
|
||||
|
||||
Luau::tryUnifyWithAny(queue, *this, tempSeenTy_DEPRECATED, tempSeenTp_DEPRECATED, singletonTypes.anyType, anyTP);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
@ -1721,12 +1955,24 @@ void Unifier::tryUnifyWithAny(TypePackId any, TypePackId ty)
|
|||
{
|
||||
std::vector<TypeId> queue;
|
||||
|
||||
tempSeenTy.clear();
|
||||
tempSeenTp.clear();
|
||||
if (FFlag::LuauCacheUnifyTableResults)
|
||||
{
|
||||
sharedState.tempSeenTy.clear();
|
||||
sharedState.tempSeenTp.clear();
|
||||
|
||||
queueTypePack(queue, tempSeenTp, *this, ty, any);
|
||||
queueTypePack(queue, sharedState.tempSeenTp, *this, ty, any);
|
||||
|
||||
Luau::tryUnifyWithAny(queue, *this, tempSeenTy, tempSeenTp, anyTy, any);
|
||||
Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, anyTy, any);
|
||||
}
|
||||
else
|
||||
{
|
||||
tempSeenTy_DEPRECATED.clear();
|
||||
tempSeenTp_DEPRECATED.clear();
|
||||
|
||||
queueTypePack(queue, tempSeenTp_DEPRECATED, *this, ty, any);
|
||||
|
||||
Luau::tryUnifyWithAny(queue, *this, tempSeenTy_DEPRECATED, tempSeenTp_DEPRECATED, anyTy, any);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
@ -1775,10 +2021,20 @@ void Unifier::occursCheck(TypeId needle, TypeId haystack)
|
|||
{
|
||||
std::unordered_set<TypeId> seen_DEPRECATED;
|
||||
|
||||
if (FFlag::LuauTypecheckOpts)
|
||||
tempSeenTy.clear();
|
||||
if (FFlag::LuauCacheUnifyTableResults)
|
||||
{
|
||||
if (FFlag::LuauTypecheckOpts)
|
||||
sharedState.tempSeenTy.clear();
|
||||
|
||||
return occursCheck(seen_DEPRECATED, tempSeenTy, needle, haystack);
|
||||
return occursCheck(seen_DEPRECATED, sharedState.tempSeenTy, needle, haystack);
|
||||
}
|
||||
else
|
||||
{
|
||||
if (FFlag::LuauTypecheckOpts)
|
||||
tempSeenTy_DEPRECATED.clear();
|
||||
|
||||
return occursCheck(seen_DEPRECATED, tempSeenTy_DEPRECATED, needle, haystack);
|
||||
}
|
||||
}
|
||||
|
||||
void Unifier::occursCheck(std::unordered_set<TypeId>& seen_DEPRECATED, DenseHashSet<TypeId>& seen, TypeId needle, TypeId haystack)
|
||||
|
@ -1851,10 +2107,20 @@ void Unifier::occursCheck(TypePackId needle, TypePackId haystack)
|
|||
{
|
||||
std::unordered_set<TypePackId> seen_DEPRECATED;
|
||||
|
||||
if (FFlag::LuauTypecheckOpts)
|
||||
tempSeenTp.clear();
|
||||
if (FFlag::LuauCacheUnifyTableResults)
|
||||
{
|
||||
if (FFlag::LuauTypecheckOpts)
|
||||
sharedState.tempSeenTp.clear();
|
||||
|
||||
return occursCheck(seen_DEPRECATED, tempSeenTp, needle, haystack);
|
||||
return occursCheck(seen_DEPRECATED, sharedState.tempSeenTp, needle, haystack);
|
||||
}
|
||||
else
|
||||
{
|
||||
if (FFlag::LuauTypecheckOpts)
|
||||
tempSeenTp_DEPRECATED.clear();
|
||||
|
||||
return occursCheck(seen_DEPRECATED, tempSeenTp_DEPRECATED, needle, haystack);
|
||||
}
|
||||
}
|
||||
|
||||
void Unifier::occursCheck(std::unordered_set<TypePackId>& seen_DEPRECATED, DenseHashSet<TypePackId>& seen, TypePackId needle, TypePackId haystack)
|
||||
|
@ -1922,7 +2188,10 @@ void Unifier::occursCheck(std::unordered_set<TypePackId>& seen_DEPRECATED, Dense
|
|||
|
||||
Unifier Unifier::makeChildUnifier()
|
||||
{
|
||||
return Unifier{types, mode, globalScope, log.seen, location, variance, iceHandler, counters_DEPRECATED, counters};
|
||||
if (FFlag::LuauShareTxnSeen)
|
||||
return Unifier{types, mode, globalScope, log.sharedSeen, location, variance, sharedState, counters_DEPRECATED, counters};
|
||||
else
|
||||
return Unifier{types, mode, globalScope, log.ownedSeen, location, variance, sharedState, counters_DEPRECATED, counters};
|
||||
}
|
||||
|
||||
bool Unifier::isNonstrictMode() const
|
||||
|
@ -1940,12 +2209,12 @@ void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, TypeId
|
|||
|
||||
void Unifier::ice(const std::string& message, const Location& location)
|
||||
{
|
||||
iceHandler->ice(message, location);
|
||||
sharedState.iceHandler->ice(message, location);
|
||||
}
|
||||
|
||||
void Unifier::ice(const std::string& message)
|
||||
{
|
||||
iceHandler->ice(message);
|
||||
sharedState.iceHandler->ice(message);
|
||||
}
|
||||
|
||||
} // namespace Luau
|
||||
|
|
|
@ -194,20 +194,20 @@ LUAU_NOINLINE std::pair<uint16_t, Luau::TimeTrace::ThreadContext&> createScopeDa
|
|||
} // namespace Luau
|
||||
|
||||
// Regular scope
|
||||
#define LUAU_TIMETRACE_SCOPE(name, category) \
|
||||
#define LUAU_TIMETRACE_SCOPE(name, category) \
|
||||
static auto lttScopeStatic = Luau::TimeTrace::createScopeData(name, category); \
|
||||
Luau::TimeTrace::Scope lttScope(lttScopeStatic.second, lttScopeStatic.first)
|
||||
|
||||
// A scope without nested scopes that may be skipped if the time it took is less than the threshold
|
||||
#define LUAU_TIMETRACE_OPTIONAL_TAIL_SCOPE(name, category, microsec) \
|
||||
#define LUAU_TIMETRACE_OPTIONAL_TAIL_SCOPE(name, category, microsec) \
|
||||
static auto lttScopeStaticOptTail = Luau::TimeTrace::createScopeData(name, category); \
|
||||
Luau::TimeTrace::OptionalTailScope lttScope(lttScopeStaticOptTail.second, lttScopeStaticOptTail.first, microsec)
|
||||
|
||||
// Extra key/value data can be added to regular scopes
|
||||
#define LUAU_TIMETRACE_ARGUMENT(name, value) \
|
||||
do \
|
||||
{ \
|
||||
if (FFlag::DebugLuauTimeTracing) \
|
||||
#define LUAU_TIMETRACE_ARGUMENT(name, value) \
|
||||
do \
|
||||
{ \
|
||||
if (FFlag::DebugLuauTimeTracing) \
|
||||
lttScopeStatic.second.eventArgument(name, value); \
|
||||
} while (false)
|
||||
|
||||
|
@ -216,8 +216,8 @@ LUAU_NOINLINE std::pair<uint16_t, Luau::TimeTrace::ThreadContext&> createScopeDa
|
|||
#define LUAU_TIMETRACE_SCOPE(name, category)
|
||||
#define LUAU_TIMETRACE_OPTIONAL_TAIL_SCOPE(name, category, microsec)
|
||||
#define LUAU_TIMETRACE_ARGUMENT(name, value) \
|
||||
do \
|
||||
{ \
|
||||
do \
|
||||
{ \
|
||||
} while (false)
|
||||
|
||||
#endif
|
||||
|
|
|
@ -77,7 +77,10 @@ struct GlobalContext
|
|||
// Ideally we would want all ThreadContext destructors to run
|
||||
// But in VS, not all thread_local object instances are destroyed
|
||||
for (ThreadContext* context : threads)
|
||||
context->flushEvents();
|
||||
{
|
||||
if (!context->events.empty())
|
||||
context->flushEvents();
|
||||
}
|
||||
|
||||
if (traceFile)
|
||||
fclose(traceFile);
|
||||
|
|
1
Makefile
1
Makefile
|
@ -1,4 +1,5 @@
|
|||
# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
|
||||
.SUFFIXES:
|
||||
MAKEFLAGS+=-r -j8
|
||||
COMMA=,
|
||||
|
||||
|
|
|
@ -46,6 +46,7 @@ target_sources(Luau.Analysis PRIVATE
|
|||
Analysis/include/Luau/Module.h
|
||||
Analysis/include/Luau/ModuleResolver.h
|
||||
Analysis/include/Luau/Predicate.h
|
||||
Analysis/include/Luau/Quantify.h
|
||||
Analysis/include/Luau/RecursionCounter.h
|
||||
Analysis/include/Luau/RequireTracer.h
|
||||
Analysis/include/Luau/Scope.h
|
||||
|
@ -63,6 +64,7 @@ target_sources(Luau.Analysis PRIVATE
|
|||
Analysis/include/Luau/TypeVar.h
|
||||
Analysis/include/Luau/Unifiable.h
|
||||
Analysis/include/Luau/Unifier.h
|
||||
Analysis/include/Luau/UnifierSharedState.h
|
||||
Analysis/include/Luau/Variant.h
|
||||
Analysis/include/Luau/VisitTypeVar.h
|
||||
|
||||
|
@ -77,6 +79,7 @@ target_sources(Luau.Analysis PRIVATE
|
|||
Analysis/src/Linter.cpp
|
||||
Analysis/src/Module.cpp
|
||||
Analysis/src/Predicate.cpp
|
||||
Analysis/src/Quantify.cpp
|
||||
Analysis/src/RequireTracer.cpp
|
||||
Analysis/src/Scope.cpp
|
||||
Analysis/src/Substitution.cpp
|
||||
|
|
|
@ -13,8 +13,6 @@
|
|||
|
||||
#include <string.h>
|
||||
|
||||
LUAU_FASTFLAG(LuauGcFullSkipInactiveThreads)
|
||||
|
||||
const char* lua_ident = "$Lua: Lua 5.1.4 Copyright (C) 1994-2008 Lua.org, PUC-Rio $\n"
|
||||
"$Authors: R. Ierusalimschy, L. H. de Figueiredo & W. Celes $\n"
|
||||
"$URL: www.lua.org $\n";
|
||||
|
@ -1153,7 +1151,7 @@ void* lua_newuserdatadtor(lua_State* L, size_t sz, void (*dtor)(void*))
|
|||
luaC_checkGC(L);
|
||||
luaC_checkthreadsleep(L);
|
||||
Udata* u = luaS_newudata(L, sz + sizeof(dtor), UTAG_IDTOR);
|
||||
memcpy(u->data + sz, &dtor, sizeof(dtor));
|
||||
memcpy(&u->data + sz, &dtor, sizeof(dtor));
|
||||
setuvalue(L, L->top, u);
|
||||
api_incr_top(L);
|
||||
return u->data;
|
||||
|
|
|
@ -5,8 +5,6 @@
|
|||
#include "lstate.h"
|
||||
#include "lvm.h"
|
||||
|
||||
LUAU_FASTFLAGVARIABLE(LuauPreferXpush, false)
|
||||
|
||||
#define CO_RUN 0 /* running */
|
||||
#define CO_SUS 1 /* suspended */
|
||||
#define CO_NOR 2 /* 'normal' (it resumed another coroutine) */
|
||||
|
@ -17,7 +15,7 @@ LUAU_FASTFLAGVARIABLE(LuauPreferXpush, false)
|
|||
|
||||
static const char* const statnames[] = {"running", "suspended", "normal", "dead"};
|
||||
|
||||
static int costatus(lua_State* L, lua_State* co)
|
||||
static int auxstatus(lua_State* L, lua_State* co)
|
||||
{
|
||||
if (co == L)
|
||||
return CO_RUN;
|
||||
|
@ -34,11 +32,11 @@ static int costatus(lua_State* L, lua_State* co)
|
|||
return CO_SUS; /* initial state */
|
||||
}
|
||||
|
||||
static int luaB_costatus(lua_State* L)
|
||||
static int costatus(lua_State* L)
|
||||
{
|
||||
lua_State* co = lua_tothread(L, 1);
|
||||
luaL_argexpected(L, co, 1, "thread");
|
||||
lua_pushstring(L, statnames[costatus(L, co)]);
|
||||
lua_pushstring(L, statnames[auxstatus(L, co)]);
|
||||
return 1;
|
||||
}
|
||||
|
||||
|
@ -47,7 +45,7 @@ static int auxresume(lua_State* L, lua_State* co, int narg)
|
|||
// error handling for edge cases
|
||||
if (co->status != LUA_YIELD)
|
||||
{
|
||||
int status = costatus(L, co);
|
||||
int status = auxstatus(L, co);
|
||||
if (status != CO_SUS)
|
||||
{
|
||||
lua_pushfstring(L, "cannot resume %s coroutine", statnames[status]);
|
||||
|
@ -115,7 +113,7 @@ static int auxresumecont(lua_State* L, lua_State* co)
|
|||
}
|
||||
}
|
||||
|
||||
static int luaB_coresumefinish(lua_State* L, int r)
|
||||
static int coresumefinish(lua_State* L, int r)
|
||||
{
|
||||
if (r < 0)
|
||||
{
|
||||
|
@ -131,7 +129,7 @@ static int luaB_coresumefinish(lua_State* L, int r)
|
|||
}
|
||||
}
|
||||
|
||||
static int luaB_coresumey(lua_State* L)
|
||||
static int coresumey(lua_State* L)
|
||||
{
|
||||
lua_State* co = lua_tothread(L, 1);
|
||||
luaL_argexpected(L, co, 1, "thread");
|
||||
|
@ -141,10 +139,10 @@ static int luaB_coresumey(lua_State* L)
|
|||
if (r == CO_STATUS_BREAK)
|
||||
return interruptThread(L, co);
|
||||
|
||||
return luaB_coresumefinish(L, r);
|
||||
return coresumefinish(L, r);
|
||||
}
|
||||
|
||||
static int luaB_coresumecont(lua_State* L, int status)
|
||||
static int coresumecont(lua_State* L, int status)
|
||||
{
|
||||
lua_State* co = lua_tothread(L, 1);
|
||||
luaL_argexpected(L, co, 1, "thread");
|
||||
|
@ -155,10 +153,10 @@ static int luaB_coresumecont(lua_State* L, int status)
|
|||
|
||||
int r = auxresumecont(L, co);
|
||||
|
||||
return luaB_coresumefinish(L, r);
|
||||
return coresumefinish(L, r);
|
||||
}
|
||||
|
||||
static int luaB_auxwrapfinish(lua_State* L, int r)
|
||||
static int auxwrapfinish(lua_State* L, int r)
|
||||
{
|
||||
if (r < 0)
|
||||
{
|
||||
|
@ -173,7 +171,7 @@ static int luaB_auxwrapfinish(lua_State* L, int r)
|
|||
return r;
|
||||
}
|
||||
|
||||
static int luaB_auxwrapy(lua_State* L)
|
||||
static int auxwrapy(lua_State* L)
|
||||
{
|
||||
lua_State* co = lua_tothread(L, lua_upvalueindex(1));
|
||||
int narg = cast_int(L->top - L->base);
|
||||
|
@ -182,10 +180,10 @@ static int luaB_auxwrapy(lua_State* L)
|
|||
if (r == CO_STATUS_BREAK)
|
||||
return interruptThread(L, co);
|
||||
|
||||
return luaB_auxwrapfinish(L, r);
|
||||
return auxwrapfinish(L, r);
|
||||
}
|
||||
|
||||
static int luaB_auxwrapcont(lua_State* L, int status)
|
||||
static int auxwrapcont(lua_State* L, int status)
|
||||
{
|
||||
lua_State* co = lua_tothread(L, lua_upvalueindex(1));
|
||||
|
||||
|
@ -195,62 +193,52 @@ static int luaB_auxwrapcont(lua_State* L, int status)
|
|||
|
||||
int r = auxresumecont(L, co);
|
||||
|
||||
return luaB_auxwrapfinish(L, r);
|
||||
return auxwrapfinish(L, r);
|
||||
}
|
||||
|
||||
static int luaB_cocreate(lua_State* L)
|
||||
static int cocreate(lua_State* L)
|
||||
{
|
||||
luaL_checktype(L, 1, LUA_TFUNCTION);
|
||||
lua_State* NL = lua_newthread(L);
|
||||
|
||||
if (FFlag::LuauPreferXpush)
|
||||
{
|
||||
lua_xpush(L, NL, 1); // push function on top of NL
|
||||
}
|
||||
else
|
||||
{
|
||||
lua_pushvalue(L, 1); /* move function to top */
|
||||
lua_xmove(L, NL, 1); /* move function from L to NL */
|
||||
}
|
||||
|
||||
lua_xpush(L, NL, 1); // push function on top of NL
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int luaB_cowrap(lua_State* L)
|
||||
static int cowrap(lua_State* L)
|
||||
{
|
||||
luaB_cocreate(L);
|
||||
cocreate(L);
|
||||
|
||||
lua_pushcfunction(L, luaB_auxwrapy, NULL, 1, luaB_auxwrapcont);
|
||||
lua_pushcfunction(L, auxwrapy, NULL, 1, auxwrapcont);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int luaB_yield(lua_State* L)
|
||||
static int coyield(lua_State* L)
|
||||
{
|
||||
int nres = cast_int(L->top - L->base);
|
||||
return lua_yield(L, nres);
|
||||
}
|
||||
|
||||
static int luaB_corunning(lua_State* L)
|
||||
static int corunning(lua_State* L)
|
||||
{
|
||||
if (lua_pushthread(L))
|
||||
lua_pushnil(L); /* main thread is not a coroutine */
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int luaB_yieldable(lua_State* L)
|
||||
static int coyieldable(lua_State* L)
|
||||
{
|
||||
lua_pushboolean(L, lua_isyieldable(L));
|
||||
return 1;
|
||||
}
|
||||
|
||||
static const luaL_Reg co_funcs[] = {
|
||||
{"create", luaB_cocreate},
|
||||
{"running", luaB_corunning},
|
||||
{"status", luaB_costatus},
|
||||
{"wrap", luaB_cowrap},
|
||||
{"yield", luaB_yield},
|
||||
{"isyieldable", luaB_yieldable},
|
||||
{"create", cocreate},
|
||||
{"running", corunning},
|
||||
{"status", costatus},
|
||||
{"wrap", cowrap},
|
||||
{"yield", coyield},
|
||||
{"isyieldable", coyieldable},
|
||||
{NULL, NULL},
|
||||
};
|
||||
|
||||
|
@ -258,7 +246,7 @@ LUALIB_API int luaopen_coroutine(lua_State* L)
|
|||
{
|
||||
luaL_register(L, LUA_COLIBNAME, co_funcs);
|
||||
|
||||
lua_pushcfunction(L, luaB_coresumey, "resume", 0, luaB_coresumecont);
|
||||
lua_pushcfunction(L, coresumey, "resume", 0, coresumecont);
|
||||
lua_setfield(L, -2, "resume");
|
||||
|
||||
return 1;
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include <string.h>
|
||||
|
||||
LUAU_FASTFLAGVARIABLE(LuauExceptionMessageFix, false)
|
||||
LUAU_FASTFLAGVARIABLE(LuauCcallRestoreFix, false)
|
||||
|
||||
/*
|
||||
** {======================================================
|
||||
|
@ -536,6 +537,12 @@ int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t e
|
|||
status = LUA_ERRERR;
|
||||
}
|
||||
|
||||
if (FFlag::LuauCcallRestoreFix)
|
||||
{
|
||||
// Restore nCcalls before calling the debugprotectederror callback which may rely on the proper value to have been restored.
|
||||
L->nCcalls = oldnCcalls;
|
||||
}
|
||||
|
||||
// an error occurred, check if we have a protected error callback
|
||||
if (L->global->cb.debugprotectederror)
|
||||
{
|
||||
|
@ -549,7 +556,10 @@ int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t e
|
|||
StkId oldtop = restorestack(L, old_top);
|
||||
luaF_close(L, oldtop); /* close eventual pending closures */
|
||||
seterrorobj(L, status, oldtop);
|
||||
L->nCcalls = oldnCcalls;
|
||||
if (!FFlag::LuauCcallRestoreFix)
|
||||
{
|
||||
L->nCcalls = oldnCcalls;
|
||||
}
|
||||
L->ci = restoreci(L, old_ci);
|
||||
L->base = L->ci->base;
|
||||
restore_stack_limit(L);
|
||||
|
|
325
VM/src/lgc.cpp
325
VM/src/lgc.cpp
|
@ -12,11 +12,9 @@
|
|||
#include <string.h>
|
||||
#include <stdio.h>
|
||||
|
||||
LUAU_FASTFLAGVARIABLE(LuauRescanGrayAgain, false)
|
||||
LUAU_FASTFLAGVARIABLE(LuauRescanGrayAgainForwardBarrier, false)
|
||||
LUAU_FASTFLAGVARIABLE(LuauGcFullSkipInactiveThreads, false)
|
||||
LUAU_FASTFLAGVARIABLE(LuauShrinkWeakTables, false)
|
||||
LUAU_FASTFLAGVARIABLE(LuauConsolidatedStep, false)
|
||||
LUAU_FASTFLAGVARIABLE(LuauSeparateAtomic, false)
|
||||
|
||||
LUAU_FASTFLAG(LuauArrayBoundary)
|
||||
|
||||
|
@ -66,13 +64,18 @@ static void recordGcStateTime(global_State* g, int startgcstate, double seconds,
|
|||
g->gcstats.currcycle.marktime += seconds;
|
||||
|
||||
// atomic step had to be performed during the switch and it's tracked separately
|
||||
if (g->gcstate == GCSsweepstring)
|
||||
if (!FFlag::LuauSeparateAtomic && g->gcstate == GCSsweepstring)
|
||||
g->gcstats.currcycle.marktime -= g->gcstats.currcycle.atomictime;
|
||||
break;
|
||||
case GCSatomic:
|
||||
g->gcstats.currcycle.atomictime += seconds;
|
||||
break;
|
||||
case GCSsweepstring:
|
||||
case GCSsweep:
|
||||
g->gcstats.currcycle.sweeptime += seconds;
|
||||
break;
|
||||
default:
|
||||
LUAU_ASSERT(!"Unexpected GC state");
|
||||
}
|
||||
|
||||
if (assist)
|
||||
|
@ -183,33 +186,15 @@ static int traversetable(global_State* g, Table* h)
|
|||
if (h->metatable)
|
||||
markobject(g, cast_to(Table*, h->metatable));
|
||||
|
||||
if (FFlag::LuauShrinkWeakTables)
|
||||
/* is there a weak mode? */
|
||||
if (const char* modev = gettablemode(g, h))
|
||||
{
|
||||
/* is there a weak mode? */
|
||||
if (const char* modev = gettablemode(g, h))
|
||||
{
|
||||
weakkey = (strchr(modev, 'k') != NULL);
|
||||
weakvalue = (strchr(modev, 'v') != NULL);
|
||||
if (weakkey || weakvalue)
|
||||
{ /* is really weak? */
|
||||
h->gclist = g->weak; /* must be cleared after GC, ... */
|
||||
g->weak = obj2gco(h); /* ... so put in the appropriate list */
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
const TValue* mode = gfasttm(g, h->metatable, TM_MODE);
|
||||
if (mode && ttisstring(mode))
|
||||
{ /* is there a weak mode? */
|
||||
const char* modev = svalue(mode);
|
||||
weakkey = (strchr(modev, 'k') != NULL);
|
||||
weakvalue = (strchr(modev, 'v') != NULL);
|
||||
if (weakkey || weakvalue)
|
||||
{ /* is really weak? */
|
||||
h->gclist = g->weak; /* must be cleared after GC, ... */
|
||||
g->weak = obj2gco(h); /* ... so put in the appropriate list */
|
||||
}
|
||||
weakkey = (strchr(modev, 'k') != NULL);
|
||||
weakvalue = (strchr(modev, 'v') != NULL);
|
||||
if (weakkey || weakvalue)
|
||||
{ /* is really weak? */
|
||||
h->gclist = g->weak; /* must be cleared after GC, ... */
|
||||
g->weak = obj2gco(h); /* ... so put in the appropriate list */
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -297,7 +282,7 @@ static void traversestack(global_State* g, lua_State* l, bool clearstack)
|
|||
for (StkId o = l->stack; o < l->top; o++)
|
||||
markvalue(g, o);
|
||||
/* final traversal? */
|
||||
if (g->gcstate == GCSatomic || (FFlag::LuauGcFullSkipInactiveThreads && clearstack))
|
||||
if (g->gcstate == GCSatomic || clearstack)
|
||||
{
|
||||
StkId stack_end = l->stack + l->stacksize;
|
||||
for (StkId o = l->top; o < stack_end; o++) /* clear not-marked stack slice */
|
||||
|
@ -336,28 +321,16 @@ static size_t propagatemark(global_State* g)
|
|||
lua_State* th = gco2th(o);
|
||||
g->gray = th->gclist;
|
||||
|
||||
if (FFlag::LuauGcFullSkipInactiveThreads)
|
||||
LUAU_ASSERT(!luaC_threadsleeping(th));
|
||||
|
||||
// threads that are executing and the main thread are not deactivated
|
||||
bool active = luaC_threadactive(th) || th == th->global->mainthread;
|
||||
|
||||
if (!active && g->gcstate == GCSpropagate)
|
||||
{
|
||||
LUAU_ASSERT(!luaC_threadsleeping(th));
|
||||
traversestack(g, th, /* clearstack= */ true);
|
||||
|
||||
// threads that are executing and the main thread are not deactivated
|
||||
bool active = luaC_threadactive(th) || th == th->global->mainthread;
|
||||
|
||||
if (!active && g->gcstate == GCSpropagate)
|
||||
{
|
||||
traversestack(g, th, /* clearstack= */ true);
|
||||
|
||||
l_setbit(th->stackstate, THREAD_SLEEPINGBIT);
|
||||
}
|
||||
else
|
||||
{
|
||||
th->gclist = g->grayagain;
|
||||
g->grayagain = o;
|
||||
|
||||
black2gray(o);
|
||||
|
||||
traversestack(g, th, /* clearstack= */ false);
|
||||
}
|
||||
l_setbit(th->stackstate, THREAD_SLEEPINGBIT);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
@ -385,12 +358,14 @@ static size_t propagatemark(global_State* g)
|
|||
}
|
||||
}
|
||||
|
||||
static void propagateall(global_State* g)
|
||||
static size_t propagateall(global_State* g)
|
||||
{
|
||||
size_t work = 0;
|
||||
while (g->gray)
|
||||
{
|
||||
propagatemark(g);
|
||||
work += propagatemark(g);
|
||||
}
|
||||
return work;
|
||||
}
|
||||
|
||||
/*
|
||||
|
@ -415,11 +390,14 @@ static int isobjcleared(GCObject* o)
|
|||
/*
|
||||
** clear collected entries from weaktables
|
||||
*/
|
||||
static void cleartable(lua_State* L, GCObject* l)
|
||||
static size_t cleartable(lua_State* L, GCObject* l)
|
||||
{
|
||||
size_t work = 0;
|
||||
while (l)
|
||||
{
|
||||
Table* h = gco2h(l);
|
||||
work += sizeof(Table) + sizeof(TValue) * h->sizearray + sizeof(LuaNode) * sizenode(h);
|
||||
|
||||
int i = h->sizearray;
|
||||
while (i--)
|
||||
{
|
||||
|
@ -433,50 +411,36 @@ static void cleartable(lua_State* L, GCObject* l)
|
|||
{
|
||||
LuaNode* n = gnode(h, i);
|
||||
|
||||
if (FFlag::LuauShrinkWeakTables)
|
||||
// non-empty entry?
|
||||
if (!ttisnil(gval(n)))
|
||||
{
|
||||
// non-empty entry?
|
||||
if (!ttisnil(gval(n)))
|
||||
{
|
||||
// can we clear key or value?
|
||||
if (iscleared(gkey(n)) || iscleared(gval(n)))
|
||||
{
|
||||
setnilvalue(gval(n)); /* remove value ... */
|
||||
removeentry(n); /* remove entry from table */
|
||||
}
|
||||
else
|
||||
{
|
||||
activevalues++;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if (!ttisnil(gval(n)) && /* non-empty entry? */
|
||||
(iscleared(gkey(n)) || iscleared(gval(n))))
|
||||
// can we clear key or value?
|
||||
if (iscleared(gkey(n)) || iscleared(gval(n)))
|
||||
{
|
||||
setnilvalue(gval(n)); /* remove value ... */
|
||||
removeentry(n); /* remove entry from table */
|
||||
}
|
||||
else
|
||||
{
|
||||
activevalues++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (FFlag::LuauShrinkWeakTables)
|
||||
if (const char* modev = gettablemode(L->global, h))
|
||||
{
|
||||
if (const char* modev = gettablemode(L->global, h))
|
||||
// are we allowed to shrink this weak table?
|
||||
if (strchr(modev, 's'))
|
||||
{
|
||||
// are we allowed to shrink this weak table?
|
||||
if (strchr(modev, 's'))
|
||||
{
|
||||
// shrink at 37.5% occupancy
|
||||
if (activevalues < sizenode(h) * 3 / 8)
|
||||
luaH_resizehash(L, h, activevalues);
|
||||
}
|
||||
// shrink at 37.5% occupancy
|
||||
if (activevalues < sizenode(h) * 3 / 8)
|
||||
luaH_resizehash(L, h, activevalues);
|
||||
}
|
||||
}
|
||||
|
||||
l = h->gclist;
|
||||
}
|
||||
return work;
|
||||
}
|
||||
|
||||
static void shrinkstack(lua_State* L)
|
||||
|
@ -655,37 +619,49 @@ static void markroot(lua_State* L)
|
|||
g->gcstate = GCSpropagate;
|
||||
}
|
||||
|
||||
static void remarkupvals(global_State* g)
|
||||
static size_t remarkupvals(global_State* g)
|
||||
{
|
||||
UpVal* uv;
|
||||
for (uv = g->uvhead.u.l.next; uv != &g->uvhead; uv = uv->u.l.next)
|
||||
size_t work = 0;
|
||||
for (UpVal* uv = g->uvhead.u.l.next; uv != &g->uvhead; uv = uv->u.l.next)
|
||||
{
|
||||
work += sizeof(UpVal);
|
||||
LUAU_ASSERT(uv->u.l.next->u.l.prev == uv && uv->u.l.prev->u.l.next == uv);
|
||||
if (isgray(obj2gco(uv)))
|
||||
markvalue(g, uv->v);
|
||||
}
|
||||
return work;
|
||||
}
|
||||
|
||||
static void atomic(lua_State* L)
|
||||
static size_t atomic(lua_State* L)
|
||||
{
|
||||
global_State* g = L->global;
|
||||
g->gcstate = GCSatomic;
|
||||
size_t work = 0;
|
||||
|
||||
if (FFlag::LuauSeparateAtomic)
|
||||
{
|
||||
LUAU_ASSERT(g->gcstate == GCSatomic);
|
||||
}
|
||||
else
|
||||
{
|
||||
g->gcstate = GCSatomic;
|
||||
}
|
||||
|
||||
/* remark occasional upvalues of (maybe) dead threads */
|
||||
remarkupvals(g);
|
||||
work += remarkupvals(g);
|
||||
/* traverse objects caught by write barrier and by 'remarkupvals' */
|
||||
propagateall(g);
|
||||
work += propagateall(g);
|
||||
/* remark weak tables */
|
||||
g->gray = g->weak;
|
||||
g->weak = NULL;
|
||||
LUAU_ASSERT(!iswhite(obj2gco(g->mainthread)));
|
||||
markobject(g, L); /* mark running thread */
|
||||
markmt(g); /* mark basic metatables (again) */
|
||||
propagateall(g);
|
||||
work += propagateall(g);
|
||||
/* remark gray again */
|
||||
g->gray = g->grayagain;
|
||||
g->grayagain = NULL;
|
||||
propagateall(g);
|
||||
cleartable(L, g->weak); /* remove collected objects from weak tables */
|
||||
work += propagateall(g);
|
||||
work += cleartable(L, g->weak); /* remove collected objects from weak tables */
|
||||
g->weak = NULL;
|
||||
/* flip current white */
|
||||
g->currentwhite = cast_byte(otherwhite(g));
|
||||
|
@ -693,7 +669,12 @@ static void atomic(lua_State* L)
|
|||
g->sweepgc = &g->rootgc;
|
||||
g->gcstate = GCSsweepstring;
|
||||
|
||||
GC_INTERRUPT(GCSatomic);
|
||||
if (!FFlag::LuauSeparateAtomic)
|
||||
{
|
||||
GC_INTERRUPT(GCSatomic);
|
||||
}
|
||||
|
||||
return work;
|
||||
}
|
||||
|
||||
static size_t singlestep(lua_State* L)
|
||||
|
@ -705,46 +686,24 @@ static size_t singlestep(lua_State* L)
|
|||
case GCSpause:
|
||||
{
|
||||
markroot(L); /* start a new collection */
|
||||
LUAU_ASSERT(g->gcstate == GCSpropagate);
|
||||
break;
|
||||
}
|
||||
case GCSpropagate:
|
||||
{
|
||||
if (FFlag::LuauRescanGrayAgain)
|
||||
if (g->gray)
|
||||
{
|
||||
if (g->gray)
|
||||
{
|
||||
g->gcstats.currcycle.markitems++;
|
||||
g->gcstats.currcycle.markitems++;
|
||||
|
||||
cost = propagatemark(g);
|
||||
}
|
||||
else
|
||||
{
|
||||
// perform one iteration over 'gray again' list
|
||||
g->gray = g->grayagain;
|
||||
g->grayagain = NULL;
|
||||
|
||||
g->gcstate = GCSpropagateagain;
|
||||
}
|
||||
cost = propagatemark(g);
|
||||
}
|
||||
else
|
||||
{
|
||||
if (g->gray)
|
||||
{
|
||||
g->gcstats.currcycle.markitems++;
|
||||
// perform one iteration over 'gray again' list
|
||||
g->gray = g->grayagain;
|
||||
g->grayagain = NULL;
|
||||
|
||||
cost = propagatemark(g);
|
||||
}
|
||||
else /* no more `gray' objects */
|
||||
{
|
||||
double starttimestamp = lua_clock();
|
||||
|
||||
g->gcstats.currcycle.atomicstarttimestamp = starttimestamp;
|
||||
g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes;
|
||||
|
||||
atomic(L); /* finish mark phase */
|
||||
|
||||
g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp;
|
||||
}
|
||||
g->gcstate = GCSpropagateagain;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
@ -758,17 +717,34 @@ static size_t singlestep(lua_State* L)
|
|||
}
|
||||
else /* no more `gray' objects */
|
||||
{
|
||||
double starttimestamp = lua_clock();
|
||||
if (FFlag::LuauSeparateAtomic)
|
||||
{
|
||||
g->gcstate = GCSatomic;
|
||||
}
|
||||
else
|
||||
{
|
||||
double starttimestamp = lua_clock();
|
||||
|
||||
g->gcstats.currcycle.atomicstarttimestamp = starttimestamp;
|
||||
g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes;
|
||||
g->gcstats.currcycle.atomicstarttimestamp = starttimestamp;
|
||||
g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes;
|
||||
|
||||
atomic(L); /* finish mark phase */
|
||||
atomic(L); /* finish mark phase */
|
||||
LUAU_ASSERT(g->gcstate == GCSsweepstring);
|
||||
|
||||
g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp;
|
||||
g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp;
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
case GCSatomic:
|
||||
{
|
||||
g->gcstats.currcycle.atomicstarttimestamp = lua_clock();
|
||||
g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes;
|
||||
|
||||
cost = atomic(L); /* finish mark phase */
|
||||
LUAU_ASSERT(g->gcstate == GCSsweepstring);
|
||||
break;
|
||||
}
|
||||
case GCSsweepstring:
|
||||
{
|
||||
size_t traversedcount = 0;
|
||||
|
@ -806,7 +782,7 @@ static size_t singlestep(lua_State* L)
|
|||
break;
|
||||
}
|
||||
default:
|
||||
LUAU_ASSERT(0);
|
||||
LUAU_ASSERT(!"Unexpected GC state");
|
||||
}
|
||||
|
||||
return cost;
|
||||
|
@ -821,48 +797,25 @@ static size_t gcstep(lua_State* L, size_t limit)
|
|||
case GCSpause:
|
||||
{
|
||||
markroot(L); /* start a new collection */
|
||||
LUAU_ASSERT(g->gcstate == GCSpropagate);
|
||||
break;
|
||||
}
|
||||
case GCSpropagate:
|
||||
{
|
||||
if (FFlag::LuauRescanGrayAgain)
|
||||
while (g->gray && cost < limit)
|
||||
{
|
||||
while (g->gray && cost < limit)
|
||||
{
|
||||
g->gcstats.currcycle.markitems++;
|
||||
g->gcstats.currcycle.markitems++;
|
||||
|
||||
cost += propagatemark(g);
|
||||
}
|
||||
|
||||
if (!g->gray)
|
||||
{
|
||||
// perform one iteration over 'gray again' list
|
||||
g->gray = g->grayagain;
|
||||
g->grayagain = NULL;
|
||||
|
||||
g->gcstate = GCSpropagateagain;
|
||||
}
|
||||
cost += propagatemark(g);
|
||||
}
|
||||
else
|
||||
|
||||
if (!g->gray)
|
||||
{
|
||||
while (g->gray && cost < limit)
|
||||
{
|
||||
g->gcstats.currcycle.markitems++;
|
||||
// perform one iteration over 'gray again' list
|
||||
g->gray = g->grayagain;
|
||||
g->grayagain = NULL;
|
||||
|
||||
cost += propagatemark(g);
|
||||
}
|
||||
|
||||
if (!g->gray) /* no more `gray' objects */
|
||||
{
|
||||
double starttimestamp = lua_clock();
|
||||
|
||||
g->gcstats.currcycle.atomicstarttimestamp = starttimestamp;
|
||||
g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes;
|
||||
|
||||
atomic(L); /* finish mark phase */
|
||||
|
||||
g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp;
|
||||
}
|
||||
g->gcstate = GCSpropagateagain;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
@ -877,17 +830,34 @@ static size_t gcstep(lua_State* L, size_t limit)
|
|||
|
||||
if (!g->gray) /* no more `gray' objects */
|
||||
{
|
||||
double starttimestamp = lua_clock();
|
||||
if (FFlag::LuauSeparateAtomic)
|
||||
{
|
||||
g->gcstate = GCSatomic;
|
||||
}
|
||||
else
|
||||
{
|
||||
double starttimestamp = lua_clock();
|
||||
|
||||
g->gcstats.currcycle.atomicstarttimestamp = starttimestamp;
|
||||
g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes;
|
||||
g->gcstats.currcycle.atomicstarttimestamp = starttimestamp;
|
||||
g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes;
|
||||
|
||||
atomic(L); /* finish mark phase */
|
||||
atomic(L); /* finish mark phase */
|
||||
LUAU_ASSERT(g->gcstate == GCSsweepstring);
|
||||
|
||||
g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp;
|
||||
g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp;
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
case GCSatomic:
|
||||
{
|
||||
g->gcstats.currcycle.atomicstarttimestamp = lua_clock();
|
||||
g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes;
|
||||
|
||||
cost = atomic(L); /* finish mark phase */
|
||||
LUAU_ASSERT(g->gcstate == GCSsweepstring);
|
||||
break;
|
||||
}
|
||||
case GCSsweepstring:
|
||||
{
|
||||
while (g->sweepstrgc < g->strt.size && cost < limit)
|
||||
|
@ -934,7 +904,7 @@ static size_t gcstep(lua_State* L, size_t limit)
|
|||
break;
|
||||
}
|
||||
default:
|
||||
LUAU_ASSERT(0);
|
||||
LUAU_ASSERT(!"Unexpected GC state");
|
||||
}
|
||||
return cost;
|
||||
}
|
||||
|
@ -1084,7 +1054,7 @@ void luaC_fullgc(lua_State* L)
|
|||
if (g->gcstate == GCSpause)
|
||||
startGcCycleStats(g);
|
||||
|
||||
if (g->gcstate <= GCSpropagateagain)
|
||||
if (g->gcstate <= (FFlag::LuauSeparateAtomic ? GCSatomic : GCSpropagateagain))
|
||||
{
|
||||
/* reset sweep marks to sweep all elements (returning them to white) */
|
||||
g->sweepstrgc = 0;
|
||||
|
@ -1095,7 +1065,7 @@ void luaC_fullgc(lua_State* L)
|
|||
g->weak = NULL;
|
||||
g->gcstate = GCSsweepstring;
|
||||
}
|
||||
LUAU_ASSERT(g->gcstate != GCSpause && g->gcstate != GCSpropagate && g->gcstate != GCSpropagateagain);
|
||||
LUAU_ASSERT(g->gcstate == GCSsweepstring || g->gcstate == GCSsweep);
|
||||
/* finish any pending sweep phase */
|
||||
while (g->gcstate != GCSpause)
|
||||
{
|
||||
|
@ -1143,14 +1113,11 @@ void luaC_fullgc(lua_State* L)
|
|||
|
||||
void luaC_barrierupval(lua_State* L, GCObject* v)
|
||||
{
|
||||
if (FFlag::LuauGcFullSkipInactiveThreads)
|
||||
{
|
||||
global_State* g = L->global;
|
||||
LUAU_ASSERT(iswhite(v) && !isdead(g, v));
|
||||
global_State* g = L->global;
|
||||
LUAU_ASSERT(iswhite(v) && !isdead(g, v));
|
||||
|
||||
if (keepinvariant(g))
|
||||
reallymarkobject(g, v);
|
||||
}
|
||||
if (keepinvariant(g))
|
||||
reallymarkobject(g, v);
|
||||
}
|
||||
|
||||
void luaC_barrierf(lua_State* L, GCObject* o, GCObject* v)
|
||||
|
@ -1778,7 +1745,7 @@ int64_t luaC_allocationrate(lua_State* L)
|
|||
global_State* g = L->global;
|
||||
const double durationthreshold = 1e-3; // avoid measuring intervals smaller than 1ms
|
||||
|
||||
if (g->gcstate <= GCSpropagateagain)
|
||||
if (g->gcstate <= (FFlag::LuauSeparateAtomic ? GCSatomic : GCSpropagateagain))
|
||||
{
|
||||
double duration = lua_clock() - g->gcstats.lastcycle.endtimestamp;
|
||||
|
||||
|
|
|
@ -6,8 +6,6 @@
|
|||
#include "lobject.h"
|
||||
#include "lstate.h"
|
||||
|
||||
LUAU_FASTFLAG(LuauGcFullSkipInactiveThreads)
|
||||
|
||||
/*
|
||||
** Possible states of the Garbage Collector
|
||||
*/
|
||||
|
@ -25,7 +23,7 @@ LUAU_FASTFLAG(LuauGcFullSkipInactiveThreads)
|
|||
** still-black objects. The invariant is restored when sweep ends and
|
||||
** all objects are white again.
|
||||
*/
|
||||
#define keepinvariant(g) ((g)->gcstate == GCSpropagate || (g)->gcstate == GCSpropagateagain)
|
||||
#define keepinvariant(g) ((g)->gcstate == GCSpropagate || (g)->gcstate == GCSpropagateagain || (g)->gcstate == GCSatomic)
|
||||
|
||||
/*
|
||||
** some useful bit tricks
|
||||
|
@ -147,4 +145,4 @@ LUAI_FUNC void luaC_validate(lua_State* L);
|
|||
LUAI_FUNC void luaC_dump(lua_State* L, void* file, const char* (*categoryName)(lua_State* L, uint8_t memcat));
|
||||
LUAI_FUNC int64_t luaC_allocationrate(lua_State* L);
|
||||
LUAI_FUNC void luaC_wakethread(lua_State* L);
|
||||
LUAI_FUNC const char* luaC_statename(int state);
|
||||
LUAI_FUNC const char* luaC_statename(int state);
|
||||
|
|
|
@ -199,7 +199,7 @@ static void* luaM_newblock(lua_State* L, int sizeClass)
|
|||
|
||||
if (page->freeNext >= 0)
|
||||
{
|
||||
block = page->data + page->freeNext;
|
||||
block = &page->data + page->freeNext;
|
||||
ASAN_UNPOISON_MEMORY_REGION(block, page->blockSize);
|
||||
|
||||
page->freeNext -= page->blockSize;
|
||||
|
|
|
@ -226,7 +226,7 @@ void luaS_freeudata(lua_State* L, Udata* u)
|
|||
|
||||
void (*dtor)(void*) = nullptr;
|
||||
if (u->tag == UTAG_IDTOR)
|
||||
memcpy(&dtor, u->data + u->len - sizeof(dtor), sizeof(dtor));
|
||||
memcpy(&dtor, &u->data + u->len - sizeof(dtor), sizeof(dtor));
|
||||
else if (u->tag)
|
||||
dtor = L->global->udatagc[u->tag];
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
#include <string.h>
|
||||
|
||||
// TODO: RAII deallocation doesn't work for longjmp builds if a memory error happens
|
||||
template <typename T>
|
||||
template<typename T>
|
||||
struct TempBuffer
|
||||
{
|
||||
lua_State* L;
|
||||
|
@ -346,6 +346,8 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size
|
|||
uint32_t mainid = readVarInt(data, size, offset);
|
||||
Proto* main = protos[mainid];
|
||||
|
||||
luaC_checkthreadsleep(L);
|
||||
|
||||
Closure* cl = luaF_newLclosure(L, 0, envt, main);
|
||||
setclvalue(L, L->top, cl);
|
||||
incr_top(L);
|
||||
|
|
849
bench/tests/chess.lua
Normal file
849
bench/tests/chess.lua
Normal file
|
@ -0,0 +1,849 @@
|
|||
|
||||
local bench = script and require(script.Parent.bench_support) or require("bench_support")
|
||||
|
||||
local RANKS = "12345678"
|
||||
local FILES = "abcdefgh"
|
||||
local PieceSymbols = "PpRrNnBbQqKk"
|
||||
local UnicodePieces = {"♙", "♟", "♖", "♜", "♘", "♞", "♗", "♝", "♕", "♛", "♔", "♚"}
|
||||
local StartingFen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
|
||||
|
||||
--
|
||||
-- Lua 5.2 Compat
|
||||
--
|
||||
|
||||
if not table.create then
|
||||
function table.create(n, v)
|
||||
local result = {}
|
||||
for i=1,n do result[i] = v end
|
||||
return result
|
||||
end
|
||||
end
|
||||
|
||||
if not table.move then
|
||||
function table.move(a, from, to, start, target)
|
||||
local dx = start - from
|
||||
for i=from,to do
|
||||
target[i+dx] = a[i]
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
--
|
||||
-- Utils
|
||||
--
|
||||
|
||||
local function square(s)
|
||||
return RANKS:find(s:sub(2,2)) * 8 + FILES:find(s:sub(1,1)) - 9
|
||||
end
|
||||
|
||||
local function squareName(n)
|
||||
local file = n % 8
|
||||
local rank = (n-file)/8
|
||||
return FILES:sub(file+1,file+1) .. RANKS:sub(rank+1,rank+1)
|
||||
end
|
||||
|
||||
local function moveName(v )
|
||||
local from = bit32.extract(v, 6, 6)
|
||||
local to = bit32.extract(v, 0, 6)
|
||||
local piece = bit32.extract(v, 20, 4)
|
||||
local captured = bit32.extract(v, 25, 4)
|
||||
|
||||
local move = PieceSymbols:sub(piece,piece) .. ' ' .. squareName(from) .. (captured ~= 0 and 'x' or '-') .. squareName(to)
|
||||
|
||||
if bit32.extract(v,14) == 1 then
|
||||
if to > from then
|
||||
return "O-O"
|
||||
else
|
||||
return "O-O-O"
|
||||
end
|
||||
end
|
||||
|
||||
local promote = bit32.extract(v,15,4)
|
||||
if promote ~= 0 then
|
||||
move = move .. "=" .. PieceSymbols:sub(promote,promote)
|
||||
end
|
||||
return move
|
||||
end
|
||||
|
||||
local function ucimove(m)
|
||||
local mm = squareName(bit32.extract(m, 6, 6)) .. squareName(bit32.extract(m, 0, 6))
|
||||
local promote = bit32.extract(m,15,4)
|
||||
if promote > 0 then
|
||||
mm = mm .. PieceSymbols:sub(promote,promote):lower()
|
||||
end
|
||||
return mm
|
||||
end
|
||||
|
||||
local _utils = {squareName, moveName}
|
||||
|
||||
--
|
||||
-- Bitboards
|
||||
--
|
||||
|
||||
local Bitboard = {}
|
||||
|
||||
|
||||
function Bitboard:toString()
|
||||
local out = {}
|
||||
|
||||
local src = self.h
|
||||
for x=7,0,-1 do
|
||||
table.insert(out, RANKS:sub(x+1,x+1))
|
||||
table.insert(out, " ")
|
||||
local bit = bit32.lshift(1,(x%4) * 8)
|
||||
for x=0,7 do
|
||||
if bit32.band(src, bit) ~= 0 then
|
||||
table.insert(out, "x ")
|
||||
else
|
||||
table.insert(out, "- ")
|
||||
end
|
||||
bit = bit32.lshift(bit, 1)
|
||||
end
|
||||
if x == 4 then
|
||||
src = self.l
|
||||
end
|
||||
table.insert(out, "\n")
|
||||
end
|
||||
table.insert(out, ' ' .. FILES:gsub('.', '%1 ') .. '\n')
|
||||
table.insert(out, '#: ' .. self:popcnt() .. "\tl:" .. self.l .. "\th:" .. self.h)
|
||||
return table.concat(out)
|
||||
end
|
||||
|
||||
|
||||
function Bitboard.from(l ,h )
|
||||
return setmetatable({l=l, h=h}, Bitboard)
|
||||
end
|
||||
|
||||
Bitboard.zero = Bitboard.from(0,0)
|
||||
Bitboard.full = Bitboard.from(0xFFFFFFFF, 0xFFFFFFFF)
|
||||
|
||||
local Rank1 = Bitboard.from(0x000000FF, 0)
|
||||
local Rank3 = Bitboard.from(0x00FF0000, 0)
|
||||
local Rank6 = Bitboard.from(0, 0x0000FF00)
|
||||
local Rank8 = Bitboard.from(0, 0xFF000000)
|
||||
local FileA = Bitboard.from(0x01010101, 0x01010101)
|
||||
local FileB = Bitboard.from(0x02020202, 0x02020202)
|
||||
local FileC = Bitboard.from(0x04040404, 0x04040404)
|
||||
local FileD = Bitboard.from(0x08080808, 0x08080808)
|
||||
local FileE = Bitboard.from(0x10101010, 0x10101010)
|
||||
local FileF = Bitboard.from(0x20202020, 0x20202020)
|
||||
local FileG = Bitboard.from(0x40404040, 0x40404040)
|
||||
local FileH = Bitboard.from(0x80808080, 0x80808080)
|
||||
|
||||
local _Files = {FileA, FileB, FileC, FileD, FileE, FileF, FileG, FileH}
|
||||
|
||||
-- These masks are filled out below for all files
|
||||
local RightMasks = {FileH}
|
||||
local LeftMasks = {FileA}
|
||||
|
||||
|
||||
|
||||
local function popcnt32(i)
|
||||
i = i - bit32.band(bit32.rshift(i,1), 0x55555555)
|
||||
i = bit32.band(i, 0x33333333) + bit32.band(bit32.rshift(i,2), 0x33333333)
|
||||
return bit32.rshift(bit32.band(i + bit32.rshift(i,4), 0x0F0F0F0F) * 0x01010101, 24)
|
||||
end
|
||||
|
||||
function Bitboard:up()
|
||||
return self:lshift(8)
|
||||
end
|
||||
|
||||
function Bitboard:down()
|
||||
return self:rshift(8)
|
||||
end
|
||||
|
||||
function Bitboard:right()
|
||||
return self:band(FileH:inverse()):lshift(1)
|
||||
end
|
||||
|
||||
function Bitboard:left()
|
||||
return self:band(FileA:inverse()):rshift(1)
|
||||
end
|
||||
|
||||
function Bitboard:move(x,y)
|
||||
local out = self
|
||||
|
||||
if x < 0 then out = out:bandnot(RightMasks[-x]):lshift(-x) end
|
||||
if x > 0 then out = out:bandnot(LeftMasks[x]):rshift(x) end
|
||||
|
||||
if y < 0 then out = out:rshift(-8 * y) end
|
||||
if y > 0 then out = out:lshift(8 * y) end
|
||||
return out
|
||||
end
|
||||
|
||||
|
||||
function Bitboard:popcnt()
|
||||
return popcnt32(self.l) + popcnt32(self.h)
|
||||
end
|
||||
|
||||
function Bitboard:band(other )
|
||||
return Bitboard.from(bit32.band(self.l,other.l), bit32.band(self.h, other.h))
|
||||
end
|
||||
|
||||
function Bitboard:bandnot(other )
|
||||
return Bitboard.from(bit32.band(self.l,bit32.bnot(other.l)), bit32.band(self.h, bit32.bnot(other.h)))
|
||||
end
|
||||
|
||||
function Bitboard:bandempty(other )
|
||||
return bit32.band(self.l,other.l) == 0 and bit32.band(self.h, other.h) == 0
|
||||
end
|
||||
|
||||
function Bitboard:bor(other )
|
||||
return Bitboard.from(bit32.bor(self.l,other.l), bit32.bor(self.h, other.h))
|
||||
end
|
||||
|
||||
function Bitboard:bxor(other )
|
||||
return Bitboard.from(bit32.bxor(self.l,other.l), bit32.bxor(self.h, other.h))
|
||||
end
|
||||
|
||||
function Bitboard:inverse()
|
||||
return Bitboard.from(bit32.bxor(self.l,0xFFFFFFFF), bit32.bxor(self.h, 0xFFFFFFFF))
|
||||
end
|
||||
|
||||
function Bitboard:empty()
|
||||
return self.h == 0 and self.l == 0
|
||||
end
|
||||
|
||||
function Bitboard:ctz()
|
||||
local target = self.l
|
||||
local offset = 0
|
||||
local result = 0
|
||||
|
||||
if target == 0 then
|
||||
target = self.h
|
||||
result = 32
|
||||
end
|
||||
|
||||
if target == 0 then
|
||||
return 64
|
||||
end
|
||||
|
||||
while bit32.extract(target, offset) == 0 do
|
||||
offset = offset + 1
|
||||
end
|
||||
|
||||
return result + offset
|
||||
end
|
||||
|
||||
function Bitboard:ctzafter(start)
|
||||
start = start + 1
|
||||
if start < 32 then
|
||||
for i=start,31 do
|
||||
if bit32.extract(self.l, i) == 1 then return i end
|
||||
end
|
||||
end
|
||||
for i=math.max(32,start),63 do
|
||||
if bit32.extract(self.h, i-32) == 1 then return i end
|
||||
end
|
||||
return 64
|
||||
end
|
||||
|
||||
|
||||
function Bitboard:lshift(amt)
|
||||
assert(amt >= 0)
|
||||
if amt == 0 then return self end
|
||||
|
||||
if amt > 31 then
|
||||
return Bitboard.from(0, bit32.lshift(self.l, amt-31))
|
||||
end
|
||||
|
||||
local l = bit32.lshift(self.l, amt)
|
||||
local h = bit32.bor(
|
||||
bit32.lshift(self.h, amt),
|
||||
bit32.extract(self.l, 32-amt, amt)
|
||||
)
|
||||
return Bitboard.from(l, h)
|
||||
end
|
||||
|
||||
function Bitboard:rshift(amt)
|
||||
assert(amt >= 0)
|
||||
if amt == 0 then return self end
|
||||
local h = bit32.rshift(self.h, amt)
|
||||
local l = bit32.bor(
|
||||
bit32.rshift(self.l, amt),
|
||||
bit32.lshift(bit32.extract(self.h, 0, amt), 32-amt)
|
||||
)
|
||||
return Bitboard.from(l, h)
|
||||
end
|
||||
|
||||
function Bitboard:index(i)
|
||||
if i > 31 then
|
||||
return bit32.extract(self.h, i - 32)
|
||||
else
|
||||
return bit32.extract(self.l, i)
|
||||
end
|
||||
end
|
||||
|
||||
function Bitboard:set(i , v)
|
||||
if i > 31 then
|
||||
return Bitboard.from(self.l, bit32.replace(self.h, v, i - 32))
|
||||
else
|
||||
return Bitboard.from(bit32.replace(self.l, v, i), self.h)
|
||||
end
|
||||
end
|
||||
|
||||
function Bitboard:isolate(i)
|
||||
return self:band(Bitboard.some(i))
|
||||
end
|
||||
|
||||
function Bitboard.some(idx )
|
||||
return Bitboard.zero:set(idx, 1)
|
||||
end
|
||||
|
||||
Bitboard.__index = Bitboard
|
||||
Bitboard.__tostring = Bitboard.toString
|
||||
|
||||
for i=2,8 do
|
||||
RightMasks[i] = RightMasks[i-1]:rshift(1):bor(FileH)
|
||||
LeftMasks[i] = LeftMasks[i-1]:lshift(1):bor(FileA)
|
||||
end
|
||||
--
|
||||
-- Board
|
||||
--
|
||||
|
||||
local Board = {}
|
||||
|
||||
|
||||
function Board.new()
|
||||
local boards = table.create(12, Bitboard.zero)
|
||||
boards.ocupied = Bitboard.zero
|
||||
boards.white = Bitboard.zero
|
||||
boards.black = Bitboard.zero
|
||||
boards.unocupied = Bitboard.full
|
||||
boards.ep = Bitboard.zero
|
||||
boards.castle = Bitboard.zero
|
||||
boards.toMove = 1
|
||||
boards.hm = 0
|
||||
boards.moves = 0
|
||||
boards.material = 0
|
||||
|
||||
return setmetatable(boards, Board)
|
||||
end
|
||||
|
||||
function Board.fromFen(fen )
|
||||
local b = Board.new()
|
||||
local i = 0
|
||||
local rank = 7
|
||||
local file = 0
|
||||
|
||||
while true do
|
||||
i = i + 1
|
||||
local p = fen:sub(i,i)
|
||||
if p == '/' then
|
||||
rank = rank - 1
|
||||
file = 0
|
||||
elseif tonumber(p) ~= nil then
|
||||
file = file + tonumber(p)
|
||||
else
|
||||
local pidx = PieceSymbols:find(p)
|
||||
if pidx == nil then break end
|
||||
b[pidx] = b[pidx]:set(rank*8+file, 1)
|
||||
file = file + 1
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
local move, castle, ep, hm, m = string.match(fen, "^ ([bw]) ([KQkq-]*) ([a-h-][0-9]?) (%d*) (%d*)", i)
|
||||
if move == nil then print(fen:sub(i)) end
|
||||
b.toMove = move == 'w' and 1 or 2
|
||||
|
||||
if ep ~= "-" then
|
||||
b.ep = Bitboard.some(square(ep))
|
||||
end
|
||||
|
||||
if castle ~= "-" then
|
||||
local oo = Bitboard.zero
|
||||
if castle:find("K") then
|
||||
oo = oo:set(7, 1)
|
||||
end
|
||||
if castle:find("Q") then
|
||||
oo = oo:set(0, 1)
|
||||
end
|
||||
if castle:find("k") then
|
||||
oo = oo:set(63, 1)
|
||||
end
|
||||
if castle:find("q") then
|
||||
oo = oo:set(56, 1)
|
||||
end
|
||||
|
||||
b.castle = oo
|
||||
end
|
||||
|
||||
b.hm = hm
|
||||
b.moves = m
|
||||
|
||||
b:updateCache()
|
||||
return b
|
||||
|
||||
end
|
||||
|
||||
function Board:index(idx )
|
||||
if self.white:index(idx) == 1 then
|
||||
for p=1,12,2 do
|
||||
if self[p]:index(idx) == 1 then
|
||||
return p
|
||||
end
|
||||
end
|
||||
else
|
||||
for p=2,12,2 do
|
||||
if self[p]:index(idx) == 1 then
|
||||
return p
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return 0
|
||||
end
|
||||
|
||||
function Board:updateCache()
|
||||
for i=1,11,2 do
|
||||
self.white = self.white:bor(self[i])
|
||||
self.black = self.black:bor(self[i+1])
|
||||
end
|
||||
|
||||
self.ocupied = self.black:bor(self.white)
|
||||
self.unocupied = self.ocupied:inverse()
|
||||
self.material =
|
||||
100*self[1]:popcnt() - 100*self[2]:popcnt() +
|
||||
500*self[3]:popcnt() - 500*self[4]:popcnt() +
|
||||
300*self[5]:popcnt() - 300*self[6]:popcnt() +
|
||||
300*self[7]:popcnt() - 300*self[8]:popcnt() +
|
||||
900*self[9]:popcnt() - 900*self[10]:popcnt()
|
||||
|
||||
end
|
||||
|
||||
function Board:fen()
|
||||
local out = {}
|
||||
local s = 0
|
||||
local idx = 56
|
||||
for i=0,63 do
|
||||
if i % 8 == 0 and i > 0 then
|
||||
idx = idx - 16
|
||||
if s > 0 then
|
||||
table.insert(out, '' .. s)
|
||||
s = 0
|
||||
end
|
||||
table.insert(out, '/')
|
||||
end
|
||||
local p = self:index(idx)
|
||||
if p == 0 then
|
||||
s = s + 1
|
||||
else
|
||||
if s > 0 then
|
||||
table.insert(out, '' .. s)
|
||||
s = 0
|
||||
end
|
||||
table.insert(out, PieceSymbols:sub(p,p))
|
||||
end
|
||||
|
||||
idx = idx + 1
|
||||
end
|
||||
if s > 0 then
|
||||
table.insert(out, '' .. s)
|
||||
end
|
||||
|
||||
table.insert(out, self.toMove == 1 and ' w ' or ' b ')
|
||||
if self.castle:empty() then
|
||||
table.insert(out, '-')
|
||||
else
|
||||
if self.castle:index(7) == 1 then table.insert(out, 'K') end
|
||||
if self.castle:index(0) == 1 then table.insert(out, 'Q') end
|
||||
if self.castle:index(63) == 1 then table.insert(out, 'k') end
|
||||
if self.castle:index(56) == 1 then table.insert(out, 'q') end
|
||||
end
|
||||
|
||||
table.insert(out, ' ')
|
||||
if self.ep:empty() then
|
||||
table.insert(out, '-')
|
||||
else
|
||||
table.insert(out, squareName(self.ep:ctz()))
|
||||
end
|
||||
|
||||
table.insert(out, ' ' .. self.hm)
|
||||
table.insert(out, ' ' .. self.moves)
|
||||
|
||||
return table.concat(out)
|
||||
end
|
||||
|
||||
function Board:pmoves(idx)
|
||||
return self:generate(idx)
|
||||
end
|
||||
|
||||
function Board:pcaptures(idx)
|
||||
return self:generate(idx):band(self.ocupied)
|
||||
end
|
||||
|
||||
local ROOK_SLIDES = {{1,0}, {-1,0}, {0,1}, {0,-1}}
|
||||
local BISHOP_SLIDES = {{1,1}, {-1,1}, {1,-1}, {-1,-1}}
|
||||
local QUEEN_SLIDES = {{1,0}, {-1,0}, {0,1}, {0,-1}, {1,1}, {-1,1}, {1,-1}, {-1,-1}}
|
||||
local KNIGHT_MOVES = {{2,1}, {2,-1}, {-2,1}, {-2,-1}, {1,2}, {1,-2}, {-1,2}, {-1,-2}}
|
||||
|
||||
function Board:generate(idx)
|
||||
local piece = self:index(idx)
|
||||
local r = Bitboard.some(idx)
|
||||
local out = Bitboard.zero
|
||||
local type = bit32.rshift(piece - 1, 1)
|
||||
local cancapture = piece % 2 == 1 and self.black or self.white
|
||||
|
||||
if piece == 0 then return Bitboard.zero end
|
||||
|
||||
if type == 0 then
|
||||
-- Pawn
|
||||
local d = -(piece*2 - 3)
|
||||
local movetwo = piece == 1 and Rank3 or Rank6
|
||||
|
||||
out = out:bor(r:move(0,d):band(self.unocupied))
|
||||
out = out:bor(out:band(movetwo):move(0,d):band(self.unocupied))
|
||||
|
||||
local captures = r:move(0,d)
|
||||
captures = captures:right():bor(captures:left())
|
||||
|
||||
if not captures:bandempty(self.ep) then
|
||||
out = out:bor(self.ep)
|
||||
end
|
||||
|
||||
captures = captures:band(cancapture)
|
||||
out = out:bor(captures)
|
||||
|
||||
return out
|
||||
elseif type == 5 then
|
||||
-- King
|
||||
for x=-1,1,1 do
|
||||
for y = -1,1,1 do
|
||||
local w = r:move(x,y)
|
||||
if self.ocupied:bandempty(w) then
|
||||
out = out:bor(w)
|
||||
else
|
||||
if not cancapture:bandempty(w) then
|
||||
out = out:bor(w)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
elseif type == 2 then
|
||||
-- Knight
|
||||
for _,j in ipairs(KNIGHT_MOVES) do
|
||||
local w = r:move(j[1],j[2])
|
||||
|
||||
if self.ocupied:bandempty(w) then
|
||||
out = out:bor(w)
|
||||
else
|
||||
if not cancapture:bandempty(w) then
|
||||
out = out:bor(w)
|
||||
end
|
||||
end
|
||||
end
|
||||
else
|
||||
-- Sliders (Rook, Bishop, Queen)
|
||||
local slides
|
||||
if type == 1 then
|
||||
slides = ROOK_SLIDES
|
||||
elseif type == 3 then
|
||||
slides = BISHOP_SLIDES
|
||||
else
|
||||
slides = QUEEN_SLIDES
|
||||
end
|
||||
|
||||
for _, op in ipairs(slides) do
|
||||
local w = r
|
||||
for i=1,7 do
|
||||
w = w:move(op[1], op[2])
|
||||
if w:empty() then break end
|
||||
|
||||
if self.ocupied:bandempty(w) then
|
||||
out = out:bor(w)
|
||||
else
|
||||
if not cancapture:bandempty(w) then
|
||||
out = out:bor(w)
|
||||
end
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
return out
|
||||
end
|
||||
|
||||
-- 0-5 - From Square
|
||||
-- 6-11 - To Square
|
||||
-- 12 - is Check
|
||||
-- 13 - Is EnPassent
|
||||
-- 14 - Is Castle
|
||||
-- 15-19 - Promotion Piece
|
||||
-- 20-24 - Moved Pice
|
||||
-- 25-29 - Captured Piece
|
||||
|
||||
|
||||
function Board:toString(mark )
|
||||
local out = {}
|
||||
for x=8,1,-1 do
|
||||
table.insert(out, RANKS:sub(x,x) .. " ")
|
||||
|
||||
for y=1,8 do
|
||||
local n = 8*x+y-9
|
||||
local i = self:index(n)
|
||||
if i == 0 then
|
||||
table.insert(out, '-')
|
||||
else
|
||||
-- out = out .. PieceSymbols:sub(i,i)
|
||||
table.insert(out, UnicodePieces[i])
|
||||
end
|
||||
if mark ~= nil and mark:index(n) ~= 0 then
|
||||
table.insert(out, ')')
|
||||
elseif mark ~= nil and n < 63 and y < 8 and mark:index(n+1) ~= 0 then
|
||||
table.insert(out, '(')
|
||||
else
|
||||
table.insert(out, ' ')
|
||||
end
|
||||
end
|
||||
|
||||
table.insert(out, "\n")
|
||||
end
|
||||
table.insert(out, ' ' .. FILES:gsub('.', '%1 ') .. '\n')
|
||||
table.insert(out, (self.toMove == 1 and "White" or "Black") .. ' e:' .. (self.material/100) .. "\n")
|
||||
return table.concat(out)
|
||||
end
|
||||
|
||||
function Board:moveList()
|
||||
local tm = self.toMove == 1 and self.white or self.black
|
||||
local castle_rank = self.toMove == 1 and Rank1 or Rank8
|
||||
local out = {}
|
||||
local function emit(id)
|
||||
if not self:applyMove(id):illegalyChecked() then
|
||||
table.insert(out, id)
|
||||
end
|
||||
end
|
||||
|
||||
local cr = tm:band(self.castle):band(castle_rank)
|
||||
if not cr:empty() then
|
||||
local p = self.toMove == 1 and 11 or 12
|
||||
local tcolor = self.toMove == 1 and self.black or self.white
|
||||
local kidx = self[p]:ctz()
|
||||
|
||||
|
||||
local castle = bit32.replace(0, p, 20, 4)
|
||||
castle = bit32.replace(castle, kidx, 6, 6)
|
||||
castle = bit32.replace(castle, 1, 14)
|
||||
|
||||
|
||||
local mustbeemptyl = LeftMasks[4]:bxor(FileA):band(castle_rank)
|
||||
local cantbethreatened = FileD:bor(FileC):band(castle_rank):bor(self[p])
|
||||
if
|
||||
not cr:bandempty(FileA) and
|
||||
mustbeemptyl:bandempty(self.ocupied) and
|
||||
not self:isSquareThreatened(cantbethreatened, tcolor)
|
||||
then
|
||||
emit(bit32.replace(castle, kidx - 2, 0, 6))
|
||||
end
|
||||
|
||||
|
||||
local mustbeemptyr = RightMasks[3]:bxor(FileH):band(castle_rank)
|
||||
if
|
||||
not cr:bandempty(FileH) and
|
||||
mustbeemptyr:bandempty(self.ocupied) and
|
||||
not self:isSquareThreatened(mustbeemptyr:bor(self[p]), tcolor)
|
||||
then
|
||||
emit(bit32.replace(castle, kidx + 2, 0, 6))
|
||||
end
|
||||
end
|
||||
|
||||
local sq = tm:ctz()
|
||||
repeat
|
||||
local p = self:index(sq)
|
||||
local moves = self:pmoves(sq)
|
||||
|
||||
while not moves:empty() do
|
||||
local m = moves:ctz()
|
||||
moves = moves:set(m, 0)
|
||||
local id = bit32.replace(m, sq, 6, 6)
|
||||
id = bit32.replace(id, p, 20, 4)
|
||||
local mbb = Bitboard.some(m)
|
||||
if not self.ocupied:bandempty(mbb) then
|
||||
id = bit32.replace(id, self:index(m), 25, 4)
|
||||
end
|
||||
|
||||
-- Check if pawn needs to be promoted
|
||||
if p == 1 and m >= 8*7 then
|
||||
for i=3,9,2 do
|
||||
emit(bit32.replace(id, i, 15, 4))
|
||||
end
|
||||
elseif p == 2 and m < 8 then
|
||||
for i=4,10,2 do
|
||||
emit(bit32.replace(id, i, 15, 4))
|
||||
end
|
||||
else
|
||||
emit(id)
|
||||
end
|
||||
end
|
||||
sq = tm:ctzafter(sq)
|
||||
until sq == 64
|
||||
return out
|
||||
end
|
||||
|
||||
function Board:illegalyChecked()
|
||||
local target = self.toMove == 1 and self[PieceSymbols:find("k")] or self[PieceSymbols:find("K")]
|
||||
return self:isSquareThreatened(target, self.toMove == 1 and self.white or self.black)
|
||||
end
|
||||
|
||||
function Board:isSquareThreatened(target , color )
|
||||
local tm = color
|
||||
local sq = tm:ctz()
|
||||
repeat
|
||||
local moves = self:pmoves(sq)
|
||||
if not moves:bandempty(target) then
|
||||
return true
|
||||
end
|
||||
sq = color:ctzafter(sq)
|
||||
until sq == 64
|
||||
return false
|
||||
end
|
||||
|
||||
function Board:perft(depth )
|
||||
if depth == 0 then return 1 end
|
||||
if depth == 1 then
|
||||
return #self:moveList()
|
||||
end
|
||||
local result = 0
|
||||
for k,m in ipairs(self:moveList()) do
|
||||
local c = self:applyMove(m):perft(depth - 1)
|
||||
if c == 0 then
|
||||
-- Perft only counts leaf nodes at target depth
|
||||
-- result = result + 1
|
||||
else
|
||||
result = result + c
|
||||
end
|
||||
end
|
||||
return result
|
||||
end
|
||||
|
||||
|
||||
function Board:applyMove(move )
|
||||
local out = Board.new()
|
||||
table.move(self, 1, 12, 1, out)
|
||||
local from = bit32.extract(move, 6, 6)
|
||||
local to = bit32.extract(move, 0, 6)
|
||||
local promote = bit32.extract(move, 15, 4)
|
||||
local piece = self:index(from)
|
||||
local captured = self:index(to)
|
||||
local tom = Bitboard.some(to)
|
||||
local isCastle = bit32.extract(move, 14)
|
||||
|
||||
if piece % 2 == 0 then
|
||||
out.moves = self.moves + 1
|
||||
end
|
||||
|
||||
if captured == 1 or piece < 3 then
|
||||
out.hm = 0
|
||||
else
|
||||
out.hm = self.hm + 1
|
||||
end
|
||||
out.castle = self.castle
|
||||
out.toMove = self.toMove == 1 and 2 or 1
|
||||
|
||||
if isCastle == 1 then
|
||||
local rank = piece == 11 and Rank1 or Rank8
|
||||
local colorOffset = piece - 11
|
||||
|
||||
out[3 + colorOffset] = out[3 + colorOffset]:bandnot(from < to and FileH or FileA)
|
||||
out[3 + colorOffset] = out[3 + colorOffset]:bor((from < to and FileF or FileD):band(rank))
|
||||
|
||||
out[piece] = (from < to and FileG or FileC):band(rank)
|
||||
out.castle = out.castle:bandnot(rank)
|
||||
out:updateCache()
|
||||
return out
|
||||
end
|
||||
|
||||
if piece < 3 then
|
||||
local dist = math.abs(to - from)
|
||||
-- Pawn moved two squares, set ep square
|
||||
if dist == 16 then
|
||||
out.ep = Bitboard.some((from + to) / 2)
|
||||
end
|
||||
|
||||
-- Remove enpasent capture
|
||||
if not tom:bandempty(self.ep) then
|
||||
if piece == 1 then
|
||||
out[2] = out[2]:bandnot(self.ep:down())
|
||||
end
|
||||
if piece == 2 then
|
||||
out[1] = out[1]:bandnot(self.ep:up())
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
if piece == 3 or piece == 4 then
|
||||
out.castle = out.castle:set(from, 0)
|
||||
end
|
||||
|
||||
if piece > 10 then
|
||||
local rank = piece == 11 and Rank1 or Rank8
|
||||
out.castle = out.castle:bandnot(rank)
|
||||
end
|
||||
|
||||
out[piece] = out[piece]:set(from, 0)
|
||||
if promote == 0 then
|
||||
out[piece] = out[piece]:set(to, 1)
|
||||
else
|
||||
out[promote] = out[promote]:set(to, 1)
|
||||
end
|
||||
if captured ~= 0 then
|
||||
out[captured] = out[captured]:set(to, 0)
|
||||
end
|
||||
|
||||
out:updateCache()
|
||||
return out
|
||||
end
|
||||
|
||||
Board.__index = Board
|
||||
Board.__tostring = Board.toString
|
||||
--
|
||||
-- Main
|
||||
--
|
||||
|
||||
local failures = 0
|
||||
local function test(fen, ply, target)
|
||||
local b = Board.fromFen(fen)
|
||||
if b:fen() ~= fen then
|
||||
print("FEN MISMATCH", fen, b:fen())
|
||||
failures = failures + 1
|
||||
return
|
||||
end
|
||||
|
||||
local found = b:perft(ply)
|
||||
if found ~= target then
|
||||
print(fen, "Found", found, "target", target)
|
||||
failures = failures + 1
|
||||
for k,v in pairs(b:moveList()) do
|
||||
print(ucimove(v) .. ': ' .. (ply > 1 and b:applyMove(v):perft(ply-1) or '1'))
|
||||
end
|
||||
--error("Test Failure")
|
||||
else
|
||||
print("OK", found, fen)
|
||||
end
|
||||
end
|
||||
|
||||
-- From https://www.chessprogramming.org/Perft_Results
|
||||
-- If interpreter, computers, or algorithm gets too fast
|
||||
-- feel free to go deeper
|
||||
|
||||
local testCases = {}
|
||||
local function addTest(...) table.insert(testCases, {...}) end
|
||||
|
||||
addTest(StartingFen, 3, 8902)
|
||||
addTest("r3k2r/p1ppqpb1/bn2pnp1/3PN3/1p2P3/2N2Q1p/PPPBBPPP/R3K2R w KQkq - 0 0", 2, 2039)
|
||||
addTest("8/2p5/3p4/KP5r/1R3p1k/8/4P1P1/8 w - - 0 0", 3, 2812)
|
||||
addTest("r3k2r/Pppp1ppp/1b3nbN/nP6/BBP1P3/q4N2/Pp1P2PP/R2Q1RK1 w kq - 0 1", 3, 9467)
|
||||
addTest("rnbq1k1r/pp1Pbppp/2p5/8/2B5/8/PPP1NnPP/RNBQK2R w KQ - 1 8", 2, 1486)
|
||||
addTest("r4rk1/1pp1qppp/p1np1n2/2b1p1B1/2B1P1b1/P1NP1N2/1PP1QPPP/R4RK1 w - - 0 10", 2, 2079)
|
||||
|
||||
|
||||
local function chess()
|
||||
for k,v in ipairs(testCases) do
|
||||
test(v[1],v[2],v[3])
|
||||
end
|
||||
end
|
||||
|
||||
bench.runCode(chess, "chess")
|
|
@ -1596,7 +1596,7 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_argument_type_suggestion")
|
|||
local function target(a: number, b: string) return a + #b end
|
||||
|
||||
local function d(a: n@1, b)
|
||||
return target(a, b)
|
||||
return target(a, b)
|
||||
end
|
||||
)");
|
||||
|
||||
|
@ -1609,7 +1609,7 @@ end
|
|||
local function target(a: number, b: string) return a + #b end
|
||||
|
||||
local function d(a, b: s@1)
|
||||
return target(a, b)
|
||||
return target(a, b)
|
||||
end
|
||||
)");
|
||||
|
||||
|
@ -1622,7 +1622,7 @@ end
|
|||
local function target(a: number, b: string) return a + #b end
|
||||
|
||||
local function d(a:@1 @2, b)
|
||||
return target(a, b)
|
||||
return target(a, b)
|
||||
end
|
||||
)");
|
||||
|
||||
|
@ -1640,7 +1640,7 @@ end
|
|||
local function target(a: number, b: string) return a + #b end
|
||||
|
||||
local function d(a, b: @1)@2: number
|
||||
return target(a, b)
|
||||
return target(a, b)
|
||||
end
|
||||
)");
|
||||
|
||||
|
@ -1682,7 +1682,7 @@ local x = target(function(a: n@1
|
|||
local function target(callback: (a: number, b: string) -> number) return callback(4, "hello") end
|
||||
|
||||
local x = target(function(a: n@1, b: @2)
|
||||
return a + #b
|
||||
return a + #b
|
||||
end)
|
||||
)");
|
||||
|
||||
|
@ -1700,7 +1700,7 @@ end)
|
|||
local function target(callback: (...number) -> number) return callback(1, 2, 3) end
|
||||
|
||||
local x = target(function(a: n@1)
|
||||
return a
|
||||
return a
|
||||
end
|
||||
)");
|
||||
|
||||
|
@ -1716,7 +1716,7 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_expected_argument_type_pack_suggestio
|
|||
local function target(callback: (...number) -> number) return callback(1, 2, 3) end
|
||||
|
||||
local x = target(function(...:n@1)
|
||||
return a
|
||||
return a
|
||||
end
|
||||
)");
|
||||
|
||||
|
@ -1729,7 +1729,7 @@ end
|
|||
local function target(callback: (...number) -> number) return callback(1, 2, 3) end
|
||||
|
||||
local x = target(function(a:number, b:number, ...:@1)
|
||||
return a + b
|
||||
return a + b
|
||||
end
|
||||
)");
|
||||
|
||||
|
@ -1745,7 +1745,7 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_expected_return_type_suggestion")
|
|||
local function target(callback: () -> number) return callback() end
|
||||
|
||||
local x = target(function(): n@1
|
||||
return 1
|
||||
return 1
|
||||
end
|
||||
)");
|
||||
|
||||
|
@ -1758,7 +1758,7 @@ end
|
|||
local function target(callback: () -> (number, number)) return callback() end
|
||||
|
||||
local x = target(function(): (number, n@1
|
||||
return 1, 2
|
||||
return 1, 2
|
||||
end
|
||||
)");
|
||||
|
||||
|
@ -1774,7 +1774,7 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_expected_return_type_pack_suggestion"
|
|||
local function target(callback: () -> ...number) return callback() end
|
||||
|
||||
local x = target(function(): ...n@1
|
||||
return 1, 2, 3
|
||||
return 1, 2, 3
|
||||
end
|
||||
)");
|
||||
|
||||
|
@ -1787,7 +1787,7 @@ end
|
|||
local function target(callback: () -> ...number) return callback() end
|
||||
|
||||
local x = target(function(): (number, number, ...n@1
|
||||
return 1, 2, 3
|
||||
return 1, 2, 3
|
||||
end
|
||||
)");
|
||||
|
||||
|
|
|
@ -768,11 +768,11 @@ TEST_CASE("CaptureSelf")
|
|||
local MaterialsListClass = {}
|
||||
|
||||
function MaterialsListClass:_MakeToolTip(guiElement, text)
|
||||
local function updateTooltipPosition()
|
||||
self._tweakingTooltipFrame = 5
|
||||
end
|
||||
local function updateTooltipPosition()
|
||||
self._tweakingTooltipFrame = 5
|
||||
end
|
||||
|
||||
updateTooltipPosition()
|
||||
updateTooltipPosition()
|
||||
end
|
||||
|
||||
return MaterialsListClass
|
||||
|
@ -2001,14 +2001,14 @@ TEST_CASE("UpvaluesLoopsBytecode")
|
|||
{
|
||||
CHECK_EQ("\n" + compileFunction(R"(
|
||||
function test()
|
||||
for i=1,10 do
|
||||
for i=1,10 do
|
||||
i = i
|
||||
foo(function() return i end)
|
||||
if bar then
|
||||
break
|
||||
end
|
||||
end
|
||||
return 0
|
||||
foo(function() return i end)
|
||||
if bar then
|
||||
break
|
||||
end
|
||||
end
|
||||
return 0
|
||||
end
|
||||
)",
|
||||
1),
|
||||
|
@ -2035,14 +2035,14 @@ RETURN R0 1
|
|||
|
||||
CHECK_EQ("\n" + compileFunction(R"(
|
||||
function test()
|
||||
for i in ipairs(data) do
|
||||
for i in ipairs(data) do
|
||||
i = i
|
||||
foo(function() return i end)
|
||||
if bar then
|
||||
break
|
||||
end
|
||||
end
|
||||
return 0
|
||||
foo(function() return i end)
|
||||
if bar then
|
||||
break
|
||||
end
|
||||
end
|
||||
return 0
|
||||
end
|
||||
)",
|
||||
1),
|
||||
|
@ -2068,17 +2068,17 @@ RETURN R0 1
|
|||
|
||||
CHECK_EQ("\n" + compileFunction(R"(
|
||||
function test()
|
||||
local i = 0
|
||||
while i < 5 do
|
||||
local j
|
||||
local i = 0
|
||||
while i < 5 do
|
||||
local j
|
||||
j = i
|
||||
foo(function() return j end)
|
||||
i = i + 1
|
||||
if bar then
|
||||
break
|
||||
end
|
||||
end
|
||||
return 0
|
||||
foo(function() return j end)
|
||||
i = i + 1
|
||||
if bar then
|
||||
break
|
||||
end
|
||||
end
|
||||
return 0
|
||||
end
|
||||
)",
|
||||
1),
|
||||
|
@ -2105,17 +2105,17 @@ RETURN R1 1
|
|||
|
||||
CHECK_EQ("\n" + compileFunction(R"(
|
||||
function test()
|
||||
local i = 0
|
||||
repeat
|
||||
local j
|
||||
local i = 0
|
||||
repeat
|
||||
local j
|
||||
j = i
|
||||
foo(function() return j end)
|
||||
i = i + 1
|
||||
if bar then
|
||||
break
|
||||
end
|
||||
until i < 5
|
||||
return 0
|
||||
foo(function() return j end)
|
||||
i = i + 1
|
||||
if bar then
|
||||
break
|
||||
end
|
||||
until i < 5
|
||||
return 0
|
||||
end
|
||||
)",
|
||||
1),
|
||||
|
@ -2304,10 +2304,10 @@ local Value1, Value2, Value3 = ...
|
|||
local Table = {}
|
||||
|
||||
Table.SubTable["Key"] = {
|
||||
Key1 = Value1,
|
||||
Key2 = Value2,
|
||||
Key3 = Value3,
|
||||
Key4 = true,
|
||||
Key1 = Value1,
|
||||
Key2 = Value2,
|
||||
Key3 = Value3,
|
||||
Key4 = true,
|
||||
}
|
||||
)");
|
||||
|
||||
|
|
|
@ -801,4 +801,17 @@ TEST_CASE("IfElseExpression")
|
|||
runConformance("ifelseexpr.lua");
|
||||
}
|
||||
|
||||
TEST_CASE("TagMethodError")
|
||||
{
|
||||
ScopedFastFlag sff{"LuauCcallRestoreFix", true};
|
||||
|
||||
runConformance("tmerror.lua", [](lua_State* L) {
|
||||
auto* cb = lua_callbacks(L);
|
||||
|
||||
cb->debugprotectederror = [](lua_State* L) {
|
||||
CHECK(lua_isyieldable(L));
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
TEST_SUITE_END();
|
||||
|
|
|
@ -2,6 +2,9 @@
|
|||
#pragma once
|
||||
|
||||
#include <ostream>
|
||||
#include <optional>
|
||||
|
||||
namespace std {
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& lhs, const std::nullopt_t&)
|
||||
{
|
||||
|
@ -9,10 +12,12 @@ inline std::ostream& operator<<(std::ostream& lhs, const std::nullopt_t&)
|
|||
}
|
||||
|
||||
template<typename T>
|
||||
std::ostream& operator<<(std::ostream& lhs, const std::optional<T>& t)
|
||||
auto operator<<(std::ostream& lhs, const std::optional<T>& t) -> decltype(lhs << *t) // SFINAE to only instantiate << for supported types
|
||||
{
|
||||
if (t)
|
||||
return lhs << *t;
|
||||
else
|
||||
return lhs << "none";
|
||||
}
|
||||
|
||||
} // namespace std
|
||||
|
|
|
@ -791,13 +791,13 @@ TEST_CASE_FIXTURE(Fixture, "TypeAnnotationsShouldNotProduceWarnings")
|
|||
{
|
||||
LintResult result = lint(R"(--!strict
|
||||
type InputData = {
|
||||
id: number,
|
||||
inputType: EnumItem,
|
||||
inputState: EnumItem,
|
||||
updated: number,
|
||||
position: Vector3,
|
||||
keyCode: EnumItem,
|
||||
name: string
|
||||
id: number,
|
||||
inputType: EnumItem,
|
||||
inputState: EnumItem,
|
||||
updated: number,
|
||||
position: Vector3,
|
||||
keyCode: EnumItem,
|
||||
name: string
|
||||
}
|
||||
)");
|
||||
|
||||
|
|
|
@ -554,4 +554,54 @@ TEST_CASE_FIXTURE(Fixture, "non_recursive_aliases_that_reuse_a_generic_name")
|
|||
CHECK_EQ("{number | string}", toString(requireType("p"), {true}));
|
||||
}
|
||||
|
||||
/*
|
||||
* We had a problem where all type aliases would be prototyped into a child scope that happened
|
||||
* to have the same level. This caused a problem where, if a sibling function referred to that
|
||||
* type alias in its type signature, it would erroneously be quantified away, even though it doesn't
|
||||
* actually belong to the function.
|
||||
*
|
||||
* We solved this by ascribing a unique subLevel to each prototyped alias.
|
||||
*/
|
||||
TEST_CASE_FIXTURE(Fixture, "do_not_quantify_unresolved_aliases")
|
||||
{
|
||||
CheckResult result = check(R"(
|
||||
--!strict
|
||||
|
||||
local KeyPool = {}
|
||||
|
||||
local function newkey(pool: KeyPool, index)
|
||||
return {}
|
||||
end
|
||||
|
||||
function newKeyPool()
|
||||
local pool = {
|
||||
available = {} :: {Key},
|
||||
}
|
||||
|
||||
return setmetatable(pool, KeyPool)
|
||||
end
|
||||
|
||||
export type KeyPool = typeof(newKeyPool())
|
||||
export type Key = typeof(newkey(newKeyPool(), 1))
|
||||
)");
|
||||
|
||||
LUAU_REQUIRE_NO_ERRORS(result);
|
||||
}
|
||||
|
||||
/*
|
||||
* We keep a cache of type alias onto TypeVar to prevent infinite types from
|
||||
* being constructed via recursive or corecursive aliases. We have to adjust
|
||||
* the TypeLevels of those generic TypeVars so that the unifier doesn't think
|
||||
* they have improperly leaked out of their scope.
|
||||
*/
|
||||
TEST_CASE_FIXTURE(Fixture, "generic_typevars_are_not_considered_to_escape_their_scope_if_they_are_reused_in_multiple_aliases")
|
||||
{
|
||||
CheckResult result = check(R"(
|
||||
type Array<T> = {T}
|
||||
type Exclude<T, V> = T
|
||||
)");
|
||||
|
||||
LUAU_REQUIRE_NO_ERRORS(result);
|
||||
}
|
||||
|
||||
TEST_SUITE_END();
|
||||
|
|
|
@ -437,8 +437,6 @@ TEST_CASE_FIXTURE(ClassFixture, "class_unification_type_mismatch_is_correct_orde
|
|||
|
||||
TEST_CASE_FIXTURE(ClassFixture, "optional_class_field_access_error")
|
||||
{
|
||||
ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true);
|
||||
|
||||
CheckResult result = check(R"(
|
||||
local b: Vector2? = nil
|
||||
local a = b.X + b.Z
|
||||
|
|
|
@ -695,4 +695,25 @@ TEST_CASE_FIXTURE(Fixture, "typefuns_sharing_types")
|
|||
CHECK(requireType("y1") == requireType("y2"));
|
||||
}
|
||||
|
||||
TEST_CASE_FIXTURE(Fixture, "bound_tables_do_not_clone_original_fields")
|
||||
{
|
||||
ScopedFastFlag luauRankNTypes{"LuauRankNTypes", true};
|
||||
ScopedFastFlag luauCloneBoundTables{"LuauCloneBoundTables", true};
|
||||
|
||||
CheckResult result = check(R"(
|
||||
local exports = {}
|
||||
local nested = {}
|
||||
|
||||
nested.name = function(t, k)
|
||||
local a = t.x.y
|
||||
return rawget(t, k)
|
||||
end
|
||||
|
||||
exports.nested = nested
|
||||
return exports
|
||||
)");
|
||||
|
||||
LUAU_REQUIRE_NO_ERRORS(result);
|
||||
}
|
||||
|
||||
TEST_SUITE_END();
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
#include <algorithm>
|
||||
|
||||
LUAU_FASTFLAG(LuauEqConstraint)
|
||||
LUAU_FASTFLAG(LuauQuantifyInPlace2)
|
||||
|
||||
using namespace Luau;
|
||||
|
||||
|
@ -42,7 +43,7 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_inference_incomplete")
|
|||
end
|
||||
)";
|
||||
|
||||
const std::string expected = R"(
|
||||
const std::string old_expected = R"(
|
||||
function f(a:{fn:()->(free,free...)}): ()
|
||||
if type(a) == 'boolean'then
|
||||
local a1:boolean=a
|
||||
|
@ -51,7 +52,21 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_inference_incomplete")
|
|||
end
|
||||
end
|
||||
)";
|
||||
CHECK_EQ(expected, decorateWithTypes(code));
|
||||
|
||||
const std::string expected = R"(
|
||||
function f(a:{fn:()->(a,b...)}): ()
|
||||
if type(a) == 'boolean'then
|
||||
local a1:boolean=a
|
||||
elseif a.fn()then
|
||||
local a2:{fn:()->(a,b...)}=a
|
||||
end
|
||||
end
|
||||
)";
|
||||
|
||||
if (FFlag::LuauQuantifyInPlace2)
|
||||
CHECK_EQ(expected, decorateWithTypes(code));
|
||||
else
|
||||
CHECK_EQ(old_expected, decorateWithTypes(code));
|
||||
}
|
||||
|
||||
TEST_CASE_FIXTURE(Fixture, "xpcall_returns_what_f_returns")
|
||||
|
@ -263,8 +278,8 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_equals_another_lvalue_with_no_overlap")
|
|||
|
||||
TEST_CASE_FIXTURE(Fixture, "bail_early_if_unification_is_too_complicated" * doctest::timeout(0.5))
|
||||
{
|
||||
ScopedFastInt sffi{"LuauTarjanChildLimit", 50};
|
||||
ScopedFastInt sffi2{"LuauTypeInferIterationLimit", 50};
|
||||
ScopedFastInt sffi{"LuauTarjanChildLimit", 1};
|
||||
ScopedFastInt sffi2{"LuauTypeInferIterationLimit", 1};
|
||||
|
||||
CheckResult result = check(R"LUA(
|
||||
local Result
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
|
||||
LUAU_FASTFLAG(LuauWeakEqConstraint)
|
||||
LUAU_FASTFLAG(LuauOrPredicate)
|
||||
LUAU_FASTFLAG(LuauQuantifyInPlace2)
|
||||
|
||||
using namespace Luau;
|
||||
|
||||
|
@ -698,10 +699,16 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_free_table_to_vector")
|
|||
|
||||
CHECK_EQ("Vector3", toString(requireTypeAtPosition({5, 28}))); // type(vec) == "vector"
|
||||
|
||||
CHECK_EQ("Type '{- X: a, Y: b, Z: c -}' could not be converted into 'Instance'", toString(result.errors[0]));
|
||||
if (FFlag::LuauQuantifyInPlace2)
|
||||
CHECK_EQ("Type '{+ X: a, Y: b, Z: c +}' could not be converted into 'Instance'", toString(result.errors[0]));
|
||||
else
|
||||
CHECK_EQ("Type '{- X: a, Y: b, Z: c -}' could not be converted into 'Instance'", toString(result.errors[0]));
|
||||
CHECK_EQ("*unknown*", toString(requireTypeAtPosition({7, 28}))); // typeof(vec) == "Instance"
|
||||
|
||||
CHECK_EQ("{- X: a, Y: b, Z: c -}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance"
|
||||
if (FFlag::LuauQuantifyInPlace2)
|
||||
CHECK_EQ("{+ X: a, Y: b, Z: c +}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance"
|
||||
else
|
||||
CHECK_EQ("{- X: a, Y: b, Z: c -}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance"
|
||||
}
|
||||
|
||||
TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_instance_or_vector3_to_vector")
|
||||
|
|
|
@ -617,7 +617,7 @@ TEST_CASE_FIXTURE(Fixture, "indexers_get_quantified_too")
|
|||
|
||||
REQUIRE_EQ(indexer.indexType, typeChecker.numberType);
|
||||
|
||||
REQUIRE(nullptr != get<GenericTypeVar>(indexer.indexResultType));
|
||||
REQUIRE(nullptr != get<GenericTypeVar>(follow(indexer.indexResultType)));
|
||||
}
|
||||
|
||||
TEST_CASE_FIXTURE(Fixture, "indexers_quantification_2")
|
||||
|
|
|
@ -180,7 +180,12 @@ TEST_CASE_FIXTURE(Fixture, "expr_statement")
|
|||
|
||||
TEST_CASE_FIXTURE(Fixture, "generic_function")
|
||||
{
|
||||
CheckResult result = check("function id(x) return x end local a = id(55) local b = id(nil)");
|
||||
CheckResult result = check(R"(
|
||||
function id(x) return x end
|
||||
local a = id(55)
|
||||
local b = id(nil)
|
||||
)");
|
||||
|
||||
LUAU_REQUIRE_NO_ERRORS(result);
|
||||
|
||||
CHECK_EQ(*typeChecker.numberType, *requireType("a"));
|
||||
|
@ -1889,7 +1894,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_higher_order_function")
|
|||
|
||||
REQUIRE_EQ(2, argVec.size());
|
||||
|
||||
const FunctionTypeVar* fType = get<FunctionTypeVar>(argVec[0]);
|
||||
const FunctionTypeVar* fType = get<FunctionTypeVar>(follow(argVec[0]));
|
||||
REQUIRE(fType != nullptr);
|
||||
|
||||
std::vector<TypeId> fArgs = flatten(fType->argTypes).first;
|
||||
|
@ -1926,7 +1931,7 @@ TEST_CASE_FIXTURE(Fixture, "higher_order_function_2")
|
|||
|
||||
REQUIRE_EQ(6, argVec.size());
|
||||
|
||||
const FunctionTypeVar* fType = get<FunctionTypeVar>(argVec[0]);
|
||||
const FunctionTypeVar* fType = get<FunctionTypeVar>(follow(argVec[0]));
|
||||
REQUIRE(fType != nullptr);
|
||||
}
|
||||
|
||||
|
@ -3842,10 +3847,10 @@ local T: any
|
|||
T = {}
|
||||
T.__index = T
|
||||
function T.new(...)
|
||||
local self = {}
|
||||
setmetatable(self, T)
|
||||
self:construct(...)
|
||||
return self
|
||||
local self = {}
|
||||
setmetatable(self, T)
|
||||
self:construct(...)
|
||||
return self
|
||||
end
|
||||
function T:construct(index)
|
||||
end
|
||||
|
@ -4068,11 +4073,11 @@ function n:Clone() end
|
|||
local m = {}
|
||||
|
||||
function m.a(x)
|
||||
x:Clone()
|
||||
x:Clone()
|
||||
end
|
||||
|
||||
function m.b()
|
||||
m.a(n)
|
||||
m.a(n)
|
||||
end
|
||||
|
||||
return m
|
||||
|
@ -4393,8 +4398,6 @@ TEST_CASE_FIXTURE(Fixture, "record_matching_overload")
|
|||
|
||||
TEST_CASE_FIXTURE(Fixture, "infer_anonymous_function_arguments")
|
||||
{
|
||||
ScopedFastFlag luauInferFunctionArgsFix("LuauInferFunctionArgsFix", true);
|
||||
|
||||
// Simple direct arg to arg propagation
|
||||
CheckResult result = check(R"(
|
||||
type Table = { x: number, y: number }
|
||||
|
@ -4681,7 +4684,6 @@ TEST_CASE_FIXTURE(Fixture, "checked_prop_too_early")
|
|||
{
|
||||
ScopedFastFlag sffs[] = {
|
||||
{"LuauSlightlyMoreFlexibleBinaryPredicates", true},
|
||||
{"LuauExtraNilRecovery", true},
|
||||
};
|
||||
|
||||
CheckResult result = check(R"(
|
||||
|
@ -4698,7 +4700,6 @@ TEST_CASE_FIXTURE(Fixture, "accidentally_checked_prop_in_opposite_branch")
|
|||
{
|
||||
ScopedFastFlag sffs[] = {
|
||||
{"LuauSlightlyMoreFlexibleBinaryPredicates", true},
|
||||
{"LuauExtraNilRecovery", true},
|
||||
};
|
||||
|
||||
CheckResult result = check(R"(
|
||||
|
|
|
@ -8,6 +8,8 @@
|
|||
|
||||
#include "doctest.h"
|
||||
|
||||
LUAU_FASTFLAG(LuauQuantifyInPlace2);
|
||||
|
||||
using namespace Luau;
|
||||
|
||||
struct TryUnifyFixture : Fixture
|
||||
|
@ -15,7 +17,8 @@ struct TryUnifyFixture : Fixture
|
|||
TypeArena arena;
|
||||
ScopePtr globalScope{new Scope{arena.addTypePack({TypeId{}})}};
|
||||
InternalErrorReporter iceHandler;
|
||||
Unifier state{&arena, Mode::Strict, globalScope, Location{}, Variance::Covariant, &iceHandler};
|
||||
UnifierSharedState unifierState{&iceHandler};
|
||||
Unifier state{&arena, Mode::Strict, globalScope, Location{}, Variance::Covariant, unifierState};
|
||||
};
|
||||
|
||||
TEST_SUITE_BEGIN("TryUnifyTests");
|
||||
|
@ -139,7 +142,10 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "typepack_unification_should_trim_free_tails"
|
|||
)");
|
||||
|
||||
LUAU_REQUIRE_ERROR_COUNT(1, result);
|
||||
CHECK_EQ("(number) -> (boolean)", toString(requireType("f")));
|
||||
if (FFlag::LuauQuantifyInPlace2)
|
||||
CHECK_EQ("(number) -> boolean", toString(requireType("f")));
|
||||
else
|
||||
CHECK_EQ("(number) -> (boolean)", toString(requireType("f")));
|
||||
}
|
||||
|
||||
TEST_CASE_FIXTURE(TryUnifyFixture, "variadic_type_pack_unification")
|
||||
|
|
|
@ -98,10 +98,10 @@ TEST_CASE_FIXTURE(Fixture, "higher_order_function")
|
|||
std::vector<TypeId> applyArgs = flatten(applyType->argTypes).first;
|
||||
REQUIRE_EQ(3, applyArgs.size());
|
||||
|
||||
const FunctionTypeVar* fType = get<FunctionTypeVar>(applyArgs[0]);
|
||||
const FunctionTypeVar* fType = get<FunctionTypeVar>(follow(applyArgs[0]));
|
||||
REQUIRE(fType != nullptr);
|
||||
|
||||
const FunctionTypeVar* gType = get<FunctionTypeVar>(applyArgs[1]);
|
||||
const FunctionTypeVar* gType = get<FunctionTypeVar>(follow(applyArgs[1]));
|
||||
REQUIRE(gType != nullptr);
|
||||
|
||||
std::vector<TypeId> gArgs = flatten(gType->argTypes).first;
|
||||
|
@ -285,7 +285,7 @@ TEST_CASE_FIXTURE(Fixture, "variadic_argument_tail")
|
|||
{
|
||||
CheckResult result = check(R"(
|
||||
local _ = function():((...any)->(...any),()->())
|
||||
return function() end, function() end
|
||||
return function() end, function() end
|
||||
end
|
||||
for y in _() do
|
||||
end
|
||||
|
|
|
@ -181,8 +181,6 @@ TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_one_optional_property")
|
|||
|
||||
TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_missing_property")
|
||||
{
|
||||
ScopedFastFlag luauMissingUnionPropertyError("LuauMissingUnionPropertyError", true);
|
||||
|
||||
CheckResult result = check(R"(
|
||||
type A = {x: number}
|
||||
type B = {}
|
||||
|
@ -242,8 +240,6 @@ TEST_CASE_FIXTURE(Fixture, "union_equality_comparisons")
|
|||
|
||||
TEST_CASE_FIXTURE(Fixture, "optional_union_members")
|
||||
{
|
||||
ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true);
|
||||
|
||||
CheckResult result = check(R"(
|
||||
local a = { a = { x = 1, y = 2 }, b = 3 }
|
||||
type A = typeof(a)
|
||||
|
@ -259,8 +255,6 @@ local c = bf.a.y
|
|||
|
||||
TEST_CASE_FIXTURE(Fixture, "optional_union_functions")
|
||||
{
|
||||
ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true);
|
||||
|
||||
CheckResult result = check(R"(
|
||||
local a = {}
|
||||
function a.foo(x:number, y:number) return x + y end
|
||||
|
@ -276,8 +270,6 @@ local c = b.foo(1, 2)
|
|||
|
||||
TEST_CASE_FIXTURE(Fixture, "optional_union_methods")
|
||||
{
|
||||
ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true);
|
||||
|
||||
CheckResult result = check(R"(
|
||||
local a = {}
|
||||
function a:foo(x:number, y:number) return x + y end
|
||||
|
@ -310,8 +302,6 @@ return f()
|
|||
|
||||
TEST_CASE_FIXTURE(Fixture, "optional_field_access_error")
|
||||
{
|
||||
ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true);
|
||||
|
||||
CheckResult result = check(R"(
|
||||
type A = { x: number }
|
||||
local b: A? = { x = 2 }
|
||||
|
@ -327,8 +317,6 @@ local d = b.y
|
|||
|
||||
TEST_CASE_FIXTURE(Fixture, "optional_index_error")
|
||||
{
|
||||
ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true);
|
||||
|
||||
CheckResult result = check(R"(
|
||||
type A = {number}
|
||||
local a: A? = {1, 2, 3}
|
||||
|
@ -341,8 +329,6 @@ local b = a[1]
|
|||
|
||||
TEST_CASE_FIXTURE(Fixture, "optional_call_error")
|
||||
{
|
||||
ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true);
|
||||
|
||||
CheckResult result = check(R"(
|
||||
type A = (number) -> number
|
||||
local a: A? = function(a) return -a end
|
||||
|
@ -355,8 +341,6 @@ local b = a(4)
|
|||
|
||||
TEST_CASE_FIXTURE(Fixture, "optional_assignment_errors")
|
||||
{
|
||||
ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true);
|
||||
|
||||
CheckResult result = check(R"(
|
||||
type A = { x: number }
|
||||
local a: A? = { x = 2 }
|
||||
|
@ -378,8 +362,6 @@ a.x = 2
|
|||
|
||||
TEST_CASE_FIXTURE(Fixture, "optional_length_error")
|
||||
{
|
||||
ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true);
|
||||
|
||||
CheckResult result = check(R"(
|
||||
type A = {number}
|
||||
local a: A? = {1, 2, 3}
|
||||
|
@ -392,9 +374,6 @@ local b = #a
|
|||
|
||||
TEST_CASE_FIXTURE(Fixture, "optional_missing_key_error_details")
|
||||
{
|
||||
ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true);
|
||||
ScopedFastFlag luauMissingUnionPropertyError("LuauMissingUnionPropertyError", true);
|
||||
|
||||
CheckResult result = check(R"(
|
||||
type A = { x: number, y: number }
|
||||
type B = { x: number, y: number }
|
||||
|
|
|
@ -265,4 +265,64 @@ TEST_CASE_FIXTURE(Fixture, "substitution_skip_failure")
|
|||
CHECK_EQ("{ f: t1 } where t1 = () -> { f: () -> { f: ({ f: t1 }) -> (), signal: { f: (any) -> () } } }", toString(result));
|
||||
}
|
||||
|
||||
TEST_CASE("tagging_tables")
|
||||
{
|
||||
ScopedFastFlag sff{"LuauRefactorTagging", true};
|
||||
|
||||
TypeVar ttv{TableTypeVar{}};
|
||||
CHECK(!Luau::hasTag(&ttv, "foo"));
|
||||
Luau::attachTag(&ttv, "foo");
|
||||
CHECK(Luau::hasTag(&ttv, "foo"));
|
||||
}
|
||||
|
||||
TEST_CASE("tagging_classes")
|
||||
{
|
||||
ScopedFastFlag sff{"LuauRefactorTagging", true};
|
||||
|
||||
TypeVar base{ClassTypeVar{"Base", {}, std::nullopt, std::nullopt, {}, nullptr}};
|
||||
CHECK(!Luau::hasTag(&base, "foo"));
|
||||
Luau::attachTag(&base, "foo");
|
||||
CHECK(Luau::hasTag(&base, "foo"));
|
||||
}
|
||||
|
||||
TEST_CASE("tagging_subclasses")
|
||||
{
|
||||
ScopedFastFlag sff{"LuauRefactorTagging", true};
|
||||
|
||||
TypeVar base{ClassTypeVar{"Base", {}, std::nullopt, std::nullopt, {}, nullptr}};
|
||||
TypeVar derived{ClassTypeVar{"Derived", {}, &base, std::nullopt, {}, nullptr}};
|
||||
|
||||
CHECK(!Luau::hasTag(&base, "foo"));
|
||||
CHECK(!Luau::hasTag(&derived, "foo"));
|
||||
|
||||
Luau::attachTag(&base, "foo");
|
||||
CHECK(Luau::hasTag(&base, "foo"));
|
||||
CHECK(Luau::hasTag(&derived, "foo"));
|
||||
|
||||
Luau::attachTag(&derived, "bar");
|
||||
CHECK(!Luau::hasTag(&base, "bar"));
|
||||
CHECK(Luau::hasTag(&derived, "bar"));
|
||||
}
|
||||
|
||||
TEST_CASE("tagging_functions")
|
||||
{
|
||||
ScopedFastFlag sff{"LuauRefactorTagging", true};
|
||||
|
||||
TypePackVar empty{TypePack{}};
|
||||
TypeVar ftv{FunctionTypeVar{&empty, &empty}};
|
||||
CHECK(!Luau::hasTag(&ftv, "foo"));
|
||||
Luau::attachTag(&ftv, "foo");
|
||||
CHECK(Luau::hasTag(&ftv, "foo"));
|
||||
}
|
||||
|
||||
TEST_CASE("tagging_props")
|
||||
{
|
||||
ScopedFastFlag sff{"LuauRefactorTagging", true};
|
||||
|
||||
Property prop{};
|
||||
CHECK(!Luau::hasTag(prop, "foo"));
|
||||
Luau::attachTag(prop, "foo");
|
||||
CHECK(Luau::hasTag(prop, "foo"));
|
||||
}
|
||||
|
||||
TEST_SUITE_END();
|
||||
|
|
15
tests/conformance/tmerror.lua
Normal file
15
tests/conformance/tmerror.lua
Normal file
|
@ -0,0 +1,15 @@
|
|||
-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
|
||||
-- This file is based on Lua 5.x tests -- https://github.com/lua/lua/tree/master/testes
|
||||
|
||||
-- Generate an error (i.e. throw an exception) inside a tag method which is indirectly
|
||||
-- called via pcall.
|
||||
-- This test is meant to detect a regression in handling errors inside a tag method
|
||||
|
||||
local testtable = {}
|
||||
setmetatable(testtable, { __index = function() error("Error") end })
|
||||
|
||||
pcall(function()
|
||||
testtable.missingmethod()
|
||||
end)
|
||||
|
||||
return('OK')
|
|
@ -11,9 +11,9 @@ class VariantPrinter:
|
|||
return type.name + " [" + str(value) + "]"
|
||||
|
||||
def match_printer(val):
|
||||
type = val.type.strip_typedefs()
|
||||
if type.name and type.name.startswith('Luau::Variant<'):
|
||||
return VariantPrinter(val)
|
||||
return None
|
||||
type = val.type.strip_typedefs()
|
||||
if type.name and type.name.startswith('Luau::Variant<'):
|
||||
return VariantPrinter(val)
|
||||
return None
|
||||
|
||||
gdb.pretty_printers.append(match_printer)
|
||||
|
|
Loading…
Reference in a new issue