Merge branch 'upstream' into merge

This commit is contained in:
Andy Friesen 2022-11-04 10:04:57 -07:00
commit 8fd7d22046
78 changed files with 3061 additions and 2986 deletions

View file

@ -0,0 +1,68 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Def.h"
#include "Luau/TypedAllocator.h"
#include "Luau/TypeVar.h"
#include "Luau/Variant.h"
#include <memory>
namespace Luau
{
struct Negation;
struct Conjunction;
struct Disjunction;
struct Equivalence;
struct Proposition;
using Connective = Variant<Negation, Conjunction, Disjunction, Equivalence, Proposition>;
using ConnectiveId = Connective*; // Can and most likely is nullptr.
struct Negation
{
ConnectiveId connective;
};
struct Conjunction
{
ConnectiveId lhs;
ConnectiveId rhs;
};
struct Disjunction
{
ConnectiveId lhs;
ConnectiveId rhs;
};
struct Equivalence
{
ConnectiveId lhs;
ConnectiveId rhs;
};
struct Proposition
{
DefId def;
TypeId discriminantTy;
};
template<typename T>
const T* get(ConnectiveId connective)
{
return get_if<T>(connective);
}
struct ConnectiveArena
{
TypedAllocator<Connective> allocator;
ConnectiveId negation(ConnectiveId connective);
ConnectiveId conjunction(ConnectiveId lhs, ConnectiveId rhs);
ConnectiveId disjunction(ConnectiveId lhs, ConnectiveId rhs);
ConnectiveId equivalence(ConnectiveId lhs, ConnectiveId rhs);
ConnectiveId proposition(DefId def, TypeId discriminantTy);
};
} // namespace Luau

View file

@ -132,15 +132,16 @@ struct HasPropConstraint
std::string prop; std::string prop;
}; };
struct RefinementConstraint // result ~ if isSingleton D then ~D else unknown where D = discriminantType
struct SingletonOrTopTypeConstraint
{ {
DefId def; TypeId resultType;
TypeId discriminantType; TypeId discriminantType;
}; };
using ConstraintV = Variant<SubtypeConstraint, PackSubtypeConstraint, GeneralizationConstraint, InstantiationConstraint, UnaryConstraint, using ConstraintV = Variant<SubtypeConstraint, PackSubtypeConstraint, GeneralizationConstraint, InstantiationConstraint, UnaryConstraint,
BinaryConstraint, IterableConstraint, NameConstraint, TypeAliasExpansionConstraint, FunctionCallConstraint, PrimitiveTypeConstraint, BinaryConstraint, IterableConstraint, NameConstraint, TypeAliasExpansionConstraint, FunctionCallConstraint, PrimitiveTypeConstraint,
HasPropConstraint, RefinementConstraint>; HasPropConstraint, SingletonOrTopTypeConstraint>;
struct Constraint struct Constraint
{ {

View file

@ -2,6 +2,7 @@
#pragma once #pragma once
#include "Luau/Ast.h" #include "Luau/Ast.h"
#include "Luau/Connective.h"
#include "Luau/Constraint.h" #include "Luau/Constraint.h"
#include "Luau/DataFlowGraphBuilder.h" #include "Luau/DataFlowGraphBuilder.h"
#include "Luau/Module.h" #include "Luau/Module.h"
@ -26,11 +27,13 @@ struct DcrLogger;
struct Inference struct Inference
{ {
TypeId ty = nullptr; TypeId ty = nullptr;
ConnectiveId connective = nullptr;
Inference() = default; Inference() = default;
explicit Inference(TypeId ty) explicit Inference(TypeId ty, ConnectiveId connective = nullptr)
: ty(ty) : ty(ty)
, connective(connective)
{ {
} }
}; };
@ -38,11 +41,13 @@ struct Inference
struct InferencePack struct InferencePack
{ {
TypePackId tp = nullptr; TypePackId tp = nullptr;
std::vector<ConnectiveId> connectives;
InferencePack() = default; InferencePack() = default;
explicit InferencePack(TypePackId tp) explicit InferencePack(TypePackId tp, const std::vector<ConnectiveId>& connectives = {})
: tp(tp) : tp(tp)
, connectives(connectives)
{ {
} }
}; };
@ -73,6 +78,7 @@ struct ConstraintGraphBuilder
// Defining scopes for AST nodes. // Defining scopes for AST nodes.
DenseHashMap<const AstStatTypeAlias*, ScopePtr> astTypeAliasDefiningScopes{nullptr}; DenseHashMap<const AstStatTypeAlias*, ScopePtr> astTypeAliasDefiningScopes{nullptr};
NotNull<const DataFlowGraph> dfg; NotNull<const DataFlowGraph> dfg;
ConnectiveArena connectiveArena;
int recursionCount = 0; int recursionCount = 0;
@ -126,6 +132,8 @@ struct ConstraintGraphBuilder
*/ */
NotNull<Constraint> addConstraint(const ScopePtr& scope, std::unique_ptr<Constraint> c); NotNull<Constraint> addConstraint(const ScopePtr& scope, std::unique_ptr<Constraint> c);
void applyRefinements(const ScopePtr& scope, Location location, ConnectiveId connective);
/** /**
* The entry point to the ConstraintGraphBuilder. This will construct a set * The entry point to the ConstraintGraphBuilder. This will construct a set
* of scopes, constraints, and free types that can be solved later. * of scopes, constraints, and free types that can be solved later.
@ -167,10 +175,10 @@ struct ConstraintGraphBuilder
* surrounding context. Used to implement bidirectional type checking. * surrounding context. Used to implement bidirectional type checking.
* @return the type of the expression. * @return the type of the expression.
*/ */
Inference check(const ScopePtr& scope, AstExpr* expr, std::optional<TypeId> expectedType = {}); Inference check(const ScopePtr& scope, AstExpr* expr, std::optional<TypeId> expectedType = {}, bool forceSingleton = false);
Inference check(const ScopePtr& scope, AstExprConstantString* string, std::optional<TypeId> expectedType); Inference check(const ScopePtr& scope, AstExprConstantString* string, std::optional<TypeId> expectedType, bool forceSingleton);
Inference check(const ScopePtr& scope, AstExprConstantBool* bool_, std::optional<TypeId> expectedType); Inference check(const ScopePtr& scope, AstExprConstantBool* bool_, std::optional<TypeId> expectedType, bool forceSingleton);
Inference check(const ScopePtr& scope, AstExprLocal* local); Inference check(const ScopePtr& scope, AstExprLocal* local);
Inference check(const ScopePtr& scope, AstExprGlobal* global); Inference check(const ScopePtr& scope, AstExprGlobal* global);
Inference check(const ScopePtr& scope, AstExprIndexName* indexName); Inference check(const ScopePtr& scope, AstExprIndexName* indexName);
@ -180,6 +188,7 @@ struct ConstraintGraphBuilder
Inference check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional<TypeId> expectedType); Inference check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional<TypeId> expectedType);
Inference check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert); Inference check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert);
Inference check(const ScopePtr& scope, AstExprTable* expr, std::optional<TypeId> expectedType); Inference check(const ScopePtr& scope, AstExprTable* expr, std::optional<TypeId> expectedType);
std::tuple<TypeId, TypeId, ConnectiveId> checkBinary(const ScopePtr& scope, AstExprBinary* binary, std::optional<TypeId> expectedType);
TypePackId checkLValues(const ScopePtr& scope, AstArray<AstExpr*> exprs); TypePackId checkLValues(const ScopePtr& scope, AstArray<AstExpr*> exprs);

View file

@ -110,7 +110,7 @@ struct ConstraintSolver
bool tryDispatch(const FunctionCallConstraint& c, NotNull<const Constraint> constraint); bool tryDispatch(const FunctionCallConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const PrimitiveTypeConstraint& c, NotNull<const Constraint> constraint); bool tryDispatch(const PrimitiveTypeConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const HasPropConstraint& c, NotNull<const Constraint> constraint); bool tryDispatch(const HasPropConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const RefinementConstraint& c, NotNull<const Constraint> constraint); bool tryDispatch(const SingletonOrTopTypeConstraint& c, NotNull<const Constraint> constraint);
// for a, ... in some_table do // for a, ... in some_table do
// also handles __iter metamethod // also handles __iter metamethod

View file

@ -17,10 +17,8 @@ struct SingletonTypes;
using ModulePtr = std::shared_ptr<Module>; using ModulePtr = std::shared_ptr<Module>;
bool isSubtype( bool isSubtype(TypeId subTy, TypeId superTy, NotNull<Scope> scope, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter& ice);
TypeId subTy, TypeId superTy, NotNull<Scope> scope, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter& ice, bool anyIsTop = true); bool isSubtype(TypePackId subTy, TypePackId superTy, NotNull<Scope> scope, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter& ice);
bool isSubtype(TypePackId subTy, TypePackId superTy, NotNull<Scope> scope, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter& ice,
bool anyIsTop = true);
class TypeIds class TypeIds
{ {
@ -169,12 +167,26 @@ struct NormalizedStringType
bool isSubtype(const NormalizedStringType& subStr, const NormalizedStringType& superStr); bool isSubtype(const NormalizedStringType& subStr, const NormalizedStringType& superStr);
// A normalized function type is either `never` (represented by `nullopt`) // A normalized function type can be `never`, the top function type `function`,
// or an intersection of function types. // or an intersection of function types.
// NOTE: type normalization can fail on function types with generics //
// (e.g. because we do not support unions and intersections of generic type packs), // NOTE: type normalization can fail on function types with generics (e.g.
// so this type may contain `error`. // because we do not support unions and intersections of generic type packs), so
using NormalizedFunctionType = std::optional<TypeIds>; // this type may contain `error`.
struct NormalizedFunctionType
{
NormalizedFunctionType();
bool isTop = false;
// TODO: Remove this wrapping optional when clipping
// FFlagLuauNegatedFunctionTypes.
std::optional<TypeIds> parts;
void resetToNever();
void resetToTop();
bool isNever() const;
};
// A normalized generic/free type is a union, where each option is of the form (X & T) where // A normalized generic/free type is a union, where each option is of the form (X & T) where
// * X is either a free type or a generic // * X is either a free type or a generic
@ -234,12 +246,14 @@ struct NormalizedType
NormalizedType(NotNull<SingletonTypes> singletonTypes); NormalizedType(NotNull<SingletonTypes> singletonTypes);
NormalizedType(const NormalizedType&) = delete;
NormalizedType(NormalizedType&&) = default;
NormalizedType() = delete; NormalizedType() = delete;
~NormalizedType() = default; ~NormalizedType() = default;
NormalizedType(const NormalizedType&) = delete;
NormalizedType& operator=(const NormalizedType&) = delete;
NormalizedType(NormalizedType&&) = default;
NormalizedType& operator=(NormalizedType&&) = default; NormalizedType& operator=(NormalizedType&&) = default;
NormalizedType& operator=(NormalizedType&) = delete;
}; };
class Normalizer class Normalizer
@ -291,7 +305,7 @@ public:
bool unionNormalWithTy(NormalizedType& here, TypeId there, int ignoreSmallerTyvars = -1); bool unionNormalWithTy(NormalizedType& here, TypeId there, int ignoreSmallerTyvars = -1);
// ------- Negations // ------- Negations
NormalizedType negateNormal(const NormalizedType& here); std::optional<NormalizedType> negateNormal(const NormalizedType& here);
TypeIds negateAll(const TypeIds& theres); TypeIds negateAll(const TypeIds& theres);
TypeId negate(TypeId there); TypeId negate(TypeId there);
void subtractPrimitive(NormalizedType& here, TypeId ty); void subtractPrimitive(NormalizedType& here, TypeId ty);

View file

@ -35,7 +35,7 @@ std::vector<TypeId> flatten(TypeArena& arena, NotNull<SingletonTypes> singletonT
* identity) types. * identity) types.
* @param types the input type list to reduce. * @param types the input type list to reduce.
* @returns the reduced type list. * @returns the reduced type list.
*/ */
std::vector<TypeId> reduceUnion(const std::vector<TypeId>& types); std::vector<TypeId> reduceUnion(const std::vector<TypeId>& types);
/** /**
@ -45,7 +45,7 @@ std::vector<TypeId> reduceUnion(const std::vector<TypeId>& types);
* @param arena the type arena to allocate the new type in, if necessary * @param arena the type arena to allocate the new type in, if necessary
* @param ty the type to remove nil from * @param ty the type to remove nil from
* @returns a type with nil removed, or nil itself if that were the only option. * @returns a type with nil removed, or nil itself if that were the only option.
*/ */
TypeId stripNil(NotNull<SingletonTypes> singletonTypes, TypeArena& arena, TypeId ty); TypeId stripNil(NotNull<SingletonTypes> singletonTypes, TypeArena& arena, TypeId ty);
} // namespace Luau } // namespace Luau

View file

@ -115,6 +115,7 @@ struct PrimitiveTypeVar
Number, Number,
String, String,
Thread, Thread,
Function,
}; };
Type type; Type type;
@ -504,14 +505,6 @@ struct NeverTypeVar
{ {
}; };
// Invariant 1: there should never be a reason why such UseTypeVar exists without it mapping to another type.
// Invariant 2: UseTypeVar should always disappear across modules.
struct UseTypeVar
{
DefId def;
NotNull<Scope> scope;
};
// ~T // ~T
// TODO: Some simplification step that overwrites the type graph to make sure negation // TODO: Some simplification step that overwrites the type graph to make sure negation
// types disappear from the user's view, and (?) a debug flag to disable that // types disappear from the user's view, and (?) a debug flag to disable that
@ -522,9 +515,9 @@ struct NegationTypeVar
using ErrorTypeVar = Unifiable::Error; using ErrorTypeVar = Unifiable::Error;
using TypeVariant = Unifiable::Variant<TypeId, PrimitiveTypeVar, BlockedTypeVar, PendingExpansionTypeVar, SingletonTypeVar, FunctionTypeVar, using TypeVariant =
TableTypeVar, MetatableTypeVar, ClassTypeVar, AnyTypeVar, UnionTypeVar, IntersectionTypeVar, LazyTypeVar, UnknownTypeVar, NeverTypeVar, Unifiable::Variant<TypeId, PrimitiveTypeVar, BlockedTypeVar, PendingExpansionTypeVar, SingletonTypeVar, FunctionTypeVar, TableTypeVar,
UseTypeVar, NegationTypeVar>; MetatableTypeVar, ClassTypeVar, AnyTypeVar, UnionTypeVar, IntersectionTypeVar, LazyTypeVar, UnknownTypeVar, NeverTypeVar, NegationTypeVar>;
struct TypeVar final struct TypeVar final
{ {
@ -644,13 +637,14 @@ public:
const TypeId stringType; const TypeId stringType;
const TypeId booleanType; const TypeId booleanType;
const TypeId threadType; const TypeId threadType;
const TypeId functionType;
const TypeId trueType; const TypeId trueType;
const TypeId falseType; const TypeId falseType;
const TypeId anyType; const TypeId anyType;
const TypeId unknownType; const TypeId unknownType;
const TypeId neverType; const TypeId neverType;
const TypeId errorType; const TypeId errorType;
const TypeId falsyType; // No type binding! const TypeId falsyType; // No type binding!
const TypeId truthyType; // No type binding! const TypeId truthyType; // No type binding!
const TypePackId anyTypePack; const TypePackId anyTypePack;

View file

@ -61,7 +61,6 @@ struct Unifier
ErrorVec errors; ErrorVec errors;
Location location; Location location;
Variance variance = Covariant; Variance variance = Covariant;
bool anyIsTop = false; // If true, we consider any to be a top type. If false, it is a familiar but weird mix of top and bottom all at once.
bool normalize; // Normalize unions and intersections if necessary bool normalize; // Normalize unions and intersections if necessary
bool useScopes = false; // If true, we use the scope hierarchy rather than TypeLevels bool useScopes = false; // If true, we use the scope hierarchy rather than TypeLevels
CountMismatch::Context ctx = CountMismatch::Arg; CountMismatch::Context ctx = CountMismatch::Arg;
@ -131,6 +130,7 @@ public:
Unifier makeChildUnifier(); Unifier makeChildUnifier();
void reportError(TypeError err); void reportError(TypeError err);
LUAU_NOINLINE void reportError(Location location, TypeErrorData data);
private: private:
bool isNonstrictMode() const; bool isNonstrictMode() const;

View file

@ -58,13 +58,15 @@ public:
constexpr int tid = getTypeId<T>(); constexpr int tid = getTypeId<T>();
typeId = tid; typeId = tid;
new (&storage) TT(value); new (&storage) TT(std::forward<T>(value));
} }
Variant(const Variant& other) Variant(const Variant& other)
{ {
static constexpr FnCopy table[sizeof...(Ts)] = {&fnCopy<Ts>...};
typeId = other.typeId; typeId = other.typeId;
tableCopy[typeId](&storage, &other.storage); table[typeId](&storage, &other.storage);
} }
Variant(Variant&& other) Variant(Variant&& other)
@ -192,7 +194,6 @@ private:
return *static_cast<const T*>(lhs) == *static_cast<const T*>(rhs); return *static_cast<const T*>(lhs) == *static_cast<const T*>(rhs);
} }
static constexpr FnCopy tableCopy[sizeof...(Ts)] = {&fnCopy<Ts>...};
static constexpr FnMove tableMove[sizeof...(Ts)] = {&fnMove<Ts>...}; static constexpr FnMove tableMove[sizeof...(Ts)] = {&fnMove<Ts>...};
static constexpr FnDtor tableDtor[sizeof...(Ts)] = {&fnDtor<Ts>...}; static constexpr FnDtor tableDtor[sizeof...(Ts)] = {&fnDtor<Ts>...};

View file

@ -155,10 +155,6 @@ struct GenericTypeVarVisitor
{ {
return visit(ty); return visit(ty);
} }
virtual bool visit(TypeId ty, const UseTypeVar& utv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const NegationTypeVar& ntv) virtual bool visit(TypeId ty, const NegationTypeVar& ntv)
{ {
return visit(ty); return visit(ty);
@ -321,8 +317,6 @@ struct GenericTypeVarVisitor
traverse(a); traverse(a);
} }
} }
else if (auto utv = get<UseTypeVar>(ty))
visit(ty, *utv);
else if (auto ntv = get<NegationTypeVar>(ty)) else if (auto ntv = get<NegationTypeVar>(ty))
visit(ty, *ntv); visit(ty, *ntv);
else if (!FFlag::LuauCompleteVisitor) else if (!FFlag::LuauCompleteVisitor)

View file

@ -714,8 +714,7 @@ static bool dcrMagicFunctionPack(MagicFunctionCallContext context)
result = arena->addType(UnionTypeVar{std::move(options)}); result = arena->addType(UnionTypeVar{std::move(options)});
TypeId numberType = context.solver->singletonTypes->numberType; TypeId numberType = context.solver->singletonTypes->numberType;
TypeId packedTable = arena->addType( TypeId packedTable = arena->addType(TableTypeVar{{{"n", {numberType}}}, TableIndexer(numberType, result), {}, TableState::Sealed});
TableTypeVar{{{"n", {numberType}}}, TableIndexer(numberType, result), {}, TableState::Sealed});
TypePackId tableTypePack = arena->addTypePack({packedTable}); TypePackId tableTypePack = arena->addTypePack({packedTable});
asMutable(context.result)->ty.emplace<BoundTypePack>(tableTypePack); asMutable(context.result)->ty.emplace<BoundTypePack>(tableTypePack);

View file

@ -62,7 +62,6 @@ struct TypeCloner
void operator()(const LazyTypeVar& t); void operator()(const LazyTypeVar& t);
void operator()(const UnknownTypeVar& t); void operator()(const UnknownTypeVar& t);
void operator()(const NeverTypeVar& t); void operator()(const NeverTypeVar& t);
void operator()(const UseTypeVar& t);
void operator()(const NegationTypeVar& t); void operator()(const NegationTypeVar& t);
}; };
@ -338,12 +337,6 @@ void TypeCloner::operator()(const NeverTypeVar& t)
defaultClone(t); defaultClone(t);
} }
void TypeCloner::operator()(const UseTypeVar& t)
{
TypeId result = dest.addType(BoundTypeVar{follow(typeId)});
seenTypes[typeId] = result;
}
void TypeCloner::operator()(const NegationTypeVar& t) void TypeCloner::operator()(const NegationTypeVar& t)
{ {
TypeId result = dest.addType(AnyTypeVar{}); TypeId result = dest.addType(AnyTypeVar{});

View file

@ -0,0 +1,32 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Connective.h"
namespace Luau
{
ConnectiveId ConnectiveArena::negation(ConnectiveId connective)
{
return NotNull{allocator.allocate(Negation{connective})};
}
ConnectiveId ConnectiveArena::conjunction(ConnectiveId lhs, ConnectiveId rhs)
{
return NotNull{allocator.allocate(Conjunction{lhs, rhs})};
}
ConnectiveId ConnectiveArena::disjunction(ConnectiveId lhs, ConnectiveId rhs)
{
return NotNull{allocator.allocate(Disjunction{lhs, rhs})};
}
ConnectiveId ConnectiveArena::equivalence(ConnectiveId lhs, ConnectiveId rhs)
{
return NotNull{allocator.allocate(Equivalence{lhs, rhs})};
}
ConnectiveId ConnectiveArena::proposition(DefId def, TypeId discriminantTy)
{
return NotNull{allocator.allocate(Proposition{def, discriminantTy})};
}
} // namespace Luau

View file

@ -107,6 +107,101 @@ NotNull<Constraint> ConstraintGraphBuilder::addConstraint(const ScopePtr& scope,
return NotNull{scope->constraints.emplace_back(std::move(c)).get()}; return NotNull{scope->constraints.emplace_back(std::move(c)).get()};
} }
static void unionRefinements(const std::unordered_map<DefId, TypeId>& lhs, const std::unordered_map<DefId, TypeId>& rhs,
std::unordered_map<DefId, TypeId>& dest, NotNull<TypeArena> arena)
{
for (auto [def, ty] : lhs)
{
auto rhsIt = rhs.find(def);
if (rhsIt == rhs.end())
continue;
std::vector<TypeId> discriminants{{ty, rhsIt->second}};
if (auto destIt = dest.find(def); destIt != dest.end())
discriminants.push_back(destIt->second);
dest[def] = arena->addType(UnionTypeVar{std::move(discriminants)});
}
}
static void computeRefinement(const ScopePtr& scope, ConnectiveId connective, std::unordered_map<DefId, TypeId>* refis, bool sense,
NotNull<TypeArena> arena, bool eq, std::vector<SingletonOrTopTypeConstraint>* constraints)
{
using RefinementMap = std::unordered_map<DefId, TypeId>;
if (!connective)
return;
else if (auto negation = get<Negation>(connective))
return computeRefinement(scope, negation->connective, refis, !sense, arena, eq, constraints);
else if (auto conjunction = get<Conjunction>(connective))
{
RefinementMap lhsRefis;
RefinementMap rhsRefis;
computeRefinement(scope, conjunction->lhs, sense ? refis : &lhsRefis, sense, arena, eq, constraints);
computeRefinement(scope, conjunction->rhs, sense ? refis : &rhsRefis, sense, arena, eq, constraints);
if (!sense)
unionRefinements(lhsRefis, rhsRefis, *refis, arena);
}
else if (auto disjunction = get<Disjunction>(connective))
{
RefinementMap lhsRefis;
RefinementMap rhsRefis;
computeRefinement(scope, disjunction->lhs, sense ? &lhsRefis : refis, sense, arena, eq, constraints);
computeRefinement(scope, disjunction->rhs, sense ? &rhsRefis : refis, sense, arena, eq, constraints);
if (sense)
unionRefinements(lhsRefis, rhsRefis, *refis, arena);
}
else if (auto equivalence = get<Equivalence>(connective))
{
computeRefinement(scope, equivalence->lhs, refis, sense, arena, true, constraints);
computeRefinement(scope, equivalence->rhs, refis, sense, arena, true, constraints);
}
else if (auto proposition = get<Proposition>(connective))
{
TypeId discriminantTy = proposition->discriminantTy;
if (!sense && !eq)
discriminantTy = arena->addType(NegationTypeVar{proposition->discriminantTy});
else if (!sense && eq)
{
discriminantTy = arena->addType(BlockedTypeVar{});
constraints->push_back(SingletonOrTopTypeConstraint{discriminantTy, proposition->discriminantTy});
}
if (auto it = refis->find(proposition->def); it != refis->end())
(*refis)[proposition->def] = arena->addType(IntersectionTypeVar{{discriminantTy, it->second}});
else
(*refis)[proposition->def] = discriminantTy;
}
}
void ConstraintGraphBuilder::applyRefinements(const ScopePtr& scope, Location location, ConnectiveId connective)
{
if (!connective)
return;
std::unordered_map<DefId, TypeId> refinements;
std::vector<SingletonOrTopTypeConstraint> constraints;
computeRefinement(scope, connective, &refinements, /*sense*/ true, arena, /*eq*/ false, &constraints);
for (auto [def, discriminantTy] : refinements)
{
std::optional<TypeId> defTy = scope->lookup(def);
if (!defTy)
ice->ice("Every DefId must map to a type!");
TypeId resultTy = arena->addType(IntersectionTypeVar{{*defTy, discriminantTy}});
scope->dcrRefinements[def] = resultTy;
}
for (auto& c : constraints)
addConstraint(scope, location, c);
}
void ConstraintGraphBuilder::visit(AstStatBlock* block) void ConstraintGraphBuilder::visit(AstStatBlock* block)
{ {
LUAU_ASSERT(scopes.empty()); LUAU_ASSERT(scopes.empty());
@ -250,14 +345,33 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local)
if (value->is<AstExprConstantNil>()) if (value->is<AstExprConstantNil>())
{ {
// HACK: we leave nil-initialized things floating under the assumption that they will later be populated. // HACK: we leave nil-initialized things floating under the
// See the test TypeInfer/infer_locals_with_nil_value. // assumption that they will later be populated.
// Better flow awareness should make this obsolete. //
// See the test TypeInfer/infer_locals_with_nil_value. Better flow
// awareness should make this obsolete.
if (!varTypes[i]) if (!varTypes[i])
varTypes[i] = freshType(scope); varTypes[i] = freshType(scope);
} }
else if (i == local->values.size - 1) // Only function calls and vararg expressions can produce packs. All
// other expressions produce exactly one value.
else if (i != local->values.size - 1 || (!value->is<AstExprCall>() && !value->is<AstExprVarargs>()))
{
std::optional<TypeId> expectedType;
if (hasAnnotation)
expectedType = varTypes.at(i);
TypeId exprType = check(scope, value, expectedType).ty;
if (i < varTypes.size())
{
if (varTypes[i])
addConstraint(scope, local->location, SubtypeConstraint{exprType, varTypes[i]});
else
varTypes[i] = exprType;
}
}
else
{ {
std::vector<TypeId> expectedTypes; std::vector<TypeId> expectedTypes;
if (hasAnnotation) if (hasAnnotation)
@ -286,21 +400,6 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local)
addConstraint(scope, local->location, PackSubtypeConstraint{exprPack, tailPack}); addConstraint(scope, local->location, PackSubtypeConstraint{exprPack, tailPack});
} }
} }
else
{
std::optional<TypeId> expectedType;
if (hasAnnotation)
expectedType = varTypes.at(i);
TypeId exprType = check(scope, value, expectedType).ty;
if (i < varTypes.size())
{
if (varTypes[i])
addConstraint(scope, local->location, SubtypeConstraint{varTypes[i], exprType});
else
varTypes[i] = exprType;
}
}
} }
for (size_t i = 0; i < local->vars.size; ++i) for (size_t i = 0; i < local->vars.size; ++i)
@ -569,14 +668,16 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatIf* ifStatement
// TODO: Optimization opportunity, the interior scope of the condition could be // TODO: Optimization opportunity, the interior scope of the condition could be
// reused for the then body, so we don't need to refine twice. // reused for the then body, so we don't need to refine twice.
ScopePtr condScope = childScope(ifStatement->condition, scope); ScopePtr condScope = childScope(ifStatement->condition, scope);
check(condScope, ifStatement->condition, std::nullopt); auto [_, connective] = check(condScope, ifStatement->condition, std::nullopt);
ScopePtr thenScope = childScope(ifStatement->thenbody, scope); ScopePtr thenScope = childScope(ifStatement->thenbody, scope);
applyRefinements(thenScope, Location{}, connective);
visit(thenScope, ifStatement->thenbody); visit(thenScope, ifStatement->thenbody);
if (ifStatement->elsebody) if (ifStatement->elsebody)
{ {
ScopePtr elseScope = childScope(ifStatement->elsebody, scope); ScopePtr elseScope = childScope(ifStatement->elsebody, scope);
applyRefinements(elseScope, Location{}, connectiveArena.negation(connective));
visit(elseScope, ifStatement->elsebody); visit(elseScope, ifStatement->elsebody);
} }
} }
@ -925,7 +1026,7 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa
} }
} }
Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, std::optional<TypeId> expectedType) Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, std::optional<TypeId> expectedType, bool forceSingleton)
{ {
RecursionCounter counter{&recursionCount}; RecursionCounter counter{&recursionCount};
@ -938,13 +1039,13 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, st
Inference result; Inference result;
if (auto group = expr->as<AstExprGroup>()) if (auto group = expr->as<AstExprGroup>())
result = check(scope, group->expr, expectedType); result = check(scope, group->expr, expectedType, forceSingleton);
else if (auto stringExpr = expr->as<AstExprConstantString>()) else if (auto stringExpr = expr->as<AstExprConstantString>())
result = check(scope, stringExpr, expectedType); result = check(scope, stringExpr, expectedType, forceSingleton);
else if (expr->is<AstExprConstantNumber>()) else if (expr->is<AstExprConstantNumber>())
result = Inference{singletonTypes->numberType}; result = Inference{singletonTypes->numberType};
else if (auto boolExpr = expr->as<AstExprConstantBool>()) else if (auto boolExpr = expr->as<AstExprConstantBool>())
result = check(scope, boolExpr, expectedType); result = check(scope, boolExpr, expectedType, forceSingleton);
else if (expr->is<AstExprConstantNil>()) else if (expr->is<AstExprConstantNil>())
result = Inference{singletonTypes->nilType}; result = Inference{singletonTypes->nilType};
else if (auto local = expr->as<AstExprLocal>()) else if (auto local = expr->as<AstExprLocal>())
@ -999,8 +1100,11 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, st
return result; return result;
} }
Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprConstantString* string, std::optional<TypeId> expectedType) Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprConstantString* string, std::optional<TypeId> expectedType, bool forceSingleton)
{ {
if (forceSingleton)
return Inference{arena->addType(SingletonTypeVar{StringSingleton{std::string{string->value.data, string->value.size}}})};
if (expectedType) if (expectedType)
{ {
const TypeId expectedTy = follow(*expectedType); const TypeId expectedTy = follow(*expectedType);
@ -1020,12 +1124,15 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprConstantSt
return Inference{singletonTypes->stringType}; return Inference{singletonTypes->stringType};
} }
Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprConstantBool* boolExpr, std::optional<TypeId> expectedType) Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprConstantBool* boolExpr, std::optional<TypeId> expectedType, bool forceSingleton)
{ {
const TypeId singletonType = boolExpr->value ? singletonTypes->trueType : singletonTypes->falseType;
if (forceSingleton)
return Inference{singletonType};
if (expectedType) if (expectedType)
{ {
const TypeId expectedTy = follow(*expectedType); const TypeId expectedTy = follow(*expectedType);
const TypeId singletonType = boolExpr->value ? singletonTypes->trueType : singletonTypes->falseType;
if (get<BlockedTypeVar>(expectedTy) || get<PendingExpansionTypeVar>(expectedTy)) if (get<BlockedTypeVar>(expectedTy) || get<PendingExpansionTypeVar>(expectedTy))
{ {
@ -1045,8 +1152,8 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprConstantBo
Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprLocal* local) Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprLocal* local)
{ {
std::optional<TypeId> resultTy; std::optional<TypeId> resultTy;
auto def = dfg->getDef(local);
if (auto def = dfg->getDef(local)) if (def)
resultTy = scope->lookup(*def); resultTy = scope->lookup(*def);
if (!resultTy) if (!resultTy)
@ -1058,7 +1165,10 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprLocal* loc
if (!resultTy) if (!resultTy)
return Inference{singletonTypes->errorRecoveryType()}; // TODO: replace with ice, locals should never exist before its definition. return Inference{singletonTypes->errorRecoveryType()}; // TODO: replace with ice, locals should never exist before its definition.
return Inference{*resultTy}; if (def)
return Inference{*resultTy, connectiveArena.proposition(*def, singletonTypes->truthyType)};
else
return Inference{*resultTy};
} }
Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprGlobal* global) Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprGlobal* global)
@ -1107,20 +1217,23 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexExpr*
Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprUnary* unary) Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprUnary* unary)
{ {
TypeId operandType = check(scope, unary->expr).ty; auto [operandType, connective] = check(scope, unary->expr);
TypeId resultType = arena->addType(BlockedTypeVar{}); TypeId resultType = arena->addType(BlockedTypeVar{});
addConstraint(scope, unary->location, UnaryConstraint{unary->op, operandType, resultType}); addConstraint(scope, unary->location, UnaryConstraint{unary->op, operandType, resultType});
return Inference{resultType};
if (unary->op == AstExprUnary::Not)
return Inference{resultType, connectiveArena.negation(connective)};
else
return Inference{resultType};
} }
Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprBinary* binary, std::optional<TypeId> expectedType) Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprBinary* binary, std::optional<TypeId> expectedType)
{ {
TypeId leftType = check(scope, binary->left, expectedType).ty; auto [leftType, rightType, connective] = checkBinary(scope, binary, expectedType);
TypeId rightType = check(scope, binary->right, expectedType).ty;
TypeId resultType = arena->addType(BlockedTypeVar{}); TypeId resultType = arena->addType(BlockedTypeVar{});
addConstraint(scope, binary->location, BinaryConstraint{binary->op, leftType, rightType, resultType}); addConstraint(scope, binary->location, BinaryConstraint{binary->op, leftType, rightType, resultType});
return Inference{resultType}; return Inference{resultType, std::move(connective)};
} }
Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional<TypeId> expectedType) Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional<TypeId> expectedType)
@ -1147,6 +1260,58 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTypeAssert
return Inference{resolveType(scope, typeAssert->annotation)}; return Inference{resolveType(scope, typeAssert->annotation)};
} }
std::tuple<TypeId, TypeId, ConnectiveId> ConstraintGraphBuilder::checkBinary(
const ScopePtr& scope, AstExprBinary* binary, std::optional<TypeId> expectedType)
{
if (binary->op == AstExprBinary::And)
{
auto [leftType, leftConnective] = check(scope, binary->left, expectedType);
ScopePtr rightScope = childScope(binary->right, scope);
applyRefinements(rightScope, binary->right->location, leftConnective);
auto [rightType, rightConnective] = check(rightScope, binary->right, expectedType);
return {leftType, rightType, connectiveArena.conjunction(leftConnective, rightConnective)};
}
else if (binary->op == AstExprBinary::Or)
{
auto [leftType, leftConnective] = check(scope, binary->left, expectedType);
ScopePtr rightScope = childScope(binary->right, scope);
applyRefinements(rightScope, binary->right->location, connectiveArena.negation(leftConnective));
auto [rightType, rightConnective] = check(rightScope, binary->right, expectedType);
return {leftType, rightType, connectiveArena.disjunction(leftConnective, rightConnective)};
}
else if (binary->op == AstExprBinary::CompareEq || binary->op == AstExprBinary::CompareNe)
{
TypeId leftType = check(scope, binary->left, expectedType, true).ty;
TypeId rightType = check(scope, binary->right, expectedType, true).ty;
ConnectiveId leftConnective = nullptr;
if (auto def = dfg->getDef(binary->left))
leftConnective = connectiveArena.proposition(*def, rightType);
ConnectiveId rightConnective = nullptr;
if (auto def = dfg->getDef(binary->right))
rightConnective = connectiveArena.proposition(*def, leftType);
if (binary->op == AstExprBinary::CompareNe)
{
leftConnective = connectiveArena.negation(leftConnective);
rightConnective = connectiveArena.negation(rightConnective);
}
return {leftType, rightType, connectiveArena.equivalence(leftConnective, rightConnective)};
}
else
{
TypeId leftType = check(scope, binary->left, expectedType).ty;
TypeId rightType = check(scope, binary->right, expectedType).ty;
return {leftType, rightType, nullptr};
}
}
TypePackId ConstraintGraphBuilder::checkLValues(const ScopePtr& scope, AstArray<AstExpr*> exprs) TypePackId ConstraintGraphBuilder::checkLValues(const ScopePtr& scope, AstArray<AstExpr*> exprs)
{ {
std::vector<TypeId> types; std::vector<TypeId> types;
@ -1841,9 +2006,13 @@ std::vector<std::pair<Name, GenericTypePackDefinition>> ConstraintGraphBuilder::
Inference ConstraintGraphBuilder::flattenPack(const ScopePtr& scope, Location location, InferencePack pack) Inference ConstraintGraphBuilder::flattenPack(const ScopePtr& scope, Location location, InferencePack pack)
{ {
auto [tp] = pack; const auto& [tp, connectives] = pack;
ConnectiveId connective = nullptr;
if (!connectives.empty())
connective = connectives[0];
if (auto f = first(tp)) if (auto f = first(tp))
return Inference{*f}; return Inference{*f, connective};
TypeId typeResult = freshType(scope); TypeId typeResult = freshType(scope);
TypePack onePack{{typeResult}, freshTypePack(scope)}; TypePack onePack{{typeResult}, freshTypePack(scope)};
@ -1851,7 +2020,7 @@ Inference ConstraintGraphBuilder::flattenPack(const ScopePtr& scope, Location lo
addConstraint(scope, location, PackSubtypeConstraint{tp, oneTypePack}); addConstraint(scope, location, PackSubtypeConstraint{tp, oneTypePack});
return Inference{typeResult}; return Inference{typeResult, connective};
} }
void ConstraintGraphBuilder::reportError(Location location, TypeErrorData err) void ConstraintGraphBuilder::reportError(Location location, TypeErrorData err)

View file

@ -440,8 +440,8 @@ bool ConstraintSolver::tryDispatch(NotNull<const Constraint> constraint, bool fo
success = tryDispatch(*fcc, constraint); success = tryDispatch(*fcc, constraint);
else if (auto hpc = get<HasPropConstraint>(*constraint)) else if (auto hpc = get<HasPropConstraint>(*constraint))
success = tryDispatch(*hpc, constraint); success = tryDispatch(*hpc, constraint);
else if (auto rc = get<RefinementConstraint>(*constraint)) else if (auto sottc = get<SingletonOrTopTypeConstraint>(*constraint))
success = tryDispatch(*rc, constraint); success = tryDispatch(*sottc, constraint);
else else
LUAU_ASSERT(false); LUAU_ASSERT(false);
@ -1274,25 +1274,18 @@ bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNull<const Con
return true; return true;
} }
bool ConstraintSolver::tryDispatch(const RefinementConstraint& c, NotNull<const Constraint> constraint) bool ConstraintSolver::tryDispatch(const SingletonOrTopTypeConstraint& c, NotNull<const Constraint> constraint)
{ {
// TODO: Figure out exact details on when refinements need to be blocked. if (isBlocked(c.discriminantType))
// It's possible that it never needs to be, since we can just use intersection types with the discriminant type? return false;
if (!constraint->scope->parent) TypeId followed = follow(c.discriminantType);
iceReporter.ice("No parent scope");
std::optional<TypeId> previousTy = constraint->scope->parent->lookup(c.def); // `nil` is a singleton type too! There's only one value of type `nil`.
if (!previousTy) if (get<SingletonTypeVar>(followed) || isNil(followed))
iceReporter.ice("No previous type"); *asMutable(c.resultType) = NegationTypeVar{c.discriminantType};
else
std::optional<TypeId> useTy = constraint->scope->lookup(c.def); *asMutable(c.resultType) = BoundTypeVar{singletonTypes->unknownType};
if (!useTy)
iceReporter.ice("The def is not bound to a type");
TypeId resultTy = follow(*useTy);
std::vector<TypeId> parts{*previousTy, c.discriminantType};
asMutable(resultTy)->ty.emplace<IntersectionTypeVar>(std::move(parts));
return true; return true;
} }

View file

@ -13,16 +13,16 @@ declare bit32: {
bor: (...number) -> number, bor: (...number) -> number,
bxor: (...number) -> number, bxor: (...number) -> number,
btest: (number, ...number) -> boolean, btest: (number, ...number) -> boolean,
rrotate: (number, number) -> number, rrotate: (x: number, disp: number) -> number,
lrotate: (number, number) -> number, lrotate: (x: number, disp: number) -> number,
lshift: (number, number) -> number, lshift: (x: number, disp: number) -> number,
arshift: (number, number) -> number, arshift: (x: number, disp: number) -> number,
rshift: (number, number) -> number, rshift: (x: number, disp: number) -> number,
bnot: (number) -> number, bnot: (x: number) -> number,
extract: (number, number, number?) -> number, extract: (n: number, field: number, width: number?) -> number,
replace: (number, number, number, number?) -> number, replace: (n: number, v: number, field: number, width: number?) -> number,
countlz: (number) -> number, countlz: (n: number) -> number,
countrz: (number) -> number, countrz: (n: number) -> number,
} }
declare math: { declare math: {
@ -93,9 +93,9 @@ type DateTypeResult = {
} }
declare os: { declare os: {
time: (DateTypeArg?) -> number, time: (time: DateTypeArg?) -> number,
date: (string?, number?) -> DateTypeResult | string, date: (formatString: string?, time: number?) -> DateTypeResult | string,
difftime: (DateTypeResult | number, DateTypeResult | number) -> number, difftime: (t2: DateTypeResult | number, t1: DateTypeResult | number) -> number,
clock: () -> number, clock: () -> number,
} }
@ -145,51 +145,51 @@ declare function loadstring<A...>(src: string, chunkname: string?): (((A...) ->
declare function newproxy(mt: boolean?): any declare function newproxy(mt: boolean?): any
declare coroutine: { declare coroutine: {
create: <A..., R...>((A...) -> R...) -> thread, create: <A..., R...>(f: (A...) -> R...) -> thread,
resume: <A..., R...>(thread, A...) -> (boolean, R...), resume: <A..., R...>(co: thread, A...) -> (boolean, R...),
running: () -> thread, running: () -> thread,
status: (thread) -> "dead" | "running" | "normal" | "suspended", status: (co: thread) -> "dead" | "running" | "normal" | "suspended",
-- FIXME: This technically returns a function, but we can't represent this yet. -- FIXME: This technically returns a function, but we can't represent this yet.
wrap: <A..., R...>((A...) -> R...) -> any, wrap: <A..., R...>(f: (A...) -> R...) -> any,
yield: <A..., R...>(A...) -> R..., yield: <A..., R...>(A...) -> R...,
isyieldable: () -> boolean, isyieldable: () -> boolean,
close: (thread) -> (boolean, any) close: (co: thread) -> (boolean, any)
} }
declare table: { declare table: {
concat: <V>({V}, string?, number?, number?) -> string, concat: <V>(t: {V}, sep: string?, i: number?, j: number?) -> string,
insert: (<V>({V}, V) -> ()) & (<V>({V}, number, V) -> ()), insert: (<V>(t: {V}, value: V) -> ()) & (<V>(t: {V}, pos: number, value: V) -> ()),
maxn: <V>({V}) -> number, maxn: <V>(t: {V}) -> number,
remove: <V>({V}, number?) -> V?, remove: <V>(t: {V}, number?) -> V?,
sort: <V>({V}, ((V, V) -> boolean)?) -> (), sort: <V>(t: {V}, comp: ((V, V) -> boolean)?) -> (),
create: <V>(number, V?) -> {V}, create: <V>(count: number, value: V?) -> {V},
find: <V>({V}, V, number?) -> number?, find: <V>(haystack: {V}, needle: V, init: number?) -> number?,
unpack: <V>({V}, number?, number?) -> ...V, unpack: <V>(list: {V}, i: number?, j: number?) -> ...V,
pack: <V>(...V) -> { n: number, [number]: V }, pack: <V>(...V) -> { n: number, [number]: V },
getn: <V>({V}) -> number, getn: <V>(t: {V}) -> number,
foreach: <K, V>({[K]: V}, (K, V) -> ()) -> (), foreach: <K, V>(t: {[K]: V}, f: (K, V) -> ()) -> (),
foreachi: <V>({V}, (number, V) -> ()) -> (), foreachi: <V>({V}, (number, V) -> ()) -> (),
move: <V>({V}, number, number, number, {V}?) -> {V}, move: <V>(src: {V}, a: number, b: number, t: number, dst: {V}?) -> {V},
clear: <K, V>({[K]: V}) -> (), clear: <K, V>(table: {[K]: V}) -> (),
isfrozen: <K, V>({[K]: V}) -> boolean, isfrozen: <K, V>(t: {[K]: V}) -> boolean,
} }
declare debug: { declare debug: {
info: (<R...>(thread, number, string) -> R...) & (<R...>(number, string) -> R...) & (<A..., R1..., R2...>((A...) -> R1..., string) -> R2...), info: (<R...>(thread: thread, level: number, options: string) -> R...) & (<R...>(level: number, options: string) -> R...) & (<A..., R1..., R2...>(func: (A...) -> R1..., options: string) -> R2...),
traceback: ((string?, number?) -> string) & ((thread, string?, number?) -> string), traceback: ((message: string?, level: number?) -> string) & ((thread: thread, message: string?, level: number?) -> string),
} }
declare utf8: { declare utf8: {
char: (...number) -> string, char: (...number) -> string,
charpattern: string, charpattern: string,
codes: (string) -> ((string, number) -> (number, number), string, number), codes: (str: string) -> ((string, number) -> (number, number), string, number),
codepoint: (string, number?, number?) -> ...number, codepoint: (str: string, i: number?, j: number?) -> ...number,
len: (string, number?, number?) -> (number?, number?), len: (s: string, i: number?, j: number?) -> (number?, number?),
offset: (string, number?, number?) -> number, offset: (s: string, n: number?, i: number?) -> number,
} }
-- Cannot use `typeof` here because it will produce a polytype when we expect a monotype. -- Cannot use `typeof` here because it will produce a polytype when we expect a monotype.

View file

@ -7,9 +7,9 @@
#include "Luau/Clone.h" #include "Luau/Clone.h"
#include "Luau/Common.h" #include "Luau/Common.h"
#include "Luau/RecursionCounter.h"
#include "Luau/TypeVar.h" #include "Luau/TypeVar.h"
#include "Luau/Unifier.h" #include "Luau/Unifier.h"
#include "Luau/VisitTypeVar.h"
LUAU_FASTFLAGVARIABLE(DebugLuauCopyBeforeNormalizing, false) LUAU_FASTFLAGVARIABLE(DebugLuauCopyBeforeNormalizing, false)
LUAU_FASTFLAGVARIABLE(DebugLuauCheckNormalizeInvariant, false) LUAU_FASTFLAGVARIABLE(DebugLuauCheckNormalizeInvariant, false)
@ -20,6 +20,7 @@ LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000);
LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineTableFix, false); LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineTableFix, false);
LUAU_FASTFLAGVARIABLE(LuauTypeNormalization2, false); LUAU_FASTFLAGVARIABLE(LuauTypeNormalization2, false);
LUAU_FASTFLAGVARIABLE(LuauNegatedStringSingletons, false); LUAU_FASTFLAGVARIABLE(LuauNegatedStringSingletons, false);
LUAU_FASTFLAGVARIABLE(LuauNegatedFunctionTypes, false);
LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAG(LuauUnknownAndNeverType)
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
LUAU_FASTFLAG(LuauOverloadedFunctionSubtypingPerf); LUAU_FASTFLAG(LuauOverloadedFunctionSubtypingPerf);
@ -206,6 +207,28 @@ bool isSubtype(const NormalizedStringType& subStr, const NormalizedStringType& s
return true; return true;
} }
NormalizedFunctionType::NormalizedFunctionType()
: parts(FFlag::LuauNegatedFunctionTypes ? std::optional<TypeIds>{TypeIds{}} : std::nullopt)
{
}
void NormalizedFunctionType::resetToTop()
{
isTop = true;
parts.emplace();
}
void NormalizedFunctionType::resetToNever()
{
isTop = false;
parts.emplace();
}
bool NormalizedFunctionType::isNever() const
{
return !isTop && (!parts || parts->empty());
}
NormalizedType::NormalizedType(NotNull<SingletonTypes> singletonTypes) NormalizedType::NormalizedType(NotNull<SingletonTypes> singletonTypes)
: tops(singletonTypes->neverType) : tops(singletonTypes->neverType)
, booleans(singletonTypes->neverType) , booleans(singletonTypes->neverType)
@ -220,8 +243,8 @@ NormalizedType::NormalizedType(NotNull<SingletonTypes> singletonTypes)
static bool isInhabited(const NormalizedType& norm) static bool isInhabited(const NormalizedType& norm)
{ {
return !get<NeverTypeVar>(norm.tops) || !get<NeverTypeVar>(norm.booleans) || !norm.classes.empty() || !get<NeverTypeVar>(norm.errors) || return !get<NeverTypeVar>(norm.tops) || !get<NeverTypeVar>(norm.booleans) || !norm.classes.empty() || !get<NeverTypeVar>(norm.errors) ||
!get<NeverTypeVar>(norm.nils) || !get<NeverTypeVar>(norm.numbers) || !norm.strings.isNever() || !get<NeverTypeVar>(norm.nils) || !get<NeverTypeVar>(norm.numbers) || !norm.strings.isNever() || !get<NeverTypeVar>(norm.threads) ||
!get<NeverTypeVar>(norm.threads) || norm.functions || !norm.tables.empty() || !norm.tyvars.empty(); !norm.functions.isNever() || !norm.tables.empty() || !norm.tyvars.empty();
} }
static int tyvarIndex(TypeId ty) static int tyvarIndex(TypeId ty)
@ -317,10 +340,14 @@ static bool isNormalizedThread(TypeId ty)
static bool areNormalizedFunctions(const NormalizedFunctionType& tys) static bool areNormalizedFunctions(const NormalizedFunctionType& tys)
{ {
if (tys) if (tys.parts)
for (TypeId ty : *tys) {
for (TypeId ty : *tys.parts)
{
if (!get<FunctionTypeVar>(ty) && !get<ErrorTypeVar>(ty)) if (!get<FunctionTypeVar>(ty) && !get<ErrorTypeVar>(ty))
return false; return false;
}
}
return true; return true;
} }
@ -420,7 +447,7 @@ void Normalizer::clearNormal(NormalizedType& norm)
norm.strings.resetToNever(); norm.strings.resetToNever();
norm.threads = singletonTypes->neverType; norm.threads = singletonTypes->neverType;
norm.tables.clear(); norm.tables.clear();
norm.functions = std::nullopt; norm.functions.resetToNever();
norm.tyvars.clear(); norm.tyvars.clear();
} }
@ -809,20 +836,28 @@ std::optional<TypeId> Normalizer::unionOfFunctions(TypeId here, TypeId there)
void Normalizer::unionFunctions(NormalizedFunctionType& heres, const NormalizedFunctionType& theres) void Normalizer::unionFunctions(NormalizedFunctionType& heres, const NormalizedFunctionType& theres)
{ {
if (!theres) if (FFlag::LuauNegatedFunctionTypes)
{
if (heres.isTop)
return;
if (theres.isTop)
heres.resetToTop();
}
if (theres.isNever())
return; return;
TypeIds tmps; TypeIds tmps;
if (!heres) if (heres.isNever())
{ {
tmps.insert(theres->begin(), theres->end()); tmps.insert(theres.parts->begin(), theres.parts->end());
heres = std::move(tmps); heres.parts = std::move(tmps);
return; return;
} }
for (TypeId here : *heres) for (TypeId here : *heres.parts)
for (TypeId there : *theres) for (TypeId there : *theres.parts)
{ {
if (std::optional<TypeId> fun = unionOfFunctions(here, there)) if (std::optional<TypeId> fun = unionOfFunctions(here, there))
tmps.insert(*fun); tmps.insert(*fun);
@ -830,28 +865,28 @@ void Normalizer::unionFunctions(NormalizedFunctionType& heres, const NormalizedF
tmps.insert(singletonTypes->errorRecoveryType(there)); tmps.insert(singletonTypes->errorRecoveryType(there));
} }
heres = std::move(tmps); heres.parts = std::move(tmps);
} }
void Normalizer::unionFunctionsWithFunction(NormalizedFunctionType& heres, TypeId there) void Normalizer::unionFunctionsWithFunction(NormalizedFunctionType& heres, TypeId there)
{ {
if (!heres) if (heres.isNever())
{ {
TypeIds tmps; TypeIds tmps;
tmps.insert(there); tmps.insert(there);
heres = std::move(tmps); heres.parts = std::move(tmps);
return; return;
} }
TypeIds tmps; TypeIds tmps;
for (TypeId here : *heres) for (TypeId here : *heres.parts)
{ {
if (std::optional<TypeId> fun = unionOfFunctions(here, there)) if (std::optional<TypeId> fun = unionOfFunctions(here, there))
tmps.insert(*fun); tmps.insert(*fun);
else else
tmps.insert(singletonTypes->errorRecoveryType(there)); tmps.insert(singletonTypes->errorRecoveryType(there));
} }
heres = std::move(tmps); heres.parts = std::move(tmps);
} }
void Normalizer::unionTablesWithTable(TypeIds& heres, TypeId there) void Normalizer::unionTablesWithTable(TypeIds& heres, TypeId there)
@ -1004,6 +1039,11 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor
here.strings.resetToString(); here.strings.resetToString();
else if (ptv->type == PrimitiveTypeVar::Thread) else if (ptv->type == PrimitiveTypeVar::Thread)
here.threads = there; here.threads = there;
else if (ptv->type == PrimitiveTypeVar::Function)
{
LUAU_ASSERT(FFlag::LuauNegatedFunctionTypes);
here.functions.resetToTop();
}
else else
LUAU_ASSERT(!"Unreachable"); LUAU_ASSERT(!"Unreachable");
} }
@ -1036,8 +1076,11 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor
else if (const NegationTypeVar* ntv = get<NegationTypeVar>(there)) else if (const NegationTypeVar* ntv = get<NegationTypeVar>(there))
{ {
const NormalizedType* thereNormal = normalize(ntv->ty); const NormalizedType* thereNormal = normalize(ntv->ty);
NormalizedType tn = negateNormal(*thereNormal); std::optional<NormalizedType> tn = negateNormal(*thereNormal);
if (!unionNormals(here, tn)) if (!tn)
return false;
if (!unionNormals(here, *tn))
return false; return false;
} }
else else
@ -1053,7 +1096,7 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor
// ------- Negations // ------- Negations
NormalizedType Normalizer::negateNormal(const NormalizedType& here) std::optional<NormalizedType> Normalizer::negateNormal(const NormalizedType& here)
{ {
NormalizedType result{singletonTypes}; NormalizedType result{singletonTypes};
if (!get<NeverTypeVar>(here.tops)) if (!get<NeverTypeVar>(here.tops))
@ -1092,8 +1135,22 @@ NormalizedType Normalizer::negateNormal(const NormalizedType& here)
result.threads = get<NeverTypeVar>(here.threads) ? singletonTypes->threadType : singletonTypes->neverType; result.threads = get<NeverTypeVar>(here.threads) ? singletonTypes->threadType : singletonTypes->neverType;
/*
* Things get weird and so, so complicated if we allow negations of
* arbitrary function types. Ordinary code can never form these kinds of
* types, so we decline to negate them.
*/
if (FFlag::LuauNegatedFunctionTypes)
{
if (here.functions.isNever())
result.functions.resetToTop();
else if (here.functions.isTop)
result.functions.resetToNever();
else
return std::nullopt;
}
// TODO: negating tables // TODO: negating tables
// TODO: negating functions
// TODO: negating tyvars? // TODO: negating tyvars?
return result; return result;
@ -1142,21 +1199,25 @@ void Normalizer::subtractPrimitive(NormalizedType& here, TypeId ty)
LUAU_ASSERT(ptv); LUAU_ASSERT(ptv);
switch (ptv->type) switch (ptv->type)
{ {
case PrimitiveTypeVar::NilType: case PrimitiveTypeVar::NilType:
here.nils = singletonTypes->neverType; here.nils = singletonTypes->neverType;
break; break;
case PrimitiveTypeVar::Boolean: case PrimitiveTypeVar::Boolean:
here.booleans = singletonTypes->neverType; here.booleans = singletonTypes->neverType;
break; break;
case PrimitiveTypeVar::Number: case PrimitiveTypeVar::Number:
here.numbers = singletonTypes->neverType; here.numbers = singletonTypes->neverType;
break; break;
case PrimitiveTypeVar::String: case PrimitiveTypeVar::String:
here.strings.resetToNever(); here.strings.resetToNever();
break; break;
case PrimitiveTypeVar::Thread: case PrimitiveTypeVar::Thread:
here.threads = singletonTypes->neverType; here.threads = singletonTypes->neverType;
break; break;
case PrimitiveTypeVar::Function:
LUAU_ASSERT(FFlag::LuauNegatedStringSingletons);
here.functions.resetToNever();
break;
} }
} }
@ -1738,18 +1799,20 @@ std::optional<TypeId> Normalizer::unionSaturatedFunctions(TypeId here, TypeId th
void Normalizer::intersectFunctionsWithFunction(NormalizedFunctionType& heres, TypeId there) void Normalizer::intersectFunctionsWithFunction(NormalizedFunctionType& heres, TypeId there)
{ {
if (!heres) if (heres.isNever())
return; return;
for (auto it = heres->begin(); it != heres->end();) heres.isTop = false;
for (auto it = heres.parts->begin(); it != heres.parts->end();)
{ {
TypeId here = *it; TypeId here = *it;
if (get<ErrorTypeVar>(here)) if (get<ErrorTypeVar>(here))
it++; it++;
else if (std::optional<TypeId> tmp = intersectionOfFunctions(here, there)) else if (std::optional<TypeId> tmp = intersectionOfFunctions(here, there))
{ {
heres->erase(it); heres.parts->erase(it);
heres->insert(*tmp); heres.parts->insert(*tmp);
return; return;
} }
else else
@ -1757,27 +1820,27 @@ void Normalizer::intersectFunctionsWithFunction(NormalizedFunctionType& heres, T
} }
TypeIds tmps; TypeIds tmps;
for (TypeId here : *heres) for (TypeId here : *heres.parts)
{ {
if (std::optional<TypeId> tmp = unionSaturatedFunctions(here, there)) if (std::optional<TypeId> tmp = unionSaturatedFunctions(here, there))
tmps.insert(*tmp); tmps.insert(*tmp);
} }
heres->insert(there); heres.parts->insert(there);
heres->insert(tmps.begin(), tmps.end()); heres.parts->insert(tmps.begin(), tmps.end());
} }
void Normalizer::intersectFunctions(NormalizedFunctionType& heres, const NormalizedFunctionType& theres) void Normalizer::intersectFunctions(NormalizedFunctionType& heres, const NormalizedFunctionType& theres)
{ {
if (!heres) if (heres.isNever())
return; return;
else if (!theres) else if (theres.isNever())
{ {
heres = std::nullopt; heres.resetToNever();
return; return;
} }
else else
{ {
for (TypeId there : *theres) for (TypeId there : *theres.parts)
intersectFunctionsWithFunction(heres, there); intersectFunctionsWithFunction(heres, there);
} }
} }
@ -1935,6 +1998,7 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there)
TypeId nils = here.nils; TypeId nils = here.nils;
TypeId numbers = here.numbers; TypeId numbers = here.numbers;
NormalizedStringType strings = std::move(here.strings); NormalizedStringType strings = std::move(here.strings);
NormalizedFunctionType functions = std::move(here.functions);
TypeId threads = here.threads; TypeId threads = here.threads;
clearNormal(here); clearNormal(here);
@ -1949,6 +2013,11 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there)
here.strings = std::move(strings); here.strings = std::move(strings);
else if (ptv->type == PrimitiveTypeVar::Thread) else if (ptv->type == PrimitiveTypeVar::Thread)
here.threads = threads; here.threads = threads;
else if (ptv->type == PrimitiveTypeVar::Function)
{
LUAU_ASSERT(FFlag::LuauNegatedFunctionTypes);
here.functions = std::move(functions);
}
else else
LUAU_ASSERT(!"Unreachable"); LUAU_ASSERT(!"Unreachable");
} }
@ -1981,8 +2050,10 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there)
for (TypeId part : itv->options) for (TypeId part : itv->options)
{ {
const NormalizedType* normalPart = normalize(part); const NormalizedType* normalPart = normalize(part);
NormalizedType negated = negateNormal(*normalPart); std::optional<NormalizedType> negated = negateNormal(*normalPart);
intersectNormals(here, negated); if (!negated)
return false;
intersectNormals(here, *negated);
} }
} }
else else
@ -2016,14 +2087,16 @@ TypeId Normalizer::typeFromNormal(const NormalizedType& norm)
result.insert(result.end(), norm.classes.begin(), norm.classes.end()); result.insert(result.end(), norm.classes.begin(), norm.classes.end());
if (!get<NeverTypeVar>(norm.errors)) if (!get<NeverTypeVar>(norm.errors))
result.push_back(norm.errors); result.push_back(norm.errors);
if (norm.functions) if (FFlag::LuauNegatedFunctionTypes && norm.functions.isTop)
result.push_back(singletonTypes->functionType);
else if (!norm.functions.isNever())
{ {
if (norm.functions->size() == 1) if (norm.functions.parts->size() == 1)
result.push_back(*norm.functions->begin()); result.push_back(*norm.functions.parts->begin());
else else
{ {
std::vector<TypeId> parts; std::vector<TypeId> parts;
parts.insert(parts.end(), norm.functions->begin(), norm.functions->end()); parts.insert(parts.end(), norm.functions.parts->begin(), norm.functions.parts->end());
result.push_back(arena->addType(IntersectionTypeVar{std::move(parts)})); result.push_back(arena->addType(IntersectionTypeVar{std::move(parts)}));
} }
} }
@ -2070,62 +2143,24 @@ TypeId Normalizer::typeFromNormal(const NormalizedType& norm)
return arena->addType(UnionTypeVar{std::move(result)}); return arena->addType(UnionTypeVar{std::move(result)});
} }
namespace bool isSubtype(TypeId subTy, TypeId superTy, NotNull<Scope> scope, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter& ice)
{
struct Replacer
{
TypeArena* arena;
TypeId sourceType;
TypeId replacedType;
DenseHashMap<TypeId, TypeId> newTypes;
Replacer(TypeArena* arena, TypeId sourceType, TypeId replacedType)
: arena(arena)
, sourceType(sourceType)
, replacedType(replacedType)
, newTypes(nullptr)
{
}
TypeId smartClone(TypeId t)
{
t = follow(t);
TypeId* res = newTypes.find(t);
if (res)
return *res;
TypeId result = shallowClone(t, *arena, TxnLog::empty());
newTypes[t] = result;
newTypes[result] = result;
return result;
}
};
} // anonymous namespace
bool isSubtype(TypeId subTy, TypeId superTy, NotNull<Scope> scope, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter& ice, bool anyIsTop)
{ {
UnifierSharedState sharedState{&ice}; UnifierSharedState sharedState{&ice};
TypeArena arena; TypeArena arena;
Normalizer normalizer{&arena, singletonTypes, NotNull{&sharedState}}; Normalizer normalizer{&arena, singletonTypes, NotNull{&sharedState}};
Unifier u{NotNull{&normalizer}, Mode::Strict, scope, Location{}, Covariant}; Unifier u{NotNull{&normalizer}, Mode::Strict, scope, Location{}, Covariant};
u.anyIsTop = anyIsTop;
u.tryUnify(subTy, superTy); u.tryUnify(subTy, superTy);
const bool ok = u.errors.empty() && u.log.empty(); const bool ok = u.errors.empty() && u.log.empty();
return ok; return ok;
} }
bool isSubtype( bool isSubtype(TypePackId subPack, TypePackId superPack, NotNull<Scope> scope, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter& ice)
TypePackId subPack, TypePackId superPack, NotNull<Scope> scope, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter& ice, bool anyIsTop)
{ {
UnifierSharedState sharedState{&ice}; UnifierSharedState sharedState{&ice};
TypeArena arena; TypeArena arena;
Normalizer normalizer{&arena, singletonTypes, NotNull{&sharedState}}; Normalizer normalizer{&arena, singletonTypes, NotNull{&sharedState}};
Unifier u{NotNull{&normalizer}, Mode::Strict, scope, Location{}, Covariant}; Unifier u{NotNull{&normalizer}, Mode::Strict, scope, Location{}, Covariant};
u.anyIsTop = anyIsTop;
u.tryUnify(subPack, superPack); u.tryUnify(subPack, superPack);
const bool ok = u.errors.empty() && u.log.empty(); const bool ok = u.errors.empty() && u.log.empty();

View file

@ -10,12 +10,12 @@
#include <algorithm> #include <algorithm>
#include <stdexcept> #include <stdexcept>
LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
LUAU_FASTFLAG(LuauLvaluelessPath) LUAU_FASTFLAG(LuauLvaluelessPath)
LUAU_FASTFLAGVARIABLE(LuauSpecialTypesAsterisked, false) LUAU_FASTFLAG(LuauUnknownAndNeverType)
LUAU_FASTFLAGVARIABLE(LuauFixNameMaps, false) LUAU_FASTFLAGVARIABLE(LuauFixNameMaps, false)
LUAU_FASTFLAGVARIABLE(LuauUnseeArrayTtv, false)
LUAU_FASTFLAGVARIABLE(LuauFunctionReturnStringificationFixup, false) LUAU_FASTFLAGVARIABLE(LuauFunctionReturnStringificationFixup, false)
LUAU_FASTFLAGVARIABLE(LuauUnseeArrayTtv, false)
/* /*
* Prefix generic typenames with gen- * Prefix generic typenames with gen-
@ -225,6 +225,20 @@ struct StringifierState
result.name += s; result.name += s;
} }
void emitLevel(Scope* scope)
{
size_t count = 0;
for (Scope* s = scope; s; s = s->parent.get())
++count;
emit(count);
emit("-");
char buffer[16];
uint32_t s = uint32_t(intptr_t(scope) & 0xFFFFFF);
snprintf(buffer, sizeof(buffer), "0x%x", s);
emit(buffer);
}
void emit(TypeLevel level) void emit(TypeLevel level)
{ {
emit(std::to_string(level.level)); emit(std::to_string(level.level));
@ -296,10 +310,7 @@ struct TypeVarStringifier
if (tv->ty.valueless_by_exception()) if (tv->ty.valueless_by_exception())
{ {
state.result.error = true; state.result.error = true;
if (FFlag::LuauSpecialTypesAsterisked) state.emit("* VALUELESS BY EXCEPTION *");
state.emit("* VALUELESS BY EXCEPTION *");
else
state.emit("< VALUELESS BY EXCEPTION >");
return; return;
} }
@ -377,7 +388,10 @@ struct TypeVarStringifier
if (FFlag::DebugLuauVerboseTypeNames) if (FFlag::DebugLuauVerboseTypeNames)
{ {
state.emit("-"); state.emit("-");
state.emit(ftv.level); if (FFlag::DebugLuauDeferredConstraintResolution)
state.emitLevel(ftv.scope);
else
state.emit(ftv.level);
} }
} }
@ -399,6 +413,15 @@ struct TypeVarStringifier
} }
else else
state.emit(state.getName(ty)); state.emit(state.getName(ty));
if (FFlag::DebugLuauVerboseTypeNames)
{
state.emit("-");
if (FFlag::DebugLuauDeferredConstraintResolution)
state.emitLevel(gtv.scope);
else
state.emit(gtv.level);
}
} }
void operator()(TypeId, const BlockedTypeVar& btv) void operator()(TypeId, const BlockedTypeVar& btv)
@ -434,6 +457,9 @@ struct TypeVarStringifier
case PrimitiveTypeVar::Thread: case PrimitiveTypeVar::Thread:
state.emit("thread"); state.emit("thread");
return; return;
case PrimitiveTypeVar::Function:
state.emit("function");
return;
default: default:
LUAU_ASSERT(!"Unknown primitive type"); LUAU_ASSERT(!"Unknown primitive type");
throwRuntimeError("Unknown primitive type " + std::to_string(ptv.type)); throwRuntimeError("Unknown primitive type " + std::to_string(ptv.type));
@ -462,10 +488,7 @@ struct TypeVarStringifier
if (state.hasSeen(&ftv)) if (state.hasSeen(&ftv))
{ {
state.result.cycle = true; state.result.cycle = true;
if (FFlag::LuauSpecialTypesAsterisked) state.emit("*CYCLE*");
state.emit("*CYCLE*");
else
state.emit("<CYCLE>");
return; return;
} }
@ -573,10 +596,7 @@ struct TypeVarStringifier
if (state.hasSeen(&ttv)) if (state.hasSeen(&ttv))
{ {
state.result.cycle = true; state.result.cycle = true;
if (FFlag::LuauSpecialTypesAsterisked) state.emit("*CYCLE*");
state.emit("*CYCLE*");
else
state.emit("<CYCLE>");
return; return;
} }
@ -710,10 +730,7 @@ struct TypeVarStringifier
if (state.hasSeen(&uv)) if (state.hasSeen(&uv))
{ {
state.result.cycle = true; state.result.cycle = true;
if (FFlag::LuauSpecialTypesAsterisked) state.emit("*CYCLE*");
state.emit("*CYCLE*");
else
state.emit("<CYCLE>");
return; return;
} }
@ -780,10 +797,7 @@ struct TypeVarStringifier
if (state.hasSeen(&uv)) if (state.hasSeen(&uv))
{ {
state.result.cycle = true; state.result.cycle = true;
if (FFlag::LuauSpecialTypesAsterisked) state.emit("*CYCLE*");
state.emit("*CYCLE*");
else
state.emit("<CYCLE>");
return; return;
} }
@ -828,10 +842,7 @@ struct TypeVarStringifier
void operator()(TypeId, const ErrorTypeVar& tv) void operator()(TypeId, const ErrorTypeVar& tv)
{ {
state.result.error = true; state.result.error = true;
if (FFlag::LuauSpecialTypesAsterisked) state.emit(FFlag::LuauUnknownAndNeverType ? "*error-type*" : "*unknown*");
state.emit(FFlag::LuauUnknownAndNeverType ? "*error-type*" : "*unknown*");
else
state.emit(FFlag::LuauUnknownAndNeverType ? "<error-type>" : "*unknown*");
} }
void operator()(TypeId, const LazyTypeVar& ltv) void operator()(TypeId, const LazyTypeVar& ltv)
@ -850,11 +861,6 @@ struct TypeVarStringifier
state.emit("never"); state.emit("never");
} }
void operator()(TypeId ty, const UseTypeVar&)
{
stringify(follow(ty));
}
void operator()(TypeId, const NegationTypeVar& ntv) void operator()(TypeId, const NegationTypeVar& ntv)
{ {
state.emit("~"); state.emit("~");
@ -907,10 +913,7 @@ struct TypePackStringifier
if (tp->ty.valueless_by_exception()) if (tp->ty.valueless_by_exception())
{ {
state.result.error = true; state.result.error = true;
if (FFlag::LuauSpecialTypesAsterisked) state.emit("* VALUELESS TP BY EXCEPTION *");
state.emit("* VALUELESS TP BY EXCEPTION *");
else
state.emit("< VALUELESS TP BY EXCEPTION >");
return; return;
} }
@ -934,10 +937,7 @@ struct TypePackStringifier
if (state.hasSeen(&tp)) if (state.hasSeen(&tp))
{ {
state.result.cycle = true; state.result.cycle = true;
if (FFlag::LuauSpecialTypesAsterisked) state.emit("*CYCLETP*");
state.emit("*CYCLETP*");
else
state.emit("<CYCLETP>");
return; return;
} }
@ -982,10 +982,7 @@ struct TypePackStringifier
void operator()(TypePackId, const Unifiable::Error& error) void operator()(TypePackId, const Unifiable::Error& error)
{ {
state.result.error = true; state.result.error = true;
if (FFlag::LuauSpecialTypesAsterisked) state.emit(FFlag::LuauUnknownAndNeverType ? "*error-type*" : "*unknown*");
state.emit(FFlag::LuauUnknownAndNeverType ? "*error-type*" : "*unknown*");
else
state.emit(FFlag::LuauUnknownAndNeverType ? "<error-type>" : "*unknown*");
} }
void operator()(TypePackId, const VariadicTypePack& pack) void operator()(TypePackId, const VariadicTypePack& pack)
@ -993,10 +990,7 @@ struct TypePackStringifier
state.emit("..."); state.emit("...");
if (FFlag::DebugLuauVerboseTypeNames && pack.hidden) if (FFlag::DebugLuauVerboseTypeNames && pack.hidden)
{ {
if (FFlag::LuauSpecialTypesAsterisked) state.emit("*hidden*");
state.emit("*hidden*");
else
state.emit("<hidden>");
} }
stringify(pack.ty); stringify(pack.ty);
} }
@ -1031,7 +1025,10 @@ struct TypePackStringifier
if (FFlag::DebugLuauVerboseTypeNames) if (FFlag::DebugLuauVerboseTypeNames)
{ {
state.emit("-"); state.emit("-");
state.emit(pack.level); if (FFlag::DebugLuauDeferredConstraintResolution)
state.emitLevel(pack.scope);
else
state.emit(pack.level);
} }
state.emit("..."); state.emit("...");
@ -1204,10 +1201,7 @@ ToStringResult toStringDetailed(TypeId ty, ToStringOptions& opts)
{ {
result.truncated = true; result.truncated = true;
if (FFlag::LuauSpecialTypesAsterisked) result.name += "... *TRUNCATED*";
result.name += "... *TRUNCATED*";
else
result.name += "... <TRUNCATED>";
} }
return result; return result;
@ -1280,10 +1274,7 @@ ToStringResult toStringDetailed(TypePackId tp, ToStringOptions& opts)
if (opts.maxTypeLength > 0 && result.name.length() > opts.maxTypeLength) if (opts.maxTypeLength > 0 && result.name.length() > opts.maxTypeLength)
{ {
if (FFlag::LuauSpecialTypesAsterisked) result.name += "... *TRUNCATED*";
result.name += "... *TRUNCATED*";
else
result.name += "... <TRUNCATED>";
} }
return result; return result;
@ -1526,9 +1517,12 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts)
{ {
return tos(c.resultType, opts) + " ~ hasProp " + tos(c.subjectType, opts) + ", \"" + c.prop + "\""; return tos(c.resultType, opts) + " ~ hasProp " + tos(c.subjectType, opts) + ", \"" + c.prop + "\"";
} }
else if constexpr (std::is_same_v<T, RefinementConstraint>) else if constexpr (std::is_same_v<T, SingletonOrTopTypeConstraint>)
{ {
return "TODO"; std::string result = tos(c.resultType, opts);
std::string discriminant = tos(c.discriminantType, opts);
return result + " ~ if isSingleton D then ~D else unknown where D = " + discriminant;
} }
else else
static_assert(always_false_v<T>, "Non-exhaustive constraint switch"); static_assert(always_false_v<T>, "Non-exhaustive constraint switch");

View file

@ -338,12 +338,6 @@ public:
{ {
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName{"never"}); return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName{"never"});
} }
AstType* operator()(const UseTypeVar& utv)
{
std::optional<TypeId> ty = utv.scope->lookup(utv.def);
LUAU_ASSERT(ty);
return Luau::visit(*this, (*ty)->ty);
}
AstType* operator()(const NegationTypeVar& ntv) AstType* operator()(const NegationTypeVar& ntv)
{ {
// FIXME: do the same thing we do with ErrorTypeVar // FIXME: do the same thing we do with ErrorTypeVar

View file

@ -301,7 +301,6 @@ struct TypeChecker2
UnifierSharedState sharedState{&ice}; UnifierSharedState sharedState{&ice};
Normalizer normalizer{&arena, singletonTypes, NotNull{&sharedState}}; Normalizer normalizer{&arena, singletonTypes, NotNull{&sharedState}};
Unifier u{NotNull{&normalizer}, Mode::Strict, stack.back(), ret->location, Covariant}; Unifier u{NotNull{&normalizer}, Mode::Strict, stack.back(), ret->location, Covariant};
u.anyIsTop = true;
u.tryUnify(actualRetType, expectedRetType); u.tryUnify(actualRetType, expectedRetType);
const bool ok = u.errors.empty() && u.log.empty(); const bool ok = u.errors.empty() && u.log.empty();
@ -331,16 +330,21 @@ struct TypeChecker2
if (value) if (value)
visit(value); visit(value);
if (i != local->values.size - 1) TypeId* maybeValueType = value ? module->astTypes.find(value) : nullptr;
if (i != local->values.size - 1 || maybeValueType)
{ {
AstLocal* var = i < local->vars.size ? local->vars.data[i] : nullptr; AstLocal* var = i < local->vars.size ? local->vars.data[i] : nullptr;
if (var && var->annotation) if (var && var->annotation)
{ {
TypeId varType = lookupAnnotation(var->annotation); TypeId annotationType = lookupAnnotation(var->annotation);
TypeId valueType = value ? lookupType(value) : nullptr; TypeId valueType = value ? lookupType(value) : nullptr;
if (valueType && !isSubtype(varType, valueType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) if (valueType)
reportError(TypeMismatch{varType, valueType}, value->location); {
ErrorVec errors = tryUnify(stack.back(), value->location, valueType, annotationType);
if (!errors.empty())
reportErrors(std::move(errors));
}
} }
} }
else else
@ -606,7 +610,7 @@ struct TypeChecker2
visit(rhs); visit(rhs);
TypeId rhsType = lookupType(rhs); TypeId rhsType = lookupType(rhs);
if (!isSubtype(rhsType, lhsType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) if (!isSubtype(rhsType, lhsType, stack.back(), singletonTypes, ice))
{ {
reportError(TypeMismatch{lhsType, rhsType}, rhs->location); reportError(TypeMismatch{lhsType, rhsType}, rhs->location);
} }
@ -757,7 +761,7 @@ struct TypeChecker2
TypeId actualType = lookupType(number); TypeId actualType = lookupType(number);
TypeId numberType = singletonTypes->numberType; TypeId numberType = singletonTypes->numberType;
if (!isSubtype(numberType, actualType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) if (!isSubtype(numberType, actualType, stack.back(), singletonTypes, ice))
{ {
reportError(TypeMismatch{actualType, numberType}, number->location); reportError(TypeMismatch{actualType, numberType}, number->location);
} }
@ -768,7 +772,7 @@ struct TypeChecker2
TypeId actualType = lookupType(string); TypeId actualType = lookupType(string);
TypeId stringType = singletonTypes->stringType; TypeId stringType = singletonTypes->stringType;
if (!isSubtype(actualType, stringType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) if (!isSubtype(actualType, stringType, stack.back(), singletonTypes, ice))
{ {
reportError(TypeMismatch{actualType, stringType}, string->location); reportError(TypeMismatch{actualType, stringType}, string->location);
} }
@ -857,7 +861,7 @@ struct TypeChecker2
FunctionTypeVar ftv{argsTp, expectedRetType}; FunctionTypeVar ftv{argsTp, expectedRetType};
TypeId expectedType = arena.addType(ftv); TypeId expectedType = arena.addType(ftv);
if (!isSubtype(testFunctionType, expectedType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) if (!isSubtype(testFunctionType, expectedType, stack.back(), singletonTypes, ice))
{ {
CloneState cloneState; CloneState cloneState;
expectedType = clone(expectedType, module->internalTypes, cloneState); expectedType = clone(expectedType, module->internalTypes, cloneState);
@ -876,7 +880,7 @@ struct TypeChecker2
getIndexTypeFromType(module->getModuleScope(), leftType, indexName->index.value, indexName->location, /* addErrors */ true); getIndexTypeFromType(module->getModuleScope(), leftType, indexName->index.value, indexName->location, /* addErrors */ true);
if (ty) if (ty)
{ {
if (!isSubtype(resultType, *ty, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) if (!isSubtype(resultType, *ty, stack.back(), singletonTypes, ice))
{ {
reportError(TypeMismatch{resultType, *ty}, indexName->location); reportError(TypeMismatch{resultType, *ty}, indexName->location);
} }
@ -909,7 +913,7 @@ struct TypeChecker2
TypeId inferredArgTy = *argIt; TypeId inferredArgTy = *argIt;
TypeId annotatedArgTy = lookupAnnotation(arg->annotation); TypeId annotatedArgTy = lookupAnnotation(arg->annotation);
if (!isSubtype(annotatedArgTy, inferredArgTy, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) if (!isSubtype(annotatedArgTy, inferredArgTy, stack.back(), singletonTypes, ice))
{ {
reportError(TypeMismatch{annotatedArgTy, inferredArgTy}, arg->location); reportError(TypeMismatch{annotatedArgTy, inferredArgTy}, arg->location);
} }
@ -1203,10 +1207,10 @@ struct TypeChecker2
TypeId computedType = lookupType(expr->expr); TypeId computedType = lookupType(expr->expr);
// Note: As an optimization, we try 'number <: number | string' first, as that is the more likely case. // Note: As an optimization, we try 'number <: number | string' first, as that is the more likely case.
if (isSubtype(annotationType, computedType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) if (isSubtype(annotationType, computedType, stack.back(), singletonTypes, ice))
return; return;
if (isSubtype(computedType, annotationType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) if (isSubtype(computedType, annotationType, stack.back(), singletonTypes, ice))
return; return;
reportError(TypesAreUnrelated{computedType, annotationType}, expr->location); reportError(TypesAreUnrelated{computedType, annotationType}, expr->location);
@ -1507,7 +1511,6 @@ struct TypeChecker2
UnifierSharedState sharedState{&ice}; UnifierSharedState sharedState{&ice};
Normalizer normalizer{&module->internalTypes, singletonTypes, NotNull{&sharedState}}; Normalizer normalizer{&module->internalTypes, singletonTypes, NotNull{&sharedState}};
Unifier u{NotNull{&normalizer}, Mode::Strict, scope, location, Covariant}; Unifier u{NotNull{&normalizer}, Mode::Strict, scope, location, Covariant};
u.anyIsTop = true;
u.tryUnify(subTy, superTy); u.tryUnify(subTy, superTy);
return std::move(u.errors); return std::move(u.errors);

View file

@ -57,13 +57,6 @@ TypeId follow(TypeId t, std::function<TypeId(TypeId)> mapper)
return btv->boundTo; return btv->boundTo;
else if (auto ttv = get<TableTypeVar>(mapper(ty))) else if (auto ttv = get<TableTypeVar>(mapper(ty)))
return ttv->boundTo; return ttv->boundTo;
else if (auto utv = get<UseTypeVar>(mapper(ty)))
{
std::optional<TypeId> ty = utv->scope->lookup(utv->def);
if (!ty)
throwRuntimeError("UseTypeVar must map to another TypeId");
return *ty;
}
else else
return std::nullopt; return std::nullopt;
}; };
@ -761,6 +754,7 @@ SingletonTypes::SingletonTypes()
, stringType(arena->addType(TypeVar{PrimitiveTypeVar{PrimitiveTypeVar::String}, /*persistent*/ true})) , stringType(arena->addType(TypeVar{PrimitiveTypeVar{PrimitiveTypeVar::String}, /*persistent*/ true}))
, booleanType(arena->addType(TypeVar{PrimitiveTypeVar{PrimitiveTypeVar::Boolean}, /*persistent*/ true})) , booleanType(arena->addType(TypeVar{PrimitiveTypeVar{PrimitiveTypeVar::Boolean}, /*persistent*/ true}))
, threadType(arena->addType(TypeVar{PrimitiveTypeVar{PrimitiveTypeVar::Thread}, /*persistent*/ true})) , threadType(arena->addType(TypeVar{PrimitiveTypeVar{PrimitiveTypeVar::Thread}, /*persistent*/ true}))
, functionType(arena->addType(TypeVar{PrimitiveTypeVar{PrimitiveTypeVar::Function}, /*persistent*/ true}))
, trueType(arena->addType(TypeVar{SingletonTypeVar{BooleanSingleton{true}}, /*persistent*/ true})) , trueType(arena->addType(TypeVar{SingletonTypeVar{BooleanSingleton{true}}, /*persistent*/ true}))
, falseType(arena->addType(TypeVar{SingletonTypeVar{BooleanSingleton{false}}, /*persistent*/ true})) , falseType(arena->addType(TypeVar{SingletonTypeVar{BooleanSingleton{false}}, /*persistent*/ true}))
, anyType(arena->addType(TypeVar{AnyTypeVar{}, /*persistent*/ true})) , anyType(arena->addType(TypeVar{AnyTypeVar{}, /*persistent*/ true}))
@ -946,7 +940,8 @@ void persist(TypeId ty)
queue.push_back(mtv->table); queue.push_back(mtv->table);
queue.push_back(mtv->metatable); queue.push_back(mtv->metatable);
} }
else if (get<GenericTypeVar>(t) || get<AnyTypeVar>(t) || get<FreeTypeVar>(t) || get<SingletonTypeVar>(t) || get<PrimitiveTypeVar>(t) || get<NegationTypeVar>(t)) else if (get<GenericTypeVar>(t) || get<AnyTypeVar>(t) || get<FreeTypeVar>(t) || get<SingletonTypeVar>(t) || get<PrimitiveTypeVar>(t) ||
get<NegationTypeVar>(t))
{ {
} }
else else

View file

@ -8,6 +8,7 @@
#include "Luau/TypePack.h" #include "Luau/TypePack.h"
#include "Luau/TypeUtils.h" #include "Luau/TypeUtils.h"
#include "Luau/TimeTrace.h" #include "Luau/TimeTrace.h"
#include "Luau/TypeVar.h"
#include "Luau/VisitTypeVar.h" #include "Luau/VisitTypeVar.h"
#include "Luau/ToString.h" #include "Luau/ToString.h"
@ -23,6 +24,7 @@ LUAU_FASTFLAGVARIABLE(LuauInstantiateInSubtyping, false)
LUAU_FASTFLAGVARIABLE(LuauOverloadedFunctionSubtypingPerf, false); LUAU_FASTFLAGVARIABLE(LuauOverloadedFunctionSubtypingPerf, false);
LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution)
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
LUAU_FASTFLAG(LuauNegatedFunctionTypes)
namespace Luau namespace Luau
{ {
@ -363,7 +365,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
if (sharedState.counters.iterationLimit > 0 && sharedState.counters.iterationLimit < sharedState.counters.iterationCount) if (sharedState.counters.iterationLimit > 0 && sharedState.counters.iterationLimit < sharedState.counters.iterationCount)
{ {
reportError(TypeError{location, UnificationTooComplex{}}); reportError(location, UnificationTooComplex{});
return; return;
} }
@ -404,7 +406,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
if (subGeneric && !subsumes(useScopes, subGeneric, superFree)) if (subGeneric && !subsumes(useScopes, subGeneric, superFree))
{ {
// TODO: a more informative error message? CLI-39912 // TODO: a more informative error message? CLI-39912
reportError(TypeError{location, GenericError{"Generic subtype escaping scope"}}); reportError(location, GenericError{"Generic subtype escaping scope"});
return; return;
} }
@ -433,7 +435,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
if (superGeneric && !subsumes(useScopes, superGeneric, subFree)) if (superGeneric && !subsumes(useScopes, superGeneric, subFree))
{ {
// TODO: a more informative error message? CLI-39912 // TODO: a more informative error message? CLI-39912
reportError(TypeError{location, GenericError{"Generic supertype escaping scope"}}); reportError(location, GenericError{"Generic supertype escaping scope"});
return; return;
} }
@ -450,15 +452,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
return tryUnifyWithAny(subTy, superTy); return tryUnifyWithAny(subTy, superTy);
if (get<AnyTypeVar>(subTy)) if (get<AnyTypeVar>(subTy))
{ return tryUnifyWithAny(superTy, subTy);
if (anyIsTop)
{
reportError(TypeError{location, TypeMismatch{superTy, subTy}});
return;
}
else
return tryUnifyWithAny(superTy, subTy);
}
if (log.get<ErrorTypeVar>(subTy)) if (log.get<ErrorTypeVar>(subTy))
return tryUnifyWithAny(superTy, subTy); return tryUnifyWithAny(superTy, subTy);
@ -478,7 +472,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
if (auto error = sharedState.cachedUnifyError.find({subTy, superTy})) if (auto error = sharedState.cachedUnifyError.find({subTy, superTy}))
{ {
reportError(TypeError{location, *error}); reportError(location, *error);
return; return;
} }
} }
@ -520,6 +514,12 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
else if ((log.getMutable<PrimitiveTypeVar>(superTy) || log.getMutable<SingletonTypeVar>(superTy)) && log.getMutable<SingletonTypeVar>(subTy)) else if ((log.getMutable<PrimitiveTypeVar>(superTy) || log.getMutable<SingletonTypeVar>(superTy)) && log.getMutable<SingletonTypeVar>(subTy))
tryUnifySingletons(subTy, superTy); tryUnifySingletons(subTy, superTy);
else if (auto ptv = get<PrimitiveTypeVar>(superTy);
FFlag::LuauNegatedFunctionTypes && ptv && ptv->type == PrimitiveTypeVar::Function && get<FunctionTypeVar>(subTy))
{
// Ok. Do nothing. forall functions F, F <: function
}
else if (log.getMutable<FunctionTypeVar>(superTy) && log.getMutable<FunctionTypeVar>(subTy)) else if (log.getMutable<FunctionTypeVar>(superTy) && log.getMutable<FunctionTypeVar>(subTy))
tryUnifyFunctions(subTy, superTy, isFunctionCall); tryUnifyFunctions(subTy, superTy, isFunctionCall);
@ -559,7 +559,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
tryUnifyNegationWithType(subTy, superTy); tryUnifyNegationWithType(subTy, superTy);
else else
reportError(TypeError{location, TypeMismatch{superTy, subTy}}); reportError(location, TypeMismatch{superTy, subTy});
if (cacheEnabled) if (cacheEnabled)
cacheResult(subTy, superTy, errorCount); cacheResult(subTy, superTy, errorCount);
@ -633,9 +633,9 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionTypeVar* subUnion,
else if (failed) else if (failed)
{ {
if (firstFailedOption) if (firstFailedOption)
reportError(TypeError{location, TypeMismatch{superTy, subTy, "Not all union options are compatible.", *firstFailedOption}}); reportError(location, TypeMismatch{superTy, subTy, "Not all union options are compatible.", *firstFailedOption});
else else
reportError(TypeError{location, TypeMismatch{superTy, subTy}}); reportError(location, TypeMismatch{superTy, subTy});
} }
} }
@ -734,7 +734,7 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp
const NormalizedType* subNorm = normalizer->normalize(subTy); const NormalizedType* subNorm = normalizer->normalize(subTy);
const NormalizedType* superNorm = normalizer->normalize(superTy); const NormalizedType* superNorm = normalizer->normalize(superTy);
if (!subNorm || !superNorm) if (!subNorm || !superNorm)
reportError(TypeError{location, UnificationTooComplex{}}); reportError(location, UnificationTooComplex{});
else if ((failedOptionCount == 1 || foundHeuristic) && failedOption) else if ((failedOptionCount == 1 || foundHeuristic) && failedOption)
tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption); tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption);
else else
@ -743,9 +743,9 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp
else if (!found) else if (!found)
{ {
if ((failedOptionCount == 1 || foundHeuristic) && failedOption) if ((failedOptionCount == 1 || foundHeuristic) && failedOption)
reportError(TypeError{location, TypeMismatch{superTy, subTy, "None of the union options are compatible. For example:", *failedOption}}); reportError(location, TypeMismatch{superTy, subTy, "None of the union options are compatible. For example:", *failedOption});
else else
reportError(TypeError{location, TypeMismatch{superTy, subTy, "none of the union options are compatible"}}); reportError(location, TypeMismatch{superTy, subTy, "none of the union options are compatible"});
} }
} }
@ -774,7 +774,7 @@ void Unifier::tryUnifyTypeWithIntersection(TypeId subTy, TypeId superTy, const I
if (unificationTooComplex) if (unificationTooComplex)
reportError(*unificationTooComplex); reportError(*unificationTooComplex);
else if (firstFailedOption) else if (firstFailedOption)
reportError(TypeError{location, TypeMismatch{superTy, subTy, "Not all intersection parts are compatible.", *firstFailedOption}}); reportError(location, TypeMismatch{superTy, subTy, "Not all intersection parts are compatible.", *firstFailedOption});
} }
void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionTypeVar* uv, TypeId superTy, bool cacheEnabled, bool isFunctionCall) void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionTypeVar* uv, TypeId superTy, bool cacheEnabled, bool isFunctionCall)
@ -832,11 +832,11 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionTypeV
if (subNorm && superNorm) if (subNorm && superNorm)
tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the intersection parts are compatible"); tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the intersection parts are compatible");
else else
reportError(TypeError{location, UnificationTooComplex{}}); reportError(location, UnificationTooComplex{});
} }
else if (!found) else if (!found)
{ {
reportError(TypeError{location, TypeMismatch{superTy, subTy, "none of the intersection parts are compatible"}}); reportError(location, TypeMismatch{superTy, subTy, "none of the intersection parts are compatible"});
} }
} }
@ -848,37 +848,37 @@ void Unifier::tryUnifyNormalizedTypes(
if (get<UnknownTypeVar>(superNorm.tops) || get<AnyTypeVar>(superNorm.tops) || get<AnyTypeVar>(subNorm.tops)) if (get<UnknownTypeVar>(superNorm.tops) || get<AnyTypeVar>(superNorm.tops) || get<AnyTypeVar>(subNorm.tops))
return; return;
else if (get<UnknownTypeVar>(subNorm.tops)) else if (get<UnknownTypeVar>(subNorm.tops))
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); return reportError(location, TypeMismatch{superTy, subTy, reason, error});
if (get<ErrorTypeVar>(subNorm.errors)) if (get<ErrorTypeVar>(subNorm.errors))
if (!get<ErrorTypeVar>(superNorm.errors)) if (!get<ErrorTypeVar>(superNorm.errors))
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); return reportError(location, TypeMismatch{superTy, subTy, reason, error});
if (get<PrimitiveTypeVar>(subNorm.booleans)) if (get<PrimitiveTypeVar>(subNorm.booleans))
{ {
if (!get<PrimitiveTypeVar>(superNorm.booleans)) if (!get<PrimitiveTypeVar>(superNorm.booleans))
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); return reportError(location, TypeMismatch{superTy, subTy, reason, error});
} }
else if (const SingletonTypeVar* stv = get<SingletonTypeVar>(subNorm.booleans)) else if (const SingletonTypeVar* stv = get<SingletonTypeVar>(subNorm.booleans))
{ {
if (!get<PrimitiveTypeVar>(superNorm.booleans) && stv != get<SingletonTypeVar>(superNorm.booleans)) if (!get<PrimitiveTypeVar>(superNorm.booleans) && stv != get<SingletonTypeVar>(superNorm.booleans))
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); return reportError(location, TypeMismatch{superTy, subTy, reason, error});
} }
if (get<PrimitiveTypeVar>(subNorm.nils)) if (get<PrimitiveTypeVar>(subNorm.nils))
if (!get<PrimitiveTypeVar>(superNorm.nils)) if (!get<PrimitiveTypeVar>(superNorm.nils))
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); return reportError(location, TypeMismatch{superTy, subTy, reason, error});
if (get<PrimitiveTypeVar>(subNorm.numbers)) if (get<PrimitiveTypeVar>(subNorm.numbers))
if (!get<PrimitiveTypeVar>(superNorm.numbers)) if (!get<PrimitiveTypeVar>(superNorm.numbers))
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); return reportError(location, TypeMismatch{superTy, subTy, reason, error});
if (!isSubtype(subNorm.strings, superNorm.strings)) if (!isSubtype(subNorm.strings, superNorm.strings))
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); return reportError(location, TypeMismatch{superTy, subTy, reason, error});
if (get<PrimitiveTypeVar>(subNorm.threads)) if (get<PrimitiveTypeVar>(subNorm.threads))
if (!get<PrimitiveTypeVar>(superNorm.errors)) if (!get<PrimitiveTypeVar>(superNorm.errors))
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); return reportError(location, TypeMismatch{superTy, subTy, reason, error});
for (TypeId subClass : subNorm.classes) for (TypeId subClass : subNorm.classes)
{ {
@ -894,7 +894,7 @@ void Unifier::tryUnifyNormalizedTypes(
} }
} }
if (!found) if (!found)
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); return reportError(location, TypeMismatch{superTy, subTy, reason, error});
} }
for (TypeId subTable : subNorm.tables) for (TypeId subTable : subNorm.tables)
@ -919,21 +919,19 @@ void Unifier::tryUnifyNormalizedTypes(
return reportError(*e); return reportError(*e);
} }
if (!found) if (!found)
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); return reportError(location, TypeMismatch{superTy, subTy, reason, error});
} }
if (subNorm.functions) if (!subNorm.functions.isNever())
{ {
if (!superNorm.functions) if (superNorm.functions.isNever())
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); return reportError(location, TypeMismatch{superTy, subTy, reason, error});
if (superNorm.functions->empty()) for (TypeId superFun : *superNorm.functions.parts)
return;
for (TypeId superFun : *superNorm.functions)
{ {
Unifier innerState = makeChildUnifier(); Unifier innerState = makeChildUnifier();
const FunctionTypeVar* superFtv = get<FunctionTypeVar>(superFun); const FunctionTypeVar* superFtv = get<FunctionTypeVar>(superFun);
if (!superFtv) if (!superFtv)
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); return reportError(location, TypeMismatch{superTy, subTy, reason, error});
TypePackId tgt = innerState.tryApplyOverloadedFunction(subTy, subNorm.functions, superFtv->argTypes); TypePackId tgt = innerState.tryApplyOverloadedFunction(subTy, subNorm.functions, superFtv->argTypes);
innerState.tryUnify_(tgt, superFtv->retTypes); innerState.tryUnify_(tgt, superFtv->retTypes);
if (innerState.errors.empty()) if (innerState.errors.empty())
@ -941,7 +939,7 @@ void Unifier::tryUnifyNormalizedTypes(
else if (auto e = hasUnificationTooComplex(innerState.errors)) else if (auto e = hasUnificationTooComplex(innerState.errors))
return reportError(*e); return reportError(*e);
else else
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); return reportError(location, TypeMismatch{superTy, subTy, reason, error});
} }
} }
@ -959,15 +957,15 @@ void Unifier::tryUnifyNormalizedTypes(
TypePackId Unifier::tryApplyOverloadedFunction(TypeId function, const NormalizedFunctionType& overloads, TypePackId args) TypePackId Unifier::tryApplyOverloadedFunction(TypeId function, const NormalizedFunctionType& overloads, TypePackId args)
{ {
if (!overloads || overloads->empty()) if (overloads.isNever())
{ {
reportError(TypeError{location, CannotCallNonFunction{function}}); reportError(location, CannotCallNonFunction{function});
return singletonTypes->errorRecoveryTypePack(); return singletonTypes->errorRecoveryTypePack();
} }
std::optional<TypePackId> result; std::optional<TypePackId> result;
const FunctionTypeVar* firstFun = nullptr; const FunctionTypeVar* firstFun = nullptr;
for (TypeId overload : *overloads) for (TypeId overload : *overloads.parts)
{ {
if (const FunctionTypeVar* ftv = get<FunctionTypeVar>(overload)) if (const FunctionTypeVar* ftv = get<FunctionTypeVar>(overload))
{ {
@ -1015,12 +1013,12 @@ TypePackId Unifier::tryApplyOverloadedFunction(TypeId function, const Normalized
// TODO: better error reporting? // TODO: better error reporting?
// The logic for error reporting overload resolution // The logic for error reporting overload resolution
// is currently over in TypeInfer.cpp, should we move it? // is currently over in TypeInfer.cpp, should we move it?
reportError(TypeError{location, GenericError{"No matching overload."}}); reportError(location, GenericError{"No matching overload."});
return singletonTypes->errorRecoveryTypePack(firstFun->retTypes); return singletonTypes->errorRecoveryTypePack(firstFun->retTypes);
} }
else else
{ {
reportError(TypeError{location, CannotCallNonFunction{function}}); reportError(location, CannotCallNonFunction{function});
return singletonTypes->errorRecoveryTypePack(); return singletonTypes->errorRecoveryTypePack();
} }
} }
@ -1199,7 +1197,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal
if (sharedState.counters.iterationLimit > 0 && sharedState.counters.iterationLimit < sharedState.counters.iterationCount) if (sharedState.counters.iterationLimit > 0 && sharedState.counters.iterationLimit < sharedState.counters.iterationCount)
{ {
reportError(TypeError{location, UnificationTooComplex{}}); reportError(location, UnificationTooComplex{});
return; return;
} }
@ -1372,7 +1370,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal
size_t actualSize = size(subTp); size_t actualSize = size(subTp);
if (ctx == CountMismatch::FunctionResult || ctx == CountMismatch::ExprListResult) if (ctx == CountMismatch::FunctionResult || ctx == CountMismatch::ExprListResult)
std::swap(expectedSize, actualSize); std::swap(expectedSize, actualSize);
reportError(TypeError{location, CountMismatch{expectedSize, std::nullopt, actualSize, ctx}}); reportError(location, CountMismatch{expectedSize, std::nullopt, actualSize, ctx});
while (superIter.good()) while (superIter.good())
{ {
@ -1394,9 +1392,9 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal
else else
{ {
if (FFlag::LuauReportTypeMismatchForTypePackUnificationFailure) if (FFlag::LuauReportTypeMismatchForTypePackUnificationFailure)
reportError(TypeError{location, TypePackMismatch{subTp, superTp}}); reportError(location, TypePackMismatch{subTp, superTp});
else else
reportError(TypeError{location, GenericError{"Failed to unify type packs"}}); reportError(location, GenericError{"Failed to unify type packs"});
} }
} }
@ -1408,7 +1406,7 @@ void Unifier::tryUnifyPrimitives(TypeId subTy, TypeId superTy)
ice("passed non primitive types to unifyPrimitives"); ice("passed non primitive types to unifyPrimitives");
if (superPrim->type != subPrim->type) if (superPrim->type != subPrim->type)
reportError(TypeError{location, TypeMismatch{superTy, subTy}}); reportError(location, TypeMismatch{superTy, subTy});
} }
void Unifier::tryUnifySingletons(TypeId subTy, TypeId superTy) void Unifier::tryUnifySingletons(TypeId subTy, TypeId superTy)
@ -1429,7 +1427,7 @@ void Unifier::tryUnifySingletons(TypeId subTy, TypeId superTy)
if (superPrim && superPrim->type == PrimitiveTypeVar::String && get<StringSingleton>(subSingleton) && variance == Covariant) if (superPrim && superPrim->type == PrimitiveTypeVar::String && get<StringSingleton>(subSingleton) && variance == Covariant)
return; return;
reportError(TypeError{location, TypeMismatch{superTy, subTy}}); reportError(location, TypeMismatch{superTy, subTy});
} }
void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCall) void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCall)
@ -1465,21 +1463,21 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal
} }
else else
{ {
reportError(TypeError{location, UnificationTooComplex{}}); reportError(location, UnificationTooComplex{});
} }
} }
else if (numGenerics != subFunction->generics.size()) else if (numGenerics != subFunction->generics.size())
{ {
numGenerics = std::min(superFunction->generics.size(), subFunction->generics.size()); numGenerics = std::min(superFunction->generics.size(), subFunction->generics.size());
reportError(TypeError{location, TypeMismatch{superTy, subTy, "different number of generic type parameters"}}); reportError(location, TypeMismatch{superTy, subTy, "different number of generic type parameters"});
} }
if (numGenericPacks != subFunction->genericPacks.size()) if (numGenericPacks != subFunction->genericPacks.size())
{ {
numGenericPacks = std::min(superFunction->genericPacks.size(), subFunction->genericPacks.size()); numGenericPacks = std::min(superFunction->genericPacks.size(), subFunction->genericPacks.size());
reportError(TypeError{location, TypeMismatch{superTy, subTy, "different number of generic type pack parameters"}}); reportError(location, TypeMismatch{superTy, subTy, "different number of generic type pack parameters"});
} }
for (size_t i = 0; i < numGenerics; i++) for (size_t i = 0; i < numGenerics; i++)
@ -1506,11 +1504,10 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal
if (auto e = hasUnificationTooComplex(innerState.errors)) if (auto e = hasUnificationTooComplex(innerState.errors))
reportError(*e); reportError(*e);
else if (!innerState.errors.empty() && innerState.firstPackErrorPos) else if (!innerState.errors.empty() && innerState.firstPackErrorPos)
reportError( reportError(location, TypeMismatch{superTy, subTy, format("Argument #%d type is not compatible.", *innerState.firstPackErrorPos),
TypeError{location, TypeMismatch{superTy, subTy, format("Argument #%d type is not compatible.", *innerState.firstPackErrorPos), innerState.errors.front()});
innerState.errors.front()}});
else if (!innerState.errors.empty()) else if (!innerState.errors.empty())
reportError(TypeError{location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}}); reportError(location, TypeMismatch{superTy, subTy, "", innerState.errors.front()});
innerState.ctx = CountMismatch::FunctionResult; innerState.ctx = CountMismatch::FunctionResult;
innerState.tryUnify_(subFunction->retTypes, superFunction->retTypes); innerState.tryUnify_(subFunction->retTypes, superFunction->retTypes);
@ -1520,13 +1517,12 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal
if (auto e = hasUnificationTooComplex(innerState.errors)) if (auto e = hasUnificationTooComplex(innerState.errors))
reportError(*e); reportError(*e);
else if (!innerState.errors.empty() && size(superFunction->retTypes) == 1 && finite(superFunction->retTypes)) else if (!innerState.errors.empty() && size(superFunction->retTypes) == 1 && finite(superFunction->retTypes))
reportError(TypeError{location, TypeMismatch{superTy, subTy, "Return type is not compatible.", innerState.errors.front()}}); reportError(location, TypeMismatch{superTy, subTy, "Return type is not compatible.", innerState.errors.front()});
else if (!innerState.errors.empty() && innerState.firstPackErrorPos) else if (!innerState.errors.empty() && innerState.firstPackErrorPos)
reportError( reportError(location, TypeMismatch{superTy, subTy, format("Return #%d type is not compatible.", *innerState.firstPackErrorPos),
TypeError{location, TypeMismatch{superTy, subTy, format("Return #%d type is not compatible.", *innerState.firstPackErrorPos), innerState.errors.front()});
innerState.errors.front()}});
else if (!innerState.errors.empty()) else if (!innerState.errors.empty())
reportError(TypeError{location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}}); reportError(location, TypeMismatch{superTy, subTy, "", innerState.errors.front()});
} }
log.concat(std::move(innerState.log)); log.concat(std::move(innerState.log));
@ -1608,7 +1604,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection)
} }
else else
{ {
reportError(TypeError{location, UnificationTooComplex{}}); reportError(location, UnificationTooComplex{});
} }
} }
} }
@ -1626,7 +1622,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection)
if (!missingProperties.empty()) if (!missingProperties.empty())
{ {
reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(missingProperties)}}); reportError(location, MissingProperties{superTy, subTy, std::move(missingProperties)});
return; return;
} }
} }
@ -1644,7 +1640,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection)
if (!extraProperties.empty()) if (!extraProperties.empty())
{ {
reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra}}); reportError(location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra});
return; return;
} }
} }
@ -1825,13 +1821,13 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection)
if (!missingProperties.empty()) if (!missingProperties.empty())
{ {
reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(missingProperties)}}); reportError(location, MissingProperties{superTy, subTy, std::move(missingProperties)});
return; return;
} }
if (!extraProperties.empty()) if (!extraProperties.empty())
{ {
reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra}}); reportError(location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra});
return; return;
} }
@ -1867,14 +1863,14 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed)
std::swap(subTy, superTy); std::swap(subTy, superTy);
if (auto ttv = log.get<TableTypeVar>(superTy); !ttv || ttv->state != TableState::Free) if (auto ttv = log.get<TableTypeVar>(superTy); !ttv || ttv->state != TableState::Free)
return reportError(TypeError{location, TypeMismatch{osuperTy, osubTy}}); return reportError(location, TypeMismatch{osuperTy, osubTy});
auto fail = [&](std::optional<TypeError> e) { auto fail = [&](std::optional<TypeError> e) {
std::string reason = "The former's metatable does not satisfy the requirements."; std::string reason = "The former's metatable does not satisfy the requirements.";
if (e) if (e)
reportError(TypeError{location, TypeMismatch{osuperTy, osubTy, reason, *e}}); reportError(location, TypeMismatch{osuperTy, osubTy, reason, *e});
else else
reportError(TypeError{location, TypeMismatch{osuperTy, osubTy, reason}}); reportError(location, TypeMismatch{osuperTy, osubTy, reason});
}; };
// Given t1 where t1 = { lower: (t1) -> (a, b...) } // Given t1 where t1 = { lower: (t1) -> (a, b...) }
@ -1906,7 +1902,7 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed)
} }
} }
reportError(TypeError{location, TypeMismatch{osuperTy, osubTy}}); reportError(location, TypeMismatch{osuperTy, osubTy});
return; return;
} }
@ -1947,7 +1943,7 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed)
if (auto e = hasUnificationTooComplex(innerState.errors)) if (auto e = hasUnificationTooComplex(innerState.errors))
reportError(*e); reportError(*e);
else if (!innerState.errors.empty()) else if (!innerState.errors.empty())
reportError(TypeError{location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front()}}); reportError(location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front()});
log.concat(std::move(innerState.log)); log.concat(std::move(innerState.log));
} }
@ -2024,9 +2020,9 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed)
auto fail = [&]() { auto fail = [&]() {
if (!reversed) if (!reversed)
reportError(TypeError{location, TypeMismatch{superTy, subTy}}); reportError(location, TypeMismatch{superTy, subTy});
else else
reportError(TypeError{location, TypeMismatch{subTy, superTy}}); reportError(location, TypeMismatch{subTy, superTy});
}; };
const ClassTypeVar* superClass = get<ClassTypeVar>(superTy); const ClassTypeVar* superClass = get<ClassTypeVar>(superTy);
@ -2071,7 +2067,7 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed)
if (!classProp) if (!classProp)
{ {
ok = false; ok = false;
reportError(TypeError{location, UnknownProperty{superTy, propName}}); reportError(location, UnknownProperty{superTy, propName});
} }
else else
{ {
@ -2095,7 +2091,7 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed)
{ {
ok = false; ok = false;
std::string msg = "Class " + superClass->name + " does not have an indexer"; std::string msg = "Class " + superClass->name + " does not have an indexer";
reportError(TypeError{location, GenericError{msg}}); reportError(location, GenericError{msg});
} }
if (!ok) if (!ok)
@ -2116,13 +2112,13 @@ void Unifier::tryUnifyTypeWithNegation(TypeId subTy, TypeId superTy)
const NormalizedType* subNorm = normalizer->normalize(subTy); const NormalizedType* subNorm = normalizer->normalize(subTy);
const NormalizedType* superNorm = normalizer->normalize(superTy); const NormalizedType* superNorm = normalizer->normalize(superTy);
if (!subNorm || !superNorm) if (!subNorm || !superNorm)
return reportError(TypeError{location, UnificationTooComplex{}}); return reportError(location, UnificationTooComplex{});
// T </: ~U iff T <: U // T </: ~U iff T <: U
Unifier state = makeChildUnifier(); Unifier state = makeChildUnifier();
state.tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, ""); state.tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "");
if (state.errors.empty()) if (state.errors.empty())
reportError(TypeError{location, TypeMismatch{superTy, subTy}}); reportError(location, TypeMismatch{superTy, subTy});
} }
void Unifier::tryUnifyNegationWithType(TypeId subTy, TypeId superTy) void Unifier::tryUnifyNegationWithType(TypeId subTy, TypeId superTy)
@ -2132,7 +2128,7 @@ void Unifier::tryUnifyNegationWithType(TypeId subTy, TypeId superTy)
ice("tryUnifyNegationWithType subTy must be a negation type"); ice("tryUnifyNegationWithType subTy must be a negation type");
// TODO: ~T </: U iff T <: U // TODO: ~T </: U iff T <: U
reportError(TypeError{location, TypeMismatch{superTy, subTy}}); reportError(location, TypeMismatch{superTy, subTy});
} }
static void queueTypePack(std::vector<TypeId>& queue, DenseHashSet<TypePackId>& seenTypePacks, Unifier& state, TypePackId a, TypePackId anyTypePack) static void queueTypePack(std::vector<TypeId>& queue, DenseHashSet<TypePackId>& seenTypePacks, Unifier& state, TypePackId a, TypePackId anyTypePack)
@ -2195,7 +2191,7 @@ void Unifier::tryUnifyVariadics(TypePackId subTp, TypePackId superTp, bool rever
} }
else if (get<Unifiable::Generic>(tail)) else if (get<Unifiable::Generic>(tail))
{ {
reportError(TypeError{location, GenericError{"Cannot unify variadic and generic packs"}}); reportError(location, GenericError{"Cannot unify variadic and generic packs"});
} }
else if (get<Unifiable::Error>(tail)) else if (get<Unifiable::Error>(tail))
{ {
@ -2209,7 +2205,7 @@ void Unifier::tryUnifyVariadics(TypePackId subTp, TypePackId superTp, bool rever
} }
else else
{ {
reportError(TypeError{location, GenericError{"Failed to unify variadic packs"}}); reportError(location, GenericError{"Failed to unify variadic packs"});
} }
} }
@ -2351,7 +2347,7 @@ bool Unifier::occursCheck(DenseHashSet<TypeId>& seen, TypeId needle, TypeId hays
if (needle == haystack) if (needle == haystack)
{ {
reportError(TypeError{location, OccursCheckFailed{}}); reportError(location, OccursCheckFailed{});
log.replace(needle, *singletonTypes->errorRecoveryType()); log.replace(needle, *singletonTypes->errorRecoveryType());
return true; return true;
@ -2402,7 +2398,7 @@ bool Unifier::occursCheck(DenseHashSet<TypePackId>& seen, TypePackId needle, Typ
{ {
if (needle == haystack) if (needle == haystack)
{ {
reportError(TypeError{location, OccursCheckFailed{}}); reportError(location, OccursCheckFailed{});
log.replace(needle, *singletonTypes->errorRecoveryTypePack()); log.replace(needle, *singletonTypes->errorRecoveryTypePack());
return true; return true;
@ -2423,18 +2419,31 @@ bool Unifier::occursCheck(DenseHashSet<TypePackId>& seen, TypePackId needle, Typ
Unifier Unifier::makeChildUnifier() Unifier Unifier::makeChildUnifier()
{ {
Unifier u = Unifier{normalizer, mode, scope, location, variance, &log}; Unifier u = Unifier{normalizer, mode, scope, location, variance, &log};
u.anyIsTop = anyIsTop;
u.normalize = normalize; u.normalize = normalize;
u.useScopes = useScopes;
return u; return u;
} }
// A utility function that appends the given error to the unifier's error log. // A utility function that appends the given error to the unifier's error log.
// This allows setting a breakpoint wherever the unifier reports an error. // This allows setting a breakpoint wherever the unifier reports an error.
//
// Note: report error accepts its arguments by value intentionally to reduce the stack usage of functions which call `reportError`.
void Unifier::reportError(Location location, TypeErrorData data)
{
errors.emplace_back(std::move(location), std::move(data));
}
// A utility function that appends the given error to the unifier's error log.
// This allows setting a breakpoint wherever the unifier reports an error.
//
// Note: to conserve stack space in calling functions it is generally preferred to call `Unifier::reportError(Location location, TypeErrorData data)`
// instead of this method.
void Unifier::reportError(TypeError err) void Unifier::reportError(TypeError err)
{ {
errors.push_back(std::move(err)); errors.push_back(std::move(err));
} }
bool Unifier::isNonstrictMode() const bool Unifier::isNonstrictMode() const
{ {
return (mode == Mode::Nonstrict) || (mode == Mode::NoCheck); return (mode == Mode::Nonstrict) || (mode == Mode::NoCheck);
@ -2445,7 +2454,7 @@ void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, TypeId
if (auto e = hasUnificationTooComplex(innerErrors)) if (auto e = hasUnificationTooComplex(innerErrors))
reportError(*e); reportError(*e);
else if (!innerErrors.empty()) else if (!innerErrors.empty())
reportError(TypeError{location, TypeMismatch{wantedType, givenType}}); reportError(location, TypeMismatch{wantedType, givenType});
} }
void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, const std::string& prop, TypeId wantedType, TypeId givenType) void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, const std::string& prop, TypeId wantedType, TypeId givenType)

View file

@ -25,6 +25,7 @@ LUAU_DYNAMIC_FASTFLAGVARIABLE(LuaReportParseIntegerIssues, false)
LUAU_FASTFLAGVARIABLE(LuauInterpolatedStringBaseSupport, false) LUAU_FASTFLAGVARIABLE(LuauInterpolatedStringBaseSupport, false)
LUAU_FASTFLAGVARIABLE(LuauCommaParenWarnings, false) LUAU_FASTFLAGVARIABLE(LuauCommaParenWarnings, false)
LUAU_FASTFLAGVARIABLE(LuauTableConstructorRecovery, false)
bool lua_telemetry_parsed_out_of_range_bin_integer = false; bool lua_telemetry_parsed_out_of_range_bin_integer = false;
bool lua_telemetry_parsed_out_of_range_hex_integer = false; bool lua_telemetry_parsed_out_of_range_hex_integer = false;
@ -2310,9 +2311,13 @@ AstExpr* Parser::parseTableConstructor()
MatchLexeme matchBrace = lexer.current(); MatchLexeme matchBrace = lexer.current();
expectAndConsume('{', "table literal"); expectAndConsume('{', "table literal");
unsigned lastElementIndent = 0;
while (lexer.current().type != '}') while (lexer.current().type != '}')
{ {
if (FFlag::LuauTableConstructorRecovery)
lastElementIndent = lexer.current().location.begin.column;
if (lexer.current().type == '[') if (lexer.current().type == '[')
{ {
MatchLexeme matchLocationBracket = lexer.current(); MatchLexeme matchLocationBracket = lexer.current();
@ -2357,10 +2362,14 @@ AstExpr* Parser::parseTableConstructor()
{ {
nextLexeme(); nextLexeme();
} }
else else if (FFlag::LuauTableConstructorRecovery && (lexer.current().type == '[' || lexer.current().type == Lexeme::Name) &&
lexer.current().location.begin.column == lastElementIndent)
{ {
if (lexer.current().type != '}') report(lexer.current().location, "Expected ',' after table constructor element");
break; }
else if (lexer.current().type != '}')
{
break;
} }
} }

View file

@ -978,7 +978,8 @@ int replMain(int argc, char** argv)
if (compileFormat == CompileFormat::Null) if (compileFormat == CompileFormat::Null)
printf("Compiled %d KLOC into %d KB bytecode\n", int(stats.lines / 1000), int(stats.bytecode / 1024)); printf("Compiled %d KLOC into %d KB bytecode\n", int(stats.lines / 1000), int(stats.bytecode / 1024));
else if (compileFormat == CompileFormat::CodegenNull) else if (compileFormat == CompileFormat::CodegenNull)
printf("Compiled %d KLOC into %d KB bytecode => %d KB native code\n", int(stats.lines / 1000), int(stats.bytecode / 1024), int(stats.codegen / 1024)); printf("Compiled %d KLOC into %d KB bytecode => %d KB native code\n", int(stats.lines / 1000), int(stats.bytecode / 1024),
int(stats.codegen / 1024));
return failed ? 1 : 0; return failed ? 1 : 0;
} }

View file

@ -0,0 +1,52 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/RegisterA64.h"
namespace Luau
{
namespace CodeGen
{
enum class AddressKindA64 : uint8_t
{
imm, // reg + imm
reg, // reg + reg
// TODO:
// reg + reg << shift
// reg + sext(reg) << shift
// reg + uext(reg) << shift
// pc + offset
};
struct AddressA64
{
AddressA64(RegisterA64 base, int off = 0)
: kind(AddressKindA64::imm)
, base(base)
, offset(xzr)
, data(off)
{
LUAU_ASSERT(base.kind == KindA64::x);
LUAU_ASSERT(off >= 0 && off < 4096);
}
AddressA64(RegisterA64 base, RegisterA64 offset)
: kind(AddressKindA64::reg)
, base(base)
, offset(offset)
, data(0)
{
LUAU_ASSERT(base.kind == KindA64::x);
LUAU_ASSERT(offset.kind == KindA64::x);
}
AddressKindA64 kind;
RegisterA64 base;
RegisterA64 offset;
int data;
};
} // namespace CodeGen
} // namespace Luau

View file

@ -0,0 +1,144 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/RegisterA64.h"
#include "Luau/AddressA64.h"
#include "Luau/ConditionA64.h"
#include "Luau/Label.h"
#include <string>
#include <vector>
namespace Luau
{
namespace CodeGen
{
class AssemblyBuilderA64
{
public:
explicit AssemblyBuilderA64(bool logText);
~AssemblyBuilderA64();
// Moves
void mov(RegisterA64 dst, RegisterA64 src);
void mov(RegisterA64 dst, uint16_t src, int shift = 0);
void movk(RegisterA64 dst, uint16_t src, int shift = 0);
// Arithmetics
void add(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift = 0);
void add(RegisterA64 dst, RegisterA64 src1, int src2);
void sub(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift = 0);
void sub(RegisterA64 dst, RegisterA64 src1, int src2);
void neg(RegisterA64 dst, RegisterA64 src);
// Comparisons
// Note: some arithmetic instructions also have versions that update flags (ADDS etc) but we aren't using them atm
// TODO: add cmp
// Binary
// Note: shifted-register support and bitfield operations are omitted for simplicity
// TODO: support immediate arguments (they have odd encoding and forbid many values)
// TODO: support not variants for and/or/eor (required to support not...)
void and_(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2);
void orr(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2);
void eor(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2);
void lsl(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2);
void lsr(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2);
void asr(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2);
void ror(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2);
void clz(RegisterA64 dst, RegisterA64 src);
void rbit(RegisterA64 dst, RegisterA64 src);
// Load
// Note: paired loads are currently omitted for simplicity
void ldr(RegisterA64 dst, AddressA64 src);
void ldrb(RegisterA64 dst, AddressA64 src);
void ldrh(RegisterA64 dst, AddressA64 src);
void ldrsb(RegisterA64 dst, AddressA64 src);
void ldrsh(RegisterA64 dst, AddressA64 src);
void ldrsw(RegisterA64 dst, AddressA64 src);
// Store
void str(RegisterA64 src, AddressA64 dst);
void strb(RegisterA64 src, AddressA64 dst);
void strh(RegisterA64 src, AddressA64 dst);
// Control flow
// Note: tbz/tbnz are currently not supported because they have 15-bit offsets and we don't support branch thunks
void b(ConditionA64 cond, Label& label);
void cbz(RegisterA64 src, Label& label);
void cbnz(RegisterA64 src, Label& label);
void ret();
// Run final checks
bool finalize();
// Places a label at current location and returns it
Label setLabel();
// Assigns label position to the current location
void setLabel(Label& label);
void logAppend(const char* fmt, ...) LUAU_PRINTF_ATTR(2, 3);
uint32_t getCodeSize() const;
// Resulting data and code that need to be copied over one after the other
// The *end* of 'data' has to be aligned to 16 bytes, this will also align 'code'
std::vector<uint8_t> data;
std::vector<uint32_t> code;
std::string text;
const bool logText = false;
private:
// Instruction archetypes
void place0(const char* name, uint32_t word);
void placeSR3(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t op, int shift = 0);
void placeSR2(const char* name, RegisterA64 dst, RegisterA64 src, uint8_t op);
void placeR3(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t op, uint8_t op2);
void placeR1(const char* name, RegisterA64 dst, RegisterA64 src, uint32_t op);
void placeI12(const char* name, RegisterA64 dst, RegisterA64 src1, int src2, uint8_t op);
void placeI16(const char* name, RegisterA64 dst, int src, uint8_t op, int shift = 0);
void placeA(const char* name, RegisterA64 dst, AddressA64 src, uint8_t op, uint8_t size);
void placeBC(const char* name, Label& label, uint8_t op, uint8_t cond);
void placeBR(const char* name, Label& label, uint8_t op, RegisterA64 cond);
void place(uint32_t word);
void placeLabel(Label& label);
void commit();
LUAU_NOINLINE void extend();
// Data
size_t allocateData(size_t size, size_t align);
// Logging of assembly in text form
LUAU_NOINLINE void log(const char* opcode);
LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift = 0);
LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, RegisterA64 src1, int src2);
LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, RegisterA64 src);
LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, int src, int shift = 0);
LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, AddressA64 src);
LUAU_NOINLINE void log(const char* opcode, RegisterA64 src, Label label);
LUAU_NOINLINE void log(const char* opcode, Label label);
LUAU_NOINLINE void log(Label label);
LUAU_NOINLINE void log(RegisterA64 reg);
LUAU_NOINLINE void log(AddressA64 addr);
uint32_t nextLabel = 1;
std::vector<Label> pendingLabels;
std::vector<uint32_t> labelLocations;
bool finalized = false;
size_t dataPos = 0;
uint32_t* codePos = nullptr;
uint32_t* codeEnd = nullptr;
};
} // namespace CodeGen
} // namespace Luau

View file

@ -2,8 +2,8 @@
#pragma once #pragma once
#include "Luau/Common.h" #include "Luau/Common.h"
#include "Luau/Condition.h"
#include "Luau/Label.h" #include "Luau/Label.h"
#include "Luau/ConditionX64.h"
#include "Luau/OperandX64.h" #include "Luau/OperandX64.h"
#include "Luau/RegisterX64.h" #include "Luau/RegisterX64.h"
@ -84,7 +84,7 @@ public:
void ret(); void ret();
// Control flow // Control flow
void jcc(Condition cond, Label& label); void jcc(ConditionX64 cond, Label& label);
void jmp(Label& label); void jmp(Label& label);
void jmp(OperandX64 op); void jmp(OperandX64 op);

View file

@ -0,0 +1,37 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
namespace Luau
{
namespace CodeGen
{
enum class ConditionA64
{
Equal,
NotEqual,
CarrySet,
CarryClear,
Minus,
Plus,
Overflow,
NoOverflow,
UnsignedGreater,
UnsignedLessEqual,
GreaterEqual,
Less,
Greater,
LessEqual,
Always,
Count
};
} // namespace CodeGen
} // namespace Luau

View file

@ -6,7 +6,7 @@ namespace Luau
namespace CodeGen namespace CodeGen
{ {
enum class Condition enum class ConditionX64 : uint8_t
{ {
Overflow, Overflow,
NoOverflow, NoOverflow,

View file

@ -61,7 +61,7 @@ struct OperandX64
constexpr OperandX64 operator[](OperandX64&& addr) const constexpr OperandX64 operator[](OperandX64&& addr) const
{ {
LUAU_ASSERT(cat == CategoryX64::mem); LUAU_ASSERT(cat == CategoryX64::mem);
LUAU_ASSERT(memSize != SizeX64::none && index == noreg && scale == 1 && base == noreg && imm == 0); LUAU_ASSERT(index == noreg && scale == 1 && base == noreg && imm == 0);
LUAU_ASSERT(addr.memSize == SizeX64::none); LUAU_ASSERT(addr.memSize == SizeX64::none);
addr.cat = CategoryX64::mem; addr.cat = CategoryX64::mem;
@ -70,13 +70,13 @@ struct OperandX64
} }
}; };
constexpr OperandX64 addr{SizeX64::none, noreg, 1, noreg, 0};
constexpr OperandX64 byte{SizeX64::byte, noreg, 1, noreg, 0}; constexpr OperandX64 byte{SizeX64::byte, noreg, 1, noreg, 0};
constexpr OperandX64 word{SizeX64::word, noreg, 1, noreg, 0}; constexpr OperandX64 word{SizeX64::word, noreg, 1, noreg, 0};
constexpr OperandX64 dword{SizeX64::dword, noreg, 1, noreg, 0}; constexpr OperandX64 dword{SizeX64::dword, noreg, 1, noreg, 0};
constexpr OperandX64 qword{SizeX64::qword, noreg, 1, noreg, 0}; constexpr OperandX64 qword{SizeX64::qword, noreg, 1, noreg, 0};
constexpr OperandX64 xmmword{SizeX64::xmmword, noreg, 1, noreg, 0}; constexpr OperandX64 xmmword{SizeX64::xmmword, noreg, 1, noreg, 0};
constexpr OperandX64 ymmword{SizeX64::ymmword, noreg, 1, noreg, 0}; constexpr OperandX64 ymmword{SizeX64::ymmword, noreg, 1, noreg, 0};
constexpr OperandX64 ptr{sizeof(void*) == 4 ? SizeX64::dword : SizeX64::qword, noreg, 1, noreg, 0};
constexpr OperandX64 operator*(RegisterX64 reg, uint8_t scale) constexpr OperandX64 operator*(RegisterX64 reg, uint8_t scale)
{ {

View file

@ -0,0 +1,105 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Common.h"
#include <stdint.h>
namespace Luau
{
namespace CodeGen
{
enum class KindA64 : uint8_t
{
none,
w, // 32-bit GPR
x, // 64-bit GPR
};
struct RegisterA64
{
KindA64 kind : 3;
uint8_t index : 5;
constexpr bool operator==(RegisterA64 rhs) const
{
return kind == rhs.kind && index == rhs.index;
}
constexpr bool operator!=(RegisterA64 rhs) const
{
return !(*this == rhs);
}
};
constexpr RegisterA64 w0{KindA64::w, 0};
constexpr RegisterA64 w1{KindA64::w, 1};
constexpr RegisterA64 w2{KindA64::w, 2};
constexpr RegisterA64 w3{KindA64::w, 3};
constexpr RegisterA64 w4{KindA64::w, 4};
constexpr RegisterA64 w5{KindA64::w, 5};
constexpr RegisterA64 w6{KindA64::w, 6};
constexpr RegisterA64 w7{KindA64::w, 7};
constexpr RegisterA64 w8{KindA64::w, 8};
constexpr RegisterA64 w9{KindA64::w, 9};
constexpr RegisterA64 w10{KindA64::w, 10};
constexpr RegisterA64 w11{KindA64::w, 11};
constexpr RegisterA64 w12{KindA64::w, 12};
constexpr RegisterA64 w13{KindA64::w, 13};
constexpr RegisterA64 w14{KindA64::w, 14};
constexpr RegisterA64 w15{KindA64::w, 15};
constexpr RegisterA64 w16{KindA64::w, 16};
constexpr RegisterA64 w17{KindA64::w, 17};
constexpr RegisterA64 w18{KindA64::w, 18};
constexpr RegisterA64 w19{KindA64::w, 19};
constexpr RegisterA64 w20{KindA64::w, 20};
constexpr RegisterA64 w21{KindA64::w, 21};
constexpr RegisterA64 w22{KindA64::w, 22};
constexpr RegisterA64 w23{KindA64::w, 23};
constexpr RegisterA64 w24{KindA64::w, 24};
constexpr RegisterA64 w25{KindA64::w, 25};
constexpr RegisterA64 w26{KindA64::w, 26};
constexpr RegisterA64 w27{KindA64::w, 27};
constexpr RegisterA64 w28{KindA64::w, 28};
constexpr RegisterA64 w29{KindA64::w, 29};
constexpr RegisterA64 w30{KindA64::w, 30};
constexpr RegisterA64 wzr{KindA64::w, 31};
constexpr RegisterA64 x0{KindA64::x, 0};
constexpr RegisterA64 x1{KindA64::x, 1};
constexpr RegisterA64 x2{KindA64::x, 2};
constexpr RegisterA64 x3{KindA64::x, 3};
constexpr RegisterA64 x4{KindA64::x, 4};
constexpr RegisterA64 x5{KindA64::x, 5};
constexpr RegisterA64 x6{KindA64::x, 6};
constexpr RegisterA64 x7{KindA64::x, 7};
constexpr RegisterA64 x8{KindA64::x, 8};
constexpr RegisterA64 x9{KindA64::x, 9};
constexpr RegisterA64 x10{KindA64::x, 10};
constexpr RegisterA64 x11{KindA64::x, 11};
constexpr RegisterA64 x12{KindA64::x, 12};
constexpr RegisterA64 x13{KindA64::x, 13};
constexpr RegisterA64 x14{KindA64::x, 14};
constexpr RegisterA64 x15{KindA64::x, 15};
constexpr RegisterA64 x16{KindA64::x, 16};
constexpr RegisterA64 x17{KindA64::x, 17};
constexpr RegisterA64 x18{KindA64::x, 18};
constexpr RegisterA64 x19{KindA64::x, 19};
constexpr RegisterA64 x20{KindA64::x, 20};
constexpr RegisterA64 x21{KindA64::x, 21};
constexpr RegisterA64 x22{KindA64::x, 22};
constexpr RegisterA64 x23{KindA64::x, 23};
constexpr RegisterA64 x24{KindA64::x, 24};
constexpr RegisterA64 x25{KindA64::x, 25};
constexpr RegisterA64 x26{KindA64::x, 26};
constexpr RegisterA64 x27{KindA64::x, 27};
constexpr RegisterA64 x28{KindA64::x, 28};
constexpr RegisterA64 x29{KindA64::x, 29};
constexpr RegisterA64 x30{KindA64::x, 30};
constexpr RegisterA64 xzr{KindA64::x, 31};
constexpr RegisterA64 sp{KindA64::none, 31};
} // namespace CodeGen
} // namespace Luau

View file

@ -128,5 +128,10 @@ constexpr RegisterX64 dwordReg(RegisterX64 reg)
return RegisterX64{SizeX64::dword, reg.index}; return RegisterX64{SizeX64::dword, reg.index};
} }
constexpr RegisterX64 qwordReg(RegisterX64 reg)
{
return RegisterX64{SizeX64::qword, reg.index};
}
} // namespace CodeGen } // namespace CodeGen
} // namespace Luau } // namespace Luau

View file

@ -0,0 +1,607 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/AssemblyBuilderA64.h"
#include "ByteUtils.h"
#include <stdarg.h>
namespace Luau
{
namespace CodeGen
{
static const uint8_t codeForCondition[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14};
static_assert(sizeof(codeForCondition) / sizeof(codeForCondition[0]) == size_t(ConditionA64::Count), "all conditions have to be covered");
static const char* textForCondition[] = {
"b.eq", "b.ne", "b.cs", "b.cc", "b.mi", "b.pl", "b.vs", "b.vc", "b.hi", "b.ls", "b.ge", "b.lt", "b.gt", "b.le", "b.al"};
static_assert(sizeof(textForCondition) / sizeof(textForCondition[0]) == size_t(ConditionA64::Count), "all conditions have to be covered");
const unsigned kMaxAlign = 32;
AssemblyBuilderA64::AssemblyBuilderA64(bool logText)
: logText(logText)
{
data.resize(4096);
dataPos = data.size(); // data is filled backwards
code.resize(1024);
codePos = code.data();
codeEnd = code.data() + code.size();
}
AssemblyBuilderA64::~AssemblyBuilderA64()
{
LUAU_ASSERT(finalized);
}
void AssemblyBuilderA64::mov(RegisterA64 dst, RegisterA64 src)
{
placeSR2("mov", dst, src, 0b01'01010);
}
void AssemblyBuilderA64::mov(RegisterA64 dst, uint16_t src, int shift)
{
placeI16("mov", dst, src, 0b10'100101, shift);
}
void AssemblyBuilderA64::movk(RegisterA64 dst, uint16_t src, int shift)
{
placeI16("movk", dst, src, 0b11'100101, shift);
}
void AssemblyBuilderA64::add(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift)
{
placeSR3("add", dst, src1, src2, 0b00'01011, shift);
}
void AssemblyBuilderA64::add(RegisterA64 dst, RegisterA64 src1, int src2)
{
placeI12("add", dst, src1, src2, 0b00'10001);
}
void AssemblyBuilderA64::sub(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift)
{
placeSR3("sub", dst, src1, src2, 0b10'01011, shift);
}
void AssemblyBuilderA64::sub(RegisterA64 dst, RegisterA64 src1, int src2)
{
placeI12("sub", dst, src1, src2, 0b10'10001);
}
void AssemblyBuilderA64::neg(RegisterA64 dst, RegisterA64 src)
{
placeSR2("neg", dst, src, 0b10'01011);
}
void AssemblyBuilderA64::and_(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2)
{
placeSR3("and", dst, src1, src2, 0b00'01010);
}
void AssemblyBuilderA64::orr(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2)
{
placeSR3("orr", dst, src1, src2, 0b01'01010);
}
void AssemblyBuilderA64::eor(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2)
{
placeSR3("eor", dst, src1, src2, 0b10'01010);
}
void AssemblyBuilderA64::lsl(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2)
{
placeR3("lsl", dst, src1, src2, 0b11010110, 0b0010'00);
}
void AssemblyBuilderA64::lsr(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2)
{
placeR3("lsr", dst, src1, src2, 0b11010110, 0b0010'01);
}
void AssemblyBuilderA64::asr(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2)
{
placeR3("asr", dst, src1, src2, 0b11010110, 0b0010'10);
}
void AssemblyBuilderA64::ror(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2)
{
placeR3("ror", dst, src1, src2, 0b11010110, 0b0010'11);
}
void AssemblyBuilderA64::clz(RegisterA64 dst, RegisterA64 src)
{
placeR1("clz", dst, src, 0b10'11010110'00000'00010'0);
}
void AssemblyBuilderA64::rbit(RegisterA64 dst, RegisterA64 src)
{
placeR1("rbit", dst, src, 0b10'11010110'00000'0000'00);
}
void AssemblyBuilderA64::ldr(RegisterA64 dst, AddressA64 src)
{
LUAU_ASSERT(dst.kind == KindA64::x || dst.kind == KindA64::w);
placeA("ldr", dst, src, 0b11100001, 0b10 | uint8_t(dst.kind == KindA64::x));
}
void AssemblyBuilderA64::ldrb(RegisterA64 dst, AddressA64 src)
{
LUAU_ASSERT(dst.kind == KindA64::w);
placeA("ldrb", dst, src, 0b11100001, 0b00);
}
void AssemblyBuilderA64::ldrh(RegisterA64 dst, AddressA64 src)
{
LUAU_ASSERT(dst.kind == KindA64::w);
placeA("ldrh", dst, src, 0b11100001, 0b01);
}
void AssemblyBuilderA64::ldrsb(RegisterA64 dst, AddressA64 src)
{
LUAU_ASSERT(dst.kind == KindA64::x || dst.kind == KindA64::w);
placeA("ldrsb", dst, src, 0b11100010 | uint8_t(dst.kind == KindA64::w), 0b00);
}
void AssemblyBuilderA64::ldrsh(RegisterA64 dst, AddressA64 src)
{
LUAU_ASSERT(dst.kind == KindA64::x || dst.kind == KindA64::w);
placeA("ldrsh", dst, src, 0b11100010 | uint8_t(dst.kind == KindA64::w), 0b01);
}
void AssemblyBuilderA64::ldrsw(RegisterA64 dst, AddressA64 src)
{
LUAU_ASSERT(dst.kind == KindA64::x);
placeA("ldrsw", dst, src, 0b11100010, 0b10);
}
void AssemblyBuilderA64::str(RegisterA64 src, AddressA64 dst)
{
LUAU_ASSERT(src.kind == KindA64::x || src.kind == KindA64::w);
placeA("str", src, dst, 0b11100000, 0b10 | uint8_t(src.kind == KindA64::x));
}
void AssemblyBuilderA64::strb(RegisterA64 src, AddressA64 dst)
{
LUAU_ASSERT(src.kind == KindA64::w);
placeA("strb", src, dst, 0b11100000, 0b00);
}
void AssemblyBuilderA64::strh(RegisterA64 src, AddressA64 dst)
{
LUAU_ASSERT(src.kind == KindA64::w);
placeA("strh", src, dst, 0b11100000, 0b01);
}
void AssemblyBuilderA64::b(ConditionA64 cond, Label& label)
{
placeBC(textForCondition[int(cond)], label, 0b0101010'0, codeForCondition[int(cond)]);
}
void AssemblyBuilderA64::cbz(RegisterA64 src, Label& label)
{
placeBR("cbz", label, 0b011010'0, src);
}
void AssemblyBuilderA64::cbnz(RegisterA64 src, Label& label)
{
placeBR("cbnz", label, 0b011010'1, src);
}
void AssemblyBuilderA64::ret()
{
place0("ret", 0b1101011'0'0'10'11111'0000'0'0'11110'00000);
}
bool AssemblyBuilderA64::finalize()
{
bool success = true;
code.resize(codePos - code.data());
// Resolve jump targets
for (Label fixup : pendingLabels)
{
// If this assertion fires, a label was used in jmp without calling setLabel
LUAU_ASSERT(labelLocations[fixup.id - 1] != ~0u);
int value = int(labelLocations[fixup.id - 1]) - int(fixup.location);
// imm19 encoding word offset, at bit offset 5
// note that 18 bits of word offsets = 20 bits of byte offsets = +-1MB
if (value > -(1 << 18) && value < (1 << 18))
code[fixup.location] |= (value & ((1 << 19) - 1)) << 5;
else
success = false; // overflow
}
size_t dataSize = data.size() - dataPos;
// Shrink data
if (dataSize > 0)
memmove(&data[0], &data[dataPos], dataSize);
data.resize(dataSize);
finalized = true;
return success;
}
Label AssemblyBuilderA64::setLabel()
{
Label label{nextLabel++, getCodeSize()};
labelLocations.push_back(~0u);
if (logText)
log(label);
return label;
}
void AssemblyBuilderA64::setLabel(Label& label)
{
if (label.id == 0)
{
label.id = nextLabel++;
labelLocations.push_back(~0u);
}
label.location = getCodeSize();
labelLocations[label.id - 1] = label.location;
if (logText)
log(label);
}
void AssemblyBuilderA64::logAppend(const char* fmt, ...)
{
char buf[256];
va_list args;
va_start(args, fmt);
vsnprintf(buf, sizeof(buf), fmt, args);
va_end(args);
text.append(buf);
}
uint32_t AssemblyBuilderA64::getCodeSize() const
{
return uint32_t(codePos - code.data());
}
void AssemblyBuilderA64::place0(const char* name, uint32_t op)
{
if (logText)
log(name);
place(op);
commit();
}
void AssemblyBuilderA64::placeSR3(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t op, int shift)
{
if (logText)
log(name, dst, src1, src2, shift);
LUAU_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x);
LUAU_ASSERT(dst.kind == src1.kind && dst.kind == src2.kind);
LUAU_ASSERT(shift >= 0 && shift < 64); // right shift requires changing some encoding bits
uint32_t sf = (dst.kind == KindA64::x) ? 0x80000000 : 0;
place(dst.index | (src1.index << 5) | (shift << 10) | (src2.index << 16) | (op << 24) | sf);
commit();
}
void AssemblyBuilderA64::placeSR2(const char* name, RegisterA64 dst, RegisterA64 src, uint8_t op)
{
if (logText)
log(name, dst, src);
LUAU_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x);
LUAU_ASSERT(dst.kind == src.kind);
uint32_t sf = (dst.kind == KindA64::x) ? 0x80000000 : 0;
place(dst.index | (0x1f << 5) | (src.index << 16) | (op << 24) | sf);
commit();
}
void AssemblyBuilderA64::placeR3(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t op, uint8_t op2)
{
if (logText)
log(name, dst, src1, src2);
LUAU_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x);
LUAU_ASSERT(dst.kind == src1.kind && dst.kind == src2.kind);
uint32_t sf = (dst.kind == KindA64::x) ? 0x80000000 : 0;
place(dst.index | (src1.index << 5) | (op2 << 10) | (src2.index << 16) | (op << 21) | sf);
commit();
}
void AssemblyBuilderA64::placeR1(const char* name, RegisterA64 dst, RegisterA64 src, uint32_t op)
{
if (logText)
log(name, dst, src);
LUAU_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x);
LUAU_ASSERT(dst.kind == src.kind);
uint32_t sf = (dst.kind == KindA64::x) ? 0x80000000 : 0;
place(dst.index | (src.index << 5) | (op << 10) | sf);
commit();
}
void AssemblyBuilderA64::placeI12(const char* name, RegisterA64 dst, RegisterA64 src1, int src2, uint8_t op)
{
if (logText)
log(name, dst, src1, src2);
LUAU_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x);
LUAU_ASSERT(dst.kind == src1.kind);
LUAU_ASSERT(src2 >= 0 && src2 < (1 << 12));
uint32_t sf = (dst.kind == KindA64::x) ? 0x80000000 : 0;
place(dst.index | (src1.index << 5) | (src2 << 10) | (op << 24) | sf);
commit();
}
void AssemblyBuilderA64::placeI16(const char* name, RegisterA64 dst, int src, uint8_t op, int shift)
{
if (logText)
log(name, dst, src, shift);
LUAU_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x);
LUAU_ASSERT(src >= 0 && src <= 0xffff);
LUAU_ASSERT(shift == 0 || shift == 16 || shift == 32 || shift == 48);
uint32_t sf = (dst.kind == KindA64::x) ? 0x80000000 : 0;
place(dst.index | (src << 5) | ((shift >> 4) << 21) | (op << 23) | sf);
commit();
}
void AssemblyBuilderA64::placeA(const char* name, RegisterA64 dst, AddressA64 src, uint8_t op, uint8_t size)
{
if (logText)
log(name, dst, src);
switch (src.kind)
{
case AddressKindA64::imm:
LUAU_ASSERT(src.data % (1 << size) == 0);
place(dst.index | (src.base.index << 5) | ((src.data >> size) << 10) | (op << 22) | (1 << 24) | (size << 30));
break;
case AddressKindA64::reg:
place(dst.index | (src.base.index << 5) | (0b10 << 10) | (0b011 << 13) | (src.offset.index << 16) | (1 << 21) | (op << 22) | (size << 30));
break;
}
commit();
}
void AssemblyBuilderA64::placeBC(const char* name, Label& label, uint8_t op, uint8_t cond)
{
placeLabel(label);
if (logText)
log(name, label);
place(cond | (op << 24));
commit();
}
void AssemblyBuilderA64::placeBR(const char* name, Label& label, uint8_t op, RegisterA64 cond)
{
placeLabel(label);
if (logText)
log(name, cond, label);
LUAU_ASSERT(cond.kind == KindA64::w || cond.kind == KindA64::x);
uint32_t sf = (cond.kind == KindA64::x) ? 0x80000000 : 0;
place(cond.index | (op << 24) | sf);
commit();
}
void AssemblyBuilderA64::place(uint32_t word)
{
LUAU_ASSERT(codePos < codeEnd);
*codePos++ = word;
}
void AssemblyBuilderA64::placeLabel(Label& label)
{
if (label.location == ~0u)
{
if (label.id == 0)
{
label.id = nextLabel++;
labelLocations.push_back(~0u);
}
pendingLabels.push_back({label.id, getCodeSize()});
}
else
{
// note: if label has an assigned location we can in theory avoid patching it later, but
// we need to handle potential overflow of 19-bit offsets
LUAU_ASSERT(label.id != 0);
labelLocations[label.id - 1] = label.location;
pendingLabels.push_back({label.id, getCodeSize()});
}
}
void AssemblyBuilderA64::commit()
{
LUAU_ASSERT(codePos <= codeEnd);
if (codeEnd == codePos)
extend();
}
void AssemblyBuilderA64::extend()
{
uint32_t count = getCodeSize();
code.resize(code.size() * 2);
codePos = code.data() + count;
codeEnd = code.data() + code.size();
}
size_t AssemblyBuilderA64::allocateData(size_t size, size_t align)
{
LUAU_ASSERT(align > 0 && align <= kMaxAlign && (align & (align - 1)) == 0);
if (dataPos < size)
{
size_t oldSize = data.size();
data.resize(data.size() * 2);
memcpy(&data[oldSize], &data[0], oldSize);
memset(&data[0], 0, oldSize);
dataPos += oldSize;
}
dataPos = (dataPos - size) & ~(align - 1);
return dataPos;
}
void AssemblyBuilderA64::log(const char* opcode)
{
logAppend(" %s\n", opcode);
}
void AssemblyBuilderA64::log(const char* opcode, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift)
{
logAppend(" %-12s", opcode);
log(dst);
text.append(",");
log(src1);
text.append(",");
log(src2);
if (shift > 0)
logAppend(" LSL #%d", shift);
text.append("\n");
}
void AssemblyBuilderA64::log(const char* opcode, RegisterA64 dst, RegisterA64 src1, int src2)
{
logAppend(" %-12s", opcode);
log(dst);
text.append(",");
log(src1);
text.append(",");
logAppend("#%d", src2);
text.append("\n");
}
void AssemblyBuilderA64::log(const char* opcode, RegisterA64 dst, AddressA64 src)
{
logAppend(" %-12s", opcode);
log(dst);
text.append(",");
log(src);
text.append("\n");
}
void AssemblyBuilderA64::log(const char* opcode, RegisterA64 dst, RegisterA64 src)
{
logAppend(" %-12s", opcode);
log(dst);
text.append(",");
log(src);
text.append("\n");
}
void AssemblyBuilderA64::log(const char* opcode, RegisterA64 dst, int src, int shift)
{
logAppend(" %-12s", opcode);
log(dst);
text.append(",");
logAppend("#%d", src);
if (shift > 0)
logAppend(" LSL #%d", shift);
text.append("\n");
}
void AssemblyBuilderA64::log(const char* opcode, RegisterA64 src, Label label)
{
logAppend(" %-12s", opcode);
log(src);
text.append(",");
logAppend(".L%d\n", label.id);
}
void AssemblyBuilderA64::log(const char* opcode, Label label)
{
logAppend(" %-12s.L%d\n", opcode, label.id);
}
void AssemblyBuilderA64::log(Label label)
{
logAppend(".L%d:\n", label.id);
}
void AssemblyBuilderA64::log(RegisterA64 reg)
{
switch (reg.kind)
{
case KindA64::w:
if (reg.index == 31)
logAppend("wzr");
else
logAppend("w%d", reg.index);
break;
case KindA64::x:
if (reg.index == 31)
logAppend("xzr");
else
logAppend("x%d", reg.index);
break;
case KindA64::none:
LUAU_ASSERT(!"Unexpected register kind");
}
}
void AssemblyBuilderA64::log(AddressA64 addr)
{
text.append("[");
switch (addr.kind)
{
case AddressKindA64::imm:
log(addr.base);
if (addr.data != 0)
logAppend(",#%d", addr.data);
break;
case AddressKindA64::reg:
log(addr.base);
text.append(",");
log(addr.offset);
if (addr.data != 0)
logAppend(" LSL #%d", addr.data);
break;
}
text.append("]");
}
} // namespace CodeGen
} // namespace Luau

View file

@ -13,13 +13,13 @@ namespace CodeGen
{ {
// TODO: more assertions on operand sizes // TODO: more assertions on operand sizes
const uint8_t codeForCondition[] = { static const uint8_t codeForCondition[] = {
0x0, 0x1, 0x2, 0x3, 0x2, 0x6, 0x7, 0x3, 0x4, 0xc, 0xe, 0xf, 0xd, 0x3, 0x7, 0x6, 0x2, 0x5, 0xd, 0xf, 0xe, 0xc, 0x4, 0x5, 0xa, 0xb}; 0x0, 0x1, 0x2, 0x3, 0x2, 0x6, 0x7, 0x3, 0x4, 0xc, 0xe, 0xf, 0xd, 0x3, 0x7, 0x6, 0x2, 0x5, 0xd, 0xf, 0xe, 0xc, 0x4, 0x5, 0xa, 0xb};
static_assert(sizeof(codeForCondition) / sizeof(codeForCondition[0]) == size_t(Condition::Count), "all conditions have to be covered"); static_assert(sizeof(codeForCondition) / sizeof(codeForCondition[0]) == size_t(ConditionX64::Count), "all conditions have to be covered");
const char* textForCondition[] = {"jo", "jno", "jc", "jnc", "jb", "jbe", "ja", "jae", "je", "jl", "jle", "jg", "jge", "jnb", "jnbe", "jna", "jnae", static const char* textForCondition[] = {"jo", "jno", "jc", "jnc", "jb", "jbe", "ja", "jae", "je", "jl", "jle", "jg", "jge", "jnb", "jnbe", "jna",
"jne", "jnl", "jnle", "jng", "jnge", "jz", "jnz", "jp", "jnp"}; "jnae", "jne", "jnl", "jnle", "jng", "jnge", "jz", "jnz", "jp", "jnp"};
static_assert(sizeof(textForCondition) / sizeof(textForCondition[0]) == size_t(Condition::Count), "all conditions have to be covered"); static_assert(sizeof(textForCondition) / sizeof(textForCondition[0]) == size_t(ConditionX64::Count), "all conditions have to be covered");
#define OP_PLUS_REG(op, reg) ((op) + (reg & 0x7)) #define OP_PLUS_REG(op, reg) ((op) + (reg & 0x7))
#define OP_PLUS_CC(op, cc) ((op) + uint8_t(cc)) #define OP_PLUS_CC(op, cc) ((op) + uint8_t(cc))
@ -328,7 +328,10 @@ void AssemblyBuilderX64::lea(OperandX64 lhs, OperandX64 rhs)
if (logText) if (logText)
log("lea", lhs, rhs); log("lea", lhs, rhs);
LUAU_ASSERT(rhs.cat == CategoryX64::mem); LUAU_ASSERT(lhs.cat == CategoryX64::reg && rhs.cat == CategoryX64::mem && rhs.memSize == SizeX64::none);
LUAU_ASSERT(rhs.base == rip || rhs.base.size == lhs.base.size);
LUAU_ASSERT(rhs.index == noreg || rhs.index.size == lhs.base.size);
rhs.memSize = lhs.base.size;
placeBinaryRegAndRegMem(lhs, rhs, 0x8d, 0x8d); placeBinaryRegAndRegMem(lhs, rhs, 0x8d, 0x8d);
} }
@ -363,7 +366,7 @@ void AssemblyBuilderX64::ret()
commit(); commit();
} }
void AssemblyBuilderX64::jcc(Condition cond, Label& label) void AssemblyBuilderX64::jcc(ConditionX64 cond, Label& label)
{ {
placeJcc(textForCondition[size_t(cond)], label, codeForCondition[size_t(cond)]); placeJcc(textForCondition[size_t(cond)], label, codeForCondition[size_t(cond)]);
} }
@ -781,7 +784,7 @@ OperandX64 AssemblyBuilderX64::bytes(const void* ptr, size_t size, size_t align)
{ {
size_t pos = allocateData(size, align); size_t pos = allocateData(size, align);
memcpy(&data[pos], ptr, size); memcpy(&data[pos], ptr, size);
return OperandX64(SizeX64::qword, noreg, 1, rip, int32_t(pos - data.size())); return OperandX64(SizeX64::none, noreg, 1, rip, int32_t(pos - data.size()));
} }
void AssemblyBuilderX64::logAppend(const char* fmt, ...) void AssemblyBuilderX64::logAppend(const char* fmt, ...)
@ -1072,7 +1075,7 @@ void AssemblyBuilderX64::placeVex(OperandX64 dst, OperandX64 src1, OperandX64 sr
place(AVX_3_3(setW, src1.base, dst.base.size == SizeX64::ymmword, prefix)); place(AVX_3_3(setW, src1.base, dst.base.size == SizeX64::ymmword, prefix));
} }
uint8_t getScaleEncoding(uint8_t scale) static uint8_t getScaleEncoding(uint8_t scale)
{ {
static const uint8_t scales[9] = {0xff, 0, 1, 0xff, 2, 0xff, 0xff, 0xff, 3}; static const uint8_t scales[9] = {0xff, 0, 1, 0xff, 2, 0xff, 0xff, 0xff, 3};

View file

@ -51,10 +51,15 @@ static void assembleHelpers(AssemblyBuilderX64& build, ModuleHelpers& helpers)
build.logAppend("; exitNoContinueVm\n"); build.logAppend("; exitNoContinueVm\n");
helpers.exitNoContinueVm = build.setLabel(); helpers.exitNoContinueVm = build.setLabel();
emitExit(build, /* continueInVm */ false); emitExit(build, /* continueInVm */ false);
if (build.logText)
build.logAppend("; continueCallInVm\n");
helpers.continueCallInVm = build.setLabel();
emitContinueCallInVm(build);
} }
static int emitInst( static int emitInst(AssemblyBuilderX64& build, NativeState& data, ModuleHelpers& helpers, Proto* proto, LuauOpcode op, const Instruction* pc, int i,
AssemblyBuilderX64& build, NativeState& data, ModuleHelpers& helpers, Proto* proto, LuauOpcode op, const Instruction* pc, int i, Label* labelarr, Label& fallback) Label* labelarr, Label& fallback)
{ {
int skip = 0; int skip = 0;
@ -86,6 +91,9 @@ static int emitInst(
case LOP_SETGLOBAL: case LOP_SETGLOBAL:
emitInstSetGlobal(build, pc, i, labelarr, fallback); emitInstSetGlobal(build, pc, i, labelarr, fallback);
break; break;
case LOP_CALL:
emitInstCall(build, helpers, pc, i, labelarr);
break;
case LOP_RETURN: case LOP_RETURN:
emitInstReturn(build, helpers, pc, i, labelarr); emitInstReturn(build, helpers, pc, i, labelarr);
break; break;
@ -123,19 +131,19 @@ static int emitInst(
emitInstJumpIfEq(build, pc, i, labelarr, /* not_ */ false, fallback); emitInstJumpIfEq(build, pc, i, labelarr, /* not_ */ false, fallback);
break; break;
case LOP_JUMPIFLE: case LOP_JUMPIFLE:
emitInstJumpIfCond(build, pc, i, labelarr, Condition::LessEqual, fallback); emitInstJumpIfCond(build, pc, i, labelarr, ConditionX64::LessEqual, fallback);
break; break;
case LOP_JUMPIFLT: case LOP_JUMPIFLT:
emitInstJumpIfCond(build, pc, i, labelarr, Condition::Less, fallback); emitInstJumpIfCond(build, pc, i, labelarr, ConditionX64::Less, fallback);
break; break;
case LOP_JUMPIFNOTEQ: case LOP_JUMPIFNOTEQ:
emitInstJumpIfEq(build, pc, i, labelarr, /* not_ */ true, fallback); emitInstJumpIfEq(build, pc, i, labelarr, /* not_ */ true, fallback);
break; break;
case LOP_JUMPIFNOTLE: case LOP_JUMPIFNOTLE:
emitInstJumpIfCond(build, pc, i, labelarr, Condition::NotLessEqual, fallback); emitInstJumpIfCond(build, pc, i, labelarr, ConditionX64::NotLessEqual, fallback);
break; break;
case LOP_JUMPIFNOTLT: case LOP_JUMPIFNOTLT:
emitInstJumpIfCond(build, pc, i, labelarr, Condition::NotLess, fallback); emitInstJumpIfCond(build, pc, i, labelarr, ConditionX64::NotLess, fallback);
break; break;
case LOP_JUMPX: case LOP_JUMPX:
emitInstJumpX(build, pc, i, labelarr); emitInstJumpX(build, pc, i, labelarr);
@ -291,19 +299,19 @@ static void emitInstFallback(AssemblyBuilderX64& build, NativeState& data, LuauO
emitInstJumpIfEqFallback(build, pc, i, labelarr, /* not_ */ false); emitInstJumpIfEqFallback(build, pc, i, labelarr, /* not_ */ false);
break; break;
case LOP_JUMPIFLE: case LOP_JUMPIFLE:
emitInstJumpIfCondFallback(build, pc, i, labelarr, Condition::LessEqual); emitInstJumpIfCondFallback(build, pc, i, labelarr, ConditionX64::LessEqual);
break; break;
case LOP_JUMPIFLT: case LOP_JUMPIFLT:
emitInstJumpIfCondFallback(build, pc, i, labelarr, Condition::Less); emitInstJumpIfCondFallback(build, pc, i, labelarr, ConditionX64::Less);
break; break;
case LOP_JUMPIFNOTEQ: case LOP_JUMPIFNOTEQ:
emitInstJumpIfEqFallback(build, pc, i, labelarr, /* not_ */ true); emitInstJumpIfEqFallback(build, pc, i, labelarr, /* not_ */ true);
break; break;
case LOP_JUMPIFNOTLE: case LOP_JUMPIFNOTLE:
emitInstJumpIfCondFallback(build, pc, i, labelarr, Condition::NotLessEqual); emitInstJumpIfCondFallback(build, pc, i, labelarr, ConditionX64::NotLessEqual);
break; break;
case LOP_JUMPIFNOTLT: case LOP_JUMPIFNOTLT:
emitInstJumpIfCondFallback(build, pc, i, labelarr, Condition::NotLess); emitInstJumpIfCondFallback(build, pc, i, labelarr, ConditionX64::NotLess);
break; break;
case LOP_ADD: case LOP_ADD:
emitInstBinaryFallback(build, pc, i, TM_ADD); emitInstBinaryFallback(build, pc, i, TM_ADD);
@ -687,6 +695,9 @@ void compile(lua_State* L, int idx)
{ {
for (int i = 0; i < result->proto->sizecode; i++) for (int i = 0; i < result->proto->sizecode; i++)
result->instTargets[i] += uintptr_t(codeStart + result->location); result->instTargets[i] += uintptr_t(codeStart + result->location);
LUAU_ASSERT(result->proto->sizecode);
result->entryTarget = result->instTargets[0];
} }
// Link native proto objects to Proto; the memory is now managed by VM and will be freed via onDestroyFunction // Link native proto objects to Proto; the memory is now managed by VM and will be freed via onDestroyFunction

View file

@ -72,5 +72,59 @@ void forgPrepXnextFallback(lua_State* L, TValue* ra, int pc)
} }
} }
Closure* callProlog(lua_State* L, TValue* ra, StkId argtop, int nresults)
{
// slow-path: not a function call
if (LUAU_UNLIKELY(!ttisfunction(ra)))
{
luaV_tryfuncTM(L, ra);
argtop++; // __call adds an extra self
}
Closure* ccl = clvalue(ra);
CallInfo* ci = incr_ci(L);
ci->func = ra;
ci->base = ra + 1;
ci->top = argtop + ccl->stacksize; // note: technically UB since we haven't reallocated the stack yet
ci->savedpc = NULL;
ci->flags = 0;
ci->nresults = nresults;
L->base = ci->base;
L->top = argtop;
// note: this reallocs stack, but we don't need to VM_PROTECT this
// this is because we're going to modify base/savedpc manually anyhow
// crucially, we can't use ra/argtop after this line
luaD_checkstack(L, ccl->stacksize);
return ccl;
}
void callEpilogC(lua_State* L, int nresults, int n)
{
// ci is our callinfo, cip is our parent
CallInfo* ci = L->ci;
CallInfo* cip = ci - 1;
// copy return values into parent stack (but only up to nresults!), fill the rest with nil
// note: in MULTRET context nresults starts as -1 so i != 0 condition never activates intentionally
StkId res = ci->func;
StkId vali = L->top - n;
StkId valend = L->top;
int i;
for (i = nresults; i != 0 && vali < valend; i--)
setobj2s(L, res++, vali++);
while (i-- > 0)
setnilvalue(res++);
// pop the stack frame
L->ci = cip;
L->base = cip->base;
L->top = (nresults == LUA_MULTRET) ? res : cip->top;
}
} // namespace CodeGen } // namespace CodeGen
} // namespace Luau } // namespace Luau

View file

@ -13,5 +13,8 @@ bool forgLoopNonTableFallback(lua_State* L, int insnA, int aux);
void forgPrepXnextFallback(lua_State* L, TValue* ra, int pc); void forgPrepXnextFallback(lua_State* L, TValue* ra, int pc);
Closure* callProlog(lua_State* L, TValue* ra, StkId argtop, int nresults);
void callEpilogC(lua_State* L, int nresults, int n);
} // namespace CodeGen } // namespace CodeGen
} // namespace Luau } // namespace Luau

View file

@ -87,7 +87,7 @@ bool initEntryFunction(NativeState& data)
unwind.allocStack(stacksize + localssize); unwind.allocStack(stacksize + localssize);
// Setup frame pointer // Setup frame pointer
build.lea(rbp, qword[rsp + stacksize]); build.lea(rbp, addr[rsp + stacksize]);
unwind.setupFrameReg(rbp, stacksize); unwind.setupFrameReg(rbp, stacksize);
unwind.finish(); unwind.finish();
@ -113,7 +113,7 @@ bool initEntryFunction(NativeState& data)
Label returnOff = build.setLabel(); Label returnOff = build.setLabel();
// Cleanup and exit // Cleanup and exit
build.lea(rsp, qword[rbp + localssize]); build.lea(rsp, addr[rbp + localssize]);
build.pop(r15); build.pop(r15);
build.pop(r14); build.pop(r14);
build.pop(r13); build.pop(r13);

View file

@ -14,7 +14,7 @@ namespace Luau
namespace CodeGen namespace CodeGen
{ {
void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs, OperandX64 rhs, Condition cond, Label& label) void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs, OperandX64 rhs, ConditionX64 cond, Label& label)
{ {
// Refresher on comi/ucomi EFLAGS: // Refresher on comi/ucomi EFLAGS:
// CF only: less // CF only: less
@ -35,53 +35,54 @@ void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs,
// And because of NaN, integer check interchangeability like 'not less or equal' <-> 'greater' does not hold // And because of NaN, integer check interchangeability like 'not less or equal' <-> 'greater' does not hold
switch (cond) switch (cond)
{ {
case Condition::NotLessEqual: case ConditionX64::NotLessEqual:
// (b < a) is the same as !(a <= b). jnae checks CF=1 which means < or NaN // (b < a) is the same as !(a <= b). jnae checks CF=1 which means < or NaN
build.jcc(Condition::NotAboveEqual, label); build.jcc(ConditionX64::NotAboveEqual, label);
break; break;
case Condition::LessEqual: case ConditionX64::LessEqual:
// (b >= a) is the same as (a <= b). jae checks CF=0 which means >= and not NaN // (b >= a) is the same as (a <= b). jae checks CF=0 which means >= and not NaN
build.jcc(Condition::AboveEqual, label); build.jcc(ConditionX64::AboveEqual, label);
break; break;
case Condition::NotLess: case ConditionX64::NotLess:
// (b <= a) is the same as !(a < b). jna checks CF=1 or ZF=1 which means <= or NaN // (b <= a) is the same as !(a < b). jna checks CF=1 or ZF=1 which means <= or NaN
build.jcc(Condition::NotAbove, label); build.jcc(ConditionX64::NotAbove, label);
break; break;
case Condition::Less: case ConditionX64::Less:
// (b > a) is the same as (a < b). ja checks CF=0 and ZF=0 which means > and not NaN // (b > a) is the same as (a < b). ja checks CF=0 and ZF=0 which means > and not NaN
build.jcc(Condition::Above, label); build.jcc(ConditionX64::Above, label);
break; break;
case Condition::NotEqual: case ConditionX64::NotEqual:
// ZF=0 or PF=1 means != or NaN // ZF=0 or PF=1 means != or NaN
build.jcc(Condition::NotZero, label); build.jcc(ConditionX64::NotZero, label);
build.jcc(Condition::Parity, label); build.jcc(ConditionX64::Parity, label);
break; break;
default: default:
LUAU_ASSERT(!"Unsupported condition"); LUAU_ASSERT(!"Unsupported condition");
} }
} }
void jumpOnAnyCmpFallback(AssemblyBuilderX64& build, int ra, int rb, Condition cond, Label& label, int pcpos) void jumpOnAnyCmpFallback(AssemblyBuilderX64& build, int ra, int rb, ConditionX64 cond, Label& label, int pcpos)
{ {
emitSetSavedPc(build, pcpos + 1); emitSetSavedPc(build, pcpos + 1);
build.mov(rArg1, rState); build.mov(rArg1, rState);
build.lea(rArg2, luauRegValue(ra)); build.lea(rArg2, luauRegAddress(ra));
build.lea(rArg3, luauRegValue(rb)); build.lea(rArg3, luauRegAddress(rb));
if (cond == Condition::NotLessEqual || cond == Condition::LessEqual) if (cond == ConditionX64::NotLessEqual || cond == ConditionX64::LessEqual)
build.call(qword[rNativeContext + offsetof(NativeContext, luaV_lessequal)]); build.call(qword[rNativeContext + offsetof(NativeContext, luaV_lessequal)]);
else if (cond == Condition::NotLess || cond == Condition::Less) else if (cond == ConditionX64::NotLess || cond == ConditionX64::Less)
build.call(qword[rNativeContext + offsetof(NativeContext, luaV_lessthan)]); build.call(qword[rNativeContext + offsetof(NativeContext, luaV_lessthan)]);
else if (cond == Condition::NotEqual || cond == Condition::Equal) else if (cond == ConditionX64::NotEqual || cond == ConditionX64::Equal)
build.call(qword[rNativeContext + offsetof(NativeContext, luaV_equalval)]); build.call(qword[rNativeContext + offsetof(NativeContext, luaV_equalval)]);
else else
LUAU_ASSERT(!"Unsupported condition"); LUAU_ASSERT(!"Unsupported condition");
emitUpdateBase(build); emitUpdateBase(build);
build.test(eax, eax); build.test(eax, eax);
build.jcc( build.jcc(cond == ConditionX64::NotLessEqual || cond == ConditionX64::NotLess || cond == ConditionX64::NotEqual ? ConditionX64::Zero
cond == Condition::NotLessEqual || cond == Condition::NotLess || cond == Condition::NotEqual ? Condition::Zero : Condition::NotZero, label); : ConditionX64::NotZero,
label);
} }
RegisterX64 getTableNodeAtCachedSlot(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 table, int pcpos) RegisterX64 getTableNodeAtCachedSlot(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 table, int pcpos)
@ -120,7 +121,7 @@ void convertNumberToIndexOrJump(AssemblyBuilderX64& build, RegisterX64 tmp, Regi
build.vucomisd(tmp, numd); // Sets ZF=1 if equal or NaN build.vucomisd(tmp, numd); // Sets ZF=1 if equal or NaN
// We don't need non-integer values // We don't need non-integer values
// But to skip the PF=1 check, we proceed with NaN because 0x80000000 index is out of bounds // But to skip the PF=1 check, we proceed with NaN because 0x80000000 index is out of bounds
build.jcc(Condition::NotZero, label); build.jcc(ConditionX64::NotZero, label);
} }
void callArithHelper(AssemblyBuilderX64& build, int ra, int rb, OperandX64 c, int pcpos, TMS tm) void callArithHelper(AssemblyBuilderX64& build, int ra, int rb, OperandX64 c, int pcpos, TMS tm)
@ -133,8 +134,8 @@ void callArithHelper(AssemblyBuilderX64& build, int ra, int rb, OperandX64 c, in
build.mov(rArg5, tm); build.mov(rArg5, tm);
build.mov(rArg1, rState); build.mov(rArg1, rState);
build.lea(rArg2, luauRegValue(ra)); build.lea(rArg2, luauRegAddress(ra));
build.lea(rArg3, luauRegValue(rb)); build.lea(rArg3, luauRegAddress(rb));
build.lea(rArg4, c); build.lea(rArg4, c);
build.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarith)]); build.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarith)]);
@ -146,8 +147,8 @@ void callLengthHelper(AssemblyBuilderX64& build, int ra, int rb, int pcpos)
emitSetSavedPc(build, pcpos + 1); emitSetSavedPc(build, pcpos + 1);
build.mov(rArg1, rState); build.mov(rArg1, rState);
build.lea(rArg2, luauRegValue(ra)); build.lea(rArg2, luauRegAddress(ra));
build.lea(rArg3, luauRegValue(rb)); build.lea(rArg3, luauRegAddress(rb));
build.call(qword[rNativeContext + offsetof(NativeContext, luaV_dolen)]); build.call(qword[rNativeContext + offsetof(NativeContext, luaV_dolen)]);
emitUpdateBase(build); emitUpdateBase(build);
@ -158,9 +159,9 @@ void callPrepareForN(AssemblyBuilderX64& build, int limit, int step, int init, i
emitSetSavedPc(build, pcpos + 1); emitSetSavedPc(build, pcpos + 1);
build.mov(rArg1, rState); build.mov(rArg1, rState);
build.lea(rArg2, luauRegValue(limit)); build.lea(rArg2, luauRegAddress(limit));
build.lea(rArg3, luauRegValue(step)); build.lea(rArg3, luauRegAddress(step));
build.lea(rArg4, luauRegValue(init)); build.lea(rArg4, luauRegAddress(init));
build.call(qword[rNativeContext + offsetof(NativeContext, luaV_prepareFORN)]); build.call(qword[rNativeContext + offsetof(NativeContext, luaV_prepareFORN)]);
} }
@ -169,9 +170,9 @@ void callGetTable(AssemblyBuilderX64& build, int rb, OperandX64 c, int ra, int p
emitSetSavedPc(build, pcpos + 1); emitSetSavedPc(build, pcpos + 1);
build.mov(rArg1, rState); build.mov(rArg1, rState);
build.lea(rArg2, luauRegValue(rb)); build.lea(rArg2, luauRegAddress(rb));
build.lea(rArg3, c); build.lea(rArg3, c);
build.lea(rArg4, luauRegValue(ra)); build.lea(rArg4, luauRegAddress(ra));
build.call(qword[rNativeContext + offsetof(NativeContext, luaV_gettable)]); build.call(qword[rNativeContext + offsetof(NativeContext, luaV_gettable)]);
emitUpdateBase(build); emitUpdateBase(build);
@ -182,9 +183,9 @@ void callSetTable(AssemblyBuilderX64& build, int rb, OperandX64 c, int ra, int p
emitSetSavedPc(build, pcpos + 1); emitSetSavedPc(build, pcpos + 1);
build.mov(rArg1, rState); build.mov(rArg1, rState);
build.lea(rArg2, luauRegValue(rb)); build.lea(rArg2, luauRegAddress(rb));
build.lea(rArg3, c); build.lea(rArg3, c);
build.lea(rArg4, luauRegValue(ra)); build.lea(rArg4, luauRegAddress(ra));
build.call(qword[rNativeContext + offsetof(NativeContext, luaV_settable)]); build.call(qword[rNativeContext + offsetof(NativeContext, luaV_settable)]);
emitUpdateBase(build); emitUpdateBase(build);
@ -197,16 +198,16 @@ static void callBarrierImpl(AssemblyBuilderX64& build, RegisterX64 tmp, Register
// iscollectable(ra) // iscollectable(ra)
build.cmp(luauRegTag(ra), LUA_TSTRING); build.cmp(luauRegTag(ra), LUA_TSTRING);
build.jcc(Condition::Less, skip); build.jcc(ConditionX64::Less, skip);
// isblack(obj2gco(o)) // isblack(obj2gco(o))
build.test(byte[object + offsetof(GCheader, marked)], bitmask(BLACKBIT)); build.test(byte[object + offsetof(GCheader, marked)], bitmask(BLACKBIT));
build.jcc(Condition::Zero, skip); build.jcc(ConditionX64::Zero, skip);
// iswhite(gcvalue(ra)) // iswhite(gcvalue(ra))
build.mov(tmp, luauRegValue(ra)); build.mov(tmp, luauRegValue(ra));
build.test(byte[tmp + offsetof(GCheader, marked)], bit2mask(WHITE0BIT, WHITE1BIT)); build.test(byte[tmp + offsetof(GCheader, marked)], bit2mask(WHITE0BIT, WHITE1BIT));
build.jcc(Condition::Zero, skip); build.jcc(ConditionX64::Zero, skip);
LUAU_ASSERT(object != rArg3); LUAU_ASSERT(object != rArg3);
build.mov(rArg3, tmp); build.mov(rArg3, tmp);
@ -229,12 +230,12 @@ void callBarrierTableFast(AssemblyBuilderX64& build, RegisterX64 table, Label& s
{ {
// isblack(obj2gco(t)) // isblack(obj2gco(t))
build.test(byte[table + offsetof(GCheader, marked)], bitmask(BLACKBIT)); build.test(byte[table + offsetof(GCheader, marked)], bitmask(BLACKBIT));
build.jcc(Condition::Zero, skip); build.jcc(ConditionX64::Zero, skip);
// Argument setup re-ordered to avoid conflicts with table register // Argument setup re-ordered to avoid conflicts with table register
if (table != rArg2) if (table != rArg2)
build.mov(rArg2, table); build.mov(rArg2, table);
build.lea(rArg3, qword[rArg2 + offsetof(Table, gclist)]); build.lea(rArg3, addr[rArg2 + offsetof(Table, gclist)]);
build.mov(rArg1, rState); build.mov(rArg1, rState);
build.call(qword[rNativeContext + offsetof(NativeContext, luaC_barrierback)]); build.call(qword[rNativeContext + offsetof(NativeContext, luaC_barrierback)]);
} }
@ -244,13 +245,13 @@ void callCheckGc(AssemblyBuilderX64& build, int pcpos, bool savepc, Label& skip)
build.mov(rax, qword[rState + offsetof(lua_State, global)]); build.mov(rax, qword[rState + offsetof(lua_State, global)]);
build.mov(rdx, qword[rax + offsetof(global_State, totalbytes)]); build.mov(rdx, qword[rax + offsetof(global_State, totalbytes)]);
build.cmp(rdx, qword[rax + offsetof(global_State, GCthreshold)]); build.cmp(rdx, qword[rax + offsetof(global_State, GCthreshold)]);
build.jcc(Condition::Below, skip); build.jcc(ConditionX64::Below, skip);
if (savepc) if (savepc)
emitSetSavedPc(build, pcpos + 1); emitSetSavedPc(build, pcpos + 1);
build.mov(rArg1, rState); build.mov(rArg1, rState);
build.mov(rArg2, 1); build.mov(dwordReg(rArg2), 1);
build.call(qword[rNativeContext + offsetof(NativeContext, luaC_step)]); build.call(qword[rNativeContext + offsetof(NativeContext, luaC_step)]);
emitUpdateBase(build); emitUpdateBase(build);
@ -288,7 +289,7 @@ void emitInterrupt(AssemblyBuilderX64& build, int pcpos)
build.mov(r8, qword[rState + offsetof(lua_State, global)]); build.mov(r8, qword[rState + offsetof(lua_State, global)]);
build.mov(r8, qword[r8 + offsetof(global_State, cb.interrupt)]); build.mov(r8, qword[r8 + offsetof(global_State, cb.interrupt)]);
build.test(r8, r8); build.test(r8, r8);
build.jcc(Condition::Zero, skip); build.jcc(ConditionX64::Zero, skip);
emitSetSavedPc(build, pcpos + 1); // uses rax/rdx emitSetSavedPc(build, pcpos + 1); // uses rax/rdx
@ -301,7 +302,7 @@ void emitInterrupt(AssemblyBuilderX64& build, int pcpos)
// Check if we need to exit // Check if we need to exit
build.mov(al, byte[rState + offsetof(lua_State, status)]); build.mov(al, byte[rState + offsetof(lua_State, status)]);
build.test(al, al); build.test(al, al);
build.jcc(Condition::Zero, skip); build.jcc(ConditionX64::Zero, skip);
build.mov(rax, qword[rState + offsetof(lua_State, ci)]); build.mov(rax, qword[rState + offsetof(lua_State, ci)]);
build.sub(qword[rax + offsetof(CallInfo, savedpc)], sizeof(Instruction)); build.sub(qword[rax + offsetof(CallInfo, savedpc)], sizeof(Instruction));
@ -329,17 +330,6 @@ void emitFallback(AssemblyBuilderX64& build, NativeState& data, int op, int pcpo
build.mov(rArg4, rConstants); build.mov(rArg4, rConstants);
build.call(qword[rNativeContext + offsetof(NativeContext, fallback) + op * sizeof(NativeFallback) + offsetof(NativeFallback, fallback)]); build.call(qword[rNativeContext + offsetof(NativeContext, fallback) + op * sizeof(NativeFallback) + offsetof(NativeFallback, fallback)]);
// Some instructions may interrupt the execution
if (opinfo.flags & kFallbackCheckInterrupt)
{
Label skip;
build.test(rax, rax);
build.jcc(Condition::NotZero, skip);
emitExit(build, /* continueInVm */ false);
build.setLabel(skip);
}
emitUpdateBase(build); emitUpdateBase(build);
// Some instructions may jump to a different instruction or a completely different function // Some instructions may jump to a different instruction or a completely different function
@ -359,50 +349,17 @@ void emitFallback(AssemblyBuilderX64& build, NativeState& data, int op, int pcpo
build.mov(rcx, qword[rdx + offsetof(NativeProto, instTargets)]); build.mov(rcx, qword[rdx + offsetof(NativeProto, instTargets)]);
build.jmp(qword[rax * 2 + rcx]); build.jmp(qword[rax * 2 + rcx]);
} }
else if (opinfo.flags & kFallbackUpdateCi) }
{
// Need to update state of the current function before we jump away
build.mov(rcx, qword[rState + offsetof(lua_State, ci)]); // L->ci
build.mov(rcx, qword[rcx + offsetof(CallInfo, func)]); // L->ci->func
build.mov(rcx, qword[rcx + offsetof(TValue, value.gc)]); // L->ci->func->value.gc aka cl
build.mov(sClosure, rcx);
build.mov(rsi, qword[rcx + offsetof(Closure, l.p)]); // cl->l.p aka proto
build.mov(rConstants, qword[rsi + offsetof(Proto, k)]); // proto->k
build.mov(rcx, qword[rsi + offsetof(Proto, code)]); // proto->code
build.mov(sCode, rcx);
// We'll need original instruction pointer later to handle return to interpreter void emitContinueCallInVm(AssemblyBuilderX64& build)
if (op == LOP_CALL) {
build.mov(r9, rax); RegisterX64 proto = rcx; // Sync with emitInstCall
// Get instruction index from instruction pointer build.mov(rdx, qword[proto + offsetof(Proto, code)]);
// To get instruction index from instruction pointer, we need to divide byte offset by 4 build.mov(rax, qword[rState + offsetof(lua_State, ci)]);
// But we will actually need to scale instruction index by 8 back to byte offset later so it cancels out build.mov(qword[rax + offsetof(CallInfo, savedpc)], rdx);
build.sub(rax, sCode);
// We need to check if the new function can be executed natively emitExit(build, /* continueInVm */ true);
Label returnToInterpreter;
build.mov(rdx, qword[rsi + offsetofProtoExecData]);
build.test(rdx, rdx);
build.jcc(Condition::Zero, returnToInterpreter);
// Get new instruction location and jump to it
build.mov(rcx, qword[rdx + offsetof(NativeProto, instTargets)]);
build.jmp(qword[rax * 2 + rcx]);
build.setLabel(returnToInterpreter);
// If we are returning to the interpreter to make a call, we need to update the current instruction
if (op == LOP_CALL)
{
build.mov(rax, qword[rState + offsetof(lua_State, ci)]);
build.mov(qword[rax + offsetof(CallInfo, savedpc)], r9);
}
// Continue in the interpreter
emitExit(build, /* continueInVm */ true);
}
} }
} // namespace CodeGen } // namespace CodeGen

View file

@ -72,6 +72,7 @@ struct ModuleHelpers
{ {
Label exitContinueVm; Label exitContinueVm;
Label exitNoContinueVm; Label exitNoContinueVm;
Label continueCallInVm;
}; };
inline OperandX64 luauReg(int ri) inline OperandX64 luauReg(int ri)
@ -79,6 +80,11 @@ inline OperandX64 luauReg(int ri)
return xmmword[rBase + ri * sizeof(TValue)]; return xmmword[rBase + ri * sizeof(TValue)];
} }
inline OperandX64 luauRegAddress(int ri)
{
return addr[rBase + ri * sizeof(TValue)];
}
inline OperandX64 luauRegValue(int ri) inline OperandX64 luauRegValue(int ri)
{ {
return qword[rBase + ri * sizeof(TValue) + offsetof(TValue, value)]; return qword[rBase + ri * sizeof(TValue) + offsetof(TValue, value)];
@ -99,6 +105,11 @@ inline OperandX64 luauConstant(int ki)
return xmmword[rConstants + ki * sizeof(TValue)]; return xmmword[rConstants + ki * sizeof(TValue)];
} }
inline OperandX64 luauConstantAddress(int ki)
{
return addr[rConstants + ki * sizeof(TValue)];
}
inline OperandX64 luauConstantTag(int ki) inline OperandX64 luauConstantTag(int ki)
{ {
return dword[rConstants + ki * sizeof(TValue) + offsetof(TValue, tt)]; return dword[rConstants + ki * sizeof(TValue) + offsetof(TValue, tt)];
@ -144,13 +155,13 @@ inline void setNodeValue(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64
inline void jumpIfTagIs(AssemblyBuilderX64& build, int ri, lua_Type tag, Label& label) inline void jumpIfTagIs(AssemblyBuilderX64& build, int ri, lua_Type tag, Label& label)
{ {
build.cmp(luauRegTag(ri), tag); build.cmp(luauRegTag(ri), tag);
build.jcc(Condition::Equal, label); build.jcc(ConditionX64::Equal, label);
} }
inline void jumpIfTagIsNot(AssemblyBuilderX64& build, int ri, lua_Type tag, Label& label) inline void jumpIfTagIsNot(AssemblyBuilderX64& build, int ri, lua_Type tag, Label& label)
{ {
build.cmp(luauRegTag(ri), tag); build.cmp(luauRegTag(ri), tag);
build.jcc(Condition::NotEqual, label); build.jcc(ConditionX64::NotEqual, label);
} }
// Note: fallthrough label should be placed after this condition // Note: fallthrough label should be placed after this condition
@ -160,7 +171,7 @@ inline void jumpIfFalsy(AssemblyBuilderX64& build, int ri, Label& target, Label&
jumpIfTagIsNot(build, ri, LUA_TBOOLEAN, fallthrough); // true if not nil or boolean jumpIfTagIsNot(build, ri, LUA_TBOOLEAN, fallthrough); // true if not nil or boolean
build.cmp(luauRegValueBoolean(ri), 0); build.cmp(luauRegValueBoolean(ri), 0);
build.jcc(Condition::Equal, target); // true if boolean value is 'true' build.jcc(ConditionX64::Equal, target); // true if boolean value is 'true'
} }
// Note: fallthrough label should be placed after this condition // Note: fallthrough label should be placed after this condition
@ -170,13 +181,13 @@ inline void jumpIfTruthy(AssemblyBuilderX64& build, int ri, Label& target, Label
jumpIfTagIsNot(build, ri, LUA_TBOOLEAN, target); // true if not nil or boolean jumpIfTagIsNot(build, ri, LUA_TBOOLEAN, target); // true if not nil or boolean
build.cmp(luauRegValueBoolean(ri), 0); build.cmp(luauRegValueBoolean(ri), 0);
build.jcc(Condition::NotEqual, target); // true if boolean value is 'true' build.jcc(ConditionX64::NotEqual, target); // true if boolean value is 'true'
} }
inline void jumpIfMetatablePresent(AssemblyBuilderX64& build, RegisterX64 table, Label& target) inline void jumpIfMetatablePresent(AssemblyBuilderX64& build, RegisterX64 table, Label& target)
{ {
build.cmp(qword[table + offsetof(Table, metatable)], 0); build.cmp(qword[table + offsetof(Table, metatable)], 0);
build.jcc(Condition::NotEqual, target); build.jcc(ConditionX64::NotEqual, target);
} }
inline void jumpIfUnsafeEnv(AssemblyBuilderX64& build, RegisterX64 tmp, Label& label) inline void jumpIfUnsafeEnv(AssemblyBuilderX64& build, RegisterX64 tmp, Label& label)
@ -184,13 +195,13 @@ inline void jumpIfUnsafeEnv(AssemblyBuilderX64& build, RegisterX64 tmp, Label& l
build.mov(tmp, sClosure); build.mov(tmp, sClosure);
build.mov(tmp, qword[tmp + offsetof(Closure, env)]); build.mov(tmp, qword[tmp + offsetof(Closure, env)]);
build.test(byte[tmp + offsetof(Table, safeenv)], 1); build.test(byte[tmp + offsetof(Table, safeenv)], 1);
build.jcc(Condition::Zero, label); // Not a safe environment build.jcc(ConditionX64::Zero, label); // Not a safe environment
} }
inline void jumpIfTableIsReadOnly(AssemblyBuilderX64& build, RegisterX64 table, Label& label) inline void jumpIfTableIsReadOnly(AssemblyBuilderX64& build, RegisterX64 table, Label& label)
{ {
build.cmp(byte[table + offsetof(Table, readonly)], 0); build.cmp(byte[table + offsetof(Table, readonly)], 0);
build.jcc(Condition::NotEqual, label); build.jcc(ConditionX64::NotEqual, label);
} }
inline void jumpIfNodeKeyTagIsNot(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 node, lua_Type tag, Label& label) inline void jumpIfNodeKeyTagIsNot(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 node, lua_Type tag, Label& label)
@ -200,13 +211,13 @@ inline void jumpIfNodeKeyTagIsNot(AssemblyBuilderX64& build, RegisterX64 tmp, Re
build.mov(tmp, luauNodeKeyTag(node)); build.mov(tmp, luauNodeKeyTag(node));
build.and_(tmp, kLuaNodeTagMask); build.and_(tmp, kLuaNodeTagMask);
build.cmp(tmp, tag); build.cmp(tmp, tag);
build.jcc(Condition::NotEqual, label); build.jcc(ConditionX64::NotEqual, label);
} }
inline void jumpIfNodeValueTagIs(AssemblyBuilderX64& build, RegisterX64 node, lua_Type tag, Label& label) inline void jumpIfNodeValueTagIs(AssemblyBuilderX64& build, RegisterX64 node, lua_Type tag, Label& label)
{ {
build.cmp(dword[node + offsetof(LuaNode, val) + offsetof(TValue, tt)], tag); build.cmp(dword[node + offsetof(LuaNode, val) + offsetof(TValue, tt)], tag);
build.jcc(Condition::Equal, label); build.jcc(ConditionX64::Equal, label);
} }
inline void jumpIfNodeKeyNotInExpectedSlot(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 node, OperandX64 expectedKey, Label& label) inline void jumpIfNodeKeyNotInExpectedSlot(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 node, OperandX64 expectedKey, Label& label)
@ -215,13 +226,13 @@ inline void jumpIfNodeKeyNotInExpectedSlot(AssemblyBuilderX64& build, RegisterX6
build.mov(tmp, expectedKey); build.mov(tmp, expectedKey);
build.cmp(tmp, luauNodeKeyValue(node)); build.cmp(tmp, luauNodeKeyValue(node));
build.jcc(Condition::NotEqual, label); build.jcc(ConditionX64::NotEqual, label);
jumpIfNodeValueTagIs(build, node, LUA_TNIL, label); jumpIfNodeValueTagIs(build, node, LUA_TNIL, label);
} }
void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs, OperandX64 rhs, Condition cond, Label& label); void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs, OperandX64 rhs, ConditionX64 cond, Label& label);
void jumpOnAnyCmpFallback(AssemblyBuilderX64& build, int ra, int rb, Condition cond, Label& label, int pcpos); void jumpOnAnyCmpFallback(AssemblyBuilderX64& build, int ra, int rb, ConditionX64 cond, Label& label, int pcpos);
RegisterX64 getTableNodeAtCachedSlot(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 table, int pcpos); RegisterX64 getTableNodeAtCachedSlot(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 table, int pcpos);
void convertNumberToIndexOrJump(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 numd, RegisterX64 numi, int ri, Label& label); void convertNumberToIndexOrJump(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 numd, RegisterX64 numi, int ri, Label& label);
@ -242,5 +253,8 @@ void emitSetSavedPc(AssemblyBuilderX64& build, int pcpos); // Note: only uses ra
void emitInterrupt(AssemblyBuilderX64& build, int pcpos); void emitInterrupt(AssemblyBuilderX64& build, int pcpos);
void emitFallback(AssemblyBuilderX64& build, NativeState& data, int op, int pcpos); void emitFallback(AssemblyBuilderX64& build, NativeState& data, int op, int pcpos);
void emitContinueCallInVm(AssemblyBuilderX64& build);
void emitExitFromLastReturn(AssemblyBuilderX64& build);
} // namespace CodeGen } // namespace CodeGen
} // namespace Luau } // namespace Luau

View file

@ -11,9 +11,6 @@
#include "lobject.h" #include "lobject.h"
#include "ltm.h" #include "ltm.h"
// TODO: all uses of luauRegValue and luauConstantValue need to be audited; some need to be changed to luauReg/ConstantAddress (doesn't exist yet)
// (the problem with existing use is that it includes additional offsetof(TValue, value) which happens to be 0 but isn't guaranteed to be)
namespace Luau namespace Luau
{ {
namespace CodeGen namespace CodeGen
@ -72,6 +69,159 @@ void emitInstMove(AssemblyBuilderX64& build, const Instruction* pc)
build.vmovups(luauReg(ra), xmm0); build.vmovups(luauReg(ra), xmm0);
} }
void emitInstCall(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Instruction* pc, int pcpos, Label* labelarr)
{
int ra = LUAU_INSN_A(*pc);
int nparams = LUAU_INSN_B(*pc) - 1;
int nresults = LUAU_INSN_C(*pc) - 1;
emitInterrupt(build, pcpos);
emitSetSavedPc(build, pcpos + 1);
build.mov(rArg1, rState);
build.lea(rArg2, luauRegAddress(ra));
if (nparams == LUA_MULTRET)
build.mov(rArg3, qword[rState + offsetof(lua_State, top)]);
else
build.lea(rArg3, luauRegAddress(ra + 1 + nparams));
build.mov(dwordReg(rArg4), nresults);
build.call(qword[rNativeContext + offsetof(NativeContext, callProlog)]);
RegisterX64 ccl = rax; // Returned from callProlog
emitUpdateBase(build);
Label cFuncCall;
build.test(byte[ccl + offsetof(Closure, isC)], 1);
build.jcc(ConditionX64::NotZero, cFuncCall);
{
RegisterX64 proto = rcx; // Sync with emitContinueCallInVm
RegisterX64 ci = rdx;
RegisterX64 argi = rsi;
RegisterX64 argend = rdi;
build.mov(proto, qword[ccl + offsetof(Closure, l.p)]);
// Switch current Closure
build.mov(sClosure, ccl); // Last use of 'ccl'
build.mov(ci, qword[rState + offsetof(lua_State, ci)]);
Label fillnil, exitfillnil;
// argi = L->top
build.mov(argi, qword[rState + offsetof(lua_State, top)]);
// argend = L->base + p->numparams
build.movzx(eax, byte[proto + offsetof(Proto, numparams)]);
build.shl(eax, kTValueSizeLog2);
build.lea(argend, addr[rBase + rax]);
// while (argi < argend) setnilvalue(argi++);
build.setLabel(fillnil);
build.cmp(argi, argend);
build.jcc(ConditionX64::NotBelow, exitfillnil);
build.mov(dword[argi + offsetof(TValue, tt)], LUA_TNIL);
build.add(argi, sizeof(TValue));
build.jmp(fillnil); // This loop rarely runs so it's not worth repeating cmp/jcc
build.setLabel(exitfillnil);
// Set L->top to ci->top as most function expect (no vararg)
build.mov(rax, qword[ci + offsetof(CallInfo, top)]);
build.mov(qword[rState + offsetof(lua_State, top)], rax);
build.mov(rax, qword[proto + offsetofProtoExecData]); // We'll need this value later
// But if it is vararg, update it to 'argi'
Label skipVararg;
build.test(byte[proto + offsetof(Proto, is_vararg)], 1);
build.jcc(ConditionX64::Zero, skipVararg);
build.mov(qword[rState + offsetof(lua_State, top)], argi);
build.setLabel(skipVararg);
// Check native function data
build.test(rax, rax);
build.jcc(ConditionX64::Zero, helpers.continueCallInVm);
// Switch current constants
build.mov(rConstants, qword[proto + offsetof(Proto, k)]);
// Switch current code
build.mov(rdx, qword[proto + offsetof(Proto, code)]);
build.mov(sCode, rdx);
build.jmp(qword[rax + offsetof(NativeProto, entryTarget)]);
}
build.setLabel(cFuncCall);
{
// results = ccl->c.f(L);
build.mov(rArg1, rState);
build.call(qword[ccl + offsetof(Closure, c.f)]); // Last use of 'ccl'
RegisterX64 results = eax;
build.test(results, results); // test here will set SF=1 for a negative number and it always sets OF to 0
build.jcc(ConditionX64::Less, helpers.exitNoContinueVm); // jl jumps if SF != OF
// We have special handling for small number of expected results below
if (nresults != 0 && nresults != 1)
{
build.mov(rArg1, rState);
build.mov(dwordReg(rArg2), nresults);
build.mov(dwordReg(rArg3), results);
build.call(qword[rNativeContext + offsetof(NativeContext, callEpilogC)]);
emitUpdateBase(build);
return;
}
RegisterX64 ci = rdx;
RegisterX64 cip = rcx;
RegisterX64 vali = rsi;
build.mov(ci, qword[rState + offsetof(lua_State, ci)]);
build.lea(cip, addr[ci - sizeof(CallInfo)]);
// L->base = cip->base
build.mov(rBase, qword[cip + offsetof(CallInfo, base)]);
build.mov(qword[rState + offsetof(lua_State, base)], rBase);
if (nresults == 1)
{
// Opportunistically copy the result we expected from (L->top - results)
build.mov(vali, qword[rState + offsetof(lua_State, top)]);
build.shl(results, kTValueSizeLog2);
build.sub(vali, qwordReg(results));
build.vmovups(xmm0, xmmword[vali]);
build.vmovups(luauReg(ra), xmm0);
Label skipnil;
// If there was no result, override the value with 'nil'
build.test(results, results);
build.jcc(ConditionX64::NotZero, skipnil);
build.mov(luauRegTag(ra), LUA_TNIL);
build.setLabel(skipnil);
}
// L->ci = cip
build.mov(qword[rState + offsetof(lua_State, ci)], cip);
// L->top = cip->top
build.mov(rax, qword[cip + offsetof(CallInfo, top)]);
build.mov(qword[rState + offsetof(lua_State, top)], rax);
}
}
void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Instruction* pc, int pcpos, Label* labelarr) void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Instruction* pc, int pcpos, Label* labelarr)
{ {
emitInterrupt(build, pcpos); emitInterrupt(build, pcpos);
@ -85,7 +235,7 @@ void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Ins
RegisterX64 nresults = esi; RegisterX64 nresults = esi;
build.mov(ci, qword[rState + offsetof(lua_State, ci)]); build.mov(ci, qword[rState + offsetof(lua_State, ci)]);
build.lea(cip, qword[ci - sizeof(CallInfo)]); build.lea(cip, addr[ci - sizeof(CallInfo)]);
// res = ci->func; note: we assume CALL always puts func+args and expects results to start at func // res = ci->func; note: we assume CALL always puts func+args and expects results to start at func
build.mov(res, qword[ci + offsetof(CallInfo, func)]); build.mov(res, qword[ci + offsetof(CallInfo, func)]);
@ -100,8 +250,8 @@ void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Ins
if (b == 0) if (b == 0)
{ {
// Our instruction doesn't have any results, so just fill results expected in parent with 'nil' // Our instruction doesn't have any results, so just fill results expected in parent with 'nil'
build.test(nresults, nresults); // test here will set SF=1 for a negative number, ZF=1 for zero and OF=0 build.test(nresults, nresults); // test here will set SF=1 for a negative number, ZF=1 for zero and OF=0
build.jcc(Condition::LessEqual, skipResultCopy); // jle jumps if SF != OF or ZF == 1 build.jcc(ConditionX64::LessEqual, skipResultCopy); // jle jumps if SF != OF or ZF == 1
build.mov(counter, nresults); build.mov(counter, nresults);
@ -109,29 +259,29 @@ void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Ins
build.mov(dword[res + offsetof(TValue, tt)], LUA_TNIL); build.mov(dword[res + offsetof(TValue, tt)], LUA_TNIL);
build.add(res, sizeof(TValue)); build.add(res, sizeof(TValue));
build.dec(counter); build.dec(counter);
build.jcc(Condition::NotZero, repeatNilLoop); build.jcc(ConditionX64::NotZero, repeatNilLoop);
} }
else if (b == 1) else if (b == 1)
{ {
// Try setting our 1 result // Try setting our 1 result
build.test(nresults, nresults); build.test(nresults, nresults);
build.jcc(Condition::Zero, skipResultCopy); build.jcc(ConditionX64::Zero, skipResultCopy);
build.lea(counter, dword[nresults - 1]); build.lea(counter, addr[nresults - 1]);
build.vmovups(xmm0, luauReg(ra)); build.vmovups(xmm0, luauReg(ra));
build.vmovups(xmmword[res], xmm0); build.vmovups(xmmword[res], xmm0);
build.add(res, sizeof(TValue)); build.add(res, sizeof(TValue));
// Fill the rest of the expected results with 'nil' // Fill the rest of the expected results with 'nil'
build.test(counter, counter); // test here will set SF=1 for a negative number, ZF=1 for zero and OF=0 build.test(counter, counter); // test here will set SF=1 for a negative number, ZF=1 for zero and OF=0
build.jcc(Condition::LessEqual, skipResultCopy); // jle jumps if SF != OF or ZF == 1 build.jcc(ConditionX64::LessEqual, skipResultCopy); // jle jumps if SF != OF or ZF == 1
Label repeatNilLoop = build.setLabel(); Label repeatNilLoop = build.setLabel();
build.mov(dword[res + offsetof(TValue, tt)], LUA_TNIL); build.mov(dword[res + offsetof(TValue, tt)], LUA_TNIL);
build.add(res, sizeof(TValue)); build.add(res, sizeof(TValue));
build.dec(counter); build.dec(counter);
build.jcc(Condition::NotZero, repeatNilLoop); build.jcc(ConditionX64::NotZero, repeatNilLoop);
} }
else else
{ {
@ -140,16 +290,16 @@ void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Ins
// Copy return values into parent stack (but only up to nresults!) // Copy return values into parent stack (but only up to nresults!)
build.test(nresults, nresults); build.test(nresults, nresults);
build.jcc(Condition::Zero, skipResultCopy); build.jcc(ConditionX64::Zero, skipResultCopy);
// vali = ra // vali = ra
build.lea(vali, luauRegValue(ra)); build.lea(vali, luauRegAddress(ra));
// Copy as much as possible for MULTRET calls, and only as much as needed otherwise // Copy as much as possible for MULTRET calls, and only as much as needed otherwise
if (b == LUA_MULTRET) if (b == LUA_MULTRET)
build.mov(valend, qword[rState + offsetof(lua_State, top)]); // valend = L->top build.mov(valend, qword[rState + offsetof(lua_State, top)]); // valend = L->top
else else
build.lea(valend, luauRegValue(ra + b)); // valend = ra + b build.lea(valend, luauRegAddress(ra + b)); // valend = ra + b
build.mov(counter, nresults); build.mov(counter, nresults);
@ -157,26 +307,26 @@ void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Ins
build.setLabel(repeatValueLoop); build.setLabel(repeatValueLoop);
build.cmp(vali, valend); build.cmp(vali, valend);
build.jcc(Condition::NotBelow, exitValueLoop); build.jcc(ConditionX64::NotBelow, exitValueLoop);
build.vmovups(xmm0, xmmword[vali]); build.vmovups(xmm0, xmmword[vali]);
build.vmovups(xmmword[res], xmm0); build.vmovups(xmmword[res], xmm0);
build.add(vali, sizeof(TValue)); build.add(vali, sizeof(TValue));
build.add(res, sizeof(TValue)); build.add(res, sizeof(TValue));
build.dec(counter); build.dec(counter);
build.jcc(Condition::NotZero, repeatValueLoop); build.jcc(ConditionX64::NotZero, repeatValueLoop);
build.setLabel(exitValueLoop); build.setLabel(exitValueLoop);
// Fill the rest of the expected results with 'nil' // Fill the rest of the expected results with 'nil'
build.test(counter, counter); // test here will set SF=1 for a negative number, ZF=1 for zero and OF=0 build.test(counter, counter); // test here will set SF=1 for a negative number, ZF=1 for zero and OF=0
build.jcc(Condition::LessEqual, skipResultCopy); // jle jumps if SF != OF or ZF == 1 build.jcc(ConditionX64::LessEqual, skipResultCopy); // jle jumps if SF != OF or ZF == 1
Label repeatNilLoop = build.setLabel(); Label repeatNilLoop = build.setLabel();
build.mov(dword[res + offsetof(TValue, tt)], LUA_TNIL); build.mov(dword[res + offsetof(TValue, tt)], LUA_TNIL);
build.add(res, sizeof(TValue)); build.add(res, sizeof(TValue));
build.dec(counter); build.dec(counter);
build.jcc(Condition::NotZero, repeatNilLoop); build.jcc(ConditionX64::NotZero, repeatNilLoop);
} }
build.setLabel(skipResultCopy); build.setLabel(skipResultCopy);
@ -191,11 +341,11 @@ void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Ins
// Unlikely, but this might be the last return from VM // Unlikely, but this might be the last return from VM
build.test(byte[ci + offsetof(CallInfo, flags)], LUA_CALLINFO_RETURN); build.test(byte[ci + offsetof(CallInfo, flags)], LUA_CALLINFO_RETURN);
build.jcc(Condition::NotZero, helpers.exitNoContinueVm); build.jcc(ConditionX64::NotZero, helpers.exitNoContinueVm);
Label skipFixedRetTop; Label skipFixedRetTop;
build.test(nresults, nresults); // test here will set SF=1 for a negative number and it always sets OF to 0 build.test(nresults, nresults); // test here will set SF=1 for a negative number and it always sets OF to 0
build.jcc(Condition::Less, skipFixedRetTop); // jl jumps if SF != OF build.jcc(ConditionX64::Less, skipFixedRetTop); // jl jumps if SF != OF
build.mov(rax, qword[cip + offsetof(CallInfo, top)]); build.mov(rax, qword[cip + offsetof(CallInfo, top)]);
build.mov(qword[rState + offsetof(lua_State, top)], rax); // L->top = cip->top build.mov(qword[rState + offsetof(lua_State, top)], rax); // L->top = cip->top
build.setLabel(skipFixedRetTop); build.setLabel(skipFixedRetTop);
@ -214,7 +364,7 @@ void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Ins
build.mov(execdata, qword[proto + offsetofProtoExecData]); build.mov(execdata, qword[proto + offsetofProtoExecData]);
build.test(execdata, execdata); build.test(execdata, execdata);
build.jcc(Condition::Zero, helpers.exitContinueVm); // Continue in interpreter if function has no native data build.jcc(ConditionX64::Zero, helpers.exitContinueVm); // Continue in interpreter if function has no native data
// Change constants // Change constants
build.mov(rConstants, qword[proto + offsetof(Proto, k)]); build.mov(rConstants, qword[proto + offsetof(Proto, k)]);
@ -269,13 +419,13 @@ void emitInstJumpIfEq(AssemblyBuilderX64& build, const Instruction* pc, int pcpo
build.mov(eax, luauRegTag(ra)); build.mov(eax, luauRegTag(ra));
build.cmp(eax, luauRegTag(rb)); build.cmp(eax, luauRegTag(rb));
build.jcc(Condition::NotEqual, not_ ? target : exit); build.jcc(ConditionX64::NotEqual, not_ ? target : exit);
// fast-path: number // fast-path: number
build.cmp(eax, LUA_TNUMBER); build.cmp(eax, LUA_TNUMBER);
build.jcc(Condition::NotEqual, fallback); build.jcc(ConditionX64::NotEqual, fallback);
jumpOnNumberCmp(build, xmm0, luauRegValue(ra), luauRegValue(rb), Condition::NotEqual, not_ ? target : exit); jumpOnNumberCmp(build, xmm0, luauRegValue(ra), luauRegValue(rb), ConditionX64::NotEqual, not_ ? target : exit);
if (!not_) if (!not_)
build.jmp(target); build.jmp(target);
@ -285,10 +435,10 @@ void emitInstJumpIfEqFallback(AssemblyBuilderX64& build, const Instruction* pc,
{ {
Label& target = labelarr[pcpos + 1 + LUAU_INSN_D(*pc)]; Label& target = labelarr[pcpos + 1 + LUAU_INSN_D(*pc)];
jumpOnAnyCmpFallback(build, LUAU_INSN_A(*pc), pc[1], not_ ? Condition::NotEqual : Condition::Equal, target, pcpos); jumpOnAnyCmpFallback(build, LUAU_INSN_A(*pc), pc[1], not_ ? ConditionX64::NotEqual : ConditionX64::Equal, target, pcpos);
} }
void emitInstJumpIfCond(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Condition cond, Label& fallback) void emitInstJumpIfCond(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, ConditionX64 cond, Label& fallback)
{ {
int ra = LUAU_INSN_A(*pc); int ra = LUAU_INSN_A(*pc);
int rb = pc[1]; int rb = pc[1];
@ -302,7 +452,7 @@ void emitInstJumpIfCond(AssemblyBuilderX64& build, const Instruction* pc, int pc
jumpOnNumberCmp(build, xmm0, luauRegValue(ra), luauRegValue(rb), cond, target); jumpOnNumberCmp(build, xmm0, luauRegValue(ra), luauRegValue(rb), cond, target);
} }
void emitInstJumpIfCondFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Condition cond) void emitInstJumpIfCondFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, ConditionX64 cond)
{ {
Label& target = labelarr[pcpos + 1 + LUAU_INSN_D(*pc)]; Label& target = labelarr[pcpos + 1 + LUAU_INSN_D(*pc)];
@ -324,7 +474,7 @@ void emitInstJumpxEqNil(AssemblyBuilderX64& build, const Instruction* pc, int pc
Label& target = labelarr[pcpos + 1 + LUAU_INSN_D(*pc)]; Label& target = labelarr[pcpos + 1 + LUAU_INSN_D(*pc)];
build.cmp(luauRegTag(ra), LUA_TNIL); build.cmp(luauRegTag(ra), LUA_TNIL);
build.jcc(not_ ? Condition::NotEqual : Condition::Equal, target); build.jcc(not_ ? ConditionX64::NotEqual : ConditionX64::Equal, target);
} }
void emitInstJumpxEqB(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr) void emitInstJumpxEqB(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr)
@ -339,7 +489,7 @@ void emitInstJumpxEqB(AssemblyBuilderX64& build, const Instruction* pc, int pcpo
jumpIfTagIsNot(build, ra, LUA_TBOOLEAN, not_ ? target : exit); jumpIfTagIsNot(build, ra, LUA_TBOOLEAN, not_ ? target : exit);
build.test(luauRegValueBoolean(ra), 1); build.test(luauRegValueBoolean(ra), 1);
build.jcc((aux & 0x1) ^ not_ ? Condition::NotZero : Condition::Zero, target); build.jcc((aux & 0x1) ^ not_ ? ConditionX64::NotZero : ConditionX64::Zero, target);
} }
void emitInstJumpxEqN(AssemblyBuilderX64& build, const Instruction* pc, const TValue* k, int pcpos, Label* labelarr) void emitInstJumpxEqN(AssemblyBuilderX64& build, const Instruction* pc, const TValue* k, int pcpos, Label* labelarr)
@ -356,15 +506,15 @@ void emitInstJumpxEqN(AssemblyBuilderX64& build, const Instruction* pc, const TV
if (not_) if (not_)
{ {
jumpOnNumberCmp(build, xmm0, luauRegValue(ra), build.f64(kv.value.n), Condition::NotEqual, target); jumpOnNumberCmp(build, xmm0, luauRegValue(ra), build.f64(kv.value.n), ConditionX64::NotEqual, target);
} }
else else
{ {
// Compact equality check requires two labels, so it's not supported in generic 'jumpOnNumberCmp' // Compact equality check requires two labels, so it's not supported in generic 'jumpOnNumberCmp'
build.vmovsd(xmm0, luauRegValue(ra)); build.vmovsd(xmm0, luauRegValue(ra));
build.vucomisd(xmm0, build.f64(kv.value.n)); build.vucomisd(xmm0, build.f64(kv.value.n));
build.jcc(Condition::Parity, exit); // We first have to check PF=1 for NaN operands, because it also sets ZF=1 build.jcc(ConditionX64::Parity, exit); // We first have to check PF=1 for NaN operands, because it also sets ZF=1
build.jcc(Condition::Zero, target); // Now that NaN is out of the way, we can check ZF=1 for equality build.jcc(ConditionX64::Zero, target); // Now that NaN is out of the way, we can check ZF=1 for equality
} }
} }
@ -381,7 +531,7 @@ void emitInstJumpxEqS(AssemblyBuilderX64& build, const Instruction* pc, int pcpo
build.mov(rax, luauRegValue(ra)); build.mov(rax, luauRegValue(ra));
build.cmp(rax, luauConstantValue(aux & 0xffffff)); build.cmp(rax, luauConstantValue(aux & 0xffffff));
build.jcc(not_ ? Condition::NotEqual : Condition::Equal, target); build.jcc(not_ ? ConditionX64::NotEqual : ConditionX64::Equal, target);
} }
static void emitInstBinaryNumeric(AssemblyBuilderX64& build, int ra, int rb, int rc, OperandX64 opc, int pcpos, TMS tm, Label& fallback) static void emitInstBinaryNumeric(AssemblyBuilderX64& build, int ra, int rb, int rc, OperandX64 opc, int pcpos, TMS tm, Label& fallback)
@ -437,7 +587,7 @@ void emitInstBinary(AssemblyBuilderX64& build, const Instruction* pc, int pcpos,
void emitInstBinaryFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, TMS tm) void emitInstBinaryFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, TMS tm)
{ {
callArithHelper(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), luauRegValue(LUAU_INSN_C(*pc)), pcpos, tm); callArithHelper(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), luauRegAddress(LUAU_INSN_C(*pc)), pcpos, tm);
} }
void emitInstBinaryK(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, TMS tm, Label& fallback) void emitInstBinaryK(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, TMS tm, Label& fallback)
@ -447,7 +597,7 @@ void emitInstBinaryK(AssemblyBuilderX64& build, const Instruction* pc, int pcpos
void emitInstBinaryKFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, TMS tm) void emitInstBinaryKFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, TMS tm)
{ {
callArithHelper(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), luauConstantValue(LUAU_INSN_C(*pc)), pcpos, tm); callArithHelper(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), luauConstantAddress(LUAU_INSN_C(*pc)), pcpos, tm);
} }
void emitInstPowK(AssemblyBuilderX64& build, const Instruction* pc, const TValue* k, int pcpos, Label& fallback) void emitInstPowK(AssemblyBuilderX64& build, const Instruction* pc, const TValue* k, int pcpos, Label& fallback)
@ -525,7 +675,7 @@ void emitInstMinus(AssemblyBuilderX64& build, const Instruction* pc, int pcpos,
void emitInstMinusFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos) void emitInstMinusFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos)
{ {
callArithHelper(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), luauRegValue(LUAU_INSN_B(*pc)), pcpos, TM_UNM); callArithHelper(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), luauRegAddress(LUAU_INSN_B(*pc)), pcpos, TM_UNM);
} }
void emitInstLength(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback) void emitInstLength(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback)
@ -563,8 +713,8 @@ void emitInstNewTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpo
emitSetSavedPc(build, pcpos + 1); emitSetSavedPc(build, pcpos + 1);
build.mov(rArg1, rState); build.mov(rArg1, rState);
build.mov(rArg2, aux); build.mov(dwordReg(rArg2), aux);
build.mov(rArg3, b == 0 ? 0 : 1 << (b - 1)); build.mov(dwordReg(rArg3), 1 << (b - 1));
build.call(qword[rNativeContext + offsetof(NativeContext, luaH_new)]); build.call(qword[rNativeContext + offsetof(NativeContext, luaH_new)]);
build.mov(luauRegValue(ra), rax); build.mov(luauRegValue(ra), rax);
build.mov(luauRegTag(ra), LUA_TTABLE); build.mov(luauRegTag(ra), LUA_TTABLE);
@ -610,7 +760,7 @@ void emitInstSetList(AssemblyBuilderX64& build, const Instruction* pc, int pcpos
// c = L->top - rb // c = L->top - rb
build.mov(cscaled, qword[rState + offsetof(lua_State, top)]); build.mov(cscaled, qword[rState + offsetof(lua_State, top)]);
build.lea(tmp, luauRegValue(rb)); build.lea(tmp, luauRegAddress(rb));
build.sub(cscaled, tmp); // Using byte difference build.sub(cscaled, tmp); // Using byte difference
// L->top = L->ci->top // L->top = L->ci->top
@ -633,7 +783,7 @@ void emitInstSetList(AssemblyBuilderX64& build, const Instruction* pc, int pcpos
// Resize if h->sizearray < last // Resize if h->sizearray < last
build.cmp(dword[table + offsetof(Table, sizearray)], last); build.cmp(dword[table + offsetof(Table, sizearray)], last);
build.jcc(Condition::NotBelow, skipResize); build.jcc(ConditionX64::NotBelow, skipResize);
// Argument setup reordered to avoid conflicts // Argument setup reordered to avoid conflicts
LUAU_ASSERT(rArg3 != table); LUAU_ASSERT(rArg3 != table);
@ -676,7 +826,7 @@ void emitInstSetList(AssemblyBuilderX64& build, const Instruction* pc, int pcpos
if (c == LUA_MULTRET) if (c == LUA_MULTRET)
{ {
build.cmp(offset, limit); build.cmp(offset, limit);
build.jcc(Condition::NotBelow, endLoop); build.jcc(ConditionX64::NotBelow, endLoop);
} }
build.setLabel(repeatLoop); build.setLabel(repeatLoop);
@ -687,7 +837,7 @@ void emitInstSetList(AssemblyBuilderX64& build, const Instruction* pc, int pcpos
build.add(offset, sizeof(TValue)); build.add(offset, sizeof(TValue));
build.cmp(offset, limit); build.cmp(offset, limit);
build.jcc(Condition::Below, repeatLoop); build.jcc(ConditionX64::Below, repeatLoop);
build.setLabel(endLoop); build.setLabel(endLoop);
} }
@ -707,7 +857,7 @@ void emitInstGetUpval(AssemblyBuilderX64& build, const Instruction* pc, int pcpo
Label skip; Label skip;
// TODO: jumpIfTagIsNot can be generalized to take OperandX64 and then we can use it here; let's wait until we see this more though // TODO: jumpIfTagIsNot can be generalized to take OperandX64 and then we can use it here; let's wait until we see this more though
build.cmp(dword[rax + offsetof(TValue, tt)], LUA_TUPVAL); build.cmp(dword[rax + offsetof(TValue, tt)], LUA_TUPVAL);
build.jcc(Condition::NotEqual, skip); build.jcc(ConditionX64::NotEqual, skip);
// UpVal.v points to the value (either on stack, or on heap inside each UpVal, but we can deref it unconditionally) // UpVal.v points to the value (either on stack, or on heap inside each UpVal, but we can deref it unconditionally)
build.mov(rax, qword[rax + offsetof(TValue, value.gc)]); build.mov(rax, qword[rax + offsetof(TValue, value.gc)]);
@ -746,12 +896,12 @@ void emitInstCloseUpvals(AssemblyBuilderX64& build, const Instruction* pc, int p
// L->openupval != 0 // L->openupval != 0
build.mov(rax, qword[rState + offsetof(lua_State, openupval)]); build.mov(rax, qword[rState + offsetof(lua_State, openupval)]);
build.test(rax, rax); build.test(rax, rax);
build.jcc(Condition::Zero, skip); build.jcc(ConditionX64::Zero, skip);
// ra <= L->openuval->v // ra <= L->openuval->v
build.lea(rcx, qword[rBase + ra * sizeof(TValue)]); build.lea(rcx, addr[rBase + ra * sizeof(TValue)]);
build.cmp(rcx, qword[rax + offsetof(UpVal, v)]); build.cmp(rcx, qword[rax + offsetof(UpVal, v)]);
build.jcc(Condition::Above, skip); build.jcc(ConditionX64::Above, skip);
build.mov(rArg2, rcx); build.mov(rArg2, rcx);
build.mov(rArg1, rState); build.mov(rArg1, rState);
@ -771,7 +921,7 @@ static int emitInstFastCallN(AssemblyBuilderX64& build, const Instruction* pc, b
int nparams = customParams ? customParamCount : LUAU_INSN_B(call) - 1; int nparams = customParams ? customParamCount : LUAU_INSN_B(call) - 1;
int nresults = LUAU_INSN_C(call) - 1; int nresults = LUAU_INSN_C(call) - 1;
int arg = customParams ? LUAU_INSN_B(*pc) : ra + 1; int arg = customParams ? LUAU_INSN_B(*pc) : ra + 1;
OperandX64 args = customParams ? customArgs : luauRegValue(ra + 2); OperandX64 args = customParams ? customArgs : luauRegAddress(ra + 2);
Label& exit = labelarr[pcpos + instLen]; Label& exit = labelarr[pcpos + instLen];
@ -784,7 +934,7 @@ static int emitInstFastCallN(AssemblyBuilderX64& build, const Instruction* pc, b
if (nresults == LUA_MULTRET) if (nresults == LUA_MULTRET)
{ {
// L->top = ra + n; // L->top = ra + n;
build.lea(rax, qword[rBase + (ra + br.actualResultCount) * sizeof(TValue)]); build.lea(rax, addr[rBase + (ra + br.actualResultCount) * sizeof(TValue)]);
build.mov(qword[rState + offsetof(lua_State, top)], rax); build.mov(qword[rState + offsetof(lua_State, top)], rax);
} }
else if (nparams == LUA_MULTRET) else if (nparams == LUA_MULTRET)
@ -822,7 +972,7 @@ static int emitInstFastCallN(AssemblyBuilderX64& build, const Instruction* pc, b
// TODO: for SystemV ABI we can compute the result directly into rArg6 // TODO: for SystemV ABI we can compute the result directly into rArg6
// L->top - (ra + 1) // L->top - (ra + 1)
build.mov(rcx, qword[rState + offsetof(lua_State, top)]); build.mov(rcx, qword[rState + offsetof(lua_State, top)]);
build.lea(rdx, qword[rBase + (ra + 1) * sizeof(TValue)]); build.lea(rdx, addr[rBase + (ra + 1) * sizeof(TValue)]);
build.sub(rcx, rdx); build.sub(rcx, rdx);
build.shr(rcx, kTValueSizeLog2); build.shr(rcx, kTValueSizeLog2);
@ -840,20 +990,20 @@ static int emitInstFastCallN(AssemblyBuilderX64& build, const Instruction* pc, b
} }
build.mov(rArg1, rState); build.mov(rArg1, rState);
build.lea(rArg2, luauRegValue(ra)); build.lea(rArg2, luauRegAddress(ra));
build.lea(rArg3, luauRegValue(arg)); build.lea(rArg3, luauRegAddress(arg));
build.mov(rArg4, nresults); build.mov(dwordReg(rArg4), nresults);
build.call(rax); build.call(rax);
build.test(eax, eax); // test here will set SF=1 for a negative number and it always sets OF to 0 build.test(eax, eax); // test here will set SF=1 for a negative number and it always sets OF to 0
build.jcc(Condition::Less, exit); // jl jumps if SF != OF build.jcc(ConditionX64::Less, exit); // jl jumps if SF != OF
if (nresults == LUA_MULTRET) if (nresults == LUA_MULTRET)
{ {
// L->top = ra + n; // L->top = ra + n;
build.shl(rax, kTValueSizeLog2); build.shl(rax, kTValueSizeLog2);
build.lea(rax, qword[rBase + rax + ra * sizeof(TValue)]); build.lea(rax, addr[rBase + rax + ra * sizeof(TValue)]);
build.mov(qword[rState + offsetof(lua_State, top)], rax); build.mov(qword[rState + offsetof(lua_State, top)], rax);
} }
else if (nparams == LUA_MULTRET) else if (nparams == LUA_MULTRET)
@ -875,13 +1025,13 @@ int emitInstFastCall1(AssemblyBuilderX64& build, const Instruction* pc, int pcpo
int emitInstFastCall2(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr) int emitInstFastCall2(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr)
{ {
return emitInstFastCallN( return emitInstFastCallN(
build, pc, /* customParams */ true, /* customParamCount */ 2, /* customArgs */ luauRegValue(pc[1]), pcpos, /* instLen */ 2, labelarr); build, pc, /* customParams */ true, /* customParamCount */ 2, /* customArgs */ luauRegAddress(pc[1]), pcpos, /* instLen */ 2, labelarr);
} }
int emitInstFastCall2K(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr) int emitInstFastCall2K(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr)
{ {
return emitInstFastCallN( return emitInstFastCallN(
build, pc, /* customParams */ true, /* customParamCount */ 2, /* customArgs */ luauConstantValue(pc[1]), pcpos, /* instLen */ 2, labelarr); build, pc, /* customParams */ true, /* customParamCount */ 2, /* customArgs */ luauConstantAddress(pc[1]), pcpos, /* instLen */ 2, labelarr);
} }
int emitInstFastCall(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr) int emitInstFastCall(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr)
@ -916,16 +1066,16 @@ void emitInstForNPrep(AssemblyBuilderX64& build, const Instruction* pc, int pcpo
Label reverse; Label reverse;
// step <= 0 // step <= 0
jumpOnNumberCmp(build, noreg, step, zero, Condition::LessEqual, reverse); jumpOnNumberCmp(build, noreg, step, zero, ConditionX64::LessEqual, reverse);
// TODO: target branches can probably be arranged better, but we need tests for NaN behavior preservation // TODO: target branches can probably be arranged better, but we need tests for NaN behavior preservation
// false: idx <= limit // false: idx <= limit
jumpOnNumberCmp(build, noreg, idx, limit, Condition::LessEqual, exit); jumpOnNumberCmp(build, noreg, idx, limit, ConditionX64::LessEqual, exit);
build.jmp(loopExit); build.jmp(loopExit);
// true: limit <= idx // true: limit <= idx
build.setLabel(reverse); build.setLabel(reverse);
jumpOnNumberCmp(build, noreg, limit, idx, Condition::LessEqual, exit); jumpOnNumberCmp(build, noreg, limit, idx, ConditionX64::LessEqual, exit);
build.jmp(loopExit); build.jmp(loopExit);
// TOOD: place at the end of the function // TOOD: place at the end of the function
@ -958,15 +1108,15 @@ void emitInstForNLoop(AssemblyBuilderX64& build, const Instruction* pc, int pcpo
Label reverse, exit; Label reverse, exit;
// step <= 0 // step <= 0
jumpOnNumberCmp(build, noreg, step, zero, Condition::LessEqual, reverse); jumpOnNumberCmp(build, noreg, step, zero, ConditionX64::LessEqual, reverse);
// false: idx <= limit // false: idx <= limit
jumpOnNumberCmp(build, noreg, idx, limit, Condition::LessEqual, loopRepeat); jumpOnNumberCmp(build, noreg, idx, limit, ConditionX64::LessEqual, loopRepeat);
build.jmp(exit); build.jmp(exit);
// true: limit <= idx // true: limit <= idx
build.setLabel(reverse); build.setLabel(reverse);
jumpOnNumberCmp(build, noreg, limit, idx, Condition::LessEqual, loopRepeat); jumpOnNumberCmp(build, noreg, limit, idx, ConditionX64::LessEqual, loopRepeat);
build.setLabel(exit); build.setLabel(exit);
} }
@ -1010,13 +1160,13 @@ void emitinstForGLoop(AssemblyBuilderX64& build, const Instruction* pc, int pcpo
// while (unsigned(index) < unsigned(sizearray)) // while (unsigned(index) < unsigned(sizearray))
Label arrayLoop = build.setLabel(); Label arrayLoop = build.setLabel();
build.cmp(dwordReg(index), dword[table + offsetof(Table, sizearray)]); build.cmp(dwordReg(index), dword[table + offsetof(Table, sizearray)]);
build.jcc(Condition::NotBelow, isIpairsIter ? exit : skipArray); build.jcc(ConditionX64::NotBelow, isIpairsIter ? exit : skipArray);
// If element is nil, we increment the index; if it's not, we still need 'index + 1' inside // If element is nil, we increment the index; if it's not, we still need 'index + 1' inside
build.inc(index); build.inc(index);
build.cmp(dword[elemPtr + offsetof(TValue, tt)], LUA_TNIL); build.cmp(dword[elemPtr + offsetof(TValue, tt)], LUA_TNIL);
build.jcc(Condition::Equal, isIpairsIter ? exit : skipArrayNil); build.jcc(ConditionX64::Equal, isIpairsIter ? exit : skipArrayNil);
// setpvalue(ra + 2, reinterpret_cast<void*>(uintptr_t(index + 1))); // setpvalue(ra + 2, reinterpret_cast<void*>(uintptr_t(index + 1)));
build.mov(luauRegValue(ra + 2), index); build.mov(luauRegValue(ra + 2), index);
@ -1045,10 +1195,10 @@ void emitinstForGLoop(AssemblyBuilderX64& build, const Instruction* pc, int pcpo
// Call helper to assign next node value or to signal loop exit // Call helper to assign next node value or to signal loop exit
build.mov(rArg1, rState); build.mov(rArg1, rState);
// rArg2 and rArg3 are already set // rArg2 and rArg3 are already set
build.lea(rArg4, luauRegValue(ra)); build.lea(rArg4, luauRegAddress(ra));
build.call(qword[rNativeContext + offsetof(NativeContext, forgLoopNodeIter)]); build.call(qword[rNativeContext + offsetof(NativeContext, forgLoopNodeIter)]);
build.test(al, al); build.test(al, al);
build.jcc(Condition::NotZero, loopRepeat); build.jcc(ConditionX64::NotZero, loopRepeat);
} }
} }
@ -1062,12 +1212,12 @@ void emitinstForGLoopFallback(AssemblyBuilderX64& build, const Instruction* pc,
emitSetSavedPc(build, pcpos + 1); emitSetSavedPc(build, pcpos + 1);
build.mov(rArg1, rState); build.mov(rArg1, rState);
build.mov(rArg2, ra); build.mov(dwordReg(rArg2), ra);
build.mov(rArg3, aux); build.mov(dwordReg(rArg3), aux);
build.call(qword[rNativeContext + offsetof(NativeContext, forgLoopNonTableFallback)]); build.call(qword[rNativeContext + offsetof(NativeContext, forgLoopNonTableFallback)]);
emitUpdateBase(build); emitUpdateBase(build);
build.test(al, al); build.test(al, al);
build.jcc(Condition::NotZero, loopRepeat); build.jcc(ConditionX64::NotZero, loopRepeat);
} }
void emitInstForGPrepNext(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Label& fallback) void emitInstForGPrepNext(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Label& fallback)
@ -1103,7 +1253,7 @@ void emitInstForGPrepInext(AssemblyBuilderX64& build, const Instruction* pc, int
build.vxorpd(xmm0, xmm0, xmm0); build.vxorpd(xmm0, xmm0, xmm0);
build.vmovsd(xmm1, luauRegValue(ra + 2)); build.vmovsd(xmm1, luauRegValue(ra + 2));
jumpOnNumberCmp(build, noreg, xmm0, xmm1, Condition::NotEqual, fallback); jumpOnNumberCmp(build, noreg, xmm0, xmm1, ConditionX64::NotEqual, fallback);
build.mov(luauRegTag(ra), LUA_TNIL); build.mov(luauRegTag(ra), LUA_TNIL);
@ -1121,8 +1271,8 @@ void emitInstForGPrepXnextFallback(AssemblyBuilderX64& build, const Instruction*
Label& target = labelarr[pcpos + 1 + LUAU_INSN_D(*pc)]; Label& target = labelarr[pcpos + 1 + LUAU_INSN_D(*pc)];
build.mov(rArg1, rState); build.mov(rArg1, rState);
build.lea(rArg2, luauRegValue(ra)); build.lea(rArg2, luauRegAddress(ra));
build.mov(rArg3, pcpos + 1); build.mov(dwordReg(rArg3), pcpos + 1);
build.call(qword[rNativeContext + offsetof(NativeContext, forgPrepXnextFallback)]); build.call(qword[rNativeContext + offsetof(NativeContext, forgPrepXnextFallback)]);
build.jmp(target); build.jmp(target);
} }
@ -1216,7 +1366,7 @@ void emitInstGetTableN(AssemblyBuilderX64& build, const Instruction* pc, int pcp
// unsigned(c) < unsigned(h->sizearray) // unsigned(c) < unsigned(h->sizearray)
build.cmp(dword[table + offsetof(Table, sizearray)], c); build.cmp(dword[table + offsetof(Table, sizearray)], c);
build.jcc(Condition::BelowEqual, fallback); build.jcc(ConditionX64::BelowEqual, fallback);
jumpIfMetatablePresent(build, table, fallback); jumpIfMetatablePresent(build, table, fallback);
@ -1244,7 +1394,7 @@ void emitInstSetTableN(AssemblyBuilderX64& build, const Instruction* pc, int pcp
// unsigned(c) < unsigned(h->sizearray) // unsigned(c) < unsigned(h->sizearray)
build.cmp(dword[table + offsetof(Table, sizearray)], c); build.cmp(dword[table + offsetof(Table, sizearray)], c);
build.jcc(Condition::BelowEqual, fallback); build.jcc(ConditionX64::BelowEqual, fallback);
jumpIfMetatablePresent(build, table, fallback); jumpIfMetatablePresent(build, table, fallback);
jumpIfTableIsReadOnly(build, table, fallback); jumpIfTableIsReadOnly(build, table, fallback);
@ -1284,7 +1434,7 @@ void emitInstGetTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpo
// unsigned(index - 1) < unsigned(h->sizearray) // unsigned(index - 1) < unsigned(h->sizearray)
build.cmp(dword[table + offsetof(Table, sizearray)], eax); build.cmp(dword[table + offsetof(Table, sizearray)], eax);
build.jcc(Condition::BelowEqual, fallback); build.jcc(ConditionX64::BelowEqual, fallback);
jumpIfMetatablePresent(build, table, fallback); jumpIfMetatablePresent(build, table, fallback);
@ -1296,7 +1446,7 @@ void emitInstGetTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpo
void emitInstGetTableFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos) void emitInstGetTableFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos)
{ {
callGetTable(build, LUAU_INSN_B(*pc), luauRegValue(LUAU_INSN_C(*pc)), LUAU_INSN_A(*pc), pcpos); callGetTable(build, LUAU_INSN_B(*pc), luauRegAddress(LUAU_INSN_C(*pc)), LUAU_INSN_A(*pc), pcpos);
} }
void emitInstSetTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Label& fallback) void emitInstSetTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Label& fallback)
@ -1319,7 +1469,7 @@ void emitInstSetTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpo
// unsigned(index - 1) < unsigned(h->sizearray) // unsigned(index - 1) < unsigned(h->sizearray)
build.cmp(dword[table + offsetof(Table, sizearray)], eax); build.cmp(dword[table + offsetof(Table, sizearray)], eax);
build.jcc(Condition::BelowEqual, fallback); build.jcc(ConditionX64::BelowEqual, fallback);
jumpIfMetatablePresent(build, table, fallback); jumpIfMetatablePresent(build, table, fallback);
jumpIfTableIsReadOnly(build, table, fallback); jumpIfTableIsReadOnly(build, table, fallback);
@ -1335,7 +1485,7 @@ void emitInstSetTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpo
void emitInstSetTableFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos) void emitInstSetTableFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos)
{ {
callSetTable(build, LUAU_INSN_B(*pc), luauRegValue(LUAU_INSN_C(*pc)), LUAU_INSN_A(*pc), pcpos); callSetTable(build, LUAU_INSN_B(*pc), luauRegAddress(LUAU_INSN_C(*pc)), LUAU_INSN_A(*pc), pcpos);
} }
void emitInstGetImport(AssemblyBuilderX64& build, const Instruction* pc, Label& fallback) void emitInstGetImport(AssemblyBuilderX64& build, const Instruction* pc, Label& fallback)
@ -1348,7 +1498,7 @@ void emitInstGetImport(AssemblyBuilderX64& build, const Instruction* pc, Label&
// note: if import failed, k[] is nil; we could check this during codegen, but we instead use runtime fallback // note: if import failed, k[] is nil; we could check this during codegen, but we instead use runtime fallback
// this allows us to handle ahead-of-time codegen smoothly when an import fails to resolve at runtime // this allows us to handle ahead-of-time codegen smoothly when an import fails to resolve at runtime
build.cmp(luauConstantTag(k), LUA_TNIL); build.cmp(luauConstantTag(k), LUA_TNIL);
build.jcc(Condition::Equal, fallback); build.jcc(ConditionX64::Equal, fallback);
build.vmovups(xmm0, luauConstant(k)); build.vmovups(xmm0, luauConstant(k));
build.vmovups(luauReg(ra), xmm0); build.vmovups(luauReg(ra), xmm0);
@ -1367,7 +1517,7 @@ void emitInstGetImportFallback(AssemblyBuilderX64& build, const Instruction* pc,
build.mov(rArg1, rState); build.mov(rArg1, rState);
build.mov(rArg2, qword[rax + offsetof(Closure, env)]); build.mov(rArg2, qword[rax + offsetof(Closure, env)]);
build.mov(rArg3, rConstants); build.mov(rArg3, rConstants);
build.mov(rArg4, aux); build.mov(dwordReg(rArg4), aux);
if (build.abi == ABIX64::Windows) if (build.abi == ABIX64::Windows)
build.mov(sArg5, 0); build.mov(sArg5, 0);
@ -1470,8 +1620,8 @@ void emitInstConcat(AssemblyBuilderX64& build, const Instruction* pc, int pcpos,
// luaV_concat(L, c - b + 1, c) // luaV_concat(L, c - b + 1, c)
build.mov(rArg1, rState); build.mov(rArg1, rState);
build.mov(rArg2, rc - rb + 1); build.mov(dwordReg(rArg2), rc - rb + 1);
build.mov(rArg3, rc); build.mov(dwordReg(rArg3), rc);
build.call(qword[rNativeContext + offsetof(NativeContext, luaV_concat)]); build.call(qword[rNativeContext + offsetof(NativeContext, luaV_concat)]);
emitUpdateBase(build); emitUpdateBase(build);

View file

@ -14,7 +14,7 @@ namespace CodeGen
{ {
class AssemblyBuilderX64; class AssemblyBuilderX64;
enum class Condition; enum class ConditionX64 : uint8_t;
struct Label; struct Label;
struct ModuleHelpers; struct ModuleHelpers;
@ -24,14 +24,15 @@ void emitInstLoadN(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstLoadK(AssemblyBuilderX64& build, const Instruction* pc); void emitInstLoadK(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstLoadKX(AssemblyBuilderX64& build, const Instruction* pc); void emitInstLoadKX(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstMove(AssemblyBuilderX64& build, const Instruction* pc); void emitInstMove(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstCall(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Instruction* pc, int pcpos, Label* labelarr); void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstJump(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr); void emitInstJump(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstJumpBack(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr); void emitInstJumpBack(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstJumpIf(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, bool not_); void emitInstJumpIf(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, bool not_);
void emitInstJumpIfEq(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, bool not_, Label& fallback); void emitInstJumpIfEq(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, bool not_, Label& fallback);
void emitInstJumpIfEqFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, bool not_); void emitInstJumpIfEqFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, bool not_);
void emitInstJumpIfCond(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Condition cond, Label& fallback); void emitInstJumpIfCond(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, ConditionX64 cond, Label& fallback);
void emitInstJumpIfCondFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Condition cond); void emitInstJumpIfCondFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, ConditionX64 cond);
void emitInstJumpX(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr); void emitInstJumpX(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstJumpxEqNil(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr); void emitInstJumpxEqNil(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstJumpxEqB(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr); void emitInstJumpxEqB(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);

File diff suppressed because it is too large Load diff

View file

@ -10,84 +10,15 @@ typedef uint32_t Instruction;
typedef struct lua_TValue TValue; typedef struct lua_TValue TValue;
typedef TValue* StkId; typedef TValue* StkId;
const Instruction* execute_LOP_NOP(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_LOADNIL(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_LOADB(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_LOADN(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_LOADK(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_MOVE(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_GETGLOBAL(lua_State* L, const Instruction* pc, StkId base, TValue* k); const Instruction* execute_LOP_GETGLOBAL(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_SETGLOBAL(lua_State* L, const Instruction* pc, StkId base, TValue* k); const Instruction* execute_LOP_SETGLOBAL(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_GETUPVAL(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_SETUPVAL(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_CLOSEUPVALS(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_GETIMPORT(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_GETTABLEKS(lua_State* L, const Instruction* pc, StkId base, TValue* k); const Instruction* execute_LOP_GETTABLEKS(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_SETTABLEKS(lua_State* L, const Instruction* pc, StkId base, TValue* k); const Instruction* execute_LOP_SETTABLEKS(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_GETTABLE(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_SETTABLE(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_GETTABLEN(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_SETTABLEN(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_NEWCLOSURE(lua_State* L, const Instruction* pc, StkId base, TValue* k); const Instruction* execute_LOP_NEWCLOSURE(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_NAMECALL(lua_State* L, const Instruction* pc, StkId base, TValue* k); const Instruction* execute_LOP_NAMECALL(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_CALL(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_RETURN(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMP(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPIF(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPIFNOT(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPIFEQ(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPIFNOTEQ(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPIFLE(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPIFNOTLE(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPIFLT(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPIFNOTLT(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_ADD(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_SUB(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_MUL(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_DIV(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_MOD(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_POW(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_ADDK(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_SUBK(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_MULK(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_DIVK(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_MODK(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_POWK(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_AND(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_OR(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_ANDK(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_ORK(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_CONCAT(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_NOT(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_MINUS(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_LENGTH(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_NEWTABLE(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_DUPTABLE(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_SETLIST(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_FORNPREP(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_FORNLOOP(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_FORGPREP(lua_State* L, const Instruction* pc, StkId base, TValue* k); const Instruction* execute_LOP_FORGPREP(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_FORGLOOP(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_FORGPREP_INEXT(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_DEP_FORGLOOP_INEXT(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_FORGPREP_NEXT(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_DEP_FORGLOOP_NEXT(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_GETVARARGS(lua_State* L, const Instruction* pc, StkId base, TValue* k); const Instruction* execute_LOP_GETVARARGS(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_DUPCLOSURE(lua_State* L, const Instruction* pc, StkId base, TValue* k); const Instruction* execute_LOP_DUPCLOSURE(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_PREPVARARGS(lua_State* L, const Instruction* pc, StkId base, TValue* k); const Instruction* execute_LOP_PREPVARARGS(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPBACK(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_LOADKX(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPX(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_FASTCALL(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_COVERAGE(lua_State* L, const Instruction* pc, StkId base, TValue* k); const Instruction* execute_LOP_COVERAGE(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_CAPTURE(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_DEP_JUMPIFEQK(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_DEP_JUMPIFNOTEQK(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_FASTCALL1(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_FASTCALL2(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_FASTCALL2K(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_BREAK(lua_State* L, const Instruction* pc, StkId base, TValue* k); const Instruction* execute_LOP_BREAK(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPXEQKNIL(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPXEQKB(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPXEQKN(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPXEQKS(lua_State* L, const Instruction* pc, StkId base, TValue* k);

View file

@ -35,10 +35,9 @@ NativeState::~NativeState() = default;
void initFallbackTable(NativeState& data) void initFallbackTable(NativeState& data)
{ {
// TODO: lvmexecute_split.py could be taught to generate a subset of instructions we actually need // When fallback is completely removed, remove it from includeInsts list in lvmexecute_split.py
CODEGEN_SET_FALLBACK(LOP_NEWCLOSURE, 0); CODEGEN_SET_FALLBACK(LOP_NEWCLOSURE, 0);
CODEGEN_SET_FALLBACK(LOP_NAMECALL, 0); CODEGEN_SET_FALLBACK(LOP_NAMECALL, 0);
CODEGEN_SET_FALLBACK(LOP_CALL, kFallbackUpdateCi | kFallbackCheckInterrupt);
CODEGEN_SET_FALLBACK(LOP_FORGPREP, kFallbackUpdatePc); CODEGEN_SET_FALLBACK(LOP_FORGPREP, kFallbackUpdatePc);
CODEGEN_SET_FALLBACK(LOP_GETVARARGS, 0); CODEGEN_SET_FALLBACK(LOP_GETVARARGS, 0);
CODEGEN_SET_FALLBACK(LOP_DUPCLOSURE, 0); CODEGEN_SET_FALLBACK(LOP_DUPCLOSURE, 0);
@ -86,6 +85,8 @@ void initHelperFunctions(NativeState& data)
data.context.forgLoopNodeIter = forgLoopNodeIter; data.context.forgLoopNodeIter = forgLoopNodeIter;
data.context.forgLoopNonTableFallback = forgLoopNonTableFallback; data.context.forgLoopNonTableFallback = forgLoopNonTableFallback;
data.context.forgPrepXnextFallback = forgPrepXnextFallback; data.context.forgPrepXnextFallback = forgPrepXnextFallback;
data.context.callProlog = callProlog;
data.context.callEpilogC = callEpilogC;
} }
} // namespace CodeGen } // namespace CodeGen

View file

@ -26,8 +26,6 @@ class UnwindBuilder;
using FallbackFn = const Instruction*(lua_State* L, const Instruction* pc, StkId base, TValue* k); using FallbackFn = const Instruction*(lua_State* L, const Instruction* pc, StkId base, TValue* k);
constexpr uint8_t kFallbackUpdatePc = 1 << 0; constexpr uint8_t kFallbackUpdatePc = 1 << 0;
constexpr uint8_t kFallbackUpdateCi = 1 << 1;
constexpr uint8_t kFallbackCheckInterrupt = 1 << 2;
struct NativeFallback struct NativeFallback
{ {
@ -37,7 +35,8 @@ struct NativeFallback
struct NativeProto struct NativeProto
{ {
uintptr_t* instTargets = nullptr; uintptr_t entryTarget = 0;
uintptr_t* instTargets = nullptr; // TODO: NativeProto should be variable-size with all target embedded
Proto* proto = nullptr; Proto* proto = nullptr;
uint32_t location = 0; uint32_t location = 0;
@ -85,6 +84,8 @@ struct NativeContext
bool (*forgLoopNodeIter)(lua_State* L, Table* h, int index, TValue* ra) = nullptr; bool (*forgLoopNodeIter)(lua_State* L, Table* h, int index, TValue* ra) = nullptr;
bool (*forgLoopNonTableFallback)(lua_State* L, int insnA, int aux) = nullptr; bool (*forgLoopNonTableFallback)(lua_State* L, int insnA, int aux) = nullptr;
void (*forgPrepXnextFallback)(lua_State* L, TValue* ra, int pc) = nullptr; void (*forgPrepXnextFallback)(lua_State* L, TValue* ra, int pc) = nullptr;
Closure* (*callProlog)(lua_State* L, TValue* ra, StkId argtop, int nresults) = nullptr;
void (*callEpilogC)(lua_State* L, int nresults, int n) = nullptr;
}; };
struct NativeState struct NativeState

View file

@ -55,18 +55,23 @@ target_sources(Luau.Compiler PRIVATE
# Luau.CodeGen Sources # Luau.CodeGen Sources
target_sources(Luau.CodeGen PRIVATE target_sources(Luau.CodeGen PRIVATE
CodeGen/include/Luau/AddressA64.h
CodeGen/include/Luau/AssemblyBuilderA64.h
CodeGen/include/Luau/AssemblyBuilderX64.h CodeGen/include/Luau/AssemblyBuilderX64.h
CodeGen/include/Luau/CodeAllocator.h CodeGen/include/Luau/CodeAllocator.h
CodeGen/include/Luau/CodeBlockUnwind.h CodeGen/include/Luau/CodeBlockUnwind.h
CodeGen/include/Luau/CodeGen.h CodeGen/include/Luau/CodeGen.h
CodeGen/include/Luau/Condition.h CodeGen/include/Luau/ConditionA64.h
CodeGen/include/Luau/ConditionX64.h
CodeGen/include/Luau/Label.h CodeGen/include/Luau/Label.h
CodeGen/include/Luau/OperandX64.h CodeGen/include/Luau/OperandX64.h
CodeGen/include/Luau/RegisterA64.h
CodeGen/include/Luau/RegisterX64.h CodeGen/include/Luau/RegisterX64.h
CodeGen/include/Luau/UnwindBuilder.h CodeGen/include/Luau/UnwindBuilder.h
CodeGen/include/Luau/UnwindBuilderDwarf2.h CodeGen/include/Luau/UnwindBuilderDwarf2.h
CodeGen/include/Luau/UnwindBuilderWin.h CodeGen/include/Luau/UnwindBuilderWin.h
CodeGen/src/AssemblyBuilderA64.cpp
CodeGen/src/AssemblyBuilderX64.cpp CodeGen/src/AssemblyBuilderX64.cpp
CodeGen/src/CodeAllocator.cpp CodeGen/src/CodeAllocator.cpp
CodeGen/src/CodeBlockUnwind.cpp CodeGen/src/CodeBlockUnwind.cpp
@ -103,6 +108,7 @@ target_sources(Luau.Analysis PRIVATE
Analysis/include/Luau/BuiltinDefinitions.h Analysis/include/Luau/BuiltinDefinitions.h
Analysis/include/Luau/Clone.h Analysis/include/Luau/Clone.h
Analysis/include/Luau/Config.h Analysis/include/Luau/Config.h
Analysis/include/Luau/Connective.h
Analysis/include/Luau/Constraint.h Analysis/include/Luau/Constraint.h
Analysis/include/Luau/ConstraintGraphBuilder.h Analysis/include/Luau/ConstraintGraphBuilder.h
Analysis/include/Luau/ConstraintSolver.h Analysis/include/Luau/ConstraintSolver.h
@ -156,6 +162,7 @@ target_sources(Luau.Analysis PRIVATE
Analysis/src/BuiltinDefinitions.cpp Analysis/src/BuiltinDefinitions.cpp
Analysis/src/Clone.cpp Analysis/src/Clone.cpp
Analysis/src/Config.cpp Analysis/src/Config.cpp
Analysis/src/Connective.cpp
Analysis/src/Constraint.cpp Analysis/src/Constraint.cpp
Analysis/src/ConstraintGraphBuilder.cpp Analysis/src/ConstraintGraphBuilder.cpp
Analysis/src/ConstraintSolver.cpp Analysis/src/ConstraintSolver.cpp
@ -299,6 +306,7 @@ if(TARGET Luau.UnitTest)
tests/AstQueryDsl.cpp tests/AstQueryDsl.cpp
tests/ConstraintGraphBuilderFixture.cpp tests/ConstraintGraphBuilderFixture.cpp
tests/Fixture.cpp tests/Fixture.cpp
tests/AssemblyBuilderA64.test.cpp
tests/AssemblyBuilderX64.test.cpp tests/AssemblyBuilderX64.test.cpp
tests/AstJsonEncoder.test.cpp tests/AstJsonEncoder.test.cpp
tests/AstQuery.test.cpp tests/AstQuery.test.cpp

View file

@ -20,14 +20,6 @@
#define LUAU_FASTMATH_END #define LUAU_FASTMATH_END
#endif #endif
// Some functions like floor/ceil have SSE4.1 equivalents but we currently support systems without SSE4.1
// On newer GCC and Clang we can use function multi-versioning to generate SSE4.1 code plus CPUID based dispatch.
#if !defined(__APPLE__) && (defined(__x86_64__) || defined(_M_X64)) && ((defined(__clang__) && __clang_major__ >= 14) || (defined(__GNUC__) && __GNUC__ >= 6)) && !defined(__SSE4_1__)
#define LUAU_DISPATCH_SSE41 __attribute__((target_clones("default", "sse4.1")))
#else
#define LUAU_DISPATCH_SSE41
#endif
// Used on functions that have a printf-like interface to validate them statically // Used on functions that have a printf-like interface to validate them statically
#if defined(__GNUC__) #if defined(__GNUC__)
#define LUA_PRINTF_ATTR(fmt, arg) __attribute__((format(printf, fmt, arg))) #define LUA_PRINTF_ATTR(fmt, arg) __attribute__((format(printf, fmt, arg)))

View file

@ -96,7 +96,6 @@ static int luauF_atan(lua_State* L, StkId res, TValue* arg0, int nresults, StkId
} }
LUAU_FASTMATH_BEGIN LUAU_FASTMATH_BEGIN
LUAU_DISPATCH_SSE41
static int luauF_ceil(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) static int luauF_ceil(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams)
{ {
if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0))
@ -160,7 +159,6 @@ static int luauF_exp(lua_State* L, StkId res, TValue* arg0, int nresults, StkId
} }
LUAU_FASTMATH_BEGIN LUAU_FASTMATH_BEGIN
LUAU_DISPATCH_SSE41
static int luauF_floor(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) static int luauF_floor(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams)
{ {
if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0))
@ -938,7 +936,6 @@ static int luauF_sign(lua_State* L, StkId res, TValue* arg0, int nresults, StkId
} }
LUAU_FASTMATH_BEGIN LUAU_FASTMATH_BEGIN
LUAU_DISPATCH_SSE41
static int luauF_round(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) static int luauF_round(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams)
{ {
if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0))
@ -1326,9 +1323,9 @@ const luau_FastFunction luauF_table[256] = {
luauF_getmetatable, luauF_getmetatable,
luauF_setmetatable, luauF_setmetatable,
// When adding builtins, add them above this line; what follows is 64 "dummy" entries with luauF_missing fallback. // When adding builtins, add them above this line; what follows is 64 "dummy" entries with luauF_missing fallback.
// This is important so that older versions of the runtime that don't support newer builtins automatically fall back via luauF_missing. // This is important so that older versions of the runtime that don't support newer builtins automatically fall back via luauF_missing.
// Given the builtin addition velocity this should always provide a larger compatibility window than bytecode versions suggest. // Given the builtin addition velocity this should always provide a larger compatibility window than bytecode versions suggest.
#define MISSING8 luauF_missing, luauF_missing, luauF_missing, luauF_missing, luauF_missing, luauF_missing, luauF_missing, luauF_missing #define MISSING8 luauF_missing, luauF_missing, luauF_missing, luauF_missing, luauF_missing, luauF_missing, luauF_missing, luauF_missing
MISSING8, MISSING8,

View file

@ -34,7 +34,6 @@ inline bool luai_vecisnan(const float* a)
} }
LUAU_FASTMATH_BEGIN LUAU_FASTMATH_BEGIN
// TODO: LUAU_DISPATCH_SSE41 would be nice here, but clang-14 doesn't support it correctly on inline functions...
inline double luai_nummod(double a, double b) inline double luai_nummod(double a, double b)
{ {
return a - floor(a / b) * b; return a - floor(a / b) * b;

View file

@ -0,0 +1,221 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/AssemblyBuilderA64.h"
#include "Luau/StringUtils.h"
#include "doctest.h"
#include <string.h>
using namespace Luau::CodeGen;
static std::string bytecodeAsArray(const std::vector<uint8_t>& bytecode)
{
std::string result = "{";
for (size_t i = 0; i < bytecode.size(); i++)
Luau::formatAppend(result, "%s0x%02x", i == 0 ? "" : ", ", bytecode[i]);
return result.append("}");
}
static std::string bytecodeAsArray(const std::vector<uint32_t>& code)
{
std::string result = "{";
for (size_t i = 0; i < code.size(); i++)
Luau::formatAppend(result, "%s0x%08x", i == 0 ? "" : ", ", code[i]);
return result.append("}");
}
class AssemblyBuilderA64Fixture
{
public:
bool check(void (*f)(AssemblyBuilderA64& build), std::vector<uint32_t> code, std::vector<uint8_t> data = {})
{
AssemblyBuilderA64 build(/* logText= */ false);
f(build);
build.finalize();
if (build.code != code)
{
printf("Expected code: %s\nReceived code: %s\n", bytecodeAsArray(code).c_str(), bytecodeAsArray(build.code).c_str());
return false;
}
if (build.data != data)
{
printf("Expected data: %s\nReceived data: %s\n", bytecodeAsArray(data).c_str(), bytecodeAsArray(build.data).c_str());
return false;
}
return true;
}
};
// armconverter.com can be used to validate instruction sequences
TEST_SUITE_BEGIN("A64Assembly");
#define SINGLE_COMPARE(inst, ...) \
CHECK(check( \
[](AssemblyBuilderA64& build) { \
build.inst; \
}, \
{__VA_ARGS__}))
TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Unary")
{
SINGLE_COMPARE(neg(x0, x1), 0xCB0103E0);
SINGLE_COMPARE(neg(w0, w1), 0x4B0103E0);
SINGLE_COMPARE(clz(x0, x1), 0xDAC01020);
SINGLE_COMPARE(clz(w0, w1), 0x5AC01020);
SINGLE_COMPARE(rbit(x0, x1), 0xDAC00020);
SINGLE_COMPARE(rbit(w0, w1), 0x5AC00020);
}
TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Binary")
{
// reg, reg
SINGLE_COMPARE(add(x0, x1, x2), 0x8B020020);
SINGLE_COMPARE(add(w0, w1, w2), 0x0B020020);
SINGLE_COMPARE(add(x0, x1, x2, 7), 0x8B021C20);
SINGLE_COMPARE(sub(x0, x1, x2), 0xCB020020);
SINGLE_COMPARE(and_(x0, x1, x2), 0x8A020020);
SINGLE_COMPARE(orr(x0, x1, x2), 0xAA020020);
SINGLE_COMPARE(eor(x0, x1, x2), 0xCA020020);
SINGLE_COMPARE(lsl(x0, x1, x2), 0x9AC22020);
SINGLE_COMPARE(lsl(w0, w1, w2), 0x1AC22020);
SINGLE_COMPARE(lsr(x0, x1, x2), 0x9AC22420);
SINGLE_COMPARE(asr(x0, x1, x2), 0x9AC22820);
SINGLE_COMPARE(ror(x0, x1, x2), 0x9AC22C20);
// reg, imm
SINGLE_COMPARE(add(x3, x7, 78), 0x910138E3);
SINGLE_COMPARE(add(w3, w7, 78), 0x110138E3);
SINGLE_COMPARE(sub(w3, w7, 78), 0x510138E3);
}
TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Loads")
{
// address forms
SINGLE_COMPARE(ldr(x0, x1), 0xF9400020);
SINGLE_COMPARE(ldr(x0, AddressA64(x1, 8)), 0xF9400420);
SINGLE_COMPARE(ldr(x0, AddressA64(x1, x7)), 0xF8676820);
// load sizes
SINGLE_COMPARE(ldr(x0, x1), 0xF9400020);
SINGLE_COMPARE(ldr(w0, x1), 0xB9400020);
SINGLE_COMPARE(ldrb(w0, x1), 0x39400020);
SINGLE_COMPARE(ldrh(w0, x1), 0x79400020);
SINGLE_COMPARE(ldrsb(x0, x1), 0x39800020);
SINGLE_COMPARE(ldrsb(w0, x1), 0x39C00020);
SINGLE_COMPARE(ldrsh(x0, x1), 0x79800020);
SINGLE_COMPARE(ldrsh(w0, x1), 0x79C00020);
SINGLE_COMPARE(ldrsw(x0, x1), 0xB9800020);
}
TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Stores")
{
// address forms
SINGLE_COMPARE(str(x0, x1), 0xF9000020);
SINGLE_COMPARE(str(x0, AddressA64(x1, 8)), 0xF9000420);
SINGLE_COMPARE(str(x0, AddressA64(x1, x7)), 0xF8276820);
// store sizes
SINGLE_COMPARE(str(x0, x1), 0xF9000020);
SINGLE_COMPARE(str(w0, x1), 0xB9000020);
SINGLE_COMPARE(strb(w0, x1), 0x39000020);
SINGLE_COMPARE(strh(w0, x1), 0x79000020);
}
TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Moves")
{
SINGLE_COMPARE(mov(x0, x1), 0xAA0103E0);
SINGLE_COMPARE(mov(w0, w1), 0x2A0103E0);
SINGLE_COMPARE(mov(x0, 42), 0xD2800540);
SINGLE_COMPARE(mov(w0, 42), 0x52800540);
SINGLE_COMPARE(movk(x0, 42, 16), 0xF2A00540);
}
TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "ControlFlow")
{
// Jump back
CHECK(check(
[](AssemblyBuilderA64& build) {
Label start = build.setLabel();
build.mov(x0, x1);
build.b(ConditionA64::Equal, start);
},
{0xAA0103E0, 0x54FFFFE0}));
// Jump forward
CHECK(check(
[](AssemblyBuilderA64& build) {
Label skip;
build.b(ConditionA64::Equal, skip);
build.mov(x0, x1);
build.setLabel(skip);
},
{0x54000040, 0xAA0103E0}));
// Jumps
CHECK(check(
[](AssemblyBuilderA64& build) {
Label skip;
build.b(ConditionA64::Equal, skip);
build.cbz(x0, skip);
build.cbnz(x0, skip);
build.setLabel(skip);
},
{0x54000060, 0xB4000040, 0xB5000020}));
// Basic control flow
SINGLE_COMPARE(ret(), 0xD65F03C0);
}
TEST_CASE("LogTest")
{
AssemblyBuilderA64 build(/* logText= */ true);
build.add(w0, w1, w2);
build.add(x0, x1, x2, 2);
build.add(w7, w8, 5);
build.add(x7, x8, 5);
build.ldr(x7, x8);
build.ldr(x7, AddressA64(x8, 8));
build.ldr(x7, AddressA64(x8, x9));
build.mov(x1, x2);
build.movk(x1, 42, 16);
Label l;
build.b(ConditionA64::Plus, l);
build.cbz(x7, l);
build.setLabel(l);
build.ret();
build.finalize();
std::string expected = R"(
add w0,w1,w2
add x0,x1,x2 LSL #2
add w7,w8,#5
add x7,x8,#5
ldr x7,[x8]
ldr x7,[x8,#8]
ldr x7,[x8,x9]
mov x1,x2
movk x1,#42 LSL #16
b.pl .L1
cbz x7,.L1
.L1:
ret
)";
CHECK("\n" + build.text == expected);
}
TEST_SUITE_END();

View file

@ -8,7 +8,7 @@
using namespace Luau::CodeGen; using namespace Luau::CodeGen;
std::string bytecodeAsArray(const std::vector<uint8_t>& bytecode) static std::string bytecodeAsArray(const std::vector<uint8_t>& bytecode)
{ {
std::string result = "{"; std::string result = "{";
@ -21,7 +21,7 @@ std::string bytecodeAsArray(const std::vector<uint8_t>& bytecode)
class AssemblyBuilderX64Fixture class AssemblyBuilderX64Fixture
{ {
public: public:
void check(void (*f)(AssemblyBuilderX64& build), std::vector<uint8_t> code, std::vector<uint8_t> data = {}) bool check(void (*f)(AssemblyBuilderX64& build), std::vector<uint8_t> code, std::vector<uint8_t> data = {})
{ {
AssemblyBuilderX64 build(/* logText= */ false); AssemblyBuilderX64 build(/* logText= */ false);
@ -32,25 +32,27 @@ public:
if (build.code != code) if (build.code != code)
{ {
printf("Expected code: %s\nReceived code: %s\n", bytecodeAsArray(code).c_str(), bytecodeAsArray(build.code).c_str()); printf("Expected code: %s\nReceived code: %s\n", bytecodeAsArray(code).c_str(), bytecodeAsArray(build.code).c_str());
CHECK(false); return false;
} }
if (build.data != data) if (build.data != data)
{ {
printf("Expected data: %s\nReceived data: %s\n", bytecodeAsArray(data).c_str(), bytecodeAsArray(build.data).c_str()); printf("Expected data: %s\nReceived data: %s\n", bytecodeAsArray(data).c_str(), bytecodeAsArray(build.data).c_str());
CHECK(false); return false;
} }
return true;
} }
}; };
TEST_SUITE_BEGIN("x64Assembly"); TEST_SUITE_BEGIN("x64Assembly");
#define SINGLE_COMPARE(inst, ...) \ #define SINGLE_COMPARE(inst, ...) \
check( \ CHECK(check( \
[](AssemblyBuilderX64& build) { \ [](AssemblyBuilderX64& build) { \
build.inst; \ build.inst; \
}, \ }, \
{__VA_ARGS__}) {__VA_ARGS__}))
TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "BaseBinaryInstructionForms") TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "BaseBinaryInstructionForms")
{ {
@ -236,9 +238,9 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "FormsOfShift")
TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "FormsOfLea") TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "FormsOfLea")
{ {
SINGLE_COMPARE(lea(rax, qword[rdx + rcx]), 0x48, 0x8d, 0x04, 0x0a); SINGLE_COMPARE(lea(rax, addr[rdx + rcx]), 0x48, 0x8d, 0x04, 0x0a);
SINGLE_COMPARE(lea(rax, qword[rdx + rax * 4]), 0x48, 0x8d, 0x04, 0x82); SINGLE_COMPARE(lea(rax, addr[rdx + rax * 4]), 0x48, 0x8d, 0x04, 0x82);
SINGLE_COMPARE(lea(rax, qword[r13 + r12 * 4 + 4]), 0x4b, 0x8d, 0x44, 0xa5, 0x04); SINGLE_COMPARE(lea(rax, addr[r13 + r12 * 4 + 4]), 0x4b, 0x8d, 0x44, 0xa5, 0x04);
} }
TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "FormsOfAbsoluteJumps") TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "FormsOfAbsoluteJumps")
@ -280,34 +282,34 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "NopForms")
TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AlignmentForms") TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AlignmentForms")
{ {
check( CHECK(check(
[](AssemblyBuilderX64& build) { [](AssemblyBuilderX64& build) {
build.ret(); build.ret();
build.align(8, AlignmentDataX64::Nop); build.align(8, AlignmentDataX64::Nop);
}, },
{0xc3, 0x0f, 0x1f, 0x80, 0x00, 0x00, 0x00, 0x00}); {0xc3, 0x0f, 0x1f, 0x80, 0x00, 0x00, 0x00, 0x00}));
check( CHECK(check(
[](AssemblyBuilderX64& build) { [](AssemblyBuilderX64& build) {
build.ret(); build.ret();
build.align(32, AlignmentDataX64::Nop); build.align(32, AlignmentDataX64::Nop);
}, },
{0xc3, 0x66, 0x0f, 0x1f, 0x84, 0x00, 0x00, 0x00, 0x00, 0x00, 0x66, 0x0f, 0x1f, 0x84, 0x00, 0x00, 0x00, 0x00, 0x00, 0x66, 0x0f, 0x1f, 0x84, {0xc3, 0x66, 0x0f, 0x1f, 0x84, 0x00, 0x00, 0x00, 0x00, 0x00, 0x66, 0x0f, 0x1f, 0x84, 0x00, 0x00, 0x00, 0x00, 0x00, 0x66, 0x0f, 0x1f, 0x84,
0x00, 0x00, 0x00, 0x00, 0x00, 0x0f, 0x1f, 0x40, 0x00}); 0x00, 0x00, 0x00, 0x00, 0x00, 0x0f, 0x1f, 0x40, 0x00}));
check( CHECK(check(
[](AssemblyBuilderX64& build) { [](AssemblyBuilderX64& build) {
build.ret(); build.ret();
build.align(8, AlignmentDataX64::Int3); build.align(8, AlignmentDataX64::Int3);
}, },
{0xc3, 0xcc, 0xcc, 0xcc, 0xcc, 0xcc, 0xcc, 0xcc}); {0xc3, 0xcc, 0xcc, 0xcc, 0xcc, 0xcc, 0xcc, 0xcc}));
check( CHECK(check(
[](AssemblyBuilderX64& build) { [](AssemblyBuilderX64& build) {
build.ret(); build.ret();
build.align(8, AlignmentDataX64::Ud2); build.align(8, AlignmentDataX64::Ud2);
}, },
{0xc3, 0x0f, 0x0b, 0x0f, 0x0b, 0x0f, 0x0b, 0xcc}); {0xc3, 0x0f, 0x0b, 0x0f, 0x0b, 0x0f, 0x0b, 0xcc}));
} }
TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AlignmentOverflow") TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AlignmentOverflow")
@ -349,40 +351,40 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AlignmentOverflow")
TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "ControlFlow") TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "ControlFlow")
{ {
// Jump back // Jump back
check( CHECK(check(
[](AssemblyBuilderX64& build) { [](AssemblyBuilderX64& build) {
Label start = build.setLabel(); Label start = build.setLabel();
build.add(rsi, 1); build.add(rsi, 1);
build.cmp(rsi, rdi); build.cmp(rsi, rdi);
build.jcc(Condition::Equal, start); build.jcc(ConditionX64::Equal, start);
}, },
{0x48, 0x83, 0xc6, 0x01, 0x48, 0x3b, 0xf7, 0x0f, 0x84, 0xf3, 0xff, 0xff, 0xff}); {0x48, 0x83, 0xc6, 0x01, 0x48, 0x3b, 0xf7, 0x0f, 0x84, 0xf3, 0xff, 0xff, 0xff}));
// Jump back, but the label is set before use // Jump back, but the label is set before use
check( CHECK(check(
[](AssemblyBuilderX64& build) { [](AssemblyBuilderX64& build) {
Label start; Label start;
build.add(rsi, 1); build.add(rsi, 1);
build.setLabel(start); build.setLabel(start);
build.cmp(rsi, rdi); build.cmp(rsi, rdi);
build.jcc(Condition::Equal, start); build.jcc(ConditionX64::Equal, start);
}, },
{0x48, 0x83, 0xc6, 0x01, 0x48, 0x3b, 0xf7, 0x0f, 0x84, 0xf7, 0xff, 0xff, 0xff}); {0x48, 0x83, 0xc6, 0x01, 0x48, 0x3b, 0xf7, 0x0f, 0x84, 0xf7, 0xff, 0xff, 0xff}));
// Jump forward // Jump forward
check( CHECK(check(
[](AssemblyBuilderX64& build) { [](AssemblyBuilderX64& build) {
Label skip; Label skip;
build.cmp(rsi, rdi); build.cmp(rsi, rdi);
build.jcc(Condition::Greater, skip); build.jcc(ConditionX64::Greater, skip);
build.or_(rdi, 0x3e); build.or_(rdi, 0x3e);
build.setLabel(skip); build.setLabel(skip);
}, },
{0x48, 0x3b, 0xf7, 0x0f, 0x8f, 0x04, 0x00, 0x00, 0x00, 0x48, 0x83, 0xcf, 0x3e}); {0x48, 0x3b, 0xf7, 0x0f, 0x8f, 0x04, 0x00, 0x00, 0x00, 0x48, 0x83, 0xcf, 0x3e}));
// Regular jump // Regular jump
check( CHECK(check(
[](AssemblyBuilderX64& build) { [](AssemblyBuilderX64& build) {
Label skip; Label skip;
@ -390,12 +392,12 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "ControlFlow")
build.and_(rdi, 0x3e); build.and_(rdi, 0x3e);
build.setLabel(skip); build.setLabel(skip);
}, },
{0xe9, 0x04, 0x00, 0x00, 0x00, 0x48, 0x83, 0xe7, 0x3e}); {0xe9, 0x04, 0x00, 0x00, 0x00, 0x48, 0x83, 0xe7, 0x3e}));
} }
TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "LabelCall") TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "LabelCall")
{ {
check( CHECK(check(
[](AssemblyBuilderX64& build) { [](AssemblyBuilderX64& build) {
Label fnB; Label fnB;
@ -404,10 +406,10 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "LabelCall")
build.ret(); build.ret();
build.setLabel(fnB); build.setLabel(fnB);
build.lea(rax, qword[rcx + 0x1f]); build.lea(rax, addr[rcx + 0x1f]);
build.ret(); build.ret();
}, },
{0x48, 0x83, 0xe1, 0x3e, 0xe8, 0x01, 0x00, 0x00, 0x00, 0xc3, 0x48, 0x8d, 0x41, 0x1f, 0xc3}); {0x48, 0x83, 0xe1, 0x3e, 0xe8, 0x01, 0x00, 0x00, 0x00, 0xc3, 0x48, 0x8d, 0x41, 0x1f, 0xc3}));
} }
TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXBinaryInstructionForms") TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXBinaryInstructionForms")
@ -518,7 +520,7 @@ TEST_CASE("LogTest")
Label start = build.setLabel(); Label start = build.setLabel();
build.cmp(rsi, rdi); build.cmp(rsi, rdi);
build.jcc(Condition::Equal, start); build.jcc(ConditionX64::Equal, start);
build.jmp(qword[rdx]); build.jmp(qword[rdx]);
build.vaddps(ymm9, ymm12, ymmword[rbp + 0xc]); build.vaddps(ymm9, ymm12, ymmword[rbp + 0xc]);
@ -548,7 +550,7 @@ TEST_CASE("LogTest")
build.finalize(); build.finalize();
bool same = "\n" + build.text == R"( std::string expected = R"(
push r12 push r12
; align 8 ; align 8
nop word ptr[rax+rax] ; 6-byte nop nop word ptr[rax+rax] ; 6-byte nop
@ -588,13 +590,14 @@ TEST_CASE("LogTest")
nop dword ptr[rax+rax] ; 8-byte nop nop dword ptr[rax+rax] ; 8-byte nop
nop word ptr[rax+rax] ; 9-byte nop nop word ptr[rax+rax] ; 9-byte nop
)"; )";
CHECK(same);
CHECK("\n" + build.text == expected);
} }
TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "Constants") TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "Constants")
{ {
// clang-format off // clang-format off
check( CHECK(check(
[](AssemblyBuilderX64& build) { [](AssemblyBuilderX64& build) {
build.xor_(rax, rax); build.xor_(rax, rax);
build.add(rax, build.i64(0x1234567887654321)); build.add(rax, build.i64(0x1234567887654321));
@ -625,7 +628,7 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "Constants")
0x00, 0x00, 0x00, 0x00, // padding to align f64 0x00, 0x00, 0x00, 0x00, // padding to align f64
0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x80, 0x3f,
0x21, 0x43, 0x65, 0x87, 0x78, 0x56, 0x34, 0x12, 0x21, 0x43, 0x65, 0x87, 0x78, 0x56, 0x34, 0x12,
}); }));
// clang-format on // clang-format on
} }

View file

@ -1,5 +1,6 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/AssemblyBuilderX64.h" #include "Luau/AssemblyBuilderX64.h"
#include "Luau/AssemblyBuilderA64.h"
#include "Luau/CodeAllocator.h" #include "Luau/CodeAllocator.h"
#include "Luau/CodeBlockUnwind.h" #include "Luau/CodeBlockUnwind.h"
#include "Luau/UnwindBuilder.h" #include "Luau/UnwindBuilder.h"
@ -207,7 +208,7 @@ constexpr RegisterX64 rArg3 = rdx;
constexpr RegisterX64 rNonVol1 = r12; constexpr RegisterX64 rNonVol1 = r12;
constexpr RegisterX64 rNonVol2 = rbx; constexpr RegisterX64 rNonVol2 = rbx;
TEST_CASE("GeneratedCodeExecution") TEST_CASE("GeneratedCodeExecutionX64")
{ {
AssemblyBuilderX64 build(/* logText= */ false); AssemblyBuilderX64 build(/* logText= */ false);
@ -241,7 +242,7 @@ void throwing(int64_t arg)
throw std::runtime_error("testing"); throw std::runtime_error("testing");
} }
TEST_CASE("GeneratedCodeExecutionWithThrow") TEST_CASE("GeneratedCodeExecutionWithThrowX64")
{ {
AssemblyBuilderX64 build(/* logText= */ false); AssemblyBuilderX64 build(/* logText= */ false);
@ -267,7 +268,7 @@ TEST_CASE("GeneratedCodeExecutionWithThrow")
build.sub(rsp, stackSize + localsSize); build.sub(rsp, stackSize + localsSize);
unwind->allocStack(stackSize + localsSize); unwind->allocStack(stackSize + localsSize);
build.lea(rbp, qword[rsp + stackSize]); build.lea(rbp, addr[rsp + stackSize]);
unwind->setupFrameReg(rbp, stackSize); unwind->setupFrameReg(rbp, stackSize);
unwind->finish(); unwind->finish();
@ -281,7 +282,7 @@ TEST_CASE("GeneratedCodeExecutionWithThrow")
build.call(rNonVol2); build.call(rNonVol2);
// Epilogue // Epilogue
build.lea(rsp, qword[rbp + localsSize]); build.lea(rsp, addr[rbp + localsSize]);
build.pop(rbp); build.pop(rbp);
build.pop(rNonVol2); build.pop(rNonVol2);
build.pop(rNonVol1); build.pop(rNonVol1);
@ -317,7 +318,7 @@ TEST_CASE("GeneratedCodeExecutionWithThrow")
} }
} }
TEST_CASE("GeneratedCodeExecutionWithThrowOutsideTheGate") TEST_CASE("GeneratedCodeExecutionWithThrowOutsideTheGateX64")
{ {
AssemblyBuilderX64 build(/* logText= */ false); AssemblyBuilderX64 build(/* logText= */ false);
@ -351,7 +352,7 @@ TEST_CASE("GeneratedCodeExecutionWithThrowOutsideTheGate")
build.sub(rsp, stackSize + localsSize); build.sub(rsp, stackSize + localsSize);
unwind->allocStack(stackSize + localsSize); unwind->allocStack(stackSize + localsSize);
build.lea(rbp, qword[rsp + stackSize]); build.lea(rbp, addr[rsp + stackSize]);
unwind->setupFrameReg(rbp, stackSize); unwind->setupFrameReg(rbp, stackSize);
unwind->finish(); unwind->finish();
@ -366,7 +367,7 @@ TEST_CASE("GeneratedCodeExecutionWithThrowOutsideTheGate")
Label returnOffset = build.setLabel(); Label returnOffset = build.setLabel();
// Epilogue // Epilogue
build.lea(rsp, qword[rbp + localsSize]); build.lea(rsp, addr[rbp + localsSize]);
build.pop(rbp); build.pop(rbp);
build.pop(r15); build.pop(r15);
build.pop(r14); build.pop(r14);
@ -432,4 +433,43 @@ TEST_CASE("GeneratedCodeExecutionWithThrowOutsideTheGate")
#endif #endif
#if defined(__aarch64__)
TEST_CASE("GeneratedCodeExecutionA64")
{
AssemblyBuilderA64 build(/* logText= */ false);
Label skip;
build.cbz(x1, skip);
build.ldrsw(x1, x1);
build.cbnz(x1, skip);
build.mov(x1, 0); // doesn't execute due to cbnz above
build.setLabel(skip);
build.add(x1, x1, 1);
build.add(x0, x0, x1, /* LSL */ 1);
build.ret();
build.finalize();
size_t blockSize = 1024 * 1024;
size_t maxTotalSize = 1024 * 1024;
CodeAllocator allocator(blockSize, maxTotalSize);
uint8_t* nativeData;
size_t sizeNativeData;
uint8_t* nativeEntry;
REQUIRE(allocator.allocate(build.data.data(), build.data.size(), reinterpret_cast<uint8_t*>(build.code.data()), build.code.size() * 4, nativeData,
sizeNativeData, nativeEntry));
REQUIRE(nativeEntry);
using FunctionType = int64_t(int64_t, int*);
FunctionType* f = (FunctionType*)nativeEntry;
int input = 10;
int64_t result = f(20, &input);
CHECK(result == 42);
}
#endif
TEST_SUITE_END(); TEST_SUITE_END();

View file

@ -539,8 +539,14 @@ TEST_CASE("Debugger")
static bool singlestep = false; static bool singlestep = false;
static int stephits = 0; static int stephits = 0;
SUBCASE("") { singlestep = false; } SUBCASE("")
SUBCASE("SingleStep") { singlestep = true; } {
singlestep = false;
}
SUBCASE("SingleStep")
{
singlestep = true;
}
breakhits = 0; breakhits = 0;
interruptedthread = nullptr; interruptedthread = nullptr;

View file

@ -506,13 +506,14 @@ std::optional<TypeId> linearSearchForBinding(Scope* scope, const char* name)
return std::nullopt; return std::nullopt;
} }
void registerNotType(Fixture& fixture, TypeArena& arena) void registerHiddenTypes(Fixture& fixture, TypeArena& arena)
{ {
TypeId t = arena.addType(GenericTypeVar{"T"}); TypeId t = arena.addType(GenericTypeVar{"T"});
GenericTypeDefinition genericT{t}; GenericTypeDefinition genericT{t};
ScopePtr moduleScope = fixture.frontend.getGlobalScope(); ScopePtr moduleScope = fixture.frontend.getGlobalScope();
moduleScope->exportedTypeBindings["Not"] = TypeFun{{genericT}, arena.addType(NegationTypeVar{t})}; moduleScope->exportedTypeBindings["Not"] = TypeFun{{genericT}, arena.addType(NegationTypeVar{t})};
moduleScope->exportedTypeBindings["fun"] = TypeFun{{}, fixture.singletonTypes->functionType};
} }
void dump(const std::vector<Constraint>& constraints) void dump(const std::vector<Constraint>& constraints)

View file

@ -186,7 +186,7 @@ std::optional<TypeId> lookupName(ScopePtr scope, const std::string& name); // Wa
std::optional<TypeId> linearSearchForBinding(Scope* scope, const char* name); std::optional<TypeId> linearSearchForBinding(Scope* scope, const char* name);
void registerNotType(Fixture& fixture, TypeArena& arena); void registerHiddenTypes(Fixture& fixture, TypeArena& arena);
} // namespace Luau } // namespace Luau

View file

@ -3,6 +3,7 @@
#include "Fixture.h" #include "Fixture.h"
#include "Luau/Common.h" #include "Luau/Common.h"
#include "Luau/TypeVar.h"
#include "doctest.h" #include "doctest.h"
#include "Luau/Normalize.h" #include "Luau/Normalize.h"
@ -19,7 +20,7 @@ struct IsSubtypeFixture : Fixture
return ::Luau::isSubtype(a, b, NotNull{getMainModule()->getModuleScope().get()}, singletonTypes, ice); return ::Luau::isSubtype(a, b, NotNull{getMainModule()->getModuleScope().get()}, singletonTypes, ice);
} }
}; };
} } // namespace
void createSomeClasses(Frontend& frontend) void createSomeClasses(Frontend& frontend)
{ {
@ -109,12 +110,10 @@ TEST_CASE_FIXTURE(IsSubtypeFixture, "functions_and_any")
TypeId a = requireType("a"); TypeId a = requireType("a");
TypeId b = requireType("b"); TypeId b = requireType("b");
// Intuition: // any makes things work even when it makes no sense.
// We cannot use b where a is required because we cannot rely on b to return a string.
// We cannot use a where b is required because we cannot rely on a to accept non-number arguments.
CHECK(!isSubtype(b, a)); CHECK(isSubtype(b, a));
CHECK(!isSubtype(a, b)); CHECK(isSubtype(a, b));
} }
TEST_CASE_FIXTURE(IsSubtypeFixture, "variadic_functions_with_no_head") TEST_CASE_FIXTURE(IsSubtypeFixture, "variadic_functions_with_no_head")
@ -199,7 +198,7 @@ TEST_CASE_FIXTURE(IsSubtypeFixture, "table_with_any_prop")
TypeId b = requireType("b"); TypeId b = requireType("b");
CHECK(isSubtype(a, b)); CHECK(isSubtype(a, b));
CHECK(!isSubtype(b, a)); CHECK(isSubtype(b, a));
} }
TEST_CASE_FIXTURE(IsSubtypeFixture, "intersection") TEST_CASE_FIXTURE(IsSubtypeFixture, "intersection")
@ -261,7 +260,7 @@ TEST_CASE_FIXTURE(IsSubtypeFixture, "tables")
TypeId d = requireType("d"); TypeId d = requireType("d");
CHECK(isSubtype(a, b)); CHECK(isSubtype(a, b));
CHECK(!isSubtype(b, a)); CHECK(isSubtype(b, a));
CHECK(!isSubtype(c, a)); CHECK(!isSubtype(c, a));
CHECK(!isSubtype(a, c)); CHECK(!isSubtype(a, c));
@ -394,7 +393,8 @@ TEST_SUITE_END();
struct NormalizeFixture : Fixture struct NormalizeFixture : Fixture
{ {
ScopedFastFlag sff{"LuauNegatedStringSingletons", true}; ScopedFastFlag sff0{"LuauNegatedStringSingletons", true};
ScopedFastFlag sff1{"LuauNegatedFunctionTypes", true};
TypeArena arena; TypeArena arena;
InternalErrorReporter iceHandler; InternalErrorReporter iceHandler;
@ -403,16 +403,21 @@ struct NormalizeFixture : Fixture
NormalizeFixture() NormalizeFixture()
{ {
registerNotType(*this, arena); registerHiddenTypes(*this, arena);
} }
TypeId normal(const std::string& annotation) const NormalizedType* toNormalizedType(const std::string& annotation)
{ {
CheckResult result = check("type _Res = " + annotation); CheckResult result = check("type _Res = " + annotation);
LUAU_REQUIRE_NO_ERRORS(result); LUAU_REQUIRE_NO_ERRORS(result);
std::optional<TypeId> ty = lookupType("_Res"); std::optional<TypeId> ty = lookupType("_Res");
REQUIRE(ty); REQUIRE(ty);
const NormalizedType* norm = normalizer.normalize(*ty); return normalizer.normalize(*ty);
}
TypeId normal(const std::string& annotation)
{
const NormalizedType* norm = toNormalizedType(annotation);
REQUIRE(norm); REQUIRE(norm);
return normalizer.typeFromNormal(*norm); return normalizer.typeFromNormal(*norm);
} }
@ -476,6 +481,13 @@ TEST_CASE_FIXTURE(NormalizeFixture, "union_of_negations")
)"))); )")));
} }
TEST_CASE_FIXTURE(NormalizeFixture, "disjoint_negations_normalize_to_string")
{
CHECK(R"(string)" == toString(normal(R"(
(string & Not<"hello"> & Not<"world">) | (string & Not<"goodbye">)
)")));
}
TEST_CASE_FIXTURE(NormalizeFixture, "negate_boolean") TEST_CASE_FIXTURE(NormalizeFixture, "negate_boolean")
{ {
CHECK("true" == toString(normal(R"( CHECK("true" == toString(normal(R"(
@ -490,10 +502,43 @@ TEST_CASE_FIXTURE(NormalizeFixture, "negate_boolean_2")
)"))); )")));
} }
TEST_CASE_FIXTURE(NormalizeFixture, "bare_negation") TEST_CASE_FIXTURE(NormalizeFixture, "intersect_function_and_top_function")
{
CHECK("() -> ()" == toString(normal(R"(
fun & (() -> ())
)")));
}
TEST_CASE_FIXTURE(NormalizeFixture, "intersect_function_and_top_function_reverse")
{
CHECK("() -> ()" == toString(normal(R"(
(() -> ()) & fun
)")));
}
TEST_CASE_FIXTURE(NormalizeFixture, "union_function_and_top_function")
{
CHECK("function" == toString(normal(R"(
fun | (() -> ())
)")));
}
TEST_CASE_FIXTURE(NormalizeFixture, "negated_function_is_anything_except_a_function")
{
CHECK("(boolean | number | string | thread)?" == toString(normal(R"(
Not<fun>
)")));
}
TEST_CASE_FIXTURE(NormalizeFixture, "specific_functions_cannot_be_negated")
{
CHECK(nullptr == toNormalizedType("Not<(boolean) -> boolean>"));
}
TEST_CASE_FIXTURE(NormalizeFixture, "bare_negated_boolean")
{ {
// TODO: We don't yet have a way to say number | string | thread | nil | Class | Table | Function // TODO: We don't yet have a way to say number | string | thread | nil | Class | Table | Function
CHECK("(number | string | thread)?" == toString(normal(R"( CHECK("(function | number | string | thread)?" == toString(normal(R"(
Not<boolean> Not<boolean>
)"))); )")));
} }

View file

@ -1,6 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Parser.h" #include "Luau/Parser.h"
#include "AstQueryDsl.h"
#include "Fixture.h" #include "Fixture.h"
#include "ScopedFlags.h" #include "ScopedFlags.h"
@ -2775,4 +2776,51 @@ TEST_CASE_FIXTURE(Fixture, "get_a_nice_error_when_there_is_an_extra_comma_at_the
CHECK(2 == f->generics.size); CHECK(2 == f->generics.size);
} }
TEST_CASE_FIXTURE(Fixture, "get_a_nice_error_when_there_is_no_comma_between_table_members")
{
ScopedFastFlag luauTableConstructorRecovery{"LuauTableConstructorRecovery", true};
ParseResult result = tryParse(R"(
local t = {
first = 1
second = 2,
third = 3,
fouth = 4,
}
)");
REQUIRE(1 == result.errors.size());
CHECK(Location({3, 12}, {3, 18}) == result.errors[0].getLocation());
CHECK("Expected ',' after table constructor element" == result.errors[0].getMessage());
REQUIRE(1 == result.root->body.size);
AstExprTable* table = Luau::query<AstExprTable>(result.root);
REQUIRE(table);
CHECK(table->items.size == 4);
}
TEST_CASE_FIXTURE(Fixture, "get_a_nice_error_when_there_is_no_comma_after_last_table_member")
{
ParseResult result = tryParse(R"(
local t = {
first = 1
local ok = true
local good = ok == true
)");
REQUIRE(1 == result.errors.size());
CHECK(Location({4, 8}, {4, 13}) == result.errors[0].getLocation());
CHECK("Expected '}' (to close '{' at line 2), got 'local'" == result.errors[0].getMessage());
REQUIRE(3 == result.root->body.size);
AstExprTable* table = Luau::query<AstExprTable>(result.root);
REQUIRE(table);
CHECK(table->items.size == 1);
}
TEST_SUITE_END(); TEST_SUITE_END();

View file

@ -10,7 +10,6 @@
using namespace Luau; using namespace Luau;
LUAU_FASTFLAG(LuauRecursiveTypeParameterRestriction); LUAU_FASTFLAG(LuauRecursiveTypeParameterRestriction);
LUAU_FASTFLAG(LuauSpecialTypesAsterisked);
LUAU_FASTFLAG(LuauFixNameMaps); LUAU_FASTFLAG(LuauFixNameMaps);
LUAU_FASTFLAG(LuauFunctionReturnStringificationFixup); LUAU_FASTFLAG(LuauFunctionReturnStringificationFixup);
@ -270,16 +269,8 @@ TEST_CASE_FIXTURE(Fixture, "quit_stringifying_type_when_length_is_exceeded")
o.maxTypeLength = 40; o.maxTypeLength = 40;
CHECK_EQ(toString(requireType("f0"), o), "() -> ()"); CHECK_EQ(toString(requireType("f0"), o), "() -> ()");
CHECK_EQ(toString(requireType("f1"), o), "(() -> ()) -> () -> ()"); CHECK_EQ(toString(requireType("f1"), o), "(() -> ()) -> () -> ()");
if (FFlag::LuauSpecialTypesAsterisked) CHECK_EQ(toString(requireType("f2"), o), "((() -> ()) -> () -> ()) -> (() -> ()) -> ... *TRUNCATED*");
{ CHECK_EQ(toString(requireType("f3"), o), "(((() -> ()) -> () -> ()) -> (() -> ()) -> ... *TRUNCATED*");
CHECK_EQ(toString(requireType("f2"), o), "((() -> ()) -> () -> ()) -> (() -> ()) -> ... *TRUNCATED*");
CHECK_EQ(toString(requireType("f3"), o), "(((() -> ()) -> () -> ()) -> (() -> ()) -> ... *TRUNCATED*");
}
else
{
CHECK_EQ(toString(requireType("f2"), o), "((() -> ()) -> () -> ()) -> (() -> ()) -> ... <TRUNCATED>");
CHECK_EQ(toString(requireType("f3"), o), "(((() -> ()) -> () -> ()) -> (() -> ()) -> ... <TRUNCATED>");
}
} }
TEST_CASE_FIXTURE(Fixture, "stringifying_type_is_still_capped_when_exhaustive") TEST_CASE_FIXTURE(Fixture, "stringifying_type_is_still_capped_when_exhaustive")
@ -297,16 +288,8 @@ TEST_CASE_FIXTURE(Fixture, "stringifying_type_is_still_capped_when_exhaustive")
o.maxTypeLength = 40; o.maxTypeLength = 40;
CHECK_EQ(toString(requireType("f0"), o), "() -> ()"); CHECK_EQ(toString(requireType("f0"), o), "() -> ()");
CHECK_EQ(toString(requireType("f1"), o), "(() -> ()) -> () -> ()"); CHECK_EQ(toString(requireType("f1"), o), "(() -> ()) -> () -> ()");
if (FFlag::LuauSpecialTypesAsterisked) CHECK_EQ(toString(requireType("f2"), o), "((() -> ()) -> () -> ()) -> (() -> ()) -> ... *TRUNCATED*");
{ CHECK_EQ(toString(requireType("f3"), o), "(((() -> ()) -> () -> ()) -> (() -> ()) -> ... *TRUNCATED*");
CHECK_EQ(toString(requireType("f2"), o), "((() -> ()) -> () -> ()) -> (() -> ()) -> ... *TRUNCATED*");
CHECK_EQ(toString(requireType("f3"), o), "(((() -> ()) -> () -> ()) -> (() -> ()) -> ... *TRUNCATED*");
}
else
{
CHECK_EQ(toString(requireType("f2"), o), "((() -> ()) -> () -> ()) -> (() -> ()) -> ... <TRUNCATED>");
CHECK_EQ(toString(requireType("f3"), o), "(((() -> ()) -> () -> ()) -> (() -> ()) -> ... <TRUNCATED>");
}
} }
TEST_CASE_FIXTURE(Fixture, "stringifying_table_type_correctly_use_matching_table_state_braces") TEST_CASE_FIXTURE(Fixture, "stringifying_table_type_correctly_use_matching_table_state_braces")
@ -535,10 +518,7 @@ local function target(callback: nil) return callback(4, "hello") end
)"); )");
LUAU_REQUIRE_ERRORS(result); LUAU_REQUIRE_ERRORS(result);
if (FFlag::LuauSpecialTypesAsterisked) CHECK_EQ("(nil) -> (*error-type*)", toString(requireType("target")));
CHECK_EQ("(nil) -> (*error-type*)", toString(requireType("target")));
else
CHECK_EQ("(nil) -> (<error-type>)", toString(requireType("target")));
} }
TEST_CASE_FIXTURE(Fixture, "toStringGenericPack") TEST_CASE_FIXTURE(Fixture, "toStringGenericPack")

View file

@ -13,8 +13,6 @@
using namespace Luau; using namespace Luau;
LUAU_FASTFLAG(LuauSpecialTypesAsterisked)
TEST_SUITE_BEGIN("TypeInferAnyError"); TEST_SUITE_BEGIN("TypeInferAnyError");
TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_returns_any") TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_returns_any")
@ -96,10 +94,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error")
LUAU_REQUIRE_ERROR_COUNT(1, result); LUAU_REQUIRE_ERROR_COUNT(1, result);
if (FFlag::LuauSpecialTypesAsterisked) CHECK_EQ("*error-type*", toString(requireType("a")));
CHECK_EQ("*error-type*", toString(requireType("a")));
else
CHECK_EQ("<error-type>", toString(requireType("a")));
} }
TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error2") TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error2")
@ -115,10 +110,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error2")
LUAU_REQUIRE_ERROR_COUNT(1, result); LUAU_REQUIRE_ERROR_COUNT(1, result);
if (FFlag::LuauSpecialTypesAsterisked) CHECK_EQ("*error-type*", toString(requireType("a")));
CHECK_EQ("*error-type*", toString(requireType("a")));
else
CHECK_EQ("<error-type>", toString(requireType("a")));
} }
TEST_CASE_FIXTURE(Fixture, "length_of_error_type_does_not_produce_an_error") TEST_CASE_FIXTURE(Fixture, "length_of_error_type_does_not_produce_an_error")
@ -233,10 +225,7 @@ TEST_CASE_FIXTURE(Fixture, "calling_error_type_yields_error")
CHECK_EQ("unknown", err->name); CHECK_EQ("unknown", err->name);
if (FFlag::LuauSpecialTypesAsterisked) CHECK_EQ("*error-type*", toString(requireType("a")));
CHECK_EQ("*error-type*", toString(requireType("a")));
else
CHECK_EQ("<error-type>", toString(requireType("a")));
} }
TEST_CASE_FIXTURE(Fixture, "chain_calling_error_type_yields_error") TEST_CASE_FIXTURE(Fixture, "chain_calling_error_type_yields_error")
@ -245,10 +234,7 @@ TEST_CASE_FIXTURE(Fixture, "chain_calling_error_type_yields_error")
local a = Utility.Create "Foo" {} local a = Utility.Create "Foo" {}
)"); )");
if (FFlag::LuauSpecialTypesAsterisked) CHECK_EQ("*error-type*", toString(requireType("a")));
CHECK_EQ("*error-type*", toString(requireType("a")));
else
CHECK_EQ("<error-type>", toString(requireType("a")));
} }
TEST_CASE_FIXTURE(BuiltinsFixture, "replace_every_free_type_when_unifying_a_complex_function_with_any") TEST_CASE_FIXTURE(BuiltinsFixture, "replace_every_free_type_when_unifying_a_complex_function_with_any")

View file

@ -8,7 +8,6 @@
using namespace Luau; using namespace Luau;
LUAU_FASTFLAG(LuauSpecialTypesAsterisked);
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
TEST_SUITE_BEGIN("BuiltinTests"); TEST_SUITE_BEGIN("BuiltinTests");
@ -685,7 +684,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "select_with_variadic_typepack_tail")
LUAU_REQUIRE_NO_ERRORS(result); LUAU_REQUIRE_NO_ERRORS(result);
if (FFlag::DebugLuauDeferredConstraintResolution && FFlag::LuauSpecialTypesAsterisked) if (FFlag::DebugLuauDeferredConstraintResolution)
{ {
CHECK_EQ("string", toString(requireType("foo"))); CHECK_EQ("string", toString(requireType("foo")));
CHECK_EQ("*error-type*", toString(requireType("bar"))); CHECK_EQ("*error-type*", toString(requireType("bar")));
@ -714,7 +713,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "select_with_variadic_typepack_tail_and_strin
LUAU_REQUIRE_NO_ERRORS(result); LUAU_REQUIRE_NO_ERRORS(result);
if (FFlag::DebugLuauDeferredConstraintResolution && FFlag::LuauSpecialTypesAsterisked) if (FFlag::DebugLuauDeferredConstraintResolution)
{ {
CHECK_EQ("string", toString(requireType("foo"))); CHECK_EQ("string", toString(requireType("foo")));
CHECK_EQ("string", toString(requireType("bar"))); CHECK_EQ("string", toString(requireType("bar")));
@ -1016,10 +1015,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_is_generic")
CHECK_EQ("number", toString(requireType("a"))); CHECK_EQ("number", toString(requireType("a")));
CHECK_EQ("string", toString(requireType("b"))); CHECK_EQ("string", toString(requireType("b")));
CHECK_EQ("boolean", toString(requireType("c"))); CHECK_EQ("boolean", toString(requireType("c")));
if (FFlag::LuauSpecialTypesAsterisked) CHECK_EQ("*error-type*", toString(requireType("d")));
CHECK_EQ("*error-type*", toString(requireType("d")));
else
CHECK_EQ("<error-type>", toString(requireType("d")));
} }
TEST_CASE_FIXTURE(BuiltinsFixture, "set_metatable_needs_arguments") TEST_CASE_FIXTURE(BuiltinsFixture, "set_metatable_needs_arguments")

View file

@ -15,7 +15,6 @@
using namespace Luau; using namespace Luau;
LUAU_FASTFLAG(LuauInstantiateInSubtyping); LUAU_FASTFLAG(LuauInstantiateInSubtyping);
LUAU_FASTFLAG(LuauSpecialTypesAsterisked);
TEST_SUITE_BEGIN("TypeInferFunctions"); TEST_SUITE_BEGIN("TypeInferFunctions");
@ -980,19 +979,13 @@ TEST_CASE_FIXTURE(Fixture, "function_cast_error_uses_correct_language")
REQUIRE(tm1); REQUIRE(tm1);
CHECK_EQ("(string) -> number", toString(tm1->wantedType)); CHECK_EQ("(string) -> number", toString(tm1->wantedType));
if (FFlag::LuauSpecialTypesAsterisked) CHECK_EQ("(string, *error-type*) -> number", toString(tm1->givenType));
CHECK_EQ("(string, *error-type*) -> number", toString(tm1->givenType));
else
CHECK_EQ("(string, <error-type>) -> number", toString(tm1->givenType));
auto tm2 = get<TypeMismatch>(result.errors[1]); auto tm2 = get<TypeMismatch>(result.errors[1]);
REQUIRE(tm2); REQUIRE(tm2);
CHECK_EQ("(number, number) -> (number, number)", toString(tm2->wantedType)); CHECK_EQ("(number, number) -> (number, number)", toString(tm2->wantedType));
if (FFlag::LuauSpecialTypesAsterisked) CHECK_EQ("(string, *error-type*) -> number", toString(tm2->givenType));
CHECK_EQ("(string, *error-type*) -> number", toString(tm2->givenType));
else
CHECK_EQ("(string, <error-type>) -> number", toString(tm2->givenType));
} }
TEST_CASE_FIXTURE(Fixture, "no_lossy_function_type") TEST_CASE_FIXTURE(Fixture, "no_lossy_function_type")
@ -1538,20 +1531,10 @@ function t:b() return 2 end -- not OK
)"); )");
LUAU_REQUIRE_ERROR_COUNT(1, result); LUAU_REQUIRE_ERROR_COUNT(1, result);
if (FFlag::LuauSpecialTypesAsterisked) CHECK_EQ(R"(Type '(*error-type*) -> number' could not be converted into '() -> number'
{
CHECK_EQ(R"(Type '(*error-type*) -> number' could not be converted into '() -> number'
caused by: caused by:
Argument count mismatch. Function expects 1 argument, but none are specified)", Argument count mismatch. Function expects 1 argument, but none are specified)",
toString(result.errors[0])); toString(result.errors[0]));
}
else
{
CHECK_EQ(R"(Type '(<error-type>) -> number' could not be converted into '() -> number'
caused by:
Argument count mismatch. Function expects 1 argument, but none are specified)",
toString(result.errors[0]));
}
} }
TEST_CASE_FIXTURE(Fixture, "too_few_arguments_variadic") TEST_CASE_FIXTURE(Fixture, "too_few_arguments_variadic")
@ -1779,7 +1762,72 @@ z = y -- Not OK, so the line is colorable
)"); )");
LUAU_REQUIRE_ERROR_COUNT(1, result); LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ(toString(result.errors[0]), "Type '((\"blue\" | \"red\") -> (\"blue\" | \"red\") -> (\"blue\" | \"red\") -> boolean) & ((\"blue\" | \"red\") -> (\"blue\") -> (\"blue\") -> false) & ((\"blue\" | \"red\") -> (\"red\") -> (\"red\") -> false) & ((\"blue\") -> (\"blue\") -> (\"blue\" | \"red\") -> false) & ((\"red\") -> (\"red\") -> (\"blue\" | \"red\") -> false)' could not be converted into '(\"blue\" | \"red\") -> (\"blue\" | \"red\") -> (\"blue\" | \"red\") -> false'; none of the intersection parts are compatible"); CHECK_EQ(toString(result.errors[0]),
"Type '((\"blue\" | \"red\") -> (\"blue\" | \"red\") -> (\"blue\" | \"red\") -> boolean) & ((\"blue\" | \"red\") -> (\"blue\") -> (\"blue\") "
"-> false) & ((\"blue\" | \"red\") -> (\"red\") -> (\"red\") -> false) & ((\"blue\") -> (\"blue\") -> (\"blue\" | \"red\") -> false) & "
"((\"red\") -> (\"red\") -> (\"blue\" | \"red\") -> false)' could not be converted into '(\"blue\" | \"red\") -> (\"blue\" | \"red\") -> "
"(\"blue\" | \"red\") -> false'; none of the intersection parts are compatible");
}
TEST_CASE_FIXTURE(Fixture, "function_is_supertype_of_concrete_functions")
{
ScopedFastFlag sff{"LuauNegatedFunctionTypes", true};
registerHiddenTypes(*this, frontend.globalTypes);
CheckResult result = check(R"(
function foo(f: fun) end
function a() end
function id(x) return x end
foo(a)
foo(id)
foo(foo)
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "concrete_functions_are_not_supertypes_of_function")
{
ScopedFastFlag sff{"LuauNegatedFunctionTypes", true};
registerHiddenTypes(*this, frontend.globalTypes);
CheckResult result = check(R"(
local a: fun = function() end
function one(arg: () -> ()) end
function two(arg: <T>(T) -> T) end
one(a)
two(a)
)");
LUAU_REQUIRE_ERROR_COUNT(2, result);
CHECK(6 == result.errors[0].location.begin.line);
CHECK(7 == result.errors[1].location.begin.line);
}
TEST_CASE_FIXTURE(Fixture, "other_things_are_not_related_to_function")
{
ScopedFastFlag sff{"LuauNegatedFunctionTypes", true};
registerHiddenTypes(*this, frontend.globalTypes);
CheckResult result = check(R"(
local a: fun = function() end
local b: {} = a
local c: boolean = a
local d: fun = true
local e: fun = {}
)");
LUAU_REQUIRE_ERROR_COUNT(4, result);
CHECK(2 == result.errors[0].location.begin.line);
CHECK(3 == result.errors[1].location.begin.line);
CHECK(4 == result.errors[2].location.begin.line);
CHECK(5 == result.errors[3].location.begin.line);
} }
TEST_SUITE_END(); TEST_SUITE_END();

View file

@ -10,7 +10,6 @@
#include "doctest.h" #include "doctest.h"
LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAG(LuauInstantiateInSubtyping)
LUAU_FASTFLAG(LuauSpecialTypesAsterisked)
using namespace Luau; using namespace Luau;
@ -1011,10 +1010,7 @@ TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_quantifying")
std::optional<TypeFun> t0 = getMainModule()->getModuleScope()->lookupType("t0"); std::optional<TypeFun> t0 = getMainModule()->getModuleScope()->lookupType("t0");
REQUIRE(t0); REQUIRE(t0);
if (FFlag::LuauSpecialTypesAsterisked) CHECK_EQ("*error-type*", toString(t0->type));
CHECK_EQ("*error-type*", toString(t0->type));
else
CHECK_EQ("<error-type>", toString(t0->type));
auto it = std::find_if(result.errors.begin(), result.errors.end(), [](TypeError& err) { auto it = std::find_if(result.errors.begin(), result.errors.end(), [](TypeError& err) {
return get<OccursCheckFailed>(err); return get<OccursCheckFailed>(err);

View file

@ -13,7 +13,6 @@
using namespace Luau; using namespace Luau;
LUAU_FASTFLAG(LuauSpecialTypesAsterisked)
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
TEST_SUITE_BEGIN("TypeInferLoops"); TEST_SUITE_BEGIN("TypeInferLoops");
@ -157,10 +156,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_on_error")
LUAU_REQUIRE_ERROR_COUNT(2, result); LUAU_REQUIRE_ERROR_COUNT(2, result);
TypeId p = requireType("p"); TypeId p = requireType("p");
if (FFlag::LuauSpecialTypesAsterisked) CHECK_EQ("*error-type*", toString(p));
CHECK_EQ("*error-type*", toString(p));
else
CHECK_EQ("<error-type>", toString(p));
} }
TEST_CASE_FIXTURE(Fixture, "for_in_loop_on_non_function") TEST_CASE_FIXTURE(Fixture, "for_in_loop_on_non_function")

View file

@ -14,8 +14,6 @@ LUAU_FASTFLAG(LuauInstantiateInSubtyping)
using namespace Luau; using namespace Luau;
LUAU_FASTFLAG(LuauSpecialTypesAsterisked)
TEST_SUITE_BEGIN("TypeInferModules"); TEST_SUITE_BEGIN("TypeInferModules");
TEST_CASE_FIXTURE(BuiltinsFixture, "dcr_require_basic") TEST_CASE_FIXTURE(BuiltinsFixture, "dcr_require_basic")
@ -176,10 +174,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "require_module_that_does_not_export")
auto hootyType = requireType(bModule, "Hooty"); auto hootyType = requireType(bModule, "Hooty");
if (FFlag::LuauSpecialTypesAsterisked) CHECK_EQ("*error-type*", toString(hootyType));
CHECK_EQ("*error-type*", toString(hootyType));
else
CHECK_EQ("<error-type>", toString(hootyType));
} }
TEST_CASE_FIXTURE(BuiltinsFixture, "warn_if_you_try_to_require_a_non_modulescript") TEST_CASE_FIXTURE(BuiltinsFixture, "warn_if_you_try_to_require_a_non_modulescript")
@ -282,10 +277,7 @@ local ModuleA = require(game.A)
std::optional<TypeId> oty = requireType("ModuleA"); std::optional<TypeId> oty = requireType("ModuleA");
if (FFlag::LuauSpecialTypesAsterisked) CHECK_EQ("*error-type*", toString(*oty));
CHECK_EQ("*error-type*", toString(*oty));
else
CHECK_EQ("<error-type>", toString(*oty));
} }
TEST_CASE_FIXTURE(BuiltinsFixture, "do_not_modify_imported_types") TEST_CASE_FIXTURE(BuiltinsFixture, "do_not_modify_imported_types")

View file

@ -13,17 +13,17 @@ namespace
struct NegationFixture : Fixture struct NegationFixture : Fixture
{ {
TypeArena arena; TypeArena arena;
ScopedFastFlag sff[2] { ScopedFastFlag sff[2]{
{"LuauNegatedStringSingletons", true}, {"LuauNegatedStringSingletons", true},
{"LuauSubtypeNormalizer", true}, {"LuauSubtypeNormalizer", true},
}; };
NegationFixture() NegationFixture()
{ {
registerNotType(*this, arena); registerHiddenTypes(*this, arena);
} }
}; };
} } // namespace
TEST_SUITE_BEGIN("Negations"); TEST_SUITE_BEGIN("Negations");

View file

@ -11,8 +11,6 @@
#include "doctest.h" #include "doctest.h"
LUAU_FASTFLAG(LuauSpecialTypesAsterisked)
using namespace Luau; using namespace Luau;
TEST_SUITE_BEGIN("TypeInferPrimitives"); TEST_SUITE_BEGIN("TypeInferPrimitives");
@ -49,10 +47,7 @@ TEST_CASE_FIXTURE(Fixture, "string_index")
REQUIRE(nat); REQUIRE(nat);
CHECK_EQ("string", toString(nat->ty)); CHECK_EQ("string", toString(nat->ty));
if (FFlag::LuauSpecialTypesAsterisked) CHECK_EQ("*error-type*", toString(requireType("t")));
CHECK_EQ("*error-type*", toString(requireType("t")));
else
CHECK_EQ("<error-type>", toString(requireType("t")));
} }
TEST_CASE_FIXTURE(Fixture, "string_method") TEST_CASE_FIXTURE(Fixture, "string_method")

View file

@ -633,7 +633,7 @@ struct IsSubtypeFixture : Fixture
return ::Luau::isSubtype(a, b, NotNull{getMainModule()->getModuleScope().get()}, singletonTypes, ice); return ::Luau::isSubtype(a, b, NotNull{getMainModule()->getModuleScope().get()}, singletonTypes, ice);
} }
}; };
} } // namespace
TEST_CASE_FIXTURE(IsSubtypeFixture, "intersection_of_functions_of_different_arities") TEST_CASE_FIXTURE(IsSubtypeFixture, "intersection_of_functions_of_different_arities")
{ {

View file

@ -7,7 +7,6 @@
#include "doctest.h" #include "doctest.h"
LUAU_FASTFLAG(LuauSpecialTypesAsterisked)
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
using namespace Luau; using namespace Luau;
@ -127,7 +126,7 @@ TEST_CASE_FIXTURE(Fixture, "invert_is_truthy_constraint")
if (FFlag::DebugLuauDeferredConstraintResolution) if (FFlag::DebugLuauDeferredConstraintResolution)
{ {
CHECK_EQ("(string?) & ~~(false?)", toString(requireTypeAtPosition({3, 26}))); CHECK_EQ("(string?) & ~~(false?)", toString(requireTypeAtPosition({3, 26})));
CHECK_EQ("(string?) & ~~~(false?)", toString(requireTypeAtPosition({5, 26}))); CHECK_EQ("(string?) & ~(false?)", toString(requireTypeAtPosition({5, 26})));
} }
else else
{ {
@ -153,7 +152,7 @@ TEST_CASE_FIXTURE(Fixture, "parenthesized_expressions_are_followed_through")
if (FFlag::DebugLuauDeferredConstraintResolution) if (FFlag::DebugLuauDeferredConstraintResolution)
{ {
CHECK_EQ("(string?) & ~~(false?)", toString(requireTypeAtPosition({3, 26}))); CHECK_EQ("(string?) & ~~(false?)", toString(requireTypeAtPosition({3, 26})));
CHECK_EQ("(string?) & ~~~(false?)", toString(requireTypeAtPosition({5, 26}))); CHECK_EQ("(string?) & ~(false?)", toString(requireTypeAtPosition({5, 26})));
} }
else else
{ {
@ -178,8 +177,16 @@ TEST_CASE_FIXTURE(Fixture, "and_constraint")
LUAU_REQUIRE_NO_ERRORS(result); LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("string", toString(requireTypeAtPosition({3, 26}))); if (FFlag::DebugLuauDeferredConstraintResolution)
CHECK_EQ("number", toString(requireTypeAtPosition({4, 26}))); {
CHECK_EQ("(string?) & ~(false?)", toString(requireTypeAtPosition({3, 26})));
CHECK_EQ("(number?) & ~(false?)", toString(requireTypeAtPosition({4, 26})));
}
else
{
CHECK_EQ("string", toString(requireTypeAtPosition({3, 26})));
CHECK_EQ("number", toString(requireTypeAtPosition({4, 26})));
}
CHECK_EQ("string?", toString(requireTypeAtPosition({6, 26}))); CHECK_EQ("string?", toString(requireTypeAtPosition({6, 26})));
CHECK_EQ("number?", toString(requireTypeAtPosition({7, 26}))); CHECK_EQ("number?", toString(requireTypeAtPosition({7, 26})));
@ -204,8 +211,16 @@ TEST_CASE_FIXTURE(Fixture, "not_and_constraint")
CHECK_EQ("string?", toString(requireTypeAtPosition({3, 26}))); CHECK_EQ("string?", toString(requireTypeAtPosition({3, 26})));
CHECK_EQ("number?", toString(requireTypeAtPosition({4, 26}))); CHECK_EQ("number?", toString(requireTypeAtPosition({4, 26})));
CHECK_EQ("string", toString(requireTypeAtPosition({6, 26}))); if (FFlag::DebugLuauDeferredConstraintResolution)
CHECK_EQ("number", toString(requireTypeAtPosition({7, 26}))); {
CHECK_EQ("(string?) & ~(false?)", toString(requireTypeAtPosition({6, 26})));
CHECK_EQ("(number?) & ~(false?)", toString(requireTypeAtPosition({7, 26})));
}
else
{
CHECK_EQ("string", toString(requireTypeAtPosition({6, 26})));
CHECK_EQ("number", toString(requireTypeAtPosition({7, 26})));
}
} }
TEST_CASE_FIXTURE(Fixture, "or_predicate_with_truthy_predicates") TEST_CASE_FIXTURE(Fixture, "or_predicate_with_truthy_predicates")
@ -227,8 +242,56 @@ TEST_CASE_FIXTURE(Fixture, "or_predicate_with_truthy_predicates")
CHECK_EQ("string?", toString(requireTypeAtPosition({3, 26}))); CHECK_EQ("string?", toString(requireTypeAtPosition({3, 26})));
CHECK_EQ("number?", toString(requireTypeAtPosition({4, 26}))); CHECK_EQ("number?", toString(requireTypeAtPosition({4, 26})));
CHECK_EQ("nil", toString(requireTypeAtPosition({6, 26}))); if (FFlag::DebugLuauDeferredConstraintResolution)
CHECK_EQ("nil", toString(requireTypeAtPosition({7, 26}))); {
CHECK_EQ("(string?) & ~~(false?)", toString(requireTypeAtPosition({6, 26})));
CHECK_EQ("(number?) & ~~(false?)", toString(requireTypeAtPosition({7, 26})));
}
else
{
CHECK_EQ("nil", toString(requireTypeAtPosition({6, 26})));
CHECK_EQ("nil", toString(requireTypeAtPosition({7, 26})));
}
}
TEST_CASE_FIXTURE(Fixture, "a_and_b_or_a_and_c")
{
CheckResult result = check(R"(
function f(a: string?, b: number?, c: boolean)
if (a and b) or (a and c) then
local foo = a
local bar = b
local baz = c
else
local foo = a
local bar = b
local baz = c
end
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
if (FFlag::DebugLuauDeferredConstraintResolution)
{
CHECK_EQ("(string?) & (~(false?) | ~(false?))", toString(requireTypeAtPosition({3, 28})));
CHECK_EQ("number?", toString(requireTypeAtPosition({4, 28})));
CHECK_EQ("boolean", toString(requireTypeAtPosition({5, 28})));
CHECK_EQ("string?", toString(requireTypeAtPosition({7, 28})));
CHECK_EQ("number?", toString(requireTypeAtPosition({8, 28})));
CHECK_EQ("boolean", toString(requireTypeAtPosition({9, 28})));
}
else
{
CHECK_EQ("string", toString(requireTypeAtPosition({3, 28})));
CHECK_EQ("number?", toString(requireTypeAtPosition({4, 28})));
CHECK_EQ("true", toString(requireTypeAtPosition({5, 28}))); // oh no! :(
CHECK_EQ("string?", toString(requireTypeAtPosition({7, 28})));
CHECK_EQ("number?", toString(requireTypeAtPosition({8, 28})));
CHECK_EQ("boolean", toString(requireTypeAtPosition({9, 28})));
}
} }
TEST_CASE_FIXTURE(Fixture, "type_assertion_expr_carry_its_constraints") TEST_CASE_FIXTURE(Fixture, "type_assertion_expr_carry_its_constraints")
@ -244,8 +307,17 @@ TEST_CASE_FIXTURE(Fixture, "type_assertion_expr_carry_its_constraints")
LUAU_REQUIRE_NO_ERRORS(result); LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("number", toString(requireTypeAtPosition({3, 26}))); if (FFlag::DebugLuauDeferredConstraintResolution)
CHECK_EQ("string", toString(requireTypeAtPosition({4, 26}))); {
CHECK_EQ("number?", toString(requireTypeAtPosition({3, 26})));
CHECK_EQ("string?", toString(requireTypeAtPosition({4, 26})));
}
else
{
// We're going to drop support for type refinements through type assertions.
CHECK_EQ("number", toString(requireTypeAtPosition({3, 26})));
CHECK_EQ("string", toString(requireTypeAtPosition({4, 26})));
}
} }
TEST_CASE_FIXTURE(Fixture, "typeguard_in_if_condition_position") TEST_CASE_FIXTURE(Fixture, "typeguard_in_if_condition_position")
@ -381,11 +453,22 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_is_equal_to_another_lvalue")
LUAU_REQUIRE_NO_ERRORS(result); LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "(number | string)?"); // a == b if (FFlag::DebugLuauDeferredConstraintResolution)
CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "boolean?"); // a == b {
CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "((number | string)?) & (boolean?)"); // a == b
CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "((number | string)?) & (boolean?)"); // a == b
CHECK_EQ(toString(requireTypeAtPosition({5, 33})), "(number | string)?"); // a ~= b CHECK_EQ(toString(requireTypeAtPosition({5, 33})), "((number | string)?) & unknown"); // a ~= b
CHECK_EQ(toString(requireTypeAtPosition({5, 36})), "boolean?"); // a ~= b CHECK_EQ(toString(requireTypeAtPosition({5, 36})), "(boolean?) & unknown"); // a ~= b
}
else
{
CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "(number | string)?"); // a == b
CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "boolean?"); // a == b
CHECK_EQ(toString(requireTypeAtPosition({5, 33})), "(number | string)?"); // a ~= b
CHECK_EQ(toString(requireTypeAtPosition({5, 36})), "boolean?"); // a ~= b
}
} }
TEST_CASE_FIXTURE(Fixture, "lvalue_is_equal_to_a_term") TEST_CASE_FIXTURE(Fixture, "lvalue_is_equal_to_a_term")
@ -402,8 +485,16 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_is_equal_to_a_term")
LUAU_REQUIRE_NO_ERRORS(result); LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "(number | string)?"); // a == 1 if (FFlag::DebugLuauDeferredConstraintResolution)
CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a ~= 1 {
CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "((number | string)?) & number"); // a == 1
CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "((number | string)?) & unknown"); // a ~= 1
}
else
{
CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "(number | string)?"); // a == 1;
CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a ~= 1
}
} }
TEST_CASE_FIXTURE(Fixture, "term_is_equal_to_an_lvalue") TEST_CASE_FIXTURE(Fixture, "term_is_equal_to_an_lvalue")
@ -420,8 +511,16 @@ TEST_CASE_FIXTURE(Fixture, "term_is_equal_to_an_lvalue")
LUAU_REQUIRE_NO_ERRORS(result); LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(toString(requireTypeAtPosition({3, 28})), R"("hello")"); // a == "hello" if (FFlag::DebugLuauDeferredConstraintResolution)
CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a ~= "hello" {
CHECK_EQ(toString(requireTypeAtPosition({3, 28})), R"("hello" & ((number | string)?))"); // a == "hello"
CHECK_EQ(toString(requireTypeAtPosition({5, 28})), R"(((number | string)?) & ~"hello")"); // a ~= "hello"
}
else
{
CHECK_EQ(toString(requireTypeAtPosition({3, 28})), R"("hello")"); // a == "hello"
CHECK_EQ(toString(requireTypeAtPosition({5, 28})), R"((number | string)?)"); // a ~= "hello"
}
} }
TEST_CASE_FIXTURE(Fixture, "lvalue_is_not_nil") TEST_CASE_FIXTURE(Fixture, "lvalue_is_not_nil")
@ -438,8 +537,16 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_is_not_nil")
LUAU_REQUIRE_NO_ERRORS(result); LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "number | string"); // a ~= nil if (FFlag::DebugLuauDeferredConstraintResolution)
CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a == nil {
CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "((number | string)?) & ~nil"); // a ~= nil
CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "((number | string)?) & nil"); // a == nil
}
else
{
CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "number | string"); // a ~= nil
CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a == nil
}
} }
TEST_CASE_FIXTURE(Fixture, "free_type_is_equal_to_an_lvalue") TEST_CASE_FIXTURE(Fixture, "free_type_is_equal_to_an_lvalue")
@ -454,8 +561,17 @@ TEST_CASE_FIXTURE(Fixture, "free_type_is_equal_to_an_lvalue")
LUAU_REQUIRE_NO_ERRORS(result); LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "a"); // a == b if (FFlag::DebugLuauDeferredConstraintResolution)
CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "string?"); // a == b {
ToStringOptions opts;
CHECK_EQ(toString(requireTypeAtPosition({3, 33}), opts), "(string?) & a"); // a == b
CHECK_EQ(toString(requireTypeAtPosition({3, 36}), opts), "(string?) & a"); // a == b
}
else
{
CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "a"); // a == b
CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "string?"); // a == b
}
} }
TEST_CASE_FIXTURE(Fixture, "unknown_lvalue_is_not_synonymous_with_other_on_not_equal") TEST_CASE_FIXTURE(Fixture, "unknown_lvalue_is_not_synonymous_with_other_on_not_equal")
@ -470,8 +586,16 @@ TEST_CASE_FIXTURE(Fixture, "unknown_lvalue_is_not_synonymous_with_other_on_not_e
LUAU_REQUIRE_NO_ERRORS(result); LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "any"); // a ~= b if (FFlag::DebugLuauDeferredConstraintResolution)
CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "{| x: number |}?"); // a ~= b {
CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "any & unknown"); // a ~= b
CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "({| x: number |}?) & unknown"); // a ~= b
}
else
{
CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "any"); // a ~= b
CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "{| x: number |}?"); // a ~= b
}
} }
TEST_CASE_FIXTURE(Fixture, "string_not_equal_to_string_or_nil") TEST_CASE_FIXTURE(Fixture, "string_not_equal_to_string_or_nil")
@ -490,11 +614,22 @@ TEST_CASE_FIXTURE(Fixture, "string_not_equal_to_string_or_nil")
LUAU_REQUIRE_NO_ERRORS(result); LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(toString(requireTypeAtPosition({6, 29})), "string"); // a ~= b if (FFlag::DebugLuauDeferredConstraintResolution)
CHECK_EQ(toString(requireTypeAtPosition({6, 32})), "string?"); // a ~= b {
CHECK_EQ(toString(requireTypeAtPosition({6, 29})), "string & unknown"); // a ~= b
CHECK_EQ(toString(requireTypeAtPosition({6, 32})), "(string?) & unknown"); // a ~= b
CHECK_EQ(toString(requireTypeAtPosition({8, 29})), "string"); // a == b CHECK_EQ(toString(requireTypeAtPosition({8, 29})), "(string?) & string"); // a == b
CHECK_EQ(toString(requireTypeAtPosition({8, 32})), "string?"); // a == b CHECK_EQ(toString(requireTypeAtPosition({8, 32})), "(string?) & string"); // a == b
}
else
{
CHECK_EQ(toString(requireTypeAtPosition({6, 29})), "string"); // a ~= b
CHECK_EQ(toString(requireTypeAtPosition({6, 32})), "string?"); // a ~= b
CHECK_EQ(toString(requireTypeAtPosition({8, 29})), "string"); // a == b
CHECK_EQ(toString(requireTypeAtPosition({8, 32})), "string?"); // a == b
}
} }
TEST_CASE_FIXTURE(Fixture, "narrow_property_of_a_bounded_variable") TEST_CASE_FIXTURE(Fixture, "narrow_property_of_a_bounded_variable")
@ -525,10 +660,7 @@ TEST_CASE_FIXTURE(Fixture, "type_narrow_to_vector")
LUAU_REQUIRE_NO_ERRORS(result); LUAU_REQUIRE_NO_ERRORS(result);
if (FFlag::LuauSpecialTypesAsterisked) CHECK_EQ("*error-type*", toString(requireTypeAtPosition({3, 28})));
CHECK_EQ("*error-type*", toString(requireTypeAtPosition({3, 28})));
else
CHECK_EQ("<error-type>", toString(requireTypeAtPosition({3, 28})));
} }
TEST_CASE_FIXTURE(Fixture, "nonoptional_type_can_narrow_to_nil_if_sense_is_true") TEST_CASE_FIXTURE(Fixture, "nonoptional_type_can_narrow_to_nil_if_sense_is_true")
@ -715,8 +847,16 @@ TEST_CASE_FIXTURE(Fixture, "not_a_and_not_b")
LUAU_REQUIRE_NO_ERRORS(result); LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("nil", toString(requireTypeAtPosition({3, 28}))); if (FFlag::DebugLuauDeferredConstraintResolution)
CHECK_EQ("nil", toString(requireTypeAtPosition({4, 28}))); {
CHECK_EQ("(number?) & ~~(false?)", toString(requireTypeAtPosition({3, 28})));
CHECK_EQ("(number?) & ~~(false?)", toString(requireTypeAtPosition({4, 28})));
}
else
{
CHECK_EQ("nil", toString(requireTypeAtPosition({3, 28})));
CHECK_EQ("nil", toString(requireTypeAtPosition({4, 28})));
}
} }
TEST_CASE_FIXTURE(Fixture, "not_a_and_not_b2") TEST_CASE_FIXTURE(Fixture, "not_a_and_not_b2")
@ -732,8 +872,16 @@ TEST_CASE_FIXTURE(Fixture, "not_a_and_not_b2")
LUAU_REQUIRE_NO_ERRORS(result); LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("nil", toString(requireTypeAtPosition({3, 28}))); if (FFlag::DebugLuauDeferredConstraintResolution)
CHECK_EQ("nil", toString(requireTypeAtPosition({4, 28}))); {
CHECK_EQ("(number?) & ~~(false?)", toString(requireTypeAtPosition({3, 28})));
CHECK_EQ("(number?) & ~~(false?)", toString(requireTypeAtPosition({4, 28})));
}
else
{
CHECK_EQ("nil", toString(requireTypeAtPosition({3, 28})));
CHECK_EQ("nil", toString(requireTypeAtPosition({4, 28})));
}
} }
TEST_CASE_FIXTURE(Fixture, "either_number_or_string") TEST_CASE_FIXTURE(Fixture, "either_number_or_string")

View file

@ -79,6 +79,16 @@ TEST_CASE_FIXTURE(Fixture, "string_singleton_subtype")
LUAU_REQUIRE_NO_ERRORS(result); LUAU_REQUIRE_NO_ERRORS(result);
} }
TEST_CASE_FIXTURE(Fixture, "string_singleton_subtype_multi_assignment")
{
CheckResult result = check(R"(
local a: "foo" = "foo"
local b: string, c: number = a, 10
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "function_call_with_singletons") TEST_CASE_FIXTURE(Fixture, "function_call_with_singletons")
{ {
CheckResult result = check(R"( CheckResult result = check(R"(

View file

@ -17,7 +17,6 @@ using namespace Luau;
LUAU_FASTFLAG(LuauLowerBoundsCalculation); LUAU_FASTFLAG(LuauLowerBoundsCalculation);
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution);
LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAG(LuauInstantiateInSubtyping)
LUAU_FASTFLAG(LuauSpecialTypesAsterisked)
TEST_SUITE_BEGIN("TableTests"); TEST_SUITE_BEGIN("TableTests");
@ -3272,11 +3271,12 @@ TEST_CASE_FIXTURE(Fixture, "invariant_table_properties_means_instantiating_table
LUAU_REQUIRE_NO_ERRORS(result); LUAU_REQUIRE_NO_ERRORS(result);
// TODO: test behavior is wrong until we can re-enable the covariant requirement for instantiation in subtyping // TODO: test behavior is wrong until we can re-enable the covariant requirement for instantiation in subtyping
// LUAU_REQUIRE_ERRORS(result); // LUAU_REQUIRE_ERRORS(result);
// CHECK_EQ(toString(result.errors[0]), R"(Type 't' could not be converted into '{| m: (number) -> number |}' // CHECK_EQ(toString(result.errors[0]), R"(Type 't' could not be converted into '{| m: (number) -> number |}'
// caused by: // caused by:
// Property 'm' is not compatible. Type '<a>(a) -> a' could not be converted into '(number) -> number'; different number of generic type parameters)"); // Property 'm' is not compatible. Type '<a>(a) -> a' could not be converted into '(number) -> number'; different number of generic type
// // this error message is not great since the underlying issue is that the context is invariant, // parameters)");
// // this error message is not great since the underlying issue is that the context is invariant,
// and `(number) -> number` cannot be a subtype of `<a>(a) -> a`. // and `(number) -> number` cannot be a subtype of `<a>(a) -> a`.
} }
@ -3335,10 +3335,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "tables_should_be_fully_populated")
ToStringOptions opts; ToStringOptions opts;
opts.exhaustive = true; opts.exhaustive = true;
if (FFlag::LuauSpecialTypesAsterisked) CHECK_EQ("{ x: *error-type*, y: number }", toString(requireType("t"), opts));
CHECK_EQ("{ x: *error-type*, y: number }", toString(requireType("t"), opts));
else
CHECK_EQ("{ x: <error-type>, y: number }", toString(requireType("t"), opts));
} }
TEST_SUITE_END(); TEST_SUITE_END();

View file

@ -17,7 +17,6 @@
LUAU_FASTFLAG(LuauFixLocationSpanTableIndexExpr); LUAU_FASTFLAG(LuauFixLocationSpanTableIndexExpr);
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution);
LUAU_FASTFLAG(LuauInstantiateInSubtyping); LUAU_FASTFLAG(LuauInstantiateInSubtyping);
LUAU_FASTFLAG(LuauSpecialTypesAsterisked);
using namespace Luau; using namespace Luau;
@ -238,20 +237,10 @@ TEST_CASE_FIXTURE(Fixture, "type_errors_infer_types")
// TODO: Should we assert anything about these tests when DCR is being used? // TODO: Should we assert anything about these tests when DCR is being used?
if (!FFlag::DebugLuauDeferredConstraintResolution) if (!FFlag::DebugLuauDeferredConstraintResolution)
{ {
if (FFlag::LuauSpecialTypesAsterisked) CHECK_EQ("*error-type*", toString(requireType("c")));
{ CHECK_EQ("*error-type*", toString(requireType("d")));
CHECK_EQ("*error-type*", toString(requireType("c"))); CHECK_EQ("*error-type*", toString(requireType("e")));
CHECK_EQ("*error-type*", toString(requireType("d"))); CHECK_EQ("*error-type*", toString(requireType("f")));
CHECK_EQ("*error-type*", toString(requireType("e")));
CHECK_EQ("*error-type*", toString(requireType("f")));
}
else
{
CHECK_EQ("<error-type>", toString(requireType("c")));
CHECK_EQ("<error-type>", toString(requireType("d")));
CHECK_EQ("<error-type>", toString(requireType("e")));
CHECK_EQ("<error-type>", toString(requireType("f")));
}
} }
} }
@ -662,10 +651,7 @@ TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_isoptional")
std::optional<TypeFun> t0 = getMainModule()->getModuleScope()->lookupType("t0"); std::optional<TypeFun> t0 = getMainModule()->getModuleScope()->lookupType("t0");
REQUIRE(t0); REQUIRE(t0);
if (FFlag::LuauSpecialTypesAsterisked) CHECK_EQ("*error-type*", toString(t0->type));
CHECK_EQ("*error-type*", toString(t0->type));
else
CHECK_EQ("<error-type>", toString(t0->type));
auto it = std::find_if(result.errors.begin(), result.errors.end(), [](TypeError& err) { auto it = std::find_if(result.errors.begin(), result.errors.end(), [](TypeError& err) {
return get<OccursCheckFailed>(err); return get<OccursCheckFailed>(err);

View file

@ -9,8 +9,6 @@
using namespace Luau; using namespace Luau;
LUAU_FASTFLAG(LuauSpecialTypesAsterisked)
struct TryUnifyFixture : Fixture struct TryUnifyFixture : Fixture
{ {
TypeArena arena; TypeArena arena;
@ -124,10 +122,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "members_of_failed_typepack_unification_are_u
LUAU_REQUIRE_ERROR_COUNT(1, result); LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ("a", toString(requireType("a"))); CHECK_EQ("a", toString(requireType("a")));
if (FFlag::LuauSpecialTypesAsterisked) CHECK_EQ("*error-type*", toString(requireType("b")));
CHECK_EQ("*error-type*", toString(requireType("b")));
else
CHECK_EQ("<error-type>", toString(requireType("b")));
} }
TEST_CASE_FIXTURE(TryUnifyFixture, "result_of_failed_typepack_unification_is_constrained") TEST_CASE_FIXTURE(TryUnifyFixture, "result_of_failed_typepack_unification_is_constrained")
@ -142,10 +137,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "result_of_failed_typepack_unification_is_con
LUAU_REQUIRE_ERROR_COUNT(1, result); LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ("a", toString(requireType("a"))); CHECK_EQ("a", toString(requireType("a")));
if (FFlag::LuauSpecialTypesAsterisked) CHECK_EQ("*error-type*", toString(requireType("b")));
CHECK_EQ("*error-type*", toString(requireType("b")));
else
CHECK_EQ("<error-type>", toString(requireType("b")));
CHECK_EQ("number", toString(requireType("c"))); CHECK_EQ("number", toString(requireType("c")));
} }

View file

@ -6,8 +6,6 @@
#include "doctest.h" #include "doctest.h"
LUAU_FASTFLAG(LuauSpecialTypesAsterisked)
using namespace Luau; using namespace Luau;
TEST_SUITE_BEGIN("UnionTypes"); TEST_SUITE_BEGIN("UnionTypes");
@ -199,10 +197,7 @@ TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_missing_property")
CHECK_EQ(mup->missing[0], *bTy); CHECK_EQ(mup->missing[0], *bTy);
CHECK_EQ(mup->key, "x"); CHECK_EQ(mup->key, "x");
if (FFlag::LuauSpecialTypesAsterisked) CHECK_EQ("*error-type*", toString(requireType("r")));
CHECK_EQ("*error-type*", toString(requireType("r")));
else
CHECK_EQ("<error-type>", toString(requireType("r")));
} }
TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_one_property_of_type_any") TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_one_property_of_type_any")

View file

@ -217,4 +217,35 @@ TEST_CASE("Visit")
CHECK(r3 == "1231147"); CHECK(r3 == "1231147");
} }
struct MoveOnly
{
MoveOnly() = default;
MoveOnly(const MoveOnly&) = delete;
MoveOnly& operator=(const MoveOnly&) = delete;
MoveOnly(MoveOnly&&) = default;
MoveOnly& operator=(MoveOnly&&) = default;
};
TEST_CASE("Move")
{
Variant<MoveOnly> v1 = MoveOnly{};
Variant<MoveOnly> v2 = std::move(v1);
}
TEST_CASE("MoveWithCopyableAlternative")
{
Variant<std::string, MoveOnly> v1 = std::string{"Hello, world! I am longer than a normal hello world string to avoid SSO."};
Variant<std::string, MoveOnly> v2 = std::move(v1);
std::string* s1 = get_if<std::string>(&v1);
REQUIRE(s1);
CHECK(*s1 == "");
std::string* s2 = get_if<std::string>(&v2);
REQUIRE(s2);
CHECK(*s2 == "Hello, world! I am longer than a normal hello world string to avoid SSO.");
}
TEST_SUITE_END(); TEST_SUITE_END();

View file

@ -10,7 +10,9 @@ AnnotationTests.too_many_type_params
AnnotationTests.two_type_params AnnotationTests.two_type_params
AnnotationTests.unknown_type_reference_generates_error AnnotationTests.unknown_type_reference_generates_error
AstQuery.last_argument_function_call_type AstQuery.last_argument_function_call_type
AstQuery::getDocumentationSymbolAtPosition.overloaded_class_method
AstQuery::getDocumentationSymbolAtPosition.overloaded_fn AstQuery::getDocumentationSymbolAtPosition.overloaded_fn
AstQuery::getDocumentationSymbolAtPosition.table_overloaded_function_prop
AutocompleteTest.autocomplete_first_function_arg_expected_type AutocompleteTest.autocomplete_first_function_arg_expected_type
AutocompleteTest.autocomplete_interpolated_string AutocompleteTest.autocomplete_interpolated_string
AutocompleteTest.autocomplete_oop_implicit_self AutocompleteTest.autocomplete_oop_implicit_self
@ -106,17 +108,14 @@ GenericsTests.correctly_instantiate_polymorphic_member_functions
GenericsTests.do_not_infer_generic_functions GenericsTests.do_not_infer_generic_functions
GenericsTests.duplicate_generic_type_packs GenericsTests.duplicate_generic_type_packs
GenericsTests.duplicate_generic_types GenericsTests.duplicate_generic_types
GenericsTests.factories_of_generics
GenericsTests.generic_argument_count_too_few GenericsTests.generic_argument_count_too_few
GenericsTests.generic_argument_count_too_many GenericsTests.generic_argument_count_too_many
GenericsTests.generic_factories GenericsTests.generic_factories
GenericsTests.generic_functions_in_types
GenericsTests.generic_functions_should_be_memory_safe GenericsTests.generic_functions_should_be_memory_safe
GenericsTests.generic_table_method GenericsTests.generic_table_method
GenericsTests.generic_type_pack_parentheses GenericsTests.generic_type_pack_parentheses
GenericsTests.generic_type_pack_unification1 GenericsTests.generic_type_pack_unification1
GenericsTests.generic_type_pack_unification2 GenericsTests.generic_type_pack_unification2
GenericsTests.generic_type_pack_unification3
GenericsTests.higher_rank_polymorphism_should_not_accept_instantiated_arguments GenericsTests.higher_rank_polymorphism_should_not_accept_instantiated_arguments
GenericsTests.infer_generic_function_function_argument GenericsTests.infer_generic_function_function_argument
GenericsTests.infer_generic_function_function_argument_overloaded GenericsTests.infer_generic_function_function_argument_overloaded
@ -172,7 +171,6 @@ ProvisionalTests.table_insert_with_a_singleton_argument
ProvisionalTests.typeguard_inference_incomplete ProvisionalTests.typeguard_inference_incomplete
ProvisionalTests.weirditer_should_not_loop_forever ProvisionalTests.weirditer_should_not_loop_forever
ProvisionalTests.while_body_are_also_refined ProvisionalTests.while_body_are_also_refined
RefinementTest.and_constraint
RefinementTest.apply_refinements_on_astexprindexexpr_whose_subscript_expr_is_constant_string RefinementTest.apply_refinements_on_astexprindexexpr_whose_subscript_expr_is_constant_string
RefinementTest.assert_a_to_be_truthy_then_assert_a_to_be_number RefinementTest.assert_a_to_be_truthy_then_assert_a_to_be_number
RefinementTest.assert_non_binary_expressions_actually_resolve_constraints RefinementTest.assert_non_binary_expressions_actually_resolve_constraints
@ -187,28 +185,17 @@ RefinementTest.either_number_or_string
RefinementTest.eliminate_subclasses_of_instance RefinementTest.eliminate_subclasses_of_instance
RefinementTest.falsiness_of_TruthyPredicate_narrows_into_nil RefinementTest.falsiness_of_TruthyPredicate_narrows_into_nil
RefinementTest.index_on_a_refined_property RefinementTest.index_on_a_refined_property
RefinementTest.invert_is_truthy_constraint
RefinementTest.invert_is_truthy_constraint_ifelse_expression RefinementTest.invert_is_truthy_constraint_ifelse_expression
RefinementTest.is_truthy_constraint
RefinementTest.is_truthy_constraint_ifelse_expression RefinementTest.is_truthy_constraint_ifelse_expression
RefinementTest.lvalue_is_not_nil
RefinementTest.merge_should_be_fully_agnostic_of_hashmap_ordering RefinementTest.merge_should_be_fully_agnostic_of_hashmap_ordering
RefinementTest.narrow_boolean_to_true_or_false
RefinementTest.narrow_property_of_a_bounded_variable RefinementTest.narrow_property_of_a_bounded_variable
RefinementTest.narrow_this_large_union RefinementTest.narrow_this_large_union
RefinementTest.nonoptional_type_can_narrow_to_nil_if_sense_is_true RefinementTest.nonoptional_type_can_narrow_to_nil_if_sense_is_true
RefinementTest.not_a_and_not_b
RefinementTest.not_a_and_not_b2
RefinementTest.not_and_constraint
RefinementTest.not_t_or_some_prop_of_t RefinementTest.not_t_or_some_prop_of_t
RefinementTest.or_predicate_with_truthy_predicates
RefinementTest.parenthesized_expressions_are_followed_through
RefinementTest.refine_a_property_not_to_be_nil_through_an_intersection_table RefinementTest.refine_a_property_not_to_be_nil_through_an_intersection_table
RefinementTest.refine_the_correct_types_opposite_of_when_a_is_not_number_or_string RefinementTest.refine_the_correct_types_opposite_of_when_a_is_not_number_or_string
RefinementTest.refine_unknowns RefinementTest.refine_unknowns
RefinementTest.term_is_equal_to_an_lvalue
RefinementTest.truthy_constraint_on_properties RefinementTest.truthy_constraint_on_properties
RefinementTest.type_assertion_expr_carry_its_constraints
RefinementTest.type_comparison_ifelse_expression RefinementTest.type_comparison_ifelse_expression
RefinementTest.type_guard_can_filter_for_intersection_of_tables RefinementTest.type_guard_can_filter_for_intersection_of_tables
RefinementTest.type_guard_can_filter_for_overloaded_function RefinementTest.type_guard_can_filter_for_overloaded_function
@ -271,6 +258,7 @@ TableTests.infer_indexer_from_value_property_in_literal
TableTests.inferred_return_type_of_free_table TableTests.inferred_return_type_of_free_table
TableTests.inferring_crazy_table_should_also_be_quick TableTests.inferring_crazy_table_should_also_be_quick
TableTests.instantiate_table_cloning_3 TableTests.instantiate_table_cloning_3
TableTests.invariant_table_properties_means_instantiating_tables_in_assignment_is_unsound
TableTests.invariant_table_properties_means_instantiating_tables_in_call_is_unsound TableTests.invariant_table_properties_means_instantiating_tables_in_call_is_unsound
TableTests.leaking_bad_metatable_errors TableTests.leaking_bad_metatable_errors
TableTests.less_exponential_blowup_please TableTests.less_exponential_blowup_please
@ -315,7 +303,6 @@ TableTests.table_subtyping_with_missing_props_dont_report_multiple_errors
TableTests.tables_get_names_from_their_locals TableTests.tables_get_names_from_their_locals
TableTests.tc_member_function TableTests.tc_member_function
TableTests.tc_member_function_2 TableTests.tc_member_function_2
TableTests.top_table_type
TableTests.type_mismatch_on_massive_table_is_cut_short TableTests.type_mismatch_on_massive_table_is_cut_short
TableTests.unification_of_unions_in_a_self_referential_type TableTests.unification_of_unions_in_a_self_referential_type
TableTests.unifying_tables_shouldnt_uaf2 TableTests.unifying_tables_shouldnt_uaf2
@ -417,12 +404,12 @@ TypeInferFunctions.too_few_arguments_variadic
TypeInferFunctions.too_few_arguments_variadic_generic TypeInferFunctions.too_few_arguments_variadic_generic
TypeInferFunctions.too_few_arguments_variadic_generic2 TypeInferFunctions.too_few_arguments_variadic_generic2
TypeInferFunctions.too_many_arguments TypeInferFunctions.too_many_arguments
TypeInferFunctions.too_many_arguments_error_location
TypeInferFunctions.too_many_return_values TypeInferFunctions.too_many_return_values
TypeInferFunctions.too_many_return_values_in_parentheses TypeInferFunctions.too_many_return_values_in_parentheses
TypeInferFunctions.too_many_return_values_no_function TypeInferFunctions.too_many_return_values_no_function
TypeInferFunctions.vararg_function_is_quantified TypeInferFunctions.vararg_function_is_quantified
TypeInferLoops.for_in_loop_error_on_factory_not_returning_the_right_amount_of_values TypeInferLoops.for_in_loop_error_on_factory_not_returning_the_right_amount_of_values
TypeInferLoops.for_in_loop_with_custom_iterator
TypeInferLoops.for_in_loop_with_next TypeInferLoops.for_in_loop_with_next
TypeInferLoops.for_in_with_generic_next TypeInferLoops.for_in_with_generic_next
TypeInferLoops.for_in_with_just_one_iterator_is_ok TypeInferLoops.for_in_with_just_one_iterator_is_ok
@ -430,7 +417,6 @@ TypeInferLoops.loop_iter_no_indexer_nonstrict
TypeInferLoops.loop_iter_trailing_nil TypeInferLoops.loop_iter_trailing_nil
TypeInferLoops.unreachable_code_after_infinite_loop TypeInferLoops.unreachable_code_after_infinite_loop
TypeInferLoops.varlist_declared_by_for_in_loop_should_be_free TypeInferLoops.varlist_declared_by_for_in_loop_should_be_free
TypeInferModules.bound_free_table_export_is_ok
TypeInferModules.custom_require_global TypeInferModules.custom_require_global
TypeInferModules.do_not_modify_imported_types TypeInferModules.do_not_modify_imported_types
TypeInferModules.module_type_conflict TypeInferModules.module_type_conflict
@ -443,6 +429,7 @@ TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_it_wont_help_2
TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_not_defined_with_colon TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_not_defined_with_colon
TypeInferOOP.inferring_hundreds_of_self_calls_should_not_suffocate_memory TypeInferOOP.inferring_hundreds_of_self_calls_should_not_suffocate_memory
TypeInferOOP.methods_are_topologically_sorted TypeInferOOP.methods_are_topologically_sorted
TypeInferOOP.object_constructor_can_refer_to_method_of_self
TypeInferOperators.and_or_ternary TypeInferOperators.and_or_ternary
TypeInferOperators.CallAndOrOfFunctions TypeInferOperators.CallAndOrOfFunctions
TypeInferOperators.cannot_compare_tables_that_do_not_have_the_same_metatable TypeInferOperators.cannot_compare_tables_that_do_not_have_the_same_metatable

View file

@ -32,6 +32,9 @@ source = """// This file is part of the Luau programming language and is license
""" """
function = "" function = ""
signature = ""
includeInsts = ["LOP_NEWCLOSURE", "LOP_NAMECALL", "LOP_FORGPREP", "LOP_GETVARARGS", "LOP_DUPCLOSURE", "LOP_PREPVARARGS", "LOP_COVERAGE", "LOP_BREAK", "LOP_GETGLOBAL", "LOP_SETGLOBAL", "LOP_GETTABLEKS", "LOP_SETTABLEKS"]
state = 0 state = 0
@ -44,7 +47,6 @@ for line in input:
if match: if match:
inst = match[1] inst = match[1]
signature = "const Instruction* execute_" + inst + "(lua_State* L, const Instruction* pc, StkId base, TValue* k)" signature = "const Instruction* execute_" + inst + "(lua_State* L, const Instruction* pc, StkId base, TValue* k)"
header += signature + ";\n"
function = signature + "\n" function = signature + "\n"
function += "{\n" function += "{\n"
function += " [[maybe_unused]] Closure* cl = clvalue(L->ci->func);\n" function += " [[maybe_unused]] Closure* cl = clvalue(L->ci->func);\n"
@ -84,7 +86,10 @@ for line in input:
function = function[:-len(finalline)] function = function[:-len(finalline)]
function += " return pc;\n}\n" function += " return pc;\n}\n"
source += function + "\n" if inst in includeInsts:
header += signature + ";\n"
source += function + "\n"
state = 0 state = 0
# skip LUA_CUSTOM_EXECUTION code blocks # skip LUA_CUSTOM_EXECUTION code blocks

View file

@ -0,0 +1,173 @@
#!/usr/bin/python
# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
# The purpose of this script is to analyze disassembly generated by objdump or
# dumpbin to print (or to compare) the stack usage of functions/methods.
# This is a quickly written script, so it is quite possible it may not handle
# all code properly.
#
# The script expects the user to create a text assembly dump to be passed to
# the script.
#
# objdump Example
# objdump --demangle --disassemble objfile.o > objfile.s
#
# dumpbin Example
# dumpbin /disasm objfile.obj > objfile.s
#
# If the script is passed a single file, then all stack size information that
# is found it printed. If two files are passed, then the script compares the
# stack usage of the two files (useful for A/B comparisons).
# Currently more than two input files are not supported. (But adding support shouldn't
# be very difficult.)
#
# Note: The script only handles x64 disassembly. Supporting x86 is likely
# trivial, but ARM support could be difficult.
# Thus far the script has been tested with MSVC on Win64 and clang on OSX.
import argparse
import re
blank_re = re.compile('\s*')
class LineReader:
def __init__(self, lines):
self.lines = list(reversed(lines))
def get_line(self):
return self.lines.pop(-1)
def peek_line(self):
return self.lines[-1]
def consume_blank_lines(self):
while blank_re.fullmatch(self.peek_line()):
self.get_line()
def is_empty(self):
return len(self.lines) == 0
def parse_objdump_assembly(in_file):
results = {}
text_section_re = re.compile('Disassembly of section __TEXT,__text:\s*')
symbol_re = re.compile('[^<]*<(.*)>:\s*')
stack_alloc = re.compile('.*subq\s*\$(\d*), %rsp\s*')
lr = LineReader(in_file.readlines())
def find_stack_alloc_size():
while True:
if lr.is_empty():
return None
if blank_re.fullmatch(lr.peek_line()):
return None
line = lr.get_line()
mo = stack_alloc.fullmatch(line)
if mo:
lr.consume_blank_lines()
return int(mo.group(1))
# Find beginning of disassembly
while not text_section_re.fullmatch(lr.get_line()):
pass
# Scan for symbols
while not lr.is_empty():
lr.consume_blank_lines()
if lr.is_empty():
break
line = lr.get_line()
mo = symbol_re.fullmatch(line)
# Found a symbol
if mo:
symbol = mo.group(1)
stack_size = find_stack_alloc_size()
if stack_size != None:
results[symbol] = stack_size
return results
def parse_dumpbin_assembly(in_file):
results = {}
file_type_re = re.compile('File Type: COFF OBJECT\s*')
symbol_re = re.compile('[^(]*\((.*)\):\s*')
summary_re = re.compile('\s*Summary\s*')
stack_alloc = re.compile('.*sub\s*rsp,([A-Z0-9]*)h\s*')
lr = LineReader(in_file.readlines())
def find_stack_alloc_size():
while True:
if lr.is_empty():
return None
if blank_re.fullmatch(lr.peek_line()):
return None
line = lr.get_line()
mo = stack_alloc.fullmatch(line)
if mo:
lr.consume_blank_lines()
return int(mo.group(1), 16) # return value in decimal
# Find beginning of disassembly
while not file_type_re.fullmatch(lr.get_line()):
pass
# Scan for symbols
while not lr.is_empty():
lr.consume_blank_lines()
if lr.is_empty():
break
line = lr.get_line()
if summary_re.fullmatch(line):
break
mo = symbol_re.fullmatch(line)
# Found a symbol
if mo:
symbol = mo.group(1)
stack_size = find_stack_alloc_size()
if stack_size != None:
results[symbol] = stack_size
return results
def main():
parser = argparse.ArgumentParser(description='Tool used for reporting or comparing the stack usage of functions/methods')
parser.add_argument('--format', choices=['dumpbin', 'objdump'], required=True, help='Specifies the program used to generate the input files')
parser.add_argument('--input', action='append', required=True, help='Input assembly file. This option may be specified multiple times.')
parser.add_argument('--md-output', action='store_true', help='Show table output in markdown format')
parser.add_argument('--only-diffs', action='store_true', help='Only show stack info when it differs between the input files')
args = parser.parse_args()
parsers = {'dumpbin': parse_dumpbin_assembly, 'objdump' : parse_objdump_assembly}
parse_func = parsers[args.format]
input_results = []
for input_name in args.input:
with open(input_name) as in_file:
results = parse_func(in_file)
input_results.append(results)
if len(input_results) == 1:
# Print out the results sorted by size
size_sorted = sorted([(size, symbol) for symbol, size in results.items()], reverse=True)
print(input_name)
for size, symbol in size_sorted:
print(f'{size:10}\t{symbol}')
print()
elif len(input_results) == 2:
common_symbols = set(input_results[0].keys()).intersection(set(input_results[1].keys()))
print(f'Found {len(common_symbols)} common symbols')
stack_sizes = sorted([(input_results[0][sym], input_results[1][sym], sym) for sym in common_symbols], reverse=True)
if args.md_output:
print('Before | After | Symbol')
print('-- | -- | --')
for size0, size1, symbol in stack_sizes:
if args.only_diffs and size0 == size1:
continue
if args.md_output:
print(f'{size0} | {size1} | {symbol}')
else:
print(f'{size0:10}\t{size1:10}\t{symbol}')
else:
print("TODO support more than 2 inputs")
if __name__ == '__main__':
main()