Sync to upstream/release/648 (#1477)

## What's new

* Added `math.map` function to the standard library, based on
https://rfcs.luau-lang.org/function-math-map.html
* `FileResolver` can provide an implementation of
`getRequireSuggestions` to provide auto-complete suggestions for
require-by-string

## New Solver

* In user-defined type functions, `readproperty` and `writeproperty`
will return `nil` instead of erroring if property is not found
* Fixed incorrect scope of variadic arguments in the data-flow graph
* Fixed multiple assertion failures

---

Internal Contributors:

Co-authored-by: Aaron Weiss <aaronweiss@roblox.com>
Co-authored-by: Hunter Goldstein <hgoldstein@roblox.com>
Co-authored-by: Varun Saini <vsaini@roblox.com>
Co-authored-by: Vighnesh Vijay <vvijay@roblox.com>
Co-authored-by: Vyacheslav Egorov <vegorov@roblox.com>
This commit is contained in:
vegorov-rbx 2024-10-18 10:27:15 -07:00 committed by GitHub
parent d7842e08ae
commit e491128f95
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
44 changed files with 1590 additions and 524 deletions

View file

@ -39,6 +39,7 @@ enum class AutocompleteEntryKind
Type, Type,
Module, Module,
GeneratedFunction, GeneratedFunction,
RequirePath,
}; };
enum class ParenthesesRecommendation enum class ParenthesesRecommendation

View file

@ -9,6 +9,8 @@
namespace Luau namespace Luau
{ {
static constexpr char kRequireTagName[] = "require";
struct Frontend; struct Frontend;
struct GlobalTypes; struct GlobalTypes;
struct TypeChecker; struct TypeChecker;

View file

@ -68,7 +68,6 @@ private:
DenseHashMap<const AstExpr*, const Def*> compoundAssignDefs{nullptr}; DenseHashMap<const AstExpr*, const Def*> compoundAssignDefs{nullptr};
DenseHashMap<const AstExpr*, const RefinementKey*> astRefinementKeys{nullptr}; DenseHashMap<const AstExpr*, const RefinementKey*> astRefinementKeys{nullptr};
friend struct DataFlowGraphBuilder; friend struct DataFlowGraphBuilder;
}; };
@ -83,6 +82,7 @@ struct DfgScope
DfgScope* parent; DfgScope* parent;
ScopeType scopeType; ScopeType scopeType;
Location location;
using Bindings = DenseHashMap<Symbol, const Def*>; using Bindings = DenseHashMap<Symbol, const Def*>;
using Props = DenseHashMap<const Def*, std::unordered_map<std::string, const Def*>>; using Props = DenseHashMap<const Def*, std::unordered_map<std::string, const Def*>>;
@ -105,10 +105,44 @@ struct DataFlowResult
const RefinementKey* parent = nullptr; const RefinementKey* parent = nullptr;
}; };
using ScopeStack = std::vector<DfgScope*>;
struct DataFlowGraphBuilder struct DataFlowGraphBuilder
{ {
static DataFlowGraph build(AstStatBlock* root, NotNull<struct InternalErrorReporter> handle); static DataFlowGraph build(AstStatBlock* root, NotNull<struct InternalErrorReporter> handle);
/**
* This method is identical to the build method above, but returns a pair of dfg, scopes as the data flow graph
* here is intended to live on the module between runs of typechecking. Before, the DFG only needed to live as
* long as the typecheck, but in a world with incremental typechecking, we need the information on the dfg to incrementally
* typecheck small fragments of code.
* @param block - pointer to the ast to build the dfg for
* @param handle - for raising internal errors while building the dfg
*/
static std::pair<std::shared_ptr<DataFlowGraph>, std::vector<std::unique_ptr<DfgScope>>> buildShared(
AstStatBlock* block,
NotNull<InternalErrorReporter> handle
);
/**
* Takes a stale graph along with a list of scopes, a small fragment of the ast, and a cursor position
* and constructs the DataFlowGraph for just that fragment. This method will fabricate defs in the final
* DFG for things that have been referenced and exist in the stale dfg.
* For example, the fragment local z = x + y will populate defs for x and y from the stale graph.
* @param staleGraph - the old DFG
* @param scopes - the old DfgScopes in the graph
* @param fragment - the Ast Fragment to re-build the root for
* @param cursorPos - the current location of the cursor - used to determine which scope we are currently in
* @param handle - for internal compiler errors
*/
static DataFlowGraph updateGraph(
const DataFlowGraph& staleGraph,
const std::vector<std::unique_ptr<DfgScope>>& scopes,
AstStatBlock* fragment,
const Position& cursorPos,
NotNull<InternalErrorReporter> handle
);
private: private:
DataFlowGraphBuilder() = default; DataFlowGraphBuilder() = default;
@ -120,10 +154,15 @@ private:
NotNull<RefinementKeyArena> keyArena{&graph.keyArena}; NotNull<RefinementKeyArena> keyArena{&graph.keyArena};
struct InternalErrorReporter* handle = nullptr; struct InternalErrorReporter* handle = nullptr;
DfgScope* moduleScope = nullptr;
/// The arena owning all of the scope allocations for the dataflow graph being built.
std::vector<std::unique_ptr<DfgScope>> scopes; std::vector<std::unique_ptr<DfgScope>> scopes;
/// A stack of scopes used by the visitor to see where we are.
ScopeStack scopeStack;
DfgScope* currentScope();
struct FunctionCapture struct FunctionCapture
{ {
std::vector<DefId> captureDefs; std::vector<DefId> captureDefs;
@ -134,81 +173,81 @@ private:
DenseHashMap<Symbol, FunctionCapture> captures{Symbol{}}; DenseHashMap<Symbol, FunctionCapture> captures{Symbol{}};
void resolveCaptures(); void resolveCaptures();
DfgScope* childScope(DfgScope* scope, DfgScope::ScopeType scopeType = DfgScope::Linear); DfgScope* makeChildScope(Location loc, DfgScope::ScopeType scopeType = DfgScope::Linear);
void join(DfgScope* p, DfgScope* a, DfgScope* b); void join(DfgScope* p, DfgScope* a, DfgScope* b);
void joinBindings(DfgScope* p, const DfgScope& a, const DfgScope& b); void joinBindings(DfgScope* p, const DfgScope& a, const DfgScope& b);
void joinProps(DfgScope* p, const DfgScope& a, const DfgScope& b); void joinProps(DfgScope* p, const DfgScope& a, const DfgScope& b);
DefId lookup(DfgScope* scope, Symbol symbol); DefId lookup(Symbol symbol);
DefId lookup(DfgScope* scope, DefId def, const std::string& key); DefId lookup(DefId def, const std::string& key);
ControlFlow visit(DfgScope* scope, AstStatBlock* b); ControlFlow visit(AstStatBlock* b);
ControlFlow visitBlockWithoutChildScope(DfgScope* scope, AstStatBlock* b); ControlFlow visitBlockWithoutChildScope(AstStatBlock* b);
ControlFlow visit(DfgScope* scope, AstStat* s); ControlFlow visit(AstStat* s);
ControlFlow visit(DfgScope* scope, AstStatIf* i); ControlFlow visit(AstStatIf* i);
ControlFlow visit(DfgScope* scope, AstStatWhile* w); ControlFlow visit(AstStatWhile* w);
ControlFlow visit(DfgScope* scope, AstStatRepeat* r); ControlFlow visit(AstStatRepeat* r);
ControlFlow visit(DfgScope* scope, AstStatBreak* b); ControlFlow visit(AstStatBreak* b);
ControlFlow visit(DfgScope* scope, AstStatContinue* c); ControlFlow visit(AstStatContinue* c);
ControlFlow visit(DfgScope* scope, AstStatReturn* r); ControlFlow visit(AstStatReturn* r);
ControlFlow visit(DfgScope* scope, AstStatExpr* e); ControlFlow visit(AstStatExpr* e);
ControlFlow visit(DfgScope* scope, AstStatLocal* l); ControlFlow visit(AstStatLocal* l);
ControlFlow visit(DfgScope* scope, AstStatFor* f); ControlFlow visit(AstStatFor* f);
ControlFlow visit(DfgScope* scope, AstStatForIn* f); ControlFlow visit(AstStatForIn* f);
ControlFlow visit(DfgScope* scope, AstStatAssign* a); ControlFlow visit(AstStatAssign* a);
ControlFlow visit(DfgScope* scope, AstStatCompoundAssign* c); ControlFlow visit(AstStatCompoundAssign* c);
ControlFlow visit(DfgScope* scope, AstStatFunction* f); ControlFlow visit(AstStatFunction* f);
ControlFlow visit(DfgScope* scope, AstStatLocalFunction* l); ControlFlow visit(AstStatLocalFunction* l);
ControlFlow visit(DfgScope* scope, AstStatTypeAlias* t); ControlFlow visit(AstStatTypeAlias* t);
ControlFlow visit(DfgScope* scope, AstStatTypeFunction* f); ControlFlow visit(AstStatTypeFunction* f);
ControlFlow visit(DfgScope* scope, AstStatDeclareGlobal* d); ControlFlow visit(AstStatDeclareGlobal* d);
ControlFlow visit(DfgScope* scope, AstStatDeclareFunction* d); ControlFlow visit(AstStatDeclareFunction* d);
ControlFlow visit(DfgScope* scope, AstStatDeclareClass* d); ControlFlow visit(AstStatDeclareClass* d);
ControlFlow visit(DfgScope* scope, AstStatError* error); ControlFlow visit(AstStatError* error);
DataFlowResult visitExpr(DfgScope* scope, AstExpr* e); DataFlowResult visitExpr(AstExpr* e);
DataFlowResult visitExpr(DfgScope* scope, AstExprGroup* group); DataFlowResult visitExpr(AstExprGroup* group);
DataFlowResult visitExpr(DfgScope* scope, AstExprLocal* l); DataFlowResult visitExpr(AstExprLocal* l);
DataFlowResult visitExpr(DfgScope* scope, AstExprGlobal* g); DataFlowResult visitExpr(AstExprGlobal* g);
DataFlowResult visitExpr(DfgScope* scope, AstExprCall* c); DataFlowResult visitExpr(AstExprCall* c);
DataFlowResult visitExpr(DfgScope* scope, AstExprIndexName* i); DataFlowResult visitExpr(AstExprIndexName* i);
DataFlowResult visitExpr(DfgScope* scope, AstExprIndexExpr* i); DataFlowResult visitExpr(AstExprIndexExpr* i);
DataFlowResult visitExpr(DfgScope* scope, AstExprFunction* f); DataFlowResult visitExpr(AstExprFunction* f);
DataFlowResult visitExpr(DfgScope* scope, AstExprTable* t); DataFlowResult visitExpr(AstExprTable* t);
DataFlowResult visitExpr(DfgScope* scope, AstExprUnary* u); DataFlowResult visitExpr(AstExprUnary* u);
DataFlowResult visitExpr(DfgScope* scope, AstExprBinary* b); DataFlowResult visitExpr(AstExprBinary* b);
DataFlowResult visitExpr(DfgScope* scope, AstExprTypeAssertion* t); DataFlowResult visitExpr(AstExprTypeAssertion* t);
DataFlowResult visitExpr(DfgScope* scope, AstExprIfElse* i); DataFlowResult visitExpr(AstExprIfElse* i);
DataFlowResult visitExpr(DfgScope* scope, AstExprInterpString* i); DataFlowResult visitExpr(AstExprInterpString* i);
DataFlowResult visitExpr(DfgScope* scope, AstExprError* error); DataFlowResult visitExpr(AstExprError* error);
void visitLValue(DfgScope* scope, AstExpr* e, DefId incomingDef); void visitLValue(AstExpr* e, DefId incomingDef);
DefId visitLValue(DfgScope* scope, AstExprLocal* l, DefId incomingDef); DefId visitLValue(AstExprLocal* l, DefId incomingDef);
DefId visitLValue(DfgScope* scope, AstExprGlobal* g, DefId incomingDef); DefId visitLValue(AstExprGlobal* g, DefId incomingDef);
DefId visitLValue(DfgScope* scope, AstExprIndexName* i, DefId incomingDef); DefId visitLValue(AstExprIndexName* i, DefId incomingDef);
DefId visitLValue(DfgScope* scope, AstExprIndexExpr* i, DefId incomingDef); DefId visitLValue(AstExprIndexExpr* i, DefId incomingDef);
DefId visitLValue(DfgScope* scope, AstExprError* e, DefId incomingDef); DefId visitLValue(AstExprError* e, DefId incomingDef);
void visitType(DfgScope* scope, AstType* t); void visitType(AstType* t);
void visitType(DfgScope* scope, AstTypeReference* r); void visitType(AstTypeReference* r);
void visitType(DfgScope* scope, AstTypeTable* t); void visitType(AstTypeTable* t);
void visitType(DfgScope* scope, AstTypeFunction* f); void visitType(AstTypeFunction* f);
void visitType(DfgScope* scope, AstTypeTypeof* t); void visitType(AstTypeTypeof* t);
void visitType(DfgScope* scope, AstTypeUnion* u); void visitType(AstTypeUnion* u);
void visitType(DfgScope* scope, AstTypeIntersection* i); void visitType(AstTypeIntersection* i);
void visitType(DfgScope* scope, AstTypeError* error); void visitType(AstTypeError* error);
void visitTypePack(DfgScope* scope, AstTypePack* p); void visitTypePack(AstTypePack* p);
void visitTypePack(DfgScope* scope, AstTypePackExplicit* e); void visitTypePack(AstTypePackExplicit* e);
void visitTypePack(DfgScope* scope, AstTypePackVariadic* v); void visitTypePack(AstTypePackVariadic* v);
void visitTypePack(DfgScope* scope, AstTypePackGeneric* g); void visitTypePack(AstTypePackGeneric* g);
void visitTypeList(DfgScope* scope, AstTypeList l); void visitTypeList(AstTypeList l);
void visitGenerics(DfgScope* scope, AstArray<AstGenericType> g); void visitGenerics(AstArray<AstGenericType> g);
void visitGenericPacks(DfgScope* scope, AstArray<AstGenericTypePack> g); void visitGenericPacks(AstArray<AstGenericTypePack> g);
}; };
} // namespace Luau } // namespace Luau

View file

@ -3,6 +3,7 @@
#include <string> #include <string>
#include <optional> #include <optional>
#include <vector>
namespace Luau namespace Luau
{ {
@ -31,6 +32,9 @@ struct ModuleInfo
bool optional = false; bool optional = false;
}; };
using RequireSuggestion = std::string;
using RequireSuggestions = std::vector<RequireSuggestion>;
struct FileResolver struct FileResolver
{ {
virtual ~FileResolver() {} virtual ~FileResolver() {}
@ -51,6 +55,11 @@ struct FileResolver
{ {
return std::nullopt; return std::nullopt;
} }
virtual std::optional<RequireSuggestions> getRequireSuggestions(const ModuleName& requirer, const std::optional<std::string>& pathString) const
{
return std::nullopt;
}
}; };
struct NullFileResolver : FileResolver struct NullFileResolver : FileResolver

View file

@ -1,12 +1,15 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once #pragma once
#include "Luau/DenseHash.h"
#include "Luau/Ast.h" #include "Luau/Ast.h"
#include "Luau/Parser.h"
#include "Luau/Autocomplete.h"
#include "Luau/DenseHash.h"
#include "Luau/Module.h"
#include <memory>
#include <vector> #include <vector>
namespace Luau namespace Luau
{ {
@ -15,9 +18,28 @@ struct FragmentAutocompleteAncestryResult
DenseHashMap<AstName, AstLocal*> localMap{AstName()}; DenseHashMap<AstName, AstLocal*> localMap{AstName()};
std::vector<AstLocal*> localStack; std::vector<AstLocal*> localStack;
std::vector<AstNode*> ancestry; std::vector<AstNode*> ancestry;
AstStat* nearestStatement; AstStat* nearestStatement = nullptr;
};
struct FragmentParseResult
{
std::string fragmentToParse;
AstStatBlock* root = nullptr;
std::vector<AstNode*> ancestry;
std::unique_ptr<Allocator> alloc = std::make_unique<Allocator>();
}; };
FragmentAutocompleteAncestryResult findAncestryForFragmentParse(AstStatBlock* root, const Position& cursorPos); FragmentAutocompleteAncestryResult findAncestryForFragmentParse(AstStatBlock* root, const Position& cursorPos);
FragmentParseResult parseFragment(const SourceModule& srcModule, std::string_view src, const Position& cursorPos);
AutocompleteResult fragmentAutocomplete(
Frontend& frontend,
std::string_view src,
const ModuleName& moduleName,
Position& cursorPosition,
StringCompletionCallback callback
);
} // namespace Luau } // namespace Luau

