Merge branch 'luau-lang:master' into branch-1

This commit is contained in:
ramdoys 2024-11-14 09:36:29 -05:00 committed by GitHub
commit 81283f8cb4
Signed by: DevComp
GPG key ID: B5690EEEBB952194
82 changed files with 7571 additions and 2526 deletions

View file

@ -1,10 +1,10 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once #pragma once
#include "Luau/AutocompleteTypes.h"
#include "Luau/Location.h" #include "Luau/Location.h"
#include "Luau/Type.h" #include "Luau/Type.h"
#include <unordered_map>
#include <string> #include <string>
#include <memory> #include <memory>
#include <optional> #include <optional>
@ -16,90 +16,8 @@ struct Frontend;
struct SourceModule; struct SourceModule;
struct Module; struct Module;
struct TypeChecker; struct TypeChecker;
struct FileResolver;
using ModulePtr = std::shared_ptr<Module>;
enum class AutocompleteContext
{
Unknown,
Expression,
Statement,
Property,
Type,
Keyword,
String,
};
enum class AutocompleteEntryKind
{
Property,
Binding,
Keyword,
String,
Type,
Module,
GeneratedFunction,
RequirePath,
};
enum class ParenthesesRecommendation
{
None,
CursorAfter,
CursorInside,
};
enum class TypeCorrectKind
{
None,
Correct,
CorrectFunctionResult,
};
struct AutocompleteEntry
{
AutocompleteEntryKind kind = AutocompleteEntryKind::Property;
// Nullopt if kind is Keyword
std::optional<TypeId> type = std::nullopt;
bool deprecated = false;
// Only meaningful if kind is Property.
bool wrongIndexType = false;
// Set if this suggestion matches the type expected in the context
TypeCorrectKind typeCorrect = TypeCorrectKind::None;
std::optional<const ClassType*> containingClass = std::nullopt;
std::optional<const Property*> prop = std::nullopt;
std::optional<std::string> documentationSymbol = std::nullopt;
Tags tags;
ParenthesesRecommendation parens = ParenthesesRecommendation::None;
std::optional<std::string> insertText;
// Only meaningful if kind is Property.
bool indexedWithSelf = false;
};
using AutocompleteEntryMap = std::unordered_map<std::string, AutocompleteEntry>;
struct AutocompleteResult
{
AutocompleteEntryMap entryMap;
std::vector<AstNode*> ancestry;
AutocompleteContext context = AutocompleteContext::Unknown;
AutocompleteResult() = default;
AutocompleteResult(AutocompleteEntryMap entryMap, std::vector<AstNode*> ancestry, AutocompleteContext context)
: entryMap(std::move(entryMap))
, ancestry(std::move(ancestry))
, context(context)
{
}
};
using ModuleName = std::string;
using StringCompletionCallback =
std::function<std::optional<AutocompleteEntryMap>(std::string tag, std::optional<const ClassType*> ctx, std::optional<std::string> contents)>;
AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName, Position position, StringCompletionCallback callback); AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName, Position position, StringCompletionCallback callback);
constexpr char kGeneratedAnonymousFunctionEntryName[] = "function (anonymous autofilled)";
} // namespace Luau } // namespace Luau

View file

@ -0,0 +1,92 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Ast.h"
#include "Luau/Type.h"
#include <unordered_map>
namespace Luau
{
enum class AutocompleteContext
{
Unknown,
Expression,
Statement,
Property,
Type,
Keyword,
String,
};
enum class AutocompleteEntryKind
{
Property,
Binding,
Keyword,
String,
Type,
Module,
GeneratedFunction,
RequirePath,
};
enum class ParenthesesRecommendation
{
None,
CursorAfter,
CursorInside,
};
enum class TypeCorrectKind
{
None,
Correct,
CorrectFunctionResult,
};
struct AutocompleteEntry
{
AutocompleteEntryKind kind = AutocompleteEntryKind::Property;
// Nullopt if kind is Keyword
std::optional<TypeId> type = std::nullopt;
bool deprecated = false;
// Only meaningful if kind is Property.
bool wrongIndexType = false;
// Set if this suggestion matches the type expected in the context
TypeCorrectKind typeCorrect = TypeCorrectKind::None;
std::optional<const ClassType*> containingClass = std::nullopt;
std::optional<const Property*> prop = std::nullopt;
std::optional<std::string> documentationSymbol = std::nullopt;
Tags tags;
ParenthesesRecommendation parens = ParenthesesRecommendation::None;
std::optional<std::string> insertText;
// Only meaningful if kind is Property.
bool indexedWithSelf = false;
};
using AutocompleteEntryMap = std::unordered_map<std::string, AutocompleteEntry>;
struct AutocompleteResult
{
AutocompleteEntryMap entryMap;
std::vector<AstNode*> ancestry;
AutocompleteContext context = AutocompleteContext::Unknown;
AutocompleteResult() = default;
AutocompleteResult(AutocompleteEntryMap entryMap, std::vector<AstNode*> ancestry, AutocompleteContext context)
: entryMap(std::move(entryMap))
, ancestry(std::move(ancestry))
, context(context)
{
}
};
using StringCompletionCallback =
std::function<std::optional<AutocompleteEntryMap>(std::string tag, std::optional<const ClassType*> ctx, std::optional<std::string> contents)>;
constexpr char kGeneratedAnonymousFunctionEntryName[] = "function (anonymous autofilled)";
} // namespace Luau

View file

@ -5,6 +5,7 @@
#include "Luau/Constraint.h" #include "Luau/Constraint.h"
#include "Luau/ControlFlow.h" #include "Luau/ControlFlow.h"
#include "Luau/DataFlowGraph.h" #include "Luau/DataFlowGraph.h"
#include "Luau/EqSatSimplification.h"
#include "Luau/InsertionOrderedMap.h" #include "Luau/InsertionOrderedMap.h"
#include "Luau/Module.h" #include "Luau/Module.h"
#include "Luau/ModuleResolver.h" #include "Luau/ModuleResolver.h"
@ -15,7 +16,6 @@
#include "Luau/TypeFwd.h" #include "Luau/TypeFwd.h"
#include "Luau/TypeUtils.h" #include "Luau/TypeUtils.h"
#include "Luau/Variant.h" #include "Luau/Variant.h"
#include "Luau/Normalize.h"
#include <memory> #include <memory>
#include <vector> #include <vector>
@ -109,6 +109,9 @@ struct ConstraintGenerator
// Needed to be able to enable error-suppression preservation for immediate refinements. // Needed to be able to enable error-suppression preservation for immediate refinements.
NotNull<Normalizer> normalizer; NotNull<Normalizer> normalizer;
NotNull<Simplifier> simplifier;
// Needed to register all available type functions for execution at later stages. // Needed to register all available type functions for execution at later stages.
NotNull<TypeFunctionRuntime> typeFunctionRuntime; NotNull<TypeFunctionRuntime> typeFunctionRuntime;
// Needed to resolve modules to make 'require' import types properly. // Needed to resolve modules to make 'require' import types properly.
@ -128,6 +131,7 @@ struct ConstraintGenerator
ConstraintGenerator( ConstraintGenerator(
ModulePtr module, ModulePtr module,
NotNull<Normalizer> normalizer, NotNull<Normalizer> normalizer,
NotNull<Simplifier> simplifier,
NotNull<TypeFunctionRuntime> typeFunctionRuntime, NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<ModuleResolver> moduleResolver, NotNull<ModuleResolver> moduleResolver,
NotNull<BuiltinTypes> builtinTypes, NotNull<BuiltinTypes> builtinTypes,
@ -405,6 +409,7 @@ private:
TypeId makeUnion(const ScopePtr& scope, Location location, TypeId lhs, TypeId rhs); TypeId makeUnion(const ScopePtr& scope, Location location, TypeId lhs, TypeId rhs);
// make an intersect type function of these two types // make an intersect type function of these two types
TypeId makeIntersect(const ScopePtr& scope, Location location, TypeId lhs, TypeId rhs); TypeId makeIntersect(const ScopePtr& scope, Location location, TypeId lhs, TypeId rhs);
void prepopulateGlobalScopeForFragmentTypecheck(const ScopePtr& globalScope, const ScopePtr& resumeScope, AstStatBlock* program);
/** Scan the program for global definitions. /** Scan the program for global definitions.
* *
@ -435,6 +440,8 @@ private:
const ScopePtr& scope, const ScopePtr& scope,
Location location Location location
); );
TypeId simplifyUnion(const ScopePtr& scope, Location location, TypeId left, TypeId right);
}; };
/** Borrow a vector of pointers from a vector of owning pointers to constraints. /** Borrow a vector of pointers from a vector of owning pointers to constraints.

View file

@ -5,6 +5,7 @@
#include "Luau/Constraint.h" #include "Luau/Constraint.h"
#include "Luau/DataFlowGraph.h" #include "Luau/DataFlowGraph.h"
#include "Luau/DenseHash.h" #include "Luau/DenseHash.h"
#include "Luau/EqSatSimplification.h"
#include "Luau/Error.h" #include "Luau/Error.h"
#include "Luau/Location.h" #include "Luau/Location.h"
#include "Luau/Module.h" #include "Luau/Module.h"
@ -64,6 +65,7 @@ struct ConstraintSolver
NotNull<BuiltinTypes> builtinTypes; NotNull<BuiltinTypes> builtinTypes;
InternalErrorReporter iceReporter; InternalErrorReporter iceReporter;
NotNull<Normalizer> normalizer; NotNull<Normalizer> normalizer;
NotNull<Simplifier> simplifier;
NotNull<TypeFunctionRuntime> typeFunctionRuntime; NotNull<TypeFunctionRuntime> typeFunctionRuntime;
// The entire set of constraints that the solver is trying to resolve. // The entire set of constraints that the solver is trying to resolve.
std::vector<NotNull<Constraint>> constraints; std::vector<NotNull<Constraint>> constraints;
@ -117,6 +119,7 @@ struct ConstraintSolver
explicit ConstraintSolver( explicit ConstraintSolver(
NotNull<Normalizer> normalizer, NotNull<Normalizer> normalizer,
NotNull<Simplifier> simplifier,
NotNull<TypeFunctionRuntime> typeFunctionRuntime, NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<Scope> rootScope, NotNull<Scope> rootScope,
std::vector<NotNull<Constraint>> constraints, std::vector<NotNull<Constraint>> constraints,
@ -384,6 +387,10 @@ public:
**/ **/
void reproduceConstraints(NotNull<Scope> scope, const Location& location, const Substitution& subst); void reproduceConstraints(NotNull<Scope> scope, const Location& location, const Substitution& subst);
TypeId simplifyIntersection(NotNull<Scope> scope, Location location, TypeId left, TypeId right);
TypeId simplifyIntersection(NotNull<Scope> scope, Location location, std::set<TypeId> parts);
TypeId simplifyUnion(NotNull<Scope> scope, Location location, TypeId left, TypeId right);
TypeId errorRecoveryType() const; TypeId errorRecoveryType() const;
TypePackId errorRecoveryTypePack() const; TypePackId errorRecoveryTypePack() const;

View file

@ -0,0 +1,50 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/TypeFwd.h"
#include "Luau/NotNull.h"
#include "Luau/DenseHash.h"
#include <memory>
#include <optional>
#include <vector>
namespace Luau
{
struct TypeArena;
}
// The EqSat stuff is pretty template heavy, so we go to some lengths to prevent
// the complexity from leaking outside its implementation sources.
namespace Luau::EqSatSimplification
{
struct Simplifier;
using SimplifierPtr = std::unique_ptr<Simplifier, void (*)(Simplifier*)>;
SimplifierPtr newSimplifier(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtinTypes);
} // namespace Luau::EqSatSimplification
namespace Luau
{
struct EqSatSimplificationResult
{
TypeId result;
// New type function applications that were created by the reduction phase.
// We return these so that the ConstraintSolver can know to try to reduce
// them.
std::vector<TypeId> newTypeFunctions;
};
using EqSatSimplification::newSimplifier; // NOLINT: clang-tidy thinks these are unused. It is incorrect.
using Luau::EqSatSimplification::Simplifier; // NOLINT
using Luau::EqSatSimplification::SimplifierPtr;
std::optional<EqSatSimplificationResult> eqSatSimplify(NotNull<Simplifier> simplifier, TypeId ty);
} // namespace Luau

View file

@ -0,0 +1,363 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/EGraph.h"
#include "Luau/Id.h"
#include "Luau/Language.h"
#include "Luau/Lexer.h" // For Allocator
#include "Luau/NotNull.h"
#include "Luau/TypeArena.h"
#include "Luau/TypeFwd.h"
namespace Luau
{
struct TypeFunction;
}
namespace Luau::EqSatSimplification
{
using StringId = uint32_t;
using Id = Luau::EqSat::Id;
LUAU_EQSAT_UNIT(TNil);
LUAU_EQSAT_UNIT(TBoolean);
LUAU_EQSAT_UNIT(TNumber);
LUAU_EQSAT_UNIT(TString);
LUAU_EQSAT_UNIT(TThread);
LUAU_EQSAT_UNIT(TTopFunction);
LUAU_EQSAT_UNIT(TTopTable);
LUAU_EQSAT_UNIT(TTopClass);
LUAU_EQSAT_UNIT(TBuffer);
// Used for any type that eqsat can't do anything interesting with.
LUAU_EQSAT_ATOM(TOpaque, TypeId);
LUAU_EQSAT_ATOM(SBoolean, bool);
LUAU_EQSAT_ATOM(SString, StringId);
LUAU_EQSAT_ATOM(TFunction, TypeId);
LUAU_EQSAT_ATOM(TImportedTable, TypeId);
LUAU_EQSAT_ATOM(TClass, TypeId);
LUAU_EQSAT_UNIT(TAny);
LUAU_EQSAT_UNIT(TError);
LUAU_EQSAT_UNIT(TUnknown);
LUAU_EQSAT_UNIT(TNever);
LUAU_EQSAT_NODE_SET(Union);
LUAU_EQSAT_NODE_SET(Intersection);
LUAU_EQSAT_NODE_ARRAY(Negation, 1);
LUAU_EQSAT_NODE_ATOM_WITH_VECTOR(TTypeFun, const TypeFunction*);
LUAU_EQSAT_UNIT(TNoRefine);
LUAU_EQSAT_UNIT(Invalid);
// enodes are immutable, but types are cyclic. We need a way to tie the knot.
// We handle this by generating TBound nodes at points where we encounter cycles.
// Each TBound has an ordinal that we later map onto the type.
// We use a substitution rule to replace all TBound nodes with their referrent.
LUAU_EQSAT_ATOM(TBound, size_t);
// Tables are sufficiently unlike other enodes that the Language.h macros won't cut it.
struct TTable
{
explicit TTable(Id basis);
TTable(Id basis, std::vector<StringId> propNames_, std::vector<Id> propTypes_);
// All TTables extend some other table. This may be TTopTable.
//
// It will frequently be a TImportedTable, in which case we can reuse things
// like source location and documentation info.
Id getBasis() const;
EqSat::Slice<const Id> propTypes() const;
// TODO: Also support read-only table props
// TODO: Indexer type, index result type.
std::vector<StringId> propNames;
// The enode interface
EqSat::Slice<Id> mutableOperands();
EqSat::Slice<const Id> operands() const;
bool operator==(const TTable& rhs) const;
bool operator!=(const TTable& rhs) const
{
return !(*this == rhs);
}
struct Hash
{
size_t operator()(const TTable& value) const;
};
private:
// The first element of this vector is the basis. Subsequent elements are
// property types. As we add other things like read-only properties and
// indexers, the structure of this array is likely to change.
//
// We encode our data in this way so that the operands() method can properly
// return a Slice<Id>.
std::vector<Id> storage;
};
using EType = EqSat::Language<
TNil,
TBoolean,
TNumber,
TString,
TThread,
TTopFunction,
TTopTable,
TTopClass,
TBuffer,
TOpaque,
SBoolean,
SString,
TFunction,
TTable,
TImportedTable,
TClass,
TAny,
TError,
TUnknown,
TNever,
Union,
Intersection,
Negation,
TTypeFun,
Invalid,
TNoRefine,
TBound>;
struct StringCache
{
Allocator allocator;
DenseHashMap<size_t, StringId> strings{{}};
std::vector<std::string_view> views;
StringId add(std::string_view s);
std::string_view asStringView(StringId id) const;
std::string asString(StringId id) const;
};
using EGraph = Luau::EqSat::EGraph<EType, struct Simplify>;
struct Simplify
{
using Data = bool;
template<typename T>
Data make(const EGraph&, const T&) const;
void join(Data& left, const Data& right) const;
};
struct Subst
{
Id eclass;
Id newClass;
std::string desc;
Subst(Id eclass, Id newClass, std::string desc = "");
};
struct Simplifier
{
NotNull<TypeArena> arena;
NotNull<BuiltinTypes> builtinTypes;
EGraph egraph;
StringCache stringCache;
// enodes are immutable but types can be cyclic, so we need some way to
// encode the cycle. This map is used to connect TBound nodes to the right
// eclass.
//
// The cyclicIntersection rewrite rule uses this to sense when a cycle can
// be deleted from an intersection or union.
std::unordered_map<size_t, Id> mappingIdToClass;
std::vector<Subst> substs;
using RewriteRuleFn = void (Simplifier::*)(Id id);
Simplifier(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtinTypes);
// Utilities
const EqSat::EClass<EType, Simplify::Data>& get(Id id) const;
Id find(Id id) const;
Id add(EType enode);
template<typename Tag>
const Tag* isTag(Id id) const;
template<typename Tag>
const Tag* isTag(const EType& enode) const;
void subst(Id from, Id to);
void subst(Id from, Id to, const std::string& ruleName);
void subst(Id from, Id to, const std::string& ruleName, const std::unordered_map<Id, size_t>& forceNodes);
void unionClasses(std::vector<Id>& hereParts, Id there);
// Rewrite rules
void simplifyUnion(Id id);
void uninhabitedIntersection(Id id);
void intersectWithNegatedClass(Id id);
void intersectWithNoRefine(Id id);
void cyclicIntersectionOfUnion(Id id);
void cyclicUnionOfIntersection(Id id);
void expandNegation(Id id);
void intersectionOfUnion(Id id);
void intersectTableProperty(Id id);
void uninhabitedTable(Id id);
void unneededTableModification(Id id);
void builtinTypeFunctions(Id id);
void iffyTypeFunctions(Id id);
};
template<typename Tag>
struct QueryIterator
{
QueryIterator();
QueryIterator(EGraph* egraph, Id eclass);
bool operator==(const QueryIterator& other) const;
bool operator!=(const QueryIterator& other) const;
std::pair<const Tag*, size_t> operator*() const;
QueryIterator& operator++();
QueryIterator& operator++(int);
private:
EGraph* egraph = nullptr;
Id eclass;
size_t index = 0;
};
template<typename Tag>
struct Query
{
EGraph* egraph;
Id eclass;
Query(EGraph* egraph, Id eclass)
: egraph(egraph)
, eclass(eclass)
{
}
QueryIterator<Tag> begin()
{
return QueryIterator<Tag>{egraph, eclass};
}
QueryIterator<Tag> end()
{
return QueryIterator<Tag>{};
}
};
template<typename Tag>
QueryIterator<Tag>::QueryIterator()
: egraph(nullptr)
, eclass(Id{0})
, index(0)
{
}
template<typename Tag>
QueryIterator<Tag>::QueryIterator(EGraph* egraph_, Id eclass)
: egraph(egraph_)
, eclass(eclass)
, index(0)
{
const auto& ecl = (*egraph)[eclass];
static constexpr const int idx = EType::VariantTy::getTypeId<Tag>();
for (const auto& enode : ecl.nodes)
{
if (enode.index() < idx)
++index;
else
break;
}
if (index >= ecl.nodes.size() || ecl.nodes[index].index() != idx)
{
egraph = nullptr;
index = 0;
}
}
template<typename Tag>
bool QueryIterator<Tag>::operator==(const QueryIterator<Tag>& rhs) const
{
if (egraph == nullptr && rhs.egraph == nullptr)
return true;
return egraph == rhs.egraph && eclass == rhs.eclass && index == rhs.index;
}
template<typename Tag>
bool QueryIterator<Tag>::operator!=(const QueryIterator<Tag>& rhs) const
{
return !(*this == rhs);
}
template<typename Tag>
std::pair<const Tag*, size_t> QueryIterator<Tag>::operator*() const
{
LUAU_ASSERT(egraph != nullptr);
EGraph::EClassT& ecl = (*egraph)[eclass];
LUAU_ASSERT(index < ecl.nodes.size());
auto& enode = ecl.nodes[index];
Tag* result = enode.template get<Tag>();
LUAU_ASSERT(result);
return {result, index};
}
// pre-increment
template<typename Tag>
QueryIterator<Tag>& QueryIterator<Tag>::operator++()
{
const auto& ecl = (*egraph)[eclass];
++index;
if (index >= ecl.nodes.size() || ecl.nodes[index].index() != EType::VariantTy::getTypeId<Tag>())
{
egraph = nullptr;
index = 0;
}
return *this;
}
// post-increment
template<typename Tag>
QueryIterator<Tag>& QueryIterator<Tag>::operator++(int)
{
QueryIterator<Tag> res = *this;
++res;
return res;
}
} // namespace Luau::EqSatSimplification

View file

@ -32,7 +32,11 @@ struct ModuleInfo
bool optional = false; bool optional = false;
}; };
using RequireSuggestion = std::string; struct RequireSuggestion
{
std::string label;
std::string fullPath;
};
using RequireSuggestions = std::vector<RequireSuggestion>; using RequireSuggestions = std::vector<RequireSuggestion>;
struct FileResolver struct FileResolver

View file

@ -3,9 +3,10 @@
#include "Luau/Ast.h" #include "Luau/Ast.h"
#include "Luau/Parser.h" #include "Luau/Parser.h"
#include "Luau/Autocomplete.h" #include "Luau/AutocompleteTypes.h"
#include "Luau/DenseHash.h" #include "Luau/DenseHash.h"
#include "Luau/Module.h" #include "Luau/Module.h"
#include "Luau/Frontend.h"
#include <memory> #include <memory>
#include <vector> #include <vector>
@ -27,13 +28,23 @@ struct FragmentParseResult
std::string fragmentToParse; std::string fragmentToParse;
AstStatBlock* root = nullptr; AstStatBlock* root = nullptr;
std::vector<AstNode*> ancestry; std::vector<AstNode*> ancestry;
AstStat* nearestStatement = nullptr;
std::unique_ptr<Allocator> alloc = std::make_unique<Allocator>(); std::unique_ptr<Allocator> alloc = std::make_unique<Allocator>();
}; };
struct FragmentTypeCheckResult struct FragmentTypeCheckResult
{ {
ModulePtr incrementalModule = nullptr; ModulePtr incrementalModule = nullptr;
Scope* freshScope = nullptr; ScopePtr freshScope;
std::vector<AstNode*> ancestry;
};
struct FragmentAutocompleteResult
{
ModulePtr incrementalModule;
Scope* freshScope;
TypeArena arenaForAutocomplete;
AutocompleteResult acResults;
}; };
FragmentAutocompleteAncestryResult findAncestryForFragmentParse(AstStatBlock* root, const Position& cursorPos); FragmentAutocompleteAncestryResult findAncestryForFragmentParse(AstStatBlock* root, const Position& cursorPos);
@ -48,11 +59,11 @@ FragmentTypeCheckResult typecheckFragment(
std::string_view src std::string_view src
); );
AutocompleteResult fragmentAutocomplete( FragmentAutocompleteResult fragmentAutocomplete(
Frontend& frontend, Frontend& frontend,
std::string_view src, std::string_view src,
const ModuleName& moduleName, const ModuleName& moduleName,
Position& cursorPosition, Position cursorPosition,
std::optional<FrontendOptions> opts, std::optional<FrontendOptions> opts,
StringCompletionCallback callback StringCompletionCallback callback
); );

View file

@ -44,6 +44,7 @@ struct ToStringOptions
bool hideTableKind = false; // If true, all tables will be surrounded with plain '{}' bool hideTableKind = false; // If true, all tables will be surrounded with plain '{}'
bool hideNamedFunctionTypeParameters = false; // If true, type parameters of functions will be hidden at top-level. bool hideNamedFunctionTypeParameters = false; // If true, type parameters of functions will be hidden at top-level.
bool hideFunctionSelfArgument = false; // If true, `self: X` will be omitted from the function signature if the function has self bool hideFunctionSelfArgument = false; // If true, `self: X` will be omitted from the function signature if the function has self
bool useQuestionMarks = true; // If true, use a postfix ? for options, else write them out as unions that include nil.
size_t maxTableLength = size_t(FInt::LuauTableTypeMaximumStringifierLength); // Only applied to TableTypes size_t maxTableLength = size_t(FInt::LuauTableTypeMaximumStringifierLength); // Only applied to TableTypes
size_t maxTypeLength = size_t(FInt::LuauTypeMaximumStringifierLength); size_t maxTypeLength = size_t(FInt::LuauTypeMaximumStringifierLength);
size_t compositeTypesSingleLineLimit = 5; // The number of type elements permitted on a single line when printing type unions/intersections size_t compositeTypesSingleLineLimit = 5; // The number of type elements permitted on a single line when printing type unions/intersections

View file

