Sync to upstream/release/583

This commit is contained in:
Andy Friesen 2023-07-07 10:14:35 -07:00
parent 8bc2f51d89
commit e00dbbeaf2
68 changed files with 1897 additions and 297 deletions

View file

@ -34,6 +34,11 @@ struct PackSubtypeConstraint
{
TypePackId subPack;
TypePackId superPack;
// HACK!! TODO clip.
// We need to know which of `PackSubtypeConstraint` are emitted from `AstStatReturn` vs any others.
// Then we force these specific `PackSubtypeConstraint` to only dispatch in the order of the `return`s.
bool returns = false;
};
// generalizedType ~ gen sourceType
@ -108,13 +113,12 @@ struct FunctionCallConstraint
TypeId fn;
TypePackId argsPack;
TypePackId result;
class AstExprCall* callSite;
class AstExprCall* callSite = nullptr;
std::vector<std::optional<TypeId>> discriminantTypes;
// When we dispatch this constraint, we update the key at this map to record
// the overload that we selected.
DenseHashMap<const AstNode*, TypeId>* astOriginalCallTypes;
DenseHashMap<const AstNode*, TypeId>* astOverloadResolvedTypes;
DenseHashMap<const AstNode*, TypeId>* astOverloadResolvedTypes = nullptr;
};
// result ~ prim ExpectedType SomeSingletonType MultitonType

View file

@ -0,0 +1,138 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Type.h"
#include <optional>
#include <string>
namespace Luau
{
struct DiffPathNode
{
// TODO: consider using Variants to simplify toString implementation
enum Kind
{
TableProperty,
FunctionArgument,
FunctionReturn,
Union,
Intersection,
};
Kind kind;
// non-null when TableProperty
std::optional<Name> tableProperty;
// non-null when FunctionArgument, FunctionReturn, Union, or Intersection (i.e. anonymous fields)
std::optional<int> index;
/**
* Do not use for leaf nodes
*/
DiffPathNode(Kind kind)
: kind(kind)
{
}
DiffPathNode(Kind kind, std::optional<Name> tableProperty, std::optional<int> index)
: kind(kind)
, tableProperty(tableProperty)
, index(index)
{
}
std::string toString() const;
static DiffPathNode constructWithTableProperty(Name tableProperty);
};
struct DiffPathNodeLeaf
{
std::optional<TypeId> ty;
std::optional<Name> tableProperty;
DiffPathNodeLeaf(std::optional<TypeId> ty, std::optional<Name> tableProperty)
: ty(ty)
, tableProperty(tableProperty)
{
}
static DiffPathNodeLeaf nullopts();
};
struct DiffPath
{
std::vector<DiffPathNode> path;
std::string toString(bool prependDot) const;
};
struct DiffError
{
enum Kind
{
Normal,
MissingProperty,
LengthMismatchInFnArgs,
LengthMismatchInFnRets,
LengthMismatchInUnion,
LengthMismatchInIntersection,
};
Kind kind;
DiffPath diffPath;
DiffPathNodeLeaf left;
DiffPathNodeLeaf right;
std::string leftRootName;
std::string rightRootName;
DiffError(Kind kind, DiffPathNodeLeaf left, DiffPathNodeLeaf right, std::string leftRootName, std::string rightRootName)
: kind(kind)
, left(left)
, right(right)
, leftRootName(leftRootName)
, rightRootName(rightRootName)
{
checkValidInitialization(left, right);
}
DiffError(Kind kind, DiffPath diffPath, DiffPathNodeLeaf left, DiffPathNodeLeaf right, std::string leftRootName, std::string rightRootName)
: kind(kind)
, diffPath(diffPath)
, left(left)
, right(right)
, leftRootName(leftRootName)
, rightRootName(rightRootName)
{
checkValidInitialization(left, right);
}
std::string toString() const;
private:
std::string toStringALeaf(std::string rootName, const DiffPathNodeLeaf& leaf, const DiffPathNodeLeaf& otherLeaf) const;
void checkValidInitialization(const DiffPathNodeLeaf& left, const DiffPathNodeLeaf& right);
void checkNonMissingPropertyLeavesHaveNulloptTableProperty() const;
};
struct DifferResult
{
std::optional<DiffError> diffError;
DifferResult() {}
DifferResult(DiffError diffError)
: diffError(diffError)
{
}
void wrapDiffPath(DiffPathNode node);
};
struct DifferEnvironment
{
TypeId rootLeft;
TypeId rootRight;
};
DifferResult diff(TypeId ty1, TypeId ty2);
/**
* True if ty is a "simple" type, i.e. cannot contain types.
* string, number, boolean are simple types.
* function and table are not simple types.
*/
bool isSimple(TypeId ty);
} // namespace Luau

View file

@ -144,6 +144,10 @@ struct Frontend
Frontend(FileResolver* fileResolver, ConfigResolver* configResolver, const FrontendOptions& options = {});
// Parse module graph and prepare SourceNode/SourceModule data, including required dependencies without running typechecking
void parse(const ModuleName& name);
// Parse and typecheck module graph
CheckResult check(const ModuleName& name, std::optional<FrontendOptions> optionOverride = {}); // new shininess
bool isDirty(const ModuleName& name, bool forAutocomplete = false) const;

View file

@ -274,6 +274,9 @@ struct NormalizedType
/// Returns true if the type is a subtype of string(it could be a singleton). Behaves like Type::isString()
bool isSubtypeOfString() const;
/// Returns true if this type should result in error suppressing behavior.
bool shouldSuppressErrors() const;
// Helpers that improve readability of the above (they just say if the component is present)
bool hasTops() const;
bool hasBooleans() const;
@ -343,7 +346,7 @@ public:
void unionTablesWithTable(TypeIds& heres, TypeId there);
void unionTables(TypeIds& heres, const TypeIds& theres);
bool unionNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars = -1);
bool unionNormalWithTy(NormalizedType& here, TypeId there, int ignoreSmallerTyvars = -1);
bool unionNormalWithTy(NormalizedType& here, TypeId there, std::unordered_set<TypeId>& seenSetTypes, int ignoreSmallerTyvars = -1);
// ------- Negations
std::optional<NormalizedType> negateNormal(const NormalizedType& here);
@ -365,9 +368,9 @@ public:
std::optional<TypeId> intersectionOfFunctions(TypeId here, TypeId there);
void intersectFunctionsWithFunction(NormalizedFunctionType& heress, TypeId there);
void intersectFunctions(NormalizedFunctionType& heress, const NormalizedFunctionType& theress);
bool intersectTyvarsWithTy(NormalizedTyvars& here, TypeId there);
bool intersectTyvarsWithTy(NormalizedTyvars& here, TypeId there, std::unordered_set<TypeId>& seenSetTypes);
bool intersectNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars = -1);
bool intersectNormalWithTy(NormalizedType& here, TypeId there);
bool intersectNormalWithTy(NormalizedType& here, TypeId there, std::unordered_set<TypeId>& seenSetTypes);
bool normalizeIntersections(const std::vector<TypeId>& intersections, NormalizedType& outType);
// Check for inhabitance

View file

@ -23,7 +23,6 @@
LUAU_FASTINT(LuauTableTypeMaximumStringifierLength)
LUAU_FASTINT(LuauTypeMaximumStringifierLength)
LUAU_FASTFLAG(LuauTypecheckClassTypeIndexers)
namespace Luau
{
@ -527,7 +526,6 @@ struct ClassType
, definitionModuleName(definitionModuleName)
, indexer(indexer)
{
LUAU_ASSERT(FFlag::LuauTypecheckClassTypeIndexers);
}
};

View file

@ -304,13 +304,10 @@ struct GenericTypeVisitor
if (ctv->metatable)
traverse(*ctv->metatable);
if (FFlag::LuauTypecheckClassTypeIndexers)
if (ctv->indexer)
{
if (ctv->indexer)
{
traverse(ctv->indexer->indexType);
traverse(ctv->indexer->indexResultType);
}
traverse(ctv->indexer->indexType);
traverse(ctv->indexer->indexResultType);
}
}
}

View file

@ -55,7 +55,6 @@ Property clone(const Property& prop, TypeArena& dest, CloneState& cloneState)
static TableIndexer clone(const TableIndexer& indexer, TypeArena& dest, CloneState& cloneState)
{
LUAU_ASSERT(FFlag::LuauTypecheckClassTypeIndexers);
return TableIndexer{clone(indexer.indexType, dest, cloneState), clone(indexer.indexResultType, dest, cloneState)};
}
@ -312,16 +311,8 @@ void TypeCloner::operator()(const TableType& t)
for (const auto& [name, prop] : t.props)
ttv->props[name] = clone(prop, dest, cloneState);
if (FFlag::LuauTypecheckClassTypeIndexers)
{
if (t.indexer)
ttv->indexer = clone(*t.indexer, dest, cloneState);
}
else
{
if (t.indexer)
ttv->indexer = TableIndexer{clone(t.indexer->indexType, dest, cloneState), clone(t.indexer->indexResultType, dest, cloneState)};
}
if (t.indexer)
ttv->indexer = clone(*t.indexer, dest, cloneState);
for (TypeId& arg : ttv->instantiatedTypeParams)
arg = clone(arg, dest, cloneState);
@ -360,11 +351,8 @@ void TypeCloner::operator()(const ClassType& t)
if (t.metatable)
ctv->metatable = clone(*t.metatable, dest, cloneState);
if (FFlag::LuauTypecheckClassTypeIndexers)
{
if (t.indexer)
ctv->indexer = clone(*t.indexer, dest, cloneState);
}
if (t.indexer)
ctv->indexer = clone(*t.indexer, dest, cloneState);
}
void TypeCloner::operator()(const AnyType& t)

View file

