Merge branch 'Roblox:master' into master

This commit is contained in:
Nikita Gordeev 2022-11-06 11:00:19 +07:00 committed by GitHub
commit dc0263b0fe
Signed by: DevComp
GPG key ID: 4AEE18F83AFDEB23
162 changed files with 10184 additions and 5681 deletions

View file

@ -73,9 +73,11 @@ jobs:
valgrind --tool=callgrind ./luau --compile=null -O0 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O0 | tee -a compile-output.txt
valgrind --tool=callgrind ./luau --compile=null -O1 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O1 | tee -a compile-output.txt
valgrind --tool=callgrind ./luau --compile=null -O2 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O2 | tee -a compile-output.txt
valgrind --tool=callgrind ./luau --compile=codegennull -O2 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O2-codegen | tee -a compile-output.txt
valgrind --tool=callgrind ./luau --compile=null -O0 bench/other/regex.lua 2>&1 | filter regex-O0 | tee -a compile-output.txt
valgrind --tool=callgrind ./luau --compile=null -O1 bench/other/regex.lua 2>&1 | filter regex-O1 | tee -a compile-output.txt
valgrind --tool=callgrind ./luau --compile=null -O2 bench/other/regex.lua 2>&1 | filter regex-O2 | tee -a compile-output.txt
valgrind --tool=callgrind ./luau --compile=codegennull -O2 bench/other/regex.lua 2>&1 | filter regex-O2-codegen | tee -a compile-output.txt
- name: Checkout benchmark results
uses: actions/checkout@v3

View file

@ -1,6 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include <Luau/NotNull.h>
#include "Luau/TypeArena.h"
#include "Luau/TypeVar.h"
@ -26,5 +27,6 @@ TypeId clone(TypeId tp, TypeArena& dest, CloneState& cloneState);
TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState);
TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysClone = false);
TypeId shallowClone(TypeId ty, NotNull<TypeArena> dest);
} // namespace Luau

View file

@ -0,0 +1,68 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Def.h"
#include "Luau/TypedAllocator.h"
#include "Luau/TypeVar.h"
#include "Luau/Variant.h"
#include <memory>
namespace Luau
{
struct Negation;
struct Conjunction;
struct Disjunction;
struct Equivalence;
struct Proposition;
using Connective = Variant<Negation, Conjunction, Disjunction, Equivalence, Proposition>;
using ConnectiveId = Connective*; // Can and most likely is nullptr.
struct Negation
{
ConnectiveId connective;
};
struct Conjunction
{
ConnectiveId lhs;
ConnectiveId rhs;
};
struct Disjunction
{
ConnectiveId lhs;
ConnectiveId rhs;
};
struct Equivalence
{
ConnectiveId lhs;
ConnectiveId rhs;
};
struct Proposition
{
DefId def;
TypeId discriminantTy;
};
template<typename T>
const T* get(ConnectiveId connective)
{
return get_if<T>(connective);
}
struct ConnectiveArena
{
TypedAllocator<Connective> allocator;
ConnectiveId negation(ConnectiveId connective);
ConnectiveId conjunction(ConnectiveId lhs, ConnectiveId rhs);
ConnectiveId disjunction(ConnectiveId lhs, ConnectiveId rhs);
ConnectiveId equivalence(ConnectiveId lhs, ConnectiveId rhs);
ConnectiveId proposition(DefId def, TypeId discriminantTy);
};
} // namespace Luau

View file

@ -2,9 +2,10 @@
#pragma once
#include "Luau/Ast.h" // Used for some of the enumerations
#include "Luau/Def.h"
#include "Luau/NotNull.h"
#include "Luau/Variant.h"
#include "Luau/TypeVar.h"
#include "Luau/Variant.h"
#include <string>
#include <memory>
@ -131,9 +132,16 @@ struct HasPropConstraint
std::string prop;
};
using ConstraintV =
Variant<SubtypeConstraint, PackSubtypeConstraint, GeneralizationConstraint, InstantiationConstraint, UnaryConstraint, BinaryConstraint,
IterableConstraint, NameConstraint, TypeAliasExpansionConstraint, FunctionCallConstraint, PrimitiveTypeConstraint, HasPropConstraint>;
// result ~ if isSingleton D then ~D else unknown where D = discriminantType
struct SingletonOrTopTypeConstraint
{
TypeId resultType;
TypeId discriminantType;
};
using ConstraintV = Variant<SubtypeConstraint, PackSubtypeConstraint, GeneralizationConstraint, InstantiationConstraint, UnaryConstraint,
BinaryConstraint, IterableConstraint, NameConstraint, TypeAliasExpansionConstraint, FunctionCallConstraint, PrimitiveTypeConstraint,
HasPropConstraint, SingletonOrTopTypeConstraint>;
struct Constraint
{
@ -143,7 +151,7 @@ struct Constraint
Constraint& operator=(const Constraint&) = delete;
NotNull<Scope> scope;
Location location;
Location location; // TODO: Extract this out into only the constraints that needs a location. Not all constraints needs locations.
ConstraintV c;
std::vector<NotNull<Constraint>> dependencies;

View file

@ -1,13 +1,10 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include <memory>
#include <vector>
#include <unordered_map>
#include "Luau/Ast.h"
#include "Luau/Connective.h"
#include "Luau/Constraint.h"
#include "Luau/DataFlowGraphBuilder.h"
#include "Luau/Module.h"
#include "Luau/ModuleResolver.h"
#include "Luau/NotNull.h"
@ -15,6 +12,10 @@
#include "Luau/TypeVar.h"
#include "Luau/Variant.h"
#include <memory>
#include <vector>
#include <unordered_map>
namespace Luau
{
@ -23,6 +24,34 @@ using ScopePtr = std::shared_ptr<Scope>;
struct DcrLogger;
struct Inference
{
TypeId ty = nullptr;
ConnectiveId connective = nullptr;
Inference() = default;
explicit Inference(TypeId ty, ConnectiveId connective = nullptr)
: ty(ty)
, connective(connective)
{
}
};
struct InferencePack
{
TypePackId tp = nullptr;
std::vector<ConnectiveId> connectives;
InferencePack() = default;
explicit InferencePack(TypePackId tp, const std::vector<ConnectiveId>& connectives = {})
: tp(tp)
, connectives(connectives)
{
}
};
struct ConstraintGraphBuilder
{
// A list of all the scopes in the module. This vector holds ownership of the
@ -48,6 +77,8 @@ struct ConstraintGraphBuilder
DenseHashMap<const AstTypePack*, TypePackId> astResolvedTypePacks{nullptr};
// Defining scopes for AST nodes.
DenseHashMap<const AstStatTypeAlias*, ScopePtr> astTypeAliasDefiningScopes{nullptr};
NotNull<const DataFlowGraph> dfg;
ConnectiveArena connectiveArena;
int recursionCount = 0;
@ -63,7 +94,8 @@ struct ConstraintGraphBuilder
DcrLogger* logger;
ConstraintGraphBuilder(const ModuleName& moduleName, ModulePtr module, TypeArena* arena, NotNull<ModuleResolver> moduleResolver,
NotNull<SingletonTypes> singletonTypes, NotNull<InternalErrorReporter> ice, const ScopePtr& globalScope, DcrLogger* logger);
NotNull<SingletonTypes> singletonTypes, NotNull<InternalErrorReporter> ice, const ScopePtr& globalScope, DcrLogger* logger,
NotNull<DataFlowGraph> dfg);
/**
* Fabricates a new free type belonging to a given scope.
@ -88,15 +120,19 @@ struct ConstraintGraphBuilder
* Adds a new constraint with no dependencies to a given scope.
* @param scope the scope to add the constraint to.
* @param cv the constraint variant to add.
* @return the pointer to the inserted constraint
*/
void addConstraint(const ScopePtr& scope, const Location& location, ConstraintV cv);
NotNull<Constraint> addConstraint(const ScopePtr& scope, const Location& location, ConstraintV cv);
/**
* Adds a constraint to a given scope.
* @param scope the scope to add the constraint to. Must not be null.
* @param c the constraint to add.
* @return the pointer to the inserted constraint
*/
void addConstraint(const ScopePtr& scope, std::unique_ptr<Constraint> c);
NotNull<Constraint> addConstraint(const ScopePtr& scope, std::unique_ptr<Constraint> c);
void applyRefinements(const ScopePtr& scope, Location location, ConnectiveId connective);
/**
* The entry point to the ConstraintGraphBuilder. This will construct a set
@ -126,8 +162,10 @@ struct ConstraintGraphBuilder
void visit(const ScopePtr& scope, AstStatDeclareFunction* declareFunction);
void visit(const ScopePtr& scope, AstStatError* error);
TypePackId checkPack(const ScopePtr& scope, AstArray<AstExpr*> exprs, const std::vector<TypeId>& expectedTypes = {});
TypePackId checkPack(const ScopePtr& scope, AstExpr* expr, const std::vector<TypeId>& expectedTypes = {});
InferencePack checkPack(const ScopePtr& scope, AstArray<AstExpr*> exprs, const std::vector<TypeId>& expectedTypes = {});
InferencePack checkPack(const ScopePtr& scope, AstExpr* expr, const std::vector<TypeId>& expectedTypes = {});
InferencePack checkPack(const ScopePtr& scope, AstExprCall* call, const std::vector<TypeId>& expectedTypes);
/**
* Checks an expression that is expected to evaluate to one type.
@ -137,15 +175,24 @@ struct ConstraintGraphBuilder
* surrounding context. Used to implement bidirectional type checking.
* @return the type of the expression.
*/
TypeId check(const ScopePtr& scope, AstExpr* expr, std::optional<TypeId> expectedType = {});
Inference check(const ScopePtr& scope, AstExpr* expr, std::optional<TypeId> expectedType = {}, bool forceSingleton = false);
TypeId check(const ScopePtr& scope, AstExprTable* expr, std::optional<TypeId> expectedType);
TypeId check(const ScopePtr& scope, AstExprIndexName* indexName);
TypeId check(const ScopePtr& scope, AstExprIndexExpr* indexExpr);
TypeId check(const ScopePtr& scope, AstExprUnary* unary);
TypeId check(const ScopePtr& scope, AstExprBinary* binary);
TypeId check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional<TypeId> expectedType);
TypeId check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert);
Inference check(const ScopePtr& scope, AstExprConstantString* string, std::optional<TypeId> expectedType, bool forceSingleton);
Inference check(const ScopePtr& scope, AstExprConstantBool* bool_, std::optional<TypeId> expectedType, bool forceSingleton);
Inference check(const ScopePtr& scope, AstExprLocal* local);
Inference check(const ScopePtr& scope, AstExprGlobal* global);
Inference check(const ScopePtr& scope, AstExprIndexName* indexName);
Inference check(const ScopePtr& scope, AstExprIndexExpr* indexExpr);
Inference check(const ScopePtr& scope, AstExprUnary* unary);
Inference check(const ScopePtr& scope, AstExprBinary* binary, std::optional<TypeId> expectedType);
Inference check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional<TypeId> expectedType);
Inference check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert);
Inference check(const ScopePtr& scope, AstExprTable* expr, std::optional<TypeId> expectedType);
std::tuple<TypeId, TypeId, ConnectiveId> checkBinary(const ScopePtr& scope, AstExprBinary* binary, std::optional<TypeId> expectedType);
TypePackId checkLValues(const ScopePtr& scope, AstArray<AstExpr*> exprs);
TypeId checkLValue(const ScopePtr& scope, AstExpr* expr);
struct FunctionSignature
{
@ -191,7 +238,7 @@ struct ConstraintGraphBuilder
std::vector<std::pair<Name, GenericTypeDefinition>> createGenerics(const ScopePtr& scope, AstArray<AstGenericType> generics);
std::vector<std::pair<Name, GenericTypePackDefinition>> createGenericPacks(const ScopePtr& scope, AstArray<AstGenericTypePack> packs);
TypeId flattenPack(const ScopePtr& scope, Location location, TypePackId tp);
Inference flattenPack(const ScopePtr& scope, Location location, InferencePack pack);
void reportError(Location location, TypeErrorData err);
void reportCodeTooComplex(Location location);

View file

@ -110,6 +110,7 @@ struct ConstraintSolver
bool tryDispatch(const FunctionCallConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const PrimitiveTypeConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const HasPropConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const SingletonOrTopTypeConstraint& c, NotNull<const Constraint> constraint);
// for a, ... in some_table do
// also handles __iter metamethod
@ -215,6 +216,8 @@ private:
TypeId errorRecoveryType() const;
TypePackId errorRecoveryTypePack() const;
TypeId unionOfTypes(TypeId a, TypeId b, NotNull<Scope> scope, bool unifyFreeTypes);
ToStringOptions opts;
};

View file

@ -0,0 +1,115 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
// Do not include LValue. It should never be used here.
#include "Luau/Ast.h"
#include "Luau/DenseHash.h"
#include "Luau/Def.h"
#include "Luau/Symbol.h"
#include <unordered_map>
namespace Luau
{
struct DataFlowGraph
{
DataFlowGraph(DataFlowGraph&&) = default;
DataFlowGraph& operator=(DataFlowGraph&&) = default;
// TODO: AstExprLocal, AstExprGlobal, and AstLocal* are guaranteed never to return nullopt.
// We leave them to return an optional as we build it out, but the end state is for them to return a non-optional DefId.
std::optional<DefId> getDef(const AstExpr* expr) const;
std::optional<DefId> getDef(const AstLocal* local) const;
/// Retrieve the Def that corresponds to the given Symbol.
///
/// We do not perform dataflow analysis on globals, so this function always
/// yields nullopt when passed a global Symbol.
std::optional<DefId> getDef(const Symbol& symbol) const;
private:
DataFlowGraph() = default;
DataFlowGraph(const DataFlowGraph&) = delete;
DataFlowGraph& operator=(const DataFlowGraph&) = delete;
DefArena arena;
DenseHashMap<const AstExpr*, const Def*> astDefs{nullptr};
DenseHashMap<const AstLocal*, const Def*> localDefs{nullptr};
friend struct DataFlowGraphBuilder;
};
struct DfgScope
{
DfgScope* parent;
DenseHashMap<Symbol, const Def*> bindings{Symbol{}};
};
struct ExpressionFlowGraph
{
std::optional<DefId> def;
};
// Currently unsound. We do not presently track the control flow of the program.
// Additionally, we do not presently track assignments.
struct DataFlowGraphBuilder
{
static DataFlowGraph build(AstStatBlock* root, NotNull<struct InternalErrorReporter> handle);
private:
DataFlowGraphBuilder() = default;
DataFlowGraphBuilder(const DataFlowGraphBuilder&) = delete;
DataFlowGraphBuilder& operator=(const DataFlowGraphBuilder&) = delete;
DataFlowGraph graph;
NotNull<DefArena> arena{&graph.arena};
struct InternalErrorReporter* handle;
std::vector<std::unique_ptr<DfgScope>> scopes;
DfgScope* childScope(DfgScope* scope);
std::optional<DefId> use(DfgScope* scope, Symbol symbol, AstExpr* e);
void visit(DfgScope* scope, AstStatBlock* b);
void visitBlockWithoutChildScope(DfgScope* scope, AstStatBlock* b);
// TODO: visit type aliases
void visit(DfgScope* scope, AstStat* s);
void visit(DfgScope* scope, AstStatIf* i);
void visit(DfgScope* scope, AstStatWhile* w);
void visit(DfgScope* scope, AstStatRepeat* r);
void visit(DfgScope* scope, AstStatBreak* b);
void visit(DfgScope* scope, AstStatContinue* c);
void visit(DfgScope* scope, AstStatReturn* r);
void visit(DfgScope* scope, AstStatExpr* e);
void visit(DfgScope* scope, AstStatLocal* l);
void visit(DfgScope* scope, AstStatFor* f);
void visit(DfgScope* scope, AstStatForIn* f);
void visit(DfgScope* scope, AstStatAssign* a);
void visit(DfgScope* scope, AstStatCompoundAssign* c);
void visit(DfgScope* scope, AstStatFunction* f);
void visit(DfgScope* scope, AstStatLocalFunction* l);
ExpressionFlowGraph visitExpr(DfgScope* scope, AstExpr* e);
ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprLocal* l);
ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprGlobal* g);
ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprCall* c);
ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprIndexName* i);
ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprIndexExpr* i);
ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprFunction* f);
ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprTable* t);
ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprUnary* u);
ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprBinary* b);
ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprTypeAssertion* t);
ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprIfElse* i);
ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprInterpString* i);
// TODO: visitLValue
// TODO: visitTypes (because of typeof which has access to values namespace, needs unreachable scope)
// TODO: visitTypePacks (because of typeof which has access to values namespace, needs unreachable scope)
};
} // namespace Luau

View file

@ -0,0 +1,78 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/NotNull.h"
#include "Luau/TypedAllocator.h"
#include "Luau/Variant.h"
namespace Luau
{
using Def = Variant<struct Undefined, struct Phi>;
/**
* We statically approximate a value at runtime using a symbolic value, which we call a Def.
*
* DataFlowGraphBuilder will allocate these defs as a stand-in for some Luau values, and bind them to places that
* can hold a Luau value, and then observes how those defs will commute as it statically evaluate the program.
*
* It must also be noted that defs are a cyclic graph, so it is not safe to recursively traverse into it expecting it to terminate.
*/
using DefId = NotNull<const Def>;
/**
* A "single-object" value.
*
* Leaky implementation note: sometimes "multiple-object" values, but none of which were interesting enough to warrant creating a phi node instead.
* That can happen because there's no point in creating a phi node that points to either resultant in `if math.random() > 0.5 then 5 else "hello"`.
* This might become of utmost importance if we wanted to do some backward reasoning, e.g. if `5` is taken, then `cond` must be `truthy`.
*/
struct Undefined
{
};
/**
* A phi node is a union of defs.
*
* We need this because we're statically evaluating a program, and sometimes a place may be assigned with
* different defs, and when that happens, we need a special data type that merges in all the defs
* that will flow into that specific place. For example, consider this simple program:
*
* ```
* x-1
* if cond() then
* x-2 = 5
* else
* x-3 = "hello"
* end
* x-4 : {x-2, x-3}
* ```
*
* At x-4, we know for a fact statically that either `5` or `"hello"` can flow into the variable `x` after the branch, but
* we cannot make any definitive decisions about which one, so we just take in both.
*/
struct Phi
{
std::vector<DefId> operands;
};
template<typename T>
T* getMutable(DefId def)
{
return get_if<T>(def.get());
}
template<typename T>
const T* get(DefId def)
{
return getMutable<T>(def);
}
struct DefArena
{
TypedAllocator<Def> allocator;
DefId freshDef();
};
} // namespace Luau

View file

@ -7,6 +7,8 @@
#include "Luau/Variant.h"
#include "Luau/TypeArena.h"
LUAU_FASTFLAG(LuauIceExceptionInheritanceChange)
namespace Luau
{
struct TypeError;
@ -302,12 +304,20 @@ struct NormalizationTooComplex
}
};
struct TypePackMismatch
{
TypePackId wantedTp;
TypePackId givenTp;
bool operator==(const TypePackMismatch& rhs) const;
};
using TypeErrorData = Variant<TypeMismatch, UnknownSymbol, UnknownProperty, NotATable, CannotExtendTable, OnlyTablesCanHaveMethods,
DuplicateTypeDefinition, CountMismatch, FunctionDoesNotTakeSelf, FunctionRequiresSelf, OccursCheckFailed, UnknownRequire,
IncorrectGenericParameterCount, SyntaxError, CodeTooComplex, UnificationTooComplex, UnknownPropButFoundLikeProp, GenericError, InternalError,
CannotCallNonFunction, ExtraInformation, DeprecatedApiUsed, ModuleHasCyclicDependency, IllegalRequire, FunctionExitsWithoutReturning,
DuplicateGenericParameter, CannotInferBinaryOperation, MissingProperties, SwappedGenericTypeParameter, OptionalValueAccess, MissingUnionProperty,
TypesAreUnrelated, NormalizationTooComplex>;
TypesAreUnrelated, NormalizationTooComplex, TypePackMismatch>;
struct TypeError
{
@ -374,6 +384,10 @@ struct InternalErrorReporter
class InternalCompilerError : public std::exception
{
public:
explicit InternalCompilerError(const std::string& message)
: message(message)
{
}
explicit InternalCompilerError(const std::string& message, const std::string& moduleName)
: message(message)
, moduleName(moduleName)
@ -388,8 +402,14 @@ public:
virtual const char* what() const throw();
const std::string message;
const std::string moduleName;
const std::optional<std::string> moduleName;
const std::optional<Location> location;
};
// These two function overloads only exist to facilitate fast flagging a change to InternalCompilerError
// Both functions can be removed when FFlagLuauIceExceptionInheritanceChange is removed and calling code
// can directly throw InternalCompilerError.
[[noreturn]] void throwRuntimeError(const std::string& message);
[[noreturn]] void throwRuntimeError(const std::string& message, const std::string& moduleName);
} // namespace Luau

View file

@ -14,6 +14,8 @@ struct TypeVar;
using TypeId = const TypeVar*;
struct Field;
// Deprecated. Do not use in new work.
using LValue = Variant<Symbol, Field>;
struct Field

View file

@ -0,0 +1,32 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Ast.h"
#include <unordered_map>
namespace Luau
{
static const std::unordered_map<AstExprBinary::Op, const char*> kBinaryOpMetamethods{
{AstExprBinary::Op::CompareEq, "__eq"},
{AstExprBinary::Op::CompareNe, "__eq"},
{AstExprBinary::Op::CompareGe, "__lt"},
{AstExprBinary::Op::CompareGt, "__le"},
{AstExprBinary::Op::CompareLe, "__le"},
{AstExprBinary::Op::CompareLt, "__lt"},
{AstExprBinary::Op::Add, "__add"},
{AstExprBinary::Op::Sub, "__sub"},
{AstExprBinary::Op::Mul, "__mul"},
{AstExprBinary::Op::Div, "__div"},
{AstExprBinary::Op::Pow, "__pow"},
{AstExprBinary::Op::Mod, "__mod"},
{AstExprBinary::Op::Concat, "__concat"},
};
static const std::unordered_map<AstExprUnary::Op, const char*> kUnaryOpMetamethods{
{AstExprUnary::Op::Minus, "__unm"},
{AstExprUnary::Op::Len, "__len"},
};
} // namespace Luau

View file

@ -17,19 +17,8 @@ struct SingletonTypes;
using ModulePtr = std::shared_ptr<Module>;
bool isSubtype(
TypeId subTy, TypeId superTy, NotNull<Scope> scope, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter& ice, bool anyIsTop = true);
bool isSubtype(TypePackId subTy, TypePackId superTy, NotNull<Scope> scope, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter& ice,
bool anyIsTop = true);
std::pair<TypeId, bool> normalize(
TypeId ty, NotNull<Scope> scope, TypeArena& arena, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter& ice);
std::pair<TypeId, bool> normalize(TypeId ty, NotNull<Module> module, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter& ice);
std::pair<TypeId, bool> normalize(TypeId ty, const ModulePtr& module, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter& ice);
std::pair<TypePackId, bool> normalize(
TypePackId ty, NotNull<Scope> scope, TypeArena& arena, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter& ice);
std::pair<TypePackId, bool> normalize(TypePackId ty, NotNull<Module> module, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter& ice);
std::pair<TypePackId, bool> normalize(TypePackId ty, const ModulePtr& module, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter& ice);
bool isSubtype(TypeId subTy, TypeId superTy, NotNull<Scope> scope, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter& ice);
bool isSubtype(TypePackId subTy, TypePackId superTy, NotNull<Scope> scope, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter& ice);
class TypeIds
{
@ -115,16 +104,89 @@ struct std::equal_to<const Luau::TypeIds*>
namespace Luau
{
// A normalized string type is either `string` (represented by `nullopt`)
// or a union of string singletons.
using NormalizedStringType = std::optional<std::map<std::string, TypeId>>;
/** A normalized string type is either `string` (represented by `nullopt`) or a
* union of string singletons.
*
* When FFlagLuauNegatedStringSingletons is unset, the representation is as
* follows:
*
* * The `string` data type is represented by the option `singletons` having the
* value `std::nullopt`.
* * The type `never` is represented by `singletons` being populated with an
* empty map.
* * A union of string singletons is represented by a map populated by the names
* and TypeIds of the singletons contained therein.
*
* When FFlagLuauNegatedStringSingletons is set, the representation is as
* follows:
*
* * A union of string singletons is finite and includes the singletons named by
* the `singletons` field.
* * An intersection of negated string singletons is cofinite and includes the
* singletons excluded by the `singletons` field. It is implied that cofinite
* values are exclusions from `string` itself.
* * The `string` data type is a cofinite set minus zero elements.
* * The `never` data type is a finite set plus zero elements.
*/
struct NormalizedStringType
{
// When false, this type represents a union of singleton string types.
// eg "a" | "b" | "c"
//
// When true, this type represents string intersected with negated string
// singleton types.
// eg string & ~"a" & ~"b" & ...
bool isCofinite = false;
// A normalized function type is either `never` (represented by `nullopt`)
// TODO: This field cannot be nullopt when FFlagLuauNegatedStringSingletons
// is set. When clipping that flag, we can remove the wrapping optional.
std::optional<std::map<std::string, TypeId>> singletons;
void resetToString();
void resetToNever();
bool isNever() const;
bool isString() const;
/// Returns true if the string has finite domain.
///
/// Important subtlety: This method returns true for `never`. The empty set
/// is indeed an empty set.
bool isUnion() const;
/// Returns true if the string has infinite domain.
bool isIntersection() const;
bool includes(const std::string& str) const;
static const NormalizedStringType never;
NormalizedStringType() = default;
NormalizedStringType(bool isCofinite, std::optional<std::map<std::string, TypeId>> singletons);
};
bool isSubtype(const NormalizedStringType& subStr, const NormalizedStringType& superStr);
// A normalized function type can be `never`, the top function type `function`,
// or an intersection of function types.
// NOTE: type normalization can fail on function types with generics
// (e.g. because we do not support unions and intersections of generic type packs),
// so this type may contain `error`.
using NormalizedFunctionType = std::optional<TypeIds>;
//
// NOTE: type normalization can fail on function types with generics (e.g.
// because we do not support unions and intersections of generic type packs), so
// this type may contain `error`.
struct NormalizedFunctionType
{
NormalizedFunctionType();
bool isTop = false;
// TODO: Remove this wrapping optional when clipping
// FFlagLuauNegatedFunctionTypes.
std::optional<TypeIds> parts;
void resetToNever();
void resetToTop();
bool isNever() const;
};
// A normalized generic/free type is a union, where each option is of the form (X & T) where
// * X is either a free type or a generic
@ -166,7 +228,7 @@ struct NormalizedType
// The string part of the type.
// This may be the `string` type, or a union of singletons.
NormalizedStringType strings = std::map<std::string, TypeId>{};
NormalizedStringType strings;
// The thread part of the type.
// This type is either never or thread.
@ -184,12 +246,14 @@ struct NormalizedType
NormalizedType(NotNull<SingletonTypes> singletonTypes);
NormalizedType(const NormalizedType&) = delete;
NormalizedType(NormalizedType&&) = default;
NormalizedType() = delete;
~NormalizedType() = default;
NormalizedType(const NormalizedType&) = delete;
NormalizedType& operator=(const NormalizedType&) = delete;
NormalizedType(NormalizedType&&) = default;
NormalizedType& operator=(NormalizedType&&) = default;
NormalizedType& operator=(NormalizedType&) = delete;
};
class Normalizer
@ -240,8 +304,14 @@ public:
bool unionNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars = -1);
bool unionNormalWithTy(NormalizedType& here, TypeId there, int ignoreSmallerTyvars = -1);
// ------- Negations
std::optional<NormalizedType> negateNormal(const NormalizedType& here);
TypeIds negateAll(const TypeIds& theres);
TypeId negate(TypeId there);
void subtractPrimitive(NormalizedType& here, TypeId ty);
void subtractSingleton(NormalizedType& here, TypeId ty);
// ------- Normalizing intersections
void intersectTysWithTy(TypeIds& here, TypeId there);
TypeId intersectionOfTops(TypeId here, TypeId there);
TypeId intersectionOfBools(TypeId here, TypeId there);
void intersectClasses(TypeIds& heres, const TypeIds& theres);

View file

@ -2,6 +2,7 @@
#pragma once
#include "Luau/Common.h"
#include "Luau/Error.h"
#include <stdexcept>
#include <exception>
@ -9,10 +10,20 @@
namespace Luau
{
struct RecursionLimitException : public std::exception
struct RecursionLimitException : public InternalCompilerError
{
RecursionLimitException()
: InternalCompilerError("Internal recursion counter limit exceeded")
{
LUAU_ASSERT(FFlag::LuauIceExceptionInheritanceChange);
}
};
struct RecursionLimitException_DEPRECATED : public std::exception
{
const char* what() const noexcept
{
LUAU_ASSERT(!FFlag::LuauIceExceptionInheritanceChange);
return "Internal recursion counter limit exceeded";
}
};
@ -42,7 +53,14 @@ struct RecursionLimiter : RecursionCounter
{
if (limit > 0 && *count > limit)
{
throw RecursionLimitException();
if (FFlag::LuauIceExceptionInheritanceChange)
{
throw RecursionLimitException();
}
else
{
throw RecursionLimitException_DEPRECATED();
}
}
}
};

View file

@ -54,7 +54,9 @@ struct Scope
DenseHashSet<Name> builtinTypeNames{""};
void addBuiltinTypeBinding(const Name& name, const TypeFun& tyFun);
std::optional<TypeId> lookup(Symbol sym);
std::optional<TypeId> lookup(Symbol sym) const;
std::optional<TypeId> lookup(DefId def) const;
std::optional<std::pair<TypeId, Scope*>> lookupEx(Symbol sym);
std::optional<TypeFun> lookupType(const Name& name);
std::optional<TypeFun> lookupImportedType(const Name& moduleAlias, const Name& name);
@ -66,6 +68,7 @@ struct Scope
std::optional<Binding> linearSearchForBinding(const std::string& name, bool traverseScopeChain = true) const;
RefinementMap refinements;
DenseHashMap<const Def*, TypeId> dcrRefinements{nullptr};
// For mutually recursive type aliases, it's important that
// they use the same types for the same names.

View file

@ -6,10 +6,11 @@
#include <string>
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
namespace Luau
{
// TODO Rename this to Name once the old type alias is gone.
struct Symbol
{
Symbol()
@ -40,9 +41,12 @@ struct Symbol
{
if (local)
return local == rhs.local;
if (global.value)
else if (global.value)
return rhs.global.value && global == rhs.global.value; // Subtlety: AstName::operator==(const char*) uses strcmp, not pointer identity.
return false;
else if (FFlag::DebugLuauDeferredConstraintResolution)
return !rhs.local && !rhs.global.value; // Reflexivity: we already know `this` Symbol is empty, so check that rhs is.
else
return false;
}
bool operator!=(const Symbol& rhs) const
@ -58,8 +62,8 @@ struct Symbol
return global < rhs.global;
else if (local)
return true;
else
return false;
return false;
}
AstName astName() const

View file

@ -117,6 +117,8 @@ inline std::string toStringNamedFunction(const std::string& funcName, const Func
return toStringNamedFunction(funcName, ftv, opts);
}
std::optional<std::string> getFunctionNameAsString(const AstExpr& expr);
// It could be useful to see the text representation of a type during a debugging session instead of exploring the content of the class
// These functions will dump the type to stdout and can be evaluated in Watch/Immediate windows or as gdb/lldb expression
std::string dump(TypeId ty);

View file

@ -48,7 +48,17 @@ struct HashBoolNamePair
size_t operator()(const std::pair<bool, Name>& pair) const;
};
class TimeLimitError : public std::exception
class TimeLimitError : public InternalCompilerError
{
public:
explicit TimeLimitError(const std::string& moduleName)
: InternalCompilerError("Typeinfer failed to complete in allotted time", moduleName)
{
LUAU_ASSERT(FFlag::LuauIceExceptionInheritanceChange);
}
};
class TimeLimitError_DEPRECATED : public std::exception
{
public:
virtual const char* what() const throw();
@ -192,18 +202,12 @@ struct TypeChecker
ErrorVec canUnify(TypeId subTy, TypeId superTy, const ScopePtr& scope, const Location& location);
ErrorVec canUnify(TypePackId subTy, TypePackId superTy, const ScopePtr& scope, const Location& location);
void unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel demotedLevel, const ScopePtr& scope, const Location& location);
std::optional<TypeId> findMetatableEntry(TypeId type, std::string entry, const Location& location, bool addErrors);
std::optional<TypeId> findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location, bool addErrors);
std::optional<TypeId> getIndexTypeFromType(const ScopePtr& scope, TypeId type, const Name& name, const Location& location, bool addErrors);
std::optional<TypeId> getIndexTypeFromTypeImpl(const ScopePtr& scope, TypeId type, const Name& name, const Location& location, bool addErrors);
// Reduces the union to its simplest possible shape.
// (A | B) | B | C yields A | B | C
std::vector<TypeId> reduceUnion(const std::vector<TypeId>& types);
std::optional<TypeId> tryStripUnionFromNil(TypeId ty);
TypeId stripFromNilAndReport(TypeId ty, const Location& location);
@ -242,6 +246,7 @@ public:
[[noreturn]] void ice(const std::string& message, const Location& location);
[[noreturn]] void ice(const std::string& message);
[[noreturn]] void throwTimeLimitError();
ScopePtr childFunctionScope(const ScopePtr& parent, const Location& location, int subLevel = 0);
ScopePtr childScope(const ScopePtr& parent, const Location& location);

View file

@ -29,4 +29,23 @@ std::pair<size_t, std::optional<size_t>> getParameterExtents(const TxnLog* log,
// various other things to get there.
std::vector<TypeId> flatten(TypeArena& arena, NotNull<SingletonTypes> singletonTypes, TypePackId pack, size_t length);
/**
* Reduces a union by decomposing to the any/error type if it appears in the
* type list, and by merging child unions. Also strips out duplicate (by pointer
* identity) types.
* @param types the input type list to reduce.
* @returns the reduced type list.
*/
std::vector<TypeId> reduceUnion(const std::vector<TypeId>& types);
/**
* Tries to remove nil from a union type, if there's another option. T | nil
* reduces to T, but nil itself does not reduce.
* @param singletonTypes the singleton types to use
* @param arena the type arena to allocate the new type in, if necessary
* @param ty the type to remove nil from
* @returns a type with nil removed, or nil itself if that were the only option.
*/
TypeId stripNil(NotNull<SingletonTypes> singletonTypes, TypeArena& arena, TypeId ty);
} // namespace Luau

View file

@ -2,22 +2,23 @@
#pragma once
#include "Luau/Ast.h"
#include "Luau/Common.h"
#include "Luau/DenseHash.h"
#include "Luau/Def.h"
#include "Luau/NotNull.h"
#include "Luau/Predicate.h"
#include "Luau/Unifiable.h"
#include "Luau/Variant.h"
#include "Luau/Common.h"
#include "Luau/NotNull.h"
#include <set>
#include <string>
#include <map>
#include <unordered_set>
#include <unordered_map>
#include <vector>
#include <deque>
#include <map>
#include <memory>
#include <optional>
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
LUAU_FASTINT(LuauTableTypeMaximumStringifierLength)
LUAU_FASTINT(LuauTypeMaximumStringifierLength)
@ -114,6 +115,7 @@ struct PrimitiveTypeVar
Number,
String,
Thread,
Function,
};
Type type;
@ -131,24 +133,6 @@ struct PrimitiveTypeVar
}
};
struct ConstrainedTypeVar
{
explicit ConstrainedTypeVar(TypeLevel level)
: level(level)
{
}
explicit ConstrainedTypeVar(TypeLevel level, const std::vector<TypeId>& parts)
: parts(parts)
, level(level)
{
}
std::vector<TypeId> parts;
TypeLevel level;
Scope* scope = nullptr;
};
// Singleton types https://github.com/Roblox/luau/blob/master/rfcs/syntax-singleton-types.md
// Types for true and false
struct BooleanSingleton
@ -496,11 +480,13 @@ struct AnyTypeVar
{
};
// T | U
struct UnionTypeVar
{
std::vector<TypeId> options;
};
// T & U
struct IntersectionTypeVar
{
std::vector<TypeId> parts;
@ -519,12 +505,19 @@ struct NeverTypeVar
{
};
// ~T
// TODO: Some simplification step that overwrites the type graph to make sure negation
// types disappear from the user's view, and (?) a debug flag to disable that
struct NegationTypeVar
{
TypeId ty;
};
using ErrorTypeVar = Unifiable::Error;
using TypeVariant =
Unifiable::Variant<TypeId, PrimitiveTypeVar, ConstrainedTypeVar, BlockedTypeVar, PendingExpansionTypeVar, SingletonTypeVar, FunctionTypeVar,
TableTypeVar, MetatableTypeVar, ClassTypeVar, AnyTypeVar, UnionTypeVar, IntersectionTypeVar, LazyTypeVar, UnknownTypeVar, NeverTypeVar>;
Unifiable::Variant<TypeId, PrimitiveTypeVar, BlockedTypeVar, PendingExpansionTypeVar, SingletonTypeVar, FunctionTypeVar, TableTypeVar,
MetatableTypeVar, ClassTypeVar, AnyTypeVar, UnionTypeVar, IntersectionTypeVar, LazyTypeVar, UnknownTypeVar, NeverTypeVar, NegationTypeVar>;
struct TypeVar final
{
@ -541,7 +534,6 @@ struct TypeVar final
TypeVar(const TypeVariant& ty, bool persistent)
: ty(ty)
, persistent(persistent)
, normal(persistent) // We assume that all persistent types are irreducable.
{
}
@ -549,7 +541,6 @@ struct TypeVar final
void reassign(const TypeVar& rhs)
{
ty = rhs.ty;
normal = rhs.normal;
documentationSymbol = rhs.documentationSymbol;
}
@ -560,10 +551,6 @@ struct TypeVar final
// Persistent TypeVars do not get cloned.
bool persistent = false;
// Normalization sets this for types that are fully normalized.
// This implies that they are transitively immutable.
bool normal = false;
std::optional<std::string> documentationSymbol;
// Pointer to the type arena that allocated this type.
@ -650,12 +637,15 @@ public:
const TypeId stringType;
const TypeId booleanType;
const TypeId threadType;
const TypeId functionType;
const TypeId trueType;
const TypeId falseType;
const TypeId anyType;
const TypeId unknownType;
const TypeId neverType;
const TypeId errorType;
const TypeId falsyType; // No type binding!
const TypeId truthyType; // No type binding!
const TypePackId anyTypePack;
const TypePackId neverTypePack;
@ -703,7 +693,6 @@ T* getMutable(TypeId tv)
const std::vector<TypeId>& getTypes(const UnionTypeVar* utv);
const std::vector<TypeId>& getTypes(const IntersectionTypeVar* itv);
const std::vector<TypeId>& getTypes(const ConstrainedTypeVar* ctv);
template<typename T>
struct TypeIterator;
@ -716,10 +705,6 @@ using IntersectionTypeVarIterator = TypeIterator<IntersectionTypeVar>;
IntersectionTypeVarIterator begin(const IntersectionTypeVar* itv);
IntersectionTypeVarIterator end(const IntersectionTypeVar* itv);
using ConstrainedTypeVarIterator = TypeIterator<ConstrainedTypeVar>;
ConstrainedTypeVarIterator begin(const ConstrainedTypeVar* ctv);
ConstrainedTypeVarIterator end(const ConstrainedTypeVar* ctv);
/* Traverses the type T yielding each TypeId.
* If the iterator encounters a nested type T, it will instead yield each TypeId within.
*/
@ -793,7 +778,6 @@ struct TypeIterator
// with templates portability in this area, so not worth it. Thanks MSVC.
friend UnionTypeVarIterator end(const UnionTypeVar*);
friend IntersectionTypeVarIterator end(const IntersectionTypeVar*);
friend ConstrainedTypeVarIterator end(const ConstrainedTypeVar*);
private:
TypeIterator() = default;

View file

@ -61,7 +61,6 @@ struct Unifier
ErrorVec errors;
Location location;
Variance variance = Covariant;
bool anyIsTop = false; // If true, we consider any to be a top type. If false, it is a familiar but weird mix of top and bottom all at once.
bool normalize; // Normalize unions and intersections if necessary
bool useScopes = false; // If true, we use the scope hierarchy rather than TypeLevels
CountMismatch::Context ctx = CountMismatch::Arg;
@ -96,6 +95,8 @@ private:
void tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed);
void tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed);
void tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed);
void tryUnifyTypeWithNegation(TypeId subTy, TypeId superTy);
void tryUnifyNegationWithType(TypeId subTy, TypeId superTy);
TypePackId tryApplyOverloadedFunction(TypeId function, const NormalizedFunctionType& overloads, TypePackId args);
@ -119,12 +120,7 @@ private:
std::optional<TypeId> findTablePropertyRespectingMeta(TypeId lhsType, Name name);
void tryUnifyWithConstrainedSubTypeVar(TypeId subTy, TypeId superTy);
void tryUnifyWithConstrainedSuperTypeVar(TypeId subTy, TypeId superTy);
public:
void unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel demotedLevel);
// Returns true if the type "needle" already occurs within "haystack" and reports an "infinite type error"
bool occursCheck(TypeId needle, TypeId haystack);
bool occursCheck(DenseHashSet<TypeId>& seen, TypeId needle, TypeId haystack);
@ -134,6 +130,7 @@ public:
Unifier makeChildUnifier();
void reportError(TypeError err);
LUAU_NOINLINE void reportError(Location location, TypeErrorData data);
private:
bool isNonstrictMode() const;

View file

@ -58,13 +58,15 @@ public:
constexpr int tid = getTypeId<T>();
typeId = tid;
new (&storage) TT(value);
new (&storage) TT(std::forward<T>(value));
}
Variant(const Variant& other)
{
static constexpr FnCopy table[sizeof...(Ts)] = {&fnCopy<Ts>...};
typeId = other.typeId;
tableCopy[typeId](&storage, &other.storage);
table[typeId](&storage, &other.storage);
}
Variant(Variant&& other)
@ -105,7 +107,7 @@ public:
tableDtor[typeId](&storage);
typeId = tid;
new (&storage) TT(std::forward<Args>(args)...);
new (&storage) TT{std::forward<Args>(args)...};
return *reinterpret_cast<T*>(&storage);
}
@ -192,7 +194,6 @@ private:
return *static_cast<const T*>(lhs) == *static_cast<const T*>(rhs);
}
static constexpr FnCopy tableCopy[sizeof...(Ts)] = {&fnCopy<Ts>...};
static constexpr FnMove tableMove[sizeof...(Ts)] = {&fnMove<Ts>...};
static constexpr FnDtor tableDtor[sizeof...(Ts)] = {&fnDtor<Ts>...};

View file

@ -103,10 +103,6 @@ struct GenericTypeVarVisitor
{
return visit(ty);
}
virtual bool visit(TypeId ty, const ConstrainedTypeVar& ctv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const PrimitiveTypeVar& ptv)
{
return visit(ty);
@ -159,6 +155,10 @@ struct GenericTypeVarVisitor
{
return visit(ty);
}
virtual bool visit(TypeId ty, const NegationTypeVar& ntv)
{
return visit(ty);
}
virtual bool visit(TypePackId tp)
{
@ -216,14 +216,6 @@ struct GenericTypeVarVisitor
visit(ty, *gtv);
else if (auto etv = get<ErrorTypeVar>(ty))
visit(ty, *etv);
else if (auto ctv = get<ConstrainedTypeVar>(ty))
{
if (visit(ty, *ctv))
{
for (TypeId part : ctv->parts)
traverse(part);
}
}
else if (auto ptv = get<PrimitiveTypeVar>(ty))
visit(ty, *ptv);
else if (auto ftv = get<FunctionTypeVar>(ty))
@ -325,6 +317,8 @@ struct GenericTypeVarVisitor
traverse(a);
}
}
else if (auto ntv = get<NegationTypeVar>(ty))
visit(ty, *ntv);
else if (!FFlag::LuauCompleteVisitor)
return visit_detail::unsee(seen, ty);
else

View file

@ -37,8 +37,6 @@ bool Anyification::isDirty(TypeId ty)
return (ttv->state == TableState::Free || ttv->state == TableState::Unsealed);
else if (log->getMutable<FreeTypeVar>(ty))
return true;
else if (get<ConstrainedTypeVar>(ty))
return true;
else
return false;
}
@ -65,20 +63,8 @@ TypeId Anyification::clean(TypeId ty)
clone.syntheticName = ttv->syntheticName;
clone.tags = ttv->tags;
TypeId res = addType(std::move(clone));
asMutable(res)->normal = ty->normal;
return res;
}
else if (auto ctv = get<ConstrainedTypeVar>(ty))
{
std::vector<TypeId> copy = ctv->parts;
for (TypeId& ty : copy)
ty = replace(ty);
TypeId res = copy.size() == 1 ? copy[0] : addType(UnionTypeVar{std::move(copy)});
auto [t, ok] = normalize(res, scope, *arena, singletonTypes, *iceHandler);
if (!ok)
normalizationTooComplex = true;
return t;
}
else
return anyType;
}

View file

@ -11,6 +11,8 @@
#include <algorithm>
LUAU_FASTFLAGVARIABLE(LuauCheckOverloadedDocSymbol, false)
namespace Luau
{
@ -427,6 +429,38 @@ ExprOrLocal findExprOrLocalAtPosition(const SourceModule& source, Position pos)
return findVisitor.result;
}
static std::optional<DocumentationSymbol> checkOverloadedDocumentationSymbol(
const Module& module, const TypeId ty, const AstExpr* parentExpr, const std::optional<DocumentationSymbol> documentationSymbol)
{
LUAU_ASSERT(FFlag::LuauCheckOverloadedDocSymbol);
if (!documentationSymbol)
return std::nullopt;
// This might be an overloaded function.
if (get<IntersectionTypeVar>(follow(ty)))
{
TypeId matchingOverload = nullptr;
if (parentExpr && parentExpr->is<AstExprCall>())
{
if (auto it = module.astOverloadResolvedTypes.find(parentExpr))
{
matchingOverload = *it;
}
}
if (matchingOverload)
{
std::string overloadSymbol = *documentationSymbol + "/overload/";
// Default toString options are fine for this purpose.
overloadSymbol += toString(matchingOverload);
return overloadSymbol;
}
}
return documentationSymbol;
}
std::optional<DocumentationSymbol> getDocumentationSymbolAtPosition(const SourceModule& source, const Module& module, Position position)
{
std::vector<AstNode*> ancestry = findAstAncestryOfPosition(source, position);
@ -436,31 +470,38 @@ std::optional<DocumentationSymbol> getDocumentationSymbolAtPosition(const Source
if (std::optional<Binding> binding = findBindingAtPosition(module, source, position))
{
if (binding->documentationSymbol)
if (FFlag::LuauCheckOverloadedDocSymbol)
{
// This might be an overloaded function binding.
if (get<IntersectionTypeVar>(follow(binding->typeId)))
return checkOverloadedDocumentationSymbol(module, binding->typeId, parentExpr, binding->documentationSymbol);
}
else
{
if (binding->documentationSymbol)
{
TypeId matchingOverload = nullptr;
if (parentExpr && parentExpr->is<AstExprCall>())
// This might be an overloaded function binding.
if (get<IntersectionTypeVar>(follow(binding->typeId)))
{
if (auto it = module.astOverloadResolvedTypes.find(parentExpr))
TypeId matchingOverload = nullptr;
if (parentExpr && parentExpr->is<AstExprCall>())
{
matchingOverload = *it;
if (auto it = module.astOverloadResolvedTypes.find(parentExpr))
{
matchingOverload = *it;
}
}
if (matchingOverload)
{
std::string overloadSymbol = *binding->documentationSymbol + "/overload/";
// Default toString options are fine for this purpose.
overloadSymbol += toString(matchingOverload);
return overloadSymbol;
}
}
if (matchingOverload)
{
std::string overloadSymbol = *binding->documentationSymbol + "/overload/";
// Default toString options are fine for this purpose.
overloadSymbol += toString(matchingOverload);
return overloadSymbol;
}
}
}
return binding->documentationSymbol;
return binding->documentationSymbol;
}
}
if (targetExpr)
@ -474,14 +515,20 @@ std::optional<DocumentationSymbol> getDocumentationSymbolAtPosition(const Source
{
if (auto propIt = ttv->props.find(indexName->index.value); propIt != ttv->props.end())
{
return propIt->second.documentationSymbol;
if (FFlag::LuauCheckOverloadedDocSymbol)
return checkOverloadedDocumentationSymbol(module, propIt->second.type, parentExpr, propIt->second.documentationSymbol);
else
return propIt->second.documentationSymbol;
}
}
else if (const ClassTypeVar* ctv = get<ClassTypeVar>(parentTy))
{
if (auto propIt = ctv->props.find(indexName->index.value); propIt != ctv->props.end())
{
return propIt->second.documentationSymbol;
if (FFlag::LuauCheckOverloadedDocSymbol)
return checkOverloadedDocumentationSymbol(module, propIt->second.type, parentExpr, propIt->second.documentationSymbol);
else
return propIt->second.documentationSymbol;
}
}
}

View file

@ -10,6 +10,7 @@
#include "Luau/TypeInfer.h"
#include "Luau/TypePack.h"
#include "Luau/TypeVar.h"
#include "Luau/TypeUtils.h"
#include <algorithm>
@ -41,6 +42,7 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionRequire(
static bool dcrMagicFunctionSelect(MagicFunctionCallContext context);
static bool dcrMagicFunctionRequire(MagicFunctionCallContext context);
static bool dcrMagicFunctionPack(MagicFunctionCallContext context);
TypeId makeUnion(TypeArena& arena, std::vector<TypeId>&& types)
{
@ -333,6 +335,7 @@ void registerBuiltinGlobals(TypeChecker& typeChecker)
ttv->props["clone"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.clone");
attachMagicFunction(ttv->props["pack"].type, magicFunctionPack);
attachDcrMagicFunction(ttv->props["pack"].type, dcrMagicFunctionPack);
}
attachMagicFunction(getGlobalBinding(typeChecker, "require"), magicFunctionRequire);
@ -660,7 +663,7 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionPack(
options.push_back(vtp->ty);
}
options = typechecker.reduceUnion(options);
options = reduceUnion(options);
// table.pack() -> {| n: number, [number]: nil |}
// table.pack(1) -> {| n: number, [number]: number |}
@ -679,6 +682,46 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionPack(
return WithPredicate<TypePackId>{arena.addTypePack({packedTable})};
}
static bool dcrMagicFunctionPack(MagicFunctionCallContext context)
{
TypeArena* arena = context.solver->arena;
const auto& [paramTypes, paramTail] = flatten(context.arguments);
std::vector<TypeId> options;
options.reserve(paramTypes.size());
for (auto type : paramTypes)
options.push_back(type);
if (paramTail)
{
if (const VariadicTypePack* vtp = get<VariadicTypePack>(*paramTail))
options.push_back(vtp->ty);
}
options = reduceUnion(options);
// table.pack() -> {| n: number, [number]: nil |}
// table.pack(1) -> {| n: number, [number]: number |}
// table.pack(1, "foo") -> {| n: number, [number]: number | string |}
TypeId result = nullptr;
if (options.empty())
result = context.solver->singletonTypes->nilType;
else if (options.size() == 1)
result = options[0];
else
result = arena->addType(UnionTypeVar{std::move(options)});
TypeId numberType = context.solver->singletonTypes->numberType;
TypeId packedTable = arena->addType(TableTypeVar{{{"n", {numberType}}}, TableIndexer(numberType, result), {}, TableState::Sealed});
TypePackId tableTypePack = arena->addTypePack({packedTable});
asMutable(context.result)->ty.emplace<BoundTypePack>(tableTypePack);
return true;
}
static bool checkRequirePath(TypeChecker& typechecker, AstExpr* expr)
{
// require(foo.parent.bar) will technically work, but it depends on legacy goop that

View file

@ -1,6 +1,6 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Clone.h"
#include "Luau/RecursionCounter.h"
#include "Luau/TxnLog.h"
#include "Luau/TypePack.h"
@ -51,7 +51,6 @@ struct TypeCloner
void operator()(const BlockedTypeVar& t);
void operator()(const PendingExpansionTypeVar& t);
void operator()(const PrimitiveTypeVar& t);
void operator()(const ConstrainedTypeVar& t);
void operator()(const SingletonTypeVar& t);
void operator()(const FunctionTypeVar& t);
void operator()(const TableTypeVar& t);
@ -63,6 +62,7 @@ struct TypeCloner
void operator()(const LazyTypeVar& t);
void operator()(const UnknownTypeVar& t);
void operator()(const NeverTypeVar& t);
void operator()(const NegationTypeVar& t);
};
struct TypePackCloner
@ -198,21 +198,6 @@ void TypeCloner::operator()(const PrimitiveTypeVar& t)
defaultClone(t);
}
void TypeCloner::operator()(const ConstrainedTypeVar& t)
{
TypeId res = dest.addType(ConstrainedTypeVar{t.level});
ConstrainedTypeVar* ctv = getMutable<ConstrainedTypeVar>(res);
LUAU_ASSERT(ctv);
seenTypes[typeId] = res;
std::vector<TypeId> parts;
for (TypeId part : t.parts)
parts.push_back(clone(part, dest, cloneState));
ctv->parts = std::move(parts);
}
void TypeCloner::operator()(const SingletonTypeVar& t)
{
defaultClone(t);
@ -352,6 +337,15 @@ void TypeCloner::operator()(const NeverTypeVar& t)
defaultClone(t);
}
void TypeCloner::operator()(const NegationTypeVar& t)
{
TypeId result = dest.addType(AnyTypeVar{});
seenTypes[typeId] = result;
TypeId ty = clone(t.ty, dest, cloneState);
asMutable(result)->ty = NegationTypeVar{ty};
}
} // anonymous namespace
TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState)
@ -390,7 +384,6 @@ TypeId clone(TypeId typeId, TypeArena& dest, CloneState& cloneState)
if (!res->persistent)
{
asMutable(res)->documentationSymbol = typeId->documentationSymbol;
asMutable(res)->normal = typeId->normal;
}
}
@ -478,11 +471,6 @@ TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysCl
clone.parts = itv->parts;
result = dest.addType(std::move(clone));
}
else if (const ConstrainedTypeVar* ctv = get<ConstrainedTypeVar>(ty))
{
ConstrainedTypeVar clone{ctv->level, ctv->parts};
result = dest.addType(std::move(clone));
}
else if (const PendingExpansionTypeVar* petv = get<PendingExpansionTypeVar>(ty))
{
PendingExpansionTypeVar clone{petv->prefix, petv->name, petv->typeArguments, petv->packArguments};
@ -497,6 +485,10 @@ TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysCl
{
result = dest.addType(*ty);
}
else if (const NegationTypeVar* ntv = get<NegationTypeVar>(ty))
{
result = dest.addType(NegationTypeVar{ntv->ty});
}
else
return result;
@ -504,4 +496,9 @@ TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysCl
return result;
}
TypeId shallowClone(TypeId ty, NotNull<TypeArena> dest)
{
return shallowClone(ty, *dest, TxnLog::empty());
}
} // namespace Luau

View file

@ -0,0 +1,32 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Connective.h"
namespace Luau
{
ConnectiveId ConnectiveArena::negation(ConnectiveId connective)
{
return NotNull{allocator.allocate(Negation{connective})};
}
ConnectiveId ConnectiveArena::conjunction(ConnectiveId lhs, ConnectiveId rhs)
{
return NotNull{allocator.allocate(Conjunction{lhs, rhs})};
}
ConnectiveId ConnectiveArena::disjunction(ConnectiveId lhs, ConnectiveId rhs)
{
return NotNull{allocator.allocate(Disjunction{lhs, rhs})};
}
ConnectiveId ConnectiveArena::equivalence(ConnectiveId lhs, ConnectiveId rhs)
{
return NotNull{allocator.allocate(Equivalence{lhs, rhs})};
}
ConnectiveId ConnectiveArena::proposition(DefId def, TypeId discriminantTy)
{
return NotNull{allocator.allocate(Proposition{def, discriminantTy})};
}
} // namespace Luau

File diff suppressed because it is too large Load diff

View file

@ -3,14 +3,16 @@
#include "Luau/Anyification.h"
#include "Luau/ApplyTypeFunction.h"
#include "Luau/ConstraintSolver.h"
#include "Luau/DcrLogger.h"
#include "Luau/Instantiation.h"
#include "Luau/Location.h"
#include "Luau/Metamethods.h"
#include "Luau/ModuleResolver.h"
#include "Luau/Quantify.h"
#include "Luau/ToString.h"
#include "Luau/TypeUtils.h"
#include "Luau/TypeVar.h"
#include "Luau/Unifier.h"
#include "Luau/DcrLogger.h"
#include "Luau/VisitTypeVar.h"
#include "Luau/TypeUtils.h"
@ -438,6 +440,8 @@ bool ConstraintSolver::tryDispatch(NotNull<const Constraint> constraint, bool fo
success = tryDispatch(*fcc, constraint);
else if (auto hpc = get<HasPropConstraint>(*constraint))
success = tryDispatch(*hpc, constraint);
else if (auto sottc = get<SingletonOrTopTypeConstraint>(*constraint))
success = tryDispatch(*sottc, constraint);
else
LUAU_ASSERT(false);
@ -540,6 +544,7 @@ bool ConstraintSolver::tryDispatch(const UnaryConstraint& c, NotNull<const Const
}
case AstExprUnary::Len:
{
// __len must return a number.
asMutable(c.resultType)->ty.emplace<BoundTypeVar>(singletonTypes->numberType);
return true;
}
@ -548,13 +553,46 @@ bool ConstraintSolver::tryDispatch(const UnaryConstraint& c, NotNull<const Const
if (isNumber(operandType) || get<AnyTypeVar>(operandType) || get<ErrorTypeVar>(operandType))
{
asMutable(c.resultType)->ty.emplace<BoundTypeVar>(c.operandType);
return true;
}
break;
else if (std::optional<TypeId> mm = findMetatableEntry(singletonTypes, errors, operandType, "__unm", constraint->location))
{
const FunctionTypeVar* ftv = get<FunctionTypeVar>(follow(*mm));
if (!ftv)
{
if (std::optional<TypeId> callMm = findMetatableEntry(singletonTypes, errors, follow(*mm), "__call", constraint->location))
{
ftv = get<FunctionTypeVar>(follow(*callMm));
}
}
if (!ftv)
{
asMutable(c.resultType)->ty.emplace<BoundTypeVar>(singletonTypes->errorRecoveryType());
return true;
}
TypePackId argsPack = arena->addTypePack({operandType});
unify(ftv->argTypes, argsPack, constraint->scope);
TypeId result = singletonTypes->errorRecoveryType();
if (ftv)
{
result = first(ftv->retTypes).value_or(singletonTypes->errorRecoveryType());
}
asMutable(c.resultType)->ty.emplace<BoundTypeVar>(result);
}
else
{
asMutable(c.resultType)->ty.emplace<BoundTypeVar>(singletonTypes->errorRecoveryType());
}
return true;
}
}
LUAU_ASSERT(false); // TODO metatable handling
LUAU_ASSERT(false);
return false;
}
@ -564,44 +602,192 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNull<const Cons
TypeId rightType = follow(c.rightType);
TypeId resultType = follow(c.resultType);
if (isBlocked(leftType) || isBlocked(rightType))
{
/* Compound assignments create constraints of the form
*
* A <: Binary<op, A, B>
*
* This constraint is the one that is meant to unblock A, so it doesn't
* make any sense to stop and wait for someone else to do it.
*/
if (leftType != resultType && rightType != resultType)
{
block(c.leftType, constraint);
block(c.rightType, constraint);
return false;
}
}
bool isLogical = c.op == AstExprBinary::Op::And || c.op == AstExprBinary::Op::Or;
if (isNumber(leftType))
{
unify(leftType, rightType, constraint->scope);
asMutable(resultType)->ty.emplace<BoundTypeVar>(leftType);
return true;
}
/* Compound assignments create constraints of the form
*
* A <: Binary<op, A, B>
*
* This constraint is the one that is meant to unblock A, so it doesn't
* make any sense to stop and wait for someone else to do it.
*/
if (isBlocked(leftType) && leftType != resultType)
return block(c.leftType, constraint);
if (isBlocked(rightType) && rightType != resultType)
return block(c.rightType, constraint);
if (!force)
{
if (get<FreeTypeVar>(leftType))
// Logical expressions may proceed if the LHS is free.
if (get<FreeTypeVar>(leftType) && !isLogical)
return block(leftType, constraint);
}
if (isBlocked(leftType))
// Logical expressions may proceed if the LHS is free.
if (isBlocked(leftType) || (get<FreeTypeVar>(leftType) && !isLogical))
{
asMutable(resultType)->ty.emplace<BoundTypeVar>(errorRecoveryType());
// reportError(constraint->location, CannotInferBinaryOperation{c.op, std::nullopt, CannotInferBinaryOperation::Operation});
unblock(resultType);
return true;
}
// TODO metatables, classes
// For or expressions, the LHS will never have nil as a possible output.
// Consider:
// local foo = nil or 2
// `foo` will always be 2.
if (c.op == AstExprBinary::Op::Or)
leftType = stripNil(singletonTypes, *arena, leftType);
// Metatables go first, even if there is primitive behavior.
if (auto it = kBinaryOpMetamethods.find(c.op); it != kBinaryOpMetamethods.end())
{
// Metatables are not the same. The metamethod will not be invoked.
if ((c.op == AstExprBinary::Op::CompareEq || c.op == AstExprBinary::Op::CompareNe) &&
getMetatable(leftType, singletonTypes) != getMetatable(rightType, singletonTypes))
{
// TODO: Boolean singleton false? The result is _always_ boolean false.
asMutable(resultType)->ty.emplace<BoundTypeVar>(singletonTypes->booleanType);
unblock(resultType);
return true;
}
std::optional<TypeId> mm;
// The LHS metatable takes priority over the RHS metatable, where
// present.
if (std::optional<TypeId> leftMm = findMetatableEntry(singletonTypes, errors, leftType, it->second, constraint->location))
mm = leftMm;
else if (std::optional<TypeId> rightMm = findMetatableEntry(singletonTypes, errors, rightType, it->second, constraint->location))
mm = rightMm;
if (mm)
{
// TODO: Is a table with __call legal here?
// TODO: Overloads
if (const FunctionTypeVar* ftv = get<FunctionTypeVar>(follow(*mm)))
{
TypePackId inferredArgs;
// For >= and > we invoke __lt and __le respectively with
// swapped argument ordering.
if (c.op == AstExprBinary::Op::CompareGe || c.op == AstExprBinary::Op::CompareGt)
{
inferredArgs = arena->addTypePack({rightType, leftType});
}
else
{
inferredArgs = arena->addTypePack({leftType, rightType});
}
unify(inferredArgs, ftv->argTypes, constraint->scope);
TypeId mmResult;
// Comparison operations always evaluate to a boolean,
// regardless of what the metamethod returns.
switch (c.op)
{
case AstExprBinary::Op::CompareEq:
case AstExprBinary::Op::CompareNe:
case AstExprBinary::Op::CompareGe:
case AstExprBinary::Op::CompareGt:
case AstExprBinary::Op::CompareLe:
case AstExprBinary::Op::CompareLt:
mmResult = singletonTypes->booleanType;
break;
default:
mmResult = first(ftv->retTypes).value_or(errorRecoveryType());
}
asMutable(resultType)->ty.emplace<BoundTypeVar>(mmResult);
unblock(resultType);
return true;
}
}
// If there's no metamethod available, fall back to primitive behavior.
}
// If any is present, the expression must evaluate to any as well.
bool leftAny = get<AnyTypeVar>(leftType) || get<ErrorTypeVar>(leftType);
bool rightAny = get<AnyTypeVar>(rightType) || get<ErrorTypeVar>(rightType);
bool anyPresent = leftAny || rightAny;
switch (c.op)
{
// For arithmetic operators, if the LHS is a number, the RHS must be a
// number as well. The result will also be a number.
case AstExprBinary::Op::Add:
case AstExprBinary::Op::Sub:
case AstExprBinary::Op::Mul:
case AstExprBinary::Op::Div:
case AstExprBinary::Op::Pow:
case AstExprBinary::Op::Mod:
if (isNumber(leftType))
{
unify(leftType, rightType, constraint->scope);
asMutable(resultType)->ty.emplace<BoundTypeVar>(anyPresent ? singletonTypes->anyType : leftType);
unblock(resultType);
return true;
}
break;
// For concatenation, if the LHS is a string, the RHS must be a string as
// well. The result will also be a string.
case AstExprBinary::Op::Concat:
if (isString(leftType))
{
unify(leftType, rightType, constraint->scope);
asMutable(resultType)->ty.emplace<BoundTypeVar>(anyPresent ? singletonTypes->anyType : leftType);
unblock(resultType);
return true;
}
break;
// Inexact comparisons require that the types be both numbers or both
// strings, and evaluate to a boolean.
case AstExprBinary::Op::CompareGe:
case AstExprBinary::Op::CompareGt:
case AstExprBinary::Op::CompareLe:
case AstExprBinary::Op::CompareLt:
if ((isNumber(leftType) && isNumber(rightType)) || (isString(leftType) && isString(rightType)))
{
asMutable(resultType)->ty.emplace<BoundTypeVar>(singletonTypes->booleanType);
unblock(resultType);
return true;
}
break;
// == and ~= always evaluate to a boolean, and impose no other constraints
// on their parameters.
case AstExprBinary::Op::CompareEq:
case AstExprBinary::Op::CompareNe:
asMutable(resultType)->ty.emplace<BoundTypeVar>(singletonTypes->booleanType);
unblock(resultType);
return true;
// And evalutes to a boolean if the LHS is falsey, and the RHS type if LHS is
// truthy.
case AstExprBinary::Op::And:
asMutable(resultType)->ty.emplace<BoundTypeVar>(unionOfTypes(rightType, singletonTypes->booleanType, constraint->scope, false));
unblock(resultType);
return true;
// Or evaluates to the LHS type if the LHS is truthy, and the RHS type if
// LHS is falsey.
case AstExprBinary::Op::Or:
asMutable(resultType)->ty.emplace<BoundTypeVar>(unionOfTypes(rightType, leftType, constraint->scope, true));
unblock(resultType);
return true;
default:
iceReporter.ice("Unhandled AstExprBinary::Op for binary operation", constraint->location);
break;
}
// We failed to either evaluate a metamethod or invoke primitive behavior.
unify(leftType, errorRecoveryType(), constraint->scope);
unify(rightType, errorRecoveryType(), constraint->scope);
asMutable(resultType)->ty.emplace<BoundTypeVar>(errorRecoveryType());
unblock(resultType);
return true;
}
@ -710,6 +896,10 @@ bool ConstraintSolver::tryDispatch(const NameConstraint& c, NotNull<const Constr
ttv->name = c.name;
else if (MetatableTypeVar* mtv = getMutable<MetatableTypeVar>(target))
mtv->syntheticName = c.name;
else if (get<IntersectionTypeVar>(target) || get<UnionTypeVar>(target))
{
// nothing (yet)
}
else
return block(c.namedType, constraint);
@ -943,6 +1133,31 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull<cons
return block(c.fn, constraint);
}
// We don't support magic __call metamethods.
if (std::optional<TypeId> callMm = findMetatableEntry(singletonTypes, errors, fn, "__call", constraint->location))
{
std::vector<TypeId> args{fn};
for (TypeId arg : c.argsPack)
args.push_back(arg);
TypeId instantiatedType = arena->addType(BlockedTypeVar{});
TypeId inferredFnType =
arena->addType(FunctionTypeVar(TypeLevel{}, constraint->scope.get(), arena->addTypePack(TypePack{args, {}}), c.result));
// Alter the inner constraints.
LUAU_ASSERT(c.innerConstraints.size() == 2);
asMutable(*c.innerConstraints.at(0)).c = InstantiationConstraint{instantiatedType, *callMm};
asMutable(*c.innerConstraints.at(1)).c = SubtypeConstraint{inferredFnType, instantiatedType};
unsolvedConstraints.insert(end(unsolvedConstraints), begin(c.innerConstraints), end(c.innerConstraints));
asMutable(c.result)->ty.emplace<FreeTypePack>(constraint->scope);
unblock(c.result);
return true;
}
const FunctionTypeVar* ftv = get<FunctionTypeVar>(fn);
bool usedMagic = false;
@ -1059,6 +1274,22 @@ bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNull<const Con
return true;
}
bool ConstraintSolver::tryDispatch(const SingletonOrTopTypeConstraint& c, NotNull<const Constraint> constraint)
{
if (isBlocked(c.discriminantType))
return false;
TypeId followed = follow(c.discriminantType);
// `nil` is a singleton type too! There's only one value of type `nil`.
if (get<SingletonTypeVar>(followed) || isNil(followed))
*asMutable(c.resultType) = NegationTypeVar{c.discriminantType};
else
*asMutable(c.resultType) = BoundTypeVar{singletonTypes->unknownType};
return true;
}
bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const IterableConstraint& c, NotNull<const Constraint> constraint, bool force)
{
auto block_ = [&](auto&& t) {
@ -1502,4 +1733,39 @@ TypePackId ConstraintSolver::errorRecoveryTypePack() const
return singletonTypes->errorRecoveryTypePack();
}
TypeId ConstraintSolver::unionOfTypes(TypeId a, TypeId b, NotNull<Scope> scope, bool unifyFreeTypes)
{
a = follow(a);
b = follow(b);
if (unifyFreeTypes && (get<FreeTypeVar>(a) || get<FreeTypeVar>(b)))
{
Unifier u{normalizer, Mode::Strict, scope, Location{}, Covariant};
u.useScopes = true;
u.tryUnify(b, a);
if (u.errors.empty())
{
u.log.commit();
return a;
}
else
{
return singletonTypes->errorRecoveryType(singletonTypes->anyType);
}
}
if (*a == *b)
return a;
std::vector<TypeId> types = reduceUnion({a, b});
if (types.empty())
return singletonTypes->neverType;
if (types.size() == 1)
return types[0];
return arena->addType(UnionTypeVar{types});
}
} // namespace Luau

View file

@ -0,0 +1,440 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/DataFlowGraphBuilder.h"
#include "Luau/Error.h"
LUAU_FASTFLAG(DebugLuauFreezeArena)
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
namespace Luau
{
std::optional<DefId> DataFlowGraph::getDef(const AstExpr* expr) const
{
if (auto def = astDefs.find(expr))
return NotNull{*def};
return std::nullopt;
}
std::optional<DefId> DataFlowGraph::getDef(const AstLocal* local) const
{
if (auto def = localDefs.find(local))
return NotNull{*def};
return std::nullopt;
}
std::optional<DefId> DataFlowGraph::getDef(const Symbol& symbol) const
{
if (symbol.local)
return getDef(symbol.local);
else
return std::nullopt;
}
DataFlowGraph DataFlowGraphBuilder::build(AstStatBlock* block, NotNull<InternalErrorReporter> handle)
{
LUAU_ASSERT(FFlag::DebugLuauDeferredConstraintResolution);
DataFlowGraphBuilder builder;
builder.handle = handle;
builder.visit(nullptr, block); // nullptr is the root DFG scope.
if (FFlag::DebugLuauFreezeArena)
builder.arena->allocator.freeze();
return std::move(builder.graph);
}
DfgScope* DataFlowGraphBuilder::childScope(DfgScope* scope)
{
return scopes.emplace_back(new DfgScope{scope}).get();
}
std::optional<DefId> DataFlowGraphBuilder::use(DfgScope* scope, Symbol symbol, AstExpr* e)
{
for (DfgScope* current = scope; current; current = current->parent)
{
if (auto loc = current->bindings.find(symbol))
{
graph.astDefs[e] = *loc;
return NotNull{*loc};
}
}
return std::nullopt;
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatBlock* b)
{
DfgScope* child = childScope(scope);
return visitBlockWithoutChildScope(child, b);
}
void DataFlowGraphBuilder::visitBlockWithoutChildScope(DfgScope* scope, AstStatBlock* b)
{
for (AstStat* s : b->body)
visit(scope, s);
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStat* s)
{
if (auto b = s->as<AstStatBlock>())
return visit(scope, b);
else if (auto i = s->as<AstStatIf>())
return visit(scope, i);
else if (auto w = s->as<AstStatWhile>())
return visit(scope, w);
else if (auto r = s->as<AstStatRepeat>())
return visit(scope, r);
else if (auto b = s->as<AstStatBreak>())
return visit(scope, b);
else if (auto c = s->as<AstStatContinue>())
return visit(scope, c);
else if (auto r = s->as<AstStatReturn>())
return visit(scope, r);
else if (auto e = s->as<AstStatExpr>())
return visit(scope, e);
else if (auto l = s->as<AstStatLocal>())
return visit(scope, l);
else if (auto f = s->as<AstStatFor>())
return visit(scope, f);
else if (auto f = s->as<AstStatForIn>())
return visit(scope, f);
else if (auto a = s->as<AstStatAssign>())
return visit(scope, a);
else if (auto c = s->as<AstStatCompoundAssign>())
return visit(scope, c);
else if (auto f = s->as<AstStatFunction>())
return visit(scope, f);
else if (auto l = s->as<AstStatLocalFunction>())
return visit(scope, l);
else if (auto t = s->as<AstStatTypeAlias>())
return; // ok
else if (auto d = s->as<AstStatDeclareFunction>())
return; // ok
else if (auto d = s->as<AstStatDeclareGlobal>())
return; // ok
else if (auto d = s->as<AstStatDeclareFunction>())
return; // ok
else if (auto d = s->as<AstStatDeclareClass>())
return; // ok
else if (auto _ = s->as<AstStatError>())
return; // ok
else
handle->ice("Unknown AstStat in DataFlowGraphBuilder");
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatIf* i)
{
DfgScope* condScope = childScope(scope);
visitExpr(condScope, i->condition);
visit(condScope, i->thenbody);
if (i->elsebody)
visit(scope, i->elsebody);
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatWhile* w)
{
// TODO(controlflow): entry point has a back edge from exit point
DfgScope* whileScope = childScope(scope);
visitExpr(whileScope, w->condition);
visit(whileScope, w->body);
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatRepeat* r)
{
// TODO(controlflow): entry point has a back edge from exit point
DfgScope* repeatScope = childScope(scope); // TODO: loop scope.
visitBlockWithoutChildScope(repeatScope, r->body);
visitExpr(repeatScope, r->condition);
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatBreak* b)
{
// TODO: Control flow analysis
return; // ok
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatContinue* c)
{
// TODO: Control flow analysis
return; // ok
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatReturn* r)
{
// TODO: Control flow analysis
for (AstExpr* e : r->list)
visitExpr(scope, e);
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatExpr* e)
{
visitExpr(scope, e->expr);
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatLocal* l)
{
// TODO: alias tracking
for (AstExpr* e : l->values)
visitExpr(scope, e);
for (AstLocal* local : l->vars)
{
DefId def = arena->freshDef();
graph.localDefs[local] = def;
scope->bindings[local] = def;
}
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatFor* f)
{
DfgScope* forScope = childScope(scope); // TODO: loop scope.
DefId def = arena->freshDef();
graph.localDefs[f->var] = def;
scope->bindings[f->var] = def;
// TODO(controlflow): entry point has a back edge from exit point
visit(forScope, f->body);
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatForIn* f)
{
DfgScope* forScope = childScope(scope); // TODO: loop scope.
for (AstLocal* local : f->vars)
{
DefId def = arena->freshDef();
graph.localDefs[local] = def;
forScope->bindings[local] = def;
}
// TODO(controlflow): entry point has a back edge from exit point
for (AstExpr* e : f->values)
visitExpr(forScope, e);
visit(forScope, f->body);
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatAssign* a)
{
for (AstExpr* r : a->values)
visitExpr(scope, r);
for (AstExpr* l : a->vars)
{
AstExpr* root = l;
bool isUpdatable = true;
while (true)
{
if (root->is<AstExprLocal>() || root->is<AstExprGlobal>())
break;
AstExprIndexName* indexName = root->as<AstExprIndexName>();
if (!indexName)
{
isUpdatable = false;
break;
}
root = indexName->expr;
}
if (isUpdatable)
{
// TODO global?
if (auto exprLocal = root->as<AstExprLocal>())
{
DefId def = arena->freshDef();
graph.astDefs[exprLocal] = def;
// Update the def in the scope that introduced the local. Not
// the current scope.
AstLocal* local = exprLocal->local;
DfgScope* s = scope;
while (s && !s->bindings.find(local))
s = s->parent;
LUAU_ASSERT(s && s->bindings.find(local));
s->bindings[local] = def;
}
}
visitExpr(scope, l); // TODO: they point to a new def!!
}
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatCompoundAssign* c)
{
// TODO(typestates): The lhs is being read and written to. This might or might not be annoying.
visitExpr(scope, c->value);
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatFunction* f)
{
visitExpr(scope, f->name);
visitExpr(scope, f->func);
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatLocalFunction* l)
{
DefId def = arena->freshDef();
graph.localDefs[l->name] = def;
scope->bindings[l->name] = def;
visitExpr(scope, l->func);
}
ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExpr* e)
{
if (auto g = e->as<AstExprGroup>())
return visitExpr(scope, g->expr);
else if (auto c = e->as<AstExprConstantNil>())
return {}; // ok
else if (auto c = e->as<AstExprConstantBool>())
return {}; // ok
else if (auto c = e->as<AstExprConstantNumber>())
return {}; // ok
else if (auto c = e->as<AstExprConstantString>())
return {}; // ok
else if (auto l = e->as<AstExprLocal>())
return visitExpr(scope, l);
else if (auto g = e->as<AstExprGlobal>())
return visitExpr(scope, g);
else if (auto v = e->as<AstExprVarargs>())
return {}; // ok
else if (auto c = e->as<AstExprCall>())
return visitExpr(scope, c);
else if (auto i = e->as<AstExprIndexName>())
return visitExpr(scope, i);
else if (auto i = e->as<AstExprIndexExpr>())
return visitExpr(scope, i);
else if (auto f = e->as<AstExprFunction>())
return visitExpr(scope, f);
else if (auto t = e->as<AstExprTable>())
return visitExpr(scope, t);
else if (auto u = e->as<AstExprUnary>())
return visitExpr(scope, u);
else if (auto b = e->as<AstExprBinary>())
return visitExpr(scope, b);
else if (auto t = e->as<AstExprTypeAssertion>())
return visitExpr(scope, t);
else if (auto i = e->as<AstExprIfElse>())
return visitExpr(scope, i);
else if (auto i = e->as<AstExprInterpString>())
return visitExpr(scope, i);
else if (auto _ = e->as<AstExprError>())
return {}; // ok
else
handle->ice("Unknown AstExpr in DataFlowGraphBuilder");
}
ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprLocal* l)
{
return {use(scope, l->local, l)};
}
ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprGlobal* g)
{
return {use(scope, g->name, g)};
}
ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprCall* c)
{
visitExpr(scope, c->func);
for (AstExpr* arg : c->args)
visitExpr(scope, arg);
return {};
}
ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIndexName* i)
{
std::optional<DefId> def = visitExpr(scope, i->expr).def;
if (!def)
return {};
// TODO: properties for the above def.
return {};
}
ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIndexExpr* i)
{
visitExpr(scope, i->expr);
visitExpr(scope, i->expr);
if (i->index->as<AstExprConstantString>())
{
// TODO: properties for the def
}
return {};
}
ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprFunction* f)
{
if (AstLocal* self = f->self)
{
DefId def = arena->freshDef();
graph.localDefs[self] = def;
scope->bindings[self] = def;
}
for (AstLocal* param : f->args)
{
DefId def = arena->freshDef();
graph.localDefs[param] = def;
scope->bindings[param] = def;
}
visit(scope, f->body);
return {};
}
ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprTable* t)
{
return {};
}
ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprUnary* u)
{
visitExpr(scope, u->expr);
return {};
}
ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprBinary* b)
{
visitExpr(scope, b->left);
visitExpr(scope, b->right);
return {};
}
ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprTypeAssertion* t)
{
ExpressionFlowGraph result = visitExpr(scope, t->expr);
// TODO: visit type
return result;
}
ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIfElse* i)
{
DfgScope* condScope = childScope(scope);
visitExpr(condScope, i->condition);
visitExpr(condScope, i->trueExpr);
visitExpr(scope, i->falseExpr);
return {};
}
ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprInterpString* i)
{
for (AstExpr* e : i->expressions)
visitExpr(scope, e);
return {};
}
} // namespace Luau

12
Analysis/src/Def.cpp Normal file
View file

@ -0,0 +1,12 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Def.h"
namespace Luau
{
DefId DefArena::freshDef()
{
return NotNull{allocator.allocate<Def>(Undefined{})};
}
} // namespace Luau

View file

@ -13,47 +13,47 @@ declare bit32: {
bor: (...number) -> number,
bxor: (...number) -> number,
btest: (number, ...number) -> boolean,
rrotate: (number, number) -> number,
lrotate: (number, number) -> number,
lshift: (number, number) -> number,
arshift: (number, number) -> number,
rshift: (number, number) -> number,
bnot: (number) -> number,
extract: (number, number, number?) -> number,
replace: (number, number, number, number?) -> number,
countlz: (number) -> number,
countrz: (number) -> number,
rrotate: (x: number, disp: number) -> number,
lrotate: (x: number, disp: number) -> number,
lshift: (x: number, disp: number) -> number,
arshift: (x: number, disp: number) -> number,
rshift: (x: number, disp: number) -> number,
bnot: (x: number) -> number,
extract: (n: number, field: number, width: number?) -> number,
replace: (n: number, v: number, field: number, width: number?) -> number,
countlz: (n: number) -> number,
countrz: (n: number) -> number,
}
declare math: {
frexp: (number) -> (number, number),
ldexp: (number, number) -> number,
fmod: (number, number) -> number,
modf: (number) -> (number, number),
pow: (number, number) -> number,
exp: (number) -> number,
frexp: (n: number) -> (number, number),
ldexp: (s: number, e: number) -> number,
fmod: (x: number, y: number) -> number,
modf: (n: number) -> (number, number),
pow: (x: number, y: number) -> number,
exp: (n: number) -> number,
ceil: (number) -> number,
floor: (number) -> number,
abs: (number) -> number,
sqrt: (number) -> number,
ceil: (n: number) -> number,
floor: (n: number) -> number,
abs: (n: number) -> number,
sqrt: (n: number) -> number,
log: (number, number?) -> number,
log10: (number) -> number,
log: (n: number, base: number?) -> number,
log10: (n: number) -> number,
rad: (number) -> number,
deg: (number) -> number,
rad: (n: number) -> number,
deg: (n: number) -> number,
sin: (number) -> number,
cos: (number) -> number,
tan: (number) -> number,
sinh: (number) -> number,
cosh: (number) -> number,
tanh: (number) -> number,
atan: (number) -> number,
acos: (number) -> number,
asin: (number) -> number,
atan2: (number, number) -> number,
sin: (n: number) -> number,
cos: (n: number) -> number,
tan: (n: number) -> number,
sinh: (n: number) -> number,
cosh: (n: number) -> number,
tanh: (n: number) -> number,
atan: (n: number) -> number,
acos: (n: number) -> number,
asin: (n: number) -> number,
atan2: (y: number, x: number) -> number,
min: (number, ...number) -> number,
max: (number, ...number) -> number,
@ -61,13 +61,13 @@ declare math: {
pi: number,
huge: number,
randomseed: (number) -> (),
randomseed: (seed: number) -> (),
random: (number?, number?) -> number,
sign: (number) -> number,
clamp: (number, number, number) -> number,
noise: (number, number?, number?) -> number,
round: (number) -> number,
sign: (n: number) -> number,
clamp: (n: number, min: number, max: number) -> number,
noise: (x: number, y: number?, z: number?) -> number,
round: (n: number) -> number,
}
type DateTypeArg = {
@ -93,9 +93,9 @@ type DateTypeResult = {
}
declare os: {
time: (DateTypeArg?) -> number,
date: (string?, number?) -> DateTypeResult | string,
difftime: (DateTypeResult | number, DateTypeResult | number) -> number,
time: (time: DateTypeArg?) -> number,
date: (formatString: string?, time: number?) -> DateTypeResult | string,
difftime: (t2: DateTypeResult | number, t1: DateTypeResult | number) -> number,
clock: () -> number,
}
@ -145,51 +145,51 @@ declare function loadstring<A...>(src: string, chunkname: string?): (((A...) ->
declare function newproxy(mt: boolean?): any
declare coroutine: {
create: <A..., R...>((A...) -> R...) -> thread,
resume: <A..., R...>(thread, A...) -> (boolean, R...),
create: <A..., R...>(f: (A...) -> R...) -> thread,
resume: <A..., R...>(co: thread, A...) -> (boolean, R...),
running: () -> thread,
status: (thread) -> "dead" | "running" | "normal" | "suspended",
status: (co: thread) -> "dead" | "running" | "normal" | "suspended",
-- FIXME: This technically returns a function, but we can't represent this yet.
wrap: <A..., R...>((A...) -> R...) -> any,
wrap: <A..., R...>(f: (A...) -> R...) -> any,
yield: <A..., R...>(A...) -> R...,
isyieldable: () -> boolean,
close: (thread) -> (boolean, any)
close: (co: thread) -> (boolean, any)
}
declare table: {
concat: <V>({V}, string?, number?, number?) -> string,
insert: (<V>({V}, V) -> ()) & (<V>({V}, number, V) -> ()),
maxn: <V>({V}) -> number,
remove: <V>({V}, number?) -> V?,
sort: <V>({V}, ((V, V) -> boolean)?) -> (),
create: <V>(number, V?) -> {V},
find: <V>({V}, V, number?) -> number?,
concat: <V>(t: {V}, sep: string?, i: number?, j: number?) -> string,
insert: (<V>(t: {V}, value: V) -> ()) & (<V>(t: {V}, pos: number, value: V) -> ()),
maxn: <V>(t: {V}) -> number,
remove: <V>(t: {V}, number?) -> V?,
sort: <V>(t: {V}, comp: ((V, V) -> boolean)?) -> (),
create: <V>(count: number, value: V?) -> {V},
find: <V>(haystack: {V}, needle: V, init: number?) -> number?,
unpack: <V>({V}, number?, number?) -> ...V,
unpack: <V>(list: {V}, i: number?, j: number?) -> ...V,
pack: <V>(...V) -> { n: number, [number]: V },
getn: <V>({V}) -> number,
foreach: <K, V>({[K]: V}, (K, V) -> ()) -> (),
getn: <V>(t: {V}) -> number,
foreach: <K, V>(t: {[K]: V}, f: (K, V) -> ()) -> (),
foreachi: <V>({V}, (number, V) -> ()) -> (),
move: <V>({V}, number, number, number, {V}?) -> {V},
clear: <K, V>({[K]: V}) -> (),
move: <V>(src: {V}, a: number, b: number, t: number, dst: {V}?) -> {V},
clear: <K, V>(table: {[K]: V}) -> (),
isfrozen: <K, V>({[K]: V}) -> boolean,
isfrozen: <K, V>(t: {[K]: V}) -> boolean,
}
declare debug: {
info: (<R...>(thread, number, string) -> R...) & (<R...>(number, string) -> R...) & (<A..., R1..., R2...>((A...) -> R1..., string) -> R2...),
traceback: ((string?, number?) -> string) & ((thread, string?, number?) -> string),
info: (<R...>(thread: thread, level: number, options: string) -> R...) & (<R...>(level: number, options: string) -> R...) & (<A..., R1..., R2...>(func: (A...) -> R1..., options: string) -> R2...),
traceback: ((message: string?, level: number?) -> string) & ((thread: thread, message: string?, level: number?) -> string),
}
declare utf8: {
char: (...number) -> string,
charpattern: string,
codes: (string) -> ((string, number) -> (number, number), string, number),
codepoint: (string, number?, number?) -> ...number,
len: (string, number?, number?) -> (number?, number?),
offset: (string, number?, number?) -> number,
codes: (str: string) -> ((string, number) -> (number, number), string, number),
codepoint: (str: string, i: number?, j: number?) -> ...number,
len: (s: string, i: number?, j: number?) -> (number?, number?),
offset: (s: string, n: number?, i: number?) -> number,
}
-- Cannot use `typeof` here because it will produce a polytype when we expect a monotype.

View file

@ -7,7 +7,7 @@
#include <stdexcept>
LUAU_FASTFLAGVARIABLE(LuauTypeMismatchModuleNameResolution, false)
LUAU_FASTFLAGVARIABLE(LuauIceExceptionInheritanceChange, false)
static std::string wrongNumberOfArgsString(
size_t expectedCount, std::optional<size_t> maximumCount, size_t actualCount, const char* argPrefix = nullptr, bool isVariadic = false)
@ -70,7 +70,7 @@ struct ErrorConverter
{
if (auto wantedDefinitionModule = getDefinitionModuleName(tm.wantedType))
{
if (FFlag::LuauTypeMismatchModuleNameResolution && fileResolver != nullptr)
if (fileResolver != nullptr)
{
std::string givenModuleName = fileResolver->getHumanReadableModuleName(*givenDefinitionModule);
std::string wantedModuleName = fileResolver->getHumanReadableModuleName(*wantedDefinitionModule);
@ -96,14 +96,7 @@ struct ErrorConverter
if (!tm.reason.empty())
result += tm.reason + " ";
if (FFlag::LuauTypeMismatchModuleNameResolution)
{
result += Luau::toString(*tm.error, TypeErrorToStringOptions{fileResolver});
}
else
{
result += Luau::toString(*tm.error);
}
result += Luau::toString(*tm.error, TypeErrorToStringOptions{fileResolver});
}
else if (!tm.reason.empty())
{
@ -469,6 +462,11 @@ struct ErrorConverter
{
return "Code is too complex to typecheck! Consider simplifying the code around this area";
}
std::string operator()(const TypePackMismatch& e) const
{
return "Type pack '" + toString(e.givenTp) + "' could not be converted into '" + toString(e.wantedTp) + "'";
}
};
struct InvalidNameChecker
@ -727,6 +725,11 @@ bool TypesAreUnrelated::operator==(const TypesAreUnrelated& rhs) const
return left == rhs.left && right == rhs.right;
}
bool TypePackMismatch::operator==(const TypePackMismatch& rhs) const
{
return *wantedTp == *rhs.wantedTp && *givenTp == *rhs.givenTp;
}
std::string toString(const TypeError& error)
{
return toString(error, TypeErrorToStringOptions{});
@ -878,6 +881,11 @@ void copyError(T& e, TypeArena& destArena, CloneState cloneState)
else if constexpr (std::is_same_v<T, NormalizationTooComplex>)
{
}
else if constexpr (std::is_same_v<T, TypePackMismatch>)
{
e.wantedTp = clone(e.wantedTp);
e.givenTp = clone(e.givenTp);
}
else
static_assert(always_false_v<T>, "Non-exhaustive type switch");
}
@ -922,4 +930,30 @@ const char* InternalCompilerError::what() const throw()
return this->message.data();
}
// TODO: Inline me when LuauIceExceptionInheritanceChange is deleted.
void throwRuntimeError(const std::string& message)
{
if (FFlag::LuauIceExceptionInheritanceChange)
{
throw InternalCompilerError(message);
}
else
{
throw std::runtime_error(message);
}
}
// TODO: Inline me when LuauIceExceptionInheritanceChange is deleted.
void throwRuntimeError(const std::string& message, const std::string& moduleName)
{
if (FFlag::LuauIceExceptionInheritanceChange)
{
throw InternalCompilerError(message, moduleName);
}
else
{
throw std::runtime_error(message);
}
}
} // namespace Luau

View file

@ -1,11 +1,13 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Frontend.h"
#include "Luau/BuiltinDefinitions.h"
#include "Luau/Clone.h"
#include "Luau/Common.h"
#include "Luau/Config.h"
#include "Luau/ConstraintGraphBuilder.h"
#include "Luau/ConstraintSolver.h"
#include "Luau/DataFlowGraphBuilder.h"
#include "Luau/DcrLogger.h"
#include "Luau/FileResolver.h"
#include "Luau/Parser.h"
@ -15,7 +17,6 @@
#include "Luau/TypeChecker2.h"
#include "Luau/TypeInfer.h"
#include "Luau/Variant.h"
#include "Luau/BuiltinDefinitions.h"
#include <algorithm>
#include <chrono>
@ -26,10 +27,11 @@ LUAU_FASTINT(LuauTarjanChildLimit)
LUAU_FASTFLAG(LuauInferInNoCheckMode)
LUAU_FASTFLAG(LuauNoMoreGlobalSingletonTypes)
LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false)
LUAU_FASTFLAGVARIABLE(LuauAutocompleteDynamicLimits, false)
LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 100)
LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false)
LUAU_FASTFLAG(DebugLuauLogSolverToJson);
LUAU_FASTFLAGVARIABLE(LuauFixMarkDirtyReverseDeps, false)
LUAU_FASTFLAGVARIABLE(LuauPersistTypesAfterGeneratingDocSyms, false)
namespace Luau
{
@ -110,24 +112,57 @@ LoadDefinitionFileResult Frontend::loadDefinitionFile(std::string_view source, c
CloneState cloneState;
for (const auto& [name, ty] : checkedModule->declaredGlobals)
if (FFlag::LuauPersistTypesAfterGeneratingDocSyms)
{
TypeId globalTy = clone(ty, globalTypes, cloneState);
std::string documentationSymbol = packageName + "/global/" + name;
generateDocumentationSymbols(globalTy, documentationSymbol);
globalScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol};
std::vector<TypeId> typesToPersist;
typesToPersist.reserve(checkedModule->declaredGlobals.size() + checkedModule->getModuleScope()->exportedTypeBindings.size());
persist(globalTy);
for (const auto& [name, ty] : checkedModule->declaredGlobals)
{
TypeId globalTy = clone(ty, globalTypes, cloneState);
std::string documentationSymbol = packageName + "/global/" + name;
generateDocumentationSymbols(globalTy, documentationSymbol);
globalScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol};
typesToPersist.push_back(globalTy);
}
for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings)
{
TypeFun globalTy = clone(ty, globalTypes, cloneState);
std::string documentationSymbol = packageName + "/globaltype/" + name;
generateDocumentationSymbols(globalTy.type, documentationSymbol);
globalScope->exportedTypeBindings[name] = globalTy;
typesToPersist.push_back(globalTy.type);
}
for (TypeId ty : typesToPersist)
{
persist(ty);
}
}
for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings)
else
{
TypeFun globalTy = clone(ty, globalTypes, cloneState);
std::string documentationSymbol = packageName + "/globaltype/" + name;
generateDocumentationSymbols(globalTy.type, documentationSymbol);
globalScope->exportedTypeBindings[name] = globalTy;
for (const auto& [name, ty] : checkedModule->declaredGlobals)
{
TypeId globalTy = clone(ty, globalTypes, cloneState);
std::string documentationSymbol = packageName + "/global/" + name;
generateDocumentationSymbols(globalTy, documentationSymbol);
globalScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol};
persist(globalTy.type);
persist(globalTy);
}
for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings)
{
TypeFun globalTy = clone(ty, globalTypes, cloneState);
std::string documentationSymbol = packageName + "/globaltype/" + name;
generateDocumentationSymbols(globalTy.type, documentationSymbol);
globalScope->exportedTypeBindings[name] = globalTy;
persist(globalTy.type);
}
}
return LoadDefinitionFileResult{true, parseResult, checkedModule};
@ -159,24 +194,57 @@ LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr t
CloneState cloneState;
for (const auto& [name, ty] : checkedModule->declaredGlobals)
if (FFlag::LuauPersistTypesAfterGeneratingDocSyms)
{
TypeId globalTy = clone(ty, typeChecker.globalTypes, cloneState);
std::string documentationSymbol = packageName + "/global/" + name;
generateDocumentationSymbols(globalTy, documentationSymbol);
targetScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol};
std::vector<TypeId> typesToPersist;
typesToPersist.reserve(checkedModule->declaredGlobals.size() + checkedModule->getModuleScope()->exportedTypeBindings.size());
persist(globalTy);
for (const auto& [name, ty] : checkedModule->declaredGlobals)
{
TypeId globalTy = clone(ty, typeChecker.globalTypes, cloneState);
std::string documentationSymbol = packageName + "/global/" + name;
generateDocumentationSymbols(globalTy, documentationSymbol);
targetScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol};
typesToPersist.push_back(globalTy);
}
for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings)
{
TypeFun globalTy = clone(ty, typeChecker.globalTypes, cloneState);
std::string documentationSymbol = packageName + "/globaltype/" + name;
generateDocumentationSymbols(globalTy.type, documentationSymbol);
targetScope->exportedTypeBindings[name] = globalTy;
typesToPersist.push_back(globalTy.type);
}
for (TypeId ty : typesToPersist)
{
persist(ty);
}
}
for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings)
else
{
TypeFun globalTy = clone(ty, typeChecker.globalTypes, cloneState);
std::string documentationSymbol = packageName + "/globaltype/" + name;
generateDocumentationSymbols(globalTy.type, documentationSymbol);
targetScope->exportedTypeBindings[name] = globalTy;
for (const auto& [name, ty] : checkedModule->declaredGlobals)
{
TypeId globalTy = clone(ty, typeChecker.globalTypes, cloneState);
std::string documentationSymbol = packageName + "/global/" + name;
generateDocumentationSymbols(globalTy, documentationSymbol);
targetScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol};
persist(globalTy.type);
persist(globalTy);
}
for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings)
{
TypeFun globalTy = clone(ty, typeChecker.globalTypes, cloneState);
std::string documentationSymbol = packageName + "/globaltype/" + name;
generateDocumentationSymbols(globalTy.type, documentationSymbol);
targetScope->exportedTypeBindings[name] = globalTy;
persist(globalTy.type);
}
}
return LoadDefinitionFileResult{true, parseResult, checkedModule};
@ -425,13 +493,13 @@ CheckResult Frontend::check(const ModuleName& name, std::optional<FrontendOption
{
auto it2 = moduleResolverForAutocomplete.modules.find(name);
if (it2 == moduleResolverForAutocomplete.modules.end() || it2->second == nullptr)
throw std::runtime_error("Frontend::modules does not have data for " + name);
throwRuntimeError("Frontend::modules does not have data for " + name, name);
}
else
{
auto it2 = moduleResolver.modules.find(name);
if (it2 == moduleResolver.modules.end() || it2->second == nullptr)
throw std::runtime_error("Frontend::modules does not have data for " + name);
throwRuntimeError("Frontend::modules does not have data for " + name, name);
}
return CheckResult{
@ -488,23 +556,19 @@ CheckResult Frontend::check(const ModuleName& name, std::optional<FrontendOption
else
typeCheckerForAutocomplete.finishTime = std::nullopt;
if (FFlag::LuauAutocompleteDynamicLimits)
{
// TODO: This is a dirty ad hoc solution for autocomplete timeouts
// We are trying to dynamically adjust our existing limits to lower total typechecking time under the limit
// so that we'll have type information for the whole file at lower quality instead of a full abort in the middle
if (FInt::LuauTarjanChildLimit > 0)
typeCheckerForAutocomplete.instantiationChildLimit =
std::max(1, int(FInt::LuauTarjanChildLimit * sourceNode.autocompleteLimitsMult));
else
typeCheckerForAutocomplete.instantiationChildLimit = std::nullopt;
// TODO: This is a dirty ad hoc solution for autocomplete timeouts
// We are trying to dynamically adjust our existing limits to lower total typechecking time under the limit
// so that we'll have type information for the whole file at lower quality instead of a full abort in the middle
if (FInt::LuauTarjanChildLimit > 0)
typeCheckerForAutocomplete.instantiationChildLimit = std::max(1, int(FInt::LuauTarjanChildLimit * sourceNode.autocompleteLimitsMult));
else
typeCheckerForAutocomplete.instantiationChildLimit = std::nullopt;
if (FInt::LuauTypeInferIterationLimit > 0)
typeCheckerForAutocomplete.unifierIterationLimit =
std::max(1, int(FInt::LuauTypeInferIterationLimit * sourceNode.autocompleteLimitsMult));
else
typeCheckerForAutocomplete.unifierIterationLimit = std::nullopt;
}
if (FInt::LuauTypeInferIterationLimit > 0)
typeCheckerForAutocomplete.unifierIterationLimit =
std::max(1, int(FInt::LuauTypeInferIterationLimit * sourceNode.autocompleteLimitsMult));
else
typeCheckerForAutocomplete.unifierIterationLimit = std::nullopt;
ModulePtr moduleForAutocomplete = FFlag::DebugLuauDeferredConstraintResolution
? check(sourceModule, mode, environmentScope, requireCycles, /*forAutocomplete*/ true)
@ -518,10 +582,9 @@ CheckResult Frontend::check(const ModuleName& name, std::optional<FrontendOption
{
checkResult.timeoutHits.push_back(moduleName);
if (FFlag::LuauAutocompleteDynamicLimits)
sourceNode.autocompleteLimitsMult = sourceNode.autocompleteLimitsMult / 2.0;
sourceNode.autocompleteLimitsMult = sourceNode.autocompleteLimitsMult / 2.0;
}
else if (FFlag::LuauAutocompleteDynamicLimits && duration < autocompleteTimeLimit / 2.0)
else if (duration < autocompleteTimeLimit / 2.0)
{
sourceNode.autocompleteLimitsMult = std::min(sourceNode.autocompleteLimitsMult * 2.0, 1.0);
}
@ -543,7 +606,7 @@ CheckResult Frontend::check(const ModuleName& name, std::optional<FrontendOption
stats.filesNonstrict += mode == Mode::Nonstrict;
if (module == nullptr)
throw std::runtime_error("Frontend::check produced a nullptr module for " + moduleName);
throwRuntimeError("Frontend::check produced a nullptr module for " + moduleName, moduleName);
if (!frontendOptions.retainFullTypeGraphs)
{
@ -807,13 +870,26 @@ void Frontend::markDirty(const ModuleName& name, std::vector<ModuleName>* marked
sourceNode.dirtyModule = true;
sourceNode.dirtyModuleForAutocomplete = true;
if (0 == reverseDeps.count(name))
continue;
if (FFlag::LuauFixMarkDirtyReverseDeps)
{
if (0 == reverseDeps.count(next))
continue;
sourceModules.erase(name);
sourceModules.erase(next);
const std::vector<ModuleName>& dependents = reverseDeps[name];
queue.insert(queue.end(), dependents.begin(), dependents.end());
const std::vector<ModuleName>& dependents = reverseDeps[next];
queue.insert(queue.end(), dependents.begin(), dependents.end());
}
else
{
if (0 == reverseDeps.count(name))
continue;
sourceModules.erase(name);
const std::vector<ModuleName>& dependents = reverseDeps[name];
queue.insert(queue.end(), dependents.begin(), dependents.end());
}
}
}
@ -857,13 +933,25 @@ ModulePtr Frontend::check(
}
}
DataFlowGraph dfg = DataFlowGraphBuilder::build(sourceModule.root, NotNull{&iceHandler});
const NotNull<ModuleResolver> mr{forAutocomplete ? &moduleResolverForAutocomplete : &moduleResolver};
const ScopePtr& globalScope{forAutocomplete ? typeCheckerForAutocomplete.globalScope : typeChecker.globalScope};
Normalizer normalizer{&result->internalTypes, singletonTypes, NotNull{&typeChecker.unifierState}};
ConstraintGraphBuilder cgb{
sourceModule.name, result, &result->internalTypes, mr, singletonTypes, NotNull(&iceHandler), globalScope, logger.get()};
sourceModule.name,
result,
&result->internalTypes,
mr,
singletonTypes,
NotNull(&iceHandler),
globalScope,
logger.get(),
NotNull{&dfg},
};
cgb.visit(sourceModule.root);
result->errors = std::move(cgb.errors);
@ -986,11 +1074,11 @@ SourceModule Frontend::parse(const ModuleName& name, std::string_view src, const
double timestamp = getTimestamp();
auto parseResult = Luau::Parser::parse(src.data(), src.size(), *sourceModule.names, *sourceModule.allocator, parseOptions);
Luau::ParseResult parseResult = Luau::Parser::parse(src.data(), src.size(), *sourceModule.names, *sourceModule.allocator, parseOptions);
stats.timeParse += getTimestamp() - timestamp;
stats.files++;
stats.lines += std::count(src.begin(), src.end(), '\n') + (src.size() && src.back() != '\n');
stats.lines += parseResult.lines;
if (!parseResult.errors.empty())
sourceModule.parseErrors.insert(sourceModule.parseErrors.end(), parseResult.errors.begin(), parseResult.errors.end());

View file

@ -188,6 +188,8 @@ static void errorToString(std::ostream& stream, const T& err)
stream << "TypesAreUnrelated { left = '" + toString(err.left) + "', right = '" + toString(err.right) + "' }";
else if constexpr (std::is_same_v<T, NormalizationTooComplex>)
stream << "NormalizationTooComplex { }";
else if constexpr (std::is_same_v<T, TypePackMismatch>)
stream << "TypePackMismatch { wanted = '" + toString(err.wantedTp) + "', given = '" + toString(err.givenTp) + "' }";
else
static_assert(always_false_v<T>, "Non-exhaustive type switch");
}

View file

@ -60,36 +60,6 @@ bool isWithinComment(const SourceModule& sourceModule, Position pos)
return contains(pos, *iter);
}
struct ForceNormal : TypeVarOnceVisitor
{
const TypeArena* typeArena = nullptr;
ForceNormal(const TypeArena* typeArena)
: typeArena(typeArena)
{
}
bool visit(TypeId ty) override
{
if (ty->owningArena != typeArena)
return false;
asMutable(ty)->normal = true;
return true;
}
bool visit(TypeId ty, const FreeTypeVar& ftv) override
{
visit(ty);
return true;
}
bool visit(TypePackId tp, const FreeTypePack& ftp) override
{
return true;
}
};
struct ClonePublicInterface : Substitution
{
NotNull<SingletonTypes> singletonTypes;
@ -241,8 +211,6 @@ void Module::clonePublicInterface(NotNull<SingletonTypes> singletonTypes, Intern
moduleScope->varargPack = varargPack;
}
ForceNormal forceNormal{&interfaceTypes};
if (exportedTypeBindings)
{
for (auto& [name, tf] : *exportedTypeBindings)
@ -262,7 +230,6 @@ void Module::clonePublicInterface(NotNull<SingletonTypes> singletonTypes, Intern
{
auto t = asMutable(ty);
t->ty = AnyTypeVar{};
t->normal = true;
}
}
}

File diff suppressed because it is too large Load diff

View file

@ -57,29 +57,6 @@ struct Quantifier final : TypeVarOnceVisitor
return false;
}
bool visit(TypeId ty, const ConstrainedTypeVar&) override
{
ConstrainedTypeVar* ctv = getMutable<ConstrainedTypeVar>(ty);
seenMutableType = true;
if (!level.subsumes(ctv->level))
return false;
std::vector<TypeId> opts = std::move(ctv->parts);
// We might transmute, so it's not safe to rely on the builtin traversal logic
for (TypeId opt : opts)
traverse(opt);
if (opts.size() == 1)
*asMutable(ty) = BoundTypeVar{opts[0]};
else
*asMutable(ty) = UnionTypeVar{std::move(opts)};
return false;
}
bool visit(TypeId ty, const TableTypeVar&) override
{
LUAU_ASSERT(getMutable<TableTypeVar>(ty));

View file

@ -27,6 +27,44 @@ void Scope::addBuiltinTypeBinding(const Name& name, const TypeFun& tyFun)
builtinTypeNames.insert(name);
}
std::optional<TypeId> Scope::lookup(Symbol sym) const
{
auto r = const_cast<Scope*>(this)->lookupEx(sym);
if (r)
return r->first;
else
return std::nullopt;
}
std::optional<std::pair<TypeId, Scope*>> Scope::lookupEx(Symbol sym)
{
Scope* s = this;
while (true)
{
auto it = s->bindings.find(sym);
if (it != s->bindings.end())
return std::pair{it->second.typeId, s};
if (s->parent)
s = s->parent.get();
else
return std::nullopt;
}
}
// TODO: We might kill Scope::lookup(Symbol) once data flow is fully fleshed out with type states and control flow analysis.
std::optional<TypeId> Scope::lookup(DefId def) const
{
for (const Scope* current = this; current; current = current->parent.get())
{
if (auto ty = current->dcrRefinements.find(def))
return *ty;
}
return std::nullopt;
}
std::optional<TypeFun> Scope::lookupType(const Name& name)
{
const Scope* scope = this;
@ -111,23 +149,6 @@ std::optional<Binding> Scope::linearSearchForBinding(const std::string& name, bo
return std::nullopt;
}
std::optional<TypeId> Scope::lookup(Symbol sym)
{
Scope* s = this;
while (true)
{
auto it = s->bindings.find(sym);
if (it != s->bindings.end())
return it->second.typeId;
if (s->parent)
s = s->parent.get();
else
return std::nullopt;
}
}
bool subsumesStrict(Scope* left, Scope* right)
{
while (right)

View file

@ -73,11 +73,6 @@ void Tarjan::visitChildren(TypeId ty, int index)
for (TypeId part : itv->parts)
visitChild(part);
}
else if (const ConstrainedTypeVar* ctv = get<ConstrainedTypeVar>(ty))
{
for (TypeId part : ctv->parts)
visitChild(part);
}
else if (const PendingExpansionTypeVar* petv = get<PendingExpansionTypeVar>(ty))
{
for (TypeId a : petv->typeArguments)
@ -97,6 +92,10 @@ void Tarjan::visitChildren(TypeId ty, int index)
if (ctv->metatable)
visitChild(*ctv->metatable);
}
else if (const NegationTypeVar* ntv = get<NegationTypeVar>(ty))
{
visitChild(ntv->ty);
}
}
void Tarjan::visitChildren(TypePackId tp, int index)
@ -605,11 +604,6 @@ void Substitution::replaceChildren(TypeId ty)
for (TypeId& part : itv->parts)
part = replace(part);
}
else if (ConstrainedTypeVar* ctv = getMutable<ConstrainedTypeVar>(ty))
{
for (TypeId& part : ctv->parts)
part = replace(part);
}
else if (PendingExpansionTypeVar* petv = getMutable<PendingExpansionTypeVar>(ty))
{
for (TypeId& a : petv->typeArguments)
@ -629,6 +623,10 @@ void Substitution::replaceChildren(TypeId ty)
if (ctv->metatable)
ctv->metatable = replace(*ctv->metatable);
}
else if (NegationTypeVar* ntv = getMutable<NegationTypeVar>(ty))
{
ntv->ty = replace(ntv->ty);
}
}
void Substitution::replaceChildren(TypePackId tp)

View file

@ -237,15 +237,6 @@ void StateDot::visitChildren(TypeId ty, int index)
finishNodeLabel(ty);
finishNode();
}
else if (const ConstrainedTypeVar* ctv = get<ConstrainedTypeVar>(ty))
{
formatAppend(result, "ConstrainedTypeVar %d", index);
finishNodeLabel(ty);
finishNode();
for (TypeId part : ctv->parts)
visitChild(part, index);
}
else if (get<ErrorTypeVar>(ty))
{
formatAppend(result, "ErrorTypeVar %d", index);

View file

@ -10,11 +10,12 @@
#include <algorithm>
#include <stdexcept>
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
LUAU_FASTFLAG(LuauLvaluelessPath)
LUAU_FASTFLAG(LuauUnknownAndNeverType)
LUAU_FASTFLAGVARIABLE(LuauSpecialTypesAsterisked, false)
LUAU_FASTFLAGVARIABLE(LuauFixNameMaps, false)
LUAU_FASTFLAGVARIABLE(LuauUnseeArrayTtv, false)
LUAU_FASTFLAGVARIABLE(LuauFunctionReturnStringificationFixup, false)
LUAU_FASTFLAGVARIABLE(LuauUnseeArrayTtv, false)
/*
* Prefix generic typenames with gen-
@ -224,6 +225,20 @@ struct StringifierState
result.name += s;
}
void emitLevel(Scope* scope)
{
size_t count = 0;
for (Scope* s = scope; s; s = s->parent.get())
++count;
emit(count);
emit("-");
char buffer[16];
uint32_t s = uint32_t(intptr_t(scope) & 0xFFFFFF);
snprintf(buffer, sizeof(buffer), "0x%x", s);
emit(buffer);
}
void emit(TypeLevel level)
{
emit(std::to_string(level.level));
@ -295,10 +310,7 @@ struct TypeVarStringifier
if (tv->ty.valueless_by_exception())
{
state.result.error = true;
if (FFlag::LuauSpecialTypesAsterisked)
state.emit("* VALUELESS BY EXCEPTION *");
else
state.emit("< VALUELESS BY EXCEPTION >");
state.emit("* VALUELESS BY EXCEPTION *");
return;
}
@ -376,7 +388,10 @@ struct TypeVarStringifier
if (FFlag::DebugLuauVerboseTypeNames)
{
state.emit("-");
state.emit(ftv.level);
if (FFlag::DebugLuauDeferredConstraintResolution)
state.emitLevel(ftv.scope);
else
state.emit(ftv.level);
}
}
@ -398,29 +413,15 @@ struct TypeVarStringifier
}
else
state.emit(state.getName(ty));
}
void operator()(TypeId, const ConstrainedTypeVar& ctv)
{
state.result.invalid = true;
state.emit("[");
if (FFlag::DebugLuauVerboseTypeNames)
state.emit(ctv.level);
state.emit("[");
bool first = true;
for (TypeId ty : ctv.parts)
{
if (first)
first = false;
state.emit("-");
if (FFlag::DebugLuauDeferredConstraintResolution)
state.emitLevel(gtv.scope);
else
state.emit("|");
stringify(ty);
state.emit(gtv.level);
}
state.emit("]]");
}
void operator()(TypeId, const BlockedTypeVar& btv)
@ -456,9 +457,12 @@ struct TypeVarStringifier
case PrimitiveTypeVar::Thread:
state.emit("thread");
return;
case PrimitiveTypeVar::Function:
state.emit("function");
return;
default:
LUAU_ASSERT(!"Unknown primitive type");
throw std::runtime_error("Unknown primitive type " + std::to_string(ptv.type));
throwRuntimeError("Unknown primitive type " + std::to_string(ptv.type));
}
}
@ -475,7 +479,7 @@ struct TypeVarStringifier
else
{
LUAU_ASSERT(!"Unknown singleton type");
throw std::runtime_error("Unknown singleton type");
throwRuntimeError("Unknown singleton type");
}
}
@ -484,10 +488,7 @@ struct TypeVarStringifier
if (state.hasSeen(&ftv))
{
state.result.cycle = true;
if (FFlag::LuauSpecialTypesAsterisked)
state.emit("*CYCLE*");
else
state.emit("<CYCLE>");
state.emit("*CYCLE*");
return;
}
@ -595,10 +596,7 @@ struct TypeVarStringifier
if (state.hasSeen(&ttv))
{
state.result.cycle = true;
if (FFlag::LuauSpecialTypesAsterisked)
state.emit("*CYCLE*");
else
state.emit("<CYCLE>");
state.emit("*CYCLE*");
return;
}
@ -732,10 +730,7 @@ struct TypeVarStringifier
if (state.hasSeen(&uv))
{
state.result.cycle = true;
if (FFlag::LuauSpecialTypesAsterisked)
state.emit("*CYCLE*");
else
state.emit("<CYCLE>");
state.emit("*CYCLE*");
return;
}
@ -802,10 +797,7 @@ struct TypeVarStringifier
if (state.hasSeen(&uv))
{
state.result.cycle = true;
if (FFlag::LuauSpecialTypesAsterisked)
state.emit("*CYCLE*");
else
state.emit("<CYCLE>");
state.emit("*CYCLE*");
return;
}
@ -850,10 +842,7 @@ struct TypeVarStringifier
void operator()(TypeId, const ErrorTypeVar& tv)
{
state.result.error = true;
if (FFlag::LuauSpecialTypesAsterisked)
state.emit(FFlag::LuauUnknownAndNeverType ? "*error-type*" : "*unknown*");
else
state.emit(FFlag::LuauUnknownAndNeverType ? "<error-type>" : "*unknown*");
state.emit(FFlag::LuauUnknownAndNeverType ? "*error-type*" : "*unknown*");
}
void operator()(TypeId, const LazyTypeVar& ltv)
@ -871,6 +860,23 @@ struct TypeVarStringifier
{
state.emit("never");
}
void operator()(TypeId, const NegationTypeVar& ntv)
{
state.emit("~");
// The precedence of `~` should be less than `|` and `&`.
TypeId followed = follow(ntv.ty);
bool parens = get<UnionTypeVar>(followed) || get<IntersectionTypeVar>(followed);
if (parens)
state.emit("(");
stringify(ntv.ty);
if (parens)
state.emit(")");
}
};
struct TypePackStringifier
@ -907,10 +913,7 @@ struct TypePackStringifier
if (tp->ty.valueless_by_exception())
{
state.result.error = true;
if (FFlag::LuauSpecialTypesAsterisked)
state.emit("* VALUELESS TP BY EXCEPTION *");
else
state.emit("< VALUELESS TP BY EXCEPTION >");
state.emit("* VALUELESS TP BY EXCEPTION *");
return;
}
@ -934,10 +937,7 @@ struct TypePackStringifier
if (state.hasSeen(&tp))
{
state.result.cycle = true;
if (FFlag::LuauSpecialTypesAsterisked)
state.emit("*CYCLETP*");
else
state.emit("<CYCLETP>");
state.emit("*CYCLETP*");
return;
}
@ -982,10 +982,7 @@ struct TypePackStringifier
void operator()(TypePackId, const Unifiable::Error& error)
{
state.result.error = true;
if (FFlag::LuauSpecialTypesAsterisked)
state.emit(FFlag::LuauUnknownAndNeverType ? "*error-type*" : "*unknown*");
else
state.emit(FFlag::LuauUnknownAndNeverType ? "<error-type>" : "*unknown*");
state.emit(FFlag::LuauUnknownAndNeverType ? "*error-type*" : "*unknown*");
}
void operator()(TypePackId, const VariadicTypePack& pack)
@ -993,10 +990,7 @@ struct TypePackStringifier
state.emit("...");
if (FFlag::DebugLuauVerboseTypeNames && pack.hidden)
{
if (FFlag::LuauSpecialTypesAsterisked)
state.emit("*hidden*");
else
state.emit("<hidden>");
state.emit("*hidden*");
}
stringify(pack.ty);
}
@ -1031,7 +1025,10 @@ struct TypePackStringifier
if (FFlag::DebugLuauVerboseTypeNames)
{
state.emit("-");
state.emit(pack.level);
if (FFlag::DebugLuauDeferredConstraintResolution)
state.emitLevel(pack.scope);
else
state.emit(pack.level);
}
state.emit("...");
@ -1204,10 +1201,7 @@ ToStringResult toStringDetailed(TypeId ty, ToStringOptions& opts)
{
result.truncated = true;
if (FFlag::LuauSpecialTypesAsterisked)
result.name += "... *TRUNCATED*";
else
result.name += "... <TRUNCATED>";
result.name += "... *TRUNCATED*";
}
return result;
@ -1280,10 +1274,7 @@ ToStringResult toStringDetailed(TypePackId tp, ToStringOptions& opts)
if (opts.maxTypeLength > 0 && result.name.length() > opts.maxTypeLength)
{
if (FFlag::LuauSpecialTypesAsterisked)
result.name += "... *TRUNCATED*";
else
result.name += "... <TRUNCATED>";
result.name += "... *TRUNCATED*";
}
return result;
@ -1442,7 +1433,7 @@ std::string generateName(size_t i)
std::string toString(const Constraint& constraint, ToStringOptions& opts)
{
auto go = [&opts](auto&& c) {
auto go = [&opts](auto&& c) -> std::string {
using T = std::decay_t<decltype(c)>;
// TODO: Inline and delete this function when clipping FFlag::LuauFixNameMaps
@ -1526,6 +1517,13 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts)
{
return tos(c.resultType, opts) + " ~ hasProp " + tos(c.subjectType, opts) + ", \"" + c.prop + "\"";
}
else if constexpr (std::is_same_v<T, SingletonOrTopTypeConstraint>)
{
std::string result = tos(c.resultType, opts);
std::string discriminant = tos(c.discriminantType, opts);
return result + " ~ if isSingleton D then ~D else unknown where D = " + discriminant;
}
else
static_assert(always_false_v<T>, "Non-exhaustive constraint switch");
};
@ -1545,6 +1543,8 @@ std::string dump(const Constraint& c)
std::string toString(const LValue& lvalue)
{
LUAU_ASSERT(!FFlag::LuauLvaluelessPath);
std::string s;
for (const LValue* current = &lvalue; current; current = baseof(*current))
{
@ -1559,4 +1559,37 @@ std::string toString(const LValue& lvalue)
return s;
}
std::optional<std::string> getFunctionNameAsString(const AstExpr& expr)
{
LUAU_ASSERT(FFlag::LuauLvaluelessPath);
const AstExpr* curr = &expr;
std::string s;
for (;;)
{
if (auto local = curr->as<AstExprLocal>())
return local->local->name.value + s;
if (auto global = curr->as<AstExprGlobal>())
return global->name.value + s;
if (auto indexname = curr->as<AstExprIndexName>())
{
curr = indexname->expr;
s = "." + std::string(indexname->index.value) + s;
}
else if (auto group = curr->as<AstExprGroup>())
{
curr = group->expr;
}
else
{
return std::nullopt;
}
}
return s;
}
} // namespace Luau

View file

@ -1,6 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/TopoSortStatements.h"
#include "Luau/Error.h"
/* Decide the order in which we typecheck Lua statements in a block.
*
* Algorithm:
@ -149,7 +150,7 @@ Identifier mkName(const AstStatFunction& function)
auto name = mkName(*function.name);
LUAU_ASSERT(bool(name));
if (!name)
throw std::runtime_error("Internal error: Function declaration has a bad name");
throwRuntimeError("Internal error: Function declaration has a bad name");
return *name;
}
@ -255,7 +256,7 @@ struct ArcCollector : public AstVisitor
{
auto name = mkName(*node->name);
if (!name)
throw std::runtime_error("Internal error: AstStatFunction has a bad name");
throwRuntimeError("Internal error: AstStatFunction has a bad name");
add(*name);
return true;

View file

@ -251,7 +251,7 @@ PendingType* TxnLog::bindTable(TypeId ty, std::optional<TypeId> newBoundTo)
PendingType* TxnLog::changeLevel(TypeId ty, TypeLevel newLevel)
{
LUAU_ASSERT(get<FreeTypeVar>(ty) || get<TableTypeVar>(ty) || get<FunctionTypeVar>(ty) || get<ConstrainedTypeVar>(ty));
LUAU_ASSERT(get<FreeTypeVar>(ty) || get<TableTypeVar>(ty) || get<FunctionTypeVar>(ty));
PendingType* newTy = queue(ty);
if (FreeTypeVar* ftv = Luau::getMutable<FreeTypeVar>(newTy))
@ -267,11 +267,6 @@ PendingType* TxnLog::changeLevel(TypeId ty, TypeLevel newLevel)
{
ftv->level = newLevel;
}
else if (ConstrainedTypeVar* ctv = Luau::getMutable<ConstrainedTypeVar>(newTy))
{
if (FFlag::LuauUnknownAndNeverType)
ctv->level = newLevel;
}
return newTy;
}
@ -291,7 +286,7 @@ PendingTypePack* TxnLog::changeLevel(TypePackId tp, TypeLevel newLevel)
PendingType* TxnLog::changeScope(TypeId ty, NotNull<Scope> newScope)
{
LUAU_ASSERT(get<FreeTypeVar>(ty) || get<TableTypeVar>(ty) || get<FunctionTypeVar>(ty) || get<ConstrainedTypeVar>(ty));
LUAU_ASSERT(get<FreeTypeVar>(ty) || get<TableTypeVar>(ty) || get<FunctionTypeVar>(ty));
PendingType* newTy = queue(ty);
if (FreeTypeVar* ftv = Luau::getMutable<FreeTypeVar>(newTy))
@ -307,10 +302,6 @@ PendingType* TxnLog::changeScope(TypeId ty, NotNull<Scope> newScope)
{
ftv->scope = newScope;
}
else if (ConstrainedTypeVar* ctv = Luau::getMutable<ConstrainedTypeVar>(newTy))
{
ctv->scope = newScope;
}
return newTy;
}

View file

@ -104,16 +104,6 @@ public:
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("*pending-expansion*"));
}
AstType* operator()(const ConstrainedTypeVar& ctv)
{
AstArray<AstType*> types;
types.size = ctv.parts.size();
types.data = static_cast<AstType**>(allocator->allocate(sizeof(AstType*) * ctv.parts.size()));
for (size_t i = 0; i < ctv.parts.size(); ++i)
types.data[i] = Luau::visit(*this, ctv.parts[i]->ty);
return allocator->alloc<AstTypeIntersection>(Location(), types);
}
AstType* operator()(const SingletonTypeVar& stv)
{
if (const BooleanSingleton* bs = get<BooleanSingleton>(&stv))
@ -348,6 +338,11 @@ public:
{
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName{"never"});
}
AstType* operator()(const NegationTypeVar& ntv)
{
// FIXME: do the same thing we do with ErrorTypeVar
throwRuntimeError("Cannot convert NegationTypeVar into AstNode");
}
private:
Allocator* allocator;

View file

@ -5,6 +5,7 @@
#include "Luau/AstQuery.h"
#include "Luau/Clone.h"
#include "Luau/Instantiation.h"
#include "Luau/Metamethods.h"
#include "Luau/Normalize.h"
#include "Luau/ToString.h"
#include "Luau/TxnLog.h"
@ -62,6 +63,23 @@ struct StackPusher
}
};
static std::optional<std::string> getIdentifierOfBaseVar(AstExpr* node)
{
if (AstExprGlobal* expr = node->as<AstExprGlobal>())
return expr->name.value;
if (AstExprLocal* expr = node->as<AstExprLocal>())
return expr->local->name.value;
if (AstExprIndexExpr* expr = node->as<AstExprIndexExpr>())
return getIdentifierOfBaseVar(expr->expr);
if (AstExprIndexName* expr = node->as<AstExprIndexName>())
return getIdentifierOfBaseVar(expr->expr);
return std::nullopt;
}
struct TypeChecker2
{
NotNull<SingletonTypes> singletonTypes;
@ -283,7 +301,6 @@ struct TypeChecker2
UnifierSharedState sharedState{&ice};
Normalizer normalizer{&arena, singletonTypes, NotNull{&sharedState}};
Unifier u{NotNull{&normalizer}, Mode::Strict, stack.back(), ret->location, Covariant};
u.anyIsTop = true;
u.tryUnify(actualRetType, expectedRetType);
const bool ok = u.errors.empty() && u.log.empty();
@ -313,16 +330,21 @@ struct TypeChecker2
if (value)
visit(value);
if (i != local->values.size - 1)
TypeId* maybeValueType = value ? module->astTypes.find(value) : nullptr;
if (i != local->values.size - 1 || maybeValueType)
{
AstLocal* var = i < local->vars.size ? local->vars.data[i] : nullptr;
if (var && var->annotation)
{
TypeId varType = lookupAnnotation(var->annotation);
TypeId annotationType = lookupAnnotation(var->annotation);
TypeId valueType = value ? lookupType(value) : nullptr;
if (valueType && !isSubtype(varType, valueType, stack.back(), singletonTypes, ice, /* anyIsTop */ false))
reportError(TypeMismatch{varType, valueType}, value->location);
if (valueType)
{
ErrorVec errors = tryUnify(stack.back(), value->location, valueType, annotationType);
if (!errors.empty())
reportErrors(std::move(errors));
}
}
}
else
@ -588,7 +610,7 @@ struct TypeChecker2
visit(rhs);
TypeId rhsType = lookupType(rhs);
if (!isSubtype(rhsType, lhsType, stack.back(), singletonTypes, ice, /* anyIsTop */ false))
if (!isSubtype(rhsType, lhsType, stack.back(), singletonTypes, ice))
{
reportError(TypeMismatch{lhsType, rhsType}, rhs->location);
}
@ -739,7 +761,7 @@ struct TypeChecker2
TypeId actualType = lookupType(number);
TypeId numberType = singletonTypes->numberType;
if (!isSubtype(numberType, actualType, stack.back(), singletonTypes, ice, /* anyIsTop */ false))
if (!isSubtype(numberType, actualType, stack.back(), singletonTypes, ice))
{
reportError(TypeMismatch{actualType, numberType}, number->location);
}
@ -750,7 +772,7 @@ struct TypeChecker2
TypeId actualType = lookupType(string);
TypeId stringType = singletonTypes->stringType;
if (!isSubtype(stringType, actualType, stack.back(), singletonTypes, ice, /* anyIsTop */ false))
if (!isSubtype(actualType, stringType, stack.back(), singletonTypes, ice))
{
reportError(TypeMismatch{actualType, stringType}, string->location);
}
@ -783,26 +805,55 @@ struct TypeChecker2
TypePackId expectedRetType = lookupPack(call);
TypeId functionType = lookupType(call->func);
LUAU_ASSERT(functionType);
TypeId testFunctionType = functionType;
TypePack args;
if (get<AnyTypeVar>(functionType) || get<ErrorTypeVar>(functionType))
return;
// TODO: Lots of other types are callable: intersections of functions
// and things with the __call metamethod.
if (!get<FunctionTypeVar>(functionType))
else if (std::optional<TypeId> callMm = findMetatableEntry(singletonTypes, module->errors, functionType, "__call", call->func->location))
{
if (get<FunctionTypeVar>(follow(*callMm)))
{
if (std::optional<TypeId> instantiatedCallMm = instantiation.substitute(*callMm))
{
args.head.push_back(functionType);
testFunctionType = follow(*instantiatedCallMm);
}
else
{
reportError(UnificationTooComplex{}, call->func->location);
return;
}
}
else
{
// TODO: This doesn't flag the __call metamethod as the problem
// very clearly.
reportError(CannotCallNonFunction{*callMm}, call->func->location);
return;
}
}
else if (get<FunctionTypeVar>(functionType))
{
if (std::optional<TypeId> instantiatedFunctionType = instantiation.substitute(functionType))
{
testFunctionType = *instantiatedFunctionType;
}
else
{
reportError(UnificationTooComplex{}, call->func->location);
return;
}
}
else
{
reportError(CannotCallNonFunction{functionType}, call->func->location);
return;
}
TypeId instantiatedFunctionType = follow(instantiation.substitute(functionType).value_or(nullptr));
TypePack args;
for (AstExpr* arg : call->args)
{
TypeId argTy = module->astTypes[arg];
LUAU_ASSERT(argTy);
TypeId argTy = lookupType(arg);
args.head.push_back(argTy);
}
@ -810,7 +861,7 @@ struct TypeChecker2
FunctionTypeVar ftv{argsTp, expectedRetType};
TypeId expectedType = arena.addType(ftv);
if (!isSubtype(instantiatedFunctionType, expectedType, stack.back(), singletonTypes, ice, /* anyIsTop */ false))
if (!isSubtype(testFunctionType, expectedType, stack.back(), singletonTypes, ice))
{
CloneState cloneState;
expectedType = clone(expectedType, module->internalTypes, cloneState);
@ -829,7 +880,7 @@ struct TypeChecker2
getIndexTypeFromType(module->getModuleScope(), leftType, indexName->index.value, indexName->location, /* addErrors */ true);
if (ty)
{
if (!isSubtype(resultType, *ty, stack.back(), singletonTypes, ice, /* anyIsTop */ false))
if (!isSubtype(resultType, *ty, stack.back(), singletonTypes, ice))
{
reportError(TypeMismatch{resultType, *ty}, indexName->location);
}
@ -862,7 +913,7 @@ struct TypeChecker2
TypeId inferredArgTy = *argIt;
TypeId annotatedArgTy = lookupAnnotation(arg->annotation);
if (!isSubtype(annotatedArgTy, inferredArgTy, stack.back(), singletonTypes, ice, /* anyIsTop */ false))
if (!isSubtype(annotatedArgTy, inferredArgTy, stack.back(), singletonTypes, ice))
{
reportError(TypeMismatch{annotatedArgTy, inferredArgTy}, arg->location);
}
@ -887,15 +938,264 @@ struct TypeChecker2
void visit(AstExprUnary* expr)
{
// TODO!
visit(expr->expr);
NotNull<Scope> scope = stack.back();
TypeId operandType = lookupType(expr->expr);
if (get<AnyTypeVar>(operandType) || get<ErrorTypeVar>(operandType) || get<NeverTypeVar>(operandType))
return;
if (auto it = kUnaryOpMetamethods.find(expr->op); it != kUnaryOpMetamethods.end())
{
std::optional<TypeId> mm = findMetatableEntry(singletonTypes, module->errors, operandType, it->second, expr->location);
if (mm)
{
if (const FunctionTypeVar* ftv = get<FunctionTypeVar>(follow(*mm)))
{
TypePackId expectedArgs = module->internalTypes.addTypePack({operandType});
reportErrors(tryUnify(scope, expr->location, ftv->argTypes, expectedArgs));
if (std::optional<TypeId> ret = first(ftv->retTypes))
{
if (expr->op == AstExprUnary::Op::Len)
{
reportErrors(tryUnify(scope, expr->location, follow(*ret), singletonTypes->numberType));
}
}
else
{
reportError(GenericError{format("Metamethod '%s' must return a value", it->second)}, expr->location);
}
}
return;
}
}
if (expr->op == AstExprUnary::Op::Len)
{
DenseHashSet<TypeId> seen{nullptr};
int recursionCount = 0;
if (!hasLength(operandType, seen, &recursionCount))
{
reportError(NotATable{operandType}, expr->location);
}
}
else if (expr->op == AstExprUnary::Op::Minus)
{
reportErrors(tryUnify(scope, expr->location, operandType, singletonTypes->numberType));
}
else if (expr->op == AstExprUnary::Op::Not)
{
}
else
{
LUAU_ASSERT(!"Unhandled unary operator");
}
}
void visit(AstExprBinary* expr)
{
// TODO!
visit(expr->left);
visit(expr->right);
NotNull<Scope> scope = stack.back();
bool isEquality = expr->op == AstExprBinary::Op::CompareEq || expr->op == AstExprBinary::Op::CompareNe;
bool isComparison = expr->op >= AstExprBinary::Op::CompareEq && expr->op <= AstExprBinary::Op::CompareGe;
bool isLogical = expr->op == AstExprBinary::Op::And || expr->op == AstExprBinary::Op::Or;
TypeId leftType = lookupType(expr->left);
TypeId rightType = lookupType(expr->right);
if (expr->op == AstExprBinary::Op::Or)
{
leftType = stripNil(singletonTypes, module->internalTypes, leftType);
}
bool isStringOperation = isString(leftType) && isString(rightType);
if (get<AnyTypeVar>(leftType) || get<ErrorTypeVar>(leftType) || get<AnyTypeVar>(rightType) || get<ErrorTypeVar>(rightType))
return;
if ((get<BlockedTypeVar>(leftType) || get<FreeTypeVar>(leftType)) && !isEquality && !isLogical)
{
auto name = getIdentifierOfBaseVar(expr->left);
reportError(CannotInferBinaryOperation{expr->op, name,
isComparison ? CannotInferBinaryOperation::OpKind::Comparison : CannotInferBinaryOperation::OpKind::Operation},
expr->location);
return;
}
if (auto it = kBinaryOpMetamethods.find(expr->op); it != kBinaryOpMetamethods.end())
{
std::optional<TypeId> leftMt = getMetatable(leftType, singletonTypes);
std::optional<TypeId> rightMt = getMetatable(rightType, singletonTypes);
bool matches = leftMt == rightMt;
if (isEquality && !matches)
{
auto testUnion = [&matches, singletonTypes = this->singletonTypes](const UnionTypeVar* utv, std::optional<TypeId> otherMt) {
for (TypeId option : utv)
{
if (getMetatable(follow(option), singletonTypes) == otherMt)
{
matches = true;
break;
}
}
};
if (const UnionTypeVar* utv = get<UnionTypeVar>(leftType); utv && rightMt)
{
testUnion(utv, rightMt);
}
if (const UnionTypeVar* utv = get<UnionTypeVar>(rightType); utv && leftMt && !matches)
{
testUnion(utv, leftMt);
}
}
if (!matches && isComparison)
{
reportError(GenericError{format("Types %s and %s cannot be compared with %s because they do not have the same metatable",
toString(leftType).c_str(), toString(rightType).c_str(), toString(expr->op).c_str())},
expr->location);
return;
}
std::optional<TypeId> mm;
if (std::optional<TypeId> leftMm = findMetatableEntry(singletonTypes, module->errors, leftType, it->second, expr->left->location))
mm = leftMm;
else if (std::optional<TypeId> rightMm = findMetatableEntry(singletonTypes, module->errors, rightType, it->second, expr->right->location))
mm = rightMm;
if (mm)
{
if (const FunctionTypeVar* ftv = get<FunctionTypeVar>(*mm))
{
TypePackId expectedArgs;
// For >= and > we invoke __lt and __le respectively with
// swapped argument ordering.
if (expr->op == AstExprBinary::Op::CompareGe || expr->op == AstExprBinary::Op::CompareGt)
{
expectedArgs = module->internalTypes.addTypePack({rightType, leftType});
}
else
{
expectedArgs = module->internalTypes.addTypePack({leftType, rightType});
}
reportErrors(tryUnify(scope, expr->location, ftv->argTypes, expectedArgs));
if (expr->op == AstExprBinary::CompareEq || expr->op == AstExprBinary::CompareNe || expr->op == AstExprBinary::CompareGe ||
expr->op == AstExprBinary::CompareGt || expr->op == AstExprBinary::Op::CompareLe || expr->op == AstExprBinary::Op::CompareLt)
{
TypePackId expectedRets = module->internalTypes.addTypePack({singletonTypes->booleanType});
if (!isSubtype(ftv->retTypes, expectedRets, scope, singletonTypes, ice))
{
reportError(GenericError{format("Metamethod '%s' must return type 'boolean'", it->second)}, expr->location);
}
}
else if (!first(ftv->retTypes))
{
reportError(GenericError{format("Metamethod '%s' must return a value", it->second)}, expr->location);
}
}
else
{
reportError(CannotCallNonFunction{*mm}, expr->location);
}
return;
}
// If this is a string comparison, or a concatenation of strings, we
// want to fall through to primitive behavior.
else if (!isEquality && !(isStringOperation && (expr->op == AstExprBinary::Op::Concat || isComparison)))
{
if (leftMt || rightMt)
{
if (isComparison)
{
reportError(GenericError{format(
"Types '%s' and '%s' cannot be compared with %s because neither type's metatable has a '%s' metamethod",
toString(leftType).c_str(), toString(rightType).c_str(), toString(expr->op).c_str(), it->second)},
expr->location);
}
else
{
reportError(GenericError{format(
"Operator %s is not applicable for '%s' and '%s' because neither type's metatable has a '%s' metamethod",
toString(expr->op).c_str(), toString(leftType).c_str(), toString(rightType).c_str(), it->second)},
expr->location);
}
return;
}
else if (!leftMt && !rightMt && (get<TableTypeVar>(leftType) || get<TableTypeVar>(rightType)))
{
if (isComparison)
{
reportError(GenericError{format("Types '%s' and '%s' cannot be compared with %s because neither type has a metatable",
toString(leftType).c_str(), toString(rightType).c_str(), toString(expr->op).c_str())},
expr->location);
}
else
{
reportError(GenericError{format("Operator %s is not applicable for '%s' and '%s' because neither type has a metatable",
toString(expr->op).c_str(), toString(leftType).c_str(), toString(rightType).c_str())},
expr->location);
}
return;
}
}
}
switch (expr->op)
{
case AstExprBinary::Op::Add:
case AstExprBinary::Op::Sub:
case AstExprBinary::Op::Mul:
case AstExprBinary::Op::Div:
case AstExprBinary::Op::Pow:
case AstExprBinary::Op::Mod:
reportErrors(tryUnify(scope, expr->left->location, leftType, singletonTypes->numberType));
reportErrors(tryUnify(scope, expr->right->location, rightType, singletonTypes->numberType));
break;
case AstExprBinary::Op::Concat:
reportErrors(tryUnify(scope, expr->left->location, leftType, singletonTypes->stringType));
reportErrors(tryUnify(scope, expr->right->location, rightType, singletonTypes->stringType));
break;
case AstExprBinary::Op::CompareGe:
case AstExprBinary::Op::CompareGt:
case AstExprBinary::Op::CompareLe:
case AstExprBinary::Op::CompareLt:
if (isNumber(leftType))
reportErrors(tryUnify(scope, expr->right->location, rightType, singletonTypes->numberType));
else if (isString(leftType))
reportErrors(tryUnify(scope, expr->right->location, rightType, singletonTypes->stringType));
else
reportError(GenericError{format("Types '%s' and '%s' cannot be compared with relational operator %s", toString(leftType).c_str(),
toString(rightType).c_str(), toString(expr->op).c_str())},
expr->location);
break;
case AstExprBinary::Op::And:
case AstExprBinary::Op::Or:
case AstExprBinary::Op::CompareEq:
case AstExprBinary::Op::CompareNe:
break;
default:
// Unhandled AstExprBinary::Op possibility.
LUAU_ASSERT(false);
}
}
void visit(AstExprTypeAssertion* expr)
@ -907,10 +1207,10 @@ struct TypeChecker2
TypeId computedType = lookupType(expr->expr);
// Note: As an optimization, we try 'number <: number | string' first, as that is the more likely case.
if (isSubtype(annotationType, computedType, stack.back(), singletonTypes, ice, /* anyIsTop */ false))
if (isSubtype(annotationType, computedType, stack.back(), singletonTypes, ice))
return;
if (isSubtype(computedType, annotationType, stack.back(), singletonTypes, ice, /* anyIsTop */ false))
if (isSubtype(computedType, annotationType, stack.back(), singletonTypes, ice))
return;
reportError(TypesAreUnrelated{computedType, annotationType}, expr->location);
@ -998,9 +1298,8 @@ struct TypeChecker2
Scope* scope = findInnermostScope(ty->location);
LUAU_ASSERT(scope);
// TODO: Imported types
std::optional<TypeFun> alias = scope->lookupType(ty->name.value);
std::optional<TypeFun> alias =
(ty->prefix) ? scope->lookupImportedType(ty->prefix->value, ty->name.value) : scope->lookupType(ty->name.value);
if (alias.has_value())
{
@ -1212,7 +1511,6 @@ struct TypeChecker2
UnifierSharedState sharedState{&ice};
Normalizer normalizer{&module->internalTypes, singletonTypes, NotNull{&sharedState}};
Unifier u{NotNull{&normalizer}, Mode::Strict, scope, location, Covariant};
u.anyIsTop = true;
u.tryUnify(subTy, superTy);
return std::move(u.errors);

View file

@ -31,12 +31,12 @@ LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 5000)
LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 300)
LUAU_FASTINTVARIABLE(LuauVisitRecursionLimit, 500)
LUAU_FASTFLAG(LuauKnowsTheDataModel3)
LUAU_FASTFLAG(LuauAutocompleteDynamicLimits)
LUAU_FASTFLAG(LuauTypeNormalization2)
LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false)
LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false.
LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false)
LUAU_FASTFLAGVARIABLE(LuauAnyifyModuleReturnGenerics, false)
LUAU_FASTFLAGVARIABLE(LuauLvaluelessPath, false)
LUAU_FASTFLAGVARIABLE(LuauUnknownAndNeverType, false)
LUAU_FASTFLAGVARIABLE(LuauBinaryNeedsExpectedTypesToo, false)
LUAU_FASTFLAGVARIABLE(LuauFixVarargExprHeadType, false)
@ -44,15 +44,15 @@ LUAU_FASTFLAGVARIABLE(LuauNeverTypesAndOperatorsInference, false)
LUAU_FASTFLAGVARIABLE(LuauReturnsFromCallsitesAreNotWidened, false)
LUAU_FASTFLAG(LuauInstantiateInSubtyping)
LUAU_FASTFLAGVARIABLE(LuauCompleteVisitor, false)
LUAU_FASTFLAGVARIABLE(LuauUnionOfTypesFollow, false)
LUAU_FASTFLAGVARIABLE(LuauReportShadowedTypeAlias, false)
LUAU_FASTFLAGVARIABLE(LuauBetterMessagingOnCountMismatch, false)
LUAU_FASTFLAGVARIABLE(LuauArgMismatchReportFunctionLocation, false)
namespace Luau
{
const char* TimeLimitError::what() const throw()
const char* TimeLimitError_DEPRECATED::what() const throw()
{
LUAU_ASSERT(!FFlag::LuauIceExceptionInheritanceChange);
return "Typeinfer failed to complete in allotted time";
}
@ -265,6 +265,11 @@ ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optiona
reportErrorCodeTooComplex(module.root->location);
return std::move(currentModule);
}
catch (const RecursionLimitException_DEPRECATED&)
{
reportErrorCodeTooComplex(module.root->location);
return std::move(currentModule);
}
}
ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mode mode, std::optional<ScopePtr> environmentScope)
@ -280,11 +285,8 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo
iceHandler->moduleName = module.name;
normalizer.arena = &currentModule->internalTypes;
if (FFlag::LuauAutocompleteDynamicLimits)
{
unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit;
unifierState.counters.iterationLimit = unifierIterationLimit ? *unifierIterationLimit : FInt::LuauTypeInferIterationLimit;
}
unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit;
unifierState.counters.iterationLimit = unifierIterationLimit ? *unifierIterationLimit : FInt::LuauTypeInferIterationLimit;
ScopePtr parentScope = environmentScope.value_or(globalScope);
ScopePtr moduleScope = std::make_shared<Scope>(parentScope);
@ -312,6 +314,10 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo
{
currentModule->timeout = true;
}
catch (const TimeLimitError_DEPRECATED&)
{
currentModule->timeout = true;
}
if (FFlag::DebugLuauSharedSelf)
{
@ -419,7 +425,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStat& program)
ice("Unknown AstStat");
if (finishTime && TimeTrace::getClock() > *finishTime)
throw TimeLimitError();
throwTimeLimitError();
}
// This particular overload is for do...end. If you need to not increase the scope level, use checkBlock directly.
@ -446,6 +452,11 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block)
reportErrorCodeTooComplex(block.location);
return;
}
catch (const RecursionLimitException_DEPRECATED&)
{
reportErrorCodeTooComplex(block.location);
return;
}
}
struct InplaceDemoter : TypeVarOnceVisitor
@ -773,16 +784,6 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatRepeat& statement)
checkExpr(repScope, *statement.condition);
}
void TypeChecker::unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel demotedLevel, const ScopePtr& scope, const Location& location)
{
Unifier state = mkUnifier(scope, location);
state.unifyLowerBound(subTy, superTy, demotedLevel);
state.log.commit();
reportErrors(state.errors);
}
struct Demoter : Substitution
{
Demoter(TypeArena* arena)
@ -2091,39 +2092,6 @@ std::optional<TypeId> TypeChecker::getIndexTypeFromTypeImpl(
return std::nullopt;
}
std::vector<TypeId> TypeChecker::reduceUnion(const std::vector<TypeId>& types)
{
std::vector<TypeId> result;
for (TypeId t : types)
{
t = follow(t);
if (get<NeverTypeVar>(t))
continue;
if (get<ErrorTypeVar>(t) || get<AnyTypeVar>(t))
return {t};
if (const UnionTypeVar* utv = get<UnionTypeVar>(t))
{
for (TypeId ty : utv)
{
ty = follow(ty);
if (get<NeverTypeVar>(ty))
continue;
if (get<ErrorTypeVar>(ty) || get<AnyTypeVar>(ty))
return {ty};
if (result.end() == std::find(result.begin(), result.end(), ty))
result.push_back(ty);
}
}
else if (std::find(result.begin(), result.end(), t) == result.end())
result.push_back(t);
}
return result;
}
std::optional<TypeId> TypeChecker::tryStripUnionFromNil(TypeId ty)
{
if (const UnionTypeVar* utv = get<UnionTypeVar>(ty))
@ -2503,11 +2471,8 @@ std::string opToMetaTableEntry(const AstExprBinary::Op& op)
TypeId TypeChecker::unionOfTypes(TypeId a, TypeId b, const ScopePtr& scope, const Location& location, bool unifyFreeTypes)
{
if (FFlag::LuauUnionOfTypesFollow)
{
a = follow(a);
b = follow(b);
}
a = follow(a);
b = follow(b);
if (unifyFreeTypes && (get<FreeTypeVar>(a) || get<FreeTypeVar>(b)))
{
@ -3643,8 +3608,17 @@ void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funNam
location = {state.location.begin, argLocations.back().end};
std::string namePath;
if (std::optional<LValue> lValue = tryGetLValue(funName))
namePath = toString(*lValue);
if (FFlag::LuauLvaluelessPath)
{
if (std::optional<std::string> path = getFunctionNameAsString(funName))
namePath = *path;
}
else
{
if (std::optional<LValue> lValue = tryGetLValue(funName))
namePath = toString(*lValue);
}
auto [minParams, optMaxParams] = getParameterExtents(&state.log, paramPack);
state.reportError(TypeError{location,
@ -3753,11 +3727,28 @@ void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funNam
bool isVariadic = tail && Luau::isVariadic(*tail);
std::string namePath;
if (std::optional<LValue> lValue = tryGetLValue(funName))
namePath = toString(*lValue);
state.reportError(TypeError{
state.location, CountMismatch{minParams, optMaxParams, paramIndex, CountMismatch::Context::Arg, isVariadic, namePath}});
if (FFlag::LuauLvaluelessPath)
{
if (std::optional<std::string> path = getFunctionNameAsString(funName))
namePath = *path;
}
else
{
if (std::optional<LValue> lValue = tryGetLValue(funName))
namePath = toString(*lValue);
}
if (FFlag::LuauArgMismatchReportFunctionLocation)
{
state.reportError(TypeError{
funName.location, CountMismatch{minParams, optMaxParams, paramIndex, CountMismatch::Context::Arg, isVariadic, namePath}});
}
else
{
state.reportError(TypeError{
state.location, CountMismatch{minParams, optMaxParams, paramIndex, CountMismatch::Context::Arg, isVariadic, namePath}});
}
return;
}
++paramIter;
@ -4597,7 +4588,7 @@ TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location locat
Instantiation instantiation{log, &currentModule->internalTypes, scope->level, /*scope*/ nullptr};
if (FFlag::LuauAutocompleteDynamicLimits && instantiationChildLimit)
if (instantiationChildLimit)
instantiation.childLimit = *instantiationChildLimit;
std::optional<TypeId> instantiated = instantiation.substitute(ty);
@ -4694,6 +4685,19 @@ void TypeChecker::ice(const std::string& message)
iceHandler->ice(message);
}
// TODO: Inline me when LuauIceExceptionInheritanceChange is deleted.
void TypeChecker::throwTimeLimitError()
{
if (FFlag::LuauIceExceptionInheritanceChange)
{
throw TimeLimitError(iceHandler->moduleName);
}
else
{
throw TimeLimitError_DEPRECATED();
}
}
void TypeChecker::prepareErrorsForDisplay(ErrorVec& errVec)
{
// Remove errors with names that were generated by recovery from a parse error

View file

@ -1,6 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/TypePack.h"
#include "Luau/Error.h"
#include "Luau/TxnLog.h"
#include <stdexcept>
@ -234,7 +235,7 @@ TypePackId follow(TypePackId tp, std::function<TypePackId(TypePackId)> mapper)
cycleTester = nullptr;
if (tp == cycleTester)
throw std::runtime_error("Luau::follow detected a TypeVar cycle!!");
throwRuntimeError("Luau::follow detected a TypeVar cycle!!");
}
}
}

View file

@ -6,6 +6,8 @@
#include "Luau/ToString.h"
#include "Luau/TypeInfer.h"
#include <algorithm>
namespace Luau
{
@ -146,18 +148,15 @@ std::optional<TypeId> getIndexTypeFromType(const ScopePtr& scope, ErrorVec& erro
return std::nullopt;
}
goodOptions = reduceUnion(goodOptions);
if (goodOptions.empty())
return singletonTypes->neverType;
if (goodOptions.size() == 1)
return goodOptions[0];
// TODO: inefficient.
TypeId result = arena->addType(UnionTypeVar{std::move(goodOptions)});
auto [ty, ok] = normalize(result, NotNull{scope.get()}, *arena, singletonTypes, handle);
if (!ok && addErrors)
errors.push_back(TypeError{location, NormalizationTooComplex{}});
return ok ? ty : singletonTypes->anyType;
return arena->addType(UnionTypeVar{std::move(goodOptions)});
}
else if (const IntersectionTypeVar* itv = get<IntersectionTypeVar>(type))
{
@ -264,4 +263,79 @@ std::vector<TypeId> flatten(TypeArena& arena, NotNull<SingletonTypes> singletonT
return result;
}
std::vector<TypeId> reduceUnion(const std::vector<TypeId>& types)
{
std::vector<TypeId> result;
for (TypeId t : types)
{
t = follow(t);
if (get<NeverTypeVar>(t))
continue;
if (get<ErrorTypeVar>(t) || get<AnyTypeVar>(t))
return {t};
if (const UnionTypeVar* utv = get<UnionTypeVar>(t))
{
for (TypeId ty : utv)
{
ty = follow(ty);
if (get<NeverTypeVar>(ty))
continue;
if (get<ErrorTypeVar>(ty) || get<AnyTypeVar>(ty))
return {ty};
if (result.end() == std::find(result.begin(), result.end(), ty))
result.push_back(ty);
}
}
else if (std::find(result.begin(), result.end(), t) == result.end())
result.push_back(t);
}
return result;
}
static std::optional<TypeId> tryStripUnionFromNil(TypeArena& arena, TypeId ty)
{
if (const UnionTypeVar* utv = get<UnionTypeVar>(ty))
{
if (!std::any_of(begin(utv), end(utv), isNil))
return ty;
std::vector<TypeId> result;
for (TypeId option : utv)
{
if (!isNil(option))
result.push_back(option);
}
if (result.empty())
return std::nullopt;
return result.size() == 1 ? result[0] : arena.addType(UnionTypeVar{std::move(result)});
}
return std::nullopt;
}
TypeId stripNil(NotNull<SingletonTypes> singletonTypes, TypeArena& arena, TypeId ty)
{
ty = follow(ty);
if (get<UnionTypeVar>(ty))
{
std::optional<TypeId> cleaned = tryStripUnionFromNil(arena, ty);
// If there is no union option without 'nil'
if (!cleaned)
return singletonTypes->nilType;
return follow(*cleaned);
}
return follow(ty);
}
} // namespace Luau

View file

@ -66,7 +66,7 @@ TypeId follow(TypeId t, std::function<TypeId(TypeId)> mapper)
{
TypeId res = ltv->thunk();
if (get<LazyTypeVar>(res))
throw std::runtime_error("Lazy TypeVar cannot resolve to another Lazy TypeVar");
throwRuntimeError("Lazy TypeVar cannot resolve to another Lazy TypeVar");
*asMutable(ty) = BoundTypeVar(res);
}
@ -104,7 +104,7 @@ TypeId follow(TypeId t, std::function<TypeId(TypeId)> mapper)
cycleTester = nullptr;
if (t == cycleTester)
throw std::runtime_error("Luau::follow detected a TypeVar cycle!!");
throwRuntimeError("Luau::follow detected a TypeVar cycle!!");
}
}
}
@ -754,12 +754,15 @@ SingletonTypes::SingletonTypes()
, stringType(arena->addType(TypeVar{PrimitiveTypeVar{PrimitiveTypeVar::String}, /*persistent*/ true}))
, booleanType(arena->addType(TypeVar{PrimitiveTypeVar{PrimitiveTypeVar::Boolean}, /*persistent*/ true}))
, threadType(arena->addType(TypeVar{PrimitiveTypeVar{PrimitiveTypeVar::Thread}, /*persistent*/ true}))
, functionType(arena->addType(TypeVar{PrimitiveTypeVar{PrimitiveTypeVar::Function}, /*persistent*/ true}))
, trueType(arena->addType(TypeVar{SingletonTypeVar{BooleanSingleton{true}}, /*persistent*/ true}))
, falseType(arena->addType(TypeVar{SingletonTypeVar{BooleanSingleton{false}}, /*persistent*/ true}))
, anyType(arena->addType(TypeVar{AnyTypeVar{}, /*persistent*/ true}))
, unknownType(arena->addType(TypeVar{UnknownTypeVar{}, /*persistent*/ true}))
, neverType(arena->addType(TypeVar{NeverTypeVar{}, /*persistent*/ true}))
, errorType(arena->addType(TypeVar{ErrorTypeVar{}, /*persistent*/ true}))
, falsyType(arena->addType(TypeVar{UnionTypeVar{{falseType, nilType}}, /*persistent*/ true}))
, truthyType(arena->addType(TypeVar{NegationTypeVar{falsyType}, /*persistent*/ true}))
, anyTypePack(arena->addTypePack(TypePackVar{VariadicTypePack{anyType}, /*persistent*/ true}))
, neverTypePack(arena->addTypePack(TypePackVar{VariadicTypePack{neverType}, /*persistent*/ true}))
, uninhabitableTypePack(arena->addTypePack({neverType}, neverTypePack))
@ -896,7 +899,6 @@ void persist(TypeId ty)
continue;
asMutable(t)->persistent = true;
asMutable(t)->normal = true; // all persistent types are assumed to be normal
if (auto btv = get<BoundTypeVar>(t))
queue.push_back(btv->boundTo);
@ -933,17 +935,13 @@ void persist(TypeId ty)
for (TypeId opt : itv->parts)
queue.push_back(opt);
}
else if (auto ctv = get<ConstrainedTypeVar>(t))
{
for (TypeId opt : ctv->parts)
queue.push_back(opt);
}
else if (auto mtv = get<MetatableTypeVar>(t))
{
queue.push_back(mtv->table);
queue.push_back(mtv->metatable);
}
else if (get<GenericTypeVar>(t) || get<AnyTypeVar>(t) || get<FreeTypeVar>(t) || get<SingletonTypeVar>(t) || get<PrimitiveTypeVar>(t))
else if (get<GenericTypeVar>(t) || get<AnyTypeVar>(t) || get<FreeTypeVar>(t) || get<SingletonTypeVar>(t) || get<PrimitiveTypeVar>(t) ||
get<NegationTypeVar>(t))
{
}
else
@ -990,8 +988,6 @@ const TypeLevel* getLevel(TypeId ty)
return &ttv->level;
else if (auto ftv = get<FunctionTypeVar>(ty))
return &ftv->level;
else if (auto ctv = get<ConstrainedTypeVar>(ty))
return &ctv->level;
else
return nullptr;
}
@ -1056,11 +1052,6 @@ const std::vector<TypeId>& getTypes(const IntersectionTypeVar* itv)
return itv->parts;
}
const std::vector<TypeId>& getTypes(const ConstrainedTypeVar* ctv)
{
return ctv->parts;
}
UnionTypeVarIterator begin(const UnionTypeVar* utv)
{
return UnionTypeVarIterator{utv};
@ -1081,17 +1072,6 @@ IntersectionTypeVarIterator end(const IntersectionTypeVar* itv)
return IntersectionTypeVarIterator{};
}
ConstrainedTypeVarIterator begin(const ConstrainedTypeVar* ctv)
{
return ConstrainedTypeVarIterator{ctv};
}
ConstrainedTypeVarIterator end(const ConstrainedTypeVar* ctv)
{
return ConstrainedTypeVarIterator{};
}
static std::vector<TypeId> parseFormatString(TypeChecker& typechecker, const char* data, size_t size)
{
const char* options = "cdiouxXeEfgGqs*";

View file

@ -8,23 +8,23 @@
#include "Luau/TypePack.h"
#include "Luau/TypeUtils.h"
#include "Luau/TimeTrace.h"
#include "Luau/TypeVar.h"
#include "Luau/VisitTypeVar.h"
#include "Luau/ToString.h"
#include <algorithm>
LUAU_FASTINT(LuauTypeInferRecursionLimit);
LUAU_FASTINT(LuauTypeInferTypePackLoopLimit);
LUAU_FASTINT(LuauTypeInferIterationLimit);
LUAU_FASTFLAG(LuauAutocompleteDynamicLimits)
LUAU_FASTINTVARIABLE(LuauTypeInferLowerBoundsIterationLimit, 2000);
LUAU_FASTFLAG(LuauErrorRecoveryType);
LUAU_FASTFLAG(LuauUnknownAndNeverType)
LUAU_FASTFLAGVARIABLE(LuauReportTypeMismatchForTypePackUnificationFailure, false)
LUAU_FASTFLAGVARIABLE(LuauSubtypeNormalizer, false);
LUAU_FASTFLAGVARIABLE(LuauScalarShapeSubtyping, false)
LUAU_FASTFLAGVARIABLE(LuauInstantiateInSubtyping, false)
LUAU_FASTFLAGVARIABLE(LuauOverloadedFunctionSubtypingPerf, false);
LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution)
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
LUAU_FASTFLAG(LuauNegatedFunctionTypes)
namespace Luau
{
@ -95,15 +95,6 @@ struct PromoteTypeLevels final : TypeVarOnceVisitor
return true;
}
bool visit(TypeId ty, const ConstrainedTypeVar&) override
{
if (!FFlag::LuauUnknownAndNeverType)
return visit(ty);
promote(ty, log.getMutable<ConstrainedTypeVar>(ty));
return true;
}
bool visit(TypeId ty, const FunctionTypeVar&) override
{
// Type levels of types from other modules are already global, so we don't need to promote anything inside
@ -285,7 +276,7 @@ TypeId Widen::clean(TypeId ty)
TypePackId Widen::clean(TypePackId)
{
throw std::runtime_error("Widen attempted to clean a dirty type pack?");
throwRuntimeError("Widen attempted to clean a dirty type pack?");
}
bool Widen::ignoreChildren(TypeId ty)
@ -368,26 +359,14 @@ void Unifier::tryUnify(TypeId subTy, TypeId superTy, bool isFunctionCall, bool i
void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool isIntersection)
{
RecursionLimiter _ra(&sharedState.counters.recursionCount,
FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit);
RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit);
++sharedState.counters.iterationCount;
if (FFlag::LuauAutocompleteDynamicLimits)
if (sharedState.counters.iterationLimit > 0 && sharedState.counters.iterationLimit < sharedState.counters.iterationCount)
{
if (sharedState.counters.iterationLimit > 0 && sharedState.counters.iterationLimit < sharedState.counters.iterationCount)
{
reportError(TypeError{location, UnificationTooComplex{}});
return;
}
}
else
{
if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < sharedState.counters.iterationCount)
{
reportError(TypeError{location, UnificationTooComplex{}});
return;
}
reportError(location, UnificationTooComplex{});
return;
}
superTy = log.follow(superTy);
@ -396,9 +375,6 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
if (superTy == subTy)
return;
if (log.get<ConstrainedTypeVar>(superTy))
return tryUnifyWithConstrainedSuperTypeVar(subTy, superTy);
auto superFree = log.getMutable<FreeTypeVar>(superTy);
auto subFree = log.getMutable<FreeTypeVar>(subTy);
@ -430,7 +406,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
if (subGeneric && !subsumes(useScopes, subGeneric, superFree))
{
// TODO: a more informative error message? CLI-39912
reportError(TypeError{location, GenericError{"Generic subtype escaping scope"}});
reportError(location, GenericError{"Generic subtype escaping scope"});
return;
}
@ -459,7 +435,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
if (superGeneric && !subsumes(useScopes, superGeneric, subFree))
{
// TODO: a more informative error message? CLI-39912
reportError(TypeError{location, GenericError{"Generic supertype escaping scope"}});
reportError(location, GenericError{"Generic supertype escaping scope"});
return;
}
@ -476,15 +452,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
return tryUnifyWithAny(subTy, superTy);
if (get<AnyTypeVar>(subTy))
{
if (anyIsTop)
{
reportError(TypeError{location, TypeMismatch{superTy, subTy}});
return;
}
else
return tryUnifyWithAny(superTy, subTy);
}
return tryUnifyWithAny(superTy, subTy);
if (log.get<ErrorTypeVar>(subTy))
return tryUnifyWithAny(superTy, subTy);
@ -504,7 +472,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
if (auto error = sharedState.cachedUnifyError.find({subTy, superTy}))
{
reportError(TypeError{location, *error});
reportError(location, *error);
return;
}
}
@ -520,9 +488,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
size_t errorCount = errors.size();
if (log.get<ConstrainedTypeVar>(subTy))
tryUnifyWithConstrainedSubTypeVar(subTy, superTy);
else if (const UnionTypeVar* subUnion = log.getMutable<UnionTypeVar>(subTy))
if (const UnionTypeVar* subUnion = log.getMutable<UnionTypeVar>(subTy))
{
tryUnifyUnionWithType(subTy, subUnion, superTy);
}
@ -548,6 +514,12 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
else if ((log.getMutable<PrimitiveTypeVar>(superTy) || log.getMutable<SingletonTypeVar>(superTy)) && log.getMutable<SingletonTypeVar>(subTy))
tryUnifySingletons(subTy, superTy);
else if (auto ptv = get<PrimitiveTypeVar>(superTy);
FFlag::LuauNegatedFunctionTypes && ptv && ptv->type == PrimitiveTypeVar::Function && get<FunctionTypeVar>(subTy))
{
// Ok. Do nothing. forall functions F, F <: function
}
else if (log.getMutable<FunctionTypeVar>(superTy) && log.getMutable<FunctionTypeVar>(subTy))
tryUnifyFunctions(subTy, superTy, isFunctionCall);
@ -580,8 +552,14 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
else if (log.getMutable<ClassTypeVar>(subTy))
tryUnifyWithClass(subTy, superTy, /*reversed*/ true);
else if (log.get<NegationTypeVar>(superTy))
tryUnifyTypeWithNegation(subTy, superTy);
else if (log.get<NegationTypeVar>(subTy))
tryUnifyNegationWithType(subTy, superTy);
else
reportError(TypeError{location, TypeMismatch{superTy, subTy}});
reportError(location, TypeMismatch{superTy, subTy});
if (cacheEnabled)
cacheResult(subTy, superTy, errorCount);
@ -655,9 +633,9 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionTypeVar* subUnion,
else if (failed)
{
if (firstFailedOption)
reportError(TypeError{location, TypeMismatch{superTy, subTy, "Not all union options are compatible.", *firstFailedOption}});
reportError(location, TypeMismatch{superTy, subTy, "Not all union options are compatible.", *firstFailedOption});
else
reportError(TypeError{location, TypeMismatch{superTy, subTy}});
reportError(location, TypeMismatch{superTy, subTy});
}
}
@ -756,7 +734,7 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp
const NormalizedType* subNorm = normalizer->normalize(subTy);
const NormalizedType* superNorm = normalizer->normalize(superTy);
if (!subNorm || !superNorm)
reportError(TypeError{location, UnificationTooComplex{}});
reportError(location, UnificationTooComplex{});
else if ((failedOptionCount == 1 || foundHeuristic) && failedOption)
tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption);
else
@ -765,9 +743,9 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp
else if (!found)
{
if ((failedOptionCount == 1 || foundHeuristic) && failedOption)
reportError(TypeError{location, TypeMismatch{superTy, subTy, "None of the union options are compatible. For example:", *failedOption}});
reportError(location, TypeMismatch{superTy, subTy, "None of the union options are compatible. For example:", *failedOption});
else
reportError(TypeError{location, TypeMismatch{superTy, subTy, "none of the union options are compatible"}});
reportError(location, TypeMismatch{superTy, subTy, "none of the union options are compatible"});
}
}
@ -796,7 +774,7 @@ void Unifier::tryUnifyTypeWithIntersection(TypeId subTy, TypeId superTy, const I
if (unificationTooComplex)
reportError(*unificationTooComplex);
else if (firstFailedOption)
reportError(TypeError{location, TypeMismatch{superTy, subTy, "Not all intersection parts are compatible.", *firstFailedOption}});
reportError(location, TypeMismatch{superTy, subTy, "Not all intersection parts are compatible.", *firstFailedOption});
}
void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionTypeVar* uv, TypeId superTy, bool cacheEnabled, bool isFunctionCall)
@ -854,11 +832,11 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionTypeV
if (subNorm && superNorm)
tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the intersection parts are compatible");
else
reportError(TypeError{location, UnificationTooComplex{}});
reportError(location, UnificationTooComplex{});
}
else if (!found)
{
reportError(TypeError{location, TypeMismatch{superTy, subTy, "none of the intersection parts are compatible"}});
reportError(location, TypeMismatch{superTy, subTy, "none of the intersection parts are compatible"});
}
}
@ -870,43 +848,37 @@ void Unifier::tryUnifyNormalizedTypes(
if (get<UnknownTypeVar>(superNorm.tops) || get<AnyTypeVar>(superNorm.tops) || get<AnyTypeVar>(subNorm.tops))
return;
else if (get<UnknownTypeVar>(subNorm.tops))
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}});
return reportError(location, TypeMismatch{superTy, subTy, reason, error});
if (get<ErrorTypeVar>(subNorm.errors))
if (!get<ErrorTypeVar>(superNorm.errors))
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}});
return reportError(location, TypeMismatch{superTy, subTy, reason, error});
if (get<PrimitiveTypeVar>(subNorm.booleans))
{
if (!get<PrimitiveTypeVar>(superNorm.booleans))
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}});
return reportError(location, TypeMismatch{superTy, subTy, reason, error});
}
else if (const SingletonTypeVar* stv = get<SingletonTypeVar>(subNorm.booleans))
{
if (!get<PrimitiveTypeVar>(superNorm.booleans) && stv != get<SingletonTypeVar>(superNorm.booleans))
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}});
return reportError(location, TypeMismatch{superTy, subTy, reason, error});
}
if (get<PrimitiveTypeVar>(subNorm.nils))
if (!get<PrimitiveTypeVar>(superNorm.nils))
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}});
return reportError(location, TypeMismatch{superTy, subTy, reason, error});
if (get<PrimitiveTypeVar>(subNorm.numbers))
if (!get<PrimitiveTypeVar>(superNorm.numbers))
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}});
return reportError(location, TypeMismatch{superTy, subTy, reason, error});
if (subNorm.strings && superNorm.strings)
{
for (auto [name, ty] : *subNorm.strings)
if (!superNorm.strings->count(name))
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}});
}
else if (!subNorm.strings && superNorm.strings)
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}});
if (!isSubtype(subNorm.strings, superNorm.strings))
return reportError(location, TypeMismatch{superTy, subTy, reason, error});
if (get<PrimitiveTypeVar>(subNorm.threads))
if (!get<PrimitiveTypeVar>(superNorm.errors))
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}});
return reportError(location, TypeMismatch{superTy, subTy, reason, error});
for (TypeId subClass : subNorm.classes)
{
@ -922,7 +894,7 @@ void Unifier::tryUnifyNormalizedTypes(
}
}
if (!found)
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}});
return reportError(location, TypeMismatch{superTy, subTy, reason, error});
}
for (TypeId subTable : subNorm.tables)
@ -947,21 +919,19 @@ void Unifier::tryUnifyNormalizedTypes(
return reportError(*e);
}
if (!found)
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}});
return reportError(location, TypeMismatch{superTy, subTy, reason, error});
}
if (subNorm.functions)
if (!subNorm.functions.isNever())
{
if (!superNorm.functions)
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}});
if (superNorm.functions->empty())
return;
for (TypeId superFun : *superNorm.functions)
if (superNorm.functions.isNever())
return reportError(location, TypeMismatch{superTy, subTy, reason, error});
for (TypeId superFun : *superNorm.functions.parts)
{
Unifier innerState = makeChildUnifier();
const FunctionTypeVar* superFtv = get<FunctionTypeVar>(superFun);
if (!superFtv)
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}});
return reportError(location, TypeMismatch{superTy, subTy, reason, error});
TypePackId tgt = innerState.tryApplyOverloadedFunction(subTy, subNorm.functions, superFtv->argTypes);
innerState.tryUnify_(tgt, superFtv->retTypes);
if (innerState.errors.empty())
@ -969,7 +939,7 @@ void Unifier::tryUnifyNormalizedTypes(
else if (auto e = hasUnificationTooComplex(innerState.errors))
return reportError(*e);
else
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}});
return reportError(location, TypeMismatch{superTy, subTy, reason, error});
}
}
@ -987,15 +957,15 @@ void Unifier::tryUnifyNormalizedTypes(
TypePackId Unifier::tryApplyOverloadedFunction(TypeId function, const NormalizedFunctionType& overloads, TypePackId args)
{
if (!overloads || overloads->empty())
if (overloads.isNever())
{
reportError(TypeError{location, CannotCallNonFunction{function}});
reportError(location, CannotCallNonFunction{function});
return singletonTypes->errorRecoveryTypePack();
}
std::optional<TypePackId> result;
const FunctionTypeVar* firstFun = nullptr;
for (TypeId overload : *overloads)
for (TypeId overload : *overloads.parts)
{
if (const FunctionTypeVar* ftv = get<FunctionTypeVar>(overload))
{
@ -1011,10 +981,17 @@ TypePackId Unifier::tryApplyOverloadedFunction(TypeId function, const Normalized
log.concat(std::move(innerState.log));
if (result)
{
if (FFlag::LuauOverloadedFunctionSubtypingPerf)
{
innerState.log.clear();
innerState.tryUnify_(*result, ftv->retTypes);
}
if (FFlag::LuauOverloadedFunctionSubtypingPerf && innerState.errors.empty())
log.concat(std::move(innerState.log));
// Annoyingly, since we don't support intersection of generic type packs,
// the intersection may fail. We rather arbitrarily use the first matching overload
// in that case.
if (std::optional<TypePackId> intersect = normalizer->intersectionOfTypePacks(*result, ftv->retTypes))
else if (std::optional<TypePackId> intersect = normalizer->intersectionOfTypePacks(*result, ftv->retTypes))
result = intersect;
}
else
@ -1036,12 +1013,12 @@ TypePackId Unifier::tryApplyOverloadedFunction(TypeId function, const Normalized
// TODO: better error reporting?
// The logic for error reporting overload resolution
// is currently over in TypeInfer.cpp, should we move it?
reportError(TypeError{location, GenericError{"No matching overload."}});
reportError(location, GenericError{"No matching overload."});
return singletonTypes->errorRecoveryTypePack(firstFun->retTypes);
}
else
{
reportError(TypeError{location, CannotCallNonFunction{function}});
reportError(location, CannotCallNonFunction{function});
return singletonTypes->errorRecoveryTypePack();
}
}
@ -1214,26 +1191,14 @@ void Unifier::tryUnify(TypePackId subTp, TypePackId superTp, bool isFunctionCall
*/
void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCall)
{
RecursionLimiter _ra(&sharedState.counters.recursionCount,
FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit);
RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit);
++sharedState.counters.iterationCount;
if (FFlag::LuauAutocompleteDynamicLimits)
if (sharedState.counters.iterationLimit > 0 && sharedState.counters.iterationLimit < sharedState.counters.iterationCount)
{
if (sharedState.counters.iterationLimit > 0 && sharedState.counters.iterationLimit < sharedState.counters.iterationCount)
{
reportError(TypeError{location, UnificationTooComplex{}});
return;
}
}
else
{
if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < sharedState.counters.iterationCount)
{
reportError(TypeError{location, UnificationTooComplex{}});
return;
}
reportError(location, UnificationTooComplex{});
return;
}
superTp = log.follow(superTp);
@ -1405,7 +1370,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal
size_t actualSize = size(subTp);
if (ctx == CountMismatch::FunctionResult || ctx == CountMismatch::ExprListResult)
std::swap(expectedSize, actualSize);
reportError(TypeError{location, CountMismatch{expectedSize, std::nullopt, actualSize, ctx}});
reportError(location, CountMismatch{expectedSize, std::nullopt, actualSize, ctx});
while (superIter.good())
{
@ -1426,7 +1391,10 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal
}
else
{
reportError(TypeError{location, GenericError{"Failed to unify type packs"}});
if (FFlag::LuauReportTypeMismatchForTypePackUnificationFailure)
reportError(location, TypePackMismatch{subTp, superTp});
else
reportError(location, GenericError{"Failed to unify type packs"});
}
}
@ -1438,7 +1406,7 @@ void Unifier::tryUnifyPrimitives(TypeId subTy, TypeId superTy)
ice("passed non primitive types to unifyPrimitives");
if (superPrim->type != subPrim->type)
reportError(TypeError{location, TypeMismatch{superTy, subTy}});
reportError(location, TypeMismatch{superTy, subTy});
}
void Unifier::tryUnifySingletons(TypeId subTy, TypeId superTy)
@ -1459,7 +1427,7 @@ void Unifier::tryUnifySingletons(TypeId subTy, TypeId superTy)
if (superPrim && superPrim->type == PrimitiveTypeVar::String && get<StringSingleton>(subSingleton) && variance == Covariant)
return;
reportError(TypeError{location, TypeMismatch{superTy, subTy}});
reportError(location, TypeMismatch{superTy, subTy});
}
void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCall)
@ -1475,7 +1443,10 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal
bool shouldInstantiate = (numGenerics == 0 && subFunction->generics.size() > 0) || (numGenericPacks == 0 && subFunction->genericPacks.size() > 0);
if (FFlag::LuauInstantiateInSubtyping && variance == Covariant && shouldInstantiate)
// TODO: This is unsound when the context is invariant, but the annotation burden without allowing it and without
// read-only properties is too high for lua-apps. Read-only properties _should_ resolve their issue by allowing
// generic methods in tables to be marked read-only.
if (FFlag::LuauInstantiateInSubtyping && shouldInstantiate)
{
Instantiation instantiation{&log, types, scope->level, scope};
@ -1492,21 +1463,21 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal
}
else
{
reportError(TypeError{location, UnificationTooComplex{}});
reportError(location, UnificationTooComplex{});
}
}
else if (numGenerics != subFunction->generics.size())
{
numGenerics = std::min(superFunction->generics.size(), subFunction->generics.size());
reportError(TypeError{location, TypeMismatch{superTy, subTy, "different number of generic type parameters"}});
reportError(location, TypeMismatch{superTy, subTy, "different number of generic type parameters"});
}
if (numGenericPacks != subFunction->genericPacks.size())
{
numGenericPacks = std::min(superFunction->genericPacks.size(), subFunction->genericPacks.size());
reportError(TypeError{location, TypeMismatch{superTy, subTy, "different number of generic type pack parameters"}});
reportError(location, TypeMismatch{superTy, subTy, "different number of generic type pack parameters"});
}
for (size_t i = 0; i < numGenerics; i++)
@ -1533,11 +1504,10 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal
if (auto e = hasUnificationTooComplex(innerState.errors))
reportError(*e);
else if (!innerState.errors.empty() && innerState.firstPackErrorPos)
reportError(
TypeError{location, TypeMismatch{superTy, subTy, format("Argument #%d type is not compatible.", *innerState.firstPackErrorPos),
innerState.errors.front()}});
reportError(location, TypeMismatch{superTy, subTy, format("Argument #%d type is not compatible.", *innerState.firstPackErrorPos),
innerState.errors.front()});
else if (!innerState.errors.empty())
reportError(TypeError{location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}});
reportError(location, TypeMismatch{superTy, subTy, "", innerState.errors.front()});
innerState.ctx = CountMismatch::FunctionResult;
innerState.tryUnify_(subFunction->retTypes, superFunction->retTypes);
@ -1547,13 +1517,12 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal
if (auto e = hasUnificationTooComplex(innerState.errors))
reportError(*e);
else if (!innerState.errors.empty() && size(superFunction->retTypes) == 1 && finite(superFunction->retTypes))
reportError(TypeError{location, TypeMismatch{superTy, subTy, "Return type is not compatible.", innerState.errors.front()}});
reportError(location, TypeMismatch{superTy, subTy, "Return type is not compatible.", innerState.errors.front()});
else if (!innerState.errors.empty() && innerState.firstPackErrorPos)
reportError(
TypeError{location, TypeMismatch{superTy, subTy, format("Return #%d type is not compatible.", *innerState.firstPackErrorPos),
innerState.errors.front()}});
reportError(location, TypeMismatch{superTy, subTy, format("Return #%d type is not compatible.", *innerState.firstPackErrorPos),
innerState.errors.front()});
else if (!innerState.errors.empty())
reportError(TypeError{location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}});
reportError(location, TypeMismatch{superTy, subTy, "", innerState.errors.front()});
}
log.concat(std::move(innerState.log));
@ -1610,6 +1579,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection)
{
TableTypeVar* superTable = log.getMutable<TableTypeVar>(superTy);
TableTypeVar* subTable = log.getMutable<TableTypeVar>(subTy);
TableTypeVar* instantiatedSubTable = subTable;
if (!superTable || !subTable)
ice("passed non-table types to unifyTables");
@ -1627,13 +1597,14 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection)
if (instantiated.has_value())
{
subTable = log.getMutable<TableTypeVar>(*instantiated);
instantiatedSubTable = subTable;
if (!subTable)
ice("instantiation made a table type into a non-table type in tryUnifyTables");
}
else
{
reportError(TypeError{location, UnificationTooComplex{}});
reportError(location, UnificationTooComplex{});
}
}
}
@ -1651,7 +1622,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection)
if (!missingProperties.empty())
{
reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(missingProperties)}});
reportError(location, MissingProperties{superTy, subTy, std::move(missingProperties)});
return;
}
}
@ -1669,7 +1640,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection)
if (!extraProperties.empty())
{
reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra}});
reportError(location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra});
return;
}
}
@ -1730,7 +1701,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection)
// txn log.
TableTypeVar* newSuperTable = log.getMutable<TableTypeVar>(superTy);
TableTypeVar* newSubTable = log.getMutable<TableTypeVar>(subTy);
if (superTable != newSuperTable || subTable != newSubTable)
if (superTable != newSuperTable || (subTable != newSubTable && subTable != instantiatedSubTable))
{
if (errors.empty())
return tryUnifyTables(subTy, superTy, isIntersection);
@ -1792,7 +1763,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection)
// txn log.
TableTypeVar* newSuperTable = log.getMutable<TableTypeVar>(superTy);
TableTypeVar* newSubTable = log.getMutable<TableTypeVar>(subTy);
if (superTable != newSuperTable || subTable != newSubTable)
if (superTable != newSuperTable || (subTable != newSubTable && subTable != instantiatedSubTable))
{
if (errors.empty())
return tryUnifyTables(subTy, superTy, isIntersection);
@ -1850,13 +1821,13 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection)
if (!missingProperties.empty())
{
reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(missingProperties)}});
reportError(location, MissingProperties{superTy, subTy, std::move(missingProperties)});
return;
}
if (!extraProperties.empty())
{
reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra}});
reportError(location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra});
return;
}
@ -1892,14 +1863,14 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed)
std::swap(subTy, superTy);
if (auto ttv = log.get<TableTypeVar>(superTy); !ttv || ttv->state != TableState::Free)
return reportError(TypeError{location, TypeMismatch{osuperTy, osubTy}});
return reportError(location, TypeMismatch{osuperTy, osubTy});
auto fail = [&](std::optional<TypeError> e) {
std::string reason = "The former's metatable does not satisfy the requirements.";
if (e)
reportError(TypeError{location, TypeMismatch{osuperTy, osubTy, reason, *e}});
reportError(location, TypeMismatch{osuperTy, osubTy, reason, *e});
else
reportError(TypeError{location, TypeMismatch{osuperTy, osubTy, reason}});
reportError(location, TypeMismatch{osuperTy, osubTy, reason});
};
// Given t1 where t1 = { lower: (t1) -> (a, b...) }
@ -1931,7 +1902,7 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed)
}
}
reportError(TypeError{location, TypeMismatch{osuperTy, osubTy}});
reportError(location, TypeMismatch{osuperTy, osubTy});
return;
}
@ -1972,7 +1943,7 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed)
if (auto e = hasUnificationTooComplex(innerState.errors))
reportError(*e);
else if (!innerState.errors.empty())
reportError(TypeError{location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front()}});
reportError(location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front()});
log.concat(std::move(innerState.log));
}
@ -2049,9 +2020,9 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed)
auto fail = [&]() {
if (!reversed)
reportError(TypeError{location, TypeMismatch{superTy, subTy}});
reportError(location, TypeMismatch{superTy, subTy});
else
reportError(TypeError{location, TypeMismatch{subTy, superTy}});
reportError(location, TypeMismatch{subTy, superTy});
};
const ClassTypeVar* superClass = get<ClassTypeVar>(superTy);
@ -2096,7 +2067,7 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed)
if (!classProp)
{
ok = false;
reportError(TypeError{location, UnknownProperty{superTy, propName}});
reportError(location, UnknownProperty{superTy, propName});
}
else
{
@ -2120,7 +2091,7 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed)
{
ok = false;
std::string msg = "Class " + superClass->name + " does not have an indexer";
reportError(TypeError{location, GenericError{msg}});
reportError(location, GenericError{msg});
}
if (!ok)
@ -2132,6 +2103,34 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed)
return fail();
}
void Unifier::tryUnifyTypeWithNegation(TypeId subTy, TypeId superTy)
{
const NegationTypeVar* ntv = get<NegationTypeVar>(superTy);
if (!ntv)
ice("tryUnifyTypeWithNegation superTy must be a negation type");
const NormalizedType* subNorm = normalizer->normalize(subTy);
const NormalizedType* superNorm = normalizer->normalize(superTy);
if (!subNorm || !superNorm)
return reportError(location, UnificationTooComplex{});
// T </: ~U iff T <: U
Unifier state = makeChildUnifier();
state.tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "");
if (state.errors.empty())
reportError(location, TypeMismatch{superTy, subTy});
}
void Unifier::tryUnifyNegationWithType(TypeId subTy, TypeId superTy)
{
const NegationTypeVar* ntv = get<NegationTypeVar>(subTy);
if (!ntv)
ice("tryUnifyNegationWithType subTy must be a negation type");
// TODO: ~T </: U iff T <: U
reportError(location, TypeMismatch{superTy, subTy});
}
static void queueTypePack(std::vector<TypeId>& queue, DenseHashSet<TypePackId>& seenTypePacks, Unifier& state, TypePackId a, TypePackId anyTypePack)
{
while (true)
@ -2192,7 +2191,7 @@ void Unifier::tryUnifyVariadics(TypePackId subTp, TypePackId superTp, bool rever
}
else if (get<Unifiable::Generic>(tail))
{
reportError(TypeError{location, GenericError{"Cannot unify variadic and generic packs"}});
reportError(location, GenericError{"Cannot unify variadic and generic packs"});
}
else if (get<Unifiable::Error>(tail))
{
@ -2206,7 +2205,7 @@ void Unifier::tryUnifyVariadics(TypePackId subTp, TypePackId superTp, bool rever
}
else
{
reportError(TypeError{location, GenericError{"Failed to unify variadic packs"}});
reportError(location, GenericError{"Failed to unify variadic packs"});
}
}
@ -2314,186 +2313,6 @@ std::optional<TypeId> Unifier::findTablePropertyRespectingMeta(TypeId lhsType, N
return Luau::findTablePropertyRespectingMeta(singletonTypes, errors, lhsType, name, location);
}
void Unifier::tryUnifyWithConstrainedSubTypeVar(TypeId subTy, TypeId superTy)
{
const ConstrainedTypeVar* subConstrained = get<ConstrainedTypeVar>(subTy);
if (!subConstrained)
ice("tryUnifyWithConstrainedSubTypeVar received non-ConstrainedTypeVar subTy!");
const std::vector<TypeId>& subTyParts = subConstrained->parts;
// A | B <: T if A <: T and B <: T
bool failed = false;
std::optional<TypeError> unificationTooComplex;
const size_t count = subTyParts.size();
for (size_t i = 0; i < count; ++i)
{
TypeId type = subTyParts[i];
Unifier innerState = makeChildUnifier();
innerState.tryUnify_(type, superTy);
if (i == count - 1)
log.concat(std::move(innerState.log));
++i;
if (auto e = hasUnificationTooComplex(innerState.errors))
unificationTooComplex = e;
if (!innerState.errors.empty())
{
failed = true;
break;
}
}
if (unificationTooComplex)
reportError(*unificationTooComplex);
else if (failed)
reportError(TypeError{location, TypeMismatch{superTy, subTy}});
else
log.replace(subTy, BoundTypeVar{superTy});
}
void Unifier::tryUnifyWithConstrainedSuperTypeVar(TypeId subTy, TypeId superTy)
{
ConstrainedTypeVar* superC = log.getMutable<ConstrainedTypeVar>(superTy);
if (!superC)
ice("tryUnifyWithConstrainedSuperTypeVar received non-ConstrainedTypeVar superTy!");
// subTy could be a
// table
// metatable
// class
// function
// primitive
// free
// generic
// intersection
// union
// Do we really just tack it on? I think we might!
// We can certainly do some deduplication.
// Is there any point to deducing Player|Instance when we could just reduce to Instance?
// Is it actually ok to have multiple free types in a single intersection? What if they are later unified into the same type?
// Maybe we do a simplification step during quantification.
auto it = std::find(superC->parts.begin(), superC->parts.end(), subTy);
if (it != superC->parts.end())
return;
superC->parts.push_back(subTy);
}
void Unifier::unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel demotedLevel)
{
// The duplication between this and regular typepack unification is tragic.
auto superIter = begin(superTy, &log);
auto superEndIter = end(superTy);
auto subIter = begin(subTy, &log);
auto subEndIter = end(subTy);
int count = FInt::LuauTypeInferLowerBoundsIterationLimit;
for (; subIter != subEndIter; ++subIter)
{
if (0 >= --count)
ice("Internal recursion counter limit exceeded in Unifier::unifyLowerBound");
if (superIter != superEndIter)
{
tryUnify_(*subIter, *superIter);
++superIter;
continue;
}
if (auto t = superIter.tail())
{
TypePackId tailPack = follow(*t);
if (log.get<FreeTypePack>(tailPack) && occursCheck(tailPack, subTy))
return;
FreeTypePack* freeTailPack = log.getMutable<FreeTypePack>(tailPack);
if (!freeTailPack)
return;
TypePack* tp = getMutable<TypePack>(log.replace(tailPack, TypePack{}));
for (; subIter != subEndIter; ++subIter)
{
tp->head.push_back(types->addType(ConstrainedTypeVar{demotedLevel, {follow(*subIter)}}));
}
tp->tail = subIter.tail();
}
return;
}
if (superIter != superEndIter)
{
if (auto subTail = subIter.tail())
{
TypePackId subTailPack = follow(*subTail);
if (get<FreeTypePack>(subTailPack))
{
TypePack* tp = getMutable<TypePack>(log.replace(subTailPack, TypePack{}));
for (; superIter != superEndIter; ++superIter)
tp->head.push_back(*superIter);
}
else if (const VariadicTypePack* subVariadic = log.getMutable<VariadicTypePack>(subTailPack))
{
while (superIter != superEndIter)
{
tryUnify_(subVariadic->ty, *superIter);
++superIter;
}
}
}
else
{
while (superIter != superEndIter)
{
if (!isOptional(*superIter))
{
errors.push_back(TypeError{location, CountMismatch{size(superTy), std::nullopt, size(subTy), CountMismatch::Return}});
return;
}
++superIter;
}
}
return;
}
// Both iters are at their respective tails
auto subTail = subIter.tail();
auto superTail = superIter.tail();
if (subTail && superTail)
tryUnify(*subTail, *superTail);
else if (subTail)
{
const FreeTypePack* freeSubTail = log.getMutable<FreeTypePack>(*subTail);
if (freeSubTail)
{
log.replace(*subTail, TypePack{});
}
}
else if (superTail)
{
const FreeTypePack* freeSuperTail = log.getMutable<FreeTypePack>(*superTail);
if (freeSuperTail)
{
log.replace(*superTail, TypePack{});
}
}
}
bool Unifier::occursCheck(TypeId needle, TypeId haystack)
{
sharedState.tempSeenTy.clear();
@ -2503,8 +2322,7 @@ bool Unifier::occursCheck(TypeId needle, TypeId haystack)
bool Unifier::occursCheck(DenseHashSet<TypeId>& seen, TypeId needle, TypeId haystack)
{
RecursionLimiter _ra(&sharedState.counters.recursionCount,
FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit);
RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit);
bool occurrence = false;
@ -2529,7 +2347,7 @@ bool Unifier::occursCheck(DenseHashSet<TypeId>& seen, TypeId needle, TypeId hays
if (needle == haystack)
{
reportError(TypeError{location, OccursCheckFailed{}});
reportError(location, OccursCheckFailed{});
log.replace(needle, *singletonTypes->errorRecoveryType());
return true;
@ -2547,11 +2365,6 @@ bool Unifier::occursCheck(DenseHashSet<TypeId>& seen, TypeId needle, TypeId hays
for (TypeId ty : a->parts)
check(ty);
}
else if (auto a = log.getMutable<ConstrainedTypeVar>(haystack))
{
for (TypeId ty : a->parts)
check(ty);
}
return occurrence;
}
@ -2579,14 +2392,13 @@ bool Unifier::occursCheck(DenseHashSet<TypePackId>& seen, TypePackId needle, Typ
if (!log.getMutable<Unifiable::Free>(needle))
ice("Expected needle pack to be free");
RecursionLimiter _ra(&sharedState.counters.recursionCount,
FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit);
RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit);
while (!log.getMutable<ErrorTypeVar>(haystack))
{
if (needle == haystack)
{
reportError(TypeError{location, OccursCheckFailed{}});
reportError(location, OccursCheckFailed{});
log.replace(needle, *singletonTypes->errorRecoveryTypePack());
return true;
@ -2607,18 +2419,31 @@ bool Unifier::occursCheck(DenseHashSet<TypePackId>& seen, TypePackId needle, Typ
Unifier Unifier::makeChildUnifier()
{
Unifier u = Unifier{normalizer, mode, scope, location, variance, &log};
u.anyIsTop = anyIsTop;
u.normalize = normalize;
u.useScopes = useScopes;
return u;
}
// A utility function that appends the given error to the unifier's error log.
// This allows setting a breakpoint wherever the unifier reports an error.
//
// Note: report error accepts its arguments by value intentionally to reduce the stack usage of functions which call `reportError`.
void Unifier::reportError(Location location, TypeErrorData data)
{
errors.emplace_back(std::move(location), std::move(data));
}
// A utility function that appends the given error to the unifier's error log.
// This allows setting a breakpoint wherever the unifier reports an error.
//
// Note: to conserve stack space in calling functions it is generally preferred to call `Unifier::reportError(Location location, TypeErrorData data)`
// instead of this method.
void Unifier::reportError(TypeError err)
{
errors.push_back(std::move(err));
}
bool Unifier::isNonstrictMode() const
{
return (mode == Mode::Nonstrict) || (mode == Mode::NoCheck);
@ -2629,7 +2454,7 @@ void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, TypeId
if (auto e = hasUnificationTooComplex(innerErrors))
reportError(*e);
else if (!innerErrors.empty())
reportError(TypeError{location, TypeMismatch{wantedType, givenType}});
reportError(location, TypeMismatch{wantedType, givenType});
}
void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, const std::string& prop, TypeId wantedType, TypeId givenType)

View file

@ -58,6 +58,8 @@ struct Comment
struct ParseResult
{
AstStatBlock* root;
size_t lines = 0;
std::vector<HotComment> hotcomments;
std::vector<ParseError> errors;

View file

@ -302,8 +302,8 @@ private:
AstStatError* reportStatError(const Location& location, const AstArray<AstExpr*>& expressions, const AstArray<AstStat*>& statements,
const char* format, ...) LUAU_PRINTF_ATTR(5, 6);
AstExprError* reportExprError(const Location& location, const AstArray<AstExpr*>& expressions, const char* format, ...) LUAU_PRINTF_ATTR(4, 5);
AstTypeError* reportTypeAnnotationError(const Location& location, const AstArray<AstType*>& types, bool isMissing, const char* format, ...)
LUAU_PRINTF_ATTR(5, 6);
AstTypeError* reportTypeAnnotationError(const Location& location, const AstArray<AstType*>& types, const char* format, ...)
LUAU_PRINTF_ATTR(4, 5);
// `parseErrorLocation` is associated with the parser error
// `astErrorLocation` is associated with the AstTypeError created
// It can be useful to have different error locations so that the parse error can include the next lexeme, while the AstTypeError can precisely

View file

@ -641,8 +641,8 @@ Lexeme Lexer::readInterpolatedStringSection(Position start, Lexeme::Type formatT
return brokenDoubleBrace;
}
Lexeme lexemeOutput(Location(start, position()), Lexeme::InterpStringBegin, &buffer[startOffset], offset - startOffset);
consume();
Lexeme lexemeOutput(Location(start, position()), Lexeme::InterpStringBegin, &buffer[startOffset], offset - startOffset - 1);
return lexemeOutput;
}

View file

@ -23,9 +23,9 @@ LUAU_FASTFLAGVARIABLE(LuauErrorDoubleHexPrefix, false)
LUAU_DYNAMIC_FASTFLAGVARIABLE(LuaReportParseIntegerIssues, false)
LUAU_FASTFLAGVARIABLE(LuauInterpolatedStringBaseSupport, false)
LUAU_FASTFLAGVARIABLE(LuauTypeAnnotationLocationChange, false)
LUAU_FASTFLAGVARIABLE(LuauCommaParenWarnings, false)
LUAU_FASTFLAGVARIABLE(LuauTableConstructorRecovery, false)
bool lua_telemetry_parsed_out_of_range_bin_integer = false;
bool lua_telemetry_parsed_out_of_range_hex_integer = false;
@ -164,15 +164,16 @@ ParseResult Parser::parse(const char* buffer, size_t bufferSize, AstNameTable& n
try
{
AstStatBlock* root = p.parseChunk();
size_t lines = p.lexer.current().location.end.line + (bufferSize > 0 && buffer[bufferSize - 1] != '\n');
return ParseResult{root, std::move(p.hotcomments), std::move(p.parseErrors), std::move(p.commentLocations)};
return ParseResult{root, lines, std::move(p.hotcomments), std::move(p.parseErrors), std::move(p.commentLocations)};
}
catch (ParseError& err)
{
// when catching a fatal error, append it to the list of non-fatal errors and return
p.parseErrors.push_back(err);
return ParseResult{nullptr, {}, p.parseErrors};
return ParseResult{nullptr, 0, {}, p.parseErrors};
}
}
@ -811,9 +812,8 @@ AstDeclaredClassProp Parser::parseDeclaredClassMethod()
if (args.size() == 0 || args[0].name.name != "self" || args[0].annotation != nullptr)
{
return AstDeclaredClassProp{fnName.name,
reportTypeAnnotationError(Location(start, end), {}, /*isMissing*/ false, "'self' must be present as the unannotated first parameter"),
true};
return AstDeclaredClassProp{
fnName.name, reportTypeAnnotationError(Location(start, end), {}, "'self' must be present as the unannotated first parameter"), true};
}
// Skip the first index.
@ -824,8 +824,7 @@ AstDeclaredClassProp Parser::parseDeclaredClassMethod()
if (args[i].annotation)
vars.push_back(args[i].annotation);
else
vars.push_back(reportTypeAnnotationError(
Location(start, end), {}, /*isMissing*/ false, "All declaration parameters aside from 'self' must be annotated"));
vars.push_back(reportTypeAnnotationError(Location(start, end), {}, "All declaration parameters aside from 'self' must be annotated"));
}
if (vararg && !varargAnnotation)
@ -1537,7 +1536,7 @@ AstType* Parser::parseTypeAnnotation(TempVector<AstType*>& parts, const Location
if (isUnion && isIntersection)
{
return reportTypeAnnotationError(Location(begin, parts.back()->location), copy(parts), /*isMissing*/ false,
return reportTypeAnnotationError(Location(begin, parts.back()->location), copy(parts),
"Mixing union and intersection types is not allowed; consider wrapping in parentheses.");
}
@ -1623,18 +1622,18 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack)
return {allocator.alloc<AstTypeSingletonString>(start, svalue)};
}
else
return {reportTypeAnnotationError(start, {}, /*isMissing*/ false, "String literal contains malformed escape sequence")};
return {reportTypeAnnotationError(start, {}, "String literal contains malformed escape sequence")};
}
else if (lexer.current().type == Lexeme::InterpStringBegin || lexer.current().type == Lexeme::InterpStringSimple)
{
parseInterpString();
return {reportTypeAnnotationError(start, {}, /*isMissing*/ false, "Interpolated string literals cannot be used as types")};
return {reportTypeAnnotationError(start, {}, "Interpolated string literals cannot be used as types")};
}
else if (lexer.current().type == Lexeme::BrokenString)
{
nextLexeme();
return {reportTypeAnnotationError(start, {}, /*isMissing*/ false, "Malformed string")};
return {reportTypeAnnotationError(start, {}, "Malformed string")};
}
else if (lexer.current().type == Lexeme::Name)
{
@ -1693,33 +1692,20 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack)
{
nextLexeme();
return {reportTypeAnnotationError(start, {}, /*isMissing*/ false,
return {reportTypeAnnotationError(start, {},
"Using 'function' as a type annotation is not supported, consider replacing with a function type annotation e.g. '(...any) -> "
"...any'"),
{}};
}
else
{
if (FFlag::LuauTypeAnnotationLocationChange)
{
// For a missing type annotation, capture 'space' between last token and the next one
Location astErrorlocation(lexer.previousLocation().end, start.begin);
// The parse error includes the next lexeme to make it easier to display where the error is (e.g. in an IDE or a CLI error message).
// Including the current lexeme also makes the parse error consistent with other parse errors returned by Luau.
Location parseErrorLocation(lexer.previousLocation().end, start.end);
return {
reportMissingTypeAnnotationError(parseErrorLocation, astErrorlocation, "Expected type, got %s", lexer.current().toString().c_str()),
{}};
}
else
{
Location location = lexer.current().location;
// For a missing type annotation, capture 'space' between last token and the next one
location = Location(lexer.previousLocation().end, lexer.current().location.begin);
return {reportTypeAnnotationError(location, {}, /*isMissing*/ true, "Expected type, got %s", lexer.current().toString().c_str()), {}};
}
// For a missing type annotation, capture 'space' between last token and the next one
Location astErrorlocation(lexer.previousLocation().end, start.begin);
// The parse error includes the next lexeme to make it easier to display where the error is (e.g. in an IDE or a CLI error message).
// Including the current lexeme also makes the parse error consistent with other parse errors returned by Luau.
Location parseErrorLocation(lexer.previousLocation().end, start.end);
return {
reportMissingTypeAnnotationError(parseErrorLocation, astErrorlocation, "Expected type, got %s", lexer.current().toString().c_str()), {}};
}
}
@ -2325,9 +2311,13 @@ AstExpr* Parser::parseTableConstructor()
MatchLexeme matchBrace = lexer.current();
expectAndConsume('{', "table literal");
unsigned lastElementIndent = 0;
while (lexer.current().type != '}')
{
if (FFlag::LuauTableConstructorRecovery)
lastElementIndent = lexer.current().location.begin.column;
if (lexer.current().type == '[')
{
MatchLexeme matchLocationBracket = lexer.current();
@ -2372,10 +2362,14 @@ AstExpr* Parser::parseTableConstructor()
{
nextLexeme();
}
else
else if (FFlag::LuauTableConstructorRecovery && (lexer.current().type == '[' || lexer.current().type == Lexeme::Name) &&
lexer.current().location.begin.column == lastElementIndent)
{
if (lexer.current().type != '}')
break;
report(lexer.current().location, "Expected ',' after table constructor element");
}
else if (lexer.current().type != '}')
{
break;
}
}
@ -3033,27 +3027,18 @@ AstExprError* Parser::reportExprError(const Location& location, const AstArray<A
return allocator.alloc<AstExprError>(location, expressions, unsigned(parseErrors.size() - 1));
}
AstTypeError* Parser::reportTypeAnnotationError(const Location& location, const AstArray<AstType*>& types, bool isMissing, const char* format, ...)
AstTypeError* Parser::reportTypeAnnotationError(const Location& location, const AstArray<AstType*>& types, const char* format, ...)
{
if (FFlag::LuauTypeAnnotationLocationChange)
{
// Missing type annotations should be using `reportMissingTypeAnnotationError` when LuauTypeAnnotationLocationChange is enabled
// Note: `isMissing` can be removed once FFlag::LuauTypeAnnotationLocationChange is removed since it will always be true.
LUAU_ASSERT(!isMissing);
}
va_list args;
va_start(args, format);
report(location, format, args);
va_end(args);
return allocator.alloc<AstTypeError>(location, types, isMissing, unsigned(parseErrors.size() - 1));
return allocator.alloc<AstTypeError>(location, types, false, unsigned(parseErrors.size() - 1));
}
AstTypeError* Parser::reportMissingTypeAnnotationError(const Location& parseErrorLocation, const Location& astErrorLocation, const char* format, ...)
{
LUAU_ASSERT(FFlag::LuauTypeAnnotationLocationChange);
va_list args;
va_start(args, format);
report(parseErrorLocation, format, args);

View file

@ -14,7 +14,6 @@
#endif
LUAU_FASTFLAG(DebugLuauTimeTracing)
LUAU_FASTFLAG(LuauTypeMismatchModuleNameResolution)
enum class ReportFormat
{
@ -55,11 +54,9 @@ static void reportError(const Luau::Frontend& frontend, ReportFormat format, con
if (const Luau::SyntaxError* syntaxError = Luau::get_if<Luau::SyntaxError>(&error.data))
report(format, humanReadableName.c_str(), error.location, "SyntaxError", syntaxError->message.c_str());
else if (FFlag::LuauTypeMismatchModuleNameResolution)
else
report(format, humanReadableName.c_str(), error.location, "TypeError",
Luau::toString(error, Luau::TypeErrorToStringOptions{frontend.fileResolver}).c_str());
else
report(format, humanReadableName.c_str(), error.location, "TypeError", Luau::toString(error).c_str());
}
static void reportWarning(ReportFormat format, const char* name, const Luau::LintWarning& warning)

View file

@ -49,6 +49,8 @@ enum class CompileFormat
Binary,
Remarks,
Codegen,
CodegenVerbose,
CodegenNull,
Null
};
@ -673,21 +675,33 @@ static void reportError(const char* name, const Luau::CompileError& error)
report(name, error.getLocation(), "CompileError", error.what());
}
static std::string getCodegenAssembly(const char* name, const std::string& bytecode)
static std::string getCodegenAssembly(const char* name, const std::string& bytecode, Luau::CodeGen::AssemblyOptions options)
{
std::unique_ptr<lua_State, void (*)(lua_State*)> globalState(luaL_newstate(), lua_close);
lua_State* L = globalState.get();
setupState(L);
if (luau_load(L, name, bytecode.data(), bytecode.size(), 0) == 0)
return Luau::CodeGen::getAssemblyText(L, -1);
return Luau::CodeGen::getAssembly(L, -1, options);
fprintf(stderr, "Error loading bytecode %s\n", name);
return "";
}
static bool compileFile(const char* name, CompileFormat format)
static void annotateInstruction(void* context, std::string& text, int fid, int instpos)
{
Luau::BytecodeBuilder& bcb = *(Luau::BytecodeBuilder*)context;
bcb.annotateInstruction(text, fid, instpos);
}
struct CompileStats
{
size_t lines;
size_t bytecode;
size_t codegen;
};
static bool compileFile(const char* name, CompileFormat format, CompileStats& stats)
{
std::optional<std::string> source = readFile(name);
if (!source)
@ -696,9 +710,13 @@ static bool compileFile(const char* name, CompileFormat format)
return false;
}
// NOTE: Normally, you should use Luau::compile or luau_compile (see lua_require as an example)
// This function is much more complicated because it supports many output human-readable formats through internal interfaces
try
{
Luau::BytecodeBuilder bcb;
Luau::CodeGen::AssemblyOptions options = {format == CompileFormat::CodegenNull, format == CompileFormat::Codegen, annotateInstruction, &bcb};
if (format == CompileFormat::Text)
{
@ -711,8 +729,24 @@ static bool compileFile(const char* name, CompileFormat format)
bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Remarks);
bcb.setDumpSource(*source);
}
else if (format == CompileFormat::Codegen || format == CompileFormat::CodegenVerbose)
{
bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Locals |
Luau::BytecodeBuilder::Dump_Remarks);
bcb.setDumpSource(*source);
}
Luau::compileOrThrow(bcb, *source, copts());
Luau::Allocator allocator;
Luau::AstNameTable names(allocator);
Luau::ParseResult result = Luau::Parser::parse(source->c_str(), source->size(), names, allocator);
if (!result.errors.empty())
throw Luau::ParseErrors(result.errors);
stats.lines += result.lines;
Luau::compileOrThrow(bcb, result, names, copts());
stats.bytecode += bcb.getBytecode().size();
switch (format)
{
@ -726,7 +760,11 @@ static bool compileFile(const char* name, CompileFormat format)
fwrite(bcb.getBytecode().data(), 1, bcb.getBytecode().size(), stdout);
break;
case CompileFormat::Codegen:
printf("%s", getCodegenAssembly(name, bcb.getBytecode()).c_str());
case CompileFormat::CodegenVerbose:
printf("%s", getCodegenAssembly(name, bcb.getBytecode(), options).c_str());
break;
case CompileFormat::CodegenNull:
stats.codegen += getCodegenAssembly(name, bcb.getBytecode(), options).size();
break;
case CompileFormat::Null:
break;
@ -755,7 +793,7 @@ static void displayHelp(const char* argv0)
printf("\n");
printf("Available modes:\n");
printf(" omitted: compile and run input files one by one\n");
printf(" --compile[=format]: compile input files and output resulting formatted bytecode (binary, text, remarks, codegen or null)\n");
printf(" --compile[=format]: compile input files and output resulting bytecode/assembly (binary, text, remarks, codegen)\n");
printf("\n");
printf("Available options:\n");
printf(" --coverage: collect code coverage while running the code and output results to coverage.out\n");
@ -813,6 +851,14 @@ int replMain(int argc, char** argv)
{
compileFormat = CompileFormat::Codegen;
}
else if (strcmp(argv[1], "--compile=codegenverbose") == 0)
{
compileFormat = CompileFormat::CodegenVerbose;
}
else if (strcmp(argv[1], "--compile=codegennull") == 0)
{
compileFormat = CompileFormat::CodegenNull;
}
else if (strcmp(argv[1], "--compile=null") == 0)
{
compileFormat = CompileFormat::Null;
@ -924,10 +970,17 @@ int replMain(int argc, char** argv)
_setmode(_fileno(stdout), _O_BINARY);
#endif
CompileStats stats = {};
int failed = 0;
for (const std::string& path : files)
failed += !compileFile(path.c_str(), compileFormat);
failed += !compileFile(path.c_str(), compileFormat, stats);
if (compileFormat == CompileFormat::Null)
printf("Compiled %d KLOC into %d KB bytecode\n", int(stats.lines / 1000), int(stats.bytecode / 1024));
else if (compileFormat == CompileFormat::CodegenNull)
printf("Compiled %d KLOC into %d KB bytecode => %d KB native code\n", int(stats.lines / 1000), int(stats.bytecode / 1024),
int(stats.codegen / 1024));
return failed ? 1 : 0;
}

View file

@ -143,6 +143,11 @@ if (MSVC AND MSVC_VERSION GREATER_EQUAL 1924)
set_source_files_properties(VM/src/lvmexecute.cpp PROPERTIES COMPILE_FLAGS /d2ssa-pre-)
endif()
if (NOT MSVC)
# disable support for math_errno which allows compilers to lower sqrt() into a single CPU instruction
target_compile_options(Luau.VM PRIVATE -fno-math-errno)
endif()
if(MSVC AND LUAU_BUILD_CLI)
# the default stack size that MSVC linker uses is 1 MB; we need more stack space in Debug because stack frames are larger
set_target_properties(Luau.Analyze.CLI PROPERTIES LINK_FLAGS_DEBUG /STACK:2097152)

View file

@ -0,0 +1,52 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/RegisterA64.h"
namespace Luau
{
namespace CodeGen
{
enum class AddressKindA64 : uint8_t
{
imm, // reg + imm
reg, // reg + reg
// TODO:
// reg + reg << shift
// reg + sext(reg) << shift
// reg + uext(reg) << shift
// pc + offset
};
struct AddressA64
{
AddressA64(RegisterA64 base, int off = 0)
: kind(AddressKindA64::imm)
, base(base)
, offset(xzr)
, data(off)
{
LUAU_ASSERT(base.kind == KindA64::x);
LUAU_ASSERT(off >= 0 && off < 4096);
}
AddressA64(RegisterA64 base, RegisterA64 offset)
: kind(AddressKindA64::reg)
, base(base)
, offset(offset)
, data(0)
{
LUAU_ASSERT(base.kind == KindA64::x);
LUAU_ASSERT(offset.kind == KindA64::x);
}
AddressKindA64 kind;
RegisterA64 base;
RegisterA64 offset;
int data;
};
} // namespace CodeGen
} // namespace Luau

View file

@ -0,0 +1,144 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/RegisterA64.h"
#include "Luau/AddressA64.h"
#include "Luau/ConditionA64.h"
#include "Luau/Label.h"
#include <string>
#include <vector>
namespace Luau
{
namespace CodeGen
{
class AssemblyBuilderA64
{
public:
explicit AssemblyBuilderA64(bool logText);
~AssemblyBuilderA64();
// Moves
void mov(RegisterA64 dst, RegisterA64 src);
void mov(RegisterA64 dst, uint16_t src, int shift = 0);
void movk(RegisterA64 dst, uint16_t src, int shift = 0);
// Arithmetics
void add(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift = 0);
void add(RegisterA64 dst, RegisterA64 src1, int src2);
void sub(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift = 0);
void sub(RegisterA64 dst, RegisterA64 src1, int src2);
void neg(RegisterA64 dst, RegisterA64 src);
// Comparisons
// Note: some arithmetic instructions also have versions that update flags (ADDS etc) but we aren't using them atm
// TODO: add cmp
// Binary
// Note: shifted-register support and bitfield operations are omitted for simplicity
// TODO: support immediate arguments (they have odd encoding and forbid many values)
// TODO: support not variants for and/or/eor (required to support not...)
void and_(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2);
void orr(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2);
void eor(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2);
void lsl(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2);
void lsr(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2);
void asr(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2);
void ror(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2);
void clz(RegisterA64 dst, RegisterA64 src);
void rbit(RegisterA64 dst, RegisterA64 src);
// Load
// Note: paired loads are currently omitted for simplicity
void ldr(RegisterA64 dst, AddressA64 src);
void ldrb(RegisterA64 dst, AddressA64 src);
void ldrh(RegisterA64 dst, AddressA64 src);
void ldrsb(RegisterA64 dst, AddressA64 src);
void ldrsh(RegisterA64 dst, AddressA64 src);
void ldrsw(RegisterA64 dst, AddressA64 src);
// Store
void str(RegisterA64 src, AddressA64 dst);
void strb(RegisterA64 src, AddressA64 dst);
void strh(RegisterA64 src, AddressA64 dst);
// Control flow
// Note: tbz/tbnz are currently not supported because they have 15-bit offsets and we don't support branch thunks
void b(ConditionA64 cond, Label& label);
void cbz(RegisterA64 src, Label& label);
void cbnz(RegisterA64 src, Label& label);
void ret();
// Run final checks
bool finalize();
// Places a label at current location and returns it
Label setLabel();
// Assigns label position to the current location
void setLabel(Label& label);
void logAppend(const char* fmt, ...) LUAU_PRINTF_ATTR(2, 3);
uint32_t getCodeSize() const;
// Resulting data and code that need to be copied over one after the other
// The *end* of 'data' has to be aligned to 16 bytes, this will also align 'code'
std::vector<uint8_t> data;
std::vector<uint32_t> code;
std::string text;
const bool logText = false;
private:
// Instruction archetypes
void place0(const char* name, uint32_t word);
void placeSR3(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t op, int shift = 0);
void placeSR2(const char* name, RegisterA64 dst, RegisterA64 src, uint8_t op);
void placeR3(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t op, uint8_t op2);
void placeR1(const char* name, RegisterA64 dst, RegisterA64 src, uint32_t op);
void placeI12(const char* name, RegisterA64 dst, RegisterA64 src1, int src2, uint8_t op);
void placeI16(const char* name, RegisterA64 dst, int src, uint8_t op, int shift = 0);
void placeA(const char* name, RegisterA64 dst, AddressA64 src, uint8_t op, uint8_t size);
void placeBC(const char* name, Label& label, uint8_t op, uint8_t cond);
void placeBR(const char* name, Label& label, uint8_t op, RegisterA64 cond);
void place(uint32_t word);
void placeLabel(Label& label);
void commit();
LUAU_NOINLINE void extend();
// Data
size_t allocateData(size_t size, size_t align);
// Logging of assembly in text form
LUAU_NOINLINE void log(const char* opcode);
LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift = 0);
LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, RegisterA64 src1, int src2);
LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, RegisterA64 src);
LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, int src, int shift = 0);
LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, AddressA64 src);
LUAU_NOINLINE void log(const char* opcode, RegisterA64 src, Label label);
LUAU_NOINLINE void log(const char* opcode, Label label);
LUAU_NOINLINE void log(Label label);
LUAU_NOINLINE void log(RegisterA64 reg);
LUAU_NOINLINE void log(AddressA64 addr);
uint32_t nextLabel = 1;
std::vector<Label> pendingLabels;
std::vector<uint32_t> labelLocations;
bool finalized = false;
size_t dataPos = 0;
uint32_t* codePos = nullptr;
uint32_t* codeEnd = nullptr;
};
} // namespace CodeGen
} // namespace Luau

View file

@ -2,8 +2,8 @@
#pragma once
#include "Luau/Common.h"
#include "Luau/Condition.h"
#include "Luau/Label.h"
#include "Luau/ConditionX64.h"
#include "Luau/OperandX64.h"
#include "Luau/RegisterX64.h"
@ -23,6 +23,19 @@ enum class RoundingModeX64
RoundToZero = 0b11,
};
enum class AlignmentDataX64
{
Nop,
Int3,
Ud2, // int3 will be used as a fall-back if it doesn't fit
};
enum class ABIX64
{
Windows,
SystemV,
};
class AssemblyBuilderX64
{
public:
@ -71,7 +84,7 @@ public:
void ret();
// Control flow
void jcc(Condition cond, Label& label);
void jcc(ConditionX64 cond, Label& label);
void jmp(Label& label);
void jmp(OperandX64 op);
@ -80,6 +93,10 @@ public:
void int3();
// Code alignment
void nop(uint32_t length = 1);
void align(uint32_t alignment, AlignmentDataX64 data = AlignmentDataX64::Nop);
// AVX
void vaddpd(OperandX64 dst, OperandX64 src1, OperandX64 src2);
void vaddps(OperandX64 dst, OperandX64 src1, OperandX64 src2);
@ -131,6 +148,8 @@ public:
void logAppend(const char* fmt, ...) LUAU_PRINTF_ATTR(2, 3);
uint32_t getCodeSize() const;
// Resulting data and code that need to be copied over one after the other
// The *end* of 'data' has to be aligned to 16 bytes, this will also align 'code'
std::vector<uint8_t> data;
@ -140,6 +159,8 @@ public:
const bool logText = false;
const ABIX64 abi;
private:
// Instruction archetypes
void placeBinary(const char* name, OperandX64 lhs, OperandX64 rhs, uint8_t codeimm8, uint8_t codeimm, uint8_t codeimmImm8, uint8_t code8rev,
@ -177,7 +198,6 @@ private:
void commit();
LUAU_NOINLINE void extend();
uint32_t getCodeSize();
// Data
size_t allocateData(size_t size, size_t align);
@ -192,8 +212,8 @@ private:
LUAU_NOINLINE void log(const char* opcode, Label label);
void log(OperandX64 op);
const char* getSizeName(SizeX64 size);
const char* getRegisterName(RegisterX64 reg);
const char* getSizeName(SizeX64 size) const;
const char* getRegisterName(RegisterX64 reg) const;
uint32_t nextLabel = 1;
std::vector<Label> pendingLabels;

View file

@ -11,6 +11,8 @@ namespace Luau
namespace CodeGen
{
constexpr uint32_t kCodeAlignment = 32;
struct CodeAllocator
{
CodeAllocator(size_t blockSize, size_t maxTotalSize);

View file

@ -17,8 +17,20 @@ void create(lua_State* L);
// Builds target function and all inner functions
void compile(lua_State* L, int idx);
// Generates assembly text for target function and all inner functions
std::string getAssemblyText(lua_State* L, int idx);
using annotatorFn = void (*)(void* context, std::string& result, int fid, int instpos);
struct AssemblyOptions
{
bool outputBinary = false;
bool skipOutlinedCode = false;
// Optional annotator function can be provided to describe each instruction, it takes function id and sequential instruction id
annotatorFn annotator = nullptr;
void* annotatorContext = nullptr;
};
// Generates assembly for target function and all inner functions
std::string getAssembly(lua_State* L, int idx, AssemblyOptions options = {});
} // namespace CodeGen
} // namespace Luau

View file

@ -0,0 +1,37 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
namespace Luau
{
namespace CodeGen
{
enum class ConditionA64
{
Equal,
NotEqual,
CarrySet,
CarryClear,
Minus,
Plus,
Overflow,
NoOverflow,
UnsignedGreater,
UnsignedLessEqual,
GreaterEqual,
Less,
Greater,
LessEqual,
Always,
Count
};
} // namespace CodeGen
} // namespace Luau

View file

@ -6,7 +6,7 @@ namespace Luau
namespace CodeGen
{
enum class Condition
enum class ConditionX64 : uint8_t
{
Overflow,
NoOverflow,

View file

@ -61,7 +61,7 @@ struct OperandX64
constexpr OperandX64 operator[](OperandX64&& addr) const
{
LUAU_ASSERT(cat == CategoryX64::mem);
LUAU_ASSERT(memSize != SizeX64::none && index == noreg && scale == 1 && base == noreg && imm == 0);
LUAU_ASSERT(index == noreg && scale == 1 && base == noreg && imm == 0);
LUAU_ASSERT(addr.memSize == SizeX64::none);
addr.cat = CategoryX64::mem;
@ -70,13 +70,13 @@ struct OperandX64
}
};
constexpr OperandX64 addr{SizeX64::none, noreg, 1, noreg, 0};
constexpr OperandX64 byte{SizeX64::byte, noreg, 1, noreg, 0};
constexpr OperandX64 word{SizeX64::word, noreg, 1, noreg, 0};
constexpr OperandX64 dword{SizeX64::dword, noreg, 1, noreg, 0};
constexpr OperandX64 qword{SizeX64::qword, noreg, 1, noreg, 0};
constexpr OperandX64 xmmword{SizeX64::xmmword, noreg, 1, noreg, 0};
constexpr OperandX64 ymmword{SizeX64::ymmword, noreg, 1, noreg, 0};
constexpr OperandX64 ptr{sizeof(void*) == 4 ? SizeX64::dword : SizeX64::qword, noreg, 1, noreg, 0};
constexpr OperandX64 operator*(RegisterX64 reg, uint8_t scale)
{
@ -94,6 +94,11 @@ constexpr OperandX64 operator+(RegisterX64 reg, int32_t disp)
return OperandX64(SizeX64::none, noreg, 1, reg, disp);
}
constexpr OperandX64 operator-(RegisterX64 reg, int32_t disp)
{
return OperandX64(SizeX64::none, noreg, 1, reg, -disp);
}
constexpr OperandX64 operator+(RegisterX64 base, RegisterX64 index)
{
LUAU_ASSERT(index.index != 4 && "sp cannot be used as index");

View file

@ -0,0 +1,105 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Common.h"
#include <stdint.h>
namespace Luau
{
namespace CodeGen
{
enum class KindA64 : uint8_t
{
none,
w, // 32-bit GPR
x, // 64-bit GPR
};
struct RegisterA64
{
KindA64 kind : 3;
uint8_t index : 5;
constexpr bool operator==(RegisterA64 rhs) const
{
return kind == rhs.kind && index == rhs.index;
}
constexpr bool operator!=(RegisterA64 rhs) const
{
return !(*this == rhs);
}
};
constexpr RegisterA64 w0{KindA64::w, 0};
constexpr RegisterA64 w1{KindA64::w, 1};
constexpr RegisterA64 w2{KindA64::w, 2};
constexpr RegisterA64 w3{KindA64::w, 3};
constexpr RegisterA64 w4{KindA64::w, 4};
constexpr RegisterA64 w5{KindA64::w, 5};
constexpr RegisterA64 w6{KindA64::w, 6};
constexpr RegisterA64 w7{KindA64::w, 7};
constexpr RegisterA64 w8{KindA64::w, 8};
constexpr RegisterA64 w9{KindA64::w, 9};
constexpr RegisterA64 w10{KindA64::w, 10};
constexpr RegisterA64 w11{KindA64::w, 11};
constexpr RegisterA64 w12{KindA64::w, 12};
constexpr RegisterA64 w13{KindA64::w, 13};
constexpr RegisterA64 w14{KindA64::w, 14};
constexpr RegisterA64 w15{KindA64::w, 15};
constexpr RegisterA64 w16{KindA64::w, 16};
constexpr RegisterA64 w17{KindA64::w, 17};
constexpr RegisterA64 w18{KindA64::w, 18};
constexpr RegisterA64 w19{KindA64::w, 19};
constexpr RegisterA64 w20{KindA64::w, 20};
constexpr RegisterA64 w21{KindA64::w, 21};
constexpr RegisterA64 w22{KindA64::w, 22};
constexpr RegisterA64 w23{KindA64::w, 23};
constexpr RegisterA64 w24{KindA64::w, 24};
constexpr RegisterA64 w25{KindA64::w, 25};
constexpr RegisterA64 w26{KindA64::w, 26};
constexpr RegisterA64 w27{KindA64::w, 27};
constexpr RegisterA64 w28{KindA64::w, 28};
constexpr RegisterA64 w29{KindA64::w, 29};
constexpr RegisterA64 w30{KindA64::w, 30};
constexpr RegisterA64 wzr{KindA64::w, 31};
constexpr RegisterA64 x0{KindA64::x, 0};
constexpr RegisterA64 x1{KindA64::x, 1};
constexpr RegisterA64 x2{KindA64::x, 2};
constexpr RegisterA64 x3{KindA64::x, 3};
constexpr RegisterA64 x4{KindA64::x, 4};
constexpr RegisterA64 x5{KindA64::x, 5};
constexpr RegisterA64 x6{KindA64::x, 6};
constexpr RegisterA64 x7{KindA64::x, 7};
constexpr RegisterA64 x8{KindA64::x, 8};
constexpr RegisterA64 x9{KindA64::x, 9};
constexpr RegisterA64 x10{KindA64::x, 10};
constexpr RegisterA64 x11{KindA64::x, 11};
constexpr RegisterA64 x12{KindA64::x, 12};
constexpr RegisterA64 x13{KindA64::x, 13};
constexpr RegisterA64 x14{KindA64::x, 14};
constexpr RegisterA64 x15{KindA64::x, 15};
constexpr RegisterA64 x16{KindA64::x, 16};
constexpr RegisterA64 x17{KindA64::x, 17};
constexpr RegisterA64 x18{KindA64::x, 18};
constexpr RegisterA64 x19{KindA64::x, 19};
constexpr RegisterA64 x20{KindA64::x, 20};
constexpr RegisterA64 x21{KindA64::x, 21};
constexpr RegisterA64 x22{KindA64::x, 22};
constexpr RegisterA64 x23{KindA64::x, 23};
constexpr RegisterA64 x24{KindA64::x, 24};
constexpr RegisterA64 x25{KindA64::x, 25};
constexpr RegisterA64 x26{KindA64::x, 26};
constexpr RegisterA64 x27{KindA64::x, 27};
constexpr RegisterA64 x28{KindA64::x, 28};
constexpr RegisterA64 x29{KindA64::x, 29};
constexpr RegisterA64 x30{KindA64::x, 30};
constexpr RegisterA64 xzr{KindA64::x, 31};
constexpr RegisterA64 sp{KindA64::none, 31};
} // namespace CodeGen
} // namespace Luau

View file

@ -113,5 +113,25 @@ constexpr RegisterX64 ymm13{SizeX64::ymmword, 13};
constexpr RegisterX64 ymm14{SizeX64::ymmword, 14};
constexpr RegisterX64 ymm15{SizeX64::ymmword, 15};
constexpr RegisterX64 byteReg(RegisterX64 reg)
{
return RegisterX64{SizeX64::byte, reg.index};
}
constexpr RegisterX64 wordReg(RegisterX64 reg)
{
return RegisterX64{SizeX64::word, reg.index};
}
constexpr RegisterX64 dwordReg(RegisterX64 reg)
{
return RegisterX64{SizeX64::dword, reg.index};
}
constexpr RegisterX64 qwordReg(RegisterX64 reg)
{
return RegisterX64{SizeX64::qword, reg.index};
}
} // namespace CodeGen
} // namespace Luau

View file

@ -0,0 +1,607 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/AssemblyBuilderA64.h"
#include "ByteUtils.h"
#include <stdarg.h>
namespace Luau
{
namespace CodeGen
{
static const uint8_t codeForCondition[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14};
static_assert(sizeof(codeForCondition) / sizeof(codeForCondition[0]) == size_t(ConditionA64::Count), "all conditions have to be covered");
static const char* textForCondition[] = {
"b.eq", "b.ne", "b.cs", "b.cc", "b.mi", "b.pl", "b.vs", "b.vc", "b.hi", "b.ls", "b.ge", "b.lt", "b.gt", "b.le", "b.al"};
static_assert(sizeof(textForCondition) / sizeof(textForCondition[0]) == size_t(ConditionA64::Count), "all conditions have to be covered");
const unsigned kMaxAlign = 32;
AssemblyBuilderA64::AssemblyBuilderA64(bool logText)
: logText(logText)
{
data.resize(4096);
dataPos = data.size(); // data is filled backwards
code.resize(1024);
codePos = code.data();
codeEnd = code.data() + code.size();
}
AssemblyBuilderA64::~AssemblyBuilderA64()
{
LUAU_ASSERT(finalized);
}
void AssemblyBuilderA64::mov(RegisterA64 dst, RegisterA64 src)
{
placeSR2("mov", dst, src, 0b01'01010);
}
void AssemblyBuilderA64::mov(RegisterA64 dst, uint16_t src, int shift)
{
placeI16("mov", dst, src, 0b10'100101, shift);
}
void AssemblyBuilderA64::movk(RegisterA64 dst, uint16_t src, int shift)
{
placeI16("movk", dst, src, 0b11'100101, shift);
}
void AssemblyBuilderA64::add(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift)
{
placeSR3("add", dst, src1, src2, 0b00'01011, shift);
}
void AssemblyBuilderA64::add(RegisterA64 dst, RegisterA64 src1, int src2)
{
placeI12("add", dst, src1, src2, 0b00'10001);
}
void AssemblyBuilderA64::sub(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift)
{
placeSR3("sub", dst, src1, src2, 0b10'01011, shift);
}
void AssemblyBuilderA64::sub(RegisterA64 dst, RegisterA64 src1, int src2)
{
placeI12("sub", dst, src1, src2, 0b10'10001);
}
void AssemblyBuilderA64::neg(RegisterA64 dst, RegisterA64 src)
{
placeSR2("neg", dst, src, 0b10'01011);
}
void AssemblyBuilderA64::and_(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2)
{
placeSR3("and", dst, src1, src2, 0b00'01010);
}
void AssemblyBuilderA64::orr(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2)
{
placeSR3("orr", dst, src1, src2, 0b01'01010);
}
void AssemblyBuilderA64::eor(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2)
{
placeSR3("eor", dst, src1, src2, 0b10'01010);
}
void AssemblyBuilderA64::lsl(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2)
{
placeR3("lsl", dst, src1, src2, 0b11010110, 0b0010'00);
}
void AssemblyBuilderA64::lsr(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2)
{
placeR3("lsr", dst, src1, src2, 0b11010110, 0b0010'01);
}
void AssemblyBuilderA64::asr(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2)
{
placeR3("asr", dst, src1, src2, 0b11010110, 0b0010'10);
}
void AssemblyBuilderA64::ror(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2)
{
placeR3("ror", dst, src1, src2, 0b11010110, 0b0010'11);
}
void AssemblyBuilderA64::clz(RegisterA64 dst, RegisterA64 src)
{
placeR1("clz", dst, src, 0b10'11010110'00000'00010'0);
}
void AssemblyBuilderA64::rbit(RegisterA64 dst, RegisterA64 src)
{
placeR1("rbit", dst, src, 0b10'11010110'00000'0000'00);
}
void AssemblyBuilderA64::ldr(RegisterA64 dst, AddressA64 src)
{
LUAU_ASSERT(dst.kind == KindA64::x || dst.kind == KindA64::w);
placeA("ldr", dst, src, 0b11100001, 0b10 | uint8_t(dst.kind == KindA64::x));
}
void AssemblyBuilderA64::ldrb(RegisterA64 dst, AddressA64 src)
{
LUAU_ASSERT(dst.kind == KindA64::w);
placeA("ldrb", dst, src, 0b11100001, 0b00);
}
void AssemblyBuilderA64::ldrh(RegisterA64 dst, AddressA64 src)
{
LUAU_ASSERT(dst.kind == KindA64::w);
placeA("ldrh", dst, src, 0b11100001, 0b01);
}
void AssemblyBuilderA64::ldrsb(RegisterA64 dst, AddressA64 src)
{
LUAU_ASSERT(dst.kind == KindA64::x || dst.kind == KindA64::w);
placeA("ldrsb", dst, src, 0b11100010 | uint8_t(dst.kind == KindA64::w), 0b00);
}
void AssemblyBuilderA64::ldrsh(RegisterA64 dst, AddressA64 src)
{
LUAU_ASSERT(dst.kind == KindA64::x || dst.kind == KindA64::w);
placeA("ldrsh", dst, src, 0b11100010 | uint8_t(dst.kind == KindA64::w), 0b01);
}
void AssemblyBuilderA64::ldrsw(RegisterA64 dst, AddressA64 src)
{
LUAU_ASSERT(dst.kind == KindA64::x);
placeA("ldrsw", dst, src, 0b11100010, 0b10);
}
void AssemblyBuilderA64::str(RegisterA64 src, AddressA64 dst)
{
LUAU_ASSERT(src.kind == KindA64::x || src.kind == KindA64::w);
placeA("str", src, dst, 0b11100000, 0b10 | uint8_t(src.kind == KindA64::x));
}
void AssemblyBuilderA64::strb(RegisterA64 src, AddressA64 dst)
{
LUAU_ASSERT(src.kind == KindA64::w);
placeA("strb", src, dst, 0b11100000, 0b00);
}
void AssemblyBuilderA64::strh(RegisterA64 src, AddressA64 dst)
{
LUAU_ASSERT(src.kind == KindA64::w);
placeA("strh", src, dst, 0b11100000, 0b01);
}
void AssemblyBuilderA64::b(ConditionA64 cond, Label& label)
{
placeBC(textForCondition[int(cond)], label, 0b0101010'0, codeForCondition[int(cond)]);
}
void AssemblyBuilderA64::cbz(RegisterA64 src, Label& label)
{
placeBR("cbz", label, 0b011010'0, src);
}
void AssemblyBuilderA64::cbnz(RegisterA64 src, Label& label)
{
placeBR("cbnz", label, 0b011010'1, src);
}
void AssemblyBuilderA64::ret()
{
place0("ret", 0b1101011'0'0'10'11111'0000'0'0'11110'00000);
}
bool AssemblyBuilderA64::finalize()
{
bool success = true;
code.resize(codePos - code.data());
// Resolve jump targets
for (Label fixup : pendingLabels)
{
// If this assertion fires, a label was used in jmp without calling setLabel
LUAU_ASSERT(labelLocations[fixup.id - 1] != ~0u);
int value = int(labelLocations[fixup.id - 1]) - int(fixup.location);
// imm19 encoding word offset, at bit offset 5
// note that 18 bits of word offsets = 20 bits of byte offsets = +-1MB
if (value > -(1 << 18) && value < (1 << 18))
code[fixup.location] |= (value & ((1 << 19) - 1)) << 5;
else
success = false; // overflow
}
size_t dataSize = data.size() - dataPos;
// Shrink data
if (dataSize > 0)
memmove(&data[0], &data[dataPos], dataSize);
data.resize(dataSize);
finalized = true;
return success;
}
Label AssemblyBuilderA64::setLabel()
{
Label label{nextLabel++, getCodeSize()};
labelLocations.push_back(~0u);
if (logText)
log(label);
return label;
}
void AssemblyBuilderA64::setLabel(Label& label)
{
if (label.id == 0)
{
label.id = nextLabel++;
labelLocations.push_back(~0u);
}
label.location = getCodeSize();
labelLocations[label.id - 1] = label.location;
if (logText)
log(label);
}
void AssemblyBuilderA64::logAppend(const char* fmt, ...)
{
char buf[256];
va_list args;
va_start(args, fmt);
vsnprintf(buf, sizeof(buf), fmt, args);
va_end(args);
text.append(buf);
}
uint32_t AssemblyBuilderA64::getCodeSize() const
{
return uint32_t(codePos - code.data());
}
void AssemblyBuilderA64::place0(const char* name, uint32_t op)
{
if (logText)
log(name);
place(op);
commit();
}
void AssemblyBuilderA64::placeSR3(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t op, int shift)
{
if (logText)
log(name, dst, src1, src2, shift);
LUAU_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x);
LUAU_ASSERT(dst.kind == src1.kind && dst.kind == src2.kind);
LUAU_ASSERT(shift >= 0 && shift < 64); // right shift requires changing some encoding bits
uint32_t sf = (dst.kind == KindA64::x) ? 0x80000000 : 0;
place(dst.index | (src1.index << 5) | (shift << 10) | (src2.index << 16) | (op << 24) | sf);
commit();
}
void AssemblyBuilderA64::placeSR2(const char* name, RegisterA64 dst, RegisterA64 src, uint8_t op)
{
if (logText)
log(name, dst, src);
LUAU_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x);
LUAU_ASSERT(dst.kind == src.kind);
uint32_t sf = (dst.kind == KindA64::x) ? 0x80000000 : 0;
place(dst.index | (0x1f << 5) | (src.index << 16) | (op << 24) | sf);
commit();
}
void AssemblyBuilderA64::placeR3(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t op, uint8_t op2)
{
if (logText)
log(name, dst, src1, src2);
LUAU_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x);
LUAU_ASSERT(dst.kind == src1.kind && dst.kind == src2.kind);
uint32_t sf = (dst.kind == KindA64::x) ? 0x80000000 : 0;
place(dst.index | (src1.index << 5) | (op2 << 10) | (src2.index << 16) | (op << 21) | sf);
commit();
}
void AssemblyBuilderA64::placeR1(const char* name, RegisterA64 dst, RegisterA64 src, uint32_t op)
{
if (logText)
log(name, dst, src);
LUAU_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x);
LUAU_ASSERT(dst.kind == src.kind);
uint32_t sf = (dst.kind == KindA64::x) ? 0x80000000 : 0;
place(dst.index | (src.index << 5) | (op << 10) | sf);
commit();
}
void AssemblyBuilderA64::placeI12(const char* name, RegisterA64 dst, RegisterA64 src1, int src2, uint8_t op)
{
if (logText)
log(name, dst, src1, src2);
LUAU_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x);
LUAU_ASSERT(dst.kind == src1.kind);
LUAU_ASSERT(src2 >= 0 && src2 < (1 << 12));
uint32_t sf = (dst.kind == KindA64::x) ? 0x80000000 : 0;
place(dst.index | (src1.index << 5) | (src2 << 10) | (op << 24) | sf);
commit();
}
void AssemblyBuilderA64::placeI16(const char* name, RegisterA64 dst, int src, uint8_t op, int shift)
{
if (logText)
log(name, dst, src, shift);
LUAU_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x);
LUAU_ASSERT(src >= 0 && src <= 0xffff);
LUAU_ASSERT(shift == 0 || shift == 16 || shift == 32 || shift == 48);
uint32_t sf = (dst.kind == KindA64::x) ? 0x80000000 : 0;
place(dst.index | (src << 5) | ((shift >> 4) << 21) | (op << 23) | sf);
commit();
}
void AssemblyBuilderA64::placeA(const char* name, RegisterA64 dst, AddressA64 src, uint8_t op, uint8_t size)
{
if (logText)
log(name, dst, src);
switch (src.kind)
{
case AddressKindA64::imm:
LUAU_ASSERT(src.data % (1 << size) == 0);
place(dst.index | (src.base.index << 5) | ((src.data >> size) << 10) | (op << 22) | (1 << 24) | (size << 30));
break;
case AddressKindA64::reg:
place(dst.index | (src.base.index << 5) | (0b10 << 10) | (0b011 << 13) | (src.offset.index << 16) | (1 << 21) | (op << 22) | (size << 30));
break;
}
commit();
}
void AssemblyBuilderA64::placeBC(const char* name, Label& label, uint8_t op, uint8_t cond)
{
placeLabel(label);
if (logText)
log(name, label);
place(cond | (op << 24));
commit();
}
void AssemblyBuilderA64::placeBR(const char* name, Label& label, uint8_t op, RegisterA64 cond)
{
placeLabel(label);
if (logText)
log(name, cond, label);
LUAU_ASSERT(cond.kind == KindA64::w || cond.kind == KindA64::x);
uint32_t sf = (cond.kind == KindA64::x) ? 0x80000000 : 0;
place(cond.index | (op << 24) | sf);
commit();
}
void AssemblyBuilderA64::place(uint32_t word)
{
LUAU_ASSERT(codePos < codeEnd);
*codePos++ = word;
}
void AssemblyBuilderA64::placeLabel(Label& label)
{
if (label.location == ~0u)
{
if (label.id == 0)
{
label.id = nextLabel++;
labelLocations.push_back(~0u);
}
pendingLabels.push_back({label.id, getCodeSize()});
}
else
{
// note: if label has an assigned location we can in theory avoid patching it later, but
// we need to handle potential overflow of 19-bit offsets
LUAU_ASSERT(label.id != 0);
labelLocations[label.id - 1] = label.location;
pendingLabels.push_back({label.id, getCodeSize()});
}
}
void AssemblyBuilderA64::commit()
{
LUAU_ASSERT(codePos <= codeEnd);
if (codeEnd == codePos)
extend();
}
void AssemblyBuilderA64::extend()
{
uint32_t count = getCodeSize();
code.resize(code.size() * 2);
codePos = code.data() + count;
codeEnd = code.data() + code.size();
}
size_t AssemblyBuilderA64::allocateData(size_t size, size_t align)
{
LUAU_ASSERT(align > 0 && align <= kMaxAlign && (align & (align - 1)) == 0);
if (dataPos < size)
{
size_t oldSize = data.size();
data.resize(data.size() * 2);
memcpy(&data[oldSize], &data[0], oldSize);
memset(&data[0], 0, oldSize);
dataPos += oldSize;
}
dataPos = (dataPos - size) & ~(align - 1);
return dataPos;
}
void AssemblyBuilderA64::log(const char* opcode)
{
logAppend(" %s\n", opcode);
}
void AssemblyBuilderA64::log(const char* opcode, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift)
{
logAppend(" %-12s", opcode);
log(dst);
text.append(",");
log(src1);
text.append(",");
log(src2);
if (shift > 0)
logAppend(" LSL #%d", shift);
text.append("\n");
}
void AssemblyBuilderA64::log(const char* opcode, RegisterA64 dst, RegisterA64 src1, int src2)
{
logAppend(" %-12s", opcode);
log(dst);
text.append(",");
log(src1);
text.append(",");
logAppend("#%d", src2);
text.append("\n");
}
void AssemblyBuilderA64::log(const char* opcode, RegisterA64 dst, AddressA64 src)
{
logAppend(" %-12s", opcode);
log(dst);
text.append(",");
log(src);
text.append("\n");
}
void AssemblyBuilderA64::log(const char* opcode, RegisterA64 dst, RegisterA64 src)
{
logAppend(" %-12s", opcode);
log(dst);
text.append(",");
log(src);
text.append("\n");
}
void AssemblyBuilderA64::log(const char* opcode, RegisterA64 dst, int src, int shift)
{
logAppend(" %-12s", opcode);
log(dst);
text.append(",");
logAppend("#%d", src);
if (shift > 0)
logAppend(" LSL #%d", shift);
text.append("\n");
}
void AssemblyBuilderA64::log(const char* opcode, RegisterA64 src, Label label)
{
logAppend(" %-12s", opcode);
log(src);
text.append(",");
logAppend(".L%d\n", label.id);
}
void AssemblyBuilderA64::log(const char* opcode, Label label)
{
logAppend(" %-12s.L%d\n", opcode, label.id);
}
void AssemblyBuilderA64::log(Label label)
{
logAppend(".L%d:\n", label.id);
}
void AssemblyBuilderA64::log(RegisterA64 reg)
{
switch (reg.kind)
{
case KindA64::w:
if (reg.index == 31)
logAppend("wzr");
else
logAppend("w%d", reg.index);
break;
case KindA64::x:
if (reg.index == 31)
logAppend("xzr");
else
logAppend("x%d", reg.index);
break;
case KindA64::none:
LUAU_ASSERT(!"Unexpected register kind");
}
}
void AssemblyBuilderA64::log(AddressA64 addr)
{
text.append("[");
switch (addr.kind)
{
case AddressKindA64::imm:
log(addr.base);
if (addr.data != 0)
logAppend(",#%d", addr.data);
break;
case AddressKindA64::reg:
log(addr.base);
text.append(",");
log(addr.offset);
if (addr.data != 0)
logAppend(" LSL #%d", addr.data);
break;
}
text.append("]");
}
} // namespace CodeGen
} // namespace Luau

View file

@ -13,13 +13,13 @@ namespace CodeGen
{
// TODO: more assertions on operand sizes
const uint8_t codeForCondition[] = {
static const uint8_t codeForCondition[] = {
0x0, 0x1, 0x2, 0x3, 0x2, 0x6, 0x7, 0x3, 0x4, 0xc, 0xe, 0xf, 0xd, 0x3, 0x7, 0x6, 0x2, 0x5, 0xd, 0xf, 0xe, 0xc, 0x4, 0x5, 0xa, 0xb};
static_assert(sizeof(codeForCondition) / sizeof(codeForCondition[0]) == size_t(Condition::Count), "all conditions have to be covered");
static_assert(sizeof(codeForCondition) / sizeof(codeForCondition[0]) == size_t(ConditionX64::Count), "all conditions have to be covered");
const char* textForCondition[] = {"jo", "jno", "jc", "jnc", "jb", "jbe", "ja", "jae", "je", "jl", "jle", "jg", "jge", "jnb", "jnbe", "jna", "jnae",
"jne", "jnl", "jnle", "jng", "jnge", "jz", "jnz", "jp", "jnp"};
static_assert(sizeof(textForCondition) / sizeof(textForCondition[0]) == size_t(Condition::Count), "all conditions have to be covered");
static const char* textForCondition[] = {"jo", "jno", "jc", "jnc", "jb", "jbe", "ja", "jae", "je", "jl", "jle", "jg", "jge", "jnb", "jnbe", "jna",
"jnae", "jne", "jnl", "jnle", "jng", "jnge", "jz", "jnz", "jp", "jnp"};
static_assert(sizeof(textForCondition) / sizeof(textForCondition[0]) == size_t(ConditionX64::Count), "all conditions have to be covered");
#define OP_PLUS_REG(op, reg) ((op) + (reg & 0x7))
#define OP_PLUS_CC(op, cc) ((op) + uint8_t(cc))
@ -29,7 +29,7 @@ static_assert(sizeof(textForCondition) / sizeof(textForCondition[0]) == size_t(C
#define REX_X(reg) (((reg).index & 0x8) >> 2)
#define REX_B(reg) (((reg).index & 0x8) >> 3)
#define AVX_W(value) (!(value) ? 0x80 : 0x0)
#define AVX_W(value) ((value) ? 0x80 : 0x0)
#define AVX_R(reg) ((~(reg).index & 0x8) << 4)
#define AVX_X(reg) ((~(reg).index & 0x8) << 3)
#define AVX_B(reg) ((~(reg).index & 0x8) << 2)
@ -50,12 +50,23 @@ const unsigned AVX_66 = 0b01;
const unsigned AVX_F3 = 0b10;
const unsigned AVX_F2 = 0b11;
const unsigned kMaxAlign = 16;
const unsigned kMaxAlign = 32;
const unsigned kMaxInstructionLength = 16;
const uint8_t kRoundingPrecisionInexact = 0b1000;
static ABIX64 getCurrentX64ABI()
{
#if defined(_WIN32)
return ABIX64::Windows;
#else
return ABIX64::SystemV;
#endif
}
AssemblyBuilderX64::AssemblyBuilderX64(bool logText)
: logText(logText)
, abi(getCurrentX64ABI())
{
data.resize(4096);
dataPos = data.size(); // data is filled backwards
@ -317,7 +328,10 @@ void AssemblyBuilderX64::lea(OperandX64 lhs, OperandX64 rhs)
if (logText)
log("lea", lhs, rhs);
LUAU_ASSERT(rhs.cat == CategoryX64::mem);
LUAU_ASSERT(lhs.cat == CategoryX64::reg && rhs.cat == CategoryX64::mem && rhs.memSize == SizeX64::none);
LUAU_ASSERT(rhs.base == rip || rhs.base.size == lhs.base.size);
LUAU_ASSERT(rhs.index == noreg || rhs.index.size == lhs.base.size);
rhs.memSize = lhs.base.size;
placeBinaryRegAndRegMem(lhs, rhs, 0x8d, 0x8d);
}
@ -352,7 +366,7 @@ void AssemblyBuilderX64::ret()
commit();
}
void AssemblyBuilderX64::jcc(Condition cond, Label& label)
void AssemblyBuilderX64::jcc(ConditionX64 cond, Label& label)
{
placeJcc(textForCondition[size_t(cond)], label, codeForCondition[size_t(cond)]);
}
@ -416,6 +430,153 @@ void AssemblyBuilderX64::int3()
log("int3");
place(0xcc);
commit();
}
void AssemblyBuilderX64::nop(uint32_t length)
{
while (length != 0)
{
uint32_t step = length > 9 ? 9 : length;
length -= step;
switch (step)
{
case 1:
if (logText)
logAppend(" nop\n");
place(0x90);
break;
case 2:
if (logText)
logAppend(" xchg ax, ax ; %u-byte nop\n", step);
place(0x66);
place(0x90);
break;
case 3:
if (logText)
logAppend(" nop dword ptr[rax] ; %u-byte nop\n", step);
place(0x0f);
place(0x1f);
place(0x00);
break;
case 4:
if (logText)
logAppend(" nop dword ptr[rax] ; %u-byte nop\n", step);
place(0x0f);
place(0x1f);
place(0x40);
place(0x00);
break;
case 5:
if (logText)
logAppend(" nop dword ptr[rax+rax] ; %u-byte nop\n", step);
place(0x0f);
place(0x1f);
place(0x44);
place(0x00);
place(0x00);
break;
case 6:
if (logText)
logAppend(" nop word ptr[rax+rax] ; %u-byte nop\n", step);
place(0x66);
place(0x0f);
place(0x1f);
place(0x44);
place(0x00);
place(0x00);
break;
case 7:
if (logText)
logAppend(" nop dword ptr[rax] ; %u-byte nop\n", step);
place(0x0f);
place(0x1f);
place(0x80);
place(0x00);
place(0x00);
place(0x00);
place(0x00);
break;
case 8:
if (logText)
logAppend(" nop dword ptr[rax+rax] ; %u-byte nop\n", step);
place(0x0f);
place(0x1f);
place(0x84);
place(0x00);
place(0x00);
place(0x00);
place(0x00);
place(0x00);
break;
case 9:
if (logText)
logAppend(" nop word ptr[rax+rax] ; %u-byte nop\n", step);
place(0x66);
place(0x0f);
place(0x1f);
place(0x84);
place(0x00);
place(0x00);
place(0x00);
place(0x00);
place(0x00);
break;
}
commit();
}
}
void AssemblyBuilderX64::align(uint32_t alignment, AlignmentDataX64 data)
{
LUAU_ASSERT((alignment & (alignment - 1)) == 0);
uint32_t size = getCodeSize();
uint32_t pad = ((size + alignment - 1) & ~(alignment - 1)) - size;
switch (data)
{
case AlignmentDataX64::Nop:
if (logText)
logAppend("; align %u\n", alignment);
nop(pad);
break;
case AlignmentDataX64::Int3:
if (logText)
logAppend("; align %u using int3\n", alignment);
while (codePos + pad > codeEnd)
extend();
for (uint32_t i = 0; i < pad; ++i)
place(0xcc);
commit();
break;
case AlignmentDataX64::Ud2:
if (logText)
logAppend("; align %u using ud2\n", alignment);
while (codePos + pad > codeEnd)
extend();
uint32_t i = 0;
for (; i + 1 < pad; i += 2)
{
place(0x0f);
place(0x0b);
}
if (i < pad)
place(0xcc);
commit();
break;
}
}
void AssemblyBuilderX64::vaddpd(OperandX64 dst, OperandX64 src1, OperandX64 src2)
@ -465,12 +626,12 @@ void AssemblyBuilderX64::vucomisd(OperandX64 src1, OperandX64 src2)
void AssemblyBuilderX64::vcvttsd2si(OperandX64 dst, OperandX64 src)
{
placeAvx("vcvttsd2si", dst, src, 0x2c, dst.base.size == SizeX64::dword, AVX_0F, AVX_F2);
placeAvx("vcvttsd2si", dst, src, 0x2c, dst.base.size == SizeX64::qword, AVX_0F, AVX_F2);
}
void AssemblyBuilderX64::vcvtsi2sd(OperandX64 dst, OperandX64 src1, OperandX64 src2)
{
placeAvx("vcvtsi2sd", dst, src1, src2, 0x2a, (src2.cat == CategoryX64::reg ? src2.base.size : src2.memSize) == SizeX64::dword, AVX_0F, AVX_F2);
placeAvx("vcvtsi2sd", dst, src1, src2, 0x2a, (src2.cat == CategoryX64::reg ? src2.base.size : src2.memSize) == SizeX64::qword, AVX_0F, AVX_F2);
}
void AssemblyBuilderX64::vroundsd(OperandX64 dst, OperandX64 src1, OperandX64 src2, RoundingModeX64 roundingMode)
@ -623,7 +784,22 @@ OperandX64 AssemblyBuilderX64::bytes(const void* ptr, size_t size, size_t align)
{
size_t pos = allocateData(size, align);
memcpy(&data[pos], ptr, size);
return OperandX64(SizeX64::qword, noreg, 1, rip, int32_t(pos - data.size()));
return OperandX64(SizeX64::none, noreg, 1, rip, int32_t(pos - data.size()));
}
void AssemblyBuilderX64::logAppend(const char* fmt, ...)
{
char buf[256];
va_list args;
va_start(args, fmt);
vsnprintf(buf, sizeof(buf), fmt, args);
va_end(args);
text.append(buf);
}
uint32_t AssemblyBuilderX64::getCodeSize() const
{
return uint32_t(codePos - code.data());
}
void AssemblyBuilderX64::placeBinary(const char* name, OperandX64 lhs, OperandX64 rhs, uint8_t codeimm8, uint8_t codeimm, uint8_t codeimmImm8,
@ -899,7 +1075,7 @@ void AssemblyBuilderX64::placeVex(OperandX64 dst, OperandX64 src1, OperandX64 sr
place(AVX_3_3(setW, src1.base, dst.base.size == SizeX64::ymmword, prefix));
}
uint8_t getScaleEncoding(uint8_t scale)
static uint8_t getScaleEncoding(uint8_t scale)
{
static const uint8_t scales[9] = {0xff, 0, 1, 0xff, 2, 0xff, 0xff, 0xff, 3};
@ -1054,7 +1230,7 @@ void AssemblyBuilderX64::commit()
{
LUAU_ASSERT(codePos <= codeEnd);
if (codeEnd - codePos < 16)
if (codeEnd - codePos < kMaxInstructionLength)
extend();
}
@ -1067,11 +1243,6 @@ void AssemblyBuilderX64::extend()
codeEnd = code.data() + code.size();
}
uint32_t AssemblyBuilderX64::getCodeSize()
{
return uint32_t(codePos - code.data());
}
size_t AssemblyBuilderX64::allocateData(size_t size, size_t align)
{
LUAU_ASSERT(align > 0 && align <= kMaxAlign && (align & (align - 1)) == 0);
@ -1174,8 +1345,10 @@ void AssemblyBuilderX64::log(OperandX64 op)
{
if (op.imm >= 0 && op.imm <= 9)
logAppend("+%d", op.imm);
else
else if (op.imm > 0)
logAppend("+0%Xh", op.imm);
else
logAppend("-0%Xh", -op.imm);
}
text.append("]");
@ -1191,17 +1364,7 @@ void AssemblyBuilderX64::log(OperandX64 op)
}
}
void AssemblyBuilderX64::logAppend(const char* fmt, ...)
{
char buf[256];
va_list args;
va_start(args, fmt);
vsnprintf(buf, sizeof(buf), fmt, args);
va_end(args);
text.append(buf);
}
const char* AssemblyBuilderX64::getSizeName(SizeX64 size)
const char* AssemblyBuilderX64::getSizeName(SizeX64 size) const
{
static const char* sizeNames[] = {"none", "byte", "word", "dword", "qword", "xmmword", "ymmword"};
@ -1209,7 +1372,7 @@ const char* AssemblyBuilderX64::getSizeName(SizeX64 size)
return sizeNames[unsigned(size)];
}
const char* AssemblyBuilderX64::getRegisterName(RegisterX64 reg)
const char* AssemblyBuilderX64::getRegisterName(RegisterX64 reg) const
{
static const char* names[][16] = {{"rip", "", "", "", "", "", "", "", "", "", "", "", "", "", "", ""},
{"al", "cl", "dl", "bl", "spl", "bpl", "sil", "dil", "r8b", "r9b", "r10b", "r11b", "r12b", "r13b", "r14b", "r15b"},

View file

@ -1,7 +1,9 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
#include "Luau/Common.h"
#if defined(LUAU_BIG_ENDIAN)
#include <endian.h>
#endif
@ -15,7 +17,7 @@ inline uint8_t* writeu8(uint8_t* target, uint8_t value)
inline uint8_t* writeu32(uint8_t* target, uint32_t value)
{
#if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
#if defined(LUAU_BIG_ENDIAN)
value = htole32(value);
#endif
@ -25,7 +27,7 @@ inline uint8_t* writeu32(uint8_t* target, uint32_t value)
inline uint8_t* writeu64(uint8_t* target, uint64_t value)
{
#if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
#if defined(LUAU_BIG_ENDIAN)
value = htole64(value);
#endif
@ -51,7 +53,7 @@ inline uint8_t* writeuleb128(uint8_t* target, uint64_t value)
inline uint8_t* writef32(uint8_t* target, float value)
{
#if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
#if defined(LUAU_BIG_ENDIAN)
static_assert(sizeof(float) == sizeof(uint32_t), "type size must match to reinterpret data");
uint32_t data;
memcpy(&data, &value, sizeof(value));
@ -65,7 +67,7 @@ inline uint8_t* writef32(uint8_t* target, float value)
inline uint8_t* writef64(uint8_t* target, double value)
{
#if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
#if defined(LUAU_BIG_ENDIAN)
static_assert(sizeof(double) == sizeof(uint64_t), "type size must match to reinterpret data");
uint64_t data;
memcpy(&data, &value, sizeof(value));

View file

@ -110,8 +110,8 @@ CodeAllocator::~CodeAllocator()
bool CodeAllocator::allocate(
uint8_t* data, size_t dataSize, uint8_t* code, size_t codeSize, uint8_t*& result, size_t& resultSize, uint8_t*& resultCodeStart)
{
// 'Round up' to preserve 16 byte alignment
size_t alignedDataSize = (dataSize + 15) & ~15;
// 'Round up' to preserve code alignment
size_t alignedDataSize = (dataSize + (kCodeAlignment - 1)) & ~(kCodeAlignment - 1);
size_t totalSize = alignedDataSize + codeSize;
@ -187,8 +187,8 @@ bool CodeAllocator::allocateNewBlock(size_t& unwindInfoSize)
{
void* unwindInfo = createBlockUnwindInfo(context, block, blockSize, unwindInfoSize);
// 'Round up' to preserve 16 byte alignment of the following data and code
unwindInfoSize = (unwindInfoSize + 15) & ~15;
// 'Round up' to preserve alignment of the following data and code
unwindInfoSize = (unwindInfoSize + (kCodeAlignment - 1)) & ~(kCodeAlignment - 1);
LUAU_ASSERT(unwindInfoSize <= kMaxReservedDataSize);

View file

@ -1,6 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/CodeBlockUnwind.h"
#include "Luau/CodeAllocator.h"
#include "Luau/UnwindBuilder.h"
#include <string.h>
@ -58,7 +59,7 @@ void* createBlockUnwindInfo(void* context, uint8_t* block, size_t blockSize, siz
// All unwinding related data is placed together at the start of the block
size_t unwindSize = sizeof(RUNTIME_FUNCTION) + unwind->getSize();
unwindSize = (unwindSize + 15) & ~15; // Align to 16 bytes
unwindSize = (unwindSize + (kCodeAlignment - 1)) & ~(kCodeAlignment - 1); // Match code allocator alignment
LUAU_ASSERT(blockSize >= unwindSize);
RUNTIME_FUNCTION* runtimeFunc = (RUNTIME_FUNCTION*)block;
@ -82,7 +83,7 @@ void* createBlockUnwindInfo(void* context, uint8_t* block, size_t blockSize, siz
// All unwinding related data is placed together at the start of the block
size_t unwindSize = unwind->getSize();
unwindSize = (unwindSize + 15) & ~15; // Align to 16 bytes
unwindSize = (unwindSize + (kCodeAlignment - 1)) & ~(kCodeAlignment - 1); // Match code allocator alignment
LUAU_ASSERT(blockSize >= unwindSize);
char* unwindData = (char*)block;

View file

@ -32,7 +32,360 @@ namespace Luau
namespace CodeGen
{
static NativeProto* assembleFunction(AssemblyBuilderX64& build, NativeState& data, Proto* proto)
constexpr uint32_t kFunctionAlignment = 32;
struct InstructionOutline
{
int pcpos;
int length;
};
static void assembleHelpers(AssemblyBuilderX64& build, ModuleHelpers& helpers)
{
if (build.logText)
build.logAppend("; exitContinueVm\n");
helpers.exitContinueVm = build.setLabel();
emitExit(build, /* continueInVm */ true);
if (build.logText)
build.logAppend("; exitNoContinueVm\n");
helpers.exitNoContinueVm = build.setLabel();
emitExit(build, /* continueInVm */ false);
if (build.logText)
build.logAppend("; continueCallInVm\n");
helpers.continueCallInVm = build.setLabel();
emitContinueCallInVm(build);
}
static int emitInst(AssemblyBuilderX64& build, NativeState& data, ModuleHelpers& helpers, Proto* proto, LuauOpcode op, const Instruction* pc, int i,
Label* labelarr, Label& fallback)
{
int skip = 0;
switch (op)
{
case LOP_NOP:
break;
case LOP_LOADNIL:
emitInstLoadNil(build, pc);
break;
case LOP_LOADB:
emitInstLoadB(build, pc, i, labelarr);
break;
case LOP_LOADN:
emitInstLoadN(build, pc);
break;
case LOP_LOADK:
emitInstLoadK(build, pc);
break;
case LOP_LOADKX:
emitInstLoadKX(build, pc);
break;
case LOP_MOVE:
emitInstMove(build, pc);
break;
case LOP_GETGLOBAL:
emitInstGetGlobal(build, pc, i, fallback);
break;
case LOP_SETGLOBAL:
emitInstSetGlobal(build, pc, i, labelarr, fallback);
break;
case LOP_CALL:
emitInstCall(build, helpers, pc, i, labelarr);
break;
case LOP_RETURN:
emitInstReturn(build, helpers, pc, i, labelarr);
break;
case LOP_GETTABLE:
emitInstGetTable(build, pc, i, fallback);
break;
case LOP_SETTABLE:
emitInstSetTable(build, pc, i, labelarr, fallback);
break;
case LOP_GETTABLEKS:
emitInstGetTableKS(build, pc, i, fallback);
break;
case LOP_SETTABLEKS:
emitInstSetTableKS(build, pc, i, labelarr, fallback);
break;
case LOP_GETTABLEN:
emitInstGetTableN(build, pc, i, fallback);
break;
case LOP_SETTABLEN:
emitInstSetTableN(build, pc, i, labelarr, fallback);
break;
case LOP_JUMP:
emitInstJump(build, pc, i, labelarr);
break;
case LOP_JUMPBACK:
emitInstJumpBack(build, pc, i, labelarr);
break;
case LOP_JUMPIF:
emitInstJumpIf(build, pc, i, labelarr, /* not_ */ false);
break;
case LOP_JUMPIFNOT:
emitInstJumpIf(build, pc, i, labelarr, /* not_ */ true);
break;
case LOP_JUMPIFEQ:
emitInstJumpIfEq(build, pc, i, labelarr, /* not_ */ false, fallback);
break;
case LOP_JUMPIFLE:
emitInstJumpIfCond(build, pc, i, labelarr, ConditionX64::LessEqual, fallback);
break;
case LOP_JUMPIFLT:
emitInstJumpIfCond(build, pc, i, labelarr, ConditionX64::Less, fallback);
break;
case LOP_JUMPIFNOTEQ:
emitInstJumpIfEq(build, pc, i, labelarr, /* not_ */ true, fallback);
break;
case LOP_JUMPIFNOTLE:
emitInstJumpIfCond(build, pc, i, labelarr, ConditionX64::NotLessEqual, fallback);
break;
case LOP_JUMPIFNOTLT:
emitInstJumpIfCond(build, pc, i, labelarr, ConditionX64::NotLess, fallback);
break;
case LOP_JUMPX:
emitInstJumpX(build, pc, i, labelarr);
break;
case LOP_JUMPXEQKNIL:
emitInstJumpxEqNil(build, pc, i, labelarr);
break;
case LOP_JUMPXEQKB:
emitInstJumpxEqB(build, pc, i, labelarr);
break;
case LOP_JUMPXEQKN:
emitInstJumpxEqN(build, pc, proto->k, i, labelarr);
break;
case LOP_JUMPXEQKS:
emitInstJumpxEqS(build, pc, i, labelarr);
break;
case LOP_ADD:
emitInstBinary(build, pc, i, TM_ADD, fallback);
break;
case LOP_SUB:
emitInstBinary(build, pc, i, TM_SUB, fallback);
break;
case LOP_MUL:
emitInstBinary(build, pc, i, TM_MUL, fallback);
break;
case LOP_DIV:
emitInstBinary(build, pc, i, TM_DIV, fallback);
break;
case LOP_MOD:
emitInstBinary(build, pc, i, TM_MOD, fallback);
break;
case LOP_POW:
emitInstBinary(build, pc, i, TM_POW, fallback);
break;
case LOP_ADDK:
emitInstBinaryK(build, pc, i, TM_ADD, fallback);
break;
case LOP_SUBK:
emitInstBinaryK(build, pc, i, TM_SUB, fallback);
break;
case LOP_MULK:
emitInstBinaryK(build, pc, i, TM_MUL, fallback);
break;
case LOP_DIVK:
emitInstBinaryK(build, pc, i, TM_DIV, fallback);
break;
case LOP_MODK:
emitInstBinaryK(build, pc, i, TM_MOD, fallback);
break;
case LOP_POWK:
emitInstPowK(build, pc, proto->k, i, fallback);
break;
case LOP_NOT:
emitInstNot(build, pc);
break;
case LOP_MINUS:
emitInstMinus(build, pc, i, fallback);
break;
case LOP_LENGTH:
emitInstLength(build, pc, i, fallback);
break;
case LOP_NEWTABLE:
emitInstNewTable(build, pc, i, labelarr);
break;
case LOP_DUPTABLE:
emitInstDupTable(build, pc, i, labelarr);
break;
case LOP_SETLIST:
emitInstSetList(build, pc, i, labelarr);
break;
case LOP_GETUPVAL:
emitInstGetUpval(build, pc, i);
break;
case LOP_SETUPVAL:
emitInstSetUpval(build, pc, i, labelarr);
break;
case LOP_CLOSEUPVALS:
emitInstCloseUpvals(build, pc, i, labelarr);
break;
case LOP_FASTCALL:
skip = emitInstFastCall(build, pc, i, labelarr);
break;
case LOP_FASTCALL1:
skip = emitInstFastCall1(build, pc, i, labelarr);
break;
case LOP_FASTCALL2:
skip = emitInstFastCall2(build, pc, i, labelarr);
break;
case LOP_FASTCALL2K:
skip = emitInstFastCall2K(build, pc, i, labelarr);
break;
case LOP_FORNPREP:
emitInstForNPrep(build, pc, i, labelarr);
break;
case LOP_FORNLOOP:
emitInstForNLoop(build, pc, i, labelarr);
break;
case LOP_FORGLOOP:
emitinstForGLoop(build, pc, i, labelarr, fallback);
break;
case LOP_FORGPREP_NEXT:
emitInstForGPrepNext(build, pc, i, labelarr, fallback);
break;
case LOP_FORGPREP_INEXT:
emitInstForGPrepInext(build, pc, i, labelarr, fallback);
break;
case LOP_AND:
emitInstAnd(build, pc);
break;
case LOP_ANDK:
emitInstAndK(build, pc);
break;
case LOP_OR:
emitInstOr(build, pc);
break;
case LOP_ORK:
emitInstOrK(build, pc);
break;
case LOP_GETIMPORT:
emitInstGetImport(build, pc, fallback);
break;
case LOP_CONCAT:
emitInstConcat(build, pc, i, labelarr);
break;
default:
emitFallback(build, data, op, i);
break;
}
return skip;
}
static void emitInstFallback(AssemblyBuilderX64& build, NativeState& data, LuauOpcode op, const Instruction* pc, int i, Label* labelarr)
{
switch (op)
{
case LOP_GETIMPORT:
emitInstGetImportFallback(build, pc, i);
break;
case LOP_GETTABLE:
emitInstGetTableFallback(build, pc, i);
break;
case LOP_SETTABLE:
emitInstSetTableFallback(build, pc, i);
break;
case LOP_GETTABLEN:
emitInstGetTableNFallback(build, pc, i);
break;
case LOP_SETTABLEN:
emitInstSetTableNFallback(build, pc, i);
break;
case LOP_JUMPIFEQ:
emitInstJumpIfEqFallback(build, pc, i, labelarr, /* not_ */ false);
break;
case LOP_JUMPIFLE:
emitInstJumpIfCondFallback(build, pc, i, labelarr, ConditionX64::LessEqual);
break;
case LOP_JUMPIFLT:
emitInstJumpIfCondFallback(build, pc, i, labelarr, ConditionX64::Less);
break;
case LOP_JUMPIFNOTEQ:
emitInstJumpIfEqFallback(build, pc, i, labelarr, /* not_ */ true);
break;
case LOP_JUMPIFNOTLE:
emitInstJumpIfCondFallback(build, pc, i, labelarr, ConditionX64::NotLessEqual);
break;
case LOP_JUMPIFNOTLT:
emitInstJumpIfCondFallback(build, pc, i, labelarr, ConditionX64::NotLess);
break;
case LOP_ADD:
emitInstBinaryFallback(build, pc, i, TM_ADD);
break;
case LOP_SUB:
emitInstBinaryFallback(build, pc, i, TM_SUB);
break;
case LOP_MUL:
emitInstBinaryFallback(build, pc, i, TM_MUL);
break;
case LOP_DIV:
emitInstBinaryFallback(build, pc, i, TM_DIV);
break;
case LOP_MOD:
emitInstBinaryFallback(build, pc, i, TM_MOD);
break;
case LOP_POW:
emitInstBinaryFallback(build, pc, i, TM_POW);
break;
case LOP_ADDK:
emitInstBinaryKFallback(build, pc, i, TM_ADD);
break;
case LOP_SUBK:
emitInstBinaryKFallback(build, pc, i, TM_SUB);
break;
case LOP_MULK:
emitInstBinaryKFallback(build, pc, i, TM_MUL);
break;
case LOP_DIVK:
emitInstBinaryKFallback(build, pc, i, TM_DIV);
break;
case LOP_MODK:
emitInstBinaryKFallback(build, pc, i, TM_MOD);
break;
case LOP_POWK:
emitInstBinaryKFallback(build, pc, i, TM_POW);
break;
case LOP_MINUS:
emitInstMinusFallback(build, pc, i);
break;
case LOP_LENGTH:
emitInstLengthFallback(build, pc, i);
break;
case LOP_FORGLOOP:
emitinstForGLoopFallback(build, pc, i, labelarr);
break;
case LOP_FORGPREP_NEXT:
case LOP_FORGPREP_INEXT:
emitInstForGPrepXnextFallback(build, pc, i, labelarr);
break;
case LOP_GETGLOBAL:
// TODO: luaV_gettable + cachedslot update instead of full fallback
emitFallback(build, data, op, i);
break;
case LOP_SETGLOBAL:
// TODO: luaV_settable + cachedslot update instead of full fallback
emitFallback(build, data, op, i);
break;
case LOP_GETTABLEKS:
// Full fallback required for LOP_GETTABLEKS because 'luaV_gettable' doesn't handle builtin vector field access
// It is also required to perform cached slot update
// TODO: extra fast-paths could be lowered before the full fallback
emitFallback(build, data, op, i);
break;
case LOP_SETTABLEKS:
// TODO: luaV_settable + cachedslot update instead of full fallback
emitFallback(build, data, op, i);
break;
default:
LUAU_ASSERT(!"Expected fallback for instruction");
}
}
static NativeProto* assembleFunction(AssemblyBuilderX64& build, NativeState& data, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options)
{
NativeProto* result = new NativeProto();
@ -54,6 +407,14 @@ static NativeProto* assembleFunction(AssemblyBuilderX64& build, NativeState& dat
std::vector<Label> instLabels;
instLabels.resize(proto->sizecode);
std::vector<Label> instFallbacks;
instFallbacks.resize(proto->sizecode);
std::vector<InstructionOutline> instOutlines;
instOutlines.reserve(64);
build.align(kFunctionAlignment, AlignmentDataX64::Ud2);
Label start = build.setLabel();
for (int i = 0; i < proto->sizecode;)
@ -61,172 +422,90 @@ static NativeProto* assembleFunction(AssemblyBuilderX64& build, NativeState& dat
const Instruction* pc = &proto->code[i];
LuauOpcode op = LuauOpcode(LUAU_INSN_OP(*pc));
int nexti = i + getOpLength(op);
LUAU_ASSERT(nexti <= proto->sizecode);
build.setLabel(instLabels[i]);
if (build.logText)
build.logAppend("; #%d: %s\n", i, data.names[op]);
if (options.annotator)
options.annotator(options.annotatorContext, build.text, proto->bytecodeid, i);
switch (op)
int skip = emitInst(build, data, helpers, proto, op, pc, i, instLabels.data(), instFallbacks[i]);
if (skip != 0)
instOutlines.push_back({nexti, skip});
i = nexti + skip;
LUAU_ASSERT(i <= proto->sizecode);
}
size_t textSize = build.text.size();
uint32_t codeSize = build.getCodeSize();
if (options.annotator && !options.skipOutlinedCode)
build.logAppend("; outlined instructions\n");
for (auto [pcpos, length] : instOutlines)
{
int i = pcpos;
while (i < pcpos + length)
{
case LOP_NOP:
break;
case LOP_LOADNIL:
emitInstLoadNil(build, data, pc);
break;
case LOP_LOADB:
emitInstLoadB(build, data, pc, i, instLabels.data());
break;
case LOP_LOADN:
emitInstLoadN(build, data, pc);
break;
case LOP_LOADK:
emitInstLoadK(build, data, pc, proto->k);
break;
case LOP_MOVE:
emitInstMove(build, data, pc);
break;
case LOP_GETTABLE:
emitInstGetTable(build, pc, i);
break;
case LOP_SETTABLE:
emitInstSetTable(build, pc, i);
break;
case LOP_GETTABLEN:
emitInstGetTableN(build, pc, i);
break;
case LOP_SETTABLEN:
emitInstSetTableN(build, pc, i);
break;
case LOP_JUMP:
emitInstJump(build, data, pc, i, instLabels.data());
break;
case LOP_JUMPBACK:
emitInstJumpBack(build, data, pc, i, instLabels.data());
break;
case LOP_JUMPIF:
emitInstJumpIf(build, data, pc, i, instLabels.data(), /* not_ */ false);
break;
case LOP_JUMPIFNOT:
emitInstJumpIf(build, data, pc, i, instLabels.data(), /* not_ */ true);
break;
case LOP_JUMPIFEQ:
emitInstJumpIfEq(build, data, pc, i, instLabels.data(), /* not_ */ false);
break;
case LOP_JUMPIFLE:
emitInstJumpIfCond(build, data, pc, i, instLabels.data(), Condition::LessEqual);
break;
case LOP_JUMPIFLT:
emitInstJumpIfCond(build, data, pc, i, instLabels.data(), Condition::Less);
break;
case LOP_JUMPIFNOTEQ:
emitInstJumpIfEq(build, data, pc, i, instLabels.data(), /* not_ */ true);
break;
case LOP_JUMPIFNOTLE:
emitInstJumpIfCond(build, data, pc, i, instLabels.data(), Condition::NotLessEqual);
break;
case LOP_JUMPIFNOTLT:
emitInstJumpIfCond(build, data, pc, i, instLabels.data(), Condition::NotLess);
break;
case LOP_JUMPX:
emitInstJumpX(build, data, pc, i, instLabels.data());
break;
case LOP_JUMPXEQKNIL:
emitInstJumpxEqNil(build, data, pc, proto->k, i, instLabels.data());
break;
case LOP_JUMPXEQKB:
emitInstJumpxEqB(build, data, pc, proto->k, i, instLabels.data());
break;
case LOP_JUMPXEQKN:
emitInstJumpxEqN(build, data, pc, proto->k, i, instLabels.data());
break;
case LOP_JUMPXEQKS:
emitInstJumpxEqS(build, data, pc, proto->k, i, instLabels.data());
break;
case LOP_ADD:
emitInstAdd(build, pc, i);
break;
case LOP_SUB:
emitInstSub(build, pc, i);
break;
case LOP_MUL:
emitInstMul(build, pc, i);
break;
case LOP_DIV:
emitInstDiv(build, pc, i);
break;
case LOP_MOD:
emitInstMod(build, pc, i);
break;
case LOP_POW:
emitInstPow(build, pc, i);
break;
case LOP_ADDK:
emitInstAddK(build, pc, proto->k, i);
break;
case LOP_SUBK:
emitInstSubK(build, pc, proto->k, i);
break;
case LOP_MULK:
emitInstMulK(build, pc, proto->k, i);
break;
case LOP_DIVK:
emitInstDivK(build, pc, proto->k, i);
break;
case LOP_MODK:
emitInstModK(build, pc, proto->k, i);
break;
case LOP_POWK:
emitInstPowK(build, pc, proto->k, i);
break;
case LOP_NOT:
emitInstNot(build, pc);
break;
case LOP_MINUS:
emitInstMinus(build, pc, i);
break;
case LOP_LENGTH:
emitInstLength(build, pc, i);
break;
case LOP_GETUPVAL:
emitInstGetUpval(build, pc, i);
break;
case LOP_FASTCALL:
emitInstFastCall(build, pc, i, instLabels.data());
break;
case LOP_FASTCALL1:
emitInstFastCall1(build, pc, i, instLabels.data());
break;
case LOP_FASTCALL2:
emitInstFastCall2(build, pc, i, instLabels.data());
break;
case LOP_FASTCALL2K:
emitInstFastCall2K(build, pc, proto->k, i, instLabels.data());
break;
case LOP_FORNPREP:
emitInstForNPrep(build, pc, i, instLabels.data());
break;
case LOP_FORNLOOP:
emitInstForNLoop(build, pc, i, instLabels.data());
break;
case LOP_AND:
emitInstAnd(build, pc);
break;
case LOP_ANDK:
emitInstAndK(build, pc);
break;
case LOP_OR:
emitInstOr(build, pc);
break;
case LOP_ORK:
emitInstOrK(build, pc);
break;
default:
emitFallback(build, data, op, i);
break;
const Instruction* pc = &proto->code[i];
LuauOpcode op = LuauOpcode(LUAU_INSN_OP(*pc));
build.setLabel(instLabels[i]);
if (options.annotator && !options.skipOutlinedCode)
options.annotator(options.annotatorContext, build.text, proto->bytecodeid, i);
int skip = emitInst(build, data, helpers, proto, op, pc, i, instLabels.data(), instFallbacks[i]);
LUAU_ASSERT(skip == 0);
i += getOpLength(op);
}
i += getOpLength(op);
LUAU_ASSERT(i <= proto->sizecode);
if (i < proto->sizecode)
build.jmp(instLabels[i]);
}
if (options.annotator && !options.skipOutlinedCode)
build.logAppend("; outlined code\n");
for (int i = 0, instid = 0; i < proto->sizecode; ++instid)
{
const Instruction* pc = &proto->code[i];
LuauOpcode op = LuauOpcode(LUAU_INSN_OP(*pc));
int nexti = i + getOpLength(op);
LUAU_ASSERT(nexti <= proto->sizecode);
if (instFallbacks[i].id == 0)
{
i = nexti;
continue;
}
if (options.annotator && !options.skipOutlinedCode)
options.annotator(options.annotatorContext, build.text, proto->bytecodeid, instid);
build.setLabel(instFallbacks[i]);
emitInstFallback(build, data, op, pc, i, instLabels.data());
// Jump back to the next instruction handler
if (nexti < proto->sizecode)
build.jmp(instLabels[nexti]);
i = nexti;
}
// Truncate assembly output if we don't care for outlined code part
if (options.skipOutlinedCode)
{
build.text.resize(textSize);
build.logAppend("; skipping %u bytes of outlined code\n", build.getCodeSize() - codeSize);
}
result->instTargets = new uintptr_t[proto->sizecode];
@ -386,13 +665,16 @@ void compile(lua_State* L, int idx)
std::vector<Proto*> protos;
gatherFunctions(protos, clvalue(func)->l.p);
ModuleHelpers helpers;
assembleHelpers(build, helpers);
std::vector<NativeProto*> results;
results.reserve(protos.size());
// Skip protos that have been compiled during previous invocations of CodeGen::compile
for (Proto* p : protos)
if (p && getProtoExecData(p) == nullptr)
results.push_back(assembleFunction(build, *data, p));
results.push_back(assembleFunction(build, *data, helpers, p, {}));
build.finalize();
@ -413,6 +695,9 @@ void compile(lua_State* L, int idx)
{
for (int i = 0; i < result->proto->sizecode; i++)
result->instTargets[i] += uintptr_t(codeStart + result->location);
LUAU_ASSERT(result->proto->sizecode);
result->entryTarget = result->instTargets[0];
}
// Link native proto objects to Proto; the memory is now managed by VM and will be freed via onDestroyFunction
@ -420,29 +705,35 @@ void compile(lua_State* L, int idx)
setProtoExecData(result->proto, result);
}
std::string getAssemblyText(lua_State* L, int idx)
std::string getAssembly(lua_State* L, int idx, AssemblyOptions options)
{
LUAU_ASSERT(lua_isLfunction(L, idx));
const TValue* func = luaA_toobject(L, idx);
AssemblyBuilderX64 build(/* logText= */ true);
AssemblyBuilderX64 build(/* logText= */ !options.outputBinary);
NativeState data;
initFallbackTable(data);
initInstructionNames(data);
std::vector<Proto*> protos;
gatherFunctions(protos, clvalue(func)->l.p);
ModuleHelpers helpers;
assembleHelpers(build, helpers);
for (Proto* p : protos)
if (p)
{
NativeProto* nativeProto = assembleFunction(build, data, p);
NativeProto* nativeProto = assembleFunction(build, data, helpers, p, options);
destroyNativeProto(nativeProto);
}
build.finalize();
return build.text;
if (options.outputBinary)
return std::string(build.code.begin(), build.code.end()) + std::string(build.data.begin(), build.data.end());
else
return build.text;
}
} // namespace CodeGen

View file

@ -0,0 +1,130 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "CodeGenUtils.h"
#include "ldo.h"
#include "ltable.h"
#include "FallbacksProlog.h"
#include <string.h>
namespace Luau
{
namespace CodeGen
{
bool forgLoopNodeIter(lua_State* L, Table* h, int index, TValue* ra)
{
// then we advance index through the hash portion
while (unsigned(index - h->sizearray) < unsigned(1 << h->lsizenode))
{
LuaNode* n = &h->node[index - h->sizearray];
if (!ttisnil(gval(n)))
{
setpvalue(ra + 2, reinterpret_cast<void*>(uintptr_t(index + 1)));
getnodekey(L, ra + 3, n);
setobj(L, ra + 4, gval(n));
return true;
}
index++;
}
return false;
}
bool forgLoopNonTableFallback(lua_State* L, int insnA, int aux)
{
TValue* base = L->base;
TValue* ra = VM_REG(insnA);
// note: it's safe to push arguments past top for complicated reasons (see lvmexecute.cpp)
setobj2s(L, ra + 3 + 2, ra + 2);
setobj2s(L, ra + 3 + 1, ra + 1);
setobj2s(L, ra + 3, ra);
L->top = ra + 3 + 3; // func + 2 args (state and index)
LUAU_ASSERT(L->top <= L->stack_last);
luaD_call(L, ra + 3, uint8_t(aux));
L->top = L->ci->top;
// recompute ra since stack might have been reallocated
base = L->base;
ra = VM_REG(insnA);
// copy first variable back into the iteration index
setobj2s(L, ra + 2, ra + 3);
return !ttisnil(ra + 3);
}
void forgPrepXnextFallback(lua_State* L, TValue* ra, int pc)
{
if (!ttisfunction(ra))
{
Closure* cl = clvalue(L->ci->func);
L->ci->savedpc = cl->l.p->code + pc;
luaG_typeerror(L, ra, "iterate over");
}
}
Closure* callProlog(lua_State* L, TValue* ra, StkId argtop, int nresults)
{
// slow-path: not a function call
if (LUAU_UNLIKELY(!ttisfunction(ra)))
{
luaV_tryfuncTM(L, ra);
argtop++; // __call adds an extra self
}
Closure* ccl = clvalue(ra);
CallInfo* ci = incr_ci(L);
ci->func = ra;
ci->base = ra + 1;
ci->top = argtop + ccl->stacksize; // note: technically UB since we haven't reallocated the stack yet
ci->savedpc = NULL;
ci->flags = 0;
ci->nresults = nresults;
L->base = ci->base;
L->top = argtop;
// note: this reallocs stack, but we don't need to VM_PROTECT this
// this is because we're going to modify base/savedpc manually anyhow
// crucially, we can't use ra/argtop after this line
luaD_checkstack(L, ccl->stacksize);
return ccl;
}
void callEpilogC(lua_State* L, int nresults, int n)
{
// ci is our callinfo, cip is our parent
CallInfo* ci = L->ci;
CallInfo* cip = ci - 1;
// copy return values into parent stack (but only up to nresults!), fill the rest with nil
// note: in MULTRET context nresults starts as -1 so i != 0 condition never activates intentionally
StkId res = ci->func;
StkId vali = L->top - n;
StkId valend = L->top;
int i;
for (i = nresults; i != 0 && vali < valend; i--)
setobj2s(L, res++, vali++);
while (i-- > 0)
setnilvalue(res++);
// pop the stack frame
L->ci = cip;
L->base = cip->base;
L->top = (nresults == LUA_MULTRET) ? res : cip->top;
}
} // namespace CodeGen
} // namespace Luau

View file

@ -0,0 +1,20 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "lobject.h"
namespace Luau
{
namespace CodeGen
{
bool forgLoopNodeIter(lua_State* L, Table* h, int index, TValue* ra);
bool forgLoopNonTableFallback(lua_State* L, int insnA, int aux);
void forgPrepXnextFallback(lua_State* L, TValue* ra, int pc);
Closure* callProlog(lua_State* L, TValue* ra, StkId argtop, int nresults);
void callEpilogC(lua_State* L, int nresults, int n);
} // namespace CodeGen
} // namespace Luau

View file

@ -48,7 +48,7 @@ bool initEntryFunction(NativeState& data)
unwind.start();
if (getCurrentX64ABI() == X64ABI::Windows)
if (build.abi == ABIX64::Windows)
{
// Place arguments in home space
build.mov(qword[rsp + 16], rArg2);
@ -87,7 +87,7 @@ bool initEntryFunction(NativeState& data)
unwind.allocStack(stacksize + localssize);
// Setup frame pointer
build.lea(rbp, qword[rsp + stacksize]);
build.lea(rbp, addr[rsp + stacksize]);
unwind.setupFrameReg(rbp, stacksize);
unwind.finish();
@ -113,7 +113,7 @@ bool initEntryFunction(NativeState& data)
Label returnOff = build.setLabel();
// Cleanup and exit
build.lea(rsp, qword[rbp + localssize]);
build.lea(rsp, addr[rbp + localssize]);
build.pop(r15);
build.pop(r14);
build.pop(r13);
@ -121,7 +121,7 @@ bool initEntryFunction(NativeState& data)
build.pop(rbp);
build.pop(rbx);
if (getCurrentX64ABI() == X64ABI::Windows)
if (build.abi == ABIX64::Windows)
{
build.pop(rsi);
build.pop(rdi);

View file

@ -126,20 +126,5 @@ inline int getOpLength(LuauOpcode op)
}
}
enum class X64ABI
{
Windows,
SystemV,
};
inline X64ABI getCurrentX64ABI()
{
#if defined(_WIN32)
return X64ABI::Windows;
#else
return X64ABI::SystemV;
#endif
}
} // namespace CodeGen
} // namespace Luau

View file

@ -14,7 +14,7 @@ namespace Luau
namespace CodeGen
{
void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs, OperandX64 rhs, Condition cond, Label& label)
void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs, OperandX64 rhs, ConditionX64 cond, Label& label)
{
// Refresher on comi/ucomi EFLAGS:
// CF only: less
@ -35,52 +35,75 @@ void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs,
// And because of NaN, integer check interchangeability like 'not less or equal' <-> 'greater' does not hold
switch (cond)
{
case Condition::NotLessEqual:
case ConditionX64::NotLessEqual:
// (b < a) is the same as !(a <= b). jnae checks CF=1 which means < or NaN
build.jcc(Condition::NotAboveEqual, label);
build.jcc(ConditionX64::NotAboveEqual, label);
break;
case Condition::LessEqual:
case ConditionX64::LessEqual:
// (b >= a) is the same as (a <= b). jae checks CF=0 which means >= and not NaN
build.jcc(Condition::AboveEqual, label);
build.jcc(ConditionX64::AboveEqual, label);
break;
case Condition::NotLess:
case ConditionX64::NotLess:
// (b <= a) is the same as !(a < b). jna checks CF=1 or ZF=1 which means <= or NaN
build.jcc(Condition::NotAbove, label);
build.jcc(ConditionX64::NotAbove, label);
break;
case Condition::Less:
case ConditionX64::Less:
// (b > a) is the same as (a < b). ja checks CF=0 and ZF=0 which means > and not NaN
build.jcc(Condition::Above, label);
build.jcc(ConditionX64::Above, label);
break;
case Condition::NotEqual:
case ConditionX64::NotEqual:
// ZF=0 or PF=1 means != or NaN
build.jcc(Condition::NotZero, label);
build.jcc(Condition::Parity, label);
build.jcc(ConditionX64::NotZero, label);
build.jcc(ConditionX64::Parity, label);
break;
default:
LUAU_ASSERT(!"Unsupported condition");
}
}
void jumpOnAnyCmpFallback(AssemblyBuilderX64& build, int ra, int rb, Condition cond, Label& label, int pcpos)
void jumpOnAnyCmpFallback(AssemblyBuilderX64& build, int ra, int rb, ConditionX64 cond, Label& label, int pcpos)
{
emitSetSavedPc(build, pcpos + 1);
build.mov(rArg1, rState);
build.lea(rArg2, luauRegValue(ra));
build.lea(rArg3, luauRegValue(rb));
if (cond == Condition::NotLessEqual || cond == Condition::LessEqual)
build.mov(rArg1, rState);
build.lea(rArg2, luauRegAddress(ra));
build.lea(rArg3, luauRegAddress(rb));
if (cond == ConditionX64::NotLessEqual || cond == ConditionX64::LessEqual)
build.call(qword[rNativeContext + offsetof(NativeContext, luaV_lessequal)]);
else if (cond == Condition::NotLess || cond == Condition::Less)
else if (cond == ConditionX64::NotLess || cond == ConditionX64::Less)
build.call(qword[rNativeContext + offsetof(NativeContext, luaV_lessthan)]);
else if (cond == Condition::NotEqual || cond == Condition::Equal)
else if (cond == ConditionX64::NotEqual || cond == ConditionX64::Equal)
build.call(qword[rNativeContext + offsetof(NativeContext, luaV_equalval)]);
else
LUAU_ASSERT(!"Unsupported condition");
emitUpdateBase(build);
build.test(eax, eax);
build.jcc(
cond == Condition::NotLessEqual || cond == Condition::NotLess || cond == Condition::NotEqual ? Condition::Zero : Condition::NotZero, label);
build.jcc(cond == ConditionX64::NotLessEqual || cond == ConditionX64::NotLess || cond == ConditionX64::NotEqual ? ConditionX64::Zero
: ConditionX64::NotZero,
label);
}
RegisterX64 getTableNodeAtCachedSlot(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 table, int pcpos)
{
RegisterX64 node = rdx;
LUAU_ASSERT(tmp != node);
LUAU_ASSERT(table != node);
build.mov(node, qword[table + offsetof(Table, node)]);
// compute cached slot
build.mov(tmp, sCode);
build.movzx(dwordReg(tmp), byte[tmp + pcpos * sizeof(Instruction) + kOffsetOfInstructionC]);
build.and_(byteReg(tmp), byte[table + offsetof(Table, nodemask8)]);
// LuaNode* n = &h->node[slot];
build.shl(dwordReg(tmp), kLuaNodeSizeLog2);
build.add(node, tmp);
return node;
}
void convertNumberToIndexOrJump(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 numd, RegisterX64 numi, int ri, Label& label)
@ -98,21 +121,21 @@ void convertNumberToIndexOrJump(AssemblyBuilderX64& build, RegisterX64 tmp, Regi
build.vucomisd(tmp, numd); // Sets ZF=1 if equal or NaN
// We don't need non-integer values
// But to skip the PF=1 check, we proceed with NaN because 0x80000000 index is out of bounds
build.jcc(Condition::NotZero, label);
build.jcc(ConditionX64::NotZero, label);
}
void callArithHelper(AssemblyBuilderX64& build, int ra, int rb, OperandX64 c, int pcpos, TMS tm)
{
emitSetSavedPc(build, pcpos + 1);
if (getCurrentX64ABI() == X64ABI::Windows)
if (build.abi == ABIX64::Windows)
build.mov(sArg5, tm);
else
build.mov(rArg5, tm);
build.mov(rArg1, rState);
build.lea(rArg2, luauRegValue(ra));
build.lea(rArg3, luauRegValue(rb));
build.lea(rArg2, luauRegAddress(ra));
build.lea(rArg3, luauRegAddress(rb));
build.lea(rArg4, c);
build.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarith)]);
@ -124,8 +147,8 @@ void callLengthHelper(AssemblyBuilderX64& build, int ra, int rb, int pcpos)
emitSetSavedPc(build, pcpos + 1);
build.mov(rArg1, rState);
build.lea(rArg2, luauRegValue(ra));
build.lea(rArg3, luauRegValue(rb));
build.lea(rArg2, luauRegAddress(ra));
build.lea(rArg3, luauRegAddress(rb));
build.call(qword[rNativeContext + offsetof(NativeContext, luaV_dolen)]);
emitUpdateBase(build);
@ -136,9 +159,9 @@ void callPrepareForN(AssemblyBuilderX64& build, int limit, int step, int init, i
emitSetSavedPc(build, pcpos + 1);
build.mov(rArg1, rState);
build.lea(rArg2, luauRegValue(limit));
build.lea(rArg3, luauRegValue(step));
build.lea(rArg4, luauRegValue(init));
build.lea(rArg2, luauRegAddress(limit));
build.lea(rArg3, luauRegAddress(step));
build.lea(rArg4, luauRegAddress(init));
build.call(qword[rNativeContext + offsetof(NativeContext, luaV_prepareFORN)]);
}
@ -147,9 +170,9 @@ void callGetTable(AssemblyBuilderX64& build, int rb, OperandX64 c, int ra, int p
emitSetSavedPc(build, pcpos + 1);
build.mov(rArg1, rState);
build.lea(rArg2, luauRegValue(rb));
build.lea(rArg2, luauRegAddress(rb));
build.lea(rArg3, c);
build.lea(rArg4, luauRegValue(ra));
build.lea(rArg4, luauRegAddress(ra));
build.call(qword[rNativeContext + offsetof(NativeContext, luaV_gettable)]);
emitUpdateBase(build);
@ -160,36 +183,78 @@ void callSetTable(AssemblyBuilderX64& build, int rb, OperandX64 c, int ra, int p
emitSetSavedPc(build, pcpos + 1);
build.mov(rArg1, rState);
build.lea(rArg2, luauRegValue(rb));
build.lea(rArg2, luauRegAddress(rb));
build.lea(rArg3, c);
build.lea(rArg4, luauRegValue(ra));
build.lea(rArg4, luauRegAddress(ra));
build.call(qword[rNativeContext + offsetof(NativeContext, luaV_settable)]);
emitUpdateBase(build);
}
void callBarrierTable(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 table, int ra, Label& skip)
// works for luaC_barriertable, luaC_barrierf
static void callBarrierImpl(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 object, int ra, Label& skip, int contextOffset)
{
LUAU_ASSERT(tmp != table);
LUAU_ASSERT(tmp != object);
// iscollectable(ra)
build.cmp(luauRegTag(ra), LUA_TSTRING);
build.jcc(Condition::Less, skip);
build.jcc(ConditionX64::Less, skip);
// isblack(obj2gco(h))
build.test(byte[table + offsetof(GCheader, marked)], bitmask(BLACKBIT));
build.jcc(Condition::Zero, skip);
// isblack(obj2gco(o))
build.test(byte[object + offsetof(GCheader, marked)], bitmask(BLACKBIT));
build.jcc(ConditionX64::Zero, skip);
// iswhite(gcvalue(ra))
build.mov(tmp, luauRegValue(ra));
build.test(byte[tmp + offsetof(GCheader, marked)], bit2mask(WHITE0BIT, WHITE1BIT));
build.jcc(Condition::Zero, skip);
build.jcc(ConditionX64::Zero, skip);
LUAU_ASSERT(table != rArg3);
LUAU_ASSERT(object != rArg3);
build.mov(rArg3, tmp);
build.mov(rArg2, table);
build.mov(rArg2, object);
build.mov(rArg1, rState);
build.call(qword[rNativeContext + offsetof(NativeContext, luaC_barriertable)]);
build.call(qword[rNativeContext + contextOffset]);
}
void callBarrierTable(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 table, int ra, Label& skip)
{
callBarrierImpl(build, tmp, table, ra, skip, offsetof(NativeContext, luaC_barriertable));
}
void callBarrierObject(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 object, int ra, Label& skip)
{
callBarrierImpl(build, tmp, object, ra, skip, offsetof(NativeContext, luaC_barrierf));
}
void callBarrierTableFast(AssemblyBuilderX64& build, RegisterX64 table, Label& skip)
{
// isblack(obj2gco(t))
build.test(byte[table + offsetof(GCheader, marked)], bitmask(BLACKBIT));
build.jcc(ConditionX64::Zero, skip);
// Argument setup re-ordered to avoid conflicts with table register
if (table != rArg2)
build.mov(rArg2, table);
build.lea(rArg3, addr[rArg2 + offsetof(Table, gclist)]);
build.mov(rArg1, rState);
build.call(qword[rNativeContext + offsetof(NativeContext, luaC_barrierback)]);
}
void callCheckGc(AssemblyBuilderX64& build, int pcpos, bool savepc, Label& skip)
{
build.mov(rax, qword[rState + offsetof(lua_State, global)]);
build.mov(rdx, qword[rax + offsetof(global_State, totalbytes)]);
build.cmp(rdx, qword[rax + offsetof(global_State, GCthreshold)]);
build.jcc(ConditionX64::Below, skip);
if (savepc)
emitSetSavedPc(build, pcpos + 1);
build.mov(rArg1, rState);
build.mov(dwordReg(rArg2), 1);
build.call(qword[rNativeContext + offsetof(NativeContext, luaC_step)]);
emitUpdateBase(build);
}
void emitExit(AssemblyBuilderX64& build, bool continueInVm)
@ -224,20 +289,20 @@ void emitInterrupt(AssemblyBuilderX64& build, int pcpos)
build.mov(r8, qword[rState + offsetof(lua_State, global)]);
build.mov(r8, qword[r8 + offsetof(global_State, cb.interrupt)]);
build.test(r8, r8);
build.jcc(Condition::Zero, skip);
build.jcc(ConditionX64::Zero, skip);
emitSetSavedPc(build, pcpos + 1); // uses rax/rdx
// Call interrupt
// TODO: This code should move to the end of the function, or even be outlined so that it can be shared by multiple interruptible instructions
build.mov(rArg1, rState);
build.mov(rArg2d, -1);
build.mov(dwordReg(rArg2), -1); // function accepts 'int' here and using qword reg would've forced 8 byte constant here
build.call(r8);
// Check if we need to exit
build.mov(al, byte[rState + offsetof(lua_State, status)]);
build.test(al, al);
build.jcc(Condition::Zero, skip);
build.jcc(ConditionX64::Zero, skip);
build.mov(rax, qword[rState + offsetof(lua_State, ci)]);
build.sub(qword[rax + offsetof(CallInfo, savedpc)], sizeof(Instruction));
@ -265,17 +330,6 @@ void emitFallback(AssemblyBuilderX64& build, NativeState& data, int op, int pcpo
build.mov(rArg4, rConstants);
build.call(qword[rNativeContext + offsetof(NativeContext, fallback) + op * sizeof(NativeFallback) + offsetof(NativeFallback, fallback)]);
// Some instructions may interrupt the execution
if (opinfo.flags & kFallbackCheckInterrupt)
{
Label skip;
build.test(rax, rax);
build.jcc(Condition::NotZero, skip);
emitExit(build, /* continueInVm */ false);
build.setLabel(skip);
}
emitUpdateBase(build);
// Some instructions may jump to a different instruction or a completely different function
@ -295,50 +349,17 @@ void emitFallback(AssemblyBuilderX64& build, NativeState& data, int op, int pcpo
build.mov(rcx, qword[rdx + offsetof(NativeProto, instTargets)]);
build.jmp(qword[rax * 2 + rcx]);
}
else if (opinfo.flags & kFallbackUpdateCi)
{
// Need to update state of the current function before we jump away
build.mov(rcx, qword[rState + offsetof(lua_State, ci)]); // L->ci
build.mov(rcx, qword[rcx + offsetof(CallInfo, func)]); // L->ci->func
build.mov(rcx, qword[rcx + offsetof(TValue, value.gc)]); // L->ci->func->value.gc aka cl
build.mov(sClosure, rcx);
build.mov(rsi, qword[rcx + offsetof(Closure, l.p)]); // cl->l.p aka proto
build.mov(rConstants, qword[rsi + offsetof(Proto, k)]); // proto->k
build.mov(rcx, qword[rsi + offsetof(Proto, code)]); // proto->code
build.mov(sCode, rcx);
}
// We'll need original instruction pointer later to handle return to interpreter
if (op == LOP_CALL)
build.mov(r9, rax);
void emitContinueCallInVm(AssemblyBuilderX64& build)
{
RegisterX64 proto = rcx; // Sync with emitInstCall
// Get instruction index from instruction pointer
// To get instruction index from instruction pointer, we need to divide byte offset by 4
// But we will actually need to scale instruction index by 8 back to byte offset later so it cancels out
build.sub(rax, sCode);
build.mov(rdx, qword[proto + offsetof(Proto, code)]);
build.mov(rax, qword[rState + offsetof(lua_State, ci)]);
build.mov(qword[rax + offsetof(CallInfo, savedpc)], rdx);
// We need to check if the new function can be executed natively
Label returnToInterpreter;
build.mov(rdx, qword[rsi + offsetofProtoExecData]);
build.test(rdx, rdx);
build.jcc(Condition::Zero, returnToInterpreter);
// Get new instruction location and jump to it
build.mov(rcx, qword[rdx + offsetof(NativeProto, instTargets)]);
build.jmp(qword[rax * 2 + rcx]);
build.setLabel(returnToInterpreter);
// If we are returning to the interpreter to make a call, we need to update the current instruction
if (op == LOP_CALL)
{
build.mov(rax, qword[rState + offsetof(lua_State, ci)]);
build.mov(qword[rax + offsetof(CallInfo, savedpc)], r9);
}
// Continue in the interpreter
emitExit(build, /* continueInVm */ true);
}
emitExit(build, /* continueInVm */ true);
}
} // namespace CodeGen

View file

@ -35,11 +35,11 @@ constexpr RegisterX64 rConstants = r12; // TValue* k
constexpr OperandX64 sClosure = qword[rbp + 0]; // Closure* cl
constexpr OperandX64 sCode = qword[rbp + 8]; // Instruction* code
// TODO: These should be replaced with a portable call function that checks the ABI at runtime and reorders moves accordingly to avoid conflicts
#if defined(_WIN32)
constexpr RegisterX64 rArg1 = rcx;
constexpr RegisterX64 rArg2 = rdx;
constexpr RegisterX64 rArg2d = edx;
constexpr RegisterX64 rArg3 = r8;
constexpr RegisterX64 rArg4 = r9;
constexpr RegisterX64 rArg5 = noreg;
@ -51,7 +51,6 @@ constexpr OperandX64 sArg6 = qword[rsp + 40];
constexpr RegisterX64 rArg1 = rdi;
constexpr RegisterX64 rArg2 = rsi;
constexpr RegisterX64 rArg2d = esi;
constexpr RegisterX64 rArg3 = rdx;
constexpr RegisterX64 rArg4 = rcx;
constexpr RegisterX64 rArg5 = r8;
@ -62,12 +61,30 @@ constexpr OperandX64 sArg6 = noreg;
#endif
constexpr unsigned kTValueSizeLog2 = 4;
constexpr unsigned kLuaNodeSizeLog2 = 5;
constexpr unsigned kLuaNodeTagMask = 0xf;
constexpr unsigned kOffsetOfLuaNodeTag = 12; // offsetof cannot be used on a bit field
constexpr unsigned kOffsetOfInstructionC = 3;
// Leaf functions that are placed in every module to perform common instruction sequences
struct ModuleHelpers
{
Label exitContinueVm;
Label exitNoContinueVm;
Label continueCallInVm;
};
inline OperandX64 luauReg(int ri)
{
return xmmword[rBase + ri * sizeof(TValue)];
}
inline OperandX64 luauRegAddress(int ri)
{
return addr[rBase + ri * sizeof(TValue)];
}
inline OperandX64 luauRegValue(int ri)
{
return qword[rBase + ri * sizeof(TValue) + offsetof(TValue, value)];
@ -88,11 +105,37 @@ inline OperandX64 luauConstant(int ki)
return xmmword[rConstants + ki * sizeof(TValue)];
}
inline OperandX64 luauConstantAddress(int ki)
{
return addr[rConstants + ki * sizeof(TValue)];
}
inline OperandX64 luauConstantTag(int ki)
{
return dword[rConstants + ki * sizeof(TValue) + offsetof(TValue, tt)];
}
inline OperandX64 luauConstantValue(int ki)
{
return qword[rConstants + ki * sizeof(TValue) + offsetof(TValue, value)];
}
inline OperandX64 luauNodeKeyValue(RegisterX64 node)
{
return qword[node + offsetof(LuaNode, key) + offsetof(TKey, value)];
}
// Note: tag has dirty upper bits
inline OperandX64 luauNodeKeyTag(RegisterX64 node)
{
return dword[node + offsetof(LuaNode, key) + kOffsetOfLuaNodeTag];
}
inline OperandX64 luauNodeValue(RegisterX64 node)
{
return xmmword[node + offsetof(LuaNode, val)];
}
inline void setLuauReg(AssemblyBuilderX64& build, RegisterX64 tmp, int ri, OperandX64 op)
{
LUAU_ASSERT(op.cat == CategoryX64::mem);
@ -101,16 +144,24 @@ inline void setLuauReg(AssemblyBuilderX64& build, RegisterX64 tmp, int ri, Opera
build.vmovups(luauReg(ri), tmp);
}
inline void setNodeValue(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 op, int ri)
{
LUAU_ASSERT(op.cat == CategoryX64::mem);
build.vmovups(tmp, luauReg(ri));
build.vmovups(op, tmp);
}
inline void jumpIfTagIs(AssemblyBuilderX64& build, int ri, lua_Type tag, Label& label)
{
build.cmp(luauRegTag(ri), tag);
build.jcc(Condition::Equal, label);
build.jcc(ConditionX64::Equal, label);
}
inline void jumpIfTagIsNot(AssemblyBuilderX64& build, int ri, lua_Type tag, Label& label)
{
build.cmp(luauRegTag(ri), tag);
build.jcc(Condition::NotEqual, label);
build.jcc(ConditionX64::NotEqual, label);
}
// Note: fallthrough label should be placed after this condition
@ -120,7 +171,7 @@ inline void jumpIfFalsy(AssemblyBuilderX64& build, int ri, Label& target, Label&
jumpIfTagIsNot(build, ri, LUA_TBOOLEAN, fallthrough); // true if not nil or boolean
build.cmp(luauRegValueBoolean(ri), 0);
build.jcc(Condition::Equal, target); // true if boolean value is 'true'
build.jcc(ConditionX64::Equal, target); // true if boolean value is 'true'
}
// Note: fallthrough label should be placed after this condition
@ -130,13 +181,13 @@ inline void jumpIfTruthy(AssemblyBuilderX64& build, int ri, Label& target, Label
jumpIfTagIsNot(build, ri, LUA_TBOOLEAN, target); // true if not nil or boolean
build.cmp(luauRegValueBoolean(ri), 0);
build.jcc(Condition::NotEqual, target); // true if boolean value is 'true'
build.jcc(ConditionX64::NotEqual, target); // true if boolean value is 'true'
}
inline void jumpIfMetatablePresent(AssemblyBuilderX64& build, RegisterX64 table, Label& target)
{
build.cmp(qword[table + offsetof(Table, metatable)], 0);
build.jcc(Condition::NotEqual, target);
build.jcc(ConditionX64::NotEqual, target);
}
inline void jumpIfUnsafeEnv(AssemblyBuilderX64& build, RegisterX64 tmp, Label& label)
@ -144,18 +195,46 @@ inline void jumpIfUnsafeEnv(AssemblyBuilderX64& build, RegisterX64 tmp, Label& l
build.mov(tmp, sClosure);
build.mov(tmp, qword[tmp + offsetof(Closure, env)]);
build.test(byte[tmp + offsetof(Table, safeenv)], 1);
build.jcc(Condition::Zero, label); // Not a safe environment
build.jcc(ConditionX64::Zero, label); // Not a safe environment
}
inline void jumpIfTableIsReadOnly(AssemblyBuilderX64& build, RegisterX64 table, Label& label)
{
build.cmp(byte[table + offsetof(Table, readonly)], 0);
build.jcc(Condition::NotEqual, label);
build.jcc(ConditionX64::NotEqual, label);
}
void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs, OperandX64 rhs, Condition cond, Label& label);
void jumpOnAnyCmpFallback(AssemblyBuilderX64& build, int ra, int rb, Condition cond, Label& label, int pcpos);
inline void jumpIfNodeKeyTagIsNot(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 node, lua_Type tag, Label& label)
{
tmp.size = SizeX64::dword;
build.mov(tmp, luauNodeKeyTag(node));
build.and_(tmp, kLuaNodeTagMask);
build.cmp(tmp, tag);
build.jcc(ConditionX64::NotEqual, label);
}
inline void jumpIfNodeValueTagIs(AssemblyBuilderX64& build, RegisterX64 node, lua_Type tag, Label& label)
{
build.cmp(dword[node + offsetof(LuaNode, val) + offsetof(TValue, tt)], tag);
build.jcc(ConditionX64::Equal, label);
}
inline void jumpIfNodeKeyNotInExpectedSlot(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 node, OperandX64 expectedKey, Label& label)
{
jumpIfNodeKeyTagIsNot(build, tmp, node, LUA_TSTRING, label);
build.mov(tmp, expectedKey);
build.cmp(tmp, luauNodeKeyValue(node));
build.jcc(ConditionX64::NotEqual, label);
jumpIfNodeValueTagIs(build, node, LUA_TNIL, label);
}
void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs, OperandX64 rhs, ConditionX64 cond, Label& label);
void jumpOnAnyCmpFallback(AssemblyBuilderX64& build, int ra, int rb, ConditionX64 cond, Label& label, int pcpos);
RegisterX64 getTableNodeAtCachedSlot(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 table, int pcpos);
void convertNumberToIndexOrJump(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 numd, RegisterX64 numi, int ri, Label& label);
void callArithHelper(AssemblyBuilderX64& build, int ra, int rb, OperandX64 c, int pcpos, TMS tm);
@ -164,6 +243,9 @@ void callPrepareForN(AssemblyBuilderX64& build, int limit, int step, int init, i
void callGetTable(AssemblyBuilderX64& build, int rb, OperandX64 c, int ra, int pcpos);
void callSetTable(AssemblyBuilderX64& build, int rb, OperandX64 c, int ra, int pcpos);
void callBarrierTable(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 table, int ra, Label& skip);
void callBarrierObject(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 object, int ra, Label& skip);
void callBarrierTableFast(AssemblyBuilderX64& build, RegisterX64 table, Label& skip);
void callCheckGc(AssemblyBuilderX64& build, int pcpos, bool savepc, Label& skip);
void emitExit(AssemblyBuilderX64& build, bool continueInVm);
void emitUpdateBase(AssemblyBuilderX64& build);
@ -171,5 +253,8 @@ void emitSetSavedPc(AssemblyBuilderX64& build, int pcpos); // Note: only uses ra
void emitInterrupt(AssemblyBuilderX64& build, int pcpos);
void emitFallback(AssemblyBuilderX64& build, NativeState& data, int op, int pcpos);
void emitContinueCallInVm(AssemblyBuilderX64& build);
void emitExitFromLastReturn(AssemblyBuilderX64& build);
} // namespace CodeGen
} // namespace Luau

File diff suppressed because it is too large Load diff

View file

@ -3,6 +3,8 @@
#include <stdint.h>
#include "ltm.h"
typedef uint32_t Instruction;
typedef struct lua_TValue TValue;
@ -12,55 +14,76 @@ namespace CodeGen
{
class AssemblyBuilderX64;
enum class Condition;
enum class ConditionX64 : uint8_t;
struct Label;
struct NativeState;
struct ModuleHelpers;
void emitInstLoadNil(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc);
void emitInstLoadB(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstLoadN(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc);
void emitInstLoadK(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc, const TValue* k);
void emitInstMove(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc);
void emitInstJump(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstJumpBack(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstJumpIf(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc, int pcpos, Label* labelarr, bool not_);
void emitInstJumpIfEq(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc, int pcpos, Label* labelarr, bool not_);
void emitInstJumpIfCond(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc, int pcpos, Label* labelarr, Condition cond);
void emitInstJumpX(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstJumpxEqNil(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc, const TValue* k, int pcpos, Label* labelarr);
void emitInstJumpxEqB(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc, const TValue* k, int pcpos, Label* labelarr);
void emitInstJumpxEqN(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc, const TValue* k, int pcpos, Label* labelarr);
void emitInstJumpxEqS(AssemblyBuilderX64& build, NativeState& data, const Instruction* pc, const TValue* k, int pcpos, Label* labelarr);
void emitInstAdd(AssemblyBuilderX64& build, const Instruction* pc, int pcpos);
void emitInstSub(AssemblyBuilderX64& build, const Instruction* pc, int pcpos);
void emitInstMul(AssemblyBuilderX64& build, const Instruction* pc, int pcpos);
void emitInstDiv(AssemblyBuilderX64& build, const Instruction* pc, int pcpos);
void emitInstMod(AssemblyBuilderX64& build, const Instruction* pc, int pcpos);
void emitInstPow(AssemblyBuilderX64& build, const Instruction* pc, int pcpos);
void emitInstAddK(AssemblyBuilderX64& build, const Instruction* pc, const TValue* k, int pcpos);
void emitInstSubK(AssemblyBuilderX64& build, const Instruction* pc, const TValue* k, int pcpos);
void emitInstMulK(AssemblyBuilderX64& build, const Instruction* pc, const TValue* k, int pcpos);
void emitInstDivK(AssemblyBuilderX64& build, const Instruction* pc, const TValue* k, int pcpos);
void emitInstModK(AssemblyBuilderX64& build, const Instruction* pc, const TValue* k, int pcpos);
void emitInstPowK(AssemblyBuilderX64& build, const Instruction* pc, const TValue* k, int pcpos);
void emitInstLoadNil(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstLoadB(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstLoadN(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstLoadK(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstLoadKX(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstMove(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstCall(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstJump(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstJumpBack(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstJumpIf(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, bool not_);
void emitInstJumpIfEq(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, bool not_, Label& fallback);
void emitInstJumpIfEqFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, bool not_);
void emitInstJumpIfCond(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, ConditionX64 cond, Label& fallback);
void emitInstJumpIfCondFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, ConditionX64 cond);
void emitInstJumpX(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstJumpxEqNil(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstJumpxEqB(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstJumpxEqN(AssemblyBuilderX64& build, const Instruction* pc, const TValue* k, int pcpos, Label* labelarr);
void emitInstJumpxEqS(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstBinary(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, TMS tm, Label& fallback);
void emitInstBinaryFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, TMS tm);
void emitInstBinaryK(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, TMS tm, Label& fallback);
void emitInstBinaryKFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, TMS tm);
void emitInstPowK(AssemblyBuilderX64& build, const Instruction* pc, const TValue* k, int pcpos, Label& fallback);
void emitInstNot(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstMinus(AssemblyBuilderX64& build, const Instruction* pc, int pcpos);
void emitInstLength(AssemblyBuilderX64& build, const Instruction* pc, int pcpos);
void emitInstMinus(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback);
void emitInstMinusFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos);
void emitInstLength(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback);
void emitInstLengthFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos);
void emitInstNewTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstDupTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstSetList(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstGetUpval(AssemblyBuilderX64& build, const Instruction* pc, int pcpos);
void emitInstFastCall1(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstFastCall2(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstFastCall2K(AssemblyBuilderX64& build, const Instruction* pc, const TValue* k, int pcpos, Label* labelarr);
void emitInstFastCall(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstSetUpval(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstCloseUpvals(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
int emitInstFastCall1(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
int emitInstFastCall2(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
int emitInstFastCall2K(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
int emitInstFastCall(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstForNPrep(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstForNLoop(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitinstForGLoop(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Label& fallback);
void emitinstForGLoopFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstForGPrepNext(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Label& fallback);
void emitInstForGPrepInext(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Label& fallback);
void emitInstForGPrepXnextFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstAnd(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstAndK(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstOr(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstOrK(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstGetTableN(AssemblyBuilderX64& build, const Instruction* pc, int pcpos);
void emitInstSetTableN(AssemblyBuilderX64& build, const Instruction* pc, int pcpos);
void emitInstGetTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpos);
void emitInstSetTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpos);
void emitInstGetTableN(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback);
void emitInstGetTableNFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos);
void emitInstSetTableN(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Label& fallback);
void emitInstSetTableNFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos);
void emitInstGetTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback);
void emitInstGetTableFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos);
void emitInstSetTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Label& fallback);
void emitInstSetTableFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos);
void emitInstGetImport(AssemblyBuilderX64& build, const Instruction* pc, Label& fallback);
void emitInstGetImportFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos);
void emitInstGetTableKS(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback);
void emitInstSetTableKS(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Label& fallback);
void emitInstGetGlobal(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback);
void emitInstSetGlobal(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Label& fallback);
void emitInstConcat(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
} // namespace CodeGen
} // namespace Luau

File diff suppressed because it is too large Load diff

View file

@ -10,84 +10,15 @@ typedef uint32_t Instruction;
typedef struct lua_TValue TValue;
typedef TValue* StkId;
const Instruction* execute_LOP_NOP(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_LOADNIL(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_LOADB(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_LOADN(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_LOADK(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_MOVE(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_GETGLOBAL(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_SETGLOBAL(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_GETUPVAL(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_SETUPVAL(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_CLOSEUPVALS(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_GETIMPORT(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_GETTABLEKS(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_SETTABLEKS(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_GETTABLE(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_SETTABLE(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_GETTABLEN(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_SETTABLEN(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_NEWCLOSURE(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_NAMECALL(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_CALL(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_RETURN(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMP(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPIF(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPIFNOT(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPIFEQ(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPIFNOTEQ(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPIFLE(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPIFNOTLE(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPIFLT(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPIFNOTLT(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_ADD(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_SUB(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_MUL(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_DIV(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_MOD(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_POW(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_ADDK(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_SUBK(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_MULK(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_DIVK(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_MODK(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_POWK(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_AND(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_OR(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_ANDK(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_ORK(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_CONCAT(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_NOT(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_MINUS(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_LENGTH(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_NEWTABLE(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_DUPTABLE(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_SETLIST(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_FORNPREP(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_FORNLOOP(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_FORGPREP(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_FORGLOOP(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_FORGPREP_INEXT(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_DEP_FORGLOOP_INEXT(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_FORGPREP_NEXT(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_DEP_FORGLOOP_NEXT(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_GETVARARGS(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_DUPCLOSURE(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_PREPVARARGS(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPBACK(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_LOADKX(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPX(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_FASTCALL(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_COVERAGE(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_CAPTURE(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_DEP_JUMPIFEQK(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_DEP_JUMPIFNOTEQK(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_FASTCALL1(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_FASTCALL2(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_FASTCALL2K(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_BREAK(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPXEQKNIL(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPXEQKB(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPXEQKN(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPXEQKS(lua_State* L, const Instruction* pc, StkId base, TValue* k);

View file

@ -3,46 +3,20 @@
#include "Luau/UnwindBuilder.h"
#include "CodeGenUtils.h"
#include "CustomExecUtils.h"
#include "Fallbacks.h"
#include "lbuiltins.h"
#include "lgc.h"
#include "ltable.h"
#include "lfunc.h"
#include "lvm.h"
#include <math.h>
#include <string.h>
#define CODEGEN_SET_FALLBACK(op, flags) data.context.fallback[op] = {execute_##op, flags}
#define CODEGEN_SET_NAME(op) data.names[op] = #op
// Similar to a dispatch table in lvmexecute.cpp
#define CODEGEN_SET_NAMES() \
CODEGEN_SET_NAME(LOP_NOP), CODEGEN_SET_NAME(LOP_BREAK), CODEGEN_SET_NAME(LOP_LOADNIL), CODEGEN_SET_NAME(LOP_LOADB), CODEGEN_SET_NAME(LOP_LOADN), \
CODEGEN_SET_NAME(LOP_LOADK), CODEGEN_SET_NAME(LOP_MOVE), CODEGEN_SET_NAME(LOP_GETGLOBAL), CODEGEN_SET_NAME(LOP_SETGLOBAL), \
CODEGEN_SET_NAME(LOP_GETUPVAL), CODEGEN_SET_NAME(LOP_SETUPVAL), CODEGEN_SET_NAME(LOP_CLOSEUPVALS), CODEGEN_SET_NAME(LOP_GETIMPORT), \
CODEGEN_SET_NAME(LOP_GETTABLE), CODEGEN_SET_NAME(LOP_SETTABLE), CODEGEN_SET_NAME(LOP_GETTABLEKS), CODEGEN_SET_NAME(LOP_SETTABLEKS), \
CODEGEN_SET_NAME(LOP_GETTABLEN), CODEGEN_SET_NAME(LOP_SETTABLEN), CODEGEN_SET_NAME(LOP_NEWCLOSURE), CODEGEN_SET_NAME(LOP_NAMECALL), \
CODEGEN_SET_NAME(LOP_CALL), CODEGEN_SET_NAME(LOP_RETURN), CODEGEN_SET_NAME(LOP_JUMP), CODEGEN_SET_NAME(LOP_JUMPBACK), \
CODEGEN_SET_NAME(LOP_JUMPIF), CODEGEN_SET_NAME(LOP_JUMPIFNOT), CODEGEN_SET_NAME(LOP_JUMPIFEQ), CODEGEN_SET_NAME(LOP_JUMPIFLE), \
CODEGEN_SET_NAME(LOP_JUMPIFLT), CODEGEN_SET_NAME(LOP_JUMPIFNOTEQ), CODEGEN_SET_NAME(LOP_JUMPIFNOTLE), CODEGEN_SET_NAME(LOP_JUMPIFNOTLT), \
CODEGEN_SET_NAME(LOP_ADD), CODEGEN_SET_NAME(LOP_SUB), CODEGEN_SET_NAME(LOP_MUL), CODEGEN_SET_NAME(LOP_DIV), CODEGEN_SET_NAME(LOP_MOD), \
CODEGEN_SET_NAME(LOP_POW), CODEGEN_SET_NAME(LOP_ADDK), CODEGEN_SET_NAME(LOP_SUBK), CODEGEN_SET_NAME(LOP_MULK), CODEGEN_SET_NAME(LOP_DIVK), \
CODEGEN_SET_NAME(LOP_MODK), CODEGEN_SET_NAME(LOP_POWK), CODEGEN_SET_NAME(LOP_AND), CODEGEN_SET_NAME(LOP_OR), CODEGEN_SET_NAME(LOP_ANDK), \
CODEGEN_SET_NAME(LOP_ORK), CODEGEN_SET_NAME(LOP_CONCAT), CODEGEN_SET_NAME(LOP_NOT), CODEGEN_SET_NAME(LOP_MINUS), \
CODEGEN_SET_NAME(LOP_LENGTH), CODEGEN_SET_NAME(LOP_NEWTABLE), CODEGEN_SET_NAME(LOP_DUPTABLE), CODEGEN_SET_NAME(LOP_SETLIST), \
CODEGEN_SET_NAME(LOP_FORNPREP), CODEGEN_SET_NAME(LOP_FORNLOOP), CODEGEN_SET_NAME(LOP_FORGLOOP), CODEGEN_SET_NAME(LOP_FORGPREP_INEXT), \
CODEGEN_SET_NAME(LOP_DEP_FORGLOOP_INEXT), CODEGEN_SET_NAME(LOP_FORGPREP_NEXT), CODEGEN_SET_NAME(LOP_DEP_FORGLOOP_NEXT), \
CODEGEN_SET_NAME(LOP_GETVARARGS), CODEGEN_SET_NAME(LOP_DUPCLOSURE), CODEGEN_SET_NAME(LOP_PREPVARARGS), CODEGEN_SET_NAME(LOP_LOADKX), \
CODEGEN_SET_NAME(LOP_JUMPX), CODEGEN_SET_NAME(LOP_FASTCALL), CODEGEN_SET_NAME(LOP_COVERAGE), CODEGEN_SET_NAME(LOP_CAPTURE), \
CODEGEN_SET_NAME(LOP_DEP_JUMPIFEQK), CODEGEN_SET_NAME(LOP_DEP_JUMPIFNOTEQK), CODEGEN_SET_NAME(LOP_FASTCALL1), \
CODEGEN_SET_NAME(LOP_FASTCALL2), CODEGEN_SET_NAME(LOP_FASTCALL2K), CODEGEN_SET_NAME(LOP_FORGPREP), CODEGEN_SET_NAME(LOP_JUMPXEQKNIL), \
CODEGEN_SET_NAME(LOP_JUMPXEQKB), CODEGEN_SET_NAME(LOP_JUMPXEQKN), CODEGEN_SET_NAME(LOP_JUMPXEQKS)
static int luauF_missing(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams)
{
return -1;
}
namespace Luau
{
@ -61,41 +35,27 @@ NativeState::~NativeState() = default;
void initFallbackTable(NativeState& data)
{
CODEGEN_SET_FALLBACK(LOP_GETGLOBAL, 0);
CODEGEN_SET_FALLBACK(LOP_SETGLOBAL, 0);
CODEGEN_SET_FALLBACK(LOP_SETUPVAL, 0);
CODEGEN_SET_FALLBACK(LOP_CLOSEUPVALS, 0);
CODEGEN_SET_FALLBACK(LOP_GETIMPORT, 0);
CODEGEN_SET_FALLBACK(LOP_GETTABLEKS, 0);
CODEGEN_SET_FALLBACK(LOP_SETTABLEKS, 0);
CODEGEN_SET_FALLBACK(LOP_NEWCLOSURE, kFallbackUpdatePc);
// When fallback is completely removed, remove it from includeInsts list in lvmexecute_split.py
CODEGEN_SET_FALLBACK(LOP_NEWCLOSURE, 0);
CODEGEN_SET_FALLBACK(LOP_NAMECALL, 0);
CODEGEN_SET_FALLBACK(LOP_CALL, kFallbackUpdateCi | kFallbackCheckInterrupt);
CODEGEN_SET_FALLBACK(LOP_RETURN, kFallbackUpdateCi | kFallbackCheckInterrupt);
CODEGEN_SET_FALLBACK(LOP_CONCAT, 0);
CODEGEN_SET_FALLBACK(LOP_NEWTABLE, 0);
CODEGEN_SET_FALLBACK(LOP_DUPTABLE, 0);
CODEGEN_SET_FALLBACK(LOP_SETLIST, kFallbackCheckInterrupt);
CODEGEN_SET_FALLBACK(LOP_FORGPREP, kFallbackUpdatePc);
CODEGEN_SET_FALLBACK(LOP_FORGLOOP, kFallbackUpdatePc | kFallbackCheckInterrupt);
CODEGEN_SET_FALLBACK(LOP_FORGPREP_INEXT, kFallbackUpdatePc);
CODEGEN_SET_FALLBACK(LOP_FORGPREP_NEXT, kFallbackUpdatePc);
CODEGEN_SET_FALLBACK(LOP_GETVARARGS, 0);
CODEGEN_SET_FALLBACK(LOP_DUPCLOSURE, 0);
CODEGEN_SET_FALLBACK(LOP_PREPVARARGS, 0);
CODEGEN_SET_FALLBACK(LOP_LOADKX, 0);
CODEGEN_SET_FALLBACK(LOP_COVERAGE, 0);
CODEGEN_SET_FALLBACK(LOP_BREAK, 0);
// Fallbacks that are called from partial implementation of an instruction
CODEGEN_SET_FALLBACK(LOP_GETGLOBAL, 0);
CODEGEN_SET_FALLBACK(LOP_SETGLOBAL, 0);
CODEGEN_SET_FALLBACK(LOP_GETTABLEKS, 0);
CODEGEN_SET_FALLBACK(LOP_SETTABLEKS, 0);
}
void initHelperFunctions(NativeState& data)
{
static_assert(sizeof(data.context.luauF_table) / sizeof(data.context.luauF_table[0]) == sizeof(luauF_table) / sizeof(luauF_table[0]),
"fast call tables are not of the same length");
// Replace missing fast call functions with an empty placeholder that forces LOP_CALL fallback
for (size_t i = 0; i < sizeof(data.context.luauF_table) / sizeof(data.context.luauF_table[0]); i++)
data.context.luauF_table[i] = luauF_table[i] ? luauF_table[i] : luauF_missing;
static_assert(sizeof(data.context.luauF_table) == sizeof(luauF_table), "fastcall tables are not of the same length");
memcpy(data.context.luauF_table, luauF_table, sizeof(luauF_table));
data.context.luaV_lessthan = luaV_lessthan;
data.context.luaV_lessequal = luaV_lessequal;
@ -105,17 +65,28 @@ void initHelperFunctions(NativeState& data)
data.context.luaV_prepareFORN = luaV_prepareFORN;
data.context.luaV_gettable = luaV_gettable;
data.context.luaV_settable = luaV_settable;
data.context.luaV_getimport = luaV_getimport;
data.context.luaV_concat = luaV_concat;
data.context.luaH_getn = luaH_getn;
data.context.luaH_new = luaH_new;
data.context.luaH_clone = luaH_clone;
data.context.luaH_resizearray = luaH_resizearray;
data.context.luaC_barriertable = luaC_barriertable;
data.context.luaC_barrierf = luaC_barrierf;
data.context.luaC_barrierback = luaC_barrierback;
data.context.luaC_step = luaC_step;
data.context.luaF_close = luaF_close;
data.context.libm_pow = pow;
}
void initInstructionNames(NativeState& data)
{
CODEGEN_SET_NAMES();
data.context.forgLoopNodeIter = forgLoopNodeIter;
data.context.forgLoopNonTableFallback = forgLoopNonTableFallback;
data.context.forgPrepXnextFallback = forgPrepXnextFallback;
data.context.callProlog = callProlog;
data.context.callEpilogC = callEpilogC;
}
} // namespace CodeGen

View file

@ -3,13 +3,16 @@
#include "Luau/Bytecode.h"
#include "Luau/CodeAllocator.h"
#include "Luau/Label.h"
#include <memory>
#include <stdint.h>
#include "ldebug.h"
#include "lobject.h"
#include "ltm.h"
#include "lstate.h"
typedef int (*luau_FastFunction)(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams);
@ -23,8 +26,6 @@ class UnwindBuilder;
using FallbackFn = const Instruction*(lua_State* L, const Instruction* pc, StkId base, TValue* k);
constexpr uint8_t kFallbackUpdatePc = 1 << 0;
constexpr uint8_t kFallbackUpdateCi = 1 << 1;
constexpr uint8_t kFallbackCheckInterrupt = 1 << 2;
struct NativeFallback
{
@ -34,7 +35,8 @@ struct NativeFallback
struct NativeProto
{
uintptr_t* instTargets = nullptr;
uintptr_t entryTarget = 0;
uintptr_t* instTargets = nullptr; // TODO: NativeProto should be variable-size with all target embedded
Proto* proto = nullptr;
uint32_t location = 0;
@ -61,12 +63,29 @@ struct NativeContext
void (*luaV_prepareFORN)(lua_State* L, StkId plimit, StkId pstep, StkId pinit) = nullptr;
void (*luaV_gettable)(lua_State* L, const TValue* t, TValue* key, StkId val) = nullptr;
void (*luaV_settable)(lua_State* L, const TValue* t, TValue* key, StkId val) = nullptr;
void (*luaV_getimport)(lua_State* L, Table* env, TValue* k, uint32_t id, bool propagatenil) = nullptr;
void (*luaV_concat)(lua_State* L, int total, int last) = nullptr;
int (*luaH_getn)(Table* t) = nullptr;
Table* (*luaH_new)(lua_State* L, int narray, int lnhash) = nullptr;
Table* (*luaH_clone)(lua_State* L, Table* tt) = nullptr;
void (*luaH_resizearray)(lua_State* L, Table* t, int nasize) = nullptr;
void (*luaC_barriertable)(lua_State* L, Table* t, GCObject* v) = nullptr;
void (*luaC_barrierf)(lua_State* L, GCObject* o, GCObject* v) = nullptr;
void (*luaC_barrierback)(lua_State* L, GCObject* o, GCObject** gclist) = nullptr;
size_t (*luaC_step)(lua_State* L, bool assist) = nullptr;
void (*luaF_close)(lua_State* L, StkId level) = nullptr;
double (*libm_pow)(double, double) = nullptr;
// Helper functions
bool (*forgLoopNodeIter)(lua_State* L, Table* h, int index, TValue* ra) = nullptr;
bool (*forgLoopNonTableFallback)(lua_State* L, int insnA, int aux) = nullptr;
void (*forgPrepXnextFallback)(lua_State* L, TValue* ra, int pc) = nullptr;
Closure* (*callProlog)(lua_State* L, TValue* ra, StkId argtop, int nresults) = nullptr;
void (*callEpilogC)(lua_State* L, int nresults, int n) = nullptr;
};
struct NativeState
@ -77,9 +96,6 @@ struct NativeState
CodeAllocator codeAllocator;
std::unique_ptr<UnwindBuilder> unwindBuilder;
// For annotations in assembly text generation
const char* names[LOP__COUNT] = {};
uint8_t* gateData = nullptr;
size_t gateDataSize = 0;
@ -88,7 +104,6 @@ struct NativeState
void initFallbackTable(NativeState& data);
void initHelperFunctions(NativeState& data);
void initInstructionNames(NativeState& data);
} // namespace CodeGen
} // namespace Luau

View file

@ -20,6 +20,10 @@
#define LUAU_DEBUGBREAK() __builtin_trap()
#endif
#if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
#define LUAU_BIG_ENDIAN
#endif
namespace Luau
{

View file

@ -117,6 +117,8 @@ public:
std::string dumpEverything() const;
std::string dumpSourceRemarks() const;
void annotateInstruction(std::string& result, uint32_t fid, uint32_t instpos) const;
static uint32_t getImportId(int32_t id0);
static uint32_t getImportId(int32_t id0, int32_t id1);
static uint32_t getImportId(int32_t id0, int32_t id1, int32_t id2);
@ -179,6 +181,7 @@ private:
std::string dump;
std::string dumpname;
std::vector<int> dumpinstoffs;
};
struct DebugLocal
@ -251,11 +254,13 @@ private:
std::vector<std::string> dumpSource;
std::vector<std::pair<int, std::string>> dumpRemarks;
std::string (BytecodeBuilder::*dumpFunctionPtr)() const = nullptr;
std::string (BytecodeBuilder::*dumpFunctionPtr)(std::vector<int>&) const = nullptr;
void validate() const;
void validateInstructions() const;
void validateVariadic() const;
std::string dumpCurrentFunction() const;
std::string dumpCurrentFunction(std::vector<int>& dumpinstoffs) const;
void dumpInstruction(const uint32_t* opcode, std::string& output, int targetLabel) const;
void writeFunction(std::string& ss, uint32_t id) const;

View file

@ -4,8 +4,6 @@
#include "Luau/Bytecode.h"
#include "Luau/Compiler.h"
LUAU_FASTFLAGVARIABLE(LuauCompileBuiltinMT, false)
namespace Luau
{
namespace Compile
@ -66,13 +64,10 @@ static int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& op
if (builtin.isGlobal("select"))
return LBF_SELECT_VARARG;
if (FFlag::LuauCompileBuiltinMT)
{
if (builtin.isGlobal("getmetatable"))
return LBF_GETMETATABLE;
if (builtin.isGlobal("setmetatable"))
return LBF_SETMETATABLE;
}
if (builtin.isGlobal("getmetatable"))
return LBF_GETMETATABLE;
if (builtin.isGlobal("setmetatable"))
return LBF_SETMETATABLE;
if (builtin.object == "math")
{

View file

@ -269,7 +269,7 @@ void BytecodeBuilder::endFunction(uint8_t maxstacksize, uint8_t numupvalues)
// this call is indirect to make sure we only gain link time dependency on dumpCurrentFunction when needed
if (dumpFunctionPtr)
func.dump = (this->*dumpFunctionPtr)();
func.dump = (this->*dumpFunctionPtr)(func.dumpinstoffs);
insns.clear();
lines.clear();
@ -1078,6 +1078,12 @@ uint8_t BytecodeBuilder::getVersion()
#ifdef LUAU_ASSERTENABLED
void BytecodeBuilder::validate() const
{
validateInstructions();
validateVariadic();
}
void BytecodeBuilder::validateInstructions() const
{
#define VREG(v) LUAU_ASSERT(unsigned(v) < func.maxstacksize)
#define VREGRANGE(v, count) LUAU_ASSERT(unsigned(v + (count < 0 ? 0 : count)) <= func.maxstacksize)
@ -1090,26 +1096,27 @@ void BytecodeBuilder::validate() const
const Function& func = functions[currentFunction];
// first pass: tag instruction offsets so that we can validate jumps
std::vector<uint8_t> insnvalid(insns.size(), false);
// tag instruction offsets so that we can validate jumps
std::vector<uint8_t> insnvalid(insns.size(), 0);
for (size_t i = 0; i < insns.size();)
{
uint8_t op = LUAU_INSN_OP(insns[i]);
uint32_t insn = insns[i];
LuauOpcode op = LuauOpcode(LUAU_INSN_OP(insn));
insnvalid[i] = true;
i += getOpLength(LuauOpcode(op));
i += getOpLength(op);
LUAU_ASSERT(i <= insns.size());
}
std::vector<uint8_t> openCaptures;
// second pass: validate the rest of the bytecode
// validate individual instructions
for (size_t i = 0; i < insns.size();)
{
uint32_t insn = insns[i];
uint8_t op = LUAU_INSN_OP(insn);
LuauOpcode op = LuauOpcode(LUAU_INSN_OP(insn));
switch (op)
{
@ -1452,7 +1459,7 @@ void BytecodeBuilder::validate() const
LUAU_ASSERT(!"Unsupported opcode");
}
i += getOpLength(LuauOpcode(op));
i += getOpLength(op);
LUAU_ASSERT(i <= insns.size());
}
@ -1469,6 +1476,126 @@ void BytecodeBuilder::validate() const
#undef VCONSTANY
#undef VJUMP
}
void BytecodeBuilder::validateVariadic() const
{
// validate MULTRET sequences: instructions that produce a variadic sequence and consume one must come in pairs
// we classify instructions into four groups: producers, consumers, neutral and others
// any producer (an instruction that produces more than one value) must be followed by 0 or more neutral instructions
// and a consumer (that consumes more than one value); these form a variadic sequence.
// except for producer, no instruction in the variadic sequence may be a jump target.
// from the execution perspective, producer adjusts L->top to point to one past the last result, neutral instructions
// leave L->top unmodified, and consumer adjusts L->top back to the stack frame end.
// consumers invalidate all values after L->top after they execute (which we currently don't validate)
bool variadicSeq = false;
std::vector<uint8_t> insntargets(insns.size(), 0);
for (size_t i = 0; i < insns.size();)
{
uint32_t insn = insns[i];
LuauOpcode op = LuauOpcode(LUAU_INSN_OP(insn));
int target = getJumpTarget(insn, uint32_t(i));
if (target >= 0 && !isFastCall(op))
{
LUAU_ASSERT(unsigned(target) < insns.size());
insntargets[target] = true;
}
i += getOpLength(op);
LUAU_ASSERT(i <= insns.size());
}
for (size_t i = 0; i < insns.size();)
{
uint32_t insn = insns[i];
LuauOpcode op = LuauOpcode(LUAU_INSN_OP(insn));
if (variadicSeq)
{
// no instruction inside the sequence, including the consumer, may be a jump target
// this guarantees uninterrupted L->top adjustment flow
LUAU_ASSERT(!insntargets[i]);
}
if (op == LOP_CALL)
{
// note: calls may end one variadic sequence and start a new one
if (LUAU_INSN_B(insn) == 0)
{
// consumer instruction ens a variadic sequence
LUAU_ASSERT(variadicSeq);
variadicSeq = false;
}
else
{
// CALL is not a neutral instruction so it can't be present in a variadic sequence unless it's a consumer
LUAU_ASSERT(!variadicSeq);
}
if (LUAU_INSN_C(insn) == 0)
{
// producer instruction starts a variadic sequence
LUAU_ASSERT(!variadicSeq);
variadicSeq = true;
}
}
else if (op == LOP_GETVARARGS && LUAU_INSN_B(insn) == 0)
{
// producer instruction starts a variadic sequence
LUAU_ASSERT(!variadicSeq);
variadicSeq = true;
}
else if ((op == LOP_RETURN && LUAU_INSN_B(insn) == 0) || (op == LOP_SETLIST && LUAU_INSN_C(insn) == 0))
{
// consumer instruction ends a variadic sequence
LUAU_ASSERT(variadicSeq);
variadicSeq = false;
}
else if (op == LOP_FASTCALL)
{
int callTarget = int(i + LUAU_INSN_C(insn) + 1);
LUAU_ASSERT(unsigned(callTarget) < insns.size() && LUAU_INSN_OP(insns[callTarget]) == LOP_CALL);
if (LUAU_INSN_B(insns[callTarget]) == 0)
{
// consumer instruction ends a variadic sequence; however, we can't terminate it yet because future analysis of CALL will do it
// during FASTCALL fallback, the instructions between this and CALL consumer are going to be executed before L->top so they must
// be neutral; as such, we will defer termination of variadic sequence until CALL analysis
LUAU_ASSERT(variadicSeq);
}
else
{
// FASTCALL is not a neutral instruction so it can't be present in a variadic sequence unless it's linked to CALL consumer
LUAU_ASSERT(!variadicSeq);
}
// note: if FASTCALL is linked to a CALL producer, the instructions between FASTCALL and CALL are technically not part of an executed
// variadic sequence since they are never executed if FASTCALL does anything, so it's okay to skip their validation until CALL
// (we can't simply start a variadic sequence here because that would trigger assertions during linked CALL validation)
}
else if (op == LOP_CLOSEUPVALS || op == LOP_NAMECALL || op == LOP_GETIMPORT || op == LOP_MOVE || op == LOP_GETUPVAL || op == LOP_GETGLOBAL ||
op == LOP_GETTABLEKS || op == LOP_COVERAGE)
{
// instructions inside a variadic sequence must be neutral (can't change L->top)
// while there are many neutral instructions like this, here we check that the instruction is one of the few
// that we'd expect to exist in FASTCALL fallback sequences or between consecutive CALLs for encoding reasons
}
else
{
LUAU_ASSERT(!variadicSeq);
}
i += getOpLength(op);
LUAU_ASSERT(i <= insns.size());
}
LUAU_ASSERT(!variadicSeq);
}
#endif
void BytecodeBuilder::dumpInstruction(const uint32_t* code, std::string& result, int targetLabel) const
@ -1800,7 +1927,7 @@ void BytecodeBuilder::dumpInstruction(const uint32_t* code, std::string& result,
}
}
std::string BytecodeBuilder::dumpCurrentFunction() const
std::string BytecodeBuilder::dumpCurrentFunction(std::vector<int>& dumpinstoffs) const
{
if ((dumpFlags & Dump_Code) == 0)
return std::string();
@ -1850,11 +1977,15 @@ std::string BytecodeBuilder::dumpCurrentFunction() const
if (labels[i] == 0)
labels[i] = nextLabel++;
dumpinstoffs.resize(insns.size() + 1, -1);
for (size_t i = 0; i < insns.size();)
{
const uint32_t* code = &insns[i];
uint8_t op = LUAU_INSN_OP(*code);
dumpinstoffs[i] = int(result.size());
if (op == LOP_PREPVARARGS)
{
// Don't emit function header in bytecode - it's used for call dispatching and doesn't contain "interesting" information
@ -1897,6 +2028,8 @@ std::string BytecodeBuilder::dumpCurrentFunction() const
LUAU_ASSERT(i <= insns.size());
}
dumpinstoffs[insns.size()] = int(result.size());
return result;
}
@ -1986,4 +2119,26 @@ std::string BytecodeBuilder::dumpSourceRemarks() const
return result;
}
void BytecodeBuilder::annotateInstruction(std::string& result, uint32_t fid, uint32_t instpos) const
{
if ((dumpFlags & Dump_Code) == 0)
return;
LUAU_ASSERT(fid < functions.size());
const Function& function = functions[fid];
const std::string& dump = function.dump;
const std::vector<int>& dumpinstoffs = function.dumpinstoffs;
uint32_t next = instpos + 1;
LUAU_ASSERT(next < dumpinstoffs.size());
// Skip locations of multi-dword instructions
while (next < dumpinstoffs.size() && dumpinstoffs[next] == -1)
next++;
formatAppend(result, "%.*s", dumpinstoffs[next] - dumpinstoffs[instpos], dump.data() + dumpinstoffs[instpos]);
}
} // namespace Luau

View file

@ -75,7 +75,7 @@ endif
# configuration-specific flags
ifeq ($(config),release)
CXXFLAGS+=-O2 -DNDEBUG
CXXFLAGS+=-O2 -DNDEBUG -fno-math-errno
endif
ifeq ($(config),coverage)
@ -102,7 +102,7 @@ ifeq ($(config),fuzz)
endif
ifeq ($(config),profile)
CXXFLAGS+=-O2 -DNDEBUG -gdwarf-4 -DCALLGRIND=1
CXXFLAGS+=-O2 -DNDEBUG -fno-math-errno -gdwarf-4 -DCALLGRIND=1
endif
ifeq ($(protobuf),download)

View file

@ -55,22 +55,28 @@ target_sources(Luau.Compiler PRIVATE
# Luau.CodeGen Sources
target_sources(Luau.CodeGen PRIVATE
CodeGen/include/Luau/AddressA64.h
CodeGen/include/Luau/AssemblyBuilderA64.h
CodeGen/include/Luau/AssemblyBuilderX64.h
CodeGen/include/Luau/CodeAllocator.h
CodeGen/include/Luau/CodeBlockUnwind.h
CodeGen/include/Luau/CodeGen.h
CodeGen/include/Luau/Condition.h
CodeGen/include/Luau/ConditionA64.h
CodeGen/include/Luau/ConditionX64.h
CodeGen/include/Luau/Label.h
CodeGen/include/Luau/OperandX64.h
CodeGen/include/Luau/RegisterA64.h
CodeGen/include/Luau/RegisterX64.h
CodeGen/include/Luau/UnwindBuilder.h
CodeGen/include/Luau/UnwindBuilderDwarf2.h
CodeGen/include/Luau/UnwindBuilderWin.h
CodeGen/src/AssemblyBuilderA64.cpp
CodeGen/src/AssemblyBuilderX64.cpp
CodeGen/src/CodeAllocator.cpp
CodeGen/src/CodeBlockUnwind.cpp
CodeGen/src/CodeGen.cpp
CodeGen/src/CodeGenUtils.cpp
CodeGen/src/CodeGenX64.cpp
CodeGen/src/EmitBuiltinsX64.cpp
CodeGen/src/EmitCommonX64.cpp
@ -82,6 +88,7 @@ target_sources(Luau.CodeGen PRIVATE
CodeGen/src/ByteUtils.h
CodeGen/src/CustomExecUtils.h
CodeGen/src/CodeGenUtils.h
CodeGen/src/CodeGenX64.h
CodeGen/src/EmitBuiltinsX64.h
CodeGen/src/EmitCommonX64.h
@ -101,10 +108,13 @@ target_sources(Luau.Analysis PRIVATE
Analysis/include/Luau/BuiltinDefinitions.h
Analysis/include/Luau/Clone.h
Analysis/include/Luau/Config.h
Analysis/include/Luau/Connective.h
Analysis/include/Luau/Constraint.h
Analysis/include/Luau/ConstraintGraphBuilder.h
Analysis/include/Luau/ConstraintSolver.h
Analysis/include/Luau/DataFlowGraphBuilder.h
Analysis/include/Luau/DcrLogger.h
Analysis/include/Luau/Def.h
Analysis/include/Luau/Documentation.h
Analysis/include/Luau/Error.h
Analysis/include/Luau/FileResolver.h
@ -114,6 +124,7 @@ target_sources(Luau.Analysis PRIVATE
Analysis/include/Luau/JsonEmitter.h
Analysis/include/Luau/Linter.h
Analysis/include/Luau/LValue.h
Analysis/include/Luau/Metamethods.h
Analysis/include/Luau/Module.h
Analysis/include/Luau/ModuleResolver.h
Analysis/include/Luau/Normalize.h
@ -151,10 +162,13 @@ target_sources(Luau.Analysis PRIVATE
Analysis/src/BuiltinDefinitions.cpp
Analysis/src/Clone.cpp
Analysis/src/Config.cpp
Analysis/src/Connective.cpp
Analysis/src/Constraint.cpp
Analysis/src/ConstraintGraphBuilder.cpp
Analysis/src/ConstraintSolver.cpp
Analysis/src/DataFlowGraphBuilder.cpp
Analysis/src/DcrLogger.cpp
Analysis/src/Def.cpp
Analysis/src/EmbeddedBuiltinDefinitions.cpp
Analysis/src/Error.cpp
Analysis/src/Frontend.cpp
@ -292,6 +306,7 @@ if(TARGET Luau.UnitTest)
tests/AstQueryDsl.cpp
tests/ConstraintGraphBuilderFixture.cpp
tests/Fixture.cpp
tests/AssemblyBuilderA64.test.cpp
tests/AssemblyBuilderX64.test.cpp
tests/AstJsonEncoder.test.cpp
tests/AstQuery.test.cpp
@ -301,9 +316,9 @@ if(TARGET Luau.UnitTest)
tests/CodeAllocator.test.cpp
tests/Compiler.test.cpp
tests/Config.test.cpp
tests/ConstraintGraphBuilder.test.cpp
tests/ConstraintSolver.test.cpp
tests/CostModel.test.cpp
tests/DataFlowGraphBuilder.test.cpp
tests/Error.test.cpp
tests/Frontend.test.cpp
tests/JsonEmitter.test.cpp
@ -334,6 +349,7 @@ if(TARGET Luau.UnitTest)
tests/TypeInfer.intersectionTypes.test.cpp
tests/TypeInfer.loops.test.cpp
tests/TypeInfer.modules.test.cpp
tests/TypeInfer.negations.test.cpp
tests/TypeInfer.oop.test.cpp
tests/TypeInfer.operators.test.cpp
tests/TypeInfer.primitives.test.cpp

View file

@ -1239,7 +1239,12 @@ static int luauF_setmetatable(lua_State* L, StkId res, TValue* arg0, int nresult
return -1;
}
luau_FastFunction luauF_table[256] = {
static int luauF_missing(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams)
{
return -1;
}
const luau_FastFunction luauF_table[256] = {
NULL,
luauF_assert,
@ -1317,4 +1322,20 @@ luau_FastFunction luauF_table[256] = {
luauF_getmetatable,
luauF_setmetatable,
// When adding builtins, add them above this line; what follows is 64 "dummy" entries with luauF_missing fallback.
// This is important so that older versions of the runtime that don't support newer builtins automatically fall back via luauF_missing.
// Given the builtin addition velocity this should always provide a larger compatibility window than bytecode versions suggest.
#define MISSING8 luauF_missing, luauF_missing, luauF_missing, luauF_missing, luauF_missing, luauF_missing, luauF_missing, luauF_missing
MISSING8,
MISSING8,
MISSING8,
MISSING8,
MISSING8,
MISSING8,
MISSING8,
MISSING8,
#undef MISSING8
};

View file

@ -6,4 +6,4 @@
typedef int (*luau_FastFunction)(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams);
extern luau_FastFunction luauF_table[256];
extern const luau_FastFunction luauF_table[256];

View file

@ -12,8 +12,6 @@
#include <string.h>
#include <stdio.h>
LUAU_FASTFLAGVARIABLE(LuauFasterGetInfo, false)
static const char* getfuncname(Closure* f);
static int currentpc(lua_State* L, CallInfo* ci)
@ -105,8 +103,7 @@ static Closure* auxgetinfo(lua_State* L, const char* what, lua_Debug* ar, Closur
ar->source = "=[C]";
ar->what = "C";
ar->linedefined = -1;
if (FFlag::LuauFasterGetInfo)
ar->short_src = "[C]";
ar->short_src = "[C]";
}
else
{
@ -114,13 +111,7 @@ static Closure* auxgetinfo(lua_State* L, const char* what, lua_Debug* ar, Closur
ar->source = getstr(source);
ar->what = "Lua";
ar->linedefined = f->l.p->linedefined;
if (FFlag::LuauFasterGetInfo)
ar->short_src = luaO_chunkid(ar->ssbuf, sizeof(ar->ssbuf), getstr(source), source->len);
}
if (!FFlag::LuauFasterGetInfo)
{
luaO_chunkid(ar->ssbuf, LUA_IDSIZE, ar->source, 0);
ar->short_src = ar->ssbuf;
ar->short_src = luaO_chunkid(ar->ssbuf, sizeof(ar->ssbuf), getstr(source), source->len);
}
break;
}
@ -195,25 +186,12 @@ int lua_getinfo(lua_State* L, int level, const char* what, lua_Debug* ar)
}
if (f)
{
if (FFlag::LuauFasterGetInfo)
// auxgetinfo fills ar and optionally requests to put closure on stack
if (Closure* fcl = auxgetinfo(L, what, ar, f, ci))
{
// auxgetinfo fills ar and optionally requests to put closure on stack
if (Closure* fcl = auxgetinfo(L, what, ar, f, ci))
{
luaC_threadbarrier(L);
setclvalue(L, L->top, fcl);
incr_top(L);
}
}
else
{
auxgetinfo(L, what, ar, f, ci);
if (strchr(what, 'f'))
{
luaC_threadbarrier(L);
setclvalue(L, L->top, f);
incr_top(L);
}
luaC_threadbarrier(L);
setclvalue(L, L->top, fcl);
incr_top(L);
}
}
return f ? 1 : 0;

View file

@ -13,8 +13,6 @@
#include <stdio.h>
#include <stdlib.h>
LUAU_FASTFLAG(LuauFasterGetInfo)
const TValue luaO_nilobject_ = {{NULL}, {0}, LUA_TNIL};
int luaO_log2(unsigned int x)
@ -121,48 +119,20 @@ const char* luaO_chunkid(char* buf, size_t buflen, const char* source, size_t sr
{
if (*source == '=')
{
if (FFlag::LuauFasterGetInfo)
{
if (srclen <= buflen)
return source + 1;
// truncate the part after =
memcpy(buf, source + 1, buflen - 1);
buf[buflen - 1] = '\0';
}
else
{
source++; // skip the `='
size_t len = strlen(source);
size_t dstlen = len < buflen ? len : buflen - 1;
memcpy(buf, source, dstlen);
buf[dstlen] = '\0';
}
if (srclen <= buflen)
return source + 1;
// truncate the part after =
memcpy(buf, source + 1, buflen - 1);
buf[buflen - 1] = '\0';
}
else if (*source == '@')
{
if (FFlag::LuauFasterGetInfo)
{
if (srclen <= buflen)
return source + 1;
// truncate the part after @
memcpy(buf, "...", 3);
memcpy(buf + 3, source + srclen - (buflen - 4), buflen - 4);
buf[buflen - 1] = '\0';
}
else
{
size_t l;
source++; // skip the `@'
buflen -= sizeof("...");
l = strlen(source);
strcpy(buf, "");
if (l > buflen)
{
source += (l - buflen); // get last part of file name
strcat(buf, "...");
}
strcat(buf, source);
}
if (srclen <= buflen)
return source + 1;
// truncate the part after @
memcpy(buf, "...", 3);
memcpy(buf + 3, source + srclen - (buflen - 4), buflen - 4);
buf[buflen - 1] = '\0';
}
else
{ // buf = [string "string"]

View file

@ -16,8 +16,6 @@
#include <string.h>
LUAU_FASTFLAGVARIABLE(LuauNoTopRestoreInFastCall, false)
// Disable c99-designator to avoid the warning in CGOTO dispatch table
#ifdef __clang__
#if __has_warning("-Wc99-designator")
@ -758,6 +756,8 @@ reentry:
Proto* pv = cl->l.p->p[LUAU_INSN_D(insn)];
LUAU_ASSERT(unsigned(LUAU_INSN_D(insn)) < unsigned(cl->l.p->sizep));
VM_PROTECT_PC(); // luaF_newLclosure may fail due to OOM
// note: we save closure to stack early in case the code below wants to capture it by value
Closure* ncl = luaF_newLclosure(L, pv->nups, cl->env, pv);
setclvalue(L, ra, ncl);
@ -2056,6 +2056,8 @@ reentry:
int b = LUAU_INSN_B(insn);
uint32_t aux = *pc++;
VM_PROTECT_PC(); // luaH_new may fail due to OOM
sethvalue(L, ra, luaH_new(L, aux, b == 0 ? 0 : (1 << (b - 1))));
VM_PROTECT(luaC_checkGC(L));
VM_NEXT();
@ -2067,6 +2069,8 @@ reentry:
StkId ra = VM_REG(LUAU_INSN_A(insn));
TValue* kv = VM_KV(LUAU_INSN_D(insn));
VM_PROTECT_PC(); // luaH_clone may fail due to OOM
sethvalue(L, ra, luaH_clone(L, hvalue(kv)));
VM_PROTECT(luaC_checkGC(L));
VM_NEXT();
@ -2088,12 +2092,17 @@ reentry:
Table* h = hvalue(ra);
// TODO: we really don't need this anymore
if (!ttistable(ra))
return; // temporary workaround to weaken a rather powerful exploitation primitive in case of a MITM attack on bytecode
int last = index + c - 1;
if (last > h->sizearray)
{
VM_PROTECT_PC(); // luaH_resizearray may fail due to OOM
luaH_resizearray(L, h, last);
}
TValue* array = h->array;
@ -2185,7 +2194,8 @@ reentry:
// protect against __iter returning nil, since nil is used as a marker for builtin iteration in FORGLOOP
if (ttisnil(ra))
{
VM_PROTECT(luaG_typeerror(L, ra, "call"));
VM_PROTECT_PC(); // next call always errors
luaG_typeerror(L, ra, "call");
}
}
else if (fasttm(L, mt, TM_CALL))
@ -2202,7 +2212,8 @@ reentry:
}
else
{
VM_PROTECT(luaG_typeerror(L, ra, "iterate over"));
VM_PROTECT_PC(); // next call always errors
luaG_typeerror(L, ra, "iterate over");
}
}
@ -2325,7 +2336,8 @@ reentry:
}
else if (!ttisfunction(ra))
{
VM_PROTECT(luaG_typeerror(L, ra, "iterate over"));
VM_PROTECT_PC(); // next call always errors
luaG_typeerror(L, ra, "iterate over");
}
pc += LUAU_INSN_D(insn);
@ -2353,7 +2365,8 @@ reentry:
}
else if (!ttisfunction(ra))
{
VM_PROTECT(luaG_typeerror(L, ra, "iterate over"));
VM_PROTECT_PC(); // next call always errors
luaG_typeerror(L, ra, "iterate over");
}
pc += LUAU_INSN_D(insn);
@ -2404,6 +2417,8 @@ reentry:
Closure* kcl = clvalue(kv);
VM_PROTECT_PC(); // luaF_newLclosure may fail due to OOM
// clone closure if the environment is not shared
// note: we save closure to stack early in case the code below wants to capture it by value
Closure* ncl = (kcl->env == cl->env) ? kcl : luaF_newLclosure(L, kcl->nupvalues, cl->env, kcl->l.p);
@ -2530,10 +2545,11 @@ reentry:
nparams = (nparams == LUA_MULTRET) ? int(L->top - ra - 1) : nparams;
luau_FastFunction f = luauF_table[bfid];
LUAU_ASSERT(f);
if (cl->env->safeenv && f)
if (cl->env->safeenv)
{
VM_PROTECT_PC();
VM_PROTECT_PC(); // f may fail due to OOM
int n = f(L, ra, ra + 1, nresults, ra + 2, nparams);
@ -2608,24 +2624,18 @@ reentry:
int nresults = LUAU_INSN_C(call) - 1;
luau_FastFunction f = luauF_table[bfid];
LUAU_ASSERT(f);
if (cl->env->safeenv && f)
if (cl->env->safeenv)
{
VM_PROTECT_PC();
VM_PROTECT_PC(); // f may fail due to OOM
int n = f(L, ra, arg, nresults, NULL, nparams);
if (n >= 0)
{
if (FFlag::LuauNoTopRestoreInFastCall)
{
if (nresults == LUA_MULTRET)
L->top = ra + n;
}
else
{
L->top = (nresults == LUA_MULTRET) ? ra + n : L->ci->top;
}
if (nresults == LUA_MULTRET)
L->top = ra + n;
pc += skip + 1; // skip instructions that compute function as well as CALL
LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode));
@ -2664,24 +2674,18 @@ reentry:
int nresults = LUAU_INSN_C(call) - 1;
luau_FastFunction f = luauF_table[bfid];
LUAU_ASSERT(f);
if (cl->env->safeenv && f)
if (cl->env->safeenv)
{
VM_PROTECT_PC();
VM_PROTECT_PC(); // f may fail due to OOM
int n = f(L, ra, arg1, nresults, arg2, nparams);
if (n >= 0)
{
if (FFlag::LuauNoTopRestoreInFastCall)
{
if (nresults == LUA_MULTRET)
L->top = ra + n;
}
else
{
L->top = (nresults == LUA_MULTRET) ? ra + n : L->ci->top;
}
if (nresults == LUA_MULTRET)
L->top = ra + n;
pc += skip + 1; // skip instructions that compute function as well as CALL
LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode));
@ -2720,24 +2724,18 @@ reentry:
int nresults = LUAU_INSN_C(call) - 1;
luau_FastFunction f = luauF_table[bfid];
LUAU_ASSERT(f);
if (cl->env->safeenv && f)
if (cl->env->safeenv)
{
VM_PROTECT_PC();
VM_PROTECT_PC(); // f may fail due to OOM
int n = f(L, ra, arg1, nresults, arg2, nparams);
if (n >= 0)
{
if (FFlag::LuauNoTopRestoreInFastCall)
{
if (nresults == LUA_MULTRET)
L->top = ra + n;
}
else
{
L->top = (nresults == LUA_MULTRET) ? ra + n : L->ci->top;
}
if (nresults == LUA_MULTRET)
L->top = ra + n;
pc += skip + 1; // skip instructions that compute function as well as CALL
LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode));

View file

@ -1,4 +1,4 @@
--!non-strict
--!nonstrict
local bench = script and require(script.Parent.bench_support) or require("bench_support")
local stretchTreeDepth = 18 -- about 16Mb

456
bench/tests/voxelgen.lua Normal file
View file

@ -0,0 +1,456 @@
local bench = script and require(script.Parent.bench_support) or require("bench_support")
-- Based on voxel terrain generator by Stickmasterluke
local kSelectedBiomes = {
['Mountains'] = true,
['Canyons'] = true,
['Dunes'] = true,
['Arctic'] = true,
['Lavaflow'] = true,
['Hills'] = true,
['Plains'] = true,
['Marsh'] = true,
['Water'] = true,
}
---------Directly used in Generation---------
local masterSeed = 618033988
local mapWidth = 32
local mapHeight = 32
local biomeSize = 16
local generateCaves = true
local waterLevel = .48
local surfaceThickness = .018
local biomes = {}
---------------------------------------------
local rock = "Rock"
local snow = "Snow"
local ice = "Glacier"
local grass = "Grass"
local ground = "Ground"
local mud = "Mud"
local slate = "Slate"
local concrete = "Concrete"
local lava = "CrackedLava"
local basalt = "Basalt"
local air = "Air"
local sand = "Sand"
local sandstone = "Sandstone"
local water = "Water"
math.randomseed(6180339)
local theseed={}
for i=1,999 do
table.insert(theseed,math.random())
end
local function getPerlin(x,y,z,seed,scale,raw)
local seed = seed or 0
local scale = scale or 1
if not raw then
return math.noise(x/scale+(seed*17)+masterSeed,y/scale-masterSeed,z/scale-seed*seed)*.5 + .5 -- accounts for bleeding from interpolated line
else
return math.noise(x/scale+(seed*17)+masterSeed,y/scale-masterSeed,z/scale-seed*seed)
end
end
local function getNoise(x,y,z,seed1)
local x = x or 0
local y = y or 0
local z = z or 0
local seed1 = seed1 or 7
local wtf=x+y+z+seed1+masterSeed + (masterSeed-x)*(seed1+z) + (seed1-y)*(masterSeed+z) -- + x*(y+z) + z*(masterSeed+seed1) + seed1*(x+y) --x+y+z+seed1+masterSeed + x*y*masterSeed-y*z+(z+masterSeed)*x --((x+y)*(y-seed1)*seed1)-(x+z)*seed2+x*11+z*23-y*17
return theseed[(math.floor(wtf%(#theseed)))+1]
end
local function thresholdFilter(value, bottom, size)
if value <= bottom then
return 0
elseif value >= bottom+size then
return 1
else
return (value-bottom)/size
end
end
local function ridgedFilter(value) --absolute and flip for ridges. and normalize
return value<.5 and value*2 or 2-value*2
end
local function ridgedFlippedFilter(value) --unflipped
return value < .5 and 1-value*2 or value*2-1
end
local function advancedRidgedFilter(value, cutoff)
local cutoff = cutoff or .5
value = value - cutoff
return 1 - (value < 0 and -value or value) * 1/(1-cutoff)
end
local function fractalize(operation,x,y,z, operationCount, scale, offset, gain)
local operationCount = operationCount or 3
local scale = scale or .5
local offset = 0
local gain = gain or 1
local totalValue = 0
local totalScale = 0
for i=1, operationCount do
local thisScale = scale^(i-1)
totalScale = totalScale + thisScale
totalValue = totalValue + (offset + gain * operation(x,y,z,i))*thisScale
end
return totalValue/totalScale
end
local function mountainsOperation(x,y,z,i)
return ridgedFilter(getPerlin(x,y,z,100+i,(1/i)*160))
end
local canyonBandingMaterial = {rock,mud,sand,sand,sandstone,sandstone,sandstone,sandstone,sandstone,sandstone,}
local function findBiomeInfo(choiceBiome,x,y,z,verticalGradientTurbulence)
local choiceBiomeValue = .5
local choiceBiomeSurface = grass
local choiceBiomeFill = rock
if choiceBiome == 'City' then
choiceBiomeValue = .55
choiceBiomeSurface = concrete
choiceBiomeFill = slate
elseif choiceBiome == 'Water' then
choiceBiomeValue = .36+getPerlin(x,y,z,2,50)*.08
choiceBiomeSurface =
(1-verticalGradientTurbulence < .44 and slate)
or sand
elseif choiceBiome == 'Marsh' then
local preLedge = getPerlin(x+getPerlin(x,0,z,5,7,true)*10+getPerlin(x,0,z,6,30,true)*50,0,z+getPerlin(x,0,z,9,7,true)*10+getPerlin(x,0,z,10,30,true)*50,2,70) --could use some turbulence
local grassyLedge = thresholdFilter(preLedge,.65,0)
local largeGradient = getPerlin(x,y,z,4,100)
local smallGradient = getPerlin(x,y,z,3,20)
local smallGradientThreshold = thresholdFilter(smallGradient,.5,0)
choiceBiomeValue = waterLevel-.04
+preLedge*grassyLedge*.025
+largeGradient*.035
+smallGradient*.025
choiceBiomeSurface =
(grassyLedge >= 1 and grass)
or (1-verticalGradientTurbulence < waterLevel-.01 and mud)
or (1-verticalGradientTurbulence < waterLevel+.01 and ground)
or grass
choiceBiomeFill = slate
elseif choiceBiome == 'Plains' then
local rivulet = ridgedFlippedFilter(getPerlin(x+getPerlin(x,y,z,17,40)*25,0,z+getPerlin(x,y,z,19,40)*25,2,200))
local rivuletThreshold = thresholdFilter(rivulet,.01,0)
local rockMap = thresholdFilter(ridgedFlippedFilter(getPerlin(x,0,z,101,7)),.3,.7) --rocks
* thresholdFilter(getPerlin(x,0,z,102,50),.6,.05) --zoning
choiceBiomeValue = .5 --.51
+getPerlin(x,y,z,2,100)*.02 --.05
+rivulet*.05 --.02
+rockMap*.05 --.03
+rivuletThreshold*.005
local verticalGradient = 1-((y-1)/(mapHeight-1))
local surfaceGradient = verticalGradient*.5 + choiceBiomeValue*.5
local thinSurface = surfaceGradient > .5-surfaceThickness*.4 and surfaceGradient < .5+surfaceThickness*.4
choiceBiomeSurface =
(rockMap>0 and rock)
or (not thinSurface and mud)
or (thinSurface and rivuletThreshold <=0 and water)
or (1-verticalGradientTurbulence < waterLevel-.01 and sand)
or grass
choiceBiomeFill =
(rockMap>0 and rock)
or sandstone
elseif choiceBiome == 'Canyons' then
local canyonNoise = ridgedFlippedFilter(getPerlin(x,0,z,2,200))
local canyonNoiseTurbed = ridgedFlippedFilter(getPerlin(x+getPerlin(x,0,z,5,20,true)*20,0,z+getPerlin(x,0,z,9,20,true)*20,2,200))
local sandbank = thresholdFilter(canyonNoiseTurbed,0,.05)
local canyonTop = thresholdFilter(canyonNoiseTurbed,.125,0)
local mesaSlope = thresholdFilter(canyonNoise,.33,.12)
local mesaTop = thresholdFilter(canyonNoiseTurbed,.49,0)
choiceBiomeValue = .42
+getPerlin(x,y,z,2,70)*.05
+canyonNoise*.05
+sandbank*.04 --canyon bottom slope
+thresholdFilter(canyonNoiseTurbed,.05,0)*.08 --canyon cliff
+thresholdFilter(canyonNoiseTurbed,.05,.075)*.04 --canyon cliff top slope
+canyonTop*.01 --canyon cliff top ledge
+thresholdFilter(canyonNoiseTurbed,.0575,.2725)*.01 --plane slope
+mesaSlope*.06 --mesa slope
+thresholdFilter(canyonNoiseTurbed,.45,0)*.14 --mesa cliff
+thresholdFilter(canyonNoiseTurbed,.45,.04)*.025 --mesa cap
+mesaTop*.02 --mesa top ledge
choiceBiomeSurface =
(1-verticalGradientTurbulence < waterLevel+.015 and sand) --this for biome blending in to lakes
or (sandbank>0 and sandbank<1 and sand) --this for canyonbase sandbanks
--or (canyonTop>0 and canyonTop<=1 and mesaSlope<=0 and grass) --this for grassy canyon tops
--or (mesaTop>0 and mesaTop<=1 and grass) --this for grassy mesa tops
or sandstone
choiceBiomeFill = canyonBandingMaterial[math.ceil((1-getNoise(1,y,2))*10)]
elseif choiceBiome == 'Hills' then
local rivulet = ridgedFlippedFilter(getPerlin(x+getPerlin(x,y,z,17,20)*20,0,z+getPerlin(x,y,z,19,20)*20,2,200))^(1/2)
local largeHills = getPerlin(x,y,z,3,60)
choiceBiomeValue = .48
+largeHills*.05
+(.05
+largeHills*.1
+getPerlin(x,y,z,4,25)*.125)
*rivulet
local surfaceMaterialGradient = (1-verticalGradientTurbulence)*.9 + rivulet*.1
choiceBiomeSurface =
(surfaceMaterialGradient < waterLevel-.015 and mud)
or (surfaceMaterialGradient < waterLevel and ground)
or grass
choiceBiomeFill = slate
elseif choiceBiome == 'Dunes' then
local duneTurbulence = getPerlin(x,0,z,227,20)*24
local layer1 = ridgedFilter(getPerlin(x,0,z,201,40))
local layer2 = ridgedFilter(getPerlin(x/10+duneTurbulence,0,z+duneTurbulence,200,48))
choiceBiomeValue = .4+.1*(layer1 + layer2)
choiceBiomeSurface = sand
choiceBiomeFill = sandstone
elseif choiceBiome == 'Mountains' then
local rivulet = ridgedFlippedFilter(getPerlin(x+getPerlin(x,y,z,17,20)*20,0,z+getPerlin(x,y,z,19,20)*20,2,200))
choiceBiomeValue = -.4 --.3
+fractalize(mountainsOperation,x,y/20,z, 8, .65)*1.2
+rivulet*.2
choiceBiomeSurface =
(verticalGradientTurbulence < .275 and snow)
or (verticalGradientTurbulence < .35 and rock)
or (verticalGradientTurbulence < .4 and ground)
or (1-verticalGradientTurbulence < waterLevel and rock)
or (1-verticalGradientTurbulence < waterLevel+.01 and mud)
or (1-verticalGradientTurbulence < waterLevel+.015 and ground)
or grass
elseif choiceBiome == 'Lavaflow' then
local crackX = x+getPerlin(x,y*.25,z,21,8,true)*5
local crackY = y+getPerlin(x,y*.25,z,22,8,true)*5
local crackZ = z+getPerlin(x,y*.25,z,23,8,true)*5
local crack1 = ridgedFilter(getPerlin(crackX+getPerlin(x,y,z,22,30,true)*30,crackY,crackZ+getPerlin(x,y,z,24,30,true)*30,2,120))
local crack2 = ridgedFilter(getPerlin(crackX,crackY,crackZ,3,40))*(crack1*.25+.75)
local crack3 = ridgedFilter(getPerlin(crackX,crackY,crackZ,4,20))*(crack2*.25+.75)
local generalHills = thresholdFilter(getPerlin(x,y,z,9,40),.25,.5)*getPerlin(x,y,z,10,60)
local cracks = math.max(0,1-thresholdFilter(crack1,.975,0)-thresholdFilter(crack2,.925,0)-thresholdFilter(crack3,.9,0))
local spires = thresholdFilter(getPerlin(crackX/40,crackY/300,crackZ/30,123,1),.6,.4)
choiceBiomeValue = waterLevel+.02
+cracks*(.5+generalHills*.5)*.02
+generalHills*.05
+spires*.3
+((1-verticalGradientTurbulence > waterLevel+.01 or spires>0) and .04 or 0) --This lets it lip over water
choiceBiomeFill = (spires>0 and rock) or (cracks<1 and lava) or basalt
choiceBiomeSurface = (choiceBiomeFill == lava and 1-verticalGradientTurbulence < waterLevel and basalt) or choiceBiomeFill
elseif choiceBiome == 'Arctic' then
local preBoundary = getPerlin(x+getPerlin(x,0,z,5,8,true)*5,y/8,z+getPerlin(x,0,z,9,8,true)*5,2,20)
--local cliffs = thresholdFilter(preBoundary,.5,0)
local boundary = ridgedFilter(preBoundary)
local roughChunks = getPerlin(x,y/4,z,436,2)
local boundaryMask = thresholdFilter(boundary,.8,.1) --,.7,.25)
local boundaryTypeMask = getPerlin(x,0,z,6,74)-.5
local boundaryComp = 0
if boundaryTypeMask < 0 then --divergent
boundaryComp = (boundary > (1+boundaryTypeMask*.5) and -.17 or 0)
--* boundaryTypeMask*-2
else --convergent
boundaryComp = boundaryMask*.1*roughChunks
* boundaryTypeMask
end
choiceBiomeValue = .55
+boundary*.05*boundaryTypeMask --.1 --soft slope up or down to boundary
+boundaryComp --convergent/divergent effects
+getPerlin(x,0,z,123,25)*.025 --*cliffs --gentle rolling slopes
choiceBiomeSurface = (1-verticalGradientTurbulence < waterLevel-.1 and ice) or (boundaryMask>.6 and boundaryTypeMask>.1 and roughChunks>.5 and ice) or snow
choiceBiomeFill = ice
end
return choiceBiomeValue, choiceBiomeSurface, choiceBiomeFill
end
function findBiomeTransitionValue(biome,weight,value,averageValue)
if biome == 'Arctic' then
return (weight>.2 and 1 or 0)*value
elseif biome == 'Canyons' then
return (weight>.7 and 1 or 0)*value
elseif biome == 'Mountains' then
local weight = weight^3 --This improves the ease of mountains transitioning to other biomes
return averageValue*(1-weight)+value*weight
else
return averageValue*(1-weight)+value*weight
end
end
function generate()
local mapWidth = mapWidth
local biomeSize = biomeSize
local biomeBlendPercent = .25 --(biomeSize==50 or biomeSize == 100) and .5 or .25
local biomeBlendPercentInverse = 1-biomeBlendPercent
local biomeBlendDistortion = biomeBlendPercent
local smoothScale = .5/mapHeight
biomes = {}
for i,v in pairs(kSelectedBiomes) do
if v then
table.insert(biomes,i)
end
end
if #biomes<=0 then
table.insert(biomes,'Hills')
end
table.sort(biomes)
--local oMap = {}
--local mMap = {}
for x = 1, mapWidth do
local oMapX = {}
--oMap[x] = oMapX
local mMapX = {}
--mMap[x] = mMapX
for z = 1, mapWidth do
local biomeNoCave = false
local cellToBiomeX = x/biomeSize + getPerlin(x,0,z,233,biomeSize*.3)*.25 + getPerlin(x,0,z,235,biomeSize*.05)*.075
local cellToBiomeZ = z/biomeSize + getPerlin(x,0,z,234,biomeSize*.3)*.25 + getPerlin(x,0,z,236,biomeSize*.05)*.075
local closestDistance = 1000000
local biomePoints = {}
for vx=-1,1 do
for vz=-1,1 do
local gridPointX = math.floor(cellToBiomeX+vx+.5)
local gridPointZ = math.floor(cellToBiomeZ+vz+.5)
--local pointX, pointZ = getBiomePoint(gridPointX,gridPointZ)
local pointX = gridPointX+(getNoise(gridPointX,gridPointZ,53)-.5)*.75 --de-uniforming grid for vornonoi
local pointZ = gridPointZ+(getNoise(gridPointX,gridPointZ,73)-.5)*.75
local dist = math.sqrt((pointX-cellToBiomeX)^2 + (pointZ-cellToBiomeZ)^2)
if dist < closestDistance then
closestDistance = dist
end
table.insert(biomePoints,{
x = pointX,
z = pointZ,
dist = dist,
biomeNoise = getNoise(gridPointX,gridPointZ),
weight = 0
})
end
end
local weightTotal = 0
local weightPoints = {}
for _,point in pairs(biomePoints) do
local weight = point.dist == closestDistance and 1 or ((closestDistance / point.dist)-biomeBlendPercentInverse)/biomeBlendPercent
if weight > 0 then
local weight = weight^2.1 --this smooths the biome transition from linear to cubic InOut
weightTotal = weightTotal + weight
local biome = biomes[math.ceil(#biomes*(1-point.biomeNoise))] --inverting the noise so that it is limited as (0,1]. One less addition operation when finding a random list index
weightPoints[biome] = {
weight = weightPoints[biome] and weightPoints[biome].weight + weight or weight
}
end
end
for biome,info in pairs(weightPoints) do
info.weight = info.weight / weightTotal
if biome == 'Arctic' then --biomes that don't have caves that breach the surface
biomeNoCave = true
end
end
for y = 1, mapHeight do
local oMapY = oMapX[y] or {}
oMapX[y] = oMapY
local mMapY = mMapX[y] or {}
mMapX[y] = mMapY
--[[local oMapY = {}
oMapX[y] = oMapY
local mMapY = {}
mMapX[z] = mMapY]]
local verticalGradient = 1-((y-1)/(mapHeight-1))
local caves = 0
local verticalGradientTurbulence = verticalGradient*.9 + .1*getPerlin(x,y,z,107,15)
local choiceValue = 0
local choiceSurface = lava
local choiceFill = rock
if verticalGradient > .65 or verticalGradient < .1 then
--under surface of every biome; don't get biome data; waste of time.
choiceValue = .5
elseif #biomes == 1 then
choiceValue, choiceSurface, choiceFill = findBiomeInfo(biomes[1],x,y,z,verticalGradientTurbulence)
else
local averageValue = 0
--local findChoiceMaterial = -getNoise(x,y,z,19)
for biome,info in pairs(weightPoints) do
local biomeValue, biomeSurface, biomeFill = findBiomeInfo(biome,x,y,z,verticalGradientTurbulence)
info.biomeValue = biomeValue
info.biomeSurface = biomeSurface
info.biomeFill = biomeFill
local value = biomeValue * info.weight
averageValue = averageValue + value
--[[if findChoiceMaterial < 0 and findChoiceMaterial + weight >= 0 then
choiceMaterial = biomeMaterial
end
findChoiceMaterial = findChoiceMaterial + weight]]
end
for biome,info in pairs(weightPoints) do
local value = findBiomeTransitionValue(biome,info.weight,info.biomeValue,averageValue)
if value > choiceValue then
choiceValue = value
choiceSurface = info.biomeSurface
choiceFill = info.biomeFill
end
end
end
local preCaveComp = verticalGradient*.5 + choiceValue*.5
local surface = preCaveComp > .5-surfaceThickness and preCaveComp < .5+surfaceThickness
if generateCaves --user wants caves
and (not biomeNoCave or verticalGradient > .65) --biome allows caves or deep enough
and not (surface and (1-verticalGradient) < waterLevel+.005) --caves only breach surface above waterlevel
and not (surface and (1-verticalGradient) > waterLevel+.58) then --caves don't go too high so that they don't cut up mountain tops
local ridged2 = ridgedFilter(getPerlin(x,y,z,4,30))
local caves2 = thresholdFilter(ridged2,.84,.01)
local ridged3 = ridgedFilter(getPerlin(x,y,z,5,30))
local caves3 = thresholdFilter(ridged3,.84,.01)
local ridged4 = ridgedFilter(getPerlin(x,y,z,6,30))
local caves4 = thresholdFilter(ridged4,.84,.01)
local caveOpenings = (surface and 1 or 0) * thresholdFilter(getPerlin(x,0,z,143,62),.35,0) --.45
caves = caves2 * caves3 * caves4 - caveOpenings
caves = caves < 0 and 0 or caves > 1 and 1 or caves
end
local comp = preCaveComp - caves
local smoothedResult = thresholdFilter(comp,.5,smoothScale)
---below water level -above surface -no terrain
if 1-verticalGradient < waterLevel and preCaveComp <= .5 and smoothedResult <= 0 then
smoothedResult = 1
choiceSurface = water
choiceFill = water
surface = true
end
oMapY[z] = (y == 1 and 1) or smoothedResult
mMapY[z] = (y == 1 and lava) or (smoothedResult <= 0 and air) or (surface and choiceSurface) or choiceFill
end
end
-- local regionStart = Vector3.new(mapWidth*-2+(x-1)*4,mapHeight*-2,mapWidth*-2)
-- local regionEnd = Vector3.new(mapWidth*-2+x*4,mapHeight*2,mapWidth*2)
-- local mapRegion = Region3.new(regionStart, regionEnd)
-- terrain:WriteVoxels(mapRegion, 4, {mMapX}, {oMapX})
end
end
bench.runCode(generate, "voxelgen")

View file

@ -66,8 +66,8 @@ Syntactic subtyping is a syntax-directed recursive algorithm. The interesting ca
* Reflexivity: `T` is a subtype of `T`
* Intersection L: `(T₁ & … & Tⱼ)` is a subtype of `U` whenever some of the `Tᵢ` are subtypes of `U`
* Union L: `(T₁ | … | Tⱼ)` is a subtype of `U` whenever all of the `Tᵢ` are subtypes of `U`
* Intersection R: `T` is a subtype of `(U₁ & … & Uⱼ)` whenever `T` is a subtype of some of the `Uᵢ`
* Union R: `T` is a subtype of `(U₁ | … | Uⱼ)` whenever `T` is a subtype of all of the `Uᵢ`.
* Intersection R: `T` is a subtype of `(U₁ & … & Uⱼ)` whenever `T` is a subtype of all of the `Uᵢ`
* Union R: `T` is a subtype of `(U₁ | … | Uⱼ)` whenever `T` is a subtype of some of the `Uᵢ`.
For example:
@ -262,6 +262,10 @@ Semantic subtyping has removed one source of false positives, but we still have
The quest to remove spurious red squiggles continues!
## Acknowledgments
Thanks to Giuseppe Castagna and Ben Greenman for helpful comments on drafts of this post.
## Further reading
If you want to find out more about Luau and semantic subtyping, you might want to check out…
@ -274,8 +278,9 @@ If you want to find out more about Luau and semantic subtyping, you might want t
* Giuseppe Castagna, *Covariance and Contravariance*, Logical Methods in Computer Science 16(1), 2022. <https://arxiv.org/abs/1809.01427>
* Giuseppe Castagna and Alain Frisch, *A gentle introduction to semantic subtyping*, Proc. Principles and practice of declarative programming (PPDP), pp 198208, 2005. <https://doi.org/10.1145/1069774.1069793>
* Giuseppe Castagna, Mickaël Laurent, Kim Nguyễn, Matthew Lutze, *On Type-Cases, Union Elimination, and Occurrence Typing*, Principles of Programming Languages (POPL), 2022. <https://doi.org/10.1145/3498674>
* Sam Tobin-Hochstadt and Matthias Felleisen, *Logical types for untyped languages*. International Conference on Functional Programming (ICFP), 2010. https://doi.org/10.1145/1863543.1863561
* José Valim, *My Future with Elixir: set-theoretic types*, 2022. https://elixir-lang.org/blog/2022/10/05/my-future-with-elixir-set-theoretic-types/
* Giuseppe Castagna, *Programming with union, intersection, and negation types*, 2022. <https://arxiv.org/abs/2111.03354>
* Sam Tobin-Hochstadt and Matthias Felleisen, *Logical types for untyped languages*. International Conference on Functional Programming (ICFP), 2010. <https://doi.org/10.1145/1863543.1863561>
* José Valim, *My Future with Elixir: set-theoretic types*, 2022. <https://elixir-lang.org/blog/2022/10/05/my-future-with-elixir-set-theoretic-types/>
Some other languages which support semantic subtyping…

Some files were not shown because too many files have changed in this diff Show more