View file

@ -9,6 +9,7 @@
#include "Luau/Scope.h" #include "Luau/Scope.h"
#include "Luau/TypeArena.h" #include "Luau/TypeArena.h"
#include "Luau/AnyTypeSummary.h" #include "Luau/AnyTypeSummary.h"
#include "Luau/DataFlowGraph.h"
#include <memory> #include <memory>
#include <vector> #include <vector>
@ -131,6 +132,9 @@ struct Module
TypePackId returnType = nullptr; TypePackId returnType = nullptr;
std::unordered_map<Name, TypeFun> exportedTypeBindings; std::unordered_map<Name, TypeFun> exportedTypeBindings;
// We also need to keep DFG data alive between runs
std::shared_ptr<DataFlowGraph> dataFlowGraph = nullptr;
std::vector<std::unique_ptr<DfgScope>> dfgScopes;
bool hasModuleScope() const; bool hasModuleScope() const;
ScopePtr getModuleScope() const; ScopePtr getModuleScope() const;

View file

@ -667,6 +667,11 @@ struct AnyType
{ {
}; };
// A special, trivial type for the refinement system that is always eliminated from intersections.
struct NoRefineType
{
};
// `T | U` // `T | U`
struct UnionType struct UnionType
{ {
@ -755,6 +760,7 @@ using TypeVariant = Unifiable::Variant<
UnknownType, UnknownType,
NeverType, NeverType,
NegationType, NegationType,
NoRefineType,
TypeFunctionInstanceType>; TypeFunctionInstanceType>;
struct Type final struct Type final
@ -949,6 +955,7 @@ public:
const TypeId unknownType; const TypeId unknownType;
const TypeId neverType; const TypeId neverType;
const TypeId errorType; const TypeId errorType;
const TypeId noRefineType;
const TypeId falsyType; const TypeId falsyType;
const TypeId truthyType; const TypeId truthyType;

View file

@ -248,4 +248,36 @@ std::optional<Ty> follow(std::optional<Ty> ty)
return std::nullopt; return std::nullopt;
} }
/**
* Returns whether or not expr is a literal expression, for example:
* - Scalar literals (numbers, booleans, strings, nil)
* - Table literals
* - Lambdas (a "function literal")
*/
bool isLiteral(const AstExpr* expr);
/**
* Given a table literal and a mapping from expression to type, determine
* whether any literal expression in this table depends on any blocked types.
* This is used as a precondition for bidirectional inference: be warned that
* the behavior of this algorithm is tightly coupled to that of bidirectional
* inference.
* @param expr Expression to search
* @param astTypes Mapping from AST node to TypeID
* @returns A vector of blocked types
*/
std::vector<TypeId> findBlockedTypesIn(AstExprTable* expr, NotNull<DenseHashMap<const AstExpr*, TypeId>> astTypes);
/**
* Given a function call and a mapping from expression to type, determine
* whether the type of any argument in said call in depends on a blocked types.
* This is used as a precondition for bidirectional inference: be warned that
* the behavior of this algorithm is tightly coupled to that of bidirectional
* inference.
* @param expr Expression to search
* @param astTypes Mapping from AST node to TypeID
* @returns A vector of blocked types
*/
std::vector<TypeId> findBlockedArgTypesIn(AstExprCall* expr, NotNull<DenseHashMap<const AstExpr*, TypeId>> astTypes);
} // namespace Luau } // namespace Luau

View file