@ -776,9 +776,10 @@ ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatForIn* f
ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatWhile* while_)
{
check(scope, while_->condition);
RefinementId refinement = check(scope, while_->condition).refinement;
ScopePtr whileScope = childScope(while_, scope);
applyRefinements(whileScope, while_->condition->location, refinement);
visit(whileScope, while_->body);
@ -825,8 +826,17 @@ ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocalFun
std::unique_ptr<Constraint> c =
std::make_unique<Constraint>(constraintScope, function->name->location, GeneralizationConstraint{functionType, sig.signature});
forEachConstraint(start, end, this, [&c](const ConstraintPtr& constraint) {
Constraint* previous = nullptr;
forEachConstraint(start, end, this, [&c, &previous](const ConstraintPtr& constraint) {
c->dependencies.push_back(NotNull{constraint.get()});
if (auto psc = get<PackSubtypeConstraint>(*constraint); psc && psc->returns)
{
if (previous)
constraint->dependencies.push_back(NotNull{previous});
previous = constraint.get();
}
});
addConstraint(scope, std::move(c));
@ -915,9 +925,18 @@ ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction
std::unique_ptr<Constraint> c =
std::make_unique<Constraint>(constraintScope, function->name->location, GeneralizationConstraint{generalizedType, sig.signature});
forEachConstraint(start, end, this, [&c, &excludeList](const ConstraintPtr& constraint) {
Constraint* previous = nullptr;
forEachConstraint(start, end, this, [&c, &excludeList, &previous](const ConstraintPtr& constraint) {
if (!excludeList.count(constraint.get()))
c->dependencies.push_back(NotNull{constraint.get()});
if (auto psc = get<PackSubtypeConstraint>(*constraint); psc && psc->returns)
{
if (previous)
constraint->dependencies.push_back(NotNull{previous});
previous = constraint.get();
}
});
addConstraint(scope, std::move(c));
@ -936,7 +955,7 @@ ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatReturn*
expectedTypes.push_back(ty);
TypePackId exprTypes = checkPack(scope, ret->list, expectedTypes).tp;
addConstraint(scope, ret->location, PackSubtypeConstraint{exprTypes, scope->returnType});
addConstraint(scope, ret->location, PackSubtypeConstraint{exprTypes, scope->returnType, /*returns*/ true});
return ControlFlow::Returns;
}
@ -1408,6 +1427,7 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa
std::vector<std::optional<TypeId>> expectedTypesForCall = getExpectedCallTypesForFunctionOverloads(fnType);
module->astOriginalCallTypes[call->func] = fnType;
module->astOriginalCallTypes[call] = fnType;
TypePackId expectedArgPack = arena->freshTypePack(scope.get());
TypePackId expectedRetPack = arena->freshTypePack(scope.get());
@ -1547,7 +1567,6 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa
rets,
call,
std::move(discriminantTypes),
&module->astOriginalCallTypes,
&module->astOverloadResolvedTypes,
});
@ -1642,7 +1661,7 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprConstantSt
if (expectedType)
{
const TypeId expectedTy = follow(*expectedType);
if (get<BlockedType>(expectedTy) || get<PendingExpansionType>(expectedTy))
if (get<BlockedType>(expectedTy) || get<PendingExpansionType>(expectedTy) || get<FreeType>(expectedTy))
{
TypeId ty = arena->addType(BlockedType{});
TypeId singletonType = arena->addType(SingletonType(StringSingleton{std::string(string->value.data, string->value.size)}));
@ -1774,8 +1793,17 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprFunction*
TypeId generalizedTy = arena->addType(BlockedType{});
NotNull<Constraint> gc = addConstraint(sig.signatureScope, func->location, GeneralizationConstraint{generalizedTy, sig.signature});
forEachConstraint(startCheckpoint, endCheckpoint, this, [gc](const ConstraintPtr& constraint) {
Constraint* previous = nullptr;
forEachConstraint(startCheckpoint, endCheckpoint, this, [gc, &previous](const ConstraintPtr& constraint) {
gc->dependencies.emplace_back(constraint.get());
if (auto psc = get<PackSubtypeConstraint>(*constraint); psc && psc->returns)
{
if (previous)
constraint->dependencies.push_back(NotNull{previous});
previous = constraint.get();
}
});
return Inference{generalizedTy};
@ -2412,7 +2440,7 @@ void ConstraintGraphBuilder::checkFunctionBody(const ScopePtr& scope, AstExprFun
if (nullptr != getFallthrough(fn->body))
{
TypePackId empty = arena->addTypePack({}); // TODO we could have CSG retain one of these forever
TypePackId empty = arena->addTypePack({}); // TODO we could have CGB retain one of these forever
addConstraint(scope, fn->location, PackSubtypeConstraint{scope->returnType, empty});
}
}

View file

@ -1259,18 +1259,13 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull<cons
else if (auto it = get<IntersectionType>(fn))
fn = collapse(it).value_or(fn);
if (c.callSite)
(*c.astOriginalCallTypes)[c.callSite] = fn;
// We don't support magic __call metamethods.
if (std::optional<TypeId> callMm = findMetatableEntry(builtinTypes, errors, fn, "__call", constraint->location))
{
std::vector<TypeId> args{fn};
auto [head, tail] = flatten(c.argsPack);
head.insert(head.begin(), fn);
for (TypeId arg : c.argsPack)
args.push_back(arg);
argsPack = arena->addTypePack(TypePack{args, {}});
argsPack = arena->addTypePack(TypePack{std::move(head), tail});
fn = *callMm;
asMutable(c.result)->ty.emplace<FreeTypePack>(constraint->scope);
}
@ -1890,7 +1885,20 @@ bool ConstraintSolver::tryDispatch(const RefineConstraint& c, NotNull<const Cons
if (!force && !blockedTypes.empty())
return block(blockedTypes, constraint);
asMutable(c.resultType)->ty.emplace<BoundType>(result);
const NormalizedType* normType = normalizer->normalize(c.type);
if (!normType)
reportError(NormalizationTooComplex{}, constraint->location);
if (normType && normType->shouldSuppressErrors())
{
auto resultOrError = simplifyUnion(builtinTypes, arena, result, builtinTypes->errorType).result;
asMutable(c.resultType)->ty.emplace<BoundType>(resultOrError);
}
else
{
asMutable(c.resultType)->ty.emplace<BoundType>(result);
}
unblock(c.resultType, constraint->location);

273
Analysis/src/Differ.cpp Normal file
View file

@ -0,0 +1,273 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Differ.h"
#include "Luau/Error.h"
#include "Luau/ToString.h"
#include "Luau/Type.h"
#include <optional>
namespace Luau
{
std::string DiffPathNode::toString() const
{
switch (kind)
{
case DiffPathNode::Kind::TableProperty:
{
if (!tableProperty.has_value())
throw InternalCompilerError{"DiffPathNode has kind TableProperty but tableProperty is nullopt"};
return *tableProperty;
break;
}
default:
{
throw InternalCompilerError{"DiffPathNode::toString is not exhaustive"};
}
}
}
DiffPathNode DiffPathNode::constructWithTableProperty(Name tableProperty)
{
return DiffPathNode{DiffPathNode::Kind::TableProperty, tableProperty, std::nullopt};
}
DiffPathNodeLeaf DiffPathNodeLeaf::nullopts()
{
return DiffPathNodeLeaf{std::nullopt, std::nullopt};
}
std::string DiffPath::toString(bool prependDot) const
{
std::string pathStr;
bool isFirstInForLoop = !prependDot;
for (auto node = path.rbegin(); node != path.rend(); node++)
{
if (isFirstInForLoop)
{
isFirstInForLoop = false;
}
else
{
pathStr += ".";
}
pathStr += node->toString();
}
return pathStr;
}
std::string DiffError::toStringALeaf(std::string rootName, const DiffPathNodeLeaf& leaf, const DiffPathNodeLeaf& otherLeaf) const
{
std::string pathStr{rootName + diffPath.toString(true)};
switch (kind)
{
case DiffError::Kind::Normal:
{
checkNonMissingPropertyLeavesHaveNulloptTableProperty();
return pathStr + " has type " + Luau::toString(*leaf.ty);
}
case DiffError::Kind::MissingProperty:
{
if (leaf.ty.has_value())
{
if (!leaf.tableProperty.has_value())
throw InternalCompilerError{"leaf.tableProperty is nullopt"};
return pathStr + "." + *leaf.tableProperty + " has type " + Luau::toString(*leaf.ty);
}
else if (otherLeaf.ty.has_value())
{
if (!otherLeaf.tableProperty.has_value())
throw InternalCompilerError{"otherLeaf.tableProperty is nullopt"};
return pathStr + " is missing the property " + *otherLeaf.tableProperty;
}
throw InternalCompilerError{"Both leaf.ty and otherLeaf.ty is nullopt"};
}
default:
{
throw InternalCompilerError{"DiffPath::toStringWithLeaf is not exhaustive"};
}
}
}
void DiffError::checkNonMissingPropertyLeavesHaveNulloptTableProperty() const
{
if (left.tableProperty.has_value() || right.tableProperty.has_value())
throw InternalCompilerError{"Non-MissingProperty DiffError should have nullopt tableProperty in both leaves"};
}
std::string getDevFixFriendlyName(TypeId ty)
{
if (auto table = get<TableType>(ty))
{
if (table->name.has_value())
return *table->name;
else if (table->syntheticName.has_value())
return *table->syntheticName;
}
// else if (auto primitive = get<PrimitiveType>(ty))
//{
// return "<unlabeled-symbol>";
//}
return "<unlabeled-symbol>";
}
std::string DiffError::toString() const
{
std::string msg = "DiffError: these two types are not equal because the left type at " + toStringALeaf(leftRootName, left, right) +
", while the right type at " + toStringALeaf(rightRootName, right, left);
return msg;
}
void DiffError::checkValidInitialization(const DiffPathNodeLeaf& left, const DiffPathNodeLeaf& right)
{
if (!left.ty.has_value() || !right.ty.has_value())
{
// TODO: think about whether this should be always thrown!
// For example, Kind::Primitive doesn't make too much sense to have a TypeId
// throw InternalCompilerError{"Left and Right fields are leaf nodes and must have a TypeId"};
}
}
void DifferResult::wrapDiffPath(DiffPathNode node)
{
if (!diffError.has_value())
{
throw InternalCompilerError{"Cannot wrap diffPath because there is no diffError"};
}
diffError->diffPath.path.push_back(node);
}
static DifferResult diffUsingEnv(DifferEnvironment& env, TypeId left, TypeId right);
static DifferResult diffTable(DifferEnvironment& env, TypeId left, TypeId right);
static DifferResult diffPrimitive(DifferEnvironment& env, TypeId left, TypeId right);
static DifferResult diffSingleton(DifferEnvironment& env, TypeId left, TypeId right);
static DifferResult diffTable(DifferEnvironment& env, TypeId left, TypeId right)
{
const TableType* leftTable = get<TableType>(left);
const TableType* rightTable = get<TableType>(right);
for (auto const& [field, value] : leftTable->props)
{
if (rightTable->props.find(field) == rightTable->props.end())
{
// left has a field the right doesn't
return DifferResult{DiffError{
DiffError::Kind::MissingProperty,
DiffPathNodeLeaf{value.type(), field},
DiffPathNodeLeaf::nullopts(),
getDevFixFriendlyName(env.rootLeft),
getDevFixFriendlyName(env.rootRight),
}};
}
}
for (auto const& [field, value] : rightTable->props)
{
if (leftTable->props.find(field) == leftTable->props.end())
{
// right has a field the left doesn't
return DifferResult{DiffError{DiffError::Kind::MissingProperty, DiffPathNodeLeaf::nullopts(), DiffPathNodeLeaf{value.type(), field},
getDevFixFriendlyName(env.rootLeft), getDevFixFriendlyName(env.rootRight)}};
}
}
// left and right have the same set of keys
for (auto const& [field, leftValue] : leftTable->props)
{
auto const& rightValue = rightTable->props.at(field);
DifferResult differResult = diffUsingEnv(env, leftValue.type(), rightValue.type());
if (differResult.diffError.has_value())
{
differResult.wrapDiffPath(DiffPathNode::constructWithTableProperty(field));
return differResult;
}
}
return DifferResult{};
}
static DifferResult diffPrimitive(DifferEnvironment& env, TypeId left, TypeId right)
{
const PrimitiveType* leftPrimitive = get<PrimitiveType>(left);
const PrimitiveType* rightPrimitive = get<PrimitiveType>(right);
if (leftPrimitive->type != rightPrimitive->type)
{
return DifferResult{DiffError{
DiffError::Kind::Normal,
DiffPathNodeLeaf{left, std::nullopt},
DiffPathNodeLeaf{right, std::nullopt},
getDevFixFriendlyName(env.rootLeft),
getDevFixFriendlyName(env.rootRight),
}};
}
return DifferResult{};
}
static DifferResult diffSingleton(DifferEnvironment& env, TypeId left, TypeId right)
{
const SingletonType* leftSingleton = get<SingletonType>(left);
const SingletonType* rightSingleton = get<SingletonType>(right);
if (*leftSingleton != *rightSingleton)
{
return DifferResult{DiffError{
DiffError::Kind::Normal,
DiffPathNodeLeaf{left, std::nullopt},
DiffPathNodeLeaf{right, std::nullopt},
getDevFixFriendlyName(env.rootLeft),
getDevFixFriendlyName(env.rootRight),
}};
}
return DifferResult{};
}
static DifferResult diffUsingEnv(DifferEnvironment& env, TypeId left, TypeId right)
{
left = follow(left);
right = follow(right);
if (left->ty.index() != right->ty.index())
{
return DifferResult{DiffError{
DiffError::Kind::Normal,
DiffPathNodeLeaf{left, std::nullopt},
DiffPathNodeLeaf{right, std::nullopt},
getDevFixFriendlyName(env.rootLeft),
getDevFixFriendlyName(env.rootRight),
}};
}
// Both left and right are the same variant
if (isSimple(left))
{
if (auto lp = get<PrimitiveType>(left))
return diffPrimitive(env, left, right);
else if (auto ls = get<SingletonType>(left))
{
return diffSingleton(env, left, right);
}
throw InternalCompilerError{"Unimplemented Simple TypeId variant for diffing"};
}
// Both left and right are the same non-Simple
if (auto lt = get<TableType>(left))
{
return diffTable(env, left, right);
}
throw InternalCompilerError{"Unimplemented non-simple TypeId variant for diffing"};
}
DifferResult diff(TypeId ty1, TypeId ty2)
{
DifferEnvironment differEnv{ty1, ty2};
return diffUsingEnv(differEnv, ty1, ty2);
}
bool isSimple(TypeId ty)
{
ty = follow(ty);
// TODO: think about GenericType, etc.
return get<PrimitiveType>(ty) || get<SingletonType>(ty);
}
} // namespace Luau

View file

@ -415,6 +415,18 @@ Frontend::Frontend(FileResolver* fileResolver, ConfigResolver* configResolver, c
{
}
void Frontend::parse(const ModuleName& name)
{
LUAU_TIMETRACE_SCOPE("Frontend::parse", "Frontend");
LUAU_TIMETRACE_ARGUMENT("name", name.c_str());
if (getCheckResult(name, false, false))
return;
std::vector<ModuleName> buildQueue;
parseGraph(buildQueue, name, false);
}
CheckResult Frontend::check(const ModuleName& name, std::optional<FrontendOptions> optionOverride)
{
LUAU_TIMETRACE_SCOPE("Frontend::check", "Frontend");

View file

@ -18,6 +18,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauCheckNormalizeInvariant, false)
LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200);
LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000);
LUAU_FASTFLAGVARIABLE(LuauNormalizeBlockedTypes, false);
LUAU_FASTFLAGVARIABLE(LuauNormalizeCyclicUnions, false);
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
LUAU_FASTFLAG(LuauTransitiveSubtyping)
LUAU_FASTFLAG(DebugLuauReadWriteProperties)
@ -247,6 +248,11 @@ bool NormalizedType::isSubtypeOfString() const
!hasTables() && !hasFunctions() && !hasTyvars();
}
bool NormalizedType::shouldSuppressErrors() const
{
return hasErrors() || get<AnyType>(tops);
}
bool NormalizedType::hasTops() const
{
return !get<NeverType>(tops);
@ -690,7 +696,8 @@ const NormalizedType* Normalizer::normalize(TypeId ty)
return found->second.get();
NormalizedType norm{builtinTypes};
if (!unionNormalWithTy(norm, ty))
std::unordered_set<TypeId> seenSetTypes;
if (!unionNormalWithTy(norm, ty, seenSetTypes))
return nullptr;
std::unique_ptr<NormalizedType> uniq = std::make_unique<NormalizedType>(std::move(norm));
const NormalizedType* result = uniq.get();
@ -705,9 +712,12 @@ bool Normalizer::normalizeIntersections(const std::vector<TypeId>& intersections
NormalizedType norm{builtinTypes};
norm.tops = builtinTypes->anyType;
// Now we need to intersect the two types
std::unordered_set<TypeId> seenSetTypes;
for (auto ty : intersections)
if (!intersectNormalWithTy(norm, ty))
{
if (!intersectNormalWithTy(norm, ty, seenSetTypes))
return false;
}
if (!unionNormals(outType, norm))
return false;
@ -1438,13 +1448,14 @@ bool Normalizer::withinResourceLimits()
}
// See above for an explaination of `ignoreSmallerTyvars`.
bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignoreSmallerTyvars)
bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, std::unordered_set<TypeId>& seenSetTypes, int ignoreSmallerTyvars)
{
RecursionCounter _rc(&sharedState->counters.recursionCount);
if (!withinResourceLimits())
return false;
there = follow(there);
if (get<AnyType>(there) || get<UnknownType>(there))
{
TypeId tops = unionOfTops(here.tops, there);
@ -1465,9 +1476,23 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor
}
else if (const UnionType* utv = get<UnionType>(there))
{
if (FFlag::LuauNormalizeCyclicUnions)
{
if (seenSetTypes.count(there))
return true;
seenSetTypes.insert(there);
}
for (UnionTypeIterator it = begin(utv); it != end(utv); ++it)
if (!unionNormalWithTy(here, *it))
{
if (!unionNormalWithTy(here, *it, seenSetTypes))
{
seenSetTypes.erase(there);
return false;
}
}
seenSetTypes.erase(there);
return true;
}
else if (const IntersectionType* itv = get<IntersectionType>(there))
@ -1475,8 +1500,10 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor
NormalizedType norm{builtinTypes};
norm.tops = builtinTypes->anyType;
for (IntersectionTypeIterator it = begin(itv); it != end(itv); ++it)
if (!intersectNormalWithTy(norm, *it))
{
if (!intersectNormalWithTy(norm, *it, seenSetTypes))
return false;
}
return unionNormals(here, norm);
}
else if (FFlag::LuauTransitiveSubtyping && get<UnknownType>(here.tops))
@ -1560,7 +1587,7 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor
LUAU_ASSERT(!"Unreachable");
for (auto& [tyvar, intersect] : here.tyvars)
if (!unionNormalWithTy(*intersect, there, tyvarIndex(tyvar)))
if (!unionNormalWithTy(*intersect, there, seenSetTypes, tyvarIndex(tyvar)))
return false;
assertInvariant(here);
@ -2463,12 +2490,12 @@ void Normalizer::intersectFunctions(NormalizedFunctionType& heres, const Normali
}
}
bool Normalizer::intersectTyvarsWithTy(NormalizedTyvars& here, TypeId there)
bool Normalizer::intersectTyvarsWithTy(NormalizedTyvars& here, TypeId there, std::unordered_set<TypeId>& seenSetTypes)
{
for (auto it = here.begin(); it != here.end();)
{
NormalizedType& inter = *it->second;
if (!intersectNormalWithTy(inter, there))
if (!intersectNormalWithTy(inter, there, seenSetTypes))
return false;
if (isShallowInhabited(inter))
++it;
@ -2541,13 +2568,14 @@ bool Normalizer::intersectNormals(NormalizedType& here, const NormalizedType& th
return true;
}
bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there)
bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there, std::unordered_set<TypeId>& seenSetTypes)
{
RecursionCounter _rc(&sharedState->counters.recursionCount);
if (!withinResourceLimits())
return false;
there = follow(there);
if (get<AnyType>(there) || get<UnknownType>(there))
{
here.tops = intersectionOfTops(here.tops, there);
@ -2556,20 +2584,20 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there)
else if (!get<NeverType>(here.tops))
{
clearNormal(here);
return unionNormalWithTy(here, there);
return unionNormalWithTy(here, there, seenSetTypes);
}
else if (const UnionType* utv = get<UnionType>(there))
{
NormalizedType norm{builtinTypes};
for (UnionTypeIterator it = begin(utv); it != end(utv); ++it)
if (!unionNormalWithTy(norm, *it))
if (!unionNormalWithTy(norm, *it, seenSetTypes))
return false;
return intersectNormals(here, norm);
}
else if (const IntersectionType* itv = get<IntersectionType>(there))
{
for (IntersectionTypeIterator it = begin(itv); it != end(itv); ++it)
if (!intersectNormalWithTy(here, *it))
if (!intersectNormalWithTy(here, *it, seenSetTypes))
return false;
return true;
}
@ -2691,7 +2719,7 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there)
return true;
}
else if (auto nt = get<NegationType>(t))
return intersectNormalWithTy(here, nt->ty);
return intersectNormalWithTy(here, nt->ty, seenSetTypes);
else
{
// TODO negated unions, intersections, table, and function.
@ -2706,7 +2734,7 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there)
else
LUAU_ASSERT(!"Unreachable");
if (!intersectTyvarsWithTy(tyvars, there))
if (!intersectTyvarsWithTy(tyvars, there, seenSetTypes))
return false;
here.tyvars = std::move(tyvars);

View file

@ -191,16 +191,8 @@ static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool a
{
if (alwaysClone)
{
if (FFlag::LuauTypecheckClassTypeIndexers)
{
ClassType clone{a.name, a.props, a.parent, a.metatable, a.tags, a.userData, a.definitionModuleName, a.indexer};
return dest.addType(std::move(clone));
}
else
{
ClassType clone{a.name, a.props, a.parent, a.metatable, a.tags, a.userData, a.definitionModuleName};
return dest.addType(std::move(clone));
}
ClassType clone{a.name, a.props, a.parent, a.metatable, a.tags, a.userData, a.definitionModuleName, a.indexer};
return dest.addType(std::move(clone));
}
else
return ty;
@ -316,13 +308,10 @@ void Tarjan::visitChildren(TypeId ty, int index)
if (ctv->metatable)
visitChild(*ctv->metatable);
if (FFlag::LuauTypecheckClassTypeIndexers)
if (ctv->indexer)
{
if (ctv->indexer)
{
visitChild(ctv->indexer->indexType);
visitChild(ctv->indexer->indexResultType);
}
visitChild(ctv->indexer->indexType);
visitChild(ctv->indexer->indexResultType);
}
}
else if (const NegationType* ntv = get<NegationType>(ty))
@ -1038,13 +1027,10 @@ void Substitution::replaceChildren(TypeId ty)
if (ctv->metatable)
ctv->metatable = replace(*ctv->metatable);
if (FFlag::LuauTypecheckClassTypeIndexers)
if (ctv->indexer)
{
if (ctv->indexer)
{
ctv->indexer->indexType = replace(ctv->indexer->indexType);
ctv->indexer->indexResultType = replace(ctv->indexer->indexResultType);
}
ctv->indexer->indexType = replace(ctv->indexer->indexType);
ctv->indexer->indexResultType = replace(ctv->indexer->indexResultType);
}
}
else if (NegationType* ntv = getMutable<NegationType>(ty))

View file

@ -258,13 +258,10 @@ void StateDot::visitChildren(TypeId ty, int index)
if (ctv->metatable)
visitChild(*ctv->metatable, index, "[metatable]");
if (FFlag::LuauTypecheckClassTypeIndexers)
if (ctv->indexer)
{
if (ctv->indexer)
{
visitChild(ctv->indexer->indexType, index, "[index]");
visitChild(ctv->indexer->indexResultType, index, "[value]");
}
visitChild(ctv->indexer->indexType, index, "[index]");
visitChild(ctv->indexer->indexResultType, index, "[value]");
}
}
else if (const SingletonType* stv = get<SingletonType>(ty))

