Merge pull request #2 from MathematicalDessert/webrepl

Webrepl
This commit is contained in:
Pelanyo Kamara 2021-11-08 19:37:01 +00:00 committed by GitHub
commit 0085d2479a
Signed by: DevComp
GPG key ID: 4AEE18F83AFDEB23
66 changed files with 2414 additions and 664 deletions

View file

@ -83,3 +83,28 @@ jobs:
with:
name: coverage
path: coverage
web:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
- name: setup emsdk
uses: mymindstorm/setup-emsdk@v9
with:
version: 1.38.40
actions-cache-folder: 'emsdk-cache'
- name: verify emsdk
run: |
emcc -v # verify emscripten installed successfully
- name: cmake configure
run: |
emcmake cmake .
- name: build
run: |
cmake --build . --target Luau.Repl.CLI --config Debug # as seen above but only Repl.CLI.
- name: upload-artifacts
- uses: actions/upload-artifact@v2
with:
name: luau-web-artifacts
path: docs/assets/luau/
if-no-files-found: error

1
.gitignore vendored
View file

@ -5,3 +5,4 @@
^default.prof*
^fuzz-*
^luau$
/.vs

View file

@ -34,7 +34,6 @@ TypeId makeFunction( // Polymorphic
std::initializer_list<TypeId> paramTypes, std::initializer_list<std::string> paramNames, std::initializer_list<TypeId> retTypes);
void attachMagicFunction(TypeId ty, MagicFunction fn);
void attachFunctionTag(TypeId ty, std::string constraint);
Property makeProperty(TypeId ty, std::optional<std::string> documentationSymbol = std::nullopt);
void assignPropDocumentationSymbols(TableTypeVar::Props& props, const std::string& baseName);

View file

@ -0,0 +1,14 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/TypeVar.h"
namespace Luau
{
struct Module;
using ModulePtr = std::shared_ptr<Module>;
void quantify(ModulePtr module, TypeId ty, TypeLevel level);
} // namespace Luau

View file

@ -69,4 +69,6 @@ std::string toString(const TypePackVar& tp, const ToStringOptions& opts = {});
void dump(TypeId ty);
void dump(TypePackId ty);
std::string generateName(size_t n);
} // namespace Luau

View file

@ -12,6 +12,7 @@ struct AstArray;
class AstStat;
bool containsFunctionCall(const AstStat& stat);
bool containsFunctionCallOrReturn(const AstStat& stat);
bool isFunction(const AstStat& stat);
void toposort(std::vector<AstStat*>& stats);

View file

@ -3,19 +3,37 @@
#include "Luau/TypeVar.h"
LUAU_FASTFLAG(LuauShareTxnSeen);
namespace Luau
{
// Log of where what TypeIds we are rebinding and what they used to be
struct TxnLog
{
TxnLog() = default;
explicit TxnLog(const std::vector<std::pair<TypeId, TypeId>>& seen)
: seen(seen)
TxnLog()
: originalSeenSize(0)
, ownedSeen()
, sharedSeen(&ownedSeen)
{
}
explicit TxnLog(std::vector<std::pair<TypeId, TypeId>>* sharedSeen)
: originalSeenSize(sharedSeen->size())
, ownedSeen()
, sharedSeen(sharedSeen)
{
}
explicit TxnLog(const std::vector<std::pair<TypeId, TypeId>>& ownedSeen)
: originalSeenSize(ownedSeen.size())
, ownedSeen(ownedSeen)
, sharedSeen(nullptr)
{
// This is deprecated!
LUAU_ASSERT(!FFlag::LuauShareTxnSeen);
}
TxnLog(const TxnLog&) = delete;
TxnLog& operator=(const TxnLog&) = delete;
@ -38,9 +56,11 @@ private:
std::vector<std::pair<TypeId, TypeVar>> typeVarChanges;
std::vector<std::pair<TypePackId, TypePackVar>> typePackChanges;
std::vector<std::pair<TableTypeVar*, std::optional<TypeId>>> tableChanges;
size_t originalSeenSize;
public:
std::vector<std::pair<TypeId, TypeId>> seen; // used to avoid infinite recursion when types are cyclic
std::vector<std::pair<TypeId, TypeId>> ownedSeen; // used to avoid infinite recursion when types are cyclic
std::vector<std::pair<TypeId, TypeId>>* sharedSeen; // shared with all the descendent logs
};
} // namespace Luau

View file

@ -11,6 +11,7 @@
#include "Luau/TypePack.h"
#include "Luau/TypeVar.h"
#include "Luau/Unifier.h"
#include "Luau/UnifierSharedState.h"
#include <memory>
#include <unordered_map>
@ -121,7 +122,7 @@ struct TypeChecker
void check(const ScopePtr& scope, const AstStatForIn& forin);
void check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatFunction& function);
void check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatLocalFunction& function);
void check(const ScopePtr& scope, const AstStatTypeAlias& typealias, bool forwardDeclare = false);
void check(const ScopePtr& scope, const AstStatTypeAlias& typealias, int subLevel = 0, bool forwardDeclare = false);
void check(const ScopePtr& scope, const AstStatDeclareClass& declaredClass);
void check(const ScopePtr& scope, const AstStatDeclareFunction& declaredFunction);
@ -336,7 +337,7 @@ private:
// Note: `scope` must be a fresh scope.
std::pair<std::vector<TypeId>, std::vector<TypePackId>> createGenericTypes(
const ScopePtr& scope, const AstNode& node, const AstArray<AstName>& genericNames, const AstArray<AstName>& genericPackNames);
const ScopePtr& scope, std::optional<TypeLevel> levelOpt, const AstNode& node, const AstArray<AstName>& genericNames, const AstArray<AstName>& genericPackNames);
public:
ErrorVec resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense);
@ -383,6 +384,8 @@ public:
std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope;
InternalErrorReporter* iceHandler;
UnifierSharedState unifierState;
public:
const TypeId nilType;
const TypeId numberType;

View file

@ -540,4 +540,11 @@ UnionTypeVarIterator end(const UnionTypeVar* utv);
using TypeIdPredicate = std::function<std::optional<TypeId>(TypeId)>;
std::vector<TypeId> filterMap(TypeId type, TypeIdPredicate predicate);
void attachTag(TypeId ty, const std::string& tagName);
void attachTag(Property& prop, const std::string& tagName);
bool hasTag(TypeId ty, const std::string& tagName);
bool hasTag(const Property& prop, const std::string& tagName);
bool hasTag(const Tags& tags, const std::string& tagName); // Do not use in new work.
} // namespace Luau

View file

@ -6,6 +6,7 @@
#include "Luau/TxnLog.h"
#include "Luau/TypeInfer.h"
#include "Luau/Module.h" // FIXME: For TypeArena. It merits breaking out into its own header.
#include "Luau/UnifierSharedState.h"
#include <unordered_set>
@ -41,11 +42,14 @@ struct Unifier
std::shared_ptr<UnifierCounters> counters_DEPRECATED;
InternalErrorReporter* iceHandler;
UnifierSharedState& sharedState;
Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, InternalErrorReporter* iceHandler);
Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::vector<std::pair<TypeId, TypeId>>& seen, const Location& location,
Variance variance, InternalErrorReporter* iceHandler, const std::shared_ptr<UnifierCounters>& counters_DEPRECATED = nullptr,
Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, UnifierSharedState& sharedState);
Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::vector<std::pair<TypeId, TypeId>>& ownedSeen, const Location& location,
Variance variance, UnifierSharedState& sharedState, const std::shared_ptr<UnifierCounters>& counters_DEPRECATED = nullptr,
UnifierCounters* counters = nullptr);
Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, std::vector<std::pair<TypeId, TypeId>>* sharedSeen, const Location& location,
Variance variance, UnifierSharedState& sharedState, const std::shared_ptr<UnifierCounters>& counters_DEPRECATED = nullptr,
UnifierCounters* counters = nullptr);
// Test whether the two type vars unify. Never commits the result.
@ -69,7 +73,8 @@ private:
void tryUnifyWithMetatable(TypeId metatable, TypeId other, bool reversed);
void tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed);
void tryUnify(const TableIndexer& superIndexer, const TableIndexer& subIndexer);
TypeId deeplyOptional(TypeId ty, std::unordered_map<TypeId,TypeId> seen = {});
TypeId deeplyOptional(TypeId ty, std::unordered_map<TypeId, TypeId> seen = {});
void cacheResult(TypeId superTy, TypeId subTy);
public:
void tryUnify(TypePackId superTy, TypePackId subTy, bool isFunctionCall = false);
@ -101,8 +106,9 @@ private:
[[noreturn]] void ice(const std::string& message, const Location& location);
[[noreturn]] void ice(const std::string& message);
DenseHashSet<TypeId> tempSeenTy{nullptr};
DenseHashSet<TypePackId> tempSeenTp{nullptr};
// Remove with FFlagLuauCacheUnifyTableResults
DenseHashSet<TypeId> tempSeenTy_DEPRECATED{nullptr};
DenseHashSet<TypePackId> tempSeenTp_DEPRECATED{nullptr};
};
} // namespace Luau

View file

@ -0,0 +1,44 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/DenseHash.h"
#include "Luau/TypeVar.h"
#include "Luau/TypePack.h"
#include <utility>
namespace Luau
{
struct InternalErrorReporter;
struct TypeIdPairHash
{
size_t hashOne(Luau::TypeId key) const
{
return (uintptr_t(key) >> 4) ^ (uintptr_t(key) >> 9);
}
size_t operator()(const std::pair<Luau::TypeId, Luau::TypeId>& x) const
{
return hashOne(x.first) ^ (hashOne(x.second) << 1);
}
};
struct UnifierSharedState
{
UnifierSharedState(InternalErrorReporter* iceHandler)
: iceHandler(iceHandler)
{
}
InternalErrorReporter* iceHandler;
DenseHashSet<void*> seenAny{nullptr};
DenseHashMap<TypeId, bool> skipCacheForType{nullptr};
DenseHashSet<std::pair<TypeId, TypeId>, TypeIdPairHash> cachedUnify{{nullptr, nullptr}};
DenseHashSet<TypeId> tempSeenTy{nullptr};
DenseHashSet<TypePackId> tempSeenTp{nullptr};
};
} // namespace Luau

View file

@ -1,9 +1,12 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/DenseHash.h"
#include "Luau/TypeVar.h"
#include "Luau/TypePack.h"
LUAU_FASTFLAG(LuauCacheUnifyTableResults)
namespace Luau
{
@ -32,17 +35,33 @@ inline bool hasSeen(std::unordered_set<void*>& seen, const void* tv)
return !seen.insert(ttv).second;
}
inline bool hasSeen(DenseHashSet<void*>& seen, const void* tv)
{
void* ttv = const_cast<void*>(tv);
if (seen.contains(ttv))
return true;
seen.insert(ttv);
return false;
}
inline void unsee(std::unordered_set<void*>& seen, const void* tv)
{
void* ttv = const_cast<void*>(tv);
seen.erase(ttv);
}
template<typename F>
void visit(TypePackId tp, F& f, std::unordered_set<void*>& seen);
inline void unsee(DenseHashSet<void*>& seen, const void* tv)
{
// When DenseHashSet is used for 'visitOnce', where don't forget visited elements
}
template<typename F>
void visit(TypeId ty, F& f, std::unordered_set<void*>& seen)
template<typename F, typename Set>
void visit(TypePackId tp, F& f, Set& seen);
template<typename F, typename Set>
void visit(TypeId ty, F& f, Set& seen)
{
if (visit_detail::hasSeen(seen, ty))
{
@ -79,15 +98,23 @@ void visit(TypeId ty, F& f, std::unordered_set<void*>& seen)
else if (auto ttv = get<TableTypeVar>(ty))
{
// Some visitors want to see bound tables, that's why we visit the original type
if (apply(ty, *ttv, seen, f))
{
for (auto& [_name, prop] : ttv->props)
visit(prop.type, f, seen);
if (ttv->indexer)
if (FFlag::LuauCacheUnifyTableResults && ttv->boundTo)
{
visit(ttv->indexer->indexType, f, seen);
visit(ttv->indexer->indexResultType, f, seen);
visit(*ttv->boundTo, f, seen);
}
else
{
for (auto& [_name, prop] : ttv->props)
visit(prop.type, f, seen);
if (ttv->indexer)
{
visit(ttv->indexer->indexType, f, seen);
visit(ttv->indexer->indexResultType, f, seen);
}
}
}
}
@ -140,8 +167,8 @@ void visit(TypeId ty, F& f, std::unordered_set<void*>& seen)
visit_detail::unsee(seen, ty);
}
template<typename F>
void visit(TypePackId tp, F& f, std::unordered_set<void*>& seen)
template<typename F, typename Set>
void visit(TypePackId tp, F& f, Set& seen)
{
if (visit_detail::hasSeen(seen, tp))
{
@ -182,6 +209,7 @@ void visit(TypePackId tp, F& f, std::unordered_set<void*>& seen)
visit_detail::unsee(seen, tp);
}
} // namespace visit_detail
template<typename TID, typename F>
@ -197,4 +225,11 @@ void visitTypeVar(TID ty, F& f)
visit_detail::visit(ty, f, seen);
}
template<typename TID, typename F>
void visitTypeVarOnce(TID ty, F& f, DenseHashSet<void*>& seen)
{
seen.clear();
visit_detail::visit(ty, f, seen);
}
} // namespace Luau

View file

