Sync to upstream/release/658 (#1625)
Some checks failed
benchmark / callgrind (map[branch:main name:luau-lang/benchmark-data], ubuntu-22.04) (push) Has been cancelled
build / macos (push) Has been cancelled
build / macos-arm (push) Has been cancelled
build / ubuntu (push) Has been cancelled
build / windows (Win32) (push) Has been cancelled
build / windows (x64) (push) Has been cancelled
build / coverage (push) Has been cancelled
build / web (push) Has been cancelled
release / macos (push) Has been cancelled
release / ubuntu (push) Has been cancelled
release / windows (push) Has been cancelled
release / web (push) Has been cancelled

## What's Changed

### General
- Allow types of tables to diverge after using `table.clone` (fixes
#1617).
- Allow 2-argument vector.create in Luau.
- Fix a crash when suggesting autocomplete after encountering parsing
errors.
- Add lua_tolstringatom C API which returns the string length (whether
or not the atom exists) and which extends the existing lua_tostringatom
function the same way lua_tolstring/lua_tostring do.
- Luau now retains the DFGs of typechecked modules.

### Magic Functions Migration Note
We've made a change to the API used to define magic functions.

Previously, we had a set of function pointers on each `FunctionType`
that would be invoked by the type inference engine at the correct point.

The problem we'd run into is that they were all `std::function`s, we'd
grown quite a few of them, and Luau allocates tens of thousands of types
as it performs type inference. This adds up to a large amount of memory
for data that isn't used by 99% of types.

To slim things down a bit, we've replaced all of those `std::function`s
with a single `shared_ptr` to a new interface called `MagicFunction`.
This slims down the memory footprint of each type by about 50 bytes.

The virtual methods of `MagicFunction` have roughly 1:1 correspondence
with the old interface, so updating things should not be too difficult:

* `FunctionType::magicFunction` is now `MagicFunction::handleOldSolver`
* `FunctionType::dcrMagicFunction` is now `MagicFunction::infer`
* `FunctionType::dcrMagicRefinement` is now `MagicFunction::refine`
* `FunctionType::dcrMagicTypeCheck` is now `MagicFunction::typeCheck`

**Full Changelog**:
https://github.com/luau-lang/luau/compare/0.657...0.658

---

Co-authored-by: Andy Friesen <afriesen@roblox.com>
Co-authored-by: Ariel Weiss <aaronweiss@roblox.com>
Co-authored-by: Aviral Goel <agoel@roblox.com>
Co-authored-by: Hunter Goldstein <hgoldstein@roblox.com>
Co-authored-by: Talha Pathan <tpathan@roblox.com>
Co-authored-by: Varun Saini <vsaini@roblox.com>
Co-authored-by: Vighnesh Vijay <vvijay@roblox.com>
Co-authored-by: Vyacheslav Egorov <vegorov@roblox.com>
This commit is contained in:
ayoungbloodrbx 2025-01-24 12:15:19 -08:00 committed by GitHub
parent 6061a14e9f
commit c13b5b7440
Signed by: DevComp
GPG key ID: B5690EEEBB952194
37 changed files with 751 additions and 360 deletions

View file

@ -65,10 +65,7 @@ TypeId makeFunction( // Polymorphic
bool checked = false
);
void attachMagicFunction(TypeId ty, MagicFunction fn);
void attachDcrMagicFunction(TypeId ty, DcrMagicFunction fn);
void attachDcrMagicRefinement(TypeId ty, DcrMagicRefinement fn);
void attachDcrMagicFunctionTypeCheck(TypeId ty, DcrMagicFunctionTypeCheck fn);
void attachMagicFunction(TypeId ty, std::shared_ptr<MagicFunction> fn);
Property makeProperty(TypeId ty, std::optional<std::string> documentationSymbol = std::nullopt);
void assignPropDocumentationSymbols(TableType::Props& props, const std::string& baseName);

View file

@ -166,7 +166,7 @@ struct ConstraintSolver
**/
void finalizeTypeFunctions();
bool isDone();
bool isDone() const;
private:
/**
@ -298,10 +298,10 @@ public:
// FIXME: This use of a boolean for the return result is an appalling
// interface.
bool blockOnPendingTypes(TypeId target, NotNull<const Constraint> constraint);
bool blockOnPendingTypes(TypePackId target, NotNull<const Constraint> constraint);
bool blockOnPendingTypes(TypePackId targetPack, NotNull<const Constraint> constraint);
void unblock(NotNull<const Constraint> progressed);
void unblock(TypeId progressed, Location location);
void unblock(TypeId ty, Location location);
void unblock(TypePackId progressed, Location location);
void unblock(const std::vector<TypeId>& types, Location location);
void unblock(const std::vector<TypePackId>& packs, Location location);
@ -336,7 +336,7 @@ public:
* @param location the location where the require is taking place; used for
* error locations.
**/
TypeId resolveModule(const ModuleInfo& module, const Location& location);
TypeId resolveModule(const ModuleInfo& info, const Location& location);
void reportError(TypeErrorData&& data, const Location& location);
void reportError(TypeError e);

View file

@ -6,6 +6,7 @@
#include "Luau/ControlFlow.h"
#include "Luau/DenseHash.h"
#include "Luau/Def.h"
#include "Luau/NotNull.h"
#include "Luau/Symbol.h"
#include "Luau/TypedAllocator.h"
@ -48,13 +49,13 @@ struct DataFlowGraph
const RefinementKey* getRefinementKey(const AstExpr* expr) const;
private:
DataFlowGraph() = default;
DataFlowGraph(NotNull<DefArena> defArena, NotNull<RefinementKeyArena> keyArena);
DataFlowGraph(const DataFlowGraph&) = delete;
DataFlowGraph& operator=(const DataFlowGraph&) = delete;
DefArena defArena;
RefinementKeyArena keyArena;
NotNull<DefArena> defArena;
NotNull<RefinementKeyArena> keyArena;
DenseHashMap<const AstExpr*, const Def*> astDefs{nullptr};
@ -110,30 +111,22 @@ using ScopeStack = std::vector<DfgScope*>;
struct DataFlowGraphBuilder
{
static DataFlowGraph build(AstStatBlock* root, NotNull<struct InternalErrorReporter> handle);
/**
* This method is identical to the build method above, but returns a pair of dfg, scopes as the data flow graph
* here is intended to live on the module between runs of typechecking. Before, the DFG only needed to live as
* long as the typecheck, but in a world with incremental typechecking, we need the information on the dfg to incrementally
* typecheck small fragments of code.
* @param block - pointer to the ast to build the dfg for
* @param handle - for raising internal errors while building the dfg
*/
static std::pair<std::shared_ptr<DataFlowGraph>, std::vector<std::unique_ptr<DfgScope>>> buildShared(
static DataFlowGraph build(
AstStatBlock* block,
NotNull<InternalErrorReporter> handle
NotNull<DefArena> defArena,
NotNull<RefinementKeyArena> keyArena,
NotNull<struct InternalErrorReporter> handle
);
private:
DataFlowGraphBuilder() = default;
DataFlowGraphBuilder(NotNull<DefArena> defArena, NotNull<RefinementKeyArena> keyArena);
DataFlowGraphBuilder(const DataFlowGraphBuilder&) = delete;
DataFlowGraphBuilder& operator=(const DataFlowGraphBuilder&) = delete;
DataFlowGraph graph;
NotNull<DefArena> defArena{&graph.defArena};
NotNull<RefinementKeyArena> keyArena{&graph.keyArena};
NotNull<DefArena> defArena;
NotNull<RefinementKeyArena> keyArena;
struct InternalErrorReporter* handle = nullptr;

View file

@ -105,6 +105,9 @@ private:
std::vector<Id> storage;
};
template <typename L>
using Node = EqSat::Node<L>;
using EType = EqSat::Language<
TNil,
TBoolean,
@ -171,6 +174,9 @@ struct Subst
Id eclass;
Id newClass;
// The node into eclass which is boring, if any
std::optional<size_t> boringIndex;
std::string desc;
Subst(Id eclass, Id newClass, std::string desc = "");
@ -211,6 +217,7 @@ struct Simplifier
void subst(Id from, Id to);
void subst(Id from, Id to, const std::string& ruleName);
void subst(Id from, Id to, const std::string& ruleName, const std::unordered_map<Id, size_t>& forceNodes);
void subst(Id from, size_t boringIndex, Id to, const std::string& ruleName, const std::unordered_map<Id, size_t>& forceNodes);
void unionClasses(std::vector<Id>& hereParts, Id there);
@ -295,13 +302,13 @@ QueryIterator<Tag>::QueryIterator(EGraph* egraph_, Id eclass)
for (const auto& enode : ecl.nodes)
{
if (enode.index() < idx)
if (enode.node.index() < idx)
++index;
else
break;
}
if (index >= ecl.nodes.size() || ecl.nodes[index].index() != idx)
if (index >= ecl.nodes.size() || ecl.nodes[index].node.index() != idx)
{
egraph = nullptr;
index = 0;
@ -331,7 +338,7 @@ std::pair<const Tag*, size_t> QueryIterator<Tag>::operator*() const
EGraph::EClassT& ecl = (*egraph)[eclass];
LUAU_ASSERT(index < ecl.nodes.size());
auto& enode = ecl.nodes[index];
auto& enode = ecl.nodes[index].node;
Tag* result = enode.template get<Tag>();
LUAU_ASSERT(result);
return {result, index};
@ -343,12 +350,16 @@ QueryIterator<Tag>& QueryIterator<Tag>::operator++()
{
const auto& ecl = (*egraph)[eclass];
do
{
++index;
if (index >= ecl.nodes.size() || ecl.nodes[index].index() != EType::VariantTy::getTypeId<Tag>())
if (index >= ecl.nodes.size() || ecl.nodes[index].node.index() != EType::VariantTy::getTypeId<Tag>())
{
egraph = nullptr;
index = 0;
break;
}
} while (ecl.nodes[index].boring);
return *this;
}

View file

@ -17,8 +17,8 @@ struct FrontendOptions;
enum class FragmentTypeCheckStatus
{
Success,
SkipAutocomplete,
Success,
};
struct FragmentAutocompleteAncestryResult
@ -56,7 +56,7 @@ struct FragmentAutocompleteResult
FragmentAutocompleteAncestryResult findAncestryForFragmentParse(AstStatBlock* root, const Position& cursorPos);
FragmentParseResult parseFragment(
std::optional<FragmentParseResult> parseFragment(
const SourceModule& srcModule,
std::string_view src,
const Position& cursorPos,

View file

@ -139,6 +139,11 @@ struct Module
TypePackId returnType = nullptr;
std::unordered_map<Name, TypeFun> exportedTypeBindings;
// Arenas related to the DFG must persist after the DFG no longer exists, as
// Module objects maintain raw pointers to objects in these arenas.
DefArena defArena;
RefinementKeyArena keyArena;
bool hasModuleScope() const;
ScopePtr getModuleScope() const;

View file

@ -131,14 +131,14 @@ struct BlockedType
BlockedType();
int index;
Constraint* getOwner() const;
void setOwner(Constraint* newOwner);
void replaceOwner(Constraint* newOwner);
const Constraint* getOwner() const;
void setOwner(const Constraint* newOwner);
void replaceOwner(const Constraint* newOwner);
private:
// The constraint that is intended to unblock this type. Other constraints
// should block on this constraint if present.
Constraint* owner = nullptr;
const Constraint* owner = nullptr;
};
struct PrimitiveType
@ -279,9 +279,6 @@ struct WithPredicate
}
};
using MagicFunction = std::function<std::optional<
WithPredicate<TypePackId>>(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>)>;
struct MagicFunctionCallContext
{
NotNull<struct ConstraintSolver> solver;
@ -291,7 +288,6 @@ struct MagicFunctionCallContext
TypePackId result;
};
using DcrMagicFunction = std::function<bool(MagicFunctionCallContext)>;
struct MagicRefinementContext
{
NotNull<Scope> scope;
@ -308,8 +304,29 @@ struct MagicFunctionTypeCheckContext
NotNull<Scope> checkScope;
};
using DcrMagicRefinement = void (*)(const MagicRefinementContext&);
using DcrMagicFunctionTypeCheck = std::function<void(const MagicFunctionTypeCheckContext&)>;
struct MagicFunction
{
virtual std::optional<WithPredicate<TypePackId>> handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) = 0;
// Callback to allow custom typechecking of builtin function calls whose argument types
// will only be resolved after constraint solving. For example, the arguments to string.format
// have types that can only be decided after parsing the format string and unifying
// with the passed in values, but the correctness of the call can only be decided after
// all the types have been finalized.
virtual bool infer(const MagicFunctionCallContext&) = 0;
virtual void refine(const MagicRefinementContext&) {}
// If a magic function needs to do its own special typechecking, do it here.
// Returns true if magic typechecking was performed. Return false if the
// default typechecking logic should run.
virtual bool typeCheck(const MagicFunctionTypeCheckContext&)
{
return false;
}
virtual ~MagicFunction() {}
};
struct FunctionType
{
// Global monomorphic function
@ -367,16 +384,7 @@ struct FunctionType
Scope* scope = nullptr;
TypePackId argTypes;
TypePackId retTypes;
MagicFunction magicFunction = nullptr;
DcrMagicFunction dcrMagicFunction = nullptr;
DcrMagicRefinement dcrMagicRefinement = nullptr;
// Callback to allow custom typechecking of builtin function calls whose argument types
// will only be resolved after constraint solving. For example, the arguments to string.format
// have types that can only be decided after parsing the format string and unifying
// with the passed in values, but the correctness of the call can only be decided after
// all the types have been finalized.
DcrMagicFunctionTypeCheck dcrMagicTypeCheck = nullptr;
std::shared_ptr<MagicFunction> magic = nullptr;
bool hasSelf;
// `hasNoFreeOrGenericTypes` should be true if and only if the type does not have any free or generic types present inside it.

View file

@ -85,6 +85,8 @@ struct GenericTypeVisitor
{
}
virtual ~GenericTypeVisitor() {}
virtual void cycle(TypeId) {}
virtual void cycle(TypePackId) {}

View file

@ -13,6 +13,8 @@
LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAG(LuauExtendStatEndPosWithSemicolon)
namespace Luau
{
@ -40,12 +42,27 @@ struct AutocompleteNodeFinder : public AstVisitor
}
bool visit(AstStat* stat) override
{
if (FFlag::LuauExtendStatEndPosWithSemicolon)
{
// Consider 'local myLocal = 4;|' and 'local myLocal = 4', where '|' is the cursor position. In both cases, the cursor position is equal
// to `AstStatLocal.location.end`. However, in the first case (semicolon), we are starting a new statement, whilst in the second case
// (no semicolon) we are still part of the AstStatLocal, hence the different comparison check.
if (stat->location.begin < pos && (stat->hasSemicolon ? pos < stat->location.end : pos <= stat->location.end))
{
ancestry.push_back(stat);
return true;
}
}
else
{
if (stat->location.begin < pos && pos <= stat->location.end)
{
ancestry.push_back(stat);
return true;
}
}
return false;
}

View file

@ -34,46 +34,78 @@ LUAU_FASTFLAGVARIABLE(LuauStringFormatArityFix)
LUAU_FASTFLAGVARIABLE(LuauStringFormatErrorSuppression)
LUAU_FASTFLAG(AutocompleteRequirePathSuggestions2)
LUAU_FASTFLAG(LuauVectorDefinitionsExtra)
LUAU_FASTFLAGVARIABLE(LuauTableCloneClonesType)
LUAU_FASTFLAG(LuauTrackInteriorFreeTypesOnScope)
namespace Luau
{
static std::optional<WithPredicate<TypePackId>> magicFunctionSelect(
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
);
static std::optional<WithPredicate<TypePackId>> magicFunctionSetMetaTable(
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
);
static std::optional<WithPredicate<TypePackId>> magicFunctionAssert(
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
);
static std::optional<WithPredicate<TypePackId>> magicFunctionPack(
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
);
static std::optional<WithPredicate<TypePackId>> magicFunctionRequire(
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
);
struct MagicSelect final : MagicFunction
{
std::optional<WithPredicate<TypePackId>> handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override;
};
struct MagicSetMetatable final : MagicFunction
{
std::optional<WithPredicate<TypePackId>> handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override;
};
static bool dcrMagicFunctionSelect(MagicFunctionCallContext context);
static bool dcrMagicFunctionRequire(MagicFunctionCallContext context);
static bool dcrMagicFunctionPack(MagicFunctionCallContext context);
static bool dcrMagicFunctionFreeze(MagicFunctionCallContext context);
struct MagicAssert final : MagicFunction
{
std::optional<WithPredicate<TypePackId>> handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override;
};
struct MagicPack final : MagicFunction
{
std::optional<WithPredicate<TypePackId>> handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override;
};
struct MagicRequire final : MagicFunction
{
std::optional<WithPredicate<TypePackId>> handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override;
};
struct MagicClone final : MagicFunction
{
std::optional<WithPredicate<TypePackId>> handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override;
};
struct MagicFreeze final : MagicFunction
{
std::optional<WithPredicate<TypePackId>> handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override;
};
struct MagicFormat final : MagicFunction
{
std::optional<WithPredicate<TypePackId>> handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override;
bool typeCheck(const MagicFunctionTypeCheckContext& ctx) override;
};
struct MagicMatch final : MagicFunction
{
std::optional<WithPredicate<TypePackId>> handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override;
};
struct MagicGmatch final : MagicFunction
{
std::optional<WithPredicate<TypePackId>> handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override;
};
struct MagicFind final : MagicFunction
{
std::optional<WithPredicate<TypePackId>> handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override;
};
TypeId makeUnion(TypeArena& arena, std::vector<TypeId>&& types)
{
@ -168,34 +200,10 @@ TypeId makeFunction(
return arena.addType(std::move(ftv));
}
void attachMagicFunction(TypeId ty, MagicFunction fn)
void attachMagicFunction(TypeId ty, std::shared_ptr<MagicFunction> magic)
{
if (auto ftv = getMutable<FunctionType>(ty))
ftv->magicFunction = fn;
else
LUAU_ASSERT(!"Got a non functional type");
}
void attachDcrMagicFunction(TypeId ty, DcrMagicFunction fn)
{
if (auto ftv = getMutable<FunctionType>(ty))
ftv->dcrMagicFunction = fn;
else
LUAU_ASSERT(!"Got a non functional type");
}
void attachDcrMagicRefinement(TypeId ty, DcrMagicRefinement fn)
{
if (auto ftv = getMutable<FunctionType>(ty))
ftv->dcrMagicRefinement = fn;
else
LUAU_ASSERT(!"Got a non functional type");
}
void attachDcrMagicFunctionTypeCheck(TypeId ty, DcrMagicFunctionTypeCheck fn)
{
if (auto ftv = getMutable<FunctionType>(ty))
ftv->dcrMagicTypeCheck = fn;
ftv->magic = std::move(magic);
else
LUAU_ASSERT(!"Got a non functional type");
}
@ -396,7 +404,7 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC
}
}
attachMagicFunction(getGlobalBinding(globals, "assert"), magicFunctionAssert);
attachMagicFunction(getGlobalBinding(globals, "assert"), std::make_shared<MagicAssert>());
if (FFlag::LuauSolverV2)
{
@ -412,9 +420,8 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC
addGlobalBinding(globals, "assert", assertTy, "@luau");
}
attachMagicFunction(getGlobalBinding(globals, "setmetatable"), magicFunctionSetMetaTable);
attachMagicFunction(getGlobalBinding(globals, "select"), magicFunctionSelect);
attachDcrMagicFunction(getGlobalBinding(globals, "select"), dcrMagicFunctionSelect);
attachMagicFunction(getGlobalBinding(globals, "setmetatable"), std::make_shared<MagicSetMetatable>());
attachMagicFunction(getGlobalBinding(globals, "select"), std::make_shared<MagicSelect>());
if (TableType* ttv = getMutable<TableType>(getGlobalBinding(globals, "table")))
{
@ -445,23 +452,22 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC
ttv->props["foreach"].deprecated = true;
ttv->props["foreachi"].deprecated = true;
attachMagicFunction(ttv->props["pack"].type(), magicFunctionPack);
attachDcrMagicFunction(ttv->props["pack"].type(), dcrMagicFunctionPack);
attachMagicFunction(ttv->props["pack"].type(), std::make_shared<MagicPack>());
if (FFlag::LuauTableCloneClonesType)
attachMagicFunction(ttv->props["clone"].type(), std::make_shared<MagicClone>());
if (FFlag::LuauTypestateBuiltins2)
attachDcrMagicFunction(ttv->props["freeze"].type(), dcrMagicFunctionFreeze);
attachMagicFunction(ttv->props["freeze"].type(), std::make_shared<MagicFreeze>());
}
if (FFlag::AutocompleteRequirePathSuggestions2)
{
TypeId requireTy = getGlobalBinding(globals, "require");
attachTag(requireTy, kRequireTagName);
attachMagicFunction(requireTy, magicFunctionRequire);
attachDcrMagicFunction(requireTy, dcrMagicFunctionRequire);
attachMagicFunction(requireTy, std::make_shared<MagicRequire>());
}
else
{
attachMagicFunction(getGlobalBinding(globals, "require"), magicFunctionRequire);
attachDcrMagicFunction(getGlobalBinding(globals, "require"), dcrMagicFunctionRequire);
attachMagicFunction(getGlobalBinding(globals, "require"), std::make_shared<MagicRequire>());
}
}
@ -501,7 +507,7 @@ static std::vector<TypeId> parseFormatString(NotNull<BuiltinTypes> builtinTypes,
return result;
}
std::optional<WithPredicate<TypePackId>> magicFunctionFormat(
std::optional<WithPredicate<TypePackId>> MagicFormat::handleOldSolver(
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
@ -551,7 +557,7 @@ std::optional<WithPredicate<TypePackId>> magicFunctionFormat(
return WithPredicate<TypePackId>{arena.addTypePack({typechecker.stringType})};
}
static bool dcrMagicFunctionFormat(MagicFunctionCallContext context)
bool MagicFormat::infer(const MagicFunctionCallContext& context)
{
TypeArena* arena = context.solver->arena;
@ -595,7 +601,7 @@ static bool dcrMagicFunctionFormat(MagicFunctionCallContext context)
return true;
}
static void dcrMagicFunctionTypeCheckFormat(MagicFunctionTypeCheckContext context)
bool MagicFormat::typeCheck(const MagicFunctionTypeCheckContext& context)
{
AstExprConstantString* fmt = nullptr;
if (auto index = context.callSite->func->as<AstExprIndexName>(); index && context.callSite->self)
@ -615,7 +621,7 @@ static void dcrMagicFunctionTypeCheckFormat(MagicFunctionTypeCheckContext contex
context.typechecker->reportError(
CountMismatch{1, std::nullopt, 0, CountMismatch::Arg, true, "string.format"}, context.callSite->location
);
return;
return true;
}
std::vector<TypeId> expected = parseFormatString(context.builtinTypes, fmt->value.data, fmt->value.size);
@ -657,6 +663,8 @@ static void dcrMagicFunctionTypeCheckFormat(MagicFunctionTypeCheckContext contex
}
}
}
return true;
}
static std::vector<TypeId> parsePatternString(NotNull<BuiltinTypes> builtinTypes, const char* data, size_t size)
@ -719,7 +727,7 @@ static std::vector<TypeId> parsePatternString(NotNull<BuiltinTypes> builtinTypes
return result;
}
static std::optional<WithPredicate<TypePackId>> magicFunctionGmatch(
std::optional<WithPredicate<TypePackId>> MagicGmatch::handleOldSolver(
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
@ -755,7 +763,7 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionGmatch(
return WithPredicate<TypePackId>{arena.addTypePack({iteratorType})};
}
static bool dcrMagicFunctionGmatch(MagicFunctionCallContext context)
bool MagicGmatch::infer(const MagicFunctionCallContext& context)
{
const auto& [params, tail] = flatten(context.arguments);
@ -788,7 +796,7 @@ static bool dcrMagicFunctionGmatch(MagicFunctionCallContext context)
return true;
}
static std::optional<WithPredicate<TypePackId>> magicFunctionMatch(
std::optional<WithPredicate<TypePackId>> MagicMatch::handleOldSolver(
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
@ -828,7 +836,7 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionMatch(
return WithPredicate<TypePackId>{returnList};
}
static bool dcrMagicFunctionMatch(MagicFunctionCallContext context)
bool MagicMatch::infer(const MagicFunctionCallContext& context)
{
const auto& [params, tail] = flatten(context.arguments);
@ -864,7 +872,7 @@ static bool dcrMagicFunctionMatch(MagicFunctionCallContext context)
return true;
}
static std::optional<WithPredicate<TypePackId>> magicFunctionFind(
std::optional<WithPredicate<TypePackId>> MagicFind::handleOldSolver(
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
@ -922,7 +930,7 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionFind(
return WithPredicate<TypePackId>{returnList};
}
static bool dcrMagicFunctionFind(MagicFunctionCallContext context)
bool MagicFind::infer(const MagicFunctionCallContext& context)
{
const auto& [params, tail] = flatten(context.arguments);
@ -999,11 +1007,9 @@ TypeId makeStringMetatable(NotNull<BuiltinTypes> builtinTypes)
FunctionType formatFTV{arena->addTypePack(TypePack{{stringType}, variadicTailPack}), oneStringPack};
formatFTV.magicFunction = &magicFunctionFormat;
formatFTV.isCheckedFunction = true;
const TypeId formatFn = arena->addType(formatFTV);
attachDcrMagicFunction(formatFn, dcrMagicFunctionFormat);
attachDcrMagicFunctionTypeCheck(formatFn, dcrMagicFunctionTypeCheckFormat);
attachMagicFunction(formatFn, std::make_shared<MagicFormat>());
const TypeId stringToStringType = makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType}, /* checked */ true);
@ -1017,16 +1023,14 @@ TypeId makeStringMetatable(NotNull<BuiltinTypes> builtinTypes)
makeFunction(*arena, stringType, {}, {}, {stringType, replArgType, optionalNumber}, {}, {stringType, numberType}, /* checked */ false);
const TypeId gmatchFunc =
makeFunction(*arena, stringType, {}, {}, {stringType}, {}, {arena->addType(FunctionType{emptyPack, stringVariadicList})}, /* checked */ true);
attachMagicFunction(gmatchFunc, magicFunctionGmatch);
attachDcrMagicFunction(gmatchFunc, dcrMagicFunctionGmatch);
attachMagicFunction(gmatchFunc, std::make_shared<MagicGmatch>());
FunctionType matchFuncTy{
arena->addTypePack({stringType, stringType, optionalNumber}), arena->addTypePack(TypePackVar{VariadicTypePack{stringType}})
};
matchFuncTy.isCheckedFunction = true;
const TypeId matchFunc = arena->addType(matchFuncTy);
attachMagicFunction(matchFunc, magicFunctionMatch);
attachDcrMagicFunction(matchFunc, dcrMagicFunctionMatch);
attachMagicFunction(matchFunc, std::make_shared<MagicMatch>());
FunctionType findFuncTy{
arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}),
@ -1034,8 +1038,7 @@ TypeId makeStringMetatable(NotNull<BuiltinTypes> builtinTypes)
};
findFuncTy.isCheckedFunction = true;
const TypeId findFunc = arena->addType(findFuncTy);
attachMagicFunction(findFunc, magicFunctionFind);
attachDcrMagicFunction(findFunc, dcrMagicFunctionFind);
attachMagicFunction(findFunc, std::make_shared<MagicFind>());
// string.byte : string -> number? -> number? -> ...number
FunctionType stringDotByte{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList};
@ -1096,7 +1099,7 @@ TypeId makeStringMetatable(NotNull<BuiltinTypes> builtinTypes)
return arena->addType(TableType{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed});
}
static std::optional<WithPredicate<TypePackId>> magicFunctionSelect(
std::optional<WithPredicate<TypePackId>> MagicSelect::handleOldSolver(
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
@ -1141,7 +1144,7 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionSelect(
return std::nullopt;
}
static bool dcrMagicFunctionSelect(MagicFunctionCallContext context)
bool MagicSelect::infer(const MagicFunctionCallContext& context)
{
if (context.callSite->args.size <= 0)
{
@ -1186,7 +1189,7 @@ static bool dcrMagicFunctionSelect(MagicFunctionCallContext context)
return false;
}
static std::optional<WithPredicate<TypePackId>> magicFunctionSetMetaTable(
std::optional<WithPredicate<TypePackId>> MagicSetMetatable::handleOldSolver(
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
@ -1268,7 +1271,12 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionSetMetaTable(
return WithPredicate<TypePackId>{arena.addTypePack({target})};
}
static std::optional<WithPredicate<TypePackId>> magicFunctionAssert(
bool MagicSetMetatable::infer(const MagicFunctionCallContext&)
{
return false;
}
std::optional<WithPredicate<TypePackId>> MagicAssert::handleOldSolver(
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
@ -1302,7 +1310,12 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionAssert(
return WithPredicate<TypePackId>{arena.addTypePack(TypePack{std::move(head), tail})};
}
static std::optional<WithPredicate<TypePackId>> magicFunctionPack(
bool MagicAssert::infer(const MagicFunctionCallContext&)
{
return false;
}
std::optional<WithPredicate<TypePackId>> MagicPack::handleOldSolver(
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
@ -1345,7 +1358,7 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionPack(
return WithPredicate<TypePackId>{arena.addTypePack({packedTable})};
}
static bool dcrMagicFunctionPack(MagicFunctionCallContext context)
bool MagicPack::infer(const MagicFunctionCallContext& context)
{
TypeArena* arena = context.solver->arena;
@ -1385,7 +1398,68 @@ static bool dcrMagicFunctionPack(MagicFunctionCallContext context)
return true;
}
static std::optional<TypeId> freezeTable(TypeId inputType, MagicFunctionCallContext& context)
std::optional<WithPredicate<TypePackId>> MagicClone::handleOldSolver(
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
)
{
LUAU_ASSERT(FFlag::LuauTableCloneClonesType);
auto [paramPack, _predicates] = withPredicate;
TypeArena& arena = typechecker.currentModule->internalTypes;
const auto& [paramTypes, paramTail] = flatten(paramPack);
if (paramTypes.empty() || expr.args.size == 0)
{
typechecker.reportError(expr.argLocation, CountMismatch{1, std::nullopt, 0});
return std::nullopt;
}
TypeId inputType = follow(paramTypes[0]);
CloneState cloneState{typechecker.builtinTypes};
TypeId resultType = shallowClone(inputType, arena, cloneState);
TypePackId clonedTypePack = arena.addTypePack({resultType});
return WithPredicate<TypePackId>{clonedTypePack};
}
bool MagicClone::infer(const MagicFunctionCallContext& context)
{
LUAU_ASSERT(FFlag::LuauTableCloneClonesType);
TypeArena* arena = context.solver->arena;
const auto& [paramTypes, paramTail] = flatten(context.arguments);
if (paramTypes.empty() || context.callSite->args.size == 0)
{
context.solver->reportError(CountMismatch{1, std::nullopt, 0}, context.callSite->argLocation);
return false;
}
TypeId inputType = follow(paramTypes[0]);
CloneState cloneState{context.solver->builtinTypes};
TypeId resultType = shallowClone(inputType, *arena, cloneState);
if (auto tableType = getMutable<TableType>(resultType))
{
tableType->scope = context.constraint->scope.get();
}
if (FFlag::LuauTrackInteriorFreeTypesOnScope)
trackInteriorFreeType(context.constraint->scope.get(), resultType);
TypePackId clonedTypePack = arena->addTypePack({resultType});
asMutable(context.result)->ty.emplace<BoundTypePack>(clonedTypePack);
return true;
}
static std::optional<TypeId> freezeTable(TypeId inputType, const MagicFunctionCallContext& context)
{
TypeArena* arena = context.solver->arena;
@ -1430,7 +1504,12 @@ static std::optional<TypeId> freezeTable(TypeId inputType, MagicFunctionCallCont
return std::nullopt;
}
static bool dcrMagicFunctionFreeze(MagicFunctionCallContext context)
std::optional<WithPredicate<TypePackId>> MagicFreeze::handleOldSolver(struct TypeChecker &, const std::shared_ptr<struct Scope> &, const class AstExprCall &, WithPredicate<TypePackId>)
{
return std::nullopt;
}
bool MagicFreeze::infer(const MagicFunctionCallContext& context)
{
LUAU_ASSERT(FFlag::LuauTypestateBuiltins2);
@ -1491,7 +1570,7 @@ static bool checkRequirePath(TypeChecker& typechecker, AstExpr* expr)
return good;
}
static std::optional<WithPredicate<TypePackId>> magicFunctionRequire(
std::optional<WithPredicate<TypePackId>> MagicRequire::handleOldSolver(
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
@ -1537,7 +1616,7 @@ static bool checkRequirePathDcr(NotNull<ConstraintSolver> solver, AstExpr* expr)
return good;
}
static bool dcrMagicFunctionRequire(MagicFunctionCallContext context)
bool MagicRequire::infer(const MagicFunctionCallContext& context)
{
if (context.callSite->args.size != 1)
{

View file

@ -3,8 +3,6 @@
#include "Luau/Constraint.h"
#include "Luau/VisitType.h"
LUAU_FASTFLAGVARIABLE(LuauDontRefCountTypesInTypeFunctions)
namespace Luau
{
@ -60,7 +58,7 @@ struct ReferenceCountInitializer : TypeOnceVisitor
//
// The default behavior here is `true` for "visit the child types"
// of this type, hence:
return !FFlag::LuauDontRefCountTypesInTypeFunctions;
return false;
}
};

View file

@ -75,7 +75,7 @@ size_t HashBlockedConstraintId::operator()(const BlockedConstraintId& bci) const
{
if (auto blocked = get<BlockedType>(ty))
{
Constraint* owner = blocked->getOwner();
const Constraint* owner = blocked->getOwner();
LUAU_ASSERT(owner);
return owner == constraint;
}
@ -446,7 +446,7 @@ void ConstraintSolver::run()
if (success)
{
unblock(c);
unsolvedConstraints.erase(unsolvedConstraints.begin() + i);
unsolvedConstraints.erase(unsolvedConstraints.begin() + ptrdiff_t(i));
// decrement the referenced free types for this constraint if we dispatched successfully!
for (auto ty : c->getMaybeMutatedFreeTypes())
@ -553,7 +553,7 @@ void ConstraintSolver::finalizeTypeFunctions()
}
}
bool ConstraintSolver::isDone()
bool ConstraintSolver::isDone() const
{
return unsolvedConstraints.empty();
}
@ -1293,11 +1293,11 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull<cons
if (ftv)
{
if (ftv->dcrMagicFunction)
usedMagic = ftv->dcrMagicFunction(MagicFunctionCallContext{NotNull{this}, constraint, c.callSite, c.argsPack, result});
if (ftv->dcrMagicRefinement)
ftv->dcrMagicRefinement(MagicRefinementContext{constraint->scope, c.callSite, c.discriminantTypes});
if (ftv->magic)
{
usedMagic = ftv->magic->infer(MagicFunctionCallContext{NotNull{this}, constraint, c.callSite, c.argsPack, result});
ftv->magic->refine(MagicRefinementContext{constraint->scope, c.callSite, c.discriminantTypes});
}
}
if (!usedMagic)
@ -1702,7 +1702,7 @@ bool ConstraintSolver::tryDispatchHasIndexer(
for (TypeId part : parts)
{
TypeId r = arena->addType(BlockedType{});
getMutable<BlockedType>(r)->setOwner(const_cast<Constraint*>(constraint.get()));
getMutable<BlockedType>(r)->setOwner(constraint.get());
bool ok = tryDispatchHasIndexer(recursionDepth, constraint, part, indexType, r, seen);
// If we've cut a recursive loop short, skip it.
@ -1734,7 +1734,7 @@ bool ConstraintSolver::tryDispatchHasIndexer(
for (TypeId part : parts)
{
TypeId r = arena->addType(BlockedType{});
getMutable<BlockedType>(r)->setOwner(const_cast<Constraint*>(constraint.get()));
getMutable<BlockedType>(r)->setOwner(constraint.get());
bool ok = tryDispatchHasIndexer(recursionDepth, constraint, part, indexType, r, seen);
// If we've cut a recursive loop short, skip it.
@ -2874,10 +2874,10 @@ bool ConstraintSolver::blockOnPendingTypes(TypeId target, NotNull<const Constrai
return !blocker.blocked;
}
bool ConstraintSolver::blockOnPendingTypes(TypePackId pack, NotNull<const Constraint> constraint)
bool ConstraintSolver::blockOnPendingTypes(TypePackId targetPack, NotNull<const Constraint> constraint)
{
Blocker blocker{NotNull{this}, constraint};
blocker.traverse(pack);
blocker.traverse(targetPack);
return !blocker.blocked;
}

View file

@ -62,6 +62,12 @@ const RefinementKey* RefinementKeyArena::node(const RefinementKey* parent, DefId
return allocator.allocate(RefinementKey{parent, def, propName});
}
DataFlowGraph::DataFlowGraph(NotNull<DefArena> defArena, NotNull<RefinementKeyArena> keyArena)
: defArena{defArena}
, keyArena{keyArena}
{
}
DefId DataFlowGraph::getDef(const AstExpr* expr) const
{
auto def = astDefs.find(expr);
@ -178,11 +184,23 @@ bool DfgScope::canUpdateDefinition(DefId def, const std::string& key) const
return true;
}
DataFlowGraph DataFlowGraphBuilder::build(AstStatBlock* block, NotNull<InternalErrorReporter> handle)
DataFlowGraphBuilder::DataFlowGraphBuilder(NotNull<DefArena> defArena, NotNull<RefinementKeyArena> keyArena)
: graph{defArena, keyArena}
, defArena{defArena}
, keyArena{keyArena}
{
}
DataFlowGraph DataFlowGraphBuilder::build(
AstStatBlock* block,
NotNull<DefArena> defArena,
NotNull<RefinementKeyArena> keyArena,
NotNull<struct InternalErrorReporter> handle
)
{
LUAU_TIMETRACE_SCOPE("DataFlowGraphBuilder::build", "Typechecking");
DataFlowGraphBuilder builder;
DataFlowGraphBuilder builder(defArena, keyArena);
builder.handle = handle;
DfgScope* moduleScope = builder.makeChildScope();
PushScope ps{builder.scopeStack, moduleScope};
@ -198,30 +216,6 @@ DataFlowGraph DataFlowGraphBuilder::build(AstStatBlock* block, NotNull<InternalE
return std::move(builder.graph);
}
std::pair<std::shared_ptr<DataFlowGraph>, std::vector<std::unique_ptr<DfgScope>>> DataFlowGraphBuilder::buildShared(
AstStatBlock* block,
NotNull<InternalErrorReporter> handle
)
{
LUAU_TIMETRACE_SCOPE("DataFlowGraphBuilder::build", "Typechecking");
DataFlowGraphBuilder builder;
builder.handle = handle;
DfgScope* moduleScope = builder.makeChildScope();
PushScope ps{builder.scopeStack, moduleScope};
builder.visitBlockWithoutChildScope(block);
builder.resolveCaptures();
if (FFlag::DebugLuauFreezeArena)
{
builder.defArena->allocator.freeze();
builder.keyArena->allocator.freeze();
}
return {std::make_shared<DataFlowGraph>(std::move(builder.graph)), std::move(builder.scopes)};
}
void DataFlowGraphBuilder::resolveCaptures()
{
for (const auto& [_, capture] : captures)

View file

@ -206,7 +206,7 @@ static bool isTerminal(const EGraph& egraph, Id eclass)
nodes.end(),
[](auto& a)
{
return isTerminal(a);
return isTerminal(a.node);
}
);
}
@ -464,7 +464,7 @@ static size_t computeCost(std::unordered_map<Id, size_t>& bestNodes, const EGrap
if (auto it = costs.find(id); it != costs.end())
return it->second;
const std::vector<EType>& nodes = egraph[id].nodes;
const std::vector<Node<EType>>& nodes = egraph[id].nodes;
size_t minCost = std::numeric_limits<size_t>::max();
size_t bestNode = std::numeric_limits<size_t>::max();
@ -481,7 +481,7 @@ static size_t computeCost(std::unordered_map<Id, size_t>& bestNodes, const EGrap
// First, quickly scan for a terminal type. If we can find one, it is obviously the best.
for (size_t index = 0; index < nodes.size(); ++index)
{
if (isTerminal(nodes[index]))
if (isTerminal(nodes[index].node))
{
minCost = 1;
bestNode = index;
@ -533,44 +533,44 @@ static size_t computeCost(std::unordered_map<Id, size_t>& bestNodes, const EGrap
{
const auto& node = nodes[index];
if (node.get<TBound>())
if (node.node.get<TBound>())
updateCost(BOUND_PENALTY, index); // TODO: This could probably be an assert now that we don't need rewrite rules to handle TBound.
else if (node.get<TFunction>())
else if (node.node.get<TFunction>())
{
minCost = 1;
bestNode = index;
}
else if (auto tbl = node.get<TTable>())
else if (auto tbl = node.node.get<TTable>())
{
// TODO: We could make the penalty a parameter to computeChildren.
std::optional<size_t> maybeCost = computeChildren(tbl->operands(), minCost);
if (maybeCost)
updateCost(TABLE_TYPE_PENALTY + *maybeCost, index);
}
else if (node.get<TImportedTable>())
else if (node.node.get<TImportedTable>())
{
minCost = IMPORTED_TABLE_PENALTY;
bestNode = index;
}
else if (auto u = node.get<Union>())
else if (auto u = node.node.get<Union>())
{
std::optional<size_t> maybeCost = computeChildren(u->operands(), minCost);
if (maybeCost)
updateCost(SET_TYPE_PENALTY + *maybeCost, index);
}
else if (auto i = node.get<Intersection>())
else if (auto i = node.node.get<Intersection>())
{
std::optional<size_t> maybeCost = computeChildren(i->operands(), minCost);
if (maybeCost)
updateCost(SET_TYPE_PENALTY + *maybeCost, index);
}
else if (auto negation = node.get<Negation>())
else if (auto negation = node.node.get<Negation>())
{
std::optional<size_t> maybeCost = computeChildren(negation->operands(), minCost);
if (maybeCost)
updateCost(NEGATION_PENALTY + *maybeCost, index);
}
else if (auto tfun = node.get<TTypeFun>())
else if (auto tfun = node.node.get<TTypeFun>())
{
std::optional<size_t> maybeCost = computeChildren(tfun->operands(), minCost);
if (maybeCost)
@ -643,7 +643,7 @@ TypeId flattenTableNode(
for (size_t i = 0; i < eclass.nodes.size(); ++i)
{
if (eclass.nodes[i].get<TImportedTable>())
if (eclass.nodes[i].node.get<TImportedTable>())
{
found = true;
index = i;
@ -660,13 +660,13 @@ TypeId flattenTableNode(
}
const auto& node = eclass.nodes[index];
if (const TTable* ttable = node.get<TTable>())
if (const TTable* ttable = node.node.get<TTable>())
{
stack.push_back(ttable);
id = ttable->getBasis();
continue;
}
else if (const TImportedTable* ti = node.get<TImportedTable>())
else if (const TImportedTable* ti = node.node.get<TImportedTable>())
{
importedTable = ti;
break;
@ -718,7 +718,7 @@ TypeId fromId(
size_t index = bestNodes.at(rootId);
LUAU_ASSERT(index <= egraph[rootId].nodes.size());
const EType& node = egraph[rootId].nodes[index];
const EType& node = egraph[rootId].nodes[index].node;
if (node.get<TNil>())
return builtinTypes->nilType;
@ -1025,8 +1025,9 @@ std::string toDot(const StringCache& strings, const EGraph& egraph)
for (const auto& [id, eclass] : egraph.getAllClasses())
{
for (const auto& node : eclass.nodes)
for (const auto& n : eclass.nodes)
{
const EType& node = n.node;
if (!node.operands().empty())
populated.insert(id);
for (Id op : node.operands())
@ -1047,7 +1048,7 @@ std::string toDot(const StringCache& strings, const EGraph& egraph)
for (size_t index = 0; index < eclass.nodes.size(); ++index)
{
const auto& node = eclass.nodes[index];
const auto& node = eclass.nodes[index].node;
const std::string label = getNodeName(strings, node);
const std::string nodeName = "n" + std::to_string(uint32_t(id)) + "_" + std::to_string(index);
@ -1062,7 +1063,7 @@ std::string toDot(const StringCache& strings, const EGraph& egraph)
{
for (size_t index = 0; index < eclass.nodes.size(); ++index)
{
const auto& node = eclass.nodes[index];
const auto& node = eclass.nodes[index].node;
const std::string label = getNodeName(strings, node);
const std::string nodeName = "n" + std::to_string(uint32_t(egraph.find(id))) + "_" + std::to_string(index);
@ -1098,7 +1099,7 @@ static Tag const* isTag(const EGraph& egraph, Id id)
{
for (const auto& node : egraph[id].nodes)
{
if (auto n = isTag<Tag>(node))
if (auto n = isTag<Tag>(node.node))
return n;
}
return nullptr;
@ -1134,7 +1135,7 @@ protected:
{
for (const auto& node : (*egraph)[id].nodes)
{
if (auto n = node.get<Tag>())
if (auto n = node.node.get<Tag>())
return n;
}
return nullptr;
@ -1322,8 +1323,10 @@ const EType* findSubtractableClass(const EGraph& egraph, std::unordered_set<Id>&
const EType* bestUnion = nullptr;
std::optional<size_t> unionSize;
for (const auto& node : egraph[id].nodes)
for (const auto& n : egraph[id].nodes)
{
const EType& node = n.node;
if (isTerminal(node))
return &node;
@ -1439,14 +1442,14 @@ bool subtract(EGraph& egraph, CanonicalizedType& ct, Id part)
return true;
}
Id fromCanonicalized(EGraph& egraph, CanonicalizedType& ct)
static std::pair<Id, size_t> fromCanonicalized(EGraph& egraph, CanonicalizedType& ct)
{
if (ct.isUnknown())
{
if (ct.errorPart)
return egraph.add(TAny{});
return {egraph.add(TAny{}), 1};
else
return egraph.add(TUnknown{});
return {egraph.add(TUnknown{}), 1};
}
std::vector<Id> parts;
@ -1484,7 +1487,12 @@ Id fromCanonicalized(EGraph& egraph, CanonicalizedType& ct)
parts.insert(parts.end(), ct.functionParts.begin(), ct.functionParts.end());
parts.insert(parts.end(), ct.otherParts.begin(), ct.otherParts.end());
return mkUnion(egraph, std::move(parts));
std::sort(parts.begin(), parts.end());
auto it = std::unique(parts.begin(), parts.end());
parts.erase(it, parts.end());
const size_t size = parts.size();
return {mkUnion(egraph, std::move(parts)), size};
}
void addChildren(const EGraph& egraph, const EType* enode, VecDeque<Id>& worklist)
@ -1530,7 +1538,7 @@ const Tag* Simplifier::isTag(Id id) const
{
for (const auto& node : get(id).nodes)
{
if (const Tag* ty = node.get<Tag>())
if (const Tag* ty = node.node.get<Tag>())
return ty;
}
@ -1564,6 +1572,16 @@ void Simplifier::subst(Id from, Id to, const std::string& ruleName, const std::u
substs.emplace_back(from, to, desc);
}
void Simplifier::subst(Id from, size_t boringIndex, Id to, const std::string& ruleName, const std::unordered_map<Id, size_t>& forceNodes)
{
std::string desc;
if (FFlag::DebugLuauLogSimplification)
desc = mkDesc(egraph, stringCache, arena, builtinTypes, from, to, forceNodes, ruleName);
egraph.markBoring(from, boringIndex);
substs.emplace_back(from, to, desc);
}
void Simplifier::unionClasses(std::vector<Id>& hereParts, Id there)
{
if (1 == hereParts.size() && isTag<TTopClass>(hereParts[0]))
@ -1614,8 +1632,11 @@ void Simplifier::simplifyUnion(Id id)
for (Id part : u->operands())
unionWithType(egraph, canonicalized, find(part));
Id resultId = fromCanonicalized(egraph, canonicalized);
const auto [resultId, newSize] = fromCanonicalized(egraph, canonicalized);
if (newSize < u->operands().size())
subst(id, unionIndex, resultId, "simplifyUnion", {{id, unionIndex}});
else
subst(id, resultId, "simplifyUnion", {{id, unionIndex}});
}
}
@ -1824,7 +1845,7 @@ void Simplifier::uninhabitedIntersection(Id id)
const auto& partNodes = egraph[partId].nodes;
for (size_t partIndex = 0; partIndex < partNodes.size(); ++partIndex)
{
const EType& N = partNodes[partIndex];
const EType& N = partNodes[partIndex].node;
if (std::optional<EType> intersection = intersectOne(egraph, accumulator, &accumulatorNode, partId, &N))
{
if (isTag<TNever>(*intersection))
@ -1847,8 +1868,13 @@ void Simplifier::uninhabitedIntersection(Id id)
if ((unsimplified.empty() || !isTag<TUnknown>(accumulator)) && find(accumulator) != id)
unsimplified.push_back(accumulator);
const bool isSmaller = unsimplified.size() < parts.size();
const Id result = mkIntersection(egraph, std::move(unsimplified));
if (isSmaller)
subst(id, index, result, "uninhabitedIntersection", {{id, index}});
else
subst(id, result, "uninhabitedIntersection", {{id, index}});
}
}
@ -1880,7 +1906,7 @@ void Simplifier::intersectWithNegatedClass(Id id)
const auto& iNodes = egraph[iId].nodes;
for (size_t iIndex = 0; iIndex < iNodes.size(); ++iIndex)
{
const EType& iNode = iNodes[iIndex];
const EType& iNode = iNodes[iIndex].node;
if (isTag<TNil>(iNode) || isTag<TBoolean>(iNode) || isTag<TNumber>(iNode) || isTag<TString>(iNode) || isTag<TThread>(iNode) ||
isTag<TTopFunction>(iNode) ||
// isTag<TTopTable>(iNode) || // I'm not sure about this one.
@ -1923,7 +1949,7 @@ void Simplifier::intersectWithNegatedClass(Id id)
newParts.push_back(part);
}
Id substId = egraph.add(Intersection{newParts.begin(), newParts.end()});
Id substId = mkIntersection(egraph, newParts);
subst(
id,
substId,
@ -1965,7 +1991,7 @@ void Simplifier::intersectWithNegatedAtom(Id id)
{
for (size_t negationOperandIndex = 0; negationOperandIndex < egraph[negation->operands()[0]].nodes.size(); ++negationOperandIndex)
{
const EType* negationOperand = &egraph[negation->operands()[0]].nodes[negationOperandIndex];
const EType* negationOperand = &egraph[negation->operands()[0]].nodes[negationOperandIndex].node;
if (!isTerminal(*negationOperand) || negationOperand->get<TOpaque>())
continue;
@ -1976,7 +2002,7 @@ void Simplifier::intersectWithNegatedAtom(Id id)
for (size_t jNodeIndex = 0; jNodeIndex < egraph[intersectionOperands[j]].nodes.size(); ++jNodeIndex)
{
const EType* jNode = &egraph[intersectionOperands[j]].nodes[jNodeIndex];
const EType* jNode = &egraph[intersectionOperands[j]].nodes[jNodeIndex].node;
if (!isTerminal(*jNode) || jNode->get<TOpaque>())
continue;
@ -2001,7 +2027,7 @@ void Simplifier::intersectWithNegatedAtom(Id id)
subst(
id,
egraph.add(Intersection{newOperands}),
mkIntersection(egraph, std::move(newOperands)),
"intersectWithNegatedAtom",
{{id, intersectionIndex}, {intersectionOperands[i], negationIndex}, {intersectionOperands[j], jNodeIndex}}
);
@ -2178,7 +2204,7 @@ void Simplifier::expandNegation(Id id)
if (!ok)
continue;
subst(id, fromCanonicalized(egraph, canonicalized), "expandNegation", {{id, index}});
subst(id, fromCanonicalized(egraph, canonicalized).first, "expandNegation", {{id, index}});
}
}
@ -2576,9 +2602,9 @@ std::optional<EqSatSimplificationResult> eqSatSimplify(NotNull<Simplifier> simpl
// try to run any rules on it.
bool shouldAbort = false;
for (const EType& enode : egraph[id].nodes)
for (const auto& enode : egraph[id].nodes)
{
if (isTerminal(enode))
if (isTerminal(enode.node))
{
shouldAbort = true;
break;
@ -2588,8 +2614,8 @@ std::optional<EqSatSimplificationResult> eqSatSimplify(NotNull<Simplifier> simpl
if (shouldAbort)
continue;
for (const EType& enode : egraph[id].nodes)
addChildren(egraph, &enode, worklist);
for (const auto& enode : egraph[id].nodes)
addChildren(egraph, &enode.node, worklist);
for (Simplifier::RewriteRuleFn rule : rules)
(simplifier.get()->*rule)(id);

View file

@ -200,7 +200,7 @@ ScopePtr findClosestScope(const ModulePtr& module, const AstStat* nearestStateme
return closest;
}
FragmentParseResult parseFragment(
std::optional<FragmentParseResult> parseFragment(
const SourceModule& srcModule,
std::string_view src,
const Position& cursorPos,
@ -245,6 +245,10 @@ FragmentParseResult parseFragment(
opts.captureComments = true;
opts.parseFragment = FragmentParseResumeSettings{std::move(result.localMap), std::move(result.localStack), startPos};
ParseResult p = Luau::Parser::parse(srcStart, parseLength, *nameTbl, *fragmentResult.alloc.get(), opts);
// This means we threw a ParseError and we should decline to offer autocomplete here.
if (p.root == nullptr)
return std::nullopt;
std::vector<AstNode*> fabricatedAncestry = std::move(result.ancestry);
// Get the ancestry for the fragment at the offset cursor position.
@ -366,7 +370,8 @@ FragmentTypeCheckResult typecheckFragment_(
TypeFunctionRuntime typeFunctionRuntime(iceHandler, NotNull{&limits});
/// Create a DataFlowGraph just for the surrounding context
auto dfg = DataFlowGraphBuilder::build(root, iceHandler);
DataFlowGraph dfg = DataFlowGraphBuilder::build(root, NotNull{&incrementalModule->defArena}, NotNull{&incrementalModule->keyArena}, iceHandler);
SimplifierPtr simplifier = newSimplifier(NotNull{&incrementalModule->internalTypes}, frontend.builtinTypes);
FrontendModuleResolver& resolver = getModuleResolver(frontend, opts);
@ -468,7 +473,13 @@ std::pair<FragmentTypeCheckStatus, FragmentTypeCheckResult> typecheckFragment(
return {};
}
FragmentParseResult parseResult = parseFragment(*sourceModule, src, cursorPos, fragmentEndPosition);
auto tryParse = parseFragment(*sourceModule, src, cursorPos, fragmentEndPosition);
if (!tryParse)
return {FragmentTypeCheckStatus::SkipAutocomplete, {}};
FragmentParseResult& parseResult = *tryParse;
if (isWithinComment(parseResult.commentLocations, fragmentEndPosition.value_or(cursorPos)))
return {FragmentTypeCheckStatus::SkipAutocomplete, {}};

View file

@ -13,6 +13,7 @@
#include "Luau/EqSatSimplification.h"
#include "Luau/FileResolver.h"
#include "Luau/NonStrictTypeChecker.h"
#include "Luau/NotNull.h"
#include "Luau/Parser.h"
#include "Luau/Scope.h"
#include "Luau/StringUtils.h"
@ -1338,7 +1339,7 @@ ModulePtr check(
}
}
DataFlowGraph dfg = DataFlowGraphBuilder::build(sourceModule.root, iceHandler);
DataFlowGraph dfg = DataFlowGraphBuilder::build(sourceModule.root, NotNull{&result->defArena}, NotNull{&result->keyArena}, iceHandler);
UnifierSharedState unifierState{iceHandler};
unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit;

View file

@ -61,9 +61,7 @@ TypeId Instantiation::clean(TypeId ty)
LUAU_ASSERT(ftv);
FunctionType clone = FunctionType{level, scope, ftv->argTypes, ftv->retTypes, ftv->definition, ftv->hasSelf};
clone.magicFunction = ftv->magicFunction;
clone.dcrMagicFunction = ftv->dcrMagicFunction;
clone.dcrMagicRefinement = ftv->dcrMagicRefinement;
clone.magic = ftv->magic;
clone.tags = ftv->tags;
clone.argNames = ftv->argNames;
TypeId result = addType(std::move(clone));

View file

@ -98,9 +98,7 @@ static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool a
FunctionType clone = FunctionType{a.level, a.scope, a.argTypes, a.retTypes, a.definition, a.hasSelf};
clone.generics = a.generics;
clone.genericPacks = a.genericPacks;
clone.magicFunction = a.magicFunction;
clone.dcrMagicFunction = a.dcrMagicFunction;
clone.dcrMagicRefinement = a.dcrMagicRefinement;
clone.magic = a.magic;
clone.tags = a.tags;
clone.argNames = a.argNames;
clone.isCheckedFunction = a.isCheckedFunction;

View file

@ -554,12 +554,12 @@ BlockedType::BlockedType()
{
}
Constraint* BlockedType::getOwner() const
const Constraint* BlockedType::getOwner() const
{
return owner;
}
void BlockedType::setOwner(Constraint* newOwner)
void BlockedType::setOwner(const Constraint* newOwner)
{
LUAU_ASSERT(owner == nullptr);
@ -569,7 +569,7 @@ void BlockedType::setOwner(Constraint* newOwner)
owner = newOwner;
}
void BlockedType::replaceOwner(Constraint* newOwner)
void BlockedType::replaceOwner(const Constraint* newOwner)
{
owner = newOwner;
}

View file

@ -1454,9 +1454,10 @@ void TypeChecker2::visitCall(AstExprCall* call)
TypePackId argsTp = module->internalTypes.addTypePack(args);
if (auto ftv = get<FunctionType>(follow(*originalCallTy)))
{
if (ftv->dcrMagicTypeCheck)
if (ftv->magic)
{
ftv->dcrMagicTypeCheck(MagicFunctionTypeCheckContext{NotNull{this}, builtinTypes, call, argsTp, scope});
bool usedMagic = ftv->magic->typeCheck(MagicFunctionTypeCheckContext{NotNull{this}, builtinTypes, call, argsTp, scope});
if (usedMagic)
return;
}
}

View file

@ -4506,10 +4506,10 @@ std::unique_ptr<WithPredicate<TypePackId>> TypeChecker::checkCallOverload(
// When this function type has magic functions and did return something, we select that overload instead.
// TODO: pass in a Unifier object to the magic functions? This will allow the magic functions to cooperate with overload resolution.
if (ftv->magicFunction)
if (ftv->magic)
{
// TODO: We're passing in the wrong TypePackId. Should be argPack, but a unit test fails otherwise. CLI-40458
if (std::optional<WithPredicate<TypePackId>> ret = ftv->magicFunction(*this, scope, expr, argListResult))
if (std::optional<WithPredicate<TypePackId>> ret = ftv->magic->handleOldSolver(*this, scope, expr, argListResult))
return std::make_unique<WithPredicate<TypePackId>>(std::move(*ret));
}

View file

@ -23,6 +23,7 @@ LUAU_FASTFLAGVARIABLE(LuauAllowComplexTypesInGenericParams)
LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryForTableTypes)
LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryForClassNames)
LUAU_FASTFLAGVARIABLE(LuauFixFunctionNameStartPosition)
LUAU_FASTFLAGVARIABLE(LuauExtendStatEndPosWithSemicolon)
namespace Luau
{
@ -288,6 +289,10 @@ AstStatBlock* Parser::parseBlockNoScope()
{
nextLexeme();
stat->hasSemicolon = true;
if (FFlag::LuauExtendStatEndPosWithSemicolon)
{
stat->location.end = lexer.previousLocation().end;
}
}
body.push_back(stat);

View file

@ -14,6 +14,7 @@ inline bool isFlagExperimental(const char* flag)
"LuauInstantiateInSubtyping", // requires some fixes to lua-apps code
"LuauFixIndexerSubtypingOrdering", // requires some small fixes to lua-apps code since this fixes a false negative
"StudioReportLuauAny2", // takes telemetry data for usage of any types
"LuauTableCloneClonesType", // requires fixes in lua-apps code, terrifyingly
"LuauSolverV2",
// makes sure we always have at least one entry
nullptr,

View file

@ -51,13 +51,70 @@ struct Analysis final
}
};
template<typename L>
struct Node
{
L node;
bool boring = false;
struct Hash
{
size_t operator()(const Node& node) const
{
return typename L::Hash{}(node.node);
}
};
};
template <typename L>
struct NodeIterator
{
private:
using iterator = std::vector<Node<L>>;
iterator iter;
public:
L& operator*()
{
return iter->node;
}
const L& operator*() const
{
return iter->node;
}
iterator& operator++()
{
++iter;
return *this;
}
iterator operator++(int)
{
iterator copy = *this;
++*this;
return copy;
}
bool operator==(const iterator& rhs) const
{
return iter == rhs.iter;
}
bool operator!=(const iterator& rhs) const
{
return iter != rhs.iter;
}
};
/// Each e-class is a set of e-nodes representing equivalent terms from a given language,
/// and an e-node is a function symbol paired with a list of children e-classes.
template<typename L, typename D>
struct EClass final
{
Id id;
std::vector<L> nodes;
std::vector<Node<L>> nodes;
D data;
std::vector<std::pair<L, Id>> parents;
};
@ -125,9 +182,9 @@ struct EGraph final
std::sort(
eclass1.nodes.begin(),
eclass1.nodes.end(),
[](const L& left, const L& right)
[](const Node<L>& left, const Node<L>& right)
{
return left.index() < right.index();
return left.node.index() < right.node.index();
}
);
@ -177,6 +234,11 @@ struct EGraph final
return classes;
}
void markBoring(Id id, size_t index)
{
get(id).nodes[index].boring = true;
}
private:
Analysis<L, N> analysis;
@ -225,7 +287,7 @@ private:
id,
EClassT{
id,
{enode},
{Node<L>{enode, false}},
analysis.make(*this, enode),
{},
}
@ -264,18 +326,18 @@ private:
std::vector<std::pair<L, Id>> parents = get(id).parents;
for (auto& pair : parents)
{
L& enode = pair.first;
Id id = pair.second;
L& parentNode = pair.first;
Id parentId = pair.second;
// By removing the old enode from the hashcons map, we will always find our new canonicalized eclass id.
hashcons.erase(enode);
canonicalize(enode);
hashcons.insert_or_assign(enode, find(id));
hashcons.erase(parentNode);
canonicalize(parentNode);
hashcons.insert_or_assign(parentNode, find(parentId));
if (auto it = newParents.find(enode); it != newParents.end())
merge(id, it->second);
if (auto it = newParents.find(parentNode); it != newParents.end())
merge(parentId, it->second);
newParents.insert_or_assign(enode, find(id));
newParents.insert_or_assign(parentNode, find(parentId));
}
// We reacquire the pointer because the prior loop potentially merges
@ -287,22 +349,30 @@ private:
for (const auto& [node, id] : newParents)
eclass->parents.emplace_back(std::move(node), std::move(id));
std::unordered_set<L, typename L::Hash> newNodes;
for (L node : eclass->nodes)
std::unordered_map<L, bool, typename L::Hash> newNodes;
for (Node<L> node : eclass->nodes)
{
canonicalize(node);
newNodes.insert(std::move(node));
canonicalize(node.node);
bool& b = newNodes[std::move(node.node)];
b = b || node.boring;
}
eclass->nodes.assign(newNodes.begin(), newNodes.end());
eclass->nodes.clear();
while (!newNodes.empty())
{
auto n = newNodes.extract(newNodes.begin());
eclass->nodes.push_back(Node<L>{n.key(), n.mapped()});
}
// FIXME: Extract into sortByTag()
std::sort(
eclass->nodes.begin(),
eclass->nodes.end(),
[](const L& left, const L& right)
[](const Node<L>& left, const Node<L>& right)
{
return left.index() < right.index();
return left.node.index() < right.node.index();
}
);
}

View file

@ -154,6 +154,7 @@ LUA_API const float* lua_tovector(lua_State* L, int idx);
LUA_API int lua_toboolean(lua_State* L, int idx);
LUA_API const char* lua_tolstring(lua_State* L, int idx, size_t* len);
LUA_API const char* lua_tostringatom(lua_State* L, int idx, int* atom);
LUA_API const char* lua_tolstringatom(lua_State* L, int idx, size_t* len, int* atom);
LUA_API const char* lua_namecallatom(lua_State* L, int* atom);
LUA_API int lua_objlen(lua_State* L, int idx);
LUA_API lua_CFunction lua_tocfunction(lua_State* L, int idx);

View file

@ -454,6 +454,29 @@ const char* lua_tostringatom(lua_State* L, int idx, int* atom)
return getstr(s);
}
const char* lua_tolstringatom(lua_State* L, int idx, size_t* len, int* atom)
{
StkId o = index2addr(L, idx);
if (!ttisstring(o))
{
if (len)
*len = 0;
return NULL;
}
TString* s = tsvalue(o);
if (len)
*len = s->len;
if (atom)
{
updateatom(L, s);
*atom = s->atom;
}
return getstr(s);
}
const char* lua_namecallatom(lua_State* L, int* atom)
{
TString* s = L->namecall;

View file

@ -422,6 +422,20 @@ int luaG_isnative(lua_State* L, int level)
return (ci->flags & LUA_CALLINFO_NATIVE) != 0 ? 1 : 0;
}
int luaG_hasnative(lua_State* L, int level)
{
if (unsigned(level) >= unsigned(L->ci - L->base_ci))
return 0;
CallInfo* ci = L->ci - level;
Proto* proto = getluaproto(ci);
if (proto == nullptr)
return 0;
return (proto->execdata != nullptr);
}
void lua_singlestep(lua_State* L, int enabled)
{
L->singlestep = bool(enabled);

View file

@ -31,3 +31,4 @@ LUAI_FUNC bool luaG_onbreak(lua_State* L);
LUAI_FUNC int luaG_getline(Proto* p, int pc);
LUAI_FUNC int luaG_isnative(lua_State* L, int level);
LUAI_FUNC int luaG_hasnative(lua_State* L, int level);

View file

@ -4304,4 +4304,29 @@ foo(@1)
CHECK_EQ(EXPECTED_INSERT, *ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText);
}
TEST_CASE_FIXTURE(ACFixture, "autocomplete_at_end_of_stmt_should_continue_as_part_of_stmt")
{
check(R"(
local data = { x = 1 }
local var = data.@1
)");
auto ac = autocomplete('1');
CHECK(!ac.entryMap.empty());
CHECK(ac.entryMap.count("x"));
CHECK_EQ(ac.context, AutocompleteContext::Property);
}
TEST_CASE_FIXTURE(ACFixture, "autocomplete_after_semicolon_should_complete_a_new_statement")
{
check(R"(
local data = { x = 1 }
local var = data;@1
)");
auto ac = autocomplete('1');
CHECK(!ac.entryMap.empty());
CHECK(ac.entryMap.count("table"));
CHECK(ac.entryMap.count("math"));
CHECK_EQ(ac.context, AutocompleteContext::Statement);
}
TEST_SUITE_END();

View file

@ -22,7 +22,9 @@ ConstraintGeneratorFixture::ConstraintGeneratorFixture()
void ConstraintGeneratorFixture::generateConstraints(const std::string& code)
{
AstStatBlock* root = parse(code);
dfg = std::make_unique<DataFlowGraph>(DataFlowGraphBuilder::build(root, NotNull{&ice}));
dfg = std::make_unique<DataFlowGraph>(
DataFlowGraphBuilder::build(root, NotNull{&mainModule->defArena}, NotNull{&mainModule->keyArena}, NotNull{&ice})
);
cg = std::make_unique<ConstraintGenerator>(
mainModule,
NotNull{&normalizer},

View file

@ -1,6 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/DataFlowGraph.h"
#include "Fixture.h"
#include "Luau/Def.h"
#include "Luau/Error.h"
#include "Luau/Parser.h"
@ -18,6 +19,8 @@ struct DataFlowGraphFixture
// Only needed to fix the operator== reflexivity of an empty Symbol.
ScopedFastFlag dcr{FFlag::LuauSolverV2, true};
DefArena defArena;
RefinementKeyArena keyArena;
InternalErrorReporter handle;
Allocator allocator;
@ -32,7 +35,7 @@ struct DataFlowGraphFixture
if (!parseResult.errors.empty())
throw ParseErrors(std::move(parseResult.errors));
module = parseResult.root;
graph = DataFlowGraphBuilder::build(module, NotNull{&handle});
graph = DataFlowGraphBuilder::build(module, NotNull{&defArena}, NotNull{&keyArena}, NotNull{&handle});
}
template<typename T, int N>

View file

@ -26,6 +26,7 @@ LUAU_FASTFLAG(LuauSymbolEquality);
LUAU_FASTFLAG(LuauStoreSolverTypeOnModule);
LUAU_FASTFLAG(LexerResumesFromPosition2)
LUAU_FASTFLAG(LuauIncrementalAutocompleteCommentDetection)
LUAU_FASTINT(LuauParseErrorLimit)
static std::optional<AutocompleteEntryMap> nullCallback(std::string tag, std::optional<const ClassType*> ptr, std::optional<std::string> contents)
{
@ -69,7 +70,7 @@ struct FragmentAutocompleteFixtureImpl : BaseType
}
FragmentParseResult parseFragment(
std::optional<FragmentParseResult> parseFragment(
const std::string& document,
const Position& cursorPos,
std::optional<Position> fragmentEndPosition = std::nullopt
@ -164,6 +165,7 @@ end
}
};
//NOLINTBEGIN(bugprone-unchecked-optional-access)
TEST_SUITE_BEGIN("FragmentAutocompleteTraversalTests");
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "just_two_locals")
@ -286,13 +288,23 @@ TEST_SUITE_END();
TEST_SUITE_BEGIN("FragmentAutocompleteParserTests");
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "thrown_parse_error_leads_to_null_root")
{
check("type A = ");
ScopedFastInt sfi{FInt::LuauParseErrorLimit, 1};
auto fragment = parseFragment("type A = <>function<> more garbage here", Position(0, 39));
CHECK(fragment == std::nullopt);
}
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "local_initializer")
{
ScopedFastFlag sff{FFlag::LuauSolverV2, true};
check("local a =");
auto fragment = parseFragment("local a =", Position(0, 10));
CHECK_EQ("local a =", fragment.fragmentToParse);
CHECK_EQ(Location{Position{0, 0}, 9}, fragment.root->location);
REQUIRE(fragment.has_value());
CHECK_EQ("local a =", fragment->fragmentToParse);
CHECK_EQ(Location{Position{0, 0}, 9}, fragment->root->location);
}
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "statement_in_empty_fragment_is_non_null")
@ -310,11 +322,12 @@ TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "statement_in_empty_fragment_is_n
)",
Position(1, 0)
);
CHECK_EQ("\n", fragment.fragmentToParse);
CHECK_EQ(2, fragment.ancestry.size());
REQUIRE(fragment.root);
CHECK_EQ(0, fragment.root->body.size);
auto statBody = fragment.root->as<AstStatBlock>();
REQUIRE(fragment.has_value());
CHECK_EQ("\n", fragment->fragmentToParse);
CHECK_EQ(2, fragment->ancestry.size());
REQUIRE(fragment->root);
CHECK_EQ(0, fragment->root->body.size);
auto statBody = fragment->root->as<AstStatBlock>();
CHECK(statBody != nullptr);
}
@ -339,13 +352,15 @@ local z = x + y
Position{3, 15}
);
CHECK_EQ(Location{Position{2, 0}, Position{3, 15}}, fragment.root->location);
REQUIRE(fragment.has_value());
CHECK_EQ("local y = 5\nlocal z = x + y", fragment.fragmentToParse);
CHECK_EQ(5, fragment.ancestry.size());
REQUIRE(fragment.root);
CHECK_EQ(2, fragment.root->body.size);
auto stat = fragment.root->body.data[1]->as<AstStatLocal>();
CHECK_EQ(Location{Position{2, 0}, Position{3, 15}}, fragment->root->location);
CHECK_EQ("local y = 5\nlocal z = x + y", fragment->fragmentToParse);
CHECK_EQ(5, fragment->ancestry.size());
REQUIRE(fragment->root);
CHECK_EQ(2, fragment->root->body.size);
auto stat = fragment->root->body.data[1]->as<AstStatLocal>();
REQUIRE(stat);
CHECK_EQ(1, stat->vars.size);
CHECK_EQ(1, stat->values.size);
@ -384,12 +399,14 @@ local y = 5
Position{2, 15}
);
CHECK_EQ("local z = x + y", fragment.fragmentToParse);
CHECK_EQ(5, fragment.ancestry.size());
REQUIRE(fragment.root);
CHECK_EQ(Location{Position{2, 0}, Position{2, 15}}, fragment.root->location);
CHECK_EQ(1, fragment.root->body.size);
auto stat = fragment.root->body.data[0]->as<AstStatLocal>();
REQUIRE(fragment.has_value());
CHECK_EQ("local z = x + y", fragment->fragmentToParse);
CHECK_EQ(5, fragment->ancestry.size());
REQUIRE(fragment->root);
CHECK_EQ(Location{Position{2, 0}, Position{2, 15}}, fragment->root->location);
CHECK_EQ(1, fragment->root->body.size);
auto stat = fragment->root->body.data[0]->as<AstStatLocal>();
REQUIRE(stat);
CHECK_EQ(1, stat->vars.size);
CHECK_EQ(1, stat->values.size);
@ -429,7 +446,9 @@ TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "can_parse_in_correct_scope")
Position{6, 0}
);
CHECK_EQ("\n ", fragment.fragmentToParse);
REQUIRE(fragment.has_value());
CHECK_EQ("\n ", fragment->fragmentToParse);
}
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "can_parse_single_line_fragment_override")
@ -448,17 +467,19 @@ abc("bar")
Position{1, 10}
);
CHECK_EQ("function abc(foo: string) end\nabc(\"foo\")", callFragment.fragmentToParse);
CHECK(callFragment.nearestStatement->is<AstStatFunction>());
REQUIRE(callFragment.has_value());
CHECK_GE(callFragment.ancestry.size(), 2);
CHECK_EQ("function abc(foo: string) end\nabc(\"foo\")", callFragment->fragmentToParse);
CHECK(callFragment->nearestStatement->is<AstStatFunction>());
AstNode* back = callFragment.ancestry.back();
CHECK_GE(callFragment->ancestry.size(), 2);
AstNode* back = callFragment->ancestry.back();
CHECK(back->is<AstExprConstantString>());
CHECK_EQ(Position{1, 4}, back->location.begin);
CHECK_EQ(Position{1, 9}, back->location.end);
AstNode* parent = callFragment.ancestry.rbegin()[1];
AstNode* parent = callFragment->ancestry.rbegin()[1];
CHECK(parent->is<AstExprCall>());
CHECK_EQ(Position{1, 0}, parent->location.begin);
CHECK_EQ(Position{1, 10}, parent->location.end);
@ -473,12 +494,14 @@ abc("bar")
Position{1, 9}
);
CHECK_EQ("function abc(foo: string) end\nabc(\"foo\")", stringFragment.fragmentToParse);
CHECK(stringFragment.nearestStatement->is<AstStatFunction>());
REQUIRE(stringFragment.has_value());
CHECK_GE(stringFragment.ancestry.size(), 1);
CHECK_EQ("function abc(foo: string) end\nabc(\"foo\")", stringFragment->fragmentToParse);
CHECK(stringFragment->nearestStatement->is<AstStatFunction>());
back = stringFragment.ancestry.back();
CHECK_GE(stringFragment->ancestry.size(), 1);
back = stringFragment->ancestry.back();
auto asString = back->as<AstExprConstantString>();
CHECK(asString);
@ -508,17 +531,19 @@ abc("bar")
Position{3, 1}
);
CHECK_EQ("function abc(foo: string) end\nabc(\n\"foo\"\n)", fragment.fragmentToParse);
CHECK(fragment.nearestStatement->is<AstStatFunction>());
REQUIRE(fragment.has_value());
CHECK_GE(fragment.ancestry.size(), 2);
CHECK_EQ("function abc(foo: string) end\nabc(\n\"foo\"\n)", fragment->fragmentToParse);
CHECK(fragment->nearestStatement->is<AstStatFunction>());
AstNode* back = fragment.ancestry.back();
CHECK_GE(fragment->ancestry.size(), 2);
AstNode* back = fragment->ancestry.back();
CHECK(back->is<AstExprConstantString>());
CHECK_EQ(Position{2, 0}, back->location.begin);
CHECK_EQ(Position{2, 5}, back->location.end);
AstNode* parent = fragment.ancestry.rbegin()[1];
AstNode* parent = fragment->ancestry.rbegin()[1];
CHECK(parent->is<AstExprCall>());
CHECK_EQ(Position{1, 0}, parent->location.begin);
CHECK_EQ(Position{3, 1}, parent->location.end);
@ -549,6 +574,7 @@ t
}
TEST_SUITE_END();
//NOLINTEND(bugprone-unchecked-optional-access)
TEST_SUITE_BEGIN("FragmentAutocompleteTypeCheckerTests");
@ -1558,4 +1584,26 @@ if x == 5 then -- a comment
);
}
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "fragment_autocomplete_handles_parse_errors")
{
ScopedFastInt sfi{FInt::LuauParseErrorLimit, 1};
const std::string source = R"(
)";
const std::string updated = R"(
type A = <>random non code text here
)";
autocompleteFragmentInBothSolvers(
source,
updated,
Position{1, 38},
[](FragmentAutocompleteResult& result)
{
CHECK(result.acResults.entryMap.empty());
}
);
}
TEST_SUITE_END();

View file

@ -20,6 +20,7 @@ LUAU_FASTFLAG(LuauAllowComplexTypesInGenericParams)
LUAU_FASTFLAG(LuauErrorRecoveryForTableTypes)
LUAU_FASTFLAG(LuauErrorRecoveryForClassNames)
LUAU_FASTFLAG(LuauFixFunctionNameStartPosition)
LUAU_FASTFLAG(LuauExtendStatEndPosWithSemicolon)
namespace
{
@ -3766,5 +3767,32 @@ TEST_CASE_FIXTURE(Fixture, "function_name_has_correct_start_location")
CHECK_EQ(Position{4, 17}, function2->name->location.begin);
}
TEST_CASE_FIXTURE(Fixture, "stat_end_includes_semicolon_position")
{
ScopedFastFlag _{FFlag::LuauExtendStatEndPosWithSemicolon, true};
AstStatBlock* block = parse(R"(
local x = 1
local y = 2;
local z = 3 ;
)");
REQUIRE_EQ(3, block->body.size);
const auto stat1 = block->body.data[0];
LUAU_ASSERT(stat1);
CHECK_FALSE(stat1->hasSemicolon);
CHECK_EQ(Position{1, 19}, stat1->location.end);
const auto stat2 = block->body.data[1];
LUAU_ASSERT(stat2);
CHECK(stat2->hasSemicolon);
CHECK_EQ(Position{2, 20}, stat2->location.end);
const auto stat3 = block->body.data[2];
LUAU_ASSERT(stat3);
CHECK(stat3->hasSemicolon);
CHECK_EQ(Position{3, 22}, stat3->location.end);
}
TEST_SUITE_END();

View file

@ -12,6 +12,7 @@ using namespace Luau;
LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAG(LuauTypestateBuiltins2)
LUAU_FASTFLAG(LuauStringFormatArityFix)
LUAU_FASTFLAG(LuauTableCloneClonesType)
LUAU_FASTFLAG(LuauStringFormatErrorSuppression)
TEST_SUITE_BEGIN("BuiltinTests");
@ -1587,6 +1588,30 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "string_find_should_not_crash")
)"));
}
TEST_CASE_FIXTURE(BuiltinsFixture, "table_dot_clone_type_states")
{
CheckResult result = check(R"(
local t1 = {}
t1.x = 5
local t2 = table.clone(t1)
t2.y = 6
t1.z = 3
)");
LUAU_REQUIRE_NO_ERRORS(result);
if (FFlag::LuauTableCloneClonesType)
{
CHECK_EQ(toString(requireType("t1"), {true}), "{ x: number, z: number }");
CHECK_EQ(toString(requireType("t2"), {true}), "{ x: number, y: number }");
}
else
{
CHECK_EQ(toString(requireType("t1"), {true}), "{ x: number, y: number, z: number }");
CHECK_EQ(toString(requireType("t2"), {true}), "{ x: number, y: number, z: number }");
}
}
TEST_CASE_FIXTURE(BuiltinsFixture, "string_format_should_support_any")
{
ScopedFastFlag _{FFlag::LuauSolverV2, true};

View file

@ -20,7 +20,6 @@ LUAU_FASTFLAG(LuauInstantiateInSubtyping)
LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTINT(LuauTarjanChildLimit)
LUAU_FASTFLAG(LuauRetrySubtypingWithoutHiddenPack)
LUAU_FASTFLAG(LuauDontRefCountTypesInTypeFunctions)
LUAU_FASTFLAG(DebugLuauEqSatSimplification)
TEST_SUITE_BEGIN("TypeInferFunctions");
@ -2566,7 +2565,7 @@ end
TEST_CASE_FIXTURE(BuiltinsFixture, "tf_suggest_return_type")
{
ScopedFastFlag sffs[] = {{FFlag::LuauSolverV2, true}, {FFlag::LuauDontRefCountTypesInTypeFunctions, true}};
ScopedFastFlag _{FFlag::LuauSolverV2, true};
// CLI-114134: This test:
// a) Has a kind of weird result (suggesting `number | false` is not great);
@ -2878,8 +2877,6 @@ TEST_CASE_FIXTURE(Fixture, "fuzzer_missing_follow_in_ast_stat_fun")
TEST_CASE_FIXTURE(Fixture, "unifier_should_not_bind_free_types")
{
ScopedFastFlag _{FFlag::LuauDontRefCountTypesInTypeFunctions, true};
CheckResult result = check(R"(
function foo(player)
local success,result = player:thing()

View file

@ -16,13 +16,16 @@ using namespace Luau;
namespace
{
std::optional<WithPredicate<TypePackId>> magicFunctionInstanceIsA(
struct MagicInstanceIsA final : MagicFunction
{
std::optional<WithPredicate<TypePackId>> handleOldSolver(
TypeChecker& typeChecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
)
{
) override
{
if (expr.args.size != 1)
return std::nullopt;
@ -39,10 +42,15 @@ std::optional<WithPredicate<TypePackId>> magicFunctionInstanceIsA(
ModulePtr module = typeChecker.currentModule;
TypePackId booleanPack = module->internalTypes.addTypePack({typeChecker.booleanType});
return WithPredicate<TypePackId>{booleanPack, {IsAPredicate{std::move(*lvalue), expr.location, tfun->type}}};
}
}
void dcrMagicRefinementInstanceIsA(const MagicRefinementContext& ctx)
{
bool infer(const MagicFunctionCallContext&) override
{
return false;
}
void refine(const MagicRefinementContext& ctx) override
{
if (ctx.callSite->args.size != 1 || ctx.discriminantTypes.empty())
return;
@ -61,7 +69,10 @@ void dcrMagicRefinementInstanceIsA(const MagicRefinementContext& ctx)
LUAU_ASSERT(get<BlockedType>(*discriminantTy));
asMutable(*discriminantTy)->ty.emplace<BoundType>(tfun->type);
}
}
};
struct RefinementClassFixture : BuiltinsFixture
{
@ -85,8 +96,7 @@ struct RefinementClassFixture : BuiltinsFixture
TypePackId isAParams = arena.addTypePack({inst, builtinTypes->stringType});
TypePackId isARets = arena.addTypePack({builtinTypes->booleanType});
TypeId isA = arena.addType(FunctionType{isAParams, isARets});
getMutable<FunctionType>(isA)->magicFunction = magicFunctionInstanceIsA;
getMutable<FunctionType>(isA)->dcrMagicRefinement = dcrMagicRefinementInstanceIsA;
getMutable<FunctionType>(isA)->magic = std::make_shared<MagicInstanceIsA>();
getMutable<ClassType>(inst)->props = {
{"Name", Property{builtinTypes->stringType}},

View file

@ -24,7 +24,6 @@ LUAU_FASTINT(LuauNormalizeCacheLimit);
LUAU_FASTINT(LuauRecursionLimit);
LUAU_FASTINT(LuauTypeInferRecursionLimit);
LUAU_FASTFLAG(LuauNewSolverVisitErrorExprLvalues)
LUAU_FASTFLAG(LuauDontRefCountTypesInTypeFunctions)
LUAU_FASTFLAG(InferGlobalTypes)
using namespace Luau;
@ -1731,7 +1730,7 @@ TEST_CASE_FIXTURE(Fixture, "visit_error_nodes_in_lvalue")
TEST_CASE_FIXTURE(Fixture, "avoid_blocking_type_function")
{
ScopedFastFlag sffs[] = {{FFlag::LuauSolverV2, true}, {FFlag::LuauDontRefCountTypesInTypeFunctions, true}};
ScopedFastFlag _{FFlag::LuauSolverV2, true};
LUAU_CHECK_NO_ERRORS(check(R"(
--!strict
@ -1744,7 +1743,7 @@ TEST_CASE_FIXTURE(Fixture, "avoid_blocking_type_function")
TEST_CASE_FIXTURE(Fixture, "avoid_double_reference_to_free_type")
{
ScopedFastFlag sffs[] = {{FFlag::LuauSolverV2, true}, {FFlag::LuauDontRefCountTypesInTypeFunctions, true}};
ScopedFastFlag _{FFlag::LuauSolverV2, true};
LUAU_CHECK_NO_ERRORS(check(R"(
--!strict