View file

@ -1090,6 +1090,11 @@ struct TypeChecker2
args.head.push_back(lookupType(indexExpr->expr));
argLocs.push_back(indexExpr->expr->location);
}
else if (findMetatableEntry(builtinTypes, module->errors, *originalCallTy, "__call", call->func->location))
{
args.head.insert(args.head.begin(), lookupType(call->func));
argLocs.push_back(call->func->location);
}
for (size_t i = 0; i < call->args.size; ++i)
{

View file

@ -38,7 +38,6 @@ LUAU_FASTFLAGVARIABLE(LuauAllowIndexClassParameters, false)
LUAU_FASTFLAG(LuauOccursIsntAlwaysFailure)
LUAU_FASTFLAGVARIABLE(LuauTypecheckTypeguards, false)
LUAU_FASTFLAGVARIABLE(LuauTinyControlFlowAnalysis, false)
LUAU_FASTFLAGVARIABLE(LuauTypecheckClassTypeIndexers, false)
LUAU_FASTFLAGVARIABLE(LuauAlwaysCommitInferencesOfFunctionCalls, false)
LUAU_FASTFLAG(LuauParseDeclareClassIndexer)
@ -2107,21 +2106,18 @@ std::optional<TypeId> TypeChecker::getIndexTypeFromTypeImpl(
if (prop)
return prop->type();
if (FFlag::LuauTypecheckClassTypeIndexers)
if (auto indexer = cls->indexer)
{
if (auto indexer = cls->indexer)
{
// TODO: Property lookup should work with string singletons or unions thereof as the indexer key type.
ErrorVec errors = tryUnify(stringType, indexer->indexType, scope, location);
// TODO: Property lookup should work with string singletons or unions thereof as the indexer key type.
ErrorVec errors = tryUnify(stringType, indexer->indexType, scope, location);
if (errors.empty())
return indexer->indexResultType;
if (errors.empty())
return indexer->indexResultType;
if (addErrors)
reportError(location, UnknownProperty{type, name});
if (addErrors)
reportError(location, UnknownProperty{type, name});
return std::nullopt;
}
return std::nullopt;
}
}
else if (const UnionType* utv = get<UnionType>(type))
@ -3312,38 +3308,24 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex
}
else if (const ClassType* lhsClass = get<ClassType>(lhs))
{
if (FFlag::LuauTypecheckClassTypeIndexers)
if (const Property* prop = lookupClassProp(lhsClass, name))
{
if (const Property* prop = lookupClassProp(lhsClass, name))
{
return prop->type();
}
if (auto indexer = lhsClass->indexer)
{
Unifier state = mkUnifier(scope, expr.location);
state.tryUnify(stringType, indexer->indexType);
if (state.errors.empty())
{
state.log.commit();
return indexer->indexResultType;
}
}
reportError(TypeError{expr.location, UnknownProperty{lhs, name}});
return errorRecoveryType(scope);
}
else
{
const Property* prop = lookupClassProp(lhsClass, name);
if (!prop)
{
reportError(TypeError{expr.location, UnknownProperty{lhs, name}});
return errorRecoveryType(scope);
}
return prop->type();
}
if (auto indexer = lhsClass->indexer)
{
Unifier state = mkUnifier(scope, expr.location);
state.tryUnify(stringType, indexer->indexType);
if (state.errors.empty())
{
state.log.commit();
return indexer->indexResultType;
}
}
reportError(TypeError{expr.location, UnknownProperty{lhs, name}});
return errorRecoveryType(scope);
}
else if (get<IntersectionType>(lhs))
{
@ -3385,45 +3367,29 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex
{
if (const ClassType* exprClass = get<ClassType>(exprType))
{
if (FFlag::LuauTypecheckClassTypeIndexers)
if (const Property* prop = lookupClassProp(exprClass, value->value.data))
{
if (const Property* prop = lookupClassProp(exprClass, value->value.data))
{
return prop->type();
}
if (auto indexer = exprClass->indexer)
{
unify(stringType, indexer->indexType, scope, expr.index->location);
return indexer->indexResultType;
}
reportError(TypeError{expr.location, UnknownProperty{exprType, value->value.data}});
return errorRecoveryType(scope);
}
else
{
const Property* prop = lookupClassProp(exprClass, value->value.data);
if (!prop)
{
reportError(TypeError{expr.location, UnknownProperty{exprType, value->value.data}});
return errorRecoveryType(scope);
}
return prop->type();
}
if (auto indexer = exprClass->indexer)
{
unify(stringType, indexer->indexType, scope, expr.index->location);
return indexer->indexResultType;
}
reportError(TypeError{expr.location, UnknownProperty{exprType, value->value.data}});
return errorRecoveryType(scope);
}
}
else
{
if (FFlag::LuauTypecheckClassTypeIndexers)
if (const ClassType* exprClass = get<ClassType>(exprType))
{
if (const ClassType* exprClass = get<ClassType>(exprType))
if (auto indexer = exprClass->indexer)
{
if (auto indexer = exprClass->indexer)
{
unify(indexType, indexer->indexType, scope, expr.index->location);
return indexer->indexResultType;
}
unify(indexType, indexer->indexType, scope, expr.index->location);
return indexer->indexResultType;
}
}

View file

@ -1,6 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include <string>
namespace Luau
{
@ -38,6 +39,11 @@ struct Location
bool containsClosed(const Position& p) const;
void extend(const Location& other);
void shift(const Position& start, const Position& oldEnd, const Position& newEnd);
/**
* Use offset=1 when displaying for the user.
*/
std::string toString(int offset = 0, bool useBegin = true) const;
};
} // namespace Luau

View file

@ -4,6 +4,7 @@
#include "Luau/Common.h"
#include <vector>
#include <memory>
#include <stdint.h>
@ -54,7 +55,7 @@ struct Event
struct GlobalContext;
struct ThreadContext;
GlobalContext& getGlobalContext();
std::shared_ptr<GlobalContext> getGlobalContext();
uint16_t createToken(GlobalContext& context, const char* name, const char* category);
uint32_t createThread(GlobalContext& context, ThreadContext* threadContext);
@ -66,7 +67,7 @@ struct ThreadContext
ThreadContext()
: globalContext(getGlobalContext())
{
threadId = createThread(globalContext, this);
threadId = createThread(*globalContext, this);
}
~ThreadContext()
@ -74,16 +75,16 @@ struct ThreadContext
if (!events.empty())
flushEvents();
releaseThread(globalContext, this);
releaseThread(*globalContext, this);
}
void flushEvents()
{
static uint16_t flushToken = createToken(globalContext, "flushEvents", "TimeTrace");
static uint16_t flushToken = createToken(*globalContext, "flushEvents", "TimeTrace");
events.push_back({EventType::Enter, flushToken, {getClockMicroseconds()}});
TimeTrace::flushEvents(globalContext, threadId, events, data);
TimeTrace::flushEvents(*globalContext, threadId, events, data);
events.clear();
data.clear();
@ -125,7 +126,7 @@ struct ThreadContext
events.push_back({EventType::ArgValue, 0, {pos}});
}
GlobalContext& globalContext;
std::shared_ptr<GlobalContext> globalContext;
uint32_t threadId;
std::vector<Event> events;
std::vector<char> data;

View file

@ -1,5 +1,6 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Location.h"
#include <string>
namespace Luau
{
@ -128,4 +129,12 @@ void Location::shift(const Position& start, const Position& oldEnd, const Positi
end.shift(start, oldEnd, newEnd);
}
std::string Location::toString(int offset, bool useBegin) const
{
const Position& pos = useBegin ? this->begin : this->end;
std::string line{std::to_string(pos.line + offset)};
std::string column{std::to_string(pos.column + offset)};
return "(" + line + ", " + column + ")";
}
} // namespace Luau

View file

@ -90,7 +90,6 @@ namespace TimeTrace
{
struct GlobalContext
{
GlobalContext() = default;
~GlobalContext()
{
// Ideally we would want all ThreadContext destructors to run
@ -110,11 +109,15 @@ struct GlobalContext
uint32_t nextThreadId = 0;
std::vector<Token> tokens;
FILE* traceFile = nullptr;
private:
friend std::shared_ptr<GlobalContext> getGlobalContext();
GlobalContext() = default;
};
GlobalContext& getGlobalContext()
std::shared_ptr<GlobalContext> getGlobalContext()
{
static GlobalContext context;
static std::shared_ptr<GlobalContext> context = std::shared_ptr<GlobalContext>{new GlobalContext};
return context;
}
@ -261,7 +264,7 @@ ThreadContext& getThreadContext()
uint16_t createScopeData(const char* name, const char* category)
{
return createToken(Luau::TimeTrace::getGlobalContext(), name, category);
return createToken(*Luau::TimeTrace::getGlobalContext(), name, category);
}
} // namespace TimeTrace
} // namespace Luau

View file

@ -236,6 +236,10 @@ enum class IrCmd : uint8_t
// A: pointer (Table)
TABLE_LEN,
// Get string length
// A: pointer (string)
STRING_LEN,
// Allocate new table
// A: int (array element count)
// B: int (node element count)
@ -361,8 +365,10 @@ enum class IrCmd : uint8_t
// Guard against tag mismatch
// A, B: tag
// C: block/undef
// D: bool (finish execution in VM on failure)
// In final x64 lowering, A can also be Rn
// When undef is specified instead of a block, execution is aborted on check failure
// When undef is specified instead of a block, execution is aborted on check failure; if D is true, execution is continued in VM interpreter
// instead.
CHECK_TAG,
// Guard against readonly table
@ -377,9 +383,9 @@ enum class IrCmd : uint8_t
// When undef is specified instead of a block, execution is aborted on check failure
CHECK_NO_METATABLE,
// Guard against executing in unsafe environment
// A: block/undef
// When undef is specified instead of a block, execution is aborted on check failure
// Guard against executing in unsafe environment, exits to VM on check failure
// A: unsigned int (pcpos)/undef
// When undef is specified, execution is aborted on check failure
CHECK_SAFE_ENV,
// Guard against index overflowing the table array size
@ -610,7 +616,6 @@ struct IrConst
union
{
bool valueBool;
int valueInt;
unsigned valueUint;
double valueDouble;

View file

@ -39,6 +39,8 @@ struct IrRegAllocX64
RegisterX64 allocRegOrReuse(SizeX64 size, uint32_t instIdx, std::initializer_list<IrOp> oprefs);
RegisterX64 takeReg(RegisterX64 reg, uint32_t instIdx);
bool canTakeReg(RegisterX64 reg) const;
void freeReg(RegisterX64 reg);
void freeLastUseReg(IrInst& target, uint32_t instIdx);
void freeLastUseRegs(const IrInst& inst, uint32_t instIdx);

View file

@ -167,6 +167,7 @@ inline bool hasResult(IrCmd cmd)
case IrCmd::ABS_NUM:
case IrCmd::NOT_ANY:
case IrCmd::TABLE_LEN:
case IrCmd::STRING_LEN:
case IrCmd::NEW_TABLE:
case IrCmd::DUP_TABLE:
case IrCmd::TRY_NUM_TO_INDEX:
@ -256,5 +257,8 @@ void foldConstants(IrBuilder& build, IrFunction& function, IrBlock& block, uint3
uint32_t getNativeContextOffset(int bfid);
// Cleans up blocks that were created with no users
void killUnusedBlocks(IrFunction& function);
} // namespace CodeGen
} // namespace Luau

View file