@ -31,6 +31,7 @@ namespace Luau
struct TypeArena; struct TypeArena;
struct Scope; struct Scope;
using ScopePtr = std::shared_ptr<Scope>; using ScopePtr = std::shared_ptr<Scope>;
struct Module;
struct TypeFunction; struct TypeFunction;
struct Constraint; struct Constraint;
@ -598,6 +599,18 @@ struct ClassType
} }
}; };
// Data required to initialize a user-defined function and its environment
struct UserDefinedFunctionData
{
// Store a weak module reference to ensure the lifetime requirements are preserved
std::weak_ptr<Module> owner;
// References to AST elements are owned by the Module allocator which also stores this type
AstStatTypeFunction* definition = nullptr;
DenseHashMap<Name, AstStatTypeFunction*> environment{""};
};
/** /**
* An instance of a type function that has not yet been reduced to a more concrete * An instance of a type function that has not yet been reduced to a more concrete
* type. The constraint solver receives a constraint to reduce each * type. The constraint solver receives a constraint to reduce each
@ -613,17 +626,20 @@ struct TypeFunctionInstanceType
std::vector<TypePackId> packArguments; std::vector<TypePackId> packArguments;
std::optional<AstName> userFuncName; // Name of the user-defined type function; only available for UDTFs std::optional<AstName> userFuncName; // Name of the user-defined type function; only available for UDTFs
UserDefinedFunctionData userFuncData;
TypeFunctionInstanceType( TypeFunctionInstanceType(
NotNull<const TypeFunction> function, NotNull<const TypeFunction> function,
std::vector<TypeId> typeArguments, std::vector<TypeId> typeArguments,
std::vector<TypePackId> packArguments, std::vector<TypePackId> packArguments,
std::optional<AstName> userFuncName = std::nullopt std::optional<AstName> userFuncName,
UserDefinedFunctionData userFuncData
) )
: function(function) : function(function)
, typeArguments(typeArguments) , typeArguments(typeArguments)
, packArguments(packArguments) , packArguments(packArguments)
, userFuncName(userFuncName) , userFuncName(userFuncName)
, userFuncData(userFuncData)
{ {
} }
@ -640,6 +656,13 @@ struct TypeFunctionInstanceType
, packArguments(packArguments) , packArguments(packArguments)
{ {
} }
TypeFunctionInstanceType(NotNull<const TypeFunction> function, std::vector<TypeId> typeArguments, std::vector<TypePackId> packArguments)
: function{function}
, typeArguments(typeArguments)
, packArguments(packArguments)
{
}
}; };
/** Represents a pending type alias instantiation. /** Represents a pending type alias instantiation.

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,27 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/AutocompleteTypes.h"
namespace Luau
{
struct Module;
struct FileResolver;
using ModulePtr = std::shared_ptr<Module>;
using ModuleName = std::string;
AutocompleteResult autocomplete_(
const ModulePtr& module,
NotNull<BuiltinTypes> builtinTypes,
TypeArena* typeArena,
std::vector<AstNode*>& ancestry,
Scope* globalScope,
const ScopePtr& scopeAtPosition,
Position position,
FileResolver* fileResolver,
StringCompletionCallback callback
);
} // namespace Luau

View file

@ -33,7 +33,7 @@ LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease)
LUAU_FASTFLAGVARIABLE(LuauTypestateBuiltins2) LUAU_FASTFLAGVARIABLE(LuauTypestateBuiltins2)
LUAU_FASTFLAGVARIABLE(LuauStringFormatArityFix) LUAU_FASTFLAGVARIABLE(LuauStringFormatArityFix)
LUAU_FASTFLAG(AutocompleteRequirePathSuggestions) LUAU_FASTFLAG(AutocompleteRequirePathSuggestions2);
namespace Luau namespace Luau
{ {
@ -426,7 +426,7 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC
attachDcrMagicFunction(ttv->props["freeze"].type(), dcrMagicFunctionFreeze); attachDcrMagicFunction(ttv->props["freeze"].type(), dcrMagicFunctionFreeze);
} }
if (FFlag::AutocompleteRequirePathSuggestions) if (FFlag::AutocompleteRequirePathSuggestions2)
{ {
TypeId requireTy = getGlobalBinding(globals, "require"); TypeId requireTy = getGlobalBinding(globals, "require");
attachTag(requireTy, kRequireTagName); attachTag(requireTy, kRequireTagName);

View file

@ -3,6 +3,8 @@
#include "Luau/Constraint.h" #include "Luau/Constraint.h"
#include "Luau/VisitType.h" #include "Luau/VisitType.h"
LUAU_FASTFLAGVARIABLE(LuauDontRefCountTypesInTypeFunctions)
namespace Luau namespace Luau
{ {
@ -46,6 +48,21 @@ struct ReferenceCountInitializer : TypeOnceVisitor
// ClassTypes never contain free types. // ClassTypes never contain free types.
return false; return false;
} }
bool visit(TypeId, const TypeFunctionInstanceType&) override
{
// We do not consider reference counted types that are inside a type
// function to be part of the reachable reference counted types.
// Otherwise, code can be constructed in just the right way such
// that two type functions both claim to mutate a free type, which
// prevents either type function from trying to generalize it, so
// we potentially get stuck.
//
// The default behavior here is `true` for "visit the child types"
// of this type, hence:
return !FFlag::LuauDontRefCountTypesInTypeFunctions;
}
}; };
bool isReferenceCountedType(const TypeId typ) bool isReferenceCountedType(const TypeId typ)

View file

@ -10,6 +10,7 @@
#include "Luau/Def.h" #include "Luau/Def.h"
#include "Luau/DenseHash.h" #include "Luau/DenseHash.h"
#include "Luau/ModuleResolver.h" #include "Luau/ModuleResolver.h"
#include "Luau/NotNull.h"
#include "Luau/RecursionCounter.h" #include "Luau/RecursionCounter.h"
#include "Luau/Refinement.h" #include "Luau/Refinement.h"
#include "Luau/Scope.h" #include "Luau/Scope.h"
@ -30,10 +31,14 @@
LUAU_FASTINT(LuauCheckRecursionLimit) LUAU_FASTINT(LuauCheckRecursionLimit)
LUAU_FASTFLAG(DebugLuauLogSolverToJson) LUAU_FASTFLAG(DebugLuauLogSolverToJson)
LUAU_FASTFLAG(DebugLuauMagicTypes) LUAU_FASTFLAG(DebugLuauMagicTypes)
LUAU_FASTFLAG(DebugLuauEqSatSimplification)
LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease)
LUAU_FASTFLAG(LuauTypestateBuiltins2) LUAU_FASTFLAG(LuauTypestateBuiltins2)
LUAU_FASTFLAGVARIABLE(LuauNewSolverVisitErrorExprLvalues) LUAU_FASTFLAGVARIABLE(LuauNewSolverVisitErrorExprLvalues)
LUAU_FASTFLAGVARIABLE(LuauNewSolverPrePopulateClasses)
LUAU_FASTFLAGVARIABLE(LuauUserTypeFunExportedAndLocal)
LUAU_FASTFLAGVARIABLE(LuauNewSolverPopulateTableLocations)
namespace Luau namespace Luau
{ {
@ -170,6 +175,7 @@ bool hasFreeType(TypeId ty)
ConstraintGenerator::ConstraintGenerator( ConstraintGenerator::ConstraintGenerator(
ModulePtr module, ModulePtr module,
NotNull<Normalizer> normalizer, NotNull<Normalizer> normalizer,
NotNull<Simplifier> simplifier,
NotNull<TypeFunctionRuntime> typeFunctionRuntime, NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<ModuleResolver> moduleResolver, NotNull<ModuleResolver> moduleResolver,
NotNull<BuiltinTypes> builtinTypes, NotNull<BuiltinTypes> builtinTypes,
@ -186,6 +192,7 @@ ConstraintGenerator::ConstraintGenerator(
, rootScope(nullptr) , rootScope(nullptr)
, dfg(dfg) , dfg(dfg)
, normalizer(normalizer) , normalizer(normalizer)
, simplifier(simplifier)
, typeFunctionRuntime(typeFunctionRuntime) , typeFunctionRuntime(typeFunctionRuntime)
, moduleResolver(moduleResolver) , moduleResolver(moduleResolver)
, ice(ice) , ice(ice)
@ -255,7 +262,7 @@ void ConstraintGenerator::visitModuleRoot(AstStatBlock* block)
d = follow(d); d = follow(d);
if (d == ty) if (d == ty)
continue; continue;
domainTy = simplifyUnion(builtinTypes, arena, domainTy, d).result; domainTy = simplifyUnion(scope, Location{}, domainTy, d);
} }
LUAU_ASSERT(get<BlockedType>(ty)); LUAU_ASSERT(get<BlockedType>(ty));
@ -265,7 +272,15 @@ void ConstraintGenerator::visitModuleRoot(AstStatBlock* block)
void ConstraintGenerator::visitFragmentRoot(const ScopePtr& resumeScope, AstStatBlock* block) void ConstraintGenerator::visitFragmentRoot(const ScopePtr& resumeScope, AstStatBlock* block)
{ {
// We prepopulate global data in the resumeScope to avoid writing data into the old modules scopes
prepopulateGlobalScopeForFragmentTypecheck(globalScope, resumeScope, block);
// Pre
// We need to pop the interior types,
interiorTypes.emplace_back();
visitBlockWithoutChildScope(resumeScope, block); visitBlockWithoutChildScope(resumeScope, block);
// Post
interiorTypes.pop_back();
fillInInferredBindings(resumeScope, block); fillInInferredBindings(resumeScope, block);
if (logger) if (logger)
@ -280,7 +295,7 @@ void ConstraintGenerator::visitFragmentRoot(const ScopePtr& resumeScope, AstStat
d = follow(d); d = follow(d);
if (d == ty) if (d == ty)
continue; continue;
domainTy = simplifyUnion(builtinTypes, arena, domainTy, d).result; domainTy = simplifyUnion(resumeScope, resumeScope->location, domainTy, d);
} }
LUAU_ASSERT(get<BlockedType>(ty)); LUAU_ASSERT(get<BlockedType>(ty));
@ -653,6 +668,7 @@ void ConstraintGenerator::applyRefinements(const ScopePtr& scope, Location locat
void ConstraintGenerator::checkAliases(const ScopePtr& scope, AstStatBlock* block) void ConstraintGenerator::checkAliases(const ScopePtr& scope, AstStatBlock* block)
{ {
std::unordered_map<Name, Location> aliasDefinitionLocations; std::unordered_map<Name, Location> aliasDefinitionLocations;
std::unordered_map<Name, Location> classDefinitionLocations;
// In order to enable mutually-recursive type aliases, we need to // In order to enable mutually-recursive type aliases, we need to
// populate the type bindings before we actually check any of the // populate the type bindings before we actually check any of the
@ -708,7 +724,7 @@ void ConstraintGenerator::checkAliases(const ScopePtr& scope, AstStatBlock* bloc
continue; continue;
} }
if (scope->parent != globalScope) if (!FFlag::LuauUserTypeFunExportedAndLocal && scope->parent != globalScope)
{ {
reportError(function->location, GenericError{"Local user-defined functions are not supported yet"}); reportError(function->location, GenericError{"Local user-defined functions are not supported yet"});
continue; continue;
@ -737,19 +753,103 @@ void ConstraintGenerator::checkAliases(const ScopePtr& scope, AstStatBlock* bloc
if (std::optional<std::string> error = typeFunctionRuntime->registerFunction(function)) if (std::optional<std::string> error = typeFunctionRuntime->registerFunction(function))
reportError(function->location, GenericError{*error}); reportError(function->location, GenericError{*error});
TypeId typeFunctionTy = arena->addType(TypeFunctionInstanceType{ UserDefinedFunctionData udtfData;
NotNull{&builtinTypeFunctions().userFunc},
std::move(typeParams), if (FFlag::LuauUserTypeFunExportedAndLocal)
{}, {
function->name, udtfData.owner = module;
}); udtfData.definition = function;
}
TypeId typeFunctionTy = arena->addType(
TypeFunctionInstanceType{NotNull{&builtinTypeFunctions().userFunc}, std::move(typeParams), {}, function->name, udtfData}
);
TypeFun typeFunction{std::move(quantifiedTypeParams), typeFunctionTy}; TypeFun typeFunction{std::move(quantifiedTypeParams), typeFunctionTy};
// Set type bindings and definition locations for this user-defined type function // Set type bindings and definition locations for this user-defined type function
if (FFlag::LuauUserTypeFunExportedAndLocal && function->exported)
scope->exportedTypeBindings[function->name.value] = std::move(typeFunction);
else
scope->privateTypeBindings[function->name.value] = std::move(typeFunction); scope->privateTypeBindings[function->name.value] = std::move(typeFunction);
aliasDefinitionLocations[function->name.value] = function->location; aliasDefinitionLocations[function->name.value] = function->location;
} }
else if (auto classDeclaration = stat->as<AstStatDeclareClass>())
{
if (!FFlag::LuauNewSolverPrePopulateClasses)
continue;
if (scope->exportedTypeBindings.count(classDeclaration->name.value))
{
auto it = classDefinitionLocations.find(classDeclaration->name.value);
LUAU_ASSERT(it != classDefinitionLocations.end());
reportError(classDeclaration->location, DuplicateTypeDefinition{classDeclaration->name.value, it->second});
continue;
}
// A class might have no name if the code is syntactically
// illegal. We mustn't prepopulate anything in this case.
if (classDeclaration->name == kParseNameError)
continue;
ScopePtr defnScope = childScope(classDeclaration, scope);
TypeId initialType = arena->addType(BlockedType{});
TypeFun initialFun{initialType};
scope->exportedTypeBindings[classDeclaration->name.value] = std::move(initialFun);
classDefinitionLocations[classDeclaration->name.value] = classDeclaration->location;
}
}
if (FFlag::LuauUserTypeFunExportedAndLocal)
{
// Additional pass for user-defined type functions to fill in their environments completely
for (AstStat* stat : block->body)
{
if (auto function = stat->as<AstStatTypeFunction>())
{
// Find the type function we have already created
TypeFunctionInstanceType* mainTypeFun = nullptr;
if (auto it = scope->privateTypeBindings.find(function->name.value); it != scope->privateTypeBindings.end())
mainTypeFun = getMutable<TypeFunctionInstanceType>(it->second.type);
if (!mainTypeFun)
{
if (auto it = scope->exportedTypeBindings.find(function->name.value); it != scope->exportedTypeBindings.end())
mainTypeFun = getMutable<TypeFunctionInstanceType>(it->second.type);
}
// Fill it with all visible type functions
if (mainTypeFun)
{
UserDefinedFunctionData& userFuncData = mainTypeFun->userFuncData;
for (Scope* curr = scope.get(); curr; curr = curr->parent.get())
{
for (auto& [name, tf] : curr->privateTypeBindings)
{
if (userFuncData.environment.find(name))
continue;
if (auto ty = get<TypeFunctionInstanceType>(tf.type); ty && ty->userFuncData.definition)
userFuncData.environment[name] = ty->userFuncData.definition;
}
for (auto& [name, tf] : curr->exportedTypeBindings)
{
if (userFuncData.environment.find(name))
continue;
if (auto ty = get<TypeFunctionInstanceType>(tf.type); ty && ty->userFuncData.definition)
userFuncData.environment[name] = ty->userFuncData.definition;
}
}
}
}
}
} }
} }
@ -871,12 +971,8 @@ ControlFlow ConstraintGenerator::visitBlockWithoutChildScope_DEPRECATED(const Sc
if (std::optional<std::string> error = typeFunctionRuntime->registerFunction(function)) if (std::optional<std::string> error = typeFunctionRuntime->registerFunction(function))
reportError(function->location, GenericError{*error}); reportError(function->location, GenericError{*error});
TypeId typeFunctionTy = arena->addType(TypeFunctionInstanceType{ TypeId typeFunctionTy =
NotNull{&builtinTypeFunctions().userFunc}, arena->addType(TypeFunctionInstanceType{NotNull{&builtinTypeFunctions().userFunc}, std::move(typeParams), {}, function->name, {}});
std::move(typeParams),
{},
function->name,
});
TypeFun typeFunction{std::move(quantifiedTypeParams), typeFunctionTy}; TypeFun typeFunction{std::move(quantifiedTypeParams), typeFunctionTy};
@ -1645,6 +1741,11 @@ static bool isMetamethod(const Name& name)
ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatDeclareClass* declaredClass) ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatDeclareClass* declaredClass)
{ {
// If a class with the same name was already defined, we skip over
auto bindingIt = scope->exportedTypeBindings.find(declaredClass->name.value);
if (FFlag::LuauNewSolverPrePopulateClasses && bindingIt == scope->exportedTypeBindings.end())
return ControlFlow::None;
std::optional<TypeId> superTy = std::make_optional(builtinTypes->classType); std::optional<TypeId> superTy = std::make_optional(builtinTypes->classType);
if (declaredClass->superName) if (declaredClass->superName)
{ {
@ -1659,6 +1760,9 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatDeclareClas
// We don't have generic classes, so this assertion _should_ never be hit. // We don't have generic classes, so this assertion _should_ never be hit.
LUAU_ASSERT(lookupType->typeParams.size() == 0 && lookupType->typePackParams.size() == 0); LUAU_ASSERT(lookupType->typeParams.size() == 0 && lookupType->typePackParams.size() == 0);
if (FFlag::LuauNewSolverPrePopulateClasses)
superTy = follow(lookupType->type);
else
superTy = lookupType->type; superTy = lookupType->type;
if (!get<ClassType>(follow(*superTy))) if (!get<ClassType>(follow(*superTy)))
@ -1682,6 +1786,13 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatDeclareClas
ctv->metatable = metaTy; ctv->metatable = metaTy;
if (FFlag::LuauNewSolverPrePopulateClasses)
{
TypeId classBindTy = bindingIt->second.type;
emplaceType<BoundType>(asMutable(classBindTy), classTy);
}
else
scope->exportedTypeBindings[className] = TypeFun{{}, classTy}; scope->exportedTypeBindings[className] = TypeFun{{}, classTy};
if (declaredClass->indexer) if (declaredClass->indexer)
@ -2763,7 +2874,7 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprLocal* local
case ErrorSuppression::DoNotSuppress: case ErrorSuppression::DoNotSuppress:
break; break;
case ErrorSuppression::Suppress: case ErrorSuppression::Suppress:
ty = simplifyUnion(builtinTypes, arena, *ty, builtinTypes->errorType).result; ty = simplifyUnion(scope, local->location, *ty, builtinTypes->errorType);
break; break;
case ErrorSuppression::NormalizationFailed: case ErrorSuppression::NormalizationFailed:
reportError(local->local->annotation->location, NormalizationTooComplex{}); reportError(local->local->annotation->location, NormalizationTooComplex{});
@ -2844,6 +2955,10 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprTable* expr,
ttv->state = TableState::Unsealed; ttv->state = TableState::Unsealed;
ttv->definitionModuleName = module->name; ttv->definitionModuleName = module->name;
if (FFlag::LuauNewSolverPopulateTableLocations)
{
ttv->definitionLocation = expr->location;
}
ttv->scope = scope.get(); ttv->scope = scope.get();
interiorTypes.back().push_back(ty); interiorTypes.back().push_back(ty);
@ -3301,7 +3416,16 @@ TypeId ConstraintGenerator::resolveTableType(const ScopePtr& scope, AstType* ty,
ice->ice("Unexpected property access " + std::to_string(int(astIndexer->access))); ice->ice("Unexpected property access " + std::to_string(int(astIndexer->access)));
} }
return arena->addType(TableType{props, indexer, scope->level, scope.get(), TableState::Sealed}); TypeId tableTy = arena->addType(TableType{props, indexer, scope->level, scope.get(), TableState::Sealed});
TableType* ttv = getMutable<TableType>(tableTy);
if (FFlag::LuauNewSolverPopulateTableLocations)
{
ttv->definitionModuleName = module->name;
ttv->definitionLocation = tab->location;
}
return tableTy;
} }
TypeId ConstraintGenerator::resolveFunctionType( TypeId ConstraintGenerator::resolveFunctionType(
@ -3616,6 +3740,32 @@ TypeId ConstraintGenerator::makeIntersect(const ScopePtr& scope, Location locati
return resultType; return resultType;
} }
struct FragmentTypeCheckGlobalPrepopulator : AstVisitor
{
const NotNull<Scope> globalScope;
const NotNull<Scope> currentScope;
const NotNull<const DataFlowGraph> dfg;
FragmentTypeCheckGlobalPrepopulator(NotNull<Scope> globalScope, NotNull<Scope> currentScope, NotNull<const DataFlowGraph> dfg)
: globalScope(globalScope)
, currentScope(currentScope)
, dfg(dfg)
{
}
bool visit(AstExprGlobal* global) override
{
if (auto ty = globalScope->lookup(global->name))
{
DefId def = dfg->getDef(global);
// We only want to write into the current scope the type of the global
currentScope->lvalueTypes[def] = *ty;
}
return true;
}
};
struct GlobalPrepopulator : AstVisitor struct GlobalPrepopulator : AstVisitor
{ {
const NotNull<Scope> globalScope; const NotNull<Scope> globalScope;
@ -3662,6 +3812,14 @@ struct GlobalPrepopulator : AstVisitor
} }
}; };
void ConstraintGenerator::prepopulateGlobalScopeForFragmentTypecheck(const ScopePtr& globalScope, const ScopePtr& resumeScope, AstStatBlock* program)
{
FragmentTypeCheckGlobalPrepopulator gp{NotNull{globalScope.get()}, NotNull{resumeScope.get()}, dfg};
if (prepareModuleScope)
prepareModuleScope(module->name, resumeScope);
program->visit(&gp);
}
void ConstraintGenerator::prepopulateGlobalScope(const ScopePtr& globalScope, AstStatBlock* program) void ConstraintGenerator::prepopulateGlobalScope(const ScopePtr& globalScope, AstStatBlock* program)
{ {
GlobalPrepopulator gp{NotNull{globalScope.get()}, arena, dfg}; GlobalPrepopulator gp{NotNull{globalScope.get()}, arena, dfg};
@ -3813,6 +3971,24 @@ TypeId ConstraintGenerator::createTypeFunctionInstance(
return result; return result;
} }
TypeId ConstraintGenerator::simplifyUnion(const ScopePtr& scope, Location location, TypeId left, TypeId right)
{
if (FFlag::DebugLuauEqSatSimplification)
{
TypeId ty = arena->addType(UnionType{{left, right}});
std::optional<EqSatSimplificationResult> res = eqSatSimplify(simplifier, ty);
if (!res)
return ty;
for (TypeId tyFun : res->newTypeFunctions)
addConstraint(scope, location, ReduceConstraint{tyFun});
return res->result;
}
else
return ::Luau::simplifyUnion(builtinTypes, arena, left, right).result;
}
std::vector<NotNull<Constraint>> borrowConstraints(const std::vector<ConstraintPtr>& constraints) std::vector<NotNull<Constraint>> borrowConstraints(const std::vector<ConstraintPtr>& constraints)
{ {
std::vector<NotNull<Constraint>> result; std::vector<NotNull<Constraint>> result;

View file

@ -33,6 +33,8 @@ LUAU_FASTFLAGVARIABLE(DebugLuauLogBindings)
LUAU_FASTINTVARIABLE(LuauSolverRecursionLimit, 500) LUAU_FASTINTVARIABLE(LuauSolverRecursionLimit, 500)
LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease)
LUAU_FASTFLAGVARIABLE(LuauRemoveNotAnyHack) LUAU_FASTFLAGVARIABLE(LuauRemoveNotAnyHack)
LUAU_FASTFLAGVARIABLE(DebugLuauEqSatSimplification)
LUAU_FASTFLAG(LuauNewSolverPopulateTableLocations)
namespace Luau namespace Luau
{ {
@ -319,6 +321,7 @@ struct InstantiationQueuer : TypeOnceVisitor
ConstraintSolver::ConstraintSolver( ConstraintSolver::ConstraintSolver(
NotNull<Normalizer> normalizer, NotNull<Normalizer> normalizer,
NotNull<Simplifier> simplifier,
NotNull<TypeFunctionRuntime> typeFunctionRuntime, NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<Scope> rootScope, NotNull<Scope> rootScope,
std::vector<NotNull<Constraint>> constraints, std::vector<NotNull<Constraint>> constraints,
@ -332,6 +335,7 @@ ConstraintSolver::ConstraintSolver(
: arena(normalizer->arena) : arena(normalizer->arena)
, builtinTypes(normalizer->builtinTypes) , builtinTypes(normalizer->builtinTypes)
, normalizer(normalizer) , normalizer(normalizer)
, simplifier(simplifier)
, typeFunctionRuntime(typeFunctionRuntime) , typeFunctionRuntime(typeFunctionRuntime)
, constraints(std::move(constraints)) , constraints(std::move(constraints))
, rootScope(rootScope) , rootScope(rootScope)
@ -1109,9 +1113,15 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul
target = follow(instantiated); target = follow(instantiated);
} }
if (FFlag::LuauNewSolverPopulateTableLocations)
{
// This is a new type - redefine the location.
ttv->definitionLocation = constraint->location;
ttv->definitionModuleName = currentModuleName;
}
ttv->instantiatedTypeParams = typeArguments; ttv->instantiatedTypeParams = typeArguments;
ttv->instantiatedTypePackParams = packArguments; ttv->instantiatedTypePackParams = packArguments;
// TODO: Fill in definitionModuleName.
} }
bindResult(target); bindResult(target);
@ -1433,7 +1443,8 @@ bool ConstraintSolver::tryDispatch(const FunctionCheckConstraint& c, NotNull<con
} }
} }
} }
else if (expr->is<AstExprConstantBool>() || expr->is<AstExprConstantString>() || expr->is<AstExprConstantNumber>() || expr->is<AstExprConstantNil>()) else if (expr->is<AstExprConstantBool>() || expr->is<AstExprConstantString>() || expr->is<AstExprConstantNumber>() ||
expr->is<AstExprConstantNil>())
{ {
Unifier2 u2{arena, builtinTypes, constraint->scope, NotNull{&iceReporter}}; Unifier2 u2{arena, builtinTypes, constraint->scope, NotNull{&iceReporter}};
u2.unify(actualArgTy, expectedArgTy); u2.unify(actualArgTy, expectedArgTy);
@ -1794,7 +1805,7 @@ bool ConstraintSolver::tryDispatch(const AssignPropConstraint& c, NotNull<const
upperTable->props[c.propName] = rhsType; upperTable->props[c.propName] = rhsType;
// Food for thought: Could we block if simplification encounters a blocked type? // Food for thought: Could we block if simplification encounters a blocked type?
lhsFree->upperBound = simplifyIntersection(builtinTypes, arena, lhsFreeUpperBound, newUpperBound).result; lhsFree->upperBound = simplifyIntersection(constraint->scope, constraint->location, lhsFreeUpperBound, newUpperBound);
bind(constraint, c.propType, rhsType); bind(constraint, c.propType, rhsType);
return true; return true;
@ -2008,7 +2019,7 @@ bool ConstraintSolver::tryDispatch(const AssignIndexConstraint& c, NotNull<const
} }
} }
TypeId res = simplifyIntersection(builtinTypes, arena, std::move(parts)).result; TypeId res = simplifyIntersection(constraint->scope, constraint->location, std::move(parts));
unify(constraint, rhsType, res); unify(constraint, rhsType, res);
} }
@ -2326,12 +2337,7 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl
return true; return true;
} }
bool ConstraintSolver::tryDispatchIterableFunction( bool ConstraintSolver::tryDispatchIterableFunction(TypeId nextTy, TypeId tableTy, const IterableConstraint& c, NotNull<const Constraint> constraint)
TypeId nextTy,
TypeId tableTy,
const IterableConstraint& c,
NotNull<const Constraint> constraint
)
{ {
const FunctionType* nextFn = get<FunctionType>(nextTy); const FunctionType* nextFn = get<FunctionType>(nextTy);
// If this does not hold, we should've never called `tryDispatchIterableFunction` in the first place. // If this does not hold, we should've never called `tryDispatchIterableFunction` in the first place.
@ -2593,9 +2599,9 @@ std::pair<std::vector<TypeId>, std::optional<TypeId>> ConstraintSolver::lookupTa
// if we're in an lvalue context, we need the _common_ type here. // if we're in an lvalue context, we need the _common_ type here.
if (context == ValueContext::LValue) if (context == ValueContext::LValue)
return {{}, simplifyIntersection(builtinTypes, arena, one, two).result}; return {{}, simplifyIntersection(constraint->scope, constraint->location, one, two)};
return {{}, simplifyUnion(builtinTypes, arena, one, two).result}; return {{}, simplifyUnion(constraint->scope, constraint->location, one, two)};
} }
// if we're in an lvalue context, we need the _common_ type here. // if we're in an lvalue context, we need the _common_ type here.
else if (context == ValueContext::LValue) else if (context == ValueContext::LValue)
@ -2627,7 +2633,7 @@ std::pair<std::vector<TypeId>, std::optional<TypeId>> ConstraintSolver::lookupTa
{ {
TypeId one = *begin(options); TypeId one = *begin(options);
TypeId two = *(++begin(options)); TypeId two = *(++begin(options));
return {{}, simplifyIntersection(builtinTypes, arena, one, two).result}; return {{}, simplifyIntersection(constraint->scope, constraint->location, one, two)};
} }
else else
return {{}, arena->addType(IntersectionType{std::vector<TypeId>(begin(options), end(options))})}; return {{}, arena->addType(IntersectionType{std::vector<TypeId>(begin(options), end(options))})};
@ -3016,6 +3022,63 @@ bool ConstraintSolver::hasUnresolvedConstraints(TypeId ty)
return false; return false;
} }
TypeId ConstraintSolver::simplifyIntersection(NotNull<Scope> scope, Location location, TypeId left, TypeId right)
{
if (FFlag::DebugLuauEqSatSimplification)
{
TypeId ty = arena->addType(IntersectionType{{left, right}});
std::optional<EqSatSimplificationResult> res = eqSatSimplify(simplifier, ty);
if (!res)
return ty;
for (TypeId ty : res->newTypeFunctions)
pushConstraint(scope, location, ReduceConstraint{ty});
return res->result;
}
else
return ::Luau::simplifyIntersection(builtinTypes, arena, left, right).result;
}
TypeId ConstraintSolver::simplifyIntersection(NotNull<Scope> scope, Location location, std::set<TypeId> parts)
{
if (FFlag::DebugLuauEqSatSimplification)
{
TypeId ty = arena->addType(IntersectionType{std::vector(parts.begin(), parts.end())});
std::optional<EqSatSimplificationResult> res = eqSatSimplify(simplifier, ty);
if (!res)
return ty;
for (TypeId ty : res->newTypeFunctions)
pushConstraint(scope, location, ReduceConstraint{ty});
return res->result;
}
else
return ::Luau::simplifyIntersection(builtinTypes, arena, std::move(parts)).result;
}
TypeId ConstraintSolver::simplifyUnion(NotNull<Scope> scope, Location location, TypeId left, TypeId right)
{
if (FFlag::DebugLuauEqSatSimplification)
{
TypeId ty = arena->addType(UnionType{{left, right}});
std::optional<EqSatSimplificationResult> res = eqSatSimplify(simplifier, ty);
if (!res)
return ty;
for (TypeId ty : res->newTypeFunctions)
pushConstraint(scope, location, ReduceConstraint{ty});
return res->result;
}
else
return ::Luau::simplifyUnion(builtinTypes, arena, left, right).result;
}
TypeId ConstraintSolver::errorRecoveryType() const TypeId ConstraintSolver::errorRecoveryType() const
{ {
return builtinTypes->errorRecoveryType(); return builtinTypes->errorRecoveryType();

File diff suppressed because it is too large Load diff

View file

@ -4,6 +4,7 @@
#include "Luau/Ast.h" #include "Luau/Ast.h"
#include "Luau/AstQuery.h" #include "Luau/AstQuery.h"
#include "Luau/Common.h" #include "Luau/Common.h"
#include "Luau/EqSatSimplification.h"
#include "Luau/Parser.h" #include "Luau/Parser.h"
#include "Luau/ParseOptions.h" #include "Luau/ParseOptions.h"
#include "Luau/Module.h" #include "Luau/Module.h"
@ -18,11 +19,14 @@
#include "Luau/ParseOptions.h" #include "Luau/ParseOptions.h"
#include "Luau/Module.h" #include "Luau/Module.h"
#include "AutocompleteCore.h"
LUAU_FASTINT(LuauTypeInferRecursionLimit); LUAU_FASTINT(LuauTypeInferRecursionLimit);
LUAU_FASTINT(LuauTypeInferIterationLimit); LUAU_FASTINT(LuauTypeInferIterationLimit);
LUAU_FASTINT(LuauTarjanChildLimit) LUAU_FASTINT(LuauTarjanChildLimit)
LUAU_FASTFLAG(LuauAllowFragmentParsing); LUAU_FASTFLAG(LuauAllowFragmentParsing);
LUAU_FASTFLAG(LuauStoreDFGOnModule2); LUAU_FASTFLAG(LuauStoreDFGOnModule2);
LUAU_FASTFLAG(LuauAutocompleteRefactorsForIncrementalAutocomplete)
namespace namespace
{ {
@ -41,7 +45,6 @@ void copyModuleMap(Luau::DenseHashMap<K, V>& result, const Luau::DenseHashMap<K,
} // namespace } // namespace
namespace Luau namespace Luau
{ {
@ -88,14 +91,22 @@ FragmentAutocompleteAncestryResult findAncestryForFragmentParse(AstStatBlock* ro
return {std::move(localMap), std::move(localStack), std::move(ancestry), std::move(nearestStatement)}; return {std::move(localMap), std::move(localStack), std::move(ancestry), std::move(nearestStatement)};
} }
std::pair<unsigned int, unsigned int> getDocumentOffsets(const std::string_view& src, const Position& startPos, const Position& endPos) /**
* Get document offsets is a function that takes a source text document as well as a start position and end position(line, column) in that
* document and attempts to get the concrete text between those points. It returns a pair of:
* - start offset that represents an index in the source `char*` corresponding to startPos
* - length, that represents how many more bytes to read to get to endPos.
* Example - your document is "foo bar baz" and getDocumentOffsets is passed (1, 4) - (1, 8). This function returns the pair {3, 7},
* which corresponds to the string " bar "
*/
std::pair<size_t, size_t> getDocumentOffsets(const std::string_view& src, const Position& startPos, const Position& endPos)
{ {
unsigned int lineCount = 0; size_t lineCount = 0;
unsigned int colCount = 0; size_t colCount = 0;
unsigned int docOffset = 0; size_t docOffset = 0;
unsigned int startOffset = 0; size_t startOffset = 0;
unsigned int endOffset = 0; size_t endOffset = 0;
bool foundStart = false; bool foundStart = false;
bool foundEnd = false; bool foundEnd = false;
for (char c : src) for (char c : src)
@ -115,6 +126,13 @@ std::pair<unsigned int, unsigned int> getDocumentOffsets(const std::string_view&
foundEnd = true; foundEnd = true;
} }
// We put a cursor position that extends beyond the extents of the current line
if (foundStart && !foundEnd && (lineCount > endPos.line))
{
foundEnd = true;
endOffset = docOffset - 1;
}
if (c == '\n') if (c == '\n')
{ {
lineCount++; lineCount++;
@ -125,20 +143,24 @@ std::pair<unsigned int, unsigned int> getDocumentOffsets(const std::string_view&
docOffset++; docOffset++;
} }
if (foundStart && !foundEnd)
endOffset = src.length();
unsigned int min = std::min(startOffset, endOffset); size_t min = std::min(startOffset, endOffset);
unsigned int len = std::max(startOffset, endOffset) - min; size_t len = std::max(startOffset, endOffset) - min;
return {min, len}; return {min, len};
} }
ScopePtr findClosestScope(const ModulePtr& module, const Position& cursorPos) ScopePtr findClosestScope(const ModulePtr& module, const AstStat* nearestStatement)
{ {
LUAU_ASSERT(module->hasModuleScope()); LUAU_ASSERT(module->hasModuleScope());
ScopePtr closest = module->getModuleScope(); ScopePtr closest = module->getModuleScope();
// find the scope the nearest statement belonged to.
for (auto [loc, sc] : module->scopes) for (auto [loc, sc] : module->scopes)
{ {
if (loc.begin <= cursorPos && closest->location.begin <= loc.begin) if (loc.encloses(nearestStatement->location) && closest->location.begin <= loc.begin)
closest = sc; closest = sc;
} }
@ -152,13 +174,27 @@ FragmentParseResult parseFragment(const SourceModule& srcModule, std::string_vie
opts.allowDeclarationSyntax = false; opts.allowDeclarationSyntax = false;
opts.captureComments = false; opts.captureComments = false;
opts.parseFragment = FragmentParseResumeSettings{std::move(result.localMap), std::move(result.localStack)}; opts.parseFragment = FragmentParseResumeSettings{std::move(result.localMap), std::move(result.localStack)};
AstStat* enclosingStatement = result.nearestStatement; AstStat* nearestStatement = result.nearestStatement;
const Position& endPos = cursorPos; const Location& rootSpan = srcModule.root->location;
// If the statement starts on a previous line, grab the statement beginning // Did we append vs did we insert inline
// otherwise, grab the statement end to whatever is being typed right now bool appended = cursorPos >= rootSpan.end;
const Position& startPos = // statement spans multiple lines
enclosingStatement->location.begin.line == cursorPos.line ? enclosingStatement->location.begin : enclosingStatement->location.end; bool multiline = nearestStatement->location.begin.line != nearestStatement->location.end.line;
const Position endPos = cursorPos;
// We start by re-parsing everything (we'll refine this as we go)
Position startPos = srcModule.root->location.begin;
// If we added to the end of the sourceModule, use the end of the nearest location
if (appended && multiline)
startPos = nearestStatement->location.end;
// Statement spans one line && cursorPos is on a different line
else if (!multiline && cursorPos.line != nearestStatement->location.end.line)
startPos = nearestStatement->location.end;
else
startPos = nearestStatement->location.begin;
auto [offsetStart, parseLength] = getDocumentOffsets(src, startPos, endPos); auto [offsetStart, parseLength] = getDocumentOffsets(src, startPos, endPos);
@ -173,10 +209,11 @@ FragmentParseResult parseFragment(const SourceModule& srcModule, std::string_vie
std::vector<AstNode*> fabricatedAncestry = std::move(result.ancestry); std::vector<AstNode*> fabricatedAncestry = std::move(result.ancestry);
std::vector<AstNode*> fragmentAncestry = findAncestryAtPositionForAutocomplete(p.root, p.root->location.end); std::vector<AstNode*> fragmentAncestry = findAncestryAtPositionForAutocomplete(p.root, p.root->location.end);
fabricatedAncestry.insert(fabricatedAncestry.end(), fragmentAncestry.begin(), fragmentAncestry.end()); fabricatedAncestry.insert(fabricatedAncestry.end(), fragmentAncestry.begin(), fragmentAncestry.end());
if (enclosingStatement == nullptr) if (nearestStatement == nullptr)
enclosingStatement = p.root; nearestStatement = p.root;
fragmentResult.root = std::move(p.root); fragmentResult.root = std::move(p.root);
fragmentResult.ancestry = std::move(fabricatedAncestry); fragmentResult.ancestry = std::move(fabricatedAncestry);
fragmentResult.nearestStatement = nearestStatement;
return fragmentResult; return fragmentResult;
} }
@ -205,7 +242,7 @@ ModulePtr copyModule(const ModulePtr& result, std::unique_ptr<Allocator> alloc)
return incrementalModule; return incrementalModule;
} }
FragmentTypeCheckResult typeCheckFragmentHelper( FragmentTypeCheckResult typecheckFragment_(
Frontend& frontend, Frontend& frontend,
AstStatBlock* root, AstStatBlock* root,
const ModulePtr& stale, const ModulePtr& stale,
@ -245,15 +282,18 @@ FragmentTypeCheckResult typeCheckFragmentHelper(
/// Create a DataFlowGraph just for the surrounding context /// Create a DataFlowGraph just for the surrounding context
auto updatedDfg = DataFlowGraphBuilder::updateGraph(*stale->dataFlowGraph.get(), stale->dfgScopes, root, cursorPos, iceHandler); auto updatedDfg = DataFlowGraphBuilder::updateGraph(*stale->dataFlowGraph.get(), stale->dfgScopes, root, cursorPos, iceHandler);
SimplifierPtr simplifier = newSimplifier(NotNull{&incrementalModule->internalTypes}, frontend.builtinTypes);
/// Contraint Generator /// Contraint Generator
ConstraintGenerator cg{ ConstraintGenerator cg{
incrementalModule, incrementalModule,
NotNull{&normalizer}, NotNull{&normalizer},
NotNull{simplifier.get()},
NotNull{&typeFunctionRuntime}, NotNull{&typeFunctionRuntime},
NotNull{&frontend.moduleResolver}, NotNull{&frontend.moduleResolver},
frontend.builtinTypes, frontend.builtinTypes,
iceHandler, iceHandler,
frontend.globals.globalScope, stale->getModuleScope(),
nullptr, nullptr,
nullptr, nullptr,
NotNull{&updatedDfg}, NotNull{&updatedDfg},
@ -262,7 +302,7 @@ FragmentTypeCheckResult typeCheckFragmentHelper(
cg.rootScope = stale->getModuleScope().get(); cg.rootScope = stale->getModuleScope().get();
// Any additions to the scope must occur in a fresh scope // Any additions to the scope must occur in a fresh scope
auto freshChildOfNearestScope = std::make_shared<Scope>(closestScope); auto freshChildOfNearestScope = std::make_shared<Scope>(closestScope);
incrementalModule->scopes.push_back({root->location, freshChildOfNearestScope}); incrementalModule->scopes.emplace_back(root->location, freshChildOfNearestScope);
// closest Scope -> children = { ...., freshChildOfNearestScope} // closest Scope -> children = { ...., freshChildOfNearestScope}
// We need to trim nearestChild from the scope hierarcy // We need to trim nearestChild from the scope hierarcy
@ -274,9 +314,11 @@ FragmentTypeCheckResult typeCheckFragmentHelper(
LUAU_ASSERT(back == freshChildOfNearestScope.get()); LUAU_ASSERT(back == freshChildOfNearestScope.get());
closestScope->children.pop_back(); closestScope->children.pop_back();
/// Initialize the constraint solver and run it /// Initialize the constraint solver and run it
ConstraintSolver cs{ ConstraintSolver cs{
NotNull{&normalizer}, NotNull{&normalizer},
NotNull{simplifier.get()},
NotNull{&typeFunctionRuntime}, NotNull{&typeFunctionRuntime},
NotNull(cg.rootScope), NotNull(cg.rootScope),
borrowConstraints(cg.constraints), borrowConstraints(cg.constraints),
@ -307,7 +349,7 @@ FragmentTypeCheckResult typeCheckFragmentHelper(
freeze(incrementalModule->internalTypes); freeze(incrementalModule->internalTypes);
freeze(incrementalModule->interfaceTypes); freeze(incrementalModule->interfaceTypes);
return {std::move(incrementalModule), freshChildOfNearestScope.get()}; return {std::move(incrementalModule), std::move(freshChildOfNearestScope)};
} }
@ -327,27 +369,51 @@ FragmentTypeCheckResult typecheckFragment(
} }
ModulePtr module = frontend.moduleResolver.getModule(moduleName); ModulePtr module = frontend.moduleResolver.getModule(moduleName);
const ScopePtr& closestScope = findClosestScope(module, cursorPos); FragmentParseResult parseResult = parseFragment(*sourceModule, src, cursorPos);
FragmentParseResult r = parseFragment(*sourceModule, src, cursorPos);
FrontendOptions frontendOptions = opts.value_or(frontend.options); FrontendOptions frontendOptions = opts.value_or(frontend.options);
return typeCheckFragmentHelper(frontend, r.root, module, closestScope, cursorPos, std::move(r.alloc), frontendOptions); const ScopePtr& closestScope = findClosestScope(module, parseResult.nearestStatement);
FragmentTypeCheckResult result =
typecheckFragment_(frontend, parseResult.root, module, closestScope, cursorPos, std::move(parseResult.alloc), frontendOptions);
result.ancestry = std::move(parseResult.ancestry);
return result;
} }
AutocompleteResult fragmentAutocomplete(
FragmentAutocompleteResult fragmentAutocomplete(
Frontend& frontend, Frontend& frontend,
std::string_view src, std::string_view src,
const ModuleName& moduleName, const ModuleName& moduleName,
Position& cursorPosition, Position cursorPosition,
const FrontendOptions& opts, std::optional<FrontendOptions> opts,
StringCompletionCallback callback StringCompletionCallback callback
) )
{ {
LUAU_ASSERT(FFlag::LuauSolverV2); LUAU_ASSERT(FFlag::LuauSolverV2);
LUAU_ASSERT(FFlag::LuauAllowFragmentParsing); LUAU_ASSERT(FFlag::LuauAllowFragmentParsing);
LUAU_ASSERT(FFlag::LuauStoreDFGOnModule2); LUAU_ASSERT(FFlag::LuauStoreDFGOnModule2);
LUAU_ASSERT(FFlag::LuauAutocompleteRefactorsForIncrementalAutocomplete);
const SourceModule* sourceModule = frontend.getSourceModule(moduleName);
if (!sourceModule)
{
LUAU_ASSERT(!"Expected Source Module for fragment typecheck");
return {}; return {};
}
auto tcResult = typecheckFragment(frontend, moduleName, cursorPosition, opts, src);
TypeArena arenaForFragmentAutocomplete;
auto result = Luau::autocomplete_(
tcResult.incrementalModule,
frontend.builtinTypes,
&arenaForFragmentAutocomplete,
tcResult.ancestry,
frontend.globals.globalScope.get(),
tcResult.freshScope,
cursorPosition,
frontend.fileResolver,
callback
);
return {std::move(tcResult.incrementalModule), tcResult.freshScope.get(), std::move(arenaForFragmentAutocomplete), std::move(result)};
} }
} // namespace Luau } // namespace Luau

View file

@ -10,6 +10,7 @@
#include "Luau/ConstraintSolver.h" #include "Luau/ConstraintSolver.h"
#include "Luau/DataFlowGraph.h" #include "Luau/DataFlowGraph.h"
#include "Luau/DcrLogger.h" #include "Luau/DcrLogger.h"
#include "Luau/EqSatSimplification.h"
#include "Luau/FileResolver.h" #include "Luau/FileResolver.h"
#include "Luau/NonStrictTypeChecker.h" #include "Luau/NonStrictTypeChecker.h"
#include "Luau/Parser.h" #include "Luau/Parser.h"
@ -46,7 +47,6 @@ LUAU_FASTFLAGVARIABLE(DebugLuauForceStrictMode)
LUAU_FASTFLAGVARIABLE(DebugLuauForceNonStrictMode) LUAU_FASTFLAGVARIABLE(DebugLuauForceNonStrictMode)
LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunctionNoEvaluation) LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunctionNoEvaluation)
LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauRunCustomModuleChecks, false) LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauRunCustomModuleChecks, false)
LUAU_FASTFLAGVARIABLE(LuauMoreThoroughCycleDetection)
LUAU_FASTFLAG(StudioReportLuauAny2) LUAU_FASTFLAG(StudioReportLuauAny2)
LUAU_FASTFLAGVARIABLE(LuauStoreDFGOnModule2) LUAU_FASTFLAGVARIABLE(LuauStoreDFGOnModule2)
@ -287,8 +287,7 @@ static void filterLintOptions(LintOptions& lintOptions, const std::vector<HotCom
std::vector<RequireCycle> getRequireCycles( std::vector<RequireCycle> getRequireCycles(
const FileResolver* resolver, const FileResolver* resolver,
const std::unordered_map<ModuleName, std::shared_ptr<SourceNode>>& sourceNodes, const std::unordered_map<ModuleName, std::shared_ptr<SourceNode>>& sourceNodes,
const SourceNode* start, const SourceNode* start
bool stopAtFirst = false
) )
{ {
std::vector<RequireCycle> result; std::vector<RequireCycle> result;
@ -358,9 +357,6 @@ std::vector<RequireCycle> getRequireCycles(
{ {
result.push_back({depLocation, std::move(cycle)}); result.push_back({depLocation, std::move(cycle)});
if (stopAtFirst)
return result;
// note: if we didn't find a cycle, all nodes that we've seen don't depend [transitively] on start // note: if we didn't find a cycle, all nodes that we've seen don't depend [transitively] on start
// so it's safe to *only* clear seen vector when we find a cycle // so it's safe to *only* clear seen vector when we find a cycle
// if we don't do it, we will not have correct reporting for some cycles // if we don't do it, we will not have correct reporting for some cycles
@ -884,18 +880,11 @@ void Frontend::addBuildQueueItems(
data.environmentScope = getModuleEnvironment(*sourceModule, data.config, frontendOptions.forAutocomplete); data.environmentScope = getModuleEnvironment(*sourceModule, data.config, frontendOptions.forAutocomplete);
data.recordJsonLog = FFlag::DebugLuauLogSolverToJson; data.recordJsonLog = FFlag::DebugLuauLogSolverToJson;
const Mode mode = sourceModule->mode.value_or(data.config.mode);
// in the future we could replace toposort with an algorithm that can flag cyclic nodes by itself // in the future we could replace toposort with an algorithm that can flag cyclic nodes by itself
// however, for now getRequireCycles isn't expensive in practice on the cases we care about, and long term // however, for now getRequireCycles isn't expensive in practice on the cases we care about, and long term
// all correct programs must be acyclic so this code triggers rarely // all correct programs must be acyclic so this code triggers rarely
if (cycleDetected) if (cycleDetected)
{ data.requireCycles = getRequireCycles(fileResolver, sourceNodes, sourceNode.get());
if (FFlag::LuauMoreThoroughCycleDetection)
data.requireCycles = getRequireCycles(fileResolver, sourceNodes, sourceNode.get(), false);
else
data.requireCycles = getRequireCycles(fileResolver, sourceNodes, sourceNode.get(), mode == Mode::NoCheck);
}
data.options = frontendOptions; data.options = frontendOptions;
@ -1334,6 +1323,7 @@ ModulePtr check(
unifierState.counters.iterationLimit = limits.unifierIterationLimit.value_or(FInt::LuauTypeInferIterationLimit); unifierState.counters.iterationLimit = limits.unifierIterationLimit.value_or(FInt::LuauTypeInferIterationLimit);
Normalizer normalizer{&result->internalTypes, builtinTypes, NotNull{&unifierState}}; Normalizer normalizer{&result->internalTypes, builtinTypes, NotNull{&unifierState}};
SimplifierPtr simplifier = newSimplifier(NotNull{&result->internalTypes}, builtinTypes);
TypeFunctionRuntime typeFunctionRuntime{iceHandler, NotNull{&limits}}; TypeFunctionRuntime typeFunctionRuntime{iceHandler, NotNull{&limits}};
if (FFlag::LuauUserDefinedTypeFunctionNoEvaluation) if (FFlag::LuauUserDefinedTypeFunctionNoEvaluation)
@ -1342,6 +1332,7 @@ ModulePtr check(
ConstraintGenerator cg{ ConstraintGenerator cg{
result, result,
NotNull{&normalizer}, NotNull{&normalizer},
NotNull{simplifier.get()},
NotNull{&typeFunctionRuntime}, NotNull{&typeFunctionRuntime},
moduleResolver, moduleResolver,
builtinTypes, builtinTypes,
@ -1358,6 +1349,7 @@ ModulePtr check(
ConstraintSolver cs{ ConstraintSolver cs{
NotNull{&normalizer}, NotNull{&normalizer},
NotNull{simplifier.get()},
NotNull{&typeFunctionRuntime}, NotNull{&typeFunctionRuntime},
NotNull(cg.rootScope), NotNull(cg.rootScope),
borrowConstraints(cg.constraints), borrowConstraints(cg.constraints),

View file

@ -132,7 +132,7 @@ static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool a
return dest.addType(NegationType{a.ty}); return dest.addType(NegationType{a.ty});
else if constexpr (std::is_same_v<T, TypeFunctionInstanceType>) else if constexpr (std::is_same_v<T, TypeFunctionInstanceType>)
{ {
TypeFunctionInstanceType clone{a.function, a.typeArguments, a.packArguments, a.userFuncName}; TypeFunctionInstanceType clone{a.function, a.typeArguments, a.packArguments, a.userFuncName, a.userFuncData};
return dest.addType(std::move(clone)); return dest.addType(std::move(clone));
} }
else else

View file

@ -4,6 +4,7 @@
#include "Luau/Common.h" #include "Luau/Common.h"
LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAGVARIABLE(LuauSymbolEquality)
namespace Luau namespace Luau
{ {
@ -14,7 +15,7 @@ bool Symbol::operator==(const Symbol& rhs) const
return local == rhs.local; return local == rhs.local;
else if (global.value) else if (global.value)
return rhs.global.value && global == rhs.global.value; // Subtlety: AstName::operator==(const char*) uses strcmp, not pointer identity. return rhs.global.value && global == rhs.global.value; // Subtlety: AstName::operator==(const char*) uses strcmp, not pointer identity.
else if (FFlag::LuauSolverV2) else if (FFlag::LuauSolverV2 || FFlag::LuauSymbolEquality)
return !rhs.local && !rhs.global.value; // Reflexivity: we already know `this` Symbol is empty, so check that rhs is. return !rhs.local && !rhs.global.value; // Reflexivity: we already know `this` Symbol is empty, so check that rhs is.
else else
return false; return false;

View file

@ -870,6 +870,8 @@ struct TypeStringifier
return; return;
} }
LUAU_ASSERT(uv.options.size() > 1);
bool optional = false; bool optional = false;
bool hasNonNilDisjunct = false; bool hasNonNilDisjunct = false;
@ -878,7 +880,7 @@ struct TypeStringifier
{ {
el = follow(el); el = follow(el);
if (isNil(el)) if (state.opts.useQuestionMarks && isNil(el))
{ {
optional = true; optional = true;
continue; continue;

View file

@ -51,6 +51,7 @@ LUAU_FASTFLAG(LuauUserDefinedTypeFunctionNoEvaluation)
LUAU_FASTFLAG(LuauUserTypeFunFixRegister) LUAU_FASTFLAG(LuauUserTypeFunFixRegister)
LUAU_FASTFLAG(LuauRemoveNotAnyHack) LUAU_FASTFLAG(LuauRemoveNotAnyHack)
LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunctionResetState) LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunctionResetState)
LUAU_FASTFLAG(LuauUserTypeFunExportedAndLocal)
LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease)
@ -610,11 +611,30 @@ TypeFunctionReductionResult<TypeId> userDefinedTypeFunction(
NotNull<TypeFunctionContext> ctx NotNull<TypeFunctionContext> ctx
) )
{ {
auto typeFunction = getMutable<TypeFunctionInstanceType>(instance);
if (FFlag::LuauUserTypeFunExportedAndLocal)
{
if (typeFunction->userFuncData.owner.expired())
{
ctx->ice->ice("user-defined type function module has expired");
return {std::nullopt, true, {}, {}};
}
if (!typeFunction->userFuncName || !typeFunction->userFuncData.definition)
{
ctx->ice->ice("all user-defined type functions must have an associated function definition");
return {std::nullopt, true, {}, {}};
}
}
else
{
if (!ctx->userFuncName) if (!ctx->userFuncName)
{ {
ctx->ice->ice("all user-defined type functions must have an associated function definition"); ctx->ice->ice("all user-defined type functions must have an associated function definition");
return {std::nullopt, true, {}, {}}; return {std::nullopt, true, {}, {}};
} }
}
if (FFlag::LuauUserDefinedTypeFunctionNoEvaluation) if (FFlag::LuauUserDefinedTypeFunctionNoEvaluation)
{ {
@ -632,7 +652,22 @@ TypeFunctionReductionResult<TypeId> userDefinedTypeFunction(
return {std::nullopt, false, {ty}, {}}; return {std::nullopt, false, {ty}, {}};
} }
AstName name = *ctx->userFuncName; if (FFlag::LuauUserTypeFunExportedAndLocal)
{
// Ensure that whole type function environment is registered
for (auto& [name, definition] : typeFunction->userFuncData.environment)
{
if (std::optional<std::string> error = ctx->typeFunctionRuntime->registerFunction(definition))
{
// Failure to register at this point means that original definition had to error out and should not have been present in the
// environment
ctx->ice->ice("user-defined type function reference cannot be registered");
return {std::nullopt, true, {}, {}};
}
}
}
AstName name = FFlag::LuauUserTypeFunExportedAndLocal ? typeFunction->userFuncData.definition->name : *ctx->userFuncName;
lua_State* global = ctx->typeFunctionRuntime->state.get(); lua_State* global = ctx->typeFunctionRuntime->state.get();
@ -643,8 +678,44 @@ TypeFunctionReductionResult<TypeId> userDefinedTypeFunction(
lua_State* L = lua_newthread(global); lua_State* L = lua_newthread(global);
LuauTempThreadPopper popper(global); LuauTempThreadPopper popper(global);
if (FFlag::LuauUserTypeFunExportedAndLocal)
{
// Fetch the function we want to evaluate
lua_pushlightuserdata(L, typeFunction->userFuncData.definition);
lua_gettable(L, LUA_REGISTRYINDEX);
if (!lua_isfunction(L, -1))
{
ctx->ice->ice("user-defined type function reference cannot be found in the registry");
return {std::nullopt, true, {}, {}};
}
// Build up the environment
lua_getfenv(L, -1);
lua_setreadonly(L, -1, false);
for (auto& [name, definition] : typeFunction->userFuncData.environment)
{
lua_pushlightuserdata(L, definition);
lua_gettable(L, LUA_REGISTRYINDEX);
if (!lua_isfunction(L, -1))
{
ctx->ice->ice("user-defined type function reference cannot be found in the registry");
return {std::nullopt, true, {}, {}};
}
lua_setfield(L, -2, name.c_str());
}
lua_setreadonly(L, -1, true);
lua_pop(L, 1);
}
else
{
lua_getglobal(global, name.value); lua_getglobal(global, name.value);
lua_xmove(global, L, 1); lua_xmove(global, L, 1);
}
if (FFlag::LuauUserDefinedTypeFunctionResetState) if (FFlag::LuauUserDefinedTypeFunctionResetState)
resetTypeFunctionState(L); resetTypeFunctionState(L);
@ -693,7 +764,7 @@ TypeFunctionReductionResult<TypeId> userDefinedTypeFunction(
TypeId retTypeId = deserialize(retTypeFunctionTypeId, runtimeBuilder.get()); TypeId retTypeId = deserialize(retTypeFunctionTypeId, runtimeBuilder.get());
// At least 1 error occured while deserializing // At least 1 error occurred while deserializing
if (runtimeBuilder->errors.size() > 0) if (runtimeBuilder->errors.size() > 0)
return {std::nullopt, true, {}, {}, runtimeBuilder->errors.front()}; return {std::nullopt, true, {}, {}, runtimeBuilder->errors.front()};
@ -935,6 +1006,23 @@ std::optional<std::string> TypeFunctionRuntime::registerFunction(AstStatTypeFunc
prepareState(); prepareState();
lua_State* global = state.get();
if (FFlag::LuauUserTypeFunExportedAndLocal)
{
// Fetch to check if function is already registered
lua_pushlightuserdata(global, function);
lua_gettable(global, LUA_REGISTRYINDEX);
if (!lua_isnil(global, -1))
{
lua_pop(global, 1);
return std::nullopt;
}
lua_pop(global, 1);
}
AstName name = function->name; AstName name = function->name;
// Construct ParseResult containing the type function // Construct ParseResult containing the type function
@ -961,7 +1049,6 @@ std::optional<std::string> TypeFunctionRuntime::registerFunction(AstStatTypeFunc
std::string bytecode = builder.getBytecode(); std::string bytecode = builder.getBytecode();
lua_State* global = state.get();
// Separate sandboxed thread for individual execution and private globals // Separate sandboxed thread for individual execution and private globals
lua_State* L = lua_newthread(global); lua_State* L = lua_newthread(global);
@ -989,9 +1076,19 @@ std::optional<std::string> TypeFunctionRuntime::registerFunction(AstStatTypeFunc
return format("Could not find '%s' type function in the global scope", name.value); return format("Could not find '%s' type function in the global scope", name.value);
} }
if (FFlag::LuauUserTypeFunExportedAndLocal)
{
// Store resulting function in the registry
lua_pushlightuserdata(global, function);
lua_xmove(L, global, 1);
lua_settable(global, LUA_REGISTRYINDEX);
}
else
{
// Store resulting function in the global environment // Store resulting function in the global environment
lua_xmove(L, global, 1); lua_xmove(L, global, 1);
lua_setglobal(global, name.value); lua_setglobal(global, name.value);
}
return std::nullopt; return std::nullopt;
} }

View file

@ -0,0 +1,48 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Ast.h"
#include "Luau/Location.h"
#include "Luau/DenseHash.h"
#include "Luau/Common.h"
#include <vector>
namespace Luau
{
class Allocator
{
public:
Allocator();
Allocator(Allocator&&);
Allocator& operator=(Allocator&&) = delete;
~Allocator();
void* allocate(size_t size);
template<typename T, typename... Args>
T* alloc(Args&&... args)
{
static_assert(std::is_trivially_destructible<T>::value, "Objects allocated with this allocator will never have their destructors run!");
T* t = static_cast<T*>(allocate(sizeof(T)));
new (t) T(std::forward<Args>(args)...);
return t;
}
private:
struct Page
{
Page* next;
char data[8192];
};
Page* root;
size_t offset;
};
}

View file

@ -316,16 +316,18 @@ public:
enum QuoteStyle enum QuoteStyle
{ {
Quoted, QuotedSimple,
QuotedRaw,
Unquoted Unquoted
}; };
AstExprConstantString(const Location& location, const AstArray<char>& value, QuoteStyle quoteStyle = Quoted); AstExprConstantString(const Location& location, const AstArray<char>& value, QuoteStyle quoteStyle);
void visit(AstVisitor* visitor) override; void visit(AstVisitor* visitor) override;
bool isQuoted() const;
AstArray<char> value; AstArray<char> value;
QuoteStyle quoteStyle = Quoted; QuoteStyle quoteStyle;
}; };
class AstExprLocal : public AstExpr class AstExprLocal : public AstExpr
@ -876,13 +878,14 @@ class AstStatTypeFunction : public AstStat
public: public:
LUAU_RTTI(AstStatTypeFunction); LUAU_RTTI(AstStatTypeFunction);
AstStatTypeFunction(const Location& location, const AstName& name, const Location& nameLocation, AstExprFunction* body); AstStatTypeFunction(const Location& location, const AstName& name, const Location& nameLocation, AstExprFunction* body, bool exported);
void visit(AstVisitor* visitor) override; void visit(AstVisitor* visitor) override;
AstName name; AstName name;
Location nameLocation; Location nameLocation;
AstExprFunction* body; AstExprFunction* body;
bool exported;
}; };
class AstStatDeclareGlobal : public AstStat class AstStatDeclareGlobal : public AstStat

View file

@ -1,6 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once #pragma once
#include "Luau/Allocator.h"
#include "Luau/Ast.h" #include "Luau/Ast.h"
#include "Luau/Location.h" #include "Luau/Location.h"
#include "Luau/DenseHash.h" #include "Luau/DenseHash.h"
@ -11,40 +12,6 @@
namespace Luau namespace Luau
{ {
class Allocator
{
public:
Allocator();
Allocator(Allocator&&);
Allocator& operator=(Allocator&&) = delete;
~Allocator();
void* allocate(size_t size);
template<typename T, typename... Args>
T* alloc(Args&&... args)
{
static_assert(std::is_trivially_destructible<T>::value, "Objects allocated with this allocator will never have their destructors run!");
T* t = static_cast<T*>(allocate(sizeof(T)));
new (t) T(std::forward<Args>(args)...);
return t;
}
private:
struct Page
{
Page* next;
char data[8192];
};
Page* root;
size_t offset;
};
struct Lexeme struct Lexeme
{ {
enum Type enum Type

66
Ast/src/Allocator.cpp Normal file
View file

@ -0,0 +1,66 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Allocator.h"
namespace Luau
{
Allocator::Allocator()
: root(static_cast<Page*>(operator new(sizeof(Page))))
, offset(0)
{
root->next = nullptr;
}
Allocator::Allocator(Allocator&& rhs)
: root(rhs.root)
, offset(rhs.offset)
{
rhs.root = nullptr;
rhs.offset = 0;
}
Allocator::~Allocator()
{
Page* page = root;
while (page)
{
Page* next = page->next;
operator delete(page);
page = next;
}
}
void* Allocator::allocate(size_t size)
{
constexpr size_t align = alignof(void*) > alignof(double) ? alignof(void*) : alignof(double);
if (root)
{
uintptr_t data = reinterpret_cast<uintptr_t>(root->data);
uintptr_t result = (data + offset + align - 1) & ~(align - 1);
if (result + size <= data + sizeof(root->data))
{
offset = result - data + size;
return reinterpret_cast<void*>(result);
}
}
// allocate new page
size_t pageSize = size > sizeof(root->data) ? size : sizeof(root->data);
void* pageData = operator new(offsetof(Page, data) + pageSize);
Page* page = static_cast<Page*>(pageData);
page->next = root;
root = page;
offset = size;
return page->data;
}
}

View file

@ -92,6 +92,11 @@ void AstExprConstantString::visit(AstVisitor* visitor)
visitor->visit(this); visitor->visit(this);
} }
bool AstExprConstantString::isQuoted() const
{
return quoteStyle == QuoteStyle::QuotedSimple || quoteStyle == QuoteStyle::QuotedRaw;
}
AstExprLocal::AstExprLocal(const Location& location, AstLocal* local, bool upvalue) AstExprLocal::AstExprLocal(const Location& location, AstLocal* local, bool upvalue)
: AstExpr(ClassIndex(), location) : AstExpr(ClassIndex(), location)
, local(local) , local(local)
@ -760,11 +765,18 @@ void AstStatTypeAlias::visit(AstVisitor* visitor)
} }
} }
AstStatTypeFunction::AstStatTypeFunction(const Location& location, const AstName& name, const Location& nameLocation, AstExprFunction* body) AstStatTypeFunction::AstStatTypeFunction(
const Location& location,
const AstName& name,
const Location& nameLocation,
AstExprFunction* body,
bool exported
)
: AstStat(ClassIndex(), location) : AstStat(ClassIndex(), location)
, name(name) , name(name)
, nameLocation(nameLocation) , nameLocation(nameLocation)
, body(body) , body(body)
, exported(exported)
{ {
} }

View file

@ -1,6 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Lexer.h" #include "Luau/Lexer.h"
#include "Luau/Allocator.h"
#include "Luau/Common.h" #include "Luau/Common.h"
#include "Luau/Confusables.h" #include "Luau/Confusables.h"
#include "Luau/StringUtils.h" #include "Luau/StringUtils.h"
@ -10,64 +11,6 @@
namespace Luau namespace Luau
{ {
Allocator::Allocator()
: root(static_cast<Page*>(operator new(sizeof(Page))))
, offset(0)
{
root->next = nullptr;
}
Allocator::Allocator(Allocator&& rhs)
: root(rhs.root)
, offset(rhs.offset)
{
rhs.root = nullptr;
rhs.offset = 0;
}
Allocator::~Allocator()
{
Page* page = root;
while (page)
{
Page* next = page->next;
operator delete(page);
page = next;
}
}
void* Allocator::allocate(size_t size)
{
constexpr size_t align = alignof(void*) > alignof(double) ? alignof(void*) : alignof(double);
if (root)
{
uintptr_t data = reinterpret_cast<uintptr_t>(root->data);
uintptr_t result = (data + offset + align - 1) & ~(align - 1);
if (result + size <= data + sizeof(root->data))
{
offset = result - data + size;
return reinterpret_cast<void*>(result);
}
}
// allocate new page
size_t pageSize = size > sizeof(root->data) ? size : sizeof(root->data);
void* pageData = operator new(offsetof(Page, data) + pageSize);
Page* page = static_cast<Page*>(pageData);
page->next = root;
root = page;
offset = size;
return page->data;
}
Lexeme::Lexeme(const Location& location, Type type) Lexeme::Lexeme(const Location& location, Type type)
: type(type) : type(type)
, location(location) , location(location)

View file

@ -21,6 +21,7 @@ LUAU_FASTFLAGVARIABLE(LuauSolverV2)
LUAU_FASTFLAGVARIABLE(LuauNativeAttribute) LUAU_FASTFLAGVARIABLE(LuauNativeAttribute)
LUAU_FASTFLAGVARIABLE(LuauAttributeSyntaxFunExpr) LUAU_FASTFLAGVARIABLE(LuauAttributeSyntaxFunExpr)
LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunctionsSyntax2) LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunctionsSyntax2)
LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunParseExport)
LUAU_FASTFLAGVARIABLE(LuauAllowFragmentParsing) LUAU_FASTFLAGVARIABLE(LuauAllowFragmentParsing)
LUAU_FASTFLAGVARIABLE(LuauPortableStringZeroCheck) LUAU_FASTFLAGVARIABLE(LuauPortableStringZeroCheck)
@ -943,8 +944,11 @@ AstStat* Parser::parseTypeFunction(const Location& start, bool exported)
Lexeme matchFn = lexer.current(); Lexeme matchFn = lexer.current();
nextLexeme(); nextLexeme();
if (!FFlag::LuauUserDefinedTypeFunParseExport)
{
if (exported) if (exported)
report(start, "Type function cannot be exported"); report(start, "Type function cannot be exported");
}
// parse the name of the type function // parse the name of the type function
std::optional<Name> fnName = parseNameOpt("type function name"); std::optional<Name> fnName = parseNameOpt("type function name");
@ -962,7 +966,7 @@ AstStat* Parser::parseTypeFunction(const Location& start, bool exported)
matchRecoveryStopOnToken[Lexeme::ReservedEnd]--; matchRecoveryStopOnToken[Lexeme::ReservedEnd]--;
return allocator.alloc<AstStatTypeFunction>(Location(start, body->location), fnName->name, fnName->location, body); return allocator.alloc<AstStatTypeFunction>(Location(start, body->location), fnName->name, fnName->location, body, exported);
} }
AstDeclaredClassProp Parser::parseDeclaredClassMethod() AstDeclaredClassProp Parser::parseDeclaredClassMethod()
@ -3012,8 +3016,23 @@ std::optional<AstArray<char>> Parser::parseCharArray()
AstExpr* Parser::parseString() AstExpr* Parser::parseString()
{ {
Location location = lexer.current().location; Location location = lexer.current().location;
AstExprConstantString::QuoteStyle style;
switch (lexer.current().type)
{
case Lexeme::QuotedString:
case Lexeme::InterpStringSimple:
style = AstExprConstantString::QuotedSimple;
break;
case Lexeme::RawString:
style = AstExprConstantString::QuotedRaw;
break;
default:
LUAU_ASSERT(false && "Invalid string type");
}
if (std::optional<AstArray<char>> value = parseCharArray()) if (std::optional<AstArray<char>> value = parseCharArray())
return allocator.alloc<AstExprConstantString>(location, *value); return allocator.alloc<AstExprConstantString>(location, *value, style);
else else
return reportExprError(location, {}, "String literal contains malformed escape sequence"); return reportExprError(location, {}, "String literal contains malformed escape sequence");
} }

View file

@ -1,4 +1,5 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Config.h"
#include "Luau/ModuleResolver.h" #include "Luau/ModuleResolver.h"
#include "Luau/TypeInfer.h" #include "Luau/TypeInfer.h"
#include "Luau/BuiltinDefinitions.h" #include "Luau/BuiltinDefinitions.h"
@ -224,7 +225,14 @@ struct CliConfigResolver : Luau::ConfigResolver
if (std::optional<std::string> contents = readFile(configPath)) if (std::optional<std::string> contents = readFile(configPath))
{ {
std::optional<std::string> error = Luau::parseConfig(*contents, result); Luau::ConfigOptions::AliasOptions aliasOpts;
aliasOpts.configLocation = configPath;
aliasOpts.overwriteAliases = true;
Luau::ConfigOptions opts;
opts.aliasOptions = std::move(aliasOpts);
std::optional<std::string> error = Luau::parseConfig(*contents, result, opts);
if (error) if (error)
configErrors.push_back({configPath, *error}); configErrors.push_back({configPath, *error});
} }

View file

@ -181,6 +181,16 @@ std::string resolvePath(std::string_view path, std::string_view baseFilePath)
return resolvedPath; return resolvedPath;
} }
bool hasFileExtension(std::string_view name, const std::vector<std::string>& extensions)
{
for (const std::string& extension : extensions)
{
if (name.size() >= extension.size() && name.substr(name.size() - extension.size()) == extension)
return true;
}
return false;
}
std::optional<std::string> readFile(const std::string& name) std::optional<std::string> readFile(const std::string& name)
{ {
#ifdef _WIN32 #ifdef _WIN32

View file

@ -15,6 +15,8 @@ std::string resolvePath(std::string_view relativePath, std::string_view baseFile
std::optional<std::string> readFile(const std::string& name); std::optional<std::string> readFile(const std::string& name);
std::optional<std::string> readStdin(); std::optional<std::string> readStdin();
bool hasFileExtension(std::string_view name, const std::vector<std::string>& extensions);
bool isAbsolutePath(std::string_view path); bool isAbsolutePath(std::string_view path);
bool isFile(const std::string& path); bool isFile(const std::string& path);
bool isDirectory(const std::string& path); bool isDirectory(const std::string& path);

View file

@ -3,6 +3,7 @@
#include "FileUtils.h" #include "FileUtils.h"
#include "Luau/Common.h" #include "Luau/Common.h"
#include "Luau/Config.h"
#include <algorithm> #include <algorithm>
#include <array> #include <array>
@ -83,6 +84,9 @@ RequireResolver::ModuleStatus RequireResolver::findModuleImpl()
absolutePath.resize(unsuffixedAbsolutePathSize); // truncate to remove suffix absolutePath.resize(unsuffixedAbsolutePathSize); // truncate to remove suffix
} }
if (hasFileExtension(absolutePath, {".luau", ".lua"}) && isFile(absolutePath))
luaL_argerrorL(L, 1, "error requiring module: consider removing the file extension");
return ModuleStatus::NotFound; return ModuleStatus::NotFound;
} }
@ -235,14 +239,15 @@ std::optional<std::string> RequireResolver::getAlias(std::string alias)
return ('A' <= c && c <= 'Z') ? (c + ('a' - 'A')) : c; return ('A' <= c && c <= 'Z') ? (c + ('a' - 'A')) : c;
} }
); );
while (!config.aliases.count(alias) && !isConfigFullyResolved) while (!config.aliases.contains(alias) && !isConfigFullyResolved)
{ {
parseNextConfig(); parseNextConfig();
} }
if (!config.aliases.count(alias) && isConfigFullyResolved) if (!config.aliases.contains(alias) && isConfigFullyResolved)
return std::nullopt; // could not find alias return std::nullopt; // could not find alias
return resolvePath(config.aliases[alias], joinPaths(lastSearchedDir, Luau::kConfigName)); const Luau::Config::AliasInfo& aliasInfo = config.aliases[alias];
return resolvePath(aliasInfo.value, aliasInfo.configLocation);
} }
void RequireResolver::parseNextConfig() void RequireResolver::parseNextConfig()
@ -275,9 +280,16 @@ void RequireResolver::parseConfigInDirectory(const std::string& directory)
{ {
std::string configPath = joinPaths(directory, Luau::kConfigName); std::string configPath = joinPaths(directory, Luau::kConfigName);
Luau::ConfigOptions::AliasOptions aliasOpts;
aliasOpts.configLocation = configPath;
aliasOpts.overwriteAliases = false;
Luau::ConfigOptions opts;
opts.aliasOptions = std::move(aliasOpts);
if (std::optional<std::string> contents = readFile(configPath)) if (std::optional<std::string> contents = readFile(configPath))
{ {
std::optional<std::string> error = Luau::parseConfig(*contents, config); std::optional<std::string> error = Luau::parseConfig(*contents, config, opts);
if (error) if (error)
luaL_errorL(L, "error parsing %s (%s)", configPath.c_str(), (*error).c_str()); luaL_errorL(L, "error parsing %s (%s)", configPath.c_str(), (*error).c_str());
} }

View file

@ -8,7 +8,7 @@ Some questions help improve the language, implementation or documentation by ins
## Documentation ## Documentation
A [separate site repository](https://github.com/luau-lang/site) hosts the language documentation, which is accessible on https://luau-lang.org. A [separate site repository](https://github.com/luau-lang/site) hosts the language documentation, which is accessible on https://luau.org.
Changes to this documentation that improve clarity, fix grammatical issues, explain aspects that haven't been explained before and the like are warmly welcomed. Changes to this documentation that improve clarity, fix grammatical issues, explain aspects that haven't been explained before and the like are warmly welcomed.
Please feel free to [create a pull request](https://help.github.com/articles/about-pull-requests/) to improve our documentation. Note that at this point the documentation is English-only. Please feel free to [create a pull request](https://help.github.com/articles/about-pull-requests/) to improve our documentation. Note that at this point the documentation is English-only.

View file

@ -138,6 +138,7 @@ public:
void fneg(RegisterA64 dst, RegisterA64 src); void fneg(RegisterA64 dst, RegisterA64 src);
void fsqrt(RegisterA64 dst, RegisterA64 src); void fsqrt(RegisterA64 dst, RegisterA64 src);
void fsub(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); void fsub(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2);
void faddp(RegisterA64 dst, RegisterA64 src);
// Vector component manipulation // Vector component manipulation
void ins_4s(RegisterA64 dst, RegisterA64 src, uint8_t index); void ins_4s(RegisterA64 dst, RegisterA64 src, uint8_t index);

View file

@ -167,6 +167,8 @@ public:
void vpshufps(RegisterX64 dst, RegisterX64 src1, OperandX64 src2, uint8_t shuffle); void vpshufps(RegisterX64 dst, RegisterX64 src1, OperandX64 src2, uint8_t shuffle);
void vpinsrd(RegisterX64 dst, RegisterX64 src1, OperandX64 src2, uint8_t offset); void vpinsrd(RegisterX64 dst, RegisterX64 src1, OperandX64 src2, uint8_t offset);
void vdpps(OperandX64 dst, OperandX64 src1, OperandX64 src2, uint8_t mask);
// Run final checks // Run final checks
bool finalize(); bool finalize();

View file

@ -194,6 +194,10 @@ enum class IrCmd : uint8_t
// A: TValue // A: TValue
UNM_VEC, UNM_VEC,
// Compute dot product between two vectors
// A, B: TValue
DOT_VEC,
// Compute Luau 'not' operation on destructured TValue // Compute Luau 'not' operation on destructured TValue
// A: tag // A: tag
// B: int (value) // B: int (value)

View file

@ -176,6 +176,7 @@ inline bool hasResult(IrCmd cmd)
case IrCmd::SUB_VEC: case IrCmd::SUB_VEC:
case IrCmd::MUL_VEC: case IrCmd::MUL_VEC:
case IrCmd::DIV_VEC: case IrCmd::DIV_VEC:
case IrCmd::DOT_VEC:
case IrCmd::UNM_VEC: case IrCmd::UNM_VEC:
case IrCmd::NOT_ANY: case IrCmd::NOT_ANY:
case IrCmd::CMP_ANY: case IrCmd::CMP_ANY:

View file

@ -586,6 +586,14 @@ void AssemblyBuilderA64::fabs(RegisterA64 dst, RegisterA64 src)
placeR1("fabs", dst, src, 0b000'11110'01'1'0000'01'10000); placeR1("fabs", dst, src, 0b000'11110'01'1'0000'01'10000);
} }
void AssemblyBuilderA64::faddp(RegisterA64 dst, RegisterA64 src)
{
CODEGEN_ASSERT(dst.kind == KindA64::d || dst.kind == KindA64::s);
CODEGEN_ASSERT(dst.kind == src.kind);
placeR1("faddp", dst, src, 0b011'11110'0'0'11000'01101'10 | ((dst.kind == KindA64::d) << 12));
}
void AssemblyBuilderA64::fadd(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2) void AssemblyBuilderA64::fadd(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2)
{ {
if (dst.kind == KindA64::d) if (dst.kind == KindA64::d)

View file

@ -946,6 +946,11 @@ void AssemblyBuilderX64::vpinsrd(RegisterX64 dst, RegisterX64 src1, OperandX64 s
placeAvx("vpinsrd", dst, src1, src2, offset, 0x22, false, AVX_0F3A, AVX_66); placeAvx("vpinsrd", dst, src1, src2, offset, 0x22, false, AVX_0F3A, AVX_66);
} }
void AssemblyBuilderX64::vdpps(OperandX64 dst, OperandX64 src1, OperandX64 src2, uint8_t mask)
{
placeAvx("vdpps", dst, src1, src2, mask, 0x40, false, AVX_0F3A, AVX_66);
}
bool AssemblyBuilderX64::finalize() bool AssemblyBuilderX64::finalize()
{ {
code.resize(codePos - code.data()); code.resize(codePos - code.data());

View file

@ -163,6 +163,8 @@ const char* getCmdName(IrCmd cmd)
return "DIV_VEC"; return "DIV_VEC";
case IrCmd::UNM_VEC: case IrCmd::UNM_VEC:
return "UNM_VEC"; return "UNM_VEC";
case IrCmd::DOT_VEC:
return "DOT_VEC";
case IrCmd::NOT_ANY: case IrCmd::NOT_ANY:
return "NOT_ANY"; return "NOT_ANY";
case IrCmd::CMP_ANY: case IrCmd::CMP_ANY:

View file

@ -728,6 +728,21 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
build.fneg(inst.regA64, regOp(inst.a)); build.fneg(inst.regA64, regOp(inst.a));
break; break;
} }
case IrCmd::DOT_VEC:
{
inst.regA64 = regs.allocReg(KindA64::d, index);
RegisterA64 temp = regs.allocTemp(KindA64::q);
RegisterA64 temps = castReg(KindA64::s, temp);
RegisterA64 regs = castReg(KindA64::s, inst.regA64);
build.fmul(temp, regOp(inst.a), regOp(inst.b));
build.faddp(regs, temps); // x+y
build.dup_4s(temp, temp, 2);
build.fadd(regs, regs, temps); // +z
build.fcvt(inst.regA64, regs);
break;
}
case IrCmd::NOT_ANY: case IrCmd::NOT_ANY:
{ {
inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.a, inst.b}); inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.a, inst.b});

View file

@ -675,6 +675,20 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
build.vxorpd(inst.regX64, regOp(inst.a), build.f32x4(-0.0, -0.0, -0.0, -0.0)); build.vxorpd(inst.regX64, regOp(inst.a), build.f32x4(-0.0, -0.0, -0.0, -0.0));
break; break;
} }
case IrCmd::DOT_VEC:
{
inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a, inst.b});
ScopedRegX64 tmp1{regs};
ScopedRegX64 tmp2{regs};
RegisterX64 tmpa = vecOp(inst.a, tmp1);
RegisterX64 tmpb = (inst.a == inst.b) ? tmpa : vecOp(inst.b, tmp2);
build.vdpps(inst.regX64, tmpa, tmpb, 0x71); // 7 = 0b0111, sum first 3 products into first float
build.vcvtss2sd(inst.regX64, inst.regX64, inst.regX64);
break;
}
case IrCmd::NOT_ANY: case IrCmd::NOT_ANY:
{ {
// TODO: if we have a single user which is a STORE_INT, we are missing the opportunity to write directly to target // TODO: if we have a single user which is a STORE_INT, we are missing the opportunity to write directly to target

View file

@ -14,6 +14,7 @@ static const int kMinMaxUnrolledParams = 5;
static const int kBit32BinaryOpUnrolledParams = 5; static const int kBit32BinaryOpUnrolledParams = 5;
LUAU_FASTFLAGVARIABLE(LuauVectorLibNativeCodegen); LUAU_FASTFLAGVARIABLE(LuauVectorLibNativeCodegen);
LUAU_FASTFLAGVARIABLE(LuauVectorLibNativeDot);
namespace Luau namespace Luau
{ {
@ -907,6 +908,16 @@ static BuiltinImplResult translateBuiltinVectorMagnitude(
build.loadAndCheckTag(arg1, LUA_TVECTOR, build.vmExit(pcpos)); build.loadAndCheckTag(arg1, LUA_TVECTOR, build.vmExit(pcpos));
IrOp sum;
if (FFlag::LuauVectorLibNativeDot)
{
IrOp a = build.inst(IrCmd::LOAD_TVALUE, arg1, build.constInt(0));
sum = build.inst(IrCmd::DOT_VEC, a, a);
}
else
{
IrOp x = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(0)); IrOp x = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(0));
IrOp y = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(4)); IrOp y = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(4));
IrOp z = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(8)); IrOp z = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(8));
@ -915,7 +926,8 @@ static BuiltinImplResult translateBuiltinVectorMagnitude(
IrOp y2 = build.inst(IrCmd::MUL_NUM, y, y); IrOp y2 = build.inst(IrCmd::MUL_NUM, y, y);
IrOp z2 = build.inst(IrCmd::MUL_NUM, z, z); IrOp z2 = build.inst(IrCmd::MUL_NUM, z, z);
IrOp sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, x2, y2), z2); sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, x2, y2), z2);
}
IrOp mag = build.inst(IrCmd::SQRT_NUM, sum); IrOp mag = build.inst(IrCmd::SQRT_NUM, sum);
@ -945,6 +957,23 @@ static BuiltinImplResult translateBuiltinVectorNormalize(
build.loadAndCheckTag(arg1, LUA_TVECTOR, build.vmExit(pcpos)); build.loadAndCheckTag(arg1, LUA_TVECTOR, build.vmExit(pcpos));
if (FFlag::LuauVectorLibNativeDot)
{
IrOp a = build.inst(IrCmd::LOAD_TVALUE, arg1, build.constInt(0));
IrOp sum = build.inst(IrCmd::DOT_VEC, a, a);
IrOp mag = build.inst(IrCmd::SQRT_NUM, sum);
IrOp inv = build.inst(IrCmd::DIV_NUM, build.constDouble(1.0), mag);
IrOp invvec = build.inst(IrCmd::NUM_TO_VEC, inv);
IrOp result = build.inst(IrCmd::MUL_VEC, a, invvec);
result = build.inst(IrCmd::TAG_VECTOR, result);
build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), result);
}
else
{
IrOp x = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(0)); IrOp x = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(0));
IrOp y = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(4)); IrOp y = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(4));
IrOp z = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(8)); IrOp z = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(8));
@ -964,6 +993,7 @@ static BuiltinImplResult translateBuiltinVectorNormalize(
build.inst(IrCmd::STORE_VECTOR, build.vmReg(ra), xr, yr, zr); build.inst(IrCmd::STORE_VECTOR, build.vmReg(ra), xr, yr, zr);
build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TVECTOR)); build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TVECTOR));
}
return {BuiltinImplType::Full, 1}; return {BuiltinImplType::Full, 1};
} }
@ -1019,6 +1049,17 @@ static BuiltinImplResult translateBuiltinVectorDot(IrBuilder& build, int nparams
build.loadAndCheckTag(arg1, LUA_TVECTOR, build.vmExit(pcpos)); build.loadAndCheckTag(arg1, LUA_TVECTOR, build.vmExit(pcpos));
build.loadAndCheckTag(args, LUA_TVECTOR, build.vmExit(pcpos)); build.loadAndCheckTag(args, LUA_TVECTOR, build.vmExit(pcpos));
IrOp sum;
if (FFlag::LuauVectorLibNativeDot)
{
IrOp a = build.inst(IrCmd::LOAD_TVALUE, arg1, build.constInt(0));
IrOp b = build.inst(IrCmd::LOAD_TVALUE, args, build.constInt(0));
sum = build.inst(IrCmd::DOT_VEC, a, b);
}
else
{
IrOp x1 = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(0)); IrOp x1 = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(0));
IrOp x2 = build.inst(IrCmd::LOAD_FLOAT, args, build.constInt(0)); IrOp x2 = build.inst(IrCmd::LOAD_FLOAT, args, build.constInt(0));
IrOp xx = build.inst(IrCmd::MUL_NUM, x1, x2); IrOp xx = build.inst(IrCmd::MUL_NUM, x1, x2);
@ -1031,7 +1072,8 @@ static BuiltinImplResult translateBuiltinVectorDot(IrBuilder& build, int nparams
IrOp z2 = build.inst(IrCmd::LOAD_FLOAT, args, build.constInt(8)); IrOp z2 = build.inst(IrCmd::LOAD_FLOAT, args, build.constInt(8));
IrOp zz = build.inst(IrCmd::MUL_NUM, z1, z2); IrOp zz = build.inst(IrCmd::MUL_NUM, z1, z2);
IrOp sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, xx, yy), zz); sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, xx, yy), zz);
}
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), sum); build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), sum);
build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER));

