Merge branch 'master' into error-function

This commit is contained in:
Arseny Kapoulkine 2021-11-19 08:17:47 -08:00 committed by GitHub
commit 8267765245
Signed by: DevComp
GPG key ID: 4AEE18F83AFDEB23
152 changed files with 11377 additions and 5604 deletions

View file

@ -1,5 +1,5 @@
blank_issues_enabled: false
contact_links:
- name: Help and support
- name: Questions
url: https://github.com/Roblox/luau/discussions
about: Please use GitHub Discussions if you have questions or need help.

View file

@ -27,13 +27,13 @@ jobs:
- uses: actions/checkout@v1
- name: make test
run: |
make -j2 config=sanitize test
make -j2 config=sanitize werror=1 test
- name: make test w/flags
run: |
make -j2 config=sanitize flags=true test
make -j2 config=sanitize werror=1 flags=true test
- name: make cli
run: |
make -j2 config=sanitize luau luau-analyze # match config with tests to improve build time
make -j2 config=sanitize werror=1 luau luau-analyze # match config with tests to improve build time
./luau tests/conformance/assert.lua
./luau-analyze tests/conformance/assert.lua
@ -45,7 +45,7 @@ jobs:
steps:
- uses: actions/checkout@v1
- name: cmake configure
run: cmake . -A ${{matrix.arch}}
run: cmake . -A ${{matrix.arch}} -DLUAU_WERROR=ON
- name: cmake test
shell: bash # necessary for fail-fast
run: |

1
.gitignore vendored
View file

@ -5,3 +5,4 @@
^default.prof*
^fuzz-*
^luau$
/.vs

View file

@ -1,7 +1,8 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "TypeInfer.h"
#include "Luau/Scope.h"
#include "Luau/TypeInfer.h"
namespace Luau
{
@ -33,7 +34,6 @@ TypeId makeFunction( // Polymorphic
std::initializer_list<TypeId> paramTypes, std::initializer_list<std::string> paramNames, std::initializer_list<TypeId> retTypes);
void attachMagicFunction(TypeId ty, MagicFunction fn);
void attachFunctionTag(TypeId ty, std::string constraint);
Property makeProperty(TypeId ty, std::optional<std::string> documentationSymbol = std::nullopt);
void assignPropDocumentationSymbols(TableTypeVar::Props& props, const std::string& baseName);

View file

@ -12,10 +12,17 @@ namespace Luau
struct FunctionDocumentation;
struct TableDocumentation;
struct OverloadedFunctionDocumentation;
struct BasicDocumentation;
using Documentation = Luau::Variant<std::string, FunctionDocumentation, TableDocumentation, OverloadedFunctionDocumentation>;
using Documentation = Luau::Variant<BasicDocumentation, FunctionDocumentation, TableDocumentation, OverloadedFunctionDocumentation>;
using DocumentationSymbol = std::string;
struct BasicDocumentation
{
std::string documentation;
std::string learnMoreLink;
};
struct FunctionParameterDocumentation
{
std::string name;
@ -29,6 +36,7 @@ struct FunctionDocumentation
std::string documentation;
std::vector<FunctionParameterDocumentation> parameters;
std::vector<DocumentationSymbol> returns;
std::string learnMoreLink;
};
struct OverloadedFunctionDocumentation
@ -43,6 +51,7 @@ struct TableDocumentation
{
std::string documentation;
Luau::DenseHashMap<std::string, DocumentationSymbol> keys;
std::string learnMoreLink;
};
using DocumentationDatabase = Luau::DenseHashMap<DocumentationSymbol, Documentation>;

View file

@ -8,11 +8,20 @@
namespace Luau
{
struct TypeError;
struct TypeMismatch
{
TypeId wantedType;
TypeId givenType;
TypeMismatch() = default;
TypeMismatch(TypeId wantedType, TypeId givenType);
TypeMismatch(TypeId wantedType, TypeId givenType, std::string reason);
TypeMismatch(TypeId wantedType, TypeId givenType, std::string reason, TypeError error);
TypeId wantedType = nullptr;
TypeId givenType = nullptr;
std::string reason;
std::shared_ptr<TypeError> error;
bool operator==(const TypeMismatch& rhs) const;
};
@ -120,6 +129,7 @@ struct IncorrectGenericParameterCount
Name name;
TypeFun typeFun;
size_t actualParameters;
size_t actualPackParameters;
bool operator==(const IncorrectGenericParameterCount& rhs) const;
};

View file

@ -25,51 +25,39 @@ struct SourceCode
Type type;
};
struct ModuleInfo
{
ModuleName name;
bool optional = false;
};
struct FileResolver
{
virtual ~FileResolver() {}
/** Fetch the source code associated with the provided ModuleName.
*
* FIXME: This requires a string copy!
*
* @returns The actual Lua code on success.
* @returns std::nullopt if no such file exists. When this occurs, type inference will report an UnknownRequire error.
*/
virtual std::optional<SourceCode> readSource(const ModuleName& name) = 0;
/** Does the module exist?
*
* Saves a string copy over reading the source and throwing it away.
*/
virtual bool moduleExists(const ModuleName& name) const = 0;
virtual std::optional<ModuleInfo> resolveModule(const ModuleInfo* context, AstExpr* expr)
{
return std::nullopt;
}
virtual std::optional<ModuleName> fromAstFragment(AstExpr* expr) const = 0;
/** Given a valid module name and a string of arbitrary data, figure out the concatenation.
*/
virtual ModuleName concat(const ModuleName& lhs, std::string_view rhs) const = 0;
/** Goes "up" a level in the hierarchy that the ModuleName represents.
*
* For instances, this is analogous to someInstance.Parent; for paths, this is equivalent to removing the last
* element of the path. Other ModuleName representations may have other ways of doing this.
*
* @returns The parent ModuleName, if one exists.
* @returns std::nullopt if there is no parent for this module name.
*/
virtual std::optional<ModuleName> getParentModuleName(const ModuleName& name) const = 0;
virtual std::optional<std::string> getHumanReadableModuleName_(const ModuleName& name) const
virtual std::string getHumanReadableModuleName(const ModuleName& name) const
{
return name;
}
virtual std::optional<std::string> getEnvironmentForModule(const ModuleName& name) const = 0;
virtual std::optional<std::string> getEnvironmentForModule(const ModuleName& name) const
{
return std::nullopt;
}
/** LanguageService only:
* std::optional<ModuleName> fromInstance(Instance* inst)
*/
// DEPRECATED APIS
// These are going to be removed with LuauNewRequireTrace2
virtual bool moduleExists(const ModuleName& name) const = 0;
virtual std::optional<ModuleName> fromAstFragment(AstExpr* expr) const = 0;
virtual ModuleName concat(const ModuleName& lhs, std::string_view rhs) const = 0;
virtual std::optional<ModuleName> getParentModuleName(const ModuleName& name) const = 0;
};
struct NullFileResolver : FileResolver
@ -94,10 +82,6 @@ struct NullFileResolver : FileResolver
{
return std::nullopt;
}
std::optional<std::string> getEnvironmentForModule(const ModuleName& name) const override
{
return std::nullopt;
}
};
} // namespace Luau

View file

@ -90,10 +90,12 @@ struct Module
TypeArena internalTypes;
std::vector<std::pair<Location, ScopePtr>> scopes; // never empty
std::unordered_map<const AstExpr*, TypeId> astTypes;
std::unordered_map<const AstExpr*, TypeId> astExpectedTypes;
std::unordered_map<const AstExpr*, TypeId> astOriginalCallTypes;
std::unordered_map<const AstExpr*, TypeId> astOverloadResolvedTypes;
DenseHashMap<const AstExpr*, TypeId> astTypes{nullptr};
DenseHashMap<const AstExpr*, TypeId> astExpectedTypes{nullptr};
DenseHashMap<const AstExpr*, TypeId> astOriginalCallTypes{nullptr};
DenseHashMap<const AstExpr*, TypeId> astOverloadResolvedTypes{nullptr};
std::unordered_map<Name, TypeId> declaredGlobals;
ErrorVec errors;
Mode mode;

View file

@ -15,12 +15,6 @@ struct Module;
using ModulePtr = std::shared_ptr<Module>;
struct ModuleInfo
{
ModuleName name;
bool optional = false;
};
struct ModuleResolver
{
virtual ~ModuleResolver() {}

View file

@ -0,0 +1,14 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/TypeVar.h"
namespace Luau
{
struct Module;
using ModulePtr = std::shared_ptr<Module>;
void quantify(ModulePtr module, TypeId ty, TypeLevel level);
} // namespace Luau

View file

@ -17,12 +17,11 @@ struct AstLocal;
struct RequireTraceResult
{
DenseHashMap<const AstExpr*, ModuleName> exprs{0};
DenseHashMap<const AstExpr*, bool> optional{0};
DenseHashMap<const AstExpr*, ModuleInfo> exprs{nullptr};
std::vector<std::pair<ModuleName, Location>> requires;
};
RequireTraceResult traceRequires(FileResolver* fileResolver, AstStatBlock* root, ModuleName currentModuleName);
RequireTraceResult traceRequires(FileResolver* fileResolver, AstStatBlock* root, const ModuleName& currentModuleName);
} // namespace Luau

View file

@ -0,0 +1,67 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Location.h"
#include "Luau/TypeVar.h"
#include <unordered_map>
#include <optional>
#include <memory>
namespace Luau
{
struct Scope;
using ScopePtr = std::shared_ptr<Scope>;
struct Binding
{
TypeId typeId;
Location location;
bool deprecated = false;
std::string deprecatedSuggestion;
std::optional<std::string> documentationSymbol;
};
struct Scope
{
explicit Scope(TypePackId returnType); // root scope
explicit Scope(const ScopePtr& parent, int subLevel = 0); // child scope. Parent must not be nullptr.
const ScopePtr parent; // null for the root
std::unordered_map<Symbol, Binding> bindings;
TypePackId returnType;
bool breakOk = false;
std::optional<TypePackId> varargPack;
TypeLevel level;
std::unordered_map<Name, TypeFun> exportedTypeBindings;
std::unordered_map<Name, TypeFun> privateTypeBindings;
std::unordered_map<Name, Location> typeAliasLocations;
std::unordered_map<Name, std::unordered_map<Name, TypeFun>> importedTypeBindings;
std::optional<TypeId> lookup(const Symbol& name);
std::optional<TypeFun> lookupType(const Name& name);
std::optional<TypeFun> lookupImportedType(const Name& moduleAlias, const Name& name);
std::unordered_map<Name, TypePackId> privateTypePackBindings;
std::optional<TypePackId> lookupPack(const Name& name);
// WARNING: This function linearly scans for a string key of equal value! It is thus O(n**2)
std::optional<Binding> linearSearchForBinding(const std::string& name, bool traverseScopeChain = true);
RefinementMap refinements;
// For mutually recursive type aliases, it's important that
// they use the same types for the same names.
// For instance, in `type Tree<T> { data: T, children: Forest<T> } type Forest<T> = {Tree<T>}`
// we need that the generic type `T` in both cases is the same, so we use a cache.
std::unordered_map<Name, TypeId> typeAliasTypeParameters;
std::unordered_map<Name, TypePackId> typeAliasTypePackParameters;
};
} // namespace Luau

View file

@ -52,8 +52,6 @@
// `T`, and the type of `f` are in the same SCC, which is why `f` gets
// replaced.
LUAU_FASTFLAG(DebugLuauTrackOwningArena)
namespace Luau
{
@ -188,20 +186,12 @@ struct Substitution : FindDirty
template<typename T>
TypeId addType(const T& tv)
{
TypeId allocated = currentModule->internalTypes.typeVars.allocate(tv);
if (FFlag::DebugLuauTrackOwningArena)
asMutable(allocated)->owningArena = &currentModule->internalTypes;
return allocated;
return currentModule->internalTypes.addType(tv);
}
template<typename T>
TypePackId addTypePack(const T& tp)
{
TypePackId allocated = currentModule->internalTypes.typePacks.allocate(tp);
if (FFlag::DebugLuauTrackOwningArena)
asMutable(allocated)->owningArena = &currentModule->internalTypes;
return allocated;
return currentModule->internalTypes.addTypePack(TypePackVar{tp});
}
};

View file

@ -23,10 +23,11 @@ struct ToStringNameMap
struct ToStringOptions
{
bool exhaustive = false; // If true, we produce complete output rather than comprehensible output
bool useLineBreaks = false; // If true, we insert new lines to separate long results such as table entries/metatable.
bool functionTypeArguments = false; // If true, output function type argument names when they are available
bool hideTableKind = false; // If true, all tables will be surrounded with plain '{}'
bool exhaustive = false; // If true, we produce complete output rather than comprehensible output
bool useLineBreaks = false; // If true, we insert new lines to separate long results such as table entries/metatable.
bool functionTypeArguments = false; // If true, output function type argument names when they are available
bool hideTableKind = false; // If true, all tables will be surrounded with plain '{}'
bool hideNamedFunctionTypeParameters = false; // If true, type parameters of functions will be hidden at top-level.
size_t maxTableLength = size_t(FInt::LuauTableTypeMaximumStringifierLength); // Only applied to TableTypeVars
size_t maxTypeLength = size_t(FInt::LuauTypeMaximumStringifierLength);
std::optional<ToStringNameMap> nameMap;
@ -64,9 +65,13 @@ inline std::string toString(TypePackId ty)
std::string toString(const TypeVar& tv, const ToStringOptions& opts = {});
std::string toString(const TypePackVar& tp, const ToStringOptions& opts = {});
std::string toStringNamedFunction(const std::string& prefix, const FunctionTypeVar& ftv, ToStringOptions opts = {});
// It could be useful to see the text representation of a type during a debugging session instead of exploring the content of the class
// These functions will dump the type to stdout and can be evaluated in Watch/Immediate windows or as gdb/lldb expression
void dump(TypeId ty);
void dump(TypePackId ty);
std::string generateName(size_t n);
} // namespace Luau

View file

@ -12,6 +12,7 @@ struct AstArray;
class AstStat;
bool containsFunctionCall(const AstStat& stat);
bool containsFunctionCallOrReturn(const AstStat& stat);
bool isFunction(const AstStat& stat);
void toposort(std::vector<AstStat*>& stats);

View file

@ -18,6 +18,7 @@ struct TranspileResult
std::string parseError; // Nonempty if the transpile failed
};
std::string toString(AstNode* node);
void dump(AstNode* node);
// Never fails on a well-formed AST
@ -25,6 +26,6 @@ std::string transpile(AstStatBlock& ast);
std::string transpileWithTypes(AstStatBlock& block);
// Only fails when parsing fails
TranspileResult transpile(std::string_view source, ParseOptions options = ParseOptions{});
TranspileResult transpile(std::string_view source, ParseOptions options = ParseOptions{}, bool withTypes = false);
} // namespace Luau

View file

@ -3,19 +3,37 @@
#include "Luau/TypeVar.h"
LUAU_FASTFLAG(LuauShareTxnSeen);
namespace Luau
{
// Log of where what TypeIds we are rebinding and what they used to be
struct TxnLog
{
TxnLog() = default;
explicit TxnLog(const std::vector<std::pair<TypeId, TypeId>>& seen)
: seen(seen)
TxnLog()
: originalSeenSize(0)
, ownedSeen()
, sharedSeen(&ownedSeen)
{
}
explicit TxnLog(std::vector<std::pair<TypeId, TypeId>>* sharedSeen)
: originalSeenSize(sharedSeen->size())
, ownedSeen()
, sharedSeen(sharedSeen)
{
}
explicit TxnLog(const std::vector<std::pair<TypeId, TypeId>>& ownedSeen)
: originalSeenSize(ownedSeen.size())
, ownedSeen(ownedSeen)
, sharedSeen(nullptr)
{
// This is deprecated!
LUAU_ASSERT(!FFlag::LuauShareTxnSeen);
}
TxnLog(const TxnLog&) = delete;
TxnLog& operator=(const TxnLog&) = delete;
@ -38,9 +56,11 @@ private:
std::vector<std::pair<TypeId, TypeVar>> typeVarChanges;
std::vector<std::pair<TypePackId, TypePackVar>> typePackChanges;
std::vector<std::pair<TableTypeVar*, std::optional<TypeId>>> tableChanges;
size_t originalSeenSize;
public:
std::vector<std::pair<TypeId, TypeId>> seen; // used to avoid infinite recursion when types are cyclic
std::vector<std::pair<TypeId, TypeId>> ownedSeen; // used to avoid infinite recursion when types are cyclic
std::vector<std::pair<TypeId, TypeId>>* sharedSeen; // shared with all the descendent logs
};
} // namespace Luau

View file

@ -11,6 +11,7 @@
#include "Luau/TypePack.h"
#include "Luau/TypeVar.h"
#include "Luau/Unifier.h"
#include "Luau/UnifierSharedState.h"
#include <memory>
#include <unordered_map>
@ -86,7 +87,10 @@ struct ApplyTypeFunction : Substitution
{
TypeLevel level;
bool encounteredForwardedType;
std::unordered_map<TypeId, TypeId> arguments;
std::unordered_map<TypeId, TypeId> typeArguments;
std::unordered_map<TypePackId, TypePackId> typePackArguments;
bool ignoreChildren(TypeId ty) override;
bool ignoreChildren(TypePackId tp) override;
bool isDirty(TypeId ty) override;
bool isDirty(TypePackId tp) override;
TypeId clean(TypeId ty) override;
@ -118,7 +122,7 @@ struct TypeChecker
void check(const ScopePtr& scope, const AstStatForIn& forin);
void check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatFunction& function);
void check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatLocalFunction& function);
void check(const ScopePtr& scope, const AstStatTypeAlias& typealias, bool forwardDeclare = false);
void check(const ScopePtr& scope, const AstStatTypeAlias& typealias, int subLevel = 0, bool forwardDeclare = false);
void check(const ScopePtr& scope, const AstStatDeclareClass& declaredClass);
void check(const ScopePtr& scope, const AstStatDeclareFunction& declaredFunction);
@ -171,10 +175,10 @@ struct TypeChecker
std::vector<std::optional<TypeId>> getExpectedTypesForCall(const std::vector<TypeId>& overloads, size_t argumentCount, bool selfCall);
std::optional<ExprResult<TypePackId>> checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack,
TypePackId argPack, TypePack* args, const std::vector<Location>& argLocations, const ExprResult<TypePackId>& argListResult,
std::vector<TypeId>& overloadsThatMatchArgCount, std::vector<OverloadErrorEntry>& errors);
std::vector<TypeId>& overloadsThatMatchArgCount, std::vector<TypeId>& overloadsThatDont, std::vector<OverloadErrorEntry>& errors);
bool handleSelfCallMismatch(const ScopePtr& scope, const AstExprCall& expr, TypePack* args, const std::vector<Location>& argLocations,
const std::vector<OverloadErrorEntry>& errors);
ExprResult<TypePackId> reportOverloadResolutionError(const ScopePtr& scope, const AstExprCall& expr, TypePackId retPack, TypePackId argPack,
void reportOverloadResolutionError(const ScopePtr& scope, const AstExprCall& expr, TypePackId retPack, TypePackId argPack,
const std::vector<Location>& argLocations, const std::vector<TypeId>& overloads, const std::vector<TypeId>& overloadsThatMatchArgCount,
const std::vector<OverloadErrorEntry>& errors);
@ -259,8 +263,6 @@ public:
*
*/
TypeId instantiate(const ScopePtr& scope, TypeId ty, Location location);
// Removed by FFlag::LuauRankNTypes
TypePackId DEPRECATED_instantiate(const ScopePtr& scope, TypePackId ty, Location location);
// Replace any free types or type packs by `any`.
// This is used when exporting types from modules, to make sure free types don't leak.
@ -280,6 +282,14 @@ public:
// Wrapper for merge(l, r, toUnion) but without the lambda junk.
void merge(RefinementMap& l, const RefinementMap& r);
// Produce an "emergency backup type" for recovery from type errors.
// This comes in two flavours, depening on whether or not we can make a good guess
// for an error recovery type.
TypeId errorRecoveryType(TypeId guess);
TypePackId errorRecoveryTypePack(TypePackId guess);
TypeId errorRecoveryType(const ScopePtr& scope);
TypePackId errorRecoveryTypePack(const ScopePtr& scope);
private:
void prepareErrorsForDisplay(ErrorVec& errVec);
void diagnoseMissingTableKey(UnknownProperty* utk, TypeErrorData& data);
@ -294,8 +304,10 @@ private:
// Produce a new free type var.
TypeId freshType(const ScopePtr& scope);
TypeId freshType(TypeLevel level);
TypeId DEPRECATED_freshType(const ScopePtr& scope, bool canBeGeneric = false);
TypeId DEPRECATED_freshType(TypeLevel level, bool canBeGeneric = false);
// Produce a new singleton type var.
TypeId singletonType(bool value);
TypeId singletonType(std::string value);
// Returns nullopt if the predicate filters down the TypeId to 0 options.
std::optional<TypeId> filterMap(TypeId type, TypeIdPredicate predicate);
@ -322,17 +334,16 @@ private:
TypePackId addTypePack(std::initializer_list<TypeId>&& ty);
TypePackId freshTypePack(const ScopePtr& scope);
TypePackId freshTypePack(TypeLevel level);
TypePackId DEPRECATED_freshTypePack(const ScopePtr& scope, bool canBeGeneric = false);
TypePackId DEPRECATED_freshTypePack(TypeLevel level, bool canBeGeneric = false);
TypeId resolveType(const ScopePtr& scope, const AstType& annotation, bool canBeGeneric = false);
TypeId resolveType(const ScopePtr& scope, const AstType& annotation);
TypePackId resolveTypePack(const ScopePtr& scope, const AstTypeList& types);
TypePackId resolveTypePack(const ScopePtr& scope, const AstTypePack& annotation);
TypeId instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector<TypeId>& typeParams, const Location& location);
TypeId instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& typePackParams, const Location& location);
// Note: `scope` must be a fresh scope.
std::pair<std::vector<TypeId>, std::vector<TypePackId>> createGenericTypes(
const ScopePtr& scope, const AstNode& node, const AstArray<AstName>& genericNames, const AstArray<AstName>& genericPackNames);
std::pair<std::vector<TypeId>, std::vector<TypePackId>> createGenericTypes(const ScopePtr& scope, std::optional<TypeLevel> levelOpt,
const AstNode& node, const AstArray<AstName>& genericNames, const AstArray<AstName>& genericPackNames);
public:
ErrorVec resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense);
@ -348,7 +359,6 @@ private:
void resolve(const OrPredicate& orP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense);
void resolve(const IsAPredicate& isaP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense);
void resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense);
void DEPRECATED_resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense);
void resolve(const EqPredicate& eqP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense);
bool isNonstrictMode() const;
@ -379,6 +389,8 @@ public:
std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope;
InternalErrorReporter* iceHandler;
UnifierSharedState unifierState;
public:
const TypeId nilType;
const TypeId numberType;
@ -386,66 +398,15 @@ public:
const TypeId booleanType;
const TypeId threadType;
const TypeId anyType;
const TypeId errorType;
const TypeId optionalNumberType;
const TypePackId anyTypePack;
const TypePackId errorTypePack;
private:
int checkRecursionCount = 0;
int recursionCount = 0;
};
struct Binding
{
TypeId typeId;
Location location;
bool deprecated = false;
std::string deprecatedSuggestion;
std::optional<std::string> documentationSymbol;
};
struct Scope
{
explicit Scope(TypePackId returnType); // root scope
explicit Scope(const ScopePtr& parent, int subLevel = 0); // child scope. Parent must not be nullptr.
const ScopePtr parent; // null for the root
std::unordered_map<Symbol, Binding> bindings;
TypePackId returnType;
bool breakOk = false;
std::optional<TypePackId> varargPack;
TypeLevel level;
std::unordered_map<Name, TypeFun> exportedTypeBindings;
std::unordered_map<Name, TypeFun> privateTypeBindings;
std::unordered_map<Name, Location> typeAliasLocations;
std::unordered_map<Name, std::unordered_map<Name, TypeFun>> importedTypeBindings;
std::optional<TypeId> lookup(const Symbol& name);
std::optional<TypeFun> lookupType(const Name& name);
std::optional<TypeFun> lookupImportedType(const Name& moduleAlias, const Name& name);
std::unordered_map<Name, TypePackId> privateTypePackBindings;
std::optional<TypePackId> lookupPack(const Name& name);
// WARNING: This function linearly scans for a string key of equal value! It is thus O(n**2)
std::optional<Binding> linearSearchForBinding(const std::string& name, bool traverseScopeChain = true);
RefinementMap refinements;
// For mutually recursive type aliases, it's important that
// they use the same types for the same names.
// For instance, in `type Tree<T> { data: T, children: Forest<T> } type Forest<T> = {Tree<T>}`
// we need that the generic type `T` in both cases is the same, so we use a cache.
std::unordered_map<Name, TypeId> typeAliasParameters;
};
// Unit test hook
void setPrintLine(void (*pl)(const std::string& s));
void resetPrintLine();

View file

@ -8,8 +8,6 @@
#include <optional>
#include <set>
LUAU_FASTFLAG(LuauAddMissingFollow)
namespace Luau
{
@ -117,7 +115,8 @@ bool areEqual(SeenSet& seen, const TypePackVar& lhs, const TypePackVar& rhs);
TypePackId follow(TypePackId tp);
size_t size(const TypePackId tp);
size_t size(TypePackId tp);
bool finite(TypePackId tp);
size_t size(const TypePack& tp);
std::optional<TypeId> first(TypePackId tp);
@ -127,13 +126,10 @@ TypePack* asMutable(const TypePack* tp);
template<typename T>
const T* get(TypePackId tp)
{
if (FFlag::LuauAddMissingFollow)
{
LUAU_ASSERT(tp);
LUAU_ASSERT(tp);
if constexpr (!std::is_same_v<T, BoundTypePack>)
LUAU_ASSERT(get_if<BoundTypePack>(&tp->ty) == nullptr);
}
if constexpr (!std::is_same_v<T, BoundTypePack>)
LUAU_ASSERT(get_if<BoundTypePack>(&tp->ty) == nullptr);
return get_if<T>(&(tp->ty));
}
@ -141,13 +137,10 @@ const T* get(TypePackId tp)
template<typename T>
T* getMutable(TypePackId tp)
{
if (FFlag::LuauAddMissingFollow)
{
LUAU_ASSERT(tp);
LUAU_ASSERT(tp);
if constexpr (!std::is_same_v<T, BoundTypePack>)
LUAU_ASSERT(get_if<BoundTypePack>(&tp->ty) == nullptr);
}
if constexpr (!std::is_same_v<T, BoundTypePack>)
LUAU_ASSERT(get_if<BoundTypePack>(&tp->ty) == nullptr);
return get_if<T>(&(asMutable(tp)->ty));
}

View file

@ -18,7 +18,6 @@
LUAU_FASTINT(LuauTableTypeMaximumStringifierLength)
LUAU_FASTINT(LuauTypeMaximumStringifierLength)
LUAU_FASTFLAG(LuauAddMissingFollow)
namespace Luau
{
@ -109,6 +108,79 @@ struct PrimitiveTypeVar
}
};
// Singleton types https://github.com/Roblox/luau/blob/master/rfcs/syntax-singleton-types.md
// Types for true and false
struct BoolSingleton
{
bool value;
bool operator==(const BoolSingleton& rhs) const
{
return value == rhs.value;
}
bool operator!=(const BoolSingleton& rhs) const
{
return !(*this == rhs);
}
};
// Types for "foo", "bar" etc.
struct StringSingleton
{
std::string value;
bool operator==(const StringSingleton& rhs) const
{
return value == rhs.value;
}
bool operator!=(const StringSingleton& rhs) const
{
return !(*this == rhs);
}
};
// No type for float singletons, partly because === isn't any equalivalence on floats
// (NaN != NaN).
using SingletonVariant = Luau::Variant<BoolSingleton, StringSingleton>;
struct SingletonTypeVar
{
explicit SingletonTypeVar(const SingletonVariant& variant)
: variant(variant)
{
}
explicit SingletonTypeVar(SingletonVariant&& variant)
: variant(std::move(variant))
{
}
// Default operator== is C++20.
bool operator==(const SingletonTypeVar& rhs) const
{
return variant == rhs.variant;
}
bool operator!=(const SingletonTypeVar& rhs) const
{
return !(*this == rhs);
}
SingletonVariant variant;
};
template<typename T>
const T* get(const SingletonTypeVar* stv)
{
if (stv)
return get_if<T>(&stv->variant);
else
return nullptr;
}
struct FunctionArgument
{
Name name;
@ -228,6 +300,7 @@ struct TableTypeVar
std::map<Name, Location> methodDefinitionLocations;
std::vector<TypeId> instantiatedTypeParams;
std::vector<TypePackId> instantiatedTypePackParams;
ModuleName definitionModuleName;
std::optional<TypeId> boundTo;
@ -284,8 +357,9 @@ struct ClassTypeVar
struct TypeFun
{
/// These should all be generic
// These should all be generic
std::vector<TypeId> typeParams;
std::vector<TypePackId> typePackParams;
/** The underlying type.
*
@ -293,6 +367,20 @@ struct TypeFun
* You must first use TypeChecker::instantiateTypeFun to turn it into a real type.
*/
TypeId type;
TypeFun() = default;
TypeFun(std::vector<TypeId> typeParams, TypeId type)
: typeParams(std::move(typeParams))
, type(type)
{
}
TypeFun(std::vector<TypeId> typeParams, std::vector<TypePackId> typePackParams, TypeId type)
: typeParams(std::move(typeParams))
, typePackParams(std::move(typePackParams))
, type(type)
{
}
};
// Anything! All static checking is off.
@ -317,8 +405,8 @@ struct LazyTypeVar
using ErrorTypeVar = Unifiable::Error;
using TypeVariant = Unifiable::Variant<TypeId, PrimitiveTypeVar, FunctionTypeVar, TableTypeVar, MetatableTypeVar, ClassTypeVar, AnyTypeVar,
UnionTypeVar, IntersectionTypeVar, LazyTypeVar>;
using TypeVariant = Unifiable::Variant<TypeId, PrimitiveTypeVar, SingletonTypeVar, FunctionTypeVar, TableTypeVar, MetatableTypeVar, ClassTypeVar,
AnyTypeVar, UnionTypeVar, IntersectionTypeVar, LazyTypeVar>;
struct TypeVar final
{
@ -395,30 +483,32 @@ bool isGeneric(const TypeId ty);
// Checks if a type may be instantiated to one containing generic type binders
bool maybeGeneric(const TypeId ty);
// Checks if a type is of the form T1|...|Tn where one of the Ti is a singleton
bool maybeSingleton(TypeId ty);
struct SingletonTypes
{
const TypeId nilType = &nilType_;
const TypeId numberType = &numberType_;
const TypeId stringType = &stringType_;
const TypeId booleanType = &booleanType_;
const TypeId threadType = &threadType_;
const TypeId anyType = &anyType_;
const TypeId errorType = &errorType_;
const TypeId nilType;
const TypeId numberType;
const TypeId stringType;
const TypeId booleanType;
const TypeId threadType;
const TypeId anyType;
const TypeId optionalNumberType;
const TypePackId anyTypePack;
SingletonTypes();
SingletonTypes(const SingletonTypes&) = delete;
void operator=(const SingletonTypes&) = delete;
TypeId errorRecoveryType(TypeId guess);
TypePackId errorRecoveryTypePack(TypePackId guess);
TypeId errorRecoveryType();
TypePackId errorRecoveryTypePack();
private:
std::unique_ptr<struct TypeArena> arena;
TypeVar nilType_;
TypeVar numberType_;
TypeVar stringType_;
TypeVar booleanType_;
TypeVar threadType_;
TypeVar anyType_;
TypeVar errorType_;
TypeId makeStringMetatable();
};
@ -456,13 +546,10 @@ TypeVar* asMutable(TypeId ty);
template<typename T>
const T* get(TypeId tv)
{
if (FFlag::LuauAddMissingFollow)
{
LUAU_ASSERT(tv);
LUAU_ASSERT(tv);
if constexpr (!std::is_same_v<T, BoundTypeVar>)
LUAU_ASSERT(get_if<BoundTypeVar>(&tv->ty) == nullptr);
}
if constexpr (!std::is_same_v<T, BoundTypeVar>)
LUAU_ASSERT(get_if<BoundTypeVar>(&tv->ty) == nullptr);
return get_if<T>(&tv->ty);
}
@ -470,13 +557,10 @@ const T* get(TypeId tv)
template<typename T>
T* getMutable(TypeId tv)
{
if (FFlag::LuauAddMissingFollow)
{
LUAU_ASSERT(tv);
LUAU_ASSERT(tv);
if constexpr (!std::is_same_v<T, BoundTypeVar>)
LUAU_ASSERT(get_if<BoundTypeVar>(&tv->ty) == nullptr);
}
if constexpr (!std::is_same_v<T, BoundTypeVar>)
LUAU_ASSERT(get_if<BoundTypeVar>(&tv->ty) == nullptr);
return get_if<T>(&asMutable(tv)->ty);
}
@ -524,8 +608,11 @@ UnionTypeVarIterator end(const UnionTypeVar* utv);
using TypeIdPredicate = std::function<std::optional<TypeId>(TypeId)>;
std::vector<TypeId> filterMap(TypeId type, TypeIdPredicate predicate);
// TEMP: Clip this prototype with FFlag::LuauStringMetatable
std::optional<ExprResult<TypePackId>> magicFunctionFormat(
struct TypeChecker& typechecker, const std::shared_ptr<struct Scope>& scope, const AstExprCall& expr, ExprResult<TypePackId> exprResult);
void attachTag(TypeId ty, const std::string& tagName);
void attachTag(Property& prop, const std::string& tagName);
bool hasTag(TypeId ty, const std::string& tagName);
bool hasTag(const Property& prop, const std::string& tagName);
bool hasTag(const Tags& tags, const std::string& tagName); // Do not use in new work.
} // namespace Luau

View file

@ -63,12 +63,9 @@ using Name = std::string;
struct Free
{
explicit Free(TypeLevel level);
Free(TypeLevel level, bool DEPRECATED_canBeGeneric);
int index;
TypeLevel level;
// Removed by FFlag::LuauRankNTypes
bool DEPRECATED_canBeGeneric = false;
// True if this free type variable is part of a mutually
// recursive type alias whose definitions haven't been
// resolved yet.
@ -108,6 +105,8 @@ private:
struct Error
{
// This constructor has to be public, since it's used in TypeVar and TypePack,
// but shouldn't be called directly. Please use errorRecoveryType() instead.
Error();
int index;

View file

@ -6,6 +6,7 @@
#include "Luau/TxnLog.h"
#include "Luau/TypeInfer.h"
#include "Luau/Module.h" // FIXME: For TypeArena. It merits breaking out into its own header.
#include "Luau/UnifierSharedState.h"
#include <unordered_set>
@ -36,12 +37,20 @@ struct Unifier
Variance variance = Covariant;
CountMismatch::Context ctx = CountMismatch::Arg;
std::shared_ptr<UnifierCounters> counters;
InternalErrorReporter* iceHandler;
UnifierCounters* counters;
UnifierCounters countersData;
Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, InternalErrorReporter* iceHandler);
Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::vector<std::pair<TypeId, TypeId>>& seen, const Location& location,
Variance variance, InternalErrorReporter* iceHandler, const std::shared_ptr<UnifierCounters>& counters = nullptr);
std::shared_ptr<UnifierCounters> counters_DEPRECATED;
UnifierSharedState& sharedState;
Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, UnifierSharedState& sharedState);
Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::vector<std::pair<TypeId, TypeId>>& ownedSeen, const Location& location,
Variance variance, UnifierSharedState& sharedState, const std::shared_ptr<UnifierCounters>& counters_DEPRECATED = nullptr,
UnifierCounters* counters = nullptr);
Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, std::vector<std::pair<TypeId, TypeId>>* sharedSeen, const Location& location,
Variance variance, UnifierSharedState& sharedState, const std::shared_ptr<UnifierCounters>& counters_DEPRECATED = nullptr,
UnifierCounters* counters = nullptr);
// Test whether the two type vars unify. Never commits the result.
ErrorVec canUnify(TypeId superTy, TypeId subTy);
@ -56,13 +65,17 @@ struct Unifier
private:
void tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall = false, bool isIntersection = false);
void tryUnifyPrimitives(TypeId superTy, TypeId subTy);
void tryUnifySingletons(TypeId superTy, TypeId subTy);
void tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCall = false);
void tryUnifyTables(TypeId left, TypeId right, bool isIntersection = false);
void DEPRECATED_tryUnifyTables(TypeId left, TypeId right, bool isIntersection = false);
void tryUnifyFreeTable(TypeId free, TypeId other);
void tryUnifySealedTables(TypeId left, TypeId right, bool isIntersection);
void tryUnifyWithMetatable(TypeId metatable, TypeId other, bool reversed);
void tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed);
void tryUnify(const TableIndexer& superIndexer, const TableIndexer& subIndexer);
TypeId deeplyOptional(TypeId ty, std::unordered_map<TypeId, TypeId> seen = {});
void cacheResult(TypeId superTy, TypeId subTy);
public:
void tryUnify(TypePackId superTy, TypePackId subTy, bool isFunctionCall = false);
@ -75,14 +88,13 @@ private:
void tryUnifyWithAny(TypePackId any, TypePackId ty);
std::optional<TypeId> findTablePropertyRespectingMeta(TypeId lhsType, Name name);
std::optional<TypeId> findMetatableEntry(TypeId type, std::string entry);
public:
// Report an "infinite type error" if the type "needle" already occurs within "haystack"
void occursCheck(TypeId needle, TypeId haystack);
void occursCheck(std::unordered_set<TypeId>& seen, TypeId needle, TypeId haystack);
void occursCheck(std::unordered_set<TypeId>& seen_DEPRECATED, DenseHashSet<TypeId>& seen, TypeId needle, TypeId haystack);
void occursCheck(TypePackId needle, TypePackId haystack);
void occursCheck(std::unordered_set<TypePackId>& seen, TypePackId needle, TypePackId haystack);
void occursCheck(std::unordered_set<TypePackId>& seen_DEPRECATED, DenseHashSet<TypePackId>& seen, TypePackId needle, TypePackId haystack);
Unifier makeChildUnifier();
@ -90,9 +102,14 @@ private:
bool isNonstrictMode() const;
void checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, TypeId wantedType, TypeId givenType);
void checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, const std::string& prop, TypeId wantedType, TypeId givenType);
[[noreturn]] void ice(const std::string& message, const Location& location);
[[noreturn]] void ice(const std::string& message);
// Remove with FFlagLuauCacheUnifyTableResults
DenseHashSet<TypeId> tempSeenTy_DEPRECATED{nullptr};
DenseHashSet<TypePackId> tempSeenTp_DEPRECATED{nullptr};
};
} // namespace Luau

View file

@ -0,0 +1,44 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/DenseHash.h"
#include "Luau/TypeVar.h"
#include "Luau/TypePack.h"
#include <utility>
namespace Luau
{
struct InternalErrorReporter;
struct TypeIdPairHash
{
size_t hashOne(Luau::TypeId key) const
{
return (uintptr_t(key) >> 4) ^ (uintptr_t(key) >> 9);
}
size_t operator()(const std::pair<Luau::TypeId, Luau::TypeId>& x) const
{
return hashOne(x.first) ^ (hashOne(x.second) << 1);
}
};
struct UnifierSharedState
{
UnifierSharedState(InternalErrorReporter* iceHandler)
: iceHandler(iceHandler)
{
}
InternalErrorReporter* iceHandler;
DenseHashSet<void*> seenAny{nullptr};
DenseHashMap<TypeId, bool> skipCacheForType{nullptr};
DenseHashSet<std::pair<TypeId, TypeId>, TypeIdPairHash> cachedUnify{{nullptr, nullptr}};
DenseHashSet<TypeId> tempSeenTy{nullptr};
DenseHashSet<TypePackId> tempSeenTp{nullptr};
};
} // namespace Luau

View file

@ -1,9 +1,12 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/DenseHash.h"
#include "Luau/TypeVar.h"
#include "Luau/TypePack.h"
LUAU_FASTFLAG(LuauCacheUnifyTableResults)
namespace Luau
{
@ -32,17 +35,33 @@ inline bool hasSeen(std::unordered_set<void*>& seen, const void* tv)
return !seen.insert(ttv).second;
}
inline bool hasSeen(DenseHashSet<void*>& seen, const void* tv)
{
void* ttv = const_cast<void*>(tv);
if (seen.contains(ttv))
return true;
seen.insert(ttv);
return false;
}
inline void unsee(std::unordered_set<void*>& seen, const void* tv)
{
void* ttv = const_cast<void*>(tv);
seen.erase(ttv);
}
template<typename F>
void visit(TypePackId tp, F& f, std::unordered_set<void*>& seen);
inline void unsee(DenseHashSet<void*>& seen, const void* tv)
{
// When DenseHashSet is used for 'visitOnce', where don't forget visited elements
}
template<typename F>
void visit(TypeId ty, F& f, std::unordered_set<void*>& seen)
template<typename F, typename Set>
void visit(TypePackId tp, F& f, Set& seen);
template<typename F, typename Set>
void visit(TypeId ty, F& f, Set& seen)
{
if (visit_detail::hasSeen(seen, ty))
{
@ -79,15 +98,23 @@ void visit(TypeId ty, F& f, std::unordered_set<void*>& seen)
else if (auto ttv = get<TableTypeVar>(ty))
{
// Some visitors want to see bound tables, that's why we visit the original type
if (apply(ty, *ttv, seen, f))
{
for (auto& [_name, prop] : ttv->props)
visit(prop.type, f, seen);
if (ttv->indexer)
if (FFlag::LuauCacheUnifyTableResults && ttv->boundTo)
{
visit(ttv->indexer->indexType, f, seen);
visit(ttv->indexer->indexResultType, f, seen);
visit(*ttv->boundTo, f, seen);
}
else
{
for (auto& [_name, prop] : ttv->props)
visit(prop.type, f, seen);
if (ttv->indexer)
{
visit(ttv->indexer->indexType, f, seen);
visit(ttv->indexer->indexResultType, f, seen);
}
}
}
}
@ -140,8 +167,8 @@ void visit(TypeId ty, F& f, std::unordered_set<void*>& seen)
visit_detail::unsee(seen, ty);
}
template<typename F>
void visit(TypePackId tp, F& f, std::unordered_set<void*>& seen)
template<typename F, typename Set>
void visit(TypePackId tp, F& f, Set& seen)
{
if (visit_detail::hasSeen(seen, tp))
{
@ -182,6 +209,7 @@ void visit(TypePackId tp, F& f, std::unordered_set<void*>& seen)
visit_detail::unsee(seen, tp);
}
} // namespace visit_detail
template<typename TID, typename F>
@ -197,4 +225,11 @@ void visitTypeVar(TID ty, F& f)
visit_detail::visit(ty, f, seen);
}
template<typename TID, typename F>
void visitTypeVarOnce(TID ty, F& f, DenseHashSet<void*>& seen)
{
seen.clear();
visit_detail::visit(ty, f, seen);
}
} // namespace Luau

View file

@ -2,6 +2,7 @@
#include "Luau/AstQuery.h"
#include "Luau/Module.h"
#include "Luau/Scope.h"
#include "Luau/TypeInfer.h"
#include "Luau/TypeVar.h"
#include "Luau/ToString.h"
@ -143,8 +144,8 @@ std::optional<TypeId> findTypeAtPosition(const Module& module, const SourceModul
{
if (auto expr = findExprAtPosition(sourceModule, pos))
{
if (auto it = module.astTypes.find(expr); it != module.astTypes.end())
return it->second;
if (auto it = module.astTypes.find(expr))
return *it;
}
return std::nullopt;
@ -154,8 +155,8 @@ std::optional<TypeId> findExpectedTypeAtPosition(const Module& module, const Sou
{
if (auto expr = findExprAtPosition(sourceModule, pos))
{
if (auto it = module.astExpectedTypes.find(expr); it != module.astExpectedTypes.end())
return it->second;
if (auto it = module.astExpectedTypes.find(expr))
return *it;
}
return std::nullopt;
@ -322,9 +323,9 @@ std::optional<DocumentationSymbol> getDocumentationSymbolAtPosition(const Source
TypeId matchingOverload = nullptr;
if (parentExpr && parentExpr->is<AstExprCall>())
{
if (auto it = module.astOverloadResolvedTypes.find(parentExpr); it != module.astOverloadResolvedTypes.end())
if (auto it = module.astOverloadResolvedTypes.find(parentExpr))
{
matchingOverload = it->second;
matchingOverload = *it;
}
}
@ -345,9 +346,9 @@ std::optional<DocumentationSymbol> getDocumentationSymbolAtPosition(const Source
{
if (AstExprIndexName* indexName = targetExpr->as<AstExprIndexName>())
{
if (auto it = module.astTypes.find(indexName->expr); it != module.astTypes.end())
if (auto it = module.astTypes.find(indexName->expr))
{
TypeId parentTy = follow(it->second);
TypeId parentTy = follow(*it);
if (const TableTypeVar* ttv = get<TableTypeVar>(parentTy))
{
if (auto propIt = ttv->props.find(indexName->index.value); propIt != ttv->props.end())

View file

@ -12,9 +12,9 @@
#include <unordered_set>
#include <utility>
LUAU_FASTFLAG(LuauSecondTypecheckKnowsTheDataModel)
LUAU_FASTFLAGVARIABLE(ElseElseIfCompletionImprovements, false);
LUAU_FASTFLAG(LuauIfElseExpressionAnalysisSupport)
LUAU_FASTFLAGVARIABLE(LuauAutocompleteAvoidMutation, false);
static const std::unordered_set<std::string> kStatementStartingKeywords = {
"while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"};
@ -196,13 +196,27 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ
auto canUnify = [&typeArena, &module](TypeId expectedType, TypeId actualType) {
InternalErrorReporter iceReporter;
Unifier unifier(typeArena, Mode::Strict, module.getModuleScope(), Location(), Variance::Covariant, &iceReporter);
UnifierSharedState unifierState(&iceReporter);
Unifier unifier(typeArena, Mode::Strict, module.getModuleScope(), Location(), Variance::Covariant, unifierState);
unifier.tryUnify(expectedType, actualType);
if (FFlag::LuauAutocompleteAvoidMutation)
{
SeenTypes seenTypes;
SeenTypePacks seenTypePacks;
expectedType = clone(expectedType, *typeArena, seenTypes, seenTypePacks, nullptr);
actualType = clone(actualType, *typeArena, seenTypes, seenTypePacks, nullptr);
bool ok = unifier.errors.empty();
unifier.log.rollback();
return ok;
auto errors = unifier.canUnify(expectedType, actualType);
return errors.empty();
}
else
{
unifier.tryUnify(expectedType, actualType);
bool ok = unifier.errors.empty();
unifier.log.rollback();
return ok;
}
};
auto expr = node->asExpr();
@ -210,10 +224,10 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ
return TypeCorrectKind::None;
auto it = module.astExpectedTypes.find(expr);
if (it == module.astExpectedTypes.end())
if (!it)
return TypeCorrectKind::None;
TypeId expectedType = follow(it->second);
TypeId expectedType = follow(*it);
if (canUnify(expectedType, ty))
return TypeCorrectKind::Correct;
@ -368,20 +382,10 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId
while (iter != endIter)
{
if (FFlag::LuauAddMissingFollow)
{
if (isNil(*iter))
++iter;
else
break;
}
if (isNil(*iter))
++iter;
else
{
if (auto primTy = Luau::get<PrimitiveTypeVar>(*iter); primTy && primTy->type == PrimitiveTypeVar::NilType)
++iter;
else
break;
}
break;
}
if (iter == endIter)
@ -396,21 +400,10 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId
AutocompleteEntryMap inner;
std::unordered_set<TypeId> innerSeen = seen;
if (FFlag::LuauAddMissingFollow)
if (isNil(*iter))
{
if (isNil(*iter))
{
++iter;
continue;
}
}
else
{
if (auto innerPrimTy = Luau::get<PrimitiveTypeVar>(*iter); innerPrimTy && innerPrimTy->type == PrimitiveTypeVar::NilType)
{
++iter;
continue;
}
++iter;
continue;
}
autocompleteProps(module, typeArena, *iter, indexType, nodes, inner, innerSeen);
@ -495,7 +488,7 @@ static bool canSuggestInferredType(ScopePtr scope, TypeId ty)
return false;
// No syntax for unnamed tables with a metatable
if (const MetatableTypeVar* mtv = get<MetatableTypeVar>(ty))
if (get<MetatableTypeVar>(ty))
return false;
if (const TableTypeVar* ttv = get<TableTypeVar>(ty))
@ -682,12 +675,12 @@ static std::optional<bool> functionIsExpectedAt(const Module& module, AstNode* n
return std::nullopt;
auto it = module.astExpectedTypes.find(expr);
if (it == module.astExpectedTypes.end())
if (!it)
return std::nullopt;
TypeId expectedType = follow(it->second);
TypeId expectedType = follow(*it);
if (const FunctionTypeVar* ftv = get<FunctionTypeVar>(expectedType))
if (get<FunctionTypeVar>(expectedType))
return true;
if (const IntersectionTypeVar* itv = get<IntersectionTypeVar>(expectedType))
@ -784,9 +777,9 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi
if (AstExprCall* exprCall = expr->as<AstExprCall>())
{
if (auto it = module.astTypes.find(exprCall->func); it != module.astTypes.end())
if (auto it = module.astTypes.find(exprCall->func))
{
if (const FunctionTypeVar* ftv = get<FunctionTypeVar>(follow(it->second)))
if (const FunctionTypeVar* ftv = get<FunctionTypeVar>(follow(*it)))
{
if (auto ty = tryGetTypePackTypeAt(ftv->retType, tailPos))
inferredType = *ty;
@ -798,8 +791,8 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi
if (tailPos != 0)
break;
if (auto it = module.astTypes.find(expr); it != module.astTypes.end())
inferredType = it->second;
if (auto it = module.astTypes.find(expr))
inferredType = *it;
}
if (inferredType)
@ -815,10 +808,10 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi
auto tryGetExpectedFunctionType = [](const Module& module, AstExpr* expr) -> const FunctionTypeVar* {
auto it = module.astExpectedTypes.find(expr);
if (it == module.astExpectedTypes.end())
if (!it)
return nullptr;
TypeId ty = follow(it->second);
TypeId ty = follow(*it);
if (const FunctionTypeVar* ftv = get<FunctionTypeVar>(ty))
return ftv;
@ -1129,9 +1122,8 @@ static void autocompleteExpression(const SourceModule& sourceModule, const Modul
if (node->is<AstExprIndexName>())
{
auto it = module.astTypes.find(node->asExpr());
if (it != module.astTypes.end())
autocompleteProps(module, typeArena, it->second, PropIndexType::Point, ancestry, result);
if (auto it = module.astTypes.find(node->asExpr()))
autocompleteProps(module, typeArena, *it, PropIndexType::Point, ancestry, result);
}
else if (FFlag::LuauIfElseExpressionAnalysisSupport && autocompleteIfElseExpression(node, ancestry, position, result))
return;
@ -1203,13 +1195,13 @@ static std::optional<const ClassTypeVar*> getMethodContainingClass(const ModuleP
return std::nullopt;
}
auto parentIter = module->astTypes.find(parentExpr);
if (parentIter == module->astTypes.end())
auto parentIt = module->astTypes.find(parentExpr);
if (!parentIt)
{
return std::nullopt;
}
Luau::TypeId parentType = Luau::follow(parentIter->second);
Luau::TypeId parentType = Luau::follow(*parentIt);
if (auto parentClass = Luau::get<ClassTypeVar>(parentType))
{
@ -1250,8 +1242,8 @@ static std::optional<AutocompleteEntryMap> autocompleteStringParams(const Source
return std::nullopt;
}
auto iter = module->astTypes.find(candidate->func);
if (iter == module->astTypes.end())
auto it = module->astTypes.find(candidate->func);
if (!it)
{
return std::nullopt;
}
@ -1267,7 +1259,7 @@ static std::optional<AutocompleteEntryMap> autocompleteStringParams(const Source
return std::nullopt;
};
auto followedId = Luau::follow(iter->second);
auto followedId = Luau::follow(*it);
if (auto functionType = Luau::get<FunctionTypeVar>(followedId))
{
return performCallback(functionType);
@ -1316,10 +1308,10 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M
if (auto indexName = node->as<AstExprIndexName>())
{
auto it = module->astTypes.find(indexName->expr);
if (it == module->astTypes.end())
if (!it)
return {};
TypeId ty = follow(it->second);
TypeId ty = follow(*it);
PropIndexType indexType = indexName->op == ':' ? PropIndexType::Colon : PropIndexType::Point;
if (isString(ty))
@ -1447,9 +1439,9 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M
// If item doesn't have a key, maybe the value is actually the key
if (key ? key == node : node->is<AstExprGlobal>() && value == node)
{
if (auto it = module->astExpectedTypes.find(exprTable); it != module->astExpectedTypes.end())
if (auto it = module->astExpectedTypes.find(exprTable))
{
auto result = autocompleteProps(*module, typeArena, it->second, PropIndexType::Key, finder.ancestry);
auto result = autocompleteProps(*module, typeArena, *it, PropIndexType::Key, finder.ancestry);
// Remove keys that are already completed
for (const auto& item : exprTable->items)
@ -1485,9 +1477,9 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M
{
if (auto idxExpr = finder.ancestry.at(finder.ancestry.size() - 2)->as<AstExprIndexExpr>())
{
if (auto it = module->astTypes.find(idxExpr->expr); it != module->astTypes.end())
if (auto it = module->astTypes.find(idxExpr->expr))
{
return {autocompleteProps(*module, typeArena, follow(it->second), PropIndexType::Point, finder.ancestry), finder.ancestry};
return {autocompleteProps(*module, typeArena, follow(*it), PropIndexType::Point, finder.ancestry), finder.ancestry};
}
}
}
@ -1518,11 +1510,9 @@ AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName
if (!sourceModule)
return {};
TypeChecker& typeChecker =
(frontend.options.typecheckTwice && FFlag::LuauSecondTypecheckKnowsTheDataModel ? frontend.typeCheckerForAutocomplete : frontend.typeChecker);
ModulePtr module =
(frontend.options.typecheckTwice && FFlag::LuauSecondTypecheckKnowsTheDataModel ? frontend.moduleResolverForAutocomplete.getModule(moduleName)
: frontend.moduleResolver.getModule(moduleName));
TypeChecker& typeChecker = (frontend.options.typecheckTwice ? frontend.typeCheckerForAutocomplete : frontend.typeChecker);
ModulePtr module = (frontend.options.typecheckTwice ? frontend.moduleResolverForAutocomplete.getModule(moduleName)
: frontend.moduleResolver.getModule(moduleName));
if (!module)
return {};
@ -1549,8 +1539,7 @@ OwningAutocompleteResult autocompleteSource(Frontend& frontend, std::string_view
sourceModule->mode = Mode::Strict;
sourceModule->commentLocations = std::move(result.commentLocations);
TypeChecker& typeChecker =
(frontend.options.typecheckTwice && FFlag::LuauSecondTypecheckKnowsTheDataModel ? frontend.typeCheckerForAutocomplete : frontend.typeChecker);
TypeChecker& typeChecker = (frontend.options.typecheckTwice ? frontend.typeCheckerForAutocomplete : frontend.typeChecker);
ModulePtr module = typeChecker.check(*sourceModule, Mode::Strict);

View file

@ -8,10 +8,7 @@
#include <algorithm>
LUAU_FASTFLAG(LuauParseGenericFunctions)
LUAU_FASTFLAG(LuauGenericFunctions)
LUAU_FASTFLAG(LuauRankNTypes)
LUAU_FASTFLAG(LuauStringMetatable)
LUAU_FASTFLAG(LuauNewRequireTrace2)
/** FIXME: Many of these type definitions are not quite completely accurate.
*
@ -106,18 +103,6 @@ void attachMagicFunction(TypeId ty, MagicFunction fn)
LUAU_ASSERT(!"Got a non functional type");
}
void attachFunctionTag(TypeId ty, std::string tag)
{
if (auto ftv = getMutable<FunctionTypeVar>(ty))
{
ftv->tags.emplace_back(std::move(tag));
}
else
{
LUAU_ASSERT(!"Got a non functional type");
}
}
Property makeProperty(TypeId ty, std::optional<std::string> documentationSymbol)
{
return {
@ -197,28 +182,13 @@ void registerBuiltinTypes(TypeChecker& typeChecker)
TypeId numberType = typeChecker.numberType;
TypeId booleanType = typeChecker.booleanType;
TypeId nilType = typeChecker.nilType;
TypeId stringType = typeChecker.stringType;
TypeId threadType = typeChecker.threadType;
TypeId anyType = typeChecker.anyType;
TypeArena& arena = typeChecker.globalTypes;
TypeId optionalNumber = makeOption(typeChecker, arena, numberType);
TypeId optionalString = makeOption(typeChecker, arena, stringType);
TypeId optionalBoolean = makeOption(typeChecker, arena, booleanType);
TypeId stringOrNumber = makeUnion(arena, {stringType, numberType});
TypePackId emptyPack = arena.addTypePack({});
TypePackId oneNumberPack = arena.addTypePack({numberType});
TypePackId oneStringPack = arena.addTypePack({stringType});
TypePackId oneBooleanPack = arena.addTypePack({booleanType});
TypePackId oneAnyPack = arena.addTypePack({anyType});
TypePackId anyTypePack = typeChecker.anyTypePack;
TypePackId numberVariadicList = arena.addTypePack(TypePackVar{VariadicTypePack{numberType}});
TypePackId stringVariadicList = arena.addTypePack(TypePackVar{VariadicTypePack{stringType}});
TypePackId listOfAtLeastOneNumber = arena.addTypePack(TypePack{{numberType}, numberVariadicList});
TypeId listOfAtLeastOneNumberToNumberType = arena.addType(FunctionTypeVar{
@ -228,8 +198,6 @@ void registerBuiltinTypes(TypeChecker& typeChecker)
TypeId listOfAtLeastZeroNumbersToNumberType = arena.addType(FunctionTypeVar{numberVariadicList, oneNumberPack});
TypeId stringToAnyMap = arena.addType(TableTypeVar{{}, TableIndexer(stringType, anyType), typeChecker.globalScope->level});
LoadDefinitionFileResult loadResult = Luau::loadDefinitionFile(typeChecker, typeChecker.globalScope, getBuiltinDefinitionSource(), "@luau");
LUAU_ASSERT(loadResult.success);
@ -249,304 +217,19 @@ void registerBuiltinTypes(TypeChecker& typeChecker)
ttv->props["btest"] = makeProperty(arena.addType(FunctionTypeVar{listOfAtLeastOneNumber, oneBooleanPack}), "@luau/global/bit32.btest");
}
TypeId anyFunction = arena.addType(FunctionTypeVar{anyTypePack, anyTypePack});
TypeId genericK = arena.addType(GenericTypeVar{"K"});
TypeId genericV = arena.addType(GenericTypeVar{"V"});
TypeId mapOfKtoV = arena.addType(TableTypeVar{{}, TableIndexer(genericK, genericV), typeChecker.globalScope->level});
if (FFlag::LuauStringMetatable)
{
std::optional<TypeId> stringMetatableTy = getMetatable(singletonTypes.stringType);
LUAU_ASSERT(stringMetatableTy);
const TableTypeVar* stringMetatableTable = get<TableTypeVar>(follow(*stringMetatableTy));
LUAU_ASSERT(stringMetatableTable);
std::optional<TypeId> stringMetatableTy = getMetatable(singletonTypes.stringType);
LUAU_ASSERT(stringMetatableTy);
const TableTypeVar* stringMetatableTable = get<TableTypeVar>(follow(*stringMetatableTy));
LUAU_ASSERT(stringMetatableTable);
auto it = stringMetatableTable->props.find("__index");
LUAU_ASSERT(it != stringMetatableTable->props.end());
auto it = stringMetatableTable->props.find("__index");
LUAU_ASSERT(it != stringMetatableTable->props.end());
TypeId stringLib = it->second.type;
addGlobalBinding(typeChecker, "string", stringLib, "@luau");
}
if (FFlag::LuauParseGenericFunctions && FFlag::LuauGenericFunctions)
{
if (!FFlag::LuauStringMetatable)
{
TypeId stringLibTy = getGlobalBinding(typeChecker, "string");
TableTypeVar* stringLib = getMutable<TableTypeVar>(stringLibTy);
TypeId replArgType = makeUnion(
arena, {stringType,
arena.addType(TableTypeVar({}, TableIndexer(stringType, stringType), typeChecker.globalScope->level, TableState::Generic)),
makeFunction(arena, std::nullopt, {stringType}, {stringType})});
TypeId gsubFunc = makeFunction(arena, stringType, {stringType, replArgType, optionalNumber}, {stringType, numberType});
stringLib->props["gsub"] = makeProperty(gsubFunc, "@luau/global/string.gsub");
}
}
else
{
if (!FFlag::LuauStringMetatable)
{
TypeId stringToStringType = makeFunction(arena, std::nullopt, {stringType}, {stringType});
TypeId gmatchFunc = makeFunction(arena, stringType, {stringType}, {arena.addType(FunctionTypeVar{emptyPack, stringVariadicList})});
TypeId replArgType = makeUnion(
arena, {stringType,
arena.addType(TableTypeVar({}, TableIndexer(stringType, stringType), typeChecker.globalScope->level, TableState::Generic)),
makeFunction(arena, std::nullopt, {stringType}, {stringType})});
TypeId gsubFunc = makeFunction(arena, stringType, {stringType, replArgType, optionalNumber}, {stringType, numberType});
TypeId formatFn = arena.addType(FunctionTypeVar{arena.addTypePack(TypePack{{stringType}, anyTypePack}), oneStringPack});
TableTypeVar::Props stringLib = {
// FIXME string.byte "can" return a pack of numbers, but only if 2nd or 3rd arguments were supplied
{"byte", {makeFunction(arena, stringType, {optionalNumber, optionalNumber}, {optionalNumber})}},
// FIXME char takes a variadic pack of numbers
{"char", {makeFunction(arena, std::nullopt, {numberType, optionalNumber, optionalNumber, optionalNumber}, {stringType})}},
{"find", {makeFunction(arena, stringType, {stringType, optionalNumber, optionalBoolean}, {optionalNumber, optionalNumber})}},
{"format", {formatFn}}, // FIXME
{"gmatch", {gmatchFunc}},
{"gsub", {gsubFunc}},
{"len", {makeFunction(arena, stringType, {}, {numberType})}},
{"lower", {stringToStringType}},
{"match", {makeFunction(arena, stringType, {stringType, optionalNumber}, {optionalString})}},
{"rep", {makeFunction(arena, stringType, {numberType}, {stringType})}},
{"reverse", {stringToStringType}},
{"sub", {makeFunction(arena, stringType, {numberType, optionalNumber}, {stringType})}},
{"upper", {stringToStringType}},
{"split", {makeFunction(arena, stringType, {stringType, optionalString},
{arena.addType(TableTypeVar{{}, TableIndexer{numberType, stringType}, typeChecker.globalScope->level})})}},
{"pack", {arena.addType(FunctionTypeVar{
arena.addTypePack(TypePack{{stringType}, anyTypePack}),
oneStringPack,
})}},
{"packsize", {makeFunction(arena, stringType, {}, {numberType})}},
{"unpack", {arena.addType(FunctionTypeVar{
arena.addTypePack(TypePack{{stringType, stringType, optionalNumber}}),
anyTypePack,
})}},
};
assignPropDocumentationSymbols(stringLib, "@luau/global/string");
addGlobalBinding(typeChecker, "string",
arena.addType(TableTypeVar{stringLib, std::nullopt, typeChecker.globalScope->level, TableState::Sealed}), "@luau");
}
TableTypeVar::Props debugLib{
{"info", {makeIntersection(arena,
{
arena.addType(FunctionTypeVar{arena.addTypePack({typeChecker.threadType, numberType, stringType}), anyTypePack}),
arena.addType(FunctionTypeVar{arena.addTypePack({numberType, stringType}), anyTypePack}),
arena.addType(FunctionTypeVar{arena.addTypePack({anyFunction, stringType}), anyTypePack}),
})}},
{"traceback", {makeIntersection(arena,
{
makeFunction(arena, std::nullopt, {optionalString, optionalNumber}, {stringType}),
makeFunction(arena, std::nullopt, {typeChecker.threadType, optionalString, optionalNumber}, {stringType}),
})}},
};
assignPropDocumentationSymbols(debugLib, "@luau/global/debug");
addGlobalBinding(typeChecker, "debug",
arena.addType(TableTypeVar{debugLib, std::nullopt, typeChecker.globalScope->level, Luau::TableState::Sealed}), "@luau");
TableTypeVar::Props utf8Lib = {
{"char", {arena.addType(FunctionTypeVar{listOfAtLeastOneNumber, oneStringPack})}}, // FIXME
{"charpattern", {stringType}},
{"codes", {makeFunction(arena, std::nullopt, {stringType},
{makeFunction(arena, std::nullopt, {stringType, numberType}, {numberType, numberType}), stringType, numberType})}},
{"codepoint",
{arena.addType(FunctionTypeVar{arena.addTypePack({stringType, optionalNumber, optionalNumber}), listOfAtLeastOneNumber})}}, // FIXME
{"len", {makeFunction(arena, std::nullopt, {stringType, optionalNumber, optionalNumber}, {optionalNumber, numberType})}},
{"offset", {makeFunction(arena, std::nullopt, {stringType, optionalNumber, optionalNumber}, {numberType})}},
{"nfdnormalize", {makeFunction(arena, std::nullopt, {stringType}, {stringType})}},
{"graphemes", {makeFunction(arena, std::nullopt, {stringType, optionalNumber, optionalNumber},
{makeFunction(arena, std::nullopt, {}, {numberType, numberType})})}},
{"nfcnormalize", {makeFunction(arena, std::nullopt, {stringType}, {stringType})}},
};
assignPropDocumentationSymbols(utf8Lib, "@luau/global/utf8");
addGlobalBinding(
typeChecker, "utf8", arena.addType(TableTypeVar{utf8Lib, std::nullopt, typeChecker.globalScope->level, TableState::Sealed}), "@luau");
TypeId optionalV = makeOption(typeChecker, arena, genericV);
TypeId arrayOfV = arena.addType(TableTypeVar{{}, TableIndexer(numberType, genericV), typeChecker.globalScope->level});
TypePackId unpackArgsPack = arena.addTypePack(TypePack{{arrayOfV, optionalNumber, optionalNumber}});
TypePackId unpackReturnPack = arena.addTypePack(TypePack{{}, anyTypePack});
TypeId unpackFunc = arena.addType(FunctionTypeVar{{genericV}, {}, unpackArgsPack, unpackReturnPack});
TypeId packResult = arena.addType(TableTypeVar{
TableTypeVar::Props{{"n", {numberType}}}, TableIndexer{numberType, numberType}, typeChecker.globalScope->level, TableState::Sealed});
TypePackId packArgsPack = arena.addTypePack(TypePack{{}, anyTypePack});
TypePackId packReturnPack = arena.addTypePack(TypePack{{packResult}});
TypeId comparator = makeFunction(arena, std::nullopt, {genericV, genericV}, {booleanType});
TypeId optionalComparator = makeOption(typeChecker, arena, comparator);
TypeId packFn = arena.addType(FunctionTypeVar(packArgsPack, packReturnPack));
TableTypeVar::Props tableLib = {
{"concat", {makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV, optionalString, optionalNumber, optionalNumber}, {stringType})}},
{"insert", {makeIntersection(arena, {makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV, genericV}, {}),
makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV, numberType, genericV}, {})})}},
{"maxn", {makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV}, {numberType})}},
{"remove", {makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV, optionalNumber}, {optionalV})}},
{"sort", {makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV, optionalComparator}, {})}},
{"create", {makeFunction(arena, std::nullopt, {genericV}, {}, {numberType, optionalV}, {arrayOfV})}},
{"find", {makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV, genericV, optionalNumber}, {optionalNumber})}},
{"unpack", {unpackFunc}}, // FIXME
{"pack", {packFn}},
// Lua 5.0 compat
{"getn", {makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV}, {numberType})}},
{"foreach", {makeFunction(arena, std::nullopt, {genericK, genericV}, {},
{mapOfKtoV, makeFunction(arena, std::nullopt, {genericK, genericV}, {})}, {})}},
{"foreachi", {makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV, makeFunction(arena, std::nullopt, {genericV}, {})}, {})}},
// backported from Lua 5.3
{"move", {makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV, numberType, numberType, numberType, arrayOfV}, {})}},
// added in Luau (borrowed from LuaJIT)
{"clear", {makeFunction(arena, std::nullopt, {genericK, genericV}, {}, {mapOfKtoV}, {})}},
{"freeze", {makeFunction(arena, std::nullopt, {genericK, genericV}, {}, {mapOfKtoV}, {mapOfKtoV})}},
{"isfrozen", {makeFunction(arena, std::nullopt, {genericK, genericV}, {}, {mapOfKtoV}, {booleanType})}},
};
assignPropDocumentationSymbols(tableLib, "@luau/global/table");
addGlobalBinding(
typeChecker, "table", arena.addType(TableTypeVar{tableLib, std::nullopt, typeChecker.globalScope->level, TableState::Sealed}), "@luau");
TableTypeVar::Props coroutineLib = {
{"create", {makeFunction(arena, std::nullopt, {anyFunction}, {threadType})}},
{"resume", {arena.addType(FunctionTypeVar{arena.addTypePack(TypePack{{threadType}, anyTypePack}), anyTypePack})}},
{"running", {makeFunction(arena, std::nullopt, {}, {threadType})}},
{"status", {makeFunction(arena, std::nullopt, {threadType}, {stringType})}},
{"wrap", {makeFunction(
arena, std::nullopt, {anyFunction}, {anyType})}}, // FIXME this technically returns a function, but we can't represent this
// atm since it can be called with different arg types at different times
{"yield", {arena.addType(FunctionTypeVar{anyTypePack, anyTypePack})}},
{"isyieldable", {makeFunction(arena, std::nullopt, {}, {booleanType})}},
};
assignPropDocumentationSymbols(coroutineLib, "@luau/global/coroutine");
addGlobalBinding(typeChecker, "coroutine",
arena.addType(TableTypeVar{coroutineLib, std::nullopt, typeChecker.globalScope->level, TableState::Sealed}), "@luau");
TypeId genericT = arena.addType(GenericTypeVar{"T"});
TypeId genericR = arena.addType(GenericTypeVar{"R"});
// assert returns all arguments
TypePackId assertArgs = arena.addTypePack({genericT, optionalString});
TypePackId assertRets = arena.addTypePack({genericT});
addGlobalBinding(typeChecker, "assert", arena.addType(FunctionTypeVar{assertArgs, assertRets}), "@luau");
addGlobalBinding(typeChecker, "print", arena.addType(FunctionTypeVar{anyTypePack, emptyPack}), "@luau");
addGlobalBinding(typeChecker, "type", makeFunction(arena, std::nullopt, {genericT}, {}, {genericT}, {stringType}), "@luau");
addGlobalBinding(typeChecker, "typeof", makeFunction(arena, std::nullopt, {genericT}, {}, {genericT}, {stringType}), "@luau");
addGlobalBinding(typeChecker, "error", makeFunction(arena, std::nullopt, {genericT}, {}, {genericT, optionalNumber}, {}), "@luau");
addGlobalBinding(typeChecker, "tostring", makeFunction(arena, std::nullopt, {genericT}, {}, {genericT}, {stringType}), "@luau");
addGlobalBinding(
typeChecker, "tonumber", makeFunction(arena, std::nullopt, {genericT}, {}, {genericT, optionalNumber}, {numberType}), "@luau");
addGlobalBinding(
typeChecker, "rawequal", makeFunction(arena, std::nullopt, {genericT, genericR}, {}, {genericT, genericR}, {booleanType}), "@luau");
addGlobalBinding(
typeChecker, "rawget", makeFunction(arena, std::nullopt, {genericK, genericV}, {}, {mapOfKtoV, genericK}, {genericV}), "@luau");
addGlobalBinding(typeChecker, "rawset",
makeFunction(arena, std::nullopt, {genericK, genericV}, {}, {mapOfKtoV, genericK, genericV}, {mapOfKtoV}), "@luau");
TypePackId genericTPack = arena.addTypePack({genericT});
TypePackId genericRPack = arena.addTypePack({genericR});
TypeId genericArgsToReturnFunction = arena.addType(
FunctionTypeVar{{genericT, genericR}, {}, arena.addTypePack(TypePack{{}, genericTPack}), arena.addTypePack(TypePack{{}, genericRPack})});
TypeId setfenvArgType = makeUnion(arena, {numberType, genericArgsToReturnFunction});
TypeId setfenvReturnType = makeOption(typeChecker, arena, genericArgsToReturnFunction);
addGlobalBinding(typeChecker, "setfenv", makeFunction(arena, std::nullopt, {setfenvArgType, stringToAnyMap}, {setfenvReturnType}), "@luau");
TypePackId ipairsArgsTypePack = arena.addTypePack({arrayOfV});
TypeId ipairsNextFunctionType = arena.addType(
FunctionTypeVar{{genericK, genericV}, {}, arena.addTypePack({arrayOfV, numberType}), arena.addTypePack({numberType, genericV})});
// ipairs returns 'next, Array<V>, 0' so we would need type-level primitives and change to
// again, we have a direct reference to 'next' because ipairs returns it
// ipairs<V>(t: Array<V>) -> ((Array<V>) -> (number, V), Array<V>, 0)
TypePackId ipairsReturnTypePack = arena.addTypePack(TypePack{{ipairsNextFunctionType, arrayOfV, numberType}});
// ipairs<V>(t: Array<V>) -> ((Array<V>) -> (number, V), Array<V>, number)
addGlobalBinding(typeChecker, "ipairs", arena.addType(FunctionTypeVar{{genericV}, {}, ipairsArgsTypePack, ipairsReturnTypePack}), "@luau");
TypePackId pcallArg0FnArgs = arena.addTypePack(TypePackVar{GenericTypeVar{"A"}});
TypePackId pcallArg0FnRet = arena.addTypePack(TypePackVar{GenericTypeVar{"R"}});
TypeId pcallArg0 = arena.addType(FunctionTypeVar{pcallArg0FnArgs, pcallArg0FnRet});
TypePackId pcallArgsTypePack = arena.addTypePack(TypePack{{pcallArg0}, pcallArg0FnArgs});
TypePackId pcallReturnTypePack = arena.addTypePack(TypePack{{booleanType}, pcallArg0FnRet});
// pcall<A..., R...>(f: (A...) -> R..., args: A...) -> boolean, R...
addGlobalBinding(typeChecker, "pcall",
arena.addType(FunctionTypeVar{{}, {pcallArg0FnArgs, pcallArg0FnRet}, pcallArgsTypePack, pcallReturnTypePack}), "@luau");
// errors thrown by the function 'f' are propagated onto the function 'err' that accepts it.
// and either 'f' or 'err' are valid results of this xpcall
// if 'err' did throw an error, then it returns: false, "error in error handling"
// TODO: the above is not represented (nor representable) in the type annotation below.
//
// The real type of xpcall is as such: <E, A..., R1..., R2...>(f: (A...) -> R1..., err: (E) -> R2..., A...) -> (true, R1...) | (false,
// R2...)
TypePackId genericAPack = arena.addTypePack(TypePackVar{GenericTypeVar{"A"}});
TypePackId genericR1Pack = arena.addTypePack(TypePackVar{GenericTypeVar{"R1"}});
TypePackId genericR2Pack = arena.addTypePack(TypePackVar{GenericTypeVar{"R2"}});
TypeId genericE = arena.addType(GenericTypeVar{"E"});
TypeId xpcallFArg = arena.addType(FunctionTypeVar{genericAPack, genericR1Pack});
TypeId xpcallErrArg = arena.addType(FunctionTypeVar{arena.addTypePack({genericE}), genericR2Pack});
TypePackId xpcallArgsPack = arena.addTypePack({{xpcallFArg, xpcallErrArg}, genericAPack});
TypePackId xpcallRetPack = arena.addTypePack({{booleanType}, genericR1Pack}); // FIXME
addGlobalBinding(typeChecker, "xpcall",
arena.addType(FunctionTypeVar{{genericE}, {genericAPack, genericR1Pack, genericR2Pack}, xpcallArgsPack, xpcallRetPack}), "@luau");
addGlobalBinding(typeChecker, "unpack", unpackFunc, "@luau");
TypePackId selectArgsTypePack = arena.addTypePack(TypePack{
{stringOrNumber},
anyTypePack // FIXME? select() is tricky.
});
addGlobalBinding(typeChecker, "select", arena.addType(FunctionTypeVar{selectArgsTypePack, anyTypePack}), "@luau");
// TODO: not completely correct. loadstring's return type should be a function or (nil, string)
TypeId loadstringFunc = arena.addType(FunctionTypeVar{anyTypePack, oneAnyPack});
addGlobalBinding(typeChecker, "loadstring",
makeFunction(arena, std::nullopt, {stringType, optionalString},
{
makeOption(typeChecker, arena, loadstringFunc),
makeOption(typeChecker, arena, stringType),
}),
"@luau");
// a userdata object is "roughly" the same as a sealed empty table
// except `type(newproxy(false))` evaluates to "userdata" so we may need another special type here too.
// another important thing to note: the value passed in conditionally creates an empty metatable, and you have to use getmetatable, NOT
// setmetatable.
// TODO: change this to something Luau can understand how to reject `setmetatable(newproxy(false or true), {})`.
TypeId sealedTable = arena.addType(TableTypeVar(TableState::Sealed, typeChecker.globalScope->level));
addGlobalBinding(typeChecker, "newproxy", makeFunction(arena, std::nullopt, {optionalBoolean}, {sealedTable}), "@luau");
}
addGlobalBinding(typeChecker, "string", it->second.type, "@luau");
// next<K, V>(t: Table<K, V>, i: K | nil) -> (K, V)
TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(typeChecker, arena, genericK)}});
@ -555,8 +238,7 @@ void registerBuiltinTypes(TypeChecker& typeChecker)
TypePackId pairsArgsTypePack = arena.addTypePack({mapOfKtoV});
TypeId pairsNext = (FFlag::LuauRankNTypes ? arena.addType(FunctionTypeVar{nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})})
: getGlobalBinding(typeChecker, "next"));
TypeId pairsNext = arena.addType(FunctionTypeVar{nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})});
TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, nilType}});
// NOTE we are missing 'i: K | nil' argument in the first return types' argument.
@ -601,9 +283,6 @@ void registerBuiltinTypes(TypeChecker& typeChecker)
auto tableLib = getMutable<TableTypeVar>(getGlobalBinding(typeChecker, "table"));
attachMagicFunction(tableLib->props["pack"].type, magicFunctionPack);
auto stringLib = getMutable<TableTypeVar>(getGlobalBinding(typeChecker, "string"));
attachMagicFunction(stringLib->props["format"].type, magicFunctionFormat);
attachMagicFunction(getGlobalBinding(typeChecker, "require"), magicFunctionRequire);
}
@ -791,11 +470,11 @@ static std::optional<ExprResult<TypePackId>> magicFunctionRequire(
return std::nullopt;
}
AstExpr* require = expr.args.data[0];
if (!checkRequirePath(typechecker, require))
if (!checkRequirePath(typechecker, expr.args.data[0]))
return std::nullopt;
const AstExpr* require = FFlag::LuauNewRequireTrace2 ? &expr : expr.args.data[0];
if (auto moduleInfo = typechecker.resolver->resolveModuleInfo(typechecker.currentModuleName, *require))
return ExprResult<TypePackId>{arena.addTypePack({typechecker.checkRequire(scope, *moduleInfo, expr.location)})};

View file

@ -1,9 +1,6 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/BuiltinDefinitions.h"
LUAU_FASTFLAG(LuauParseGenericFunctions)
LUAU_FASTFLAG(LuauGenericFunctions)
namespace Luau
{
@ -19,6 +16,8 @@ declare bit32: {
bnot: (number) -> number,
extract: (number, number, number?) -> number,
replace: (number, number, number, number?) -> number,
countlz: (number) -> number,
countrz: (number) -> number,
}
declare math: {
@ -103,15 +102,6 @@ declare _VERSION: string
declare function gcinfo(): number
)BUILTIN_SRC";
std::string getBuiltinDefinitionSource()
{
std::string src = kBuiltinDefinitionLuaSrc;
if (FFlag::LuauParseGenericFunctions && FFlag::LuauGenericFunctions)
{
src += R"(
declare function print<T...>(...: T...)
declare function type<T>(value: T): string
@ -163,6 +153,7 @@ std::string getBuiltinDefinitionSource()
wrap: <A..., R...>((A...) -> R...) -> any,
yield: <A..., R...>(A...) -> R...,
isyieldable: () -> boolean,
close: (thread) -> (boolean, any?)
}
declare table: {
@ -206,33 +197,14 @@ std::string getBuiltinDefinitionSource()
graphemes: (string, number?, number?) -> (() -> (number, number)),
}
declare string: {
byte: (string, number?, number?) -> ...number,
char: (number, ...number) -> string,
find: (string, string, number?, boolean?) -> (number?, number?),
-- `string.format` has a magic function attached that will provide more type information for literal format strings.
format: <A...>(string, A...) -> string,
gmatch: (string, string) -> () -> (...string),
-- gsub is defined in C++ because we don't have syntax for describing a generic table.
len: (string) -> number,
lower: (string) -> string,
match: (string, string, number?) -> string?,
rep: (string, number) -> string,
reverse: (string) -> string,
sub: (string, number, number?) -> string,
upper: (string) -> string,
split: (string, string, string?) -> {string},
pack: <A...>(string, A...) -> string,
packsize: (string) -> number,
unpack: <R...>(string, string, number?) -> R...,
}
-- Cannot use `typeof` here because it will produce a polytype when we expect a monotype.
declare function unpack<V>(tab: {V}, i: number?, j: number?): ...V
)";
}
return src;
)BUILTIN_SRC";
std::string getBuiltinDefinitionSource()
{
return kBuiltinDefinitionLuaSrc;
}
} // namespace Luau

View file

@ -7,9 +7,9 @@
#include <stdexcept>
LUAU_FASTFLAG(LuauFasterStringifier)
LUAU_FASTFLAG(LuauTypeAliasPacks)
static std::string wrongNumberOfArgsString(size_t expectedCount, size_t actualCount, bool isTypeArgs = false)
static std::string wrongNumberOfArgsString_DEPRECATED(size_t expectedCount, size_t actualCount, bool isTypeArgs = false)
{
std::string s = "expects " + std::to_string(expectedCount) + " ";
@ -41,6 +41,52 @@ static std::string wrongNumberOfArgsString(size_t expectedCount, size_t actualCo
return s;
}
static std::string wrongNumberOfArgsString(size_t expectedCount, size_t actualCount, const char* argPrefix = nullptr, bool isVariadic = false)
{
std::string s;
if (FFlag::LuauTypeAliasPacks)
{
s = "expects ";
if (isVariadic)
s += "at least ";
s += std::to_string(expectedCount) + " ";
}
else
{
s = "expects " + std::to_string(expectedCount) + " ";
}
if (argPrefix)
s += std::string(argPrefix) + " ";
s += "argument";
if (expectedCount != 1)
s += "s";
s += ", but ";
if (actualCount == 0)
{
s += "none";
}
else
{
if (actualCount < expectedCount)
s += "only ";
s += std::to_string(actualCount);
}
s += (actualCount == 1) ? " is" : " are";
s += " specified";
return s;
}
namespace Luau
{
@ -48,8 +94,23 @@ struct ErrorConverter
{
std::string operator()(const Luau::TypeMismatch& tm) const
{
ToStringOptions opts;
return "Type '" + Luau::toString(tm.givenType, opts) + "' could not be converted into '" + Luau::toString(tm.wantedType, opts) + "'";
std::string result = "Type '" + Luau::toString(tm.givenType) + "' could not be converted into '" + Luau::toString(tm.wantedType) + "'";
if (tm.error)
{
result += "\ncaused by:\n ";
if (!tm.reason.empty())
result += tm.reason + ". ";
result += Luau::toString(*tm.error);
}
else if (!tm.reason.empty())
{
result += "; " + tm.reason;
}
return result;
}
std::string operator()(const Luau::UnknownSymbol& e) const
@ -119,15 +180,18 @@ struct ErrorConverter
switch (e.context)
{
case CountMismatch::Return:
return "Expected to return " + std::to_string(e.expected) + " value" + expectedS + ", but " +
std::to_string(e.actual) + " " + actualVerb + " returned here";
return "Expected to return " + std::to_string(e.expected) + " value" + expectedS + ", but " + std::to_string(e.actual) + " " +
actualVerb + " returned here";
case CountMismatch::Result:
// It is alright if right hand side produces more values than the
// left hand side accepts. In this context consider only the opposite case.
return "Function only returns " + std::to_string(e.expected) + " value" + expectedS + ". " +
std::to_string(e.actual) + " are required here";
return "Function only returns " + std::to_string(e.expected) + " value" + expectedS + ". " + std::to_string(e.actual) +
" are required here";
case CountMismatch::Arg:
return "Argument count mismatch. Function " + wrongNumberOfArgsString(e.expected, e.actual);
if (FFlag::LuauTypeAliasPacks)
return "Argument count mismatch. Function " + wrongNumberOfArgsString(e.expected, e.actual);
else
return "Argument count mismatch. Function " + wrongNumberOfArgsString_DEPRECATED(e.expected, e.actual);
}
LUAU_ASSERT(!"Unknown context");
@ -159,13 +223,16 @@ struct ErrorConverter
std::string operator()(const Luau::UnknownRequire& e) const
{
return "Unknown require: " + e.modulePath;
if (e.modulePath.empty())
return "Unknown require: unsupported path";
else
return "Unknown require: " + e.modulePath;
}
std::string operator()(const Luau::IncorrectGenericParameterCount& e) const
{
std::string name = e.name;
if (!e.typeFun.typeParams.empty())
if (!e.typeFun.typeParams.empty() || (FFlag::LuauTypeAliasPacks && !e.typeFun.typePackParams.empty()))
{
name += "<";
bool first = true;
@ -178,10 +245,37 @@ struct ErrorConverter
name += toString(t);
}
if (FFlag::LuauTypeAliasPacks)
{
for (TypePackId t : e.typeFun.typePackParams)
{
if (first)
first = false;
else
name += ", ";
name += toString(t);
}
}
name += ">";
}
return "Generic type '" + name + "' " + wrongNumberOfArgsString(e.typeFun.typeParams.size(), e.actualParameters, /*isTypeArgs*/ true);
if (FFlag::LuauTypeAliasPacks)
{
if (e.typeFun.typeParams.size() != e.actualParameters)
return "Generic type '" + name + "' " +
wrongNumberOfArgsString(e.typeFun.typeParams.size(), e.actualParameters, "type", !e.typeFun.typePackParams.empty());
return "Generic type '" + name + "' " +
wrongNumberOfArgsString(e.typeFun.typePackParams.size(), e.actualPackParameters, "type pack", /*isVariadic*/ false);
}
else
{
return "Generic type '" + name + "' " +
wrongNumberOfArgsString_DEPRECATED(e.typeFun.typeParams.size(), e.actualParameters, /*isTypeArgs*/ true);
}
}
std::string operator()(const Luau::SyntaxError& e) const
@ -399,9 +493,36 @@ struct InvalidNameChecker
}
};
TypeMismatch::TypeMismatch(TypeId wantedType, TypeId givenType)
: wantedType(wantedType)
, givenType(givenType)
{
}
TypeMismatch::TypeMismatch(TypeId wantedType, TypeId givenType, std::string reason)
: wantedType(wantedType)
, givenType(givenType)
, reason(reason)
{
}
TypeMismatch::TypeMismatch(TypeId wantedType, TypeId givenType, std::string reason, TypeError error)
: wantedType(wantedType)
, givenType(givenType)
, reason(reason)
, error(std::make_shared<TypeError>(std::move(error)))
{
}
bool TypeMismatch::operator==(const TypeMismatch& rhs) const
{
return *wantedType == *rhs.wantedType && *givenType == *rhs.givenType;
if (!!error != !!rhs.error)
return false;
if (error && !(*error == *rhs.error))
return false;
return *wantedType == *rhs.wantedType && *givenType == *rhs.givenType && reason == rhs.reason;
}
bool UnknownSymbol::operator==(const UnknownSymbol& rhs) const
@ -470,9 +591,26 @@ bool IncorrectGenericParameterCount::operator==(const IncorrectGenericParameterC
if (typeFun.typeParams.size() != rhs.typeFun.typeParams.size())
return false;
if (FFlag::LuauTypeAliasPacks)
{
if (typeFun.typePackParams.size() != rhs.typeFun.typePackParams.size())
return false;
}
for (size_t i = 0; i < typeFun.typeParams.size(); ++i)
{
if (typeFun.typeParams[i] != rhs.typeFun.typeParams[i])
return false;
}
if (FFlag::LuauTypeAliasPacks)
{
for (size_t i = 0; i < typeFun.typePackParams.size(); ++i)
{
if (typeFun.typePackParams[i] != rhs.typeFun.typePackParams[i])
return false;
}
}
return true;
}
@ -594,130 +732,141 @@ bool containsParseErrorName(const TypeError& error)
return Luau::visit(InvalidNameChecker{}, error.data);
}
void copyErrors(ErrorVec& errors, struct TypeArena& destArena)
template<typename T>
void copyError(T& e, TypeArena& destArena, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks)
{
SeenTypes seenTypes;
SeenTypePacks seenTypePacks;
auto clone = [&](auto&& ty) {
return ::Luau::clone(ty, destArena, seenTypes, seenTypePacks);
};
auto visitErrorData = [&](auto&& e) {
using T = std::decay_t<decltype(e)>;
copyError(e, destArena, seenTypes, seenTypePacks);
};
if constexpr (false)
{
}
else if constexpr (std::is_same_v<T, TypeMismatch>)
{
e.wantedType = clone(e.wantedType);
e.givenType = clone(e.givenType);
}
else if constexpr (std::is_same_v<T, UnknownSymbol>)
{
}
else if constexpr (std::is_same_v<T, UnknownProperty>)
{
e.table = clone(e.table);
}
else if constexpr (std::is_same_v<T, NotATable>)
{
e.ty = clone(e.ty);
}
else if constexpr (std::is_same_v<T, CannotExtendTable>)
{
e.tableType = clone(e.tableType);
}
else if constexpr (std::is_same_v<T, OnlyTablesCanHaveMethods>)
{
e.tableType = clone(e.tableType);
}
else if constexpr (std::is_same_v<T, DuplicateTypeDefinition>)
{
}
else if constexpr (std::is_same_v<T, CountMismatch>)
{
}
else if constexpr (std::is_same_v<T, FunctionDoesNotTakeSelf>)
{
}
else if constexpr (std::is_same_v<T, FunctionRequiresSelf>)
{
}
else if constexpr (std::is_same_v<T, OccursCheckFailed>)
{
}
else if constexpr (std::is_same_v<T, UnknownRequire>)
{
}
else if constexpr (std::is_same_v<T, IncorrectGenericParameterCount>)
{
e.typeFun = clone(e.typeFun);
}
else if constexpr (std::is_same_v<T, SyntaxError>)
{
}
else if constexpr (std::is_same_v<T, CodeTooComplex>)
{
}
else if constexpr (std::is_same_v<T, UnificationTooComplex>)
{
}
else if constexpr (std::is_same_v<T, UnknownPropButFoundLikeProp>)
{
e.table = clone(e.table);
}
else if constexpr (std::is_same_v<T, GenericError>)
{
}
else if constexpr (std::is_same_v<T, CannotCallNonFunction>)
{
e.ty = clone(e.ty);
}
else if constexpr (std::is_same_v<T, ExtraInformation>)
{
}
else if constexpr (std::is_same_v<T, DeprecatedApiUsed>)
{
}
else if constexpr (std::is_same_v<T, ModuleHasCyclicDependency>)
{
}
else if constexpr (std::is_same_v<T, IllegalRequire>)
{
}
else if constexpr (std::is_same_v<T, FunctionExitsWithoutReturning>)
{
e.expectedReturnType = clone(e.expectedReturnType);
}
else if constexpr (std::is_same_v<T, DuplicateGenericParameter>)
{
}
else if constexpr (std::is_same_v<T, CannotInferBinaryOperation>)
{
}
else if constexpr (std::is_same_v<T, MissingProperties>)
{
e.superType = clone(e.superType);
e.subType = clone(e.subType);
}
else if constexpr (std::is_same_v<T, SwappedGenericTypeParameter>)
{
}
else if constexpr (std::is_same_v<T, OptionalValueAccess>)
{
e.optional = clone(e.optional);
}
else if constexpr (std::is_same_v<T, MissingUnionProperty>)
{
e.type = clone(e.type);
if constexpr (false)
{
}
else if constexpr (std::is_same_v<T, TypeMismatch>)
{
e.wantedType = clone(e.wantedType);
e.givenType = clone(e.givenType);
for (auto& ty : e.missing)
ty = clone(ty);
}
else
static_assert(always_false_v<T>, "Non-exhaustive type switch");
if (e.error)
visit(visitErrorData, e.error->data);
}
else if constexpr (std::is_same_v<T, UnknownSymbol>)
{
}
else if constexpr (std::is_same_v<T, UnknownProperty>)
{
e.table = clone(e.table);
}
else if constexpr (std::is_same_v<T, NotATable>)
{
e.ty = clone(e.ty);
}
else if constexpr (std::is_same_v<T, CannotExtendTable>)
{
e.tableType = clone(e.tableType);
}
else if constexpr (std::is_same_v<T, OnlyTablesCanHaveMethods>)
{
e.tableType = clone(e.tableType);
}
else if constexpr (std::is_same_v<T, DuplicateTypeDefinition>)
{
}
else if constexpr (std::is_same_v<T, CountMismatch>)
{
}
else if constexpr (std::is_same_v<T, FunctionDoesNotTakeSelf>)
{
}
else if constexpr (std::is_same_v<T, FunctionRequiresSelf>)
{
}
else if constexpr (std::is_same_v<T, OccursCheckFailed>)
{
}
else if constexpr (std::is_same_v<T, UnknownRequire>)
{
}
else if constexpr (std::is_same_v<T, IncorrectGenericParameterCount>)
{
e.typeFun = clone(e.typeFun);
}
else if constexpr (std::is_same_v<T, SyntaxError>)
{
}
else if constexpr (std::is_same_v<T, CodeTooComplex>)
{
}
else if constexpr (std::is_same_v<T, UnificationTooComplex>)
{
}
else if constexpr (std::is_same_v<T, UnknownPropButFoundLikeProp>)
{
e.table = clone(e.table);
}
else if constexpr (std::is_same_v<T, GenericError>)
{
}
else if constexpr (std::is_same_v<T, CannotCallNonFunction>)
{
e.ty = clone(e.ty);
}
else if constexpr (std::is_same_v<T, ExtraInformation>)
{
}
else if constexpr (std::is_same_v<T, DeprecatedApiUsed>)
{
}
else if constexpr (std::is_same_v<T, ModuleHasCyclicDependency>)
{
}
else if constexpr (std::is_same_v<T, IllegalRequire>)
{
}
else if constexpr (std::is_same_v<T, FunctionExitsWithoutReturning>)
{
e.expectedReturnType = clone(e.expectedReturnType);
}
else if constexpr (std::is_same_v<T, DuplicateGenericParameter>)
{
}
else if constexpr (std::is_same_v<T, CannotInferBinaryOperation>)
{
}
else if constexpr (std::is_same_v<T, MissingProperties>)
{
e.superType = clone(e.superType);
e.subType = clone(e.subType);
}
else if constexpr (std::is_same_v<T, SwappedGenericTypeParameter>)
{
}
else if constexpr (std::is_same_v<T, OptionalValueAccess>)
{
e.optional = clone(e.optional);
}
else if constexpr (std::is_same_v<T, MissingUnionProperty>)
{
e.type = clone(e.type);
for (auto& ty : e.missing)
ty = clone(ty);
}
else
static_assert(always_false_v<T>, "Non-exhaustive type switch");
}
void copyErrors(ErrorVec& errors, TypeArena& destArena)
{
SeenTypes seenTypes;
SeenTypePacks seenTypePacks;
auto visitErrorData = [&](auto&& e) {
copyError(e, destArena, seenTypes, seenTypePacks);
};
LUAU_ASSERT(!destArena.typeVars.isFrozen());

View file

@ -1,9 +1,12 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Frontend.h"
#include "Luau/Common.h"
#include "Luau/Config.h"
#include "Luau/FileResolver.h"
#include "Luau/Scope.h"
#include "Luau/StringUtils.h"
#include "Luau/TimeTrace.h"
#include "Luau/TypeInfer.h"
#include "Luau/Variant.h"
#include "Luau/Common.h"
@ -15,10 +18,10 @@
LUAU_FASTFLAG(LuauInferInNoCheckMode)
LUAU_FASTFLAGVARIABLE(LuauTypeCheckTwice, false)
LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false)
LUAU_FASTFLAGVARIABLE(LuauSecondTypecheckKnowsTheDataModel, false)
LUAU_FASTFLAGVARIABLE(LuauResolveModuleNameWithoutACurrentModule, false)
LUAU_FASTFLAG(LuauTraceRequireLookupChild)
LUAU_FASTFLAGVARIABLE(LuauPersistDefinitionFileTypes, false)
LUAU_FASTFLAG(LuauNewRequireTrace2)
namespace Luau
{
@ -69,6 +72,8 @@ static void generateDocumentationSymbols(TypeId ty, const std::string& rootName)
LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr targetScope, std::string_view source, const std::string& packageName)
{
LUAU_TIMETRACE_SCOPE("loadDefinitionFile", "Frontend");
Luau::Allocator allocator;
Luau::AstNameTable names(allocator);
@ -242,7 +247,7 @@ struct RequireCycle
// Note that this is O(V^2) for a fully connected graph and produces O(V) paths of length O(V)
// However, when the graph is acyclic, this is O(V), as well as when only the first cycle is needed (stopAtFirst=true)
std::vector<RequireCycle> getRequireCycles(
const std::unordered_map<ModuleName, SourceNode>& sourceNodes, const SourceNode* start, bool stopAtFirst = false)
const FileResolver* resolver, const std::unordered_map<ModuleName, SourceNode>& sourceNodes, const SourceNode* start, bool stopAtFirst = false)
{
std::vector<RequireCycle> result;
@ -276,9 +281,9 @@ std::vector<RequireCycle> getRequireCycles(
if (top == start)
{
for (const SourceNode* node : path)
cycle.push_back(node->name);
cycle.push_back(resolver->getHumanReadableModuleName(node->name));
cycle.push_back(top->name);
cycle.push_back(resolver->getHumanReadableModuleName(top->name));
break;
}
}
@ -350,6 +355,9 @@ FrontendModuleResolver::FrontendModuleResolver(Frontend* frontend)
CheckResult Frontend::check(const ModuleName& name)
{
LUAU_TIMETRACE_SCOPE("Frontend::check", "Frontend");
LUAU_TIMETRACE_ARGUMENT("name", name.c_str());
CheckResult checkResult;
auto it = sourceNodes.find(name);
@ -395,7 +403,7 @@ CheckResult Frontend::check(const ModuleName& name)
// however, for now getRequireCycles isn't expensive in practice on the cases we care about, and long term
// all correct programs must be acyclic so this code triggers rarely
if (cycleDetected)
requireCycles = getRequireCycles(sourceNodes, &sourceNode, mode == Mode::NoCheck);
requireCycles = getRequireCycles(fileResolver, sourceNodes, &sourceNode, mode == Mode::NoCheck);
// This is used by the type checker to replace the resulting type of cyclic modules with any
sourceModule.cyclic = !requireCycles.empty();
@ -405,7 +413,7 @@ CheckResult Frontend::check(const ModuleName& name)
// If we're typechecking twice, we do so.
// The second typecheck is always in strict mode with DM awareness
// to provide better typen information for IDE features.
if (options.typecheckTwice && FFlag::LuauSecondTypecheckKnowsTheDataModel)
if (options.typecheckTwice)
{
ModulePtr moduleForAutocomplete = typeCheckerForAutocomplete.check(sourceModule, Mode::Strict);
moduleResolverForAutocomplete.modules[moduleName] = moduleForAutocomplete;
@ -449,6 +457,7 @@ CheckResult Frontend::check(const ModuleName& name)
module->astTypes.clear();
module->astExpectedTypes.clear();
module->astOriginalCallTypes.clear();
module->scopes.resize(1);
}
if (mode != Mode::NoCheck)
@ -479,6 +488,9 @@ CheckResult Frontend::check(const ModuleName& name)
bool Frontend::parseGraph(std::vector<ModuleName>& buildQueue, CheckResult& checkResult, const ModuleName& root)
{
LUAU_TIMETRACE_SCOPE("Frontend::parseGraph", "Frontend");
LUAU_TIMETRACE_ARGUMENT("root", root.c_str());
// https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search
enum Mark
{
@ -597,6 +609,9 @@ ScopePtr Frontend::getModuleEnvironment(const SourceModule& module, const Config
LintResult Frontend::lint(const ModuleName& name, std::optional<Luau::LintOptions> enabledLintWarnings)
{
LUAU_TIMETRACE_SCOPE("Frontend::lint", "Frontend");
LUAU_TIMETRACE_ARGUMENT("name", name.c_str());
CheckResult checkResult;
auto [_sourceNode, sourceModule] = getSourceNode(checkResult, name);
@ -608,6 +623,8 @@ LintResult Frontend::lint(const ModuleName& name, std::optional<Luau::LintOption
std::pair<SourceModule, LintResult> Frontend::lintFragment(std::string_view source, std::optional<Luau::LintOptions> enabledLintWarnings)
{
LUAU_TIMETRACE_SCOPE("Frontend::lintFragment", "Frontend");
const Config& config = configResolver->getConfig("");
SourceModule sourceModule = parse(ModuleName{}, source, config.parseOptions);
@ -627,6 +644,9 @@ std::pair<SourceModule, LintResult> Frontend::lintFragment(std::string_view sour
CheckResult Frontend::check(const SourceModule& module)
{
LUAU_TIMETRACE_SCOPE("Frontend::check", "Frontend");
LUAU_TIMETRACE_ARGUMENT("module", module.name.c_str());
const Config& config = configResolver->getConfig(module.name);
Mode mode = module.mode.value_or(config.mode);
@ -648,6 +668,9 @@ CheckResult Frontend::check(const SourceModule& module)
LintResult Frontend::lint(const SourceModule& module, std::optional<Luau::LintOptions> enabledLintWarnings)
{
LUAU_TIMETRACE_SCOPE("Frontend::lint", "Frontend");
LUAU_TIMETRACE_ARGUMENT("module", module.name.c_str());
const Config& config = configResolver->getConfig(module.name);
LintOptions options = enabledLintWarnings.value_or(config.enabledLint);
@ -746,6 +769,9 @@ const SourceModule* Frontend::getSourceModule(const ModuleName& moduleName) cons
// Read AST into sourceModules if necessary. Trace require()s. Report parse errors.
std::pair<SourceNode*, SourceModule*> Frontend::getSourceNode(CheckResult& checkResult, const ModuleName& name)
{
LUAU_TIMETRACE_SCOPE("Frontend::getSourceNode", "Frontend");
LUAU_TIMETRACE_ARGUMENT("name", name.c_str());
auto it = sourceNodes.find(name);
if (it != sourceNodes.end() && !it->second.dirty)
{
@ -815,6 +841,9 @@ std::pair<SourceNode*, SourceModule*> Frontend::getSourceNode(CheckResult& check
*/
SourceModule Frontend::parse(const ModuleName& name, std::string_view src, const ParseOptions& parseOptions)
{
LUAU_TIMETRACE_SCOPE("Frontend::parse", "Frontend");
LUAU_TIMETRACE_ARGUMENT("name", name.c_str());
SourceModule sourceModule;
double timestamp = getTimestamp();
@ -864,20 +893,11 @@ std::optional<ModuleInfo> FrontendModuleResolver::resolveModuleInfo(const Module
const auto& exprs = it->second.exprs;
const ModuleName* relativeName = exprs.find(&pathExpr);
if (!relativeName || relativeName->empty())
const ModuleInfo* info = exprs.find(&pathExpr);
if (!info || (!FFlag::LuauNewRequireTrace2 && info->name.empty()))
return std::nullopt;
if (FFlag::LuauTraceRequireLookupChild)
{
const bool* optional = it->second.optional.find(&pathExpr);
return {{*relativeName, optional ? *optional : false}};
}
else
{
return {{*relativeName, false}};
}
return *info;
}
const ModulePtr FrontendModuleResolver::getModule(const ModuleName& moduleName) const
@ -891,12 +911,15 @@ const ModulePtr FrontendModuleResolver::getModule(const ModuleName& moduleName)
bool FrontendModuleResolver::moduleExists(const ModuleName& moduleName) const
{
return frontend->fileResolver->moduleExists(moduleName);
if (FFlag::LuauNewRequireTrace2)
return frontend->sourceNodes.count(moduleName) != 0;
else
return frontend->fileResolver->moduleExists(moduleName);
}
std::string FrontendModuleResolver::getHumanReadableModuleName(const ModuleName& moduleName) const
{
return frontend->fileResolver->getHumanReadableModuleName_(moduleName).value_or(moduleName);
return frontend->fileResolver->getHumanReadableModuleName(moduleName);
}
ScopePtr Frontend::addEnvironment(const std::string& environmentName)

View file

@ -2,6 +2,8 @@
#include "Luau/IostreamHelpers.h"
#include "Luau/ToString.h"
LUAU_FASTFLAG(LuauTypeAliasPacks)
namespace Luau
{
@ -92,7 +94,7 @@ std::ostream& operator<<(std::ostream& stream, const IncorrectGenericParameterCo
{
stream << "IncorrectGenericParameterCount { name = " << error.name;
if (!error.typeFun.typeParams.empty())
if (!error.typeFun.typeParams.empty() || (FFlag::LuauTypeAliasPacks && !error.typeFun.typePackParams.empty()))
{
stream << "<";
bool first = true;
@ -105,6 +107,20 @@ std::ostream& operator<<(std::ostream& stream, const IncorrectGenericParameterCo
stream << toString(t);
}
if (FFlag::LuauTypeAliasPacks)
{
for (TypePackId t : error.typeFun.typePackParams)
{
if (first)
first = false;
else
stream << ", ";
stream << toString(t);
}
}
stream << ">";
}

View file

@ -3,6 +3,9 @@
#include "Luau/Ast.h"
#include "Luau/StringUtils.h"
#include "Luau/Common.h"
LUAU_FASTFLAG(LuauTypeAliasPacks)
namespace Luau
{
@ -612,6 +615,12 @@ struct AstJsonEncoder : public AstVisitor
writeNode(node, "AstStatTypeAlias", [&]() {
PROP(name);
PROP(generics);
if (FFlag::LuauTypeAliasPacks)
{
PROP(genericPacks);
}
PROP(type);
PROP(exported);
});
@ -664,13 +673,21 @@ struct AstJsonEncoder : public AstVisitor
});
}
void write(struct AstTypeOrPack node)
{
if (node.type)
write(node.type);
else
write(node.typePack);
}
void write(class AstTypeReference* node)
{
writeNode(node, "AstTypeReference", [&]() {
if (node->hasPrefix)
PROP(prefix);
PROP(name);
PROP(generics);
PROP(parameters);
});
}
@ -734,6 +751,13 @@ struct AstJsonEncoder : public AstVisitor
});
}
void write(class AstTypePackExplicit* node)
{
writeNode(node, "AstTypePackExplicit", [&]() {
PROP(typeList);
});
}
void write(class AstTypePackVariadic* node)
{
writeNode(node, "AstTypePackVariadic", [&]() {
@ -1018,6 +1042,12 @@ struct AstJsonEncoder : public AstVisitor
return false;
}
bool visit(class AstTypePackExplicit* node) override
{
write(node);
return false;
}
bool visit(class AstTypePackVariadic* node) override
{
write(node);

View file

@ -3,6 +3,7 @@
#include "Luau/AstQuery.h"
#include "Luau/Module.h"
#include "Luau/Scope.h"
#include "Luau/TypeInfer.h"
#include "Luau/StringUtils.h"
#include "Luau/Common.h"
@ -11,8 +12,6 @@
#include <math.h>
#include <limits.h>
LUAU_FASTFLAGVARIABLE(LuauLinterUnknownTypeVectorAware, false)
namespace Luau
{
@ -85,10 +84,10 @@ struct LintContext
return std::nullopt;
auto it = module->astTypes.find(expr);
if (it == module->astTypes.end())
if (!it)
return std::nullopt;
return it->second;
return *it;
}
};
@ -1108,10 +1107,7 @@ private:
if (g && g->name == "type")
{
if (FFlag::LuauLinterUnknownTypeVectorAware)
validateType(arg, {Kind_Primitive, Kind_Vector}, "primitive type");
else
validateType(arg, {Kind_Primitive}, "primitive type");
validateType(arg, {Kind_Primitive, Kind_Vector}, "primitive type");
}
else if (g && g->name == "typeof")
{
@ -2144,6 +2140,19 @@ private:
"wrap it in parentheses to silence");
}
if (func->index == "move" && node->args.size >= 4)
{
// table.move(t, 0, _, _)
if (isConstant(args[1], 0.0))
emitWarning(*context, LintWarning::Code_TableOperations, args[1]->location,
"table.move uses index 0 but arrays are 1-based; did you mean 1 instead?");
// table.move(t, _, _, 0)
else if (isConstant(args[3], 0.0))
emitWarning(*context, LintWarning::Code_TableOperations, args[3]->location,
"table.move uses index 0 but arrays are 1-based; did you mean 1 instead?");
}
return true;
}

View file

@ -1,6 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Module.h"
#include "Luau/Scope.h"
#include "Luau/TypeInfer.h"
#include "Luau/TypePack.h"
#include "Luau/TypeVar.h"
@ -11,8 +12,9 @@
LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false)
LUAU_FASTFLAGVARIABLE(DebugLuauTrackOwningArena, false)
LUAU_FASTFLAG(LuauSecondTypecheckKnowsTheDataModel)
LUAU_FASTFLAG(LuauCaptureBrokenCommentSpans)
LUAU_FASTFLAG(LuauTypeAliasPacks)
LUAU_FASTFLAGVARIABLE(LuauCloneBoundTables, false)
namespace Luau
{
@ -159,6 +161,7 @@ struct TypeCloner
void operator()(const Unifiable::Bound<TypeId>& t);
void operator()(const Unifiable::Error& t);
void operator()(const PrimitiveTypeVar& t);
void operator()(const SingletonTypeVar& t);
void operator()(const FunctionTypeVar& t);
void operator()(const TableTypeVar& t);
void operator()(const MetatableTypeVar& t);
@ -188,7 +191,7 @@ struct TypePackCloner
template<typename T>
void defaultClone(const T& t)
{
TypePackId cloned = dest.typePacks.allocate(t);
TypePackId cloned = dest.addTypePack(TypePackVar{t});
seenTypePacks[typePackId] = cloned;
}
@ -197,7 +200,9 @@ struct TypePackCloner
if (encounteredFreeType)
*encounteredFreeType = true;
seenTypePacks[typePackId] = dest.typePacks.allocate(TypePackVar{Unifiable::Error{}});
TypePackId err = singletonTypes.errorRecoveryTypePack(singletonTypes.anyTypePack);
TypePackId cloned = dest.addTypePack(*err);
seenTypePacks[typePackId] = cloned;
}
void operator()(const Unifiable::Generic& t)
@ -219,13 +224,13 @@ struct TypePackCloner
void operator()(const VariadicTypePack& t)
{
TypePackId cloned = dest.typePacks.allocate(VariadicTypePack{clone(t.ty, dest, seenTypes, seenTypePacks, encounteredFreeType)});
TypePackId cloned = dest.addTypePack(TypePackVar{VariadicTypePack{clone(t.ty, dest, seenTypes, seenTypePacks, encounteredFreeType)}});
seenTypePacks[typePackId] = cloned;
}
void operator()(const TypePack& t)
{
TypePackId cloned = dest.typePacks.allocate(TypePack{});
TypePackId cloned = dest.addTypePack(TypePack{});
TypePack* destTp = getMutable<TypePack>(cloned);
LUAU_ASSERT(destTp != nullptr);
seenTypePacks[typePackId] = cloned;
@ -241,7 +246,7 @@ struct TypePackCloner
template<typename T>
void TypeCloner::defaultClone(const T& t)
{
TypeId cloned = dest.typeVars.allocate(t);
TypeId cloned = dest.addType(t);
seenTypes[typeId] = cloned;
}
@ -249,8 +254,9 @@ void TypeCloner::operator()(const Unifiable::Free& t)
{
if (encounteredFreeType)
*encounteredFreeType = true;
seenTypes[typeId] = dest.typeVars.allocate(ErrorTypeVar{});
TypeId err = singletonTypes.errorRecoveryType(singletonTypes.anyType);
TypeId cloned = dest.addType(*err);
seenTypes[typeId] = cloned;
}
void TypeCloner::operator()(const Unifiable::Generic& t)
@ -268,14 +274,20 @@ void TypeCloner::operator()(const Unifiable::Error& t)
{
defaultClone(t);
}
void TypeCloner::operator()(const PrimitiveTypeVar& t)
{
defaultClone(t);
}
void TypeCloner::operator()(const SingletonTypeVar& t)
{
defaultClone(t);
}
void TypeCloner::operator()(const FunctionTypeVar& t)
{
TypeId result = dest.typeVars.allocate(FunctionTypeVar{TypeLevel{0, 0}, {}, {}, nullptr, nullptr, t.definition, t.hasSelf});
TypeId result = dest.addType(FunctionTypeVar{TypeLevel{0, 0}, {}, {}, nullptr, nullptr, t.definition, t.hasSelf});
FunctionTypeVar* ftv = getMutable<FunctionTypeVar>(result);
LUAU_ASSERT(ftv != nullptr);
@ -287,9 +299,7 @@ void TypeCloner::operator()(const FunctionTypeVar& t)
for (TypePackId genericPack : t.genericPacks)
ftv->genericPacks.push_back(clone(genericPack, dest, seenTypes, seenTypePacks, encounteredFreeType));
if (FFlag::LuauSecondTypecheckKnowsTheDataModel)
ftv->tags = t.tags;
ftv->tags = t.tags;
ftv->argTypes = clone(t.argTypes, dest, seenTypes, seenTypePacks, encounteredFreeType);
ftv->argNames = t.argNames;
ftv->retType = clone(t.retType, dest, seenTypes, seenTypePacks, encounteredFreeType);
@ -297,7 +307,15 @@ void TypeCloner::operator()(const FunctionTypeVar& t)
void TypeCloner::operator()(const TableTypeVar& t)
{
TypeId result = dest.typeVars.allocate(TableTypeVar{});
// If table is now bound to another one, we ignore the content of the original
if (FFlag::LuauCloneBoundTables && t.boundTo)
{
TypeId boundTo = clone(*t.boundTo, dest, seenTypes, seenTypePacks, encounteredFreeType);
seenTypes[typeId] = boundTo;
return;
}
TypeId result = dest.addType(TableTypeVar{});
TableTypeVar* ttv = getMutable<TableTypeVar>(result);
LUAU_ASSERT(ttv != nullptr);
@ -308,26 +326,30 @@ void TypeCloner::operator()(const TableTypeVar& t)
ttv->level = TypeLevel{0, 0};
for (const auto& [name, prop] : t.props)
{
if (FFlag::LuauSecondTypecheckKnowsTheDataModel)
ttv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, encounteredFreeType), prop.deprecated, {}, prop.location, prop.tags};
else
ttv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, encounteredFreeType), prop.deprecated, {}, prop.location};
}
ttv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, encounteredFreeType), prop.deprecated, {}, prop.location, prop.tags};
if (t.indexer)
ttv->indexer = TableIndexer{clone(t.indexer->indexType, dest, seenTypes, seenTypePacks, encounteredFreeType),
clone(t.indexer->indexResultType, dest, seenTypes, seenTypePacks, encounteredFreeType)};
if (t.boundTo)
ttv->boundTo = clone(*t.boundTo, dest, seenTypes, seenTypePacks, encounteredFreeType);
if (!FFlag::LuauCloneBoundTables)
{
if (t.boundTo)
ttv->boundTo = clone(*t.boundTo, dest, seenTypes, seenTypePacks, encounteredFreeType);
}
for (TypeId& arg : ttv->instantiatedTypeParams)
arg = (clone(arg, dest, seenTypes, seenTypePacks, encounteredFreeType));
arg = clone(arg, dest, seenTypes, seenTypePacks, encounteredFreeType);
if (FFlag::LuauTypeAliasPacks)
{
for (TypePackId& arg : ttv->instantiatedTypePackParams)
arg = clone(arg, dest, seenTypes, seenTypePacks, encounteredFreeType);
}
if (ttv->state == TableState::Free)
{
if (!t.boundTo)
if (FFlag::LuauCloneBoundTables || !t.boundTo)
{
if (encounteredFreeType)
*encounteredFreeType = true;
@ -343,7 +365,7 @@ void TypeCloner::operator()(const TableTypeVar& t)
void TypeCloner::operator()(const MetatableTypeVar& t)
{
TypeId result = dest.typeVars.allocate(MetatableTypeVar{});
TypeId result = dest.addType(MetatableTypeVar{});
MetatableTypeVar* mtv = getMutable<MetatableTypeVar>(result);
seenTypes[typeId] = result;
@ -353,16 +375,13 @@ void TypeCloner::operator()(const MetatableTypeVar& t)
void TypeCloner::operator()(const ClassTypeVar& t)
{
TypeId result = dest.typeVars.allocate(ClassTypeVar{t.name, {}, std::nullopt, std::nullopt, t.tags, t.userData});
TypeId result = dest.addType(ClassTypeVar{t.name, {}, std::nullopt, std::nullopt, t.tags, t.userData});
ClassTypeVar* ctv = getMutable<ClassTypeVar>(result);
seenTypes[typeId] = result;
for (const auto& [name, prop] : t.props)
if (FFlag::LuauSecondTypecheckKnowsTheDataModel)
ctv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, encounteredFreeType), prop.deprecated, {}, prop.location, prop.tags};
else
ctv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, encounteredFreeType), prop.deprecated, {}, prop.location};
ctv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, encounteredFreeType), prop.deprecated, {}, prop.location, prop.tags};
if (t.parent)
ctv->parent = clone(*t.parent, dest, seenTypes, seenTypePacks, encounteredFreeType);
@ -378,7 +397,7 @@ void TypeCloner::operator()(const AnyTypeVar& t)
void TypeCloner::operator()(const UnionTypeVar& t)
{
TypeId result = dest.typeVars.allocate(UnionTypeVar{});
TypeId result = dest.addType(UnionTypeVar{});
seenTypes[typeId] = result;
UnionTypeVar* option = getMutable<UnionTypeVar>(result);
@ -390,7 +409,7 @@ void TypeCloner::operator()(const UnionTypeVar& t)
void TypeCloner::operator()(const IntersectionTypeVar& t)
{
TypeId result = dest.typeVars.allocate(IntersectionTypeVar{});
TypeId result = dest.addType(IntersectionTypeVar{});
seenTypes[typeId] = result;
IntersectionTypeVar* option = getMutable<IntersectionTypeVar>(result);
@ -451,8 +470,14 @@ TypeId clone(TypeId typeId, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks
TypeFun clone(const TypeFun& typeFun, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, bool* encounteredFreeType)
{
TypeFun result;
for (TypeId param : typeFun.typeParams)
result.typeParams.push_back(clone(param, dest, seenTypes, seenTypePacks, encounteredFreeType));
for (TypeId ty : typeFun.typeParams)
result.typeParams.push_back(clone(ty, dest, seenTypes, seenTypePacks, encounteredFreeType));
if (FFlag::LuauTypeAliasPacks)
{
for (TypePackId tp : typeFun.typePackParams)
result.typePackParams.push_back(clone(tp, dest, seenTypes, seenTypePacks, encounteredFreeType));
}
result.type = clone(typeFun.type, dest, seenTypes, seenTypePacks, encounteredFreeType);

View file

@ -3,8 +3,6 @@
#include "Luau/Ast.h"
LUAU_FASTFLAG(LuauOrPredicate)
namespace Luau
{
@ -60,8 +58,6 @@ std::string toString(const LValue& lvalue)
void merge(RefinementMap& l, const RefinementMap& r, std::function<TypeId(TypeId, TypeId)> f)
{
LUAU_ASSERT(FFlag::LuauOrPredicate);
auto itL = l.begin();
auto itR = r.begin();
while (itL != l.end() && itR != r.end())

90
Analysis/src/Quantify.cpp Normal file
View file

@ -0,0 +1,90 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Quantify.h"
#include "Luau/VisitTypeVar.h"
namespace Luau
{
struct Quantifier
{
ModulePtr module;
TypeLevel level;
std::vector<TypeId> generics;
std::vector<TypePackId> genericPacks;
Quantifier(ModulePtr module, TypeLevel level)
: module(module)
, level(level)
{
}
void cycle(TypeId) {}
void cycle(TypePackId) {}
bool operator()(TypeId ty, const FreeTypeVar& ftv)
{
if (!level.subsumes(ftv.level))
return false;
*asMutable(ty) = GenericTypeVar{level};
generics.push_back(ty);
return false;
}
template<typename T>
bool operator()(TypeId ty, const T& t)
{
return true;
}
template<typename T>
bool operator()(TypePackId, const T&)
{
return true;
}
bool operator()(TypeId ty, const TableTypeVar&)
{
TableTypeVar& ttv = *getMutable<TableTypeVar>(ty);
if (ttv.state == TableState::Sealed || ttv.state == TableState::Generic)
return false;
if (!level.subsumes(ttv.level))
return false;
if (ttv.state == TableState::Free)
ttv.state = TableState::Generic;
else if (ttv.state == TableState::Unsealed)
ttv.state = TableState::Sealed;
ttv.level = level;
return true;
}
bool operator()(TypePackId tp, const FreeTypePack& ftp)
{
if (!level.subsumes(ftp.level))
return false;
*asMutable(tp) = GenericTypePack{level};
genericPacks.push_back(tp);
return true;
}
};
void quantify(ModulePtr module, TypeId ty, TypeLevel level)
{
Quantifier q{std::move(module), level};
visitTypeVar(ty, q);
FunctionTypeVar* ftv = getMutable<FunctionTypeVar>(ty);
LUAU_ASSERT(ftv);
ftv->generics = q.generics;
ftv->genericPacks = q.genericPacks;
}
} // namespace Luau

View file

@ -5,6 +5,7 @@
#include "Luau/Module.h"
LUAU_FASTFLAGVARIABLE(LuauTraceRequireLookupChild, false)
LUAU_FASTFLAGVARIABLE(LuauNewRequireTrace2, false)
namespace Luau
{
@ -12,17 +13,18 @@ namespace Luau
namespace
{
struct RequireTracer : AstVisitor
struct RequireTracerOld : AstVisitor
{
explicit RequireTracer(FileResolver* fileResolver, ModuleName currentModuleName)
explicit RequireTracerOld(FileResolver* fileResolver, const ModuleName& currentModuleName)
: fileResolver(fileResolver)
, currentModuleName(std::move(currentModuleName))
, currentModuleName(currentModuleName)
{
LUAU_ASSERT(!FFlag::LuauNewRequireTrace2);
}
FileResolver* const fileResolver;
ModuleName currentModuleName;
DenseHashMap<AstLocal*, ModuleName> locals{0};
DenseHashMap<AstLocal*, ModuleName> locals{nullptr};
RequireTraceResult result;
std::optional<ModuleName> fromAstFragment(AstExpr* expr)
@ -50,9 +52,9 @@ struct RequireTracer : AstVisitor
AstExpr* expr = stat->values.data[i];
expr->visit(this);
const ModuleName* name = result.exprs.find(expr);
if (name)
locals[local] = *name;
const ModuleInfo* info = result.exprs.find(expr);
if (info)
locals[local] = info->name;
}
}
@ -63,7 +65,7 @@ struct RequireTracer : AstVisitor
{
std::optional<ModuleName> name = fromAstFragment(global);
if (name)
result.exprs[global] = *name;
result.exprs[global] = {*name};
return false;
}
@ -72,7 +74,7 @@ struct RequireTracer : AstVisitor
{
const ModuleName* name = locals.find(local->local);
if (name)
result.exprs[local] = *name;
result.exprs[local] = {*name};
return false;
}
@ -81,16 +83,16 @@ struct RequireTracer : AstVisitor
{
indexName->expr->visit(this);
const ModuleName* name = result.exprs.find(indexName->expr);
if (name)
const ModuleInfo* info = result.exprs.find(indexName->expr);
if (info)
{
if (indexName->index == "parent" || indexName->index == "Parent")
{
if (auto parent = fileResolver->getParentModuleName(*name))
result.exprs[indexName] = *parent;
if (auto parent = fileResolver->getParentModuleName(info->name))
result.exprs[indexName] = {*parent};
}
else
result.exprs[indexName] = fileResolver->concat(*name, indexName->index.value);
result.exprs[indexName] = {fileResolver->concat(info->name, indexName->index.value)};
}
return false;
@ -100,11 +102,11 @@ struct RequireTracer : AstVisitor
{
indexExpr->expr->visit(this);
const ModuleName* name = result.exprs.find(indexExpr->expr);
const ModuleInfo* info = result.exprs.find(indexExpr->expr);
const AstExprConstantString* str = indexExpr->index->as<AstExprConstantString>();
if (name && str)
if (info && str)
{
result.exprs[indexExpr] = fileResolver->concat(*name, std::string_view(str->value.data, str->value.size));
result.exprs[indexExpr] = {fileResolver->concat(info->name, std::string_view(str->value.data, str->value.size))};
}
indexExpr->index->visit(this);
@ -129,8 +131,8 @@ struct RequireTracer : AstVisitor
AstExprGlobal* globalName = call->func->as<AstExprGlobal>();
if (globalName && globalName->name == "require" && call->args.size >= 1)
{
if (const ModuleName* moduleName = result.exprs.find(call->args.data[0]))
result.requires.push_back({*moduleName, call->location});
if (const ModuleInfo* moduleInfo = result.exprs.find(call->args.data[0]))
result.requires.push_back({moduleInfo->name, call->location});
return false;
}
@ -143,8 +145,8 @@ struct RequireTracer : AstVisitor
if (FFlag::LuauTraceRequireLookupChild && !rootName)
{
if (const ModuleName* moduleName = result.exprs.find(indexName->expr))
rootName = *moduleName;
if (const ModuleInfo* moduleInfo = result.exprs.find(indexName->expr))
rootName = moduleInfo->name;
}
if (!rootName)
@ -167,24 +169,183 @@ struct RequireTracer : AstVisitor
if (v.end() != std::find(v.begin(), v.end(), '/'))
return false;
result.exprs[call] = fileResolver->concat(*rootName, v);
result.exprs[call] = {fileResolver->concat(*rootName, v)};
// 'WaitForChild' can be used on modules that are not available at the typecheck time, but will be available at runtime
// If we fail to find such module, we will not report an UnknownRequire error
if (FFlag::LuauTraceRequireLookupChild && indexName->index == "WaitForChild")
result.optional[call] = true;
result.exprs[call].optional = true;
return false;
}
};
struct RequireTracer : AstVisitor
{
RequireTracer(RequireTraceResult& result, FileResolver* fileResolver, const ModuleName& currentModuleName)
: result(result)
, fileResolver(fileResolver)
, currentModuleName(currentModuleName)
, locals(nullptr)
{
LUAU_ASSERT(FFlag::LuauNewRequireTrace2);
}
bool visit(AstExprTypeAssertion* expr) override
{
// suppress `require() :: any`
return false;
}
bool visit(AstExprCall* expr) override
{
AstExprGlobal* global = expr->func->as<AstExprGlobal>();
if (global && global->name == "require" && expr->args.size >= 1)
requires.push_back(expr);
return true;
}
bool visit(AstStatLocal* stat) override
{
for (size_t i = 0; i < stat->vars.size && i < stat->values.size; ++i)
{
AstLocal* local = stat->vars.data[i];
AstExpr* expr = stat->values.data[i];
// track initializing expression to be able to trace modules through locals
locals[local] = expr;
}
return true;
}
bool visit(AstStatAssign* stat) override
{
for (size_t i = 0; i < stat->vars.size; ++i)
{
// locals that are assigned don't have a known expression
if (AstExprLocal* expr = stat->vars.data[i]->as<AstExprLocal>())
locals[expr->local] = nullptr;
}
return true;
}
bool visit(AstType* node) override
{
// allow resolving require inside `typeof` annotations
return true;
}
AstExpr* getDependent(AstExpr* node)
{
if (AstExprLocal* expr = node->as<AstExprLocal>())
return locals[expr->local];
else if (AstExprIndexName* expr = node->as<AstExprIndexName>())
return expr->expr;
else if (AstExprIndexExpr* expr = node->as<AstExprIndexExpr>())
return expr->expr;
else if (AstExprCall* expr = node->as<AstExprCall>(); expr && expr->self)
return expr->func->as<AstExprIndexName>()->expr;
else
return nullptr;
}
void process()
{
ModuleInfo moduleContext{currentModuleName};
// seed worklist with require arguments
work.reserve(requires.size());
for (AstExprCall* require : requires)
work.push_back(require->args.data[0]);
// push all dependent expressions to the work stack; note that the vector is modified during traversal
for (size_t i = 0; i < work.size(); ++i)
if (AstExpr* dep = getDependent(work[i]))
work.push_back(dep);
// resolve all expressions to a module info
for (size_t i = work.size(); i > 0; --i)
{
AstExpr* expr = work[i - 1];
// when multiple expressions depend on the same one we push it to work queue multiple times
if (result.exprs.contains(expr))
continue;
std::optional<ModuleInfo> info;
if (AstExpr* dep = getDependent(expr))
{
const ModuleInfo* context = result.exprs.find(dep);
// locals just inherit their dependent context, no resolution required
if (expr->is<AstExprLocal>())
info = context ? std::optional<ModuleInfo>(*context) : std::nullopt;
else
info = fileResolver->resolveModule(context, expr);
}
else
{
info = fileResolver->resolveModule(&moduleContext, expr);
}
if (info)
result.exprs[expr] = std::move(*info);
}
// resolve all requires according to their argument
result.requires.reserve(requires.size());
for (AstExprCall* require : requires)
{
AstExpr* arg = require->args.data[0];
if (const ModuleInfo* info = result.exprs.find(arg))
{
result.requires.push_back({info->name, require->location});
ModuleInfo infoCopy = *info; // copy *info out since next line invalidates info!
result.exprs[require] = std::move(infoCopy);
}
else
{
result.exprs[require] = {}; // mark require as unresolved
}
}
}
RequireTraceResult& result;
FileResolver* fileResolver;
ModuleName currentModuleName;
DenseHashMap<AstLocal*, AstExpr*> locals;
std::vector<AstExpr*> work;
std::vector<AstExprCall*> requires;
};
} // anonymous namespace
RequireTraceResult traceRequires(FileResolver* fileResolver, AstStatBlock* root, ModuleName currentModuleName)
RequireTraceResult traceRequires(FileResolver* fileResolver, AstStatBlock* root, const ModuleName& currentModuleName)
{
RequireTracer tracer{fileResolver, std::move(currentModuleName)};
root->visit(&tracer);
return tracer.result;
if (FFlag::LuauNewRequireTrace2)
{
RequireTraceResult result;
RequireTracer tracer{result, fileResolver, currentModuleName};
root->visit(&tracer);
tracer.process();
return result;
}
else
{
RequireTracerOld tracer{fileResolver, currentModuleName};
root->visit(&tracer);
return tracer.result;
}
}
} // namespace Luau

123
Analysis/src/Scope.cpp Normal file
View file

@ -0,0 +1,123 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Scope.h"
namespace Luau
{
Scope::Scope(TypePackId returnType)
: parent(nullptr)
, returnType(returnType)
, level(TypeLevel())
{
}
Scope::Scope(const ScopePtr& parent, int subLevel)
: parent(parent)
, returnType(parent->returnType)
, level(parent->level.incr())
{
level.subLevel = subLevel;
}
std::optional<TypeId> Scope::lookup(const Symbol& name)
{
Scope* scope = this;
while (scope)
{
auto it = scope->bindings.find(name);
if (it != scope->bindings.end())
return it->second.typeId;
scope = scope->parent.get();
}
return std::nullopt;
}
std::optional<TypeFun> Scope::lookupType(const Name& name)
{
const Scope* scope = this;
while (true)
{
auto it = scope->exportedTypeBindings.find(name);
if (it != scope->exportedTypeBindings.end())
return it->second;
it = scope->privateTypeBindings.find(name);
if (it != scope->privateTypeBindings.end())
return it->second;
if (scope->parent)
scope = scope->parent.get();
else
return std::nullopt;
}
}
std::optional<TypeFun> Scope::lookupImportedType(const Name& moduleAlias, const Name& name)
{
const Scope* scope = this;
while (scope)
{
auto it = scope->importedTypeBindings.find(moduleAlias);
if (it == scope->importedTypeBindings.end())
{
scope = scope->parent.get();
continue;
}
auto it2 = it->second.find(name);
if (it2 == it->second.end())
{
scope = scope->parent.get();
continue;
}
return it2->second;
}
return std::nullopt;
}
std::optional<TypePackId> Scope::lookupPack(const Name& name)
{
const Scope* scope = this;
while (true)
{
auto it = scope->privateTypePackBindings.find(name);
if (it != scope->privateTypePackBindings.end())
return it->second;
if (scope->parent)
scope = scope->parent.get();
else
return std::nullopt;
}
}
std::optional<Binding> Scope::linearSearchForBinding(const std::string& name, bool traverseScopeChain)
{
Scope* scope = this;
while (scope)
{
for (const auto& [n, binding] : scope->bindings)
{
if (n.local && n.local->name == name.c_str())
return binding;
else if (n.global.value && n.global == name.c_str())
return binding;
}
scope = scope->parent.get();
if (!traverseScopeChain)
break;
}
return std::nullopt;
}
} // namespace Luau

View file

@ -6,9 +6,9 @@
#include <algorithm>
#include <stdexcept>
LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 0)
LUAU_FASTFLAG(LuauSecondTypecheckKnowsTheDataModel)
LUAU_FASTFLAG(LuauRankNTypes)
LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 1000)
LUAU_FASTFLAGVARIABLE(LuauSubstitutionDontReplaceIgnoredTypes, false)
LUAU_FASTFLAG(LuauTypeAliasPacks)
namespace Luau
{
@ -17,7 +17,7 @@ void Tarjan::visitChildren(TypeId ty, int index)
{
ty = follow(ty);
if (FFlag::LuauRankNTypes && ignoreChildren(ty))
if (ignoreChildren(ty))
return;
if (const FunctionTypeVar* ftv = get<FunctionTypeVar>(ty))
@ -35,8 +35,15 @@ void Tarjan::visitChildren(TypeId ty, int index)
visitChild(ttv->indexer->indexType);
visitChild(ttv->indexer->indexResultType);
}
for (TypeId itp : ttv->instantiatedTypeParams)
visitChild(itp);
if (FFlag::LuauTypeAliasPacks)
{
for (TypePackId itp : ttv->instantiatedTypePackParams)
visitChild(itp);
}
}
else if (const MetatableTypeVar* mtv = get<MetatableTypeVar>(ty))
{
@ -59,7 +66,7 @@ void Tarjan::visitChildren(TypePackId tp, int index)
{
tp = follow(tp);
if (FFlag::LuauRankNTypes && ignoreChildren(tp))
if (ignoreChildren(tp))
return;
if (const TypePack* tpp = get<TypePack>(tp))
@ -332,9 +339,11 @@ std::optional<TypeId> Substitution::substitute(TypeId ty)
return std::nullopt;
for (auto [oldTy, newTy] : newTypes)
replaceChildren(newTy);
if (!FFlag::LuauSubstitutionDontReplaceIgnoredTypes || !ignoreChildren(oldTy))
replaceChildren(newTy);
for (auto [oldTp, newTp] : newPacks)
replaceChildren(newTp);
if (!FFlag::LuauSubstitutionDontReplaceIgnoredTypes || !ignoreChildren(oldTp))
replaceChildren(newTp);
TypeId newTy = replace(ty);
return newTy;
}
@ -350,9 +359,11 @@ std::optional<TypePackId> Substitution::substitute(TypePackId tp)
return std::nullopt;
for (auto [oldTy, newTy] : newTypes)
replaceChildren(newTy);
if (!FFlag::LuauSubstitutionDontReplaceIgnoredTypes || !ignoreChildren(oldTy))
replaceChildren(newTy);
for (auto [oldTp, newTp] : newPacks)
replaceChildren(newTp);
if (!FFlag::LuauSubstitutionDontReplaceIgnoredTypes || !ignoreChildren(oldTp))
replaceChildren(newTp);
TypePackId newTp = replace(tp);
return newTp;
}
@ -382,8 +393,11 @@ TypeId Substitution::clone(TypeId ty)
clone.name = ttv->name;
clone.syntheticName = ttv->syntheticName;
clone.instantiatedTypeParams = ttv->instantiatedTypeParams;
if (FFlag::LuauSecondTypecheckKnowsTheDataModel)
clone.tags = ttv->tags;
if (FFlag::LuauTypeAliasPacks)
clone.instantiatedTypePackParams = ttv->instantiatedTypePackParams;
clone.tags = ttv->tags;
result = addType(std::move(clone));
}
else if (const MetatableTypeVar* mtv = get<MetatableTypeVar>(ty))
@ -469,7 +483,7 @@ void Substitution::replaceChildren(TypeId ty)
{
ty = follow(ty);
if (FFlag::LuauRankNTypes && ignoreChildren(ty))
if (ignoreChildren(ty))
return;
if (FunctionTypeVar* ftv = getMutable<FunctionTypeVar>(ty))
@ -487,8 +501,15 @@ void Substitution::replaceChildren(TypeId ty)
ttv->indexer->indexType = replace(ttv->indexer->indexType);
ttv->indexer->indexResultType = replace(ttv->indexer->indexResultType);
}
for (TypeId& itp : ttv->instantiatedTypeParams)
itp = replace(itp);
if (FFlag::LuauTypeAliasPacks)
{
for (TypePackId& itp : ttv->instantiatedTypePackParams)
itp = replace(itp);
}
}
else if (MetatableTypeVar* mtv = getMutable<MetatableTypeVar>(ty))
{
@ -511,7 +532,7 @@ void Substitution::replaceChildren(TypePackId tp)
{
tp = follow(tp);
if (FFlag::LuauRankNTypes && ignoreChildren(tp))
if (ignoreChildren(tp))
return;
if (TypePack* tpp = getMutable<TypePack>(tp))

View file

@ -1,6 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/ToString.h"
#include "Luau/Scope.h"
#include "Luau/TypeInfer.h"
#include "Luau/TypePack.h"
#include "Luau/TypeVar.h"
@ -9,10 +10,8 @@
#include <algorithm>
#include <stdexcept>
LUAU_FASTFLAG(LuauToStringFollowsBoundTo)
LUAU_FASTFLAG(LuauExtraNilRecovery)
LUAU_FASTFLAG(LuauOccursCheckOkWithRecursiveFunctions)
LUAU_FASTFLAGVARIABLE(LuauInstantiatedTypeParamRecursion, false)
LUAU_FASTFLAG(LuauTypeAliasPacks)
namespace Luau
{
@ -59,6 +58,13 @@ struct FindCyclicTypes
{
for (TypeId itp : ttv.instantiatedTypeParams)
visitTypeVar(itp, *this, seen);
if (FFlag::LuauTypeAliasPacks)
{
for (TypePackId itp : ttv.instantiatedTypePackParams)
visitTypeVar(itp, *this, seen);
}
return exhaustive;
}
@ -151,15 +157,6 @@ struct StringifierState
seen.erase(iter);
}
static std::string generateName(size_t i)
{
std::string n;
n = char('a' + i % 26);
if (i >= 26)
n += std::to_string(i / 26);
return n;
}
std::string getName(TypeId ty)
{
const size_t s = result.nameMap.typeVars.size();
@ -239,15 +236,6 @@ struct TypeVarStringifier
return;
}
if (!FFlag::LuauAddMissingFollow)
{
if (get<FreeTypeVar>(tv))
{
state.emit(state.getName(tv));
return;
}
}
Luau::visit(
[this, tv](auto&& t) {
return (*this)(tv, t);
@ -258,34 +246,67 @@ struct TypeVarStringifier
void stringify(TypePackId tp);
void stringify(TypePackId tpid, const std::vector<std::optional<FunctionArgument>>& names);
void stringify(const std::vector<TypeId>& types)
void stringify(const std::vector<TypeId>& types, const std::vector<TypePackId>& typePacks)
{
if (types.size() == 0)
if (types.size() == 0 && (!FFlag::LuauTypeAliasPacks || typePacks.size() == 0))
return;
if (types.size())
if (types.size() || (FFlag::LuauTypeAliasPacks && typePacks.size()))
state.emit("<");
for (size_t i = 0; i < types.size(); ++i)
if (FFlag::LuauTypeAliasPacks)
{
if (i > 0)
state.emit(", ");
bool first = true;
stringify(types[i]);
for (TypeId ty : types)
{
if (!first)
state.emit(", ");
first = false;
stringify(ty);
}
bool singleTp = typePacks.size() == 1;
for (TypePackId tp : typePacks)
{
if (isEmpty(tp) && singleTp)
continue;
if (!first)
state.emit(", ");
else
first = false;
if (!singleTp)
state.emit("(");
stringify(tp);
if (!singleTp)
state.emit(")");
}
}
else
{
for (size_t i = 0; i < types.size(); ++i)
{
if (i > 0)
state.emit(", ");
stringify(types[i]);
}
}
if (types.size())
if (types.size() || (FFlag::LuauTypeAliasPacks && typePacks.size()))
state.emit(">");
}
void operator()(TypeId ty, const Unifiable::Free& ftv)
{
state.result.invalid = true;
if (FFlag::LuauAddMissingFollow)
state.emit(state.getName(ty));
else
state.emit("<FREE>");
state.emit(state.getName(ty));
}
void operator()(TypeId, const BoundTypeVar& btv)
@ -329,6 +350,23 @@ struct TypeVarStringifier
}
}
void operator()(TypeId, const SingletonTypeVar& stv)
{
if (const BoolSingleton* bs = Luau::get<BoolSingleton>(&stv))
state.emit(bs->value ? "true" : "false");
else if (const StringSingleton* ss = Luau::get<StringSingleton>(&stv))
{
state.emit("\"");
state.emit(escape(ss->value));
state.emit("\"");
}
else
{
LUAU_ASSERT(!"Unknown singleton type");
throw std::runtime_error("Unknown singleton type");
}
}
void operator()(TypeId, const FunctionTypeVar& ftv)
{
if (state.hasSeen(&ftv))
@ -338,6 +376,7 @@ struct TypeVarStringifier
return;
}
// We should not be respecting opts.hideNamedFunctionTypeParameters here.
if (ftv.generics.size() > 0 || ftv.genericPacks.size() > 0)
{
state.emit("<");
@ -388,7 +427,7 @@ struct TypeVarStringifier
void operator()(TypeId, const TableTypeVar& ttv)
{
if (FFlag::LuauToStringFollowsBoundTo && ttv.boundTo)
if (ttv.boundTo)
return stringify(*ttv.boundTo);
if (!state.exhaustive)
@ -411,14 +450,14 @@ struct TypeVarStringifier
}
state.emit(*ttv.name);
stringify(ttv.instantiatedTypeParams);
stringify(ttv.instantiatedTypeParams, ttv.instantiatedTypePackParams);
return;
}
if (ttv.syntheticName)
{
state.result.invalid = true;
state.emit(*ttv.syntheticName);
stringify(ttv.instantiatedTypeParams);
stringify(ttv.instantiatedTypeParams, ttv.instantiatedTypePackParams);
return;
}
}
@ -493,7 +532,14 @@ struct TypeVarStringifier
break;
}
state.emit(name);
if (isIdentifier(name))
state.emit(name);
else
{
state.emit("[\"");
state.emit(escape(name));
state.emit("\"]");
}
state.emit(": ");
stringify(prop.type);
comma = true;
@ -539,8 +585,7 @@ struct TypeVarStringifier
std::vector<std::string> results = {};
for (auto el : &uv)
{
if (FFlag::LuauExtraNilRecovery || FFlag::LuauAddMissingFollow)
el = follow(el);
el = follow(el);
if (isNil(el))
{
@ -604,8 +649,7 @@ struct TypeVarStringifier
std::vector<std::string> results = {};
for (auto el : uv.parts)
{
if (FFlag::LuauExtraNilRecovery || FFlag::LuauAddMissingFollow)
el = follow(el);
el = follow(el);
std::string saved = std::move(state.result.name);
@ -691,16 +735,6 @@ struct TypePackStringifier
return;
}
if (!FFlag::LuauAddMissingFollow)
{
if (get<FreeTypePack>(tp))
{
state.emit(state.getName(tp));
state.emit("...");
return;
}
}
auto it = state.cycleTpNames.find(tp);
if (it != state.cycleTpNames.end())
{
@ -788,16 +822,8 @@ struct TypePackStringifier
void operator()(TypePackId tp, const FreeTypePack& pack)
{
state.result.invalid = true;
if (FFlag::LuauAddMissingFollow)
{
state.emit(state.getName(tp));
state.emit("...");
}
else
{
state.emit("<FREETP>");
}
state.emit(state.getName(tp));
state.emit("...");
}
void operator()(TypePackId, const BoundTypePack& btv)
@ -831,23 +857,15 @@ static void assignCycleNames(const std::unordered_set<TypeId>& cycles, const std
std::string name;
// TODO: use the stringified type list if there are no cycles
if (FFlag::LuauInstantiatedTypeParamRecursion)
if (auto ttv = get<TableTypeVar>(follow(cycleTy)); !exhaustive && ttv && (ttv->syntheticName || ttv->name))
{
if (auto ttv = get<TableTypeVar>(follow(cycleTy)); !exhaustive && ttv && (ttv->syntheticName || ttv->name))
{
// If we have a cycle type in type parameters, assign a cycle name for this named table
if (std::find_if(ttv->instantiatedTypeParams.begin(), ttv->instantiatedTypeParams.end(), [&](auto&& el) {
return cycles.count(follow(el));
}) != ttv->instantiatedTypeParams.end())
cycleNames[cycleTy] = ttv->name ? *ttv->name : *ttv->syntheticName;
// If we have a cycle type in type parameters, assign a cycle name for this named table
if (std::find_if(ttv->instantiatedTypeParams.begin(), ttv->instantiatedTypeParams.end(), [&](auto&& el) {
return cycles.count(follow(el));
}) != ttv->instantiatedTypeParams.end())
cycleNames[cycleTy] = ttv->name ? *ttv->name : *ttv->syntheticName;
continue;
}
}
else
{
if (auto ttv = get<TableTypeVar>(follow(cycleTy)); !exhaustive && ttv && (ttv->syntheticName || ttv->name))
continue;
continue;
}
name = "t" + std::to_string(nextIndex);
@ -879,45 +897,6 @@ ToStringResult toStringDetailed(TypeId ty, const ToStringOptions& opts)
ToStringResult result;
if (!FFlag::LuauInstantiatedTypeParamRecursion && !opts.exhaustive)
{
if (auto ttv = get<TableTypeVar>(ty); ttv && (ttv->name || ttv->syntheticName))
{
if (ttv->syntheticName)
result.invalid = true;
// If scope if provided, add module name and check visibility
if (ttv->name && opts.scope)
{
auto [success, moduleName] = canUseTypeNameInScope(opts.scope, *ttv->name);
if (!success)
result.invalid = true;
if (moduleName)
result.name = format("%s.", moduleName->c_str());
}
result.name += ttv->name ? *ttv->name : *ttv->syntheticName;
if (ttv->instantiatedTypeParams.empty())
return result;
std::vector<std::string> params;
for (TypeId tp : ttv->instantiatedTypeParams)
params.push_back(toString(tp));
result.name += "<" + join(params, ", ") + ">";
return result;
}
else if (auto mtv = get<MetatableTypeVar>(ty); mtv && mtv->syntheticName)
{
result.invalid = true;
result.name = *mtv->syntheticName;
return result;
}
}
StringifierState state{opts, result, opts.nameMap};
std::unordered_set<TypeId> cycles;
@ -929,7 +908,7 @@ ToStringResult toStringDetailed(TypeId ty, const ToStringOptions& opts)
TypeVarStringifier tvs{state};
if (FFlag::LuauInstantiatedTypeParamRecursion && !opts.exhaustive)
if (!opts.exhaustive)
{
if (auto ttv = get<TableTypeVar>(ty); ttv && (ttv->name || ttv->syntheticName))
{
@ -950,30 +929,37 @@ ToStringResult toStringDetailed(TypeId ty, const ToStringOptions& opts)
result.name += ttv->name ? *ttv->name : *ttv->syntheticName;
if (ttv->instantiatedTypeParams.empty())
return result;
result.name += "<";
bool first = true;
for (TypeId ty : ttv->instantiatedTypeParams)
if (FFlag::LuauTypeAliasPacks)
{
if (!first)
result.name += ", ";
else
first = false;
tvs.stringify(ty);
}
if (opts.maxTypeLength > 0 && result.name.length() > opts.maxTypeLength)
{
result.truncated = true;
result.name += "... <TRUNCATED>";
tvs.stringify(ttv->instantiatedTypeParams, ttv->instantiatedTypePackParams);
}
else
{
result.name += ">";
if (ttv->instantiatedTypeParams.empty() && (!FFlag::LuauTypeAliasPacks || ttv->instantiatedTypePackParams.empty()))
return result;
result.name += "<";
bool first = true;
for (TypeId ty : ttv->instantiatedTypeParams)
{
if (!first)
result.name += ", ";
else
first = false;
tvs.stringify(ty);
}
if (opts.maxTypeLength > 0 && result.name.length() > opts.maxTypeLength)
{
result.truncated = true;
result.name += "... <TRUNCATED>";
}
else
{
result.name += ">";
}
}
return result;
@ -1123,6 +1109,94 @@ std::string toString(const TypePackVar& tp, const ToStringOptions& opts)
return toString(const_cast<TypePackId>(&tp), std::move(opts));
}
std::string toStringNamedFunction(const std::string& prefix, const FunctionTypeVar& ftv, ToStringOptions opts)
{
std::string s = prefix;
auto toString_ = [&opts](TypeId ty) -> std::string {
ToStringResult res = toStringDetailed(ty, opts);
opts.nameMap = std::move(res.nameMap);
return res.name;
};
auto toStringPack_ = [&opts](TypePackId ty) -> std::string {
ToStringResult res = toStringDetailed(ty, opts);
opts.nameMap = std::move(res.nameMap);
return res.name;
};
if (!opts.hideNamedFunctionTypeParameters && (!ftv.generics.empty() || !ftv.genericPacks.empty()))
{
s += "<";
bool first = true;
for (TypeId g : ftv.generics)
{
if (!first)
s += ", ";
first = false;
s += toString_(g);
}
for (TypePackId gp : ftv.genericPacks)
{
if (!first)
s += ", ";
first = false;
s += toStringPack_(gp);
}
s += ">";
}
s += "(";
auto argPackIter = begin(ftv.argTypes);
auto argNameIter = ftv.argNames.begin();
bool first = true;
while (argPackIter != end(ftv.argTypes))
{
if (!first)
s += ", ";
first = false;
// argNames is guaranteed to be equal to argTypes iff argNames is not empty.
// We don't currently respect opts.functionTypeArguments. I don't think this function should.
if (!ftv.argNames.empty())
s += (*argNameIter ? (*argNameIter)->name : "_") + ": ";
s += toString_(*argPackIter);
++argPackIter;
if (!ftv.argNames.empty())
{
LUAU_ASSERT(argNameIter != ftv.argNames.end());
++argNameIter;
}
}
if (argPackIter.tail())
{
if (auto vtp = get<VariadicTypePack>(*argPackIter.tail()))
s += ", ...: " + toString_(vtp->ty);
else
s += ", ...: " + toStringPack_(*argPackIter.tail());
}
s += "): ";
size_t retSize = size(ftv.retType);
bool hasTail = !finite(ftv.retType);
if (retSize == 0 && !hasTail)
s += "()";
else if ((retSize == 0 && hasTail) || (retSize == 1 && !hasTail))
s += toStringPack_(ftv.retType);
else
s += "(" + toStringPack_(ftv.retType) + ")";
return s;
}
void dump(TypeId ty)
{
ToStringOptions opts;
@ -1139,4 +1213,13 @@ void dump(TypePackId ty)
printf("%s\n", toString(ty, opts).c_str());
}
std::string generateName(size_t i)
{
std::string n;
n = char('a' + i % 26);
if (i >= 26)
n += std::to_string(i / 26);
return n;
}
} // namespace Luau

View file

@ -298,8 +298,15 @@ struct ArcCollector : public AstVisitor
struct ContainsFunctionCall : public AstVisitor
{
bool alsoReturn = false;
bool result = false;
ContainsFunctionCall() = default;
explicit ContainsFunctionCall(bool alsoReturn)
: alsoReturn(alsoReturn)
{
}
bool visit(AstExpr*) override
{
return !result; // short circuit if result is true
@ -318,6 +325,17 @@ struct ContainsFunctionCall : public AstVisitor
return false;
}
bool visit(AstStatReturn* stat) override
{
if (alsoReturn)
{
result = true;
return false;
}
else
return AstVisitor::visit(stat);
}
bool visit(AstExprFunction*) override
{
return false;
@ -479,6 +497,13 @@ bool containsFunctionCall(const AstStat& stat)
return cfc.result;
}
bool containsFunctionCallOrReturn(const AstStat& stat)
{
detail::ContainsFunctionCall cfc{true};
const_cast<AstStat&>(stat).visit(&cfc);
return cfc.result;
}
bool isFunction(const AstStat& stat)
{
return stat.is<AstStatFunction>() || stat.is<AstStatLocalFunction>();

View file

@ -10,65 +10,10 @@
#include <limits>
#include <math.h>
LUAU_FASTFLAG(LuauGenericFunctions)
LUAU_FASTFLAG(LuauTypeAliasPacks)
namespace
{
std::string escape(std::string_view s)
{
std::string r;
r.reserve(s.size() + 50); // arbitrary number to guess how many characters we'll be inserting
for (uint8_t c : s)
{
if (c >= ' ' && c != '\\' && c != '\'' && c != '\"')
r += c;
else
{
r += '\\';
switch (c)
{
case '\a':
r += 'a';
break;
case '\b':
r += 'b';
break;
case '\f':
r += 'f';
break;
case '\n':
r += 'n';
break;
case '\r':
r += 'r';
break;
case '\t':
r += 't';
break;
case '\v':
r += 'v';
break;
case '\'':
r += '\'';
break;
case '\"':
r += '\"';
break;
case '\\':
r += '\\';
break;
default:
Luau::formatAppend(r, "%03u", c);
}
}
}
return r;
}
bool isIdentifierStartChar(char c)
{
return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || c == '_';
@ -96,9 +41,6 @@ struct Writer
{
virtual ~Writer() {}
virtual void begin() {}
virtual void end() {}
virtual void advance(const Position&) = 0;
virtual void newline() = 0;
virtual void space() = 0;
@ -130,6 +72,7 @@ struct StringWriter : Writer
if (pos.column < newPos.column)
write(std::string(newPos.column - pos.column, ' '));
}
void maybeSpace(const Position& newPos, int reserve) override
{
if (pos.column + reserve < newPos.column)
@ -278,12 +221,25 @@ struct Printer
writer.identifier(func->index.value);
}
void visualizeTypePackAnnotation(const AstTypePack& annotation)
void visualizeTypePackAnnotation(const AstTypePack& annotation, bool forVarArg)
{
if (const AstTypePackVariadic* variadic = annotation.as<AstTypePackVariadic>())
advance(annotation.location.begin);
if (const AstTypePackVariadic* variadicTp = annotation.as<AstTypePackVariadic>())
{
if (!forVarArg)
writer.symbol("...");
visualizeTypeAnnotation(*variadicTp->variadicType);
}
else if (const AstTypePackGeneric* genericTp = annotation.as<AstTypePackGeneric>())
{
writer.symbol(genericTp->genericName.value);
writer.symbol("...");
visualizeTypeAnnotation(*variadic->variadicType);
}
else if (const AstTypePackExplicit* explicitTp = annotation.as<AstTypePackExplicit>())
{
LUAU_ASSERT(!forVarArg);
visualizeTypeList(explicitTp->typeList, true);
}
else
{
@ -307,7 +263,7 @@ struct Printer
// Only variadic tail
if (list.types.size == 0)
{
visualizeTypePackAnnotation(*list.tailType);
visualizeTypePackAnnotation(*list.tailType, false);
}
else
{
@ -335,7 +291,7 @@ struct Printer
if (list.tailType)
{
writer.symbol(",");
visualizeTypePackAnnotation(*list.tailType);
visualizeTypePackAnnotation(*list.tailType, false);
}
writer.symbol(")");
@ -532,6 +488,7 @@ struct Printer
case AstExprBinary::CompareLt:
case AstExprBinary::CompareGt:
writer.maybeSpace(a->right->location.begin, 2);
writer.symbol(toString(a->op));
break;
case AstExprBinary::Concat:
case AstExprBinary::CompareNe:
@ -540,19 +497,35 @@ struct Printer
case AstExprBinary::CompareGe:
case AstExprBinary::Or:
writer.maybeSpace(a->right->location.begin, 3);
writer.keyword(toString(a->op));
break;
case AstExprBinary::And:
writer.maybeSpace(a->right->location.begin, 4);
writer.keyword(toString(a->op));
break;
}
writer.symbol(toString(a->op));
visualize(*a->right);
}
else if (const auto& a = expr.as<AstExprTypeAssertion>())
{
visualize(*a->expr);
if (writeTypes)
{
writer.maybeSpace(a->annotation->location.begin, 2);
writer.symbol("::");
visualizeTypeAnnotation(*a->annotation);
}
}
else if (const auto& a = expr.as<AstExprIfElse>())
{
writer.keyword("if");
visualize(*a->condition);
writer.keyword("then");
visualize(*a->trueExpr);
writer.keyword("else");
visualize(*a->falseExpr);
}
else if (const auto& a = expr.as<AstExprError>())
{
@ -759,24 +732,31 @@ struct Printer
switch (a->op)
{
case AstExprBinary::Add:
writer.maybeSpace(a->value->location.begin, 2);
writer.symbol("+=");
break;
case AstExprBinary::Sub:
writer.maybeSpace(a->value->location.begin, 2);
writer.symbol("-=");
break;
case AstExprBinary::Mul:
writer.maybeSpace(a->value->location.begin, 2);
writer.symbol("*=");
break;
case AstExprBinary::Div:
writer.maybeSpace(a->value->location.begin, 2);
writer.symbol("/=");
break;
case AstExprBinary::Mod:
writer.maybeSpace(a->value->location.begin, 2);
writer.symbol("%=");
break;
case AstExprBinary::Pow:
writer.maybeSpace(a->value->location.begin, 2);
writer.symbol("^=");
break;
case AstExprBinary::Concat:
writer.maybeSpace(a->value->location.begin, 3);
writer.symbol("..=");
break;
default:
@ -807,7 +787,7 @@ struct Printer
writer.keyword("type");
writer.identifier(a->name.value);
if (a->generics.size > 0)
if (a->generics.size > 0 || (FFlag::LuauTypeAliasPacks && a->genericPacks.size > 0))
{
writer.symbol("<");
CommaSeparatorInserter comma(writer);
@ -817,6 +797,17 @@ struct Printer
comma();
writer.identifier(o.value);
}
if (FFlag::LuauTypeAliasPacks)
{
for (auto o : a->genericPacks)
{
comma();
writer.identifier(o.value);
writer.symbol("...");
}
}
writer.symbol(">");
}
writer.maybeSpace(a->type->location.begin, 2);
@ -853,7 +844,7 @@ struct Printer
void visualizeFunctionBody(AstExprFunction& func)
{
if (FFlag::LuauGenericFunctions && (func.generics.size > 0 || func.genericPacks.size > 0))
if (func.generics.size > 0 || func.genericPacks.size > 0)
{
CommaSeparatorInserter comma(writer);
writer.symbol("<");
@ -892,12 +883,13 @@ struct Printer
if (func.vararg)
{
comma();
advance(func.varargLocation.begin);
writer.symbol("...");
if (func.varargAnnotation)
{
writer.symbol(":");
visualizeTypePackAnnotation(*func.varargAnnotation);
visualizeTypePackAnnotation(*func.varargAnnotation, true);
}
}
@ -959,22 +951,33 @@ struct Printer
advance(typeAnnotation.location.begin);
if (const auto& a = typeAnnotation.as<AstTypeReference>())
{
if (a->hasPrefix)
{
writer.write(a->prefix.value);
writer.symbol(".");
}
writer.write(a->name.value);
if (a->generics.size > 0)
if (a->parameters.size > 0 || a->hasParameterList)
{
CommaSeparatorInserter comma(writer);
writer.symbol("<");
for (auto o : a->generics)
for (auto o : a->parameters)
{
comma();
visualizeTypeAnnotation(*o);
if (o.type)
visualizeTypeAnnotation(*o.type);
else
visualizeTypePackAnnotation(*o.typePack, false);
}
writer.symbol(">");
}
}
else if (const auto& a = typeAnnotation.as<AstTypeFunction>())
{
if (FFlag::LuauGenericFunctions && (a->generics.size > 0 || a->genericPacks.size > 0))
if (a->generics.size > 0 || a->genericPacks.size > 0)
{
CommaSeparatorInserter comma(writer);
writer.symbol("<");
@ -1049,7 +1052,16 @@ struct Printer
auto rta = r->as<AstTypeReference>();
if (rta && rta->name == "nil")
{
bool wrap = l->as<AstTypeIntersection>() || l->as<AstTypeFunction>();
if (wrap)
writer.symbol("(");
visualizeTypeAnnotation(*l);
if (wrap)
writer.symbol(")");
writer.symbol("?");
return;
}
@ -1063,7 +1075,15 @@ struct Printer
writer.symbol("|");
}
bool wrap = a->types.data[i]->as<AstTypeIntersection>() || a->types.data[i]->as<AstTypeFunction>();
if (wrap)
writer.symbol("(");
visualizeTypeAnnotation(*a->types.data[i]);
if (wrap)
writer.symbol(")");
}
}
else if (const auto& a = typeAnnotation.as<AstTypeIntersection>())
@ -1076,7 +1096,15 @@ struct Printer
writer.symbol("&");
}
bool wrap = a->types.data[i]->as<AstTypeUnion>() || a->types.data[i]->as<AstTypeFunction>();
if (wrap)
writer.symbol("(");
visualizeTypeAnnotation(*a->types.data[i]);
if (wrap)
writer.symbol(")");
}
}
else if (typeAnnotation.is<AstTypeError>())
@ -1090,31 +1118,27 @@ struct Printer
}
};
void dump(AstNode* node)
std::string toString(AstNode* node)
{
StringWriter writer;
writer.pos = node->location.begin;
Printer printer(writer);
printer.writeTypes = true;
if (auto statNode = dynamic_cast<AstStat*>(node))
{
printer.visualize(*statNode);
printf("%s\n", writer.str().c_str());
}
else if (auto exprNode = dynamic_cast<AstExpr*>(node))
{
printer.visualize(*exprNode);
printf("%s\n", writer.str().c_str());
}
else if (auto typeNode = dynamic_cast<AstType*>(node))
{
printer.visualizeTypeAnnotation(*typeNode);
printf("%s\n", writer.str().c_str());
}
else
{
printf("Can't dump this node\n");
}
return writer.str();
}
void dump(AstNode* node)
{
printf("%s\n", toString(node).c_str());
}
std::string transpile(AstStatBlock& block)
@ -1123,6 +1147,7 @@ std::string transpile(AstStatBlock& block)
Printer(writer).visualizeBlock(block);
return writer.str();
}
std::string transpileWithTypes(AstStatBlock& block)
{
StringWriter writer;
@ -1132,7 +1157,7 @@ std::string transpileWithTypes(AstStatBlock& block)
return writer.str();
}
TranspileResult transpile(std::string_view source, ParseOptions options)
TranspileResult transpile(std::string_view source, ParseOptions options, bool withTypes)
{
auto allocator = Allocator{};
auto names = AstNameTable{allocator};
@ -1150,6 +1175,9 @@ TranspileResult transpile(std::string_view source, ParseOptions options)
if (!parseResult.root)
return TranspileResult{"", {}, "Internal error: Parser yielded empty parse tree"};
if (withTypes)
return TranspileResult{transpileWithTypes(*parseResult.root)};
return TranspileResult{transpile(*parseResult.root)};
}

View file

@ -5,6 +5,8 @@
#include <algorithm>
LUAU_FASTFLAGVARIABLE(LuauShareTxnSeen, false)
namespace Luau
{
@ -33,6 +35,12 @@ void TxnLog::rollback()
for (auto it = tableChanges.rbegin(); it != tableChanges.rend(); ++it)
std::swap(it->first->boundTo, it->second);
if (FFlag::LuauShareTxnSeen)
{
LUAU_ASSERT(originalSeenSize <= sharedSeen->size());
sharedSeen->resize(originalSeenSize);
}
}
void TxnLog::concat(TxnLog rhs)
@ -46,27 +54,44 @@ void TxnLog::concat(TxnLog rhs)
tableChanges.insert(tableChanges.end(), rhs.tableChanges.begin(), rhs.tableChanges.end());
rhs.tableChanges.clear();
seen.swap(rhs.seen);
rhs.seen.clear();
if (!FFlag::LuauShareTxnSeen)
{
ownedSeen.swap(rhs.ownedSeen);
rhs.ownedSeen.clear();
}
}
bool TxnLog::haveSeen(TypeId lhs, TypeId rhs)
{
const std::pair<TypeId, TypeId> sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs);
return (seen.end() != std::find(seen.begin(), seen.end(), sortedPair));
if (FFlag::LuauShareTxnSeen)
return (sharedSeen->end() != std::find(sharedSeen->begin(), sharedSeen->end(), sortedPair));
else
return (ownedSeen.end() != std::find(ownedSeen.begin(), ownedSeen.end(), sortedPair));
}
void TxnLog::pushSeen(TypeId lhs, TypeId rhs)
{
const std::pair<TypeId, TypeId> sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs);
seen.push_back(sortedPair);
if (FFlag::LuauShareTxnSeen)
sharedSeen->push_back(sortedPair);
else
ownedSeen.push_back(sortedPair);
}
void TxnLog::popSeen(TypeId lhs, TypeId rhs)
{
const std::pair<TypeId, TypeId> sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs);
LUAU_ASSERT(sortedPair == seen.back());
seen.pop_back();
if (FFlag::LuauShareTxnSeen)
{
LUAU_ASSERT(sortedPair == sharedSeen->back());
sharedSeen->pop_back();
}
else
{
LUAU_ASSERT(sortedPair == ownedSeen.back());
ownedSeen.pop_back();
}
}
} // namespace Luau

View file

@ -5,13 +5,15 @@
#include "Luau/Module.h"
#include "Luau/Parser.h"
#include "Luau/RecursionCounter.h"
#include "Luau/Scope.h"
#include "Luau/ToString.h"
#include "Luau/TypeInfer.h"
#include "Luau/TypePack.h"
#include "Luau/TypeVar.h"
#include <string>
LUAU_FASTFLAG(LuauGenericFunctions)
LUAU_FASTFLAG(LuauTypeAliasPacks)
static char* allocateString(Luau::Allocator& allocator, std::string_view contents)
{
@ -31,15 +33,31 @@ static char* allocateString(Luau::Allocator& allocator, const char* format, Data
return result;
}
using SyntheticNames = std::unordered_map<const void*, char*>;
namespace Luau
{
static const char* getName(Allocator* allocator, SyntheticNames* syntheticNames, const Unifiable::Generic& gen)
{
size_t s = syntheticNames->size();
char*& n = (*syntheticNames)[&gen];
if (!n)
{
std::string str = gen.explicitName ? gen.name : generateName(s);
n = static_cast<char*>(allocator->allocate(str.size() + 1));
strcpy(n, str.c_str());
}
return n;
}
class TypeRehydrationVisitor
{
mutable std::map<void*, int> seen;
mutable int count = 0;
std::map<void*, int> seen;
int count = 0;
bool hasSeen(const void* tv) const
bool hasSeen(const void* tv)
{
void* ttv = const_cast<void*>(tv);
auto it = seen.find(ttv);
@ -51,13 +69,16 @@ class TypeRehydrationVisitor
}
public:
TypeRehydrationVisitor(Allocator* alloc, const TypeRehydrationOptions& options = TypeRehydrationOptions())
TypeRehydrationVisitor(Allocator* alloc, SyntheticNames* syntheticNames, const TypeRehydrationOptions& options = TypeRehydrationOptions())
: allocator(alloc)
, syntheticNames(syntheticNames)
, options(options)
{
}
AstType* operator()(const PrimitiveTypeVar& ptv) const
AstTypePack* rehydrate(TypePackId tp);
AstType* operator()(const PrimitiveTypeVar& ptv)
{
switch (ptv.type)
{
@ -75,26 +96,50 @@ public:
return nullptr;
}
}
AstType* operator()(const AnyTypeVar&) const
AstType* operator()(const SingletonTypeVar& stv)
{
if (const BoolSingleton* bs = get<BoolSingleton>(&stv))
return allocator->alloc<AstTypeSingletonBool>(Location(), bs->value);
else if (const StringSingleton* ss = get<StringSingleton>(&stv))
{
AstArray<char> value;
value.data = const_cast<char*>(ss->value.c_str());
value.size = strlen(value.data);
return allocator->alloc<AstTypeSingletonString>(Location(), value);
}
else
return nullptr;
}
AstType* operator()(const AnyTypeVar&)
{
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("any"));
}
AstType* operator()(const TableTypeVar& ttv) const
AstType* operator()(const TableTypeVar& ttv)
{
RecursionCounter counter(&count);
if (ttv.name && options.bannedNames.find(*ttv.name) == options.bannedNames.end())
{
AstArray<AstType*> generics;
generics.size = ttv.instantiatedTypeParams.size();
generics.data = static_cast<AstType**>(allocator->allocate(sizeof(AstType*) * generics.size));
AstArray<AstTypeOrPack> parameters;
parameters.size = ttv.instantiatedTypeParams.size();
parameters.data = static_cast<AstTypeOrPack*>(allocator->allocate(sizeof(AstTypeOrPack) * parameters.size));
for (size_t i = 0; i < ttv.instantiatedTypeParams.size(); ++i)
{
generics.data[i] = Luau::visit(*this, ttv.instantiatedTypeParams[i]->ty);
parameters.data[i] = {Luau::visit(*this, ttv.instantiatedTypeParams[i]->ty), {}};
}
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName(ttv.name->c_str()), generics);
if (FFlag::LuauTypeAliasPacks)
{
for (size_t i = 0; i < ttv.instantiatedTypePackParams.size(); ++i)
{
parameters.data[i] = {{}, rehydrate(ttv.instantiatedTypePackParams[i])};
}
}
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName(ttv.name->c_str()), parameters.size != 0, parameters);
}
if (hasSeen(&ttv))
@ -133,12 +178,12 @@ public:
return allocator->alloc<AstTypeTable>(Location(), props, indexer);
}
AstType* operator()(const MetatableTypeVar& mtv) const
AstType* operator()(const MetatableTypeVar& mtv)
{
return Luau::visit(*this, mtv.table->ty);
}
AstType* operator()(const ClassTypeVar& ctv) const
AstType* operator()(const ClassTypeVar& ctv)
{
RecursionCounter counter(&count);
@ -165,7 +210,7 @@ public:
return allocator->alloc<AstTypeTable>(Location(), props);
}
AstType* operator()(const FunctionTypeVar& ftv) const
AstType* operator()(const FunctionTypeVar& ftv)
{
RecursionCounter counter(&count);
@ -173,39 +218,23 @@ public:
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("<Cycle>"));
AstArray<AstName> generics;
if (FFlag::LuauGenericFunctions)
generics.size = ftv.generics.size();
generics.data = static_cast<AstName*>(allocator->allocate(sizeof(AstName) * generics.size));
size_t numGenerics = 0;
for (auto it = ftv.generics.begin(); it != ftv.generics.end(); ++it)
{
generics.size = ftv.generics.size();
generics.data = static_cast<AstName*>(allocator->allocate(sizeof(AstName) * generics.size));
size_t i = 0;
for (auto it = ftv.generics.begin(); it != ftv.generics.end(); ++it)
{
if (auto gtv = get<GenericTypeVar>(*it))
generics.data[i++] = AstName(gtv->name.c_str());
}
}
else
{
generics.size = 0;
generics.data = nullptr;
if (auto gtv = get<GenericTypeVar>(*it))
generics.data[numGenerics++] = AstName(gtv->name.c_str());
}
AstArray<AstName> genericPacks;
if (FFlag::LuauGenericFunctions)
genericPacks.size = ftv.genericPacks.size();
genericPacks.data = static_cast<AstName*>(allocator->allocate(sizeof(AstName) * genericPacks.size));
size_t numGenericPacks = 0;
for (auto it = ftv.genericPacks.begin(); it != ftv.genericPacks.end(); ++it)
{
genericPacks.size = ftv.genericPacks.size();
genericPacks.data = static_cast<AstName*>(allocator->allocate(sizeof(AstName) * genericPacks.size));
size_t i = 0;
for (auto it = ftv.genericPacks.begin(); it != ftv.genericPacks.end(); ++it)
{
if (auto gtv = get<GenericTypeVar>(*it))
genericPacks.data[i++] = AstName(gtv->name.c_str());
}
}
else
{
generics.size = 0;
generics.data = nullptr;
if (auto gtv = get<GenericTypeVar>(*it))
genericPacks.data[numGenericPacks++] = AstName(gtv->name.c_str());
}
AstArray<AstType*> argTypes;
@ -222,10 +251,17 @@ public:
AstTypePack* argTailAnnotation = nullptr;
if (argTail)
{
TypePackId tail = *argTail;
if (const VariadicTypePack* vtp = get<VariadicTypePack>(tail))
if (FFlag::LuauTypeAliasPacks)
{
argTailAnnotation = allocator->alloc<AstTypePackVariadic>(Location(), Luau::visit(*this, vtp->ty->ty));
argTailAnnotation = rehydrate(*argTail);
}
else
{
TypePackId tail = *argTail;
if (const VariadicTypePack* vtp = get<VariadicTypePack>(tail))
{
argTailAnnotation = allocator->alloc<AstTypePackVariadic>(Location(), Luau::visit(*this, vtp->ty->ty));
}
}
}
@ -235,10 +271,12 @@ public:
size_t i = 0;
for (const auto& el : ftv.argNames)
{
std::optional<AstArgumentName>* arg = &argNames.data[i++];
if (el)
argNames.data[i++] = {AstName(el->name.c_str()), el->location};
new (arg) std::optional<AstArgumentName>(AstArgumentName(AstName(el->name.c_str()), el->location));
else
argNames.data[i++] = {};
new (arg) std::optional<AstArgumentName>();
}
AstArray<AstType*> returnTypes;
@ -255,33 +293,40 @@ public:
AstTypePack* retTailAnnotation = nullptr;
if (retTail)
{
TypePackId tail = *retTail;
if (const VariadicTypePack* vtp = get<VariadicTypePack>(tail))
if (FFlag::LuauTypeAliasPacks)
{
retTailAnnotation = allocator->alloc<AstTypePackVariadic>(Location(), Luau::visit(*this, vtp->ty->ty));
retTailAnnotation = rehydrate(*retTail);
}
else
{
TypePackId tail = *retTail;
if (const VariadicTypePack* vtp = get<VariadicTypePack>(tail))
{
retTailAnnotation = allocator->alloc<AstTypePackVariadic>(Location(), Luau::visit(*this, vtp->ty->ty));
}
}
}
return allocator->alloc<AstTypeFunction>(
Location(), generics, genericPacks, AstTypeList{argTypes, argTailAnnotation}, argNames, AstTypeList{returnTypes, retTailAnnotation});
}
AstType* operator()(const Unifiable::Error&) const
AstType* operator()(const Unifiable::Error&)
{
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("Unifiable<Error>"));
}
AstType* operator()(const GenericTypeVar& gtv) const
AstType* operator()(const GenericTypeVar& gtv)
{
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName(gtv.name.c_str()));
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName(getName(allocator, syntheticNames, gtv)));
}
AstType* operator()(const Unifiable::Bound<TypeId>& bound) const
AstType* operator()(const Unifiable::Bound<TypeId>& bound)
{
return Luau::visit(*this, bound.boundTo->ty);
}
AstType* operator()(Unifiable::Free ftv) const
AstType* operator()(const FreeTypeVar& ftv)
{
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("free"));
}
AstType* operator()(const UnionTypeVar& uv) const
AstType* operator()(const UnionTypeVar& uv)
{
AstArray<AstType*> unionTypes;
unionTypes.size = uv.options.size();
@ -292,7 +337,7 @@ public:
}
return allocator->alloc<AstTypeUnion>(Location(), unionTypes);
}
AstType* operator()(const IntersectionTypeVar& uv) const
AstType* operator()(const IntersectionTypeVar& uv)
{
AstArray<AstType*> intersectionTypes;
intersectionTypes.size = uv.parts.size();
@ -303,16 +348,84 @@ public:
}
return allocator->alloc<AstTypeIntersection>(Location(), intersectionTypes);
}
AstType* operator()(const LazyTypeVar& ltv) const
AstType* operator()(const LazyTypeVar& ltv)
{
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("<Lazy?>"));
}
private:
Allocator* allocator;
SyntheticNames* syntheticNames;
const TypeRehydrationOptions& options;
};
class TypePackRehydrationVisitor
{
public:
TypePackRehydrationVisitor(Allocator* allocator, SyntheticNames* syntheticNames, TypeRehydrationVisitor* typeVisitor)
: allocator(allocator)
, syntheticNames(syntheticNames)
, typeVisitor(typeVisitor)
{
LUAU_ASSERT(allocator);
LUAU_ASSERT(syntheticNames);
LUAU_ASSERT(typeVisitor);
}
AstTypePack* operator()(const BoundTypePack& btp) const
{
return Luau::visit(*this, btp.boundTo->ty);
}
AstTypePack* operator()(const TypePack& tp) const
{
AstArray<AstType*> head;
head.size = tp.head.size();
head.data = static_cast<AstType**>(allocator->allocate(sizeof(AstType*) * tp.head.size()));
for (size_t i = 0; i < tp.head.size(); i++)
head.data[i] = Luau::visit(*typeVisitor, tp.head[i]->ty);
AstTypePack* tail = nullptr;
if (tp.tail)
tail = Luau::visit(*this, (*tp.tail)->ty);
return allocator->alloc<AstTypePackExplicit>(Location(), AstTypeList{head, tail});
}
AstTypePack* operator()(const VariadicTypePack& vtp) const
{
return allocator->alloc<AstTypePackVariadic>(Location(), Luau::visit(*typeVisitor, vtp.ty->ty));
}
AstTypePack* operator()(const GenericTypePack& gtp) const
{
return allocator->alloc<AstTypePackGeneric>(Location(), AstName(getName(allocator, syntheticNames, gtp)));
}
AstTypePack* operator()(const FreeTypePack& gtp) const
{
return allocator->alloc<AstTypePackGeneric>(Location(), AstName("free"));
}
AstTypePack* operator()(const Unifiable::Error&) const
{
return allocator->alloc<AstTypePackGeneric>(Location(), AstName("Unifiable<Error>"));
}
private:
Allocator* allocator;
SyntheticNames* syntheticNames;
TypeRehydrationVisitor* typeVisitor;
};
AstTypePack* TypeRehydrationVisitor::rehydrate(TypePackId tp)
{
TypePackRehydrationVisitor tprv(allocator, syntheticNames, this);
return Luau::visit(tprv, tp->ty);
}
class TypeAttacher : public AstVisitor
{
public:
@ -344,7 +457,7 @@ public:
{
if (!type)
return nullptr;
return Luau::visit(TypeRehydrationVisitor(allocator), (*type)->ty);
return Luau::visit(TypeRehydrationVisitor(allocator, &syntheticNames), (*type)->ty);
}
AstArray<Luau::AstType*> typeAstPack(TypePackId type)
@ -356,7 +469,7 @@ public:
result.data = static_cast<AstType**>(allocator->allocate(sizeof(AstType*) * v.size()));
for (size_t i = 0; i < v.size(); ++i)
{
result.data[i] = Luau::visit(TypeRehydrationVisitor(allocator), v[i]->ty);
result.data[i] = Luau::visit(TypeRehydrationVisitor(allocator, &syntheticNames), v[i]->ty);
}
return result;
}
@ -406,9 +519,16 @@ public:
if (tail)
{
TypePackId tailPack = *tail;
if (const VariadicTypePack* vtp = get<VariadicTypePack>(tailPack))
variadicAnnotation = allocator->alloc<AstTypePackVariadic>(Location(), typeAst(vtp->ty));
if (FFlag::LuauTypeAliasPacks)
{
variadicAnnotation = TypeRehydrationVisitor(allocator, &syntheticNames).rehydrate(*tail);
}
else
{
TypePackId tailPack = *tail;
if (const VariadicTypePack* vtp = get<VariadicTypePack>(tailPack))
variadicAnnotation = allocator->alloc<AstTypePackVariadic>(Location(), typeAst(vtp->ty));
}
}
fn->returnAnnotation = AstTypeList{typeAstPack(ret), variadicAnnotation};
@ -421,6 +541,7 @@ public:
private:
Module& module;
Allocator* allocator;
SyntheticNames syntheticNames;
};
void attachTypeData(SourceModule& source, Module& result)
@ -431,7 +552,8 @@ void attachTypeData(SourceModule& source, Module& result)
AstType* rehydrateAnnotation(TypeId type, Allocator* allocator, const TypeRehydrationOptions& options)
{
return Luau::visit(TypeRehydrationVisitor(allocator, options), type->ty);
SyntheticNames syntheticNames;
return Luau::visit(TypeRehydrationVisitor(allocator, &syntheticNames, options), type->ty);
}
} // namespace Luau

File diff suppressed because it is too large Load diff

View file

@ -97,7 +97,7 @@ TypePackIterator begin(TypePackId tp)
TypePackIterator end(TypePackId tp)
{
return FFlag::LuauAddMissingFollow ? TypePackIterator{} : TypePackIterator{nullptr};
return TypePackIterator{};
}
bool areEqual(SeenSet& seen, const TypePackVar& lhs, const TypePackVar& rhs)
@ -203,18 +203,31 @@ TypePackId follow(TypePackId tp)
size_t size(TypePackId tp)
{
if (auto pack = get<TypePack>(FFlag::LuauAddMissingFollow ? follow(tp) : tp))
if (auto pack = get<TypePack>(follow(tp)))
return size(*pack);
else
return 0;
}
bool finite(TypePackId tp)
{
tp = follow(tp);
if (auto pack = get<TypePack>(tp))
return pack->tail ? finite(*pack->tail) : true;
if (get<VariadicTypePack>(tp))
return false;
return true;
}
size_t size(const TypePack& tp)
{
size_t result = tp.head.size();
if (tp.tail)
{
const TypePack* tail = get<TypePack>(FFlag::LuauAddMissingFollow ? follow(*tp.tail) : *tp.tail);
const TypePack* tail = get<TypePack>(follow(*tp.tail));
if (tail)
result += size(*tail);
}
@ -273,5 +286,4 @@ TypePack* asMutable(const TypePack* tp)
{
return const_cast<TypePack*>(tp);
}
} // namespace Luau

View file

@ -1,11 +1,10 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/TypeUtils.h"
#include "Luau/Scope.h"
#include "Luau/ToString.h"
#include "Luau/TypeInfer.h"
LUAU_FASTFLAG(LuauStringMetatable)
namespace Luau
{
@ -13,21 +12,6 @@ std::optional<TypeId> findMetatableEntry(ErrorVec& errors, const ScopePtr& globa
{
type = follow(type);
if (!FFlag::LuauStringMetatable)
{
if (const PrimitiveTypeVar* primType = get<PrimitiveTypeVar>(type))
{
if (primType->type != PrimitiveTypeVar::String || "__index" != entry)
return std::nullopt;
auto it = globalScope->bindings.find(AstName{"string"});
if (it != globalScope->bindings.end())
return it->second.typeId;
else
return std::nullopt;
}
}
std::optional<TypeId> metatable = getMetatable(type);
if (!metatable)
return std::nullopt;

View file

@ -19,11 +19,9 @@
LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500)
LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0)
LUAU_FASTFLAG(LuauImprovedTypeGuardPredicate2)
LUAU_FASTFLAGVARIABLE(LuauToStringFollowsBoundTo, false)
LUAU_FASTFLAG(LuauRankNTypes)
LUAU_FASTFLAGVARIABLE(LuauStringMetatable, false)
LUAU_FASTFLAG(LuauTypeGuardPeelsAwaySubclasses)
LUAU_FASTFLAG(LuauTypeAliasPacks)
LUAU_FASTFLAGVARIABLE(LuauRefactorTagging, false)
LUAU_FASTFLAG(LuauErrorRecoveryType)
namespace Luau
{
@ -43,7 +41,7 @@ TypeId follow(TypeId t)
};
auto force = [](TypeId ty) {
if (auto ltv = FFlag::LuauAddMissingFollow ? get_if<LazyTypeVar>(&ty->ty) : get<LazyTypeVar>(ty))
if (auto ltv = get_if<LazyTypeVar>(&ty->ty))
{
TypeId res = ltv->thunk();
if (get<LazyTypeVar>(res))
@ -193,27 +191,11 @@ bool isOptional(TypeId ty)
bool isTableIntersection(TypeId ty)
{
if (FFlag::LuauImprovedTypeGuardPredicate2)
{
if (!get<IntersectionTypeVar>(follow(ty)))
return false;
std::vector<TypeId> parts = flattenIntersection(ty);
return std::all_of(parts.begin(), parts.end(), getTableType);
}
else
{
if (const IntersectionTypeVar* itv = get<IntersectionTypeVar>(ty))
{
for (TypeId part : itv->parts)
{
if (getTableType(follow(part)))
return true;
}
}
if (!get<IntersectionTypeVar>(follow(ty)))
return false;
}
std::vector<TypeId> parts = flattenIntersection(ty);
return std::all_of(parts.begin(), parts.end(), getTableType);
}
bool isOverloadedFunction(TypeId ty)
@ -235,8 +217,7 @@ std::optional<TypeId> getMetatable(TypeId type)
return mtType->metatable;
else if (const ClassTypeVar* classType = get<ClassTypeVar>(type))
return classType->metatable;
else if (const PrimitiveTypeVar* primitiveType = get<PrimitiveTypeVar>(type);
FFlag::LuauStringMetatable && primitiveType && primitiveType->metatable)
else if (const PrimitiveTypeVar* primitiveType = get<PrimitiveTypeVar>(type); primitiveType && primitiveType->metatable)
{
LUAU_ASSERT(primitiveType->type == PrimitiveTypeVar::String);
return primitiveType->metatable;
@ -313,8 +294,8 @@ bool isGeneric(TypeId ty)
bool maybeGeneric(TypeId ty)
{
ty = follow(ty);
if (auto ftv = get<FreeTypeVar>(ty))
return FFlag::LuauRankNTypes || ftv->DEPRECATED_canBeGeneric;
if (get<FreeTypeVar>(ty))
return true;
else if (auto ttv = get<TableTypeVar>(ty))
{
// TODO: recurse on table types CLI-39914
@ -325,6 +306,18 @@ bool maybeGeneric(TypeId ty)
return isGeneric(ty);
}
bool maybeSingleton(TypeId ty)
{
ty = follow(ty);
if (get<SingletonTypeVar>(ty))
return true;
if (const UnionTypeVar* utv = get<UnionTypeVar>(ty))
for (TypeId option : utv)
if (get<SingletonTypeVar>(follow(option)))
return true;
return false;
}
FunctionTypeVar::FunctionTypeVar(TypePackId argTypes, TypePackId retType, std::optional<FunctionDefinition> defn, bool hasSelf)
: argTypes(argTypes)
, retType(retType)
@ -563,15 +556,28 @@ TypeId makeFunction(TypeArena& arena, std::optional<TypeId> selfType, std::initi
std::initializer_list<TypePackId> genericPacks, std::initializer_list<TypeId> paramTypes, std::initializer_list<std::string> paramNames,
std::initializer_list<TypeId> retTypes);
static TypeVar nilType_{PrimitiveTypeVar{PrimitiveTypeVar::NilType}, /*persistent*/ true};
static TypeVar numberType_{PrimitiveTypeVar{PrimitiveTypeVar::Number}, /*persistent*/ true};
static TypeVar stringType_{PrimitiveTypeVar{PrimitiveTypeVar::String}, /*persistent*/ true};
static TypeVar booleanType_{PrimitiveTypeVar{PrimitiveTypeVar::Boolean}, /*persistent*/ true};
static TypeVar threadType_{PrimitiveTypeVar{PrimitiveTypeVar::Thread}, /*persistent*/ true};
static TypeVar anyType_{AnyTypeVar{}};
static TypeVar errorType_{ErrorTypeVar{}};
static TypeVar optionalNumberType_{UnionTypeVar{{&numberType_, &nilType_}}};
static TypePackVar anyTypePack_{VariadicTypePack{&anyType_}, true};
static TypePackVar errorTypePack_{Unifiable::Error{}};
SingletonTypes::SingletonTypes()
: arena(new TypeArena)
, nilType_{PrimitiveTypeVar{PrimitiveTypeVar::NilType}, /*persistent*/ true}
, numberType_{PrimitiveTypeVar{PrimitiveTypeVar::Number}, /*persistent*/ true}
, stringType_{PrimitiveTypeVar{PrimitiveTypeVar::String}, /*persistent*/ true}
, booleanType_{PrimitiveTypeVar{PrimitiveTypeVar::Boolean}, /*persistent*/ true}
, threadType_{PrimitiveTypeVar{PrimitiveTypeVar::Thread}, /*persistent*/ true}
, anyType_{AnyTypeVar{}}
, errorType_{ErrorTypeVar{}}
: nilType(&nilType_)
, numberType(&numberType_)
, stringType(&stringType_)
, booleanType(&booleanType_)
, threadType(&threadType_)
, anyType(&anyType_)
, optionalNumberType(&optionalNumberType_)
, anyTypePack(&anyTypePack_)
, arena(new TypeArena)
{
TypeId stringMetatable = makeStringMetatable();
stringType_.ty = PrimitiveTypeVar{PrimitiveTypeVar::String, makeStringMetatable()};
@ -639,6 +645,32 @@ TypeId SingletonTypes::makeStringMetatable()
return arena->addType(TableTypeVar{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed});
}
TypeId SingletonTypes::errorRecoveryType()
{
return &errorType_;
}
TypePackId SingletonTypes::errorRecoveryTypePack()
{
return &errorTypePack_;
}
TypeId SingletonTypes::errorRecoveryType(TypeId guess)
{
if (FFlag::LuauErrorRecoveryType)
return guess;
else
return &errorType_;
}
TypePackId SingletonTypes::errorRecoveryTypePack(TypePackId guess)
{
if (FFlag::LuauErrorRecoveryType)
return guess;
else
return &errorTypePack_;
}
SingletonTypes singletonTypes;
void persist(TypeId ty)
@ -767,9 +799,9 @@ void StateDot::visitChild(TypeId ty, int parentIndex, const char* linkName)
if (opts.duplicatePrimitives && canDuplicatePrimitive(ty))
{
if (const PrimitiveTypeVar* ptv = get<PrimitiveTypeVar>(ty))
if (get<PrimitiveTypeVar>(ty))
formatAppend(result, "n%d [label=\"%s\"];\n", index, toStringDetailed(ty, {}).name.c_str());
else if (const AnyTypeVar* atv = get<AnyTypeVar>(ty))
else if (get<AnyTypeVar>(ty))
formatAppend(result, "n%d [label=\"any\"];\n", index);
}
else
@ -871,6 +903,12 @@ void StateDot::visitChildren(TypeId ty, int index)
}
for (TypeId itp : ttv->instantiatedTypeParams)
visitChild(itp, index, "typeParam");
if (FFlag::LuauTypeAliasPacks)
{
for (TypePackId itp : ttv->instantiatedTypePackParams)
visitChild(itp, index, "typePackParam");
}
}
else if (const MetatableTypeVar* mtv = get<MetatableTypeVar>(ty))
{
@ -914,19 +952,19 @@ void StateDot::visitChildren(TypeId ty, int index)
finishNodeLabel(ty);
finishNode();
}
else if (const AnyTypeVar* atv = get<AnyTypeVar>(ty))
else if (get<AnyTypeVar>(ty))
{
formatAppend(result, "AnyTypeVar %d", index);
finishNodeLabel(ty);
finishNode();
}
else if (const PrimitiveTypeVar* ptv = get<PrimitiveTypeVar>(ty))
else if (get<PrimitiveTypeVar>(ty))
{
formatAppend(result, "PrimitiveTypeVar %s", toStringDetailed(ty, {}).name.c_str());
finishNodeLabel(ty);
finishNode();
}
else if (const ErrorTypeVar* etv = get<ErrorTypeVar>(ty))
else if (get<ErrorTypeVar>(ty))
{
formatAppend(result, "ErrorTypeVar %d", index);
finishNodeLabel(ty);
@ -1006,7 +1044,7 @@ void StateDot::visitChildren(TypePackId tp, int index)
finishNodeLabel(tp);
finishNode();
}
else if (const Unifiable::Error* etp = get<Unifiable::Error>(tp))
else if (get<Unifiable::Error>(tp))
{
formatAppend(result, "ErrorTypePack %d", index);
finishNodeLabel(tp);
@ -1140,6 +1178,11 @@ struct QVarFinder
return false;
}
bool operator()(const SingletonTypeVar&) const
{
return false;
}
bool operator()(const FunctionTypeVar& ftv) const
{
if (hasGeneric(ftv.argTypes))
@ -1384,24 +1427,6 @@ UnionTypeVarIterator end(const UnionTypeVar* utv)
return UnionTypeVarIterator{};
}
static std::vector<TypeId> DEPRECATED_filterMap(TypeId type, TypeIdPredicate predicate)
{
std::vector<TypeId> result;
if (auto utv = get<UnionTypeVar>(follow(type)))
{
for (TypeId option : utv)
{
if (auto out = predicate(follow(option)))
result.push_back(*out);
}
}
else if (auto out = predicate(follow(type)))
return {*out};
return result;
}
static std::vector<TypeId> parseFormatString(TypeChecker& typechecker, const char* data, size_t size)
{
const char* options = "cdiouxXeEfgGqs";
@ -1429,7 +1454,7 @@ static std::vector<TypeId> parseFormatString(TypeChecker& typechecker, const cha
else if (strchr(options, data[i]))
result.push_back(typechecker.numberType);
else
result.push_back(typechecker.errorType);
result.push_back(typechecker.errorRecoveryType(typechecker.anyType));
}
}
@ -1482,9 +1507,6 @@ std::optional<ExprResult<TypePackId>> magicFunctionFormat(
std::vector<TypeId> filterMap(TypeId type, TypeIdPredicate predicate)
{
if (!FFlag::LuauTypeGuardPeelsAwaySubclasses)
return DEPRECATED_filterMap(type, predicate);
type = follow(type);
if (auto utv = get<UnionTypeVar>(type))
@ -1502,4 +1524,86 @@ std::vector<TypeId> filterMap(TypeId type, TypeIdPredicate predicate)
return {};
}
static Tags* getTags(TypeId ty)
{
ty = follow(ty);
if (auto ftv = getMutable<FunctionTypeVar>(ty))
return &ftv->tags;
else if (auto ttv = getMutable<TableTypeVar>(ty))
return &ttv->tags;
else if (auto ctv = getMutable<ClassTypeVar>(ty))
return &ctv->tags;
return nullptr;
}
void attachTag(TypeId ty, const std::string& tagName)
{
if (!FFlag::LuauRefactorTagging)
{
if (auto ftv = getMutable<FunctionTypeVar>(ty))
{
ftv->tags.emplace_back(tagName);
}
else
{
LUAU_ASSERT(!"Got a non functional type");
}
}
else
{
if (auto tags = getTags(ty))
tags->push_back(tagName);
else
LUAU_ASSERT(!"This TypeId does not support tags");
}
}
void attachTag(Property& prop, const std::string& tagName)
{
LUAU_ASSERT(FFlag::LuauRefactorTagging);
prop.tags.push_back(tagName);
}
// We would ideally not expose this because it could cause a footgun.
// If the Base class has a tag and you ask if Derived has that tag, it would return false.
// Unfortunately, there's already use cases that's hard to disentangle. For now, we expose it.
bool hasTag(const Tags& tags, const std::string& tagName)
{
LUAU_ASSERT(FFlag::LuauRefactorTagging);
return std::find(tags.begin(), tags.end(), tagName) != tags.end();
}
bool hasTag(TypeId ty, const std::string& tagName)
{
ty = follow(ty);
// We special case classes because getTags only returns a pointer to one vector of tags.
// But classes has multiple vector of tags, represented throughout the hierarchy.
if (auto ctv = get<ClassTypeVar>(ty))
{
while (ctv)
{
if (hasTag(ctv->tags, tagName))
return true;
else if (!ctv->parent)
return false;
ctv = get<ClassTypeVar>(*ctv->parent);
LUAU_ASSERT(ctv);
}
}
else if (auto tags = getTags(ty))
return hasTag(*tags, tagName);
return false;
}
bool hasTag(const Property& prop, const std::string& tagName)
{
return hasTag(prop.tags, tagName);
}
} // namespace Luau

View file

@ -1,8 +1,6 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Unifiable.h"
LUAU_FASTFLAG(LuauRankNTypes)
namespace Luau
{
namespace Unifiable
@ -14,14 +12,6 @@ Free::Free(TypeLevel level)
{
}
Free::Free(TypeLevel level, bool DEPRECATED_canBeGeneric)
: index(++nextIndex)
, level(level)
, DEPRECATED_canBeGeneric(DEPRECATED_canBeGeneric)
{
LUAU_ASSERT(!FFlag::LuauRankNTypes);
}
int Free::nextIndex = 0;
Generic::Generic()

File diff suppressed because it is too large Load diff

View file

@ -255,6 +255,14 @@ public:
{
return visit((class AstType*)node);
}
virtual bool visit(class AstTypeSingletonBool* node)
{
return visit((class AstType*)node);
}
virtual bool visit(class AstTypeSingletonString* node)
{
return visit((class AstType*)node);
}
virtual bool visit(class AstTypeError* node)
{
return visit((class AstType*)node);
@ -264,6 +272,10 @@ public:
{
return false;
}
virtual bool visit(class AstTypePackExplicit* node)
{
return visit((class AstTypePack*)node);
}
virtual bool visit(class AstTypePackVariadic* node)
{
return visit((class AstTypePack*)node);
@ -930,12 +942,14 @@ class AstStatTypeAlias : public AstStat
public:
LUAU_RTTI(AstStatTypeAlias)
AstStatTypeAlias(const Location& location, const AstName& name, const AstArray<AstName>& generics, AstType* type, bool exported);
AstStatTypeAlias(const Location& location, const AstName& name, const AstArray<AstName>& generics, const AstArray<AstName>& genericPacks,
AstType* type, bool exported);
void visit(AstVisitor* visitor) override;
AstName name;
AstArray<AstName> generics;
AstArray<AstName> genericPacks;
AstType* type;
bool exported;
};
@ -1007,19 +1021,28 @@ public:
}
};
// Don't have Luau::Variant available, it's a bit of an overhead, but a plain struct is nice to use
struct AstTypeOrPack
{
AstType* type = nullptr;
AstTypePack* typePack = nullptr;
};
class AstTypeReference : public AstType
{
public:
LUAU_RTTI(AstTypeReference)
AstTypeReference(const Location& location, std::optional<AstName> prefix, AstName name, const AstArray<AstType*>& generics = {});
AstTypeReference(const Location& location, std::optional<AstName> prefix, AstName name, bool hasParameterList = false,
const AstArray<AstTypeOrPack>& parameters = {});
void visit(AstVisitor* visitor) override;
bool hasPrefix;
bool hasParameterList;
AstName prefix;
AstName name;
AstArray<AstType*> generics;
AstArray<AstTypeOrPack> parameters;
};
struct AstTableProp
@ -1143,6 +1166,30 @@ public:
unsigned messageIndex;
};
class AstTypeSingletonBool : public AstType
{
public:
LUAU_RTTI(AstTypeSingletonBool)
AstTypeSingletonBool(const Location& location, bool value);
void visit(AstVisitor* visitor) override;
bool value;
};
class AstTypeSingletonString : public AstType
{
public:
LUAU_RTTI(AstTypeSingletonString)
AstTypeSingletonString(const Location& location, const AstArray<char>& value);
void visit(AstVisitor* visitor) override;
const AstArray<char> value;
};
class AstTypePack : public AstNode
{
public:
@ -1152,6 +1199,18 @@ public:
}
};
class AstTypePackExplicit : public AstTypePack
{
public:
LUAU_RTTI(AstTypePackExplicit)
AstTypePackExplicit(const Location& location, AstTypeList typeList);
void visit(AstVisitor* visitor) override;
AstTypeList typeList;
};
class AstTypePackVariadic : public AstTypePack
{
public:

View file

@ -136,7 +136,10 @@ public:
const Key& key = ItemInterface::getKey(data[i]);
if (!eq(key, empty_key))
*newtable.insert_unsafe(key) = data[i];
{
Item* item = newtable.insert_unsafe(key);
*item = std::move(data[i]);
}
}
LUAU_ASSERT(count == newtable.count);

View file

@ -218,13 +218,14 @@ private:
AstTableIndexer* parseTableIndexerAnnotation();
AstType* parseFunctionTypeAnnotation();
AstTypeOrPack parseFunctionTypeAnnotation(bool allowPack);
AstType* parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray<AstName> generics, AstArray<AstName> genericPacks,
AstArray<AstType*>& params, AstArray<std::optional<AstArgumentName>>& paramNames, AstTypePack* varargAnnotation);
AstType* parseTableTypeAnnotation();
AstType* parseSimpleTypeAnnotation();
AstTypeOrPack parseSimpleTypeAnnotation(bool allowPack);
AstTypeOrPack parseTypeOrPackAnnotation();
AstType* parseTypeAnnotation(TempVector<AstType*>& parts, const Location& begin);
AstType* parseTypeAnnotation();
@ -281,11 +282,11 @@ private:
// `<' namelist `>'
std::pair<AstArray<AstName>, AstArray<AstName>> parseGenericTypeList();
std::pair<AstArray<AstName>, AstArray<AstName>> parseGenericTypeListIfFFlagParseGenericFunctions();
// `<' typeAnnotation[, ...] `>'
AstArray<AstType*> parseTypeParams();
AstArray<AstTypeOrPack> parseTypeParams();
std::optional<AstArray<char>> parseCharArray();
AstExpr* parseString();
AstLocal* pushLocal(const Binding& binding);
@ -413,6 +414,7 @@ private:
std::vector<AstLocal*> scratchLocal;
std::vector<AstTableProp> scratchTableTypeProps;
std::vector<AstType*> scratchAnnotation;
std::vector<AstTypeOrPack> scratchTypeOrPackAnnotation;
std::vector<AstDeclaredClassProp> scratchDeclaredClassProps;
std::vector<AstExprTable::Item> scratchItem;
std::vector<AstArgumentName> scratchArgName;

View file

@ -34,4 +34,6 @@ bool equalsLower(std::string_view lhs, std::string_view rhs);
size_t hashRange(const char* data, size_t size);
std::string escape(std::string_view s);
bool isIdentifier(std::string_view s);
} // namespace Luau

View file

@ -0,0 +1,223 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Common.h"
#include <vector>
#include <stdint.h>
LUAU_FASTFLAG(DebugLuauTimeTracing)
#if defined(LUAU_ENABLE_TIME_TRACE)
namespace Luau
{
namespace TimeTrace
{
uint32_t getClockMicroseconds();
struct Token
{
const char* name;
const char* category;
};
enum class EventType : uint8_t
{
Enter,
Leave,
ArgName,
ArgValue,
};
struct Event
{
EventType type;
uint16_t token;
union
{
uint32_t microsec; // 1 hour trace limit
uint32_t dataPos;
} data;
};
struct GlobalContext;
struct ThreadContext;
GlobalContext& getGlobalContext();
uint16_t createToken(GlobalContext& context, const char* name, const char* category);
uint32_t createThread(GlobalContext& context, ThreadContext* threadContext);
void releaseThread(GlobalContext& context, ThreadContext* threadContext);
void flushEvents(GlobalContext& context, uint32_t threadId, const std::vector<Event>& events, const std::vector<char>& data);
struct ThreadContext
{
ThreadContext()
: globalContext(getGlobalContext())
{
threadId = createThread(globalContext, this);
}
~ThreadContext()
{
if (!events.empty())
flushEvents();
releaseThread(globalContext, this);
}
void flushEvents()
{
static uint16_t flushToken = createToken(globalContext, "flushEvents", "TimeTrace");
events.push_back({EventType::Enter, flushToken, {getClockMicroseconds()}});
TimeTrace::flushEvents(globalContext, threadId, events, data);
events.clear();
data.clear();
events.push_back({EventType::Leave, 0, {getClockMicroseconds()}});
}
void eventEnter(uint16_t token)
{
eventEnter(token, getClockMicroseconds());
}
void eventEnter(uint16_t token, uint32_t microsec)
{
events.push_back({EventType::Enter, token, {microsec}});
}
void eventLeave()
{
eventLeave(getClockMicroseconds());
}
void eventLeave(uint32_t microsec)
{
events.push_back({EventType::Leave, 0, {microsec}});
if (events.size() > kEventFlushLimit)
flushEvents();
}
void eventArgument(const char* name, const char* value)
{
uint32_t pos = uint32_t(data.size());
data.insert(data.end(), name, name + strlen(name) + 1);
events.push_back({EventType::ArgName, 0, {pos}});
pos = uint32_t(data.size());
data.insert(data.end(), value, value + strlen(value) + 1);
events.push_back({EventType::ArgValue, 0, {pos}});
}
GlobalContext& globalContext;
uint32_t threadId;
std::vector<Event> events;
std::vector<char> data;
static constexpr size_t kEventFlushLimit = 8192;
};
ThreadContext& getThreadContext();
struct Scope
{
explicit Scope(ThreadContext& context, uint16_t token)
: context(context)
{
if (!FFlag::DebugLuauTimeTracing)
return;
context.eventEnter(token);
}
~Scope()
{
if (!FFlag::DebugLuauTimeTracing)
return;
context.eventLeave();
}
ThreadContext& context;
};
struct OptionalTailScope
{
explicit OptionalTailScope(ThreadContext& context, uint16_t token, uint32_t threshold)
: context(context)
, token(token)
, threshold(threshold)
{
if (!FFlag::DebugLuauTimeTracing)
return;
pos = uint32_t(context.events.size());
microsec = getClockMicroseconds();
}
~OptionalTailScope()
{
if (!FFlag::DebugLuauTimeTracing)
return;
if (pos == context.events.size())
{
uint32_t curr = getClockMicroseconds();
if (curr - microsec > threshold)
{
context.eventEnter(token, microsec);
context.eventLeave(curr);
}
}
}
ThreadContext& context;
uint16_t token;
uint32_t threshold;
uint32_t microsec;
uint32_t pos;
};
LUAU_NOINLINE std::pair<uint16_t, Luau::TimeTrace::ThreadContext&> createScopeData(const char* name, const char* category);
} // namespace TimeTrace
} // namespace Luau
// Regular scope
#define LUAU_TIMETRACE_SCOPE(name, category) \
static auto lttScopeStatic = Luau::TimeTrace::createScopeData(name, category); \
Luau::TimeTrace::Scope lttScope(lttScopeStatic.second, lttScopeStatic.first)
// A scope without nested scopes that may be skipped if the time it took is less than the threshold
#define LUAU_TIMETRACE_OPTIONAL_TAIL_SCOPE(name, category, microsec) \
static auto lttScopeStaticOptTail = Luau::TimeTrace::createScopeData(name, category); \
Luau::TimeTrace::OptionalTailScope lttScope(lttScopeStaticOptTail.second, lttScopeStaticOptTail.first, microsec)
// Extra key/value data can be added to regular scopes
#define LUAU_TIMETRACE_ARGUMENT(name, value) \
do \
{ \
if (FFlag::DebugLuauTimeTracing) \
lttScopeStatic.second.eventArgument(name, value); \
} while (false)
#else
#define LUAU_TIMETRACE_SCOPE(name, category)
#define LUAU_TIMETRACE_OPTIONAL_TAIL_SCOPE(name, category, microsec)
#define LUAU_TIMETRACE_ARGUMENT(name, value) \
do \
{ \
} while (false)
#endif

View file

@ -641,10 +641,12 @@ void AstStatLocalFunction::visit(AstVisitor* visitor)
func->visit(visitor);
}
AstStatTypeAlias::AstStatTypeAlias(const Location& location, const AstName& name, const AstArray<AstName>& generics, AstType* type, bool exported)
AstStatTypeAlias::AstStatTypeAlias(const Location& location, const AstName& name, const AstArray<AstName>& generics,
const AstArray<AstName>& genericPacks, AstType* type, bool exported)
: AstStat(ClassIndex(), location)
, name(name)
, generics(generics)
, genericPacks(genericPacks)
, type(type)
, exported(exported)
{
@ -729,12 +731,14 @@ void AstStatError::visit(AstVisitor* visitor)
}
}
AstTypeReference::AstTypeReference(const Location& location, std::optional<AstName> prefix, AstName name, const AstArray<AstType*>& generics)
AstTypeReference::AstTypeReference(
const Location& location, std::optional<AstName> prefix, AstName name, bool hasParameterList, const AstArray<AstTypeOrPack>& parameters)
: AstType(ClassIndex(), location)
, hasPrefix(bool(prefix))
, hasParameterList(hasParameterList)
, prefix(prefix ? *prefix : AstName())
, name(name)
, generics(generics)
, parameters(parameters)
{
}
@ -742,8 +746,13 @@ void AstTypeReference::visit(AstVisitor* visitor)
{
if (visitor->visit(this))
{
for (AstType* generic : generics)
generic->visit(visitor);
for (const AstTypeOrPack& param : parameters)
{
if (param.type)
param.type->visit(visitor);
else
param.typePack->visit(visitor);
}
}
}
@ -832,6 +841,28 @@ void AstTypeIntersection::visit(AstVisitor* visitor)
}
}
AstTypeSingletonBool::AstTypeSingletonBool(const Location& location, bool value)
: AstType(ClassIndex(), location)
, value(value)
{
}
void AstTypeSingletonBool::visit(AstVisitor* visitor)
{
visitor->visit(this);
}
AstTypeSingletonString::AstTypeSingletonString(const Location& location, const AstArray<char>& value)
: AstType(ClassIndex(), location)
, value(value)
{
}
void AstTypeSingletonString::visit(AstVisitor* visitor)
{
visitor->visit(this);
}
AstTypeError::AstTypeError(const Location& location, const AstArray<AstType*>& types, bool isMissing, unsigned messageIndex)
: AstType(ClassIndex(), location)
, types(types)
@ -849,6 +880,24 @@ void AstTypeError::visit(AstVisitor* visitor)
}
}
AstTypePackExplicit::AstTypePackExplicit(const Location& location, AstTypeList typeList)
: AstTypePack(ClassIndex(), location)
, typeList(typeList)
{
}
void AstTypePackExplicit::visit(AstVisitor* visitor)
{
if (visitor->visit(this))
{
for (AstType* type : typeList.types)
type->visit(visitor);
if (typeList.tailType)
typeList.tailType->visit(visitor);
}
}
AstTypePackVariadic::AstTypePackVariadic(const Location& location, AstType* variadicType)
: AstTypePack(ClassIndex(), location)
, variadicType(variadicType)

View file

@ -1,6 +1,8 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Parser.h"
#include "Luau/TimeTrace.h"
#include <algorithm>
// Warning: If you are introducing new syntax, ensure that it is behind a separate
@ -8,11 +10,14 @@
// See docs/SyntaxChanges.md for an explanation.
LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000)
LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100)
LUAU_FASTFLAGVARIABLE(LuauGenericFunctionsParserFix, false)
LUAU_FASTFLAGVARIABLE(LuauParseGenericFunctions, false)
LUAU_FASTFLAGVARIABLE(LuauCaptureBrokenCommentSpans, false)
LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionBaseSupport, false)
LUAU_FASTFLAGVARIABLE(LuauIfStatementRecursionGuard, false)
LUAU_FASTFLAGVARIABLE(LuauTypeAliasPacks, false)
LUAU_FASTFLAGVARIABLE(LuauParseTypePackTypeParameters, false)
LUAU_FASTFLAGVARIABLE(LuauFixAmbiguousErrorRecoveryInAssign, false)
LUAU_FASTFLAGVARIABLE(LuauParseSingletonTypes, false)
LUAU_FASTFLAGVARIABLE(LuauParseGenericFunctionTypeBegin, false)
namespace Luau
{
@ -148,6 +153,8 @@ static bool shouldParseTypePackAnnotation(Lexer& lexer)
ParseResult Parser::parse(const char* buffer, size_t bufferSize, AstNameTable& names, Allocator& allocator, ParseOptions options)
{
LUAU_TIMETRACE_SCOPE("Parser::parse", "Parser");
Parser p(buffer, bufferSize, names, allocator);
try
@ -769,14 +776,14 @@ AstStat* Parser::parseTypeAlias(const Location& start, bool exported)
if (!name)
name = Name(nameError, lexer.current().location);
// TODO: support generic type pack parameters in type aliases CLI-39907
auto [generics, genericPacks] = parseGenericTypeList();
expectAndConsume('=', "type alias");
AstType* type = parseTypeAnnotation();
return allocator.alloc<AstStatTypeAlias>(Location(start, type->location), name->name, generics, type, exported);
return allocator.alloc<AstStatTypeAlias>(
Location(start, type->location), name->name, generics, FFlag::LuauTypeAliasPacks ? genericPacks : AstArray<AstName>{}, type, exported);
}
AstDeclaredClassProp Parser::parseDeclaredClassMethod()
@ -951,7 +958,7 @@ AstStat* Parser::parseAssignment(AstExpr* initial)
{
nextLexeme();
AstExpr* expr = parsePrimaryExpr(/* asStatement= */ false);
AstExpr* expr = parsePrimaryExpr(/* asStatement= */ FFlag::LuauFixAmbiguousErrorRecoveryInAssign);
if (!isExprLValue(expr))
expr = reportExprError(expr->location, copy({expr}), "Assigned expression must be a variable or a field");
@ -989,7 +996,7 @@ std::pair<AstExprFunction*, AstLocal*> Parser::parseFunctionBody(
{
Location start = matchFunction.location;
auto [generics, genericPacks] = parseGenericTypeListIfFFlagParseGenericFunctions();
auto [generics, genericPacks] = parseGenericTypeList();
Lexeme matchParen = lexer.current();
expectAndConsume('(', "function");
@ -1272,7 +1279,27 @@ AstType* Parser::parseTableTypeAnnotation()
while (lexer.current().type != '}')
{
if (lexer.current().type == '[')
if (FFlag::LuauParseSingletonTypes && lexer.current().type == '[' &&
(lexer.lookahead().type == Lexeme::RawString || lexer.lookahead().type == Lexeme::QuotedString))
{
const Lexeme begin = lexer.current();
nextLexeme(); // [
std::optional<AstArray<char>> chars = parseCharArray();
expectMatchAndConsume(']', begin);
expectAndConsume(':', "table field");
AstType* type = parseTypeAnnotation();
// TODO: since AstName conains a char*, it can't contain null
bool containsNull = chars && (strnlen(chars->data, chars->size) < chars->size);
if (chars && !containsNull)
props.push_back({AstName(chars->data), begin.location, type});
else
report(begin.location, "String literal contains malformed escape sequence");
}
else if (lexer.current().type == '[')
{
if (indexer)
{
@ -1333,23 +1360,22 @@ AstType* Parser::parseTableTypeAnnotation()
// ReturnType ::= TypeAnnotation | `(' TypeList `)'
// FunctionTypeAnnotation ::= [`<' varlist `>'] `(' [TypeList] `)' `->` ReturnType
AstType* Parser::parseFunctionTypeAnnotation()
AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack)
{
incrementRecursionCounter("type annotation");
bool monomorphic = !(FFlag::LuauParseGenericFunctions && lexer.current().type == '<');
auto [generics, genericPacks] = parseGenericTypeListIfFFlagParseGenericFunctions();
bool monomorphic = lexer.current().type != '<';
Lexeme begin = lexer.current();
if (FFlag::LuauGenericFunctionsParserFix)
expectAndConsume('(', "function parameters");
else
{
LUAU_ASSERT(begin.type == '(');
nextLexeme(); // (
}
auto [generics, genericPacks] = parseGenericTypeList();
Lexeme parameterStart = lexer.current();
if (!FFlag::LuauParseGenericFunctionTypeBegin)
begin = parameterStart;
expectAndConsume('(', "function parameters");
matchRecoveryStopOnToken[Lexeme::SkinnyArrow]++;
@ -1360,18 +1386,27 @@ AstType* Parser::parseFunctionTypeAnnotation()
if (lexer.current().type != ')')
varargAnnotation = parseTypeList(params, names);
expectMatchAndConsume(')', begin, true);
expectMatchAndConsume(')', parameterStart, true);
matchRecoveryStopOnToken[Lexeme::SkinnyArrow]--;
// Not a function at all. Just a parenthesized type.
if (params.size() == 1 && !varargAnnotation && monomorphic && lexer.current().type != Lexeme::SkinnyArrow)
return params[0];
AstArray<AstType*> paramTypes = copy(params);
// Not a function at all. Just a parenthesized type. Or maybe a type pack with a single element
if (params.size() == 1 && !varargAnnotation && monomorphic && lexer.current().type != Lexeme::SkinnyArrow)
{
if (allowPack)
return {{}, allocator.alloc<AstTypePackExplicit>(begin.location, AstTypeList{paramTypes, nullptr})};
else
return {params[0], {}};
}
if (lexer.current().type != Lexeme::SkinnyArrow && monomorphic && allowPack)
return {{}, allocator.alloc<AstTypePackExplicit>(begin.location, AstTypeList{paramTypes, varargAnnotation})};
AstArray<std::optional<AstArgumentName>> paramNames = copy(names);
return parseFunctionTypeAnnotationTail(begin, generics, genericPacks, paramTypes, paramNames, varargAnnotation);
return {parseFunctionTypeAnnotationTail(begin, generics, genericPacks, paramTypes, paramNames, varargAnnotation), {}};
}
AstType* Parser::parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray<AstName> generics, AstArray<AstName> genericPacks,
@ -1421,7 +1456,7 @@ AstType* Parser::parseTypeAnnotation(TempVector<AstType*>& parts, const Location
if (c == '|')
{
nextLexeme();
parts.push_back(parseSimpleTypeAnnotation());
parts.push_back(parseSimpleTypeAnnotation(false).type);
isUnion = true;
}
else if (c == '?')
@ -1434,7 +1469,7 @@ AstType* Parser::parseTypeAnnotation(TempVector<AstType*>& parts, const Location
else if (c == '&')
{
nextLexeme();
parts.push_back(parseSimpleTypeAnnotation());
parts.push_back(parseSimpleTypeAnnotation(false).type);
isIntersection = true;
}
else
@ -1462,6 +1497,30 @@ AstType* Parser::parseTypeAnnotation(TempVector<AstType*>& parts, const Location
ParseError::raise(begin, "Composite type was not an intersection or union.");
}
AstTypeOrPack Parser::parseTypeOrPackAnnotation()
{
unsigned int oldRecursionCount = recursionCounter;
incrementRecursionCounter("type annotation");
Location begin = lexer.current().location;
TempVector<AstType*> parts(scratchAnnotation);
auto [type, typePack] = parseSimpleTypeAnnotation(true);
if (typePack)
{
LUAU_ASSERT(!type);
return {{}, typePack};
}
parts.push_back(type);
recursionCounter = oldRecursionCount;
return {parseTypeAnnotation(parts, begin), {}};
}
AstType* Parser::parseTypeAnnotation()
{
unsigned int oldRecursionCount = recursionCounter;
@ -1470,7 +1529,7 @@ AstType* Parser::parseTypeAnnotation()
Location begin = lexer.current().location;
TempVector<AstType*> parts(scratchAnnotation);
parts.push_back(parseSimpleTypeAnnotation());
parts.push_back(parseSimpleTypeAnnotation(false).type);
recursionCounter = oldRecursionCount;
@ -1479,7 +1538,7 @@ AstType* Parser::parseTypeAnnotation()
// typeannotation ::= nil | Name[`.' Name] [ `<' typeannotation [`,' ...] `>' ] | `typeof' `(' expr `)' | `{' [PropList] `}'
// | [`<' varlist `>'] `(' [TypeList] `)' `->` ReturnType
AstType* Parser::parseSimpleTypeAnnotation()
AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack)
{
incrementRecursionCounter("type annotation");
@ -1488,7 +1547,33 @@ AstType* Parser::parseSimpleTypeAnnotation()
if (lexer.current().type == Lexeme::ReservedNil)
{
nextLexeme();
return allocator.alloc<AstTypeReference>(begin, std::nullopt, nameNil);
return {allocator.alloc<AstTypeReference>(begin, std::nullopt, nameNil), {}};
}
else if (FFlag::LuauParseSingletonTypes && lexer.current().type == Lexeme::ReservedTrue)
{
nextLexeme();
return {allocator.alloc<AstTypeSingletonBool>(begin, true)};
}
else if (FFlag::LuauParseSingletonTypes && lexer.current().type == Lexeme::ReservedFalse)
{
nextLexeme();
return {allocator.alloc<AstTypeSingletonBool>(begin, false)};
}
else if (FFlag::LuauParseSingletonTypes && (lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::QuotedString))
{
if (std::optional<AstArray<char>> value = parseCharArray())
{
AstArray<char> svalue = *value;
return {allocator.alloc<AstTypeSingletonString>(begin, svalue)};
}
else
return {reportTypeAnnotationError(begin, {}, /*isMissing*/ false, "String literal contains malformed escape sequence")};
}
else if (FFlag::LuauParseSingletonTypes && lexer.current().type == Lexeme::BrokenString)
{
Location location = lexer.current().location;
nextLexeme();
return {reportTypeAnnotationError(location, {}, /*isMissing*/ false, "Malformed string")};
}
else if (lexer.current().type == Lexeme::Name)
{
@ -1514,22 +1599,41 @@ AstType* Parser::parseSimpleTypeAnnotation()
expectMatchAndConsume(')', typeofBegin);
return allocator.alloc<AstTypeTypeof>(Location(begin, end), expr);
return {allocator.alloc<AstTypeTypeof>(Location(begin, end), expr), {}};
}
AstArray<AstType*> generics = parseTypeParams();
if (FFlag::LuauParseTypePackTypeParameters)
{
bool hasParameters = false;
AstArray<AstTypeOrPack> parameters{};
Location end = lexer.previousLocation();
if (lexer.current().type == '<')
{
hasParameters = true;
parameters = parseTypeParams();
}
return allocator.alloc<AstTypeReference>(Location(begin, end), prefix, name.name, generics);
Location end = lexer.previousLocation();
return {allocator.alloc<AstTypeReference>(Location(begin, end), prefix, name.name, hasParameters, parameters), {}};
}
else
{
AstArray<AstTypeOrPack> generics = parseTypeParams();
Location end = lexer.previousLocation();
// false in 'hasParameterList' as it is not used without FFlagLuauTypeAliasPacks
return {allocator.alloc<AstTypeReference>(Location(begin, end), prefix, name.name, false, generics), {}};
}
}
else if (lexer.current().type == '{')
{
return parseTableTypeAnnotation();
return {parseTableTypeAnnotation(), {}};
}
else if (lexer.current().type == '(' || (FFlag::LuauParseGenericFunctions && lexer.current().type == '<'))
else if (lexer.current().type == '(' || lexer.current().type == '<')
{
return parseFunctionTypeAnnotation();
return parseFunctionTypeAnnotation(allowPack);
}
else
{
@ -1538,7 +1642,7 @@ AstType* Parser::parseSimpleTypeAnnotation()
// For a missing type annotation, capture 'space' between last token and the next one
location = Location(lexer.previousLocation().end, lexer.current().location.begin);
return reportTypeAnnotationError(location, {}, /*isMissing*/ true, "Expected type, got %s", lexer.current().toString().c_str());
return {reportTypeAnnotationError(location, {}, /*isMissing*/ true, "Expected type, got %s", lexer.current().toString().c_str()), {}};
}
}
@ -2257,19 +2361,6 @@ Parser::Name Parser::parseIndexName(const char* context, const Position& previou
return Name(nameError, location);
}
std::pair<AstArray<AstName>, AstArray<AstName>> Parser::parseGenericTypeListIfFFlagParseGenericFunctions()
{
if (FFlag::LuauParseGenericFunctions)
return Parser::parseGenericTypeList();
AstArray<AstName> generics;
AstArray<AstName> genericPacks;
generics.size = 0;
generics.data = nullptr;
genericPacks.size = 0;
genericPacks.data = nullptr;
return std::pair(generics, genericPacks);
}
std::pair<AstArray<AstName>, AstArray<AstName>> Parser::parseGenericTypeList()
{
TempVector<AstName> names{scratchName};
@ -2284,7 +2375,7 @@ std::pair<AstArray<AstName>, AstArray<AstName>> Parser::parseGenericTypeList()
while (true)
{
AstName name = parseName().name;
if (FFlag::LuauParseGenericFunctions && lexer.current().type == Lexeme::Dot3)
if (lexer.current().type == Lexeme::Dot3)
{
seenPack = true;
nextLexeme();
@ -2312,9 +2403,9 @@ std::pair<AstArray<AstName>, AstArray<AstName>> Parser::parseGenericTypeList()
return {generics, genericPacks};
}
AstArray<AstType*> Parser::parseTypeParams()
AstArray<AstTypeOrPack> Parser::parseTypeParams()
{
TempVector<AstType*> result{scratchAnnotation};
TempVector<AstTypeOrPack> parameters{scratchTypeOrPackAnnotation};
if (lexer.current().type == '<')
{
@ -2323,7 +2414,43 @@ AstArray<AstType*> Parser::parseTypeParams()
while (true)
{
result.push_back(parseTypeAnnotation());
if (FFlag::LuauParseTypePackTypeParameters)
{
if (shouldParseTypePackAnnotation(lexer))
{
auto typePack = parseTypePackAnnotation();
if (FFlag::LuauTypeAliasPacks) // Type packs are recorded only is we can handle them
parameters.push_back({{}, typePack});
}
else if (lexer.current().type == '(')
{
auto [type, typePack] = parseTypeOrPackAnnotation();
if (typePack)
{
if (FFlag::LuauTypeAliasPacks) // Type packs are recorded only is we can handle them
parameters.push_back({{}, typePack});
}
else
{
parameters.push_back({type, {}});
}
}
else if (lexer.current().type == '>' && parameters.empty())
{
break;
}
else
{
parameters.push_back({parseTypeAnnotation(), {}});
}
}
else
{
parameters.push_back({parseTypeAnnotation(), {}});
}
if (lexer.current().type == ',')
nextLexeme();
else
@ -2333,10 +2460,10 @@ AstArray<AstType*> Parser::parseTypeParams()
expectMatchAndConsume('>', begin);
}
return copy(result);
return copy(parameters);
}
AstExpr* Parser::parseString()
std::optional<AstArray<char>> Parser::parseCharArray()
{
LUAU_ASSERT(lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::RawString);
@ -2346,11 +2473,8 @@ AstExpr* Parser::parseString()
{
if (!Lexer::fixupQuotedString(scratchData))
{
Location location = lexer.current().location;
nextLexeme();
return reportExprError(location, {}, "String literal contains malformed escape sequence");
return std::nullopt;
}
}
else
@ -2358,12 +2482,18 @@ AstExpr* Parser::parseString()
Lexer::fixupMultilineString(scratchData);
}
Location start = lexer.current().location;
AstArray<char> value = copy(scratchData);
nextLexeme();
return value;
}
return allocator.alloc<AstExprConstantString>(start, value);
AstExpr* Parser::parseString()
{
Location location = lexer.current().location;
if (std::optional<AstArray<char>> value = parseCharArray())
return allocator.alloc<AstExprConstantString>(location, *value);
else
return reportExprError(location, {}, "String literal contains malformed escape sequence");
}
AstLocal* Parser::pushLocal(const Binding& binding)

View file

@ -225,4 +225,62 @@ size_t hashRange(const char* data, size_t size)
return hash;
}
bool isIdentifier(std::string_view s)
{
return (s.find_first_not_of("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ01234567890_") == std::string::npos);
}
std::string escape(std::string_view s)
{
std::string r;
r.reserve(s.size() + 50); // arbitrary number to guess how many characters we'll be inserting
for (uint8_t c : s)
{
if (c >= ' ' && c != '\\' && c != '\'' && c != '\"')
r += c;
else
{
r += '\\';
switch (c)
{
case '\a':
r += 'a';
break;
case '\b':
r += 'b';
break;
case '\f':
r += 'f';
break;
case '\n':
r += 'n';
break;
case '\r':
r += 'r';
break;
case '\t':
r += 't';
break;
case '\v':
r += 'v';
break;
case '\'':
r += '\'';
break;
case '\"':
r += '\"';
break;
case '\\':
r += '\\';
break;
default:
Luau::formatAppend(r, "%03u", c);
}
}
}
return r;
}
} // namespace Luau

251
Ast/src/TimeTrace.cpp Normal file
View file

@ -0,0 +1,251 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/TimeTrace.h"
#include "Luau/StringUtils.h"
#include <mutex>
#include <string>
#include <stdlib.h>
#ifdef _WIN32
#include <Windows.h>
#endif
#ifdef __APPLE__
#include <mach/mach.h>
#include <mach/mach_time.h>
#endif
#include <time.h>
LUAU_FASTFLAGVARIABLE(DebugLuauTimeTracing, false)
#if defined(LUAU_ENABLE_TIME_TRACE)
namespace Luau
{
namespace TimeTrace
{
static double getClockPeriod()
{
#if defined(_WIN32)
LARGE_INTEGER result = {};
QueryPerformanceFrequency(&result);
return 1.0 / double(result.QuadPart);
#elif defined(__APPLE__)
mach_timebase_info_data_t result = {};
mach_timebase_info(&result);
return double(result.numer) / double(result.denom) * 1e-9;
#elif defined(__linux__)
return 1e-9;
#else
return 1.0 / double(CLOCKS_PER_SEC);
#endif
}
static double getClockTimestamp()
{
#if defined(_WIN32)
LARGE_INTEGER result = {};
QueryPerformanceCounter(&result);
return double(result.QuadPart);
#elif defined(__APPLE__)
return double(mach_absolute_time());
#elif defined(__linux__)
timespec now;
clock_gettime(CLOCK_MONOTONIC, &now);
return now.tv_sec * 1e9 + now.tv_nsec;
#else
return double(clock());
#endif
}
uint32_t getClockMicroseconds()
{
static double period = getClockPeriod() * 1e6;
static double start = getClockTimestamp();
return uint32_t((getClockTimestamp() - start) * period);
}
struct GlobalContext
{
GlobalContext() = default;
~GlobalContext()
{
// Ideally we would want all ThreadContext destructors to run
// But in VS, not all thread_local object instances are destroyed
for (ThreadContext* context : threads)
{
if (!context->events.empty())
context->flushEvents();
}
if (traceFile)
fclose(traceFile);
}
std::mutex mutex;
std::vector<ThreadContext*> threads;
uint32_t nextThreadId = 0;
std::vector<Token> tokens;
FILE* traceFile = nullptr;
};
GlobalContext& getGlobalContext()
{
static GlobalContext context;
return context;
}
uint16_t createToken(GlobalContext& context, const char* name, const char* category)
{
std::scoped_lock lock(context.mutex);
LUAU_ASSERT(context.tokens.size() < 64 * 1024);
context.tokens.push_back({name, category});
return uint16_t(context.tokens.size() - 1);
}
uint32_t createThread(GlobalContext& context, ThreadContext* threadContext)
{
std::scoped_lock lock(context.mutex);
context.threads.push_back(threadContext);
return ++context.nextThreadId;
}
void releaseThread(GlobalContext& context, ThreadContext* threadContext)
{
std::scoped_lock lock(context.mutex);
if (auto it = std::find(context.threads.begin(), context.threads.end(), threadContext); it != context.threads.end())
context.threads.erase(it);
}
void flushEvents(GlobalContext& context, uint32_t threadId, const std::vector<Event>& events, const std::vector<char>& data)
{
std::scoped_lock lock(context.mutex);
if (!context.traceFile)
{
context.traceFile = fopen("trace.json", "w");
if (!context.traceFile)
return;
fprintf(context.traceFile, "[\n");
}
std::string temp;
const unsigned tempReserve = 64 * 1024;
temp.reserve(tempReserve);
const char* rawData = data.data();
// Formatting state
bool unfinishedEnter = false;
bool unfinishedArgs = false;
for (const Event& ev : events)
{
switch (ev.type)
{
case EventType::Enter:
{
if (unfinishedArgs)
{
formatAppend(temp, "}");
unfinishedArgs = false;
}
if (unfinishedEnter)
{
formatAppend(temp, "},\n");
unfinishedEnter = false;
}
Token& token = context.tokens[ev.token];
formatAppend(temp, R"({"name": "%s", "cat": "%s", "ph": "B", "ts": %u, "pid": 0, "tid": %u)", token.name, token.category,
ev.data.microsec, threadId);
unfinishedEnter = true;
}
break;
case EventType::Leave:
if (unfinishedArgs)
{
formatAppend(temp, "}");
unfinishedArgs = false;
}
if (unfinishedEnter)
{
formatAppend(temp, "},\n");
unfinishedEnter = false;
}
formatAppend(temp,
R"({"ph": "E", "ts": %u, "pid": 0, "tid": %u},)"
"\n",
ev.data.microsec, threadId);
break;
case EventType::ArgName:
LUAU_ASSERT(unfinishedEnter);
if (!unfinishedArgs)
{
formatAppend(temp, R"(, "args": { "%s": )", rawData + ev.data.dataPos);
unfinishedArgs = true;
}
else
{
formatAppend(temp, R"(, "%s": )", rawData + ev.data.dataPos);
}
break;
case EventType::ArgValue:
LUAU_ASSERT(unfinishedArgs);
formatAppend(temp, R"("%s")", rawData + ev.data.dataPos);
break;
}
// Don't want to hit the string capacity and reallocate
if (temp.size() > tempReserve - 1024)
{
fwrite(temp.data(), 1, temp.size(), context.traceFile);
temp.clear();
}
}
if (unfinishedArgs)
{
formatAppend(temp, "}");
unfinishedArgs = false;
}
if (unfinishedEnter)
{
formatAppend(temp, "},\n");
unfinishedEnter = false;
}
fwrite(temp.data(), 1, temp.size(), context.traceFile);
fflush(context.traceFile);
}
ThreadContext& getThreadContext()
{
thread_local ThreadContext context;
return context;
}
std::pair<uint16_t, Luau::TimeTrace::ThreadContext&> createScopeData(const char* name, const char* category)
{
uint16_t token = createToken(Luau::TimeTrace::getGlobalContext(), name, category);
return {token, Luau::TimeTrace::getThreadContext()};
}
} // namespace TimeTrace
} // namespace Luau
#endif

View file

@ -34,8 +34,10 @@ static void report(ReportFormat format, const char* name, const Luau::Location&
}
}
static void reportError(ReportFormat format, const char* name, const Luau::TypeError& error)
static void reportError(ReportFormat format, const Luau::TypeError& error)
{
const char* name = error.moduleName.c_str();
if (const Luau::SyntaxError* syntaxError = Luau::get_if<Luau::SyntaxError>(&error.data))
report(format, name, error.location, "SyntaxError", syntaxError->message.c_str());
else
@ -49,7 +51,10 @@ static void reportWarning(ReportFormat format, const char* name, const Luau::Lin
static bool analyzeFile(Luau::Frontend& frontend, const char* name, ReportFormat format, bool annotate)
{
Luau::CheckResult cr = frontend.check(name);
Luau::CheckResult cr;
if (frontend.isDirty(name))
cr = frontend.check(name);
if (!frontend.getSourceModule(name))
{
@ -58,7 +63,7 @@ static bool analyzeFile(Luau::Frontend& frontend, const char* name, ReportFormat
}
for (auto& error : cr.errors)
reportError(format, name, error);
reportError(format, error);
Luau::LintResult lr = frontend.lint(name);
@ -111,11 +116,29 @@ struct CliFileResolver : Luau::FileResolver
return Luau::SourceCode{*source, Luau::SourceCode::Module};
}
std::optional<Luau::ModuleInfo> resolveModule(const Luau::ModuleInfo* context, Luau::AstExpr* node) override
{
if (Luau::AstExprConstantString* expr = node->as<Luau::AstExprConstantString>())
{
Luau::ModuleName name = std::string(expr->value.data, expr->value.size) + ".luau";
if (!moduleExists(name))
{
// fall back to .lua if a module with .luau doesn't exist
name = std::string(expr->value.data, expr->value.size) + ".lua";
}
return {{name}};
}
return std::nullopt;
}
bool moduleExists(const Luau::ModuleName& name) const override
{
return !!readFile(name);
}
std::optional<Luau::ModuleName> fromAstFragment(Luau::AstExpr* expr) const override
{
return std::nullopt;
@ -130,11 +153,6 @@ struct CliFileResolver : Luau::FileResolver
{
return std::nullopt;
}
std::optional<std::string> getEnvironmentForModule(const Luau::ModuleName& name) const override
{
return std::nullopt;
}
};
struct CliConfigResolver : Luau::ConfigResolver
@ -218,25 +236,12 @@ int main(int argc, char** argv)
Luau::registerBuiltinTypes(frontend.typeChecker);
Luau::freeze(frontend.typeChecker.globalTypes);
std::vector<std::string> files = getSourceFiles(argc, argv);
int failed = 0;
for (int i = 1; i < argc; ++i)
{
if (argv[i][0] == '-')
continue;
if (isDirectory(argv[i]))
{
traverseDirectory(argv[i], [&](const std::string& name) {
if (name.length() > 4 && name.rfind(".lua") == name.length() - 4)
failed += !analyzeFile(frontend, name.c_str(), format, annotate);
});
}
else
{
failed += !analyzeFile(frontend, argv[i], format, annotate);
}
}
for (const std::string& path : files)
failed += !analyzeFile(frontend, path.c_str(), format, annotate);
if (!configResolver.configErrors.empty())
{
@ -248,5 +253,3 @@ int main(int argc, char** argv)
return (format == ReportFormat::Luacheck) ? 0 : failed;
}

View file

@ -142,6 +142,7 @@ static bool traverseDirectoryRec(const std::string& path, const std::function<vo
joinPaths(buf, path.c_str(), data.d_name);
int type = data.d_type;
int mode = -1;
// we need to stat DT_UNKNOWN to be able to tell the type
if (type == DT_UNKNOWN)
@ -153,18 +154,18 @@ static bool traverseDirectoryRec(const std::string& path, const std::function<vo
lstat(buf.c_str(), &st);
#endif
type = IFTODT(st.st_mode);
mode = st.st_mode;
}
if (type == DT_DIR)
if (type == DT_DIR || mode == S_IFDIR)
{
traverseDirectoryRec(buf, callback);
}
else if (type == DT_REG)
else if (type == DT_REG || mode == S_IFREG)
{
callback(buf);
}
else if (type == DT_LNK)
else if (type == DT_LNK || mode == S_IFLNK)
{
// Skip symbolic links to avoid handling cycles
}
@ -222,3 +223,40 @@ std::optional<std::string> getParentPath(const std::string& path)
return "";
}
static std::string getExtension(const std::string& path)
{
std::string::size_type dot = path.find_last_of(".\\/");
if (dot == std::string::npos || path[dot] != '.')
return "";
return path.substr(dot);
}
std::vector<std::string> getSourceFiles(int argc, char** argv)
{
std::vector<std::string> files;
for (int i = 1; i < argc; ++i)
{
if (argv[i][0] == '-')
continue;
if (isDirectory(argv[i]))
{
traverseDirectory(argv[i], [&](const std::string& name) {
std::string ext = getExtension(name);
if (ext == ".lua" || ext == ".luau")
files.push_back(name);
});
}
else
{
files.push_back(argv[i]);
}
}
return files;
}

View file

@ -4,6 +4,7 @@
#include <optional>
#include <string>
#include <functional>
#include <vector>
std::optional<std::string> readFile(const std::string& name);
@ -12,3 +13,5 @@ bool traverseDirectory(const std::string& path, const std::function<void(const s
std::string joinPaths(const std::string& lhs, const std::string& rhs);
std::optional<std::string> getParentPath(const std::string& path);
std::vector<std::string> getSourceFiles(int argc, char** argv);

View file

@ -13,6 +13,17 @@
#include <memory>
#ifdef _WIN32
#include <io.h>
#include <fcntl.h>
#endif
enum class CompileFormat
{
Text,
Binary
};
static int lua_loadstring(lua_State* L)
{
size_t l = 0;
@ -22,7 +33,7 @@ static int lua_loadstring(lua_State* L)
lua_setsafeenv(L, LUA_ENVIRONINDEX, false);
std::string bytecode = Luau::compile(std::string(s, l));
if (luau_load(L, chunkname, bytecode.data(), bytecode.size()) == 0)
if (luau_load(L, chunkname, bytecode.data(), bytecode.size(), 0) == 0)
return 1;
lua_pushnil(L);
@ -51,9 +62,13 @@ static int lua_require(lua_State* L)
return finishrequire(L);
lua_pop(L, 1);
std::optional<std::string> source = readFile(name + ".lua");
std::optional<std::string> source = readFile(name + ".luau");
if (!source)
luaL_argerrorL(L, 1, ("error loading " + name).c_str());
{
source = readFile(name + ".lua"); // try .lua if .luau doesn't exist
if (!source)
luaL_argerrorL(L, 1, ("error loading " + name).c_str()); // if neither .luau nor .lua exist, we have an error
}
// module needs to run in a new thread, isolated from the rest
lua_State* GL = lua_mainthread(L);
@ -65,7 +80,7 @@ static int lua_require(lua_State* L)
// now we can compile & run module on the new thread
std::string bytecode = Luau::compile(*source);
if (luau_load(ML, chunkname.c_str(), bytecode.data(), bytecode.size()) == 0)
if (luau_load(ML, chunkname.c_str(), bytecode.data(), bytecode.size(), 0) == 0)
{
int status = lua_resume(ML, L, 0);
@ -136,7 +151,7 @@ static std::string runCode(lua_State* L, const std::string& source)
{
std::string bytecode = Luau::compile(source);
if (luau_load(L, "=stdin", bytecode.data(), bytecode.size()) != 0)
if (luau_load(L, "=stdin", bytecode.data(), bytecode.size(), 0) != 0)
{
size_t len;
const char* msg = lua_tolstring(L, -1, &len);
@ -175,7 +190,7 @@ static std::string runCode(lua_State* L, const std::string& source)
{
error = "thread yielded unexpectedly";
}
else if (const char* str = lua_tostring(L, -1))
else if (const char* str = lua_tostring(T, -1))
{
error = str;
}
@ -183,6 +198,11 @@ static std::string runCode(lua_State* L, const std::string& source)
error += "\nstack backtrace:\n";
error += lua_debugtrace(T);
#ifdef __EMSCRIPTEN__
// nicer formatting for errors in web repl
error = "Error:" + error;
#endif
fprintf(stdout, "%s", error.c_str());
}
@ -190,6 +210,39 @@ static std::string runCode(lua_State* L, const std::string& source)
return std::string();
}
#ifdef __EMSCRIPTEN__
extern "C"
{
const char* executeScript(const char* source)
{
// setup flags
for (Luau::FValue<bool>* flag = Luau::FValue<bool>::list; flag; flag = flag->next)
if (strncmp(flag->name, "Luau", 4) == 0)
flag->value = true;
// create new state
std::unique_ptr<lua_State, void (*)(lua_State*)> globalState(luaL_newstate(), lua_close);
lua_State* L = globalState.get();
// setup state
setupState(L);
// sandbox thread
luaL_sandboxthread(L);
// static string for caching result (prevents dangling ptr on function exit)
static std::string result;
// run code + collect error
result = runCode(L, source);
return result.empty() ? NULL : result.c_str();
}
}
#endif
// Excluded from emscripten compilation to avoid -Wunused-function errors.
#ifndef __EMSCRIPTEN__
static void completeIndexer(lua_State* L, const char* editBuffer, size_t start, std::vector<std::string>& completions)
{
std::string_view lookup = editBuffer + start;
@ -317,7 +370,7 @@ static bool runFile(const char* name, lua_State* GL)
std::string bytecode = Luau::compile(*source);
int status = 0;
if (luau_load(L, chunkname.c_str(), bytecode.data(), bytecode.size()) == 0)
if (luau_load(L, chunkname.c_str(), bytecode.data(), bytecode.size(), 0) == 0)
{
status = lua_resume(L, NULL, 0);
}
@ -326,11 +379,7 @@ static bool runFile(const char* name, lua_State* GL)
status = LUA_ERRSYNTAX;
}
if (status == 0)
{
return true;
}
else
if (status != 0)
{
std::string error;
@ -347,8 +396,10 @@ static bool runFile(const char* name, lua_State* GL)
error += lua_debugtrace(L);
fprintf(stderr, "%s", error.c_str());
return false;
}
lua_pop(GL, 1);
return status == 0;
}
static void report(const char* name, const Luau::Location& location, const char* type, const char* message)
@ -366,7 +417,7 @@ static void reportError(const char* name, const Luau::CompileError& error)
report(name, error.getLocation(), "CompileError", error.what());
}
static bool compileFile(const char* name)
static bool compileFile(const char* name, CompileFormat format)
{
std::optional<std::string> source = readFile(name);
if (!source)
@ -378,12 +429,24 @@ static bool compileFile(const char* name)
try
{
Luau::BytecodeBuilder bcb;
bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source);
bcb.setDumpSource(*source);
if (format == CompileFormat::Text)
{
bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source);
bcb.setDumpSource(*source);
}
Luau::compileOrThrow(bcb, *source);
printf("%s", bcb.dumpEverything().c_str());
switch (format)
{
case CompileFormat::Text:
printf("%s", bcb.dumpEverything().c_str());
break;
case CompileFormat::Binary:
fwrite(bcb.getBytecode().data(), 1, bcb.getBytecode().size(), stdout);
break;
}
return true;
}
@ -408,7 +471,7 @@ static void displayHelp(const char* argv0)
printf("\n");
printf("Available modes:\n");
printf(" omitted: compile and run input files one by one\n");
printf(" --compile: compile input files and output resulting bytecode\n");
printf(" --compile[=format]: compile input files and output resulting formatted bytecode (binary or text)\n");
printf("\n");
printf("Available options:\n");
printf(" --profile[=N]: profile the code using N Hz sampling (default 10000) and output results to profile.out\n");
@ -440,27 +503,25 @@ int main(int argc, char** argv)
return 0;
}
if (argc >= 2 && strcmp(argv[1], "--compile") == 0)
if (argc >= 2 && strncmp(argv[1], "--compile", strlen("--compile")) == 0)
{
CompileFormat format = CompileFormat::Text;
if (strcmp(argv[1], "--compile=binary") == 0)
format = CompileFormat::Binary;
#ifdef _WIN32
if (format == CompileFormat::Binary)
_setmode(_fileno(stdout), _O_BINARY);
#endif
std::vector<std::string> files = getSourceFiles(argc, argv);
int failed = 0;
for (int i = 2; i < argc; ++i)
{
if (argv[i][0] == '-')
continue;
if (isDirectory(argv[i]))
{
traverseDirectory(argv[i], [&](const std::string& name) {
if (name.length() > 4 && name.rfind(".lua") == name.length() - 4)
failed += !compileFile(name.c_str());
});
}
else
{
failed += !compileFile(argv[i]);
}
}
for (const std::string& path : files)
failed += !compileFile(path.c_str(), format);
return failed;
}
@ -474,33 +535,25 @@ int main(int argc, char** argv)
int profile = 0;
for (int i = 1; i < argc; ++i)
{
if (argv[i][0] != '-')
continue;
if (strcmp(argv[i], "--profile") == 0)
profile = 10000; // default to 10 KHz
else if (strncmp(argv[i], "--profile=", 10) == 0)
profile = atoi(argv[i] + 10);
}
if (profile)
profilerStart(L, profile);
std::vector<std::string> files = getSourceFiles(argc, argv);
int failed = 0;
for (int i = 1; i < argc; ++i)
{
if (argv[i][0] == '-')
continue;
if (isDirectory(argv[i]))
{
traverseDirectory(argv[i], [&](const std::string& name) {
if (name.length() > 4 && name.rfind(".lua") == name.length() - 4)
failed += !runFile(name.c_str(), L);
});
}
else
{
failed += !runFile(argv[i], L);
}
}
for (const std::string& path : files)
failed += !runFile(path.c_str(), L);
if (profile)
{
@ -511,5 +564,5 @@ int main(int argc, char** argv)
return failed;
}
}
#endif

View file

@ -9,6 +9,7 @@ project(Luau LANGUAGES CXX)
option(LUAU_BUILD_CLI "Build CLI" ON)
option(LUAU_BUILD_TESTS "Build tests" ON)
option(LUAU_WERROR "Warnings as errors" OFF)
add_library(Luau.Ast STATIC)
add_library(Luau.Compiler STATIC)
@ -17,17 +18,26 @@ add_library(Luau.VM STATIC)
if(LUAU_BUILD_CLI)
add_executable(Luau.Repl.CLI)
add_executable(Luau.Analyze.CLI)
if(NOT EMSCRIPTEN)
add_executable(Luau.Analyze.CLI)
else()
# add -fexceptions for emscripten to allow exceptions to be caught in C++
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fexceptions")
endif()
# This also adds target `name` on Linux/macOS and `name.exe` on Windows
set_target_properties(Luau.Repl.CLI PROPERTIES OUTPUT_NAME luau)
set_target_properties(Luau.Analyze.CLI PROPERTIES OUTPUT_NAME luau-analyze)
if(NOT EMSCRIPTEN)
set_target_properties(Luau.Analyze.CLI PROPERTIES OUTPUT_NAME luau-analyze)
endif()
endif()
if(LUAU_BUILD_TESTS)
if(LUAU_BUILD_TESTS AND NOT EMSCRIPTEN)
add_executable(Luau.UnitTest)
add_executable(Luau.Conformance)
endif()
include(Sources.cmake)
target_compile_features(Luau.Ast PUBLIC cxx_std_17)
@ -48,14 +58,17 @@ set(LUAU_OPTIONS)
if(MSVC)
list(APPEND LUAU_OPTIONS /D_CRT_SECURE_NO_WARNINGS) # We need to use the portable CRT functions.
list(APPEND LUAU_OPTIONS /WX) # Warnings are errors
list(APPEND LUAU_OPTIONS /MP) # Distribute single project compilation across multiple cores
else()
list(APPEND LUAU_OPTIONS -Wall) # All warnings
list(APPEND LUAU_OPTIONS -Werror) # Warnings are errors
endif()
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
list(APPEND LUAU_OPTIONS -Wno-unused) # GCC considers variables declared/checked in if() as unused
# Enabled in CI; we should be warning free on our main compiler versions but don't guarantee being warning free everywhere
if(LUAU_WERROR)
if(MSVC)
list(APPEND LUAU_OPTIONS /WX) # Warnings are errors
else()
list(APPEND LUAU_OPTIONS -Werror) # Warnings are errors
endif()
endif()
@ -65,19 +78,35 @@ target_compile_options(Luau.VM PRIVATE ${LUAU_OPTIONS})
if(LUAU_BUILD_CLI)
target_compile_options(Luau.Repl.CLI PRIVATE ${LUAU_OPTIONS})
target_compile_options(Luau.Analyze.CLI PRIVATE ${LUAU_OPTIONS})
if(NOT EMSCRIPTEN)
target_compile_options(Luau.Analyze.CLI PRIVATE ${LUAU_OPTIONS})
endif()
target_include_directories(Luau.Repl.CLI PRIVATE extern)
target_link_libraries(Luau.Repl.CLI PRIVATE Luau.Compiler Luau.VM)
if(UNIX)
target_link_libraries(Luau.Repl.CLI PRIVATE pthread)
find_library(LIBPTHREAD pthread)
if (LIBPTHREAD)
target_link_libraries(Luau.Repl.CLI PRIVATE pthread)
endif()
endif()
target_link_libraries(Luau.Analyze.CLI PRIVATE Luau.Analysis)
if(NOT EMSCRIPTEN)
target_link_libraries(Luau.Analyze.CLI PRIVATE Luau.Analysis)
endif()
if(EMSCRIPTEN)
# declare exported functions to emscripten
target_link_options(Luau.Repl.CLI PRIVATE -sEXPORTED_FUNCTIONS=['_executeScript'] -sEXPORTED_RUNTIME_METHODS=['ccall','cwrap'] -fexceptions)
# custom output directory for wasm + js file
set_target_properties(Luau.Repl.CLI PROPERTIES RUNTIME_OUTPUT_DIRECTORY ${CMAKE_SOURCE_DIR}/docs/assets/luau)
endif()
endif()
if(LUAU_BUILD_TESTS)
if(LUAU_BUILD_TESTS AND NOT EMSCRIPTEN)
target_compile_options(Luau.UnitTest PRIVATE ${LUAU_OPTIONS})
target_include_directories(Luau.UnitTest PRIVATE extern)
target_link_libraries(Luau.UnitTest PRIVATE Luau.Analysis Luau.Compiler)

View file

@ -467,6 +467,10 @@ enum LuauBuiltinFunction
// vector ctor
LBF_VECTOR,
// bit32.count
LBF_BIT32_COUNTLZ,
LBF_BIT32_COUNTRZ,
};
// Capture type, used in LOP_CAPTURE

View file

@ -13,11 +13,9 @@ class AstNameTable;
class BytecodeBuilder;
class BytecodeEncoder;
// Note: this structure is duplicated in luacode.h, don't forget to change these in sync!
struct CompileOptions
{
// default bytecode version target; can be used to compile code for older clients
int bytecodeVersion = 1;
// 0 - no optimization
// 1 - baseline optimization level that doesn't prevent debuggability
// 2 - includes optimizations that harm debuggability such as inlining
@ -36,6 +34,9 @@ struct CompileOptions
// global builtin to construct vectors; disabled by default
const char* vectorLib = nullptr;
const char* vectorCtor = nullptr;
// null-terminated array of globals that are mutable; disables the import optimization for fields accessed through these
const char** mutableGlobals = nullptr;
};
class CompileError : public std::exception

View file

@ -0,0 +1,39 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include <stddef.h>
/* Can be used to reconfigure visibility/exports for public APIs */
#ifndef LUACODE_API
#define LUACODE_API extern
#endif
typedef struct lua_CompileOptions lua_CompileOptions;
struct lua_CompileOptions
{
// 0 - no optimization
// 1 - baseline optimization level that doesn't prevent debuggability
// 2 - includes optimizations that harm debuggability such as inlining
int optimizationLevel; // default=1
// 0 - no debugging support
// 1 - line info & function names only; sufficient for backtraces
// 2 - full debug info with local & upvalue names; necessary for debugger
int debugLevel; // default=1
// 0 - no code coverage support
// 1 - statement coverage
// 2 - statement and expression coverage (verbose)
int coverageLevel; // default=0
// global builtin to construct vectors; disabled by default
const char* vectorLib;
const char* vectorCtor;
// null-terminated array of globals that are mutable; disables the import optimization for fields accessed through these
const char** mutableGlobals;
};
/* compile source to bytecode; when source compilation fails, the resulting bytecode contains the encoded error. use free() to destroy */
LUACODE_API char* luau_compile(const char* source, size_t size, lua_CompileOptions* options, size_t* outsize);

View file

@ -4,15 +4,15 @@
#include "Luau/Parser.h"
#include "Luau/BytecodeBuilder.h"
#include "Luau/Common.h"
#include "Luau/TimeTrace.h"
#include <algorithm>
#include <bitset>
#include <math.h>
LUAU_FASTFLAGVARIABLE(LuauPreloadClosures, false)
LUAU_FASTFLAGVARIABLE(LuauPreloadClosuresFenv, false)
LUAU_FASTFLAGVARIABLE(LuauPreloadClosuresUpval, false)
LUAU_FASTFLAG(LuauIfElseExpressionBaseSupport)
LUAU_FASTFLAGVARIABLE(LuauBit32CountBuiltin, false)
namespace Luau
{
@ -21,8 +21,6 @@ static const uint32_t kMaxRegisterCount = 255;
static const uint32_t kMaxUpvalueCount = 200;
static const uint32_t kMaxLocalCount = 200;
static const char* kSpecialGlobals[] = {"Game", "Workspace", "_G", "game", "plugin", "script", "shared", "workspace"};
CompileError::CompileError(const Location& location, const std::string& message)
: location(location)
, message(message)
@ -137,6 +135,11 @@ struct Compiler
uint32_t compileFunction(AstExprFunction* func)
{
LUAU_TIMETRACE_SCOPE("Compiler::compileFunction", "Compiler");
if (func->debugname.value)
LUAU_TIMETRACE_ARGUMENT("name", func->debugname.value);
LUAU_ASSERT(!functions.contains(func));
LUAU_ASSERT(regTop == 0 && stackSize == 0 && localStack.empty() && upvals.empty());
@ -457,7 +460,7 @@ struct Compiler
bool shared = false;
if (FFlag::LuauPreloadClosuresUpval)
if (FFlag::LuauPreloadClosures)
{
// Optimization: when closure has no upvalues, or upvalues are safe to share, instead of allocating it every time we can share closure
// objects (this breaks assumptions about function identity which can lead to setfenv not working as expected, so we disable this when it
@ -473,18 +476,6 @@ struct Compiler
}
}
}
// Optimization: when closure has no upvalues, instead of allocating it every time we can share closure objects
// (this breaks assumptions about function identity which can lead to setfenv not working as expected, so we disable this when it is used)
else if (FFlag::LuauPreloadClosures && options.optimizationLevel >= 1 && f->upvals.empty() && !setfenvUsed)
{
int32_t cid = bytecode.addConstantClosure(f->id);
if (cid >= 0 && cid < 32768)
{
bytecode.emitAD(LOP_DUPCLOSURE, target, cid);
return;
}
}
if (!shared)
bytecode.emitAD(LOP_NEWCLOSURE, target, pid);
@ -1271,7 +1262,7 @@ struct Compiler
{
const Global* global = globals.find(expr->name);
return options.optimizationLevel >= 1 && (!global || (!global->written && !global->special));
return options.optimizationLevel >= 1 && (!global || (!global->written && !global->writable));
}
void compileExprIndexName(AstExprIndexName* expr, uint8_t target)
@ -2459,9 +2450,10 @@ struct Compiler
}
else if (node->is<AstStatBreak>())
{
LUAU_ASSERT(!loops.empty());
// before exiting out of the loop, we need to close all local variables that were captured in closures since loop start
// normally they are closed by the enclosing blocks, including the loop block, but we're skipping that here
LUAU_ASSERT(!loops.empty());
closeLocals(loops.back().localOffset);
size_t label = bytecode.emitLabel();
@ -2472,12 +2464,13 @@ struct Compiler
}
else if (AstStatContinue* stat = node->as<AstStatContinue>())
{
LUAU_ASSERT(!loops.empty());
if (loops.back().untilCondition)
validateContinueUntil(stat, loops.back().untilCondition);
// before continuing, we need to close all local variables that were captured in closures since loop start
// normally they are closed by the enclosing blocks, including the loop block, but we're skipping that here
LUAU_ASSERT(!loops.empty());
closeLocals(loops.back().localOffset);
size_t label = bytecode.emitLabel();
@ -2894,6 +2887,11 @@ struct Compiler
break;
case AstExprUnary::Len:
if (arg.type == Constant::Type_String)
{
result.type = Constant::Type_Number;
result.valueNumber = double(arg.valueString.size);
}
break;
default:
@ -3282,8 +3280,7 @@ struct Compiler
bool visit(AstStatLocalFunction* node) override
{
// record local->function association for some optimizations
if (FFlag::LuauPreloadClosuresUpval)
self->locals[node->name].func = node->func;
self->locals[node->name].func = node->func;
return true;
}
@ -3434,7 +3431,7 @@ struct Compiler
struct Global
{
bool special = false;
bool writable = false;
bool written = false;
};
@ -3492,7 +3489,7 @@ struct Compiler
{
Global* g = globals.find(object->name);
return !g || (!g->special && !g->written) ? Builtin{object->name, expr->index} : Builtin();
return !g || (!g->writable && !g->written) ? Builtin{object->name, expr->index} : Builtin();
}
else
{
@ -3623,6 +3620,10 @@ struct Compiler
return LBF_BIT32_RROTATE;
if (builtin.method == "rshift")
return LBF_BIT32_RSHIFT;
if (builtin.method == "countlz" && FFlag::LuauBit32CountBuiltin)
return LBF_BIT32_COUNTLZ;
if (builtin.method == "countrz" && FFlag::LuauBit32CountBuiltin)
return LBF_BIT32_COUNTRZ;
}
if (builtin.object == "string")
@ -3686,16 +3687,18 @@ struct Compiler
void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstNameTable& names, const CompileOptions& options)
{
LUAU_TIMETRACE_SCOPE("compileOrThrow", "Compiler");
Compiler compiler(bytecode, options);
// since access to some global objects may result in values that change over time, we block table imports
for (const char* global : kSpecialGlobals)
{
AstName name = names.get(global);
// since access to some global objects may result in values that change over time, we block imports from non-readonly tables
if (AstName name = names.get("_G"); name.value)
compiler.globals[name].writable = true;
if (name.value)
compiler.globals[name].special = true;
}
if (options.mutableGlobals)
for (const char** ptr = options.mutableGlobals; *ptr; ++ptr)
if (AstName name = names.get(*ptr); name.value)
compiler.globals[name].writable = true;
// this visitor traverses the AST to analyze mutability of locals/globals, filling Local::written and Global::written
Compiler::AssignmentVisitor assignmentVisitor(&compiler);
@ -3709,7 +3712,7 @@ void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstName
}
// this visitor tracks calls to getfenv/setfenv and disables some optimizations when they are found
if (FFlag::LuauPreloadClosuresFenv && options.optimizationLevel >= 1)
if (options.optimizationLevel >= 1 && (names.get("getfenv").value || names.get("setfenv").value))
{
Compiler::FenvVisitor fenvVisitor(compiler.getfenvUsed, compiler.setfenvUsed);
root->visit(&fenvVisitor);
@ -3748,6 +3751,8 @@ void compileOrThrow(BytecodeBuilder& bytecode, const std::string& source, const
std::string compile(const std::string& source, const CompileOptions& options, const ParseOptions& parseOptions, BytecodeEncoder* encoder)
{
LUAU_TIMETRACE_SCOPE("compile", "Compiler");
Allocator allocator;
AstNameTable names(allocator);
ParseResult result = Parser::parse(source.c_str(), source.size(), names, allocator, parseOptions);

29
Compiler/src/lcode.cpp Normal file
View file

@ -0,0 +1,29 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "luacode.h"
#include "Luau/Compiler.h"
#include <string.h>
char* luau_compile(const char* source, size_t size, lua_CompileOptions* options, size_t* outsize)
{
LUAU_ASSERT(outsize);
Luau::CompileOptions opts;
if (options)
{
static_assert(sizeof(lua_CompileOptions) == sizeof(Luau::CompileOptions), "C and C++ interface must match");
memcpy(static_cast<void*>(&opts), options, sizeof(opts));
}
std::string result = compile(std::string(source, size), opts);
char* copy = static_cast<char*>(malloc(result.size()));
if (!copy)
return nullptr;
memcpy(copy, result.data(), result.size());
*outsize = result.size();
return copy;
}

View file

@ -1,4 +1,5 @@
# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
.SUFFIXES:
MAKEFLAGS+=-r -j8
COMMA=,
@ -45,10 +46,19 @@ endif
OBJECTS=$(AST_OBJECTS) $(COMPILER_OBJECTS) $(ANALYSIS_OBJECTS) $(VM_OBJECTS) $(TESTS_OBJECTS) $(CLI_OBJECTS) $(FUZZ_OBJECTS)
# common flags
CXXFLAGS=-g -Wall -Werror
CXXFLAGS=-g -Wall
LDFLAGS=
CXXFLAGS+=-Wno-unused # temporary, for older gcc versions
# some gcc versions treat var in `if (type var = val)` as unused
# some gcc versions treat variables used in constexpr if blocks as unused
ifeq ($(findstring g++,$(shell $(CXX) --version)),g++)
CXXFLAGS+=-Wno-unused
endif
# enabled in CI; we should be warning free on our main compiler versions but don't guarantee being warning free everywhere
ifneq ($(werror),)
CXXFLAGS+=-Werror
endif
# configuration-specific flags
ifeq ($(config),release)
@ -133,12 +143,11 @@ $(TESTS_TARGET) $(REPL_CLI_TARGET) $(ANALYZE_CLI_TARGET):
# executable targets for fuzzing
fuzz-%: $(BUILD)/fuzz/%.cpp.o $(ANALYSIS_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(VM_TARGET)
$(CXX) $^ $(LDFLAGS) -o $@
fuzz-proto: $(BUILD)/fuzz/proto.cpp.o $(BUILD)/fuzz/protoprint.cpp.o $(BUILD)/fuzz/luau.pb.cpp.o $(ANALYSIS_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(VM_TARGET) | build/libprotobuf-mutator
fuzz-prototest: $(BUILD)/fuzz/prototest.cpp.o $(BUILD)/fuzz/protoprint.cpp.o $(BUILD)/fuzz/luau.pb.cpp.o $(ANALYSIS_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(VM_TARGET) | build/libprotobuf-mutator
fuzz-%:
$(CXX) $^ $(LDFLAGS) -o $@
# static library targets
$(AST_TARGET): $(AST_OBJECTS)
$(COMPILER_TARGET): $(COMPILER_OBJECTS)

View file

@ -22,7 +22,7 @@ You can download the binaries from [a recent release](https://github.com/Roblox/
# Building
To build Luau tools or tests yourself, you can use CMake on all platforms, or alternatively make (on Linux/macOS). For example:
To build Luau tools or tests yourself, you can use CMake on all platforms:
```sh
mkdir cmake && cd cmake
@ -31,11 +31,22 @@ cmake --build . --target Luau.Repl.CLI --config RelWithDebInfo
cmake --build . --target Luau.Analyze.CLI --config RelWithDebInfo
```
Alternatively, on Linux/macOS you can use make:
```sh
make config=release luau luau-analyze
```
To integrate Luau into your CMake application projects, at the minimum you'll need to depend on `Luau.Compiler` and `Luau.VM` projects. From there you need to create a new Luau state (using Lua 5.x API such as `lua_newstate`), compile source to bytecode and load it into the VM like this:
```cpp
std::string bytecode = Luau::compile(source); // needs Luau/Compiler.h include
if (luau_load(L, chunkname, bytecode.data(), bytecode.size()) == 0)
// needs lua.h and luacode.h
size_t bytecodeSize = 0;
char* bytecode = luau_compile(source, strlen(source), NULL, &bytecodeSize);
int result = luau_load(L, chunkname, bytecode, bytecodeSize, 0);
free(bytecode);
if (result == 0)
return 1; /* return chunk main function */
```

14
SECURITY.md Normal file
View file

@ -0,0 +1,14 @@
# Security Guarantees
Luau provides a safe sandbox that scripts can not escape from, short of vulnerabilities in custom C functions exposed by the host. This includes the virtual machine and builtin libraries.
Any source code can not result in memory safety errors or crashes during its compilation or execution. Violations of memory safety are considered vulnerabilities.
Note that Luau does not provide termination guarantees - some code may exhaust CPU or RAM resources on the system during compilation or execution.
The runtime expects valid bytecode as an input. Feeding bytecode that was not produced by Luau compiler into the VM is not supported and
doesn't come with any security guarantees; make sure to sign and/or encrypt the bytecode when it crosses a network or file system boundary to avoid tampering.
# Reporting a Vulnerability
You can report security bugs via [Hackerone](https://hackerone.com/roblox). Please refer to the linked page for rules of the bounty program.

View file

@ -9,6 +9,7 @@ target_sources(Luau.Ast PRIVATE
Ast/include/Luau/ParseOptions.h
Ast/include/Luau/Parser.h
Ast/include/Luau/StringUtils.h
Ast/include/Luau/TimeTrace.h
Ast/src/Ast.cpp
Ast/src/Confusables.cpp
@ -16,6 +17,7 @@ target_sources(Luau.Ast PRIVATE
Ast/src/Location.cpp
Ast/src/Parser.cpp
Ast/src/StringUtils.cpp
Ast/src/TimeTrace.cpp
)
# Luau.Compiler Sources
@ -23,9 +25,11 @@ target_sources(Luau.Compiler PRIVATE
Compiler/include/Luau/Bytecode.h
Compiler/include/Luau/BytecodeBuilder.h
Compiler/include/Luau/Compiler.h
Compiler/include/luacode.h
Compiler/src/BytecodeBuilder.cpp
Compiler/src/Compiler.cpp
Compiler/src/lcode.cpp
)
# Luau.Analysis Sources
@ -44,8 +48,10 @@ target_sources(Luau.Analysis PRIVATE
Analysis/include/Luau/Module.h
Analysis/include/Luau/ModuleResolver.h
Analysis/include/Luau/Predicate.h
Analysis/include/Luau/Quantify.h
Analysis/include/Luau/RecursionCounter.h
Analysis/include/Luau/RequireTracer.h
Analysis/include/Luau/Scope.h
Analysis/include/Luau/Substitution.h
Analysis/include/Luau/Symbol.h
Analysis/include/Luau/TopoSortStatements.h
@ -60,6 +66,7 @@ target_sources(Luau.Analysis PRIVATE
Analysis/include/Luau/TypeVar.h
Analysis/include/Luau/Unifiable.h
Analysis/include/Luau/Unifier.h
Analysis/include/Luau/UnifierSharedState.h
Analysis/include/Luau/Variant.h
Analysis/include/Luau/VisitTypeVar.h
@ -74,7 +81,9 @@ target_sources(Luau.Analysis PRIVATE
Analysis/src/Linter.cpp
Analysis/src/Module.cpp
Analysis/src/Predicate.cpp
Analysis/src/Quantify.cpp
Analysis/src/RequireTracer.cpp
Analysis/src/Scope.cpp
Analysis/src/Substitution.cpp
Analysis/src/Symbol.cpp
Analysis/src/TopoSortStatements.cpp
@ -188,6 +197,7 @@ if(TARGET Luau.UnitTest)
tests/TopoSort.test.cpp
tests/ToString.test.cpp
tests/Transpiler.test.cpp
tests/TypeInfer.aliases.test.cpp
tests/TypeInfer.annotations.test.cpp
tests/TypeInfer.builtins.test.cpp
tests/TypeInfer.classes.test.cpp
@ -196,6 +206,7 @@ if(TARGET Luau.UnitTest)
tests/TypeInfer.intersectionTypes.test.cpp
tests/TypeInfer.provisional.test.cpp
tests/TypeInfer.refinements.test.cpp
tests/TypeInfer.singletons.test.cpp
tests/TypeInfer.tables.test.cpp
tests/TypeInfer.test.cpp
tests/TypeInfer.tryUnify.test.cpp

View file

@ -102,6 +102,8 @@ LUA_API lua_State* lua_newstate(lua_Alloc f, void* ud);
LUA_API void lua_close(lua_State* L);
LUA_API lua_State* lua_newthread(lua_State* L);
LUA_API lua_State* lua_mainthread(lua_State* L);
LUA_API void lua_resetthread(lua_State* L);
LUA_API int lua_isthreadreset(lua_State* L);
/*
** basic stack manipulation
@ -162,8 +164,7 @@ LUA_API void lua_pushlstring(lua_State* L, const char* s, size_t l);
LUA_API void lua_pushstring(lua_State* L, const char* s);
LUA_API const char* lua_pushvfstring(lua_State* L, const char* fmt, va_list argp);
LUA_API LUA_PRINTF_ATTR(2, 3) const char* lua_pushfstringL(lua_State* L, const char* fmt, ...);
LUA_API void lua_pushcfunction(
lua_State* L, lua_CFunction fn, const char* debugname = NULL, int nup = 0, lua_Continuation cont = NULL);
LUA_API void lua_pushcclosurek(lua_State* L, lua_CFunction fn, const char* debugname, int nup, lua_Continuation cont);
LUA_API void lua_pushboolean(lua_State* L, int b);
LUA_API void lua_pushlightuserdata(lua_State* L, void* p);
LUA_API int lua_pushthread(lua_State* L);
@ -178,9 +179,9 @@ LUA_API void lua_rawget(lua_State* L, int idx);
LUA_API void lua_rawgeti(lua_State* L, int idx, int n);
LUA_API void lua_createtable(lua_State* L, int narr, int nrec);
LUA_API void lua_setreadonly(lua_State* L, int idx, bool value);
LUA_API void lua_setreadonly(lua_State* L, int idx, int enabled);
LUA_API int lua_getreadonly(lua_State* L, int idx);
LUA_API void lua_setsafeenv(lua_State* L, int idx, bool value);
LUA_API void lua_setsafeenv(lua_State* L, int idx, int enabled);
LUA_API void* lua_newuserdata(lua_State* L, size_t sz, int tag);
LUA_API void* lua_newuserdatadtor(lua_State* L, size_t sz, void (*dtor)(void*));
@ -200,7 +201,7 @@ LUA_API int lua_setfenv(lua_State* L, int idx);
/*
** `load' and `call' functions (load and run Luau bytecode)
*/
LUA_API int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size, int env = 0);
LUA_API int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size, int env);
LUA_API void lua_call(lua_State* L, int nargs, int nresults);
LUA_API int lua_pcall(lua_State* L, int nargs, int nresults, int errfunc);
@ -213,6 +214,8 @@ LUA_API int lua_resume(lua_State* L, lua_State* from, int narg);
LUA_API int lua_resumeerror(lua_State* L, lua_State* from);
LUA_API int lua_status(lua_State* L);
LUA_API int lua_isyieldable(lua_State* L);
LUA_API void* lua_getthreaddata(lua_State* L);
LUA_API void lua_setthreaddata(lua_State* L, void* data);
/*
** garbage-collection function and options
@ -291,6 +294,8 @@ LUA_API void lua_unref(lua_State* L, int ref);
#define lua_isnoneornil(L, n) (lua_type(L, (n)) <= LUA_TNIL)
#define lua_pushliteral(L, s) lua_pushlstring(L, "" s, (sizeof(s) / sizeof(char)) - 1)
#define lua_pushcfunction(L, fn, debugname) lua_pushcclosurek(L, fn, debugname, 0, NULL)
#define lua_pushcclosure(L, fn, debugname, nup) lua_pushcclosurek(L, fn, debugname, nup, NULL)
#define lua_setglobal(L, s) lua_setfield(L, LUA_GLOBALSINDEX, (s))
#define lua_getglobal(L, s) lua_getfield(L, LUA_GLOBALSINDEX, (s))
@ -317,8 +322,8 @@ LUA_API const char* lua_setlocal(lua_State* L, int level, int n);
LUA_API const char* lua_getupvalue(lua_State* L, int funcindex, int n);
LUA_API const char* lua_setupvalue(lua_State* L, int funcindex, int n);
LUA_API void lua_singlestep(lua_State* L, bool singlestep);
LUA_API void lua_breakpoint(lua_State* L, int funcindex, int line, bool enable);
LUA_API void lua_singlestep(lua_State* L, int enabled);
LUA_API void lua_breakpoint(lua_State* L, int funcindex, int line, int enabled);
/* Warning: this function is not thread-safe since it stores the result in a shared global array! Only use for debugging. */
LUA_API const char* lua_debugtrace(lua_State* L);
@ -346,6 +351,8 @@ struct lua_Debug
* can only be changed when the VM is not running any code */
struct lua_Callbacks
{
void* userdata; /* arbitrary userdata pointer that is never overwritten by Luau */
void (*interrupt)(lua_State* L, int gc); /* gets called at safepoints (loop back edges, call/ret, gc) if set */
void (*panic)(lua_State* L, int errcode); /* gets called when an unprotected error is raised (if longjmp is used) */
@ -357,6 +364,7 @@ struct lua_Callbacks
void (*debuginterrupt)(lua_State* L, lua_Debug* ar); /* gets called when thread execution is interrupted by break in another thread */
void (*debugprotectederror)(lua_State* L); /* gets called when protected call results in an error */
};
typedef struct lua_Callbacks lua_Callbacks;
LUA_API lua_Callbacks* lua_callbacks(lua_State* L);

View file

@ -8,11 +8,12 @@
#define luaL_typeerror(L, narg, tname) luaL_typeerrorL(L, narg, tname)
#define luaL_argerror(L, narg, extramsg) luaL_argerrorL(L, narg, extramsg)
typedef struct luaL_Reg
struct luaL_Reg
{
const char* name;
lua_CFunction func;
} luaL_Reg;
};
typedef struct luaL_Reg luaL_Reg;
LUALIB_API void luaL_register(lua_State* L, const char* libname, const luaL_Reg* l);
LUALIB_API int luaL_getmetafield(lua_State* L, int obj, const char* e);
@ -75,6 +76,7 @@ struct luaL_Buffer
struct TString* storage;
char buffer[LUA_BUFFERSIZE];
};
typedef struct luaL_Buffer luaL_Buffer;
// when internal buffer storage is exhausted, a mutable string value 'storage' will be placed on the stack
// in general, functions expect the mutable string buffer to be placed on top of the stack (top-1)

View file

@ -13,8 +13,6 @@
#include <string.h>
LUAU_FASTFLAG(LuauGcFullSkipInactiveThreads)
const char* lua_ident = "$Lua: Lua 5.1.4 Copyright (C) 1994-2008 Lua.org, PUC-Rio $\n"
"$Authors: R. Ierusalimschy, L. H. de Figueiredo & W. Celes $\n"
"$URL: www.lua.org $\n";
@ -595,7 +593,7 @@ const char* lua_pushfstringL(lua_State* L, const char* fmt, ...)
return ret;
}
void lua_pushcfunction(lua_State* L, lua_CFunction fn, const char* debugname, int nup, lua_Continuation cont)
void lua_pushcclosurek(lua_State* L, lua_CFunction fn, const char* debugname, int nup, lua_Continuation cont)
{
luaC_checkGC(L);
luaC_checkthreadsleep(L);
@ -700,12 +698,13 @@ void lua_createtable(lua_State* L, int narray, int nrec)
return;
}
void lua_setreadonly(lua_State* L, int objindex, bool value)
void lua_setreadonly(lua_State* L, int objindex, int enabled)
{
const TValue* o = index2adr(L, objindex);
api_check(L, ttistable(o));
Table* t = hvalue(o);
t->readonly = value;
api_check(L, t != hvalue(registry(L)));
t->readonly = bool(enabled);
return;
}
@ -718,12 +717,12 @@ int lua_getreadonly(lua_State* L, int objindex)
return res;
}
void lua_setsafeenv(lua_State* L, int objindex, bool value)
void lua_setsafeenv(lua_State* L, int objindex, int enabled)
{
const TValue* o = index2adr(L, objindex);
api_check(L, ttistable(o));
Table* t = hvalue(o);
t->safeenv = value;
t->safeenv = bool(enabled);
return;
}
@ -989,6 +988,16 @@ int lua_status(lua_State* L)
return L->status;
}
void* lua_getthreaddata(lua_State* L)
{
return L->userdata;
}
void lua_setthreaddata(lua_State* L, void* data)
{
L->userdata = data;
}
/*
** Garbage-collection function
*/
@ -1153,7 +1162,7 @@ void* lua_newuserdatadtor(lua_State* L, size_t sz, void (*dtor)(void*))
luaC_checkGC(L);
luaC_checkthreadsleep(L);
Udata* u = luaS_newudata(L, sz + sizeof(dtor), UTAG_IDTOR);
memcpy(u->data + sz, &dtor, sizeof(dtor));
memcpy(&u->data + sz, &dtor, sizeof(dtor));
setuvalue(L, L->top, u);
api_incr_top(L);
return u->data;

View file

@ -436,8 +436,8 @@ static const luaL_Reg base_funcs[] = {
static void auxopen(lua_State* L, const char* name, lua_CFunction f, lua_CFunction u)
{
lua_pushcfunction(L, u);
lua_pushcfunction(L, f, name, 1);
lua_pushcfunction(L, u, NULL);
lua_pushcclosure(L, f, name, 1);
lua_setfield(L, -2, name);
}
@ -456,10 +456,10 @@ LUALIB_API int luaopen_base(lua_State* L)
auxopen(L, "ipairs", luaB_ipairs, luaB_inext);
auxopen(L, "pairs", luaB_pairs, luaB_next);
lua_pushcfunction(L, luaB_pcally, "pcall", 0, luaB_pcallcont);
lua_pushcclosurek(L, luaB_pcally, "pcall", 0, luaB_pcallcont);
lua_setfield(L, -2, "pcall");
lua_pushcfunction(L, luaB_xpcally, "xpcall", 0, luaB_xpcallcont);
lua_pushcclosurek(L, luaB_xpcally, "xpcall", 0, luaB_xpcallcont);
lua_setfield(L, -2, "xpcall");
return 1;

View file

@ -2,8 +2,11 @@
// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details
#include "lualib.h"
#include "lcommon.h"
#include "lnumutils.h"
LUAU_FASTFLAGVARIABLE(LuauBit32Count, false)
#define ALLONES ~0u
#define NBITS int(8 * sizeof(unsigned))
@ -177,6 +180,44 @@ static int b_replace(lua_State* L)
return 1;
}
static int b_countlz(lua_State* L)
{
if (!FFlag::LuauBit32Count)
luaL_error(L, "bit32.countlz isn't enabled");
b_uint v = luaL_checkunsigned(L, 1);
b_uint r = NBITS;
for (int i = 0; i < NBITS; ++i)
if (v & (1u << (NBITS - 1 - i)))
{
r = i;
break;
}
lua_pushunsigned(L, r);
return 1;
}
static int b_countrz(lua_State* L)
{
if (!FFlag::LuauBit32Count)
luaL_error(L, "bit32.countrz isn't enabled");
b_uint v = luaL_checkunsigned(L, 1);
b_uint r = NBITS;
for (int i = 0; i < NBITS; ++i)
if (v & (1u << i))
{
r = i;
break;
}
lua_pushunsigned(L, r);
return 1;
}
static const luaL_Reg bitlib[] = {
{"arshift", b_arshift},
{"band", b_and},
@ -190,6 +231,8 @@ static const luaL_Reg bitlib[] = {
{"replace", b_replace},
{"rrotate", b_rrot},
{"rshift", b_rshift},
{"countlz", b_countlz},
{"countrz", b_countrz},
{NULL, NULL},
};

View file

@ -20,8 +20,9 @@
// If types of the arguments mismatch, luauF_* needs to return -1 and the execution will fall back to the usual call path
// If luauF_* succeeds, it needs to return *all* requested arguments, filling results with nil as appropriate.
// On input, nparams refers to the actual number of arguments (0+), whereas nresults contains LUA_MULTRET for arbitrary returns or 0+ for a
// fixed-length return Because of this, and the fact that "extra" returned values will be ignored, implementations below typically check that nresults
// is <= expected number, which covers the LUA_MULTRET case.
// fixed-length return
// Because of this, and the fact that "extra" returned values will be ignored, implementations below typically check that nresults is <= expected
// number, which covers the LUA_MULTRET case.
static int luauF_assert(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams)
{
@ -1030,6 +1031,52 @@ static int luauF_vector(lua_State* L, StkId res, TValue* arg0, int nresults, Stk
return -1;
}
static int luauF_countlz(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams)
{
if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0))
{
double a1 = nvalue(arg0);
unsigned n;
luai_num2unsigned(n, a1);
#ifdef _MSC_VER
unsigned long rl;
int r = _BitScanReverse(&rl, n) ? 31 - int(rl) : 32;
#else
int r = n == 0 ? 32 : __builtin_clz(n);
#endif
setnvalue(res, double(r));
return 1;
}
return -1;
}
static int luauF_countrz(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams)
{
if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0))
{
double a1 = nvalue(arg0);
unsigned n;
luai_num2unsigned(n, a1);
#ifdef _MSC_VER
unsigned long rl;
int r = _BitScanForward(&rl, n) ? int(rl) : 32;
#else
int r = n == 0 ? 32 : __builtin_ctz(n);
#endif
setnvalue(res, double(r));
return 1;
}
return -1;
}
luau_FastFunction luauF_table[256] = {
NULL,
luauF_assert,
@ -1096,4 +1143,7 @@ luau_FastFunction luauF_table[256] = {
luauF_tunpack,
luauF_vector,
luauF_countlz,
luauF_countrz,
};

View file

@ -5,7 +5,7 @@
#include "lstate.h"
#include "lvm.h"
LUAU_FASTFLAGVARIABLE(LuauPreferXpush, false)
LUAU_FASTFLAGVARIABLE(LuauCoroutineClose, false)
#define CO_RUN 0 /* running */
#define CO_SUS 1 /* suspended */
@ -17,7 +17,7 @@ LUAU_FASTFLAGVARIABLE(LuauPreferXpush, false)
static const char* const statnames[] = {"running", "suspended", "normal", "dead"};
static int costatus(lua_State* L, lua_State* co)
static int auxstatus(lua_State* L, lua_State* co)
{
if (co == L)
return CO_RUN;
@ -34,11 +34,11 @@ static int costatus(lua_State* L, lua_State* co)
return CO_SUS; /* initial state */
}
static int luaB_costatus(lua_State* L)
static int costatus(lua_State* L)
{
lua_State* co = lua_tothread(L, 1);
luaL_argexpected(L, co, 1, "thread");
lua_pushstring(L, statnames[costatus(L, co)]);
lua_pushstring(L, statnames[auxstatus(L, co)]);
return 1;
}
@ -47,7 +47,7 @@ static int auxresume(lua_State* L, lua_State* co, int narg)
// error handling for edge cases
if (co->status != LUA_YIELD)
{
int status = costatus(L, co);
int status = auxstatus(L, co);
if (status != CO_SUS)
{
lua_pushfstring(L, "cannot resume %s coroutine", statnames[status]);
@ -115,7 +115,7 @@ static int auxresumecont(lua_State* L, lua_State* co)
}
}
static int luaB_coresumefinish(lua_State* L, int r)
static int coresumefinish(lua_State* L, int r)
{
if (r < 0)
{
@ -131,7 +131,7 @@ static int luaB_coresumefinish(lua_State* L, int r)
}
}
static int luaB_coresumey(lua_State* L)
static int coresumey(lua_State* L)
{
lua_State* co = lua_tothread(L, 1);
luaL_argexpected(L, co, 1, "thread");
@ -141,10 +141,10 @@ static int luaB_coresumey(lua_State* L)
if (r == CO_STATUS_BREAK)
return interruptThread(L, co);
return luaB_coresumefinish(L, r);
return coresumefinish(L, r);
}
static int luaB_coresumecont(lua_State* L, int status)
static int coresumecont(lua_State* L, int status)
{
lua_State* co = lua_tothread(L, 1);
luaL_argexpected(L, co, 1, "thread");
@ -155,10 +155,10 @@ static int luaB_coresumecont(lua_State* L, int status)
int r = auxresumecont(L, co);
return luaB_coresumefinish(L, r);
return coresumefinish(L, r);
}
static int luaB_auxwrapfinish(lua_State* L, int r)
static int auxwrapfinish(lua_State* L, int r)
{
if (r < 0)
{
@ -173,7 +173,7 @@ static int luaB_auxwrapfinish(lua_State* L, int r)
return r;
}
static int luaB_auxwrapy(lua_State* L)
static int auxwrapy(lua_State* L)
{
lua_State* co = lua_tothread(L, lua_upvalueindex(1));
int narg = cast_int(L->top - L->base);
@ -182,10 +182,10 @@ static int luaB_auxwrapy(lua_State* L)
if (r == CO_STATUS_BREAK)
return interruptThread(L, co);
return luaB_auxwrapfinish(L, r);
return auxwrapfinish(L, r);
}
static int luaB_auxwrapcont(lua_State* L, int status)
static int auxwrapcont(lua_State* L, int status)
{
lua_State* co = lua_tothread(L, lua_upvalueindex(1));
@ -195,62 +195,80 @@ static int luaB_auxwrapcont(lua_State* L, int status)
int r = auxresumecont(L, co);
return luaB_auxwrapfinish(L, r);
return auxwrapfinish(L, r);
}
static int luaB_cocreate(lua_State* L)
static int cocreate(lua_State* L)
{
luaL_checktype(L, 1, LUA_TFUNCTION);
lua_State* NL = lua_newthread(L);
if (FFlag::LuauPreferXpush)
{
lua_xpush(L, NL, 1); // push function on top of NL
}
else
{
lua_pushvalue(L, 1); /* move function to top */
lua_xmove(L, NL, 1); /* move function from L to NL */
}
lua_xpush(L, NL, 1); // push function on top of NL
return 1;
}
static int luaB_cowrap(lua_State* L)
static int cowrap(lua_State* L)
{
luaB_cocreate(L);
lua_pushcfunction(L, luaB_auxwrapy, NULL, 1, luaB_auxwrapcont);
cocreate(L);
lua_pushcclosurek(L, auxwrapy, NULL, 1, auxwrapcont);
return 1;
}
static int luaB_yield(lua_State* L)
static int coyield(lua_State* L)
{
int nres = cast_int(L->top - L->base);
return lua_yield(L, nres);
}
static int luaB_corunning(lua_State* L)
static int corunning(lua_State* L)
{
if (lua_pushthread(L))
lua_pushnil(L); /* main thread is not a coroutine */
return 1;
}
static int luaB_yieldable(lua_State* L)
static int coyieldable(lua_State* L)
{
lua_pushboolean(L, lua_isyieldable(L));
return 1;
}
static int coclose(lua_State* L)
{
if (!FFlag::LuauCoroutineClose)
luaL_error(L, "coroutine.close is not enabled");
lua_State* co = lua_tothread(L, 1);
luaL_argexpected(L, co, 1, "thread");
int status = auxstatus(L, co);
if (status != CO_DEAD && status != CO_SUS)
luaL_error(L, "cannot close %s coroutine", statnames[status]);
if (co->status == LUA_OK || co->status == LUA_YIELD)
{
lua_pushboolean(L, true);
lua_resetthread(co);
return 1;
}
else
{
lua_pushboolean(L, false);
if (lua_gettop(co))
lua_xmove(co, L, 1); /* move error message */
lua_resetthread(co);
return 2;
}
}
static const luaL_Reg co_funcs[] = {
{"create", luaB_cocreate},
{"running", luaB_corunning},
{"status", luaB_costatus},
{"wrap", luaB_cowrap},
{"yield", luaB_yield},
{"isyieldable", luaB_yieldable},
{"create", cocreate},
{"running", corunning},
{"status", costatus},
{"wrap", cowrap},
{"yield", coyield},
{"isyieldable", coyieldable},
{"close", coclose},
{NULL, NULL},
};
@ -258,7 +276,7 @@ LUALIB_API int luaopen_coroutine(lua_State* L)
{
luaL_register(L, LUA_COLIBNAME, co_funcs);
lua_pushcfunction(L, luaB_coresumey, "resume", 0, luaB_coresumecont);
lua_pushcclosurek(L, coresumey, "resume", 0, coresumecont);
lua_setfield(L, -2, "resume");
return 1;

View file

@ -316,7 +316,7 @@ void luaG_breakpoint(lua_State* L, Proto* p, int line, bool enable)
p->debuginsn[j] = LUAU_INSN_OP(p->code[j]);
}
uint8_t op = enable ? LOP_BREAK : LUAU_INSN_OP(p->code[i]);
uint8_t op = enable ? LOP_BREAK : LUAU_INSN_OP(p->debuginsn[i]);
// patch just the opcode byte, leave arguments alone
p->code[i] &= ~0xff;
@ -357,17 +357,17 @@ int luaG_getline(Proto* p, int pc)
return p->abslineinfo[pc >> p->linegaplog2] + p->lineinfo[pc];
}
void lua_singlestep(lua_State* L, bool singlestep)
void lua_singlestep(lua_State* L, int enabled)
{
L->singlestep = singlestep;
L->singlestep = bool(enabled);
}
void lua_breakpoint(lua_State* L, int funcindex, int line, bool enable)
void lua_breakpoint(lua_State* L, int funcindex, int line, int enabled)
{
const TValue* func = luaA_toobject(L, funcindex);
api_check(L, ttisfunction(func) && !clvalue(func)->isC);
luaG_breakpoint(L, clvalue(func)->l.p, line, enable);
luaG_breakpoint(L, clvalue(func)->l.p, line, bool(enabled));
}
static size_t append(char* buf, size_t bufsize, size_t offset, const char* data)

View file

@ -8,12 +8,18 @@
#include "lmem.h"
#include "lvm.h"
#include <stdexcept>
#if LUA_USE_LONGJMP
#include <setjmp.h>
#include <stdlib.h>
#else
#include <stdexcept>
#endif
#include <string.h>
LUAU_FASTFLAGVARIABLE(LuauExceptionMessageFix, false)
LUAU_FASTFLAGVARIABLE(LuauCcallRestoreFix, false)
LUAU_FASTFLAG(LuauCoroutineClose)
/*
** {======================================================
@ -51,8 +57,8 @@ l_noret luaD_throw(lua_State* L, int errcode)
longjmp(jb->buf, 1);
}
if (L->global->panic)
L->global->panic(L, errcode);
if (L->global->cb.panic)
L->global->cb.panic(L, errcode);
abort();
}
@ -295,7 +301,10 @@ static void resume(lua_State* L, void* ud)
if (L->status == 0)
{
// start coroutine
LUAU_ASSERT(L->ci == L->base_ci && firstArg > L->base);
LUAU_ASSERT(L->ci == L->base_ci && firstArg >= L->base);
if (FFlag::LuauCoroutineClose && firstArg == L->base)
luaG_runerror(L, "cannot resume dead coroutine");
if (luau_precall(L, firstArg - 1, LUA_MULTRET) != PCRLUA)
return;
@ -532,6 +541,12 @@ int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t e
status = LUA_ERRERR;
}
if (FFlag::LuauCcallRestoreFix)
{
// Restore nCcalls before calling the debugprotectederror callback which may rely on the proper value to have been restored.
L->nCcalls = oldnCcalls;
}
// an error occurred, check if we have a protected error callback
if (L->global->cb.debugprotectederror)
{
@ -545,7 +560,10 @@ int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t e
StkId oldtop = restorestack(L, old_top);
luaF_close(L, oldtop); /* close eventual pending closures */
seterrorobj(L, status, oldtop);
L->nCcalls = oldnCcalls;
if (!FFlag::LuauCcallRestoreFix)
{
L->nCcalls = oldnCcalls;
}
L->ci = restoreci(L, old_ci);
L->base = L->ci->base;
restore_stack_limit(L);

View file

@ -12,10 +12,9 @@
#include <string.h>
#include <stdio.h>
LUAU_FASTFLAGVARIABLE(LuauRescanGrayAgain, false)
LUAU_FASTFLAGVARIABLE(LuauRescanGrayAgainForwardBarrier, false)
LUAU_FASTFLAGVARIABLE(LuauGcFullSkipInactiveThreads, false)
LUAU_FASTFLAGVARIABLE(LuauShrinkWeakTables, false)
LUAU_FASTFLAGVARIABLE(LuauSeparateAtomic, false)
LUAU_FASTFLAG(LuauArrayBoundary)
#define GC_SWEEPMAX 40
@ -64,13 +63,18 @@ static void recordGcStateTime(global_State* g, int startgcstate, double seconds,
g->gcstats.currcycle.marktime += seconds;
// atomic step had to be performed during the switch and it's tracked separately
if (g->gcstate == GCSsweepstring)
if (!FFlag::LuauSeparateAtomic && g->gcstate == GCSsweepstring)
g->gcstats.currcycle.marktime -= g->gcstats.currcycle.atomictime;
break;
case GCSatomic:
g->gcstats.currcycle.atomictime += seconds;
break;
case GCSsweepstring:
case GCSsweep:
g->gcstats.currcycle.sweeptime += seconds;
break;
default:
LUAU_ASSERT(!"Unexpected GC state");
}
if (assist)
@ -181,33 +185,15 @@ static int traversetable(global_State* g, Table* h)
if (h->metatable)
markobject(g, cast_to(Table*, h->metatable));
if (FFlag::LuauShrinkWeakTables)
/* is there a weak mode? */
if (const char* modev = gettablemode(g, h))
{
/* is there a weak mode? */
if (const char* modev = gettablemode(g, h))
{
weakkey = (strchr(modev, 'k') != NULL);
weakvalue = (strchr(modev, 'v') != NULL);
if (weakkey || weakvalue)
{ /* is really weak? */
h->gclist = g->weak; /* must be cleared after GC, ... */
g->weak = obj2gco(h); /* ... so put in the appropriate list */
}
}
}
else
{
const TValue* mode = gfasttm(g, h->metatable, TM_MODE);
if (mode && ttisstring(mode))
{ /* is there a weak mode? */
const char* modev = svalue(mode);
weakkey = (strchr(modev, 'k') != NULL);
weakvalue = (strchr(modev, 'v') != NULL);
if (weakkey || weakvalue)
{ /* is really weak? */
h->gclist = g->weak; /* must be cleared after GC, ... */
g->weak = obj2gco(h); /* ... so put in the appropriate list */
}
weakkey = (strchr(modev, 'k') != NULL);
weakvalue = (strchr(modev, 'v') != NULL);
if (weakkey || weakvalue)
{ /* is really weak? */
h->gclist = g->weak; /* must be cleared after GC, ... */
g->weak = obj2gco(h); /* ... so put in the appropriate list */
}
}
@ -295,7 +281,7 @@ static void traversestack(global_State* g, lua_State* l, bool clearstack)
for (StkId o = l->stack; o < l->top; o++)
markvalue(g, o);
/* final traversal? */
if (g->gcstate == GCSatomic || (FFlag::LuauGcFullSkipInactiveThreads && clearstack))
if (g->gcstate == GCSatomic || clearstack)
{
StkId stack_end = l->stack + l->stacksize;
for (StkId o = l->top; o < stack_end; o++) /* clear not-marked stack slice */
@ -334,28 +320,16 @@ static size_t propagatemark(global_State* g)
lua_State* th = gco2th(o);
g->gray = th->gclist;
if (FFlag::LuauGcFullSkipInactiveThreads)
LUAU_ASSERT(!luaC_threadsleeping(th));
// threads that are executing and the main thread are not deactivated
bool active = luaC_threadactive(th) || th == th->global->mainthread;
if (!active && g->gcstate == GCSpropagate)
{
LUAU_ASSERT(!luaC_threadsleeping(th));
traversestack(g, th, /* clearstack= */ true);
// threads that are executing and the main thread are not deactivated
bool active = luaC_threadactive(th) || th == th->global->mainthread;
if (!active && g->gcstate == GCSpropagate)
{
traversestack(g, th, /* clearstack= */ true);
l_setbit(th->stackstate, THREAD_SLEEPINGBIT);
}
else
{
th->gclist = g->grayagain;
g->grayagain = o;
black2gray(o);
traversestack(g, th, /* clearstack= */ false);
}
l_setbit(th->stackstate, THREAD_SLEEPINGBIT);
}
else
{
@ -383,12 +357,14 @@ static size_t propagatemark(global_State* g)
}
}
static void propagateall(global_State* g)
static size_t propagateall(global_State* g)
{
size_t work = 0;
while (g->gray)
{
propagatemark(g);
work += propagatemark(g);
}
return work;
}
/*
@ -413,11 +389,14 @@ static int isobjcleared(GCObject* o)
/*
** clear collected entries from weaktables
*/
static void cleartable(lua_State* L, GCObject* l)
static size_t cleartable(lua_State* L, GCObject* l)
{
size_t work = 0;
while (l)
{
Table* h = gco2h(l);
work += sizeof(Table) + sizeof(TValue) * h->sizearray + sizeof(LuaNode) * sizenode(h);
int i = h->sizearray;
while (i--)
{
@ -431,50 +410,36 @@ static void cleartable(lua_State* L, GCObject* l)
{
LuaNode* n = gnode(h, i);
if (FFlag::LuauShrinkWeakTables)
// non-empty entry?
if (!ttisnil(gval(n)))
{
// non-empty entry?
if (!ttisnil(gval(n)))
{
// can we clear key or value?
if (iscleared(gkey(n)) || iscleared(gval(n)))
{
setnilvalue(gval(n)); /* remove value ... */
removeentry(n); /* remove entry from table */
}
else
{
activevalues++;
}
}
}
else
{
if (!ttisnil(gval(n)) && /* non-empty entry? */
(iscleared(gkey(n)) || iscleared(gval(n))))
// can we clear key or value?
if (iscleared(gkey(n)) || iscleared(gval(n)))
{
setnilvalue(gval(n)); /* remove value ... */
removeentry(n); /* remove entry from table */
}
else
{
activevalues++;
}
}
}
if (FFlag::LuauShrinkWeakTables)
if (const char* modev = gettablemode(L->global, h))
{
if (const char* modev = gettablemode(L->global, h))
// are we allowed to shrink this weak table?
if (strchr(modev, 's'))
{
// are we allowed to shrink this weak table?
if (strchr(modev, 's'))
{
// shrink at 37.5% occupancy
if (activevalues < sizenode(h) * 3 / 8)
luaH_resizehash(L, h, activevalues);
}
// shrink at 37.5% occupancy
if (activevalues < sizenode(h) * 3 / 8)
luaH_resizehash(L, h, activevalues);
}
}
l = h->gclist;
}
return work;
}
static void shrinkstack(lua_State* L)
@ -653,37 +618,49 @@ static void markroot(lua_State* L)
g->gcstate = GCSpropagate;
}
static void remarkupvals(global_State* g)
static size_t remarkupvals(global_State* g)
{
UpVal* uv;
for (uv = g->uvhead.u.l.next; uv != &g->uvhead; uv = uv->u.l.next)
size_t work = 0;
for (UpVal* uv = g->uvhead.u.l.next; uv != &g->uvhead; uv = uv->u.l.next)
{
work += sizeof(UpVal);
LUAU_ASSERT(uv->u.l.next->u.l.prev == uv && uv->u.l.prev->u.l.next == uv);
if (isgray(obj2gco(uv)))
markvalue(g, uv->v);
}
return work;
}
static void atomic(lua_State* L)
static size_t atomic(lua_State* L)
{
global_State* g = L->global;
g->gcstate = GCSatomic;
size_t work = 0;
if (FFlag::LuauSeparateAtomic)
{
LUAU_ASSERT(g->gcstate == GCSatomic);
}
else
{
g->gcstate = GCSatomic;
}
/* remark occasional upvalues of (maybe) dead threads */
remarkupvals(g);
work += remarkupvals(g);
/* traverse objects caught by write barrier and by 'remarkupvals' */
propagateall(g);
work += propagateall(g);
/* remark weak tables */
g->gray = g->weak;
g->weak = NULL;
LUAU_ASSERT(!iswhite(obj2gco(g->mainthread)));
markobject(g, L); /* mark running thread */
markmt(g); /* mark basic metatables (again) */
propagateall(g);
work += propagateall(g);
/* remark gray again */
g->gray = g->grayagain;
g->grayagain = NULL;
propagateall(g);
cleartable(L, g->weak); /* remove collected objects from weak tables */
work += propagateall(g);
work += cleartable(L, g->weak); /* remove collected objects from weak tables */
g->weak = NULL;
/* flip current white */
g->currentwhite = cast_byte(otherwhite(g));
@ -691,10 +668,15 @@ static void atomic(lua_State* L)
g->sweepgc = &g->rootgc;
g->gcstate = GCSsweepstring;
GC_INTERRUPT(GCSatomic);
if (!FFlag::LuauSeparateAtomic)
{
GC_INTERRUPT(GCSatomic);
}
return work;
}
static size_t singlestep(lua_State* L)
static size_t gcstep(lua_State* L, size_t limit)
{
size_t cost = 0;
global_State* g = L->global;
@ -703,36 +685,44 @@ static size_t singlestep(lua_State* L)
case GCSpause:
{
markroot(L); /* start a new collection */
LUAU_ASSERT(g->gcstate == GCSpropagate);
break;
}
case GCSpropagate:
{
if (FFlag::LuauRescanGrayAgain)
while (g->gray && cost < limit)
{
if (g->gray)
{
g->gcstats.currcycle.markitems++;
g->gcstats.currcycle.markitems++;
cost = propagatemark(g);
cost += propagatemark(g);
}
if (!g->gray)
{
// perform one iteration over 'gray again' list
g->gray = g->grayagain;
g->grayagain = NULL;
g->gcstate = GCSpropagateagain;
}
break;
}
case GCSpropagateagain:
{
while (g->gray && cost < limit)
{
g->gcstats.currcycle.markitems++;
cost += propagatemark(g);
}
if (!g->gray) /* no more `gray' objects */
{
if (FFlag::LuauSeparateAtomic)
{
g->gcstate = GCSatomic;
}
else
{
// perform one iteration over 'gray again' list
g->gray = g->grayagain;
g->grayagain = NULL;
g->gcstate = GCSpropagateagain;
}
}
else
{
if (g->gray)
{
g->gcstats.currcycle.markitems++;
cost = propagatemark(g);
}
else /* no more `gray' objects */
{
double starttimestamp = lua_clock();
@ -740,73 +730,70 @@ static size_t singlestep(lua_State* L)
g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes;
atomic(L); /* finish mark phase */
LUAU_ASSERT(g->gcstate == GCSsweepstring);
g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp;
}
}
break;
}
case GCSpropagateagain:
case GCSatomic:
{
if (g->gray)
{
g->gcstats.currcycle.markitems++;
g->gcstats.currcycle.atomicstarttimestamp = lua_clock();
g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes;
cost = propagatemark(g);
}
else /* no more `gray' objects */
{
double starttimestamp = lua_clock();
g->gcstats.currcycle.atomicstarttimestamp = starttimestamp;
g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes;
atomic(L); /* finish mark phase */
g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp;
}
cost = atomic(L); /* finish mark phase */
LUAU_ASSERT(g->gcstate == GCSsweepstring);
break;
}
case GCSsweepstring:
{
size_t traversedcount = 0;
sweepwholelist(L, &g->strt.hash[g->sweepstrgc++], &traversedcount);
while (g->sweepstrgc < g->strt.size && cost < limit)
{
size_t traversedcount = 0;
sweepwholelist(L, &g->strt.hash[g->sweepstrgc++], &traversedcount);
g->gcstats.currcycle.sweepitems += traversedcount;
cost += GC_SWEEPCOST;
}
// nothing more to sweep?
if (g->sweepstrgc >= g->strt.size)
{
// sweep string buffer list and preserve used string count
uint32_t nuse = L->global->strt.nuse;
size_t traversedcount = 0;
sweepwholelist(L, &g->strbufgc, &traversedcount);
L->global->strt.nuse = nuse;
g->gcstats.currcycle.sweepitems += traversedcount;
g->gcstate = GCSsweep; // end sweep-string phase
}
g->gcstats.currcycle.sweepitems += traversedcount;
cost = GC_SWEEPCOST;
break;
}
case GCSsweep:
{
size_t traversedcount = 0;
g->sweepgc = sweeplist(L, g->sweepgc, GC_SWEEPMAX, &traversedcount);
while (*g->sweepgc && cost < limit)
{
size_t traversedcount = 0;
g->sweepgc = sweeplist(L, g->sweepgc, GC_SWEEPMAX, &traversedcount);
g->gcstats.currcycle.sweepitems += traversedcount;
g->gcstats.currcycle.sweepitems += traversedcount;
cost += GC_SWEEPMAX * GC_SWEEPCOST;
}
if (*g->sweepgc == NULL)
{ /* nothing more to sweep? */
shrinkbuffers(L);
g->gcstate = GCSpause; /* end collection */
}
cost = GC_SWEEPMAX * GC_SWEEPCOST;
break;
}
default:
LUAU_ASSERT(0);
LUAU_ASSERT(!"Unexpected GC state");
}
return cost;
}
@ -878,33 +865,15 @@ void luaC_step(lua_State* L, bool assist)
if (g->gcstate == GCSpause)
startGcCycleStats(g);
if (assist)
g->gcstats.currcycle.assistwork += lim;
else
g->gcstats.currcycle.explicitwork += lim;
int lastgcstate = g->gcstate;
double lasttimestamp = lua_clock();
// always perform at least one single step
do
{
lim -= singlestep(L);
size_t work = gcstep(L, lim);
// if we have switched to a different state, capture the duration of last stage
// this way we reduce the number of timer calls we make
if (lastgcstate != g->gcstate)
{
GC_INTERRUPT(lastgcstate);
double now = lua_clock();
recordGcStateTime(g, lastgcstate, now - lasttimestamp, assist);
lasttimestamp = now;
lastgcstate = g->gcstate;
}
} while (lim > 0 && g->gcstate != GCSpause);
if (assist)
g->gcstats.currcycle.assistwork += work;
else
g->gcstats.currcycle.explicitwork += work;
recordGcStateTime(g, lastgcstate, lua_clock() - lasttimestamp, assist);
@ -931,7 +900,7 @@ void luaC_step(lua_State* L, bool assist)
g->GCthreshold -= debt;
}
GC_INTERRUPT(g->gcstate);
GC_INTERRUPT(lastgcstate);
}
void luaC_fullgc(lua_State* L)
@ -941,7 +910,7 @@ void luaC_fullgc(lua_State* L)
if (g->gcstate == GCSpause)
startGcCycleStats(g);
if (g->gcstate <= GCSpropagateagain)
if (g->gcstate <= (FFlag::LuauSeparateAtomic ? GCSatomic : GCSpropagateagain))
{
/* reset sweep marks to sweep all elements (returning them to white) */
g->sweepstrgc = 0;
@ -952,12 +921,12 @@ void luaC_fullgc(lua_State* L)
g->weak = NULL;
g->gcstate = GCSsweepstring;
}
LUAU_ASSERT(g->gcstate != GCSpause && g->gcstate != GCSpropagate && g->gcstate != GCSpropagateagain);
LUAU_ASSERT(g->gcstate == GCSsweepstring || g->gcstate == GCSsweep);
/* finish any pending sweep phase */
while (g->gcstate != GCSpause)
{
LUAU_ASSERT(g->gcstate == GCSsweepstring || g->gcstate == GCSsweep);
singlestep(L);
gcstep(L, SIZE_MAX);
}
finishGcCycleStats(g);
@ -968,7 +937,7 @@ void luaC_fullgc(lua_State* L)
markroot(L);
while (g->gcstate != GCSpause)
{
singlestep(L);
gcstep(L, SIZE_MAX);
}
/* reclaim as much buffer memory as possible (shrinkbuffers() called during sweep is incremental) */
shrinkbuffersfull(L);
@ -994,14 +963,11 @@ void luaC_fullgc(lua_State* L)
void luaC_barrierupval(lua_State* L, GCObject* v)
{
if (FFlag::LuauGcFullSkipInactiveThreads)
{
global_State* g = L->global;
LUAU_ASSERT(iswhite(v) && !isdead(g, v));
global_State* g = L->global;
LUAU_ASSERT(iswhite(v) && !isdead(g, v));
if (keepinvariant(g))
reallymarkobject(g, v);
}
if (keepinvariant(g))
reallymarkobject(g, v);
}
void luaC_barrierf(lua_State* L, GCObject* o, GCObject* v)
@ -1629,7 +1595,7 @@ int64_t luaC_allocationrate(lua_State* L)
global_State* g = L->global;
const double durationthreshold = 1e-3; // avoid measuring intervals smaller than 1ms
if (g->gcstate <= GCSpropagateagain)
if (g->gcstate <= (FFlag::LuauSeparateAtomic ? GCSatomic : GCSpropagateagain))
{
double duration = lua_clock() - g->gcstats.lastcycle.endtimestamp;

View file

@ -6,8 +6,6 @@
#include "lobject.h"
#include "lstate.h"
LUAU_FASTFLAG(LuauGcFullSkipInactiveThreads)
/*
** Possible states of the Garbage Collector
*/
@ -25,7 +23,7 @@ LUAU_FASTFLAG(LuauGcFullSkipInactiveThreads)
** still-black objects. The invariant is restored when sweep ends and
** all objects are white again.
*/
#define keepinvariant(g) ((g)->gcstate == GCSpropagate || (g)->gcstate == GCSpropagateagain)
#define keepinvariant(g) ((g)->gcstate == GCSpropagate || (g)->gcstate == GCSpropagateagain || (g)->gcstate == GCSatomic)
/*
** some useful bit tricks
@ -147,4 +145,4 @@ LUAI_FUNC void luaC_validate(lua_State* L);
LUAI_FUNC void luaC_dump(lua_State* L, void* file, const char* (*categoryName)(lua_State* L, uint8_t memcat));
LUAI_FUNC int64_t luaC_allocationrate(lua_State* L);
LUAI_FUNC void luaC_wakethread(lua_State* L);
LUAI_FUNC const char* luaC_statename(int state);
LUAI_FUNC const char* luaC_statename(int state);

View file

@ -22,7 +22,7 @@ LUALIB_API void luaL_openlibs(lua_State* L)
const luaL_Reg* lib = lualibs;
for (; lib->func; lib++)
{
lua_pushcfunction(L, lib->func);
lua_pushcfunction(L, lib->func, NULL);
lua_pushstring(L, lib->name);
lua_call(L, 1, 0);
}

View file

@ -199,7 +199,7 @@ static void* luaM_newblock(lua_State* L, int sizeClass)
if (page->freeNext >= 0)
{
block = page->data + page->freeNext;
block = &page->data + page->freeNext;
ASAN_UNPOISON_MEMORY_REGION(block, page->blockSize);
page->freeNext -= page->blockSize;

View file

@ -124,6 +124,34 @@ void luaE_freethread(lua_State* L, lua_State* L1)
luaM_free(L, L1, sizeof(lua_State), L1->memcat);
}
void lua_resetthread(lua_State* L)
{
/* close upvalues before clearing anything */
luaF_close(L, L->stack);
/* clear call frames */
CallInfo* ci = L->base_ci;
ci->func = L->stack;
ci->base = ci->func + 1;
ci->top = ci->base + LUA_MINSTACK;
setnilvalue(ci->func);
L->ci = ci;
luaD_reallocCI(L, BASIC_CI_SIZE);
/* clear thread state */
L->status = LUA_OK;
L->base = L->ci->base;
L->top = L->ci->base;
L->nCcalls = L->baseCcalls = 0;
/* clear thread stack */
luaD_reallocstack(L, BASIC_STACK_SIZE);
for (int i = 0; i < L->stacksize; i++)
setnilvalue(L->stack + i);
}
int lua_isthreadreset(lua_State* L)
{
return L->ci == L->base_ci && L->base == L->top && L->status == LUA_OK;
}
lua_State* lua_newstate(lua_Alloc f, void* ud)
{
int i;

View file

@ -226,7 +226,7 @@ void luaS_freeudata(lua_State* L, Udata* u)
void (*dtor)(void*) = nullptr;
if (u->tag == UTAG_IDTOR)
memcpy(&dtor, u->data + u->len - sizeof(dtor), sizeof(dtor));
memcpy(&dtor, &u->data + u->len - sizeof(dtor), sizeof(dtor));
else if (u->tag)
dtor = L->global->udatagc[u->tag];

View file

@ -8,6 +8,8 @@
#include <string.h>
#include <stdio.h>
LUAU_FASTFLAGVARIABLE(LuauStrPackUBCastFix, false)
/* macro to `unsign' a character */
#define uchar(c) ((unsigned char)(c))
@ -746,7 +748,7 @@ static int gmatch(lua_State* L)
luaL_checkstring(L, 2);
lua_settop(L, 2);
lua_pushinteger(L, 0);
lua_pushcfunction(L, gmatch_aux, NULL, 3);
lua_pushcclosure(L, gmatch_aux, NULL, 3);
return 1;
}
@ -1404,10 +1406,20 @@ static int str_pack(lua_State* L)
}
case Kuint:
{ /* unsigned integers */
unsigned long long n = (unsigned long long)luaL_checknumber(L, arg);
if (size < SZINT) /* need overflow check? */
luaL_argcheck(L, n < ((unsigned long long)1 << (size * NB)), arg, "unsigned overflow");
packint(&b, n, h.islittle, size, 0);
if (FFlag::LuauStrPackUBCastFix)
{
long long n = (long long)luaL_checknumber(L, arg);
if (size < SZINT) /* need overflow check? */
luaL_argcheck(L, (unsigned long long)n < ((unsigned long long)1 << (size * NB)), arg, "unsigned overflow");
packint(&b, (unsigned long long)n, h.islittle, size, 0);
}
else
{
unsigned long long n = (unsigned long long)luaL_checknumber(L, arg);
if (size < SZINT) /* need overflow check? */
luaL_argcheck(L, n < ((unsigned long long)1 << (size * NB)), arg, "unsigned overflow");
packint(&b, n, h.islittle, size, 0);
}
break;
}
case Kfloat:

View file

@ -30,6 +30,7 @@ LUAU_FASTFLAGVARIABLE(LuauArrayBoundary, false)
#define MAXBITS 26
#define MAXSIZE (1 << MAXBITS)
static_assert(offsetof(LuaNode, val) == 0, "Unexpected Node memory layout, pointer cast in gval2slot is incorrect");
// TKey is bitpacked for memory efficiency so we need to validate bit counts for worst case
static_assert(TKey{{NULL}, 0, LUA_TDEADKEY, 0}.tt == LUA_TDEADKEY, "not enough bits for tt");
static_assert(TKey{{NULL}, 0, LUA_TNIL, MAXSIZE - 1}.next == MAXSIZE - 1, "not enough bits for next");

View file

@ -9,7 +9,6 @@
#define gval(n) (&(n)->val)
#define gnext(n) ((n)->key.next)
static_assert(offsetof(LuaNode, val) == 0, "Unexpected Node memory layout, pointer cast below is incorrect");
#define gval2slot(t, v) int(cast_to(LuaNode*, static_cast<const TValue*>(v)) - t->node)
LUAI_FUNC const TValue* luaH_getnum(Table* t, int key);

View file

@ -9,14 +9,6 @@
#include "ldebug.h"
#include "lvm.h"
LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauTableMoveTelemetry, false)
LUAU_FASTFLAGVARIABLE(LuauTableFreeze, false)
bool lua_telemetry_table_move_oob_src_from = false;
bool lua_telemetry_table_move_oob_src_to = false;
bool lua_telemetry_table_move_oob_dst = false;
static int foreachi(lua_State* L)
{
luaL_checktype(L, 1, LUA_TTABLE);
@ -202,22 +194,6 @@ static int tmove(lua_State* L)
int tt = !lua_isnoneornil(L, 5) ? 5 : 1; /* destination table */
luaL_checktype(L, tt, LUA_TTABLE);
if (DFFlag::LuauTableMoveTelemetry)
{
int nf = lua_objlen(L, 1);
int nt = lua_objlen(L, tt);
// source index range must be in bounds in source table unless the table is empty (permits 1..#t moves)
if (!(f == 1 || (f >= 1 && f <= nf)))
lua_telemetry_table_move_oob_src_from = true;
if (!(e == nf || (e >= 1 && e <= nf)))
lua_telemetry_table_move_oob_src_to = true;
// destination index must be in bounds in dest table or be exactly at the first empty element (permits concats)
if (!(t == nt + 1 || (t >= 1 && t <= nt + 1)))
lua_telemetry_table_move_oob_dst = true;
}
if (e >= f)
{ /* otherwise, nothing to move */
luaL_argcheck(L, f > 0 || e < INT_MAX + f, 3, "too many elements to move");
@ -513,9 +489,6 @@ static int tclear(lua_State* L)
static int tfreeze(lua_State* L)
{
if (!FFlag::LuauTableFreeze)
luaG_runerror(L, "table.freeze is disabled");
luaL_checktype(L, 1, LUA_TTABLE);
luaL_argcheck(L, !lua_getreadonly(L, 1), 1, "table is already frozen");
luaL_argcheck(L, !luaL_getmetafield(L, 1, "__metatable"), 1, "table has a protected metatable");
@ -528,9 +501,6 @@ static int tfreeze(lua_State* L)
static int tisfrozen(lua_State* L)
{
if (!FFlag::LuauTableFreeze)
luaG_runerror(L, "table.isfrozen is disabled");
luaL_checktype(L, 1, LUA_TTABLE);
lua_pushboolean(L, lua_getreadonly(L, 1));

View file

@ -265,7 +265,7 @@ static int iter_aux(lua_State* L)
static int iter_codes(lua_State* L)
{
luaL_checkstring(L, 1);
lua_pushcfunction(L, iter_aux);
lua_pushcfunction(L, iter_aux, NULL);
lua_pushvalue(L, 1);
lua_pushinteger(L, 0);
return 3;

View file

@ -16,8 +16,6 @@
#include <string.h>
LUAU_FASTFLAGVARIABLE(LuauLoopUseSafeenv, false)
// Disable c99-designator to avoid the warning in CGOTO dispatch table
#ifdef __clang__
#if __has_warning("-Wc99-designator")
@ -292,10 +290,6 @@ inline bool luau_skipstep(uint8_t op)
return op == LOP_PREPVARARGS || op == LOP_BREAK;
}
// declared in lbaselib.cpp, needed to support cases when pairs/ipairs have been replaced via setfenv
LUAI_FUNC int luaB_inext(lua_State* L);
LUAI_FUNC int luaB_next(lua_State* L);
template<bool SingleStep>
static void luau_execute(lua_State* L)
{
@ -2223,8 +2217,7 @@ static void luau_execute(lua_State* L)
StkId ra = VM_REG(LUAU_INSN_A(insn));
// fast-path: ipairs/inext
bool safeenv = FFlag::LuauLoopUseSafeenv ? cl->env->safeenv : ttisfunction(ra) && clvalue(ra)->isC && clvalue(ra)->c.f == luaB_inext;
if (safeenv && ttistable(ra + 1) && ttisnumber(ra + 2) && nvalue(ra + 2) == 0.0)
if (cl->env->safeenv && ttistable(ra + 1) && ttisnumber(ra + 2) && nvalue(ra + 2) == 0.0)
{
setpvalue(ra + 2, reinterpret_cast<void*>(uintptr_t(0)));
}
@ -2304,8 +2297,7 @@ static void luau_execute(lua_State* L)
StkId ra = VM_REG(LUAU_INSN_A(insn));
// fast-path: pairs/next
bool safeenv = FFlag::LuauLoopUseSafeenv ? cl->env->safeenv : ttisfunction(ra) && clvalue(ra)->isC && clvalue(ra)->c.f == luaB_next;
if (safeenv && ttistable(ra + 1) && ttisnil(ra + 2))
if (cl->env->safeenv && ttistable(ra + 1) && ttisnil(ra + 2))
{
setpvalue(ra + 2, reinterpret_cast<void*>(uintptr_t(0)));
}

View file

@ -12,7 +12,32 @@
#include <string.h>
#include <vector>
// TODO: RAII deallocation doesn't work for longjmp builds if a memory error happens
template<typename T>
struct TempBuffer
{
lua_State* L;
T* data;
size_t count;
TempBuffer(lua_State* L, size_t count)
: L(L)
, data(luaM_newarray(L, count, T, 0))
, count(count)
{
}
~TempBuffer()
{
luaM_freearray(L, data, count, T, 0);
}
T& operator[](size_t index)
{
LUAU_ASSERT(index < count);
return data[index];
}
};
void luaV_getimport(lua_State* L, Table* env, TValue* k, uint32_t id, bool propagatenil)
{
@ -67,7 +92,7 @@ static unsigned int readVarInt(const char* data, size_t size, size_t& offset)
return result;
}
static TString* readString(std::vector<TString*>& strings, const char* data, size_t size, size_t& offset)
static TString* readString(TempBuffer<TString*>& strings, const char* data, size_t size, size_t& offset)
{
unsigned int id = readVarInt(data, size, offset);
@ -133,6 +158,7 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size
}
// pause GC for the duration of deserialization - some objects we're creating aren't rooted
// TODO: if an allocation error happens mid-load, we do not unpause GC!
size_t GCthreshold = L->global->GCthreshold;
L->global->GCthreshold = SIZE_MAX;
@ -144,7 +170,7 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size
// string table
unsigned int stringCount = readVarInt(data, size, offset);
std::vector<TString*> strings(stringCount);
TempBuffer<TString*> strings(L, stringCount);
for (unsigned int i = 0; i < stringCount; ++i)
{
@ -156,7 +182,7 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size
// proto table
unsigned int protoCount = readVarInt(data, size, offset);
std::vector<Proto*> protos(protoCount);
TempBuffer<Proto*> protos(L, protoCount);
for (unsigned int i = 0; i < protoCount; ++i)
{
@ -320,6 +346,8 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size
uint32_t mainid = readVarInt(data, size, offset);
Proto* main = protos[mainid];
luaC_checkthreadsleep(L);
Closure* cl = luaF_newLclosure(L, 0, envt, main);
setclvalue(L, L->top, cl);
incr_top(L);

View file

@ -25,8 +25,8 @@ try:
import scipy
from scipy import stats
except ModuleNotFoundError:
print("scipy package is required")
exit(1)
print("Warning: scipy package is not installed, confidence values will not be available")
stats = None
scriptdir = os.path.dirname(os.path.realpath(__file__))
defaultVm = 'luau.exe' if os.name == "nt" else './luau'
@ -200,11 +200,14 @@ def finalizeResult(result):
result.sampleStdDev = math.sqrt(sumOfSquares / (result.count - 1))
result.unbiasedEst = result.sampleStdDev * result.sampleStdDev
# Two-tailed distribution with 95% conf.
tValue = stats.t.ppf(1 - 0.05 / 2, result.count - 1)
if stats:
# Two-tailed distribution with 95% conf.
tValue = stats.t.ppf(1 - 0.05 / 2, result.count - 1)
# Compute confidence interval
result.sampleConfidenceInterval = tValue * result.sampleStdDev / math.sqrt(result.count)
# Compute confidence interval
result.sampleConfidenceInterval = tValue * result.sampleStdDev / math.sqrt(result.count)
else:
result.sampleConfidenceInterval = result.sampleStdDev
else:
result.sampleStdDev = 0
result.unbiasedEst = 0
@ -377,14 +380,19 @@ def analyzeResult(subdir, main, comparisons):
tStat = abs(main.avg - compare.avg) / (pooledStdDev * math.sqrt(2 / main.count))
degreesOfFreedom = 2 * main.count - 2
# Two-tailed distribution with 95% conf.
tCritical = stats.t.ppf(1 - 0.05 / 2, degreesOfFreedom)
if stats:
# Two-tailed distribution with 95% conf.
tCritical = stats.t.ppf(1 - 0.05 / 2, degreesOfFreedom)
noSignificantDifference = tStat < tCritical
noSignificantDifference = tStat < tCritical
pValue = 2 * (1 - stats.t.cdf(tStat, df = degreesOfFreedom))
else:
noSignificantDifference = None
pValue = -1
pValue = 2 * (1 - stats.t.cdf(tStat, df = degreesOfFreedom))
if noSignificantDifference:
if noSignificantDifference is None:
verdict = ""
elif noSignificantDifference:
verdict = "likely same"
elif main.avg < compare.avg:
verdict = "likely worse"

859
bench/tests/chess.lua Normal file
View file

@ -0,0 +1,859 @@
local bench = script and require(script.Parent.bench_support) or require("bench_support")
local RANKS = "12345678"
local FILES = "abcdefgh"
local PieceSymbols = "PpRrNnBbQqKk"
local UnicodePieces = {"", "", "", "", "", "", "", "", "", "", "", ""}
local StartingFen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
--
-- Lua 5.2 Compat
--
if not table.create then
function table.create(n, v)
local result = {}
for i=1,n do result[i] = v end
return result
end
end
if not table.move then
function table.move(a, from, to, start, target)
local dx = start - from
for i=from,to do
target[i+dx] = a[i]
end
end
end
--
-- Utils
--
local function square(s)
return RANKS:find(s:sub(2,2)) * 8 + FILES:find(s:sub(1,1)) - 9
end
local function squareName(n)
local file = n % 8
local rank = (n-file)/8
return FILES:sub(file+1,file+1) .. RANKS:sub(rank+1,rank+1)
end
local function moveName(v )
local from = bit32.extract(v, 6, 6)
local to = bit32.extract(v, 0, 6)
local piece = bit32.extract(v, 20, 4)
local captured = bit32.extract(v, 25, 4)
local move = PieceSymbols:sub(piece,piece) .. ' ' .. squareName(from) .. (captured ~= 0 and 'x' or '-') .. squareName(to)
if bit32.extract(v,14) == 1 then
if to > from then
return "O-O"
else
return "O-O-O"
end
end
local promote = bit32.extract(v,15,4)
if promote ~= 0 then
move = move .. "=" .. PieceSymbols:sub(promote,promote)
end
return move
end
local function ucimove(m)
local mm = squareName(bit32.extract(m, 6, 6)) .. squareName(bit32.extract(m, 0, 6))
local promote = bit32.extract(m,15,4)
if promote > 0 then
mm = mm .. PieceSymbols:sub(promote,promote):lower()
end
return mm
end
local _utils = {squareName, moveName}
--
-- Bitboards
--
local Bitboard = {}
function Bitboard:toString()
local out = {}
local src = self.h
for x=7,0,-1 do
table.insert(out, RANKS:sub(x+1,x+1))
table.insert(out, " ")
local bit = bit32.lshift(1,(x%4) * 8)
for x=0,7 do
if bit32.band(src, bit) ~= 0 then
table.insert(out, "x ")
else
table.insert(out, "- ")
end
bit = bit32.lshift(bit, 1)
end
if x == 4 then
src = self.l
end
table.insert(out, "\n")
end
table.insert(out, ' ' .. FILES:gsub('.', '%1 ') .. '\n')
table.insert(out, '#: ' .. self:popcnt() .. "\tl:" .. self.l .. "\th:" .. self.h)
return table.concat(out)
end
function Bitboard.from(l ,h )
return setmetatable({l=l, h=h}, Bitboard)
end
Bitboard.zero = Bitboard.from(0,0)
Bitboard.full = Bitboard.from(0xFFFFFFFF, 0xFFFFFFFF)
local Rank1 = Bitboard.from(0x000000FF, 0)
local Rank3 = Bitboard.from(0x00FF0000, 0)
local Rank6 = Bitboard.from(0, 0x0000FF00)
local Rank8 = Bitboard.from(0, 0xFF000000)
local FileA = Bitboard.from(0x01010101, 0x01010101)
local FileB = Bitboard.from(0x02020202, 0x02020202)
local FileC = Bitboard.from(0x04040404, 0x04040404)
local FileD = Bitboard.from(0x08080808, 0x08080808)
local FileE = Bitboard.from(0x10101010, 0x10101010)
local FileF = Bitboard.from(0x20202020, 0x20202020)
local FileG = Bitboard.from(0x40404040, 0x40404040)
local FileH = Bitboard.from(0x80808080, 0x80808080)
local _Files = {FileA, FileB, FileC, FileD, FileE, FileF, FileG, FileH}
-- These masks are filled out below for all files
local RightMasks = {FileH}
local LeftMasks = {FileA}
local function popcnt32(i)
i = i - bit32.band(bit32.rshift(i,1), 0x55555555)
i = bit32.band(i, 0x33333333) + bit32.band(bit32.rshift(i,2), 0x33333333)
return bit32.rshift(bit32.band(i + bit32.rshift(i,4), 0x0F0F0F0F) * 0x01010101, 24)
end
function Bitboard:up()
return self:lshift(8)
end
function Bitboard:down()
return self:rshift(8)
end
function Bitboard:right()
return self:band(FileH:inverse()):lshift(1)
end
function Bitboard:left()
return self:band(FileA:inverse()):rshift(1)
end
function Bitboard:move(x,y)
local out = self
if x < 0 then out = out:bandnot(RightMasks[-x]):lshift(-x) end
if x > 0 then out = out:bandnot(LeftMasks[x]):rshift(x) end
if y < 0 then out = out:rshift(-8 * y) end
if y > 0 then out = out:lshift(8 * y) end
return out
end
function Bitboard:popcnt()
return popcnt32(self.l) + popcnt32(self.h)
end
function Bitboard:band(other )
return Bitboard.from(bit32.band(self.l,other.l), bit32.band(self.h, other.h))
end
function Bitboard:bandnot(other )
return Bitboard.from(bit32.band(self.l,bit32.bnot(other.l)), bit32.band(self.h, bit32.bnot(other.h)))
end
function Bitboard:bandempty(other )
return bit32.band(self.l,other.l) == 0 and bit32.band(self.h, other.h) == 0
end
function Bitboard:bor(other )
return Bitboard.from(bit32.bor(self.l,other.l), bit32.bor(self.h, other.h))
end
function Bitboard:bxor(other )
return Bitboard.from(bit32.bxor(self.l,other.l), bit32.bxor(self.h, other.h))
end
function Bitboard:inverse()
return Bitboard.from(bit32.bxor(self.l,0xFFFFFFFF), bit32.bxor(self.h, 0xFFFFFFFF))
end
function Bitboard:empty()
return self.h == 0 and self.l == 0
end
if not bit32.countrz then
local function ctz(v)
if v == 0 then return 32 end
local offset = 0
while bit32.extract(v, offset) == 0 do
offset = offset + 1
end
return offset
end
function Bitboard:ctz()
local result = ctz(self.l)
if result == 32 then
return ctz(self.h) + 32
else
return result
end
end
function Bitboard:ctzafter(start)
start = start + 1
if start < 32 then
for i=start,31 do
if bit32.extract(self.l, i) == 1 then return i end
end
end
for i=math.max(32,start),63 do
if bit32.extract(self.h, i-32) == 1 then return i end
end
return 64
end
else
function Bitboard:ctz()
local result = bit32.countrz(self.l)
if result == 32 then
return bit32.countrz(self.h) + 32
else
return result
end
end
function Bitboard:ctzafter(start)
local masked = self:band(Bitboard.full:lshift(start+1))
return masked:ctz()
end
end
function Bitboard:lshift(amt)
assert(amt >= 0)
if amt == 0 then return self end
if amt > 31 then
return Bitboard.from(0, bit32.lshift(self.l, amt-32))
end
local l = bit32.lshift(self.l, amt)
local h = bit32.bor(
bit32.lshift(self.h, amt),
bit32.extract(self.l, 32-amt, amt)
)
return Bitboard.from(l, h)
end
function Bitboard:rshift(amt)
assert(amt >= 0)
if amt == 0 then return self end
local h = bit32.rshift(self.h, amt)
local l = bit32.bor(
bit32.rshift(self.l, amt),
bit32.lshift(bit32.extract(self.h, 0, amt), 32-amt)
)
return Bitboard.from(l, h)
end
function Bitboard:index(i)
if i > 31 then
return bit32.extract(self.h, i - 32)
else
return bit32.extract(self.l, i)
end
end
function Bitboard:set(i , v)
if i > 31 then
return Bitboard.from(self.l, bit32.replace(self.h, v, i - 32))
else
return Bitboard.from(bit32.replace(self.l, v, i), self.h)
end
end
function Bitboard:isolate(i)
return self:band(Bitboard.some(i))
end
function Bitboard.some(idx )
return Bitboard.zero:set(idx, 1)
end
Bitboard.__index = Bitboard
Bitboard.__tostring = Bitboard.toString
for i=2,8 do
RightMasks[i] = RightMasks[i-1]:rshift(1):bor(FileH)
LeftMasks[i] = LeftMasks[i-1]:lshift(1):bor(FileA)
end
--
-- Board
--
local Board = {}
function Board.new()
local boards = table.create(12, Bitboard.zero)
boards.ocupied = Bitboard.zero
boards.white = Bitboard.zero
boards.black = Bitboard.zero
boards.unocupied = Bitboard.full
boards.ep = Bitboard.zero
boards.castle = Bitboard.zero
boards.toMove = 1
boards.hm = 0
boards.moves = 0
boards.material = 0
return setmetatable(boards, Board)
end
function Board.fromFen(fen )
local b = Board.new()
local i = 0
local rank = 7
local file = 0
while true do
i = i + 1
local p = fen:sub(i,i)
if p == '/' then
rank = rank - 1
file = 0
elseif tonumber(p) ~= nil then
file = file + tonumber(p)
else
local pidx = PieceSymbols:find(p)
if pidx == nil then break end
b[pidx] = b[pidx]:set(rank*8+file, 1)
file = file + 1
end
end
local move, castle, ep, hm, m = string.match(fen, "^ ([bw]) ([KQkq-]*) ([a-h-][0-9]?) (%d*) (%d*)", i)
if move == nil then print(fen:sub(i)) end
b.toMove = move == 'w' and 1 or 2
if ep ~= "-" then
b.ep = Bitboard.some(square(ep))
end
if castle ~= "-" then
local oo = Bitboard.zero
if castle:find("K") then
oo = oo:set(7, 1)
end
if castle:find("Q") then
oo = oo:set(0, 1)
end
if castle:find("k") then
oo = oo:set(63, 1)
end
if castle:find("q") then
oo = oo:set(56, 1)
end
b.castle = oo
end
b.hm = hm
b.moves = m
b:updateCache()
return b
end
function Board:index(idx )
if self.white:index(idx) == 1 then
for p=1,12,2 do
if self[p]:index(idx) == 1 then
return p
end
end
else
for p=2,12,2 do
if self[p]:index(idx) == 1 then
return p
end
end
end
return 0
end
function Board:updateCache()
for i=1,11,2 do
self.white = self.white:bor(self[i])
self.black = self.black:bor(self[i+1])
end
self.ocupied = self.black:bor(self.white)
self.unocupied = self.ocupied:inverse()
self.material =
100*self[1]:popcnt() - 100*self[2]:popcnt() +
500*self[3]:popcnt() - 500*self[4]:popcnt() +
300*self[5]:popcnt() - 300*self[6]:popcnt() +
300*self[7]:popcnt() - 300*self[8]:popcnt() +
900*self[9]:popcnt() - 900*self[10]:popcnt()
end
function Board:fen()
local out = {}
local s = 0
local idx = 56
for i=0,63 do
if i % 8 == 0 and i > 0 then
idx = idx - 16
if s > 0 then
table.insert(out, '' .. s)
s = 0
end
table.insert(out, '/')
end
local p = self:index(idx)
if p == 0 then
s = s + 1
else
if s > 0 then
table.insert(out, '' .. s)
s = 0
end
table.insert(out, PieceSymbols:sub(p,p))
end
idx = idx + 1
end
if s > 0 then
table.insert(out, '' .. s)
end
table.insert(out, self.toMove == 1 and ' w ' or ' b ')
if self.castle:empty() then
table.insert(out, '-')
else
if self.castle:index(7) == 1 then table.insert(out, 'K') end
if self.castle:index(0) == 1 then table.insert(out, 'Q') end
if self.castle:index(63) == 1 then table.insert(out, 'k') end
if self.castle:index(56) == 1 then table.insert(out, 'q') end
end
table.insert(out, ' ')
if self.ep:empty() then
table.insert(out, '-')
else
table.insert(out, squareName(self.ep:ctz()))
end
table.insert(out, ' ' .. self.hm)
table.insert(out, ' ' .. self.moves)
return table.concat(out)
end
function Board:pmoves(idx)
return self:generate(idx)
end
function Board:pcaptures(idx)
return self:generate(idx):band(self.ocupied)
end
local ROOK_SLIDES = {{1,0}, {-1,0}, {0,1}, {0,-1}}
local BISHOP_SLIDES = {{1,1}, {-1,1}, {1,-1}, {-1,-1}}
local QUEEN_SLIDES = {{1,0}, {-1,0}, {0,1}, {0,-1}, {1,1}, {-1,1}, {1,-1}, {-1,-1}}
local KNIGHT_MOVES = {{2,1}, {2,-1}, {-2,1}, {-2,-1}, {1,2}, {1,-2}, {-1,2}, {-1,-2}}
function Board:generate(idx)
local piece = self:index(idx)
local r = Bitboard.some(idx)
local out = Bitboard.zero
local type = bit32.rshift(piece - 1, 1)
local cancapture = piece % 2 == 1 and self.black or self.white
if piece == 0 then return Bitboard.zero end
if type == 0 then
-- Pawn
local d = -(piece*2 - 3)
local movetwo = piece == 1 and Rank3 or Rank6
out = out:bor(r:move(0,d):band(self.unocupied))
out = out:bor(out:band(movetwo):move(0,d):band(self.unocupied))
local captures = r:move(0,d)
captures = captures:right():bor(captures:left())
if not captures:bandempty(self.ep) then
out = out:bor(self.ep)
end
captures = captures:band(cancapture)
out = out:bor(captures)
return out
elseif type == 5 then
-- King
for x=-1,1,1 do
for y = -1,1,1 do
local w = r:move(x,y)
if self.ocupied:bandempty(w) then
out = out:bor(w)
else
if not cancapture:bandempty(w) then
out = out:bor(w)
end
end
end
end
elseif type == 2 then
-- Knight
for _,j in ipairs(KNIGHT_MOVES) do
local w = r:move(j[1],j[2])
if self.ocupied:bandempty(w) then
out = out:bor(w)
else
if not cancapture:bandempty(w) then
out = out:bor(w)
end
end
end
else
-- Sliders (Rook, Bishop, Queen)
local slides
if type == 1 then
slides = ROOK_SLIDES
elseif type == 3 then
slides = BISHOP_SLIDES
else
slides = QUEEN_SLIDES
end
for _, op in ipairs(slides) do
local w = r
for i=1,7 do
w = w:move(op[1], op[2])
if w:empty() then break end
if self.ocupied:bandempty(w) then
out = out:bor(w)
else
if not cancapture:bandempty(w) then
out = out:bor(w)
end
break
end
end
end
end
return out
end
-- 0-5 - From Square
-- 6-11 - To Square
-- 12 - is Check
-- 13 - Is EnPassent
-- 14 - Is Castle
-- 15-19 - Promotion Piece
-- 20-24 - Moved Pice
-- 25-29 - Captured Piece
function Board:toString(mark )
local out = {}
for x=8,1,-1 do
table.insert(out, RANKS:sub(x,x) .. " ")
for y=1,8 do
local n = 8*x+y-9
local i = self:index(n)
if i == 0 then
table.insert(out, '-')
else
-- out = out .. PieceSymbols:sub(i,i)
table.insert(out, UnicodePieces[i])
end
if mark ~= nil and mark:index(n) ~= 0 then
table.insert(out, ')')
elseif mark ~= nil and n < 63 and y < 8 and mark:index(n+1) ~= 0 then
table.insert(out, '(')
else
table.insert(out, ' ')
end
end
table.insert(out, "\n")
end
table.insert(out, ' ' .. FILES:gsub('.', '%1 ') .. '\n')
table.insert(out, (self.toMove == 1 and "White" or "Black") .. ' e:' .. (self.material/100) .. "\n")
return table.concat(out)
end
function Board:moveList()
local tm = self.toMove == 1 and self.white or self.black
local castle_rank = self.toMove == 1 and Rank1 or Rank8
local out = {}
local function emit(id)
if not self:applyMove(id):illegalyChecked() then
table.insert(out, id)
end
end
local cr = tm:band(self.castle):band(castle_rank)
if not cr:empty() then
local p = self.toMove == 1 and 11 or 12
local tcolor = self.toMove == 1 and self.black or self.white
local kidx = self[p]:ctz()
local castle = bit32.replace(0, p, 20, 4)
castle = bit32.replace(castle, kidx, 6, 6)
castle = bit32.replace(castle, 1, 14)
local mustbeemptyl = LeftMasks[4]:bxor(FileA):band(castle_rank)
local cantbethreatened = FileD:bor(FileC):band(castle_rank):bor(self[p])
if
not cr:bandempty(FileA) and
mustbeemptyl:bandempty(self.ocupied) and
not self:isSquareThreatened(cantbethreatened, tcolor)
then
emit(bit32.replace(castle, kidx - 2, 0, 6))
end
local mustbeemptyr = RightMasks[3]:bxor(FileH):band(castle_rank)
if
not cr:bandempty(FileH) and
mustbeemptyr:bandempty(self.ocupied) and
not self:isSquareThreatened(mustbeemptyr:bor(self[p]), tcolor)
then
emit(bit32.replace(castle, kidx + 2, 0, 6))
end
end
local sq = tm:ctz()
repeat
local p = self:index(sq)
local moves = self:pmoves(sq)
while not moves:empty() do
local m = moves:ctz()
moves = moves:set(m, 0)
local id = bit32.replace(m, sq, 6, 6)
id = bit32.replace(id, p, 20, 4)
local mbb = Bitboard.some(m)
if not self.ocupied:bandempty(mbb) then
id = bit32.replace(id, self:index(m), 25, 4)
end
-- Check if pawn needs to be promoted
if p == 1 and m >= 8*7 then
for i=3,9,2 do
emit(bit32.replace(id, i, 15, 4))
end
elseif p == 2 and m < 8 then
for i=4,10,2 do
emit(bit32.replace(id, i, 15, 4))
end
else
emit(id)
end
end
sq = tm:ctzafter(sq)
until sq == 64
return out
end
function Board:illegalyChecked()
local target = self.toMove == 1 and self[PieceSymbols:find("k")] or self[PieceSymbols:find("K")]
return self:isSquareThreatened(target, self.toMove == 1 and self.white or self.black)
end
function Board:isSquareThreatened(target , color )
local tm = color
local sq = tm:ctz()
repeat
local moves = self:pmoves(sq)
if not moves:bandempty(target) then
return true
end
sq = color:ctzafter(sq)
until sq == 64
return false
end
function Board:perft(depth )
if depth == 0 then return 1 end
if depth == 1 then
return #self:moveList()
end
local result = 0
for k,m in ipairs(self:moveList()) do
local c = self:applyMove(m):perft(depth - 1)
if c == 0 then
-- Perft only counts leaf nodes at target depth
-- result = result + 1
else
result = result + c
end
end
return result
end
function Board:applyMove(move )
local out = Board.new()
table.move(self, 1, 12, 1, out)
local from = bit32.extract(move, 6, 6)
local to = bit32.extract(move, 0, 6)
local promote = bit32.extract(move, 15, 4)
local piece = self:index(from)
local captured = self:index(to)
local tom = Bitboard.some(to)
local isCastle = bit32.extract(move, 14)
if piece % 2 == 0 then
out.moves = self.moves + 1
end
if captured == 1 or piece < 3 then
out.hm = 0
else
out.hm = self.hm + 1
end
out.castle = self.castle
out.toMove = self.toMove == 1 and 2 or 1
if isCastle == 1 then
local rank = piece == 11 and Rank1 or Rank8
local colorOffset = piece - 11
out[3 + colorOffset] = out[3 + colorOffset]:bandnot(from < to and FileH or FileA)
out[3 + colorOffset] = out[3 + colorOffset]:bor((from < to and FileF or FileD):band(rank))
out[piece] = (from < to and FileG or FileC):band(rank)
out.castle = out.castle:bandnot(rank)
out:updateCache()
return out
end
if piece < 3 then
local dist = math.abs(to - from)
-- Pawn moved two squares, set ep square
if dist == 16 then
out.ep = Bitboard.some((from + to) / 2)
end
-- Remove enpasent capture
if not tom:bandempty(self.ep) then
if piece == 1 then
out[2] = out[2]:bandnot(self.ep:down())
end
if piece == 2 then
out[1] = out[1]:bandnot(self.ep:up())
end
end
end
if piece == 3 or piece == 4 then
out.castle = out.castle:set(from, 0)
end
if piece > 10 then
local rank = piece == 11 and Rank1 or Rank8
out.castle = out.castle:bandnot(rank)
end
out[piece] = out[piece]:set(from, 0)
if promote == 0 then
out[piece] = out[piece]:set(to, 1)
else
out[promote] = out[promote]:set(to, 1)
end
if captured ~= 0 then
out[captured] = out[captured]:set(to, 0)
end
out:updateCache()
return out
end
Board.__index = Board
Board.__tostring = Board.toString
--
-- Main
--
local failures = 0
local function test(fen, ply, target)
local b = Board.fromFen(fen)
if b:fen() ~= fen then
print("FEN MISMATCH", fen, b:fen())
failures = failures + 1
return
end
local found = b:perft(ply)
if found ~= target then
print(fen, "Found", found, "target", target)
failures = failures + 1
for k,v in pairs(b:moveList()) do
print(ucimove(v) .. ': ' .. (ply > 1 and b:applyMove(v):perft(ply-1) or '1'))
end
--error("Test Failure")
else
print("OK", found, fen)
end
end
-- From https://www.chessprogramming.org/Perft_Results
-- If interpreter, computers, or algorithm gets too fast
-- feel free to go deeper
local testCases = {}
local function addTest(...) table.insert(testCases, {...}) end
addTest(StartingFen, 2, 400)
addTest("r3k2r/p1ppqpb1/bn2pnp1/3PN3/1p2P3/2N2Q1p/PPPBBPPP/R3K2R w KQkq - 0 0", 1, 48)
addTest("8/2p5/3p4/KP5r/1R3p1k/8/4P1P1/8 w - - 0 0", 2, 191)
addTest("r3k2r/Pppp1ppp/1b3nbN/nP6/BBP1P3/q4N2/Pp1P2PP/R2Q1RK1 w kq - 0 1", 2, 264)
addTest("rnbq1k1r/pp1Pbppp/2p5/8/2B5/8/PPP1NnPP/RNBQK2R w KQ - 1 8", 1, 44)
addTest("r4rk1/1pp1qppp/p1np1n2/2b1p1B1/2B1P1b1/P1NP1N2/1PP1QPPP/R4RK1 w - - 0 10", 1, 46)
local function chess()
for k,v in ipairs(testCases) do
test(v[1],v[2],v[3])
end
end
bench.runCode(chess, "chess")

View file

@ -1,934 +0,0 @@
local bench = script and require(script.Parent.bench_support) or require("bench_support")
-- Copyright 2008 the V8 project authors. All rights reserved.
-- Copyright 1996 John Maloney and Mario Wolczko.
-- This program is free software; you can redistribute it and/or modify
-- it under the terms of the GNU General Public License as published by
-- the Free Software Foundation; either version 2 of the License, or
-- (at your option) any later version.
--
-- This program is distributed in the hope that it will be useful,
-- but WITHOUT ANY WARRANTY; without even the implied warranty of
-- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-- GNU General Public License for more details.
--
-- You should have received a copy of the GNU General Public License
-- along with this program; if not, write to the Free Software
-- Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
-- This implementation of the DeltaBlue benchmark is derived
-- from the Smalltalk implementation by John Maloney and Mario
-- Wolczko. Some parts have been translated directly, whereas
-- others have been modified more aggressively to make it feel
-- more like a JavaScript program.
--
-- A JavaScript implementation of the DeltaBlue constraint-solving
-- algorithm, as described in:
--
-- "The DeltaBlue Algorithm: An Incremental Constraint Hierarchy Solver"
-- Bjorn N. Freeman-Benson and John Maloney
-- January 1990 Communications of the ACM,
-- also available as University of Washington TR 89-08-06.
--
-- Beware: this benchmark is written in a grotesque style where
-- the constraint model is built by side-effects from constructors.
-- I've kept it this way to avoid deviating too much from the original
-- implementation.
--
function class(base)
local T = {}
T.__index = T
if base then
T.super = base
setmetatable(T, base)
end
function T.new(...)
local O = {}
setmetatable(O, T)
O:constructor(...)
return O
end
return T
end
local planner
--- O b j e c t M o d e l ---
local function alert (...) print(...) end
local OrderedCollection = class()
function OrderedCollection:constructor()
self.elms = {}
end
function OrderedCollection:add(elm)
self.elms[#self.elms + 1] = elm
end
function OrderedCollection:at (index)
return self.elms[index]
end
function OrderedCollection:size ()
return #self.elms
end
function OrderedCollection:removeFirst ()
local e = self.elms[#self.elms]
self.elms[#self.elms] = nil
return e
end
function OrderedCollection:remove (elm)
local index = 0
local skipped = 0
for i = 1, #self.elms do
local value = self.elms[i]
if value ~= elm then
self.elms[index] = value
index = index + 1
else
skipped = skipped + 1
end
end
local l = #self.elms
for i = 1, skipped do self.elms[l - i + 1] = nil end
end
--
-- S t r e n g t h
--
--
-- Strengths are used to measure the relative importance of constraints.
-- New strengths may be inserted in the strength hierarchy without
-- disrupting current constraints. Strengths cannot be created outside
-- this class, so pointer comparison can be used for value comparison.
--
local Strength = class()
function Strength:constructor(strengthValue, name)
self.strengthValue = strengthValue
self.name = name
end
function Strength.stronger (s1, s2)
return s1.strengthValue < s2.strengthValue
end
function Strength.weaker (s1, s2)
return s1.strengthValue > s2.strengthValue
end
function Strength.weakestOf (s1, s2)
return Strength.weaker(s1, s2) and s1 or s2
end
function Strength.strongest (s1, s2)
return Strength.stronger(s1, s2) and s1 or s2
end
function Strength:nextWeaker ()
local v = self.strengthValue
if v == 0 then return Strength.WEAKEST
elseif v == 1 then return Strength.WEAK_DEFAULT
elseif v == 2 then return Strength.NORMAL
elseif v == 3 then return Strength.STRONG_DEFAULT
elseif v == 4 then return Strength.PREFERRED
elseif v == 5 then return Strength.REQUIRED
end
end
-- Strength constants.
Strength.REQUIRED = Strength.new(0, "required");
Strength.STONG_PREFERRED = Strength.new(1, "strongPreferred");
Strength.PREFERRED = Strength.new(2, "preferred");
Strength.STRONG_DEFAULT = Strength.new(3, "strongDefault");
Strength.NORMAL = Strength.new(4, "normal");
Strength.WEAK_DEFAULT = Strength.new(5, "weakDefault");
Strength.WEAKEST = Strength.new(6, "weakest");
--
-- C o n s t r a i n t
--
--
-- An abstract class representing a system-maintainable relationship
-- (or "constraint") between a set of variables. A constraint supplies
-- a strength instance variable; concrete subclasses provide a means
-- of storing the constrained variables and other information required
-- to represent a constraint.
--
local Constraint = class ()
function Constraint:constructor(strength)
self.strength = strength
end
--
-- Activate this constraint and attempt to satisfy it.
--
function Constraint:addConstraint ()
self:addToGraph()
planner:incrementalAdd(self)
end
--
-- Attempt to find a way to enforce this constraint. If successful,
-- record the solution, perhaps modifying the current dataflow
-- graph. Answer the constraint that this constraint overrides, if
-- there is one, or nil, if there isn't.
-- Assume: I am not already satisfied.
--
function Constraint:satisfy (mark)
self:chooseMethod(mark)
if not self:isSatisfied() then
if self.strength == Strength.REQUIRED then
alert("Could not satisfy a required constraint!")
end
return nil
end
self:markInputs(mark)
local out = self:output()
local overridden = out.determinedBy
if overridden ~= nil then overridden:markUnsatisfied() end
out.determinedBy = self
if not planner:addPropagate(self, mark) then alert("Cycle encountered") end
out.mark = mark
return overridden
end
function Constraint:destroyConstraint ()
if self:isSatisfied()
then planner:incrementalRemove(self)
else self:removeFromGraph()
end
end
--
-- Normal constraints are not input constraints. An input constraint
-- is one that depends on external state, such as the mouse, the
-- keyboard, a clock, or some arbitrary piece of imperative code.
--
function Constraint:isInput ()
return false
end
--
-- U n a r y C o n s t r a i n t
--
--
-- Abstract superclass for constraints having a single possible output
-- variable.
--
local UnaryConstraint = class(Constraint)
function UnaryConstraint:constructor (v, strength)
UnaryConstraint.super.constructor(self, strength)
self.myOutput = v
self.satisfied = false
self:addConstraint()
end
--
-- Adds this constraint to the constraint graph
--
function UnaryConstraint:addToGraph ()
self.myOutput:addConstraint(self)
self.satisfied = false
end
--
-- Decides if this constraint can be satisfied and records that
-- decision.
--
function UnaryConstraint:chooseMethod (mark)
self.satisfied = (self.myOutput.mark ~= mark)
and Strength.stronger(self.strength, self.myOutput.walkStrength);
end
--
-- Returns true if this constraint is satisfied in the current solution.
--
function UnaryConstraint:isSatisfied ()
return self.satisfied;
end
function UnaryConstraint:markInputs (mark)
-- has no inputs
end
--
-- Returns the current output variable.
--
function UnaryConstraint:output ()
return self.myOutput
end
--
-- Calculate the walkabout strength, the stay flag, and, if it is
-- 'stay', the value for the current output of this constraint. Assume
-- this constraint is satisfied.
--
function UnaryConstraint:recalculate ()
self.myOutput.walkStrength = self.strength
self.myOutput.stay = not self:isInput()
if self.myOutput.stay then
self:execute() -- Stay optimization
end
end
--
-- Records that this constraint is unsatisfied
--
function UnaryConstraint:markUnsatisfied ()
self.satisfied = false
end
function UnaryConstraint:inputsKnown ()
return true
end
function UnaryConstraint:removeFromGraph ()
if self.myOutput ~= nil then
self.myOutput:removeConstraint(self)
end
self.satisfied = false
end
--
-- S t a y C o n s t r a i n t
--
--
-- Variables that should, with some level of preference, stay the same.
-- Planners may exploit the fact that instances, if satisfied, will not
-- change their output during plan execution. This is called "stay
-- optimization".
--
local StayConstraint = class(UnaryConstraint)
function StayConstraint:constructor(v, str)
StayConstraint.super.constructor(self, v, str)
end
function StayConstraint:execute ()
-- Stay constraints do nothing
end
--
-- E d i t C o n s t r a i n t
--
--
-- A unary input constraint used to mark a variable that the client
-- wishes to change.
--
local EditConstraint = class (UnaryConstraint)
function EditConstraint:constructor(v, str)
EditConstraint.super.constructor(self, v, str)
end
--
-- Edits indicate that a variable is to be changed by imperative code.
--
function EditConstraint:isInput ()
return true
end
function EditConstraint:execute ()
-- Edit constraints do nothing
end
--
-- B i n a r y C o n s t r a i n t
--
local Direction = {}
Direction.NONE = 0
Direction.FORWARD = 1
Direction.BACKWARD = -1
--
-- Abstract superclass for constraints having two possible output
-- variables.
--
local BinaryConstraint = class(Constraint)
function BinaryConstraint:constructor(var1, var2, strength)
BinaryConstraint.super.constructor(self, strength);
self.v1 = var1
self.v2 = var2
self.direction = Direction.NONE
self:addConstraint()
end
--
-- Decides if this constraint can be satisfied and which way it
-- should flow based on the relative strength of the variables related,
-- and record that decision.
--
function BinaryConstraint:chooseMethod (mark)
if self.v1.mark == mark then
self.direction = (self.v2.mark ~= mark and Strength.stronger(self.strength, self.v2.walkStrength)) and Direction.FORWARD or Direction.NONE
end
if self.v2.mark == mark then
self.direction = (self.v1.mark ~= mark and Strength.stronger(self.strength, self.v1.walkStrength)) and Direction.BACKWARD or Direction.NONE
end
if Strength.weaker(self.v1.walkStrength, self.v2.walkStrength) then
self.direction = Strength.stronger(self.strength, self.v1.walkStrength) and Direction.BACKWARD or Direction.NONE
else
self.direction = Strength.stronger(self.strength, self.v2.walkStrength) and Direction.FORWARD or Direction.BACKWARD
end
end
--
-- Add this constraint to the constraint graph
--
function BinaryConstraint:addToGraph ()
self.v1:addConstraint(self)
self.v2:addConstraint(self)
self.direction = Direction.NONE
end
--
-- Answer true if this constraint is satisfied in the current solution.
--
function BinaryConstraint:isSatisfied ()
return self.direction ~= Direction.NONE
end
--
-- Mark the input variable with the given mark.
--
function BinaryConstraint:markInputs (mark)
self:input().mark = mark
end
--
-- Returns the current input variable
--
function BinaryConstraint:input ()
return (self.direction == Direction.FORWARD) and self.v1 or self.v2
end
--
-- Returns the current output variable
--
function BinaryConstraint:output ()
return (self.direction == Direction.FORWARD) and self.v2 or self.v1
end
--
-- Calculate the walkabout strength, the stay flag, and, if it is
-- 'stay', the value for the current output of this
-- constraint. Assume this constraint is satisfied.
--
function BinaryConstraint:recalculate ()
local ihn = self:input()
local out = self:output()
out.walkStrength = Strength.weakestOf(self.strength, ihn.walkStrength);
out.stay = ihn.stay
if out.stay then self:execute() end
end
--
-- Record the fact that self constraint is unsatisfied.
--
function BinaryConstraint:markUnsatisfied ()
self.direction = Direction.NONE
end
function BinaryConstraint:inputsKnown (mark)
local i = self:input()
return i.mark == mark or i.stay or i.determinedBy == nil
end
function BinaryConstraint:removeFromGraph ()
if (self.v1 ~= nil) then self.v1:removeConstraint(self) end
if (self.v2 ~= nil) then self.v2:removeConstraint(self) end
self.direction = Direction.NONE
end
--
-- S c a l e C o n s t r a i n t
--
--
-- Relates two variables by the linear scaling relationship: "v2 =
-- (v1 * scale) + offset". Either v1 or v2 may be changed to maintain
-- this relationship but the scale factor and offset are considered
-- read-only.
--
local ScaleConstraint = class (BinaryConstraint)
function ScaleConstraint:constructor(src, scale, offset, dest, strength)
self.direction = Direction.NONE
self.scale = scale
self.offset = offset
ScaleConstraint.super.constructor(self, src, dest, strength)
end
--
-- Adds this constraint to the constraint graph.
--
function ScaleConstraint:addToGraph ()
ScaleConstraint.super.addToGraph(self)
self.scale:addConstraint(self)
self.offset:addConstraint(self)
end
function ScaleConstraint:removeFromGraph ()
ScaleConstraint.super.removeFromGraph(self)
if (self.scale ~= nil) then self.scale:removeConstraint(self) end
if (self.offset ~= nil) then self.offset:removeConstraint(self) end
end
function ScaleConstraint:markInputs (mark)
ScaleConstraint.super.markInputs(self, mark);
self.offset.mark = mark
self.scale.mark = mark
end
--
-- Enforce this constraint. Assume that it is satisfied.
--
function ScaleConstraint:execute ()
if self.direction == Direction.FORWARD then
self.v2.value = self.v1.value * self.scale.value + self.offset.value
else
self.v1.value = (self.v2.value - self.offset.value) / self.scale.value
end
end
--
-- Calculate the walkabout strength, the stay flag, and, if it is
-- 'stay', the value for the current output of this constraint. Assume
-- this constraint is satisfied.
--
function ScaleConstraint:recalculate ()
local ihn = self:input()
local out = self:output()
out.walkStrength = Strength.weakestOf(self.strength, ihn.walkStrength)
out.stay = ihn.stay and self.scale.stay and self.offset.stay
if out.stay then self:execute() end
end
--
-- E q u a l i t y C o n s t r a i n t
--
--
-- Constrains two variables to have the same value.
--
local EqualityConstraint = class (BinaryConstraint)
function EqualityConstraint:constructor(var1, var2, strength)
EqualityConstraint.super.constructor(self, var1, var2, strength)
end
--
-- Enforce this constraint. Assume that it is satisfied.
--
function EqualityConstraint:execute ()
self:output().value = self:input().value
end
--
-- V a r i a b l e
--
--
-- A constrained variable. In addition to its value, it maintain the
-- structure of the constraint graph, the current dataflow graph, and
-- various parameters of interest to the DeltaBlue incremental
-- constraint solver.
--
local Variable = class ()
function Variable:constructor(name, initialValue)
self.value = initialValue or 0
self.constraints = OrderedCollection.new()
self.determinedBy = nil
self.mark = 0
self.walkStrength = Strength.WEAKEST
self.stay = true
self.name = name
end
--
-- Add the given constraint to the set of all constraints that refer
-- this variable.
--
function Variable:addConstraint (c)
self.constraints:add(c)
end
--
-- Removes all traces of c from this variable.
--
function Variable:removeConstraint (c)
self.constraints:remove(c)
if self.determinedBy == c then
self.determinedBy = nil
end
end
--
-- P l a n n e r
--
--
-- The DeltaBlue planner
--
local Planner = class()
function Planner:constructor()
self.currentMark = 0
end
--
-- Attempt to satisfy the given constraint and, if successful,
-- incrementally update the dataflow graph. Details: If satisfying
-- the constraint is successful, it may override a weaker constraint
-- on its output. The algorithm attempts to resatisfy that
-- constraint using some other method. This process is repeated
-- until either a) it reaches a variable that was not previously
-- determined by any constraint or b) it reaches a constraint that
-- is too weak to be satisfied using any of its methods. The
-- variables of constraints that have been processed are marked with
-- a unique mark value so that we know where we've been. This allows
-- the algorithm to avoid getting into an infinite loop even if the
-- constraint graph has an inadvertent cycle.
--
function Planner:incrementalAdd (c)
local mark = self:newMark()
local overridden = c:satisfy(mark)
while overridden ~= nil do
overridden = overridden:satisfy(mark)
end
end
--
-- Entry point for retracting a constraint. Remove the given
-- constraint and incrementally update the dataflow graph.
-- Details: Retracting the given constraint may allow some currently
-- unsatisfiable downstream constraint to be satisfied. We therefore collect
-- a list of unsatisfied downstream constraints and attempt to
-- satisfy each one in turn. This list is traversed by constraint
-- strength, strongest first, as a heuristic for avoiding
-- unnecessarily adding and then overriding weak constraints.
-- Assume: c is satisfied.
--
function Planner:incrementalRemove (c)
local out = c:output()
c:markUnsatisfied()
c:removeFromGraph()
local unsatisfied = self:removePropagateFrom(out)
local strength = Strength.REQUIRED
repeat
for i = 1, unsatisfied:size() do
local u = unsatisfied:at(i)
if u.strength == strength then
self:incrementalAdd(u)
end
end
strength = strength:nextWeaker()
until strength == Strength.WEAKEST
end
--
-- Select a previously unused mark value.
--
function Planner:newMark ()
self.currentMark = self.currentMark + 1
return self.currentMark
end
--
-- Extract a plan for resatisfaction starting from the given source
-- constraints, usually a set of input constraints. This method
-- assumes that stay optimization is desired; the plan will contain
-- only constraints whose output variables are not stay. Constraints
-- that do no computation, such as stay and edit constraints, are
-- not included in the plan.
-- Details: The outputs of a constraint are marked when it is added
-- to the plan under construction. A constraint may be appended to
-- the plan when all its input variables are known. A variable is
-- known if either a) the variable is marked (indicating that has
-- been computed by a constraint appearing earlier in the plan), b)
-- the variable is 'stay' (i.e. it is a constant at plan execution
-- time), or c) the variable is not determined by any
-- constraint. The last provision is for past states of history
-- variables, which are not stay but which are also not computed by
-- any constraint.
-- Assume: sources are all satisfied.
--
local Plan -- FORWARD DECLARATION
function Planner:makePlan (sources)
local mark = self:newMark()
local plan = Plan.new()
local todo = sources
while todo:size() > 0 do
local c = todo:removeFirst()
if c:output().mark ~= mark and c:inputsKnown(mark) then
plan:addConstraint(c)
c:output().mark = mark
self:addConstraintsConsumingTo(c:output(), todo)
end
end
return plan
end
--
-- Extract a plan for resatisfying starting from the output of the
-- given constraints, usually a set of input constraints.
--
function Planner:extractPlanFromConstraints (constraints)
local sources = OrderedCollection.new()
for i = 1, constraints:size() do
local c = constraints:at(i)
if c:isInput() and c:isSatisfied() then
-- not in plan already and eligible for inclusion
sources:add(c)
end
end
return self:makePlan(sources)
end
--
-- Recompute the walkabout strengths and stay flags of all variables
-- downstream of the given constraint and recompute the actual
-- values of all variables whose stay flag is true. If a cycle is
-- detected, remove the given constraint and answer
-- false. Otherwise, answer true.
-- Details: Cycles are detected when a marked variable is
-- encountered downstream of the given constraint. The sender is
-- assumed to have marked the inputs of the given constraint with
-- the given mark. Thus, encountering a marked node downstream of
-- the output constraint means that there is a path from the
-- constraint's output to one of its inputs.
--
function Planner:addPropagate (c, mark)
local todo = OrderedCollection.new()
todo:add(c)
while todo:size() > 0 do
local d = todo:removeFirst()
if d:output().mark == mark then
self:incrementalRemove(c)
return false
end
d:recalculate()
self:addConstraintsConsumingTo(d:output(), todo)
end
return true
end
--
-- Update the walkabout strengths and stay flags of all variables
-- downstream of the given constraint. Answer a collection of
-- unsatisfied constraints sorted in order of decreasing strength.
--
function Planner:removePropagateFrom (out)
out.determinedBy = nil
out.walkStrength = Strength.WEAKEST
out.stay = true
local unsatisfied = OrderedCollection.new()
local todo = OrderedCollection.new()
todo:add(out)
while todo:size() > 0 do
local v = todo:removeFirst()
for i = 1, v.constraints:size() do
local c = v.constraints:at(i)
if not c:isSatisfied() then unsatisfied:add(c) end
end
local determining = v.determinedBy
for i = 1, v.constraints:size() do
local next = v.constraints:at(i);
if next ~= determining and next:isSatisfied() then
next:recalculate()
todo:add(next:output())
end
end
end
return unsatisfied
end
function Planner:addConstraintsConsumingTo (v, coll)
local determining = v.determinedBy
local cc = v.constraints
for i = 1, cc:size() do
local c = cc:at(i)
if c ~= determining and c:isSatisfied() then
coll:add(c)
end
end
end
--
-- P l a n
--
--
-- A Plan is an ordered list of constraints to be executed in sequence
-- to resatisfy all currently satisfiable constraints in the face of
-- one or more changing inputs.
--
Plan = class()
function Plan:constructor()
self.v = OrderedCollection.new()
end
function Plan:addConstraint (c)
self.v:add(c)
end
function Plan:size ()
return self.v:size()
end
function Plan:constraintAt (index)
return self.v:at(index)
end
function Plan:execute ()
for i = 1, self:size() do
local c = self:constraintAt(i)
c:execute()
end
end
--
-- M a i n
--
--
-- This is the standard DeltaBlue benchmark. A long chain of equality
-- constraints is constructed with a stay constraint on one end. An
-- edit constraint is then added to the opposite end and the time is
-- measured for adding and removing this constraint, and extracting
-- and executing a constraint satisfaction plan. There are two cases.
-- In case 1, the added constraint is stronger than the stay
-- constraint and values must propagate down the entire length of the
-- chain. In case 2, the added constraint is weaker than the stay
-- constraint so it cannot be accommodated. The cost in this case is,
-- of course, very low. Typical situations lie somewhere between these
-- two extremes.
--
local function chainTest(n)
planner = Planner.new()
local prev = nil
local first = nil
local last = nil
-- Build chain of n equality constraints
for i = 0, n do
local name = "v" .. i;
local v = Variable.new(name)
if prev ~= nil then EqualityConstraint.new(prev, v, Strength.REQUIRED) end
if i == 0 then first = v end
if i == n then last = v end
prev = v
end
StayConstraint.new(last, Strength.STRONG_DEFAULT)
local edit = EditConstraint.new(first, Strength.PREFERRED)
local edits = OrderedCollection.new()
edits:add(edit)
local plan = planner:extractPlanFromConstraints(edits)
for i = 0, 99 do
first.value = i
plan:execute()
if last.value ~= i then
alert("Chain test failed.")
end
end
end
local function change(v, newValue)
local edit = EditConstraint.new(v, Strength.PREFERRED)
local edits = OrderedCollection.new()
edits:add(edit)
local plan = planner:extractPlanFromConstraints(edits)
for i = 1, 10 do
v.value = newValue
plan:execute()
end
edit:destroyConstraint()
end
--
-- This test constructs a two sets of variables related to each
-- other by a simple linear transformation (scale and offset). The
-- time is measured to change a variable on either side of the
-- mapping and to change the scale and offset factors.
--
local function projectionTest(n)
planner = Planner.new();
local scale = Variable.new("scale", 10);
local offset = Variable.new("offset", 1000);
local src = nil
local dst = nil;
local dests = OrderedCollection.new();
for i = 0, n - 1 do
src = Variable.new("src" .. i, i);
dst = Variable.new("dst" .. i, i);
dests:add(dst);
StayConstraint.new(src, Strength.NORMAL);
ScaleConstraint.new(src, scale, offset, dst, Strength.REQUIRED);
end
change(src, 17)
if dst.value ~= 1170 then alert("Projection 1 failed") end
change(dst, 1050)
if src.value ~= 5 then alert("Projection 2 failed") end
change(scale, 5)
for i = 0, n - 2 do
if dests:at(i + 1).value ~= i * 5 + 1000 then
alert("Projection 3 failed")
end
end
change(offset, 2000)
for i = 0, n - 2 do
if dests:at(i + 1).value ~= i * 5 + 2000 then
alert("Projection 4 failed")
end
end
end
function test()
local t0 = os.clock()
chainTest(1000);
projectionTest(1000);
local t1 = os.clock()
return t1-t0
end
bench.runCode(test, "deltablue")

View file

@ -3,6 +3,8 @@ main:
url: /news
- title: Getting Started
url: /getting-started
- title: GitHub
url: https://github.com/Roblox/luau
pages:
- title: Getting Started
@ -21,5 +23,11 @@ pages:
url: /compatibility
- title: Typechecking
url: /typecheck
- title: Profiling
url: /profile
- title: Library
url: /library
# Remove demo pages until solution is found
# - title: Demo
# url: /demo

50
docs/_includes/repl.html Normal file
View file

@ -0,0 +1,50 @@
<form>
<div>
<label>Script:</label>
<br>
<textarea rows="10" cols="70" id="script">print("Hello World!")</textarea>
<br><br>
<button onclick="clearInput(); return false;">
Clear Input
</button>
<button onclick="executeScript(); return false;">
Run
</button>
</div>
<br><br>
<div>
<label>Output:</label>
<br>
<textarea readonly rows="10" cols="70" id="output"></textarea>
<br><br>
<button onclick="clearOutput(); return false;">
Clear Output
</button>
</div>
</form>
<script>
function output(text) {
document.getElementById("output").value += "[" + new Date().toLocaleTimeString() + "] " + text.replace('stdin:', '') + "\n";
}
var Module = {
'print': function (msg) { output(msg) }
};
function clearInput() {
document.getElementById("script").value = "";
}
function clearOutput() {
document.getElementById("output").value = "";
}
function executeScript() {
var err = Module.ccall('executeScript', 'string', ['string'], [document.getElementById("script").value]);
if (err) {
output('Error:' + err.replace('stdin:', ''));
}
}
</script>
<script async src="assets/luau/luau.js"></script>

Some files were not shown because too many files have changed in this diff Show more