Merge branch 'master' into petrih-lightuserdatatag

This commit is contained in:
Petri Häkkinen 2023-12-08 09:55:59 +02:00
commit 29d59ad40c
328 changed files with 10534 additions and 2802 deletions

3
.github/codecov.yml vendored
View file

@ -2,3 +2,6 @@ comment: false
coverage:
status:
patch: false
project:
default:
informational: true

View file

@ -20,7 +20,7 @@ jobs:
unix:
strategy:
matrix:
os: [{name: ubuntu, version: ubuntu-20.04}, {name: macos, version: macos-latest}]
os: [{name: ubuntu, version: ubuntu-latest}, {name: macos, version: macos-latest}]
name: ${{matrix.os.name}}
runs-on: ${{matrix.os.version}}
steps:
@ -83,7 +83,7 @@ jobs:
Debug/luau-compile tests/conformance/assert.lua
coverage:
runs-on: ubuntu-20.04
runs-on: ubuntu-20.04 # needed for clang++-10 to avoid gcov compatibility issues
steps:
- uses: actions/checkout@v2
- name: install
@ -99,7 +99,7 @@ jobs:
token: ${{ secrets.CODECOV_TOKEN }}
web:
runs-on: ubuntu-20.04
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
- uses: actions/checkout@v2

View file

@ -12,7 +12,7 @@ permissions:
jobs:
create-release:
runs-on: ubuntu-20.04
runs-on: ubuntu-latest
outputs:
upload_url: ${{ steps.create_release.outputs.upload_url }}
steps:
@ -29,7 +29,7 @@ jobs:
build:
needs: ["create-release"]
strategy:
matrix:
matrix: # using ubuntu-20.04 to build a Linux binary targeting older glibc to improve compatibility
os: [{name: ubuntu, version: ubuntu-20.04}, {name: macos, version: macos-latest}, {name: windows, version: windows-latest}]
name: ${{matrix.os.name}}
runs-on: ${{matrix.os.version}}
@ -56,7 +56,7 @@ jobs:
web:
needs: ["create-release"]
runs-on: ubuntu-20.04
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
- uses: actions/checkout@v2

View file

@ -13,7 +13,7 @@ on:
jobs:
build:
strategy:
matrix:
matrix: # using ubuntu-20.04 to build a Linux binary targeting older glibc to improve compatibility
os: [{name: ubuntu, version: ubuntu-20.04}, {name: macos, version: macos-latest}, {name: windows, version: windows-latest}]
name: ${{matrix.os.name}}
runs-on: ${{matrix.os.version}}
@ -35,7 +35,7 @@ jobs:
path: Release\luau*.exe
web:
runs-on: ubuntu-20.04
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
- uses: actions/checkout@v2

View file

@ -49,8 +49,8 @@ struct InstantiationConstraint
TypeId superType;
};
// iteratee is iterable
// iterators is the iteration types.
// variables ~ iterate iterator
// Unpack the iterator, figure out what types it iterates over, and bind those types to variables.
struct IterableConstraint
{
TypePackId iterator;
@ -190,6 +190,11 @@ struct UnpackConstraint
{
TypePackId resultPack;
TypePackId sourcePack;
// UnpackConstraint is sometimes used to resolve the types of assignments.
// When this is the case, any LocalTypes in resultPack can have their
// domains extended by the corresponding type from sourcePack.
bool resultIsLValue = false;
};
// resultType ~ refine type mode discriminant

View file

@ -57,7 +57,7 @@ struct InferencePack
}
};
struct ConstraintGraphBuilder
struct ConstraintGenerator
{
// A list of all the scopes in the module. This vector holds ownership of the
// scope pointers; the scopes themselves borrow pointers to other scopes to
@ -68,7 +68,7 @@ struct ConstraintGraphBuilder
NotNull<BuiltinTypes> builtinTypes;
const NotNull<TypeArena> arena;
// The root scope of the module we're generating constraints for.
// This is null when the CGB is initially constructed.
// This is null when the CG is initially constructed.
Scope* rootScope;
struct InferredBinding
@ -78,11 +78,13 @@ struct ConstraintGraphBuilder
TypeIds types;
};
// During constraint generation, we only populate the Scope::bindings
// property for annotated symbols. Unannotated symbols must be handled in a
// postprocessing step because we have not yet allocated the types that will
// be assigned to those unannotated symbols, so we queue them up here.
std::map<Symbol, InferredBinding> inferredBindings;
// Some locals have multiple type states. We wish for Scope::bindings to
// map each local name onto the union of every type that the local can have
// over its lifetime, so we use this map to accumulate the set of types it
// might have.
//
// See the functions recordInferredBinding and fillInInferredBindings.
DenseHashMap<Symbol, InferredBinding> inferredBindings{{}};
// Constraints that go straight to the solver.
std::vector<ConstraintPtr> constraints;
@ -116,13 +118,13 @@ struct ConstraintGraphBuilder
DcrLogger* logger;
ConstraintGraphBuilder(ModulePtr module, NotNull<Normalizer> normalizer, NotNull<ModuleResolver> moduleResolver,
ConstraintGenerator(ModulePtr module, NotNull<Normalizer> normalizer, NotNull<ModuleResolver> moduleResolver,
NotNull<BuiltinTypes> builtinTypes, NotNull<InternalErrorReporter> ice, const ScopePtr& globalScope,
std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope, DcrLogger* logger, NotNull<DataFlowGraph> dfg,
std::vector<RequireCycle> requireCycles);
/**
* The entry point to the ConstraintGraphBuilder. This will construct a set
* The entry point to the ConstraintGenerator. This will construct a set
* of scopes, constraints, and free types that can be solved later.
* @param block the root block to generate constraints for.
*/
@ -148,6 +150,8 @@ private:
*/
ScopePtr childScope(AstNode* node, const ScopePtr& parent);
std::optional<TypeId> lookup(Scope* scope, DefId def);
/**
* Adds a new constraint with no dependencies to a given scope.
* @param scope the scope to add the constraint to.
@ -221,6 +225,7 @@ private:
Inference check(const ScopePtr& scope, AstExprConstantBool* bool_, std::optional<TypeId> expectedType, bool forceSingleton);
Inference check(const ScopePtr& scope, AstExprLocal* local);
Inference check(const ScopePtr& scope, AstExprGlobal* global);
Inference checkIndexName(const ScopePtr& scope, const RefinementKey* key, AstExpr* indexee, std::string index);
Inference check(const ScopePtr& scope, AstExprIndexName* indexName);
Inference check(const ScopePtr& scope, AstExprIndexExpr* indexExpr);
Inference check(const ScopePtr& scope, AstExprFunction* func, std::optional<TypeId> expectedType, bool generalize);
@ -232,14 +237,16 @@ private:
Inference check(const ScopePtr& scope, AstExprTable* expr, std::optional<TypeId> expectedType);
std::tuple<TypeId, TypeId, RefinementId> checkBinary(const ScopePtr& scope, AstExprBinary* binary, std::optional<TypeId> expectedType);
std::optional<TypeId> checkLValue(const ScopePtr& scope, AstExpr* expr);
std::optional<TypeId> checkLValue(const ScopePtr& scope, AstExprLocal* local);
std::optional<TypeId> checkLValue(const ScopePtr& scope, AstExprGlobal* global);
std::optional<TypeId> checkLValue(const ScopePtr& scope, AstExprIndexName* indexName);
std::optional<TypeId> checkLValue(const ScopePtr& scope, AstExprIndexExpr* indexExpr);
TypeId updateProperty(const ScopePtr& scope, AstExpr* expr);
void updateLValueType(AstExpr* lvalue, TypeId ty);
/**
* Generate constraints to assign assignedTy to the expression expr
* @returns the type of the expression. This may or may not be assignedTy itself.
*/
std::optional<TypeId> checkLValue(const ScopePtr& scope, AstExpr* expr, TypeId assignedTy);
std::optional<TypeId> checkLValue(const ScopePtr& scope, AstExprLocal* local, TypeId assignedTy);
std::optional<TypeId> checkLValue(const ScopePtr& scope, AstExprGlobal* global, TypeId assignedTy);
std::optional<TypeId> checkLValue(const ScopePtr& scope, AstExprIndexName* indexName, TypeId assignedTy);
std::optional<TypeId> checkLValue(const ScopePtr& scope, AstExprIndexExpr* indexExpr, TypeId assignedTy);
TypeId updateProperty(const ScopePtr& scope, AstExpr* expr, TypeId assignedTy);
struct FunctionSignature
{
@ -324,12 +331,16 @@ private:
/** Scan the program for global definitions.
*
* ConstraintGraphBuilder needs to differentiate between globals and accesses to undefined symbols. Doing this "for
* ConstraintGenerator needs to differentiate between globals and accesses to undefined symbols. Doing this "for
* real" in a general way is going to be pretty hard, so we are choosing not to tackle that yet. For now, we do an
* initial scan of the AST and note what globals are defined.
*/
void prepopulateGlobalScope(const ScopePtr& globalScope, AstStatBlock* program);
// Record the fact that a particular local has a particular type in at least
// one of its states.
void recordInferredBinding(AstLocal* local, TypeId ty);
void fillInInferredBindings(const ScopePtr& globalScope, AstStatBlock* block);
/** Given a function type annotation, return a vector describing the expected types of the calls to the function

View file

@ -3,14 +3,18 @@
#pragma once
#include "Luau/Constraint.h"
#include "Luau/DenseHash.h"
#include "Luau/Error.h"
#include "Luau/Location.h"
#include "Luau/Module.h"
#include "Luau/Normalize.h"
#include "Luau/ToString.h"
#include "Luau/Type.h"
#include "Luau/TypeCheckLimits.h"
#include "Luau/TypeFwd.h"
#include "Luau/Variant.h"
#include <utility>
#include <vector>
namespace Luau
@ -74,6 +78,10 @@ struct ConstraintSolver
std::unordered_map<BlockedConstraintId, std::vector<NotNull<const Constraint>>, HashBlockedConstraintId> blocked;
// Memoized instantiations of type aliases.
DenseHashMap<InstantiationSignature, TypeId, HashInstantiationSignature> instantiatedAliases{{}};
// Breadcrumbs for where a free type's upper bound was expanded. We use
// these to provide more helpful error messages when a free type is solved
// as never unexpectedly.
DenseHashMap<TypeId, std::vector<std::pair<Location, TypeId>>> upperBoundContributors{nullptr};
// A mapping from free types to the number of unresolved constraints that mention them.
DenseHashMap<TypeId, size_t> unresolvedConstraints{{}};
@ -140,7 +148,7 @@ struct ConstraintSolver
std::pair<std::vector<TypeId>, std::optional<TypeId>> lookupTableProp(
TypeId subjectType, const std::string& propName, bool suppressSimplification = false);
std::pair<std::vector<TypeId>, std::optional<TypeId>> lookupTableProp(
TypeId subjectType, const std::string& propName, bool suppressSimplification, std::unordered_set<TypeId>& seen);
TypeId subjectType, const std::string& propName, bool suppressSimplification, DenseHashSet<TypeId>& seen);
void block(NotNull<const Constraint> target, NotNull<const Constraint> constraint);
/**

View file

@ -3,6 +3,7 @@
// Do not include LValue. It should never be used here.
#include "Luau/Ast.h"
#include "Luau/ControlFlow.h"
#include "Luau/DenseHash.h"
#include "Luau/Def.h"
#include "Luau/Symbol.h"
@ -34,7 +35,7 @@ struct DataFlowGraph
DataFlowGraph& operator=(DataFlowGraph&&) = default;
DefId getDef(const AstExpr* expr) const;
// Look up for the rvalue breadcrumb for a compound assignment.
// Look up for the rvalue def for a compound assignment.
std::optional<DefId> getRValueDefForCompoundAssign(const AstExpr* expr) const;
DefId getDef(const AstLocal* local) const;
@ -64,7 +65,7 @@ private:
// Compound assignments are in a weird situation where the local being assigned to is also being used at its
// previous type implicitly in an rvalue position. This map provides the previous binding.
DenseHashMap<const AstExpr*, const Def*> compoundAssignBreadcrumbs{nullptr};
DenseHashMap<const AstExpr*, const Def*> compoundAssignDefs{nullptr};
DenseHashMap<const AstExpr*, const RefinementKey*> astRefinementKeys{nullptr};
@ -74,11 +75,21 @@ private:
struct DfgScope
{
DfgScope* parent;
DenseHashMap<Symbol, const Def*> bindings{Symbol{}};
DenseHashMap<const Def*, std::unordered_map<std::string, const Def*>> props{nullptr};
bool isLoopScope;
using Bindings = DenseHashMap<Symbol, const Def*>;
using Props = DenseHashMap<const Def*, std::unordered_map<std::string, const Def*>>;
Bindings bindings{Symbol{}};
Props props{nullptr};
std::optional<DefId> lookup(Symbol symbol) const;
std::optional<DefId> lookup(DefId def, const std::string& key) const;
void inherit(const DfgScope* childScope);
bool canUpdateDefinition(Symbol symbol) const;
bool canUpdateDefinition(DefId def, const std::string& key) const;
};
struct DataFlowResult
@ -106,31 +117,38 @@ private:
std::vector<std::unique_ptr<DfgScope>> scopes;
DfgScope* childScope(DfgScope* scope);
DfgScope* childScope(DfgScope* scope, bool isLoopScope = false);
void visit(DfgScope* scope, AstStatBlock* b);
void visitBlockWithoutChildScope(DfgScope* scope, AstStatBlock* b);
void join(DfgScope* p, DfgScope* a, DfgScope* b);
void joinBindings(DfgScope::Bindings& p, const DfgScope::Bindings& a, const DfgScope::Bindings& b);
void joinProps(DfgScope::Props& p, const DfgScope::Props& a, const DfgScope::Props& b);
void visit(DfgScope* scope, AstStat* s);
void visit(DfgScope* scope, AstStatIf* i);
void visit(DfgScope* scope, AstStatWhile* w);
void visit(DfgScope* scope, AstStatRepeat* r);
void visit(DfgScope* scope, AstStatBreak* b);
void visit(DfgScope* scope, AstStatContinue* c);
void visit(DfgScope* scope, AstStatReturn* r);
void visit(DfgScope* scope, AstStatExpr* e);
void visit(DfgScope* scope, AstStatLocal* l);
void visit(DfgScope* scope, AstStatFor* f);
void visit(DfgScope* scope, AstStatForIn* f);
void visit(DfgScope* scope, AstStatAssign* a);
void visit(DfgScope* scope, AstStatCompoundAssign* c);
void visit(DfgScope* scope, AstStatFunction* f);
void visit(DfgScope* scope, AstStatLocalFunction* l);
void visit(DfgScope* scope, AstStatTypeAlias* t);
void visit(DfgScope* scope, AstStatDeclareGlobal* d);
void visit(DfgScope* scope, AstStatDeclareFunction* d);
void visit(DfgScope* scope, AstStatDeclareClass* d);
void visit(DfgScope* scope, AstStatError* error);
DefId lookup(DfgScope* scope, Symbol symbol);
DefId lookup(DfgScope* scope, DefId def, const std::string& key);
ControlFlow visit(DfgScope* scope, AstStatBlock* b);
ControlFlow visitBlockWithoutChildScope(DfgScope* scope, AstStatBlock* b);
ControlFlow visit(DfgScope* scope, AstStat* s);
ControlFlow visit(DfgScope* scope, AstStatIf* i);
ControlFlow visit(DfgScope* scope, AstStatWhile* w);
ControlFlow visit(DfgScope* scope, AstStatRepeat* r);
ControlFlow visit(DfgScope* scope, AstStatBreak* b);
ControlFlow visit(DfgScope* scope, AstStatContinue* c);
ControlFlow visit(DfgScope* scope, AstStatReturn* r);
ControlFlow visit(DfgScope* scope, AstStatExpr* e);
ControlFlow visit(DfgScope* scope, AstStatLocal* l);
ControlFlow visit(DfgScope* scope, AstStatFor* f);
ControlFlow visit(DfgScope* scope, AstStatForIn* f);
ControlFlow visit(DfgScope* scope, AstStatAssign* a);
ControlFlow visit(DfgScope* scope, AstStatCompoundAssign* c);
ControlFlow visit(DfgScope* scope, AstStatFunction* f);
ControlFlow visit(DfgScope* scope, AstStatLocalFunction* l);
ControlFlow visit(DfgScope* scope, AstStatTypeAlias* t);
ControlFlow visit(DfgScope* scope, AstStatDeclareGlobal* d);
ControlFlow visit(DfgScope* scope, AstStatDeclareFunction* d);
ControlFlow visit(DfgScope* scope, AstStatDeclareClass* d);
ControlFlow visit(DfgScope* scope, AstStatError* error);
DataFlowResult visitExpr(DfgScope* scope, AstExpr* e);
DataFlowResult visitExpr(DfgScope* scope, AstExprGroup* group);

View file

@ -79,8 +79,8 @@ struct DefArena
TypedAllocator<Def> allocator;
DefId freshCell(bool subscripted = false);
// TODO: implement once we have cases where we need to merge in definitions
// DefId phi(const std::vector<DefId>& defs);
DefId phi(DefId a, DefId b);
DefId phi(const std::vector<DefId>& defs);
};
} // namespace Luau

View file

@ -322,6 +322,7 @@ struct TypePackMismatch
{
TypePackId wantedTp;
TypePackId givenTp;
std::string reason;
bool operator==(const TypePackMismatch& rhs) const;
};
@ -371,13 +372,21 @@ struct CheckedFunctionCallError
bool operator==(const CheckedFunctionCallError& rhs) const;
};
struct NonStrictFunctionDefinitionError
{
std::string functionName;
std::string argument;
TypeId argumentType;
bool operator==(const NonStrictFunctionDefinitionError& rhs) const;
};
using TypeErrorData = Variant<TypeMismatch, UnknownSymbol, UnknownProperty, NotATable, CannotExtendTable, OnlyTablesCanHaveMethods,
DuplicateTypeDefinition, CountMismatch, FunctionDoesNotTakeSelf, FunctionRequiresSelf, OccursCheckFailed, UnknownRequire,
IncorrectGenericParameterCount, SyntaxError, CodeTooComplex, UnificationTooComplex, UnknownPropButFoundLikeProp, GenericError, InternalError,
CannotCallNonFunction, ExtraInformation, DeprecatedApiUsed, ModuleHasCyclicDependency, IllegalRequire, FunctionExitsWithoutReturning,
DuplicateGenericParameter, CannotInferBinaryOperation, MissingProperties, SwappedGenericTypeParameter, OptionalValueAccess, MissingUnionProperty,
TypesAreUnrelated, NormalizationTooComplex, TypePackMismatch, DynamicPropertyLookupOnClassesUnsafe, UninhabitedTypeFamily,
UninhabitedTypePackFamily, WhereClauseNeeded, PackWhereClauseNeeded, CheckedFunctionCallError>;
UninhabitedTypePackFamily, WhereClauseNeeded, PackWhereClauseNeeded, CheckedFunctionCallError, NonStrictFunctionDefinitionError>;
struct TypeErrorSummary
{

View file

@ -71,7 +71,7 @@ struct SourceNode
ModuleName name;
std::string humanReadableName;
std::unordered_set<ModuleName> requireSet;
DenseHashSet<ModuleName> requireSet{{}};
std::vector<std::pair<ModuleName, Location>> requireLocations;
bool dirtySourceModule = true;
bool dirtyModule = true;
@ -206,7 +206,7 @@ private:
std::vector<ModuleName>& buildQueue, const ModuleName& root, bool forAutocomplete, std::function<bool(const ModuleName&)> canSkip = {});
void addBuildQueueItems(std::vector<BuildQueueItem>& items, std::vector<ModuleName>& buildQueue, bool cycleDetected,
std::unordered_set<Luau::ModuleName>& seen, const FrontendOptions& frontendOptions);
DenseHashSet<Luau::ModuleName>& seen, const FrontendOptions& frontendOptions);
void checkBuildQueueItem(BuildQueueItem& item);
void checkBuildQueueItems(std::vector<BuildQueueItem>& items);
void recordItemResult(const BuildQueueItem& item);

View file

@ -102,6 +102,8 @@ struct Module
DenseHashMap<const AstType*, TypeId> astResolvedTypes{nullptr};
DenseHashMap<const AstTypePack*, TypePackId> astResolvedTypePacks{nullptr};
DenseHashMap<TypeId, std::vector<std::pair<Location, TypeId>>> upperBoundContributors{nullptr};
// Map AST nodes to the scope they create. Cannot be NotNull<Scope> because
// we need a sentinel value for the map.
DenseHashMap<const AstNode*, Scope*> astScopes{nullptr};

View file

@ -2,6 +2,7 @@
#pragma once
#include "Luau/NotNull.h"
#include "Luau/Set.h"
#include "Luau/TypeFwd.h"
#include "Luau/UnifierSharedState.h"
@ -9,7 +10,6 @@
#include <map>
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace Luau
@ -29,7 +29,7 @@ bool isConsistentSubtype(TypePackId subTy, TypePackId superTy, NotNull<Scope> sc
class TypeIds
{
private:
std::unordered_set<TypeId> types;
DenseHashMap<TypeId, bool> types{nullptr};
std::vector<TypeId> order;
std::size_t hash = 0;
@ -254,6 +254,10 @@ struct NormalizedType
// This type is either never or thread.
TypeId threads;
// The buffer part of the type.
// This type is either never or buffer.
TypeId buffers;
// The (meta)table part of the type.
// Each element of this set is a (meta)table type, or the top `table` type.
// An empty set denotes never.
@ -277,6 +281,7 @@ struct NormalizedType
NormalizedType& operator=(NormalizedType&&) = default;
// IsType functions
bool isUnknown() const;
/// Returns true if the type is exactly a number. Behaves like Type::isNumber()
bool isExactlyNumber() const;
@ -298,6 +303,7 @@ struct NormalizedType
bool hasNumbers() const;
bool hasStrings() const;
bool hasThreads() const;
bool hasBuffers() const;
bool hasTables() const;
bool hasFunctions() const;
bool hasTyvars() const;
@ -358,7 +364,7 @@ public:
void unionTablesWithTable(TypeIds& heres, TypeId there);
void unionTables(TypeIds& heres, const TypeIds& theres);
bool unionNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars = -1);
bool unionNormalWithTy(NormalizedType& here, TypeId there, std::unordered_set<TypeId>& seenSetTypes, int ignoreSmallerTyvars = -1);
bool unionNormalWithTy(NormalizedType& here, TypeId there, Set<TypeId>& seenSetTypes, int ignoreSmallerTyvars = -1);
// ------- Negations
std::optional<NormalizedType> negateNormal(const NormalizedType& here);
@ -380,15 +386,15 @@ public:
std::optional<TypeId> intersectionOfFunctions(TypeId here, TypeId there);
void intersectFunctionsWithFunction(NormalizedFunctionType& heress, TypeId there);
void intersectFunctions(NormalizedFunctionType& heress, const NormalizedFunctionType& theress);
bool intersectTyvarsWithTy(NormalizedTyvars& here, TypeId there, std::unordered_set<TypeId>& seenSetTypes);
bool intersectTyvarsWithTy(NormalizedTyvars& here, TypeId there, Set<TypeId>& seenSetTypes);
bool intersectNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars = -1);
bool intersectNormalWithTy(NormalizedType& here, TypeId there, std::unordered_set<TypeId>& seenSetTypes);
bool intersectNormalWithTy(NormalizedType& here, TypeId there, Set<TypeId>& seenSetTypes);
bool normalizeIntersections(const std::vector<TypeId>& intersections, NormalizedType& outType);
// Check for inhabitance
bool isInhabited(TypeId ty);
bool isInhabited(TypeId ty, std::unordered_set<TypeId> seen);
bool isInhabited(const NormalizedType* norm, std::unordered_set<TypeId> seen = {});
bool isInhabited(TypeId ty, Set<TypeId> seen);
bool isInhabited(const NormalizedType* norm, Set<TypeId> seen = {nullptr});
// Check for intersections being inhabited
bool isIntersectionInhabited(TypeId left, TypeId right);

View file

@ -56,7 +56,7 @@ struct Scope
void addBuiltinTypeBinding(const Name& name, const TypeFun& tyFun);
std::optional<TypeId> lookup(Symbol sym) const;
std::optional<TypeId> lookupLValue(DefId def) const;
std::optional<TypeId> lookupUnrefinedType(DefId def) const;
std::optional<TypeId> lookup(DefId def) const;
std::optional<std::pair<TypeId, Scope*>> lookupEx(DefId def);
std::optional<std::pair<Binding*, Scope*>> lookupEx(Symbol sym);
@ -80,6 +80,7 @@ struct Scope
// types here.
DenseHashMap<const Def*, TypeId> rvalueRefinements{nullptr};
void inheritAssignments(const ScopePtr& childScope);
void inheritRefinements(const ScopePtr& childScope);
// For mutually recursive type aliases, it's important that

171
Analysis/include/Luau/Set.h Normal file
View file

@ -0,0 +1,171 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/DenseHash.h"
namespace Luau
{
template<typename T>
using SetHashDefault = std::conditional_t<std::is_pointer_v<T>, DenseHashPointer, std::hash<T>>;
// This is an implementation of `unordered_set` using `DenseHashMap<T, bool>` to support erasure.
// This lets us work around `DenseHashSet` limitations and get a more traditional set interface.
template<typename T, typename Hash = SetHashDefault<T>>
class Set
{
private:
using Impl = DenseHashMap<T, bool, Hash>;
Impl mapping;
size_t entryCount = 0;
public:
class const_iterator;
using iterator = const_iterator;
Set(const T& empty_key)
: mapping{empty_key}
{
}
bool insert(const T& element)
{
bool& entry = mapping[element];
bool fresh = !entry;
if (fresh)
{
entry = true;
entryCount++;
}
return fresh;
}
template<class Iterator>
void insert(Iterator begin, Iterator end)
{
for (Iterator it = begin; it != end; ++it)
insert(*it);
}
void erase(const T& element)
{
bool& entry = mapping[element];
if (entry)
{
entry = false;
entryCount--;
}
}
void clear()
{
mapping.clear();
entryCount = 0;
}
size_t size() const
{
return entryCount;
}
bool empty() const
{
return entryCount == 0;
}
size_t count(const T& element) const
{
const bool* entry = mapping.find(element);
return (entry && *entry) ? 1 : 0;
}
bool contains(const T& element) const
{
return count(element) != 0;
}
const_iterator begin() const
{
return const_iterator(mapping.begin(), mapping.end());
}
const_iterator end() const
{
return const_iterator(mapping.end(), mapping.end());
}
bool operator==(const Set<T>& there) const
{
// if the sets are unequal sizes, then they cannot possibly be equal.
if (size() != there.size())
return false;
// otherwise, we'll need to check that every element we have here is in `there`.
for (auto [elem, present] : mapping)
{
// if it's not, we'll return `false`
if (present && there.contains(elem))
return false;
}
// otherwise, we've proven the two equal!
return true;
}
class const_iterator
{
public:
const_iterator(typename Impl::const_iterator impl, typename Impl::const_iterator end)
: impl(impl)
, end(end)
{}
const T& operator*() const
{
return impl->first;
}
const T* operator->() const
{
return &impl->first;
}
bool operator==(const const_iterator& other) const
{
return impl == other.impl;
}
bool operator!=(const const_iterator& other) const
{
return impl != other.impl;
}
const_iterator& operator++()
{
do
{
impl++;
} while (impl != end && impl->second == false);
// keep iterating past pairs where the value is `false`
return *this;
}
const_iterator operator++(int)
{
const_iterator res = *this;
++*this;
return res;
}
private:
typename Impl::const_iterator impl;
typename Impl::const_iterator end;
};
};
} // namespace Luau

View file

@ -2,11 +2,10 @@
#pragma once
#include "Luau/DenseHash.h"
#include "Luau/NotNull.h"
#include "Luau/TypeFwd.h"
#include <set>
namespace Luau
{
@ -16,7 +15,7 @@ struct SimplifyResult
{
TypeId result;
std::set<TypeId> blockedTypes;
DenseHashSet<TypeId> blockedTypes;
};
SimplifyResult simplifyIntersection(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, TypeId ty, TypeId discriminant);

View file

@ -1,10 +1,11 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Set.h"
#include "Luau/TypeFwd.h"
#include "Luau/TypePairHash.h"
#include "Luau/UnifierSharedState.h"
#include "Luau/TypePath.h"
#include "Luau/DenseHash.h"
#include <vector>
#include <optional>
@ -22,15 +23,40 @@ struct NormalizedType;
struct NormalizedClassType;
struct NormalizedStringType;
struct NormalizedFunctionType;
struct TypeArena;
struct Scope;
struct TableIndexer;
enum class SubtypingVariance
{
// Used for an empty key. Should never appear in actual code.
Invalid,
Covariant,
// This is used to identify cases where we have a covariant + a
// contravariant reason and we need to merge them.
Contravariant,
Invariant,
};
struct SubtypingReasoning
{
// The path, relative to the _root subtype_, where subtyping failed.
Path subPath;
// The path, relative to the _root supertype_, where subtyping failed.
Path superPath;
SubtypingVariance variance = SubtypingVariance::Covariant;
bool operator==(const SubtypingReasoning& other) const;
};
struct SubtypingReasoningHash
{
size_t operator()(const SubtypingReasoning& r) const;
};
using SubtypingReasonings = DenseHashSet<SubtypingReasoning, SubtypingReasoningHash>;
static const SubtypingReasoning kEmptyReasoning = SubtypingReasoning{TypePath::kEmpty, TypePath::kEmpty, SubtypingVariance::Invalid};
struct SubtypingResult
{
bool isSubtype = false;
@ -40,7 +66,7 @@ struct SubtypingResult
/// The reason for isSubtype to be false. May not be present even if
/// isSubtype is false, depending on the input types.
std::optional<SubtypingReasoning> reasoning;
SubtypingReasonings reasoning{kEmptyReasoning};
SubtypingResult& andAlso(const SubtypingResult& other);
SubtypingResult& orElse(const SubtypingResult& other);
@ -92,9 +118,9 @@ struct Subtyping
Variance variance = Variance::Covariant;
using SeenSet = std::unordered_set<std::pair<TypeId, TypeId>, TypeIdPairHash>;
using SeenSet = Set<std::pair<TypeId, TypeId>, TypePairHash>;
SeenSet seenTypes;
SeenSet seenTypes{{}};
Subtyping(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> typeArena, NotNull<Normalizer> normalizer,
NotNull<InternalErrorReporter> iceReporter, NotNull<Scope> scope);

View file

@ -21,7 +21,6 @@
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
LUAU_FASTINT(LuauTableTypeMaximumStringifierLength)
@ -87,6 +86,24 @@ struct FreeType
TypeId upperBound = nullptr;
};
/** A type that tracks the domain of a local variable.
*
* We consider each local's domain to be the union of all types assigned to it.
* We accomplish this with LocalType. Each time we dispatch an assignment to a
* local, we accumulate this union and decrement blockCount.
*
* When blockCount reaches 0, we can consider the LocalType to be "fully baked"
* and replace it with the union we've built.
*/
struct LocalType
{
TypeId domain;
int blockCount = 0;
// Used for debugging
std::string name;
};
struct GenericType
{
// By default, generics are global, with a synthetic name
@ -141,6 +158,7 @@ struct PrimitiveType
Thread,
Function,
Table,
Buffer,
};
Type type;
@ -373,7 +391,15 @@ struct Property
bool deprecated = false;
std::string deprecatedSuggestion;
// If this property was inferred from an expression, this field will be
// populated with the source location of the corresponding table property.
std::optional<Location> location = std::nullopt;
// If this property was built from an explicit type annotation, this field
// will be populated with the source location of that table property.
std::optional<Location> typeLocation = std::nullopt;
Tags tags;
std::optional<std::string> documentationSymbol;
@ -381,7 +407,7 @@ struct Property
// TODO: Kill all constructors in favor of `Property::rw(TypeId read, TypeId write)` and friends.
Property();
Property(TypeId readTy, bool deprecated = false, const std::string& deprecatedSuggestion = "", std::optional<Location> location = std::nullopt,
const Tags& tags = {}, const std::optional<std::string>& documentationSymbol = std::nullopt);
const Tags& tags = {}, const std::optional<std::string>& documentationSymbol = std::nullopt, std::optional<Location> typeLocation = std::nullopt);
// DEPRECATED: Should only be called in non-RWP! We assert that the `readTy` is not nullopt.
// TODO: Kill once we don't have non-RWP.
@ -615,7 +641,7 @@ struct NegationType
using ErrorType = Unifiable::Error;
using TypeVariant =
Unifiable::Variant<TypeId, FreeType, GenericType, PrimitiveType, BlockedType, PendingExpansionType, SingletonType, FunctionType, TableType,
Unifiable::Variant<TypeId, FreeType, LocalType, GenericType, PrimitiveType, BlockedType, PendingExpansionType, SingletonType, FunctionType, TableType,
MetatableType, ClassType, AnyType, UnionType, IntersectionType, LazyType, UnknownType, NeverType, NegationType, TypeFamilyInstanceType>;
struct Type final
@ -739,6 +765,7 @@ bool isBoolean(TypeId ty);
bool isNumber(TypeId ty);
bool isString(TypeId ty);
bool isThread(TypeId ty);
bool isBuffer(TypeId ty);
bool isOptional(TypeId ty);
bool isTableIntersection(TypeId ty);
bool isOverloadedFunction(TypeId ty);
@ -797,6 +824,7 @@ public:
const TypeId stringType;
const TypeId booleanType;
const TypeId threadType;
const TypeId bufferType;
const TypeId functionType;
const TypeId classType;
const TypeId tableType;
@ -965,7 +993,7 @@ private:
using SavedIterInfo = std::pair<const T*, size_t>;
std::deque<SavedIterInfo> stack;
std::unordered_set<const T*> seen; // Only needed to protect the iterator from hanging the thread.
DenseHashSet<const T*> seen{nullptr}; // Only needed to protect the iterator from hanging the thread.
void advance()
{
@ -992,7 +1020,7 @@ private:
{
// If we're about to descend into a cyclic type, we should skip over this.
// Ideally this should never happen, but alas it does from time to time. :(
if (seen.find(inner) != seen.end())
if (seen.contains(inner))
advance();
else
{

View file

@ -2,8 +2,6 @@
#pragma once
#include "Luau/Ast.h"
#include "Luau/Module.h"
#include "Luau/NotNull.h"
namespace Luau
@ -13,6 +11,8 @@ struct BuiltinTypes;
struct DcrLogger;
struct TypeCheckLimits;
struct UnifierSharedState;
struct SourceModule;
struct Module;
void check(NotNull<BuiltinTypes> builtinTypes, NotNull<UnifierSharedState> sharedState, NotNull<TypeCheckLimits> limits, DcrLogger* logger,
const SourceModule& sourceModule, Module* module);

View file

@ -377,6 +377,7 @@ public:
const TypeId stringType;
const TypeId booleanType;
const TypeId threadType;
const TypeId bufferType;
const TypeId anyType;
const TypeId unknownType;
const TypeId neverType;

View file

@ -12,32 +12,28 @@ namespace Luau
const void* ptr(TypeOrPack ty);
template<typename T>
const T* get(TypeOrPack ty)
template<typename T, typename std::enable_if_t<TypeOrPack::is_part_of_v<T>, bool> = true>
const T* get(const TypeOrPack& tyOrTp)
{
if constexpr (std::is_same_v<T, TypeId>)
return ty.get_if<TypeId>();
else if constexpr (std::is_same_v<T, TypePackId>)
return ty.get_if<TypePackId>();
else if constexpr (TypeVariant::is_part_of_v<T>)
{
if (auto innerTy = ty.get_if<TypeId>())
return get<T>(*innerTy);
else
return nullptr;
}
else if constexpr (TypePackVariant::is_part_of_v<T>)
{
if (auto innerTp = ty.get_if<TypePackId>())
return get<T>(*innerTp);
else
return nullptr;
}
return tyOrTp.get_if<T>();
}
template<typename T, typename std::enable_if_t<TypeVariant::is_part_of_v<T>, bool> = true>
const T* get(const TypeOrPack& tyOrTp)
{
if (const TypeId* ty = get<TypeId>(tyOrTp))
return get<T>(*ty);
else
{
static_assert(always_false_v<T>, "invalid T to get from TypeOrPack");
LUAU_UNREACHABLE();
}
return nullptr;
}
template<typename T, typename std::enable_if_t<TypePackVariant::is_part_of_v<T>, bool> = true>
const T* get(const TypeOrPack& tyOrTp)
{
if (const TypePackId* tp = get<TypePackId>(tyOrTp))
return get<T>(*tp);
else
return nullptr;
}
TypeOrPack follow(TypeOrPack ty);

View file

@ -4,7 +4,6 @@
#include "Luau/TypeFwd.h"
#include "Luau/Variant.h"
#include "Luau/NotNull.h"
#include "Luau/TypeOrPack.h"
#include <optional>
#include <string>
@ -153,6 +152,16 @@ struct Path
}
};
struct PathHash
{
size_t operator()(const Property& prop) const;
size_t operator()(const Index& idx) const;
size_t operator()(const TypeField& field) const;
size_t operator()(const PackField& field) const;
size_t operator()(const Component& component) const;
size_t operator()(const Path& path) const;
};
/// The canonical "empty" Path, meaning a Path with no components.
static const Path kEmpty{};
@ -184,7 +193,7 @@ using Path = TypePath::Path;
/// Converts a Path to a string for debugging purposes. This output may not be
/// terribly clear to end users of the Luau type system.
std::string toString(const TypePath::Path& path);
std::string toString(const TypePath::Path& path, bool prefixDot = false);
std::optional<TypeOrPack> traverse(TypeId root, const Path& path, NotNull<BuiltinTypes> builtinTypes);
std::optional<TypeOrPack> traverse(TypePackId root, const Path& path, NotNull<BuiltinTypes> builtinTypes);

View file

@ -6,7 +6,6 @@
#include "Luau/NotNull.h"
#include "Luau/TypePairHash.h"
#include "Luau/TypeCheckLimits.h"
#include "Luau/TypeChecker2.h"
#include "Luau/TypeFwd.h"
#include <optional>
@ -37,6 +36,8 @@ struct Unifier2
DenseHashSet<std::pair<TypeId, TypeId>, TypePairHash> seenTypePairings{{nullptr, nullptr}};
DenseHashSet<std::pair<TypePackId, TypePackId>, TypePairHash> seenTypePackPairings{{nullptr, nullptr}};
DenseHashMap<TypeId, std::vector<TypeId>> expandedFreeTypes{nullptr};
int recursionCount = 0;
int recursionLimit = 0;
@ -60,7 +61,7 @@ struct Unifier2
bool unify(TypeId subTy, const UnionType* superUnion);
bool unify(const IntersectionType* subIntersection, TypeId superTy);
bool unify(TypeId subTy, const IntersectionType* superIntersection);
bool unify(const TableType* subTable, const TableType* superTable);
bool unify(TableType* subTable, const TableType* superTable);
bool unify(const MetatableType* subMetatable, const MetatableType* superMetatable);
// TODO think about this one carefully. We don't do unions or intersections of type packs

View file

@ -97,6 +97,10 @@ struct GenericTypeVisitor
{
return visit(ty);
}
virtual bool visit(TypeId ty, const LocalType& ftv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const GenericType& gtv)
{
return visit(ty);
@ -241,6 +245,11 @@ struct GenericTypeVisitor
else
visit(ty, *ftv);
}
else if (auto lt = get<LocalType>(ty))
{
if (visit(ty, *lt))
traverse(lt->domain);
}
else if (auto gtv = get<GenericType>(ty))
visit(ty, *gtv);
else if (auto etv = get<ErrorType>(ty))

View file

@ -8,7 +8,6 @@
#include <math.h>
LUAU_FASTFLAG(LuauFloorDivision);
LUAU_FASTFLAG(LuauClipExtraHasEndProps);
namespace Luau
@ -519,7 +518,6 @@ struct AstJsonEncoder : public AstVisitor
case AstExprBinary::Div:
return writeString("Div");
case AstExprBinary::FloorDiv:
LUAU_ASSERT(FFlag::LuauFloorDivision);
return writeString("FloorDiv");
case AstExprBinary::Mod:
return writeString("Mod");

View file

@ -5,6 +5,7 @@
#include "Luau/BuiltinDefinitions.h"
#include "Luau/Frontend.h"
#include "Luau/ToString.h"
#include "Luau/Subtyping.h"
#include "Luau/TypeInfer.h"
#include "Luau/TypePack.h"
@ -12,6 +13,7 @@
#include <unordered_set>
#include <utility>
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution);
LUAU_FASTFLAG(DebugLuauReadWriteProperties);
LUAU_FASTFLAG(LuauClipExtraHasEndProps);
LUAU_FASTFLAGVARIABLE(LuauAutocompleteDoEnd, false);
@ -143,13 +145,24 @@ static bool checkTypeMatch(TypeId subTy, TypeId superTy, NotNull<Scope> scope, T
InternalErrorReporter iceReporter;
UnifierSharedState unifierState(&iceReporter);
Normalizer normalizer{typeArena, builtinTypes, NotNull{&unifierState}};
Unifier unifier(NotNull<Normalizer>{&normalizer}, scope, Location(), Variance::Covariant);
// Cost of normalization can be too high for autocomplete response time requirements
unifier.normalize = false;
unifier.checkInhabited = false;
if (FFlag::DebugLuauDeferredConstraintResolution)
{
Subtyping subtyping{builtinTypes, NotNull{typeArena}, NotNull{&normalizer}, NotNull{&iceReporter}, scope};
return subtyping.isSubtype(subTy, superTy).isSubtype;
}
else
{
Unifier unifier(NotNull<Normalizer>{&normalizer}, scope, Location(), Variance::Covariant);
// Cost of normalization can be too high for autocomplete response time requirements
unifier.normalize = false;
unifier.checkInhabited = false;
return unifier.canUnify(subTy, superTy).empty();
}
return unifier.canUnify(subTy, superTy).empty();
}
static TypeCorrectKind checkTypeCorrectKind(

View file

@ -7,7 +7,7 @@
#include "Luau/Common.h"
#include "Luau/ToString.h"
#include "Luau/ConstraintSolver.h"
#include "Luau/ConstraintGraphBuilder.h"
#include "Luau/ConstraintGenerator.h"
#include "Luau/NotNull.h"
#include "Luau/TypeInfer.h"
#include "Luau/TypeFamily.h"

View file

@ -12,9 +12,8 @@ LUAU_FASTFLAG(DebugLuauReadWriteProperties)
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300)
LUAU_FASTFLAGVARIABLE(LuauCloneCyclicUnions, false)
LUAU_FASTFLAGVARIABLE(LuauStacklessTypeClone2, false)
LUAU_FASTFLAGVARIABLE(LuauStacklessTypeClone3, false)
LUAU_FASTINTVARIABLE(LuauTypeCloneIterationLimit, 100'000)
namespace Luau
@ -118,6 +117,8 @@ private:
ty = follow(ty, FollowOption::DisableLazyTypeThunks);
if (auto it = types->find(ty); it != types->end())
return it->second;
else if (ty->persistent)
return ty;
return std::nullopt;
}
@ -126,6 +127,8 @@ private:
tp = follow(tp);
if (auto it = packs->find(tp); it != packs->end())
return it->second;
else if (tp->persistent)
return tp;
return std::nullopt;
}
@ -258,6 +261,11 @@ private:
t->upperBound = shallowClone(t->upperBound);
}
void cloneChildren(LocalType* t)
{
t->domain = shallowClone(t->domain);
}
void cloneChildren(GenericType* t)
{
// TOOD: clone upper bounds.
@ -501,6 +509,7 @@ struct TypeCloner
void defaultClone(const T& t);
void operator()(const FreeType& t);
void operator()(const LocalType& t);
void operator()(const GenericType& t);
void operator()(const BoundType& t);
void operator()(const ErrorType& t);
@ -628,6 +637,11 @@ void TypeCloner::operator()(const FreeType& t)
defaultClone(t);
}
void TypeCloner::operator()(const LocalType& t)
{
defaultClone(t);
}
void TypeCloner::operator()(const GenericType& t)
{
defaultClone(t);
@ -778,33 +792,19 @@ void TypeCloner::operator()(const AnyType& t)
void TypeCloner::operator()(const UnionType& t)
{
if (FFlag::LuauCloneCyclicUnions)
{
// We're just using this FreeType as a placeholder until we've finished
// cloning the parts of this union so it is okay that its bounds are
// nullptr. We'll never indirect them.
TypeId result = dest.addType(FreeType{nullptr, /*lowerBound*/ nullptr, /*upperBound*/ nullptr});
seenTypes[typeId] = result;
// We're just using this FreeType as a placeholder until we've finished
// cloning the parts of this union so it is okay that its bounds are
// nullptr. We'll never indirect them.
TypeId result = dest.addType(FreeType{nullptr, /*lowerBound*/ nullptr, /*upperBound*/ nullptr});
seenTypes[typeId] = result;
std::vector<TypeId> options;
options.reserve(t.options.size());
std::vector<TypeId> options;
options.reserve(t.options.size());
for (TypeId ty : t.options)
options.push_back(clone(ty, dest, cloneState));
for (TypeId ty : t.options)
options.push_back(clone(ty, dest, cloneState));
asMutable(result)->ty.emplace<UnionType>(std::move(options));
}
else
{
std::vector<TypeId> options;
options.reserve(t.options.size());
for (TypeId ty : t.options)
options.push_back(clone(ty, dest, cloneState));
TypeId result = dest.addType(UnionType{std::move(options)});
seenTypes[typeId] = result;
}
asMutable(result)->ty.emplace<UnionType>(std::move(options));
}
void TypeCloner::operator()(const IntersectionType& t)
@ -879,7 +879,7 @@ TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState)
if (tp->persistent)
return tp;
if (FFlag::LuauStacklessTypeClone2)
if (FFlag::LuauStacklessTypeClone3)
{
TypeCloner2 cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}};
return cloner.clone(tp);
@ -905,7 +905,7 @@ TypeId clone(TypeId typeId, TypeArena& dest, CloneState& cloneState)
if (typeId->persistent)
return typeId;
if (FFlag::LuauStacklessTypeClone2)
if (FFlag::LuauStacklessTypeClone3)
{
TypeCloner2 cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}};
return cloner.clone(typeId);
@ -934,7 +934,7 @@ TypeId clone(TypeId typeId, TypeArena& dest, CloneState& cloneState)
TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState)
{
if (FFlag::LuauStacklessTypeClone2)
if (FFlag::LuauStacklessTypeClone3)
{
TypeCloner2 cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}};

View file

@ -2,13 +2,11 @@
#include "Luau/Anyification.h"
#include "Luau/ApplyTypeFunction.h"
#include "Luau/Clone.h"
#include "Luau/Common.h"
#include "Luau/ConstraintSolver.h"
#include "Luau/DcrLogger.h"
#include "Luau/Instantiation.h"
#include "Luau/Location.h"
#include "Luau/Metamethods.h"
#include "Luau/ModuleResolver.h"
#include "Luau/Quantify.h"
#include "Luau/Simplify.h"
@ -17,12 +15,12 @@
#include "Luau/Type.h"
#include "Luau/TypeFamily.h"
#include "Luau/TypeUtils.h"
#include "Luau/Unifier.h"
#include "Luau/Unifier2.h"
#include "Luau/VisitType.h"
#include <algorithm>
#include <utility>
LUAU_FASTFLAGVARIABLE(DebugLuauLogSolver, false);
LUAU_FASTFLAG(LuauFloorDivision);
namespace Luau
{
@ -995,6 +993,27 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull<cons
return block(c.fn, constraint);
}
auto [argsHead, argsTail] = flatten(argsPack);
bool blocked = false;
for (TypeId t : argsHead)
{
if (isBlocked(t))
{
block(t, constraint);
blocked = true;
}
}
if (argsTail && isBlocked(*argsTail))
{
block(*argsTail, constraint);
blocked = true;
}
if (blocked)
return false;
auto collapse = [](const auto* t) -> std::optional<TypeId> {
auto it = begin(t);
auto endIt = end(t);
@ -1020,10 +1039,12 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull<cons
// We don't support magic __call metamethods.
if (std::optional<TypeId> callMm = findMetatableEntry(builtinTypes, errors, fn, "__call", constraint->location))
{
auto [head, tail] = flatten(c.argsPack);
head.insert(head.begin(), fn);
argsHead.insert(argsHead.begin(), fn);
argsPack = arena->addTypePack(TypePack{std::move(head), tail});
if (argsTail && isBlocked(*argsTail))
return block(*argsTail, constraint);
argsPack = arena->addTypePack(TypePack{std::move(argsHead), argsTail});
fn = follow(*callMm);
asMutable(c.result)->ty.emplace<FreeTypePack>(constraint->scope);
}
@ -1103,6 +1124,12 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull<cons
const bool occursCheckPassed = u2.unify(fn, inferredTy);
for (const auto& [expanded, additions] : u2.expandedFreeTypes)
{
for (TypeId addition : additions)
upperBoundContributors[expanded].push_back(std::make_pair(constraint->location, addition));
}
if (occursCheckPassed && c.callSite)
(*c.astOverloadResolvedTypes)[c.callSite] = inferredTy;
@ -1132,23 +1159,14 @@ bool ConstraintSolver::tryDispatch(const PrimitiveTypeConstraint& c, NotNull<con
bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNull<const Constraint> constraint)
{
TypeId subjectType = follow(c.subjectType);
const TypeId subjectType = follow(c.subjectType);
const TypeId resultType = follow(c.resultType);
LUAU_ASSERT(get<BlockedType>(c.resultType));
LUAU_ASSERT(get<BlockedType>(resultType));
if (isBlocked(subjectType) || get<PendingExpansionType>(subjectType))
return block(subjectType, constraint);
if (get<FreeType>(subjectType))
{
TableType& ttv = asMutable(subjectType)->ty.emplace<TableType>(TableState::Free, TypeLevel{}, constraint->scope);
ttv.props[c.prop] = Property{c.resultType};
TypeId res = freshType(arena, builtinTypes, constraint->scope);
asMutable(c.resultType)->ty.emplace<BoundType>(res);
unblock(c.resultType, constraint->location);
return true;
}
auto [blocked, result] = lookupTableProp(subjectType, c.prop, c.suppressSimplification);
if (!blocked.empty())
{
@ -1158,8 +1176,8 @@ bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNull<const Con
return false;
}
bindBlockedType(c.resultType, result.value_or(builtinTypes->anyType), c.subjectType, constraint->location);
unblock(c.resultType, constraint->location);
bindBlockedType(resultType, result.value_or(builtinTypes->anyType), c.subjectType, constraint->location);
unblock(resultType, constraint->location);
return true;
}
@ -1245,9 +1263,6 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNull<const Con
if (isBlocked(subjectType))
return block(subjectType, constraint);
if (!force && get<FreeType>(subjectType))
return block(subjectType, constraint);
std::optional<TypeId> existingPropType = subjectType;
for (const std::string& segment : c.path)
{
@ -1283,25 +1298,13 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNull<const Con
if (get<FreeType>(subjectType))
{
TypeId ty = freshType(arena, builtinTypes, constraint->scope);
// Mint a chain of free tables per c.path
for (auto it = rbegin(c.path); it != rend(c.path); ++it)
{
TableType t{TableState::Free, TypeLevel{}, constraint->scope};
t.props[*it] = {ty};
ty = arena->addType(std::move(t));
}
LUAU_ASSERT(ty);
bind(subjectType, ty);
if (follow(c.resultType) != follow(ty))
bind(c.resultType, ty);
unblock(subjectType, constraint->location);
unblock(c.resultType, constraint->location);
return true;
/*
* This should never occur because lookupTableProp() will add bounds to
* any free types it encounters. There will always be an
* existingPropType if the subject is free.
*/
LUAU_ASSERT(false);
return false;
}
else if (auto ttv = getMutable<TableType>(subjectType))
{
@ -1310,7 +1313,7 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNull<const Con
LUAU_ASSERT(!subjectType->persistent);
ttv->props[c.path[0]] = Property{c.propType};
bind(c.resultType, c.subjectType);
bind(c.resultType, subjectType);
unblock(c.resultType, constraint->location);
return true;
}
@ -1319,26 +1322,12 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNull<const Con
LUAU_ASSERT(!subjectType->persistent);
updateTheTableType(builtinTypes, NotNull{arena}, subjectType, c.path, c.propType);
bind(c.resultType, c.subjectType);
unblock(subjectType, constraint->location);
unblock(c.resultType, constraint->location);
return true;
}
else
{
bind(c.resultType, subjectType);
unblock(c.resultType, constraint->location);
return true;
}
}
else
{
// Other kinds of types don't change shape when properties are assigned
// to them. (if they allow properties at all!)
bind(c.resultType, subjectType);
unblock(c.resultType, constraint->location);
return true;
}
bind(c.resultType, subjectType);
unblock(c.resultType, constraint->location);
return true;
}
bool ConstraintSolver::tryDispatch(const SetIndexerConstraint& c, NotNull<const Constraint> constraint, bool force)
@ -1434,32 +1423,57 @@ bool ConstraintSolver::tryDispatch(const UnpackConstraint& c, NotNull<const Cons
TypePack srcPack = extendTypePack(*arena, builtinTypes, sourcePack, size(resultPack));
auto destIter = begin(resultPack);
auto destEnd = end(resultPack);
auto resultIter = begin(resultPack);
auto resultEnd = end(resultPack);
size_t i = 0;
while (destIter != destEnd)
while (resultIter != resultEnd)
{
if (i >= srcPack.head.size())
break;
TypeId srcTy = follow(srcPack.head[i]);
if (isBlocked(*destIter))
TypeId srcTy = follow(srcPack.head[i]);
TypeId resultTy = follow(*resultIter);
if (resultTy)
{
if (follow(srcTy) == *destIter)
if (auto lt = getMutable<LocalType>(resultTy); c.resultIsLValue && lt)
{
// Cyclic type dependency. (????)
TypeId f = freshType(arena, builtinTypes, constraint->scope);
asMutable(*destIter)->ty.emplace<BoundType>(f);
lt->domain = simplifyUnion(builtinTypes, arena, lt->domain, srcTy).result;
LUAU_ASSERT(lt->blockCount > 0);
--lt->blockCount;
LUAU_ASSERT(0 <= lt->blockCount);
if (0 == lt->blockCount)
asMutable(resultTy)->ty.emplace<BoundType>(lt->domain);
}
else if (get<BlockedType>(resultTy))
{
if (follow(srcTy) == resultTy)
{
// It is sometimes the case that we find that a blocked type
// is only blocked on itself. This doesn't actually
// constitute any meaningful constraint, so we replace it
// with a free type.
TypeId f = freshType(arena, builtinTypes, constraint->scope);
asMutable(resultTy)->ty.emplace<BoundType>(f);
}
else
asMutable(resultTy)->ty.emplace<BoundType>(srcTy);
}
else
asMutable(*destIter)->ty.emplace<BoundType>(srcTy);
unblock(*destIter, constraint->location);
{
LUAU_ASSERT(c.resultIsLValue);
unify(constraint->scope, constraint->location, resultTy, srcTy);
}
unblock(resultTy, constraint->location);
}
else
unify(constraint->scope, constraint->location, *destIter, srcTy);
unify(constraint->scope, constraint->location, resultTy, srcTy);
++destIter;
++resultIter;
++i;
}
@ -1467,15 +1481,25 @@ bool ConstraintSolver::tryDispatch(const UnpackConstraint& c, NotNull<const Cons
// sourcePack is long enough to fill every value. Replace every remaining
// result TypeId with `nil`.
while (destIter != destEnd)
while (resultIter != resultEnd)
{
if (isBlocked(*destIter))
TypeId resultTy = follow(*resultIter);
if (auto lt = getMutable<LocalType>(resultTy); c.resultIsLValue && lt)
{
asMutable(*destIter)->ty.emplace<BoundType>(builtinTypes->nilType);
unblock(*destIter, constraint->location);
lt->domain = simplifyUnion(builtinTypes, arena, lt->domain, builtinTypes->nilType).result;
LUAU_ASSERT(0 <= lt->blockCount);
--lt->blockCount;
if (0 == lt->blockCount)
asMutable(resultTy)->ty.emplace<BoundType>(lt->domain);
}
else if (get<BlockedType>(*resultIter) || get<PendingExpansionType>(*resultIter))
{
asMutable(*resultIter)->ty.emplace<BoundType>(builtinTypes->nilType);
unblock(*resultIter, constraint->location);
}
++destIter;
++resultIter;
}
return true;
@ -1490,7 +1514,7 @@ namespace
*/
struct FindRefineConstraintBlockers : TypeOnceVisitor
{
std::unordered_set<TypeId> found;
DenseHashSet<TypeId> found{nullptr};
bool visit(TypeId ty, const BlockedType&) override
{
found.insert(ty);
@ -1855,6 +1879,7 @@ bool ConstraintSolver::tryDispatchIterableFunction(
TypeId retIndex;
if (isNil(firstIndexTy) || isOptional(firstIndexTy))
{
// FIXME freshType is suspect here
firstIndex = arena->addType(UnionType{{freshType(arena, builtinTypes, constraint->scope), builtinTypes->nilType}});
retIndex = firstIndex;
}
@ -1896,7 +1921,7 @@ bool ConstraintSolver::tryDispatchIterableFunction(
modifiedNextRetHead.push_back(*it);
TypePackId modifiedNextRetPack = arena->addTypePack(std::move(modifiedNextRetHead), it.tail());
auto psc = pushConstraint(constraint->scope, constraint->location, PackSubtypeConstraint{c.variables, modifiedNextRetPack});
auto psc = pushConstraint(constraint->scope, constraint->location, UnpackConstraint{c.variables, modifiedNextRetPack});
inheritBlocks(constraint, psc);
return true;
@ -1905,15 +1930,16 @@ bool ConstraintSolver::tryDispatchIterableFunction(
std::pair<std::vector<TypeId>, std::optional<TypeId>> ConstraintSolver::lookupTableProp(
TypeId subjectType, const std::string& propName, bool suppressSimplification)
{
std::unordered_set<TypeId> seen;
DenseHashSet<TypeId> seen{nullptr};
return lookupTableProp(subjectType, propName, suppressSimplification, seen);
}
std::pair<std::vector<TypeId>, std::optional<TypeId>> ConstraintSolver::lookupTableProp(
TypeId subjectType, const std::string& propName, bool suppressSimplification, std::unordered_set<TypeId>& seen)
TypeId subjectType, const std::string& propName, bool suppressSimplification, DenseHashSet<TypeId>& seen)
{
if (!seen.insert(subjectType).second)
if (seen.contains(subjectType))
return {};
seen.insert(subjectType);
subjectType = follow(subjectType);
@ -1994,14 +2020,23 @@ std::pair<std::vector<TypeId>, std::optional<TypeId>> ConstraintSolver::lookupTa
}
else if (auto ft = get<FreeType>(subjectType))
{
Scope* scope = ft->scope;
const TypeId upperBound = follow(ft->upperBound);
TableType* tt = &asMutable(subjectType)->ty.emplace<TableType>();
tt->state = TableState::Free;
tt->scope = scope;
if (get<TableType>(upperBound))
return lookupTableProp(upperBound, propName, suppressSimplification, seen);
// TODO: The upper bound could be an intersection that contains suitable tables or classes.
NotNull<Scope> scope{ft->scope};
const TypeId newUpperBound = arena->addType(TableType{TableState::Free, TypeLevel{}, scope});
TableType* tt = getMutable<TableType>(newUpperBound);
LUAU_ASSERT(tt);
TypeId propType = freshType(arena, builtinTypes, scope);
tt->props[propName] = Property{propType};
unify(scope, Location{}, subjectType, newUpperBound);
return {{}, propType};
}
else if (auto utv = get<UnionType>(subjectType))
@ -2073,7 +2108,15 @@ bool ConstraintSolver::tryUnify(NotNull<const Constraint> constraint, TID subTy,
bool success = u2.unify(subTy, superTy);
if (!success)
if (success)
{
for (const auto& [expanded, additions] : u2.expandedFreeTypes)
{
for (TypeId addition : additions)
upperBoundContributors[expanded].push_back(std::make_pair(constraint->location, addition));
}
}
else
{
// Unification only fails when doing so would fail the occurs check.
// ie create a self-bound type or a cyclic type pack
@ -2285,7 +2328,12 @@ void ConstraintSolver::unblock(const std::vector<TypePackId>& packs, Location lo
bool ConstraintSolver::isBlocked(TypeId ty)
{
return nullptr != get<BlockedType>(follow(ty)) || nullptr != get<PendingExpansionType>(follow(ty));
ty = follow(ty);
if (auto lt = get<LocalType>(ty))
return lt->blockCount > 0;
return nullptr != get<BlockedType>(ty) || nullptr != get<PendingExpansionType>(ty);
}
bool ConstraintSolver::isBlocked(TypePackId tp)
@ -2320,6 +2368,12 @@ ErrorVec ConstraintSolver::unify(NotNull<Scope> scope, Location location, TypePa
u.unify(subPack, superPack);
for (const auto& [expanded, additions] : u.expandedFreeTypes)
{
for (TypeId addition : additions)
upperBoundContributors[expanded].push_back(std::make_pair(location, addition));
}
unblock(subPack, Location{});
unblock(superPack, Location{});

View file

@ -11,10 +11,13 @@
LUAU_FASTFLAG(DebugLuauFreezeArena)
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
LUAU_FASTFLAG(LuauLoopControlFlowAnalysis)
namespace Luau
{
bool doesCallError(const AstExprCall* call); // TypeInfer.cpp
const RefinementKey* RefinementKeyArena::leaf(DefId def)
{
return allocator.allocate(RefinementKey{nullptr, def, std::nullopt});
@ -34,7 +37,7 @@ DefId DataFlowGraph::getDef(const AstExpr* expr) const
std::optional<DefId> DataFlowGraph::getRValueDefForCompoundAssign(const AstExpr* expr) const
{
auto def = compoundAssignBreadcrumbs.find(expr);
auto def = compoundAssignDefs.find(expr);
return def ? std::optional<DefId>(*def) : std::nullopt;
}
@ -82,9 +85,9 @@ std::optional<DefId> DfgScope::lookup(DefId def, const std::string& key) const
{
for (const DfgScope* current = this; current; current = current->parent)
{
if (auto map = props.find(def))
if (auto props = current->props.find(def))
{
if (auto it = map->find(key); it != map->end())
if (auto it = props->find(key); it != props->end())
return NotNull{it->second};
}
}
@ -92,6 +95,47 @@ std::optional<DefId> DfgScope::lookup(DefId def, const std::string& key) const
return std::nullopt;
}
void DfgScope::inherit(const DfgScope* childScope)
{
for (const auto& [k, a] : childScope->bindings)
{
if (lookup(k))
bindings[k] = a;
}
for (const auto& [k1, a1] : childScope->props)
{
for (const auto& [k2, a2] : a1)
props[k1][k2] = a2;
}
}
bool DfgScope::canUpdateDefinition(Symbol symbol) const
{
for (const DfgScope* current = this; current; current = current->parent)
{
if (current->bindings.find(symbol))
return true;
else if (current->isLoopScope)
return false;
}
return true;
}
bool DfgScope::canUpdateDefinition(DefId def, const std::string& key) const
{
for (const DfgScope* current = this; current; current = current->parent)
{
if (auto props = current->props.find(def))
return true;
else if (current->isLoopScope)
return false;
}
return true;
}
DataFlowGraph DataFlowGraphBuilder::build(AstStatBlock* block, NotNull<InternalErrorReporter> handle)
{
LUAU_ASSERT(FFlag::DebugLuauDeferredConstraintResolution);
@ -110,24 +154,138 @@ DataFlowGraph DataFlowGraphBuilder::build(AstStatBlock* block, NotNull<InternalE
return std::move(builder.graph);
}
DfgScope* DataFlowGraphBuilder::childScope(DfgScope* scope)
DfgScope* DataFlowGraphBuilder::childScope(DfgScope* scope, bool isLoopScope)
{
return scopes.emplace_back(new DfgScope{scope}).get();
return scopes.emplace_back(new DfgScope{scope, isLoopScope}).get();
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatBlock* b)
void DataFlowGraphBuilder::join(DfgScope* p, DfgScope* a, DfgScope* b)
{
joinBindings(p->bindings, a->bindings, b->bindings);
joinProps(p->props, a->props, b->props);
}
void DataFlowGraphBuilder::joinBindings(DfgScope::Bindings& p, const DfgScope::Bindings& a, const DfgScope::Bindings& b)
{
for (const auto& [sym, def1] : a)
{
if (auto def2 = b.find(sym))
p[sym] = defArena->phi(NotNull{def1}, NotNull{*def2});
else if (auto def2 = p.find(sym))
p[sym] = defArena->phi(NotNull{def1}, NotNull{*def2});
}
for (const auto& [sym, def1] : b)
{
if (auto def2 = p.find(sym))
p[sym] = defArena->phi(NotNull{def1}, NotNull{*def2});
}
}
void DataFlowGraphBuilder::joinProps(DfgScope::Props& p, const DfgScope::Props& a, const DfgScope::Props& b)
{
auto phinodify = [this](auto& p, const auto& a, const auto& b) mutable {
for (const auto& [k, defA] : a)
{
if (auto it = b.find(k); it != b.end())
p[k] = defArena->phi(NotNull{it->second}, NotNull{defA});
else if (auto it = p.find(k); it != p.end())
p[k] = defArena->phi(NotNull{it->second}, NotNull{defA});
else
p[k] = defA;
}
for (const auto& [k, defB] : b)
{
if (auto it = a.find(k); it != a.end())
continue;
else if (auto it = p.find(k); it != p.end())
p[k] = defArena->phi(NotNull{it->second}, NotNull{defB});
else
p[k] = defB;
}
};
for (const auto& [def, a1] : a)
{
p.try_insert(def, {});
if (auto a2 = b.find(def))
phinodify(p[def], a1, *a2);
else if (auto a2 = p.find(def))
phinodify(p[def], a1, *a2);
}
for (const auto& [def, a1] : b)
{
p.try_insert(def, {});
if (a.find(def))
continue;
else if (auto a2 = p.find(def))
phinodify(p[def], a1, *a2);
}
}
DefId DataFlowGraphBuilder::lookup(DfgScope* scope, Symbol symbol)
{
if (auto found = scope->lookup(symbol))
return *found;
else
{
DefId result = defArena->freshCell();
if (symbol.local)
scope->bindings[symbol] = result;
else
moduleScope->bindings[symbol] = result;
return result;
}
}
DefId DataFlowGraphBuilder::lookup(DfgScope* scope, DefId def, const std::string& key)
{
if (auto found = scope->lookup(def, key))
return *found;
else if (auto phi = get<Phi>(def))
{
std::vector<DefId> defs;
for (DefId operand : phi->operands)
defs.push_back(lookup(scope, operand, key));
DefId result = defArena->phi(defs);
scope->props[def][key] = result;
return result;
}
else if (get<Cell>(def))
{
DefId result = defArena->freshCell();
scope->props[def][key] = result;
return result;
}
else
handle->ice("Inexhaustive lookup cases in DataFlowGraphBuilder::lookup");
}
ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatBlock* b)
{
DfgScope* child = childScope(scope);
return visitBlockWithoutChildScope(child, b);
ControlFlow cf = visitBlockWithoutChildScope(child, b);
scope->inherit(child);
return cf;
}
void DataFlowGraphBuilder::visitBlockWithoutChildScope(DfgScope* scope, AstStatBlock* b)
ControlFlow DataFlowGraphBuilder::visitBlockWithoutChildScope(DfgScope* scope, AstStatBlock* b)
{
for (AstStat* s : b->body)
visit(scope, s);
std::optional<ControlFlow> firstControlFlow;
for (AstStat* stat : b->body)
{
ControlFlow cf = visit(scope, stat);
if (cf != ControlFlow::None && !firstControlFlow)
firstControlFlow = cf;
}
return firstControlFlow.value_or(ControlFlow::None);
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStat* s)
ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStat* s)
{
if (auto b = s->as<AstStatBlock>())
return visit(scope, b);
@ -173,56 +331,85 @@ void DataFlowGraphBuilder::visit(DfgScope* scope, AstStat* s)
handle->ice("Unknown AstStat in DataFlowGraphBuilder::visit");
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatIf* i)
ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatIf* i)
{
// TODO: type states and control flow analysis
visitExpr(scope, i->condition);
visit(scope, i->thenbody);
DfgScope* thenScope = childScope(scope);
DfgScope* elseScope = childScope(scope);
ControlFlow thencf = visit(thenScope, i->thenbody);
ControlFlow elsecf = ControlFlow::None;
if (i->elsebody)
visit(scope, i->elsebody);
elsecf = visit(elseScope, i->elsebody);
if (thencf != ControlFlow::None && elsecf == ControlFlow::None)
join(scope, scope, elseScope);
else if (thencf == ControlFlow::None && elsecf != ControlFlow::None)
join(scope, thenScope, scope);
else if ((thencf | elsecf) == ControlFlow::None)
join(scope, thenScope, elseScope);
if (FFlag::LuauLoopControlFlowAnalysis && thencf == elsecf)
return thencf;
else if (matches(thencf, ControlFlow::Returns | ControlFlow::Throws) && matches(elsecf, ControlFlow::Returns | ControlFlow::Throws))
return ControlFlow::Returns;
else
return ControlFlow::None;
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatWhile* w)
ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatWhile* w)
{
// TODO(controlflow): entry point has a back edge from exit point
DfgScope* whileScope = childScope(scope);
DfgScope* whileScope = childScope(scope, /*isLoopScope=*/true);
visitExpr(whileScope, w->condition);
visit(whileScope, w->body);
scope->inherit(whileScope);
return ControlFlow::None;
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatRepeat* r)
ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatRepeat* r)
{
// TODO(controlflow): entry point has a back edge from exit point
DfgScope* repeatScope = childScope(scope); // TODO: loop scope.
DfgScope* repeatScope = childScope(scope, /*isLoopScope=*/true);
visitBlockWithoutChildScope(repeatScope, r->body);
visitExpr(repeatScope, r->condition);
scope->inherit(repeatScope);
return ControlFlow::None;
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatBreak* b)
ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatBreak* b)
{
// TODO: Control flow analysis
return; // ok
return ControlFlow::Breaks;
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatContinue* c)
ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatContinue* c)
{
// TODO: Control flow analysis
return; // ok
return ControlFlow::Continues;
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatReturn* r)
ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatReturn* r)
{
// TODO: Control flow analysis
for (AstExpr* e : r->list)
visitExpr(scope, e);
return ControlFlow::Returns;
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatExpr* e)
ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatExpr* e)
{
visitExpr(scope, e->expr);
if (auto call = e->expr->as<AstExprCall>(); call && doesCallError(call))
return ControlFlow::Throws;
else
return ControlFlow::None;
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatLocal* l)
ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatLocal* l)
{
// We're gonna need a `visitExprList` and `visitVariadicExpr` (function calls and `...`)
std::vector<DefId> defs;
@ -243,11 +430,13 @@ void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatLocal* l)
graph.localDefs[local] = def;
scope->bindings[local] = def;
}
return ControlFlow::None;
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatFor* f)
ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatFor* f)
{
DfgScope* forScope = childScope(scope); // TODO: loop scope.
DfgScope* forScope = childScope(scope, /*isLoopScope=*/true);
visitExpr(scope, f->from);
visitExpr(scope, f->to);
@ -263,11 +452,15 @@ void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatFor* f)
// TODO(controlflow): entry point has a back edge from exit point
visit(forScope, f->body);
scope->inherit(forScope);
return ControlFlow::None;
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatForIn* f)
ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatForIn* f)
{
DfgScope* forScope = childScope(scope); // TODO: loop scope.
DfgScope* forScope = childScope(scope, /*isLoopScope=*/true);
for (AstLocal* local : f->vars)
{
@ -285,9 +478,13 @@ void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatForIn* f)
visitExpr(forScope, e);
visit(forScope, f->body);
scope->inherit(forScope);
return ControlFlow::None;
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatAssign* a)
ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatAssign* a)
{
std::vector<DefId> defs;
defs.reserve(a->values.size);
@ -299,9 +496,11 @@ void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatAssign* a)
AstExpr* v = a->vars.data[i];
visitLValue(scope, v, i < defs.size() ? defs[i] : defArena->freshCell());
}
return ControlFlow::None;
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatCompoundAssign* c)
ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatCompoundAssign* c)
{
// TODO: This needs revisiting because this is incorrect. The `c->var` part is both being read and written to,
// but the `c->var` only has one pointer address, so we need to come up with a way to store both.
@ -312,9 +511,11 @@ void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatCompoundAssign* c)
// We can't just visit `c->var` as a rvalue and then separately traverse `c->var` as an lvalue, since that's O(n^2).
DefId def = visitExpr(scope, c->value).def;
visitLValue(scope, c->var, def, /* isCompoundAssignment */ true);
return ControlFlow::None;
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatFunction* f)
ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatFunction* f)
{
// In the old solver, we assumed that the name of the function is always a function in the body
// but this isn't true, e.g. the following example will print `5`, not a function address.
@ -329,34 +530,42 @@ void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatFunction* f)
DefId prototype = defArena->freshCell();
visitLValue(scope, f->name, prototype);
visitExpr(scope, f->func);
return ControlFlow::None;
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatLocalFunction* l)
ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatLocalFunction* l)
{
DefId def = defArena->freshCell();
graph.localDefs[l->name] = def;
scope->bindings[l->name] = def;
visitExpr(scope, l->func);
return ControlFlow::None;
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatTypeAlias* t)
ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatTypeAlias* t)
{
DfgScope* unreachable = childScope(scope);
visitGenerics(unreachable, t->generics);
visitGenericPacks(unreachable, t->genericPacks);
visitType(unreachable, t->type);
return ControlFlow::None;
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatDeclareGlobal* d)
ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatDeclareGlobal* d)
{
DefId def = defArena->freshCell();
graph.declaredDefs[d] = def;
scope->bindings[d->name] = def;
visitType(scope, d->type);
return ControlFlow::None;
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatDeclareFunction* d)
ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatDeclareFunction* d)
{
DefId def = defArena->freshCell();
graph.declaredDefs[d] = def;
@ -367,9 +576,11 @@ void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatDeclareFunction* d)
visitGenericPacks(unreachable, d->genericPacks);
visitTypeList(unreachable, d->params);
visitTypeList(unreachable, d->retTypes);
return ControlFlow::None;
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatDeclareClass* d)
ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatDeclareClass* d)
{
// This declaration does not "introduce" any bindings in value namespace,
// so there's no symbolic value to begin with. We'll traverse the properties
@ -377,19 +588,30 @@ void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatDeclareClass* d)
DfgScope* unreachable = childScope(scope);
for (AstDeclaredClassProp prop : d->props)
visitType(unreachable, prop.ty);
return ControlFlow::None;
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatError* error)
ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatError* error)
{
DfgScope* unreachable = childScope(scope);
for (AstStat* s : error->statements)
visit(unreachable, s);
for (AstExpr* e : error->expressions)
visitExpr(unreachable, e);
return ControlFlow::None;
}
DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExpr* e)
{
// Some subexpressions could be visited two times. If we've already seen it, just extract it.
if (auto def = graph.astDefs.find(e))
{
auto key = graph.astRefinementKeys.find(e);
return {NotNull{*def}, key ? *key : nullptr};
}
auto go = [&]() -> DataFlowResult {
if (auto g = e->as<AstExprGroup>())
return visitExpr(scope, g);
@ -447,6 +669,7 @@ DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprGroup* gr
DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprLocal* l)
{
// DfgScope::lookup is intentional here: we want to be able to ice.
if (auto def = scope->lookup(l->local))
{
const RefinementKey* key = keyArena->leaf(*def);
@ -458,11 +681,7 @@ DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprLocal* l)
DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprGlobal* g)
{
if (auto def = scope->lookup(g->name))
return {*def, keyArena->leaf(*def)};
DefId def = defArena->freshCell();
moduleScope->bindings[g->name] = def;
DefId def = lookup(scope, g->name);
return {def, keyArena->leaf(def)};
}
@ -481,11 +700,9 @@ DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIndexName
auto [parentDef, parentKey] = visitExpr(scope, i->expr);
std::string index = i->index.value;
auto& propDef = moduleScope->props[parentDef][index];
if (!propDef)
propDef = defArena->freshCell();
return {NotNull{propDef}, keyArena->node(parentKey, NotNull{propDef}, index)};
DefId def = lookup(scope, parentDef, index);
return {def, keyArena->node(parentKey, def, index)};
}
DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIndexExpr* i)
@ -496,11 +713,9 @@ DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIndexExpr
if (auto string = i->index->as<AstExprConstantString>())
{
std::string index{string->value.data, string->value.size};
auto& propDef = moduleScope->props[parentDef][index];
if (!propDef)
propDef = defArena->freshCell();
return {NotNull{propDef}, keyArena->node(parentKey, NotNull{propDef}, index)};
DefId def = lookup(scope, parentDef, index);
return {def, keyArena->node(parentKey, def, index)};
}
return {defArena->freshCell(/* subscripted= */true), nullptr};
@ -628,41 +843,56 @@ void DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExpr* e, DefId incomi
void DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprLocal* l, DefId incomingDef, bool isCompoundAssignment)
{
// We need to keep the previous breadcrumb around for a compound assignment.
// We need to keep the previous def around for a compound assignment.
if (isCompoundAssignment)
{
if (auto def = scope->lookup(l->local))
graph.compoundAssignBreadcrumbs[l] = *def;
graph.compoundAssignDefs[l] = *def;
}
// In order to avoid alias tracking, we need to clip the reference to the parent def.
DefId updated = defArena->freshCell(containsSubscriptedDefinition(incomingDef));
graph.astDefs[l] = updated;
scope->bindings[l->local] = updated;
if (scope->canUpdateDefinition(l->local))
{
DefId updated = defArena->freshCell(containsSubscriptedDefinition(incomingDef));
graph.astDefs[l] = updated;
scope->bindings[l->local] = updated;
}
else
visitExpr(scope, static_cast<AstExpr*>(l));
}
void DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprGlobal* g, DefId incomingDef, bool isCompoundAssignment)
{
// We need to keep the previous breadcrumb around for a compound assignment.
// We need to keep the previous def around for a compound assignment.
if (isCompoundAssignment)
{
if (auto def = scope->lookup(g->name))
graph.compoundAssignBreadcrumbs[g] = *def;
DefId def = lookup(scope, g->name);
graph.compoundAssignDefs[g] = def;
}
// In order to avoid alias tracking, we need to clip the reference to the parent def.
DefId updated = defArena->freshCell(containsSubscriptedDefinition(incomingDef));
graph.astDefs[g] = updated;
scope->bindings[g->name] = updated;
if (scope->canUpdateDefinition(g->name))
{
DefId updated = defArena->freshCell(containsSubscriptedDefinition(incomingDef));
graph.astDefs[g] = updated;
scope->bindings[g->name] = updated;
}
else
visitExpr(scope, static_cast<AstExpr*>(g));
}
void DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprIndexName* i, DefId incomingDef)
{
DefId parentDef = visitExpr(scope, i->expr).def;
DefId updated = defArena->freshCell(containsSubscriptedDefinition(incomingDef));
graph.astDefs[i] = updated;
scope->props[parentDef][i->index.value] = updated;
if (scope->canUpdateDefinition(parentDef, i->index.value))
{
DefId updated = defArena->freshCell(containsSubscriptedDefinition(incomingDef));
graph.astDefs[i] = updated;
scope->props[parentDef][i->index.value] = updated;
}
else
visitExpr(scope, static_cast<AstExpr*>(i));
}
void DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprIndexExpr* i, DefId incomingDef)
@ -672,9 +902,14 @@ void DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprIndexExpr* i, Def
if (auto string = i->index->as<AstExprConstantString>())
{
DefId updated = defArena->freshCell(containsSubscriptedDefinition(incomingDef));
graph.astDefs[i] = updated;
scope->props[parentDef][string->value.data] = updated;
if (scope->canUpdateDefinition(parentDef, string->value.data))
{
DefId updated = defArena->freshCell(containsSubscriptedDefinition(incomingDef));
graph.astDefs[i] = updated;
scope->props[parentDef][string->value.data] = updated;
}
else
visitExpr(scope, static_cast<AstExpr*>(i));
}
graph.astDefs[i] = defArena->freshCell();

View file

@ -1,7 +1,11 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Def.h"
#include "Luau/Common.h"
#include <algorithm>
#include <deque>
namespace Luau
{
@ -9,9 +13,10 @@ bool containsSubscriptedDefinition(DefId def)
{
if (auto cell = get<Cell>(def))
return cell->subscripted;
LUAU_ASSERT(!"Phi nodes not implemented yet");
return false;
else if (auto phi = get<Phi>(def))
return std::any_of(phi->operands.begin(), phi->operands.end(), containsSubscriptedDefinition);
else
return false;
}
DefId DefArena::freshCell(bool subscripted)
@ -19,4 +24,35 @@ DefId DefArena::freshCell(bool subscripted)
return NotNull{allocator.allocate(Def{Cell{subscripted}})};
}
static void collectOperands(DefId def, std::vector<DefId>& operands)
{
if (std::find(operands.begin(), operands.end(), def) != operands.end())
return;
else if (get<Cell>(def))
operands.push_back(def);
else if (auto phi = get<Phi>(def))
{
for (const Def* operand : phi->operands)
collectOperands(NotNull{operand}, operands);
}
}
DefId DefArena::phi(DefId a, DefId b)
{
return phi({a, b});
}
DefId DefArena::phi(const std::vector<DefId>& defs)
{
std::vector<DefId> operands;
for (DefId operand : defs)
collectOperands(operand, operands);
// There's no need to allocate a Phi node for a singleton set.
if (operands.size() == 1)
return operands[0];
else
return NotNull{allocator.allocate(Def{Phi{std::move(operands)}})};
}
} // namespace Luau

View file

@ -2,11 +2,12 @@
#include "Luau/BuiltinDefinitions.h"
LUAU_FASTFLAGVARIABLE(LuauBufferDefinitions, false)
LUAU_FASTFLAGVARIABLE(LuauBufferTypeck, false)
namespace Luau
{
static const std::string kBuiltinDefinitionBufferSrc = R"BUILTIN_SRC(
static const std::string kBuiltinDefinitionBufferSrc_DEPRECATED = R"BUILTIN_SRC(
-- TODO: this will be replaced with a built-in primitive type
declare class buffer end
@ -40,6 +41,36 @@ declare buffer: {
)BUILTIN_SRC";
static const std::string kBuiltinDefinitionBufferSrc = R"BUILTIN_SRC(
declare buffer: {
create: (size: number) -> buffer,
fromstring: (str: string) -> buffer,
tostring: (b: buffer) -> string,
len: (b: buffer) -> number,
copy: (target: buffer, targetOffset: number, source: buffer, sourceOffset: number?, count: number?) -> (),
fill: (b: buffer, offset: number, value: number, count: number?) -> (),
readi8: (b: buffer, offset: number) -> number,
readu8: (b: buffer, offset: number) -> number,
readi16: (b: buffer, offset: number) -> number,
readu16: (b: buffer, offset: number) -> number,
readi32: (b: buffer, offset: number) -> number,
readu32: (b: buffer, offset: number) -> number,
readf32: (b: buffer, offset: number) -> number,
readf64: (b: buffer, offset: number) -> number,
writei8: (b: buffer, offset: number, value: number) -> (),
writeu8: (b: buffer, offset: number, value: number) -> (),
writei16: (b: buffer, offset: number, value: number) -> (),
writeu16: (b: buffer, offset: number, value: number) -> (),
writei32: (b: buffer, offset: number, value: number) -> (),
writeu32: (b: buffer, offset: number, value: number) -> (),
writef32: (b: buffer, offset: number, value: number) -> (),
writef64: (b: buffer, offset: number, value: number) -> (),
readstring: (b: buffer, offset: number, count: number) -> string,
writestring: (b: buffer, offset: number, value: string, count: number?) -> (),
}
)BUILTIN_SRC";
static const std::string kBuiltinDefinitionLuaSrc = R"BUILTIN_SRC(
declare bit32: {
@ -236,8 +267,10 @@ std::string getBuiltinDefinitionSource()
{
std::string result = kBuiltinDefinitionLuaSrc;
if (FFlag::LuauBufferDefinitions)
if (FFlag::LuauBufferTypeck)
result = kBuiltinDefinitionBufferSrc + result;
else if (FFlag::LuauBufferDefinitions)
result = kBuiltinDefinitionBufferSrc_DEPRECATED + result;
return result;
}

View file

@ -490,7 +490,12 @@ struct ErrorConverter
std::string operator()(const TypePackMismatch& e) const
{
return "Type pack '" + toString(e.givenTp) + "' could not be converted into '" + toString(e.wantedTp) + "'";
std::string ss = "Type pack '" + toString(e.givenTp) + "' could not be converted into '" + toString(e.wantedTp) + "'";
if (!e.reason.empty())
ss += "; " + e.reason;
return ss;
}
std::string operator()(const DynamicPropertyLookupOnClassesUnsafe& e) const
@ -528,6 +533,12 @@ struct ErrorConverter
return "Function '" + e.checkedFunctionName + "' expects '" + toString(e.expected) + "' at argument #" + std::to_string(e.argumentIndex) +
", but got '" + Luau::toString(e.passed) + "'";
}
std::string operator()(const NonStrictFunctionDefinitionError& e) const
{
return "Argument " + e.argument + " with type '" + toString(e.argumentType) + "' in function '" + e.functionName +
"' is used in a way that will run time error";
}
};
struct InvalidNameChecker
@ -856,6 +867,11 @@ bool CheckedFunctionCallError::operator==(const CheckedFunctionCallError& rhs) c
argumentIndex == rhs.argumentIndex;
}
bool NonStrictFunctionDefinitionError::operator==(const NonStrictFunctionDefinitionError& rhs) const
{
return functionName == rhs.functionName && argument == rhs.argument && argumentType == rhs.argumentType;
}
std::string toString(const TypeError& error)
{
return toString(error, TypeErrorToStringOptions{});
@ -1027,6 +1043,10 @@ void copyError(T& e, TypeArena& destArena, CloneState& cloneState)
e.expected = clone(e.expected);
e.passed = clone(e.passed);
}
else if constexpr (std::is_same_v<T, NonStrictFunctionDefinitionError>)
{
e.argumentType = clone(e.argumentType);
}
else
static_assert(always_false_v<T>, "Non-exhaustive type switch");
}

View file

@ -5,7 +5,7 @@
#include "Luau/Clone.h"
#include "Luau/Common.h"
#include "Luau/Config.h"
#include "Luau/ConstraintGraphBuilder.h"
#include "Luau/ConstraintGenerator.h"
#include "Luau/ConstraintSolver.h"
#include "Luau/DataFlowGraph.h"
#include "Luau/DcrLogger.h"
@ -38,6 +38,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false)
LUAU_FASTFLAGVARIABLE(DebugLuauReadWriteProperties, false)
LUAU_FASTFLAGVARIABLE(LuauTypecheckLimitControls, false)
LUAU_FASTFLAGVARIABLE(CorrectEarlyReturnInMarkDirty, false)
LUAU_FASTFLAGVARIABLE(LuauDefinitionFileSetModuleName, false)
namespace Luau
{
@ -165,6 +166,11 @@ LoadDefinitionFileResult Frontend::loadDefinitionFile(GlobalTypes& globals, Scop
LUAU_TIMETRACE_SCOPE("loadDefinitionFile", "Frontend");
Luau::SourceModule sourceModule;
if (FFlag::LuauDefinitionFileSetModuleName)
{
sourceModule.name = packageName;
sourceModule.humanReadableName = packageName;
}
Luau::ParseResult parseResult = parseSourceForModule(source, sourceModule, captureComments);
if (parseResult.errors.size() > 0)
return LoadDefinitionFileResult{false, parseResult, sourceModule, nullptr};
@ -251,7 +257,7 @@ namespace
static ErrorVec accumulateErrors(
const std::unordered_map<ModuleName, std::shared_ptr<SourceNode>>& sourceNodes, ModuleResolver& moduleResolver, const ModuleName& name)
{
std::unordered_set<ModuleName> seen;
DenseHashSet<ModuleName> seen{{}};
std::vector<ModuleName> queue{name};
ErrorVec result;
@ -261,7 +267,7 @@ static ErrorVec accumulateErrors(
ModuleName next = std::move(queue.back());
queue.pop_back();
if (seen.count(next))
if (seen.contains(next))
continue;
seen.insert(next);
@ -442,7 +448,7 @@ CheckResult Frontend::check(const ModuleName& name, std::optional<FrontendOption
std::vector<ModuleName> buildQueue;
bool cycleDetected = parseGraph(buildQueue, name, frontendOptions.forAutocomplete);
std::unordered_set<Luau::ModuleName> seen;
DenseHashSet<Luau::ModuleName> seen{{}};
std::vector<BuildQueueItem> buildQueueItems;
addBuildQueueItems(buildQueueItems, buildQueue, cycleDetected, seen, frontendOptions);
LUAU_ASSERT(!buildQueueItems.empty());
@ -495,12 +501,12 @@ std::vector<ModuleName> Frontend::checkQueuedModules(std::optional<FrontendOptio
std::vector<ModuleName> currModuleQueue;
std::swap(currModuleQueue, moduleQueue);
std::unordered_set<Luau::ModuleName> seen;
DenseHashSet<Luau::ModuleName> seen{{}};
std::vector<BuildQueueItem> buildQueueItems;
for (const ModuleName& name : currModuleQueue)
{
if (seen.count(name))
if (seen.contains(name))
continue;
if (!isDirty(name, frontendOptions.forAutocomplete))
@ -511,7 +517,7 @@ std::vector<ModuleName> Frontend::checkQueuedModules(std::optional<FrontendOptio
std::vector<ModuleName> queue;
bool cycleDetected = parseGraph(queue, name, frontendOptions.forAutocomplete, [&seen](const ModuleName& name) {
return seen.count(name);
return seen.contains(name);
});
addBuildQueueItems(buildQueueItems, queue, cycleDetected, seen, frontendOptions);
@ -836,11 +842,11 @@ bool Frontend::parseGraph(
}
void Frontend::addBuildQueueItems(std::vector<BuildQueueItem>& items, std::vector<ModuleName>& buildQueue, bool cycleDetected,
std::unordered_set<Luau::ModuleName>& seen, const FrontendOptions& frontendOptions)
DenseHashSet<Luau::ModuleName>& seen, const FrontendOptions& frontendOptions)
{
for (const ModuleName& moduleName : buildQueue)
{
if (seen.count(moduleName))
if (seen.contains(moduleName))
continue;
seen.insert(moduleName);
@ -1048,6 +1054,7 @@ void Frontend::checkBuildQueueItem(BuildQueueItem& item)
module->astResolvedTypes.clear();
module->astResolvedTypePacks.clear();
module->astScopes.clear();
module->upperBoundContributors.clear();
if (!FFlag::DebugLuauDeferredConstraintResolution)
module->scopes.clear();
@ -1255,13 +1262,13 @@ ModulePtr check(const SourceModule& sourceModule, Mode mode, const std::vector<R
Normalizer normalizer{&result->internalTypes, builtinTypes, NotNull{&unifierState}};
ConstraintGraphBuilder cgb{result, NotNull{&normalizer}, moduleResolver, builtinTypes, iceHandler, parentScope, std::move(prepareModuleScope),
ConstraintGenerator cg{result, NotNull{&normalizer}, moduleResolver, builtinTypes, iceHandler, parentScope, std::move(prepareModuleScope),
logger.get(), NotNull{&dfg}, requireCycles};
cgb.visitModuleRoot(sourceModule.root);
result->errors = std::move(cgb.errors);
cg.visitModuleRoot(sourceModule.root);
result->errors = std::move(cg.errors);
ConstraintSolver cs{NotNull{&normalizer}, NotNull(cgb.rootScope), borrowConstraints(cgb.constraints), result->humanReadableName, moduleResolver,
ConstraintSolver cs{NotNull{&normalizer}, NotNull(cg.rootScope), borrowConstraints(cg.constraints), result->humanReadableName, moduleResolver,
requireCycles, logger.get(), limits};
if (options.randomizeConstraintResolutionSeed)
@ -1283,8 +1290,9 @@ ModulePtr check(const SourceModule& sourceModule, Mode mode, const std::vector<R
for (TypeError& e : cs.errors)
result->errors.emplace_back(std::move(e));
result->scopes = std::move(cgb.scopes);
result->scopes = std::move(cg.scopes);
result->type = sourceModule.type;
result->upperBoundContributors = std::move(cs.upperBoundContributors);
result->clonePublicInterface(builtinTypes, *iceHandler);

View file

@ -3,6 +3,7 @@
#include "Luau/GlobalTypes.h"
LUAU_FASTFLAG(LuauInitializeStringMetatableInGlobalTypes)
LUAU_FASTFLAG(LuauBufferTypeck)
namespace Luau
{
@ -18,6 +19,8 @@ GlobalTypes::GlobalTypes(NotNull<BuiltinTypes> builtinTypes)
globalScope->addBuiltinTypeBinding("string", TypeFun{{}, builtinTypes->stringType});
globalScope->addBuiltinTypeBinding("boolean", TypeFun{{}, builtinTypes->booleanType});
globalScope->addBuiltinTypeBinding("thread", TypeFun{{}, builtinTypes->threadType});
if (FFlag::LuauBufferTypeck)
globalScope->addBuiltinTypeBinding("buffer", TypeFun{{}, builtinTypes->bufferType});
globalScope->addBuiltinTypeBinding("unknown", TypeFun{{}, builtinTypes->unknownType});
globalScope->addBuiltinTypeBinding("never", TypeFun{{}, builtinTypes->neverType});

View file

@ -7,6 +7,8 @@
#include "Luau/TypeArena.h"
#include "Luau/TypeCheckLimits.h"
#include <algorithm>
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
namespace Luau

View file

@ -204,6 +204,9 @@ static void errorToString(std::ostream& stream, const T& err)
else if constexpr (std::is_same_v<T, CheckedFunctionCallError>)
stream << "CheckedFunctionCallError { expected = '" << toString(err.expected) << "', passed = '" << toString(err.passed)
<< "', checkedFunctionName = " << err.checkedFunctionName << ", argumentIndex = " << std::to_string(err.argumentIndex) << " }";
else if constexpr (std::is_same_v<T, NonStrictFunctionDefinitionError>)
stream << "NonStrictFunctionDefinitionError { functionName = '" + err.functionName + "', argument = '" + err.argument +
"', argumentType = '" + toString(err.argumentType) + "' }";
else
static_assert(always_false_v<T>, "Non-exhaustive type switch");
}

View file

@ -14,8 +14,7 @@
LUAU_FASTINTVARIABLE(LuauSuggestionDistance, 4)
LUAU_FASTFLAGVARIABLE(LuauLintDeprecatedFenv, false)
LUAU_FASTFLAGVARIABLE(LuauLintTableIndexer, false)
LUAU_FASTFLAG(LuauBufferTypeck)
namespace Luau
{
@ -1108,7 +1107,7 @@ private:
TypeKind getTypeKind(const std::string& name)
{
if (name == "nil" || name == "boolean" || name == "userdata" || name == "number" || name == "string" || name == "table" ||
name == "function" || name == "thread")
name == "function" || name == "thread" || (FFlag::LuauBufferTypeck && name == "buffer"))
return Kind_Primitive;
if (name == "vector")
@ -2093,7 +2092,7 @@ private:
// getfenv/setfenv are deprecated, however they are still used in some test frameworks and don't have a great general replacement
// for now we warn about the deprecation only when they are used with a numeric first argument; this produces fewer warnings and makes use
// of getfenv/setfenv a little more localized
if (FFlag::LuauLintDeprecatedFenv && !node->self && node->args.size >= 1)
if (!node->self && node->args.size >= 1)
{
if (AstExprGlobal* fenv = node->func->as<AstExprGlobal>(); fenv && (fenv->name == "getfenv" || fenv->name == "setfenv"))
{
@ -2185,7 +2184,7 @@ private:
bool visit(AstExprUnary* node) override
{
if (FFlag::LuauLintTableIndexer && node->op == AstExprUnary::Len)
if (node->op == AstExprUnary::Len)
checkIndexer(node, node->expr, "#");
return true;
@ -2195,7 +2194,7 @@ private:
{
if (AstExprGlobal* func = node->func->as<AstExprGlobal>())
{
if (FFlag::LuauLintTableIndexer && func->name == "ipairs" && node->args.size == 1)
if (func->name == "ipairs" && node->args.size == 1)
checkIndexer(node, node->args.data[0], "ipairs");
}
else if (AstExprIndexName* func = node->func->as<AstExprIndexName>())
@ -2209,8 +2208,6 @@ private:
void checkIndexer(AstExpr* node, AstExpr* expr, const char* op)
{
LUAU_ASSERT(FFlag::LuauLintTableIndexer);
std::optional<Luau::TypeId> ty = context->getType(expr);
if (!ty)
return;
@ -2220,7 +2217,8 @@ private:
return;
if (!tty->indexer && !tty->props.empty() && tty->state != TableState::Generic)
emitWarning(*context, LintWarning::Code_TableOperations, node->location, "Using '%s' on a table without an array part is likely a bug", op);
emitWarning(
*context, LintWarning::Code_TableOperations, node->location, "Using '%s' on a table without an array part is likely a bug", op);
else if (tty->indexer && isString(tty->indexer->indexType)) // note: to avoid complexity of subtype tests we just check if the key is a string
emitWarning(*context, LintWarning::Code_TableOperations, node->location, "Using '%s' on a table with string keys is likely a bug", op);
}
@ -2653,13 +2651,17 @@ private:
case ConstantNumberParseResult::Ok:
case ConstantNumberParseResult::Malformed:
break;
case ConstantNumberParseResult::Imprecise:
emitWarning(*context, LintWarning::Code_IntegerParsing, node->location,
"Number literal exceeded available precision and was truncated to closest representable number");
break;
case ConstantNumberParseResult::BinOverflow:
emitWarning(*context, LintWarning::Code_IntegerParsing, node->location,
"Binary number literal exceeded available precision and has been truncated to 2^64");
"Binary number literal exceeded available precision and was truncated to 2^64");
break;
case ConstantNumberParseResult::HexOverflow:
emitWarning(*context, LintWarning::Code_IntegerParsing, node->location,
"Hexadecimal number literal exceeded available precision and has been truncated to 2^64");
"Hexadecimal number literal exceeded available precision and was truncated to 2^64");
break;
}

View file

@ -3,7 +3,7 @@
#include "Luau/Clone.h"
#include "Luau/Common.h"
#include "Luau/ConstraintGraphBuilder.h"
#include "Luau/ConstraintGenerator.h"
#include "Luau/Normalize.h"
#include "Luau/RecursionCounter.h"
#include "Luau/Scope.h"

View file

@ -3,7 +3,9 @@
#include "Luau/Ast.h"
#include "Luau/Common.h"
#include "Luau/Simplify.h"
#include "Luau/Type.h"
#include "Luau/Simplify.h"
#include "Luau/Subtyping.h"
#include "Luau/Normalize.h"
#include "Luau/Error.h"
@ -12,6 +14,7 @@
#include "Luau/Def.h"
#include <iostream>
#include <iterator>
namespace Luau
{
@ -64,24 +67,60 @@ struct NonStrictContext
NonStrictContext(NonStrictContext&&) = default;
NonStrictContext& operator=(NonStrictContext&&) = default;
void unionContexts(const NonStrictContext& other)
static NonStrictContext disjunction(
NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, const NonStrictContext& left, const NonStrictContext& right)
{
// TODO: unimplemented
// disjunction implements union over the domain of keys
// if the default value for a defId not in the map is `never`
// then never | T is T
NonStrictContext disj{};
for (auto [def, leftTy] : left.context)
{
if (std::optional<TypeId> rightTy = right.find(def))
disj.context[def] = simplifyUnion(builtinTypes, arena, leftTy, *rightTy).result;
else
disj.context[def] = leftTy;
}
for (auto [def, rightTy] : right.context)
{
if (!left.find(def).has_value())
disj.context[def] = rightTy;
}
return disj;
}
void intersectContexts(const NonStrictContext& other)
static NonStrictContext conjunction(
NotNull<BuiltinTypes> builtins, NotNull<TypeArena> arena, const NonStrictContext& left, const NonStrictContext& right)
{
// TODO: unimplemented
NonStrictContext conj{};
for (auto [def, leftTy] : left.context)
{
if (std::optional<TypeId> rightTy = right.find(def))
conj.context[def] = simplifyIntersection(builtins, arena, leftTy, *rightTy).result;
}
return conj;
}
void removeFromContext(const std::vector<DefId>& defs)
// Returns true if the removal was successful
bool remove(const DefId& def)
{
// TODO: unimplemented
return context.erase(def.get()) == 1;
}
std::optional<TypeId> find(const DefId& def) const
{
const Def* d = def.get();
return find(d);
}
private:
std::optional<TypeId> find(const Def* d) const
{
auto it = context.find(d);
if (it != context.end())
return {it->second};
@ -101,6 +140,7 @@ struct NonStrictTypeChecker
NotNull<const DataFlowGraph> dfg;
DenseHashSet<TypeId> noTypeFamilyErrors{nullptr};
std::vector<NotNull<Scope>> stack;
DenseHashMap<TypeId, TypeId> cachedNegations{nullptr};
const NotNull<TypeCheckLimits> limits;
@ -180,153 +220,282 @@ struct NonStrictTypeChecker
return builtinTypes->anyType;
}
void visit(AstStat* stat)
{
NonStrictContext fresh{};
visit(stat, fresh);
}
void visit(AstStat* stat, NonStrictContext& context)
NonStrictContext visit(AstStat* stat)
{
auto pusher = pushStack(stat);
if (auto s = stat->as<AstStatBlock>())
return visit(s, context);
return visit(s);
else if (auto s = stat->as<AstStatIf>())
return visit(s, context);
return visit(s);
else if (auto s = stat->as<AstStatWhile>())
return visit(s, context);
return visit(s);
else if (auto s = stat->as<AstStatRepeat>())
return visit(s, context);
return visit(s);
else if (auto s = stat->as<AstStatBreak>())
return visit(s, context);
return visit(s);
else if (auto s = stat->as<AstStatContinue>())
return visit(s, context);
return visit(s);
else if (auto s = stat->as<AstStatReturn>())
return visit(s, context);
return visit(s);
else if (auto s = stat->as<AstStatExpr>())
return visit(s, context);
return visit(s);
else if (auto s = stat->as<AstStatLocal>())
return visit(s, context);
return visit(s);
else if (auto s = stat->as<AstStatFor>())
return visit(s, context);
return visit(s);
else if (auto s = stat->as<AstStatForIn>())
return visit(s, context);
return visit(s);
else if (auto s = stat->as<AstStatAssign>())
return visit(s, context);
return visit(s);
else if (auto s = stat->as<AstStatCompoundAssign>())
return visit(s, context);
return visit(s);
else if (auto s = stat->as<AstStatFunction>())
return visit(s, context);
return visit(s);
else if (auto s = stat->as<AstStatLocalFunction>())
return visit(s, context);
return visit(s);
else if (auto s = stat->as<AstStatTypeAlias>())
return visit(s, context);
return visit(s);
else if (auto s = stat->as<AstStatDeclareFunction>())
return visit(s, context);
return visit(s);
else if (auto s = stat->as<AstStatDeclareGlobal>())
return visit(s, context);
return visit(s);
else if (auto s = stat->as<AstStatDeclareClass>())
return visit(s, context);
return visit(s);
else if (auto s = stat->as<AstStatError>())
return visit(s, context);
return visit(s);
else
LUAU_ASSERT(!"NonStrictTypeChecker encountered an unknown node type");
{
LUAU_ASSERT(!"NonStrictTypeChecker encountered an unknown statement type");
ice->ice("NonStrictTypeChecker encountered an unknown statement type");
}
}
void visit(AstStatBlock* block, NonStrictContext& context)
NonStrictContext visit(AstStatBlock* block)
{
auto StackPusher = pushStack(block);
for (AstStat* statement : block->body)
visit(statement, context);
NonStrictContext ctx;
for (auto it = block->body.rbegin(); it != block->body.rend(); it++)
{
AstStat* stat = *it;
if (AstStatLocal* local = stat->as<AstStatLocal>())
{
// Iterating in reverse order
// local x ; B generates the context of B without x
visit(local);
for (auto local : local->vars)
ctx.remove(dfg->getDef(local));
}
else
ctx = NonStrictContext::disjunction(builtinTypes, NotNull{&arena}, visit(stat), ctx);
}
return ctx;
}
void visit(AstStatIf* ifStatement, NonStrictContext& context) {}
void visit(AstStatWhile* whileStatement, NonStrictContext& context) {}
void visit(AstStatRepeat* repeatStatement, NonStrictContext& context) {}
void visit(AstStatBreak* breakStatement, NonStrictContext& context) {}
void visit(AstStatContinue* continueStatement, NonStrictContext& context) {}
void visit(AstStatReturn* returnStatement, NonStrictContext& context) {}
void visit(AstStatExpr* expr, NonStrictContext& context)
NonStrictContext visit(AstStatIf* ifStatement)
{
visit(expr->expr, context);
NonStrictContext condB = visit(ifStatement->condition);
NonStrictContext branchContext;
// If there is no else branch, don't bother generating warnings for the then branch - we can't prove there is an error
if (ifStatement->elsebody)
{
NonStrictContext thenBody = visit(ifStatement->thenbody);
NonStrictContext elseBody = visit(ifStatement->elsebody);
branchContext = NonStrictContext::conjunction(builtinTypes, NotNull{&arena}, thenBody, elseBody);
}
return NonStrictContext::disjunction(builtinTypes, NotNull{&arena}, condB, branchContext);
}
void visit(AstStatLocal* local, NonStrictContext& context) {}
void visit(AstStatFor* forStatement, NonStrictContext& context) {}
void visit(AstStatForIn* forInStatement, NonStrictContext& context) {}
void visit(AstStatAssign* assign, NonStrictContext& context) {}
void visit(AstStatCompoundAssign* compoundAssign, NonStrictContext& context) {}
void visit(AstStatFunction* statFn, NonStrictContext& context) {}
void visit(AstStatLocalFunction* localFn, NonStrictContext& context) {}
void visit(AstStatTypeAlias* typeAlias, NonStrictContext& context) {}
void visit(AstStatDeclareFunction* declFn, NonStrictContext& context) {}
void visit(AstStatDeclareGlobal* declGlobal, NonStrictContext& context) {}
void visit(AstStatDeclareClass* declClass, NonStrictContext& context) {}
void visit(AstStatError* error, NonStrictContext& context) {}
void visit(AstExpr* expr, NonStrictContext& context)
NonStrictContext visit(AstStatWhile* whileStatement)
{
return {};
}
NonStrictContext visit(AstStatRepeat* repeatStatement)
{
return {};
}
NonStrictContext visit(AstStatBreak* breakStatement)
{
return {};
}
NonStrictContext visit(AstStatContinue* continueStatement)
{
return {};
}
NonStrictContext visit(AstStatReturn* returnStatement)
{
return {};
}
NonStrictContext visit(AstStatExpr* expr)
{
return visit(expr->expr);
}
NonStrictContext visit(AstStatLocal* local)
{
for (AstExpr* rhs : local->values)
visit(rhs);
return {};
}
NonStrictContext visit(AstStatFor* forStatement)
{
return {};
}
NonStrictContext visit(AstStatForIn* forInStatement)
{
return {};
}
NonStrictContext visit(AstStatAssign* assign)
{
return {};
}
NonStrictContext visit(AstStatCompoundAssign* compoundAssign)
{
return {};
}
NonStrictContext visit(AstStatFunction* statFn)
{
return visit(statFn->func);
}
NonStrictContext visit(AstStatLocalFunction* localFn)
{
return visit(localFn->func);
}
NonStrictContext visit(AstStatTypeAlias* typeAlias)
{
return {};
}
NonStrictContext visit(AstStatDeclareFunction* declFn)
{
return {};
}
NonStrictContext visit(AstStatDeclareGlobal* declGlobal)
{
return {};
}
NonStrictContext visit(AstStatDeclareClass* declClass)
{
return {};
}
NonStrictContext visit(AstStatError* error)
{
return {};
}
NonStrictContext visit(AstExpr* expr)
{
auto pusher = pushStack(expr);
if (auto e = expr->as<AstExprGroup>())
return visit(e, context);
return visit(e);
else if (auto e = expr->as<AstExprConstantNil>())
return visit(e, context);
return visit(e);
else if (auto e = expr->as<AstExprConstantBool>())
return visit(e, context);
return visit(e);
else if (auto e = expr->as<AstExprConstantNumber>())
return visit(e, context);
return visit(e);
else if (auto e = expr->as<AstExprConstantString>())
return visit(e, context);
return visit(e);
else if (auto e = expr->as<AstExprLocal>())
return visit(e, context);
return visit(e);
else if (auto e = expr->as<AstExprGlobal>())
return visit(e, context);
return visit(e);
else if (auto e = expr->as<AstExprVarargs>())
return visit(e, context);
return visit(e);
else if (auto e = expr->as<AstExprCall>())
return visit(e, context);
return visit(e);
else if (auto e = expr->as<AstExprIndexName>())
return visit(e, context);
return visit(e);
else if (auto e = expr->as<AstExprIndexExpr>())
return visit(e, context);
return visit(e);
else if (auto e = expr->as<AstExprFunction>())
return visit(e, context);
return visit(e);
else if (auto e = expr->as<AstExprTable>())
return visit(e, context);
return visit(e);
else if (auto e = expr->as<AstExprUnary>())
return visit(e, context);
return visit(e);
else if (auto e = expr->as<AstExprBinary>())
return visit(e, context);
return visit(e);
else if (auto e = expr->as<AstExprTypeAssertion>())
return visit(e, context);
return visit(e);
else if (auto e = expr->as<AstExprIfElse>())
return visit(e, context);
return visit(e);
else if (auto e = expr->as<AstExprInterpString>())
return visit(e, context);
return visit(e);
else if (auto e = expr->as<AstExprError>())
return visit(e, context);
return visit(e);
else
{
LUAU_ASSERT(!"NonStrictTypeChecker encountered an unknown expression type");
ice->ice("NonStrictTypeChecker encountered an unknown expression type");
}
}
void visit(AstExprGroup* group, NonStrictContext& context) {}
void visit(AstExprConstantNil* expr, NonStrictContext& context) {}
void visit(AstExprConstantBool* expr, NonStrictContext& context) {}
void visit(AstExprConstantNumber* expr, NonStrictContext& context) {}
void visit(AstExprConstantString* expr, NonStrictContext& context) {}
void visit(AstExprLocal* local, NonStrictContext& context) {}
void visit(AstExprGlobal* global, NonStrictContext& context) {}
void visit(AstExprVarargs* global, NonStrictContext& context) {}
void visit(AstExprCall* call, NonStrictContext& context)
NonStrictContext visit(AstExprGroup* group)
{
return {};
}
NonStrictContext visit(AstExprConstantNil* expr)
{
return {};
}
NonStrictContext visit(AstExprConstantBool* expr)
{
return {};
}
NonStrictContext visit(AstExprConstantNumber* expr)
{
return {};
}
NonStrictContext visit(AstExprConstantString* expr)
{
return {};
}
NonStrictContext visit(AstExprLocal* local)
{
return {};
}
NonStrictContext visit(AstExprGlobal* global)
{
return {};
}
NonStrictContext visit(AstExprVarargs* global)
{
return {};
}
NonStrictContext visit(AstExprCall* call)
{
NonStrictContext fresh{};
TypeId* originalCallTy = module->astOriginalCallTypes.find(call);
if (!originalCallTy)
return;
return fresh;
TypeId fnTy = *originalCallTy;
// TODO: how should we link this to the passed in context here
NonStrictContext fresh{};
if (auto fn = get<FunctionType>(follow(fnTy)))
{
if (fn->isCheckedFunction)
@ -353,9 +522,7 @@ struct NonStrictTypeChecker
AstExpr* arg = call->args.data[i];
TypeId expectedArgType = argTypes[i];
DefId def = dfg->getDef(arg);
// TODO: Cache negations created here!!!
// See Jira Ticket: https://roblox.atlassian.net/browse/CLI-87539
TypeId runTimeErrorTy = arena.addType(NegationType{expectedArgType});
TypeId runTimeErrorTy = getOrCreateNegation(expectedArgType);
fresh.context[def.get()] = runTimeErrorTy;
}
@ -369,21 +536,72 @@ struct NonStrictTypeChecker
}
}
}
return fresh;
}
void visit(AstExprIndexName* indexName, NonStrictContext& context) {}
void visit(AstExprIndexExpr* indexExpr, NonStrictContext& context) {}
void visit(AstExprFunction* exprFn, NonStrictContext& context)
NonStrictContext visit(AstExprIndexName* indexName)
{
auto pusher = pushStack(exprFn);
return {};
}
NonStrictContext visit(AstExprIndexExpr* indexExpr)
{
return {};
}
NonStrictContext visit(AstExprFunction* exprFn)
{
// TODO: should a function being used as an expression generate a context without the arguments?
auto pusher = pushStack(exprFn);
NonStrictContext remainder = visit(exprFn->body);
for (AstLocal* local : exprFn->args)
{
if (std::optional<TypeId> ty = willRunTimeErrorFunctionDefinition(local, remainder))
reportError(NonStrictFunctionDefinitionError{exprFn->debugname.value, local->name.value, *ty}, local->location);
remainder.remove(dfg->getDef(local));
}
return remainder;
}
NonStrictContext visit(AstExprTable* table)
{
return {};
}
NonStrictContext visit(AstExprUnary* unary)
{
return {};
}
NonStrictContext visit(AstExprBinary* binary)
{
return {};
}
NonStrictContext visit(AstExprTypeAssertion* typeAssertion)
{
return {};
}
NonStrictContext visit(AstExprIfElse* ifElse)
{
NonStrictContext condB = visit(ifElse->condition);
NonStrictContext thenB = visit(ifElse->trueExpr);
NonStrictContext elseB = visit(ifElse->falseExpr);
return NonStrictContext::disjunction(
builtinTypes, NotNull{&arena}, condB, NonStrictContext::conjunction(builtinTypes, NotNull{&arena}, thenB, elseB));
}
NonStrictContext visit(AstExprInterpString* interpString)
{
return {};
}
NonStrictContext visit(AstExprError* error)
{
return {};
}
void visit(AstExprTable* table, NonStrictContext& context) {}
void visit(AstExprUnary* unary, NonStrictContext& context) {}
void visit(AstExprBinary* binary, NonStrictContext& context) {}
void visit(AstExprTypeAssertion* typeAssertion, NonStrictContext& context) {}
void visit(AstExprIfElse* ifElse, NonStrictContext& context) {}
void visit(AstExprInterpString* interpString, NonStrictContext& context) {}
void visit(AstExprError* error, NonStrictContext& context) {}
void reportError(TypeErrorData data, const Location& location)
{
@ -402,16 +620,37 @@ struct NonStrictTypeChecker
SubtypingResult r = subtyping.isSubtype(actualType, *contextTy);
if (r.normalizationTooComplex)
reportError(NormalizationTooComplex{}, fragment->location);
if (!r.isSubtype && !r.isErrorSuppressing)
reportError(TypeMismatch{actualType, *contextTy}, fragment->location);
if (r.isSubtype)
return {actualType};
}
return {};
}
std::optional<TypeId> willRunTimeErrorFunctionDefinition(AstLocal* fragment, const NonStrictContext& context)
{
DefId def = dfg->getDef(fragment);
if (std::optional<TypeId> contextTy = context.find(def))
{
SubtypingResult r1 = subtyping.isSubtype(builtinTypes->unknownType, *contextTy);
SubtypingResult r2 = subtyping.isSubtype(*contextTy, builtinTypes->unknownType);
if (r1.normalizationTooComplex || r2.normalizationTooComplex)
reportError(NormalizationTooComplex{}, fragment->location);
bool isUnknown = r1.isSubtype && r2.isSubtype;
if (isUnknown)
return {builtinTypes->unknownType};
}
return {};
}
private:
TypeId getOrCreateNegation(TypeId baseType)
{
TypeId& cachedResult = cachedNegations[baseType];
if (!cachedResult)
cachedResult = arena.addType(NegationType{baseType});
return cachedResult;
};
};
void checkNonStrict(NotNull<BuiltinTypes> builtinTypes, NotNull<InternalErrorReporter> ice, NotNull<UnifierSharedState> unifierState,

View file

@ -8,7 +8,10 @@
#include "Luau/Clone.h"
#include "Luau/Common.h"
#include "Luau/RecursionCounter.h"
#include "Luau/Set.h"
#include "Luau/Subtyping.h"
#include "Luau/Type.h"
#include "Luau/TypeFwd.h"
#include "Luau/Unifier.h"
LUAU_FASTFLAGVARIABLE(DebugLuauCheckNormalizeInvariant, false)
@ -16,9 +19,10 @@ LUAU_FASTFLAGVARIABLE(DebugLuauCheckNormalizeInvariant, false)
// This could theoretically be 2000 on amd64, but x86 requires this.
LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200);
LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000);
LUAU_FASTFLAGVARIABLE(LuauNormalizeCyclicUnions, false);
LUAU_FASTFLAG(LuauTransitiveSubtyping)
LUAU_FASTFLAG(DebugLuauReadWriteProperties)
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
LUAU_FASTFLAG(LuauBufferTypeck)
namespace Luau
{
@ -32,9 +36,14 @@ TypeIds::TypeIds(std::initializer_list<TypeId> tys)
void TypeIds::insert(TypeId ty)
{
ty = follow(ty);
auto [_, fresh] = types.insert(ty);
if (fresh)
// get a reference to the slot for `ty` in `types`
bool& entry = types[ty];
// if `ty` is fresh, we can set it to `true`, add it to the order and hash and be done.
if (!entry)
{
entry = true;
order.push_back(ty);
hash ^= std::hash<TypeId>{}(ty);
}
@ -75,25 +84,26 @@ TypeIds::const_iterator TypeIds::end() const
TypeIds::iterator TypeIds::erase(TypeIds::const_iterator it)
{
TypeId ty = *it;
types.erase(ty);
types[ty] = false;
hash ^= std::hash<TypeId>{}(ty);
return order.erase(it);
}
size_t TypeIds::size() const
{
return types.size();
return order.size();
}
bool TypeIds::empty() const
{
return types.empty();
return order.empty();
}
size_t TypeIds::count(TypeId ty) const
{
ty = follow(ty);
return types.count(ty);
const bool* val = types.find(ty);
return (val && *val) ? 1 : 0;
}
void TypeIds::retain(const TypeIds& there)
@ -122,7 +132,29 @@ bool TypeIds::isNever() const
bool TypeIds::operator==(const TypeIds& there) const
{
return hash == there.hash && types == there.types;
// we can early return if the hashes don't match.
if (hash != there.hash)
return false;
// we have to check equality of the sets themselves if not.
// if the sets are unequal sizes, then they cannot possibly be equal.
// it is important to use `order` here and not `types` since the mappings
// may have different sizes since removal is not possible, and so erase
// simply writes `false` into the map.
if (order.size() != there.order.size())
return false;
// otherwise, we'll need to check that every element we have here is in `there`.
for (auto ty : order)
{
// if it's not, we'll return `false`
if (there.count(ty) == 0)
return false;
}
// otherwise, we've proven the two equal!
return true;
}
NormalizedStringType::NormalizedStringType() {}
@ -237,19 +269,56 @@ NormalizedType::NormalizedType(NotNull<BuiltinTypes> builtinTypes)
, numbers(builtinTypes->neverType)
, strings{NormalizedStringType::never}
, threads(builtinTypes->neverType)
, buffers(builtinTypes->neverType)
{
}
bool NormalizedType::isUnknown() const
{
if (get<UnknownType>(tops))
return true;
// Otherwise, we can still be unknown!
bool hasAllPrimitives = isPrim(booleans, PrimitiveType::Boolean) && isPrim(nils, PrimitiveType::NilType) && isNumber(numbers) &&
strings.isString() && isPrim(threads, PrimitiveType::Thread) && isThread(threads);
// Check is class
bool isTopClass = false;
for (auto [t, disj] : classes.classes)
{
if (auto ct = get<ClassType>(t))
{
if (ct->name == "class" && disj.empty())
{
isTopClass = true;
break;
}
}
}
// Check is table
bool isTopTable = false;
for (auto t : tables)
{
if (isPrim(t, PrimitiveType::Table))
{
isTopTable = true;
break;
}
}
// any = unknown or error ==> we need to make sure we have all the unknown components, but not errors
return get<NeverType>(errors) && hasAllPrimitives && isTopClass && isTopTable && functions.isTop;
}
bool NormalizedType::isExactlyNumber() const
{
return hasNumbers() && !hasTops() && !hasBooleans() && !hasClasses() && !hasErrors() && !hasNils() && !hasStrings() && !hasThreads() &&
!hasTables() && !hasFunctions() && !hasTyvars();
(!FFlag::LuauBufferTypeck || !hasBuffers()) && !hasTables() && !hasFunctions() && !hasTyvars();
}
bool NormalizedType::isSubtypeOfString() const
{
return hasStrings() && !hasTops() && !hasBooleans() && !hasClasses() && !hasErrors() && !hasNils() && !hasNumbers() && !hasThreads() &&
!hasTables() && !hasFunctions() && !hasTyvars();
(!FFlag::LuauBufferTypeck || !hasBuffers()) && !hasTables() && !hasFunctions() && !hasTyvars();
}
bool NormalizedType::shouldSuppressErrors() const
@ -306,6 +375,12 @@ bool NormalizedType::hasThreads() const
return !get<NeverType>(threads);
}
bool NormalizedType::hasBuffers() const
{
LUAU_ASSERT(FFlag::LuauBufferTypeck);
return !get<NeverType>(buffers);
}
bool NormalizedType::hasTables() const
{
return !tables.isNever();
@ -326,18 +401,18 @@ static bool isShallowInhabited(const NormalizedType& norm)
// This test is just a shallow check, for example it returns `true` for `{ p : never }`
return !get<NeverType>(norm.tops) || !get<NeverType>(norm.booleans) || !norm.classes.isNever() || !get<NeverType>(norm.errors) ||
!get<NeverType>(norm.nils) || !get<NeverType>(norm.numbers) || !norm.strings.isNever() || !get<NeverType>(norm.threads) ||
!norm.functions.isNever() || !norm.tables.empty() || !norm.tyvars.empty();
(FFlag::LuauBufferTypeck && !get<NeverType>(norm.buffers)) || !norm.functions.isNever() || !norm.tables.empty() || !norm.tyvars.empty();
}
bool Normalizer::isInhabited(const NormalizedType* norm, std::unordered_set<TypeId> seen)
bool Normalizer::isInhabited(const NormalizedType* norm, Set<TypeId> seen)
{
// If normalization failed, the type is complex, and so is more likely than not to be inhabited.
if (!norm)
return true;
if (!get<NeverType>(norm->tops) || !get<NeverType>(norm->booleans) || !get<NeverType>(norm->errors) || !get<NeverType>(norm->nils) ||
!get<NeverType>(norm->numbers) || !get<NeverType>(norm->threads) || !norm->classes.isNever() || !norm->strings.isNever() ||
!norm->functions.isNever())
!get<NeverType>(norm->numbers) || !get<NeverType>(norm->threads) || (FFlag::LuauBufferTypeck && !get<NeverType>(norm->buffers)) ||
!norm->classes.isNever() || !norm->strings.isNever() || !norm->functions.isNever())
return true;
for (const auto& [_, intersect] : norm->tyvars)
@ -363,7 +438,7 @@ bool Normalizer::isInhabited(TypeId ty)
return *result;
}
bool result = isInhabited(ty, {});
bool result = isInhabited(ty, {nullptr});
if (cacheInhabitance)
cachedIsInhabited[ty] = result;
@ -371,7 +446,7 @@ bool Normalizer::isInhabited(TypeId ty)
return result;
}
bool Normalizer::isInhabited(TypeId ty, std::unordered_set<TypeId> seen)
bool Normalizer::isInhabited(TypeId ty, Set<TypeId> seen)
{
// TODO: use log.follow(ty), CLI-64291
ty = follow(ty);
@ -425,7 +500,7 @@ bool Normalizer::isIntersectionInhabited(TypeId left, TypeId right)
return *result;
}
std::unordered_set<TypeId> seen = {};
Set<TypeId> seen{nullptr};
seen.insert(left);
seen.insert(right);
@ -561,6 +636,18 @@ static bool isNormalizedThread(TypeId ty)
return false;
}
static bool isNormalizedBuffer(TypeId ty)
{
LUAU_ASSERT(FFlag::LuauBufferTypeck);
if (get<NeverType>(ty))
return true;
else if (const PrimitiveType* ptv = get<PrimitiveType>(ty))
return ptv->type == PrimitiveType::Buffer;
else
return false;
}
static bool areNormalizedFunctions(const NormalizedFunctionType& tys)
{
for (TypeId ty : tys.parts)
@ -647,8 +734,7 @@ static bool areNormalizedClasses(const NormalizedClassType& tys)
static bool isPlainTyvar(TypeId ty)
{
return (get<FreeType>(ty) || get<GenericType>(ty) || get<BlockedType>(ty) ||
get<PendingExpansionType>(ty) || get<TypeFamilyInstanceType>(ty));
return (get<FreeType>(ty) || get<GenericType>(ty) || get<BlockedType>(ty) || get<PendingExpansionType>(ty) || get<TypeFamilyInstanceType>(ty));
}
static bool isNormalizedTyvar(const NormalizedTyvars& tyvars)
@ -682,6 +768,8 @@ static void assertInvariant(const NormalizedType& norm)
LUAU_ASSERT(isNormalizedNumber(norm.numbers));
LUAU_ASSERT(isNormalizedString(norm.strings));
LUAU_ASSERT(isNormalizedThread(norm.threads));
if (FFlag::LuauBufferTypeck)
LUAU_ASSERT(isNormalizedBuffer(norm.buffers));
LUAU_ASSERT(areNormalizedFunctions(norm.functions));
LUAU_ASSERT(areNormalizedTables(norm.tables));
LUAU_ASSERT(isNormalizedTyvar(norm.tyvars));
@ -708,9 +796,14 @@ const NormalizedType* Normalizer::normalize(TypeId ty)
return found->second.get();
NormalizedType norm{builtinTypes};
std::unordered_set<TypeId> seenSetTypes;
Set<TypeId> seenSetTypes{nullptr};
if (!unionNormalWithTy(norm, ty, seenSetTypes))
return nullptr;
if (norm.isUnknown())
{
clearNormal(norm);
norm.tops = builtinTypes->unknownType;
}
std::unique_ptr<NormalizedType> uniq = std::make_unique<NormalizedType>(std::move(norm));
const NormalizedType* result = uniq.get();
cachedNormals[ty] = std::move(uniq);
@ -724,7 +817,7 @@ bool Normalizer::normalizeIntersections(const std::vector<TypeId>& intersections
NormalizedType norm{builtinTypes};
norm.tops = builtinTypes->anyType;
// Now we need to intersect the two types
std::unordered_set<TypeId> seenSetTypes;
Set<TypeId> seenSetTypes{nullptr};
for (auto ty : intersections)
{
if (!intersectNormalWithTy(norm, ty, seenSetTypes))
@ -747,6 +840,8 @@ void Normalizer::clearNormal(NormalizedType& norm)
norm.numbers = builtinTypes->neverType;
norm.strings.resetToNever();
norm.threads = builtinTypes->neverType;
if (FFlag::LuauBufferTypeck)
norm.buffers = builtinTypes->neverType;
norm.tables.clear();
norm.functions.resetToNever();
norm.tyvars.clear();
@ -1432,6 +1527,8 @@ bool Normalizer::unionNormals(NormalizedType& here, const NormalizedType& there,
here.numbers = (get<NeverType>(there.numbers) ? here.numbers : there.numbers);
unionStrings(here.strings, there.strings);
here.threads = (get<NeverType>(there.threads) ? here.threads : there.threads);
if (FFlag::LuauBufferTypeck)
here.buffers = (get<NeverType>(there.buffers) ? here.buffers : there.buffers);
unionFunctions(here.functions, there.functions);
unionTables(here.tables, there.tables);
return true;
@ -1460,7 +1557,7 @@ bool Normalizer::withinResourceLimits()
}
// See above for an explaination of `ignoreSmallerTyvars`.
bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, std::unordered_set<TypeId>& seenSetTypes, int ignoreSmallerTyvars)
bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, Set<TypeId>& seenSetTypes, int ignoreSmallerTyvars)
{
RecursionCounter _rc(&sharedState->counters.recursionCount);
if (!withinResourceLimits())
@ -1488,12 +1585,9 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, std::unor
}
else if (const UnionType* utv = get<UnionType>(there))
{
if (FFlag::LuauNormalizeCyclicUnions)
{
if (seenSetTypes.count(there))
return true;
seenSetTypes.insert(there);
}
if (seenSetTypes.count(there))
return true;
seenSetTypes.insert(there);
for (UnionTypeIterator it = begin(utv); it != end(utv); ++it)
{
@ -1520,8 +1614,8 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, std::unor
}
else if (FFlag::LuauTransitiveSubtyping && get<UnknownType>(here.tops))
return true;
else if (get<GenericType>(there) || get<FreeType>(there) || get<BlockedType>(there) ||
get<PendingExpansionType>(there) || get<TypeFamilyInstanceType>(there))
else if (get<GenericType>(there) || get<FreeType>(there) || get<BlockedType>(there) || get<PendingExpansionType>(there) ||
get<TypeFamilyInstanceType>(there))
{
if (tyvarIndex(there) <= ignoreSmallerTyvars)
return true;
@ -1529,6 +1623,12 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, std::unor
inter.tops = builtinTypes->unknownType;
here.tyvars.insert_or_assign(there, std::make_unique<NormalizedType>(std::move(inter)));
}
else if (auto lt = get<LocalType>(there))
{
// FIXME? This is somewhat questionable.
// Maybe we should assert because this should never happen?
unionNormalWithTy(here, lt->domain, seenSetTypes, ignoreSmallerTyvars);
}
else if (get<FunctionType>(there))
unionFunctionsWithFunction(here.functions, there);
else if (get<TableType>(there) || get<MetatableType>(there))
@ -1549,6 +1649,8 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, std::unor
here.strings.resetToString();
else if (ptv->type == PrimitiveType::Thread)
here.threads = there;
else if (FFlag::LuauBufferTypeck && ptv->type == PrimitiveType::Buffer)
here.buffers = there;
else if (ptv->type == PrimitiveType::Function)
{
here.functions.resetToTop();
@ -1668,6 +1770,8 @@ std::optional<NormalizedType> Normalizer::negateNormal(const NormalizedType& her
result.strings.isCofinite = !result.strings.isCofinite;
result.threads = get<NeverType>(here.threads) ? builtinTypes->threadType : builtinTypes->neverType;
if (FFlag::LuauBufferTypeck)
result.buffers = get<NeverType>(here.buffers) ? builtinTypes->bufferType : builtinTypes->neverType;
/*
* Things get weird and so, so complicated if we allow negations of
@ -1757,6 +1861,10 @@ void Normalizer::subtractPrimitive(NormalizedType& here, TypeId ty)
case PrimitiveType::Thread:
here.threads = builtinTypes->neverType;
break;
case PrimitiveType::Buffer:
if (FFlag::LuauBufferTypeck)
here.buffers = builtinTypes->neverType;
break;
case PrimitiveType::Function:
here.functions.resetToNever();
break;
@ -2550,7 +2658,7 @@ void Normalizer::intersectFunctions(NormalizedFunctionType& heres, const Normali
}
}
bool Normalizer::intersectTyvarsWithTy(NormalizedTyvars& here, TypeId there, std::unordered_set<TypeId>& seenSetTypes)
bool Normalizer::intersectTyvarsWithTy(NormalizedTyvars& here, TypeId there, Set<TypeId>& seenSetTypes)
{
for (auto it = here.begin(); it != here.end();)
{
@ -2587,6 +2695,8 @@ bool Normalizer::intersectNormals(NormalizedType& here, const NormalizedType& th
here.numbers = (get<NeverType>(there.numbers) ? there.numbers : here.numbers);
intersectStrings(here.strings, there.strings);
here.threads = (get<NeverType>(there.threads) ? there.threads : here.threads);
if (FFlag::LuauBufferTypeck)
here.buffers = (get<NeverType>(there.buffers) ? there.buffers : here.buffers);
intersectFunctions(here.functions, there.functions);
intersectTables(here.tables, there.tables);
@ -2628,7 +2738,7 @@ bool Normalizer::intersectNormals(NormalizedType& here, const NormalizedType& th
return true;
}
bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there, std::unordered_set<TypeId>& seenSetTypes)
bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there, Set<TypeId>& seenSetTypes)
{
RecursionCounter _rc(&sharedState->counters.recursionCount);
if (!withinResourceLimits())
@ -2661,8 +2771,8 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there, std::
return false;
return true;
}
else if (get<GenericType>(there) || get<FreeType>(there) || get<BlockedType>(there) ||
get<PendingExpansionType>(there) || get<TypeFamilyInstanceType>(there))
else if (get<GenericType>(there) || get<FreeType>(there) || get<BlockedType>(there) || get<PendingExpansionType>(there) ||
get<TypeFamilyInstanceType>(there) || get<LocalType>(there))
{
NormalizedType thereNorm{builtinTypes};
NormalizedType topNorm{builtinTypes};
@ -2670,6 +2780,10 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there, std::
thereNorm.tyvars.insert_or_assign(there, std::make_unique<NormalizedType>(std::move(topNorm)));
return intersectNormals(here, thereNorm);
}
else if (auto lt = get<LocalType>(there))
{
return intersectNormalWithTy(here, lt->domain, seenSetTypes);
}
NormalizedTyvars tyvars = std::move(here.tyvars);
@ -2708,6 +2822,7 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there, std::
NormalizedStringType strings = std::move(here.strings);
NormalizedFunctionType functions = std::move(here.functions);
TypeId threads = here.threads;
TypeId buffers = here.buffers;
TypeIds tables = std::move(here.tables);
clearNormal(here);
@ -2722,6 +2837,8 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there, std::
here.strings = std::move(strings);
else if (ptv->type == PrimitiveType::Thread)
here.threads = threads;
else if (FFlag::LuauBufferTypeck && ptv->type == PrimitiveType::Buffer)
here.buffers = buffers;
else if (ptv->type == PrimitiveType::Function)
here.functions = std::move(functions);
else if (ptv->type == PrimitiveType::Table)
@ -2892,6 +3009,8 @@ TypeId Normalizer::typeFromNormal(const NormalizedType& norm)
}
if (!get<NeverType>(norm.threads))
result.push_back(builtinTypes->threadType);
if (FFlag::LuauBufferTypeck && !get<NeverType>(norm.buffers))
result.push_back(builtinTypes->bufferType);
result.insert(result.end(), norm.tables.begin(), norm.tables.end());
for (auto& [tyvar, intersect] : norm.tyvars)
@ -2915,32 +3034,58 @@ TypeId Normalizer::typeFromNormal(const NormalizedType& norm)
bool isSubtype(TypeId subTy, TypeId superTy, NotNull<Scope> scope, NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter& ice)
{
if (!FFlag::LuauTransitiveSubtyping)
if (!FFlag::LuauTransitiveSubtyping && !FFlag::DebugLuauDeferredConstraintResolution)
return isConsistentSubtype(subTy, superTy, scope, builtinTypes, ice);
UnifierSharedState sharedState{&ice};
TypeArena arena;
Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}};
Unifier u{NotNull{&normalizer}, scope, Location{}, Covariant};
u.tryUnify(subTy, superTy);
return !u.failure;
// Subtyping under DCR is not implemented using unification!
if (FFlag::DebugLuauDeferredConstraintResolution)
{
Subtyping subtyping{builtinTypes, NotNull{&arena}, NotNull{&normalizer}, NotNull{&ice}, scope};
return subtyping.isSubtype(subTy, superTy).isSubtype;
}
else
{
Unifier u{NotNull{&normalizer}, scope, Location{}, Covariant};
u.tryUnify(subTy, superTy);
return !u.failure;
}
}
bool isSubtype(TypePackId subPack, TypePackId superPack, NotNull<Scope> scope, NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter& ice)
{
if (!FFlag::LuauTransitiveSubtyping)
if (!FFlag::LuauTransitiveSubtyping && !FFlag::DebugLuauDeferredConstraintResolution)
return isConsistentSubtype(subPack, superPack, scope, builtinTypes, ice);
UnifierSharedState sharedState{&ice};
TypeArena arena;
Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}};
Unifier u{NotNull{&normalizer}, scope, Location{}, Covariant};
u.tryUnify(subPack, superPack);
return !u.failure;
// Subtyping under DCR is not implemented using unification!
if (FFlag::DebugLuauDeferredConstraintResolution)
{
Subtyping subtyping{builtinTypes, NotNull{&arena}, NotNull{&normalizer}, NotNull{&ice}, scope};
return subtyping.isSubtype(subPack, superPack).isSubtype;
}
else
{
Unifier u{NotNull{&normalizer}, scope, Location{}, Covariant};
u.tryUnify(subPack, superPack);
return !u.failure;
}
}
bool isConsistentSubtype(TypeId subTy, TypeId superTy, NotNull<Scope> scope, NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter& ice)
{
LUAU_ASSERT(!FFlag::DebugLuauDeferredConstraintResolution);
UnifierSharedState sharedState{&ice};
TypeArena arena;
Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}};
@ -2954,6 +3099,8 @@ bool isConsistentSubtype(TypeId subTy, TypeId superTy, NotNull<Scope> scope, Not
bool isConsistentSubtype(
TypePackId subPack, TypePackId superPack, NotNull<Scope> scope, NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter& ice)
{
LUAU_ASSERT(!FFlag::DebugLuauDeferredConstraintResolution);
UnifierSharedState sharedState{&ice};
TypeArena arena;
Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}};

View file

@ -72,7 +72,7 @@ std::optional<std::pair<Binding*, Scope*>> Scope::lookupEx(Symbol sym)
}
}
std::optional<TypeId> Scope::lookupLValue(DefId def) const
std::optional<TypeId> Scope::lookupUnrefinedType(DefId def) const
{
for (const Scope* current = this; current; current = current->parent.get())
{
@ -83,7 +83,6 @@ std::optional<TypeId> Scope::lookupLValue(DefId def) const
return std::nullopt;
}
// TODO: We might kill Scope::lookup(Symbol) once data flow is fully fleshed out with type states and control flow analysis.
std::optional<TypeId> Scope::lookup(DefId def) const
{
for (const Scope* current = this; current; current = current->parent.get())
@ -181,6 +180,16 @@ std::optional<Binding> Scope::linearSearchForBinding(const std::string& name, bo
return std::nullopt;
}
// Updates the `this` scope with the assignments from the `childScope` including ones that doesn't exist in `this`.
void Scope::inheritAssignments(const ScopePtr& childScope)
{
if (!FFlag::DebugLuauDeferredConstraintResolution)
return;
for (const auto& [k, a] : childScope->lvalueTypes)
lvalueTypes[k] = a;
}
// Updates the `this` scope with the refinements from the `childScope` excluding ones that doesn't exist in `this`.
void Scope::inheritRefinements(const ScopePtr& childScope)
{

View file

@ -2,6 +2,7 @@
#include "Luau/Simplify.h"
#include "Luau/DenseHash.h"
#include "Luau/Normalize.h" // TypeIds
#include "Luau/RecursionCounter.h"
#include "Luau/ToString.h"
@ -21,7 +22,7 @@ struct TypeSimplifier
NotNull<BuiltinTypes> builtinTypes;
NotNull<TypeArena> arena;
std::set<TypeId> blockedTypes;
DenseHashSet<TypeId> blockedTypes{nullptr};
int recursionDepth = 0;

View file

@ -19,8 +19,12 @@ static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool a
auto go = [ty, &dest, alwaysClone](auto&& a) {
using T = std::decay_t<decltype(a)>;
// The pointer identities of free and local types is very important.
// We decline to copy them.
if constexpr (std::is_same_v<T, FreeType>)
return ty;
else if constexpr (std::is_same_v<T, LocalType>)
return ty;
else if constexpr (std::is_same_v<T, BoundType>)
{
// This should never happen, but visit() cannot see it.

View file

@ -16,6 +16,8 @@
#include <algorithm>
LUAU_FASTFLAGVARIABLE(DebugLuauSubtypingCheckPathValidity, false);
namespace Luau
{
@ -47,16 +49,77 @@ struct VarianceFlipper
bool SubtypingReasoning::operator==(const SubtypingReasoning& other) const
{
return subPath == other.subPath && superPath == other.superPath;
return subPath == other.subPath && superPath == other.superPath && variance == other.variance;
}
size_t SubtypingReasoningHash::operator()(const SubtypingReasoning& r) const
{
return TypePath::PathHash()(r.subPath) ^ (TypePath::PathHash()(r.superPath) << 1) ^ (static_cast<size_t>(r.variance) << 1);
}
template<typename TID>
static void assertReasoningValid(TID subTy, TID superTy, const SubtypingResult& result, NotNull<BuiltinTypes> builtinTypes)
{
if (!FFlag::DebugLuauSubtypingCheckPathValidity)
return;
for (const SubtypingReasoning& reasoning : result.reasoning)
{
LUAU_ASSERT(traverse(subTy, reasoning.subPath, builtinTypes));
LUAU_ASSERT(traverse(superTy, reasoning.superPath, builtinTypes));
}
}
template<>
void assertReasoningValid<TableIndexer>(TableIndexer subIdx, TableIndexer superIdx, const SubtypingResult& result, NotNull<BuiltinTypes> builtinTypes)
{
// Empty method to satisfy the compiler.
}
static SubtypingReasonings mergeReasonings(const SubtypingReasonings& a, const SubtypingReasonings& b)
{
SubtypingReasonings result{kEmptyReasoning};
for (const SubtypingReasoning& r : a)
{
if (r.variance == SubtypingVariance::Invariant)
result.insert(r);
else if (r.variance == SubtypingVariance::Covariant || r.variance == SubtypingVariance::Contravariant)
{
SubtypingReasoning inverseReasoning = SubtypingReasoning{
r.subPath, r.superPath, r.variance == SubtypingVariance::Covariant ? SubtypingVariance::Contravariant : SubtypingVariance::Covariant};
if (b.contains(inverseReasoning))
result.insert(SubtypingReasoning{r.subPath, r.superPath, SubtypingVariance::Invariant});
else
result.insert(r);
}
}
for (const SubtypingReasoning& r : b)
{
if (r.variance == SubtypingVariance::Invariant)
result.insert(r);
else if (r.variance == SubtypingVariance::Covariant || r.variance == SubtypingVariance::Contravariant)
{
SubtypingReasoning inverseReasoning = SubtypingReasoning{
r.subPath, r.superPath, r.variance == SubtypingVariance::Covariant ? SubtypingVariance::Contravariant : SubtypingVariance::Covariant};
if (a.contains(inverseReasoning))
result.insert(SubtypingReasoning{r.subPath, r.superPath, SubtypingVariance::Invariant});
else
result.insert(r);
}
}
return result;
}
SubtypingResult& SubtypingResult::andAlso(const SubtypingResult& other)
{
// If this result is a subtype, we take the other result's reasoning. If
// this result is not a subtype, we keep the current reasoning, even if the
// other isn't a subtype.
if (isSubtype)
reasoning = other.reasoning;
// If the other result is not a subtype, we want to join all of its
// reasonings to this one. If this result already has reasonings of its own,
// those need to be attributed here.
if (!other.isSubtype)
reasoning = mergeReasonings(reasoning, other.reasoning);
isSubtype &= other.isSubtype;
// `|=` is intentional here, we want to preserve error related flags.
@ -69,10 +132,17 @@ SubtypingResult& SubtypingResult::andAlso(const SubtypingResult& other)
SubtypingResult& SubtypingResult::orElse(const SubtypingResult& other)
{
// If the other result is not a subtype, we take the other result's
// reasoning.
if (!other.isSubtype)
reasoning = other.reasoning;
// If this result is a subtype, we do not join the reasoning lists. If this
// result is not a subtype, but the other is a subtype, we want to _clear_
// our reasoning list. If both results are not subtypes, we join the
// reasoning lists.
if (!isSubtype)
{
if (other.isSubtype)
reasoning.clear();
else
reasoning = mergeReasonings(reasoning, other.reasoning);
}
isSubtype |= other.isSubtype;
isErrorSuppressing |= other.isErrorSuppressing;
@ -89,20 +159,26 @@ SubtypingResult& SubtypingResult::withBothComponent(TypePath::Component componen
SubtypingResult& SubtypingResult::withSubComponent(TypePath::Component component)
{
if (!reasoning)
reasoning = SubtypingReasoning{Path(), Path()};
reasoning->subPath = reasoning->subPath.push_front(component);
if (reasoning.empty())
reasoning.insert(SubtypingReasoning{Path(component), TypePath::kEmpty});
else
{
for (auto& r : reasoning)
r.subPath = r.subPath.push_front(component);
}
return *this;
}
SubtypingResult& SubtypingResult::withSuperComponent(TypePath::Component component)
{
if (!reasoning)
reasoning = SubtypingReasoning{Path(), Path()};
reasoning->superPath = reasoning->superPath.push_front(component);
if (reasoning.empty())
reasoning.insert(SubtypingReasoning{TypePath::kEmpty, Path(component)});
else
{
for (auto& r : reasoning)
r.superPath = r.superPath.push_front(component);
}
return *this;
}
@ -114,20 +190,26 @@ SubtypingResult& SubtypingResult::withBothPath(TypePath::Path path)
SubtypingResult& SubtypingResult::withSubPath(TypePath::Path path)
{
if (!reasoning)
reasoning = SubtypingReasoning{Path(), Path()};
reasoning->subPath = path.append(reasoning->subPath);
if (reasoning.empty())
reasoning.insert(SubtypingReasoning{path, TypePath::kEmpty});
else
{
for (auto& r : reasoning)
r.subPath = path.append(r.subPath);
}
return *this;
}
SubtypingResult& SubtypingResult::withSuperPath(TypePath::Path path)
{
if (!reasoning)
reasoning = SubtypingReasoning{Path(), Path()};
reasoning->superPath = path.append(reasoning->superPath);
if (reasoning.empty())
reasoning.insert(SubtypingReasoning{TypePath::kEmpty, path});
else
{
for (auto& r : reasoning)
r.superPath = path.append(r.superPath);
}
return *this;
}
@ -202,7 +284,10 @@ SubtypingResult Subtyping::isSubtype(TypeId subTy, TypeId superTy)
result.isSubtype = false;
}
result.andAlso(isCovariantWith(env, lowerBound, upperBound));
SubtypingResult boundsResult = isCovariantWith(env, lowerBound, upperBound);
boundsResult.reasoning.clear();
result.andAlso(boundsResult);
}
/* TODO: We presently don't store subtype test results in the persistent
@ -281,7 +366,7 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypeId sub
return {true};
std::pair<TypeId, TypeId> typePair{subTy, superTy};
if (!seenTypes.insert(typePair).second)
if (!seenTypes.insert(typePair))
{
/* TODO: Caching results for recursive types is really tricky to think
* about.
@ -321,14 +406,34 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypeId sub
if (auto subUnion = get<UnionType>(subTy))
result = isCovariantWith(env, subUnion, superTy);
else if (auto superUnion = get<UnionType>(superTy))
{
result = isCovariantWith(env, subTy, superUnion);
if (!result.isSubtype && !result.isErrorSuppressing && !result.normalizationTooComplex)
{
SubtypingResult semantic = isCovariantWith(env, normalizer->normalize(subTy), normalizer->normalize(superTy));
if (semantic.isSubtype)
{
semantic.reasoning.clear();
result = semantic;
}
}
}
else if (auto superIntersection = get<IntersectionType>(superTy))
result = isCovariantWith(env, subTy, superIntersection);
else if (auto subIntersection = get<IntersectionType>(subTy))
{
result = isCovariantWith(env, subIntersection, superTy);
if (!result.isSubtype && !result.isErrorSuppressing && !result.normalizationTooComplex)
result = isCovariantWith(env, normalizer->normalize(subTy), normalizer->normalize(superTy));
{
SubtypingResult semantic = isCovariantWith(env, normalizer->normalize(subTy), normalizer->normalize(superTy));
if (semantic.isSubtype)
{
// Clear the semantic reasoning, as any reasonings within
// potentially contain invalid paths.
semantic.reasoning.clear();
result = semantic;
}
}
}
else if (get<AnyType>(superTy))
result = {true};
@ -356,9 +461,31 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypeId sub
else if (auto p = get2<NegationType, NegationType>(subTy, superTy))
result = isCovariantWith(env, p.first->ty, p.second->ty).withBothComponent(TypePath::TypeField::Negated);
else if (auto subNegation = get<NegationType>(subTy))
{
result = isCovariantWith(env, subNegation, superTy);
if (!result.isSubtype && !result.isErrorSuppressing && !result.normalizationTooComplex)
{
SubtypingResult semantic = isCovariantWith(env, normalizer->normalize(subTy), normalizer->normalize(superTy));
if (semantic.isSubtype)
{
semantic.reasoning.clear();
result = semantic;
}
}
}
else if (auto superNegation = get<NegationType>(superTy))
{
result = isCovariantWith(env, subTy, superNegation);
if (!result.isSubtype && !result.isErrorSuppressing && !result.normalizationTooComplex)
{
SubtypingResult semantic = isCovariantWith(env, normalizer->normalize(subTy), normalizer->normalize(superTy));
if (semantic.isSubtype)
{
semantic.reasoning.clear();
result = semantic;
}
}
}
else if (auto subGeneric = get<GenericType>(subTy); subGeneric && variance == Variance::Covariant)
{
bool ok = bindGeneric(env, subTy, superTy);
@ -394,6 +521,8 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypeId sub
else if (auto p = get2<SingletonType, TableType>(subTy, superTy))
result = isCovariantWith(env, p);
assertReasoningValid(subTy, superTy, result, builtinTypes);
return cache(env, result, subTy, superTy);
}
@ -481,7 +610,7 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypePackId
for (size_t i = headSize; i < subHead.size(); ++i)
results.push_back(isCovariantWith(env, subHead[i], vt->ty)
.withSubComponent(TypePath::Index{i})
.withSuperComponent(TypePath::TypeField::Variadic));
.withSuperPath(TypePath::PathBuilder().tail().variadic().build()));
}
else if (auto gt = get<GenericTypePack>(*superTail))
{
@ -609,19 +738,38 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypePackId
iceReporter->ice("Subtyping test encountered the unexpected type pack: " + toString(*superTail));
}
return SubtypingResult::all(results);
SubtypingResult result = SubtypingResult::all(results);
assertReasoningValid(subTp, superTp, result, builtinTypes);
return result;
}
template<typename SubTy, typename SuperTy>
SubtypingResult Subtyping::isContravariantWith(SubtypingEnvironment& env, SubTy&& subTy, SuperTy&& superTy)
{
SubtypingResult result = isCovariantWith(env, superTy, subTy);
// If we don't swap the paths here, we will end up producing an invalid path
// whenever we involve contravariance. We'll end up appending path
// components that should belong to the supertype to the subtype, and vice
// versa.
if (result.reasoning)
std::swap(result.reasoning->subPath, result.reasoning->superPath);
if (result.reasoning.empty())
result.reasoning.insert(SubtypingReasoning{TypePath::kEmpty, TypePath::kEmpty, SubtypingVariance::Contravariant});
else
{
// If we don't swap the paths here, we will end up producing an invalid path
// whenever we involve contravariance. We'll end up appending path
// components that should belong to the supertype to the subtype, and vice
// versa.
for (auto& reasoning : result.reasoning)
{
std::swap(reasoning.subPath, reasoning.superPath);
// Also swap covariant/contravariant, since those are also the other way
// around.
if (reasoning.variance == SubtypingVariance::Covariant)
reasoning.variance = SubtypingVariance::Contravariant;
else if (reasoning.variance == SubtypingVariance::Contravariant)
reasoning.variance = SubtypingVariance::Covariant;
}
}
assertReasoningValid(subTy, superTy, result, builtinTypes);
return result;
}
@ -629,7 +777,17 @@ SubtypingResult Subtyping::isContravariantWith(SubtypingEnvironment& env, SubTy&
template<typename SubTy, typename SuperTy>
SubtypingResult Subtyping::isInvariantWith(SubtypingEnvironment& env, SubTy&& subTy, SuperTy&& superTy)
{
return isCovariantWith(env, subTy, superTy).andAlso(isContravariantWith(env, subTy, superTy));
SubtypingResult result = isCovariantWith(env, subTy, superTy).andAlso(isContravariantWith(env, subTy, superTy));
if (result.reasoning.empty())
result.reasoning.insert(SubtypingReasoning{TypePath::kEmpty, TypePath::kEmpty, SubtypingVariance::Invariant});
else
{
for (auto& reasoning : result.reasoning)
reasoning.variance = SubtypingVariance::Invariant;
}
assertReasoningValid(subTy, superTy, result, builtinTypes);
return result;
}
template<typename SubTy, typename SuperTy>
@ -641,13 +799,13 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const TryP
template<typename SubTy, typename SuperTy>
SubtypingResult Subtyping::isContravariantWith(SubtypingEnvironment& env, const TryPair<const SubTy*, const SuperTy*>& pair)
{
return isCovariantWith(env, pair.second, pair.first);
return isContravariantWith(env, pair.first, pair.second);
}
template<typename SubTy, typename SuperTy>
SubtypingResult Subtyping::isInvariantWith(SubtypingEnvironment& env, const TryPair<const SubTy*, const SuperTy*>& pair)
{
return isCovariantWith(env, pair).andAlso(isContravariantWith(pair));
return isInvariantWith(env, pair.first, pair.second);
}
/*
@ -733,17 +891,17 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const Nega
if (is<NeverType>(negatedTy))
{
// ¬never ~ unknown
result = isCovariantWith(env, builtinTypes->unknownType, superTy);
result = isCovariantWith(env, builtinTypes->unknownType, superTy).withSubComponent(TypePath::TypeField::Negated);
}
else if (is<UnknownType>(negatedTy))
{
// ¬unknown ~ never
result = isCovariantWith(env, builtinTypes->neverType, superTy);
result = isCovariantWith(env, builtinTypes->neverType, superTy).withSubComponent(TypePath::TypeField::Negated);
}
else if (is<AnyType>(negatedTy))
{
// ¬any ~ any
result = isCovariantWith(env, negatedTy, superTy);
result = isCovariantWith(env, negatedTy, superTy).withSubComponent(TypePath::TypeField::Negated);
}
else if (auto u = get<UnionType>(negatedTy))
{
@ -753,8 +911,13 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const Nega
for (TypeId ty : u)
{
NegationType negatedTmp{ty};
subtypings.push_back(isCovariantWith(env, &negatedTmp, superTy));
if (auto negatedPart = get<NegationType>(follow(ty)))
subtypings.push_back(isCovariantWith(env, negatedPart->ty, superTy).withSubComponent(TypePath::TypeField::Negated));
else
{
NegationType negatedTmp{ty};
subtypings.push_back(isCovariantWith(env, &negatedTmp, superTy));
}
}
result = SubtypingResult::all(subtypings);
@ -768,7 +931,7 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const Nega
for (TypeId ty : i)
{
if (auto negatedPart = get<NegationType>(follow(ty)))
subtypings.push_back(isCovariantWith(env, negatedPart->ty, superTy));
subtypings.push_back(isCovariantWith(env, negatedPart->ty, superTy).withSubComponent(TypePath::TypeField::Negated));
else
{
NegationType negatedTmp{ty};
@ -786,10 +949,10 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const Nega
// subtype of other stuff.
else
{
result = {false};
result = SubtypingResult{false}.withSubComponent(TypePath::TypeField::Negated);
}
return result.withSubComponent(TypePath::TypeField::Negated);
return result;
}
SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const TypeId subTy, const NegationType* superNegation)
@ -830,7 +993,7 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const Type
}
}
result = SubtypingResult::all(subtypings);
return SubtypingResult::all(subtypings);
}
else if (auto i = get<IntersectionType>(negatedTy))
{
@ -849,7 +1012,7 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const Type
}
}
result = SubtypingResult::any(subtypings);
return SubtypingResult::any(subtypings);
}
else if (auto p = get2<PrimitiveType, PrimitiveType>(subTy, negatedTy))
{
@ -931,8 +1094,7 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const Tabl
{
std::vector<SubtypingResult> results;
if (auto it = subTable->props.find(name); it != subTable->props.end())
results.push_back(isInvariantWith(env, it->second.type(), prop.type())
.withBothComponent(TypePath::Property(name)));
results.push_back(isInvariantWith(env, it->second.type(), prop.type()).withBothComponent(TypePath::Property(name)));
if (subTable->indexer)
{
@ -967,7 +1129,7 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const Meta
SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const MetatableType* subMt, const TableType* superTable)
{
if (auto subTable = get<TableType>(subMt->table))
if (auto subTable = get<TableType>(follow(subMt->table)))
{
// Metatables cannot erase properties from the table they're attached to, so
// the subtyping rule for this is just if the table component is a subtype
@ -1067,7 +1229,7 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const Tabl
{
return isInvariantWith(env, subIndexer.indexType, superIndexer.indexType)
.withBothComponent(TypePath::TypeField::IndexLookup)
.andAlso(isInvariantWith(env, superIndexer.indexResultType, subIndexer.indexResultType).withBothComponent(TypePath::TypeField::IndexResult));
.andAlso(isInvariantWith(env, subIndexer.indexResultType, superIndexer.indexResultType).withBothComponent(TypePath::TypeField::IndexResult));
}
SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const NormalizedType* subNorm, const NormalizedType* superNorm)
@ -1194,12 +1356,11 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const Type
{
std::vector<SubtypingResult> results;
size_t i = 0;
for (TypeId subTy : subTypes)
{
results.emplace_back();
for (TypeId superTy : superTypes)
results.back().orElse(isCovariantWith(env, subTy, superTy).withBothComponent(TypePath::Index{i++}));
results.back().orElse(isCovariantWith(env, subTy, superTy));
}
return SubtypingResult::all(results);

View file

@ -261,6 +261,14 @@ void StateDot::visitChildren(TypeId ty, int index)
visitChild(t.upperBound, index, "[upperBound]");
}
}
else if constexpr (std::is_same_v<T, LocalType>)
{
formatAppend(result, "LocalType");
finishNodeLabel(ty);
finishNode();
visitChild(t.domain, 1, "[domain]");
}
else if constexpr (std::is_same_v<T, AnyType>)
{
formatAppend(result, "AnyType %d", index);

View file

@ -100,6 +100,16 @@ struct FindCyclicTypes final : TypeVisitor
return false;
}
bool visit(TypeId ty, const LocalType& lt) override
{
if (!visited.insert(ty).second)
return false;
traverse(lt.domain);
return false;
}
bool visit(TypeId ty, const TableType& ttv) override
{
if (!visited.insert(ty).second)
@ -500,6 +510,15 @@ struct TypeStringifier
}
}
void operator()(TypeId ty, const LocalType& lt)
{
state.emit("l-");
state.emit(lt.name);
state.emit("=[");
stringify(lt.domain);
state.emit("]");
}
void operator()(TypeId, const BoundType& btv)
{
stringify(btv.boundTo);
@ -562,6 +581,9 @@ struct TypeStringifier
case PrimitiveType::Thread:
state.emit("thread");
return;
case PrimitiveType::Buffer:
state.emit("buffer");
return;
case PrimitiveType::Function:
state.emit("function");
return;
@ -1699,7 +1721,7 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts)
std::string iteratorStr = tos(c.iterator);
std::string variableStr = tos(c.variables);
return variableStr + " ~ Iterate<" + iteratorStr + ">";
return variableStr + " ~ iterate " + iteratorStr;
}
else if constexpr (std::is_same_v<T, NameConstraint>)
{

View file

@ -10,7 +10,6 @@
#include <limits>
#include <math.h>
LUAU_FASTFLAG(LuauFloorDivision)
namespace
{
@ -474,8 +473,6 @@ struct Printer
case AstExprBinary::Pow:
case AstExprBinary::CompareLt:
case AstExprBinary::CompareGt:
LUAU_ASSERT(FFlag::LuauFloorDivision || a->op != AstExprBinary::FloorDiv);
writer.maybeSpace(a->right->location.begin, 2);
writer.symbol(toString(a->op));
break;
@ -761,8 +758,6 @@ struct Printer
writer.symbol("/=");
break;
case AstExprBinary::FloorDiv:
LUAU_ASSERT(FFlag::LuauFloorDivision);
writer.maybeSpace(a->value->location.begin, 2);
writer.symbol("//=");
break;

View file

@ -27,6 +27,7 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit)
LUAU_FASTFLAG(LuauInstantiateInSubtyping)
LUAU_FASTFLAG(DebugLuauReadWriteProperties)
LUAU_FASTFLAGVARIABLE(LuauInitializeStringMetatableInGlobalTypes, false)
LUAU_FASTFLAG(LuauBufferTypeck)
namespace Luau
{
@ -214,6 +215,13 @@ bool isThread(TypeId ty)
return isPrim(ty, PrimitiveType::Thread);
}
bool isBuffer(TypeId ty)
{
LUAU_ASSERT(FFlag::LuauBufferTypeck);
return isPrim(ty, PrimitiveType::Buffer);
}
bool isOptional(TypeId ty)
{
if (isNil(ty))
@ -604,10 +612,11 @@ FunctionType::FunctionType(TypeLevel level, Scope* scope, std::vector<TypeId> ge
Property::Property() {}
Property::Property(TypeId readTy, bool deprecated, const std::string& deprecatedSuggestion, std::optional<Location> location, const Tags& tags,
const std::optional<std::string>& documentationSymbol)
const std::optional<std::string>& documentationSymbol, std::optional<Location> typeLocation)
: deprecated(deprecated)
, deprecatedSuggestion(deprecatedSuggestion)
, location(location)
, typeLocation(typeLocation)
, tags(tags)
, documentationSymbol(documentationSymbol)
, readTy(readTy)
@ -925,6 +934,7 @@ BuiltinTypes::BuiltinTypes()
, stringType(arena->addType(Type{PrimitiveType{PrimitiveType::String}, /*persistent*/ true}))
, booleanType(arena->addType(Type{PrimitiveType{PrimitiveType::Boolean}, /*persistent*/ true}))
, threadType(arena->addType(Type{PrimitiveType{PrimitiveType::Thread}, /*persistent*/ true}))
, bufferType(arena->addType(Type{PrimitiveType{PrimitiveType::Buffer}, /*persistent*/ true}))
, functionType(arena->addType(Type{PrimitiveType{PrimitiveType::Function}, /*persistent*/ true}))
, classType(arena->addType(Type{ClassType{"class", {}, std::nullopt, std::nullopt, {}, {}, {}}, /*persistent*/ true}))
, tableType(arena->addType(Type{PrimitiveType{PrimitiveType::Table}, /*persistent*/ true}))

View file

@ -13,8 +13,6 @@
#include <string>
LUAU_FASTFLAG(LuauParseDeclareClassIndexer);
static char* allocateString(Luau::Allocator& allocator, std::string_view contents)
{
char* result = (char*)allocator.allocate(contents.size() + 1);
@ -106,6 +104,8 @@ public:
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("string"), std::nullopt, Location());
case PrimitiveType::Thread:
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("thread"), std::nullopt, Location());
case PrimitiveType::Buffer:
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("buffer"), std::nullopt, Location());
default:
return nullptr;
}
@ -230,7 +230,7 @@ public:
}
AstTableIndexer* indexer = nullptr;
if (FFlag::LuauParseDeclareClassIndexer && ctv.indexer)
if (ctv.indexer)
{
RecursionCounter counter(&count);
@ -329,10 +329,14 @@ public:
{
return Luau::visit(*this, bound.boundTo->ty);
}
AstType* operator()(const FreeType& ftv)
AstType* operator()(const FreeType& ft)
{
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("free"), std::nullopt, Location());
}
AstType* operator()(const LocalType& lt)
{
return Luau::visit(*this, lt.domain->ty);
}
AstType* operator()(const UnionType& uv)
{
AstArray<AstType*> unionTypes;

View file

@ -3,9 +3,9 @@
#include "Luau/Ast.h"
#include "Luau/AstQuery.h"
#include "Luau/Clone.h"
#include "Luau/Common.h"
#include "Luau/DcrLogger.h"
#include "Luau/DenseHash.h"
#include "Luau/Error.h"
#include "Luau/InsertionOrderedMap.h"
#include "Luau/Instantiation.h"
@ -20,12 +20,12 @@
#include "Luau/TypePack.h"
#include "Luau/TypePath.h"
#include "Luau/TypeUtils.h"
#include "Luau/TypeOrPack.h"
#include "Luau/VisitType.h"
#include <algorithm>
LUAU_FASTFLAG(DebugLuauMagicTypes)
LUAU_FASTFLAG(LuauFloorDivision);
namespace Luau
{
@ -1660,14 +1660,48 @@ struct TypeChecker2
if (argIt == end(inferredFtv->argTypes))
break;
TypeId inferredArgTy = *argIt;
if (arg->annotation)
{
TypeId inferredArgTy = *argIt;
TypeId annotatedArgTy = lookupAnnotation(arg->annotation);
testIsSubtype(inferredArgTy, annotatedArgTy, arg->location);
}
// Some Luau constructs can result in an argument type being
// reduced to never by inference. In this case, we want to
// report an error at the function, instead of reporting an
// error at every callsite.
if (is<NeverType>(follow(inferredArgTy)))
{
// If the annotation simplified to never, we don't want to
// even look at contributors.
bool explicitlyNever = false;
if (arg->annotation)
{
TypeId annotatedArgTy = lookupAnnotation(arg->annotation);
explicitlyNever = is<NeverType>(annotatedArgTy);
}
// Not following here is deliberate: the contribution map is
// keyed by type pointer, but that type pointer has, at some
// point, been transmuted to a bound type pointing to never.
if (const auto contributors = module->upperBoundContributors.find(inferredArgTy); contributors && !explicitlyNever)
{
// It's unfortunate that we can't link error messages
// together. For now, this will work.
reportError(
GenericError{format(
"Parameter '%s' has been reduced to never. This function is not callable with any possible value.", arg->name.value)},
arg->location);
for (const auto& [site, component] : *contributors)
reportError(ExtraInformation{format("Parameter '%s' is required to be a subtype of '%s' here.", arg->name.value,
toString(component).c_str())},
site);
}
}
++argIt;
}
}
@ -1819,8 +1853,6 @@ struct TypeChecker2
bool typesHaveIntersection = normalizer.isIntersectionInhabited(leftType, rightType);
if (auto it = kBinaryOpMetamethods.find(expr->op); it != kBinaryOpMetamethods.end())
{
LUAU_ASSERT(FFlag::LuauFloorDivision || expr->op != AstExprBinary::Op::FloorDiv);
std::optional<TypeId> leftMt = getMetatable(leftType, builtinTypes);
std::optional<TypeId> rightMt = getMetatable(rightType, builtinTypes);
bool matches = leftMt == rightMt;
@ -2009,8 +2041,6 @@ struct TypeChecker2
case AstExprBinary::Op::FloorDiv:
case AstExprBinary::Op::Pow:
case AstExprBinary::Op::Mod:
LUAU_ASSERT(FFlag::LuauFloorDivision || expr->op != AstExprBinary::Op::FloorDiv);
testIsSubtype(leftType, builtinTypes->numberType, expr->left->location);
testIsSubtype(rightType, builtinTypes->numberType, expr->right->location);
@ -2413,6 +2443,72 @@ struct TypeChecker2
}
}
template<typename TID>
std::optional<std::string> explainReasonings(TID subTy, TID superTy, Location location, const SubtypingResult& r)
{
if (r.reasoning.empty())
return std::nullopt;
std::vector<std::string> reasons;
for (const SubtypingReasoning& reasoning : r.reasoning)
{
if (reasoning.subPath.empty() && reasoning.superPath.empty())
continue;
std::optional<TypeOrPack> subLeaf = traverse(subTy, reasoning.subPath, builtinTypes);
std::optional<TypeOrPack> superLeaf = traverse(superTy, reasoning.superPath, builtinTypes);
if (!subLeaf || !superLeaf)
ice->ice("Subtyping test returned a reasoning with an invalid path", location);
if (!get2<TypeId, TypeId>(*subLeaf, *superLeaf) && !get2<TypePackId, TypePackId>(*subLeaf, *superLeaf))
ice->ice("Subtyping test returned a reasoning where one path ends at a type and the other ends at a pack.", location);
std::string relation = "a subtype of";
if (reasoning.variance == SubtypingVariance::Invariant)
relation = "exactly";
else if (reasoning.variance == SubtypingVariance::Contravariant)
relation = "a supertype of";
std::string reason;
if (reasoning.subPath == reasoning.superPath)
reason = "at " + toString(reasoning.subPath) + ", " + toString(*subLeaf) + " is not " + relation + " " + toString(*superLeaf);
else
reason = "type " + toString(subTy) + toString(reasoning.subPath, /* prefixDot */ true) + " (" + toString(*subLeaf) + ") is not " +
relation + " " + toString(superTy) + toString(reasoning.superPath, /* prefixDot */ true) + " (" + toString(*superLeaf) + ")";
reasons.push_back(reason);
}
// DenseHashSet ordering is entirely undefined, so we want to
// sort the reasons here to achieve a stable error
// stringification.
std::sort(reasons.begin(), reasons.end());
std::string allReasons;
bool first = true;
for (const std::string& reason : reasons)
{
if (first)
first = false;
else
allReasons += "\n\t";
allReasons += reason;
}
return allReasons;
}
void explainError(TypeId subTy, TypeId superTy, Location location, const SubtypingResult& result)
{
reportError(TypeMismatch{superTy, subTy, explainReasonings(subTy, superTy, location, result).value_or("")}, location);
}
void explainError(TypePackId subTy, TypePackId superTy, Location location, const SubtypingResult& result)
{
reportError(TypePackMismatch{superTy, subTy, explainReasonings(subTy, superTy, location, result).value_or("")}, location);
}
bool testIsSubtype(TypeId subTy, TypeId superTy, Location location)
{
SubtypingResult r = subtyping->isSubtype(subTy, superTy);
@ -2421,27 +2517,7 @@ struct TypeChecker2
reportError(NormalizationTooComplex{}, location);
if (!r.isSubtype && !r.isErrorSuppressing)
{
if (r.reasoning)
{
std::optional<TypeOrPack> subLeaf = traverse(subTy, r.reasoning->subPath, builtinTypes);
std::optional<TypeOrPack> superLeaf = traverse(superTy, r.reasoning->superPath, builtinTypes);
if (!subLeaf || !superLeaf)
ice->ice("Subtyping test returned a reasoning with an invalid path", location);
if (!get2<TypeId, TypeId>(*subLeaf, *superLeaf) && !get2<TypePackId, TypePackId>(*subLeaf, *superLeaf))
ice->ice("Subtyping test returned a reasoning where one path ends at a type and the other ends at a pack.", location);
std::string reason = "type " + toString(subTy) + toString(r.reasoning->subPath) + " (" + toString(*subLeaf) +
") is not a subtype of " + toString(superTy) + toString(r.reasoning->superPath) + " (" + toString(*superLeaf) +
")";
reportError(TypeMismatch{superTy, subTy, reason}, location);
}
else
reportError(TypeMismatch{superTy, subTy}, location);
}
explainError(subTy, superTy, location, r);
return r.isSubtype;
}
@ -2454,7 +2530,7 @@ struct TypeChecker2
reportError(NormalizationTooComplex{}, location);
if (!r.isSubtype && !r.isErrorSuppressing)
reportError(TypePackMismatch{superTy, subTy}, location);
explainError(subTy, superTy, location, r);
return r.isSubtype;
}
@ -2502,7 +2578,7 @@ struct TypeChecker2
if (!normalizer.isInhabited(ty))
return;
std::unordered_set<TypeId> seen;
DenseHashSet<TypeId> seen{nullptr};
bool found = hasIndexTypeFromType(ty, prop, location, seen, astIndexExprType);
foundOneProp |= found;
if (!found)
@ -2563,14 +2639,14 @@ struct TypeChecker2
}
}
bool hasIndexTypeFromType(TypeId ty, const std::string& prop, const Location& location, std::unordered_set<TypeId>& seen, TypeId astIndexExprType)
bool hasIndexTypeFromType(TypeId ty, const std::string& prop, const Location& location, DenseHashSet<TypeId>& seen, TypeId astIndexExprType)
{
// If we have already encountered this type, we must assume that some
// other codepath will do the right thing and signal false if the
// property is not present.
const bool isUnseen = seen.insert(ty).second;
if (!isUnseen)
if (seen.contains(ty))
return true;
seen.insert(ty);
if (get<ErrorType>(ty) || get<AnyType>(ty) || get<NeverType>(ty))
return true;

View file

@ -751,8 +751,11 @@ TypeFamilyReductionResult<TypeId> andFamilyFn(const std::vector<TypeId>& typePar
// And evalutes to a boolean if the LHS is falsey, and the RHS type if LHS is truthy.
SimplifyResult filteredLhs = simplifyIntersection(ctx->builtins, ctx->arena, lhsTy, ctx->builtins->falsyType);
SimplifyResult overallResult = simplifyUnion(ctx->builtins, ctx->arena, rhsTy, filteredLhs.result);
std::vector<TypeId> blockedTypes(filteredLhs.blockedTypes.begin(), filteredLhs.blockedTypes.end());
blockedTypes.insert(blockedTypes.end(), overallResult.blockedTypes.begin(), overallResult.blockedTypes.end());
std::vector<TypeId> blockedTypes{};
for (auto ty : filteredLhs.blockedTypes)
blockedTypes.push_back(ty);
for (auto ty : overallResult.blockedTypes)
blockedTypes.push_back(ty);
return {overallResult.result, false, std::move(blockedTypes), {}};
}
@ -776,8 +779,11 @@ TypeFamilyReductionResult<TypeId> orFamilyFn(const std::vector<TypeId>& typePara
// Or evalutes to the LHS type if the LHS is truthy, and the RHS type if LHS is falsy.
SimplifyResult filteredLhs = simplifyIntersection(ctx->builtins, ctx->arena, lhsTy, ctx->builtins->truthyType);
SimplifyResult overallResult = simplifyUnion(ctx->builtins, ctx->arena, rhsTy, filteredLhs.result);
std::vector<TypeId> blockedTypes(filteredLhs.blockedTypes.begin(), filteredLhs.blockedTypes.end());
blockedTypes.insert(blockedTypes.end(), overallResult.blockedTypes.begin(), overallResult.blockedTypes.end());
std::vector<TypeId> blockedTypes{};
for (auto ty : filteredLhs.blockedTypes)
blockedTypes.push_back(ty);
for (auto ty : overallResult.blockedTypes)
blockedTypes.push_back(ty);
return {overallResult.result, false, std::move(blockedTypes), {}};
}

View file

@ -35,14 +35,11 @@ LUAU_FASTFLAG(LuauKnowsTheDataModel3)
LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false)
LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false)
LUAU_FASTFLAG(LuauInstantiateInSubtyping)
LUAU_FASTFLAGVARIABLE(LuauAllowIndexClassParameters, false)
LUAU_FASTFLAG(LuauOccursIsntAlwaysFailure)
LUAU_FASTFLAGVARIABLE(LuauTinyControlFlowAnalysis, false)
LUAU_FASTFLAGVARIABLE(LuauLoopControlFlowAnalysis, false)
LUAU_FASTFLAGVARIABLE(LuauVariadicOverloadFix, false)
LUAU_FASTFLAGVARIABLE(LuauAlwaysCommitInferencesOfFunctionCalls, false)
LUAU_FASTFLAG(LuauParseDeclareClassIndexer)
LUAU_FASTFLAG(LuauFloorDivision);
LUAU_FASTFLAG(LuauBufferTypeck)
LUAU_FASTFLAGVARIABLE(LuauRemoveBadRelationalOperatorWarning, false)
namespace Luau
{
@ -204,7 +201,7 @@ static bool isMetamethod(const Name& name)
return name == "__index" || name == "__newindex" || name == "__call" || name == "__concat" || name == "__unm" || name == "__add" ||
name == "__sub" || name == "__mul" || name == "__div" || name == "__mod" || name == "__pow" || name == "__tostring" ||
name == "__metatable" || name == "__eq" || name == "__lt" || name == "__le" || name == "__mode" || name == "__iter" || name == "__len" ||
(FFlag::LuauFloorDivision && name == "__idiv");
name == "__idiv";
}
size_t HashBoolNamePair::operator()(const std::pair<bool, Name>& pair) const
@ -224,6 +221,7 @@ TypeChecker::TypeChecker(const ScopePtr& globalScope, ModuleResolver* resolver,
, stringType(builtinTypes->stringType)
, booleanType(builtinTypes->booleanType)
, threadType(builtinTypes->threadType)
, bufferType(builtinTypes->bufferType)
, anyType(builtinTypes->anyType)
, unknownType(builtinTypes->unknownType)
, neverType(builtinTypes->neverType)
@ -1628,13 +1626,6 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& ty
TypeId& bindingType = bindingsMap[name].type;
if (!FFlag::LuauOccursIsntAlwaysFailure)
{
if (unify(ty, bindingType, aliasScope, typealias.location))
bindingType = ty;
return ControlFlow::None;
}
unify(ty, bindingType, aliasScope, typealias.location);
// It is possible for this unification to succeed but for
@ -1764,7 +1755,7 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass&
if (!ctv->metatable)
ice("No metatable for declared class");
if (const auto& indexer = declaredClass.indexer; FFlag::LuauParseDeclareClassIndexer && indexer)
if (const auto& indexer = declaredClass.indexer)
ctv->indexer = TableIndexer(resolveType(scope, *indexer->indexType), resolveType(scope, *indexer->resultType));
TableType* metatable = getMutable<TableType>(*ctv->metatable);
@ -2562,7 +2553,6 @@ std::string opToMetaTableEntry(const AstExprBinary::Op& op)
case AstExprBinary::Div:
return "__div";
case AstExprBinary::FloorDiv:
LUAU_ASSERT(FFlag::LuauFloorDivision);
return "__idiv";
case AstExprBinary::Mod:
return "__mod";
@ -2765,10 +2755,26 @@ TypeId TypeChecker::checkRelationalOperation(
{
reportErrors(state.errors);
if (!isEquality && state.errors.empty() && (get<UnionType>(leftType) || isBoolean(leftType)))
if (FFlag::LuauRemoveBadRelationalOperatorWarning)
{
reportError(expr.location, GenericError{format("Type '%s' cannot be compared with relational operator %s", toString(leftType).c_str(),
toString(expr.op).c_str())});
// The original version of this check also produced this error when we had a union type.
// However, the old solver does not readily have the ability to discern if the union is comparable.
// This is the case when the lhs is e.g. a union of singletons and the rhs is the combined type.
// The new solver has much more powerful logic for resolving relational operators, but for now,
// we need to be conservative in the old solver to deliver a reasonable developer experience.
if (!isEquality && state.errors.empty() && isBoolean(leftType))
{
reportError(expr.location, GenericError{format("Type '%s' cannot be compared with relational operator %s",
toString(leftType).c_str(), toString(expr.op).c_str())});
}
}
else
{
if (!isEquality && state.errors.empty() && (get<UnionType>(leftType) || isBoolean(leftType)))
{
reportError(expr.location, GenericError{format("Type '%s' cannot be compared with relational operator %s",
toString(leftType).c_str(), toString(expr.op).c_str())});
}
}
return booleanType;
@ -3060,8 +3066,6 @@ TypeId TypeChecker::checkBinaryOperation(
case AstExprBinary::FloorDiv:
case AstExprBinary::Mod:
case AstExprBinary::Pow:
LUAU_ASSERT(FFlag::LuauFloorDivision || expr.op != AstExprBinary::FloorDiv);
reportErrors(tryUnify(lhsType, numberType, scope, expr.left->location));
reportErrors(tryUnify(rhsType, numberType, scope, expr.right->location));
return numberType;
@ -3412,15 +3416,12 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex
}
}
if (FFlag::LuauAllowIndexClassParameters)
if (const ClassType* exprClass = get<ClassType>(exprType))
{
if (const ClassType* exprClass = get<ClassType>(exprType))
{
if (isNonstrictMode())
return unknownType;
reportError(TypeError{expr.location, DynamicPropertyLookupOnClassesUnsafe{exprType}});
return errorRecoveryType(scope);
}
if (isNonstrictMode())
return unknownType;
reportError(TypeError{expr.location, DynamicPropertyLookupOnClassesUnsafe{exprType}});
return errorRecoveryType(scope);
}
}
@ -4026,13 +4027,9 @@ void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funNam
if (argIndex < argLocations.size())
location = argLocations[argIndex];
if (FFlag::LuauVariadicOverloadFix)
{
state.location = location;
state.tryUnify(*argIter, vtp->ty);
}
else
unify(*argIter, vtp->ty, scope, location);
state.location = location;
state.tryUnify(*argIter, vtp->ty);
++argIter;
++argIndex;
}
@ -5403,7 +5400,7 @@ TypeId TypeChecker::resolveTypeWorker(const ScopePtr& scope, const AstType& anno
std::optional<TableIndexer> tableIndexer;
for (const auto& prop : table->props)
props[prop.name.value] = {resolveType(scope, *prop.type)};
props[prop.name.value] = {resolveType(scope, *prop.type), /* deprecated: */ false, {}, std::nullopt, {}, std::nullopt, prop.location};
if (const auto& indexer = table->indexer)
tableIndexer = TableIndexer(resolveType(scope, *indexer->indexType), resolveType(scope, *indexer->resultType));
@ -6025,6 +6022,8 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, RefinementMap& r
return refine(isBoolean, booleanType);
else if (typeguardP.kind == "thread")
return refine(isThread, threadType);
else if (FFlag::LuauBufferTypeck && typeguardP.kind == "buffer")
return refine(isBuffer, bufferType);
else if (typeguardP.kind == "table")
{
return refine([](TypeId ty) -> bool {

View file

@ -6,12 +6,12 @@
#include "Luau/Type.h"
#include "Luau/TypeFwd.h"
#include "Luau/TypePack.h"
#include "Luau/TypeUtils.h"
#include "Luau/TypeOrPack.h"
#include <functional>
#include <optional>
#include <sstream>
#include <type_traits>
#include <unordered_set>
LUAU_FASTFLAG(DebugLuauReadWriteProperties);
@ -104,6 +104,41 @@ bool Path::operator==(const Path& other) const
return components == other.components;
}
size_t PathHash::operator()(const Property& prop) const
{
return std::hash<std::string>()(prop.name) ^ static_cast<size_t>(prop.isRead);
}
size_t PathHash::operator()(const Index& idx) const
{
return idx.index;
}
size_t PathHash::operator()(const TypeField& field) const
{
return static_cast<size_t>(field);
}
size_t PathHash::operator()(const PackField& field) const
{
return static_cast<size_t>(field);
}
size_t PathHash::operator()(const Component& component) const
{
return visit(*this, component);
}
size_t PathHash::operator()(const Path& path) const
{
size_t hash = 0;
for (const Component& component : path.components)
hash ^= (*this)(component);
return hash;
}
Path PathBuilder::build()
{
return Path(std::move(components));
@ -216,8 +251,6 @@ struct TraversalState
TypeOrPack current;
NotNull<BuiltinTypes> builtinTypes;
DenseHashSet<const void*> seen{nullptr};
int steps = 0;
void updateCurrent(TypeId ty)
@ -232,18 +265,6 @@ struct TraversalState
current = follow(tp);
}
bool haveCycle()
{
const void* currentPtr = ptr(current);
if (seen.contains(currentPtr))
return true;
else
seen.insert(currentPtr);
return false;
}
bool tooLong()
{
return ++steps > DFInt::LuauTypePathMaximumTraverseSteps;
@ -251,7 +272,7 @@ struct TraversalState
bool checkInvariants()
{
return haveCycle() || tooLong();
return tooLong();
}
bool traverse(const TypePath::Property& property)
@ -277,18 +298,36 @@ struct TraversalState
{
prop = lookupClassProp(c, property.name);
}
else if (auto m = getMetatable(*currentType, builtinTypes))
// For a metatable type, the table takes priority; check that before
// falling through to the metatable entry below.
else if (auto m = get<MetatableType>(*currentType))
{
// Weird: rather than use findMetatableEntry, which requires a lot
// of stuff that we don't have and don't want to pull in, we use the
// path traversal logic to grab __index and then re-enter the lookup
// logic there.
updateCurrent(*m);
TypeOrPack pinned = current;
updateCurrent(m->table);
if (!traverse(TypePath::Property{"__index"}))
return false;
if (traverse(property))
return true;
return traverse(property);
// Restore the old current type if we didn't traverse the metatable
// successfully; we'll use the next branch to address this.
current = pinned;
}
if (!prop)
{
if (auto m = getMetatable(*currentType, builtinTypes))
{
// Weird: rather than use findMetatableEntry, which requires a lot
// of stuff that we don't have and don't want to pull in, we use the
// path traversal logic to grab __index and then re-enter the lookup
// logic there.
updateCurrent(*m);
if (!traverse(TypePath::Property{"__index"}))
return false;
return traverse(property);
}
}
if (prop)
@ -465,7 +504,7 @@ struct TraversalState
} // namespace
std::string toString(const TypePath::Path& path)
std::string toString(const TypePath::Path& path, bool prefixDot)
{
std::stringstream result;
bool first = true;
@ -491,7 +530,7 @@ std::string toString(const TypePath::Path& path)
}
else if constexpr (std::is_same_v<T, TypePath::TypeField>)
{
if (!first)
if (!first || prefixDot)
result << '.';
switch (c)
@ -523,7 +562,7 @@ std::string toString(const TypePath::Path& path)
}
else if constexpr (std::is_same_v<T, TypePath::PackField>)
{
if (!first)
if (!first || prefixDot)
result << '.';
switch (c)
@ -580,7 +619,14 @@ std::optional<TypeOrPack> traverse(TypeId root, const Path& path, NotNull<Builti
return std::nullopt;
}
std::optional<TypeOrPack> traverse(TypePackId root, const Path& path, NotNull<BuiltinTypes> builtinTypes);
std::optional<TypeOrPack> traverse(TypePackId root, const Path& path, NotNull<BuiltinTypes> builtinTypes)
{
TraversalState state(follow(root), builtinTypes);
if (traverse(state, path))
return state.current;
else
return std::nullopt;
}
std::optional<TypeId> traverseForType(TypeId root, const Path& path, NotNull<BuiltinTypes> builtinTypes)
{

View file

@ -18,12 +18,11 @@
LUAU_FASTINT(LuauTypeInferTypePackLoopLimit)
LUAU_FASTFLAG(LuauErrorRecoveryType)
LUAU_FASTFLAGVARIABLE(LuauInstantiateInSubtyping, false)
LUAU_FASTFLAGVARIABLE(LuauMaintainScopesInUnifier, false)
LUAU_FASTFLAGVARIABLE(LuauTransitiveSubtyping, false)
LUAU_FASTFLAGVARIABLE(LuauOccursIsntAlwaysFailure, false)
LUAU_FASTFLAG(LuauAlwaysCommitInferencesOfFunctionCalls)
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
LUAU_FASTFLAGVARIABLE(LuauFixIndexerSubtypingOrdering, false)
LUAU_FASTFLAGVARIABLE(LuauUnifierShouldNotCopyError, false)
namespace Luau
{
@ -1514,7 +1513,7 @@ struct WeirdIter
auto freePack = log.getMutable<FreeTypePack>(packId);
level = freePack->level;
if (FFlag::LuauMaintainScopesInUnifier && freePack->scope != nullptr)
if (freePack->scope != nullptr)
scope = freePack->scope;
log.replace(packId, BoundTypePack(newTail));
packId = newTail;
@ -1679,11 +1678,8 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal
auto superIter = WeirdIter(superTp, log);
auto subIter = WeirdIter(subTp, log);
if (FFlag::LuauMaintainScopesInUnifier)
{
superIter.scope = scope.get();
subIter.scope = scope.get();
}
superIter.scope = scope.get();
subIter.scope = scope.get();
auto mkFreshType = [this](Scope* scope, TypeLevel level) {
if (FFlag::DebugLuauDeferredConstraintResolution)
@ -2877,7 +2873,7 @@ bool Unifier::occursCheck(TypeId needle, TypeId haystack, bool reversed)
bool occurs = occursCheck(sharedState.tempSeenTy, needle, haystack);
if (occurs && FFlag::LuauOccursIsntAlwaysFailure)
if (occurs)
{
Unifier innerState = makeChildUnifier();
if (const UnionType* ut = get<UnionType>(haystack))
@ -2935,15 +2931,7 @@ bool Unifier::occursCheck(DenseHashSet<TypeId>& seen, TypeId needle, TypeId hays
ice("Expected needle to be free");
if (needle == haystack)
{
if (!FFlag::LuauOccursIsntAlwaysFailure)
{
reportError(location, OccursCheckFailed{});
log.replace(needle, *builtinTypes->errorRecoveryType());
}
return true;
}
if (log.getMutable<FreeType>(haystack) || (hideousFixMeGenericsAreActuallyFree && log.is<GenericType>(haystack)))
return false;
@ -2967,10 +2955,13 @@ bool Unifier::occursCheck(TypePackId needle, TypePackId haystack, bool reversed)
bool occurs = occursCheck(sharedState.tempSeenTp, needle, haystack);
if (occurs && FFlag::LuauOccursIsntAlwaysFailure)
if (occurs)
{
reportError(location, OccursCheckFailed{});
log.replace(needle, *builtinTypes->errorRecoveryTypePack());
if (FFlag::LuauUnifierShouldNotCopyError)
log.replace(needle, BoundTypePack{builtinTypes->errorRecoveryTypePack()});
else
log.replace(needle, *builtinTypes->errorRecoveryTypePack());
}
return occurs;
@ -2997,15 +2988,7 @@ bool Unifier::occursCheck(DenseHashSet<TypePackId>& seen, TypePackId needle, Typ
while (!log.getMutable<ErrorType>(haystack))
{
if (needle == haystack)
{
if (!FFlag::LuauOccursIsntAlwaysFailure)
{
reportError(location, OccursCheckFailed{});
log.replace(needle, *builtinTypes->errorRecoveryTypePack());
}
return true;
}
if (auto a = get<TypePack>(haystack); a && a->tail)
{

View file

@ -5,9 +5,6 @@
#include "Luau/Instantiation.h"
#include "Luau/Scope.h"
#include "Luau/Simplify.h"
#include "Luau/Substitution.h"
#include "Luau/ToString.h"
#include "Luau/TxnLog.h"
#include "Luau/Type.h"
#include "Luau/TypeArena.h"
#include "Luau/TypeCheckLimits.h"
@ -16,7 +13,6 @@
#include <algorithm>
#include <optional>
#include <unordered_set>
LUAU_FASTINT(LuauTypeInferRecursionLimit)
@ -49,7 +45,10 @@ bool Unifier2::unify(TypeId subTy, TypeId superTy)
FreeType* superFree = getMutable<FreeType>(superTy);
if (subFree)
{
subFree->upperBound = mkIntersection(subFree->upperBound, superTy);
expandedFreeTypes[subTy].push_back(superTy);
}
if (superFree)
superFree->lowerBound = mkUnion(superFree->lowerBound, subTy);
@ -114,7 +113,7 @@ bool Unifier2::unify(TypeId subTy, TypeId superTy)
return argResult && retResult;
}
auto subTable = get<TableType>(subTy);
auto subTable = getMutable<TableType>(subTy);
auto superTable = get<TableType>(superTy);
if (subTable && superTable)
{
@ -211,7 +210,7 @@ bool Unifier2::unify(TypeId subTy, const IntersectionType* superIntersection)
return result;
}
bool Unifier2::unify(const TableType* subTable, const TableType* superTable)
bool Unifier2::unify(TableType* subTable, const TableType* superTable)
{
bool result = true;
@ -257,6 +256,21 @@ bool Unifier2::unify(const TableType* subTable, const TableType* superTable)
result &= unify(subTable->indexer->indexResultType, superTable->indexer->indexResultType);
}
if (!subTable->indexer && subTable->state == TableState::Unsealed && superTable->indexer)
{
/*
* Unsealed tables are always created from literal table expressions. We
* can't be completely certain whether such a table has an indexer just
* by the content of the expression itself, so we need to be a bit more
* flexible here.
*
* If we are trying to reconcile an unsealed table with a table that has
* an indexer, we therefore conclude that the unsealed table has the
* same indexer.
*/
subTable->indexer = *superTable->indexer;
}
return result;
}

View file

@ -3,6 +3,7 @@
#include "Luau/Location.h"
#include <iterator>
#include <optional>
#include <functional>
#include <string>
@ -91,10 +92,21 @@ struct AstArray
{
return data;
}
const T* end() const
{
return data + size;
}
std::reverse_iterator<const T*> rbegin() const
{
return std::make_reverse_iterator(end());
}
std::reverse_iterator<const T*> rend() const
{
return std::make_reverse_iterator(begin());
}
};
struct AstTypeList
@ -249,6 +261,7 @@ public:
enum class ConstantNumberParseResult
{
Ok,
Imprecise,
Malformed,
BinOverflow,
HexOverflow,

View file

@ -1,7 +1,6 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include <string>
namespace Luau
{
@ -9,7 +8,11 @@ struct Position
{
unsigned int line, column;
Position(unsigned int line, unsigned int column);
Position(unsigned int line, unsigned int column)
: line(line)
, column(column)
{
}
bool operator==(const Position& rhs) const;
bool operator!=(const Position& rhs) const;
@ -25,10 +28,29 @@ struct Location
{
Position begin, end;
Location();
Location(const Position& begin, const Position& end);
Location(const Position& begin, unsigned int length);
Location(const Location& begin, const Location& end);
Location()
: begin(0, 0)
, end(0, 0)
{
}
Location(const Position& begin, const Position& end)
: begin(begin)
, end(end)
{
}
Location(const Position& begin, unsigned int length)
: begin(begin)
, end(begin.line, begin.column + length)
{
}
Location(const Location& begin, const Location& end)
: begin(begin.begin)
, end(end.end)
{
}
bool operator==(const Location& rhs) const;
bool operator!=(const Location& rhs) const;

View file

@ -3,7 +3,6 @@
#include "Luau/Common.h"
LUAU_FASTFLAG(LuauFloorDivision);
namespace Luau
{
@ -282,7 +281,6 @@ std::string toString(AstExprBinary::Op op)
case AstExprBinary::Div:
return "/";
case AstExprBinary::FloorDiv:
LUAU_ASSERT(FFlag::LuauFloorDivision);
return "//";
case AstExprBinary::Mod:
return "%";

View file

@ -7,7 +7,6 @@
#include <limits.h>
LUAU_FASTFLAGVARIABLE(LuauFloorDivision, false)
LUAU_FASTFLAGVARIABLE(LuauLexerLookaheadRemembersBraceType, false)
LUAU_FASTFLAGVARIABLE(LuauCheckedFunctionSyntax, false)
@ -142,7 +141,7 @@ std::string Lexeme::toString() const
return "'::'";
case FloorDiv:
return FFlag::LuauFloorDivision ? "'//'" : "<unknown>";
return "'//'";
case AddAssign:
return "'+='";
@ -157,7 +156,7 @@ std::string Lexeme::toString() const
return "'/='";
case FloorDivAssign:
return FFlag::LuauFloorDivision ? "'//='" : "<unknown>";
return "'//='";
case ModAssign:
return "'%='";
@ -909,44 +908,29 @@ Lexeme Lexer::readNext()
case '/':
{
if (FFlag::LuauFloorDivision)
consume();
char ch = peekch();
if (ch == '=')
{
consume();
char ch = peekch();
if (ch == '=')
{
consume();
return Lexeme(Location(start, 2), Lexeme::DivAssign);
}
else if (ch == '/')
{
consume();
if (peekch() == '=')
{
consume();
return Lexeme(Location(start, 3), Lexeme::FloorDivAssign);
}
else
return Lexeme(Location(start, 2), Lexeme::FloorDiv);
}
else
return Lexeme(Location(start, 1), '/');
return Lexeme(Location(start, 2), Lexeme::DivAssign);
}
else
else if (ch == '/')
{
consume();
if (peekch() == '=')
{
consume();
return Lexeme(Location(start, 2), Lexeme::DivAssign);
return Lexeme(Location(start, 3), Lexeme::FloorDivAssign);
}
else
return Lexeme(Location(start, 1), '/');
return Lexeme(Location(start, 2), Lexeme::FloorDiv);
}
else
return Lexeme(Location(start, 1), '/');
}
case '*':

View file

@ -1,16 +1,9 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Location.h"
#include <string>
namespace Luau
{
Position::Position(unsigned int line, unsigned int column)
: line(line)
, column(column)
{
}
bool Position::operator==(const Position& rhs) const
{
return this->column == rhs.column && this->line == rhs.line;
@ -61,30 +54,6 @@ void Position::shift(const Position& start, const Position& oldEnd, const Positi
}
}
Location::Location()
: begin(0, 0)
, end(0, 0)
{
}
Location::Location(const Position& begin, const Position& end)
: begin(begin)
, end(end)
{
}
Location::Location(const Position& begin, unsigned int length)
: begin(begin)
, end(begin.line, begin.column + length)
{
}
Location::Location(const Location& begin, const Location& end)
: begin(begin.begin)
, end(end.end)
{
}
bool Location::operator==(const Location& rhs) const
{
return this->begin == rhs.begin && this->end == rhs.end;

View file

@ -16,14 +16,9 @@ LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100)
// Warning: If you are introducing new syntax, ensure that it is behind a separate
// flag so that we don't break production games by reverting syntax changes.
// See docs/SyntaxChanges.md for an explanation.
LUAU_FASTFLAGVARIABLE(LuauParseDeclareClassIndexer, false)
LUAU_FASTFLAGVARIABLE(LuauClipExtraHasEndProps, false)
LUAU_FASTFLAG(LuauFloorDivision)
LUAU_FASTFLAG(LuauCheckedFunctionSyntax)
LUAU_FASTFLAGVARIABLE(LuauBetterTypeUnionLimits, false)
LUAU_FASTFLAGVARIABLE(LuauBetterTypeRecLimits, false)
namespace Luau
{
@ -924,7 +919,7 @@ AstStat* Parser::parseDeclaration(const Location& start)
{
props.push_back(parseDeclaredClassMethod());
}
else if (lexer.current().type == '[' && (!FFlag::LuauParseDeclareClassIndexer || lexer.lookahead().type == Lexeme::RawString ||
else if (lexer.current().type == '[' && (lexer.lookahead().type == Lexeme::RawString ||
lexer.lookahead().type == Lexeme::QuotedString))
{
const Lexeme begin = lexer.current();
@ -944,7 +939,7 @@ AstStat* Parser::parseDeclaration(const Location& start)
else
report(begin.location, "String literal contains malformed escape sequence or \\0");
}
else if (lexer.current().type == '[' && FFlag::LuauParseDeclareClassIndexer)
else if (lexer.current().type == '[')
{
if (indexer)
{
@ -1544,8 +1539,7 @@ AstType* Parser::parseTypeSuffix(AstType* type, const Location& begin)
unsigned int oldRecursionCount = recursionCounter;
parts.push_back(parseSimpleType(/* allowPack= */ false).type);
if (FFlag::LuauBetterTypeUnionLimits)
recursionCounter = oldRecursionCount;
recursionCounter = oldRecursionCount;
isUnion = true;
}
@ -1554,7 +1548,7 @@ AstType* Parser::parseTypeSuffix(AstType* type, const Location& begin)
Location loc = lexer.current().location;
nextLexeme();
if (!FFlag::LuauBetterTypeUnionLimits || !hasOptional)
if (!hasOptional)
parts.push_back(allocator.alloc<AstTypeReference>(loc, std::nullopt, nameNil, std::nullopt, loc));
isUnion = true;
@ -1566,8 +1560,7 @@ AstType* Parser::parseTypeSuffix(AstType* type, const Location& begin)
unsigned int oldRecursionCount = recursionCounter;
parts.push_back(parseSimpleType(/* allowPack= */ false).type);
if (FFlag::LuauBetterTypeUnionLimits)
recursionCounter = oldRecursionCount;
recursionCounter = oldRecursionCount;
isIntersection = true;
}
@ -1579,7 +1572,7 @@ AstType* Parser::parseTypeSuffix(AstType* type, const Location& begin)
else
break;
if (FFlag::LuauBetterTypeUnionLimits && parts.size() > unsigned(FInt::LuauTypeLengthLimit) + hasOptional)
if (parts.size() > unsigned(FInt::LuauTypeLengthLimit) + hasOptional)
ParseError::raise(parts.back()->location, "Exceeded allowed type length; simplify your type annotation to make the code compile");
}
@ -1607,10 +1600,7 @@ AstType* Parser::parseTypeSuffix(AstType* type, const Location& begin)
AstTypeOrPack Parser::parseTypeOrPack()
{
unsigned int oldRecursionCount = recursionCounter;
// recursion counter is incremented in parseSimpleType
if (!FFlag::LuauBetterTypeRecLimits)
incrementRecursionCounter("type annotation");
Location begin = lexer.current().location;
@ -1630,10 +1620,7 @@ AstTypeOrPack Parser::parseTypeOrPack()
AstType* Parser::parseType(bool inDeclarationContext)
{
unsigned int oldRecursionCount = recursionCounter;
// recursion counter is incremented in parseSimpleType
if (!FFlag::LuauBetterTypeRecLimits)
incrementRecursionCounter("type annotation");
Location begin = lexer.current().location;
@ -1839,11 +1826,7 @@ std::optional<AstExprBinary::Op> Parser::parseBinaryOp(const Lexeme& l)
else if (l.type == '/')
return AstExprBinary::Div;
else if (l.type == Lexeme::FloorDiv)
{
LUAU_ASSERT(FFlag::LuauFloorDivision);
return AstExprBinary::FloorDiv;
}
else if (l.type == '%')
return AstExprBinary::Mod;
else if (l.type == '^')
@ -1881,11 +1864,7 @@ std::optional<AstExprBinary::Op> Parser::parseCompoundOp(const Lexeme& l)
else if (l.type == Lexeme::DivAssign)
return AstExprBinary::Div;
else if (l.type == Lexeme::FloorDivAssign)
{
LUAU_ASSERT(FFlag::LuauFloorDivision);
return AstExprBinary::FloorDiv;
}
else if (l.type == Lexeme::ModAssign)
return AstExprBinary::Mod;
else if (l.type == Lexeme::PowAssign)
@ -2187,6 +2166,9 @@ static ConstantNumberParseResult parseInteger(double& result, const char* data,
return base == 2 ? ConstantNumberParseResult::BinOverflow : ConstantNumberParseResult::HexOverflow;
}
if (value >= (1ull << 53) && static_cast<unsigned long long>(result) != value)
return ConstantNumberParseResult::Imprecise;
return ConstantNumberParseResult::Ok;
}
@ -2203,8 +2185,24 @@ static ConstantNumberParseResult parseDouble(double& result, const char* data)
char* end = nullptr;
double value = strtod(data, &end);
// trailing non-numeric characters
if (*end != 0)
return ConstantNumberParseResult::Malformed;
result = value;
return *end == 0 ? ConstantNumberParseResult::Ok : ConstantNumberParseResult::Malformed;
// for linting, we detect integer constants that are parsed imprecisely
// since the check is expensive we only perform it when the number is larger than the precise integer range
if (value >= double(1ull << 53) && strspn(data, "0123456789") == strlen(data))
{
char repr[512];
snprintf(repr, sizeof(repr), "%.0f", value);
if (strcmp(repr, data) != 0)
return ConstantNumberParseResult::Imprecise;
}
return ConstantNumberParseResult::Ok;
}
// simpleexp -> NUMBER | STRING | NIL | true | false | ... | constructor | FUNCTION body | primaryexp

295
CLI/Bytecode.cpp Normal file
View file

@ -0,0 +1,295 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "lua.h"
#include "lualib.h"
#include "Luau/CodeGen.h"
#include "Luau/Compiler.h"
#include "Luau/BytecodeBuilder.h"
#include "Luau/Parser.h"
#include "Luau/BytecodeSummary.h"
#include "FileUtils.h"
#include "Flags.h"
#include <memory>
using Luau::CodeGen::FunctionBytecodeSummary;
struct GlobalOptions
{
int optimizationLevel = 1;
int debugLevel = 1;
} globalOptions;
static Luau::CompileOptions copts()
{
Luau::CompileOptions result = {};
result.optimizationLevel = globalOptions.optimizationLevel;
result.debugLevel = globalOptions.debugLevel;
return result;
}
static void displayHelp(const char* argv0)
{
printf("Usage: %s [options] [file list]\n", argv0);
printf("\n");
printf("Available options:\n");
printf(" -h, --help: Display this usage message.\n");
printf(" -O<n>: compile with optimization level n (default 1, n should be between 0 and 2).\n");
printf(" -g<n>: compile with debug level n (default 1, n should be between 0 and 2).\n");
printf(" --fflags=<fflags>: flags to be enabled.\n");
printf(" --summary-file=<filename>: file in which bytecode analysis summary will be recorded (default 'bytecode-summary.json').\n");
exit(0);
}
static bool parseArgs(int argc, char** argv, std::string& summaryFile)
{
for (int i = 1; i < argc; i++)
{
if (strcmp(argv[i], "-h") == 0 || strcmp(argv[i], "--help") == 0)
{
displayHelp(argv[0]);
}
else if (strncmp(argv[i], "-O", 2) == 0)
{
int level = atoi(argv[i] + 2);
if (level < 0 || level > 2)
{
fprintf(stderr, "Error: Optimization level must be between 0 and 2 inclusive.\n");
return false;
}
globalOptions.optimizationLevel = level;
}
else if (strncmp(argv[i], "-g", 2) == 0)
{
int level = atoi(argv[i] + 2);
if (level < 0 || level > 2)
{
fprintf(stderr, "Error: Debug level must be between 0 and 2 inclusive.\n");
return false;
}
globalOptions.debugLevel = level;
}
else if (strncmp(argv[i], "--summary-file=", 15) == 0)
{
summaryFile = argv[i] + 15;
if (summaryFile.size() == 0)
{
fprintf(stderr, "Error: filename missing for '--summary-file'.\n\n");
return false;
}
}
else if (strncmp(argv[i], "--fflags=", 9) == 0)
{
setLuauFlags(argv[i] + 9);
}
else if (argv[i][0] == '-')
{
fprintf(stderr, "Error: Unrecognized option '%s'.\n\n", argv[i]);
displayHelp(argv[0]);
}
}
return true;
}
static void report(const char* name, const Luau::Location& location, const char* type, const char* message)
{
fprintf(stderr, "%s(%d,%d): %s: %s\n", name, location.begin.line + 1, location.begin.column + 1, type, message);
}
static void reportError(const char* name, const Luau::ParseError& error)
{
report(name, error.getLocation(), "SyntaxError", error.what());
}
static void reportError(const char* name, const Luau::CompileError& error)
{
report(name, error.getLocation(), "CompileError", error.what());
}
static bool analyzeFile(const char* name, const unsigned nestingLimit, std::vector<FunctionBytecodeSummary>& summaries)
{
std::optional<std::string> source = readFile(name);
if (!source)
{
fprintf(stderr, "Error opening %s\n", name);
return false;
}
try
{
Luau::BytecodeBuilder bcb;
compileOrThrow(bcb, source.value(), copts());
const std::string& bytecode = bcb.getBytecode();
std::unique_ptr<lua_State, void (*)(lua_State*)> globalState(luaL_newstate(), lua_close);
lua_State* L = globalState.get();
if (luau_load(L, name, bytecode.data(), bytecode.size(), 0) == 0)
{
summaries = Luau::CodeGen::summarizeBytecode(L, -1, nestingLimit);
return true;
}
else
{
fprintf(stderr, "Error loading bytecode %s\n", name);
return false;
}
}
catch (Luau::ParseErrors& e)
{
for (auto& error : e.getErrors())
reportError(name, error);
return false;
}
catch (Luau::CompileError& e)
{
reportError(name, e);
return false;
}
return true;
}
static std::string escapeFilename(const std::string& filename)
{
std::string escaped;
escaped.reserve(filename.size());
for (const char ch : filename)
{
switch (ch)
{
case '\\':
escaped.push_back('/');
break;
case '"':
escaped.push_back('\\');
escaped.push_back(ch);
break;
default:
escaped.push_back(ch);
}
}
return escaped;
}
static void serializeFunctionSummary(const FunctionBytecodeSummary& summary, FILE* fp)
{
const unsigned nestingLimit = summary.getNestingLimit();
const unsigned opLimit = summary.getOpLimit();
fprintf(fp, " {\n");
fprintf(fp, " \"source\": \"%s\",\n", summary.getSource().c_str());
fprintf(fp, " \"name\": \"%s\",\n", summary.getName().c_str());
fprintf(fp, " \"line\": %d,\n", summary.getLine());
fprintf(fp, " \"nestingLimit\": %u,\n", nestingLimit);
fprintf(fp, " \"counts\": [");
for (unsigned nesting = 0; nesting <= nestingLimit; ++nesting)
{
fprintf(fp, "\n [");
for (unsigned i = 0; i < opLimit; ++i)
{
fprintf(fp, "%d", summary.getCount(nesting, uint8_t(i)));
if (i < opLimit - 1)
fprintf(fp, ", ");
}
fprintf(fp, "]");
if (nesting < nestingLimit)
fprintf(fp, ",");
}
fprintf(fp, "\n ]");
fprintf(fp, "\n }");
}
static void serializeScriptSummary(const std::string& file, const std::vector<FunctionBytecodeSummary>& scriptSummary, FILE* fp)
{
std::string escaped(escapeFilename(file));
const size_t functionCount = scriptSummary.size();
fprintf(fp, " \"%s\": [\n", escaped.c_str());
for (size_t i = 0; i < functionCount; ++i)
{
serializeFunctionSummary(scriptSummary[i], fp);
fprintf(fp, i == (functionCount - 1) ? "\n" : ",\n");
}
fprintf(fp, " ]");
}
static bool serializeSummaries(
const std::vector<std::string>& files, const std::vector<std::vector<FunctionBytecodeSummary>>& scriptSummaries, const std::string& summaryFile)
{
FILE* fp = fopen(summaryFile.c_str(), "w");
const size_t fileCount = files.size();
if (!fp)
{
fprintf(stderr, "Unable to open '%s'.\n", summaryFile.c_str());
return false;
}
fprintf(fp, "{\n");
for (size_t i = 0; i < fileCount; ++i)
{
serializeScriptSummary(files[i], scriptSummaries[i], fp);
fprintf(fp, i < (fileCount - 1) ? ",\n" : "\n");
}
fprintf(fp, "}");
fclose(fp);
return true;
}
static int assertionHandler(const char* expr, const char* file, int line, const char* function)
{
printf("%s(%d): ASSERTION FAILED: %s\n", file, line, expr);
return 1;
}
int main(int argc, char** argv)
{
Luau::assertHandler() = assertionHandler;
setLuauFlagsDefault();
std::string summaryFile("bytecode-summary.json");
unsigned nestingLimit = 0;
if (!parseArgs(argc, argv, summaryFile))
return 1;
const std::vector<std::string> files = getSourceFiles(argc, argv);
size_t fileCount = files.size();
std::vector<std::vector<FunctionBytecodeSummary>> scriptSummaries;
scriptSummaries.reserve(fileCount);
for (size_t i = 0; i < fileCount; ++i)
{
if (!analyzeFile(files[i].c_str(), nestingLimit, scriptSummaries[i]))
return 1;
}
if (!serializeSummaries(files, scriptSummaries, summaryFile))
return 1;
fprintf(stdout, "Bytecode summary written to '%s'\n", summaryFile.c_str());
return 0;
}

View file

@ -37,13 +37,19 @@ enum class RecordStats
{
None,
Total,
Split
File,
Function
};
struct GlobalOptions
{
int optimizationLevel = 1;
int debugLevel = 1;
std::string vectorLib;
std::string vectorCtor;
std::string vectorType;
} globalOptions;
static Luau::CompileOptions copts()
@ -52,6 +58,11 @@ static Luau::CompileOptions copts()
result.optimizationLevel = globalOptions.optimizationLevel;
result.debugLevel = globalOptions.debugLevel;
// globalOptions outlive the CompileOptions, so it's safe to use string data pointers here
result.vectorLib = globalOptions.vectorLib.c_str();
result.vectorCtor = globalOptions.vectorCtor.c_str();
result.vectorType = globalOptions.vectorType.c_str();
return result;
}
@ -120,6 +131,7 @@ struct CompileStats
{
size_t lines;
size_t bytecode;
size_t bytecodeInstructionCount;
size_t codegen;
double readTime;
@ -136,6 +148,7 @@ struct CompileStats
fprintf(fp, "{\
\"lines\": %zu, \
\"bytecode\": %zu, \
\"bytecodeInstructionCount\": %zu, \
\"codegen\": %zu, \
\"readTime\": %f, \
\"miscTime\": %f, \
@ -153,16 +166,37 @@ struct CompileStats
\"maxBlockInstructions\": %u, \
\"regAllocErrors\": %d, \
\"loweringErrors\": %d\
}}",
lines, bytecode, codegen, readTime, miscTime, parseTime, compileTime, codegenTime, lowerStats.totalFunctions, lowerStats.skippedFunctions,
lowerStats.spillsToSlot, lowerStats.spillsToRestore, lowerStats.maxSpillSlotsUsed, lowerStats.blocksPreOpt, lowerStats.blocksPostOpt,
lowerStats.maxBlockInstructions, lowerStats.regAllocErrors, lowerStats.loweringErrors);
}, \
\"blockLinearizationStats\": {\
\"constPropInstructionCount\": %u, \
\"timeSeconds\": %f\
}",
lines, bytecode, bytecodeInstructionCount, codegen, readTime, miscTime, parseTime, compileTime, codegenTime, lowerStats.totalFunctions,
lowerStats.skippedFunctions, lowerStats.spillsToSlot, lowerStats.spillsToRestore, lowerStats.maxSpillSlotsUsed, lowerStats.blocksPreOpt,
lowerStats.blocksPostOpt, lowerStats.maxBlockInstructions, lowerStats.regAllocErrors, lowerStats.loweringErrors,
lowerStats.blockLinearizationStats.constPropInstructionCount, lowerStats.blockLinearizationStats.timeSeconds);
if (lowerStats.collectFunctionStats)
{
fprintf(fp, ", \"functions\": [");
auto functionCount = lowerStats.functions.size();
for (size_t i = 0; i < functionCount; ++i)
{
const Luau::CodeGen::FunctionStats& fstat = lowerStats.functions[i];
fprintf(fp, "{\"name\": \"%s\", \"line\": %d, \"bcodeCount\": %u, \"irCount\": %u, \"asmCount\": %u}", fstat.name.c_str(), fstat.line,
fstat.bcodeCount, fstat.irCount, fstat.asmCount);
if (i < functionCount - 1)
fprintf(fp, ", ");
}
fprintf(fp, "]");
}
fprintf(fp, "}");
}
CompileStats& operator+=(const CompileStats& that)
{
this->lines += that.lines;
this->bytecode += that.bytecode;
this->bytecodeInstructionCount += that.bytecodeInstructionCount;
this->codegen += that.codegen;
this->readTime += that.readTime;
this->miscTime += that.miscTime;
@ -257,6 +291,7 @@ static bool compileFile(const char* name, CompileFormat format, Luau::CodeGen::A
Luau::compileOrThrow(bcb, result, names, copts());
stats.bytecode += bcb.getBytecode().size();
stats.bytecodeInstructionCount = bcb.getTotalInstructionCount();
stats.compileTime += recordDeltaTime(currts);
switch (format)
@ -312,7 +347,11 @@ static void displayHelp(const char* argv0)
printf(" -g<n>: compile with debug level n (default 1, n should be between 0 and 2).\n");
printf(" --target=<target>: compile code for specific architecture (a64, x64, a64_nf, x64_ms).\n");
printf(" --timetrace: record compiler time tracing information into trace.json\n");
printf(" --record-stats=<style>: records compilation stats in stats.json (total, split).\n");
printf(" --stats-file=<filename>: file in which compilation stats will be recored (default 'stats.json').\n");
printf(" --record-stats=<granularity>: granularity of compilation stats recorded in stats.json (total, file, function).\n");
printf(" --vector-lib=<name>: name of the library providing vector type operations.\n");
printf(" --vector-ctor=<name>: name of the function constructing a vector value.\n");
printf(" --vector-type=<name>: name of the vector type.\n");
}
static int assertionHandler(const char* expr, const char* file, int line, const char* function)
@ -321,6 +360,30 @@ static int assertionHandler(const char* expr, const char* file, int line, const
return 1;
}
std::string escapeFilename(const std::string& filename)
{
std::string escaped;
escaped.reserve(filename.size());
for (const char ch : filename)
{
switch (ch)
{
case '\\':
escaped.push_back('/');
break;
case '"':
escaped.push_back('\\');
escaped.push_back(ch);
break;
default:
escaped.push_back(ch);
}
}
return escaped;
}
int main(int argc, char** argv)
{
Luau::assertHandler() = assertionHandler;
@ -330,6 +393,7 @@ int main(int argc, char** argv)
CompileFormat compileFormat = CompileFormat::Text;
Luau::CodeGen::AssemblyOptions::Target assemblyTarget = Luau::CodeGen::AssemblyOptions::Host;
RecordStats recordStats = RecordStats::None;
std::string statsFile("stats.json");
for (int i = 1; i < argc; i++)
{
@ -386,11 +450,23 @@ int main(int argc, char** argv)
if (strcmp(value, "total") == 0)
recordStats = RecordStats::Total;
else if (strcmp(value, "split") == 0)
recordStats = RecordStats::Split;
else if (strcmp(value, "file") == 0)
recordStats = RecordStats::File;
else if (strcmp(value, "function") == 0)
recordStats = RecordStats::Function;
else
{
fprintf(stderr, "Error: unknown 'style' for '--record-stats'\n");
fprintf(stderr, "Error: unknown 'granularity' for '--record-stats'\n");
return 1;
}
}
else if (strncmp(argv[i], "--stats-file=", 13) == 0)
{
statsFile = argv[i] + 13;
if (statsFile.size() == 0)
{
fprintf(stderr, "Error: filename missing for '--stats-file'.\n\n");
return 1;
}
}
@ -398,6 +474,18 @@ int main(int argc, char** argv)
{
setLuauFlags(argv[i] + 9);
}
else if (strncmp(argv[i], "--vector-lib=", 13) == 0)
{
globalOptions.vectorLib = argv[i] + 13;
}
else if (strncmp(argv[i], "--vector-ctor=", 14) == 0)
{
globalOptions.vectorCtor = argv[i] + 14;
}
else if (strncmp(argv[i], "--vector-type=", 14) == 0)
{
globalOptions.vectorType = argv[i] + 14;
}
else if (argv[i][0] == '-' && argv[i][1] == '-' && getCompileFormat(argv[i] + 2))
{
compileFormat = *getCompileFormat(argv[i] + 2);
@ -429,7 +517,7 @@ int main(int argc, char** argv)
CompileStats stats = {};
std::vector<CompileStats> fileStats;
if (recordStats == RecordStats::Split)
if (recordStats == RecordStats::File || recordStats == RecordStats::Function)
fileStats.reserve(fileCount);
int failed = 0;
@ -437,9 +525,10 @@ int main(int argc, char** argv)
for (const std::string& path : files)
{
CompileStats fileStat = {};
fileStat.lowerStats.collectFunctionStats = (recordStats == RecordStats::Function);
failed += !compileFile(path.c_str(), compileFormat, assemblyTarget, fileStat);
stats += fileStat;
if (recordStats == RecordStats::Split)
if (recordStats == RecordStats::File || recordStats == RecordStats::Function)
fileStats.push_back(fileStat);
}
@ -462,8 +551,7 @@ int main(int argc, char** argv)
if (recordStats != RecordStats::None)
{
FILE* fp = fopen("stats.json", "w");
FILE* fp = fopen(statsFile.c_str(), "w");
if (!fp)
{
@ -475,12 +563,13 @@ int main(int argc, char** argv)
{
stats.serializeToJson(fp);
}
else if (recordStats == RecordStats::Split)
else if (recordStats == RecordStats::File || recordStats == RecordStats::Function)
{
fprintf(fp, "{\n");
for (size_t i = 0; i < fileCount; ++i)
{
fprintf(fp, "\"%s\": ", files[i].c_str());
std::string escaped(escapeFilename(files[i]));
fprintf(fp, "\"%s\": ", escaped.c_str());
fileStats[i].serializeToJson(fp);
fprintf(fp, i == (fileCount - 1) ? "\n" : ",\n");
}

View file

@ -10,6 +10,7 @@
#ifndef NOMINMAX
#define NOMINMAX
#endif
#include <direct.h>
#include <windows.h>
#else
#include <dirent.h>
@ -44,6 +45,148 @@ static std::string toUtf8(const std::wstring& path)
}
#endif
bool isAbsolutePath(std::string_view path)
{
#ifdef _WIN32
// Must either begin with "X:/", "X:\", "/", or "\", where X is a drive letter
return (path.size() >= 3 && isalpha(path[0]) && path[1] == ':' && (path[2] == '/' || path[2] == '\\')) ||
(path.size() >= 1 && (path[0] == '/' || path[0] == '\\'));
#else
// Must begin with '/'
return path.size() >= 1 && path[0] == '/';
#endif
}
bool isExplicitlyRelative(std::string_view path)
{
return (path == ".") || (path == "..") || (path.size() >= 2 && path[0] == '.' && path[1] == '/') ||
(path.size() >= 3 && path[0] == '.' && path[1] == '.' && path[2] == '/');
}
std::optional<std::string> getCurrentWorkingDirectory()
{
// 2^17 - derived from the Windows path length limit
constexpr size_t maxPathLength = 131072;
constexpr size_t initialPathLength = 260;
std::string directory(initialPathLength, '\0');
char* cstr = nullptr;
while (!cstr && directory.size() <= maxPathLength)
{
#ifdef _WIN32
cstr = _getcwd(directory.data(), static_cast<int>(directory.size()));
#else
cstr = getcwd(directory.data(), directory.size());
#endif
if (cstr)
{
directory.resize(strlen(cstr));
return directory;
}
else if (errno != ERANGE || directory.size() * 2 > maxPathLength)
{
return std::nullopt;
}
else
{
directory.resize(directory.size() * 2);
}
}
return std::nullopt;
}
// Returns the normal/canonical form of a path (e.g. "../subfolder/../module.luau" -> "../module.luau")
std::string normalizePath(std::string_view path)
{
return resolvePath(path, "");
}
// Takes a path that is relative to the file at baseFilePath and returns the path explicitly rebased onto baseFilePath.
// For absolute paths, baseFilePath will be ignored, and this function will resolve the path to a canonical path:
// (e.g. "/Users/.././Users/johndoe" -> "/Users/johndoe").
std::string resolvePath(std::string_view path, std::string_view baseFilePath)
{
std::vector<std::string_view> pathComponents;
std::vector<std::string_view> baseFilePathComponents;
// Dependent on whether the final resolved path is absolute or relative
// - if relative (when path and baseFilePath are both relative), resolvedPathPrefix remains empty
// - if absolute (if either path or baseFilePath are absolute), resolvedPathPrefix is "C:\", "/", etc.
std::string resolvedPathPrefix;
if (isAbsolutePath(path))
{
// path is absolute, we use path's prefix and ignore baseFilePath
size_t afterPrefix = path.find_first_of("\\/") + 1;
resolvedPathPrefix = path.substr(0, afterPrefix);
pathComponents = splitPath(path.substr(afterPrefix));
}
else
{
pathComponents = splitPath(path);
if (isAbsolutePath(baseFilePath))
{
// path is relative and baseFilePath is absolute, we use baseFilePath's prefix
size_t afterPrefix = baseFilePath.find_first_of("\\/") + 1;
resolvedPathPrefix = baseFilePath.substr(0, afterPrefix);
baseFilePathComponents = splitPath(baseFilePath.substr(afterPrefix));
}
else
{
// path and baseFilePath are both relative, we do not set a prefix (resolved path will be relative)
baseFilePathComponents = splitPath(baseFilePath);
}
}
// Remove filename from components
if (!baseFilePathComponents.empty())
baseFilePathComponents.pop_back();
// Resolve the path by applying pathComponents to baseFilePathComponents
int numPrependedParents = 0;
for (std::string_view component : pathComponents)
{
if (component == "..")
{
if (baseFilePathComponents.empty())
{
if (resolvedPathPrefix.empty()) // only when final resolved path will be relative
numPrependedParents++; // "../" will later be added to the beginning of the resolved path
}
else if (baseFilePathComponents.back() != "..")
{
baseFilePathComponents.pop_back(); // Resolve cases like "folder/subfolder/../../file" to "file"
}
}
else if (component != "." && !component.empty())
{
baseFilePathComponents.push_back(component);
}
}
// Join baseFilePathComponents to form the resolved path
std::string resolvedPath = resolvedPathPrefix;
// Only when resolvedPath will be relative
for (int i = 0; i < numPrependedParents; i++)
{
resolvedPath += "../";
}
for (auto iter = baseFilePathComponents.begin(); iter != baseFilePathComponents.end(); ++iter)
{
if (iter != baseFilePathComponents.begin())
resolvedPath += "/";
resolvedPath += *iter;
}
if (resolvedPath.size() > resolvedPathPrefix.size() && resolvedPath.back() == '/')
{
// Remove trailing '/' if present
resolvedPath.pop_back();
}
return resolvedPath;
}
std::optional<std::string> readFile(const std::string& name)
{
#ifdef _WIN32
@ -224,6 +367,24 @@ bool isDirectory(const std::string& path)
#endif
}
std::vector<std::string_view> splitPath(std::string_view path)
{
std::vector<std::string_view> components;
size_t pos = 0;
size_t nextPos = path.find_first_of("\\/", pos);
while (nextPos != std::string::npos)
{
components.push_back(path.substr(pos, nextPos - pos));
pos = nextPos + 1;
nextPos = path.find_first_of("\\/", pos);
}
components.push_back(path.substr(pos));
return components;
}
std::string joinPaths(const std::string& lhs, const std::string& rhs)
{
std::string result = lhs;

View file

@ -3,15 +3,24 @@
#include <optional>
#include <string>
#include <string_view>
#include <functional>
#include <vector>
std::optional<std::string> getCurrentWorkingDirectory();
std::string normalizePath(std::string_view path);
std::string resolvePath(std::string_view relativePath, std::string_view baseFilePath);
std::optional<std::string> readFile(const std::string& name);
std::optional<std::string> readStdin();
bool isAbsolutePath(std::string_view path);
bool isExplicitlyRelative(std::string_view path);
bool isDirectory(const std::string& path);
bool traverseDirectory(const std::string& path, const std::function<void(const std::string& name)>& callback);
std::vector<std::string_view> splitPath(std::string_view path);
std::string joinPaths(const std::string& lhs, const std::string& rhs);
std::optional<std::string> getParentPath(const std::string& path);

View file

@ -1,6 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Repl.h"
#include "Luau/Common.h"
#include "lua.h"
#include "lualib.h"
@ -13,6 +14,7 @@
#include "FileUtils.h"
#include "Flags.h"
#include "Profiler.h"
#include "Require.h"
#include "isocline.h"
@ -39,6 +41,8 @@
LUAU_FASTFLAG(DebugLuauTimeTracing)
LUAU_FASTFLAGVARIABLE(LuauUpdatedRequireByStringSemantics, false)
constexpr int MaxTraversalLimit = 50;
static bool codegen = false;
@ -115,74 +119,129 @@ static int finishrequire(lua_State* L)
static int lua_require(lua_State* L)
{
std::string name = luaL_checkstring(L, 1);
std::string chunkname = "=" + name;
luaL_findtable(L, LUA_REGISTRYINDEX, "_MODULES", 1);
// return the module from the cache
lua_getfield(L, -1, name.c_str());
if (!lua_isnil(L, -1))
if (FFlag::LuauUpdatedRequireByStringSemantics)
{
// L stack: _MODULES result
std::string name = luaL_checkstring(L, 1);
RequireResolver::ResolvedRequire resolvedRequire = RequireResolver::resolveRequire(L, std::move(name));
if (resolvedRequire.status == RequireResolver::ModuleStatus::Cached)
return finishrequire(L);
else if (resolvedRequire.status == RequireResolver::ModuleStatus::NotFound)
luaL_errorL(L, "error requiring module");
// module needs to run in a new thread, isolated from the rest
// note: we create ML on main thread so that it doesn't inherit environment of L
lua_State* GL = lua_mainthread(L);
lua_State* ML = lua_newthread(GL);
lua_xmove(GL, L, 1);
// new thread needs to have the globals sandboxed
luaL_sandboxthread(ML);
// now we can compile & run module on the new thread
std::string bytecode = Luau::compile(resolvedRequire.sourceCode, copts());
if (luau_load(ML, resolvedRequire.chunkName.c_str(), bytecode.data(), bytecode.size(), 0) == 0)
{
if (codegen)
Luau::CodeGen::compile(ML, -1);
if (coverageActive())
coverageTrack(ML, -1);
int status = lua_resume(ML, L, 0);
if (status == 0)
{
if (lua_gettop(ML) == 0)
lua_pushstring(ML, "module must return a value");
else if (!lua_istable(ML, -1) && !lua_isfunction(ML, -1))
lua_pushstring(ML, "module must return a table or function");
}
else if (status == LUA_YIELD)
{
lua_pushstring(ML, "module can not yield");
}
else if (!lua_isstring(ML, -1))
{
lua_pushstring(ML, "unknown error while running module");
}
}
// there's now a return value on top of ML; L stack: _MODULES ML
lua_xmove(ML, L, 1);
lua_pushvalue(L, -1);
lua_setfield(L, -4, resolvedRequire.absolutePath.c_str());
// L stack: _MODULES ML result
return finishrequire(L);
}
lua_pop(L, 1);
std::optional<std::string> source = readFile(name + ".luau");
if (!source)
else
{
source = readFile(name + ".lua"); // try .lua if .luau doesn't exist
std::string name = luaL_checkstring(L, 1);
std::string chunkname = "=" + name;
luaL_findtable(L, LUA_REGISTRYINDEX, "_MODULES", 1);
// return the module from the cache
lua_getfield(L, -1, name.c_str());
if (!lua_isnil(L, -1))
{
// L stack: _MODULES result
return finishrequire(L);
}
lua_pop(L, 1);
std::optional<std::string> source = readFile(name + ".luau");
if (!source)
luaL_argerrorL(L, 1, ("error loading " + name).c_str()); // if neither .luau nor .lua exist, we have an error
{
source = readFile(name + ".lua"); // try .lua if .luau doesn't exist
if (!source)
luaL_argerrorL(L, 1, ("error loading " + name).c_str()); // if neither .luau nor .lua exist, we have an error
}
// module needs to run in a new thread, isolated from the rest
// note: we create ML on main thread so that it doesn't inherit environment of L
lua_State* GL = lua_mainthread(L);
lua_State* ML = lua_newthread(GL);
lua_xmove(GL, L, 1);
// new thread needs to have the globals sandboxed
luaL_sandboxthread(ML);
// now we can compile & run module on the new thread
std::string bytecode = Luau::compile(*source, copts());
if (luau_load(ML, chunkname.c_str(), bytecode.data(), bytecode.size(), 0) == 0)
{
if (codegen)
Luau::CodeGen::compile(ML, -1);
if (coverageActive())
coverageTrack(ML, -1);
int status = lua_resume(ML, L, 0);
if (status == 0)
{
if (lua_gettop(ML) == 0)
lua_pushstring(ML, "module must return a value");
else if (!lua_istable(ML, -1) && !lua_isfunction(ML, -1))
lua_pushstring(ML, "module must return a table or function");
}
else if (status == LUA_YIELD)
{
lua_pushstring(ML, "module can not yield");
}
else if (!lua_isstring(ML, -1))
{
lua_pushstring(ML, "unknown error while running module");
}
}
// there's now a return value on top of ML; L stack: _MODULES ML
lua_xmove(ML, L, 1);
lua_pushvalue(L, -1);
lua_setfield(L, -4, name.c_str());
// L stack: _MODULES ML result
return finishrequire(L);
}
// module needs to run in a new thread, isolated from the rest
// note: we create ML on main thread so that it doesn't inherit environment of L
lua_State* GL = lua_mainthread(L);
lua_State* ML = lua_newthread(GL);
lua_xmove(GL, L, 1);
// new thread needs to have the globals sandboxed
luaL_sandboxthread(ML);
// now we can compile & run module on the new thread
std::string bytecode = Luau::compile(*source, copts());
if (luau_load(ML, chunkname.c_str(), bytecode.data(), bytecode.size(), 0) == 0)
{
if (codegen)
Luau::CodeGen::compile(ML, -1);
if (coverageActive())
coverageTrack(ML, -1);
int status = lua_resume(ML, L, 0);
if (status == 0)
{
if (lua_gettop(ML) == 0)
lua_pushstring(ML, "module must return a value");
else if (!lua_istable(ML, -1) && !lua_isfunction(ML, -1))
lua_pushstring(ML, "module must return a table or function");
}
else if (status == LUA_YIELD)
{
lua_pushstring(ML, "module can not yield");
}
else if (!lua_isstring(ML, -1))
{
lua_pushstring(ML, "unknown error while running module");
}
}
// there's now a return value on top of ML; L stack: _MODULES ML
lua_xmove(ML, L, 1);
lua_pushvalue(L, -1);
lua_setfield(L, -4, name.c_str());
// L stack: _MODULES ML result
return finishrequire(L);
}
static int lua_collectgarbage(lua_State* L)

290
CLI/Require.cpp Normal file
View file

@ -0,0 +1,290 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Require.h"
#include "FileUtils.h"
#include "Luau/Common.h"
#include <algorithm>
#include <array>
#include <utility>
RequireResolver::RequireResolver(lua_State* L, std::string path)
: pathToResolve(std::move(path))
, L(L)
{
lua_Debug ar;
lua_getinfo(L, 1, "s", &ar);
sourceChunkname = ar.source;
if (!isRequireAllowed(sourceChunkname))
luaL_errorL(L, "require is not supported in this context");
if (isAbsolutePath(pathToResolve))
luaL_argerrorL(L, 1, "cannot require an absolute path");
bool isAlias = !pathToResolve.empty() && pathToResolve[0] == '@';
if (!isAlias && !isExplicitlyRelative(pathToResolve))
luaL_argerrorL(L, 1, "must require an alias prepended with '@' or an explicitly relative path");
std::replace(pathToResolve.begin(), pathToResolve.end(), '\\', '/');
if (isAlias)
{
pathToResolve = pathToResolve.substr(1);
substituteAliasIfPresent(pathToResolve);
}
}
[[nodiscard]] RequireResolver::ResolvedRequire RequireResolver::resolveRequire(lua_State* L, std::string path)
{
RequireResolver resolver(L, std::move(path));
ModuleStatus status = resolver.findModule();
if (status != ModuleStatus::FileRead)
return ResolvedRequire{status};
else
return ResolvedRequire{status, std::move(resolver.chunkname), std::move(resolver.absolutePath), std::move(resolver.sourceCode)};
}
RequireResolver::ModuleStatus RequireResolver::findModule()
{
resolveAndStoreDefaultPaths();
// Put _MODULES table on stack for checking and saving to the cache
luaL_findtable(L, LUA_REGISTRYINDEX, "_MODULES", 1);
RequireResolver::ModuleStatus moduleStatus = findModuleImpl();
if (moduleStatus != RequireResolver::ModuleStatus::NotFound)
return moduleStatus;
if (!shouldSearchPathsArray())
return moduleStatus;
if (!isConfigFullyResolved)
parseNextConfig();
// Index-based iteration because std::iterator may be invalidated if config.paths is reallocated
for (size_t i = 0; i < config.paths.size(); ++i)
{
// "placeholder" acts as a requiring file in the relevant directory
std::optional<std::string> absolutePathOpt = resolvePath(pathToResolve, joinPaths(config.paths[i], "placeholder"));
if (!absolutePathOpt)
luaL_errorL(L, "error requiring module");
chunkname = *absolutePathOpt;
absolutePath = *absolutePathOpt;
moduleStatus = findModuleImpl();
if (moduleStatus != RequireResolver::ModuleStatus::NotFound)
return moduleStatus;
// Before finishing the loop, parse more config files if there are any
if (i == config.paths.size() - 1 && !isConfigFullyResolved)
parseNextConfig(); // could reallocate config.paths when paths are parsed and added
}
return RequireResolver::ModuleStatus::NotFound;
}
RequireResolver::ModuleStatus RequireResolver::findModuleImpl()
{
static const std::array<const char*, 4> possibleSuffixes = {".luau", ".lua", "/init.luau", "/init.lua"};
size_t unsuffixedAbsolutePathSize = absolutePath.size();
for (const char* possibleSuffix : possibleSuffixes)
{
absolutePath += possibleSuffix;
// Check cache for module
lua_getfield(L, -1, absolutePath.c_str());
if (!lua_isnil(L, -1))
{
return ModuleStatus::Cached;
}
lua_pop(L, 1);
// Try to read the matching file
std::optional<std::string> source = readFile(absolutePath);
if (source)
{
chunkname = "=" + chunkname + possibleSuffix;
sourceCode = *source;
return ModuleStatus::FileRead;
}
absolutePath.resize(unsuffixedAbsolutePathSize); // truncate to remove suffix
}
return ModuleStatus::NotFound;
}
bool RequireResolver::isRequireAllowed(std::string_view sourceChunkname)
{
LUAU_ASSERT(!sourceChunkname.empty());
return (sourceChunkname[0] == '=' || sourceChunkname[0] == '@');
}
bool RequireResolver::shouldSearchPathsArray()
{
return !isAbsolutePath(pathToResolve) && !isExplicitlyRelative(pathToResolve);
}
void RequireResolver::resolveAndStoreDefaultPaths()
{
if (!isAbsolutePath(pathToResolve))
{
std::string chunknameContext = getRequiringContextRelative();
std::optional<std::string> absolutePathContext = getRequiringContextAbsolute();
if (!absolutePathContext)
luaL_errorL(L, "error requiring module");
// resolvePath automatically sanitizes/normalizes the paths
std::optional<std::string> chunknameOpt = resolvePath(pathToResolve, chunknameContext);
std::optional<std::string> absolutePathOpt = resolvePath(pathToResolve, *absolutePathContext);
if (!chunknameOpt || !absolutePathOpt)
luaL_errorL(L, "error requiring module");
chunkname = std::move(*chunknameOpt);
absolutePath = std::move(*absolutePathOpt);
}
else
{
// Here we must explicitly sanitize, as the path is taken as is
std::optional<std::string> sanitizedPath = normalizePath(pathToResolve);
if (!sanitizedPath)
luaL_errorL(L, "error requiring module");
chunkname = *sanitizedPath;
absolutePath = std::move(*sanitizedPath);
}
}
std::optional<std::string> RequireResolver::getRequiringContextAbsolute()
{
std::string requiringFile;
if (isAbsolutePath(sourceChunkname.substr(1)))
{
// We already have an absolute path for the requiring file
requiringFile = sourceChunkname.substr(1);
}
else
{
// Requiring file's stored path is relative to the CWD, must make absolute
std::optional<std::string> cwd = getCurrentWorkingDirectory();
if (!cwd)
return std::nullopt;
if (sourceChunkname.substr(1) == "stdin")
{
// Require statement is being executed from REPL input prompt
// The requiring context is the pseudo-file "stdin" in the CWD
requiringFile = joinPaths(*cwd, "stdin");
}
else
{
// Require statement is being executed in a file, must resolve relative to CWD
std::optional<std::string> requiringFileOpt = resolvePath(sourceChunkname.substr(1), joinPaths(*cwd, "stdin"));
if (!requiringFileOpt)
return std::nullopt;
requiringFile = *requiringFileOpt;
}
}
std::replace(requiringFile.begin(), requiringFile.end(), '\\', '/');
return requiringFile;
}
std::string RequireResolver::getRequiringContextRelative()
{
std::string baseFilePath;
if (sourceChunkname.substr(1) != "stdin")
baseFilePath = sourceChunkname.substr(1);
return baseFilePath;
}
void RequireResolver::substituteAliasIfPresent(std::string& path)
{
std::string potentialAlias = path.substr(0, path.find_first_of("\\/"));
// Not worth searching when potentialAlias cannot be an alias
if (!Luau::isValidAlias(potentialAlias))
return;
std::optional<std::string> alias = getAlias(potentialAlias);
if (alias)
{
path = *alias + path.substr(potentialAlias.size());
}
}
std::optional<std::string> RequireResolver::getAlias(std::string alias)
{
std::transform(alias.begin(), alias.end(), alias.begin(), [](unsigned char c) {
return ('A' <= c && c <= 'Z') ? (c + ('a' - 'A')) : c;
});
while (!config.aliases.count(alias) && !isConfigFullyResolved)
{
parseNextConfig();
}
if (!config.aliases.count(alias) && isConfigFullyResolved)
return std::nullopt; // could not find alias
return resolvePath(config.aliases[alias], joinPaths(lastSearchedDir, Luau::kConfigName));
}
void RequireResolver::parseNextConfig()
{
if (isConfigFullyResolved)
return; // no config files left to parse
std::optional<std::string> directory;
if (lastSearchedDir.empty())
{
std::optional<std::string> requiringFile = getRequiringContextAbsolute();
if (!requiringFile)
luaL_errorL(L, "error requiring module");
directory = getParentPath(*requiringFile);
}
else
directory = getParentPath(lastSearchedDir);
if (directory)
{
lastSearchedDir = *directory;
parseConfigInDirectory(*directory);
}
else
isConfigFullyResolved = true;
}
void RequireResolver::parseConfigInDirectory(const std::string& directory)
{
std::string configPath = joinPaths(directory, Luau::kConfigName);
size_t numPaths = config.paths.size();
if (std::optional<std::string> contents = readFile(configPath))
{
std::optional<std::string> error = Luau::parseConfig(*contents, config);
if (error)
luaL_errorL(L, "error parsing %s (%s)", configPath.c_str(), (*error).c_str());
}
// Resolve any newly obtained relative paths in "paths" in relation to configPath
for (auto it = config.paths.begin() + numPaths; it != config.paths.end(); ++it)
{
if (!isAbsolutePath(*it))
{
if (std::optional<std::string> resolvedPath = resolvePath(*it, configPath))
*it = std::move(*resolvedPath);
else
luaL_errorL(L, "error requiring module");
}
}
}

62
CLI/Require.h Normal file
View file

@ -0,0 +1,62 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "lua.h"
#include "lualib.h"
#include "Luau/Config.h"
#include <string>
#include <string_view>
class RequireResolver
{
public:
std::string chunkname;
std::string absolutePath;
std::string sourceCode;
enum class ModuleStatus
{
Cached,
FileRead,
NotFound
};
struct ResolvedRequire
{
ModuleStatus status;
std::string chunkName;
std::string absolutePath;
std::string sourceCode;
};
[[nodiscard]] ResolvedRequire static resolveRequire(lua_State* L, std::string path);
private:
std::string pathToResolve;
std::string_view sourceChunkname;
RequireResolver(lua_State* L, std::string path);
ModuleStatus findModule();
lua_State* L;
Luau::Config config;
std::string lastSearchedDir;
bool isConfigFullyResolved = false;
bool isRequireAllowed(std::string_view sourceChunkname);
bool shouldSearchPathsArray();
void resolveAndStoreDefaultPaths();
ModuleStatus findModuleImpl();
std::optional<std::string> getRequiringContextAbsolute();
std::string getRequiringContextRelative();
void substituteAliasIfPresent(std::string& path);
std::optional<std::string> getAlias(std::string alias);
void parseNextConfig();
void parseConfigInDirectory(const std::string& path);
};

View file

@ -37,6 +37,7 @@ if(LUAU_BUILD_CLI)
add_executable(Luau.Ast.CLI)
add_executable(Luau.Reduce.CLI)
add_executable(Luau.Compile.CLI)
add_executable(Luau.Bytecode.CLI)
# This also adds target `name` on Linux/macOS and `name.exe` on Windows
set_target_properties(Luau.Repl.CLI PROPERTIES OUTPUT_NAME luau)
@ -44,6 +45,7 @@ if(LUAU_BUILD_CLI)
set_target_properties(Luau.Ast.CLI PROPERTIES OUTPUT_NAME luau-ast)
set_target_properties(Luau.Reduce.CLI PROPERTIES OUTPUT_NAME luau-reduce)
set_target_properties(Luau.Compile.CLI PROPERTIES OUTPUT_NAME luau-compile)
set_target_properties(Luau.Bytecode.CLI PROPERTIES OUTPUT_NAME luau-bytecode)
endif()
if(LUAU_BUILD_TESTS)
@ -187,10 +189,11 @@ if(LUAU_BUILD_CLI)
target_compile_options(Luau.Analyze.CLI PRIVATE ${LUAU_OPTIONS})
target_compile_options(Luau.Ast.CLI PRIVATE ${LUAU_OPTIONS})
target_compile_options(Luau.Compile.CLI PRIVATE ${LUAU_OPTIONS})
target_compile_options(Luau.Bytecode.CLI PRIVATE ${LUAU_OPTIONS})
target_include_directories(Luau.Repl.CLI PRIVATE extern extern/isocline/include)
target_link_libraries(Luau.Repl.CLI PRIVATE Luau.Compiler Luau.CodeGen Luau.VM isocline)
target_link_libraries(Luau.Repl.CLI PRIVATE Luau.Compiler Luau.Config Luau.CodeGen Luau.VM isocline)
if(UNIX)
find_library(LIBPTHREAD pthread)
@ -209,6 +212,8 @@ if(LUAU_BUILD_CLI)
target_link_libraries(Luau.Reduce.CLI PRIVATE Luau.Common Luau.Ast Luau.Analysis)
target_link_libraries(Luau.Compile.CLI PRIVATE Luau.Compiler Luau.VM Luau.CodeGen)
target_link_libraries(Luau.Bytecode.CLI PRIVATE Luau.Compiler Luau.VM Luau.CodeGen)
endif()
if(LUAU_BUILD_TESTS)
@ -225,7 +230,7 @@ if(LUAU_BUILD_TESTS)
target_compile_options(Luau.CLI.Test PRIVATE ${LUAU_OPTIONS})
target_include_directories(Luau.CLI.Test PRIVATE extern CLI)
target_link_libraries(Luau.CLI.Test PRIVATE Luau.Compiler Luau.CodeGen Luau.VM isocline)
target_link_libraries(Luau.CLI.Test PRIVATE Luau.Compiler Luau.Config Luau.CodeGen Luau.VM isocline)
if(UNIX)
find_library(LIBPTHREAD pthread)
if (LIBPTHREAD)
@ -249,6 +254,8 @@ if(LUAU_BUILD_WEB)
target_link_options(Luau.Web PRIVATE -sSINGLE_FILE=1)
endif()
add_subdirectory(fuzz)
# validate dependencies for internal libraries
foreach(LIB Luau.Ast Luau.Compiler Luau.Config Luau.Analysis Luau.CodeGen Luau.VM)
if(TARGET ${LIB})

47
CMakePresets.json Normal file
View file

@ -0,0 +1,47 @@
{
"version": 6,
"configurePresets": [
{
"name": "fuzz",
"displayName": "Fuzz",
"description": "Configures required fuzzer settings.",
"binaryDir": "build",
"condition": {
"type": "anyOf",
"conditions": [
{
"type": "equals",
"lhs": "${hostSystemName}",
"rhs": "Darwin"
},
{
"type": "equals",
"lhs": "${hostSystemName}",
"rhs": "Linux"
}
]
},
"cacheVariables": {
"CMAKE_OSX_ARCHITECTURES": "x86_64",
"CMAKE_BUILD_TYPE": "Release",
"CMAKE_CXX_STANDARD": "17",
"CMAKE_CXX_EXTENSIONS": false
},
"warnings": {
"dev": false
}
}
],
"buildPresets": [
{
"name": "fuzz-proto",
"displayName": "Protobuf Fuzzer",
"description": "Builds the protobuf-based fuzzer and transpiler tools.",
"configurePreset": "fuzz",
"targets": [
"Luau.Fuzz.Proto",
"Luau.Fuzz.ProtoTest"
]
}
]
}

View file

@ -133,6 +133,7 @@ public:
void vcvttsd2si(OperandX64 dst, OperandX64 src);
void vcvtsi2sd(OperandX64 dst, OperandX64 src1, OperandX64 src2);
void vcvtsd2ss(OperandX64 dst, OperandX64 src1, OperandX64 src2);
void vcvtss2sd(OperandX64 dst, OperandX64 src1, OperandX64 src2);
void vroundsd(OperandX64 dst, OperandX64 src1, OperandX64 src2, RoundingModeX64 roundingMode); // inexact
@ -158,7 +159,6 @@ public:
void vblendvpd(RegisterX64 dst, RegisterX64 src1, OperandX64 mask, RegisterX64 src3);
// Run final checks
bool finalize();
@ -228,6 +228,7 @@ private:
void placeVex(OperandX64 dst, OperandX64 src1, OperandX64 src2, bool setW, uint8_t mode, uint8_t prefix);
void placeImm8Or32(int32_t imm);
void placeImm8(int32_t imm);
void placeImm16(int16_t imm);
void placeImm32(int32_t imm);
void placeImm64(int64_t imm);
void placeLabel(Label& label);

View file

@ -0,0 +1,21 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Common.h"
#include <vector>
#include <stdint.h>
namespace Luau
{
namespace CodeGen
{
struct IrFunction;
void buildBytecodeBlocks(IrFunction& function, const std::vector<uint8_t>& jumpTargets);
void analyzeBytecodeTypes(IrFunction& function);
} // namespace CodeGen
} // namespace Luau

View file

@ -0,0 +1,83 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Common.h"
#include "Luau/Bytecode.h"
#include <string>
#include <vector>
#include <stdint.h>
struct lua_State;
struct Proto;
namespace Luau
{
namespace CodeGen
{
class FunctionBytecodeSummary
{
public:
FunctionBytecodeSummary(std::string source, std::string name, const int line, unsigned nestingLimit);
const std::string& getSource() const
{
return source;
}
const std::string& getName() const
{
return name;
}
int getLine() const
{
return line;
}
const unsigned getNestingLimit() const
{
return nestingLimit;
}
const unsigned getOpLimit() const
{
return LOP__COUNT;
}
void incCount(unsigned nesting, uint8_t op)
{
LUAU_ASSERT(nesting <= getNestingLimit());
LUAU_ASSERT(op < getOpLimit());
++counts[nesting][op];
}
unsigned getCount(unsigned nesting, uint8_t op) const
{
LUAU_ASSERT(nesting <= getNestingLimit());
LUAU_ASSERT(op < getOpLimit());
return counts[nesting][op];
}
const std::vector<unsigned>& getCounts(unsigned nesting) const
{
LUAU_ASSERT(nesting <= getNestingLimit());
return counts[nesting];
}
static FunctionBytecodeSummary fromProto(Proto* proto, unsigned nestingLimit);
private:
std::string source;
std::string name;
int line;
unsigned nestingLimit;
std::vector<std::vector<unsigned>> counts;
};
std::vector<FunctionBytecodeSummary> summarizeBytecode(lua_State* L, int idx, unsigned nestingLimit);
} // namespace CodeGen
} // namespace Luau

View file

@ -3,6 +3,7 @@
#include <algorithm>
#include <string>
#include <vector>
#include <stddef.h>
#include <stdint.h>
@ -80,6 +81,36 @@ struct AssemblyOptions
void* annotatorContext = nullptr;
};
struct BlockLinearizationStats
{
unsigned int constPropInstructionCount = 0;
double timeSeconds = 0.0;
BlockLinearizationStats& operator+=(const BlockLinearizationStats& that)
{
this->constPropInstructionCount += that.constPropInstructionCount;
this->timeSeconds += that.timeSeconds;
return *this;
}
BlockLinearizationStats operator+(const BlockLinearizationStats& other) const
{
BlockLinearizationStats result(*this);
result += other;
return result;
}
};
struct FunctionStats
{
std::string name;
int line = -1;
unsigned bcodeCount = 0;
unsigned irCount = 0;
unsigned asmCount = 0;
};
struct LoweringStats
{
unsigned totalFunctions = 0;
@ -94,6 +125,11 @@ struct LoweringStats
int regAllocErrors = 0;
int loweringErrors = 0;
BlockLinearizationStats blockLinearizationStats;
bool collectFunctionStats = false;
std::vector<FunctionStats> functions;
LoweringStats operator+(const LoweringStats& other) const
{
LoweringStats result(*this);
@ -113,6 +149,9 @@ struct LoweringStats
this->maxBlockInstructions = std::max(this->maxBlockInstructions, that.maxBlockInstructions);
this->regAllocErrors += that.regAllocErrors;
this->loweringErrors += that.loweringErrors;
this->blockLinearizationStats += that.blockLinearizationStats;
if (this->collectFunctionStats)
this->functions.insert(this->functions.end(), that.functions.begin(), that.functions.end());
return *this;
}
};

View file

@ -78,7 +78,13 @@ struct IrBuilder
std::vector<uint32_t> instIndexToBlock; // Block index at the bytecode instruction
std::vector<IrOp> loopStepStack;
struct LoopInfo
{
IrOp step;
int startpc = 0;
};
std::vector<LoopInfo> numericLoopStack;
// Similar to BytecodeBuilder, duplicate constants are removed used the same method
struct ConstantKey

View file

@ -1,6 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Bytecode.h"
#include "Luau/IrAnalysis.h"
#include "Luau/Label.h"
#include "Luau/RegisterX64.h"
@ -12,6 +13,8 @@
#include <stdint.h>
#include <string.h>
LUAU_FASTFLAG(LuauKeepVmapLinear2)
struct Proto;
namespace Luau
@ -251,7 +254,7 @@ enum class IrCmd : uint8_t
// A: pointer (Table)
DUP_TABLE,
// Insert an integer key into a table
// Insert an integer key into a table and return the pointer to inserted value (TValue)
// A: pointer (Table)
// B: int (key)
TABLE_SETNUM,
@ -281,7 +284,7 @@ enum class IrCmd : uint8_t
NUM_TO_UINT,
// Adjust stack top (L->top) to point at 'B' TValues *after* the specified register
// This is used to return muliple values
// This is used to return multiple values
// A: Rn
// B: int (offset)
ADJUST_STACK_TO_REG,
@ -420,6 +423,14 @@ enum class IrCmd : uint8_t
// When undef is specified instead of a block, execution is aborted on check failure
CHECK_NODE_VALUE,
// Guard against access at specified offset/size overflowing the buffer length
// A: pointer (buffer)
// B: int (offset)
// C: int (size)
// D: block/vmexit/undef
// When undef is specified instead of a block, execution is aborted on check failure
CHECK_BUFFER_LEN,
// Special operations
// Check interrupt handler
@ -600,6 +611,10 @@ enum class IrCmd : uint8_t
BITCOUNTLZ_UINT,
BITCOUNTRZ_UINT,
// Swap byte order in A
// A: int
BYTESWAP_UINT,
// Calls native libm function with 1 or 2 arguments
// A: builtin function ID
// B: double
@ -617,6 +632,71 @@ enum class IrCmd : uint8_t
// Find or create an upval at the given level
// A: Rn (level)
FINDUPVAL,
// Read i8 (sign-extended to int) from buffer storage at specified offset
// A: pointer (buffer)
// B: int (offset)
BUFFER_READI8,
// Read u8 (zero-extended to int) from buffer storage at specified offset
// A: pointer (buffer)
// B: int (offset)
BUFFER_READU8,
// Write i8/u8 value (int argument is truncated) to buffer storage at specified offset
// A: pointer (buffer)
// B: int (offset)
// C: int (value)
BUFFER_WRITEI8,
// Read i16 (sign-extended to int) from buffer storage at specified offset
// A: pointer (buffer)
// B: int (offset)
BUFFER_READI16,
// Read u16 (zero-extended to int) from buffer storage at specified offset
// A: pointer (buffer)
// B: int (offset)
BUFFER_READU16,
// Write i16/u16 value (int argument is truncated) to buffer storage at specified offset
// A: pointer (buffer)
// B: int (offset)
// C: int (value)
BUFFER_WRITEI16,
// Read i32 value from buffer storage at specified offset
// A: pointer (buffer)
// B: int (offset)
BUFFER_READI32,
// Write i32/u32 value to buffer storage at specified offset
// A: pointer (buffer)
// B: int (offset)
// C: int (value)
BUFFER_WRITEI32,
// Read float value (converted to double) from buffer storage at specified offset
// A: pointer (buffer)
// B: int (offset)
BUFFER_READF32,
// Write float value (converted from double) to buffer storage at specified offset
// A: pointer (buffer)
// B: int (offset)
// C: double (value)
BUFFER_WRITEF32,
// Read double value from buffer storage at specified offset
// A: pointer (buffer)
// B: int (offset)
BUFFER_READF64,
// Write double value to buffer storage at specified offset
// A: pointer (buffer)
// B: int (offset)
// C: double (value)
BUFFER_WRITEF64,
};
enum class IrConstKind : uint8_t
@ -853,18 +933,37 @@ struct BytecodeMapping
uint32_t asmLocation;
};
struct BytecodeBlock
{
// 'start' and 'finish' define an inclusive range of instructions which belong to the block
int startpc = -1;
int finishpc = -1;
};
struct BytecodeTypes
{
uint8_t result = LBC_TYPE_ANY;
uint8_t a = LBC_TYPE_ANY;
uint8_t b = LBC_TYPE_ANY;
uint8_t c = LBC_TYPE_ANY;
};
struct IrFunction
{
std::vector<IrBlock> blocks;
std::vector<IrInst> instructions;
std::vector<IrConst> constants;
std::vector<BytecodeBlock> bcBlocks;
std::vector<BytecodeTypes> bcTypes;
std::vector<BytecodeMapping> bcMapping;
uint32_t entryBlock = 0;
uint32_t entryLocation = 0;
// For each instruction, an operand that can be used to recompute the value
std::vector<IrOp> valueRestoreOps;
std::vector<uint32_t> validRestoreOpBlocks;
uint32_t validRestoreOpBlockIdx = 0;
Proto* proto = nullptr;
@ -1009,22 +1108,53 @@ struct IrFunction
if (instIdx >= valueRestoreOps.size())
return {};
const IrBlock& block = blocks[validRestoreOpBlockIdx];
// When spilled, values can only reference restore operands in the current block
if (limitToCurrentBlock)
if (FFlag::LuauKeepVmapLinear2)
{
if (instIdx < block.start || instIdx > block.finish)
return {};
}
// When spilled, values can only reference restore operands in the current block chain
if (limitToCurrentBlock)
{
for (uint32_t blockIdx : validRestoreOpBlocks)
{
const IrBlock& block = blocks[blockIdx];
return valueRestoreOps[instIdx];
if (instIdx >= block.start && instIdx <= block.finish)
return valueRestoreOps[instIdx];
}
return {};
}
return valueRestoreOps[instIdx];
}
else
{
const IrBlock& block = blocks[validRestoreOpBlockIdx];
// When spilled, values can only reference restore operands in the current block
if (limitToCurrentBlock)
{
if (instIdx < block.start || instIdx > block.finish)
return {};
}
return valueRestoreOps[instIdx];
}
}
IrOp findRestoreOp(const IrInst& inst, bool limitToCurrentBlock) const
{
return findRestoreOp(getInstIndex(inst), limitToCurrentBlock);
}
BytecodeTypes getBytecodeTypesAt(int pcpos) const
{
LUAU_ASSERT(pcpos >= 0);
if (size_t(pcpos) < bcTypes.size())
return bcTypes[pcpos];
return BytecodeTypes();
}
};
inline IrCondition conditionOp(IrOp op)

View file

@ -29,6 +29,7 @@ void toString(IrToStringContext& ctx, const IrBlock& block, uint32_t index); //
void toString(IrToStringContext& ctx, IrOp op);
void toString(std::string& result, IrConst constant);
void toString(std::string& result, const BytecodeTypes& bcTypes);
void toStringDetailed(IrToStringContext& ctx, const IrBlock& block, uint32_t blockIdx, const IrInst& inst, uint32_t instIdx, bool includeUseInfo);
void toStringDetailed(IrToStringContext& ctx, const IrBlock& block, uint32_t index, bool includeUseInfo); // Block title

View file

@ -128,6 +128,7 @@ inline bool isNonTerminatingJump(IrCmd cmd)
case IrCmd::CHECK_SLOT_MATCH:
case IrCmd::CHECK_NODE_NO_NEXT:
case IrCmd::CHECK_NODE_VALUE:
case IrCmd::CHECK_BUFFER_LEN:
return true;
default:
break;
@ -197,6 +198,13 @@ inline bool hasResult(IrCmd cmd)
case IrCmd::GET_TYPEOF:
case IrCmd::NEWCLOSURE:
case IrCmd::FINDUPVAL:
case IrCmd::BUFFER_READI8:
case IrCmd::BUFFER_READU8:
case IrCmd::BUFFER_READI16:
case IrCmd::BUFFER_READU16:
case IrCmd::BUFFER_READI32:
case IrCmd::BUFFER_READF32:
case IrCmd::BUFFER_READF64:
return true;
default:
break;

View file

@ -41,6 +41,11 @@ static void visitVmRegDefsUses(T& visitor, IrFunction& function, const IrInst& i
break;
// A <- B, C
case IrCmd::DO_ARITH:
visitor.maybeUse(inst.b); // Argument can also be a VmConst
visitor.maybeUse(inst.c); // Argument can also be a VmConst
visitor.def(inst.a);
break;
case IrCmd::GET_TABLE:
visitor.use(inst.b);
visitor.maybeUse(inst.c); // Argument can also be a VmConst

View file

@ -6,6 +6,8 @@
#include <stdarg.h>
#include <stdio.h>
LUAU_FASTFLAG(LuauCodeGenFixByteLower)
namespace Luau
{
namespace CodeGen
@ -175,6 +177,12 @@ void AssemblyBuilderX64::mov(OperandX64 lhs, OperandX64 rhs)
place(OP_PLUS_REG(0xb0, lhs.base.index));
placeImm8(rhs.imm);
}
else if (size == SizeX64::word)
{
place(0x66);
place(OP_PLUS_REG(0xb8, lhs.base.index));
placeImm16(rhs.imm);
}
else if (size == SizeX64::dword)
{
place(OP_PLUS_REG(0xb8, lhs.base.index));
@ -200,6 +208,13 @@ void AssemblyBuilderX64::mov(OperandX64 lhs, OperandX64 rhs)
placeModRegMem(lhs, 0, /*extraCodeBytes=*/1);
placeImm8(rhs.imm);
}
else if (size == SizeX64::word)
{
place(0x66);
place(0xc7);
placeModRegMem(lhs, 0, /*extraCodeBytes=*/2);
placeImm16(rhs.imm);
}
else
{
LUAU_ASSERT(size == SizeX64::dword || size == SizeX64::qword);
@ -780,6 +795,16 @@ void AssemblyBuilderX64::vcvtsd2ss(OperandX64 dst, OperandX64 src1, OperandX64 s
placeAvx("vcvtsd2ss", dst, src1, src2, 0x5a, (src2.cat == CategoryX64::reg ? src2.base.size : src2.memSize) == SizeX64::qword, AVX_0F, AVX_F2);
}
void AssemblyBuilderX64::vcvtss2sd(OperandX64 dst, OperandX64 src1, OperandX64 src2)
{
if (src2.cat == CategoryX64::reg)
LUAU_ASSERT(src2.base.size == SizeX64::xmmword);
else
LUAU_ASSERT(src2.memSize == SizeX64::dword);
placeAvx("vcvtsd2ss", dst, src1, src2, 0x5a, false, AVX_0F, AVX_F3);
}
void AssemblyBuilderX64::vroundsd(OperandX64 dst, OperandX64 src1, OperandX64 src2, RoundingModeX64 roundingMode)
{
placeAvx("vroundsd", dst, src1, src2, uint8_t(roundingMode) | kRoundingPrecisionInexact, 0x0b, false, AVX_0F3A, AVX_66);
@ -1086,7 +1111,10 @@ void AssemblyBuilderX64::placeBinaryRegAndRegMem(OperandX64 lhs, OperandX64 rhs,
LUAU_ASSERT(lhs.base.size == (rhs.cat == CategoryX64::reg ? rhs.base.size : rhs.memSize));
SizeX64 size = lhs.base.size;
LUAU_ASSERT(size == SizeX64::byte || size == SizeX64::dword || size == SizeX64::qword);
LUAU_ASSERT(size == SizeX64::byte || size == SizeX64::word || size == SizeX64::dword || size == SizeX64::qword);
if (size == SizeX64::word)
place(0x66);
placeRex(lhs.base, rhs);
place(size == SizeX64::byte ? code8 : code);
@ -1411,10 +1439,25 @@ void AssemblyBuilderX64::placeImm8(int32_t imm)
{
int8_t imm8 = int8_t(imm);
if (imm8 == imm)
if (FFlag::LuauCodeGenFixByteLower)
{
LUAU_ASSERT(imm8 == imm);
place(imm8);
}
else
LUAU_ASSERT(!"Invalid immediate value");
{
if (imm8 == imm)
place(imm8);
else
LUAU_ASSERT(!"Invalid immediate value");
}
}
void AssemblyBuilderX64::placeImm16(int16_t imm)
{
uint8_t* pos = codePos;
LUAU_ASSERT(pos + sizeof(imm) < codeEnd);
codePos = writeu16(pos, imm);
}
void AssemblyBuilderX64::placeImm32(int32_t imm)

View file

@ -15,6 +15,16 @@ inline uint8_t* writeu8(uint8_t* target, uint8_t value)
return target + sizeof(value);
}
inline uint8_t* writeu16(uint8_t* target, uint16_t value)
{
#if defined(LUAU_BIG_ENDIAN)
value = htole16(value);
#endif
memcpy(target, &value, sizeof(value));
return target + sizeof(value);
}
inline uint8_t* writeu32(uint8_t* target, uint32_t value)
{
#if defined(LUAU_BIG_ENDIAN)

View file

@ -0,0 +1,884 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/BytecodeAnalysis.h"
#include "Luau/BytecodeUtils.h"
#include "Luau/IrData.h"
#include "Luau/IrUtils.h"
#include "lobject.h"
namespace Luau
{
namespace CodeGen
{
static bool hasTypedParameters(Proto* proto)
{
return proto->typeinfo && proto->numparams != 0;
}
static uint8_t getBytecodeConstantTag(Proto* proto, unsigned ki)
{
TValue protok = proto->k[ki];
switch (protok.tt)
{
case LUA_TNIL:
return LBC_TYPE_NIL;
case LUA_TBOOLEAN:
return LBC_TYPE_BOOLEAN;
case LUA_TLIGHTUSERDATA:
return LBC_TYPE_USERDATA;
case LUA_TNUMBER:
return LBC_TYPE_NUMBER;
case LUA_TVECTOR:
return LBC_TYPE_VECTOR;
case LUA_TSTRING:
return LBC_TYPE_STRING;
case LUA_TTABLE:
return LBC_TYPE_TABLE;
case LUA_TFUNCTION:
return LBC_TYPE_FUNCTION;
case LUA_TUSERDATA:
return LBC_TYPE_USERDATA;
case LUA_TTHREAD:
return LBC_TYPE_THREAD;
case LUA_TBUFFER:
return LBC_TYPE_BUFFER;
}
return LBC_TYPE_ANY;
}
static void applyBuiltinCall(int bfid, BytecodeTypes& types)
{
switch (bfid)
{
case LBF_NONE:
case LBF_ASSERT:
types.result = LBC_TYPE_ANY;
break;
case LBF_MATH_ABS:
case LBF_MATH_ACOS:
case LBF_MATH_ASIN:
types.result = LBC_TYPE_NUMBER;
types.a = LBC_TYPE_NUMBER;
break;
case LBF_MATH_ATAN2:
types.result = LBC_TYPE_NUMBER;
types.a = LBC_TYPE_NUMBER;
types.b = LBC_TYPE_NUMBER;
break;
case LBF_MATH_ATAN:
case LBF_MATH_CEIL:
case LBF_MATH_COSH:
case LBF_MATH_COS:
case LBF_MATH_DEG:
case LBF_MATH_EXP:
case LBF_MATH_FLOOR:
types.result = LBC_TYPE_NUMBER;
types.a = LBC_TYPE_NUMBER;
break;
case LBF_MATH_FMOD:
types.result = LBC_TYPE_NUMBER;
types.a = LBC_TYPE_NUMBER;
types.b = LBC_TYPE_NUMBER;
break;
case LBF_MATH_FREXP:
types.result = LBC_TYPE_NUMBER;
types.a = LBC_TYPE_NUMBER;
break;
case LBF_MATH_LDEXP:
types.result = LBC_TYPE_NUMBER;
types.a = LBC_TYPE_NUMBER;
types.b = LBC_TYPE_NUMBER;
break;
case LBF_MATH_LOG10:
types.result = LBC_TYPE_NUMBER;
types.a = LBC_TYPE_NUMBER;
break;
case LBF_MATH_LOG:
types.result = LBC_TYPE_NUMBER;
types.a = LBC_TYPE_NUMBER;
types.b = LBC_TYPE_NUMBER; // We can mark optional arguments
break;
case LBF_MATH_MAX:
case LBF_MATH_MIN:
types.result = LBC_TYPE_NUMBER;
types.a = LBC_TYPE_NUMBER;
types.b = LBC_TYPE_NUMBER;
types.c = LBC_TYPE_NUMBER; // We can mark optional arguments
break;
case LBF_MATH_MODF:
types.result = LBC_TYPE_NUMBER;
types.a = LBC_TYPE_NUMBER;
break;
case LBF_MATH_POW:
types.result = LBC_TYPE_NUMBER;
types.a = LBC_TYPE_NUMBER;
types.b = LBC_TYPE_NUMBER;
break;
case LBF_MATH_RAD:
case LBF_MATH_SINH:
case LBF_MATH_SIN:
case LBF_MATH_SQRT:
case LBF_MATH_TANH:
case LBF_MATH_TAN:
types.result = LBC_TYPE_NUMBER;
types.a = LBC_TYPE_NUMBER;
break;
case LBF_BIT32_ARSHIFT:
types.result = LBC_TYPE_NUMBER;
types.a = LBC_TYPE_NUMBER;
types.b = LBC_TYPE_NUMBER;
break;
case LBF_BIT32_BAND:
types.result = LBC_TYPE_NUMBER;
types.a = LBC_TYPE_NUMBER;
types.b = LBC_TYPE_NUMBER;
types.c = LBC_TYPE_NUMBER; // We can mark optional arguments
break;
case LBF_BIT32_BNOT:
types.result = LBC_TYPE_NUMBER;
types.a = LBC_TYPE_NUMBER;
break;
case LBF_BIT32_BOR:
case LBF_BIT32_BXOR:
case LBF_BIT32_BTEST:
case LBF_BIT32_EXTRACT:
types.result = LBC_TYPE_NUMBER;
types.a = LBC_TYPE_NUMBER;
types.b = LBC_TYPE_NUMBER;
types.c = LBC_TYPE_NUMBER; // We can mark optional arguments
break;
case LBF_BIT32_LROTATE:
case LBF_BIT32_LSHIFT:
types.result = LBC_TYPE_NUMBER;
types.a = LBC_TYPE_NUMBER;
types.b = LBC_TYPE_NUMBER;
break;
case LBF_BIT32_REPLACE:
types.result = LBC_TYPE_NUMBER;
types.a = LBC_TYPE_NUMBER;
types.b = LBC_TYPE_NUMBER;
types.c = LBC_TYPE_NUMBER; // We can mark optional arguments
break;
case LBF_BIT32_RROTATE:
case LBF_BIT32_RSHIFT:
types.result = LBC_TYPE_NUMBER;
types.a = LBC_TYPE_NUMBER;
types.b = LBC_TYPE_NUMBER;
break;
case LBF_TYPE:
types.result = LBC_TYPE_STRING;
break;
case LBF_STRING_BYTE:
types.result = LBC_TYPE_NUMBER;
types.a = LBC_TYPE_STRING;
types.b = LBC_TYPE_NUMBER;
break;
case LBF_STRING_CHAR:
types.result = LBC_TYPE_STRING;
// We can mark optional arguments
types.a = LBC_TYPE_NUMBER;
types.b = LBC_TYPE_NUMBER;
types.c = LBC_TYPE_NUMBER;
break;
case LBF_STRING_LEN:
types.result = LBC_TYPE_NUMBER;
types.a = LBC_TYPE_STRING;
break;
case LBF_TYPEOF:
types.result = LBC_TYPE_STRING;
break;
case LBF_STRING_SUB:
types.result = LBC_TYPE_STRING;
types.a = LBC_TYPE_STRING;
types.b = LBC_TYPE_NUMBER;
types.c = LBC_TYPE_NUMBER;
break;
case LBF_MATH_CLAMP:
types.result = LBC_TYPE_NUMBER;
types.a = LBC_TYPE_NUMBER;
types.b = LBC_TYPE_NUMBER;
types.c = LBC_TYPE_NUMBER;
break;
case LBF_MATH_SIGN:
types.result = LBC_TYPE_NUMBER;
types.a = LBC_TYPE_NUMBER;
break;
case LBF_MATH_ROUND:
types.result = LBC_TYPE_NUMBER;
types.a = LBC_TYPE_NUMBER;
break;
case LBF_RAWGET:
types.result = LBC_TYPE_ANY;
types.a = LBC_TYPE_TABLE;
break;
case LBF_RAWEQUAL:
types.result = LBC_TYPE_BOOLEAN;
break;
case LBF_TABLE_UNPACK:
types.result = LBC_TYPE_ANY;
types.a = LBC_TYPE_TABLE;
types.b = LBC_TYPE_NUMBER; // We can mark optional arguments
break;
case LBF_VECTOR:
types.result = LBC_TYPE_VECTOR;
types.a = LBC_TYPE_NUMBER;
types.b = LBC_TYPE_NUMBER;
types.c = LBC_TYPE_NUMBER;
break;
case LBF_BIT32_COUNTLZ:
case LBF_BIT32_COUNTRZ:
types.result = LBC_TYPE_NUMBER;
types.a = LBC_TYPE_NUMBER;
break;
case LBF_SELECT_VARARG:
types.result = LBC_TYPE_ANY;
break;
case LBF_RAWLEN:
types.result = LBC_TYPE_NUMBER;
break;
case LBF_BIT32_EXTRACTK:
types.result = LBC_TYPE_NUMBER;
types.a = LBC_TYPE_NUMBER;
types.b = LBC_TYPE_NUMBER;
break;
case LBF_GETMETATABLE:
types.result = LBC_TYPE_TABLE;
break;
case LBF_TONUMBER:
types.result = LBC_TYPE_NUMBER;
break;
case LBF_TOSTRING:
types.result = LBC_TYPE_STRING;
break;
case LBF_BIT32_BYTESWAP:
types.result = LBC_TYPE_NUMBER;
types.a = LBC_TYPE_NUMBER;
break;
case LBF_BUFFER_READI8:
case LBF_BUFFER_READU8:
types.result = LBC_TYPE_NUMBER;
types.a = LBC_TYPE_BUFFER;
types.b = LBC_TYPE_NUMBER;
break;
case LBF_BUFFER_WRITEU8:
types.result = LBC_TYPE_NIL;
types.a = LBC_TYPE_BUFFER;
types.b = LBC_TYPE_NUMBER;
types.c = LBC_TYPE_NUMBER;
break;
case LBF_BUFFER_READI16:
case LBF_BUFFER_READU16:
types.result = LBC_TYPE_NUMBER;
types.a = LBC_TYPE_BUFFER;
types.b = LBC_TYPE_NUMBER;
break;
case LBF_BUFFER_WRITEU16:
types.result = LBC_TYPE_NIL;
types.a = LBC_TYPE_BUFFER;
types.b = LBC_TYPE_NUMBER;
types.c = LBC_TYPE_NUMBER;
break;
case LBF_BUFFER_READI32:
case LBF_BUFFER_READU32:
types.result = LBC_TYPE_NUMBER;
types.a = LBC_TYPE_BUFFER;
types.b = LBC_TYPE_NUMBER;
break;
case LBF_BUFFER_WRITEU32:
types.result = LBC_TYPE_NIL;
types.a = LBC_TYPE_BUFFER;
types.b = LBC_TYPE_NUMBER;
types.c = LBC_TYPE_NUMBER;
break;
case LBF_BUFFER_READF32:
types.result = LBC_TYPE_NUMBER;
types.a = LBC_TYPE_BUFFER;
types.b = LBC_TYPE_NUMBER;
break;
case LBF_BUFFER_WRITEF32:
types.result = LBC_TYPE_NIL;
types.a = LBC_TYPE_BUFFER;
types.b = LBC_TYPE_NUMBER;
types.c = LBC_TYPE_NUMBER;
break;
case LBF_BUFFER_READF64:
types.result = LBC_TYPE_NUMBER;
types.a = LBC_TYPE_BUFFER;
types.b = LBC_TYPE_NUMBER;
break;
case LBF_BUFFER_WRITEF64:
types.result = LBC_TYPE_NIL;
types.a = LBC_TYPE_BUFFER;
types.b = LBC_TYPE_NUMBER;
types.c = LBC_TYPE_NUMBER;
break;
case LBF_TABLE_INSERT:
types.result = LBC_TYPE_NIL;
types.a = LBC_TYPE_TABLE;
break;
case LBF_RAWSET:
types.result = LBC_TYPE_ANY;
types.a = LBC_TYPE_TABLE;
break;
case LBF_SETMETATABLE:
types.result = LBC_TYPE_TABLE;
types.a = LBC_TYPE_TABLE;
types.b = LBC_TYPE_TABLE;
break;
}
}
void buildBytecodeBlocks(IrFunction& function, const std::vector<uint8_t>& jumpTargets)
{
Proto* proto = function.proto;
LUAU_ASSERT(proto);
std::vector<BytecodeBlock>& bcBlocks = function.bcBlocks;
// Using the same jump targets, create VM bytecode basic blocks
bcBlocks.push_back(BytecodeBlock{0, -1});
int previ = 0;
for (int i = 0; i < proto->sizecode;)
{
const Instruction* pc = &proto->code[i];
LuauOpcode op = LuauOpcode(LUAU_INSN_OP(*pc));
int nexti = i + getOpLength(op);
// If instruction is a jump target, begin new block starting from it
if (i != 0 && jumpTargets[i])
{
bcBlocks.back().finishpc = previ;
bcBlocks.push_back(BytecodeBlock{i, -1});
}
int target = getJumpTarget(*pc, uint32_t(i));
// Implicit fallthroughs terminate the block and might start a new one
if (target >= 0 && !isFastCall(op))
{
bcBlocks.back().finishpc = i;
// Start a new block if there was no explicit jump for the fallthrough
if (!jumpTargets[nexti])
bcBlocks.push_back(BytecodeBlock{nexti, -1});
}
// Returns just terminate the block
else if (op == LOP_RETURN)
{
bcBlocks.back().finishpc = i;
}
previ = i;
i = nexti;
LUAU_ASSERT(i <= proto->sizecode);
}
}
void analyzeBytecodeTypes(IrFunction& function)
{
Proto* proto = function.proto;
LUAU_ASSERT(proto);
// Setup our current knowledge of type tags based on arguments
uint8_t regTags[256];
memset(regTags, LBC_TYPE_ANY, 256);
function.bcTypes.resize(proto->sizecode);
// Now that we have VM basic blocks, we can attempt to track register type tags locally
for (const BytecodeBlock& block : function.bcBlocks)
{
LUAU_ASSERT(block.startpc != -1);
LUAU_ASSERT(block.finishpc != -1);
// At the block start, reset or knowledge to the starting state
// In the future we might be able to propagate some info between the blocks as well
if (hasTypedParameters(proto))
{
for (int i = 0; i < proto->numparams; ++i)
{
uint8_t et = proto->typeinfo[2 + i];
// TODO: if argument is optional, this might force a VM exit unnecessarily
regTags[i] = et & ~LBC_TYPE_OPTIONAL_BIT;
}
}
for (int i = proto->numparams; i < proto->maxstacksize; ++i)
regTags[i] = LBC_TYPE_ANY;
for (int i = block.startpc; i <= block.finishpc;)
{
const Instruction* pc = &proto->code[i];
LuauOpcode op = LuauOpcode(LUAU_INSN_OP(*pc));
BytecodeTypes& bcType = function.bcTypes[i];
switch (op)
{
case LOP_NOP:
break;
case LOP_LOADNIL:
{
int ra = LUAU_INSN_A(*pc);
regTags[ra] = LBC_TYPE_NIL;
bcType.result = regTags[ra];
break;
}
case LOP_LOADB:
{
int ra = LUAU_INSN_A(*pc);
regTags[ra] = LBC_TYPE_BOOLEAN;
bcType.result = regTags[ra];
break;
}
case LOP_LOADN:
{
int ra = LUAU_INSN_A(*pc);
regTags[ra] = LBC_TYPE_NUMBER;
bcType.result = regTags[ra];
break;
}
case LOP_LOADK:
{
int ra = LUAU_INSN_A(*pc);
int kb = LUAU_INSN_D(*pc);
bcType.a = getBytecodeConstantTag(proto, kb);
regTags[ra] = bcType.a;
bcType.result = regTags[ra];
break;
}
case LOP_LOADKX:
{
int ra = LUAU_INSN_A(*pc);
int kb = int(pc[1]);
bcType.a = getBytecodeConstantTag(proto, kb);
regTags[ra] = bcType.a;
bcType.result = regTags[ra];
break;
}
case LOP_MOVE:
{
int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc);
bcType.a = regTags[rb];
regTags[ra] = regTags[rb];
bcType.result = regTags[ra];
break;
}
case LOP_GETTABLE:
{
int rb = LUAU_INSN_B(*pc);
int rc = LUAU_INSN_C(*pc);
bcType.a = regTags[rb];
bcType.b = regTags[rc];
break;
}
case LOP_SETTABLE:
{
int rb = LUAU_INSN_B(*pc);
int rc = LUAU_INSN_C(*pc);
bcType.a = regTags[rb];
bcType.b = regTags[rc];
break;
}
case LOP_GETTABLEKS:
{
int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc);
uint32_t kc = pc[1];
bcType.a = regTags[rb];
bcType.b = getBytecodeConstantTag(proto, kc);
regTags[ra] = LBC_TYPE_ANY;
// Assuming that vector component is being indexed
// TODO: check what key is used
if (bcType.a == LBC_TYPE_VECTOR)
regTags[ra] = LBC_TYPE_NUMBER;
bcType.result = regTags[ra];
break;
}
case LOP_SETTABLEKS:
{
int rb = LUAU_INSN_B(*pc);
bcType.a = regTags[rb];
bcType.b = LBC_TYPE_STRING;
break;
}
case LOP_GETTABLEN:
case LOP_SETTABLEN:
{
int rb = LUAU_INSN_B(*pc);
bcType.a = regTags[rb];
bcType.b = LBC_TYPE_NUMBER;
break;
}
case LOP_ADD:
case LOP_SUB:
{
int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc);
int rc = LUAU_INSN_C(*pc);
bcType.a = regTags[rb];
bcType.b = regTags[rc];
regTags[ra] = LBC_TYPE_ANY;
if (bcType.a == LBC_TYPE_NUMBER && bcType.b == LBC_TYPE_NUMBER)
regTags[ra] = LBC_TYPE_NUMBER;
else if (bcType.a == LBC_TYPE_VECTOR && bcType.b == LBC_TYPE_VECTOR)
regTags[ra] = LBC_TYPE_VECTOR;
bcType.result = regTags[ra];
break;
}
case LOP_MUL:
case LOP_DIV:
case LOP_IDIV:
{
int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc);
int rc = LUAU_INSN_C(*pc);
bcType.a = regTags[rb];
bcType.b = regTags[rc];
regTags[ra] = LBC_TYPE_ANY;
if (bcType.a == LBC_TYPE_NUMBER)
{
if (bcType.b == LBC_TYPE_NUMBER)
regTags[ra] = LBC_TYPE_NUMBER;
else if (bcType.b == LBC_TYPE_VECTOR)
regTags[ra] = LBC_TYPE_VECTOR;
}
else if (bcType.a == LBC_TYPE_VECTOR)
{
if (bcType.b == LBC_TYPE_NUMBER || bcType.b == LBC_TYPE_VECTOR)
regTags[ra] = LBC_TYPE_VECTOR;
}
bcType.result = regTags[ra];
break;
}
case LOP_MOD:
case LOP_POW:
{
int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc);
int rc = LUAU_INSN_C(*pc);
bcType.a = regTags[rb];
bcType.b = regTags[rc];
regTags[ra] = LBC_TYPE_ANY;
if (bcType.a == LBC_TYPE_NUMBER && bcType.b == LBC_TYPE_NUMBER)
regTags[ra] = LBC_TYPE_NUMBER;
bcType.result = regTags[ra];
break;
}
case LOP_ADDK:
case LOP_SUBK:
{
int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc);
int kc = LUAU_INSN_C(*pc);
bcType.a = regTags[rb];
bcType.b = getBytecodeConstantTag(proto, kc);
regTags[ra] = LBC_TYPE_ANY;
if (bcType.a == LBC_TYPE_NUMBER && bcType.b == LBC_TYPE_NUMBER)
regTags[ra] = LBC_TYPE_NUMBER;
else if (bcType.a == LBC_TYPE_VECTOR && bcType.b == LBC_TYPE_VECTOR)
regTags[ra] = LBC_TYPE_VECTOR;
bcType.result = regTags[ra];
break;
}
case LOP_MULK:
case LOP_DIVK:
case LOP_IDIVK:
{
int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc);
int kc = LUAU_INSN_C(*pc);
bcType.a = regTags[rb];
bcType.b = getBytecodeConstantTag(proto, kc);
regTags[ra] = LBC_TYPE_ANY;
if (bcType.a == LBC_TYPE_NUMBER)
{
if (bcType.b == LBC_TYPE_NUMBER)
regTags[ra] = LBC_TYPE_NUMBER;
else if (bcType.b == LBC_TYPE_VECTOR)
regTags[ra] = LBC_TYPE_VECTOR;
}
else if (bcType.a == LBC_TYPE_VECTOR)
{
if (bcType.b == LBC_TYPE_NUMBER || bcType.b == LBC_TYPE_VECTOR)
regTags[ra] = LBC_TYPE_VECTOR;
}
bcType.result = regTags[ra];
break;
}
case LOP_MODK:
case LOP_POWK:
{
int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc);
int kc = LUAU_INSN_C(*pc);
bcType.a = regTags[rb];
bcType.b = getBytecodeConstantTag(proto, kc);
regTags[ra] = LBC_TYPE_ANY;
if (bcType.a == LBC_TYPE_NUMBER && bcType.b == LBC_TYPE_NUMBER)
regTags[ra] = LBC_TYPE_NUMBER;
bcType.result = regTags[ra];
break;
}
case LOP_SUBRK:
{
int ra = LUAU_INSN_A(*pc);
int kb = LUAU_INSN_B(*pc);
int rc = LUAU_INSN_C(*pc);
bcType.a = getBytecodeConstantTag(proto, kb);
bcType.b = regTags[rc];
regTags[ra] = LBC_TYPE_ANY;
if (bcType.a == LBC_TYPE_NUMBER && bcType.b == LBC_TYPE_NUMBER)
regTags[ra] = LBC_TYPE_NUMBER;
else if (bcType.a == LBC_TYPE_VECTOR && bcType.b == LBC_TYPE_VECTOR)
regTags[ra] = LBC_TYPE_VECTOR;
bcType.result = regTags[ra];
break;
}
case LOP_DIVRK:
{
int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc);
int kc = LUAU_INSN_C(*pc);
bcType.a = regTags[rb];
bcType.b = getBytecodeConstantTag(proto, kc);
regTags[ra] = LBC_TYPE_ANY;
if (bcType.a == LBC_TYPE_NUMBER)
{
if (bcType.b == LBC_TYPE_NUMBER)
regTags[ra] = LBC_TYPE_NUMBER;
else if (bcType.b == LBC_TYPE_VECTOR)
regTags[ra] = LBC_TYPE_VECTOR;
}
else if (bcType.a == LBC_TYPE_VECTOR)
{
if (bcType.b == LBC_TYPE_NUMBER || bcType.b == LBC_TYPE_VECTOR)
regTags[ra] = LBC_TYPE_VECTOR;
}
bcType.result = regTags[ra];
break;
}
case LOP_NOT:
{
int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc);
bcType.a = regTags[rb];
regTags[ra] = LBC_TYPE_BOOLEAN;
bcType.result = regTags[ra];
break;
}
case LOP_MINUS:
{
int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc);
bcType.a = regTags[rb];
regTags[ra] = LBC_TYPE_ANY;
if (bcType.a == LBC_TYPE_NUMBER)
regTags[ra] = LBC_TYPE_NUMBER;
else if (bcType.a == LBC_TYPE_VECTOR)
regTags[ra] = LBC_TYPE_VECTOR;
bcType.result = regTags[ra];
break;
}
case LOP_LENGTH:
{
int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc);
bcType.a = regTags[rb];
regTags[ra] = LBC_TYPE_NUMBER; // Even if it's a custom __len, it's ok to assume a sane result
bcType.result = regTags[ra];
break;
}
case LOP_NEWTABLE:
case LOP_DUPTABLE:
{
int ra = LUAU_INSN_A(*pc);
regTags[ra] = LBC_TYPE_TABLE;
bcType.result = regTags[ra];
break;
}
case LOP_FASTCALL:
{
int bfid = LUAU_INSN_A(*pc);
int skip = LUAU_INSN_C(*pc);
Instruction call = pc[skip + 1];
LUAU_ASSERT(LUAU_INSN_OP(call) == LOP_CALL);
int ra = LUAU_INSN_A(call);
applyBuiltinCall(bfid, bcType);
regTags[ra + 1] = bcType.a;
regTags[ra + 2] = bcType.b;
regTags[ra + 3] = bcType.c;
regTags[ra] = bcType.result;
break;
}
case LOP_FASTCALL1:
case LOP_FASTCALL2K:
{
int bfid = LUAU_INSN_A(*pc);
int skip = LUAU_INSN_C(*pc);
Instruction call = pc[skip + 1];
LUAU_ASSERT(LUAU_INSN_OP(call) == LOP_CALL);
int ra = LUAU_INSN_A(call);
applyBuiltinCall(bfid, bcType);
regTags[LUAU_INSN_B(*pc)] = bcType.a;
regTags[ra] = bcType.result;
break;
}
case LOP_FASTCALL2:
{
int bfid = LUAU_INSN_A(*pc);
int skip = LUAU_INSN_C(*pc);
Instruction call = pc[skip + 1];
LUAU_ASSERT(LUAU_INSN_OP(call) == LOP_CALL);
int ra = LUAU_INSN_A(call);
applyBuiltinCall(bfid, bcType);
regTags[LUAU_INSN_B(*pc)] = bcType.a;
regTags[int(pc[1])] = bcType.b;
regTags[ra] = bcType.result;
break;
}
case LOP_FORNPREP:
{
int ra = LUAU_INSN_A(*pc);
regTags[ra] = LBC_TYPE_NUMBER;
regTags[ra + 1] = LBC_TYPE_NUMBER;
regTags[ra + 2] = LBC_TYPE_NUMBER;
break;
}
case LOP_FORNLOOP:
{
int ra = LUAU_INSN_A(*pc);
// These types are established by LOP_FORNPREP and we reinforce that here
regTags[ra] = LBC_TYPE_NUMBER;
regTags[ra + 1] = LBC_TYPE_NUMBER;
regTags[ra + 2] = LBC_TYPE_NUMBER;
break;
}
case LOP_CONCAT:
{
int ra = LUAU_INSN_A(*pc);
regTags[ra] = LBC_TYPE_STRING;
bcType.result = regTags[ra];
break;
}
case LOP_NEWCLOSURE:
case LOP_DUPCLOSURE:
{
int ra = LUAU_INSN_A(*pc);
regTags[ra] = LBC_TYPE_FUNCTION;
bcType.result = regTags[ra];
break;
}
case LOP_GETGLOBAL:
case LOP_SETGLOBAL:
case LOP_CALL:
case LOP_RETURN:
case LOP_JUMP:
case LOP_JUMPBACK:
case LOP_JUMPIF:
case LOP_JUMPIFNOT:
case LOP_JUMPIFEQ:
case LOP_JUMPIFLE:
case LOP_JUMPIFLT:
case LOP_JUMPIFNOTEQ:
case LOP_JUMPIFNOTLE:
case LOP_JUMPIFNOTLT:
case LOP_JUMPX:
case LOP_JUMPXEQKNIL:
case LOP_JUMPXEQKB:
case LOP_JUMPXEQKN:
case LOP_JUMPXEQKS:
case LOP_SETLIST:
case LOP_GETUPVAL:
case LOP_SETUPVAL:
case LOP_CLOSEUPVALS:
case LOP_FORGLOOP:
case LOP_FORGPREP_NEXT:
case LOP_FORGPREP_INEXT:
case LOP_AND:
case LOP_ANDK:
case LOP_OR:
case LOP_ORK:
case LOP_COVERAGE:
case LOP_GETIMPORT:
case LOP_CAPTURE:
case LOP_NAMECALL:
case LOP_PREPVARARGS:
case LOP_GETVARARGS:
case LOP_FORGPREP:
break;
default:
LUAU_ASSERT(!"Unknown instruction");
}
i += getOpLength(op);
}
}
}
} // namespace CodeGen
} // namespace Luau

View file

@ -0,0 +1,71 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/BytecodeSummary.h"
#include "CodeGenLower.h"
#include "lua.h"
#include "lapi.h"
#include "lobject.h"
#include "lstate.h"
namespace Luau
{
namespace CodeGen
{
FunctionBytecodeSummary::FunctionBytecodeSummary(std::string source, std::string name, const int line, unsigned nestingLimit)
: source(std::move(source))
, name(std::move(name))
, line(line)
, nestingLimit(nestingLimit)
{
counts.reserve(nestingLimit);
for (unsigned i = 0; i < 1 + nestingLimit; ++i)
{
counts.push_back(std::vector<unsigned>(getOpLimit(), 0));
}
}
FunctionBytecodeSummary FunctionBytecodeSummary::fromProto(Proto* proto, unsigned nestingLimit)
{
const char* source = getstr(proto->source);
source = (source[0] == '=' || source[0] == '@') ? source + 1 : "[string]";
const char* name = proto->debugname ? getstr(proto->debugname) : "";
int line = proto->linedefined;
FunctionBytecodeSummary summary(source, name, line, nestingLimit);
for (int i = 0; i < proto->sizecode; ++i)
{
Instruction insn = proto->code[i];
uint8_t op = LUAU_INSN_OP(insn);
summary.incCount(0, op);
}
return summary;
}
std::vector<FunctionBytecodeSummary> summarizeBytecode(lua_State* L, int idx, unsigned nestingLimit)
{
LUAU_ASSERT(lua_isLfunction(L, idx));
const TValue* func = luaA_toobject(L, idx);
Proto* root = clvalue(func)->l.p;
std::vector<Proto*> protos;
gatherFunctions(protos, root, CodeGen_ColdFunctions);
std::vector<FunctionBytecodeSummary> summaries;
summaries.reserve(protos.size());
for (Proto* proto : protos)
{
summaries.push_back(FunctionBytecodeSummary::fromProto(proto, nestingLimit));
}
return summaries;
}
} // namespace CodeGen
} // namespace Luau

View file

@ -1,5 +1,6 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/CodeGen.h"
#include "Luau/BytecodeUtils.h"
#include "CodeGenLower.h"
@ -42,6 +43,17 @@ static void logFunctionHeader(AssemblyBuilder& build, Proto* proto)
build.logAppend("\n");
}
unsigned getInstructionCount(const Instruction* insns, const unsigned size)
{
unsigned count = 0;
for (unsigned i = 0; i < size;)
{
++count;
i += Luau::getOpLength(LuauOpcode(LUAU_INSN_OP(insns[i])));
}
return count;
}
template<typename AssemblyBuilder>
static std::string getAssemblyImpl(AssemblyBuilder& build, const TValue* func, AssemblyOptions options, LoweringStats* stats)
{
@ -53,7 +65,11 @@ static std::string getAssemblyImpl(AssemblyBuilder& build, const TValue* func, A
std::vector<Proto*> protos;
gatherFunctions(protos, root, options.flags);
protos.erase(std::remove_if(protos.begin(), protos.end(), [](Proto* p) { return p == nullptr; }), protos.end());
protos.erase(std::remove_if(protos.begin(), protos.end(),
[](Proto* p) {
return p == nullptr;
}),
protos.end());
if (stats)
stats->totalFunctions += unsigned(protos.size());
@ -77,6 +93,7 @@ static std::string getAssemblyImpl(AssemblyBuilder& build, const TValue* func, A
{
IrBuilder ir;
ir.buildFunctionIr(p);
unsigned asmCount = build.getCodeSize();
if (options.includeAssembly || options.includeIr)
logFunctionHeader(build, p);
@ -86,9 +103,24 @@ static std::string getAssemblyImpl(AssemblyBuilder& build, const TValue* func, A
if (build.logText)
build.logAppend("; skipping (can't lower)\n");
asmCount = 0;
if (stats)
stats->skippedFunctions += 1;
}
else
{
asmCount = build.getCodeSize() - asmCount;
}
if (stats && stats->collectFunctionStats)
{
const char* name = p->debugname ? getstr(p->debugname) : "";
int line = p->linedefined;
unsigned bcodeCount = getInstructionCount(p->code, p->sizecode);
unsigned irCount = unsigned(ir.function.instructions.size());
stats->functions.push_back({name, line, bcodeCount, irCount, asmCount});
}
if (build.logText)
build.logAppend("\n");

View file

@ -26,6 +26,7 @@ LUAU_FASTFLAG(DebugCodegenSkipNumbering)
LUAU_FASTINT(CodegenHeuristicsInstructionLimit)
LUAU_FASTINT(CodegenHeuristicsBlockLimit)
LUAU_FASTINT(CodegenHeuristicsBlockInstructionLimit)
LUAU_FASTFLAG(LuauKeepVmapLinear2)
namespace Luau
{
@ -50,6 +51,13 @@ inline void gatherFunctions(std::vector<Proto*>& results, Proto* proto, unsigned
gatherFunctions(results, proto->p[i], flags);
}
inline unsigned getInstructionCount(const std::vector<IrInst>& instructions, IrCmd cmd)
{
return unsigned(std::count_if(instructions.begin(), instructions.end(), [&cmd](const IrInst& inst) {
return inst.cmd == cmd;
}));
}
template<typename AssemblyBuilder, typename IrLowering>
inline bool lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction& function, const std::vector<uint32_t>& sortedBlocks, int bytecodeid,
AssemblyOptions options)
@ -105,8 +113,16 @@ inline bool lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction&
toStringDetailed(ctx, block, blockIndex, /* includeUseInfo */ true);
}
// Values can only reference restore operands in the current block
function.validRestoreOpBlockIdx = blockIndex;
if (FFlag::LuauKeepVmapLinear2)
{
// Values can only reference restore operands in the current block chain
function.validRestoreOpBlocks.push_back(blockIndex);
}
else
{
// Values can only reference restore operands in the current block
function.validRestoreOpBlockIdx = blockIndex;
}
build.setLabel(block.label);
@ -132,6 +148,15 @@ inline bool lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction&
if (outputEnabled && options.annotator && bcLocation != ~0u)
{
options.annotator(options.annotatorContext, build.text, bytecodeid, bcLocation);
// If available, report inferred register tags
BytecodeTypes bcTypes = function.getBytecodeTypesAt(bcLocation);
if (bcTypes.result != LBC_TYPE_ANY || bcTypes.a != LBC_TYPE_ANY || bcTypes.b != LBC_TYPE_ANY || bcTypes.c != LBC_TYPE_ANY)
{
toString(ctx.result, bcTypes);
build.logAppend("\n");
}
}
// If bytecode needs the location of this instruction for jumps, record it
@ -183,6 +208,9 @@ inline bool lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction&
if (options.includeIr)
build.logAppend("#\n");
if (FFlag::LuauKeepVmapLinear2 && block.expectedNextBlock == ~0u)
function.validRestoreOpBlocks.clear();
}
if (!seenFallback)
@ -269,7 +297,25 @@ inline bool lowerFunction(IrBuilder& ir, AssemblyBuilder& build, ModuleHelpers&
constPropInBlockChains(ir, useValueNumbering);
if (!FFlag::DebugCodegenOptSize)
{
double startTime = 0.0;
unsigned constPropInstructionCount = 0;
if (stats)
{
constPropInstructionCount = getInstructionCount(ir.function.instructions, IrCmd::SUBSTITUTE);
startTime = lua_clock();
}
createLinearBlocks(ir, useValueNumbering);
if (stats)
{
stats->blockLinearizationStats.timeSeconds += lua_clock() - startTime;
constPropInstructionCount = getInstructionCount(ir.function.instructions, IrCmd::SUBSTITUTE) - constPropInstructionCount;
stats->blockLinearizationStats.constPropInstructionCount += constPropInstructionCount;
}
}
}
std::vector<uint32_t> sortedBlocks = getSortedBlockOrder(ir.function);

View file

@ -426,7 +426,7 @@ const Instruction* executeGETTABLEKS(lua_State* L, const Instruction* pc, StkId
if (unsigned(ic) < LUA_VECTOR_SIZE && name[1] == '\0')
{
const float* v = rb->value.v; // silences ubsan when indexing v[]
const float* v = vvalue(rb); // silences ubsan when indexing v[]
setnvalue(ra, v[ic]);
return pc;
}
@ -531,50 +531,6 @@ const Instruction* executeSETTABLEKS(lua_State* L, const Instruction* pc, StkId
}
}
const Instruction* executeNEWCLOSURE(lua_State* L, const Instruction* pc, StkId base, TValue* k)
{
[[maybe_unused]] Closure* cl = clvalue(L->ci->func);
Instruction insn = *pc++;
StkId ra = VM_REG(LUAU_INSN_A(insn));
Proto* pv = cl->l.p->p[LUAU_INSN_D(insn)];
LUAU_ASSERT(unsigned(LUAU_INSN_D(insn)) < unsigned(cl->l.p->sizep));
VM_PROTECT_PC(); // luaF_newLclosure may fail due to OOM
// note: we save closure to stack early in case the code below wants to capture it by value
Closure* ncl = luaF_newLclosure(L, pv->nups, cl->env, pv);
setclvalue(L, ra, ncl);
for (int ui = 0; ui < pv->nups; ++ui)
{
Instruction uinsn = *pc++;
LUAU_ASSERT(LUAU_INSN_OP(uinsn) == LOP_CAPTURE);
switch (LUAU_INSN_A(uinsn))
{
case LCT_VAL:
setobj(L, &ncl->l.uprefs[ui], VM_REG(LUAU_INSN_B(uinsn)));
break;
case LCT_REF:
setupvalue(L, &ncl->l.uprefs[ui], luaF_findupval(L, VM_REG(LUAU_INSN_B(uinsn))));
break;
case LCT_UPVAL:
setobj(L, &ncl->l.uprefs[ui], VM_UV(LUAU_INSN_B(uinsn)));
break;
default:
LUAU_ASSERT(!"Unknown upvalue capture type");
LUAU_UNREACHABLE(); // improves switch() codegen by eliding opcode bounds checks
}
}
VM_PROTECT(luaC_checkGC(L));
return pc;
}
const Instruction* executeNAMECALL(lua_State* L, const Instruction* pc, StkId base, TValue* k)
{
[[maybe_unused]] Closure* cl = clvalue(L->ci->func);
@ -587,43 +543,19 @@ const Instruction* executeNAMECALL(lua_State* L, const Instruction* pc, StkId ba
if (ttistable(rb))
{
Table* h = hvalue(rb);
// note: we can't use nodemask8 here because we need to query the main position of the table, and 8-bit nodemask8 only works
// for predictive lookups
LuaNode* n = &h->node[tsvalue(kv)->hash & (sizenode(h) - 1)];
// note: lvmexecute.cpp version of NAMECALL has two fast paths, but both fast paths are inlined into IR
// as such, if we get here we can just use the generic path which makes the fallback path a little faster
const TValue* mt = 0;
const LuaNode* mtn = 0;
// fast-path: key is in the table in expected slot
if (ttisstring(gkey(n)) && tsvalue(gkey(n)) == tsvalue(kv) && !ttisnil(gval(n)))
{
// note: order of copies allows rb to alias ra+1 or ra
setobj2s(L, ra + 1, rb);
setobj2s(L, ra, gval(n));
}
// fast-path: key is absent from the base, table has an __index table, and it has the result in the expected slot
else if (gnext(n) == 0 && (mt = fasttm(L, hvalue(rb)->metatable, TM_INDEX)) && ttistable(mt) &&
(mtn = &hvalue(mt)->node[LUAU_INSN_C(insn) & hvalue(mt)->nodemask8]) && ttisstring(gkey(mtn)) && tsvalue(gkey(mtn)) == tsvalue(kv) &&
!ttisnil(gval(mtn)))
{
// note: order of copies allows rb to alias ra+1 or ra
setobj2s(L, ra + 1, rb);
setobj2s(L, ra, gval(mtn));
}
else
{
// slow-path: handles full table lookup
setobj2s(L, ra + 1, rb);
L->cachedslot = LUAU_INSN_C(insn);
VM_PROTECT(luaV_gettable(L, rb, kv, ra));
// save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++
VM_PATCH_C(pc - 2, L->cachedslot);
// recompute ra since stack might have been reallocated
ra = VM_REG(LUAU_INSN_A(insn));
if (ttisnil(ra))
luaG_methoderror(L, ra + 1, tsvalue(kv));
}
// slow-path: handles full table lookup
setobj2s(L, ra + 1, rb);
L->cachedslot = LUAU_INSN_C(insn);
VM_PROTECT(luaV_gettable(L, rb, kv, ra));
// save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++
VM_PATCH_C(pc - 2, L->cachedslot);
// recompute ra since stack might have been reallocated
ra = VM_REG(LUAU_INSN_A(insn));
if (ttisnil(ra))
luaG_methoderror(L, ra + 1, tsvalue(kv));
}
else
{

View file

@ -25,7 +25,6 @@ const Instruction* executeGETGLOBAL(lua_State* L, const Instruction* pc, StkId b
const Instruction* executeSETGLOBAL(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* executeGETTABLEKS(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* executeSETTABLEKS(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* executeNEWCLOSURE(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* executeNAMECALL(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* executeSETLIST(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* executeFORGPREP(lua_State* L, const Instruction* pc, StkId base, TValue* k);

View file

@ -148,12 +148,12 @@ void convertNumberToIndexOrJump(AssemblyBuilderX64& build, RegisterX64 tmp, Regi
build.jcc(ConditionX64::NotZero, label);
}
void callArithHelper(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb, OperandX64 c, TMS tm)
void callArithHelper(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, OperandX64 b, OperandX64 c, TMS tm)
{
IrCallWrapperX64 callWrap(regs, build);
callWrap.addArgument(SizeX64::qword, rState);
callWrap.addArgument(SizeX64::qword, luauRegAddress(ra));
callWrap.addArgument(SizeX64::qword, luauRegAddress(rb));
callWrap.addArgument(SizeX64::qword, b);
callWrap.addArgument(SizeX64::qword, c);
callWrap.addArgument(SizeX64::dword, tm);
callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarith)]);

View file

@ -200,7 +200,7 @@ ConditionX64 getConditionInt(IrCondition cond);
void getTableNodeAtCachedSlot(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 node, RegisterX64 table, int pcpos);
void convertNumberToIndexOrJump(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 numd, RegisterX64 numi, Label& label);
void callArithHelper(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb, OperandX64 c, TMS tm);
void callArithHelper(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, OperandX64 b, OperandX64 c, TMS tm);
void callLengthHelper(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb);
void callGetTable(IrRegAllocX64& regs, AssemblyBuilderX64& build, int rb, OperandX64 c, int ra);
void callSetTable(IrRegAllocX64& regs, AssemblyBuilderX64& build, int rb, OperandX64 c, int ra);

View file

@ -8,6 +8,7 @@
#include "lobject.h"
#include <algorithm>
#include <bitset>
#include <stddef.h>

View file

@ -2,6 +2,7 @@
#include "Luau/IrBuilder.h"
#include "Luau/Bytecode.h"
#include "Luau/BytecodeAnalysis.h"
#include "Luau/BytecodeUtils.h"
#include "Luau/IrData.h"
#include "Luau/IrUtils.h"
@ -12,6 +13,8 @@
#include <string.h>
LUAU_FASTFLAGVARIABLE(LuauCodegenBytecodeInfer, false)
namespace Luau
{
namespace CodeGen
@ -119,6 +122,10 @@ void IrBuilder::buildFunctionIr(Proto* proto)
// Rebuild original control flow blocks
rebuildBytecodeBasicBlocks(proto);
// Infer register tags in bytecode
if (FFlag::LuauCodegenBytecodeInfer)
analyzeBytecodeTypes(function);
function.bcMapping.resize(proto->sizecode, {~0u, ~0u});
if (generateTypeChecks)
@ -152,7 +159,7 @@ void IrBuilder::buildFunctionIr(Proto* proto)
// Numeric for loops require additional processing to maintain loop stack
// Notably, this must be performed even when the block is dead so that we maintain the pairing FORNPREP-FORNLOOP
if (op == LOP_FORNPREP)
beforeInstForNPrep(*this, pc);
beforeInstForNPrep(*this, pc, i);
// We skip dead bytecode instructions when they appear after block was already terminated
if (!inTerminatedBlock)
@ -212,7 +219,6 @@ void IrBuilder::rebuildBytecodeBasicBlocks(Proto* proto)
LUAU_ASSERT(i <= proto->sizecode);
}
// Bytecode blocks are created at bytecode jump targets and the start of a function
jumpTargets[0] = true;
@ -224,6 +230,9 @@ void IrBuilder::rebuildBytecodeBasicBlocks(Proto* proto)
instIndexToBlock[i] = b.index;
}
}
if (FFlag::LuauCodegenBytecodeInfer)
buildBytecodeBlocks(function, jumpTargets);
}
void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i)
@ -381,6 +390,12 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i)
case LOP_POWK:
translateInstBinaryK(*this, pc, i, TM_POW);
break;
case LOP_SUBRK:
translateInstBinaryRK(*this, pc, i, TM_SUB);
break;
case LOP_DIVRK:
translateInstBinaryRK(*this, pc, i, TM_DIV);
break;
case LOP_NOT:
translateInstNot(*this, pc);
break;

View file

@ -235,6 +235,8 @@ const char* getCmdName(IrCmd cmd)
return "CHECK_NODE_NO_NEXT";
case IrCmd::CHECK_NODE_VALUE:
return "CHECK_NODE_VALUE";
case IrCmd::CHECK_BUFFER_LEN:
return "CHECK_BUFFER_LEN";
case IrCmd::INTERRUPT:
return "INTERRUPT";
case IrCmd::CHECK_GC:
@ -309,6 +311,8 @@ const char* getCmdName(IrCmd cmd)
return "BITCOUNTLZ_UINT";
case IrCmd::BITCOUNTRZ_UINT:
return "BITCOUNTRZ_UINT";
case IrCmd::BYTESWAP_UINT:
return "BYTESWAP_UINT";
case IrCmd::INVOKE_LIBM:
return "INVOKE_LIBM";
case IrCmd::GET_TYPE:
@ -317,6 +321,30 @@ const char* getCmdName(IrCmd cmd)
return "GET_TYPEOF";
case IrCmd::FINDUPVAL:
return "FINDUPVAL";
case IrCmd::BUFFER_READI8:
return "BUFFER_READI8";
case IrCmd::BUFFER_READU8:
return "BUFFER_READU8";
case IrCmd::BUFFER_WRITEI8:
return "BUFFER_WRITEI8";
case IrCmd::BUFFER_READI16:
return "BUFFER_READI16";
case IrCmd::BUFFER_READU16:
return "BUFFER_READU16";
case IrCmd::BUFFER_WRITEI16:
return "BUFFER_WRITEI16";
case IrCmd::BUFFER_READI32:
return "BUFFER_READI32";
case IrCmd::BUFFER_WRITEI32:
return "BUFFER_WRITEI32";
case IrCmd::BUFFER_READF32:
return "BUFFER_READF32";
case IrCmd::BUFFER_WRITEF32:
return "BUFFER_WRITEF32";
case IrCmd::BUFFER_READF64:
return "BUFFER_READF64";
case IrCmd::BUFFER_WRITEF64:
return "BUFFER_WRITEF64";
}
LUAU_UNREACHABLE();
@ -404,7 +432,10 @@ void toString(IrToStringContext& ctx, IrOp op)
append(ctx.result, "U%d", vmUpvalueOp(op));
break;
case IrOpKind::VmExit:
append(ctx.result, "exit(%d)", vmExitOp(op));
if (vmExitOp(op) == kVmExitEntryGuardPc)
append(ctx.result, "exit(entry)");
else
append(ctx.result, "exit(%d)", vmExitOp(op));
break;
}
}
@ -431,6 +462,47 @@ void toString(std::string& result, IrConst constant)
}
}
const char* getBytecodeTypeName(uint8_t type)
{
switch (type)
{
case LBC_TYPE_NIL:
return "nil";
case LBC_TYPE_BOOLEAN:
return "boolean";
case LBC_TYPE_NUMBER:
return "number";
case LBC_TYPE_STRING:
return "string";
case LBC_TYPE_TABLE:
return "table";
case LBC_TYPE_FUNCTION:
return "function";
case LBC_TYPE_THREAD:
return "thread";
case LBC_TYPE_USERDATA:
return "userdata";
case LBC_TYPE_VECTOR:
return "vector";
case LBC_TYPE_BUFFER:
return "buffer";
case LBC_TYPE_ANY:
return "any";
}
LUAU_ASSERT(!"Unhandled type in getBytecodeTypeName");
return nullptr;
}
void toString(std::string& result, const BytecodeTypes& bcTypes)
{
if (bcTypes.c != LBC_TYPE_ANY)
append(result, "%s <- %s, %s, %s", getBytecodeTypeName(bcTypes.result), getBytecodeTypeName(bcTypes.a), getBytecodeTypeName(bcTypes.b),
getBytecodeTypeName(bcTypes.c));
else
append(result, "%s <- %s, %s", getBytecodeTypeName(bcTypes.result), getBytecodeTypeName(bcTypes.a), getBytecodeTypeName(bcTypes.b));
}
static void appendBlockSet(IrToStringContext& ctx, BlockIteratorWrapper blocks)
{
bool comma = false;

View file

@ -135,13 +135,9 @@ static void checkObjectBarrierConditions(AssemblyBuilderA64& build, RegisterA64
if (ratag == -1 || !isGCO(ratag))
{
if (ra.kind == IrOpKind::VmReg)
{
addr = mem(rBase, vmRegOp(ra) * sizeof(TValue) + offsetof(TValue, tt));
}
else if (ra.kind == IrOpKind::VmConst)
{
emitAddOffset(build, temp, rConstants, vmConstOp(ra) * sizeof(TValue) + offsetof(TValue, tt));
}
build.ldr(tempw, addr);
build.cmp(tempw, LUA_TSTRING);
@ -154,13 +150,10 @@ static void checkObjectBarrierConditions(AssemblyBuilderA64& build, RegisterA64
// iswhite(gcvalue(ra))
if (ra.kind == IrOpKind::VmReg)
{
addr = mem(rBase, vmRegOp(ra) * sizeof(TValue) + offsetof(TValue, value));
}
else if (ra.kind == IrOpKind::VmConst)
{
emitAddOffset(build, temp, rConstants, vmConstOp(ra) * sizeof(TValue) + offsetof(TValue, value));
}
build.ldr(temp, addr);
build.ldrb(tempw, mem(temp, offsetof(GCheader, marked)));
build.tst(tempw, bit2mask(WHITE0BIT, WHITE1BIT));
@ -240,6 +233,14 @@ static bool emitBuiltin(
}
}
static uint64_t getDoubleBits(double value)
{
uint64_t result;
static_assert(sizeof(result) == sizeof(value), "Expecting double to be 64-bit");
memcpy(&result, &value, sizeof(value));
return result;
}
IrLoweringA64::IrLoweringA64(AssemblyBuilderA64& build, ModuleHelpers& helpers, IrFunction& function, LoweringStats* stats)
: build(build)
, helpers(helpers)
@ -309,7 +310,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
if (inst.b.kind == IrOpKind::Inst)
{
build.add(inst.regA64, inst.regA64, regOp(inst.b), kTValueSizeLog2);
build.add(inst.regA64, inst.regA64, regOp(inst.b), kTValueSizeLog2); // implicit uxtw
}
else if (inst.b.kind == IrOpKind::Constant)
{
@ -404,14 +405,29 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
case IrCmd::STORE_POINTER:
{
AddressA64 addr = tempAddr(inst.a, offsetof(TValue, value));
build.str(regOp(inst.b), addr);
if (inst.b.kind == IrOpKind::Constant)
{
LUAU_ASSERT(intOp(inst.b) == 0);
build.str(xzr, addr);
}
else
{
build.str(regOp(inst.b), addr);
}
break;
}
case IrCmd::STORE_DOUBLE:
{
RegisterA64 temp = tempDouble(inst.b);
AddressA64 addr = tempAddr(inst.a, offsetof(TValue, value));
build.str(temp, addr);
if (inst.b.kind == IrOpKind::Constant && getDoubleBits(doubleOp(inst.b)) == 0)
{
build.str(xzr, addr);
}
else
{
RegisterA64 temp = tempDouble(inst.b);
build.str(temp, addr);
}
break;
}
case IrCmd::STORE_INT:
@ -816,11 +832,12 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
{
RegisterA64 index = tempDouble(inst.a);
RegisterA64 limit = tempDouble(inst.b);
RegisterA64 step = tempDouble(inst.c);
Label direct;
// step > 0
build.fcmpz(tempDouble(inst.c));
build.fcmpz(step);
build.b(getConditionFP(IrCondition::Greater), direct);
// !(limit <= index)
@ -974,6 +991,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
{
inst.regA64 = regs.allocReg(KindA64::w, index);
RegisterA64 temp = tempDouble(inst.a);
// note: we don't use fcvtzu for consistency with C++ code
build.fcvtzs(castReg(KindA64::x, inst.regA64), temp);
break;
}
@ -989,7 +1007,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
else if (inst.b.kind == IrOpKind::Inst)
{
build.add(temp, rBase, uint16_t(vmRegOp(inst.a) * sizeof(TValue)));
build.add(temp, temp, regOp(inst.b), kTValueSizeLog2);
build.add(temp, temp, regOp(inst.b), kTValueSizeLog2); // implicit uxtw
build.str(temp, mem(rState, offsetof(lua_State, top)));
}
else
@ -1049,7 +1067,11 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
regs.spill(build, index);
build.mov(x0, rState);
build.add(x1, rBase, uint16_t(vmRegOp(inst.a) * sizeof(TValue)));
build.add(x2, rBase, uint16_t(vmRegOp(inst.b) * sizeof(TValue)));
if (inst.b.kind == IrOpKind::VmConst)
emitAddOffset(build, x2, rConstants, vmConstOp(inst.b) * sizeof(TValue));
else
build.add(x2, rBase, uint16_t(vmRegOp(inst.b) * sizeof(TValue)));
if (inst.c.kind == IrOpKind::VmConst)
emitAddOffset(build, x3, rConstants, vmConstOp(inst.c) * sizeof(TValue));
@ -1372,6 +1394,63 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
finalizeTargetLabel(inst.b, fresh);
break;
}
case IrCmd::CHECK_BUFFER_LEN:
{
int accessSize = intOp(inst.c);
LUAU_ASSERT(accessSize > 0 && accessSize <= int(AssemblyBuilderA64::kMaxImmediate));
Label fresh; // used when guard aborts execution or jumps to a VM exit
Label& target = getTargetLabel(inst.d, fresh);
RegisterA64 temp = regs.allocTemp(KindA64::w);
build.ldr(temp, mem(regOp(inst.a), offsetof(Buffer, len)));
if (inst.b.kind == IrOpKind::Inst)
{
if (accessSize == 1)
{
// fails if offset >= len
build.cmp(temp, regOp(inst.b));
build.b(ConditionA64::UnsignedLessEqual, target);
}
else
{
// fails if offset + size >= len; we compute it as len - offset <= size
RegisterA64 tempx = castReg(KindA64::x, temp);
build.sub(tempx, tempx, regOp(inst.b)); // implicit uxtw
build.cmp(tempx, uint16_t(accessSize));
build.b(ConditionA64::LessEqual, target); // note: this is a signed 64-bit comparison so that out of bounds offset fails
}
}
else if (inst.b.kind == IrOpKind::Constant)
{
int offset = intOp(inst.b);
// Constant folding can take care of it, but for safety we avoid overflow/underflow cases here
if (offset < 0 || unsigned(offset) + unsigned(accessSize) >= unsigned(INT_MAX))
{
build.b(target);
}
else if (offset + accessSize <= int(AssemblyBuilderA64::kMaxImmediate))
{
build.cmp(temp, uint16_t(offset + accessSize));
build.b(ConditionA64::UnsignedLessEqual, target);
}
else
{
RegisterA64 temp2 = regs.allocTemp(KindA64::w);
build.mov(temp2, offset + accessSize);
build.cmp(temp, temp2);
build.b(ConditionA64::UnsignedLessEqual, target);
}
}
else
{
LUAU_ASSERT(!"Unsupported instruction form");
}
finalizeTargetLabel(inst.d, fresh);
break;
}
case IrCmd::INTERRUPT:
{
regs.spill(build, index);
@ -1912,6 +1991,13 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
build.clz(inst.regA64, inst.regA64);
break;
}
case IrCmd::BYTESWAP_UINT:
{
inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.a});
RegisterA64 temp = tempUint(inst.a);
build.rev(inst.regA64, temp);
break;
}
case IrCmd::INVOKE_LIBM:
{
if (inst.c.kind != IrOpKind::None)
@ -1960,7 +2046,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
LUAU_ASSERT(sizeof(TString*) == 8);
if (inst.a.kind == IrOpKind::Inst)
build.add(inst.regA64, rGlobalState, regOp(inst.a), 3);
build.add(inst.regA64, rGlobalState, regOp(inst.a), 3); // implicit uxtw
else if (inst.a.kind == IrOpKind::Constant)
build.add(inst.regA64, rGlobalState, uint16_t(tagOp(inst.a)) * 8);
else
@ -1993,6 +2079,118 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
break;
}
case IrCmd::BUFFER_READI8:
{
inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.b});
AddressA64 addr = tempAddrBuffer(inst.a, inst.b);
build.ldrsb(inst.regA64, addr);
break;
}
case IrCmd::BUFFER_READU8:
{
inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.b});
AddressA64 addr = tempAddrBuffer(inst.a, inst.b);
build.ldrb(inst.regA64, addr);
break;
}
case IrCmd::BUFFER_WRITEI8:
{
RegisterA64 temp = tempInt(inst.c);
AddressA64 addr = tempAddrBuffer(inst.a, inst.b);
build.strb(temp, addr);
break;
}
case IrCmd::BUFFER_READI16:
{
inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.b});
AddressA64 addr = tempAddrBuffer(inst.a, inst.b);
build.ldrsh(inst.regA64, addr);
break;
}
case IrCmd::BUFFER_READU16:
{
inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.b});
AddressA64 addr = tempAddrBuffer(inst.a, inst.b);
build.ldrh(inst.regA64, addr);
break;
}
case IrCmd::BUFFER_WRITEI16:
{
RegisterA64 temp = tempInt(inst.c);
AddressA64 addr = tempAddrBuffer(inst.a, inst.b);
build.strh(temp, addr);
break;
}
case IrCmd::BUFFER_READI32:
{
inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.b});
AddressA64 addr = tempAddrBuffer(inst.a, inst.b);
build.ldr(inst.regA64, addr);
break;
}
case IrCmd::BUFFER_WRITEI32:
{
RegisterA64 temp = tempInt(inst.c);
AddressA64 addr = tempAddrBuffer(inst.a, inst.b);
build.str(temp, addr);
break;
}
case IrCmd::BUFFER_READF32:
{
inst.regA64 = regs.allocReg(KindA64::d, index);
RegisterA64 temp = castReg(KindA64::s, inst.regA64); // safe to alias a fresh register
AddressA64 addr = tempAddrBuffer(inst.a, inst.b);
build.ldr(temp, addr);
build.fcvt(inst.regA64, temp);
break;
}
case IrCmd::BUFFER_WRITEF32:
{
RegisterA64 temp1 = tempDouble(inst.c);
RegisterA64 temp2 = regs.allocTemp(KindA64::s);
AddressA64 addr = tempAddrBuffer(inst.a, inst.b);
build.fcvt(temp2, temp1);
build.str(temp2, addr);
break;
}
case IrCmd::BUFFER_READF64:
{
inst.regA64 = regs.allocReg(KindA64::d, index);
AddressA64 addr = tempAddrBuffer(inst.a, inst.b);
build.ldr(inst.regA64, addr);
break;
}
case IrCmd::BUFFER_WRITEF64:
{
RegisterA64 temp = tempDouble(inst.c);
AddressA64 addr = tempAddrBuffer(inst.a, inst.b);
build.str(temp, addr);
break;
}
// To handle unsupported instructions, add "case IrCmd::OP" and make sure to set error = true!
}
@ -2119,9 +2317,7 @@ RegisterA64 IrLoweringA64::tempDouble(IrOp op)
RegisterA64 temp1 = regs.allocTemp(KindA64::x);
RegisterA64 temp2 = regs.allocTemp(KindA64::d);
uint64_t vali;
static_assert(sizeof(vali) == sizeof(val), "Expecting double to be 64-bit");
memcpy(&vali, &val, sizeof(val));
uint64_t vali = getDoubleBits(val);
if ((vali << 16) == 0)
{
@ -2217,6 +2413,35 @@ AddressA64 IrLoweringA64::tempAddr(IrOp op, int offset)
}
}
AddressA64 IrLoweringA64::tempAddrBuffer(IrOp bufferOp, IrOp indexOp)
{
if (indexOp.kind == IrOpKind::Inst)
{
RegisterA64 temp = regs.allocTemp(KindA64::x);
build.add(temp, regOp(bufferOp), regOp(indexOp)); // implicit uxtw
return mem(temp, offsetof(Buffer, data));
}
else if (indexOp.kind == IrOpKind::Constant)
{
// Since the resulting address may be used to load any size, including 1 byte, from an unaligned offset, we are limited by unscaled encoding
if (unsigned(intOp(indexOp)) + offsetof(Buffer, data) <= 255)
return mem(regOp(bufferOp), int(intOp(indexOp) + offsetof(Buffer, data)));
// indexOp can only be negative in dead code (since offsets are checked); this avoids assertion in emitAddOffset
if (intOp(indexOp) < 0)
return mem(regOp(bufferOp), offsetof(Buffer, data));
RegisterA64 temp = regs.allocTemp(KindA64::x);
emitAddOffset(build, temp, regOp(bufferOp), size_t(intOp(indexOp)));
return mem(temp, offsetof(Buffer, data));
}
else
{
LUAU_ASSERT(!"Unsupported instruction form");
return noreg;
}
}
RegisterA64 IrLoweringA64::regOp(IrOp op)
{
IrInst& inst = function.instOp(op);

View file

@ -44,6 +44,7 @@ struct IrLoweringA64
RegisterA64 tempInt(IrOp op);
RegisterA64 tempUint(IrOp op);
AddressA64 tempAddr(IrOp op, int offset);
AddressA64 tempAddrBuffer(IrOp bufferOp, IrOp indexOp);
// May emit restore instructions
RegisterA64 regOp(IrOp op);

View file

@ -15,6 +15,8 @@
#include "lstate.h"
#include "lgc.h"
LUAU_FASTFLAG(LuauCodeGenFixByteLower)
namespace Luau
{
namespace CodeGen
@ -213,28 +215,45 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
LUAU_ASSERT(!"Unsupported instruction form");
break;
case IrCmd::STORE_POINTER:
if (inst.a.kind == IrOpKind::Inst)
build.mov(qword[regOp(inst.a) + offsetof(TValue, value)], regOp(inst.b));
else
build.mov(luauRegValue(vmRegOp(inst.a)), regOp(inst.b));
break;
case IrCmd::STORE_DOUBLE:
{
OperandX64 valueLhs = inst.a.kind == IrOpKind::Inst ? qword[regOp(inst.a) + offsetof(TValue, value)] : luauRegValue(vmRegOp(inst.a));
if (inst.b.kind == IrOpKind::Constant)
{
ScopedRegX64 tmp{regs, SizeX64::xmmword};
build.vmovsd(tmp.reg, build.f64(doubleOp(inst.b)));
build.vmovsd(luauRegValue(vmRegOp(inst.a)), tmp.reg);
LUAU_ASSERT(intOp(inst.b) == 0);
build.mov(valueLhs, 0);
}
else if (inst.b.kind == IrOpKind::Inst)
{
build.vmovsd(luauRegValue(vmRegOp(inst.a)), regOp(inst.b));
build.mov(valueLhs, regOp(inst.b));
}
else
{
LUAU_ASSERT(!"Unsupported instruction form");
}
break;
}
case IrCmd::STORE_DOUBLE:
{
OperandX64 valueLhs = inst.a.kind == IrOpKind::Inst ? qword[regOp(inst.a) + offsetof(TValue, value)] : luauRegValue(vmRegOp(inst.a));
if (inst.b.kind == IrOpKind::Constant)
{
ScopedRegX64 tmp{regs, SizeX64::xmmword};
build.vmovsd(tmp.reg, build.f64(doubleOp(inst.b)));
build.vmovsd(valueLhs, tmp.reg);
}
else if (inst.b.kind == IrOpKind::Inst)
{
build.vmovsd(valueLhs, regOp(inst.b));
}
else
{
LUAU_ASSERT(!"Unsupported instruction form");
}
break;
}
case IrCmd::STORE_INT:
if (inst.b.kind == IrOpKind::Constant)
build.mov(luauRegValueInt(vmRegOp(inst.a)), intOp(inst.b));
@ -822,7 +841,19 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
case IrCmd::UINT_TO_NUM:
inst.regX64 = regs.allocReg(SizeX64::xmmword, index);
build.vcvtsi2sd(inst.regX64, inst.regX64, qwordReg(regOp(inst.a)));
// AVX has no uint->double conversion; the source must come from UINT op and they all should clear top 32 bits so we can usually
// use 64-bit reg; the one exception is NUM_TO_UINT which doesn't clear top bits
if (IrCmd source = function.instOp(inst.a).cmd; source == IrCmd::NUM_TO_UINT)
{
ScopedRegX64 tmp{regs, SizeX64::dword};
build.mov(tmp.reg, regOp(inst.a));
build.vcvtsi2sd(inst.regX64, inst.regX64, qwordReg(tmp.reg));
}
else
{
LUAU_ASSERT(source != IrCmd::SUBSTITUTE); // we don't process substitutions
build.vcvtsi2sd(inst.regX64, inst.regX64, qwordReg(regOp(inst.a)));
}
break;
case IrCmd::NUM_TO_INT:
inst.regX64 = regs.allocReg(SizeX64::dword, index);
@ -931,11 +962,12 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
break;
}
case IrCmd::DO_ARITH:
if (inst.c.kind == IrOpKind::VmReg)
callArithHelper(regs, build, vmRegOp(inst.a), vmRegOp(inst.b), luauRegAddress(vmRegOp(inst.c)), TMS(intOp(inst.d)));
else
callArithHelper(regs, build, vmRegOp(inst.a), vmRegOp(inst.b), luauConstantAddress(vmConstOp(inst.c)), TMS(intOp(inst.d)));
{
OperandX64 opb = inst.b.kind == IrOpKind::VmReg ? luauRegAddress(vmRegOp(inst.b)) : luauConstantAddress(vmConstOp(inst.b));
OperandX64 opc = inst.c.kind == IrOpKind::VmReg ? luauRegAddress(vmRegOp(inst.c)) : luauConstantAddress(vmConstOp(inst.c));
callArithHelper(regs, build, vmRegOp(inst.a), opb, opc, TMS(intOp(inst.d)));
break;
}
case IrCmd::DO_LEN:
callLengthHelper(regs, build, vmRegOp(inst.a), vmRegOp(inst.b));
break;
@ -1157,6 +1189,64 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
jumpOrAbortOnUndef(ConditionX64::Equal, inst.b, next);
break;
}
case IrCmd::CHECK_BUFFER_LEN:
{
int accessSize = intOp(inst.c);
LUAU_ASSERT(accessSize > 0);
if (inst.b.kind == IrOpKind::Inst)
{
if (accessSize == 1)
{
// Simpler check for a single byte access
build.cmp(dword[regOp(inst.a) + offsetof(Buffer, len)], regOp(inst.b));
jumpOrAbortOnUndef(ConditionX64::BelowEqual, inst.d, next);
}
else
{
ScopedRegX64 tmp1{regs, SizeX64::qword};
ScopedRegX64 tmp2{regs, SizeX64::dword};
// To perform the bounds check using a single branch, we take index that is limited to 32 bit int
// Access size is then added using a 64 bit addition
// This will make sure that addition will not wrap around for values like 0xffffffff
if (IrCmd source = function.instOp(inst.b).cmd; source == IrCmd::NUM_TO_INT)
{
// When previous operation is a conversion to an integer (common case), it is guaranteed to have high register bits cleared
build.lea(tmp1.reg, addr[qwordReg(regOp(inst.b)) + accessSize]);
}
else
{
// When the source of the index is unknown, it could contain garbage in the high bits, so we zero-extend it explicitly
build.mov(dwordReg(tmp1.reg), regOp(inst.b));
build.add(tmp1.reg, accessSize);
}
build.mov(tmp2.reg, dword[regOp(inst.a) + offsetof(Buffer, len)]);
build.cmp(qwordReg(tmp2.reg), tmp1.reg);
jumpOrAbortOnUndef(ConditionX64::Below, inst.d, next);
}
}
else if (inst.b.kind == IrOpKind::Constant)
{
int offset = intOp(inst.b);
// Constant folding can take care of it, but for safety we avoid overflow/underflow cases here
if (offset < 0 || unsigned(offset) + unsigned(accessSize) >= unsigned(INT_MAX))
jumpOrAbortOnUndef(inst.d, next);
else
build.cmp(dword[regOp(inst.a) + offsetof(Buffer, len)], offset + accessSize);
jumpOrAbortOnUndef(ConditionX64::Below, inst.d, next);
}
else
{
LUAU_ASSERT(!"Unsupported instruction form");
}
break;
}
case IrCmd::INTERRUPT:
{
unsigned pcpos = uintOp(inst.a);
@ -1633,6 +1723,16 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
build.setLabel(exit);
break;
}
case IrCmd::BYTESWAP_UINT:
{
inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a});
if (inst.a.kind != IrOpKind::Inst || inst.regX64 != regOp(inst.a))
build.mov(inst.regX64, memRegUintOp(inst.a));
build.bswap(inst.regX64);
break;
}
case IrCmd::INVOKE_LIBM:
{
IrCallWrapperX64 callWrap(regs, build, index);
@ -1689,6 +1789,111 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
break;
}
case IrCmd::BUFFER_READI8:
inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b});
build.movsx(inst.regX64, byte[bufferAddrOp(inst.a, inst.b)]);
break;
case IrCmd::BUFFER_READU8:
inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b});
build.movzx(inst.regX64, byte[bufferAddrOp(inst.a, inst.b)]);
break;
case IrCmd::BUFFER_WRITEI8:
{
if (FFlag::LuauCodeGenFixByteLower)
{
OperandX64 value = inst.c.kind == IrOpKind::Inst ? byteReg(regOp(inst.c)) : OperandX64(int8_t(intOp(inst.c)));
build.mov(byte[bufferAddrOp(inst.a, inst.b)], value);
}
else
{
OperandX64 value = inst.c.kind == IrOpKind::Inst ? byteReg(regOp(inst.c)) : OperandX64(intOp(inst.c));
build.mov(byte[bufferAddrOp(inst.a, inst.b)], value);
}
break;
}
case IrCmd::BUFFER_READI16:
inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b});
build.movsx(inst.regX64, word[bufferAddrOp(inst.a, inst.b)]);
break;
case IrCmd::BUFFER_READU16:
inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b});
build.movzx(inst.regX64, word[bufferAddrOp(inst.a, inst.b)]);
break;
case IrCmd::BUFFER_WRITEI16:
{
if (FFlag::LuauCodeGenFixByteLower)
{
OperandX64 value = inst.c.kind == IrOpKind::Inst ? wordReg(regOp(inst.c)) : OperandX64(int16_t(intOp(inst.c)));
build.mov(word[bufferAddrOp(inst.a, inst.b)], value);
}
else
{
OperandX64 value = inst.c.kind == IrOpKind::Inst ? wordReg(regOp(inst.c)) : OperandX64(intOp(inst.c));
build.mov(word[bufferAddrOp(inst.a, inst.b)], value);
}
break;
}
case IrCmd::BUFFER_READI32:
inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b});
build.mov(inst.regX64, dword[bufferAddrOp(inst.a, inst.b)]);
break;
case IrCmd::BUFFER_WRITEI32:
{
OperandX64 value = inst.c.kind == IrOpKind::Inst ? regOp(inst.c) : OperandX64(intOp(inst.c));
build.mov(dword[bufferAddrOp(inst.a, inst.b)], value);
break;
}
case IrCmd::BUFFER_READF32:
inst.regX64 = regs.allocReg(SizeX64::xmmword, index);
build.vcvtss2sd(inst.regX64, inst.regX64, dword[bufferAddrOp(inst.a, inst.b)]);
break;
case IrCmd::BUFFER_WRITEF32:
storeDoubleAsFloat(dword[bufferAddrOp(inst.a, inst.b)], inst.c);
break;
case IrCmd::BUFFER_READF64:
inst.regX64 = regs.allocReg(SizeX64::xmmword, index);
build.vmovsd(inst.regX64, qword[bufferAddrOp(inst.a, inst.b)]);
break;
case IrCmd::BUFFER_WRITEF64:
if (inst.c.kind == IrOpKind::Constant)
{
ScopedRegX64 tmp{regs, SizeX64::xmmword};
build.vmovsd(tmp.reg, build.f64(doubleOp(inst.c)));
build.vmovsd(qword[bufferAddrOp(inst.a, inst.b)], tmp.reg);
}
else if (inst.c.kind == IrOpKind::Inst)
{
build.vmovsd(qword[bufferAddrOp(inst.a, inst.b)], regOp(inst.c));
}
else
{
LUAU_ASSERT(!"Unsupported instruction form");
}
break;
// Pseudo instructions
case IrCmd::NOP:
case IrCmd::SUBSTITUTE:
@ -1707,7 +1912,7 @@ void IrLoweringX64::finishBlock(const IrBlock& curr, const IrBlock& next)
{
// If we have spills remaining, we have to immediately lower the successor block
for (uint32_t predIdx : predecessors(function.cfg, function.getBlockIndex(next)))
LUAU_ASSERT(predIdx == function.getBlockIndex(curr));
LUAU_ASSERT(predIdx == function.getBlockIndex(curr) || function.blocks[predIdx].kind == IrBlockKind::Dead);
// And the next block cannot be a join block in cfg
LUAU_ASSERT(next.useCount == 1);
@ -1900,6 +2105,17 @@ RegisterX64 IrLoweringX64::regOp(IrOp op)
return inst.regX64;
}
OperandX64 IrLoweringX64::bufferAddrOp(IrOp bufferOp, IrOp indexOp)
{
if (indexOp.kind == IrOpKind::Inst)
return regOp(bufferOp) + qwordReg(regOp(indexOp)) + offsetof(Buffer, data);
else if (indexOp.kind == IrOpKind::Constant)
return regOp(bufferOp) + intOp(indexOp) + offsetof(Buffer, data);
LUAU_ASSERT(!"Unsupported instruction form");
return noreg;
}
IrConst IrLoweringX64::constOp(IrOp op) const
{
return function.constOp(op);

View file

@ -50,6 +50,7 @@ struct IrLoweringX64
OperandX64 memRegUintOp(IrOp op);
OperandX64 memRegTagOp(IrOp op);
RegisterX64 regOp(IrOp op);
OperandX64 bufferAddrOp(IrOp bufferOp, IrOp indexOp);
IrConst constOp(IrOp op) const;
uint8_t tagOp(IrOp op) const;

View file

@ -8,6 +8,8 @@
#include <math.h>
LUAU_FASTFLAGVARIABLE(LuauBufferTranslateIr, false)
// TODO: when nresults is less than our actual result count, we can skip computing/writing unused results
static const int kMinMaxUnrolledParams = 5;
@ -150,13 +152,12 @@ static BuiltinImplResult translateBuiltinMathDegRad(IrBuilder& build, IrCmd cmd,
return {BuiltinImplType::Full, 1};
}
static BuiltinImplResult translateBuiltinMathLog(
IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, int pcpos)
static BuiltinImplResult translateBuiltinMathLog(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, int pcpos)
{
if (nparams < 1 || nresults > 1)
return {BuiltinImplType::None, -1};
int libmId = bfid;
int libmId = LBF_MATH_LOG;
std::optional<double> denom;
if (nparams != 1)
@ -298,7 +299,7 @@ static BuiltinImplResult translateBuiltinTypeof(IrBuilder& build, int nparams, i
}
static BuiltinImplResult translateBuiltinBit32BinaryOp(
IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, int pcpos)
IrBuilder& build, IrCmd cmd, bool btest, int nparams, int ra, int arg, IrOp args, int nresults, int pcpos)
{
if (nparams < 2 || nparams > kBit32BinaryOpUnrolledParams || nresults > 1)
return {BuiltinImplType::None, -1};
@ -315,17 +316,6 @@ static BuiltinImplResult translateBuiltinBit32BinaryOp(
IrOp vaui = build.inst(IrCmd::NUM_TO_UINT, va);
IrOp vbui = build.inst(IrCmd::NUM_TO_UINT, vb);
IrCmd cmd = IrCmd::NOP;
if (bfid == LBF_BIT32_BAND || bfid == LBF_BIT32_BTEST)
cmd = IrCmd::BITAND_UINT;
else if (bfid == LBF_BIT32_BXOR)
cmd = IrCmd::BITXOR_UINT;
else if (bfid == LBF_BIT32_BOR)
cmd = IrCmd::BITOR_UINT;
LUAU_ASSERT(cmd != IrCmd::NOP);
IrOp res = build.inst(cmd, vaui, vbui);
for (int i = 3; i <= nparams; ++i)
@ -336,7 +326,7 @@ static BuiltinImplResult translateBuiltinBit32BinaryOp(
res = build.inst(cmd, res, arg);
}
if (bfid == LBF_BIT32_BTEST)
if (btest)
{
IrOp falsey = build.block(IrBlockKind::Internal);
IrOp truthy = build.block(IrBlockKind::Internal);
@ -351,7 +341,6 @@ static BuiltinImplResult translateBuiltinBit32BinaryOp(
build.inst(IrCmd::STORE_INT, build.vmReg(ra), build.constInt(1));
build.inst(IrCmd::JUMP, exit);
build.beginBlock(exit);
build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TBOOLEAN));
}
@ -367,8 +356,7 @@ static BuiltinImplResult translateBuiltinBit32BinaryOp(
return {BuiltinImplType::Full, 1};
}
static BuiltinImplResult translateBuiltinBit32Bnot(
IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, int pcpos)
static BuiltinImplResult translateBuiltinBit32Bnot(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, int pcpos)
{
if (nparams < 1 || nresults > 1)
return {BuiltinImplType::None, -1};
@ -389,7 +377,7 @@ static BuiltinImplResult translateBuiltinBit32Bnot(
}
static BuiltinImplResult translateBuiltinBit32Shift(
IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback, int pcpos)
IrBuilder& build, IrCmd cmd, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback, int pcpos)
{
if (nparams < 2 || nresults > 1)
return {BuiltinImplType::None, -1};
@ -418,16 +406,6 @@ static BuiltinImplResult translateBuiltinBit32Shift(
build.beginBlock(block);
}
IrCmd cmd = IrCmd::NOP;
if (bfid == LBF_BIT32_LSHIFT)
cmd = IrCmd::BITLSHIFT_UINT;
else if (bfid == LBF_BIT32_RSHIFT)
cmd = IrCmd::BITRSHIFT_UINT;
else if (bfid == LBF_BIT32_ARSHIFT)
cmd = IrCmd::BITARSHIFT_UINT;
LUAU_ASSERT(cmd != IrCmd::NOP);
IrOp shift = build.inst(cmd, vaui, vbi);
IrOp value = build.inst(IrCmd::UINT_TO_NUM, shift);
@ -439,8 +417,7 @@ static BuiltinImplResult translateBuiltinBit32Shift(
return {BuiltinImplType::UsesFallback, 1};
}
static BuiltinImplResult translateBuiltinBit32Rotate(
IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, int pcpos)
static BuiltinImplResult translateBuiltinBit32Rotate(IrBuilder& build, IrCmd cmd, int nparams, int ra, int arg, IrOp args, int nresults, int pcpos)
{
if (nparams < 2 || nresults > 1)
return {BuiltinImplType::None, -1};
@ -454,7 +431,6 @@ static BuiltinImplResult translateBuiltinBit32Rotate(
IrOp vaui = build.inst(IrCmd::NUM_TO_UINT, va);
IrOp vbi = build.inst(IrCmd::NUM_TO_INT, vb);
IrCmd cmd = (bfid == LBF_BIT32_LROTATE) ? IrCmd::BITLROTATE_UINT : IrCmd::BITRROTATE_UINT;
IrOp shift = build.inst(cmd, vaui, vbi);
IrOp value = build.inst(IrCmd::UINT_TO_NUM, shift);
@ -467,7 +443,7 @@ static BuiltinImplResult translateBuiltinBit32Rotate(
}
static BuiltinImplResult translateBuiltinBit32Extract(
IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback, int pcpos)
IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback, int pcpos)
{
if (nparams < 2 || nresults > 1)
return {BuiltinImplType::None, -1};
@ -547,8 +523,7 @@ static BuiltinImplResult translateBuiltinBit32Extract(
return {BuiltinImplType::UsesFallback, 1};
}
static BuiltinImplResult translateBuiltinBit32ExtractK(
IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, int pcpos)
static BuiltinImplResult translateBuiltinBit32ExtractK(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, int pcpos)
{
if (nparams < 2 || nresults > 1)
return {BuiltinImplType::None, -1};
@ -583,8 +558,7 @@ static BuiltinImplResult translateBuiltinBit32ExtractK(
return {BuiltinImplType::Full, 1};
}
static BuiltinImplResult translateBuiltinBit32Countz(
IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, int pcpos)
static BuiltinImplResult translateBuiltinBit32Unary(IrBuilder& build, IrCmd cmd, int nparams, int ra, int arg, IrOp args, int nresults, int pcpos)
{
if (nparams < 1 || nresults > 1)
return {BuiltinImplType::None, -1};
@ -594,7 +568,6 @@ static BuiltinImplResult translateBuiltinBit32Countz(
IrOp vaui = build.inst(IrCmd::NUM_TO_UINT, va);
IrCmd cmd = (bfid == LBF_BIT32_COUNTLZ) ? IrCmd::BITCOUNTLZ_UINT : IrCmd::BITCOUNTRZ_UINT;
IrOp bin = build.inst(cmd, vaui);
IrOp value = build.inst(IrCmd::UINT_TO_NUM, bin);
@ -608,7 +581,7 @@ static BuiltinImplResult translateBuiltinBit32Countz(
}
static BuiltinImplResult translateBuiltinBit32Replace(
IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback, int pcpos)
IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback, int pcpos)
{
if (nparams < 3 || nresults > 1)
return {BuiltinImplType::None, -1};
@ -632,7 +605,6 @@ static BuiltinImplResult translateBuiltinBit32Replace(
build.inst(IrCmd::JUMP_CMP_INT, f, build.constInt(32), build.cond(IrCondition::UnsignedGreaterEqual), fallback, block);
build.beginBlock(block);
// TODO: this can be optimized using a bit-select instruction (btr on x86)
IrOp m = build.constInt(1);
IrOp shift = build.inst(IrCmd::BITLSHIFT_UINT, m, f);
IrOp not_ = build.inst(IrCmd::BITNOT_UINT, shift);
@ -718,10 +690,25 @@ static BuiltinImplResult translateBuiltinTableInsert(IrBuilder& build, int npara
IrOp setnum = build.inst(IrCmd::TABLE_SETNUM, table, pos);
IrOp va = build.inst(IrCmd::LOAD_TVALUE, args);
build.inst(IrCmd::STORE_TVALUE, setnum, va);
if (args.kind == IrOpKind::Constant)
{
LUAU_ASSERT(build.function.constOp(args).kind == IrConstKind::Double);
build.inst(IrCmd::BARRIER_TABLE_FORWARD, table, args, build.undef());
// No barrier necessary since numbers aren't collectable
build.inst(IrCmd::STORE_DOUBLE, setnum, args);
build.inst(IrCmd::STORE_TAG, setnum, build.constTag(LUA_TNUMBER));
}
else
{
IrOp va = build.inst(IrCmd::LOAD_TVALUE, args);
build.inst(IrCmd::STORE_TVALUE, setnum, va);
// Compiler only generates FASTCALL*K for source-level constants, so dynamic imports are not affected
LUAU_ASSERT(build.function.proto);
IrOp argstag = args.kind == IrOpKind::VmConst ? build.constTag(build.function.proto->k[vmConstOp(args)].tt) : build.undef();
build.inst(IrCmd::BARRIER_TABLE_FORWARD, table, args, argstag);
}
return {BuiltinImplType::Full, 0};
}
@ -743,6 +730,59 @@ static BuiltinImplResult translateBuiltinStringLen(IrBuilder& build, int nparams
return {BuiltinImplType::Full, 1};
}
static void translateBufferArgsAndCheckBounds(IrBuilder& build, int nparams, int arg, IrOp args, int size, int pcpos, IrOp& buf, IrOp& intIndex)
{
build.loadAndCheckTag(build.vmReg(arg), LUA_TBUFFER, build.vmExit(pcpos));
builtinCheckDouble(build, args, pcpos);
if (nparams == 3)
builtinCheckDouble(build, build.vmReg(vmRegOp(args) + 1), pcpos);
buf = build.inst(IrCmd::LOAD_POINTER, build.vmReg(arg));
IrOp numIndex = builtinLoadDouble(build, args);
intIndex = build.inst(IrCmd::NUM_TO_INT, numIndex);
build.inst(IrCmd::CHECK_BUFFER_LEN, buf, intIndex, build.constInt(size), build.vmExit(pcpos));
}
static BuiltinImplResult translateBuiltinBufferRead(
IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, int pcpos, IrCmd readCmd, int size, IrCmd convCmd)
{
if (!FFlag::LuauBufferTranslateIr)
return {BuiltinImplType::None, -1};
if (nparams < 2 || nresults > 1)
return {BuiltinImplType::None, -1};
IrOp buf, intIndex;
translateBufferArgsAndCheckBounds(build, nparams, arg, args, size, pcpos, buf, intIndex);
IrOp result = build.inst(readCmd, buf, intIndex);
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), convCmd == IrCmd::NOP ? result : build.inst(convCmd, result));
build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER));
return {BuiltinImplType::Full, 1};
}
static BuiltinImplResult translateBuiltinBufferWrite(
IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, int pcpos, IrCmd writeCmd, int size, IrCmd convCmd)
{
if (!FFlag::LuauBufferTranslateIr)
return {BuiltinImplType::None, -1};
if (nparams < 3 || nresults > 0)
return {BuiltinImplType::None, -1};
IrOp buf, intIndex;
translateBufferArgsAndCheckBounds(build, nparams, arg, args, size, pcpos, buf, intIndex);
IrOp numValue = builtinLoadDouble(build, build.vmReg(vmRegOp(args) + 1));
build.inst(writeCmd, buf, intIndex, convCmd == IrCmd::NOP ? numValue : build.inst(convCmd, numValue));
return {BuiltinImplType::Full, 0};
}
BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg, IrOp args, int nparams, int nresults, IrOp fallback, int pcpos)
{
// Builtins are not allowed to handle variadic arguments
@ -758,7 +798,7 @@ BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg,
case LBF_MATH_RAD:
return translateBuiltinMathDegRad(build, IrCmd::MUL_NUM, nparams, ra, arg, args, nresults, pcpos);
case LBF_MATH_LOG:
return translateBuiltinMathLog(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, pcpos);
return translateBuiltinMathLog(build, nparams, ra, arg, args, nresults, pcpos);
case LBF_MATH_MIN:
return translateBuiltinMathMinMax(build, IrCmd::MIN_NUM, nparams, ra, arg, args, nresults, pcpos);
case LBF_MATH_MAX:
@ -798,28 +838,35 @@ BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg,
case LBF_MATH_MODF:
return translateBuiltinNumberTo2Number(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, pcpos);
case LBF_BIT32_BAND:
return translateBuiltinBit32BinaryOp(build, IrCmd::BITAND_UINT, /* btest= */ false, nparams, ra, arg, args, nresults, pcpos);
case LBF_BIT32_BOR:
return translateBuiltinBit32BinaryOp(build, IrCmd::BITOR_UINT, /* btest= */ false, nparams, ra, arg, args, nresults, pcpos);
case LBF_BIT32_BXOR:
return translateBuiltinBit32BinaryOp(build, IrCmd::BITXOR_UINT, /* btest= */ false, nparams, ra, arg, args, nresults, pcpos);
case LBF_BIT32_BTEST:
return translateBuiltinBit32BinaryOp(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, pcpos);
return translateBuiltinBit32BinaryOp(build, IrCmd::BITAND_UINT, /* btest= */ true, nparams, ra, arg, args, nresults, pcpos);
case LBF_BIT32_BNOT:
return translateBuiltinBit32Bnot(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, pcpos);
return translateBuiltinBit32Bnot(build, nparams, ra, arg, args, nresults, pcpos);
case LBF_BIT32_LSHIFT:
return translateBuiltinBit32Shift(build, IrCmd::BITLSHIFT_UINT, nparams, ra, arg, args, nresults, fallback, pcpos);
case LBF_BIT32_RSHIFT:
return translateBuiltinBit32Shift(build, IrCmd::BITRSHIFT_UINT, nparams, ra, arg, args, nresults, fallback, pcpos);
case LBF_BIT32_ARSHIFT:
return translateBuiltinBit32Shift(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback, pcpos);
return translateBuiltinBit32Shift(build, IrCmd::BITARSHIFT_UINT, nparams, ra, arg, args, nresults, fallback, pcpos);
case LBF_BIT32_LROTATE:
return translateBuiltinBit32Rotate(build, IrCmd::BITLROTATE_UINT, nparams, ra, arg, args, nresults, pcpos);
case LBF_BIT32_RROTATE:
return translateBuiltinBit32Rotate(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, pcpos);
return translateBuiltinBit32Rotate(build, IrCmd::BITRROTATE_UINT, nparams, ra, arg, args, nresults, pcpos);
case LBF_BIT32_EXTRACT:
return translateBuiltinBit32Extract(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback, pcpos);
return translateBuiltinBit32Extract(build, nparams, ra, arg, args, nresults, fallback, pcpos);
case LBF_BIT32_EXTRACTK:
return translateBuiltinBit32ExtractK(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, pcpos);
return translateBuiltinBit32ExtractK(build, nparams, ra, arg, args, nresults, pcpos);
case LBF_BIT32_COUNTLZ:
return translateBuiltinBit32Unary(build, IrCmd::BITCOUNTLZ_UINT, nparams, ra, arg, args, nresults, pcpos);
case LBF_BIT32_COUNTRZ:
return translateBuiltinBit32Countz(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, pcpos);
return translateBuiltinBit32Unary(build, IrCmd::BITCOUNTRZ_UINT, nparams, ra, arg, args, nresults, pcpos);
case LBF_BIT32_REPLACE:
return translateBuiltinBit32Replace(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback, pcpos);
return translateBuiltinBit32Replace(build, nparams, ra, arg, args, nresults, fallback, pcpos);
case LBF_TYPE:
return translateBuiltinType(build, nparams, ra, arg, args, nresults);
case LBF_TYPEOF:
@ -830,6 +877,34 @@ BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg,
return translateBuiltinTableInsert(build, nparams, ra, arg, args, nresults, pcpos);
case LBF_STRING_LEN:
return translateBuiltinStringLen(build, nparams, ra, arg, args, nresults, pcpos);
case LBF_BIT32_BYTESWAP:
return translateBuiltinBit32Unary(build, IrCmd::BYTESWAP_UINT, nparams, ra, arg, args, nresults, pcpos);
case LBF_BUFFER_READI8:
return translateBuiltinBufferRead(build, nparams, ra, arg, args, nresults, pcpos, IrCmd::BUFFER_READI8, 1, IrCmd::INT_TO_NUM);
case LBF_BUFFER_READU8:
return translateBuiltinBufferRead(build, nparams, ra, arg, args, nresults, pcpos, IrCmd::BUFFER_READU8, 1, IrCmd::INT_TO_NUM);
case LBF_BUFFER_WRITEU8:
return translateBuiltinBufferWrite(build, nparams, ra, arg, args, nresults, pcpos, IrCmd::BUFFER_WRITEI8, 1, IrCmd::NUM_TO_UINT);
case LBF_BUFFER_READI16:
return translateBuiltinBufferRead(build, nparams, ra, arg, args, nresults, pcpos, IrCmd::BUFFER_READI16, 2, IrCmd::INT_TO_NUM);
case LBF_BUFFER_READU16:
return translateBuiltinBufferRead(build, nparams, ra, arg, args, nresults, pcpos, IrCmd::BUFFER_READU16, 2, IrCmd::INT_TO_NUM);
case LBF_BUFFER_WRITEU16:
return translateBuiltinBufferWrite(build, nparams, ra, arg, args, nresults, pcpos, IrCmd::BUFFER_WRITEI16, 2, IrCmd::NUM_TO_UINT);
case LBF_BUFFER_READI32:
return translateBuiltinBufferRead(build, nparams, ra, arg, args, nresults, pcpos, IrCmd::BUFFER_READI32, 4, IrCmd::INT_TO_NUM);
case LBF_BUFFER_READU32:
return translateBuiltinBufferRead(build, nparams, ra, arg, args, nresults, pcpos, IrCmd::BUFFER_READI32, 4, IrCmd::UINT_TO_NUM);
case LBF_BUFFER_WRITEU32:
return translateBuiltinBufferWrite(build, nparams, ra, arg, args, nresults, pcpos, IrCmd::BUFFER_WRITEI32, 4, IrCmd::NUM_TO_UINT);
case LBF_BUFFER_READF32:
return translateBuiltinBufferRead(build, nparams, ra, arg, args, nresults, pcpos, IrCmd::BUFFER_READF32, 4, IrCmd::NOP);
case LBF_BUFFER_WRITEF32:
return translateBuiltinBufferWrite(build, nparams, ra, arg, args, nresults, pcpos, IrCmd::BUFFER_WRITEF32, 4, IrCmd::NOP);
case LBF_BUFFER_READF64:
return translateBuiltinBufferRead(build, nparams, ra, arg, args, nresults, pcpos, IrCmd::BUFFER_READF64, 8, IrCmd::NOP);
case LBF_BUFFER_WRITEF64:
return translateBuiltinBufferWrite(build, nparams, ra, arg, args, nresults, pcpos, IrCmd::BUFFER_WRITEF64, 8, IrCmd::NOP);
default:
return {BuiltinImplType::None, -1};
}

View file

@ -12,9 +12,8 @@
#include "lstate.h"
#include "ltm.h"
LUAU_FASTFLAG(LuauReduceStackSpills)
LUAU_FASTFLAGVARIABLE(LuauInlineArrConstOffset, false)
LUAU_FASTFLAGVARIABLE(LuauLowerAltLoopForn, false)
LUAU_FASTFLAGVARIABLE(LuauFullLoopLuserdata, false)
LUAU_FASTFLAGVARIABLE(LuauLoopInterruptFix, false)
namespace Luau
{
@ -44,6 +43,14 @@ struct FallbackStreamScope
IrOp next;
};
static IrOp getInitializedFallback(IrBuilder& build, IrOp& fallback)
{
if (fallback.kind == IrOpKind::None)
fallback = build.block(IrBlockKind::Fallback);
return fallback;
}
void translateInstLoadNil(IrBuilder& build, const Instruction* pc)
{
int ra = LUAU_INSN_A(*pc);
@ -327,25 +334,43 @@ void translateInstJumpxEqS(IrBuilder& build, const Instruction* pc, int pcpos)
build.beginBlock(next);
}
static void translateInstBinaryNumeric(IrBuilder& build, int ra, int rb, int rc, IrOp opc, int pcpos, TMS tm)
static void translateInstBinaryNumeric(IrBuilder& build, int ra, int rb, int rc, IrOp opb, IrOp opc, int pcpos, TMS tm)
{
IrOp fallback = build.block(IrBlockKind::Fallback);
IrOp fallback;
BytecodeTypes bcTypes = build.function.getBytecodeTypesAt(pcpos);
// fast-path: number
IrOp tb = build.inst(IrCmd::LOAD_TAG, build.vmReg(rb));
build.inst(IrCmd::CHECK_TAG, tb, build.constTag(LUA_TNUMBER), fallback);
if (rc != -1 && rc != rb) // TODO: optimization should handle second check, but we'll test it later
if (rb != -1)
{
IrOp tc = build.inst(IrCmd::LOAD_TAG, build.vmReg(rc));
build.inst(IrCmd::CHECK_TAG, tc, build.constTag(LUA_TNUMBER), fallback);
IrOp tb = build.inst(IrCmd::LOAD_TAG, build.vmReg(rb));
build.inst(IrCmd::CHECK_TAG, tb, build.constTag(LUA_TNUMBER),
bcTypes.a == LBC_TYPE_NUMBER ? build.vmExit(pcpos) : getInitializedFallback(build, fallback));
}
IrOp vb = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(rb));
IrOp vc;
if (rc != -1 && rc != rb)
{
IrOp tc = build.inst(IrCmd::LOAD_TAG, build.vmReg(rc));
build.inst(IrCmd::CHECK_TAG, tc, build.constTag(LUA_TNUMBER),
bcTypes.b == LBC_TYPE_NUMBER ? build.vmExit(pcpos) : getInitializedFallback(build, fallback));
}
IrOp vb, vc;
IrOp result;
if (opb.kind == IrOpKind::VmConst)
{
LUAU_ASSERT(build.function.proto);
TValue protok = build.function.proto->k[vmConstOp(opb)];
LUAU_ASSERT(protok.tt == LUA_TNUMBER);
vb = build.constDouble(protok.value.n);
}
else
{
vb = build.inst(IrCmd::LOAD_DOUBLE, opb);
}
if (opc.kind == IrOpKind::VmConst)
{
LUAU_ASSERT(build.function.proto);
@ -405,22 +430,33 @@ static void translateInstBinaryNumeric(IrBuilder& build, int ra, int rb, int rc,
if (ra != rb && ra != rc) // TODO: optimization should handle second check, but we'll test this later
build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER));
IrOp next = build.blockAtInst(pcpos + 1);
FallbackStreamScope scope(build, fallback, next);
if (fallback.kind != IrOpKind::None)
{
IrOp next = build.blockAtInst(pcpos + 1);
FallbackStreamScope scope(build, fallback, next);
build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1));
build.inst(IrCmd::DO_ARITH, build.vmReg(ra), build.vmReg(rb), opc, build.constInt(tm));
build.inst(IrCmd::JUMP, next);
build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1));
build.inst(IrCmd::DO_ARITH, build.vmReg(ra), opb, opc, build.constInt(tm));
build.inst(IrCmd::JUMP, next);
}
}
void translateInstBinary(IrBuilder& build, const Instruction* pc, int pcpos, TMS tm)
{
translateInstBinaryNumeric(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), LUAU_INSN_C(*pc), build.vmReg(LUAU_INSN_C(*pc)), pcpos, tm);
translateInstBinaryNumeric(
build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), LUAU_INSN_C(*pc), build.vmReg(LUAU_INSN_B(*pc)), build.vmReg(LUAU_INSN_C(*pc)), pcpos, tm);
}
void translateInstBinaryK(IrBuilder& build, const Instruction* pc, int pcpos, TMS tm)
{
translateInstBinaryNumeric(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), -1, build.vmConst(LUAU_INSN_C(*pc)), pcpos, tm);
translateInstBinaryNumeric(
build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), -1, build.vmReg(LUAU_INSN_B(*pc)), build.vmConst(LUAU_INSN_C(*pc)), pcpos, tm);
}
void translateInstBinaryRK(IrBuilder& build, const Instruction* pc, int pcpos, TMS tm)
{
translateInstBinaryNumeric(
build, LUAU_INSN_A(*pc), -1, LUAU_INSN_C(*pc), build.vmConst(LUAU_INSN_B(*pc)), build.vmReg(LUAU_INSN_C(*pc)), pcpos, tm);
}
void translateInstNot(IrBuilder& build, const Instruction* pc)
@ -439,13 +475,15 @@ void translateInstNot(IrBuilder& build, const Instruction* pc)
void translateInstMinus(IrBuilder& build, const Instruction* pc, int pcpos)
{
IrOp fallback;
BytecodeTypes bcTypes = build.function.getBytecodeTypesAt(pcpos);
int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc);
IrOp fallback = build.block(IrBlockKind::Fallback);
IrOp tb = build.inst(IrCmd::LOAD_TAG, build.vmReg(rb));
build.inst(IrCmd::CHECK_TAG, tb, build.constTag(LUA_TNUMBER), fallback);
build.inst(IrCmd::CHECK_TAG, tb, build.constTag(LUA_TNUMBER),
bcTypes.a == LBC_TYPE_NUMBER ? build.vmExit(pcpos) : getInitializedFallback(build, fallback));
// fast-path: number
IrOp vb = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(rb));
@ -456,23 +494,29 @@ void translateInstMinus(IrBuilder& build, const Instruction* pc, int pcpos)
if (ra != rb)
build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER));
IrOp next = build.blockAtInst(pcpos + 1);
FallbackStreamScope scope(build, fallback, next);
if (fallback.kind != IrOpKind::None)
{
IrOp next = build.blockAtInst(pcpos + 1);
FallbackStreamScope scope(build, fallback, next);
build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1));
build.inst(IrCmd::DO_ARITH, build.vmReg(LUAU_INSN_A(*pc)), build.vmReg(LUAU_INSN_B(*pc)), build.vmReg(LUAU_INSN_B(*pc)), build.constInt(TM_UNM));
build.inst(IrCmd::JUMP, next);
build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1));
build.inst(
IrCmd::DO_ARITH, build.vmReg(LUAU_INSN_A(*pc)), build.vmReg(LUAU_INSN_B(*pc)), build.vmReg(LUAU_INSN_B(*pc)), build.constInt(TM_UNM));
build.inst(IrCmd::JUMP, next);
}
}
void translateInstLength(IrBuilder& build, const Instruction* pc, int pcpos)
{
BytecodeTypes bcTypes = build.function.getBytecodeTypesAt(pcpos);
int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc);
IrOp fallback = build.block(IrBlockKind::Fallback);
IrOp tb = build.inst(IrCmd::LOAD_TAG, build.vmReg(rb));
build.inst(IrCmd::CHECK_TAG, tb, build.constTag(LUA_TTABLE), fallback);
build.inst(IrCmd::CHECK_TAG, tb, build.constTag(LUA_TTABLE), bcTypes.a == LBC_TYPE_TABLE ? build.vmExit(pcpos) : fallback);
// fast-path: table without __len
IrOp vb = build.inst(IrCmd::LOAD_POINTER, build.vmReg(rb));
@ -562,9 +606,10 @@ IrOp translateFastCallN(IrBuilder& build, const Instruction* pc, int pcpos, bool
IrOp builtinArgs = args;
if (customArgs.kind == IrOpKind::VmConst && bfid != LBF_TABLE_INSERT)
if (customArgs.kind == IrOpKind::VmConst)
{
TValue protok = build.function.proto->k[customArgs.index];
LUAU_ASSERT(build.function.proto);
TValue protok = build.function.proto->k[vmConstOp(customArgs)];
if (protok.tt == LUA_TNUMBER)
builtinArgs = build.constDouble(protok.value.n);
@ -632,18 +677,18 @@ static IrOp getLoopStepK(IrBuilder& build, int ra)
return build.undef();
}
void beforeInstForNPrep(IrBuilder& build, const Instruction* pc)
void beforeInstForNPrep(IrBuilder& build, const Instruction* pc, int pcpos)
{
int ra = LUAU_INSN_A(*pc);
IrOp stepK = getLoopStepK(build, ra);
build.loopStepStack.push_back(stepK);
build.numericLoopStack.push_back({stepK, pcpos + 1});
}
void afterInstForNLoop(IrBuilder& build, const Instruction* pc)
{
LUAU_ASSERT(!build.loopStepStack.empty());
build.loopStepStack.pop_back();
LUAU_ASSERT(!build.numericLoopStack.empty());
build.numericLoopStack.pop_back();
}
void translateInstForNPrep(IrBuilder& build, const Instruction* pc, int pcpos)
@ -653,8 +698,8 @@ void translateInstForNPrep(IrBuilder& build, const Instruction* pc, int pcpos)
IrOp loopStart = build.blockAtInst(pcpos + getOpLength(LuauOpcode(LUAU_INSN_OP(*pc))));
IrOp loopExit = build.blockAtInst(getJumpTarget(*pc, pcpos));
LUAU_ASSERT(!build.loopStepStack.empty());
IrOp stepK = build.loopStepStack.back();
LUAU_ASSERT(!build.numericLoopStack.empty());
IrOp stepK = build.numericLoopStack.back().step;
// When loop parameters are not numbers, VM tries to perform type coercion from string and raises an exception if that fails
// Performing that fallback in native code increases code size and complicates CFG, obscuring the values when they are constant
@ -672,35 +717,9 @@ void translateInstForNPrep(IrBuilder& build, const Instruction* pc, int pcpos)
IrOp tagStep = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 1));
build.inst(IrCmd::CHECK_TAG, tagStep, build.constTag(LUA_TNUMBER), build.vmExit(pcpos));
if (FFlag::LuauLowerAltLoopForn)
{
IrOp step = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 1));
IrOp step = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 1));
build.inst(IrCmd::JUMP_FORN_LOOP_COND, idx, limit, step, loopStart, loopExit);
}
else
{
IrOp direct = build.block(IrBlockKind::Internal);
IrOp reverse = build.block(IrBlockKind::Internal);
IrOp zero = build.constDouble(0.0);
IrOp step = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 1));
// step > 0
// note: equivalent to 0 < step, but lowers into one instruction on both X64 and A64
build.inst(IrCmd::JUMP_CMP_NUM, step, zero, build.cond(IrCondition::Greater), direct, reverse);
// Condition to start the loop: step > 0 ? idx <= limit : limit <= idx
// We invert the condition so that loopStart is the fallthrough (false) label
// step > 0 is false, check limit <= idx
build.beginBlock(reverse);
build.inst(IrCmd::JUMP_CMP_NUM, limit, idx, build.cond(IrCondition::NotLessEqual), loopExit, loopStart);
// step > 0 is true, check idx <= limit
build.beginBlock(direct);
build.inst(IrCmd::JUMP_CMP_NUM, idx, limit, build.cond(IrCondition::NotLessEqual), loopExit, loopStart);
}
build.inst(IrCmd::JUMP_FORN_LOOP_COND, idx, limit, step, loopStart, loopExit);
}
else
{
@ -728,17 +747,32 @@ void translateInstForNLoop(IrBuilder& build, const Instruction* pc, int pcpos)
{
int ra = LUAU_INSN_A(*pc);
IrOp loopRepeat = build.blockAtInst(getJumpTarget(*pc, pcpos));
int repeatJumpTarget = getJumpTarget(*pc, pcpos);
IrOp loopRepeat = build.blockAtInst(repeatJumpTarget);
IrOp loopExit = build.blockAtInst(pcpos + getOpLength(LuauOpcode(LUAU_INSN_OP(*pc))));
// normally, the interrupt is placed at the beginning of the loop body by FORNPREP translation
// however, there are rare contrived cases where FORNLOOP ends up jumping to itself without an interrupt placed
// we detect this by checking if loopRepeat has any instructions (it should normally start with INTERRUPT) and emit a failsafe INTERRUPT if not
if (build.function.blockOp(loopRepeat).start == build.function.instructions.size())
build.inst(IrCmd::INTERRUPT, build.constUint(pcpos));
LUAU_ASSERT(!build.numericLoopStack.empty());
IrBuilder::LoopInfo loopInfo = build.numericLoopStack.back();
LUAU_ASSERT(!build.loopStepStack.empty());
IrOp stepK = build.loopStepStack.back();
if (FFlag::LuauLoopInterruptFix)
{
// normally, the interrupt is placed at the beginning of the loop body by FORNPREP translation
// however, there are rare cases where FORNLOOP might not jump directly to the first loop instruction
// we detect this by checking the starting instruction of the loop body from loop information stack
if (repeatJumpTarget != loopInfo.startpc)
build.inst(IrCmd::INTERRUPT, build.constUint(pcpos));
}
else
{
// normally, the interrupt is placed at the beginning of the loop body by FORNPREP translation
// however, there are rare contrived cases where FORNLOOP ends up jumping to itself without an interrupt placed
// we detect this by checking if loopRepeat has any instructions (it should normally start with INTERRUPT) and emit a failsafe INTERRUPT if
// not
if (build.function.blockOp(loopRepeat).start == build.function.instructions.size())
build.inst(IrCmd::INTERRUPT, build.constUint(pcpos));
}
IrOp stepK = loopInfo.step;
IrOp limit = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 0));
IrOp step = stepK.kind == IrOpKind::Undef ? build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 1)) : stepK;
@ -749,31 +783,7 @@ void translateInstForNLoop(IrBuilder& build, const Instruction* pc, int pcpos)
if (stepK.kind == IrOpKind::Undef)
{
if (FFlag::LuauLowerAltLoopForn)
{
build.inst(IrCmd::JUMP_FORN_LOOP_COND, idx, limit, step, loopRepeat, loopExit);
}
else
{
IrOp direct = build.block(IrBlockKind::Internal);
IrOp reverse = build.block(IrBlockKind::Internal);
IrOp zero = build.constDouble(0.0);
// step > 0
// note: equivalent to 0 < step, but lowers into one instruction on both X64 and A64
build.inst(IrCmd::JUMP_CMP_NUM, step, zero, build.cond(IrCondition::Greater), direct, reverse);
// Condition to continue the loop: step > 0 ? idx <= limit : limit <= idx
// step > 0 is false, check limit <= idx
build.beginBlock(reverse);
build.inst(IrCmd::JUMP_CMP_NUM, limit, idx, build.cond(IrCondition::LessEqual), loopRepeat, loopExit);
// step > 0 is true, check idx <= limit
build.beginBlock(direct);
build.inst(IrCmd::JUMP_CMP_NUM, idx, limit, build.cond(IrCondition::LessEqual), loopRepeat, loopExit);
}
build.inst(IrCmd::JUMP_FORN_LOOP_COND, idx, limit, step, loopRepeat, loopExit);
}
else
{
@ -808,7 +818,7 @@ void translateInstForGPrepNext(IrBuilder& build, const Instruction* pc, int pcpo
build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNIL));
// setpvalue(ra + 2, reinterpret_cast<void*>(uintptr_t(0)));
build.inst(IrCmd::STORE_INT, build.vmReg(ra + 2), build.constInt(0));
build.inst(FFlag::LuauFullLoopLuserdata ? IrCmd::STORE_POINTER : IrCmd::STORE_INT, build.vmReg(ra + 2), build.constInt(0));
build.inst(IrCmd::STORE_TAG, build.vmReg(ra + 2), build.constTag(LUA_TLIGHTUSERDATA));
build.inst(IrCmd::JUMP, target);
@ -840,7 +850,7 @@ void translateInstForGPrepInext(IrBuilder& build, const Instruction* pc, int pcp
build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNIL));
// setpvalue(ra + 2, reinterpret_cast<void*>(uintptr_t(0)));
build.inst(IrCmd::STORE_INT, build.vmReg(ra + 2), build.constInt(0));
build.inst(FFlag::LuauFullLoopLuserdata ? IrCmd::STORE_POINTER : IrCmd::STORE_INT, build.vmReg(ra + 2), build.constInt(0));
build.inst(IrCmd::STORE_TAG, build.vmReg(ra + 2), build.constTag(LUA_TLIGHTUSERDATA));
build.inst(IrCmd::JUMP, target);
@ -912,29 +922,20 @@ void translateInstGetTableN(IrBuilder& build, const Instruction* pc, int pcpos)
int c = LUAU_INSN_C(*pc);
IrOp fallback = build.block(IrBlockKind::Fallback);
BytecodeTypes bcTypes = build.function.getBytecodeTypesAt(pcpos);
IrOp tb = build.inst(IrCmd::LOAD_TAG, build.vmReg(rb));
build.inst(IrCmd::CHECK_TAG, tb, build.constTag(LUA_TTABLE), fallback);
build.inst(IrCmd::CHECK_TAG, tb, build.constTag(LUA_TTABLE), bcTypes.a == LBC_TYPE_TABLE ? build.vmExit(pcpos) : fallback);
IrOp vb = build.inst(IrCmd::LOAD_POINTER, build.vmReg(rb));
build.inst(IrCmd::CHECK_ARRAY_SIZE, vb, build.constInt(c), fallback);
build.inst(IrCmd::CHECK_NO_METATABLE, vb, fallback);
if (FFlag::LuauInlineArrConstOffset)
{
IrOp arrEl = build.inst(IrCmd::GET_ARR_ADDR, vb, build.constInt(0));
IrOp arrEl = build.inst(IrCmd::GET_ARR_ADDR, vb, build.constInt(0));
IrOp arrElTval = build.inst(IrCmd::LOAD_TVALUE, arrEl, build.constInt(c * sizeof(TValue)));
build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), arrElTval);
}
else
{
IrOp arrEl = build.inst(IrCmd::GET_ARR_ADDR, vb, build.constInt(c));
IrOp arrElTval = build.inst(IrCmd::LOAD_TVALUE, arrEl);
build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), arrElTval);
}
IrOp arrElTval = build.inst(IrCmd::LOAD_TVALUE, arrEl, build.constInt(c * sizeof(TValue)));
build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), arrElTval);
IrOp next = build.blockAtInst(pcpos + 1);
FallbackStreamScope scope(build, fallback, next);
@ -951,9 +952,10 @@ void translateInstSetTableN(IrBuilder& build, const Instruction* pc, int pcpos)
int c = LUAU_INSN_C(*pc);
IrOp fallback = build.block(IrBlockKind::Fallback);
BytecodeTypes bcTypes = build.function.getBytecodeTypesAt(pcpos);
IrOp tb = build.inst(IrCmd::LOAD_TAG, build.vmReg(rb));
build.inst(IrCmd::CHECK_TAG, tb, build.constTag(LUA_TTABLE), fallback);
build.inst(IrCmd::CHECK_TAG, tb, build.constTag(LUA_TTABLE), bcTypes.a == LBC_TYPE_TABLE ? build.vmExit(pcpos) : fallback);
IrOp vb = build.inst(IrCmd::LOAD_POINTER, build.vmReg(rb));
@ -961,20 +963,10 @@ void translateInstSetTableN(IrBuilder& build, const Instruction* pc, int pcpos)
build.inst(IrCmd::CHECK_NO_METATABLE, vb, fallback);
build.inst(IrCmd::CHECK_READONLY, vb, fallback);
if (FFlag::LuauInlineArrConstOffset)
{
IrOp arrEl = build.inst(IrCmd::GET_ARR_ADDR, vb, build.constInt(0));
IrOp arrEl = build.inst(IrCmd::GET_ARR_ADDR, vb, build.constInt(0));
IrOp tva = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(ra));
build.inst(IrCmd::STORE_TVALUE, arrEl, tva, build.constInt(c * sizeof(TValue)));
}
else
{
IrOp arrEl = build.inst(IrCmd::GET_ARR_ADDR, vb, build.constInt(c));
IrOp tva = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(ra));
build.inst(IrCmd::STORE_TVALUE, arrEl, tva);
}
IrOp tva = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(ra));
build.inst(IrCmd::STORE_TVALUE, arrEl, tva, build.constInt(c * sizeof(TValue)));
build.inst(IrCmd::BARRIER_TABLE_FORWARD, vb, build.vmReg(ra), build.undef());
@ -993,11 +985,12 @@ void translateInstGetTable(IrBuilder& build, const Instruction* pc, int pcpos)
int rc = LUAU_INSN_C(*pc);
IrOp fallback = build.block(IrBlockKind::Fallback);
BytecodeTypes bcTypes = build.function.getBytecodeTypesAt(pcpos);
IrOp tb = build.inst(IrCmd::LOAD_TAG, build.vmReg(rb));
build.inst(IrCmd::CHECK_TAG, tb, build.constTag(LUA_TTABLE), fallback);
build.inst(IrCmd::CHECK_TAG, tb, build.constTag(LUA_TTABLE), bcTypes.a == LBC_TYPE_TABLE ? build.vmExit(pcpos) : fallback);
IrOp tc = build.inst(IrCmd::LOAD_TAG, build.vmReg(rc));
build.inst(IrCmd::CHECK_TAG, tc, build.constTag(LUA_TNUMBER), fallback);
build.inst(IrCmd::CHECK_TAG, tc, build.constTag(LUA_TNUMBER), bcTypes.b == LBC_TYPE_NUMBER ? build.vmExit(pcpos) : fallback);
// fast-path: table with a number index
IrOp vb = build.inst(IrCmd::LOAD_POINTER, build.vmReg(rb));
@ -1030,11 +1023,12 @@ void translateInstSetTable(IrBuilder& build, const Instruction* pc, int pcpos)
int rc = LUAU_INSN_C(*pc);
IrOp fallback = build.block(IrBlockKind::Fallback);
BytecodeTypes bcTypes = build.function.getBytecodeTypesAt(pcpos);
IrOp tb = build.inst(IrCmd::LOAD_TAG, build.vmReg(rb));
build.inst(IrCmd::CHECK_TAG, tb, build.constTag(LUA_TTABLE), fallback);
build.inst(IrCmd::CHECK_TAG, tb, build.constTag(LUA_TTABLE), bcTypes.a == LBC_TYPE_TABLE ? build.vmExit(pcpos) : fallback);
IrOp tc = build.inst(IrCmd::LOAD_TAG, build.vmReg(rc));
build.inst(IrCmd::CHECK_TAG, tc, build.constTag(LUA_TNUMBER), fallback);
build.inst(IrCmd::CHECK_TAG, tc, build.constTag(LUA_TNUMBER), bcTypes.b == LBC_TYPE_NUMBER ? build.vmExit(pcpos) : fallback);
// fast-path: table with a number index
IrOp vb = build.inst(IrCmd::LOAD_POINTER, build.vmReg(rb));
@ -1099,9 +1093,10 @@ void translateInstGetTableKS(IrBuilder& build, const Instruction* pc, int pcpos)
uint32_t aux = pc[1];
IrOp fallback = build.block(IrBlockKind::Fallback);
BytecodeTypes bcTypes = build.function.getBytecodeTypesAt(pcpos);
IrOp tb = build.inst(IrCmd::LOAD_TAG, build.vmReg(rb));
build.inst(IrCmd::CHECK_TAG, tb, build.constTag(LUA_TTABLE), fallback);
build.inst(IrCmd::CHECK_TAG, tb, build.constTag(LUA_TTABLE), bcTypes.a == LBC_TYPE_TABLE ? build.vmExit(pcpos) : fallback);
IrOp vb = build.inst(IrCmd::LOAD_POINTER, build.vmReg(rb));
@ -1126,9 +1121,10 @@ void translateInstSetTableKS(IrBuilder& build, const Instruction* pc, int pcpos)
uint32_t aux = pc[1];
IrOp fallback = build.block(IrBlockKind::Fallback);
BytecodeTypes bcTypes = build.function.getBytecodeTypesAt(pcpos);
IrOp tb = build.inst(IrCmd::LOAD_TAG, build.vmReg(rb));
build.inst(IrCmd::CHECK_TAG, tb, build.constTag(LUA_TTABLE), fallback);
build.inst(IrCmd::CHECK_TAG, tb, build.constTag(LUA_TTABLE), bcTypes.a == LBC_TYPE_TABLE ? build.vmExit(pcpos) : fallback);
IrOp vb = build.inst(IrCmd::LOAD_POINTER, build.vmReg(rb));
@ -1376,74 +1372,37 @@ void translateInstNewClosure(IrBuilder& build, const Instruction* pc, int pcpos)
Instruction uinsn = pc[ui + 1];
LUAU_ASSERT(LUAU_INSN_OP(uinsn) == LOP_CAPTURE);
if (FFlag::LuauReduceStackSpills)
switch (LUAU_INSN_A(uinsn))
{
switch (LUAU_INSN_A(uinsn))
{
case LCT_VAL:
{
IrOp src = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(LUAU_INSN_B(uinsn)));
IrOp dst = build.inst(IrCmd::GET_CLOSURE_UPVAL_ADDR, ncl, build.vmUpvalue(ui));
build.inst(IrCmd::STORE_TVALUE, dst, src);
break;
}
case LCT_REF:
{
IrOp src = build.inst(IrCmd::FINDUPVAL, build.vmReg(LUAU_INSN_B(uinsn)));
IrOp dst = build.inst(IrCmd::GET_CLOSURE_UPVAL_ADDR, ncl, build.vmUpvalue(ui));
build.inst(IrCmd::STORE_POINTER, dst, src);
build.inst(IrCmd::STORE_TAG, dst, build.constTag(LUA_TUPVAL));
break;
}
case LCT_UPVAL:
{
IrOp src = build.inst(IrCmd::GET_CLOSURE_UPVAL_ADDR, build.undef(), build.vmUpvalue(LUAU_INSN_B(uinsn)));
IrOp dst = build.inst(IrCmd::GET_CLOSURE_UPVAL_ADDR, ncl, build.vmUpvalue(ui));
IrOp load = build.inst(IrCmd::LOAD_TVALUE, src);
build.inst(IrCmd::STORE_TVALUE, dst, load);
break;
}
default:
LUAU_ASSERT(!"Unknown upvalue capture type");
LUAU_UNREACHABLE(); // improves switch() codegen by eliding opcode bounds checks
}
}
else
case LCT_VAL:
{
IrOp src = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(LUAU_INSN_B(uinsn)));
IrOp dst = build.inst(IrCmd::GET_CLOSURE_UPVAL_ADDR, ncl, build.vmUpvalue(ui));
build.inst(IrCmd::STORE_TVALUE, dst, src);
break;
}
switch (LUAU_INSN_A(uinsn))
{
case LCT_VAL:
{
IrOp src = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(LUAU_INSN_B(uinsn)));
build.inst(IrCmd::STORE_TVALUE, dst, src);
break;
}
case LCT_REF:
{
IrOp src = build.inst(IrCmd::FINDUPVAL, build.vmReg(LUAU_INSN_B(uinsn)));
IrOp dst = build.inst(IrCmd::GET_CLOSURE_UPVAL_ADDR, ncl, build.vmUpvalue(ui));
build.inst(IrCmd::STORE_POINTER, dst, src);
build.inst(IrCmd::STORE_TAG, dst, build.constTag(LUA_TUPVAL));
break;
}
case LCT_REF:
{
IrOp src = build.inst(IrCmd::FINDUPVAL, build.vmReg(LUAU_INSN_B(uinsn)));
build.inst(IrCmd::STORE_POINTER, dst, src);
build.inst(IrCmd::STORE_TAG, dst, build.constTag(LUA_TUPVAL));
break;
}
case LCT_UPVAL:
{
IrOp src = build.inst(IrCmd::GET_CLOSURE_UPVAL_ADDR, build.undef(), build.vmUpvalue(LUAU_INSN_B(uinsn)));
IrOp dst = build.inst(IrCmd::GET_CLOSURE_UPVAL_ADDR, ncl, build.vmUpvalue(ui));
IrOp load = build.inst(IrCmd::LOAD_TVALUE, src);
build.inst(IrCmd::STORE_TVALUE, dst, load);
break;
}
case LCT_UPVAL:
{
IrOp src = build.inst(IrCmd::GET_CLOSURE_UPVAL_ADDR, build.undef(), build.vmUpvalue(LUAU_INSN_B(uinsn)));
IrOp load = build.inst(IrCmd::LOAD_TVALUE, src);
build.inst(IrCmd::STORE_TVALUE, dst, load);
break;
}
default:
LUAU_ASSERT(!"Unknown upvalue capture type");
LUAU_UNREACHABLE(); // improves switch() codegen by eliding opcode bounds checks
}
default:
LUAU_ASSERT(!"Unknown upvalue capture type");
LUAU_UNREACHABLE(); // improves switch() codegen by eliding opcode bounds checks
}
}

Some files were not shown because too many files have changed in this diff Show more