View file

@ -75,6 +75,8 @@ IrValueKind getCmdValueKind(IrCmd cmd)
case IrCmd::DIV_VEC: case IrCmd::DIV_VEC:
case IrCmd::UNM_VEC: case IrCmd::UNM_VEC:
return IrValueKind::Tvalue; return IrValueKind::Tvalue;
case IrCmd::DOT_VEC:
return IrValueKind::Double;
case IrCmd::NOT_ANY: case IrCmd::NOT_ANY:
case IrCmd::CMP_ANY: case IrCmd::CMP_ANY:
return IrValueKind::Int; return IrValueKind::Int;

View file

@ -768,7 +768,8 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction&
if (tag == LUA_TBOOLEAN && if (tag == LUA_TBOOLEAN &&
(value.kind == IrOpKind::Inst || (value.kind == IrOpKind::Constant && function.constOp(value).kind == IrConstKind::Int))) (value.kind == IrOpKind::Inst || (value.kind == IrOpKind::Constant && function.constOp(value).kind == IrConstKind::Int)))
canSplitTvalueStore = true; canSplitTvalueStore = true;
else if (tag == LUA_TNUMBER && (value.kind == IrOpKind::Inst || (value.kind == IrOpKind::Constant && function.constOp(value).kind == IrConstKind::Double))) else if (tag == LUA_TNUMBER &&
(value.kind == IrOpKind::Inst || (value.kind == IrOpKind::Constant && function.constOp(value).kind == IrConstKind::Double)))
canSplitTvalueStore = true; canSplitTvalueStore = true;
else if (tag != 0xff && isGCO(tag) && value.kind == IrOpKind::Inst) else if (tag != 0xff && isGCO(tag) && value.kind == IrOpKind::Inst)
canSplitTvalueStore = true; canSplitTvalueStore = true;
@ -1342,6 +1343,7 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction&
case IrCmd::SUB_VEC: case IrCmd::SUB_VEC:
case IrCmd::MUL_VEC: case IrCmd::MUL_VEC:
case IrCmd::DIV_VEC: case IrCmd::DIV_VEC:
case IrCmd::DOT_VEC:
if (IrInst* a = function.asInstOp(inst.a); a && a->cmd == IrCmd::TAG_VECTOR) if (IrInst* a = function.asInstOp(inst.a); a && a->cmd == IrCmd::TAG_VECTOR)
replace(function, inst.a, a->a); replace(function, inst.a, a->a);