@ -196,7 +196,8 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ
auto canUnify = [&typeArena, &module](TypeId expectedType, TypeId actualType) {
InternalErrorReporter iceReporter;
Unifier unifier(typeArena, Mode::Strict, module.getModuleScope(), Location(), Variance::Covariant, &iceReporter);
UnifierSharedState unifierState(&iceReporter);
Unifier unifier(typeArena, Mode::Strict, module.getModuleScope(), Location(), Variance::Covariant, unifierState);
unifier.tryUnify(expectedType, actualType);

View file

@ -106,18 +106,6 @@ void attachMagicFunction(TypeId ty, MagicFunction fn)
LUAU_ASSERT(!"Got a non functional type");
}
void attachFunctionTag(TypeId ty, std::string tag)
{
if (auto ftv = getMutable<FunctionTypeVar>(ty))
{
ftv->tags.emplace_back(std::move(tag));
}
else
{
LUAU_ASSERT(!"Got a non functional type");
}
}
Property makeProperty(TypeId ty, std::optional<std::string> documentationSymbol)
{
return {

View file

@ -23,6 +23,7 @@ LUAU_FASTFLAGVARIABLE(LuauResolveModuleNameWithoutACurrentModule, false)
LUAU_FASTFLAG(LuauTraceRequireLookupChild)
LUAU_FASTFLAGVARIABLE(LuauPersistDefinitionFileTypes, false)
LUAU_FASTFLAG(LuauNewRequireTrace)
LUAU_FASTFLAGVARIABLE(LuauClearScopes, false)
namespace Luau
{
@ -248,7 +249,7 @@ struct RequireCycle
// Note that this is O(V^2) for a fully connected graph and produces O(V) paths of length O(V)
// However, when the graph is acyclic, this is O(V), as well as when only the first cycle is needed (stopAtFirst=true)
std::vector<RequireCycle> getRequireCycles(
const std::unordered_map<ModuleName, SourceNode>& sourceNodes, const SourceNode* start, bool stopAtFirst = false)
const FileResolver* resolver, const std::unordered_map<ModuleName, SourceNode>& sourceNodes, const SourceNode* start, bool stopAtFirst = false)
{
std::vector<RequireCycle> result;
@ -282,9 +283,9 @@ std::vector<RequireCycle> getRequireCycles(
if (top == start)
{
for (const SourceNode* node : path)
cycle.push_back(node->name);
cycle.push_back(resolver->getHumanReadableModuleName(node->name));
cycle.push_back(top->name);
cycle.push_back(resolver->getHumanReadableModuleName(top->name));
break;
}
}
@ -404,7 +405,7 @@ CheckResult Frontend::check(const ModuleName& name)
// however, for now getRequireCycles isn't expensive in practice on the cases we care about, and long term
// all correct programs must be acyclic so this code triggers rarely
if (cycleDetected)
requireCycles = getRequireCycles(sourceNodes, &sourceNode, mode == Mode::NoCheck);
requireCycles = getRequireCycles(fileResolver, sourceNodes, &sourceNode, mode == Mode::NoCheck);
// This is used by the type checker to replace the resulting type of cyclic modules with any
sourceModule.cyclic = !requireCycles.empty();
@ -458,6 +459,8 @@ CheckResult Frontend::check(const ModuleName& name)
module->astTypes.clear();
module->astExpectedTypes.clear();
module->astOriginalCallTypes.clear();
if (FFlag::LuauClearScopes)
module->scopes.resize(1);
}
if (mode != Mode::NoCheck)

View file

@ -15,6 +15,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauTrackOwningArena, false)
LUAU_FASTFLAG(LuauSecondTypecheckKnowsTheDataModel)
LUAU_FASTFLAG(LuauCaptureBrokenCommentSpans)
LUAU_FASTFLAG(LuauTypeAliasPacks)
LUAU_FASTFLAGVARIABLE(LuauCloneBoundTables, false)
namespace Luau
{
@ -299,6 +300,14 @@ void TypeCloner::operator()(const FunctionTypeVar& t)
void TypeCloner::operator()(const TableTypeVar& t)
{
// If table is now bound to another one, we ignore the content of the original
if (FFlag::LuauCloneBoundTables && t.boundTo)
{
TypeId boundTo = clone(*t.boundTo, dest, seenTypes, seenTypePacks, encounteredFreeType);
seenTypes[typeId] = boundTo;
return;
}
TypeId result = dest.addType(TableTypeVar{});
TableTypeVar* ttv = getMutable<TableTypeVar>(result);
LUAU_ASSERT(ttv != nullptr);
@ -321,8 +330,11 @@ void TypeCloner::operator()(const TableTypeVar& t)
ttv->indexer = TableIndexer{clone(t.indexer->indexType, dest, seenTypes, seenTypePacks, encounteredFreeType),
clone(t.indexer->indexResultType, dest, seenTypes, seenTypePacks, encounteredFreeType)};
if (t.boundTo)
ttv->boundTo = clone(*t.boundTo, dest, seenTypes, seenTypePacks, encounteredFreeType);
if (!FFlag::LuauCloneBoundTables)
{
if (t.boundTo)
ttv->boundTo = clone(*t.boundTo, dest, seenTypes, seenTypePacks, encounteredFreeType);
}
for (TypeId& arg : ttv->instantiatedTypeParams)
arg = clone(arg, dest, seenTypes, seenTypePacks, encounteredFreeType);
@ -335,7 +347,7 @@ void TypeCloner::operator()(const TableTypeVar& t)
if (ttv->state == TableState::Free)
{
if (!t.boundTo)
if (FFlag::LuauCloneBoundTables || !t.boundTo)
{
if (encounteredFreeType)
*encounteredFreeType = true;

90
Analysis/src/Quantify.cpp Normal file
View file

@ -0,0 +1,90 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Quantify.h"
#include "Luau/VisitTypeVar.h"
namespace Luau
{
struct Quantifier
{
ModulePtr module;
TypeLevel level;
std::vector<TypeId> generics;
std::vector<TypePackId> genericPacks;
Quantifier(ModulePtr module, TypeLevel level)
: module(module)
, level(level)
{
}
void cycle(TypeId) {}
void cycle(TypePackId) {}
bool operator()(TypeId ty, const FreeTypeVar& ftv)
{
if (!level.subsumes(ftv.level))
return false;
*asMutable(ty) = GenericTypeVar{level};
generics.push_back(ty);
return false;
}
template<typename T>
bool operator()(TypeId ty, const T& t)
{
return true;
}
template<typename T>
bool operator()(TypePackId, const T&)
{
return true;
}
bool operator()(TypeId ty, const TableTypeVar&)
{
TableTypeVar& ttv = *getMutable<TableTypeVar>(ty);
if (ttv.state == TableState::Sealed || ttv.state == TableState::Generic)
return false;
if (!level.subsumes(ttv.level))
return false;
if (ttv.state == TableState::Free)
ttv.state = TableState::Generic;
else if (ttv.state == TableState::Unsealed)
ttv.state = TableState::Sealed;
ttv.level = level;
return true;
}
bool operator()(TypePackId tp, const FreeTypePack& ftp)
{
if (!level.subsumes(ftp.level))
return false;
*asMutable(tp) = GenericTypePack{level};
genericPacks.push_back(tp);
return true;
}
};
void quantify(ModulePtr module, TypeId ty, TypeLevel level)
{
Quantifier q{std::move(module), level};
visitTypeVar(ty, q);
FunctionTypeVar* ftv = getMutable<FunctionTypeVar>(ty);
LUAU_ASSERT(ftv);
ftv->generics = q.generics;
ftv->genericPacks = q.genericPacks;
}
} // namespace Luau

View file

@ -182,7 +182,7 @@ struct RequireTracerOld : AstVisitor
struct RequireTracer : AstVisitor
{
RequireTracer(RequireTraceResult& result, FileResolver * fileResolver, const ModuleName& currentModuleName)
RequireTracer(RequireTraceResult& result, FileResolver* fileResolver, const ModuleName& currentModuleName)
: result(result)
, fileResolver(fileResolver)
, currentModuleName(currentModuleName)
@ -260,7 +260,7 @@ struct RequireTracer : AstVisitor
// seed worklist with require arguments
work.reserve(requires.size());
for (AstExprCall* require: requires)
for (AstExprCall* require : requires)
work.push_back(require->args.data[0]);
// push all dependent expressions to the work stack; note that the vector is modified during traversal

View file

@ -10,7 +10,6 @@
#include <algorithm>
#include <stdexcept>
LUAU_FASTFLAG(LuauExtraNilRecovery)
LUAU_FASTFLAG(LuauOccursCheckOkWithRecursiveFunctions)
LUAU_FASTFLAGVARIABLE(LuauInstantiatedTypeParamRecursion, false)
LUAU_FASTFLAG(LuauTypeAliasPacks)
@ -159,15 +158,6 @@ struct StringifierState
seen.erase(iter);
}
static std::string generateName(size_t i)
{
std::string n;
n = char('a' + i % 26);
if (i >= 26)
n += std::to_string(i / 26);
return n;
}
std::string getName(TypeId ty)
{
const size_t s = result.nameMap.typeVars.size();
@ -584,8 +574,7 @@ struct TypeVarStringifier
std::vector<std::string> results = {};
for (auto el : &uv)
{
if (FFlag::LuauExtraNilRecovery || FFlag::LuauAddMissingFollow)
el = follow(el);
el = follow(el);
if (isNil(el))
{
@ -649,8 +638,7 @@ struct TypeVarStringifier
std::vector<std::string> results = {};
for (auto el : uv.parts)
{
if (FFlag::LuauExtraNilRecovery || FFlag::LuauAddMissingFollow)
el = follow(el);
el = follow(el);
std::string saved = std::move(state.result.name);
@ -1204,4 +1192,13 @@ void dump(TypePackId ty)
printf("%s\n", toString(ty, opts).c_str());
}
std::string generateName(size_t i)
{
std::string n;
n = char('a' + i % 26);
if (i >= 26)
n += std::to_string(i / 26);
return n;
}
} // namespace Luau

View file

@ -298,8 +298,15 @@ struct ArcCollector : public AstVisitor
struct ContainsFunctionCall : public AstVisitor
{
bool alsoReturn = false;
bool result = false;
ContainsFunctionCall() = default;
explicit ContainsFunctionCall(bool alsoReturn)
: alsoReturn(alsoReturn)
{
}
bool visit(AstExpr*) override
{
return !result; // short circuit if result is true
@ -318,6 +325,17 @@ struct ContainsFunctionCall : public AstVisitor
return false;
}
bool visit(AstStatReturn* stat) override
{
if (alsoReturn)
{
result = true;
return false;
}
else
return AstVisitor::visit(stat);
}
bool visit(AstExprFunction*) override
{
return false;
@ -479,6 +497,13 @@ bool containsFunctionCall(const AstStat& stat)
return cfc.result;
}
bool containsFunctionCallOrReturn(const AstStat& stat)
{
detail::ContainsFunctionCall cfc{true};
const_cast<AstStat&>(stat).visit(&cfc);
return cfc.result;
}
bool isFunction(const AstStat& stat)
{
return stat.is<AstStatFunction>() || stat.is<AstStatLocalFunction>();

View file

@ -5,6 +5,8 @@
#include <algorithm>
LUAU_FASTFLAGVARIABLE(LuauShareTxnSeen, false)
namespace Luau
{
@ -33,6 +35,12 @@ void TxnLog::rollback()
for (auto it = tableChanges.rbegin(); it != tableChanges.rend(); ++it)
std::swap(it->first->boundTo, it->second);
if (FFlag::LuauShareTxnSeen)
{
LUAU_ASSERT(originalSeenSize <= sharedSeen->size());
sharedSeen->resize(originalSeenSize);
}
}
void TxnLog::concat(TxnLog rhs)
@ -46,27 +54,44 @@ void TxnLog::concat(TxnLog rhs)
tableChanges.insert(tableChanges.end(), rhs.tableChanges.begin(), rhs.tableChanges.end());
rhs.tableChanges.clear();
seen.swap(rhs.seen);
rhs.seen.clear();
if (!FFlag::LuauShareTxnSeen)
{
ownedSeen.swap(rhs.ownedSeen);
rhs.ownedSeen.clear();
}
}
bool TxnLog::haveSeen(TypeId lhs, TypeId rhs)
{
const std::pair<TypeId, TypeId> sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs);
return (seen.end() != std::find(seen.begin(), seen.end(), sortedPair));
if (FFlag::LuauShareTxnSeen)
return (sharedSeen->end() != std::find(sharedSeen->begin(), sharedSeen->end(), sortedPair));
else
return (ownedSeen.end() != std::find(ownedSeen.begin(), ownedSeen.end(), sortedPair));
}
void TxnLog::pushSeen(TypeId lhs, TypeId rhs)
{
const std::pair<TypeId, TypeId> sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs);
seen.push_back(sortedPair);
if (FFlag::LuauShareTxnSeen)
sharedSeen->push_back(sortedPair);
else
ownedSeen.push_back(sortedPair);
}
void TxnLog::popSeen(TypeId lhs, TypeId rhs)
{
const std::pair<TypeId, TypeId> sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs);
LUAU_ASSERT(sortedPair == seen.back());
seen.pop_back();
if (FFlag::LuauShareTxnSeen)
{
LUAU_ASSERT(sortedPair == sharedSeen->back());
sharedSeen->pop_back();
}
else
{
LUAU_ASSERT(sortedPair == ownedSeen.back());
ownedSeen.pop_back();
}
}
} // namespace Luau

View file

@ -6,6 +6,7 @@
#include "Luau/Parser.h"
#include "Luau/RecursionCounter.h"
#include "Luau/Scope.h"
#include "Luau/ToString.h"
#include "Luau/TypeInfer.h"
#include "Luau/TypePack.h"
#include "Luau/TypeVar.h"
@ -33,14 +34,31 @@ static char* allocateString(Luau::Allocator& allocator, const char* format, Data
return result;
}
using SyntheticNames = std::unordered_map<const void*, char*>;
namespace Luau
{
static const char* getName(Allocator* allocator, SyntheticNames* syntheticNames, const Unifiable::Generic& gen)
{
size_t s = syntheticNames->size();
char*& n = (*syntheticNames)[&gen];
if (!n)
{
std::string str = gen.explicitName ? gen.name : generateName(s);
n = static_cast<char*>(allocator->allocate(str.size() + 1));
strcpy(n, str.c_str());
}
return n;
}
class TypeRehydrationVisitor
{
mutable std::map<void*, int> seen;
mutable int count = 0;
std::map<void*, int> seen;
int count = 0;
bool hasSeen(const void* tv) const
bool hasSeen(const void* tv)
{
void* ttv = const_cast<void*>(tv);
auto it = seen.find(ttv);
@ -52,15 +70,16 @@ class TypeRehydrationVisitor
}
public:
TypeRehydrationVisitor(Allocator* alloc, const TypeRehydrationOptions& options = TypeRehydrationOptions())
TypeRehydrationVisitor(Allocator* alloc, SyntheticNames* syntheticNames, const TypeRehydrationOptions& options = TypeRehydrationOptions())
: allocator(alloc)
, syntheticNames(syntheticNames)
, options(options)
{
}
AstTypePack* rehydrate(TypePackId tp) const;
AstTypePack* rehydrate(TypePackId tp);
AstType* operator()(const PrimitiveTypeVar& ptv) const
AstType* operator()(const PrimitiveTypeVar& ptv)
{
switch (ptv.type)
{
@ -78,11 +97,11 @@ public:
return nullptr;
}
}
AstType* operator()(const AnyTypeVar&) const
AstType* operator()(const AnyTypeVar&)
{
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("any"));
}
AstType* operator()(const TableTypeVar& ttv) const
AstType* operator()(const TableTypeVar& ttv)
{
RecursionCounter counter(&count);
@ -144,12 +163,12 @@ public:
return allocator->alloc<AstTypeTable>(Location(), props, indexer);
}
AstType* operator()(const MetatableTypeVar& mtv) const
AstType* operator()(const MetatableTypeVar& mtv)
{
return Luau::visit(*this, mtv.table->ty);
}
AstType* operator()(const ClassTypeVar& ctv) const
AstType* operator()(const ClassTypeVar& ctv)
{
RecursionCounter counter(&count);
@ -176,7 +195,7 @@ public:
return allocator->alloc<AstTypeTable>(Location(), props);
}
AstType* operator()(const FunctionTypeVar& ftv) const
AstType* operator()(const FunctionTypeVar& ftv)
{
RecursionCounter counter(&count);
@ -253,10 +272,12 @@ public:
size_t i = 0;
for (const auto& el : ftv.argNames)
{
std::optional<AstArgumentName>* arg = &argNames.data[i++];
if (el)
argNames.data[i++] = {AstName(el->name.c_str()), el->location};
new (arg) std::optional<AstArgumentName>(AstArgumentName(AstName(el->name.c_str()), el->location));
else
argNames.data[i++] = {};
new (arg) std::optional<AstArgumentName>();
}
AstArray<AstType*> returnTypes;
@ -290,23 +311,23 @@ public:
return allocator->alloc<AstTypeFunction>(
Location(), generics, genericPacks, AstTypeList{argTypes, argTailAnnotation}, argNames, AstTypeList{returnTypes, retTailAnnotation});
}
AstType* operator()(const Unifiable::Error&) const
AstType* operator()(const Unifiable::Error&)
{
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("Unifiable<Error>"));
}
AstType* operator()(const GenericTypeVar& gtv) const
AstType* operator()(const GenericTypeVar& gtv)
{
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName(gtv.name.c_str()));
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName(getName(allocator, syntheticNames, gtv)));
}
AstType* operator()(const Unifiable::Bound<TypeId>& bound) const
AstType* operator()(const Unifiable::Bound<TypeId>& bound)
{
return Luau::visit(*this, bound.boundTo->ty);
}
AstType* operator()(Unifiable::Free ftv) const
AstType* operator()(const FreeTypeVar& ftv)
{
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("free"));
}
AstType* operator()(const UnionTypeVar& uv) const
AstType* operator()(const UnionTypeVar& uv)
{
AstArray<AstType*> unionTypes;
unionTypes.size = uv.options.size();
@ -317,7 +338,7 @@ public:
}
return allocator->alloc<AstTypeUnion>(Location(), unionTypes);
}
AstType* operator()(const IntersectionTypeVar& uv) const
AstType* operator()(const IntersectionTypeVar& uv)
{
AstArray<AstType*> intersectionTypes;
intersectionTypes.size = uv.parts.size();
@ -328,23 +349,28 @@ public:
}
return allocator->alloc<AstTypeIntersection>(Location(), intersectionTypes);
}
AstType* operator()(const LazyTypeVar& ltv) const
AstType* operator()(const LazyTypeVar& ltv)
{
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("<Lazy?>"));
}
private:
Allocator* allocator;
SyntheticNames* syntheticNames;
const TypeRehydrationOptions& options;
};
class TypePackRehydrationVisitor
{
public:
TypePackRehydrationVisitor(Allocator* allocator, const TypeRehydrationVisitor& typeVisitor)
TypePackRehydrationVisitor(Allocator* allocator, SyntheticNames* syntheticNames, TypeRehydrationVisitor* typeVisitor)
: allocator(allocator)
, syntheticNames(syntheticNames)
, typeVisitor(typeVisitor)
{
LUAU_ASSERT(allocator);
LUAU_ASSERT(syntheticNames);
LUAU_ASSERT(typeVisitor);
}
AstTypePack* operator()(const BoundTypePack& btp) const
@ -359,7 +385,7 @@ public:
head.data = static_cast<AstType**>(allocator->allocate(sizeof(AstType*) * tp.head.size()));
for (size_t i = 0; i < tp.head.size(); i++)
head.data[i] = Luau::visit(typeVisitor, tp.head[i]->ty);
head.data[i] = Luau::visit(*typeVisitor, tp.head[i]->ty);
AstTypePack* tail = nullptr;
@ -371,12 +397,12 @@ public:
AstTypePack* operator()(const VariadicTypePack& vtp) const
{
return allocator->alloc<AstTypePackVariadic>(Location(), Luau::visit(typeVisitor, vtp.ty->ty));
return allocator->alloc<AstTypePackVariadic>(Location(), Luau::visit(*typeVisitor, vtp.ty->ty));
}
AstTypePack* operator()(const GenericTypePack& gtp) const
{
return allocator->alloc<AstTypePackGeneric>(Location(), AstName(gtp.name.c_str()));
return allocator->alloc<AstTypePackGeneric>(Location(), AstName(getName(allocator, syntheticNames, gtp)));
}
AstTypePack* operator()(const FreeTypePack& gtp) const
@ -391,12 +417,13 @@ public:
private:
Allocator* allocator;
const TypeRehydrationVisitor& typeVisitor;
SyntheticNames* syntheticNames;
TypeRehydrationVisitor* typeVisitor;
};
AstTypePack* TypeRehydrationVisitor::rehydrate(TypePackId tp) const
AstTypePack* TypeRehydrationVisitor::rehydrate(TypePackId tp)
{
TypePackRehydrationVisitor tprv(allocator, *this);
TypePackRehydrationVisitor tprv(allocator, syntheticNames, this);
return Luau::visit(tprv, tp->ty);
}
@ -431,7 +458,7 @@ public:
{
if (!type)
return nullptr;
return Luau::visit(TypeRehydrationVisitor(allocator), (*type)->ty);
return Luau::visit(TypeRehydrationVisitor(allocator, &syntheticNames), (*type)->ty);
}
AstArray<Luau::AstType*> typeAstPack(TypePackId type)
@ -443,7 +470,7 @@ public:
result.data = static_cast<AstType**>(allocator->allocate(sizeof(AstType*) * v.size()));
for (size_t i = 0; i < v.size(); ++i)
{
result.data[i] = Luau::visit(TypeRehydrationVisitor(allocator), v[i]->ty);
result.data[i] = Luau::visit(TypeRehydrationVisitor(allocator, &syntheticNames), v[i]->ty);
}
return result;
}
@ -495,7 +522,7 @@ public:
{
if (FFlag::LuauTypeAliasPacks)
{
variadicAnnotation = TypeRehydrationVisitor(allocator).rehydrate(*tail);
variadicAnnotation = TypeRehydrationVisitor(allocator, &syntheticNames).rehydrate(*tail);
}
else
{
@ -515,6 +542,7 @@ public:
private:
Module& module;
Allocator* allocator;
SyntheticNames syntheticNames;
};
void attachTypeData(SourceModule& source, Module& result)
@ -525,7 +553,8 @@ void attachTypeData(SourceModule& source, Module& result)
AstType* rehydrateAnnotation(TypeId type, Allocator* allocator, const TypeRehydrationOptions& options)
{
return Luau::visit(TypeRehydrationVisitor(allocator, options), type->ty);
SyntheticNames syntheticNames;
return Luau::visit(TypeRehydrationVisitor(allocator, &syntheticNames, options), type->ty);
}
} // namespace Luau

View file