@ -24,6 +24,15 @@ struct EntryLocations
Label epilogueStart;
};
static void emitClearNativeFlag(AssemblyBuilderA64& build)
{
build.ldr(x0, mem(rState, offsetof(lua_State, ci)));
build.ldr(w1, mem(x0, offsetof(CallInfo, flags)));
build.mov(w2, ~LUA_CALLINFO_NATIVE);
build.and_(w1, w1, w2);
build.str(w1, mem(x0, offsetof(CallInfo, flags)));
}
static void emitExit(AssemblyBuilderA64& build, bool continueInVm)
{
build.mov(x0, continueInVm);
@ -31,6 +40,16 @@ static void emitExit(AssemblyBuilderA64& build, bool continueInVm)
build.br(x1);
}
static void emitUpdatePcAndContinueInVm(AssemblyBuilderA64& build)
{
// x0 = pcpos * sizeof(Instruction)
build.add(x0, rCode, x0);
build.ldr(x1, mem(rState, offsetof(lua_State, ci)));
build.str(x0, mem(x1, offsetof(CallInfo, savedpc)));
emitExit(build, /* continueInVm */ true);
}
static void emitInterrupt(AssemblyBuilderA64& build)
{
// x0 = pc offset
@ -286,6 +305,11 @@ bool initHeaderFunctions(NativeState& data)
void assembleHelpers(AssemblyBuilderA64& build, ModuleHelpers& helpers)
{
if (build.logText)
build.logAppend("; exitContinueVmClearNativeFlag\n");
build.setLabel(helpers.exitContinueVmClearNativeFlag);
emitClearNativeFlag(build);
if (build.logText)
build.logAppend("; exitContinueVm\n");
build.setLabel(helpers.exitContinueVm);
@ -296,6 +320,11 @@ void assembleHelpers(AssemblyBuilderA64& build, ModuleHelpers& helpers)
build.setLabel(helpers.exitNoContinueVm);
emitExit(build, /* continueInVm */ false);
if (build.logText)
build.logAppend("; updatePcAndContinueInVm\n");
build.setLabel(helpers.updatePcAndContinueInVm);
emitUpdatePcAndContinueInVm(build);
if (build.logText)
build.logAppend("; reentry\n");
build.setLabel(helpers.reentry);

View file

@ -221,6 +221,8 @@ inline bool lowerIr(A64::AssemblyBuilderA64& build, IrBuilder& ir, ModuleHelpers
template<typename AssemblyBuilder>
inline bool lowerFunction(IrBuilder& ir, AssemblyBuilder& build, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options)
{
killUnusedBlocks(ir.function);
computeCfgInfo(ir.function);
if (!FFlag::DebugCodegenNoOpt)

View file

@ -180,6 +180,11 @@ bool initHeaderFunctions(NativeState& data)
void assembleHelpers(X64::AssemblyBuilderX64& build, ModuleHelpers& helpers)
{
if (build.logText)
build.logAppend("; exitContinueVmClearNativeFlag\n");
build.setLabel(helpers.exitContinueVmClearNativeFlag);
emitClearNativeFlag(build);
if (build.logText)
build.logAppend("; exitContinueVm\n");
build.setLabel(helpers.exitContinueVm);
@ -190,6 +195,11 @@ void assembleHelpers(X64::AssemblyBuilderX64& build, ModuleHelpers& helpers)
build.setLabel(helpers.exitNoContinueVm);
emitExit(build, /* continueInVm */ false);
if (build.logText)
build.logAppend("; updatePcAndContinueInVm\n");
build.setLabel(helpers.updatePcAndContinueInVm);
emitUpdatePcAndContinueInVm(build);
if (build.logText)
build.logAppend("; continueCallInVm\n");
build.setLabel(helpers.continueCallInVm);

View file

@ -24,6 +24,8 @@ struct ModuleHelpers
// A64/X64
Label exitContinueVm;
Label exitNoContinueVm;
Label exitContinueVmClearNativeFlag;
Label updatePcAndContinueInVm;
Label return_;
Label interrupt;

View file

@ -268,6 +268,13 @@ void callStepGc(IrRegAllocX64& regs, AssemblyBuilderX64& build)
build.setLabel(skip);
}
void emitClearNativeFlag(AssemblyBuilderX64& build)
{
build.mov(rax, qword[rState + offsetof(lua_State, ci)]);
build.and_(dword[rax + offsetof(CallInfo, flags)], ~LUA_CALLINFO_NATIVE);
}
void emitExit(AssemblyBuilderX64& build, bool continueInVm)
{
if (continueInVm)
@ -345,6 +352,16 @@ void emitFallback(IrRegAllocX64& regs, AssemblyBuilderX64& build, int offset, in
emitUpdateBase(build);
}
void emitUpdatePcAndContinueInVm(AssemblyBuilderX64& build)
{
// edx = pcpos * sizeof(Instruction)
build.add(rdx, sCode);
build.mov(rax, qword[rState + offsetof(lua_State, ci)]);
build.mov(qword[rax + offsetof(CallInfo, savedpc)], rdx);
emitExit(build, /* continueInVm */ true);
}
void emitContinueCallInVm(AssemblyBuilderX64& build)
{
RegisterX64 proto = rcx; // Sync with emitInstCall

View file

@ -175,11 +175,13 @@ void callBarrierObject(IrRegAllocX64& regs, AssemblyBuilderX64& build, RegisterX
void callBarrierTableFast(IrRegAllocX64& regs, AssemblyBuilderX64& build, RegisterX64 table, IrOp tableOp);
void callStepGc(IrRegAllocX64& regs, AssemblyBuilderX64& build);
void emitClearNativeFlag(AssemblyBuilderX64& build);
void emitExit(AssemblyBuilderX64& build, bool continueInVm);
void emitUpdateBase(AssemblyBuilderX64& build);
void emitInterrupt(AssemblyBuilderX64& build);
void emitFallback(IrRegAllocX64& regs, AssemblyBuilderX64& build, int offset, int pcpos);
void emitUpdatePcAndContinueInVm(AssemblyBuilderX64& build);
void emitContinueCallInVm(AssemblyBuilderX64& build);
void emitReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers);

View file

@ -90,6 +90,11 @@ void emitInstCall(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, int
build.mov(qword[rState + offsetof(lua_State, top)], argi);
build.setLabel(skipVararg);
// Keep executing new function
// ci->savedpc = p->code;
build.mov(rax, qword[proto + offsetof(Proto, code)]);
build.mov(qword[ci + offsetof(CallInfo, savedpc)], rax);
// Get native function entry
build.mov(rax, qword[proto + offsetof(Proto, exectarget)]);
build.test(rax, rax);

View file

@ -22,6 +22,82 @@ IrBuilder::IrBuilder()
{
}
static void buildArgumentTypeChecks(IrBuilder& build, Proto* proto)
{
if (!proto->typeinfo || proto->numparams == 0)
return;
for (int i = 0; i < proto->numparams; ++i)
{
uint8_t et = proto->typeinfo[2 + i];
uint8_t tag = et & ~LBC_TYPE_OPTIONAL_BIT;
uint8_t optional = et & LBC_TYPE_OPTIONAL_BIT;
if (tag == LBC_TYPE_ANY)
continue;
IrOp load = build.inst(IrCmd::LOAD_TAG, build.vmReg(i));
IrOp nextCheck;
if (optional)
{
nextCheck = build.block(IrBlockKind::Internal);
IrOp fallbackCheck = build.block(IrBlockKind::Internal);
build.inst(IrCmd::JUMP_EQ_TAG, load, build.constTag(LUA_TNIL), nextCheck, fallbackCheck);
build.beginBlock(fallbackCheck);
}
switch (tag)
{
case LBC_TYPE_NIL:
build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TNIL), build.undef(), build.constInt(1));
break;
case LBC_TYPE_BOOLEAN:
build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TBOOLEAN), build.undef(), build.constInt(1));
break;
case LBC_TYPE_NUMBER:
build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TNUMBER), build.undef(), build.constInt(1));
break;
case LBC_TYPE_STRING:
build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TSTRING), build.undef(), build.constInt(1));
break;
case LBC_TYPE_TABLE:
build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TTABLE), build.undef(), build.constInt(1));
break;
case LBC_TYPE_FUNCTION:
build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TFUNCTION), build.undef(), build.constInt(1));
break;
case LBC_TYPE_THREAD:
build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TTHREAD), build.undef(), build.constInt(1));
break;
case LBC_TYPE_USERDATA:
build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TUSERDATA), build.undef(), build.constInt(1));
break;
case LBC_TYPE_VECTOR:
build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TVECTOR), build.undef(), build.constInt(1));
break;
}
if (optional)
{
build.inst(IrCmd::JUMP, nextCheck);
build.beginBlock(nextCheck);
}
}
// If the last argument is optional, we can skip creating a new internal block since one will already have been created.
if (!(proto->typeinfo[2 + proto->numparams - 1] & LBC_TYPE_OPTIONAL_BIT))
{
IrOp next = build.block(IrBlockKind::Internal);
build.inst(IrCmd::JUMP, next);
build.beginBlock(next);
}
}
void IrBuilder::buildFunctionIr(Proto* proto)
{
function.proto = proto;
@ -47,6 +123,9 @@ void IrBuilder::buildFunctionIr(Proto* proto)
if (instIndexToBlock[i] != kNoAssociatedBlockIndex)
beginBlock(blockAtInst(i));
if (i == 0)
buildArgumentTypeChecks(*this, proto);
// We skip dead bytecode instructions when they appear after block was already terminated
if (!inTerminatedBlock)
translateInst(op, pc, i);

View file

@ -212,7 +212,13 @@ RegisterX64 IrCallWrapperX64::suggestNextArgumentRegister(SizeX64 size) const
{
OperandX64 target = getNextArgumentTarget(size);
return target.cat == CategoryX64::reg ? regs.takeReg(target.base, kInvalidInstIdx) : regs.allocReg(size, kInvalidInstIdx);
if (target.cat != CategoryX64::reg)
return regs.allocReg(size, kInvalidInstIdx);
if (!regs.canTakeReg(target.base))
return regs.allocReg(size, kInvalidInstIdx);
return regs.takeReg(target.base, kInvalidInstIdx);
}
OperandX64 IrCallWrapperX64::getNextArgumentTarget(SizeX64 size) const

View file

@ -163,6 +163,8 @@ const char* getCmdName(IrCmd cmd)
return "JUMP_SLOT_MATCH";
case IrCmd::TABLE_LEN:
return "TABLE_LEN";
case IrCmd::STRING_LEN:
return "STRING_LEN";
case IrCmd::NEW_TABLE:
return "NEW_TABLE";
case IrCmd::DUP_TABLE:

View file

@ -376,8 +376,9 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
build.add(inst.regA64, regOp(inst.b), uint16_t(intOp(inst.a)));
else
{
RegisterA64 temp = tempInt(inst.b);
build.add(inst.regA64, regOp(inst.a), temp);
RegisterA64 temp1 = tempInt(inst.a);
RegisterA64 temp2 = tempInt(inst.b);
build.add(inst.regA64, temp1, temp2);
}
break;
case IrCmd::SUB_INT:
@ -386,8 +387,9 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
build.sub(inst.regA64, regOp(inst.a), uint16_t(intOp(inst.b)));
else
{
RegisterA64 temp = tempInt(inst.b);
build.sub(inst.regA64, regOp(inst.a), temp);
RegisterA64 temp1 = tempInt(inst.a);
RegisterA64 temp2 = tempInt(inst.b);
build.sub(inst.regA64, temp1, temp2);
}
break;
case IrCmd::ADD_NUM:
@ -689,6 +691,14 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
build.scvtf(inst.regA64, x0);
break;
}
case IrCmd::STRING_LEN:
{
RegisterA64 reg = regOp(inst.a);
inst.regA64 = regs.allocReg(KindA64::w, index);
build.ldr(inst.regA64, mem(reg, offsetof(TString, len)));
break;
}
case IrCmd::NEW_TABLE:
{
regs.spill(build, index);
@ -816,7 +826,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
}
case IrCmd::FASTCALL:
regs.spill(build, index);
error |= emitBuiltin(build, function, regs, uintOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c), inst.d, intOp(inst.e), intOp(inst.f));
error |= !emitBuiltin(build, function, regs, uintOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c), inst.d, intOp(inst.e), intOp(inst.f));
break;
case IrCmd::INVOKE_FASTCALL:
{
@ -1018,8 +1028,9 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
break;
case IrCmd::CHECK_TAG:
{
bool continueInVm = (inst.d.kind == IrOpKind::Constant && intOp(inst.d));
Label abort; // used when guard aborts execution
Label& fail = inst.c.kind == IrOpKind::Undef ? abort : labelOp(inst.c);
Label& fail = inst.c.kind == IrOpKind::Undef ? (continueInVm ? helpers.exitContinueVmClearNativeFlag : abort) : labelOp(inst.c);
if (tagOp(inst.b) == 0)
{
build.cbnz(regOp(inst.a), fail);
@ -1029,7 +1040,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
build.cmp(regOp(inst.a), tagOp(inst.b));
build.b(ConditionA64::NotEqual, fail);
}
if (abort.id)
if (abort.id && !continueInVm)
emitAbort(build, abort);
break;
}
@ -1060,9 +1071,18 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
RegisterA64 tempw = castReg(KindA64::w, temp);
build.ldr(temp, mem(rClosure, offsetof(Closure, env)));
build.ldrb(tempw, mem(temp, offsetof(Table, safeenv)));
build.cbz(tempw, inst.a.kind == IrOpKind::Undef ? abort : labelOp(inst.a));
if (abort.id)
if (inst.a.kind == IrOpKind::Undef)
{
build.cbz(tempw, abort);
emitAbort(build, abort);
}
else
{
Label self;
build.cbz(tempw, self);
exitHandlers.push_back({self, uintOp(inst.a)});
}
break;
}
case IrCmd::CHECK_ARRAY_SIZE:
@ -1528,7 +1548,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
case IrCmd::BITAND_UINT:
{
inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.a, inst.b});
if (inst.b.kind == IrOpKind::Constant && AssemblyBuilderA64::isMaskSupported(unsigned(intOp(inst.b))))
if (inst.a.kind == IrOpKind::Inst && inst.b.kind == IrOpKind::Constant && AssemblyBuilderA64::isMaskSupported(unsigned(intOp(inst.b))))
build.and_(inst.regA64, regOp(inst.a), unsigned(intOp(inst.b)));
else
{
@ -1541,7 +1561,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
case IrCmd::BITXOR_UINT:
{
inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.a, inst.b});
if (inst.b.kind == IrOpKind::Constant && AssemblyBuilderA64::isMaskSupported(unsigned(intOp(inst.b))))
if (inst.a.kind == IrOpKind::Inst && inst.b.kind == IrOpKind::Constant && AssemblyBuilderA64::isMaskSupported(unsigned(intOp(inst.b))))
build.eor(inst.regA64, regOp(inst.a), unsigned(intOp(inst.b)));
else
{
@ -1554,7 +1574,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
case IrCmd::BITOR_UINT:
{
inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.a, inst.b});
if (inst.b.kind == IrOpKind::Constant && AssemblyBuilderA64::isMaskSupported(unsigned(intOp(inst.b))))
if (inst.a.kind == IrOpKind::Inst && inst.b.kind == IrOpKind::Constant && AssemblyBuilderA64::isMaskSupported(unsigned(intOp(inst.b))))
build.orr(inst.regA64, regOp(inst.a), unsigned(intOp(inst.b)));
else
{
@ -1574,7 +1594,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
case IrCmd::BITLSHIFT_UINT:
{
inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.a, inst.b});
if (inst.b.kind == IrOpKind::Constant)
if (inst.a.kind == IrOpKind::Inst && inst.b.kind == IrOpKind::Constant)
build.lsl(inst.regA64, regOp(inst.a), uint8_t(unsigned(intOp(inst.b)) & 31));
else
{
@ -1587,7 +1607,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
case IrCmd::BITRSHIFT_UINT:
{
inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.a, inst.b});
if (inst.b.kind == IrOpKind::Constant)
if (inst.a.kind == IrOpKind::Inst && inst.b.kind == IrOpKind::Constant)
build.lsr(inst.regA64, regOp(inst.a), uint8_t(unsigned(intOp(inst.b)) & 31));
else
{
@ -1600,7 +1620,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
case IrCmd::BITARSHIFT_UINT:
{
inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.a, inst.b});
if (inst.b.kind == IrOpKind::Constant)
if (inst.a.kind == IrOpKind::Inst && inst.b.kind == IrOpKind::Constant)
build.asr(inst.regA64, regOp(inst.a), uint8_t(unsigned(intOp(inst.b)) & 31));
else
{
@ -1612,7 +1632,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
}
case IrCmd::BITLROTATE_UINT:
{
if (inst.b.kind == IrOpKind::Constant)
if (inst.a.kind == IrOpKind::Inst && inst.b.kind == IrOpKind::Constant)
{
inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.a});
build.ror(inst.regA64, regOp(inst.a), uint8_t((32 - unsigned(intOp(inst.b))) & 31));
@ -1630,7 +1650,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
case IrCmd::BITRROTATE_UINT:
{
inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.a, inst.b});
if (inst.b.kind == IrOpKind::Constant)
if (inst.a.kind == IrOpKind::Inst && inst.b.kind == IrOpKind::Constant)
build.ror(inst.regA64, regOp(inst.a), uint8_t(unsigned(intOp(inst.b)) & 31));
else
{
@ -1751,11 +1771,21 @@ void IrLoweringA64::finishFunction()
build.adr(x1, handler.next);
build.b(helpers.interrupt);
}
if (build.logText)
build.logAppend("; exit handlers\n");
for (ExitHandler& handler : exitHandlers)
{
build.setLabel(handler.self);
build.mov(x0, handler.pcpos * sizeof(Instruction));
build.b(helpers.updatePcAndContinueInVm);
}
}
bool IrLoweringA64::hasError() const
{
return error;
return error || regs.error;
}
bool IrLoweringA64::isFallthroughBlock(IrBlock target, IrBlock next)