View file

@ -19,7 +19,7 @@ class Variant
static_assert(std::disjunction_v<std::is_reference<Ts>...> == false, "variant does not allow references as an alternative type"); static_assert(std::disjunction_v<std::is_reference<Ts>...> == false, "variant does not allow references as an alternative type");
static_assert(std::disjunction_v<std::is_array<Ts>...> == false, "variant does not allow arrays as an alternative type"); static_assert(std::disjunction_v<std::is_array<Ts>...> == false, "variant does not allow arrays as an alternative type");
private: public:
template<typename T> template<typename T>
static constexpr int getTypeId() static constexpr int getTypeId()
{ {
@ -35,6 +35,7 @@ private:
return -1; return -1;
} }
private:
template<typename T, typename... Tail> template<typename T, typename... Tail>
struct First struct First
{ {

View file

@ -1,12 +1,14 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once #pragma once
#include "Luau/DenseHash.h"
#include "Luau/LinterConfig.h" #include "Luau/LinterConfig.h"
#include "Luau/ParseOptions.h" #include "Luau/ParseOptions.h"
#include <memory>
#include <optional> #include <optional>
#include <string> #include <string>
#include <unordered_map> #include <string_view>
#include <vector> #include <vector>
namespace Luau namespace Luau
@ -19,6 +21,10 @@ constexpr const char* kConfigName = ".luaurc";
struct Config struct Config
{ {
Config(); Config();
Config(const Config& other);
Config& operator=(const Config& other);
Config(Config&& other) = default;
Config& operator=(Config&& other) = default;
Mode mode = Mode::Nonstrict; Mode mode = Mode::Nonstrict;
@ -32,7 +38,19 @@ struct Config
std::vector<std::string> globals; std::vector<std::string> globals;
std::unordered_map<std::string, std::string> aliases; struct AliasInfo
{
std::string value;
std::string_view configLocation;
};
DenseHashMap<std::string, AliasInfo> aliases{""};
void setAlias(std::string alias, const std::string& value, const std::string configLocation);
private:
// Prevents making unnecessary copies of the same config location string.
DenseHashMap<std::string, std::unique_ptr<std::string>> configLocationCache{""};
}; };
struct ConfigResolver struct ConfigResolver
@ -60,6 +78,18 @@ std::optional<std::string> parseLintRuleString(
bool isValidAlias(const std::string& alias); bool isValidAlias(const std::string& alias);
std::optional<std::string> parseConfig(const std::string& contents, Config& config, bool compat = false); struct ConfigOptions
{
bool compat = false;
struct AliasOptions
{
std::string configLocation;
bool overwriteAliases;
};
std::optional<AliasOptions> aliasOptions = std::nullopt;
};
std::optional<std::string> parseConfig(const std::string& contents, Config& config, const ConfigOptions& options = ConfigOptions{});
} // namespace Luau } // namespace Luau

View file

@ -15,7 +15,7 @@ struct HotComment;
struct LintWarning struct LintWarning
{ {
// Make sure any new lint codes are documented here: https://luau-lang.org/lint // Make sure any new lint codes are documented here: https://luau.org/lint
// Note that in Studio, the active set of lint warnings is determined by FStringStudioLuauLints // Note that in Studio, the active set of lint warnings is determined by FStringStudioLuauLints
enum Code enum Code
{ {

View file

@ -4,7 +4,8 @@
#include "Luau/Lexer.h" #include "Luau/Lexer.h"
#include "Luau/StringUtils.h" #include "Luau/StringUtils.h"
#include <algorithm> #include <algorithm>
#include <unordered_map> #include <memory>
#include <string>
namespace Luau namespace Luau
{ {
@ -16,6 +17,50 @@ Config::Config()
enabledLint.setDefaults(); enabledLint.setDefaults();
} }
Config::Config(const Config& other)
: mode(other.mode)
, parseOptions(other.parseOptions)
, enabledLint(other.enabledLint)
, fatalLint(other.fatalLint)
, lintErrors(other.lintErrors)
, typeErrors(other.typeErrors)
, globals(other.globals)
{
for (const auto& [alias, aliasInfo] : other.aliases)
{
std::string configLocation = std::string(aliasInfo.configLocation);
if (!configLocationCache.contains(configLocation))
configLocationCache[configLocation] = std::make_unique<std::string>(configLocation);
AliasInfo newAliasInfo;
newAliasInfo.value = aliasInfo.value;
newAliasInfo.configLocation = *configLocationCache[configLocation];
aliases[alias] = std::move(newAliasInfo);
}
}
Config& Config::operator=(const Config& other)
{
if (this != &other)
{
Config copy(other);
std::swap(*this, copy);
}
return *this;
}
void Config::setAlias(std::string alias, const std::string& value, const std::string configLocation)
{
AliasInfo& info = aliases[alias];
info.value = value;
if (!configLocationCache.contains(configLocation))
configLocationCache[configLocation] = std::make_unique<std::string>(configLocation);
info.configLocation = *configLocationCache[configLocation];
}
static Error parseBoolean(bool& result, const std::string& value) static Error parseBoolean(bool& result, const std::string& value)
{ {
if (value == "true") if (value == "true")
@ -136,7 +181,12 @@ bool isValidAlias(const std::string& alias)
return true; return true;
} }
Error parseAlias(std::unordered_map<std::string, std::string>& aliases, std::string aliasKey, const std::string& aliasValue) static Error parseAlias(
Config& config,
std::string aliasKey,
const std::string& aliasValue,
const std::optional<ConfigOptions::AliasOptions>& aliasOptions
)
{ {
if (!isValidAlias(aliasKey)) if (!isValidAlias(aliasKey))
return Error{"Invalid alias " + aliasKey}; return Error{"Invalid alias " + aliasKey};
@ -150,8 +200,12 @@ Error parseAlias(std::unordered_map<std::string, std::string>& aliases, std::str
return ('A' <= c && c <= 'Z') ? (c + ('a' - 'A')) : c; return ('A' <= c && c <= 'Z') ? (c + ('a' - 'A')) : c;
} }
); );
if (!aliases.count(aliasKey))
aliases[std::move(aliasKey)] = aliasValue; if (!aliasOptions)
return Error("Cannot parse aliases without alias options");
if (aliasOptions->overwriteAliases || !config.aliases.contains(aliasKey))
config.setAlias(std::move(aliasKey), aliasValue, aliasOptions->configLocation);
return std::nullopt; return std::nullopt;
} }
@ -285,16 +339,16 @@ static Error parseJson(const std::string& contents, Action action)
return {}; return {};
} }
Error parseConfig(const std::string& contents, Config& config, bool compat) Error parseConfig(const std::string& contents, Config& config, const ConfigOptions& options)
{ {
return parseJson( return parseJson(
contents, contents,
[&](const std::vector<std::string>& keys, const std::string& value) -> Error [&](const std::vector<std::string>& keys, const std::string& value) -> Error
{ {
if (keys.size() == 1 && keys[0] == "languageMode") if (keys.size() == 1 && keys[0] == "languageMode")
return parseModeString(config.mode, value, compat); return parseModeString(config.mode, value, options.compat);
else if (keys.size() == 2 && keys[0] == "lint") else if (keys.size() == 2 && keys[0] == "lint")
return parseLintRuleString(config.enabledLint, config.fatalLint, keys[1], value, compat); return parseLintRuleString(config.enabledLint, config.fatalLint, keys[1], value, options.compat);
else if (keys.size() == 1 && keys[0] == "lintErrors") else if (keys.size() == 1 && keys[0] == "lintErrors")
return parseBoolean(config.lintErrors, value); return parseBoolean(config.lintErrors, value);
else if (keys.size() == 1 && keys[0] == "typeErrors") else if (keys.size() == 1 && keys[0] == "typeErrors")
@ -305,9 +359,9 @@ Error parseConfig(const std::string& contents, Config& config, bool compat)
return std::nullopt; return std::nullopt;
} }
else if (keys.size() == 2 && keys[0] == "aliases") else if (keys.size() == 2 && keys[0] == "aliases")
return parseAlias(config.aliases, keys[1], value); return parseAlias(config, keys[1], value, options.aliasOptions);
else if (compat && keys.size() == 2 && keys[0] == "language" && keys[1] == "mode") else if (options.compat && keys.size() == 2 && keys[0] == "language" && keys[1] == "mode")
return parseModeString(config.mode, value, compat); return parseModeString(config.mode, value, options.compat);
else else
{ {
std::vector<std::string_view> keysv(keys.begin(), keys.end()); std::vector<std::string_view> keysv(keys.begin(), keys.end());

View file

@ -23,6 +23,13 @@ struct Analysis final
using D = typename N::Data; using D = typename N::Data;
Analysis() = default;
Analysis(N a)
: analysis(std::move(a))
{
}
template<typename T> template<typename T>
static D fnMake(const N& analysis, const EGraph<L, N>& egraph, const L& enode) static D fnMake(const N& analysis, const EGraph<L, N>& egraph, const L& enode)
{ {
@ -59,6 +66,15 @@ struct EClass final
template<typename L, typename N> template<typename L, typename N>
struct EGraph final struct EGraph final
{ {
using EClassT = EClass<L, typename N::Data>;
EGraph() = default;
explicit EGraph(N analysis)
: analysis(std::move(analysis))
{
}
Id find(Id id) const Id find(Id id) const
{ {
return unionfind.find(id); return unionfind.find(id);
@ -85,33 +101,59 @@ struct EGraph final
return id; return id;
} }
void merge(Id id1, Id id2) // Returns true if the two IDs were not previously merged.
bool merge(Id id1, Id id2)
{ {
id1 = find(id1); id1 = find(id1);
id2 = find(id2); id2 = find(id2);
if (id1 == id2) if (id1 == id2)
return; return false;
unionfind.merge(id1, id2); const Id mergedId = unionfind.merge(id1, id2);
EClass<L, typename N::Data>& eclass1 = get(id1); // Ensure that id1 is the Id that we keep, and id2 is the id that we drop.
EClass<L, typename N::Data> eclass2 = std::move(get(id2)); if (mergedId == id2)
std::swap(id1, id2);
EClassT& eclass1 = get(id1);
EClassT eclass2 = std::move(get(id2));
classes.erase(id2); classes.erase(id2);
worklist.reserve(worklist.size() + eclass2.parents.size()); eclass1.nodes.insert(eclass1.nodes.end(), eclass2.nodes.begin(), eclass2.nodes.end());
for (auto [enode, id] : eclass2.parents) eclass1.parents.insert(eclass1.parents.end(), eclass2.parents.begin(), eclass2.parents.end());
worklist.push_back({std::move(enode), id});
std::sort(
eclass1.nodes.begin(),
eclass1.nodes.end(),
[](const L& left, const L& right)
{
return left.index() < right.index();
}
);
worklist.reserve(worklist.size() + eclass1.parents.size());
for (const auto& [eclass, id] : eclass1.parents)
worklist.push_back(id);
analysis.join(eclass1.data, eclass2.data); analysis.join(eclass1.data, eclass2.data);
return true;
} }
void rebuild() void rebuild()
{ {
std::unordered_set<Id> seen;
while (!worklist.empty()) while (!worklist.empty())
{ {
auto [enode, id] = worklist.back(); Id id = worklist.back();
worklist.pop_back(); worklist.pop_back();
repair(get(find(id)));
const bool isFresh = seen.insert(id).second;
if (!isFresh)
continue;
repair(find(id));
} }
} }
@ -120,16 +162,21 @@ struct EGraph final
return classes.size(); return classes.size();
} }
EClass<L, typename N::Data>& operator[](Id id) EClassT& operator[](Id id)
{ {
return get(find(id)); return get(find(id));
} }
const EClass<L, typename N::Data>& operator[](Id id) const const EClassT& operator[](Id id) const
{ {
return const_cast<EGraph*>(this)->get(find(id)); return const_cast<EGraph*>(this)->get(find(id));
} }
const std::unordered_map<Id, EClassT>& getAllClasses() const
{
return classes;
}
private: private:
Analysis<L, N> analysis; Analysis<L, N> analysis;
@ -139,19 +186,19 @@ private:
/// The e-class map 𝑀 maps e-class ids to e-classes. All equivalent e-class ids map to the same /// The e-class map 𝑀 maps e-class ids to e-classes. All equivalent e-class ids map to the same
/// e-class, i.e., 𝑎 ≡id 𝑏 iff 𝑀[𝑎] is the same set as 𝑀[𝑏]. An e-class id 𝑎 is said to refer to the /// e-class, i.e., 𝑎 ≡id 𝑏 iff 𝑀[𝑎] is the same set as 𝑀[𝑏]. An e-class id 𝑎 is said to refer to the
/// e-class 𝑀[find(𝑎)]. /// e-class 𝑀[find(𝑎)].
std::unordered_map<Id, EClass<L, typename N::Data>> classes; std::unordered_map<Id, EClassT> classes;
/// The hashcons 𝐻 is a map from e-nodes to e-class ids. /// The hashcons 𝐻 is a map from e-nodes to e-class ids.
std::unordered_map<L, Id, typename L::Hash> hashcons; std::unordered_map<L, Id, typename L::Hash> hashcons;
std::vector<std::pair<L, Id>> worklist; std::vector<Id> worklist;
private: private:
void canonicalize(L& enode) void canonicalize(L& enode)
{ {
// An e-node 𝑛 is canonical iff 𝑛 = canonicalize(𝑛), where // An e-node 𝑛 is canonical iff 𝑛 = canonicalize(𝑛), where
// canonicalize(𝑓(𝑎1, 𝑎2, ...)) = 𝑓(find(𝑎1), find(𝑎2), ...). // canonicalize(𝑓(𝑎1, 𝑎2, ...)) = 𝑓(find(𝑎1), find(𝑎2), ...).
for (Id& id : enode.operands()) for (Id& id : enode.mutableOperands())
id = find(id); id = find(id);
} }
@ -171,7 +218,7 @@ private:
classes.insert_or_assign( classes.insert_or_assign(
id, id,
EClass<L, typename N::Data>{ EClassT{
id, id,
{enode}, {enode},
analysis.make(*this, enode), analysis.make(*this, enode),
@ -182,7 +229,7 @@ private:
for (Id operand : enode.operands()) for (Id operand : enode.operands())
get(operand).parents.push_back({enode, id}); get(operand).parents.push_back({enode, id});
worklist.emplace_back(enode, id); worklist.emplace_back(id);
hashcons.insert_or_assign(enode, id); hashcons.insert_or_assign(enode, id);
return id; return id;
@ -190,12 +237,13 @@ private:
// Looks up for an eclass from a given non-canonicalized `id`. // Looks up for an eclass from a given non-canonicalized `id`.
// For a canonicalized eclass, use `get(find(id))` or `egraph[id]`. // For a canonicalized eclass, use `get(find(id))` or `egraph[id]`.
EClass<L, typename N::Data>& get(Id id) EClassT& get(Id id)
{ {
LUAU_ASSERT(classes.count(id));
return classes.at(id); return classes.at(id);
} }
void repair(EClass<L, typename N::Data>& eclass) void repair(Id id)
{ {
// In the egg paper, the `repair` function makes use of two loops over the `eclass.parents` // In the egg paper, the `repair` function makes use of two loops over the `eclass.parents`
// by first erasing the old enode entry, and adding back the canonicalized enode with the canonical id. // by first erasing the old enode entry, and adding back the canonicalized enode with the canonical id.
@ -204,26 +252,54 @@ private:
// Here, we unify the two loops. I think it's equivalent? // Here, we unify the two loops. I think it's equivalent?
// After canonicalizing the enodes, the eclass may contain multiple enodes that are equivalent. // After canonicalizing the enodes, the eclass may contain multiple enodes that are equivalent.
std::unordered_map<L, Id, typename L::Hash> map; std::unordered_map<L, Id, typename L::Hash> newParents;
for (auto& [enode, id] : eclass.parents)
// The eclass can be deallocated if it is merged into another eclass, so
// we take what we need from it and avoid retaining a pointer.
std::vector<std::pair<L, Id>> parents = get(id).parents;
for (auto& pair : parents)
{ {
L& enode = pair.first;
Id id = pair.second;
// By removing the old enode from the hashcons map, we will always find our new canonicalized eclass id. // By removing the old enode from the hashcons map, we will always find our new canonicalized eclass id.
hashcons.erase(enode); hashcons.erase(enode);
canonicalize(enode); canonicalize(enode);
hashcons.insert_or_assign(enode, find(id)); hashcons.insert_or_assign(enode, find(id));
if (auto it = map.find(enode); it != map.end()) if (auto it = newParents.find(enode); it != newParents.end())
merge(id, it->second); merge(id, it->second);
map.insert_or_assign(enode, find(id)); newParents.insert_or_assign(enode, find(id));
} }
eclass.parents.clear(); // We reacquire the pointer because the prior loop potentially merges
for (auto it = map.begin(); it != map.end();) // the eclass into another, which might move it around in memory.
EClassT* eclass = &get(find(id));
eclass->parents.clear();
for (const auto& [node, id] : newParents)
eclass->parents.emplace_back(std::move(node), std::move(id));
std::unordered_set<L, typename L::Hash> newNodes;
for (L node : eclass->nodes)
{ {
auto node = map.extract(it++); canonicalize(node);
eclass.parents.emplace_back(std::move(node.key()), node.mapped()); newNodes.insert(std::move(node));
} }
eclass->nodes.assign(newNodes.begin(), newNodes.end());
// FIXME: Extract into sortByTag()
std::sort(
eclass->nodes.begin(),
eclass->nodes.end(),
[](const L& left, const L& right)
{
return left.index() < right.index();
}
);
} }
}; };

View file

@ -2,6 +2,7 @@
#pragma once #pragma once
#include <cstddef> #include <cstddef>
#include <cstdint>
#include <functional> #include <functional>
namespace Luau::EqSat namespace Luau::EqSat
@ -9,15 +10,17 @@ namespace Luau::EqSat
struct Id final struct Id final
{ {
explicit Id(size_t id); explicit Id(uint32_t id);
explicit operator size_t() const; explicit operator uint32_t() const;
bool operator==(Id rhs) const; bool operator==(Id rhs) const;
bool operator!=(Id rhs) const; bool operator!=(Id rhs) const;
bool operator<(Id rhs) const;
private: private:
size_t id; uint32_t id;
}; };
} // namespace Luau::EqSat } // namespace Luau::EqSat

View file

@ -6,9 +6,19 @@
#include "Luau/Slice.h" #include "Luau/Slice.h"
#include "Luau/Variant.h" #include "Luau/Variant.h"
#include <algorithm>
#include <array> #include <array>
#include <type_traits> #include <type_traits>
#include <unordered_set>
#include <utility> #include <utility>
#include <vector>
#define LUAU_EQSAT_UNIT(name) \
struct name : ::Luau::EqSat::Unit<name> \
{ \
static constexpr const char* tag = #name; \
using Unit::Unit; \
}
#define LUAU_EQSAT_ATOM(name, t) \ #define LUAU_EQSAT_ATOM(name, t) \
struct name : public ::Luau::EqSat::Atom<name, t> \ struct name : public ::Luau::EqSat::Atom<name, t> \
@ -31,21 +41,57 @@
using NodeVector::NodeVector; \ using NodeVector::NodeVector; \
} }
#define LUAU_EQSAT_FIELD(name) \ #define LUAU_EQSAT_NODE_SET(name) \
struct name : public ::Luau::EqSat::Field<name> \ struct name : public ::Luau::EqSat::NodeSet<name, std::vector<::Luau::EqSat::Id>> \
{ \
}
#define LUAU_EQSAT_NODE_FIELDS(name, ...) \
struct name : public ::Luau::EqSat::NodeFields<name, __VA_ARGS__> \
{ \ { \
static constexpr const char* tag = #name; \ static constexpr const char* tag = #name; \
using NodeFields::NodeFields; \ using NodeSet::NodeSet; \
}
#define LUAU_EQSAT_NODE_ATOM_WITH_VECTOR(name, t) \
struct name : public ::Luau::EqSat::NodeAtomAndVector<name, t, std::vector<::Luau::EqSat::Id>> \
{ \
static constexpr const char* tag = #name; \
using NodeAtomAndVector::NodeAtomAndVector; \
} }
namespace Luau::EqSat namespace Luau::EqSat
{ {
template<typename Phantom>
struct Unit
{
Slice<Id> mutableOperands()
{
return {};
}
Slice<const Id> operands() const
{
return {};
}
bool operator==(const Unit& rhs) const
{
return true;
}
bool operator!=(const Unit& rhs) const
{
return false;
}
struct Hash
{
size_t operator()(const Unit& value) const
{
// chosen by fair dice roll.
// guaranteed to be random.
return 4;
}
};
};
template<typename Phantom, typename T> template<typename Phantom, typename T>
struct Atom struct Atom
{ {
@ -60,7 +106,7 @@ struct Atom
} }
public: public:
Slice<Id> operands() Slice<Id> mutableOperands()
{ {
return {}; return {};
} }
@ -92,6 +138,62 @@ private:
T _value; T _value;
}; };
template<typename Phantom, typename X, typename T>
struct NodeAtomAndVector
{
template<typename... Args>
NodeAtomAndVector(const X& value, Args&&... args)
: _value(value)
, vector{std::forward<Args>(args)...}
{
}
Id operator[](size_t i) const
{
return vector[i];
}
public:
const X& value() const
{
return _value;
}
Slice<Id> mutableOperands()
{
return Slice{vector.data(), vector.size()};
}
Slice<const Id> operands() const
{
return Slice{vector.data(), vector.size()};
}
bool operator==(const NodeAtomAndVector& rhs) const
{
return _value == rhs._value && vector == rhs.vector;
}
bool operator!=(const NodeAtomAndVector& rhs) const
{
return !(*this == rhs);
}
struct Hash
{
size_t operator()(const NodeAtomAndVector& value) const
{
size_t result = languageHash(value._value);
hashCombine(result, languageHash(value.vector));
return result;
}
};
private:
X _value;
T vector;
};
template<typename Phantom, typename T> template<typename Phantom, typename T>
struct NodeVector struct NodeVector
{ {
@ -107,7 +209,7 @@ struct NodeVector
} }
public: public:
Slice<Id> operands() Slice<Id> mutableOperands()
{ {
return Slice{vector.data(), vector.size()}; return Slice{vector.data(), vector.size()};
} }
@ -139,90 +241,61 @@ private:
T vector; T vector;
}; };
/// Empty base class just for static_asserts. template<typename Phantom, typename T>
struct FieldBase struct NodeSet
{ {
FieldBase() = delete; template<typename... Args>
NodeSet(Args&&... args)
FieldBase(FieldBase&&) = delete; : vector{std::forward<Args>(args)...}
FieldBase& operator=(FieldBase&&) = delete;
FieldBase(const FieldBase&) = delete;
FieldBase& operator=(const FieldBase&) = delete;
};
template<typename Phantom>
struct Field : FieldBase
{
};
template<typename Phantom, typename... Fields>
struct NodeFields
{
static_assert(std::conjunction<std::is_base_of<FieldBase, Fields>...>::value);
template<typename T>
static constexpr int getIndex()
{ {
constexpr int N = sizeof...(Fields); std::sort(begin(vector), end(vector));
constexpr bool is[N] = {std::is_same_v<std::decay_t<T>, Fields>...}; auto it = std::unique(begin(vector), end(vector));
vector.erase(it, end(vector));
}
for (int i = 0; i < N; ++i) Id operator[](size_t i) const
if (is[i]) {
return i; return vector[i];
return -1;
} }
public: public:
template<typename... Args> Slice<Id> mutableOperands()
NodeFields(Args&&... args)
: array{std::forward<Args>(args)...}
{ {
} return Slice{vector.data(), vector.size()};
Slice<Id> operands()
{
return Slice{array};
} }
Slice<const Id> operands() const Slice<const Id> operands() const
{ {
return Slice{array.data(), array.size()}; return Slice{vector.data(), vector.size()};
} }
template<typename T> bool operator==(const NodeSet& rhs) const
Id field() const
{ {
static_assert(std::disjunction_v<std::is_same<std::decay_t<T>, Fields>...>); return vector == rhs.vector;
return array[getIndex<T>()];
} }
bool operator==(const NodeFields& rhs) const bool operator!=(const NodeSet& rhs) const
{
return array == rhs.array;
}
bool operator!=(const NodeFields& rhs) const
{ {
return !(*this == rhs); return !(*this == rhs);
} }
struct Hash struct Hash
{ {
size_t operator()(const NodeFields& value) const size_t operator()(const NodeSet& value) const
{ {
return languageHash(value.array); return languageHash(value.vector);
} }
}; };
private: protected:
std::array<Id, sizeof...(Fields)> array; T vector;
}; };
template<typename... Ts> template<typename... Ts>
struct Language final struct Language final
{ {
using VariantTy = Luau::Variant<Ts...>;
template<typename T> template<typename T>
using WithinDomain = std::disjunction<std::is_same<std::decay_t<T>, Ts>...>; using WithinDomain = std::disjunction<std::is_same<std::decay_t<T>, Ts>...>;
@ -237,14 +310,14 @@ struct Language final
return v.index(); return v.index();
} }
/// You should never call this function with the intention of mutating the `Id`. /// This should only be used in canonicalization!
/// Reading is ok, but you should also never assume that these `Id`s are stable. /// Always prefer operands()
Slice<Id> operands() noexcept Slice<Id> mutableOperands() noexcept
{ {
return visit( return visit(
[](auto&& v) -> Slice<Id> [](auto&& v) -> Slice<Id>
{ {
return v.operands(); return v.mutableOperands();
}, },
v v
); );
@ -306,7 +379,7 @@ public:
}; };
private: private:
Variant<Ts...> v; VariantTy v;
}; };
} // namespace Luau::EqSat } // namespace Luau::EqSat

View file

@ -3,6 +3,7 @@
#include <cstddef> #include <cstddef>
#include <functional> #include <functional>
#include <unordered_set>
#include <vector> #include <vector>
namespace Luau::EqSat namespace Luau::EqSat

View file

@ -14,7 +14,9 @@ struct UnionFind final
Id makeSet(); Id makeSet();
Id find(Id id) const; Id find(Id id) const;
Id find(Id id); Id find(Id id);
void merge(Id a, Id b);
// Merge aSet with bSet and return the canonicalized Id into the merged set.
Id merge(Id aSet, Id bSet);
private: private:
std::vector<Id> parents; std::vector<Id> parents;

View file

@ -1,15 +1,16 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Id.h" #include "Luau/Id.h"
#include "Luau/Common.h"
namespace Luau::EqSat namespace Luau::EqSat
{ {
Id::Id(size_t id) Id::Id(uint32_t id)
: id(id) : id(id)
{ {
} }
Id::operator size_t() const Id::operator uint32_t() const
{ {
return id; return id;
} }
@ -24,9 +25,14 @@ bool Id::operator!=(Id rhs) const
return id != rhs.id; return id != rhs.id;
} }
bool Id::operator<(Id rhs) const
{
return id < rhs.id;
}
} // namespace Luau::EqSat } // namespace Luau::EqSat
size_t std::hash<Luau::EqSat::Id>::operator()(Luau::EqSat::Id id) const size_t std::hash<Luau::EqSat::Id>::operator()(Luau::EqSat::Id id) const
{ {
return std::hash<size_t>()(size_t(id)); return std::hash<uint32_t>()(uint32_t(id));
} }

View file

@ -3,12 +3,16 @@
#include "Luau/Common.h" #include "Luau/Common.h"
#include <limits>
namespace Luau::EqSat namespace Luau::EqSat
{ {
Id UnionFind::makeSet() Id UnionFind::makeSet()
{ {
Id id{parents.size()}; LUAU_ASSERT(parents.size() < std::numeric_limits<uint32_t>::max());
Id id{uint32_t(parents.size())};
parents.push_back(id); parents.push_back(id);
ranks.push_back(0); ranks.push_back(0);
@ -25,42 +29,44 @@ Id UnionFind::find(Id id)
Id set = canonicalize(id); Id set = canonicalize(id);
// An e-class id 𝑎 is canonical iff find(𝑎) = 𝑎. // An e-class id 𝑎 is canonical iff find(𝑎) = 𝑎.
while (id != parents[size_t(id)]) while (id != parents[uint32_t(id)])
{ {
// Note: we don't update the ranks here since a rank // Note: we don't update the ranks here since a rank
// represents the upper bound on the maximum depth of a tree // represents the upper bound on the maximum depth of a tree
Id parent = parents[size_t(id)]; Id parent = parents[uint32_t(id)];
parents[size_t(id)] = set; parents[uint32_t(id)] = set;
id = parent; id = parent;
} }
return set; return set;
} }
void UnionFind::merge(Id a, Id b) Id UnionFind::merge(Id a, Id b)
{ {
Id aSet = find(a); Id aSet = find(a);
Id bSet = find(b); Id bSet = find(b);
if (aSet == bSet) if (aSet == bSet)
return; return aSet;
// Ensure that the rank of set A is greater than the rank of set B // Ensure that the rank of set A is greater than the rank of set B
if (ranks[size_t(aSet)] < ranks[size_t(bSet)]) if (ranks[uint32_t(aSet)] > ranks[uint32_t(bSet)])
std::swap(aSet, bSet); std::swap(aSet, bSet);
parents[size_t(bSet)] = aSet; parents[uint32_t(bSet)] = aSet;
if (ranks[size_t(aSet)] == ranks[size_t(bSet)]) if (ranks[uint32_t(aSet)] == ranks[uint32_t(bSet)])
ranks[size_t(aSet)]++; ranks[uint32_t(aSet)]++;
return aSet;
} }
Id UnionFind::canonicalize(Id id) const Id UnionFind::canonicalize(Id id) const
{ {
LUAU_ASSERT(size_t(id) < parents.size()); LUAU_ASSERT(uint32_t(id) < parents.size());
// An e-class id 𝑎 is canonical iff find(𝑎) = 𝑎. // An e-class id 𝑎 is canonical iff find(𝑎) = 𝑎.
while (id != parents[size_t(id)]) while (id != parents[uint32_t(id)])
id = parents[size_t(id)]; id = parents[uint32_t(id)];
return id; return id;
} }

View file

@ -3,11 +3,11 @@ Luau ![CI](https://github.com/luau-lang/luau/actions/workflows/build.yml/badge.s
Luau (lowercase u, /ˈlu.aʊ/) is a fast, small, safe, gradually typed embeddable scripting language derived from [Lua](https://lua.org). Luau (lowercase u, /ˈlu.aʊ/) is a fast, small, safe, gradually typed embeddable scripting language derived from [Lua](https://lua.org).
It is designed to be backwards compatible with Lua 5.1, as well as incorporating [some features](https://luau-lang.org/compatibility) from future Lua releases, but also expands the feature set (most notably with type annotations). Luau is largely implemented from scratch, with the language runtime being a very heavily modified version of Lua 5.1 runtime, with completely rewritten interpreter and other [performance innovations](https://luau-lang.org/performance). The runtime mostly preserves Lua 5.1 API, so existing bindings should be more or less compatible with a few caveats. It is designed to be backwards compatible with Lua 5.1, as well as incorporating [some features](https://luau.org/compatibility) from future Lua releases, but also expands the feature set (most notably with type annotations). Luau is largely implemented from scratch, with the language runtime being a very heavily modified version of Lua 5.1 runtime, with completely rewritten interpreter and other [performance innovations](https://luau.org/performance). The runtime mostly preserves Lua 5.1 API, so existing bindings should be more or less compatible with a few caveats.
Luau is used by Roblox game developers to write game code, as well as by Roblox engineers to implement large parts of the user-facing application code as well as portions of the editor (Roblox Studio) as plugins. Roblox chose to open-source Luau to foster collaboration within the Roblox community as well as to allow other companies and communities to benefit from the ongoing language and runtime innovation. As a consequence, Luau is now also used by games like Alan Wake 2 and Warframe. Luau is used by Roblox game developers to write game code, as well as by Roblox engineers to implement large parts of the user-facing application code as well as portions of the editor (Roblox Studio) as plugins. Roblox chose to open-source Luau to foster collaboration within the Roblox community as well as to allow other companies and communities to benefit from the ongoing language and runtime innovation. As a consequence, Luau is now also used by games like Alan Wake 2 and Warframe.
This repository hosts source code for the language implementation and associated tooling. Documentation for the language is available at https://luau-lang.org/ and accepts contributions via [site repository](https://github.com/luau-lang/site); the language is evolved through RFCs that are located in [rfcs repository](https://github.com/luau-lang/rfcs). This repository hosts source code for the language implementation and associated tooling. Documentation for the language is available at https://luau.org/ and accepts contributions via [site repository](https://github.com/luau-lang/site); the language is evolved through RFCs that are located in [rfcs repository](https://github.com/luau-lang/rfcs).
# Usage # Usage
@ -15,7 +15,7 @@ Luau is an embeddable language, but it also comes with two command-line tools by
`luau` is a command-line REPL and can also run input files. Note that REPL runs in a sandboxed environment and as such doesn't have access to the underlying file system except for ability to `require` modules. `luau` is a command-line REPL and can also run input files. Note that REPL runs in a sandboxed environment and as such doesn't have access to the underlying file system except for ability to `require` modules.
`luau-analyze` is a command-line type checker and linter; given a set of input files, it produces errors/warnings according to the file configuration, which can be customized by using `--!` comments in the files or [`.luaurc`](https://rfcs.luau-lang.org/config-luaurc) files. For details please refer to [type checking]( https://luau-lang.org/typecheck) and [linting](https://luau-lang.org/lint) documentation. `luau-analyze` is a command-line type checker and linter; given a set of input files, it produces errors/warnings according to the file configuration, which can be customized by using `--!` comments in the files or [`.luaurc`](https://rfcs.luau.org/config-luaurc) files. For details please refer to [type checking]( https://luau.org/typecheck) and [linting](https://luau.org/lint) documentation.
# Installation # Installation
@ -28,7 +28,7 @@ Alternatively, you can use one of the packaged distributions (note that these ar
- Alpine Linux: [Enable community repositories](https://wiki.alpinelinux.org/w/index.php?title=Enable_Community_Repository) and run `apk add luau` - Alpine Linux: [Enable community repositories](https://wiki.alpinelinux.org/w/index.php?title=Enable_Community_Repository) and run `apk add luau`
- Gentoo Linux: Luau is [officially packaged by Gentoo](https://packages.gentoo.org/packages/dev-lang/luau) and can be installed using `emerge dev-lang/luau`. You may have to unmask the package first before installing it (which can be done by including the `--autounmask=y` option in the `emerge` command). - Gentoo Linux: Luau is [officially packaged by Gentoo](https://packages.gentoo.org/packages/dev-lang/luau) and can be installed using `emerge dev-lang/luau`. You may have to unmask the package first before installing it (which can be done by including the `--autounmask=y` option in the `emerge` command).
After installing, you will want to validate the installation was successful by running the test case [here](https://luau-lang.org/getting-started). After installing, you will want to validate the installation was successful by running the test case [here](https://luau.org/getting-started).
## Building ## Building

View file

@ -14,6 +14,7 @@ endif()
# Luau.Ast Sources # Luau.Ast Sources
target_sources(Luau.Ast PRIVATE target_sources(Luau.Ast PRIVATE
Ast/include/Luau/Allocator.h
Ast/include/Luau/Ast.h Ast/include/Luau/Ast.h
Ast/include/Luau/Confusables.h Ast/include/Luau/Confusables.h
Ast/include/Luau/Lexer.h Ast/include/Luau/Lexer.h
@ -24,6 +25,7 @@ target_sources(Luau.Ast PRIVATE
Ast/include/Luau/StringUtils.h Ast/include/Luau/StringUtils.h
Ast/include/Luau/TimeTrace.h Ast/include/Luau/TimeTrace.h
Ast/src/Allocator.cpp
Ast/src/Ast.cpp Ast/src/Ast.cpp
Ast/src/Confusables.cpp Ast/src/Confusables.cpp
Ast/src/Lexer.cpp Ast/src/Lexer.cpp
@ -168,6 +170,7 @@ target_sources(Luau.Analysis PRIVATE
Analysis/include/Luau/AstJsonEncoder.h Analysis/include/Luau/AstJsonEncoder.h
Analysis/include/Luau/AstQuery.h Analysis/include/Luau/AstQuery.h
Analysis/include/Luau/Autocomplete.h Analysis/include/Luau/Autocomplete.h
Analysis/include/Luau/AutocompleteTypes.h
Analysis/include/Luau/BuiltinDefinitions.h Analysis/include/Luau/BuiltinDefinitions.h
Analysis/include/Luau/Cancellation.h Analysis/include/Luau/Cancellation.h
Analysis/include/Luau/Clone.h Analysis/include/Luau/Clone.h
@ -181,6 +184,7 @@ target_sources(Luau.Analysis PRIVATE
Analysis/include/Luau/Differ.h Analysis/include/Luau/Differ.h
Analysis/include/Luau/Documentation.h Analysis/include/Luau/Documentation.h
Analysis/include/Luau/Error.h Analysis/include/Luau/Error.h
Analysis/include/Luau/EqSatSimplification.h
Analysis/include/Luau/FileResolver.h Analysis/include/Luau/FileResolver.h
Analysis/include/Luau/FragmentAutocomplete.h Analysis/include/Luau/FragmentAutocomplete.h
Analysis/include/Luau/Frontend.h Analysis/include/Luau/Frontend.h
@ -245,6 +249,7 @@ target_sources(Luau.Analysis PRIVATE
Analysis/src/AstJsonEncoder.cpp Analysis/src/AstJsonEncoder.cpp
Analysis/src/AstQuery.cpp Analysis/src/AstQuery.cpp
Analysis/src/Autocomplete.cpp Analysis/src/Autocomplete.cpp
Analysis/src/AutocompleteCore.cpp
Analysis/src/BuiltinDefinitions.cpp Analysis/src/BuiltinDefinitions.cpp
Analysis/src/Clone.cpp Analysis/src/Clone.cpp
Analysis/src/Constraint.cpp Analysis/src/Constraint.cpp
@ -256,6 +261,7 @@ target_sources(Luau.Analysis PRIVATE
Analysis/src/Differ.cpp Analysis/src/Differ.cpp
Analysis/src/EmbeddedBuiltinDefinitions.cpp Analysis/src/EmbeddedBuiltinDefinitions.cpp
Analysis/src/Error.cpp Analysis/src/Error.cpp
Analysis/src/EqSatSimplification.cpp
Analysis/src/FragmentAutocomplete.cpp Analysis/src/FragmentAutocomplete.cpp
Analysis/src/Frontend.cpp Analysis/src/Frontend.cpp
Analysis/src/Generalization.cpp Analysis/src/Generalization.cpp
@ -444,6 +450,7 @@ if(TARGET Luau.UnitTest)
tests/EqSat.language.test.cpp tests/EqSat.language.test.cpp
tests/EqSat.propositional.test.cpp tests/EqSat.propositional.test.cpp
tests/EqSat.slice.test.cpp tests/EqSat.slice.test.cpp
tests/EqSatSimplification.test.cpp
tests/Error.test.cpp tests/Error.test.cpp
tests/Fixture.cpp tests/Fixture.cpp
tests/Fixture.h tests/Fixture.h

View file

@ -39,8 +39,8 @@ const char* lua_ident = "$Lua: Lua 5.1.4 Copyright (C) 1994-2008 Lua.org, PUC-Ri
"$Authors: R. Ierusalimschy, L. H. de Figueiredo & W. Celes $\n" "$Authors: R. Ierusalimschy, L. H. de Figueiredo & W. Celes $\n"
"$URL: www.lua.org $\n"; "$URL: www.lua.org $\n";
const char* luau_ident = "$Luau: Copyright (C) 2019-2023 Roblox Corporation $\n" const char* luau_ident = "$Luau: Copyright (C) 2019-2024 Roblox Corporation $\n"
"$URL: luau-lang.org $\n"; "$URL: luau.org $\n";
#define api_checknelems(L, n) api_check(L, (n) <= (L->top - L->base)) #define api_checknelems(L, n) api_check(L, (n) <= (L->top - L->base))

View file

@ -86,7 +86,8 @@ function test()
function compute_triangle_cones() function compute_triangle_cones()
local mesh_area = 0 local mesh_area = 0
local i = 1 local pos = 1
for i = 1,#mesh.indices,3 do for i = 1,#mesh.indices,3 do
local p0 = mesh.vertices[mesh.indices[i]] local p0 = mesh.vertices[mesh.indices[i]]
local p1 = mesh.vertices[mesh.indices[i + 1]] local p1 = mesh.vertices[mesh.indices[i + 1]]
@ -100,9 +101,9 @@ function test()
local area = vector.magnitude(normal) local area = vector.magnitude(normal)
local invarea = (area == 0) and 0 or 1 / area; local invarea = (area == 0) and 0 or 1 / area;
mesh.triangle_cone_p[i] = (p0.p + p1.p + p2.p) / 3 mesh.triangle_cone_p[pos] = (p0.p + p1.p + p2.p) / 3
mesh.triangle_cone_n[i] = normal * invarea mesh.triangle_cone_n[pos] = normal * invarea
i += 1 pos += 1
mesh_area += area mesh_area += area
end end

View file

@ -400,6 +400,9 @@ TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "FPMath")
SINGLE_COMPARE(fsub(d1, d2, d3), 0x1E633841); SINGLE_COMPARE(fsub(d1, d2, d3), 0x1E633841);
SINGLE_COMPARE(fsub(s29, s29, s28), 0x1E3C3BBD); SINGLE_COMPARE(fsub(s29, s29, s28), 0x1E3C3BBD);
SINGLE_COMPARE(faddp(s29, s28), 0x7E30DB9D);
SINGLE_COMPARE(faddp(d29, d28), 0x7E70DB9D);
SINGLE_COMPARE(frinta(d1, d2), 0x1E664041); SINGLE_COMPARE(frinta(d1, d2), 0x1E664041);
SINGLE_COMPARE(frintm(d1, d2), 0x1E654041); SINGLE_COMPARE(frintm(d1, d2), 0x1E654041);
SINGLE_COMPARE(frintp(d1, d2), 0x1E64C041); SINGLE_COMPARE(frintp(d1, d2), 0x1E64C041);

View file

@ -577,6 +577,8 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXTernaryInstructionForms")
SINGLE_COMPARE(vpshufps(xmm7, xmm12, xmmword[rcx + r10], 0b11010100), 0xc4, 0xa1, 0x18, 0xc6, 0x3c, 0x11, 0xd4); SINGLE_COMPARE(vpshufps(xmm7, xmm12, xmmword[rcx + r10], 0b11010100), 0xc4, 0xa1, 0x18, 0xc6, 0x3c, 0x11, 0xd4);
SINGLE_COMPARE(vpinsrd(xmm7, xmm12, xmmword[rcx + r10], 2), 0xc4, 0xa3, 0x19, 0x22, 0x3c, 0x11, 0x02); SINGLE_COMPARE(vpinsrd(xmm7, xmm12, xmmword[rcx + r10], 2), 0xc4, 0xa3, 0x19, 0x22, 0x3c, 0x11, 0x02);
SINGLE_COMPARE(vdpps(xmm7, xmm12, xmmword[rcx + r10], 2), 0xc4, 0xa3, 0x19, 0x40, 0x3c, 0x11, 0x02);
} }
TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "MiscInstructions") TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "MiscInstructions")

View file

@ -67,7 +67,7 @@ TEST_CASE("encode_constants")
charString.data = const_cast<char*>("a\x1d\0\\\"b"); charString.data = const_cast<char*>("a\x1d\0\\\"b");
charString.size = 6; charString.size = 6;
AstExprConstantString needsEscaping{Location(), charString}; AstExprConstantString needsEscaping{Location(), charString, AstExprConstantString::QuotedSimple};
CHECK_EQ(R"({"type":"AstExprConstantNil","location":"0,0 - 0,0"})", toJson(&nil)); CHECK_EQ(R"({"type":"AstExprConstantNil","location":"0,0 - 0,0"})", toJson(&nil));
CHECK_EQ(R"({"type":"AstExprConstantBool","location":"0,0 - 0,0","value":true})", toJson(&b)); CHECK_EQ(R"({"type":"AstExprConstantBool","location":"0,0 - 0,0","value":true})", toJson(&b));
@ -83,7 +83,7 @@ TEST_CASE("basic_escaping")
{ {
std::string s = "hello \"world\""; std::string s = "hello \"world\"";
AstArray<char> theString{s.data(), s.size()}; AstArray<char> theString{s.data(), s.size()};
AstExprConstantString str{Location(), theString}; AstExprConstantString str{Location(), theString, AstExprConstantString::QuotedSimple};
std::string expected = R"({"type":"AstExprConstantString","location":"0,0 - 0,0","value":"hello \"world\""})"; std::string expected = R"({"type":"AstExprConstantString","location":"0,0 - 0,0","value":"hello \"world\""})";
CHECK_EQ(expected, toJson(&str)); CHECK_EQ(expected, toJson(&str));

View file

@ -151,40 +151,6 @@ struct ACBuiltinsFixture : ACFixtureImpl<BuiltinsFixture>
{ {
}; };
#define LUAU_CHECK_HAS_KEY(map, key) \
do \
{ \
auto&& _m = (map); \
auto&& _k = (key); \
const size_t count = _m.count(_k); \
CHECK_MESSAGE(count, "Map should have key \"" << _k << "\""); \
if (!count) \
{ \
MESSAGE("Keys: (count " << _m.size() << ")"); \
for (const auto& [k, v] : _m) \
{ \
MESSAGE("\tkey: " << k); \
} \
} \
} while (false)
#define LUAU_CHECK_HAS_NO_KEY(map, key) \
do \
{ \
auto&& _m = (map); \
auto&& _k = (key); \
const size_t count = _m.count(_k); \
CHECK_MESSAGE(!count, "Map should not have key \"" << _k << "\""); \
if (count) \
{ \
MESSAGE("Keys: (count " << _m.size() << ")"); \
for (const auto& [k, v] : _m) \
{ \
MESSAGE("\tkey: " << k); \
} \
} \
} while (false)
TEST_SUITE_BEGIN("AutocompleteTest"); TEST_SUITE_BEGIN("AutocompleteTest");
TEST_CASE_FIXTURE(ACFixture, "empty_program") TEST_CASE_FIXTURE(ACFixture, "empty_program")

View file

@ -58,7 +58,11 @@ TEST_CASE("report_a_syntax_error")
TEST_CASE("noinfer_is_still_allowed") TEST_CASE("noinfer_is_still_allowed")
{ {
Config config; Config config;
auto err = parseConfig(R"( {"language": {"mode": "noinfer"}} )", config, true);
ConfigOptions opts;
opts.compat = true;
auto err = parseConfig(R"( {"language": {"mode": "noinfer"}} )", config, opts);
REQUIRE(!err); REQUIRE(!err);
CHECK_EQ(int(Luau::Mode::NoCheck), int(config.mode)); CHECK_EQ(int(Luau::Mode::NoCheck), int(config.mode));
@ -147,6 +151,10 @@ TEST_CASE("extra_globals")
TEST_CASE("lint_rules_compat") TEST_CASE("lint_rules_compat")
{ {
Config config; Config config;
ConfigOptions opts;
opts.compat = true;
auto err = parseConfig( auto err = parseConfig(
R"( R"(
{"lint": { {"lint": {
@ -156,7 +164,7 @@ TEST_CASE("lint_rules_compat")
}} }}
)", )",
config, config,
true opts
); );
REQUIRE(!err); REQUIRE(!err);

View file

@ -10,6 +10,7 @@ namespace Luau
ConstraintGeneratorFixture::ConstraintGeneratorFixture() ConstraintGeneratorFixture::ConstraintGeneratorFixture()
: Fixture() : Fixture()
, mainModule(new Module) , mainModule(new Module)
, simplifier(newSimplifier(NotNull{&arena}, builtinTypes))
, forceTheFlag{FFlag::LuauSolverV2, true} , forceTheFlag{FFlag::LuauSolverV2, true}
{ {
mainModule->name = "MainModule"; mainModule->name = "MainModule";
@ -25,6 +26,7 @@ void ConstraintGeneratorFixture::generateConstraints(const std::string& code)
cg = std::make_unique<ConstraintGenerator>( cg = std::make_unique<ConstraintGenerator>(
mainModule, mainModule,
NotNull{&normalizer}, NotNull{&normalizer},
NotNull{simplifier.get()},
NotNull{&typeFunctionRuntime}, NotNull{&typeFunctionRuntime},
NotNull(&moduleResolver), NotNull(&moduleResolver),
builtinTypes, builtinTypes,
@ -44,8 +46,19 @@ void ConstraintGeneratorFixture::solve(const std::string& code)
{ {
generateConstraints(code); generateConstraints(code);
ConstraintSolver cs{ ConstraintSolver cs{
NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{rootScope}, constraints, "MainModule", NotNull(&moduleResolver), {}, &logger, NotNull{dfg.get()}, {} NotNull{&normalizer},
NotNull{simplifier.get()},
NotNull{&typeFunctionRuntime},
NotNull{rootScope},
constraints,
"MainModule",
NotNull(&moduleResolver),
{},
&logger,
NotNull{dfg.get()},
{}
}; };
cs.run(); cs.run();
} }

View file

@ -4,8 +4,9 @@
#include "Luau/ConstraintGenerator.h" #include "Luau/ConstraintGenerator.h"
#include "Luau/ConstraintSolver.h" #include "Luau/ConstraintSolver.h"
#include "Luau/DcrLogger.h" #include "Luau/DcrLogger.h"
#include "Luau/TypeArena.h" #include "Luau/EqSatSimplification.h"
#include "Luau/Module.h" #include "Luau/Module.h"
#include "Luau/TypeArena.h"
#include "Fixture.h" #include "Fixture.h"
#include "ScopedFlags.h" #include "ScopedFlags.h"
@ -20,6 +21,7 @@ struct ConstraintGeneratorFixture : Fixture
DcrLogger logger; DcrLogger logger;
UnifierSharedState sharedState{&ice}; UnifierSharedState sharedState{&ice};
Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}};
SimplifierPtr simplifier;
TypeCheckLimits limits; TypeCheckLimits limits;
TypeFunctionRuntime typeFunctionRuntime{NotNull{&ice}, NotNull{&limits}}; TypeFunctionRuntime typeFunctionRuntime{NotNull{&ice}, NotNull{&limits}};

View file

@ -11,9 +11,7 @@ LUAU_EQSAT_ATOM(I32, int);
LUAU_EQSAT_ATOM(Bool, bool); LUAU_EQSAT_ATOM(Bool, bool);
LUAU_EQSAT_ATOM(Str, std::string); LUAU_EQSAT_ATOM(Str, std::string);
LUAU_EQSAT_FIELD(Left); LUAU_EQSAT_NODE_ARRAY(Add, 2);
LUAU_EQSAT_FIELD(Right);
LUAU_EQSAT_NODE_FIELDS(Add, Left, Right);
using namespace Luau; using namespace Luau;
@ -117,8 +115,8 @@ TEST_CASE("node_field")
Add add{left, right}; Add add{left, right};
EqSat::Id left2 = add.field<Left>(); EqSat::Id left2 = add.operands()[0];
EqSat::Id right2 = add.field<Right>(); EqSat::Id right2 = add.operands()[1];
CHECK(left == left2); CHECK(left == left2);
CHECK(left != right2); CHECK(left != right2);
@ -135,10 +133,10 @@ TEST_CASE("language_operands")
const Add* add = v2.get<Add>(); const Add* add = v2.get<Add>();
REQUIRE(add); REQUIRE(add);
EqSat::Slice<EqSat::Id> actual = v2.operands(); EqSat::Slice<const EqSat::Id> actual = v2.operands();
CHECK(actual.size() == 2); CHECK(actual.size() == 2);
CHECK(actual[0] == add->field<Left>()); CHECK(actual[0] == add->operands()[0]);
CHECK(actual[1] == add->field<Right>()); CHECK(actual[1] == add->operands()[1]);
} }
TEST_SUITE_END(); TEST_SUITE_END();

View file

@ -0,0 +1,728 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Fixture.h"
#include "Luau/EqSatSimplification.h"
using namespace Luau;
struct ESFixture : Fixture
{
ScopedFastFlag newSolverOnly{FFlag::LuauSolverV2, true};
TypeArena arena_;
const NotNull<TypeArena> arena{&arena_};
SimplifierPtr simplifier;
TypeId parentClass;
TypeId childClass;
TypeId anotherChild;
TypeId unrelatedClass;
TypeId genericT = arena_.addType(GenericType{"T"});
TypeId genericU = arena_.addType(GenericType{"U"});
TypeId numberToString = arena_.addType(FunctionType{
arena_.addTypePack({builtinTypes->numberType}),
arena_.addTypePack({builtinTypes->stringType})
});
TypeId stringToNumber = arena_.addType(FunctionType{
arena_.addTypePack({builtinTypes->stringType}),
arena_.addTypePack({builtinTypes->numberType})
});
ESFixture()
: simplifier(newSimplifier(arena, builtinTypes))
{
createSomeClasses(&frontend);
ScopePtr moduleScope = frontend.globals.globalScope;
parentClass = moduleScope->linearSearchForBinding("Parent")->typeId;
childClass = moduleScope->linearSearchForBinding("Child")->typeId;
anotherChild = moduleScope->linearSearchForBinding("AnotherChild")->typeId;
unrelatedClass = moduleScope->linearSearchForBinding("Unrelated")->typeId;
}
std::optional<std::string> simplifyStr(TypeId ty)
{
auto res = eqSatSimplify(NotNull{simplifier.get()}, ty);
LUAU_ASSERT(res);
return toString(res->result);
}
TypeId tbl(TableType::Props props)
{
return arena->addType(TableType{std::move(props), std::nullopt, TypeLevel{}, TableState::Sealed});
}
};
TEST_SUITE_BEGIN("EqSatSimplification");
TEST_CASE_FIXTURE(ESFixture, "primitive")
{
CHECK("number" == simplifyStr(builtinTypes->numberType));
}
TEST_CASE_FIXTURE(ESFixture, "number | number")
{
TypeId ty = arena->addType(UnionType{{builtinTypes->numberType, builtinTypes->numberType}});
CHECK("number" == simplifyStr(ty));
}
TEST_CASE_FIXTURE(ESFixture, "number | string")
{
CHECK("number | string" == simplifyStr(arena->addType(UnionType{{builtinTypes->numberType, builtinTypes->stringType}})));
}
TEST_CASE_FIXTURE(ESFixture, "t1 where t1 = number | t1")
{
TypeId ty = arena->freshType(nullptr);
asMutable(ty)->ty.emplace<UnionType>(std::vector<TypeId>{builtinTypes->numberType, ty});
CHECK("number" == simplifyStr(ty));
}
TEST_CASE_FIXTURE(ESFixture, "number | string | number")
{
TypeId ty = arena->addType(UnionType{{builtinTypes->numberType, builtinTypes->stringType, builtinTypes->numberType}});
CHECK("number | string" == simplifyStr(ty));
}
TEST_CASE_FIXTURE(ESFixture, "string | (number | string) | number")
{
TypeId u1 = arena->addType(UnionType{{builtinTypes->numberType, builtinTypes->stringType}});
TypeId u2 = arena->addType(UnionType{{builtinTypes->stringType, u1, builtinTypes->numberType}});
CHECK("number | string" == simplifyStr(u2));
}
TEST_CASE_FIXTURE(ESFixture, "string | any")
{
CHECK("any" == simplifyStr(arena->addType(UnionType{{builtinTypes->stringType, builtinTypes->anyType}})));
}
TEST_CASE_FIXTURE(ESFixture, "any | string")
{
CHECK("any" == simplifyStr(arena->addType(UnionType{{builtinTypes->anyType, builtinTypes->stringType}})));
}
TEST_CASE_FIXTURE(ESFixture, "any | never")
{
CHECK("any" == simplifyStr(arena->addType(UnionType{{builtinTypes->anyType, builtinTypes->neverType}})));
}
TEST_CASE_FIXTURE(ESFixture, "string | unknown")
{
CHECK("unknown" == simplifyStr(arena->addType(UnionType{{builtinTypes->stringType, builtinTypes->unknownType}})));
}
TEST_CASE_FIXTURE(ESFixture, "unknown | string")
{
CHECK("unknown" == simplifyStr(arena->addType(UnionType{{builtinTypes->unknownType, builtinTypes->stringType}})));
}
TEST_CASE_FIXTURE(ESFixture, "unknown | never")
{
CHECK("unknown" == simplifyStr(arena->addType(UnionType{{builtinTypes->unknownType, builtinTypes->neverType}})));
}
TEST_CASE_FIXTURE(ESFixture, "string | never")
{
CHECK("string" == simplifyStr(arena->addType(UnionType{{builtinTypes->stringType, builtinTypes->neverType}})));
}
TEST_CASE_FIXTURE(ESFixture, "string | never | number")
{
CHECK("number | string" == simplifyStr(arena->addType(UnionType{{builtinTypes->stringType, builtinTypes->neverType, builtinTypes->numberType}})));
}
TEST_CASE_FIXTURE(ESFixture, "string & string")
{
CHECK("string" == simplifyStr(arena->addType(IntersectionType{{builtinTypes->stringType, builtinTypes->stringType}})));
}
TEST_CASE_FIXTURE(ESFixture, "string & number")
{
CHECK("never" == simplifyStr(arena->addType(IntersectionType{{builtinTypes->stringType, builtinTypes->numberType}})));
}
TEST_CASE_FIXTURE(ESFixture, "string & unknown")
{
CHECK("string" == simplifyStr(arena->addType(IntersectionType{{builtinTypes->stringType, builtinTypes->unknownType}})));
}
TEST_CASE_FIXTURE(ESFixture, "never & string")
{
CHECK("never" == simplifyStr(arena->addType(IntersectionType{{builtinTypes->neverType, builtinTypes->stringType}})));
}
TEST_CASE_FIXTURE(ESFixture, "string & (unknown | never)")
{
CHECK("string" == simplifyStr(arena->addType(IntersectionType{{
builtinTypes->stringType,
arena->addType(UnionType{{builtinTypes->unknownType, builtinTypes->neverType}})
}})));
}
TEST_CASE_FIXTURE(ESFixture, "true | false")
{
CHECK("boolean" == simplifyStr(arena->addType(UnionType{{builtinTypes->trueType, builtinTypes->falseType}})));
}
/*
* Intuitively, if we have a type like
*
* x where x = A & B & (C | D | x)
*
* We know that x is certainly not larger than A & B.
* We also know that the union (C | D | x) can be rewritten `(C | D | (A & B & (C | D | x)))
* This tells us that the union part is not smaller than A & B.
* We can therefore discard the union entirely and simplify this type to A & B
*/
TEST_CASE_FIXTURE(ESFixture, "t1 where t1 = string & (number | t1)")
{
TypeId intersectionTy = arena->addType(BlockedType{});
TypeId unionTy = arena->addType(UnionType{{builtinTypes->numberType, intersectionTy}});
asMutable(intersectionTy)->ty.emplace<IntersectionType>(std::vector<TypeId>{builtinTypes->stringType, unionTy});
CHECK("string" == simplifyStr(intersectionTy));
}
TEST_CASE_FIXTURE(ESFixture, "t1 where t1 = string & (unknown | t1)")
{
TypeId intersectionTy = arena->addType(BlockedType{});
TypeId unionTy = arena->addType(UnionType{{builtinTypes->unknownType, intersectionTy}});
asMutable(intersectionTy)->ty.emplace<IntersectionType>(std::vector<TypeId>{builtinTypes->stringType, unionTy});
CHECK("string" == simplifyStr(intersectionTy));
}
TEST_CASE_FIXTURE(ESFixture, "error | unknown")
{
CHECK("any" == simplifyStr(arena->addType(UnionType{{builtinTypes->errorType, builtinTypes->unknownType}})));
}
TEST_CASE_FIXTURE(ESFixture, "\"hello\" | string")
{
CHECK("string" == simplifyStr(arena->addType(UnionType{{
arena->addType(SingletonType{StringSingleton{"hello"}}), builtinTypes->stringType
}})));
}
TEST_CASE_FIXTURE(ESFixture, "\"hello\" | \"world\" | \"hello\"")
{
CHECK("\"hello\" | \"world\"" == simplifyStr(arena->addType(UnionType{{
arena->addType(SingletonType{StringSingleton{"hello"}}),
arena->addType(SingletonType{StringSingleton{"world"}}),
arena->addType(SingletonType{StringSingleton{"hello"}}),
}})));
}
TEST_CASE_FIXTURE(ESFixture, "nil | boolean | number | string | thread | function | table | class | buffer")
{
CHECK("unknown" == simplifyStr(arena->addType(UnionType{{
builtinTypes->nilType,
builtinTypes->booleanType,
builtinTypes->numberType,
builtinTypes->stringType,
builtinTypes->threadType,
builtinTypes->functionType,
builtinTypes->tableType,
builtinTypes->classType,
builtinTypes->bufferType,
}})));
}
TEST_CASE_FIXTURE(ESFixture, "Parent & number")
{
CHECK("never" == simplifyStr(arena->addType(IntersectionType{{
parentClass, builtinTypes->numberType
}})));
}
TEST_CASE_FIXTURE(ESFixture, "Child & Parent")
{
CHECK("Child" == simplifyStr(arena->addType(IntersectionType{{
childClass, parentClass
}})));
}
TEST_CASE_FIXTURE(ESFixture, "Child & Unrelated")
{
CHECK("never" == simplifyStr(arena->addType(IntersectionType{{
childClass, unrelatedClass
}})));
}
TEST_CASE_FIXTURE(ESFixture, "Child | Parent")
{
CHECK("Parent" == simplifyStr(arena->addType(UnionType{{
childClass, parentClass
}})));
}
TEST_CASE_FIXTURE(ESFixture, "class | Child")
{
CHECK("class" == simplifyStr(arena->addType(UnionType{{
builtinTypes->classType, childClass
}})));
}
TEST_CASE_FIXTURE(ESFixture, "Parent | class | Child")
{
CHECK("class" == simplifyStr(arena->addType(UnionType{{
parentClass, builtinTypes->classType, childClass
}})));
}
TEST_CASE_FIXTURE(ESFixture, "Parent | Unrelated")
{
CHECK("Parent | Unrelated" == simplifyStr(arena->addType(UnionType{{
parentClass, unrelatedClass
}})));
}
TEST_CASE_FIXTURE(ESFixture, "never | Parent | Unrelated")
{
CHECK("Parent | Unrelated" == simplifyStr(arena->addType(UnionType{{
builtinTypes->neverType, parentClass, unrelatedClass
}})));
}
TEST_CASE_FIXTURE(ESFixture, "never | Parent | (number & string) | Unrelated")
{
CHECK("Parent | Unrelated" == simplifyStr(arena->addType(UnionType{{
builtinTypes->neverType, parentClass,
arena->addType(IntersectionType{{builtinTypes->numberType, builtinTypes->stringType}}),
unrelatedClass
}})));
}
TEST_CASE_FIXTURE(ESFixture, "T & U")
{
CHECK("T & U" == simplifyStr(arena->addType(IntersectionType{{
genericT, genericU
}})));
}
TEST_CASE_FIXTURE(ESFixture, "boolean & true")
{
CHECK("true" == simplifyStr(arena->addType(IntersectionType{{
builtinTypes->booleanType, builtinTypes->trueType
}})));
}
TEST_CASE_FIXTURE(ESFixture, "boolean & (true | number | string | thread | function | table | class | buffer)")
{
TypeId truthy = arena->addType(UnionType{{
builtinTypes->trueType,
builtinTypes->numberType,
builtinTypes->stringType,
builtinTypes->threadType,
builtinTypes->functionType,
builtinTypes->tableType,
builtinTypes->classType,
builtinTypes->bufferType,
}});
CHECK("true" == simplifyStr(arena->addType(IntersectionType{{
builtinTypes->booleanType, truthy
}})));
}
TEST_CASE_FIXTURE(ESFixture, "boolean & ~(false?)")
{
CHECK("true" == simplifyStr(arena->addType(IntersectionType{{
builtinTypes->booleanType, builtinTypes->truthyType
}})));
}
TEST_CASE_FIXTURE(ESFixture, "false & ~(false?)")
{
CHECK("never" == simplifyStr(arena->addType(IntersectionType{{
builtinTypes->falseType, builtinTypes->truthyType
}})));
}
TEST_CASE_FIXTURE(ESFixture, "(number) -> string & (number) -> string")
{
CHECK("(number) -> string" == simplifyStr(arena->addType(IntersectionType{{numberToString, numberToString}})));
}
TEST_CASE_FIXTURE(ESFixture, "(number) -> string | (number) -> string")
{
CHECK("(number) -> string" == simplifyStr(arena->addType(UnionType{{numberToString, numberToString}})));
}
TEST_CASE_FIXTURE(ESFixture, "(number) -> string & function")
{
CHECK("(number) -> string" == simplifyStr(arena->addType(IntersectionType{{numberToString, builtinTypes->functionType}})));
}
TEST_CASE_FIXTURE(ESFixture, "(number) -> string & boolean")
{
CHECK("never" == simplifyStr(arena->addType(IntersectionType{{numberToString, builtinTypes->booleanType}})));
}
TEST_CASE_FIXTURE(ESFixture, "(number) -> string & string")
{
CHECK("never" == simplifyStr(arena->addType(IntersectionType{{numberToString, builtinTypes->stringType}})));
}
TEST_CASE_FIXTURE(ESFixture, "(number) -> string & ~function")
{
TypeId notFunction = arena->addType(NegationType{builtinTypes->functionType});
CHECK("never" == simplifyStr(arena->addType(IntersectionType{{numberToString, notFunction}})));
}
TEST_CASE_FIXTURE(ESFixture, "(number) -> string | function")
{
CHECK("function" == simplifyStr(arena->addType(UnionType{{numberToString, builtinTypes->functionType}})));
}
TEST_CASE_FIXTURE(ESFixture, "(number) -> string & (string) -> number")
{
CHECK("((number) -> string) & ((string) -> number)" == simplifyStr(arena->addType(IntersectionType{{numberToString, stringToNumber}})));
}
TEST_CASE_FIXTURE(ESFixture, "(number) -> string | (string) -> number")
{
CHECK("((number) -> string) | ((string) -> number)" == simplifyStr(arena->addType(UnionType{{numberToString, stringToNumber}})));
}
TEST_CASE_FIXTURE(ESFixture, "add<number, number>")
{
CHECK("number" == simplifyStr(arena->addType(
TypeFunctionInstanceType{builtinTypeFunctions().addFunc, {
builtinTypes->numberType, builtinTypes->numberType
}}
)));
}
TEST_CASE_FIXTURE(ESFixture, "union<number, number>")
{
CHECK("number" == simplifyStr(arena->addType(
TypeFunctionInstanceType{builtinTypeFunctions().unionFunc, {
builtinTypes->numberType, builtinTypes->numberType
}}
)));
}
TEST_CASE_FIXTURE(ESFixture, "never & ~string")
{
CHECK("never" == simplifyStr(arena->addType(IntersectionType{{
builtinTypes->neverType,
arena->addType(NegationType{builtinTypes->stringType})
}})));
}
TEST_CASE_FIXTURE(ESFixture, "blocked & never")
{
const TypeId blocked = arena->addType(BlockedType{});
CHECK("never" == simplifyStr(arena->addType(IntersectionType{{blocked, builtinTypes->neverType}})));
}
TEST_CASE_FIXTURE(ESFixture, "blocked & ~number & function")
{
const TypeId blocked = arena->addType(BlockedType{});
const TypeId notNumber = arena->addType(NegationType{builtinTypes->numberType});
const TypeId ty = arena->addType(IntersectionType{{blocked, notNumber, builtinTypes->functionType}});
std::string expected = toString(blocked) + " & function";
CHECK(expected == simplifyStr(ty));
}
TEST_CASE_FIXTURE(ESFixture, "(number | boolean | string | nil | table) & (false | nil)")
{
const TypeId t1 = arena->addType(UnionType{{builtinTypes->numberType, builtinTypes->booleanType, builtinTypes->stringType, builtinTypes->nilType, builtinTypes->tableType}});
CHECK("false?" == simplifyStr(arena->addType(IntersectionType{{t1, builtinTypes->falsyType}})));
}
TEST_CASE_FIXTURE(ESFixture, "(number | boolean | nil) & (false | nil)")
{
const TypeId t1 = arena->addType(UnionType{{builtinTypes->numberType, builtinTypes->booleanType, builtinTypes->nilType}});
CHECK("false?" == simplifyStr(arena->addType(IntersectionType{{t1, builtinTypes->falsyType}})));
}
TEST_CASE_FIXTURE(ESFixture, "(boolean | nil) & (false | nil)")
{
const TypeId t1 = arena->addType(UnionType{{builtinTypes->booleanType, builtinTypes->nilType}});
CHECK("false?" == simplifyStr(arena->addType(IntersectionType{{t1, builtinTypes->falsyType}})));
}
// (('a & false) | ('a & nil)) | number
// Child & ~Parent
// ~Parent & Child
// ~Child & Parent
// Parent & ~Child
// ~Child & ~Parent
// ~Parent & ~Child
TEST_CASE_FIXTURE(ESFixture, "free & string & number")
{
Scope scope{builtinTypes->anyTypePack};
const TypeId freeTy = arena->addType(FreeType{&scope});
CHECK("never" == simplifyStr(arena->addType(IntersectionType{{freeTy, builtinTypes->numberType, builtinTypes->stringType}})));
}
TEST_CASE_FIXTURE(ESFixture, "(blocked & number) | (blocked & number)")
{
const TypeId blocked = arena->addType(BlockedType{});
const TypeId u = arena->addType(IntersectionType{{blocked, builtinTypes->numberType}});
const TypeId ty = arena->addType(UnionType{{u, u}});
const std::string blockedStr = toString(blocked);
CHECK(blockedStr + " & number" == simplifyStr(ty));
}
TEST_CASE_FIXTURE(ESFixture, "{} & unknown")
{
CHECK("{ }" == simplifyStr(arena->addType(IntersectionType{{
tbl({}),
builtinTypes->unknownType
}})));
}
TEST_CASE_FIXTURE(ESFixture, "{} & table")
{
CHECK("{ }" == simplifyStr(arena->addType(IntersectionType{{
tbl({}),
builtinTypes->tableType
}})));
}
TEST_CASE_FIXTURE(ESFixture, "{} & ~(false?)")
{
CHECK("{ }" == simplifyStr(arena->addType(IntersectionType{{
tbl({}),
builtinTypes->truthyType
}})));
}
TEST_CASE_FIXTURE(ESFixture, "{x: number?} & {x: number}")
{
const TypeId hasOptionalX = tbl({{"x", builtinTypes->optionalNumberType}});
const TypeId hasX = tbl({{"x", builtinTypes->numberType}});
const TypeId ty = arena->addType(IntersectionType{{hasOptionalX, hasX}});
auto res = eqSatSimplify(NotNull{simplifier.get()}, ty);
CHECK("{ x: number }" == toString(res->result));
// Also assert that we don't allocate a fresh TableType in this case.
CHECK(follow(res->result) == hasX);
}
TEST_CASE_FIXTURE(ESFixture, "{x: number?} & {x: ~(false?)}")
{
const TypeId hasOptionalX = tbl({{"x", builtinTypes->optionalNumberType}});
const TypeId hasX = tbl({{"x", builtinTypes->truthyType}});
const TypeId ty = arena->addType(IntersectionType{{hasOptionalX, hasX}});
auto res = eqSatSimplify(NotNull{simplifier.get()}, ty);
CHECK("{ x: number }" == toString(res->result));
}
TEST_CASE_FIXTURE(ESFixture, "(({ x: number? }?) & { x: ~(false?) }")
{
// {x: number?}?
const TypeId xWithOptionalNumber = arena->addType(UnionType{{tbl({{"x", builtinTypes->optionalNumberType}}), builtinTypes->nilType}});
// {x: ~(false?)}
const TypeId xWithTruthy = tbl({{"x", builtinTypes->truthyType}});
const TypeId ty = arena->addType(IntersectionType{{xWithOptionalNumber, xWithTruthy}});
CHECK("{ x: number }" == simplifyStr(ty));
}
TEST_CASE_FIXTURE(ESFixture, "never | (({ x: number? }?) & { x: ~(false?) })")
{
// {x: number?}?
const TypeId xWithOptionalNumber = arena->addType(UnionType{{tbl({{"x", builtinTypes->optionalNumberType}}), builtinTypes->nilType}});
// {x: ~(false?)}
const TypeId xWithTruthy = tbl({{"x", builtinTypes->truthyType}});
// ({x: number?}?) & {x: ~(false?)}
const TypeId intersectionTy = arena->addType(IntersectionType{{xWithOptionalNumber, xWithTruthy}});
const TypeId ty = arena->addType(UnionType{{builtinTypes->neverType, intersectionTy}});
CHECK("{ x: number }" == simplifyStr(ty));
}
TEST_CASE_FIXTURE(ESFixture, "({ x: number? }?) & { x: ~(false?) } & ~(false?)")
{
// {x: number?}?
const TypeId xWithOptionalNumber = arena->addType(UnionType{{tbl({{"x", builtinTypes->optionalNumberType}}), builtinTypes->nilType}});
// {x: ~(false?)}
const TypeId xWithTruthy = tbl({{"x", builtinTypes->truthyType}});
// ({x: number?}?) & {x: ~(false?)} & ~(false?)
const TypeId intersectionTy = arena->addType(IntersectionType{{xWithOptionalNumber, xWithTruthy, builtinTypes->truthyType}});
CHECK("{ x: number }" == simplifyStr(intersectionTy));
}
#if 0
// TODO
TEST_CASE_FIXTURE(ESFixture, "(({ x: number? }?) & { x: ~(false?) } & ~(false?)) | number")
{
// ({ x: number? }?) & { x: ~(false?) } & ~(false?)
const TypeId xWithOptionalNumber = tbl({{"x", builtinTypes->optionalNumberType}});
const TypeId xWithTruthy = tbl({{"x", builtinTypes->truthyType}});
const TypeId intersectionTy = arena->addType(IntersectionType{{xWithOptionalNumber, xWithTruthy, builtinTypes->truthyType}});
const TypeId ty = arena->addType(UnionType{{intersectionTy, builtinTypes->numberType}});
CHECK("{ x: number } | number" == simplifyStr(ty));
}
#endif
TEST_CASE_FIXTURE(ESFixture, "number & no-refine")
{
CHECK("number" == simplifyStr(arena->addType(IntersectionType{{builtinTypes->numberType, builtinTypes->noRefineType}})));
}
TEST_CASE_FIXTURE(ESFixture, "{ x: number } & ~boolean")
{
const TypeId tblTy = tbl(TableType::Props{{"x", builtinTypes->numberType}});
const TypeId ty = arena->addType(IntersectionType{{
tblTy,
arena->addType(NegationType{builtinTypes->booleanType})
}});
CHECK("{ x: number }" == simplifyStr(ty));
}
TEST_CASE_FIXTURE(ESFixture, "(nil & string)?")
{
const TypeId nilAndString = arena->addType(IntersectionType{{builtinTypes->nilType, builtinTypes->stringType}});
const TypeId ty = arena->addType(UnionType{{nilAndString, builtinTypes->nilType}});
CHECK("nil" == simplifyStr(ty));
}
TEST_CASE_FIXTURE(ESFixture, "string & \"hi\"")
{
const TypeId hi = arena->addType(SingletonType{StringSingleton{"hi"}});
CHECK("\"hi\"" == simplifyStr(arena->addType(IntersectionType{{builtinTypes->stringType, hi}})));
}
TEST_CASE_FIXTURE(ESFixture, "string & (\"hi\" | \"bye\")")
{
const TypeId hi = arena->addType(SingletonType{StringSingleton{"hi"}});
const TypeId bye = arena->addType(SingletonType{StringSingleton{"bye"}});
CHECK("\"bye\" | \"hi\"" == simplifyStr(arena->addType(IntersectionType{{
builtinTypes->stringType,
arena->addType(UnionType{{hi, bye}})
}})));
}
TEST_CASE_FIXTURE(ESFixture, "(Child | Unrelated) & ~Child")
{
const TypeId ty = arena->addType(IntersectionType{{
arena->addType(UnionType{{childClass, unrelatedClass}}),
arena->addType(NegationType{childClass})
}});
CHECK("Unrelated" == simplifyStr(ty));
}
TEST_CASE_FIXTURE(ESFixture, "string & ~Child")
{
CHECK("string" == simplifyStr(arena->addType(IntersectionType{{
builtinTypes->stringType,
arena->addType(NegationType{childClass})
}})));
}
TEST_CASE_FIXTURE(ESFixture, "(Child | Unrelated) & Child")
{
CHECK("Child" == simplifyStr(arena->addType(IntersectionType{{
arena->addType(UnionType{{childClass, unrelatedClass}}),
childClass
}})));
}
TEST_CASE_FIXTURE(ESFixture, "(Child | AnotherChild) & ~Child")
{
CHECK("Child" == simplifyStr(arena->addType(IntersectionType{{
arena->addType(UnionType{{childClass, anotherChild}}),
childClass
}})));
}
TEST_CASE_FIXTURE(ESFixture, "{ tag: \"Part\", x: never }")
{
const TypeId ty = tbl({{"tag", arena->addType(SingletonType{StringSingleton{"Part"}})}, {"x", builtinTypes->neverType}});
CHECK("never" == simplifyStr(ty));
}
TEST_CASE_FIXTURE(ESFixture, "{ tag: \"Part\", x: number? } & { x: string }")
{
const TypeId leftTable = tbl({{"tag", arena->addType(SingletonType{StringSingleton{"Part"}})}, {"x", builtinTypes->optionalNumberType}});
const TypeId rightTable = tbl({{"x", builtinTypes->stringType}});
CHECK("never" == simplifyStr(arena->addType(IntersectionType{{leftTable, rightTable}})));
}
TEST_CASE_FIXTURE(ESFixture, "Child & add<Child | AnotherChild | string, Parent>")
{
const TypeId u = arena->addType(UnionType{{childClass, anotherChild, builtinTypes->stringType}});
const TypeId intersectTf = arena->addType(TypeFunctionInstanceType{
builtinTypeFunctions().addFunc,
{u, parentClass},
{}
});
const TypeId intersection = arena->addType(IntersectionType{{childClass, intersectTf}});
CHECK("Child & add<AnotherChild | Child | string, Parent>" == simplifyStr(intersection));
}
TEST_CASE_FIXTURE(ESFixture, "Child & intersect<Child | AnotherChild | string, Parent>")
{
const TypeId u = arena->addType(UnionType{{childClass, anotherChild, builtinTypes->stringType}});
const TypeId intersectTf = arena->addType(TypeFunctionInstanceType{
builtinTypeFunctions().intersectFunc,
{u, parentClass},
{}
});
const TypeId intersection = arena->addType(IntersectionType{{childClass, intersectTf}});
CHECK("Child" == simplifyStr(intersection));
}
// {someKey: ~any}
//
// Maybe something we could do here is to try to reduce the key, get the
// class->node mapping, and skip the extraction process if the class corresponds
// to TNever.
// t1 where t1 = add<union<number, t1>, number>
TEST_SUITE_END();

View file

@ -293,3 +293,37 @@ using DifferFixtureWithBuiltins = DifferFixtureGeneric<BuiltinsFixture>;
} while (false) } while (false)
#define LUAU_CHECK_NO_ERRORS(result) LUAU_CHECK_ERROR_COUNT(0, result) #define LUAU_CHECK_NO_ERRORS(result) LUAU_CHECK_ERROR_COUNT(0, result)
#define LUAU_CHECK_HAS_KEY(map, key) \
do \
{ \
auto&& _m = (map); \
auto&& _k = (key); \
const size_t count = _m.count(_k); \
CHECK_MESSAGE(count, "Map should have key \"" << _k << "\""); \
if (!count) \
{ \
MESSAGE("Keys: (count " << _m.size() << ")"); \
for (const auto& [k, v] : _m) \
{ \
MESSAGE("\tkey: " << k); \
} \
} \
} while (false)
#define LUAU_CHECK_HAS_NO_KEY(map, key) \
do \
{ \
auto&& _m = (map); \
auto&& _k = (key); \
const size_t count = _m.count(_k); \
CHECK_MESSAGE(!count, "Map should not have key \"" << _k << "\""); \
if (count) \
{ \
MESSAGE("Keys: (count " << _m.size() << ")"); \
for (const auto& [k, v] : _m) \
{ \
MESSAGE("\tkey: " << k); \
} \
} \
} while (false)

View file

@ -4,19 +4,37 @@
#include "Fixture.h" #include "Fixture.h"
#include "Luau/Ast.h" #include "Luau/Ast.h"
#include "Luau/AstQuery.h" #include "Luau/AstQuery.h"
#include "Luau/Autocomplete.h"
#include "Luau/BuiltinDefinitions.h"
#include "Luau/Common.h" #include "Luau/Common.h"
#include "Luau/Frontend.h" #include "Luau/Frontend.h"
#include "Luau/AutocompleteTypes.h"
using namespace Luau; using namespace Luau;
LUAU_FASTFLAG(LuauAllowFragmentParsing); LUAU_FASTFLAG(LuauAllowFragmentParsing);
LUAU_FASTFLAG(LuauStoreDFGOnModule2); LUAU_FASTFLAG(LuauStoreDFGOnModule2);
LUAU_FASTFLAG(LuauAutocompleteRefactorsForIncrementalAutocomplete)
static std::optional<AutocompleteEntryMap> nullCallback(std::string tag, std::optional<const ClassType*> ptr, std::optional<std::string> contents)
{
return std::nullopt;
}
struct FragmentAutocompleteFixture : Fixture struct FragmentAutocompleteFixture : Fixture
{ {
ScopedFastFlag sffs[3] = {{FFlag::LuauAllowFragmentParsing, true}, {FFlag::LuauSolverV2, true}, {FFlag::LuauStoreDFGOnModule2, true}}; ScopedFastFlag sffs[4] = {
{FFlag::LuauAllowFragmentParsing, true},
{FFlag::LuauSolverV2, true},
{FFlag::LuauStoreDFGOnModule2, true},
{FFlag::LuauAutocompleteRefactorsForIncrementalAutocomplete, true}
};
FragmentAutocompleteFixture()
{
addGlobalBinding(frontend.globals, "table", Binding{builtinTypes->anyType});
addGlobalBinding(frontend.globals, "math", Binding{builtinTypes->anyType});
}
FragmentAutocompleteAncestryResult runAutocompleteVisitor(const std::string& source, const Position& cursorPos) FragmentAutocompleteAncestryResult runAutocompleteVisitor(const std::string& source, const Position& cursorPos)
{ {
ParseResult p = tryParse(source); // We don't care about parsing incomplete asts ParseResult p = tryParse(source); // We don't care about parsing incomplete asts
@ -26,7 +44,6 @@ struct FragmentAutocompleteFixture : Fixture
CheckResult checkBase(const std::string& document) CheckResult checkBase(const std::string& document)
{ {
ScopedFastFlag sff{FFlag::LuauSolverV2, true};
FrontendOptions opts; FrontendOptions opts;
opts.retainFullTypeGraphs = true; opts.retainFullTypeGraphs = true;
return this->frontend.check("MainModule", opts); return this->frontend.check("MainModule", opts);
@ -48,6 +65,16 @@ struct FragmentAutocompleteFixture : Fixture
options.runLintChecks = false; options.runLintChecks = false;
return Luau::typecheckFragment(frontend, "MainModule", cursorPos, options, document); return Luau::typecheckFragment(frontend, "MainModule", cursorPos, options, document);
} }
FragmentAutocompleteResult autocompleteFragment(const std::string& document, Position cursorPos)
{
FrontendOptions options;
options.retainFullTypeGraphs = true;
// Don't strictly need this in the new solver
options.forAutocomplete = true;
options.runLintChecks = false;
return Luau::fragmentAutocomplete(frontend, document, "MainModule", cursorPos, options, nullCallback);
}
}; };
TEST_SUITE_BEGIN("FragmentAutocompleteTraversalTests"); TEST_SUITE_BEGIN("FragmentAutocompleteTraversalTests");
@ -172,6 +199,13 @@ TEST_SUITE_END();
TEST_SUITE_BEGIN("FragmentAutocompleteParserTests"); TEST_SUITE_BEGIN("FragmentAutocompleteParserTests");
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "local_initializer")
{
check("local a =");
auto fragment = parseFragment("local a =", Position(0, 10));
CHECK_EQ("local a =", fragment.fragmentToParse);
}
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "statement_in_empty_fragment_is_non_null") TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "statement_in_empty_fragment_is_non_null")
{ {
auto res = check(R"( auto res = check(R"(
@ -278,6 +312,33 @@ local y = 5
CHECK_EQ("y", std::string(rhs->name.value)); CHECK_EQ("y", std::string(rhs->name.value));
} }
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "can_parse_in_correct_scope")
{
check(R"(
local myLocal = 4
function abc()
local myInnerLocal = 1
end
)");
auto fragment = parseFragment(
R"(
local myLocal = 4
function abc()
local myInnerLocal = 1
end
)",
Position{6, 0}
);
CHECK_EQ("function abc()\n local myInnerLocal = 1\n\n end\n", fragment.fragmentToParse);
}
TEST_SUITE_END(); TEST_SUITE_END();
TEST_SUITE_BEGIN("FragmentAutocompleteTypeCheckerTests"); TEST_SUITE_BEGIN("FragmentAutocompleteTypeCheckerTests");
@ -302,7 +363,7 @@ local z = x + y
Position{3, 15} Position{3, 15}
); );
auto opt = linearSearchForBinding(fragment.freshScope, "z"); auto opt = linearSearchForBinding(fragment.freshScope.get(), "z");
REQUIRE(opt); REQUIRE(opt);
CHECK_EQ("number", toString(*opt)); CHECK_EQ("number", toString(*opt));
} }
@ -326,9 +387,222 @@ local y = 5
Position{2, 11} Position{2, 11}
); );
auto correct = linearSearchForBinding(fragment.freshScope, "z"); auto correct = linearSearchForBinding(fragment.freshScope.get(), "z");
REQUIRE(correct); REQUIRE(correct);
CHECK_EQ("number", toString(*correct)); CHECK_EQ("number", toString(*correct));
} }
TEST_SUITE_END(); TEST_SUITE_END();
TEST_SUITE_BEGIN("FragmentAutocompleteTests");
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "can_autocomplete_simple_property_access")
{
auto res = check(
R"(
local tbl = { abc = 1234}
)"
);
LUAU_REQUIRE_NO_ERRORS(res);
auto fragment = autocompleteFragment(
R"(
local tbl = { abc = 1234}
tbl.
)",
Position{2, 5}
);
LUAU_ASSERT(fragment.freshScope);
CHECK_EQ(1, fragment.acResults.entryMap.size());
CHECK(fragment.acResults.entryMap.count("abc"));
CHECK_EQ(AutocompleteContext::Property, fragment.acResults.context);
}
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "can_autocomplete_nested_property_access")
{
auto res = check(
R"(
local tbl = { abc = { def = 1234, egh = false } }
)"
);
LUAU_REQUIRE_NO_ERRORS(res);
auto fragment = autocompleteFragment(
R"(
local tbl = { abc = { def = 1234, egh = false } }
tbl.abc.
)",
Position{2, 8}
);
LUAU_ASSERT(fragment.freshScope);
CHECK_EQ(2, fragment.acResults.entryMap.size());
CHECK(fragment.acResults.entryMap.count("def"));
CHECK(fragment.acResults.entryMap.count("egh"));
CHECK_EQ(fragment.acResults.context, AutocompleteContext::Property);
}
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "inline_autocomplete_picks_the_right_scope")
{
auto res = check(
R"(
type Table = { a: number, b: number }
do
type Table = { x: string, y: string }
end
)"
);
LUAU_REQUIRE_NO_ERRORS(res);
auto fragment = autocompleteFragment(
R"(
type Table = { a: number, b: number }
do
type Table = { x: string, y: string }
local a : T
end
)",
Position{4, 15}
);
LUAU_ASSERT(fragment.freshScope);
REQUIRE(fragment.acResults.entryMap.count("Table"));
REQUIRE(fragment.acResults.entryMap["Table"].type);
const TableType* tv = get<TableType>(follow(*fragment.acResults.entryMap["Table"].type));
REQUIRE(tv);
CHECK(tv->props.count("x"));
}
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "nested_recursive_function")
{
auto res = check(R"(
function foo()
end
)");
LUAU_REQUIRE_NO_ERRORS(res);
auto fragment = autocompleteFragment(
R"(
function foo()
end
)",
Position{2, 0}
);
CHECK(fragment.acResults.entryMap.count("foo"));
CHECK_EQ(AutocompleteContext::Statement, fragment.acResults.context);
}
// Start compatibility tests!
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "empty_program")
{
check("");
auto frag = autocompleteFragment(" ", Position{0, 1});
auto ac = frag.acResults;
CHECK(ac.entryMap.count("table"));
CHECK(ac.entryMap.count("math"));
CHECK_EQ(ac.context, AutocompleteContext::Statement);
}
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "local_initializer")
{
check("local a =");
auto frag = autocompleteFragment("local a =", Position{0, 9});
auto ac = frag.acResults;
CHECK(ac.entryMap.count("table"));
CHECK(ac.entryMap.count("math"));
CHECK_EQ(ac.context, AutocompleteContext::Expression);
}
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "leave_numbers_alone")
{
check("local a = 3.");
auto frag = autocompleteFragment("local a = 3.", Position{0, 12});
auto ac = frag.acResults;
CHECK(ac.entryMap.empty());
CHECK_EQ(ac.context, AutocompleteContext::Unknown);
}
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "user_defined_globals")
{
check("local myLocal = 4; ");
auto frag = autocompleteFragment("local myLocal = 4; ", Position{0, 18});
auto ac = frag.acResults;
CHECK(ac.entryMap.count("myLocal"));
CHECK(ac.entryMap.count("table"));
CHECK(ac.entryMap.count("math"));
CHECK_EQ(ac.context, AutocompleteContext::Statement);
}
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "dont_suggest_local_before_its_definition")
{
check(R"(
local myLocal = 4
function abc()
local myInnerLocal = 1
end
)");
// autocomplete after abc but before myInnerLocal
auto fragment = autocompleteFragment(
R"(
local myLocal = 4
function abc()
local myInnerLocal = 1
end
)",
Position{3, 0}
);
auto ac = fragment.acResults;
CHECK(ac.entryMap.count("myLocal"));
LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "myInnerLocal");
// autocomplete after my inner local
fragment = autocompleteFragment(
R"(
local myLocal = 4
function abc()
local myInnerLocal = 1
end
)",
Position{4, 0}
);
ac = fragment.acResults;
CHECK(ac.entryMap.count("myLocal"));
CHECK(ac.entryMap.count("myInnerLocal"));
fragment = autocompleteFragment(
R"(
local myLocal = 4
function abc()
local myInnerLocal = 1
end
)",
Position{6, 0}
);
ac = fragment.acResults;
CHECK(ac.entryMap.count("myLocal"));
LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "myInnerLocal");
}
TEST_SUITE_END();

