Sync to upstream/release/550 (#723)

* Support `["prop"]` syntax on class definitions in definition files.
(#704)
* Improve type checking performance for complex overloaded functions
* Fix rare cases of incorrect stack traces for out of memory errors at
runtime
This commit is contained in:
Andy Friesen 2022-10-21 10:54:01 -07:00 committed by GitHub
parent 12ee1407a1
commit 54324867df
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
104 changed files with 4210 additions and 2266 deletions

View file

@ -1,6 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once #pragma once
#include <Luau/NotNull.h>
#include "Luau/TypeArena.h" #include "Luau/TypeArena.h"
#include "Luau/TypeVar.h" #include "Luau/TypeVar.h"
@ -26,5 +27,6 @@ TypeId clone(TypeId tp, TypeArena& dest, CloneState& cloneState);
TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState); TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState);
TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysClone = false); TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysClone = false);
TypeId shallowClone(TypeId ty, NotNull<TypeArena> dest);
} // namespace Luau } // namespace Luau

View file

@ -2,9 +2,10 @@
#pragma once #pragma once
#include "Luau/Ast.h" // Used for some of the enumerations #include "Luau/Ast.h" // Used for some of the enumerations
#include "Luau/Def.h"
#include "Luau/NotNull.h" #include "Luau/NotNull.h"
#include "Luau/Variant.h"
#include "Luau/TypeVar.h" #include "Luau/TypeVar.h"
#include "Luau/Variant.h"
#include <string> #include <string>
#include <memory> #include <memory>
@ -131,9 +132,15 @@ struct HasPropConstraint
std::string prop; std::string prop;
}; };
using ConstraintV = struct RefinementConstraint
Variant<SubtypeConstraint, PackSubtypeConstraint, GeneralizationConstraint, InstantiationConstraint, UnaryConstraint, BinaryConstraint, {
IterableConstraint, NameConstraint, TypeAliasExpansionConstraint, FunctionCallConstraint, PrimitiveTypeConstraint, HasPropConstraint>; DefId def;
TypeId discriminantType;
};
using ConstraintV = Variant<SubtypeConstraint, PackSubtypeConstraint, GeneralizationConstraint, InstantiationConstraint, UnaryConstraint,
BinaryConstraint, IterableConstraint, NameConstraint, TypeAliasExpansionConstraint, FunctionCallConstraint, PrimitiveTypeConstraint,
HasPropConstraint, RefinementConstraint>;
struct Constraint struct Constraint
{ {
@ -143,7 +150,7 @@ struct Constraint
Constraint& operator=(const Constraint&) = delete; Constraint& operator=(const Constraint&) = delete;
NotNull<Scope> scope; NotNull<Scope> scope;
Location location; Location location; // TODO: Extract this out into only the constraints that needs a location. Not all constraints needs locations.
ConstraintV c; ConstraintV c;
std::vector<NotNull<Constraint>> dependencies; std::vector<NotNull<Constraint>> dependencies;

View file

@ -1,13 +1,9 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once #pragma once
#include <memory>
#include <vector>
#include <unordered_map>
#include "Luau/Ast.h" #include "Luau/Ast.h"
#include "Luau/Constraint.h" #include "Luau/Constraint.h"
#include "Luau/DataFlowGraphBuilder.h"
#include "Luau/Module.h" #include "Luau/Module.h"
#include "Luau/ModuleResolver.h" #include "Luau/ModuleResolver.h"
#include "Luau/NotNull.h" #include "Luau/NotNull.h"
@ -15,6 +11,10 @@
#include "Luau/TypeVar.h" #include "Luau/TypeVar.h"
#include "Luau/Variant.h" #include "Luau/Variant.h"
#include <memory>
#include <vector>
#include <unordered_map>
namespace Luau namespace Luau
{ {
@ -48,6 +48,7 @@ struct ConstraintGraphBuilder
DenseHashMap<const AstTypePack*, TypePackId> astResolvedTypePacks{nullptr}; DenseHashMap<const AstTypePack*, TypePackId> astResolvedTypePacks{nullptr};
// Defining scopes for AST nodes. // Defining scopes for AST nodes.
DenseHashMap<const AstStatTypeAlias*, ScopePtr> astTypeAliasDefiningScopes{nullptr}; DenseHashMap<const AstStatTypeAlias*, ScopePtr> astTypeAliasDefiningScopes{nullptr};
NotNull<const DataFlowGraph> dfg;
int recursionCount = 0; int recursionCount = 0;
@ -63,7 +64,8 @@ struct ConstraintGraphBuilder
DcrLogger* logger; DcrLogger* logger;
ConstraintGraphBuilder(const ModuleName& moduleName, ModulePtr module, TypeArena* arena, NotNull<ModuleResolver> moduleResolver, ConstraintGraphBuilder(const ModuleName& moduleName, ModulePtr module, TypeArena* arena, NotNull<ModuleResolver> moduleResolver,
NotNull<SingletonTypes> singletonTypes, NotNull<InternalErrorReporter> ice, const ScopePtr& globalScope, DcrLogger* logger); NotNull<SingletonTypes> singletonTypes, NotNull<InternalErrorReporter> ice, const ScopePtr& globalScope, DcrLogger* logger,
NotNull<DataFlowGraph> dfg);
/** /**
* Fabricates a new free type belonging to a given scope. * Fabricates a new free type belonging to a given scope.
@ -88,15 +90,17 @@ struct ConstraintGraphBuilder
* Adds a new constraint with no dependencies to a given scope. * Adds a new constraint with no dependencies to a given scope.
* @param scope the scope to add the constraint to. * @param scope the scope to add the constraint to.
* @param cv the constraint variant to add. * @param cv the constraint variant to add.
* @return the pointer to the inserted constraint
*/ */
void addConstraint(const ScopePtr& scope, const Location& location, ConstraintV cv); NotNull<Constraint> addConstraint(const ScopePtr& scope, const Location& location, ConstraintV cv);
/** /**
* Adds a constraint to a given scope. * Adds a constraint to a given scope.
* @param scope the scope to add the constraint to. Must not be null. * @param scope the scope to add the constraint to. Must not be null.
* @param c the constraint to add. * @param c the constraint to add.
* @return the pointer to the inserted constraint
*/ */
void addConstraint(const ScopePtr& scope, std::unique_ptr<Constraint> c); NotNull<Constraint> addConstraint(const ScopePtr& scope, std::unique_ptr<Constraint> c);
/** /**
* The entry point to the ConstraintGraphBuilder. This will construct a set * The entry point to the ConstraintGraphBuilder. This will construct a set
@ -139,13 +143,20 @@ struct ConstraintGraphBuilder
*/ */
TypeId check(const ScopePtr& scope, AstExpr* expr, std::optional<TypeId> expectedType = {}); TypeId check(const ScopePtr& scope, AstExpr* expr, std::optional<TypeId> expectedType = {});
TypeId check(const ScopePtr& scope, AstExprTable* expr, std::optional<TypeId> expectedType); TypeId check(const ScopePtr& scope, AstExprLocal* local);
TypeId check(const ScopePtr& scope, AstExprGlobal* global);
TypeId check(const ScopePtr& scope, AstExprIndexName* indexName); TypeId check(const ScopePtr& scope, AstExprIndexName* indexName);
TypeId check(const ScopePtr& scope, AstExprIndexExpr* indexExpr); TypeId check(const ScopePtr& scope, AstExprIndexExpr* indexExpr);
TypeId check(const ScopePtr& scope, AstExprUnary* unary); TypeId check(const ScopePtr& scope, AstExprUnary* unary);
TypeId check(const ScopePtr& scope, AstExprBinary* binary); TypeId check_(const ScopePtr& scope, AstExprUnary* unary);
TypeId check(const ScopePtr& scope, AstExprBinary* binary, std::optional<TypeId> expectedType);
TypeId check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional<TypeId> expectedType); TypeId check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional<TypeId> expectedType);
TypeId check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert); TypeId check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert);
TypeId check(const ScopePtr& scope, AstExprTable* expr, std::optional<TypeId> expectedType);
TypePackId checkLValues(const ScopePtr& scope, AstArray<AstExpr*> exprs);
TypeId checkLValue(const ScopePtr& scope, AstExpr* expr);
struct FunctionSignature struct FunctionSignature
{ {

View file

@ -110,6 +110,7 @@ struct ConstraintSolver
bool tryDispatch(const FunctionCallConstraint& c, NotNull<const Constraint> constraint); bool tryDispatch(const FunctionCallConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const PrimitiveTypeConstraint& c, NotNull<const Constraint> constraint); bool tryDispatch(const PrimitiveTypeConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const HasPropConstraint& c, NotNull<const Constraint> constraint); bool tryDispatch(const HasPropConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const RefinementConstraint& c, NotNull<const Constraint> constraint);
// for a, ... in some_table do // for a, ... in some_table do
// also handles __iter metamethod // also handles __iter metamethod
@ -215,6 +216,8 @@ private:
TypeId errorRecoveryType() const; TypeId errorRecoveryType() const;
TypePackId errorRecoveryTypePack() const; TypePackId errorRecoveryTypePack() const;
TypeId unionOfTypes(TypeId a, TypeId b, NotNull<Scope> scope, bool unifyFreeTypes);
ToStringOptions opts; ToStringOptions opts;
}; };

View file

@ -0,0 +1,115 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
// Do not include LValue. It should never be used here.
#include "Luau/Ast.h"
#include "Luau/DenseHash.h"
#include "Luau/Def.h"
#include "Luau/Symbol.h"
#include <unordered_map>
namespace Luau
{
struct DataFlowGraph
{
DataFlowGraph(DataFlowGraph&&) = default;
DataFlowGraph& operator=(DataFlowGraph&&) = default;
// TODO: AstExprLocal, AstExprGlobal, and AstLocal* are guaranteed never to return nullopt.
// We leave them to return an optional as we build it out, but the end state is for them to return a non-optional DefId.
std::optional<DefId> getDef(const AstExpr* expr) const;
std::optional<DefId> getDef(const AstLocal* local) const;
/// Retrieve the Def that corresponds to the given Symbol.
///
/// We do not perform dataflow analysis on globals, so this function always
/// yields nullopt when passed a global Symbol.
std::optional<DefId> getDef(const Symbol& symbol) const;
private:
DataFlowGraph() = default;
DataFlowGraph(const DataFlowGraph&) = delete;
DataFlowGraph& operator=(const DataFlowGraph&) = delete;
DefArena arena;
DenseHashMap<const AstExpr*, const Def*> astDefs{nullptr};
DenseHashMap<const AstLocal*, const Def*> localDefs{nullptr};
friend struct DataFlowGraphBuilder;
};
struct DfgScope
{
DfgScope* parent;
DenseHashMap<Symbol, const Def*> bindings{Symbol{}};
};
struct ExpressionFlowGraph
{
std::optional<DefId> def;
};
// Currently unsound. We do not presently track the control flow of the program.
// Additionally, we do not presently track assignments.
struct DataFlowGraphBuilder
{
static DataFlowGraph build(AstStatBlock* root, NotNull<struct InternalErrorReporter> handle);
private:
DataFlowGraphBuilder() = default;
DataFlowGraphBuilder(const DataFlowGraphBuilder&) = delete;
DataFlowGraphBuilder& operator=(const DataFlowGraphBuilder&) = delete;
DataFlowGraph graph;
NotNull<DefArena> arena{&graph.arena};
struct InternalErrorReporter* handle;
std::vector<std::unique_ptr<DfgScope>> scopes;
DfgScope* childScope(DfgScope* scope);
std::optional<DefId> use(DfgScope* scope, Symbol symbol, AstExpr* e);
void visit(DfgScope* scope, AstStatBlock* b);
void visitBlockWithoutChildScope(DfgScope* scope, AstStatBlock* b);
// TODO: visit type aliases
void visit(DfgScope* scope, AstStat* s);
void visit(DfgScope* scope, AstStatIf* i);
void visit(DfgScope* scope, AstStatWhile* w);
void visit(DfgScope* scope, AstStatRepeat* r);
void visit(DfgScope* scope, AstStatBreak* b);
void visit(DfgScope* scope, AstStatContinue* c);
void visit(DfgScope* scope, AstStatReturn* r);
void visit(DfgScope* scope, AstStatExpr* e);
void visit(DfgScope* scope, AstStatLocal* l);
void visit(DfgScope* scope, AstStatFor* f);
void visit(DfgScope* scope, AstStatForIn* f);
void visit(DfgScope* scope, AstStatAssign* a);
void visit(DfgScope* scope, AstStatCompoundAssign* c);
void visit(DfgScope* scope, AstStatFunction* f);
void visit(DfgScope* scope, AstStatLocalFunction* l);
ExpressionFlowGraph visitExpr(DfgScope* scope, AstExpr* e);
ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprLocal* l);
ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprGlobal* g);
ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprCall* c);
ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprIndexName* i);
ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprIndexExpr* i);
ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprFunction* f);
ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprTable* t);
ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprUnary* u);
ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprBinary* b);
ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprTypeAssertion* t);
ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprIfElse* i);
ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprInterpString* i);
// TODO: visitLValue
// TODO: visitTypes (because of typeof which has access to values namespace, needs unreachable scope)
// TODO: visitTypePacks (because of typeof which has access to values namespace, needs unreachable scope)
};
} // namespace Luau

View file

@ -0,0 +1,78 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/NotNull.h"
#include "Luau/TypedAllocator.h"
#include "Luau/Variant.h"
namespace Luau
{
using Def = Variant<struct Undefined, struct Phi>;
/**
* We statically approximate a value at runtime using a symbolic value, which we call a Def.
*
* DataFlowGraphBuilder will allocate these defs as a stand-in for some Luau values, and bind them to places that
* can hold a Luau value, and then observes how those defs will commute as it statically evaluate the program.
*
* It must also be noted that defs are a cyclic graph, so it is not safe to recursively traverse into it expecting it to terminate.
*/
using DefId = NotNull<const Def>;
/**
* A "single-object" value.
*
* Leaky implementation note: sometimes "multiple-object" values, but none of which were interesting enough to warrant creating a phi node instead.
* That can happen because there's no point in creating a phi node that points to either resultant in `if math.random() > 0.5 then 5 else "hello"`.
* This might become of utmost importance if we wanted to do some backward reasoning, e.g. if `5` is taken, then `cond` must be `truthy`.
*/
struct Undefined
{
};
/**
* A phi node is a union of defs.
*
* We need this because we're statically evaluating a program, and sometimes a place may be assigned with
* different defs, and when that happens, we need a special data type that merges in all the defs
* that will flow into that specific place. For example, consider this simple program:
*
* ```
* x-1
* if cond() then
* x-2 = 5
* else
* x-3 = "hello"
* end
* x-4 : {x-2, x-3}
* ```
*
* At x-4, we know for a fact statically that either `5` or `"hello"` can flow into the variable `x` after the branch, but
* we cannot make any definitive decisions about which one, so we just take in both.
*/
struct Phi
{
std::vector<DefId> operands;
};
template<typename T>
T* getMutable(DefId def)
{
return get_if<T>(def.get());
}
template<typename T>
const T* get(DefId def)
{
return getMutable<T>(def);
}
struct DefArena
{
TypedAllocator<Def> allocator;
DefId freshDef();
};
} // namespace Luau

View file

@ -14,6 +14,8 @@ struct TypeVar;
using TypeId = const TypeVar*; using TypeId = const TypeVar*;
struct Field; struct Field;
// Deprecated. Do not use in new work.
using LValue = Variant<Symbol, Field>; using LValue = Variant<Symbol, Field>;
struct Field struct Field

View file

@ -0,0 +1,32 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Ast.h"
#include <unordered_map>
namespace Luau
{
static const std::unordered_map<AstExprBinary::Op, const char*> kBinaryOpMetamethods{
{AstExprBinary::Op::CompareEq, "__eq"},
{AstExprBinary::Op::CompareNe, "__eq"},
{AstExprBinary::Op::CompareGe, "__lt"},
{AstExprBinary::Op::CompareGt, "__le"},
{AstExprBinary::Op::CompareLe, "__le"},
{AstExprBinary::Op::CompareLt, "__lt"},
{AstExprBinary::Op::Add, "__add"},
{AstExprBinary::Op::Sub, "__sub"},
{AstExprBinary::Op::Mul, "__mul"},
{AstExprBinary::Op::Div, "__div"},
{AstExprBinary::Op::Pow, "__pow"},
{AstExprBinary::Op::Mod, "__mod"},
{AstExprBinary::Op::Concat, "__concat"},
};
static const std::unordered_map<AstExprUnary::Op, const char*> kUnaryOpMetamethods{
{AstExprUnary::Op::Minus, "__unm"},
{AstExprUnary::Op::Len, "__len"},
};
} // namespace Luau

View file

@ -22,15 +22,6 @@ bool isSubtype(
bool isSubtype(TypePackId subTy, TypePackId superTy, NotNull<Scope> scope, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter& ice, bool isSubtype(TypePackId subTy, TypePackId superTy, NotNull<Scope> scope, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter& ice,
bool anyIsTop = true); bool anyIsTop = true);
std::pair<TypeId, bool> normalize(
TypeId ty, NotNull<Scope> scope, TypeArena& arena, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter& ice);
std::pair<TypeId, bool> normalize(TypeId ty, NotNull<Module> module, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter& ice);
std::pair<TypeId, bool> normalize(TypeId ty, const ModulePtr& module, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter& ice);
std::pair<TypePackId, bool> normalize(
TypePackId ty, NotNull<Scope> scope, TypeArena& arena, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter& ice);
std::pair<TypePackId, bool> normalize(TypePackId ty, NotNull<Module> module, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter& ice);
std::pair<TypePackId, bool> normalize(TypePackId ty, const ModulePtr& module, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter& ice);
class TypeIds class TypeIds
{ {
private: private:

View file

@ -54,7 +54,9 @@ struct Scope
DenseHashSet<Name> builtinTypeNames{""}; DenseHashSet<Name> builtinTypeNames{""};
void addBuiltinTypeBinding(const Name& name, const TypeFun& tyFun); void addBuiltinTypeBinding(const Name& name, const TypeFun& tyFun);
std::optional<TypeId> lookup(Symbol sym); std::optional<TypeId> lookup(Symbol sym) const;
std::optional<TypeId> lookup(DefId def) const;
std::optional<std::pair<TypeId, Scope*>> lookupEx(Symbol sym);
std::optional<TypeFun> lookupType(const Name& name); std::optional<TypeFun> lookupType(const Name& name);
std::optional<TypeFun> lookupImportedType(const Name& moduleAlias, const Name& name); std::optional<TypeFun> lookupImportedType(const Name& moduleAlias, const Name& name);
@ -66,6 +68,7 @@ struct Scope
std::optional<Binding> linearSearchForBinding(const std::string& name, bool traverseScopeChain = true) const; std::optional<Binding> linearSearchForBinding(const std::string& name, bool traverseScopeChain = true) const;
RefinementMap refinements; RefinementMap refinements;
DenseHashMap<const Def*, TypeId> dcrRefinements{nullptr};
// For mutually recursive type aliases, it's important that // For mutually recursive type aliases, it's important that
// they use the same types for the same names. // they use the same types for the same names.

View file

@ -6,10 +6,11 @@
#include <string> #include <string>
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
namespace Luau namespace Luau
{ {
// TODO Rename this to Name once the old type alias is gone.
struct Symbol struct Symbol
{ {
Symbol() Symbol()
@ -40,9 +41,12 @@ struct Symbol
{ {
if (local) if (local)
return local == rhs.local; return local == rhs.local;
if (global.value) else if (global.value)
return rhs.global.value && global == rhs.global.value; // Subtlety: AstName::operator==(const char*) uses strcmp, not pointer identity. return rhs.global.value && global == rhs.global.value; // Subtlety: AstName::operator==(const char*) uses strcmp, not pointer identity.
return false; else if (FFlag::DebugLuauDeferredConstraintResolution)
return !rhs.local && !rhs.global.value; // Reflexivity: we already know `this` Symbol is empty, so check that rhs is.
else
return false;
} }
bool operator!=(const Symbol& rhs) const bool operator!=(const Symbol& rhs) const
@ -58,8 +62,8 @@ struct Symbol
return global < rhs.global; return global < rhs.global;
else if (local) else if (local)
return true; return true;
else
return false; return false;
} }
AstName astName() const AstName astName() const

View file

@ -192,18 +192,12 @@ struct TypeChecker
ErrorVec canUnify(TypeId subTy, TypeId superTy, const ScopePtr& scope, const Location& location); ErrorVec canUnify(TypeId subTy, TypeId superTy, const ScopePtr& scope, const Location& location);
ErrorVec canUnify(TypePackId subTy, TypePackId superTy, const ScopePtr& scope, const Location& location); ErrorVec canUnify(TypePackId subTy, TypePackId superTy, const ScopePtr& scope, const Location& location);
void unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel demotedLevel, const ScopePtr& scope, const Location& location);
std::optional<TypeId> findMetatableEntry(TypeId type, std::string entry, const Location& location, bool addErrors); std::optional<TypeId> findMetatableEntry(TypeId type, std::string entry, const Location& location, bool addErrors);
std::optional<TypeId> findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location, bool addErrors); std::optional<TypeId> findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location, bool addErrors);
std::optional<TypeId> getIndexTypeFromType(const ScopePtr& scope, TypeId type, const Name& name, const Location& location, bool addErrors); std::optional<TypeId> getIndexTypeFromType(const ScopePtr& scope, TypeId type, const Name& name, const Location& location, bool addErrors);
std::optional<TypeId> getIndexTypeFromTypeImpl(const ScopePtr& scope, TypeId type, const Name& name, const Location& location, bool addErrors); std::optional<TypeId> getIndexTypeFromTypeImpl(const ScopePtr& scope, TypeId type, const Name& name, const Location& location, bool addErrors);
// Reduces the union to its simplest possible shape.
// (A | B) | B | C yields A | B | C
std::vector<TypeId> reduceUnion(const std::vector<TypeId>& types);
std::optional<TypeId> tryStripUnionFromNil(TypeId ty); std::optional<TypeId> tryStripUnionFromNil(TypeId ty);
TypeId stripFromNilAndReport(TypeId ty, const Location& location); TypeId stripFromNilAndReport(TypeId ty, const Location& location);

View file

@ -29,4 +29,23 @@ std::pair<size_t, std::optional<size_t>> getParameterExtents(const TxnLog* log,
// various other things to get there. // various other things to get there.
std::vector<TypeId> flatten(TypeArena& arena, NotNull<SingletonTypes> singletonTypes, TypePackId pack, size_t length); std::vector<TypeId> flatten(TypeArena& arena, NotNull<SingletonTypes> singletonTypes, TypePackId pack, size_t length);
/**
* Reduces a union by decomposing to the any/error type if it appears in the
* type list, and by merging child unions. Also strips out duplicate (by pointer
* identity) types.
* @param types the input type list to reduce.
* @returns the reduced type list.
*/
std::vector<TypeId> reduceUnion(const std::vector<TypeId>& types);
/**
* Tries to remove nil from a union type, if there's another option. T | nil
* reduces to T, but nil itself does not reduce.
* @param singletonTypes the singleton types to use
* @param arena the type arena to allocate the new type in, if necessary
* @param ty the type to remove nil from
* @returns a type with nil removed, or nil itself if that were the only option.
*/
TypeId stripNil(NotNull<SingletonTypes> singletonTypes, TypeArena& arena, TypeId ty);
} // namespace Luau } // namespace Luau

View file

@ -2,22 +2,23 @@
#pragma once #pragma once
#include "Luau/Ast.h" #include "Luau/Ast.h"
#include "Luau/Common.h"
#include "Luau/DenseHash.h" #include "Luau/DenseHash.h"
#include "Luau/Def.h"
#include "Luau/NotNull.h"
#include "Luau/Predicate.h" #include "Luau/Predicate.h"
#include "Luau/Unifiable.h" #include "Luau/Unifiable.h"
#include "Luau/Variant.h" #include "Luau/Variant.h"
#include "Luau/Common.h"
#include "Luau/NotNull.h"
#include <set>
#include <string>
#include <map>
#include <unordered_set>
#include <unordered_map>
#include <vector>
#include <deque> #include <deque>
#include <map>
#include <memory> #include <memory>
#include <optional> #include <optional>
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
LUAU_FASTINT(LuauTableTypeMaximumStringifierLength) LUAU_FASTINT(LuauTableTypeMaximumStringifierLength)
LUAU_FASTINT(LuauTypeMaximumStringifierLength) LUAU_FASTINT(LuauTypeMaximumStringifierLength)
@ -131,24 +132,6 @@ struct PrimitiveTypeVar
} }
}; };
struct ConstrainedTypeVar
{
explicit ConstrainedTypeVar(TypeLevel level)
: level(level)
{
}
explicit ConstrainedTypeVar(TypeLevel level, const std::vector<TypeId>& parts)
: parts(parts)
, level(level)
{
}
std::vector<TypeId> parts;
TypeLevel level;
Scope* scope = nullptr;
};
// Singleton types https://github.com/Roblox/luau/blob/master/rfcs/syntax-singleton-types.md // Singleton types https://github.com/Roblox/luau/blob/master/rfcs/syntax-singleton-types.md
// Types for true and false // Types for true and false
struct BooleanSingleton struct BooleanSingleton
@ -496,11 +479,13 @@ struct AnyTypeVar
{ {
}; };
// T | U
struct UnionTypeVar struct UnionTypeVar
{ {
std::vector<TypeId> options; std::vector<TypeId> options;
}; };
// T & U
struct IntersectionTypeVar struct IntersectionTypeVar
{ {
std::vector<TypeId> parts; std::vector<TypeId> parts;
@ -519,12 +504,27 @@ struct NeverTypeVar
{ {
}; };
// Invariant 1: there should never be a reason why such UseTypeVar exists without it mapping to another type.
// Invariant 2: UseTypeVar should always disappear across modules.
struct UseTypeVar
{
DefId def;
NotNull<Scope> scope;
};
// ~T
// TODO: Some simplification step that overwrites the type graph to make sure negation
// types disappear from the user's view, and (?) a debug flag to disable that
struct NegationTypeVar
{
TypeId ty;
};
using ErrorTypeVar = Unifiable::Error; using ErrorTypeVar = Unifiable::Error;
using TypeVariant = using TypeVariant = Unifiable::Variant<TypeId, PrimitiveTypeVar, BlockedTypeVar, PendingExpansionTypeVar, SingletonTypeVar, FunctionTypeVar,
Unifiable::Variant<TypeId, PrimitiveTypeVar, ConstrainedTypeVar, BlockedTypeVar, PendingExpansionTypeVar, SingletonTypeVar, FunctionTypeVar, TableTypeVar, MetatableTypeVar, ClassTypeVar, AnyTypeVar, UnionTypeVar, IntersectionTypeVar, LazyTypeVar, UnknownTypeVar, NeverTypeVar,
TableTypeVar, MetatableTypeVar, ClassTypeVar, AnyTypeVar, UnionTypeVar, IntersectionTypeVar, LazyTypeVar, UnknownTypeVar, NeverTypeVar>; UseTypeVar, NegationTypeVar>;
struct TypeVar final struct TypeVar final
{ {
@ -541,7 +541,6 @@ struct TypeVar final
TypeVar(const TypeVariant& ty, bool persistent) TypeVar(const TypeVariant& ty, bool persistent)
: ty(ty) : ty(ty)
, persistent(persistent) , persistent(persistent)
, normal(persistent) // We assume that all persistent types are irreducable.
{ {
} }
@ -549,7 +548,6 @@ struct TypeVar final
void reassign(const TypeVar& rhs) void reassign(const TypeVar& rhs)
{ {
ty = rhs.ty; ty = rhs.ty;
normal = rhs.normal;
documentationSymbol = rhs.documentationSymbol; documentationSymbol = rhs.documentationSymbol;
} }
@ -560,10 +558,6 @@ struct TypeVar final
// Persistent TypeVars do not get cloned. // Persistent TypeVars do not get cloned.
bool persistent = false; bool persistent = false;
// Normalization sets this for types that are fully normalized.
// This implies that they are transitively immutable.
bool normal = false;
std::optional<std::string> documentationSymbol; std::optional<std::string> documentationSymbol;
// Pointer to the type arena that allocated this type. // Pointer to the type arena that allocated this type.
@ -656,6 +650,8 @@ public:
const TypeId unknownType; const TypeId unknownType;
const TypeId neverType; const TypeId neverType;
const TypeId errorType; const TypeId errorType;
const TypeId falsyType; // No type binding!
const TypeId truthyType; // No type binding!
const TypePackId anyTypePack; const TypePackId anyTypePack;
const TypePackId neverTypePack; const TypePackId neverTypePack;
@ -703,7 +699,6 @@ T* getMutable(TypeId tv)
const std::vector<TypeId>& getTypes(const UnionTypeVar* utv); const std::vector<TypeId>& getTypes(const UnionTypeVar* utv);
const std::vector<TypeId>& getTypes(const IntersectionTypeVar* itv); const std::vector<TypeId>& getTypes(const IntersectionTypeVar* itv);
const std::vector<TypeId>& getTypes(const ConstrainedTypeVar* ctv);
template<typename T> template<typename T>
struct TypeIterator; struct TypeIterator;
@ -716,10 +711,6 @@ using IntersectionTypeVarIterator = TypeIterator<IntersectionTypeVar>;
IntersectionTypeVarIterator begin(const IntersectionTypeVar* itv); IntersectionTypeVarIterator begin(const IntersectionTypeVar* itv);
IntersectionTypeVarIterator end(const IntersectionTypeVar* itv); IntersectionTypeVarIterator end(const IntersectionTypeVar* itv);
using ConstrainedTypeVarIterator = TypeIterator<ConstrainedTypeVar>;
ConstrainedTypeVarIterator begin(const ConstrainedTypeVar* ctv);
ConstrainedTypeVarIterator end(const ConstrainedTypeVar* ctv);
/* Traverses the type T yielding each TypeId. /* Traverses the type T yielding each TypeId.
* If the iterator encounters a nested type T, it will instead yield each TypeId within. * If the iterator encounters a nested type T, it will instead yield each TypeId within.
*/ */
@ -793,7 +784,6 @@ struct TypeIterator
// with templates portability in this area, so not worth it. Thanks MSVC. // with templates portability in this area, so not worth it. Thanks MSVC.
friend UnionTypeVarIterator end(const UnionTypeVar*); friend UnionTypeVarIterator end(const UnionTypeVar*);
friend IntersectionTypeVarIterator end(const IntersectionTypeVar*); friend IntersectionTypeVarIterator end(const IntersectionTypeVar*);
friend ConstrainedTypeVarIterator end(const ConstrainedTypeVar*);
private: private:
TypeIterator() = default; TypeIterator() = default;

View file

@ -119,12 +119,7 @@ private:
std::optional<TypeId> findTablePropertyRespectingMeta(TypeId lhsType, Name name); std::optional<TypeId> findTablePropertyRespectingMeta(TypeId lhsType, Name name);
void tryUnifyWithConstrainedSubTypeVar(TypeId subTy, TypeId superTy);
void tryUnifyWithConstrainedSuperTypeVar(TypeId subTy, TypeId superTy);
public: public:
void unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel demotedLevel);
// Returns true if the type "needle" already occurs within "haystack" and reports an "infinite type error" // Returns true if the type "needle" already occurs within "haystack" and reports an "infinite type error"
bool occursCheck(TypeId needle, TypeId haystack); bool occursCheck(TypeId needle, TypeId haystack);
bool occursCheck(DenseHashSet<TypeId>& seen, TypeId needle, TypeId haystack); bool occursCheck(DenseHashSet<TypeId>& seen, TypeId needle, TypeId haystack);

View file

@ -105,7 +105,7 @@ public:
tableDtor[typeId](&storage); tableDtor[typeId](&storage);
typeId = tid; typeId = tid;
new (&storage) TT(std::forward<Args>(args)...); new (&storage) TT{std::forward<Args>(args)...};
return *reinterpret_cast<T*>(&storage); return *reinterpret_cast<T*>(&storage);
} }

View file

@ -103,10 +103,6 @@ struct GenericTypeVarVisitor
{ {
return visit(ty); return visit(ty);
} }
virtual bool visit(TypeId ty, const ConstrainedTypeVar& ctv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const PrimitiveTypeVar& ptv) virtual bool visit(TypeId ty, const PrimitiveTypeVar& ptv)
{ {
return visit(ty); return visit(ty);
@ -159,6 +155,14 @@ struct GenericTypeVarVisitor
{ {
return visit(ty); return visit(ty);
} }
virtual bool visit(TypeId ty, const UseTypeVar& utv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const NegationTypeVar& ntv)
{
return visit(ty);
}
virtual bool visit(TypePackId tp) virtual bool visit(TypePackId tp)
{ {
@ -216,14 +220,6 @@ struct GenericTypeVarVisitor
visit(ty, *gtv); visit(ty, *gtv);
else if (auto etv = get<ErrorTypeVar>(ty)) else if (auto etv = get<ErrorTypeVar>(ty))
visit(ty, *etv); visit(ty, *etv);
else if (auto ctv = get<ConstrainedTypeVar>(ty))
{
if (visit(ty, *ctv))
{
for (TypeId part : ctv->parts)
traverse(part);
}
}
else if (auto ptv = get<PrimitiveTypeVar>(ty)) else if (auto ptv = get<PrimitiveTypeVar>(ty))
visit(ty, *ptv); visit(ty, *ptv);
else if (auto ftv = get<FunctionTypeVar>(ty)) else if (auto ftv = get<FunctionTypeVar>(ty))
@ -325,6 +321,10 @@ struct GenericTypeVarVisitor
traverse(a); traverse(a);
} }
} }
else if (auto utv = get<UseTypeVar>(ty))
visit(ty, *utv);
else if (auto ntv = get<NegationTypeVar>(ty))
visit(ty, *ntv);
else if (!FFlag::LuauCompleteVisitor) else if (!FFlag::LuauCompleteVisitor)
return visit_detail::unsee(seen, ty); return visit_detail::unsee(seen, ty);
else else

View file

@ -37,8 +37,6 @@ bool Anyification::isDirty(TypeId ty)
return (ttv->state == TableState::Free || ttv->state == TableState::Unsealed); return (ttv->state == TableState::Free || ttv->state == TableState::Unsealed);
else if (log->getMutable<FreeTypeVar>(ty)) else if (log->getMutable<FreeTypeVar>(ty))
return true; return true;
else if (get<ConstrainedTypeVar>(ty))
return true;
else else
return false; return false;
} }
@ -65,20 +63,8 @@ TypeId Anyification::clean(TypeId ty)
clone.syntheticName = ttv->syntheticName; clone.syntheticName = ttv->syntheticName;
clone.tags = ttv->tags; clone.tags = ttv->tags;
TypeId res = addType(std::move(clone)); TypeId res = addType(std::move(clone));
asMutable(res)->normal = ty->normal;
return res; return res;
} }
else if (auto ctv = get<ConstrainedTypeVar>(ty))
{
std::vector<TypeId> copy = ctv->parts;
for (TypeId& ty : copy)
ty = replace(ty);
TypeId res = copy.size() == 1 ? copy[0] : addType(UnionTypeVar{std::move(copy)});
auto [t, ok] = normalize(res, scope, *arena, singletonTypes, *iceHandler);
if (!ok)
normalizationTooComplex = true;
return t;
}
else else
return anyType; return anyType;
} }

View file

@ -10,6 +10,7 @@
#include "Luau/TypeInfer.h" #include "Luau/TypeInfer.h"
#include "Luau/TypePack.h" #include "Luau/TypePack.h"
#include "Luau/TypeVar.h" #include "Luau/TypeVar.h"
#include "Luau/TypeUtils.h"
#include <algorithm> #include <algorithm>
@ -41,6 +42,7 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionRequire(
static bool dcrMagicFunctionSelect(MagicFunctionCallContext context); static bool dcrMagicFunctionSelect(MagicFunctionCallContext context);
static bool dcrMagicFunctionRequire(MagicFunctionCallContext context); static bool dcrMagicFunctionRequire(MagicFunctionCallContext context);
static bool dcrMagicFunctionPack(MagicFunctionCallContext context);
TypeId makeUnion(TypeArena& arena, std::vector<TypeId>&& types) TypeId makeUnion(TypeArena& arena, std::vector<TypeId>&& types)
{ {
@ -333,6 +335,7 @@ void registerBuiltinGlobals(TypeChecker& typeChecker)
ttv->props["clone"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.clone"); ttv->props["clone"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.clone");
attachMagicFunction(ttv->props["pack"].type, magicFunctionPack); attachMagicFunction(ttv->props["pack"].type, magicFunctionPack);
attachDcrMagicFunction(ttv->props["pack"].type, dcrMagicFunctionPack);
} }
attachMagicFunction(getGlobalBinding(typeChecker, "require"), magicFunctionRequire); attachMagicFunction(getGlobalBinding(typeChecker, "require"), magicFunctionRequire);
@ -660,7 +663,7 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionPack(
options.push_back(vtp->ty); options.push_back(vtp->ty);
} }
options = typechecker.reduceUnion(options); options = reduceUnion(options);
// table.pack() -> {| n: number, [number]: nil |} // table.pack() -> {| n: number, [number]: nil |}
// table.pack(1) -> {| n: number, [number]: number |} // table.pack(1) -> {| n: number, [number]: number |}
@ -679,6 +682,47 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionPack(
return WithPredicate<TypePackId>{arena.addTypePack({packedTable})}; return WithPredicate<TypePackId>{arena.addTypePack({packedTable})};
} }
static bool dcrMagicFunctionPack(MagicFunctionCallContext context)
{
TypeArena* arena = context.solver->arena;
const auto& [paramTypes, paramTail] = flatten(context.arguments);
std::vector<TypeId> options;
options.reserve(paramTypes.size());
for (auto type : paramTypes)
options.push_back(type);
if (paramTail)
{
if (const VariadicTypePack* vtp = get<VariadicTypePack>(*paramTail))
options.push_back(vtp->ty);
}
options = reduceUnion(options);
// table.pack() -> {| n: number, [number]: nil |}
// table.pack(1) -> {| n: number, [number]: number |}
// table.pack(1, "foo") -> {| n: number, [number]: number | string |}
TypeId result = nullptr;
if (options.empty())
result = context.solver->singletonTypes->nilType;
else if (options.size() == 1)
result = options[0];
else
result = arena->addType(UnionTypeVar{std::move(options)});
TypeId numberType = context.solver->singletonTypes->numberType;
TypeId packedTable = arena->addType(
TableTypeVar{{{"n", {numberType}}}, TableIndexer(numberType, result), {}, TableState::Sealed});
TypePackId tableTypePack = arena->addTypePack({packedTable});
asMutable(context.result)->ty.emplace<BoundTypePack>(tableTypePack);
return true;
}
static bool checkRequirePath(TypeChecker& typechecker, AstExpr* expr) static bool checkRequirePath(TypeChecker& typechecker, AstExpr* expr)
{ {
// require(foo.parent.bar) will technically work, but it depends on legacy goop that // require(foo.parent.bar) will technically work, but it depends on legacy goop that

View file

@ -1,6 +1,6 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Clone.h" #include "Luau/Clone.h"
#include "Luau/RecursionCounter.h" #include "Luau/RecursionCounter.h"
#include "Luau/TxnLog.h" #include "Luau/TxnLog.h"
#include "Luau/TypePack.h" #include "Luau/TypePack.h"
@ -51,7 +51,6 @@ struct TypeCloner
void operator()(const BlockedTypeVar& t); void operator()(const BlockedTypeVar& t);
void operator()(const PendingExpansionTypeVar& t); void operator()(const PendingExpansionTypeVar& t);
void operator()(const PrimitiveTypeVar& t); void operator()(const PrimitiveTypeVar& t);
void operator()(const ConstrainedTypeVar& t);
void operator()(const SingletonTypeVar& t); void operator()(const SingletonTypeVar& t);
void operator()(const FunctionTypeVar& t); void operator()(const FunctionTypeVar& t);
void operator()(const TableTypeVar& t); void operator()(const TableTypeVar& t);
@ -63,6 +62,8 @@ struct TypeCloner
void operator()(const LazyTypeVar& t); void operator()(const LazyTypeVar& t);
void operator()(const UnknownTypeVar& t); void operator()(const UnknownTypeVar& t);
void operator()(const NeverTypeVar& t); void operator()(const NeverTypeVar& t);
void operator()(const UseTypeVar& t);
void operator()(const NegationTypeVar& t);
}; };
struct TypePackCloner struct TypePackCloner
@ -198,21 +199,6 @@ void TypeCloner::operator()(const PrimitiveTypeVar& t)
defaultClone(t); defaultClone(t);
} }
void TypeCloner::operator()(const ConstrainedTypeVar& t)
{
TypeId res = dest.addType(ConstrainedTypeVar{t.level});
ConstrainedTypeVar* ctv = getMutable<ConstrainedTypeVar>(res);
LUAU_ASSERT(ctv);
seenTypes[typeId] = res;
std::vector<TypeId> parts;
for (TypeId part : t.parts)
parts.push_back(clone(part, dest, cloneState));
ctv->parts = std::move(parts);
}
void TypeCloner::operator()(const SingletonTypeVar& t) void TypeCloner::operator()(const SingletonTypeVar& t)
{ {
defaultClone(t); defaultClone(t);
@ -352,6 +338,21 @@ void TypeCloner::operator()(const NeverTypeVar& t)
defaultClone(t); defaultClone(t);
} }
void TypeCloner::operator()(const UseTypeVar& t)
{
TypeId result = dest.addType(BoundTypeVar{follow(typeId)});
seenTypes[typeId] = result;
}
void TypeCloner::operator()(const NegationTypeVar& t)
{
TypeId result = dest.addType(AnyTypeVar{});
seenTypes[typeId] = result;
TypeId ty = clone(t.ty, dest, cloneState);
asMutable(result)->ty = NegationTypeVar{ty};
}
} // anonymous namespace } // anonymous namespace
TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState) TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState)
@ -390,7 +391,6 @@ TypeId clone(TypeId typeId, TypeArena& dest, CloneState& cloneState)
if (!res->persistent) if (!res->persistent)
{ {
asMutable(res)->documentationSymbol = typeId->documentationSymbol; asMutable(res)->documentationSymbol = typeId->documentationSymbol;
asMutable(res)->normal = typeId->normal;
} }
} }
@ -478,11 +478,6 @@ TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysCl
clone.parts = itv->parts; clone.parts = itv->parts;
result = dest.addType(std::move(clone)); result = dest.addType(std::move(clone));
} }
else if (const ConstrainedTypeVar* ctv = get<ConstrainedTypeVar>(ty))
{
ConstrainedTypeVar clone{ctv->level, ctv->parts};
result = dest.addType(std::move(clone));
}
else if (const PendingExpansionTypeVar* petv = get<PendingExpansionTypeVar>(ty)) else if (const PendingExpansionTypeVar* petv = get<PendingExpansionTypeVar>(ty))
{ {
PendingExpansionTypeVar clone{petv->prefix, petv->name, petv->typeArguments, petv->packArguments}; PendingExpansionTypeVar clone{petv->prefix, petv->name, petv->typeArguments, petv->packArguments};
@ -497,6 +492,10 @@ TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysCl
{ {
result = dest.addType(*ty); result = dest.addType(*ty);
} }
else if (const NegationTypeVar* ntv = get<NegationTypeVar>(ty))
{
result = dest.addType(NegationTypeVar{ntv->ty});
}
else else
return result; return result;
@ -504,4 +503,9 @@ TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysCl
return result; return result;
} }
TypeId shallowClone(TypeId ty, NotNull<TypeArena> dest)
{
return shallowClone(ty, *dest, TxnLog::empty());
}
} // namespace Luau } // namespace Luau

View file

@ -1,20 +1,21 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/ConstraintGraphBuilder.h" #include "Luau/ConstraintGraphBuilder.h"
#include "Luau/Ast.h" #include "Luau/Ast.h"
#include "Luau/Clone.h"
#include "Luau/Common.h" #include "Luau/Common.h"
#include "Luau/Constraint.h" #include "Luau/Constraint.h"
#include "Luau/DcrLogger.h"
#include "Luau/ModuleResolver.h" #include "Luau/ModuleResolver.h"
#include "Luau/RecursionCounter.h" #include "Luau/RecursionCounter.h"
#include "Luau/Scope.h"
#include "Luau/ToString.h" #include "Luau/ToString.h"
#include "Luau/DcrLogger.h" #include "Luau/TypeUtils.h"
LUAU_FASTINT(LuauCheckRecursionLimit); LUAU_FASTINT(LuauCheckRecursionLimit);
LUAU_FASTFLAG(DebugLuauLogSolverToJson); LUAU_FASTFLAG(DebugLuauLogSolverToJson);
LUAU_FASTFLAG(DebugLuauMagicTypes); LUAU_FASTFLAG(DebugLuauMagicTypes);
#include "Luau/Scope.h"
namespace Luau namespace Luau
{ {
@ -53,12 +54,13 @@ static bool matchSetmetatable(const AstExprCall& call)
ConstraintGraphBuilder::ConstraintGraphBuilder(const ModuleName& moduleName, ModulePtr module, TypeArena* arena, ConstraintGraphBuilder::ConstraintGraphBuilder(const ModuleName& moduleName, ModulePtr module, TypeArena* arena,
NotNull<ModuleResolver> moduleResolver, NotNull<SingletonTypes> singletonTypes, NotNull<InternalErrorReporter> ice, const ScopePtr& globalScope, NotNull<ModuleResolver> moduleResolver, NotNull<SingletonTypes> singletonTypes, NotNull<InternalErrorReporter> ice, const ScopePtr& globalScope,
DcrLogger* logger) DcrLogger* logger, NotNull<DataFlowGraph> dfg)
: moduleName(moduleName) : moduleName(moduleName)
, module(module) , module(module)
, singletonTypes(singletonTypes) , singletonTypes(singletonTypes)
, arena(arena) , arena(arena)
, rootScope(nullptr) , rootScope(nullptr)
, dfg(dfg)
, moduleResolver(moduleResolver) , moduleResolver(moduleResolver)
, ice(ice) , ice(ice)
, globalScope(globalScope) , globalScope(globalScope)
@ -95,14 +97,14 @@ ScopePtr ConstraintGraphBuilder::childScope(AstNode* node, const ScopePtr& paren
return scope; return scope;
} }
void ConstraintGraphBuilder::addConstraint(const ScopePtr& scope, const Location& location, ConstraintV cv) NotNull<Constraint> ConstraintGraphBuilder::addConstraint(const ScopePtr& scope, const Location& location, ConstraintV cv)
{ {
scope->constraints.emplace_back(new Constraint{NotNull{scope.get()}, location, std::move(cv)}); return NotNull{scope->constraints.emplace_back(new Constraint{NotNull{scope.get()}, location, std::move(cv)}).get()};
} }
void ConstraintGraphBuilder::addConstraint(const ScopePtr& scope, std::unique_ptr<Constraint> c) NotNull<Constraint> ConstraintGraphBuilder::addConstraint(const ScopePtr& scope, std::unique_ptr<Constraint> c)
{ {
scope->constraints.emplace_back(std::move(c)); return NotNull{scope->constraints.emplace_back(std::move(c)).get()};
} }
void ConstraintGraphBuilder::visit(AstStatBlock* block) void ConstraintGraphBuilder::visit(AstStatBlock* block)
@ -229,22 +231,16 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStat* stat)
void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local)
{ {
std::vector<TypeId> varTypes; std::vector<TypeId> varTypes;
varTypes.reserve(local->vars.size);
for (AstLocal* local : local->vars) for (AstLocal* local : local->vars)
{ {
TypeId ty = nullptr; TypeId ty = nullptr;
Location location = local->location;
if (local->annotation) if (local->annotation)
{
location = local->annotation->location;
ty = resolveType(scope, local->annotation, /* topLevel */ true); ty = resolveType(scope, local->annotation, /* topLevel */ true);
}
else
ty = freshType(scope);
varTypes.push_back(ty); varTypes.push_back(ty);
scope->bindings[local] = Binding{ty, location};
} }
for (size_t i = 0; i < local->values.size; ++i) for (size_t i = 0; i < local->values.size; ++i)
@ -257,6 +253,9 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local)
// HACK: we leave nil-initialized things floating under the assumption that they will later be populated. // HACK: we leave nil-initialized things floating under the assumption that they will later be populated.
// See the test TypeInfer/infer_locals_with_nil_value. // See the test TypeInfer/infer_locals_with_nil_value.
// Better flow awareness should make this obsolete. // Better flow awareness should make this obsolete.
if (!varTypes[i])
varTypes[i] = freshType(scope);
} }
else if (i == local->values.size - 1) else if (i == local->values.size - 1)
{ {
@ -268,6 +267,20 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local)
if (i < local->vars.size) if (i < local->vars.size)
{ {
std::vector<TypeId> packTypes = flatten(*arena, singletonTypes, exprPack, varTypes.size() - i);
// fill out missing values in varTypes with values from exprPack
for (size_t j = i; j < varTypes.size(); ++j)
{
if (!varTypes[j])
{
if (j - i < packTypes.size())
varTypes[j] = packTypes[j - i];
else
varTypes[j] = freshType(scope);
}
}
std::vector<TypeId> tailValues{varTypes.begin() + i, varTypes.end()}; std::vector<TypeId> tailValues{varTypes.begin() + i, varTypes.end()};
TypePackId tailPack = arena->addTypePack(std::move(tailValues)); TypePackId tailPack = arena->addTypePack(std::move(tailValues));
addConstraint(scope, local->location, PackSubtypeConstraint{exprPack, tailPack}); addConstraint(scope, local->location, PackSubtypeConstraint{exprPack, tailPack});
@ -281,10 +294,31 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local)
TypeId exprType = check(scope, value, expectedType); TypeId exprType = check(scope, value, expectedType);
if (i < varTypes.size()) if (i < varTypes.size())
addConstraint(scope, local->location, SubtypeConstraint{varTypes[i], exprType}); {
if (varTypes[i])
addConstraint(scope, local->location, SubtypeConstraint{varTypes[i], exprType});
else
varTypes[i] = exprType;
}
} }
} }
for (size_t i = 0; i < local->vars.size; ++i)
{
AstLocal* l = local->vars.data[i];
Location location = l->location;
if (!varTypes[i])
varTypes[i] = freshType(scope);
scope->bindings[l] = Binding{varTypes[i], location};
// HACK: In the greedy solver, we say the type state of a variable is the type annotation itself, but
// the actual type state is the corresponding initializer expression (if it exists) or nil otherwise.
if (auto def = dfg->getDef(l))
scope->dcrRefinements[*def] = varTypes[i];
}
if (local->values.size > 0) if (local->values.size > 0)
{ {
// To correctly handle 'require', we need to import the exported type bindings into the variable 'namespace'. // To correctly handle 'require', we need to import the exported type bindings into the variable 'namespace'.
@ -510,7 +544,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatBlock* block)
void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatAssign* assign) void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatAssign* assign)
{ {
TypePackId varPackId = checkPack(scope, assign->vars); TypePackId varPackId = checkLValues(scope, assign->vars);
TypePackId valuePack = checkPack(scope, assign->values); TypePackId valuePack = checkPack(scope, assign->values);
addConstraint(scope, assign->location, PackSubtypeConstraint{valuePack, varPackId}); addConstraint(scope, assign->location, PackSubtypeConstraint{valuePack, varPackId});
@ -532,7 +566,10 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatCompoundAssign*
void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatIf* ifStatement) void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatIf* ifStatement)
{ {
check(scope, ifStatement->condition); // TODO: Optimization opportunity, the interior scope of the condition could be
// reused for the then body, so we don't need to refine twice.
ScopePtr condScope = childScope(ifStatement->condition, scope);
check(condScope, ifStatement->condition, std::nullopt);
ScopePtr thenScope = childScope(ifStatement->thenbody, scope); ScopePtr thenScope = childScope(ifStatement->thenbody, scope);
visit(thenScope, ifStatement->thenbody); visit(thenScope, ifStatement->thenbody);
@ -893,7 +930,7 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, std::
TypeId result = nullptr; TypeId result = nullptr;
if (auto group = expr->as<AstExprGroup>()) if (auto group = expr->as<AstExprGroup>())
result = check(scope, group->expr); result = check(scope, group->expr, expectedType);
else if (auto stringExpr = expr->as<AstExprConstantString>()) else if (auto stringExpr = expr->as<AstExprConstantString>())
{ {
if (expectedType) if (expectedType)
@ -937,32 +974,14 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, std::
} }
else if (expr->is<AstExprConstantNil>()) else if (expr->is<AstExprConstantNil>())
result = singletonTypes->nilType; result = singletonTypes->nilType;
else if (auto a = expr->as<AstExprLocal>()) else if (auto local = expr->as<AstExprLocal>())
{ result = check(scope, local);
std::optional<TypeId> ty = scope->lookup(a->local); else if (auto global = expr->as<AstExprGlobal>())
if (ty) result = check(scope, global);
result = *ty;
else
result = singletonTypes->errorRecoveryType(); // FIXME? Record an error at this point?
}
else if (auto g = expr->as<AstExprGlobal>())
{
std::optional<TypeId> ty = scope->lookup(g->name);
if (ty)
result = *ty;
else
{
/* prepopulateGlobalScope() has already added all global functions to the environment by this point, so any
* global that is not already in-scope is definitely an unknown symbol.
*/
reportError(g->location, UnknownSymbol{g->name.value});
result = singletonTypes->errorRecoveryType(); // FIXME? Record an error at this point?
}
}
else if (expr->is<AstExprVarargs>()) else if (expr->is<AstExprVarargs>())
result = flattenPack(scope, expr->location, checkPack(scope, expr)); result = flattenPack(scope, expr->location, checkPack(scope, expr));
else if (expr->is<AstExprCall>()) else if (expr->is<AstExprCall>())
result = flattenPack(scope, expr->location, checkPack(scope, expr)); result = flattenPack(scope, expr->location, checkPack(scope, expr)); // TODO: needs predicates too
else if (auto a = expr->as<AstExprFunction>()) else if (auto a = expr->as<AstExprFunction>())
{ {
FunctionSignature sig = checkFunctionSignature(scope, a); FunctionSignature sig = checkFunctionSignature(scope, a);
@ -978,7 +997,7 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, std::
else if (auto unary = expr->as<AstExprUnary>()) else if (auto unary = expr->as<AstExprUnary>())
result = check(scope, unary); result = check(scope, unary);
else if (auto binary = expr->as<AstExprBinary>()) else if (auto binary = expr->as<AstExprBinary>())
result = check(scope, binary); result = check(scope, binary, expectedType);
else if (auto ifElse = expr->as<AstExprIfElse>()) else if (auto ifElse = expr->as<AstExprIfElse>())
result = check(scope, ifElse, expectedType); result = check(scope, ifElse, expectedType);
else if (auto typeAssert = expr->as<AstExprTypeAssertion>()) else if (auto typeAssert = expr->as<AstExprTypeAssertion>())
@ -1002,6 +1021,37 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, std::
return result; return result;
} }
TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprLocal* local)
{
std::optional<TypeId> resultTy;
if (auto def = dfg->getDef(local))
resultTy = scope->lookup(*def);
if (!resultTy)
{
if (auto ty = scope->lookup(local->local))
resultTy = *ty;
}
if (!resultTy)
return singletonTypes->errorRecoveryType(); // TODO: replace with ice, locals should never exist before its definition.
return *resultTy;
}
TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprGlobal* global)
{
if (std::optional<TypeId> ty = scope->lookup(global->name))
return *ty;
/* prepopulateGlobalScope() has already added all global functions to the environment by this point, so any
* global that is not already in-scope is definitely an unknown symbol.
*/
reportError(global->location, UnknownSymbol{global->name.value});
return singletonTypes->errorRecoveryType();
}
TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexName* indexName) TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexName* indexName)
{ {
TypeId obj = check(scope, indexName->expr); TypeId obj = check(scope, indexName->expr);
@ -1036,54 +1086,32 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexExpr* in
TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprUnary* unary) TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprUnary* unary)
{ {
TypeId operandType = check(scope, unary->expr); TypeId operandType = check_(scope, unary);
TypeId resultType = arena->addType(BlockedTypeVar{}); TypeId resultType = arena->addType(BlockedTypeVar{});
addConstraint(scope, unary->location, UnaryConstraint{unary->op, operandType, resultType}); addConstraint(scope, unary->location, UnaryConstraint{unary->op, operandType, resultType});
return resultType; return resultType;
} }
TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprBinary* binary) TypeId ConstraintGraphBuilder::check_(const ScopePtr& scope, AstExprUnary* unary)
{ {
TypeId leftType = check(scope, binary->left); if (unary->op == AstExprUnary::Not)
TypeId rightType = check(scope, binary->right);
switch (binary->op)
{ {
case AstExprBinary::And: TypeId ty = check(scope, unary->expr, std::nullopt);
case AstExprBinary::Or:
{ return ty;
addConstraint(scope, binary->location, SubtypeConstraint{leftType, rightType});
return leftType;
}
case AstExprBinary::Add:
case AstExprBinary::Sub:
case AstExprBinary::Mul:
case AstExprBinary::Div:
case AstExprBinary::Mod:
case AstExprBinary::Pow:
case AstExprBinary::CompareNe:
case AstExprBinary::CompareEq:
case AstExprBinary::CompareLt:
case AstExprBinary::CompareLe:
case AstExprBinary::CompareGt:
case AstExprBinary::CompareGe:
{
TypeId resultType = arena->addType(BlockedTypeVar{});
addConstraint(scope, binary->location, BinaryConstraint{binary->op, leftType, rightType, resultType});
return resultType;
}
case AstExprBinary::Concat:
{
addConstraint(scope, binary->left->location, SubtypeConstraint{leftType, singletonTypes->stringType});
addConstraint(scope, binary->right->location, SubtypeConstraint{rightType, singletonTypes->stringType});
return singletonTypes->stringType;
}
default:
LUAU_ASSERT(0);
} }
LUAU_ASSERT(0); return check(scope, unary->expr);
return nullptr; }
TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprBinary* binary, std::optional<TypeId> expectedType)
{
TypeId leftType = check(scope, binary->left, expectedType);
TypeId rightType = check(scope, binary->right, expectedType);
TypeId resultType = arena->addType(BlockedTypeVar{});
addConstraint(scope, binary->location, BinaryConstraint{binary->op, leftType, rightType, resultType});
return resultType;
} }
TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional<TypeId> expectedType) TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional<TypeId> expectedType)
@ -1106,10 +1134,182 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIfElse* ifEls
TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert) TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert)
{ {
check(scope, typeAssert->expr); check(scope, typeAssert->expr, std::nullopt);
return resolveType(scope, typeAssert->annotation); return resolveType(scope, typeAssert->annotation);
} }
TypePackId ConstraintGraphBuilder::checkLValues(const ScopePtr& scope, AstArray<AstExpr*> exprs)
{
std::vector<TypeId> types;
types.reserve(exprs.size);
for (size_t i = 0; i < exprs.size; ++i)
{
AstExpr* const expr = exprs.data[i];
types.push_back(checkLValue(scope, expr));
}
return arena->addTypePack(std::move(types));
}
static bool isUnsealedTable(TypeId ty)
{
ty = follow(ty);
const TableTypeVar* ttv = get<TableTypeVar>(ty);
return ttv && ttv->state == TableState::Unsealed;
};
/**
* If the expr is a dotted set of names, and if the root symbol refers to an
* unsealed table, return that table type, plus the indeces that follow as a
* vector.
*/
static std::optional<std::pair<Symbol, std::vector<const char*>>> extractDottedName(AstExpr* expr)
{
std::vector<const char*> names;
while (expr)
{
if (auto global = expr->as<AstExprGlobal>())
{
std::reverse(begin(names), end(names));
return std::pair{global->name, std::move(names)};
}
else if (auto local = expr->as<AstExprLocal>())
{
std::reverse(begin(names), end(names));
return std::pair{local->local, std::move(names)};
}
else if (auto indexName = expr->as<AstExprIndexName>())
{
names.push_back(indexName->index.value);
expr = indexName->expr;
}
else
return std::nullopt;
}
return std::nullopt;
}
/**
* Create a shallow copy of `ty` and its properties along `path`. Insert a new
* property (the last segment of `path`) into the tail table with the value `t`.
*
* On success, returns the new outermost table type. If the root table or any
* of its subkeys are not unsealed tables, the function fails and returns
* std::nullopt.
*
* TODO: Prove that we completely give up in the face of indexers and
* metatables.
*/
static std::optional<TypeId> updateTheTableType(NotNull<TypeArena> arena, TypeId ty, const std::vector<const char*>& path, TypeId replaceTy)
{
if (path.empty())
return std::nullopt;
// First walk the path and ensure that it's unsealed tables all the way
// to the end.
{
TypeId t = ty;
for (size_t i = 0; i < path.size() - 1; ++i)
{
if (!isUnsealedTable(t))
return std::nullopt;
const TableTypeVar* tbl = get<TableTypeVar>(t);
auto it = tbl->props.find(path[i]);
if (it == tbl->props.end())
return std::nullopt;
t = it->second.type;
}
// The last path segment should not be a property of the table at all.
// We are not changing property types. We are only admitting this one
// new property to be appended.
if (!isUnsealedTable(t))
return std::nullopt;
const TableTypeVar* tbl = get<TableTypeVar>(t);
auto it = tbl->props.find(path.back());
if (it != tbl->props.end())
return std::nullopt;
}
const TypeId res = shallowClone(ty, arena);
TypeId t = res;
for (size_t i = 0; i < path.size() - 1; ++i)
{
const std::string segment = path[i];
TableTypeVar* ttv = getMutable<TableTypeVar>(t);
LUAU_ASSERT(ttv);
auto propIt = ttv->props.find(segment);
if (propIt != ttv->props.end())
{
LUAU_ASSERT(isUnsealedTable(propIt->second.type));
t = shallowClone(follow(propIt->second.type), arena);
ttv->props[segment].type = t;
}
else
return std::nullopt;
}
TableTypeVar* ttv = getMutable<TableTypeVar>(t);
LUAU_ASSERT(ttv);
const std::string lastSegment = path.back();
LUAU_ASSERT(0 == ttv->props.count(lastSegment));
ttv->props[lastSegment] = Property{replaceTy};
return res;
}
/**
* This function is mostly about identifying properties that are being inserted into unsealed tables.
*
* If expr has the form name.a.b.c
*/
TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr)
{
if (auto indexExpr = expr->as<AstExprIndexExpr>())
{
if (auto constantString = indexExpr->index->as<AstExprConstantString>())
{
AstName syntheticIndex{constantString->value.data};
AstExprIndexName synthetic{
indexExpr->location, indexExpr->expr, syntheticIndex, constantString->location, indexExpr->expr->location.end, '.'};
return checkLValue(scope, &synthetic);
}
}
auto dottedPath = extractDottedName(expr);
if (!dottedPath)
return check(scope, expr);
const auto [sym, segments] = std::move(*dottedPath);
if (!sym.local)
return check(scope, expr);
auto lookupResult = scope->lookupEx(sym);
if (!lookupResult)
return check(scope, expr);
const auto [ty, symbolScope] = std::move(*lookupResult);
TypeId replaceTy = arena->freshType(scope.get());
std::optional<TypeId> updatedType = updateTheTableType(arena, ty, segments, replaceTy);
if (!updatedType)
return check(scope, expr);
std::optional<DefId> def = dfg->getDef(sym);
LUAU_ASSERT(def);
symbolScope->bindings[sym].typeId = *updatedType;
symbolScope->dcrRefinements[*def] = *updatedType;
return replaceTy;
}
TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTable* expr, std::optional<TypeId> expectedType) TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTable* expr, std::optional<TypeId> expectedType)
{ {
TypeId ty = arena->addType(TableTypeVar{}); TypeId ty = arena->addType(TableTypeVar{});
@ -1275,6 +1475,9 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS
argTypes.push_back(t); argTypes.push_back(t);
signatureScope->bindings[local] = Binding{t, local->location}; signatureScope->bindings[local] = Binding{t, local->location};
if (auto def = dfg->getDef(local))
signatureScope->dcrRefinements[*def] = t;
if (local->annotation) if (local->annotation)
{ {
TypeId argAnnotation = resolveType(signatureScope, local->annotation, /* topLevel */ true); TypeId argAnnotation = resolveType(signatureScope, local->annotation, /* topLevel */ true);

View file

@ -3,14 +3,16 @@
#include "Luau/Anyification.h" #include "Luau/Anyification.h"
#include "Luau/ApplyTypeFunction.h" #include "Luau/ApplyTypeFunction.h"
#include "Luau/ConstraintSolver.h" #include "Luau/ConstraintSolver.h"
#include "Luau/DcrLogger.h"
#include "Luau/Instantiation.h" #include "Luau/Instantiation.h"
#include "Luau/Location.h" #include "Luau/Location.h"
#include "Luau/Metamethods.h"
#include "Luau/ModuleResolver.h" #include "Luau/ModuleResolver.h"
#include "Luau/Quantify.h" #include "Luau/Quantify.h"
#include "Luau/ToString.h" #include "Luau/ToString.h"
#include "Luau/TypeUtils.h"
#include "Luau/TypeVar.h" #include "Luau/TypeVar.h"
#include "Luau/Unifier.h" #include "Luau/Unifier.h"
#include "Luau/DcrLogger.h"
#include "Luau/VisitTypeVar.h" #include "Luau/VisitTypeVar.h"
#include "Luau/TypeUtils.h" #include "Luau/TypeUtils.h"
@ -438,6 +440,8 @@ bool ConstraintSolver::tryDispatch(NotNull<const Constraint> constraint, bool fo
success = tryDispatch(*fcc, constraint); success = tryDispatch(*fcc, constraint);
else if (auto hpc = get<HasPropConstraint>(*constraint)) else if (auto hpc = get<HasPropConstraint>(*constraint))
success = tryDispatch(*hpc, constraint); success = tryDispatch(*hpc, constraint);
else if (auto rc = get<RefinementConstraint>(*constraint))
success = tryDispatch(*rc, constraint);
else else
LUAU_ASSERT(false); LUAU_ASSERT(false);
@ -564,44 +568,192 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNull<const Cons
TypeId rightType = follow(c.rightType); TypeId rightType = follow(c.rightType);
TypeId resultType = follow(c.resultType); TypeId resultType = follow(c.resultType);
if (isBlocked(leftType) || isBlocked(rightType)) bool isLogical = c.op == AstExprBinary::Op::And || c.op == AstExprBinary::Op::Or;
{
/* Compound assignments create constraints of the form
*
* A <: Binary<op, A, B>
*
* This constraint is the one that is meant to unblock A, so it doesn't
* make any sense to stop and wait for someone else to do it.
*/
if (leftType != resultType && rightType != resultType)
{
block(c.leftType, constraint);
block(c.rightType, constraint);
return false;
}
}
if (isNumber(leftType)) /* Compound assignments create constraints of the form
{ *
unify(leftType, rightType, constraint->scope); * A <: Binary<op, A, B>
asMutable(resultType)->ty.emplace<BoundTypeVar>(leftType); *
return true; * This constraint is the one that is meant to unblock A, so it doesn't
} * make any sense to stop and wait for someone else to do it.
*/
if (isBlocked(leftType) && leftType != resultType)
return block(c.leftType, constraint);
if (isBlocked(rightType) && rightType != resultType)
return block(c.rightType, constraint);
if (!force) if (!force)
{ {
if (get<FreeTypeVar>(leftType)) // Logical expressions may proceed if the LHS is free.
if (get<FreeTypeVar>(leftType) && !isLogical)
return block(leftType, constraint); return block(leftType, constraint);
} }
if (isBlocked(leftType)) // Logical expressions may proceed if the LHS is free.
if (isBlocked(leftType) || (get<FreeTypeVar>(leftType) && !isLogical))
{ {
asMutable(resultType)->ty.emplace<BoundTypeVar>(errorRecoveryType()); asMutable(resultType)->ty.emplace<BoundTypeVar>(errorRecoveryType());
// reportError(constraint->location, CannotInferBinaryOperation{c.op, std::nullopt, CannotInferBinaryOperation::Operation}); unblock(resultType);
return true; return true;
} }
// TODO metatables, classes // For or expressions, the LHS will never have nil as a possible output.
// Consider:
// local foo = nil or 2
// `foo` will always be 2.
if (c.op == AstExprBinary::Op::Or)
leftType = stripNil(singletonTypes, *arena, leftType);
// Metatables go first, even if there is primitive behavior.
if (auto it = kBinaryOpMetamethods.find(c.op); it != kBinaryOpMetamethods.end())
{
// Metatables are not the same. The metamethod will not be invoked.
if ((c.op == AstExprBinary::Op::CompareEq || c.op == AstExprBinary::Op::CompareNe) &&
getMetatable(leftType, singletonTypes) != getMetatable(rightType, singletonTypes))
{
// TODO: Boolean singleton false? The result is _always_ boolean false.
asMutable(resultType)->ty.emplace<BoundTypeVar>(singletonTypes->booleanType);
unblock(resultType);
return true;
}
std::optional<TypeId> mm;
// The LHS metatable takes priority over the RHS metatable, where
// present.
if (std::optional<TypeId> leftMm = findMetatableEntry(singletonTypes, errors, leftType, it->second, constraint->location))
mm = leftMm;
else if (std::optional<TypeId> rightMm = findMetatableEntry(singletonTypes, errors, rightType, it->second, constraint->location))
mm = rightMm;
if (mm)
{
// TODO: Is a table with __call legal here?
// TODO: Overloads
if (const FunctionTypeVar* ftv = get<FunctionTypeVar>(follow(*mm)))
{
TypePackId inferredArgs;
// For >= and > we invoke __lt and __le respectively with
// swapped argument ordering.
if (c.op == AstExprBinary::Op::CompareGe || c.op == AstExprBinary::Op::CompareGt)
{
inferredArgs = arena->addTypePack({rightType, leftType});
}
else
{
inferredArgs = arena->addTypePack({leftType, rightType});
}
unify(inferredArgs, ftv->argTypes, constraint->scope);
TypeId mmResult;
// Comparison operations always evaluate to a boolean,
// regardless of what the metamethod returns.
switch (c.op)
{
case AstExprBinary::Op::CompareEq:
case AstExprBinary::Op::CompareNe:
case AstExprBinary::Op::CompareGe:
case AstExprBinary::Op::CompareGt:
case AstExprBinary::Op::CompareLe:
case AstExprBinary::Op::CompareLt:
mmResult = singletonTypes->booleanType;
break;
default:
mmResult = first(ftv->retTypes).value_or(errorRecoveryType());
}
asMutable(resultType)->ty.emplace<BoundTypeVar>(mmResult);
unblock(resultType);
return true;
}
}
// If there's no metamethod available, fall back to primitive behavior.
}
// If any is present, the expression must evaluate to any as well.
bool leftAny = get<AnyTypeVar>(leftType) || get<ErrorTypeVar>(leftType);
bool rightAny = get<AnyTypeVar>(rightType) || get<ErrorTypeVar>(rightType);
bool anyPresent = leftAny || rightAny;
switch (c.op)
{
// For arithmetic operators, if the LHS is a number, the RHS must be a
// number as well. The result will also be a number.
case AstExprBinary::Op::Add:
case AstExprBinary::Op::Sub:
case AstExprBinary::Op::Mul:
case AstExprBinary::Op::Div:
case AstExprBinary::Op::Pow:
case AstExprBinary::Op::Mod:
if (isNumber(leftType))
{
unify(leftType, rightType, constraint->scope);
asMutable(resultType)->ty.emplace<BoundTypeVar>(anyPresent ? singletonTypes->anyType : leftType);
unblock(resultType);
return true;
}
break;
// For concatenation, if the LHS is a string, the RHS must be a string as
// well. The result will also be a string.
case AstExprBinary::Op::Concat:
if (isString(leftType))
{
unify(leftType, rightType, constraint->scope);
asMutable(resultType)->ty.emplace<BoundTypeVar>(anyPresent ? singletonTypes->anyType : leftType);
unblock(resultType);
return true;
}
break;
// Inexact comparisons require that the types be both numbers or both
// strings, and evaluate to a boolean.
case AstExprBinary::Op::CompareGe:
case AstExprBinary::Op::CompareGt:
case AstExprBinary::Op::CompareLe:
case AstExprBinary::Op::CompareLt:
if ((isNumber(leftType) && isNumber(rightType)) || (isString(leftType) && isString(rightType)))
{
asMutable(resultType)->ty.emplace<BoundTypeVar>(singletonTypes->booleanType);
unblock(resultType);
return true;
}
break;
// == and ~= always evaluate to a boolean, and impose no other constraints
// on their parameters.
case AstExprBinary::Op::CompareEq:
case AstExprBinary::Op::CompareNe:
asMutable(resultType)->ty.emplace<BoundTypeVar>(singletonTypes->booleanType);
unblock(resultType);
return true;
// And evalutes to a boolean if the LHS is falsey, and the RHS type if LHS is
// truthy.
case AstExprBinary::Op::And:
asMutable(resultType)->ty.emplace<BoundTypeVar>(unionOfTypes(rightType, singletonTypes->booleanType, constraint->scope, false));
unblock(resultType);
return true;
// Or evaluates to the LHS type if the LHS is truthy, and the RHS type if
// LHS is falsey.
case AstExprBinary::Op::Or:
asMutable(resultType)->ty.emplace<BoundTypeVar>(unionOfTypes(rightType, leftType, constraint->scope, true));
unblock(resultType);
return true;
default:
iceReporter.ice("Unhandled AstExprBinary::Op for binary operation", constraint->location);
break;
}
// We failed to either evaluate a metamethod or invoke primitive behavior.
unify(leftType, errorRecoveryType(), constraint->scope);
unify(rightType, errorRecoveryType(), constraint->scope);
asMutable(resultType)->ty.emplace<BoundTypeVar>(errorRecoveryType());
unblock(resultType);
return true; return true;
} }
@ -943,6 +1095,31 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull<cons
return block(c.fn, constraint); return block(c.fn, constraint);
} }
// We don't support magic __call metamethods.
if (std::optional<TypeId> callMm = findMetatableEntry(singletonTypes, errors, fn, "__call", constraint->location))
{
std::vector<TypeId> args{fn};
for (TypeId arg : c.argsPack)
args.push_back(arg);
TypeId instantiatedType = arena->addType(BlockedTypeVar{});
TypeId inferredFnType =
arena->addType(FunctionTypeVar(TypeLevel{}, constraint->scope.get(), arena->addTypePack(TypePack{args, {}}), c.result));
// Alter the inner constraints.
LUAU_ASSERT(c.innerConstraints.size() == 2);
asMutable(*c.innerConstraints.at(0)).c = InstantiationConstraint{instantiatedType, *callMm};
asMutable(*c.innerConstraints.at(1)).c = SubtypeConstraint{inferredFnType, instantiatedType};
unsolvedConstraints.insert(end(unsolvedConstraints), begin(c.innerConstraints), end(c.innerConstraints));
asMutable(c.result)->ty.emplace<FreeTypePack>(constraint->scope);
unblock(c.result);
return true;
}
const FunctionTypeVar* ftv = get<FunctionTypeVar>(fn); const FunctionTypeVar* ftv = get<FunctionTypeVar>(fn);
bool usedMagic = false; bool usedMagic = false;
@ -1059,6 +1236,29 @@ bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNull<const Con
return true; return true;
} }
bool ConstraintSolver::tryDispatch(const RefinementConstraint& c, NotNull<const Constraint> constraint)
{
// TODO: Figure out exact details on when refinements need to be blocked.
// It's possible that it never needs to be, since we can just use intersection types with the discriminant type?
if (!constraint->scope->parent)
iceReporter.ice("No parent scope");
std::optional<TypeId> previousTy = constraint->scope->parent->lookup(c.def);
if (!previousTy)
iceReporter.ice("No previous type");
std::optional<TypeId> useTy = constraint->scope->lookup(c.def);
if (!useTy)
iceReporter.ice("The def is not bound to a type");
TypeId resultTy = follow(*useTy);
std::vector<TypeId> parts{*previousTy, c.discriminantType};
asMutable(resultTy)->ty.emplace<IntersectionTypeVar>(std::move(parts));
return true;
}
bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const IterableConstraint& c, NotNull<const Constraint> constraint, bool force) bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const IterableConstraint& c, NotNull<const Constraint> constraint, bool force)
{ {
auto block_ = [&](auto&& t) { auto block_ = [&](auto&& t) {
@ -1502,4 +1702,39 @@ TypePackId ConstraintSolver::errorRecoveryTypePack() const
return singletonTypes->errorRecoveryTypePack(); return singletonTypes->errorRecoveryTypePack();
} }
TypeId ConstraintSolver::unionOfTypes(TypeId a, TypeId b, NotNull<Scope> scope, bool unifyFreeTypes)
{
a = follow(a);
b = follow(b);
if (unifyFreeTypes && (get<FreeTypeVar>(a) || get<FreeTypeVar>(b)))
{
Unifier u{normalizer, Mode::Strict, scope, Location{}, Covariant};
u.useScopes = true;
u.tryUnify(b, a);
if (u.errors.empty())
{
u.log.commit();
return a;
}
else
{
return singletonTypes->errorRecoveryType(singletonTypes->anyType);
}
}
if (*a == *b)
return a;
std::vector<TypeId> types = reduceUnion({a, b});
if (types.empty())
return singletonTypes->neverType;
if (types.size() == 1)
return types[0];
return arena->addType(UnionTypeVar{types});
}
} // namespace Luau } // namespace Luau

View file

@ -0,0 +1,440 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/DataFlowGraphBuilder.h"
#include "Luau/Error.h"
LUAU_FASTFLAG(DebugLuauFreezeArena)
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
namespace Luau
{
std::optional<DefId> DataFlowGraph::getDef(const AstExpr* expr) const
{
if (auto def = astDefs.find(expr))
return NotNull{*def};
return std::nullopt;
}
std::optional<DefId> DataFlowGraph::getDef(const AstLocal* local) const
{
if (auto def = localDefs.find(local))
return NotNull{*def};
return std::nullopt;
}
std::optional<DefId> DataFlowGraph::getDef(const Symbol& symbol) const
{
if (symbol.local)
return getDef(symbol.local);
else
return std::nullopt;
}
DataFlowGraph DataFlowGraphBuilder::build(AstStatBlock* block, NotNull<InternalErrorReporter> handle)
{
LUAU_ASSERT(FFlag::DebugLuauDeferredConstraintResolution);
DataFlowGraphBuilder builder;
builder.handle = handle;
builder.visit(nullptr, block); // nullptr is the root DFG scope.
if (FFlag::DebugLuauFreezeArena)
builder.arena->allocator.freeze();
return std::move(builder.graph);
}
DfgScope* DataFlowGraphBuilder::childScope(DfgScope* scope)
{
return scopes.emplace_back(new DfgScope{scope}).get();
}
std::optional<DefId> DataFlowGraphBuilder::use(DfgScope* scope, Symbol symbol, AstExpr* e)
{
for (DfgScope* current = scope; current; current = current->parent)
{
if (auto loc = current->bindings.find(symbol))
{
graph.astDefs[e] = *loc;
return NotNull{*loc};
}
}
return std::nullopt;
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatBlock* b)
{
DfgScope* child = childScope(scope);
return visitBlockWithoutChildScope(child, b);
}
void DataFlowGraphBuilder::visitBlockWithoutChildScope(DfgScope* scope, AstStatBlock* b)
{
for (AstStat* s : b->body)
visit(scope, s);
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStat* s)
{
if (auto b = s->as<AstStatBlock>())
return visit(scope, b);
else if (auto i = s->as<AstStatIf>())
return visit(scope, i);
else if (auto w = s->as<AstStatWhile>())
return visit(scope, w);
else if (auto r = s->as<AstStatRepeat>())
return visit(scope, r);
else if (auto b = s->as<AstStatBreak>())
return visit(scope, b);
else if (auto c = s->as<AstStatContinue>())
return visit(scope, c);
else if (auto r = s->as<AstStatReturn>())
return visit(scope, r);
else if (auto e = s->as<AstStatExpr>())
return visit(scope, e);
else if (auto l = s->as<AstStatLocal>())
return visit(scope, l);
else if (auto f = s->as<AstStatFor>())
return visit(scope, f);
else if (auto f = s->as<AstStatForIn>())
return visit(scope, f);
else if (auto a = s->as<AstStatAssign>())
return visit(scope, a);
else if (auto c = s->as<AstStatCompoundAssign>())
return visit(scope, c);
else if (auto f = s->as<AstStatFunction>())
return visit(scope, f);
else if (auto l = s->as<AstStatLocalFunction>())
return visit(scope, l);
else if (auto t = s->as<AstStatTypeAlias>())
return; // ok
else if (auto d = s->as<AstStatDeclareFunction>())
return; // ok
else if (auto d = s->as<AstStatDeclareGlobal>())
return; // ok
else if (auto d = s->as<AstStatDeclareFunction>())
return; // ok
else if (auto d = s->as<AstStatDeclareClass>())
return; // ok
else if (auto _ = s->as<AstStatError>())
return; // ok
else
handle->ice("Unknown AstStat in DataFlowGraphBuilder");
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatIf* i)
{
DfgScope* condScope = childScope(scope);
visitExpr(condScope, i->condition);
visit(condScope, i->thenbody);
if (i->elsebody)
visit(scope, i->elsebody);
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatWhile* w)
{
// TODO(controlflow): entry point has a back edge from exit point
DfgScope* whileScope = childScope(scope);
visitExpr(whileScope, w->condition);
visit(whileScope, w->body);
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatRepeat* r)
{
// TODO(controlflow): entry point has a back edge from exit point
DfgScope* repeatScope = childScope(scope); // TODO: loop scope.
visitBlockWithoutChildScope(repeatScope, r->body);
visitExpr(repeatScope, r->condition);
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatBreak* b)
{
// TODO: Control flow analysis
return; // ok
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatContinue* c)
{
// TODO: Control flow analysis
return; // ok
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatReturn* r)
{
// TODO: Control flow analysis
for (AstExpr* e : r->list)
visitExpr(scope, e);
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatExpr* e)
{
visitExpr(scope, e->expr);
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatLocal* l)
{
// TODO: alias tracking
for (AstExpr* e : l->values)
visitExpr(scope, e);
for (AstLocal* local : l->vars)
{
DefId def = arena->freshDef();
graph.localDefs[local] = def;
scope->bindings[local] = def;
}
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatFor* f)
{
DfgScope* forScope = childScope(scope); // TODO: loop scope.
DefId def = arena->freshDef();
graph.localDefs[f->var] = def;
scope->bindings[f->var] = def;
// TODO(controlflow): entry point has a back edge from exit point
visit(forScope, f->body);
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatForIn* f)
{
DfgScope* forScope = childScope(scope); // TODO: loop scope.
for (AstLocal* local : f->vars)
{
DefId def = arena->freshDef();
graph.localDefs[local] = def;
forScope->bindings[local] = def;
}
// TODO(controlflow): entry point has a back edge from exit point
for (AstExpr* e : f->values)
visitExpr(forScope, e);
visit(forScope, f->body);
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatAssign* a)
{
for (AstExpr* r : a->values)
visitExpr(scope, r);
for (AstExpr* l : a->vars)
{
AstExpr* root = l;
bool isUpdatable = true;
while (true)
{
if (root->is<AstExprLocal>() || root->is<AstExprGlobal>())
break;
AstExprIndexName* indexName = root->as<AstExprIndexName>();
if (!indexName)
{
isUpdatable = false;
break;
}
root = indexName->expr;
}
if (isUpdatable)
{
// TODO global?
if (auto exprLocal = root->as<AstExprLocal>())
{
DefId def = arena->freshDef();
graph.astDefs[exprLocal] = def;
// Update the def in the scope that introduced the local. Not
// the current scope.
AstLocal* local = exprLocal->local;
DfgScope* s = scope;
while (s && !s->bindings.find(local))
s = s->parent;
LUAU_ASSERT(s && s->bindings.find(local));
s->bindings[local] = def;
}
}
visitExpr(scope, l); // TODO: they point to a new def!!
}
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatCompoundAssign* c)
{
// TODO(typestates): The lhs is being read and written to. This might or might not be annoying.
visitExpr(scope, c->value);
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatFunction* f)
{
visitExpr(scope, f->name);
visitExpr(scope, f->func);
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatLocalFunction* l)
{
DefId def = arena->freshDef();
graph.localDefs[l->name] = def;
scope->bindings[l->name] = def;
visitExpr(scope, l->func);
}
ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExpr* e)
{
if (auto g = e->as<AstExprGroup>())
return visitExpr(scope, g->expr);
else if (auto c = e->as<AstExprConstantNil>())
return {}; // ok
else if (auto c = e->as<AstExprConstantBool>())
return {}; // ok
else if (auto c = e->as<AstExprConstantNumber>())
return {}; // ok
else if (auto c = e->as<AstExprConstantString>())
return {}; // ok
else if (auto l = e->as<AstExprLocal>())
return visitExpr(scope, l);
else if (auto g = e->as<AstExprGlobal>())
return visitExpr(scope, g);
else if (auto v = e->as<AstExprVarargs>())
return {}; // ok
else if (auto c = e->as<AstExprCall>())
return visitExpr(scope, c);
else if (auto i = e->as<AstExprIndexName>())
return visitExpr(scope, i);
else if (auto i = e->as<AstExprIndexExpr>())
return visitExpr(scope, i);
else if (auto f = e->as<AstExprFunction>())
return visitExpr(scope, f);
else if (auto t = e->as<AstExprTable>())
return visitExpr(scope, t);
else if (auto u = e->as<AstExprUnary>())
return visitExpr(scope, u);
else if (auto b = e->as<AstExprBinary>())
return visitExpr(scope, b);
else if (auto t = e->as<AstExprTypeAssertion>())
return visitExpr(scope, t);
else if (auto i = e->as<AstExprIfElse>())
return visitExpr(scope, i);
else if (auto i = e->as<AstExprInterpString>())
return visitExpr(scope, i);
else if (auto _ = e->as<AstExprError>())
return {}; // ok
else
handle->ice("Unknown AstExpr in DataFlowGraphBuilder");
}
ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprLocal* l)
{
return {use(scope, l->local, l)};
}
ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprGlobal* g)
{
return {use(scope, g->name, g)};
}
ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprCall* c)
{
visitExpr(scope, c->func);
for (AstExpr* arg : c->args)
visitExpr(scope, arg);
return {};
}
ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIndexName* i)
{
std::optional<DefId> def = visitExpr(scope, i->expr).def;
if (!def)
return {};
// TODO: properties for the above def.
return {};
}
ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIndexExpr* i)
{
visitExpr(scope, i->expr);
visitExpr(scope, i->expr);
if (i->index->as<AstExprConstantString>())
{
// TODO: properties for the def
}
return {};
}
ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprFunction* f)
{
if (AstLocal* self = f->self)
{
DefId def = arena->freshDef();
graph.localDefs[self] = def;
scope->bindings[self] = def;
}
for (AstLocal* param : f->args)
{
DefId def = arena->freshDef();
graph.localDefs[param] = def;
scope->bindings[param] = def;
}
visit(scope, f->body);
return {};
}
ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprTable* t)
{
return {};
}
ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprUnary* u)
{
visitExpr(scope, u->expr);
return {};
}
ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprBinary* b)
{
visitExpr(scope, b->left);
visitExpr(scope, b->right);
return {};
}
ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprTypeAssertion* t)
{
ExpressionFlowGraph result = visitExpr(scope, t->expr);
// TODO: visit type
return result;
}
ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIfElse* i)
{
DfgScope* condScope = childScope(scope);
visitExpr(condScope, i->condition);
visitExpr(condScope, i->trueExpr);
visitExpr(scope, i->falseExpr);
return {};
}
ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprInterpString* i)
{
for (AstExpr* e : i->expressions)
visitExpr(scope, e);
return {};
}
} // namespace Luau

12
Analysis/src/Def.cpp Normal file
View file

@ -0,0 +1,12 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Def.h"
namespace Luau
{
DefId DefArena::freshDef()
{
return NotNull{allocator.allocate<Def>(Undefined{})};
}
} // namespace Luau

View file

@ -7,8 +7,6 @@
#include <stdexcept> #include <stdexcept>
LUAU_FASTFLAGVARIABLE(LuauTypeMismatchModuleNameResolution, false)
static std::string wrongNumberOfArgsString( static std::string wrongNumberOfArgsString(
size_t expectedCount, std::optional<size_t> maximumCount, size_t actualCount, const char* argPrefix = nullptr, bool isVariadic = false) size_t expectedCount, std::optional<size_t> maximumCount, size_t actualCount, const char* argPrefix = nullptr, bool isVariadic = false)
{ {
@ -70,7 +68,7 @@ struct ErrorConverter
{ {
if (auto wantedDefinitionModule = getDefinitionModuleName(tm.wantedType)) if (auto wantedDefinitionModule = getDefinitionModuleName(tm.wantedType))
{ {
if (FFlag::LuauTypeMismatchModuleNameResolution && fileResolver != nullptr) if (fileResolver != nullptr)
{ {
std::string givenModuleName = fileResolver->getHumanReadableModuleName(*givenDefinitionModule); std::string givenModuleName = fileResolver->getHumanReadableModuleName(*givenDefinitionModule);
std::string wantedModuleName = fileResolver->getHumanReadableModuleName(*wantedDefinitionModule); std::string wantedModuleName = fileResolver->getHumanReadableModuleName(*wantedDefinitionModule);
@ -96,14 +94,7 @@ struct ErrorConverter
if (!tm.reason.empty()) if (!tm.reason.empty())
result += tm.reason + " "; result += tm.reason + " ";
if (FFlag::LuauTypeMismatchModuleNameResolution) result += Luau::toString(*tm.error, TypeErrorToStringOptions{fileResolver});
{
result += Luau::toString(*tm.error, TypeErrorToStringOptions{fileResolver});
}
else
{
result += Luau::toString(*tm.error);
}
} }
else if (!tm.reason.empty()) else if (!tm.reason.empty())
{ {

View file

@ -1,11 +1,13 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Frontend.h" #include "Luau/Frontend.h"
#include "Luau/BuiltinDefinitions.h"
#include "Luau/Clone.h" #include "Luau/Clone.h"
#include "Luau/Common.h" #include "Luau/Common.h"
#include "Luau/Config.h" #include "Luau/Config.h"
#include "Luau/ConstraintGraphBuilder.h" #include "Luau/ConstraintGraphBuilder.h"
#include "Luau/ConstraintSolver.h" #include "Luau/ConstraintSolver.h"
#include "Luau/DataFlowGraphBuilder.h"
#include "Luau/DcrLogger.h" #include "Luau/DcrLogger.h"
#include "Luau/FileResolver.h" #include "Luau/FileResolver.h"
#include "Luau/Parser.h" #include "Luau/Parser.h"
@ -15,7 +17,6 @@
#include "Luau/TypeChecker2.h" #include "Luau/TypeChecker2.h"
#include "Luau/TypeInfer.h" #include "Luau/TypeInfer.h"
#include "Luau/Variant.h" #include "Luau/Variant.h"
#include "Luau/BuiltinDefinitions.h"
#include <algorithm> #include <algorithm>
#include <chrono> #include <chrono>
@ -26,7 +27,6 @@ LUAU_FASTINT(LuauTarjanChildLimit)
LUAU_FASTFLAG(LuauInferInNoCheckMode) LUAU_FASTFLAG(LuauInferInNoCheckMode)
LUAU_FASTFLAG(LuauNoMoreGlobalSingletonTypes) LUAU_FASTFLAG(LuauNoMoreGlobalSingletonTypes)
LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false)
LUAU_FASTFLAGVARIABLE(LuauAutocompleteDynamicLimits, false)
LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 100) LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 100)
LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false) LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false)
LUAU_FASTFLAG(DebugLuauLogSolverToJson); LUAU_FASTFLAG(DebugLuauLogSolverToJson);
@ -489,23 +489,19 @@ CheckResult Frontend::check(const ModuleName& name, std::optional<FrontendOption
else else
typeCheckerForAutocomplete.finishTime = std::nullopt; typeCheckerForAutocomplete.finishTime = std::nullopt;
if (FFlag::LuauAutocompleteDynamicLimits) // TODO: This is a dirty ad hoc solution for autocomplete timeouts
{ // We are trying to dynamically adjust our existing limits to lower total typechecking time under the limit
// TODO: This is a dirty ad hoc solution for autocomplete timeouts // so that we'll have type information for the whole file at lower quality instead of a full abort in the middle
// We are trying to dynamically adjust our existing limits to lower total typechecking time under the limit if (FInt::LuauTarjanChildLimit > 0)
// so that we'll have type information for the whole file at lower quality instead of a full abort in the middle typeCheckerForAutocomplete.instantiationChildLimit = std::max(1, int(FInt::LuauTarjanChildLimit * sourceNode.autocompleteLimitsMult));
if (FInt::LuauTarjanChildLimit > 0) else
typeCheckerForAutocomplete.instantiationChildLimit = typeCheckerForAutocomplete.instantiationChildLimit = std::nullopt;
std::max(1, int(FInt::LuauTarjanChildLimit * sourceNode.autocompleteLimitsMult));
else
typeCheckerForAutocomplete.instantiationChildLimit = std::nullopt;
if (FInt::LuauTypeInferIterationLimit > 0) if (FInt::LuauTypeInferIterationLimit > 0)
typeCheckerForAutocomplete.unifierIterationLimit = typeCheckerForAutocomplete.unifierIterationLimit =
std::max(1, int(FInt::LuauTypeInferIterationLimit * sourceNode.autocompleteLimitsMult)); std::max(1, int(FInt::LuauTypeInferIterationLimit * sourceNode.autocompleteLimitsMult));
else else
typeCheckerForAutocomplete.unifierIterationLimit = std::nullopt; typeCheckerForAutocomplete.unifierIterationLimit = std::nullopt;
}
ModulePtr moduleForAutocomplete = FFlag::DebugLuauDeferredConstraintResolution ModulePtr moduleForAutocomplete = FFlag::DebugLuauDeferredConstraintResolution
? check(sourceModule, mode, environmentScope, requireCycles, /*forAutocomplete*/ true) ? check(sourceModule, mode, environmentScope, requireCycles, /*forAutocomplete*/ true)
@ -519,10 +515,9 @@ CheckResult Frontend::check(const ModuleName& name, std::optional<FrontendOption
{ {
checkResult.timeoutHits.push_back(moduleName); checkResult.timeoutHits.push_back(moduleName);
if (FFlag::LuauAutocompleteDynamicLimits) sourceNode.autocompleteLimitsMult = sourceNode.autocompleteLimitsMult / 2.0;
sourceNode.autocompleteLimitsMult = sourceNode.autocompleteLimitsMult / 2.0;
} }
else if (FFlag::LuauAutocompleteDynamicLimits && duration < autocompleteTimeLimit / 2.0) else if (duration < autocompleteTimeLimit / 2.0)
{ {
sourceNode.autocompleteLimitsMult = std::min(sourceNode.autocompleteLimitsMult * 2.0, 1.0); sourceNode.autocompleteLimitsMult = std::min(sourceNode.autocompleteLimitsMult * 2.0, 1.0);
} }
@ -871,13 +866,25 @@ ModulePtr Frontend::check(
} }
} }
DataFlowGraph dfg = DataFlowGraphBuilder::build(sourceModule.root, NotNull{&iceHandler});
const NotNull<ModuleResolver> mr{forAutocomplete ? &moduleResolverForAutocomplete : &moduleResolver}; const NotNull<ModuleResolver> mr{forAutocomplete ? &moduleResolverForAutocomplete : &moduleResolver};
const ScopePtr& globalScope{forAutocomplete ? typeCheckerForAutocomplete.globalScope : typeChecker.globalScope}; const ScopePtr& globalScope{forAutocomplete ? typeCheckerForAutocomplete.globalScope : typeChecker.globalScope};
Normalizer normalizer{&result->internalTypes, singletonTypes, NotNull{&typeChecker.unifierState}}; Normalizer normalizer{&result->internalTypes, singletonTypes, NotNull{&typeChecker.unifierState}};
ConstraintGraphBuilder cgb{ ConstraintGraphBuilder cgb{
sourceModule.name, result, &result->internalTypes, mr, singletonTypes, NotNull(&iceHandler), globalScope, logger.get()}; sourceModule.name,
result,
&result->internalTypes,
mr,
singletonTypes,
NotNull(&iceHandler),
globalScope,
logger.get(),
NotNull{&dfg},
};
cgb.visit(sourceModule.root); cgb.visit(sourceModule.root);
result->errors = std::move(cgb.errors); result->errors = std::move(cgb.errors);

View file

@ -60,36 +60,6 @@ bool isWithinComment(const SourceModule& sourceModule, Position pos)
return contains(pos, *iter); return contains(pos, *iter);
} }
struct ForceNormal : TypeVarOnceVisitor
{
const TypeArena* typeArena = nullptr;
ForceNormal(const TypeArena* typeArena)
: typeArena(typeArena)
{
}
bool visit(TypeId ty) override
{
if (ty->owningArena != typeArena)
return false;
asMutable(ty)->normal = true;
return true;
}
bool visit(TypeId ty, const FreeTypeVar& ftv) override
{
visit(ty);
return true;
}
bool visit(TypePackId tp, const FreeTypePack& ftp) override
{
return true;
}
};
struct ClonePublicInterface : Substitution struct ClonePublicInterface : Substitution
{ {
NotNull<SingletonTypes> singletonTypes; NotNull<SingletonTypes> singletonTypes;
@ -241,8 +211,6 @@ void Module::clonePublicInterface(NotNull<SingletonTypes> singletonTypes, Intern
moduleScope->varargPack = varargPack; moduleScope->varargPack = varargPack;
} }
ForceNormal forceNormal{&interfaceTypes};
if (exportedTypeBindings) if (exportedTypeBindings)
{ {
for (auto& [name, tf] : *exportedTypeBindings) for (auto& [name, tf] : *exportedTypeBindings)
@ -262,7 +230,6 @@ void Module::clonePublicInterface(NotNull<SingletonTypes> singletonTypes, Intern
{ {
auto t = asMutable(ty); auto t = asMutable(ty);
t->ty = AnyTypeVar{}; t->ty = AnyTypeVar{};
t->normal = true;
} }
} }
} }

View file

@ -16,11 +16,11 @@ LUAU_FASTFLAGVARIABLE(DebugLuauCheckNormalizeInvariant, false)
// This could theoretically be 2000 on amd64, but x86 requires this. // This could theoretically be 2000 on amd64, but x86 requires this.
LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200); LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200);
LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000); LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000);
LUAU_FASTINT(LuauTypeInferRecursionLimit);
LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineTableFix, false); LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineTableFix, false);
LUAU_FASTFLAGVARIABLE(LuauTypeNormalization2, false); LUAU_FASTFLAGVARIABLE(LuauTypeNormalization2, false);
LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAG(LuauUnknownAndNeverType)
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
LUAU_FASTFLAG(LuauOverloadedFunctionSubtypingPerf);
namespace Luau namespace Luau
{ {
@ -1269,19 +1269,35 @@ std::optional<TypeId> Normalizer::intersectionOfFunctions(TypeId here, TypeId th
return std::nullopt; return std::nullopt;
if (hftv->genericPacks != tftv->genericPacks) if (hftv->genericPacks != tftv->genericPacks)
return std::nullopt; return std::nullopt;
if (hftv->retTypes != tftv->retTypes)
TypePackId argTypes;
TypePackId retTypes;
if (hftv->retTypes == tftv->retTypes)
{
std::optional<TypePackId> argTypesOpt = unionOfTypePacks(hftv->argTypes, tftv->argTypes);
if (!argTypesOpt)
return std::nullopt;
argTypes = *argTypesOpt;
retTypes = hftv->retTypes;
}
else if (FFlag::LuauOverloadedFunctionSubtypingPerf && hftv->argTypes == tftv->argTypes)
{
std::optional<TypePackId> retTypesOpt = intersectionOfTypePacks(hftv->argTypes, tftv->argTypes);
if (!retTypesOpt)
return std::nullopt;
argTypes = hftv->argTypes;
retTypes = *retTypesOpt;
}
else
return std::nullopt; return std::nullopt;
std::optional<TypePackId> argTypes = unionOfTypePacks(hftv->argTypes, tftv->argTypes); if (argTypes == hftv->argTypes && retTypes == hftv->retTypes)
if (!argTypes)
return std::nullopt;
if (*argTypes == hftv->argTypes)
return here; return here;
if (*argTypes == tftv->argTypes) if (argTypes == tftv->argTypes && retTypes == tftv->retTypes)
return there; return there;
FunctionTypeVar result{*argTypes, hftv->retTypes}; FunctionTypeVar result{argTypes, retTypes};
result.generics = hftv->generics; result.generics = hftv->generics;
result.genericPacks = hftv->genericPacks; result.genericPacks = hftv->genericPacks;
return arena->addType(std::move(result)); return arena->addType(std::move(result));
@ -1762,610 +1778,4 @@ bool isSubtype(
return ok; return ok;
} }
template<typename T>
static bool areNormal_(const T& t, const std::unordered_set<void*>& seen, InternalErrorReporter& ice)
{
int count = 0;
auto isNormal = [&](TypeId ty) {
++count;
if (count >= FInt::LuauNormalizeIterationLimit)
ice.ice("Luau::areNormal hit iteration limit");
return ty->normal;
};
return std::all_of(begin(t), end(t), isNormal);
}
static bool areNormal(const std::vector<TypeId>& types, const std::unordered_set<void*>& seen, InternalErrorReporter& ice)
{
return areNormal_(types, seen, ice);
}
static bool areNormal(TypePackId tp, const std::unordered_set<void*>& seen, InternalErrorReporter& ice)
{
tp = follow(tp);
if (get<FreeTypePack>(tp))
return false;
auto [head, tail] = flatten(tp);
if (!areNormal_(head, seen, ice))
return false;
if (!tail)
return true;
if (auto vtp = get<VariadicTypePack>(*tail))
return vtp->ty->normal || follow(vtp->ty)->normal || seen.find(asMutable(vtp->ty)) != seen.end();
return true;
}
#define CHECK_ITERATION_LIMIT(...) \
do \
{ \
if (iterationLimit > FInt::LuauNormalizeIterationLimit) \
{ \
limitExceeded = true; \
return __VA_ARGS__; \
} \
++iterationLimit; \
} while (false)
struct Normalize final : TypeVarVisitor
{
using TypeVarVisitor::Set;
Normalize(TypeArena& arena, NotNull<Scope> scope, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter& ice)
: arena(arena)
, scope(scope)
, singletonTypes(singletonTypes)
, ice(ice)
{
}
TypeArena& arena;
NotNull<Scope> scope;
NotNull<SingletonTypes> singletonTypes;
InternalErrorReporter& ice;
int iterationLimit = 0;
bool limitExceeded = false;
bool visit(TypeId ty, const FreeTypeVar&) override
{
LUAU_ASSERT(!ty->normal);
return false;
}
bool visit(TypeId ty, const BoundTypeVar& btv) override
{
// A type could be considered normal when it is in the stack, but we will eventually find out it is not normal as normalization progresses.
// So we need to avoid eagerly saying that this bound type is normal if the thing it is bound to is in the stack.
if (seen.find(asMutable(btv.boundTo)) != seen.end())
return false;
// It should never be the case that this TypeVar is normal, but is bound to a non-normal type, except in nontrivial cases.
LUAU_ASSERT(!ty->normal || ty->normal == btv.boundTo->normal);
if (!ty->normal)
asMutable(ty)->normal = btv.boundTo->normal;
return !ty->normal;
}
bool visit(TypeId ty, const PrimitiveTypeVar&) override
{
LUAU_ASSERT(ty->normal);
return false;
}
bool visit(TypeId ty, const GenericTypeVar&) override
{
if (!ty->normal)
asMutable(ty)->normal = true;
return false;
}
bool visit(TypeId ty, const ErrorTypeVar&) override
{
if (!ty->normal)
asMutable(ty)->normal = true;
return false;
}
bool visit(TypeId ty, const UnknownTypeVar&) override
{
if (!ty->normal)
asMutable(ty)->normal = true;
return false;
}
bool visit(TypeId ty, const NeverTypeVar&) override
{
if (!ty->normal)
asMutable(ty)->normal = true;
return false;
}
bool visit(TypeId ty, const ConstrainedTypeVar& ctvRef) override
{
CHECK_ITERATION_LIMIT(false);
LUAU_ASSERT(!ty->normal);
ConstrainedTypeVar* ctv = const_cast<ConstrainedTypeVar*>(&ctvRef);
std::vector<TypeId> parts = std::move(ctv->parts);
// We might transmute, so it's not safe to rely on the builtin traversal logic of visitTypeVar
for (TypeId part : parts)
traverse(part);
std::vector<TypeId> newParts = normalizeUnion(parts);
ctv->parts = std::move(newParts);
return false;
}
bool visit(TypeId ty, const FunctionTypeVar& ftv) override
{
CHECK_ITERATION_LIMIT(false);
if (ty->normal)
return false;
traverse(ftv.argTypes);
traverse(ftv.retTypes);
asMutable(ty)->normal = areNormal(ftv.argTypes, seen, ice) && areNormal(ftv.retTypes, seen, ice);
return false;
}
bool visit(TypeId ty, const TableTypeVar& ttv) override
{
CHECK_ITERATION_LIMIT(false);
if (ty->normal)
return false;
bool normal = true;
auto checkNormal = [&](TypeId t) {
// if t is on the stack, it is possible that this type is normal.
// If t is not normal and it is not on the stack, this type is definitely not normal.
if (!t->normal && seen.find(asMutable(t)) == seen.end())
normal = false;
};
if (ttv.boundTo)
{
traverse(*ttv.boundTo);
asMutable(ty)->normal = (*ttv.boundTo)->normal;
return false;
}
for (const auto& [_name, prop] : ttv.props)
{
traverse(prop.type);
checkNormal(prop.type);
}
if (ttv.indexer)
{
traverse(ttv.indexer->indexType);
checkNormal(ttv.indexer->indexType);
traverse(ttv.indexer->indexResultType);
checkNormal(ttv.indexer->indexResultType);
}
// An unsealed table can never be normal, ditto for free tables iff the type it is bound to is also not normal.
if (ttv.state == TableState::Generic || ttv.state == TableState::Sealed || (ttv.state == TableState::Free && follow(ty)->normal))
asMutable(ty)->normal = normal;
return false;
}
bool visit(TypeId ty, const MetatableTypeVar& mtv) override
{
CHECK_ITERATION_LIMIT(false);
if (ty->normal)
return false;
traverse(mtv.table);
traverse(mtv.metatable);
asMutable(ty)->normal = mtv.table->normal && mtv.metatable->normal;
return false;
}
bool visit(TypeId ty, const ClassTypeVar& ctv) override
{
if (!ty->normal)
asMutable(ty)->normal = true;
return false;
}
bool visit(TypeId ty, const AnyTypeVar&) override
{
LUAU_ASSERT(ty->normal);
return false;
}
bool visit(TypeId ty, const UnionTypeVar& utvRef) override
{
CHECK_ITERATION_LIMIT(false);
if (ty->normal)
return false;
UnionTypeVar* utv = &const_cast<UnionTypeVar&>(utvRef);
// We might transmute, so it's not safe to rely on the builtin traversal logic of visitTypeVar
for (TypeId option : utv->options)
traverse(option);
std::vector<TypeId> newOptions = normalizeUnion(utv->options);
const bool normal = areNormal(newOptions, seen, ice);
LUAU_ASSERT(!newOptions.empty());
if (newOptions.size() == 1)
*asMutable(ty) = BoundTypeVar{newOptions[0]};
else
utv->options = std::move(newOptions);
asMutable(ty)->normal = normal;
return false;
}
bool visit(TypeId ty, const IntersectionTypeVar& itvRef) override
{
CHECK_ITERATION_LIMIT(false);
if (ty->normal)
return false;
IntersectionTypeVar* itv = &const_cast<IntersectionTypeVar&>(itvRef);
std::vector<TypeId> oldParts = itv->parts;
IntersectionTypeVar newIntersection;
for (TypeId part : oldParts)
traverse(part);
std::vector<TypeId> tables;
for (TypeId part : oldParts)
{
part = follow(part);
if (get<TableTypeVar>(part))
tables.push_back(part);
else
{
Replacer replacer{&arena, nullptr, nullptr}; // FIXME this is super super WEIRD
combineIntoIntersection(replacer, &newIntersection, part);
}
}
// Don't allocate a new table if there's just one in the intersection.
if (tables.size() == 1)
newIntersection.parts.push_back(tables[0]);
else if (!tables.empty())
{
const TableTypeVar* first = get<TableTypeVar>(tables[0]);
LUAU_ASSERT(first);
TypeId newTable = arena.addType(TableTypeVar{first->state, first->level});
TableTypeVar* ttv = getMutable<TableTypeVar>(newTable);
for (TypeId part : tables)
{
// Intuition: If combineIntoTable() needs to clone a table, any references to 'part' are cyclic and need
// to be rewritten to point at 'newTable' in the clone.
Replacer replacer{&arena, part, newTable};
combineIntoTable(replacer, ttv, part);
}
newIntersection.parts.push_back(newTable);
}
itv->parts = std::move(newIntersection.parts);
asMutable(ty)->normal = areNormal(itv->parts, seen, ice);
if (itv->parts.size() == 1)
{
TypeId part = itv->parts[0];
*asMutable(ty) = BoundTypeVar{part};
}
return false;
}
std::vector<TypeId> normalizeUnion(const std::vector<TypeId>& options)
{
if (options.size() == 1)
return options;
std::vector<TypeId> result;
for (TypeId part : options)
{
// AnyTypeVar always win the battle no matter what we do, so we're done.
if (FFlag::LuauUnknownAndNeverType && get<AnyTypeVar>(follow(part)))
return {part};
combineIntoUnion(result, part);
}
return result;
}
void combineIntoUnion(std::vector<TypeId>& result, TypeId ty)
{
ty = follow(ty);
if (auto utv = get<UnionTypeVar>(ty))
{
for (TypeId t : utv)
{
// AnyTypeVar always win the battle no matter what we do, so we're done.
if (FFlag::LuauUnknownAndNeverType && get<AnyTypeVar>(t))
{
result = {t};
return;
}
combineIntoUnion(result, t);
}
return;
}
for (TypeId& part : result)
{
if (isSubtype(ty, part, scope, singletonTypes, ice))
return; // no need to do anything
else if (isSubtype(part, ty, scope, singletonTypes, ice))
{
part = ty; // replace the less general type by the more general one
return;
}
}
result.push_back(ty);
}
/**
* @param replacer knows how to clone a type such that any recursive references point at the new containing type.
* @param result is an intersection that is safe for us to mutate in-place.
*/
void combineIntoIntersection(Replacer& replacer, IntersectionTypeVar* result, TypeId ty)
{
// Note: this check guards against running out of stack space
// so if you increase the size of a stack frame, you'll need to decrease the limit.
CHECK_ITERATION_LIMIT();
ty = follow(ty);
if (auto itv = get<IntersectionTypeVar>(ty))
{
for (TypeId part : itv->parts)
combineIntoIntersection(replacer, result, part);
return;
}
// Let's say that the last part of our result intersection is always a table, if any table is part of this intersection
if (get<TableTypeVar>(ty))
{
if (result->parts.empty())
result->parts.push_back(arena.addType(TableTypeVar{TableState::Sealed, TypeLevel{}}));
TypeId theTable = result->parts.back();
if (!get<TableTypeVar>(follow(theTable)))
{
result->parts.push_back(arena.addType(TableTypeVar{TableState::Sealed, TypeLevel{}}));
theTable = result->parts.back();
}
TypeId newTable = replacer.smartClone(theTable);
result->parts.back() = newTable;
combineIntoTable(replacer, getMutable<TableTypeVar>(newTable), ty);
}
else if (auto ftv = get<FunctionTypeVar>(ty))
{
bool merged = false;
for (TypeId& part : result->parts)
{
if (isSubtype(part, ty, scope, singletonTypes, ice))
{
merged = true;
break; // no need to do anything
}
else if (isSubtype(ty, part, scope, singletonTypes, ice))
{
merged = true;
part = ty; // replace the less general type by the more general one
break;
}
}
if (!merged)
result->parts.push_back(ty);
}
else
result->parts.push_back(ty);
}
TableState combineTableStates(TableState lhs, TableState rhs)
{
if (lhs == rhs)
return lhs;
if (lhs == TableState::Free || rhs == TableState::Free)
return TableState::Free;
if (lhs == TableState::Unsealed || rhs == TableState::Unsealed)
return TableState::Unsealed;
return lhs;
}
/**
* @param replacer gives us a way to clone a type such that recursive references are rewritten to the new
* "containing" type.
* @param table always points into a table that is safe for us to mutate.
*/
void combineIntoTable(Replacer& replacer, TableTypeVar* table, TypeId ty)
{
// Note: this check guards against running out of stack space
// so if you increase the size of a stack frame, you'll need to decrease the limit.
CHECK_ITERATION_LIMIT();
LUAU_ASSERT(table);
ty = follow(ty);
TableTypeVar* tyTable = getMutable<TableTypeVar>(ty);
LUAU_ASSERT(tyTable);
for (const auto& [propName, prop] : tyTable->props)
{
if (auto it = table->props.find(propName); it != table->props.end())
{
/**
* If we are going to recursively merge intersections of tables, we need to ensure that we never mutate
* a table that comes from somewhere else in the type graph.
*
* smarClone() does some nice things for us: It will perform a clone that is as shallow as possible
* while still rewriting any cyclic references back to the new 'root' table.
*
* replacer also keeps a mapping of types that have previously been copied, so we have the added
* advantage here of knowing that, whether or not a new copy was actually made, the resulting TypeVar is
* safe for us to mutate in-place.
*/
TypeId clone = replacer.smartClone(it->second.type);
it->second.type = combine(replacer, clone, prop.type);
}
else
table->props.insert({propName, prop});
}
if (tyTable->indexer)
{
if (table->indexer)
{
table->indexer->indexType = combine(replacer, replacer.smartClone(tyTable->indexer->indexType), table->indexer->indexType);
table->indexer->indexResultType =
combine(replacer, replacer.smartClone(tyTable->indexer->indexResultType), table->indexer->indexResultType);
}
else
{
table->indexer =
TableIndexer{replacer.smartClone(tyTable->indexer->indexType), replacer.smartClone(tyTable->indexer->indexResultType)};
}
}
table->state = combineTableStates(table->state, tyTable->state);
table->level = max(table->level, tyTable->level);
}
/**
* @param a is always cloned by the caller. It is safe to mutate in-place.
* @param b will never be mutated.
*/
TypeId combine(Replacer& replacer, TypeId a, TypeId b)
{
b = follow(b);
if (FFlag::LuauNormalizeCombineTableFix && a == b)
return a;
if (!get<IntersectionTypeVar>(a) && !get<TableTypeVar>(a))
{
if (!FFlag::LuauNormalizeCombineTableFix && a == b)
return a;
else
return arena.addType(IntersectionTypeVar{{a, b}});
}
if (auto itv = getMutable<IntersectionTypeVar>(a))
{
combineIntoIntersection(replacer, itv, b);
return a;
}
else if (auto ttv = getMutable<TableTypeVar>(a))
{
if (FFlag::LuauNormalizeCombineTableFix && !get<TableTypeVar>(b))
return arena.addType(IntersectionTypeVar{{a, b}});
combineIntoTable(replacer, ttv, b);
return a;
}
LUAU_ASSERT(!"Impossible");
LUAU_UNREACHABLE();
}
};
#undef CHECK_ITERATION_LIMIT
/**
* @returns A tuple of TypeId and a success indicator. (true indicates that the normalization completed successfully)
*/
std::pair<TypeId, bool> normalize(
TypeId ty, NotNull<Scope> scope, TypeArena& arena, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter& ice)
{
CloneState state;
if (FFlag::DebugLuauCopyBeforeNormalizing)
(void)clone(ty, arena, state);
Normalize n{arena, scope, singletonTypes, ice};
n.traverse(ty);
return {ty, !n.limitExceeded};
}
// TODO: Think about using a temporary arena and cloning types out of it so that we
// reclaim memory used by wantonly allocated intermediate types here.
// The main wrinkle here is that we don't want clone() to copy a type if the source and dest
// arena are the same.
std::pair<TypeId, bool> normalize(TypeId ty, NotNull<Module> module, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter& ice)
{
return normalize(ty, NotNull{module->getModuleScope().get()}, module->internalTypes, singletonTypes, ice);
}
std::pair<TypeId, bool> normalize(TypeId ty, const ModulePtr& module, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter& ice)
{
return normalize(ty, NotNull{module.get()}, singletonTypes, ice);
}
/**
* @returns A tuple of TypeId and a success indicator. (true indicates that the normalization completed successfully)
*/
std::pair<TypePackId, bool> normalize(
TypePackId tp, NotNull<Scope> scope, TypeArena& arena, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter& ice)
{
CloneState state;
if (FFlag::DebugLuauCopyBeforeNormalizing)
(void)clone(tp, arena, state);
Normalize n{arena, scope, singletonTypes, ice};
n.traverse(tp);
return {tp, !n.limitExceeded};
}
std::pair<TypePackId, bool> normalize(TypePackId tp, NotNull<Module> module, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter& ice)
{
return normalize(tp, NotNull{module->getModuleScope().get()}, module->internalTypes, singletonTypes, ice);
}
std::pair<TypePackId, bool> normalize(TypePackId tp, const ModulePtr& module, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter& ice)
{
return normalize(tp, NotNull{module.get()}, singletonTypes, ice);
}
} // namespace Luau } // namespace Luau

View file

@ -57,29 +57,6 @@ struct Quantifier final : TypeVarOnceVisitor
return false; return false;
} }
bool visit(TypeId ty, const ConstrainedTypeVar&) override
{
ConstrainedTypeVar* ctv = getMutable<ConstrainedTypeVar>(ty);
seenMutableType = true;
if (!level.subsumes(ctv->level))
return false;
std::vector<TypeId> opts = std::move(ctv->parts);
// We might transmute, so it's not safe to rely on the builtin traversal logic
for (TypeId opt : opts)
traverse(opt);
if (opts.size() == 1)
*asMutable(ty) = BoundTypeVar{opts[0]};
else
*asMutable(ty) = UnionTypeVar{std::move(opts)};
return false;
}
bool visit(TypeId ty, const TableTypeVar&) override bool visit(TypeId ty, const TableTypeVar&) override
{ {
LUAU_ASSERT(getMutable<TableTypeVar>(ty)); LUAU_ASSERT(getMutable<TableTypeVar>(ty));

View file

@ -27,6 +27,44 @@ void Scope::addBuiltinTypeBinding(const Name& name, const TypeFun& tyFun)
builtinTypeNames.insert(name); builtinTypeNames.insert(name);
} }
std::optional<TypeId> Scope::lookup(Symbol sym) const
{
auto r = const_cast<Scope*>(this)->lookupEx(sym);
if (r)
return r->first;
else
return std::nullopt;
}
std::optional<std::pair<TypeId, Scope*>> Scope::lookupEx(Symbol sym)
{
Scope* s = this;
while (true)
{
auto it = s->bindings.find(sym);
if (it != s->bindings.end())
return std::pair{it->second.typeId, s};
if (s->parent)
s = s->parent.get();
else
return std::nullopt;
}
}
// TODO: We might kill Scope::lookup(Symbol) once data flow is fully fleshed out with type states and control flow analysis.
std::optional<TypeId> Scope::lookup(DefId def) const
{
for (const Scope* current = this; current; current = current->parent.get())
{
if (auto ty = current->dcrRefinements.find(def))
return *ty;
}
return std::nullopt;
}
std::optional<TypeFun> Scope::lookupType(const Name& name) std::optional<TypeFun> Scope::lookupType(const Name& name)
{ {
const Scope* scope = this; const Scope* scope = this;
@ -111,23 +149,6 @@ std::optional<Binding> Scope::linearSearchForBinding(const std::string& name, bo
return std::nullopt; return std::nullopt;
} }
std::optional<TypeId> Scope::lookup(Symbol sym)
{
Scope* s = this;
while (true)
{
auto it = s->bindings.find(sym);
if (it != s->bindings.end())
return it->second.typeId;
if (s->parent)
s = s->parent.get();
else
return std::nullopt;
}
}
bool subsumesStrict(Scope* left, Scope* right) bool subsumesStrict(Scope* left, Scope* right)
{ {
while (right) while (right)

View file

@ -73,11 +73,6 @@ void Tarjan::visitChildren(TypeId ty, int index)
for (TypeId part : itv->parts) for (TypeId part : itv->parts)
visitChild(part); visitChild(part);
} }
else if (const ConstrainedTypeVar* ctv = get<ConstrainedTypeVar>(ty))
{
for (TypeId part : ctv->parts)
visitChild(part);
}
else if (const PendingExpansionTypeVar* petv = get<PendingExpansionTypeVar>(ty)) else if (const PendingExpansionTypeVar* petv = get<PendingExpansionTypeVar>(ty))
{ {
for (TypeId a : petv->typeArguments) for (TypeId a : petv->typeArguments)
@ -97,6 +92,10 @@ void Tarjan::visitChildren(TypeId ty, int index)
if (ctv->metatable) if (ctv->metatable)
visitChild(*ctv->metatable); visitChild(*ctv->metatable);
} }
else if (const NegationTypeVar* ntv = get<NegationTypeVar>(ty))
{
visitChild(ntv->ty);
}
} }
void Tarjan::visitChildren(TypePackId tp, int index) void Tarjan::visitChildren(TypePackId tp, int index)
@ -605,11 +604,6 @@ void Substitution::replaceChildren(TypeId ty)
for (TypeId& part : itv->parts) for (TypeId& part : itv->parts)
part = replace(part); part = replace(part);
} }
else if (ConstrainedTypeVar* ctv = getMutable<ConstrainedTypeVar>(ty))
{
for (TypeId& part : ctv->parts)
part = replace(part);
}
else if (PendingExpansionTypeVar* petv = getMutable<PendingExpansionTypeVar>(ty)) else if (PendingExpansionTypeVar* petv = getMutable<PendingExpansionTypeVar>(ty))
{ {
for (TypeId& a : petv->typeArguments) for (TypeId& a : petv->typeArguments)
@ -629,6 +623,10 @@ void Substitution::replaceChildren(TypeId ty)
if (ctv->metatable) if (ctv->metatable)
ctv->metatable = replace(*ctv->metatable); ctv->metatable = replace(*ctv->metatable);
} }
else if (NegationTypeVar* ntv = getMutable<NegationTypeVar>(ty))
{
ntv->ty = replace(ntv->ty);
}
} }
void Substitution::replaceChildren(TypePackId tp) void Substitution::replaceChildren(TypePackId tp)

View file

@ -237,15 +237,6 @@ void StateDot::visitChildren(TypeId ty, int index)
finishNodeLabel(ty); finishNodeLabel(ty);
finishNode(); finishNode();
} }
else if (const ConstrainedTypeVar* ctv = get<ConstrainedTypeVar>(ty))
{
formatAppend(result, "ConstrainedTypeVar %d", index);
finishNodeLabel(ty);
finishNode();
for (TypeId part : ctv->parts)
visitChild(part, index);
}
else if (get<ErrorTypeVar>(ty)) else if (get<ErrorTypeVar>(ty))
{ {
formatAppend(result, "ErrorTypeVar %d", index); formatAppend(result, "ErrorTypeVar %d", index);

View file

@ -400,29 +400,6 @@ struct TypeVarStringifier
state.emit(state.getName(ty)); state.emit(state.getName(ty));
} }
void operator()(TypeId, const ConstrainedTypeVar& ctv)
{
state.result.invalid = true;
state.emit("[");
if (FFlag::DebugLuauVerboseTypeNames)
state.emit(ctv.level);
state.emit("[");
bool first = true;
for (TypeId ty : ctv.parts)
{
if (first)
first = false;
else
state.emit("|");
stringify(ty);
}
state.emit("]]");
}
void operator()(TypeId, const BlockedTypeVar& btv) void operator()(TypeId, const BlockedTypeVar& btv)
{ {
state.emit("*blocked-"); state.emit("*blocked-");
@ -871,6 +848,28 @@ struct TypeVarStringifier
{ {
state.emit("never"); state.emit("never");
} }
void operator()(TypeId ty, const UseTypeVar&)
{
stringify(follow(ty));
}
void operator()(TypeId, const NegationTypeVar& ntv)
{
state.emit("~");
// The precedence of `~` should be less than `|` and `&`.
TypeId followed = follow(ntv.ty);
bool parens = get<UnionTypeVar>(followed) || get<IntersectionTypeVar>(followed);
if (parens)
state.emit("(");
stringify(ntv.ty);
if (parens)
state.emit(")");
}
}; };
struct TypePackStringifier struct TypePackStringifier
@ -1442,7 +1441,7 @@ std::string generateName(size_t i)
std::string toString(const Constraint& constraint, ToStringOptions& opts) std::string toString(const Constraint& constraint, ToStringOptions& opts)
{ {
auto go = [&opts](auto&& c) { auto go = [&opts](auto&& c) -> std::string {
using T = std::decay_t<decltype(c)>; using T = std::decay_t<decltype(c)>;
// TODO: Inline and delete this function when clipping FFlag::LuauFixNameMaps // TODO: Inline and delete this function when clipping FFlag::LuauFixNameMaps
@ -1526,6 +1525,10 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts)
{ {
return tos(c.resultType, opts) + " ~ hasProp " + tos(c.subjectType, opts) + ", \"" + c.prop + "\""; return tos(c.resultType, opts) + " ~ hasProp " + tos(c.subjectType, opts) + ", \"" + c.prop + "\"";
} }
else if constexpr (std::is_same_v<T, RefinementConstraint>)
{
return "TODO";
}
else else
static_assert(always_false_v<T>, "Non-exhaustive constraint switch"); static_assert(always_false_v<T>, "Non-exhaustive constraint switch");
}; };

View file

@ -251,7 +251,7 @@ PendingType* TxnLog::bindTable(TypeId ty, std::optional<TypeId> newBoundTo)
PendingType* TxnLog::changeLevel(TypeId ty, TypeLevel newLevel) PendingType* TxnLog::changeLevel(TypeId ty, TypeLevel newLevel)
{ {
LUAU_ASSERT(get<FreeTypeVar>(ty) || get<TableTypeVar>(ty) || get<FunctionTypeVar>(ty) || get<ConstrainedTypeVar>(ty)); LUAU_ASSERT(get<FreeTypeVar>(ty) || get<TableTypeVar>(ty) || get<FunctionTypeVar>(ty));
PendingType* newTy = queue(ty); PendingType* newTy = queue(ty);
if (FreeTypeVar* ftv = Luau::getMutable<FreeTypeVar>(newTy)) if (FreeTypeVar* ftv = Luau::getMutable<FreeTypeVar>(newTy))
@ -267,11 +267,6 @@ PendingType* TxnLog::changeLevel(TypeId ty, TypeLevel newLevel)
{ {
ftv->level = newLevel; ftv->level = newLevel;
} }
else if (ConstrainedTypeVar* ctv = Luau::getMutable<ConstrainedTypeVar>(newTy))
{
if (FFlag::LuauUnknownAndNeverType)
ctv->level = newLevel;
}
return newTy; return newTy;
} }
@ -291,7 +286,7 @@ PendingTypePack* TxnLog::changeLevel(TypePackId tp, TypeLevel newLevel)
PendingType* TxnLog::changeScope(TypeId ty, NotNull<Scope> newScope) PendingType* TxnLog::changeScope(TypeId ty, NotNull<Scope> newScope)
{ {
LUAU_ASSERT(get<FreeTypeVar>(ty) || get<TableTypeVar>(ty) || get<FunctionTypeVar>(ty) || get<ConstrainedTypeVar>(ty)); LUAU_ASSERT(get<FreeTypeVar>(ty) || get<TableTypeVar>(ty) || get<FunctionTypeVar>(ty));
PendingType* newTy = queue(ty); PendingType* newTy = queue(ty);
if (FreeTypeVar* ftv = Luau::getMutable<FreeTypeVar>(newTy)) if (FreeTypeVar* ftv = Luau::getMutable<FreeTypeVar>(newTy))
@ -307,10 +302,6 @@ PendingType* TxnLog::changeScope(TypeId ty, NotNull<Scope> newScope)
{ {
ftv->scope = newScope; ftv->scope = newScope;
} }
else if (ConstrainedTypeVar* ctv = Luau::getMutable<ConstrainedTypeVar>(newTy))
{
ctv->scope = newScope;
}
return newTy; return newTy;
} }

View file

@ -104,16 +104,6 @@ public:
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("*pending-expansion*")); return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("*pending-expansion*"));
} }
AstType* operator()(const ConstrainedTypeVar& ctv)
{
AstArray<AstType*> types;
types.size = ctv.parts.size();
types.data = static_cast<AstType**>(allocator->allocate(sizeof(AstType*) * ctv.parts.size()));
for (size_t i = 0; i < ctv.parts.size(); ++i)
types.data[i] = Luau::visit(*this, ctv.parts[i]->ty);
return allocator->alloc<AstTypeIntersection>(Location(), types);
}
AstType* operator()(const SingletonTypeVar& stv) AstType* operator()(const SingletonTypeVar& stv)
{ {
if (const BooleanSingleton* bs = get<BooleanSingleton>(&stv)) if (const BooleanSingleton* bs = get<BooleanSingleton>(&stv))
@ -348,6 +338,17 @@ public:
{ {
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName{"never"}); return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName{"never"});
} }
AstType* operator()(const UseTypeVar& utv)
{
std::optional<TypeId> ty = utv.scope->lookup(utv.def);
LUAU_ASSERT(ty);
return Luau::visit(*this, (*ty)->ty);
}
AstType* operator()(const NegationTypeVar& ntv)
{
// FIXME: do the same thing we do with ErrorTypeVar
throw std::runtime_error("Cannot convert NegationTypeVar into AstNode");
}
private: private:
Allocator* allocator; Allocator* allocator;

View file

@ -5,6 +5,7 @@
#include "Luau/AstQuery.h" #include "Luau/AstQuery.h"
#include "Luau/Clone.h" #include "Luau/Clone.h"
#include "Luau/Instantiation.h" #include "Luau/Instantiation.h"
#include "Luau/Metamethods.h"
#include "Luau/Normalize.h" #include "Luau/Normalize.h"
#include "Luau/ToString.h" #include "Luau/ToString.h"
#include "Luau/TxnLog.h" #include "Luau/TxnLog.h"
@ -62,6 +63,23 @@ struct StackPusher
} }
}; };
static std::optional<std::string> getIdentifierOfBaseVar(AstExpr* node)
{
if (AstExprGlobal* expr = node->as<AstExprGlobal>())
return expr->name.value;
if (AstExprLocal* expr = node->as<AstExprLocal>())
return expr->local->name.value;
if (AstExprIndexExpr* expr = node->as<AstExprIndexExpr>())
return getIdentifierOfBaseVar(expr->expr);
if (AstExprIndexName* expr = node->as<AstExprIndexName>())
return getIdentifierOfBaseVar(expr->expr);
return std::nullopt;
}
struct TypeChecker2 struct TypeChecker2
{ {
NotNull<SingletonTypes> singletonTypes; NotNull<SingletonTypes> singletonTypes;
@ -750,7 +768,7 @@ struct TypeChecker2
TypeId actualType = lookupType(string); TypeId actualType = lookupType(string);
TypeId stringType = singletonTypes->stringType; TypeId stringType = singletonTypes->stringType;
if (!isSubtype(stringType, actualType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) if (!isSubtype(actualType, stringType, stack.back(), singletonTypes, ice, /* anyIsTop */ false))
{ {
reportError(TypeMismatch{actualType, stringType}, string->location); reportError(TypeMismatch{actualType, stringType}, string->location);
} }
@ -783,26 +801,55 @@ struct TypeChecker2
TypePackId expectedRetType = lookupPack(call); TypePackId expectedRetType = lookupPack(call);
TypeId functionType = lookupType(call->func); TypeId functionType = lookupType(call->func);
LUAU_ASSERT(functionType); TypeId testFunctionType = functionType;
TypePack args;
if (get<AnyTypeVar>(functionType) || get<ErrorTypeVar>(functionType)) if (get<AnyTypeVar>(functionType) || get<ErrorTypeVar>(functionType))
return; return;
else if (std::optional<TypeId> callMm = findMetatableEntry(singletonTypes, module->errors, functionType, "__call", call->func->location))
// TODO: Lots of other types are callable: intersections of functions {
// and things with the __call metamethod. if (get<FunctionTypeVar>(follow(*callMm)))
if (!get<FunctionTypeVar>(functionType)) {
if (std::optional<TypeId> instantiatedCallMm = instantiation.substitute(*callMm))
{
args.head.push_back(functionType);
testFunctionType = follow(*instantiatedCallMm);
}
else
{
reportError(UnificationTooComplex{}, call->func->location);
return;
}
}
else
{
// TODO: This doesn't flag the __call metamethod as the problem
// very clearly.
reportError(CannotCallNonFunction{*callMm}, call->func->location);
return;
}
}
else if (get<FunctionTypeVar>(functionType))
{
if (std::optional<TypeId> instantiatedFunctionType = instantiation.substitute(functionType))
{
testFunctionType = *instantiatedFunctionType;
}
else
{
reportError(UnificationTooComplex{}, call->func->location);
return;
}
}
else
{ {
reportError(CannotCallNonFunction{functionType}, call->func->location); reportError(CannotCallNonFunction{functionType}, call->func->location);
return; return;
} }
TypeId instantiatedFunctionType = follow(instantiation.substitute(functionType).value_or(nullptr));
TypePack args;
for (AstExpr* arg : call->args) for (AstExpr* arg : call->args)
{ {
TypeId argTy = module->astTypes[arg]; TypeId argTy = lookupType(arg);
LUAU_ASSERT(argTy);
args.head.push_back(argTy); args.head.push_back(argTy);
} }
@ -810,7 +857,7 @@ struct TypeChecker2
FunctionTypeVar ftv{argsTp, expectedRetType}; FunctionTypeVar ftv{argsTp, expectedRetType};
TypeId expectedType = arena.addType(ftv); TypeId expectedType = arena.addType(ftv);
if (!isSubtype(instantiatedFunctionType, expectedType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) if (!isSubtype(testFunctionType, expectedType, stack.back(), singletonTypes, ice, /* anyIsTop */ false))
{ {
CloneState cloneState; CloneState cloneState;
expectedType = clone(expectedType, module->internalTypes, cloneState); expectedType = clone(expectedType, module->internalTypes, cloneState);
@ -893,9 +940,204 @@ struct TypeChecker2
void visit(AstExprBinary* expr) void visit(AstExprBinary* expr)
{ {
// TODO!
visit(expr->left); visit(expr->left);
visit(expr->right); visit(expr->right);
NotNull<Scope> scope = stack.back();
bool isEquality = expr->op == AstExprBinary::Op::CompareEq || expr->op == AstExprBinary::Op::CompareNe;
bool isComparison = expr->op >= AstExprBinary::Op::CompareEq && expr->op <= AstExprBinary::Op::CompareGe;
bool isLogical = expr->op == AstExprBinary::Op::And || expr->op == AstExprBinary::Op::Or;
TypeId leftType = lookupType(expr->left);
TypeId rightType = lookupType(expr->right);
if (expr->op == AstExprBinary::Op::Or)
{
leftType = stripNil(singletonTypes, module->internalTypes, leftType);
}
bool isStringOperation = isString(leftType) && isString(rightType);
if (get<AnyTypeVar>(leftType) || get<ErrorTypeVar>(leftType) || get<AnyTypeVar>(rightType) || get<ErrorTypeVar>(rightType))
return;
if ((get<BlockedTypeVar>(leftType) || get<FreeTypeVar>(leftType)) && !isEquality && !isLogical)
{
auto name = getIdentifierOfBaseVar(expr->left);
reportError(CannotInferBinaryOperation{expr->op, name,
isComparison ? CannotInferBinaryOperation::OpKind::Comparison : CannotInferBinaryOperation::OpKind::Operation},
expr->location);
return;
}
if (auto it = kBinaryOpMetamethods.find(expr->op); it != kBinaryOpMetamethods.end())
{
std::optional<TypeId> leftMt = getMetatable(leftType, singletonTypes);
std::optional<TypeId> rightMt = getMetatable(rightType, singletonTypes);
bool matches = leftMt == rightMt;
if (isEquality && !matches)
{
auto testUnion = [&matches, singletonTypes = this->singletonTypes](const UnionTypeVar* utv, std::optional<TypeId> otherMt) {
for (TypeId option : utv)
{
if (getMetatable(follow(option), singletonTypes) == otherMt)
{
matches = true;
break;
}
}
};
if (const UnionTypeVar* utv = get<UnionTypeVar>(leftType); utv && rightMt)
{
testUnion(utv, rightMt);
}
if (const UnionTypeVar* utv = get<UnionTypeVar>(rightType); utv && leftMt && !matches)
{
testUnion(utv, leftMt);
}
}
if (!matches && isComparison)
{
reportError(GenericError{format("Types %s and %s cannot be compared with %s because they do not have the same metatable",
toString(leftType).c_str(), toString(rightType).c_str(), toString(expr->op).c_str())},
expr->location);
return;
}
std::optional<TypeId> mm;
if (std::optional<TypeId> leftMm = findMetatableEntry(singletonTypes, module->errors, leftType, it->second, expr->left->location))
mm = leftMm;
else if (std::optional<TypeId> rightMm = findMetatableEntry(singletonTypes, module->errors, rightType, it->second, expr->right->location))
mm = rightMm;
if (mm)
{
if (const FunctionTypeVar* ftv = get<FunctionTypeVar>(*mm))
{
TypePackId expectedArgs;
// For >= and > we invoke __lt and __le respectively with
// swapped argument ordering.
if (expr->op == AstExprBinary::Op::CompareGe || expr->op == AstExprBinary::Op::CompareGt)
{
expectedArgs = module->internalTypes.addTypePack({rightType, leftType});
}
else
{
expectedArgs = module->internalTypes.addTypePack({leftType, rightType});
}
reportErrors(tryUnify(scope, expr->location, ftv->argTypes, expectedArgs));
if (expr->op == AstExprBinary::CompareEq || expr->op == AstExprBinary::CompareNe || expr->op == AstExprBinary::CompareGe ||
expr->op == AstExprBinary::CompareGt || expr->op == AstExprBinary::Op::CompareLe || expr->op == AstExprBinary::Op::CompareLt)
{
TypePackId expectedRets = module->internalTypes.addTypePack({singletonTypes->booleanType});
if (!isSubtype(ftv->retTypes, expectedRets, scope, singletonTypes, ice))
{
reportError(GenericError{format("Metamethod '%s' must return type 'boolean'", it->second)}, expr->location);
}
}
else if (!first(ftv->retTypes))
{
reportError(GenericError{format("Metamethod '%s' must return a value", it->second)}, expr->location);
}
}
else
{
reportError(CannotCallNonFunction{*mm}, expr->location);
}
return;
}
// If this is a string comparison, or a concatenation of strings, we
// want to fall through to primitive behavior.
else if (!isEquality && !(isStringOperation && (expr->op == AstExprBinary::Op::Concat || isComparison)))
{
if (leftMt || rightMt)
{
if (isComparison)
{
reportError(GenericError{format(
"Types '%s' and '%s' cannot be compared with %s because neither type's metatable has a '%s' metamethod",
toString(leftType).c_str(), toString(rightType).c_str(), toString(expr->op).c_str(), it->second)},
expr->location);
}
else
{
reportError(GenericError{format(
"Operator %s is not applicable for '%s' and '%s' because neither type's metatable has a '%s' metamethod",
toString(expr->op).c_str(), toString(leftType).c_str(), toString(rightType).c_str(), it->second)},
expr->location);
}
return;
}
else if (!leftMt && !rightMt && (get<TableTypeVar>(leftType) || get<TableTypeVar>(rightType)))
{
if (isComparison)
{
reportError(GenericError{format("Types '%s' and '%s' cannot be compared with %s because neither type has a metatable",
toString(leftType).c_str(), toString(rightType).c_str(), toString(expr->op).c_str())},
expr->location);
}
else
{
reportError(GenericError{format("Operator %s is not applicable for '%s' and '%s' because neither type has a metatable",
toString(expr->op).c_str(), toString(leftType).c_str(), toString(rightType).c_str())},
expr->location);
}
return;
}
}
}
switch (expr->op)
{
case AstExprBinary::Op::Add:
case AstExprBinary::Op::Sub:
case AstExprBinary::Op::Mul:
case AstExprBinary::Op::Div:
case AstExprBinary::Op::Pow:
case AstExprBinary::Op::Mod:
reportErrors(tryUnify(scope, expr->left->location, leftType, singletonTypes->numberType));
reportErrors(tryUnify(scope, expr->right->location, rightType, singletonTypes->numberType));
break;
case AstExprBinary::Op::Concat:
reportErrors(tryUnify(scope, expr->left->location, leftType, singletonTypes->stringType));
reportErrors(tryUnify(scope, expr->right->location, rightType, singletonTypes->stringType));
break;
case AstExprBinary::Op::CompareGe:
case AstExprBinary::Op::CompareGt:
case AstExprBinary::Op::CompareLe:
case AstExprBinary::Op::CompareLt:
if (isNumber(leftType))
reportErrors(tryUnify(scope, expr->right->location, rightType, singletonTypes->numberType));
else if (isString(leftType))
reportErrors(tryUnify(scope, expr->right->location, rightType, singletonTypes->stringType));
else
reportError(GenericError{format("Types '%s' and '%s' cannot be compared with relational operator %s", toString(leftType).c_str(),
toString(rightType).c_str(), toString(expr->op).c_str())},
expr->location);
break;
case AstExprBinary::Op::And:
case AstExprBinary::Op::Or:
case AstExprBinary::Op::CompareEq:
case AstExprBinary::Op::CompareNe:
break;
default:
// Unhandled AstExprBinary::Op possibility.
LUAU_ASSERT(false);
}
} }
void visit(AstExprTypeAssertion* expr) void visit(AstExprTypeAssertion* expr)

View file

@ -31,7 +31,6 @@ LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 5000)
LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 300) LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 300)
LUAU_FASTINTVARIABLE(LuauVisitRecursionLimit, 500) LUAU_FASTINTVARIABLE(LuauVisitRecursionLimit, 500)
LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAG(LuauKnowsTheDataModel3)
LUAU_FASTFLAG(LuauAutocompleteDynamicLimits)
LUAU_FASTFLAG(LuauTypeNormalization2) LUAU_FASTFLAG(LuauTypeNormalization2)
LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false)
LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false. LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false.
@ -280,11 +279,8 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo
iceHandler->moduleName = module.name; iceHandler->moduleName = module.name;
normalizer.arena = &currentModule->internalTypes; normalizer.arena = &currentModule->internalTypes;
if (FFlag::LuauAutocompleteDynamicLimits) unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit;
{ unifierState.counters.iterationLimit = unifierIterationLimit ? *unifierIterationLimit : FInt::LuauTypeInferIterationLimit;
unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit;
unifierState.counters.iterationLimit = unifierIterationLimit ? *unifierIterationLimit : FInt::LuauTypeInferIterationLimit;
}
ScopePtr parentScope = environmentScope.value_or(globalScope); ScopePtr parentScope = environmentScope.value_or(globalScope);
ScopePtr moduleScope = std::make_shared<Scope>(parentScope); ScopePtr moduleScope = std::make_shared<Scope>(parentScope);
@ -773,16 +769,6 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatRepeat& statement)
checkExpr(repScope, *statement.condition); checkExpr(repScope, *statement.condition);
} }
void TypeChecker::unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel demotedLevel, const ScopePtr& scope, const Location& location)
{
Unifier state = mkUnifier(scope, location);
state.unifyLowerBound(subTy, superTy, demotedLevel);
state.log.commit();
reportErrors(state.errors);
}
struct Demoter : Substitution struct Demoter : Substitution
{ {
Demoter(TypeArena* arena) Demoter(TypeArena* arena)
@ -2091,39 +2077,6 @@ std::optional<TypeId> TypeChecker::getIndexTypeFromTypeImpl(
return std::nullopt; return std::nullopt;
} }
std::vector<TypeId> TypeChecker::reduceUnion(const std::vector<TypeId>& types)
{
std::vector<TypeId> result;
for (TypeId t : types)
{
t = follow(t);
if (get<NeverTypeVar>(t))
continue;
if (get<ErrorTypeVar>(t) || get<AnyTypeVar>(t))
return {t};
if (const UnionTypeVar* utv = get<UnionTypeVar>(t))
{
for (TypeId ty : utv)
{
ty = follow(ty);
if (get<NeverTypeVar>(ty))
continue;
if (get<ErrorTypeVar>(ty) || get<AnyTypeVar>(ty))
return {ty};
if (result.end() == std::find(result.begin(), result.end(), ty))
result.push_back(ty);
}
}
else if (std::find(result.begin(), result.end(), t) == result.end())
result.push_back(t);
}
return result;
}
std::optional<TypeId> TypeChecker::tryStripUnionFromNil(TypeId ty) std::optional<TypeId> TypeChecker::tryStripUnionFromNil(TypeId ty)
{ {
if (const UnionTypeVar* utv = get<UnionTypeVar>(ty)) if (const UnionTypeVar* utv = get<UnionTypeVar>(ty))
@ -4597,7 +4550,7 @@ TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location locat
Instantiation instantiation{log, &currentModule->internalTypes, scope->level, /*scope*/ nullptr}; Instantiation instantiation{log, &currentModule->internalTypes, scope->level, /*scope*/ nullptr};
if (FFlag::LuauAutocompleteDynamicLimits && instantiationChildLimit) if (instantiationChildLimit)
instantiation.childLimit = *instantiationChildLimit; instantiation.childLimit = *instantiationChildLimit;
std::optional<TypeId> instantiated = instantiation.substitute(ty); std::optional<TypeId> instantiated = instantiation.substitute(ty);

View file

@ -6,6 +6,8 @@
#include "Luau/ToString.h" #include "Luau/ToString.h"
#include "Luau/TypeInfer.h" #include "Luau/TypeInfer.h"
#include <algorithm>
namespace Luau namespace Luau
{ {
@ -146,18 +148,15 @@ std::optional<TypeId> getIndexTypeFromType(const ScopePtr& scope, ErrorVec& erro
return std::nullopt; return std::nullopt;
} }
goodOptions = reduceUnion(goodOptions);
if (goodOptions.empty()) if (goodOptions.empty())
return singletonTypes->neverType; return singletonTypes->neverType;
if (goodOptions.size() == 1) if (goodOptions.size() == 1)
return goodOptions[0]; return goodOptions[0];
// TODO: inefficient. return arena->addType(UnionTypeVar{std::move(goodOptions)});
TypeId result = arena->addType(UnionTypeVar{std::move(goodOptions)});
auto [ty, ok] = normalize(result, NotNull{scope.get()}, *arena, singletonTypes, handle);
if (!ok && addErrors)
errors.push_back(TypeError{location, NormalizationTooComplex{}});
return ok ? ty : singletonTypes->anyType;
} }
else if (const IntersectionTypeVar* itv = get<IntersectionTypeVar>(type)) else if (const IntersectionTypeVar* itv = get<IntersectionTypeVar>(type))
{ {
@ -264,4 +263,79 @@ std::vector<TypeId> flatten(TypeArena& arena, NotNull<SingletonTypes> singletonT
return result; return result;
} }
std::vector<TypeId> reduceUnion(const std::vector<TypeId>& types)
{
std::vector<TypeId> result;
for (TypeId t : types)
{
t = follow(t);
if (get<NeverTypeVar>(t))
continue;
if (get<ErrorTypeVar>(t) || get<AnyTypeVar>(t))
return {t};
if (const UnionTypeVar* utv = get<UnionTypeVar>(t))
{
for (TypeId ty : utv)
{
ty = follow(ty);
if (get<NeverTypeVar>(ty))
continue;
if (get<ErrorTypeVar>(ty) || get<AnyTypeVar>(ty))
return {ty};
if (result.end() == std::find(result.begin(), result.end(), ty))
result.push_back(ty);
}
}
else if (std::find(result.begin(), result.end(), t) == result.end())
result.push_back(t);
}
return result;
}
static std::optional<TypeId> tryStripUnionFromNil(TypeArena& arena, TypeId ty)
{
if (const UnionTypeVar* utv = get<UnionTypeVar>(ty))
{
if (!std::any_of(begin(utv), end(utv), isNil))
return ty;
std::vector<TypeId> result;
for (TypeId option : utv)
{
if (!isNil(option))
result.push_back(option);
}
if (result.empty())
return std::nullopt;
return result.size() == 1 ? result[0] : arena.addType(UnionTypeVar{std::move(result)});
}
return std::nullopt;
}
TypeId stripNil(NotNull<SingletonTypes> singletonTypes, TypeArena& arena, TypeId ty)
{
ty = follow(ty);
if (get<UnionTypeVar>(ty))
{
std::optional<TypeId> cleaned = tryStripUnionFromNil(arena, ty);
// If there is no union option without 'nil'
if (!cleaned)
return singletonTypes->nilType;
return follow(*cleaned);
}
return follow(ty);
}
} // namespace Luau } // namespace Luau

View file

@ -57,6 +57,13 @@ TypeId follow(TypeId t, std::function<TypeId(TypeId)> mapper)
return btv->boundTo; return btv->boundTo;
else if (auto ttv = get<TableTypeVar>(mapper(ty))) else if (auto ttv = get<TableTypeVar>(mapper(ty)))
return ttv->boundTo; return ttv->boundTo;
else if (auto utv = get<UseTypeVar>(mapper(ty)))
{
std::optional<TypeId> ty = utv->scope->lookup(utv->def);
if (!ty)
throw std::runtime_error("UseTypeVar must map to another TypeId");
return *ty;
}
else else
return std::nullopt; return std::nullopt;
}; };
@ -760,6 +767,8 @@ SingletonTypes::SingletonTypes()
, unknownType(arena->addType(TypeVar{UnknownTypeVar{}, /*persistent*/ true})) , unknownType(arena->addType(TypeVar{UnknownTypeVar{}, /*persistent*/ true}))
, neverType(arena->addType(TypeVar{NeverTypeVar{}, /*persistent*/ true})) , neverType(arena->addType(TypeVar{NeverTypeVar{}, /*persistent*/ true}))
, errorType(arena->addType(TypeVar{ErrorTypeVar{}, /*persistent*/ true})) , errorType(arena->addType(TypeVar{ErrorTypeVar{}, /*persistent*/ true}))
, falsyType(arena->addType(TypeVar{UnionTypeVar{{falseType, nilType}}, /*persistent*/ true}))
, truthyType(arena->addType(TypeVar{NegationTypeVar{falsyType}, /*persistent*/ true}))
, anyTypePack(arena->addTypePack(TypePackVar{VariadicTypePack{anyType}, /*persistent*/ true})) , anyTypePack(arena->addTypePack(TypePackVar{VariadicTypePack{anyType}, /*persistent*/ true}))
, neverTypePack(arena->addTypePack(TypePackVar{VariadicTypePack{neverType}, /*persistent*/ true})) , neverTypePack(arena->addTypePack(TypePackVar{VariadicTypePack{neverType}, /*persistent*/ true}))
, uninhabitableTypePack(arena->addTypePack({neverType}, neverTypePack)) , uninhabitableTypePack(arena->addTypePack({neverType}, neverTypePack))
@ -896,7 +905,6 @@ void persist(TypeId ty)
continue; continue;
asMutable(t)->persistent = true; asMutable(t)->persistent = true;
asMutable(t)->normal = true; // all persistent types are assumed to be normal
if (auto btv = get<BoundTypeVar>(t)) if (auto btv = get<BoundTypeVar>(t))
queue.push_back(btv->boundTo); queue.push_back(btv->boundTo);
@ -933,11 +941,6 @@ void persist(TypeId ty)
for (TypeId opt : itv->parts) for (TypeId opt : itv->parts)
queue.push_back(opt); queue.push_back(opt);
} }
else if (auto ctv = get<ConstrainedTypeVar>(t))
{
for (TypeId opt : ctv->parts)
queue.push_back(opt);
}
else if (auto mtv = get<MetatableTypeVar>(t)) else if (auto mtv = get<MetatableTypeVar>(t))
{ {
queue.push_back(mtv->table); queue.push_back(mtv->table);
@ -990,8 +993,6 @@ const TypeLevel* getLevel(TypeId ty)
return &ttv->level; return &ttv->level;
else if (auto ftv = get<FunctionTypeVar>(ty)) else if (auto ftv = get<FunctionTypeVar>(ty))
return &ftv->level; return &ftv->level;
else if (auto ctv = get<ConstrainedTypeVar>(ty))
return &ctv->level;
else else
return nullptr; return nullptr;
} }
@ -1056,11 +1057,6 @@ const std::vector<TypeId>& getTypes(const IntersectionTypeVar* itv)
return itv->parts; return itv->parts;
} }
const std::vector<TypeId>& getTypes(const ConstrainedTypeVar* ctv)
{
return ctv->parts;
}
UnionTypeVarIterator begin(const UnionTypeVar* utv) UnionTypeVarIterator begin(const UnionTypeVar* utv)
{ {
return UnionTypeVarIterator{utv}; return UnionTypeVarIterator{utv};
@ -1081,17 +1077,6 @@ IntersectionTypeVarIterator end(const IntersectionTypeVar* itv)
return IntersectionTypeVarIterator{}; return IntersectionTypeVarIterator{};
} }
ConstrainedTypeVarIterator begin(const ConstrainedTypeVar* ctv)
{
return ConstrainedTypeVarIterator{ctv};
}
ConstrainedTypeVarIterator end(const ConstrainedTypeVar* ctv)
{
return ConstrainedTypeVarIterator{};
}
static std::vector<TypeId> parseFormatString(TypeChecker& typechecker, const char* data, size_t size) static std::vector<TypeId> parseFormatString(TypeChecker& typechecker, const char* data, size_t size)
{ {
const char* options = "cdiouxXeEfgGqs*"; const char* options = "cdiouxXeEfgGqs*";

View file

@ -13,16 +13,13 @@
#include <algorithm> #include <algorithm>
LUAU_FASTINT(LuauTypeInferRecursionLimit);
LUAU_FASTINT(LuauTypeInferTypePackLoopLimit); LUAU_FASTINT(LuauTypeInferTypePackLoopLimit);
LUAU_FASTINT(LuauTypeInferIterationLimit);
LUAU_FASTFLAG(LuauAutocompleteDynamicLimits)
LUAU_FASTINTVARIABLE(LuauTypeInferLowerBoundsIterationLimit, 2000);
LUAU_FASTFLAG(LuauErrorRecoveryType); LUAU_FASTFLAG(LuauErrorRecoveryType);
LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAG(LuauUnknownAndNeverType)
LUAU_FASTFLAGVARIABLE(LuauSubtypeNormalizer, false); LUAU_FASTFLAGVARIABLE(LuauSubtypeNormalizer, false);
LUAU_FASTFLAGVARIABLE(LuauScalarShapeSubtyping, false) LUAU_FASTFLAGVARIABLE(LuauScalarShapeSubtyping, false)
LUAU_FASTFLAGVARIABLE(LuauInstantiateInSubtyping, false) LUAU_FASTFLAGVARIABLE(LuauInstantiateInSubtyping, false)
LUAU_FASTFLAGVARIABLE(LuauOverloadedFunctionSubtypingPerf, false);
LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution)
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
@ -95,15 +92,6 @@ struct PromoteTypeLevels final : TypeVarOnceVisitor
return true; return true;
} }
bool visit(TypeId ty, const ConstrainedTypeVar&) override
{
if (!FFlag::LuauUnknownAndNeverType)
return visit(ty);
promote(ty, log.getMutable<ConstrainedTypeVar>(ty));
return true;
}
bool visit(TypeId ty, const FunctionTypeVar&) override bool visit(TypeId ty, const FunctionTypeVar&) override
{ {
// Type levels of types from other modules are already global, so we don't need to promote anything inside // Type levels of types from other modules are already global, so we don't need to promote anything inside
@ -368,26 +356,14 @@ void Unifier::tryUnify(TypeId subTy, TypeId superTy, bool isFunctionCall, bool i
void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool isIntersection) void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool isIntersection)
{ {
RecursionLimiter _ra(&sharedState.counters.recursionCount, RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit);
FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit);
++sharedState.counters.iterationCount; ++sharedState.counters.iterationCount;
if (FFlag::LuauAutocompleteDynamicLimits) if (sharedState.counters.iterationLimit > 0 && sharedState.counters.iterationLimit < sharedState.counters.iterationCount)
{ {
if (sharedState.counters.iterationLimit > 0 && sharedState.counters.iterationLimit < sharedState.counters.iterationCount) reportError(TypeError{location, UnificationTooComplex{}});
{ return;
reportError(TypeError{location, UnificationTooComplex{}});
return;
}
}
else
{
if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < sharedState.counters.iterationCount)
{
reportError(TypeError{location, UnificationTooComplex{}});
return;
}
} }
superTy = log.follow(superTy); superTy = log.follow(superTy);
@ -396,9 +372,6 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
if (superTy == subTy) if (superTy == subTy)
return; return;
if (log.get<ConstrainedTypeVar>(superTy))
return tryUnifyWithConstrainedSuperTypeVar(subTy, superTy);
auto superFree = log.getMutable<FreeTypeVar>(superTy); auto superFree = log.getMutable<FreeTypeVar>(superTy);
auto subFree = log.getMutable<FreeTypeVar>(subTy); auto subFree = log.getMutable<FreeTypeVar>(subTy);
@ -520,9 +493,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
size_t errorCount = errors.size(); size_t errorCount = errors.size();
if (log.get<ConstrainedTypeVar>(subTy)) if (const UnionTypeVar* subUnion = log.getMutable<UnionTypeVar>(subTy))
tryUnifyWithConstrainedSubTypeVar(subTy, superTy);
else if (const UnionTypeVar* subUnion = log.getMutable<UnionTypeVar>(subTy))
{ {
tryUnifyUnionWithType(subTy, subUnion, superTy); tryUnifyUnionWithType(subTy, subUnion, superTy);
} }
@ -1011,10 +982,17 @@ TypePackId Unifier::tryApplyOverloadedFunction(TypeId function, const Normalized
log.concat(std::move(innerState.log)); log.concat(std::move(innerState.log));
if (result) if (result)
{ {
if (FFlag::LuauOverloadedFunctionSubtypingPerf)
{
innerState.log.clear();
innerState.tryUnify_(*result, ftv->retTypes);
}
if (FFlag::LuauOverloadedFunctionSubtypingPerf && innerState.errors.empty())
log.concat(std::move(innerState.log));
// Annoyingly, since we don't support intersection of generic type packs, // Annoyingly, since we don't support intersection of generic type packs,
// the intersection may fail. We rather arbitrarily use the first matching overload // the intersection may fail. We rather arbitrarily use the first matching overload
// in that case. // in that case.
if (std::optional<TypePackId> intersect = normalizer->intersectionOfTypePacks(*result, ftv->retTypes)) else if (std::optional<TypePackId> intersect = normalizer->intersectionOfTypePacks(*result, ftv->retTypes))
result = intersect; result = intersect;
} }
else else
@ -1214,26 +1192,14 @@ void Unifier::tryUnify(TypePackId subTp, TypePackId superTp, bool isFunctionCall
*/ */
void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCall) void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCall)
{ {
RecursionLimiter _ra(&sharedState.counters.recursionCount, RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit);
FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit);
++sharedState.counters.iterationCount; ++sharedState.counters.iterationCount;
if (FFlag::LuauAutocompleteDynamicLimits) if (sharedState.counters.iterationLimit > 0 && sharedState.counters.iterationLimit < sharedState.counters.iterationCount)
{ {
if (sharedState.counters.iterationLimit > 0 && sharedState.counters.iterationLimit < sharedState.counters.iterationCount) reportError(TypeError{location, UnificationTooComplex{}});
{ return;
reportError(TypeError{location, UnificationTooComplex{}});
return;
}
}
else
{
if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < sharedState.counters.iterationCount)
{
reportError(TypeError{location, UnificationTooComplex{}});
return;
}
} }
superTp = log.follow(superTp); superTp = log.follow(superTp);
@ -2314,186 +2280,6 @@ std::optional<TypeId> Unifier::findTablePropertyRespectingMeta(TypeId lhsType, N
return Luau::findTablePropertyRespectingMeta(singletonTypes, errors, lhsType, name, location); return Luau::findTablePropertyRespectingMeta(singletonTypes, errors, lhsType, name, location);
} }
void Unifier::tryUnifyWithConstrainedSubTypeVar(TypeId subTy, TypeId superTy)
{
const ConstrainedTypeVar* subConstrained = get<ConstrainedTypeVar>(subTy);
if (!subConstrained)
ice("tryUnifyWithConstrainedSubTypeVar received non-ConstrainedTypeVar subTy!");
const std::vector<TypeId>& subTyParts = subConstrained->parts;
// A | B <: T if A <: T and B <: T
bool failed = false;
std::optional<TypeError> unificationTooComplex;
const size_t count = subTyParts.size();
for (size_t i = 0; i < count; ++i)
{
TypeId type = subTyParts[i];
Unifier innerState = makeChildUnifier();
innerState.tryUnify_(type, superTy);
if (i == count - 1)
log.concat(std::move(innerState.log));
++i;
if (auto e = hasUnificationTooComplex(innerState.errors))
unificationTooComplex = e;
if (!innerState.errors.empty())
{
failed = true;
break;
}
}
if (unificationTooComplex)
reportError(*unificationTooComplex);
else if (failed)
reportError(TypeError{location, TypeMismatch{superTy, subTy}});
else
log.replace(subTy, BoundTypeVar{superTy});
}
void Unifier::tryUnifyWithConstrainedSuperTypeVar(TypeId subTy, TypeId superTy)
{
ConstrainedTypeVar* superC = log.getMutable<ConstrainedTypeVar>(superTy);
if (!superC)
ice("tryUnifyWithConstrainedSuperTypeVar received non-ConstrainedTypeVar superTy!");
// subTy could be a
// table
// metatable
// class
// function
// primitive
// free
// generic
// intersection
// union
// Do we really just tack it on? I think we might!
// We can certainly do some deduplication.
// Is there any point to deducing Player|Instance when we could just reduce to Instance?
// Is it actually ok to have multiple free types in a single intersection? What if they are later unified into the same type?
// Maybe we do a simplification step during quantification.
auto it = std::find(superC->parts.begin(), superC->parts.end(), subTy);
if (it != superC->parts.end())
return;
superC->parts.push_back(subTy);
}
void Unifier::unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel demotedLevel)
{
// The duplication between this and regular typepack unification is tragic.
auto superIter = begin(superTy, &log);
auto superEndIter = end(superTy);
auto subIter = begin(subTy, &log);
auto subEndIter = end(subTy);
int count = FInt::LuauTypeInferLowerBoundsIterationLimit;
for (; subIter != subEndIter; ++subIter)
{
if (0 >= --count)
ice("Internal recursion counter limit exceeded in Unifier::unifyLowerBound");
if (superIter != superEndIter)
{
tryUnify_(*subIter, *superIter);
++superIter;
continue;
}
if (auto t = superIter.tail())
{
TypePackId tailPack = follow(*t);
if (log.get<FreeTypePack>(tailPack) && occursCheck(tailPack, subTy))
return;
FreeTypePack* freeTailPack = log.getMutable<FreeTypePack>(tailPack);
if (!freeTailPack)
return;
TypePack* tp = getMutable<TypePack>(log.replace(tailPack, TypePack{}));
for (; subIter != subEndIter; ++subIter)
{
tp->head.push_back(types->addType(ConstrainedTypeVar{demotedLevel, {follow(*subIter)}}));
}
tp->tail = subIter.tail();
}
return;
}
if (superIter != superEndIter)
{
if (auto subTail = subIter.tail())
{
TypePackId subTailPack = follow(*subTail);
if (get<FreeTypePack>(subTailPack))
{
TypePack* tp = getMutable<TypePack>(log.replace(subTailPack, TypePack{}));
for (; superIter != superEndIter; ++superIter)
tp->head.push_back(*superIter);
}
else if (const VariadicTypePack* subVariadic = log.getMutable<VariadicTypePack>(subTailPack))
{
while (superIter != superEndIter)
{
tryUnify_(subVariadic->ty, *superIter);
++superIter;
}
}
}
else
{
while (superIter != superEndIter)
{
if (!isOptional(*superIter))
{
errors.push_back(TypeError{location, CountMismatch{size(superTy), std::nullopt, size(subTy), CountMismatch::Return}});
return;
}
++superIter;
}
}
return;
}
// Both iters are at their respective tails
auto subTail = subIter.tail();
auto superTail = superIter.tail();
if (subTail && superTail)
tryUnify(*subTail, *superTail);
else if (subTail)
{
const FreeTypePack* freeSubTail = log.getMutable<FreeTypePack>(*subTail);
if (freeSubTail)
{
log.replace(*subTail, TypePack{});
}
}
else if (superTail)
{
const FreeTypePack* freeSuperTail = log.getMutable<FreeTypePack>(*superTail);
if (freeSuperTail)
{
log.replace(*superTail, TypePack{});
}
}
}
bool Unifier::occursCheck(TypeId needle, TypeId haystack) bool Unifier::occursCheck(TypeId needle, TypeId haystack)
{ {
sharedState.tempSeenTy.clear(); sharedState.tempSeenTy.clear();
@ -2503,8 +2289,7 @@ bool Unifier::occursCheck(TypeId needle, TypeId haystack)
bool Unifier::occursCheck(DenseHashSet<TypeId>& seen, TypeId needle, TypeId haystack) bool Unifier::occursCheck(DenseHashSet<TypeId>& seen, TypeId needle, TypeId haystack)
{ {
RecursionLimiter _ra(&sharedState.counters.recursionCount, RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit);
FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit);
bool occurrence = false; bool occurrence = false;
@ -2547,11 +2332,6 @@ bool Unifier::occursCheck(DenseHashSet<TypeId>& seen, TypeId needle, TypeId hays
for (TypeId ty : a->parts) for (TypeId ty : a->parts)
check(ty); check(ty);
} }
else if (auto a = log.getMutable<ConstrainedTypeVar>(haystack))
{
for (TypeId ty : a->parts)
check(ty);
}
return occurrence; return occurrence;
} }
@ -2579,8 +2359,7 @@ bool Unifier::occursCheck(DenseHashSet<TypePackId>& seen, TypePackId needle, Typ
if (!log.getMutable<Unifiable::Free>(needle)) if (!log.getMutable<Unifiable::Free>(needle))
ice("Expected needle pack to be free"); ice("Expected needle pack to be free");
RecursionLimiter _ra(&sharedState.counters.recursionCount, RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit);
FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit);
while (!log.getMutable<ErrorTypeVar>(haystack)) while (!log.getMutable<ErrorTypeVar>(haystack))
{ {

View file

@ -14,7 +14,6 @@
#endif #endif
LUAU_FASTFLAG(DebugLuauTimeTracing) LUAU_FASTFLAG(DebugLuauTimeTracing)
LUAU_FASTFLAG(LuauTypeMismatchModuleNameResolution)
enum class ReportFormat enum class ReportFormat
{ {
@ -55,11 +54,9 @@ static void reportError(const Luau::Frontend& frontend, ReportFormat format, con
if (const Luau::SyntaxError* syntaxError = Luau::get_if<Luau::SyntaxError>(&error.data)) if (const Luau::SyntaxError* syntaxError = Luau::get_if<Luau::SyntaxError>(&error.data))
report(format, humanReadableName.c_str(), error.location, "SyntaxError", syntaxError->message.c_str()); report(format, humanReadableName.c_str(), error.location, "SyntaxError", syntaxError->message.c_str());
else if (FFlag::LuauTypeMismatchModuleNameResolution) else
report(format, humanReadableName.c_str(), error.location, "TypeError", report(format, humanReadableName.c_str(), error.location, "TypeError",
Luau::toString(error, Luau::TypeErrorToStringOptions{frontend.fileResolver}).c_str()); Luau::toString(error, Luau::TypeErrorToStringOptions{frontend.fileResolver}).c_str());
else
report(format, humanReadableName.c_str(), error.location, "TypeError", Luau::toString(error).c_str());
} }
static void reportWarning(ReportFormat format, const char* name, const Luau::LintWarning& warning) static void reportWarning(ReportFormat format, const char* name, const Luau::LintWarning& warning)

View file

@ -16,6 +16,7 @@
#include "isocline.h" #include "isocline.h"
#include <algorithm>
#include <memory> #include <memory>
#ifdef _WIN32 #ifdef _WIN32
@ -49,6 +50,8 @@ enum class CompileFormat
Binary, Binary,
Remarks, Remarks,
Codegen, Codegen,
CodegenVerbose,
CodegenNull,
Null Null
}; };
@ -673,21 +676,33 @@ static void reportError(const char* name, const Luau::CompileError& error)
report(name, error.getLocation(), "CompileError", error.what()); report(name, error.getLocation(), "CompileError", error.what());
} }
static std::string getCodegenAssembly(const char* name, const std::string& bytecode) static std::string getCodegenAssembly(const char* name, const std::string& bytecode, Luau::CodeGen::AssemblyOptions options)
{ {
std::unique_ptr<lua_State, void (*)(lua_State*)> globalState(luaL_newstate(), lua_close); std::unique_ptr<lua_State, void (*)(lua_State*)> globalState(luaL_newstate(), lua_close);
lua_State* L = globalState.get(); lua_State* L = globalState.get();
setupState(L);
if (luau_load(L, name, bytecode.data(), bytecode.size(), 0) == 0) if (luau_load(L, name, bytecode.data(), bytecode.size(), 0) == 0)
return Luau::CodeGen::getAssemblyText(L, -1); return Luau::CodeGen::getAssembly(L, -1, options);
fprintf(stderr, "Error loading bytecode %s\n", name); fprintf(stderr, "Error loading bytecode %s\n", name);
return ""; return "";
} }
static bool compileFile(const char* name, CompileFormat format) static void annotateInstruction(void* context, std::string& text, int fid, int instid)
{
Luau::BytecodeBuilder& bcb = *(Luau::BytecodeBuilder*)context;
bcb.annotateInstruction(text, fid, instid);
}
struct CompileStats
{
size_t lines;
size_t bytecode;
size_t codegen;
};
static bool compileFile(const char* name, CompileFormat format, CompileStats& stats)
{ {
std::optional<std::string> source = readFile(name); std::optional<std::string> source = readFile(name);
if (!source) if (!source)
@ -696,9 +711,12 @@ static bool compileFile(const char* name, CompileFormat format)
return false; return false;
} }
stats.lines += std::count(source->begin(), source->end(), '\n');
try try
{ {
Luau::BytecodeBuilder bcb; Luau::BytecodeBuilder bcb;
Luau::CodeGen::AssemblyOptions options = {format == CompileFormat::CodegenNull, format == CompileFormat::Codegen, annotateInstruction, &bcb};
if (format == CompileFormat::Text) if (format == CompileFormat::Text)
{ {
@ -711,8 +729,15 @@ static bool compileFile(const char* name, CompileFormat format)
bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Remarks); bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Remarks);
bcb.setDumpSource(*source); bcb.setDumpSource(*source);
} }
else if (format == CompileFormat::Codegen || format == CompileFormat::CodegenVerbose)
{
bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Locals |
Luau::BytecodeBuilder::Dump_Remarks);
bcb.setDumpSource(*source);
}
Luau::compileOrThrow(bcb, *source, copts()); Luau::compileOrThrow(bcb, *source, copts());
stats.bytecode += bcb.getBytecode().size();
switch (format) switch (format)
{ {
@ -726,7 +751,11 @@ static bool compileFile(const char* name, CompileFormat format)
fwrite(bcb.getBytecode().data(), 1, bcb.getBytecode().size(), stdout); fwrite(bcb.getBytecode().data(), 1, bcb.getBytecode().size(), stdout);
break; break;
case CompileFormat::Codegen: case CompileFormat::Codegen:
printf("%s", getCodegenAssembly(name, bcb.getBytecode()).c_str()); case CompileFormat::CodegenVerbose:
printf("%s", getCodegenAssembly(name, bcb.getBytecode(), options).c_str());
break;
case CompileFormat::CodegenNull:
stats.codegen += getCodegenAssembly(name, bcb.getBytecode(), options).size();
break; break;
case CompileFormat::Null: case CompileFormat::Null:
break; break;
@ -755,7 +784,7 @@ static void displayHelp(const char* argv0)
printf("\n"); printf("\n");
printf("Available modes:\n"); printf("Available modes:\n");
printf(" omitted: compile and run input files one by one\n"); printf(" omitted: compile and run input files one by one\n");
printf(" --compile[=format]: compile input files and output resulting formatted bytecode (binary, text, remarks, codegen or null)\n"); printf(" --compile[=format]: compile input files and output resulting bytecode/assembly (binary, text, remarks, codegen)\n");
printf("\n"); printf("\n");
printf("Available options:\n"); printf("Available options:\n");
printf(" --coverage: collect code coverage while running the code and output results to coverage.out\n"); printf(" --coverage: collect code coverage while running the code and output results to coverage.out\n");
@ -812,6 +841,14 @@ int replMain(int argc, char** argv)
{ {
compileFormat = CompileFormat::Codegen; compileFormat = CompileFormat::Codegen;
} }
else if (strcmp(argv[1], "--compile=codegenverbose") == 0)
{
compileFormat = CompileFormat::CodegenVerbose;
}
else if (strcmp(argv[1], "--compile=codegennull") == 0)
{
compileFormat = CompileFormat::CodegenNull;
}
else if (strcmp(argv[1], "--compile=null") == 0) else if (strcmp(argv[1], "--compile=null") == 0)
{ {
compileFormat = CompileFormat::Null; compileFormat = CompileFormat::Null;
@ -923,10 +960,16 @@ int replMain(int argc, char** argv)
_setmode(_fileno(stdout), _O_BINARY); _setmode(_fileno(stdout), _O_BINARY);
#endif #endif
CompileStats stats = {};
int failed = 0; int failed = 0;
for (const std::string& path : files) for (const std::string& path : files)
failed += !compileFile(path.c_str(), compileFormat); failed += !compileFile(path.c_str(), compileFormat, stats);
if (compileFormat == CompileFormat::Null)
printf("Compiled %d KLOC into %d KB bytecode\n", int(stats.lines / 1000), int(stats.bytecode / 1024));
else if (compileFormat == CompileFormat::CodegenNull)
printf("Compiled %d KLOC into %d KB bytecode => %d KB native code\n", int(stats.lines / 1000), int(stats.bytecode / 1024), int(stats.codegen / 1024));
return failed ? 1 : 0; return failed ? 1 : 0;
} }

View file

@ -23,6 +23,19 @@ enum class RoundingModeX64
RoundToZero = 0b11, RoundToZero = 0b11,
}; };
enum class AlignmentDataX64
{
Nop,
Int3,
Ud2, // int3 will be used as a fall-back if it doesn't fit
};
enum class ABIX64
{
Windows,
SystemV,
};
class AssemblyBuilderX64 class AssemblyBuilderX64
{ {
public: public:
@ -80,6 +93,10 @@ public:
void int3(); void int3();
// Code alignment
void nop(uint32_t length = 1);
void align(uint32_t alignment, AlignmentDataX64 data = AlignmentDataX64::Nop);
// AVX // AVX
void vaddpd(OperandX64 dst, OperandX64 src1, OperandX64 src2); void vaddpd(OperandX64 dst, OperandX64 src1, OperandX64 src2);
void vaddps(OperandX64 dst, OperandX64 src1, OperandX64 src2); void vaddps(OperandX64 dst, OperandX64 src1, OperandX64 src2);
@ -131,6 +148,8 @@ public:
void logAppend(const char* fmt, ...) LUAU_PRINTF_ATTR(2, 3); void logAppend(const char* fmt, ...) LUAU_PRINTF_ATTR(2, 3);
uint32_t getCodeSize() const;
// Resulting data and code that need to be copied over one after the other // Resulting data and code that need to be copied over one after the other
// The *end* of 'data' has to be aligned to 16 bytes, this will also align 'code' // The *end* of 'data' has to be aligned to 16 bytes, this will also align 'code'
std::vector<uint8_t> data; std::vector<uint8_t> data;
@ -140,6 +159,8 @@ public:
const bool logText = false; const bool logText = false;
const ABIX64 abi;
private: private:
// Instruction archetypes // Instruction archetypes
void placeBinary(const char* name, OperandX64 lhs, OperandX64 rhs, uint8_t codeimm8, uint8_t codeimm, uint8_t codeimmImm8, uint8_t code8rev, void placeBinary(const char* name, OperandX64 lhs, OperandX64 rhs, uint8_t codeimm8, uint8_t codeimm, uint8_t codeimmImm8, uint8_t code8rev,
@ -177,7 +198,6 @@ private:
void commit(); void commit();
LUAU_NOINLINE void extend(); LUAU_NOINLINE void extend();
uint32_t getCodeSize();
// Data // Data
size_t allocateData(size_t size, size_t align); size_t allocateData(size_t size, size_t align);
@ -192,8 +212,8 @@ private:
LUAU_NOINLINE void log(const char* opcode, Label label); LUAU_NOINLINE void log(const char* opcode, Label label);
void log(OperandX64 op); void log(OperandX64 op);
const char* getSizeName(SizeX64 size); const char* getSizeName(SizeX64 size) const;
const char* getRegisterName(RegisterX64 reg); const char* getRegisterName(RegisterX64 reg) const;
uint32_t nextLabel = 1; uint32_t nextLabel = 1;
std::vector<Label> pendingLabels; std::vector<Label> pendingLabels;

View file

@ -11,6 +11,8 @@ namespace Luau
namespace CodeGen namespace CodeGen
{ {
constexpr uint32_t kCodeAlignment = 32;
struct CodeAllocator struct CodeAllocator
{ {
CodeAllocator(size_t blockSize, size_t maxTotalSize); CodeAllocator(size_t blockSize, size_t maxTotalSize);

View file

@ -17,8 +17,20 @@ void create(lua_State* L);
// Builds target function and all inner functions // Builds target function and all inner functions
void compile(lua_State* L, int idx); void compile(lua_State* L, int idx);
// Generates assembly text for target function and all inner functions using annotatorFn = void (*)(void* context, std::string& result, int fid, int instid);
std::string getAssemblyText(lua_State* L, int idx);
struct AssemblyOptions
{
bool outputBinary = false;
bool skipOutlinedCode = false;
// Optional annotator function can be provided to describe each instruction, it takes function id and sequential instruction id
annotatorFn annotator = nullptr;
void* annotatorContext = nullptr;
};
// Generates assembly for target function and all inner functions
std::string getAssembly(lua_State* L, int idx, AssemblyOptions options = {});
} // namespace CodeGen } // namespace CodeGen
} // namespace Luau } // namespace Luau

View file

@ -94,6 +94,11 @@ constexpr OperandX64 operator+(RegisterX64 reg, int32_t disp)
return OperandX64(SizeX64::none, noreg, 1, reg, disp); return OperandX64(SizeX64::none, noreg, 1, reg, disp);
} }
constexpr OperandX64 operator-(RegisterX64 reg, int32_t disp)
{
return OperandX64(SizeX64::none, noreg, 1, reg, -disp);
}
constexpr OperandX64 operator+(RegisterX64 base, RegisterX64 index) constexpr OperandX64 operator+(RegisterX64 base, RegisterX64 index)
{ {
LUAU_ASSERT(index.index != 4 && "sp cannot be used as index"); LUAU_ASSERT(index.index != 4 && "sp cannot be used as index");

View file

@ -113,5 +113,20 @@ constexpr RegisterX64 ymm13{SizeX64::ymmword, 13};
constexpr RegisterX64 ymm14{SizeX64::ymmword, 14}; constexpr RegisterX64 ymm14{SizeX64::ymmword, 14};
constexpr RegisterX64 ymm15{SizeX64::ymmword, 15}; constexpr RegisterX64 ymm15{SizeX64::ymmword, 15};
constexpr RegisterX64 byteReg(RegisterX64 reg)
{
return RegisterX64{SizeX64::byte, reg.index};
}
constexpr RegisterX64 wordReg(RegisterX64 reg)
{
return RegisterX64{SizeX64::word, reg.index};
}
constexpr RegisterX64 dwordReg(RegisterX64 reg)
{
return RegisterX64{SizeX64::dword, reg.index};
}
} // namespace CodeGen } // namespace CodeGen
} // namespace Luau } // namespace Luau

View file

@ -29,7 +29,7 @@ static_assert(sizeof(textForCondition) / sizeof(textForCondition[0]) == size_t(C
#define REX_X(reg) (((reg).index & 0x8) >> 2) #define REX_X(reg) (((reg).index & 0x8) >> 2)
#define REX_B(reg) (((reg).index & 0x8) >> 3) #define REX_B(reg) (((reg).index & 0x8) >> 3)
#define AVX_W(value) (!(value) ? 0x80 : 0x0) #define AVX_W(value) ((value) ? 0x80 : 0x0)
#define AVX_R(reg) ((~(reg).index & 0x8) << 4) #define AVX_R(reg) ((~(reg).index & 0x8) << 4)
#define AVX_X(reg) ((~(reg).index & 0x8) << 3) #define AVX_X(reg) ((~(reg).index & 0x8) << 3)
#define AVX_B(reg) ((~(reg).index & 0x8) << 2) #define AVX_B(reg) ((~(reg).index & 0x8) << 2)
@ -50,12 +50,23 @@ const unsigned AVX_66 = 0b01;
const unsigned AVX_F3 = 0b10; const unsigned AVX_F3 = 0b10;
const unsigned AVX_F2 = 0b11; const unsigned AVX_F2 = 0b11;
const unsigned kMaxAlign = 16; const unsigned kMaxAlign = 32;
const unsigned kMaxInstructionLength = 16;
const uint8_t kRoundingPrecisionInexact = 0b1000; const uint8_t kRoundingPrecisionInexact = 0b1000;
static ABIX64 getCurrentX64ABI()
{
#if defined(_WIN32)
return ABIX64::Windows;
#else
return ABIX64::SystemV;
#endif
}
AssemblyBuilderX64::AssemblyBuilderX64(bool logText) AssemblyBuilderX64::AssemblyBuilderX64(bool logText)
: logText(logText) : logText(logText)
, abi(getCurrentX64ABI())
{ {
data.resize(4096); data.resize(4096);
dataPos = data.size(); // data is filled backwards dataPos = data.size(); // data is filled backwards
@ -416,6 +427,153 @@ void AssemblyBuilderX64::int3()
log("int3"); log("int3");
place(0xcc); place(0xcc);
commit();
}
void AssemblyBuilderX64::nop(uint32_t length)
{
while (length != 0)
{
uint32_t step = length > 9 ? 9 : length;
length -= step;
switch (step)
{
case 1:
if (logText)
logAppend(" nop\n");
place(0x90);
break;
case 2:
if (logText)
logAppend(" xchg ax, ax ; %u-byte nop\n", step);
place(0x66);
place(0x90);
break;
case 3:
if (logText)
logAppend(" nop dword ptr[rax] ; %u-byte nop\n", step);
place(0x0f);
place(0x1f);
place(0x00);
break;
case 4:
if (logText)
logAppend(" nop dword ptr[rax] ; %u-byte nop\n", step);
place(0x0f);
place(0x1f);
place(0x40);
place(0x00);
break;
case 5:
if (logText)
logAppend(" nop dword ptr[rax+rax] ; %u-byte nop\n", step);
place(0x0f);
place(0x1f);
place(0x44);
place(0x00);
place(0x00);
break;
case 6:
if (logText)
logAppend(" nop word ptr[rax+rax] ; %u-byte nop\n", step);
place(0x66);
place(0x0f);
place(0x1f);
place(0x44);
place(0x00);
place(0x00);
break;
case 7:
if (logText)
logAppend(" nop dword ptr[rax] ; %u-byte nop\n", step);
place(0x0f);
place(0x1f);
place(0x80);
place(0x00);
place(0x00);
place(0x00);
place(0x00);
break;
case 8:
if (logText)
logAppend(" nop dword ptr[rax+rax] ; %u-byte nop\n", step);
place(0x0f);
place(0x1f);
place(0x84);
place(0x00);
place(0x00);
place(0x00);
place(0x00);
place(0x00);
break;
case 9:
if (logText)
logAppend(" nop word ptr[rax+rax] ; %u-byte nop\n", step);
place(0x66);
place(0x0f);
place(0x1f);
place(0x84);
place(0x00);
place(0x00);
place(0x00);
place(0x00);
place(0x00);
break;
}
commit();
}
}
void AssemblyBuilderX64::align(uint32_t alignment, AlignmentDataX64 data)
{
LUAU_ASSERT((alignment & (alignment - 1)) == 0);
uint32_t size = getCodeSize();
uint32_t pad = ((size + alignment - 1) & ~(alignment - 1)) - size;
switch (data)
{
case AlignmentDataX64::Nop:
if (logText)
logAppend("; align %u\n", alignment);
nop(pad);
break;
case AlignmentDataX64::Int3:
if (logText)
logAppend("; align %u using int3\n", alignment);
while (codePos + pad > codeEnd)
extend();
for (uint32_t i = 0; i < pad; ++i)
place(0xcc);
commit();
break;
case AlignmentDataX64::Ud2:
if (logText)
logAppend("; align %u using ud2\n", alignment);
while (codePos + pad > codeEnd)
extend();
uint32_t i = 0;
for (; i + 1 < pad; i += 2)
{
place(0x0f);
place(0x0b);
}
if (i < pad)
place(0xcc);
commit();
break;
}
} }
void AssemblyBuilderX64::vaddpd(OperandX64 dst, OperandX64 src1, OperandX64 src2) void AssemblyBuilderX64::vaddpd(OperandX64 dst, OperandX64 src1, OperandX64 src2)
@ -465,12 +623,12 @@ void AssemblyBuilderX64::vucomisd(OperandX64 src1, OperandX64 src2)
void AssemblyBuilderX64::vcvttsd2si(OperandX64 dst, OperandX64 src) void AssemblyBuilderX64::vcvttsd2si(OperandX64 dst, OperandX64 src)
{ {
placeAvx("vcvttsd2si", dst, src, 0x2c, dst.base.size == SizeX64::dword, AVX_0F, AVX_F2); placeAvx("vcvttsd2si", dst, src, 0x2c, dst.base.size == SizeX64::qword, AVX_0F, AVX_F2);
} }
void AssemblyBuilderX64::vcvtsi2sd(OperandX64 dst, OperandX64 src1, OperandX64 src2) void AssemblyBuilderX64::vcvtsi2sd(OperandX64 dst, OperandX64 src1, OperandX64 src2)
{ {
placeAvx("vcvtsi2sd", dst, src1, src2, 0x2a, (src2.cat == CategoryX64::reg ? src2.base.size : src2.memSize) == SizeX64::dword, AVX_0F, AVX_F2); placeAvx("vcvtsi2sd", dst, src1, src2, 0x2a, (src2.cat == CategoryX64::reg ? src2.base.size : src2.memSize) == SizeX64::qword, AVX_0F, AVX_F2);
} }
void AssemblyBuilderX64::vroundsd(OperandX64 dst, OperandX64 src1, OperandX64 src2, RoundingModeX64 roundingMode) void AssemblyBuilderX64::vroundsd(OperandX64 dst, OperandX64 src1, OperandX64 src2, RoundingModeX64 roundingMode)
@ -626,6 +784,21 @@ OperandX64 AssemblyBuilderX64::bytes(const void* ptr, size_t size, size_t align)
return OperandX64(SizeX64::qword, noreg, 1, rip, int32_t(pos - data.size())); return OperandX64(SizeX64::qword, noreg, 1, rip, int32_t(pos - data.size()));
} }
void AssemblyBuilderX64::logAppend(const char* fmt, ...)
{
char buf[256];
va_list args;
va_start(args, fmt);
vsnprintf(buf, sizeof(buf), fmt, args);
va_end(args);
text.append(buf);
}
uint32_t AssemblyBuilderX64::getCodeSize() const
{
return uint32_t(codePos - code.data());
}
void AssemblyBuilderX64::placeBinary(const char* name, OperandX64 lhs, OperandX64 rhs, uint8_t codeimm8, uint8_t codeimm, uint8_t codeimmImm8, void AssemblyBuilderX64::placeBinary(const char* name, OperandX64 lhs, OperandX64 rhs, uint8_t codeimm8, uint8_t codeimm, uint8_t codeimmImm8,
uint8_t code8rev, uint8_t coderev, uint8_t code8, uint8_t code, uint8_t opreg) uint8_t code8rev, uint8_t coderev, uint8_t code8, uint8_t code, uint8_t opreg)
{ {
@ -1054,7 +1227,7 @@ void AssemblyBuilderX64::commit()
{ {
LUAU_ASSERT(codePos <= codeEnd); LUAU_ASSERT(codePos <= codeEnd);
if (codeEnd - codePos < 16) if (codeEnd - codePos < kMaxInstructionLength)
extend(); extend();
} }
@ -1067,11 +1240,6 @@ void AssemblyBuilderX64::extend()
codeEnd = code.data() + code.size(); codeEnd = code.data() + code.size();
} }
uint32_t AssemblyBuilderX64::getCodeSize()
{
return uint32_t(codePos - code.data());
}
size_t AssemblyBuilderX64::allocateData(size_t size, size_t align) size_t AssemblyBuilderX64::allocateData(size_t size, size_t align)
{ {
LUAU_ASSERT(align > 0 && align <= kMaxAlign && (align & (align - 1)) == 0); LUAU_ASSERT(align > 0 && align <= kMaxAlign && (align & (align - 1)) == 0);
@ -1174,8 +1342,10 @@ void AssemblyBuilderX64::log(OperandX64 op)
{ {
if (op.imm >= 0 && op.imm <= 9) if (op.imm >= 0 && op.imm <= 9)
logAppend("+%d", op.imm); logAppend("+%d", op.imm);
else else if (op.imm > 0)
logAppend("+0%Xh", op.imm); logAppend("+0%Xh", op.imm);
else
logAppend("-0%Xh", -op.imm);
} }
text.append("]"); text.append("]");
@ -1191,17 +1361,7 @@ void AssemblyBuilderX64::log(OperandX64 op)
} }
} }
void AssemblyBuilderX64::logAppend(const char* fmt, ...) const char* AssemblyBuilderX64::getSizeName(SizeX64 size) const
{
char buf[256];
va_list args;
va_start(args, fmt);
vsnprintf(buf, sizeof(buf), fmt, args);
va_end(args);
text.append(buf);
}
const char* AssemblyBuilderX64::getSizeName(SizeX64 size)
{ {
static const char* sizeNames[] = {"none", "byte", "word", "dword", "qword", "xmmword", "ymmword"}; static const char* sizeNames[] = {"none", "byte", "word", "dword", "qword", "xmmword", "ymmword"};
@ -1209,7 +1369,7 @@ const char* AssemblyBuilderX64::getSizeName(SizeX64 size)
return sizeNames[unsigned(size)]; return sizeNames[unsigned(size)];
} }
const char* AssemblyBuilderX64::getRegisterName(RegisterX64 reg) const char* AssemblyBuilderX64::getRegisterName(RegisterX64 reg) const
{ {
static const char* names[][16] = {{"rip", "", "", "", "", "", "", "", "", "", "", "", "", "", "", ""}, static const char* names[][16] = {{"rip", "", "", "", "", "", "", "", "", "", "", "", "", "", "", ""},
{"al", "cl", "dl", "bl", "spl", "bpl", "sil", "dil", "r8b", "r9b", "r10b", "r11b", "r12b", "r13b", "r14b", "r15b"}, {"al", "cl", "dl", "bl", "spl", "bpl", "sil", "dil", "r8b", "r9b", "r10b", "r11b", "r12b", "r13b", "r14b", "r15b"},

View file

@ -1,7 +1,9 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once #pragma once
#if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ #include "Luau/Common.h"
#if defined(LUAU_BIG_ENDIAN)
#include <endian.h> #include <endian.h>
#endif #endif
@ -15,7 +17,7 @@ inline uint8_t* writeu8(uint8_t* target, uint8_t value)
inline uint8_t* writeu32(uint8_t* target, uint32_t value) inline uint8_t* writeu32(uint8_t* target, uint32_t value)
{ {
#if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ #if defined(LUAU_BIG_ENDIAN)
value = htole32(value); value = htole32(value);
#endif #endif
@ -25,7 +27,7 @@ inline uint8_t* writeu32(uint8_t* target, uint32_t value)
inline uint8_t* writeu64(uint8_t* target, uint64_t value) inline uint8_t* writeu64(uint8_t* target, uint64_t value)
{ {
#if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ #if defined(LUAU_BIG_ENDIAN)
value = htole64(value); value = htole64(value);
#endif #endif
@ -51,7 +53,7 @@ inline uint8_t* writeuleb128(uint8_t* target, uint64_t value)
inline uint8_t* writef32(uint8_t* target, float value) inline uint8_t* writef32(uint8_t* target, float value)
{ {
#if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ #if defined(LUAU_BIG_ENDIAN)
static_assert(sizeof(float) == sizeof(uint32_t), "type size must match to reinterpret data"); static_assert(sizeof(float) == sizeof(uint32_t), "type size must match to reinterpret data");
uint32_t data; uint32_t data;
memcpy(&data, &value, sizeof(value)); memcpy(&data, &value, sizeof(value));
@ -65,7 +67,7 @@ inline uint8_t* writef32(uint8_t* target, float value)
inline uint8_t* writef64(uint8_t* target, double value) inline uint8_t* writef64(uint8_t* target, double value)
{ {
#if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ #if defined(LUAU_BIG_ENDIAN)
static_assert(sizeof(double) == sizeof(uint64_t), "type size must match to reinterpret data"); static_assert(sizeof(double) == sizeof(uint64_t), "type size must match to reinterpret data");
uint64_t data; uint64_t data;
memcpy(&data, &value, sizeof(value)); memcpy(&data, &value, sizeof(value));

View file

@ -110,8 +110,8 @@ CodeAllocator::~CodeAllocator()
bool CodeAllocator::allocate( bool CodeAllocator::allocate(
uint8_t* data, size_t dataSize, uint8_t* code, size_t codeSize, uint8_t*& result, size_t& resultSize, uint8_t*& resultCodeStart) uint8_t* data, size_t dataSize, uint8_t* code, size_t codeSize, uint8_t*& result, size_t& resultSize, uint8_t*& resultCodeStart)
{ {
// 'Round up' to preserve 16 byte alignment // 'Round up' to preserve code alignment
size_t alignedDataSize = (dataSize + 15) & ~15; size_t alignedDataSize = (dataSize + (kCodeAlignment - 1)) & ~(kCodeAlignment - 1);
size_t totalSize = alignedDataSize + codeSize; size_t totalSize = alignedDataSize + codeSize;
@ -187,8 +187,8 @@ bool CodeAllocator::allocateNewBlock(size_t& unwindInfoSize)
{ {
void* unwindInfo = createBlockUnwindInfo(context, block, blockSize, unwindInfoSize); void* unwindInfo = createBlockUnwindInfo(context, block, blockSize, unwindInfoSize);
// 'Round up' to preserve 16 byte alignment of the following data and code // 'Round up' to preserve alignment of the following data and code
unwindInfoSize = (unwindInfoSize + 15) & ~15; unwindInfoSize = (unwindInfoSize + (kCodeAlignment - 1)) & ~(kCodeAlignment - 1);
LUAU_ASSERT(unwindInfoSize <= kMaxReservedDataSize); LUAU_ASSERT(unwindInfoSize <= kMaxReservedDataSize);

View file

@ -1,6 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/CodeBlockUnwind.h" #include "Luau/CodeBlockUnwind.h"
#include "Luau/CodeAllocator.h"
#include "Luau/UnwindBuilder.h" #include "Luau/UnwindBuilder.h"
#include <string.h> #include <string.h>
@ -58,7 +59,7 @@ void* createBlockUnwindInfo(void* context, uint8_t* block, size_t blockSize, siz
// All unwinding related data is placed together at the start of the block // All unwinding related data is placed together at the start of the block
size_t unwindSize = sizeof(RUNTIME_FUNCTION) + unwind->getSize(); size_t unwindSize = sizeof(RUNTIME_FUNCTION) + unwind->getSize();
unwindSize = (unwindSize + 15) & ~15; // Align to 16 bytes unwindSize = (unwindSize + (kCodeAlignment - 1)) & ~(kCodeAlignment - 1); // Match code allocator alignment
LUAU_ASSERT(blockSize >= unwindSize); LUAU_ASSERT(blockSize >= unwindSize);
RUNTIME_FUNCTION* runtimeFunc = (RUNTIME_FUNCTION*)block; RUNTIME_FUNCTION* runtimeFunc = (RUNTIME_FUNCTION*)block;
@ -82,7 +83,7 @@ void* createBlockUnwindInfo(void* context, uint8_t* block, size_t blockSize, siz
// All unwinding related data is placed together at the start of the block // All unwinding related data is placed together at the start of the block
size_t unwindSize = unwind->getSize(); size_t unwindSize = unwind->getSize();
unwindSize = (unwindSize + 15) & ~15; // Align to 16 bytes unwindSize = (unwindSize + (kCodeAlignment - 1)) & ~(kCodeAlignment - 1); // Match code allocator alignment
LUAU_ASSERT(blockSize >= unwindSize); LUAU_ASSERT(blockSize >= unwindSize);
char* unwindData = (char*)block; char* unwindData = (char*)block;

View file

@ -32,7 +32,9 @@ namespace Luau
namespace CodeGen namespace CodeGen
{ {
static NativeProto* assembleFunction(AssemblyBuilderX64& build, NativeState& data, Proto* proto) constexpr uint32_t kFunctionAlignment = 32;
static NativeProto* assembleFunction(AssemblyBuilderX64& build, NativeState& data, Proto* proto, AssemblyOptions options)
{ {
NativeProto* result = new NativeProto(); NativeProto* result = new NativeProto();
@ -54,142 +56,177 @@ static NativeProto* assembleFunction(AssemblyBuilderX64& build, NativeState& dat
std::vector<Label> instLabels; std::vector<Label> instLabels;
instLabels.resize(proto->sizecode); instLabels.resize(proto->sizecode);
std::vector<Label> instFallbacks;
instFallbacks.resize(proto->sizecode);
build.align(kFunctionAlignment, AlignmentDataX64::Ud2);
Label start = build.setLabel(); Label start = build.setLabel();
for (int i = 0; i < proto->sizecode;) for (int i = 0, instid = 0; i < proto->sizecode; ++instid)
{ {
const Instruction* pc = &proto->code[i]; const Instruction* pc = &proto->code[i];
LuauOpcode op = LuauOpcode(LUAU_INSN_OP(*pc)); LuauOpcode op = LuauOpcode(LUAU_INSN_OP(*pc));
build.setLabel(instLabels[i]); build.setLabel(instLabels[i]);
if (build.logText) if (options.annotator)
build.logAppend("; #%d: %s\n", i, data.names[op]); options.annotator(options.annotatorContext, build.text, proto->bytecodeid, instid);
switch (op) switch (op)
{ {
case LOP_NOP: case LOP_NOP:
break; break;
case LOP_LOADNIL: case LOP_LOADNIL:
emitInstLoadNil(build, data, pc); emitInstLoadNil(build, pc);
break; break;
case LOP_LOADB: case LOP_LOADB:
emitInstLoadB(build, data, pc, i, instLabels.data()); emitInstLoadB(build, pc, i, instLabels.data());
break; break;
case LOP_LOADN: case LOP_LOADN:
emitInstLoadN(build, data, pc); emitInstLoadN(build, pc);
break; break;
case LOP_LOADK: case LOP_LOADK:
emitInstLoadK(build, data, pc, proto->k); emitInstLoadK(build, pc);
break;
case LOP_LOADKX:
emitInstLoadKX(build, pc);
break; break;
case LOP_MOVE: case LOP_MOVE:
emitInstMove(build, data, pc); emitInstMove(build, pc);
break;
case LOP_GETGLOBAL:
emitInstGetGlobal(build, pc, i, instFallbacks[i]);
break;
case LOP_SETGLOBAL:
emitInstSetGlobal(build, pc, i, instLabels.data(), instFallbacks[i]);
break; break;
case LOP_GETTABLE: case LOP_GETTABLE:
emitInstGetTable(build, pc, i); emitInstGetTable(build, pc, i, instFallbacks[i]);
break; break;
case LOP_SETTABLE: case LOP_SETTABLE:
emitInstSetTable(build, pc, i); emitInstSetTable(build, pc, i, instLabels.data(), instFallbacks[i]);
break;
case LOP_GETTABLEKS:
emitInstGetTableKS(build, pc, i, instFallbacks[i]);
break;
case LOP_SETTABLEKS:
emitInstSetTableKS(build, pc, i, instLabels.data(), instFallbacks[i]);
break; break;
case LOP_GETTABLEN: case LOP_GETTABLEN:
emitInstGetTableN(build, pc, i); emitInstGetTableN(build, pc, i, instFallbacks[i]);
break; break;
case LOP_SETTABLEN: case LOP_SETTABLEN:
emitInstSetTableN(build, pc, i); emitInstSetTableN(build, pc, i, instLabels.data(), instFallbacks[i]);
break; break;
case LOP_JUMP: case LOP_JUMP:
emitInstJump(build, data, pc, i, instLabels.data()); emitInstJump(build, pc, i, instLabels.data());
break; break;
case LOP_JUMPBACK: case LOP_JUMPBACK:
emitInstJumpBack(build, data, pc, i, instLabels.data()); emitInstJumpBack(build, pc, i, instLabels.data());
break; break;
case LOP_JUMPIF: case LOP_JUMPIF:
emitInstJumpIf(build, data, pc, i, instLabels.data(), /* not_ */ false); emitInstJumpIf(build, pc, i, instLabels.data(), /* not_ */ false);
break; break;
case LOP_JUMPIFNOT: case LOP_JUMPIFNOT:
emitInstJumpIf(build, data, pc, i, instLabels.data(), /* not_ */ true); emitInstJumpIf(build, pc, i, instLabels.data(), /* not_ */ true);
break; break;
case LOP_JUMPIFEQ: case LOP_JUMPIFEQ:
emitInstJumpIfEq(build, data, pc, i, instLabels.data(), /* not_ */ false); emitInstJumpIfEq(build, pc, i, instLabels.data(), /* not_ */ false, instFallbacks[i]);
break; break;
case LOP_JUMPIFLE: case LOP_JUMPIFLE:
emitInstJumpIfCond(build, data, pc, i, instLabels.data(), Condition::LessEqual); emitInstJumpIfCond(build, pc, i, instLabels.data(), Condition::LessEqual, instFallbacks[i]);
break; break;
case LOP_JUMPIFLT: case LOP_JUMPIFLT:
emitInstJumpIfCond(build, data, pc, i, instLabels.data(), Condition::Less); emitInstJumpIfCond(build, pc, i, instLabels.data(), Condition::Less, instFallbacks[i]);
break; break;
case LOP_JUMPIFNOTEQ: case LOP_JUMPIFNOTEQ:
emitInstJumpIfEq(build, data, pc, i, instLabels.data(), /* not_ */ true); emitInstJumpIfEq(build, pc, i, instLabels.data(), /* not_ */ true, instFallbacks[i]);
break; break;
case LOP_JUMPIFNOTLE: case LOP_JUMPIFNOTLE:
emitInstJumpIfCond(build, data, pc, i, instLabels.data(), Condition::NotLessEqual); emitInstJumpIfCond(build, pc, i, instLabels.data(), Condition::NotLessEqual, instFallbacks[i]);
break; break;
case LOP_JUMPIFNOTLT: case LOP_JUMPIFNOTLT:
emitInstJumpIfCond(build, data, pc, i, instLabels.data(), Condition::NotLess); emitInstJumpIfCond(build, pc, i, instLabels.data(), Condition::NotLess, instFallbacks[i]);
break; break;
case LOP_JUMPX: case LOP_JUMPX:
emitInstJumpX(build, data, pc, i, instLabels.data()); emitInstJumpX(build, pc, i, instLabels.data());
break; break;
case LOP_JUMPXEQKNIL: case LOP_JUMPXEQKNIL:
emitInstJumpxEqNil(build, data, pc, proto->k, i, instLabels.data()); emitInstJumpxEqNil(build, pc, i, instLabels.data());
break; break;
case LOP_JUMPXEQKB: case LOP_JUMPXEQKB:
emitInstJumpxEqB(build, data, pc, proto->k, i, instLabels.data()); emitInstJumpxEqB(build, pc, i, instLabels.data());
break; break;
case LOP_JUMPXEQKN: case LOP_JUMPXEQKN:
emitInstJumpxEqN(build, data, pc, proto->k, i, instLabels.data()); emitInstJumpxEqN(build, pc, proto->k, i, instLabels.data());
break; break;
case LOP_JUMPXEQKS: case LOP_JUMPXEQKS:
emitInstJumpxEqS(build, data, pc, proto->k, i, instLabels.data()); emitInstJumpxEqS(build, pc, i, instLabels.data());
break; break;
case LOP_ADD: case LOP_ADD:
emitInstAdd(build, pc, i); emitInstBinary(build, pc, i, TM_ADD, instFallbacks[i]);
break; break;
case LOP_SUB: case LOP_SUB:
emitInstSub(build, pc, i); emitInstBinary(build, pc, i, TM_SUB, instFallbacks[i]);
break; break;
case LOP_MUL: case LOP_MUL:
emitInstMul(build, pc, i); emitInstBinary(build, pc, i, TM_MUL, instFallbacks[i]);
break; break;
case LOP_DIV: case LOP_DIV:
emitInstDiv(build, pc, i); emitInstBinary(build, pc, i, TM_DIV, instFallbacks[i]);
break; break;
case LOP_MOD: case LOP_MOD:
emitInstMod(build, pc, i); emitInstBinary(build, pc, i, TM_MOD, instFallbacks[i]);
break; break;
case LOP_POW: case LOP_POW:
emitInstPow(build, pc, i); emitInstBinary(build, pc, i, TM_POW, instFallbacks[i]);
break; break;
case LOP_ADDK: case LOP_ADDK:
emitInstAddK(build, pc, proto->k, i); emitInstBinaryK(build, pc, i, TM_ADD, instFallbacks[i]);
break; break;
case LOP_SUBK: case LOP_SUBK:
emitInstSubK(build, pc, proto->k, i); emitInstBinaryK(build, pc, i, TM_SUB, instFallbacks[i]);
break; break;
case LOP_MULK: case LOP_MULK:
emitInstMulK(build, pc, proto->k, i); emitInstBinaryK(build, pc, i, TM_MUL, instFallbacks[i]);
break; break;
case LOP_DIVK: case LOP_DIVK:
emitInstDivK(build, pc, proto->k, i); emitInstBinaryK(build, pc, i, TM_DIV, instFallbacks[i]);
break; break;
case LOP_MODK: case LOP_MODK:
emitInstModK(build, pc, proto->k, i); emitInstBinaryK(build, pc, i, TM_MOD, instFallbacks[i]);
break; break;
case LOP_POWK: case LOP_POWK:
emitInstPowK(build, pc, proto->k, i); emitInstPowK(build, pc, proto->k, i, instFallbacks[i]);
break; break;
case LOP_NOT: case LOP_NOT:
emitInstNot(build, pc); emitInstNot(build, pc);
break; break;
case LOP_MINUS: case LOP_MINUS:
emitInstMinus(build, pc, i); emitInstMinus(build, pc, i, instFallbacks[i]);
break; break;
case LOP_LENGTH: case LOP_LENGTH:
emitInstLength(build, pc, i); emitInstLength(build, pc, i, instFallbacks[i]);
break;
case LOP_NEWTABLE:
emitInstNewTable(build, pc, i, instLabels.data());
break;
case LOP_DUPTABLE:
emitInstDupTable(build, pc, i, instLabels.data());
break;
case LOP_SETLIST:
emitInstSetList(build, pc, i, instLabels.data());
break; break;
case LOP_GETUPVAL: case LOP_GETUPVAL:
emitInstGetUpval(build, pc, i); emitInstGetUpval(build, pc, i);
break; break;
case LOP_SETUPVAL:
emitInstSetUpval(build, pc, i, instLabels.data());
break;
case LOP_CLOSEUPVALS:
emitInstCloseUpvals(build, pc, i, instLabels.data());
break;
case LOP_FASTCALL: case LOP_FASTCALL:
emitInstFastCall(build, pc, i, instLabels.data()); emitInstFastCall(build, pc, i, instLabels.data());
break; break;
@ -200,7 +237,7 @@ static NativeProto* assembleFunction(AssemblyBuilderX64& build, NativeState& dat
emitInstFastCall2(build, pc, i, instLabels.data()); emitInstFastCall2(build, pc, i, instLabels.data());
break; break;
case LOP_FASTCALL2K: case LOP_FASTCALL2K:
emitInstFastCall2K(build, pc, proto->k, i, instLabels.data()); emitInstFastCall2K(build, pc, i, instLabels.data());
break; break;
case LOP_FORNPREP: case LOP_FORNPREP:
emitInstForNPrep(build, pc, i, instLabels.data()); emitInstForNPrep(build, pc, i, instLabels.data());
@ -220,6 +257,12 @@ static NativeProto* assembleFunction(AssemblyBuilderX64& build, NativeState& dat
case LOP_ORK: case LOP_ORK:
emitInstOrK(build, pc); emitInstOrK(build, pc);
break; break;
case LOP_GETIMPORT:
emitInstGetImport(build, pc, instFallbacks[i]);
break;
case LOP_CONCAT:
emitInstConcat(build, pc, i, instLabels.data());
break;
default: default:
emitFallback(build, data, op, i); emitFallback(build, data, op, i);
break; break;
@ -229,6 +272,145 @@ static NativeProto* assembleFunction(AssemblyBuilderX64& build, NativeState& dat
LUAU_ASSERT(i <= proto->sizecode); LUAU_ASSERT(i <= proto->sizecode);
} }
size_t textSize = build.text.size();
uint32_t codeSize = build.getCodeSize();
if (options.annotator && !options.skipOutlinedCode)
build.logAppend("; outlined code\n");
for (int i = 0, instid = 0; i < proto->sizecode; ++instid)
{
const Instruction* pc = &proto->code[i];
LuauOpcode op = LuauOpcode(LUAU_INSN_OP(*pc));
int nexti = i + getOpLength(op);
LUAU_ASSERT(nexti <= proto->sizecode);
if (instFallbacks[i].id == 0)
{
i = nexti;
continue;
}
if (options.annotator && !options.skipOutlinedCode)
options.annotator(options.annotatorContext, build.text, proto->bytecodeid, instid);
build.setLabel(instFallbacks[i]);
switch (op)
{
case LOP_GETIMPORT:
emitInstGetImportFallback(build, pc, i);
break;
case LOP_GETTABLE:
emitInstGetTableFallback(build, pc, i);
break;
case LOP_SETTABLE:
emitInstSetTableFallback(build, pc, i);
break;
case LOP_GETTABLEN:
emitInstGetTableNFallback(build, pc, i);
break;
case LOP_SETTABLEN:
emitInstSetTableNFallback(build, pc, i);
break;
case LOP_JUMPIFEQ:
emitInstJumpIfEqFallback(build, pc, i, instLabels.data(), /* not_ */ false);
break;
case LOP_JUMPIFLE:
emitInstJumpIfCondFallback(build, pc, i, instLabels.data(), Condition::LessEqual);
break;
case LOP_JUMPIFLT:
emitInstJumpIfCondFallback(build, pc, i, instLabels.data(), Condition::Less);
break;
case LOP_JUMPIFNOTEQ:
emitInstJumpIfEqFallback(build, pc, i, instLabels.data(), /* not_ */ true);
break;
case LOP_JUMPIFNOTLE:
emitInstJumpIfCondFallback(build, pc, i, instLabels.data(), Condition::NotLessEqual);
break;
case LOP_JUMPIFNOTLT:
emitInstJumpIfCondFallback(build, pc, i, instLabels.data(), Condition::NotLess);
break;
case LOP_ADD:
emitInstBinaryFallback(build, pc, i, TM_ADD);
break;
case LOP_SUB:
emitInstBinaryFallback(build, pc, i, TM_SUB);
break;
case LOP_MUL:
emitInstBinaryFallback(build, pc, i, TM_MUL);
break;
case LOP_DIV:
emitInstBinaryFallback(build, pc, i, TM_DIV);
break;
case LOP_MOD:
emitInstBinaryFallback(build, pc, i, TM_MOD);
break;
case LOP_POW:
emitInstBinaryFallback(build, pc, i, TM_POW);
break;
case LOP_ADDK:
emitInstBinaryKFallback(build, pc, i, TM_ADD);
break;
case LOP_SUBK:
emitInstBinaryKFallback(build, pc, i, TM_SUB);
break;
case LOP_MULK:
emitInstBinaryKFallback(build, pc, i, TM_MUL);
break;
case LOP_DIVK:
emitInstBinaryKFallback(build, pc, i, TM_DIV);
break;
case LOP_MODK:
emitInstBinaryKFallback(build, pc, i, TM_MOD);
break;
case LOP_POWK:
emitInstBinaryKFallback(build, pc, i, TM_POW);
break;
case LOP_MINUS:
emitInstMinusFallback(build, pc, i);
break;
case LOP_LENGTH:
emitInstLengthFallback(build, pc, i);
break;
case LOP_GETGLOBAL:
// TODO: luaV_gettable + cachedslot update instead of full fallback
emitFallback(build, data, op, i);
break;
case LOP_SETGLOBAL:
// TODO: luaV_settable + cachedslot update instead of full fallback
emitFallback(build, data, op, i);
break;
case LOP_GETTABLEKS:
// Full fallback required for LOP_GETTABLEKS because 'luaV_gettable' doesn't handle builtin vector field access
// It is also required to perform cached slot update
// TODO: extra fast-paths could be lowered before the full fallback
emitFallback(build, data, op, i);
break;
case LOP_SETTABLEKS:
// TODO: luaV_settable + cachedslot update instead of full fallback
emitFallback(build, data, op, i);
break;
default:
LUAU_ASSERT(!"Expected fallback for instruction");
}
// Jump back to the next instruction handler
if (nexti < proto->sizecode)
build.jmp(instLabels[nexti]);
i = nexti;
}
// Truncate assembly output if we don't care for outlined code part
if (options.skipOutlinedCode)
{
build.text.resize(textSize);
build.logAppend("; skipping %u bytes of outlined code\n", build.getCodeSize() - codeSize);
}
result->instTargets = new uintptr_t[proto->sizecode]; result->instTargets = new uintptr_t[proto->sizecode];
for (int i = 0; i < proto->sizecode; i++) for (int i = 0; i < proto->sizecode; i++)
@ -392,7 +574,7 @@ void compile(lua_State* L, int idx)
// Skip protos that have been compiled during previous invocations of CodeGen::compile // Skip protos that have been compiled during previous invocations of CodeGen::compile
for (Proto* p : protos) for (Proto* p : protos)
if (p && getProtoExecData(p) == nullptr) if (p && getProtoExecData(p) == nullptr)
results.push_back(assembleFunction(build, *data, p)); results.push_back(assembleFunction(build, *data, p, {}));
build.finalize(); build.finalize();
@ -420,15 +602,15 @@ void compile(lua_State* L, int idx)
setProtoExecData(result->proto, result); setProtoExecData(result->proto, result);
} }
std::string getAssemblyText(lua_State* L, int idx) std::string getAssembly(lua_State* L, int idx, AssemblyOptions options)
{ {
LUAU_ASSERT(lua_isLfunction(L, idx)); LUAU_ASSERT(lua_isLfunction(L, idx));
const TValue* func = luaA_toobject(L, idx); const TValue* func = luaA_toobject(L, idx);
AssemblyBuilderX64 build(/* logText= */ true); AssemblyBuilderX64 build(/* logText= */ !options.outputBinary);
NativeState data; NativeState data;
initFallbackTable(data); initFallbackTable(data);
initInstructionNames(data);
std::vector<Proto*> protos; std::vector<Proto*> protos;
gatherFunctions(protos, clvalue(func)->l.p); gatherFunctions(protos, clvalue(func)->l.p);
@ -436,13 +618,16 @@ std::string getAssemblyText(lua_State* L, int idx)
for (Proto* p : protos) for (Proto* p : protos)
if (p) if (p)
{ {
NativeProto* nativeProto = assembleFunction(build, data, p); NativeProto* nativeProto = assembleFunction(build, data, p, options);
destroyNativeProto(nativeProto); destroyNativeProto(nativeProto);
} }
build.finalize(); build.finalize();
return build.text; if (options.outputBinary)
return std::string(build.code.begin(), build.code.end()) + std::string(build.data.begin(), build.data.end());
else
return build.text;
} }
} // namespace CodeGen } // namespace CodeGen

View file

@ -48,7 +48,7 @@ bool initEntryFunction(NativeState& data)
unwind.start(); unwind.start();
if (getCurrentX64ABI() == X64ABI::Windows) if (build.abi == ABIX64::Windows)
{ {
// Place arguments in home space // Place arguments in home space
build.mov(qword[rsp + 16], rArg2); build.mov(qword[rsp + 16], rArg2);
@ -121,7 +121,7 @@ bool initEntryFunction(NativeState& data)
build.pop(rbp); build.pop(rbp);
build.pop(rbx); build.pop(rbx);
if (getCurrentX64ABI() == X64ABI::Windows) if (build.abi == ABIX64::Windows)
{ {
build.pop(rsi); build.pop(rsi);
build.pop(rdi); build.pop(rdi);

View file

@ -126,20 +126,5 @@ inline int getOpLength(LuauOpcode op)
} }
} }
enum class X64ABI
{
Windows,
SystemV,
};
inline X64ABI getCurrentX64ABI()
{
#if defined(_WIN32)
return X64ABI::Windows;
#else
return X64ABI::SystemV;
#endif
}
} // namespace CodeGen } // namespace CodeGen
} // namespace Luau } // namespace Luau

View file

@ -64,6 +64,7 @@ void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs,
void jumpOnAnyCmpFallback(AssemblyBuilderX64& build, int ra, int rb, Condition cond, Label& label, int pcpos) void jumpOnAnyCmpFallback(AssemblyBuilderX64& build, int ra, int rb, Condition cond, Label& label, int pcpos)
{ {
emitSetSavedPc(build, pcpos + 1); emitSetSavedPc(build, pcpos + 1);
build.mov(rArg1, rState); build.mov(rArg1, rState);
build.lea(rArg2, luauRegValue(ra)); build.lea(rArg2, luauRegValue(ra));
build.lea(rArg3, luauRegValue(rb)); build.lea(rArg3, luauRegValue(rb));
@ -83,6 +84,27 @@ void jumpOnAnyCmpFallback(AssemblyBuilderX64& build, int ra, int rb, Condition c
cond == Condition::NotLessEqual || cond == Condition::NotLess || cond == Condition::NotEqual ? Condition::Zero : Condition::NotZero, label); cond == Condition::NotLessEqual || cond == Condition::NotLess || cond == Condition::NotEqual ? Condition::Zero : Condition::NotZero, label);
} }
RegisterX64 getTableNodeAtCachedSlot(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 table, int pcpos)
{
RegisterX64 node = rdx;
LUAU_ASSERT(tmp != node);
LUAU_ASSERT(table != node);
build.mov(node, qword[table + offsetof(Table, node)]);
// compute cached slot
build.mov(tmp, sCode);
build.movzx(dwordReg(tmp), byte[tmp + pcpos * sizeof(Instruction) + kOffsetOfInstructionC]);
build.and_(byteReg(tmp), byte[table + offsetof(Table, nodemask8)]);
// LuaNode* n = &h->node[slot];
build.shl(dwordReg(tmp), kLuaNodeSizeLog2);
build.add(node, tmp);
return node;
}
void convertNumberToIndexOrJump(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 numd, RegisterX64 numi, int ri, Label& label) void convertNumberToIndexOrJump(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 numd, RegisterX64 numi, int ri, Label& label)
{ {
LUAU_ASSERT(numi.size == SizeX64::dword); LUAU_ASSERT(numi.size == SizeX64::dword);
@ -105,7 +127,7 @@ void callArithHelper(AssemblyBuilderX64& build, int ra, int rb, OperandX64 c, in
{ {
emitSetSavedPc(build, pcpos + 1); emitSetSavedPc(build, pcpos + 1);
if (getCurrentX64ABI() == X64ABI::Windows) if (build.abi == ABIX64::Windows)
build.mov(sArg5, tm); build.mov(sArg5, tm);
else else
build.mov(rArg5, tm); build.mov(rArg5, tm);
@ -168,16 +190,17 @@ void callSetTable(AssemblyBuilderX64& build, int rb, OperandX64 c, int ra, int p
emitUpdateBase(build); emitUpdateBase(build);
} }
void callBarrierTable(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 table, int ra, Label& skip) // works for luaC_barriertable, luaC_barrierf
static void callBarrierImpl(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 object, int ra, Label& skip, int contextOffset)
{ {
LUAU_ASSERT(tmp != table); LUAU_ASSERT(tmp != object);
// iscollectable(ra) // iscollectable(ra)
build.cmp(luauRegTag(ra), LUA_TSTRING); build.cmp(luauRegTag(ra), LUA_TSTRING);
build.jcc(Condition::Less, skip); build.jcc(Condition::Less, skip);
// isblack(obj2gco(h)) // isblack(obj2gco(o))
build.test(byte[table + offsetof(GCheader, marked)], bitmask(BLACKBIT)); build.test(byte[object + offsetof(GCheader, marked)], bitmask(BLACKBIT));
build.jcc(Condition::Zero, skip); build.jcc(Condition::Zero, skip);
// iswhite(gcvalue(ra)) // iswhite(gcvalue(ra))
@ -185,11 +208,52 @@ void callBarrierTable(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 ta
build.test(byte[tmp + offsetof(GCheader, marked)], bit2mask(WHITE0BIT, WHITE1BIT)); build.test(byte[tmp + offsetof(GCheader, marked)], bit2mask(WHITE0BIT, WHITE1BIT));
build.jcc(Condition::Zero, skip); build.jcc(Condition::Zero, skip);
LUAU_ASSERT(table != rArg3); LUAU_ASSERT(object != rArg3);
build.mov(rArg3, tmp); build.mov(rArg3, tmp);
build.mov(rArg2, table); build.mov(rArg2, object);
build.mov(rArg1, rState); build.mov(rArg1, rState);
build.call(qword[rNativeContext + offsetof(NativeContext, luaC_barriertable)]); build.call(qword[rNativeContext + contextOffset]);
}
void callBarrierTable(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 table, int ra, Label& skip)
{
callBarrierImpl(build, tmp, table, ra, skip, offsetof(NativeContext, luaC_barriertable));
}
void callBarrierObject(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 object, int ra, Label& skip)
{
callBarrierImpl(build, tmp, object, ra, skip, offsetof(NativeContext, luaC_barrierf));
}
void callBarrierTableFast(AssemblyBuilderX64& build, RegisterX64 table, Label& skip)
{
// isblack(obj2gco(t))
build.test(byte[table + offsetof(GCheader, marked)], bitmask(BLACKBIT));
build.jcc(Condition::Zero, skip);
// Argument setup re-ordered to avoid conflicts with table register
if (table != rArg2)
build.mov(rArg2, table);
build.lea(rArg3, qword[rArg2 + offsetof(Table, gclist)]);
build.mov(rArg1, rState);
build.call(qword[rNativeContext + offsetof(NativeContext, luaC_barrierback)]);
}
void callCheckGc(AssemblyBuilderX64& build, int pcpos, bool savepc, Label& skip)
{
build.mov(rax, qword[rState + offsetof(lua_State, global)]);
build.mov(rdx, qword[rax + offsetof(global_State, totalbytes)]);
build.cmp(rdx, qword[rax + offsetof(global_State, GCthreshold)]);
build.jcc(Condition::Below, skip);
if (savepc)
emitSetSavedPc(build, pcpos + 1);
build.mov(rArg1, rState);
build.mov(rArg2, 1);
build.call(qword[rNativeContext + offsetof(NativeContext, luaC_step)]);
emitUpdateBase(build);
} }
void emitExit(AssemblyBuilderX64& build, bool continueInVm) void emitExit(AssemblyBuilderX64& build, bool continueInVm)
@ -231,7 +295,7 @@ void emitInterrupt(AssemblyBuilderX64& build, int pcpos)
// Call interrupt // Call interrupt
// TODO: This code should move to the end of the function, or even be outlined so that it can be shared by multiple interruptible instructions // TODO: This code should move to the end of the function, or even be outlined so that it can be shared by multiple interruptible instructions
build.mov(rArg1, rState); build.mov(rArg1, rState);
build.mov(rArg2d, -1); build.mov(dwordReg(rArg2), -1); // function accepts 'int' here and using qword reg would've forced 8 byte constant here
build.call(r8); build.call(r8);
// Check if we need to exit // Check if we need to exit

View file

@ -35,11 +35,11 @@ constexpr RegisterX64 rConstants = r12; // TValue* k
constexpr OperandX64 sClosure = qword[rbp + 0]; // Closure* cl constexpr OperandX64 sClosure = qword[rbp + 0]; // Closure* cl
constexpr OperandX64 sCode = qword[rbp + 8]; // Instruction* code constexpr OperandX64 sCode = qword[rbp + 8]; // Instruction* code
// TODO: These should be replaced with a portable call function that checks the ABI at runtime and reorders moves accordingly to avoid conflicts
#if defined(_WIN32) #if defined(_WIN32)
constexpr RegisterX64 rArg1 = rcx; constexpr RegisterX64 rArg1 = rcx;
constexpr RegisterX64 rArg2 = rdx; constexpr RegisterX64 rArg2 = rdx;
constexpr RegisterX64 rArg2d = edx;
constexpr RegisterX64 rArg3 = r8; constexpr RegisterX64 rArg3 = r8;
constexpr RegisterX64 rArg4 = r9; constexpr RegisterX64 rArg4 = r9;
constexpr RegisterX64 rArg5 = noreg; constexpr RegisterX64 rArg5 = noreg;
@ -51,7 +51,6 @@ constexpr OperandX64 sArg6 = qword[rsp + 40];
constexpr RegisterX64 rArg1 = rdi; constexpr RegisterX64 rArg1 = rdi;
constexpr RegisterX64 rArg2 = rsi; constexpr RegisterX64 rArg2 = rsi;
constexpr RegisterX64 rArg2d = esi;
constexpr RegisterX64 rArg3 = rdx; constexpr RegisterX64 rArg3 = rdx;
constexpr RegisterX64 rArg4 = rcx; constexpr RegisterX64 rArg4 = rcx;
constexpr RegisterX64 rArg5 = r8; constexpr RegisterX64 rArg5 = r8;
@ -62,6 +61,11 @@ constexpr OperandX64 sArg6 = noreg;
#endif #endif
constexpr unsigned kTValueSizeLog2 = 4; constexpr unsigned kTValueSizeLog2 = 4;
constexpr unsigned kLuaNodeSizeLog2 = 5;
constexpr unsigned kLuaNodeTagMask = 0xf;
constexpr unsigned kOffsetOfLuaNodeTag = 12; // offsetof cannot be used on a bit field
constexpr unsigned kOffsetOfInstructionC = 3;
inline OperandX64 luauReg(int ri) inline OperandX64 luauReg(int ri)
{ {
@ -88,11 +92,26 @@ inline OperandX64 luauConstant(int ki)
return xmmword[rConstants + ki * sizeof(TValue)]; return xmmword[rConstants + ki * sizeof(TValue)];
} }
inline OperandX64 luauConstantTag(int ki)
{
return dword[rConstants + ki * sizeof(TValue) + offsetof(TValue, tt)];
}
inline OperandX64 luauConstantValue(int ki) inline OperandX64 luauConstantValue(int ki)
{ {
return qword[rConstants + ki * sizeof(TValue) + offsetof(TValue, value)]; return qword[rConstants + ki * sizeof(TValue) + offsetof(TValue, value)];
} }
inline OperandX64 luauNodeKeyValue(RegisterX64 node)
{
return qword[node + offsetof(LuaNode, key) + offsetof(TKey, value)];
}
inline OperandX64 luauNodeValue(RegisterX64 node)
{
return xmmword[node + offsetof(LuaNode, val)];
}
inline void setLuauReg(AssemblyBuilderX64& build, RegisterX64 tmp, int ri, OperandX64 op) inline void setLuauReg(AssemblyBuilderX64& build, RegisterX64 tmp, int ri, OperandX64 op)
{ {
LUAU_ASSERT(op.cat == CategoryX64::mem); LUAU_ASSERT(op.cat == CategoryX64::mem);
@ -101,6 +120,14 @@ inline void setLuauReg(AssemblyBuilderX64& build, RegisterX64 tmp, int ri, Opera
build.vmovups(luauReg(ri), tmp); build.vmovups(luauReg(ri), tmp);
} }
inline void setNodeValue(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 op, int ri)
{
LUAU_ASSERT(op.cat == CategoryX64::mem);
build.vmovups(tmp, luauReg(ri));
build.vmovups(op, tmp);
}
inline void jumpIfTagIs(AssemblyBuilderX64& build, int ri, lua_Type tag, Label& label) inline void jumpIfTagIs(AssemblyBuilderX64& build, int ri, lua_Type tag, Label& label)
{ {
build.cmp(luauRegTag(ri), tag); build.cmp(luauRegTag(ri), tag);
@ -153,9 +180,37 @@ inline void jumpIfTableIsReadOnly(AssemblyBuilderX64& build, RegisterX64 table,
build.jcc(Condition::NotEqual, label); build.jcc(Condition::NotEqual, label);
} }
inline void jumpIfNodeKeyTagIsNot(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 node, lua_Type tag, Label& label)
{
tmp.size = SizeX64::dword;
build.mov(tmp, dword[node + offsetof(LuaNode, key) + kOffsetOfLuaNodeTag]);
build.and_(tmp, kLuaNodeTagMask);
build.cmp(tmp, tag);
build.jcc(Condition::NotEqual, label);
}
inline void jumpIfNodeValueTagIs(AssemblyBuilderX64& build, RegisterX64 node, lua_Type tag, Label& label)
{
build.cmp(dword[node + offsetof(LuaNode, val) + offsetof(TValue, tt)], tag);
build.jcc(Condition::Equal, label);
}
inline void jumpIfNodeKeyNotInExpectedSlot(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 node, OperandX64 expectedKey, Label& label)
{
jumpIfNodeKeyTagIsNot(build, tmp, node, LUA_TSTRING, label);
build.mov(tmp, expectedKey);
build.cmp(tmp, luauNodeKeyValue(node));
build.jcc(Condition::NotEqual, label);
jumpIfNodeValueTagIs(build, node, LUA_TNIL, label);
}
void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs, OperandX64 rhs, Condition cond, Label& label); void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs, OperandX64 rhs, Condition cond, Label& label);
void jumpOnAnyCmpFallback(AssemblyBuilderX64& build, int ra, int rb, Condition cond, Label& label, int pcpos); void jumpOnAnyCmpFallback(AssemblyBuilderX64& build, int ra, int rb, Condition cond, Label& label, int pcpos);
RegisterX64 getTableNodeAtCachedSlot(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 table, int pcpos);
void convertNumberToIndexOrJump(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 numd, RegisterX64 numi, int ri, Label& label); void convertNumberToIndexOrJump(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 numd, RegisterX64 numi, int ri, Label& label);
void callArithHelper(AssemblyBuilderX64& build, int ra, int rb, OperandX64 c, int pcpos, TMS tm); void callArithHelper(AssemblyBuilderX64& build, int ra, int rb, OperandX64 c, int pcpos, TMS tm);
@ -164,6 +219,9 @@ void callPrepareForN(AssemblyBuilderX64& build, int limit, int step, int init, i
void callGetTable(AssemblyBuilderX64& build, int rb, OperandX64 c, int ra, int pcpos); void callGetTable(AssemblyBuilderX64& build, int rb, OperandX64 c, int ra, int pcpos);
void callSetTable(AssemblyBuilderX64& build, int rb, OperandX64 c, int ra, int pcpos); void callSetTable(AssemblyBuilderX64& build, int rb, OperandX64 c, int ra, int pcpos);
void callBarrierTable(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 table, int ra, Label& skip); void callBarrierTable(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 table, int ra, Label& skip);
void callBarrierObject(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 object, int ra, Label& skip);
void callBarrierTableFast(AssemblyBuilderX64& build, RegisterX64 table, Label& skip);
void callCheckGc(AssemblyBuilderX64& build, int pcpos, bool savepc, Label& skip);
void emitExit(AssemblyBuilderX64& build, bool continueInVm); void emitExit(AssemblyBuilderX64& build, bool continueInVm);
void emitUpdateBase(AssemblyBuilderX64& build); void emitUpdateBase(AssemblyBuilderX64& build);

View file

@ -11,19 +11,22 @@
#include "lobject.h" #include "lobject.h"
#include "ltm.h" #include "ltm.h"
// TODO: all uses of luauRegValue and luauConstantValue need to be audited; some need to be changed to luauReg/ConstantAddress (doesn't exist yet)
// (the problem with existing use is that it includes additional offsetof(TValue, value) which happens to be 0 but isn't guaranteed to be)
namespace Luau namespace Luau
{ {
namespace CodeGen namespace CodeGen
{ {
void emitInstLoadNil(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc) void emitInstLoadNil(AssemblyBuilderX64& build, const Instruction* pc)
{ {
int ra = LUAU_INSN_A(*pc); int ra = LUAU_INSN_A(*pc);
build.mov(luauRegTag(ra), LUA_TNIL); build.mov(luauRegTag(ra), LUA_TNIL);
} }
void emitInstLoadB(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc, int pcpos, Label* labelarr) void emitInstLoadB(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr)
{ {
int ra = LUAU_INSN_A(*pc); int ra = LUAU_INSN_A(*pc);
@ -34,7 +37,7 @@ void emitInstLoadB(AssemblyBuilderX64& build, NativeState& data, const Instructi
build.jmp(labelarr[pcpos + target + 1]); build.jmp(labelarr[pcpos + target + 1]);
} }
void emitInstLoadN(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc) void emitInstLoadN(AssemblyBuilderX64& build, const Instruction* pc)
{ {
int ra = LUAU_INSN_A(*pc); int ra = LUAU_INSN_A(*pc);
@ -43,7 +46,7 @@ void emitInstLoadN(AssemblyBuilderX64& build, NativeState& data, const Instructi
build.mov(luauRegTag(ra), LUA_TNUMBER); build.mov(luauRegTag(ra), LUA_TNUMBER);
} }
void emitInstLoadK(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc, const TValue* k) void emitInstLoadK(AssemblyBuilderX64& build, const Instruction* pc)
{ {
int ra = LUAU_INSN_A(*pc); int ra = LUAU_INSN_A(*pc);
@ -51,7 +54,16 @@ void emitInstLoadK(AssemblyBuilderX64& build, NativeState& data, const Instructi
build.vmovups(luauReg(ra), xmm0); build.vmovups(luauReg(ra), xmm0);
} }
void emitInstMove(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc) void emitInstLoadKX(AssemblyBuilderX64& build, const Instruction* pc)
{
int ra = LUAU_INSN_A(*pc);
uint32_t aux = pc[1];
build.vmovups(xmm0, luauConstant(aux));
build.vmovups(luauReg(ra), xmm0);
}
void emitInstMove(AssemblyBuilderX64& build, const Instruction* pc)
{ {
int ra = LUAU_INSN_A(*pc); int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc); int rb = LUAU_INSN_B(*pc);
@ -60,19 +72,19 @@ void emitInstMove(AssemblyBuilderX64& build, NativeState& data, const Instructio
build.vmovups(luauReg(ra), xmm0); build.vmovups(luauReg(ra), xmm0);
} }
void emitInstJump(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc, int pcpos, Label* labelarr) void emitInstJump(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr)
{ {
build.jmp(labelarr[pcpos + LUAU_INSN_D(*pc) + 1]); build.jmp(labelarr[pcpos + LUAU_INSN_D(*pc) + 1]);
} }
void emitInstJumpBack(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc, int pcpos, Label* labelarr) void emitInstJumpBack(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr)
{ {
emitInterrupt(build, pcpos); emitInterrupt(build, pcpos);
build.jmp(labelarr[pcpos + LUAU_INSN_D(*pc) + 1]); build.jmp(labelarr[pcpos + LUAU_INSN_D(*pc) + 1]);
} }
void emitInstJumpIf(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc, int pcpos, Label* labelarr, bool not_) void emitInstJumpIf(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, bool not_)
{ {
int ra = LUAU_INSN_A(*pc); int ra = LUAU_INSN_A(*pc);
@ -85,14 +97,13 @@ void emitInstJumpIf(AssemblyBuilderX64& build, NativeState& data, const Instruct
jumpIfTruthy(build, ra, target, exit); jumpIfTruthy(build, ra, target, exit);
} }
void emitInstJumpIfEq(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc, int pcpos, Label* labelarr, bool not_) void emitInstJumpIfEq(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, bool not_, Label& fallback)
{ {
int ra = LUAU_INSN_A(*pc); int ra = LUAU_INSN_A(*pc);
int rb = pc[1]; int rb = pc[1];
Label& target = labelarr[pcpos + LUAU_INSN_D(*pc) + 1]; Label& target = labelarr[pcpos + LUAU_INSN_D(*pc) + 1];
Label& exit = labelarr[pcpos + 2]; Label& exit = labelarr[pcpos + 2];
Label any;
build.mov(eax, luauRegTag(ra)); build.mov(eax, luauRegTag(ra));
build.cmp(eax, luauRegTag(rb)); build.cmp(eax, luauRegTag(rb));
@ -100,47 +111,50 @@ void emitInstJumpIfEq(AssemblyBuilderX64& build, NativeState& data, const Instru
// fast-path: number // fast-path: number
build.cmp(eax, LUA_TNUMBER); build.cmp(eax, LUA_TNUMBER);
build.jcc(Condition::NotEqual, any); build.jcc(Condition::NotEqual, fallback);
jumpOnNumberCmp(build, xmm0, luauRegValue(ra), luauRegValue(rb), Condition::NotEqual, not_ ? target : exit); jumpOnNumberCmp(build, xmm0, luauRegValue(ra), luauRegValue(rb), Condition::NotEqual, not_ ? target : exit);
build.jmp(not_ ? exit : target);
// slow-path if (!not_)
// TODO: move to the end of the function build.jmp(target);
build.setLabel(any);
jumpOnAnyCmpFallback(build, ra, rb, not_ ? Condition::NotEqual : Condition::Equal, target, pcpos);
} }
void emitInstJumpIfCond(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc, int pcpos, Label* labelarr, Condition cond) void emitInstJumpIfEqFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, bool not_)
{
Label& target = labelarr[pcpos + LUAU_INSN_D(*pc) + 1];
jumpOnAnyCmpFallback(build, LUAU_INSN_A(*pc), pc[1], not_ ? Condition::NotEqual : Condition::Equal, target, pcpos);
}
void emitInstJumpIfCond(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Condition cond, Label& fallback)
{ {
int ra = LUAU_INSN_A(*pc); int ra = LUAU_INSN_A(*pc);
int rb = pc[1]; int rb = pc[1];
Label& target = labelarr[pcpos + LUAU_INSN_D(*pc) + 1]; Label& target = labelarr[pcpos + LUAU_INSN_D(*pc) + 1];
Label& exit = labelarr[pcpos + 2];
Label any;
// fast-path: number // fast-path: number
jumpIfTagIsNot(build, ra, LUA_TNUMBER, any); jumpIfTagIsNot(build, ra, LUA_TNUMBER, fallback);
jumpIfTagIsNot(build, rb, LUA_TNUMBER, any); jumpIfTagIsNot(build, rb, LUA_TNUMBER, fallback);
jumpOnNumberCmp(build, xmm0, luauRegValue(ra), luauRegValue(rb), cond, target); jumpOnNumberCmp(build, xmm0, luauRegValue(ra), luauRegValue(rb), cond, target);
build.jmp(exit);
// slow-path
// TODO: move to the end of the function
build.setLabel(any);
jumpOnAnyCmpFallback(build, ra, rb, cond, target, pcpos);
} }
void emitInstJumpX(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc, int pcpos, Label* labelarr) void emitInstJumpIfCondFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Condition cond)
{
Label& target = labelarr[pcpos + LUAU_INSN_D(*pc) + 1];
jumpOnAnyCmpFallback(build, LUAU_INSN_A(*pc), pc[1], cond, target, pcpos);
}
void emitInstJumpX(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr)
{ {
emitInterrupt(build, pcpos); emitInterrupt(build, pcpos);
build.jmp(labelarr[pcpos + LUAU_INSN_E(*pc) + 1]); build.jmp(labelarr[pcpos + LUAU_INSN_E(*pc) + 1]);
} }
void emitInstJumpxEqNil(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc, const TValue* k, int pcpos, Label* labelarr) void emitInstJumpxEqNil(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr)
{ {
int ra = LUAU_INSN_A(*pc); int ra = LUAU_INSN_A(*pc);
bool not_ = (pc[1] & 0x80000000) != 0; bool not_ = (pc[1] & 0x80000000) != 0;
@ -151,7 +165,7 @@ void emitInstJumpxEqNil(AssemblyBuilderX64& build, NativeState& data, const Inst
build.jcc(not_ ? Condition::NotEqual : Condition::Equal, target); build.jcc(not_ ? Condition::NotEqual : Condition::Equal, target);
} }
void emitInstJumpxEqB(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc, const TValue* k, int pcpos, Label* labelarr) void emitInstJumpxEqB(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr)
{ {
int ra = LUAU_INSN_A(*pc); int ra = LUAU_INSN_A(*pc);
uint32_t aux = pc[1]; uint32_t aux = pc[1];
@ -166,7 +180,7 @@ void emitInstJumpxEqB(AssemblyBuilderX64& build, NativeState& data, const Instru
build.jcc((aux & 0x1) ^ not_ ? Condition::NotZero : Condition::Zero, target); build.jcc((aux & 0x1) ^ not_ ? Condition::NotZero : Condition::Zero, target);
} }
void emitInstJumpxEqN(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc, const TValue* k, int pcpos, Label* labelarr) void emitInstJumpxEqN(AssemblyBuilderX64& build, const Instruction* pc, const TValue* k, int pcpos, Label* labelarr)
{ {
int ra = LUAU_INSN_A(*pc); int ra = LUAU_INSN_A(*pc);
uint32_t aux = pc[1]; uint32_t aux = pc[1];
@ -192,7 +206,7 @@ void emitInstJumpxEqN(AssemblyBuilderX64& build, NativeState& data, const Instru
} }
} }
void emitInstJumpxEqS(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc, const TValue* k, int pcpos, Label* labelarr) void emitInstJumpxEqS(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr)
{ {
int ra = LUAU_INSN_A(*pc); int ra = LUAU_INSN_A(*pc);
uint32_t aux = pc[1]; uint32_t aux = pc[1];
@ -208,14 +222,12 @@ void emitInstJumpxEqS(AssemblyBuilderX64& build, NativeState& data, const Instru
build.jcc(not_ ? Condition::NotEqual : Condition::Equal, target); build.jcc(not_ ? Condition::NotEqual : Condition::Equal, target);
} }
static void emitInstBinaryNumeric(AssemblyBuilderX64& build, int ra, int rb, int rc, OperandX64 opc, int pcpos, TMS tm) static void emitInstBinaryNumeric(AssemblyBuilderX64& build, int ra, int rb, int rc, OperandX64 opc, int pcpos, TMS tm, Label& fallback)
{ {
Label common, exit; jumpIfTagIsNot(build, rb, LUA_TNUMBER, fallback);
jumpIfTagIsNot(build, rb, LUA_TNUMBER, common);
if (rc != -1 && rc != rb) if (rc != -1 && rc != rb)
jumpIfTagIsNot(build, rc, LUA_TNUMBER, common); jumpIfTagIsNot(build, rc, LUA_TNUMBER, fallback);
// fast-path: number // fast-path: number
build.vmovsd(xmm0, luauRegValue(rb)); build.vmovsd(xmm0, luauRegValue(rb));
@ -254,81 +266,35 @@ static void emitInstBinaryNumeric(AssemblyBuilderX64& build, int ra, int rb, int
if (ra != rb && ra != rc) if (ra != rb && ra != rc)
build.mov(luauRegTag(ra), LUA_TNUMBER); build.mov(luauRegTag(ra), LUA_TNUMBER);
build.jmp(exit);
// slow-path
// TODO: move to the end of the function
build.setLabel(common);
callArithHelper(build, ra, rb, opc, pcpos, tm);
build.setLabel(exit);
} }
void emitInstAdd(AssemblyBuilderX64& build, const Instruction* pc, int pcpos) void emitInstBinary(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, TMS tm, Label& fallback)
{ {
emitInstBinaryNumeric(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), LUAU_INSN_C(*pc), luauRegValue(LUAU_INSN_C(*pc)), pcpos, TM_ADD); emitInstBinaryNumeric(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), LUAU_INSN_C(*pc), luauRegValue(LUAU_INSN_C(*pc)), pcpos, tm, fallback);
} }
void emitInstSub(AssemblyBuilderX64& build, const Instruction* pc, int pcpos) void emitInstBinaryFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, TMS tm)
{ {
emitInstBinaryNumeric(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), LUAU_INSN_C(*pc), luauRegValue(LUAU_INSN_C(*pc)), pcpos, TM_SUB); callArithHelper(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), luauRegValue(LUAU_INSN_C(*pc)), pcpos, tm);
} }
void emitInstMul(AssemblyBuilderX64& build, const Instruction* pc, int pcpos) void emitInstBinaryK(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, TMS tm, Label& fallback)
{ {
emitInstBinaryNumeric(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), LUAU_INSN_C(*pc), luauRegValue(LUAU_INSN_C(*pc)), pcpos, TM_MUL); emitInstBinaryNumeric(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), -1, luauConstantValue(LUAU_INSN_C(*pc)), pcpos, tm, fallback);
} }
void emitInstDiv(AssemblyBuilderX64& build, const Instruction* pc, int pcpos) void emitInstBinaryKFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, TMS tm)
{ {
emitInstBinaryNumeric(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), LUAU_INSN_C(*pc), luauRegValue(LUAU_INSN_C(*pc)), pcpos, TM_DIV); callArithHelper(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), luauConstantValue(LUAU_INSN_C(*pc)), pcpos, tm);
} }
void emitInstMod(AssemblyBuilderX64& build, const Instruction* pc, int pcpos) void emitInstPowK(AssemblyBuilderX64& build, const Instruction* pc, const TValue* k, int pcpos, Label& fallback)
{
emitInstBinaryNumeric(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), LUAU_INSN_C(*pc), luauRegValue(LUAU_INSN_C(*pc)), pcpos, TM_MOD);
}
void emitInstPow(AssemblyBuilderX64& build, const Instruction* pc, int pcpos)
{
emitInstBinaryNumeric(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), LUAU_INSN_C(*pc), luauRegValue(LUAU_INSN_C(*pc)), pcpos, TM_POW);
}
void emitInstAddK(AssemblyBuilderX64& build, const Instruction* pc, const TValue* k, int pcpos)
{
emitInstBinaryNumeric(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), -1, luauConstantValue(LUAU_INSN_C(*pc)), pcpos, TM_ADD);
}
void emitInstSubK(AssemblyBuilderX64& build, const Instruction* pc, const TValue* k, int pcpos)
{
emitInstBinaryNumeric(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), -1, luauConstantValue(LUAU_INSN_C(*pc)), pcpos, TM_SUB);
}
void emitInstMulK(AssemblyBuilderX64& build, const Instruction* pc, const TValue* k, int pcpos)
{
emitInstBinaryNumeric(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), -1, luauConstantValue(LUAU_INSN_C(*pc)), pcpos, TM_MUL);
}
void emitInstDivK(AssemblyBuilderX64& build, const Instruction* pc, const TValue* k, int pcpos)
{
emitInstBinaryNumeric(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), -1, luauConstantValue(LUAU_INSN_C(*pc)), pcpos, TM_DIV);
}
void emitInstModK(AssemblyBuilderX64& build, const Instruction* pc, const TValue* k, int pcpos)
{
emitInstBinaryNumeric(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), -1, luauConstantValue(LUAU_INSN_C(*pc)), pcpos, TM_MOD);
}
void emitInstPowK(AssemblyBuilderX64& build, const Instruction* pc, const TValue* k, int pcpos)
{ {
int ra = LUAU_INSN_A(*pc); int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc); int rb = LUAU_INSN_B(*pc);
double kv = nvalue(&k[LUAU_INSN_C(*pc)]); double kv = nvalue(&k[LUAU_INSN_C(*pc)]);
Label common, exit; jumpIfTagIsNot(build, rb, LUA_TNUMBER, fallback);
jumpIfTagIsNot(build, rb, LUA_TNUMBER, common);
// fast-path: number // fast-path: number
build.vmovsd(xmm0, luauRegValue(rb)); build.vmovsd(xmm0, luauRegValue(rb));
@ -357,15 +323,6 @@ void emitInstPowK(AssemblyBuilderX64& build, const Instruction* pc, const TValue
if (ra != rb) if (ra != rb)
build.mov(luauRegTag(ra), LUA_TNUMBER); build.mov(luauRegTag(ra), LUA_TNUMBER);
build.jmp(exit);
// slow-path
// TODO: move to the end of the function
build.setLabel(common);
callArithHelper(build, ra, rb, luauConstantValue(LUAU_INSN_C(*pc)), pcpos, TM_POW);
build.setLabel(exit);
} }
void emitInstNot(AssemblyBuilderX64& build, const Instruction* pc) void emitInstNot(AssemblyBuilderX64& build, const Instruction* pc)
@ -388,14 +345,12 @@ void emitInstNot(AssemblyBuilderX64& build, const Instruction* pc)
build.mov(luauRegTag(ra), LUA_TBOOLEAN); build.mov(luauRegTag(ra), LUA_TBOOLEAN);
} }
void emitInstMinus(AssemblyBuilderX64& build, const Instruction* pc, int pcpos) void emitInstMinus(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback)
{ {
int ra = LUAU_INSN_A(*pc); int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc); int rb = LUAU_INSN_B(*pc);
Label any, exit; jumpIfTagIsNot(build, rb, LUA_TNUMBER, fallback);
jumpIfTagIsNot(build, rb, LUA_TNUMBER, any);
// fast-path: number // fast-path: number
build.vxorpd(xmm0, xmm0, xmm0); build.vxorpd(xmm0, xmm0, xmm0);
@ -404,29 +359,23 @@ void emitInstMinus(AssemblyBuilderX64& build, const Instruction* pc, int pcpos)
if (ra != rb) if (ra != rb)
build.mov(luauRegTag(ra), LUA_TNUMBER); build.mov(luauRegTag(ra), LUA_TNUMBER);
build.jmp(exit);
// slow-path
// TODO: move to the end of the function
build.setLabel(any);
callArithHelper(build, ra, rb, luauRegValue(rb), pcpos, TM_UNM);
build.setLabel(exit);
} }
void emitInstLength(AssemblyBuilderX64& build, const Instruction* pc, int pcpos) void emitInstMinusFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos)
{
callArithHelper(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), luauRegValue(LUAU_INSN_B(*pc)), pcpos, TM_UNM);
}
void emitInstLength(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback)
{ {
int ra = LUAU_INSN_A(*pc); int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc); int rb = LUAU_INSN_B(*pc);
Label any, exit; jumpIfTagIsNot(build, rb, LUA_TTABLE, fallback);
jumpIfTagIsNot(build, rb, LUA_TTABLE, any);
// fast-path: table without __len // fast-path: table without __len
build.mov(rArg1, luauRegValue(rb)); build.mov(rArg1, luauRegValue(rb));
jumpIfMetatablePresent(build, rArg1, any); jumpIfMetatablePresent(build, rArg1, fallback);
// First argument (Table*) is already in rArg1 // First argument (Table*) is already in rArg1
build.call(qword[rNativeContext + offsetof(NativeContext, luaH_getn)]); build.call(qword[rNativeContext + offsetof(NativeContext, luaH_getn)]);
@ -434,14 +383,154 @@ void emitInstLength(AssemblyBuilderX64& build, const Instruction* pc, int pcpos)
build.vcvtsi2sd(xmm0, xmm0, eax); build.vcvtsi2sd(xmm0, xmm0, eax);
build.vmovsd(luauRegValue(ra), xmm0); build.vmovsd(luauRegValue(ra), xmm0);
build.mov(luauRegTag(ra), LUA_TNUMBER); build.mov(luauRegTag(ra), LUA_TNUMBER);
build.jmp(exit); }
// slow-path void emitInstLengthFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos)
// TODO: move to the end of the function {
build.setLabel(any); callLengthHelper(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), pcpos);
callLengthHelper(build, ra, rb, pcpos); }
build.setLabel(exit); void emitInstNewTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr)
{
int ra = LUAU_INSN_A(*pc);
int b = LUAU_INSN_B(*pc);
uint32_t aux = pc[1];
Label& exit = labelarr[pcpos + 2];
emitSetSavedPc(build, pcpos + 1);
build.mov(rArg1, rState);
build.mov(rArg2, aux);
build.mov(rArg3, b == 0 ? 0 : 1 << (b - 1));
build.call(qword[rNativeContext + offsetof(NativeContext, luaH_new)]);
build.mov(luauRegValue(ra), rax);
build.mov(luauRegTag(ra), LUA_TTABLE);
callCheckGc(build, pcpos, /* savepc = */ false, exit);
}
void emitInstDupTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr)
{
int ra = LUAU_INSN_A(*pc);
Label& exit = labelarr[pcpos + 1];
emitSetSavedPc(build, pcpos + 1);
build.mov(rArg1, rState);
build.mov(rArg2, luauConstantValue(LUAU_INSN_D(*pc)));
build.call(qword[rNativeContext + offsetof(NativeContext, luaH_clone)]);
build.mov(luauRegValue(ra), rax);
build.mov(luauRegTag(ra), LUA_TTABLE);
callCheckGc(build, pcpos, /* savepc= */ false, exit);
}
void emitInstSetList(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr)
{
int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc);
int c = LUAU_INSN_C(*pc) - 1;
uint32_t index = pc[1];
Label& exit = labelarr[pcpos + 2];
OperandX64 last = index + c - 1;
// Using non-volatile 'rbx' for dynamic 'c' value (for LUA_MULTRET) to skip later recomputation
// We also keep 'c' scaled by sizeof(TValue) here as it helps in the loop below
RegisterX64 cscaled = rbx;
if (c == LUA_MULTRET)
{
RegisterX64 tmp = rax;
// c = L->top - rb
build.mov(cscaled, qword[rState + offsetof(lua_State, top)]);
build.lea(tmp, luauRegValue(rb));
build.sub(cscaled, tmp); // Using byte difference
// L->top = L->ci->top
build.mov(tmp, qword[rState + offsetof(lua_State, ci)]);
build.mov(tmp, qword[tmp + offsetof(CallInfo, top)]);
build.mov(qword[rState + offsetof(lua_State, top)], tmp);
// last = index + c - 1;
last = edx;
build.mov(last, dwordReg(cscaled));
build.shr(last, kTValueSizeLog2);
build.add(last, index - 1);
}
Label skipResize;
RegisterX64 table = rax;
build.mov(table, luauRegValue(ra));
// Resize if h->sizearray < last
build.cmp(dword[table + offsetof(Table, sizearray)], last);
build.jcc(Condition::NotBelow, skipResize);
// Argument setup reordered to avoid conflicts
LUAU_ASSERT(rArg3 != table);
build.mov(dwordReg(rArg3), last);
build.mov(rArg2, table);
build.mov(rArg1, rState);
build.call(qword[rNativeContext + offsetof(NativeContext, luaH_resizearray)]);
build.mov(table, luauRegValue(ra)); // Reload cloberred register value
build.setLabel(skipResize);
RegisterX64 arrayDst = rdx;
RegisterX64 offset = rcx;
build.mov(arrayDst, qword[table + offsetof(Table, array)]);
const int kUnrollSetListLimit = 4;
if (c != LUA_MULTRET && c <= kUnrollSetListLimit)
{
for (int i = 0; i < c; ++i)
{
// setobj2t(L, &array[index + i - 1], rb + i);
build.vmovups(xmm0, luauRegValue(rb + i));
build.vmovups(xmmword[arrayDst + (index + i - 1) * sizeof(TValue)], xmm0);
}
}
else
{
LUAU_ASSERT(c != 0);
build.xor_(offset, offset);
if (index != 1)
build.add(arrayDst, (index - 1) * sizeof(TValue));
Label repeatLoop, endLoop;
OperandX64 limit = c == LUA_MULTRET ? cscaled : OperandX64(c * sizeof(TValue));
// If c is static, we will always do at least one iteration
if (c == LUA_MULTRET)
{
build.cmp(offset, limit);
build.jcc(Condition::NotBelow, endLoop);
}
build.setLabel(repeatLoop);
// setobj2t(L, &array[index + i - 1], rb + i);
build.vmovups(xmm0, xmmword[offset + rBase + rb * sizeof(TValue)]); // luauReg(rb) unwrapped to add offset
build.vmovups(xmmword[offset + arrayDst], xmm0);
build.add(offset, sizeof(TValue));
build.cmp(offset, limit);
build.jcc(Condition::Below, repeatLoop);
build.setLabel(endLoop);
}
callBarrierTableFast(build, table, exit);
} }
void emitInstGetUpval(AssemblyBuilderX64& build, const Instruction* pc, int pcpos) void emitInstGetUpval(AssemblyBuilderX64& build, const Instruction* pc, int pcpos)
@ -468,6 +557,45 @@ void emitInstGetUpval(AssemblyBuilderX64& build, const Instruction* pc, int pcpo
build.vmovups(luauReg(ra), xmm0); build.vmovups(luauReg(ra), xmm0);
} }
void emitInstSetUpval(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr)
{
int ra = LUAU_INSN_A(*pc);
int up = LUAU_INSN_B(*pc);
RegisterX64 upval = rax;
RegisterX64 tmp = rcx;
build.mov(tmp, sClosure);
build.mov(upval, qword[tmp + offsetof(Closure, l.uprefs) + sizeof(TValue) * up + offsetof(TValue, value.gc)]);
build.mov(tmp, qword[upval + offsetof(UpVal, v)]);
build.vmovups(xmm0, luauReg(ra));
build.vmovups(xmmword[tmp], xmm0);
callBarrierObject(build, tmp, upval, ra, labelarr[pcpos + 1]);
}
void emitInstCloseUpvals(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr)
{
int ra = LUAU_INSN_A(*pc);
Label& skip = labelarr[pcpos + 1];
// L->openupval != 0
build.mov(rax, qword[rState + offsetof(lua_State, openupval)]);
build.test(rax, rax);
build.jcc(Condition::Zero, skip);
// ra <= L->openuval->v
build.lea(rcx, qword[rBase + ra * sizeof(TValue)]);
build.cmp(rcx, qword[rax + offsetof(UpVal, v)]);
build.jcc(Condition::Above, skip);
build.mov(rArg2, rcx);
build.mov(rArg1, rState);
build.call(qword[rNativeContext + offsetof(NativeContext, luaF_close)]);
}
static void emitInstFastCallN( static void emitInstFastCallN(
AssemblyBuilderX64& build, const Instruction* pc, bool customParams, int customParamCount, OperandX64 customArgs, int pcpos, Label* labelarr) AssemblyBuilderX64& build, const Instruction* pc, bool customParams, int customParamCount, OperandX64 customArgs, int pcpos, Label* labelarr)
{ {
@ -512,14 +640,14 @@ static void emitInstFastCallN(
} }
// TODO: we can skip saving pc for some well-behaved builtins which we didn't inline // TODO: we can skip saving pc for some well-behaved builtins which we didn't inline
emitSetSavedPc(build, pcpos); // uses rax/rdx emitSetSavedPc(build, pcpos + 1); // uses rax/rdx
build.mov(rax, qword[rNativeContext + offsetof(NativeContext, luauF_table) + bfid * sizeof(luau_FastFunction)]); build.mov(rax, qword[rNativeContext + offsetof(NativeContext, luauF_table) + bfid * sizeof(luau_FastFunction)]);
// 5th parameter (args) is left unset for LOP_FASTCALL1 // 5th parameter (args) is left unset for LOP_FASTCALL1
if (args.cat == CategoryX64::mem) if (args.cat == CategoryX64::mem)
{ {
if (getCurrentX64ABI() == X64ABI::Windows) if (build.abi == ABIX64::Windows)
{ {
build.lea(rcx, args); build.lea(rcx, args);
build.mov(sArg5, rcx); build.mov(sArg5, rcx);
@ -539,14 +667,14 @@ static void emitInstFastCallN(
build.sub(rcx, rdx); build.sub(rcx, rdx);
build.shr(rcx, kTValueSizeLog2); build.shr(rcx, kTValueSizeLog2);
if (getCurrentX64ABI() == X64ABI::Windows) if (build.abi == ABIX64::Windows)
build.mov(sArg6, rcx); build.mov(sArg6, rcx);
else else
build.mov(rArg6, rcx); build.mov(rArg6, rcx);
} }
else else
{ {
if (getCurrentX64ABI() == X64ABI::Windows) if (build.abi == ABIX64::Windows)
build.mov(sArg6, nparams); build.mov(sArg6, nparams);
else else
build.mov(rArg6, nparams); build.mov(rArg6, nparams);
@ -594,7 +722,7 @@ void emitInstFastCall2(AssemblyBuilderX64& build, const Instruction* pc, int pcp
emitInstFastCallN(build, pc, /* customParams */ true, /* customParamCount */ 2, /* customArgs */ luauRegValue(pc[1]), pcpos, labelarr); emitInstFastCallN(build, pc, /* customParams */ true, /* customParamCount */ 2, /* customArgs */ luauRegValue(pc[1]), pcpos, labelarr);
} }
void emitInstFastCall2K(AssemblyBuilderX64& build, const Instruction* pc, const TValue* k, int pcpos, Label* labelarr) void emitInstFastCall2K(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr)
{ {
emitInstFastCallN(build, pc, /* customParams */ true, /* customParamCount */ 2, /* customArgs */ luauConstantValue(pc[1]), pcpos, labelarr); emitInstFastCallN(build, pc, /* customParams */ true, /* customParamCount */ 2, /* customArgs */ luauConstantValue(pc[1]), pcpos, labelarr);
} }
@ -762,14 +890,12 @@ void emitInstOrK(AssemblyBuilderX64& build, const Instruction* pc)
emitInstOrX(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), luauConstant(LUAU_INSN_C(*pc))); emitInstOrX(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), luauConstant(LUAU_INSN_C(*pc)));
} }
void emitInstGetTableN(AssemblyBuilderX64& build, const Instruction* pc, int pcpos) void emitInstGetTableN(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback)
{ {
int ra = LUAU_INSN_A(*pc); int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc); int rb = LUAU_INSN_B(*pc);
int c = LUAU_INSN_C(*pc); int c = LUAU_INSN_C(*pc);
Label fallback, exit;
jumpIfTagIsNot(build, rb, LUA_TTABLE, fallback); jumpIfTagIsNot(build, rb, LUA_TTABLE, fallback);
RegisterX64 table = rcx; RegisterX64 table = rcx;
@ -783,27 +909,21 @@ void emitInstGetTableN(AssemblyBuilderX64& build, const Instruction* pc, int pcp
build.mov(rax, qword[table + offsetof(Table, array)]); build.mov(rax, qword[table + offsetof(Table, array)]);
setLuauReg(build, xmm0, ra, xmmword[rax + c * sizeof(TValue)]); setLuauReg(build, xmm0, ra, xmmword[rax + c * sizeof(TValue)]);
build.jmp(exit);
// slow-path
// TODO: move to the end of the function
build.setLabel(fallback);
TValue n;
setnvalue(&n, c + 1);
callGetTable(build, rb, build.bytes(&n, sizeof(n)), ra, pcpos);
build.setLabel(exit);
} }
void emitInstSetTableN(AssemblyBuilderX64& build, const Instruction* pc, int pcpos) void emitInstGetTableNFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos)
{
TValue n;
setnvalue(&n, LUAU_INSN_C(*pc) + 1);
callGetTable(build, LUAU_INSN_B(*pc), build.bytes(&n, sizeof(n)), LUAU_INSN_A(*pc), pcpos);
}
void emitInstSetTableN(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Label& fallback)
{ {
int ra = LUAU_INSN_A(*pc); int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc); int rb = LUAU_INSN_B(*pc);
int c = LUAU_INSN_C(*pc); int c = LUAU_INSN_C(*pc);
Label fallback, exit;
jumpIfTagIsNot(build, rb, LUA_TTABLE, fallback); jumpIfTagIsNot(build, rb, LUA_TTABLE, fallback);
RegisterX64 table = rcx; RegisterX64 table = rcx;
@ -821,27 +941,22 @@ void emitInstSetTableN(AssemblyBuilderX64& build, const Instruction* pc, int pcp
build.vmovups(xmm0, luauReg(ra)); build.vmovups(xmm0, luauReg(ra));
build.vmovups(xmmword[rax + c * sizeof(TValue)], xmm0); build.vmovups(xmmword[rax + c * sizeof(TValue)], xmm0);
callBarrierTable(build, rax, table, ra, exit); callBarrierTable(build, rax, table, ra, labelarr[pcpos + 1]);
build.jmp(exit);
// slow-path
// TODO: move to the end of the function
build.setLabel(fallback);
TValue n;
setnvalue(&n, c + 1);
callSetTable(build, rb, build.bytes(&n, sizeof(n)), ra, pcpos);
build.setLabel(exit);
} }
void emitInstGetTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpos) void emitInstSetTableNFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos)
{
TValue n;
setnvalue(&n, LUAU_INSN_C(*pc) + 1);
callSetTable(build, LUAU_INSN_B(*pc), build.bytes(&n, sizeof(n)), LUAU_INSN_A(*pc), pcpos);
}
void emitInstGetTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback)
{ {
int ra = LUAU_INSN_A(*pc); int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc); int rb = LUAU_INSN_B(*pc);
int rc = LUAU_INSN_C(*pc); int rc = LUAU_INSN_C(*pc);
Label fallback, exit;
jumpIfTagIsNot(build, rb, LUA_TTABLE, fallback); jumpIfTagIsNot(build, rb, LUA_TTABLE, fallback);
jumpIfTagIsNot(build, rc, LUA_TNUMBER, fallback); jumpIfTagIsNot(build, rc, LUA_TNUMBER, fallback);
@ -864,26 +979,19 @@ void emitInstGetTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpo
build.mov(rdx, qword[table + offsetof(Table, array)]); build.mov(rdx, qword[table + offsetof(Table, array)]);
build.shl(eax, kTValueSizeLog2); build.shl(eax, kTValueSizeLog2);
setLuauReg(build, xmm0, ra, xmmword[rdx + rax]); setLuauReg(build, xmm0, ra, xmmword[rdx + rax]);
build.jmp(exit);
build.setLabel(fallback);
// slow-path
// TODO: move to the end of the function
callGetTable(build, rb, luauRegValue(rc), ra, pcpos);
build.setLabel(exit);
} }
void emitInstSetTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpos) void emitInstGetTableFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos)
{
callGetTable(build, LUAU_INSN_B(*pc), luauRegValue(LUAU_INSN_C(*pc)), LUAU_INSN_A(*pc), pcpos);
}
void emitInstSetTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Label& fallback)
{ {
int ra = LUAU_INSN_A(*pc); int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc); int rb = LUAU_INSN_B(*pc);
int rc = LUAU_INSN_C(*pc); int rc = LUAU_INSN_C(*pc);
Label fallback, exit;
jumpIfTagIsNot(build, rb, LUA_TTABLE, fallback); jumpIfTagIsNot(build, rb, LUA_TTABLE, fallback);
jumpIfTagIsNot(build, rc, LUA_TNUMBER, fallback); jumpIfTagIsNot(build, rc, LUA_TNUMBER, fallback);
@ -909,16 +1017,157 @@ void emitInstSetTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpo
build.vmovups(xmm0, luauReg(ra)); build.vmovups(xmm0, luauReg(ra));
build.vmovups(xmmword[rdx + rax], xmm0); build.vmovups(xmmword[rdx + rax], xmm0);
callBarrierTable(build, rdx, table, ra, exit); callBarrierTable(build, rdx, table, ra, labelarr[pcpos + 1]);
build.jmp(exit); }
build.setLabel(fallback); void emitInstSetTableFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos)
{
callSetTable(build, LUAU_INSN_B(*pc), luauRegValue(LUAU_INSN_C(*pc)), LUAU_INSN_A(*pc), pcpos);
}
// slow-path void emitInstGetImport(AssemblyBuilderX64& build, const Instruction* pc, Label& fallback)
// TODO: move to the end of the function {
callSetTable(build, rb, luauRegValue(rc), ra, pcpos); int ra = LUAU_INSN_A(*pc);
int k = LUAU_INSN_D(*pc);
build.setLabel(exit); jumpIfUnsafeEnv(build, rax, fallback);
// note: if import failed, k[] is nil; we could check this during codegen, but we instead use runtime fallback
// this allows us to handle ahead-of-time codegen smoothly when an import fails to resolve at runtime
build.cmp(luauConstantTag(k), LUA_TNIL);
build.jcc(Condition::Equal, fallback);
build.vmovups(xmm0, luauConstant(k));
build.vmovups(luauReg(ra), xmm0);
}
void emitInstGetImportFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos)
{
int ra = LUAU_INSN_A(*pc);
uint32_t aux = pc[1];
emitSetSavedPc(build, pcpos + 1);
build.mov(rax, sClosure);
// luaV_getimport(L, cl->env, k, aux, /* propagatenil= */ false)
build.mov(rArg1, rState);
build.mov(rArg2, qword[rax + offsetof(Closure, env)]);
build.mov(rArg3, rConstants);
build.mov(rArg4, aux);
if (build.abi == ABIX64::Windows)
build.mov(sArg5, 0);
else
build.xor_(rArg5, rArg5);
build.call(qword[rNativeContext + offsetof(NativeContext, luaV_getimport)]);
emitUpdateBase(build);
// setobj2s(L, ra, L->top - 1)
build.mov(rax, qword[rState + offsetof(lua_State, top)]);
build.sub(rax, sizeof(TValue));
build.vmovups(xmm0, xmmword[rax]);
build.vmovups(luauReg(ra), xmm0);
// L->top--
build.mov(qword[rState + offsetof(lua_State, top)], rax);
}
void emitInstGetTableKS(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback)
{
int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc);
uint32_t aux = pc[1];
jumpIfTagIsNot(build, rb, LUA_TTABLE, fallback);
RegisterX64 table = rcx;
build.mov(table, luauRegValue(rb));
RegisterX64 node = getTableNodeAtCachedSlot(build, rax, table, pcpos);
jumpIfNodeKeyNotInExpectedSlot(build, rax, node, luauConstantValue(aux), fallback);
setLuauReg(build, xmm0, ra, luauNodeValue(node));
}
void emitInstSetTableKS(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Label& fallback)
{
int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc);
uint32_t aux = pc[1];
jumpIfTagIsNot(build, rb, LUA_TTABLE, fallback);
RegisterX64 table = rcx;
build.mov(table, luauRegValue(rb));
// fast-path: set value at the expected slot
RegisterX64 node = getTableNodeAtCachedSlot(build, rax, table, pcpos);
jumpIfNodeKeyNotInExpectedSlot(build, rax, node, luauConstantValue(aux), fallback);
jumpIfTableIsReadOnly(build, table, fallback);
setNodeValue(build, xmm0, luauNodeValue(node), ra);
callBarrierTable(build, rax, table, ra, labelarr[pcpos + 2]);
}
void emitInstGetGlobal(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback)
{
int ra = LUAU_INSN_A(*pc);
uint32_t aux = pc[1];
RegisterX64 table = rcx;
build.mov(rax, sClosure);
build.mov(table, qword[rax + offsetof(Closure, env)]);
RegisterX64 node = getTableNodeAtCachedSlot(build, rax, table, pcpos);
jumpIfNodeKeyNotInExpectedSlot(build, rax, node, luauConstantValue(aux), fallback);
setLuauReg(build, xmm0, ra, luauNodeValue(node));
}
void emitInstSetGlobal(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Label& fallback)
{
int ra = LUAU_INSN_A(*pc);
uint32_t aux = pc[1];
RegisterX64 table = rcx;
build.mov(rax, sClosure);
build.mov(table, qword[rax + offsetof(Closure, env)]);
RegisterX64 node = getTableNodeAtCachedSlot(build, rax, table, pcpos);
jumpIfNodeKeyNotInExpectedSlot(build, rax, node, luauConstantValue(aux), fallback);
jumpIfTableIsReadOnly(build, table, fallback);
setNodeValue(build, xmm0, luauNodeValue(node), ra);
callBarrierTable(build, rax, table, ra, labelarr[pcpos + 2]);
}
void emitInstConcat(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr)
{
int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc);
int rc = LUAU_INSN_C(*pc);
emitSetSavedPc(build, pcpos + 1);
// luaV_concat(L, c - b + 1, c)
build.mov(rArg1, rState);
build.mov(rArg2, rc - rb + 1);
build.mov(rArg3, rc);
build.call(qword[rNativeContext + offsetof(NativeContext, luaV_concat)]);
emitUpdateBase(build);
// setobj2s(L, ra, base + b)
build.vmovups(xmm0, luauReg(rb));
build.vmovups(luauReg(ra), xmm0);
callCheckGc(build, pcpos, /* savepc= */ false, labelarr[pcpos + 1]);
} }
} // namespace CodeGen } // namespace CodeGen

View file

@ -3,6 +3,8 @@
#include <stdint.h> #include <stdint.h>
#include "ltm.h"
typedef uint32_t Instruction; typedef uint32_t Instruction;
typedef struct lua_TValue TValue; typedef struct lua_TValue TValue;
@ -16,40 +18,43 @@ enum class Condition;
struct Label; struct Label;
struct NativeState; struct NativeState;
void emitInstLoadNil(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc); void emitInstLoadNil(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstLoadB(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc, int pcpos, Label* labelarr); void emitInstLoadB(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstLoadN(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc); void emitInstLoadN(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstLoadK(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc, const TValue* k); void emitInstLoadK(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstMove(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc); void emitInstLoadKX(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstJump(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc, int pcpos, Label* labelarr); void emitInstMove(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstJumpBack(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc, int pcpos, Label* labelarr); void emitInstJump(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstJumpIf(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc, int pcpos, Label* labelarr, bool not_); void emitInstJumpBack(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstJumpIfEq(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc, int pcpos, Label* labelarr, bool not_); void emitInstJumpIf(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, bool not_);
void emitInstJumpIfCond(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc, int pcpos, Label* labelarr, Condition cond); void emitInstJumpIfEq(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, bool not_, Label& fallback);
void emitInstJumpX(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc, int pcpos, Label* labelarr); void emitInstJumpIfEqFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, bool not_);
void emitInstJumpxEqNil(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc, const TValue* k, int pcpos, Label* labelarr); void emitInstJumpIfCond(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Condition cond, Label& fallback);
void emitInstJumpxEqB(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc, const TValue* k, int pcpos, Label* labelarr); void emitInstJumpIfCondFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Condition cond);
void emitInstJumpxEqN(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc, const TValue* k, int pcpos, Label* labelarr); void emitInstJumpX(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstJumpxEqS(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc, const TValue* k, int pcpos, Label* labelarr); void emitInstJumpxEqNil(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstAdd(AssemblyBuilderX64& build, const Instruction* pc, int pcpos); void emitInstJumpxEqB(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstSub(AssemblyBuilderX64& build, const Instruction* pc, int pcpos); void emitInstJumpxEqN(AssemblyBuilderX64& build, const Instruction* pc, const TValue* k, int pcpos, Label* labelarr);
void emitInstMul(AssemblyBuilderX64& build, const Instruction* pc, int pcpos); void emitInstJumpxEqS(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstDiv(AssemblyBuilderX64& build, const Instruction* pc, int pcpos); void emitInstBinary(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, TMS tm, Label& fallback);
void emitInstMod(AssemblyBuilderX64& build, const Instruction* pc, int pcpos); void emitInstBinaryFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, TMS tm);
void emitInstPow(AssemblyBuilderX64& build, const Instruction* pc, int pcpos); void emitInstBinaryK(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, TMS tm, Label& fallback);
void emitInstAddK(AssemblyBuilderX64& build, const Instruction* pc, const TValue* k, int pcpos); void emitInstBinaryKFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, TMS tm);
void emitInstSubK(AssemblyBuilderX64& build, const Instruction* pc, const TValue* k, int pcpos); void emitInstPowK(AssemblyBuilderX64& build, const Instruction* pc, const TValue* k, int pcpos, Label& fallback);
void emitInstMulK(AssemblyBuilderX64& build, const Instruction* pc, const TValue* k, int pcpos);
void emitInstDivK(AssemblyBuilderX64& build, const Instruction* pc, const TValue* k, int pcpos);
void emitInstModK(AssemblyBuilderX64& build, const Instruction* pc, const TValue* k, int pcpos);
void emitInstPowK(AssemblyBuilderX64& build, const Instruction* pc, const TValue* k, int pcpos);
void emitInstNot(AssemblyBuilderX64& build, const Instruction* pc); void emitInstNot(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstMinus(AssemblyBuilderX64& build, const Instruction* pc, int pcpos); void emitInstMinus(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback);
void emitInstLength(AssemblyBuilderX64& build, const Instruction* pc, int pcpos); void emitInstMinusFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos);
void emitInstLength(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback);
void emitInstLengthFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos);
void emitInstNewTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstDupTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstSetList(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstGetUpval(AssemblyBuilderX64& build, const Instruction* pc, int pcpos); void emitInstGetUpval(AssemblyBuilderX64& build, const Instruction* pc, int pcpos);
void emitInstSetUpval(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstCloseUpvals(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstFastCall1(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr); void emitInstFastCall1(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstFastCall2(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr); void emitInstFastCall2(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstFastCall2K(AssemblyBuilderX64& build, const Instruction* pc, const TValue* k, int pcpos, Label* labelarr); void emitInstFastCall2K(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstFastCall(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr); void emitInstFastCall(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstForNPrep(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr); void emitInstForNPrep(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstForNLoop(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr); void emitInstForNLoop(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
@ -57,10 +62,21 @@ void emitInstAnd(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstAndK(AssemblyBuilderX64& build, const Instruction* pc); void emitInstAndK(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstOr(AssemblyBuilderX64& build, const Instruction* pc); void emitInstOr(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstOrK(AssemblyBuilderX64& build, const Instruction* pc); void emitInstOrK(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstGetTableN(AssemblyBuilderX64& build, const Instruction* pc, int pcpos); void emitInstGetTableN(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback);
void emitInstSetTableN(AssemblyBuilderX64& build, const Instruction* pc, int pcpos); void emitInstGetTableNFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos);
void emitInstGetTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpos); void emitInstSetTableN(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Label& fallback);
void emitInstSetTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpos); void emitInstSetTableNFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos);
void emitInstGetTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback);
void emitInstGetTableFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos);
void emitInstSetTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Label& fallback);
void emitInstSetTableFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos);
void emitInstGetImport(AssemblyBuilderX64& build, const Instruction* pc, Label& fallback);
void emitInstGetImportFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos);
void emitInstGetTableKS(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback);
void emitInstSetTableKS(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Label& fallback);
void emitInstGetGlobal(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback);
void emitInstSetGlobal(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Label& fallback);
void emitInstConcat(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
} // namespace CodeGen } // namespace CodeGen
} // namespace Luau } // namespace Luau

View file

@ -9,35 +9,12 @@
#include "lbuiltins.h" #include "lbuiltins.h"
#include "lgc.h" #include "lgc.h"
#include "ltable.h" #include "ltable.h"
#include "lfunc.h"
#include "lvm.h" #include "lvm.h"
#include <math.h> #include <math.h>
#define CODEGEN_SET_FALLBACK(op, flags) data.context.fallback[op] = {execute_##op, flags} #define CODEGEN_SET_FALLBACK(op, flags) data.context.fallback[op] = {execute_##op, flags}
#define CODEGEN_SET_NAME(op) data.names[op] = #op
// Similar to a dispatch table in lvmexecute.cpp
#define CODEGEN_SET_NAMES() \
CODEGEN_SET_NAME(LOP_NOP), CODEGEN_SET_NAME(LOP_BREAK), CODEGEN_SET_NAME(LOP_LOADNIL), CODEGEN_SET_NAME(LOP_LOADB), CODEGEN_SET_NAME(LOP_LOADN), \
CODEGEN_SET_NAME(LOP_LOADK), CODEGEN_SET_NAME(LOP_MOVE), CODEGEN_SET_NAME(LOP_GETGLOBAL), CODEGEN_SET_NAME(LOP_SETGLOBAL), \
CODEGEN_SET_NAME(LOP_GETUPVAL), CODEGEN_SET_NAME(LOP_SETUPVAL), CODEGEN_SET_NAME(LOP_CLOSEUPVALS), CODEGEN_SET_NAME(LOP_GETIMPORT), \
CODEGEN_SET_NAME(LOP_GETTABLE), CODEGEN_SET_NAME(LOP_SETTABLE), CODEGEN_SET_NAME(LOP_GETTABLEKS), CODEGEN_SET_NAME(LOP_SETTABLEKS), \
CODEGEN_SET_NAME(LOP_GETTABLEN), CODEGEN_SET_NAME(LOP_SETTABLEN), CODEGEN_SET_NAME(LOP_NEWCLOSURE), CODEGEN_SET_NAME(LOP_NAMECALL), \
CODEGEN_SET_NAME(LOP_CALL), CODEGEN_SET_NAME(LOP_RETURN), CODEGEN_SET_NAME(LOP_JUMP), CODEGEN_SET_NAME(LOP_JUMPBACK), \
CODEGEN_SET_NAME(LOP_JUMPIF), CODEGEN_SET_NAME(LOP_JUMPIFNOT), CODEGEN_SET_NAME(LOP_JUMPIFEQ), CODEGEN_SET_NAME(LOP_JUMPIFLE), \
CODEGEN_SET_NAME(LOP_JUMPIFLT), CODEGEN_SET_NAME(LOP_JUMPIFNOTEQ), CODEGEN_SET_NAME(LOP_JUMPIFNOTLE), CODEGEN_SET_NAME(LOP_JUMPIFNOTLT), \
CODEGEN_SET_NAME(LOP_ADD), CODEGEN_SET_NAME(LOP_SUB), CODEGEN_SET_NAME(LOP_MUL), CODEGEN_SET_NAME(LOP_DIV), CODEGEN_SET_NAME(LOP_MOD), \
CODEGEN_SET_NAME(LOP_POW), CODEGEN_SET_NAME(LOP_ADDK), CODEGEN_SET_NAME(LOP_SUBK), CODEGEN_SET_NAME(LOP_MULK), CODEGEN_SET_NAME(LOP_DIVK), \
CODEGEN_SET_NAME(LOP_MODK), CODEGEN_SET_NAME(LOP_POWK), CODEGEN_SET_NAME(LOP_AND), CODEGEN_SET_NAME(LOP_OR), CODEGEN_SET_NAME(LOP_ANDK), \
CODEGEN_SET_NAME(LOP_ORK), CODEGEN_SET_NAME(LOP_CONCAT), CODEGEN_SET_NAME(LOP_NOT), CODEGEN_SET_NAME(LOP_MINUS), \
CODEGEN_SET_NAME(LOP_LENGTH), CODEGEN_SET_NAME(LOP_NEWTABLE), CODEGEN_SET_NAME(LOP_DUPTABLE), CODEGEN_SET_NAME(LOP_SETLIST), \
CODEGEN_SET_NAME(LOP_FORNPREP), CODEGEN_SET_NAME(LOP_FORNLOOP), CODEGEN_SET_NAME(LOP_FORGLOOP), CODEGEN_SET_NAME(LOP_FORGPREP_INEXT), \
CODEGEN_SET_NAME(LOP_DEP_FORGLOOP_INEXT), CODEGEN_SET_NAME(LOP_FORGPREP_NEXT), CODEGEN_SET_NAME(LOP_DEP_FORGLOOP_NEXT), \
CODEGEN_SET_NAME(LOP_GETVARARGS), CODEGEN_SET_NAME(LOP_DUPCLOSURE), CODEGEN_SET_NAME(LOP_PREPVARARGS), CODEGEN_SET_NAME(LOP_LOADKX), \
CODEGEN_SET_NAME(LOP_JUMPX), CODEGEN_SET_NAME(LOP_FASTCALL), CODEGEN_SET_NAME(LOP_COVERAGE), CODEGEN_SET_NAME(LOP_CAPTURE), \
CODEGEN_SET_NAME(LOP_DEP_JUMPIFEQK), CODEGEN_SET_NAME(LOP_DEP_JUMPIFNOTEQK), CODEGEN_SET_NAME(LOP_FASTCALL1), \
CODEGEN_SET_NAME(LOP_FASTCALL2), CODEGEN_SET_NAME(LOP_FASTCALL2K), CODEGEN_SET_NAME(LOP_FORGPREP), CODEGEN_SET_NAME(LOP_JUMPXEQKNIL), \
CODEGEN_SET_NAME(LOP_JUMPXEQKB), CODEGEN_SET_NAME(LOP_JUMPXEQKN), CODEGEN_SET_NAME(LOP_JUMPXEQKS)
static int luauF_missing(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) static int luauF_missing(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams)
{ {
@ -61,21 +38,11 @@ NativeState::~NativeState() = default;
void initFallbackTable(NativeState& data) void initFallbackTable(NativeState& data)
{ {
CODEGEN_SET_FALLBACK(LOP_GETGLOBAL, 0); // TODO: lvmexecute_split.py could be taught to generate a subset of instructions we actually need
CODEGEN_SET_FALLBACK(LOP_SETGLOBAL, 0); CODEGEN_SET_FALLBACK(LOP_NEWCLOSURE, 0);
CODEGEN_SET_FALLBACK(LOP_SETUPVAL, 0);
CODEGEN_SET_FALLBACK(LOP_CLOSEUPVALS, 0);
CODEGEN_SET_FALLBACK(LOP_GETIMPORT, 0);
CODEGEN_SET_FALLBACK(LOP_GETTABLEKS, 0);
CODEGEN_SET_FALLBACK(LOP_SETTABLEKS, 0);
CODEGEN_SET_FALLBACK(LOP_NEWCLOSURE, kFallbackUpdatePc);
CODEGEN_SET_FALLBACK(LOP_NAMECALL, 0); CODEGEN_SET_FALLBACK(LOP_NAMECALL, 0);
CODEGEN_SET_FALLBACK(LOP_CALL, kFallbackUpdateCi | kFallbackCheckInterrupt); CODEGEN_SET_FALLBACK(LOP_CALL, kFallbackUpdateCi | kFallbackCheckInterrupt);
CODEGEN_SET_FALLBACK(LOP_RETURN, kFallbackUpdateCi | kFallbackCheckInterrupt); CODEGEN_SET_FALLBACK(LOP_RETURN, kFallbackUpdateCi | kFallbackCheckInterrupt);
CODEGEN_SET_FALLBACK(LOP_CONCAT, 0);
CODEGEN_SET_FALLBACK(LOP_NEWTABLE, 0);
CODEGEN_SET_FALLBACK(LOP_DUPTABLE, 0);
CODEGEN_SET_FALLBACK(LOP_SETLIST, kFallbackCheckInterrupt);
CODEGEN_SET_FALLBACK(LOP_FORGPREP, kFallbackUpdatePc); CODEGEN_SET_FALLBACK(LOP_FORGPREP, kFallbackUpdatePc);
CODEGEN_SET_FALLBACK(LOP_FORGLOOP, kFallbackUpdatePc | kFallbackCheckInterrupt); CODEGEN_SET_FALLBACK(LOP_FORGLOOP, kFallbackUpdatePc | kFallbackCheckInterrupt);
CODEGEN_SET_FALLBACK(LOP_FORGPREP_INEXT, kFallbackUpdatePc); CODEGEN_SET_FALLBACK(LOP_FORGPREP_INEXT, kFallbackUpdatePc);
@ -83,9 +50,14 @@ void initFallbackTable(NativeState& data)
CODEGEN_SET_FALLBACK(LOP_GETVARARGS, 0); CODEGEN_SET_FALLBACK(LOP_GETVARARGS, 0);
CODEGEN_SET_FALLBACK(LOP_DUPCLOSURE, 0); CODEGEN_SET_FALLBACK(LOP_DUPCLOSURE, 0);
CODEGEN_SET_FALLBACK(LOP_PREPVARARGS, 0); CODEGEN_SET_FALLBACK(LOP_PREPVARARGS, 0);
CODEGEN_SET_FALLBACK(LOP_LOADKX, 0);
CODEGEN_SET_FALLBACK(LOP_COVERAGE, 0); CODEGEN_SET_FALLBACK(LOP_COVERAGE, 0);
CODEGEN_SET_FALLBACK(LOP_BREAK, 0); CODEGEN_SET_FALLBACK(LOP_BREAK, 0);
// Fallbacks that are called from partial implementation of an instruction
CODEGEN_SET_FALLBACK(LOP_GETGLOBAL, 0);
CODEGEN_SET_FALLBACK(LOP_SETGLOBAL, 0);
CODEGEN_SET_FALLBACK(LOP_GETTABLEKS, 0);
CODEGEN_SET_FALLBACK(LOP_SETTABLEKS, 0);
} }
void initHelperFunctions(NativeState& data) void initHelperFunctions(NativeState& data)
@ -105,18 +77,23 @@ void initHelperFunctions(NativeState& data)
data.context.luaV_prepareFORN = luaV_prepareFORN; data.context.luaV_prepareFORN = luaV_prepareFORN;
data.context.luaV_gettable = luaV_gettable; data.context.luaV_gettable = luaV_gettable;
data.context.luaV_settable = luaV_settable; data.context.luaV_settable = luaV_settable;
data.context.luaV_getimport = luaV_getimport;
data.context.luaV_concat = luaV_concat;
data.context.luaH_getn = luaH_getn; data.context.luaH_getn = luaH_getn;
data.context.luaH_new = luaH_new;
data.context.luaH_clone = luaH_clone;
data.context.luaH_resizearray = luaH_resizearray;
data.context.luaC_barriertable = luaC_barriertable; data.context.luaC_barriertable = luaC_barriertable;
data.context.luaC_barrierf = luaC_barrierf;
data.context.luaC_barrierback = luaC_barrierback;
data.context.luaC_step = luaC_step;
data.context.luaF_close = luaF_close;
data.context.libm_pow = pow; data.context.libm_pow = pow;
} }
void initInstructionNames(NativeState& data)
{
CODEGEN_SET_NAMES();
}
} // namespace CodeGen } // namespace CodeGen
} // namespace Luau } // namespace Luau

View file

@ -61,10 +61,20 @@ struct NativeContext
void (*luaV_prepareFORN)(lua_State* L, StkId plimit, StkId pstep, StkId pinit) = nullptr; void (*luaV_prepareFORN)(lua_State* L, StkId plimit, StkId pstep, StkId pinit) = nullptr;
void (*luaV_gettable)(lua_State* L, const TValue* t, TValue* key, StkId val) = nullptr; void (*luaV_gettable)(lua_State* L, const TValue* t, TValue* key, StkId val) = nullptr;
void (*luaV_settable)(lua_State* L, const TValue* t, TValue* key, StkId val) = nullptr; void (*luaV_settable)(lua_State* L, const TValue* t, TValue* key, StkId val) = nullptr;
void (*luaV_getimport)(lua_State* L, Table* env, TValue* k, uint32_t id, bool propagatenil) = nullptr;
void (*luaV_concat)(lua_State* L, int total, int last) = nullptr;
int (*luaH_getn)(Table* t) = nullptr; int (*luaH_getn)(Table* t) = nullptr;
Table* (*luaH_new)(lua_State* L, int narray, int lnhash) = nullptr;
Table* (*luaH_clone)(lua_State* L, Table* tt) = nullptr;
void (*luaH_resizearray)(lua_State* L, Table* t, int nasize) = nullptr;
void (*luaC_barriertable)(lua_State* L, Table* t, GCObject* v) = nullptr; void (*luaC_barriertable)(lua_State* L, Table* t, GCObject* v) = nullptr;
void (*luaC_barrierf)(lua_State* L, GCObject* o, GCObject* v) = nullptr;
void (*luaC_barrierback)(lua_State* L, GCObject* o, GCObject** gclist) = nullptr;
size_t (*luaC_step)(lua_State* L, bool assist) = nullptr;
void (*luaF_close)(lua_State* L, StkId level) = nullptr;
double (*libm_pow)(double, double) = nullptr; double (*libm_pow)(double, double) = nullptr;
}; };
@ -77,9 +87,6 @@ struct NativeState
CodeAllocator codeAllocator; CodeAllocator codeAllocator;
std::unique_ptr<UnwindBuilder> unwindBuilder; std::unique_ptr<UnwindBuilder> unwindBuilder;
// For annotations in assembly text generation
const char* names[LOP__COUNT] = {};
uint8_t* gateData = nullptr; uint8_t* gateData = nullptr;
size_t gateDataSize = 0; size_t gateDataSize = 0;
@ -88,7 +95,6 @@ struct NativeState
void initFallbackTable(NativeState& data); void initFallbackTable(NativeState& data);
void initHelperFunctions(NativeState& data); void initHelperFunctions(NativeState& data);
void initInstructionNames(NativeState& data);
} // namespace CodeGen } // namespace CodeGen
} // namespace Luau } // namespace Luau

View file

@ -20,6 +20,10 @@
#define LUAU_DEBUGBREAK() __builtin_trap() #define LUAU_DEBUGBREAK() __builtin_trap()
#endif #endif
#if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
#define LUAU_BIG_ENDIAN
#endif
namespace Luau namespace Luau
{ {

View file

@ -117,6 +117,8 @@ public:
std::string dumpEverything() const; std::string dumpEverything() const;
std::string dumpSourceRemarks() const; std::string dumpSourceRemarks() const;
void annotateInstruction(std::string& result, uint32_t fid, uint32_t instid) const;
static uint32_t getImportId(int32_t id0); static uint32_t getImportId(int32_t id0);
static uint32_t getImportId(int32_t id0, int32_t id1); static uint32_t getImportId(int32_t id0, int32_t id1);
static uint32_t getImportId(int32_t id0, int32_t id1, int32_t id2); static uint32_t getImportId(int32_t id0, int32_t id1, int32_t id2);
@ -179,6 +181,7 @@ private:
std::string dump; std::string dump;
std::string dumpname; std::string dumpname;
std::vector<int> dumpinstoffs;
}; };
struct DebugLocal struct DebugLocal
@ -251,11 +254,13 @@ private:
std::vector<std::string> dumpSource; std::vector<std::string> dumpSource;
std::vector<std::pair<int, std::string>> dumpRemarks; std::vector<std::pair<int, std::string>> dumpRemarks;
std::string (BytecodeBuilder::*dumpFunctionPtr)() const = nullptr; std::string (BytecodeBuilder::*dumpFunctionPtr)(std::vector<int>&) const = nullptr;
void validate() const; void validate() const;
void validateInstructions() const;
void validateVariadic() const;
std::string dumpCurrentFunction() const; std::string dumpCurrentFunction(std::vector<int>& dumpinstoffs) const;
void dumpInstruction(const uint32_t* opcode, std::string& output, int targetLabel) const; void dumpInstruction(const uint32_t* opcode, std::string& output, int targetLabel) const;
void writeFunction(std::string& ss, uint32_t id) const; void writeFunction(std::string& ss, uint32_t id) const;

View file

@ -4,8 +4,6 @@
#include "Luau/Bytecode.h" #include "Luau/Bytecode.h"
#include "Luau/Compiler.h" #include "Luau/Compiler.h"
LUAU_FASTFLAGVARIABLE(LuauCompileBuiltinMT, false)
namespace Luau namespace Luau
{ {
namespace Compile namespace Compile
@ -66,13 +64,10 @@ static int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& op
if (builtin.isGlobal("select")) if (builtin.isGlobal("select"))
return LBF_SELECT_VARARG; return LBF_SELECT_VARARG;
if (FFlag::LuauCompileBuiltinMT) if (builtin.isGlobal("getmetatable"))
{ return LBF_GETMETATABLE;
if (builtin.isGlobal("getmetatable")) if (builtin.isGlobal("setmetatable"))
return LBF_GETMETATABLE; return LBF_SETMETATABLE;
if (builtin.isGlobal("setmetatable"))
return LBF_SETMETATABLE;
}
if (builtin.object == "math") if (builtin.object == "math")
{ {

View file

@ -269,7 +269,7 @@ void BytecodeBuilder::endFunction(uint8_t maxstacksize, uint8_t numupvalues)
// this call is indirect to make sure we only gain link time dependency on dumpCurrentFunction when needed // this call is indirect to make sure we only gain link time dependency on dumpCurrentFunction when needed
if (dumpFunctionPtr) if (dumpFunctionPtr)
func.dump = (this->*dumpFunctionPtr)(); func.dump = (this->*dumpFunctionPtr)(func.dumpinstoffs);
insns.clear(); insns.clear();
lines.clear(); lines.clear();
@ -1078,6 +1078,12 @@ uint8_t BytecodeBuilder::getVersion()
#ifdef LUAU_ASSERTENABLED #ifdef LUAU_ASSERTENABLED
void BytecodeBuilder::validate() const void BytecodeBuilder::validate() const
{
validateInstructions();
validateVariadic();
}
void BytecodeBuilder::validateInstructions() const
{ {
#define VREG(v) LUAU_ASSERT(unsigned(v) < func.maxstacksize) #define VREG(v) LUAU_ASSERT(unsigned(v) < func.maxstacksize)
#define VREGRANGE(v, count) LUAU_ASSERT(unsigned(v + (count < 0 ? 0 : count)) <= func.maxstacksize) #define VREGRANGE(v, count) LUAU_ASSERT(unsigned(v + (count < 0 ? 0 : count)) <= func.maxstacksize)
@ -1090,26 +1096,27 @@ void BytecodeBuilder::validate() const
const Function& func = functions[currentFunction]; const Function& func = functions[currentFunction];
// first pass: tag instruction offsets so that we can validate jumps // tag instruction offsets so that we can validate jumps
std::vector<uint8_t> insnvalid(insns.size(), false); std::vector<uint8_t> insnvalid(insns.size(), 0);
for (size_t i = 0; i < insns.size();) for (size_t i = 0; i < insns.size();)
{ {
uint8_t op = LUAU_INSN_OP(insns[i]); uint32_t insn = insns[i];
LuauOpcode op = LuauOpcode(LUAU_INSN_OP(insn));
insnvalid[i] = true; insnvalid[i] = true;
i += getOpLength(LuauOpcode(op)); i += getOpLength(op);
LUAU_ASSERT(i <= insns.size()); LUAU_ASSERT(i <= insns.size());
} }
std::vector<uint8_t> openCaptures; std::vector<uint8_t> openCaptures;
// second pass: validate the rest of the bytecode // validate individual instructions
for (size_t i = 0; i < insns.size();) for (size_t i = 0; i < insns.size();)
{ {
uint32_t insn = insns[i]; uint32_t insn = insns[i];
uint8_t op = LUAU_INSN_OP(insn); LuauOpcode op = LuauOpcode(LUAU_INSN_OP(insn));
switch (op) switch (op)
{ {
@ -1452,7 +1459,7 @@ void BytecodeBuilder::validate() const
LUAU_ASSERT(!"Unsupported opcode"); LUAU_ASSERT(!"Unsupported opcode");
} }
i += getOpLength(LuauOpcode(op)); i += getOpLength(op);
LUAU_ASSERT(i <= insns.size()); LUAU_ASSERT(i <= insns.size());
} }
@ -1469,6 +1476,126 @@ void BytecodeBuilder::validate() const
#undef VCONSTANY #undef VCONSTANY
#undef VJUMP #undef VJUMP
} }
void BytecodeBuilder::validateVariadic() const
{
// validate MULTRET sequences: instructions that produce a variadic sequence and consume one must come in pairs
// we classify instructions into four groups: producers, consumers, neutral and others
// any producer (an instruction that produces more than one value) must be followed by 0 or more neutral instructions
// and a consumer (that consumes more than one value); these form a variadic sequence.
// except for producer, no instruction in the variadic sequence may be a jump target.
// from the execution perspective, producer adjusts L->top to point to one past the last result, neutral instructions
// leave L->top unmodified, and consumer adjusts L->top back to the stack frame end.
// consumers invalidate all values after L->top after they execute (which we currently don't validate)
bool variadicSeq = false;
std::vector<uint8_t> insntargets(insns.size(), 0);
for (size_t i = 0; i < insns.size();)
{
uint32_t insn = insns[i];
LuauOpcode op = LuauOpcode(LUAU_INSN_OP(insn));
int target = getJumpTarget(insn, uint32_t(i));
if (target >= 0 && !isFastCall(op))
{
LUAU_ASSERT(unsigned(target) < insns.size());
insntargets[target] = true;
}
i += getOpLength(op);
LUAU_ASSERT(i <= insns.size());
}
for (size_t i = 0; i < insns.size();)
{
uint32_t insn = insns[i];
LuauOpcode op = LuauOpcode(LUAU_INSN_OP(insn));
if (variadicSeq)
{
// no instruction inside the sequence, including the consumer, may be a jump target
// this guarantees uninterrupted L->top adjustment flow
LUAU_ASSERT(!insntargets[i]);
}
if (op == LOP_CALL)
{
// note: calls may end one variadic sequence and start a new one
if (LUAU_INSN_B(insn) == 0)
{
// consumer instruction ens a variadic sequence
LUAU_ASSERT(variadicSeq);
variadicSeq = false;
}
else
{
// CALL is not a neutral instruction so it can't be present in a variadic sequence unless it's a consumer
LUAU_ASSERT(!variadicSeq);
}
if (LUAU_INSN_C(insn) == 0)
{
// producer instruction starts a variadic sequence
LUAU_ASSERT(!variadicSeq);
variadicSeq = true;
}
}
else if (op == LOP_GETVARARGS && LUAU_INSN_B(insn) == 0)
{
// producer instruction starts a variadic sequence
LUAU_ASSERT(!variadicSeq);
variadicSeq = true;
}
else if ((op == LOP_RETURN && LUAU_INSN_B(insn) == 0) || (op == LOP_SETLIST && LUAU_INSN_C(insn) == 0))
{
// consumer instruction ends a variadic sequence
LUAU_ASSERT(variadicSeq);
variadicSeq = false;
}
else if (op == LOP_FASTCALL)
{
int callTarget = int(i + LUAU_INSN_C(insn) + 1);
LUAU_ASSERT(unsigned(callTarget) < insns.size() && LUAU_INSN_OP(insns[callTarget]) == LOP_CALL);
if (LUAU_INSN_B(insns[callTarget]) == 0)
{
// consumer instruction ends a variadic sequence; however, we can't terminate it yet because future analysis of CALL will do it
// during FASTCALL fallback, the instructions between this and CALL consumer are going to be executed before L->top so they must
// be neutral; as such, we will defer termination of variadic sequence until CALL analysis
LUAU_ASSERT(variadicSeq);
}
else
{
// FASTCALL is not a neutral instruction so it can't be present in a variadic sequence unless it's linked to CALL consumer
LUAU_ASSERT(!variadicSeq);
}
// note: if FASTCALL is linked to a CALL producer, the instructions between FASTCALL and CALL are technically not part of an executed
// variadic sequence since they are never executed if FASTCALL does anything, so it's okay to skip their validation until CALL
// (we can't simply start a variadic sequence here because that would trigger assertions during linked CALL validation)
}
else if (op == LOP_CLOSEUPVALS || op == LOP_NAMECALL || op == LOP_GETIMPORT || op == LOP_MOVE || op == LOP_GETUPVAL || op == LOP_GETGLOBAL ||
op == LOP_GETTABLEKS || op == LOP_COVERAGE)
{
// instructions inside a variadic sequence must be neutral (can't change L->top)
// while there are many neutral instructions like this, here we check that the instruction is one of the few
// that we'd expect to exist in FASTCALL fallback sequences or between consecutive CALLs for encoding reasons
}
else
{
LUAU_ASSERT(!variadicSeq);
}
i += getOpLength(op);
LUAU_ASSERT(i <= insns.size());
}
LUAU_ASSERT(!variadicSeq);
}
#endif #endif
void BytecodeBuilder::dumpInstruction(const uint32_t* code, std::string& result, int targetLabel) const void BytecodeBuilder::dumpInstruction(const uint32_t* code, std::string& result, int targetLabel) const
@ -1800,7 +1927,7 @@ void BytecodeBuilder::dumpInstruction(const uint32_t* code, std::string& result,
} }
} }
std::string BytecodeBuilder::dumpCurrentFunction() const std::string BytecodeBuilder::dumpCurrentFunction(std::vector<int>& dumpinstoffs) const
{ {
if ((dumpFlags & Dump_Code) == 0) if ((dumpFlags & Dump_Code) == 0)
return std::string(); return std::string();
@ -1850,11 +1977,15 @@ std::string BytecodeBuilder::dumpCurrentFunction() const
if (labels[i] == 0) if (labels[i] == 0)
labels[i] = nextLabel++; labels[i] = nextLabel++;
dumpinstoffs.reserve(insns.size());
for (size_t i = 0; i < insns.size();) for (size_t i = 0; i < insns.size();)
{ {
const uint32_t* code = &insns[i]; const uint32_t* code = &insns[i];
uint8_t op = LUAU_INSN_OP(*code); uint8_t op = LUAU_INSN_OP(*code);
dumpinstoffs.push_back(int(result.size()));
if (op == LOP_PREPVARARGS) if (op == LOP_PREPVARARGS)
{ {
// Don't emit function header in bytecode - it's used for call dispatching and doesn't contain "interesting" information // Don't emit function header in bytecode - it's used for call dispatching and doesn't contain "interesting" information
@ -1897,6 +2028,8 @@ std::string BytecodeBuilder::dumpCurrentFunction() const
LUAU_ASSERT(i <= insns.size()); LUAU_ASSERT(i <= insns.size());
} }
dumpinstoffs.push_back(int(result.size()));
return result; return result;
} }
@ -1986,4 +2119,20 @@ std::string BytecodeBuilder::dumpSourceRemarks() const
return result; return result;
} }
void BytecodeBuilder::annotateInstruction(std::string& result, uint32_t fid, uint32_t instid) const
{
if ((dumpFlags & Dump_Code) == 0)
return;
LUAU_ASSERT(fid < functions.size());
const Function& function = functions[fid];
const std::string& dump = function.dump;
const std::vector<int>& dumpinstoffs = function.dumpinstoffs;
LUAU_ASSERT(instid + 1 < dumpinstoffs.size());
formatAppend(result, "%.*s", dumpinstoffs[instid + 1] - dumpinstoffs[instid], dump.data() + dumpinstoffs[instid]);
}
} // namespace Luau } // namespace Luau

View file

@ -104,7 +104,9 @@ target_sources(Luau.Analysis PRIVATE
Analysis/include/Luau/Constraint.h Analysis/include/Luau/Constraint.h
Analysis/include/Luau/ConstraintGraphBuilder.h Analysis/include/Luau/ConstraintGraphBuilder.h
Analysis/include/Luau/ConstraintSolver.h Analysis/include/Luau/ConstraintSolver.h
Analysis/include/Luau/DataFlowGraphBuilder.h
Analysis/include/Luau/DcrLogger.h Analysis/include/Luau/DcrLogger.h
Analysis/include/Luau/Def.h
Analysis/include/Luau/Documentation.h Analysis/include/Luau/Documentation.h
Analysis/include/Luau/Error.h Analysis/include/Luau/Error.h
Analysis/include/Luau/FileResolver.h Analysis/include/Luau/FileResolver.h
@ -114,6 +116,7 @@ target_sources(Luau.Analysis PRIVATE
Analysis/include/Luau/JsonEmitter.h Analysis/include/Luau/JsonEmitter.h
Analysis/include/Luau/Linter.h Analysis/include/Luau/Linter.h
Analysis/include/Luau/LValue.h Analysis/include/Luau/LValue.h
Analysis/include/Luau/Metamethods.h
Analysis/include/Luau/Module.h Analysis/include/Luau/Module.h
Analysis/include/Luau/ModuleResolver.h Analysis/include/Luau/ModuleResolver.h
Analysis/include/Luau/Normalize.h Analysis/include/Luau/Normalize.h
@ -154,7 +157,9 @@ target_sources(Luau.Analysis PRIVATE
Analysis/src/Constraint.cpp Analysis/src/Constraint.cpp
Analysis/src/ConstraintGraphBuilder.cpp Analysis/src/ConstraintGraphBuilder.cpp
Analysis/src/ConstraintSolver.cpp Analysis/src/ConstraintSolver.cpp
Analysis/src/DataFlowGraphBuilder.cpp
Analysis/src/DcrLogger.cpp Analysis/src/DcrLogger.cpp
Analysis/src/Def.cpp
Analysis/src/EmbeddedBuiltinDefinitions.cpp Analysis/src/EmbeddedBuiltinDefinitions.cpp
Analysis/src/Error.cpp Analysis/src/Error.cpp
Analysis/src/Frontend.cpp Analysis/src/Frontend.cpp
@ -301,9 +306,9 @@ if(TARGET Luau.UnitTest)
tests/CodeAllocator.test.cpp tests/CodeAllocator.test.cpp
tests/Compiler.test.cpp tests/Compiler.test.cpp
tests/Config.test.cpp tests/Config.test.cpp
tests/ConstraintGraphBuilder.test.cpp
tests/ConstraintSolver.test.cpp tests/ConstraintSolver.test.cpp
tests/CostModel.test.cpp tests/CostModel.test.cpp
tests/DataFlowGraphBuilder.test.cpp
tests/Error.test.cpp tests/Error.test.cpp
tests/Frontend.test.cpp tests/Frontend.test.cpp
tests/JsonEmitter.test.cpp tests/JsonEmitter.test.cpp
@ -372,7 +377,7 @@ if(TARGET Luau.CLI.Test)
CLI/Profiler.h CLI/Profiler.h
CLI/Profiler.cpp CLI/Profiler.cpp
CLI/Repl.cpp CLI/Repl.cpp
tests/Repl.test.cpp tests/Repl.test.cpp
tests/main.cpp) tests/main.cpp)
endif() endif()

View file

@ -12,8 +12,6 @@
#include <string.h> #include <string.h>
#include <stdio.h> #include <stdio.h>
LUAU_FASTFLAGVARIABLE(LuauFasterGetInfo, false)
static const char* getfuncname(Closure* f); static const char* getfuncname(Closure* f);
static int currentpc(lua_State* L, CallInfo* ci) static int currentpc(lua_State* L, CallInfo* ci)
@ -105,8 +103,7 @@ static Closure* auxgetinfo(lua_State* L, const char* what, lua_Debug* ar, Closur
ar->source = "=[C]"; ar->source = "=[C]";
ar->what = "C"; ar->what = "C";
ar->linedefined = -1; ar->linedefined = -1;
if (FFlag::LuauFasterGetInfo) ar->short_src = "[C]";
ar->short_src = "[C]";
} }
else else
{ {
@ -114,13 +111,7 @@ static Closure* auxgetinfo(lua_State* L, const char* what, lua_Debug* ar, Closur
ar->source = getstr(source); ar->source = getstr(source);
ar->what = "Lua"; ar->what = "Lua";
ar->linedefined = f->l.p->linedefined; ar->linedefined = f->l.p->linedefined;
if (FFlag::LuauFasterGetInfo) ar->short_src = luaO_chunkid(ar->ssbuf, sizeof(ar->ssbuf), getstr(source), source->len);
ar->short_src = luaO_chunkid(ar->ssbuf, sizeof(ar->ssbuf), getstr(source), source->len);
}
if (!FFlag::LuauFasterGetInfo)
{
luaO_chunkid(ar->ssbuf, LUA_IDSIZE, ar->source, 0);
ar->short_src = ar->ssbuf;
} }
break; break;
} }
@ -195,25 +186,12 @@ int lua_getinfo(lua_State* L, int level, const char* what, lua_Debug* ar)
} }
if (f) if (f)
{ {
if (FFlag::LuauFasterGetInfo) // auxgetinfo fills ar and optionally requests to put closure on stack
if (Closure* fcl = auxgetinfo(L, what, ar, f, ci))
{ {
// auxgetinfo fills ar and optionally requests to put closure on stack luaC_threadbarrier(L);
if (Closure* fcl = auxgetinfo(L, what, ar, f, ci)) setclvalue(L, L->top, fcl);
{ incr_top(L);
luaC_threadbarrier(L);
setclvalue(L, L->top, fcl);
incr_top(L);
}
}
else
{
auxgetinfo(L, what, ar, f, ci);
if (strchr(what, 'f'))
{
luaC_threadbarrier(L);
setclvalue(L, L->top, f);
incr_top(L);
}
} }
} }
return f ? 1 : 0; return f ? 1 : 0;

View file

@ -13,8 +13,6 @@
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
LUAU_FASTFLAG(LuauFasterGetInfo)
const TValue luaO_nilobject_ = {{NULL}, {0}, LUA_TNIL}; const TValue luaO_nilobject_ = {{NULL}, {0}, LUA_TNIL};
int luaO_log2(unsigned int x) int luaO_log2(unsigned int x)
@ -121,48 +119,20 @@ const char* luaO_chunkid(char* buf, size_t buflen, const char* source, size_t sr
{ {
if (*source == '=') if (*source == '=')
{ {
if (FFlag::LuauFasterGetInfo) if (srclen <= buflen)
{ return source + 1;
if (srclen <= buflen) // truncate the part after =
return source + 1; memcpy(buf, source + 1, buflen - 1);
// truncate the part after = buf[buflen - 1] = '\0';
memcpy(buf, source + 1, buflen - 1);
buf[buflen - 1] = '\0';
}
else
{
source++; // skip the `='
size_t len = strlen(source);
size_t dstlen = len < buflen ? len : buflen - 1;
memcpy(buf, source, dstlen);
buf[dstlen] = '\0';
}
} }
else if (*source == '@') else if (*source == '@')
{ {
if (FFlag::LuauFasterGetInfo) if (srclen <= buflen)
{ return source + 1;
if (srclen <= buflen) // truncate the part after @
return source + 1; memcpy(buf, "...", 3);
// truncate the part after @ memcpy(buf + 3, source + srclen - (buflen - 4), buflen - 4);
memcpy(buf, "...", 3); buf[buflen - 1] = '\0';
memcpy(buf + 3, source + srclen - (buflen - 4), buflen - 4);
buf[buflen - 1] = '\0';
}
else
{
size_t l;
source++; // skip the `@'
buflen -= sizeof("...");
l = strlen(source);
strcpy(buf, "");
if (l > buflen)
{
source += (l - buflen); // get last part of file name
strcat(buf, "...");
}
strcat(buf, source);
}
} }
else else
{ // buf = [string "string"] { // buf = [string "string"]

View file

@ -758,6 +758,8 @@ reentry:
Proto* pv = cl->l.p->p[LUAU_INSN_D(insn)]; Proto* pv = cl->l.p->p[LUAU_INSN_D(insn)];
LUAU_ASSERT(unsigned(LUAU_INSN_D(insn)) < unsigned(cl->l.p->sizep)); LUAU_ASSERT(unsigned(LUAU_INSN_D(insn)) < unsigned(cl->l.p->sizep));
VM_PROTECT_PC(); // luaF_newLclosure may fail due to OOM
// note: we save closure to stack early in case the code below wants to capture it by value // note: we save closure to stack early in case the code below wants to capture it by value
Closure* ncl = luaF_newLclosure(L, pv->nups, cl->env, pv); Closure* ncl = luaF_newLclosure(L, pv->nups, cl->env, pv);
setclvalue(L, ra, ncl); setclvalue(L, ra, ncl);
@ -2056,6 +2058,8 @@ reentry:
int b = LUAU_INSN_B(insn); int b = LUAU_INSN_B(insn);
uint32_t aux = *pc++; uint32_t aux = *pc++;
VM_PROTECT_PC(); // luaH_new may fail due to OOM
sethvalue(L, ra, luaH_new(L, aux, b == 0 ? 0 : (1 << (b - 1)))); sethvalue(L, ra, luaH_new(L, aux, b == 0 ? 0 : (1 << (b - 1))));
VM_PROTECT(luaC_checkGC(L)); VM_PROTECT(luaC_checkGC(L));
VM_NEXT(); VM_NEXT();
@ -2067,6 +2071,8 @@ reentry:
StkId ra = VM_REG(LUAU_INSN_A(insn)); StkId ra = VM_REG(LUAU_INSN_A(insn));
TValue* kv = VM_KV(LUAU_INSN_D(insn)); TValue* kv = VM_KV(LUAU_INSN_D(insn));
VM_PROTECT_PC(); // luaH_clone may fail due to OOM
sethvalue(L, ra, luaH_clone(L, hvalue(kv))); sethvalue(L, ra, luaH_clone(L, hvalue(kv)));
VM_PROTECT(luaC_checkGC(L)); VM_PROTECT(luaC_checkGC(L));
VM_NEXT(); VM_NEXT();
@ -2185,7 +2191,8 @@ reentry:
// protect against __iter returning nil, since nil is used as a marker for builtin iteration in FORGLOOP // protect against __iter returning nil, since nil is used as a marker for builtin iteration in FORGLOOP
if (ttisnil(ra)) if (ttisnil(ra))
{ {
VM_PROTECT(luaG_typeerror(L, ra, "call")); VM_PROTECT_PC(); // next call always errors
luaG_typeerror(L, ra, "call");
} }
} }
else if (fasttm(L, mt, TM_CALL)) else if (fasttm(L, mt, TM_CALL))
@ -2202,7 +2209,8 @@ reentry:
} }
else else
{ {
VM_PROTECT(luaG_typeerror(L, ra, "iterate over")); VM_PROTECT_PC(); // next call always errors
luaG_typeerror(L, ra, "iterate over");
} }
} }
@ -2325,7 +2333,8 @@ reentry:
} }
else if (!ttisfunction(ra)) else if (!ttisfunction(ra))
{ {
VM_PROTECT(luaG_typeerror(L, ra, "iterate over")); VM_PROTECT_PC(); // next call always errors
luaG_typeerror(L, ra, "iterate over");
} }
pc += LUAU_INSN_D(insn); pc += LUAU_INSN_D(insn);
@ -2353,7 +2362,8 @@ reentry:
} }
else if (!ttisfunction(ra)) else if (!ttisfunction(ra))
{ {
VM_PROTECT(luaG_typeerror(L, ra, "iterate over")); VM_PROTECT_PC(); // next call always errors
luaG_typeerror(L, ra, "iterate over");
} }
pc += LUAU_INSN_D(insn); pc += LUAU_INSN_D(insn);
@ -2404,6 +2414,8 @@ reentry:
Closure* kcl = clvalue(kv); Closure* kcl = clvalue(kv);
VM_PROTECT_PC(); // luaF_newLclosure may fail due to OOM
// clone closure if the environment is not shared // clone closure if the environment is not shared
// note: we save closure to stack early in case the code below wants to capture it by value // note: we save closure to stack early in case the code below wants to capture it by value
Closure* ncl = (kcl->env == cl->env) ? kcl : luaF_newLclosure(L, kcl->nupvalues, cl->env, kcl->l.p); Closure* ncl = (kcl->env == cl->env) ? kcl : luaF_newLclosure(L, kcl->nupvalues, cl->env, kcl->l.p);
@ -2533,7 +2545,7 @@ reentry:
if (cl->env->safeenv && f) if (cl->env->safeenv && f)
{ {
VM_PROTECT_PC(); VM_PROTECT_PC(); // f may fail due to OOM
int n = f(L, ra, ra + 1, nresults, ra + 2, nparams); int n = f(L, ra, ra + 1, nresults, ra + 2, nparams);
@ -2611,7 +2623,7 @@ reentry:
if (cl->env->safeenv && f) if (cl->env->safeenv && f)
{ {
VM_PROTECT_PC(); VM_PROTECT_PC(); // f may fail due to OOM
int n = f(L, ra, arg, nresults, NULL, nparams); int n = f(L, ra, arg, nresults, NULL, nparams);
@ -2667,7 +2679,7 @@ reentry:
if (cl->env->safeenv && f) if (cl->env->safeenv && f)
{ {
VM_PROTECT_PC(); VM_PROTECT_PC(); // f may fail due to OOM
int n = f(L, ra, arg1, nresults, arg2, nparams); int n = f(L, ra, arg1, nresults, arg2, nparams);
@ -2723,7 +2735,7 @@ reentry:
if (cl->env->safeenv && f) if (cl->env->safeenv && f)
{ {
VM_PROTECT_PC(); VM_PROTECT_PC(); // f may fail due to OOM
int n = f(L, ra, arg1, nresults, arg2, nparams); int n = f(L, ra, arg1, nresults, arg2, nparams);

View file

@ -1,4 +1,4 @@
--!non-strict --!nonstrict
local bench = script and require(script.Parent.bench_support) or require("bench_support") local bench = script and require(script.Parent.bench_support) or require("bench_support")
local stretchTreeDepth = 18 -- about 16Mb local stretchTreeDepth = 18 -- about 16Mb

View file

@ -264,6 +264,88 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "FormsOfImul")
SINGLE_COMPARE(imul(r12, rax, -13), 0x4c, 0x6b, 0xe0, 0xf3); SINGLE_COMPARE(imul(r12, rax, -13), 0x4c, 0x6b, 0xe0, 0xf3);
} }
TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "NopForms")
{
SINGLE_COMPARE(nop(), 0x90);
SINGLE_COMPARE(nop(2), 0x66, 0x90);
SINGLE_COMPARE(nop(3), 0x0f, 0x1f, 0x00);
SINGLE_COMPARE(nop(4), 0x0f, 0x1f, 0x40, 0x00);
SINGLE_COMPARE(nop(5), 0x0f, 0x1f, 0x44, 0x00, 0x00);
SINGLE_COMPARE(nop(6), 0x66, 0x0f, 0x1f, 0x44, 0x00, 0x00);
SINGLE_COMPARE(nop(7), 0x0f, 0x1f, 0x80, 0x00, 0x00, 0x00, 0x00);
SINGLE_COMPARE(nop(8), 0x0f, 0x1f, 0x84, 0x00, 0x00, 0x00, 0x00, 0x00);
SINGLE_COMPARE(nop(9), 0x66, 0x0f, 0x1f, 0x84, 0x00, 0x00, 0x00, 0x00, 0x00);
SINGLE_COMPARE(nop(15), 0x66, 0x0f, 0x1f, 0x84, 0x00, 0x00, 0x00, 0x00, 0x00, 0x66, 0x0f, 0x1f, 0x44, 0x00, 0x00); // 9+6
}
TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AlignmentForms")
{
check(
[](AssemblyBuilderX64& build) {
build.ret();
build.align(8, AlignmentDataX64::Nop);
},
{0xc3, 0x0f, 0x1f, 0x80, 0x00, 0x00, 0x00, 0x00});
check(
[](AssemblyBuilderX64& build) {
build.ret();
build.align(32, AlignmentDataX64::Nop);
},
{0xc3, 0x66, 0x0f, 0x1f, 0x84, 0x00, 0x00, 0x00, 0x00, 0x00, 0x66, 0x0f, 0x1f, 0x84, 0x00, 0x00, 0x00, 0x00, 0x00, 0x66, 0x0f, 0x1f, 0x84,
0x00, 0x00, 0x00, 0x00, 0x00, 0x0f, 0x1f, 0x40, 0x00});
check(
[](AssemblyBuilderX64& build) {
build.ret();
build.align(8, AlignmentDataX64::Int3);
},
{0xc3, 0xcc, 0xcc, 0xcc, 0xcc, 0xcc, 0xcc, 0xcc});
check(
[](AssemblyBuilderX64& build) {
build.ret();
build.align(8, AlignmentDataX64::Ud2);
},
{0xc3, 0x0f, 0x0b, 0x0f, 0x0b, 0x0f, 0x0b, 0xcc});
}
TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AlignmentOverflow")
{
// Test that alignment correctly resizes the code buffer
{
AssemblyBuilderX64 build(/* logText */ false);
build.ret();
build.align(8192, AlignmentDataX64::Nop);
build.finalize();
}
{
AssemblyBuilderX64 build(/* logText */ false);
build.ret();
build.align(8192, AlignmentDataX64::Int3);
build.finalize();
}
{
AssemblyBuilderX64 build(/* logText */ false);
for (int i = 0; i < 8192; i++)
build.int3();
build.finalize();
}
{
AssemblyBuilderX64 build(/* logText */ false);
build.ret();
build.align(8192, AlignmentDataX64::Ud2);
build.finalize();
}
}
TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "ControlFlow") TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "ControlFlow")
{ {
// Jump back // Jump back
@ -330,67 +412,67 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "LabelCall")
TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXBinaryInstructionForms") TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXBinaryInstructionForms")
{ {
SINGLE_COMPARE(vaddpd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0xa9, 0x58, 0xc6); SINGLE_COMPARE(vaddpd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x29, 0x58, 0xc6);
SINGLE_COMPARE(vaddpd(xmm8, xmm10, xmmword[r9]), 0xc4, 0x41, 0xa9, 0x58, 0x01); SINGLE_COMPARE(vaddpd(xmm8, xmm10, xmmword[r9]), 0xc4, 0x41, 0x29, 0x58, 0x01);
SINGLE_COMPARE(vaddpd(ymm8, ymm10, ymm14), 0xc4, 0x41, 0xad, 0x58, 0xc6); SINGLE_COMPARE(vaddpd(ymm8, ymm10, ymm14), 0xc4, 0x41, 0x2d, 0x58, 0xc6);
SINGLE_COMPARE(vaddpd(ymm8, ymm10, ymmword[r9]), 0xc4, 0x41, 0xad, 0x58, 0x01); SINGLE_COMPARE(vaddpd(ymm8, ymm10, ymmword[r9]), 0xc4, 0x41, 0x2d, 0x58, 0x01);
SINGLE_COMPARE(vaddps(xmm8, xmm10, xmm14), 0xc4, 0x41, 0xa8, 0x58, 0xc6); SINGLE_COMPARE(vaddps(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x28, 0x58, 0xc6);
SINGLE_COMPARE(vaddps(xmm8, xmm10, xmmword[r9]), 0xc4, 0x41, 0xa8, 0x58, 0x01); SINGLE_COMPARE(vaddps(xmm8, xmm10, xmmword[r9]), 0xc4, 0x41, 0x28, 0x58, 0x01);
SINGLE_COMPARE(vaddsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0xab, 0x58, 0xc6); SINGLE_COMPARE(vaddsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x2b, 0x58, 0xc6);
SINGLE_COMPARE(vaddsd(xmm8, xmm10, qword[r9]), 0xc4, 0x41, 0xab, 0x58, 0x01); SINGLE_COMPARE(vaddsd(xmm8, xmm10, qword[r9]), 0xc4, 0x41, 0x2b, 0x58, 0x01);
SINGLE_COMPARE(vaddss(xmm8, xmm10, xmm14), 0xc4, 0x41, 0xaa, 0x58, 0xc6); SINGLE_COMPARE(vaddss(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x2a, 0x58, 0xc6);
SINGLE_COMPARE(vaddss(xmm8, xmm10, dword[r9]), 0xc4, 0x41, 0xaa, 0x58, 0x01); SINGLE_COMPARE(vaddss(xmm8, xmm10, dword[r9]), 0xc4, 0x41, 0x2a, 0x58, 0x01);
SINGLE_COMPARE(vaddps(xmm1, xmm2, xmm3), 0xc4, 0xe1, 0xe8, 0x58, 0xcb); SINGLE_COMPARE(vaddps(xmm1, xmm2, xmm3), 0xc4, 0xe1, 0x68, 0x58, 0xcb);
SINGLE_COMPARE(vaddps(xmm9, xmm12, xmmword[r9 + r14 * 2 + 0x1c]), 0xc4, 0x01, 0x98, 0x58, 0x4c, 0x71, 0x1c); SINGLE_COMPARE(vaddps(xmm9, xmm12, xmmword[r9 + r14 * 2 + 0x1c]), 0xc4, 0x01, 0x18, 0x58, 0x4c, 0x71, 0x1c);
SINGLE_COMPARE(vaddps(ymm1, ymm2, ymm3), 0xc4, 0xe1, 0xec, 0x58, 0xcb); SINGLE_COMPARE(vaddps(ymm1, ymm2, ymm3), 0xc4, 0xe1, 0x6c, 0x58, 0xcb);
SINGLE_COMPARE(vaddps(ymm9, ymm12, ymmword[r9 + r14 * 2 + 0x1c]), 0xc4, 0x01, 0x9c, 0x58, 0x4c, 0x71, 0x1c); SINGLE_COMPARE(vaddps(ymm9, ymm12, ymmword[r9 + r14 * 2 + 0x1c]), 0xc4, 0x01, 0x1c, 0x58, 0x4c, 0x71, 0x1c);
// Coverage for other instructions that follow the same pattern // Coverage for other instructions that follow the same pattern
SINGLE_COMPARE(vsubsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0xab, 0x5c, 0xc6); SINGLE_COMPARE(vsubsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x2b, 0x5c, 0xc6);
SINGLE_COMPARE(vmulsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0xab, 0x59, 0xc6); SINGLE_COMPARE(vmulsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x2b, 0x59, 0xc6);
SINGLE_COMPARE(vdivsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0xab, 0x5e, 0xc6); SINGLE_COMPARE(vdivsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x2b, 0x5e, 0xc6);
SINGLE_COMPARE(vxorpd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0xa9, 0x57, 0xc6); SINGLE_COMPARE(vxorpd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x29, 0x57, 0xc6);
} }
TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXUnaryMergeInstructionForms") TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXUnaryMergeInstructionForms")
{ {
SINGLE_COMPARE(vsqrtpd(xmm8, xmm10), 0xc4, 0x41, 0xf9, 0x51, 0xc2); SINGLE_COMPARE(vsqrtpd(xmm8, xmm10), 0xc4, 0x41, 0x79, 0x51, 0xc2);
SINGLE_COMPARE(vsqrtpd(xmm8, xmmword[r9]), 0xc4, 0x41, 0xf9, 0x51, 0x01); SINGLE_COMPARE(vsqrtpd(xmm8, xmmword[r9]), 0xc4, 0x41, 0x79, 0x51, 0x01);
SINGLE_COMPARE(vsqrtpd(ymm8, ymm10), 0xc4, 0x41, 0xfd, 0x51, 0xc2); SINGLE_COMPARE(vsqrtpd(ymm8, ymm10), 0xc4, 0x41, 0x7d, 0x51, 0xc2);
SINGLE_COMPARE(vsqrtpd(ymm8, ymmword[r9]), 0xc4, 0x41, 0xfd, 0x51, 0x01); SINGLE_COMPARE(vsqrtpd(ymm8, ymmword[r9]), 0xc4, 0x41, 0x7d, 0x51, 0x01);
SINGLE_COMPARE(vsqrtps(xmm8, xmm10), 0xc4, 0x41, 0xf8, 0x51, 0xc2); SINGLE_COMPARE(vsqrtps(xmm8, xmm10), 0xc4, 0x41, 0x78, 0x51, 0xc2);
SINGLE_COMPARE(vsqrtps(xmm8, xmmword[r9]), 0xc4, 0x41, 0xf8, 0x51, 0x01); SINGLE_COMPARE(vsqrtps(xmm8, xmmword[r9]), 0xc4, 0x41, 0x78, 0x51, 0x01);
SINGLE_COMPARE(vsqrtsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0xab, 0x51, 0xc6); SINGLE_COMPARE(vsqrtsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x2b, 0x51, 0xc6);
SINGLE_COMPARE(vsqrtsd(xmm8, xmm10, qword[r9]), 0xc4, 0x41, 0xab, 0x51, 0x01); SINGLE_COMPARE(vsqrtsd(xmm8, xmm10, qword[r9]), 0xc4, 0x41, 0x2b, 0x51, 0x01);
SINGLE_COMPARE(vsqrtss(xmm8, xmm10, xmm14), 0xc4, 0x41, 0xaa, 0x51, 0xc6); SINGLE_COMPARE(vsqrtss(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x2a, 0x51, 0xc6);
SINGLE_COMPARE(vsqrtss(xmm8, xmm10, dword[r9]), 0xc4, 0x41, 0xaa, 0x51, 0x01); SINGLE_COMPARE(vsqrtss(xmm8, xmm10, dword[r9]), 0xc4, 0x41, 0x2a, 0x51, 0x01);
// Coverage for other instructions that follow the same pattern // Coverage for other instructions that follow the same pattern
SINGLE_COMPARE(vucomisd(xmm1, xmm4), 0xc4, 0xe1, 0xf9, 0x2e, 0xcc); SINGLE_COMPARE(vucomisd(xmm1, xmm4), 0xc4, 0xe1, 0x79, 0x2e, 0xcc);
} }
TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXMoveInstructionForms") TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXMoveInstructionForms")
{ {
SINGLE_COMPARE(vmovsd(qword[r9], xmm10), 0xc4, 0x41, 0xfb, 0x11, 0x11); SINGLE_COMPARE(vmovsd(qword[r9], xmm10), 0xc4, 0x41, 0x7b, 0x11, 0x11);
SINGLE_COMPARE(vmovsd(xmm8, qword[r9]), 0xc4, 0x41, 0xfb, 0x10, 0x01); SINGLE_COMPARE(vmovsd(xmm8, qword[r9]), 0xc4, 0x41, 0x7b, 0x10, 0x01);
SINGLE_COMPARE(vmovsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0xab, 0x10, 0xc6); SINGLE_COMPARE(vmovsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x2b, 0x10, 0xc6);
SINGLE_COMPARE(vmovss(dword[r9], xmm10), 0xc4, 0x41, 0xfa, 0x11, 0x11); SINGLE_COMPARE(vmovss(dword[r9], xmm10), 0xc4, 0x41, 0x7a, 0x11, 0x11);
SINGLE_COMPARE(vmovss(xmm8, dword[r9]), 0xc4, 0x41, 0xfa, 0x10, 0x01); SINGLE_COMPARE(vmovss(xmm8, dword[r9]), 0xc4, 0x41, 0x7a, 0x10, 0x01);
SINGLE_COMPARE(vmovss(xmm8, xmm10, xmm14), 0xc4, 0x41, 0xaa, 0x10, 0xc6); SINGLE_COMPARE(vmovss(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x2a, 0x10, 0xc6);
SINGLE_COMPARE(vmovapd(xmm8, xmmword[r9]), 0xc4, 0x41, 0xf9, 0x28, 0x01); SINGLE_COMPARE(vmovapd(xmm8, xmmword[r9]), 0xc4, 0x41, 0x79, 0x28, 0x01);
SINGLE_COMPARE(vmovapd(xmmword[r9], xmm10), 0xc4, 0x41, 0xf9, 0x29, 0x11); SINGLE_COMPARE(vmovapd(xmmword[r9], xmm10), 0xc4, 0x41, 0x79, 0x29, 0x11);
SINGLE_COMPARE(vmovapd(ymm8, ymmword[r9]), 0xc4, 0x41, 0xfd, 0x28, 0x01); SINGLE_COMPARE(vmovapd(ymm8, ymmword[r9]), 0xc4, 0x41, 0x7d, 0x28, 0x01);
SINGLE_COMPARE(vmovaps(xmm8, xmmword[r9]), 0xc4, 0x41, 0xf8, 0x28, 0x01); SINGLE_COMPARE(vmovaps(xmm8, xmmword[r9]), 0xc4, 0x41, 0x78, 0x28, 0x01);
SINGLE_COMPARE(vmovaps(xmmword[r9], xmm10), 0xc4, 0x41, 0xf8, 0x29, 0x11); SINGLE_COMPARE(vmovaps(xmmword[r9], xmm10), 0xc4, 0x41, 0x78, 0x29, 0x11);
SINGLE_COMPARE(vmovaps(ymm8, ymmword[r9]), 0xc4, 0x41, 0xfc, 0x28, 0x01); SINGLE_COMPARE(vmovaps(ymm8, ymmword[r9]), 0xc4, 0x41, 0x7c, 0x28, 0x01);
SINGLE_COMPARE(vmovupd(xmm8, xmmword[r9]), 0xc4, 0x41, 0xf9, 0x10, 0x01); SINGLE_COMPARE(vmovupd(xmm8, xmmword[r9]), 0xc4, 0x41, 0x79, 0x10, 0x01);
SINGLE_COMPARE(vmovupd(xmmword[r9], xmm10), 0xc4, 0x41, 0xf9, 0x11, 0x11); SINGLE_COMPARE(vmovupd(xmmword[r9], xmm10), 0xc4, 0x41, 0x79, 0x11, 0x11);
SINGLE_COMPARE(vmovupd(ymm8, ymmword[r9]), 0xc4, 0x41, 0xfd, 0x10, 0x01); SINGLE_COMPARE(vmovupd(ymm8, ymmword[r9]), 0xc4, 0x41, 0x7d, 0x10, 0x01);
SINGLE_COMPARE(vmovups(xmm8, xmmword[r9]), 0xc4, 0x41, 0xf8, 0x10, 0x01); SINGLE_COMPARE(vmovups(xmm8, xmmword[r9]), 0xc4, 0x41, 0x78, 0x10, 0x01);
SINGLE_COMPARE(vmovups(xmmword[r9], xmm10), 0xc4, 0x41, 0xf8, 0x11, 0x11); SINGLE_COMPARE(vmovups(xmmword[r9], xmm10), 0xc4, 0x41, 0x78, 0x11, 0x11);
SINGLE_COMPARE(vmovups(ymm8, ymmword[r9]), 0xc4, 0x41, 0xfc, 0x10, 0x01); SINGLE_COMPARE(vmovups(ymm8, ymmword[r9]), 0xc4, 0x41, 0x7c, 0x10, 0x01);
} }
TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXConversionInstructionForms") TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXConversionInstructionForms")
@ -407,10 +489,10 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXConversionInstructionForms")
TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXTernaryInstructionForms") TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXTernaryInstructionForms")
{ {
SINGLE_COMPARE(vroundsd(xmm7, xmm12, xmm3, RoundingModeX64::RoundToNegativeInfinity), 0xc4, 0xe3, 0x99, 0x0b, 0xfb, 0x09); SINGLE_COMPARE(vroundsd(xmm7, xmm12, xmm3, RoundingModeX64::RoundToNegativeInfinity), 0xc4, 0xe3, 0x19, 0x0b, 0xfb, 0x09);
SINGLE_COMPARE( SINGLE_COMPARE(
vroundsd(xmm8, xmm13, xmmword[r13 + rdx], RoundingModeX64::RoundToPositiveInfinity), 0xc4, 0x43, 0x91, 0x0b, 0x44, 0x15, 0x00, 0x0a); vroundsd(xmm8, xmm13, xmmword[r13 + rdx], RoundingModeX64::RoundToPositiveInfinity), 0xc4, 0x43, 0x11, 0x0b, 0x44, 0x15, 0x00, 0x0a);
SINGLE_COMPARE(vroundsd(xmm9, xmm14, xmmword[rcx + r10], RoundingModeX64::RoundToZero), 0xc4, 0x23, 0x89, 0x0b, 0x0c, 0x11, 0x0b); SINGLE_COMPARE(vroundsd(xmm9, xmm14, xmmword[rcx + r10], RoundingModeX64::RoundToZero), 0xc4, 0x23, 0x09, 0x0b, 0x0c, 0x11, 0x0b);
} }
TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "MiscInstructions") TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "MiscInstructions")
@ -423,6 +505,10 @@ TEST_CASE("LogTest")
AssemblyBuilderX64 build(/* logText= */ true); AssemblyBuilderX64 build(/* logText= */ true);
build.push(r12); build.push(r12);
build.align(8);
build.align(8, AlignmentDataX64::Int3);
build.align(8, AlignmentDataX64::Ud2);
build.add(rax, rdi); build.add(rax, rdi);
build.add(rcx, 8); build.add(rcx, 8);
build.sub(dword[rax], 0x1fdc); build.sub(dword[rax], 0x1fdc);
@ -445,14 +531,29 @@ TEST_CASE("LogTest")
build.imul(rcx, rdx); build.imul(rcx, rdx);
build.imul(rcx, rdx, 8); build.imul(rcx, rdx, 8);
build.vroundsd(xmm1, xmm2, xmm3, RoundingModeX64::RoundToNearestEven); build.vroundsd(xmm1, xmm2, xmm3, RoundingModeX64::RoundToNearestEven);
build.add(rdx, qword[rcx - 12]);
build.pop(r12); build.pop(r12);
build.ret(); build.ret();
build.int3(); build.int3();
build.nop();
build.nop(2);
build.nop(3);
build.nop(4);
build.nop(5);
build.nop(6);
build.nop(7);
build.nop(8);
build.nop(9);
build.finalize(); build.finalize();
bool same = "\n" + build.text == R"( bool same = "\n" + build.text == R"(
push r12 push r12
; align 8
nop word ptr[rax+rax] ; 6-byte nop
; align 8 using int3
; align 8 using ud2
add rax,rdi add rax,rdi
add rcx,8 add rcx,8
sub dword ptr [rax],1FDCh sub dword ptr [rax],1FDCh
@ -473,9 +574,19 @@ TEST_CASE("LogTest")
imul rcx,rdx imul rcx,rdx
imul rcx,rdx,8 imul rcx,rdx,8
vroundsd xmm1,xmm2,xmm3,8 vroundsd xmm1,xmm2,xmm3,8
add rdx,qword ptr [rcx-0Ch]
pop r12 pop r12
ret ret
int3 int3
nop
xchg ax, ax ; 2-byte nop
nop dword ptr[rax] ; 3-byte nop
nop dword ptr[rax] ; 4-byte nop
nop dword ptr[rax+rax] ; 5-byte nop
nop word ptr[rax+rax] ; 6-byte nop
nop dword ptr[rax] ; 7-byte nop
nop dword ptr[rax+rax] ; 8-byte nop
nop word ptr[rax+rax] ; 9-byte nop
)"; )";
CHECK(same); CHECK(same);
} }
@ -497,10 +608,10 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "Constants")
{ {
0x48, 0x33, 0xc0, 0x48, 0x33, 0xc0,
0x48, 0x03, 0x05, 0xee, 0xff, 0xff, 0xff, 0x48, 0x03, 0x05, 0xee, 0xff, 0xff, 0xff,
0xc4, 0xe1, 0xfa, 0x10, 0x15, 0xe1, 0xff, 0xff, 0xff, 0xc4, 0xe1, 0x7a, 0x10, 0x15, 0xe1, 0xff, 0xff, 0xff,
0xc4, 0xe1, 0xfb, 0x10, 0x1d, 0xcc, 0xff, 0xff, 0xff, 0xc4, 0xe1, 0x7b, 0x10, 0x1d, 0xcc, 0xff, 0xff, 0xff,
0xc4, 0xe1, 0xf8, 0x28, 0x25, 0xab, 0xff, 0xff, 0xff, 0xc4, 0xe1, 0x78, 0x28, 0x25, 0xab, 0xff, 0xff, 0xff,
0xc4, 0xe1, 0xf9, 0x10, 0x2d, 0x92, 0xff, 0xff, 0xff, 0xc4, 0xe1, 0x79, 0x10, 0x2d, 0x92, 0xff, 0xff, 0xff,
0xc3 0xc3
}, },
{ {

View file

@ -41,9 +41,9 @@ TEST_CASE("CodeAllocation")
REQUIRE(allocator.allocate(data.data(), data.size(), code.data(), code.size(), nativeData, sizeNativeData, nativeEntry)); REQUIRE(allocator.allocate(data.data(), data.size(), code.data(), code.size(), nativeData, sizeNativeData, nativeEntry));
CHECK(nativeData != nullptr); CHECK(nativeData != nullptr);
CHECK(sizeNativeData == 16 + 128); CHECK(sizeNativeData == kCodeAlignment + 128);
CHECK(nativeEntry != nullptr); CHECK(nativeEntry != nullptr);
CHECK(nativeEntry == nativeData + 16); CHECK(nativeEntry == nativeData + kCodeAlignment);
} }
TEST_CASE("CodeAllocationFailure") TEST_CASE("CodeAllocationFailure")
@ -118,15 +118,16 @@ TEST_CASE("CodeAllocationWithUnwindCallbacks")
REQUIRE(allocator.allocate(data.data(), data.size(), code.data(), code.size(), nativeData, sizeNativeData, nativeEntry)); REQUIRE(allocator.allocate(data.data(), data.size(), code.data(), code.size(), nativeData, sizeNativeData, nativeEntry));
CHECK(nativeData != nullptr); CHECK(nativeData != nullptr);
CHECK(sizeNativeData == 16 + 128); CHECK(sizeNativeData == kCodeAlignment + 128);
CHECK(nativeEntry != nullptr); CHECK(nativeEntry != nullptr);
CHECK(nativeEntry == nativeData + 16); CHECK(nativeEntry == nativeData + kCodeAlignment);
CHECK(nativeData == info.block + 16); CHECK(nativeData == info.block + kCodeAlignment);
} }
CHECK(info.destroyCalled); CHECK(info.destroyCalled);
} }
#if !defined(LUAU_BIG_ENDIAN)
TEST_CASE("WindowsUnwindCodesX64") TEST_CASE("WindowsUnwindCodesX64")
{ {
UnwindBuilderWin unwind; UnwindBuilderWin unwind;
@ -156,6 +157,7 @@ TEST_CASE("WindowsUnwindCodesX64")
REQUIRE(data.size() == expected.size()); REQUIRE(data.size() == expected.size());
CHECK(memcmp(data.data(), expected.data(), expected.size()) == 0); CHECK(memcmp(data.data(), expected.data(), expected.size()) == 0);
} }
#endif
TEST_CASE("Dwarf2UnwindCodesX64") TEST_CASE("Dwarf2UnwindCodesX64")
{ {

View file

@ -798,8 +798,6 @@ RETURN R0 1
TEST_CASE("TableSizePredictionSetMetatable") TEST_CASE("TableSizePredictionSetMetatable")
{ {
ScopedFastFlag sff("LuauCompileBuiltinMT", true);
CHECK_EQ("\n" + compileFunction0(R"( CHECK_EQ("\n" + compileFunction0(R"(
local t = setmetatable({}, nil) local t = setmetatable({}, nil)
t.field1 = 1 t.field1 = 1

View file

@ -536,9 +536,15 @@ TEST_CASE("Debugger")
{ {
static int breakhits = 0; static int breakhits = 0;
static lua_State* interruptedthread = nullptr; static lua_State* interruptedthread = nullptr;
static bool singlestep = false;
static int stephits = 0;
SUBCASE("") { singlestep = false; }
SUBCASE("SingleStep") { singlestep = true; }
breakhits = 0; breakhits = 0;
interruptedthread = nullptr; interruptedthread = nullptr;
stephits = 0;
lua_CompileOptions copts = defaultOptions(); lua_CompileOptions copts = defaultOptions();
copts.debugLevel = 2; copts.debugLevel = 2;
@ -548,6 +554,13 @@ TEST_CASE("Debugger")
[](lua_State* L) { [](lua_State* L) {
lua_Callbacks* cb = lua_callbacks(L); lua_Callbacks* cb = lua_callbacks(L);
lua_singlestep(L, singlestep);
// this will only be called in single-step mode
cb->debugstep = [](lua_State* L, lua_Debug* ar) {
stephits++;
};
// for breakpoints to work we should make sure debugbreak is installed // for breakpoints to work we should make sure debugbreak is installed
cb->debugbreak = [](lua_State* L, lua_Debug* ar) { cb->debugbreak = [](lua_State* L, lua_Debug* ar) {
breakhits++; breakhits++;
@ -667,6 +680,9 @@ TEST_CASE("Debugger")
nullptr, &copts, /* skipCodegen */ true); // Native code doesn't support debugging yet nullptr, &copts, /* skipCodegen */ true); // Native code doesn't support debugging yet
CHECK(breakhits == 12); // 2 hits per breakpoint CHECK(breakhits == 12); // 2 hits per breakpoint
if (singlestep)
CHECK(stephits > 100); // note; this will depend on number of instructions which can vary, so we just make sure the callback gets hit often
} }
TEST_CASE("SameHash") TEST_CASE("SameHash")
@ -1528,4 +1544,51 @@ TEST_CASE("SafeEnv")
runConformance("safeenv.lua"); runConformance("safeenv.lua");
} }
TEST_CASE("HugeFunction")
{
std::string source;
// add non-executed block that requires JUMPKX and generates a lot of constants that take available short (15-bit) constant space
source += "if ... then\n";
source += "local _ = {\n";
for (int i = 0; i < 40000; ++i)
{
source += "0.";
source += std::to_string(i);
source += ",";
}
source += "}\n";
source += "end\n";
// use failed fast-calls with imports and constants to exercise all of the more complex fallback sequences
source += "return bit32.lshift('84', -1)";
StateRef globalState(luaL_newstate(), lua_close);
lua_State* L = globalState.get();
if (codegen && Luau::CodeGen::isSupported())
Luau::CodeGen::create(L);
luaL_openlibs(L);
luaL_sandbox(L);
luaL_sandboxthread(L);
size_t bytecodeSize = 0;
char* bytecode = luau_compile(source.data(), source.size(), nullptr, &bytecodeSize);
int result = luau_load(L, "=HugeFunction", bytecode, bytecodeSize, 0);
free(bytecode);
REQUIRE(result == 0);
if (codegen && Luau::CodeGen::isSupported())
Luau::CodeGen::compile(L, -1);
int status = lua_resume(L, nullptr, 0);
REQUIRE(status == 0);
CHECK(lua_tonumber(L, -1) == 42);
}
TEST_SUITE_END(); TEST_SUITE_END();

View file

@ -1,126 +0,0 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/ConstraintGraphBuilder.h"
#include "Luau/NotNull.h"
#include "Luau/ToString.h"
#include "ConstraintGraphBuilderFixture.h"
#include "Fixture.h"
#include "doctest.h"
using namespace Luau;
TEST_SUITE_BEGIN("ConstraintGraphBuilder");
TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "hello_world")
{
AstStatBlock* block = parse(R"(
local a = "hello"
local b = a
)");
cgb.visit(block);
auto constraints = collectConstraints(NotNull(cgb.rootScope));
REQUIRE(2 == constraints.size());
ToStringOptions opts;
CHECK("string <: a" == toString(*constraints[0], opts));
CHECK("a <: b" == toString(*constraints[1], opts));
}
TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "primitives")
{
AstStatBlock* block = parse(R"(
local s = "hello"
local n = 555
local b = true
local n2 = nil
)");
cgb.visit(block);
auto constraints = collectConstraints(NotNull(cgb.rootScope));
REQUIRE(3 == constraints.size());
ToStringOptions opts;
CHECK("string <: a" == toString(*constraints[0], opts));
CHECK("number <: b" == toString(*constraints[1], opts));
CHECK("boolean <: c" == toString(*constraints[2], opts));
}
TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "nil_primitive")
{
AstStatBlock* block = parse(R"(
local function a() return nil end
local b = a()
)");
cgb.visit(block);
auto constraints = collectConstraints(NotNull(cgb.rootScope));
ToStringOptions opts;
REQUIRE(4 <= constraints.size());
CHECK("*blocked-1* ~ gen () -> (a...)" == toString(*constraints[0], opts));
CHECK("call *blocked-1* with { result = *blocked-tp-1* }" == toString(*constraints[1], opts));
CHECK("*blocked-tp-1* <: b" == toString(*constraints[2], opts));
CHECK("nil <: a..." == toString(*constraints[3], opts));
}
TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "function_application")
{
AstStatBlock* block = parse(R"(
local a = "hello"
local b = a("world")
)");
cgb.visit(block);
auto constraints = collectConstraints(NotNull(cgb.rootScope));
REQUIRE(3 == constraints.size());
ToStringOptions opts;
CHECK("string <: a" == toString(*constraints[0], opts));
CHECK("call a with { result = *blocked-tp-1* }" == toString(*constraints[1], opts));
CHECK("*blocked-tp-1* <: b" == toString(*constraints[2], opts));
}
TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "local_function_definition")
{
AstStatBlock* block = parse(R"(
local function f(a)
return a
end
)");
cgb.visit(block);
auto constraints = collectConstraints(NotNull(cgb.rootScope));
REQUIRE(2 == constraints.size());
ToStringOptions opts;
CHECK("*blocked-1* ~ gen (a) -> (b...)" == toString(*constraints[0], opts));
CHECK("a <: b..." == toString(*constraints[1], opts));
}
TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "recursive_function")
{
AstStatBlock* block = parse(R"(
local function f(a)
return f(a)
end
)");
cgb.visit(block);
auto constraints = collectConstraints(NotNull(cgb.rootScope));
REQUIRE(3 == constraints.size());
ToStringOptions opts;
CHECK("*blocked-1* ~ gen (a) -> (b...)" == toString(*constraints[0], opts));
CHECK("call (a) -> (b...) with { result = *blocked-tp-1* }" == toString(*constraints[1], opts));
CHECK("*blocked-tp-1* <: b..." == toString(*constraints[2], opts));
}
TEST_SUITE_END();

View file

@ -7,11 +7,28 @@ namespace Luau
ConstraintGraphBuilderFixture::ConstraintGraphBuilderFixture() ConstraintGraphBuilderFixture::ConstraintGraphBuilderFixture()
: Fixture() : Fixture()
, mainModule(new Module) , mainModule(new Module)
, cgb("MainModule", mainModule, &arena, NotNull(&moduleResolver), singletonTypes, NotNull(&ice), frontend.getGlobalScope(), &logger)
, forceTheFlag{"DebugLuauDeferredConstraintResolution", true} , forceTheFlag{"DebugLuauDeferredConstraintResolution", true}
{ {
BlockedTypeVar::nextIndex = 0; BlockedTypeVar::nextIndex = 0;
BlockedTypePack::nextIndex = 0; BlockedTypePack::nextIndex = 0;
} }
void ConstraintGraphBuilderFixture::generateConstraints(const std::string& code)
{
AstStatBlock* root = parse(code);
dfg = std::make_unique<DataFlowGraph>(DataFlowGraphBuilder::build(root, NotNull{&ice}));
cgb = std::make_unique<ConstraintGraphBuilder>("MainModule", mainModule, &arena, NotNull(&moduleResolver), singletonTypes, NotNull(&ice),
frontend.getGlobalScope(), &logger, NotNull{dfg.get()});
cgb->visit(root);
rootScope = cgb->rootScope;
constraints = Luau::collectConstraints(NotNull{cgb->rootScope});
}
void ConstraintGraphBuilderFixture::solve(const std::string& code)
{
generateConstraints(code);
ConstraintSolver cs{NotNull{&normalizer}, NotNull{rootScope}, "MainModule", NotNull(&moduleResolver), {}, &logger};
cs.run();
}
} // namespace Luau } // namespace Luau

View file

@ -2,6 +2,7 @@
#pragma once #pragma once
#include "Luau/ConstraintGraphBuilder.h" #include "Luau/ConstraintGraphBuilder.h"
#include "Luau/ConstraintSolver.h"
#include "Luau/DcrLogger.h" #include "Luau/DcrLogger.h"
#include "Luau/TypeArena.h" #include "Luau/TypeArena.h"
#include "Luau/Module.h" #include "Luau/Module.h"
@ -16,12 +17,22 @@ struct ConstraintGraphBuilderFixture : Fixture
{ {
TypeArena arena; TypeArena arena;
ModulePtr mainModule; ModulePtr mainModule;
ConstraintGraphBuilder cgb;
DcrLogger logger; DcrLogger logger;
UnifierSharedState sharedState{&ice};
Normalizer normalizer{&arena, singletonTypes, NotNull{&sharedState}};
std::unique_ptr<DataFlowGraph> dfg;
std::unique_ptr<ConstraintGraphBuilder> cgb;
Scope* rootScope = nullptr;
std::vector<NotNull<Constraint>> constraints;
ScopedFastFlag forceTheFlag; ScopedFastFlag forceTheFlag;
ConstraintGraphBuilderFixture(); ConstraintGraphBuilderFixture();
void generateConstraints(const std::string& code);
void solve(const std::string& code);
}; };
} // namespace Luau } // namespace Luau

View file

@ -1,15 +1,12 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/ConstraintGraphBuilder.h"
#include "Luau/ConstraintSolver.h"
#include "ConstraintGraphBuilderFixture.h" #include "ConstraintGraphBuilderFixture.h"
#include "Fixture.h" #include "Fixture.h"
#include "doctest.h" #include "doctest.h"
using namespace Luau; using namespace Luau;
static TypeId requireBinding(NotNull<Scope> scope, const char* name) static TypeId requireBinding(Scope* scope, const char* name)
{ {
auto b = linearSearchForBinding(scope, name); auto b = linearSearchForBinding(scope, name);
LUAU_ASSERT(b.has_value()); LUAU_ASSERT(b.has_value());
@ -20,22 +17,11 @@ TEST_SUITE_BEGIN("ConstraintSolver");
TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "hello") TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "hello")
{ {
AstStatBlock* block = parse(R"( solve(R"(
local a = 55 local a = 55
local b = a local b = a
)"); )");
cgb.visit(block);
NotNull<Scope> rootScope{cgb.rootScope};
InternalErrorReporter iceHandler;
UnifierSharedState sharedState{&iceHandler};
Normalizer normalizer{&arena, singletonTypes, NotNull{&sharedState}};
NullModuleResolver resolver;
ConstraintSolver cs{NotNull{&normalizer}, rootScope, "MainModule", NotNull(&resolver), {}, &logger};
cs.run();
TypeId bType = requireBinding(rootScope, "b"); TypeId bType = requireBinding(rootScope, "b");
CHECK("number" == toString(bType)); CHECK("number" == toString(bType));
@ -43,22 +29,12 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "hello")
TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "generic_function") TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "generic_function")
{ {
AstStatBlock* block = parse(R"( solve(R"(
local function id(a) local function id(a)
return a return a
end end
)"); )");
cgb.visit(block);
NotNull<Scope> rootScope{cgb.rootScope};
InternalErrorReporter iceHandler;
UnifierSharedState sharedState{&iceHandler};
Normalizer normalizer{&arena, singletonTypes, NotNull{&sharedState}};
NullModuleResolver resolver;
ConstraintSolver cs{NotNull{&normalizer}, rootScope, "MainModule", NotNull(&resolver), {}, &logger};
cs.run();
TypeId idType = requireBinding(rootScope, "id"); TypeId idType = requireBinding(rootScope, "id");
CHECK("<a>(a) -> a" == toString(idType)); CHECK("<a>(a) -> a" == toString(idType));
@ -66,7 +42,7 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "generic_function")
TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "proper_let_generalization") TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "proper_let_generalization")
{ {
AstStatBlock* block = parse(R"( solve(R"(
local function a(c) local function a(c)
local function d(e) local function d(e)
return c return c
@ -78,21 +54,9 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "proper_let_generalization")
local b = a(5) local b = a(5)
)"); )");
cgb.visit(block);
NotNull<Scope> rootScope{cgb.rootScope};
ToStringOptions opts;
NullModuleResolver resolver;
InternalErrorReporter iceHandler;
UnifierSharedState sharedState{&iceHandler};
Normalizer normalizer{&arena, singletonTypes, NotNull{&sharedState}};
ConstraintSolver cs{NotNull{&normalizer}, rootScope, "MainModule", NotNull(&resolver), {}, &logger};
cs.run();
TypeId idType = requireBinding(rootScope, "b"); TypeId idType = requireBinding(rootScope, "b");
ToStringOptions opts;
CHECK("<a>(a) -> number" == toString(idType, opts)); CHECK("<a>(a) -> number" == toString(idType, opts));
} }

View file

@ -0,0 +1,104 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/DataFlowGraphBuilder.h"
#include "Luau/Error.h"
#include "Luau/Parser.h"
#include "AstQueryDsl.h"
#include "ScopedFlags.h"
#include "doctest.h"
using namespace Luau;
class DataFlowGraphFixture
{
// Only needed to fix the operator== reflexivity of an empty Symbol.
ScopedFastFlag dcr{"DebugLuauDeferredConstraintResolution", true};
InternalErrorReporter handle;
Allocator allocator;
AstNameTable names{allocator};
AstStatBlock* module;
std::optional<DataFlowGraph> graph;
public:
void dfg(const std::string& code)
{
ParseResult parseResult = Parser::parse(code.c_str(), code.size(), names, allocator);
if (!parseResult.errors.empty())
throw ParseErrors(std::move(parseResult.errors));
module = parseResult.root;
graph = DataFlowGraphBuilder::build(module, NotNull{&handle});
}
template<typename T, int N>
std::optional<DefId> getDef(const std::vector<Nth>& nths = {nth<T>(N)})
{
T* node = query<T, N>(module, nths);
REQUIRE(node);
return graph->getDef(node);
}
template<typename T, int N>
DefId requireDef(const std::vector<Nth>& nths = {nth<T>(N)})
{
auto loc = getDef<T, N>(nths);
REQUIRE(loc);
return NotNull{*loc};
}
};
TEST_SUITE_BEGIN("DataFlowGraphBuilder");
TEST_CASE_FIXTURE(DataFlowGraphFixture, "define_locals_in_local_stat")
{
dfg(R"(
local x = 5
local y = x
)");
REQUIRE(getDef<AstExprLocal, 1>());
}
TEST_CASE_FIXTURE(DataFlowGraphFixture, "define_parameters_in_functions")
{
dfg(R"(
local function f(x)
local y = x
end
)");
REQUIRE(getDef<AstExprLocal, 1>());
}
TEST_CASE_FIXTURE(DataFlowGraphFixture, "find_aliases")
{
dfg(R"(
local x = 5
local y = x
local z = y
)");
DefId x = requireDef<AstExprLocal, 1>();
DefId y = requireDef<AstExprLocal, 2>();
REQUIRE(x != y); // TODO: they should be equal but it's not just locals that can alias, so we'll support this later.
}
TEST_CASE_FIXTURE(DataFlowGraphFixture, "independent_locals")
{
dfg(R"(
local x = 5
local y = 5
local a = x
local b = y
)");
DefId x = requireDef<AstExprLocal, 1>();
DefId y = requireDef<AstExprLocal, 2>();
REQUIRE(x != y);
}
TEST_SUITE_END();

View file

@ -226,24 +226,6 @@ TEST_CASE_FIXTURE(Fixture, "clone_free_tables")
CHECK_EQ(clonedTtv->state, TableState::Free); CHECK_EQ(clonedTtv->state, TableState::Free);
} }
TEST_CASE_FIXTURE(Fixture, "clone_constrained_intersection")
{
TypeArena src;
TypeId constrained = src.addType(ConstrainedTypeVar{TypeLevel{}, {singletonTypes->numberType, singletonTypes->stringType}});
TypeArena dest;
CloneState cloneState;
TypeId cloned = clone(constrained, dest, cloneState);
CHECK_NE(constrained, cloned);
const ConstrainedTypeVar* ctv = get<ConstrainedTypeVar>(cloned);
REQUIRE_EQ(2, ctv->parts.size());
CHECK_EQ(singletonTypes->numberType, ctv->parts[0]);
CHECK_EQ(singletonTypes->stringType, ctv->parts[1]);
}
TEST_CASE_FIXTURE(BuiltinsFixture, "clone_self_property") TEST_CASE_FIXTURE(BuiltinsFixture, "clone_self_property")
{ {
fileResolver.source["Module/A"] = R"( fileResolver.source["Module/A"] = R"(

View file

@ -391,26 +391,6 @@ TEST_SUITE_END();
TEST_SUITE_BEGIN("Normalize"); TEST_SUITE_BEGIN("Normalize");
TEST_CASE_FIXTURE(NormalizeFixture, "union_with_overlapping_field_that_has_a_subtype_relationship")
{
check(R"(
local t: {x: number} | {x: number?}
)");
ModulePtr tempModule{new Module};
tempModule->scopes.emplace_back(Location(), std::make_shared<Scope>(singletonTypes->anyTypePack));
// HACK: Normalization is an in-place operation. We need to cheat a little here and unfreeze
// the arena that the type lives in.
ModulePtr mainModule = getMainModule();
unfreeze(mainModule->internalTypes);
TypeId tType = requireType("t");
normalize(tType, tempModule, singletonTypes, *typeChecker.iceHandler);
CHECK_EQ("{| x: number? |}", toString(tType, {true}));
}
TEST_CASE_FIXTURE(Fixture, "higher_order_function") TEST_CASE_FIXTURE(Fixture, "higher_order_function")
{ {
check(R"( check(R"(

View file

@ -10,7 +10,7 @@ using namespace Luau;
TEST_SUITE_BEGIN("SymbolTests"); TEST_SUITE_BEGIN("SymbolTests");
TEST_CASE("hashing_globals") TEST_CASE("equality_and_hashing_of_globals")
{ {
std::string s1 = "name"; std::string s1 = "name";
std::string s2 = "name"; std::string s2 = "name";
@ -37,7 +37,7 @@ TEST_CASE("hashing_globals")
REQUIRE_EQ(1, theMap.size()); REQUIRE_EQ(1, theMap.size());
} }
TEST_CASE("hashing_locals") TEST_CASE("equality_and_hashing_of_locals")
{ {
std::string s1 = "name"; std::string s1 = "name";
std::string s2 = "name"; std::string s2 = "name";
@ -64,4 +64,24 @@ TEST_CASE("hashing_locals")
REQUIRE_EQ(2, theMap.size()); REQUIRE_EQ(2, theMap.size());
} }
TEST_CASE("equality_of_empty_symbols")
{
ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true};
std::string s1 = "name";
std::string s2 = "name";
AstName one{s1.data()};
AstLocal two{AstName{s2.data()}, Location(), nullptr, 0, 0, nullptr};
Symbol global{one};
Symbol local{&two};
Symbol empty1{};
Symbol empty2{};
CHECK(empty1 != global);
CHECK(empty1 != local);
CHECK(empty1 == empty2);
}
TEST_SUITE_END(); TEST_SUITE_END();

View file

@ -79,8 +79,8 @@ n1 [label="AnyTypeVar 1"];
TEST_CASE_FIXTURE(Fixture, "bound") TEST_CASE_FIXTURE(Fixture, "bound")
{ {
CheckResult result = check(R"( CheckResult result = check(R"(
local a = 444 function a(): number return 444 end
local b = a local b = a()
)"); )");
LUAU_REQUIRE_NO_ERRORS(result); LUAU_REQUIRE_NO_ERRORS(result);
@ -367,27 +367,6 @@ n3 [label="number"];
toDot(*ty, opts)); toDot(*ty, opts));
} }
TEST_CASE_FIXTURE(Fixture, "constrained")
{
// ConstrainedTypeVars never appear in the final type graph, so we have to create one directly
// to dotify it.
TypeVar t{ConstrainedTypeVar{TypeLevel{}, {typeChecker.numberType, typeChecker.stringType, typeChecker.nilType}}};
ToDotOptions opts;
opts.showPointers = false;
CHECK_EQ(R"(digraph graphname {
n1 [label="ConstrainedTypeVar 1"];
n1 -> n2;
n2 [label="number"];
n1 -> n3;
n3 [label="string"];
n1 -> n4;
n4 [label="nil"];
})",
toDot(&t, opts));
}
TEST_CASE_FIXTURE(Fixture, "singletontypes") TEST_CASE_FIXTURE(Fixture, "singletontypes")
{ {
CheckResult result = check(R"( CheckResult result = check(R"(

View file

@ -846,8 +846,16 @@ TEST_CASE_FIXTURE(Fixture, "forward_declared_alias_is_not_clobbered_by_prior_uni
type FutureIntersection = A & B type FutureIntersection = A & B
)"); )");
// TODO: shared self causes this test to break in bizarre ways. if (FFlag::DebugLuauDeferredConstraintResolution)
LUAU_REQUIRE_ERRORS(result); {
// To be quite honest, I don't know exactly why DCR fixes this.
LUAU_REQUIRE_NO_ERRORS(result);
}
else
{
// TODO: shared self causes this test to break in bizarre ways.
LUAU_REQUIRE_ERRORS(result);
}
} }
TEST_CASE_FIXTURE(Fixture, "recursive_types_restriction_ok") TEST_CASE_FIXTURE(Fixture, "recursive_types_restriction_ok")

View file

@ -1692,4 +1692,47 @@ foo(string.find("hello", "e"))
CHECK_EQ(toString(result.errors[0]), "Argument count mismatch. Function 'foo' expects 0 to 2 arguments, but 3 are specified"); CHECK_EQ(toString(result.errors[0]), "Argument count mismatch. Function 'foo' expects 0 to 2 arguments, but 3 are specified");
} }
TEST_CASE_FIXTURE(Fixture, "luau_subtyping_is_np_hard")
{
ScopedFastFlag sffs[]{
{"LuauSubtypeNormalizer", true},
{"LuauTypeNormalization2", true},
{"LuauOverloadedFunctionSubtypingPerf", true},
};
CheckResult result = check(R"(
--!strict
-- An example of coding up graph coloring in the Luau type system.
-- This codes a three-node, two color problem.
-- A three-node triangle is uncolorable,
-- but a three-node line is colorable.
type Red = "red"
type Blue = "blue"
type Color = Red | Blue
type Coloring = (Color) -> (Color) -> (Color) -> boolean
type Uncolorable = (Color) -> (Color) -> (Color) -> false
type Line = Coloring
& ((Red) -> (Red) -> (Color) -> false)
& ((Blue) -> (Blue) -> (Color) -> false)
& ((Color) -> (Red) -> (Red) -> false)
& ((Color) -> (Blue) -> (Blue) -> false)
type Triangle = Line
& ((Red) -> (Color) -> (Red) -> false)
& ((Blue) -> (Color) -> (Blue) -> false)
local x : Triangle
local y : Line
local z : Uncolorable
z = x -- OK, so the triangle is uncolorable
z = y -- Not OK, so the line is colorable
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ(toString(result.errors[0]), "Type '((\"blue\" | \"red\") -> (\"blue\" | \"red\") -> (\"blue\" | \"red\") -> boolean) & ((\"blue\" | \"red\") -> (\"blue\") -> (\"blue\") -> false) & ((\"blue\" | \"red\") -> (\"red\") -> (\"red\") -> false) & ((\"blue\") -> (\"blue\") -> (\"blue\" | \"red\") -> false) & ((\"red\") -> (\"red\") -> (\"blue\" | \"red\") -> false)' could not be converted into '(\"blue\" | \"red\") -> (\"blue\" | \"red\") -> (\"blue\" | \"red\") -> false'; none of the intersection parts are compatible");
}
TEST_SUITE_END(); TEST_SUITE_END();

View file

@ -781,7 +781,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "intersect_metatables")
CheckResult result = check(R"( CheckResult result = check(R"(
local a : string? = nil local a : string? = nil
local b : number? = nil local b : number? = nil
local x = setmetatable({}, { p = 5, q = a }); local x = setmetatable({}, { p = 5, q = a });
local y = setmetatable({}, { q = b, r = "hi" }); local y = setmetatable({}, { q = b, r = "hi" });
local z = setmetatable({}, { p = 5, q = nil, r = "hi" }); local z = setmetatable({}, { p = 5, q = nil, r = "hi" });

View file

@ -13,6 +13,8 @@
using namespace Luau; using namespace Luau;
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
TEST_SUITE_BEGIN("TypeInferOperators"); TEST_SUITE_BEGIN("TypeInferOperators");
TEST_CASE_FIXTURE(Fixture, "or_joins_types") TEST_CASE_FIXTURE(Fixture, "or_joins_types")
@ -33,7 +35,7 @@ TEST_CASE_FIXTURE(Fixture, "or_joins_types_with_no_extras")
local x:number|string = s local x:number|string = s
local y = x or "s" local y = x or "s"
)"); )");
CHECK_EQ(0, result.errors.size()); LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(toString(*requireType("s")), "number | string"); CHECK_EQ(toString(*requireType("s")), "number | string");
CHECK_EQ(toString(*requireType("y")), "number | string"); CHECK_EQ(toString(*requireType("y")), "number | string");
} }
@ -44,7 +46,7 @@ TEST_CASE_FIXTURE(Fixture, "or_joins_types_with_no_superfluous_union")
local s = "a" or "b" local s = "a" or "b"
local x:string = s local x:string = s
)"); )");
CHECK_EQ(0, result.errors.size()); LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(*requireType("s"), *typeChecker.stringType); CHECK_EQ(*requireType("s"), *typeChecker.stringType);
} }
@ -54,7 +56,7 @@ TEST_CASE_FIXTURE(Fixture, "and_adds_boolean")
local s = "a" and 10 local s = "a" and 10
local x:boolean|number = s local x:boolean|number = s
)"); )");
CHECK_EQ(0, result.errors.size()); LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(toString(*requireType("s")), "boolean | number"); CHECK_EQ(toString(*requireType("s")), "boolean | number");
} }
@ -64,7 +66,7 @@ TEST_CASE_FIXTURE(Fixture, "and_adds_boolean_no_superfluous_union")
local s = "a" and true local s = "a" and true
local x:boolean = s local x:boolean = s
)"); )");
CHECK_EQ(0, result.errors.size()); LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(*requireType("x"), *typeChecker.booleanType); CHECK_EQ(*requireType("x"), *typeChecker.booleanType);
} }
@ -73,7 +75,7 @@ TEST_CASE_FIXTURE(Fixture, "and_or_ternary")
CheckResult result = check(R"( CheckResult result = check(R"(
local s = (1/2) > 0.5 and "a" or 10 local s = (1/2) > 0.5 and "a" or 10
)"); )");
CHECK_EQ(0, result.errors.size()); LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(toString(*requireType("s")), "number | string"); CHECK_EQ(toString(*requireType("s")), "number | string");
} }
@ -81,7 +83,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "primitive_arith_no_metatable")
{ {
CheckResult result = check(R"( CheckResult result = check(R"(
function add(a: number, b: string) function add(a: number, b: string)
return a + (tonumber(b) :: number), a .. b return a + (tonumber(b) :: number), tostring(a) .. b
end end
local n, s = add(2,"3") local n, s = add(2,"3")
)"); )");
@ -558,15 +560,21 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "disallow_string_and_types_without_metatables
LUAU_REQUIRE_ERROR_COUNT(3, result); LUAU_REQUIRE_ERROR_COUNT(3, result);
TypeMismatch* tm = get<TypeMismatch>(result.errors[0]); TypeMismatch* tm = get<TypeMismatch>(result.errors[0]);
REQUIRE_EQ(*tm->wantedType, *typeChecker.numberType); REQUIRE(tm);
REQUIRE_EQ(*tm->givenType, *typeChecker.stringType); CHECK_EQ(*tm->wantedType, *typeChecker.numberType);
CHECK_EQ(*tm->givenType, *typeChecker.stringType);
GenericError* gen1 = get<GenericError>(result.errors[1]);
REQUIRE(gen1);
if (FFlag::DebugLuauDeferredConstraintResolution)
CHECK_EQ(gen1->message, "Operator + is not applicable for '{ value: number }' and 'number' because neither type has a metatable");
else
CHECK_EQ(gen1->message, "Binary operator '+' not supported by types 'foo' and 'number'");
TypeMismatch* tm2 = get<TypeMismatch>(result.errors[2]); TypeMismatch* tm2 = get<TypeMismatch>(result.errors[2]);
REQUIRE(tm2);
CHECK_EQ(*tm2->wantedType, *typeChecker.numberType); CHECK_EQ(*tm2->wantedType, *typeChecker.numberType);
CHECK_EQ(*tm2->givenType, *requireType("foo")); CHECK_EQ(*tm2->givenType, *requireType("foo"));
GenericError* gen2 = get<GenericError>(result.errors[1]);
REQUIRE_EQ(gen2->message, "Binary operator '+' not supported by types 'foo' and 'number'");
} }
// CLI-29033 // CLI-29033
@ -611,12 +619,10 @@ TEST_CASE_FIXTURE(Fixture, "strict_binary_op_where_lhs_unknown")
{ {
std::vector<std::string> ops = {"+", "-", "*", "/", "%", "^", ".."}; std::vector<std::string> ops = {"+", "-", "*", "/", "%", "^", ".."};
std::string src = R"( std::string src = "function foo(a, b)\n";
function foo(a, b)
)";
for (const auto& op : ops) for (const auto& op : ops)
src += "local _ = a " + op + "b\n"; src += "local _ = a " + op + " b\n";
src += "end"; src += "end";
@ -651,7 +657,11 @@ TEST_CASE_FIXTURE(Fixture, "error_on_invalid_operand_types_to_relational_operato
GenericError* ge = get<GenericError>(result.errors[0]); GenericError* ge = get<GenericError>(result.errors[0]);
REQUIRE(ge); REQUIRE(ge);
CHECK_EQ("Type 'boolean' cannot be compared with relational operator <", ge->message);
if (FFlag::DebugLuauDeferredConstraintResolution)
CHECK_EQ("Types 'boolean' and 'boolean' cannot be compared with relational operator <", ge->message);
else
CHECK_EQ("Type 'boolean' cannot be compared with relational operator <", ge->message);
} }
TEST_CASE_FIXTURE(Fixture, "error_on_invalid_operand_types_to_relational_operators2") TEST_CASE_FIXTURE(Fixture, "error_on_invalid_operand_types_to_relational_operators2")
@ -666,7 +676,10 @@ TEST_CASE_FIXTURE(Fixture, "error_on_invalid_operand_types_to_relational_operato
GenericError* ge = get<GenericError>(result.errors[0]); GenericError* ge = get<GenericError>(result.errors[0]);
REQUIRE(ge); REQUIRE(ge);
CHECK_EQ("Type 'number | string' cannot be compared with relational operator <", ge->message); if (FFlag::DebugLuauDeferredConstraintResolution)
CHECK_EQ("Types 'number | string' and 'number | string' cannot be compared with relational operator <", ge->message);
else
CHECK_EQ("Type 'number | string' cannot be compared with relational operator <", ge->message);
} }
TEST_CASE_FIXTURE(Fixture, "cli_38355_recursive_union") TEST_CASE_FIXTURE(Fixture, "cli_38355_recursive_union")
@ -891,4 +904,63 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "expected_types_through_binary_or")
LUAU_REQUIRE_NO_ERRORS(result); LUAU_REQUIRE_NO_ERRORS(result);
} }
TEST_CASE_FIXTURE(BuiltinsFixture, "mm_ops_must_return_a_value")
{
if (!FFlag::DebugLuauDeferredConstraintResolution)
return;
CheckResult result = check(R"(
local mm = {
__add = function(self, other)
return
end,
}
local x = setmetatable({}, mm)
local y = x + 123
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK(requireType("y") == singletonTypes->errorRecoveryType());
const GenericError* ge = get<GenericError>(result.errors[0]);
REQUIRE(ge);
CHECK(ge->message == "Metamethod '__add' must return a value");
}
TEST_CASE_FIXTURE(BuiltinsFixture, "mm_comparisons_must_return_a_boolean")
{
if (!FFlag::DebugLuauDeferredConstraintResolution)
return;
CheckResult result = check(R"(
local mm1 = {
__lt = function(self, other)
return 123
end,
}
local mm2 = {
__lt = function(self, other)
return
end,
}
local o1 = setmetatable({}, mm1)
local v1 = o1 < o1
local o2 = setmetatable({}, mm2)
local v2 = o2 < o2
)");
LUAU_REQUIRE_ERROR_COUNT(2, result);
CHECK(requireType("v1") == singletonTypes->booleanType);
CHECK(requireType("v2") == singletonTypes->booleanType);
CHECK(toString(result.errors[0]) == "Metamethod '__lt' must return type 'boolean'");
CHECK(toString(result.errors[1]) == "Metamethod '__lt' must return type 'boolean'");
}
TEST_SUITE_END(); TEST_SUITE_END();

View file

@ -8,6 +8,7 @@
#include "doctest.h" #include "doctest.h"
LUAU_FASTFLAG(LuauSpecialTypesAsterisked) LUAU_FASTFLAG(LuauSpecialTypesAsterisked)
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
using namespace Luau; using namespace Luau;
@ -49,7 +50,6 @@ struct RefinementClassFixture : Fixture
{"Y", Property{typeChecker.numberType}}, {"Y", Property{typeChecker.numberType}},
{"Z", Property{typeChecker.numberType}}, {"Z", Property{typeChecker.numberType}},
}; };
normalize(vec3, scope, arena, singletonTypes, *typeChecker.iceHandler);
TypeId inst = arena.addType(ClassTypeVar{"Instance", {}, std::nullopt, std::nullopt, {}, nullptr, "Test"}); TypeId inst = arena.addType(ClassTypeVar{"Instance", {}, std::nullopt, std::nullopt, {}, nullptr, "Test"});
@ -57,21 +57,17 @@ struct RefinementClassFixture : Fixture
TypePackId isARets = arena.addTypePack({typeChecker.booleanType}); TypePackId isARets = arena.addTypePack({typeChecker.booleanType});
TypeId isA = arena.addType(FunctionTypeVar{isAParams, isARets}); TypeId isA = arena.addType(FunctionTypeVar{isAParams, isARets});
getMutable<FunctionTypeVar>(isA)->magicFunction = magicFunctionInstanceIsA; getMutable<FunctionTypeVar>(isA)->magicFunction = magicFunctionInstanceIsA;
normalize(isA, scope, arena, singletonTypes, *typeChecker.iceHandler);
getMutable<ClassTypeVar>(inst)->props = { getMutable<ClassTypeVar>(inst)->props = {
{"Name", Property{typeChecker.stringType}}, {"Name", Property{typeChecker.stringType}},
{"IsA", Property{isA}}, {"IsA", Property{isA}},
}; };
normalize(inst, scope, arena, singletonTypes, *typeChecker.iceHandler);
TypeId folder = typeChecker.globalTypes.addType(ClassTypeVar{"Folder", {}, inst, std::nullopt, {}, nullptr, "Test"}); TypeId folder = typeChecker.globalTypes.addType(ClassTypeVar{"Folder", {}, inst, std::nullopt, {}, nullptr, "Test"});
normalize(folder, scope, arena, singletonTypes, *typeChecker.iceHandler);
TypeId part = typeChecker.globalTypes.addType(ClassTypeVar{"Part", {}, inst, std::nullopt, {}, nullptr, "Test"}); TypeId part = typeChecker.globalTypes.addType(ClassTypeVar{"Part", {}, inst, std::nullopt, {}, nullptr, "Test"});
getMutable<ClassTypeVar>(part)->props = { getMutable<ClassTypeVar>(part)->props = {
{"Position", Property{vec3}}, {"Position", Property{vec3}},
}; };
normalize(part, scope, arena, singletonTypes, *typeChecker.iceHandler);
typeChecker.globalScope->exportedTypeBindings["Vector3"] = TypeFun{{}, vec3}; typeChecker.globalScope->exportedTypeBindings["Vector3"] = TypeFun{{}, vec3};
typeChecker.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, inst}; typeChecker.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, inst};
@ -102,8 +98,16 @@ TEST_CASE_FIXTURE(Fixture, "is_truthy_constraint")
LUAU_REQUIRE_NO_ERRORS(result); LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("string", toString(requireTypeAtPosition({3, 26}))); if (FFlag::DebugLuauDeferredConstraintResolution)
CHECK_EQ("nil", toString(requireTypeAtPosition({5, 26}))); {
CHECK_EQ("(string?) & ~(false?)", toString(requireTypeAtPosition({3, 26})));
CHECK_EQ("(string?) & ~~(false?)", toString(requireTypeAtPosition({5, 26})));
}
else
{
CHECK_EQ("string", toString(requireTypeAtPosition({3, 26})));
CHECK_EQ("nil", toString(requireTypeAtPosition({5, 26})));
}
} }
TEST_CASE_FIXTURE(Fixture, "invert_is_truthy_constraint") TEST_CASE_FIXTURE(Fixture, "invert_is_truthy_constraint")
@ -120,8 +124,16 @@ TEST_CASE_FIXTURE(Fixture, "invert_is_truthy_constraint")
LUAU_REQUIRE_NO_ERRORS(result); LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("nil", toString(requireTypeAtPosition({3, 26}))); if (FFlag::DebugLuauDeferredConstraintResolution)
CHECK_EQ("string", toString(requireTypeAtPosition({5, 26}))); {
CHECK_EQ("(string?) & ~~(false?)", toString(requireTypeAtPosition({3, 26})));
CHECK_EQ("(string?) & ~~~(false?)", toString(requireTypeAtPosition({5, 26})));
}
else
{
CHECK_EQ("nil", toString(requireTypeAtPosition({3, 26})));
CHECK_EQ("string", toString(requireTypeAtPosition({5, 26})));
}
} }
TEST_CASE_FIXTURE(Fixture, "parenthesized_expressions_are_followed_through") TEST_CASE_FIXTURE(Fixture, "parenthesized_expressions_are_followed_through")
@ -138,8 +150,16 @@ TEST_CASE_FIXTURE(Fixture, "parenthesized_expressions_are_followed_through")
LUAU_REQUIRE_NO_ERRORS(result); LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("nil", toString(requireTypeAtPosition({3, 26}))); if (FFlag::DebugLuauDeferredConstraintResolution)
CHECK_EQ("string", toString(requireTypeAtPosition({5, 26}))); {
CHECK_EQ("(string?) & ~~(false?)", toString(requireTypeAtPosition({3, 26})));
CHECK_EQ("(string?) & ~~~(false?)", toString(requireTypeAtPosition({5, 26})));
}
else
{
CHECK_EQ("nil", toString(requireTypeAtPosition({3, 26})));
CHECK_EQ("string", toString(requireTypeAtPosition({5, 26})));
}
} }
TEST_CASE_FIXTURE(Fixture, "and_constraint") TEST_CASE_FIXTURE(Fixture, "and_constraint")
@ -963,19 +983,27 @@ TEST_CASE_FIXTURE(Fixture, "and_or_peephole_refinement")
TEST_CASE_FIXTURE(Fixture, "narrow_boolean_to_true_or_false") TEST_CASE_FIXTURE(Fixture, "narrow_boolean_to_true_or_false")
{ {
CheckResult result = check(R"( CheckResult result = check(R"(
local function is_true(b: true) end
local function is_false(b: false) end
local function f(x: boolean) local function f(x: boolean)
if x then if x then
is_true(x) local foo = x
else else
is_false(x) local foo = x
end end
end end
)"); )");
LUAU_REQUIRE_NO_ERRORS(result); LUAU_REQUIRE_NO_ERRORS(result);
if (FFlag::DebugLuauDeferredConstraintResolution)
{
CHECK_EQ("boolean & ~(false?)", toString(requireTypeAtPosition({3, 28})));
CHECK_EQ("boolean & ~~(false?)", toString(requireTypeAtPosition({5, 28})));
}
else
{
CHECK_EQ("true", toString(requireTypeAtPosition({3, 28})));
CHECK_EQ("false", toString(requireTypeAtPosition({5, 28})));
}
} }
TEST_CASE_FIXTURE(Fixture, "discriminate_on_properties_of_disjoint_tables_where_that_property_is_true_or_false") TEST_CASE_FIXTURE(Fixture, "discriminate_on_properties_of_disjoint_tables_where_that_property_is_true_or_false")

View file

@ -11,6 +11,8 @@
using namespace Luau; using namespace Luau;
LUAU_FASTFLAG(LuauLowerBoundsCalculation);
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution);
LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAG(LuauInstantiateInSubtyping)
TEST_SUITE_BEGIN("TableTests"); TEST_SUITE_BEGIN("TableTests");
@ -44,7 +46,7 @@ TEST_CASE_FIXTURE(Fixture, "augment_table")
const TableTypeVar* tType = get<TableTypeVar>(requireType("t")); const TableTypeVar* tType = get<TableTypeVar>(requireType("t"));
REQUIRE(tType != nullptr); REQUIRE(tType != nullptr);
CHECK(tType->props.find("foo") != tType->props.end()); CHECK(1 == tType->props.count("foo"));
} }
TEST_CASE_FIXTURE(Fixture, "augment_nested_table") TEST_CASE_FIXTURE(Fixture, "augment_nested_table")
@ -101,7 +103,11 @@ TEST_CASE_FIXTURE(Fixture, "updating_sealed_table_prop_is_ok")
TEST_CASE_FIXTURE(Fixture, "cannot_change_type_of_unsealed_table_prop") TEST_CASE_FIXTURE(Fixture, "cannot_change_type_of_unsealed_table_prop")
{ {
CheckResult result = check("local t = {} t.prop = 999 t.prop = 'hello'"); CheckResult result = check(R"(
local t = {}
t.prop = 999
t.prop = 'hello'
)");
LUAU_REQUIRE_ERROR_COUNT(1, result); LUAU_REQUIRE_ERROR_COUNT(1, result);
} }
@ -858,11 +864,12 @@ TEST_CASE_FIXTURE(Fixture, "assigning_to_an_unsealed_table_with_string_literal_s
LUAU_REQUIRE_NO_ERRORS(result); LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(*typeChecker.stringType, *requireType("a")); CHECK("string" == toString(*typeChecker.stringType));
TableTypeVar* tableType = getMutable<TableTypeVar>(requireType("t")); TableTypeVar* tableType = getMutable<TableTypeVar>(requireType("t"));
REQUIRE(tableType != nullptr); REQUIRE(tableType != nullptr);
REQUIRE(tableType->indexer == std::nullopt); REQUIRE(tableType->indexer == std::nullopt);
REQUIRE(0 != tableType->props.count("a"));
TypeId propertyA = tableType->props["a"].type; TypeId propertyA = tableType->props["a"].type;
REQUIRE(propertyA != nullptr); REQUIRE(propertyA != nullptr);
@ -2390,9 +2397,12 @@ TEST_CASE_FIXTURE(Fixture, "wrong_assign_does_hit_indexer")
TEST_CASE_FIXTURE(Fixture, "nil_assign_doesnt_hit_no_indexer") TEST_CASE_FIXTURE(Fixture, "nil_assign_doesnt_hit_no_indexer")
{ {
CheckResult result = check("local a = {a=1, b=2} a['a'] = nil"); CheckResult result = check(R"(
local a = {a=1, b=2}
a['a'] = nil
)");
LUAU_REQUIRE_ERROR_COUNT(1, result); LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ(result.errors[0], (TypeError{Location{Position{0, 30}, Position{0, 33}}, TypeMismatch{ CHECK_EQ(result.errors[0], (TypeError{Location{Position{2, 17}, Position{2, 20}}, TypeMismatch{
typeChecker.numberType, typeChecker.numberType,
typeChecker.nilType, typeChecker.nilType,
}})); }}));
@ -2701,6 +2711,62 @@ local baz = foo[bar]
CHECK_EQ(result.errors[0].location, Location{Position{3, 16}, Position{3, 19}}); CHECK_EQ(result.errors[0].location, Location{Position{3, 16}, Position{3, 19}});
} }
TEST_CASE_FIXTURE(BuiltinsFixture, "table_call_metamethod_basic")
{
if (!FFlag::DebugLuauDeferredConstraintResolution)
return;
CheckResult result = check(R"(
local a = setmetatable({
a = 1,
}, {
__call = function(self, b: number)
return self.a * b
end,
})
local foo = a(12)
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK(requireType("foo") == singletonTypes->numberType);
}
TEST_CASE_FIXTURE(BuiltinsFixture, "table_call_metamethod_must_be_callable")
{
CheckResult result = check(R"(
local a = setmetatable({}, {
__call = 123,
})
local foo = a()
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK(result.errors[0] == TypeError{
Location{{5, 20}, {5, 21}},
CannotCallNonFunction{singletonTypes->numberType},
});
}
TEST_CASE_FIXTURE(BuiltinsFixture, "table_call_metamethod_generic")
{
CheckResult result = check(R"(
local a = setmetatable({}, {
__call = function<T>(self, b: T)
return b
end,
})
local foo = a(12)
local bar = a("bar")
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK(requireType("foo") == singletonTypes->numberType);
CHECK(requireType("bar") == singletonTypes->stringType);
}
TEST_CASE_FIXTURE(BuiltinsFixture, "table_simple_call") TEST_CASE_FIXTURE(BuiltinsFixture, "table_simple_call")
{ {
CheckResult result = check(R"( CheckResult result = check(R"(

View file

@ -1046,7 +1046,6 @@ TEST_CASE_FIXTURE(Fixture, "type_infer_recursion_limit_normalizer")
ScopedFastFlag sffs[]{ ScopedFastFlag sffs[]{
{"LuauSubtypeNormalizer", true}, {"LuauSubtypeNormalizer", true},
{"LuauTypeNormalization2", true}, {"LuauTypeNormalization2", true},
{"LuauAutocompleteDynamicLimits", true},
}; };
CheckResult result = check(R"( CheckResult result = check(R"(

View file

@ -467,6 +467,8 @@ type I<S..., R...> = W<number, (string, S...), R...>
TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_explicit") TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_explicit")
{ {
ScopedFastFlag sff("LuauFunctionReturnStringificationFixup", true);
CheckResult result = check(R"( CheckResult result = check(R"(
type X<T...> = (T...) -> (T...) type X<T...> = (T...) -> (T...)
@ -490,6 +492,8 @@ type F = X<(string, ...number)>
TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_explicit_multi") TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_explicit_multi")
{ {
ScopedFastFlag sff("LuauFunctionReturnStringificationFixup", true);
CheckResult result = check(R"( CheckResult result = check(R"(
type Y<T..., U...> = (T...) -> (U...) type Y<T..., U...> = (T...) -> (U...)

View file

@ -436,7 +436,6 @@ TEST_CASE("proof_that_isBoolean_uses_all_of")
TEST_CASE("content_reassignment") TEST_CASE("content_reassignment")
{ {
TypeVar myAny{AnyTypeVar{}, /*presistent*/ true}; TypeVar myAny{AnyTypeVar{}, /*presistent*/ true};
myAny.normal = true;
myAny.documentationSymbol = "@global/any"; myAny.documentationSymbol = "@global/any";
TypeArena arena; TypeArena arena;
@ -446,7 +445,6 @@ TEST_CASE("content_reassignment")
CHECK(get<AnyTypeVar>(futureAny) != nullptr); CHECK(get<AnyTypeVar>(futureAny) != nullptr);
CHECK(!futureAny->persistent); CHECK(!futureAny->persistent);
CHECK(futureAny->normal);
CHECK(futureAny->documentationSymbol == "@global/any"); CHECK(futureAny->documentationSymbol == "@global/any");
CHECK(futureAny->owningArena == &arena); CHECK(futureAny->owningArena == &arena);
} }

View file

@ -93,7 +93,10 @@ assert((function() local a = 1 a = a * 2 return a end)() == 2)
assert((function() local a = 1 a = a / 2 return a end)() == 0.5) assert((function() local a = 1 a = a / 2 return a end)() == 0.5)
assert((function() local a = 5 a = a % 2 return a end)() == 1) assert((function() local a = 5 a = a % 2 return a end)() == 1)
assert((function() local a = 3 a = a ^ 2 return a end)() == 9) assert((function() local a = 3 a = a ^ 2 return a end)() == 9)
assert((function() local a = 3 a = a ^ 3 return a end)() == 27)
assert((function() local a = 9 a = a ^ 0.5 return a end)() == 3) assert((function() local a = 9 a = a ^ 0.5 return a end)() == 3)
assert((function() local a = -2 a = a ^ 2 return a end)() == 4)
assert((function() local a = -2 a = a ^ 0.5 return tostring(a) end)() == "nan")
assert((function() local a = '1' a = a .. '2' return a end)() == "12") assert((function() local a = '1' a = a .. '2' return a end)() == "12")
assert((function() local a = '1' a = a .. '2' .. '3' return a end)() == "123") assert((function() local a = '1' a = a .. '2' .. '3' return a end)() == "123")
@ -706,7 +709,11 @@ end
assert(chainTest(100) == "v0,v100") assert(chainTest(100) == "v0,v100")
-- this validates import fallbacks -- this validates import fallbacks
assert(idontexist == nil)
assert(math.idontexist == nil)
assert(pcall(function() return idontexist.a end) == false) assert(pcall(function() return idontexist.a end) == false)
assert(pcall(function() return math.pow.a end) == false)
assert(pcall(function() return math.a.b end) == false)
-- make sure that NaN is preserved by the bytecode compiler -- make sure that NaN is preserved by the bytecode compiler
local realnan = tostring(math.abs(0)/math.abs(0)) local realnan = tostring(math.abs(0)/math.abs(0))

View file

@ -226,4 +226,14 @@ assert((function () return nil end)(4) == nil)
assert((function () local a; return a end)(4) == nil) assert((function () local a; return a end)(4) == nil)
assert((function (a) return a end)() == nil) assert((function (a) return a end)() == nil)
-- C-stack overflow while handling C-stack overflow
if not limitedstack then
local function loop ()
assert(pcall(loop))
end
local err, msg = xpcall(loop, loop)
assert(not err and string.find(msg, "error"))
end
return('OK') return('OK')

View file

@ -16,6 +16,7 @@ D = os.date("*t", t)
assert(os.date(string.rep("%d", 1000), t) == assert(os.date(string.rep("%d", 1000), t) ==
string.rep(os.date("%d", t), 1000)) string.rep(os.date("%d", t), 1000))
assert(os.date(string.rep("%", 200)) == string.rep("%", 100)) assert(os.date(string.rep("%", 200)) == string.rep("%", 100))
assert(os.date("", -1) == nil)
local function checkDateTable (t) local function checkDateTable (t)
local D = os.date("!*t", t) local D = os.date("!*t", t)

View file

@ -405,5 +405,7 @@ assert(ecall(function() (""):foo() end) == "attempt to call missing method 'foo'
assert(ecall(function() (42):foo() end) == "attempt to index number with 'foo'") assert(ecall(function() (42):foo() end) == "attempt to index number with 'foo'")
assert(ecall(function() ({foo=42}):foo() end) == "attempt to call a number value") assert(ecall(function() ({foo=42}):foo() end) == "attempt to call a number value")
assert(ecall(function() local ud = newproxy(true) getmetatable(ud).__index = {} ud:foo() end) == "attempt to call missing method 'foo' of userdata") assert(ecall(function() local ud = newproxy(true) getmetatable(ud).__index = {} ud:foo() end) == "attempt to call missing method 'foo' of userdata")
assert(ecall(function() local ud = newproxy(true) getmetatable(ud).__index = function() end ud:foo() end) == "attempt to call missing method 'foo' of userdata")
assert(ecall(function() local ud = newproxy(true) getmetatable(ud).__index = function() error("nope") end ud:foo() end) == "nope")
return('OK') return('OK')

View file

@ -13,6 +13,11 @@ assert(getmetatable(a) == "xuxu")
ud=newproxy(true); getmetatable(ud).__metatable = "xuxu" ud=newproxy(true); getmetatable(ud).__metatable = "xuxu"
assert(getmetatable(ud) == "xuxu") assert(getmetatable(ud) == "xuxu")
assert(pcall(getmetatable) == false)
assert(pcall(function() return getmetatable() end) == false)
assert(select(2, pcall(getmetatable, {})) == nil)
assert(select(2, pcall(getmetatable, ud)) == "xuxu")
local res,err = pcall(tostring, a) local res,err = pcall(tostring, a)
assert(not res and err == "'__tostring' must return a string") assert(not res and err == "'__tostring' must return a string")
-- cannot change a protected metatable -- cannot change a protected metatable
@ -475,6 +480,9 @@ function testfenv()
assert(_G.X == 20) assert(_G.X == 20)
assert(_G == getfenv(0)) assert(_G == getfenv(0))
assert(pcall(getfenv, 10) == false)
assert(pcall(setfenv, setfenv, {}) == false)
end end
testfenv() -- DONT MOVE THIS LINE testfenv() -- DONT MOVE THIS LINE

View file

@ -193,4 +193,24 @@ do
assert(x == 15) assert(x == 15)
end end
-- pairs/ipairs/next may be substituted through getfenv
-- however, they *must* be substituted with functions - we don't support them falling back to generalized iteration
function testgetfenv()
local env = getfenv(1)
env.pairs = function() return "nope" end
env.ipairs = function() return "nope" end
env.next = {1, 2, 3}
local ok, err = pcall(function() for k, v in pairs({}) do end end)
assert(not ok and err:match("attempt to iterate over a string value"))
local ok, err = pcall(function() for k, v in ipairs({}) do end end)
assert(not ok and err:match("attempt to iterate over a string value"))
local ok, err = pcall(function() for k, v in next, {} do end end)
assert(not ok and err:match("attempt to iterate over a table value"))
end
testgetfenv() -- DONT MOVE THIS LINE
return"OK" return"OK"

View file

@ -283,6 +283,13 @@ assert(math.fmod(-3, 2) == -1)
assert(math.fmod(3, -2) == 1) assert(math.fmod(3, -2) == 1)
assert(math.fmod(-3, -2) == -1) assert(math.fmod(-3, -2) == -1)
-- pow
assert(math.pow(2, 0) == 1)
assert(math.pow(2, 2) == 4)
assert(math.pow(4, 0.5) == 2)
assert(math.pow(-2, 2) == 4)
assert(tostring(math.pow(-2, 0.5)) == "nan")
-- most of the tests above go through fastcall path -- most of the tests above go through fastcall path
-- to make sure the basic implementations are also correct we test these functions with string->number coercions -- to make sure the basic implementations are also correct we test these functions with string->number coercions
assert(math.abs("-4") == 4) assert(math.abs("-4") == 4)

View file

@ -74,4 +74,6 @@ checkerror("wrap around", table.move, {}, 1, maxI, 2)
checkerror("wrap around", table.move, {}, 1, 2, maxI) checkerror("wrap around", table.move, {}, 1, 2, maxI)
checkerror("wrap around", table.move, {}, minI, -2, 2) checkerror("wrap around", table.move, {}, minI, -2, 2)
checkerror("readonly", table.move, table.freeze({}), 1, 1, 1)
return"OK" return"OK"

View file

@ -48,6 +48,7 @@ assert(string.find("", "") == 1)
assert(string.find('', 'aaa', 1) == nil) assert(string.find('', 'aaa', 1) == nil)
assert(('alo(.)alo'):find('(.)', 1, 1) == 4) assert(('alo(.)alo'):find('(.)', 1, 1) == 4)
assert(string.find('', '1', 2) == nil) assert(string.find('', '1', 2) == nil)
assert(string.find('123', '2', 0) == 2)
print('+') print('+')
assert(string.len("") == 0) assert(string.len("") == 0)
@ -88,6 +89,8 @@ assert(string.lower("\0ABCc%$") == "\0abcc%$")
assert(string.rep('teste', 0) == '') assert(string.rep('teste', 0) == '')
assert(string.rep('tés\00', 2) == 'tés\0têtés\000') assert(string.rep('tés\00', 2) == 'tés\0têtés\000')
assert(string.rep('', 10) == '') assert(string.rep('', 10) == '')
assert(string.rep('', 1e9) == '')
assert(pcall(string.rep, 'x', 2e9) == false)
assert(string.reverse"" == "") assert(string.reverse"" == "")
assert(string.reverse"\0\1\2\3" == "\3\2\1\0") assert(string.reverse"\0\1\2\3" == "\3\2\1\0")
@ -126,6 +129,13 @@ assert(string.format("-%.20s.20s", string.rep("%", 2000)) == "-"..string.rep("%"
assert(string.format('"-%20s.20s"', string.rep("%", 2000)) == assert(string.format('"-%20s.20s"', string.rep("%", 2000)) ==
string.format("%q", "-"..string.rep("%", 2000)..".20s")) string.format("%q", "-"..string.rep("%", 2000)..".20s"))
assert(string.format("%o %u %x %X", -1, -1, -1, -1) == "1777777777777777777777 18446744073709551615 ffffffffffffffff FFFFFFFFFFFFFFFF")
assert(string.format("%e %E", 1.5, -1.5) == "1.500000e+00 -1.500000E+00")
assert(pcall(string.format, "%##################d", 1) == false)
assert(pcall(string.format, "%.123d", 1) == false)
assert(pcall(string.format, "%?", 1) == false)
-- longest number that can be formated -- longest number that can be formated
assert(string.len(string.format('%99.99f', -1e308)) >= 100) assert(string.len(string.format('%99.99f', -1e308)) >= 100)
@ -179,6 +189,26 @@ assert(table.concat(a, ",", 2) == "b,c")
assert(table.concat(a, ",", 3) == "c") assert(table.concat(a, ",", 3) == "c")
assert(table.concat(a, ",", 4) == "") assert(table.concat(a, ",", 4) == "")
-- string.split
do
local function eq(a, b)
if #a ~= #b then
return false
end
for i=1,#a do
if a[i] ~= b[i] then
return false
end
end
return true
end
assert(eq(string.split("abc", ""), {'a', 'b', 'c'}))
assert(eq(string.split("abc", "b"), {'a', 'c'}))
assert(eq(string.split("abc", "d"), {'abc'}))
assert(eq(string.split("abc", "c"), {'ab', ''}))
end
--[[ --[[
local locales = { "ptb", "ISO-8859-1", "pt_BR" } local locales = { "ptb", "ISO-8859-1", "pt_BR" }
local function trylocale (w) local function trylocale (w)

Some files were not shown because too many files have changed in this diff Show more