View file

@ -60,6 +60,12 @@ struct IrLoweringA64
Label next;
};
struct ExitHandler
{
Label self;
unsigned int pcpos;
};
AssemblyBuilderA64& build;
ModuleHelpers& helpers;
@ -70,6 +76,7 @@ struct IrLoweringA64
IrValueLocationTracking valueTracker;
std::vector<InterruptHandler> interruptHandlers;
std::vector<ExitHandler> exitHandlers;
bool error = false;
};

View file

@ -584,6 +584,13 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
build.vcvtsi2sd(inst.regX64, inst.regX64, eax);
break;
}
case IrCmd::STRING_LEN:
{
RegisterX64 ptr = regOp(inst.a);
inst.regX64 = regs.allocReg(SizeX64::dword, index);
build.mov(inst.regX64, dword[ptr + offsetof(TString, len)]);
break;
}
case IrCmd::NEW_TABLE:
{
IrCallWrapperX64 callWrap(regs, build, index);
@ -720,9 +727,6 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
int nparams = intOp(inst.e);
int nresults = intOp(inst.f);
ScopedRegX64 func{regs, SizeX64::qword};
build.mov(func.reg, qword[rNativeContext + offsetof(NativeContext, luauF_table) + bfid * sizeof(luau_FastFunction)]);
IrCallWrapperX64 callWrap(regs, build, index);
callWrap.addArgument(SizeX64::qword, rState);
callWrap.addArgument(SizeX64::qword, luauRegAddress(ra));
@ -748,6 +752,9 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
callWrap.addArgument(SizeX64::dword, nparams);
}
ScopedRegX64 func{regs, SizeX64::qword};
build.mov(func.reg, qword[rNativeContext + offsetof(NativeContext, luauF_table) + bfid * sizeof(luau_FastFunction)]);
callWrap.call(func.release());
inst.regX64 = regs.takeReg(eax, index); // Result of a builtin call is returned in eax
break;
@ -878,9 +885,12 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
callPrepareForN(regs, build, vmRegOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c));
break;
case IrCmd::CHECK_TAG:
{
bool continueInVm = (inst.d.kind == IrOpKind::Constant && intOp(inst.d));
build.cmp(memRegTagOp(inst.a), tagOp(inst.b));
jumpOrAbortOnUndef(ConditionX64::NotEqual, ConditionX64::Equal, inst.c);
jumpOrAbortOnUndef(ConditionX64::NotEqual, ConditionX64::Equal, inst.c, continueInVm);
break;
}
case IrCmd::CHECK_READONLY:
build.cmp(byte[regOp(inst.a) + offsetof(Table, readonly)], 0);
jumpOrAbortOnUndef(ConditionX64::NotEqual, ConditionX64::Equal, inst.b);
@ -896,7 +906,20 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
build.mov(tmp.reg, sClosure);
build.mov(tmp.reg, qword[tmp.reg + offsetof(Closure, env)]);
build.cmp(byte[tmp.reg + offsetof(Table, safeenv)], 0);
jumpOrAbortOnUndef(ConditionX64::Equal, ConditionX64::NotEqual, inst.a);
if (inst.a.kind == IrOpKind::Undef)
{
Label skip;
build.jcc(ConditionX64::NotEqual, skip);
build.ud2();
build.setLabel(skip);
}
else
{
Label self;
build.jcc(ConditionX64::Equal, self);
exitHandlers.push_back({self, uintOp(inst.a)});
}
break;
}
case IrCmd::CHECK_ARRAY_SIZE:
@ -1403,6 +1426,16 @@ void IrLoweringX64::finishFunction()
build.lea(rbx, handler.next);
build.jmp(helpers.interrupt);
}
if (build.logText)
build.logAppend("; exit handlers\n");
for (ExitHandler& handler : exitHandlers)
{
build.setLabel(handler.self);
build.mov(edx, handler.pcpos * sizeof(Instruction));
build.jmp(helpers.updatePcAndContinueInVm);
}
}
bool IrLoweringX64::hasError() const
@ -1425,10 +1458,16 @@ void IrLoweringX64::jumpOrFallthrough(IrBlock& target, IrBlock& next)
build.jmp(target.label);
}
void IrLoweringX64::jumpOrAbortOnUndef(ConditionX64 cond, ConditionX64 condInverse, IrOp targetOrUndef)
void IrLoweringX64::jumpOrAbortOnUndef(ConditionX64 cond, ConditionX64 condInverse, IrOp targetOrUndef, bool continueInVm)
{
if (targetOrUndef.kind == IrOpKind::Undef)
{
if (continueInVm)
{
build.jcc(cond, helpers.exitContinueVmClearNativeFlag);
return;
}
Label skip;
build.jcc(condInverse, skip);
build.ud2();

View file

@ -34,7 +34,7 @@ struct IrLoweringX64
bool isFallthroughBlock(IrBlock target, IrBlock next);
void jumpOrFallthrough(IrBlock& target, IrBlock& next);
void jumpOrAbortOnUndef(ConditionX64 cond, ConditionX64 condInverse, IrOp targetOrUndef);
void jumpOrAbortOnUndef(ConditionX64 cond, ConditionX64 condInverse, IrOp targetOrUndef, bool continueInVm = false);
void storeDoubleAsFloat(OperandX64 dst, IrOp src);
@ -60,6 +60,12 @@ struct IrLoweringX64
Label next;
};
struct ExitHandler
{
Label self;
unsigned int pcpos;
};
AssemblyBuilderX64& build;
ModuleHelpers& helpers;
@ -70,6 +76,7 @@ struct IrLoweringX64
IrValueLocationTracking valueTracker;
std::vector<InterruptHandler> interruptHandlers;
std::vector<ExitHandler> exitHandlers;
};
} // namespace X64

View file

@ -18,6 +18,8 @@ namespace CodeGen
namespace A64
{
static const int8_t kInvalidSpill = 64;
static int allocSpill(uint32_t& free, KindA64 kind)
{
LUAU_ASSERT(kStackSize <= 256); // to support larger stack frames, we need to ensure qN is allocated at 16b boundary to fit in ldr/str encoding
@ -91,7 +93,8 @@ static void restoreInst(AssemblyBuilderA64& build, uint32_t& freeSpillSlots, IrF
{
build.ldr(reg, mem(sp, sSpillArea.data + s.slot * 8));
freeSpill(freeSpillSlots, reg.kind, s.slot);
if (s.slot != kInvalidSpill)
freeSpill(freeSpillSlots, reg.kind, s.slot);
}
else
{
@ -135,9 +138,8 @@ RegisterA64 IrRegAllocA64::allocReg(KindA64 kind, uint32_t index)
if (set.free == 0)
{
// TODO: remember the error and fail lowering
LUAU_ASSERT(!"Out of registers to allocate");
return noreg;
error = true;
return RegisterA64{kind, 0};
}
int reg = 31 - countlz(set.free);
@ -157,9 +159,8 @@ RegisterA64 IrRegAllocA64::allocTemp(KindA64 kind)
if (set.free == 0)
{
// TODO: remember the error and fail lowering
LUAU_ASSERT(!"Out of registers to allocate");
return noreg;
error = true;
return RegisterA64{kind, 0};
}
int reg = 31 - countlz(set.free);
@ -332,7 +333,11 @@ size_t IrRegAllocA64::spill(AssemblyBuilderA64& build, uint32_t index, std::init
else
{
int slot = allocSpill(freeSpillSlots, def.regA64.kind);
LUAU_ASSERT(slot >= 0); // TODO: remember the error and fail lowering
if (slot < 0)
{
slot = kInvalidSpill;
error = true;
}
build.str(def.regA64, mem(sp, sSpillArea.data + slot * 8));

View file

@ -77,6 +77,8 @@ struct IrRegAllocA64
// which 8-byte slots are free
uint32_t freeSpillSlots = 0;
bool error = false;
};
} // namespace A64

View file

@ -121,6 +121,14 @@ RegisterX64 IrRegAllocX64::takeReg(RegisterX64 reg, uint32_t instIdx)
return reg;
}
bool IrRegAllocX64::canTakeReg(RegisterX64 reg) const
{
const std::array<bool, 16>& freeMap = reg.size == SizeX64::xmmword ? freeXmmMap : freeGprMap;
const std::array<uint32_t, 16>& instUsers = reg.size == SizeX64::xmmword ? xmmInstUsers : gprInstUsers;
return freeMap[reg.index] || instUsers[reg.index] != kInvalidInstIdx;
}
void IrRegAllocX64::freeReg(RegisterX64 reg)
{
if (reg.size == SizeX64::xmmword)

View file

@ -737,6 +737,23 @@ static BuiltinImplResult translateBuiltinVector(IrBuilder& build, int nparams, i
return {BuiltinImplType::UsesFallback, 1};
}
static BuiltinImplResult translateBuiltinStringLen(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback)
{
if (nparams < 1 || nresults > 1)
return {BuiltinImplType::None, -1};
build.loadAndCheckTag(build.vmReg(arg), LUA_TSTRING, fallback);
IrOp ts = build.inst(IrCmd::LOAD_POINTER, build.vmReg(arg));
IrOp len = build.inst(IrCmd::STRING_LEN, ts);
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), build.inst(IrCmd::INT_TO_NUM, len));
build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER));
return {BuiltinImplType::UsesFallback, 1};
}
BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg, IrOp args, int nparams, int nresults, IrOp fallback)
{
// Builtins are not allowed to handle variadic arguments
@ -821,6 +838,8 @@ BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg,
return translateBuiltinTypeof(build, nparams, ra, arg, args, nresults, fallback);
case LBF_VECTOR:
return translateBuiltinVector(build, nparams, ra, arg, args, nresults, fallback);
case LBF_STRING_LEN:
return translateBuiltinStringLen(build, nparams, ra, arg, args, nresults, fallback);
default:
return {BuiltinImplType::None, -1};
}

View file

@ -516,6 +516,7 @@ void translateInstCloseUpvals(IrBuilder& build, const Instruction* pc)
void translateFastCallN(IrBuilder& build, const Instruction* pc, int pcpos, bool customParams, int customParamCount, IrOp customArgs, IrOp next)
{
LuauOpcode opcode = LuauOpcode(LUAU_INSN_OP(*pc));
int bfid = LUAU_INSN_A(*pc);
int skip = LUAU_INSN_C(*pc);
@ -540,7 +541,8 @@ void translateFastCallN(IrBuilder& build, const Instruction* pc, int pcpos, bool
IrOp fallback = build.block(IrBlockKind::Fallback);
build.inst(IrCmd::CHECK_SAFE_ENV, fallback);
// In unsafe environment, instead of retrying fastcall at 'pcpos' we side-exit directly to fallback sequence
build.inst(IrCmd::CHECK_SAFE_ENV, build.constUint(pcpos + getOpLength(opcode)));
BuiltinImplResult br = translateBuiltin(build, LuauBuiltinFunction(bfid), ra, arg, builtinArgs, nparams, nresults, fallback);
@ -554,7 +556,7 @@ void translateFastCallN(IrBuilder& build, const Instruction* pc, int pcpos, bool
else
{
// TODO: we can skip saving pc for some well-behaved builtins which we didn't inline
build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1));
build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + getOpLength(opcode)));
IrOp res = build.inst(IrCmd::INVOKE_FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), args, build.constInt(nparams),
build.constInt(nresults));
@ -668,7 +670,7 @@ void translateInstForGPrepNext(IrBuilder& build, const Instruction* pc, int pcpo
IrOp fallback = build.block(IrBlockKind::Fallback);
// fast-path: pairs/next
build.inst(IrCmd::CHECK_SAFE_ENV, fallback);
build.inst(IrCmd::CHECK_SAFE_ENV, build.constUint(pcpos));
IrOp tagB = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 1));
build.inst(IrCmd::CHECK_TAG, tagB, build.constTag(LUA_TTABLE), fallback);
IrOp tagC = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 2));
@ -695,7 +697,7 @@ void translateInstForGPrepInext(IrBuilder& build, const Instruction* pc, int pcp
IrOp finish = build.block(IrBlockKind::Internal);
// fast-path: ipairs/inext
build.inst(IrCmd::CHECK_SAFE_ENV, fallback);
build.inst(IrCmd::CHECK_SAFE_ENV, build.constUint(pcpos));
IrOp tagB = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 1));
build.inst(IrCmd::CHECK_TAG, tagB, build.constTag(LUA_TTABLE), fallback);
IrOp tagC = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 2));
@ -921,7 +923,7 @@ void translateInstGetImport(IrBuilder& build, const Instruction* pc, int pcpos)
IrOp fastPath = build.block(IrBlockKind::Internal);
IrOp fallback = build.block(IrBlockKind::Fallback);
build.inst(IrCmd::CHECK_SAFE_ENV, fallback);
build.inst(IrCmd::CHECK_SAFE_ENV, build.constUint(pcpos));
// note: if import failed, k[] is nil; we could check this during codegen, but we instead use runtime fallback
// this allows us to handle ahead-of-time codegen smoothly when an import fails to resolve at runtime

View file

@ -80,6 +80,8 @@ IrValueKind getCmdValueKind(IrCmd cmd)
return IrValueKind::None;
case IrCmd::TABLE_LEN:
return IrValueKind::Double;
case IrCmd::STRING_LEN:
return IrValueKind::Int;
case IrCmd::NEW_TABLE:
case IrCmd::DUP_TABLE:
return IrValueKind::Pointer;
@ -684,8 +686,7 @@ void foldConstants(IrBuilder& build, IrFunction& function, IrBlock& block, uint3
unsigned op1 = unsigned(function.intOp(inst.a));
int op2 = function.intOp(inst.b);
if (unsigned(op2) < 32)
substitute(function, inst, build.constInt(op1 << op2));
substitute(function, inst, build.constInt(op1 << (op2 & 31)));
}
else if (inst.b.kind == IrOpKind::Constant && function.intOp(inst.b) == 0)
{
@ -698,8 +699,7 @@ void foldConstants(IrBuilder& build, IrFunction& function, IrBlock& block, uint3
unsigned op1 = unsigned(function.intOp(inst.a));
int op2 = function.intOp(inst.b);
if (unsigned(op2) < 32)
substitute(function, inst, build.constInt(op1 >> op2));
substitute(function, inst, build.constInt(op1 >> (op2 & 31)));
}
else if (inst.b.kind == IrOpKind::Constant && function.intOp(inst.b) == 0)
{
@ -712,12 +712,9 @@ void foldConstants(IrBuilder& build, IrFunction& function, IrBlock& block, uint3
int op1 = function.intOp(inst.a);
int op2 = function.intOp(inst.b);
if (unsigned(op2) < 32)
{
// note: technically right shift of negative values is UB, but this behavior is getting defined in C++20 and all compilers do the
// right (shift) thing.
substitute(function, inst, build.constInt(op1 >> op2));
}
// note: technically right shift of negative values is UB, but this behavior is getting defined in C++20 and all compilers do the
// right (shift) thing.
substitute(function, inst, build.constInt(op1 >> (op2 & 31)));
}
else if (inst.b.kind == IrOpKind::Constant && function.intOp(inst.b) == 0)
{
@ -794,5 +791,17 @@ uint32_t getNativeContextOffset(int bfid)
return 0;
}
void killUnusedBlocks(IrFunction& function)
{
// Start from 1 as the first block is the entry block
for (unsigned i = 1; i < function.blocks.size(); i++)
{
IrBlock& block = function.blocks[i];
if (block.kind != IrBlockKind::Dead && block.useCount == 0)
kill(function, block);
}
}
} // namespace CodeGen
} // namespace Luau