View file

@ -18,6 +18,7 @@ LUAU_FASTINT(LuauParseErrorLimit)
LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAG(LuauAttributeSyntaxFunExpr) LUAU_FASTFLAG(LuauAttributeSyntaxFunExpr)
LUAU_FASTFLAG(LuauUserDefinedTypeFunctionsSyntax2) LUAU_FASTFLAG(LuauUserDefinedTypeFunctionsSyntax2)
LUAU_FASTFLAG(LuauUserDefinedTypeFunParseExport)
namespace namespace
{ {
@ -2377,10 +2378,15 @@ TEST_CASE_FIXTURE(Fixture, "invalid_type_forms")
TEST_CASE_FIXTURE(Fixture, "parse_user_defined_type_functions") TEST_CASE_FIXTURE(Fixture, "parse_user_defined_type_functions")
{ {
ScopedFastFlag sff{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; ScopedFastFlag sff{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true};
ScopedFastFlag sff2{FFlag::LuauUserDefinedTypeFunParseExport, true};
AstStat* stat = parse(R"( AstStat* stat = parse(R"(
type function foo() type function foo()
return return types.number
end
export type function bar()
return types.string
end end
)"); )");
@ -2417,7 +2423,6 @@ TEST_CASE_FIXTURE(Fixture, "invalid_user_defined_type_functions")
{ {
ScopedFastFlag sff{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; ScopedFastFlag sff{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true};
matchParseError("export type function foo() end", "Type function cannot be exported");
matchParseError("local foo = 1; type function bar() print(foo) end", "Type function cannot reference outer local 'foo'"); matchParseError("local foo = 1; type function bar() print(foo) end", "Type function cannot reference outer local 'foo'");
matchParseError("type function foo() local v1 = 1; type function bar() print(v1) end end", "Type function cannot reference outer local 'v1'"); matchParseError("type function foo() local v1 = 1; type function bar() print(v1) end end", "Type function cannot reference outer local 'v1'");
} }

View file

@ -424,6 +424,13 @@ TEST_CASE_FIXTURE(ReplWithPathFixture, "RequireUnprefixedPath")
assertOutputContainsAll({"false", "require path must start with a valid prefix: ./, ../, or @"}); assertOutputContainsAll({"false", "require path must start with a valid prefix: ./, ../, or @"});
} }
TEST_CASE_FIXTURE(ReplWithPathFixture, "RequirePathWithExtension")
{
std::string path = getLuauDirectory(PathType::Relative) + "/tests/require/without_config/dependency.luau";
runProtectedRequire(path);
assertOutputContainsAll({"false", "error requiring module: consider removing the file extension"});
}
TEST_CASE_FIXTURE(ReplWithPathFixture, "RequirePathWithAlias") TEST_CASE_FIXTURE(ReplWithPathFixture, "RequirePathWithAlias")
{ {
std::string path = getLuauDirectory(PathType::Relative) + "/tests/require/with_config/src/alias_requirer"; std::string path = getLuauDirectory(PathType::Relative) + "/tests/require/with_config/src/alias_requirer";

View file

@ -964,6 +964,7 @@ TEST_CASE_FIXTURE(Fixture, "correct_stringification_user_defined_type_functions"
std::vector<TypeId>{builtinTypes->numberType}, // Type Function Arguments std::vector<TypeId>{builtinTypes->numberType}, // Type Function Arguments
{}, {},
{AstName{"woohoo"}}, // Type Function Name {AstName{"woohoo"}}, // Type Function Name
{},
}; };
Type tv{tftt}; Type tv{tftt};

View file

@ -16,6 +16,8 @@ LUAU_FASTFLAG(LuauUserTypeFunFixNoReadWrite)
LUAU_FASTFLAG(LuauUserTypeFunFixMetatable) LUAU_FASTFLAG(LuauUserTypeFunFixMetatable)
LUAU_FASTFLAG(LuauUserDefinedTypeFunctionResetState) LUAU_FASTFLAG(LuauUserDefinedTypeFunctionResetState)
LUAU_FASTFLAG(LuauUserTypeFunNonstrict) LUAU_FASTFLAG(LuauUserTypeFunNonstrict)
LUAU_FASTFLAG(LuauUserTypeFunExportedAndLocal)
LUAU_FASTFLAG(LuauUserDefinedTypeFunParseExport)
TEST_SUITE_BEGIN("UserDefinedTypeFunctionTests"); TEST_SUITE_BEGIN("UserDefinedTypeFunctionTests");
@ -1298,4 +1300,92 @@ local a: foo<> = "a"
LUAU_REQUIRE_NO_ERRORS(result); LUAU_REQUIRE_NO_ERRORS(result);
} }
TEST_CASE_FIXTURE(BuiltinsFixture, "implicit_export")
{
ScopedFastFlag newSolver{FFlag::LuauSolverV2, true};
ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true};
ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true};
ScopedFastFlag luauUserTypeFunFixRegister{FFlag::LuauUserTypeFunFixRegister, true};
ScopedFastFlag luauUserTypeFunExportedAndLocal{FFlag::LuauUserTypeFunExportedAndLocal, true};
fileResolver.source["game/A"] = R"(
type function concat(a, b)
return types.singleton(a:value() .. b:value())
end
export type Concat<T, U> = concat<T, U>
local a: concat<'first', 'second'>
return {}
)";
CheckResult aResult = frontend.check("game/A");
LUAU_REQUIRE_NO_ERRORS(aResult);
CHECK(toString(requireType("game/A", "a")) == R"("firstsecond")");
CheckResult bResult = check(R"(
local Test = require(game.A);
local b: Test.Concat<'third', 'fourth'>
)");
LUAU_REQUIRE_NO_ERRORS(bResult);
CHECK(toString(requireType("b")) == R"("thirdfourth")");
}
TEST_CASE_FIXTURE(BuiltinsFixture, "local_scope")
{
ScopedFastFlag newSolver{FFlag::LuauSolverV2, true};
ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true};
ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true};
ScopedFastFlag luauUserTypeFunFixRegister{FFlag::LuauUserTypeFunFixRegister, true};
ScopedFastFlag luauUserTypeFunExportedAndLocal{FFlag::LuauUserTypeFunExportedAndLocal, true};
CheckResult result = check(R"(
type function foo()
return "hi"
end
local function test()
type function bar()
return types.singleton(foo())
end
return ("" :: any) :: bar<>
end
local a = test()
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK(toString(requireType("a")) == R"("hi")");
}
TEST_CASE_FIXTURE(BuiltinsFixture, "explicit_export")
{
ScopedFastFlag newSolver{FFlag::LuauSolverV2, true};
ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true};
ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true};
ScopedFastFlag luauUserTypeFunFixRegister{FFlag::LuauUserTypeFunFixRegister, true};
ScopedFastFlag luauUserTypeFunExportedAndLocal{FFlag::LuauUserTypeFunExportedAndLocal, true};
ScopedFastFlag luauUserDefinedTypeFunParseExport{FFlag::LuauUserDefinedTypeFunParseExport, true};
fileResolver.source["game/A"] = R"(
export type function concat(a, b)
return types.singleton(a:value() .. b:value())
end
local a: concat<'first', 'second'>
return {}
)";
CheckResult aResult = frontend.check("game/A");
LUAU_REQUIRE_NO_ERRORS(aResult);
CHECK(toString(requireType("game/A", "a")) == R"("firstsecond")");
CheckResult bResult = check(R"(
local Test = require(game.A);
local b: Test.concat<'third', 'fourth'>
)");
LUAU_REQUIRE_NO_ERRORS(bResult);
CHECK(toString(requireType("b")) == R"("thirdfourth")");
}
TEST_SUITE_END(); TEST_SUITE_END();