@ -4,6 +4,7 @@
#include "Luau/Common.h"
#include "Luau/ModuleResolver.h"
#include "Luau/Parser.h"
#include "Luau/Quantify.h"
#include "Luau/RecursionCounter.h"
#include "Luau/Scope.h"
#include "Luau/Substitution.h"
@ -33,18 +34,16 @@ LUAU_FASTFLAGVARIABLE(LuauCloneCorrectlyBeforeMutatingTableType, false)
LUAU_FASTFLAGVARIABLE(LuauStoreMatchingOverloadFnType, false)
LUAU_FASTFLAGVARIABLE(LuauRankNTypes, false)
LUAU_FASTFLAGVARIABLE(LuauOrPredicate, false)
LUAU_FASTFLAGVARIABLE(LuauExtraNilRecovery, false)
LUAU_FASTFLAGVARIABLE(LuauMissingUnionPropertyError, false)
LUAU_FASTFLAGVARIABLE(LuauInferReturnAssertAssign, false)
LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false)
LUAU_FASTFLAGVARIABLE(LuauAddMissingFollow, false)
LUAU_FASTFLAGVARIABLE(LuauTypeGuardPeelsAwaySubclasses, false)
LUAU_FASTFLAGVARIABLE(LuauSlightlyMoreFlexibleBinaryPredicates, false)
LUAU_FASTFLAGVARIABLE(LuauInferFunctionArgsFix, false)
LUAU_FASTFLAGVARIABLE(LuauFollowInTypeFunApply, false)
LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionAnalysisSupport, false)
LUAU_FASTFLAGVARIABLE(LuauStrictRequire, false)
LUAU_FASTFLAG(LuauSubstitutionDontReplaceIgnoredTypes)
LUAU_FASTFLAGVARIABLE(LuauQuantifyInPlace2, false)
LUAU_FASTFLAG(LuauNewRequireTrace)
LUAU_FASTFLAG(LuauTypeAliasPacks)
@ -215,6 +214,7 @@ static bool isMetamethod(const Name& name)
TypeChecker::TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHandler)
: resolver(resolver)
, iceHandler(iceHandler)
, unifierState(iceHandler)
, nilType(singletonTypes.nilType)
, numberType(singletonTypes.numberType)
, stringType(singletonTypes.stringType)
@ -370,13 +370,18 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block)
return;
}
int subLevel = 0;
std::vector<AstStat*> sorted(block.body.data, block.body.data + block.body.size);
toposort(sorted);
for (const auto& stat : sorted)
{
if (const auto& typealias = stat->as<AstStatTypeAlias>())
check(scope, *typealias, true);
{
check(scope, *typealias, subLevel, true);
++subLevel;
}
}
auto protoIter = sorted.begin();
@ -399,8 +404,6 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block)
}
};
int subLevel = 0;
while (protoIter != sorted.end())
{
// protoIter walks forward
@ -433,7 +436,7 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block)
// function f<a>(x:a):a local x: number = g(37) return x end
// function g(x:number):number return f(x) end
// ```
if (containsFunctionCall(**protoIter))
if (FFlag::LuauQuantifyInPlace2 ? containsFunctionCallOrReturn(**protoIter) : containsFunctionCall(**protoIter))
{
while (checkIter != protoIter)
{
@ -1161,7 +1164,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco
scope->bindings[function.name] = {quantify(scope, ty, function.name->location), function.name->location};
}
void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias, bool forwardDeclare)
void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias, int subLevel, bool forwardDeclare)
{
// This function should be called at most twice for each type alias.
// Once with forwardDeclare, and once without.
@ -1189,11 +1192,12 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias
}
else
{
ScopePtr aliasScope = childScope(scope, typealias.location);
ScopePtr aliasScope =
FFlag::LuauQuantifyInPlace2 ? childScope(scope, typealias.location, subLevel) : childScope(scope, typealias.location);
if (FFlag::LuauTypeAliasPacks)
{
auto [generics, genericPacks] = createGenericTypes(aliasScope, typealias, typealias.generics, typealias.genericPacks);
auto [generics, genericPacks] = createGenericTypes(aliasScope, scope->level, typealias, typealias.generics, typealias.genericPacks);
TypeId ty = (FFlag::LuauRankNTypes ? freshType(aliasScope) : DEPRECATED_freshType(scope, true));
FreeTypeVar* ftv = getMutable<FreeTypeVar>(ty);
@ -1418,7 +1422,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareFunction& glo
{
ScopePtr funScope = childFunctionScope(scope, global.location);
auto [generics, genericPacks] = createGenericTypes(funScope, global, global.generics, global.genericPacks);
auto [generics, genericPacks] = createGenericTypes(funScope, std::nullopt, global, global.generics, global.genericPacks);
TypePackId argPack = resolveTypePack(funScope, global.params);
TypePackId retPack = resolveTypePack(funScope, global.retTypes);
@ -1610,25 +1614,11 @@ ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIn
if (std::optional<TypeId> ty = resolveLValue(scope, *lvalue))
return {*ty, {TruthyPredicate{std::move(*lvalue), expr.location}}};
if (FFlag::LuauExtraNilRecovery)
lhsType = stripFromNilAndReport(lhsType, expr.expr->location);
lhsType = stripFromNilAndReport(lhsType, expr.expr->location);
if (std::optional<TypeId> ty = getIndexTypeFromType(scope, lhsType, name, expr.location, true))
return {*ty};
if (!FFlag::LuauMissingUnionPropertyError)
reportError(expr.indexLocation, UnknownProperty{lhsType, expr.index.value});
if (!FFlag::LuauExtraNilRecovery)
{
// Try to recover using a union without 'nil' options
if (std::optional<TypeId> strippedUnion = tryStripUnionFromNil(lhsType))
{
if (std::optional<TypeId> ty = getIndexTypeFromType(scope, *strippedUnion, name, expr.location, false))
return {*ty};
}
}
return {errorType};
}
@ -1694,61 +1684,37 @@ std::optional<TypeId> TypeChecker::getIndexTypeFromType(
}
else if (const UnionTypeVar* utv = get<UnionTypeVar>(type))
{
if (FFlag::LuauMissingUnionPropertyError)
std::vector<TypeId> goodOptions;
std::vector<TypeId> badOptions;
for (TypeId t : utv)
{
std::vector<TypeId> goodOptions;
std::vector<TypeId> badOptions;
RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit);
for (TypeId t : utv)
{
RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit);
if (std::optional<TypeId> ty = getIndexTypeFromType(scope, t, name, location, false))
goodOptions.push_back(*ty);
else
badOptions.push_back(t);
}
if (!badOptions.empty())
{
if (addErrors)
{
if (goodOptions.empty())
reportError(location, UnknownProperty{type, name});
else
reportError(location, MissingUnionProperty{type, badOptions, name});
}
return std::nullopt;
}
std::vector<TypeId> result = reduceUnion(goodOptions);
if (result.size() == 1)
return result[0];
return addType(UnionTypeVar{std::move(result)});
if (std::optional<TypeId> ty = getIndexTypeFromType(scope, t, name, location, false))
goodOptions.push_back(*ty);
else
badOptions.push_back(t);
}
else
if (!badOptions.empty())
{
std::vector<TypeId> options;
for (TypeId t : utv->options)
if (addErrors)
{
RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit);
if (std::optional<TypeId> ty = getIndexTypeFromType(scope, t, name, location, false))
options.push_back(*ty);
if (goodOptions.empty())
reportError(location, UnknownProperty{type, name});
else
return std::nullopt;
reportError(location, MissingUnionProperty{type, badOptions, name});
}
std::vector<TypeId> result = reduceUnion(options);
if (result.size() == 1)
return result[0];
return addType(UnionTypeVar{std::move(result)});
return std::nullopt;
}
std::vector<TypeId> result = reduceUnion(goodOptions);
if (result.size() == 1)
return result[0];
return addType(UnionTypeVar{std::move(result)});
}
else if (const IntersectionTypeVar* itv = get<IntersectionTypeVar>(type))
{
@ -1765,7 +1731,7 @@ std::optional<TypeId> TypeChecker::getIndexTypeFromType(
// If no parts of the intersection had the property we looked up for, it never existed at all.
if (parts.empty())
{
if (FFlag::LuauMissingUnionPropertyError && addErrors)
if (addErrors)
reportError(location, UnknownProperty{type, name});
return std::nullopt;
}
@ -1779,7 +1745,7 @@ std::optional<TypeId> TypeChecker::getIndexTypeFromType(
return addType(IntersectionTypeVar{result});
}
if (FFlag::LuauMissingUnionPropertyError && addErrors)
if (addErrors)
reportError(location, UnknownProperty{type, name});
return std::nullopt;
@ -2062,8 +2028,7 @@ ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprUn
case AstExprUnary::Len:
tablify(operandType);
if (FFlag::LuauExtraNilRecovery)
operandType = stripFromNilAndReport(operandType, expr.location);
operandType = stripFromNilAndReport(operandType, expr.location);
if (get<ErrorTypeVar>(operandType))
return {errorType};
@ -2635,8 +2600,7 @@ std::pair<TypeId, TypeId*> TypeChecker::checkLValueBinding(const ScopePtr& scope
Name name = expr.index.value;
if (FFlag::LuauExtraNilRecovery)
lhs = stripFromNilAndReport(lhs, expr.expr->location);
lhs = stripFromNilAndReport(lhs, expr.expr->location);
if (TableTypeVar* lhsTable = getMutableTableType(lhs))
{
@ -2710,8 +2674,7 @@ std::pair<TypeId, TypeId*> TypeChecker::checkLValueBinding(const ScopePtr& scope
TypeId exprType = checkExpr(scope, *expr.expr).type;
tablify(exprType);
if (FFlag::LuauExtraNilRecovery)
exprType = stripFromNilAndReport(exprType, expr.expr->location);
exprType = stripFromNilAndReport(exprType, expr.expr->location);
TypeId indexType = checkExpr(scope, *expr.index).type;
@ -2738,10 +2701,7 @@ std::pair<TypeId, TypeId*> TypeChecker::checkLValueBinding(const ScopePtr& scope
if (!exprTable)
{
if (FFlag::LuauExtraNilRecovery)
reportError(TypeError{expr.expr->location, NotATable{exprType}});
else
reportError(TypeError{expr.location, NotATable{exprType}});
reportError(TypeError{expr.expr->location, NotATable{exprType}});
return std::pair(errorType, nullptr);
}
@ -2910,7 +2870,7 @@ std::pair<TypeId, ScopePtr> TypeChecker::checkFunctionSignature(
if (FFlag::LuauGenericFunctions)
{
std::tie(generics, genericPacks) = createGenericTypes(funScope, expr, expr.generics, expr.genericPacks);
std::tie(generics, genericPacks) = createGenericTypes(funScope, std::nullopt, expr, expr.generics, expr.genericPacks);
}
TypePackId retPack;
@ -3016,9 +2976,6 @@ std::pair<TypeId, ScopePtr> TypeChecker::checkFunctionSignature(
if (expectedArgsCurr != expectedArgsEnd)
{
argType = *expectedArgsCurr;
if (!FFlag::LuauInferFunctionArgsFix)
++expectedArgsCurr;
}
else if (auto expectedArgsTail = expectedArgsCurr.tail())
{
@ -3034,7 +2991,7 @@ std::pair<TypeId, ScopePtr> TypeChecker::checkFunctionSignature(
funScope->bindings[local] = {argType, local->location};
argTypes.push_back(argType);
if (FFlag::LuauInferFunctionArgsFix && expectedArgsCurr != expectedArgsEnd)
if (expectedArgsCurr != expectedArgsEnd)
++expectedArgsCurr;
}
@ -3402,8 +3359,7 @@ ExprResult<TypePackId> TypeChecker::checkExprPack(const ScopePtr& scope, const A
if (!FFlag::LuauRankNTypes)
instantiate(scope, selfType, expr.func->location);
if (FFlag::LuauExtraNilRecovery)
selfType = stripFromNilAndReport(selfType, expr.func->location);
selfType = stripFromNilAndReport(selfType, expr.func->location);
if (std::optional<TypeId> propTy = getIndexTypeFromType(scope, selfType, indexExpr->index.value, expr.location, true))
{
@ -3412,34 +3368,8 @@ ExprResult<TypePackId> TypeChecker::checkExprPack(const ScopePtr& scope, const A
}
else
{
if (!FFlag::LuauMissingUnionPropertyError)
reportError(indexExpr->indexLocation, UnknownProperty{selfType, indexExpr->index.value});
if (!FFlag::LuauExtraNilRecovery)
{
// Try to recover using a union without 'nil' options
if (std::optional<TypeId> strippedUnion = tryStripUnionFromNil(selfType))
{
if (std::optional<TypeId> propTy = getIndexTypeFromType(scope, *strippedUnion, indexExpr->index.value, expr.location, false))
{
selfType = *strippedUnion;
functionType = *propTy;
actualFunctionType = instantiate(scope, functionType, expr.func->location);
}
}
if (!actualFunctionType)
{
functionType = errorType;
actualFunctionType = errorType;
}
}
else
{
functionType = errorType;
actualFunctionType = errorType;
}
functionType = errorType;
actualFunctionType = errorType;
}
}
else
@ -3555,8 +3485,7 @@ std::optional<ExprResult<TypePackId>> TypeChecker::checkCallOverload(const Scope
TypePackId argPack, TypePack* args, const std::vector<Location>& argLocations, const ExprResult<TypePackId>& argListResult,
std::vector<TypeId>& overloadsThatMatchArgCount, std::vector<OverloadErrorEntry>& errors)
{
if (FFlag::LuauExtraNilRecovery)
fn = stripFromNilAndReport(fn, expr.func->location);
fn = stripFromNilAndReport(fn, expr.func->location);
if (get<AnyTypeVar>(fn))
{
@ -4283,6 +4212,12 @@ TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location
if (!ftv || !ftv->generics.empty() || !ftv->genericPacks.empty())
return ty;
if (FFlag::LuauQuantifyInPlace2)
{
Luau::quantify(currentModule, ty, scope->level);
return ty;
}
quantification.level = scope->level;
quantification.generics.clear();
quantification.genericPacks.clear();
@ -4491,12 +4426,12 @@ void TypeChecker::merge(RefinementMap& l, const RefinementMap& r)
Unifier TypeChecker::mkUnifier(const Location& location)
{
return Unifier{&currentModule->internalTypes, currentModule->mode, globalScope, location, Variance::Covariant, iceHandler};
return Unifier{&currentModule->internalTypes, currentModule->mode, globalScope, location, Variance::Covariant, unifierState};
}
Unifier TypeChecker::mkUnifier(const std::vector<std::pair<TypeId, TypeId>>& seen, const Location& location)
{
return Unifier{&currentModule->internalTypes, currentModule->mode, globalScope, seen, location, Variance::Covariant, iceHandler};
return Unifier{&currentModule->internalTypes, currentModule->mode, globalScope, seen, location, Variance::Covariant, unifierState};
}
TypeId TypeChecker::freshType(const ScopePtr& scope)
@ -4753,7 +4688,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation
if (FFlag::LuauGenericFunctions)
{
std::tie(generics, genericPacks) = createGenericTypes(funcScope, annotation, func->generics, func->genericPacks);
std::tie(generics, genericPacks) = createGenericTypes(funcScope, std::nullopt, annotation, func->generics, func->genericPacks);
}
// TODO: better error message CLI-39912
@ -5041,10 +4976,12 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf,
}
std::pair<std::vector<TypeId>, std::vector<TypePackId>> TypeChecker::createGenericTypes(
const ScopePtr& scope, const AstNode& node, const AstArray<AstName>& genericNames, const AstArray<AstName>& genericPackNames)
const ScopePtr& scope, std::optional<TypeLevel> levelOpt, const AstNode& node, const AstArray<AstName>& genericNames, const AstArray<AstName>& genericPackNames)
{
LUAU_ASSERT(scope->parent);
const TypeLevel level = (FFlag::LuauQuantifyInPlace2 && levelOpt) ? *levelOpt : scope->level;
std::vector<TypeId> generics;
for (const AstName& generic : genericNames)
{
@ -5063,12 +5000,12 @@ std::pair<std::vector<TypeId>, std::vector<TypePackId>> TypeChecker::createGener
{
TypeId& cached = scope->parent->typeAliasTypeParameters[n];
if (!cached)
cached = addType(GenericTypeVar{scope->level, n});
cached = addType(GenericTypeVar{level, n});
g = cached;
}
else
{
g = addType(Unifiable::Generic{scope->level, n});
g = addType(Unifiable::Generic{level, n});
}
generics.push_back(g);
@ -5093,12 +5030,12 @@ std::pair<std::vector<TypeId>, std::vector<TypePackId>> TypeChecker::createGener
{
TypePackId& cached = scope->parent->typeAliasTypePackParameters[n];
if (!cached)
cached = addTypePack(TypePackVar{Unifiable::Generic{scope->level, n}});
cached = addTypePack(TypePackVar{Unifiable::Generic{level, n}});
g = cached;
}
else
{
g = addTypePack(TypePackVar{Unifiable::Generic{scope->level, n}});
g = addTypePack(TypePackVar{Unifiable::Generic{level, n}});
}
genericPacks.push_back(g);

View file

@ -22,6 +22,7 @@ LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0)
LUAU_FASTFLAG(LuauRankNTypes)
LUAU_FASTFLAG(LuauTypeGuardPeelsAwaySubclasses)
LUAU_FASTFLAG(LuauTypeAliasPacks)
LUAU_FASTFLAGVARIABLE(LuauRefactorTagging, false)
namespace Luau
{
@ -217,8 +218,7 @@ std::optional<TypeId> getMetatable(TypeId type)
return mtType->metatable;
else if (const ClassTypeVar* classType = get<ClassTypeVar>(type))
return classType->metatable;
else if (const PrimitiveTypeVar* primitiveType = get<PrimitiveTypeVar>(type);
primitiveType && primitiveType->metatable)
else if (const PrimitiveTypeVar* primitiveType = get<PrimitiveTypeVar>(type); primitiveType && primitiveType->metatable)
{
LUAU_ASSERT(primitiveType->type == PrimitiveTypeVar::String);
return primitiveType->metatable;
@ -1490,4 +1490,86 @@ std::vector<TypeId> filterMap(TypeId type, TypeIdPredicate predicate)
return {};
}
static Tags* getTags(TypeId ty)
{
ty = follow(ty);
if (auto ftv = getMutable<FunctionTypeVar>(ty))
return &ftv->tags;
else if (auto ttv = getMutable<TableTypeVar>(ty))
return &ttv->tags;
else if (auto ctv = getMutable<ClassTypeVar>(ty))
return &ctv->tags;
return nullptr;
}
void attachTag(TypeId ty, const std::string& tagName)
{
if (!FFlag::LuauRefactorTagging)
{
if (auto ftv = getMutable<FunctionTypeVar>(ty))
{
ftv->tags.emplace_back(tagName);
}
else
{
LUAU_ASSERT(!"Got a non functional type");
}
}
else
{
if (auto tags = getTags(ty))
tags->push_back(tagName);
else
LUAU_ASSERT(!"This TypeId does not support tags");
}
}
void attachTag(Property& prop, const std::string& tagName)
{
LUAU_ASSERT(FFlag::LuauRefactorTagging);
prop.tags.push_back(tagName);
}
// We would ideally not expose this because it could cause a footgun.
// If the Base class has a tag and you ask if Derived has that tag, it would return false.
// Unfortunately, there's already use cases that's hard to disentangle. For now, we expose it.
bool hasTag(const Tags& tags, const std::string& tagName)
{
LUAU_ASSERT(FFlag::LuauRefactorTagging);
return std::find(tags.begin(), tags.end(), tagName) != tags.end();
}
bool hasTag(TypeId ty, const std::string& tagName)
{
ty = follow(ty);
// We special case classes because getTags only returns a pointer to one vector of tags.
// But classes has multiple vector of tags, represented throughout the hierarchy.
if (auto ctv = get<ClassTypeVar>(ty))
{
while (ctv)
{
if (hasTag(ctv->tags, tagName))
return true;
else if (!ctv->parent)
return false;
ctv = get<ClassTypeVar>(*ctv->parent);
LUAU_ASSERT(ctv);
}
}
else if (auto tags = getTags(ty))
return hasTag(*tags, tagName);
return false;
}
bool hasTag(const Property& prop, const std::string& tagName)
{
return hasTag(prop.tags, tagName);
}
} // namespace Luau

View file

@ -7,6 +7,7 @@
#include "Luau/TypePack.h"
#include "Luau/TypeUtils.h"
#include "Luau/TimeTrace.h"
#include "Luau/VisitTypeVar.h"
#include <algorithm>
@ -22,9 +23,99 @@ LUAU_FASTFLAGVARIABLE(LuauTableUnificationEarlyTest, false)
LUAU_FASTFLAGVARIABLE(LuauSealedTableUnifyOptionalFix, false)
LUAU_FASTFLAGVARIABLE(LuauOccursCheckOkWithRecursiveFunctions, false)
LUAU_FASTFLAGVARIABLE(LuauTypecheckOpts, false)
LUAU_FASTFLAG(LuauShareTxnSeen);
LUAU_FASTFLAGVARIABLE(LuauCacheUnifyTableResults, false)
namespace Luau
{
struct SkipCacheForType
{
SkipCacheForType(const DenseHashMap<TypeId, bool>& skipCacheForType)
: skipCacheForType(skipCacheForType)
{
}
void cycle(TypeId) {}
void cycle(TypePackId) {}
bool operator()(TypeId ty, const FreeTypeVar& ftv)
{
result = true;
return false;
}
bool operator()(TypeId ty, const BoundTypeVar& btv)
{
result = true;
return false;
}
bool operator()(TypeId ty, const GenericTypeVar& btv)
{
result = true;
return false;
}
bool operator()(TypeId ty, const TableTypeVar&)
{
TableTypeVar& ttv = *getMutable<TableTypeVar>(ty);
if (ttv.boundTo)
{
result = true;
return false;
}
if (ttv.state != TableState::Sealed)
{
result = true;
return false;
}
return true;
}
template<typename T>
bool operator()(TypeId ty, const T& t)
{
const bool* prev = skipCacheForType.find(ty);
if (prev && *prev)
{
result = true;
return false;
}
return true;
}
template<typename T>
bool operator()(TypePackId, const T&)
{
return true;
}
bool operator()(TypePackId tp, const FreeTypePack& ftp)
{
result = true;
return false;
}
bool operator()(TypePackId tp, const BoundTypePack& ftp)
{
result = true;
return false;
}
bool operator()(TypePackId tp, const GenericTypePack& ftp)
{
result = true;
return false;
}
const DenseHashMap<TypeId, bool>& skipCacheForType;
bool result = false;
};
static std::optional<TypeError> hasUnificationTooComplex(const ErrorVec& errors)
{
@ -39,7 +130,7 @@ static std::optional<TypeError> hasUnificationTooComplex(const ErrorVec& errors)
return *it;
}
Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, InternalErrorReporter* iceHandler)
Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, UnifierSharedState& sharedState)
: types(types)
, mode(mode)
, globalScope(std::move(globalScope))
@ -47,24 +138,39 @@ Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Locati
, variance(variance)
, counters(&countersData)
, counters_DEPRECATED(std::make_shared<UnifierCounters>())
, iceHandler(iceHandler)
, sharedState(sharedState)
{
LUAU_ASSERT(iceHandler);
LUAU_ASSERT(sharedState.iceHandler);
}
Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::vector<std::pair<TypeId, TypeId>>& seen, const Location& location,
Variance variance, InternalErrorReporter* iceHandler, const std::shared_ptr<UnifierCounters>& counters_DEPRECATED, UnifierCounters* counters)
Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::vector<std::pair<TypeId, TypeId>>& ownedSeen, const Location& location,
Variance variance, UnifierSharedState& sharedState, const std::shared_ptr<UnifierCounters>& counters_DEPRECATED, UnifierCounters* counters)
: types(types)
, mode(mode)
, globalScope(std::move(globalScope))
, log(seen)
, log(ownedSeen)
, location(location)
, variance(variance)
, counters(counters ? counters : &countersData)
, counters_DEPRECATED(counters_DEPRECATED ? counters_DEPRECATED : std::make_shared<UnifierCounters>())
, iceHandler(iceHandler)
, sharedState(sharedState)
{
LUAU_ASSERT(iceHandler);
LUAU_ASSERT(sharedState.iceHandler);
}
Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, std::vector<std::pair<TypeId, TypeId>>* sharedSeen, const Location& location,
Variance variance, UnifierSharedState& sharedState, const std::shared_ptr<UnifierCounters>& counters_DEPRECATED, UnifierCounters* counters)
: types(types)
, mode(mode)
, globalScope(std::move(globalScope))
, log(sharedSeen)
, location(location)
, variance(variance)
, counters(counters ? counters : &countersData)
, counters_DEPRECATED(counters_DEPRECATED ? counters_DEPRECATED : std::make_shared<UnifierCounters>())
, sharedState(sharedState)
{
LUAU_ASSERT(sharedState.iceHandler);
}
void Unifier::tryUnify(TypeId superTy, TypeId subTy, bool isFunctionCall, bool isIntersection)
@ -74,7 +180,7 @@ void Unifier::tryUnify(TypeId superTy, TypeId subTy, bool isFunctionCall, bool i
else
counters_DEPRECATED->iterationCount = 0;
return tryUnify_(superTy, subTy, isFunctionCall, isIntersection);
tryUnify_(superTy, subTy, isFunctionCall, isIntersection);
}
void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool isIntersection)
@ -206,6 +312,13 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool
if (get<ErrorTypeVar>(subTy) || get<AnyTypeVar>(subTy))
return tryUnifyWithAny(subTy, superTy);
bool cacheEnabled = FFlag::LuauCacheUnifyTableResults && !isFunctionCall && !isIntersection;
auto& cache = sharedState.cachedUnify;
// What if the types are immutable and we proved their relation before
if (cacheEnabled && cache.contains({superTy, subTy}) && (variance == Covariant || cache.contains({subTy, superTy})))
return;
// If we have seen this pair of types before, we are currently recursing into cyclic types.
// Here, we assume that the types unify. If they do not, we will find out as we roll back
// the stack.
@ -257,6 +370,8 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool
if (FFlag::LuauUnionHeuristic)
{
bool found = false;
const std::string* subName = getName(subTy);
if (subName)
{
@ -264,6 +379,21 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool
{
const std::string* optionName = getName(uv->options[i]);
if (optionName && *optionName == *subName)
{
found = true;
startIndex = i;
break;
}
}
}
if (!found && cacheEnabled)
{
for (size_t i = 0; i < uv->options.size(); ++i)
{
TypeId type = uv->options[i];
if (cache.contains({type, subTy}) && (variance == Covariant || cache.contains({subTy, type})))
{
startIndex = i;
break;
@ -311,8 +441,25 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool
bool found = false;
std::optional<TypeError> unificationTooComplex;
for (TypeId type : uv->parts)
size_t startIndex = 0;
if (cacheEnabled)
{
for (size_t i = 0; i < uv->parts.size(); ++i)
{
TypeId type = uv->parts[i];
if (cache.contains({superTy, type}) && (variance == Covariant || cache.contains({type, superTy})))
{
startIndex = i;
break;
}
}
}
for (size_t i = 0; i < uv->parts.size(); ++i)
{
TypeId type = uv->parts[(i + startIndex) % uv->parts.size()];
Unifier innerState = makeChildUnifier();
innerState.tryUnify_(superTy, type, isFunctionCall);
@ -342,8 +489,13 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool
tryUnifyFunctions(superTy, subTy, isFunctionCall);
else if (get<TableTypeVar>(superTy) && get<TableTypeVar>(subTy))
{
tryUnifyTables(superTy, subTy, isIntersection);
if (cacheEnabled && errors.empty())
cacheResult(superTy, subTy);
}
// tryUnifyWithMetatable assumes its first argument is a MetatableTypeVar. The check is otherwise symmetrical.
else if (get<MetatableTypeVar>(superTy))
tryUnifyWithMetatable(superTy, subTy, /*reversed*/ false);
@ -364,6 +516,41 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool
log.popSeen(superTy, subTy);
}
void Unifier::cacheResult(TypeId superTy, TypeId subTy)
{
LUAU_ASSERT(FFlag::LuauCacheUnifyTableResults);
bool* superTyInfo = sharedState.skipCacheForType.find(superTy);
if (superTyInfo && *superTyInfo)
return;
bool* subTyInfo = sharedState.skipCacheForType.find(subTy);
if (subTyInfo && *subTyInfo)
return;
auto skipCacheFor = [this](TypeId ty) {
SkipCacheForType visitor{sharedState.skipCacheForType};
visitTypeVarOnce(ty, visitor, sharedState.seenAny);
sharedState.skipCacheForType[ty] = visitor.result;
return visitor.result;
};
if (!superTyInfo && skipCacheFor(superTy))
return;
if (!subTyInfo && skipCacheFor(subTy))
return;
sharedState.cachedUnify.insert({superTy, subTy});
if (variance == Invariant)
sharedState.cachedUnify.insert({subTy, superTy});
}
struct WeirdIter
{
TypePackId packId;
@ -459,7 +646,7 @@ void Unifier::tryUnify(TypePackId superTp, TypePackId subTp, bool isFunctionCall
else
counters_DEPRECATED->iterationCount = 0;
return tryUnify_(superTp, subTp, isFunctionCall);
tryUnify_(superTp, subTp, isFunctionCall);
}
/*
@ -797,6 +984,40 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection)
std::vector<std::string> missingProperties;
std::vector<std::string> extraProperties;
// Optimization: First test that the property sets are compatible without doing any recursive unification
if (FFlag::LuauTableUnificationEarlyTest && !rt->indexer && rt->state != TableState::Free)
{
for (const auto& [propName, superProp] : lt->props)
{
auto subIter = rt->props.find(propName);
if (subIter == rt->props.end() && !isOptional(superProp.type) && !get<AnyTypeVar>(follow(superProp.type)))
missingProperties.push_back(propName);
}
if (!missingProperties.empty())
{
errors.push_back(TypeError{location, MissingProperties{left, right, std::move(missingProperties)}});
return;
}
}
// And vice versa if we're invariant
if (FFlag::LuauTableUnificationEarlyTest && variance == Invariant && !lt->indexer && lt->state != TableState::Unsealed && lt->state != TableState::Free)
{
for (const auto& [propName, subProp] : rt->props)
{
auto superIter = lt->props.find(propName);
if (superIter == lt->props.end() && !isOptional(subProp.type) && !get<AnyTypeVar>(follow(subProp.type)))
extraProperties.push_back(propName);
}
if (!extraProperties.empty())
{
errors.push_back(TypeError{location, MissingProperties{left, right, std::move(extraProperties), MissingProperties::Extra}});
return;
}
}
// Reminder: left is the supertype, right is the subtype.
// Width subtyping: any property in the supertype must be in the subtype,
// and the types must agree.
@ -833,9 +1054,10 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection)
innerState.log.rollback();
}
else if (isOptional(prop.type) || get<AnyTypeVar>(follow(prop.type)))
// TODO: this case is unsound, but without it our test suite fails. CLI-46031
// TODO: should isOptional(anyType) be true?
{}
// TODO: this case is unsound, but without it our test suite fails. CLI-46031
// TODO: should isOptional(anyType) be true?
{
}
else if (rt->state == TableState::Free)
{
log(rt);
@ -878,11 +1100,13 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection)
lt->props[name] = clone;
}
else if (variance == Covariant)
{}
{
}
else if (isOptional(prop.type) || get<AnyTypeVar>(follow(prop.type)))
// TODO: this case is unsound, but without it our test suite fails. CLI-46031
// TODO: should isOptional(anyType) be true?
{}
// TODO: this case is unsound, but without it our test suite fails. CLI-46031
// TODO: should isOptional(anyType) be true?
{
}
else if (lt->state == TableState::Free)
{
log(lt);
@ -980,10 +1204,10 @@ TypeId Unifier::deeplyOptional(TypeId ty, std::unordered_map<TypeId, TypeId> see
TableTypeVar* resultTtv = getMutable<TableTypeVar>(result);
for (auto& [name, prop] : resultTtv->props)
prop.type = deeplyOptional(prop.type, seen);
return types->addType(UnionTypeVar{{ singletonTypes.nilType, result }});;
return types->addType(UnionTypeVar{{singletonTypes.nilType, result}});
}
else
return types->addType(UnionTypeVar{{ singletonTypes.nilType, ty }});
return types->addType(UnionTypeVar{{singletonTypes.nilType, ty}});
}
void Unifier::DEPRECATED_tryUnifyTables(TypeId left, TypeId right, bool isIntersection)
@ -1697,10 +1921,20 @@ void Unifier::tryUnifyWithAny(TypeId any, TypeId ty)
{
std::vector<TypeId> queue = {ty};
tempSeenTy.clear();
tempSeenTp.clear();
if (FFlag::LuauCacheUnifyTableResults)
{
sharedState.tempSeenTy.clear();
sharedState.tempSeenTp.clear();
Luau::tryUnifyWithAny(queue, *this, tempSeenTy, tempSeenTp, singletonTypes.anyType, anyTP);
Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, singletonTypes.anyType, anyTP);
}
else
{
tempSeenTy_DEPRECATED.clear();
tempSeenTp_DEPRECATED.clear();
Luau::tryUnifyWithAny(queue, *this, tempSeenTy_DEPRECATED, tempSeenTp_DEPRECATED, singletonTypes.anyType, anyTP);
}
}
else
{
@ -1721,12 +1955,24 @@ void Unifier::tryUnifyWithAny(TypePackId any, TypePackId ty)
{
std::vector<TypeId> queue;
tempSeenTy.clear();
tempSeenTp.clear();
if (FFlag::LuauCacheUnifyTableResults)
{
sharedState.tempSeenTy.clear();
sharedState.tempSeenTp.clear();
queueTypePack(queue, tempSeenTp, *this, ty, any);
queueTypePack(queue, sharedState.tempSeenTp, *this, ty, any);
Luau::tryUnifyWithAny(queue, *this, tempSeenTy, tempSeenTp, anyTy, any);
Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, anyTy, any);
}
else
{
tempSeenTy_DEPRECATED.clear();
tempSeenTp_DEPRECATED.clear();
queueTypePack(queue, tempSeenTp_DEPRECATED, *this, ty, any);
Luau::tryUnifyWithAny(queue, *this, tempSeenTy_DEPRECATED, tempSeenTp_DEPRECATED, anyTy, any);
}
}
else
{
@ -1775,10 +2021,20 @@ void Unifier::occursCheck(TypeId needle, TypeId haystack)
{
std::unordered_set<TypeId> seen_DEPRECATED;
if (FFlag::LuauTypecheckOpts)
tempSeenTy.clear();
if (FFlag::LuauCacheUnifyTableResults)
{
if (FFlag::LuauTypecheckOpts)
sharedState.tempSeenTy.clear();
return occursCheck(seen_DEPRECATED, tempSeenTy, needle, haystack);
return occursCheck(seen_DEPRECATED, sharedState.tempSeenTy, needle, haystack);
}
else
{
if (FFlag::LuauTypecheckOpts)
tempSeenTy_DEPRECATED.clear();
return occursCheck(seen_DEPRECATED, tempSeenTy_DEPRECATED, needle, haystack);
}
}
void Unifier::occursCheck(std::unordered_set<TypeId>& seen_DEPRECATED, DenseHashSet<TypeId>& seen, TypeId needle, TypeId haystack)
@ -1851,10 +2107,20 @@ void Unifier::occursCheck(TypePackId needle, TypePackId haystack)
{
std::unordered_set<TypePackId> seen_DEPRECATED;
if (FFlag::LuauTypecheckOpts)
tempSeenTp.clear();
if (FFlag::LuauCacheUnifyTableResults)
{
if (FFlag::LuauTypecheckOpts)
sharedState.tempSeenTp.clear();
return occursCheck(seen_DEPRECATED, tempSeenTp, needle, haystack);
return occursCheck(seen_DEPRECATED, sharedState.tempSeenTp, needle, haystack);
}
else
{
if (FFlag::LuauTypecheckOpts)
tempSeenTp_DEPRECATED.clear();
return occursCheck(seen_DEPRECATED, tempSeenTp_DEPRECATED, needle, haystack);
}
}
void Unifier::occursCheck(std::unordered_set<TypePackId>& seen_DEPRECATED, DenseHashSet<TypePackId>& seen, TypePackId needle, TypePackId haystack)
@ -1922,7 +2188,10 @@ void Unifier::occursCheck(std::unordered_set<TypePackId>& seen_DEPRECATED, Dense
Unifier Unifier::makeChildUnifier()
{
return Unifier{types, mode, globalScope, log.seen, location, variance, iceHandler, counters_DEPRECATED, counters};
if (FFlag::LuauShareTxnSeen)
return Unifier{types, mode, globalScope, log.sharedSeen, location, variance, sharedState, counters_DEPRECATED, counters};
else
return Unifier{types, mode, globalScope, log.ownedSeen, location, variance, sharedState, counters_DEPRECATED, counters};
}
bool Unifier::isNonstrictMode() const
@ -1940,12 +2209,12 @@ void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, TypeId
void Unifier::ice(const std::string& message, const Location& location)
{
iceHandler->ice(message, location);
sharedState.iceHandler->ice(message, location);
}
void Unifier::ice(const std::string& message)
{
iceHandler->ice(message);
sharedState.iceHandler->ice(message);
}
} // namespace Luau

View file

@ -194,20 +194,20 @@ LUAU_NOINLINE std::pair<uint16_t, Luau::TimeTrace::ThreadContext&> createScopeDa
} // namespace Luau
// Regular scope
#define LUAU_TIMETRACE_SCOPE(name, category) \
#define LUAU_TIMETRACE_SCOPE(name, category) \
static auto lttScopeStatic = Luau::TimeTrace::createScopeData(name, category); \
Luau::TimeTrace::Scope lttScope(lttScopeStatic.second, lttScopeStatic.first)
// A scope without nested scopes that may be skipped if the time it took is less than the threshold
#define LUAU_TIMETRACE_OPTIONAL_TAIL_SCOPE(name, category, microsec) \
#define LUAU_TIMETRACE_OPTIONAL_TAIL_SCOPE(name, category, microsec) \
static auto lttScopeStaticOptTail = Luau::TimeTrace::createScopeData(name, category); \
Luau::TimeTrace::OptionalTailScope lttScope(lttScopeStaticOptTail.second, lttScopeStaticOptTail.first, microsec)
// Extra key/value data can be added to regular scopes
#define LUAU_TIMETRACE_ARGUMENT(name, value) \
do \
{ \
if (FFlag::DebugLuauTimeTracing) \
#define LUAU_TIMETRACE_ARGUMENT(name, value) \
do \
{ \
if (FFlag::DebugLuauTimeTracing) \
lttScopeStatic.second.eventArgument(name, value); \
} while (false)
@ -216,8 +216,8 @@ LUAU_NOINLINE std::pair<uint16_t, Luau::TimeTrace::ThreadContext&> createScopeDa
#define LUAU_TIMETRACE_SCOPE(name, category)
#define LUAU_TIMETRACE_OPTIONAL_TAIL_SCOPE(name, category, microsec)
#define LUAU_TIMETRACE_ARGUMENT(name, value) \
do \
{ \
do \
{ \
} while (false)
#endif

View file

@ -2379,15 +2379,12 @@ AstArray<AstTypeOrPack> Parser::parseTypeParams()
Lexeme begin = lexer.current();
nextLexeme();
bool seenPack = false;
while (true)
{
if (FFlag::LuauParseTypePackTypeParameters)
{
if (shouldParseTypePackAnnotation(lexer))
{
seenPack = true;
auto typePack = parseTypePackAnnotation();
if (FFlag::LuauTypeAliasPacks) // Type packs are recorded only is we can handle them
@ -2399,8 +2396,6 @@ AstArray<AstTypeOrPack> Parser::parseTypeParams()
if (typePack)
{
seenPack = true;
if (FFlag::LuauTypeAliasPacks) // Type packs are recorded only is we can handle them
parameters.push_back({{}, typePack});
}

View file

@ -77,7 +77,10 @@ struct GlobalContext
// Ideally we would want all ThreadContext destructors to run
// But in VS, not all thread_local object instances are destroyed
for (ThreadContext* context : threads)
context->flushEvents();
{
if (!context->events.empty())
context->flushEvents();
}
if (traceFile)
fclose(traceFile);

View file

@ -34,8 +34,10 @@ static void report(ReportFormat format, const char* name, const Luau::Location&
}
}
static void reportError(ReportFormat format, const char* name, const Luau::TypeError& error)
static void reportError(ReportFormat format, const Luau::TypeError& error)
{
const char* name = error.moduleName.c_str();
if (const Luau::SyntaxError* syntaxError = Luau::get_if<Luau::SyntaxError>(&error.data))
report(format, name, error.location, "SyntaxError", syntaxError->message.c_str());
else
@ -49,7 +51,10 @@ static void reportWarning(ReportFormat format, const char* name, const Luau::Lin
static bool analyzeFile(Luau::Frontend& frontend, const char* name, ReportFormat format, bool annotate)
{
Luau::CheckResult cr = frontend.check(name);
Luau::CheckResult cr;
if (frontend.isDirty(name))
cr = frontend.check(name);
if (!frontend.getSourceModule(name))
{
@ -58,7 +63,7 @@ static bool analyzeFile(Luau::Frontend& frontend, const char* name, ReportFormat
}
for (auto& error : cr.errors)
reportError(format, name, error);
reportError(format, error);
Luau::LintResult lr = frontend.lint(name);
@ -115,7 +120,12 @@ struct CliFileResolver : Luau::FileResolver
{
if (Luau::AstExprConstantString* expr = node->as<Luau::AstExprConstantString>())
{
Luau::ModuleName name = std::string(expr->value.data, expr->value.size) + ".lua";
Luau::ModuleName name = std::string(expr->value.data, expr->value.size) + ".luau";
if (!moduleExists(name))
{
// fall back to .lua if a module with .luau doesn't exist
name = std::string(expr->value.data, expr->value.size) + ".lua";
}
return {{name}};
}
@ -236,8 +246,15 @@ int main(int argc, char** argv)
if (isDirectory(argv[i]))
{
traverseDirectory(argv[i], [&](const std::string& name) {
if (name.length() > 4 && name.rfind(".lua") == name.length() - 4)
// Look for .luau first and if absent, fall back to .lua
if (name.length() > 5 && name.rfind(".luau") == name.length() - 5)
{
failed += !analyzeFile(frontend, name.c_str(), format, annotate);
}
else if (name.length() > 4 && name.rfind(".lua") == name.length() - 4)
{
failed += !analyzeFile(frontend, name.c_str(), format, annotate);
}
});
}
else
@ -256,5 +273,3 @@ int main(int argc, char** argv)
return (format == ReportFormat::Luacheck) ? 0 : failed;
}

View file

@ -51,9 +51,13 @@ static int lua_require(lua_State* L)
return finishrequire(L);
lua_pop(L, 1);
std::optional<std::string> source = readFile(name + ".lua");
std::optional<std::string> source = readFile(name + ".luau");
if (!source)
luaL_argerrorL(L, 1, ("error loading " + name).c_str());
{
source = readFile(name + ".lua"); // try .lua if .luau doesn't exist
if (!source)
luaL_argerrorL(L, 1, ("error loading " + name).c_str()); // if neither .luau nor .lua exist, we have an error
}
// module needs to run in a new thread, isolated from the rest
lua_State* GL = lua_mainthread(L);
@ -183,6 +187,11 @@ static std::string runCode(lua_State* L, const std::string& source)
error += "\nstack backtrace:\n";
error += lua_debugtrace(T);
#ifdef __EMSCRIPTEN__
// nicer formatting for errors in web repl
error = "Error:" + error;
#endif
fprintf(stdout, "%s", error.c_str());
}
@ -190,6 +199,41 @@ static std::string runCode(lua_State* L, const std::string& source)
return std::string();
}
#ifdef __EMSCRIPTEN__
extern "C"
{
const char* executeScript(const char* source)
{
// setup flags
for (Luau::FValue<bool>* flag = Luau::FValue<bool>::list; flag; flag = flag->next)
if (strncmp(flag->name, "Luau", 4) == 0)
flag->value = true;
// create new state
std::unique_ptr<lua_State, void (*)(lua_State*)> globalState(luaL_newstate(), lua_close);
lua_State* L = globalState.get();
// setup state
setupState(L);
// sandbox thread
luaL_sandboxthread(L);
// run code + collect error
std::string error = runCode(L, source);
// output error(s)
if (error.length())
{
return std::move(error.c_str());
}
return NULL;
}
}
#endif
// Excluded from emscripten compilation to avoid -Wunused-function errors.
#ifndef __EMSCRIPTEN__
static void completeIndexer(lua_State* L, const char* editBuffer, size_t start, std::vector<std::string>& completions)
{
std::string_view lookup = editBuffer + start;
@ -511,5 +555,4 @@ int main(int argc, char** argv)
return failed;
}
}
#endif

View file

@ -17,17 +17,26 @@ add_library(Luau.VM STATIC)
if(LUAU_BUILD_CLI)
add_executable(Luau.Repl.CLI)
add_executable(Luau.Analyze.CLI)
if(NOT EMSCRIPTEN)
add_executable(Luau.Analyze.CLI)
else()
# add -fexceptions for emscripten to allow exceptions to be caught in C++
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fexceptions")
endif()
# This also adds target `name` on Linux/macOS and `name.exe` on Windows
set_target_properties(Luau.Repl.CLI PROPERTIES OUTPUT_NAME luau)
set_target_properties(Luau.Analyze.CLI PROPERTIES OUTPUT_NAME luau-analyze)
if(NOT EMSCRIPTEN)
set_target_properties(Luau.Analyze.CLI PROPERTIES OUTPUT_NAME luau-analyze)
endif()
endif()
if(LUAU_BUILD_TESTS)
if(LUAU_BUILD_TESTS AND NOT EMSCRIPTEN)
add_executable(Luau.UnitTest)
add_executable(Luau.Conformance)
endif()
include(Sources.cmake)
target_compile_features(Luau.Ast PUBLIC cxx_std_17)
@ -53,10 +62,6 @@ if(MSVC)
else()
list(APPEND LUAU_OPTIONS -Wall) # All warnings
list(APPEND LUAU_OPTIONS -Werror) # Warnings are errors
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
list(APPEND LUAU_OPTIONS -Wno-unused) # GCC considers variables declared/checked in if() as unused
endif()
endif()
target_compile_options(Luau.Ast PRIVATE ${LUAU_OPTIONS})
@ -65,7 +70,10 @@ target_compile_options(Luau.VM PRIVATE ${LUAU_OPTIONS})
if(LUAU_BUILD_CLI)
target_compile_options(Luau.Repl.CLI PRIVATE ${LUAU_OPTIONS})
target_compile_options(Luau.Analyze.CLI PRIVATE ${LUAU_OPTIONS})
if(NOT EMSCRIPTEN)
target_compile_options(Luau.Analyze.CLI PRIVATE ${LUAU_OPTIONS})
endif()
target_include_directories(Luau.Repl.CLI PRIVATE extern)
target_link_libraries(Luau.Repl.CLI PRIVATE Luau.Compiler Luau.VM)
@ -74,10 +82,20 @@ if(LUAU_BUILD_CLI)
target_link_libraries(Luau.Repl.CLI PRIVATE pthread)
endif()
target_link_libraries(Luau.Analyze.CLI PRIVATE Luau.Analysis)
if(NOT EMSCRIPTEN)
target_link_libraries(Luau.Analyze.CLI PRIVATE Luau.Analysis)
endif()
if(EMSCRIPTEN)
# declare exported functions to emscripten
target_link_options(Luau.Repl.CLI PRIVATE -sEXPORTED_FUNCTIONS=['_executeScript'] -sEXPORTED_RUNTIME_METHODS=['ccall','cwrap'] -fexceptions)
# custom output directory for wasm + js file
set_target_properties(Luau.Repl.CLI PROPERTIES RUNTIME_OUTPUT_DIRECTORY ${CMAKE_SOURCE_DIR}/docs/assets/luau)
endif()
endif()
if(LUAU_BUILD_TESTS)
if(LUAU_BUILD_TESTS AND NOT EMSCRIPTEN)
target_compile_options(Luau.UnitTest PRIVATE ${LUAU_OPTIONS})
target_include_directories(Luau.UnitTest PRIVATE extern)
target_link_libraries(Luau.UnitTest PRIVATE Luau.Analysis Luau.Compiler)

View file

@ -2465,9 +2465,10 @@ struct Compiler
}
else if (node->is<AstStatBreak>())
{
LUAU_ASSERT(!loops.empty());
// before exiting out of the loop, we need to close all local variables that were captured in closures since loop start
// normally they are closed by the enclosing blocks, including the loop block, but we're skipping that here
LUAU_ASSERT(!loops.empty());
closeLocals(loops.back().localOffset);
size_t label = bytecode.emitLabel();
@ -2478,12 +2479,13 @@ struct Compiler
}
else if (AstStatContinue* stat = node->as<AstStatContinue>())
{
LUAU_ASSERT(!loops.empty());
if (loops.back().untilCondition)
validateContinueUntil(stat, loops.back().untilCondition);
// before continuing, we need to close all local variables that were captured in closures since loop start
// normally they are closed by the enclosing blocks, including the loop block, but we're skipping that here
LUAU_ASSERT(!loops.empty());
closeLocals(loops.back().localOffset);
size_t label = bytecode.emitLabel();
@ -2900,6 +2902,11 @@ struct Compiler
break;
case AstExprUnary::Len:
if (arg.type == Constant::Type_String)
{
result.type = Constant::Type_Number;
result.valueNumber = double(arg.valueString.size);
}
break;
default:

View file

@ -1,4 +1,5 @@
# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
.SUFFIXES:
MAKEFLAGS+=-r -j8
COMMA=,
@ -48,7 +49,10 @@ OBJECTS=$(AST_OBJECTS) $(COMPILER_OBJECTS) $(ANALYSIS_OBJECTS) $(VM_OBJECTS) $(T
CXXFLAGS=-g -Wall -Werror
LDFLAGS=
CXXFLAGS+=-Wno-unused # temporary, for older gcc versions
# temporary, for older gcc versions as they treat var in `if (type var = val)` as unused
ifeq ($(findstring g++,$(shell $(CXX) --version)),g++)
CXXFLAGS+=-Wno-unused
endif
# configuration-specific flags
ifeq ($(config),release)

View file

@ -22,7 +22,7 @@ You can download the binaries from [a recent release](https://github.com/Roblox/
# Building
To build Luau tools or tests yourself, you can use CMake on all platforms, or alternatively make (on Linux/macOS). For example:
To build Luau tools or tests yourself, you can use CMake on all platforms:
```sh
mkdir cmake && cd cmake
@ -31,6 +31,12 @@ cmake --build . --target Luau.Repl.CLI --config RelWithDebInfo
cmake --build . --target Luau.Analyze.CLI --config RelWithDebInfo
```
Alternatively, on Linus/macOS you can use make:
```sh
make config=release luau luau-analyze
```
To integrate Luau into your CMake application projects, at the minimum you'll need to depend on `Luau.Compiler` and `Luau.VM` projects. From there you need to create a new Luau state (using Lua 5.x API such as `lua_newstate`), compile source to bytecode and load it into the VM like this:
```cpp

View file

@ -46,6 +46,7 @@ target_sources(Luau.Analysis PRIVATE
Analysis/include/Luau/Module.h
Analysis/include/Luau/ModuleResolver.h
Analysis/include/Luau/Predicate.h
Analysis/include/Luau/Quantify.h
Analysis/include/Luau/RecursionCounter.h
Analysis/include/Luau/RequireTracer.h
Analysis/include/Luau/Scope.h
@ -63,6 +64,7 @@ target_sources(Luau.Analysis PRIVATE
Analysis/include/Luau/TypeVar.h
Analysis/include/Luau/Unifiable.h
Analysis/include/Luau/Unifier.h
Analysis/include/Luau/UnifierSharedState.h
Analysis/include/Luau/Variant.h
Analysis/include/Luau/VisitTypeVar.h
@ -77,6 +79,7 @@ target_sources(Luau.Analysis PRIVATE
Analysis/src/Linter.cpp
Analysis/src/Module.cpp
Analysis/src/Predicate.cpp
Analysis/src/Quantify.cpp
Analysis/src/RequireTracer.cpp
Analysis/src/Scope.cpp
Analysis/src/Substitution.cpp

View file

@ -13,8 +13,6 @@
#include <string.h>
LUAU_FASTFLAG(LuauGcFullSkipInactiveThreads)
const char* lua_ident = "$Lua: Lua 5.1.4 Copyright (C) 1994-2008 Lua.org, PUC-Rio $\n"
"$Authors: R. Ierusalimschy, L. H. de Figueiredo & W. Celes $\n"
"$URL: www.lua.org $\n";
@ -1153,7 +1151,7 @@ void* lua_newuserdatadtor(lua_State* L, size_t sz, void (*dtor)(void*))
luaC_checkGC(L);
luaC_checkthreadsleep(L);
Udata* u = luaS_newudata(L, sz + sizeof(dtor), UTAG_IDTOR);
memcpy(u->data + sz, &dtor, sizeof(dtor));
memcpy(&u->data + sz, &dtor, sizeof(dtor));
setuvalue(L, L->top, u);
api_incr_top(L);
return u->data;

View file

@ -5,8 +5,6 @@
#include "lstate.h"
#include "lvm.h"
LUAU_FASTFLAGVARIABLE(LuauPreferXpush, false)
#define CO_RUN 0 /* running */
#define CO_SUS 1 /* suspended */
#define CO_NOR 2 /* 'normal' (it resumed another coroutine) */
@ -17,7 +15,7 @@ LUAU_FASTFLAGVARIABLE(LuauPreferXpush, false)
static const char* const statnames[] = {"running", "suspended", "normal", "dead"};
static int costatus(lua_State* L, lua_State* co)
static int auxstatus(lua_State* L, lua_State* co)
{
if (co == L)
return CO_RUN;
@ -34,11 +32,11 @@ static int costatus(lua_State* L, lua_State* co)
return CO_SUS; /* initial state */
}
static int luaB_costatus(lua_State* L)
static int costatus(lua_State* L)
{
lua_State* co = lua_tothread(L, 1);
luaL_argexpected(L, co, 1, "thread");
lua_pushstring(L, statnames[costatus(L, co)]);
lua_pushstring(L, statnames[auxstatus(L, co)]);
return 1;
}
@ -47,7 +45,7 @@ static int auxresume(lua_State* L, lua_State* co, int narg)
// error handling for edge cases
if (co->status != LUA_YIELD)
{
int status = costatus(L, co);
int status = auxstatus(L, co);
if (status != CO_SUS)
{
lua_pushfstring(L, "cannot resume %s coroutine", statnames[status]);
@ -115,7 +113,7 @@ static int auxresumecont(lua_State* L, lua_State* co)
}
}
static int luaB_coresumefinish(lua_State* L, int r)
static int coresumefinish(lua_State* L, int r)
{
if (r < 0)
{
@ -131,7 +129,7 @@ static int luaB_coresumefinish(lua_State* L, int r)
}
}
static int luaB_coresumey(lua_State* L)
static int coresumey(lua_State* L)
{
lua_State* co = lua_tothread(L, 1);
luaL_argexpected(L, co, 1, "thread");
@ -141,10 +139,10 @@ static int luaB_coresumey(lua_State* L)
if (r == CO_STATUS_BREAK)
return interruptThread(L, co);
return luaB_coresumefinish(L, r);
return coresumefinish(L, r);
}
static int luaB_coresumecont(lua_State* L, int status)
static int coresumecont(lua_State* L, int status)
{
lua_State* co = lua_tothread(L, 1);
luaL_argexpected(L, co, 1, "thread");
@ -155,10 +153,10 @@ static int luaB_coresumecont(lua_State* L, int status)
int r = auxresumecont(L, co);
return luaB_coresumefinish(L, r);
return coresumefinish(L, r);
}
static int luaB_auxwrapfinish(lua_State* L, int r)
static int auxwrapfinish(lua_State* L, int r)
{
if (r < 0)
{
@ -173,7 +171,7 @@ static int luaB_auxwrapfinish(lua_State* L, int r)
return r;
}
static int luaB_auxwrapy(lua_State* L)
static int auxwrapy(lua_State* L)
{
lua_State* co = lua_tothread(L, lua_upvalueindex(1));
int narg = cast_int(L->top - L->base);
@ -182,10 +180,10 @@ static int luaB_auxwrapy(lua_State* L)
if (r == CO_STATUS_BREAK)
return interruptThread(L, co);
return luaB_auxwrapfinish(L, r);
return auxwrapfinish(L, r);
}
static int luaB_auxwrapcont(lua_State* L, int status)
static int auxwrapcont(lua_State* L, int status)
{
lua_State* co = lua_tothread(L, lua_upvalueindex(1));
@ -195,62 +193,52 @@ static int luaB_auxwrapcont(lua_State* L, int status)
int r = auxresumecont(L, co);
return luaB_auxwrapfinish(L, r);
return auxwrapfinish(L, r);
}
static int luaB_cocreate(lua_State* L)
static int cocreate(lua_State* L)
{
luaL_checktype(L, 1, LUA_TFUNCTION);
lua_State* NL = lua_newthread(L);
if (FFlag::LuauPreferXpush)
{
lua_xpush(L, NL, 1); // push function on top of NL
}
else
{
lua_pushvalue(L, 1); /* move function to top */
lua_xmove(L, NL, 1); /* move function from L to NL */
}
lua_xpush(L, NL, 1); // push function on top of NL
return 1;
}
static int luaB_cowrap(lua_State* L)
static int cowrap(lua_State* L)
{
luaB_cocreate(L);
cocreate(L);
lua_pushcfunction(L, luaB_auxwrapy, NULL, 1, luaB_auxwrapcont);
lua_pushcfunction(L, auxwrapy, NULL, 1, auxwrapcont);
return 1;
}
static int luaB_yield(lua_State* L)
static int coyield(lua_State* L)
{
int nres = cast_int(L->top - L->base);
return lua_yield(L, nres);
}
static int luaB_corunning(lua_State* L)
static int corunning(lua_State* L)
{
if (lua_pushthread(L))
lua_pushnil(L); /* main thread is not a coroutine */
return 1;
}
static int luaB_yieldable(lua_State* L)
static int coyieldable(lua_State* L)
{
lua_pushboolean(L, lua_isyieldable(L));
return 1;
}
static const luaL_Reg co_funcs[] = {
{"create", luaB_cocreate},
{"running", luaB_corunning},
{"status", luaB_costatus},
{"wrap", luaB_cowrap},
{"yield", luaB_yield},
{"isyieldable", luaB_yieldable},
{"create", cocreate},
{"running", corunning},
{"status", costatus},
{"wrap", cowrap},
{"yield", coyield},
{"isyieldable", coyieldable},
{NULL, NULL},
};
@ -258,7 +246,7 @@ LUALIB_API int luaopen_coroutine(lua_State* L)
{
luaL_register(L, LUA_COLIBNAME, co_funcs);
lua_pushcfunction(L, luaB_coresumey, "resume", 0, luaB_coresumecont);
lua_pushcfunction(L, coresumey, "resume", 0, coresumecont);
lua_setfield(L, -2, "resume");
return 1;

View file

@ -18,6 +18,7 @@
#include <string.h>
LUAU_FASTFLAGVARIABLE(LuauExceptionMessageFix, false)
LUAU_FASTFLAGVARIABLE(LuauCcallRestoreFix, false)
/*
** {======================================================
@ -536,6 +537,12 @@ int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t e
status = LUA_ERRERR;
}
if (FFlag::LuauCcallRestoreFix)
{
// Restore nCcalls before calling the debugprotectederror callback which may rely on the proper value to have been restored.
L->nCcalls = oldnCcalls;
}
// an error occurred, check if we have a protected error callback
if (L->global->cb.debugprotectederror)
{
@ -549,7 +556,10 @@ int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t e
StkId oldtop = restorestack(L, old_top);
luaF_close(L, oldtop); /* close eventual pending closures */
seterrorobj(L, status, oldtop);
L->nCcalls = oldnCcalls;
if (!FFlag::LuauCcallRestoreFix)
{
L->nCcalls = oldnCcalls;
}
L->ci = restoreci(L, old_ci);
L->base = L->ci->base;
restore_stack_limit(L);

View file

@ -12,11 +12,9 @@
#include <string.h>
#include <stdio.h>
LUAU_FASTFLAGVARIABLE(LuauRescanGrayAgain, false)
LUAU_FASTFLAGVARIABLE(LuauRescanGrayAgainForwardBarrier, false)
LUAU_FASTFLAGVARIABLE(LuauGcFullSkipInactiveThreads, false)
LUAU_FASTFLAGVARIABLE(LuauShrinkWeakTables, false)
LUAU_FASTFLAGVARIABLE(LuauConsolidatedStep, false)
LUAU_FASTFLAGVARIABLE(LuauSeparateAtomic, false)
LUAU_FASTFLAG(LuauArrayBoundary)
@ -66,13 +64,18 @@ static void recordGcStateTime(global_State* g, int startgcstate, double seconds,
g->gcstats.currcycle.marktime += seconds;
// atomic step had to be performed during the switch and it's tracked separately
if (g->gcstate == GCSsweepstring)
if (!FFlag::LuauSeparateAtomic && g->gcstate == GCSsweepstring)
g->gcstats.currcycle.marktime -= g->gcstats.currcycle.atomictime;
break;
case GCSatomic:
g->gcstats.currcycle.atomictime += seconds;
break;
case GCSsweepstring:
case GCSsweep:
g->gcstats.currcycle.sweeptime += seconds;
break;
default:
LUAU_ASSERT(!"Unexpected GC state");
}
if (assist)
@ -183,33 +186,15 @@ static int traversetable(global_State* g, Table* h)
if (h->metatable)
markobject(g, cast_to(Table*, h->metatable));
if (FFlag::LuauShrinkWeakTables)
/* is there a weak mode? */
if (const char* modev = gettablemode(g, h))
{
/* is there a weak mode? */
if (const char* modev = gettablemode(g, h))
{
weakkey = (strchr(modev, 'k') != NULL);
weakvalue = (strchr(modev, 'v') != NULL);
if (weakkey || weakvalue)
{ /* is really weak? */
h->gclist = g->weak; /* must be cleared after GC, ... */
g->weak = obj2gco(h); /* ... so put in the appropriate list */
}
}
}
else
{
const TValue* mode = gfasttm(g, h->metatable, TM_MODE);
if (mode && ttisstring(mode))
{ /* is there a weak mode? */
const char* modev = svalue(mode);
weakkey = (strchr(modev, 'k') != NULL);
weakvalue = (strchr(modev, 'v') != NULL);
if (weakkey || weakvalue)
{ /* is really weak? */
h->gclist = g->weak; /* must be cleared after GC, ... */
g->weak = obj2gco(h); /* ... so put in the appropriate list */
}
weakkey = (strchr(modev, 'k') != NULL);
weakvalue = (strchr(modev, 'v') != NULL);
if (weakkey || weakvalue)
{ /* is really weak? */
h->gclist = g->weak; /* must be cleared after GC, ... */
g->weak = obj2gco(h); /* ... so put in the appropriate list */
}
}
@ -297,7 +282,7 @@ static void traversestack(global_State* g, lua_State* l, bool clearstack)
for (StkId o = l->stack; o < l->top; o++)
markvalue(g, o);
/* final traversal? */
if (g->gcstate == GCSatomic || (FFlag::LuauGcFullSkipInactiveThreads && clearstack))
if (g->gcstate == GCSatomic || clearstack)
{
StkId stack_end = l->stack + l->stacksize;
for (StkId o = l->top; o < stack_end; o++) /* clear not-marked stack slice */
@ -336,28 +321,16 @@ static size_t propagatemark(global_State* g)
lua_State* th = gco2th(o);
g->gray = th->gclist;
if (FFlag::LuauGcFullSkipInactiveThreads)
LUAU_ASSERT(!luaC_threadsleeping(th));
// threads that are executing and the main thread are not deactivated
bool active = luaC_threadactive(th) || th == th->global->mainthread;
if (!active && g->gcstate == GCSpropagate)
{
LUAU_ASSERT(!luaC_threadsleeping(th));
traversestack(g, th, /* clearstack= */ true);
// threads that are executing and the main thread are not deactivated
bool active = luaC_threadactive(th) || th == th->global->mainthread;
if (!active && g->gcstate == GCSpropagate)
{
traversestack(g, th, /* clearstack= */ true);
l_setbit(th->stackstate, THREAD_SLEEPINGBIT);
}
else
{
th->gclist = g->grayagain;
g->grayagain = o;
black2gray(o);
traversestack(g, th, /* clearstack= */ false);
}
l_setbit(th->stackstate, THREAD_SLEEPINGBIT);
}
else
{
@ -385,12 +358,14 @@ static size_t propagatemark(global_State* g)
}
}
static void propagateall(global_State* g)
static size_t propagateall(global_State* g)
{
size_t work = 0;
while (g->gray)
{
propagatemark(g);
work += propagatemark(g);
}
return work;
}
/*
@ -415,11 +390,14 @@ static int isobjcleared(GCObject* o)
/*
** clear collected entries from weaktables
*/
static void cleartable(lua_State* L, GCObject* l)
static size_t cleartable(lua_State* L, GCObject* l)
{
size_t work = 0;
while (l)
{
Table* h = gco2h(l);
work += sizeof(Table) + sizeof(TValue) * h->sizearray + sizeof(LuaNode) * sizenode(h);
int i = h->sizearray;
while (i--)
{
@ -433,50 +411,36 @@ static void cleartable(lua_State* L, GCObject* l)
{
LuaNode* n = gnode(h, i);
if (FFlag::LuauShrinkWeakTables)
// non-empty entry?
if (!ttisnil(gval(n)))
{
// non-empty entry?
if (!ttisnil(gval(n)))
{
// can we clear key or value?
if (iscleared(gkey(n)) || iscleared(gval(n)))
{
setnilvalue(gval(n)); /* remove value ... */
removeentry(n); /* remove entry from table */
}
else
{
activevalues++;
}
}
}
else
{
if (!ttisnil(gval(n)) && /* non-empty entry? */
(iscleared(gkey(n)) || iscleared(gval(n))))
// can we clear key or value?
if (iscleared(gkey(n)) || iscleared(gval(n)))
{
setnilvalue(gval(n)); /* remove value ... */
removeentry(n); /* remove entry from table */
}
else
{
activevalues++;
}
}
}
if (FFlag::LuauShrinkWeakTables)
if (const char* modev = gettablemode(L->global, h))
{
if (const char* modev = gettablemode(L->global, h))
// are we allowed to shrink this weak table?
if (strchr(modev, 's'))
{
// are we allowed to shrink this weak table?
if (strchr(modev, 's'))
{
// shrink at 37.5% occupancy
if (activevalues < sizenode(h) * 3 / 8)
luaH_resizehash(L, h, activevalues);
}
// shrink at 37.5% occupancy
if (activevalues < sizenode(h) * 3 / 8)
luaH_resizehash(L, h, activevalues);
}
}
l = h->gclist;
}
return work;
}
static void shrinkstack(lua_State* L)
@ -655,37 +619,49 @@ static void markroot(lua_State* L)
g->gcstate = GCSpropagate;
}
static void remarkupvals(global_State* g)
static size_t remarkupvals(global_State* g)
{
UpVal* uv;
for (uv = g->uvhead.u.l.next; uv != &g->uvhead; uv = uv->u.l.next)
size_t work = 0;
for (UpVal* uv = g->uvhead.u.l.next; uv != &g->uvhead; uv = uv->u.l.next)
{
work += sizeof(UpVal);
LUAU_ASSERT(uv->u.l.next->u.l.prev == uv && uv->u.l.prev->u.l.next == uv);
if (isgray(obj2gco(uv)))
markvalue(g, uv->v);
}
return work;
}
static void atomic(lua_State* L)
static size_t atomic(lua_State* L)
{
global_State* g = L->global;
g->gcstate = GCSatomic;
size_t work = 0;
if (FFlag::LuauSeparateAtomic)
{
LUAU_ASSERT(g->gcstate == GCSatomic);
}
else
{
g->gcstate = GCSatomic;
}
/* remark occasional upvalues of (maybe) dead threads */
remarkupvals(g);
work += remarkupvals(g);
/* traverse objects caught by write barrier and by 'remarkupvals' */
propagateall(g);
work += propagateall(g);
/* remark weak tables */
g->gray = g->weak;
g->weak = NULL;
LUAU_ASSERT(!iswhite(obj2gco(g->mainthread)));
markobject(g, L); /* mark running thread */
markmt(g); /* mark basic metatables (again) */
propagateall(g);
work += propagateall(g);
/* remark gray again */
g->gray = g->grayagain;
g->grayagain = NULL;
propagateall(g);
cleartable(L, g->weak); /* remove collected objects from weak tables */
work += propagateall(g);
work += cleartable(L, g->weak); /* remove collected objects from weak tables */
g->weak = NULL;
/* flip current white */
g->currentwhite = cast_byte(otherwhite(g));
@ -693,7 +669,12 @@ static void atomic(lua_State* L)
g->sweepgc = &g->rootgc;
g->gcstate = GCSsweepstring;
GC_INTERRUPT(GCSatomic);
if (!FFlag::LuauSeparateAtomic)
{
GC_INTERRUPT(GCSatomic);
}
return work;
}
static size_t singlestep(lua_State* L)
@ -705,46 +686,24 @@ static size_t singlestep(lua_State* L)
case GCSpause:
{
markroot(L); /* start a new collection */
LUAU_ASSERT(g->gcstate == GCSpropagate);
break;
}
case GCSpropagate:
{
if (FFlag::LuauRescanGrayAgain)
if (g->gray)
{
if (g->gray)
{
g->gcstats.currcycle.markitems++;
g->gcstats.currcycle.markitems++;
cost = propagatemark(g);
}
else
{
// perform one iteration over 'gray again' list
g->gray = g->grayagain;
g->grayagain = NULL;
g->gcstate = GCSpropagateagain;
}
cost = propagatemark(g);
}
else
{
if (g->gray)
{
g->gcstats.currcycle.markitems++;
// perform one iteration over 'gray again' list
g->gray = g->grayagain;
g->grayagain = NULL;
cost = propagatemark(g);
}
else /* no more `gray' objects */
{
double starttimestamp = lua_clock();
g->gcstats.currcycle.atomicstarttimestamp = starttimestamp;
g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes;
atomic(L); /* finish mark phase */
g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp;
}
g->gcstate = GCSpropagateagain;
}
break;
}
@ -758,17 +717,34 @@ static size_t singlestep(lua_State* L)
}
else /* no more `gray' objects */
{
double starttimestamp = lua_clock();
if (FFlag::LuauSeparateAtomic)
{
g->gcstate = GCSatomic;
}
else
{
double starttimestamp = lua_clock();
g->gcstats.currcycle.atomicstarttimestamp = starttimestamp;
g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes;
g->gcstats.currcycle.atomicstarttimestamp = starttimestamp;
g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes;
atomic(L); /* finish mark phase */
atomic(L); /* finish mark phase */
LUAU_ASSERT(g->gcstate == GCSsweepstring);
g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp;
g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp;
}
}
break;
}
case GCSatomic:
{
g->gcstats.currcycle.atomicstarttimestamp = lua_clock();
g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes;
cost = atomic(L); /* finish mark phase */
LUAU_ASSERT(g->gcstate == GCSsweepstring);
break;
}
case GCSsweepstring:
{
size_t traversedcount = 0;
@ -806,7 +782,7 @@ static size_t singlestep(lua_State* L)
break;
}
default:
LUAU_ASSERT(0);
LUAU_ASSERT(!"Unexpected GC state");
}
return cost;
@ -821,48 +797,25 @@ static size_t gcstep(lua_State* L, size_t limit)
case GCSpause:
{
markroot(L); /* start a new collection */
LUAU_ASSERT(g->gcstate == GCSpropagate);
break;
}
case GCSpropagate:
{
if (FFlag::LuauRescanGrayAgain)
while (g->gray && cost < limit)
{
while (g->gray && cost < limit)
{
g->gcstats.currcycle.markitems++;
g->gcstats.currcycle.markitems++;
cost += propagatemark(g);
}
if (!g->gray)
{
// perform one iteration over 'gray again' list
g->gray = g->grayagain;
g->grayagain = NULL;
g->gcstate = GCSpropagateagain;
}
cost += propagatemark(g);
}
else
if (!g->gray)
{
while (g->gray && cost < limit)
{
g->gcstats.currcycle.markitems++;
// perform one iteration over 'gray again' list
g->gray = g->grayagain;
g->grayagain = NULL;
cost += propagatemark(g);
}
if (!g->gray) /* no more `gray' objects */
{
double starttimestamp = lua_clock();
g->gcstats.currcycle.atomicstarttimestamp = starttimestamp;
g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes;
atomic(L); /* finish mark phase */
g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp;
}
g->gcstate = GCSpropagateagain;
}
break;
}
@ -877,17 +830,34 @@ static size_t gcstep(lua_State* L, size_t limit)
if (!g->gray) /* no more `gray' objects */
{
double starttimestamp = lua_clock();
if (FFlag::LuauSeparateAtomic)
{
g->gcstate = GCSatomic;
}
else
{
double starttimestamp = lua_clock();
g->gcstats.currcycle.atomicstarttimestamp = starttimestamp;
g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes;
g->gcstats.currcycle.atomicstarttimestamp = starttimestamp;
g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes;
atomic(L); /* finish mark phase */
atomic(L); /* finish mark phase */
LUAU_ASSERT(g->gcstate == GCSsweepstring);
g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp;
g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp;
}
}
break;
}
case GCSatomic:
{
g->gcstats.currcycle.atomicstarttimestamp = lua_clock();
g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes;
cost = atomic(L); /* finish mark phase */
LUAU_ASSERT(g->gcstate == GCSsweepstring);
break;
}
case GCSsweepstring:
{
while (g->sweepstrgc < g->strt.size && cost < limit)
@ -934,7 +904,7 @@ static size_t gcstep(lua_State* L, size_t limit)
break;
}
default:
LUAU_ASSERT(0);
LUAU_ASSERT(!"Unexpected GC state");
}
return cost;
}
@ -1084,7 +1054,7 @@ void luaC_fullgc(lua_State* L)
if (g->gcstate == GCSpause)
startGcCycleStats(g);
if (g->gcstate <= GCSpropagateagain)
if (g->gcstate <= (FFlag::LuauSeparateAtomic ? GCSatomic : GCSpropagateagain))
{
/* reset sweep marks to sweep all elements (returning them to white) */
g->sweepstrgc = 0;
@ -1095,7 +1065,7 @@ void luaC_fullgc(lua_State* L)
g->weak = NULL;
g->gcstate = GCSsweepstring;
}
LUAU_ASSERT(g->gcstate != GCSpause && g->gcstate != GCSpropagate && g->gcstate != GCSpropagateagain);
LUAU_ASSERT(g->gcstate == GCSsweepstring || g->gcstate == GCSsweep);
/* finish any pending sweep phase */
while (g->gcstate != GCSpause)
{
@ -1143,14 +1113,11 @@ void luaC_fullgc(lua_State* L)
void luaC_barrierupval(lua_State* L, GCObject* v)
{
if (FFlag::LuauGcFullSkipInactiveThreads)
{
global_State* g = L->global;
LUAU_ASSERT(iswhite(v) && !isdead(g, v));
global_State* g = L->global;
LUAU_ASSERT(iswhite(v) && !isdead(g, v));
if (keepinvariant(g))
reallymarkobject(g, v);
}
if (keepinvariant(g))
reallymarkobject(g, v);
}
void luaC_barrierf(lua_State* L, GCObject* o, GCObject* v)
@ -1778,7 +1745,7 @@ int64_t luaC_allocationrate(lua_State* L)
global_State* g = L->global;
const double durationthreshold = 1e-3; // avoid measuring intervals smaller than 1ms
if (g->gcstate <= GCSpropagateagain)
if (g->gcstate <= (FFlag::LuauSeparateAtomic ? GCSatomic : GCSpropagateagain))
{
double duration = lua_clock() - g->gcstats.lastcycle.endtimestamp;

View file

@ -6,8 +6,6 @@
#include "lobject.h"
#include "lstate.h"
LUAU_FASTFLAG(LuauGcFullSkipInactiveThreads)
/*
** Possible states of the Garbage Collector
*/
@ -25,7 +23,7 @@ LUAU_FASTFLAG(LuauGcFullSkipInactiveThreads)
** still-black objects. The invariant is restored when sweep ends and
** all objects are white again.
*/
#define keepinvariant(g) ((g)->gcstate == GCSpropagate || (g)->gcstate == GCSpropagateagain)
#define keepinvariant(g) ((g)->gcstate == GCSpropagate || (g)->gcstate == GCSpropagateagain || (g)->gcstate == GCSatomic)
/*
** some useful bit tricks

View file

@ -199,7 +199,7 @@ static void* luaM_newblock(lua_State* L, int sizeClass)
if (page->freeNext >= 0)
{
block = page->data + page->freeNext;
block = &page->data + page->freeNext;
ASAN_UNPOISON_MEMORY_REGION(block, page->blockSize);
page->freeNext -= page->blockSize;

View file

@ -226,7 +226,7 @@ void luaS_freeudata(lua_State* L, Udata* u)
void (*dtor)(void*) = nullptr;
if (u->tag == UTAG_IDTOR)
memcpy(&dtor, u->data + u->len - sizeof(dtor), sizeof(dtor));
memcpy(&dtor, &u->data + u->len - sizeof(dtor), sizeof(dtor));
else if (u->tag)
dtor = L->global->udatagc[u->tag];

View file

@ -13,7 +13,7 @@
#include <string.h>
// TODO: RAII deallocation doesn't work for longjmp builds if a memory error happens
template <typename T>
template<typename T>
struct TempBuffer
{
lua_State* L;
@ -346,6 +346,8 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size
uint32_t mainid = readVarInt(data, size, offset);
Proto* main = protos[mainid];
luaC_checkthreadsleep(L);
Closure* cl = luaF_newLclosure(L, 0, envt, main);
setclvalue(L, L->top, cl);
incr_top(L);

849
bench/tests/chess.lua Normal file
View file

@ -0,0 +1,849 @@
local bench = script and require(script.Parent.bench_support) or require("bench_support")
local RANKS = "12345678"
local FILES = "abcdefgh"
local PieceSymbols = "PpRrNnBbQqKk"
local UnicodePieces = {"", "", "", "", "", "", "", "", "", "", "", ""}
local StartingFen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
--
-- Lua 5.2 Compat
--
if not table.create then
function table.create(n, v)
local result = {}
for i=1,n do result[i] = v end
return result
end
end
if not table.move then
function table.move(a, from, to, start, target)
local dx = start - from
for i=from,to do
target[i+dx] = a[i]
end
end
end
--
-- Utils
--
local function square(s)
return RANKS:find(s:sub(2,2)) * 8 + FILES:find(s:sub(1,1)) - 9
end
local function squareName(n)
local file = n % 8
local rank = (n-file)/8
return FILES:sub(file+1,file+1) .. RANKS:sub(rank+1,rank+1)
end
local function moveName(v )
local from = bit32.extract(v, 6, 6)
local to = bit32.extract(v, 0, 6)
local piece = bit32.extract(v, 20, 4)
local captured = bit32.extract(v, 25, 4)
local move = PieceSymbols:sub(piece,piece) .. ' ' .. squareName(from) .. (captured ~= 0 and 'x' or '-') .. squareName(to)
if bit32.extract(v,14) == 1 then
if to > from then
return "O-O"
else
return "O-O-O"
end
end
local promote = bit32.extract(v,15,4)
if promote ~= 0 then
move = move .. "=" .. PieceSymbols:sub(promote,promote)
end
return move
end
local function ucimove(m)
local mm = squareName(bit32.extract(m, 6, 6)) .. squareName(bit32.extract(m, 0, 6))
local promote = bit32.extract(m,15,4)
if promote > 0 then
mm = mm .. PieceSymbols:sub(promote,promote):lower()
end
return mm
end
local _utils = {squareName, moveName}
--
-- Bitboards
--
local Bitboard = {}
function Bitboard:toString()
local out = {}
local src = self.h
for x=7,0,-1 do
table.insert(out, RANKS:sub(x+1,x+1))
table.insert(out, " ")
local bit = bit32.lshift(1,(x%4) * 8)
for x=0,7 do
if bit32.band(src, bit) ~= 0 then
table.insert(out, "x ")
else
table.insert(out, "- ")
end
bit = bit32.lshift(bit, 1)
end
if x == 4 then
src = self.l
end
table.insert(out, "\n")
end
table.insert(out, ' ' .. FILES:gsub('.', '%1 ') .. '\n')
table.insert(out, '#: ' .. self:popcnt() .. "\tl:" .. self.l .. "\th:" .. self.h)
return table.concat(out)
end
function Bitboard.from(l ,h )
return setmetatable({l=l, h=h}, Bitboard)
end
Bitboard.zero = Bitboard.from(0,0)
Bitboard.full = Bitboard.from(0xFFFFFFFF, 0xFFFFFFFF)
local Rank1 = Bitboard.from(0x000000FF, 0)
local Rank3 = Bitboard.from(0x00FF0000, 0)
local Rank6 = Bitboard.from(0, 0x0000FF00)
local Rank8 = Bitboard.from(0, 0xFF000000)
local FileA = Bitboard.from(0x01010101, 0x01010101)
local FileB = Bitboard.from(0x02020202, 0x02020202)
local FileC = Bitboard.from(0x04040404, 0x04040404)
local FileD = Bitboard.from(0x08080808, 0x08080808)
local FileE = Bitboard.from(0x10101010, 0x10101010)
local FileF = Bitboard.from(0x20202020, 0x20202020)
local FileG = Bitboard.from(0x40404040, 0x40404040)
local FileH = Bitboard.from(0x80808080, 0x80808080)
local _Files = {FileA, FileB, FileC, FileD, FileE, FileF, FileG, FileH}
-- These masks are filled out below for all files
local RightMasks = {FileH}
local LeftMasks = {FileA}
local function popcnt32(i)
i = i - bit32.band(bit32.rshift(i,1), 0x55555555)
i = bit32.band(i, 0x33333333) + bit32.band(bit32.rshift(i,2), 0x33333333)
return bit32.rshift(bit32.band(i + bit32.rshift(i,4), 0x0F0F0F0F) * 0x01010101, 24)
end
function Bitboard:up()
return self:lshift(8)
end
function Bitboard:down()
return self:rshift(8)
end
function Bitboard:right()
return self:band(FileH:inverse()):lshift(1)
end
function Bitboard:left()
return self:band(FileA:inverse()):rshift(1)
end
function Bitboard:move(x,y)
local out = self
if x < 0 then out = out:bandnot(RightMasks[-x]):lshift(-x) end
if x > 0 then out = out:bandnot(LeftMasks[x]):rshift(x) end
if y < 0 then out = out:rshift(-8 * y) end
if y > 0 then out = out:lshift(8 * y) end
return out
end
function Bitboard:popcnt()
return popcnt32(self.l) + popcnt32(self.h)
end
function Bitboard:band(other )
return Bitboard.from(bit32.band(self.l,other.l), bit32.band(self.h, other.h))
end
function Bitboard:bandnot(other )
return Bitboard.from(bit32.band(self.l,bit32.bnot(other.l)), bit32.band(self.h, bit32.bnot(other.h)))
end
function Bitboard:bandempty(other )
return bit32.band(self.l,other.l) == 0 and bit32.band(self.h, other.h) == 0
end
function Bitboard:bor(other )
return Bitboard.from(bit32.bor(self.l,other.l), bit32.bor(self.h, other.h))
end
function Bitboard:bxor(other )
return Bitboard.from(bit32.bxor(self.l,other.l), bit32.bxor(self.h, other.h))
end
function Bitboard:inverse()
return Bitboard.from(bit32.bxor(self.l,0xFFFFFFFF), bit32.bxor(self.h, 0xFFFFFFFF))
end
function Bitboard:empty()
return self.h == 0 and self.l == 0
end
function Bitboard:ctz()
local target = self.l
local offset = 0
local result = 0
if target == 0 then
target = self.h
result = 32
end
if target == 0 then
return 64
end
while bit32.extract(target, offset) == 0 do
offset = offset + 1
end
return result + offset
end
function Bitboard:ctzafter(start)
start = start + 1
if start < 32 then
for i=start,31 do
if bit32.extract(self.l, i) == 1 then return i end
end
end
for i=math.max(32,start),63 do
if bit32.extract(self.h, i-32) == 1 then return i end
end
return 64
end
function Bitboard:lshift(amt)
assert(amt >= 0)
if amt == 0 then return self end
if amt > 31 then
return Bitboard.from(0, bit32.lshift(self.l, amt-31))
end
local l = bit32.lshift(self.l, amt)
local h = bit32.bor(
bit32.lshift(self.h, amt),
bit32.extract(self.l, 32-amt, amt)
)
return Bitboard.from(l, h)
end
function Bitboard:rshift(amt)
assert(amt >= 0)
if amt == 0 then return self end
local h = bit32.rshift(self.h, amt)
local l = bit32.bor(
bit32.rshift(self.l, amt),
bit32.lshift(bit32.extract(self.h, 0, amt), 32-amt)
)
return Bitboard.from(l, h)
end
function Bitboard:index(i)
if i > 31 then
return bit32.extract(self.h, i - 32)
else
return bit32.extract(self.l, i)
end
end
function Bitboard:set(i , v)
if i > 31 then
return Bitboard.from(self.l, bit32.replace(self.h, v, i - 32))
else
return Bitboard.from(bit32.replace(self.l, v, i), self.h)
end
end
function Bitboard:isolate(i)
return self:band(Bitboard.some(i))
end
function Bitboard.some(idx )
return Bitboard.zero:set(idx, 1)
end
Bitboard.__index = Bitboard
Bitboard.__tostring = Bitboard.toString
for i=2,8 do
RightMasks[i] = RightMasks[i-1]:rshift(1):bor(FileH)
LeftMasks[i] = LeftMasks[i-1]:lshift(1):bor(FileA)
end
--
-- Board
--
local Board = {}
function Board.new()
local boards = table.create(12, Bitboard.zero)
boards.ocupied = Bitboard.zero
boards.white = Bitboard.zero
boards.black = Bitboard.zero
boards.unocupied = Bitboard.full
boards.ep = Bitboard.zero
boards.castle = Bitboard.zero
boards.toMove = 1
boards.hm = 0
boards.moves = 0
boards.material = 0
return setmetatable(boards, Board)
end
function Board.fromFen(fen )
local b = Board.new()
local i = 0
local rank = 7
local file = 0
while true do
i = i + 1
local p = fen:sub(i,i)
if p == '/' then
rank = rank - 1
file = 0
elseif tonumber(p) ~= nil then
file = file + tonumber(p)
else
local pidx = PieceSymbols:find(p)
if pidx == nil then break end
b[pidx] = b[pidx]:set(rank*8+file, 1)
file = file + 1
end
end
local move, castle, ep, hm, m = string.match(fen, "^ ([bw]) ([KQkq-]*) ([a-h-][0-9]?) (%d*) (%d*)", i)
if move == nil then print(fen:sub(i)) end
b.toMove = move == 'w' and 1 or 2
if ep ~= "-" then
b.ep = Bitboard.some(square(ep))
end
if castle ~= "-" then
local oo = Bitboard.zero
if castle:find("K") then
oo = oo:set(7, 1)
end
if castle:find("Q") then
oo = oo:set(0, 1)
end
if castle:find("k") then
oo = oo:set(63, 1)
end
if castle:find("q") then
oo = oo:set(56, 1)
end
b.castle = oo
end
b.hm = hm
b.moves = m
b:updateCache()
return b
end
function Board:index(idx )
if self.white:index(idx) == 1 then
for p=1,12,2 do
if self[p]:index(idx) == 1 then
return p
end
end
else
for p=2,12,2 do
if self[p]:index(idx) == 1 then
return p
end
end
end
return 0
end
function Board:updateCache()
for i=1,11,2 do
self.white = self.white:bor(self[i])
self.black = self.black:bor(self[i+1])
end
self.ocupied = self.black:bor(self.white)
self.unocupied = self.ocupied:inverse()
self.material =
100*self[1]:popcnt() - 100*self[2]:popcnt() +
500*self[3]:popcnt() - 500*self[4]:popcnt() +
300*self[5]:popcnt() - 300*self[6]:popcnt() +
300*self[7]:popcnt() - 300*self[8]:popcnt() +
900*self[9]:popcnt() - 900*self[10]:popcnt()
end
function Board:fen()
local out = {}
local s = 0
local idx = 56
for i=0,63 do
if i % 8 == 0 and i > 0 then
idx = idx - 16
if s > 0 then
table.insert(out, '' .. s)
s = 0
end
table.insert(out, '/')
end
local p = self:index(idx)
if p == 0 then
s = s + 1
else
if s > 0 then
table.insert(out, '' .. s)
s = 0
end
table.insert(out, PieceSymbols:sub(p,p))
end
idx = idx + 1
end
if s > 0 then
table.insert(out, '' .. s)
end
table.insert(out, self.toMove == 1 and ' w ' or ' b ')
if self.castle:empty() then
table.insert(out, '-')
else
if self.castle:index(7) == 1 then table.insert(out, 'K') end
if self.castle:index(0) == 1 then table.insert(out, 'Q') end
if self.castle:index(63) == 1 then table.insert(out, 'k') end
if self.castle:index(56) == 1 then table.insert(out, 'q') end
end
table.insert(out, ' ')
if self.ep:empty() then
table.insert(out, '-')
else
table.insert(out, squareName(self.ep:ctz()))
end
table.insert(out, ' ' .. self.hm)
table.insert(out, ' ' .. self.moves)
return table.concat(out)
end
function Board:pmoves(idx)
return self:generate(idx)
end
function Board:pcaptures(idx)
return self:generate(idx):band(self.ocupied)
end
local ROOK_SLIDES = {{1,0}, {-1,0}, {0,1}, {0,-1}}
local BISHOP_SLIDES = {{1,1}, {-1,1}, {1,-1}, {-1,-1}}
local QUEEN_SLIDES = {{1,0}, {-1,0}, {0,1}, {0,-1}, {1,1}, {-1,1}, {1,-1}, {-1,-1}}
local KNIGHT_MOVES = {{2,1}, {2,-1}, {-2,1}, {-2,-1}, {1,2}, {1,-2}, {-1,2}, {-1,-2}}
function Board:generate(idx)
local piece = self:index(idx)
local r = Bitboard.some(idx)
local out = Bitboard.zero
local type = bit32.rshift(piece - 1, 1)
local cancapture = piece % 2 == 1 and self.black or self.white
if piece == 0 then return Bitboard.zero end
if type == 0 then
-- Pawn
local d = -(piece*2 - 3)
local movetwo = piece == 1 and Rank3 or Rank6
out = out:bor(r:move(0,d):band(self.unocupied))
out = out:bor(out:band(movetwo):move(0,d):band(self.unocupied))
local captures = r:move(0,d)
captures = captures:right():bor(captures:left())
if not captures:bandempty(self.ep) then
out = out:bor(self.ep)
end
captures = captures:band(cancapture)
out = out:bor(captures)
return out
elseif type == 5 then
-- King
for x=-1,1,1 do
for y = -1,1,1 do
local w = r:move(x,y)
if self.ocupied:bandempty(w) then
out = out:bor(w)
else
if not cancapture:bandempty(w) then
out = out:bor(w)
end
end
end
end
elseif type == 2 then
-- Knight
for _,j in ipairs(KNIGHT_MOVES) do
local w = r:move(j[1],j[2])
if self.ocupied:bandempty(w) then
out = out:bor(w)
else
if not cancapture:bandempty(w) then
out = out:bor(w)
end
end
end
else
-- Sliders (Rook, Bishop, Queen)
local slides
if type == 1 then
slides = ROOK_SLIDES
elseif type == 3 then
slides = BISHOP_SLIDES
else
slides = QUEEN_SLIDES
end
for _, op in ipairs(slides) do
local w = r
for i=1,7 do
w = w:move(op[1], op[2])
if w:empty() then break end
if self.ocupied:bandempty(w) then
out = out:bor(w)
else
if not cancapture:bandempty(w) then
out = out:bor(w)
end
break
end
end
end
end
return out
end
-- 0-5 - From Square
-- 6-11 - To Square
-- 12 - is Check
-- 13 - Is EnPassent
-- 14 - Is Castle
-- 15-19 - Promotion Piece
-- 20-24 - Moved Pice
-- 25-29 - Captured Piece
function Board:toString(mark )
local out = {}
for x=8,1,-1 do
table.insert(out, RANKS:sub(x,x) .. " ")
for y=1,8 do
local n = 8*x+y-9
local i = self:index(n)
if i == 0 then
table.insert(out, '-')
else
-- out = out .. PieceSymbols:sub(i,i)
table.insert(out, UnicodePieces[i])
end
if mark ~= nil and mark:index(n) ~= 0 then
table.insert(out, ')')
elseif mark ~= nil and n < 63 and y < 8 and mark:index(n+1) ~= 0 then
table.insert(out, '(')
else
table.insert(out, ' ')
end
end
table.insert(out, "\n")
end
table.insert(out, ' ' .. FILES:gsub('.', '%1 ') .. '\n')
table.insert(out, (self.toMove == 1 and "White" or "Black") .. ' e:' .. (self.material/100) .. "\n")
return table.concat(out)
end
function Board:moveList()
local tm = self.toMove == 1 and self.white or self.black
local castle_rank = self.toMove == 1 and Rank1 or Rank8
local out = {}
local function emit(id)
if not self:applyMove(id):illegalyChecked() then
table.insert(out, id)
end
end
local cr = tm:band(self.castle):band(castle_rank)
if not cr:empty() then
local p = self.toMove == 1 and 11 or 12
local tcolor = self.toMove == 1 and self.black or self.white
local kidx = self[p]:ctz()
local castle = bit32.replace(0, p, 20, 4)
castle = bit32.replace(castle, kidx, 6, 6)
castle = bit32.replace(castle, 1, 14)
local mustbeemptyl = LeftMasks[4]:bxor(FileA):band(castle_rank)
local cantbethreatened = FileD:bor(FileC):band(castle_rank):bor(self[p])
if
not cr:bandempty(FileA) and
mustbeemptyl:bandempty(self.ocupied) and
not self:isSquareThreatened(cantbethreatened, tcolor)
then
emit(bit32.replace(castle, kidx - 2, 0, 6))
end
local mustbeemptyr = RightMasks[3]:bxor(FileH):band(castle_rank)
if
not cr:bandempty(FileH) and
mustbeemptyr:bandempty(self.ocupied) and
not self:isSquareThreatened(mustbeemptyr:bor(self[p]), tcolor)
then
emit(bit32.replace(castle, kidx + 2, 0, 6))
end
end
local sq = tm:ctz()
repeat
local p = self:index(sq)
local moves = self:pmoves(sq)
while not moves:empty() do
local m = moves:ctz()
moves = moves:set(m, 0)
local id = bit32.replace(m, sq, 6, 6)
id = bit32.replace(id, p, 20, 4)
local mbb = Bitboard.some(m)
if not self.ocupied:bandempty(mbb) then
id = bit32.replace(id, self:index(m), 25, 4)
end
-- Check if pawn needs to be promoted
if p == 1 and m >= 8*7 then
for i=3,9,2 do
emit(bit32.replace(id, i, 15, 4))
end
elseif p == 2 and m < 8 then
for i=4,10,2 do
emit(bit32.replace(id, i, 15, 4))
end
else
emit(id)
end
end
sq = tm:ctzafter(sq)
until sq == 64
return out
end
function Board:illegalyChecked()
local target = self.toMove == 1 and self[PieceSymbols:find("k")] or self[PieceSymbols:find("K")]
return self:isSquareThreatened(target, self.toMove == 1 and self.white or self.black)
end
function Board:isSquareThreatened(target , color )
local tm = color
local sq = tm:ctz()
repeat
local moves = self:pmoves(sq)
if not moves:bandempty(target) then
return true
end
sq = color:ctzafter(sq)
until sq == 64
return false
end
function Board:perft(depth )
if depth == 0 then return 1 end
if depth == 1 then
return #self:moveList()
end
local result = 0
for k,m in ipairs(self:moveList()) do
local c = self:applyMove(m):perft(depth - 1)
if c == 0 then
-- Perft only counts leaf nodes at target depth
-- result = result + 1
else
result = result + c
end
end
return result
end
function Board:applyMove(move )
local out = Board.new()
table.move(self, 1, 12, 1, out)
local from = bit32.extract(move, 6, 6)
local to = bit32.extract(move, 0, 6)
local promote = bit32.extract(move, 15, 4)
local piece = self:index(from)
local captured = self:index(to)
local tom = Bitboard.some(to)
local isCastle = bit32.extract(move, 14)
if piece % 2 == 0 then
out.moves = self.moves + 1
end
if captured == 1 or piece < 3 then
out.hm = 0
else
out.hm = self.hm + 1
end
out.castle = self.castle
out.toMove = self.toMove == 1 and 2 or 1
if isCastle == 1 then
local rank = piece == 11 and Rank1 or Rank8
local colorOffset = piece - 11
out[3 + colorOffset] = out[3 + colorOffset]:bandnot(from < to and FileH or FileA)
out[3 + colorOffset] = out[3 + colorOffset]:bor((from < to and FileF or FileD):band(rank))
out[piece] = (from < to and FileG or FileC):band(rank)
out.castle = out.castle:bandnot(rank)
out:updateCache()
return out
end
if piece < 3 then
local dist = math.abs(to - from)
-- Pawn moved two squares, set ep square
if dist == 16 then
out.ep = Bitboard.some((from + to) / 2)
end
-- Remove enpasent capture
if not tom:bandempty(self.ep) then
if piece == 1 then
out[2] = out[2]:bandnot(self.ep:down())
end
if piece == 2 then
out[1] = out[1]:bandnot(self.ep:up())
end
end
end
if piece == 3 or piece == 4 then
out.castle = out.castle:set(from, 0)
end
if piece > 10 then
local rank = piece == 11 and Rank1 or Rank8
out.castle = out.castle:bandnot(rank)
end
out[piece] = out[piece]:set(from, 0)
if promote == 0 then
out[piece] = out[piece]:set(to, 1)
else
out[promote] = out[promote]:set(to, 1)
end
if captured ~= 0 then
out[captured] = out[captured]:set(to, 0)
end
out:updateCache()
return out
end
Board.__index = Board
Board.__tostring = Board.toString
--
-- Main
--
local failures = 0
local function test(fen, ply, target)
local b = Board.fromFen(fen)
if b:fen() ~= fen then
print("FEN MISMATCH", fen, b:fen())
failures = failures + 1
return
end
local found = b:perft(ply)
if found ~= target then
print(fen, "Found", found, "target", target)
failures = failures + 1
for k,v in pairs(b:moveList()) do
print(ucimove(v) .. ': ' .. (ply > 1 and b:applyMove(v):perft(ply-1) or '1'))
end
--error("Test Failure")
else
print("OK", found, fen)
end
end
-- From https://www.chessprogramming.org/Perft_Results
-- If interpreter, computers, or algorithm gets too fast
-- feel free to go deeper
local testCases = {}
local function addTest(...) table.insert(testCases, {...}) end
addTest(StartingFen, 3, 8902)
addTest("r3k2r/p1ppqpb1/bn2pnp1/3PN3/1p2P3/2N2Q1p/PPPBBPPP/R3K2R w KQkq - 0 0", 2, 2039)
addTest("8/2p5/3p4/KP5r/1R3p1k/8/4P1P1/8 w - - 0 0", 3, 2812)
addTest("r3k2r/Pppp1ppp/1b3nbN/nP6/BBP1P3/q4N2/Pp1P2PP/R2Q1RK1 w kq - 0 1", 3, 9467)
addTest("rnbq1k1r/pp1Pbppp/2p5/8/2B5/8/PPP1NnPP/RNBQK2R w KQ - 1 8", 2, 1486)
addTest("r4rk1/1pp1qppp/p1np1n2/2b1p1B1/2B1P1b1/P1NP1N2/1PP1QPPP/R4RK1 w - - 0 10", 2, 2079)
local function chess()
for k,v in ipairs(testCases) do
test(v[1],v[2],v[3])
end
end
bench.runCode(chess, "chess")

View file

@ -25,3 +25,5 @@ pages:
url: /profile
- title: Library
url: /library
- title: Demo
url: /demo

50
docs/_includes/repl.html Normal file
View file

@ -0,0 +1,50 @@
<form>
<div>
<label>Script:</label>
<br>
<textarea rows="10" cols="70" id="script">print("Hello World!")</textarea>
<br><br>
<button onclick="clearInput(); return false;">
Clear Input
</button>
<button onclick="executeScript(); return false;">
Run
</button>
</div>
<br><br>
<div>
<label>Output:</label>
<br>
<textarea readonly rows="10" cols="70" id="output"></textarea>
<br><br>
<button onclick="clearOutput(); return false;">
Clear Output
</button>
</div>
</form>
<script>
function output(text) {
document.getElementById("output").value += "[" + new Date().toLocaleTimeString() + "] " + text.replace('stdin:', '') + "\n";
}
var Module = {
'print': function (msg) { output(msg) }
};
function clearInput() {
document.getElementById("script").value = "";
}
function clearOutput() {
document.getElementById("output").value = "";
}
function executeScript() {
var err = Module.ccall('executeScript', 'string', ['string'], [document.getElementById("script").value]);
if (err) {
output('Error:' + err.replace('stdin:', ''));
}
}
</script>
<script async src="assets/luau/luau.js"></script>

6
docs/_pages/demo.md Normal file
View file

@ -0,0 +1,6 @@
---
permalink: /demo
title: Demo
---
{% include repl.html %}

View file

@ -46,16 +46,6 @@ This is using the VM feature that is not accessible from scripts, that prevents
By itself this would mean that code that runs in Luau can't use globals at all, since assigning globals would fail. While this is feasible, in Roblox we solve this by creating a new global table for each script, that uses `__index` to point to the builtin global table. This safely sandboxes the builtin globals while still allowing writing globals from each script. This also means that short of exposing special shared globals from the host, all scripts are isolated from each other.
## Thread identity
Environment-level sandboxing is sufficient to implement separation between trusted code and untrusted code, assuming that `getfenv`/`setfenv` are either unavailable (removed from the globals), or that trusted code never interfaces with untrusted code (which prevents untrusted code from ever getting access to trusted functions). When running trusted code, it's possible to inject extra globals from the host into that global table, providing access to special APIs.
However, in some cases it's desirable to restrict access to functions that are exposed both to trusted and untrusted code. For example, both may have access to `game` global, but `game` may expose methods that should only work from trusted code.
To achieve this, each thread in Luau has a security identity, which can only be set by the host. Newly created threads inherit identities from the parent thread, and functions exposed from the host can validate the identity of the calling thread. This makes it possible to provide APIs to trusted code while limiting the access from untrusted code.
> Note: to achieve an even stronger guarantee of isolation between trusted and untrusted code, it's possible to run it in different Luau VMs, which is what Roblox does for extra safety.
## `__gc`
Lua 5.1 exposes a `__gc` metamethod for userdata, which can be used on proxies (`newproxy`) to hook into garbage collector. Later versions of Lua extend this mechanism to work on tables.

View file

@ -1596,7 +1596,7 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_argument_type_suggestion")
local function target(a: number, b: string) return a + #b end
local function d(a: n@1, b)
return target(a, b)
return target(a, b)
end
)");
@ -1609,7 +1609,7 @@ end
local function target(a: number, b: string) return a + #b end
local function d(a, b: s@1)
return target(a, b)
return target(a, b)
end
)");
@ -1622,7 +1622,7 @@ end
local function target(a: number, b: string) return a + #b end
local function d(a:@1 @2, b)
return target(a, b)
return target(a, b)
end
)");
@ -1640,7 +1640,7 @@ end
local function target(a: number, b: string) return a + #b end
local function d(a, b: @1)@2: number
return target(a, b)
return target(a, b)
end
)");
@ -1682,7 +1682,7 @@ local x = target(function(a: n@1
local function target(callback: (a: number, b: string) -> number) return callback(4, "hello") end
local x = target(function(a: n@1, b: @2)
return a + #b
return a + #b
end)
)");
@ -1700,7 +1700,7 @@ end)
local function target(callback: (...number) -> number) return callback(1, 2, 3) end
local x = target(function(a: n@1)
return a
return a
end
)");
@ -1716,7 +1716,7 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_expected_argument_type_pack_suggestio
local function target(callback: (...number) -> number) return callback(1, 2, 3) end
local x = target(function(...:n@1)
return a
return a
end
)");
@ -1729,7 +1729,7 @@ end
local function target(callback: (...number) -> number) return callback(1, 2, 3) end
local x = target(function(a:number, b:number, ...:@1)
return a + b
return a + b
end
)");
@ -1745,7 +1745,7 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_expected_return_type_suggestion")
local function target(callback: () -> number) return callback() end
local x = target(function(): n@1
return 1
return 1
end
)");
@ -1758,7 +1758,7 @@ end
local function target(callback: () -> (number, number)) return callback() end
local x = target(function(): (number, n@1
return 1, 2
return 1, 2
end
)");
@ -1774,7 +1774,7 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_expected_return_type_pack_suggestion"
local function target(callback: () -> ...number) return callback() end
local x = target(function(): ...n@1
return 1, 2, 3
return 1, 2, 3
end
)");
@ -1787,7 +1787,7 @@ end
local function target(callback: () -> ...number) return callback() end
local x = target(function(): (number, number, ...n@1
return 1, 2, 3
return 1, 2, 3
end
)");

View file

@ -768,11 +768,11 @@ TEST_CASE("CaptureSelf")
local MaterialsListClass = {}
function MaterialsListClass:_MakeToolTip(guiElement, text)
local function updateTooltipPosition()
self._tweakingTooltipFrame = 5
end
local function updateTooltipPosition()
self._tweakingTooltipFrame = 5
end
updateTooltipPosition()
updateTooltipPosition()
end
return MaterialsListClass
@ -1168,6 +1168,17 @@ RETURN R0 1
)");
}
TEST_CASE("ConstantFoldStringLen")
{
CHECK_EQ("\n" + compileFunction0("return #'string', #'', #'a', #('b')"), R"(
LOADN R0 6
LOADN R1 0
LOADN R2 1
LOADN R3 1
RETURN R0 4
)");
}
TEST_CASE("ConstantFoldCompare")
{
// ordered comparisons
@ -2001,14 +2012,14 @@ TEST_CASE("UpvaluesLoopsBytecode")
{
CHECK_EQ("\n" + compileFunction(R"(
function test()
for i=1,10 do
for i=1,10 do
i = i
foo(function() return i end)
if bar then
break
end
end
return 0
foo(function() return i end)
if bar then
break
end
end
return 0
end
)",
1),
@ -2035,14 +2046,14 @@ RETURN R0 1
CHECK_EQ("\n" + compileFunction(R"(
function test()
for i in ipairs(data) do
for i in ipairs(data) do
i = i
foo(function() return i end)
if bar then
break
end
end
return 0
foo(function() return i end)
if bar then
break
end
end
return 0
end
)",
1),
@ -2068,17 +2079,17 @@ RETURN R0 1
CHECK_EQ("\n" + compileFunction(R"(
function test()
local i = 0
while i < 5 do
local j
local i = 0
while i < 5 do
local j
j = i
foo(function() return j end)
i = i + 1
if bar then
break
end
end
return 0
foo(function() return j end)
i = i + 1
if bar then
break
end
end
return 0
end
)",
1),
@ -2105,17 +2116,17 @@ RETURN R1 1
CHECK_EQ("\n" + compileFunction(R"(
function test()
local i = 0
repeat
local j
local i = 0
repeat
local j
j = i
foo(function() return j end)
i = i + 1
if bar then
break
end
until i < 5
return 0
foo(function() return j end)
i = i + 1
if bar then
break
end
until i < 5
return 0
end
)",
1),
@ -2304,10 +2315,10 @@ local Value1, Value2, Value3 = ...
local Table = {}
Table.SubTable["Key"] = {
Key1 = Value1,
Key2 = Value2,
Key3 = Value3,
Key4 = true,
Key1 = Value1,
Key2 = Value2,
Key3 = Value3,
Key4 = true,
}
)");

View file

@ -801,4 +801,17 @@ TEST_CASE("IfElseExpression")
runConformance("ifelseexpr.lua");
}
TEST_CASE("TagMethodError")
{
ScopedFastFlag sff{"LuauCcallRestoreFix", true};
runConformance("tmerror.lua", [](lua_State* L) {
auto* cb = lua_callbacks(L);
cb->debugprotectederror = [](lua_State* L) {
CHECK(lua_isyieldable(L));
};
});
}
TEST_SUITE_END();

View file

@ -2,6 +2,9 @@
#pragma once
#include <ostream>
#include <optional>
namespace std {
inline std::ostream& operator<<(std::ostream& lhs, const std::nullopt_t&)
{
@ -9,10 +12,12 @@ inline std::ostream& operator<<(std::ostream& lhs, const std::nullopt_t&)
}
template<typename T>
std::ostream& operator<<(std::ostream& lhs, const std::optional<T>& t)
auto operator<<(std::ostream& lhs, const std::optional<T>& t) -> decltype(lhs << *t) // SFINAE to only instantiate << for supported types
{
if (t)
return lhs << *t;
else
return lhs << "none";
}
} // namespace std

View file

@ -791,13 +791,13 @@ TEST_CASE_FIXTURE(Fixture, "TypeAnnotationsShouldNotProduceWarnings")
{
LintResult result = lint(R"(--!strict
type InputData = {
id: number,
inputType: EnumItem,
inputState: EnumItem,
updated: number,
position: Vector3,
keyCode: EnumItem,
name: string
id: number,
inputType: EnumItem,
inputState: EnumItem,
updated: number,
position: Vector3,
keyCode: EnumItem,
name: string
}
)");

View file

@ -554,4 +554,54 @@ TEST_CASE_FIXTURE(Fixture, "non_recursive_aliases_that_reuse_a_generic_name")
CHECK_EQ("{number | string}", toString(requireType("p"), {true}));
}
/*
* We had a problem where all type aliases would be prototyped into a child scope that happened
* to have the same level. This caused a problem where, if a sibling function referred to that
* type alias in its type signature, it would erroneously be quantified away, even though it doesn't
* actually belong to the function.
*
* We solved this by ascribing a unique subLevel to each prototyped alias.
*/
TEST_CASE_FIXTURE(Fixture, "do_not_quantify_unresolved_aliases")
{
CheckResult result = check(R"(
--!strict
local KeyPool = {}
local function newkey(pool: KeyPool, index)
return {}
end
function newKeyPool()
local pool = {
available = {} :: {Key},
}
return setmetatable(pool, KeyPool)
end
export type KeyPool = typeof(newKeyPool())
export type Key = typeof(newkey(newKeyPool(), 1))
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
/*
* We keep a cache of type alias onto TypeVar to prevent infinite types from
* being constructed via recursive or corecursive aliases. We have to adjust
* the TypeLevels of those generic TypeVars so that the unifier doesn't think
* they have improperly leaked out of their scope.
*/
TEST_CASE_FIXTURE(Fixture, "generic_typevars_are_not_considered_to_escape_their_scope_if_they_are_reused_in_multiple_aliases")
{
CheckResult result = check(R"(
type Array<T> = {T}
type Exclude<T, V> = T
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_SUITE_END();

View file

@ -437,8 +437,6 @@ TEST_CASE_FIXTURE(ClassFixture, "class_unification_type_mismatch_is_correct_orde
TEST_CASE_FIXTURE(ClassFixture, "optional_class_field_access_error")
{
ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true);
CheckResult result = check(R"(
local b: Vector2? = nil
local a = b.X + b.Z

View file

@ -695,4 +695,25 @@ TEST_CASE_FIXTURE(Fixture, "typefuns_sharing_types")
CHECK(requireType("y1") == requireType("y2"));
}
TEST_CASE_FIXTURE(Fixture, "bound_tables_do_not_clone_original_fields")
{
ScopedFastFlag luauRankNTypes{"LuauRankNTypes", true};
ScopedFastFlag luauCloneBoundTables{"LuauCloneBoundTables", true};
CheckResult result = check(R"(
local exports = {}
local nested = {}
nested.name = function(t, k)
local a = t.x.y
return rawget(t, k)
end
exports.nested = nested
return exports
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_SUITE_END();

View file

@ -9,6 +9,7 @@
#include <algorithm>
LUAU_FASTFLAG(LuauEqConstraint)
LUAU_FASTFLAG(LuauQuantifyInPlace2)
using namespace Luau;
@ -42,7 +43,7 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_inference_incomplete")
end
)";
const std::string expected = R"(
const std::string old_expected = R"(
function f(a:{fn:()->(free,free...)}): ()
if type(a) == 'boolean'then
local a1:boolean=a
@ -51,7 +52,21 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_inference_incomplete")
end
end
)";
CHECK_EQ(expected, decorateWithTypes(code));
const std::string expected = R"(
function f(a:{fn:()->(a,b...)}): ()
if type(a) == 'boolean'then
local a1:boolean=a
elseif a.fn()then
local a2:{fn:()->(a,b...)}=a
end
end
)";
if (FFlag::LuauQuantifyInPlace2)
CHECK_EQ(expected, decorateWithTypes(code));
else
CHECK_EQ(old_expected, decorateWithTypes(code));
}
TEST_CASE_FIXTURE(Fixture, "xpcall_returns_what_f_returns")
@ -263,8 +278,8 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_equals_another_lvalue_with_no_overlap")
TEST_CASE_FIXTURE(Fixture, "bail_early_if_unification_is_too_complicated" * doctest::timeout(0.5))
{
ScopedFastInt sffi{"LuauTarjanChildLimit", 50};
ScopedFastInt sffi2{"LuauTypeInferIterationLimit", 50};
ScopedFastInt sffi{"LuauTarjanChildLimit", 1};
ScopedFastInt sffi2{"LuauTypeInferIterationLimit", 1};
CheckResult result = check(R"LUA(
local Result

View file

@ -8,6 +8,7 @@
LUAU_FASTFLAG(LuauWeakEqConstraint)
LUAU_FASTFLAG(LuauOrPredicate)
LUAU_FASTFLAG(LuauQuantifyInPlace2)
using namespace Luau;
@ -698,10 +699,16 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_free_table_to_vector")
CHECK_EQ("Vector3", toString(requireTypeAtPosition({5, 28}))); // type(vec) == "vector"
CHECK_EQ("Type '{- X: a, Y: b, Z: c -}' could not be converted into 'Instance'", toString(result.errors[0]));
if (FFlag::LuauQuantifyInPlace2)
CHECK_EQ("Type '{+ X: a, Y: b, Z: c +}' could not be converted into 'Instance'", toString(result.errors[0]));
else
CHECK_EQ("Type '{- X: a, Y: b, Z: c -}' could not be converted into 'Instance'", toString(result.errors[0]));
CHECK_EQ("*unknown*", toString(requireTypeAtPosition({7, 28}))); // typeof(vec) == "Instance"
CHECK_EQ("{- X: a, Y: b, Z: c -}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance"
if (FFlag::LuauQuantifyInPlace2)
CHECK_EQ("{+ X: a, Y: b, Z: c +}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance"
else
CHECK_EQ("{- X: a, Y: b, Z: c -}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance"
}
TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_instance_or_vector3_to_vector")

View file

@ -617,7 +617,7 @@ TEST_CASE_FIXTURE(Fixture, "indexers_get_quantified_too")
REQUIRE_EQ(indexer.indexType, typeChecker.numberType);
REQUIRE(nullptr != get<GenericTypeVar>(indexer.indexResultType));
REQUIRE(nullptr != get<GenericTypeVar>(follow(indexer.indexResultType)));
}
TEST_CASE_FIXTURE(Fixture, "indexers_quantification_2")

View file

@ -180,7 +180,12 @@ TEST_CASE_FIXTURE(Fixture, "expr_statement")
TEST_CASE_FIXTURE(Fixture, "generic_function")
{
CheckResult result = check("function id(x) return x end local a = id(55) local b = id(nil)");
CheckResult result = check(R"(
function id(x) return x end
local a = id(55)
local b = id(nil)
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(*typeChecker.numberType, *requireType("a"));
@ -1889,7 +1894,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_higher_order_function")
REQUIRE_EQ(2, argVec.size());
const FunctionTypeVar* fType = get<FunctionTypeVar>(argVec[0]);
const FunctionTypeVar* fType = get<FunctionTypeVar>(follow(argVec[0]));
REQUIRE(fType != nullptr);
std::vector<TypeId> fArgs = flatten(fType->argTypes).first;
@ -1926,7 +1931,7 @@ TEST_CASE_FIXTURE(Fixture, "higher_order_function_2")
REQUIRE_EQ(6, argVec.size());
const FunctionTypeVar* fType = get<FunctionTypeVar>(argVec[0]);
const FunctionTypeVar* fType = get<FunctionTypeVar>(follow(argVec[0]));
REQUIRE(fType != nullptr);
}
@ -3842,10 +3847,10 @@ local T: any
T = {}
T.__index = T
function T.new(...)
local self = {}
setmetatable(self, T)
self:construct(...)
return self
local self = {}
setmetatable(self, T)
self:construct(...)
return self
end
function T:construct(index)
end
@ -4068,11 +4073,11 @@ function n:Clone() end
local m = {}
function m.a(x)
x:Clone()
x:Clone()
end
function m.b()
m.a(n)
m.a(n)
end
return m
@ -4393,8 +4398,6 @@ TEST_CASE_FIXTURE(Fixture, "record_matching_overload")
TEST_CASE_FIXTURE(Fixture, "infer_anonymous_function_arguments")
{
ScopedFastFlag luauInferFunctionArgsFix("LuauInferFunctionArgsFix", true);
// Simple direct arg to arg propagation
CheckResult result = check(R"(
type Table = { x: number, y: number }
@ -4681,7 +4684,6 @@ TEST_CASE_FIXTURE(Fixture, "checked_prop_too_early")
{
ScopedFastFlag sffs[] = {
{"LuauSlightlyMoreFlexibleBinaryPredicates", true},
{"LuauExtraNilRecovery", true},
};
CheckResult result = check(R"(
@ -4698,7 +4700,6 @@ TEST_CASE_FIXTURE(Fixture, "accidentally_checked_prop_in_opposite_branch")
{
ScopedFastFlag sffs[] = {
{"LuauSlightlyMoreFlexibleBinaryPredicates", true},
{"LuauExtraNilRecovery", true},
};
CheckResult result = check(R"(

View file

@ -8,6 +8,8 @@
#include "doctest.h"
LUAU_FASTFLAG(LuauQuantifyInPlace2);
using namespace Luau;
struct TryUnifyFixture : Fixture
@ -15,7 +17,8 @@ struct TryUnifyFixture : Fixture
TypeArena arena;
ScopePtr globalScope{new Scope{arena.addTypePack({TypeId{}})}};
InternalErrorReporter iceHandler;
Unifier state{&arena, Mode::Strict, globalScope, Location{}, Variance::Covariant, &iceHandler};
UnifierSharedState unifierState{&iceHandler};
Unifier state{&arena, Mode::Strict, globalScope, Location{}, Variance::Covariant, unifierState};
};
TEST_SUITE_BEGIN("TryUnifyTests");
@ -139,7 +142,10 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "typepack_unification_should_trim_free_tails"
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ("(number) -> (boolean)", toString(requireType("f")));
if (FFlag::LuauQuantifyInPlace2)
CHECK_EQ("(number) -> boolean", toString(requireType("f")));
else
CHECK_EQ("(number) -> (boolean)", toString(requireType("f")));
}
TEST_CASE_FIXTURE(TryUnifyFixture, "variadic_type_pack_unification")

View file

@ -98,10 +98,10 @@ TEST_CASE_FIXTURE(Fixture, "higher_order_function")
std::vector<TypeId> applyArgs = flatten(applyType->argTypes).first;
REQUIRE_EQ(3, applyArgs.size());
const FunctionTypeVar* fType = get<FunctionTypeVar>(applyArgs[0]);
const FunctionTypeVar* fType = get<FunctionTypeVar>(follow(applyArgs[0]));
REQUIRE(fType != nullptr);
const FunctionTypeVar* gType = get<FunctionTypeVar>(applyArgs[1]);
const FunctionTypeVar* gType = get<FunctionTypeVar>(follow(applyArgs[1]));
REQUIRE(gType != nullptr);
std::vector<TypeId> gArgs = flatten(gType->argTypes).first;
@ -285,7 +285,7 @@ TEST_CASE_FIXTURE(Fixture, "variadic_argument_tail")
{
CheckResult result = check(R"(
local _ = function():((...any)->(...any),()->())
return function() end, function() end
return function() end, function() end
end
for y in _() do
end

View file

@ -181,8 +181,6 @@ TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_one_optional_property")
TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_missing_property")
{
ScopedFastFlag luauMissingUnionPropertyError("LuauMissingUnionPropertyError", true);
CheckResult result = check(R"(
type A = {x: number}
type B = {}
@ -242,8 +240,6 @@ TEST_CASE_FIXTURE(Fixture, "union_equality_comparisons")
TEST_CASE_FIXTURE(Fixture, "optional_union_members")
{
ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true);
CheckResult result = check(R"(
local a = { a = { x = 1, y = 2 }, b = 3 }
type A = typeof(a)
@ -259,8 +255,6 @@ local c = bf.a.y
TEST_CASE_FIXTURE(Fixture, "optional_union_functions")
{
ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true);
CheckResult result = check(R"(
local a = {}
function a.foo(x:number, y:number) return x + y end
@ -276,8 +270,6 @@ local c = b.foo(1, 2)
TEST_CASE_FIXTURE(Fixture, "optional_union_methods")
{
ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true);
CheckResult result = check(R"(
local a = {}
function a:foo(x:number, y:number) return x + y end
@ -310,8 +302,6 @@ return f()
TEST_CASE_FIXTURE(Fixture, "optional_field_access_error")
{
ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true);
CheckResult result = check(R"(
type A = { x: number }
local b: A? = { x = 2 }
@ -327,8 +317,6 @@ local d = b.y
TEST_CASE_FIXTURE(Fixture, "optional_index_error")
{
ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true);
CheckResult result = check(R"(
type A = {number}
local a: A? = {1, 2, 3}
@ -341,8 +329,6 @@ local b = a[1]
TEST_CASE_FIXTURE(Fixture, "optional_call_error")
{
ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true);
CheckResult result = check(R"(
type A = (number) -> number
local a: A? = function(a) return -a end
@ -355,8 +341,6 @@ local b = a(4)
TEST_CASE_FIXTURE(Fixture, "optional_assignment_errors")
{
ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true);
CheckResult result = check(R"(
type A = { x: number }
local a: A? = { x = 2 }
@ -378,8 +362,6 @@ a.x = 2
TEST_CASE_FIXTURE(Fixture, "optional_length_error")
{
ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true);
CheckResult result = check(R"(
type A = {number}
local a: A? = {1, 2, 3}
@ -392,9 +374,6 @@ local b = #a
TEST_CASE_FIXTURE(Fixture, "optional_missing_key_error_details")
{
ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true);
ScopedFastFlag luauMissingUnionPropertyError("LuauMissingUnionPropertyError", true);
CheckResult result = check(R"(
type A = { x: number, y: number }
type B = { x: number, y: number }

View file

@ -265,4 +265,64 @@ TEST_CASE_FIXTURE(Fixture, "substitution_skip_failure")
CHECK_EQ("{ f: t1 } where t1 = () -> { f: () -> { f: ({ f: t1 }) -> (), signal: { f: (any) -> () } } }", toString(result));
}
TEST_CASE("tagging_tables")
{
ScopedFastFlag sff{"LuauRefactorTagging", true};
TypeVar ttv{TableTypeVar{}};
CHECK(!Luau::hasTag(&ttv, "foo"));
Luau::attachTag(&ttv, "foo");
CHECK(Luau::hasTag(&ttv, "foo"));
}
TEST_CASE("tagging_classes")
{
ScopedFastFlag sff{"LuauRefactorTagging", true};
TypeVar base{ClassTypeVar{"Base", {}, std::nullopt, std::nullopt, {}, nullptr}};
CHECK(!Luau::hasTag(&base, "foo"));
Luau::attachTag(&base, "foo");
CHECK(Luau::hasTag(&base, "foo"));
}
TEST_CASE("tagging_subclasses")
{
ScopedFastFlag sff{"LuauRefactorTagging", true};
TypeVar base{ClassTypeVar{"Base", {}, std::nullopt, std::nullopt, {}, nullptr}};
TypeVar derived{ClassTypeVar{"Derived", {}, &base, std::nullopt, {}, nullptr}};
CHECK(!Luau::hasTag(&base, "foo"));
CHECK(!Luau::hasTag(&derived, "foo"));
Luau::attachTag(&base, "foo");
CHECK(Luau::hasTag(&base, "foo"));
CHECK(Luau::hasTag(&derived, "foo"));
Luau::attachTag(&derived, "bar");
CHECK(!Luau::hasTag(&base, "bar"));
CHECK(Luau::hasTag(&derived, "bar"));
}
TEST_CASE("tagging_functions")
{
ScopedFastFlag sff{"LuauRefactorTagging", true};
TypePackVar empty{TypePack{}};
TypeVar ftv{FunctionTypeVar{&empty, &empty}};
CHECK(!Luau::hasTag(&ftv, "foo"));
Luau::attachTag(&ftv, "foo");
CHECK(Luau::hasTag(&ftv, "foo"));
}
TEST_CASE("tagging_props")
{
ScopedFastFlag sff{"LuauRefactorTagging", true};
Property prop{};
CHECK(!Luau::hasTag(prop, "foo"));
Luau::attachTag(prop, "foo");
CHECK(Luau::hasTag(prop, "foo"));
}
TEST_SUITE_END();

View file

@ -0,0 +1,15 @@
-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
-- This file is based on Lua 5.x tests -- https://github.com/lua/lua/tree/master/testes
-- Generate an error (i.e. throw an exception) inside a tag method which is indirectly
-- called via pcall.
-- This test is meant to detect a regression in handling errors inside a tag method
local testtable = {}
setmetatable(testtable, { __index = function() error("Error") end })
pcall(function()
testtable.missingmethod()
end)
return('OK')

View file

@ -11,9 +11,9 @@ class VariantPrinter:
return type.name + " [" + str(value) + "]"
def match_printer(val):
type = val.type.strip_typedefs()
if type.name and type.name.startswith('Luau::Variant<'):
return VariantPrinter(val)
return None
type = val.type.strip_typedefs()
if type.name and type.name.startswith('Luau::Variant<'):
return VariantPrinter(val)
return None
gdb.pretty_printers.append(match_printer)