View file

@ -441,17 +441,25 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction&
state.substituteOrRecordVmRegLoad(inst);
break;
case IrCmd::LOAD_DOUBLE:
if (IrOp value = state.tryGetValue(inst.a); value.kind == IrOpKind::Constant)
{
IrOp value = state.tryGetValue(inst.a);
if (function.asDoubleOp(value))
substitute(function, inst, value);
else if (inst.a.kind == IrOpKind::VmReg)
state.substituteOrRecordVmRegLoad(inst);
break;
}
case IrCmd::LOAD_INT:
if (IrOp value = state.tryGetValue(inst.a); value.kind == IrOpKind::Constant)
{
IrOp value = state.tryGetValue(inst.a);
if (function.asIntOp(value))
substitute(function, inst, value);
else if (inst.a.kind == IrOpKind::VmReg)
state.substituteOrRecordVmRegLoad(inst);
break;
}
case IrCmd::LOAD_TVALUE:
if (inst.a.kind == IrOpKind::VmReg)
state.substituteOrRecordVmRegLoad(inst);
@ -775,6 +783,7 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction&
case IrCmd::JUMP_EQ_POINTER:
case IrCmd::JUMP_SLOT_MATCH:
case IrCmd::TABLE_LEN:
case IrCmd::STRING_LEN:
case IrCmd::NEW_TABLE:
case IrCmd::DUP_TABLE:
case IrCmd::TRY_NUM_TO_INDEX:

View file

@ -2314,9 +2314,9 @@ static const char* getBaseTypeString(uint8_t type)
case LBC_TYPE_STRING:
return "string";
case LBC_TYPE_TABLE:
return "{ }";
return "table";
case LBC_TYPE_FUNCTION:
return "function( )";
return "function";
case LBC_TYPE_THREAD:
return "thread";
case LBC_TYPE_USERDATA:

View file