View file

@ -9,6 +9,8 @@
using namespace Luau; using namespace Luau;
LUAU_FASTFLAG(LuauNewSolverPrePopulateClasses)
TEST_SUITE_BEGIN("DefinitionTests"); TEST_SUITE_BEGIN("DefinitionTests");
TEST_CASE_FIXTURE(Fixture, "definition_file_simple") TEST_CASE_FIXTURE(Fixture, "definition_file_simple")
@ -492,11 +494,8 @@ TEST_CASE_FIXTURE(Fixture, "class_definition_indexer")
TEST_CASE_FIXTURE(Fixture, "class_definitions_reference_other_classes") TEST_CASE_FIXTURE(Fixture, "class_definitions_reference_other_classes")
{ {
unfreeze(frontend.globals.globalTypes); ScopedFastFlag _{FFlag::LuauNewSolverPrePopulateClasses, true};
LoadDefinitionFileResult result = frontend.loadDefinitionFile( loadDefinition(R"(
frontend.globals,
frontend.globals.globalScope,
R"(
declare class Channel declare class Channel
Messages: { Message } Messages: { Message }
OnMessage: (message: Message) -> () OnMessage: (message: Message) -> ()
@ -506,13 +505,19 @@ TEST_CASE_FIXTURE(Fixture, "class_definitions_reference_other_classes")
Text: string Text: string
Channel: Channel Channel: Channel
end end
)", )");
"@test",
/* captureComments */ false
);
freeze(frontend.globals.globalTypes);
REQUIRE(result.success); CheckResult result = check(R"(
local a: Channel
local b = a.Messages[1]
local c = b.Channel
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(toString(requireType("a")), "Channel");
CHECK_EQ(toString(requireType("b")), "Message");
CHECK_EQ(toString(requireType("c")), "Channel");
} }
TEST_CASE_FIXTURE(Fixture, "definition_file_has_source_module_name_set") TEST_CASE_FIXTURE(Fixture, "definition_file_has_source_module_name_set")