@ -133,6 +133,10 @@ struct GenericTypeVisitor
{ {
return visit(ty); return visit(ty);
} }
virtual bool visit(TypeId ty, const NoRefineType& nrt)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const UnknownType& utv) virtual bool visit(TypeId ty, const UnknownType& utv)
{ {
return visit(ty); return visit(ty);
@ -345,6 +349,8 @@ struct GenericTypeVisitor
} }
else if (auto atv = get<AnyType>(ty)) else if (auto atv = get<AnyType>(ty))
visit(ty, *atv); visit(ty, *atv);
else if (auto nrt = get<NoRefineType>(ty))
visit(ty, *nrt);
else if (auto utv = get<UnionType>(ty)) else if (auto utv = get<UnionType>(ty))
{ {
if (visit(ty, *utv)) if (visit(ty, *utv))

View file

@ -3,6 +3,8 @@
#include "Luau/AstQuery.h" #include "Luau/AstQuery.h"
#include "Luau/BuiltinDefinitions.h" #include "Luau/BuiltinDefinitions.h"
#include "Luau/Common.h"
#include "Luau/FileResolver.h"
#include "Luau/Frontend.h" #include "Luau/Frontend.h"
#include "Luau/ToString.h" #include "Luau/ToString.h"
#include "Luau/Subtyping.h" #include "Luau/Subtyping.h"
@ -15,6 +17,7 @@
LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAG(LuauAutocompleteNewSolverLimit) LUAU_FASTFLAG(LuauAutocompleteNewSolverLimit)
LUAU_FASTFLAGVARIABLE(AutocompleteRequirePathSuggestions, false)
LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease)
LUAU_FASTINT(LuauTypeInferIterationLimit) LUAU_FASTINT(LuauTypeInferIterationLimit)
@ -215,8 +218,7 @@ static TypeCorrectKind checkTypeCorrectKind(
{ {
for (TypeId id : itv->parts) for (TypeId id : itv->parts)
{ {
if (DFInt::LuauTypeSolverRelease >= 644) id = follow(id);
id = follow(id);
if (const FunctionType* ftv = get<FunctionType>(id); ftv && checkFunctionType(ftv)) if (const FunctionType* ftv = get<FunctionType>(id); ftv && checkFunctionType(ftv))
{ {
@ -1444,11 +1446,25 @@ static std::optional<std::string> getStringContents(const AstNode* node)
} }
} }
static std::optional<AutocompleteEntryMap> convertRequireSuggestionsToAutocompleteEntryMap(std::optional<RequireSuggestions> suggestions)
{
if (!suggestions)
return std::nullopt;
AutocompleteEntryMap result;
for (const RequireSuggestion& suggestion : *suggestions)
{
result[suggestion] = {AutocompleteEntryKind::RequirePath};
}
return result;
}
static std::optional<AutocompleteEntryMap> autocompleteStringParams( static std::optional<AutocompleteEntryMap> autocompleteStringParams(
const SourceModule& sourceModule, const SourceModule& sourceModule,
const ModulePtr& module, const ModulePtr& module,
const std::vector<AstNode*>& nodes, const std::vector<AstNode*>& nodes,
Position position, Position position,
FileResolver* fileResolver,
StringCompletionCallback callback StringCompletionCallback callback
) )
{ {
@ -1495,6 +1511,13 @@ static std::optional<AutocompleteEntryMap> autocompleteStringParams(
{ {
for (const std::string& tag : funcType->tags) for (const std::string& tag : funcType->tags)
{ {
if (FFlag::AutocompleteRequirePathSuggestions)
{
if (tag == kRequireTagName && fileResolver)
{
return convertRequireSuggestionsToAutocompleteEntryMap(fileResolver->getRequireSuggestions(module->name, candidateString));
}
}
if (std::optional<AutocompleteEntryMap> ret = callback(tag, getMethodContainingClass(module, candidate->func), candidateString)) if (std::optional<AutocompleteEntryMap> ret = callback(tag, getMethodContainingClass(module, candidate->func), candidateString))
{ {
return ret; return ret;
@ -1679,6 +1702,7 @@ static AutocompleteResult autocomplete(
TypeArena* typeArena, TypeArena* typeArena,
Scope* globalScope, Scope* globalScope,
Position position, Position position,
FileResolver* fileResolver,
StringCompletionCallback callback StringCompletionCallback callback
) )
{ {
@ -1922,7 +1946,7 @@ static AutocompleteResult autocomplete(
else if (isIdentifier(node) && (parent->is<AstStatExpr>() || parent->is<AstStatError>())) else if (isIdentifier(node) && (parent->is<AstStatExpr>() || parent->is<AstStatError>()))
return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry, AutocompleteContext::Statement}; return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry, AutocompleteContext::Statement};
if (std::optional<AutocompleteEntryMap> ret = autocompleteStringParams(sourceModule, module, ancestry, position, callback)) if (std::optional<AutocompleteEntryMap> ret = autocompleteStringParams(sourceModule, module, ancestry, position, fileResolver, callback))
{ {
return {*ret, ancestry, AutocompleteContext::String}; return {*ret, ancestry, AutocompleteContext::String};
} }
@ -1999,7 +2023,7 @@ AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName
globalScope = frontend.globalsForAutocomplete.globalScope.get(); globalScope = frontend.globalsForAutocomplete.globalScope.get();
TypeArena typeArena; TypeArena typeArena;
return autocomplete(*sourceModule, module, builtinTypes, &typeArena, globalScope, position, callback); return autocomplete(*sourceModule, module, builtinTypes, &typeArena, globalScope, position, frontend.fileResolver, callback);
} }
} // namespace Luau } // namespace Luau

View file

@ -27,6 +27,8 @@
LUAU_FASTFLAG(LuauSolverV2); LUAU_FASTFLAG(LuauSolverV2);
LUAU_FASTFLAG(AutocompleteRequirePathSuggestions);
namespace Luau namespace Luau
{ {
@ -413,8 +415,18 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC
attachDcrMagicFunction(ttv->props["pack"].type(), dcrMagicFunctionPack); attachDcrMagicFunction(ttv->props["pack"].type(), dcrMagicFunctionPack);
} }
attachMagicFunction(getGlobalBinding(globals, "require"), magicFunctionRequire); if (FFlag::AutocompleteRequirePathSuggestions)
attachDcrMagicFunction(getGlobalBinding(globals, "require"), dcrMagicFunctionRequire); {
TypeId requireTy = getGlobalBinding(globals, "require");
attachTag(requireTy, kRequireTagName);
attachMagicFunction(requireTy, magicFunctionRequire);
attachDcrMagicFunction(requireTy, dcrMagicFunctionRequire);
}
else
{
attachMagicFunction(getGlobalBinding(globals, "require"), magicFunctionRequire);
attachDcrMagicFunction(getGlobalBinding(globals, "require"), dcrMagicFunctionRequire);
}
} }
static std::vector<TypeId> parseFormatString(NotNull<BuiltinTypes> builtinTypes, const char* data, size_t size) static std::vector<TypeId> parseFormatString(NotNull<BuiltinTypes> builtinTypes, const char* data, size_t size)

View file

@ -359,6 +359,11 @@ private:
// noop. // noop.
} }
void cloneChildren(NoRefineType* t)
{
// noop.
}
void cloneChildren(UnionType* t) void cloneChildren(UnionType* t)
{ {
for (TypeId& ty : t->options) for (TypeId& ty : t->options)

View file

@ -26,10 +26,10 @@
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
LUAU_FASTINT(LuauCheckRecursionLimit); LUAU_FASTINT(LuauCheckRecursionLimit)
LUAU_FASTFLAG(DebugLuauLogSolverToJson); LUAU_FASTFLAG(DebugLuauLogSolverToJson)
LUAU_FASTFLAG(DebugLuauMagicTypes); LUAU_FASTFLAG(DebugLuauMagicTypes)
LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease); LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease)
namespace Luau namespace Luau
{ {
@ -2883,9 +2883,45 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprTable* expr,
{ {
Unifier2 unifier{arena, builtinTypes, NotNull{scope.get()}, ice}; Unifier2 unifier{arena, builtinTypes, NotNull{scope.get()}, ice};
std::vector<TypeId> toBlock; std::vector<TypeId> toBlock;
matchLiteralType( if (DFInt::LuauTypeSolverRelease >= 648)
NotNull{&module->astTypes}, NotNull{&module->astExpectedTypes}, builtinTypes, arena, NotNull{&unifier}, *expectedType, ty, expr, toBlock {
); // This logic is incomplete as we want to re-run this
// _after_ blocked types have resolved, but this
// allows us to do some bidirectional inference.
toBlock = findBlockedTypesIn(expr, NotNull{&module->astTypes});
if (toBlock.empty())
{
matchLiteralType(
NotNull{&module->astTypes},
NotNull{&module->astExpectedTypes},
builtinTypes,
arena,
NotNull{&unifier},
*expectedType,
ty,
expr,
toBlock
);
// The visitor we ran prior should ensure that there are no
// blocked types that we would encounter while matching on
// this expression.
LUAU_ASSERT(toBlock.empty());
}
}
else
{
matchLiteralType(
NotNull{&module->astTypes},
NotNull{&module->astExpectedTypes},
builtinTypes,
arena,
NotNull{&unifier},
*expectedType,
ty,
expr,
toBlock
);
}
} }
return Inference{ty}; return Inference{ty};

View file

@ -32,6 +32,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverIncludeDependencies, false)
LUAU_FASTFLAGVARIABLE(DebugLuauLogBindings, false) LUAU_FASTFLAGVARIABLE(DebugLuauLogBindings, false)
LUAU_FASTINTVARIABLE(LuauSolverRecursionLimit, 500) LUAU_FASTINTVARIABLE(LuauSolverRecursionLimit, 500)
LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease)
LUAU_FASTFLAGVARIABLE(LuauRemoveNotAnyHack, false)
namespace Luau namespace Luau
{ {
@ -1238,14 +1239,22 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull<cons
continue; continue;
} }
// We use `any` here because the discriminant type may be pointed at by both branches, if (FFlag::LuauRemoveNotAnyHack)
// where the discriminant type is not negated, and the other where it is negated, i.e. {
// `unknown ~ unknown` and `~unknown ~ never`, so `T & unknown ~ T` and `T & ~unknown ~ never` // We bind any unused discriminants to the `*no-refine*` type indicating that it can be safely ignored.
// v.s. emplaceType<BoundType>(asMutable(follow(*ty)), builtinTypes->noRefineType);
// `any ~ any` and `~any ~ any`, so `T & any ~ T` and `T & ~any ~ T` }
// else
// In practice, users cannot negate `any`, so this is an implementation detail we can always change. {
emplaceType<BoundType>(asMutable(follow(*ty)), builtinTypes->anyType); // We use `any` here because the discriminant type may be pointed at by both branches,
// where the discriminant type is not negated, and the other where it is negated, i.e.
// `unknown ~ unknown` and `~unknown ~ never`, so `T & unknown ~ T` and `T & ~unknown ~ never`
// v.s.
// `any ~ any` and `~any ~ any`, so `T & any ~ T` and `T & ~any ~ T`
//
// In practice, users cannot negate `any`, so this is an implementation detail we can always change.
emplaceType<BoundType>(asMutable(follow(*ty)), builtinTypes->anyType);
}
} }
OverloadResolver resolver{ OverloadResolver resolver{
@ -1322,6 +1331,22 @@ bool ConstraintSolver::tryDispatch(const FunctionCheckConstraint& c, NotNull<con
if (isBlocked(argsPack)) if (isBlocked(argsPack))
return true; return true;
if (DFInt::LuauTypeSolverRelease >= 648)
{
// This is expensive as we need to traverse a (potentially large)
// literal up front in order to determine if there are any blocked
// types, otherwise we may run `matchTypeLiteral` multiple times,
// which right now may fail due to being non-idempotent (it
// destructively updates the underlying literal type).
auto blockedTypes = findBlockedArgTypesIn(c.callSite, c.astTypes);
for (const auto ty : blockedTypes)
{
block(ty, constraint);
}
if (!blockedTypes.empty())
return false;
}
// We know the type of the function and the arguments it expects to receive. // We know the type of the function and the arguments it expects to receive.
// We also know the TypeIds of the actual arguments that will be passed. // We also know the TypeIds of the actual arguments that will be passed.
// //
@ -1384,7 +1409,7 @@ bool ConstraintSolver::tryDispatch(const FunctionCheckConstraint& c, NotNull<con
{ {
const TypeId expectedArgTy = follow(expectedArgs[i + typeOffset]); const TypeId expectedArgTy = follow(expectedArgs[i + typeOffset]);
const TypeId actualArgTy = follow(argPackHead[i + typeOffset]); const TypeId actualArgTy = follow(argPackHead[i + typeOffset]);
const AstExpr* expr = unwrapGroup(c.callSite->args.data[i]); AstExpr* expr = unwrapGroup(c.callSite->args.data[i]);
(*c.astExpectedTypes)[expr] = expectedArgTy; (*c.astExpectedTypes)[expr] = expectedArgTy;
@ -1416,10 +1441,17 @@ bool ConstraintSolver::tryDispatch(const FunctionCheckConstraint& c, NotNull<con
Unifier2 u2{arena, builtinTypes, constraint->scope, NotNull{&iceReporter}}; Unifier2 u2{arena, builtinTypes, constraint->scope, NotNull{&iceReporter}};
std::vector<TypeId> toBlock; std::vector<TypeId> toBlock;
(void)matchLiteralType(c.astTypes, c.astExpectedTypes, builtinTypes, arena, NotNull{&u2}, expectedArgTy, actualArgTy, expr, toBlock); (void)matchLiteralType(c.astTypes, c.astExpectedTypes, builtinTypes, arena, NotNull{&u2}, expectedArgTy, actualArgTy, expr, toBlock);
for (auto t : toBlock) if (DFInt::LuauTypeSolverRelease >= 648)
block(t, constraint); {
if (!toBlock.empty()) LUAU_ASSERT(toBlock.empty());
return false; }
else
{
for (auto t : toBlock)
block(t, constraint);
if (!toBlock.empty())
return false;
}
} }
} }
@ -1748,8 +1780,9 @@ bool ConstraintSolver::tryDispatch(const AssignPropConstraint& c, NotNull<const
if (auto lhsFree = getMutable<FreeType>(lhsType)) if (auto lhsFree = getMutable<FreeType>(lhsType))
{ {
if (get<TableType>(lhsFree->upperBound) || get<MetatableType>(lhsFree->upperBound)) auto lhsFreeUpperBound = DFInt::LuauTypeSolverRelease >= 648 ? follow(lhsFree->upperBound) : lhsFree->upperBound;
lhsType = lhsFree->upperBound; if (get<TableType>(lhsFreeUpperBound) || get<MetatableType>(lhsFreeUpperBound))
lhsType = lhsFreeUpperBound;
else else
{ {
TypeId newUpperBound = arena->addType(TableType{TableState::Free, TypeLevel{}, constraint->scope}); TypeId newUpperBound = arena->addType(TableType{TableState::Free, TypeLevel{}, constraint->scope});
@ -1759,7 +1792,7 @@ bool ConstraintSolver::tryDispatch(const AssignPropConstraint& c, NotNull<const
upperTable->props[c.propName] = rhsType; upperTable->props[c.propName] = rhsType;
// Food for thought: Could we block if simplification encounters a blocked type? // Food for thought: Could we block if simplification encounters a blocked type?
lhsFree->upperBound = simplifyIntersection(builtinTypes, arena, lhsFree->upperBound, newUpperBound).result; lhsFree->upperBound = simplifyIntersection(builtinTypes, arena, lhsFreeUpperBound, newUpperBound).result;
bind(constraint, c.propType, rhsType); bind(constraint, c.propType, rhsType);
return true; return true;

File diff suppressed because it is too large Load diff

View file

@ -13,6 +13,7 @@
namespace Luau namespace Luau
{ {
std::string DiffPathNode::toString() const std::string DiffPathNode::toString() const
{ {
switch (kind) switch (kind)
@ -944,12 +945,14 @@ std::vector<std::pair<TypeId, TypeId>>::const_reverse_iterator DifferEnvironment
return visitingStack.crend(); return visitingStack.crend();
} }
DifferResult diff(TypeId ty1, TypeId ty2) DifferResult diff(TypeId ty1, TypeId ty2)
{ {
DifferEnvironment differEnv{ty1, ty2, std::nullopt, std::nullopt}; DifferEnvironment differEnv{ty1, ty2, std::nullopt, std::nullopt};
return diffUsingEnv(differEnv, ty1, ty2); return diffUsingEnv(differEnv, ty1, ty2);
} }
DifferResult diffWithSymbols(TypeId ty1, TypeId ty2, std::optional<std::string> symbol1, std::optional<std::string> symbol2) DifferResult diffWithSymbols(TypeId ty1, TypeId ty2, std::optional<std::string> symbol1, std::optional<std::string> symbol2)
{ {
DifferEnvironment differEnv{ty1, ty2, symbol1, symbol2}; DifferEnvironment differEnv{ty1, ty2, symbol1, symbol2};

View file

@ -1,10 +1,13 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/BuiltinDefinitions.h" #include "Luau/BuiltinDefinitions.h"
LUAU_FASTFLAG(LuauMathMap)
namespace Luau namespace Luau
{ {
static const std::string kBuiltinDefinitionLuaSrcChecked = R"BUILTIN_SRC( // TODO: there has to be a better way, like splitting up per library
static const std::string kBuiltinDefinitionLuaSrcChecked_DEPRECATED = R"BUILTIN_SRC(
declare bit32: { declare bit32: {
band: @checked (...number) -> number, band: @checked (...number) -> number,
@ -195,6 +198,228 @@ declare utf8: {
declare function unpack<V>(tab: {V}, i: number?, j: number?): ...V declare function unpack<V>(tab: {V}, i: number?, j: number?): ...V
--- Buffer API
declare buffer: {
create: @checked (size: number) -> buffer,
fromstring: @checked (str: string) -> buffer,
tostring: @checked (b: buffer) -> string,
len: @checked (b: buffer) -> number,
copy: @checked (target: buffer, targetOffset: number, source: buffer, sourceOffset: number?, count: number?) -> (),
fill: @checked (b: buffer, offset: number, value: number, count: number?) -> (),
readi8: @checked (b: buffer, offset: number) -> number,
readu8: @checked (b: buffer, offset: number) -> number,
readi16: @checked (b: buffer, offset: number) -> number,
readu16: @checked (b: buffer, offset: number) -> number,
readi32: @checked (b: buffer, offset: number) -> number,
readu32: @checked (b: buffer, offset: number) -> number,
readf32: @checked (b: buffer, offset: number) -> number,
readf64: @checked (b: buffer, offset: number) -> number,
writei8: @checked (b: buffer, offset: number, value: number) -> (),
writeu8: @checked (b: buffer, offset: number, value: number) -> (),
writei16: @checked (b: buffer, offset: number, value: number) -> (),
writeu16: @checked (b: buffer, offset: number, value: number) -> (),
writei32: @checked (b: buffer, offset: number, value: number) -> (),
writeu32: @checked (b: buffer, offset: number, value: number) -> (),
writef32: @checked (b: buffer, offset: number, value: number) -> (),
writef64: @checked (b: buffer, offset: number, value: number) -> (),
readstring: @checked (b: buffer, offset: number, count: number) -> string,
writestring: @checked (b: buffer, offset: number, value: string, count: number?) -> (),
}
)BUILTIN_SRC";
static const std::string kBuiltinDefinitionLuaSrcChecked = R"BUILTIN_SRC(
declare bit32: {
band: @checked (...number) -> number,
bor: @checked (...number) -> number,
bxor: @checked (...number) -> number,
btest: @checked (number, ...number) -> boolean,
rrotate: @checked (x: number, disp: number) -> number,
lrotate: @checked (x: number, disp: number) -> number,
lshift: @checked (x: number, disp: number) -> number,
arshift: @checked (x: number, disp: number) -> number,
rshift: @checked (x: number, disp: number) -> number,
bnot: @checked (x: number) -> number,
extract: @checked (n: number, field: number, width: number?) -> number,
replace: @checked (n: number, v: number, field: number, width: number?) -> number,
countlz: @checked (n: number) -> number,
countrz: @checked (n: number) -> number,
byteswap: @checked (n: number) -> number,
}
declare math: {
frexp: @checked (n: number) -> (number, number),
ldexp: @checked (s: number, e: number) -> number,
fmod: @checked (x: number, y: number) -> number,
modf: @checked (n: number) -> (number, number),
pow: @checked (x: number, y: number) -> number,
exp: @checked (n: number) -> number,
ceil: @checked (n: number) -> number,
floor: @checked (n: number) -> number,
abs: @checked (n: number) -> number,
sqrt: @checked (n: number) -> number,
log: @checked (n: number, base: number?) -> number,
log10: @checked (n: number) -> number,
rad: @checked (n: number) -> number,
deg: @checked (n: number) -> number,
sin: @checked (n: number) -> number,
cos: @checked (n: number) -> number,
tan: @checked (n: number) -> number,
sinh: @checked (n: number) -> number,
cosh: @checked (n: number) -> number,
tanh: @checked (n: number) -> number,
atan: @checked (n: number) -> number,
acos: @checked (n: number) -> number,
asin: @checked (n: number) -> number,
atan2: @checked (y: number, x: number) -> number,
min: @checked (number, ...number) -> number,
max: @checked (number, ...number) -> number,
pi: number,
huge: number,
randomseed: @checked (seed: number) -> (),
random: @checked (number?, number?) -> number,
sign: @checked (n: number) -> number,
clamp: @checked (n: number, min: number, max: number) -> number,
noise: @checked (x: number, y: number?, z: number?) -> number,
round: @checked (n: number) -> number,
map: @checked (x: number, inmin: number, inmax: number, outmin: number, outmax: number) -> number,
}
type DateTypeArg = {
year: number,
month: number,
day: number,
hour: number?,
min: number?,
sec: number?,
isdst: boolean?,
}
type DateTypeResult = {
year: number,
month: number,
wday: number,
yday: number,
day: number,
hour: number,
min: number,
sec: number,
isdst: boolean,
}
declare os: {
time: (time: DateTypeArg?) -> number,
date: ((formatString: "*t" | "!*t", time: number?) -> DateTypeResult) & ((formatString: string?, time: number?) -> string),
difftime: (t2: DateTypeResult | number, t1: DateTypeResult | number) -> number,
clock: () -> number,
}
@checked declare function require(target: any): any
@checked declare function getfenv(target: any): { [string]: any }
declare _G: any
declare _VERSION: string
declare function gcinfo(): number
declare function print<T...>(...: T...)
declare function type<T>(value: T): string
declare function typeof<T>(value: T): string
-- `assert` has a magic function attached that will give more detailed type information
declare function assert<T>(value: T, errorMessage: string?): T
declare function error<T>(message: T, level: number?): never
declare function tostring<T>(value: T): string
declare function tonumber<T>(value: T, radix: number?): number?
declare function rawequal<T1, T2>(a: T1, b: T2): boolean
declare function rawget<K, V>(tab: {[K]: V}, k: K): V
declare function rawset<K, V>(tab: {[K]: V}, k: K, v: V): {[K]: V}
declare function rawlen<K, V>(obj: {[K]: V} | string): number
declare function setfenv<T..., R...>(target: number | (T...) -> R..., env: {[string]: any}): ((T...) -> R...)?
declare function ipairs<V>(tab: {V}): (({V}, number) -> (number?, V), {V}, number)
declare function pcall<A..., R...>(f: (A...) -> R..., ...: A...): (boolean, R...)
-- FIXME: The actual type of `xpcall` is:
-- <E, A..., R1..., R2...>(f: (A...) -> R1..., err: (E) -> R2..., A...) -> (true, R1...) | (false, R2...)
-- Since we can't represent the return value, we use (boolean, R1...).
declare function xpcall<E, A..., R1..., R2...>(f: (A...) -> R1..., err: (E) -> R2..., ...: A...): (boolean, R1...)
-- `select` has a magic function attached to provide more detailed type information
declare function select<A...>(i: string | number, ...: A...): ...any
-- FIXME: This type is not entirely correct - `loadstring` returns a function or
-- (nil, string).
declare function loadstring<A...>(src: string, chunkname: string?): (((A...) -> any)?, string?)
@checked declare function newproxy(mt: boolean?): any
declare coroutine: {
create: <A..., R...>(f: (A...) -> R...) -> thread,
resume: <A..., R...>(co: thread, A...) -> (boolean, R...),
running: () -> thread,
status: @checked (co: thread) -> "dead" | "running" | "normal" | "suspended",
wrap: <A..., R...>(f: (A...) -> R...) -> ((A...) -> R...),
yield: <A..., R...>(A...) -> R...,
isyieldable: () -> boolean,
close: @checked (co: thread) -> (boolean, any)
}
declare table: {
concat: <V>(t: {V}, sep: string?, i: number?, j: number?) -> string,
insert: (<V>(t: {V}, value: V) -> ()) & (<V>(t: {V}, pos: number, value: V) -> ()),
maxn: <V>(t: {V}) -> number,
remove: <V>(t: {V}, number?) -> V?,
sort: <V>(t: {V}, comp: ((V, V) -> boolean)?) -> (),
create: <V>(count: number, value: V?) -> {V},
find: <V>(haystack: {V}, needle: V, init: number?) -> number?,
unpack: <V>(list: {V}, i: number?, j: number?) -> ...V,
pack: <V>(...V) -> { n: number, [number]: V },
getn: <V>(t: {V}) -> number,
foreach: <K, V>(t: {[K]: V}, f: (K, V) -> ()) -> (),
foreachi: <V>({V}, (number, V) -> ()) -> (),
move: <V>(src: {V}, a: number, b: number, t: number, dst: {V}?) -> {V},
clear: <K, V>(table: {[K]: V}) -> (),
isfrozen: <K, V>(t: {[K]: V}) -> boolean,
}
declare debug: {
info: (<R...>(thread: thread, level: number, options: string) -> R...) & (<R...>(level: number, options: string) -> R...) & (<A..., R1..., R2...>(func: (A...) -> R1..., options: string) -> R2...),
traceback: ((message: string?, level: number?) -> string) & ((thread: thread, message: string?, level: number?) -> string),
}
declare utf8: {
char: @checked (...number) -> string,
charpattern: string,
codes: @checked (str: string) -> ((string, number) -> (number, number), string, number),
codepoint: @checked (str: string, i: number?, j: number?) -> ...number,
len: @checked (s: string, i: number?, j: number?) -> (number?, number?),
offset: @checked (s: string, n: number?, i: number?) -> number,
}
-- Cannot use `typeof` here because it will produce a polytype when we expect a monotype.
declare function unpack<V>(tab: {V}, i: number?, j: number?): ...V
--- Buffer API --- Buffer API
declare buffer: { declare buffer: {
create: @checked (size: number) -> buffer, create: @checked (size: number) -> buffer,
@ -227,7 +452,7 @@ declare buffer: {
std::string getBuiltinDefinitionSource() std::string getBuiltinDefinitionSource()
{ {
std::string result = kBuiltinDefinitionLuaSrcChecked; std::string result = FFlag::LuauMathMap ? kBuiltinDefinitionLuaSrcChecked : kBuiltinDefinitionLuaSrcChecked_DEPRECATED;
return result; return result;
} }

View file

@ -3,6 +3,11 @@
#include "Luau/Ast.h" #include "Luau/Ast.h"
#include "Luau/AstQuery.h" #include "Luau/AstQuery.h"
#include "Luau/Common.h"
#include "Luau/Frontend.h"
#include "Luau/Parser.h"
#include "Luau/ParseOptions.h"
#include "Luau/Module.h"
namespace Luau namespace Luau
{ {
@ -10,6 +15,8 @@ namespace Luau
FragmentAutocompleteAncestryResult findAncestryForFragmentParse(AstStatBlock* root, const Position& cursorPos) FragmentAutocompleteAncestryResult findAncestryForFragmentParse(AstStatBlock* root, const Position& cursorPos)
{ {
std::vector<AstNode*> ancestry = findAncestryAtPositionForAutocomplete(root, cursorPos); std::vector<AstNode*> ancestry = findAncestryAtPositionForAutocomplete(root, cursorPos);
// Should always contain the root AstStat
LUAU_ASSERT(ancestry.size() >= 1);
DenseHashMap<AstName, AstLocal*> localMap{AstName()}; DenseHashMap<AstName, AstLocal*> localMap{AstName()};
std::vector<AstLocal*> localStack; std::vector<AstLocal*> localStack;
AstStat* nearestStatement = nullptr; AstStat* nearestStatement = nullptr;
@ -21,7 +28,7 @@ FragmentAutocompleteAncestryResult findAncestryForFragmentParse(AstStatBlock* ro
{ {
if (stat->location.begin <= cursorPos) if (stat->location.begin <= cursorPos)
nearestStatement = stat; nearestStatement = stat;
if (stat->location.begin <= cursorPos) if (stat->location.begin < cursorPos && stat->location.begin.line < cursorPos.line)
{ {
// This statement precedes the current one // This statement precedes the current one
if (auto loc = stat->as<AstStatLocal>()) if (auto loc = stat->as<AstStatLocal>())
@ -42,7 +49,116 @@ FragmentAutocompleteAncestryResult findAncestryForFragmentParse(AstStatBlock* ro
} }
} }
if (!nearestStatement)
nearestStatement = ancestry[0]->asStat();
LUAU_ASSERT(nearestStatement);
return {std::move(localMap), std::move(localStack), std::move(ancestry), std::move(nearestStatement)}; return {std::move(localMap), std::move(localStack), std::move(ancestry), std::move(nearestStatement)};
} }
std::pair<unsigned int, unsigned int> getDocumentOffsets(const std::string_view& src, const Position& startPos, const Position& endPos)
{
unsigned int lineCount = 0;
unsigned int colCount = 0;
unsigned int docOffset = 0;
unsigned int startOffset = 0;
unsigned int endOffset = 0;
bool foundStart = false;
bool foundEnd = false;
for (char c : src)
{
if (foundStart && foundEnd)
break;
if (startPos.line == lineCount && startPos.column == colCount)
{
foundStart = true;
startOffset = docOffset;
}
if (endPos.line == lineCount && endPos.column == colCount)
{
endOffset = docOffset;
foundEnd = true;
}
if (c == '\n')
{
lineCount++;
colCount = 0;
}
else
colCount++;
docOffset++;
}
unsigned int min = std::min(startOffset, endOffset);
unsigned int len = std::max(startOffset, endOffset) - min;
return {min, len};
}
ScopePtr findClosestScope(const ModulePtr& module, const Position& cursorPos)
{
LUAU_ASSERT(module->hasModuleScope());
ScopePtr closest = module->getModuleScope();
for (auto [loc, sc] : module->scopes)
{
if (loc.begin <= cursorPos && closest->location.begin <= loc.begin)
closest = sc;
}
return closest;
}
FragmentParseResult parseFragment(const SourceModule& srcModule, std::string_view src, const Position& cursorPos)
{
FragmentAutocompleteAncestryResult result = findAncestryForFragmentParse(srcModule.root, cursorPos);
ParseOptions opts;
opts.allowDeclarationSyntax = false;
opts.captureComments = false;
opts.parseFragment = FragmentParseResumeSettings{std::move(result.localMap), std::move(result.localStack)};
AstStat* enclosingStatement = result.nearestStatement;
const Position& endPos = cursorPos;
// If the statement starts on a previous line, grab the statement beginning
// otherwise, grab the statement end to whatever is being typed right now
const Position& startPos =
enclosingStatement->location.begin.line == cursorPos.line ? enclosingStatement->location.begin : enclosingStatement->location.end;
auto [offsetStart, parseLength] = getDocumentOffsets(src, startPos, endPos);
const char* srcStart = src.data() + offsetStart;
std::string_view dbg = src.substr(offsetStart, parseLength);
const std::shared_ptr<AstNameTable>& nameTbl = srcModule.names;
FragmentParseResult fragmentResult;
fragmentResult.fragmentToParse = std::string(dbg.data(), parseLength);
// For the duration of the incremental parse, we want to allow the name table to re-use duplicate names
ParseResult p = Luau::Parser::parse(srcStart, parseLength, *nameTbl, *fragmentResult.alloc.get(), opts);
std::vector<AstNode*> fabricatedAncestry = std::move(result.ancestry);
std::vector<AstNode*> fragmentAncestry = findAncestryAtPositionForAutocomplete(p.root, p.root->location.end);
fabricatedAncestry.insert(fabricatedAncestry.end(), fragmentAncestry.begin(), fragmentAncestry.end());
if (enclosingStatement == nullptr)
enclosingStatement = p.root;
fragmentResult.root = std::move(p.root);
fragmentResult.ancestry = std::move(fabricatedAncestry);
return fragmentResult;
}
AutocompleteResult fragmentAutocomplete(
Frontend& frontend,
std::string_view src,
const ModuleName& moduleName,
Position& cursorPosition,
StringCompletionCallback callback
)
{
LUAU_ASSERT(FFlag::LuauSolverV2);
// TODO
return {};
}
} // namespace Luau } // namespace Luau

View file

@ -49,6 +49,7 @@ LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauRunCustomModuleChecks, false)
LUAU_FASTFLAGVARIABLE(LuauMoreThoroughCycleDetection, false) LUAU_FASTFLAGVARIABLE(LuauMoreThoroughCycleDetection, false)
LUAU_FASTFLAG(StudioReportLuauAny2) LUAU_FASTFLAG(StudioReportLuauAny2)
LUAU_FASTFLAGVARIABLE(LuauStoreDFGOnModule, false)
namespace Luau namespace Luau
{ {
@ -1315,6 +1316,18 @@ ModulePtr check(
} }
DataFlowGraph dfg = DataFlowGraphBuilder::build(sourceModule.root, iceHandler); DataFlowGraph dfg = DataFlowGraphBuilder::build(sourceModule.root, iceHandler);
DataFlowGraph* dfgForConstraintGeneration = nullptr;
if (FFlag::LuauStoreDFGOnModule)
{
auto [dfg, scopes] = DataFlowGraphBuilder::buildShared(sourceModule.root, iceHandler);
result->dataFlowGraph = std::move(dfg);
result->dfgScopes = std::move(scopes);
dfgForConstraintGeneration = result->dataFlowGraph.get();
}
else
{
dfgForConstraintGeneration = &dfg;
}
UnifierSharedState unifierState{iceHandler}; UnifierSharedState unifierState{iceHandler};
unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit; unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit;
@ -1336,7 +1349,7 @@ ModulePtr check(
parentScope, parentScope,
std::move(prepareModuleScope), std::move(prepareModuleScope),
logger.get(), logger.get(),
NotNull{&dfg}, NotNull{dfgForConstraintGeneration},
requireCycles requireCycles
}; };

View file

@ -801,6 +801,12 @@ struct TypeCacher : TypeOnceVisitor
return false; return false;
} }
bool visit(TypeId ty, const NoRefineType&) override
{
cache(ty);
return false;
}
bool visit(TypeId ty, const UnionType& ut) override bool visit(TypeId ty, const UnionType& ut) override
{ {
if (isUncacheable(ty) || isCached(ty)) if (isUncacheable(ty) || isCached(ty))

View file

@ -594,7 +594,7 @@ struct NonStrictTypeChecker
std::shared_ptr<const NormalizedType> norm = normalizer.normalize(expectedArgType); std::shared_ptr<const NormalizedType> norm = normalizer.normalize(expectedArgType);
DefId def = dfg->getDef(arg); DefId def = dfg->getDef(arg);
TypeId runTimeErrorTy; TypeId runTimeErrorTy;
// If we're dealing with any, negating any will cause all subtype tests to fail, since ~any is any // If we're dealing with any, negating any will cause all subtype tests to fail
// However, when someone calls this function, they're going to want to be able to pass it anything, // However, when someone calls this function, they're going to want to be able to pass it anything,
// for that reason, we manually inject never into the context so that the runtime test will always pass. // for that reason, we manually inject never into the context so that the runtime test will always pass.
if (!norm) if (!norm)

View file

@ -1872,7 +1872,7 @@ NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId t
if (res != NormalizationResult::True) if (res != NormalizationResult::True)
return res; return res;
} }
else if (get<PendingExpansionType>(there) || get<TypeFunctionInstanceType>(there)) else if (get<PendingExpansionType>(there) || get<TypeFunctionInstanceType>(there) || get<NoRefineType>(there))
{ {
// nothing // nothing
} }
@ -3217,6 +3217,11 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type
// assumption that it is the same as any. // assumption that it is the same as any.
return NormalizationResult::True; return NormalizationResult::True;
} }
else if (get<NoRefineType>(t))
{
// `*no-refine*` means we will never do anything to affect the intersection.
return NormalizationResult::True;
}
else if (get<NeverType>(t)) else if (get<NeverType>(t))
{ {
// if we're intersecting with `~never`, this is equivalent to intersecting with `unknown` // if we're intersecting with `~never`, this is equivalent to intersecting with `unknown`
@ -3243,6 +3248,11 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type
{ {
here.classes.resetToNever(); here.classes.resetToNever();
} }
else if (get<NoRefineType>(there))
{
// `*no-refine*` means we will never do anything to affect the intersection.
return NormalizationResult::True;
}
else else
LUAU_ASSERT(!"Unreachable"); LUAU_ASSERT(!"Unreachable");

View file

@ -50,6 +50,11 @@ static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool a
LUAU_ASSERT(ty->persistent); LUAU_ASSERT(ty->persistent);
return ty; return ty;
} }
else if constexpr (std::is_same_v<T, NoRefineType>)
{
LUAU_ASSERT(ty->persistent);
return ty;
}
else if constexpr (std::is_same_v<T, ErrorType>) else if constexpr (std::is_same_v<T, ErrorType>)
{ {
LUAU_ASSERT(ty->persistent); LUAU_ASSERT(ty->persistent);

View file

@ -261,92 +261,50 @@ SubtypingResult SubtypingResult::any(const std::vector<SubtypingResult>& results
struct ApplyMappedGenerics : Substitution struct ApplyMappedGenerics : Substitution
{ {
using MappedGenerics = DenseHashMap<TypeId, SubtypingEnvironment::GenericBounds>;
using MappedGenericPacks = DenseHashMap<TypePackId, TypePackId>;
NotNull<BuiltinTypes> builtinTypes; NotNull<BuiltinTypes> builtinTypes;
NotNull<TypeArena> arena; NotNull<TypeArena> arena;
SubtypingEnvironment& env; SubtypingEnvironment& env;
MappedGenerics& mappedGenerics_DEPRECATED; ApplyMappedGenerics(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, SubtypingEnvironment& env)
MappedGenericPacks& mappedGenericPacks_DEPRECATED;
ApplyMappedGenerics(
NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> arena,
SubtypingEnvironment& env,
MappedGenerics& mappedGenerics,
MappedGenericPacks& mappedGenericPacks
)
: Substitution(TxnLog::empty(), arena) : Substitution(TxnLog::empty(), arena)
, builtinTypes(builtinTypes) , builtinTypes(builtinTypes)
, arena(arena) , arena(arena)
, env(env) , env(env)
, mappedGenerics_DEPRECATED(mappedGenerics)
, mappedGenericPacks_DEPRECATED(mappedGenericPacks)
{ {
} }
bool isDirty(TypeId ty) override bool isDirty(TypeId ty) override
{ {
if (DFInt::LuauTypeSolverRelease >= 644) return env.containsMappedType(ty);
return env.containsMappedType(ty);
else
return mappedGenerics_DEPRECATED.contains(ty);
} }
bool isDirty(TypePackId tp) override bool isDirty(TypePackId tp) override
{ {
if (DFInt::LuauTypeSolverRelease >= 644) return env.containsMappedPack(tp);
return env.containsMappedPack(tp);
else
return mappedGenericPacks_DEPRECATED.contains(tp);
} }
TypeId clean(TypeId ty) override TypeId clean(TypeId ty) override
{ {
if (DFInt::LuauTypeSolverRelease >= 644) const auto& bounds = env.getMappedTypeBounds(ty);
{
const auto& bounds = env.getMappedTypeBounds(ty);
if (bounds.upperBound.empty()) if (bounds.upperBound.empty())
return builtinTypes->unknownType; return builtinTypes->unknownType;
if (bounds.upperBound.size() == 1) if (bounds.upperBound.size() == 1)
return *begin(bounds.upperBound); return *begin(bounds.upperBound);
return arena->addType(IntersectionType{std::vector<TypeId>(begin(bounds.upperBound), end(bounds.upperBound))}); return arena->addType(IntersectionType{std::vector<TypeId>(begin(bounds.upperBound), end(bounds.upperBound))});
}
else
{
const auto& bounds = mappedGenerics_DEPRECATED[ty];
if (bounds.upperBound.empty())
return builtinTypes->unknownType;
if (bounds.upperBound.size() == 1)
return *begin(bounds.upperBound);
return arena->addType(IntersectionType{std::vector<TypeId>(begin(bounds.upperBound), end(bounds.upperBound))});
}
} }
TypePackId clean(TypePackId tp) override TypePackId clean(TypePackId tp) override
{ {
if (DFInt::LuauTypeSolverRelease >= 644) if (auto it = env.getMappedPackBounds(tp))
{ return *it;
if (auto it = env.getMappedPackBounds(tp))
return *it;
// Clean is only called when isDirty found a pack bound // Clean is only called when isDirty found a pack bound
LUAU_ASSERT(!"Unreachable"); LUAU_ASSERT(!"Unreachable");
return nullptr; return nullptr;
}
else
{
return mappedGenericPacks_DEPRECATED[tp];
}
} }
bool ignoreChildren(TypeId ty) override bool ignoreChildren(TypeId ty) override
@ -364,7 +322,7 @@ struct ApplyMappedGenerics : Substitution
std::optional<TypeId> SubtypingEnvironment::applyMappedGenerics(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, TypeId ty) std::optional<TypeId> SubtypingEnvironment::applyMappedGenerics(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, TypeId ty)
{ {
ApplyMappedGenerics amg{builtinTypes, arena, *this, mappedGenerics, mappedGenericPacks}; ApplyMappedGenerics amg{builtinTypes, arena, *this};
return amg.substitute(ty); return amg.substitute(ty);
} }
@ -489,22 +447,12 @@ SubtypingResult Subtyping::isSubtype(TypeId subTy, TypeId superTy, NotNull<Scope
} }
if (DFInt::LuauTypeSolverRelease >= 644) SubtypingEnvironment boundsEnv;
{ boundsEnv.parent = &env;
SubtypingEnvironment boundsEnv; SubtypingResult boundsResult = isCovariantWith(boundsEnv, lowerBound, upperBound, scope);
boundsEnv.parent = &env; boundsResult.reasoning.clear();
SubtypingResult boundsResult = isCovariantWith(boundsEnv, lowerBound, upperBound, scope);
boundsResult.reasoning.clear();
result.andAlso(boundsResult); result.andAlso(boundsResult);
}
else
{
SubtypingResult boundsResult = isCovariantWith(env, lowerBound, upperBound, scope);
boundsResult.reasoning.clear();
result.andAlso(boundsResult);
}
} }
/* TODO: We presently don't store subtype test results in the persistent /* TODO: We presently don't store subtype test results in the persistent
@ -582,18 +530,17 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypeId sub
subTy = follow(subTy); subTy = follow(subTy);
superTy = follow(superTy); superTy = follow(superTy);
if (const TypeId* subIt = (DFInt::LuauTypeSolverRelease >= 644 ? env.tryFindSubstitution(subTy) : env.substitutions.find(subTy)); subIt && *subIt) if (const TypeId* subIt = env.tryFindSubstitution(subTy); subIt && *subIt)
subTy = *subIt; subTy = *subIt;
if (const TypeId* superIt = (DFInt::LuauTypeSolverRelease >= 644 ? env.tryFindSubstitution(superTy) : env.substitutions.find(superTy)); if (const TypeId* superIt = env.tryFindSubstitution(superTy); superIt && *superIt)
superIt && *superIt)
superTy = *superIt; superTy = *superIt;
const SubtypingResult* cachedResult = resultCache.find({subTy, superTy}); const SubtypingResult* cachedResult = resultCache.find({subTy, superTy});
if (cachedResult) if (cachedResult)
return *cachedResult; return *cachedResult;
cachedResult = DFInt::LuauTypeSolverRelease >= 644 ? env.tryFindSubtypingResult({subTy, superTy}) : env.ephemeralCache.find({subTy, superTy}); cachedResult = env.tryFindSubtypingResult({subTy, superTy});
if (cachedResult) if (cachedResult)
return *cachedResult; return *cachedResult;
@ -838,8 +785,7 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypePackId
std::vector<TypeId> headSlice(begin(superHead), begin(superHead) + headSize); std::vector<TypeId> headSlice(begin(superHead), begin(superHead) + headSize);
TypePackId superTailPack = arena->addTypePack(std::move(headSlice), superTail); TypePackId superTailPack = arena->addTypePack(std::move(headSlice), superTail);
if (TypePackId* other = if (TypePackId* other = env.getMappedPackBounds(*subTail))
(DFInt::LuauTypeSolverRelease >= 644 ? env.getMappedPackBounds(*subTail) : env.mappedGenericPacks.find(*subTail)))
// TODO: TypePath can't express "slice of a pack + its tail". // TODO: TypePath can't express "slice of a pack + its tail".
results.push_back(isCovariantWith(env, *other, superTailPack, scope).withSubComponent(TypePath::PackField::Tail)); results.push_back(isCovariantWith(env, *other, superTailPack, scope).withSubComponent(TypePath::PackField::Tail));
else else
@ -894,8 +840,7 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypePackId
std::vector<TypeId> headSlice(begin(subHead), begin(subHead) + headSize); std::vector<TypeId> headSlice(begin(subHead), begin(subHead) + headSize);
TypePackId subTailPack = arena->addTypePack(std::move(headSlice), subTail); TypePackId subTailPack = arena->addTypePack(std::move(headSlice), subTail);
if (TypePackId* other = if (TypePackId* other = env.getMappedPackBounds(*superTail))
(DFInt::LuauTypeSolverRelease >= 644 ? env.getMappedPackBounds(*superTail) : env.mappedGenericPacks.find(*superTail)))
// TODO: TypePath can't express "slice of a pack + its tail". // TODO: TypePath can't express "slice of a pack + its tail".
results.push_back(isContravariantWith(env, subTailPack, *other, scope).withSuperComponent(TypePath::PackField::Tail)); results.push_back(isContravariantWith(env, subTailPack, *other, scope).withSuperComponent(TypePath::PackField::Tail));
else else
@ -1837,11 +1782,8 @@ bool Subtyping::bindGeneric(SubtypingEnvironment& env, TypeId subTy, TypeId supe
if (!get<GenericType>(subTy)) if (!get<GenericType>(subTy))
return false; return false;
if (DFInt::LuauTypeSolverRelease >= 644) if (!env.mappedGenerics.find(subTy) && env.containsMappedType(subTy))
{ iceReporter->ice("attempting to modify bounds of a potentially visited generic");
if (!env.mappedGenerics.find(subTy) && env.containsMappedType(subTy))
iceReporter->ice("attempting to modify bounds of a potentially visited generic");
}
env.mappedGenerics[subTy].upperBound.insert(superTy); env.mappedGenerics[subTy].upperBound.insert(superTy);
} }
@ -1850,11 +1792,8 @@ bool Subtyping::bindGeneric(SubtypingEnvironment& env, TypeId subTy, TypeId supe
if (!get<GenericType>(superTy)) if (!get<GenericType>(superTy))
return false; return false;
if (DFInt::LuauTypeSolverRelease >= 644) if (!env.mappedGenerics.find(superTy) && env.containsMappedType(superTy))
{ iceReporter->ice("attempting to modify bounds of a potentially visited generic");
if (!env.mappedGenerics.find(superTy) && env.containsMappedType(superTy))
iceReporter->ice("attempting to modify bounds of a potentially visited generic");
}
env.mappedGenerics[superTy].lowerBound.insert(subTy); env.mappedGenerics[superTy].lowerBound.insert(subTy);
} }
@ -1901,7 +1840,7 @@ bool Subtyping::bindGeneric(SubtypingEnvironment& env, TypePackId subTp, TypePac
if (!get<GenericTypePack>(subTp)) if (!get<GenericTypePack>(subTp))
return false; return false;
if (TypePackId* m = (DFInt::LuauTypeSolverRelease >= 644 ? env.getMappedPackBounds(subTp) : env.mappedGenericPacks.find(subTp))) if (TypePackId* m = env.getMappedPackBounds(subTp))
return *m == superTp; return *m == superTp;
env.mappedGenericPacks[subTp] = superTp; env.mappedGenericPacks[subTp] = superTp;

View file

@ -6,19 +6,14 @@
#include "Luau/Type.h" #include "Luau/Type.h"
#include "Luau/ToString.h" #include "Luau/ToString.h"
#include "Luau/TypeArena.h" #include "Luau/TypeArena.h"
#include "Luau/TypeUtils.h"
#include "Luau/Unifier2.h" #include "Luau/Unifier2.h"
LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease)
namespace Luau namespace Luau
{ {
static bool isLiteral(const AstExpr* expr)
{
return (
expr->is<AstExprTable>() || expr->is<AstExprFunction>() || expr->is<AstExprConstantNumber>() || expr->is<AstExprConstantString>() ||
expr->is<AstExprConstantBool>() || expr->is<AstExprConstantNil>()
);
}
// A fast approximation of subTy <: superTy // A fast approximation of subTy <: superTy
static bool fastIsSubtype(TypeId subTy, TypeId superTy) static bool fastIsSubtype(TypeId subTy, TypeId superTy)
{ {
@ -381,15 +376,21 @@ TypeId matchLiteralType(
const TypeId* keyTy = astTypes->find(item.key); const TypeId* keyTy = astTypes->find(item.key);
LUAU_ASSERT(keyTy); LUAU_ASSERT(keyTy);
TypeId tKey = follow(*keyTy); TypeId tKey = follow(*keyTy);
if (get<BlockedType>(tKey)) if (DFInt::LuauTypeSolverRelease >= 648)
{
LUAU_ASSERT(!is<BlockedType>(tKey));
}
else if (get<BlockedType>(tKey))
toBlock.push_back(tKey); toBlock.push_back(tKey);
const TypeId* propTy = astTypes->find(item.value); const TypeId* propTy = astTypes->find(item.value);
LUAU_ASSERT(propTy); LUAU_ASSERT(propTy);
TypeId tProp = follow(*propTy); TypeId tProp = follow(*propTy);
if (get<BlockedType>(tProp)) if (DFInt::LuauTypeSolverRelease >= 648)
{
LUAU_ASSERT(!is<BlockedType>(tKey));
}
else if (get<BlockedType>(tProp))
toBlock.push_back(tProp); toBlock.push_back(tProp);
// Populate expected types for non-string keys declared with [] (the code below will handle the case where they are strings) // Populate expected types for non-string keys declared with [] (the code below will handle the case where they are strings)
if (!item.key->as<AstExprConstantString>() && expectedTableTy->indexer) if (!item.key->as<AstExprConstantString>() && expectedTableTy->indexer)
(*astExpectedTypes)[item.key] = expectedTableTy->indexer->indexType; (*astExpectedTypes)[item.key] = expectedTableTy->indexer->indexType;

View file

@ -269,6 +269,12 @@ void StateDot::visitChildren(TypeId ty, int index)
finishNodeLabel(ty); finishNodeLabel(ty);
finishNode(); finishNode();
} }
else if constexpr (std::is_same_v<T, NoRefineType>)
{
formatAppend(result, "NoRefineType %d", index);
finishNodeLabel(ty);
finishNode();
}
else if constexpr (std::is_same_v<T, UnknownType>) else if constexpr (std::is_same_v<T, UnknownType>)
{ {
formatAppend(result, "UnknownType %d", index); formatAppend(result, "UnknownType %d", index);

View file

@ -856,6 +856,11 @@ struct TypeStringifier
state.emit("any"); state.emit("any");
} }
void operator()(TypeId, const NoRefineType&)
{
state.emit("*no-refine*");
}
void operator()(TypeId, const UnionType& uv) void operator()(TypeId, const UnionType& uv)
{ {
if (state.hasSeen(&uv)) if (state.hasSeen(&uv))

View file

@ -1030,6 +1030,7 @@ BuiltinTypes::BuiltinTypes()
, unknownType(arena->addType(Type{UnknownType{}, /*persistent*/ true})) , unknownType(arena->addType(Type{UnknownType{}, /*persistent*/ true}))
, neverType(arena->addType(Type{NeverType{}, /*persistent*/ true})) , neverType(arena->addType(Type{NeverType{}, /*persistent*/ true}))
, errorType(arena->addType(Type{ErrorType{}, /*persistent*/ true})) , errorType(arena->addType(Type{ErrorType{}, /*persistent*/ true}))
, noRefineType(arena->addType(Type{NoRefineType{}, /*persistent*/ true}))
, falsyType(arena->addType(Type{UnionType{{falseType, nilType}}, /*persistent*/ true})) , falsyType(arena->addType(Type{UnionType{{falseType, nilType}}, /*persistent*/ true}))
, truthyType(arena->addType(Type{NegationType{falsyType}, /*persistent*/ true})) , truthyType(arena->addType(Type{NegationType{falsyType}, /*persistent*/ true}))
, optionalNumberType(arena->addType(Type{UnionType{{numberType, nilType}}, /*persistent*/ true})) , optionalNumberType(arena->addType(Type{UnionType{{numberType, nilType}}, /*persistent*/ true}))

View file

@ -145,6 +145,12 @@ public:
{ {
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("any"), std::nullopt, Location()); return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("any"), std::nullopt, Location());
} }
AstType* operator()(const NoRefineType&)
{
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("*no-refine*"), std::nullopt, Location());
}
AstType* operator()(const TableType& ttv) AstType* operator()(const TableType& ttv)
{ {
RecursionCounter counter(&count); RecursionCounter counter(&count);

View file

@ -3022,20 +3022,9 @@ PropertyType TypeChecker2::hasIndexTypeFromType(
if (tt->indexer) if (tt->indexer)
{ {
TypeId indexType = follow(tt->indexer->indexType); TypeId indexType = follow(tt->indexer->indexType);
if (DFInt::LuauTypeSolverRelease >= 644) TypeId givenType = module->internalTypes.addType(SingletonType{StringSingleton{prop}});
{ if (isSubtype(givenType, indexType, NotNull{module->getModuleScope().get()}, builtinTypes, *ice))
TypeId givenType = module->internalTypes.addType(SingletonType{StringSingleton{prop}}); return {NormalizationResult::True, {tt->indexer->indexResultType}};
if (isSubtype(givenType, indexType, NotNull{module->getModuleScope().get()}, builtinTypes, *ice))
return {NormalizationResult::True, {tt->indexer->indexResultType}};
}
else
{
if (isPrim(indexType, PrimitiveType::String))
return {NormalizationResult::True, {tt->indexer->indexResultType}};
// If the indexer looks like { [any] : _} - the prop lookup should be allowed!
else if (get<AnyType>(indexType) || get<UnknownType>(indexType))
return {NormalizationResult::True, {tt->indexer->indexResultType}};
}
} }

View file

@ -49,6 +49,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauLogTypeFamilies, false)
LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunctions2, false) LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunctions2, false)
LUAU_FASTFLAG(LuauUserDefinedTypeFunctionNoEvaluation) LUAU_FASTFLAG(LuauUserDefinedTypeFunctionNoEvaluation)
LUAU_FASTFLAG(LuauUserTypeFunFixRegister) LUAU_FASTFLAG(LuauUserTypeFunFixRegister)
LUAU_FASTFLAG(LuauRemoveNotAnyHack)
LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease)
@ -777,16 +778,8 @@ TypeFunctionReductionResult<TypeId> lenTypeFunction(
if (normTy->hasTopTable() || get<TableType>(normalizedOperand)) if (normTy->hasTopTable() || get<TableType>(normalizedOperand))
return {ctx->builtins->numberType, false, {}, {}}; return {ctx->builtins->numberType, false, {}, {}};
if (DFInt::LuauTypeSolverRelease >= 644) if (auto result = tryDistributeTypeFunctionApp(lenTypeFunction, instance, typeParams, packParams, ctx))
{ return *result;
if (auto result = tryDistributeTypeFunctionApp(lenTypeFunction, instance, typeParams, packParams, ctx))
return *result;
}
else
{
if (auto result = tryDistributeTypeFunctionApp(notTypeFunction, instance, typeParams, packParams, ctx))
return *result;
}
// findMetatableEntry demands the ability to emit errors, so we must give it // findMetatableEntry demands the ability to emit errors, so we must give it
// the necessary state to do that, even if we intend to just eat the errors. // the necessary state to do that, even if we intend to just eat the errors.
@ -874,16 +867,8 @@ TypeFunctionReductionResult<TypeId> unmTypeFunction(
if (normTy->isExactlyNumber()) if (normTy->isExactlyNumber())
return {ctx->builtins->numberType, false, {}, {}}; return {ctx->builtins->numberType, false, {}, {}};
if (DFInt::LuauTypeSolverRelease >= 644) if (auto result = tryDistributeTypeFunctionApp(unmTypeFunction, instance, typeParams, packParams, ctx))
{ return *result;
if (auto result = tryDistributeTypeFunctionApp(unmTypeFunction, instance, typeParams, packParams, ctx))
return *result;
}
else
{
if (auto result = tryDistributeTypeFunctionApp(notTypeFunction, instance, typeParams, packParams, ctx))
return *result;
}
// findMetatableEntry demands the ability to emit errors, so we must give it // findMetatableEntry demands the ability to emit errors, so we must give it
// the necessary state to do that, even if we intend to just eat the errors. // the necessary state to do that, even if we intend to just eat the errors.
@ -1810,7 +1795,6 @@ struct FindRefinementBlockers : TypeOnceVisitor
} }
}; };
TypeFunctionReductionResult<TypeId> refineTypeFunction( TypeFunctionReductionResult<TypeId> refineTypeFunction(
TypeId instance, TypeId instance,
const std::vector<TypeId>& typeParams, const std::vector<TypeId>& typeParams,
@ -1878,8 +1862,18 @@ TypeFunctionReductionResult<TypeId> refineTypeFunction(
* We need to treat T & ~any as T in this case. * We need to treat T & ~any as T in this case.
*/ */
if (auto nt = get<NegationType>(discriminant)) if (auto nt = get<NegationType>(discriminant))
if (get<AnyType>(follow(nt->ty))) {
return {target, {}}; if (FFlag::LuauRemoveNotAnyHack)
{
if (get<NoRefineType>(follow(nt->ty)))
return {target, {}};
}
else
{
if (get<AnyType>(follow(nt->ty)))
return {target, {}};
}
}
// If the target type is a table, then simplification already implements the logic to deal with refinements properly since the // If the target type is a table, then simplification already implements the logic to deal with refinements properly since the
// type of the discriminant is guaranteed to only ever be an (arbitrarily-nested) table of a single property type. // type of the discriminant is guaranteed to only ever be an (arbitrarily-nested) table of a single property type.
@ -2059,6 +2053,15 @@ TypeFunctionReductionResult<TypeId> intersectTypeFunction(
for (auto ty : typeParams) for (auto ty : typeParams)
types.emplace_back(follow(ty)); types.emplace_back(follow(ty));
if (FFlag::LuauRemoveNotAnyHack)
{
// if we only have two parameters and one is `*no-refine*`, we're all done.
if (types.size() == 2 && get<NoRefineType>(types[1]))
return {types[0], false, {}, {}};
else if (types.size() == 2 && get<NoRefineType>(types[0]))
return {types[1], false, {}, {}};
}
// check to see if the operand types are resolved enough, and wait to reduce if not // check to see if the operand types are resolved enough, and wait to reduce if not
// if any of them are `never`, the intersection will always be `never`, so we can reduce directly. // if any of them are `never`, the intersection will always be `never`, so we can reduce directly.
for (auto ty : types) for (auto ty : types)
@ -2073,6 +2076,10 @@ TypeFunctionReductionResult<TypeId> intersectTypeFunction(
TypeId resultTy = ctx->builtins->unknownType; TypeId resultTy = ctx->builtins->unknownType;
for (auto ty : types) for (auto ty : types)
{ {
// skip any `*no-refine*` types.
if (FFlag::LuauRemoveNotAnyHack && get<NoRefineType>(ty))
continue;
SimplifyResult result = simplifyIntersection(ctx->builtins, ctx->arena, resultTy, ty); SimplifyResult result = simplifyIntersection(ctx->builtins, ctx->arena, resultTy, ty);
if (!result.blockedTypes.empty()) if (!result.blockedTypes.empty())
return {std::nullopt, false, {result.blockedTypes.begin(), result.blockedTypes.end()}, {}}; return {std::nullopt, false, {result.blockedTypes.begin(), result.blockedTypes.end()}, {}};

View file

@ -16,6 +16,7 @@
LUAU_DYNAMIC_FASTINT(LuauTypeFunctionSerdeIterationLimit) LUAU_DYNAMIC_FASTINT(LuauTypeFunctionSerdeIterationLimit)
LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease)
LUAU_FASTFLAGVARIABLE(LuauUserTypeFunFixRegister, false) LUAU_FASTFLAGVARIABLE(LuauUserTypeFunFixRegister, false)
LUAU_FASTFLAGVARIABLE(LuauUserTypeFunFixNoReadWrite, false)
namespace Luau namespace Luau
{ {
@ -634,6 +635,8 @@ static int readTableProp(lua_State* L)
auto prop = tftt->props.at(tfsst->value); auto prop = tftt->props.at(tfsst->value);
if (prop.readTy) if (prop.readTy)
allocTypeUserData(L, (*prop.readTy)->type); allocTypeUserData(L, (*prop.readTy)->type);
else if (FFlag::LuauUserTypeFunFixNoReadWrite)
lua_pushnil(L);
else else
luaL_error(L, "type.readproperty: property %s is write-only, and therefore does not have a read type.", tfsst->value.c_str()); luaL_error(L, "type.readproperty: property %s is write-only, and therefore does not have a read type.", tfsst->value.c_str());
@ -672,6 +675,8 @@ static int writeTableProp(lua_State* L)
auto prop = tftt->props.at(tfsst->value); auto prop = tftt->props.at(tfsst->value);
if (prop.writeTy) if (prop.writeTy)
allocTypeUserData(L, (*prop.writeTy)->type); allocTypeUserData(L, (*prop.writeTy)->type);
else if (FFlag::LuauUserTypeFunFixNoReadWrite)
lua_pushnil(L);
else else
luaL_error(L, "type.writeproperty: property %s is read-only, and therefore does not have a write type.", tfsst->value.c_str()); luaL_error(L, "type.writeproperty: property %s is read-only, and therefore does not have a write type.", tfsst->value.c_str());

View file

@ -479,4 +479,68 @@ ErrorSuppression shouldSuppressErrors(NotNull<Normalizer> normalizer, TypePackId
return result; return result;
} }
bool isLiteral(const AstExpr* expr)
{
return (
expr->is<AstExprTable>() || expr->is<AstExprFunction>() || expr->is<AstExprConstantNumber>() || expr->is<AstExprConstantString>() ||
expr->is<AstExprConstantBool>() || expr->is<AstExprConstantNil>()
);
}
/**
* Visitor which, given an expression and a mapping from expression to TypeId,
* determines if there are any literal expressions that contain blocked types.
* This is used for bi-directional inference: we want to "apply" a type from
* a function argument or a type annotation to a literal.
*/
class BlockedTypeInLiteralVisitor : public AstVisitor
{
public:
explicit BlockedTypeInLiteralVisitor(NotNull<DenseHashMap<const AstExpr*, TypeId>> astTypes, NotNull<std::vector<TypeId>> toBlock)
: astTypes_{astTypes}
, toBlock_{toBlock}
{
}
bool visit(AstNode*) override
{
return false;
}
bool visit(AstExpr* e) override
{
auto ty = astTypes_->find(e);
if (ty && (get<BlockedType>(follow(*ty)) != nullptr))
{
toBlock_->push_back(*ty);
}
return isLiteral(e) || e->is<AstExprGroup>();
}
private:
NotNull<DenseHashMap<const AstExpr*, TypeId>> astTypes_;
NotNull<std::vector<TypeId>> toBlock_;
};
std::vector<TypeId> findBlockedTypesIn(AstExprTable* expr, NotNull<DenseHashMap<const AstExpr*, TypeId>> astTypes)
{
std::vector<TypeId> toBlock;
BlockedTypeInLiteralVisitor v{astTypes, NotNull{&toBlock}};
expr->visit(&v);
return toBlock;
}
std::vector<TypeId> findBlockedArgTypesIn(AstExprCall* expr, NotNull<DenseHashMap<const AstExpr*, TypeId>> astTypes)
{
std::vector<TypeId> toBlock;
BlockedTypeInLiteralVisitor v{astTypes, NotNull{&toBlock}};
for (auto arg: expr->args)
{
if (isLiteral(arg) || arg->is<AstExprGroup>())
{
arg->visit(&v);
}
}
return toBlock;
}
} // namespace Luau } // namespace Luau

View file

@ -188,9 +188,18 @@ Parser::Parser(const char* buffer, size_t bufferSize, AstNameTable& names, Alloc
functionStack.reserve(8); functionStack.reserve(8);
functionStack.push_back(top); functionStack.push_back(top);
nameSelf = names.addStatic("self"); if (FFlag::LuauAllowFragmentParsing)
nameNumber = names.addStatic("number"); {
nameError = names.addStatic(kParseNameError); nameSelf = names.getOrAdd("self");
nameNumber = names.getOrAdd("number");
nameError = names.getOrAdd(kParseNameError);
}
else
{
nameSelf = names.addStatic("self");
nameNumber = names.addStatic("number");
nameError = names.addStatic(kParseNameError);
}
nameNil = names.getOrAdd("nil"); // nil is a reserved keyword nameNil = names.getOrAdd("nil"); // nil is a reserved keyword
matchRecoveryStopOnToken.assign(Lexeme::Type::Reserved_END, 0); matchRecoveryStopOnToken.assign(Lexeme::Type::Reserved_END, 0);

View file

@ -17,8 +17,6 @@
#include <string.h> #include <string.h>
LUAU_FASTFLAGVARIABLE(LuauErrorResumeCleanupArgs, false)
/* /*
** {====================================================== ** {======================================================
** Error-recovery functions ** Error-recovery functions
@ -430,11 +428,7 @@ static void resume_handle(lua_State* L, void* ud)
static int resume_error(lua_State* L, const char* msg, int narg) static int resume_error(lua_State* L, const char* msg, int narg)
{ {
if (FFlag::LuauErrorResumeCleanupArgs) L->top -= narg;
L->top -= narg;
else
L->top = L->ci->base;
setsvalue(L, L->top, luaS_new(L, msg)); setsvalue(L, L->top, luaS_new(L, msg));
incr_top(L); incr_top(L);
return LUA_ERRRUN; return LUA_ERRRUN;

View file

@ -7,6 +7,8 @@
#include <math.h> #include <math.h>
#include <time.h> #include <time.h>
LUAU_FASTFLAGVARIABLE(LuauMathMap, false)
#undef PI #undef PI
#define PI (3.14159265358979323846) #define PI (3.14159265358979323846)
#define RADIANS_PER_DEGREE (PI / 180.0) #define RADIANS_PER_DEGREE (PI / 180.0)
@ -403,6 +405,19 @@ static int math_round(lua_State* L)
return 1; return 1;
} }
static int math_map(lua_State* L)
{
double x = luaL_checknumber(L, 1);
double inmin = luaL_checknumber(L, 2);
double inmax = luaL_checknumber(L, 3);
double outmin = luaL_checknumber(L, 4);
double outmax = luaL_checknumber(L, 5);
double result = outmin + (x - inmin) * (outmax - outmin) / (inmax - inmin);
lua_pushnumber(L, result);
return 1;
}
static const luaL_Reg mathlib[] = { static const luaL_Reg mathlib[] = {
{"abs", math_abs}, {"abs", math_abs},
{"acos", math_acos}, {"acos", math_acos},
@ -455,5 +470,12 @@ int luaopen_math(lua_State* L)
lua_setfield(L, -2, "pi"); lua_setfield(L, -2, "pi");
lua_pushnumber(L, HUGE_VAL); lua_pushnumber(L, HUGE_VAL);
lua_setfield(L, -2, "huge"); lua_setfield(L, -2, "huge");
if (FFlag::LuauMathMap)
{
lua_pushcfunction(L, math_map, "map");
lua_setfield(L, -2, "map");
}
return 1; return 1;
} }

View file

@ -508,9 +508,6 @@ def runTest(subdir, filename, filepath):
filepath = os.path.abspath(filepath) filepath = os.path.abspath(filepath)
mainVm = os.path.abspath(arguments.vm) mainVm = os.path.abspath(arguments.vm)
if not os.path.isfile(mainVm):
print(f"{colored(Color.RED, 'ERROR')}: VM executable '{mainVm}' does not exist.")
sys.exit(1)
# Process output will contain the test name and execution times # Process output will contain the test name and execution times
mainOutput = getVmOutput(substituteArguments(mainVm, getExtraArguments(filepath)) + " " + filepath) mainOutput = getVmOutput(substituteArguments(mainVm, getExtraArguments(filepath)) + " " + filepath)
@ -890,11 +887,9 @@ def run(args, argsubcb):
analyzeResult('', mainResult, compareResults) analyzeResult('', mainResult, compareResults)
else: else:
all_files = [subdir + os.sep + filename for subdir, dirs, files in os.walk(arguments.folder) for filename in files] all_files = [subdir + os.sep + filename for subdir, dirs, files in os.walk(arguments.folder) for filename in files]
if len(all_files) == 0:
print(f"{colored(Color.YELLOW, 'WARNING')}: No test files found in '{arguments.folder}'.")
for filepath in sorted(all_files): for filepath in sorted(all_files):
subdir, filename = os.path.split(filepath) subdir, filename = os.path.split(filepath)
if filename.endswith(".lua") or filename.endswith(".luau"): if filename.endswith(".lua"):
if arguments.run_test == None or re.match(arguments.run_test, filename[:-4]): if arguments.run_test == None or re.match(arguments.run_test, filename[:-4]):
runTest(subdir, filename, filepath) runTest(subdir, filename, filepath)

View file

@ -31,6 +31,7 @@ extern int optimizationLevel;
void luaC_fullgc(lua_State* L); void luaC_fullgc(lua_State* L);
void luaC_validate(lua_State* L); void luaC_validate(lua_State* L);
LUAU_FASTFLAG(LuauMathMap)
LUAU_FASTFLAG(DebugLuauAbortingChecks) LUAU_FASTFLAG(DebugLuauAbortingChecks)
LUAU_FASTINT(CodegenHeuristicsInstructionLimit) LUAU_FASTINT(CodegenHeuristicsInstructionLimit)
LUAU_FASTFLAG(LuauNativeAttribute) LUAU_FASTFLAG(LuauNativeAttribute)
@ -652,6 +653,8 @@ TEST_CASE("Buffers")
TEST_CASE("Math") TEST_CASE("Math")
{ {
ScopedFastFlag LuauMathMap{FFlag::LuauMathMap, true};
runConformance("math.lua"); runConformance("math.lua");
} }

View file

@ -4,10 +4,13 @@
#include "Fixture.h" #include "Fixture.h"
#include "Luau/Ast.h" #include "Luau/Ast.h"
#include "Luau/AstQuery.h" #include "Luau/AstQuery.h"
#include "Luau/Common.h"
using namespace Luau; using namespace Luau;
LUAU_FASTFLAG(LuauAllowFragmentParsing);
struct FragmentAutocompleteFixture : Fixture struct FragmentAutocompleteFixture : Fixture
{ {
@ -17,9 +20,25 @@ struct FragmentAutocompleteFixture : Fixture
REQUIRE(p.root); REQUIRE(p.root);
return findAncestryForFragmentParse(p.root, cursorPos); return findAncestryForFragmentParse(p.root, cursorPos);
} }
CheckResult checkBase(const std::string& document)
{
ScopedFastFlag sff{FFlag::LuauSolverV2, true};
FrontendOptions opts;
opts.retainFullTypeGraphs = true;
return this->frontend.check("MainModule", opts);
}
FragmentParseResult parseFragment(const std::string& document, const Position& cursorPos)
{
ScopedFastFlag sffs[]{{FFlag::LuauAllowFragmentParsing, true}, {FFlag::LuauSolverV2, true}};
SourceModule* srcModule = this->getMainSourceModule();
std::string_view srcString = document;
return Luau::parseFragment(*srcModule, srcString, cursorPos);
}
}; };
TEST_SUITE_BEGIN("FragmentAutocompleteTraversalTest"); TEST_SUITE_BEGIN("FragmentAutocompleteTraversalTests");
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "just_two_locals") TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "just_two_locals")
{ {
@ -32,7 +51,7 @@ local y = 5
); );
CHECK_EQ(3, result.ancestry.size()); CHECK_EQ(3, result.ancestry.size());
CHECK_EQ(2, result.localStack.size()); CHECK_EQ(1, result.localStack.size());
CHECK_EQ(result.localMap.size(), result.localStack.size()); CHECK_EQ(result.localMap.size(), result.localStack.size());
REQUIRE(result.nearestStatement); REQUIRE(result.nearestStatement);
@ -56,10 +75,10 @@ end
); );
CHECK_EQ(5, result.ancestry.size()); CHECK_EQ(5, result.ancestry.size());
CHECK_EQ(3, result.localStack.size()); CHECK_EQ(2, result.localStack.size());
CHECK_EQ(result.localMap.size(), result.localStack.size()); CHECK_EQ(result.localMap.size(), result.localStack.size());
REQUIRE(result.nearestStatement); REQUIRE(result.nearestStatement);
CHECK_EQ("e", std::string(result.localStack.back()->name.value)); CHECK_EQ("y", std::string(result.localStack.back()->name.value));
AstStatLocal* local = result.nearestStatement->as<AstStatLocal>(); AstStatLocal* local = result.nearestStatement->as<AstStatLocal>();
REQUIRE(local); REQUIRE(local);
@ -85,10 +104,10 @@ end
); );
CHECK_EQ(6, result.ancestry.size()); CHECK_EQ(6, result.ancestry.size());
CHECK_EQ(4, result.localStack.size()); CHECK_EQ(3, result.localStack.size());
CHECK_EQ(result.localMap.size(), result.localStack.size()); CHECK_EQ(result.localMap.size(), result.localStack.size());
REQUIRE(result.nearestStatement); REQUIRE(result.nearestStatement);
CHECK_EQ("q", std::string(result.localStack.back()->name.value)); CHECK_EQ("z", std::string(result.localStack.back()->name.value));
AstStatLocal* local = result.nearestStatement->as<AstStatLocal>(); AstStatLocal* local = result.nearestStatement->as<AstStatLocal>();
REQUIRE(local); REQUIRE(local);
@ -129,11 +148,122 @@ local function bar() return x + foo() end
); );
CHECK_EQ(8, result.ancestry.size()); CHECK_EQ(8, result.ancestry.size());
CHECK_EQ(3, result.localStack.size()); CHECK_EQ(2, result.localStack.size());
CHECK_EQ(result.localMap.size(), result.localStack.size()); CHECK_EQ(result.localMap.size(), result.localStack.size());
CHECK_EQ("bar", std::string(result.localStack.back()->name.value)); CHECK_EQ("x", std::string(result.localStack.back()->name.value));
auto returnSt = result.nearestStatement->as<AstStatReturn>(); auto returnSt = result.nearestStatement->as<AstStatReturn>();
CHECK(returnSt != nullptr); CHECK(returnSt != nullptr);
} }
TEST_SUITE_END(); TEST_SUITE_END();
TEST_SUITE_BEGIN("FragmentAutocompleteParserTests");
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "statement_in_empty_fragment_is_non_null")
{
auto res = check(R"(
)");
LUAU_REQUIRE_NO_ERRORS(res);
auto fragment = parseFragment(
R"(
)",
Position(1, 0)
);
CHECK_EQ("\n", fragment.fragmentToParse);
CHECK_EQ(2, fragment.ancestry.size());
REQUIRE(fragment.root);
CHECK_EQ(0, fragment.root->body.size);
auto statBody = fragment.root->as<AstStatBlock>();
CHECK(statBody != nullptr);
}
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "can_parse_complete_fragments")
{
auto res = check(
R"(
local x = 4
local y = 5
)"
);
LUAU_REQUIRE_NO_ERRORS(res);
auto fragment = parseFragment(
R"(
local x = 4
local y = 5
local z = x + y
)",
Position{3, 15}
);
CHECK_EQ("\nlocal z = x + y", fragment.fragmentToParse);
CHECK_EQ(5, fragment.ancestry.size());
REQUIRE(fragment.root);
CHECK_EQ(1, fragment.root->body.size);
auto stat = fragment.root->body.data[0]->as<AstStatLocal>();
REQUIRE(stat);
CHECK_EQ(1, stat->vars.size);
CHECK_EQ(1, stat->values.size);
CHECK_EQ("z", std::string(stat->vars.data[0]->name.value));
auto bin = stat->values.data[0]->as<AstExprBinary>();
REQUIRE(bin);
CHECK_EQ(AstExprBinary::Op::Add, bin->op);
auto lhs = bin->left->as<AstExprLocal>();
auto rhs = bin->right->as<AstExprLocal>();
REQUIRE(lhs);
REQUIRE(rhs);
CHECK_EQ("x", std::string(lhs->local->name.value));
CHECK_EQ("y", std::string(rhs->local->name.value));
}
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "can_parse_fragments_in_line")
{
auto res = check(
R"(
local x = 4
local y = 5
)"
);
LUAU_REQUIRE_NO_ERRORS(res);
auto fragment = parseFragment(
R"(
local x = 4
local z = x + y
local y = 5
)",
Position{2, 15}
);
CHECK_EQ("local z = x + y", fragment.fragmentToParse);
CHECK_EQ(5, fragment.ancestry.size());
REQUIRE(fragment.root);
CHECK_EQ(1, fragment.root->body.size);
auto stat = fragment.root->body.data[0]->as<AstStatLocal>();
REQUIRE(stat);
CHECK_EQ(1, stat->vars.size);
CHECK_EQ(1, stat->values.size);
CHECK_EQ("z", std::string(stat->vars.data[0]->name.value));
auto bin = stat->values.data[0]->as<AstExprBinary>();
REQUIRE(bin);
CHECK_EQ(AstExprBinary::Op::Add, bin->op);
auto lhs = bin->left->as<AstExprLocal>();
auto rhs = bin->right->as<AstExprGlobal>();
REQUIRE(lhs);
REQUIRE(rhs);
CHECK_EQ("x", std::string(lhs->local->name.value));
CHECK_EQ("y", std::string(rhs->name.value));
}
TEST_SUITE_END();

View file

@ -3,6 +3,7 @@
#include "lualib.h" #include "lualib.h"
#include "Repl.h" #include "Repl.h"
#include "ScopedFlags.h"
#include "doctest.h" #include "doctest.h"
@ -12,6 +13,8 @@
#include <string> #include <string>
#include <vector> #include <vector>
LUAU_FASTFLAG(LuauMathMap)
struct Completion struct Completion
{ {
std::string completion; std::string completion;
@ -172,15 +175,17 @@ TEST_CASE_FIXTURE(ReplFixture, "CompleteGlobalVariables")
CHECK(checkCompletion(completions, prefix, "myvariable1")); CHECK(checkCompletion(completions, prefix, "myvariable1"));
CHECK(checkCompletion(completions, prefix, "myvariable2")); CHECK(checkCompletion(completions, prefix, "myvariable2"));
} }
if (FFlag::LuauMathMap)
{ {
// Try completing some builtin functions // Try completing some builtin functions
CompletionSet completions = getCompletionSet("math.m"); CompletionSet completions = getCompletionSet("math.m");
std::string prefix = "math."; std::string prefix = "math.";
CHECK(completions.size() == 3); CHECK(completions.size() == 4);
CHECK(checkCompletion(completions, prefix, "max(")); CHECK(checkCompletion(completions, prefix, "max("));
CHECK(checkCompletion(completions, prefix, "min(")); CHECK(checkCompletion(completions, prefix, "min("));
CHECK(checkCompletion(completions, prefix, "modf(")); CHECK(checkCompletion(completions, prefix, "modf("));
CHECK(checkCompletion(completions, prefix, "map("));
} }
} }

View file

@ -12,6 +12,7 @@ LUAU_FASTFLAG(LuauUserDefinedTypeFunctionsSyntax2)
LUAU_FASTFLAG(LuauUserDefinedTypeFunctions2) LUAU_FASTFLAG(LuauUserDefinedTypeFunctions2)
LUAU_FASTFLAG(LuauUserDefinedTypeFunctionNoEvaluation) LUAU_FASTFLAG(LuauUserDefinedTypeFunctionNoEvaluation)
LUAU_FASTFLAG(LuauUserTypeFunFixRegister) LUAU_FASTFLAG(LuauUserTypeFunFixRegister)
LUAU_FASTFLAG(LuauUserTypeFunFixNoReadWrite)
TEST_SUITE_BEGIN("UserDefinedTypeFunctionTests"); TEST_SUITE_BEGIN("UserDefinedTypeFunctionTests");
@ -674,6 +675,36 @@ TEST_CASE_FIXTURE(ClassFixture, "udtf_class_methods_works")
CHECK(toString(tpm->givenTp) == "{ BaseField: number, read BaseMethod: (BaseClass, number) -> (), read Touched: Connection }"); CHECK(toString(tpm->givenTp) == "{ BaseField: number, read BaseMethod: (BaseClass, number) -> (), read Touched: Connection }");
} }
TEST_CASE_FIXTURE(ClassFixture, "write_of_readonly_is_nil")
{
ScopedFastFlag newSolver{FFlag::LuauSolverV2, true};
ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true};
ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true};
ScopedFastFlag udtfRwFix{FFlag::LuauUserTypeFunFixNoReadWrite, true};
CheckResult result = check(R"(
type function getclass(arg)
local props = arg:properties()
local table = types.newtable(props)
local singleton = types.singleton("BaseMethod")
if table:writeproperty(singleton) then
return types.singleton(true)
else
return types.singleton(false)
end
end
-- forcing an error here to check the exact type of the metatable
local function ok(idx: getclass<BaseClass>): nil return idx end
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
TypePackMismatch* tpm = get<TypePackMismatch>(result.errors[0]);
REQUIRE(tpm);
CHECK(toString(tpm->givenTp) == "false");
}
TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_check_mutability") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_check_mutability")
{ {
ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; ScopedFastFlag newSolver{FFlag::LuauSolverV2, true};

View file

@ -4891,4 +4891,41 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "metatable_union_type")
); );
} }
TEST_CASE_FIXTURE(Fixture, "function_check_constraint_too_eager")
{
ScopedFastFlag _{FFlag::LuauSolverV2, true};
// CLI-121540: All of these examples should have no errors.
LUAU_CHECK_ERROR_COUNT(3, check(R"(
local function doTheThing(_: { [string]: unknown }) end
doTheThing({
['foo'] = 5,
['bar'] = 'heyo',
})
)"));
LUAU_CHECK_ERROR_COUNT(1, check(R"(
type Input = { [string]: unknown }
local i : Input = {
[('%s'):format('3.14')]=5,
['stringField']='Heyo'
}
)"));
// This example previously asserted due to eagerly mutating the underlying
// table type.
LUAU_CHECK_ERROR_COUNT(3, check(R"(
type Input = { [string]: unknown }
local function doTheThing(_: Input) end
doTheThing({
[('%s'):format('3.14')]=5,
['stringField']='Heyo'
})
)"));
}
TEST_SUITE_END(); TEST_SUITE_END();

View file

@ -1683,4 +1683,27 @@ TEST_CASE_FIXTURE(Fixture, "leading_ampersand_no_type")
CHECK("*error-type*" == toString(requireTypeAlias("Amp"))); CHECK("*error-type*" == toString(requireTypeAlias("Amp")));
} }
TEST_CASE_FIXTURE(Fixture, "react_lua_follow_free_type_ub")
{
ScopedFastFlag _{FFlag::LuauSolverV2, true};
LUAU_REQUIRE_NO_ERRORS(check(R"(
return function(Roact)
local Tree = Roact.Component:extend("Tree")
function Tree:render()
local breadth, components, depth, id, wrap =
self.props.breadth, self.props.components, self.props.depth, self.props.id, self.props.wrap
local Box = components.Box
if depth == 0 then
Roact.createElement(Box, {})
else
Roact.createElement(Tree, {})
end
end
end
)"));
}
TEST_SUITE_END(); TEST_SUITE_END();

View file

@ -388,6 +388,20 @@ assert(math.pow(noinline(2), 2) == 4)
assert(math.pow(noinline(4), 0.5) == 2) assert(math.pow(noinline(4), 0.5) == 2)
assert(math.pow(noinline(-2), 2) == 4) assert(math.pow(noinline(-2), 2) == 4)
-- map
assert(math.map(0, -1, 1, 0, 2) == 1)
assert(math.map(1, 1, 4, 0, 2) == 0)
assert(math.map(2.5, 1, 4, 0, 2) == 1)
assert(math.map(4, 1, 4, 0, 2) == 2)
assert(math.map(1, 1, 4, 2, 0) == 2)
assert(math.map(2.5, 1, 4, 2, 0) == 1)
assert(math.map(4, 1, 4, 2, 0) == 0)
assert(math.map(1, 4, 1, 2, 0) == 0)
assert(math.map(2.5, 4, 1, 2, 0) == 1)
assert(math.map(4, 4, 1, 2, 0) == 2)
assert(math.map(-8, 0, 4, 0, 2) == -4)
assert(math.map(16, 0, 4, 0, 2) == 8)
assert(tostring(math.pow(-2, 0.5)) == "nan") assert(tostring(math.pow(-2, 0.5)) == "nan")
-- test that fastcalls return correct number of results -- test that fastcalls return correct number of results