@ -26,8 +26,7 @@ LUAU_FASTINTVARIABLE(LuauCompileInlineThreshold, 25)
LUAU_FASTINTVARIABLE(LuauCompileInlineThresholdMaxBoost, 300)
LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5)
LUAU_FASTFLAGVARIABLE(CompileFunctionType, false)
LUAU_FASTFLAG(BytecodeVersion4)
LUAU_FASTFLAGVARIABLE(LuauCompileFunctionType, false)
namespace Luau
{
@ -103,6 +102,7 @@ struct Compiler
, locstants(nullptr)
, tableShapes(nullptr)
, builtins(nullptr)
, typeMap(nullptr)
{
// preallocate some buffers that are very likely to grow anyway; this works around std::vector's inefficient growth policy for small arrays
localStack.reserve(16);
@ -204,11 +204,11 @@ struct Compiler
setDebugLine(func);
if (FFlag::BytecodeVersion4 && FFlag::CompileFunctionType)
if (FFlag::LuauCompileFunctionType)
{
std::string funcType = getFunctionType(func);
if (!funcType.empty())
bytecode.setFunctionTypeInfo(std::move(funcType));
// note: we move types out of typeMap which is safe because compileFunction is only called once per function
if (std::string* funcType = typeMap.find(func))
bytecode.setFunctionTypeInfo(std::move(*funcType));
}
if (func->vararg)
@ -3807,6 +3807,8 @@ struct Compiler
DenseHashMap<AstLocal*, Constant> locstants;
DenseHashMap<AstExprTable*, TableShape> tableShapes;
DenseHashMap<AstExprCall*, int> builtins;
DenseHashMap<AstExprFunction*, std::string> typeMap;
const DenseHashMap<AstExprCall*, int>* builtinsFold = nullptr;
unsigned int regTop = 0;
@ -3870,6 +3872,11 @@ void compileOrThrow(BytecodeBuilder& bytecode, const ParseResult& parseResult, c
root->visit(&fenvVisitor);
}
if (FFlag::LuauCompileFunctionType)
{
buildTypeMap(compiler.typeMap, root);
}
// gathers all functions with the invariant that all function references are to functions earlier in the list
// for example, function foo() return function() end end will result in two vector entries, [0] = anonymous and [1] = foo
std::vector<AstExprFunction*> functions;

View file

@ -6,22 +6,58 @@
namespace Luau
{
static LuauBytecodeEncodedType getType(AstType* ty)
static bool isGeneric(AstName name, const AstArray<AstGenericType>& generics)
{
for (const AstGenericType& gt : generics)
if (gt.name == name)
return true;
return false;
}
static LuauBytecodeEncodedType getPrimitiveType(AstName name)
{
if (name == "nil")
return LBC_TYPE_NIL;
else if (name == "boolean")
return LBC_TYPE_BOOLEAN;
else if (name == "number")
return LBC_TYPE_NUMBER;
else if (name == "string")
return LBC_TYPE_STRING;
else if (name == "thread")
return LBC_TYPE_THREAD;
else if (name == "any" || name == "unknown")
return LBC_TYPE_ANY;
else
return LBC_TYPE_INVALID;
}
static LuauBytecodeEncodedType getType(
AstType* ty, const AstArray<AstGenericType>& generics, const DenseHashMap<AstName, AstStatTypeAlias*>& typeAliases, bool resolveAliases)
{
if (AstTypeReference* ref = ty->as<AstTypeReference>())
{
if (ref->name == "nil")
return LBC_TYPE_NIL;
else if (ref->name == "boolean")
return LBC_TYPE_BOOLEAN;
else if (ref->name == "number")
return LBC_TYPE_NUMBER;
else if (ref->name == "string")
return LBC_TYPE_STRING;
else if (ref->name == "thread")
return LBC_TYPE_THREAD;
else if (ref->name == "any" || ref->name == "unknown")
if (ref->prefix)
return LBC_TYPE_ANY;
if (AstStatTypeAlias* const* alias = typeAliases.find(ref->name); alias && *alias)
{
// note: we only resolve aliases to the depth of 1 to avoid dealing with recursive aliases
if (resolveAliases)
return getType((*alias)->type, (*alias)->generics, typeAliases, /* resolveAliases= */ false);
else
return LBC_TYPE_ANY;
}
if (isGeneric(ref->name, generics))
return LBC_TYPE_ANY;
if (LuauBytecodeEncodedType prim = getPrimitiveType(ref->name); prim != LBC_TYPE_INVALID)
return prim;
// not primitive or alias or generic => host-provided, we assume userdata for now
return LBC_TYPE_USERDATA;
}
else if (AstTypeTable* table = ty->as<AstTypeTable>())
{
@ -38,7 +74,7 @@ static LuauBytecodeEncodedType getType(AstType* ty)
for (AstType* ty : un->types)
{
LuauBytecodeEncodedType et = getType(ty);
LuauBytecodeEncodedType et = getType(ty, generics, typeAliases, resolveAliases);
if (et == LBC_TYPE_NIL)
{
@ -69,11 +105,8 @@ static LuauBytecodeEncodedType getType(AstType* ty)
return LBC_TYPE_ANY;
}
std::string getFunctionType(const AstExprFunction* func)
static std::string getFunctionType(const AstExprFunction* func, const DenseHashMap<AstName, AstStatTypeAlias*>& typeAliases)
{
if (func->vararg || func->generics.size || func->genericPacks.size)
return {};
bool self = func->self != 0;
std::string typeInfo;
@ -88,7 +121,8 @@ std::string getFunctionType(const AstExprFunction* func)
bool haveNonAnyParam = false;
for (AstLocal* arg : func->args)
{
LuauBytecodeEncodedType ty = arg->annotation ? getType(arg->annotation) : LBC_TYPE_ANY;
LuauBytecodeEncodedType ty =
arg->annotation ? getType(arg->annotation, func->generics, typeAliases, /* resolveAliases= */ true) : LBC_TYPE_ANY;
if (ty != LBC_TYPE_ANY)
haveNonAnyParam = true;
@ -103,4 +137,88 @@ std::string getFunctionType(const AstExprFunction* func)
return typeInfo;
}
} // namespace Luau
struct TypeMapVisitor : AstVisitor
{
DenseHashMap<AstExprFunction*, std::string>& typeMap;
DenseHashMap<AstName, AstStatTypeAlias*> typeAliases;
std::vector<std::pair<AstName, AstStatTypeAlias*>> typeAliasStack;
TypeMapVisitor(DenseHashMap<AstExprFunction*, std::string>& typeMap)
: typeMap(typeMap)
, typeAliases(AstName())
{
}
size_t pushTypeAliases(AstStatBlock* block)
{
size_t aliasStackTop = typeAliasStack.size();
for (AstStat* stat : block->body)
if (AstStatTypeAlias* alias = stat->as<AstStatTypeAlias>())
{
AstStatTypeAlias*& prevAlias = typeAliases[alias->name];
typeAliasStack.push_back(std::make_pair(alias->name, prevAlias));
prevAlias = alias;
}
return aliasStackTop;
}
void popTypeAliases(size_t aliasStackTop)
{
while (typeAliasStack.size() > aliasStackTop)
{
std::pair<AstName, AstStatTypeAlias*>& top = typeAliasStack.back();
typeAliases[top.first] = top.second;
typeAliasStack.pop_back();
}
}
bool visit(AstStatBlock* node) override
{
size_t aliasStackTop = pushTypeAliases(node);
for (AstStat* stat : node->body)
stat->visit(this);
popTypeAliases(aliasStackTop);
return false;
}
// repeat..until scoping rules are such that condition (along with any possible functions declared in it) has aliases from repeat body in scope
bool visit(AstStatRepeat* node) override
{
size_t aliasStackTop = pushTypeAliases(node->body);
for (AstStat* stat : node->body->body)
stat->visit(this);
node->condition->visit(this);
popTypeAliases(aliasStackTop);
return false;
}
bool visit(AstExprFunction* node) override
{
std::string type = getFunctionType(node, typeAliases);
if (!type.empty())
typeMap[node] = std::move(type);
return true;
}
};
void buildTypeMap(DenseHashMap<AstExprFunction*, std::string>& typeMap, AstNode* root)
{
TypeMapVisitor visitor(typeMap);
root->visit(&visitor);
}
} // namespace Luau

View file

@ -3,7 +3,11 @@
#include "Luau/Ast.h"
#include <string>
namespace Luau
{
std::string getFunctionType(const AstExprFunction* func);
void buildTypeMap(DenseHashMap<AstExprFunction*, std::string>& typeMap, AstNode* root);
} // namespace Luau

View file

@ -176,9 +176,9 @@ coverage: $(TESTS_TARGET) $(COMPILE_CLI_TARGET)
mv default.profraw codegen-x64.profraw
llvm-profdata merge *.profraw -o default.profdata
rm *.profraw
llvm-cov show -format=html -show-instantiations=false -show-line-counts=true -show-region-summary=false -ignore-filename-regex=\(tests\|extern\|CLI\)/.* -output-dir=coverage --instr-profile default.profdata build/coverage/luau-tests
llvm-cov report -ignore-filename-regex=\(tests\|extern\|CLI\)/.* -show-region-summary=false --instr-profile default.profdata build/coverage/luau-tests
llvm-cov export -ignore-filename-regex=\(tests\|extern\|CLI\)/.* -format lcov --instr-profile default.profdata build/coverage/luau-tests >coverage.info
llvm-cov show -format=html -show-instantiations=false -show-line-counts=true -show-region-summary=false -ignore-filename-regex=\(tests\|extern\|CLI\)/.* -output-dir=coverage --instr-profile default.profdata -object build/coverage/luau-tests -object build/coverage/luau-compile
llvm-cov report -ignore-filename-regex=\(tests\|extern\|CLI\)/.* -show-region-summary=false --instr-profile default.profdata -object build/coverage/luau-tests -object build/coverage/luau-compile
llvm-cov export -ignore-filename-regex=\(tests\|extern\|CLI\)/.* -format lcov --instr-profile default.profdata -object build/coverage/luau-tests -object build/coverage/luau-compile >coverage.info
format:
git ls-files '*.h' '*.cpp' | xargs clang-format-11 -i

View file

@ -152,6 +152,7 @@ target_sources(Luau.Analysis PRIVATE
Analysis/include/Luau/DataFlowGraph.h
Analysis/include/Luau/DcrLogger.h
Analysis/include/Luau/Def.h
Analysis/include/Luau/Differ.h
Analysis/include/Luau/Documentation.h
Analysis/include/Luau/Error.h
Analysis/include/Luau/FileResolver.h
@ -209,6 +210,7 @@ target_sources(Luau.Analysis PRIVATE
Analysis/src/DataFlowGraph.cpp
Analysis/src/DcrLogger.cpp
Analysis/src/Def.cpp
Analysis/src/Differ.cpp
Analysis/src/EmbeddedBuiltinDefinitions.cpp
Analysis/src/Error.cpp
Analysis/src/Frontend.cpp
@ -366,6 +368,7 @@ if(TARGET Luau.UnitTest)
tests/CostModel.test.cpp
tests/DataFlowGraph.test.cpp
tests/DenseHash.test.cpp
tests/Differ.test.cpp
tests/Error.test.cpp
tests/Frontend.test.cpp
tests/IrBuilder.test.cpp

View file

@ -34,6 +34,7 @@ Proto* luaF_newproto(lua_State* L)
f->codeentry = NULL;
f->execdata = NULL;
f->exectarget = 0;
f->typeinfo = NULL;
return f;
}
@ -162,6 +163,9 @@ void luaF_freeproto(lua_State* L, Proto* f, lua_Page* page)
}
#endif
if (f->typeinfo)
luaM_freearray(L, f->typeinfo, f->numparams + 2, uint8_t, f->memcat);
luaM_freegco(L, f, sizeof(Proto), f->memcat, page);
}

View file

@ -279,6 +279,8 @@ typedef struct Proto
void* execdata;
uintptr_t exectarget;
uint8_t* typeinfo;
GCObject* gclist;

View file

@ -13,6 +13,8 @@
#include <string.h>
LUAU_FASTFLAGVARIABLE(LuauLoadCheckGC, false)
// TODO: RAII deallocation doesn't work for longjmp builds if a memory error happens
template<typename T>
struct TempBuffer
@ -178,6 +180,10 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size
return 1;
}
// we will allocate a fair amount of memory so check GC before we do
if (FFlag::LuauLoadCheckGC)
luaC_checkGC(L);
// pause GC for the duration of deserialization - some objects we're creating aren't rooted
// TODO: if an allocation error happens mid-load, we do not unpause GC!
size_t GCthreshold = L->global->GCthreshold;
@ -188,11 +194,11 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size
TString* source = luaS_new(L, chunkname);
uint8_t typesversion = 0;
if (version >= 4)
{
uint8_t typesversion = read<uint8_t>(data, size, offset);
LUAU_ASSERT(typesversion == 1);
typesversion = read<uint8_t>(data, size, offset);
}
// string table
@ -229,7 +235,7 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size
uint32_t typesize = readVarInt(data, size, offset);
if (typesize)
if (typesize && typesversion == LBC_TYPE_VERSION)
{
uint8_t* types = (uint8_t*)data + offset;
@ -237,8 +243,11 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size
LUAU_ASSERT(types[0] == LBC_TYPE_FUNCTION);
LUAU_ASSERT(types[1] == p->numparams);
offset += typesize;
p->typeinfo = luaM_newarray(L, typesize, uint8_t, p->memcat);
memcpy(p->typeinfo, types, typesize);
}
offset += typesize;
}
p->sizecode = readVarInt(data, size, offset);

View file

@ -27,12 +27,16 @@ const bool kFuzzTypeck = true;
const bool kFuzzVM = true;
const bool kFuzzTranspile = true;
const bool kFuzzCodegen = true;
const bool kFuzzCodegenAssembly = true;
// Should we generate type annotations?
const bool kFuzzTypes = true;
const Luau::CodeGen::AssemblyOptions::Target kFuzzCodegenTarget = Luau::CodeGen::AssemblyOptions::A64;
static_assert(!(kFuzzVM && !kFuzzCompiler), "VM requires the compiler!");
static_assert(!(kFuzzCodegen && !kFuzzVM), "Codegen requires the VM!");
static_assert(!(kFuzzCodegenAssembly && !kFuzzCompiler), "Codegen requires the compiler!");
std::vector<std::string> protoprint(const luau::ModuleSet& stat, bool types);
@ -348,6 +352,23 @@ DEFINE_PROTO_FUZZER(const luau::ModuleSet& message)
}
}
// run codegen on resulting bytecode (in separate state)
if (kFuzzCodegenAssembly && bytecode.size())
{
static lua_State* globalState = luaL_newstate();
if (luau_load(globalState, "=fuzz", bytecode.data(), bytecode.size(), 0) == 0)
{
Luau::CodeGen::AssemblyOptions options;
options.outputBinary = true;
options.target = kFuzzCodegenTarget;
Luau::CodeGen::getAssembly(globalState, -1, options);
}
lua_pop(globalState, 1);
lua_gc(globalState, LUA_GCCOLLECT, 0);
}
// run resulting bytecode (from last successfully compiler module)
if (kFuzzVM && bytecode.size())
{

View file

@ -107,7 +107,6 @@ ClassFixture::ClassFixture()
globals.globalScope->exportedTypeBindings["CallableClass"] = TypeFun{{}, callableClassType};
auto addIndexableClass = [&arena, &globals](const char* className, TypeId keyType, TypeId returnType) {
ScopedFastFlag LuauTypecheckClassTypeIndexers("LuauTypecheckClassTypeIndexers", true);
TypeId indexableClassMetaType = arena.addType(TableType{});
TypeId indexableClassType =
arena.addType(ClassType{className, {}, nullopt, indexableClassMetaType, {}, {}, "Test", TableIndexer{keyType, returnType}});

View file

@ -49,7 +49,7 @@ static std::string compileFunction0Coverage(const char* source, int level)
return bcb.dumpFunction(0);
}
static std::string compileFunction0TypeTable(const char* source)
static std::string compileTypeTable(const char* source)
{
Luau::BytecodeBuilder bcb;
bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code);
@ -7080,12 +7080,9 @@ L1: RETURN R3 1
TEST_CASE("EncodedTypeTable")
{
ScopedFastFlag sffs[] = {
{"BytecodeVersion4", true},
{"CompileFunctionType", true},
};
ScopedFastFlag sff("LuauCompileFunctionType", true);
CHECK_EQ("\n" + compileFunction0TypeTable(R"(
CHECK_EQ("\n" + compileTypeTable(R"(
function myfunc(test: string, num: number)
print(test)
end
@ -7104,6 +7101,9 @@ end
function myfunc5(test: string | number, n: number | boolean)
end
function myfunc6(test: (number) -> string)
end
myfunc('test')
)"),
R"(
@ -7111,9 +7111,10 @@ myfunc('test')
1: function(number?)
2: function(string, number)
3: function(any, number)
5: function(function)
)");
CHECK_EQ("\n" + compileFunction0TypeTable(R"(
CHECK_EQ("\n" + compileTypeTable(R"(
local Str = {
a = 1
}
@ -7126,7 +7127,95 @@ end
Str:test(234)
)"),
R"(
0: function({ }, number)
0: function(table, number)
)");
}
TEST_CASE("HostTypesAreUserdata")
{
ScopedFastFlag sff("LuauCompileFunctionType", true);
CHECK_EQ("\n" + compileTypeTable(R"(
function myfunc(test: string, num: number)
print(test)
end
function myfunc2(test: Instance, num: number)
end
type Foo = string
function myfunc3(test: string, n: Foo)
end
function myfunc4<Bar>(test: Bar, n: Part)
end
)"),
R"(
0: function(string, number)
1: function(userdata, number)
2: function(string, string)
3: function(any, userdata)
)");
}
TEST_CASE("TypeAliasScoping")
{
ScopedFastFlag sff("LuauCompileFunctionType", true);
CHECK_EQ("\n" + compileTypeTable(R"(
do
type Part = number
end
function myfunc1(test: Part, num: number)
end
do
type Part = number
function myfunc2(test: Part, num: number)
end
end
repeat
type Part = number
until (function(test: Part, num: number) end)()
function myfunc4(test: Instance, num: number)
end
type Instance = string
)"),
R"(
0: function(userdata, number)
1: function(number, number)
2: function(number, number)
3: function(string, number)
)");
}
TEST_CASE("TypeAliasResolve")
{
ScopedFastFlag sff("LuauCompileFunctionType", true);
CHECK_EQ("\n" + compileTypeTable(R"(
type Foo1 = number
type Foo2 = { number }
type Foo3 = Part
type Foo4 = Foo1 -- we do not resolve aliases within aliases
type Foo5<X> = X
function myfunc(f1: Foo1, f2: Foo2, f3: Foo3, f4: Foo4, f5: Foo5<number>)
end
function myfuncerr(f1: Foo1<string>, f2: Foo5)
end
)"),
R"(
0: function(number, table, userdata, any, any)
1: function(number, any)
)");
}

316
tests/Differ.test.cpp Normal file
View file

@ -0,0 +1,316 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Differ.h"
#include "Luau/Error.h"
#include "Luau/Frontend.h"
#include "Fixture.h"
#include "doctest.h"
#include <iostream>
using namespace Luau;
TEST_SUITE_BEGIN("Differ");
TEST_CASE_FIXTURE(Fixture, "equal_numbers")
{
CheckResult result = check(R"(
local foo = 5
local almostFoo = 78
almostFoo = foo
)");
LUAU_REQUIRE_NO_ERRORS(result);
TypeId foo = requireType("foo");
TypeId almostFoo = requireType("almostFoo");
try
{
DifferResult diffRes = diff(foo, almostFoo);
CHECK(!diffRes.diffError.has_value());
}
catch (InternalCompilerError e)
{
INFO(("InternalCompilerError: " + e.message));
CHECK(false);
}
}
TEST_CASE_FIXTURE(Fixture, "equal_strings")
{
CheckResult result = check(R"(
local foo = "hello"
local almostFoo = "world"
almostFoo = foo
)");
LUAU_REQUIRE_NO_ERRORS(result);
TypeId foo = requireType("foo");
TypeId almostFoo = requireType("almostFoo");
try
{
DifferResult diffRes = diff(foo, almostFoo);
CHECK(!diffRes.diffError.has_value());
}
catch (InternalCompilerError e)
{
INFO(("InternalCompilerError: " + e.message));
CHECK(false);
}
}
TEST_CASE_FIXTURE(Fixture, "equal_tables")
{
CheckResult result = check(R"(
local foo = { x = 1, y = "where" }
local almostFoo = { x = 5, y = "when" }
almostFoo = foo
)");
LUAU_REQUIRE_NO_ERRORS(result);
TypeId foo = requireType("foo");
TypeId almostFoo = requireType("almostFoo");
try
{
DifferResult diffRes = diff(foo, almostFoo);
CHECK(!diffRes.diffError.has_value());
}
catch (InternalCompilerError e)
{
INFO(("InternalCompilerError: " + e.message));
CHECK(false);
}
}
TEST_CASE_FIXTURE(Fixture, "a_table_missing_property")
{
CheckResult result = check(R"(
local foo = { x = 1, y = 2 }
local almostFoo = { x = 1, z = 3 }
almostFoo = foo
)");
LUAU_REQUIRE_ERRORS(result);
TypeId foo = requireType("foo");
TypeId almostFoo = requireType("almostFoo");
std::string diffMessage;
try
{
diffMessage = diff(foo, almostFoo).diffError->toString();
}
catch (InternalCompilerError e)
{
INFO(("InternalCompilerError: " + e.message));
CHECK(false);
}
CHECK_EQ("DiffError: these two types are not equal because the left type at foo.y has type number, while the right type at almostFoo is missing "
"the property y",
diffMessage);
}
TEST_CASE_FIXTURE(Fixture, "left_table_missing_property")
{
CheckResult result = check(R"(
local foo = { x = 1 }
local almostFoo = { x = 1, z = 3 }
almostFoo = foo
)");
LUAU_REQUIRE_ERRORS(result);
TypeId foo = requireType("foo");
TypeId almostFoo = requireType("almostFoo");
std::string diffMessage;
try
{
diffMessage = diff(foo, almostFoo).diffError->toString();
}
catch (InternalCompilerError e)
{
INFO(("InternalCompilerError: " + e.message));
CHECK(false);
}
CHECK_EQ("DiffError: these two types are not equal because the left type at foo is missing the property z, while the right type at almostFoo.z "
"has type number",
diffMessage);
}
TEST_CASE_FIXTURE(Fixture, "a_table_wrong_type")
{
CheckResult result = check(R"(
local foo = { x = 1, y = 2 }
local almostFoo = { x = 1, y = "two" }
almostFoo = foo
)");
LUAU_REQUIRE_ERRORS(result);
TypeId foo = requireType("foo");
TypeId almostFoo = requireType("almostFoo");
std::string diffMessage;
try
{
diffMessage = diff(foo, almostFoo).diffError->toString();
}
catch (InternalCompilerError e)
{
INFO(("InternalCompilerError: " + e.message));
CHECK(false);
}
CHECK_EQ("DiffError: these two types are not equal because the left type at foo.y has type number, while the right type at almostFoo.y has type "
"string",
diffMessage);
}
TEST_CASE_FIXTURE(Fixture, "a_table_wrong_type")
{
CheckResult result = check(R"(
local foo: string
local almostFoo: number
almostFoo = foo
)");
LUAU_REQUIRE_ERRORS(result);
TypeId foo = requireType("foo");
TypeId almostFoo = requireType("almostFoo");
std::string diffMessage;
try
{
diffMessage = diff(foo, almostFoo).diffError->toString();
}
catch (InternalCompilerError e)
{
INFO(("InternalCompilerError: " + e.message));
CHECK(false);
}
CHECK_EQ("DiffError: these two types are not equal because the left type at <unlabeled-symbol> has type string, while the right type at "
"<unlabeled-symbol> has type number",
diffMessage);
}
TEST_CASE_FIXTURE(Fixture, "a_nested_table_wrong_type")
{
CheckResult result = check(R"(
local foo = { x = 1, inner = { table = { has = { wrong = { value = 5 } } } } }
local almostFoo = { x = 1, inner = { table = { has = { wrong = { value = "five" } } } } }
almostFoo = foo
)");
LUAU_REQUIRE_ERRORS(result);
TypeId foo = requireType("foo");
TypeId almostFoo = requireType("almostFoo");
std::string diffMessage;
try
{
diffMessage = diff(foo, almostFoo).diffError->toString();
}
catch (InternalCompilerError e)
{
INFO(("InternalCompilerError: " + e.message));
CHECK(false);
}
CHECK_EQ("DiffError: these two types are not equal because the left type at foo.inner.table.has.wrong.value has type number, while the right "
"type at almostFoo.inner.table.has.wrong.value has type string",
diffMessage);
}
TEST_CASE_FIXTURE(Fixture, "a_nested_table_wrong_match")
{
CheckResult result = check(R"(
local foo = { x = 1, inner = { table = { has = { wrong = { variant = { because = { it = { goes = { on = "five" } } } } } } } } }
local almostFoo = { x = 1, inner = { table = { has = { wrong = { variant = "five" } } } } }
almostFoo = foo
)");
LUAU_REQUIRE_ERRORS(result);
TypeId foo = requireType("foo");
TypeId almostFoo = requireType("almostFoo");
std::string diffMessage;
try
{
diffMessage = diff(foo, almostFoo).diffError->toString();
}
catch (InternalCompilerError e)
{
INFO(("InternalCompilerError: " + e.message));
CHECK(false);
}
CHECK_EQ("DiffError: these two types are not equal because the left type at foo.inner.table.has.wrong.variant has type { because: { it: { goes: "
"{ on: string } } } }, while the right type at almostFoo.inner.table.has.wrong.variant has type string",
diffMessage);
}
TEST_CASE_FIXTURE(Fixture, "singleton")
{
CheckResult result = check(R"(
local foo: "hello" = "hello"
local almostFoo: true = true
almostFoo = foo
)");
LUAU_REQUIRE_ERRORS(result);
TypeId foo = requireType("foo");
TypeId almostFoo = requireType("almostFoo");
std::string diffMessage;
try
{
diffMessage = diff(foo, almostFoo).diffError->toString();
}
catch (InternalCompilerError e)
{
INFO(("InternalCompilerError: " + e.message));
CHECK(false);
}
CHECK_EQ(
R"(DiffError: these two types are not equal because the left type at <unlabeled-symbol> has type "hello", while the right type at <unlabeled-symbol> has type true)",
diffMessage);
}
TEST_CASE_FIXTURE(Fixture, "equal_singleton")
{
CheckResult result = check(R"(
local foo: "hello" = "hello"
local almostFoo: "hello"
almostFoo = foo
)");
LUAU_REQUIRE_NO_ERRORS(result);
TypeId foo = requireType("foo");
TypeId almostFoo = requireType("almostFoo");
try
{
DifferResult diffRes = diff(foo, almostFoo);
INFO(diffRes.diffError->toString());
CHECK(!diffRes.diffError.has_value());
}
catch (InternalCompilerError e)
{
INFO(("InternalCompilerError: " + e.message));
CHECK(false);
}
}
TEST_CASE_FIXTURE(Fixture, "singleton_string")
{
CheckResult result = check(R"(
local foo: "hello" = "hello"
local almostFoo: "world" = "world"
almostFoo = foo
)");
LUAU_REQUIRE_ERRORS(result);
TypeId foo = requireType("foo");
TypeId almostFoo = requireType("almostFoo");
std::string diffMessage;
try
{
diffMessage = diff(foo, almostFoo).diffError->toString();
}
catch (InternalCompilerError e)
{
INFO(("InternalCompilerError: " + e.message));
CHECK(false);
}
CHECK_EQ(
R"(DiffError: these two types are not equal because the left type at <unlabeled-symbol> has type "hello", while the right type at <unlabeled-symbol> has type "world")",
diffMessage);
}
TEST_SUITE_END();

View file

@ -1146,4 +1146,35 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "module_scope_check")
CHECK_EQ(toString(ty), "number");
}
TEST_CASE_FIXTURE(FrontendFixture, "parse_only")
{
fileResolver.source["game/Gui/Modules/A"] = R"(
local a: number = 'oh no a type error'
return {a=a}
)";
fileResolver.source["game/Gui/Modules/B"] = R"(
local Modules = script.Parent
local A = require(Modules.A)
local b: number = 2
)";
frontend.parse("game/Gui/Modules/B");
REQUIRE(frontend.sourceNodes.count("game/Gui/Modules/A"));
REQUIRE(frontend.sourceNodes.count("game/Gui/Modules/B"));
auto node = frontend.sourceNodes["game/Gui/Modules/B"];
CHECK_EQ(node->requireSet.count("game/Gui/Modules/A"), 1);
REQUIRE_EQ(node->requireLocations.size(), 1);
CHECK_EQ(node->requireLocations[0].second, Luau::Location(Position(2, 18), Position(2, 36)));
// Early parse doesn't cause typechecking to be skipped
CheckResult result = frontend.check("game/Gui/Modules/B");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ("game/Gui/Modules/A", result.errors[0].moduleName);
CHECK_EQ("Type 'string' could not be converted into 'number'", toString(result.errors[0]));
}
TEST_SUITE_END();

View file

@ -526,7 +526,7 @@ bb_0:
)");
}
TEST_CASE_FIXTURE(IrBuilderFixture, "Bit32Blocked")
TEST_CASE_FIXTURE(IrBuilderFixture, "Bit32RangeReduction")
{
IrOp block = build.block(IrBlockKind::Internal);
@ -534,10 +534,10 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "Bit32Blocked")
build.inst(IrCmd::STORE_INT, build.vmReg(10), build.inst(IrCmd::BITLSHIFT_UINT, build.constInt(0xf), build.constInt(-10)));
build.inst(IrCmd::STORE_INT, build.vmReg(10), build.inst(IrCmd::BITLSHIFT_UINT, build.constInt(0xf), build.constInt(140)));
build.inst(IrCmd::STORE_INT, build.vmReg(10), build.inst(IrCmd::BITRSHIFT_UINT, build.constInt(0xf), build.constInt(-10)));
build.inst(IrCmd::STORE_INT, build.vmReg(10), build.inst(IrCmd::BITRSHIFT_UINT, build.constInt(0xf), build.constInt(140)));
build.inst(IrCmd::STORE_INT, build.vmReg(10), build.inst(IrCmd::BITARSHIFT_UINT, build.constInt(0xf), build.constInt(-10)));
build.inst(IrCmd::STORE_INT, build.vmReg(10), build.inst(IrCmd::BITARSHIFT_UINT, build.constInt(0xf), build.constInt(140)));
build.inst(IrCmd::STORE_INT, build.vmReg(10), build.inst(IrCmd::BITRSHIFT_UINT, build.constInt(0xffffff), build.constInt(-10)));
build.inst(IrCmd::STORE_INT, build.vmReg(10), build.inst(IrCmd::BITRSHIFT_UINT, build.constInt(0xffffff), build.constInt(140)));
build.inst(IrCmd::STORE_INT, build.vmReg(10), build.inst(IrCmd::BITARSHIFT_UINT, build.constInt(0xffffff), build.constInt(-10)));
build.inst(IrCmd::STORE_INT, build.vmReg(10), build.inst(IrCmd::BITARSHIFT_UINT, build.constInt(0xffffff), build.constInt(140)));
build.inst(IrCmd::RETURN, build.constUint(0));
@ -546,18 +546,12 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "Bit32Blocked")
CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"(
bb_0:
%0 = BITLSHIFT_UINT 15i, -10i
STORE_INT R10, %0
%2 = BITLSHIFT_UINT 15i, 140i
STORE_INT R10, %2
%4 = BITRSHIFT_UINT 15i, -10i
STORE_INT R10, %4
%6 = BITRSHIFT_UINT 15i, 140i
STORE_INT R10, %6
%8 = BITARSHIFT_UINT 15i, -10i
STORE_INT R10, %8
%10 = BITARSHIFT_UINT 15i, 140i
STORE_INT R10, %10
STORE_INT R10, 62914560i
STORE_INT R10, 61440i
STORE_INT R10, 3i
STORE_INT R10, 4095i
STORE_INT R10, 3i
STORE_INT R10, 4095i
RETURN 0u
)");
@ -1864,6 +1858,34 @@ bb_0:
)");
}
TEST_CASE_FIXTURE(IrBuilderFixture, "LoadPropagatesOnlyRightType")
{
IrOp block = build.block(IrBlockKind::Internal);
build.beginBlock(block);
build.inst(IrCmd::STORE_INT, build.vmReg(0), build.constInt(2));
IrOp value1 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(0));
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), value1);
IrOp value2 = build.inst(IrCmd::LOAD_INT, build.vmReg(1));
build.inst(IrCmd::STORE_INT, build.vmReg(2), value2);
build.inst(IrCmd::RETURN, build.constUint(0));
updateUseCounts(build.function);
constPropInBlockChains(build, true);
CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"(
bb_0:
STORE_INT R0, 2i
%1 = LOAD_DOUBLE R0
STORE_DOUBLE R1, %1
%3 = LOAD_INT R1
STORE_INT R2, %3
RETURN 0u
)");
}
TEST_SUITE_END();
TEST_SUITE_BEGIN("Analysis");

