Sync to upstream/release/658 (#1625)
Some checks failed
benchmark / callgrind (map[branch:main name:luau-lang/benchmark-data], ubuntu-22.04) (push) Has been cancelled
build / macos (push) Has been cancelled
build / macos-arm (push) Has been cancelled
build / ubuntu (push) Has been cancelled
build / windows (Win32) (push) Has been cancelled
build / windows (x64) (push) Has been cancelled
build / coverage (push) Has been cancelled
build / web (push) Has been cancelled
release / macos (push) Has been cancelled
release / ubuntu (push) Has been cancelled
release / windows (push) Has been cancelled
release / web (push) Has been cancelled

## What's Changed

### General
- Allow types of tables to diverge after using `table.clone` (fixes
#1617).
- Allow 2-argument vector.create in Luau.
- Fix a crash when suggesting autocomplete after encountering parsing
errors.
- Add lua_tolstringatom C API which returns the string length (whether
or not the atom exists) and which extends the existing lua_tostringatom
function the same way lua_tolstring/lua_tostring do.
- Luau now retains the DFGs of typechecked modules.

### Magic Functions Migration Note
We've made a change to the API used to define magic functions.

Previously, we had a set of function pointers on each `FunctionType`
that would be invoked by the type inference engine at the correct point.

The problem we'd run into is that they were all `std::function`s, we'd
grown quite a few of them, and Luau allocates tens of thousands of types
as it performs type inference. This adds up to a large amount of memory
for data that isn't used by 99% of types.

To slim things down a bit, we've replaced all of those `std::function`s
with a single `shared_ptr` to a new interface called `MagicFunction`.
This slims down the memory footprint of each type by about 50 bytes.

The virtual methods of `MagicFunction` have roughly 1:1 correspondence
with the old interface, so updating things should not be too difficult:

* `FunctionType::magicFunction` is now `MagicFunction::handleOldSolver`
* `FunctionType::dcrMagicFunction` is now `MagicFunction::infer`
* `FunctionType::dcrMagicRefinement` is now `MagicFunction::refine`
* `FunctionType::dcrMagicTypeCheck` is now `MagicFunction::typeCheck`

**Full Changelog**:
https://github.com/luau-lang/luau/compare/0.657...0.658

---

Co-authored-by: Andy Friesen <afriesen@roblox.com>
Co-authored-by: Ariel Weiss <aaronweiss@roblox.com>
Co-authored-by: Aviral Goel <agoel@roblox.com>
Co-authored-by: Hunter Goldstein <hgoldstein@roblox.com>
Co-authored-by: Talha Pathan <tpathan@roblox.com>
Co-authored-by: Varun Saini <vsaini@roblox.com>
Co-authored-by: Vighnesh Vijay <vvijay@roblox.com>
Co-authored-by: Vyacheslav Egorov <vegorov@roblox.com>
This commit is contained in:
ayoungbloodrbx 2025-01-24 12:15:19 -08:00 committed by GitHub
parent 6061a14e9f
commit c13b5b7440
Signed by: DevComp
GPG key ID: B5690EEEBB952194
37 changed files with 751 additions and 360 deletions

View file

@ -65,10 +65,7 @@ TypeId makeFunction( // Polymorphic
bool checked = false bool checked = false
); );
void attachMagicFunction(TypeId ty, MagicFunction fn); void attachMagicFunction(TypeId ty, std::shared_ptr<MagicFunction> fn);
void attachDcrMagicFunction(TypeId ty, DcrMagicFunction fn);
void attachDcrMagicRefinement(TypeId ty, DcrMagicRefinement fn);
void attachDcrMagicFunctionTypeCheck(TypeId ty, DcrMagicFunctionTypeCheck fn);
Property makeProperty(TypeId ty, std::optional<std::string> documentationSymbol = std::nullopt); Property makeProperty(TypeId ty, std::optional<std::string> documentationSymbol = std::nullopt);
void assignPropDocumentationSymbols(TableType::Props& props, const std::string& baseName); void assignPropDocumentationSymbols(TableType::Props& props, const std::string& baseName);

View file

@ -166,7 +166,7 @@ struct ConstraintSolver
**/ **/
void finalizeTypeFunctions(); void finalizeTypeFunctions();
bool isDone(); bool isDone() const;
private: private:
/** /**
@ -298,10 +298,10 @@ public:
// FIXME: This use of a boolean for the return result is an appalling // FIXME: This use of a boolean for the return result is an appalling
// interface. // interface.
bool blockOnPendingTypes(TypeId target, NotNull<const Constraint> constraint); bool blockOnPendingTypes(TypeId target, NotNull<const Constraint> constraint);
bool blockOnPendingTypes(TypePackId target, NotNull<const Constraint> constraint); bool blockOnPendingTypes(TypePackId targetPack, NotNull<const Constraint> constraint);
void unblock(NotNull<const Constraint> progressed); void unblock(NotNull<const Constraint> progressed);
void unblock(TypeId progressed, Location location); void unblock(TypeId ty, Location location);
void unblock(TypePackId progressed, Location location); void unblock(TypePackId progressed, Location location);
void unblock(const std::vector<TypeId>& types, Location location); void unblock(const std::vector<TypeId>& types, Location location);
void unblock(const std::vector<TypePackId>& packs, Location location); void unblock(const std::vector<TypePackId>& packs, Location location);
@ -336,7 +336,7 @@ public:
* @param location the location where the require is taking place; used for * @param location the location where the require is taking place; used for
* error locations. * error locations.
**/ **/
TypeId resolveModule(const ModuleInfo& module, const Location& location); TypeId resolveModule(const ModuleInfo& info, const Location& location);
void reportError(TypeErrorData&& data, const Location& location); void reportError(TypeErrorData&& data, const Location& location);
void reportError(TypeError e); void reportError(TypeError e);

View file

@ -6,6 +6,7 @@
#include "Luau/ControlFlow.h" #include "Luau/ControlFlow.h"
#include "Luau/DenseHash.h" #include "Luau/DenseHash.h"
#include "Luau/Def.h" #include "Luau/Def.h"
#include "Luau/NotNull.h"
#include "Luau/Symbol.h" #include "Luau/Symbol.h"
#include "Luau/TypedAllocator.h" #include "Luau/TypedAllocator.h"
@ -48,13 +49,13 @@ struct DataFlowGraph
const RefinementKey* getRefinementKey(const AstExpr* expr) const; const RefinementKey* getRefinementKey(const AstExpr* expr) const;
private: private:
DataFlowGraph() = default; DataFlowGraph(NotNull<DefArena> defArena, NotNull<RefinementKeyArena> keyArena);
DataFlowGraph(const DataFlowGraph&) = delete; DataFlowGraph(const DataFlowGraph&) = delete;
DataFlowGraph& operator=(const DataFlowGraph&) = delete; DataFlowGraph& operator=(const DataFlowGraph&) = delete;
DefArena defArena; NotNull<DefArena> defArena;
RefinementKeyArena keyArena; NotNull<RefinementKeyArena> keyArena;
DenseHashMap<const AstExpr*, const Def*> astDefs{nullptr}; DenseHashMap<const AstExpr*, const Def*> astDefs{nullptr};
@ -110,30 +111,22 @@ using ScopeStack = std::vector<DfgScope*>;
struct DataFlowGraphBuilder struct DataFlowGraphBuilder
{ {
static DataFlowGraph build(AstStatBlock* root, NotNull<struct InternalErrorReporter> handle); static DataFlowGraph build(
/**
* This method is identical to the build method above, but returns a pair of dfg, scopes as the data flow graph
* here is intended to live on the module between runs of typechecking. Before, the DFG only needed to live as
* long as the typecheck, but in a world with incremental typechecking, we need the information on the dfg to incrementally
* typecheck small fragments of code.
* @param block - pointer to the ast to build the dfg for
* @param handle - for raising internal errors while building the dfg
*/
static std::pair<std::shared_ptr<DataFlowGraph>, std::vector<std::unique_ptr<DfgScope>>> buildShared(
AstStatBlock* block, AstStatBlock* block,
NotNull<InternalErrorReporter> handle NotNull<DefArena> defArena,
NotNull<RefinementKeyArena> keyArena,
NotNull<struct InternalErrorReporter> handle
); );
private: private:
DataFlowGraphBuilder() = default; DataFlowGraphBuilder(NotNull<DefArena> defArena, NotNull<RefinementKeyArena> keyArena);
DataFlowGraphBuilder(const DataFlowGraphBuilder&) = delete; DataFlowGraphBuilder(const DataFlowGraphBuilder&) = delete;
DataFlowGraphBuilder& operator=(const DataFlowGraphBuilder&) = delete; DataFlowGraphBuilder& operator=(const DataFlowGraphBuilder&) = delete;
DataFlowGraph graph; DataFlowGraph graph;
NotNull<DefArena> defArena{&graph.defArena}; NotNull<DefArena> defArena;
NotNull<RefinementKeyArena> keyArena{&graph.keyArena}; NotNull<RefinementKeyArena> keyArena;
struct InternalErrorReporter* handle = nullptr; struct InternalErrorReporter* handle = nullptr;

View file

@ -105,6 +105,9 @@ private:
std::vector<Id> storage; std::vector<Id> storage;
}; };
template <typename L>
using Node = EqSat::Node<L>;
using EType = EqSat::Language< using EType = EqSat::Language<
TNil, TNil,
TBoolean, TBoolean,
@ -171,6 +174,9 @@ struct Subst
Id eclass; Id eclass;
Id newClass; Id newClass;
// The node into eclass which is boring, if any
std::optional<size_t> boringIndex;
std::string desc; std::string desc;
Subst(Id eclass, Id newClass, std::string desc = ""); Subst(Id eclass, Id newClass, std::string desc = "");
@ -211,6 +217,7 @@ struct Simplifier
void subst(Id from, Id to); void subst(Id from, Id to);
void subst(Id from, Id to, const std::string& ruleName); void subst(Id from, Id to, const std::string& ruleName);
void subst(Id from, Id to, const std::string& ruleName, const std::unordered_map<Id, size_t>& forceNodes); void subst(Id from, Id to, const std::string& ruleName, const std::unordered_map<Id, size_t>& forceNodes);
void subst(Id from, size_t boringIndex, Id to, const std::string& ruleName, const std::unordered_map<Id, size_t>& forceNodes);
void unionClasses(std::vector<Id>& hereParts, Id there); void unionClasses(std::vector<Id>& hereParts, Id there);
@ -295,13 +302,13 @@ QueryIterator<Tag>::QueryIterator(EGraph* egraph_, Id eclass)
for (const auto& enode : ecl.nodes) for (const auto& enode : ecl.nodes)
{ {
if (enode.index() < idx) if (enode.node.index() < idx)
++index; ++index;
else else
break; break;
} }
if (index >= ecl.nodes.size() || ecl.nodes[index].index() != idx) if (index >= ecl.nodes.size() || ecl.nodes[index].node.index() != idx)
{ {
egraph = nullptr; egraph = nullptr;
index = 0; index = 0;
@ -331,7 +338,7 @@ std::pair<const Tag*, size_t> QueryIterator<Tag>::operator*() const
EGraph::EClassT& ecl = (*egraph)[eclass]; EGraph::EClassT& ecl = (*egraph)[eclass];
LUAU_ASSERT(index < ecl.nodes.size()); LUAU_ASSERT(index < ecl.nodes.size());
auto& enode = ecl.nodes[index]; auto& enode = ecl.nodes[index].node;
Tag* result = enode.template get<Tag>(); Tag* result = enode.template get<Tag>();
LUAU_ASSERT(result); LUAU_ASSERT(result);
return {result, index}; return {result, index};
@ -343,12 +350,16 @@ QueryIterator<Tag>& QueryIterator<Tag>::operator++()
{ {
const auto& ecl = (*egraph)[eclass]; const auto& ecl = (*egraph)[eclass];
++index; do
if (index >= ecl.nodes.size() || ecl.nodes[index].index() != EType::VariantTy::getTypeId<Tag>())
{ {
egraph = nullptr; ++index;
index = 0; if (index >= ecl.nodes.size() || ecl.nodes[index].node.index() != EType::VariantTy::getTypeId<Tag>())
} {
egraph = nullptr;
index = 0;
break;
}
} while (ecl.nodes[index].boring);
return *this; return *this;
} }

View file

@ -17,8 +17,8 @@ struct FrontendOptions;
enum class FragmentTypeCheckStatus enum class FragmentTypeCheckStatus
{ {
Success,
SkipAutocomplete, SkipAutocomplete,
Success,
}; };
struct FragmentAutocompleteAncestryResult struct FragmentAutocompleteAncestryResult
@ -56,7 +56,7 @@ struct FragmentAutocompleteResult
FragmentAutocompleteAncestryResult findAncestryForFragmentParse(AstStatBlock* root, const Position& cursorPos); FragmentAutocompleteAncestryResult findAncestryForFragmentParse(AstStatBlock* root, const Position& cursorPos);
FragmentParseResult parseFragment( std::optional<FragmentParseResult> parseFragment(
const SourceModule& srcModule, const SourceModule& srcModule,
std::string_view src, std::string_view src,
const Position& cursorPos, const Position& cursorPos,

View file

@ -139,6 +139,11 @@ struct Module
TypePackId returnType = nullptr; TypePackId returnType = nullptr;
std::unordered_map<Name, TypeFun> exportedTypeBindings; std::unordered_map<Name, TypeFun> exportedTypeBindings;
// Arenas related to the DFG must persist after the DFG no longer exists, as
// Module objects maintain raw pointers to objects in these arenas.
DefArena defArena;
RefinementKeyArena keyArena;
bool hasModuleScope() const; bool hasModuleScope() const;
ScopePtr getModuleScope() const; ScopePtr getModuleScope() const;

View file

@ -131,14 +131,14 @@ struct BlockedType
BlockedType(); BlockedType();
int index; int index;
Constraint* getOwner() const; const Constraint* getOwner() const;
void setOwner(Constraint* newOwner); void setOwner(const Constraint* newOwner);
void replaceOwner(Constraint* newOwner); void replaceOwner(const Constraint* newOwner);
private: private:
// The constraint that is intended to unblock this type. Other constraints // The constraint that is intended to unblock this type. Other constraints
// should block on this constraint if present. // should block on this constraint if present.
Constraint* owner = nullptr; const Constraint* owner = nullptr;
}; };
struct PrimitiveType struct PrimitiveType
@ -279,9 +279,6 @@ struct WithPredicate
} }
}; };
using MagicFunction = std::function<std::optional<
WithPredicate<TypePackId>>(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>)>;
struct MagicFunctionCallContext struct MagicFunctionCallContext
{ {
NotNull<struct ConstraintSolver> solver; NotNull<struct ConstraintSolver> solver;
@ -291,7 +288,6 @@ struct MagicFunctionCallContext
TypePackId result; TypePackId result;
}; };
using DcrMagicFunction = std::function<bool(MagicFunctionCallContext)>;
struct MagicRefinementContext struct MagicRefinementContext
{ {
NotNull<Scope> scope; NotNull<Scope> scope;
@ -308,8 +304,29 @@ struct MagicFunctionTypeCheckContext
NotNull<Scope> checkScope; NotNull<Scope> checkScope;
}; };
using DcrMagicRefinement = void (*)(const MagicRefinementContext&); struct MagicFunction
using DcrMagicFunctionTypeCheck = std::function<void(const MagicFunctionTypeCheckContext&)>; {
virtual std::optional<WithPredicate<TypePackId>> handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) = 0;
// Callback to allow custom typechecking of builtin function calls whose argument types
// will only be resolved after constraint solving. For example, the arguments to string.format
// have types that can only be decided after parsing the format string and unifying
// with the passed in values, but the correctness of the call can only be decided after
// all the types have been finalized.
virtual bool infer(const MagicFunctionCallContext&) = 0;
virtual void refine(const MagicRefinementContext&) {}
// If a magic function needs to do its own special typechecking, do it here.
// Returns true if magic typechecking was performed. Return false if the
// default typechecking logic should run.
virtual bool typeCheck(const MagicFunctionTypeCheckContext&)
{
return false;
}
virtual ~MagicFunction() {}
};
struct FunctionType struct FunctionType
{ {
// Global monomorphic function // Global monomorphic function
@ -367,16 +384,7 @@ struct FunctionType
Scope* scope = nullptr; Scope* scope = nullptr;
TypePackId argTypes; TypePackId argTypes;
TypePackId retTypes; TypePackId retTypes;
MagicFunction magicFunction = nullptr; std::shared_ptr<MagicFunction> magic = nullptr;
DcrMagicFunction dcrMagicFunction = nullptr;
DcrMagicRefinement dcrMagicRefinement = nullptr;
// Callback to allow custom typechecking of builtin function calls whose argument types
// will only be resolved after constraint solving. For example, the arguments to string.format
// have types that can only be decided after parsing the format string and unifying
// with the passed in values, but the correctness of the call can only be decided after
// all the types have been finalized.
DcrMagicFunctionTypeCheck dcrMagicTypeCheck = nullptr;
bool hasSelf; bool hasSelf;
// `hasNoFreeOrGenericTypes` should be true if and only if the type does not have any free or generic types present inside it. // `hasNoFreeOrGenericTypes` should be true if and only if the type does not have any free or generic types present inside it.

View file

@ -85,6 +85,8 @@ struct GenericTypeVisitor
{ {
} }
virtual ~GenericTypeVisitor() {}
virtual void cycle(TypeId) {} virtual void cycle(TypeId) {}
virtual void cycle(TypePackId) {} virtual void cycle(TypePackId) {}

View file

@ -13,6 +13,8 @@
LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAG(LuauExtendStatEndPosWithSemicolon)
namespace Luau namespace Luau
{ {
@ -41,11 +43,26 @@ struct AutocompleteNodeFinder : public AstVisitor
bool visit(AstStat* stat) override bool visit(AstStat* stat) override
{ {
if (stat->location.begin < pos && pos <= stat->location.end) if (FFlag::LuauExtendStatEndPosWithSemicolon)
{ {
ancestry.push_back(stat); // Consider 'local myLocal = 4;|' and 'local myLocal = 4', where '|' is the cursor position. In both cases, the cursor position is equal
return true; // to `AstStatLocal.location.end`. However, in the first case (semicolon), we are starting a new statement, whilst in the second case
// (no semicolon) we are still part of the AstStatLocal, hence the different comparison check.
if (stat->location.begin < pos && (stat->hasSemicolon ? pos < stat->location.end : pos <= stat->location.end))
{
ancestry.push_back(stat);
return true;
}
} }
else
{
if (stat->location.begin < pos && pos <= stat->location.end)
{
ancestry.push_back(stat);
return true;
}
}
return false; return false;
} }

View file

@ -34,46 +34,78 @@ LUAU_FASTFLAGVARIABLE(LuauStringFormatArityFix)
LUAU_FASTFLAGVARIABLE(LuauStringFormatErrorSuppression) LUAU_FASTFLAGVARIABLE(LuauStringFormatErrorSuppression)
LUAU_FASTFLAG(AutocompleteRequirePathSuggestions2) LUAU_FASTFLAG(AutocompleteRequirePathSuggestions2)
LUAU_FASTFLAG(LuauVectorDefinitionsExtra) LUAU_FASTFLAG(LuauVectorDefinitionsExtra)
LUAU_FASTFLAGVARIABLE(LuauTableCloneClonesType)
LUAU_FASTFLAG(LuauTrackInteriorFreeTypesOnScope)
namespace Luau namespace Luau
{ {
static std::optional<WithPredicate<TypePackId>> magicFunctionSelect( struct MagicSelect final : MagicFunction
TypeChecker& typechecker, {
const ScopePtr& scope, std::optional<WithPredicate<TypePackId>> handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
const AstExprCall& expr, bool infer(const MagicFunctionCallContext& ctx) override;
WithPredicate<TypePackId> withPredicate };
);
static std::optional<WithPredicate<TypePackId>> magicFunctionSetMetaTable(
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
);
static std::optional<WithPredicate<TypePackId>> magicFunctionAssert(
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
);
static std::optional<WithPredicate<TypePackId>> magicFunctionPack(
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
);
static std::optional<WithPredicate<TypePackId>> magicFunctionRequire(
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
);
struct MagicSetMetatable final : MagicFunction
{
std::optional<WithPredicate<TypePackId>> handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override;
};
static bool dcrMagicFunctionSelect(MagicFunctionCallContext context); struct MagicAssert final : MagicFunction
static bool dcrMagicFunctionRequire(MagicFunctionCallContext context); {
static bool dcrMagicFunctionPack(MagicFunctionCallContext context); std::optional<WithPredicate<TypePackId>> handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
static bool dcrMagicFunctionFreeze(MagicFunctionCallContext context); bool infer(const MagicFunctionCallContext& ctx) override;
};
struct MagicPack final : MagicFunction
{
std::optional<WithPredicate<TypePackId>> handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override;
};
struct MagicRequire final : MagicFunction
{
std::optional<WithPredicate<TypePackId>> handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override;
};
struct MagicClone final : MagicFunction
{
std::optional<WithPredicate<TypePackId>> handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override;
};
struct MagicFreeze final : MagicFunction
{
std::optional<WithPredicate<TypePackId>> handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override;
};
struct MagicFormat final : MagicFunction
{
std::optional<WithPredicate<TypePackId>> handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override;
bool typeCheck(const MagicFunctionTypeCheckContext& ctx) override;
};
struct MagicMatch final : MagicFunction
{
std::optional<WithPredicate<TypePackId>> handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override;
};
struct MagicGmatch final : MagicFunction
{
std::optional<WithPredicate<TypePackId>> handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override;
};
struct MagicFind final : MagicFunction
{
std::optional<WithPredicate<TypePackId>> handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override;
};
TypeId makeUnion(TypeArena& arena, std::vector<TypeId>&& types) TypeId makeUnion(TypeArena& arena, std::vector<TypeId>&& types)
{ {
@ -168,34 +200,10 @@ TypeId makeFunction(
return arena.addType(std::move(ftv)); return arena.addType(std::move(ftv));
} }
void attachMagicFunction(TypeId ty, MagicFunction fn) void attachMagicFunction(TypeId ty, std::shared_ptr<MagicFunction> magic)
{ {
if (auto ftv = getMutable<FunctionType>(ty)) if (auto ftv = getMutable<FunctionType>(ty))
ftv->magicFunction = fn; ftv->magic = std::move(magic);
else
LUAU_ASSERT(!"Got a non functional type");
}
void attachDcrMagicFunction(TypeId ty, DcrMagicFunction fn)
{
if (auto ftv = getMutable<FunctionType>(ty))
ftv->dcrMagicFunction = fn;
else
LUAU_ASSERT(!"Got a non functional type");
}
void attachDcrMagicRefinement(TypeId ty, DcrMagicRefinement fn)
{
if (auto ftv = getMutable<FunctionType>(ty))
ftv->dcrMagicRefinement = fn;
else
LUAU_ASSERT(!"Got a non functional type");
}
void attachDcrMagicFunctionTypeCheck(TypeId ty, DcrMagicFunctionTypeCheck fn)
{
if (auto ftv = getMutable<FunctionType>(ty))
ftv->dcrMagicTypeCheck = fn;
else else
LUAU_ASSERT(!"Got a non functional type"); LUAU_ASSERT(!"Got a non functional type");
} }
@ -396,7 +404,7 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC
} }
} }
attachMagicFunction(getGlobalBinding(globals, "assert"), magicFunctionAssert); attachMagicFunction(getGlobalBinding(globals, "assert"), std::make_shared<MagicAssert>());
if (FFlag::LuauSolverV2) if (FFlag::LuauSolverV2)
{ {
@ -412,9 +420,8 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC
addGlobalBinding(globals, "assert", assertTy, "@luau"); addGlobalBinding(globals, "assert", assertTy, "@luau");
} }
attachMagicFunction(getGlobalBinding(globals, "setmetatable"), magicFunctionSetMetaTable); attachMagicFunction(getGlobalBinding(globals, "setmetatable"), std::make_shared<MagicSetMetatable>());
attachMagicFunction(getGlobalBinding(globals, "select"), magicFunctionSelect); attachMagicFunction(getGlobalBinding(globals, "select"), std::make_shared<MagicSelect>());
attachDcrMagicFunction(getGlobalBinding(globals, "select"), dcrMagicFunctionSelect);
if (TableType* ttv = getMutable<TableType>(getGlobalBinding(globals, "table"))) if (TableType* ttv = getMutable<TableType>(getGlobalBinding(globals, "table")))
{ {
@ -445,23 +452,22 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC
ttv->props["foreach"].deprecated = true; ttv->props["foreach"].deprecated = true;
ttv->props["foreachi"].deprecated = true; ttv->props["foreachi"].deprecated = true;
attachMagicFunction(ttv->props["pack"].type(), magicFunctionPack); attachMagicFunction(ttv->props["pack"].type(), std::make_shared<MagicPack>());
attachDcrMagicFunction(ttv->props["pack"].type(), dcrMagicFunctionPack); if (FFlag::LuauTableCloneClonesType)
attachMagicFunction(ttv->props["clone"].type(), std::make_shared<MagicClone>());
if (FFlag::LuauTypestateBuiltins2) if (FFlag::LuauTypestateBuiltins2)
attachDcrMagicFunction(ttv->props["freeze"].type(), dcrMagicFunctionFreeze); attachMagicFunction(ttv->props["freeze"].type(), std::make_shared<MagicFreeze>());
} }
if (FFlag::AutocompleteRequirePathSuggestions2) if (FFlag::AutocompleteRequirePathSuggestions2)
{ {
TypeId requireTy = getGlobalBinding(globals, "require"); TypeId requireTy = getGlobalBinding(globals, "require");
attachTag(requireTy, kRequireTagName); attachTag(requireTy, kRequireTagName);
attachMagicFunction(requireTy, magicFunctionRequire); attachMagicFunction(requireTy, std::make_shared<MagicRequire>());
attachDcrMagicFunction(requireTy, dcrMagicFunctionRequire);
} }
else else
{ {
attachMagicFunction(getGlobalBinding(globals, "require"), magicFunctionRequire); attachMagicFunction(getGlobalBinding(globals, "require"), std::make_shared<MagicRequire>());
attachDcrMagicFunction(getGlobalBinding(globals, "require"), dcrMagicFunctionRequire);
} }
} }
@ -501,7 +507,7 @@ static std::vector<TypeId> parseFormatString(NotNull<BuiltinTypes> builtinTypes,
return result; return result;
} }
std::optional<WithPredicate<TypePackId>> magicFunctionFormat( std::optional<WithPredicate<TypePackId>> MagicFormat::handleOldSolver(
TypeChecker& typechecker, TypeChecker& typechecker,
const ScopePtr& scope, const ScopePtr& scope,
const AstExprCall& expr, const AstExprCall& expr,
@ -551,7 +557,7 @@ std::optional<WithPredicate<TypePackId>> magicFunctionFormat(
return WithPredicate<TypePackId>{arena.addTypePack({typechecker.stringType})}; return WithPredicate<TypePackId>{arena.addTypePack({typechecker.stringType})};
} }
static bool dcrMagicFunctionFormat(MagicFunctionCallContext context) bool MagicFormat::infer(const MagicFunctionCallContext& context)
{ {
TypeArena* arena = context.solver->arena; TypeArena* arena = context.solver->arena;
@ -595,7 +601,7 @@ static bool dcrMagicFunctionFormat(MagicFunctionCallContext context)
return true; return true;
} }
static void dcrMagicFunctionTypeCheckFormat(MagicFunctionTypeCheckContext context) bool MagicFormat::typeCheck(const MagicFunctionTypeCheckContext& context)
{ {
AstExprConstantString* fmt = nullptr; AstExprConstantString* fmt = nullptr;
if (auto index = context.callSite->func->as<AstExprIndexName>(); index && context.callSite->self) if (auto index = context.callSite->func->as<AstExprIndexName>(); index && context.callSite->self)
@ -615,7 +621,7 @@ static void dcrMagicFunctionTypeCheckFormat(MagicFunctionTypeCheckContext contex
context.typechecker->reportError( context.typechecker->reportError(
CountMismatch{1, std::nullopt, 0, CountMismatch::Arg, true, "string.format"}, context.callSite->location CountMismatch{1, std::nullopt, 0, CountMismatch::Arg, true, "string.format"}, context.callSite->location
); );
return; return true;
} }
std::vector<TypeId> expected = parseFormatString(context.builtinTypes, fmt->value.data, fmt->value.size); std::vector<TypeId> expected = parseFormatString(context.builtinTypes, fmt->value.data, fmt->value.size);
@ -657,6 +663,8 @@ static void dcrMagicFunctionTypeCheckFormat(MagicFunctionTypeCheckContext contex
} }
} }
} }
return true;
} }
static std::vector<TypeId> parsePatternString(NotNull<BuiltinTypes> builtinTypes, const char* data, size_t size) static std::vector<TypeId> parsePatternString(NotNull<BuiltinTypes> builtinTypes, const char* data, size_t size)
@ -719,7 +727,7 @@ static std::vector<TypeId> parsePatternString(NotNull<BuiltinTypes> builtinTypes
return result; return result;
} }
static std::optional<WithPredicate<TypePackId>> magicFunctionGmatch( std::optional<WithPredicate<TypePackId>> MagicGmatch::handleOldSolver(
TypeChecker& typechecker, TypeChecker& typechecker,
const ScopePtr& scope, const ScopePtr& scope,
const AstExprCall& expr, const AstExprCall& expr,
@ -755,7 +763,7 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionGmatch(
return WithPredicate<TypePackId>{arena.addTypePack({iteratorType})}; return WithPredicate<TypePackId>{arena.addTypePack({iteratorType})};
} }
static bool dcrMagicFunctionGmatch(MagicFunctionCallContext context) bool MagicGmatch::infer(const MagicFunctionCallContext& context)
{ {
const auto& [params, tail] = flatten(context.arguments); const auto& [params, tail] = flatten(context.arguments);
@ -788,7 +796,7 @@ static bool dcrMagicFunctionGmatch(MagicFunctionCallContext context)
return true; return true;
} }
static std::optional<WithPredicate<TypePackId>> magicFunctionMatch( std::optional<WithPredicate<TypePackId>> MagicMatch::handleOldSolver(
TypeChecker& typechecker, TypeChecker& typechecker,
const ScopePtr& scope, const ScopePtr& scope,
const AstExprCall& expr, const AstExprCall& expr,
@ -828,7 +836,7 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionMatch(
return WithPredicate<TypePackId>{returnList}; return WithPredicate<TypePackId>{returnList};
} }
static bool dcrMagicFunctionMatch(MagicFunctionCallContext context) bool MagicMatch::infer(const MagicFunctionCallContext& context)
{ {
const auto& [params, tail] = flatten(context.arguments); const auto& [params, tail] = flatten(context.arguments);
@ -864,7 +872,7 @@ static bool dcrMagicFunctionMatch(MagicFunctionCallContext context)
return true; return true;
} }
static std::optional<WithPredicate<TypePackId>> magicFunctionFind( std::optional<WithPredicate<TypePackId>> MagicFind::handleOldSolver(
TypeChecker& typechecker, TypeChecker& typechecker,
const ScopePtr& scope, const ScopePtr& scope,
const AstExprCall& expr, const AstExprCall& expr,
@ -922,7 +930,7 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionFind(
return WithPredicate<TypePackId>{returnList}; return WithPredicate<TypePackId>{returnList};
} }
static bool dcrMagicFunctionFind(MagicFunctionCallContext context) bool MagicFind::infer(const MagicFunctionCallContext& context)
{ {
const auto& [params, tail] = flatten(context.arguments); const auto& [params, tail] = flatten(context.arguments);
@ -999,11 +1007,9 @@ TypeId makeStringMetatable(NotNull<BuiltinTypes> builtinTypes)
FunctionType formatFTV{arena->addTypePack(TypePack{{stringType}, variadicTailPack}), oneStringPack}; FunctionType formatFTV{arena->addTypePack(TypePack{{stringType}, variadicTailPack}), oneStringPack};
formatFTV.magicFunction = &magicFunctionFormat;
formatFTV.isCheckedFunction = true; formatFTV.isCheckedFunction = true;
const TypeId formatFn = arena->addType(formatFTV); const TypeId formatFn = arena->addType(formatFTV);
attachDcrMagicFunction(formatFn, dcrMagicFunctionFormat); attachMagicFunction(formatFn, std::make_shared<MagicFormat>());
attachDcrMagicFunctionTypeCheck(formatFn, dcrMagicFunctionTypeCheckFormat);
const TypeId stringToStringType = makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType}, /* checked */ true); const TypeId stringToStringType = makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType}, /* checked */ true);
@ -1017,16 +1023,14 @@ TypeId makeStringMetatable(NotNull<BuiltinTypes> builtinTypes)
makeFunction(*arena, stringType, {}, {}, {stringType, replArgType, optionalNumber}, {}, {stringType, numberType}, /* checked */ false); makeFunction(*arena, stringType, {}, {}, {stringType, replArgType, optionalNumber}, {}, {stringType, numberType}, /* checked */ false);
const TypeId gmatchFunc = const TypeId gmatchFunc =
makeFunction(*arena, stringType, {}, {}, {stringType}, {}, {arena->addType(FunctionType{emptyPack, stringVariadicList})}, /* checked */ true); makeFunction(*arena, stringType, {}, {}, {stringType}, {}, {arena->addType(FunctionType{emptyPack, stringVariadicList})}, /* checked */ true);
attachMagicFunction(gmatchFunc, magicFunctionGmatch); attachMagicFunction(gmatchFunc, std::make_shared<MagicGmatch>());
attachDcrMagicFunction(gmatchFunc, dcrMagicFunctionGmatch);
FunctionType matchFuncTy{ FunctionType matchFuncTy{
arena->addTypePack({stringType, stringType, optionalNumber}), arena->addTypePack(TypePackVar{VariadicTypePack{stringType}}) arena->addTypePack({stringType, stringType, optionalNumber}), arena->addTypePack(TypePackVar{VariadicTypePack{stringType}})
}; };
matchFuncTy.isCheckedFunction = true; matchFuncTy.isCheckedFunction = true;
const TypeId matchFunc = arena->addType(matchFuncTy); const TypeId matchFunc = arena->addType(matchFuncTy);
attachMagicFunction(matchFunc, magicFunctionMatch); attachMagicFunction(matchFunc, std::make_shared<MagicMatch>());
attachDcrMagicFunction(matchFunc, dcrMagicFunctionMatch);
FunctionType findFuncTy{ FunctionType findFuncTy{
arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}), arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}),
@ -1034,8 +1038,7 @@ TypeId makeStringMetatable(NotNull<BuiltinTypes> builtinTypes)
}; };
findFuncTy.isCheckedFunction = true; findFuncTy.isCheckedFunction = true;
const TypeId findFunc = arena->addType(findFuncTy); const TypeId findFunc = arena->addType(findFuncTy);
attachMagicFunction(findFunc, magicFunctionFind); attachMagicFunction(findFunc, std::make_shared<MagicFind>());
attachDcrMagicFunction(findFunc, dcrMagicFunctionFind);
// string.byte : string -> number? -> number? -> ...number // string.byte : string -> number? -> number? -> ...number
FunctionType stringDotByte{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList}; FunctionType stringDotByte{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList};
@ -1096,7 +1099,7 @@ TypeId makeStringMetatable(NotNull<BuiltinTypes> builtinTypes)
return arena->addType(TableType{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed}); return arena->addType(TableType{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed});
} }
static std::optional<WithPredicate<TypePackId>> magicFunctionSelect( std::optional<WithPredicate<TypePackId>> MagicSelect::handleOldSolver(
TypeChecker& typechecker, TypeChecker& typechecker,
const ScopePtr& scope, const ScopePtr& scope,
const AstExprCall& expr, const AstExprCall& expr,
@ -1141,7 +1144,7 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionSelect(
return std::nullopt; return std::nullopt;
} }
static bool dcrMagicFunctionSelect(MagicFunctionCallContext context) bool MagicSelect::infer(const MagicFunctionCallContext& context)
{ {
if (context.callSite->args.size <= 0) if (context.callSite->args.size <= 0)
{ {
@ -1186,7 +1189,7 @@ static bool dcrMagicFunctionSelect(MagicFunctionCallContext context)
return false; return false;
} }
static std::optional<WithPredicate<TypePackId>> magicFunctionSetMetaTable( std::optional<WithPredicate<TypePackId>> MagicSetMetatable::handleOldSolver(
TypeChecker& typechecker, TypeChecker& typechecker,
const ScopePtr& scope, const ScopePtr& scope,
const AstExprCall& expr, const AstExprCall& expr,
@ -1268,7 +1271,12 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionSetMetaTable(
return WithPredicate<TypePackId>{arena.addTypePack({target})}; return WithPredicate<TypePackId>{arena.addTypePack({target})};
} }
static std::optional<WithPredicate<TypePackId>> magicFunctionAssert( bool MagicSetMetatable::infer(const MagicFunctionCallContext&)
{
return false;
}
std::optional<WithPredicate<TypePackId>> MagicAssert::handleOldSolver(
TypeChecker& typechecker, TypeChecker& typechecker,
const ScopePtr& scope, const ScopePtr& scope,
const AstExprCall& expr, const AstExprCall& expr,
@ -1302,7 +1310,12 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionAssert(
return WithPredicate<TypePackId>{arena.addTypePack(TypePack{std::move(head), tail})}; return WithPredicate<TypePackId>{arena.addTypePack(TypePack{std::move(head), tail})};
} }
static std::optional<WithPredicate<TypePackId>> magicFunctionPack( bool MagicAssert::infer(const MagicFunctionCallContext&)
{
return false;
}
std::optional<WithPredicate<TypePackId>> MagicPack::handleOldSolver(
TypeChecker& typechecker, TypeChecker& typechecker,
const ScopePtr& scope, const ScopePtr& scope,
const AstExprCall& expr, const AstExprCall& expr,
@ -1345,7 +1358,7 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionPack(
return WithPredicate<TypePackId>{arena.addTypePack({packedTable})}; return WithPredicate<TypePackId>{arena.addTypePack({packedTable})};
} }
static bool dcrMagicFunctionPack(MagicFunctionCallContext context) bool MagicPack::infer(const MagicFunctionCallContext& context)
{ {
TypeArena* arena = context.solver->arena; TypeArena* arena = context.solver->arena;
@ -1385,7 +1398,68 @@ static bool dcrMagicFunctionPack(MagicFunctionCallContext context)
return true; return true;
} }
static std::optional<TypeId> freezeTable(TypeId inputType, MagicFunctionCallContext& context) std::optional<WithPredicate<TypePackId>> MagicClone::handleOldSolver(
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
)
{
LUAU_ASSERT(FFlag::LuauTableCloneClonesType);
auto [paramPack, _predicates] = withPredicate;
TypeArena& arena = typechecker.currentModule->internalTypes;
const auto& [paramTypes, paramTail] = flatten(paramPack);
if (paramTypes.empty() || expr.args.size == 0)
{
typechecker.reportError(expr.argLocation, CountMismatch{1, std::nullopt, 0});
return std::nullopt;
}
TypeId inputType = follow(paramTypes[0]);
CloneState cloneState{typechecker.builtinTypes};
TypeId resultType = shallowClone(inputType, arena, cloneState);
TypePackId clonedTypePack = arena.addTypePack({resultType});
return WithPredicate<TypePackId>{clonedTypePack};
}
bool MagicClone::infer(const MagicFunctionCallContext& context)
{
LUAU_ASSERT(FFlag::LuauTableCloneClonesType);
TypeArena* arena = context.solver->arena;
const auto& [paramTypes, paramTail] = flatten(context.arguments);
if (paramTypes.empty() || context.callSite->args.size == 0)
{
context.solver->reportError(CountMismatch{1, std::nullopt, 0}, context.callSite->argLocation);
return false;
}
TypeId inputType = follow(paramTypes[0]);
CloneState cloneState{context.solver->builtinTypes};
TypeId resultType = shallowClone(inputType, *arena, cloneState);
if (auto tableType = getMutable<TableType>(resultType))
{
tableType->scope = context.constraint->scope.get();
}
if (FFlag::LuauTrackInteriorFreeTypesOnScope)
trackInteriorFreeType(context.constraint->scope.get(), resultType);
TypePackId clonedTypePack = arena->addTypePack({resultType});
asMutable(context.result)->ty.emplace<BoundTypePack>(clonedTypePack);
return true;
}
static std::optional<TypeId> freezeTable(TypeId inputType, const MagicFunctionCallContext& context)
{ {
TypeArena* arena = context.solver->arena; TypeArena* arena = context.solver->arena;
@ -1430,7 +1504,12 @@ static std::optional<TypeId> freezeTable(TypeId inputType, MagicFunctionCallCont
return std::nullopt; return std::nullopt;
} }
static bool dcrMagicFunctionFreeze(MagicFunctionCallContext context) std::optional<WithPredicate<TypePackId>> MagicFreeze::handleOldSolver(struct TypeChecker &, const std::shared_ptr<struct Scope> &, const class AstExprCall &, WithPredicate<TypePackId>)
{
return std::nullopt;
}
bool MagicFreeze::infer(const MagicFunctionCallContext& context)
{ {
LUAU_ASSERT(FFlag::LuauTypestateBuiltins2); LUAU_ASSERT(FFlag::LuauTypestateBuiltins2);
@ -1491,7 +1570,7 @@ static bool checkRequirePath(TypeChecker& typechecker, AstExpr* expr)
return good; return good;
} }
static std::optional<WithPredicate<TypePackId>> magicFunctionRequire( std::optional<WithPredicate<TypePackId>> MagicRequire::handleOldSolver(
TypeChecker& typechecker, TypeChecker& typechecker,
const ScopePtr& scope, const ScopePtr& scope,
const AstExprCall& expr, const AstExprCall& expr,
@ -1537,7 +1616,7 @@ static bool checkRequirePathDcr(NotNull<ConstraintSolver> solver, AstExpr* expr)
return good; return good;
} }
static bool dcrMagicFunctionRequire(MagicFunctionCallContext context) bool MagicRequire::infer(const MagicFunctionCallContext& context)
{ {
if (context.callSite->args.size != 1) if (context.callSite->args.size != 1)
{ {

View file

@ -3,8 +3,6 @@
#include "Luau/Constraint.h" #include "Luau/Constraint.h"
#include "Luau/VisitType.h" #include "Luau/VisitType.h"
LUAU_FASTFLAGVARIABLE(LuauDontRefCountTypesInTypeFunctions)
namespace Luau namespace Luau
{ {
@ -60,7 +58,7 @@ struct ReferenceCountInitializer : TypeOnceVisitor
// //
// The default behavior here is `true` for "visit the child types" // The default behavior here is `true` for "visit the child types"
// of this type, hence: // of this type, hence:
return !FFlag::LuauDontRefCountTypesInTypeFunctions; return false;
} }
}; };

View file

@ -75,7 +75,7 @@ size_t HashBlockedConstraintId::operator()(const BlockedConstraintId& bci) const
{ {
if (auto blocked = get<BlockedType>(ty)) if (auto blocked = get<BlockedType>(ty))
{ {
Constraint* owner = blocked->getOwner(); const Constraint* owner = blocked->getOwner();
LUAU_ASSERT(owner); LUAU_ASSERT(owner);
return owner == constraint; return owner == constraint;
} }
@ -446,7 +446,7 @@ void ConstraintSolver::run()
if (success) if (success)
{ {
unblock(c); unblock(c);
unsolvedConstraints.erase(unsolvedConstraints.begin() + i); unsolvedConstraints.erase(unsolvedConstraints.begin() + ptrdiff_t(i));
// decrement the referenced free types for this constraint if we dispatched successfully! // decrement the referenced free types for this constraint if we dispatched successfully!
for (auto ty : c->getMaybeMutatedFreeTypes()) for (auto ty : c->getMaybeMutatedFreeTypes())
@ -553,7 +553,7 @@ void ConstraintSolver::finalizeTypeFunctions()
} }
} }
bool ConstraintSolver::isDone() bool ConstraintSolver::isDone() const
{ {
return unsolvedConstraints.empty(); return unsolvedConstraints.empty();
} }
@ -1293,11 +1293,11 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull<cons
if (ftv) if (ftv)
{ {
if (ftv->dcrMagicFunction) if (ftv->magic)
usedMagic = ftv->dcrMagicFunction(MagicFunctionCallContext{NotNull{this}, constraint, c.callSite, c.argsPack, result}); {
usedMagic = ftv->magic->infer(MagicFunctionCallContext{NotNull{this}, constraint, c.callSite, c.argsPack, result});
if (ftv->dcrMagicRefinement) ftv->magic->refine(MagicRefinementContext{constraint->scope, c.callSite, c.discriminantTypes});
ftv->dcrMagicRefinement(MagicRefinementContext{constraint->scope, c.callSite, c.discriminantTypes}); }
} }
if (!usedMagic) if (!usedMagic)
@ -1702,7 +1702,7 @@ bool ConstraintSolver::tryDispatchHasIndexer(
for (TypeId part : parts) for (TypeId part : parts)
{ {
TypeId r = arena->addType(BlockedType{}); TypeId r = arena->addType(BlockedType{});
getMutable<BlockedType>(r)->setOwner(const_cast<Constraint*>(constraint.get())); getMutable<BlockedType>(r)->setOwner(constraint.get());
bool ok = tryDispatchHasIndexer(recursionDepth, constraint, part, indexType, r, seen); bool ok = tryDispatchHasIndexer(recursionDepth, constraint, part, indexType, r, seen);
// If we've cut a recursive loop short, skip it. // If we've cut a recursive loop short, skip it.
@ -1734,7 +1734,7 @@ bool ConstraintSolver::tryDispatchHasIndexer(
for (TypeId part : parts) for (TypeId part : parts)
{ {
TypeId r = arena->addType(BlockedType{}); TypeId r = arena->addType(BlockedType{});
getMutable<BlockedType>(r)->setOwner(const_cast<Constraint*>(constraint.get())); getMutable<BlockedType>(r)->setOwner(constraint.get());
bool ok = tryDispatchHasIndexer(recursionDepth, constraint, part, indexType, r, seen); bool ok = tryDispatchHasIndexer(recursionDepth, constraint, part, indexType, r, seen);
// If we've cut a recursive loop short, skip it. // If we've cut a recursive loop short, skip it.
@ -2874,10 +2874,10 @@ bool ConstraintSolver::blockOnPendingTypes(TypeId target, NotNull<const Constrai
return !blocker.blocked; return !blocker.blocked;
} }
bool ConstraintSolver::blockOnPendingTypes(TypePackId pack, NotNull<const Constraint> constraint) bool ConstraintSolver::blockOnPendingTypes(TypePackId targetPack, NotNull<const Constraint> constraint)
{ {
Blocker blocker{NotNull{this}, constraint}; Blocker blocker{NotNull{this}, constraint};
blocker.traverse(pack); blocker.traverse(targetPack);
return !blocker.blocked; return !blocker.blocked;
} }

View file

@ -62,6 +62,12 @@ const RefinementKey* RefinementKeyArena::node(const RefinementKey* parent, DefId
return allocator.allocate(RefinementKey{parent, def, propName}); return allocator.allocate(RefinementKey{parent, def, propName});
} }
DataFlowGraph::DataFlowGraph(NotNull<DefArena> defArena, NotNull<RefinementKeyArena> keyArena)
: defArena{defArena}
, keyArena{keyArena}
{
}
DefId DataFlowGraph::getDef(const AstExpr* expr) const DefId DataFlowGraph::getDef(const AstExpr* expr) const
{ {
auto def = astDefs.find(expr); auto def = astDefs.find(expr);
@ -178,11 +184,23 @@ bool DfgScope::canUpdateDefinition(DefId def, const std::string& key) const
return true; return true;
} }
DataFlowGraph DataFlowGraphBuilder::build(AstStatBlock* block, NotNull<InternalErrorReporter> handle) DataFlowGraphBuilder::DataFlowGraphBuilder(NotNull<DefArena> defArena, NotNull<RefinementKeyArena> keyArena)
: graph{defArena, keyArena}
, defArena{defArena}
, keyArena{keyArena}
{
}
DataFlowGraph DataFlowGraphBuilder::build(
AstStatBlock* block,
NotNull<DefArena> defArena,
NotNull<RefinementKeyArena> keyArena,
NotNull<struct InternalErrorReporter> handle
)
{ {
LUAU_TIMETRACE_SCOPE("DataFlowGraphBuilder::build", "Typechecking"); LUAU_TIMETRACE_SCOPE("DataFlowGraphBuilder::build", "Typechecking");
DataFlowGraphBuilder builder; DataFlowGraphBuilder builder(defArena, keyArena);
builder.handle = handle; builder.handle = handle;
DfgScope* moduleScope = builder.makeChildScope(); DfgScope* moduleScope = builder.makeChildScope();
PushScope ps{builder.scopeStack, moduleScope}; PushScope ps{builder.scopeStack, moduleScope};
@ -198,30 +216,6 @@ DataFlowGraph DataFlowGraphBuilder::build(AstStatBlock* block, NotNull<InternalE
return std::move(builder.graph); return std::move(builder.graph);
} }
std::pair<std::shared_ptr<DataFlowGraph>, std::vector<std::unique_ptr<DfgScope>>> DataFlowGraphBuilder::buildShared(
AstStatBlock* block,
NotNull<InternalErrorReporter> handle
)
{
LUAU_TIMETRACE_SCOPE("DataFlowGraphBuilder::build", "Typechecking");
DataFlowGraphBuilder builder;
builder.handle = handle;
DfgScope* moduleScope = builder.makeChildScope();
PushScope ps{builder.scopeStack, moduleScope};
builder.visitBlockWithoutChildScope(block);
builder.resolveCaptures();
if (FFlag::DebugLuauFreezeArena)
{
builder.defArena->allocator.freeze();
builder.keyArena->allocator.freeze();
}
return {std::make_shared<DataFlowGraph>(std::move(builder.graph)), std::move(builder.scopes)};
}
void DataFlowGraphBuilder::resolveCaptures() void DataFlowGraphBuilder::resolveCaptures()
{ {
for (const auto& [_, capture] : captures) for (const auto& [_, capture] : captures)

View file

@ -206,7 +206,7 @@ static bool isTerminal(const EGraph& egraph, Id eclass)
nodes.end(), nodes.end(),
[](auto& a) [](auto& a)
{ {
return isTerminal(a); return isTerminal(a.node);
} }
); );
} }
@ -464,7 +464,7 @@ static size_t computeCost(std::unordered_map<Id, size_t>& bestNodes, const EGrap
if (auto it = costs.find(id); it != costs.end()) if (auto it = costs.find(id); it != costs.end())
return it->second; return it->second;
const std::vector<EType>& nodes = egraph[id].nodes; const std::vector<Node<EType>>& nodes = egraph[id].nodes;
size_t minCost = std::numeric_limits<size_t>::max(); size_t minCost = std::numeric_limits<size_t>::max();
size_t bestNode = std::numeric_limits<size_t>::max(); size_t bestNode = std::numeric_limits<size_t>::max();
@ -481,7 +481,7 @@ static size_t computeCost(std::unordered_map<Id, size_t>& bestNodes, const EGrap
// First, quickly scan for a terminal type. If we can find one, it is obviously the best. // First, quickly scan for a terminal type. If we can find one, it is obviously the best.
for (size_t index = 0; index < nodes.size(); ++index) for (size_t index = 0; index < nodes.size(); ++index)
{ {
if (isTerminal(nodes[index])) if (isTerminal(nodes[index].node))
{ {
minCost = 1; minCost = 1;
bestNode = index; bestNode = index;
@ -533,44 +533,44 @@ static size_t computeCost(std::unordered_map<Id, size_t>& bestNodes, const EGrap
{ {
const auto& node = nodes[index]; const auto& node = nodes[index];
if (node.get<TBound>()) if (node.node.get<TBound>())
updateCost(BOUND_PENALTY, index); // TODO: This could probably be an assert now that we don't need rewrite rules to handle TBound. updateCost(BOUND_PENALTY, index); // TODO: This could probably be an assert now that we don't need rewrite rules to handle TBound.
else if (node.get<TFunction>()) else if (node.node.get<TFunction>())
{ {
minCost = 1; minCost = 1;
bestNode = index; bestNode = index;
} }
else if (auto tbl = node.get<TTable>()) else if (auto tbl = node.node.get<TTable>())
{ {
// TODO: We could make the penalty a parameter to computeChildren. // TODO: We could make the penalty a parameter to computeChildren.
std::optional<size_t> maybeCost = computeChildren(tbl->operands(), minCost); std::optional<size_t> maybeCost = computeChildren(tbl->operands(), minCost);
if (maybeCost) if (maybeCost)
updateCost(TABLE_TYPE_PENALTY + *maybeCost, index); updateCost(TABLE_TYPE_PENALTY + *maybeCost, index);
} }
else if (node.get<TImportedTable>()) else if (node.node.get<TImportedTable>())
{ {
minCost = IMPORTED_TABLE_PENALTY; minCost = IMPORTED_TABLE_PENALTY;
bestNode = index; bestNode = index;
} }
else if (auto u = node.get<Union>()) else if (auto u = node.node.get<Union>())
{ {
std::optional<size_t> maybeCost = computeChildren(u->operands(), minCost); std::optional<size_t> maybeCost = computeChildren(u->operands(), minCost);
if (maybeCost) if (maybeCost)
updateCost(SET_TYPE_PENALTY + *maybeCost, index); updateCost(SET_TYPE_PENALTY + *maybeCost, index);
} }
else if (auto i = node.get<Intersection>()) else if (auto i = node.node.get<Intersection>())
{ {
std::optional<size_t> maybeCost = computeChildren(i->operands(), minCost); std::optional<size_t> maybeCost = computeChildren(i->operands(), minCost);
if (maybeCost) if (maybeCost)
updateCost(SET_TYPE_PENALTY + *maybeCost, index); updateCost(SET_TYPE_PENALTY + *maybeCost, index);
} }
else if (auto negation = node.get<Negation>()) else if (auto negation = node.node.get<Negation>())
{ {
std::optional<size_t> maybeCost = computeChildren(negation->operands(), minCost); std::optional<size_t> maybeCost = computeChildren(negation->operands(), minCost);
if (maybeCost) if (maybeCost)
updateCost(NEGATION_PENALTY + *maybeCost, index); updateCost(NEGATION_PENALTY + *maybeCost, index);
} }
else if (auto tfun = node.get<TTypeFun>()) else if (auto tfun = node.node.get<TTypeFun>())
{ {
std::optional<size_t> maybeCost = computeChildren(tfun->operands(), minCost); std::optional<size_t> maybeCost = computeChildren(tfun->operands(), minCost);
if (maybeCost) if (maybeCost)
@ -643,7 +643,7 @@ TypeId flattenTableNode(
for (size_t i = 0; i < eclass.nodes.size(); ++i) for (size_t i = 0; i < eclass.nodes.size(); ++i)
{ {
if (eclass.nodes[i].get<TImportedTable>()) if (eclass.nodes[i].node.get<TImportedTable>())
{ {
found = true; found = true;
index = i; index = i;
@ -660,13 +660,13 @@ TypeId flattenTableNode(
} }
const auto& node = eclass.nodes[index]; const auto& node = eclass.nodes[index];
if (const TTable* ttable = node.get<TTable>()) if (const TTable* ttable = node.node.get<TTable>())
{ {
stack.push_back(ttable); stack.push_back(ttable);
id = ttable->getBasis(); id = ttable->getBasis();
continue; continue;
} }
else if (const TImportedTable* ti = node.get<TImportedTable>()) else if (const TImportedTable* ti = node.node.get<TImportedTable>())
{ {
importedTable = ti; importedTable = ti;
break; break;
@ -718,7 +718,7 @@ TypeId fromId(
size_t index = bestNodes.at(rootId); size_t index = bestNodes.at(rootId);
LUAU_ASSERT(index <= egraph[rootId].nodes.size()); LUAU_ASSERT(index <= egraph[rootId].nodes.size());
const EType& node = egraph[rootId].nodes[index]; const EType& node = egraph[rootId].nodes[index].node;
if (node.get<TNil>()) if (node.get<TNil>())
return builtinTypes->nilType; return builtinTypes->nilType;
@ -1025,8 +1025,9 @@ std::string toDot(const StringCache& strings, const EGraph& egraph)
for (const auto& [id, eclass] : egraph.getAllClasses()) for (const auto& [id, eclass] : egraph.getAllClasses())
{ {
for (const auto& node : eclass.nodes) for (const auto& n : eclass.nodes)
{ {
const EType& node = n.node;
if (!node.operands().empty()) if (!node.operands().empty())
populated.insert(id); populated.insert(id);
for (Id op : node.operands()) for (Id op : node.operands())
@ -1047,7 +1048,7 @@ std::string toDot(const StringCache& strings, const EGraph& egraph)
for (size_t index = 0; index < eclass.nodes.size(); ++index) for (size_t index = 0; index < eclass.nodes.size(); ++index)
{ {
const auto& node = eclass.nodes[index]; const auto& node = eclass.nodes[index].node;
const std::string label = getNodeName(strings, node); const std::string label = getNodeName(strings, node);
const std::string nodeName = "n" + std::to_string(uint32_t(id)) + "_" + std::to_string(index); const std::string nodeName = "n" + std::to_string(uint32_t(id)) + "_" + std::to_string(index);
@ -1062,7 +1063,7 @@ std::string toDot(const StringCache& strings, const EGraph& egraph)
{ {
for (size_t index = 0; index < eclass.nodes.size(); ++index) for (size_t index = 0; index < eclass.nodes.size(); ++index)
{ {
const auto& node = eclass.nodes[index]; const auto& node = eclass.nodes[index].node;
const std::string label = getNodeName(strings, node); const std::string label = getNodeName(strings, node);
const std::string nodeName = "n" + std::to_string(uint32_t(egraph.find(id))) + "_" + std::to_string(index); const std::string nodeName = "n" + std::to_string(uint32_t(egraph.find(id))) + "_" + std::to_string(index);
@ -1098,7 +1099,7 @@ static Tag const* isTag(const EGraph& egraph, Id id)
{ {
for (const auto& node : egraph[id].nodes) for (const auto& node : egraph[id].nodes)
{ {
if (auto n = isTag<Tag>(node)) if (auto n = isTag<Tag>(node.node))
return n; return n;
} }
return nullptr; return nullptr;
@ -1134,7 +1135,7 @@ protected:
{ {
for (const auto& node : (*egraph)[id].nodes) for (const auto& node : (*egraph)[id].nodes)
{ {
if (auto n = node.get<Tag>()) if (auto n = node.node.get<Tag>())
return n; return n;
} }
return nullptr; return nullptr;
@ -1322,8 +1323,10 @@ const EType* findSubtractableClass(const EGraph& egraph, std::unordered_set<Id>&
const EType* bestUnion = nullptr; const EType* bestUnion = nullptr;
std::optional<size_t> unionSize; std::optional<size_t> unionSize;
for (const auto& node : egraph[id].nodes) for (const auto& n : egraph[id].nodes)
{ {
const EType& node = n.node;
if (isTerminal(node)) if (isTerminal(node))
return &node; return &node;
@ -1439,14 +1442,14 @@ bool subtract(EGraph& egraph, CanonicalizedType& ct, Id part)
return true; return true;
} }
Id fromCanonicalized(EGraph& egraph, CanonicalizedType& ct) static std::pair<Id, size_t> fromCanonicalized(EGraph& egraph, CanonicalizedType& ct)
{ {
if (ct.isUnknown()) if (ct.isUnknown())
{ {
if (ct.errorPart) if (ct.errorPart)
return egraph.add(TAny{}); return {egraph.add(TAny{}), 1};
else else
return egraph.add(TUnknown{}); return {egraph.add(TUnknown{}), 1};
} }
std::vector<Id> parts; std::vector<Id> parts;
@ -1484,7 +1487,12 @@ Id fromCanonicalized(EGraph& egraph, CanonicalizedType& ct)
parts.insert(parts.end(), ct.functionParts.begin(), ct.functionParts.end()); parts.insert(parts.end(), ct.functionParts.begin(), ct.functionParts.end());
parts.insert(parts.end(), ct.otherParts.begin(), ct.otherParts.end()); parts.insert(parts.end(), ct.otherParts.begin(), ct.otherParts.end());
return mkUnion(egraph, std::move(parts)); std::sort(parts.begin(), parts.end());
auto it = std::unique(parts.begin(), parts.end());
parts.erase(it, parts.end());
const size_t size = parts.size();
return {mkUnion(egraph, std::move(parts)), size};
} }
void addChildren(const EGraph& egraph, const EType* enode, VecDeque<Id>& worklist) void addChildren(const EGraph& egraph, const EType* enode, VecDeque<Id>& worklist)
@ -1530,7 +1538,7 @@ const Tag* Simplifier::isTag(Id id) const
{ {
for (const auto& node : get(id).nodes) for (const auto& node : get(id).nodes)
{ {
if (const Tag* ty = node.get<Tag>()) if (const Tag* ty = node.node.get<Tag>())
return ty; return ty;
} }
@ -1564,6 +1572,16 @@ void Simplifier::subst(Id from, Id to, const std::string& ruleName, const std::u
substs.emplace_back(from, to, desc); substs.emplace_back(from, to, desc);
} }
void Simplifier::subst(Id from, size_t boringIndex, Id to, const std::string& ruleName, const std::unordered_map<Id, size_t>& forceNodes)
{
std::string desc;
if (FFlag::DebugLuauLogSimplification)
desc = mkDesc(egraph, stringCache, arena, builtinTypes, from, to, forceNodes, ruleName);
egraph.markBoring(from, boringIndex);
substs.emplace_back(from, to, desc);
}
void Simplifier::unionClasses(std::vector<Id>& hereParts, Id there) void Simplifier::unionClasses(std::vector<Id>& hereParts, Id there)
{ {
if (1 == hereParts.size() && isTag<TTopClass>(hereParts[0])) if (1 == hereParts.size() && isTag<TTopClass>(hereParts[0]))
@ -1614,9 +1632,12 @@ void Simplifier::simplifyUnion(Id id)
for (Id part : u->operands()) for (Id part : u->operands())
unionWithType(egraph, canonicalized, find(part)); unionWithType(egraph, canonicalized, find(part));
Id resultId = fromCanonicalized(egraph, canonicalized); const auto [resultId, newSize] = fromCanonicalized(egraph, canonicalized);
subst(id, resultId, "simplifyUnion", {{id, unionIndex}}); if (newSize < u->operands().size())
subst(id, unionIndex, resultId, "simplifyUnion", {{id, unionIndex}});
else
subst(id, resultId, "simplifyUnion", {{id, unionIndex}});
} }
} }
@ -1824,7 +1845,7 @@ void Simplifier::uninhabitedIntersection(Id id)
const auto& partNodes = egraph[partId].nodes; const auto& partNodes = egraph[partId].nodes;
for (size_t partIndex = 0; partIndex < partNodes.size(); ++partIndex) for (size_t partIndex = 0; partIndex < partNodes.size(); ++partIndex)
{ {
const EType& N = partNodes[partIndex]; const EType& N = partNodes[partIndex].node;
if (std::optional<EType> intersection = intersectOne(egraph, accumulator, &accumulatorNode, partId, &N)) if (std::optional<EType> intersection = intersectOne(egraph, accumulator, &accumulatorNode, partId, &N))
{ {
if (isTag<TNever>(*intersection)) if (isTag<TNever>(*intersection))
@ -1847,9 +1868,14 @@ void Simplifier::uninhabitedIntersection(Id id)
if ((unsimplified.empty() || !isTag<TUnknown>(accumulator)) && find(accumulator) != id) if ((unsimplified.empty() || !isTag<TUnknown>(accumulator)) && find(accumulator) != id)
unsimplified.push_back(accumulator); unsimplified.push_back(accumulator);
const bool isSmaller = unsimplified.size() < parts.size();
const Id result = mkIntersection(egraph, std::move(unsimplified)); const Id result = mkIntersection(egraph, std::move(unsimplified));
subst(id, result, "uninhabitedIntersection", {{id, index}}); if (isSmaller)
subst(id, index, result, "uninhabitedIntersection", {{id, index}});
else
subst(id, result, "uninhabitedIntersection", {{id, index}});
} }
} }
@ -1880,7 +1906,7 @@ void Simplifier::intersectWithNegatedClass(Id id)
const auto& iNodes = egraph[iId].nodes; const auto& iNodes = egraph[iId].nodes;
for (size_t iIndex = 0; iIndex < iNodes.size(); ++iIndex) for (size_t iIndex = 0; iIndex < iNodes.size(); ++iIndex)
{ {
const EType& iNode = iNodes[iIndex]; const EType& iNode = iNodes[iIndex].node;
if (isTag<TNil>(iNode) || isTag<TBoolean>(iNode) || isTag<TNumber>(iNode) || isTag<TString>(iNode) || isTag<TThread>(iNode) || if (isTag<TNil>(iNode) || isTag<TBoolean>(iNode) || isTag<TNumber>(iNode) || isTag<TString>(iNode) || isTag<TThread>(iNode) ||
isTag<TTopFunction>(iNode) || isTag<TTopFunction>(iNode) ||
// isTag<TTopTable>(iNode) || // I'm not sure about this one. // isTag<TTopTable>(iNode) || // I'm not sure about this one.
@ -1923,7 +1949,7 @@ void Simplifier::intersectWithNegatedClass(Id id)
newParts.push_back(part); newParts.push_back(part);
} }
Id substId = egraph.add(Intersection{newParts.begin(), newParts.end()}); Id substId = mkIntersection(egraph, newParts);
subst( subst(
id, id,
substId, substId,
@ -1965,7 +1991,7 @@ void Simplifier::intersectWithNegatedAtom(Id id)
{ {
for (size_t negationOperandIndex = 0; negationOperandIndex < egraph[negation->operands()[0]].nodes.size(); ++negationOperandIndex) for (size_t negationOperandIndex = 0; negationOperandIndex < egraph[negation->operands()[0]].nodes.size(); ++negationOperandIndex)
{ {
const EType* negationOperand = &egraph[negation->operands()[0]].nodes[negationOperandIndex]; const EType* negationOperand = &egraph[negation->operands()[0]].nodes[negationOperandIndex].node;
if (!isTerminal(*negationOperand) || negationOperand->get<TOpaque>()) if (!isTerminal(*negationOperand) || negationOperand->get<TOpaque>())
continue; continue;
@ -1976,7 +2002,7 @@ void Simplifier::intersectWithNegatedAtom(Id id)
for (size_t jNodeIndex = 0; jNodeIndex < egraph[intersectionOperands[j]].nodes.size(); ++jNodeIndex) for (size_t jNodeIndex = 0; jNodeIndex < egraph[intersectionOperands[j]].nodes.size(); ++jNodeIndex)
{ {
const EType* jNode = &egraph[intersectionOperands[j]].nodes[jNodeIndex]; const EType* jNode = &egraph[intersectionOperands[j]].nodes[jNodeIndex].node;
if (!isTerminal(*jNode) || jNode->get<TOpaque>()) if (!isTerminal(*jNode) || jNode->get<TOpaque>())
continue; continue;
@ -2001,7 +2027,7 @@ void Simplifier::intersectWithNegatedAtom(Id id)
subst( subst(
id, id,
egraph.add(Intersection{newOperands}), mkIntersection(egraph, std::move(newOperands)),
"intersectWithNegatedAtom", "intersectWithNegatedAtom",
{{id, intersectionIndex}, {intersectionOperands[i], negationIndex}, {intersectionOperands[j], jNodeIndex}} {{id, intersectionIndex}, {intersectionOperands[i], negationIndex}, {intersectionOperands[j], jNodeIndex}}
); );
@ -2178,7 +2204,7 @@ void Simplifier::expandNegation(Id id)
if (!ok) if (!ok)
continue; continue;
subst(id, fromCanonicalized(egraph, canonicalized), "expandNegation", {{id, index}}); subst(id, fromCanonicalized(egraph, canonicalized).first, "expandNegation", {{id, index}});
} }
} }
@ -2576,9 +2602,9 @@ std::optional<EqSatSimplificationResult> eqSatSimplify(NotNull<Simplifier> simpl
// try to run any rules on it. // try to run any rules on it.
bool shouldAbort = false; bool shouldAbort = false;
for (const EType& enode : egraph[id].nodes) for (const auto& enode : egraph[id].nodes)
{ {
if (isTerminal(enode)) if (isTerminal(enode.node))
{ {
shouldAbort = true; shouldAbort = true;
break; break;
@ -2588,8 +2614,8 @@ std::optional<EqSatSimplificationResult> eqSatSimplify(NotNull<Simplifier> simpl
if (shouldAbort) if (shouldAbort)
continue; continue;
for (const EType& enode : egraph[id].nodes) for (const auto& enode : egraph[id].nodes)
addChildren(egraph, &enode, worklist); addChildren(egraph, &enode.node, worklist);
for (Simplifier::RewriteRuleFn rule : rules) for (Simplifier::RewriteRuleFn rule : rules)
(simplifier.get()->*rule)(id); (simplifier.get()->*rule)(id);

View file

@ -200,7 +200,7 @@ ScopePtr findClosestScope(const ModulePtr& module, const AstStat* nearestStateme
return closest; return closest;
} }
FragmentParseResult parseFragment( std::optional<FragmentParseResult> parseFragment(
const SourceModule& srcModule, const SourceModule& srcModule,
std::string_view src, std::string_view src,
const Position& cursorPos, const Position& cursorPos,
@ -245,6 +245,10 @@ FragmentParseResult parseFragment(
opts.captureComments = true; opts.captureComments = true;
opts.parseFragment = FragmentParseResumeSettings{std::move(result.localMap), std::move(result.localStack), startPos}; opts.parseFragment = FragmentParseResumeSettings{std::move(result.localMap), std::move(result.localStack), startPos};
ParseResult p = Luau::Parser::parse(srcStart, parseLength, *nameTbl, *fragmentResult.alloc.get(), opts); ParseResult p = Luau::Parser::parse(srcStart, parseLength, *nameTbl, *fragmentResult.alloc.get(), opts);
// This means we threw a ParseError and we should decline to offer autocomplete here.
if (p.root == nullptr)
return std::nullopt;
std::vector<AstNode*> fabricatedAncestry = std::move(result.ancestry); std::vector<AstNode*> fabricatedAncestry = std::move(result.ancestry);
// Get the ancestry for the fragment at the offset cursor position. // Get the ancestry for the fragment at the offset cursor position.
@ -366,7 +370,8 @@ FragmentTypeCheckResult typecheckFragment_(
TypeFunctionRuntime typeFunctionRuntime(iceHandler, NotNull{&limits}); TypeFunctionRuntime typeFunctionRuntime(iceHandler, NotNull{&limits});
/// Create a DataFlowGraph just for the surrounding context /// Create a DataFlowGraph just for the surrounding context
auto dfg = DataFlowGraphBuilder::build(root, iceHandler); DataFlowGraph dfg = DataFlowGraphBuilder::build(root, NotNull{&incrementalModule->defArena}, NotNull{&incrementalModule->keyArena}, iceHandler);
SimplifierPtr simplifier = newSimplifier(NotNull{&incrementalModule->internalTypes}, frontend.builtinTypes); SimplifierPtr simplifier = newSimplifier(NotNull{&incrementalModule->internalTypes}, frontend.builtinTypes);
FrontendModuleResolver& resolver = getModuleResolver(frontend, opts); FrontendModuleResolver& resolver = getModuleResolver(frontend, opts);
@ -468,7 +473,13 @@ std::pair<FragmentTypeCheckStatus, FragmentTypeCheckResult> typecheckFragment(
return {}; return {};
} }
FragmentParseResult parseResult = parseFragment(*sourceModule, src, cursorPos, fragmentEndPosition); auto tryParse = parseFragment(*sourceModule, src, cursorPos, fragmentEndPosition);
if (!tryParse)
return {FragmentTypeCheckStatus::SkipAutocomplete, {}};
FragmentParseResult& parseResult = *tryParse;
if (isWithinComment(parseResult.commentLocations, fragmentEndPosition.value_or(cursorPos))) if (isWithinComment(parseResult.commentLocations, fragmentEndPosition.value_or(cursorPos)))
return {FragmentTypeCheckStatus::SkipAutocomplete, {}}; return {FragmentTypeCheckStatus::SkipAutocomplete, {}};

View file

@ -13,6 +13,7 @@
#include "Luau/EqSatSimplification.h" #include "Luau/EqSatSimplification.h"
#include "Luau/FileResolver.h" #include "Luau/FileResolver.h"
#include "Luau/NonStrictTypeChecker.h" #include "Luau/NonStrictTypeChecker.h"
#include "Luau/NotNull.h"
#include "Luau/Parser.h" #include "Luau/Parser.h"
#include "Luau/Scope.h" #include "Luau/Scope.h"
#include "Luau/StringUtils.h" #include "Luau/StringUtils.h"
@ -1338,7 +1339,7 @@ ModulePtr check(
} }
} }
DataFlowGraph dfg = DataFlowGraphBuilder::build(sourceModule.root, iceHandler); DataFlowGraph dfg = DataFlowGraphBuilder::build(sourceModule.root, NotNull{&result->defArena}, NotNull{&result->keyArena}, iceHandler);
UnifierSharedState unifierState{iceHandler}; UnifierSharedState unifierState{iceHandler};
unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit; unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit;

View file

@ -61,9 +61,7 @@ TypeId Instantiation::clean(TypeId ty)
LUAU_ASSERT(ftv); LUAU_ASSERT(ftv);
FunctionType clone = FunctionType{level, scope, ftv->argTypes, ftv->retTypes, ftv->definition, ftv->hasSelf}; FunctionType clone = FunctionType{level, scope, ftv->argTypes, ftv->retTypes, ftv->definition, ftv->hasSelf};
clone.magicFunction = ftv->magicFunction; clone.magic = ftv->magic;
clone.dcrMagicFunction = ftv->dcrMagicFunction;
clone.dcrMagicRefinement = ftv->dcrMagicRefinement;
clone.tags = ftv->tags; clone.tags = ftv->tags;
clone.argNames = ftv->argNames; clone.argNames = ftv->argNames;
TypeId result = addType(std::move(clone)); TypeId result = addType(std::move(clone));

View file

@ -98,9 +98,7 @@ static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool a
FunctionType clone = FunctionType{a.level, a.scope, a.argTypes, a.retTypes, a.definition, a.hasSelf}; FunctionType clone = FunctionType{a.level, a.scope, a.argTypes, a.retTypes, a.definition, a.hasSelf};
clone.generics = a.generics; clone.generics = a.generics;
clone.genericPacks = a.genericPacks; clone.genericPacks = a.genericPacks;
clone.magicFunction = a.magicFunction; clone.magic = a.magic;
clone.dcrMagicFunction = a.dcrMagicFunction;
clone.dcrMagicRefinement = a.dcrMagicRefinement;
clone.tags = a.tags; clone.tags = a.tags;
clone.argNames = a.argNames; clone.argNames = a.argNames;
clone.isCheckedFunction = a.isCheckedFunction; clone.isCheckedFunction = a.isCheckedFunction;

View file

@ -554,12 +554,12 @@ BlockedType::BlockedType()
{ {
} }
Constraint* BlockedType::getOwner() const const Constraint* BlockedType::getOwner() const
{ {
return owner; return owner;
} }
void BlockedType::setOwner(Constraint* newOwner) void BlockedType::setOwner(const Constraint* newOwner)
{ {
LUAU_ASSERT(owner == nullptr); LUAU_ASSERT(owner == nullptr);
@ -569,7 +569,7 @@ void BlockedType::setOwner(Constraint* newOwner)
owner = newOwner; owner = newOwner;
} }
void BlockedType::replaceOwner(Constraint* newOwner) void BlockedType::replaceOwner(const Constraint* newOwner)
{ {
owner = newOwner; owner = newOwner;
} }

View file

@ -1454,10 +1454,11 @@ void TypeChecker2::visitCall(AstExprCall* call)
TypePackId argsTp = module->internalTypes.addTypePack(args); TypePackId argsTp = module->internalTypes.addTypePack(args);
if (auto ftv = get<FunctionType>(follow(*originalCallTy))) if (auto ftv = get<FunctionType>(follow(*originalCallTy)))
{ {
if (ftv->dcrMagicTypeCheck) if (ftv->magic)
{ {
ftv->dcrMagicTypeCheck(MagicFunctionTypeCheckContext{NotNull{this}, builtinTypes, call, argsTp, scope}); bool usedMagic = ftv->magic->typeCheck(MagicFunctionTypeCheckContext{NotNull{this}, builtinTypes, call, argsTp, scope});
return; if (usedMagic)
return;
} }
} }

View file

@ -4506,10 +4506,10 @@ std::unique_ptr<WithPredicate<TypePackId>> TypeChecker::checkCallOverload(
// When this function type has magic functions and did return something, we select that overload instead. // When this function type has magic functions and did return something, we select that overload instead.
// TODO: pass in a Unifier object to the magic functions? This will allow the magic functions to cooperate with overload resolution. // TODO: pass in a Unifier object to the magic functions? This will allow the magic functions to cooperate with overload resolution.
if (ftv->magicFunction) if (ftv->magic)
{ {
// TODO: We're passing in the wrong TypePackId. Should be argPack, but a unit test fails otherwise. CLI-40458 // TODO: We're passing in the wrong TypePackId. Should be argPack, but a unit test fails otherwise. CLI-40458
if (std::optional<WithPredicate<TypePackId>> ret = ftv->magicFunction(*this, scope, expr, argListResult)) if (std::optional<WithPredicate<TypePackId>> ret = ftv->magic->handleOldSolver(*this, scope, expr, argListResult))
return std::make_unique<WithPredicate<TypePackId>>(std::move(*ret)); return std::make_unique<WithPredicate<TypePackId>>(std::move(*ret));
} }

View file

@ -23,6 +23,7 @@ LUAU_FASTFLAGVARIABLE(LuauAllowComplexTypesInGenericParams)
LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryForTableTypes) LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryForTableTypes)
LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryForClassNames) LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryForClassNames)
LUAU_FASTFLAGVARIABLE(LuauFixFunctionNameStartPosition) LUAU_FASTFLAGVARIABLE(LuauFixFunctionNameStartPosition)
LUAU_FASTFLAGVARIABLE(LuauExtendStatEndPosWithSemicolon)
namespace Luau namespace Luau
{ {
@ -288,6 +289,10 @@ AstStatBlock* Parser::parseBlockNoScope()
{ {
nextLexeme(); nextLexeme();
stat->hasSemicolon = true; stat->hasSemicolon = true;
if (FFlag::LuauExtendStatEndPosWithSemicolon)
{
stat->location.end = lexer.previousLocation().end;
}
} }
body.push_back(stat); body.push_back(stat);

View file

@ -14,6 +14,7 @@ inline bool isFlagExperimental(const char* flag)
"LuauInstantiateInSubtyping", // requires some fixes to lua-apps code "LuauInstantiateInSubtyping", // requires some fixes to lua-apps code
"LuauFixIndexerSubtypingOrdering", // requires some small fixes to lua-apps code since this fixes a false negative "LuauFixIndexerSubtypingOrdering", // requires some small fixes to lua-apps code since this fixes a false negative
"StudioReportLuauAny2", // takes telemetry data for usage of any types "StudioReportLuauAny2", // takes telemetry data for usage of any types
"LuauTableCloneClonesType", // requires fixes in lua-apps code, terrifyingly
"LuauSolverV2", "LuauSolverV2",
// makes sure we always have at least one entry // makes sure we always have at least one entry
nullptr, nullptr,

View file

@ -51,13 +51,70 @@ struct Analysis final
} }
}; };
template<typename L>
struct Node
{
L node;
bool boring = false;
struct Hash
{
size_t operator()(const Node& node) const
{
return typename L::Hash{}(node.node);
}
};
};
template <typename L>
struct NodeIterator
{
private:
using iterator = std::vector<Node<L>>;
iterator iter;
public:
L& operator*()
{
return iter->node;
}
const L& operator*() const
{
return iter->node;
}
iterator& operator++()
{
++iter;
return *this;
}
iterator operator++(int)
{
iterator copy = *this;
++*this;
return copy;
}
bool operator==(const iterator& rhs) const
{
return iter == rhs.iter;
}
bool operator!=(const iterator& rhs) const
{
return iter != rhs.iter;
}
};
/// Each e-class is a set of e-nodes representing equivalent terms from a given language, /// Each e-class is a set of e-nodes representing equivalent terms from a given language,
/// and an e-node is a function symbol paired with a list of children e-classes. /// and an e-node is a function symbol paired with a list of children e-classes.
template<typename L, typename D> template<typename L, typename D>
struct EClass final struct EClass final
{ {
Id id; Id id;
std::vector<L> nodes; std::vector<Node<L>> nodes;
D data; D data;
std::vector<std::pair<L, Id>> parents; std::vector<std::pair<L, Id>> parents;
}; };
@ -125,9 +182,9 @@ struct EGraph final
std::sort( std::sort(
eclass1.nodes.begin(), eclass1.nodes.begin(),
eclass1.nodes.end(), eclass1.nodes.end(),
[](const L& left, const L& right) [](const Node<L>& left, const Node<L>& right)
{ {
return left.index() < right.index(); return left.node.index() < right.node.index();
} }
); );
@ -177,6 +234,11 @@ struct EGraph final
return classes; return classes;
} }
void markBoring(Id id, size_t index)
{
get(id).nodes[index].boring = true;
}
private: private:
Analysis<L, N> analysis; Analysis<L, N> analysis;
@ -225,7 +287,7 @@ private:
id, id,
EClassT{ EClassT{
id, id,
{enode}, {Node<L>{enode, false}},
analysis.make(*this, enode), analysis.make(*this, enode),
{}, {},
} }
@ -264,18 +326,18 @@ private:
std::vector<std::pair<L, Id>> parents = get(id).parents; std::vector<std::pair<L, Id>> parents = get(id).parents;
for (auto& pair : parents) for (auto& pair : parents)
{ {
L& enode = pair.first; L& parentNode = pair.first;
Id id = pair.second; Id parentId = pair.second;
// By removing the old enode from the hashcons map, we will always find our new canonicalized eclass id. // By removing the old enode from the hashcons map, we will always find our new canonicalized eclass id.
hashcons.erase(enode); hashcons.erase(parentNode);
canonicalize(enode); canonicalize(parentNode);
hashcons.insert_or_assign(enode, find(id)); hashcons.insert_or_assign(parentNode, find(parentId));
if (auto it = newParents.find(enode); it != newParents.end()) if (auto it = newParents.find(parentNode); it != newParents.end())
merge(id, it->second); merge(parentId, it->second);
newParents.insert_or_assign(enode, find(id)); newParents.insert_or_assign(parentNode, find(parentId));
} }
// We reacquire the pointer because the prior loop potentially merges // We reacquire the pointer because the prior loop potentially merges
@ -287,22 +349,30 @@ private:
for (const auto& [node, id] : newParents) for (const auto& [node, id] : newParents)
eclass->parents.emplace_back(std::move(node), std::move(id)); eclass->parents.emplace_back(std::move(node), std::move(id));
std::unordered_set<L, typename L::Hash> newNodes; std::unordered_map<L, bool, typename L::Hash> newNodes;
for (L node : eclass->nodes) for (Node<L> node : eclass->nodes)
{ {
canonicalize(node); canonicalize(node.node);
newNodes.insert(std::move(node));
bool& b = newNodes[std::move(node.node)];
b = b || node.boring;
} }
eclass->nodes.assign(newNodes.begin(), newNodes.end()); eclass->nodes.clear();
while (!newNodes.empty())
{
auto n = newNodes.extract(newNodes.begin());
eclass->nodes.push_back(Node<L>{n.key(), n.mapped()});
}
// FIXME: Extract into sortByTag() // FIXME: Extract into sortByTag()
std::sort( std::sort(
eclass->nodes.begin(), eclass->nodes.begin(),
eclass->nodes.end(), eclass->nodes.end(),
[](const L& left, const L& right) [](const Node<L>& left, const Node<L>& right)
{ {
return left.index() < right.index(); return left.node.index() < right.node.index();
} }
); );
} }

View file

@ -154,6 +154,7 @@ LUA_API const float* lua_tovector(lua_State* L, int idx);
LUA_API int lua_toboolean(lua_State* L, int idx); LUA_API int lua_toboolean(lua_State* L, int idx);
LUA_API const char* lua_tolstring(lua_State* L, int idx, size_t* len); LUA_API const char* lua_tolstring(lua_State* L, int idx, size_t* len);
LUA_API const char* lua_tostringatom(lua_State* L, int idx, int* atom); LUA_API const char* lua_tostringatom(lua_State* L, int idx, int* atom);
LUA_API const char* lua_tolstringatom(lua_State* L, int idx, size_t* len, int* atom);
LUA_API const char* lua_namecallatom(lua_State* L, int* atom); LUA_API const char* lua_namecallatom(lua_State* L, int* atom);
LUA_API int lua_objlen(lua_State* L, int idx); LUA_API int lua_objlen(lua_State* L, int idx);
LUA_API lua_CFunction lua_tocfunction(lua_State* L, int idx); LUA_API lua_CFunction lua_tocfunction(lua_State* L, int idx);

View file

@ -454,6 +454,29 @@ const char* lua_tostringatom(lua_State* L, int idx, int* atom)
return getstr(s); return getstr(s);
} }
const char* lua_tolstringatom(lua_State* L, int idx, size_t* len, int* atom)
{
StkId o = index2addr(L, idx);
if (!ttisstring(o))
{
if (len)
*len = 0;
return NULL;
}
TString* s = tsvalue(o);
if (len)
*len = s->len;
if (atom)
{
updateatom(L, s);
*atom = s->atom;
}
return getstr(s);
}
const char* lua_namecallatom(lua_State* L, int* atom) const char* lua_namecallatom(lua_State* L, int* atom)
{ {
TString* s = L->namecall; TString* s = L->namecall;

View file

@ -422,6 +422,20 @@ int luaG_isnative(lua_State* L, int level)
return (ci->flags & LUA_CALLINFO_NATIVE) != 0 ? 1 : 0; return (ci->flags & LUA_CALLINFO_NATIVE) != 0 ? 1 : 0;
} }
int luaG_hasnative(lua_State* L, int level)
{
if (unsigned(level) >= unsigned(L->ci - L->base_ci))
return 0;
CallInfo* ci = L->ci - level;
Proto* proto = getluaproto(ci);
if (proto == nullptr)
return 0;
return (proto->execdata != nullptr);
}
void lua_singlestep(lua_State* L, int enabled) void lua_singlestep(lua_State* L, int enabled)
{ {
L->singlestep = bool(enabled); L->singlestep = bool(enabled);

View file

@ -31,3 +31,4 @@ LUAI_FUNC bool luaG_onbreak(lua_State* L);
LUAI_FUNC int luaG_getline(Proto* p, int pc); LUAI_FUNC int luaG_getline(Proto* p, int pc);
LUAI_FUNC int luaG_isnative(lua_State* L, int level); LUAI_FUNC int luaG_isnative(lua_State* L, int level);
LUAI_FUNC int luaG_hasnative(lua_State* L, int level);

View file

@ -4304,4 +4304,29 @@ foo(@1)
CHECK_EQ(EXPECTED_INSERT, *ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); CHECK_EQ(EXPECTED_INSERT, *ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText);
} }
TEST_CASE_FIXTURE(ACFixture, "autocomplete_at_end_of_stmt_should_continue_as_part_of_stmt")
{
check(R"(
local data = { x = 1 }
local var = data.@1
)");
auto ac = autocomplete('1');
CHECK(!ac.entryMap.empty());
CHECK(ac.entryMap.count("x"));
CHECK_EQ(ac.context, AutocompleteContext::Property);
}
TEST_CASE_FIXTURE(ACFixture, "autocomplete_after_semicolon_should_complete_a_new_statement")
{
check(R"(
local data = { x = 1 }
local var = data;@1
)");
auto ac = autocomplete('1');
CHECK(!ac.entryMap.empty());
CHECK(ac.entryMap.count("table"));
CHECK(ac.entryMap.count("math"));
CHECK_EQ(ac.context, AutocompleteContext::Statement);
}
TEST_SUITE_END(); TEST_SUITE_END();

View file

@ -22,7 +22,9 @@ ConstraintGeneratorFixture::ConstraintGeneratorFixture()
void ConstraintGeneratorFixture::generateConstraints(const std::string& code) void ConstraintGeneratorFixture::generateConstraints(const std::string& code)
{ {
AstStatBlock* root = parse(code); AstStatBlock* root = parse(code);
dfg = std::make_unique<DataFlowGraph>(DataFlowGraphBuilder::build(root, NotNull{&ice})); dfg = std::make_unique<DataFlowGraph>(
DataFlowGraphBuilder::build(root, NotNull{&mainModule->defArena}, NotNull{&mainModule->keyArena}, NotNull{&ice})
);
cg = std::make_unique<ConstraintGenerator>( cg = std::make_unique<ConstraintGenerator>(
mainModule, mainModule,
NotNull{&normalizer}, NotNull{&normalizer},

View file

@ -1,6 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/DataFlowGraph.h" #include "Luau/DataFlowGraph.h"
#include "Fixture.h" #include "Fixture.h"
#include "Luau/Def.h"
#include "Luau/Error.h" #include "Luau/Error.h"
#include "Luau/Parser.h" #include "Luau/Parser.h"
@ -18,6 +19,8 @@ struct DataFlowGraphFixture
// Only needed to fix the operator== reflexivity of an empty Symbol. // Only needed to fix the operator== reflexivity of an empty Symbol.
ScopedFastFlag dcr{FFlag::LuauSolverV2, true}; ScopedFastFlag dcr{FFlag::LuauSolverV2, true};
DefArena defArena;
RefinementKeyArena keyArena;
InternalErrorReporter handle; InternalErrorReporter handle;
Allocator allocator; Allocator allocator;
@ -32,7 +35,7 @@ struct DataFlowGraphFixture
if (!parseResult.errors.empty()) if (!parseResult.errors.empty())
throw ParseErrors(std::move(parseResult.errors)); throw ParseErrors(std::move(parseResult.errors));
module = parseResult.root; module = parseResult.root;
graph = DataFlowGraphBuilder::build(module, NotNull{&handle}); graph = DataFlowGraphBuilder::build(module, NotNull{&defArena}, NotNull{&keyArena}, NotNull{&handle});
} }
template<typename T, int N> template<typename T, int N>

View file

@ -26,6 +26,7 @@ LUAU_FASTFLAG(LuauSymbolEquality);
LUAU_FASTFLAG(LuauStoreSolverTypeOnModule); LUAU_FASTFLAG(LuauStoreSolverTypeOnModule);
LUAU_FASTFLAG(LexerResumesFromPosition2) LUAU_FASTFLAG(LexerResumesFromPosition2)
LUAU_FASTFLAG(LuauIncrementalAutocompleteCommentDetection) LUAU_FASTFLAG(LuauIncrementalAutocompleteCommentDetection)
LUAU_FASTINT(LuauParseErrorLimit)
static std::optional<AutocompleteEntryMap> nullCallback(std::string tag, std::optional<const ClassType*> ptr, std::optional<std::string> contents) static std::optional<AutocompleteEntryMap> nullCallback(std::string tag, std::optional<const ClassType*> ptr, std::optional<std::string> contents)
{ {
@ -69,7 +70,7 @@ struct FragmentAutocompleteFixtureImpl : BaseType
} }
FragmentParseResult parseFragment( std::optional<FragmentParseResult> parseFragment(
const std::string& document, const std::string& document,
const Position& cursorPos, const Position& cursorPos,
std::optional<Position> fragmentEndPosition = std::nullopt std::optional<Position> fragmentEndPosition = std::nullopt
@ -164,6 +165,7 @@ end
} }
}; };
//NOLINTBEGIN(bugprone-unchecked-optional-access)
TEST_SUITE_BEGIN("FragmentAutocompleteTraversalTests"); TEST_SUITE_BEGIN("FragmentAutocompleteTraversalTests");
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "just_two_locals") TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "just_two_locals")
@ -286,13 +288,23 @@ TEST_SUITE_END();
TEST_SUITE_BEGIN("FragmentAutocompleteParserTests"); TEST_SUITE_BEGIN("FragmentAutocompleteParserTests");
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "thrown_parse_error_leads_to_null_root")
{
check("type A = ");
ScopedFastInt sfi{FInt::LuauParseErrorLimit, 1};
auto fragment = parseFragment("type A = <>function<> more garbage here", Position(0, 39));
CHECK(fragment == std::nullopt);
}
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "local_initializer") TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "local_initializer")
{ {
ScopedFastFlag sff{FFlag::LuauSolverV2, true}; ScopedFastFlag sff{FFlag::LuauSolverV2, true};
check("local a ="); check("local a =");
auto fragment = parseFragment("local a =", Position(0, 10)); auto fragment = parseFragment("local a =", Position(0, 10));
CHECK_EQ("local a =", fragment.fragmentToParse);
CHECK_EQ(Location{Position{0, 0}, 9}, fragment.root->location); REQUIRE(fragment.has_value());
CHECK_EQ("local a =", fragment->fragmentToParse);
CHECK_EQ(Location{Position{0, 0}, 9}, fragment->root->location);
} }
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "statement_in_empty_fragment_is_non_null") TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "statement_in_empty_fragment_is_non_null")
@ -310,11 +322,12 @@ TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "statement_in_empty_fragment_is_n
)", )",
Position(1, 0) Position(1, 0)
); );
CHECK_EQ("\n", fragment.fragmentToParse); REQUIRE(fragment.has_value());
CHECK_EQ(2, fragment.ancestry.size()); CHECK_EQ("\n", fragment->fragmentToParse);
REQUIRE(fragment.root); CHECK_EQ(2, fragment->ancestry.size());
CHECK_EQ(0, fragment.root->body.size); REQUIRE(fragment->root);
auto statBody = fragment.root->as<AstStatBlock>(); CHECK_EQ(0, fragment->root->body.size);
auto statBody = fragment->root->as<AstStatBlock>();
CHECK(statBody != nullptr); CHECK(statBody != nullptr);
} }
@ -339,13 +352,15 @@ local z = x + y
Position{3, 15} Position{3, 15}
); );
CHECK_EQ(Location{Position{2, 0}, Position{3, 15}}, fragment.root->location); REQUIRE(fragment.has_value());
CHECK_EQ("local y = 5\nlocal z = x + y", fragment.fragmentToParse); CHECK_EQ(Location{Position{2, 0}, Position{3, 15}}, fragment->root->location);
CHECK_EQ(5, fragment.ancestry.size());
REQUIRE(fragment.root); CHECK_EQ("local y = 5\nlocal z = x + y", fragment->fragmentToParse);
CHECK_EQ(2, fragment.root->body.size); CHECK_EQ(5, fragment->ancestry.size());
auto stat = fragment.root->body.data[1]->as<AstStatLocal>(); REQUIRE(fragment->root);
CHECK_EQ(2, fragment->root->body.size);
auto stat = fragment->root->body.data[1]->as<AstStatLocal>();
REQUIRE(stat); REQUIRE(stat);
CHECK_EQ(1, stat->vars.size); CHECK_EQ(1, stat->vars.size);
CHECK_EQ(1, stat->values.size); CHECK_EQ(1, stat->values.size);
@ -384,12 +399,14 @@ local y = 5
Position{2, 15} Position{2, 15}
); );
CHECK_EQ("local z = x + y", fragment.fragmentToParse); REQUIRE(fragment.has_value());
CHECK_EQ(5, fragment.ancestry.size());
REQUIRE(fragment.root); CHECK_EQ("local z = x + y", fragment->fragmentToParse);
CHECK_EQ(Location{Position{2, 0}, Position{2, 15}}, fragment.root->location); CHECK_EQ(5, fragment->ancestry.size());
CHECK_EQ(1, fragment.root->body.size); REQUIRE(fragment->root);
auto stat = fragment.root->body.data[0]->as<AstStatLocal>(); CHECK_EQ(Location{Position{2, 0}, Position{2, 15}}, fragment->root->location);
CHECK_EQ(1, fragment->root->body.size);
auto stat = fragment->root->body.data[0]->as<AstStatLocal>();
REQUIRE(stat); REQUIRE(stat);
CHECK_EQ(1, stat->vars.size); CHECK_EQ(1, stat->vars.size);
CHECK_EQ(1, stat->values.size); CHECK_EQ(1, stat->values.size);
@ -429,7 +446,9 @@ TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "can_parse_in_correct_scope")
Position{6, 0} Position{6, 0}
); );
CHECK_EQ("\n ", fragment.fragmentToParse); REQUIRE(fragment.has_value());
CHECK_EQ("\n ", fragment->fragmentToParse);
} }
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "can_parse_single_line_fragment_override") TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "can_parse_single_line_fragment_override")
@ -448,17 +467,19 @@ abc("bar")
Position{1, 10} Position{1, 10}
); );
CHECK_EQ("function abc(foo: string) end\nabc(\"foo\")", callFragment.fragmentToParse); REQUIRE(callFragment.has_value());
CHECK(callFragment.nearestStatement->is<AstStatFunction>());
CHECK_GE(callFragment.ancestry.size(), 2); CHECK_EQ("function abc(foo: string) end\nabc(\"foo\")", callFragment->fragmentToParse);
CHECK(callFragment->nearestStatement->is<AstStatFunction>());
AstNode* back = callFragment.ancestry.back(); CHECK_GE(callFragment->ancestry.size(), 2);
AstNode* back = callFragment->ancestry.back();
CHECK(back->is<AstExprConstantString>()); CHECK(back->is<AstExprConstantString>());
CHECK_EQ(Position{1, 4}, back->location.begin); CHECK_EQ(Position{1, 4}, back->location.begin);
CHECK_EQ(Position{1, 9}, back->location.end); CHECK_EQ(Position{1, 9}, back->location.end);
AstNode* parent = callFragment.ancestry.rbegin()[1]; AstNode* parent = callFragment->ancestry.rbegin()[1];
CHECK(parent->is<AstExprCall>()); CHECK(parent->is<AstExprCall>());
CHECK_EQ(Position{1, 0}, parent->location.begin); CHECK_EQ(Position{1, 0}, parent->location.begin);
CHECK_EQ(Position{1, 10}, parent->location.end); CHECK_EQ(Position{1, 10}, parent->location.end);
@ -473,12 +494,14 @@ abc("bar")
Position{1, 9} Position{1, 9}
); );
CHECK_EQ("function abc(foo: string) end\nabc(\"foo\")", stringFragment.fragmentToParse); REQUIRE(stringFragment.has_value());
CHECK(stringFragment.nearestStatement->is<AstStatFunction>());
CHECK_GE(stringFragment.ancestry.size(), 1); CHECK_EQ("function abc(foo: string) end\nabc(\"foo\")", stringFragment->fragmentToParse);
CHECK(stringFragment->nearestStatement->is<AstStatFunction>());
back = stringFragment.ancestry.back(); CHECK_GE(stringFragment->ancestry.size(), 1);
back = stringFragment->ancestry.back();
auto asString = back->as<AstExprConstantString>(); auto asString = back->as<AstExprConstantString>();
CHECK(asString); CHECK(asString);
@ -508,17 +531,19 @@ abc("bar")
Position{3, 1} Position{3, 1}
); );
CHECK_EQ("function abc(foo: string) end\nabc(\n\"foo\"\n)", fragment.fragmentToParse); REQUIRE(fragment.has_value());
CHECK(fragment.nearestStatement->is<AstStatFunction>());
CHECK_GE(fragment.ancestry.size(), 2); CHECK_EQ("function abc(foo: string) end\nabc(\n\"foo\"\n)", fragment->fragmentToParse);
CHECK(fragment->nearestStatement->is<AstStatFunction>());
AstNode* back = fragment.ancestry.back(); CHECK_GE(fragment->ancestry.size(), 2);
AstNode* back = fragment->ancestry.back();
CHECK(back->is<AstExprConstantString>()); CHECK(back->is<AstExprConstantString>());
CHECK_EQ(Position{2, 0}, back->location.begin); CHECK_EQ(Position{2, 0}, back->location.begin);
CHECK_EQ(Position{2, 5}, back->location.end); CHECK_EQ(Position{2, 5}, back->location.end);
AstNode* parent = fragment.ancestry.rbegin()[1]; AstNode* parent = fragment->ancestry.rbegin()[1];
CHECK(parent->is<AstExprCall>()); CHECK(parent->is<AstExprCall>());
CHECK_EQ(Position{1, 0}, parent->location.begin); CHECK_EQ(Position{1, 0}, parent->location.begin);
CHECK_EQ(Position{3, 1}, parent->location.end); CHECK_EQ(Position{3, 1}, parent->location.end);
@ -549,6 +574,7 @@ t
} }
TEST_SUITE_END(); TEST_SUITE_END();
//NOLINTEND(bugprone-unchecked-optional-access)
TEST_SUITE_BEGIN("FragmentAutocompleteTypeCheckerTests"); TEST_SUITE_BEGIN("FragmentAutocompleteTypeCheckerTests");
@ -1558,4 +1584,26 @@ if x == 5 then -- a comment
); );
} }
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "fragment_autocomplete_handles_parse_errors")
{
ScopedFastInt sfi{FInt::LuauParseErrorLimit, 1};
const std::string source = R"(
)";
const std::string updated = R"(
type A = <>random non code text here
)";
autocompleteFragmentInBothSolvers(
source,
updated,
Position{1, 38},
[](FragmentAutocompleteResult& result)
{
CHECK(result.acResults.entryMap.empty());
}
);
}
TEST_SUITE_END(); TEST_SUITE_END();

View file

@ -20,6 +20,7 @@ LUAU_FASTFLAG(LuauAllowComplexTypesInGenericParams)
LUAU_FASTFLAG(LuauErrorRecoveryForTableTypes) LUAU_FASTFLAG(LuauErrorRecoveryForTableTypes)
LUAU_FASTFLAG(LuauErrorRecoveryForClassNames) LUAU_FASTFLAG(LuauErrorRecoveryForClassNames)
LUAU_FASTFLAG(LuauFixFunctionNameStartPosition) LUAU_FASTFLAG(LuauFixFunctionNameStartPosition)
LUAU_FASTFLAG(LuauExtendStatEndPosWithSemicolon)
namespace namespace
{ {
@ -3766,5 +3767,32 @@ TEST_CASE_FIXTURE(Fixture, "function_name_has_correct_start_location")
CHECK_EQ(Position{4, 17}, function2->name->location.begin); CHECK_EQ(Position{4, 17}, function2->name->location.begin);
} }
TEST_CASE_FIXTURE(Fixture, "stat_end_includes_semicolon_position")
{
ScopedFastFlag _{FFlag::LuauExtendStatEndPosWithSemicolon, true};
AstStatBlock* block = parse(R"(
local x = 1
local y = 2;
local z = 3 ;
)");
REQUIRE_EQ(3, block->body.size);
const auto stat1 = block->body.data[0];
LUAU_ASSERT(stat1);
CHECK_FALSE(stat1->hasSemicolon);
CHECK_EQ(Position{1, 19}, stat1->location.end);
const auto stat2 = block->body.data[1];
LUAU_ASSERT(stat2);
CHECK(stat2->hasSemicolon);
CHECK_EQ(Position{2, 20}, stat2->location.end);
const auto stat3 = block->body.data[2];
LUAU_ASSERT(stat3);
CHECK(stat3->hasSemicolon);
CHECK_EQ(Position{3, 22}, stat3->location.end);
}
TEST_SUITE_END(); TEST_SUITE_END();

View file

@ -12,6 +12,7 @@ using namespace Luau;
LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAG(LuauTypestateBuiltins2) LUAU_FASTFLAG(LuauTypestateBuiltins2)
LUAU_FASTFLAG(LuauStringFormatArityFix) LUAU_FASTFLAG(LuauStringFormatArityFix)
LUAU_FASTFLAG(LuauTableCloneClonesType)
LUAU_FASTFLAG(LuauStringFormatErrorSuppression) LUAU_FASTFLAG(LuauStringFormatErrorSuppression)
TEST_SUITE_BEGIN("BuiltinTests"); TEST_SUITE_BEGIN("BuiltinTests");
@ -1587,6 +1588,30 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "string_find_should_not_crash")
)")); )"));
} }
TEST_CASE_FIXTURE(BuiltinsFixture, "table_dot_clone_type_states")
{
CheckResult result = check(R"(
local t1 = {}
t1.x = 5
local t2 = table.clone(t1)
t2.y = 6
t1.z = 3
)");
LUAU_REQUIRE_NO_ERRORS(result);
if (FFlag::LuauTableCloneClonesType)
{
CHECK_EQ(toString(requireType("t1"), {true}), "{ x: number, z: number }");
CHECK_EQ(toString(requireType("t2"), {true}), "{ x: number, y: number }");
}
else
{
CHECK_EQ(toString(requireType("t1"), {true}), "{ x: number, y: number, z: number }");
CHECK_EQ(toString(requireType("t2"), {true}), "{ x: number, y: number, z: number }");
}
}
TEST_CASE_FIXTURE(BuiltinsFixture, "string_format_should_support_any") TEST_CASE_FIXTURE(BuiltinsFixture, "string_format_should_support_any")
{ {
ScopedFastFlag _{FFlag::LuauSolverV2, true}; ScopedFastFlag _{FFlag::LuauSolverV2, true};

View file

@ -20,7 +20,6 @@ LUAU_FASTFLAG(LuauInstantiateInSubtyping)
LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTINT(LuauTarjanChildLimit) LUAU_FASTINT(LuauTarjanChildLimit)
LUAU_FASTFLAG(LuauRetrySubtypingWithoutHiddenPack) LUAU_FASTFLAG(LuauRetrySubtypingWithoutHiddenPack)
LUAU_FASTFLAG(LuauDontRefCountTypesInTypeFunctions)
LUAU_FASTFLAG(DebugLuauEqSatSimplification) LUAU_FASTFLAG(DebugLuauEqSatSimplification)
TEST_SUITE_BEGIN("TypeInferFunctions"); TEST_SUITE_BEGIN("TypeInferFunctions");
@ -2566,7 +2565,7 @@ end
TEST_CASE_FIXTURE(BuiltinsFixture, "tf_suggest_return_type") TEST_CASE_FIXTURE(BuiltinsFixture, "tf_suggest_return_type")
{ {
ScopedFastFlag sffs[] = {{FFlag::LuauSolverV2, true}, {FFlag::LuauDontRefCountTypesInTypeFunctions, true}}; ScopedFastFlag _{FFlag::LuauSolverV2, true};
// CLI-114134: This test: // CLI-114134: This test:
// a) Has a kind of weird result (suggesting `number | false` is not great); // a) Has a kind of weird result (suggesting `number | false` is not great);
@ -2878,8 +2877,6 @@ TEST_CASE_FIXTURE(Fixture, "fuzzer_missing_follow_in_ast_stat_fun")
TEST_CASE_FIXTURE(Fixture, "unifier_should_not_bind_free_types") TEST_CASE_FIXTURE(Fixture, "unifier_should_not_bind_free_types")
{ {
ScopedFastFlag _{FFlag::LuauDontRefCountTypesInTypeFunctions, true};
CheckResult result = check(R"( CheckResult result = check(R"(
function foo(player) function foo(player)
local success,result = player:thing() local success,result = player:thing()

View file

@ -16,52 +16,63 @@ using namespace Luau;
namespace namespace
{ {
std::optional<WithPredicate<TypePackId>> magicFunctionInstanceIsA(
TypeChecker& typeChecker, struct MagicInstanceIsA final : MagicFunction
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
)
{ {
if (expr.args.size != 1) std::optional<WithPredicate<TypePackId>> handleOldSolver(
return std::nullopt; TypeChecker& typeChecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
) override
{
if (expr.args.size != 1)
return std::nullopt;
auto index = expr.func->as<Luau::AstExprIndexName>(); auto index = expr.func->as<Luau::AstExprIndexName>();
auto str = expr.args.data[0]->as<Luau::AstExprConstantString>(); auto str = expr.args.data[0]->as<Luau::AstExprConstantString>();
if (!index || !str) if (!index || !str)
return std::nullopt; return std::nullopt;
std::optional<LValue> lvalue = tryGetLValue(*index->expr); std::optional<LValue> lvalue = tryGetLValue(*index->expr);
std::optional<TypeFun> tfun = scope->lookupType(std::string(str->value.data, str->value.size)); std::optional<TypeFun> tfun = scope->lookupType(std::string(str->value.data, str->value.size));
if (!lvalue || !tfun) if (!lvalue || !tfun)
return std::nullopt; return std::nullopt;
ModulePtr module = typeChecker.currentModule; ModulePtr module = typeChecker.currentModule;
TypePackId booleanPack = module->internalTypes.addTypePack({typeChecker.booleanType}); TypePackId booleanPack = module->internalTypes.addTypePack({typeChecker.booleanType});
return WithPredicate<TypePackId>{booleanPack, {IsAPredicate{std::move(*lvalue), expr.location, tfun->type}}}; return WithPredicate<TypePackId>{booleanPack, {IsAPredicate{std::move(*lvalue), expr.location, tfun->type}}};
} }
void dcrMagicRefinementInstanceIsA(const MagicRefinementContext& ctx) bool infer(const MagicFunctionCallContext&) override
{ {
if (ctx.callSite->args.size != 1 || ctx.discriminantTypes.empty()) return false;
return; }
auto index = ctx.callSite->func->as<Luau::AstExprIndexName>(); void refine(const MagicRefinementContext& ctx) override
auto str = ctx.callSite->args.data[0]->as<Luau::AstExprConstantString>(); {
if (!index || !str) if (ctx.callSite->args.size != 1 || ctx.discriminantTypes.empty())
return; return;
std::optional<TypeId> discriminantTy = ctx.discriminantTypes[0]; auto index = ctx.callSite->func->as<Luau::AstExprIndexName>();
if (!discriminantTy) auto str = ctx.callSite->args.data[0]->as<Luau::AstExprConstantString>();
return; if (!index || !str)
return;
std::optional<TypeId> discriminantTy = ctx.discriminantTypes[0];
if (!discriminantTy)
return;
std::optional<TypeFun> tfun = ctx.scope->lookupType(std::string(str->value.data, str->value.size));
if (!tfun)
return;
LUAU_ASSERT(get<BlockedType>(*discriminantTy));
asMutable(*discriminantTy)->ty.emplace<BoundType>(tfun->type);
}
};
std::optional<TypeFun> tfun = ctx.scope->lookupType(std::string(str->value.data, str->value.size));
if (!tfun)
return;
LUAU_ASSERT(get<BlockedType>(*discriminantTy));
asMutable(*discriminantTy)->ty.emplace<BoundType>(tfun->type);
}
struct RefinementClassFixture : BuiltinsFixture struct RefinementClassFixture : BuiltinsFixture
{ {
@ -85,8 +96,7 @@ struct RefinementClassFixture : BuiltinsFixture
TypePackId isAParams = arena.addTypePack({inst, builtinTypes->stringType}); TypePackId isAParams = arena.addTypePack({inst, builtinTypes->stringType});
TypePackId isARets = arena.addTypePack({builtinTypes->booleanType}); TypePackId isARets = arena.addTypePack({builtinTypes->booleanType});
TypeId isA = arena.addType(FunctionType{isAParams, isARets}); TypeId isA = arena.addType(FunctionType{isAParams, isARets});
getMutable<FunctionType>(isA)->magicFunction = magicFunctionInstanceIsA; getMutable<FunctionType>(isA)->magic = std::make_shared<MagicInstanceIsA>();
getMutable<FunctionType>(isA)->dcrMagicRefinement = dcrMagicRefinementInstanceIsA;
getMutable<ClassType>(inst)->props = { getMutable<ClassType>(inst)->props = {
{"Name", Property{builtinTypes->stringType}}, {"Name", Property{builtinTypes->stringType}},

View file

@ -24,7 +24,6 @@ LUAU_FASTINT(LuauNormalizeCacheLimit);
LUAU_FASTINT(LuauRecursionLimit); LUAU_FASTINT(LuauRecursionLimit);
LUAU_FASTINT(LuauTypeInferRecursionLimit); LUAU_FASTINT(LuauTypeInferRecursionLimit);
LUAU_FASTFLAG(LuauNewSolverVisitErrorExprLvalues) LUAU_FASTFLAG(LuauNewSolverVisitErrorExprLvalues)
LUAU_FASTFLAG(LuauDontRefCountTypesInTypeFunctions)
LUAU_FASTFLAG(InferGlobalTypes) LUAU_FASTFLAG(InferGlobalTypes)
using namespace Luau; using namespace Luau;
@ -1731,7 +1730,7 @@ TEST_CASE_FIXTURE(Fixture, "visit_error_nodes_in_lvalue")
TEST_CASE_FIXTURE(Fixture, "avoid_blocking_type_function") TEST_CASE_FIXTURE(Fixture, "avoid_blocking_type_function")
{ {
ScopedFastFlag sffs[] = {{FFlag::LuauSolverV2, true}, {FFlag::LuauDontRefCountTypesInTypeFunctions, true}}; ScopedFastFlag _{FFlag::LuauSolverV2, true};
LUAU_CHECK_NO_ERRORS(check(R"( LUAU_CHECK_NO_ERRORS(check(R"(
--!strict --!strict
@ -1744,7 +1743,7 @@ TEST_CASE_FIXTURE(Fixture, "avoid_blocking_type_function")
TEST_CASE_FIXTURE(Fixture, "avoid_double_reference_to_free_type") TEST_CASE_FIXTURE(Fixture, "avoid_double_reference_to_free_type")
{ {
ScopedFastFlag sffs[] = {{FFlag::LuauSolverV2, true}, {FFlag::LuauDontRefCountTypesInTypeFunctions, true}}; ScopedFastFlag _{FFlag::LuauSolverV2, true};
LUAU_CHECK_NO_ERRORS(check(R"( LUAU_CHECK_NO_ERRORS(check(R"(
--!strict --!strict