View file

@ -20,6 +20,7 @@ LUAU_FASTFLAG(LuauInstantiateInSubtyping)
LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTINT(LuauTarjanChildLimit) LUAU_FASTINT(LuauTarjanChildLimit)
LUAU_FASTFLAG(LuauRetrySubtypingWithoutHiddenPack) LUAU_FASTFLAG(LuauRetrySubtypingWithoutHiddenPack)
LUAU_FASTFLAG(LuauDontRefCountTypesInTypeFunctions)
TEST_SUITE_BEGIN("TypeInferFunctions"); TEST_SUITE_BEGIN("TypeInferFunctions");
@ -681,6 +682,11 @@ TEST_CASE_FIXTURE(Fixture, "infer_higher_order_function")
TEST_CASE_FIXTURE(Fixture, "higher_order_function_2") TEST_CASE_FIXTURE(Fixture, "higher_order_function_2")
{ {
// CLI-114134: this code *probably* wants the egraph in order
// to work properly. The new solver either falls over or
// forces so many constraints as to be unreliable.
DOES_NOT_PASS_NEW_SOLVER_GUARD();
CheckResult result = check(R"( CheckResult result = check(R"(
function bottomupmerge(comp, a, b, left, mid, right) function bottomupmerge(comp, a, b, left, mid, right)
local i, j = left, mid local i, j = left, mid
@ -743,6 +749,11 @@ TEST_CASE_FIXTURE(Fixture, "higher_order_function_3")
TEST_CASE_FIXTURE(BuiltinsFixture, "higher_order_function_4") TEST_CASE_FIXTURE(BuiltinsFixture, "higher_order_function_4")
{ {
// CLI-114134: this code *probably* wants the egraph in order
// to work properly. The new solver either falls over or
// forces so many constraints as to be unreliable.
DOES_NOT_PASS_NEW_SOLVER_GUARD();
CheckResult result = check(R"( CheckResult result = check(R"(
function bottomupmerge(comp, a, b, left, mid, right) function bottomupmerge(comp, a, b, left, mid, right)
local i, j = left, mid local i, j = left, mid
@ -2554,8 +2565,17 @@ end
TEST_CASE_FIXTURE(BuiltinsFixture, "tf_suggest_return_type") TEST_CASE_FIXTURE(BuiltinsFixture, "tf_suggest_return_type")
{ {
if (!FFlag::LuauSolverV2) ScopedFastFlag sffs[] = {
return; {FFlag::LuauSolverV2, true},
{FFlag::LuauDontRefCountTypesInTypeFunctions, true}
};
// CLI-114134: This test:
// a) Has a kind of weird result (suggesting `number | false` is not great);
// b) Is force solving some constraints.
// We end up with a weird recursive type that, if you roughly look at it, is
// clearly `number`. Hopefully the egraph will be able to unfold this.
CheckResult result = check(R"( CheckResult result = check(R"(
function fib(n) function fib(n)
return n < 2 and 1 or fib(n-1) + fib(n-2) return n < 2 and 1 or fib(n-1) + fib(n-2)
@ -2565,9 +2585,7 @@ end
LUAU_REQUIRE_ERRORS(result); LUAU_REQUIRE_ERRORS(result);
auto err = get<ExplicitFunctionAnnotationRecommended>(result.errors.back()); auto err = get<ExplicitFunctionAnnotationRecommended>(result.errors.back());
LUAU_ASSERT(err); LUAU_ASSERT(err);
CHECK("number" == toString(err->recommendedReturn)); CHECK("false | number" == toString(err->recommendedReturn));
REQUIRE(1 == err->recommendedArgs.size());
CHECK("number" == toString(err->recommendedArgs[0].second));
} }
TEST_CASE_FIXTURE(BuiltinsFixture, "tf_suggest_arg_type") TEST_CASE_FIXTURE(BuiltinsFixture, "tf_suggest_arg_type")
@ -2862,6 +2880,8 @@ TEST_CASE_FIXTURE(Fixture, "fuzzer_missing_follow_in_ast_stat_fun")
TEST_CASE_FIXTURE(Fixture, "unifier_should_not_bind_free_types") TEST_CASE_FIXTURE(Fixture, "unifier_should_not_bind_free_types")
{ {
ScopedFastFlag _{FFlag::LuauDontRefCountTypesInTypeFunctions, true};
CheckResult result = check(R"( CheckResult result = check(R"(
function foo(player) function foo(player)
local success,result = player:thing() local success,result = player:thing()
@ -2889,7 +2909,7 @@ TEST_CASE_FIXTURE(Fixture, "unifier_should_not_bind_free_types")
auto tm2 = get<TypePackMismatch>(result.errors[1]); auto tm2 = get<TypePackMismatch>(result.errors[1]);
REQUIRE(tm2); REQUIRE(tm2);
CHECK(toString(tm2->wantedTp) == "string"); CHECK(toString(tm2->wantedTp) == "string");
CHECK(toString(tm2->givenTp) == "buffer | class | function | number | string | table | thread | true"); CHECK(toString(tm2->givenTp) == "(buffer | class | function | number | string | table | thread | true) & unknown");
} }
else else
{ {

View file

@ -14,6 +14,7 @@ LUAU_FASTFLAG(LuauInstantiateInSubtyping)
LUAU_FASTFLAG(LuauRequireCyclesDontAlwaysReturnAny) LUAU_FASTFLAG(LuauRequireCyclesDontAlwaysReturnAny)
LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAG(LuauTypestateBuiltins2) LUAU_FASTFLAG(LuauTypestateBuiltins2)
LUAU_FASTFLAG(LuauNewSolverPopulateTableLocations)
using namespace Luau; using namespace Luau;
@ -466,7 +467,15 @@ local b: B.T = a
LUAU_REQUIRE_ERROR_COUNT(1, result); LUAU_REQUIRE_ERROR_COUNT(1, result);
if (FFlag::LuauSolverV2) if (FFlag::LuauSolverV2)
{
if (FFlag::LuauNewSolverPopulateTableLocations)
CHECK(
toString(result.errors.at(0)) ==
"Type 'T' from 'game/A' could not be converted into 'T' from 'game/B'; at [read \"x\"], number is not exactly string"
);
else
CHECK(toString(result.errors.at(0)) == "Type 'T' could not be converted into 'T'; at [read \"x\"], number is not exactly string"); CHECK(toString(result.errors.at(0)) == "Type 'T' could not be converted into 'T'; at [read \"x\"], number is not exactly string");
}
else else
{ {
const std::string expected = R"(Type 'T' from 'game/A' could not be converted into 'T' from 'game/B' const std::string expected = R"(Type 'T' from 'game/A' could not be converted into 'T' from 'game/B'
@ -507,7 +516,15 @@ local b: B.T = a
LUAU_REQUIRE_ERROR_COUNT(1, result); LUAU_REQUIRE_ERROR_COUNT(1, result);
if (FFlag::LuauSolverV2) if (FFlag::LuauSolverV2)
{
if (FFlag::LuauNewSolverPopulateTableLocations)
CHECK(
toString(result.errors.at(0)) ==
"Type 'T' from 'game/B' could not be converted into 'T' from 'game/C'; at [read \"x\"], number is not exactly string"
);
else
CHECK(toString(result.errors.at(0)) == "Type 'T' could not be converted into 'T'; at [read \"x\"], number is not exactly string"); CHECK(toString(result.errors.at(0)) == "Type 'T' could not be converted into 'T'; at [read \"x\"], number is not exactly string");
}
else else
{ {
const std::string expected = R"(Type 'T' from 'game/B' could not be converted into 'T' from 'game/C' const std::string expected = R"(Type 'T' from 'game/B' could not be converted into 'T' from 'game/C'

View file

@ -24,6 +24,7 @@ LUAU_FASTINT(LuauNormalizeCacheLimit);
LUAU_FASTINT(LuauRecursionLimit); LUAU_FASTINT(LuauRecursionLimit);
LUAU_FASTINT(LuauTypeInferRecursionLimit); LUAU_FASTINT(LuauTypeInferRecursionLimit);
LUAU_FASTFLAG(LuauNewSolverVisitErrorExprLvalues) LUAU_FASTFLAG(LuauNewSolverVisitErrorExprLvalues)
LUAU_FASTFLAG(LuauDontRefCountTypesInTypeFunctions)
using namespace Luau; using namespace Luau;
@ -1730,4 +1731,36 @@ TEST_CASE_FIXTURE(Fixture, "visit_error_nodes_in_lvalue")
)")); )"));
} }
TEST_CASE_FIXTURE(Fixture, "avoid_blocking_type_function")
{
ScopedFastFlag sffs[] = {
{FFlag::LuauSolverV2, true},
{FFlag::LuauDontRefCountTypesInTypeFunctions, true}
};
LUAU_CHECK_NO_ERRORS(check(R"(
--!strict
local function foo(a : string?)
local b = a or ""
return b:upper()
end
)"));
}
TEST_CASE_FIXTURE(Fixture, "avoid_double_reference_to_free_type")
{
ScopedFastFlag sffs[] = {
{FFlag::LuauSolverV2, true},
{FFlag::LuauDontRefCountTypesInTypeFunctions, true}
};
LUAU_CHECK_NO_ERRORS(check(R"(
--!strict
local function wtf(name: string?)
local message
message = "invalid alternate fiber: " .. (name or "UNNAMED alternate")
end
)"));
}
TEST_SUITE_END(); TEST_SUITE_END();