Merge branch 'Roblox:master' into master

This commit is contained in:
Nikita Gordeev 2022-11-06 11:00:19 +07:00 committed by GitHub
commit dc0263b0fe
Signed by: DevComp
GPG key ID: 4AEE18F83AFDEB23
162 changed files with 10184 additions and 5681 deletions

View file

@ -73,9 +73,11 @@ jobs:
valgrind --tool=callgrind ./luau --compile=null -O0 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O0 | tee -a compile-output.txt valgrind --tool=callgrind ./luau --compile=null -O0 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O0 | tee -a compile-output.txt
valgrind --tool=callgrind ./luau --compile=null -O1 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O1 | tee -a compile-output.txt valgrind --tool=callgrind ./luau --compile=null -O1 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O1 | tee -a compile-output.txt
valgrind --tool=callgrind ./luau --compile=null -O2 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O2 | tee -a compile-output.txt valgrind --tool=callgrind ./luau --compile=null -O2 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O2 | tee -a compile-output.txt
valgrind --tool=callgrind ./luau --compile=codegennull -O2 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O2-codegen | tee -a compile-output.txt
valgrind --tool=callgrind ./luau --compile=null -O0 bench/other/regex.lua 2>&1 | filter regex-O0 | tee -a compile-output.txt valgrind --tool=callgrind ./luau --compile=null -O0 bench/other/regex.lua 2>&1 | filter regex-O0 | tee -a compile-output.txt
valgrind --tool=callgrind ./luau --compile=null -O1 bench/other/regex.lua 2>&1 | filter regex-O1 | tee -a compile-output.txt valgrind --tool=callgrind ./luau --compile=null -O1 bench/other/regex.lua 2>&1 | filter regex-O1 | tee -a compile-output.txt
valgrind --tool=callgrind ./luau --compile=null -O2 bench/other/regex.lua 2>&1 | filter regex-O2 | tee -a compile-output.txt valgrind --tool=callgrind ./luau --compile=null -O2 bench/other/regex.lua 2>&1 | filter regex-O2 | tee -a compile-output.txt
valgrind --tool=callgrind ./luau --compile=codegennull -O2 bench/other/regex.lua 2>&1 | filter regex-O2-codegen | tee -a compile-output.txt
- name: Checkout benchmark results - name: Checkout benchmark results
uses: actions/checkout@v3 uses: actions/checkout@v3

View file

@ -1,6 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once #pragma once
#include <Luau/NotNull.h>
#include "Luau/TypeArena.h" #include "Luau/TypeArena.h"
#include "Luau/TypeVar.h" #include "Luau/TypeVar.h"
@ -26,5 +27,6 @@ TypeId clone(TypeId tp, TypeArena& dest, CloneState& cloneState);
TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState); TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState);
TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysClone = false); TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysClone = false);
TypeId shallowClone(TypeId ty, NotNull<TypeArena> dest);
} // namespace Luau } // namespace Luau

View file

@ -0,0 +1,68 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Def.h"
#include "Luau/TypedAllocator.h"
#include "Luau/TypeVar.h"
#include "Luau/Variant.h"
#include <memory>
namespace Luau
{
struct Negation;
struct Conjunction;
struct Disjunction;
struct Equivalence;
struct Proposition;
using Connective = Variant<Negation, Conjunction, Disjunction, Equivalence, Proposition>;
using ConnectiveId = Connective*; // Can and most likely is nullptr.
struct Negation
{
ConnectiveId connective;
};
struct Conjunction
{
ConnectiveId lhs;
ConnectiveId rhs;
};
struct Disjunction
{
ConnectiveId lhs;
ConnectiveId rhs;
};
struct Equivalence
{
ConnectiveId lhs;
ConnectiveId rhs;
};
struct Proposition
{
DefId def;
TypeId discriminantTy;
};
template<typename T>
const T* get(ConnectiveId connective)
{
return get_if<T>(connective);
}
struct ConnectiveArena
{
TypedAllocator<Connective> allocator;
ConnectiveId negation(ConnectiveId connective);
ConnectiveId conjunction(ConnectiveId lhs, ConnectiveId rhs);
ConnectiveId disjunction(ConnectiveId lhs, ConnectiveId rhs);
ConnectiveId equivalence(ConnectiveId lhs, ConnectiveId rhs);
ConnectiveId proposition(DefId def, TypeId discriminantTy);
};
} // namespace Luau

View file

@ -2,9 +2,10 @@
#pragma once #pragma once
#include "Luau/Ast.h" // Used for some of the enumerations #include "Luau/Ast.h" // Used for some of the enumerations
#include "Luau/Def.h"
#include "Luau/NotNull.h" #include "Luau/NotNull.h"
#include "Luau/Variant.h"
#include "Luau/TypeVar.h" #include "Luau/TypeVar.h"
#include "Luau/Variant.h"
#include <string> #include <string>
#include <memory> #include <memory>
@ -131,9 +132,16 @@ struct HasPropConstraint
std::string prop; std::string prop;
}; };
using ConstraintV = // result ~ if isSingleton D then ~D else unknown where D = discriminantType
Variant<SubtypeConstraint, PackSubtypeConstraint, GeneralizationConstraint, InstantiationConstraint, UnaryConstraint, BinaryConstraint, struct SingletonOrTopTypeConstraint
IterableConstraint, NameConstraint, TypeAliasExpansionConstraint, FunctionCallConstraint, PrimitiveTypeConstraint, HasPropConstraint>; {
TypeId resultType;
TypeId discriminantType;
};
using ConstraintV = Variant<SubtypeConstraint, PackSubtypeConstraint, GeneralizationConstraint, InstantiationConstraint, UnaryConstraint,
BinaryConstraint, IterableConstraint, NameConstraint, TypeAliasExpansionConstraint, FunctionCallConstraint, PrimitiveTypeConstraint,
HasPropConstraint, SingletonOrTopTypeConstraint>;
struct Constraint struct Constraint
{ {
@ -143,7 +151,7 @@ struct Constraint
Constraint& operator=(const Constraint&) = delete; Constraint& operator=(const Constraint&) = delete;
NotNull<Scope> scope; NotNull<Scope> scope;
Location location; Location location; // TODO: Extract this out into only the constraints that needs a location. Not all constraints needs locations.
ConstraintV c; ConstraintV c;
std::vector<NotNull<Constraint>> dependencies; std::vector<NotNull<Constraint>> dependencies;

View file

@ -1,13 +1,10 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once #pragma once
#include <memory>
#include <vector>
#include <unordered_map>
#include "Luau/Ast.h" #include "Luau/Ast.h"
#include "Luau/Connective.h"
#include "Luau/Constraint.h" #include "Luau/Constraint.h"
#include "Luau/DataFlowGraphBuilder.h"
#include "Luau/Module.h" #include "Luau/Module.h"
#include "Luau/ModuleResolver.h" #include "Luau/ModuleResolver.h"
#include "Luau/NotNull.h" #include "Luau/NotNull.h"
@ -15,6 +12,10 @@
#include "Luau/TypeVar.h" #include "Luau/TypeVar.h"
#include "Luau/Variant.h" #include "Luau/Variant.h"
#include <memory>
#include <vector>
#include <unordered_map>
namespace Luau namespace Luau
{ {
@ -23,6 +24,34 @@ using ScopePtr = std::shared_ptr<Scope>;
struct DcrLogger; struct DcrLogger;
struct Inference
{
TypeId ty = nullptr;
ConnectiveId connective = nullptr;
Inference() = default;
explicit Inference(TypeId ty, ConnectiveId connective = nullptr)
: ty(ty)
, connective(connective)
{
}
};
struct InferencePack
{
TypePackId tp = nullptr;
std::vector<ConnectiveId> connectives;
InferencePack() = default;
explicit InferencePack(TypePackId tp, const std::vector<ConnectiveId>& connectives = {})
: tp(tp)
, connectives(connectives)
{
}
};
struct ConstraintGraphBuilder struct ConstraintGraphBuilder
{ {
// A list of all the scopes in the module. This vector holds ownership of the // A list of all the scopes in the module. This vector holds ownership of the
@ -48,6 +77,8 @@ struct ConstraintGraphBuilder
DenseHashMap<const AstTypePack*, TypePackId> astResolvedTypePacks{nullptr}; DenseHashMap<const AstTypePack*, TypePackId> astResolvedTypePacks{nullptr};
// Defining scopes for AST nodes. // Defining scopes for AST nodes.
DenseHashMap<const AstStatTypeAlias*, ScopePtr> astTypeAliasDefiningScopes{nullptr}; DenseHashMap<const AstStatTypeAlias*, ScopePtr> astTypeAliasDefiningScopes{nullptr};
NotNull<const DataFlowGraph> dfg;
ConnectiveArena connectiveArena;
int recursionCount = 0; int recursionCount = 0;
@ -63,7 +94,8 @@ struct ConstraintGraphBuilder
DcrLogger* logger; DcrLogger* logger;
ConstraintGraphBuilder(const ModuleName& moduleName, ModulePtr module, TypeArena* arena, NotNull<ModuleResolver> moduleResolver, ConstraintGraphBuilder(const ModuleName& moduleName, ModulePtr module, TypeArena* arena, NotNull<ModuleResolver> moduleResolver,
NotNull<SingletonTypes> singletonTypes, NotNull<InternalErrorReporter> ice, const ScopePtr& globalScope, DcrLogger* logger); NotNull<SingletonTypes> singletonTypes, NotNull<InternalErrorReporter> ice, const ScopePtr& globalScope, DcrLogger* logger,
NotNull<DataFlowGraph> dfg);
/** /**
* Fabricates a new free type belonging to a given scope. * Fabricates a new free type belonging to a given scope.
@ -88,15 +120,19 @@ struct ConstraintGraphBuilder
* Adds a new constraint with no dependencies to a given scope. * Adds a new constraint with no dependencies to a given scope.
* @param scope the scope to add the constraint to. * @param scope the scope to add the constraint to.
* @param cv the constraint variant to add. * @param cv the constraint variant to add.
* @return the pointer to the inserted constraint
*/ */
void addConstraint(const ScopePtr& scope, const Location& location, ConstraintV cv); NotNull<Constraint> addConstraint(const ScopePtr& scope, const Location& location, ConstraintV cv);
/** /**
* Adds a constraint to a given scope. * Adds a constraint to a given scope.
* @param scope the scope to add the constraint to. Must not be null. * @param scope the scope to add the constraint to. Must not be null.
* @param c the constraint to add. * @param c the constraint to add.
* @return the pointer to the inserted constraint
*/ */
void addConstraint(const ScopePtr& scope, std::unique_ptr<Constraint> c); NotNull<Constraint> addConstraint(const ScopePtr& scope, std::unique_ptr<Constraint> c);
void applyRefinements(const ScopePtr& scope, Location location, ConnectiveId connective);
/** /**
* The entry point to the ConstraintGraphBuilder. This will construct a set * The entry point to the ConstraintGraphBuilder. This will construct a set
@ -126,8 +162,10 @@ struct ConstraintGraphBuilder
void visit(const ScopePtr& scope, AstStatDeclareFunction* declareFunction); void visit(const ScopePtr& scope, AstStatDeclareFunction* declareFunction);
void visit(const ScopePtr& scope, AstStatError* error); void visit(const ScopePtr& scope, AstStatError* error);
TypePackId checkPack(const ScopePtr& scope, AstArray<AstExpr*> exprs, const std::vector<TypeId>& expectedTypes = {}); InferencePack checkPack(const ScopePtr& scope, AstArray<AstExpr*> exprs, const std::vector<TypeId>& expectedTypes = {});
TypePackId checkPack(const ScopePtr& scope, AstExpr* expr, const std::vector<TypeId>& expectedTypes = {}); InferencePack checkPack(const ScopePtr& scope, AstExpr* expr, const std::vector<TypeId>& expectedTypes = {});
InferencePack checkPack(const ScopePtr& scope, AstExprCall* call, const std::vector<TypeId>& expectedTypes);
/** /**
* Checks an expression that is expected to evaluate to one type. * Checks an expression that is expected to evaluate to one type.
@ -137,15 +175,24 @@ struct ConstraintGraphBuilder
* surrounding context. Used to implement bidirectional type checking. * surrounding context. Used to implement bidirectional type checking.
* @return the type of the expression. * @return the type of the expression.
*/ */
TypeId check(const ScopePtr& scope, AstExpr* expr, std::optional<TypeId> expectedType = {}); Inference check(const ScopePtr& scope, AstExpr* expr, std::optional<TypeId> expectedType = {}, bool forceSingleton = false);
TypeId check(const ScopePtr& scope, AstExprTable* expr, std::optional<TypeId> expectedType); Inference check(const ScopePtr& scope, AstExprConstantString* string, std::optional<TypeId> expectedType, bool forceSingleton);
TypeId check(const ScopePtr& scope, AstExprIndexName* indexName); Inference check(const ScopePtr& scope, AstExprConstantBool* bool_, std::optional<TypeId> expectedType, bool forceSingleton);
TypeId check(const ScopePtr& scope, AstExprIndexExpr* indexExpr); Inference check(const ScopePtr& scope, AstExprLocal* local);
TypeId check(const ScopePtr& scope, AstExprUnary* unary); Inference check(const ScopePtr& scope, AstExprGlobal* global);
TypeId check(const ScopePtr& scope, AstExprBinary* binary); Inference check(const ScopePtr& scope, AstExprIndexName* indexName);
TypeId check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional<TypeId> expectedType); Inference check(const ScopePtr& scope, AstExprIndexExpr* indexExpr);
TypeId check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert); Inference check(const ScopePtr& scope, AstExprUnary* unary);
Inference check(const ScopePtr& scope, AstExprBinary* binary, std::optional<TypeId> expectedType);
Inference check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional<TypeId> expectedType);
Inference check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert);
Inference check(const ScopePtr& scope, AstExprTable* expr, std::optional<TypeId> expectedType);
std::tuple<TypeId, TypeId, ConnectiveId> checkBinary(const ScopePtr& scope, AstExprBinary* binary, std::optional<TypeId> expectedType);
TypePackId checkLValues(const ScopePtr& scope, AstArray<AstExpr*> exprs);
TypeId checkLValue(const ScopePtr& scope, AstExpr* expr);
struct FunctionSignature struct FunctionSignature
{ {
@ -191,7 +238,7 @@ struct ConstraintGraphBuilder
std::vector<std::pair<Name, GenericTypeDefinition>> createGenerics(const ScopePtr& scope, AstArray<AstGenericType> generics); std::vector<std::pair<Name, GenericTypeDefinition>> createGenerics(const ScopePtr& scope, AstArray<AstGenericType> generics);
std::vector<std::pair<Name, GenericTypePackDefinition>> createGenericPacks(const ScopePtr& scope, AstArray<AstGenericTypePack> packs); std::vector<std::pair<Name, GenericTypePackDefinition>> createGenericPacks(const ScopePtr& scope, AstArray<AstGenericTypePack> packs);
TypeId flattenPack(const ScopePtr& scope, Location location, TypePackId tp); Inference flattenPack(const ScopePtr& scope, Location location, InferencePack pack);
void reportError(Location location, TypeErrorData err); void reportError(Location location, TypeErrorData err);
void reportCodeTooComplex(Location location); void reportCodeTooComplex(Location location);

View file

@ -110,6 +110,7 @@ struct ConstraintSolver
bool tryDispatch(const FunctionCallConstraint& c, NotNull<const Constraint> constraint); bool tryDispatch(const FunctionCallConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const PrimitiveTypeConstraint& c, NotNull<const Constraint> constraint); bool tryDispatch(const PrimitiveTypeConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const HasPropConstraint& c, NotNull<const Constraint> constraint); bool tryDispatch(const HasPropConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const SingletonOrTopTypeConstraint& c, NotNull<const Constraint> constraint);
// for a, ... in some_table do // for a, ... in some_table do
// also handles __iter metamethod // also handles __iter metamethod
@ -215,6 +216,8 @@ private:
TypeId errorRecoveryType() const; TypeId errorRecoveryType() const;
TypePackId errorRecoveryTypePack() const; TypePackId errorRecoveryTypePack() const;
TypeId unionOfTypes(TypeId a, TypeId b, NotNull<Scope> scope, bool unifyFreeTypes);
ToStringOptions opts; ToStringOptions opts;
}; };

View file

@ -0,0 +1,115 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
// Do not include LValue. It should never be used here.
#include "Luau/Ast.h"
#include "Luau/DenseHash.h"
#include "Luau/Def.h"
#include "Luau/Symbol.h"
#include <unordered_map>
namespace Luau
{
struct DataFlowGraph
{
DataFlowGraph(DataFlowGraph&&) = default;
DataFlowGraph& operator=(DataFlowGraph&&) = default;
// TODO: AstExprLocal, AstExprGlobal, and AstLocal* are guaranteed never to return nullopt.
// We leave them to return an optional as we build it out, but the end state is for them to return a non-optional DefId.
std::optional<DefId> getDef(const AstExpr* expr) const;
std::optional<DefId> getDef(const AstLocal* local) const;
/// Retrieve the Def that corresponds to the given Symbol.
///
/// We do not perform dataflow analysis on globals, so this function always
/// yields nullopt when passed a global Symbol.
std::optional<DefId> getDef(const Symbol& symbol) const;
private:
DataFlowGraph() = default;
DataFlowGraph(const DataFlowGraph&) = delete;
DataFlowGraph& operator=(const DataFlowGraph&) = delete;
DefArena arena;
DenseHashMap<const AstExpr*, const Def*> astDefs{nullptr};
DenseHashMap<const AstLocal*, const Def*> localDefs{nullptr};
friend struct DataFlowGraphBuilder;
};
struct DfgScope
{
DfgScope* parent;
DenseHashMap<Symbol, const Def*> bindings{Symbol{}};
};
struct ExpressionFlowGraph
{
std::optional<DefId> def;
};
// Currently unsound. We do not presently track the control flow of the program.
// Additionally, we do not presently track assignments.
struct DataFlowGraphBuilder
{
static DataFlowGraph build(AstStatBlock* root, NotNull<struct InternalErrorReporter> handle);
private:
DataFlowGraphBuilder() = default;
DataFlowGraphBuilder(const DataFlowGraphBuilder&) = delete;
DataFlowGraphBuilder& operator=(const DataFlowGraphBuilder&) = delete;
DataFlowGraph graph;
NotNull<DefArena> arena{&graph.arena};
struct InternalErrorReporter* handle;
std::vector<std::unique_ptr<DfgScope>> scopes;
DfgScope* childScope(DfgScope* scope);
std::optional<DefId> use(DfgScope* scope, Symbol symbol, AstExpr* e);
void visit(DfgScope* scope, AstStatBlock* b);
void visitBlockWithoutChildScope(DfgScope* scope, AstStatBlock* b);
// TODO: visit type aliases
void visit(DfgScope* scope, AstStat* s);
void visit(DfgScope* scope, AstStatIf* i);
void visit(DfgScope* scope, AstStatWhile* w);
void visit(DfgScope* scope, AstStatRepeat* r);
void visit(DfgScope* scope, AstStatBreak* b);
void visit(DfgScope* scope, AstStatContinue* c);
void visit(DfgScope* scope, AstStatReturn* r);
void visit(DfgScope* scope, AstStatExpr* e);
void visit(DfgScope* scope, AstStatLocal* l);
void visit(DfgScope* scope, AstStatFor* f);
void visit(DfgScope* scope, AstStatForIn* f);
void visit(DfgScope* scope, AstStatAssign* a);
void visit(DfgScope* scope, AstStatCompoundAssign* c);
void visit(DfgScope* scope, AstStatFunction* f);
void visit(DfgScope* scope, AstStatLocalFunction* l);
ExpressionFlowGraph visitExpr(DfgScope* scope, AstExpr* e);
ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprLocal* l);
ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprGlobal* g);
ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprCall* c);
ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprIndexName* i);
ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprIndexExpr* i);
ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprFunction* f);
ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprTable* t);
ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprUnary* u);
ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprBinary* b);
ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprTypeAssertion* t);
ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprIfElse* i);
ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprInterpString* i);
// TODO: visitLValue
// TODO: visitTypes (because of typeof which has access to values namespace, needs unreachable scope)
// TODO: visitTypePacks (because of typeof which has access to values namespace, needs unreachable scope)
};
} // namespace Luau

View file

@ -0,0 +1,78 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/NotNull.h"
#include "Luau/TypedAllocator.h"
#include "Luau/Variant.h"
namespace Luau
{
using Def = Variant<struct Undefined, struct Phi>;
/**
* We statically approximate a value at runtime using a symbolic value, which we call a Def.
*
* DataFlowGraphBuilder will allocate these defs as a stand-in for some Luau values, and bind them to places that
* can hold a Luau value, and then observes how those defs will commute as it statically evaluate the program.
*
* It must also be noted that defs are a cyclic graph, so it is not safe to recursively traverse into it expecting it to terminate.
*/
using DefId = NotNull<const Def>;
/**
* A "single-object" value.
*
* Leaky implementation note: sometimes "multiple-object" values, but none of which were interesting enough to warrant creating a phi node instead.
* That can happen because there's no point in creating a phi node that points to either resultant in `if math.random() > 0.5 then 5 else "hello"`.
* This might become of utmost importance if we wanted to do some backward reasoning, e.g. if `5` is taken, then `cond` must be `truthy`.
*/
struct Undefined
{
};
/**
* A phi node is a union of defs.
*
* We need this because we're statically evaluating a program, and sometimes a place may be assigned with
* different defs, and when that happens, we need a special data type that merges in all the defs
* that will flow into that specific place. For example, consider this simple program:
*
* ```
* x-1
* if cond() then
* x-2 = 5
* else
* x-3 = "hello"
* end
* x-4 : {x-2, x-3}
* ```
*
* At x-4, we know for a fact statically that either `5` or `"hello"` can flow into the variable `x` after the branch, but
* we cannot make any definitive decisions about which one, so we just take in both.
*/
struct Phi
{
std::vector<DefId> operands;
};
template<typename T>
T* getMutable(DefId def)
{
return get_if<T>(def.get());
}
template<typename T>
const T* get(DefId def)
{
return getMutable<T>(def);
}
struct DefArena
{
TypedAllocator<Def> allocator;
DefId freshDef();
};
} // namespace Luau

View file

@ -7,6 +7,8 @@
#include "Luau/Variant.h" #include "Luau/Variant.h"
#include "Luau/TypeArena.h" #include "Luau/TypeArena.h"
LUAU_FASTFLAG(LuauIceExceptionInheritanceChange)
namespace Luau namespace Luau
{ {
struct TypeError; struct TypeError;
@ -302,12 +304,20 @@ struct NormalizationTooComplex
} }
}; };
struct TypePackMismatch
{
TypePackId wantedTp;
TypePackId givenTp;
bool operator==(const TypePackMismatch& rhs) const;
};
using TypeErrorData = Variant<TypeMismatch, UnknownSymbol, UnknownProperty, NotATable, CannotExtendTable, OnlyTablesCanHaveMethods, using TypeErrorData = Variant<TypeMismatch, UnknownSymbol, UnknownProperty, NotATable, CannotExtendTable, OnlyTablesCanHaveMethods,
DuplicateTypeDefinition, CountMismatch, FunctionDoesNotTakeSelf, FunctionRequiresSelf, OccursCheckFailed, UnknownRequire, DuplicateTypeDefinition, CountMismatch, FunctionDoesNotTakeSelf, FunctionRequiresSelf, OccursCheckFailed, UnknownRequire,
IncorrectGenericParameterCount, SyntaxError, CodeTooComplex, UnificationTooComplex, UnknownPropButFoundLikeProp, GenericError, InternalError, IncorrectGenericParameterCount, SyntaxError, CodeTooComplex, UnificationTooComplex, UnknownPropButFoundLikeProp, GenericError, InternalError,
CannotCallNonFunction, ExtraInformation, DeprecatedApiUsed, ModuleHasCyclicDependency, IllegalRequire, FunctionExitsWithoutReturning, CannotCallNonFunction, ExtraInformation, DeprecatedApiUsed, ModuleHasCyclicDependency, IllegalRequire, FunctionExitsWithoutReturning,
DuplicateGenericParameter, CannotInferBinaryOperation, MissingProperties, SwappedGenericTypeParameter, OptionalValueAccess, MissingUnionProperty, DuplicateGenericParameter, CannotInferBinaryOperation, MissingProperties, SwappedGenericTypeParameter, OptionalValueAccess, MissingUnionProperty,
TypesAreUnrelated, NormalizationTooComplex>; TypesAreUnrelated, NormalizationTooComplex, TypePackMismatch>;
struct TypeError struct TypeError
{ {
@ -374,6 +384,10 @@ struct InternalErrorReporter
class InternalCompilerError : public std::exception class InternalCompilerError : public std::exception
{ {
public: public:
explicit InternalCompilerError(const std::string& message)
: message(message)
{
}
explicit InternalCompilerError(const std::string& message, const std::string& moduleName) explicit InternalCompilerError(const std::string& message, const std::string& moduleName)
: message(message) : message(message)
, moduleName(moduleName) , moduleName(moduleName)
@ -388,8 +402,14 @@ public:
virtual const char* what() const throw(); virtual const char* what() const throw();
const std::string message; const std::string message;
const std::string moduleName; const std::optional<std::string> moduleName;
const std::optional<Location> location; const std::optional<Location> location;
}; };
// These two function overloads only exist to facilitate fast flagging a change to InternalCompilerError
// Both functions can be removed when FFlagLuauIceExceptionInheritanceChange is removed and calling code
// can directly throw InternalCompilerError.
[[noreturn]] void throwRuntimeError(const std::string& message);
[[noreturn]] void throwRuntimeError(const std::string& message, const std::string& moduleName);
} // namespace Luau } // namespace Luau

View file

@ -14,6 +14,8 @@ struct TypeVar;
using TypeId = const TypeVar*; using TypeId = const TypeVar*;
struct Field; struct Field;
// Deprecated. Do not use in new work.
using LValue = Variant<Symbol, Field>; using LValue = Variant<Symbol, Field>;
struct Field struct Field

View file

@ -0,0 +1,32 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Ast.h"
#include <unordered_map>
namespace Luau
{
static const std::unordered_map<AstExprBinary::Op, const char*> kBinaryOpMetamethods{
{AstExprBinary::Op::CompareEq, "__eq"},
{AstExprBinary::Op::CompareNe, "__eq"},
{AstExprBinary::Op::CompareGe, "__lt"},
{AstExprBinary::Op::CompareGt, "__le"},
{AstExprBinary::Op::CompareLe, "__le"},
{AstExprBinary::Op::CompareLt, "__lt"},
{AstExprBinary::Op::Add, "__add"},
{AstExprBinary::Op::Sub, "__sub"},
{AstExprBinary::Op::Mul, "__mul"},
{AstExprBinary::Op::Div, "__div"},
{AstExprBinary::Op::Pow, "__pow"},
{AstExprBinary::Op::Mod, "__mod"},
{AstExprBinary::Op::Concat, "__concat"},
};
static const std::unordered_map<AstExprUnary::Op, const char*> kUnaryOpMetamethods{
{AstExprUnary::Op::Minus, "__unm"},
{AstExprUnary::Op::Len, "__len"},
};
} // namespace Luau

View file

@ -17,19 +17,8 @@ struct SingletonTypes;
using ModulePtr = std::shared_ptr<Module>; using ModulePtr = std::shared_ptr<Module>;
bool isSubtype( bool isSubtype(TypeId subTy, TypeId superTy, NotNull<Scope> scope, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter& ice);
TypeId subTy, TypeId superTy, NotNull<Scope> scope, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter& ice, bool anyIsTop = true); bool isSubtype(TypePackId subTy, TypePackId superTy, NotNull<Scope> scope, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter& ice);
bool isSubtype(TypePackId subTy, TypePackId superTy, NotNull<Scope> scope, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter& ice,
bool anyIsTop = true);
std::pair<TypeId, bool> normalize(
TypeId ty, NotNull<Scope> scope, TypeArena& arena, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter& ice);
std::pair<TypeId, bool> normalize(TypeId ty, NotNull<Module> module, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter& ice);
std::pair<TypeId, bool> normalize(TypeId ty, const ModulePtr& module, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter& ice);
std::pair<TypePackId, bool> normalize(
TypePackId ty, NotNull<Scope> scope, TypeArena& arena, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter& ice);
std::pair<TypePackId, bool> normalize(TypePackId ty, NotNull<Module> module, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter& ice);
std::pair<TypePackId, bool> normalize(TypePackId ty, const ModulePtr& module, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter& ice);
class TypeIds class TypeIds
{ {
@ -115,16 +104,89 @@ struct std::equal_to<const Luau::TypeIds*>
namespace Luau namespace Luau
{ {
// A normalized string type is either `string` (represented by `nullopt`) /** A normalized string type is either `string` (represented by `nullopt`) or a
// or a union of string singletons. * union of string singletons.
using NormalizedStringType = std::optional<std::map<std::string, TypeId>>; *
* When FFlagLuauNegatedStringSingletons is unset, the representation is as
* follows:
*
* * The `string` data type is represented by the option `singletons` having the
* value `std::nullopt`.
* * The type `never` is represented by `singletons` being populated with an
* empty map.
* * A union of string singletons is represented by a map populated by the names
* and TypeIds of the singletons contained therein.
*
* When FFlagLuauNegatedStringSingletons is set, the representation is as
* follows:
*
* * A union of string singletons is finite and includes the singletons named by
* the `singletons` field.
* * An intersection of negated string singletons is cofinite and includes the
* singletons excluded by the `singletons` field. It is implied that cofinite
* values are exclusions from `string` itself.
* * The `string` data type is a cofinite set minus zero elements.
* * The `never` data type is a finite set plus zero elements.
*/
struct NormalizedStringType
{
// When false, this type represents a union of singleton string types.
// eg "a" | "b" | "c"
//
// When true, this type represents string intersected with negated string
// singleton types.
// eg string & ~"a" & ~"b" & ...
bool isCofinite = false;
// A normalized function type is either `never` (represented by `nullopt`) // TODO: This field cannot be nullopt when FFlagLuauNegatedStringSingletons
// is set. When clipping that flag, we can remove the wrapping optional.
std::optional<std::map<std::string, TypeId>> singletons;
void resetToString();
void resetToNever();
bool isNever() const;
bool isString() const;
/// Returns true if the string has finite domain.
///
/// Important subtlety: This method returns true for `never`. The empty set
/// is indeed an empty set.
bool isUnion() const;
/// Returns true if the string has infinite domain.
bool isIntersection() const;
bool includes(const std::string& str) const;
static const NormalizedStringType never;
NormalizedStringType() = default;
NormalizedStringType(bool isCofinite, std::optional<std::map<std::string, TypeId>> singletons);
};
bool isSubtype(const NormalizedStringType& subStr, const NormalizedStringType& superStr);
// A normalized function type can be `never`, the top function type `function`,
// or an intersection of function types. // or an intersection of function types.
// NOTE: type normalization can fail on function types with generics //
// (e.g. because we do not support unions and intersections of generic type packs), // NOTE: type normalization can fail on function types with generics (e.g.
// so this type may contain `error`. // because we do not support unions and intersections of generic type packs), so
using NormalizedFunctionType = std::optional<TypeIds>; // this type may contain `error`.
struct NormalizedFunctionType
{
NormalizedFunctionType();
bool isTop = false;
// TODO: Remove this wrapping optional when clipping
// FFlagLuauNegatedFunctionTypes.
std::optional<TypeIds> parts;
void resetToNever();
void resetToTop();
bool isNever() const;
};
// A normalized generic/free type is a union, where each option is of the form (X & T) where // A normalized generic/free type is a union, where each option is of the form (X & T) where
// * X is either a free type or a generic // * X is either a free type or a generic
@ -166,7 +228,7 @@ struct NormalizedType
// The string part of the type. // The string part of the type.
// This may be the `string` type, or a union of singletons. // This may be the `string` type, or a union of singletons.
NormalizedStringType strings = std::map<std::string, TypeId>{}; NormalizedStringType strings;
// The thread part of the type. // The thread part of the type.
// This type is either never or thread. // This type is either never or thread.
@ -184,12 +246,14 @@ struct NormalizedType
NormalizedType(NotNull<SingletonTypes> singletonTypes); NormalizedType(NotNull<SingletonTypes> singletonTypes);
NormalizedType(const NormalizedType&) = delete;
NormalizedType(NormalizedType&&) = default;
NormalizedType() = delete; NormalizedType() = delete;
~NormalizedType() = default; ~NormalizedType() = default;
NormalizedType(const NormalizedType&) = delete;
NormalizedType& operator=(const NormalizedType&) = delete;
NormalizedType(NormalizedType&&) = default;
NormalizedType& operator=(NormalizedType&&) = default; NormalizedType& operator=(NormalizedType&&) = default;
NormalizedType& operator=(NormalizedType&) = delete;
}; };
class Normalizer class Normalizer
@ -240,8 +304,14 @@ public:
bool unionNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars = -1); bool unionNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars = -1);
bool unionNormalWithTy(NormalizedType& here, TypeId there, int ignoreSmallerTyvars = -1); bool unionNormalWithTy(NormalizedType& here, TypeId there, int ignoreSmallerTyvars = -1);
// ------- Negations
std::optional<NormalizedType> negateNormal(const NormalizedType& here);
TypeIds negateAll(const TypeIds& theres);
TypeId negate(TypeId there);
void subtractPrimitive(NormalizedType& here, TypeId ty);
void subtractSingleton(NormalizedType& here, TypeId ty);
// ------- Normalizing intersections // ------- Normalizing intersections
void intersectTysWithTy(TypeIds& here, TypeId there);
TypeId intersectionOfTops(TypeId here, TypeId there); TypeId intersectionOfTops(TypeId here, TypeId there);
TypeId intersectionOfBools(TypeId here, TypeId there); TypeId intersectionOfBools(TypeId here, TypeId there);
void intersectClasses(TypeIds& heres, const TypeIds& theres); void intersectClasses(TypeIds& heres, const TypeIds& theres);

View file

@ -2,6 +2,7 @@
#pragma once #pragma once
#include "Luau/Common.h" #include "Luau/Common.h"
#include "Luau/Error.h"
#include <stdexcept> #include <stdexcept>
#include <exception> #include <exception>
@ -9,10 +10,20 @@
namespace Luau namespace Luau
{ {
struct RecursionLimitException : public std::exception struct RecursionLimitException : public InternalCompilerError
{
RecursionLimitException()
: InternalCompilerError("Internal recursion counter limit exceeded")
{
LUAU_ASSERT(FFlag::LuauIceExceptionInheritanceChange);
}
};
struct RecursionLimitException_DEPRECATED : public std::exception
{ {
const char* what() const noexcept const char* what() const noexcept
{ {
LUAU_ASSERT(!FFlag::LuauIceExceptionInheritanceChange);
return "Internal recursion counter limit exceeded"; return "Internal recursion counter limit exceeded";
} }
}; };
@ -42,7 +53,14 @@ struct RecursionLimiter : RecursionCounter
{ {
if (limit > 0 && *count > limit) if (limit > 0 && *count > limit)
{ {
throw RecursionLimitException(); if (FFlag::LuauIceExceptionInheritanceChange)
{
throw RecursionLimitException();
}
else
{
throw RecursionLimitException_DEPRECATED();
}
} }
} }
}; };

View file

@ -54,7 +54,9 @@ struct Scope
DenseHashSet<Name> builtinTypeNames{""}; DenseHashSet<Name> builtinTypeNames{""};
void addBuiltinTypeBinding(const Name& name, const TypeFun& tyFun); void addBuiltinTypeBinding(const Name& name, const TypeFun& tyFun);
std::optional<TypeId> lookup(Symbol sym); std::optional<TypeId> lookup(Symbol sym) const;
std::optional<TypeId> lookup(DefId def) const;
std::optional<std::pair<TypeId, Scope*>> lookupEx(Symbol sym);
std::optional<TypeFun> lookupType(const Name& name); std::optional<TypeFun> lookupType(const Name& name);
std::optional<TypeFun> lookupImportedType(const Name& moduleAlias, const Name& name); std::optional<TypeFun> lookupImportedType(const Name& moduleAlias, const Name& name);
@ -66,6 +68,7 @@ struct Scope
std::optional<Binding> linearSearchForBinding(const std::string& name, bool traverseScopeChain = true) const; std::optional<Binding> linearSearchForBinding(const std::string& name, bool traverseScopeChain = true) const;
RefinementMap refinements; RefinementMap refinements;
DenseHashMap<const Def*, TypeId> dcrRefinements{nullptr};
// For mutually recursive type aliases, it's important that // For mutually recursive type aliases, it's important that
// they use the same types for the same names. // they use the same types for the same names.

View file

@ -6,10 +6,11 @@
#include <string> #include <string>
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
namespace Luau namespace Luau
{ {
// TODO Rename this to Name once the old type alias is gone.
struct Symbol struct Symbol
{ {
Symbol() Symbol()
@ -40,9 +41,12 @@ struct Symbol
{ {
if (local) if (local)
return local == rhs.local; return local == rhs.local;
if (global.value) else if (global.value)
return rhs.global.value && global == rhs.global.value; // Subtlety: AstName::operator==(const char*) uses strcmp, not pointer identity. return rhs.global.value && global == rhs.global.value; // Subtlety: AstName::operator==(const char*) uses strcmp, not pointer identity.
return false; else if (FFlag::DebugLuauDeferredConstraintResolution)
return !rhs.local && !rhs.global.value; // Reflexivity: we already know `this` Symbol is empty, so check that rhs is.
else
return false;
} }
bool operator!=(const Symbol& rhs) const bool operator!=(const Symbol& rhs) const
@ -58,8 +62,8 @@ struct Symbol
return global < rhs.global; return global < rhs.global;
else if (local) else if (local)
return true; return true;
else
return false; return false;
} }
AstName astName() const AstName astName() const

View file

@ -117,6 +117,8 @@ inline std::string toStringNamedFunction(const std::string& funcName, const Func
return toStringNamedFunction(funcName, ftv, opts); return toStringNamedFunction(funcName, ftv, opts);
} }
std::optional<std::string> getFunctionNameAsString(const AstExpr& expr);
// It could be useful to see the text representation of a type during a debugging session instead of exploring the content of the class // It could be useful to see the text representation of a type during a debugging session instead of exploring the content of the class
// These functions will dump the type to stdout and can be evaluated in Watch/Immediate windows or as gdb/lldb expression // These functions will dump the type to stdout and can be evaluated in Watch/Immediate windows or as gdb/lldb expression
std::string dump(TypeId ty); std::string dump(TypeId ty);

View file

@ -48,7 +48,17 @@ struct HashBoolNamePair
size_t operator()(const std::pair<bool, Name>& pair) const; size_t operator()(const std::pair<bool, Name>& pair) const;
}; };
class TimeLimitError : public std::exception class TimeLimitError : public InternalCompilerError
{
public:
explicit TimeLimitError(const std::string& moduleName)
: InternalCompilerError("Typeinfer failed to complete in allotted time", moduleName)
{
LUAU_ASSERT(FFlag::LuauIceExceptionInheritanceChange);
}
};
class TimeLimitError_DEPRECATED : public std::exception
{ {
public: public:
virtual const char* what() const throw(); virtual const char* what() const throw();
@ -192,18 +202,12 @@ struct TypeChecker
ErrorVec canUnify(TypeId subTy, TypeId superTy, const ScopePtr& scope, const Location& location); ErrorVec canUnify(TypeId subTy, TypeId superTy, const ScopePtr& scope, const Location& location);
ErrorVec canUnify(TypePackId subTy, TypePackId superTy, const ScopePtr& scope, const Location& location); ErrorVec canUnify(TypePackId subTy, TypePackId superTy, const ScopePtr& scope, const Location& location);
void unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel demotedLevel, const ScopePtr& scope, const Location& location);
std::optional<TypeId> findMetatableEntry(TypeId type, std::string entry, const Location& location, bool addErrors); std::optional<TypeId> findMetatableEntry(TypeId type, std::string entry, const Location& location, bool addErrors);
std::optional<TypeId> findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location, bool addErrors); std::optional<TypeId> findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location, bool addErrors);
std::optional<TypeId> getIndexTypeFromType(const ScopePtr& scope, TypeId type, const Name& name, const Location& location, bool addErrors); std::optional<TypeId> getIndexTypeFromType(const ScopePtr& scope, TypeId type, const Name& name, const Location& location, bool addErrors);
std::optional<TypeId> getIndexTypeFromTypeImpl(const ScopePtr& scope, TypeId type, const Name& name, const Location& location, bool addErrors); std::optional<TypeId> getIndexTypeFromTypeImpl(const ScopePtr& scope, TypeId type, const Name& name, const Location& location, bool addErrors);
// Reduces the union to its simplest possible shape.
// (A | B) | B | C yields A | B | C
std::vector<TypeId> reduceUnion(const std::vector<TypeId>& types);
std::optional<TypeId> tryStripUnionFromNil(TypeId ty); std::optional<TypeId> tryStripUnionFromNil(TypeId ty);
TypeId stripFromNilAndReport(TypeId ty, const Location& location); TypeId stripFromNilAndReport(TypeId ty, const Location& location);
@ -242,6 +246,7 @@ public:
[[noreturn]] void ice(const std::string& message, const Location& location); [[noreturn]] void ice(const std::string& message, const Location& location);
[[noreturn]] void ice(const std::string& message); [[noreturn]] void ice(const std::string& message);
[[noreturn]] void throwTimeLimitError();
ScopePtr childFunctionScope(const ScopePtr& parent, const Location& location, int subLevel = 0); ScopePtr childFunctionScope(const ScopePtr& parent, const Location& location, int subLevel = 0);
ScopePtr childScope(const ScopePtr& parent, const Location& location); ScopePtr childScope(const ScopePtr& parent, const Location& location);

View file

@ -29,4 +29,23 @@ std::pair<size_t, std::optional<size_t>> getParameterExtents(const TxnLog* log,
// various other things to get there. // various other things to get there.
std::vector<TypeId> flatten(TypeArena& arena, NotNull<SingletonTypes> singletonTypes, TypePackId pack, size_t length); std::vector<TypeId> flatten(TypeArena& arena, NotNull<SingletonTypes> singletonTypes, TypePackId pack, size_t length);
/**
* Reduces a union by decomposing to the any/error type if it appears in the
* type list, and by merging child unions. Also strips out duplicate (by pointer
* identity) types.
* @param types the input type list to reduce.
* @returns the reduced type list.
*/
std::vector<TypeId> reduceUnion(const std::vector<TypeId>& types);
/**
* Tries to remove nil from a union type, if there's another option. T | nil
* reduces to T, but nil itself does not reduce.
* @param singletonTypes the singleton types to use
* @param arena the type arena to allocate the new type in, if necessary
* @param ty the type to remove nil from
* @returns a type with nil removed, or nil itself if that were the only option.
*/
TypeId stripNil(NotNull<SingletonTypes> singletonTypes, TypeArena& arena, TypeId ty);
} // namespace Luau } // namespace Luau

View file

@ -2,22 +2,23 @@
#pragma once #pragma once
#include "Luau/Ast.h" #include "Luau/Ast.h"
#include "Luau/Common.h"
#include "Luau/DenseHash.h" #include "Luau/DenseHash.h"
#include "Luau/Def.h"
#include "Luau/NotNull.h"
#include "Luau/Predicate.h" #include "Luau/Predicate.h"
#include "Luau/Unifiable.h" #include "Luau/Unifiable.h"
#include "Luau/Variant.h" #include "Luau/Variant.h"
#include "Luau/Common.h"
#include "Luau/NotNull.h"
#include <set>
#include <string>
#include <map>
#include <unordered_set>
#include <unordered_map>
#include <vector>
#include <deque> #include <deque>
#include <map>
#include <memory> #include <memory>
#include <optional> #include <optional>
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
LUAU_FASTINT(LuauTableTypeMaximumStringifierLength) LUAU_FASTINT(LuauTableTypeMaximumStringifierLength)
LUAU_FASTINT(LuauTypeMaximumStringifierLength) LUAU_FASTINT(LuauTypeMaximumStringifierLength)
@ -114,6 +115,7 @@ struct PrimitiveTypeVar
Number, Number,
String, String,
Thread, Thread,
Function,
}; };
Type type; Type type;
@ -131,24 +133,6 @@ struct PrimitiveTypeVar
} }
}; };
struct ConstrainedTypeVar
{
explicit ConstrainedTypeVar(TypeLevel level)
: level(level)
{
}
explicit ConstrainedTypeVar(TypeLevel level, const std::vector<TypeId>& parts)
: parts(parts)
, level(level)
{
}
std::vector<TypeId> parts;
TypeLevel level;
Scope* scope = nullptr;
};
// Singleton types https://github.com/Roblox/luau/blob/master/rfcs/syntax-singleton-types.md // Singleton types https://github.com/Roblox/luau/blob/master/rfcs/syntax-singleton-types.md
// Types for true and false // Types for true and false
struct BooleanSingleton struct BooleanSingleton
@ -496,11 +480,13 @@ struct AnyTypeVar
{ {
}; };
// T | U
struct UnionTypeVar struct UnionTypeVar
{ {
std::vector<TypeId> options; std::vector<TypeId> options;
}; };
// T & U
struct IntersectionTypeVar struct IntersectionTypeVar
{ {
std::vector<TypeId> parts; std::vector<TypeId> parts;
@ -519,12 +505,19 @@ struct NeverTypeVar
{ {
}; };
// ~T
// TODO: Some simplification step that overwrites the type graph to make sure negation
// types disappear from the user's view, and (?) a debug flag to disable that
struct NegationTypeVar
{
TypeId ty;
};
using ErrorTypeVar = Unifiable::Error; using ErrorTypeVar = Unifiable::Error;
using TypeVariant = using TypeVariant =
Unifiable::Variant<TypeId, PrimitiveTypeVar, ConstrainedTypeVar, BlockedTypeVar, PendingExpansionTypeVar, SingletonTypeVar, FunctionTypeVar, Unifiable::Variant<TypeId, PrimitiveTypeVar, BlockedTypeVar, PendingExpansionTypeVar, SingletonTypeVar, FunctionTypeVar, TableTypeVar,
TableTypeVar, MetatableTypeVar, ClassTypeVar, AnyTypeVar, UnionTypeVar, IntersectionTypeVar, LazyTypeVar, UnknownTypeVar, NeverTypeVar>; MetatableTypeVar, ClassTypeVar, AnyTypeVar, UnionTypeVar, IntersectionTypeVar, LazyTypeVar, UnknownTypeVar, NeverTypeVar, NegationTypeVar>;
struct TypeVar final struct TypeVar final
{ {
@ -541,7 +534,6 @@ struct TypeVar final
TypeVar(const TypeVariant& ty, bool persistent) TypeVar(const TypeVariant& ty, bool persistent)
: ty(ty) : ty(ty)
, persistent(persistent) , persistent(persistent)
, normal(persistent) // We assume that all persistent types are irreducable.
{ {
} }
@ -549,7 +541,6 @@ struct TypeVar final
void reassign(const TypeVar& rhs) void reassign(const TypeVar& rhs)
{ {
ty = rhs.ty; ty = rhs.ty;
normal = rhs.normal;
documentationSymbol = rhs.documentationSymbol; documentationSymbol = rhs.documentationSymbol;
} }
@ -560,10 +551,6 @@ struct TypeVar final
// Persistent TypeVars do not get cloned. // Persistent TypeVars do not get cloned.
bool persistent = false; bool persistent = false;
// Normalization sets this for types that are fully normalized.
// This implies that they are transitively immutable.
bool normal = false;
std::optional<std::string> documentationSymbol; std::optional<std::string> documentationSymbol;
// Pointer to the type arena that allocated this type. // Pointer to the type arena that allocated this type.
@ -650,12 +637,15 @@ public:
const TypeId stringType; const TypeId stringType;
const TypeId booleanType; const TypeId booleanType;
const TypeId threadType; const TypeId threadType;
const TypeId functionType;
const TypeId trueType; const TypeId trueType;
const TypeId falseType; const TypeId falseType;
const TypeId anyType; const TypeId anyType;
const TypeId unknownType; const TypeId unknownType;
const TypeId neverType; const TypeId neverType;
const TypeId errorType; const TypeId errorType;
const TypeId falsyType; // No type binding!
const TypeId truthyType; // No type binding!
const TypePackId anyTypePack; const TypePackId anyTypePack;
const TypePackId neverTypePack; const TypePackId neverTypePack;
@ -703,7 +693,6 @@ T* getMutable(TypeId tv)
const std::vector<TypeId>& getTypes(const UnionTypeVar* utv); const std::vector<TypeId>& getTypes(const UnionTypeVar* utv);
const std::vector<TypeId>& getTypes(const IntersectionTypeVar* itv); const std::vector<TypeId>& getTypes(const IntersectionTypeVar* itv);
const std::vector<TypeId>& getTypes(const ConstrainedTypeVar* ctv);
template<typename T> template<typename T>
struct TypeIterator; struct TypeIterator;
@ -716,10 +705,6 @@ using IntersectionTypeVarIterator = TypeIterator<IntersectionTypeVar>;
IntersectionTypeVarIterator begin(const IntersectionTypeVar* itv); IntersectionTypeVarIterator begin(const IntersectionTypeVar* itv);
IntersectionTypeVarIterator end(const IntersectionTypeVar* itv); IntersectionTypeVarIterator end(const IntersectionTypeVar* itv);
using ConstrainedTypeVarIterator = TypeIterator<ConstrainedTypeVar>;
ConstrainedTypeVarIterator begin(const ConstrainedTypeVar* ctv);
ConstrainedTypeVarIterator end(const ConstrainedTypeVar* ctv);
/* Traverses the type T yielding each TypeId. /* Traverses the type T yielding each TypeId.
* If the iterator encounters a nested type T, it will instead yield each TypeId within. * If the iterator encounters a nested type T, it will instead yield each TypeId within.
*/ */
@ -793,7 +778,6 @@ struct TypeIterator
// with templates portability in this area, so not worth it. Thanks MSVC. // with templates portability in this area, so not worth it. Thanks MSVC.
friend UnionTypeVarIterator end(const UnionTypeVar*); friend UnionTypeVarIterator end(const UnionTypeVar*);
friend IntersectionTypeVarIterator end(const IntersectionTypeVar*); friend IntersectionTypeVarIterator end(const IntersectionTypeVar*);
friend ConstrainedTypeVarIterator end(const ConstrainedTypeVar*);
private: private:
TypeIterator() = default; TypeIterator() = default;

View file

@ -61,7 +61,6 @@ struct Unifier
ErrorVec errors; ErrorVec errors;
Location location; Location location;
Variance variance = Covariant; Variance variance = Covariant;
bool anyIsTop = false; // If true, we consider any to be a top type. If false, it is a familiar but weird mix of top and bottom all at once.
bool normalize; // Normalize unions and intersections if necessary bool normalize; // Normalize unions and intersections if necessary
bool useScopes = false; // If true, we use the scope hierarchy rather than TypeLevels bool useScopes = false; // If true, we use the scope hierarchy rather than TypeLevels
CountMismatch::Context ctx = CountMismatch::Arg; CountMismatch::Context ctx = CountMismatch::Arg;
@ -96,6 +95,8 @@ private:
void tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed); void tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed);
void tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed); void tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed);
void tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed); void tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed);
void tryUnifyTypeWithNegation(TypeId subTy, TypeId superTy);
void tryUnifyNegationWithType(TypeId subTy, TypeId superTy);
TypePackId tryApplyOverloadedFunction(TypeId function, const NormalizedFunctionType& overloads, TypePackId args); TypePackId tryApplyOverloadedFunction(TypeId function, const NormalizedFunctionType& overloads, TypePackId args);
@ -119,12 +120,7 @@ private:
std::optional<TypeId> findTablePropertyRespectingMeta(TypeId lhsType, Name name); std::optional<TypeId> findTablePropertyRespectingMeta(TypeId lhsType, Name name);
void tryUnifyWithConstrainedSubTypeVar(TypeId subTy, TypeId superTy);
void tryUnifyWithConstrainedSuperTypeVar(TypeId subTy, TypeId superTy);
public: public:
void unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel demotedLevel);
// Returns true if the type "needle" already occurs within "haystack" and reports an "infinite type error" // Returns true if the type "needle" already occurs within "haystack" and reports an "infinite type error"
bool occursCheck(TypeId needle, TypeId haystack); bool occursCheck(TypeId needle, TypeId haystack);
bool occursCheck(DenseHashSet<TypeId>& seen, TypeId needle, TypeId haystack); bool occursCheck(DenseHashSet<TypeId>& seen, TypeId needle, TypeId haystack);
@ -134,6 +130,7 @@ public:
Unifier makeChildUnifier(); Unifier makeChildUnifier();
void reportError(TypeError err); void reportError(TypeError err);
LUAU_NOINLINE void reportError(Location location, TypeErrorData data);
private: private:
bool isNonstrictMode() const; bool isNonstrictMode() const;

View file

@ -58,13 +58,15 @@ public:
constexpr int tid = getTypeId<T>(); constexpr int tid = getTypeId<T>();
typeId = tid; typeId = tid;
new (&storage) TT(value); new (&storage) TT(std::forward<T>(value));
} }
Variant(const Variant& other) Variant(const Variant& other)
{ {
static constexpr FnCopy table[sizeof...(Ts)] = {&fnCopy<Ts>...};
typeId = other.typeId; typeId = other.typeId;
tableCopy[typeId](&storage, &other.storage); table[typeId](&storage, &other.storage);
} }
Variant(Variant&& other) Variant(Variant&& other)
@ -105,7 +107,7 @@ public:
tableDtor[typeId](&storage); tableDtor[typeId](&storage);
typeId = tid; typeId = tid;
new (&storage) TT(std::forward<Args>(args)...); new (&storage) TT{std::forward<Args>(args)...};
return *reinterpret_cast<T*>(&storage); return *reinterpret_cast<T*>(&storage);
} }
@ -192,7 +194,6 @@ private:
return *static_cast<const T*>(lhs) == *static_cast<const T*>(rhs); return *static_cast<const T*>(lhs) == *static_cast<const T*>(rhs);
} }
static constexpr FnCopy tableCopy[sizeof...(Ts)] = {&fnCopy<Ts>...};
static constexpr FnMove tableMove[sizeof...(Ts)] = {&fnMove<Ts>...}; static constexpr FnMove tableMove[sizeof...(Ts)] = {&fnMove<Ts>...};
static constexpr FnDtor tableDtor[sizeof...(Ts)] = {&fnDtor<Ts>...}; static constexpr FnDtor tableDtor[sizeof...(Ts)] = {&fnDtor<Ts>...};

View file

@ -103,10 +103,6 @@ struct GenericTypeVarVisitor
{ {
return visit(ty); return visit(ty);
} }
virtual bool visit(TypeId ty, const ConstrainedTypeVar& ctv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const PrimitiveTypeVar& ptv) virtual bool visit(TypeId ty, const PrimitiveTypeVar& ptv)
{ {
return visit(ty); return visit(ty);
@ -159,6 +155,10 @@ struct GenericTypeVarVisitor
{ {
return visit(ty); return visit(ty);
} }
virtual bool visit(TypeId ty, const NegationTypeVar& ntv)
{
return visit(ty);
}
virtual bool visit(TypePackId tp) virtual bool visit(TypePackId tp)
{ {
@ -216,14 +216,6 @@ struct GenericTypeVarVisitor
visit(ty, *gtv); visit(ty, *gtv);
else if (auto etv = get<ErrorTypeVar>(ty)) else if (auto etv = get<ErrorTypeVar>(ty))
visit(ty, *etv); visit(ty, *etv);
else if (auto ctv = get<ConstrainedTypeVar>(ty))
{
if (visit(ty, *ctv))
{
for (TypeId part : ctv->parts)
traverse(part);
}
}
else if (auto ptv = get<PrimitiveTypeVar>(ty)) else if (auto ptv = get<PrimitiveTypeVar>(ty))
visit(ty, *ptv); visit(ty, *ptv);
else if (auto ftv = get<FunctionTypeVar>(ty)) else if (auto ftv = get<FunctionTypeVar>(ty))
@ -325,6 +317,8 @@ struct GenericTypeVarVisitor
traverse(a); traverse(a);
} }
} }
else if (auto ntv = get<NegationTypeVar>(ty))
visit(ty, *ntv);
else if (!FFlag::LuauCompleteVisitor) else if (!FFlag::LuauCompleteVisitor)
return visit_detail::unsee(seen, ty); return visit_detail::unsee(seen, ty);
else else

View file

@ -37,8 +37,6 @@ bool Anyification::isDirty(TypeId ty)
return (ttv->state == TableState::Free || ttv->state == TableState::Unsealed); return (ttv->state == TableState::Free || ttv->state == TableState::Unsealed);
else if (log->getMutable<FreeTypeVar>(ty)) else if (log->getMutable<FreeTypeVar>(ty))
return true; return true;
else if (get<ConstrainedTypeVar>(ty))
return true;
else else
return false; return false;
} }
@ -65,20 +63,8 @@ TypeId Anyification::clean(TypeId ty)
clone.syntheticName = ttv->syntheticName; clone.syntheticName = ttv->syntheticName;
clone.tags = ttv->tags; clone.tags = ttv->tags;
TypeId res = addType(std::move(clone)); TypeId res = addType(std::move(clone));
asMutable(res)->normal = ty->normal;
return res; return res;
} }
else if (auto ctv = get<ConstrainedTypeVar>(ty))
{
std::vector<TypeId> copy = ctv->parts;
for (TypeId& ty : copy)
ty = replace(ty);
TypeId res = copy.size() == 1 ? copy[0] : addType(UnionTypeVar{std::move(copy)});
auto [t, ok] = normalize(res, scope, *arena, singletonTypes, *iceHandler);
if (!ok)
normalizationTooComplex = true;
return t;
}
else else
return anyType; return anyType;
} }

View file

@ -11,6 +11,8 @@
#include <algorithm> #include <algorithm>
LUAU_FASTFLAGVARIABLE(LuauCheckOverloadedDocSymbol, false)
namespace Luau namespace Luau
{ {
@ -427,6 +429,38 @@ ExprOrLocal findExprOrLocalAtPosition(const SourceModule& source, Position pos)
return findVisitor.result; return findVisitor.result;
} }
static std::optional<DocumentationSymbol> checkOverloadedDocumentationSymbol(
const Module& module, const TypeId ty, const AstExpr* parentExpr, const std::optional<DocumentationSymbol> documentationSymbol)
{
LUAU_ASSERT(FFlag::LuauCheckOverloadedDocSymbol);
if (!documentationSymbol)
return std::nullopt;
// This might be an overloaded function.
if (get<IntersectionTypeVar>(follow(ty)))
{
TypeId matchingOverload = nullptr;
if (parentExpr && parentExpr->is<AstExprCall>())
{
if (auto it = module.astOverloadResolvedTypes.find(parentExpr))
{
matchingOverload = *it;
}
}
if (matchingOverload)
{
std::string overloadSymbol = *documentationSymbol + "/overload/";
// Default toString options are fine for this purpose.
overloadSymbol += toString(matchingOverload);
return overloadSymbol;
}
}
return documentationSymbol;
}
std::optional<DocumentationSymbol> getDocumentationSymbolAtPosition(const SourceModule& source, const Module& module, Position position) std::optional<DocumentationSymbol> getDocumentationSymbolAtPosition(const SourceModule& source, const Module& module, Position position)
{ {
std::vector<AstNode*> ancestry = findAstAncestryOfPosition(source, position); std::vector<AstNode*> ancestry = findAstAncestryOfPosition(source, position);
@ -436,31 +470,38 @@ std::optional<DocumentationSymbol> getDocumentationSymbolAtPosition(const Source
if (std::optional<Binding> binding = findBindingAtPosition(module, source, position)) if (std::optional<Binding> binding = findBindingAtPosition(module, source, position))
{ {
if (binding->documentationSymbol) if (FFlag::LuauCheckOverloadedDocSymbol)
{ {
// This might be an overloaded function binding. return checkOverloadedDocumentationSymbol(module, binding->typeId, parentExpr, binding->documentationSymbol);
if (get<IntersectionTypeVar>(follow(binding->typeId))) }
else
{
if (binding->documentationSymbol)
{ {
TypeId matchingOverload = nullptr; // This might be an overloaded function binding.
if (parentExpr && parentExpr->is<AstExprCall>()) if (get<IntersectionTypeVar>(follow(binding->typeId)))
{ {
if (auto it = module.astOverloadResolvedTypes.find(parentExpr)) TypeId matchingOverload = nullptr;
if (parentExpr && parentExpr->is<AstExprCall>())
{ {
matchingOverload = *it; if (auto it = module.astOverloadResolvedTypes.find(parentExpr))
{
matchingOverload = *it;
}
}
if (matchingOverload)
{
std::string overloadSymbol = *binding->documentationSymbol + "/overload/";
// Default toString options are fine for this purpose.
overloadSymbol += toString(matchingOverload);
return overloadSymbol;
} }
} }
if (matchingOverload)
{
std::string overloadSymbol = *binding->documentationSymbol + "/overload/";
// Default toString options are fine for this purpose.
overloadSymbol += toString(matchingOverload);
return overloadSymbol;
}
} }
}
return binding->documentationSymbol; return binding->documentationSymbol;
}
} }
if (targetExpr) if (targetExpr)
@ -474,14 +515,20 @@ std::optional<DocumentationSymbol> getDocumentationSymbolAtPosition(const Source
{ {
if (auto propIt = ttv->props.find(indexName->index.value); propIt != ttv->props.end()) if (auto propIt = ttv->props.find(indexName->index.value); propIt != ttv->props.end())
{ {
return propIt->second.documentationSymbol; if (FFlag::LuauCheckOverloadedDocSymbol)
return checkOverloadedDocumentationSymbol(module, propIt->second.type, parentExpr, propIt->second.documentationSymbol);
else
return propIt->second.documentationSymbol;
} }
} }
else if (const ClassTypeVar* ctv = get<ClassTypeVar>(parentTy)) else if (const ClassTypeVar* ctv = get<ClassTypeVar>(parentTy))
{ {
if (auto propIt = ctv->props.find(indexName->index.value); propIt != ctv->props.end()) if (auto propIt = ctv->props.find(indexName->index.value); propIt != ctv->props.end())
{ {
return propIt->second.documentationSymbol; if (FFlag::LuauCheckOverloadedDocSymbol)
return checkOverloadedDocumentationSymbol(module, propIt->second.type, parentExpr, propIt->second.documentationSymbol);
else
return propIt->second.documentationSymbol;
} }
} }
} }

View file

@ -10,6 +10,7 @@
#include "Luau/TypeInfer.h" #include "Luau/TypeInfer.h"
#include "Luau/TypePack.h" #include "Luau/TypePack.h"
#include "Luau/TypeVar.h" #include "Luau/TypeVar.h"
#include "Luau/TypeUtils.h"
#include <algorithm> #include <algorithm>
@ -41,6 +42,7 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionRequire(
static bool dcrMagicFunctionSelect(MagicFunctionCallContext context); static bool dcrMagicFunctionSelect(MagicFunctionCallContext context);
static bool dcrMagicFunctionRequire(MagicFunctionCallContext context); static bool dcrMagicFunctionRequire(MagicFunctionCallContext context);
static bool dcrMagicFunctionPack(MagicFunctionCallContext context);
TypeId makeUnion(TypeArena& arena, std::vector<TypeId>&& types) TypeId makeUnion(TypeArena& arena, std::vector<TypeId>&& types)
{ {
@ -333,6 +335,7 @@ void registerBuiltinGlobals(TypeChecker& typeChecker)
ttv->props["clone"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.clone"); ttv->props["clone"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.clone");
attachMagicFunction(ttv->props["pack"].type, magicFunctionPack); attachMagicFunction(ttv->props["pack"].type, magicFunctionPack);
attachDcrMagicFunction(ttv->props["pack"].type, dcrMagicFunctionPack);
} }
attachMagicFunction(getGlobalBinding(typeChecker, "require"), magicFunctionRequire); attachMagicFunction(getGlobalBinding(typeChecker, "require"), magicFunctionRequire);
@ -660,7 +663,7 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionPack(
options.push_back(vtp->ty); options.push_back(vtp->ty);
} }
options = typechecker.reduceUnion(options); options = reduceUnion(options);
// table.pack() -> {| n: number, [number]: nil |} // table.pack() -> {| n: number, [number]: nil |}
// table.pack(1) -> {| n: number, [number]: number |} // table.pack(1) -> {| n: number, [number]: number |}
@ -679,6 +682,46 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionPack(
return WithPredicate<TypePackId>{arena.addTypePack({packedTable})}; return WithPredicate<TypePackId>{arena.addTypePack({packedTable})};
} }
static bool dcrMagicFunctionPack(MagicFunctionCallContext context)
{
TypeArena* arena = context.solver->arena;
const auto& [paramTypes, paramTail] = flatten(context.arguments);
std::vector<TypeId> options;
options.reserve(paramTypes.size());
for (auto type : paramTypes)
options.push_back(type);
if (paramTail)
{
if (const VariadicTypePack* vtp = get<VariadicTypePack>(*paramTail))
options.push_back(vtp->ty);
}
options = reduceUnion(options);
// table.pack() -> {| n: number, [number]: nil |}
// table.pack(1) -> {| n: number, [number]: number |}
// table.pack(1, "foo") -> {| n: number, [number]: number | string |}
TypeId result = nullptr;
if (options.empty())
result = context.solver->singletonTypes->nilType;
else if (options.size() == 1)
result = options[0];
else
result = arena->addType(UnionTypeVar{std::move(options)});
TypeId numberType = context.solver->singletonTypes->numberType;
TypeId packedTable = arena->addType(TableTypeVar{{{"n", {numberType}}}, TableIndexer(numberType, result), {}, TableState::Sealed});
TypePackId tableTypePack = arena->addTypePack({packedTable});
asMutable(context.result)->ty.emplace<BoundTypePack>(tableTypePack);
return true;
}
static bool checkRequirePath(TypeChecker& typechecker, AstExpr* expr) static bool checkRequirePath(TypeChecker& typechecker, AstExpr* expr)
{ {
// require(foo.parent.bar) will technically work, but it depends on legacy goop that // require(foo.parent.bar) will technically work, but it depends on legacy goop that

View file

@ -1,6 +1,6 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Clone.h" #include "Luau/Clone.h"
#include "Luau/RecursionCounter.h" #include "Luau/RecursionCounter.h"
#include "Luau/TxnLog.h" #include "Luau/TxnLog.h"
#include "Luau/TypePack.h" #include "Luau/TypePack.h"
@ -51,7 +51,6 @@ struct TypeCloner
void operator()(const BlockedTypeVar& t); void operator()(const BlockedTypeVar& t);
void operator()(const PendingExpansionTypeVar& t); void operator()(const PendingExpansionTypeVar& t);
void operator()(const PrimitiveTypeVar& t); void operator()(const PrimitiveTypeVar& t);
void operator()(const ConstrainedTypeVar& t);
void operator()(const SingletonTypeVar& t); void operator()(const SingletonTypeVar& t);
void operator()(const FunctionTypeVar& t); void operator()(const FunctionTypeVar& t);
void operator()(const TableTypeVar& t); void operator()(const TableTypeVar& t);
@ -63,6 +62,7 @@ struct TypeCloner
void operator()(const LazyTypeVar& t); void operator()(const LazyTypeVar& t);
void operator()(const UnknownTypeVar& t); void operator()(const UnknownTypeVar& t);
void operator()(const NeverTypeVar& t); void operator()(const NeverTypeVar& t);
void operator()(const NegationTypeVar& t);
}; };
struct TypePackCloner struct TypePackCloner
@ -198,21 +198,6 @@ void TypeCloner::operator()(const PrimitiveTypeVar& t)
defaultClone(t); defaultClone(t);
} }
void TypeCloner::operator()(const ConstrainedTypeVar& t)
{
TypeId res = dest.addType(ConstrainedTypeVar{t.level});
ConstrainedTypeVar* ctv = getMutable<ConstrainedTypeVar>(res);
LUAU_ASSERT(ctv);
seenTypes[typeId] = res;
std::vector<TypeId> parts;
for (TypeId part : t.parts)
parts.push_back(clone(part, dest, cloneState));
ctv->parts = std::move(parts);
}
void TypeCloner::operator()(const SingletonTypeVar& t) void TypeCloner::operator()(const SingletonTypeVar& t)
{ {
defaultClone(t); defaultClone(t);
@ -352,6 +337,15 @@ void TypeCloner::operator()(const NeverTypeVar& t)
defaultClone(t); defaultClone(t);
} }
void TypeCloner::operator()(const NegationTypeVar& t)
{
TypeId result = dest.addType(AnyTypeVar{});
seenTypes[typeId] = result;
TypeId ty = clone(t.ty, dest, cloneState);
asMutable(result)->ty = NegationTypeVar{ty};
}
} // anonymous namespace } // anonymous namespace
TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState) TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState)
@ -390,7 +384,6 @@ TypeId clone(TypeId typeId, TypeArena& dest, CloneState& cloneState)
if (!res->persistent) if (!res->persistent)
{ {
asMutable(res)->documentationSymbol = typeId->documentationSymbol; asMutable(res)->documentationSymbol = typeId->documentationSymbol;
asMutable(res)->normal = typeId->normal;
} }
} }
@ -478,11 +471,6 @@ TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysCl
clone.parts = itv->parts; clone.parts = itv->parts;
result = dest.addType(std::move(clone)); result = dest.addType(std::move(clone));
} }
else if (const ConstrainedTypeVar* ctv = get<ConstrainedTypeVar>(ty))
{
ConstrainedTypeVar clone{ctv->level, ctv->parts};
result = dest.addType(std::move(clone));
}
else if (const PendingExpansionTypeVar* petv = get<PendingExpansionTypeVar>(ty)) else if (const PendingExpansionTypeVar* petv = get<PendingExpansionTypeVar>(ty))
{ {
PendingExpansionTypeVar clone{petv->prefix, petv->name, petv->typeArguments, petv->packArguments}; PendingExpansionTypeVar clone{petv->prefix, petv->name, petv->typeArguments, petv->packArguments};
@ -497,6 +485,10 @@ TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysCl
{ {
result = dest.addType(*ty); result = dest.addType(*ty);
} }
else if (const NegationTypeVar* ntv = get<NegationTypeVar>(ty))
{
result = dest.addType(NegationTypeVar{ntv->ty});
}
else else
return result; return result;
@ -504,4 +496,9 @@ TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysCl
return result; return result;
} }
TypeId shallowClone(TypeId ty, NotNull<TypeArena> dest)
{
return shallowClone(ty, *dest, TxnLog::empty());
}
} // namespace Luau } // namespace Luau

View file

@ -0,0 +1,32 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Connective.h"
namespace Luau
{
ConnectiveId ConnectiveArena::negation(ConnectiveId connective)
{
return NotNull{allocator.allocate(Negation{connective})};
}
ConnectiveId ConnectiveArena::conjunction(ConnectiveId lhs, ConnectiveId rhs)
{
return NotNull{allocator.allocate(Conjunction{lhs, rhs})};
}
ConnectiveId ConnectiveArena::disjunction(ConnectiveId lhs, ConnectiveId rhs)
{
return NotNull{allocator.allocate(Disjunction{lhs, rhs})};
}
ConnectiveId ConnectiveArena::equivalence(ConnectiveId lhs, ConnectiveId rhs)
{
return NotNull{allocator.allocate(Equivalence{lhs, rhs})};
}
ConnectiveId ConnectiveArena::proposition(DefId def, TypeId discriminantTy)
{
return NotNull{allocator.allocate(Proposition{def, discriminantTy})};
}
} // namespace Luau

File diff suppressed because it is too large Load diff

View file

@ -3,14 +3,16 @@
#include "Luau/Anyification.h" #include "Luau/Anyification.h"
#include "Luau/ApplyTypeFunction.h" #include "Luau/ApplyTypeFunction.h"
#include "Luau/ConstraintSolver.h" #include "Luau/ConstraintSolver.h"
#include "Luau/DcrLogger.h"
#include "Luau/Instantiation.h" #include "Luau/Instantiation.h"
#include "Luau/Location.h" #include "Luau/Location.h"
#include "Luau/Metamethods.h"
#include "Luau/ModuleResolver.h" #include "Luau/ModuleResolver.h"
#include "Luau/Quantify.h" #include "Luau/Quantify.h"
#include "Luau/ToString.h" #include "Luau/ToString.h"
#include "Luau/TypeUtils.h"
#include "Luau/TypeVar.h" #include "Luau/TypeVar.h"
#include "Luau/Unifier.h" #include "Luau/Unifier.h"
#include "Luau/DcrLogger.h"
#include "Luau/VisitTypeVar.h" #include "Luau/VisitTypeVar.h"
#include "Luau/TypeUtils.h" #include "Luau/TypeUtils.h"
@ -438,6 +440,8 @@ bool ConstraintSolver::tryDispatch(NotNull<const Constraint> constraint, bool fo
success = tryDispatch(*fcc, constraint); success = tryDispatch(*fcc, constraint);
else if (auto hpc = get<HasPropConstraint>(*constraint)) else if (auto hpc = get<HasPropConstraint>(*constraint))
success = tryDispatch(*hpc, constraint); success = tryDispatch(*hpc, constraint);
else if (auto sottc = get<SingletonOrTopTypeConstraint>(*constraint))
success = tryDispatch(*sottc, constraint);
else else
LUAU_ASSERT(false); LUAU_ASSERT(false);
@ -540,6 +544,7 @@ bool ConstraintSolver::tryDispatch(const UnaryConstraint& c, NotNull<const Const
} }
case AstExprUnary::Len: case AstExprUnary::Len:
{ {
// __len must return a number.
asMutable(c.resultType)->ty.emplace<BoundTypeVar>(singletonTypes->numberType); asMutable(c.resultType)->ty.emplace<BoundTypeVar>(singletonTypes->numberType);
return true; return true;
} }
@ -548,13 +553,46 @@ bool ConstraintSolver::tryDispatch(const UnaryConstraint& c, NotNull<const Const
if (isNumber(operandType) || get<AnyTypeVar>(operandType) || get<ErrorTypeVar>(operandType)) if (isNumber(operandType) || get<AnyTypeVar>(operandType) || get<ErrorTypeVar>(operandType))
{ {
asMutable(c.resultType)->ty.emplace<BoundTypeVar>(c.operandType); asMutable(c.resultType)->ty.emplace<BoundTypeVar>(c.operandType);
return true;
} }
break; else if (std::optional<TypeId> mm = findMetatableEntry(singletonTypes, errors, operandType, "__unm", constraint->location))
{
const FunctionTypeVar* ftv = get<FunctionTypeVar>(follow(*mm));
if (!ftv)
{
if (std::optional<TypeId> callMm = findMetatableEntry(singletonTypes, errors, follow(*mm), "__call", constraint->location))
{
ftv = get<FunctionTypeVar>(follow(*callMm));
}
}
if (!ftv)
{
asMutable(c.resultType)->ty.emplace<BoundTypeVar>(singletonTypes->errorRecoveryType());
return true;
}
TypePackId argsPack = arena->addTypePack({operandType});
unify(ftv->argTypes, argsPack, constraint->scope);
TypeId result = singletonTypes->errorRecoveryType();
if (ftv)
{
result = first(ftv->retTypes).value_or(singletonTypes->errorRecoveryType());
}
asMutable(c.resultType)->ty.emplace<BoundTypeVar>(result);
}
else
{
asMutable(c.resultType)->ty.emplace<BoundTypeVar>(singletonTypes->errorRecoveryType());
}
return true;
} }
} }
LUAU_ASSERT(false); // TODO metatable handling LUAU_ASSERT(false);
return false; return false;
} }
@ -564,44 +602,192 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNull<const Cons
TypeId rightType = follow(c.rightType); TypeId rightType = follow(c.rightType);
TypeId resultType = follow(c.resultType); TypeId resultType = follow(c.resultType);
if (isBlocked(leftType) || isBlocked(rightType)) bool isLogical = c.op == AstExprBinary::Op::And || c.op == AstExprBinary::Op::Or;
{
/* Compound assignments create constraints of the form
*
* A <: Binary<op, A, B>
*
* This constraint is the one that is meant to unblock A, so it doesn't
* make any sense to stop and wait for someone else to do it.
*/
if (leftType != resultType && rightType != resultType)
{
block(c.leftType, constraint);
block(c.rightType, constraint);
return false;
}
}
if (isNumber(leftType)) /* Compound assignments create constraints of the form
{ *
unify(leftType, rightType, constraint->scope); * A <: Binary<op, A, B>
asMutable(resultType)->ty.emplace<BoundTypeVar>(leftType); *
return true; * This constraint is the one that is meant to unblock A, so it doesn't
} * make any sense to stop and wait for someone else to do it.
*/
if (isBlocked(leftType) && leftType != resultType)
return block(c.leftType, constraint);
if (isBlocked(rightType) && rightType != resultType)
return block(c.rightType, constraint);
if (!force) if (!force)
{ {
if (get<FreeTypeVar>(leftType)) // Logical expressions may proceed if the LHS is free.
if (get<FreeTypeVar>(leftType) && !isLogical)
return block(leftType, constraint); return block(leftType, constraint);
} }
if (isBlocked(leftType)) // Logical expressions may proceed if the LHS is free.
if (isBlocked(leftType) || (get<FreeTypeVar>(leftType) && !isLogical))
{ {
asMutable(resultType)->ty.emplace<BoundTypeVar>(errorRecoveryType()); asMutable(resultType)->ty.emplace<BoundTypeVar>(errorRecoveryType());
// reportError(constraint->location, CannotInferBinaryOperation{c.op, std::nullopt, CannotInferBinaryOperation::Operation}); unblock(resultType);
return true; return true;
} }
// TODO metatables, classes // For or expressions, the LHS will never have nil as a possible output.
// Consider:
// local foo = nil or 2
// `foo` will always be 2.
if (c.op == AstExprBinary::Op::Or)
leftType = stripNil(singletonTypes, *arena, leftType);
// Metatables go first, even if there is primitive behavior.
if (auto it = kBinaryOpMetamethods.find(c.op); it != kBinaryOpMetamethods.end())
{
// Metatables are not the same. The metamethod will not be invoked.
if ((c.op == AstExprBinary::Op::CompareEq || c.op == AstExprBinary::Op::CompareNe) &&
getMetatable(leftType, singletonTypes) != getMetatable(rightType, singletonTypes))
{
// TODO: Boolean singleton false? The result is _always_ boolean false.
asMutable(resultType)->ty.emplace<BoundTypeVar>(singletonTypes->booleanType);
unblock(resultType);
return true;
}
std::optional<TypeId> mm;
// The LHS metatable takes priority over the RHS metatable, where
// present.
if (std::optional<TypeId> leftMm = findMetatableEntry(singletonTypes, errors, leftType, it->second, constraint->location))
mm = leftMm;
else if (std::optional<TypeId> rightMm = findMetatableEntry(singletonTypes, errors, rightType, it->second, constraint->location))
mm = rightMm;
if (mm)
{
// TODO: Is a table with __call legal here?
// TODO: Overloads
if (const FunctionTypeVar* ftv = get<FunctionTypeVar>(follow(*mm)))
{
TypePackId inferredArgs;
// For >= and > we invoke __lt and __le respectively with
// swapped argument ordering.
if (c.op == AstExprBinary::Op::CompareGe || c.op == AstExprBinary::Op::CompareGt)
{
inferredArgs = arena->addTypePack({rightType, leftType});
}
else
{
inferredArgs = arena->addTypePack({leftType, rightType});
}
unify(inferredArgs, ftv->argTypes, constraint->scope);
TypeId mmResult;
// Comparison operations always evaluate to a boolean,
// regardless of what the metamethod returns.
switch (c.op)
{
case AstExprBinary::Op::CompareEq:
case AstExprBinary::Op::CompareNe:
case AstExprBinary::Op::CompareGe:
case AstExprBinary::Op::CompareGt:
case AstExprBinary::Op::CompareLe:
case AstExprBinary::Op::CompareLt:
mmResult = singletonTypes->booleanType;
break;
default:
mmResult = first(ftv->retTypes).value_or(errorRecoveryType());
}
asMutable(resultType)->ty.emplace<BoundTypeVar>(mmResult);
unblock(resultType);
return true;
}
}
// If there's no metamethod available, fall back to primitive behavior.
}
// If any is present, the expression must evaluate to any as well.
bool leftAny = get<AnyTypeVar>(leftType) || get<ErrorTypeVar>(leftType);
bool rightAny = get<AnyTypeVar>(rightType) || get<ErrorTypeVar>(rightType);
bool anyPresent = leftAny || rightAny;
switch (c.op)
{
// For arithmetic operators, if the LHS is a number, the RHS must be a
// number as well. The result will also be a number.
case AstExprBinary::Op::Add:
case AstExprBinary::Op::Sub:
case AstExprBinary::Op::Mul:
case AstExprBinary::Op::Div:
case AstExprBinary::Op::Pow:
case AstExprBinary::Op::Mod:
if (isNumber(leftType))
{
unify(leftType, rightType, constraint->scope);
asMutable(resultType)->ty.emplace<BoundTypeVar>(anyPresent ? singletonTypes->anyType : leftType);
unblock(resultType);
return true;
}
break;
// For concatenation, if the LHS is a string, the RHS must be a string as
// well. The result will also be a string.
case AstExprBinary::Op::Concat:
if (isString(leftType))
{
unify(leftType, rightType, constraint->scope);
asMutable(resultType)->ty.emplace<BoundTypeVar>(anyPresent ? singletonTypes->anyType : leftType);
unblock(resultType);
return true;
}
break;
// Inexact comparisons require that the types be both numbers or both
// strings, and evaluate to a boolean.
case AstExprBinary::Op::CompareGe:
case AstExprBinary::Op::CompareGt:
case AstExprBinary::Op::CompareLe:
case AstExprBinary::Op::CompareLt:
if ((isNumber(leftType) && isNumber(rightType)) || (isString(leftType) && isString(rightType)))
{
asMutable(resultType)->ty.emplace<BoundTypeVar>(singletonTypes->booleanType);
unblock(resultType);
return true;
}
break;
// == and ~= always evaluate to a boolean, and impose no other constraints
// on their parameters.
case AstExprBinary::Op::CompareEq:
case AstExprBinary::Op::CompareNe:
asMutable(resultType)->ty.emplace<BoundTypeVar>(singletonTypes->booleanType);
unblock(resultType);
return true;
// And evalutes to a boolean if the LHS is falsey, and the RHS type if LHS is
// truthy.
case AstExprBinary::Op::And:
asMutable(resultType)->ty.emplace<BoundTypeVar>(unionOfTypes(rightType, singletonTypes->booleanType, constraint->scope, false));
unblock(resultType);
return true;
// Or evaluates to the LHS type if the LHS is truthy, and the RHS type if
// LHS is falsey.
case AstExprBinary::Op::Or:
asMutable(resultType)->ty.emplace<BoundTypeVar>(unionOfTypes(rightType, leftType, constraint->scope, true));
unblock(resultType);
return true;
default:
iceReporter.ice("Unhandled AstExprBinary::Op for binary operation", constraint->location);
break;
}
// We failed to either evaluate a metamethod or invoke primitive behavior.
unify(leftType, errorRecoveryType(), constraint->scope);
unify(rightType, errorRecoveryType(), constraint->scope);
asMutable(resultType)->ty.emplace<BoundTypeVar>(errorRecoveryType());
unblock(resultType);
return true; return true;
} }
@ -710,6 +896,10 @@ bool ConstraintSolver::tryDispatch(const NameConstraint& c, NotNull<const Constr
ttv->name = c.name; ttv->name = c.name;
else if (MetatableTypeVar* mtv = getMutable<MetatableTypeVar>(target)) else if (MetatableTypeVar* mtv = getMutable<MetatableTypeVar>(target))
mtv->syntheticName = c.name; mtv->syntheticName = c.name;
else if (get<IntersectionTypeVar>(target) || get<UnionTypeVar>(target))
{
// nothing (yet)
}
else else
return block(c.namedType, constraint); return block(c.namedType, constraint);
@ -943,6 +1133,31 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull<cons
return block(c.fn, constraint); return block(c.fn, constraint);
} }
// We don't support magic __call metamethods.
if (std::optional<TypeId> callMm = findMetatableEntry(singletonTypes, errors, fn, "__call", constraint->location))
{
std::vector<TypeId> args{fn};
for (TypeId arg : c.argsPack)
args.push_back(arg);
TypeId instantiatedType = arena->addType(BlockedTypeVar{});
TypeId inferredFnType =
arena->addType(FunctionTypeVar(TypeLevel{}, constraint->scope.get(), arena->addTypePack(TypePack{args, {}}), c.result));
// Alter the inner constraints.
LUAU_ASSERT(c.innerConstraints.size() == 2);
asMutable(*c.innerConstraints.at(0)).c = InstantiationConstraint{instantiatedType, *callMm};
asMutable(*c.innerConstraints.at(1)).c = SubtypeConstraint{inferredFnType, instantiatedType};
unsolvedConstraints.insert(end(unsolvedConstraints), begin(c.innerConstraints), end(c.innerConstraints));
asMutable(c.result)->ty.emplace<FreeTypePack>(constraint->scope);
unblock(c.result);
return true;
}
const FunctionTypeVar* ftv = get<FunctionTypeVar>(fn); const FunctionTypeVar* ftv = get<FunctionTypeVar>(fn);
bool usedMagic = false; bool usedMagic = false;
@ -1059,6 +1274,22 @@ bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNull<const Con
return true; return true;
} }
bool ConstraintSolver::tryDispatch(const SingletonOrTopTypeConstraint& c, NotNull<const Constraint> constraint)
{
if (isBlocked(c.discriminantType))
return false;
TypeId followed = follow(c.discriminantType);
// `nil` is a singleton type too! There's only one value of type `nil`.
if (get<SingletonTypeVar>(followed) || isNil(followed))
*asMutable(c.resultType) = NegationTypeVar{c.discriminantType};
else
*asMutable(c.resultType) = BoundTypeVar{singletonTypes->unknownType};
return true;
}
bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const IterableConstraint& c, NotNull<const Constraint> constraint, bool force) bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const IterableConstraint& c, NotNull<const Constraint> constraint, bool force)
{ {
auto block_ = [&](auto&& t) { auto block_ = [&](auto&& t) {
@ -1502,4 +1733,39 @@ TypePackId ConstraintSolver::errorRecoveryTypePack() const
return singletonTypes->errorRecoveryTypePack(); return singletonTypes->errorRecoveryTypePack();
} }
TypeId ConstraintSolver::unionOfTypes(TypeId a, TypeId b, NotNull<Scope> scope, bool unifyFreeTypes)
{
a = follow(a);
b = follow(b);
if (unifyFreeTypes && (get<FreeTypeVar>(a) || get<FreeTypeVar>(b)))
{
Unifier u{normalizer, Mode::Strict, scope, Location{}, Covariant};
u.useScopes = true;
u.tryUnify(b, a);
if (u.errors.empty())
{
u.log.commit();
return a;
}
else
{
return singletonTypes->errorRecoveryType(singletonTypes->anyType);
}
}
if (*a == *b)
return a;
std::vector<TypeId> types = reduceUnion({a, b});
if (types.empty())
return singletonTypes->neverType;
if (types.size() == 1)
return types[0];
return arena->addType(UnionTypeVar{types});
}
} // namespace Luau } // namespace Luau

View file

@ -0,0 +1,440 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/DataFlowGraphBuilder.h"
#include "Luau/Error.h"
LUAU_FASTFLAG(DebugLuauFreezeArena)
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
namespace Luau
{
std::optional<DefId> DataFlowGraph::getDef(const AstExpr* expr) const
{
if (auto def = astDefs.find(expr))
return NotNull{*def};
return std::nullopt;
}
std::optional<DefId> DataFlowGraph::getDef(const AstLocal* local) const
{
if (auto def = localDefs.find(local))
return NotNull{*def};
return std::nullopt;
}
std::optional<DefId> DataFlowGraph::getDef(const Symbol& symbol) const
{
if (symbol.local)
return getDef(symbol.local);
else
return std::nullopt;
}
DataFlowGraph DataFlowGraphBuilder::build(AstStatBlock* block, NotNull<InternalErrorReporter> handle)
{
LUAU_ASSERT(FFlag::DebugLuauDeferredConstraintResolution);
DataFlowGraphBuilder builder;
builder.handle = handle;
builder.visit(nullptr, block); // nullptr is the root DFG scope.
if (FFlag::DebugLuauFreezeArena)
builder.arena->allocator.freeze();
return std::move(builder.graph);
}
DfgScope* DataFlowGraphBuilder::childScope(DfgScope* scope)
{
return scopes.emplace_back(new DfgScope{scope}).get();
}
std::optional<DefId> DataFlowGraphBuilder::use(DfgScope* scope, Symbol symbol, AstExpr* e)
{
for (DfgScope* current = scope; current; current = current->parent)
{
if (auto loc = current->bindings.find(symbol))
{
graph.astDefs[e] = *loc;
return NotNull{*loc};
}
}
return std::nullopt;
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatBlock* b)
{
DfgScope* child = childScope(scope);
return visitBlockWithoutChildScope(child, b);
}
void DataFlowGraphBuilder::visitBlockWithoutChildScope(DfgScope* scope, AstStatBlock* b)
{
for (AstStat* s : b->body)
visit(scope, s);
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStat* s)
{
if (auto b = s->as<AstStatBlock>())
return visit(scope, b);
else if (auto i = s->as<AstStatIf>())
return visit(scope, i);
else if (auto w = s->as<AstStatWhile>())
return visit(scope, w);
else if (auto r = s->as<AstStatRepeat>())
return visit(scope, r);
else if (auto b = s->as<AstStatBreak>())
return visit(scope, b);
else if (auto c = s->as<AstStatContinue>())
return visit(scope, c);
else if (auto r = s->as<AstStatReturn>())
return visit(scope, r);
else if (auto e = s->as<AstStatExpr>())
return visit(scope, e);
else if (auto l = s->as<AstStatLocal>())
return visit(scope, l);
else if (auto f = s->as<AstStatFor>())
return visit(scope, f);
else if (auto f = s->as<AstStatForIn>())
return visit(scope, f);
else if (auto a = s->as<AstStatAssign>())
return visit(scope, a);
else if (auto c = s->as<AstStatCompoundAssign>())
return visit(scope, c);
else if (auto f = s->as<AstStatFunction>())
return visit(scope, f);
else if (auto l = s->as<AstStatLocalFunction>())
return visit(scope, l);
else if (auto t = s->as<AstStatTypeAlias>())
return; // ok
else if (auto d = s->as<AstStatDeclareFunction>())
return; // ok
else if (auto d = s->as<AstStatDeclareGlobal>())
return; // ok
else if (auto d = s->as<AstStatDeclareFunction>())
return; // ok
else if (auto d = s->as<AstStatDeclareClass>())
return; // ok
else if (auto _ = s->as<AstStatError>())
return; // ok
else
handle->ice("Unknown AstStat in DataFlowGraphBuilder");
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatIf* i)
{
DfgScope* condScope = childScope(scope);
visitExpr(condScope, i->condition);
visit(condScope, i->thenbody);
if (i->elsebody)
visit(scope, i->elsebody);
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatWhile* w)
{
// TODO(controlflow): entry point has a back edge from exit point
DfgScope* whileScope = childScope(scope);
visitExpr(whileScope, w->condition);
visit(whileScope, w->body);
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatRepeat* r)
{
// TODO(controlflow): entry point has a back edge from exit point
DfgScope* repeatScope = childScope(scope); // TODO: loop scope.
visitBlockWithoutChildScope(repeatScope, r->body);
visitExpr(repeatScope, r->condition);
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatBreak* b)
{
// TODO: Control flow analysis
return; // ok
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatContinue* c)
{
// TODO: Control flow analysis
return; // ok
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatReturn* r)
{
// TODO: Control flow analysis
for (AstExpr* e : r->list)
visitExpr(scope, e);
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatExpr* e)
{
visitExpr(scope, e->expr);
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatLocal* l)
{
// TODO: alias tracking
for (AstExpr* e : l->values)
visitExpr(scope, e);
for (AstLocal* local : l->vars)
{
DefId def = arena->freshDef();
graph.localDefs[local] = def;
scope->bindings[local] = def;
}
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatFor* f)
{
DfgScope* forScope = childScope(scope); // TODO: loop scope.
DefId def = arena->freshDef();
graph.localDefs[f->var] = def;
scope->bindings[f->var] = def;
// TODO(controlflow): entry point has a back edge from exit point
visit(forScope, f->body);
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatForIn* f)
{
DfgScope* forScope = childScope(scope); // TODO: loop scope.
for (AstLocal* local : f->vars)
{
DefId def = arena->freshDef();
graph.localDefs[local] = def;
forScope->bindings[local] = def;
}
// TODO(controlflow): entry point has a back edge from exit point
for (AstExpr* e : f->values)
visitExpr(forScope, e);
visit(forScope, f->body);
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatAssign* a)
{
for (AstExpr* r : a->values)
visitExpr(scope, r);
for (AstExpr* l : a->vars)
{
AstExpr* root = l;
bool isUpdatable = true;
while (true)
{
if (root->is<AstExprLocal>() || root->is<AstExprGlobal>())
break;
AstExprIndexName* indexName = root->as<AstExprIndexName>();
if (!indexName)
{
isUpdatable = false;
break;
}
root = indexName->expr;
}
if (isUpdatable)
{
// TODO global?
if (auto exprLocal = root->as<AstExprLocal>())
{
DefId def = arena->freshDef();
graph.astDefs[exprLocal] = def;
// Update the def in the scope that introduced the local. Not
// the current scope.
AstLocal* local = exprLocal->local;
DfgScope* s = scope;
while (s && !s->bindings.find(local))
s = s->parent;
LUAU_ASSERT(s && s->bindings.find(local));
s->bindings[local] = def;
}
}
visitExpr(scope, l); // TODO: they point to a new def!!
}
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatCompoundAssign* c)
{
// TODO(typestates): The lhs is being read and written to. This might or might not be annoying.
visitExpr(scope, c->value);
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatFunction* f)
{
visitExpr(scope, f->name);
visitExpr(scope, f->func);
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatLocalFunction* l)
{
DefId def = arena->freshDef();
graph.localDefs[l->name] = def;
scope->bindings[l->name] = def;
visitExpr(scope, l->func);
}
ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExpr* e)
{
if (auto g = e->as<AstExprGroup>())
return visitExpr(scope, g->expr);
else if (auto c = e->as<AstExprConstantNil>())
return {}; // ok
else if (auto c = e->as<AstExprConstantBool>())
return {}; // ok
else if (auto c = e->as<AstExprConstantNumber>())
return {}; // ok
else if (auto c = e->as<AstExprConstantString>())
return {}; // ok
else if (auto l = e->as<AstExprLocal>())
return visitExpr(scope, l);
else if (auto g = e->as<AstExprGlobal>())
return visitExpr(scope, g);
else if (auto v = e->as<AstExprVarargs>())
return {}; // ok
else if (auto c = e->as<AstExprCall>())
return visitExpr(scope, c);
else if (auto i = e->as<AstExprIndexName>())
return visitExpr(scope, i);
else if (auto i = e->as<AstExprIndexExpr>())
return visitExpr(scope, i);
else if (auto f = e->as<AstExprFunction>())
return visitExpr(scope, f);
else if (auto t = e->as<AstExprTable>())
return visitExpr(scope, t);
else if (auto u = e->as<AstExprUnary>())
return visitExpr(scope, u);
else if (auto b = e->as<AstExprBinary>())
return visitExpr(scope, b);
else if (auto t = e->as<AstExprTypeAssertion>())
return visitExpr(scope, t);
else if (auto i = e->as<AstExprIfElse>())
return visitExpr(scope, i);
else if (auto i = e->as<AstExprInterpString>())
return visitExpr(scope, i);
else if (auto _ = e->as<AstExprError>())
return {}; // ok
else
handle->ice("Unknown AstExpr in DataFlowGraphBuilder");
}
ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprLocal* l)
{
return {use(scope, l->local, l)};
}
ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprGlobal* g)
{
return {use(scope, g->name, g)};
}
ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprCall* c)
{
visitExpr(scope, c->func);
for (AstExpr* arg : c->args)
visitExpr(scope, arg);
return {};
}
ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIndexName* i)
{
std::optional<DefId> def = visitExpr(scope, i->expr).def;
if (!def)
return {};
// TODO: properties for the above def.
return {};
}
ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIndexExpr* i)
{
visitExpr(scope, i->expr);
visitExpr(scope, i->expr);
if (i->index->as<AstExprConstantString>())
{
// TODO: properties for the def
}
return {};
}
ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprFunction* f)
{
if (AstLocal* self = f->self)
{
DefId def = arena->freshDef();
graph.localDefs[self] = def;
scope->bindings[self] = def;
}
for (AstLocal* param : f->args)
{
DefId def = arena->freshDef();
graph.localDefs[param] = def;
scope->bindings[param] = def;
}
visit(scope, f->body);
return {};
}
ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprTable* t)
{
return {};
}
ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprUnary* u)
{
visitExpr(scope, u->expr);
return {};
}
ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprBinary* b)
{
visitExpr(scope, b->left);
visitExpr(scope, b->right);
return {};
}
ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprTypeAssertion* t)
{
ExpressionFlowGraph result = visitExpr(scope, t->expr);
// TODO: visit type
return result;
}
ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIfElse* i)
{
DfgScope* condScope = childScope(scope);
visitExpr(condScope, i->condition);
visitExpr(condScope, i->trueExpr);
visitExpr(scope, i->falseExpr);
return {};
}
ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprInterpString* i)
{
for (AstExpr* e : i->expressions)
visitExpr(scope, e);
return {};
}
} // namespace Luau

12
Analysis/src/Def.cpp Normal file
View file

@ -0,0 +1,12 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Def.h"
namespace Luau
{
DefId DefArena::freshDef()
{
return NotNull{allocator.allocate<Def>(Undefined{})};
}
} // namespace Luau

View file

@ -13,47 +13,47 @@ declare bit32: {
bor: (...number) -> number, bor: (...number) -> number,
bxor: (...number) -> number, bxor: (...number) -> number,
btest: (number, ...number) -> boolean, btest: (number, ...number) -> boolean,
rrotate: (number, number) -> number, rrotate: (x: number, disp: number) -> number,
lrotate: (number, number) -> number, lrotate: (x: number, disp: number) -> number,
lshift: (number, number) -> number, lshift: (x: number, disp: number) -> number,
arshift: (number, number) -> number, arshift: (x: number, disp: number) -> number,
rshift: (number, number) -> number, rshift: (x: number, disp: number) -> number,
bnot: (number) -> number, bnot: (x: number) -> number,
extract: (number, number, number?) -> number, extract: (n: number, field: number, width: number?) -> number,
replace: (number, number, number, number?) -> number, replace: (n: number, v: number, field: number, width: number?) -> number,
countlz: (number) -> number, countlz: (n: number) -> number,
countrz: (number) -> number, countrz: (n: number) -> number,
} }
declare math: { declare math: {
frexp: (number) -> (number, number), frexp: (n: number) -> (number, number),
ldexp: (number, number) -> number, ldexp: (s: number, e: number) -> number,
fmod: (number, number) -> number, fmod: (x: number, y: number) -> number,
modf: (number) -> (number, number), modf: (n: number) -> (number, number),
pow: (number, number) -> number, pow: (x: number, y: number) -> number,
exp: (number) -> number, exp: (n: number) -> number,
ceil: (number) -> number, ceil: (n: number) -> number,
floor: (number) -> number, floor: (n: number) -> number,
abs: (number) -> number, abs: (n: number) -> number,
sqrt: (number) -> number, sqrt: (n: number) -> number,
log: (number, number?) -> number, log: (n: number, base: number?) -> number,
log10: (number) -> number, log10: (n: number) -> number,
rad: (number) -> number, rad: (n: number) -> number,
deg: (number) -> number, deg: (n: number) -> number,
sin: (number) -> number, sin: (n: number) -> number,
cos: (number) -> number, cos: (n: number) -> number,
tan: (number) -> number, tan: (n: number) -> number,
sinh: (number) -> number, sinh: (n: number) -> number,
cosh: (number) -> number, cosh: (n: number) -> number,
tanh: (number) -> number, tanh: (n: number) -> number,
atan: (number) -> number, atan: (n: number) -> number,
acos: (number) -> number, acos: (n: number) -> number,
asin: (number) -> number, asin: (n: number) -> number,
atan2: (number, number) -> number, atan2: (y: number, x: number) -> number,
min: (number, ...number) -> number, min: (number, ...number) -> number,
max: (number, ...number) -> number, max: (number, ...number) -> number,
@ -61,13 +61,13 @@ declare math: {
pi: number, pi: number,
huge: number, huge: number,
randomseed: (number) -> (), randomseed: (seed: number) -> (),
random: (number?, number?) -> number, random: (number?, number?) -> number,
sign: (number) -> number, sign: (n: number) -> number,
clamp: (number, number, number) -> number, clamp: (n: number, min: number, max: number) -> number,
noise: (number, number?, number?) -> number, noise: (x: number, y: number?, z: number?) -> number,
round: (number) -> number, round: (n: number) -> number,
} }
type DateTypeArg = { type DateTypeArg = {
@ -93,9 +93,9 @@ type DateTypeResult = {
} }
declare os: { declare os: {
time: (DateTypeArg?) -> number, time: (time: DateTypeArg?) -> number,
date: (string?, number?) -> DateTypeResult | string, date: (formatString: string?, time: number?) -> DateTypeResult | string,
difftime: (DateTypeResult | number, DateTypeResult | number) -> number, difftime: (t2: DateTypeResult | number, t1: DateTypeResult | number) -> number,
clock: () -> number, clock: () -> number,
} }
@ -145,51 +145,51 @@ declare function loadstring<A...>(src: string, chunkname: string?): (((A...) ->
declare function newproxy(mt: boolean?): any declare function newproxy(mt: boolean?): any
declare coroutine: { declare coroutine: {
create: <A..., R...>((A...) -> R...) -> thread, create: <A..., R...>(f: (A...) -> R...) -> thread,
resume: <A..., R...>(thread, A...) -> (boolean, R...), resume: <A..., R...>(co: thread, A...) -> (boolean, R...),
running: () -> thread, running: () -> thread,
status: (thread) -> "dead" | "running" | "normal" | "suspended", status: (co: thread) -> "dead" | "running" | "normal" | "suspended",
-- FIXME: This technically returns a function, but we can't represent this yet. -- FIXME: This technically returns a function, but we can't represent this yet.
wrap: <A..., R...>((A...) -> R...) -> any, wrap: <A..., R...>(f: (A...) -> R...) -> any,
yield: <A..., R...>(A...) -> R..., yield: <A..., R...>(A...) -> R...,
isyieldable: () -> boolean, isyieldable: () -> boolean,
close: (thread) -> (boolean, any) close: (co: thread) -> (boolean, any)
} }
declare table: { declare table: {
concat: <V>({V}, string?, number?, number?) -> string, concat: <V>(t: {V}, sep: string?, i: number?, j: number?) -> string,
insert: (<V>({V}, V) -> ()) & (<V>({V}, number, V) -> ()), insert: (<V>(t: {V}, value: V) -> ()) & (<V>(t: {V}, pos: number, value: V) -> ()),
maxn: <V>({V}) -> number, maxn: <V>(t: {V}) -> number,
remove: <V>({V}, number?) -> V?, remove: <V>(t: {V}, number?) -> V?,
sort: <V>({V}, ((V, V) -> boolean)?) -> (), sort: <V>(t: {V}, comp: ((V, V) -> boolean)?) -> (),
create: <V>(number, V?) -> {V}, create: <V>(count: number, value: V?) -> {V},
find: <V>({V}, V, number?) -> number?, find: <V>(haystack: {V}, needle: V, init: number?) -> number?,
unpack: <V>({V}, number?, number?) -> ...V, unpack: <V>(list: {V}, i: number?, j: number?) -> ...V,
pack: <V>(...V) -> { n: number, [number]: V }, pack: <V>(...V) -> { n: number, [number]: V },
getn: <V>({V}) -> number, getn: <V>(t: {V}) -> number,
foreach: <K, V>({[K]: V}, (K, V) -> ()) -> (), foreach: <K, V>(t: {[K]: V}, f: (K, V) -> ()) -> (),
foreachi: <V>({V}, (number, V) -> ()) -> (), foreachi: <V>({V}, (number, V) -> ()) -> (),
move: <V>({V}, number, number, number, {V}?) -> {V}, move: <V>(src: {V}, a: number, b: number, t: number, dst: {V}?) -> {V},
clear: <K, V>({[K]: V}) -> (), clear: <K, V>(table: {[K]: V}) -> (),
isfrozen: <K, V>({[K]: V}) -> boolean, isfrozen: <K, V>(t: {[K]: V}) -> boolean,
} }
declare debug: { declare debug: {
info: (<R...>(thread, number, string) -> R...) & (<R...>(number, string) -> R...) & (<A..., R1..., R2...>((A...) -> R1..., string) -> R2...), info: (<R...>(thread: thread, level: number, options: string) -> R...) & (<R...>(level: number, options: string) -> R...) & (<A..., R1..., R2...>(func: (A...) -> R1..., options: string) -> R2...),
traceback: ((string?, number?) -> string) & ((thread, string?, number?) -> string), traceback: ((message: string?, level: number?) -> string) & ((thread: thread, message: string?, level: number?) -> string),
} }
declare utf8: { declare utf8: {
char: (...number) -> string, char: (...number) -> string,
charpattern: string, charpattern: string,
codes: (string) -> ((string, number) -> (number, number), string, number), codes: (str: string) -> ((string, number) -> (number, number), string, number),
codepoint: (string, number?, number?) -> ...number, codepoint: (str: string, i: number?, j: number?) -> ...number,
len: (string, number?, number?) -> (number?, number?), len: (s: string, i: number?, j: number?) -> (number?, number?),
offset: (string, number?, number?) -> number, offset: (s: string, n: number?, i: number?) -> number,
} }
-- Cannot use `typeof` here because it will produce a polytype when we expect a monotype. -- Cannot use `typeof` here because it will produce a polytype when we expect a monotype.

View file

@ -7,7 +7,7 @@
#include <stdexcept> #include <stdexcept>
LUAU_FASTFLAGVARIABLE(LuauTypeMismatchModuleNameResolution, false) LUAU_FASTFLAGVARIABLE(LuauIceExceptionInheritanceChange, false)
static std::string wrongNumberOfArgsString( static std::string wrongNumberOfArgsString(
size_t expectedCount, std::optional<size_t> maximumCount, size_t actualCount, const char* argPrefix = nullptr, bool isVariadic = false) size_t expectedCount, std::optional<size_t> maximumCount, size_t actualCount, const char* argPrefix = nullptr, bool isVariadic = false)
@ -70,7 +70,7 @@ struct ErrorConverter
{ {
if (auto wantedDefinitionModule = getDefinitionModuleName(tm.wantedType)) if (auto wantedDefinitionModule = getDefinitionModuleName(tm.wantedType))
{ {
if (FFlag::LuauTypeMismatchModuleNameResolution && fileResolver != nullptr) if (fileResolver != nullptr)
{ {
std::string givenModuleName = fileResolver->getHumanReadableModuleName(*givenDefinitionModule); std::string givenModuleName = fileResolver->getHumanReadableModuleName(*givenDefinitionModule);
std::string wantedModuleName = fileResolver->getHumanReadableModuleName(*wantedDefinitionModule); std::string wantedModuleName = fileResolver->getHumanReadableModuleName(*wantedDefinitionModule);
@ -96,14 +96,7 @@ struct ErrorConverter
if (!tm.reason.empty()) if (!tm.reason.empty())
result += tm.reason + " "; result += tm.reason + " ";
if (FFlag::LuauTypeMismatchModuleNameResolution) result += Luau::toString(*tm.error, TypeErrorToStringOptions{fileResolver});
{
result += Luau::toString(*tm.error, TypeErrorToStringOptions{fileResolver});
}
else
{
result += Luau::toString(*tm.error);
}
} }
else if (!tm.reason.empty()) else if (!tm.reason.empty())
{ {
@ -469,6 +462,11 @@ struct ErrorConverter
{ {
return "Code is too complex to typecheck! Consider simplifying the code around this area"; return "Code is too complex to typecheck! Consider simplifying the code around this area";
} }
std::string operator()(const TypePackMismatch& e) const
{
return "Type pack '" + toString(e.givenTp) + "' could not be converted into '" + toString(e.wantedTp) + "'";
}
}; };
struct InvalidNameChecker struct InvalidNameChecker
@ -727,6 +725,11 @@ bool TypesAreUnrelated::operator==(const TypesAreUnrelated& rhs) const
return left == rhs.left && right == rhs.right; return left == rhs.left && right == rhs.right;
} }
bool TypePackMismatch::operator==(const TypePackMismatch& rhs) const
{
return *wantedTp == *rhs.wantedTp && *givenTp == *rhs.givenTp;
}
std::string toString(const TypeError& error) std::string toString(const TypeError& error)
{ {
return toString(error, TypeErrorToStringOptions{}); return toString(error, TypeErrorToStringOptions{});
@ -878,6 +881,11 @@ void copyError(T& e, TypeArena& destArena, CloneState cloneState)
else if constexpr (std::is_same_v<T, NormalizationTooComplex>) else if constexpr (std::is_same_v<T, NormalizationTooComplex>)
{ {
} }
else if constexpr (std::is_same_v<T, TypePackMismatch>)
{
e.wantedTp = clone(e.wantedTp);
e.givenTp = clone(e.givenTp);
}
else else
static_assert(always_false_v<T>, "Non-exhaustive type switch"); static_assert(always_false_v<T>, "Non-exhaustive type switch");
} }
@ -922,4 +930,30 @@ const char* InternalCompilerError::what() const throw()
return this->message.data(); return this->message.data();
} }
// TODO: Inline me when LuauIceExceptionInheritanceChange is deleted.
void throwRuntimeError(const std::string& message)
{
if (FFlag::LuauIceExceptionInheritanceChange)
{
throw InternalCompilerError(message);
}
else
{
throw std::runtime_error(message);
}
}
// TODO: Inline me when LuauIceExceptionInheritanceChange is deleted.
void throwRuntimeError(const std::string& message, const std::string& moduleName)
{
if (FFlag::LuauIceExceptionInheritanceChange)
{
throw InternalCompilerError(message, moduleName);
}
else
{
throw std::runtime_error(message);
}
}
} // namespace Luau } // namespace Luau

View file

@ -1,11 +1,13 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Frontend.h" #include "Luau/Frontend.h"
#include "Luau/BuiltinDefinitions.h"
#include "Luau/Clone.h" #include "Luau/Clone.h"
#include "Luau/Common.h" #include "Luau/Common.h"
#include "Luau/Config.h" #include "Luau/Config.h"
#include "Luau/ConstraintGraphBuilder.h" #include "Luau/ConstraintGraphBuilder.h"
#include "Luau/ConstraintSolver.h" #include "Luau/ConstraintSolver.h"
#include "Luau/DataFlowGraphBuilder.h"
#include "Luau/DcrLogger.h" #include "Luau/DcrLogger.h"
#include "Luau/FileResolver.h" #include "Luau/FileResolver.h"
#include "Luau/Parser.h" #include "Luau/Parser.h"
@ -15,7 +17,6 @@
#include "Luau/TypeChecker2.h" #include "Luau/TypeChecker2.h"
#include "Luau/TypeInfer.h" #include "Luau/TypeInfer.h"
#include "Luau/Variant.h" #include "Luau/Variant.h"
#include "Luau/BuiltinDefinitions.h"
#include <algorithm> #include <algorithm>
#include <chrono> #include <chrono>
@ -26,10 +27,11 @@ LUAU_FASTINT(LuauTarjanChildLimit)
LUAU_FASTFLAG(LuauInferInNoCheckMode) LUAU_FASTFLAG(LuauInferInNoCheckMode)
LUAU_FASTFLAG(LuauNoMoreGlobalSingletonTypes) LUAU_FASTFLAG(LuauNoMoreGlobalSingletonTypes)
LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false)
LUAU_FASTFLAGVARIABLE(LuauAutocompleteDynamicLimits, false)
LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 100) LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 100)
LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false) LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false)
LUAU_FASTFLAG(DebugLuauLogSolverToJson); LUAU_FASTFLAG(DebugLuauLogSolverToJson);
LUAU_FASTFLAGVARIABLE(LuauFixMarkDirtyReverseDeps, false)
LUAU_FASTFLAGVARIABLE(LuauPersistTypesAfterGeneratingDocSyms, false)
namespace Luau namespace Luau
{ {
@ -110,24 +112,57 @@ LoadDefinitionFileResult Frontend::loadDefinitionFile(std::string_view source, c
CloneState cloneState; CloneState cloneState;
for (const auto& [name, ty] : checkedModule->declaredGlobals) if (FFlag::LuauPersistTypesAfterGeneratingDocSyms)
{ {
TypeId globalTy = clone(ty, globalTypes, cloneState); std::vector<TypeId> typesToPersist;
std::string documentationSymbol = packageName + "/global/" + name; typesToPersist.reserve(checkedModule->declaredGlobals.size() + checkedModule->getModuleScope()->exportedTypeBindings.size());
generateDocumentationSymbols(globalTy, documentationSymbol);
globalScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol};
persist(globalTy); for (const auto& [name, ty] : checkedModule->declaredGlobals)
{
TypeId globalTy = clone(ty, globalTypes, cloneState);
std::string documentationSymbol = packageName + "/global/" + name;
generateDocumentationSymbols(globalTy, documentationSymbol);
globalScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol};
typesToPersist.push_back(globalTy);
}
for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings)
{
TypeFun globalTy = clone(ty, globalTypes, cloneState);
std::string documentationSymbol = packageName + "/globaltype/" + name;
generateDocumentationSymbols(globalTy.type, documentationSymbol);
globalScope->exportedTypeBindings[name] = globalTy;
typesToPersist.push_back(globalTy.type);
}
for (TypeId ty : typesToPersist)
{
persist(ty);
}
} }
else
for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings)
{ {
TypeFun globalTy = clone(ty, globalTypes, cloneState); for (const auto& [name, ty] : checkedModule->declaredGlobals)
std::string documentationSymbol = packageName + "/globaltype/" + name; {
generateDocumentationSymbols(globalTy.type, documentationSymbol); TypeId globalTy = clone(ty, globalTypes, cloneState);
globalScope->exportedTypeBindings[name] = globalTy; std::string documentationSymbol = packageName + "/global/" + name;
generateDocumentationSymbols(globalTy, documentationSymbol);
globalScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol};
persist(globalTy.type); persist(globalTy);
}
for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings)
{
TypeFun globalTy = clone(ty, globalTypes, cloneState);
std::string documentationSymbol = packageName + "/globaltype/" + name;
generateDocumentationSymbols(globalTy.type, documentationSymbol);
globalScope->exportedTypeBindings[name] = globalTy;
persist(globalTy.type);
}
} }
return LoadDefinitionFileResult{true, parseResult, checkedModule}; return LoadDefinitionFileResult{true, parseResult, checkedModule};
@ -159,24 +194,57 @@ LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr t
CloneState cloneState; CloneState cloneState;
for (const auto& [name, ty] : checkedModule->declaredGlobals) if (FFlag::LuauPersistTypesAfterGeneratingDocSyms)
{ {
TypeId globalTy = clone(ty, typeChecker.globalTypes, cloneState); std::vector<TypeId> typesToPersist;
std::string documentationSymbol = packageName + "/global/" + name; typesToPersist.reserve(checkedModule->declaredGlobals.size() + checkedModule->getModuleScope()->exportedTypeBindings.size());
generateDocumentationSymbols(globalTy, documentationSymbol);
targetScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol};
persist(globalTy); for (const auto& [name, ty] : checkedModule->declaredGlobals)
{
TypeId globalTy = clone(ty, typeChecker.globalTypes, cloneState);
std::string documentationSymbol = packageName + "/global/" + name;
generateDocumentationSymbols(globalTy, documentationSymbol);
targetScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol};
typesToPersist.push_back(globalTy);
}
for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings)
{
TypeFun globalTy = clone(ty, typeChecker.globalTypes, cloneState);
std::string documentationSymbol = packageName + "/globaltype/" + name;
generateDocumentationSymbols(globalTy.type, documentationSymbol);
targetScope->exportedTypeBindings[name] = globalTy;
typesToPersist.push_back(globalTy.type);
}
for (TypeId ty : typesToPersist)
{
persist(ty);
}
} }
else
for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings)
{ {
TypeFun globalTy = clone(ty, typeChecker.globalTypes, cloneState); for (const auto& [name, ty] : checkedModule->declaredGlobals)
std::string documentationSymbol = packageName + "/globaltype/" + name; {
generateDocumentationSymbols(globalTy.type, documentationSymbol); TypeId globalTy = clone(ty, typeChecker.globalTypes, cloneState);
targetScope->exportedTypeBindings[name] = globalTy; std::string documentationSymbol = packageName + "/global/" + name;
generateDocumentationSymbols(globalTy, documentationSymbol);
targetScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol};
persist(globalTy.type); persist(globalTy);
}
for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings)
{
TypeFun globalTy = clone(ty, typeChecker.globalTypes, cloneState);
std::string documentationSymbol = packageName + "/globaltype/" + name;
generateDocumentationSymbols(globalTy.type, documentationSymbol);
targetScope->exportedTypeBindings[name] = globalTy;
persist(globalTy.type);
}
} }
return LoadDefinitionFileResult{true, parseResult, checkedModule}; return LoadDefinitionFileResult{true, parseResult, checkedModule};
@ -425,13 +493,13 @@ CheckResult Frontend::check(const ModuleName& name, std::optional<FrontendOption
{ {
auto it2 = moduleResolverForAutocomplete.modules.find(name); auto it2 = moduleResolverForAutocomplete.modules.find(name);
if (it2 == moduleResolverForAutocomplete.modules.end() || it2->second == nullptr) if (it2 == moduleResolverForAutocomplete.modules.end() || it2->second == nullptr)
throw std::runtime_error("Frontend::modules does not have data for " + name); throwRuntimeError("Frontend::modules does not have data for " + name, name);
} }
else else
{ {
auto it2 = moduleResolver.modules.find(name); auto it2 = moduleResolver.modules.find(name);
if (it2 == moduleResolver.modules.end() || it2->second == nullptr) if (it2 == moduleResolver.modules.end() || it2->second == nullptr)
throw std::runtime_error("Frontend::modules does not have data for " + name); throwRuntimeError("Frontend::modules does not have data for " + name, name);
} }
return CheckResult{ return CheckResult{
@ -488,23 +556,19 @@ CheckResult Frontend::check(const ModuleName& name, std::optional<FrontendOption
else else
typeCheckerForAutocomplete.finishTime = std::nullopt; typeCheckerForAutocomplete.finishTime = std::nullopt;
if (FFlag::LuauAutocompleteDynamicLimits) // TODO: This is a dirty ad hoc solution for autocomplete timeouts
{ // We are trying to dynamically adjust our existing limits to lower total typechecking time under the limit
// TODO: This is a dirty ad hoc solution for autocomplete timeouts // so that we'll have type information for the whole file at lower quality instead of a full abort in the middle
// We are trying to dynamically adjust our existing limits to lower total typechecking time under the limit if (FInt::LuauTarjanChildLimit > 0)
// so that we'll have type information for the whole file at lower quality instead of a full abort in the middle typeCheckerForAutocomplete.instantiationChildLimit = std::max(1, int(FInt::LuauTarjanChildLimit * sourceNode.autocompleteLimitsMult));
if (FInt::LuauTarjanChildLimit > 0) else
typeCheckerForAutocomplete.instantiationChildLimit = typeCheckerForAutocomplete.instantiationChildLimit = std::nullopt;
std::max(1, int(FInt::LuauTarjanChildLimit * sourceNode.autocompleteLimitsMult));
else
typeCheckerForAutocomplete.instantiationChildLimit = std::nullopt;
if (FInt::LuauTypeInferIterationLimit > 0) if (FInt::LuauTypeInferIterationLimit > 0)
typeCheckerForAutocomplete.unifierIterationLimit = typeCheckerForAutocomplete.unifierIterationLimit =
std::max(1, int(FInt::LuauTypeInferIterationLimit * sourceNode.autocompleteLimitsMult)); std::max(1, int(FInt::LuauTypeInferIterationLimit * sourceNode.autocompleteLimitsMult));
else else
typeCheckerForAutocomplete.unifierIterationLimit = std::nullopt; typeCheckerForAutocomplete.unifierIterationLimit = std::nullopt;
}
ModulePtr moduleForAutocomplete = FFlag::DebugLuauDeferredConstraintResolution ModulePtr moduleForAutocomplete = FFlag::DebugLuauDeferredConstraintResolution
? check(sourceModule, mode, environmentScope, requireCycles, /*forAutocomplete*/ true) ? check(sourceModule, mode, environmentScope, requireCycles, /*forAutocomplete*/ true)
@ -518,10 +582,9 @@ CheckResult Frontend::check(const ModuleName& name, std::optional<FrontendOption
{ {
checkResult.timeoutHits.push_back(moduleName); checkResult.timeoutHits.push_back(moduleName);
if (FFlag::LuauAutocompleteDynamicLimits) sourceNode.autocompleteLimitsMult = sourceNode.autocompleteLimitsMult / 2.0;
sourceNode.autocompleteLimitsMult = sourceNode.autocompleteLimitsMult / 2.0;
} }
else if (FFlag::LuauAutocompleteDynamicLimits && duration < autocompleteTimeLimit / 2.0) else if (duration < autocompleteTimeLimit / 2.0)
{ {
sourceNode.autocompleteLimitsMult = std::min(sourceNode.autocompleteLimitsMult * 2.0, 1.0); sourceNode.autocompleteLimitsMult = std::min(sourceNode.autocompleteLimitsMult * 2.0, 1.0);
} }
@ -543,7 +606,7 @@ CheckResult Frontend::check(const ModuleName& name, std::optional<FrontendOption
stats.filesNonstrict += mode == Mode::Nonstrict; stats.filesNonstrict += mode == Mode::Nonstrict;
if (module == nullptr) if (module == nullptr)
throw std::runtime_error("Frontend::check produced a nullptr module for " + moduleName); throwRuntimeError("Frontend::check produced a nullptr module for " + moduleName, moduleName);
if (!frontendOptions.retainFullTypeGraphs) if (!frontendOptions.retainFullTypeGraphs)
{ {
@ -807,13 +870,26 @@ void Frontend::markDirty(const ModuleName& name, std::vector<ModuleName>* marked
sourceNode.dirtyModule = true; sourceNode.dirtyModule = true;
sourceNode.dirtyModuleForAutocomplete = true; sourceNode.dirtyModuleForAutocomplete = true;
if (0 == reverseDeps.count(name)) if (FFlag::LuauFixMarkDirtyReverseDeps)
continue; {
if (0 == reverseDeps.count(next))
continue;
sourceModules.erase(name); sourceModules.erase(next);
const std::vector<ModuleName>& dependents = reverseDeps[name]; const std::vector<ModuleName>& dependents = reverseDeps[next];
queue.insert(queue.end(), dependents.begin(), dependents.end()); queue.insert(queue.end(), dependents.begin(), dependents.end());
}
else
{
if (0 == reverseDeps.count(name))
continue;
sourceModules.erase(name);
const std::vector<ModuleName>& dependents = reverseDeps[name];
queue.insert(queue.end(), dependents.begin(), dependents.end());
}
} }
} }
@ -857,13 +933,25 @@ ModulePtr Frontend::check(
} }
} }
DataFlowGraph dfg = DataFlowGraphBuilder::build(sourceModule.root, NotNull{&iceHandler});
const NotNull<ModuleResolver> mr{forAutocomplete ? &moduleResolverForAutocomplete : &moduleResolver}; const NotNull<ModuleResolver> mr{forAutocomplete ? &moduleResolverForAutocomplete : &moduleResolver};
const ScopePtr& globalScope{forAutocomplete ? typeCheckerForAutocomplete.globalScope : typeChecker.globalScope}; const ScopePtr& globalScope{forAutocomplete ? typeCheckerForAutocomplete.globalScope : typeChecker.globalScope};
Normalizer normalizer{&result->internalTypes, singletonTypes, NotNull{&typeChecker.unifierState}}; Normalizer normalizer{&result->internalTypes, singletonTypes, NotNull{&typeChecker.unifierState}};
ConstraintGraphBuilder cgb{ ConstraintGraphBuilder cgb{
sourceModule.name, result, &result->internalTypes, mr, singletonTypes, NotNull(&iceHandler), globalScope, logger.get()}; sourceModule.name,
result,
&result->internalTypes,
mr,
singletonTypes,
NotNull(&iceHandler),
globalScope,
logger.get(),
NotNull{&dfg},
};
cgb.visit(sourceModule.root); cgb.visit(sourceModule.root);
result->errors = std::move(cgb.errors); result->errors = std::move(cgb.errors);
@ -986,11 +1074,11 @@ SourceModule Frontend::parse(const ModuleName& name, std::string_view src, const
double timestamp = getTimestamp(); double timestamp = getTimestamp();
auto parseResult = Luau::Parser::parse(src.data(), src.size(), *sourceModule.names, *sourceModule.allocator, parseOptions); Luau::ParseResult parseResult = Luau::Parser::parse(src.data(), src.size(), *sourceModule.names, *sourceModule.allocator, parseOptions);
stats.timeParse += getTimestamp() - timestamp; stats.timeParse += getTimestamp() - timestamp;
stats.files++; stats.files++;
stats.lines += std::count(src.begin(), src.end(), '\n') + (src.size() && src.back() != '\n'); stats.lines += parseResult.lines;
if (!parseResult.errors.empty()) if (!parseResult.errors.empty())
sourceModule.parseErrors.insert(sourceModule.parseErrors.end(), parseResult.errors.begin(), parseResult.errors.end()); sourceModule.parseErrors.insert(sourceModule.parseErrors.end(), parseResult.errors.begin(), parseResult.errors.end());

View file

@ -188,6 +188,8 @@ static void errorToString(std::ostream& stream, const T& err)
stream << "TypesAreUnrelated { left = '" + toString(err.left) + "', right = '" + toString(err.right) + "' }"; stream << "TypesAreUnrelated { left = '" + toString(err.left) + "', right = '" + toString(err.right) + "' }";
else if constexpr (std::is_same_v<T, NormalizationTooComplex>) else if constexpr (std::is_same_v<T, NormalizationTooComplex>)
stream << "NormalizationTooComplex { }"; stream << "NormalizationTooComplex { }";
else if constexpr (std::is_same_v<T, TypePackMismatch>)
stream << "TypePackMismatch { wanted = '" + toString(err.wantedTp) + "', given = '" + toString(err.givenTp) + "' }";
else else
static_assert(always_false_v<T>, "Non-exhaustive type switch"); static_assert(always_false_v<T>, "Non-exhaustive type switch");
} }

View file

@ -60,36 +60,6 @@ bool isWithinComment(const SourceModule& sourceModule, Position pos)
return contains(pos, *iter); return contains(pos, *iter);
} }
struct ForceNormal : TypeVarOnceVisitor
{
const TypeArena* typeArena = nullptr;
ForceNormal(const TypeArena* typeArena)
: typeArena(typeArena)
{
}
bool visit(TypeId ty) override
{
if (ty->owningArena != typeArena)
return false;
asMutable(ty)->normal = true;
return true;
}
bool visit(TypeId ty, const FreeTypeVar& ftv) override
{
visit(ty);
return true;
}
bool visit(TypePackId tp, const FreeTypePack& ftp) override
{
return true;
}
};
struct ClonePublicInterface : Substitution struct ClonePublicInterface : Substitution
{ {
NotNull<SingletonTypes> singletonTypes; NotNull<SingletonTypes> singletonTypes;
@ -241,8 +211,6 @@ void Module::clonePublicInterface(NotNull<SingletonTypes> singletonTypes, Intern
moduleScope->varargPack = varargPack; moduleScope->varargPack = varargPack;
} }
ForceNormal forceNormal{&interfaceTypes};
if (exportedTypeBindings) if (exportedTypeBindings)
{ {
for (auto& [name, tf] : *exportedTypeBindings) for (auto& [name, tf] : *exportedTypeBindings)
@ -262,7 +230,6 @@ void Module::clonePublicInterface(NotNull<SingletonTypes> singletonTypes, Intern
{ {
auto t = asMutable(ty); auto t = asMutable(ty);
t->ty = AnyTypeVar{}; t->ty = AnyTypeVar{};
t->normal = true;
} }
} }
} }

File diff suppressed because it is too large Load diff

View file

@ -57,29 +57,6 @@ struct Quantifier final : TypeVarOnceVisitor
return false; return false;
} }
bool visit(TypeId ty, const ConstrainedTypeVar&) override
{
ConstrainedTypeVar* ctv = getMutable<ConstrainedTypeVar>(ty);
seenMutableType = true;
if (!level.subsumes(ctv->level))
return false;
std::vector<TypeId> opts = std::move(ctv->parts);
// We might transmute, so it's not safe to rely on the builtin traversal logic
for (TypeId opt : opts)
traverse(opt);
if (opts.size() == 1)
*asMutable(ty) = BoundTypeVar{opts[0]};
else
*asMutable(ty) = UnionTypeVar{std::move(opts)};
return false;
}
bool visit(TypeId ty, const TableTypeVar&) override bool visit(TypeId ty, const TableTypeVar&) override
{ {
LUAU_ASSERT(getMutable<TableTypeVar>(ty)); LUAU_ASSERT(getMutable<TableTypeVar>(ty));

View file

@ -27,6 +27,44 @@ void Scope::addBuiltinTypeBinding(const Name& name, const TypeFun& tyFun)
builtinTypeNames.insert(name); builtinTypeNames.insert(name);
} }
std::optional<TypeId> Scope::lookup(Symbol sym) const
{
auto r = const_cast<Scope*>(this)->lookupEx(sym);
if (r)
return r->first;
else
return std::nullopt;
}
std::optional<std::pair<TypeId, Scope*>> Scope::lookupEx(Symbol sym)
{
Scope* s = this;
while (true)
{
auto it = s->bindings.find(sym);
if (it != s->bindings.end())
return std::pair{it->second.typeId, s};
if (s->parent)
s = s->parent.get();
else
return std::nullopt;
}
}
// TODO: We might kill Scope::lookup(Symbol) once data flow is fully fleshed out with type states and control flow analysis.
std::optional<TypeId> Scope::lookup(DefId def) const
{
for (const Scope* current = this; current; current = current->parent.get())
{
if (auto ty = current->dcrRefinements.find(def))
return *ty;
}
return std::nullopt;
}
std::optional<TypeFun> Scope::lookupType(const Name& name) std::optional<TypeFun> Scope::lookupType(const Name& name)
{ {
const Scope* scope = this; const Scope* scope = this;
@ -111,23 +149,6 @@ std::optional<Binding> Scope::linearSearchForBinding(const std::string& name, bo
return std::nullopt; return std::nullopt;
} }
std::optional<TypeId> Scope::lookup(Symbol sym)
{
Scope* s = this;
while (true)
{
auto it = s->bindings.find(sym);
if (it != s->bindings.end())
return it->second.typeId;
if (s->parent)
s = s->parent.get();
else
return std::nullopt;
}
}
bool subsumesStrict(Scope* left, Scope* right) bool subsumesStrict(Scope* left, Scope* right)
{ {
while (right) while (right)

View file

@ -73,11 +73,6 @@ void Tarjan::visitChildren(TypeId ty, int index)
for (TypeId part : itv->parts) for (TypeId part : itv->parts)
visitChild(part); visitChild(part);
} }
else if (const ConstrainedTypeVar* ctv = get<ConstrainedTypeVar>(ty))
{
for (TypeId part : ctv->parts)
visitChild(part);
}
else if (const PendingExpansionTypeVar* petv = get<PendingExpansionTypeVar>(ty)) else if (const PendingExpansionTypeVar* petv = get<PendingExpansionTypeVar>(ty))
{ {
for (TypeId a : petv->typeArguments) for (TypeId a : petv->typeArguments)
@ -97,6 +92,10 @@ void Tarjan::visitChildren(TypeId ty, int index)
if (ctv->metatable) if (ctv->metatable)
visitChild(*ctv->metatable); visitChild(*ctv->metatable);
} }
else if (const NegationTypeVar* ntv = get<NegationTypeVar>(ty))
{
visitChild(ntv->ty);
}
} }
void Tarjan::visitChildren(TypePackId tp, int index) void Tarjan::visitChildren(TypePackId tp, int index)
@ -605,11 +604,6 @@ void Substitution::replaceChildren(TypeId ty)
for (TypeId& part : itv->parts) for (TypeId& part : itv->parts)
part = replace(part); part = replace(part);
} }
else if (ConstrainedTypeVar* ctv = getMutable<ConstrainedTypeVar>(ty))
{
for (TypeId& part : ctv->parts)
part = replace(part);
}
else if (PendingExpansionTypeVar* petv = getMutable<PendingExpansionTypeVar>(ty)) else if (PendingExpansionTypeVar* petv = getMutable<PendingExpansionTypeVar>(ty))
{ {
for (TypeId& a : petv->typeArguments) for (TypeId& a : petv->typeArguments)
@ -629,6 +623,10 @@ void Substitution::replaceChildren(TypeId ty)
if (ctv->metatable) if (ctv->metatable)
ctv->metatable = replace(*ctv->metatable); ctv->metatable = replace(*ctv->metatable);
} }
else if (NegationTypeVar* ntv = getMutable<NegationTypeVar>(ty))
{
ntv->ty = replace(ntv->ty);
}
} }
void Substitution::replaceChildren(TypePackId tp) void Substitution::replaceChildren(TypePackId tp)

View file

@ -237,15 +237,6 @@ void StateDot::visitChildren(TypeId ty, int index)
finishNodeLabel(ty); finishNodeLabel(ty);
finishNode(); finishNode();
} }
else if (const ConstrainedTypeVar* ctv = get<ConstrainedTypeVar>(ty))
{
formatAppend(result, "ConstrainedTypeVar %d", index);
finishNodeLabel(ty);
finishNode();
for (TypeId part : ctv->parts)
visitChild(part, index);
}
else if (get<ErrorTypeVar>(ty)) else if (get<ErrorTypeVar>(ty))
{ {
formatAppend(result, "ErrorTypeVar %d", index); formatAppend(result, "ErrorTypeVar %d", index);

View file

@ -10,11 +10,12 @@
#include <algorithm> #include <algorithm>
#include <stdexcept> #include <stdexcept>
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
LUAU_FASTFLAG(LuauLvaluelessPath)
LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAG(LuauUnknownAndNeverType)
LUAU_FASTFLAGVARIABLE(LuauSpecialTypesAsterisked, false)
LUAU_FASTFLAGVARIABLE(LuauFixNameMaps, false) LUAU_FASTFLAGVARIABLE(LuauFixNameMaps, false)
LUAU_FASTFLAGVARIABLE(LuauUnseeArrayTtv, false)
LUAU_FASTFLAGVARIABLE(LuauFunctionReturnStringificationFixup, false) LUAU_FASTFLAGVARIABLE(LuauFunctionReturnStringificationFixup, false)
LUAU_FASTFLAGVARIABLE(LuauUnseeArrayTtv, false)
/* /*
* Prefix generic typenames with gen- * Prefix generic typenames with gen-
@ -224,6 +225,20 @@ struct StringifierState
result.name += s; result.name += s;
} }
void emitLevel(Scope* scope)
{
size_t count = 0;
for (Scope* s = scope; s; s = s->parent.get())
++count;
emit(count);
emit("-");
char buffer[16];
uint32_t s = uint32_t(intptr_t(scope) & 0xFFFFFF);
snprintf(buffer, sizeof(buffer), "0x%x", s);
emit(buffer);
}
void emit(TypeLevel level) void emit(TypeLevel level)
{ {
emit(std::to_string(level.level)); emit(std::to_string(level.level));
@ -295,10 +310,7 @@ struct TypeVarStringifier
if (tv->ty.valueless_by_exception()) if (tv->ty.valueless_by_exception())
{ {
state.result.error = true; state.result.error = true;
if (FFlag::LuauSpecialTypesAsterisked) state.emit("* VALUELESS BY EXCEPTION *");
state.emit("* VALUELESS BY EXCEPTION *");
else
state.emit("< VALUELESS BY EXCEPTION >");
return; return;
} }
@ -376,7 +388,10 @@ struct TypeVarStringifier
if (FFlag::DebugLuauVerboseTypeNames) if (FFlag::DebugLuauVerboseTypeNames)
{ {
state.emit("-"); state.emit("-");
state.emit(ftv.level); if (FFlag::DebugLuauDeferredConstraintResolution)
state.emitLevel(ftv.scope);
else
state.emit(ftv.level);
} }
} }
@ -398,29 +413,15 @@ struct TypeVarStringifier
} }
else else
state.emit(state.getName(ty)); state.emit(state.getName(ty));
}
void operator()(TypeId, const ConstrainedTypeVar& ctv)
{
state.result.invalid = true;
state.emit("[");
if (FFlag::DebugLuauVerboseTypeNames) if (FFlag::DebugLuauVerboseTypeNames)
state.emit(ctv.level);
state.emit("[");
bool first = true;
for (TypeId ty : ctv.parts)
{ {
if (first) state.emit("-");
first = false; if (FFlag::DebugLuauDeferredConstraintResolution)
state.emitLevel(gtv.scope);
else else
state.emit("|"); state.emit(gtv.level);
stringify(ty);
} }
state.emit("]]");
} }
void operator()(TypeId, const BlockedTypeVar& btv) void operator()(TypeId, const BlockedTypeVar& btv)
@ -456,9 +457,12 @@ struct TypeVarStringifier
case PrimitiveTypeVar::Thread: case PrimitiveTypeVar::Thread:
state.emit("thread"); state.emit("thread");
return; return;
case PrimitiveTypeVar::Function:
state.emit("function");
return;
default: default:
LUAU_ASSERT(!"Unknown primitive type"); LUAU_ASSERT(!"Unknown primitive type");
throw std::runtime_error("Unknown primitive type " + std::to_string(ptv.type)); throwRuntimeError("Unknown primitive type " + std::to_string(ptv.type));
} }
} }
@ -475,7 +479,7 @@ struct TypeVarStringifier
else else
{ {
LUAU_ASSERT(!"Unknown singleton type"); LUAU_ASSERT(!"Unknown singleton type");
throw std::runtime_error("Unknown singleton type"); throwRuntimeError("Unknown singleton type");
} }
} }
@ -484,10 +488,7 @@ struct TypeVarStringifier
if (state.hasSeen(&ftv)) if (state.hasSeen(&ftv))
{ {
state.result.cycle = true; state.result.cycle = true;
if (FFlag::LuauSpecialTypesAsterisked) state.emit("*CYCLE*");
state.emit("*CYCLE*");
else
state.emit("<CYCLE>");
return; return;
} }
@ -595,10 +596,7 @@ struct TypeVarStringifier
if (state.hasSeen(&ttv)) if (state.hasSeen(&ttv))
{ {
state.result.cycle = true; state.result.cycle = true;
if (FFlag::LuauSpecialTypesAsterisked) state.emit("*CYCLE*");
state.emit("*CYCLE*");
else
state.emit("<CYCLE>");
return; return;
} }
@ -732,10 +730,7 @@ struct TypeVarStringifier
if (state.hasSeen(&uv)) if (state.hasSeen(&uv))
{ {
state.result.cycle = true; state.result.cycle = true;
if (FFlag::LuauSpecialTypesAsterisked) state.emit("*CYCLE*");
state.emit("*CYCLE*");
else
state.emit("<CYCLE>");
return; return;
} }
@ -802,10 +797,7 @@ struct TypeVarStringifier
if (state.hasSeen(&uv)) if (state.hasSeen(&uv))
{ {
state.result.cycle = true; state.result.cycle = true;
if (FFlag::LuauSpecialTypesAsterisked) state.emit("*CYCLE*");
state.emit("*CYCLE*");
else
state.emit("<CYCLE>");
return; return;
} }
@ -850,10 +842,7 @@ struct TypeVarStringifier
void operator()(TypeId, const ErrorTypeVar& tv) void operator()(TypeId, const ErrorTypeVar& tv)
{ {
state.result.error = true; state.result.error = true;
if (FFlag::LuauSpecialTypesAsterisked) state.emit(FFlag::LuauUnknownAndNeverType ? "*error-type*" : "*unknown*");
state.emit(FFlag::LuauUnknownAndNeverType ? "*error-type*" : "*unknown*");
else
state.emit(FFlag::LuauUnknownAndNeverType ? "<error-type>" : "*unknown*");
} }
void operator()(TypeId, const LazyTypeVar& ltv) void operator()(TypeId, const LazyTypeVar& ltv)
@ -871,6 +860,23 @@ struct TypeVarStringifier
{ {
state.emit("never"); state.emit("never");
} }
void operator()(TypeId, const NegationTypeVar& ntv)
{
state.emit("~");
// The precedence of `~` should be less than `|` and `&`.
TypeId followed = follow(ntv.ty);
bool parens = get<UnionTypeVar>(followed) || get<IntersectionTypeVar>(followed);
if (parens)
state.emit("(");
stringify(ntv.ty);
if (parens)
state.emit(")");
}
}; };
struct TypePackStringifier struct TypePackStringifier
@ -907,10 +913,7 @@ struct TypePackStringifier
if (tp->ty.valueless_by_exception()) if (tp->ty.valueless_by_exception())
{ {
state.result.error = true; state.result.error = true;
if (FFlag::LuauSpecialTypesAsterisked) state.emit("* VALUELESS TP BY EXCEPTION *");
state.emit("* VALUELESS TP BY EXCEPTION *");
else
state.emit("< VALUELESS TP BY EXCEPTION >");
return; return;
} }
@ -934,10 +937,7 @@ struct TypePackStringifier
if (state.hasSeen(&tp)) if (state.hasSeen(&tp))
{ {
state.result.cycle = true; state.result.cycle = true;
if (FFlag::LuauSpecialTypesAsterisked) state.emit("*CYCLETP*");
state.emit("*CYCLETP*");
else
state.emit("<CYCLETP>");
return; return;
} }
@ -982,10 +982,7 @@ struct TypePackStringifier
void operator()(TypePackId, const Unifiable::Error& error) void operator()(TypePackId, const Unifiable::Error& error)
{ {
state.result.error = true; state.result.error = true;
if (FFlag::LuauSpecialTypesAsterisked) state.emit(FFlag::LuauUnknownAndNeverType ? "*error-type*" : "*unknown*");
state.emit(FFlag::LuauUnknownAndNeverType ? "*error-type*" : "*unknown*");
else
state.emit(FFlag::LuauUnknownAndNeverType ? "<error-type>" : "*unknown*");
} }
void operator()(TypePackId, const VariadicTypePack& pack) void operator()(TypePackId, const VariadicTypePack& pack)
@ -993,10 +990,7 @@ struct TypePackStringifier
state.emit("..."); state.emit("...");
if (FFlag::DebugLuauVerboseTypeNames && pack.hidden) if (FFlag::DebugLuauVerboseTypeNames && pack.hidden)
{ {
if (FFlag::LuauSpecialTypesAsterisked) state.emit("*hidden*");
state.emit("*hidden*");
else
state.emit("<hidden>");
} }
stringify(pack.ty); stringify(pack.ty);
} }
@ -1031,7 +1025,10 @@ struct TypePackStringifier
if (FFlag::DebugLuauVerboseTypeNames) if (FFlag::DebugLuauVerboseTypeNames)
{ {
state.emit("-"); state.emit("-");
state.emit(pack.level); if (FFlag::DebugLuauDeferredConstraintResolution)
state.emitLevel(pack.scope);
else
state.emit(pack.level);
} }
state.emit("..."); state.emit("...");
@ -1204,10 +1201,7 @@ ToStringResult toStringDetailed(TypeId ty, ToStringOptions& opts)
{ {
result.truncated = true; result.truncated = true;
if (FFlag::LuauSpecialTypesAsterisked) result.name += "... *TRUNCATED*";
result.name += "... *TRUNCATED*";
else
result.name += "... <TRUNCATED>";
} }
return result; return result;
@ -1280,10 +1274,7 @@ ToStringResult toStringDetailed(TypePackId tp, ToStringOptions& opts)
if (opts.maxTypeLength > 0 && result.name.length() > opts.maxTypeLength) if (opts.maxTypeLength > 0 && result.name.length() > opts.maxTypeLength)
{ {
if (FFlag::LuauSpecialTypesAsterisked) result.name += "... *TRUNCATED*";
result.name += "... *TRUNCATED*";
else
result.name += "... <TRUNCATED>";
} }
return result; return result;
@ -1442,7 +1433,7 @@ std::string generateName(size_t i)
std::string toString(const Constraint& constraint, ToStringOptions& opts) std::string toString(const Constraint& constraint, ToStringOptions& opts)
{ {
auto go = [&opts](auto&& c) { auto go = [&opts](auto&& c) -> std::string {
using T = std::decay_t<decltype(c)>; using T = std::decay_t<decltype(c)>;
// TODO: Inline and delete this function when clipping FFlag::LuauFixNameMaps // TODO: Inline and delete this function when clipping FFlag::LuauFixNameMaps
@ -1526,6 +1517,13 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts)
{ {
return tos(c.resultType, opts) + " ~ hasProp " + tos(c.subjectType, opts) + ", \"" + c.prop + "\""; return tos(c.resultType, opts) + " ~ hasProp " + tos(c.subjectType, opts) + ", \"" + c.prop + "\"";
} }
else if constexpr (std::is_same_v<T, SingletonOrTopTypeConstraint>)
{
std::string result = tos(c.resultType, opts);
std::string discriminant = tos(c.discriminantType, opts);
return result + " ~ if isSingleton D then ~D else unknown where D = " + discriminant;
}
else else
static_assert(always_false_v<T>, "Non-exhaustive constraint switch"); static_assert(always_false_v<T>, "Non-exhaustive constraint switch");
}; };
@ -1545,6 +1543,8 @@ std::string dump(const Constraint& c)
std::string toString(const LValue& lvalue) std::string toString(const LValue& lvalue)
{ {
LUAU_ASSERT(!FFlag::LuauLvaluelessPath);
std::string s; std::string s;
for (const LValue* current = &lvalue; current; current = baseof(*current)) for (const LValue* current = &lvalue; current; current = baseof(*current))
{ {
@ -1559,4 +1559,37 @@ std::string toString(const LValue& lvalue)
return s; return s;
} }
std::optional<std::string> getFunctionNameAsString(const AstExpr& expr)
{
LUAU_ASSERT(FFlag::LuauLvaluelessPath);
const AstExpr* curr = &expr;
std::string s;
for (;;)
{
if (auto local = curr->as<AstExprLocal>())
return local->local->name.value + s;
if (auto global = curr->as<AstExprGlobal>())
return global->name.value + s;
if (auto indexname = curr->as<AstExprIndexName>())
{
curr = indexname->expr;
s = "." + std::string(indexname->index.value) + s;
}
else if (auto group = curr->as<AstExprGroup>())
{
curr = group->expr;
}
else
{
return std::nullopt;
}
}
return s;
}
} // namespace Luau } // namespace Luau

View file

@ -1,6 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/TopoSortStatements.h" #include "Luau/TopoSortStatements.h"
#include "Luau/Error.h"
/* Decide the order in which we typecheck Lua statements in a block. /* Decide the order in which we typecheck Lua statements in a block.
* *
* Algorithm: * Algorithm:
@ -149,7 +150,7 @@ Identifier mkName(const AstStatFunction& function)
auto name = mkName(*function.name); auto name = mkName(*function.name);
LUAU_ASSERT(bool(name)); LUAU_ASSERT(bool(name));
if (!name) if (!name)
throw std::runtime_error("Internal error: Function declaration has a bad name"); throwRuntimeError("Internal error: Function declaration has a bad name");
return *name; return *name;
} }
@ -255,7 +256,7 @@ struct ArcCollector : public AstVisitor
{ {
auto name = mkName(*node->name); auto name = mkName(*node->name);
if (!name) if (!name)
throw std::runtime_error("Internal error: AstStatFunction has a bad name"); throwRuntimeError("Internal error: AstStatFunction has a bad name");
add(*name); add(*name);
return true; return true;

View file

@ -251,7 +251,7 @@ PendingType* TxnLog::bindTable(TypeId ty, std::optional<TypeId> newBoundTo)
PendingType* TxnLog::changeLevel(TypeId ty, TypeLevel newLevel) PendingType* TxnLog::changeLevel(TypeId ty, TypeLevel newLevel)
{ {
LUAU_ASSERT(get<FreeTypeVar>(ty) || get<TableTypeVar>(ty) || get<FunctionTypeVar>(ty) || get<ConstrainedTypeVar>(ty)); LUAU_ASSERT(get<FreeTypeVar>(ty) || get<TableTypeVar>(ty) || get<FunctionTypeVar>(ty));
PendingType* newTy = queue(ty); PendingType* newTy = queue(ty);
if (FreeTypeVar* ftv = Luau::getMutable<FreeTypeVar>(newTy)) if (FreeTypeVar* ftv = Luau::getMutable<FreeTypeVar>(newTy))
@ -267,11 +267,6 @@ PendingType* TxnLog::changeLevel(TypeId ty, TypeLevel newLevel)
{ {
ftv->level = newLevel; ftv->level = newLevel;
} }
else if (ConstrainedTypeVar* ctv = Luau::getMutable<ConstrainedTypeVar>(newTy))
{
if (FFlag::LuauUnknownAndNeverType)
ctv->level = newLevel;
}
return newTy; return newTy;
} }
@ -291,7 +286,7 @@ PendingTypePack* TxnLog::changeLevel(TypePackId tp, TypeLevel newLevel)
PendingType* TxnLog::changeScope(TypeId ty, NotNull<Scope> newScope) PendingType* TxnLog::changeScope(TypeId ty, NotNull<Scope> newScope)
{ {
LUAU_ASSERT(get<FreeTypeVar>(ty) || get<TableTypeVar>(ty) || get<FunctionTypeVar>(ty) || get<ConstrainedTypeVar>(ty)); LUAU_ASSERT(get<FreeTypeVar>(ty) || get<TableTypeVar>(ty) || get<FunctionTypeVar>(ty));
PendingType* newTy = queue(ty); PendingType* newTy = queue(ty);
if (FreeTypeVar* ftv = Luau::getMutable<FreeTypeVar>(newTy)) if (FreeTypeVar* ftv = Luau::getMutable<FreeTypeVar>(newTy))
@ -307,10 +302,6 @@ PendingType* TxnLog::changeScope(TypeId ty, NotNull<Scope> newScope)
{ {
ftv->scope = newScope; ftv->scope = newScope;
} }
else if (ConstrainedTypeVar* ctv = Luau::getMutable<ConstrainedTypeVar>(newTy))
{
ctv->scope = newScope;
}
return newTy; return newTy;
} }

View file

@ -104,16 +104,6 @@ public:
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("*pending-expansion*")); return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("*pending-expansion*"));
} }
AstType* operator()(const ConstrainedTypeVar& ctv)
{
AstArray<AstType*> types;
types.size = ctv.parts.size();
types.data = static_cast<AstType**>(allocator->allocate(sizeof(AstType*) * ctv.parts.size()));
for (size_t i = 0; i < ctv.parts.size(); ++i)
types.data[i] = Luau::visit(*this, ctv.parts[i]->ty);
return allocator->alloc<AstTypeIntersection>(Location(), types);
}
AstType* operator()(const SingletonTypeVar& stv) AstType* operator()(const SingletonTypeVar& stv)
{ {
if (const BooleanSingleton* bs = get<BooleanSingleton>(&stv)) if (const BooleanSingleton* bs = get<BooleanSingleton>(&stv))
@ -348,6 +338,11 @@ public:
{ {
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName{"never"}); return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName{"never"});
} }
AstType* operator()(const NegationTypeVar& ntv)
{
// FIXME: do the same thing we do with ErrorTypeVar
throwRuntimeError("Cannot convert NegationTypeVar into AstNode");
}
private: private:
Allocator* allocator; Allocator* allocator;

View file

@ -5,6 +5,7 @@
#include "Luau/AstQuery.h" #include "Luau/AstQuery.h"
#include "Luau/Clone.h" #include "Luau/Clone.h"
#include "Luau/Instantiation.h" #include "Luau/Instantiation.h"
#include "Luau/Metamethods.h"
#include "Luau/Normalize.h" #include "Luau/Normalize.h"
#include "Luau/ToString.h" #include "Luau/ToString.h"
#include "Luau/TxnLog.h" #include "Luau/TxnLog.h"
@ -62,6 +63,23 @@ struct StackPusher
} }
}; };
static std::optional<std::string> getIdentifierOfBaseVar(AstExpr* node)
{
if (AstExprGlobal* expr = node->as<AstExprGlobal>())
return expr->name.value;
if (AstExprLocal* expr = node->as<AstExprLocal>())
return expr->local->name.value;
if (AstExprIndexExpr* expr = node->as<AstExprIndexExpr>())
return getIdentifierOfBaseVar(expr->expr);
if (AstExprIndexName* expr = node->as<AstExprIndexName>())
return getIdentifierOfBaseVar(expr->expr);
return std::nullopt;
}
struct TypeChecker2 struct TypeChecker2
{ {
NotNull<SingletonTypes> singletonTypes; NotNull<SingletonTypes> singletonTypes;
@ -283,7 +301,6 @@ struct TypeChecker2
UnifierSharedState sharedState{&ice}; UnifierSharedState sharedState{&ice};
Normalizer normalizer{&arena, singletonTypes, NotNull{&sharedState}}; Normalizer normalizer{&arena, singletonTypes, NotNull{&sharedState}};
Unifier u{NotNull{&normalizer}, Mode::Strict, stack.back(), ret->location, Covariant}; Unifier u{NotNull{&normalizer}, Mode::Strict, stack.back(), ret->location, Covariant};
u.anyIsTop = true;
u.tryUnify(actualRetType, expectedRetType); u.tryUnify(actualRetType, expectedRetType);
const bool ok = u.errors.empty() && u.log.empty(); const bool ok = u.errors.empty() && u.log.empty();
@ -313,16 +330,21 @@ struct TypeChecker2
if (value) if (value)
visit(value); visit(value);
if (i != local->values.size - 1) TypeId* maybeValueType = value ? module->astTypes.find(value) : nullptr;
if (i != local->values.size - 1 || maybeValueType)
{ {
AstLocal* var = i < local->vars.size ? local->vars.data[i] : nullptr; AstLocal* var = i < local->vars.size ? local->vars.data[i] : nullptr;
if (var && var->annotation) if (var && var->annotation)
{ {
TypeId varType = lookupAnnotation(var->annotation); TypeId annotationType = lookupAnnotation(var->annotation);
TypeId valueType = value ? lookupType(value) : nullptr; TypeId valueType = value ? lookupType(value) : nullptr;
if (valueType && !isSubtype(varType, valueType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) if (valueType)
reportError(TypeMismatch{varType, valueType}, value->location); {
ErrorVec errors = tryUnify(stack.back(), value->location, valueType, annotationType);
if (!errors.empty())
reportErrors(std::move(errors));
}
} }
} }
else else
@ -588,7 +610,7 @@ struct TypeChecker2
visit(rhs); visit(rhs);
TypeId rhsType = lookupType(rhs); TypeId rhsType = lookupType(rhs);
if (!isSubtype(rhsType, lhsType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) if (!isSubtype(rhsType, lhsType, stack.back(), singletonTypes, ice))
{ {
reportError(TypeMismatch{lhsType, rhsType}, rhs->location); reportError(TypeMismatch{lhsType, rhsType}, rhs->location);
} }
@ -739,7 +761,7 @@ struct TypeChecker2
TypeId actualType = lookupType(number); TypeId actualType = lookupType(number);
TypeId numberType = singletonTypes->numberType; TypeId numberType = singletonTypes->numberType;
if (!isSubtype(numberType, actualType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) if (!isSubtype(numberType, actualType, stack.back(), singletonTypes, ice))
{ {
reportError(TypeMismatch{actualType, numberType}, number->location); reportError(TypeMismatch{actualType, numberType}, number->location);
} }
@ -750,7 +772,7 @@ struct TypeChecker2
TypeId actualType = lookupType(string); TypeId actualType = lookupType(string);
TypeId stringType = singletonTypes->stringType; TypeId stringType = singletonTypes->stringType;
if (!isSubtype(stringType, actualType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) if (!isSubtype(actualType, stringType, stack.back(), singletonTypes, ice))
{ {
reportError(TypeMismatch{actualType, stringType}, string->location); reportError(TypeMismatch{actualType, stringType}, string->location);
} }
@ -783,26 +805,55 @@ struct TypeChecker2
TypePackId expectedRetType = lookupPack(call); TypePackId expectedRetType = lookupPack(call);
TypeId functionType = lookupType(call->func); TypeId functionType = lookupType(call->func);
LUAU_ASSERT(functionType); TypeId testFunctionType = functionType;
TypePack args;
if (get<AnyTypeVar>(functionType) || get<ErrorTypeVar>(functionType)) if (get<AnyTypeVar>(functionType) || get<ErrorTypeVar>(functionType))
return; return;
else if (std::optional<TypeId> callMm = findMetatableEntry(singletonTypes, module->errors, functionType, "__call", call->func->location))
// TODO: Lots of other types are callable: intersections of functions {
// and things with the __call metamethod. if (get<FunctionTypeVar>(follow(*callMm)))
if (!get<FunctionTypeVar>(functionType)) {
if (std::optional<TypeId> instantiatedCallMm = instantiation.substitute(*callMm))
{
args.head.push_back(functionType);
testFunctionType = follow(*instantiatedCallMm);
}
else
{
reportError(UnificationTooComplex{}, call->func->location);
return;
}
}
else
{
// TODO: This doesn't flag the __call metamethod as the problem
// very clearly.
reportError(CannotCallNonFunction{*callMm}, call->func->location);
return;
}
}
else if (get<FunctionTypeVar>(functionType))
{
if (std::optional<TypeId> instantiatedFunctionType = instantiation.substitute(functionType))
{
testFunctionType = *instantiatedFunctionType;
}
else
{
reportError(UnificationTooComplex{}, call->func->location);
return;
}
}
else
{ {
reportError(CannotCallNonFunction{functionType}, call->func->location); reportError(CannotCallNonFunction{functionType}, call->func->location);
return; return;
} }
TypeId instantiatedFunctionType = follow(instantiation.substitute(functionType).value_or(nullptr));
TypePack args;
for (AstExpr* arg : call->args) for (AstExpr* arg : call->args)
{ {
TypeId argTy = module->astTypes[arg]; TypeId argTy = lookupType(arg);
LUAU_ASSERT(argTy);
args.head.push_back(argTy); args.head.push_back(argTy);
} }
@ -810,7 +861,7 @@ struct TypeChecker2
FunctionTypeVar ftv{argsTp, expectedRetType}; FunctionTypeVar ftv{argsTp, expectedRetType};
TypeId expectedType = arena.addType(ftv); TypeId expectedType = arena.addType(ftv);
if (!isSubtype(instantiatedFunctionType, expectedType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) if (!isSubtype(testFunctionType, expectedType, stack.back(), singletonTypes, ice))
{ {
CloneState cloneState; CloneState cloneState;
expectedType = clone(expectedType, module->internalTypes, cloneState); expectedType = clone(expectedType, module->internalTypes, cloneState);
@ -829,7 +880,7 @@ struct TypeChecker2
getIndexTypeFromType(module->getModuleScope(), leftType, indexName->index.value, indexName->location, /* addErrors */ true); getIndexTypeFromType(module->getModuleScope(), leftType, indexName->index.value, indexName->location, /* addErrors */ true);
if (ty) if (ty)
{ {
if (!isSubtype(resultType, *ty, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) if (!isSubtype(resultType, *ty, stack.back(), singletonTypes, ice))
{ {
reportError(TypeMismatch{resultType, *ty}, indexName->location); reportError(TypeMismatch{resultType, *ty}, indexName->location);
} }
@ -862,7 +913,7 @@ struct TypeChecker2
TypeId inferredArgTy = *argIt; TypeId inferredArgTy = *argIt;
TypeId annotatedArgTy = lookupAnnotation(arg->annotation); TypeId annotatedArgTy = lookupAnnotation(arg->annotation);
if (!isSubtype(annotatedArgTy, inferredArgTy, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) if (!isSubtype(annotatedArgTy, inferredArgTy, stack.back(), singletonTypes, ice))
{ {
reportError(TypeMismatch{annotatedArgTy, inferredArgTy}, arg->location); reportError(TypeMismatch{annotatedArgTy, inferredArgTy}, arg->location);
} }
@ -887,15 +938,264 @@ struct TypeChecker2
void visit(AstExprUnary* expr) void visit(AstExprUnary* expr)
{ {
// TODO!
visit(expr->expr); visit(expr->expr);
NotNull<Scope> scope = stack.back();
TypeId operandType = lookupType(expr->expr);
if (get<AnyTypeVar>(operandType) || get<ErrorTypeVar>(operandType) || get<NeverTypeVar>(operandType))
return;
if (auto it = kUnaryOpMetamethods.find(expr->op); it != kUnaryOpMetamethods.end())
{
std::optional<TypeId> mm = findMetatableEntry(singletonTypes, module->errors, operandType, it->second, expr->location);
if (mm)
{
if (const FunctionTypeVar* ftv = get<FunctionTypeVar>(follow(*mm)))
{
TypePackId expectedArgs = module->internalTypes.addTypePack({operandType});
reportErrors(tryUnify(scope, expr->location, ftv->argTypes, expectedArgs));
if (std::optional<TypeId> ret = first(ftv->retTypes))
{
if (expr->op == AstExprUnary::Op::Len)
{
reportErrors(tryUnify(scope, expr->location, follow(*ret), singletonTypes->numberType));
}
}
else
{
reportError(GenericError{format("Metamethod '%s' must return a value", it->second)}, expr->location);
}
}
return;
}
}
if (expr->op == AstExprUnary::Op::Len)
{
DenseHashSet<TypeId> seen{nullptr};
int recursionCount = 0;
if (!hasLength(operandType, seen, &recursionCount))
{
reportError(NotATable{operandType}, expr->location);
}
}
else if (expr->op == AstExprUnary::Op::Minus)
{
reportErrors(tryUnify(scope, expr->location, operandType, singletonTypes->numberType));
}
else if (expr->op == AstExprUnary::Op::Not)
{
}
else
{
LUAU_ASSERT(!"Unhandled unary operator");
}
} }
void visit(AstExprBinary* expr) void visit(AstExprBinary* expr)
{ {
// TODO!
visit(expr->left); visit(expr->left);
visit(expr->right); visit(expr->right);
NotNull<Scope> scope = stack.back();
bool isEquality = expr->op == AstExprBinary::Op::CompareEq || expr->op == AstExprBinary::Op::CompareNe;
bool isComparison = expr->op >= AstExprBinary::Op::CompareEq && expr->op <= AstExprBinary::Op::CompareGe;
bool isLogical = expr->op == AstExprBinary::Op::And || expr->op == AstExprBinary::Op::Or;
TypeId leftType = lookupType(expr->left);
TypeId rightType = lookupType(expr->right);
if (expr->op == AstExprBinary::Op::Or)
{
leftType = stripNil(singletonTypes, module->internalTypes, leftType);
}
bool isStringOperation = isString(leftType) && isString(rightType);
if (get<AnyTypeVar>(leftType) || get<ErrorTypeVar>(leftType) || get<AnyTypeVar>(rightType) || get<ErrorTypeVar>(rightType))
return;
if ((get<BlockedTypeVar>(leftType) || get<FreeTypeVar>(leftType)) && !isEquality && !isLogical)
{
auto name = getIdentifierOfBaseVar(expr->left);
reportError(CannotInferBinaryOperation{expr->op, name,
isComparison ? CannotInferBinaryOperation::OpKind::Comparison : CannotInferBinaryOperation::OpKind::Operation},
expr->location);
return;
}
if (auto it = kBinaryOpMetamethods.find(expr->op); it != kBinaryOpMetamethods.end())
{
std::optional<TypeId> leftMt = getMetatable(leftType, singletonTypes);
std::optional<TypeId> rightMt = getMetatable(rightType, singletonTypes);
bool matches = leftMt == rightMt;
if (isEquality && !matches)
{
auto testUnion = [&matches, singletonTypes = this->singletonTypes](const UnionTypeVar* utv, std::optional<TypeId> otherMt) {
for (TypeId option : utv)
{
if (getMetatable(follow(option), singletonTypes) == otherMt)
{
matches = true;
break;
}
}
};
if (const UnionTypeVar* utv = get<UnionTypeVar>(leftType); utv && rightMt)
{
testUnion(utv, rightMt);
}
if (const UnionTypeVar* utv = get<UnionTypeVar>(rightType); utv && leftMt && !matches)
{
testUnion(utv, leftMt);
}
}
if (!matches && isComparison)
{
reportError(GenericError{format("Types %s and %s cannot be compared with %s because they do not have the same metatable",
toString(leftType).c_str(), toString(rightType).c_str(), toString(expr->op).c_str())},
expr->location);
return;
}
std::optional<TypeId> mm;
if (std::optional<TypeId> leftMm = findMetatableEntry(singletonTypes, module->errors, leftType, it->second, expr->left->location))
mm = leftMm;
else if (std::optional<TypeId> rightMm = findMetatableEntry(singletonTypes, module->errors, rightType, it->second, expr->right->location))
mm = rightMm;
if (mm)
{
if (const FunctionTypeVar* ftv = get<FunctionTypeVar>(*mm))
{
TypePackId expectedArgs;
// For >= and > we invoke __lt and __le respectively with
// swapped argument ordering.
if (expr->op == AstExprBinary::Op::CompareGe || expr->op == AstExprBinary::Op::CompareGt)
{
expectedArgs = module->internalTypes.addTypePack({rightType, leftType});
}
else
{
expectedArgs = module->internalTypes.addTypePack({leftType, rightType});
}
reportErrors(tryUnify(scope, expr->location, ftv->argTypes, expectedArgs));
if (expr->op == AstExprBinary::CompareEq || expr->op == AstExprBinary::CompareNe || expr->op == AstExprBinary::CompareGe ||
expr->op == AstExprBinary::CompareGt || expr->op == AstExprBinary::Op::CompareLe || expr->op == AstExprBinary::Op::CompareLt)
{
TypePackId expectedRets = module->internalTypes.addTypePack({singletonTypes->booleanType});
if (!isSubtype(ftv->retTypes, expectedRets, scope, singletonTypes, ice))
{
reportError(GenericError{format("Metamethod '%s' must return type 'boolean'", it->second)}, expr->location);
}
}
else if (!first(ftv->retTypes))
{
reportError(GenericError{format("Metamethod '%s' must return a value", it->second)}, expr->location);
}
}
else
{
reportError(CannotCallNonFunction{*mm}, expr->location);
}
return;
}
// If this is a string comparison, or a concatenation of strings, we
// want to fall through to primitive behavior.
else if (!isEquality && !(isStringOperation && (expr->op == AstExprBinary::Op::Concat || isComparison)))
{
if (leftMt || rightMt)
{
if (isComparison)
{
reportError(GenericError{format(
"Types '%s' and '%s' cannot be compared with %s because neither type's metatable has a '%s' metamethod",
toString(leftType).c_str(), toString(rightType).c_str(), toString(expr->op).c_str(), it->second)},
expr->location);
}
else
{
reportError(GenericError{format(
"Operator %s is not applicable for '%s' and '%s' because neither type's metatable has a '%s' metamethod",
toString(expr->op).c_str(), toString(leftType).c_str(), toString(rightType).c_str(), it->second)},
expr->location);
}
return;
}
else if (!leftMt && !rightMt && (get<TableTypeVar>(leftType) || get<TableTypeVar>(rightType)))
{
if (isComparison)
{
reportError(GenericError{format("Types '%s' and '%s' cannot be compared with %s because neither type has a metatable",
toString(leftType).c_str(), toString(rightType).c_str(), toString(expr->op).c_str())},
expr->location);
}
else
{
reportError(GenericError{format("Operator %s is not applicable for '%s' and '%s' because neither type has a metatable",
toString(expr->op).c_str(), toString(leftType).c_str(), toString(rightType).c_str())},
expr->location);
}
return;
}
}
}
switch (expr->op)
{
case AstExprBinary::Op::Add:
case AstExprBinary::Op::Sub:
case AstExprBinary::Op::Mul:
case AstExprBinary::Op::Div:
case AstExprBinary::Op::Pow:
case AstExprBinary::Op::Mod:
reportErrors(tryUnify(scope, expr->left->location, leftType, singletonTypes->numberType));
reportErrors(tryUnify(scope, expr->right->location, rightType, singletonTypes->numberType));
break;
case AstExprBinary::Op::Concat:
reportErrors(tryUnify(scope, expr->left->location, leftType, singletonTypes->stringType));
reportErrors(tryUnify(scope, expr->right->location, rightType, singletonTypes->stringType));
break;
case AstExprBinary::Op::CompareGe:
case AstExprBinary::Op::CompareGt:
case AstExprBinary::Op::CompareLe:
case AstExprBinary::Op::CompareLt:
if (isNumber(leftType))
reportErrors(tryUnify(scope, expr->right->location, rightType, singletonTypes->numberType));
else if (isString(leftType))
reportErrors(tryUnify(scope, expr->right->location, rightType, singletonTypes->stringType));
else
reportError(GenericError{format("Types '%s' and '%s' cannot be compared with relational operator %s", toString(leftType).c_str(),
toString(rightType).c_str(), toString(expr->op).c_str())},
expr->location);
break;
case AstExprBinary::Op::And:
case AstExprBinary::Op::Or:
case AstExprBinary::Op::CompareEq:
case AstExprBinary::Op::CompareNe:
break;
default:
// Unhandled AstExprBinary::Op possibility.
LUAU_ASSERT(false);
}
} }
void visit(AstExprTypeAssertion* expr) void visit(AstExprTypeAssertion* expr)
@ -907,10 +1207,10 @@ struct TypeChecker2
TypeId computedType = lookupType(expr->expr); TypeId computedType = lookupType(expr->expr);
// Note: As an optimization, we try 'number <: number | string' first, as that is the more likely case. // Note: As an optimization, we try 'number <: number | string' first, as that is the more likely case.
if (isSubtype(annotationType, computedType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) if (isSubtype(annotationType, computedType, stack.back(), singletonTypes, ice))
return; return;
if (isSubtype(computedType, annotationType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) if (isSubtype(computedType, annotationType, stack.back(), singletonTypes, ice))
return; return;
reportError(TypesAreUnrelated{computedType, annotationType}, expr->location); reportError(TypesAreUnrelated{computedType, annotationType}, expr->location);
@ -998,9 +1298,8 @@ struct TypeChecker2
Scope* scope = findInnermostScope(ty->location); Scope* scope = findInnermostScope(ty->location);
LUAU_ASSERT(scope); LUAU_ASSERT(scope);
// TODO: Imported types std::optional<TypeFun> alias =
(ty->prefix) ? scope->lookupImportedType(ty->prefix->value, ty->name.value) : scope->lookupType(ty->name.value);
std::optional<TypeFun> alias = scope->lookupType(ty->name.value);
if (alias.has_value()) if (alias.has_value())
{ {
@ -1212,7 +1511,6 @@ struct TypeChecker2
UnifierSharedState sharedState{&ice}; UnifierSharedState sharedState{&ice};
Normalizer normalizer{&module->internalTypes, singletonTypes, NotNull{&sharedState}}; Normalizer normalizer{&module->internalTypes, singletonTypes, NotNull{&sharedState}};
Unifier u{NotNull{&normalizer}, Mode::Strict, scope, location, Covariant}; Unifier u{NotNull{&normalizer}, Mode::Strict, scope, location, Covariant};
u.anyIsTop = true;
u.tryUnify(subTy, superTy); u.tryUnify(subTy, superTy);
return std::move(u.errors); return std::move(u.errors);

View file

@ -31,12 +31,12 @@ LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 5000)
LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 300) LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 300)
LUAU_FASTINTVARIABLE(LuauVisitRecursionLimit, 500) LUAU_FASTINTVARIABLE(LuauVisitRecursionLimit, 500)
LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAG(LuauKnowsTheDataModel3)
LUAU_FASTFLAG(LuauAutocompleteDynamicLimits)
LUAU_FASTFLAG(LuauTypeNormalization2) LUAU_FASTFLAG(LuauTypeNormalization2)
LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false)
LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false. LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false.
LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false) LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false)
LUAU_FASTFLAGVARIABLE(LuauAnyifyModuleReturnGenerics, false) LUAU_FASTFLAGVARIABLE(LuauAnyifyModuleReturnGenerics, false)
LUAU_FASTFLAGVARIABLE(LuauLvaluelessPath, false)
LUAU_FASTFLAGVARIABLE(LuauUnknownAndNeverType, false) LUAU_FASTFLAGVARIABLE(LuauUnknownAndNeverType, false)
LUAU_FASTFLAGVARIABLE(LuauBinaryNeedsExpectedTypesToo, false) LUAU_FASTFLAGVARIABLE(LuauBinaryNeedsExpectedTypesToo, false)
LUAU_FASTFLAGVARIABLE(LuauFixVarargExprHeadType, false) LUAU_FASTFLAGVARIABLE(LuauFixVarargExprHeadType, false)
@ -44,15 +44,15 @@ LUAU_FASTFLAGVARIABLE(LuauNeverTypesAndOperatorsInference, false)
LUAU_FASTFLAGVARIABLE(LuauReturnsFromCallsitesAreNotWidened, false) LUAU_FASTFLAGVARIABLE(LuauReturnsFromCallsitesAreNotWidened, false)
LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAG(LuauInstantiateInSubtyping)
LUAU_FASTFLAGVARIABLE(LuauCompleteVisitor, false) LUAU_FASTFLAGVARIABLE(LuauCompleteVisitor, false)
LUAU_FASTFLAGVARIABLE(LuauUnionOfTypesFollow, false)
LUAU_FASTFLAGVARIABLE(LuauReportShadowedTypeAlias, false) LUAU_FASTFLAGVARIABLE(LuauReportShadowedTypeAlias, false)
LUAU_FASTFLAGVARIABLE(LuauBetterMessagingOnCountMismatch, false) LUAU_FASTFLAGVARIABLE(LuauBetterMessagingOnCountMismatch, false)
LUAU_FASTFLAGVARIABLE(LuauArgMismatchReportFunctionLocation, false)
namespace Luau namespace Luau
{ {
const char* TimeLimitError_DEPRECATED::what() const throw()
const char* TimeLimitError::what() const throw()
{ {
LUAU_ASSERT(!FFlag::LuauIceExceptionInheritanceChange);
return "Typeinfer failed to complete in allotted time"; return "Typeinfer failed to complete in allotted time";
} }
@ -265,6 +265,11 @@ ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optiona
reportErrorCodeTooComplex(module.root->location); reportErrorCodeTooComplex(module.root->location);
return std::move(currentModule); return std::move(currentModule);
} }
catch (const RecursionLimitException_DEPRECATED&)
{
reportErrorCodeTooComplex(module.root->location);
return std::move(currentModule);
}
} }
ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mode mode, std::optional<ScopePtr> environmentScope) ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mode mode, std::optional<ScopePtr> environmentScope)
@ -280,11 +285,8 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo
iceHandler->moduleName = module.name; iceHandler->moduleName = module.name;
normalizer.arena = &currentModule->internalTypes; normalizer.arena = &currentModule->internalTypes;
if (FFlag::LuauAutocompleteDynamicLimits) unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit;
{ unifierState.counters.iterationLimit = unifierIterationLimit ? *unifierIterationLimit : FInt::LuauTypeInferIterationLimit;
unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit;
unifierState.counters.iterationLimit = unifierIterationLimit ? *unifierIterationLimit : FInt::LuauTypeInferIterationLimit;
}
ScopePtr parentScope = environmentScope.value_or(globalScope); ScopePtr parentScope = environmentScope.value_or(globalScope);
ScopePtr moduleScope = std::make_shared<Scope>(parentScope); ScopePtr moduleScope = std::make_shared<Scope>(parentScope);
@ -312,6 +314,10 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo
{ {
currentModule->timeout = true; currentModule->timeout = true;
} }
catch (const TimeLimitError_DEPRECATED&)
{
currentModule->timeout = true;
}
if (FFlag::DebugLuauSharedSelf) if (FFlag::DebugLuauSharedSelf)
{ {
@ -419,7 +425,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStat& program)
ice("Unknown AstStat"); ice("Unknown AstStat");
if (finishTime && TimeTrace::getClock() > *finishTime) if (finishTime && TimeTrace::getClock() > *finishTime)
throw TimeLimitError(); throwTimeLimitError();
} }
// This particular overload is for do...end. If you need to not increase the scope level, use checkBlock directly. // This particular overload is for do...end. If you need to not increase the scope level, use checkBlock directly.
@ -446,6 +452,11 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block)
reportErrorCodeTooComplex(block.location); reportErrorCodeTooComplex(block.location);
return; return;
} }
catch (const RecursionLimitException_DEPRECATED&)
{
reportErrorCodeTooComplex(block.location);
return;
}
} }
struct InplaceDemoter : TypeVarOnceVisitor struct InplaceDemoter : TypeVarOnceVisitor
@ -773,16 +784,6 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatRepeat& statement)
checkExpr(repScope, *statement.condition); checkExpr(repScope, *statement.condition);
} }
void TypeChecker::unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel demotedLevel, const ScopePtr& scope, const Location& location)
{
Unifier state = mkUnifier(scope, location);
state.unifyLowerBound(subTy, superTy, demotedLevel);
state.log.commit();
reportErrors(state.errors);
}
struct Demoter : Substitution struct Demoter : Substitution
{ {
Demoter(TypeArena* arena) Demoter(TypeArena* arena)
@ -2091,39 +2092,6 @@ std::optional<TypeId> TypeChecker::getIndexTypeFromTypeImpl(
return std::nullopt; return std::nullopt;
} }
std::vector<TypeId> TypeChecker::reduceUnion(const std::vector<TypeId>& types)
{
std::vector<TypeId> result;
for (TypeId t : types)
{
t = follow(t);
if (get<NeverTypeVar>(t))
continue;
if (get<ErrorTypeVar>(t) || get<AnyTypeVar>(t))
return {t};
if (const UnionTypeVar* utv = get<UnionTypeVar>(t))
{
for (TypeId ty : utv)
{
ty = follow(ty);
if (get<NeverTypeVar>(ty))
continue;
if (get<ErrorTypeVar>(ty) || get<AnyTypeVar>(ty))
return {ty};
if (result.end() == std::find(result.begin(), result.end(), ty))
result.push_back(ty);
}
}
else if (std::find(result.begin(), result.end(), t) == result.end())
result.push_back(t);
}
return result;
}
std::optional<TypeId> TypeChecker::tryStripUnionFromNil(TypeId ty) std::optional<TypeId> TypeChecker::tryStripUnionFromNil(TypeId ty)
{ {
if (const UnionTypeVar* utv = get<UnionTypeVar>(ty)) if (const UnionTypeVar* utv = get<UnionTypeVar>(ty))
@ -2503,11 +2471,8 @@ std::string opToMetaTableEntry(const AstExprBinary::Op& op)
TypeId TypeChecker::unionOfTypes(TypeId a, TypeId b, const ScopePtr& scope, const Location& location, bool unifyFreeTypes) TypeId TypeChecker::unionOfTypes(TypeId a, TypeId b, const ScopePtr& scope, const Location& location, bool unifyFreeTypes)
{ {
if (FFlag::LuauUnionOfTypesFollow) a = follow(a);
{ b = follow(b);
a = follow(a);
b = follow(b);
}
if (unifyFreeTypes && (get<FreeTypeVar>(a) || get<FreeTypeVar>(b))) if (unifyFreeTypes && (get<FreeTypeVar>(a) || get<FreeTypeVar>(b)))
{ {
@ -3643,8 +3608,17 @@ void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funNam
location = {state.location.begin, argLocations.back().end}; location = {state.location.begin, argLocations.back().end};
std::string namePath; std::string namePath;
if (std::optional<LValue> lValue = tryGetLValue(funName))
namePath = toString(*lValue); if (FFlag::LuauLvaluelessPath)
{
if (std::optional<std::string> path = getFunctionNameAsString(funName))
namePath = *path;
}
else
{
if (std::optional<LValue> lValue = tryGetLValue(funName))
namePath = toString(*lValue);
}
auto [minParams, optMaxParams] = getParameterExtents(&state.log, paramPack); auto [minParams, optMaxParams] = getParameterExtents(&state.log, paramPack);
state.reportError(TypeError{location, state.reportError(TypeError{location,
@ -3753,11 +3727,28 @@ void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funNam
bool isVariadic = tail && Luau::isVariadic(*tail); bool isVariadic = tail && Luau::isVariadic(*tail);
std::string namePath; std::string namePath;
if (std::optional<LValue> lValue = tryGetLValue(funName))
namePath = toString(*lValue);
state.reportError(TypeError{ if (FFlag::LuauLvaluelessPath)
state.location, CountMismatch{minParams, optMaxParams, paramIndex, CountMismatch::Context::Arg, isVariadic, namePath}}); {
if (std::optional<std::string> path = getFunctionNameAsString(funName))
namePath = *path;
}
else
{
if (std::optional<LValue> lValue = tryGetLValue(funName))
namePath = toString(*lValue);
}
if (FFlag::LuauArgMismatchReportFunctionLocation)
{
state.reportError(TypeError{
funName.location, CountMismatch{minParams, optMaxParams, paramIndex, CountMismatch::Context::Arg, isVariadic, namePath}});
}
else
{
state.reportError(TypeError{
state.location, CountMismatch{minParams, optMaxParams, paramIndex, CountMismatch::Context::Arg, isVariadic, namePath}});
}
return; return;
} }
++paramIter; ++paramIter;
@ -4597,7 +4588,7 @@ TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location locat
Instantiation instantiation{log, &currentModule->internalTypes, scope->level, /*scope*/ nullptr}; Instantiation instantiation{log, &currentModule->internalTypes, scope->level, /*scope*/ nullptr};
if (FFlag::LuauAutocompleteDynamicLimits && instantiationChildLimit) if (instantiationChildLimit)
instantiation.childLimit = *instantiationChildLimit; instantiation.childLimit = *instantiationChildLimit;
std::optional<TypeId> instantiated = instantiation.substitute(ty); std::optional<TypeId> instantiated = instantiation.substitute(ty);
@ -4694,6 +4685,19 @@ void TypeChecker::ice(const std::string& message)
iceHandler->ice(message); iceHandler->ice(message);
} }
// TODO: Inline me when LuauIceExceptionInheritanceChange is deleted.
void TypeChecker::throwTimeLimitError()
{
if (FFlag::LuauIceExceptionInheritanceChange)
{
throw TimeLimitError(iceHandler->moduleName);
}
else
{
throw TimeLimitError_DEPRECATED();
}
}
void TypeChecker::prepareErrorsForDisplay(ErrorVec& errVec) void TypeChecker::prepareErrorsForDisplay(ErrorVec& errVec)
{ {
// Remove errors with names that were generated by recovery from a parse error // Remove errors with names that were generated by recovery from a parse error

View file

@ -1,6 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/TypePack.h" #include "Luau/TypePack.h"
#include "Luau/Error.h"
#include "Luau/TxnLog.h" #include "Luau/TxnLog.h"
#include <stdexcept> #include <stdexcept>
@ -234,7 +235,7 @@ TypePackId follow(TypePackId tp, std::function<TypePackId(TypePackId)> mapper)
cycleTester = nullptr; cycleTester = nullptr;
if (tp == cycleTester) if (tp == cycleTester)
throw std::runtime_error("Luau::follow detected a TypeVar cycle!!"); throwRuntimeError("Luau::follow detected a TypeVar cycle!!");
} }
} }
} }

View file

@ -6,6 +6,8 @@
#include "Luau/ToString.h" #include "Luau/ToString.h"
#include "Luau/TypeInfer.h" #include "Luau/TypeInfer.h"
#include <algorithm>
namespace Luau namespace Luau
{ {
@ -146,18 +148,15 @@ std::optional<TypeId> getIndexTypeFromType(const ScopePtr& scope, ErrorVec& erro
return std::nullopt; return std::nullopt;
} }
goodOptions = reduceUnion(goodOptions);
if (goodOptions.empty()) if (goodOptions.empty())
return singletonTypes->neverType; return singletonTypes->neverType;
if (goodOptions.size() == 1) if (goodOptions.size() == 1)
return goodOptions[0]; return goodOptions[0];
// TODO: inefficient. return arena->addType(UnionTypeVar{std::move(goodOptions)});
TypeId result = arena->addType(UnionTypeVar{std::move(goodOptions)});
auto [ty, ok] = normalize(result, NotNull{scope.get()}, *arena, singletonTypes, handle);
if (!ok && addErrors)
errors.push_back(TypeError{location, NormalizationTooComplex{}});
return ok ? ty : singletonTypes->anyType;
} }
else if (const IntersectionTypeVar* itv = get<IntersectionTypeVar>(type)) else if (const IntersectionTypeVar* itv = get<IntersectionTypeVar>(type))
{ {
@ -264,4 +263,79 @@ std::vector<TypeId> flatten(TypeArena& arena, NotNull<SingletonTypes> singletonT
return result; return result;
} }
std::vector<TypeId> reduceUnion(const std::vector<TypeId>& types)
{
std::vector<TypeId> result;
for (TypeId t : types)
{
t = follow(t);
if (get<NeverTypeVar>(t))
continue;
if (get<ErrorTypeVar>(t) || get<AnyTypeVar>(t))
return {t};
if (const UnionTypeVar* utv = get<UnionTypeVar>(t))
{
for (TypeId ty : utv)
{
ty = follow(ty);
if (get<NeverTypeVar>(ty))
continue;
if (get<ErrorTypeVar>(ty) || get<AnyTypeVar>(ty))
return {ty};
if (result.end() == std::find(result.begin(), result.end(), ty))
result.push_back(ty);
}
}
else if (std::find(result.begin(), result.end(), t) == result.end())
result.push_back(t);
}
return result;
}
static std::optional<TypeId> tryStripUnionFromNil(TypeArena& arena, TypeId ty)
{
if (const UnionTypeVar* utv = get<UnionTypeVar>(ty))
{
if (!std::any_of(begin(utv), end(utv), isNil))
return ty;
std::vector<TypeId> result;
for (TypeId option : utv)
{
if (!isNil(option))
result.push_back(option);
}
if (result.empty())
return std::nullopt;
return result.size() == 1 ? result[0] : arena.addType(UnionTypeVar{std::move(result)});
}
return std::nullopt;
}
TypeId stripNil(NotNull<SingletonTypes> singletonTypes, TypeArena& arena, TypeId ty)
{
ty = follow(ty);
if (get<UnionTypeVar>(ty))
{
std::optional<TypeId> cleaned = tryStripUnionFromNil(arena, ty);
// If there is no union option without 'nil'
if (!cleaned)
return singletonTypes->nilType;
return follow(*cleaned);
}
return follow(ty);
}
} // namespace Luau } // namespace Luau

View file

@ -66,7 +66,7 @@ TypeId follow(TypeId t, std::function<TypeId(TypeId)> mapper)
{ {
TypeId res = ltv->thunk(); TypeId res = ltv->thunk();
if (get<LazyTypeVar>(res)) if (get<LazyTypeVar>(res))
throw std::runtime_error("Lazy TypeVar cannot resolve to another Lazy TypeVar"); throwRuntimeError("Lazy TypeVar cannot resolve to another Lazy TypeVar");
*asMutable(ty) = BoundTypeVar(res); *asMutable(ty) = BoundTypeVar(res);
} }
@ -104,7 +104,7 @@ TypeId follow(TypeId t, std::function<TypeId(TypeId)> mapper)
cycleTester = nullptr; cycleTester = nullptr;
if (t == cycleTester) if (t == cycleTester)
throw std::runtime_error("Luau::follow detected a TypeVar cycle!!"); throwRuntimeError("Luau::follow detected a TypeVar cycle!!");
} }
} }
} }
@ -754,12 +754,15 @@ SingletonTypes::SingletonTypes()
, stringType(arena->addType(TypeVar{PrimitiveTypeVar{PrimitiveTypeVar::String}, /*persistent*/ true})) , stringType(arena->addType(TypeVar{PrimitiveTypeVar{PrimitiveTypeVar::String}, /*persistent*/ true}))
, booleanType(arena->addType(TypeVar{PrimitiveTypeVar{PrimitiveTypeVar::Boolean}, /*persistent*/ true})) , booleanType(arena->addType(TypeVar{PrimitiveTypeVar{PrimitiveTypeVar::Boolean}, /*persistent*/ true}))
, threadType(arena->addType(TypeVar{PrimitiveTypeVar{PrimitiveTypeVar::Thread}, /*persistent*/ true})) , threadType(arena->addType(TypeVar{PrimitiveTypeVar{PrimitiveTypeVar::Thread}, /*persistent*/ true}))
, functionType(arena->addType(TypeVar{PrimitiveTypeVar{PrimitiveTypeVar::Function}, /*persistent*/ true}))
, trueType(arena->addType(TypeVar{SingletonTypeVar{BooleanSingleton{true}}, /*persistent*/ true})) , trueType(arena->addType(TypeVar{SingletonTypeVar{BooleanSingleton{true}}, /*persistent*/ true}))
, falseType(arena->addType(TypeVar{SingletonTypeVar{BooleanSingleton{false}}, /*persistent*/ true})) , falseType(arena->addType(TypeVar{SingletonTypeVar{BooleanSingleton{false}}, /*persistent*/ true}))
, anyType(arena->addType(TypeVar{AnyTypeVar{}, /*persistent*/ true})) , anyType(arena->addType(TypeVar{AnyTypeVar{}, /*persistent*/ true}))
, unknownType(arena->addType(TypeVar{UnknownTypeVar{}, /*persistent*/ true})) , unknownType(arena->addType(TypeVar{UnknownTypeVar{}, /*persistent*/ true}))
, neverType(arena->addType(TypeVar{NeverTypeVar{}, /*persistent*/ true})) , neverType(arena->addType(TypeVar{NeverTypeVar{}, /*persistent*/ true}))
, errorType(arena->addType(TypeVar{ErrorTypeVar{}, /*persistent*/ true})) , errorType(arena->addType(TypeVar{ErrorTypeVar{}, /*persistent*/ true}))
, falsyType(arena->addType(TypeVar{UnionTypeVar{{falseType, nilType}}, /*persistent*/ true}))
, truthyType(arena->addType(TypeVar{NegationTypeVar{falsyType}, /*persistent*/ true}))
, anyTypePack(arena->addTypePack(TypePackVar{VariadicTypePack{anyType}, /*persistent*/ true})) , anyTypePack(arena->addTypePack(TypePackVar{VariadicTypePack{anyType}, /*persistent*/ true}))
, neverTypePack(arena->addTypePack(TypePackVar{VariadicTypePack{neverType}, /*persistent*/ true})) , neverTypePack(arena->addTypePack(TypePackVar{VariadicTypePack{neverType}, /*persistent*/ true}))
, uninhabitableTypePack(arena->addTypePack({neverType}, neverTypePack)) , uninhabitableTypePack(arena->addTypePack({neverType}, neverTypePack))
@ -896,7 +899,6 @@ void persist(TypeId ty)
continue; continue;
asMutable(t)->persistent = true; asMutable(t)->persistent = true;
asMutable(t)->normal = true; // all persistent types are assumed to be normal
if (auto btv = get<BoundTypeVar>(t)) if (auto btv = get<BoundTypeVar>(t))
queue.push_back(btv->boundTo); queue.push_back(btv->boundTo);
@ -933,17 +935,13 @@ void persist(TypeId ty)
for (TypeId opt : itv->parts) for (TypeId opt : itv->parts)
queue.push_back(opt); queue.push_back(opt);
} }
else if (auto ctv = get<ConstrainedTypeVar>(t))
{
for (TypeId opt : ctv->parts)
queue.push_back(opt);
}
else if (auto mtv = get<MetatableTypeVar>(t)) else if (auto mtv = get<MetatableTypeVar>(t))
{ {
queue.push_back(mtv->table); queue.push_back(mtv->table);
queue.push_back(mtv->metatable); queue.push_back(mtv->metatable);
} }
else if (get<GenericTypeVar>(t) || get<AnyTypeVar>(t) || get<FreeTypeVar>(t) || get<SingletonTypeVar>(t) || get<PrimitiveTypeVar>(t)) else if (get<GenericTypeVar>(t) || get<AnyTypeVar>(t) || get<FreeTypeVar>(t) || get<SingletonTypeVar>(t) || get<PrimitiveTypeVar>(t) ||
get<NegationTypeVar>(t))
{ {
} }
else else
@ -990,8 +988,6 @@ const TypeLevel* getLevel(TypeId ty)
return &ttv->level; return &ttv->level;
else if (auto ftv = get<FunctionTypeVar>(ty)) else if (auto ftv = get<FunctionTypeVar>(ty))
return &ftv->level; return &ftv->level;
else if (auto ctv = get<ConstrainedTypeVar>(ty))
return &ctv->level;
else else
return nullptr; return nullptr;
} }
@ -1056,11 +1052,6 @@ const std::vector<TypeId>& getTypes(const IntersectionTypeVar* itv)
return itv->parts; return itv->parts;
} }
const std::vector<TypeId>& getTypes(const ConstrainedTypeVar* ctv)
{
return ctv->parts;
}
UnionTypeVarIterator begin(const UnionTypeVar* utv) UnionTypeVarIterator begin(const UnionTypeVar* utv)
{ {
return UnionTypeVarIterator{utv}; return UnionTypeVarIterator{utv};
@ -1081,17 +1072,6 @@ IntersectionTypeVarIterator end(const IntersectionTypeVar* itv)
return IntersectionTypeVarIterator{}; return IntersectionTypeVarIterator{};
} }
ConstrainedTypeVarIterator begin(const ConstrainedTypeVar* ctv)
{
return ConstrainedTypeVarIterator{ctv};
}
ConstrainedTypeVarIterator end(const ConstrainedTypeVar* ctv)
{
return ConstrainedTypeVarIterator{};
}
static std::vector<TypeId> parseFormatString(TypeChecker& typechecker, const char* data, size_t size) static std::vector<TypeId> parseFormatString(TypeChecker& typechecker, const char* data, size_t size)
{ {
const char* options = "cdiouxXeEfgGqs*"; const char* options = "cdiouxXeEfgGqs*";

View file

@ -8,23 +8,23 @@
#include "Luau/TypePack.h" #include "Luau/TypePack.h"
#include "Luau/TypeUtils.h" #include "Luau/TypeUtils.h"
#include "Luau/TimeTrace.h" #include "Luau/TimeTrace.h"
#include "Luau/TypeVar.h"
#include "Luau/VisitTypeVar.h" #include "Luau/VisitTypeVar.h"
#include "Luau/ToString.h" #include "Luau/ToString.h"
#include <algorithm> #include <algorithm>
LUAU_FASTINT(LuauTypeInferRecursionLimit);
LUAU_FASTINT(LuauTypeInferTypePackLoopLimit); LUAU_FASTINT(LuauTypeInferTypePackLoopLimit);
LUAU_FASTINT(LuauTypeInferIterationLimit);
LUAU_FASTFLAG(LuauAutocompleteDynamicLimits)
LUAU_FASTINTVARIABLE(LuauTypeInferLowerBoundsIterationLimit, 2000);
LUAU_FASTFLAG(LuauErrorRecoveryType); LUAU_FASTFLAG(LuauErrorRecoveryType);
LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAG(LuauUnknownAndNeverType)
LUAU_FASTFLAGVARIABLE(LuauReportTypeMismatchForTypePackUnificationFailure, false)
LUAU_FASTFLAGVARIABLE(LuauSubtypeNormalizer, false); LUAU_FASTFLAGVARIABLE(LuauSubtypeNormalizer, false);
LUAU_FASTFLAGVARIABLE(LuauScalarShapeSubtyping, false) LUAU_FASTFLAGVARIABLE(LuauScalarShapeSubtyping, false)
LUAU_FASTFLAGVARIABLE(LuauInstantiateInSubtyping, false) LUAU_FASTFLAGVARIABLE(LuauInstantiateInSubtyping, false)
LUAU_FASTFLAGVARIABLE(LuauOverloadedFunctionSubtypingPerf, false);
LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution)
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
LUAU_FASTFLAG(LuauNegatedFunctionTypes)
namespace Luau namespace Luau
{ {
@ -95,15 +95,6 @@ struct PromoteTypeLevels final : TypeVarOnceVisitor
return true; return true;
} }
bool visit(TypeId ty, const ConstrainedTypeVar&) override
{
if (!FFlag::LuauUnknownAndNeverType)
return visit(ty);
promote(ty, log.getMutable<ConstrainedTypeVar>(ty));
return true;
}
bool visit(TypeId ty, const FunctionTypeVar&) override bool visit(TypeId ty, const FunctionTypeVar&) override
{ {
// Type levels of types from other modules are already global, so we don't need to promote anything inside // Type levels of types from other modules are already global, so we don't need to promote anything inside
@ -285,7 +276,7 @@ TypeId Widen::clean(TypeId ty)
TypePackId Widen::clean(TypePackId) TypePackId Widen::clean(TypePackId)
{ {
throw std::runtime_error("Widen attempted to clean a dirty type pack?"); throwRuntimeError("Widen attempted to clean a dirty type pack?");
} }
bool Widen::ignoreChildren(TypeId ty) bool Widen::ignoreChildren(TypeId ty)
@ -368,26 +359,14 @@ void Unifier::tryUnify(TypeId subTy, TypeId superTy, bool isFunctionCall, bool i
void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool isIntersection) void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool isIntersection)
{ {
RecursionLimiter _ra(&sharedState.counters.recursionCount, RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit);
FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit);
++sharedState.counters.iterationCount; ++sharedState.counters.iterationCount;
if (FFlag::LuauAutocompleteDynamicLimits) if (sharedState.counters.iterationLimit > 0 && sharedState.counters.iterationLimit < sharedState.counters.iterationCount)
{ {
if (sharedState.counters.iterationLimit > 0 && sharedState.counters.iterationLimit < sharedState.counters.iterationCount) reportError(location, UnificationTooComplex{});
{ return;
reportError(TypeError{location, UnificationTooComplex{}});
return;
}
}
else
{
if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < sharedState.counters.iterationCount)
{
reportError(TypeError{location, UnificationTooComplex{}});
return;
}
} }
superTy = log.follow(superTy); superTy = log.follow(superTy);
@ -396,9 +375,6 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
if (superTy == subTy) if (superTy == subTy)
return; return;
if (log.get<ConstrainedTypeVar>(superTy))
return tryUnifyWithConstrainedSuperTypeVar(subTy, superTy);
auto superFree = log.getMutable<FreeTypeVar>(superTy); auto superFree = log.getMutable<FreeTypeVar>(superTy);
auto subFree = log.getMutable<FreeTypeVar>(subTy); auto subFree = log.getMutable<FreeTypeVar>(subTy);
@ -430,7 +406,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
if (subGeneric && !subsumes(useScopes, subGeneric, superFree)) if (subGeneric && !subsumes(useScopes, subGeneric, superFree))
{ {
// TODO: a more informative error message? CLI-39912 // TODO: a more informative error message? CLI-39912
reportError(TypeError{location, GenericError{"Generic subtype escaping scope"}}); reportError(location, GenericError{"Generic subtype escaping scope"});
return; return;
} }
@ -459,7 +435,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
if (superGeneric && !subsumes(useScopes, superGeneric, subFree)) if (superGeneric && !subsumes(useScopes, superGeneric, subFree))
{ {
// TODO: a more informative error message? CLI-39912 // TODO: a more informative error message? CLI-39912
reportError(TypeError{location, GenericError{"Generic supertype escaping scope"}}); reportError(location, GenericError{"Generic supertype escaping scope"});
return; return;
} }
@ -476,15 +452,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
return tryUnifyWithAny(subTy, superTy); return tryUnifyWithAny(subTy, superTy);
if (get<AnyTypeVar>(subTy)) if (get<AnyTypeVar>(subTy))
{ return tryUnifyWithAny(superTy, subTy);
if (anyIsTop)
{
reportError(TypeError{location, TypeMismatch{superTy, subTy}});
return;
}
else
return tryUnifyWithAny(superTy, subTy);
}
if (log.get<ErrorTypeVar>(subTy)) if (log.get<ErrorTypeVar>(subTy))
return tryUnifyWithAny(superTy, subTy); return tryUnifyWithAny(superTy, subTy);
@ -504,7 +472,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
if (auto error = sharedState.cachedUnifyError.find({subTy, superTy})) if (auto error = sharedState.cachedUnifyError.find({subTy, superTy}))
{ {
reportError(TypeError{location, *error}); reportError(location, *error);
return; return;
} }
} }
@ -520,9 +488,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
size_t errorCount = errors.size(); size_t errorCount = errors.size();
if (log.get<ConstrainedTypeVar>(subTy)) if (const UnionTypeVar* subUnion = log.getMutable<UnionTypeVar>(subTy))
tryUnifyWithConstrainedSubTypeVar(subTy, superTy);
else if (const UnionTypeVar* subUnion = log.getMutable<UnionTypeVar>(subTy))
{ {
tryUnifyUnionWithType(subTy, subUnion, superTy); tryUnifyUnionWithType(subTy, subUnion, superTy);
} }
@ -548,6 +514,12 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
else if ((log.getMutable<PrimitiveTypeVar>(superTy) || log.getMutable<SingletonTypeVar>(superTy)) && log.getMutable<SingletonTypeVar>(subTy)) else if ((log.getMutable<PrimitiveTypeVar>(superTy) || log.getMutable<SingletonTypeVar>(superTy)) && log.getMutable<SingletonTypeVar>(subTy))
tryUnifySingletons(subTy, superTy); tryUnifySingletons(subTy, superTy);
else if (auto ptv = get<PrimitiveTypeVar>(superTy);
FFlag::LuauNegatedFunctionTypes && ptv && ptv->type == PrimitiveTypeVar::Function && get<FunctionTypeVar>(subTy))
{
// Ok. Do nothing. forall functions F, F <: function
}
else if (log.getMutable<FunctionTypeVar>(superTy) && log.getMutable<FunctionTypeVar>(subTy)) else if (log.getMutable<FunctionTypeVar>(superTy) && log.getMutable<FunctionTypeVar>(subTy))
tryUnifyFunctions(subTy, superTy, isFunctionCall); tryUnifyFunctions(subTy, superTy, isFunctionCall);
@ -580,8 +552,14 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
else if (log.getMutable<ClassTypeVar>(subTy)) else if (log.getMutable<ClassTypeVar>(subTy))
tryUnifyWithClass(subTy, superTy, /*reversed*/ true); tryUnifyWithClass(subTy, superTy, /*reversed*/ true);
else if (log.get<NegationTypeVar>(superTy))
tryUnifyTypeWithNegation(subTy, superTy);
else if (log.get<NegationTypeVar>(subTy))
tryUnifyNegationWithType(subTy, superTy);
else else
reportError(TypeError{location, TypeMismatch{superTy, subTy}}); reportError(location, TypeMismatch{superTy, subTy});
if (cacheEnabled) if (cacheEnabled)
cacheResult(subTy, superTy, errorCount); cacheResult(subTy, superTy, errorCount);
@ -655,9 +633,9 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionTypeVar* subUnion,
else if (failed) else if (failed)
{ {
if (firstFailedOption) if (firstFailedOption)
reportError(TypeError{location, TypeMismatch{superTy, subTy, "Not all union options are compatible.", *firstFailedOption}}); reportError(location, TypeMismatch{superTy, subTy, "Not all union options are compatible.", *firstFailedOption});
else else
reportError(TypeError{location, TypeMismatch{superTy, subTy}}); reportError(location, TypeMismatch{superTy, subTy});
} }
} }
@ -756,7 +734,7 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp
const NormalizedType* subNorm = normalizer->normalize(subTy); const NormalizedType* subNorm = normalizer->normalize(subTy);
const NormalizedType* superNorm = normalizer->normalize(superTy); const NormalizedType* superNorm = normalizer->normalize(superTy);
if (!subNorm || !superNorm) if (!subNorm || !superNorm)
reportError(TypeError{location, UnificationTooComplex{}}); reportError(location, UnificationTooComplex{});
else if ((failedOptionCount == 1 || foundHeuristic) && failedOption) else if ((failedOptionCount == 1 || foundHeuristic) && failedOption)
tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption); tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption);
else else
@ -765,9 +743,9 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp
else if (!found) else if (!found)
{ {
if ((failedOptionCount == 1 || foundHeuristic) && failedOption) if ((failedOptionCount == 1 || foundHeuristic) && failedOption)
reportError(TypeError{location, TypeMismatch{superTy, subTy, "None of the union options are compatible. For example:", *failedOption}}); reportError(location, TypeMismatch{superTy, subTy, "None of the union options are compatible. For example:", *failedOption});
else else
reportError(TypeError{location, TypeMismatch{superTy, subTy, "none of the union options are compatible"}}); reportError(location, TypeMismatch{superTy, subTy, "none of the union options are compatible"});
} }
} }
@ -796,7 +774,7 @@ void Unifier::tryUnifyTypeWithIntersection(TypeId subTy, TypeId superTy, const I
if (unificationTooComplex) if (unificationTooComplex)
reportError(*unificationTooComplex); reportError(*unificationTooComplex);
else if (firstFailedOption) else if (firstFailedOption)
reportError(TypeError{location, TypeMismatch{superTy, subTy, "Not all intersection parts are compatible.", *firstFailedOption}}); reportError(location, TypeMismatch{superTy, subTy, "Not all intersection parts are compatible.", *firstFailedOption});
} }
void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionTypeVar* uv, TypeId superTy, bool cacheEnabled, bool isFunctionCall) void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionTypeVar* uv, TypeId superTy, bool cacheEnabled, bool isFunctionCall)
@ -854,11 +832,11 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionTypeV
if (subNorm && superNorm) if (subNorm && superNorm)
tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the intersection parts are compatible"); tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the intersection parts are compatible");
else else
reportError(TypeError{location, UnificationTooComplex{}}); reportError(location, UnificationTooComplex{});
} }
else if (!found) else if (!found)
{ {
reportError(TypeError{location, TypeMismatch{superTy, subTy, "none of the intersection parts are compatible"}}); reportError(location, TypeMismatch{superTy, subTy, "none of the intersection parts are compatible"});
} }
} }
@ -870,43 +848,37 @@ void Unifier::tryUnifyNormalizedTypes(
if (get<UnknownTypeVar>(superNorm.tops) || get<AnyTypeVar>(superNorm.tops) || get<AnyTypeVar>(subNorm.tops)) if (get<UnknownTypeVar>(superNorm.tops) || get<AnyTypeVar>(superNorm.tops) || get<AnyTypeVar>(subNorm.tops))
return; return;
else if (get<UnknownTypeVar>(subNorm.tops)) else if (get<UnknownTypeVar>(subNorm.tops))
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); return reportError(location, TypeMismatch{superTy, subTy, reason, error});
if (get<ErrorTypeVar>(subNorm.errors)) if (get<ErrorTypeVar>(subNorm.errors))
if (!get<ErrorTypeVar>(superNorm.errors)) if (!get<ErrorTypeVar>(superNorm.errors))
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); return reportError(location, TypeMismatch{superTy, subTy, reason, error});
if (get<PrimitiveTypeVar>(subNorm.booleans)) if (get<PrimitiveTypeVar>(subNorm.booleans))
{ {
if (!get<PrimitiveTypeVar>(superNorm.booleans)) if (!get<PrimitiveTypeVar>(superNorm.booleans))
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); return reportError(location, TypeMismatch{superTy, subTy, reason, error});
} }
else if (const SingletonTypeVar* stv = get<SingletonTypeVar>(subNorm.booleans)) else if (const SingletonTypeVar* stv = get<SingletonTypeVar>(subNorm.booleans))
{ {
if (!get<PrimitiveTypeVar>(superNorm.booleans) && stv != get<SingletonTypeVar>(superNorm.booleans)) if (!get<PrimitiveTypeVar>(superNorm.booleans) && stv != get<SingletonTypeVar>(superNorm.booleans))
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); return reportError(location, TypeMismatch{superTy, subTy, reason, error});
} }
if (get<PrimitiveTypeVar>(subNorm.nils)) if (get<PrimitiveTypeVar>(subNorm.nils))
if (!get<PrimitiveTypeVar>(superNorm.nils)) if (!get<PrimitiveTypeVar>(superNorm.nils))
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); return reportError(location, TypeMismatch{superTy, subTy, reason, error});
if (get<PrimitiveTypeVar>(subNorm.numbers)) if (get<PrimitiveTypeVar>(subNorm.numbers))
if (!get<PrimitiveTypeVar>(superNorm.numbers)) if (!get<PrimitiveTypeVar>(superNorm.numbers))
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); return reportError(location, TypeMismatch{superTy, subTy, reason, error});
if (subNorm.strings && superNorm.strings) if (!isSubtype(subNorm.strings, superNorm.strings))
{ return reportError(location, TypeMismatch{superTy, subTy, reason, error});
for (auto [name, ty] : *subNorm.strings)
if (!superNorm.strings->count(name))
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}});
}
else if (!subNorm.strings && superNorm.strings)
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}});
if (get<PrimitiveTypeVar>(subNorm.threads)) if (get<PrimitiveTypeVar>(subNorm.threads))
if (!get<PrimitiveTypeVar>(superNorm.errors)) if (!get<PrimitiveTypeVar>(superNorm.errors))
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); return reportError(location, TypeMismatch{superTy, subTy, reason, error});
for (TypeId subClass : subNorm.classes) for (TypeId subClass : subNorm.classes)
{ {
@ -922,7 +894,7 @@ void Unifier::tryUnifyNormalizedTypes(
} }
} }
if (!found) if (!found)
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); return reportError(location, TypeMismatch{superTy, subTy, reason, error});
} }
for (TypeId subTable : subNorm.tables) for (TypeId subTable : subNorm.tables)
@ -947,21 +919,19 @@ void Unifier::tryUnifyNormalizedTypes(
return reportError(*e); return reportError(*e);
} }
if (!found) if (!found)
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); return reportError(location, TypeMismatch{superTy, subTy, reason, error});
} }
if (subNorm.functions) if (!subNorm.functions.isNever())
{ {
if (!superNorm.functions) if (superNorm.functions.isNever())
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); return reportError(location, TypeMismatch{superTy, subTy, reason, error});
if (superNorm.functions->empty()) for (TypeId superFun : *superNorm.functions.parts)
return;
for (TypeId superFun : *superNorm.functions)
{ {
Unifier innerState = makeChildUnifier(); Unifier innerState = makeChildUnifier();
const FunctionTypeVar* superFtv = get<FunctionTypeVar>(superFun); const FunctionTypeVar* superFtv = get<FunctionTypeVar>(superFun);
if (!superFtv) if (!superFtv)
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); return reportError(location, TypeMismatch{superTy, subTy, reason, error});
TypePackId tgt = innerState.tryApplyOverloadedFunction(subTy, subNorm.functions, superFtv->argTypes); TypePackId tgt = innerState.tryApplyOverloadedFunction(subTy, subNorm.functions, superFtv->argTypes);
innerState.tryUnify_(tgt, superFtv->retTypes); innerState.tryUnify_(tgt, superFtv->retTypes);
if (innerState.errors.empty()) if (innerState.errors.empty())
@ -969,7 +939,7 @@ void Unifier::tryUnifyNormalizedTypes(
else if (auto e = hasUnificationTooComplex(innerState.errors)) else if (auto e = hasUnificationTooComplex(innerState.errors))
return reportError(*e); return reportError(*e);
else else
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); return reportError(location, TypeMismatch{superTy, subTy, reason, error});
} }
} }
@ -987,15 +957,15 @@ void Unifier::tryUnifyNormalizedTypes(
TypePackId Unifier::tryApplyOverloadedFunction(TypeId function, const NormalizedFunctionType& overloads, TypePackId args) TypePackId Unifier::tryApplyOverloadedFunction(TypeId function, const NormalizedFunctionType& overloads, TypePackId args)
{ {
if (!overloads || overloads->empty()) if (overloads.isNever())
{ {
reportError(TypeError{location, CannotCallNonFunction{function}}); reportError(location, CannotCallNonFunction{function});
return singletonTypes->errorRecoveryTypePack(); return singletonTypes->errorRecoveryTypePack();
} }
std::optional<TypePackId> result; std::optional<TypePackId> result;
const FunctionTypeVar* firstFun = nullptr; const FunctionTypeVar* firstFun = nullptr;
for (TypeId overload : *overloads) for (TypeId overload : *overloads.parts)
{ {
if (const FunctionTypeVar* ftv = get<FunctionTypeVar>(overload)) if (const FunctionTypeVar* ftv = get<FunctionTypeVar>(overload))
{ {
@ -1011,10 +981,17 @@ TypePackId Unifier::tryApplyOverloadedFunction(TypeId function, const Normalized
log.concat(std::move(innerState.log)); log.concat(std::move(innerState.log));
if (result) if (result)
{ {
if (FFlag::LuauOverloadedFunctionSubtypingPerf)
{
innerState.log.clear();
innerState.tryUnify_(*result, ftv->retTypes);
}
if (FFlag::LuauOverloadedFunctionSubtypingPerf && innerState.errors.empty())
log.concat(std::move(innerState.log));
// Annoyingly, since we don't support intersection of generic type packs, // Annoyingly, since we don't support intersection of generic type packs,
// the intersection may fail. We rather arbitrarily use the first matching overload // the intersection may fail. We rather arbitrarily use the first matching overload
// in that case. // in that case.
if (std::optional<TypePackId> intersect = normalizer->intersectionOfTypePacks(*result, ftv->retTypes)) else if (std::optional<TypePackId> intersect = normalizer->intersectionOfTypePacks(*result, ftv->retTypes))
result = intersect; result = intersect;
} }
else else
@ -1036,12 +1013,12 @@ TypePackId Unifier::tryApplyOverloadedFunction(TypeId function, const Normalized
// TODO: better error reporting? // TODO: better error reporting?
// The logic for error reporting overload resolution // The logic for error reporting overload resolution
// is currently over in TypeInfer.cpp, should we move it? // is currently over in TypeInfer.cpp, should we move it?
reportError(TypeError{location, GenericError{"No matching overload."}}); reportError(location, GenericError{"No matching overload."});
return singletonTypes->errorRecoveryTypePack(firstFun->retTypes); return singletonTypes->errorRecoveryTypePack(firstFun->retTypes);
} }
else else
{ {
reportError(TypeError{location, CannotCallNonFunction{function}}); reportError(location, CannotCallNonFunction{function});
return singletonTypes->errorRecoveryTypePack(); return singletonTypes->errorRecoveryTypePack();
} }
} }
@ -1214,26 +1191,14 @@ void Unifier::tryUnify(TypePackId subTp, TypePackId superTp, bool isFunctionCall
*/ */
void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCall) void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCall)
{ {
RecursionLimiter _ra(&sharedState.counters.recursionCount, RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit);
FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit);
++sharedState.counters.iterationCount; ++sharedState.counters.iterationCount;
if (FFlag::LuauAutocompleteDynamicLimits) if (sharedState.counters.iterationLimit > 0 && sharedState.counters.iterationLimit < sharedState.counters.iterationCount)
{ {
if (sharedState.counters.iterationLimit > 0 && sharedState.counters.iterationLimit < sharedState.counters.iterationCount) reportError(location, UnificationTooComplex{});
{ return;
reportError(TypeError{location, UnificationTooComplex{}});
return;
}
}
else
{
if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < sharedState.counters.iterationCount)
{
reportError(TypeError{location, UnificationTooComplex{}});
return;
}
} }
superTp = log.follow(superTp); superTp = log.follow(superTp);
@ -1405,7 +1370,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal
size_t actualSize = size(subTp); size_t actualSize = size(subTp);
if (ctx == CountMismatch::FunctionResult || ctx == CountMismatch::ExprListResult) if (ctx == CountMismatch::FunctionResult || ctx == CountMismatch::ExprListResult)
std::swap(expectedSize, actualSize); std::swap(expectedSize, actualSize);
reportError(TypeError{location, CountMismatch{expectedSize, std::nullopt, actualSize, ctx}}); reportError(location, CountMismatch{expectedSize, std::nullopt, actualSize, ctx});
while (superIter.good()) while (superIter.good())
{ {
@ -1426,7 +1391,10 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal
} }
else else
{ {
reportError(TypeError{location, GenericError{"Failed to unify type packs"}}); if (FFlag::LuauReportTypeMismatchForTypePackUnificationFailure)
reportError(location, TypePackMismatch{subTp, superTp});
else
reportError(location, GenericError{"Failed to unify type packs"});
} }
} }
@ -1438,7 +1406,7 @@ void Unifier::tryUnifyPrimitives(TypeId subTy, TypeId superTy)
ice("passed non primitive types to unifyPrimitives"); ice("passed non primitive types to unifyPrimitives");
if (superPrim->type != subPrim->type) if (superPrim->type != subPrim->type)
reportError(TypeError{location, TypeMismatch{superTy, subTy}}); reportError(location, TypeMismatch{superTy, subTy});
} }
void Unifier::tryUnifySingletons(TypeId subTy, TypeId superTy) void Unifier::tryUnifySingletons(TypeId subTy, TypeId superTy)
@ -1459,7 +1427,7 @@ void Unifier::tryUnifySingletons(TypeId subTy, TypeId superTy)
if (superPrim && superPrim->type == PrimitiveTypeVar::String && get<StringSingleton>(subSingleton) && variance == Covariant) if (superPrim && superPrim->type == PrimitiveTypeVar::String && get<StringSingleton>(subSingleton) && variance == Covariant)
return; return;
reportError(TypeError{location, TypeMismatch{superTy, subTy}}); reportError(location, TypeMismatch{superTy, subTy});
} }
void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCall) void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCall)
@ -1475,7 +1443,10 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal
bool shouldInstantiate = (numGenerics == 0 && subFunction->generics.size() > 0) || (numGenericPacks == 0 && subFunction->genericPacks.size() > 0); bool shouldInstantiate = (numGenerics == 0 && subFunction->generics.size() > 0) || (numGenericPacks == 0 && subFunction->genericPacks.size() > 0);
if (FFlag::LuauInstantiateInSubtyping && variance == Covariant && shouldInstantiate) // TODO: This is unsound when the context is invariant, but the annotation burden without allowing it and without
// read-only properties is too high for lua-apps. Read-only properties _should_ resolve their issue by allowing
// generic methods in tables to be marked read-only.
if (FFlag::LuauInstantiateInSubtyping && shouldInstantiate)
{ {
Instantiation instantiation{&log, types, scope->level, scope}; Instantiation instantiation{&log, types, scope->level, scope};
@ -1492,21 +1463,21 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal
} }
else else
{ {
reportError(TypeError{location, UnificationTooComplex{}}); reportError(location, UnificationTooComplex{});
} }
} }
else if (numGenerics != subFunction->generics.size()) else if (numGenerics != subFunction->generics.size())
{ {
numGenerics = std::min(superFunction->generics.size(), subFunction->generics.size()); numGenerics = std::min(superFunction->generics.size(), subFunction->generics.size());
reportError(TypeError{location, TypeMismatch{superTy, subTy, "different number of generic type parameters"}}); reportError(location, TypeMismatch{superTy, subTy, "different number of generic type parameters"});
} }
if (numGenericPacks != subFunction->genericPacks.size()) if (numGenericPacks != subFunction->genericPacks.size())
{ {
numGenericPacks = std::min(superFunction->genericPacks.size(), subFunction->genericPacks.size()); numGenericPacks = std::min(superFunction->genericPacks.size(), subFunction->genericPacks.size());
reportError(TypeError{location, TypeMismatch{superTy, subTy, "different number of generic type pack parameters"}}); reportError(location, TypeMismatch{superTy, subTy, "different number of generic type pack parameters"});
} }
for (size_t i = 0; i < numGenerics; i++) for (size_t i = 0; i < numGenerics; i++)
@ -1533,11 +1504,10 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal
if (auto e = hasUnificationTooComplex(innerState.errors)) if (auto e = hasUnificationTooComplex(innerState.errors))
reportError(*e); reportError(*e);
else if (!innerState.errors.empty() && innerState.firstPackErrorPos) else if (!innerState.errors.empty() && innerState.firstPackErrorPos)
reportError( reportError(location, TypeMismatch{superTy, subTy, format("Argument #%d type is not compatible.", *innerState.firstPackErrorPos),
TypeError{location, TypeMismatch{superTy, subTy, format("Argument #%d type is not compatible.", *innerState.firstPackErrorPos), innerState.errors.front()});
innerState.errors.front()}});
else if (!innerState.errors.empty()) else if (!innerState.errors.empty())
reportError(TypeError{location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}}); reportError(location, TypeMismatch{superTy, subTy, "", innerState.errors.front()});
innerState.ctx = CountMismatch::FunctionResult; innerState.ctx = CountMismatch::FunctionResult;
innerState.tryUnify_(subFunction->retTypes, superFunction->retTypes); innerState.tryUnify_(subFunction->retTypes, superFunction->retTypes);
@ -1547,13 +1517,12 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal
if (auto e = hasUnificationTooComplex(innerState.errors)) if (auto e = hasUnificationTooComplex(innerState.errors))
reportError(*e); reportError(*e);
else if (!innerState.errors.empty() && size(superFunction->retTypes) == 1 && finite(superFunction->retTypes)) else if (!innerState.errors.empty() && size(superFunction->retTypes) == 1 && finite(superFunction->retTypes))
reportError(TypeError{location, TypeMismatch{superTy, subTy, "Return type is not compatible.", innerState.errors.front()}}); reportError(location, TypeMismatch{superTy, subTy, "Return type is not compatible.", innerState.errors.front()});
else if (!innerState.errors.empty() && innerState.firstPackErrorPos) else if (!innerState.errors.empty() && innerState.firstPackErrorPos)
reportError( reportError(location, TypeMismatch{superTy, subTy, format("Return #%d type is not compatible.", *innerState.firstPackErrorPos),
TypeError{location, TypeMismatch{superTy, subTy, format("Return #%d type is not compatible.", *innerState.firstPackErrorPos), innerState.errors.front()});
innerState.errors.front()}});
else if (!innerState.errors.empty()) else if (!innerState.errors.empty())
reportError(TypeError{location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}}); reportError(location, TypeMismatch{superTy, subTy, "", innerState.errors.front()});
} }
log.concat(std::move(innerState.log)); log.concat(std::move(innerState.log));
@ -1610,6 +1579,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection)
{ {
TableTypeVar* superTable = log.getMutable<TableTypeVar>(superTy); TableTypeVar* superTable = log.getMutable<TableTypeVar>(superTy);
TableTypeVar* subTable = log.getMutable<TableTypeVar>(subTy); TableTypeVar* subTable = log.getMutable<TableTypeVar>(subTy);
TableTypeVar* instantiatedSubTable = subTable;
if (!superTable || !subTable) if (!superTable || !subTable)
ice("passed non-table types to unifyTables"); ice("passed non-table types to unifyTables");
@ -1627,13 +1597,14 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection)
if (instantiated.has_value()) if (instantiated.has_value())
{ {
subTable = log.getMutable<TableTypeVar>(*instantiated); subTable = log.getMutable<TableTypeVar>(*instantiated);
instantiatedSubTable = subTable;
if (!subTable) if (!subTable)
ice("instantiation made a table type into a non-table type in tryUnifyTables"); ice("instantiation made a table type into a non-table type in tryUnifyTables");
} }
else else
{ {
reportError(TypeError{location, UnificationTooComplex{}}); reportError(location, UnificationTooComplex{});
} }
} }
} }
@ -1651,7 +1622,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection)
if (!missingProperties.empty()) if (!missingProperties.empty())
{ {
reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(missingProperties)}}); reportError(location, MissingProperties{superTy, subTy, std::move(missingProperties)});
return; return;
} }
} }
@ -1669,7 +1640,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection)
if (!extraProperties.empty()) if (!extraProperties.empty())
{ {
reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra}}); reportError(location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra});
return; return;
} }
} }
@ -1730,7 +1701,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection)
// txn log. // txn log.
TableTypeVar* newSuperTable = log.getMutable<TableTypeVar>(superTy); TableTypeVar* newSuperTable = log.getMutable<TableTypeVar>(superTy);
TableTypeVar* newSubTable = log.getMutable<TableTypeVar>(subTy); TableTypeVar* newSubTable = log.getMutable<TableTypeVar>(subTy);
if (superTable != newSuperTable || subTable != newSubTable) if (superTable != newSuperTable || (subTable != newSubTable && subTable != instantiatedSubTable))
{ {
if (errors.empty()) if (errors.empty())
return tryUnifyTables(subTy, superTy, isIntersection); return tryUnifyTables(subTy, superTy, isIntersection);
@ -1792,7 +1763,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection)
// txn log. // txn log.
TableTypeVar* newSuperTable = log.getMutable<TableTypeVar>(superTy); TableTypeVar* newSuperTable = log.getMutable<TableTypeVar>(superTy);
TableTypeVar* newSubTable = log.getMutable<TableTypeVar>(subTy); TableTypeVar* newSubTable = log.getMutable<TableTypeVar>(subTy);
if (superTable != newSuperTable || subTable != newSubTable) if (superTable != newSuperTable || (subTable != newSubTable && subTable != instantiatedSubTable))
{ {
if (errors.empty()) if (errors.empty())
return tryUnifyTables(subTy, superTy, isIntersection); return tryUnifyTables(subTy, superTy, isIntersection);
@ -1850,13 +1821,13 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection)
if (!missingProperties.empty()) if (!missingProperties.empty())
{ {
reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(missingProperties)}}); reportError(location, MissingProperties{superTy, subTy, std::move(missingProperties)});
return; return;
} }
if (!extraProperties.empty()) if (!extraProperties.empty())
{ {
reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra}}); reportError(location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra});
return; return;
} }
@ -1892,14 +1863,14 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed)
std::swap(subTy, superTy); std::swap(subTy, superTy);
if (auto ttv = log.get<TableTypeVar>(superTy); !ttv || ttv->state != TableState::Free) if (auto ttv = log.get<TableTypeVar>(superTy); !ttv || ttv->state != TableState::Free)
return reportError(TypeError{location, TypeMismatch{osuperTy, osubTy}}); return reportError(location, TypeMismatch{osuperTy, osubTy});
auto fail = [&](std::optional<TypeError> e) { auto fail = [&](std::optional<TypeError> e) {
std::string reason = "The former's metatable does not satisfy the requirements."; std::string reason = "The former's metatable does not satisfy the requirements.";
if (e) if (e)
reportError(TypeError{location, TypeMismatch{osuperTy, osubTy, reason, *e}}); reportError(location, TypeMismatch{osuperTy, osubTy, reason, *e});
else else
reportError(TypeError{location, TypeMismatch{osuperTy, osubTy, reason}}); reportError(location, TypeMismatch{osuperTy, osubTy, reason});
}; };
// Given t1 where t1 = { lower: (t1) -> (a, b...) } // Given t1 where t1 = { lower: (t1) -> (a, b...) }
@ -1931,7 +1902,7 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed)
} }
} }
reportError(TypeError{location, TypeMismatch{osuperTy, osubTy}}); reportError(location, TypeMismatch{osuperTy, osubTy});
return; return;
} }
@ -1972,7 +1943,7 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed)
if (auto e = hasUnificationTooComplex(innerState.errors)) if (auto e = hasUnificationTooComplex(innerState.errors))
reportError(*e); reportError(*e);
else if (!innerState.errors.empty()) else if (!innerState.errors.empty())
reportError(TypeError{location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front()}}); reportError(location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front()});
log.concat(std::move(innerState.log)); log.concat(std::move(innerState.log));
} }
@ -2049,9 +2020,9 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed)
auto fail = [&]() { auto fail = [&]() {
if (!reversed) if (!reversed)
reportError(TypeError{location, TypeMismatch{superTy, subTy}}); reportError(location, TypeMismatch{superTy, subTy});
else else
reportError(TypeError{location, TypeMismatch{subTy, superTy}}); reportError(location, TypeMismatch{subTy, superTy});
}; };
const ClassTypeVar* superClass = get<ClassTypeVar>(superTy); const ClassTypeVar* superClass = get<ClassTypeVar>(superTy);
@ -2096,7 +2067,7 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed)
if (!classProp) if (!classProp)
{ {
ok = false; ok = false;
reportError(TypeError{location, UnknownProperty{superTy, propName}}); reportError(location, UnknownProperty{superTy, propName});
} }
else else
{ {
@ -2120,7 +2091,7 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed)
{ {
ok = false; ok = false;
std::string msg = "Class " + superClass->name + " does not have an indexer"; std::string msg = "Class " + superClass->name + " does not have an indexer";
reportError(TypeError{location, GenericError{msg}}); reportError(location, GenericError{msg});
} }
if (!ok) if (!ok)
@ -2132,6 +2103,34 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed)
return fail(); return fail();
} }
void Unifier::tryUnifyTypeWithNegation(TypeId subTy, TypeId superTy)
{
const NegationTypeVar* ntv = get<NegationTypeVar>(superTy);
if (!ntv)
ice("tryUnifyTypeWithNegation superTy must be a negation type");
const NormalizedType* subNorm = normalizer->normalize(subTy);
const NormalizedType* superNorm = normalizer->normalize(superTy);
if (!subNorm || !superNorm)
return reportError(location, UnificationTooComplex{});
// T </: ~U iff T <: U
Unifier state = makeChildUnifier();
state.tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "");
if (state.errors.empty())
reportError(location, TypeMismatch{superTy, subTy});
}
void Unifier::tryUnifyNegationWithType(TypeId subTy, TypeId superTy)
{
const NegationTypeVar* ntv = get<NegationTypeVar>(subTy);
if (!ntv)
ice("tryUnifyNegationWithType subTy must be a negation type");
// TODO: ~T </: U iff T <: U
reportError(location, TypeMismatch{superTy, subTy});
}
static void queueTypePack(std::vector<TypeId>& queue, DenseHashSet<TypePackId>& seenTypePacks, Unifier& state, TypePackId a, TypePackId anyTypePack) static void queueTypePack(std::vector<TypeId>& queue, DenseHashSet<TypePackId>& seenTypePacks, Unifier& state, TypePackId a, TypePackId anyTypePack)
{ {
while (true) while (true)
@ -2192,7 +2191,7 @@ void Unifier::tryUnifyVariadics(TypePackId subTp, TypePackId superTp, bool rever
} }
else if (get<Unifiable::Generic>(tail)) else if (get<Unifiable::Generic>(tail))
{ {
reportError(TypeError{location, GenericError{"Cannot unify variadic and generic packs"}}); reportError(location, GenericError{"Cannot unify variadic and generic packs"});
} }
else if (get<Unifiable::Error>(tail)) else if (get<Unifiable::Error>(tail))
{ {
@ -2206,7 +2205,7 @@ void Unifier::tryUnifyVariadics(TypePackId subTp, TypePackId superTp, bool rever
} }
else else
{ {
reportError(TypeError{location, GenericError{"Failed to unify variadic packs"}}); reportError(location, GenericError{"Failed to unify variadic packs"});
} }
} }
@ -2314,186 +2313,6 @@ std::optional<TypeId> Unifier::findTablePropertyRespectingMeta(TypeId lhsType, N
return Luau::findTablePropertyRespectingMeta(singletonTypes, errors, lhsType, name, location); return Luau::findTablePropertyRespectingMeta(singletonTypes, errors, lhsType, name, location);
} }
void Unifier::tryUnifyWithConstrainedSubTypeVar(TypeId subTy, TypeId superTy)
{
const ConstrainedTypeVar* subConstrained = get<ConstrainedTypeVar>(subTy);
if (!subConstrained)
ice("tryUnifyWithConstrainedSubTypeVar received non-ConstrainedTypeVar subTy!");
const std::vector<TypeId>& subTyParts = subConstrained->parts;
// A | B <: T if A <: T and B <: T
bool failed = false;
std::optional<TypeError> unificationTooComplex;
const size_t count = subTyParts.size();
for (size_t i = 0; i < count; ++i)
{
TypeId type = subTyParts[i];
Unifier innerState = makeChildUnifier();
innerState.tryUnify_(type, superTy);
if (i == count - 1)
log.concat(std::move(innerState.log));
++i;
if (auto e = hasUnificationTooComplex(innerState.errors))
unificationTooComplex = e;
if (!innerState.errors.empty())
{
failed = true;
break;
}
}
if (unificationTooComplex)
reportError(*unificationTooComplex);
else if (failed)
reportError(TypeError{location, TypeMismatch{superTy, subTy}});
else
log.replace(subTy, BoundTypeVar{superTy});
}
void Unifier::tryUnifyWithConstrainedSuperTypeVar(TypeId subTy, TypeId superTy)
{
ConstrainedTypeVar* superC = log.getMutable<ConstrainedTypeVar>(superTy);
if (!superC)
ice("tryUnifyWithConstrainedSuperTypeVar received non-ConstrainedTypeVar superTy!");
// subTy could be a
// table
// metatable
// class
// function
// primitive
// free
// generic
// intersection
// union
// Do we really just tack it on? I think we might!
// We can certainly do some deduplication.
// Is there any point to deducing Player|Instance when we could just reduce to Instance?
// Is it actually ok to have multiple free types in a single intersection? What if they are later unified into the same type?
// Maybe we do a simplification step during quantification.
auto it = std::find(superC->parts.begin(), superC->parts.end(), subTy);
if (it != superC->parts.end())
return;
superC->parts.push_back(subTy);
}
void Unifier::unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel demotedLevel)
{
// The duplication between this and regular typepack unification is tragic.
auto superIter = begin(superTy, &log);
auto superEndIter = end(superTy);
auto subIter = begin(subTy, &log);
auto subEndIter = end(subTy);
int count = FInt::LuauTypeInferLowerBoundsIterationLimit;
for (; subIter != subEndIter; ++subIter)
{
if (0 >= --count)
ice("Internal recursion counter limit exceeded in Unifier::unifyLowerBound");
if (superIter != superEndIter)
{
tryUnify_(*subIter, *superIter);
++superIter;
continue;
}
if (auto t = superIter.tail())
{
TypePackId tailPack = follow(*t);
if (log.get<FreeTypePack>(tailPack) && occursCheck(tailPack, subTy))
return;
FreeTypePack* freeTailPack = log.getMutable<FreeTypePack>(tailPack);
if (!freeTailPack)
return;
TypePack* tp = getMutable<TypePack>(log.replace(tailPack, TypePack{}));
for (; subIter != subEndIter; ++subIter)
{
tp->head.push_back(types->addType(ConstrainedTypeVar{demotedLevel, {follow(*subIter)}}));
}
tp->tail = subIter.tail();
}
return;
}
if (superIter != superEndIter)
{
if (auto subTail = subIter.tail())
{
TypePackId subTailPack = follow(*subTail);
if (get<FreeTypePack>(subTailPack))
{
TypePack* tp = getMutable<TypePack>(log.replace(subTailPack, TypePack{}));
for (; superIter != superEndIter; ++superIter)
tp->head.push_back(*superIter);
}
else if (const VariadicTypePack* subVariadic = log.getMutable<VariadicTypePack>(subTailPack))
{
while (superIter != superEndIter)
{
tryUnify_(subVariadic->ty, *superIter);
++superIter;
}
}
}
else
{
while (superIter != superEndIter)
{
if (!isOptional(*superIter))
{
errors.push_back(TypeError{location, CountMismatch{size(superTy), std::nullopt, size(subTy), CountMismatch::Return}});
return;
}
++superIter;
}
}
return;
}
// Both iters are at their respective tails
auto subTail = subIter.tail();
auto superTail = superIter.tail();
if (subTail && superTail)
tryUnify(*subTail, *superTail);
else if (subTail)
{
const FreeTypePack* freeSubTail = log.getMutable<FreeTypePack>(*subTail);
if (freeSubTail)
{
log.replace(*subTail, TypePack{});
}
}
else if (superTail)
{
const FreeTypePack* freeSuperTail = log.getMutable<FreeTypePack>(*superTail);
if (freeSuperTail)
{
log.replace(*superTail, TypePack{});
}
}
}
bool Unifier::occursCheck(TypeId needle, TypeId haystack) bool Unifier::occursCheck(TypeId needle, TypeId haystack)
{ {
sharedState.tempSeenTy.clear(); sharedState.tempSeenTy.clear();
@ -2503,8 +2322,7 @@ bool Unifier::occursCheck(TypeId needle, TypeId haystack)
bool Unifier::occursCheck(DenseHashSet<TypeId>& seen, TypeId needle, TypeId haystack) bool Unifier::occursCheck(DenseHashSet<TypeId>& seen, TypeId needle, TypeId haystack)
{ {
RecursionLimiter _ra(&sharedState.counters.recursionCount, RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit);
FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit);
bool occurrence = false; bool occurrence = false;
@ -2529,7 +2347,7 @@ bool Unifier::occursCheck(DenseHashSet<TypeId>& seen, TypeId needle, TypeId hays
if (needle == haystack) if (needle == haystack)
{ {
reportError(TypeError{location, OccursCheckFailed{}}); reportError(location, OccursCheckFailed{});
log.replace(needle, *singletonTypes->errorRecoveryType()); log.replace(needle, *singletonTypes->errorRecoveryType());
return true; return true;
@ -2547,11 +2365,6 @@ bool Unifier::occursCheck(DenseHashSet<TypeId>& seen, TypeId needle, TypeId hays
for (TypeId ty : a->parts) for (TypeId ty : a->parts)
check(ty); check(ty);
} }
else if (auto a = log.getMutable<ConstrainedTypeVar>(haystack))
{
for (TypeId ty : a->parts)
check(ty);
}
return occurrence; return occurrence;
} }
@ -2579,14 +2392,13 @@ bool Unifier::occursCheck(DenseHashSet<TypePackId>& seen, TypePackId needle, Typ
if (!log.getMutable<Unifiable::Free>(needle)) if (!log.getMutable<Unifiable::Free>(needle))
ice("Expected needle pack to be free"); ice("Expected needle pack to be free");
RecursionLimiter _ra(&sharedState.counters.recursionCount, RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit);
FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit);
while (!log.getMutable<ErrorTypeVar>(haystack)) while (!log.getMutable<ErrorTypeVar>(haystack))
{ {
if (needle == haystack) if (needle == haystack)
{ {
reportError(TypeError{location, OccursCheckFailed{}}); reportError(location, OccursCheckFailed{});
log.replace(needle, *singletonTypes->errorRecoveryTypePack()); log.replace(needle, *singletonTypes->errorRecoveryTypePack());
return true; return true;
@ -2607,18 +2419,31 @@ bool Unifier::occursCheck(DenseHashSet<TypePackId>& seen, TypePackId needle, Typ
Unifier Unifier::makeChildUnifier() Unifier Unifier::makeChildUnifier()
{ {
Unifier u = Unifier{normalizer, mode, scope, location, variance, &log}; Unifier u = Unifier{normalizer, mode, scope, location, variance, &log};
u.anyIsTop = anyIsTop;
u.normalize = normalize; u.normalize = normalize;
u.useScopes = useScopes;
return u; return u;
} }
// A utility function that appends the given error to the unifier's error log. // A utility function that appends the given error to the unifier's error log.
// This allows setting a breakpoint wherever the unifier reports an error. // This allows setting a breakpoint wherever the unifier reports an error.
//
// Note: report error accepts its arguments by value intentionally to reduce the stack usage of functions which call `reportError`.
void Unifier::reportError(Location location, TypeErrorData data)
{
errors.emplace_back(std::move(location), std::move(data));
}
// A utility function that appends the given error to the unifier's error log.
// This allows setting a breakpoint wherever the unifier reports an error.
//
// Note: to conserve stack space in calling functions it is generally preferred to call `Unifier::reportError(Location location, TypeErrorData data)`
// instead of this method.
void Unifier::reportError(TypeError err) void Unifier::reportError(TypeError err)
{ {
errors.push_back(std::move(err)); errors.push_back(std::move(err));
} }
bool Unifier::isNonstrictMode() const bool Unifier::isNonstrictMode() const
{ {
return (mode == Mode::Nonstrict) || (mode == Mode::NoCheck); return (mode == Mode::Nonstrict) || (mode == Mode::NoCheck);
@ -2629,7 +2454,7 @@ void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, TypeId
if (auto e = hasUnificationTooComplex(innerErrors)) if (auto e = hasUnificationTooComplex(innerErrors))
reportError(*e); reportError(*e);
else if (!innerErrors.empty()) else if (!innerErrors.empty())
reportError(TypeError{location, TypeMismatch{wantedType, givenType}}); reportError(location, TypeMismatch{wantedType, givenType});
} }
void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, const std::string& prop, TypeId wantedType, TypeId givenType) void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, const std::string& prop, TypeId wantedType, TypeId givenType)

View file

@ -58,6 +58,8 @@ struct Comment
struct ParseResult struct ParseResult
{ {
AstStatBlock* root; AstStatBlock* root;
size_t lines = 0;
std::vector<HotComment> hotcomments; std::vector<HotComment> hotcomments;
std::vector<ParseError> errors; std::vector<ParseError> errors;

View file

@ -302,8 +302,8 @@ private:
AstStatError* reportStatError(const Location& location, const AstArray<AstExpr*>& expressions, const AstArray<AstStat*>& statements, AstStatError* reportStatError(const Location& location, const AstArray<AstExpr*>& expressions, const AstArray<AstStat*>& statements,
const char* format, ...) LUAU_PRINTF_ATTR(5, 6); const char* format, ...) LUAU_PRINTF_ATTR(5, 6);
AstExprError* reportExprError(const Location& location, const AstArray<AstExpr*>& expressions, const char* format, ...) LUAU_PRINTF_ATTR(4, 5); AstExprError* reportExprError(const Location& location, const AstArray<AstExpr*>& expressions, const char* format, ...) LUAU_PRINTF_ATTR(4, 5);
AstTypeError* reportTypeAnnotationError(const Location& location, const AstArray<AstType*>& types, bool isMissing, const char* format, ...) AstTypeError* reportTypeAnnotationError(const Location& location, const AstArray<AstType*>& types, const char* format, ...)
LUAU_PRINTF_ATTR(5, 6); LUAU_PRINTF_ATTR(4, 5);
// `parseErrorLocation` is associated with the parser error // `parseErrorLocation` is associated with the parser error
// `astErrorLocation` is associated with the AstTypeError created // `astErrorLocation` is associated with the AstTypeError created
// It can be useful to have different error locations so that the parse error can include the next lexeme, while the AstTypeError can precisely // It can be useful to have different error locations so that the parse error can include the next lexeme, while the AstTypeError can precisely

View file

@ -641,8 +641,8 @@ Lexeme Lexer::readInterpolatedStringSection(Position start, Lexeme::Type formatT
return brokenDoubleBrace; return brokenDoubleBrace;
} }
Lexeme lexemeOutput(Location(start, position()), Lexeme::InterpStringBegin, &buffer[startOffset], offset - startOffset);
consume(); consume();
Lexeme lexemeOutput(Location(start, position()), Lexeme::InterpStringBegin, &buffer[startOffset], offset - startOffset - 1);
return lexemeOutput; return lexemeOutput;
} }

View file

@ -23,9 +23,9 @@ LUAU_FASTFLAGVARIABLE(LuauErrorDoubleHexPrefix, false)
LUAU_DYNAMIC_FASTFLAGVARIABLE(LuaReportParseIntegerIssues, false) LUAU_DYNAMIC_FASTFLAGVARIABLE(LuaReportParseIntegerIssues, false)
LUAU_FASTFLAGVARIABLE(LuauInterpolatedStringBaseSupport, false) LUAU_FASTFLAGVARIABLE(LuauInterpolatedStringBaseSupport, false)
LUAU_FASTFLAGVARIABLE(LuauTypeAnnotationLocationChange, false)
LUAU_FASTFLAGVARIABLE(LuauCommaParenWarnings, false) LUAU_FASTFLAGVARIABLE(LuauCommaParenWarnings, false)
LUAU_FASTFLAGVARIABLE(LuauTableConstructorRecovery, false)
bool lua_telemetry_parsed_out_of_range_bin_integer = false; bool lua_telemetry_parsed_out_of_range_bin_integer = false;
bool lua_telemetry_parsed_out_of_range_hex_integer = false; bool lua_telemetry_parsed_out_of_range_hex_integer = false;
@ -164,15 +164,16 @@ ParseResult Parser::parse(const char* buffer, size_t bufferSize, AstNameTable& n
try try
{ {
AstStatBlock* root = p.parseChunk(); AstStatBlock* root = p.parseChunk();
size_t lines = p.lexer.current().location.end.line + (bufferSize > 0 && buffer[bufferSize - 1] != '\n');
return ParseResult{root, std::move(p.hotcomments), std::move(p.parseErrors), std::move(p.commentLocations)}; return ParseResult{root, lines, std::move(p.hotcomments), std::move(p.parseErrors), std::move(p.commentLocations)};
} }
catch (ParseError& err) catch (ParseError& err)
{ {
// when catching a fatal error, append it to the list of non-fatal errors and return // when catching a fatal error, append it to the list of non-fatal errors and return
p.parseErrors.push_back(err); p.parseErrors.push_back(err);
return ParseResult{nullptr, {}, p.parseErrors}; return ParseResult{nullptr, 0, {}, p.parseErrors};
} }
} }
@ -811,9 +812,8 @@ AstDeclaredClassProp Parser::parseDeclaredClassMethod()
if (args.size() == 0 || args[0].name.name != "self" || args[0].annotation != nullptr) if (args.size() == 0 || args[0].name.name != "self" || args[0].annotation != nullptr)
{ {
return AstDeclaredClassProp{fnName.name, return AstDeclaredClassProp{
reportTypeAnnotationError(Location(start, end), {}, /*isMissing*/ false, "'self' must be present as the unannotated first parameter"), fnName.name, reportTypeAnnotationError(Location(start, end), {}, "'self' must be present as the unannotated first parameter"), true};
true};
} }
// Skip the first index. // Skip the first index.
@ -824,8 +824,7 @@ AstDeclaredClassProp Parser::parseDeclaredClassMethod()
if (args[i].annotation) if (args[i].annotation)
vars.push_back(args[i].annotation); vars.push_back(args[i].annotation);
else else
vars.push_back(reportTypeAnnotationError( vars.push_back(reportTypeAnnotationError(Location(start, end), {}, "All declaration parameters aside from 'self' must be annotated"));
Location(start, end), {}, /*isMissing*/ false, "All declaration parameters aside from 'self' must be annotated"));
} }
if (vararg && !varargAnnotation) if (vararg && !varargAnnotation)
@ -1537,7 +1536,7 @@ AstType* Parser::parseTypeAnnotation(TempVector<AstType*>& parts, const Location
if (isUnion && isIntersection) if (isUnion && isIntersection)
{ {
return reportTypeAnnotationError(Location(begin, parts.back()->location), copy(parts), /*isMissing*/ false, return reportTypeAnnotationError(Location(begin, parts.back()->location), copy(parts),
"Mixing union and intersection types is not allowed; consider wrapping in parentheses."); "Mixing union and intersection types is not allowed; consider wrapping in parentheses.");
} }
@ -1623,18 +1622,18 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack)
return {allocator.alloc<AstTypeSingletonString>(start, svalue)}; return {allocator.alloc<AstTypeSingletonString>(start, svalue)};
} }
else else
return {reportTypeAnnotationError(start, {}, /*isMissing*/ false, "String literal contains malformed escape sequence")}; return {reportTypeAnnotationError(start, {}, "String literal contains malformed escape sequence")};
} }
else if (lexer.current().type == Lexeme::InterpStringBegin || lexer.current().type == Lexeme::InterpStringSimple) else if (lexer.current().type == Lexeme::InterpStringBegin || lexer.current().type == Lexeme::InterpStringSimple)
{ {
parseInterpString(); parseInterpString();
return {reportTypeAnnotationError(start, {}, /*isMissing*/ false, "Interpolated string literals cannot be used as types")}; return {reportTypeAnnotationError(start, {}, "Interpolated string literals cannot be used as types")};
} }
else if (lexer.current().type == Lexeme::BrokenString) else if (lexer.current().type == Lexeme::BrokenString)
{ {
nextLexeme(); nextLexeme();
return {reportTypeAnnotationError(start, {}, /*isMissing*/ false, "Malformed string")}; return {reportTypeAnnotationError(start, {}, "Malformed string")};
} }
else if (lexer.current().type == Lexeme::Name) else if (lexer.current().type == Lexeme::Name)
{ {
@ -1693,33 +1692,20 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack)
{ {
nextLexeme(); nextLexeme();
return {reportTypeAnnotationError(start, {}, /*isMissing*/ false, return {reportTypeAnnotationError(start, {},
"Using 'function' as a type annotation is not supported, consider replacing with a function type annotation e.g. '(...any) -> " "Using 'function' as a type annotation is not supported, consider replacing with a function type annotation e.g. '(...any) -> "
"...any'"), "...any'"),
{}}; {}};
} }
else else
{ {
if (FFlag::LuauTypeAnnotationLocationChange) // For a missing type annotation, capture 'space' between last token and the next one
{ Location astErrorlocation(lexer.previousLocation().end, start.begin);
// For a missing type annotation, capture 'space' between last token and the next one // The parse error includes the next lexeme to make it easier to display where the error is (e.g. in an IDE or a CLI error message).
Location astErrorlocation(lexer.previousLocation().end, start.begin); // Including the current lexeme also makes the parse error consistent with other parse errors returned by Luau.
// The parse error includes the next lexeme to make it easier to display where the error is (e.g. in an IDE or a CLI error message). Location parseErrorLocation(lexer.previousLocation().end, start.end);
// Including the current lexeme also makes the parse error consistent with other parse errors returned by Luau. return {
Location parseErrorLocation(lexer.previousLocation().end, start.end); reportMissingTypeAnnotationError(parseErrorLocation, astErrorlocation, "Expected type, got %s", lexer.current().toString().c_str()), {}};
return {
reportMissingTypeAnnotationError(parseErrorLocation, astErrorlocation, "Expected type, got %s", lexer.current().toString().c_str()),
{}};
}
else
{
Location location = lexer.current().location;
// For a missing type annotation, capture 'space' between last token and the next one
location = Location(lexer.previousLocation().end, lexer.current().location.begin);
return {reportTypeAnnotationError(location, {}, /*isMissing*/ true, "Expected type, got %s", lexer.current().toString().c_str()), {}};
}
} }
} }
@ -2325,9 +2311,13 @@ AstExpr* Parser::parseTableConstructor()
MatchLexeme matchBrace = lexer.current(); MatchLexeme matchBrace = lexer.current();
expectAndConsume('{', "table literal"); expectAndConsume('{', "table literal");
unsigned lastElementIndent = 0;
while (lexer.current().type != '}') while (lexer.current().type != '}')
{ {
if (FFlag::LuauTableConstructorRecovery)
lastElementIndent = lexer.current().location.begin.column;
if (lexer.current().type == '[') if (lexer.current().type == '[')
{ {
MatchLexeme matchLocationBracket = lexer.current(); MatchLexeme matchLocationBracket = lexer.current();
@ -2372,10 +2362,14 @@ AstExpr* Parser::parseTableConstructor()
{ {
nextLexeme(); nextLexeme();
} }
else else if (FFlag::LuauTableConstructorRecovery && (lexer.current().type == '[' || lexer.current().type == Lexeme::Name) &&
lexer.current().location.begin.column == lastElementIndent)
{ {
if (lexer.current().type != '}') report(lexer.current().location, "Expected ',' after table constructor element");
break; }
else if (lexer.current().type != '}')
{
break;
} }
} }
@ -3033,27 +3027,18 @@ AstExprError* Parser::reportExprError(const Location& location, const AstArray<A
return allocator.alloc<AstExprError>(location, expressions, unsigned(parseErrors.size() - 1)); return allocator.alloc<AstExprError>(location, expressions, unsigned(parseErrors.size() - 1));
} }
AstTypeError* Parser::reportTypeAnnotationError(const Location& location, const AstArray<AstType*>& types, bool isMissing, const char* format, ...) AstTypeError* Parser::reportTypeAnnotationError(const Location& location, const AstArray<AstType*>& types, const char* format, ...)
{ {
if (FFlag::LuauTypeAnnotationLocationChange)
{
// Missing type annotations should be using `reportMissingTypeAnnotationError` when LuauTypeAnnotationLocationChange is enabled
// Note: `isMissing` can be removed once FFlag::LuauTypeAnnotationLocationChange is removed since it will always be true.
LUAU_ASSERT(!isMissing);
}
va_list args; va_list args;
va_start(args, format); va_start(args, format);
report(location, format, args); report(location, format, args);
va_end(args); va_end(args);
return allocator.alloc<AstTypeError>(location, types, isMissing, unsigned(parseErrors.size() - 1)); return allocator.alloc<AstTypeError>(location, types, false, unsigned(parseErrors.size() - 1));
} }
AstTypeError* Parser::reportMissingTypeAnnotationError(const Location& parseErrorLocation, const Location& astErrorLocation, const char* format, ...) AstTypeError* Parser::reportMissingTypeAnnotationError(const Location& parseErrorLocation, const Location& astErrorLocation, const char* format, ...)
{ {
LUAU_ASSERT(FFlag::LuauTypeAnnotationLocationChange);
va_list args; va_list args;
va_start(args, format); va_start(args, format);
report(parseErrorLocation, format, args); report(parseErrorLocation, format, args);

View file

@ -14,7 +14,6 @@
#endif #endif
LUAU_FASTFLAG(DebugLuauTimeTracing) LUAU_FASTFLAG(DebugLuauTimeTracing)
LUAU_FASTFLAG(LuauTypeMismatchModuleNameResolution)
enum class ReportFormat enum class ReportFormat
{ {
@ -55,11 +54,9 @@ static void reportError(const Luau::Frontend& frontend, ReportFormat format, con
if (const Luau::SyntaxError* syntaxError = Luau::get_if<Luau::SyntaxError>(&error.data)) if (const Luau::SyntaxError* syntaxError = Luau::get_if<Luau::SyntaxError>(&error.data))
report(format, humanReadableName.c_str(), error.location, "SyntaxError", syntaxError->message.c_str()); report(format, humanReadableName.c_str(), error.location, "SyntaxError", syntaxError->message.c_str());
else if (FFlag::LuauTypeMismatchModuleNameResolution) else
report(format, humanReadableName.c_str(), error.location, "TypeError", report(format, humanReadableName.c_str(), error.location, "TypeError",
Luau::toString(error, Luau::TypeErrorToStringOptions{frontend.fileResolver}).c_str()); Luau::toString(error, Luau::TypeErrorToStringOptions{frontend.fileResolver}).c_str());
else
report(format, humanReadableName.c_str(), error.location, "TypeError", Luau::toString(error).c_str());
} }
static void reportWarning(ReportFormat format, const char* name, const Luau::LintWarning& warning) static void reportWarning(ReportFormat format, const char* name, const Luau::LintWarning& warning)

View file

@ -49,6 +49,8 @@ enum class CompileFormat
Binary, Binary,
Remarks, Remarks,
Codegen, Codegen,
CodegenVerbose,
CodegenNull,
Null Null
}; };
@ -673,21 +675,33 @@ static void reportError(const char* name, const Luau::CompileError& error)
report(name, error.getLocation(), "CompileError", error.what()); report(name, error.getLocation(), "CompileError", error.what());
} }
static std::string getCodegenAssembly(const char* name, const std::string& bytecode) static std::string getCodegenAssembly(const char* name, const std::string& bytecode, Luau::CodeGen::AssemblyOptions options)
{ {
std::unique_ptr<lua_State, void (*)(lua_State*)> globalState(luaL_newstate(), lua_close); std::unique_ptr<lua_State, void (*)(lua_State*)> globalState(luaL_newstate(), lua_close);
lua_State* L = globalState.get(); lua_State* L = globalState.get();
setupState(L);
if (luau_load(L, name, bytecode.data(), bytecode.size(), 0) == 0) if (luau_load(L, name, bytecode.data(), bytecode.size(), 0) == 0)
return Luau::CodeGen::getAssemblyText(L, -1); return Luau::CodeGen::getAssembly(L, -1, options);
fprintf(stderr, "Error loading bytecode %s\n", name); fprintf(stderr, "Error loading bytecode %s\n", name);
return ""; return "";
} }
static bool compileFile(const char* name, CompileFormat format) static void annotateInstruction(void* context, std::string& text, int fid, int instpos)
{
Luau::BytecodeBuilder& bcb = *(Luau::BytecodeBuilder*)context;
bcb.annotateInstruction(text, fid, instpos);
}
struct CompileStats
{
size_t lines;
size_t bytecode;
size_t codegen;
};
static bool compileFile(const char* name, CompileFormat format, CompileStats& stats)
{ {
std::optional<std::string> source = readFile(name); std::optional<std::string> source = readFile(name);
if (!source) if (!source)
@ -696,9 +710,13 @@ static bool compileFile(const char* name, CompileFormat format)
return false; return false;
} }
// NOTE: Normally, you should use Luau::compile or luau_compile (see lua_require as an example)
// This function is much more complicated because it supports many output human-readable formats through internal interfaces
try try
{ {
Luau::BytecodeBuilder bcb; Luau::BytecodeBuilder bcb;
Luau::CodeGen::AssemblyOptions options = {format == CompileFormat::CodegenNull, format == CompileFormat::Codegen, annotateInstruction, &bcb};
if (format == CompileFormat::Text) if (format == CompileFormat::Text)
{ {
@ -711,8 +729,24 @@ static bool compileFile(const char* name, CompileFormat format)
bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Remarks); bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Remarks);
bcb.setDumpSource(*source); bcb.setDumpSource(*source);
} }
else if (format == CompileFormat::Codegen || format == CompileFormat::CodegenVerbose)
{
bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Locals |
Luau::BytecodeBuilder::Dump_Remarks);
bcb.setDumpSource(*source);
}
Luau::compileOrThrow(bcb, *source, copts()); Luau::Allocator allocator;
Luau::AstNameTable names(allocator);
Luau::ParseResult result = Luau::Parser::parse(source->c_str(), source->size(), names, allocator);
if (!result.errors.empty())
throw Luau::ParseErrors(result.errors);
stats.lines += result.lines;
Luau::compileOrThrow(bcb, result, names, copts());
stats.bytecode += bcb.getBytecode().size();
switch (format) switch (format)
{ {
@ -726,7 +760,11 @@ static bool compileFile(const char* name, CompileFormat format)
fwrite(bcb.getBytecode().data(), 1, bcb.getBytecode().size(), stdout); fwrite(bcb.getBytecode().data(), 1, bcb.getBytecode().size(), stdout);
break; break;
case CompileFormat::Codegen: case CompileFormat::Codegen:
printf("%s", getCodegenAssembly(name, bcb.getBytecode()).c_str()); case CompileFormat::CodegenVerbose:
printf("%s", getCodegenAssembly(name, bcb.getBytecode(), options).c_str());
break;
case CompileFormat::CodegenNull:
stats.codegen += getCodegenAssembly(name, bcb.getBytecode(), options).size();
break; break;
case CompileFormat::Null: case CompileFormat::Null:
break; break;
@ -755,7 +793,7 @@ static void displayHelp(const char* argv0)
printf("\n"); printf("\n");
printf("Available modes:\n"); printf("Available modes:\n");
printf(" omitted: compile and run input files one by one\n"); printf(" omitted: compile and run input files one by one\n");
printf(" --compile[=format]: compile input files and output resulting formatted bytecode (binary, text, remarks, codegen or null)\n"); printf(" --compile[=format]: compile input files and output resulting bytecode/assembly (binary, text, remarks, codegen)\n");
printf("\n"); printf("\n");
printf("Available options:\n"); printf("Available options:\n");
printf(" --coverage: collect code coverage while running the code and output results to coverage.out\n"); printf(" --coverage: collect code coverage while running the code and output results to coverage.out\n");
@ -813,6 +851,14 @@ int replMain(int argc, char** argv)
{ {
compileFormat = CompileFormat::Codegen; compileFormat = CompileFormat::Codegen;
} }
else if (strcmp(argv[1], "--compile=codegenverbose") == 0)
{
compileFormat = CompileFormat::CodegenVerbose;
}
else if (strcmp(argv[1], "--compile=codegennull") == 0)
{
compileFormat = CompileFormat::CodegenNull;
}
else if (strcmp(argv[1], "--compile=null") == 0) else if (strcmp(argv[1], "--compile=null") == 0)
{ {
compileFormat = CompileFormat::Null; compileFormat = CompileFormat::Null;
@ -924,10 +970,17 @@ int replMain(int argc, char** argv)
_setmode(_fileno(stdout), _O_BINARY); _setmode(_fileno(stdout), _O_BINARY);
#endif #endif
CompileStats stats = {};
int failed = 0; int failed = 0;
for (const std::string& path : files) for (const std::string& path : files)
failed += !compileFile(path.c_str(), compileFormat); failed += !compileFile(path.c_str(), compileFormat, stats);
if (compileFormat == CompileFormat::Null)
printf("Compiled %d KLOC into %d KB bytecode\n", int(stats.lines / 1000), int(stats.bytecode / 1024));
else if (compileFormat == CompileFormat::CodegenNull)
printf("Compiled %d KLOC into %d KB bytecode => %d KB native code\n", int(stats.lines / 1000), int(stats.bytecode / 1024),
int(stats.codegen / 1024));
return failed ? 1 : 0; return failed ? 1 : 0;
} }

View file

@ -143,6 +143,11 @@ if (MSVC AND MSVC_VERSION GREATER_EQUAL 1924)
set_source_files_properties(VM/src/lvmexecute.cpp PROPERTIES COMPILE_FLAGS /d2ssa-pre-) set_source_files_properties(VM/src/lvmexecute.cpp PROPERTIES COMPILE_FLAGS /d2ssa-pre-)
endif() endif()
if (NOT MSVC)
# disable support for math_errno which allows compilers to lower sqrt() into a single CPU instruction
target_compile_options(Luau.VM PRIVATE -fno-math-errno)
endif()
if(MSVC AND LUAU_BUILD_CLI) if(MSVC AND LUAU_BUILD_CLI)
# the default stack size that MSVC linker uses is 1 MB; we need more stack space in Debug because stack frames are larger # the default stack size that MSVC linker uses is 1 MB; we need more stack space in Debug because stack frames are larger
set_target_properties(Luau.Analyze.CLI PROPERTIES LINK_FLAGS_DEBUG /STACK:2097152) set_target_properties(Luau.Analyze.CLI PROPERTIES LINK_FLAGS_DEBUG /STACK:2097152)

View file

@ -0,0 +1,52 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/RegisterA64.h"
namespace Luau
{
namespace CodeGen
{
enum class AddressKindA64 : uint8_t
{
imm, // reg + imm
reg, // reg + reg
// TODO:
// reg + reg << shift
// reg + sext(reg) << shift
// reg + uext(reg) << shift
// pc + offset
};
struct AddressA64
{
AddressA64(RegisterA64 base, int off = 0)
: kind(AddressKindA64::imm)
, base(base)
, offset(xzr)
, data(off)
{
LUAU_ASSERT(base.kind == KindA64::x);
LUAU_ASSERT(off >= 0 && off < 4096);
}
AddressA64(RegisterA64 base, RegisterA64 offset)
: kind(AddressKindA64::reg)
, base(base)
, offset(offset)
, data(0)
{
LUAU_ASSERT(base.kind == KindA64::x);
LUAU_ASSERT(offset.kind == KindA64::x);
}
AddressKindA64 kind;
RegisterA64 base;
RegisterA64 offset;
int data;
};
} // namespace CodeGen
} // namespace Luau

View file

@ -0,0 +1,144 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/RegisterA64.h"
#include "Luau/AddressA64.h"
#include "Luau/ConditionA64.h"
#include "Luau/Label.h"
#include <string>
#include <vector>
namespace Luau
{
namespace CodeGen
{
class AssemblyBuilderA64
{
public:
explicit AssemblyBuilderA64(bool logText);
~AssemblyBuilderA64();
// Moves
void mov(RegisterA64 dst, RegisterA64 src);
void mov(RegisterA64 dst, uint16_t src, int shift = 0);
void movk(RegisterA64 dst, uint16_t src, int shift = 0);
// Arithmetics
void add(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift = 0);
void add(RegisterA64 dst, RegisterA64 src1, int src2);
void sub(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift = 0);
void sub(RegisterA64 dst, RegisterA64 src1, int src2);
void neg(RegisterA64 dst, RegisterA64 src);
// Comparisons
// Note: some arithmetic instructions also have versions that update flags (ADDS etc) but we aren't using them atm
// TODO: add cmp
// Binary
// Note: shifted-register support and bitfield operations are omitted for simplicity
// TODO: support immediate arguments (they have odd encoding and forbid many values)
// TODO: support not variants for and/or/eor (required to support not...)
void and_(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2);
void orr(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2);
void eor(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2);
void lsl(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2);
void lsr(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2);
void asr(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2);
void ror(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2);
void clz(RegisterA64 dst, RegisterA64 src);
void rbit(RegisterA64 dst, RegisterA64 src);
// Load
// Note: paired loads are currently omitted for simplicity
void ldr(RegisterA64 dst, AddressA64 src);
void ldrb(RegisterA64 dst, AddressA64 src);
void ldrh(RegisterA64 dst, AddressA64 src);
void ldrsb(RegisterA64 dst, AddressA64 src);
void ldrsh(RegisterA64 dst, AddressA64 src);
void ldrsw(RegisterA64 dst, AddressA64 src);
// Store
void str(RegisterA64 src, AddressA64 dst);
void strb(RegisterA64 src, AddressA64 dst);
void strh(RegisterA64 src, AddressA64 dst);
// Control flow
// Note: tbz/tbnz are currently not supported because they have 15-bit offsets and we don't support branch thunks
void b(ConditionA64 cond, Label& label);
void cbz(RegisterA64 src, Label& label);
void cbnz(RegisterA64 src, Label& label);
void ret();
// Run final checks
bool finalize();
// Places a label at current location and returns it
Label setLabel();
// Assigns label position to the current location
void setLabel(Label& label);
void logAppend(const char* fmt, ...) LUAU_PRINTF_ATTR(2, 3);
uint32_t getCodeSize() const;
// Resulting data and code that need to be copied over one after the other
// The *end* of 'data' has to be aligned to 16 bytes, this will also align 'code'
std::vector<uint8_t> data;
std::vector<uint32_t> code;
std::string text;
const bool logText = false;
private:
// Instruction archetypes
void place0(const char* name, uint32_t word);
void placeSR3(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t op, int shift = 0);
void placeSR2(const char* name, RegisterA64 dst, RegisterA64 src, uint8_t op);
void placeR3(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t op, uint8_t op2);
void placeR1(const char* name, RegisterA64 dst, RegisterA64 src, uint32_t op);
void placeI12(const char* name, RegisterA64 dst, RegisterA64 src1, int src2, uint8_t op);
void placeI16(const char* name, RegisterA64 dst, int src, uint8_t op, int shift = 0);
void placeA(const char* name, RegisterA64 dst, AddressA64 src, uint8_t op, uint8_t size);
void placeBC(const char* name, Label& label, uint8_t op, uint8_t cond);
void placeBR(const char* name, Label& label, uint8_t op, RegisterA64 cond);
void place(uint32_t word);
void placeLabel(Label& label);
void commit();
LUAU_NOINLINE void extend();
// Data
size_t allocateData(size_t size, size_t align);
// Logging of assembly in text form
LUAU_NOINLINE void log(const char* opcode);
LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift = 0);
LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, RegisterA64 src1, int src2);
LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, RegisterA64 src);
LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, int src, int shift = 0);
LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, AddressA64 src);
LUAU_NOINLINE void log(const char* opcode, RegisterA64 src, Label label);
LUAU_NOINLINE void log(const char* opcode, Label label);
LUAU_NOINLINE void log(Label label);
LUAU_NOINLINE void log(RegisterA64 reg);
LUAU_NOINLINE void log(AddressA64 addr);
uint32_t nextLabel = 1;
std::vector<Label> pendingLabels;
std::vector<uint32_t> labelLocations;
bool finalized = false;
size_t dataPos = 0;
uint32_t* codePos = nullptr;
uint32_t* codeEnd = nullptr;
};
} // namespace CodeGen
} // namespace Luau

View file

@ -2,8 +2,8 @@
#pragma once #pragma once
#include "Luau/Common.h" #include "Luau/Common.h"
#include "Luau/Condition.h"
#include "Luau/Label.h" #include "Luau/Label.h"
#include "Luau/ConditionX64.h"
#include "Luau/OperandX64.h" #include "Luau/OperandX64.h"
#include "Luau/RegisterX64.h" #include "Luau/RegisterX64.h"
@ -23,6 +23,19 @@ enum class RoundingModeX64
RoundToZero = 0b11, RoundToZero = 0b11,
}; };
enum class AlignmentDataX64
{
Nop,
Int3,
Ud2, // int3 will be used as a fall-back if it doesn't fit
};
enum class ABIX64
{
Windows,
SystemV,
};
class AssemblyBuilderX64 class AssemblyBuilderX64
{ {
public: public:
@ -71,7 +84,7 @@ public:
void ret(); void ret();
// Control flow // Control flow
void jcc(Condition cond, Label& label); void jcc(ConditionX64 cond, Label& label);
void jmp(Label& label); void jmp(Label& label);
void jmp(OperandX64 op); void jmp(OperandX64 op);
@ -80,6 +93,10 @@ public:
void int3(); void int3();
// Code alignment
void nop(uint32_t length = 1);
void align(uint32_t alignment, AlignmentDataX64 data = AlignmentDataX64::Nop);
// AVX // AVX
void vaddpd(OperandX64 dst, OperandX64 src1, OperandX64 src2); void vaddpd(OperandX64 dst, OperandX64 src1, OperandX64 src2);
void vaddps(OperandX64 dst, OperandX64 src1, OperandX64 src2); void vaddps(OperandX64 dst, OperandX64 src1, OperandX64 src2);
@ -131,6 +148,8 @@ public:
void logAppend(const char* fmt, ...) LUAU_PRINTF_ATTR(2, 3); void logAppend(const char* fmt, ...) LUAU_PRINTF_ATTR(2, 3);
uint32_t getCodeSize() const;
// Resulting data and code that need to be copied over one after the other // Resulting data and code that need to be copied over one after the other
// The *end* of 'data' has to be aligned to 16 bytes, this will also align 'code' // The *end* of 'data' has to be aligned to 16 bytes, this will also align 'code'
std::vector<uint8_t> data; std::vector<uint8_t> data;
@ -140,6 +159,8 @@ public:
const bool logText = false; const bool logText = false;
const ABIX64 abi;
private: private:
// Instruction archetypes // Instruction archetypes
void placeBinary(const char* name, OperandX64 lhs, OperandX64 rhs, uint8_t codeimm8, uint8_t codeimm, uint8_t codeimmImm8, uint8_t code8rev, void placeBinary(const char* name, OperandX64 lhs, OperandX64 rhs, uint8_t codeimm8, uint8_t codeimm, uint8_t codeimmImm8, uint8_t code8rev,
@ -177,7 +198,6 @@ private:
void commit(); void commit();
LUAU_NOINLINE void extend(); LUAU_NOINLINE void extend();
uint32_t getCodeSize();
// Data // Data
size_t allocateData(size_t size, size_t align); size_t allocateData(size_t size, size_t align);
@ -192,8 +212,8 @@ private:
LUAU_NOINLINE void log(const char* opcode, Label label); LUAU_NOINLINE void log(const char* opcode, Label label);
void log(OperandX64 op); void log(OperandX64 op);
const char* getSizeName(SizeX64 size); const char* getSizeName(SizeX64 size) const;
const char* getRegisterName(RegisterX64 reg); const char* getRegisterName(RegisterX64 reg) const;
uint32_t nextLabel = 1; uint32_t nextLabel = 1;
std::vector<Label> pendingLabels; std::vector<Label> pendingLabels;

View file

@ -11,6 +11,8 @@ namespace Luau
namespace CodeGen namespace CodeGen
{ {
constexpr uint32_t kCodeAlignment = 32;
struct CodeAllocator struct CodeAllocator
{ {
CodeAllocator(size_t blockSize, size_t maxTotalSize); CodeAllocator(size_t blockSize, size_t maxTotalSize);

View file

@ -17,8 +17,20 @@ void create(lua_State* L);
// Builds target function and all inner functions // Builds target function and all inner functions
void compile(lua_State* L, int idx); void compile(lua_State* L, int idx);
// Generates assembly text for target function and all inner functions using annotatorFn = void (*)(void* context, std::string& result, int fid, int instpos);
std::string getAssemblyText(lua_State* L, int idx);
struct AssemblyOptions
{
bool outputBinary = false;
bool skipOutlinedCode = false;
// Optional annotator function can be provided to describe each instruction, it takes function id and sequential instruction id
annotatorFn annotator = nullptr;
void* annotatorContext = nullptr;
};
// Generates assembly for target function and all inner functions
std::string getAssembly(lua_State* L, int idx, AssemblyOptions options = {});
} // namespace CodeGen } // namespace CodeGen
} // namespace Luau } // namespace Luau

View file

@ -0,0 +1,37 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
namespace Luau
{
namespace CodeGen
{
enum class ConditionA64
{
Equal,
NotEqual,
CarrySet,
CarryClear,
Minus,
Plus,
Overflow,
NoOverflow,
UnsignedGreater,
UnsignedLessEqual,
GreaterEqual,
Less,
Greater,
LessEqual,
Always,
Count
};
} // namespace CodeGen
} // namespace Luau

View file

@ -6,7 +6,7 @@ namespace Luau
namespace CodeGen namespace CodeGen
{ {
enum class Condition enum class ConditionX64 : uint8_t
{ {
Overflow, Overflow,
NoOverflow, NoOverflow,

View file

@ -61,7 +61,7 @@ struct OperandX64
constexpr OperandX64 operator[](OperandX64&& addr) const constexpr OperandX64 operator[](OperandX64&& addr) const
{ {
LUAU_ASSERT(cat == CategoryX64::mem); LUAU_ASSERT(cat == CategoryX64::mem);
LUAU_ASSERT(memSize != SizeX64::none && index == noreg && scale == 1 && base == noreg && imm == 0); LUAU_ASSERT(index == noreg && scale == 1 && base == noreg && imm == 0);
LUAU_ASSERT(addr.memSize == SizeX64::none); LUAU_ASSERT(addr.memSize == SizeX64::none);
addr.cat = CategoryX64::mem; addr.cat = CategoryX64::mem;
@ -70,13 +70,13 @@ struct OperandX64
} }
}; };
constexpr OperandX64 addr{SizeX64::none, noreg, 1, noreg, 0};
constexpr OperandX64 byte{SizeX64::byte, noreg, 1, noreg, 0}; constexpr OperandX64 byte{SizeX64::byte, noreg, 1, noreg, 0};
constexpr OperandX64 word{SizeX64::word, noreg, 1, noreg, 0}; constexpr OperandX64 word{SizeX64::word, noreg, 1, noreg, 0};
constexpr OperandX64 dword{SizeX64::dword, noreg, 1, noreg, 0}; constexpr OperandX64 dword{SizeX64::dword, noreg, 1, noreg, 0};
constexpr OperandX64 qword{SizeX64::qword, noreg, 1, noreg, 0}; constexpr OperandX64 qword{SizeX64::qword, noreg, 1, noreg, 0};
constexpr OperandX64 xmmword{SizeX64::xmmword, noreg, 1, noreg, 0}; constexpr OperandX64 xmmword{SizeX64::xmmword, noreg, 1, noreg, 0};
constexpr OperandX64 ymmword{SizeX64::ymmword, noreg, 1, noreg, 0}; constexpr OperandX64 ymmword{SizeX64::ymmword, noreg, 1, noreg, 0};
constexpr OperandX64 ptr{sizeof(void*) == 4 ? SizeX64::dword : SizeX64::qword, noreg, 1, noreg, 0};
constexpr OperandX64 operator*(RegisterX64 reg, uint8_t scale) constexpr OperandX64 operator*(RegisterX64 reg, uint8_t scale)
{ {
@ -94,6 +94,11 @@ constexpr OperandX64 operator+(RegisterX64 reg, int32_t disp)
return OperandX64(SizeX64::none, noreg, 1, reg, disp); return OperandX64(SizeX64::none, noreg, 1, reg, disp);
} }
constexpr OperandX64 operator-(RegisterX64 reg, int32_t disp)
{
return OperandX64(SizeX64::none, noreg, 1, reg, -disp);
}
constexpr OperandX64 operator+(RegisterX64 base, RegisterX64 index) constexpr OperandX64 operator+(RegisterX64 base, RegisterX64 index)
{ {
LUAU_ASSERT(index.index != 4 && "sp cannot be used as index"); LUAU_ASSERT(index.index != 4 && "sp cannot be used as index");

View file

@ -0,0 +1,105 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Common.h"
#include <stdint.h>
namespace Luau
{
namespace CodeGen
{
enum class KindA64 : uint8_t
{
none,
w, // 32-bit GPR
x, // 64-bit GPR
};
struct RegisterA64
{
KindA64 kind : 3;
uint8_t index : 5;
constexpr bool operator==(RegisterA64 rhs) const
{
return kind == rhs.kind && index == rhs.index;
}
constexpr bool operator!=(RegisterA64 rhs) const
{
return !(*this == rhs);
}
};
constexpr RegisterA64 w0{KindA64::w, 0};
constexpr RegisterA64 w1{KindA64::w, 1};
constexpr RegisterA64 w2{KindA64::w, 2};
constexpr RegisterA64 w3{KindA64::w, 3};
constexpr RegisterA64 w4{KindA64::w, 4};
constexpr RegisterA64 w5{KindA64::w, 5};
constexpr RegisterA64 w6{KindA64::w, 6};
constexpr RegisterA64 w7{KindA64::w, 7};
constexpr RegisterA64 w8{KindA64::w, 8};
constexpr RegisterA64 w9{KindA64::w, 9};
constexpr RegisterA64 w10{KindA64::w, 10};
constexpr RegisterA64 w11{KindA64::w, 11};
constexpr RegisterA64 w12{KindA64::w, 12};
constexpr RegisterA64 w13{KindA64::w, 13};
constexpr RegisterA64 w14{KindA64::w, 14};
constexpr RegisterA64 w15{KindA64::w, 15};
constexpr RegisterA64 w16{KindA64::w, 16};
constexpr RegisterA64 w17{KindA64::w, 17};
constexpr RegisterA64 w18{KindA64::w, 18};
constexpr RegisterA64 w19{KindA64::w, 19};
constexpr RegisterA64 w20{KindA64::w, 20};
constexpr RegisterA64 w21{KindA64::w, 21};
constexpr RegisterA64 w22{KindA64::w, 22};
constexpr RegisterA64 w23{KindA64::w, 23};
constexpr RegisterA64 w24{KindA64::w, 24};
constexpr RegisterA64 w25{KindA64::w, 25};
constexpr RegisterA64 w26{KindA64::w, 26};
constexpr RegisterA64 w27{KindA64::w, 27};
constexpr RegisterA64 w28{KindA64::w, 28};
constexpr RegisterA64 w29{KindA64::w, 29};
constexpr RegisterA64 w30{KindA64::w, 30};
constexpr RegisterA64 wzr{KindA64::w, 31};
constexpr RegisterA64 x0{KindA64::x, 0};
constexpr RegisterA64 x1{KindA64::x, 1};
constexpr RegisterA64 x2{KindA64::x, 2};
constexpr RegisterA64 x3{KindA64::x, 3};
constexpr RegisterA64 x4{KindA64::x, 4};
constexpr RegisterA64 x5{KindA64::x, 5};
constexpr RegisterA64 x6{KindA64::x, 6};
constexpr RegisterA64 x7{KindA64::x, 7};
constexpr RegisterA64 x8{KindA64::x, 8};
constexpr RegisterA64 x9{KindA64::x, 9};
constexpr RegisterA64 x10{KindA64::x, 10};
constexpr RegisterA64 x11{KindA64::x, 11};
constexpr RegisterA64 x12{KindA64::x, 12};
constexpr RegisterA64 x13{KindA64::x, 13};
constexpr RegisterA64 x14{KindA64::x, 14};
constexpr RegisterA64 x15{KindA64::x, 15};
constexpr RegisterA64 x16{KindA64::x, 16};
constexpr RegisterA64 x17{KindA64::x, 17};
constexpr RegisterA64 x18{KindA64::x, 18};
constexpr RegisterA64 x19{KindA64::x, 19};
constexpr RegisterA64 x20{KindA64::x, 20};
constexpr RegisterA64 x21{KindA64::x, 21};
constexpr RegisterA64 x22{KindA64::x, 22};
constexpr RegisterA64 x23{KindA64::x, 23};
constexpr RegisterA64 x24{KindA64::x, 24};
constexpr RegisterA64 x25{KindA64::x, 25};
constexpr RegisterA64 x26{KindA64::x, 26};
constexpr RegisterA64 x27{KindA64::x, 27};
constexpr RegisterA64 x28{KindA64::x, 28};
constexpr RegisterA64 x29{KindA64::x, 29};
constexpr RegisterA64 x30{KindA64::x, 30};
constexpr RegisterA64 xzr{KindA64::x, 31};
constexpr RegisterA64 sp{KindA64::none, 31};
} // namespace CodeGen
} // namespace Luau

View file

@ -113,5 +113,25 @@ constexpr RegisterX64 ymm13{SizeX64::ymmword, 13};
constexpr RegisterX64 ymm14{SizeX64::ymmword, 14}; constexpr RegisterX64 ymm14{SizeX64::ymmword, 14};
constexpr RegisterX64 ymm15{SizeX64::ymmword, 15}; constexpr RegisterX64 ymm15{SizeX64::ymmword, 15};
constexpr RegisterX64 byteReg(RegisterX64 reg)
{
return RegisterX64{SizeX64::byte, reg.index};
}
constexpr RegisterX64 wordReg(RegisterX64 reg)
{
return RegisterX64{SizeX64::word, reg.index};
}
constexpr RegisterX64 dwordReg(RegisterX64 reg)
{
return RegisterX64{SizeX64::dword, reg.index};
}
constexpr RegisterX64 qwordReg(RegisterX64 reg)
{
return RegisterX64{SizeX64::qword, reg.index};
}
} // namespace CodeGen } // namespace CodeGen
} // namespace Luau } // namespace Luau

View file

@ -0,0 +1,607 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/AssemblyBuilderA64.h"
#include "ByteUtils.h"
#include <stdarg.h>
namespace Luau
{
namespace CodeGen
{
static const uint8_t codeForCondition[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14};
static_assert(sizeof(codeForCondition) / sizeof(codeForCondition[0]) == size_t(ConditionA64::Count), "all conditions have to be covered");
static const char* textForCondition[] = {
"b.eq", "b.ne", "b.cs", "b.cc", "b.mi", "b.pl", "b.vs", "b.vc", "b.hi", "b.ls", "b.ge", "b.lt", "b.gt", "b.le", "b.al"};
static_assert(sizeof(textForCondition) / sizeof(textForCondition[0]) == size_t(ConditionA64::Count), "all conditions have to be covered");
const unsigned kMaxAlign = 32;
AssemblyBuilderA64::AssemblyBuilderA64(bool logText)
: logText(logText)
{
data.resize(4096);
dataPos = data.size(); // data is filled backwards
code.resize(1024);
codePos = code.data();
codeEnd = code.data() + code.size();
}
AssemblyBuilderA64::~AssemblyBuilderA64()
{
LUAU_ASSERT(finalized);
}
void AssemblyBuilderA64::mov(RegisterA64 dst, RegisterA64 src)
{
placeSR2("mov", dst, src, 0b01'01010);
}
void AssemblyBuilderA64::mov(RegisterA64 dst, uint16_t src, int shift)
{
placeI16("mov", dst, src, 0b10'100101, shift);
}
void AssemblyBuilderA64::movk(RegisterA64 dst, uint16_t src, int shift)
{
placeI16("movk", dst, src, 0b11'100101, shift);
}
void AssemblyBuilderA64::add(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift)
{
placeSR3("add", dst, src1, src2, 0b00'01011, shift);
}
void AssemblyBuilderA64::add(RegisterA64 dst, RegisterA64 src1, int src2)
{
placeI12("add", dst, src1, src2, 0b00'10001);
}
void AssemblyBuilderA64::sub(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift)
{
placeSR3("sub", dst, src1, src2, 0b10'01011, shift);
}
void AssemblyBuilderA64::sub(RegisterA64 dst, RegisterA64 src1, int src2)
{
placeI12("sub", dst, src1, src2, 0b10'10001);
}
void AssemblyBuilderA64::neg(RegisterA64 dst, RegisterA64 src)
{
placeSR2("neg", dst, src, 0b10'01011);
}
void AssemblyBuilderA64::and_(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2)
{
placeSR3("and", dst, src1, src2, 0b00'01010);
}
void AssemblyBuilderA64::orr(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2)
{
placeSR3("orr", dst, src1, src2, 0b01'01010);
}
void AssemblyBuilderA64::eor(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2)
{
placeSR3("eor", dst, src1, src2, 0b10'01010);
}
void AssemblyBuilderA64::lsl(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2)
{
placeR3("lsl", dst, src1, src2, 0b11010110, 0b0010'00);
}
void AssemblyBuilderA64::lsr(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2)
{
placeR3("lsr", dst, src1, src2, 0b11010110, 0b0010'01);
}
void AssemblyBuilderA64::asr(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2)
{
placeR3("asr", dst, src1, src2, 0b11010110, 0b0010'10);
}
void AssemblyBuilderA64::ror(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2)
{
placeR3("ror", dst, src1, src2, 0b11010110, 0b0010'11);
}
void AssemblyBuilderA64::clz(RegisterA64 dst, RegisterA64 src)
{
placeR1("clz", dst, src, 0b10'11010110'00000'00010'0);
}
void AssemblyBuilderA64::rbit(RegisterA64 dst, RegisterA64 src)
{
placeR1("rbit", dst, src, 0b10'11010110'00000'0000'00);
}
void AssemblyBuilderA64::ldr(RegisterA64 dst, AddressA64 src)
{
LUAU_ASSERT(dst.kind == KindA64::x || dst.kind == KindA64::w);
placeA("ldr", dst, src, 0b11100001, 0b10 | uint8_t(dst.kind == KindA64::x));
}
void AssemblyBuilderA64::ldrb(RegisterA64 dst, AddressA64 src)
{
LUAU_ASSERT(dst.kind == KindA64::w);
placeA("ldrb", dst, src, 0b11100001, 0b00);
}
void AssemblyBuilderA64::ldrh(RegisterA64 dst, AddressA64 src)
{
LUAU_ASSERT(dst.kind == KindA64::w);
placeA("ldrh", dst, src, 0b11100001, 0b01);
}
void AssemblyBuilderA64::ldrsb(RegisterA64 dst, AddressA64 src)
{
LUAU_ASSERT(dst.kind == KindA64::x || dst.kind == KindA64::w);
placeA("ldrsb", dst, src, 0b11100010 | uint8_t(dst.kind == KindA64::w), 0b00);
}
void AssemblyBuilderA64::ldrsh(RegisterA64 dst, AddressA64 src)
{
LUAU_ASSERT(dst.kind == KindA64::x || dst.kind == KindA64::w);
placeA("ldrsh", dst, src, 0b11100010 | uint8_t(dst.kind == KindA64::w), 0b01);
}
void AssemblyBuilderA64::ldrsw(RegisterA64 dst, AddressA64 src)
{
LUAU_ASSERT(dst.kind == KindA64::x);
placeA("ldrsw", dst, src, 0b11100010, 0b10);
}
void AssemblyBuilderA64::str(RegisterA64 src, AddressA64 dst)
{
LUAU_ASSERT(src.kind == KindA64::x || src.kind == KindA64::w);
placeA("str", src, dst, 0b11100000, 0b10 | uint8_t(src.kind == KindA64::x));
}
void AssemblyBuilderA64::strb(RegisterA64 src, AddressA64 dst)
{
LUAU_ASSERT(src.kind == KindA64::w);
placeA("strb", src, dst, 0b11100000, 0b00);
}
void AssemblyBuilderA64::strh(RegisterA64 src, AddressA64 dst)
{
LUAU_ASSERT(src.kind == KindA64::w);
placeA("strh", src, dst, 0b11100000, 0b01);
}
void AssemblyBuilderA64::b(ConditionA64 cond, Label& label)
{
placeBC(textForCondition[int(cond)], label, 0b0101010'0, codeForCondition[int(cond)]);
}
void AssemblyBuilderA64::cbz(RegisterA64 src, Label& label)
{
placeBR("cbz", label, 0b011010'0, src);
}
void AssemblyBuilderA64::cbnz(RegisterA64 src, Label& label)
{
placeBR("cbnz", label, 0b011010'1, src);
}
void AssemblyBuilderA64::ret()
{
place0("ret", 0b1101011'0'0'10'11111'0000'0'0'11110'00000);
}
bool AssemblyBuilderA64::finalize()
{
bool success = true;
code.resize(codePos - code.data());
// Resolve jump targets
for (Label fixup : pendingLabels)
{
// If this assertion fires, a label was used in jmp without calling setLabel
LUAU_ASSERT(labelLocations[fixup.id - 1] != ~0u);
int value = int(labelLocations[fixup.id - 1]) - int(fixup.location);
// imm19 encoding word offset, at bit offset 5
// note that 18 bits of word offsets = 20 bits of byte offsets = +-1MB
if (value > -(1 << 18) && value < (1 << 18))
code[fixup.location] |= (value & ((1 << 19) - 1)) << 5;
else
success = false; // overflow
}
size_t dataSize = data.size() - dataPos;
// Shrink data
if (dataSize > 0)
memmove(&data[0], &data[dataPos], dataSize);
data.resize(dataSize);
finalized = true;
return success;
}
Label AssemblyBuilderA64::setLabel()
{
Label label{nextLabel++, getCodeSize()};
labelLocations.push_back(~0u);
if (logText)
log(label);
return label;
}
void AssemblyBuilderA64::setLabel(Label& label)
{
if (label.id == 0)
{
label.id = nextLabel++;
labelLocations.push_back(~0u);
}
label.location = getCodeSize();
labelLocations[label.id - 1] = label.location;
if (logText)
log(label);
}
void AssemblyBuilderA64::logAppend(const char* fmt, ...)
{
char buf[256];
va_list args;
va_start(args, fmt);
vsnprintf(buf, sizeof(buf), fmt, args);
va_end(args);
text.append(buf);
}
uint32_t AssemblyBuilderA64::getCodeSize() const
{
return uint32_t(codePos - code.data());
}
void AssemblyBuilderA64::place0(const char* name, uint32_t op)
{
if (logText)
log(name);
place(op);
commit();
}
void AssemblyBuilderA64::placeSR3(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t op, int shift)
{
if (logText)
log(name, dst, src1, src2, shift);
LUAU_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x);
LUAU_ASSERT(dst.kind == src1.kind && dst.kind == src2.kind);
LUAU_ASSERT(shift >= 0 && shift < 64); // right shift requires changing some encoding bits
uint32_t sf = (dst.kind == KindA64::x) ? 0x80000000 : 0;
place(dst.index | (src1.index << 5) | (shift << 10) | (src2.index << 16) | (op << 24) | sf);
commit();
}
void AssemblyBuilderA64::placeSR2(const char* name, RegisterA64 dst, RegisterA64 src, uint8_t op)
{
if (logText)
log(name, dst, src);
LUAU_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x);
LUAU_ASSERT(dst.kind == src.kind);
uint32_t sf = (dst.kind == KindA64::x) ? 0x80000000 : 0;
place(dst.index | (0x1f << 5) | (src.index << 16) | (op << 24) | sf);
commit();
}
void AssemblyBuilderA64::placeR3(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t op, uint8_t op2)
{
if (logText)
log(name, dst, src1, src2);
LUAU_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x);
LUAU_ASSERT(dst.kind == src1.kind && dst.kind == src2.kind);
uint32_t sf = (dst.kind == KindA64::x) ? 0x80000000 : 0;
place(dst.index | (src1.index << 5) | (op2 << 10) | (src2.index << 16) | (op << 21) | sf);
commit();
}
void AssemblyBuilderA64::placeR1(const char* name, RegisterA64 dst, RegisterA64 src, uint32_t op)
{
if (logText)
log(name, dst, src);
LUAU_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x);
LUAU_ASSERT(dst.kind == src.kind);
uint32_t sf = (dst.kind == KindA64::x) ? 0x80000000 : 0;
place(dst.index | (src.index << 5) | (op << 10) | sf);
commit();
}
void AssemblyBuilderA64::placeI12(const char* name, RegisterA64 dst, RegisterA64 src1, int src2, uint8_t op)
{
if (logText)
log(name, dst, src1, src2);
LUAU_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x);
LUAU_ASSERT(dst.kind == src1.kind);
LUAU_ASSERT(src2 >= 0 && src2 < (1 << 12));
uint32_t sf = (dst.kind == KindA64::x) ? 0x80000000 : 0;
place(dst.index | (src1.index << 5) | (src2 << 10) | (op << 24) | sf);
commit();
}
void AssemblyBuilderA64::placeI16(const char* name, RegisterA64 dst, int src, uint8_t op, int shift)
{
if (logText)
log(name, dst, src, shift);
LUAU_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x);
LUAU_ASSERT(src >= 0 && src <= 0xffff);
LUAU_ASSERT(shift == 0 || shift == 16 || shift == 32 || shift == 48);
uint32_t sf = (dst.kind == KindA64::x) ? 0x80000000 : 0;
place(dst.index | (src << 5) | ((shift >> 4) << 21) | (op << 23) | sf);
commit();
}
void AssemblyBuilderA64::placeA(const char* name, RegisterA64 dst, AddressA64 src, uint8_t op, uint8_t size)
{
if (logText)
log(name, dst, src);
switch (src.kind)
{
case AddressKindA64::imm:
LUAU_ASSERT(src.data % (1 << size) == 0);
place(dst.index | (src.base.index << 5) | ((src.data >> size) << 10) | (op << 22) | (1 << 24) | (size << 30));
break;
case AddressKindA64::reg:
place(dst.index | (src.base.index << 5) | (0b10 << 10) | (0b011 << 13) | (src.offset.index << 16) | (1 << 21) | (op << 22) | (size << 30));
break;
}
commit();
}
void AssemblyBuilderA64::placeBC(const char* name, Label& label, uint8_t op, uint8_t cond)
{
placeLabel(label);
if (logText)
log(name, label);
place(cond | (op << 24));
commit();
}
void AssemblyBuilderA64::placeBR(const char* name, Label& label, uint8_t op, RegisterA64 cond)
{
placeLabel(label);
if (logText)
log(name, cond, label);
LUAU_ASSERT(cond.kind == KindA64::w || cond.kind == KindA64::x);
uint32_t sf = (cond.kind == KindA64::x) ? 0x80000000 : 0;
place(cond.index | (op << 24) | sf);
commit();
}
void AssemblyBuilderA64::place(uint32_t word)
{
LUAU_ASSERT(codePos < codeEnd);
*codePos++ = word;
}
void AssemblyBuilderA64::placeLabel(Label& label)
{
if (label.location == ~0u)
{
if (label.id == 0)
{
label.id = nextLabel++;
labelLocations.push_back(~0u);
}
pendingLabels.push_back({label.id, getCodeSize()});
}
else
{
// note: if label has an assigned location we can in theory avoid patching it later, but
// we need to handle potential overflow of 19-bit offsets
LUAU_ASSERT(label.id != 0);
labelLocations[label.id - 1] = label.location;
pendingLabels.push_back({label.id, getCodeSize()});
}
}
void AssemblyBuilderA64::commit()
{
LUAU_ASSERT(codePos <= codeEnd);
if (codeEnd == codePos)
extend();
}
void AssemblyBuilderA64::extend()
{
uint32_t count = getCodeSize();
code.resize(code.size() * 2);
codePos = code.data() + count;
codeEnd = code.data() + code.size();
}
size_t AssemblyBuilderA64::allocateData(size_t size, size_t align)
{
LUAU_ASSERT(align > 0 && align <= kMaxAlign && (align & (align - 1)) == 0);
if (dataPos < size)
{
size_t oldSize = data.size();
data.resize(data.size() * 2);
memcpy(&data[oldSize], &data[0], oldSize);
memset(&data[0], 0, oldSize);
dataPos += oldSize;
}
dataPos = (dataPos - size) & ~(align - 1);
return dataPos;
}
void AssemblyBuilderA64::log(const char* opcode)
{
logAppend(" %s\n", opcode);
}
void AssemblyBuilderA64::log(const char* opcode, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift)
{
logAppend(" %-12s", opcode);
log(dst);
text.append(",");
log(src1);
text.append(",");
log(src2);
if (shift > 0)
logAppend(" LSL #%d", shift);
text.append("\n");
}
void AssemblyBuilderA64::log(const char* opcode, RegisterA64 dst, RegisterA64 src1, int src2)
{
logAppend(" %-12s", opcode);
log(dst);
text.append(",");
log(src1);
text.append(",");
logAppend("#%d", src2);
text.append("\n");
}
void AssemblyBuilderA64::log(const char* opcode, RegisterA64 dst, AddressA64 src)
{
logAppend(" %-12s", opcode);
log(dst);
text.append(",");
log(src);
text.append("\n");
}
void AssemblyBuilderA64::log(const char* opcode, RegisterA64 dst, RegisterA64 src)
{
logAppend(" %-12s", opcode);
log(dst);
text.append(",");
log(src);
text.append("\n");
}
void AssemblyBuilderA64::log(const char* opcode, RegisterA64 dst, int src, int shift)
{
logAppend(" %-12s", opcode);
log(dst);
text.append(",");
logAppend("#%d", src);
if (shift > 0)
logAppend(" LSL #%d", shift);
text.append("\n");
}
void AssemblyBuilderA64::log(const char* opcode, RegisterA64 src, Label label)
{
logAppend(" %-12s", opcode);
log(src);
text.append(",");
logAppend(".L%d\n", label.id);
}
void AssemblyBuilderA64::log(const char* opcode, Label label)
{
logAppend(" %-12s.L%d\n", opcode, label.id);
}
void AssemblyBuilderA64::log(Label label)
{
logAppend(".L%d:\n", label.id);
}
void AssemblyBuilderA64::log(RegisterA64 reg)
{
switch (reg.kind)
{
case KindA64::w:
if (reg.index == 31)
logAppend("wzr");
else
logAppend("w%d", reg.index);
break;
case KindA64::x:
if (reg.index == 31)
logAppend("xzr");
else
logAppend("x%d", reg.index);
break;
case KindA64::none:
LUAU_ASSERT(!"Unexpected register kind");
}
}
void AssemblyBuilderA64::log(AddressA64 addr)
{
text.append("[");
switch (addr.kind)
{
case AddressKindA64::imm:
log(addr.base);
if (addr.data != 0)
logAppend(",#%d", addr.data);
break;
case AddressKindA64::reg:
log(addr.base);
text.append(",");
log(addr.offset);
if (addr.data != 0)
logAppend(" LSL #%d", addr.data);
break;
}
text.append("]");
}
} // namespace CodeGen
} // namespace Luau

View file

@ -13,13 +13,13 @@ namespace CodeGen
{ {
// TODO: more assertions on operand sizes // TODO: more assertions on operand sizes
const uint8_t codeForCondition[] = { static const uint8_t codeForCondition[] = {
0x0, 0x1, 0x2, 0x3, 0x2, 0x6, 0x7, 0x3, 0x4, 0xc, 0xe, 0xf, 0xd, 0x3, 0x7, 0x6, 0x2, 0x5, 0xd, 0xf, 0xe, 0xc, 0x4, 0x5, 0xa, 0xb}; 0x0, 0x1, 0x2, 0x3, 0x2, 0x6, 0x7, 0x3, 0x4, 0xc, 0xe, 0xf, 0xd, 0x3, 0x7, 0x6, 0x2, 0x5, 0xd, 0xf, 0xe, 0xc, 0x4, 0x5, 0xa, 0xb};
static_assert(sizeof(codeForCondition) / sizeof(codeForCondition[0]) == size_t(Condition::Count), "all conditions have to be covered"); static_assert(sizeof(codeForCondition) / sizeof(codeForCondition[0]) == size_t(ConditionX64::Count), "all conditions have to be covered");
const char* textForCondition[] = {"jo", "jno", "jc", "jnc", "jb", "jbe", "ja", "jae", "je", "jl", "jle", "jg", "jge", "jnb", "jnbe", "jna", "jnae", static const char* textForCondition[] = {"jo", "jno", "jc", "jnc", "jb", "jbe", "ja", "jae", "je", "jl", "jle", "jg", "jge", "jnb", "jnbe", "jna",
"jne", "jnl", "jnle", "jng", "jnge", "jz", "jnz", "jp", "jnp"}; "jnae", "jne", "jnl", "jnle", "jng", "jnge", "jz", "jnz", "jp", "jnp"};
static_assert(sizeof(textForCondition) / sizeof(textForCondition[0]) == size_t(Condition::Count), "all conditions have to be covered"); static_assert(sizeof(textForCondition) / sizeof(textForCondition[0]) == size_t(ConditionX64::Count), "all conditions have to be covered");
#define OP_PLUS_REG(op, reg) ((op) + (reg & 0x7)) #define OP_PLUS_REG(op, reg) ((op) + (reg & 0x7))
#define OP_PLUS_CC(op, cc) ((op) + uint8_t(cc)) #define OP_PLUS_CC(op, cc) ((op) + uint8_t(cc))
@ -29,7 +29,7 @@ static_assert(sizeof(textForCondition) / sizeof(textForCondition[0]) == size_t(C
#define REX_X(reg) (((reg).index & 0x8) >> 2) #define REX_X(reg) (((reg).index & 0x8) >> 2)
#define REX_B(reg) (((reg).index & 0x8) >> 3) #define REX_B(reg) (((reg).index & 0x8) >> 3)
#define AVX_W(value) (!(value) ? 0x80 : 0x0) #define AVX_W(value) ((value) ? 0x80 : 0x0)
#define AVX_R(reg) ((~(reg).index & 0x8) << 4) #define AVX_R(reg) ((~(reg).index & 0x8) << 4)
#define AVX_X(reg) ((~(reg).index & 0x8) << 3) #define AVX_X(reg) ((~(reg).index & 0x8) << 3)
#define AVX_B(reg) ((~(reg).index & 0x8) << 2) #define AVX_B(reg) ((~(reg).index & 0x8) << 2)
@ -50,12 +50,23 @@ const unsigned AVX_66 = 0b01;
const unsigned AVX_F3 = 0b10; const unsigned AVX_F3 = 0b10;
const unsigned AVX_F2 = 0b11; const unsigned AVX_F2 = 0b11;
const unsigned kMaxAlign = 16; const unsigned kMaxAlign = 32;
const unsigned kMaxInstructionLength = 16;
const uint8_t kRoundingPrecisionInexact = 0b1000; const uint8_t kRoundingPrecisionInexact = 0b1000;
static ABIX64 getCurrentX64ABI()
{
#if defined(_WIN32)
return ABIX64::Windows;
#else
return ABIX64::SystemV;
#endif
}
AssemblyBuilderX64::AssemblyBuilderX64(bool logText) AssemblyBuilderX64::AssemblyBuilderX64(bool logText)
: logText(logText) : logText(logText)
, abi(getCurrentX64ABI())
{ {
data.resize(4096); data.resize(4096);
dataPos = data.size(); // data is filled backwards dataPos = data.size(); // data is filled backwards
@ -317,7 +328,10 @@ void AssemblyBuilderX64::lea(OperandX64 lhs, OperandX64 rhs)
if (logText) if (logText)
log("lea", lhs, rhs); log("lea", lhs, rhs);
LUAU_ASSERT(rhs.cat == CategoryX64::mem); LUAU_ASSERT(lhs.cat == CategoryX64::reg && rhs.cat == CategoryX64::mem && rhs.memSize == SizeX64::none);
LUAU_ASSERT(rhs.base == rip || rhs.base.size == lhs.base.size);
LUAU_ASSERT(rhs.index == noreg || rhs.index.size == lhs.base.size);
rhs.memSize = lhs.base.size;
placeBinaryRegAndRegMem(lhs, rhs, 0x8d, 0x8d); placeBinaryRegAndRegMem(lhs, rhs, 0x8d, 0x8d);
} }
@ -352,7 +366,7 @@ void AssemblyBuilderX64::ret()
commit(); commit();
} }
void AssemblyBuilderX64::jcc(Condition cond, Label& label) void AssemblyBuilderX64::jcc(ConditionX64 cond, Label& label)
{ {
placeJcc(textForCondition[size_t(cond)], label, codeForCondition[size_t(cond)]); placeJcc(textForCondition[size_t(cond)], label, codeForCondition[size_t(cond)]);
} }
@ -416,6 +430,153 @@ void AssemblyBuilderX64::int3()
log("int3"); log("int3");
place(0xcc); place(0xcc);
commit();
}
void AssemblyBuilderX64::nop(uint32_t length)
{
while (length != 0)
{
uint32_t step = length > 9 ? 9 : length;
length -= step;
switch (step)
{
case 1:
if (logText)
logAppend(" nop\n");
place(0x90);
break;
case 2:
if (logText)
logAppend(" xchg ax, ax ; %u-byte nop\n", step);
place(0x66);
place(0x90);
break;
case 3:
if (logText)
logAppend(" nop dword ptr[rax] ; %u-byte nop\n", step);
place(0x0f);
place(0x1f);
place(0x00);
break;
case 4:
if (logText)
logAppend(" nop dword ptr[rax] ; %u-byte nop\n", step);
place(0x0f);
place(0x1f);
place(0x40);
place(0x00);
break;
case 5:
if (logText)
logAppend(" nop dword ptr[rax+rax] ; %u-byte nop\n", step);
place(0x0f);
place(0x1f);
place(0x44);
place(0x00);
place(0x00);
break;
case 6:
if (logText)
logAppend(" nop word ptr[rax+rax] ; %u-byte nop\n", step);
place(0x66);
place(0x0f);
place(0x1f);
place(0x44);
place(0x00);
place(0x00);
break;
case 7:
if (logText)
logAppend(" nop dword ptr[rax] ; %u-byte nop\n", step);
place(0x0f);
place(0x1f);
place(0x80);
place(0x00);
place(0x00);
place(0x00);
place(0x00);
break;
case 8:
if (logText)
logAppend(" nop dword ptr[rax+rax] ; %u-byte nop\n", step);
place(0x0f);
place(0x1f);
place(0x84);
place(0x00);
place(0x00);
place(0x00);
place(0x00);
place(0x00);
break;
case 9:
if (logText)
logAppend(" nop word ptr[rax+rax] ; %u-byte nop\n", step);
place(0x66);
place(0x0f);
place(0x1f);
place(0x84);
place(0x00);
place(0x00);
place(0x00);
place(0x00);
place(0x00);
break;
}
commit();
}
}
void AssemblyBuilderX64::align(uint32_t alignment, AlignmentDataX64 data)
{
LUAU_ASSERT((alignment & (alignment - 1)) == 0);
uint32_t size = getCodeSize();
uint32_t pad = ((size + alignment - 1) & ~(alignment - 1)) - size;
switch (data)
{
case AlignmentDataX64::Nop:
if (logText)
logAppend("; align %u\n", alignment);
nop(pad);
break;
case AlignmentDataX64::Int3:
if (logText)
logAppend("; align %u using int3\n", alignment);
while (codePos + pad > codeEnd)
extend();
for (uint32_t i = 0; i < pad; ++i)
place(0xcc);
commit();
break;
case AlignmentDataX64::Ud2:
if (logText)
logAppend("; align %u using ud2\n", alignment);
while (codePos + pad > codeEnd)
extend();
uint32_t i = 0;
for (; i + 1 < pad; i += 2)
{
place(0x0f);
place(0x0b);
}
if (i < pad)
place(0xcc);
commit();
break;
}
} }
void AssemblyBuilderX64::vaddpd(OperandX64 dst, OperandX64 src1, OperandX64 src2) void AssemblyBuilderX64::vaddpd(OperandX64 dst, OperandX64 src1, OperandX64 src2)
@ -465,12 +626,12 @@ void AssemblyBuilderX64::vucomisd(OperandX64 src1, OperandX64 src2)
void AssemblyBuilderX64::vcvttsd2si(OperandX64 dst, OperandX64 src) void AssemblyBuilderX64::vcvttsd2si(OperandX64 dst, OperandX64 src)
{ {
placeAvx("vcvttsd2si", dst, src, 0x2c, dst.base.size == SizeX64::dword, AVX_0F, AVX_F2); placeAvx("vcvttsd2si", dst, src, 0x2c, dst.base.size == SizeX64::qword, AVX_0F, AVX_F2);
} }
void AssemblyBuilderX64::vcvtsi2sd(OperandX64 dst, OperandX64 src1, OperandX64 src2) void AssemblyBuilderX64::vcvtsi2sd(OperandX64 dst, OperandX64 src1, OperandX64 src2)
{ {
placeAvx("vcvtsi2sd", dst, src1, src2, 0x2a, (src2.cat == CategoryX64::reg ? src2.base.size : src2.memSize) == SizeX64::dword, AVX_0F, AVX_F2); placeAvx("vcvtsi2sd", dst, src1, src2, 0x2a, (src2.cat == CategoryX64::reg ? src2.base.size : src2.memSize) == SizeX64::qword, AVX_0F, AVX_F2);
} }
void AssemblyBuilderX64::vroundsd(OperandX64 dst, OperandX64 src1, OperandX64 src2, RoundingModeX64 roundingMode) void AssemblyBuilderX64::vroundsd(OperandX64 dst, OperandX64 src1, OperandX64 src2, RoundingModeX64 roundingMode)
@ -623,7 +784,22 @@ OperandX64 AssemblyBuilderX64::bytes(const void* ptr, size_t size, size_t align)
{ {
size_t pos = allocateData(size, align); size_t pos = allocateData(size, align);
memcpy(&data[pos], ptr, size); memcpy(&data[pos], ptr, size);
return OperandX64(SizeX64::qword, noreg, 1, rip, int32_t(pos - data.size())); return OperandX64(SizeX64::none, noreg, 1, rip, int32_t(pos - data.size()));
}
void AssemblyBuilderX64::logAppend(const char* fmt, ...)
{
char buf[256];
va_list args;
va_start(args, fmt);
vsnprintf(buf, sizeof(buf), fmt, args);
va_end(args);
text.append(buf);
}
uint32_t AssemblyBuilderX64::getCodeSize() const
{
return uint32_t(codePos - code.data());
} }
void AssemblyBuilderX64::placeBinary(const char* name, OperandX64 lhs, OperandX64 rhs, uint8_t codeimm8, uint8_t codeimm, uint8_t codeimmImm8, void AssemblyBuilderX64::placeBinary(const char* name, OperandX64 lhs, OperandX64 rhs, uint8_t codeimm8, uint8_t codeimm, uint8_t codeimmImm8,
@ -899,7 +1075,7 @@ void AssemblyBuilderX64::placeVex(OperandX64 dst, OperandX64 src1, OperandX64 sr
place(AVX_3_3(setW, src1.base, dst.base.size == SizeX64::ymmword, prefix)); place(AVX_3_3(setW, src1.base, dst.base.size == SizeX64::ymmword, prefix));
} }
uint8_t getScaleEncoding(uint8_t scale) static uint8_t getScaleEncoding(uint8_t scale)
{ {
static const uint8_t scales[9] = {0xff, 0, 1, 0xff, 2, 0xff, 0xff, 0xff, 3}; static const uint8_t scales[9] = {0xff, 0, 1, 0xff, 2, 0xff, 0xff, 0xff, 3};
@ -1054,7 +1230,7 @@ void AssemblyBuilderX64::commit()
{ {
LUAU_ASSERT(codePos <= codeEnd); LUAU_ASSERT(codePos <= codeEnd);
if (codeEnd - codePos < 16) if (codeEnd - codePos < kMaxInstructionLength)
extend(); extend();
} }
@ -1067,11 +1243,6 @@ void AssemblyBuilderX64::extend()
codeEnd = code.data() + code.size(); codeEnd = code.data() + code.size();
} }
uint32_t AssemblyBuilderX64::getCodeSize()
{
return uint32_t(codePos - code.data());
}
size_t AssemblyBuilderX64::allocateData(size_t size, size_t align) size_t AssemblyBuilderX64::allocateData(size_t size, size_t align)
{ {
LUAU_ASSERT(align > 0 && align <= kMaxAlign && (align & (align - 1)) == 0); LUAU_ASSERT(align > 0 && align <= kMaxAlign && (align & (align - 1)) == 0);
@ -1174,8 +1345,10 @@ void AssemblyBuilderX64::log(OperandX64 op)
{ {
if (op.imm >= 0 && op.imm <= 9) if (op.imm >= 0 && op.imm <= 9)
logAppend("+%d", op.imm); logAppend("+%d", op.imm);
else else if (op.imm > 0)
logAppend("+0%Xh", op.imm); logAppend("+0%Xh", op.imm);
else
logAppend("-0%Xh", -op.imm);
} }
text.append("]"); text.append("]");
@ -1191,17 +1364,7 @@ void AssemblyBuilderX64::log(OperandX64 op)
} }
} }
void AssemblyBuilderX64::logAppend(const char* fmt, ...) const char* AssemblyBuilderX64::getSizeName(SizeX64 size) const
{
char buf[256];
va_list args;
va_start(args, fmt);
vsnprintf(buf, sizeof(buf), fmt, args);
va_end(args);
text.append(buf);
}
const char* AssemblyBuilderX64::getSizeName(SizeX64 size)
{ {
static const char* sizeNames[] = {"none", "byte", "word", "dword", "qword", "xmmword", "ymmword"}; static const char* sizeNames[] = {"none", "byte", "word", "dword", "qword", "xmmword", "ymmword"};
@ -1209,7 +1372,7 @@ const char* AssemblyBuilderX64::getSizeName(SizeX64 size)
return sizeNames[unsigned(size)]; return sizeNames[unsigned(size)];
} }
const char* AssemblyBuilderX64::getRegisterName(RegisterX64 reg) const char* AssemblyBuilderX64::getRegisterName(RegisterX64 reg) const
{ {
static const char* names[][16] = {{"rip", "", "", "", "", "", "", "", "", "", "", "", "", "", "", ""}, static const char* names[][16] = {{"rip", "", "", "", "", "", "", "", "", "", "", "", "", "", "", ""},
{"al", "cl", "dl", "bl", "spl", "bpl", "sil", "dil", "r8b", "r9b", "r10b", "r11b", "r12b", "r13b", "r14b", "r15b"}, {"al", "cl", "dl", "bl", "spl", "bpl", "sil", "dil", "r8b", "r9b", "r10b", "r11b", "r12b", "r13b", "r14b", "r15b"},

View file

@ -1,7 +1,9 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once #pragma once
#if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ #include "Luau/Common.h"
#if defined(LUAU_BIG_ENDIAN)
#include <endian.h> #include <endian.h>
#endif #endif
@ -15,7 +17,7 @@ inline uint8_t* writeu8(uint8_t* target, uint8_t value)
inline uint8_t* writeu32(uint8_t* target, uint32_t value) inline uint8_t* writeu32(uint8_t* target, uint32_t value)
{ {
#if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ #if defined(LUAU_BIG_ENDIAN)
value = htole32(value); value = htole32(value);
#endif #endif
@ -25,7 +27,7 @@ inline uint8_t* writeu32(uint8_t* target, uint32_t value)
inline uint8_t* writeu64(uint8_t* target, uint64_t value) inline uint8_t* writeu64(uint8_t* target, uint64_t value)
{ {
#if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ #if defined(LUAU_BIG_ENDIAN)
value = htole64(value); value = htole64(value);
#endif #endif
@ -51,7 +53,7 @@ inline uint8_t* writeuleb128(uint8_t* target, uint64_t value)
inline uint8_t* writef32(uint8_t* target, float value) inline uint8_t* writef32(uint8_t* target, float value)
{ {
#if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ #if defined(LUAU_BIG_ENDIAN)
static_assert(sizeof(float) == sizeof(uint32_t), "type size must match to reinterpret data"); static_assert(sizeof(float) == sizeof(uint32_t), "type size must match to reinterpret data");
uint32_t data; uint32_t data;
memcpy(&data, &value, sizeof(value)); memcpy(&data, &value, sizeof(value));
@ -65,7 +67,7 @@ inline uint8_t* writef32(uint8_t* target, float value)
inline uint8_t* writef64(uint8_t* target, double value) inline uint8_t* writef64(uint8_t* target, double value)
{ {
#if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ #if defined(LUAU_BIG_ENDIAN)
static_assert(sizeof(double) == sizeof(uint64_t), "type size must match to reinterpret data"); static_assert(sizeof(double) == sizeof(uint64_t), "type size must match to reinterpret data");
uint64_t data; uint64_t data;
memcpy(&data, &value, sizeof(value)); memcpy(&data, &value, sizeof(value));

View file

@ -110,8 +110,8 @@ CodeAllocator::~CodeAllocator()
bool CodeAllocator::allocate( bool CodeAllocator::allocate(
uint8_t* data, size_t dataSize, uint8_t* code, size_t codeSize, uint8_t*& result, size_t& resultSize, uint8_t*& resultCodeStart) uint8_t* data, size_t dataSize, uint8_t* code, size_t codeSize, uint8_t*& result, size_t& resultSize, uint8_t*& resultCodeStart)
{ {
// 'Round up' to preserve 16 byte alignment // 'Round up' to preserve code alignment
size_t alignedDataSize = (dataSize + 15) & ~15; size_t alignedDataSize = (dataSize + (kCodeAlignment - 1)) & ~(kCodeAlignment - 1);
size_t totalSize = alignedDataSize + codeSize; size_t totalSize = alignedDataSize + codeSize;
@ -187,8 +187,8 @@ bool CodeAllocator::allocateNewBlock(size_t& unwindInfoSize)
{ {
void* unwindInfo = createBlockUnwindInfo(context, block, blockSize, unwindInfoSize); void* unwindInfo = createBlockUnwindInfo(context, block, blockSize, unwindInfoSize);
// 'Round up' to preserve 16 byte alignment of the following data and code // 'Round up' to preserve alignment of the following data and code
unwindInfoSize = (unwindInfoSize + 15) & ~15; unwindInfoSize = (unwindInfoSize + (kCodeAlignment - 1)) & ~(kCodeAlignment - 1);
LUAU_ASSERT(unwindInfoSize <= kMaxReservedDataSize); LUAU_ASSERT(unwindInfoSize <= kMaxReservedDataSize);

View file

@ -1,6 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/CodeBlockUnwind.h" #include "Luau/CodeBlockUnwind.h"
#include "Luau/CodeAllocator.h"
#include "Luau/UnwindBuilder.h" #include "Luau/UnwindBuilder.h"
#include <string.h> #include <string.h>
@ -58,7 +59,7 @@ void* createBlockUnwindInfo(void* context, uint8_t* block, size_t blockSize, siz
// All unwinding related data is placed together at the start of the block // All unwinding related data is placed together at the start of the block
size_t unwindSize = sizeof(RUNTIME_FUNCTION) + unwind->getSize(); size_t unwindSize = sizeof(RUNTIME_FUNCTION) + unwind->getSize();
unwindSize = (unwindSize + 15) & ~15; // Align to 16 bytes unwindSize = (unwindSize + (kCodeAlignment - 1)) & ~(kCodeAlignment - 1); // Match code allocator alignment
LUAU_ASSERT(blockSize >= unwindSize); LUAU_ASSERT(blockSize >= unwindSize);
RUNTIME_FUNCTION* runtimeFunc = (RUNTIME_FUNCTION*)block; RUNTIME_FUNCTION* runtimeFunc = (RUNTIME_FUNCTION*)block;
@ -82,7 +83,7 @@ void* createBlockUnwindInfo(void* context, uint8_t* block, size_t blockSize, siz
// All unwinding related data is placed together at the start of the block // All unwinding related data is placed together at the start of the block
size_t unwindSize = unwind->getSize(); size_t unwindSize = unwind->getSize();
unwindSize = (unwindSize + 15) & ~15; // Align to 16 bytes unwindSize = (unwindSize + (kCodeAlignment - 1)) & ~(kCodeAlignment - 1); // Match code allocator alignment
LUAU_ASSERT(blockSize >= unwindSize); LUAU_ASSERT(blockSize >= unwindSize);
char* unwindData = (char*)block; char* unwindData = (char*)block;

View file

@ -32,7 +32,360 @@ namespace Luau
namespace CodeGen namespace CodeGen
{ {
static NativeProto* assembleFunction(AssemblyBuilderX64& build, NativeState& data, Proto* proto) constexpr uint32_t kFunctionAlignment = 32;
struct InstructionOutline
{
int pcpos;
int length;
};
static void assembleHelpers(AssemblyBuilderX64& build, ModuleHelpers& helpers)
{
if (build.logText)
build.logAppend("; exitContinueVm\n");
helpers.exitContinueVm = build.setLabel();
emitExit(build, /* continueInVm */ true);
if (build.logText)
build.logAppend("; exitNoContinueVm\n");
helpers.exitNoContinueVm = build.setLabel();
emitExit(build, /* continueInVm */ false);
if (build.logText)
build.logAppend("; continueCallInVm\n");
helpers.continueCallInVm = build.setLabel();
emitContinueCallInVm(build);
}
static int emitInst(AssemblyBuilderX64& build, NativeState& data, ModuleHelpers& helpers, Proto* proto, LuauOpcode op, const Instruction* pc, int i,
Label* labelarr, Label& fallback)
{
int skip = 0;
switch (op)
{
case LOP_NOP:
break;
case LOP_LOADNIL:
emitInstLoadNil(build, pc);
break;
case LOP_LOADB:
emitInstLoadB(build, pc, i, labelarr);
break;
case LOP_LOADN:
emitInstLoadN(build, pc);
break;
case LOP_LOADK:
emitInstLoadK(build, pc);
break;
case LOP_LOADKX:
emitInstLoadKX(build, pc);
break;
case LOP_MOVE:
emitInstMove(build, pc);
break;
case LOP_GETGLOBAL:
emitInstGetGlobal(build, pc, i, fallback);
break;
case LOP_SETGLOBAL:
emitInstSetGlobal(build, pc, i, labelarr, fallback);
break;
case LOP_CALL:
emitInstCall(build, helpers, pc, i, labelarr);
break;
case LOP_RETURN:
emitInstReturn(build, helpers, pc, i, labelarr);
break;
case LOP_GETTABLE:
emitInstGetTable(build, pc, i, fallback);
break;
case LOP_SETTABLE:
emitInstSetTable(build, pc, i, labelarr, fallback);
break;
case LOP_GETTABLEKS:
emitInstGetTableKS(build, pc, i, fallback);
break;
case LOP_SETTABLEKS:
emitInstSetTableKS(build, pc, i, labelarr, fallback);
break;
case LOP_GETTABLEN:
emitInstGetTableN(build, pc, i, fallback);
break;
case LOP_SETTABLEN:
emitInstSetTableN(build, pc, i, labelarr, fallback);
break;
case LOP_JUMP:
emitInstJump(build, pc, i, labelarr);
break;
case LOP_JUMPBACK:
emitInstJumpBack(build, pc, i, labelarr);
break;
case LOP_JUMPIF:
emitInstJumpIf(build, pc, i, labelarr, /* not_ */ false);
break;
case LOP_JUMPIFNOT:
emitInstJumpIf(build, pc, i, labelarr, /* not_ */ true);
break;
case LOP_JUMPIFEQ:
emitInstJumpIfEq(build, pc, i, labelarr, /* not_ */ false, fallback);
break;
case LOP_JUMPIFLE:
emitInstJumpIfCond(build, pc, i, labelarr, ConditionX64::LessEqual, fallback);
break;
case LOP_JUMPIFLT:
emitInstJumpIfCond(build, pc, i, labelarr, ConditionX64::Less, fallback);
break;
case LOP_JUMPIFNOTEQ:
emitInstJumpIfEq(build, pc, i, labelarr, /* not_ */ true, fallback);
break;
case LOP_JUMPIFNOTLE:
emitInstJumpIfCond(build, pc, i, labelarr, ConditionX64::NotLessEqual, fallback);
break;
case LOP_JUMPIFNOTLT:
emitInstJumpIfCond(build, pc, i, labelarr, ConditionX64::NotLess, fallback);
break;
case LOP_JUMPX:
emitInstJumpX(build, pc, i, labelarr);
break;
case LOP_JUMPXEQKNIL:
emitInstJumpxEqNil(build, pc, i, labelarr);
break;
case LOP_JUMPXEQKB:
emitInstJumpxEqB(build, pc, i, labelarr);
break;
case LOP_JUMPXEQKN:
emitInstJumpxEqN(build, pc, proto->k, i, labelarr);
break;
case LOP_JUMPXEQKS:
emitInstJumpxEqS(build, pc, i, labelarr);
break;
case LOP_ADD:
emitInstBinary(build, pc, i, TM_ADD, fallback);
break;
case LOP_SUB:
emitInstBinary(build, pc, i, TM_SUB, fallback);
break;
case LOP_MUL:
emitInstBinary(build, pc, i, TM_MUL, fallback);
break;
case LOP_DIV:
emitInstBinary(build, pc, i, TM_DIV, fallback);
break;
case LOP_MOD:
emitInstBinary(build, pc, i, TM_MOD, fallback);
break;
case LOP_POW:
emitInstBinary(build, pc, i, TM_POW, fallback);
break;
case LOP_ADDK:
emitInstBinaryK(build, pc, i, TM_ADD, fallback);
break;
case LOP_SUBK:
emitInstBinaryK(build, pc, i, TM_SUB, fallback);
break;
case LOP_MULK:
emitInstBinaryK(build, pc, i, TM_MUL, fallback);
break;
case LOP_DIVK:
emitInstBinaryK(build, pc, i, TM_DIV, fallback);
break;
case LOP_MODK:
emitInstBinaryK(build, pc, i, TM_MOD, fallback);
break;
case LOP_POWK:
emitInstPowK(build, pc, proto->k, i, fallback);
break;
case LOP_NOT:
emitInstNot(build, pc);
break;
case LOP_MINUS:
emitInstMinus(build, pc, i, fallback);
break;
case LOP_LENGTH:
emitInstLength(build, pc, i, fallback);
break;
case LOP_NEWTABLE:
emitInstNewTable(build, pc, i, labelarr);
break;
case LOP_DUPTABLE:
emitInstDupTable(build, pc, i, labelarr);
break;
case LOP_SETLIST:
emitInstSetList(build, pc, i, labelarr);
break;
case LOP_GETUPVAL:
emitInstGetUpval(build, pc, i);
break;
case LOP_SETUPVAL:
emitInstSetUpval(build, pc, i, labelarr);
break;
case LOP_CLOSEUPVALS:
emitInstCloseUpvals(build, pc, i, labelarr);
break;
case LOP_FASTCALL:
skip = emitInstFastCall(build, pc, i, labelarr);
break;
case LOP_FASTCALL1:
skip = emitInstFastCall1(build, pc, i, labelarr);
break;
case LOP_FASTCALL2:
skip = emitInstFastCall2(build, pc, i, labelarr);
break;
case LOP_FASTCALL2K:
skip = emitInstFastCall2K(build, pc, i, labelarr);
break;
case LOP_FORNPREP:
emitInstForNPrep(build, pc, i, labelarr);
break;
case LOP_FORNLOOP:
emitInstForNLoop(build, pc, i, labelarr);
break;
case LOP_FORGLOOP:
emitinstForGLoop(build, pc, i, labelarr, fallback);
break;
case LOP_FORGPREP_NEXT:
emitInstForGPrepNext(build, pc, i, labelarr, fallback);
break;
case LOP_FORGPREP_INEXT:
emitInstForGPrepInext(build, pc, i, labelarr, fallback);
break;
case LOP_AND:
emitInstAnd(build, pc);
break;
case LOP_ANDK:
emitInstAndK(build, pc);
break;
case LOP_OR:
emitInstOr(build, pc);
break;
case LOP_ORK:
emitInstOrK(build, pc);
break;
case LOP_GETIMPORT:
emitInstGetImport(build, pc, fallback);
break;
case LOP_CONCAT:
emitInstConcat(build, pc, i, labelarr);
break;
default:
emitFallback(build, data, op, i);
break;
}
return skip;
}
static void emitInstFallback(AssemblyBuilderX64& build, NativeState& data, LuauOpcode op, const Instruction* pc, int i, Label* labelarr)
{
switch (op)
{
case LOP_GETIMPORT:
emitInstGetImportFallback(build, pc, i);
break;
case LOP_GETTABLE:
emitInstGetTableFallback(build, pc, i);
break;
case LOP_SETTABLE:
emitInstSetTableFallback(build, pc, i);
break;
case LOP_GETTABLEN:
emitInstGetTableNFallback(build, pc, i);
break;
case LOP_SETTABLEN:
emitInstSetTableNFallback(build, pc, i);
break;
case LOP_JUMPIFEQ:
emitInstJumpIfEqFallback(build, pc, i, labelarr, /* not_ */ false);
break;
case LOP_JUMPIFLE:
emitInstJumpIfCondFallback(build, pc, i, labelarr, ConditionX64::LessEqual);
break;
case LOP_JUMPIFLT:
emitInstJumpIfCondFallback(build, pc, i, labelarr, ConditionX64::Less);
break;
case LOP_JUMPIFNOTEQ:
emitInstJumpIfEqFallback(build, pc, i, labelarr, /* not_ */ true);
break;
case LOP_JUMPIFNOTLE:
emitInstJumpIfCondFallback(build, pc, i, labelarr, ConditionX64::NotLessEqual);
break;
case LOP_JUMPIFNOTLT:
emitInstJumpIfCondFallback(build, pc, i, labelarr, ConditionX64::NotLess);
break;
case LOP_ADD:
emitInstBinaryFallback(build, pc, i, TM_ADD);
break;
case LOP_SUB:
emitInstBinaryFallback(build, pc, i, TM_SUB);
break;
case LOP_MUL:
emitInstBinaryFallback(build, pc, i, TM_MUL);
break;
case LOP_DIV:
emitInstBinaryFallback(build, pc, i, TM_DIV);
break;
case LOP_MOD:
emitInstBinaryFallback(build, pc, i, TM_MOD);
break;
case LOP_POW:
emitInstBinaryFallback(build, pc, i, TM_POW);
break;
case LOP_ADDK:
emitInstBinaryKFallback(build, pc, i, TM_ADD);
break;
case LOP_SUBK:
emitInstBinaryKFallback(build, pc, i, TM_SUB);
break;
case LOP_MULK:
emitInstBinaryKFallback(build, pc, i, TM_MUL);
break;
case LOP_DIVK:
emitInstBinaryKFallback(build, pc, i, TM_DIV);
break;
case LOP_MODK:
emitInstBinaryKFallback(build, pc, i, TM_MOD);
break;
case LOP_POWK:
emitInstBinaryKFallback(build, pc, i, TM_POW);
break;
case LOP_MINUS:
emitInstMinusFallback(build, pc, i);
break;
case LOP_LENGTH:
emitInstLengthFallback(build, pc, i);
break;
case LOP_FORGLOOP:
emitinstForGLoopFallback(build, pc, i, labelarr);
break;
case LOP_FORGPREP_NEXT:
case LOP_FORGPREP_INEXT:
emitInstForGPrepXnextFallback(build, pc, i, labelarr);
break;
case LOP_GETGLOBAL:
// TODO: luaV_gettable + cachedslot update instead of full fallback
emitFallback(build, data, op, i);
break;
case LOP_SETGLOBAL:
// TODO: luaV_settable + cachedslot update instead of full fallback
emitFallback(build, data, op, i);
break;
case LOP_GETTABLEKS:
// Full fallback required for LOP_GETTABLEKS because 'luaV_gettable' doesn't handle builtin vector field access
// It is also required to perform cached slot update
// TODO: extra fast-paths could be lowered before the full fallback
emitFallback(build, data, op, i);
break;
case LOP_SETTABLEKS:
// TODO: luaV_settable + cachedslot update instead of full fallback
emitFallback(build, data, op, i);
break;
default:
LUAU_ASSERT(!"Expected fallback for instruction");
}
}
static NativeProto* assembleFunction(AssemblyBuilderX64& build, NativeState& data, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options)
{ {
NativeProto* result = new NativeProto(); NativeProto* result = new NativeProto();
@ -54,6 +407,14 @@ static NativeProto* assembleFunction(AssemblyBuilderX64& build, NativeState& dat
std::vector<Label> instLabels; std::vector<Label> instLabels;
instLabels.resize(proto->sizecode); instLabels.resize(proto->sizecode);
std::vector<Label> instFallbacks;
instFallbacks.resize(proto->sizecode);
std::vector<InstructionOutline> instOutlines;
instOutlines.reserve(64);
build.align(kFunctionAlignment, AlignmentDataX64::Ud2);
Label start = build.setLabel(); Label start = build.setLabel();
for (int i = 0; i < proto->sizecode;) for (int i = 0; i < proto->sizecode;)
@ -61,172 +422,90 @@ static NativeProto* assembleFunction(AssemblyBuilderX64& build, NativeState& dat
const Instruction* pc = &proto->code[i]; const Instruction* pc = &proto->code[i];
LuauOpcode op = LuauOpcode(LUAU_INSN_OP(*pc)); LuauOpcode op = LuauOpcode(LUAU_INSN_OP(*pc));
int nexti = i + getOpLength(op);
LUAU_ASSERT(nexti <= proto->sizecode);
build.setLabel(instLabels[i]); build.setLabel(instLabels[i]);
if (build.logText) if (options.annotator)
build.logAppend("; #%d: %s\n", i, data.names[op]); options.annotator(options.annotatorContext, build.text, proto->bytecodeid, i);
switch (op) int skip = emitInst(build, data, helpers, proto, op, pc, i, instLabels.data(), instFallbacks[i]);
if (skip != 0)
instOutlines.push_back({nexti, skip});
i = nexti + skip;
LUAU_ASSERT(i <= proto->sizecode);
}
size_t textSize = build.text.size();
uint32_t codeSize = build.getCodeSize();
if (options.annotator && !options.skipOutlinedCode)
build.logAppend("; outlined instructions\n");
for (auto [pcpos, length] : instOutlines)
{
int i = pcpos;
while (i < pcpos + length)
{ {
case LOP_NOP: const Instruction* pc = &proto->code[i];
break; LuauOpcode op = LuauOpcode(LUAU_INSN_OP(*pc));
case LOP_LOADNIL:
emitInstLoadNil(build, data, pc); build.setLabel(instLabels[i]);
break;
case LOP_LOADB: if (options.annotator && !options.skipOutlinedCode)
emitInstLoadB(build, data, pc, i, instLabels.data()); options.annotator(options.annotatorContext, build.text, proto->bytecodeid, i);
break;
case LOP_LOADN: int skip = emitInst(build, data, helpers, proto, op, pc, i, instLabels.data(), instFallbacks[i]);
emitInstLoadN(build, data, pc); LUAU_ASSERT(skip == 0);
break;
case LOP_LOADK: i += getOpLength(op);
emitInstLoadK(build, data, pc, proto->k);
break;
case LOP_MOVE:
emitInstMove(build, data, pc);
break;
case LOP_GETTABLE:
emitInstGetTable(build, pc, i);
break;
case LOP_SETTABLE:
emitInstSetTable(build, pc, i);
break;
case LOP_GETTABLEN:
emitInstGetTableN(build, pc, i);
break;
case LOP_SETTABLEN:
emitInstSetTableN(build, pc, i);
break;
case LOP_JUMP:
emitInstJump(build, data, pc, i, instLabels.data());
break;
case LOP_JUMPBACK:
emitInstJumpBack(build, data, pc, i, instLabels.data());
break;
case LOP_JUMPIF:
emitInstJumpIf(build, data, pc, i, instLabels.data(), /* not_ */ false);
break;
case LOP_JUMPIFNOT:
emitInstJumpIf(build, data, pc, i, instLabels.data(), /* not_ */ true);
break;
case LOP_JUMPIFEQ:
emitInstJumpIfEq(build, data, pc, i, instLabels.data(), /* not_ */ false);
break;
case LOP_JUMPIFLE:
emitInstJumpIfCond(build, data, pc, i, instLabels.data(), Condition::LessEqual);
break;
case LOP_JUMPIFLT:
emitInstJumpIfCond(build, data, pc, i, instLabels.data(), Condition::Less);
break;
case LOP_JUMPIFNOTEQ:
emitInstJumpIfEq(build, data, pc, i, instLabels.data(), /* not_ */ true);
break;
case LOP_JUMPIFNOTLE:
emitInstJumpIfCond(build, data, pc, i, instLabels.data(), Condition::NotLessEqual);
break;
case LOP_JUMPIFNOTLT:
emitInstJumpIfCond(build, data, pc, i, instLabels.data(), Condition::NotLess);
break;
case LOP_JUMPX:
emitInstJumpX(build, data, pc, i, instLabels.data());
break;
case LOP_JUMPXEQKNIL:
emitInstJumpxEqNil(build, data, pc, proto->k, i, instLabels.data());
break;
case LOP_JUMPXEQKB:
emitInstJumpxEqB(build, data, pc, proto->k, i, instLabels.data());
break;
case LOP_JUMPXEQKN:
emitInstJumpxEqN(build, data, pc, proto->k, i, instLabels.data());
break;
case LOP_JUMPXEQKS:
emitInstJumpxEqS(build, data, pc, proto->k, i, instLabels.data());
break;
case LOP_ADD:
emitInstAdd(build, pc, i);
break;
case LOP_SUB:
emitInstSub(build, pc, i);
break;
case LOP_MUL:
emitInstMul(build, pc, i);
break;
case LOP_DIV:
emitInstDiv(build, pc, i);
break;
case LOP_MOD:
emitInstMod(build, pc, i);
break;
case LOP_POW:
emitInstPow(build, pc, i);
break;
case LOP_ADDK:
emitInstAddK(build, pc, proto->k, i);
break;
case LOP_SUBK:
emitInstSubK(build, pc, proto->k, i);
break;
case LOP_MULK:
emitInstMulK(build, pc, proto->k, i);
break;
case LOP_DIVK:
emitInstDivK(build, pc, proto->k, i);
break;
case LOP_MODK:
emitInstModK(build, pc, proto->k, i);
break;
case LOP_POWK:
emitInstPowK(build, pc, proto->k, i);
break;
case LOP_NOT:
emitInstNot(build, pc);
break;
case LOP_MINUS:
emitInstMinus(build, pc, i);
break;
case LOP_LENGTH:
emitInstLength(build, pc, i);
break;
case LOP_GETUPVAL:
emitInstGetUpval(build, pc, i);
break;
case LOP_FASTCALL:
emitInstFastCall(build, pc, i, instLabels.data());
break;
case LOP_FASTCALL1:
emitInstFastCall1(build, pc, i, instLabels.data());
break;
case LOP_FASTCALL2:
emitInstFastCall2(build, pc, i, instLabels.data());
break;
case LOP_FASTCALL2K:
emitInstFastCall2K(build, pc, proto->k, i, instLabels.data());
break;
case LOP_FORNPREP:
emitInstForNPrep(build, pc, i, instLabels.data());
break;
case LOP_FORNLOOP:
emitInstForNLoop(build, pc, i, instLabels.data());
break;
case LOP_AND:
emitInstAnd(build, pc);
break;
case LOP_ANDK:
emitInstAndK(build, pc);
break;
case LOP_OR:
emitInstOr(build, pc);
break;
case LOP_ORK:
emitInstOrK(build, pc);
break;
default:
emitFallback(build, data, op, i);
break;
} }
i += getOpLength(op); if (i < proto->sizecode)
LUAU_ASSERT(i <= proto->sizecode); build.jmp(instLabels[i]);
}
if (options.annotator && !options.skipOutlinedCode)
build.logAppend("; outlined code\n");
for (int i = 0, instid = 0; i < proto->sizecode; ++instid)
{
const Instruction* pc = &proto->code[i];
LuauOpcode op = LuauOpcode(LUAU_INSN_OP(*pc));
int nexti = i + getOpLength(op);
LUAU_ASSERT(nexti <= proto->sizecode);
if (instFallbacks[i].id == 0)
{
i = nexti;
continue;
}
if (options.annotator && !options.skipOutlinedCode)
options.annotator(options.annotatorContext, build.text, proto->bytecodeid, instid);
build.setLabel(instFallbacks[i]);
emitInstFallback(build, data, op, pc, i, instLabels.data());
// Jump back to the next instruction handler
if (nexti < proto->sizecode)
build.jmp(instLabels[nexti]);
i = nexti;
}
// Truncate assembly output if we don't care for outlined code part
if (options.skipOutlinedCode)
{
build.text.resize(textSize);
build.logAppend("; skipping %u bytes of outlined code\n", build.getCodeSize() - codeSize);
} }
result->instTargets = new uintptr_t[proto->sizecode]; result->instTargets = new uintptr_t[proto->sizecode];
@ -386,13 +665,16 @@ void compile(lua_State* L, int idx)
std::vector<Proto*> protos; std::vector<Proto*> protos;
gatherFunctions(protos, clvalue(func)->l.p); gatherFunctions(protos, clvalue(func)->l.p);
ModuleHelpers helpers;
assembleHelpers(build, helpers);
std::vector<NativeProto*> results; std::vector<NativeProto*> results;
results.reserve(protos.size()); results.reserve(protos.size());
// Skip protos that have been compiled during previous invocations of CodeGen::compile // Skip protos that have been compiled during previous invocations of CodeGen::compile
for (Proto* p : protos) for (Proto* p : protos)
if (p && getProtoExecData(p) == nullptr) if (p && getProtoExecData(p) == nullptr)
results.push_back(assembleFunction(build, *data, p)); results.push_back(assembleFunction(build, *data, helpers, p, {}));
build.finalize(); build.finalize();
@ -413,6 +695,9 @@ void compile(lua_State* L, int idx)
{ {
for (int i = 0; i < result->proto->sizecode; i++) for (int i = 0; i < result->proto->sizecode; i++)
result->instTargets[i] += uintptr_t(codeStart + result->location); result->instTargets[i] += uintptr_t(codeStart + result->location);
LUAU_ASSERT(result->proto->sizecode);
result->entryTarget = result->instTargets[0];
} }
// Link native proto objects to Proto; the memory is now managed by VM and will be freed via onDestroyFunction // Link native proto objects to Proto; the memory is now managed by VM and will be freed via onDestroyFunction
@ -420,29 +705,35 @@ void compile(lua_State* L, int idx)
setProtoExecData(result->proto, result); setProtoExecData(result->proto, result);
} }
std::string getAssemblyText(lua_State* L, int idx) std::string getAssembly(lua_State* L, int idx, AssemblyOptions options)
{ {
LUAU_ASSERT(lua_isLfunction(L, idx)); LUAU_ASSERT(lua_isLfunction(L, idx));
const TValue* func = luaA_toobject(L, idx); const TValue* func = luaA_toobject(L, idx);
AssemblyBuilderX64 build(/* logText= */ true); AssemblyBuilderX64 build(/* logText= */ !options.outputBinary);
NativeState data; NativeState data;
initFallbackTable(data); initFallbackTable(data);
initInstructionNames(data);
std::vector<Proto*> protos; std::vector<Proto*> protos;
gatherFunctions(protos, clvalue(func)->l.p); gatherFunctions(protos, clvalue(func)->l.p);
ModuleHelpers helpers;
assembleHelpers(build, helpers);
for (Proto* p : protos) for (Proto* p : protos)
if (p) if (p)
{ {
NativeProto* nativeProto = assembleFunction(build, data, p); NativeProto* nativeProto = assembleFunction(build, data, helpers, p, options);
destroyNativeProto(nativeProto); destroyNativeProto(nativeProto);
} }
build.finalize(); build.finalize();
return build.text; if (options.outputBinary)
return std::string(build.code.begin(), build.code.end()) + std::string(build.data.begin(), build.data.end());
else
return build.text;
} }
} // namespace CodeGen } // namespace CodeGen

View file

@ -0,0 +1,130 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "CodeGenUtils.h"
#include "ldo.h"
#include "ltable.h"
#include "FallbacksProlog.h"
#include <string.h>
namespace Luau
{
namespace CodeGen
{
bool forgLoopNodeIter(lua_State* L, Table* h, int index, TValue* ra)
{
// then we advance index through the hash portion
while (unsigned(index - h->sizearray) < unsigned(1 << h->lsizenode))
{
LuaNode* n = &h->node[index - h->sizearray];
if (!ttisnil(gval(n)))
{
setpvalue(ra + 2, reinterpret_cast<void*>(uintptr_t(index + 1)));
getnodekey(L, ra + 3, n);
setobj(L, ra + 4, gval(n));
return true;
}
index++;
}
return false;
}
bool forgLoopNonTableFallback(lua_State* L, int insnA, int aux)
{
TValue* base = L->base;
TValue* ra = VM_REG(insnA);
// note: it's safe to push arguments past top for complicated reasons (see lvmexecute.cpp)
setobj2s(L, ra + 3 + 2, ra + 2);
setobj2s(L, ra + 3 + 1, ra + 1);
setobj2s(L, ra + 3, ra);
L->top = ra + 3 + 3; // func + 2 args (state and index)
LUAU_ASSERT(L->top <= L->stack_last);
luaD_call(L, ra + 3, uint8_t(aux));
L->top = L->ci->top;
// recompute ra since stack might have been reallocated
base = L->base;
ra = VM_REG(insnA);
// copy first variable back into the iteration index
setobj2s(L, ra + 2, ra + 3);
return !ttisnil(ra + 3);
}
void forgPrepXnextFallback(lua_State* L, TValue* ra, int pc)
{
if (!ttisfunction(ra))
{
Closure* cl = clvalue(L->ci->func);
L->ci->savedpc = cl->l.p->code + pc;
luaG_typeerror(L, ra, "iterate over");
}
}
Closure* callProlog(lua_State* L, TValue* ra, StkId argtop, int nresults)
{
// slow-path: not a function call
if (LUAU_UNLIKELY(!ttisfunction(ra)))
{
luaV_tryfuncTM(L, ra);
argtop++; // __call adds an extra self
}
Closure* ccl = clvalue(ra);
CallInfo* ci = incr_ci(L);
ci->func = ra;
ci->base = ra + 1;
ci->top = argtop + ccl->stacksize; // note: technically UB since we haven't reallocated the stack yet
ci->savedpc = NULL;
ci->flags = 0;
ci->nresults = nresults;
L->base = ci->base;
L->top = argtop;
// note: this reallocs stack, but we don't need to VM_PROTECT this
// this is because we're going to modify base/savedpc manually anyhow
// crucially, we can't use ra/argtop after this line
luaD_checkstack(L, ccl->stacksize);
return ccl;
}
void callEpilogC(lua_State* L, int nresults, int n)
{
// ci is our callinfo, cip is our parent
CallInfo* ci = L->ci;
CallInfo* cip = ci - 1;
// copy return values into parent stack (but only up to nresults!), fill the rest with nil
// note: in MULTRET context nresults starts as -1 so i != 0 condition never activates intentionally
StkId res = ci->func;
StkId vali = L->top - n;
StkId valend = L->top;
int i;
for (i = nresults; i != 0 && vali < valend; i--)
setobj2s(L, res++, vali++);
while (i-- > 0)
setnilvalue(res++);
// pop the stack frame
L->ci = cip;
L->base = cip->base;
L->top = (nresults == LUA_MULTRET) ? res : cip->top;
}
} // namespace CodeGen
} // namespace Luau

View file

@ -0,0 +1,20 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "lobject.h"
namespace Luau
{
namespace CodeGen
{
bool forgLoopNodeIter(lua_State* L, Table* h, int index, TValue* ra);
bool forgLoopNonTableFallback(lua_State* L, int insnA, int aux);
void forgPrepXnextFallback(lua_State* L, TValue* ra, int pc);
Closure* callProlog(lua_State* L, TValue* ra, StkId argtop, int nresults);
void callEpilogC(lua_State* L, int nresults, int n);
} // namespace CodeGen
} // namespace Luau

View file

@ -48,7 +48,7 @@ bool initEntryFunction(NativeState& data)
unwind.start(); unwind.start();
if (getCurrentX64ABI() == X64ABI::Windows) if (build.abi == ABIX64::Windows)
{ {
// Place arguments in home space // Place arguments in home space
build.mov(qword[rsp + 16], rArg2); build.mov(qword[rsp + 16], rArg2);
@ -87,7 +87,7 @@ bool initEntryFunction(NativeState& data)
unwind.allocStack(stacksize + localssize); unwind.allocStack(stacksize + localssize);
// Setup frame pointer // Setup frame pointer
build.lea(rbp, qword[rsp + stacksize]); build.lea(rbp, addr[rsp + stacksize]);
unwind.setupFrameReg(rbp, stacksize); unwind.setupFrameReg(rbp, stacksize);
unwind.finish(); unwind.finish();
@ -113,7 +113,7 @@ bool initEntryFunction(NativeState& data)
Label returnOff = build.setLabel(); Label returnOff = build.setLabel();
// Cleanup and exit // Cleanup and exit
build.lea(rsp, qword[rbp + localssize]); build.lea(rsp, addr[rbp + localssize]);
build.pop(r15); build.pop(r15);
build.pop(r14); build.pop(r14);
build.pop(r13); build.pop(r13);
@ -121,7 +121,7 @@ bool initEntryFunction(NativeState& data)
build.pop(rbp); build.pop(rbp);
build.pop(rbx); build.pop(rbx);
if (getCurrentX64ABI() == X64ABI::Windows) if (build.abi == ABIX64::Windows)
{ {
build.pop(rsi); build.pop(rsi);
build.pop(rdi); build.pop(rdi);

View file

@ -126,20 +126,5 @@ inline int getOpLength(LuauOpcode op)
} }
} }
enum class X64ABI
{
Windows,
SystemV,
};
inline X64ABI getCurrentX64ABI()
{
#if defined(_WIN32)
return X64ABI::Windows;
#else
return X64ABI::SystemV;
#endif
}
} // namespace CodeGen } // namespace CodeGen
} // namespace Luau } // namespace Luau

View file

@ -14,7 +14,7 @@ namespace Luau
namespace CodeGen namespace CodeGen
{ {
void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs, OperandX64 rhs, Condition cond, Label& label) void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs, OperandX64 rhs, ConditionX64 cond, Label& label)
{ {
// Refresher on comi/ucomi EFLAGS: // Refresher on comi/ucomi EFLAGS:
// CF only: less // CF only: less
@ -35,52 +35,75 @@ void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs,
// And because of NaN, integer check interchangeability like 'not less or equal' <-> 'greater' does not hold // And because of NaN, integer check interchangeability like 'not less or equal' <-> 'greater' does not hold
switch (cond) switch (cond)
{ {
case Condition::NotLessEqual: case ConditionX64::NotLessEqual:
// (b < a) is the same as !(a <= b). jnae checks CF=1 which means < or NaN // (b < a) is the same as !(a <= b). jnae checks CF=1 which means < or NaN
build.jcc(Condition::NotAboveEqual, label); build.jcc(ConditionX64::NotAboveEqual, label);
break; break;
case Condition::LessEqual: case ConditionX64::LessEqual:
// (b >= a) is the same as (a <= b). jae checks CF=0 which means >= and not NaN // (b >= a) is the same as (a <= b). jae checks CF=0 which means >= and not NaN
build.jcc(Condition::AboveEqual, label); build.jcc(ConditionX64::AboveEqual, label);
break; break;
case Condition::NotLess: case ConditionX64::NotLess:
// (b <= a) is the same as !(a < b). jna checks CF=1 or ZF=1 which means <= or NaN // (b <= a) is the same as !(a < b). jna checks CF=1 or ZF=1 which means <= or NaN
build.jcc(Condition::NotAbove, label); build.jcc(ConditionX64::NotAbove, label);
break; break;
case Condition::Less: case ConditionX64::Less:
// (b > a) is the same as (a < b). ja checks CF=0 and ZF=0 which means > and not NaN // (b > a) is the same as (a < b). ja checks CF=0 and ZF=0 which means > and not NaN
build.jcc(Condition::Above, label); build.jcc(ConditionX64::Above, label);
break; break;
case Condition::NotEqual: case ConditionX64::NotEqual:
// ZF=0 or PF=1 means != or NaN // ZF=0 or PF=1 means != or NaN
build.jcc(Condition::NotZero, label); build.jcc(ConditionX64::NotZero, label);
build.jcc(Condition::Parity, label); build.jcc(ConditionX64::Parity, label);
break; break;
default: default:
LUAU_ASSERT(!"Unsupported condition"); LUAU_ASSERT(!"Unsupported condition");
} }
} }
void jumpOnAnyCmpFallback(AssemblyBuilderX64& build, int ra, int rb, Condition cond, Label& label, int pcpos) void jumpOnAnyCmpFallback(AssemblyBuilderX64& build, int ra, int rb, ConditionX64 cond, Label& label, int pcpos)
{ {
emitSetSavedPc(build, pcpos + 1); emitSetSavedPc(build, pcpos + 1);
build.mov(rArg1, rState);
build.lea(rArg2, luauRegValue(ra));
build.lea(rArg3, luauRegValue(rb));
if (cond == Condition::NotLessEqual || cond == Condition::LessEqual) build.mov(rArg1, rState);
build.lea(rArg2, luauRegAddress(ra));
build.lea(rArg3, luauRegAddress(rb));
if (cond == ConditionX64::NotLessEqual || cond == ConditionX64::LessEqual)
build.call(qword[rNativeContext + offsetof(NativeContext, luaV_lessequal)]); build.call(qword[rNativeContext + offsetof(NativeContext, luaV_lessequal)]);
else if (cond == Condition::NotLess || cond == Condition::Less) else if (cond == ConditionX64::NotLess || cond == ConditionX64::Less)
build.call(qword[rNativeContext + offsetof(NativeContext, luaV_lessthan)]); build.call(qword[rNativeContext + offsetof(NativeContext, luaV_lessthan)]);
else if (cond == Condition::NotEqual || cond == Condition::Equal) else if (cond == ConditionX64::NotEqual || cond == ConditionX64::Equal)
build.call(qword[rNativeContext + offsetof(NativeContext, luaV_equalval)]); build.call(qword[rNativeContext + offsetof(NativeContext, luaV_equalval)]);
else else
LUAU_ASSERT(!"Unsupported condition"); LUAU_ASSERT(!"Unsupported condition");
emitUpdateBase(build); emitUpdateBase(build);
build.test(eax, eax); build.test(eax, eax);
build.jcc( build.jcc(cond == ConditionX64::NotLessEqual || cond == ConditionX64::NotLess || cond == ConditionX64::NotEqual ? ConditionX64::Zero
cond == Condition::NotLessEqual || cond == Condition::NotLess || cond == Condition::NotEqual ? Condition::Zero : Condition::NotZero, label); : ConditionX64::NotZero,
label);
}
RegisterX64 getTableNodeAtCachedSlot(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 table, int pcpos)
{
RegisterX64 node = rdx;
LUAU_ASSERT(tmp != node);
LUAU_ASSERT(table != node);
build.mov(node, qword[table + offsetof(Table, node)]);
// compute cached slot
build.mov(tmp, sCode);
build.movzx(dwordReg(tmp), byte[tmp + pcpos * sizeof(Instruction) + kOffsetOfInstructionC]);
build.and_(byteReg(tmp), byte[table + offsetof(Table, nodemask8)]);
// LuaNode* n = &h->node[slot];
build.shl(dwordReg(tmp), kLuaNodeSizeLog2);
build.add(node, tmp);
return node;
} }
void convertNumberToIndexOrJump(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 numd, RegisterX64 numi, int ri, Label& label) void convertNumberToIndexOrJump(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 numd, RegisterX64 numi, int ri, Label& label)
@ -98,21 +121,21 @@ void convertNumberToIndexOrJump(AssemblyBuilderX64& build, RegisterX64 tmp, Regi
build.vucomisd(tmp, numd); // Sets ZF=1 if equal or NaN build.vucomisd(tmp, numd); // Sets ZF=1 if equal or NaN
// We don't need non-integer values // We don't need non-integer values
// But to skip the PF=1 check, we proceed with NaN because 0x80000000 index is out of bounds // But to skip the PF=1 check, we proceed with NaN because 0x80000000 index is out of bounds
build.jcc(Condition::NotZero, label); build.jcc(ConditionX64::NotZero, label);
} }
void callArithHelper(AssemblyBuilderX64& build, int ra, int rb, OperandX64 c, int pcpos, TMS tm) void callArithHelper(AssemblyBuilderX64& build, int ra, int rb, OperandX64 c, int pcpos, TMS tm)
{ {
emitSetSavedPc(build, pcpos + 1); emitSetSavedPc(build, pcpos + 1);
if (getCurrentX64ABI() == X64ABI::Windows) if (build.abi == ABIX64::Windows)
build.mov(sArg5, tm); build.mov(sArg5, tm);
else else
build.mov(rArg5, tm); build.mov(rArg5, tm);
build.mov(rArg1, rState); build.mov(rArg1, rState);
build.lea(rArg2, luauRegValue(ra)); build.lea(rArg2, luauRegAddress(ra));
build.lea(rArg3, luauRegValue(rb)); build.lea(rArg3, luauRegAddress(rb));
build.lea(rArg4, c); build.lea(rArg4, c);
build.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarith)]); build.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarith)]);
@ -124,8 +147,8 @@ void callLengthHelper(AssemblyBuilderX64& build, int ra, int rb, int pcpos)
emitSetSavedPc(build, pcpos + 1); emitSetSavedPc(build, pcpos + 1);
build.mov(rArg1, rState); build.mov(rArg1, rState);
build.lea(rArg2, luauRegValue(ra)); build.lea(rArg2, luauRegAddress(ra));
build.lea(rArg3, luauRegValue(rb)); build.lea(rArg3, luauRegAddress(rb));
build.call(qword[rNativeContext + offsetof(NativeContext, luaV_dolen)]); build.call(qword[rNativeContext + offsetof(NativeContext, luaV_dolen)]);
emitUpdateBase(build); emitUpdateBase(build);
@ -136,9 +159,9 @@ void callPrepareForN(AssemblyBuilderX64& build, int limit, int step, int init, i
emitSetSavedPc(build, pcpos + 1); emitSetSavedPc(build, pcpos + 1);
build.mov(rArg1, rState); build.mov(rArg1, rState);
build.lea(rArg2, luauRegValue(limit)); build.lea(rArg2, luauRegAddress(limit));
build.lea(rArg3, luauRegValue(step)); build.lea(rArg3, luauRegAddress(step));
build.lea(rArg4, luauRegValue(init)); build.lea(rArg4, luauRegAddress(init));
build.call(qword[rNativeContext + offsetof(NativeContext, luaV_prepareFORN)]); build.call(qword[rNativeContext + offsetof(NativeContext, luaV_prepareFORN)]);
} }
@ -147,9 +170,9 @@ void callGetTable(AssemblyBuilderX64& build, int rb, OperandX64 c, int ra, int p
emitSetSavedPc(build, pcpos + 1); emitSetSavedPc(build, pcpos + 1);
build.mov(rArg1, rState); build.mov(rArg1, rState);
build.lea(rArg2, luauRegValue(rb)); build.lea(rArg2, luauRegAddress(rb));
build.lea(rArg3, c); build.lea(rArg3, c);
build.lea(rArg4, luauRegValue(ra)); build.lea(rArg4, luauRegAddress(ra));
build.call(qword[rNativeContext + offsetof(NativeContext, luaV_gettable)]); build.call(qword[rNativeContext + offsetof(NativeContext, luaV_gettable)]);
emitUpdateBase(build); emitUpdateBase(build);
@ -160,36 +183,78 @@ void callSetTable(AssemblyBuilderX64& build, int rb, OperandX64 c, int ra, int p
emitSetSavedPc(build, pcpos + 1); emitSetSavedPc(build, pcpos + 1);
build.mov(rArg1, rState); build.mov(rArg1, rState);
build.lea(rArg2, luauRegValue(rb)); build.lea(rArg2, luauRegAddress(rb));
build.lea(rArg3, c); build.lea(rArg3, c);
build.lea(rArg4, luauRegValue(ra)); build.lea(rArg4, luauRegAddress(ra));
build.call(qword[rNativeContext + offsetof(NativeContext, luaV_settable)]); build.call(qword[rNativeContext + offsetof(NativeContext, luaV_settable)]);
emitUpdateBase(build); emitUpdateBase(build);
} }
void callBarrierTable(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 table, int ra, Label& skip) // works for luaC_barriertable, luaC_barrierf
static void callBarrierImpl(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 object, int ra, Label& skip, int contextOffset)
{ {
LUAU_ASSERT(tmp != table); LUAU_ASSERT(tmp != object);
// iscollectable(ra) // iscollectable(ra)
build.cmp(luauRegTag(ra), LUA_TSTRING); build.cmp(luauRegTag(ra), LUA_TSTRING);
build.jcc(Condition::Less, skip); build.jcc(ConditionX64::Less, skip);
// isblack(obj2gco(h)) // isblack(obj2gco(o))
build.test(byte[table + offsetof(GCheader, marked)], bitmask(BLACKBIT)); build.test(byte[object + offsetof(GCheader, marked)], bitmask(BLACKBIT));
build.jcc(Condition::Zero, skip); build.jcc(ConditionX64::Zero, skip);
// iswhite(gcvalue(ra)) // iswhite(gcvalue(ra))
build.mov(tmp, luauRegValue(ra)); build.mov(tmp, luauRegValue(ra));
build.test(byte[tmp + offsetof(GCheader, marked)], bit2mask(WHITE0BIT, WHITE1BIT)); build.test(byte[tmp + offsetof(GCheader, marked)], bit2mask(WHITE0BIT, WHITE1BIT));
build.jcc(Condition::Zero, skip); build.jcc(ConditionX64::Zero, skip);
LUAU_ASSERT(table != rArg3); LUAU_ASSERT(object != rArg3);
build.mov(rArg3, tmp); build.mov(rArg3, tmp);
build.mov(rArg2, table); build.mov(rArg2, object);
build.mov(rArg1, rState); build.mov(rArg1, rState);
build.call(qword[rNativeContext + offsetof(NativeContext, luaC_barriertable)]); build.call(qword[rNativeContext + contextOffset]);
}
void callBarrierTable(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 table, int ra, Label& skip)
{
callBarrierImpl(build, tmp, table, ra, skip, offsetof(NativeContext, luaC_barriertable));
}
void callBarrierObject(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 object, int ra, Label& skip)
{
callBarrierImpl(build, tmp, object, ra, skip, offsetof(NativeContext, luaC_barrierf));
}
void callBarrierTableFast(AssemblyBuilderX64& build, RegisterX64 table, Label& skip)
{
// isblack(obj2gco(t))
build.test(byte[table + offsetof(GCheader, marked)], bitmask(BLACKBIT));
build.jcc(ConditionX64::Zero, skip);
// Argument setup re-ordered to avoid conflicts with table register
if (table != rArg2)
build.mov(rArg2, table);
build.lea(rArg3, addr[rArg2 + offsetof(Table, gclist)]);
build.mov(rArg1, rState);
build.call(qword[rNativeContext + offsetof(NativeContext, luaC_barrierback)]);
}
void callCheckGc(AssemblyBuilderX64& build, int pcpos, bool savepc, Label& skip)
{
build.mov(rax, qword[rState + offsetof(lua_State, global)]);
build.mov(rdx, qword[rax + offsetof(global_State, totalbytes)]);
build.cmp(rdx, qword[rax + offsetof(global_State, GCthreshold)]);
build.jcc(ConditionX64::Below, skip);
if (savepc)
emitSetSavedPc(build, pcpos + 1);
build.mov(rArg1, rState);
build.mov(dwordReg(rArg2), 1);
build.call(qword[rNativeContext + offsetof(NativeContext, luaC_step)]);
emitUpdateBase(build);
} }
void emitExit(AssemblyBuilderX64& build, bool continueInVm) void emitExit(AssemblyBuilderX64& build, bool continueInVm)
@ -224,20 +289,20 @@ void emitInterrupt(AssemblyBuilderX64& build, int pcpos)
build.mov(r8, qword[rState + offsetof(lua_State, global)]); build.mov(r8, qword[rState + offsetof(lua_State, global)]);
build.mov(r8, qword[r8 + offsetof(global_State, cb.interrupt)]); build.mov(r8, qword[r8 + offsetof(global_State, cb.interrupt)]);
build.test(r8, r8); build.test(r8, r8);
build.jcc(Condition::Zero, skip); build.jcc(ConditionX64::Zero, skip);
emitSetSavedPc(build, pcpos + 1); // uses rax/rdx emitSetSavedPc(build, pcpos + 1); // uses rax/rdx
// Call interrupt // Call interrupt
// TODO: This code should move to the end of the function, or even be outlined so that it can be shared by multiple interruptible instructions // TODO: This code should move to the end of the function, or even be outlined so that it can be shared by multiple interruptible instructions
build.mov(rArg1, rState); build.mov(rArg1, rState);
build.mov(rArg2d, -1); build.mov(dwordReg(rArg2), -1); // function accepts 'int' here and using qword reg would've forced 8 byte constant here
build.call(r8); build.call(r8);
// Check if we need to exit // Check if we need to exit
build.mov(al, byte[rState + offsetof(lua_State, status)]); build.mov(al, byte[rState + offsetof(lua_State, status)]);
build.test(al, al); build.test(al, al);
build.jcc(Condition::Zero, skip); build.jcc(ConditionX64::Zero, skip);
build.mov(rax, qword[rState + offsetof(lua_State, ci)]); build.mov(rax, qword[rState + offsetof(lua_State, ci)]);
build.sub(qword[rax + offsetof(CallInfo, savedpc)], sizeof(Instruction)); build.sub(qword[rax + offsetof(CallInfo, savedpc)], sizeof(Instruction));
@ -265,17 +330,6 @@ void emitFallback(AssemblyBuilderX64& build, NativeState& data, int op, int pcpo
build.mov(rArg4, rConstants); build.mov(rArg4, rConstants);
build.call(qword[rNativeContext + offsetof(NativeContext, fallback) + op * sizeof(NativeFallback) + offsetof(NativeFallback, fallback)]); build.call(qword[rNativeContext + offsetof(NativeContext, fallback) + op * sizeof(NativeFallback) + offsetof(NativeFallback, fallback)]);
// Some instructions may interrupt the execution
if (opinfo.flags & kFallbackCheckInterrupt)
{
Label skip;
build.test(rax, rax);
build.jcc(Condition::NotZero, skip);
emitExit(build, /* continueInVm */ false);
build.setLabel(skip);
}
emitUpdateBase(build); emitUpdateBase(build);
// Some instructions may jump to a different instruction or a completely different function // Some instructions may jump to a different instruction or a completely different function
@ -295,50 +349,17 @@ void emitFallback(AssemblyBuilderX64& build, NativeState& data, int op, int pcpo
build.mov(rcx, qword[rdx + offsetof(NativeProto, instTargets)]); build.mov(rcx, qword[rdx + offsetof(NativeProto, instTargets)]);
build.jmp(qword[rax * 2 + rcx]); build.jmp(qword[rax * 2 + rcx]);
} }
else if (opinfo.flags & kFallbackUpdateCi) }
{
// Need to update state of the current function before we jump away
build.mov(rcx, qword[rState + offsetof(lua_State, ci)]); // L->ci
build.mov(rcx, qword[rcx + offsetof(CallInfo, func)]); // L->ci->func
build.mov(rcx, qword[rcx + offsetof(TValue, value.gc)]); // L->ci->func->value.gc aka cl
build.mov(sClosure, rcx);
build.mov(rsi, qword[rcx + offsetof(Closure, l.p)]); // cl->l.p aka proto
build.mov(rConstants, qword[rsi + offsetof(Proto, k)]); // proto->k
build.mov(rcx, qword[rsi + offsetof(Proto, code)]); // proto->code
build.mov(sCode, rcx);
// We'll need original instruction pointer later to handle return to interpreter void emitContinueCallInVm(AssemblyBuilderX64& build)
if (op == LOP_CALL) {
build.mov(r9, rax); RegisterX64 proto = rcx; // Sync with emitInstCall
// Get instruction index from instruction pointer build.mov(rdx, qword[proto + offsetof(Proto, code)]);
// To get instruction index from instruction pointer, we need to divide byte offset by 4 build.mov(rax, qword[rState + offsetof(lua_State, ci)]);
// But we will actually need to scale instruction index by 8 back to byte offset later so it cancels out build.mov(qword[rax + offsetof(CallInfo, savedpc)], rdx);
build.sub(rax, sCode);
// We need to check if the new function can be executed natively emitExit(build, /* continueInVm */ true);
Label returnToInterpreter;
build.mov(rdx, qword[rsi + offsetofProtoExecData]);
build.test(rdx, rdx);
build.jcc(Condition::Zero, returnToInterpreter);
// Get new instruction location and jump to it
build.mov(rcx, qword[rdx + offsetof(NativeProto, instTargets)]);
build.jmp(qword[rax * 2 + rcx]);
build.setLabel(returnToInterpreter);
// If we are returning to the interpreter to make a call, we need to update the current instruction
if (op == LOP_CALL)
{
build.mov(rax, qword[rState + offsetof(lua_State, ci)]);
build.mov(qword[rax + offsetof(CallInfo, savedpc)], r9);
}
// Continue in the interpreter
emitExit(build, /* continueInVm */ true);
}
} }
} // namespace CodeGen } // namespace CodeGen

View file

@ -35,11 +35,11 @@ constexpr RegisterX64 rConstants = r12; // TValue* k
constexpr OperandX64 sClosure = qword[rbp + 0]; // Closure* cl constexpr OperandX64 sClosure = qword[rbp + 0]; // Closure* cl
constexpr OperandX64 sCode = qword[rbp + 8]; // Instruction* code constexpr OperandX64 sCode = qword[rbp + 8]; // Instruction* code
// TODO: These should be replaced with a portable call function that checks the ABI at runtime and reorders moves accordingly to avoid conflicts
#if defined(_WIN32) #if defined(_WIN32)
constexpr RegisterX64 rArg1 = rcx; constexpr RegisterX64 rArg1 = rcx;
constexpr RegisterX64 rArg2 = rdx; constexpr RegisterX64 rArg2 = rdx;
constexpr RegisterX64 rArg2d = edx;
constexpr RegisterX64 rArg3 = r8; constexpr RegisterX64 rArg3 = r8;
constexpr RegisterX64 rArg4 = r9; constexpr RegisterX64 rArg4 = r9;
constexpr RegisterX64 rArg5 = noreg; constexpr RegisterX64 rArg5 = noreg;
@ -51,7 +51,6 @@ constexpr OperandX64 sArg6 = qword[rsp + 40];
constexpr RegisterX64 rArg1 = rdi; constexpr RegisterX64 rArg1 = rdi;
constexpr RegisterX64 rArg2 = rsi; constexpr RegisterX64 rArg2 = rsi;
constexpr RegisterX64 rArg2d = esi;
constexpr RegisterX64 rArg3 = rdx; constexpr RegisterX64 rArg3 = rdx;
constexpr RegisterX64 rArg4 = rcx; constexpr RegisterX64 rArg4 = rcx;
constexpr RegisterX64 rArg5 = r8; constexpr RegisterX64 rArg5 = r8;
@ -62,12 +61,30 @@ constexpr OperandX64 sArg6 = noreg;
#endif #endif
constexpr unsigned kTValueSizeLog2 = 4; constexpr unsigned kTValueSizeLog2 = 4;
constexpr unsigned kLuaNodeSizeLog2 = 5;
constexpr unsigned kLuaNodeTagMask = 0xf;
constexpr unsigned kOffsetOfLuaNodeTag = 12; // offsetof cannot be used on a bit field
constexpr unsigned kOffsetOfInstructionC = 3;
// Leaf functions that are placed in every module to perform common instruction sequences
struct ModuleHelpers
{
Label exitContinueVm;
Label exitNoContinueVm;
Label continueCallInVm;
};
inline OperandX64 luauReg(int ri) inline OperandX64 luauReg(int ri)
{ {
return xmmword[rBase + ri * sizeof(TValue)]; return xmmword[rBase + ri * sizeof(TValue)];
} }
inline OperandX64 luauRegAddress(int ri)
{
return addr[rBase + ri * sizeof(TValue)];
}
inline OperandX64 luauRegValue(int ri) inline OperandX64 luauRegValue(int ri)
{ {
return qword[rBase + ri * sizeof(TValue) + offsetof(TValue, value)]; return qword[rBase + ri * sizeof(TValue) + offsetof(TValue, value)];
@ -88,11 +105,37 @@ inline OperandX64 luauConstant(int ki)
return xmmword[rConstants + ki * sizeof(TValue)]; return xmmword[rConstants + ki * sizeof(TValue)];
} }
inline OperandX64 luauConstantAddress(int ki)
{
return addr[rConstants + ki * sizeof(TValue)];
}
inline OperandX64 luauConstantTag(int ki)
{
return dword[rConstants + ki * sizeof(TValue) + offsetof(TValue, tt)];
}
inline OperandX64 luauConstantValue(int ki) inline OperandX64 luauConstantValue(int ki)
{ {
return qword[rConstants + ki * sizeof(TValue) + offsetof(TValue, value)]; return qword[rConstants + ki * sizeof(TValue) + offsetof(TValue, value)];
} }
inline OperandX64 luauNodeKeyValue(RegisterX64 node)
{
return qword[node + offsetof(LuaNode, key) + offsetof(TKey, value)];
}
// Note: tag has dirty upper bits
inline OperandX64 luauNodeKeyTag(RegisterX64 node)
{
return dword[node + offsetof(LuaNode, key) + kOffsetOfLuaNodeTag];
}
inline OperandX64 luauNodeValue(RegisterX64 node)
{
return xmmword[node + offsetof(LuaNode, val)];
}
inline void setLuauReg(AssemblyBuilderX64& build, RegisterX64 tmp, int ri, OperandX64 op) inline void setLuauReg(AssemblyBuilderX64& build, RegisterX64 tmp, int ri, OperandX64 op)
{ {
LUAU_ASSERT(op.cat == CategoryX64::mem); LUAU_ASSERT(op.cat == CategoryX64::mem);
@ -101,16 +144,24 @@ inline void setLuauReg(AssemblyBuilderX64& build, RegisterX64 tmp, int ri, Opera
build.vmovups(luauReg(ri), tmp); build.vmovups(luauReg(ri), tmp);
} }
inline void setNodeValue(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 op, int ri)
{
LUAU_ASSERT(op.cat == CategoryX64::mem);
build.vmovups(tmp, luauReg(ri));
build.vmovups(op, tmp);
}
inline void jumpIfTagIs(AssemblyBuilderX64& build, int ri, lua_Type tag, Label& label) inline void jumpIfTagIs(AssemblyBuilderX64& build, int ri, lua_Type tag, Label& label)
{ {
build.cmp(luauRegTag(ri), tag); build.cmp(luauRegTag(ri), tag);
build.jcc(Condition::Equal, label); build.jcc(ConditionX64::Equal, label);
} }
inline void jumpIfTagIsNot(AssemblyBuilderX64& build, int ri, lua_Type tag, Label& label) inline void jumpIfTagIsNot(AssemblyBuilderX64& build, int ri, lua_Type tag, Label& label)
{ {
build.cmp(luauRegTag(ri), tag); build.cmp(luauRegTag(ri), tag);
build.jcc(Condition::NotEqual, label); build.jcc(ConditionX64::NotEqual, label);
} }
// Note: fallthrough label should be placed after this condition // Note: fallthrough label should be placed after this condition
@ -120,7 +171,7 @@ inline void jumpIfFalsy(AssemblyBuilderX64& build, int ri, Label& target, Label&
jumpIfTagIsNot(build, ri, LUA_TBOOLEAN, fallthrough); // true if not nil or boolean jumpIfTagIsNot(build, ri, LUA_TBOOLEAN, fallthrough); // true if not nil or boolean
build.cmp(luauRegValueBoolean(ri), 0); build.cmp(luauRegValueBoolean(ri), 0);
build.jcc(Condition::Equal, target); // true if boolean value is 'true' build.jcc(ConditionX64::Equal, target); // true if boolean value is 'true'
} }
// Note: fallthrough label should be placed after this condition // Note: fallthrough label should be placed after this condition
@ -130,13 +181,13 @@ inline void jumpIfTruthy(AssemblyBuilderX64& build, int ri, Label& target, Label
jumpIfTagIsNot(build, ri, LUA_TBOOLEAN, target); // true if not nil or boolean jumpIfTagIsNot(build, ri, LUA_TBOOLEAN, target); // true if not nil or boolean
build.cmp(luauRegValueBoolean(ri), 0); build.cmp(luauRegValueBoolean(ri), 0);
build.jcc(Condition::NotEqual, target); // true if boolean value is 'true' build.jcc(ConditionX64::NotEqual, target); // true if boolean value is 'true'
} }
inline void jumpIfMetatablePresent(AssemblyBuilderX64& build, RegisterX64 table, Label& target) inline void jumpIfMetatablePresent(AssemblyBuilderX64& build, RegisterX64 table, Label& target)
{ {
build.cmp(qword[table + offsetof(Table, metatable)], 0); build.cmp(qword[table + offsetof(Table, metatable)], 0);
build.jcc(Condition::NotEqual, target); build.jcc(ConditionX64::NotEqual, target);
} }
inline void jumpIfUnsafeEnv(AssemblyBuilderX64& build, RegisterX64 tmp, Label& label) inline void jumpIfUnsafeEnv(AssemblyBuilderX64& build, RegisterX64 tmp, Label& label)
@ -144,18 +195,46 @@ inline void jumpIfUnsafeEnv(AssemblyBuilderX64& build, RegisterX64 tmp, Label& l
build.mov(tmp, sClosure); build.mov(tmp, sClosure);
build.mov(tmp, qword[tmp + offsetof(Closure, env)]); build.mov(tmp, qword[tmp + offsetof(Closure, env)]);
build.test(byte[tmp + offsetof(Table, safeenv)], 1); build.test(byte[tmp + offsetof(Table, safeenv)], 1);
build.jcc(Condition::Zero, label); // Not a safe environment build.jcc(ConditionX64::Zero, label); // Not a safe environment
} }
inline void jumpIfTableIsReadOnly(AssemblyBuilderX64& build, RegisterX64 table, Label& label) inline void jumpIfTableIsReadOnly(AssemblyBuilderX64& build, RegisterX64 table, Label& label)
{ {
build.cmp(byte[table + offsetof(Table, readonly)], 0); build.cmp(byte[table + offsetof(Table, readonly)], 0);
build.jcc(Condition::NotEqual, label); build.jcc(ConditionX64::NotEqual, label);
} }
void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs, OperandX64 rhs, Condition cond, Label& label); inline void jumpIfNodeKeyTagIsNot(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 node, lua_Type tag, Label& label)
void jumpOnAnyCmpFallback(AssemblyBuilderX64& build, int ra, int rb, Condition cond, Label& label, int pcpos); {
tmp.size = SizeX64::dword;
build.mov(tmp, luauNodeKeyTag(node));
build.and_(tmp, kLuaNodeTagMask);
build.cmp(tmp, tag);
build.jcc(ConditionX64::NotEqual, label);
}
inline void jumpIfNodeValueTagIs(AssemblyBuilderX64& build, RegisterX64 node, lua_Type tag, Label& label)
{
build.cmp(dword[node + offsetof(LuaNode, val) + offsetof(TValue, tt)], tag);
build.jcc(ConditionX64::Equal, label);
}
inline void jumpIfNodeKeyNotInExpectedSlot(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 node, OperandX64 expectedKey, Label& label)
{
jumpIfNodeKeyTagIsNot(build, tmp, node, LUA_TSTRING, label);
build.mov(tmp, expectedKey);
build.cmp(tmp, luauNodeKeyValue(node));
build.jcc(ConditionX64::NotEqual, label);
jumpIfNodeValueTagIs(build, node, LUA_TNIL, label);
}
void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs, OperandX64 rhs, ConditionX64 cond, Label& label);
void jumpOnAnyCmpFallback(AssemblyBuilderX64& build, int ra, int rb, ConditionX64 cond, Label& label, int pcpos);
RegisterX64 getTableNodeAtCachedSlot(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 table, int pcpos);
void convertNumberToIndexOrJump(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 numd, RegisterX64 numi, int ri, Label& label); void convertNumberToIndexOrJump(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 numd, RegisterX64 numi, int ri, Label& label);
void callArithHelper(AssemblyBuilderX64& build, int ra, int rb, OperandX64 c, int pcpos, TMS tm); void callArithHelper(AssemblyBuilderX64& build, int ra, int rb, OperandX64 c, int pcpos, TMS tm);
@ -164,6 +243,9 @@ void callPrepareForN(AssemblyBuilderX64& build, int limit, int step, int init, i
void callGetTable(AssemblyBuilderX64& build, int rb, OperandX64 c, int ra, int pcpos); void callGetTable(AssemblyBuilderX64& build, int rb, OperandX64 c, int ra, int pcpos);
void callSetTable(AssemblyBuilderX64& build, int rb, OperandX64 c, int ra, int pcpos); void callSetTable(AssemblyBuilderX64& build, int rb, OperandX64 c, int ra, int pcpos);
void callBarrierTable(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 table, int ra, Label& skip); void callBarrierTable(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 table, int ra, Label& skip);
void callBarrierObject(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 object, int ra, Label& skip);
void callBarrierTableFast(AssemblyBuilderX64& build, RegisterX64 table, Label& skip);
void callCheckGc(AssemblyBuilderX64& build, int pcpos, bool savepc, Label& skip);
void emitExit(AssemblyBuilderX64& build, bool continueInVm); void emitExit(AssemblyBuilderX64& build, bool continueInVm);
void emitUpdateBase(AssemblyBuilderX64& build); void emitUpdateBase(AssemblyBuilderX64& build);
@ -171,5 +253,8 @@ void emitSetSavedPc(AssemblyBuilderX64& build, int pcpos); // Note: only uses ra
void emitInterrupt(AssemblyBuilderX64& build, int pcpos); void emitInterrupt(AssemblyBuilderX64& build, int pcpos);
void emitFallback(AssemblyBuilderX64& build, NativeState& data, int op, int pcpos); void emitFallback(AssemblyBuilderX64& build, NativeState& data, int op, int pcpos);
void emitContinueCallInVm(AssemblyBuilderX64& build);
void emitExitFromLastReturn(AssemblyBuilderX64& build);
} // namespace CodeGen } // namespace CodeGen
} // namespace Luau } // namespace Luau

File diff suppressed because it is too large Load diff

View file

@ -3,6 +3,8 @@
#include <stdint.h> #include <stdint.h>
#include "ltm.h"
typedef uint32_t Instruction; typedef uint32_t Instruction;
typedef struct lua_TValue TValue; typedef struct lua_TValue TValue;
@ -12,55 +14,76 @@ namespace CodeGen
{ {
class AssemblyBuilderX64; class AssemblyBuilderX64;
enum class Condition; enum class ConditionX64 : uint8_t;
struct Label; struct Label;
struct NativeState; struct ModuleHelpers;
void emitInstLoadNil(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc); void emitInstLoadNil(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstLoadB(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc, int pcpos, Label* labelarr); void emitInstLoadB(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstLoadN(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc); void emitInstLoadN(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstLoadK(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc, const TValue* k); void emitInstLoadK(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstMove(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc); void emitInstLoadKX(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstJump(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc, int pcpos, Label* labelarr); void emitInstMove(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstJumpBack(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc, int pcpos, Label* labelarr); void emitInstCall(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstJumpIf(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc, int pcpos, Label* labelarr, bool not_); void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstJumpIfEq(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc, int pcpos, Label* labelarr, bool not_); void emitInstJump(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstJumpIfCond(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc, int pcpos, Label* labelarr, Condition cond); void emitInstJumpBack(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstJumpX(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc, int pcpos, Label* labelarr); void emitInstJumpIf(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, bool not_);
void emitInstJumpxEqNil(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc, const TValue* k, int pcpos, Label* labelarr); void emitInstJumpIfEq(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, bool not_, Label& fallback);
void emitInstJumpxEqB(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc, const TValue* k, int pcpos, Label* labelarr); void emitInstJumpIfEqFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, bool not_);
void emitInstJumpxEqN(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc, const TValue* k, int pcpos, Label* labelarr); void emitInstJumpIfCond(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, ConditionX64 cond, Label& fallback);
void emitInstJumpxEqS(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc, const TValue* k, int pcpos, Label* labelarr); void emitInstJumpIfCondFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, ConditionX64 cond);
void emitInstAdd(AssemblyBuilderX64& build, const Instruction* pc, int pcpos); void emitInstJumpX(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstSub(AssemblyBuilderX64& build, const Instruction* pc, int pcpos); void emitInstJumpxEqNil(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstMul(AssemblyBuilderX64& build, const Instruction* pc, int pcpos); void emitInstJumpxEqB(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstDiv(AssemblyBuilderX64& build, const Instruction* pc, int pcpos); void emitInstJumpxEqN(AssemblyBuilderX64& build, const Instruction* pc, const TValue* k, int pcpos, Label* labelarr);
void emitInstMod(AssemblyBuilderX64& build, const Instruction* pc, int pcpos); void emitInstJumpxEqS(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstPow(AssemblyBuilderX64& build, const Instruction* pc, int pcpos); void emitInstBinary(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, TMS tm, Label& fallback);
void emitInstAddK(AssemblyBuilderX64& build, const Instruction* pc, const TValue* k, int pcpos); void emitInstBinaryFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, TMS tm);
void emitInstSubK(AssemblyBuilderX64& build, const Instruction* pc, const TValue* k, int pcpos); void emitInstBinaryK(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, TMS tm, Label& fallback);
void emitInstMulK(AssemblyBuilderX64& build, const Instruction* pc, const TValue* k, int pcpos); void emitInstBinaryKFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, TMS tm);
void emitInstDivK(AssemblyBuilderX64& build, const Instruction* pc, const TValue* k, int pcpos); void emitInstPowK(AssemblyBuilderX64& build, const Instruction* pc, const TValue* k, int pcpos, Label& fallback);
void emitInstModK(AssemblyBuilderX64& build, const Instruction* pc, const TValue* k, int pcpos);
void emitInstPowK(AssemblyBuilderX64& build, const Instruction* pc, const TValue* k, int pcpos);
void emitInstNot(AssemblyBuilderX64& build, const Instruction* pc); void emitInstNot(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstMinus(AssemblyBuilderX64& build, const Instruction* pc, int pcpos); void emitInstMinus(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback);
void emitInstLength(AssemblyBuilderX64& build, const Instruction* pc, int pcpos); void emitInstMinusFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos);
void emitInstLength(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback);
void emitInstLengthFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos);
void emitInstNewTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstDupTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstSetList(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstGetUpval(AssemblyBuilderX64& build, const Instruction* pc, int pcpos); void emitInstGetUpval(AssemblyBuilderX64& build, const Instruction* pc, int pcpos);
void emitInstFastCall1(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr); void emitInstSetUpval(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstFastCall2(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr); void emitInstCloseUpvals(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstFastCall2K(AssemblyBuilderX64& build, const Instruction* pc, const TValue* k, int pcpos, Label* labelarr); int emitInstFastCall1(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstFastCall(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr); int emitInstFastCall2(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
int emitInstFastCall2K(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
int emitInstFastCall(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstForNPrep(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr); void emitInstForNPrep(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstForNLoop(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr); void emitInstForNLoop(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitinstForGLoop(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Label& fallback);
void emitinstForGLoopFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstForGPrepNext(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Label& fallback);
void emitInstForGPrepInext(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Label& fallback);
void emitInstForGPrepXnextFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstAnd(AssemblyBuilderX64& build, const Instruction* pc); void emitInstAnd(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstAndK(AssemblyBuilderX64& build, const Instruction* pc); void emitInstAndK(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstOr(AssemblyBuilderX64& build, const Instruction* pc); void emitInstOr(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstOrK(AssemblyBuilderX64& build, const Instruction* pc); void emitInstOrK(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstGetTableN(AssemblyBuilderX64& build, const Instruction* pc, int pcpos); void emitInstGetTableN(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback);
void emitInstSetTableN(AssemblyBuilderX64& build, const Instruction* pc, int pcpos); void emitInstGetTableNFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos);
void emitInstGetTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpos); void emitInstSetTableN(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Label& fallback);
void emitInstSetTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpos); void emitInstSetTableNFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos);
void emitInstGetTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback);
void emitInstGetTableFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos);
void emitInstSetTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Label& fallback);
void emitInstSetTableFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos);
void emitInstGetImport(AssemblyBuilderX64& build, const Instruction* pc, Label& fallback);
void emitInstGetImportFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos);
void emitInstGetTableKS(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback);
void emitInstSetTableKS(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Label& fallback);
void emitInstGetGlobal(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback);
void emitInstSetGlobal(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Label& fallback);
void emitInstConcat(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
} // namespace CodeGen } // namespace CodeGen
} // namespace Luau } // namespace Luau

File diff suppressed because it is too large Load diff

View file

@ -10,84 +10,15 @@ typedef uint32_t Instruction;
typedef struct lua_TValue TValue; typedef struct lua_TValue TValue;
typedef TValue* StkId; typedef TValue* StkId;
const Instruction* execute_LOP_NOP(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_LOADNIL(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_LOADB(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_LOADN(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_LOADK(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_MOVE(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_GETGLOBAL(lua_State* L, const Instruction* pc, StkId base, TValue* k); const Instruction* execute_LOP_GETGLOBAL(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_SETGLOBAL(lua_State* L, const Instruction* pc, StkId base, TValue* k); const Instruction* execute_LOP_SETGLOBAL(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_GETUPVAL(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_SETUPVAL(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_CLOSEUPVALS(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_GETIMPORT(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_GETTABLEKS(lua_State* L, const Instruction* pc, StkId base, TValue* k); const Instruction* execute_LOP_GETTABLEKS(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_SETTABLEKS(lua_State* L, const Instruction* pc, StkId base, TValue* k); const Instruction* execute_LOP_SETTABLEKS(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_GETTABLE(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_SETTABLE(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_GETTABLEN(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_SETTABLEN(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_NEWCLOSURE(lua_State* L, const Instruction* pc, StkId base, TValue* k); const Instruction* execute_LOP_NEWCLOSURE(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_NAMECALL(lua_State* L, const Instruction* pc, StkId base, TValue* k); const Instruction* execute_LOP_NAMECALL(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_CALL(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_RETURN(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMP(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPIF(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPIFNOT(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPIFEQ(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPIFNOTEQ(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPIFLE(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPIFNOTLE(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPIFLT(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPIFNOTLT(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_ADD(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_SUB(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_MUL(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_DIV(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_MOD(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_POW(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_ADDK(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_SUBK(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_MULK(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_DIVK(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_MODK(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_POWK(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_AND(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_OR(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_ANDK(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_ORK(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_CONCAT(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_NOT(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_MINUS(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_LENGTH(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_NEWTABLE(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_DUPTABLE(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_SETLIST(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_FORNPREP(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_FORNLOOP(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_FORGPREP(lua_State* L, const Instruction* pc, StkId base, TValue* k); const Instruction* execute_LOP_FORGPREP(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_FORGLOOP(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_FORGPREP_INEXT(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_DEP_FORGLOOP_INEXT(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_FORGPREP_NEXT(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_DEP_FORGLOOP_NEXT(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_GETVARARGS(lua_State* L, const Instruction* pc, StkId base, TValue* k); const Instruction* execute_LOP_GETVARARGS(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_DUPCLOSURE(lua_State* L, const Instruction* pc, StkId base, TValue* k); const Instruction* execute_LOP_DUPCLOSURE(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_PREPVARARGS(lua_State* L, const Instruction* pc, StkId base, TValue* k); const Instruction* execute_LOP_PREPVARARGS(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPBACK(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_LOADKX(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPX(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_FASTCALL(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_COVERAGE(lua_State* L, const Instruction* pc, StkId base, TValue* k); const Instruction* execute_LOP_COVERAGE(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_CAPTURE(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_DEP_JUMPIFEQK(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_DEP_JUMPIFNOTEQK(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_FASTCALL1(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_FASTCALL2(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_FASTCALL2K(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_BREAK(lua_State* L, const Instruction* pc, StkId base, TValue* k); const Instruction* execute_LOP_BREAK(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPXEQKNIL(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPXEQKB(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPXEQKN(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPXEQKS(lua_State* L, const Instruction* pc, StkId base, TValue* k);

View file

@ -3,46 +3,20 @@
#include "Luau/UnwindBuilder.h" #include "Luau/UnwindBuilder.h"
#include "CodeGenUtils.h"
#include "CustomExecUtils.h" #include "CustomExecUtils.h"
#include "Fallbacks.h" #include "Fallbacks.h"
#include "lbuiltins.h" #include "lbuiltins.h"
#include "lgc.h" #include "lgc.h"
#include "ltable.h" #include "ltable.h"
#include "lfunc.h"
#include "lvm.h" #include "lvm.h"
#include <math.h> #include <math.h>
#include <string.h>
#define CODEGEN_SET_FALLBACK(op, flags) data.context.fallback[op] = {execute_##op, flags} #define CODEGEN_SET_FALLBACK(op, flags) data.context.fallback[op] = {execute_##op, flags}
#define CODEGEN_SET_NAME(op) data.names[op] = #op
// Similar to a dispatch table in lvmexecute.cpp
#define CODEGEN_SET_NAMES() \
CODEGEN_SET_NAME(LOP_NOP), CODEGEN_SET_NAME(LOP_BREAK), CODEGEN_SET_NAME(LOP_LOADNIL), CODEGEN_SET_NAME(LOP_LOADB), CODEGEN_SET_NAME(LOP_LOADN), \
CODEGEN_SET_NAME(LOP_LOADK), CODEGEN_SET_NAME(LOP_MOVE), CODEGEN_SET_NAME(LOP_GETGLOBAL), CODEGEN_SET_NAME(LOP_SETGLOBAL), \
CODEGEN_SET_NAME(LOP_GETUPVAL), CODEGEN_SET_NAME(LOP_SETUPVAL), CODEGEN_SET_NAME(LOP_CLOSEUPVALS), CODEGEN_SET_NAME(LOP_GETIMPORT), \
CODEGEN_SET_NAME(LOP_GETTABLE), CODEGEN_SET_NAME(LOP_SETTABLE), CODEGEN_SET_NAME(LOP_GETTABLEKS), CODEGEN_SET_NAME(LOP_SETTABLEKS), \
CODEGEN_SET_NAME(LOP_GETTABLEN), CODEGEN_SET_NAME(LOP_SETTABLEN), CODEGEN_SET_NAME(LOP_NEWCLOSURE), CODEGEN_SET_NAME(LOP_NAMECALL), \
CODEGEN_SET_NAME(LOP_CALL), CODEGEN_SET_NAME(LOP_RETURN), CODEGEN_SET_NAME(LOP_JUMP), CODEGEN_SET_NAME(LOP_JUMPBACK), \
CODEGEN_SET_NAME(LOP_JUMPIF), CODEGEN_SET_NAME(LOP_JUMPIFNOT), CODEGEN_SET_NAME(LOP_JUMPIFEQ), CODEGEN_SET_NAME(LOP_JUMPIFLE), \
CODEGEN_SET_NAME(LOP_JUMPIFLT), CODEGEN_SET_NAME(LOP_JUMPIFNOTEQ), CODEGEN_SET_NAME(LOP_JUMPIFNOTLE), CODEGEN_SET_NAME(LOP_JUMPIFNOTLT), \
CODEGEN_SET_NAME(LOP_ADD), CODEGEN_SET_NAME(LOP_SUB), CODEGEN_SET_NAME(LOP_MUL), CODEGEN_SET_NAME(LOP_DIV), CODEGEN_SET_NAME(LOP_MOD), \
CODEGEN_SET_NAME(LOP_POW), CODEGEN_SET_NAME(LOP_ADDK), CODEGEN_SET_NAME(LOP_SUBK), CODEGEN_SET_NAME(LOP_MULK), CODEGEN_SET_NAME(LOP_DIVK), \
CODEGEN_SET_NAME(LOP_MODK), CODEGEN_SET_NAME(LOP_POWK), CODEGEN_SET_NAME(LOP_AND), CODEGEN_SET_NAME(LOP_OR), CODEGEN_SET_NAME(LOP_ANDK), \
CODEGEN_SET_NAME(LOP_ORK), CODEGEN_SET_NAME(LOP_CONCAT), CODEGEN_SET_NAME(LOP_NOT), CODEGEN_SET_NAME(LOP_MINUS), \
CODEGEN_SET_NAME(LOP_LENGTH), CODEGEN_SET_NAME(LOP_NEWTABLE), CODEGEN_SET_NAME(LOP_DUPTABLE), CODEGEN_SET_NAME(LOP_SETLIST), \
CODEGEN_SET_NAME(LOP_FORNPREP), CODEGEN_SET_NAME(LOP_FORNLOOP), CODEGEN_SET_NAME(LOP_FORGLOOP), CODEGEN_SET_NAME(LOP_FORGPREP_INEXT), \
CODEGEN_SET_NAME(LOP_DEP_FORGLOOP_INEXT), CODEGEN_SET_NAME(LOP_FORGPREP_NEXT), CODEGEN_SET_NAME(LOP_DEP_FORGLOOP_NEXT), \
CODEGEN_SET_NAME(LOP_GETVARARGS), CODEGEN_SET_NAME(LOP_DUPCLOSURE), CODEGEN_SET_NAME(LOP_PREPVARARGS), CODEGEN_SET_NAME(LOP_LOADKX), \
CODEGEN_SET_NAME(LOP_JUMPX), CODEGEN_SET_NAME(LOP_FASTCALL), CODEGEN_SET_NAME(LOP_COVERAGE), CODEGEN_SET_NAME(LOP_CAPTURE), \
CODEGEN_SET_NAME(LOP_DEP_JUMPIFEQK), CODEGEN_SET_NAME(LOP_DEP_JUMPIFNOTEQK), CODEGEN_SET_NAME(LOP_FASTCALL1), \
CODEGEN_SET_NAME(LOP_FASTCALL2), CODEGEN_SET_NAME(LOP_FASTCALL2K), CODEGEN_SET_NAME(LOP_FORGPREP), CODEGEN_SET_NAME(LOP_JUMPXEQKNIL), \
CODEGEN_SET_NAME(LOP_JUMPXEQKB), CODEGEN_SET_NAME(LOP_JUMPXEQKN), CODEGEN_SET_NAME(LOP_JUMPXEQKS)
static int luauF_missing(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams)
{
return -1;
}
namespace Luau namespace Luau
{ {
@ -61,41 +35,27 @@ NativeState::~NativeState() = default;
void initFallbackTable(NativeState& data) void initFallbackTable(NativeState& data)
{ {
CODEGEN_SET_FALLBACK(LOP_GETGLOBAL, 0); // When fallback is completely removed, remove it from includeInsts list in lvmexecute_split.py
CODEGEN_SET_FALLBACK(LOP_SETGLOBAL, 0); CODEGEN_SET_FALLBACK(LOP_NEWCLOSURE, 0);
CODEGEN_SET_FALLBACK(LOP_SETUPVAL, 0);
CODEGEN_SET_FALLBACK(LOP_CLOSEUPVALS, 0);
CODEGEN_SET_FALLBACK(LOP_GETIMPORT, 0);
CODEGEN_SET_FALLBACK(LOP_GETTABLEKS, 0);
CODEGEN_SET_FALLBACK(LOP_SETTABLEKS, 0);
CODEGEN_SET_FALLBACK(LOP_NEWCLOSURE, kFallbackUpdatePc);
CODEGEN_SET_FALLBACK(LOP_NAMECALL, 0); CODEGEN_SET_FALLBACK(LOP_NAMECALL, 0);
CODEGEN_SET_FALLBACK(LOP_CALL, kFallbackUpdateCi | kFallbackCheckInterrupt);
CODEGEN_SET_FALLBACK(LOP_RETURN, kFallbackUpdateCi | kFallbackCheckInterrupt);
CODEGEN_SET_FALLBACK(LOP_CONCAT, 0);
CODEGEN_SET_FALLBACK(LOP_NEWTABLE, 0);
CODEGEN_SET_FALLBACK(LOP_DUPTABLE, 0);
CODEGEN_SET_FALLBACK(LOP_SETLIST, kFallbackCheckInterrupt);
CODEGEN_SET_FALLBACK(LOP_FORGPREP, kFallbackUpdatePc); CODEGEN_SET_FALLBACK(LOP_FORGPREP, kFallbackUpdatePc);
CODEGEN_SET_FALLBACK(LOP_FORGLOOP, kFallbackUpdatePc | kFallbackCheckInterrupt);
CODEGEN_SET_FALLBACK(LOP_FORGPREP_INEXT, kFallbackUpdatePc);
CODEGEN_SET_FALLBACK(LOP_FORGPREP_NEXT, kFallbackUpdatePc);
CODEGEN_SET_FALLBACK(LOP_GETVARARGS, 0); CODEGEN_SET_FALLBACK(LOP_GETVARARGS, 0);
CODEGEN_SET_FALLBACK(LOP_DUPCLOSURE, 0); CODEGEN_SET_FALLBACK(LOP_DUPCLOSURE, 0);
CODEGEN_SET_FALLBACK(LOP_PREPVARARGS, 0); CODEGEN_SET_FALLBACK(LOP_PREPVARARGS, 0);
CODEGEN_SET_FALLBACK(LOP_LOADKX, 0);
CODEGEN_SET_FALLBACK(LOP_COVERAGE, 0); CODEGEN_SET_FALLBACK(LOP_COVERAGE, 0);
CODEGEN_SET_FALLBACK(LOP_BREAK, 0); CODEGEN_SET_FALLBACK(LOP_BREAK, 0);
// Fallbacks that are called from partial implementation of an instruction
CODEGEN_SET_FALLBACK(LOP_GETGLOBAL, 0);
CODEGEN_SET_FALLBACK(LOP_SETGLOBAL, 0);
CODEGEN_SET_FALLBACK(LOP_GETTABLEKS, 0);
CODEGEN_SET_FALLBACK(LOP_SETTABLEKS, 0);
} }
void initHelperFunctions(NativeState& data) void initHelperFunctions(NativeState& data)
{ {
static_assert(sizeof(data.context.luauF_table) / sizeof(data.context.luauF_table[0]) == sizeof(luauF_table) / sizeof(luauF_table[0]), static_assert(sizeof(data.context.luauF_table) == sizeof(luauF_table), "fastcall tables are not of the same length");
"fast call tables are not of the same length"); memcpy(data.context.luauF_table, luauF_table, sizeof(luauF_table));
// Replace missing fast call functions with an empty placeholder that forces LOP_CALL fallback
for (size_t i = 0; i < sizeof(data.context.luauF_table) / sizeof(data.context.luauF_table[0]); i++)
data.context.luauF_table[i] = luauF_table[i] ? luauF_table[i] : luauF_missing;
data.context.luaV_lessthan = luaV_lessthan; data.context.luaV_lessthan = luaV_lessthan;
data.context.luaV_lessequal = luaV_lessequal; data.context.luaV_lessequal = luaV_lessequal;
@ -105,17 +65,28 @@ void initHelperFunctions(NativeState& data)
data.context.luaV_prepareFORN = luaV_prepareFORN; data.context.luaV_prepareFORN = luaV_prepareFORN;
data.context.luaV_gettable = luaV_gettable; data.context.luaV_gettable = luaV_gettable;
data.context.luaV_settable = luaV_settable; data.context.luaV_settable = luaV_settable;
data.context.luaV_getimport = luaV_getimport;
data.context.luaV_concat = luaV_concat;
data.context.luaH_getn = luaH_getn; data.context.luaH_getn = luaH_getn;
data.context.luaH_new = luaH_new;
data.context.luaH_clone = luaH_clone;
data.context.luaH_resizearray = luaH_resizearray;
data.context.luaC_barriertable = luaC_barriertable; data.context.luaC_barriertable = luaC_barriertable;
data.context.luaC_barrierf = luaC_barrierf;
data.context.luaC_barrierback = luaC_barrierback;
data.context.luaC_step = luaC_step;
data.context.luaF_close = luaF_close;
data.context.libm_pow = pow; data.context.libm_pow = pow;
}
void initInstructionNames(NativeState& data) data.context.forgLoopNodeIter = forgLoopNodeIter;
{ data.context.forgLoopNonTableFallback = forgLoopNonTableFallback;
CODEGEN_SET_NAMES(); data.context.forgPrepXnextFallback = forgPrepXnextFallback;
data.context.callProlog = callProlog;
data.context.callEpilogC = callEpilogC;
} }
} // namespace CodeGen } // namespace CodeGen

View file

@ -3,13 +3,16 @@
#include "Luau/Bytecode.h" #include "Luau/Bytecode.h"
#include "Luau/CodeAllocator.h" #include "Luau/CodeAllocator.h"
#include "Luau/Label.h"
#include <memory> #include <memory>
#include <stdint.h> #include <stdint.h>
#include "ldebug.h"
#include "lobject.h" #include "lobject.h"
#include "ltm.h" #include "ltm.h"
#include "lstate.h"
typedef int (*luau_FastFunction)(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams); typedef int (*luau_FastFunction)(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams);
@ -23,8 +26,6 @@ class UnwindBuilder;
using FallbackFn = const Instruction*(lua_State* L, const Instruction* pc, StkId base, TValue* k); using FallbackFn = const Instruction*(lua_State* L, const Instruction* pc, StkId base, TValue* k);
constexpr uint8_t kFallbackUpdatePc = 1 << 0; constexpr uint8_t kFallbackUpdatePc = 1 << 0;
constexpr uint8_t kFallbackUpdateCi = 1 << 1;
constexpr uint8_t kFallbackCheckInterrupt = 1 << 2;
struct NativeFallback struct NativeFallback
{ {
@ -34,7 +35,8 @@ struct NativeFallback
struct NativeProto struct NativeProto
{ {
uintptr_t* instTargets = nullptr; uintptr_t entryTarget = 0;
uintptr_t* instTargets = nullptr; // TODO: NativeProto should be variable-size with all target embedded
Proto* proto = nullptr; Proto* proto = nullptr;
uint32_t location = 0; uint32_t location = 0;
@ -61,12 +63,29 @@ struct NativeContext
void (*luaV_prepareFORN)(lua_State* L, StkId plimit, StkId pstep, StkId pinit) = nullptr; void (*luaV_prepareFORN)(lua_State* L, StkId plimit, StkId pstep, StkId pinit) = nullptr;
void (*luaV_gettable)(lua_State* L, const TValue* t, TValue* key, StkId val) = nullptr; void (*luaV_gettable)(lua_State* L, const TValue* t, TValue* key, StkId val) = nullptr;
void (*luaV_settable)(lua_State* L, const TValue* t, TValue* key, StkId val) = nullptr; void (*luaV_settable)(lua_State* L, const TValue* t, TValue* key, StkId val) = nullptr;
void (*luaV_getimport)(lua_State* L, Table* env, TValue* k, uint32_t id, bool propagatenil) = nullptr;
void (*luaV_concat)(lua_State* L, int total, int last) = nullptr;
int (*luaH_getn)(Table* t) = nullptr; int (*luaH_getn)(Table* t) = nullptr;
Table* (*luaH_new)(lua_State* L, int narray, int lnhash) = nullptr;
Table* (*luaH_clone)(lua_State* L, Table* tt) = nullptr;
void (*luaH_resizearray)(lua_State* L, Table* t, int nasize) = nullptr;
void (*luaC_barriertable)(lua_State* L, Table* t, GCObject* v) = nullptr; void (*luaC_barriertable)(lua_State* L, Table* t, GCObject* v) = nullptr;
void (*luaC_barrierf)(lua_State* L, GCObject* o, GCObject* v) = nullptr;
void (*luaC_barrierback)(lua_State* L, GCObject* o, GCObject** gclist) = nullptr;
size_t (*luaC_step)(lua_State* L, bool assist) = nullptr;
void (*luaF_close)(lua_State* L, StkId level) = nullptr;
double (*libm_pow)(double, double) = nullptr; double (*libm_pow)(double, double) = nullptr;
// Helper functions
bool (*forgLoopNodeIter)(lua_State* L, Table* h, int index, TValue* ra) = nullptr;
bool (*forgLoopNonTableFallback)(lua_State* L, int insnA, int aux) = nullptr;
void (*forgPrepXnextFallback)(lua_State* L, TValue* ra, int pc) = nullptr;
Closure* (*callProlog)(lua_State* L, TValue* ra, StkId argtop, int nresults) = nullptr;
void (*callEpilogC)(lua_State* L, int nresults, int n) = nullptr;
}; };
struct NativeState struct NativeState
@ -77,9 +96,6 @@ struct NativeState
CodeAllocator codeAllocator; CodeAllocator codeAllocator;
std::unique_ptr<UnwindBuilder> unwindBuilder; std::unique_ptr<UnwindBuilder> unwindBuilder;
// For annotations in assembly text generation
const char* names[LOP__COUNT] = {};
uint8_t* gateData = nullptr; uint8_t* gateData = nullptr;
size_t gateDataSize = 0; size_t gateDataSize = 0;
@ -88,7 +104,6 @@ struct NativeState
void initFallbackTable(NativeState& data); void initFallbackTable(NativeState& data);
void initHelperFunctions(NativeState& data); void initHelperFunctions(NativeState& data);
void initInstructionNames(NativeState& data);
} // namespace CodeGen } // namespace CodeGen
} // namespace Luau } // namespace Luau

View file

@ -20,6 +20,10 @@
#define LUAU_DEBUGBREAK() __builtin_trap() #define LUAU_DEBUGBREAK() __builtin_trap()
#endif #endif
#if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
#define LUAU_BIG_ENDIAN
#endif
namespace Luau namespace Luau
{ {

View file

@ -117,6 +117,8 @@ public:
std::string dumpEverything() const; std::string dumpEverything() const;
std::string dumpSourceRemarks() const; std::string dumpSourceRemarks() const;
void annotateInstruction(std::string& result, uint32_t fid, uint32_t instpos) const;
static uint32_t getImportId(int32_t id0); static uint32_t getImportId(int32_t id0);
static uint32_t getImportId(int32_t id0, int32_t id1); static uint32_t getImportId(int32_t id0, int32_t id1);
static uint32_t getImportId(int32_t id0, int32_t id1, int32_t id2); static uint32_t getImportId(int32_t id0, int32_t id1, int32_t id2);
@ -179,6 +181,7 @@ private:
std::string dump; std::string dump;
std::string dumpname; std::string dumpname;
std::vector<int> dumpinstoffs;
}; };
struct DebugLocal struct DebugLocal
@ -251,11 +254,13 @@ private:
std::vector<std::string> dumpSource; std::vector<std::string> dumpSource;
std::vector<std::pair<int, std::string>> dumpRemarks; std::vector<std::pair<int, std::string>> dumpRemarks;
std::string (BytecodeBuilder::*dumpFunctionPtr)() const = nullptr; std::string (BytecodeBuilder::*dumpFunctionPtr)(std::vector<int>&) const = nullptr;
void validate() const; void validate() const;
void validateInstructions() const;
void validateVariadic() const;
std::string dumpCurrentFunction() const; std::string dumpCurrentFunction(std::vector<int>& dumpinstoffs) const;
void dumpInstruction(const uint32_t* opcode, std::string& output, int targetLabel) const; void dumpInstruction(const uint32_t* opcode, std::string& output, int targetLabel) const;
void writeFunction(std::string& ss, uint32_t id) const; void writeFunction(std::string& ss, uint32_t id) const;

View file

@ -4,8 +4,6 @@
#include "Luau/Bytecode.h" #include "Luau/Bytecode.h"
#include "Luau/Compiler.h" #include "Luau/Compiler.h"
LUAU_FASTFLAGVARIABLE(LuauCompileBuiltinMT, false)
namespace Luau namespace Luau
{ {
namespace Compile namespace Compile
@ -66,13 +64,10 @@ static int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& op
if (builtin.isGlobal("select")) if (builtin.isGlobal("select"))
return LBF_SELECT_VARARG; return LBF_SELECT_VARARG;
if (FFlag::LuauCompileBuiltinMT) if (builtin.isGlobal("getmetatable"))
{ return LBF_GETMETATABLE;
if (builtin.isGlobal("getmetatable")) if (builtin.isGlobal("setmetatable"))
return LBF_GETMETATABLE; return LBF_SETMETATABLE;
if (builtin.isGlobal("setmetatable"))
return LBF_SETMETATABLE;
}
if (builtin.object == "math") if (builtin.object == "math")
{ {

View file

@ -269,7 +269,7 @@ void BytecodeBuilder::endFunction(uint8_t maxstacksize, uint8_t numupvalues)
// this call is indirect to make sure we only gain link time dependency on dumpCurrentFunction when needed // this call is indirect to make sure we only gain link time dependency on dumpCurrentFunction when needed
if (dumpFunctionPtr) if (dumpFunctionPtr)
func.dump = (this->*dumpFunctionPtr)(); func.dump = (this->*dumpFunctionPtr)(func.dumpinstoffs);
insns.clear(); insns.clear();
lines.clear(); lines.clear();
@ -1078,6 +1078,12 @@ uint8_t BytecodeBuilder::getVersion()
#ifdef LUAU_ASSERTENABLED #ifdef LUAU_ASSERTENABLED
void BytecodeBuilder::validate() const void BytecodeBuilder::validate() const
{
validateInstructions();
validateVariadic();
}
void BytecodeBuilder::validateInstructions() const
{ {
#define VREG(v) LUAU_ASSERT(unsigned(v) < func.maxstacksize) #define VREG(v) LUAU_ASSERT(unsigned(v) < func.maxstacksize)
#define VREGRANGE(v, count) LUAU_ASSERT(unsigned(v + (count < 0 ? 0 : count)) <= func.maxstacksize) #define VREGRANGE(v, count) LUAU_ASSERT(unsigned(v + (count < 0 ? 0 : count)) <= func.maxstacksize)
@ -1090,26 +1096,27 @@ void BytecodeBuilder::validate() const
const Function& func = functions[currentFunction]; const Function& func = functions[currentFunction];
// first pass: tag instruction offsets so that we can validate jumps // tag instruction offsets so that we can validate jumps
std::vector<uint8_t> insnvalid(insns.size(), false); std::vector<uint8_t> insnvalid(insns.size(), 0);
for (size_t i = 0; i < insns.size();) for (size_t i = 0; i < insns.size();)
{ {
uint8_t op = LUAU_INSN_OP(insns[i]); uint32_t insn = insns[i];
LuauOpcode op = LuauOpcode(LUAU_INSN_OP(insn));
insnvalid[i] = true; insnvalid[i] = true;
i += getOpLength(LuauOpcode(op)); i += getOpLength(op);
LUAU_ASSERT(i <= insns.size()); LUAU_ASSERT(i <= insns.size());
} }
std::vector<uint8_t> openCaptures; std::vector<uint8_t> openCaptures;
// second pass: validate the rest of the bytecode // validate individual instructions
for (size_t i = 0; i < insns.size();) for (size_t i = 0; i < insns.size();)
{ {
uint32_t insn = insns[i]; uint32_t insn = insns[i];
uint8_t op = LUAU_INSN_OP(insn); LuauOpcode op = LuauOpcode(LUAU_INSN_OP(insn));
switch (op) switch (op)
{ {
@ -1452,7 +1459,7 @@ void BytecodeBuilder::validate() const
LUAU_ASSERT(!"Unsupported opcode"); LUAU_ASSERT(!"Unsupported opcode");
} }
i += getOpLength(LuauOpcode(op)); i += getOpLength(op);
LUAU_ASSERT(i <= insns.size()); LUAU_ASSERT(i <= insns.size());
} }
@ -1469,6 +1476,126 @@ void BytecodeBuilder::validate() const
#undef VCONSTANY #undef VCONSTANY
#undef VJUMP #undef VJUMP
} }
void BytecodeBuilder::validateVariadic() const
{
// validate MULTRET sequences: instructions that produce a variadic sequence and consume one must come in pairs
// we classify instructions into four groups: producers, consumers, neutral and others
// any producer (an instruction that produces more than one value) must be followed by 0 or more neutral instructions
// and a consumer (that consumes more than one value); these form a variadic sequence.
// except for producer, no instruction in the variadic sequence may be a jump target.
// from the execution perspective, producer adjusts L->top to point to one past the last result, neutral instructions
// leave L->top unmodified, and consumer adjusts L->top back to the stack frame end.
// consumers invalidate all values after L->top after they execute (which we currently don't validate)
bool variadicSeq = false;
std::vector<uint8_t> insntargets(insns.size(), 0);
for (size_t i = 0; i < insns.size();)
{
uint32_t insn = insns[i];
LuauOpcode op = LuauOpcode(LUAU_INSN_OP(insn));
int target = getJumpTarget(insn, uint32_t(i));
if (target >= 0 && !isFastCall(op))
{
LUAU_ASSERT(unsigned(target) < insns.size());
insntargets[target] = true;
}
i += getOpLength(op);
LUAU_ASSERT(i <= insns.size());
}
for (size_t i = 0; i < insns.size();)
{
uint32_t insn = insns[i];
LuauOpcode op = LuauOpcode(LUAU_INSN_OP(insn));
if (variadicSeq)
{
// no instruction inside the sequence, including the consumer, may be a jump target
// this guarantees uninterrupted L->top adjustment flow
LUAU_ASSERT(!insntargets[i]);
}
if (op == LOP_CALL)
{
// note: calls may end one variadic sequence and start a new one
if (LUAU_INSN_B(insn) == 0)
{
// consumer instruction ens a variadic sequence
LUAU_ASSERT(variadicSeq);
variadicSeq = false;
}
else
{
// CALL is not a neutral instruction so it can't be present in a variadic sequence unless it's a consumer
LUAU_ASSERT(!variadicSeq);
}
if (LUAU_INSN_C(insn) == 0)
{
// producer instruction starts a variadic sequence
LUAU_ASSERT(!variadicSeq);
variadicSeq = true;
}
}
else if (op == LOP_GETVARARGS && LUAU_INSN_B(insn) == 0)
{
// producer instruction starts a variadic sequence
LUAU_ASSERT(!variadicSeq);
variadicSeq = true;
}
else if ((op == LOP_RETURN && LUAU_INSN_B(insn) == 0) || (op == LOP_SETLIST && LUAU_INSN_C(insn) == 0))
{
// consumer instruction ends a variadic sequence
LUAU_ASSERT(variadicSeq);
variadicSeq = false;
}
else if (op == LOP_FASTCALL)
{
int callTarget = int(i + LUAU_INSN_C(insn) + 1);
LUAU_ASSERT(unsigned(callTarget) < insns.size() && LUAU_INSN_OP(insns[callTarget]) == LOP_CALL);
if (LUAU_INSN_B(insns[callTarget]) == 0)
{
// consumer instruction ends a variadic sequence; however, we can't terminate it yet because future analysis of CALL will do it
// during FASTCALL fallback, the instructions between this and CALL consumer are going to be executed before L->top so they must
// be neutral; as such, we will defer termination of variadic sequence until CALL analysis
LUAU_ASSERT(variadicSeq);
}
else
{
// FASTCALL is not a neutral instruction so it can't be present in a variadic sequence unless it's linked to CALL consumer
LUAU_ASSERT(!variadicSeq);
}
// note: if FASTCALL is linked to a CALL producer, the instructions between FASTCALL and CALL are technically not part of an executed
// variadic sequence since they are never executed if FASTCALL does anything, so it's okay to skip their validation until CALL
// (we can't simply start a variadic sequence here because that would trigger assertions during linked CALL validation)
}
else if (op == LOP_CLOSEUPVALS || op == LOP_NAMECALL || op == LOP_GETIMPORT || op == LOP_MOVE || op == LOP_GETUPVAL || op == LOP_GETGLOBAL ||
op == LOP_GETTABLEKS || op == LOP_COVERAGE)
{
// instructions inside a variadic sequence must be neutral (can't change L->top)
// while there are many neutral instructions like this, here we check that the instruction is one of the few
// that we'd expect to exist in FASTCALL fallback sequences or between consecutive CALLs for encoding reasons
}
else
{
LUAU_ASSERT(!variadicSeq);
}
i += getOpLength(op);
LUAU_ASSERT(i <= insns.size());
}
LUAU_ASSERT(!variadicSeq);
}
#endif #endif
void BytecodeBuilder::dumpInstruction(const uint32_t* code, std::string& result, int targetLabel) const void BytecodeBuilder::dumpInstruction(const uint32_t* code, std::string& result, int targetLabel) const
@ -1800,7 +1927,7 @@ void BytecodeBuilder::dumpInstruction(const uint32_t* code, std::string& result,
} }
} }
std::string BytecodeBuilder::dumpCurrentFunction() const std::string BytecodeBuilder::dumpCurrentFunction(std::vector<int>& dumpinstoffs) const
{ {
if ((dumpFlags & Dump_Code) == 0) if ((dumpFlags & Dump_Code) == 0)
return std::string(); return std::string();
@ -1850,11 +1977,15 @@ std::string BytecodeBuilder::dumpCurrentFunction() const
if (labels[i] == 0) if (labels[i] == 0)
labels[i] = nextLabel++; labels[i] = nextLabel++;
dumpinstoffs.resize(insns.size() + 1, -1);
for (size_t i = 0; i < insns.size();) for (size_t i = 0; i < insns.size();)
{ {
const uint32_t* code = &insns[i]; const uint32_t* code = &insns[i];
uint8_t op = LUAU_INSN_OP(*code); uint8_t op = LUAU_INSN_OP(*code);
dumpinstoffs[i] = int(result.size());
if (op == LOP_PREPVARARGS) if (op == LOP_PREPVARARGS)
{ {
// Don't emit function header in bytecode - it's used for call dispatching and doesn't contain "interesting" information // Don't emit function header in bytecode - it's used for call dispatching and doesn't contain "interesting" information
@ -1897,6 +2028,8 @@ std::string BytecodeBuilder::dumpCurrentFunction() const
LUAU_ASSERT(i <= insns.size()); LUAU_ASSERT(i <= insns.size());
} }
dumpinstoffs[insns.size()] = int(result.size());
return result; return result;
} }
@ -1986,4 +2119,26 @@ std::string BytecodeBuilder::dumpSourceRemarks() const
return result; return result;
} }
void BytecodeBuilder::annotateInstruction(std::string& result, uint32_t fid, uint32_t instpos) const
{
if ((dumpFlags & Dump_Code) == 0)
return;
LUAU_ASSERT(fid < functions.size());
const Function& function = functions[fid];
const std::string& dump = function.dump;
const std::vector<int>& dumpinstoffs = function.dumpinstoffs;
uint32_t next = instpos + 1;
LUAU_ASSERT(next < dumpinstoffs.size());
// Skip locations of multi-dword instructions
while (next < dumpinstoffs.size() && dumpinstoffs[next] == -1)
next++;
formatAppend(result, "%.*s", dumpinstoffs[next] - dumpinstoffs[instpos], dump.data() + dumpinstoffs[instpos]);
}
} // namespace Luau } // namespace Luau

View file

@ -75,7 +75,7 @@ endif
# configuration-specific flags # configuration-specific flags
ifeq ($(config),release) ifeq ($(config),release)
CXXFLAGS+=-O2 -DNDEBUG CXXFLAGS+=-O2 -DNDEBUG -fno-math-errno
endif endif
ifeq ($(config),coverage) ifeq ($(config),coverage)
@ -102,7 +102,7 @@ ifeq ($(config),fuzz)
endif endif
ifeq ($(config),profile) ifeq ($(config),profile)
CXXFLAGS+=-O2 -DNDEBUG -gdwarf-4 -DCALLGRIND=1 CXXFLAGS+=-O2 -DNDEBUG -fno-math-errno -gdwarf-4 -DCALLGRIND=1
endif endif
ifeq ($(protobuf),download) ifeq ($(protobuf),download)

View file

@ -55,22 +55,28 @@ target_sources(Luau.Compiler PRIVATE
# Luau.CodeGen Sources # Luau.CodeGen Sources
target_sources(Luau.CodeGen PRIVATE target_sources(Luau.CodeGen PRIVATE
CodeGen/include/Luau/AddressA64.h
CodeGen/include/Luau/AssemblyBuilderA64.h
CodeGen/include/Luau/AssemblyBuilderX64.h CodeGen/include/Luau/AssemblyBuilderX64.h
CodeGen/include/Luau/CodeAllocator.h CodeGen/include/Luau/CodeAllocator.h
CodeGen/include/Luau/CodeBlockUnwind.h CodeGen/include/Luau/CodeBlockUnwind.h
CodeGen/include/Luau/CodeGen.h CodeGen/include/Luau/CodeGen.h
CodeGen/include/Luau/Condition.h CodeGen/include/Luau/ConditionA64.h
CodeGen/include/Luau/ConditionX64.h
CodeGen/include/Luau/Label.h CodeGen/include/Luau/Label.h
CodeGen/include/Luau/OperandX64.h CodeGen/include/Luau/OperandX64.h
CodeGen/include/Luau/RegisterA64.h
CodeGen/include/Luau/RegisterX64.h CodeGen/include/Luau/RegisterX64.h
CodeGen/include/Luau/UnwindBuilder.h CodeGen/include/Luau/UnwindBuilder.h
CodeGen/include/Luau/UnwindBuilderDwarf2.h CodeGen/include/Luau/UnwindBuilderDwarf2.h
CodeGen/include/Luau/UnwindBuilderWin.h CodeGen/include/Luau/UnwindBuilderWin.h
CodeGen/src/AssemblyBuilderA64.cpp
CodeGen/src/AssemblyBuilderX64.cpp CodeGen/src/AssemblyBuilderX64.cpp
CodeGen/src/CodeAllocator.cpp CodeGen/src/CodeAllocator.cpp
CodeGen/src/CodeBlockUnwind.cpp CodeGen/src/CodeBlockUnwind.cpp
CodeGen/src/CodeGen.cpp CodeGen/src/CodeGen.cpp
CodeGen/src/CodeGenUtils.cpp
CodeGen/src/CodeGenX64.cpp CodeGen/src/CodeGenX64.cpp
CodeGen/src/EmitBuiltinsX64.cpp CodeGen/src/EmitBuiltinsX64.cpp
CodeGen/src/EmitCommonX64.cpp CodeGen/src/EmitCommonX64.cpp
@ -82,6 +88,7 @@ target_sources(Luau.CodeGen PRIVATE
CodeGen/src/ByteUtils.h CodeGen/src/ByteUtils.h
CodeGen/src/CustomExecUtils.h CodeGen/src/CustomExecUtils.h
CodeGen/src/CodeGenUtils.h
CodeGen/src/CodeGenX64.h CodeGen/src/CodeGenX64.h
CodeGen/src/EmitBuiltinsX64.h CodeGen/src/EmitBuiltinsX64.h
CodeGen/src/EmitCommonX64.h CodeGen/src/EmitCommonX64.h
@ -101,10 +108,13 @@ target_sources(Luau.Analysis PRIVATE
Analysis/include/Luau/BuiltinDefinitions.h Analysis/include/Luau/BuiltinDefinitions.h
Analysis/include/Luau/Clone.h Analysis/include/Luau/Clone.h
Analysis/include/Luau/Config.h Analysis/include/Luau/Config.h
Analysis/include/Luau/Connective.h
Analysis/include/Luau/Constraint.h Analysis/include/Luau/Constraint.h
Analysis/include/Luau/ConstraintGraphBuilder.h Analysis/include/Luau/ConstraintGraphBuilder.h
Analysis/include/Luau/ConstraintSolver.h Analysis/include/Luau/ConstraintSolver.h
Analysis/include/Luau/DataFlowGraphBuilder.h
Analysis/include/Luau/DcrLogger.h Analysis/include/Luau/DcrLogger.h
Analysis/include/Luau/Def.h
Analysis/include/Luau/Documentation.h Analysis/include/Luau/Documentation.h
Analysis/include/Luau/Error.h Analysis/include/Luau/Error.h
Analysis/include/Luau/FileResolver.h Analysis/include/Luau/FileResolver.h
@ -114,6 +124,7 @@ target_sources(Luau.Analysis PRIVATE
Analysis/include/Luau/JsonEmitter.h Analysis/include/Luau/JsonEmitter.h
Analysis/include/Luau/Linter.h Analysis/include/Luau/Linter.h
Analysis/include/Luau/LValue.h Analysis/include/Luau/LValue.h
Analysis/include/Luau/Metamethods.h
Analysis/include/Luau/Module.h Analysis/include/Luau/Module.h
Analysis/include/Luau/ModuleResolver.h Analysis/include/Luau/ModuleResolver.h
Analysis/include/Luau/Normalize.h Analysis/include/Luau/Normalize.h
@ -151,10 +162,13 @@ target_sources(Luau.Analysis PRIVATE
Analysis/src/BuiltinDefinitions.cpp Analysis/src/BuiltinDefinitions.cpp
Analysis/src/Clone.cpp Analysis/src/Clone.cpp
Analysis/src/Config.cpp Analysis/src/Config.cpp
Analysis/src/Connective.cpp
Analysis/src/Constraint.cpp Analysis/src/Constraint.cpp
Analysis/src/ConstraintGraphBuilder.cpp Analysis/src/ConstraintGraphBuilder.cpp
Analysis/src/ConstraintSolver.cpp Analysis/src/ConstraintSolver.cpp
Analysis/src/DataFlowGraphBuilder.cpp
Analysis/src/DcrLogger.cpp Analysis/src/DcrLogger.cpp
Analysis/src/Def.cpp
Analysis/src/EmbeddedBuiltinDefinitions.cpp Analysis/src/EmbeddedBuiltinDefinitions.cpp
Analysis/src/Error.cpp Analysis/src/Error.cpp
Analysis/src/Frontend.cpp Analysis/src/Frontend.cpp
@ -292,6 +306,7 @@ if(TARGET Luau.UnitTest)
tests/AstQueryDsl.cpp tests/AstQueryDsl.cpp
tests/ConstraintGraphBuilderFixture.cpp tests/ConstraintGraphBuilderFixture.cpp
tests/Fixture.cpp tests/Fixture.cpp
tests/AssemblyBuilderA64.test.cpp
tests/AssemblyBuilderX64.test.cpp tests/AssemblyBuilderX64.test.cpp
tests/AstJsonEncoder.test.cpp tests/AstJsonEncoder.test.cpp
tests/AstQuery.test.cpp tests/AstQuery.test.cpp
@ -301,9 +316,9 @@ if(TARGET Luau.UnitTest)
tests/CodeAllocator.test.cpp tests/CodeAllocator.test.cpp
tests/Compiler.test.cpp tests/Compiler.test.cpp
tests/Config.test.cpp tests/Config.test.cpp
tests/ConstraintGraphBuilder.test.cpp
tests/ConstraintSolver.test.cpp tests/ConstraintSolver.test.cpp
tests/CostModel.test.cpp tests/CostModel.test.cpp
tests/DataFlowGraphBuilder.test.cpp
tests/Error.test.cpp tests/Error.test.cpp
tests/Frontend.test.cpp tests/Frontend.test.cpp
tests/JsonEmitter.test.cpp tests/JsonEmitter.test.cpp
@ -334,6 +349,7 @@ if(TARGET Luau.UnitTest)
tests/TypeInfer.intersectionTypes.test.cpp tests/TypeInfer.intersectionTypes.test.cpp
tests/TypeInfer.loops.test.cpp tests/TypeInfer.loops.test.cpp
tests/TypeInfer.modules.test.cpp tests/TypeInfer.modules.test.cpp
tests/TypeInfer.negations.test.cpp
tests/TypeInfer.oop.test.cpp tests/TypeInfer.oop.test.cpp
tests/TypeInfer.operators.test.cpp tests/TypeInfer.operators.test.cpp
tests/TypeInfer.primitives.test.cpp tests/TypeInfer.primitives.test.cpp
@ -372,7 +388,7 @@ if(TARGET Luau.CLI.Test)
CLI/Profiler.h CLI/Profiler.h
CLI/Profiler.cpp CLI/Profiler.cpp
CLI/Repl.cpp CLI/Repl.cpp
tests/Repl.test.cpp tests/Repl.test.cpp
tests/main.cpp) tests/main.cpp)
endif() endif()

View file

@ -1239,7 +1239,12 @@ static int luauF_setmetatable(lua_State* L, StkId res, TValue* arg0, int nresult
return -1; return -1;
} }
luau_FastFunction luauF_table[256] = { static int luauF_missing(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams)
{
return -1;
}
const luau_FastFunction luauF_table[256] = {
NULL, NULL,
luauF_assert, luauF_assert,
@ -1317,4 +1322,20 @@ luau_FastFunction luauF_table[256] = {
luauF_getmetatable, luauF_getmetatable,
luauF_setmetatable, luauF_setmetatable,
// When adding builtins, add them above this line; what follows is 64 "dummy" entries with luauF_missing fallback.
// This is important so that older versions of the runtime that don't support newer builtins automatically fall back via luauF_missing.
// Given the builtin addition velocity this should always provide a larger compatibility window than bytecode versions suggest.
#define MISSING8 luauF_missing, luauF_missing, luauF_missing, luauF_missing, luauF_missing, luauF_missing, luauF_missing, luauF_missing
MISSING8,
MISSING8,
MISSING8,
MISSING8,
MISSING8,
MISSING8,
MISSING8,
MISSING8,
#undef MISSING8
}; };

View file

@ -6,4 +6,4 @@
typedef int (*luau_FastFunction)(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams); typedef int (*luau_FastFunction)(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams);
extern luau_FastFunction luauF_table[256]; extern const luau_FastFunction luauF_table[256];

View file

@ -12,8 +12,6 @@
#include <string.h> #include <string.h>
#include <stdio.h> #include <stdio.h>
LUAU_FASTFLAGVARIABLE(LuauFasterGetInfo, false)
static const char* getfuncname(Closure* f); static const char* getfuncname(Closure* f);
static int currentpc(lua_State* L, CallInfo* ci) static int currentpc(lua_State* L, CallInfo* ci)
@ -105,8 +103,7 @@ static Closure* auxgetinfo(lua_State* L, const char* what, lua_Debug* ar, Closur
ar->source = "=[C]"; ar->source = "=[C]";
ar->what = "C"; ar->what = "C";
ar->linedefined = -1; ar->linedefined = -1;
if (FFlag::LuauFasterGetInfo) ar->short_src = "[C]";
ar->short_src = "[C]";
} }
else else
{ {
@ -114,13 +111,7 @@ static Closure* auxgetinfo(lua_State* L, const char* what, lua_Debug* ar, Closur
ar->source = getstr(source); ar->source = getstr(source);
ar->what = "Lua"; ar->what = "Lua";
ar->linedefined = f->l.p->linedefined; ar->linedefined = f->l.p->linedefined;
if (FFlag::LuauFasterGetInfo) ar->short_src = luaO_chunkid(ar->ssbuf, sizeof(ar->ssbuf), getstr(source), source->len);
ar->short_src = luaO_chunkid(ar->ssbuf, sizeof(ar->ssbuf), getstr(source), source->len);
}
if (!FFlag::LuauFasterGetInfo)
{
luaO_chunkid(ar->ssbuf, LUA_IDSIZE, ar->source, 0);
ar->short_src = ar->ssbuf;
} }
break; break;
} }
@ -195,25 +186,12 @@ int lua_getinfo(lua_State* L, int level, const char* what, lua_Debug* ar)
} }
if (f) if (f)
{ {
if (FFlag::LuauFasterGetInfo) // auxgetinfo fills ar and optionally requests to put closure on stack
if (Closure* fcl = auxgetinfo(L, what, ar, f, ci))
{ {
// auxgetinfo fills ar and optionally requests to put closure on stack luaC_threadbarrier(L);
if (Closure* fcl = auxgetinfo(L, what, ar, f, ci)) setclvalue(L, L->top, fcl);
{ incr_top(L);
luaC_threadbarrier(L);
setclvalue(L, L->top, fcl);
incr_top(L);
}
}
else
{
auxgetinfo(L, what, ar, f, ci);
if (strchr(what, 'f'))
{
luaC_threadbarrier(L);
setclvalue(L, L->top, f);
incr_top(L);
}
} }
} }
return f ? 1 : 0; return f ? 1 : 0;

View file

@ -13,8 +13,6 @@
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
LUAU_FASTFLAG(LuauFasterGetInfo)
const TValue luaO_nilobject_ = {{NULL}, {0}, LUA_TNIL}; const TValue luaO_nilobject_ = {{NULL}, {0}, LUA_TNIL};
int luaO_log2(unsigned int x) int luaO_log2(unsigned int x)
@ -121,48 +119,20 @@ const char* luaO_chunkid(char* buf, size_t buflen, const char* source, size_t sr
{ {
if (*source == '=') if (*source == '=')
{ {
if (FFlag::LuauFasterGetInfo) if (srclen <= buflen)
{ return source + 1;
if (srclen <= buflen) // truncate the part after =
return source + 1; memcpy(buf, source + 1, buflen - 1);
// truncate the part after = buf[buflen - 1] = '\0';
memcpy(buf, source + 1, buflen - 1);
buf[buflen - 1] = '\0';
}
else
{
source++; // skip the `='
size_t len = strlen(source);
size_t dstlen = len < buflen ? len : buflen - 1;
memcpy(buf, source, dstlen);
buf[dstlen] = '\0';
}
} }
else if (*source == '@') else if (*source == '@')
{ {
if (FFlag::LuauFasterGetInfo) if (srclen <= buflen)
{ return source + 1;
if (srclen <= buflen) // truncate the part after @
return source + 1; memcpy(buf, "...", 3);
// truncate the part after @ memcpy(buf + 3, source + srclen - (buflen - 4), buflen - 4);
memcpy(buf, "...", 3); buf[buflen - 1] = '\0';
memcpy(buf + 3, source + srclen - (buflen - 4), buflen - 4);
buf[buflen - 1] = '\0';
}
else
{
size_t l;
source++; // skip the `@'
buflen -= sizeof("...");
l = strlen(source);
strcpy(buf, "");
if (l > buflen)
{
source += (l - buflen); // get last part of file name
strcat(buf, "...");
}
strcat(buf, source);
}
} }
else else
{ // buf = [string "string"] { // buf = [string "string"]

View file

@ -16,8 +16,6 @@
#include <string.h> #include <string.h>
LUAU_FASTFLAGVARIABLE(LuauNoTopRestoreInFastCall, false)
// Disable c99-designator to avoid the warning in CGOTO dispatch table // Disable c99-designator to avoid the warning in CGOTO dispatch table
#ifdef __clang__ #ifdef __clang__
#if __has_warning("-Wc99-designator") #if __has_warning("-Wc99-designator")
@ -758,6 +756,8 @@ reentry:
Proto* pv = cl->l.p->p[LUAU_INSN_D(insn)]; Proto* pv = cl->l.p->p[LUAU_INSN_D(insn)];
LUAU_ASSERT(unsigned(LUAU_INSN_D(insn)) < unsigned(cl->l.p->sizep)); LUAU_ASSERT(unsigned(LUAU_INSN_D(insn)) < unsigned(cl->l.p->sizep));
VM_PROTECT_PC(); // luaF_newLclosure may fail due to OOM
// note: we save closure to stack early in case the code below wants to capture it by value // note: we save closure to stack early in case the code below wants to capture it by value
Closure* ncl = luaF_newLclosure(L, pv->nups, cl->env, pv); Closure* ncl = luaF_newLclosure(L, pv->nups, cl->env, pv);
setclvalue(L, ra, ncl); setclvalue(L, ra, ncl);
@ -2056,6 +2056,8 @@ reentry:
int b = LUAU_INSN_B(insn); int b = LUAU_INSN_B(insn);
uint32_t aux = *pc++; uint32_t aux = *pc++;
VM_PROTECT_PC(); // luaH_new may fail due to OOM
sethvalue(L, ra, luaH_new(L, aux, b == 0 ? 0 : (1 << (b - 1)))); sethvalue(L, ra, luaH_new(L, aux, b == 0 ? 0 : (1 << (b - 1))));
VM_PROTECT(luaC_checkGC(L)); VM_PROTECT(luaC_checkGC(L));
VM_NEXT(); VM_NEXT();
@ -2067,6 +2069,8 @@ reentry:
StkId ra = VM_REG(LUAU_INSN_A(insn)); StkId ra = VM_REG(LUAU_INSN_A(insn));
TValue* kv = VM_KV(LUAU_INSN_D(insn)); TValue* kv = VM_KV(LUAU_INSN_D(insn));
VM_PROTECT_PC(); // luaH_clone may fail due to OOM
sethvalue(L, ra, luaH_clone(L, hvalue(kv))); sethvalue(L, ra, luaH_clone(L, hvalue(kv)));
VM_PROTECT(luaC_checkGC(L)); VM_PROTECT(luaC_checkGC(L));
VM_NEXT(); VM_NEXT();
@ -2088,12 +2092,17 @@ reentry:
Table* h = hvalue(ra); Table* h = hvalue(ra);
// TODO: we really don't need this anymore
if (!ttistable(ra)) if (!ttistable(ra))
return; // temporary workaround to weaken a rather powerful exploitation primitive in case of a MITM attack on bytecode return; // temporary workaround to weaken a rather powerful exploitation primitive in case of a MITM attack on bytecode
int last = index + c - 1; int last = index + c - 1;
if (last > h->sizearray) if (last > h->sizearray)
{
VM_PROTECT_PC(); // luaH_resizearray may fail due to OOM
luaH_resizearray(L, h, last); luaH_resizearray(L, h, last);
}
TValue* array = h->array; TValue* array = h->array;
@ -2185,7 +2194,8 @@ reentry:
// protect against __iter returning nil, since nil is used as a marker for builtin iteration in FORGLOOP // protect against __iter returning nil, since nil is used as a marker for builtin iteration in FORGLOOP
if (ttisnil(ra)) if (ttisnil(ra))
{ {
VM_PROTECT(luaG_typeerror(L, ra, "call")); VM_PROTECT_PC(); // next call always errors
luaG_typeerror(L, ra, "call");
} }
} }
else if (fasttm(L, mt, TM_CALL)) else if (fasttm(L, mt, TM_CALL))
@ -2202,7 +2212,8 @@ reentry:
} }
else else
{ {
VM_PROTECT(luaG_typeerror(L, ra, "iterate over")); VM_PROTECT_PC(); // next call always errors
luaG_typeerror(L, ra, "iterate over");
} }
} }
@ -2325,7 +2336,8 @@ reentry:
} }
else if (!ttisfunction(ra)) else if (!ttisfunction(ra))
{ {
VM_PROTECT(luaG_typeerror(L, ra, "iterate over")); VM_PROTECT_PC(); // next call always errors
luaG_typeerror(L, ra, "iterate over");
} }
pc += LUAU_INSN_D(insn); pc += LUAU_INSN_D(insn);
@ -2353,7 +2365,8 @@ reentry:
} }
else if (!ttisfunction(ra)) else if (!ttisfunction(ra))
{ {
VM_PROTECT(luaG_typeerror(L, ra, "iterate over")); VM_PROTECT_PC(); // next call always errors
luaG_typeerror(L, ra, "iterate over");
} }
pc += LUAU_INSN_D(insn); pc += LUAU_INSN_D(insn);
@ -2404,6 +2417,8 @@ reentry:
Closure* kcl = clvalue(kv); Closure* kcl = clvalue(kv);
VM_PROTECT_PC(); // luaF_newLclosure may fail due to OOM
// clone closure if the environment is not shared // clone closure if the environment is not shared
// note: we save closure to stack early in case the code below wants to capture it by value // note: we save closure to stack early in case the code below wants to capture it by value
Closure* ncl = (kcl->env == cl->env) ? kcl : luaF_newLclosure(L, kcl->nupvalues, cl->env, kcl->l.p); Closure* ncl = (kcl->env == cl->env) ? kcl : luaF_newLclosure(L, kcl->nupvalues, cl->env, kcl->l.p);
@ -2530,10 +2545,11 @@ reentry:
nparams = (nparams == LUA_MULTRET) ? int(L->top - ra - 1) : nparams; nparams = (nparams == LUA_MULTRET) ? int(L->top - ra - 1) : nparams;
luau_FastFunction f = luauF_table[bfid]; luau_FastFunction f = luauF_table[bfid];
LUAU_ASSERT(f);
if (cl->env->safeenv && f) if (cl->env->safeenv)
{ {
VM_PROTECT_PC(); VM_PROTECT_PC(); // f may fail due to OOM
int n = f(L, ra, ra + 1, nresults, ra + 2, nparams); int n = f(L, ra, ra + 1, nresults, ra + 2, nparams);
@ -2608,24 +2624,18 @@ reentry:
int nresults = LUAU_INSN_C(call) - 1; int nresults = LUAU_INSN_C(call) - 1;
luau_FastFunction f = luauF_table[bfid]; luau_FastFunction f = luauF_table[bfid];
LUAU_ASSERT(f);
if (cl->env->safeenv && f) if (cl->env->safeenv)
{ {
VM_PROTECT_PC(); VM_PROTECT_PC(); // f may fail due to OOM
int n = f(L, ra, arg, nresults, NULL, nparams); int n = f(L, ra, arg, nresults, NULL, nparams);
if (n >= 0) if (n >= 0)
{ {
if (FFlag::LuauNoTopRestoreInFastCall) if (nresults == LUA_MULTRET)
{ L->top = ra + n;
if (nresults == LUA_MULTRET)
L->top = ra + n;
}
else
{
L->top = (nresults == LUA_MULTRET) ? ra + n : L->ci->top;
}
pc += skip + 1; // skip instructions that compute function as well as CALL pc += skip + 1; // skip instructions that compute function as well as CALL
LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode));
@ -2664,24 +2674,18 @@ reentry:
int nresults = LUAU_INSN_C(call) - 1; int nresults = LUAU_INSN_C(call) - 1;
luau_FastFunction f = luauF_table[bfid]; luau_FastFunction f = luauF_table[bfid];
LUAU_ASSERT(f);
if (cl->env->safeenv && f) if (cl->env->safeenv)
{ {
VM_PROTECT_PC(); VM_PROTECT_PC(); // f may fail due to OOM
int n = f(L, ra, arg1, nresults, arg2, nparams); int n = f(L, ra, arg1, nresults, arg2, nparams);
if (n >= 0) if (n >= 0)
{ {
if (FFlag::LuauNoTopRestoreInFastCall) if (nresults == LUA_MULTRET)
{ L->top = ra + n;
if (nresults == LUA_MULTRET)
L->top = ra + n;
}
else
{
L->top = (nresults == LUA_MULTRET) ? ra + n : L->ci->top;
}
pc += skip + 1; // skip instructions that compute function as well as CALL pc += skip + 1; // skip instructions that compute function as well as CALL
LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode));
@ -2720,24 +2724,18 @@ reentry:
int nresults = LUAU_INSN_C(call) - 1; int nresults = LUAU_INSN_C(call) - 1;
luau_FastFunction f = luauF_table[bfid]; luau_FastFunction f = luauF_table[bfid];
LUAU_ASSERT(f);
if (cl->env->safeenv && f) if (cl->env->safeenv)
{ {
VM_PROTECT_PC(); VM_PROTECT_PC(); // f may fail due to OOM
int n = f(L, ra, arg1, nresults, arg2, nparams); int n = f(L, ra, arg1, nresults, arg2, nparams);
if (n >= 0) if (n >= 0)
{ {
if (FFlag::LuauNoTopRestoreInFastCall) if (nresults == LUA_MULTRET)
{ L->top = ra + n;
if (nresults == LUA_MULTRET)
L->top = ra + n;
}
else
{
L->top = (nresults == LUA_MULTRET) ? ra + n : L->ci->top;
}
pc += skip + 1; // skip instructions that compute function as well as CALL pc += skip + 1; // skip instructions that compute function as well as CALL
LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode));

View file

@ -1,4 +1,4 @@
--!non-strict --!nonstrict
local bench = script and require(script.Parent.bench_support) or require("bench_support") local bench = script and require(script.Parent.bench_support) or require("bench_support")
local stretchTreeDepth = 18 -- about 16Mb local stretchTreeDepth = 18 -- about 16Mb

456
bench/tests/voxelgen.lua Normal file
View file

@ -0,0 +1,456 @@
local bench = script and require(script.Parent.bench_support) or require("bench_support")
-- Based on voxel terrain generator by Stickmasterluke
local kSelectedBiomes = {
['Mountains'] = true,
['Canyons'] = true,
['Dunes'] = true,
['Arctic'] = true,
['Lavaflow'] = true,
['Hills'] = true,
['Plains'] = true,
['Marsh'] = true,
['Water'] = true,
}
---------Directly used in Generation---------
local masterSeed = 618033988
local mapWidth = 32
local mapHeight = 32
local biomeSize = 16
local generateCaves = true
local waterLevel = .48
local surfaceThickness = .018
local biomes = {}
---------------------------------------------
local rock = "Rock"
local snow = "Snow"
local ice = "Glacier"
local grass = "Grass"
local ground = "Ground"
local mud = "Mud"
local slate = "Slate"
local concrete = "Concrete"
local lava = "CrackedLava"
local basalt = "Basalt"
local air = "Air"
local sand = "Sand"
local sandstone = "Sandstone"
local water = "Water"
math.randomseed(6180339)
local theseed={}
for i=1,999 do
table.insert(theseed,math.random())
end
local function getPerlin(x,y,z,seed,scale,raw)
local seed = seed or 0
local scale = scale or 1
if not raw then
return math.noise(x/scale+(seed*17)+masterSeed,y/scale-masterSeed,z/scale-seed*seed)*.5 + .5 -- accounts for bleeding from interpolated line
else
return math.noise(x/scale+(seed*17)+masterSeed,y/scale-masterSeed,z/scale-seed*seed)
end
end
local function getNoise(x,y,z,seed1)
local x = x or 0
local y = y or 0
local z = z or 0
local seed1 = seed1 or 7
local wtf=x+y+z+seed1+masterSeed + (masterSeed-x)*(seed1+z) + (seed1-y)*(masterSeed+z) -- + x*(y+z) + z*(masterSeed+seed1) + seed1*(x+y) --x+y+z+seed1+masterSeed + x*y*masterSeed-y*z+(z+masterSeed)*x --((x+y)*(y-seed1)*seed1)-(x+z)*seed2+x*11+z*23-y*17
return theseed[(math.floor(wtf%(#theseed)))+1]
end
local function thresholdFilter(value, bottom, size)
if value <= bottom then
return 0
elseif value >= bottom+size then
return 1
else
return (value-bottom)/size
end
end
local function ridgedFilter(value) --absolute and flip for ridges. and normalize
return value<.5 and value*2 or 2-value*2
end
local function ridgedFlippedFilter(value) --unflipped
return value < .5 and 1-value*2 or value*2-1
end
local function advancedRidgedFilter(value, cutoff)
local cutoff = cutoff or .5
value = value - cutoff
return 1 - (value < 0 and -value or value) * 1/(1-cutoff)
end
local function fractalize(operation,x,y,z, operationCount, scale, offset, gain)
local operationCount = operationCount or 3
local scale = scale or .5
local offset = 0
local gain = gain or 1
local totalValue = 0
local totalScale = 0
for i=1, operationCount do
local thisScale = scale^(i-1)
totalScale = totalScale + thisScale
totalValue = totalValue + (offset + gain * operation(x,y,z,i))*thisScale
end
return totalValue/totalScale
end
local function mountainsOperation(x,y,z,i)
return ridgedFilter(getPerlin(x,y,z,100+i,(1/i)*160))
end
local canyonBandingMaterial = {rock,mud,sand,sand,sandstone,sandstone,sandstone,sandstone,sandstone,sandstone,}
local function findBiomeInfo(choiceBiome,x,y,z,verticalGradientTurbulence)
local choiceBiomeValue = .5
local choiceBiomeSurface = grass
local choiceBiomeFill = rock
if choiceBiome == 'City' then
choiceBiomeValue = .55
choiceBiomeSurface = concrete
choiceBiomeFill = slate
elseif choiceBiome == 'Water' then
choiceBiomeValue = .36+getPerlin(x,y,z,2,50)*.08
choiceBiomeSurface =
(1-verticalGradientTurbulence < .44 and slate)
or sand
elseif choiceBiome == 'Marsh' then
local preLedge = getPerlin(x+getPerlin(x,0,z,5,7,true)*10+getPerlin(x,0,z,6,30,true)*50,0,z+getPerlin(x,0,z,9,7,true)*10+getPerlin(x,0,z,10,30,true)*50,2,70) --could use some turbulence
local grassyLedge = thresholdFilter(preLedge,.65,0)
local largeGradient = getPerlin(x,y,z,4,100)
local smallGradient = getPerlin(x,y,z,3,20)
local smallGradientThreshold = thresholdFilter(smallGradient,.5,0)
choiceBiomeValue = waterLevel-.04
+preLedge*grassyLedge*.025
+largeGradient*.035
+smallGradient*.025
choiceBiomeSurface =
(grassyLedge >= 1 and grass)
or (1-verticalGradientTurbulence < waterLevel-.01 and mud)
or (1-verticalGradientTurbulence < waterLevel+.01 and ground)
or grass
choiceBiomeFill = slate
elseif choiceBiome == 'Plains' then
local rivulet = ridgedFlippedFilter(getPerlin(x+getPerlin(x,y,z,17,40)*25,0,z+getPerlin(x,y,z,19,40)*25,2,200))
local rivuletThreshold = thresholdFilter(rivulet,.01,0)
local rockMap = thresholdFilter(ridgedFlippedFilter(getPerlin(x,0,z,101,7)),.3,.7) --rocks
* thresholdFilter(getPerlin(x,0,z,102,50),.6,.05) --zoning
choiceBiomeValue = .5 --.51
+getPerlin(x,y,z,2,100)*.02 --.05
+rivulet*.05 --.02
+rockMap*.05 --.03
+rivuletThreshold*.005
local verticalGradient = 1-((y-1)/(mapHeight-1))
local surfaceGradient = verticalGradient*.5 + choiceBiomeValue*.5
local thinSurface = surfaceGradient > .5-surfaceThickness*.4 and surfaceGradient < .5+surfaceThickness*.4
choiceBiomeSurface =
(rockMap>0 and rock)
or (not thinSurface and mud)
or (thinSurface and rivuletThreshold <=0 and water)
or (1-verticalGradientTurbulence < waterLevel-.01 and sand)
or grass
choiceBiomeFill =
(rockMap>0 and rock)
or sandstone
elseif choiceBiome == 'Canyons' then
local canyonNoise = ridgedFlippedFilter(getPerlin(x,0,z,2,200))
local canyonNoiseTurbed = ridgedFlippedFilter(getPerlin(x+getPerlin(x,0,z,5,20,true)*20,0,z+getPerlin(x,0,z,9,20,true)*20,2,200))
local sandbank = thresholdFilter(canyonNoiseTurbed,0,.05)
local canyonTop = thresholdFilter(canyonNoiseTurbed,.125,0)
local mesaSlope = thresholdFilter(canyonNoise,.33,.12)
local mesaTop = thresholdFilter(canyonNoiseTurbed,.49,0)
choiceBiomeValue = .42
+getPerlin(x,y,z,2,70)*.05
+canyonNoise*.05
+sandbank*.04 --canyon bottom slope
+thresholdFilter(canyonNoiseTurbed,.05,0)*.08 --canyon cliff
+thresholdFilter(canyonNoiseTurbed,.05,.075)*.04 --canyon cliff top slope
+canyonTop*.01 --canyon cliff top ledge
+thresholdFilter(canyonNoiseTurbed,.0575,.2725)*.01 --plane slope
+mesaSlope*.06 --mesa slope
+thresholdFilter(canyonNoiseTurbed,.45,0)*.14 --mesa cliff
+thresholdFilter(canyonNoiseTurbed,.45,.04)*.025 --mesa cap
+mesaTop*.02 --mesa top ledge
choiceBiomeSurface =
(1-verticalGradientTurbulence < waterLevel+.015 and sand) --this for biome blending in to lakes
or (sandbank>0 and sandbank<1 and sand) --this for canyonbase sandbanks
--or (canyonTop>0 and canyonTop<=1 and mesaSlope<=0 and grass) --this for grassy canyon tops
--or (mesaTop>0 and mesaTop<=1 and grass) --this for grassy mesa tops
or sandstone
choiceBiomeFill = canyonBandingMaterial[math.ceil((1-getNoise(1,y,2))*10)]
elseif choiceBiome == 'Hills' then
local rivulet = ridgedFlippedFilter(getPerlin(x+getPerlin(x,y,z,17,20)*20,0,z+getPerlin(x,y,z,19,20)*20,2,200))^(1/2)
local largeHills = getPerlin(x,y,z,3,60)
choiceBiomeValue = .48
+largeHills*.05
+(.05
+largeHills*.1
+getPerlin(x,y,z,4,25)*.125)
*rivulet
local surfaceMaterialGradient = (1-verticalGradientTurbulence)*.9 + rivulet*.1
choiceBiomeSurface =
(surfaceMaterialGradient < waterLevel-.015 and mud)
or (surfaceMaterialGradient < waterLevel and ground)
or grass
choiceBiomeFill = slate
elseif choiceBiome == 'Dunes' then
local duneTurbulence = getPerlin(x,0,z,227,20)*24
local layer1 = ridgedFilter(getPerlin(x,0,z,201,40))
local layer2 = ridgedFilter(getPerlin(x/10+duneTurbulence,0,z+duneTurbulence,200,48))
choiceBiomeValue = .4+.1*(layer1 + layer2)
choiceBiomeSurface = sand
choiceBiomeFill = sandstone
elseif choiceBiome == 'Mountains' then
local rivulet = ridgedFlippedFilter(getPerlin(x+getPerlin(x,y,z,17,20)*20,0,z+getPerlin(x,y,z,19,20)*20,2,200))
choiceBiomeValue = -.4 --.3
+fractalize(mountainsOperation,x,y/20,z, 8, .65)*1.2
+rivulet*.2
choiceBiomeSurface =
(verticalGradientTurbulence < .275 and snow)
or (verticalGradientTurbulence < .35 and rock)
or (verticalGradientTurbulence < .4 and ground)
or (1-verticalGradientTurbulence < waterLevel and rock)
or (1-verticalGradientTurbulence < waterLevel+.01 and mud)
or (1-verticalGradientTurbulence < waterLevel+.015 and ground)
or grass
elseif choiceBiome == 'Lavaflow' then
local crackX = x+getPerlin(x,y*.25,z,21,8,true)*5
local crackY = y+getPerlin(x,y*.25,z,22,8,true)*5
local crackZ = z+getPerlin(x,y*.25,z,23,8,true)*5
local crack1 = ridgedFilter(getPerlin(crackX+getPerlin(x,y,z,22,30,true)*30,crackY,crackZ+getPerlin(x,y,z,24,30,true)*30,2,120))
local crack2 = ridgedFilter(getPerlin(crackX,crackY,crackZ,3,40))*(crack1*.25+.75)
local crack3 = ridgedFilter(getPerlin(crackX,crackY,crackZ,4,20))*(crack2*.25+.75)
local generalHills = thresholdFilter(getPerlin(x,y,z,9,40),.25,.5)*getPerlin(x,y,z,10,60)
local cracks = math.max(0,1-thresholdFilter(crack1,.975,0)-thresholdFilter(crack2,.925,0)-thresholdFilter(crack3,.9,0))
local spires = thresholdFilter(getPerlin(crackX/40,crackY/300,crackZ/30,123,1),.6,.4)
choiceBiomeValue = waterLevel+.02
+cracks*(.5+generalHills*.5)*.02
+generalHills*.05
+spires*.3
+((1-verticalGradientTurbulence > waterLevel+.01 or spires>0) and .04 or 0) --This lets it lip over water
choiceBiomeFill = (spires>0 and rock) or (cracks<1 and lava) or basalt
choiceBiomeSurface = (choiceBiomeFill == lava and 1-verticalGradientTurbulence < waterLevel and basalt) or choiceBiomeFill
elseif choiceBiome == 'Arctic' then
local preBoundary = getPerlin(x+getPerlin(x,0,z,5,8,true)*5,y/8,z+getPerlin(x,0,z,9,8,true)*5,2,20)
--local cliffs = thresholdFilter(preBoundary,.5,0)
local boundary = ridgedFilter(preBoundary)
local roughChunks = getPerlin(x,y/4,z,436,2)
local boundaryMask = thresholdFilter(boundary,.8,.1) --,.7,.25)
local boundaryTypeMask = getPerlin(x,0,z,6,74)-.5
local boundaryComp = 0
if boundaryTypeMask < 0 then --divergent
boundaryComp = (boundary > (1+boundaryTypeMask*.5) and -.17 or 0)
--* boundaryTypeMask*-2
else --convergent
boundaryComp = boundaryMask*.1*roughChunks
* boundaryTypeMask
end
choiceBiomeValue = .55
+boundary*.05*boundaryTypeMask --.1 --soft slope up or down to boundary
+boundaryComp --convergent/divergent effects
+getPerlin(x,0,z,123,25)*.025 --*cliffs --gentle rolling slopes
choiceBiomeSurface = (1-verticalGradientTurbulence < waterLevel-.1 and ice) or (boundaryMask>.6 and boundaryTypeMask>.1 and roughChunks>.5 and ice) or snow
choiceBiomeFill = ice
end
return choiceBiomeValue, choiceBiomeSurface, choiceBiomeFill
end
function findBiomeTransitionValue(biome,weight,value,averageValue)
if biome == 'Arctic' then
return (weight>.2 and 1 or 0)*value
elseif biome == 'Canyons' then
return (weight>.7 and 1 or 0)*value
elseif biome == 'Mountains' then
local weight = weight^3 --This improves the ease of mountains transitioning to other biomes
return averageValue*(1-weight)+value*weight
else
return averageValue*(1-weight)+value*weight
end
end
function generate()
local mapWidth = mapWidth
local biomeSize = biomeSize
local biomeBlendPercent = .25 --(biomeSize==50 or biomeSize == 100) and .5 or .25
local biomeBlendPercentInverse = 1-biomeBlendPercent
local biomeBlendDistortion = biomeBlendPercent
local smoothScale = .5/mapHeight
biomes = {}
for i,v in pairs(kSelectedBiomes) do
if v then
table.insert(biomes,i)
end
end
if #biomes<=0 then
table.insert(biomes,'Hills')
end
table.sort(biomes)
--local oMap = {}
--local mMap = {}
for x = 1, mapWidth do
local oMapX = {}
--oMap[x] = oMapX
local mMapX = {}
--mMap[x] = mMapX
for z = 1, mapWidth do
local biomeNoCave = false
local cellToBiomeX = x/biomeSize + getPerlin(x,0,z,233,biomeSize*.3)*.25 + getPerlin(x,0,z,235,biomeSize*.05)*.075
local cellToBiomeZ = z/biomeSize + getPerlin(x,0,z,234,biomeSize*.3)*.25 + getPerlin(x,0,z,236,biomeSize*.05)*.075
local closestDistance = 1000000
local biomePoints = {}
for vx=-1,1 do
for vz=-1,1 do
local gridPointX = math.floor(cellToBiomeX+vx+.5)
local gridPointZ = math.floor(cellToBiomeZ+vz+.5)
--local pointX, pointZ = getBiomePoint(gridPointX,gridPointZ)
local pointX = gridPointX+(getNoise(gridPointX,gridPointZ,53)-.5)*.75 --de-uniforming grid for vornonoi
local pointZ = gridPointZ+(getNoise(gridPointX,gridPointZ,73)-.5)*.75
local dist = math.sqrt((pointX-cellToBiomeX)^2 + (pointZ-cellToBiomeZ)^2)
if dist < closestDistance then
closestDistance = dist
end
table.insert(biomePoints,{
x = pointX,
z = pointZ,
dist = dist,
biomeNoise = getNoise(gridPointX,gridPointZ),
weight = 0
})
end
end
local weightTotal = 0
local weightPoints = {}
for _,point in pairs(biomePoints) do
local weight = point.dist == closestDistance and 1 or ((closestDistance / point.dist)-biomeBlendPercentInverse)/biomeBlendPercent
if weight > 0 then
local weight = weight^2.1 --this smooths the biome transition from linear to cubic InOut
weightTotal = weightTotal + weight
local biome = biomes[math.ceil(#biomes*(1-point.biomeNoise))] --inverting the noise so that it is limited as (0,1]. One less addition operation when finding a random list index
weightPoints[biome] = {
weight = weightPoints[biome] and weightPoints[biome].weight + weight or weight
}
end
end
for biome,info in pairs(weightPoints) do
info.weight = info.weight / weightTotal
if biome == 'Arctic' then --biomes that don't have caves that breach the surface
biomeNoCave = true
end
end
for y = 1, mapHeight do
local oMapY = oMapX[y] or {}
oMapX[y] = oMapY
local mMapY = mMapX[y] or {}
mMapX[y] = mMapY
--[[local oMapY = {}
oMapX[y] = oMapY
local mMapY = {}
mMapX[z] = mMapY]]
local verticalGradient = 1-((y-1)/(mapHeight-1))
local caves = 0
local verticalGradientTurbulence = verticalGradient*.9 + .1*getPerlin(x,y,z,107,15)
local choiceValue = 0
local choiceSurface = lava
local choiceFill = rock
if verticalGradient > .65 or verticalGradient < .1 then
--under surface of every biome; don't get biome data; waste of time.
choiceValue = .5
elseif #biomes == 1 then
choiceValue, choiceSurface, choiceFill = findBiomeInfo(biomes[1],x,y,z,verticalGradientTurbulence)
else
local averageValue = 0
--local findChoiceMaterial = -getNoise(x,y,z,19)
for biome,info in pairs(weightPoints) do
local biomeValue, biomeSurface, biomeFill = findBiomeInfo(biome,x,y,z,verticalGradientTurbulence)
info.biomeValue = biomeValue
info.biomeSurface = biomeSurface
info.biomeFill = biomeFill
local value = biomeValue * info.weight
averageValue = averageValue + value
--[[if findChoiceMaterial < 0 and findChoiceMaterial + weight >= 0 then
choiceMaterial = biomeMaterial
end
findChoiceMaterial = findChoiceMaterial + weight]]
end
for biome,info in pairs(weightPoints) do
local value = findBiomeTransitionValue(biome,info.weight,info.biomeValue,averageValue)
if value > choiceValue then
choiceValue = value
choiceSurface = info.biomeSurface
choiceFill = info.biomeFill
end
end
end
local preCaveComp = verticalGradient*.5 + choiceValue*.5
local surface = preCaveComp > .5-surfaceThickness and preCaveComp < .5+surfaceThickness
if generateCaves --user wants caves
and (not biomeNoCave or verticalGradient > .65) --biome allows caves or deep enough
and not (surface and (1-verticalGradient) < waterLevel+.005) --caves only breach surface above waterlevel
and not (surface and (1-verticalGradient) > waterLevel+.58) then --caves don't go too high so that they don't cut up mountain tops
local ridged2 = ridgedFilter(getPerlin(x,y,z,4,30))
local caves2 = thresholdFilter(ridged2,.84,.01)
local ridged3 = ridgedFilter(getPerlin(x,y,z,5,30))
local caves3 = thresholdFilter(ridged3,.84,.01)
local ridged4 = ridgedFilter(getPerlin(x,y,z,6,30))
local caves4 = thresholdFilter(ridged4,.84,.01)
local caveOpenings = (surface and 1 or 0) * thresholdFilter(getPerlin(x,0,z,143,62),.35,0) --.45
caves = caves2 * caves3 * caves4 - caveOpenings
caves = caves < 0 and 0 or caves > 1 and 1 or caves
end
local comp = preCaveComp - caves
local smoothedResult = thresholdFilter(comp,.5,smoothScale)
---below water level -above surface -no terrain
if 1-verticalGradient < waterLevel and preCaveComp <= .5 and smoothedResult <= 0 then
smoothedResult = 1
choiceSurface = water
choiceFill = water
surface = true
end
oMapY[z] = (y == 1 and 1) or smoothedResult
mMapY[z] = (y == 1 and lava) or (smoothedResult <= 0 and air) or (surface and choiceSurface) or choiceFill
end
end
-- local regionStart = Vector3.new(mapWidth*-2+(x-1)*4,mapHeight*-2,mapWidth*-2)
-- local regionEnd = Vector3.new(mapWidth*-2+x*4,mapHeight*2,mapWidth*2)
-- local mapRegion = Region3.new(regionStart, regionEnd)
-- terrain:WriteVoxels(mapRegion, 4, {mMapX}, {oMapX})
end
end
bench.runCode(generate, "voxelgen")

View file

@ -66,8 +66,8 @@ Syntactic subtyping is a syntax-directed recursive algorithm. The interesting ca
* Reflexivity: `T` is a subtype of `T` * Reflexivity: `T` is a subtype of `T`
* Intersection L: `(T₁ & … & Tⱼ)` is a subtype of `U` whenever some of the `Tᵢ` are subtypes of `U` * Intersection L: `(T₁ & … & Tⱼ)` is a subtype of `U` whenever some of the `Tᵢ` are subtypes of `U`
* Union L: `(T₁ | … | Tⱼ)` is a subtype of `U` whenever all of the `Tᵢ` are subtypes of `U` * Union L: `(T₁ | … | Tⱼ)` is a subtype of `U` whenever all of the `Tᵢ` are subtypes of `U`
* Intersection R: `T` is a subtype of `(U₁ & … & Uⱼ)` whenever `T` is a subtype of some of the `Uᵢ` * Intersection R: `T` is a subtype of `(U₁ & … & Uⱼ)` whenever `T` is a subtype of all of the `Uᵢ`
* Union R: `T` is a subtype of `(U₁ | … | Uⱼ)` whenever `T` is a subtype of all of the `Uᵢ`. * Union R: `T` is a subtype of `(U₁ | … | Uⱼ)` whenever `T` is a subtype of some of the `Uᵢ`.
For example: For example:
@ -262,6 +262,10 @@ Semantic subtyping has removed one source of false positives, but we still have
The quest to remove spurious red squiggles continues! The quest to remove spurious red squiggles continues!
## Acknowledgments
Thanks to Giuseppe Castagna and Ben Greenman for helpful comments on drafts of this post.
## Further reading ## Further reading
If you want to find out more about Luau and semantic subtyping, you might want to check out… If you want to find out more about Luau and semantic subtyping, you might want to check out…
@ -274,8 +278,9 @@ If you want to find out more about Luau and semantic subtyping, you might want t
* Giuseppe Castagna, *Covariance and Contravariance*, Logical Methods in Computer Science 16(1), 2022. <https://arxiv.org/abs/1809.01427> * Giuseppe Castagna, *Covariance and Contravariance*, Logical Methods in Computer Science 16(1), 2022. <https://arxiv.org/abs/1809.01427>
* Giuseppe Castagna and Alain Frisch, *A gentle introduction to semantic subtyping*, Proc. Principles and practice of declarative programming (PPDP), pp 198208, 2005. <https://doi.org/10.1145/1069774.1069793> * Giuseppe Castagna and Alain Frisch, *A gentle introduction to semantic subtyping*, Proc. Principles and practice of declarative programming (PPDP), pp 198208, 2005. <https://doi.org/10.1145/1069774.1069793>
* Giuseppe Castagna, Mickaël Laurent, Kim Nguyễn, Matthew Lutze, *On Type-Cases, Union Elimination, and Occurrence Typing*, Principles of Programming Languages (POPL), 2022. <https://doi.org/10.1145/3498674> * Giuseppe Castagna, Mickaël Laurent, Kim Nguyễn, Matthew Lutze, *On Type-Cases, Union Elimination, and Occurrence Typing*, Principles of Programming Languages (POPL), 2022. <https://doi.org/10.1145/3498674>
* Sam Tobin-Hochstadt and Matthias Felleisen, *Logical types for untyped languages*. International Conference on Functional Programming (ICFP), 2010. https://doi.org/10.1145/1863543.1863561 * Giuseppe Castagna, *Programming with union, intersection, and negation types*, 2022. <https://arxiv.org/abs/2111.03354>
* José Valim, *My Future with Elixir: set-theoretic types*, 2022. https://elixir-lang.org/blog/2022/10/05/my-future-with-elixir-set-theoretic-types/ * Sam Tobin-Hochstadt and Matthias Felleisen, *Logical types for untyped languages*. International Conference on Functional Programming (ICFP), 2010. <https://doi.org/10.1145/1863543.1863561>
* José Valim, *My Future with Elixir: set-theoretic types*, 2022. <https://elixir-lang.org/blog/2022/10/05/my-future-with-elixir-set-theoretic-types/>
Some other languages which support semantic subtyping… Some other languages which support semantic subtyping…

Some files were not shown because too many files have changed in this diff Show more