View file

@ -10,8 +10,8 @@ using namespace Luau::CodeGen::X64;
class IrCallWrapperX64Fixture
{
public:
IrCallWrapperX64Fixture()
: build(/* logText */ true, ABIX64::Windows)
IrCallWrapperX64Fixture(ABIX64 abi = ABIX64::Windows)
: build(/* logText */ true, abi)
, regs(build, function)
, callWrap(regs, build, ~0u)
{
@ -42,6 +42,15 @@ public:
static constexpr RegisterX64 rArg4d = r9d;
};
class IrCallWrapperX64FixtureSystemV : public IrCallWrapperX64Fixture
{
public:
IrCallWrapperX64FixtureSystemV()
: IrCallWrapperX64Fixture(ABIX64::SystemV)
{
}
};
TEST_SUITE_BEGIN("IrCallWrapperX64");
TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "SimpleRegs")
@ -519,4 +528,35 @@ TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "ImmediateConflictWithFunction")
)");
}
TEST_CASE_FIXTURE(IrCallWrapperX64FixtureSystemV, "SuggestedConflictWithReserved")
{
ScopedRegX64 tmp{regs, regs.takeReg(r9, kInvalidInstIdx)};
IrCallWrapperX64 callWrap(regs, build);
callWrap.addArgument(SizeX64::qword, r12);
callWrap.addArgument(SizeX64::qword, r13);
callWrap.addArgument(SizeX64::qword, r14);
callWrap.addArgument(SizeX64::dword, 2);
callWrap.addArgument(SizeX64::qword, 1);
RegisterX64 reg = callWrap.suggestNextArgumentRegister(SizeX64::dword);
build.mov(reg, 10);
callWrap.addArgument(SizeX64::dword, reg);
callWrap.call(tmp.release());
checkMatch(R"(
mov eax,Ah
mov rdi,r12
mov rsi,r13
mov rdx,r14
mov rcx,r9
mov r9d,eax
mov rax,rcx
mov ecx,2
mov r8,1
call rax
)");
}
TEST_SUITE_END();

View file

@ -735,6 +735,37 @@ TEST_CASE_FIXTURE(NormalizeFixture, "intersection_of_metatables_where_the_metata
CHECK("{ @metatable *error-type*, {| |} }" == toString(normal("Mt<{}, any> & Mt<{}, err>")));
}
TEST_CASE_FIXTURE(NormalizeFixture, "recurring_intersection")
{
CheckResult result = check(R"(
type A = any?
type B = A & A
)");
std::optional<TypeId> t = lookupType("B");
REQUIRE(t);
const NormalizedType* nt = normalizer.normalize(*t);
REQUIRE(nt);
CHECK("any" == toString(normalizer.typeFromNormal(*nt)));
}
TEST_CASE_FIXTURE(NormalizeFixture, "cyclic_union")
{
ScopedFastFlag sff{"LuauNormalizeCyclicUnions", true};
// T where T = any & (number | T)
TypeId t = arena.addType(BlockedType{});
TypeId u = arena.addType(UnionType{{builtinTypes->numberType, t}});
asMutable(t)->ty.emplace<IntersectionType>(IntersectionType{{builtinTypes->anyType, u}});
const NormalizedType* nt = normalizer.normalize(t);
REQUIRE(nt);
CHECK("number" == toString(normalizer.typeFromNormal(*nt)));
}
TEST_CASE_FIXTURE(NormalizeFixture, "crazy_metatable")
{
CHECK("never" == toString(normal("Mt<{}, number> & Mt<{}, string>")));

View file

@ -479,7 +479,6 @@ TEST_CASE_FIXTURE(ClassFixture, "callable_classes")
TEST_CASE_FIXTURE(ClassFixture, "indexable_classes")
{
// Test reading from an index
ScopedFastFlag LuauTypecheckClassTypeIndexers("LuauTypecheckClassTypeIndexers", true);
{
CheckResult result = check(R"(
local x : IndexableClass

View file

@ -398,7 +398,6 @@ TEST_CASE_FIXTURE(Fixture, "class_definition_string_props")
TEST_CASE_FIXTURE(Fixture, "class_definition_indexer")
{
ScopedFastFlag LuauParseDeclareClassIndexer("LuauParseDeclareClassIndexer", true);
ScopedFastFlag LuauTypecheckClassTypeIndexers("LuauTypecheckClassTypeIndexers", true);
loadDefinition(R"(
declare class Foo

View file

@ -2096,4 +2096,46 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "attempt_to_call_an_intersection_of_tables_wi
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(BuiltinsFixture, "num_is_solved_before_num_or_str")
{
CheckResult result = check(R"(
function num()
return 5
end
local function num_or_str()
if math.random() > 0.5 then
return num()
else
return "some string"
end
end
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ("Type 'string' could not be converted into 'number'", toString(result.errors[0]));
CHECK_EQ("() -> number", toString(requireType("num_or_str")));
}
TEST_CASE_FIXTURE(BuiltinsFixture, "num_is_solved_after_num_or_str")
{
CheckResult result = check(R"(
local function num_or_str()
if math.random() > 0.5 then
return num()
else
return "some string"
end
end
function num()
return 5
end
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ("Type 'string' could not be converted into 'number'", toString(result.errors[0]));
CHECK_EQ("() -> number", toString(requireType("num_or_str")));
}
TEST_SUITE_END();

View file

@ -917,6 +917,52 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "type_comparison_ifelse_expression")
CHECK_EQ("any", toString(requireTypeAtPosition({6, 66})));
}
TEST_CASE_FIXTURE(BuiltinsFixture, "is_truthy_constraint_while_expression")
{
CheckResult result = check(R"(
function f(v:string?)
while v do
local foo = v
end
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("string", toString(requireTypeAtPosition({3, 28})));
}
TEST_CASE_FIXTURE(BuiltinsFixture, "invert_is_truthy_constraint_while_expression")
{
CheckResult result = check(R"(
function f(v:string?)
while not v do
local foo = v
end
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("nil", toString(requireTypeAtPosition({3, 28})));
}
TEST_CASE_FIXTURE(BuiltinsFixture, "refine_the_correct_types_opposite_of_while_a_is_not_number_or_string")
{
CheckResult result = check(R"(
local function f(a: string | number | boolean)
while type(a) ~= "number" and type(a) ~= "string" do
local foo = a
end
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("boolean", toString(requireTypeAtPosition({3, 28})));
}
TEST_CASE_FIXTURE(BuiltinsFixture, "correctly_lookup_a_shadowed_local_that_which_was_previously_refined")
{
CheckResult result = check(R"(
@ -1580,8 +1626,6 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "refine_a_param_that_got_resolved_duri
TEST_CASE_FIXTURE(Fixture, "refine_a_property_of_some_global")
{
ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true};
CheckResult result = check(R"(
foo = { bar = 5 :: number? }
@ -1590,9 +1634,12 @@ TEST_CASE_FIXTURE(Fixture, "refine_a_property_of_some_global")
end
)");
LUAU_REQUIRE_ERROR_COUNT(3, result);
if (FFlag::DebugLuauDeferredConstraintResolution)
{
LUAU_REQUIRE_ERROR_COUNT(3, result);
CHECK_EQ("~(false?)", toString(requireTypeAtPosition({4, 30})));
CHECK_EQ("~(false?)", toString(requireTypeAtPosition({4, 30})));
}
}
TEST_CASE_FIXTURE(BuiltinsFixture, "dataflow_analysis_can_tell_refinements_when_its_appropriate_to_refine_into_nil_or_never")
@ -1757,4 +1804,20 @@ TEST_CASE_FIXTURE(Fixture, "refinements_should_not_affect_assignment")
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(BuiltinsFixture, "refinements_should_preserve_error_suppression")
{
CheckResult result = check(R"(
local a: any = {}
local b
if typeof(a) == "table" then
b = a.field
end
)");
if (FFlag::DebugLuauDeferredConstraintResolution)
LUAU_REQUIRE_NO_ERRORS(result);
else
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_SUITE_END();

View file

@ -9,6 +9,20 @@ using namespace Luau;
TEST_SUITE_BEGIN("TypeSingletons");
TEST_CASE_FIXTURE(Fixture, "function_args_infer_singletons")
{
CheckResult result = check(R"(
--!strict
type Phase = "A" | "B" | "C"
local function f(e : Phase) : number
return 0
end
local e = f("B")
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "bool_singletons")
{
CheckResult result = check(R"(

View file

@ -61,4 +61,35 @@ end
assert(pcall(fuzzfail5) == false)
local function fuzzfail6(_)
return bit32.extract(_,671088640,_)
end
assert(pcall(fuzzfail6, 1) == false)
local function fuzzfail7(_)
return bit32.extract(_,_,671088640)
end
assert(pcall(fuzzfail7, 1) == false)
local function fuzzfail8(...)
local _ = _,_
_.n0,_,_,_,_,_,_,_,_._,_,_,_[...],_,_,_ = nil
_,n0,_,_,_,_,_,_,_,_,l0,_,_,_,_ = nil
function _()
end
_._,_,_,_,_,_,_,_,_,_,_[...],_,n0[l0],_ = nil
_[...],_,_,_,_,_,_,_,_()[_],_,_,_,_,_ = _(),...
end
assert(pcall(fuzzfail8) == false)
local function fuzzfail9()
local _ = bit32.bor
local x = _(_(_,_),_(_,_),_(-16834560,_),_(_(- _,-2130706432)),- _),_(_(_,_),_(-16834560,-2130706432))
end
assert(pcall(fuzzfail9) == false)
return('OK')

View file

@ -54,6 +54,8 @@ ProvisionalTests.table_insert_with_a_singleton_argument
ProvisionalTests.typeguard_inference_incomplete
RefinementTest.discriminate_from_truthiness_of_x
RefinementTest.not_t_or_some_prop_of_t
RefinementTest.refine_a_property_of_some_global
RefinementTest.refinements_should_preserve_error_suppression
RefinementTest.truthy_constraint_on_properties
RefinementTest.type_narrow_to_vector
RefinementTest.typeguard_cast_free_table_to_vector
@ -96,7 +98,6 @@ TableTests.shared_selfs
TableTests.shared_selfs_from_free_param
TableTests.shared_selfs_through_metatables
TableTests.table_call_metamethod_basic
TableTests.table_call_metamethod_generic
TableTests.table_simple_call
TableTests.table_subtyping_with_missing_props_dont_report_multiple_errors
TableTests.used_colon_instead_of_dot
@ -131,7 +132,6 @@ TypeInfer.tc_after_error_recovery_no_replacement_name_in_error
TypeInfer.type_infer_recursion_limit_no_ice
TypeInfer.type_infer_recursion_limit_normalizer
TypeInferAnyError.for_in_loop_iterator_is_any2
TypeInferClasses.callable_classes
TypeInferClasses.class_type_mismatch_with_name_conflict
TypeInferClasses.index_instance_property
TypeInferFunctions.cannot_hoist_interior_defns_into_signature