Merge branch 'master' into type_function_solver_dependency_fix2

This commit is contained in:
karl-police 2025-03-05 00:32:14 +01:00 committed by GitHub
commit 162fe2129b
Signed by: DevComp
GPG key ID: B5690EEEBB952194
479 changed files with 35991 additions and 7892 deletions

View file

@ -63,10 +63,10 @@ jobs:
}
valgrind --tool=callgrind ./luau-analyze --mode=nonstrict bench/other/LuauPolyfillMap.lua 2>&1 | filter map-nonstrict | tee -a analyze-output.txt
valgrind --tool=callgrind ./luau-analyze --mode=strict bench/other/LuauPolyfillMap.lua 2>&1 | filter map-strict | tee -a analyze-output.txt
valgrind --tool=callgrind ./luau-analyze --mode=strict --fflags=DebugLuauDeferredConstraintResolution bench/other/LuauPolyfillMap.lua 2>&1 | filter map-dcr | tee -a analyze-output.txt
valgrind --tool=callgrind ./luau-analyze --mode=strict --fflags=LuauSolverV2 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-dcr | tee -a analyze-output.txt
valgrind --tool=callgrind ./luau-analyze --mode=nonstrict bench/other/regex.lua 2>&1 | filter regex-nonstrict | tee -a analyze-output.txt
valgrind --tool=callgrind ./luau-analyze --mode=strict bench/other/regex.lua 2>&1 | filter regex-strict | tee -a analyze-output.txt
valgrind --tool=callgrind ./luau-analyze --mode=strict --fflags=DebugLuauDeferredConstraintResolution bench/other/regex.lua 2>&1 | filter regex-dcr | tee -a analyze-output.txt
valgrind --tool=callgrind ./luau-analyze --mode=strict --fflags=LuauSolverV2 bench/other/regex.lua 2>&1 | filter regex-dcr | tee -a analyze-output.txt
- name: Run benchmark (compile)
run: |

View file

@ -46,9 +46,9 @@ jobs:
- name: make cli
run: |
make -j2 config=sanitize werror=1 luau luau-analyze luau-compile # match config with tests to improve build time
./luau tests/conformance/assert.lua
./luau-analyze tests/conformance/assert.lua
./luau-compile tests/conformance/assert.lua
./luau tests/conformance/assert.luau
./luau-analyze tests/conformance/assert.luau
./luau-compile tests/conformance/assert.luau
windows:
runs-on: windows-latest
@ -81,12 +81,12 @@ jobs:
shell: bash # necessary for fail-fast
run: |
cmake --build . --target Luau.Repl.CLI Luau.Analyze.CLI Luau.Compile.CLI --config Debug # match config with tests to improve build time
Debug/luau tests/conformance/assert.lua
Debug/luau-analyze tests/conformance/assert.lua
Debug/luau-compile tests/conformance/assert.lua
Debug/luau tests/conformance/assert.luau
Debug/luau-analyze tests/conformance/assert.luau
Debug/luau-compile tests/conformance/assert.luau
coverage:
runs-on: ubuntu-20.04 # needed for clang++-10 to avoid gcov compatibility issues
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v2
- name: install
@ -94,7 +94,7 @@ jobs:
sudo apt install llvm
- name: make coverage
run: |
CXX=clang++-10 make -j2 config=coverage native=1 coverage
CXX=clang++ make -j2 config=coverage native=1 coverage
- name: upload coverage
uses: codecov/codecov-action@v3
with:

View file

@ -29,8 +29,8 @@ jobs:
build:
needs: ["create-release"]
strategy:
matrix: # using ubuntu-20.04 to build a Linux binary targeting older glibc to improve compatibility
os: [{name: ubuntu, version: ubuntu-20.04}, {name: macos, version: macos-latest}, {name: windows, version: windows-latest}]
matrix: # not using ubuntu-latest to improve compatibility
os: [{name: ubuntu, version: ubuntu-22.04}, {name: macos, version: macos-latest}, {name: windows, version: windows-latest}]
name: ${{matrix.os.name}}
runs-on: ${{matrix.os.version}}
steps:
@ -38,7 +38,7 @@ jobs:
- name: configure
run: cmake . -DCMAKE_BUILD_TYPE=Release
- name: build
run: cmake --build . --target Luau.Repl.CLI Luau.Analyze.CLI Luau.Compile.CLI --config Release -j 2
run: cmake --build . --target Luau.Repl.CLI Luau.Analyze.CLI Luau.Compile.CLI Luau.Ast.CLI --config Release -j 2
- name: pack
if: matrix.os.name != 'windows'
run: zip luau-${{matrix.os.name}}.zip luau*

View file

@ -13,8 +13,8 @@ on:
jobs:
build:
strategy:
matrix: # using ubuntu-20.04 to build a Linux binary targeting older glibc to improve compatibility
os: [{name: ubuntu, version: ubuntu-20.04}, {name: macos, version: macos-latest}, {name: windows, version: windows-latest}]
matrix: # not using ubuntu-latest to improve compatibility
os: [{name: ubuntu, version: ubuntu-22.04}, {name: macos, version: macos-latest}, {name: windows, version: windows-latest}]
name: ${{matrix.os.name}}
runs-on: ${{matrix.os.version}}
steps:

1
.gitignore vendored
View file

@ -13,6 +13,7 @@
/luau
/luau-tests
/luau-analyze
/luau-bytecode
/luau-compile
__pycache__
.cache

View file

@ -1,10 +1,10 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/AutocompleteTypes.h"
#include "Luau/Location.h"
#include "Luau/Type.h"
#include <unordered_map>
#include <string>
#include <memory>
#include <optional>
@ -16,89 +16,8 @@ struct Frontend;
struct SourceModule;
struct Module;
struct TypeChecker;
using ModulePtr = std::shared_ptr<Module>;
enum class AutocompleteContext
{
Unknown,
Expression,
Statement,
Property,
Type,
Keyword,
String,
};
enum class AutocompleteEntryKind
{
Property,
Binding,
Keyword,
String,
Type,
Module,
GeneratedFunction,
};
enum class ParenthesesRecommendation
{
None,
CursorAfter,
CursorInside,
};
enum class TypeCorrectKind
{
None,
Correct,
CorrectFunctionResult,
};
struct AutocompleteEntry
{
AutocompleteEntryKind kind = AutocompleteEntryKind::Property;
// Nullopt if kind is Keyword
std::optional<TypeId> type = std::nullopt;
bool deprecated = false;
// Only meaningful if kind is Property.
bool wrongIndexType = false;
// Set if this suggestion matches the type expected in the context
TypeCorrectKind typeCorrect = TypeCorrectKind::None;
std::optional<const ClassType*> containingClass = std::nullopt;
std::optional<const Property*> prop = std::nullopt;
std::optional<std::string> documentationSymbol = std::nullopt;
Tags tags;
ParenthesesRecommendation parens = ParenthesesRecommendation::None;
std::optional<std::string> insertText;
// Only meaningful if kind is Property.
bool indexedWithSelf = false;
};
using AutocompleteEntryMap = std::unordered_map<std::string, AutocompleteEntry>;
struct AutocompleteResult
{
AutocompleteEntryMap entryMap;
std::vector<AstNode*> ancestry;
AutocompleteContext context = AutocompleteContext::Unknown;
AutocompleteResult() = default;
AutocompleteResult(AutocompleteEntryMap entryMap, std::vector<AstNode*> ancestry, AutocompleteContext context)
: entryMap(std::move(entryMap))
, ancestry(std::move(ancestry))
, context(context)
{
}
};
using ModuleName = std::string;
using StringCompletionCallback =
std::function<std::optional<AutocompleteEntryMap>(std::string tag, std::optional<const ClassType*> ctx, std::optional<std::string> contents)>;
struct FileResolver;
AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName, Position position, StringCompletionCallback callback);
constexpr char kGeneratedAnonymousFunctionEntryName[] = "function (anonymous autofilled)";
} // namespace Luau

View file

@ -0,0 +1,92 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Ast.h"
#include "Luau/Type.h"
#include <unordered_map>
namespace Luau
{
enum class AutocompleteContext
{
Unknown,
Expression,
Statement,
Property,
Type,
Keyword,
String,
};
enum class AutocompleteEntryKind
{
Property,
Binding,
Keyword,
String,
Type,
Module,
GeneratedFunction,
RequirePath,
};
enum class ParenthesesRecommendation
{
None,
CursorAfter,
CursorInside,
};
enum class TypeCorrectKind
{
None,
Correct,
CorrectFunctionResult,
};
struct AutocompleteEntry
{
AutocompleteEntryKind kind = AutocompleteEntryKind::Property;
// Nullopt if kind is Keyword
std::optional<TypeId> type = std::nullopt;
bool deprecated = false;
// Only meaningful if kind is Property.
bool wrongIndexType = false;
// Set if this suggestion matches the type expected in the context
TypeCorrectKind typeCorrect = TypeCorrectKind::None;
std::optional<const ClassType*> containingClass = std::nullopt;
std::optional<const Property*> prop = std::nullopt;
std::optional<std::string> documentationSymbol = std::nullopt;
Tags tags;
ParenthesesRecommendation parens = ParenthesesRecommendation::None;
std::optional<std::string> insertText;
// Only meaningful if kind is Property.
bool indexedWithSelf = false;
};
using AutocompleteEntryMap = std::unordered_map<std::string, AutocompleteEntry>;
struct AutocompleteResult
{
AutocompleteEntryMap entryMap;
std::vector<AstNode*> ancestry;
AutocompleteContext context = AutocompleteContext::Unknown;
AutocompleteResult() = default;
AutocompleteResult(AutocompleteEntryMap entryMap, std::vector<AstNode*> ancestry, AutocompleteContext context)
: entryMap(std::move(entryMap))
, ancestry(std::move(ancestry))
, context(context)
{
}
};
using StringCompletionCallback =
std::function<std::optional<AutocompleteEntryMap>(std::string tag, std::optional<const ClassType*> ctx, std::optional<std::string> contents)>;
constexpr char kGeneratedAnonymousFunctionEntryName[] = "function (anonymous autofilled)";
} // namespace Luau

View file

@ -9,6 +9,8 @@
namespace Luau
{
static constexpr char kRequireTagName[] = "require";
struct Frontend;
struct GlobalTypes;
struct TypeChecker;
@ -63,10 +65,7 @@ TypeId makeFunction( // Polymorphic
bool checked = false
);
void attachMagicFunction(TypeId ty, MagicFunction fn);
void attachDcrMagicFunction(TypeId ty, DcrMagicFunction fn);
void attachDcrMagicRefinement(TypeId ty, DcrMagicRefinement fn);
void attachDcrMagicFunctionTypeCheck(TypeId ty, DcrMagicFunctionTypeCheck fn);
void attachMagicFunction(TypeId ty, std::shared_ptr<MagicFunction> fn);
Property makeProperty(TypeId ty, std::optional<std::string> documentationSymbol = std::nullopt);
void assignPropDocumentationSymbols(TableType::Props& props, const std::string& baseName);
@ -80,4 +79,16 @@ std::optional<Binding> tryGetGlobalBinding(GlobalTypes& globals, const std::stri
Binding* tryGetGlobalBindingRef(GlobalTypes& globals, const std::string& name);
TypeId getGlobalBinding(GlobalTypes& globals, const std::string& name);
/** A number of built-in functions are magical enough that we need to match on them specifically by
* name when they are called. These are listed here to be used whenever necessary, instead of duplicating this logic repeatedly.
*/
bool matchSetMetatable(const AstExprCall& call);
bool matchTableFreeze(const AstExprCall& call);
bool matchAssert(const AstExprCall& call);
// Returns `true` if the function should introduce typestate for its first argument.
bool shouldTypestateForFirstArgument(const AstExprCall& call);
} // namespace Luau

View file

@ -4,6 +4,7 @@
#include <Luau/NotNull.h>
#include "Luau/TypeArena.h"
#include "Luau/Type.h"
#include "Luau/Scope.h"
#include <unordered_map>
@ -22,8 +23,21 @@ struct CloneState
SeenTypePacks seenTypePacks;
};
/** `shallowClone` will make a copy of only the _top level_ constructor of the type,
* while `clone` will make a deep copy of the entire type and its every component.
*
* Be mindful about which behavior you actually _want_.
*
* Persistent types are not cloned as an optimization.
* If a type is cloned in order to mutate it, 'ignorePersistent' has to be set
*/
TypePackId shallowClone(TypePackId tp, TypeArena& dest, CloneState& cloneState, bool ignorePersistent = false);
TypeId shallowClone(TypeId typeId, TypeArena& dest, CloneState& cloneState, bool ignorePersistent = false);
TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState);
TypeId clone(TypeId tp, TypeArena& dest, CloneState& cloneState);
TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState);
Binding clone(const Binding& binding, TypeArena& dest, CloneState& cloneState);
} // namespace Luau

View file

@ -109,6 +109,21 @@ struct FunctionCheckConstraint
NotNull<DenseHashMap<const AstExpr*, TypeId>> astExpectedTypes;
};
// table_check expectedType exprType
//
// If `expectedType` is a table type and `exprType` is _also_ a table type,
// propogate the member types of `expectedType` into the types of `exprType`.
// This is used to implement bidirectional inference on table assignment.
// Also see: FunctionCheckConstraint.
struct TableCheckConstraint
{
TypeId expectedType;
TypeId exprType;
AstExprTable* table = nullptr;
NotNull<DenseHashMap<const AstExpr*, TypeId>> astTypes;
NotNull<DenseHashMap<const AstExpr*, TypeId>> astExpectedTypes;
};
// prim FreeType ExpectedType PrimitiveType
//
// FreeType is bounded below by the singleton type and above by PrimitiveType
@ -273,7 +288,8 @@ using ConstraintV = Variant<
UnpackConstraint,
ReduceConstraint,
ReducePackConstraint,
EqualityConstraint>;
EqualityConstraint,
TableCheckConstraint>;
struct Constraint
{

View file

@ -5,6 +5,7 @@
#include "Luau/Constraint.h"
#include "Luau/ControlFlow.h"
#include "Luau/DataFlowGraph.h"
#include "Luau/EqSatSimplification.h"
#include "Luau/InsertionOrderedMap.h"
#include "Luau/Module.h"
#include "Luau/ModuleResolver.h"
@ -15,7 +16,6 @@
#include "Luau/TypeFwd.h"
#include "Luau/TypeUtils.h"
#include "Luau/Variant.h"
#include "Luau/Normalize.h"
#include <memory>
#include <vector>
@ -28,6 +28,7 @@ struct Scope;
using ScopePtr = std::shared_ptr<Scope>;
struct DcrLogger;
struct TypeFunctionRuntime;
struct Inference
{
@ -95,6 +96,9 @@ struct ConstraintGenerator
// will enqueue them during solving.
std::vector<ConstraintPtr> unqueuedConstraints;
// Map a function's signature scope back to its signature type.
DenseHashMap<Scope*, TypeId> scopeToFunction{nullptr};
// The private scope of type aliases for which the type parameters belong to.
DenseHashMap<const AstStatTypeAlias*, ScopePtr> astTypeAliasDefiningScopes{nullptr};
@ -108,6 +112,11 @@ struct ConstraintGenerator
// Needed to be able to enable error-suppression preservation for immediate refinements.
NotNull<Normalizer> normalizer;
NotNull<Simplifier> simplifier;
// Needed to register all available type functions for execution at later stages.
NotNull<TypeFunctionRuntime> typeFunctionRuntime;
// Needed to resolve modules to make 'require' import types properly.
NotNull<ModuleResolver> moduleResolver;
// Occasionally constraint generation needs to produce an ICE.
@ -125,6 +134,8 @@ struct ConstraintGenerator
ConstraintGenerator(
ModulePtr module,
NotNull<Normalizer> normalizer,
NotNull<Simplifier> simplifier,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<ModuleResolver> moduleResolver,
NotNull<BuiltinTypes> builtinTypes,
NotNull<InternalErrorReporter> ice,
@ -142,6 +153,8 @@ struct ConstraintGenerator
*/
void visitModuleRoot(AstStatBlock* block);
void visitFragmentRoot(const ScopePtr& resumeScope, AstStatBlock* block);
private:
std::vector<std::vector<TypeId>> interiorTypes;
@ -223,7 +236,10 @@ private:
);
void applyRefinements(const ScopePtr& scope, Location location, RefinementId refinement);
LUAU_NOINLINE void checkAliases(const ScopePtr& scope, AstStatBlock* block);
ControlFlow visitBlockWithoutChildScope(const ScopePtr& scope, AstStatBlock* block);
ControlFlow visitBlockWithoutChildScope_DEPRECATED(const ScopePtr& scope, AstStatBlock* block);
ControlFlow visit(const ScopePtr& scope, AstStat* stat);
ControlFlow visit(const ScopePtr& scope, AstStatBlock* block);
@ -282,11 +298,25 @@ private:
Inference check(const ScopePtr& scope, AstExprFunction* func, std::optional<TypeId> expectedType, bool generalize);
Inference check(const ScopePtr& scope, AstExprUnary* unary);
Inference check(const ScopePtr& scope, AstExprBinary* binary, std::optional<TypeId> expectedType);
Inference checkAstExprBinary(
const ScopePtr& scope,
const Location& location,
AstExprBinary::Op op,
AstExpr* left,
AstExpr* right,
std::optional<TypeId> expectedType
);
Inference check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional<TypeId> expectedType);
Inference check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert);
Inference check(const ScopePtr& scope, AstExprInterpString* interpString);
Inference check(const ScopePtr& scope, AstExprTable* expr, std::optional<TypeId> expectedType);
std::tuple<TypeId, TypeId, RefinementId> checkBinary(const ScopePtr& scope, AstExprBinary* binary, std::optional<TypeId> expectedType);
std::tuple<TypeId, TypeId, RefinementId> checkBinary(
const ScopePtr& scope,
AstExprBinary::Op op,
AstExpr* left,
AstExpr* right,
std::optional<TypeId> expectedType
);
void visitLValue(const ScopePtr& scope, AstExpr* expr, TypeId rhsType);
void visitLValue(const ScopePtr& scope, AstExprLocal* local, TypeId rhsType);
@ -321,6 +351,11 @@ private:
*/
void checkFunctionBody(const ScopePtr& scope, AstExprFunction* fn);
// Specializations of 'resolveType' below
TypeId resolveReferenceType(const ScopePtr& scope, AstType* ty, AstTypeReference* ref, bool inTypeArguments, bool replaceErrorWithFresh);
TypeId resolveTableType(const ScopePtr& scope, AstType* ty, AstTypeTable* tab, bool inTypeArguments, bool replaceErrorWithFresh);
TypeId resolveFunctionType(const ScopePtr& scope, AstType* ty, AstTypeFunction* fn, bool inTypeArguments, bool replaceErrorWithFresh);
/**
* Resolves a type from its AST annotation.
* @param scope the scope that the type annotation appears within.
@ -360,7 +395,7 @@ private:
**/
std::vector<std::pair<Name, GenericTypeDefinition>> createGenerics(
const ScopePtr& scope,
AstArray<AstGenericType> generics,
AstArray<AstGenericType*> generics,
bool useCache = false,
bool addTypes = true
);
@ -377,7 +412,7 @@ private:
**/
std::vector<std::pair<Name, GenericTypePackDefinition>> createGenericPacks(
const ScopePtr& scope,
AstArray<AstGenericTypePack> packs,
AstArray<AstGenericTypePack*> packs,
bool useCache = false,
bool addTypes = true
);
@ -391,6 +426,7 @@ private:
TypeId makeUnion(const ScopePtr& scope, Location location, TypeId lhs, TypeId rhs);
// make an intersect type function of these two types
TypeId makeIntersect(const ScopePtr& scope, Location location, TypeId lhs, TypeId rhs);
void prepopulateGlobalScopeForFragmentTypecheck(const ScopePtr& globalScope, const ScopePtr& resumeScope, AstStatBlock* program);
/** Scan the program for global definitions.
*
@ -421,6 +457,8 @@ private:
const ScopePtr& scope,
Location location
);
TypeId simplifyUnion(const ScopePtr& scope, Location location, TypeId left, TypeId right);
};
/** Borrow a vector of pointers from a vector of owning pointers to constraints.

View file

@ -3,7 +3,9 @@
#pragma once
#include "Luau/Constraint.h"
#include "Luau/DataFlowGraph.h"
#include "Luau/DenseHash.h"
#include "Luau/EqSatSimplification.h"
#include "Luau/Error.h"
#include "Luau/Location.h"
#include "Luau/Module.h"
@ -12,6 +14,7 @@
#include "Luau/ToString.h"
#include "Luau/Type.h"
#include "Luau/TypeCheckLimits.h"
#include "Luau/TypeFunction.h"
#include "Luau/TypeFwd.h"
#include "Luau/Variant.h"
@ -56,17 +59,42 @@ struct HashInstantiationSignature
size_t operator()(const InstantiationSignature& signature) const;
};
struct TablePropLookupResult
{
// What types are we blocked on for determining this type?
std::vector<TypeId> blockedTypes;
// The type of the property (if we were able to determine it).
std::optional<TypeId> propType;
// Whether or not this is _definitely_ derived as the result of an indexer.
// We use this to determine whether or not code like:
//
// t.lol = nil;
//
// ... is legal. If `t: { [string]: ~nil }` then this is legal as
// there's no guarantee on whether "lol" specifically exists.
// However, if `t: { lol: ~nil }`, then we cannot allow assignment as
// that would remove "lol" from the table entirely.
bool isIndex = false;
};
struct ConstraintSolver
{
NotNull<TypeArena> arena;
NotNull<BuiltinTypes> builtinTypes;
InternalErrorReporter iceReporter;
NotNull<Normalizer> normalizer;
NotNull<Simplifier> simplifier;
NotNull<TypeFunctionRuntime> typeFunctionRuntime;
// The entire set of constraints that the solver is trying to resolve.
std::vector<NotNull<Constraint>> constraints;
NotNull<DenseHashMap<Scope*, TypeId>> scopeToFunction;
NotNull<Scope> rootScope;
ModuleName currentModuleName;
// The dataflow graph of the program, used in constraint generation and for magic functions.
NotNull<const DataFlowGraph> dfg;
// Constraints that the solver has generated, rather than sourcing from the
// scope tree.
std::vector<std::unique_ptr<Constraint>> solverConstraints;
@ -91,6 +119,9 @@ struct ConstraintSolver
// A mapping from free types to the number of unresolved constraints that mention them.
DenseHashMap<TypeId, size_t> unresolvedConstraints{{}};
std::unordered_map<NotNull<const Constraint>, DenseHashSet<TypeId>> maybeMutatedFreeTypes;
std::unordered_map<TypeId, DenseHashSet<const Constraint*>> mutatedFreeTypeToConstraint;
// Irreducible/uninhabited type functions or type pack functions.
DenseHashSet<const void*> uninhabitedTypeFunctions{{}};
@ -114,12 +145,16 @@ struct ConstraintSolver
explicit ConstraintSolver(
NotNull<Normalizer> normalizer,
NotNull<Simplifier> simplifier,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<Scope> rootScope,
std::vector<NotNull<Constraint>> constraints,
NotNull<DenseHashMap<Scope*, TypeId>> scopeToFunction,
ModuleName moduleName,
NotNull<ModuleResolver> moduleResolver,
std::vector<RequireCycle> requireCycles,
DcrLogger* logger,
NotNull<const DataFlowGraph> dfg,
TypeCheckLimits limits
);
@ -139,9 +174,11 @@ struct ConstraintSolver
**/
void finalizeTypeFunctions();
bool isDone();
bool isDone() const;
private:
void generalizeOneType(TypeId ty);
/**
* Bind a type variable to another type.
*
@ -167,13 +204,14 @@ public:
*/
bool tryDispatch(NotNull<const Constraint> c, bool force);
bool tryDispatch(const SubtypeConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const PackSubtypeConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const GeneralizationConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const SubtypeConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const PackSubtypeConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const GeneralizationConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const IterableConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const NameConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const TypeAliasExpansionConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const FunctionCallConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const TableCheckConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const FunctionCheckConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const PrimitiveTypeConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const HasPropConstraint& c, NotNull<const Constraint> constraint);
@ -194,16 +232,16 @@ public:
bool tryDispatch(const UnpackConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const ReduceConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const ReducePackConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const EqualityConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const EqualityConstraint& c, NotNull<const Constraint> constraint);
// for a, ... in some_table do
// also handles __iter metamethod
bool tryDispatchIterableTable(TypeId iteratorTy, const IterableConstraint& c, NotNull<const Constraint> constraint, bool force);
// for a, ... in next_function, t, ... do
bool tryDispatchIterableFunction(TypeId nextTy, TypeId tableTy, const IterableConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatchIterableFunction(TypeId nextTy, TypeId tableTy, const IterableConstraint& c, NotNull<const Constraint> constraint);
std::pair<std::vector<TypeId>, std::optional<TypeId>> lookupTableProp(
TablePropLookupResult lookupTableProp(
NotNull<const Constraint> constraint,
TypeId subjectType,
const std::string& propName,
@ -211,7 +249,8 @@ public:
bool inConditional = false,
bool suppressSimplification = false
);
std::pair<std::vector<TypeId>, std::optional<TypeId>> lookupTableProp(
TablePropLookupResult lookupTableProp(
NotNull<const Constraint> constraint,
TypeId subjectType,
const std::string& propName,
@ -270,10 +309,10 @@ public:
// FIXME: This use of a boolean for the return result is an appalling
// interface.
bool blockOnPendingTypes(TypeId target, NotNull<const Constraint> constraint);
bool blockOnPendingTypes(TypePackId target, NotNull<const Constraint> constraint);
bool blockOnPendingTypes(TypePackId targetPack, NotNull<const Constraint> constraint);
void unblock(NotNull<const Constraint> progressed);
void unblock(TypeId progressed, Location location);
void unblock(TypeId ty, Location location);
void unblock(TypePackId progressed, Location location);
void unblock(const std::vector<TypeId>& types, Location location);
void unblock(const std::vector<TypePackId>& packs, Location location);
@ -281,18 +320,18 @@ public:
/**
* @returns true if the TypeId is in a blocked state.
*/
bool isBlocked(TypeId ty);
bool isBlocked(TypeId ty) const;
/**
* @returns true if the TypePackId is in a blocked state.
*/
bool isBlocked(TypePackId tp);
bool isBlocked(TypePackId tp) const;
/**
* Returns whether the constraint is blocked on anything.
* @param constraint the constraint to check.
*/
bool isBlocked(NotNull<const Constraint> constraint);
bool isBlocked(NotNull<const Constraint> constraint) const;
/** Pushes a new solver constraint to the solver.
* @param cv the body of the constraint.
@ -308,7 +347,7 @@ public:
* @param location the location where the require is taking place; used for
* error locations.
**/
TypeId resolveModule(const ModuleInfo& module, const Location& location);
TypeId resolveModule(const ModuleInfo& info, const Location& location);
void reportError(TypeErrorData&& data, const Location& location);
void reportError(TypeError e);
@ -379,15 +418,21 @@ public:
**/
void reproduceConstraints(NotNull<Scope> scope, const Location& location, const Substitution& subst);
TypeId simplifyIntersection(NotNull<Scope> scope, Location location, TypeId left, TypeId right);
TypeId simplifyIntersection(NotNull<Scope> scope, Location location, std::set<TypeId> parts);
TypeId simplifyUnion(NotNull<Scope> scope, Location location, TypeId left, TypeId right);
TypeId errorRecoveryType() const;
TypePackId errorRecoveryTypePack() const;
TypePackId anyifyModuleReturnTypePackGenerics(TypePackId tp);
void throwTimeLimitError();
void throwUserCancelError();
void throwTimeLimitError() const;
void throwUserCancelError() const;
ToStringOptions opts;
void fillInDiscriminantTypes(NotNull<const Constraint> constraint, const std::vector<std::optional<TypeId>>& discriminantTypes);
};
void dump(NotNull<Scope> rootScope, struct ToStringOptions& opts);

View file

@ -6,6 +6,7 @@
#include "Luau/ControlFlow.h"
#include "Luau/DenseHash.h"
#include "Luau/Def.h"
#include "Luau/NotNull.h"
#include "Luau/Symbol.h"
#include "Luau/TypedAllocator.h"
@ -35,6 +36,8 @@ struct DataFlowGraph
DataFlowGraph& operator=(DataFlowGraph&&) = default;
DefId getDef(const AstExpr* expr) const;
// Look up the definition optionally, knowing it may not be present.
std::optional<DefId> getDefOptional(const AstExpr* expr) const;
// Look up for the rvalue def for a compound assignment.
std::optional<DefId> getRValueDefForCompoundAssign(const AstExpr* expr) const;
@ -46,13 +49,13 @@ struct DataFlowGraph
const RefinementKey* getRefinementKey(const AstExpr* expr) const;
private:
DataFlowGraph() = default;
DataFlowGraph(NotNull<DefArena> defArena, NotNull<RefinementKeyArena> keyArena);
DataFlowGraph(const DataFlowGraph&) = delete;
DataFlowGraph& operator=(const DataFlowGraph&) = delete;
DefArena defArena;
RefinementKeyArena keyArena;
NotNull<DefArena> defArena;
NotNull<RefinementKeyArena> keyArena;
DenseHashMap<const AstExpr*, const Def*> astDefs{nullptr};
@ -68,7 +71,6 @@ private:
DenseHashMap<const AstExpr*, const Def*> compoundAssignDefs{nullptr};
DenseHashMap<const AstExpr*, const RefinementKey*> astRefinementKeys{nullptr};
friend struct DataFlowGraphBuilder;
};
@ -105,25 +107,37 @@ struct DataFlowResult
const RefinementKey* parent = nullptr;
};
using ScopeStack = std::vector<DfgScope*>;
struct DataFlowGraphBuilder
{
static DataFlowGraph build(AstStatBlock* root, NotNull<struct InternalErrorReporter> handle);
static DataFlowGraph build(
AstStatBlock* block,
NotNull<DefArena> defArena,
NotNull<RefinementKeyArena> keyArena,
NotNull<struct InternalErrorReporter> handle
);
private:
DataFlowGraphBuilder() = default;
DataFlowGraphBuilder(NotNull<DefArena> defArena, NotNull<RefinementKeyArena> keyArena);
DataFlowGraphBuilder(const DataFlowGraphBuilder&) = delete;
DataFlowGraphBuilder& operator=(const DataFlowGraphBuilder&) = delete;
DataFlowGraph graph;
NotNull<DefArena> defArena{&graph.defArena};
NotNull<RefinementKeyArena> keyArena{&graph.keyArena};
NotNull<DefArena> defArena;
NotNull<RefinementKeyArena> keyArena;
struct InternalErrorReporter* handle = nullptr;
DfgScope* moduleScope = nullptr;
/// The arena owning all of the scope allocations for the dataflow graph being built.
std::vector<std::unique_ptr<DfgScope>> scopes;
/// A stack of scopes used by the visitor to see where we are.
ScopeStack scopeStack;
DfgScope* currentScope();
struct FunctionCapture
{
std::vector<DefId> captureDefs;
@ -134,81 +148,81 @@ private:
DenseHashMap<Symbol, FunctionCapture> captures{Symbol{}};
void resolveCaptures();
DfgScope* childScope(DfgScope* scope, DfgScope::ScopeType scopeType = DfgScope::Linear);
DfgScope* makeChildScope(DfgScope::ScopeType scopeType = DfgScope::Linear);
void join(DfgScope* p, DfgScope* a, DfgScope* b);
void joinBindings(DfgScope* p, const DfgScope& a, const DfgScope& b);
void joinProps(DfgScope* p, const DfgScope& a, const DfgScope& b);
DefId lookup(DfgScope* scope, Symbol symbol);
DefId lookup(DfgScope* scope, DefId def, const std::string& key);
DefId lookup(Symbol symbol);
DefId lookup(DefId def, const std::string& key);
ControlFlow visit(DfgScope* scope, AstStatBlock* b);
ControlFlow visitBlockWithoutChildScope(DfgScope* scope, AstStatBlock* b);
ControlFlow visit(AstStatBlock* b);
ControlFlow visitBlockWithoutChildScope(AstStatBlock* b);
ControlFlow visit(DfgScope* scope, AstStat* s);
ControlFlow visit(DfgScope* scope, AstStatIf* i);
ControlFlow visit(DfgScope* scope, AstStatWhile* w);
ControlFlow visit(DfgScope* scope, AstStatRepeat* r);
ControlFlow visit(DfgScope* scope, AstStatBreak* b);
ControlFlow visit(DfgScope* scope, AstStatContinue* c);
ControlFlow visit(DfgScope* scope, AstStatReturn* r);
ControlFlow visit(DfgScope* scope, AstStatExpr* e);
ControlFlow visit(DfgScope* scope, AstStatLocal* l);
ControlFlow visit(DfgScope* scope, AstStatFor* f);
ControlFlow visit(DfgScope* scope, AstStatForIn* f);
ControlFlow visit(DfgScope* scope, AstStatAssign* a);
ControlFlow visit(DfgScope* scope, AstStatCompoundAssign* c);
ControlFlow visit(DfgScope* scope, AstStatFunction* f);
ControlFlow visit(DfgScope* scope, AstStatLocalFunction* l);
ControlFlow visit(DfgScope* scope, AstStatTypeAlias* t);
ControlFlow visit(DfgScope* scope, AstStatTypeFunction* f);
ControlFlow visit(DfgScope* scope, AstStatDeclareGlobal* d);
ControlFlow visit(DfgScope* scope, AstStatDeclareFunction* d);
ControlFlow visit(DfgScope* scope, AstStatDeclareClass* d);
ControlFlow visit(DfgScope* scope, AstStatError* error);
ControlFlow visit(AstStat* s);
ControlFlow visit(AstStatIf* i);
ControlFlow visit(AstStatWhile* w);
ControlFlow visit(AstStatRepeat* r);
ControlFlow visit(AstStatBreak* b);
ControlFlow visit(AstStatContinue* c);
ControlFlow visit(AstStatReturn* r);
ControlFlow visit(AstStatExpr* e);
ControlFlow visit(AstStatLocal* l);
ControlFlow visit(AstStatFor* f);
ControlFlow visit(AstStatForIn* f);
ControlFlow visit(AstStatAssign* a);
ControlFlow visit(AstStatCompoundAssign* c);
ControlFlow visit(AstStatFunction* f);
ControlFlow visit(AstStatLocalFunction* l);
ControlFlow visit(AstStatTypeAlias* t);
ControlFlow visit(AstStatTypeFunction* f);
ControlFlow visit(AstStatDeclareGlobal* d);
ControlFlow visit(AstStatDeclareFunction* d);
ControlFlow visit(AstStatDeclareClass* d);
ControlFlow visit(AstStatError* error);
DataFlowResult visitExpr(DfgScope* scope, AstExpr* e);
DataFlowResult visitExpr(DfgScope* scope, AstExprGroup* group);
DataFlowResult visitExpr(DfgScope* scope, AstExprLocal* l);
DataFlowResult visitExpr(DfgScope* scope, AstExprGlobal* g);
DataFlowResult visitExpr(DfgScope* scope, AstExprCall* c);
DataFlowResult visitExpr(DfgScope* scope, AstExprIndexName* i);
DataFlowResult visitExpr(DfgScope* scope, AstExprIndexExpr* i);
DataFlowResult visitExpr(DfgScope* scope, AstExprFunction* f);
DataFlowResult visitExpr(DfgScope* scope, AstExprTable* t);
DataFlowResult visitExpr(DfgScope* scope, AstExprUnary* u);
DataFlowResult visitExpr(DfgScope* scope, AstExprBinary* b);
DataFlowResult visitExpr(DfgScope* scope, AstExprTypeAssertion* t);
DataFlowResult visitExpr(DfgScope* scope, AstExprIfElse* i);
DataFlowResult visitExpr(DfgScope* scope, AstExprInterpString* i);
DataFlowResult visitExpr(DfgScope* scope, AstExprError* error);
DataFlowResult visitExpr(AstExpr* e);
DataFlowResult visitExpr(AstExprGroup* group);
DataFlowResult visitExpr(AstExprLocal* l);
DataFlowResult visitExpr(AstExprGlobal* g);
DataFlowResult visitExpr(AstExprCall* c);
DataFlowResult visitExpr(AstExprIndexName* i);
DataFlowResult visitExpr(AstExprIndexExpr* i);
DataFlowResult visitExpr(AstExprFunction* f);
DataFlowResult visitExpr(AstExprTable* t);
DataFlowResult visitExpr(AstExprUnary* u);
DataFlowResult visitExpr(AstExprBinary* b);
DataFlowResult visitExpr(AstExprTypeAssertion* t);
DataFlowResult visitExpr(AstExprIfElse* i);
DataFlowResult visitExpr(AstExprInterpString* i);
DataFlowResult visitExpr(AstExprError* error);
void visitLValue(DfgScope* scope, AstExpr* e, DefId incomingDef);
DefId visitLValue(DfgScope* scope, AstExprLocal* l, DefId incomingDef);
DefId visitLValue(DfgScope* scope, AstExprGlobal* g, DefId incomingDef);
DefId visitLValue(DfgScope* scope, AstExprIndexName* i, DefId incomingDef);
DefId visitLValue(DfgScope* scope, AstExprIndexExpr* i, DefId incomingDef);
DefId visitLValue(DfgScope* scope, AstExprError* e, DefId incomingDef);
void visitLValue(AstExpr* e, DefId incomingDef);
DefId visitLValue(AstExprLocal* l, DefId incomingDef);
DefId visitLValue(AstExprGlobal* g, DefId incomingDef);
DefId visitLValue(AstExprIndexName* i, DefId incomingDef);
DefId visitLValue(AstExprIndexExpr* i, DefId incomingDef);
DefId visitLValue(AstExprError* e, DefId incomingDef);
void visitType(DfgScope* scope, AstType* t);
void visitType(DfgScope* scope, AstTypeReference* r);
void visitType(DfgScope* scope, AstTypeTable* t);
void visitType(DfgScope* scope, AstTypeFunction* f);
void visitType(DfgScope* scope, AstTypeTypeof* t);
void visitType(DfgScope* scope, AstTypeUnion* u);
void visitType(DfgScope* scope, AstTypeIntersection* i);
void visitType(DfgScope* scope, AstTypeError* error);
void visitType(AstType* t);
void visitType(AstTypeReference* r);
void visitType(AstTypeTable* t);
void visitType(AstTypeFunction* f);
void visitType(AstTypeTypeof* t);
void visitType(AstTypeUnion* u);
void visitType(AstTypeIntersection* i);
void visitType(AstTypeError* error);
void visitTypePack(DfgScope* scope, AstTypePack* p);
void visitTypePack(DfgScope* scope, AstTypePackExplicit* e);
void visitTypePack(DfgScope* scope, AstTypePackVariadic* v);
void visitTypePack(DfgScope* scope, AstTypePackGeneric* g);
void visitTypePack(AstTypePack* p);
void visitTypePack(AstTypePackExplicit* e);
void visitTypePack(AstTypePackVariadic* v);
void visitTypePack(AstTypePackGeneric* g);
void visitTypeList(DfgScope* scope, AstTypeList l);
void visitTypeList(AstTypeList l);
void visitGenerics(DfgScope* scope, AstArray<AstGenericType> g);
void visitGenericPacks(DfgScope* scope, AstArray<AstGenericTypePack> g);
void visitGenerics(AstArray<AstGenericType*> g);
void visitGenericPacks(AstArray<AstGenericTypePack*> g);
};
} // namespace Luau

View file

@ -0,0 +1,50 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/TypeFwd.h"
#include "Luau/NotNull.h"
#include "Luau/DenseHash.h"
#include <memory>
#include <optional>
#include <vector>
namespace Luau
{
struct TypeArena;
}
// The EqSat stuff is pretty template heavy, so we go to some lengths to prevent
// the complexity from leaking outside its implementation sources.
namespace Luau::EqSatSimplification
{
struct Simplifier;
using SimplifierPtr = std::unique_ptr<Simplifier, void (*)(Simplifier*)>;
SimplifierPtr newSimplifier(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtinTypes);
} // namespace Luau::EqSatSimplification
namespace Luau
{
struct EqSatSimplificationResult
{
TypeId result;
// New type function applications that were created by the reduction phase.
// We return these so that the ConstraintSolver can know to try to reduce
// them.
std::vector<TypeId> newTypeFunctions;
};
using EqSatSimplification::newSimplifier; // NOLINT: clang-tidy thinks these are unused. It is incorrect.
using Luau::EqSatSimplification::Simplifier; // NOLINT
using Luau::EqSatSimplification::SimplifierPtr;
std::optional<EqSatSimplificationResult> eqSatSimplify(NotNull<Simplifier> simplifier, TypeId ty);
} // namespace Luau

View file

@ -0,0 +1,376 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/EGraph.h"
#include "Luau/Id.h"
#include "Luau/Language.h"
#include "Luau/Lexer.h" // For Allocator
#include "Luau/NotNull.h"
#include "Luau/TypeArena.h"
#include "Luau/TypeFwd.h"
namespace Luau
{
struct TypeFunction;
}
namespace Luau::EqSatSimplification
{
using StringId = uint32_t;
using Id = Luau::EqSat::Id;
LUAU_EQSAT_UNIT(TNil);
LUAU_EQSAT_UNIT(TBoolean);
LUAU_EQSAT_UNIT(TNumber);
LUAU_EQSAT_UNIT(TString);
LUAU_EQSAT_UNIT(TThread);
LUAU_EQSAT_UNIT(TTopFunction);
LUAU_EQSAT_UNIT(TTopTable);
LUAU_EQSAT_UNIT(TTopClass);
LUAU_EQSAT_UNIT(TBuffer);
// Used for any type that eqsat can't do anything interesting with.
LUAU_EQSAT_ATOM(TOpaque, TypeId);
LUAU_EQSAT_ATOM(SBoolean, bool);
LUAU_EQSAT_ATOM(SString, StringId);
LUAU_EQSAT_ATOM(TFunction, TypeId);
LUAU_EQSAT_ATOM(TImportedTable, TypeId);
LUAU_EQSAT_ATOM(TClass, TypeId);
LUAU_EQSAT_UNIT(TAny);
LUAU_EQSAT_UNIT(TError);
LUAU_EQSAT_UNIT(TUnknown);
LUAU_EQSAT_UNIT(TNever);
LUAU_EQSAT_NODE_SET(Union);
LUAU_EQSAT_NODE_SET(Intersection);
LUAU_EQSAT_NODE_ARRAY(Negation, 1);
LUAU_EQSAT_NODE_ATOM_WITH_VECTOR(TTypeFun, std::shared_ptr<const TypeFunctionInstanceType>);
LUAU_EQSAT_UNIT(TNoRefine);
LUAU_EQSAT_UNIT(Invalid);
// enodes are immutable, but types are cyclic. We need a way to tie the knot.
// We handle this by generating TBound nodes at points where we encounter cycles.
// Each TBound has an ordinal that we later map onto the type.
// We use a substitution rule to replace all TBound nodes with their referrent.
LUAU_EQSAT_ATOM(TBound, size_t);
// Tables are sufficiently unlike other enodes that the Language.h macros won't cut it.
struct TTable
{
explicit TTable(Id basis);
TTable(Id basis, std::vector<StringId> propNames_, std::vector<Id> propTypes_);
// All TTables extend some other table. This may be TTopTable.
//
// It will frequently be a TImportedTable, in which case we can reuse things
// like source location and documentation info.
Id getBasis() const;
EqSat::Slice<const Id> propTypes() const;
// TODO: Also support read-only table props
// TODO: Indexer type, index result type.
std::vector<StringId> propNames;
// The enode interface
EqSat::Slice<Id> mutableOperands();
EqSat::Slice<const Id> operands() const;
bool operator==(const TTable& rhs) const;
bool operator!=(const TTable& rhs) const
{
return !(*this == rhs);
}
struct Hash
{
size_t operator()(const TTable& value) const;
};
private:
// The first element of this vector is the basis. Subsequent elements are
// property types. As we add other things like read-only properties and
// indexers, the structure of this array is likely to change.
//
// We encode our data in this way so that the operands() method can properly
// return a Slice<Id>.
std::vector<Id> storage;
};
template<typename L>
using Node = EqSat::Node<L>;
using EType = EqSat::Language<
TNil,
TBoolean,
TNumber,
TString,
TThread,
TTopFunction,
TTopTable,
TTopClass,
TBuffer,
TOpaque,
SBoolean,
SString,
TFunction,
TTable,
TImportedTable,
TClass,
TAny,
TError,
TUnknown,
TNever,
Union,
Intersection,
Negation,
TTypeFun,
Invalid,
TNoRefine,
TBound>;
struct StringCache
{
Allocator allocator;
DenseHashMap<std::string_view, StringId> strings{{}};
std::vector<std::string_view> views;
StringId add(std::string_view s);
std::string_view asStringView(StringId id) const;
std::string asString(StringId id) const;
};
using EGraph = Luau::EqSat::EGraph<EType, struct Simplify>;
struct Simplify
{
using Data = bool;
template<typename T>
Data make(const EGraph&, const T&) const;
void join(Data& left, const Data& right) const;
};
struct Subst
{
Id eclass;
Id newClass;
// The node into eclass which is boring, if any
std::optional<size_t> boringIndex;
std::string desc;
Subst(Id eclass, Id newClass, std::string desc = "");
};
struct Simplifier
{
NotNull<TypeArena> arena;
NotNull<BuiltinTypes> builtinTypes;
EGraph egraph;
StringCache stringCache;
// enodes are immutable but types can be cyclic, so we need some way to
// encode the cycle. This map is used to connect TBound nodes to the right
// eclass.
//
// The cyclicIntersection rewrite rule uses this to sense when a cycle can
// be deleted from an intersection or union.
std::unordered_map<size_t, Id> mappingIdToClass;
std::vector<Subst> substs;
using RewriteRuleFn = void (Simplifier::*)(Id id);
Simplifier(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtinTypes);
// Utilities
const EqSat::EClass<EType, Simplify::Data>& get(Id id) const;
Id find(Id id) const;
Id add(EType enode);
template<typename Tag>
const Tag* isTag(Id id) const;
template<typename Tag>
const Tag* isTag(const EType& enode) const;
void subst(Id from, Id to);
void subst(Id from, Id to, const std::string& ruleName);
void subst(Id from, Id to, const std::string& ruleName, const std::unordered_map<Id, size_t>& forceNodes);
void subst(Id from, size_t boringIndex, Id to, const std::string& ruleName, const std::unordered_map<Id, size_t>& forceNodes);
void unionClasses(std::vector<Id>& hereParts, Id there);
// Rewrite rules
void simplifyUnion(Id id);
void uninhabitedIntersection(Id id);
void intersectWithNegatedClass(Id id);
void intersectWithNegatedAtom(Id id);
void intersectWithNoRefine(Id id);
void cyclicIntersectionOfUnion(Id id);
void cyclicUnionOfIntersection(Id id);
void expandNegation(Id id);
void intersectionOfUnion(Id id);
void intersectTableProperty(Id id);
void uninhabitedTable(Id id);
void unneededTableModification(Id id);
void builtinTypeFunctions(Id id);
void iffyTypeFunctions(Id id);
void strictMetamethods(Id id);
};
template<typename Tag>
struct QueryIterator
{
QueryIterator();
QueryIterator(EGraph* egraph, Id eclass);
bool operator==(const QueryIterator& other) const;
bool operator!=(const QueryIterator& other) const;
std::pair<const Tag*, size_t> operator*() const;
QueryIterator& operator++();
QueryIterator& operator++(int);
private:
EGraph* egraph = nullptr;
Id eclass;
size_t index = 0;
};
template<typename Tag>
struct Query
{
EGraph* egraph;
Id eclass;
Query(EGraph* egraph, Id eclass)
: egraph(egraph)
, eclass(eclass)
{
}
QueryIterator<Tag> begin()
{
return QueryIterator<Tag>{egraph, eclass};
}
QueryIterator<Tag> end()
{
return QueryIterator<Tag>{};
}
};
template<typename Tag>
QueryIterator<Tag>::QueryIterator()
: egraph(nullptr)
, eclass(Id{0})
, index(0)
{
}
template<typename Tag>
QueryIterator<Tag>::QueryIterator(EGraph* egraph_, Id eclass)
: egraph(egraph_)
, eclass(eclass)
, index(0)
{
const auto& ecl = (*egraph)[eclass];
static constexpr const int idx = EType::VariantTy::getTypeId<Tag>();
for (const auto& enode : ecl.nodes)
{
if (enode.node.index() < idx)
++index;
else
break;
}
if (index >= ecl.nodes.size() || ecl.nodes[index].node.index() != idx)
{
egraph = nullptr;
index = 0;
}
}
template<typename Tag>
bool QueryIterator<Tag>::operator==(const QueryIterator<Tag>& rhs) const
{
if (egraph == nullptr && rhs.egraph == nullptr)
return true;
return egraph == rhs.egraph && eclass == rhs.eclass && index == rhs.index;
}
template<typename Tag>
bool QueryIterator<Tag>::operator!=(const QueryIterator<Tag>& rhs) const
{
return !(*this == rhs);
}
template<typename Tag>
std::pair<const Tag*, size_t> QueryIterator<Tag>::operator*() const
{
LUAU_ASSERT(egraph != nullptr);
EGraph::EClassT& ecl = (*egraph)[eclass];
LUAU_ASSERT(index < ecl.nodes.size());
auto& enode = ecl.nodes[index].node;
Tag* result = enode.template get<Tag>();
LUAU_ASSERT(result);
return {result, index};
}
// pre-increment
template<typename Tag>
QueryIterator<Tag>& QueryIterator<Tag>::operator++()
{
const auto& ecl = (*egraph)[eclass];
do
{
++index;
if (index >= ecl.nodes.size() || ecl.nodes[index].node.index() != EType::VariantTy::getTypeId<Tag>())
{
egraph = nullptr;
index = 0;
break;
}
} while (ecl.nodes[index].boring);
return *this;
}
// post-increment
template<typename Tag>
QueryIterator<Tag>& QueryIterator<Tag>::operator++(int)
{
QueryIterator<Tag> res = *this;
++res;
return res;
}
} // namespace Luau::EqSatSimplification

View file

@ -448,6 +448,13 @@ struct UnexpectedTypePackInSubtyping
bool operator==(const UnexpectedTypePackInSubtyping& rhs) const;
};
struct UserDefinedTypeFunctionError
{
std::string message;
bool operator==(const UserDefinedTypeFunctionError& rhs) const;
};
using TypeErrorData = Variant<
TypeMismatch,
UnknownSymbol,
@ -496,7 +503,8 @@ using TypeErrorData = Variant<
CheckedFunctionIncorrectArgs,
UnexpectedTypeInSubtyping,
UnexpectedTypePackInSubtyping,
ExplicitFunctionAnnotationRecommended>;
ExplicitFunctionAnnotationRecommended,
UserDefinedTypeFunctionError>;
struct TypeErrorSummary
{

View file

@ -3,6 +3,7 @@
#include <string>
#include <optional>
#include <vector>
namespace Luau
{
@ -31,6 +32,13 @@ struct ModuleInfo
bool optional = false;
};
struct RequireSuggestion
{
std::string label;
std::string fullPath;
};
using RequireSuggestions = std::vector<RequireSuggestion>;
struct FileResolver
{
virtual ~FileResolver() {}
@ -51,6 +59,11 @@ struct FileResolver
{
return std::nullopt;
}
virtual std::optional<RequireSuggestions> getRequireSuggestions(const ModuleName& requirer, const std::optional<std::string>& pathString) const
{
return std::nullopt;
}
};
struct NullFileResolver : FileResolver

View file

@ -0,0 +1,146 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Ast.h"
#include "Luau/Parser.h"
#include "Luau/AutocompleteTypes.h"
#include "Luau/DenseHash.h"
#include "Luau/Module.h"
#include "Luau/Frontend.h"
#include <memory>
#include <vector>
namespace Luau
{
struct FrontendOptions;
enum class FragmentTypeCheckStatus
{
SkipAutocomplete,
Success,
};
struct FragmentAutocompleteAncestryResult
{
DenseHashMap<AstName, AstLocal*> localMap{AstName()};
std::vector<AstLocal*> localStack;
std::vector<AstNode*> ancestry;
AstStat* nearestStatement = nullptr;
};
struct FragmentParseResult
{
std::string fragmentToParse;
AstStatBlock* root = nullptr;
std::vector<AstNode*> ancestry;
AstStat* nearestStatement = nullptr;
std::vector<Comment> commentLocations;
std::unique_ptr<Allocator> alloc = std::make_unique<Allocator>();
};
struct FragmentTypeCheckResult
{
ModulePtr incrementalModule = nullptr;
ScopePtr freshScope;
std::vector<AstNode*> ancestry;
};
struct FragmentAutocompleteResult
{
ModulePtr incrementalModule;
Scope* freshScope;
TypeArena arenaForAutocomplete;
AutocompleteResult acResults;
};
FragmentAutocompleteAncestryResult findAncestryForFragmentParse(AstStatBlock* root, const Position& cursorPos);
std::optional<FragmentParseResult> parseFragment(
const SourceModule& srcModule,
std::string_view src,
const Position& cursorPos,
std::optional<Position> fragmentEndPosition
);
std::pair<FragmentTypeCheckStatus, FragmentTypeCheckResult> typecheckFragment(
Frontend& frontend,
const ModuleName& moduleName,
const Position& cursorPos,
std::optional<FrontendOptions> opts,
std::string_view src,
std::optional<Position> fragmentEndPosition
);
FragmentAutocompleteResult fragmentAutocomplete(
Frontend& frontend,
std::string_view src,
const ModuleName& moduleName,
Position cursorPosition,
std::optional<FrontendOptions> opts,
StringCompletionCallback callback,
std::optional<Position> fragmentEndPosition = std::nullopt
);
enum class FragmentAutocompleteStatus
{
Success,
FragmentTypeCheckFail,
InternalIce
};
struct FragmentAutocompleteStatusResult
{
FragmentAutocompleteStatus status;
std::optional<FragmentAutocompleteResult> result;
};
struct FragmentContext
{
std::string_view newSrc;
const ParseResult& newAstRoot;
std::optional<FrontendOptions> opts;
std::optional<Position> DEPRECATED_fragmentEndPosition;
};
/**
* @brief Attempts to compute autocomplete suggestions from the fragment context.
*
* This function computes autocomplete suggestions using outdated frontend typechecking data
* by patching the fragment context of the new script source content.
*
* @param frontend The Luau Frontend data structure, which may contain outdated typechecking data.
*
* @param moduleName The name of the target module, specifying which script the caller wants to request autocomplete for.
*
* @param cursorPosition The position in the script where the caller wants to trigger autocomplete.
*
* @param context The fragment context that this API will use to patch the outdated typechecking data.
*
* @param stringCompletionCB A callback function that provides autocomplete suggestions for string contexts.
*
* @return
* The status indicating whether `fragmentAutocomplete` ran successfully or failed, along with the reason for failure.
* Also includes autocomplete suggestions if the status is successful.
*
* @usage
* FragmentAutocompleteStatusResult acStatusResult;
* if (shouldFragmentAC)
* acStatusResult = Luau::tryFragmentAutocomplete(...);
*
* if (acStatusResult.status != Successful)
* {
* frontend.check(moduleName, options);
* acStatusResult.acResult = Luau::autocomplete(...);
* }
* return convertResultWithContext(acStatusResult.acResult);
*/
FragmentAutocompleteStatusResult tryFragmentAutocomplete(
Frontend& frontend,
const ModuleName& moduleName,
Position cursorPosition,
FragmentContext context,
StringCompletionCallback stringCompletionCB
);
} // namespace Luau

View file

@ -7,6 +7,7 @@
#include "Luau/ModuleResolver.h"
#include "Luau/RequireTracer.h"
#include "Luau/Scope.h"
#include "Luau/Set.h"
#include "Luau/TypeCheckLimits.h"
#include "Luau/Variant.h"
#include "Luau/AnyTypeSummary.h"
@ -44,21 +45,6 @@ struct LoadDefinitionFileResult
std::optional<Mode> parseMode(const std::vector<HotComment>& hotcomments);
std::vector<std::string_view> parsePathExpr(const AstExpr& pathExpr);
// Exported only for convenient testing.
std::optional<ModuleName> pathExprToModuleName(const ModuleName& currentModuleName, const std::vector<std::string_view>& expr);
/** Try to convert an AST fragment into a ModuleName.
* Returns std::nullopt if the expression cannot be resolved. This will most likely happen in cases where
* the import path involves some dynamic computation that we cannot see into at typechecking time.
*
* Unintuitively, weirdly-formulated modules (like game.Parent.Parent.Parent.Foo) will successfully produce a ModuleName
* as long as it falls within the permitted syntax. This is ok because we will fail to find the module and produce an
* error when we try during typechecking.
*/
std::optional<ModuleName> pathExprToModuleName(const ModuleName& currentModuleName, const AstExpr& expr);
struct SourceNode
{
bool hasDirtySourceModule() const
@ -71,13 +57,32 @@ struct SourceNode
return forAutocomplete ? dirtyModuleForAutocomplete : dirtyModule;
}
bool hasInvalidModuleDependency(bool forAutocomplete) const
{
return forAutocomplete ? invalidModuleDependencyForAutocomplete : invalidModuleDependency;
}
void setInvalidModuleDependency(bool value, bool forAutocomplete)
{
if (forAutocomplete)
invalidModuleDependencyForAutocomplete = value;
else
invalidModuleDependency = value;
}
ModuleName name;
std::string humanReadableName;
DenseHashSet<ModuleName> requireSet{{}};
std::vector<std::pair<ModuleName, Location>> requireLocations;
Set<ModuleName> dependents{{}};
bool dirtySourceModule = true;
bool dirtyModule = true;
bool dirtyModuleForAutocomplete = true;
bool invalidModuleDependency = true;
bool invalidModuleDependencyForAutocomplete = true;
double autocompleteLimitsMult = 1.0;
};
@ -132,7 +137,7 @@ struct FrontendModuleResolver : ModuleResolver
std::optional<ModuleInfo> resolveModuleInfo(const ModuleName& currentModuleName, const AstExpr& pathExpr) override;
std::string getHumanReadableModuleName(const ModuleName& moduleName) const override;
void setModule(const ModuleName& moduleName, ModulePtr module);
bool setModule(const ModuleName& moduleName, ModulePtr module);
void clearModules();
private:
@ -166,9 +171,13 @@ struct Frontend
// Parse and typecheck module graph
CheckResult check(const ModuleName& name, std::optional<FrontendOptions> optionOverride = {}); // new shininess
bool allModuleDependenciesValid(const ModuleName& name, bool forAutocomplete = false) const;
bool isDirty(const ModuleName& name, bool forAutocomplete = false) const;
void markDirty(const ModuleName& name, std::vector<ModuleName>* markedDirty = nullptr);
void traverseDependents(const ModuleName& name, std::function<bool(SourceNode&)> processSubtree);
/** Borrow a pointer into the SourceModule cache.
*
* Returns nullptr if we don't have it. This could mean that the script
@ -209,6 +218,7 @@ struct Frontend
);
std::optional<CheckResult> getCheckResult(const ModuleName& name, bool accumulateNested, bool forAutocomplete = false);
std::vector<ModuleName> getRequiredScripts(const ModuleName& name);
private:
ModulePtr check(

View file

@ -60,7 +60,7 @@ struct ReplaceGenerics : Substitution
};
// A substitution which replaces generic functions by monomorphic functions
struct Instantiation : Substitution
struct Instantiation final : Substitution
{
Instantiation(const TxnLog* log, TypeArena* arena, NotNull<BuiltinTypes> builtinTypes, TypeLevel level, Scope* scope)
: Substitution(log, arena)

View file

@ -53,7 +53,7 @@ struct Replacer : Substitution
};
// A substitution which replaces generic functions by monomorphic functions
struct Instantiation2 : Substitution
struct Instantiation2 final : Substitution
{
// Mapping from generic types to free types to be used in instantiation.
DenseHashMap<TypeId, TypeId> genericSubstitutions{nullptr};

View file

@ -9,15 +9,24 @@
#include "Luau/Scope.h"
#include "Luau/TypeArena.h"
#include "Luau/AnyTypeSummary.h"
#include "Luau/DataFlowGraph.h"
#include <memory>
#include <vector>
#include <unordered_map>
#include <optional>
LUAU_FASTFLAG(LuauIncrementalAutocompleteCommentDetection)
namespace Luau
{
using LogLuauProc = void (*)(std::string_view);
extern LogLuauProc logLuau;
void setLogLuau(LogLuauProc ll);
void resetLogLuauProc();
struct Module;
struct AnyTypeSummary;
@ -54,6 +63,7 @@ struct SourceModule
}
};
bool isWithinComment(const std::vector<Comment>& commentLocations, Position pos);
bool isWithinComment(const SourceModule& sourceModule, Position pos);
bool isWithinComment(const ParseResult& result, Position pos);
@ -67,6 +77,9 @@ struct Module
{
~Module();
// TODO: Clip this when we clip FFlagLuauSolverV2
bool checkedInNewSolver = false;
ModuleName name;
std::string humanReadableName;
@ -132,6 +145,11 @@ struct Module
TypePackId returnType = nullptr;
std::unordered_map<Name, TypeFun> exportedTypeBindings;
// Arenas related to the DFG must persist after the DFG no longer exists, as
// Module objects maintain raw pointers to objects in these arenas.
DefArena defArena;
RefinementKeyArena keyArena;
bool hasModuleScope() const;
ScopePtr getModuleScope() const;

View file

@ -20,8 +20,6 @@ struct ModuleResolver
virtual ~ModuleResolver() {}
/** Compute a ModuleName from an AST fragment. This AST fragment is generally the argument to the require() function.
*
* You probably want to implement this with some variation of pathExprToModuleName.
*
* @returns The ModuleInfo if the expression is a syntactically legal path.
* @returns std::nullopt if we are unable to determine whether or not the expression is a valid path. Type inference will

View file

@ -9,11 +9,14 @@ namespace Luau
{
struct BuiltinTypes;
struct TypeFunctionRuntime;
struct UnifierSharedState;
struct TypeCheckLimits;
void checkNonStrict(
NotNull<BuiltinTypes> builtinTypes,
NotNull<Simplifier> simplifier,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<InternalErrorReporter> ice,
NotNull<UnifierSharedState> unifierState,
NotNull<const DataFlowGraph> dfg,

View file

@ -1,6 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/EqSatSimplification.h"
#include "Luau/NotNull.h"
#include "Luau/Set.h"
#include "Luau/TypeFwd.h"
@ -21,10 +22,22 @@ struct Scope;
using ModulePtr = std::shared_ptr<Module>;
bool isSubtype(TypeId subTy, TypeId superTy, NotNull<Scope> scope, NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter& ice);
bool isSubtype(TypePackId subTy, TypePackId superTy, NotNull<Scope> scope, NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter& ice);
bool isConsistentSubtype(TypeId subTy, TypeId superTy, NotNull<Scope> scope, NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter& ice);
bool isConsistentSubtype(TypePackId subTy, TypePackId superTy, NotNull<Scope> scope, NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter& ice);
bool isSubtype(
TypeId subTy,
TypeId superTy,
NotNull<Scope> scope,
NotNull<BuiltinTypes> builtinTypes,
NotNull<Simplifier> simplifier,
InternalErrorReporter& ice
);
bool isSubtype(
TypePackId subPack,
TypePackId superPack,
NotNull<Scope> scope,
NotNull<BuiltinTypes> builtinTypes,
NotNull<Simplifier> simplifier,
InternalErrorReporter& ice
);
class TypeIds
{
@ -336,6 +349,7 @@ struct NormalizedType
};
using SeenTablePropPairs = Set<std::pair<TypeId, TypeId>, TypeIdPairHash>;
class Normalizer
{
@ -390,7 +404,13 @@ public:
void unionTablesWithTable(TypeIds& heres, TypeId there);
void unionTables(TypeIds& heres, const TypeIds& theres);
NormalizationResult unionNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars = -1);
NormalizationResult unionNormalWithTy(NormalizedType& here, TypeId there, Set<TypeId>& seenSetTypes, int ignoreSmallerTyvars = -1);
NormalizationResult unionNormalWithTy(
NormalizedType& here,
TypeId there,
SeenTablePropPairs& seenTablePropPairs,
Set<TypeId>& seenSetTypes,
int ignoreSmallerTyvars = -1
);
// ------- Negations
std::optional<NormalizedType> negateNormal(const NormalizedType& here);
@ -407,16 +427,26 @@ public:
void intersectClassesWithClass(NormalizedClassType& heres, TypeId there);
void intersectStrings(NormalizedStringType& here, const NormalizedStringType& there);
std::optional<TypePackId> intersectionOfTypePacks(TypePackId here, TypePackId there);
std::optional<TypeId> intersectionOfTables(TypeId here, TypeId there, Set<TypeId>& seenSet);
void intersectTablesWithTable(TypeIds& heres, TypeId there, Set<TypeId>& seenSetTypes);
std::optional<TypeId> intersectionOfTables(TypeId here, TypeId there, SeenTablePropPairs& seenTablePropPairs, Set<TypeId>& seenSet);
void intersectTablesWithTable(TypeIds& heres, TypeId there, SeenTablePropPairs& seenTablePropPairs, Set<TypeId>& seenSetTypes);
void intersectTables(TypeIds& heres, const TypeIds& theres);
std::optional<TypeId> intersectionOfFunctions(TypeId here, TypeId there);
void intersectFunctionsWithFunction(NormalizedFunctionType& heress, TypeId there);
void intersectFunctions(NormalizedFunctionType& heress, const NormalizedFunctionType& theress);
NormalizationResult intersectTyvarsWithTy(NormalizedTyvars& here, TypeId there, Set<TypeId>& seenSetTypes);
NormalizationResult intersectTyvarsWithTy(
NormalizedTyvars& here,
TypeId there,
SeenTablePropPairs& seenTablePropPairs,
Set<TypeId>& seenSetTypes
);
NormalizationResult intersectNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars = -1);
NormalizationResult intersectNormalWithTy(NormalizedType& here, TypeId there, Set<TypeId>& seenSetTypes);
NormalizationResult normalizeIntersections(const std::vector<TypeId>& intersections, NormalizedType& outType, Set<TypeId>& seenSet);
NormalizationResult intersectNormalWithTy(NormalizedType& here, TypeId there, SeenTablePropPairs& seenTablePropPairs, Set<TypeId>& seenSetTypes);
NormalizationResult normalizeIntersections(
const std::vector<TypeId>& intersections,
NormalizedType& outType,
SeenTablePropPairs& seenTablePropPairs,
Set<TypeId>& seenSet
);
// Check for inhabitance
NormalizationResult isInhabited(TypeId ty);
@ -426,7 +456,7 @@ public:
// Check for intersections being inhabited
NormalizationResult isIntersectionInhabited(TypeId left, TypeId right);
NormalizationResult isIntersectionInhabited(TypeId left, TypeId right, Set<TypeId>& seenSet);
NormalizationResult isIntersectionInhabited(TypeId left, TypeId right, SeenTablePropPairs& seenTablePropPairs, Set<TypeId>& seenSet);
// -------- Convert back from a normalized type to a type
TypeId typeFromNormal(const NormalizedType& norm);

View file

@ -2,12 +2,13 @@
#pragma once
#include "Luau/Ast.h"
#include "Luau/InsertionOrderedMap.h"
#include "Luau/NotNull.h"
#include "Luau/TypeFwd.h"
#include "Luau/Location.h"
#include "Luau/EqSatSimplification.h"
#include "Luau/Error.h"
#include "Luau/InsertionOrderedMap.h"
#include "Luau/Location.h"
#include "Luau/NotNull.h"
#include "Luau/Subtyping.h"
#include "Luau/TypeFwd.h"
namespace Luau
{
@ -34,7 +35,9 @@ struct OverloadResolver
OverloadResolver(
NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> arena,
NotNull<Simplifier> simplifier,
NotNull<Normalizer> normalizer,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<Scope> scope,
NotNull<InternalErrorReporter> reporter,
NotNull<TypeCheckLimits> limits,
@ -43,7 +46,9 @@ struct OverloadResolver
NotNull<BuiltinTypes> builtinTypes;
NotNull<TypeArena> arena;
NotNull<Simplifier> simplifier;
NotNull<Normalizer> normalizer;
NotNull<TypeFunctionRuntime> typeFunctionRuntime;
NotNull<Scope> scope;
NotNull<InternalErrorReporter> ice;
NotNull<TypeCheckLimits> limits;
@ -108,7 +113,9 @@ struct SolveResult
SolveResult solveFunctionCall(
NotNull<TypeArena> arena,
NotNull<BuiltinTypes> builtinTypes,
NotNull<Simplifier> simplifier,
NotNull<Normalizer> normalizer,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<InternalErrorReporter> iceReporter,
NotNull<TypeCheckLimits> limits,
NotNull<Scope> scope,

View file

@ -11,14 +11,12 @@
namespace Luau
{
class AstStat;
class AstExpr;
class AstNode;
class AstStatBlock;
struct AstLocal;
struct RequireTraceResult
{
DenseHashMap<const AstExpr*, ModuleInfo> exprs{nullptr};
DenseHashMap<const AstNode*, ModuleInfo> exprs{nullptr};
std::vector<std::pair<ModuleName, Location>> requireList;
};

View file

@ -85,12 +85,18 @@ struct Scope
void inheritAssignments(const ScopePtr& childScope);
void inheritRefinements(const ScopePtr& childScope);
// Track globals that should emit warnings during type checking.
DenseHashSet<std::string> globalsToWarn{""};
bool shouldWarnGlobal(std::string name) const;
// For mutually recursive type aliases, it's important that
// they use the same types for the same names.
// For instance, in `type Tree<T> { data: T, children: Forest<T> } type Forest<T> = {Tree<T>}`
// we need that the generic type `T` in both cases is the same, so we use a cache.
std::unordered_map<Name, TypeId> typeAliasTypeParameters;
std::unordered_map<Name, TypePackId> typeAliasTypePackParameters;
std::optional<std::vector<TypeId>> interiorFreeTypes;
};
// Returns true iff the left scope encloses the right scope. A Scope* equal to

View file

@ -19,10 +19,10 @@ struct SimplifyResult
DenseHashSet<TypeId> blockedTypes;
};
SimplifyResult simplifyIntersection(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, TypeId ty, TypeId discriminant);
SimplifyResult simplifyIntersection(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, TypeId left, TypeId right);
SimplifyResult simplifyIntersection(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, std::set<TypeId> parts);
SimplifyResult simplifyUnion(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, TypeId ty, TypeId discriminant);
SimplifyResult simplifyUnion(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, TypeId left, TypeId right);
enum class Relation
{

View file

@ -1,13 +1,14 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/DenseHash.h"
#include "Luau/EqSatSimplification.h"
#include "Luau/Set.h"
#include "Luau/TypeCheckLimits.h"
#include "Luau/TypeFunction.h"
#include "Luau/TypeFwd.h"
#include "Luau/TypePairHash.h"
#include "Luau/TypePath.h"
#include "Luau/TypeFunction.h"
#include "Luau/TypeCheckLimits.h"
#include "Luau/DenseHash.h"
#include <vector>
#include <optional>
@ -96,6 +97,22 @@ struct SubtypingEnvironment
DenseHashSet<TypeId> upperBound{nullptr};
};
/* For nested subtyping relationship tests of mapped generic bounds, we keep the outer environment immutable */
SubtypingEnvironment* parent = nullptr;
/// Applies `mappedGenerics` to the given type.
/// This is used specifically to substitute for generics in type function instances.
std::optional<TypeId> applyMappedGenerics(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, TypeId ty);
const TypeId* tryFindSubstitution(TypeId ty) const;
const SubtypingResult* tryFindSubtypingResult(std::pair<TypeId, TypeId> subAndSuper) const;
bool containsMappedType(TypeId ty) const;
bool containsMappedPack(TypePackId tp) const;
GenericBounds& getMappedTypeBounds(TypeId ty);
TypePackId* getMappedPackBounds(TypePackId tp);
/*
* When we encounter a generic over the course of a subtyping test, we need
* to tentatively map that generic onto a type on the other side.
@ -112,17 +129,15 @@ struct SubtypingEnvironment
DenseHashMap<TypeId, TypeId> substitutions{nullptr};
DenseHashMap<std::pair<TypeId, TypeId>, SubtypingResult, TypePairHash> ephemeralCache{{}};
/// Applies `mappedGenerics` to the given type.
/// This is used specifically to substitute for generics in type function instances.
std::optional<TypeId> applyMappedGenerics(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, TypeId ty);
};
struct Subtyping
{
NotNull<BuiltinTypes> builtinTypes;
NotNull<TypeArena> arena;
NotNull<Simplifier> simplifier;
NotNull<Normalizer> normalizer;
NotNull<TypeFunctionRuntime> typeFunctionRuntime;
NotNull<InternalErrorReporter> iceReporter;
TypeCheckLimits limits;
@ -142,7 +157,9 @@ struct Subtyping
Subtyping(
NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> typeArena,
NotNull<Simplifier> simplifier,
NotNull<Normalizer> normalizer,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<InternalErrorReporter> iceReporter
);

View file

@ -6,6 +6,8 @@
#include "Luau/NotNull.h"
#include "Luau/TypeFwd.h"
#include <vector>
namespace Luau
{

View file

@ -44,6 +44,7 @@ struct ToStringOptions
bool hideTableKind = false; // If true, all tables will be surrounded with plain '{}'
bool hideNamedFunctionTypeParameters = false; // If true, type parameters of functions will be hidden at top-level.
bool hideFunctionSelfArgument = false; // If true, `self: X` will be omitted from the function signature if the function has self
bool useQuestionMarks = true; // If true, use a postfix ? for options, else write them out as unions that include nil.
size_t maxTableLength = size_t(FInt::LuauTableTypeMaximumStringifierLength); // Only applied to TableTypes
size_t maxTypeLength = size_t(FInt::LuauTypeMaximumStringifierLength);
size_t compositeTypesSingleLineLimit = 5; // The number of type elements permitted on a single line when printing type unions/intersections

View file

@ -65,11 +65,10 @@ T* getMutable(PendingTypePack* pending)
// Log of what TypeIds we are rebinding, to be committed later.
struct TxnLog
{
explicit TxnLog(bool useScopes = false)
explicit TxnLog()
: typeVarChanges(nullptr)
, typePackChanges(nullptr)
, ownedSeen()
, useScopes(useScopes)
, sharedSeen(&ownedSeen)
{
}

View file

@ -31,6 +31,7 @@ namespace Luau
struct TypeArena;
struct Scope;
using ScopePtr = std::shared_ptr<Scope>;
struct Module;
struct TypeFunction;
struct Constraint;
@ -68,12 +69,16 @@ using Name = std::string;
// A free type is one whose exact shape has yet to be fully determined.
struct FreeType
{
// New constructors
explicit FreeType(TypeLevel level, TypeId lowerBound, TypeId upperBound);
// This one got promoted to explicit
explicit FreeType(Scope* scope, TypeId lowerBound, TypeId upperBound);
explicit FreeType(Scope* scope, TypeLevel level, TypeId lowerBound, TypeId upperBound);
// Old constructors
explicit FreeType(TypeLevel level);
explicit FreeType(Scope* scope);
FreeType(Scope* scope, TypeLevel level);
FreeType(Scope* scope, TypeId lowerBound, TypeId upperBound);
int index;
TypeLevel level;
Scope* scope = nullptr;
@ -130,14 +135,14 @@ struct BlockedType
BlockedType();
int index;
Constraint* getOwner() const;
void setOwner(Constraint* newOwner);
void replaceOwner(Constraint* newOwner);
const Constraint* getOwner() const;
void setOwner(const Constraint* newOwner);
void replaceOwner(const Constraint* newOwner);
private:
// The constraint that is intended to unblock this type. Other constraints
// should block on this constraint if present.
Constraint* owner = nullptr;
const Constraint* owner = nullptr;
};
struct PrimitiveType
@ -278,9 +283,6 @@ struct WithPredicate
}
};
using MagicFunction = std::function<std::optional<
WithPredicate<TypePackId>>(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>)>;
struct MagicFunctionCallContext
{
NotNull<struct ConstraintSolver> solver;
@ -290,7 +292,6 @@ struct MagicFunctionCallContext
TypePackId result;
};
using DcrMagicFunction = std::function<bool(MagicFunctionCallContext)>;
struct MagicRefinementContext
{
NotNull<Scope> scope;
@ -307,8 +308,30 @@ struct MagicFunctionTypeCheckContext
NotNull<Scope> checkScope;
};
using DcrMagicRefinement = void (*)(const MagicRefinementContext&);
using DcrMagicFunctionTypeCheck = std::function<void(const MagicFunctionTypeCheckContext&)>;
struct MagicFunction
{
virtual std::optional<WithPredicate<TypePackId>>
handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) = 0;
// Callback to allow custom typechecking of builtin function calls whose argument types
// will only be resolved after constraint solving. For example, the arguments to string.format
// have types that can only be decided after parsing the format string and unifying
// with the passed in values, but the correctness of the call can only be decided after
// all the types have been finalized.
virtual bool infer(const MagicFunctionCallContext&) = 0;
virtual void refine(const MagicRefinementContext&) {}
// If a magic function needs to do its own special typechecking, do it here.
// Returns true if magic typechecking was performed. Return false if the
// default typechecking logic should run.
virtual bool typeCheck(const MagicFunctionTypeCheckContext&)
{
return false;
}
virtual ~MagicFunction() {}
};
struct FunctionType
{
// Global monomorphic function
@ -366,16 +389,7 @@ struct FunctionType
Scope* scope = nullptr;
TypePackId argTypes;
TypePackId retTypes;
MagicFunction magicFunction = nullptr;
DcrMagicFunction dcrMagicFunction = nullptr;
DcrMagicRefinement dcrMagicRefinement = nullptr;
// Callback to allow custom typechecking of builtin function calls whose argument types
// will only be resolved after constraint solving. For example, the arguments to string.format
// have types that can only be decided after parsing the format string and unifying
// with the passed in values, but the correctness of the call can only be decided after
// all the types have been finalized.
DcrMagicFunctionTypeCheck dcrMagicTypeCheck = nullptr;
std::shared_ptr<MagicFunction> magic = nullptr;
bool hasSelf;
// `hasNoFreeOrGenericTypes` should be true if and only if the type does not have any free or generic types present inside it.
@ -598,6 +612,19 @@ struct ClassType
}
};
// Data required to initialize a user-defined function and its environment
struct UserDefinedFunctionData
{
// Store a weak module reference to ensure the lifetime requirements are preserved
std::weak_ptr<Module> owner;
// References to AST elements are owned by the Module allocator which also stores this type
AstStatTypeFunction* definition = nullptr;
DenseHashMap<Name, std::pair<AstStatTypeFunction*, size_t>> environment{""};
DenseHashMap<Name, AstStatTypeFunction*> environment_DEPRECATED{""};
};
/**
* An instance of a type function that has not yet been reduced to a more concrete
* type. The constraint solver receives a constraint to reduce each
@ -612,21 +639,21 @@ struct TypeFunctionInstanceType
std::vector<TypeId> typeArguments;
std::vector<TypePackId> packArguments;
std::optional<AstName> userFuncName; // Name of the user-defined type function; only available for UDTFs
std::optional<AstExprFunction*> userFuncBody; // Body of the user-defined type function; only available for UDTFs
std::optional<AstName> userFuncName; // Name of the user-defined type function; only available for UDTFs
UserDefinedFunctionData userFuncData;
TypeFunctionInstanceType(
NotNull<const TypeFunction> function,
std::vector<TypeId> typeArguments,
std::vector<TypePackId> packArguments,
std::optional<AstName> userFuncName = std::nullopt,
std::optional<AstExprFunction*> userFuncBody = std::nullopt
std::optional<AstName> userFuncName,
UserDefinedFunctionData userFuncData
)
: function(function)
, typeArguments(typeArguments)
, packArguments(packArguments)
, userFuncName(userFuncName)
, userFuncBody(userFuncBody)
, userFuncData(userFuncData)
{
}
@ -643,6 +670,13 @@ struct TypeFunctionInstanceType
, packArguments(packArguments)
{
}
TypeFunctionInstanceType(NotNull<const TypeFunction> function, std::vector<TypeId> typeArguments, std::vector<TypePackId> packArguments)
: function{function}
, typeArguments(typeArguments)
, packArguments(packArguments)
{
}
};
/** Represents a pending type alias instantiation.
@ -670,6 +704,11 @@ struct AnyType
{
};
// A special, trivial type for the refinement system that is always eliminated from intersections.
struct NoRefineType
{
};
// `T | U`
struct UnionType
{
@ -737,7 +776,7 @@ struct NegationType
TypeId ty;
};
using ErrorType = Unifiable::Error;
using ErrorType = Unifiable::Error<TypeId>;
using TypeVariant = Unifiable::Variant<
TypeId,
@ -758,6 +797,7 @@ using TypeVariant = Unifiable::Variant<
UnknownType,
NeverType,
NegationType,
NoRefineType,
TypeFunctionInstanceType>;
struct Type final
@ -803,6 +843,13 @@ struct Type final
Type& operator=(const TypeVariant& rhs);
Type& operator=(TypeVariant&& rhs);
Type(Type&&) = default;
Type& operator=(Type&&) = default;
Type clone() const;
private:
Type(const Type&) = default;
Type& operator=(const Type& rhs);
};
@ -952,6 +999,7 @@ public:
const TypeId unknownType;
const TypeId neverType;
const TypeId errorType;
const TypeId noRefineType;
const TypeId falsyType;
const TypeId truthyType;
@ -1159,6 +1207,10 @@ TypeId freshType(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtinTypes, S
using TypeIdPredicate = std::function<std::optional<TypeId>(TypeId)>;
std::vector<TypeId> filterMap(TypeId type, TypeIdPredicate predicate);
// A tag to mark a type which doesn't derive directly from the root type as overriding the return of `typeof`.
// Any classes which derive from this type will have typeof return this type.
static constexpr char kTypeofRootTag[] = "typeofRoot";
void attachTag(TypeId ty, const std::string& tagName);
void attachTag(Property& prop, const std::string& tagName);

View file

@ -32,9 +32,13 @@ struct TypeArena
TypeId addTV(Type&& tv);
TypeId freshType(TypeLevel level);
TypeId freshType(Scope* scope);
TypeId freshType(Scope* scope, TypeLevel level);
TypeId freshType(NotNull<BuiltinTypes> builtins, TypeLevel level);
TypeId freshType(NotNull<BuiltinTypes> builtins, Scope* scope);
TypeId freshType(NotNull<BuiltinTypes> builtins, Scope* scope, TypeLevel level);
TypeId freshType_DEPRECATED(TypeLevel level);
TypeId freshType_DEPRECATED(Scope* scope);
TypeId freshType_DEPRECATED(Scope* scope, TypeLevel level);
TypePackId freshTypePack(Scope* scope);

View file

@ -2,15 +2,16 @@
#pragma once
#include "Luau/Error.h"
#include "Luau/NotNull.h"
#include "Luau/Common.h"
#include "Luau/TypeUtils.h"
#include "Luau/EqSatSimplification.h"
#include "Luau/Error.h"
#include "Luau/Normalize.h"
#include "Luau/NotNull.h"
#include "Luau/Subtyping.h"
#include "Luau/Type.h"
#include "Luau/TypeFwd.h"
#include "Luau/TypeOrPack.h"
#include "Luau/Normalize.h"
#include "Luau/Subtyping.h"
#include "Luau/TypeUtils.h"
namespace Luau
{
@ -60,7 +61,9 @@ struct Reasonings
void check(
NotNull<BuiltinTypes> builtinTypes,
NotNull<UnifierSharedState> sharedState,
NotNull<Simplifier> simplifier,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<UnifierSharedState> unifierState,
NotNull<TypeCheckLimits> limits,
DcrLogger* logger,
const SourceModule& sourceModule,
@ -70,6 +73,8 @@ void check(
struct TypeChecker2
{
NotNull<BuiltinTypes> builtinTypes;
NotNull<Simplifier> simplifier;
NotNull<TypeFunctionRuntime> typeFunctionRuntime;
DcrLogger* logger;
const NotNull<TypeCheckLimits> limits;
const NotNull<InternalErrorReporter> ice;
@ -88,6 +93,8 @@ struct TypeChecker2
TypeChecker2(
NotNull<BuiltinTypes> builtinTypes,
NotNull<Simplifier> simplifier,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<UnifierSharedState> unifierState,
NotNull<TypeCheckLimits> limits,
DcrLogger* logger,
@ -109,14 +116,14 @@ private:
std::optional<StackPusher> pushStack(AstNode* node);
void checkForInternalTypeFunction(TypeId ty, Location location);
TypeId checkForTypeFunctionInhabitance(TypeId instance, Location location);
TypePackId lookupPack(AstExpr* expr);
TypePackId lookupPack(AstExpr* expr) const;
TypeId lookupType(AstExpr* expr);
TypeId lookupAnnotation(AstType* annotation);
std::optional<TypePackId> lookupPackAnnotation(AstTypePack* annotation);
TypeId lookupExpectedType(AstExpr* expr);
TypePackId lookupExpectedPack(AstExpr* expr, TypeArena& arena);
std::optional<TypePackId> lookupPackAnnotation(AstTypePack* annotation) const;
TypeId lookupExpectedType(AstExpr* expr) const;
TypePackId lookupExpectedPack(AstExpr* expr, TypeArena& arena) const;
TypePackId reconstructPack(AstArray<AstExpr*> exprs, TypeArena& arena);
Scope* findInnermostScope(Location location);
Scope* findInnermostScope(Location location) const;
void visit(AstStat* stat);
void visit(AstStatIf* ifStatement);
void visit(AstStatWhile* whileStatement);
@ -153,7 +160,7 @@ private:
void visit(AstExprVarargs* expr);
void visitCall(AstExprCall* call);
void visit(AstExprCall* call);
std::optional<TypeId> tryStripUnionFromNil(TypeId ty);
std::optional<TypeId> tryStripUnionFromNil(TypeId ty) const;
TypeId stripFromNilAndReport(TypeId ty, const Location& location);
void visitExprName(AstExpr* expr, Location location, const std::string& propName, ValueContext context, TypeId astIndexExprTy);
void visit(AstExprIndexName* indexName, ValueContext context);
@ -168,7 +175,7 @@ private:
void visit(AstExprInterpString* interpString);
void visit(AstExprError* expr);
TypeId flattenPack(TypePackId pack);
void visitGenerics(AstArray<AstGenericType> generics, AstArray<AstGenericTypePack> genericPacks);
void visitGenerics(AstArray<AstGenericType*> generics, AstArray<AstGenericTypePack*> genericPacks);
void visit(AstType* ty);
void visit(AstTypeReference* ty);
void visit(AstTypeTable* table);
@ -210,6 +217,9 @@ private:
std::vector<TypeError>& errors
);
// Avoid duplicate warnings being emitted for the same global variable.
DenseHashSet<std::string> warnedGlobals{""};
void diagnoseMissingTableKey(UnknownProperty* utk, TypeErrorData& data) const;
bool isErrorSuppressing(Location loc, TypeId ty);
bool isErrorSuppressing(Location loc1, TypeId ty1, Location loc2, TypeId ty2);

View file

@ -1,29 +1,68 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/ConstraintSolver.h"
#include "Luau/Constraint.h"
#include "Luau/EqSatSimplification.h"
#include "Luau/Error.h"
#include "Luau/NotNull.h"
#include "Luau/TypeCheckLimits.h"
#include "Luau/TypeFunctionRuntime.h"
#include "Luau/TypeFwd.h"
#include <functional>
#include <string>
#include <optional>
struct lua_State;
namespace Luau
{
struct TypeArena;
struct TxnLog;
struct ConstraintSolver;
class Normalizer;
using StateRef = std::unique_ptr<lua_State, void (*)(lua_State*)>;
struct TypeFunctionRuntime
{
TypeFunctionRuntime(NotNull<InternalErrorReporter> ice, NotNull<TypeCheckLimits> limits);
~TypeFunctionRuntime();
// Return value is an error message if registration failed
std::optional<std::string> registerFunction(AstStatTypeFunction* function);
// For user-defined type functions, we store all generated types and packs for the duration of the typecheck
TypedAllocator<TypeFunctionType> typeArena;
TypedAllocator<TypeFunctionTypePackVar> typePackArena;
NotNull<InternalErrorReporter> ice;
NotNull<TypeCheckLimits> limits;
StateRef state;
// Set of functions which have their environment table initialized
DenseHashSet<AstStatTypeFunction*> initialized{nullptr};
// Evaluation of type functions should only be performed in the absence of parse errors in the source module
bool allowEvaluation = true;
// Output created by 'print' function
std::vector<std::string> messages;
private:
void prepareState();
};
struct TypeFunctionContext
{
NotNull<TypeArena> arena;
NotNull<BuiltinTypes> builtins;
NotNull<Scope> scope;
NotNull<Simplifier> simplifier;
NotNull<Normalizer> normalizer;
NotNull<TypeFunctionRuntime> typeFunctionRuntime;
NotNull<InternalErrorReporter> ice;
NotNull<TypeCheckLimits> limits;
@ -32,33 +71,26 @@ struct TypeFunctionContext
// The constraint being reduced in this run of the reduction
const Constraint* constraint;
std::optional<AstName> userFuncName; // Name of the user-defined type function; only available for UDTFs
std::optional<AstExprFunction*> userFuncBody; // Body of the user-defined type function; only available for UDTFs
std::optional<AstName> userFuncName; // Name of the user-defined type function; only available for UDTFs
TypeFunctionContext(NotNull<ConstraintSolver> cs, NotNull<Scope> scope, NotNull<const Constraint> constraint)
: arena(cs->arena)
, builtins(cs->builtinTypes)
, scope(scope)
, normalizer(cs->normalizer)
, ice(NotNull{&cs->iceReporter})
, limits(NotNull{&cs->limits})
, solver(cs.get())
, constraint(constraint.get())
{
}
TypeFunctionContext(NotNull<ConstraintSolver> cs, NotNull<Scope> scope, NotNull<const Constraint> constraint);
TypeFunctionContext(
NotNull<TypeArena> arena,
NotNull<BuiltinTypes> builtins,
NotNull<Scope> scope,
NotNull<Simplifier> simplifier,
NotNull<Normalizer> normalizer,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<InternalErrorReporter> ice,
NotNull<TypeCheckLimits> limits
)
: arena(arena)
, builtins(builtins)
, scope(scope)
, simplifier(simplifier)
, normalizer(normalizer)
, typeFunctionRuntime(typeFunctionRuntime)
, ice(ice)
, limits(limits)
, solver(nullptr)
@ -66,7 +98,17 @@ struct TypeFunctionContext
{
}
NotNull<Constraint> pushConstraint(ConstraintV&& c);
NotNull<Constraint> pushConstraint(ConstraintV&& c) const;
};
enum class Reduction
{
// The type function is either known to be reducible or the determination is blocked.
MaybeOk,
// The type function is known to be irreducible, but maybe not be erroneous, e.g. when it's over generics or free types.
Irreducible,
// The type function is known to be irreducible, and is definitely erroneous.
Erroneous,
};
/// Represents a reduction result, which may have successfully reduced the type,
@ -75,19 +117,25 @@ struct TypeFunctionContext
template<typename Ty>
struct TypeFunctionReductionResult
{
/// The result of the reduction, if any. If this is nullopt, the type function
/// could not be reduced.
std::optional<Ty> result;
/// Whether the result is uninhabited: whether we know, unambiguously and
/// permanently, whether this type function reduction results in an
/// uninhabitable type. This will trigger an error to be reported.
bool uninhabited;
/// Indicates the status of this reduction: is `Reduction::Irreducible` if
/// the this result indicates the type function is irreducible, and
/// `Reduction::Erroneous` if this result indicates the type function is
/// erroneous. `Reduction::MaybeOk` otherwise.
Reduction reductionStatus;
/// Any types that need to be progressed or mutated before the reduction may
/// proceed.
std::vector<TypeId> blockedTypes;
/// Any type packs that need to be progressed or mutated before the
/// reduction may proceed.
std::vector<TypePackId> blockedPacks;
/// A runtime error message from user-defined type functions
std::optional<std::string> error;
/// Messages printed out from user-defined type functions
std::vector<std::string> messages;
};
template<typename T>
@ -121,6 +169,7 @@ struct TypePackFunction
struct FunctionGraphReductionResult
{
ErrorVec errors;
ErrorVec messages;
DenseHashSet<TypeId> blockedTypes{nullptr};
DenseHashSet<TypePackId> blockedPacks{nullptr};
DenseHashSet<TypeId> reducedTypes{nullptr};
@ -192,6 +241,9 @@ struct BuiltinTypeFunctions
TypeFunction indexFunc;
TypeFunction rawgetFunc;
TypeFunction setmetatableFunc;
TypeFunction getmetatableFunc;
void addToScope(NotNull<TypeArena> arena, NotNull<Scope> scope) const;
};

View file

@ -0,0 +1,298 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Common.h"
#include "Luau/Variant.h"
#include "Luau/TypeFwd.h"
#include <optional>
#include <string>
#include <map>
#include <vector>
using lua_State = struct lua_State;
namespace Luau
{
void* typeFunctionAlloc(void* ud, void* ptr, size_t osize, size_t nsize);
// Replica of types from Type.h
struct TypeFunctionType;
using TypeFunctionTypeId = const TypeFunctionType*;
struct TypeFunctionTypePackVar;
using TypeFunctionTypePackId = const TypeFunctionTypePackVar*;
struct TypeFunctionPrimitiveType
{
enum Type
{
NilType,
Boolean,
Number,
String,
Thread,
Buffer,
};
Type type;
TypeFunctionPrimitiveType(Type type)
: type(type)
{
}
};
struct TypeFunctionBooleanSingleton
{
bool value = false;
};
struct TypeFunctionStringSingleton
{
std::string value;
};
using TypeFunctionSingletonVariant = Variant<TypeFunctionBooleanSingleton, TypeFunctionStringSingleton>;
struct TypeFunctionSingletonType
{
TypeFunctionSingletonVariant variant;
explicit TypeFunctionSingletonType(TypeFunctionSingletonVariant variant)
: variant(std::move(variant))
{
}
};
template<typename T>
const T* get(const TypeFunctionSingletonType* tv)
{
LUAU_ASSERT(tv);
return tv ? get_if<T>(&tv->variant) : nullptr;
}
template<typename T>
T* getMutable(const TypeFunctionSingletonType* tv)
{
LUAU_ASSERT(tv);
return tv ? get_if<T>(&const_cast<TypeFunctionSingletonType*>(tv)->variant) : nullptr;
}
struct TypeFunctionUnionType
{
std::vector<TypeFunctionTypeId> components;
};
struct TypeFunctionIntersectionType
{
std::vector<TypeFunctionTypeId> components;
};
struct TypeFunctionAnyType
{
};
struct TypeFunctionUnknownType
{
};
struct TypeFunctionNeverType
{
};
struct TypeFunctionNegationType
{
TypeFunctionTypeId type;
};
struct TypeFunctionTypePack
{
std::vector<TypeFunctionTypeId> head;
std::optional<TypeFunctionTypePackId> tail;
};
struct TypeFunctionVariadicTypePack
{
TypeFunctionTypeId type;
};
struct TypeFunctionGenericTypePack
{
bool isNamed = false;
std::string name;
};
using TypeFunctionTypePackVariant = Variant<TypeFunctionTypePack, TypeFunctionVariadicTypePack, TypeFunctionGenericTypePack>;
struct TypeFunctionTypePackVar
{
TypeFunctionTypePackVariant type;
TypeFunctionTypePackVar(TypeFunctionTypePackVariant type)
: type(std::move(type))
{
}
bool operator==(const TypeFunctionTypePackVar& rhs) const;
};
struct TypeFunctionFunctionType
{
std::vector<TypeFunctionTypeId> generics;
std::vector<TypeFunctionTypePackId> genericPacks;
TypeFunctionTypePackId argTypes;
TypeFunctionTypePackId retTypes;
};
template<typename T>
const T* get(TypeFunctionTypePackId tv)
{
LUAU_ASSERT(tv);
return tv ? get_if<T>(&tv->type) : nullptr;
}
template<typename T>
T* getMutable(TypeFunctionTypePackId tv)
{
LUAU_ASSERT(tv);
return tv ? get_if<T>(&const_cast<TypeFunctionTypePackVar*>(tv)->type) : nullptr;
}
struct TypeFunctionTableIndexer
{
TypeFunctionTableIndexer(TypeFunctionTypeId keyType, TypeFunctionTypeId valueType)
: keyType(keyType)
, valueType(valueType)
{
}
TypeFunctionTypeId keyType;
TypeFunctionTypeId valueType;
};
struct TypeFunctionProperty
{
static TypeFunctionProperty readonly(TypeFunctionTypeId ty);
static TypeFunctionProperty writeonly(TypeFunctionTypeId ty);
static TypeFunctionProperty rw(TypeFunctionTypeId ty); // Shared read-write type.
static TypeFunctionProperty rw(TypeFunctionTypeId read, TypeFunctionTypeId write); // Separate read-write type.
bool isReadOnly() const;
bool isWriteOnly() const;
std::optional<TypeFunctionTypeId> readTy;
std::optional<TypeFunctionTypeId> writeTy;
};
struct TypeFunctionTableType
{
using Name = std::string;
using Props = std::map<Name, TypeFunctionProperty>;
Props props;
std::optional<TypeFunctionTableIndexer> indexer;
// Should always be a TypeFunctionTableType
std::optional<TypeFunctionTypeId> metatable;
};
struct TypeFunctionClassType
{
using Name = std::string;
using Props = std::map<Name, TypeFunctionProperty>;
Props props;
std::optional<TypeFunctionTableIndexer> indexer;
std::optional<TypeFunctionTypeId> metatable; // metaclass?
// this was mistaken, and we should actually be keeping separate read/write types here.
std::optional<TypeFunctionTypeId> parent_DEPRECATED;
std::optional<TypeFunctionTypeId> readParent;
std::optional<TypeFunctionTypeId> writeParent;
TypeId classTy;
std::string name_DEPRECATED;
};
struct TypeFunctionGenericType
{
bool isNamed = false;
bool isPack = false;
std::string name;
};
using TypeFunctionTypeVariant = Luau::Variant<
TypeFunctionPrimitiveType,
TypeFunctionAnyType,
TypeFunctionUnknownType,
TypeFunctionNeverType,
TypeFunctionSingletonType,
TypeFunctionUnionType,
TypeFunctionIntersectionType,
TypeFunctionNegationType,
TypeFunctionFunctionType,
TypeFunctionTableType,
TypeFunctionClassType,
TypeFunctionGenericType>;
struct TypeFunctionType
{
TypeFunctionTypeVariant type;
TypeFunctionType(TypeFunctionTypeVariant type)
: type(std::move(type))
{
}
bool operator==(const TypeFunctionType& rhs) const;
};
template<typename T>
const T* get(TypeFunctionTypeId tv)
{
LUAU_ASSERT(tv);
return tv ? Luau::get_if<T>(&tv->type) : nullptr;
}
template<typename T>
T* getMutable(TypeFunctionTypeId tv)
{
LUAU_ASSERT(tv);
return tv ? Luau::get_if<T>(&const_cast<TypeFunctionType*>(tv)->type) : nullptr;
}
std::optional<std::string> checkResultForError(lua_State* L, const char* typeFunctionName, int luaResult);
TypeFunctionType* allocateTypeFunctionType(lua_State* L, TypeFunctionTypeVariant type);
TypeFunctionTypePackVar* allocateTypeFunctionTypePack(lua_State* L, TypeFunctionTypePackVariant type);
void allocTypeUserData(lua_State* L, TypeFunctionTypeVariant type);
bool isTypeUserData(lua_State* L, int idx);
TypeFunctionTypeId getTypeUserData(lua_State* L, int idx);
std::optional<TypeFunctionTypeId> optionalTypeUserData(lua_State* L, int idx);
void registerTypesLibrary(lua_State* L);
void registerTypeUserData(lua_State* L);
void setTypeFunctionEnvironment(lua_State* L);
void resetTypeFunctionState(lua_State* L);
} // namespace Luau

View file

@ -0,0 +1,50 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Type.h"
#include "Luau/TypeFunction.h"
#include "Luau/TypeFunctionRuntime.h"
namespace Luau
{
using Kind = Variant<TypeId, TypePackId>;
template<typename T>
const T* get(const Kind& kind)
{
return get_if<T>(&kind);
}
using TypeFunctionKind = Variant<TypeFunctionTypeId, TypeFunctionTypePackId>;
template<typename T>
const T* get(const TypeFunctionKind& tfkind)
{
return get_if<T>(&tfkind);
}
struct TypeFunctionRuntimeBuilderState
{
NotNull<TypeFunctionContext> ctx;
// Mapping of class name to ClassType
// Invariant: users can not create a new class types -> any class types that get deserialized must have been an argument to the type function
// Using this invariant, whenever a ClassType is serialized, we can put it into this map
// whenever a ClassType is deserialized, we can use this map to return the corresponding value
DenseHashMap<std::string, TypeId> classesSerialized_DEPRECATED{{}};
// List of errors that occur during serialization/deserialization
// At every iteration of serialization/deserialzation, if this list.size() != 0, we halt the process
std::vector<std::string> errors{};
TypeFunctionRuntimeBuilderState(NotNull<TypeFunctionContext> ctx)
: ctx(ctx)
{
}
};
TypeFunctionTypeId serialize(TypeId ty, TypeFunctionRuntimeBuilderState* state);
TypeId deserialize(TypeFunctionTypeId ty, TypeFunctionRuntimeBuilderState* state);
} // namespace Luau

View file

@ -399,8 +399,8 @@ private:
const ScopePtr& scope,
std::optional<TypeLevel> levelOpt,
const AstNode& node,
const AstArray<AstGenericType>& genericNames,
const AstArray<AstGenericTypePack>& genericPackNames,
const AstArray<AstGenericType*>& genericNames,
const AstArray<AstGenericTypePack*>& genericPackNames,
bool useCache = false
);

View file

@ -52,7 +52,7 @@ struct GenericTypePack
};
using BoundTypePack = Unifiable::Bound<TypePackId>;
using ErrorTypePack = Unifiable::Error;
using ErrorTypePack = Unifiable::Error<TypePackId>;
using TypePackVariant =
Unifiable::Variant<TypePackId, FreeTypePack, GenericTypePack, TypePack, VariadicTypePack, BlockedTypePack, TypeFunctionInstanceTypePack>;

View file

@ -51,6 +51,8 @@ struct Index
/// Represents fields of a type or pack that contain a type.
enum class TypeField
{
/// The table of a metatable type.
Table,
/// The metatable of a type. This could be a metatable type, a primitive
/// type, a class type, or perhaps even a string singleton type.
Metatable,

View file

@ -40,7 +40,7 @@ struct InConditionalContext
TypeContext* typeContext;
TypeContext oldValue;
InConditionalContext(TypeContext* c)
explicit InConditionalContext(TypeContext* c)
: typeContext(c)
, oldValue(*c)
{
@ -248,4 +248,45 @@ std::optional<Ty> follow(std::optional<Ty> ty)
return std::nullopt;
}
/**
* Returns whether or not expr is a literal expression, for example:
* - Scalar literals (numbers, booleans, strings, nil)
* - Table literals
* - Lambdas (a "function literal")
*/
bool isLiteral(const AstExpr* expr);
/**
* Given a table literal and a mapping from expression to type, determine
* whether any literal expression in this table depends on any blocked types.
* This is used as a precondition for bidirectional inference: be warned that
* the behavior of this algorithm is tightly coupled to that of bidirectional
* inference.
* @param expr Expression to search
* @param astTypes Mapping from AST node to TypeID
* @returns A vector of blocked types
*/
std::vector<TypeId> findBlockedTypesIn(AstExprTable* expr, NotNull<DenseHashMap<const AstExpr*, TypeId>> astTypes);
/**
* Given a function call and a mapping from expression to type, determine
* whether the type of any argument in said call in depends on a blocked types.
* This is used as a precondition for bidirectional inference: be warned that
* the behavior of this algorithm is tightly coupled to that of bidirectional
* inference.
* @param expr Expression to search
* @param astTypes Mapping from AST node to TypeID
* @returns A vector of blocked types
*/
std::vector<TypeId> findBlockedArgTypesIn(AstExprCall* expr, NotNull<DenseHashMap<const AstExpr*, TypeId>> astTypes);
/**
* Given a scope and a free type, find the closest parent that has a present
* `interiorFreeTypes` and append the given type to said list. This list will
* be generalized when the requiste `GeneralizationConstraint` is resolved.
* @param scope Initial scope this free type was attached to
* @param ty Free type to track.
*/
void trackInteriorFreeType(Scope* scope, TypeId ty);
} // namespace Luau

View file

@ -3,6 +3,7 @@
#include "Luau/Variant.h"
#include <optional>
#include <string>
namespace Luau
@ -94,19 +95,29 @@ struct Bound
Id boundTo;
};
template<typename Id>
struct Error
{
// This constructor has to be public, since it's used in Type and TypePack,
// but shouldn't be called directly. Please use errorRecoveryType() instead.
Error();
explicit Error();
explicit Error(Id synthetic)
: synthetic{synthetic}
{
}
int index;
// This is used to create an error that can be rendered out using this field
// as appropriate metadata for communicating it to the user.
std::optional<Id> synthetic;
private:
static int nextIndex;
};
template<typename Id, typename... Value>
using Variant = Luau::Variant<Bound<Id>, Error, Value...>;
using Variant = Luau::Variant<Bound<Id>, Error<Id>, Value...>;
} // namespace Luau::Unifiable

View file

@ -93,10 +93,6 @@ struct Unifier
Unifier(NotNull<Normalizer> normalizer, NotNull<Scope> scope, const Location& location, Variance variance, TxnLog* parentLog = nullptr);
// Configure the Unifier to test for scope subsumption via embedded Scope
// pointers rather than TypeLevels.
void enableNewSolver();
// Test whether the two type vars unify. Never commits the result.
ErrorVec canUnify(TypeId subTy, TypeId superTy);
ErrorVec canUnify(TypePackId subTy, TypePackId superTy, bool isFunctionCall = false);
@ -169,7 +165,6 @@ private:
std::optional<TypeId> findTablePropertyRespectingMeta(TypeId lhsType, Name name);
TxnLog combineLogsIntoIntersection(std::vector<TxnLog> logs);
TxnLog combineLogsIntoUnion(std::vector<TxnLog> logs);
public:
@ -179,7 +174,7 @@ public:
bool occursCheck(TypePackId needle, TypePackId haystack, bool reversed);
bool occursCheck(DenseHashSet<TypePackId>& seen, TypePackId needle, TypePackId haystack);
Unifier makeChildUnifier();
std::unique_ptr<Unifier> makeChildUnifier();
void reportError(TypeError err);
LUAU_NOINLINE void reportError(Location location, TypeErrorData data);
@ -195,11 +190,6 @@ private:
// Available after regular type pack unification errors
std::optional<int> firstPackErrorPos;
// If true, we do a bunch of small things differently to work better with
// the new type inference engine. Most notably, we use the Scope hierarchy
// directly rather than using TypeLevels.
bool useNewSolver = false;
};
void promoteTypeLevels(TxnLog& log, const TypeArena* arena, TypeLevel minLevel, Scope* outerScope, bool useScope, TypePackId tp);

View file

@ -49,6 +49,26 @@ struct UnifierSharedState
DenseHashSet<TypePackId> tempSeenTp{nullptr};
UnifierCounters counters;
bool reentrantTypeReduction = false;
};
struct TypeReductionRentrancyGuard final
{
explicit TypeReductionRentrancyGuard(NotNull<UnifierSharedState> sharedState)
: sharedState{sharedState}
{
sharedState->reentrantTypeReduction = true;
}
~TypeReductionRentrancyGuard()
{
sharedState->reentrantTypeReduction = false;
}
TypeReductionRentrancyGuard(const TypeReductionRentrancyGuard&) = delete;
TypeReductionRentrancyGuard(TypeReductionRentrancyGuard&&) = delete;
private:
NotNull<UnifierSharedState> sharedState;
};
} // namespace Luau

View file

@ -10,7 +10,6 @@
#include "Type.h"
LUAU_FASTINT(LuauVisitRecursionLimit)
LUAU_FASTFLAG(LuauBoundLazyTypes2)
LUAU_FASTFLAG(LuauSolverV2)
namespace Luau
@ -86,6 +85,8 @@ struct GenericTypeVisitor
{
}
virtual ~GenericTypeVisitor() {}
virtual void cycle(TypeId) {}
virtual void cycle(TypePackId) {}
@ -133,6 +134,10 @@ struct GenericTypeVisitor
{
return visit(ty);
}
virtual bool visit(TypeId ty, const NoRefineType& nrt)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const UnknownType& utv)
{
return visit(ty);
@ -186,7 +191,7 @@ struct GenericTypeVisitor
{
return visit(tp);
}
virtual bool visit(TypePackId tp, const Unifiable::Error& etp)
virtual bool visit(TypePackId tp, const ErrorTypePack& etp)
{
return visit(tp);
}
@ -345,6 +350,8 @@ struct GenericTypeVisitor
}
else if (auto atv = get<AnyType>(ty))
visit(ty, *atv);
else if (auto nrt = get<NoRefineType>(ty))
visit(ty, *nrt);
else if (auto utv = get<UnionType>(ty))
{
if (visit(ty, *utv))
@ -455,7 +462,7 @@ struct GenericTypeVisitor
else if (auto gtv = get<GenericTypePack>(tp))
visit(tp, *gtv);
else if (auto etv = get<Unifiable::Error>(tp))
else if (auto etv = get<ErrorTypePack>(tp))
visit(tp, *etv);
else if (auto pack = get<TypePack>(tp))

View file

@ -38,7 +38,7 @@
#include <stdio.h>
LUAU_FASTFLAGVARIABLE(StudioReportLuauAny2, false);
LUAU_FASTFLAGVARIABLE(StudioReportLuauAny2);
LUAU_FASTINTVARIABLE(LuauAnySummaryRecursionLimit, 300);
LUAU_FASTFLAG(DebugLuauMagicTypes);
@ -161,7 +161,7 @@ void AnyTypeSummary::visit(const Scope* scope, AstStatReturn* ret, const Module*
typeInfo.push_back(ti);
}
}
if (ret->list.size > 1 && !seenTP)
{
if (containsAny(retScope->returnType))
@ -177,7 +177,6 @@ void AnyTypeSummary::visit(const Scope* scope, AstStatReturn* ret, const Module*
}
}
}
}
void AnyTypeSummary::visit(const Scope* scope, AstStatLocal* local, const Module* module, NotNull<BuiltinTypes> builtinTypes)

View file

@ -425,6 +425,7 @@ struct AstJsonEncoder : public AstVisitor
"AstExprFunction",
[&]()
{
PROP(attributes);
PROP(generics);
PROP(genericPacks);
if (node->self)
@ -881,7 +882,7 @@ struct AstJsonEncoder : public AstVisitor
PROP(name);
PROP(generics);
PROP(genericPacks);
PROP(type);
write("value", node->type);
PROP(exported);
}
);
@ -894,7 +895,7 @@ struct AstJsonEncoder : public AstVisitor
"AstStatDeclareFunction",
[&]()
{
// TODO: attributes
PROP(attributes);
PROP(name);
PROP(nameLocation);
PROP(params);
@ -1042,6 +1043,7 @@ struct AstJsonEncoder : public AstVisitor
"AstTypeFunction",
[&]()
{
PROP(attributes);
PROP(generics);
PROP(genericPacks);
PROP(argTypes);
@ -1136,6 +1138,42 @@ struct AstJsonEncoder : public AstVisitor
);
}
void write(AstAttr::Type type)
{
switch (type)
{
case AstAttr::Type::Checked:
return writeString("checked");
case AstAttr::Type::Native:
return writeString("native");
}
}
void write(class AstAttr* node)
{
writeNode(
node,
"AstAttr",
[&]()
{
write("name", node->type);
}
);
}
bool visit(class AstTypeGroup* node) override
{
writeNode(
node,
"AstTypeGroup",
[&]()
{
write("inner", node->type);
}
);
return false;
}
bool visit(class AstTypeSingletonBool* node) override
{
writeNode(

View file

@ -13,6 +13,8 @@
LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAG(LuauExtendStatEndPosWithSemicolon)
namespace Luau
{
@ -41,11 +43,26 @@ struct AutocompleteNodeFinder : public AstVisitor
bool visit(AstStat* stat) override
{
if (stat->location.begin < pos && pos <= stat->location.end)
if (FFlag::LuauExtendStatEndPosWithSemicolon)
{
ancestry.push_back(stat);
return true;
// Consider 'local myLocal = 4;|' and 'local myLocal = 4', where '|' is the cursor position. In both cases, the cursor position is equal
// to `AstStatLocal.location.end`. However, in the first case (semicolon), we are starting a new statement, whilst in the second case
// (no semicolon) we are still part of the AstStatLocal, hence the different comparison check.
if (stat->location.begin < pos && (stat->hasSemicolon ? pos < stat->location.end : pos <= stat->location.end))
{
ancestry.push_back(stat);
return true;
}
}
else
{
if (stat->location.begin < pos && pos <= stat->location.end)
{
ancestry.push_back(stat);
return true;
}
}
return false;
}
@ -509,6 +526,37 @@ static std::optional<DocumentationSymbol> checkOverloadedDocumentationSymbol(
return documentationSymbol;
}
static std::optional<DocumentationSymbol> getMetatableDocumentation(
const Module& module,
AstExpr* parentExpr,
const TableType* mtable,
const AstName& index
)
{
auto indexIt = mtable->props.find("__index");
if (indexIt == mtable->props.end())
return std::nullopt;
TypeId followed = follow(indexIt->second.type());
const TableType* ttv = get<TableType>(followed);
if (!ttv)
return std::nullopt;
auto propIt = ttv->props.find(index.value);
if (propIt == ttv->props.end())
return std::nullopt;
if (FFlag::LuauSolverV2)
{
if (auto ty = propIt->second.readTy)
return checkOverloadedDocumentationSymbol(module, *ty, parentExpr, propIt->second.documentationSymbol);
}
else
return checkOverloadedDocumentationSymbol(module, propIt->second.type(), parentExpr, propIt->second.documentationSymbol);
return std::nullopt;
}
std::optional<DocumentationSymbol> getDocumentationSymbolAtPosition(const SourceModule& source, const Module& module, Position position)
{
std::vector<AstNode*> ancestry = findAstAncestryOfPosition(source, position);
@ -541,15 +589,29 @@ std::optional<DocumentationSymbol> getDocumentationSymbolAtPosition(const Source
}
else if (const ClassType* ctv = get<ClassType>(parentTy))
{
if (auto propIt = ctv->props.find(indexName->index.value); propIt != ctv->props.end())
while (ctv)
{
if (FFlag::LuauSolverV2)
if (auto propIt = ctv->props.find(indexName->index.value); propIt != ctv->props.end())
{
if (auto ty = propIt->second.readTy)
return checkOverloadedDocumentationSymbol(module, *ty, parentExpr, propIt->second.documentationSymbol);
if (FFlag::LuauSolverV2)
{
if (auto ty = propIt->second.readTy)
return checkOverloadedDocumentationSymbol(module, *ty, parentExpr, propIt->second.documentationSymbol);
}
else
return checkOverloadedDocumentationSymbol(
module, propIt->second.type(), parentExpr, propIt->second.documentationSymbol
);
}
else
return checkOverloadedDocumentationSymbol(module, propIt->second.type(), parentExpr, propIt->second.documentationSymbol);
ctv = ctv->parent ? Luau::get<Luau::ClassType>(*ctv->parent) : nullptr;
}
}
else if (const PrimitiveType* ptv = get<PrimitiveType>(parentTy); ptv && ptv->metatable)
{
if (auto mtable = get<TableType>(*ptv->metatable))
{
if (std::optional<std::string> docSymbol = getMetatableDocumentation(module, parentExpr, mtable, indexName->index))
return docSymbol;
}
}
}

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,27 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/AutocompleteTypes.h"
namespace Luau
{
struct Module;
struct FileResolver;
using ModulePtr = std::shared_ptr<Module>;
using ModuleName = std::string;
AutocompleteResult autocomplete_(
const ModulePtr& module,
NotNull<BuiltinTypes> builtinTypes,
TypeArena* typeArena,
std::vector<AstNode*>& ancestry,
Scope* globalScope,
const ScopePtr& scopeAtPosition,
Position position,
FileResolver* fileResolver,
StringCompletionCallback callback
);
} // namespace Luau

View file

@ -2,6 +2,9 @@
#include "Luau/BuiltinDefinitions.h"
#include "Luau/Ast.h"
#include "Luau/Clone.h"
#include "Luau/DenseHash.h"
#include "Luau/Error.h"
#include "Luau/Frontend.h"
#include "Luau/Symbol.h"
#include "Luau/Common.h"
@ -25,47 +28,93 @@
* about a function that takes any number of values, but where each value must have some specific type.
*/
LUAU_FASTFLAG(LuauSolverV2);
LUAU_FASTFLAGVARIABLE(LuauDCRMagicFunctionTypeChecker, false);
LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAGVARIABLE(LuauStringFormatErrorSuppression)
LUAU_FASTFLAGVARIABLE(LuauTableCloneClonesType3)
LUAU_FASTFLAG(LuauTrackInteriorFreeTypesOnScope)
LUAU_FASTFLAGVARIABLE(LuauFreezeIgnorePersistent)
LUAU_FASTFLAGVARIABLE(LuauFollowTableFreeze)
namespace Luau
{
static std::optional<WithPredicate<TypePackId>> magicFunctionSelect(
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
);
static std::optional<WithPredicate<TypePackId>> magicFunctionSetMetaTable(
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
);
static std::optional<WithPredicate<TypePackId>> magicFunctionAssert(
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
);
static std::optional<WithPredicate<TypePackId>> magicFunctionPack(
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
);
static std::optional<WithPredicate<TypePackId>> magicFunctionRequire(
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
);
struct MagicSelect final : MagicFunction
{
std::optional<WithPredicate<TypePackId>>
handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override;
};
struct MagicSetMetatable final : MagicFunction
{
std::optional<WithPredicate<TypePackId>>
handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override;
};
static bool dcrMagicFunctionSelect(MagicFunctionCallContext context);
static bool dcrMagicFunctionRequire(MagicFunctionCallContext context);
static bool dcrMagicFunctionPack(MagicFunctionCallContext context);
struct MagicAssert final : MagicFunction
{
std::optional<WithPredicate<TypePackId>>
handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override;
};
struct MagicPack final : MagicFunction
{
std::optional<WithPredicate<TypePackId>>
handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override;
};
struct MagicRequire final : MagicFunction
{
std::optional<WithPredicate<TypePackId>>
handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override;
};
struct MagicClone final : MagicFunction
{
std::optional<WithPredicate<TypePackId>>
handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override;
};
struct MagicFreeze final : MagicFunction
{
std::optional<WithPredicate<TypePackId>>
handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override;
};
struct MagicFormat final : MagicFunction
{
std::optional<WithPredicate<TypePackId>>
handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override;
bool typeCheck(const MagicFunctionTypeCheckContext& ctx) override;
};
struct MagicMatch final : MagicFunction
{
std::optional<WithPredicate<TypePackId>>
handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override;
};
struct MagicGmatch final : MagicFunction
{
std::optional<WithPredicate<TypePackId>>
handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override;
};
struct MagicFind final : MagicFunction
{
std::optional<WithPredicate<TypePackId>>
handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override;
};
TypeId makeUnion(TypeArena& arena, std::vector<TypeId>&& types)
{
@ -160,34 +209,10 @@ TypeId makeFunction(
return arena.addType(std::move(ftv));
}
void attachMagicFunction(TypeId ty, MagicFunction fn)
void attachMagicFunction(TypeId ty, std::shared_ptr<MagicFunction> magic)
{
if (auto ftv = getMutable<FunctionType>(ty))
ftv->magicFunction = fn;
else
LUAU_ASSERT(!"Got a non functional type");
}
void attachDcrMagicFunction(TypeId ty, DcrMagicFunction fn)
{
if (auto ftv = getMutable<FunctionType>(ty))
ftv->dcrMagicFunction = fn;
else
LUAU_ASSERT(!"Got a non functional type");
}
void attachDcrMagicRefinement(TypeId ty, DcrMagicRefinement fn)
{
if (auto ftv = getMutable<FunctionType>(ty))
ftv->dcrMagicRefinement = fn;
else
LUAU_ASSERT(!"Got a non functional type");
}
void attachDcrMagicFunctionTypeCheck(TypeId ty, DcrMagicFunctionTypeCheck fn)
{
if (auto ftv = getMutable<FunctionType>(ty))
ftv->dcrMagicTypeCheck = fn;
ftv->magic = std::move(magic);
else
LUAU_ASSERT(!"Got a non functional type");
}
@ -293,6 +318,28 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC
addGlobalBinding(globals, "string", it->second.type(), "@luau");
// Setup 'vector' metatable
if (auto it = globals.globalScope->exportedTypeBindings.find("vector"); it != globals.globalScope->exportedTypeBindings.end())
{
TypeId vectorTy = it->second.type;
ClassType* vectorCls = getMutable<ClassType>(vectorTy);
vectorCls->metatable = arena.addType(TableType{{}, std::nullopt, TypeLevel{}, TableState::Sealed});
TableType* metatableTy = Luau::getMutable<TableType>(vectorCls->metatable);
metatableTy->props["__add"] = {makeFunction(arena, vectorTy, {vectorTy}, {vectorTy})};
metatableTy->props["__sub"] = {makeFunction(arena, vectorTy, {vectorTy}, {vectorTy})};
metatableTy->props["__unm"] = {makeFunction(arena, vectorTy, {}, {vectorTy})};
std::initializer_list<TypeId> mulOverloads{
makeFunction(arena, vectorTy, {vectorTy}, {vectorTy}),
makeFunction(arena, vectorTy, {builtinTypes->numberType}, {vectorTy}),
};
metatableTy->props["__mul"] = {makeIntersection(arena, mulOverloads)};
metatableTy->props["__div"] = {makeIntersection(arena, mulOverloads)};
metatableTy->props["__idiv"] = {makeIntersection(arena, mulOverloads)};
}
// next<K, V>(t: Table<K, V>, i: K?) -> (K?, V)
TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(builtinTypes, arena, genericK)}});
TypePackId nextRetsTypePack = arena.addTypePack(TypePack{{makeOption(builtinTypes, arena, genericK), genericV}});
@ -363,7 +410,7 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC
}
}
attachMagicFunction(getGlobalBinding(globals, "assert"), magicFunctionAssert);
attachMagicFunction(getGlobalBinding(globals, "assert"), std::make_shared<MagicAssert>());
if (FFlag::LuauSolverV2)
{
@ -379,9 +426,8 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC
addGlobalBinding(globals, "assert", assertTy, "@luau");
}
attachMagicFunction(getGlobalBinding(globals, "setmetatable"), magicFunctionSetMetaTable);
attachMagicFunction(getGlobalBinding(globals, "select"), magicFunctionSelect);
attachDcrMagicFunction(getGlobalBinding(globals, "select"), dcrMagicFunctionSelect);
attachMagicFunction(getGlobalBinding(globals, "setmetatable"), std::make_shared<MagicSetMetatable>());
attachMagicFunction(getGlobalBinding(globals, "select"), std::make_shared<MagicSelect>());
if (TableType* ttv = getMutable<TableType>(getGlobalBinding(globals, "table")))
{
@ -394,8 +440,10 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC
// but it'll be ok for now.
TypeId genericTy = arena.addType(GenericType{"T"});
TypePackId thePack = arena.addTypePack({genericTy});
TypeId idTyWithMagic = arena.addType(FunctionType{{genericTy}, {}, thePack, thePack});
ttv->props["freeze"] = makeProperty(idTyWithMagic, "@luau/global/table.freeze");
TypeId idTy = arena.addType(FunctionType{{genericTy}, {}, thePack, thePack});
ttv->props["freeze"] = makeProperty(idTy, "@luau/global/table.freeze");
ttv->props["clone"] = makeProperty(idTy, "@luau/global/table.clone");
}
else
@ -410,12 +458,15 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC
ttv->props["foreach"].deprecated = true;
ttv->props["foreachi"].deprecated = true;
attachMagicFunction(ttv->props["pack"].type(), magicFunctionPack);
attachDcrMagicFunction(ttv->props["pack"].type(), dcrMagicFunctionPack);
attachMagicFunction(ttv->props["pack"].type(), std::make_shared<MagicPack>());
if (FFlag::LuauTableCloneClonesType3)
attachMagicFunction(ttv->props["clone"].type(), std::make_shared<MagicClone>());
attachMagicFunction(ttv->props["freeze"].type(), std::make_shared<MagicFreeze>());
}
attachMagicFunction(getGlobalBinding(globals, "require"), magicFunctionRequire);
attachDcrMagicFunction(getGlobalBinding(globals, "require"), dcrMagicFunctionRequire);
TypeId requireTy = getGlobalBinding(globals, "require");
attachTag(requireTy, kRequireTagName);
attachMagicFunction(requireTy, std::make_shared<MagicRequire>());
}
static std::vector<TypeId> parseFormatString(NotNull<BuiltinTypes> builtinTypes, const char* data, size_t size)
@ -454,7 +505,7 @@ static std::vector<TypeId> parseFormatString(NotNull<BuiltinTypes> builtinTypes,
return result;
}
std::optional<WithPredicate<TypePackId>> magicFunctionFormat(
std::optional<WithPredicate<TypePackId>> MagicFormat::handleOldSolver(
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
@ -504,7 +555,7 @@ std::optional<WithPredicate<TypePackId>> magicFunctionFormat(
return WithPredicate<TypePackId>{arena.addTypePack({typechecker.stringType})};
}
static bool dcrMagicFunctionFormat(MagicFunctionCallContext context)
bool MagicFormat::infer(const MagicFunctionCallContext& context)
{
TypeArena* arena = context.solver->arena;
@ -548,7 +599,7 @@ static bool dcrMagicFunctionFormat(MagicFunctionCallContext context)
return true;
}
static void dcrMagicFunctionTypeCheckFormat(MagicFunctionTypeCheckContext context)
bool MagicFormat::typeCheck(const MagicFunctionTypeCheckContext& context)
{
AstExprConstantString* fmt = nullptr;
if (auto index = context.callSite->func->as<AstExprIndexName>(); index && context.callSite->self)
@ -563,7 +614,10 @@ static void dcrMagicFunctionTypeCheckFormat(MagicFunctionTypeCheckContext contex
fmt = context.callSite->args.data[0]->as<AstExprConstantString>();
if (!fmt)
return;
{
context.typechecker->reportError(CountMismatch{1, std::nullopt, 0, CountMismatch::Arg, true, "string.format"}, context.callSite->location);
return true;
}
std::vector<TypeId> expected = parseFormatString(context.builtinTypes, fmt->value.data, fmt->value.size);
const auto& [params, tail] = flatten(context.arguments);
@ -579,12 +633,33 @@ static void dcrMagicFunctionTypeCheckFormat(MagicFunctionTypeCheckContext contex
Location location = context.callSite->args.data[i + (calledWithSelf ? 0 : paramOffset)]->location;
// use subtyping instead here
SubtypingResult result = context.typechecker->subtyping->isSubtype(actualTy, expectedTy, context.checkScope);
if (!result.isSubtype)
{
Reasonings reasonings = context.typechecker->explainReasonings(actualTy, expectedTy, location, result);
context.typechecker->reportError(TypeMismatch{expectedTy, actualTy, reasonings.toString()}, location);
if (FFlag::LuauStringFormatErrorSuppression)
{
switch (shouldSuppressErrors(NotNull{&context.typechecker->normalizer}, actualTy))
{
case ErrorSuppression::Suppress:
break;
case ErrorSuppression::NormalizationFailed:
break;
case ErrorSuppression::DoNotSuppress:
Reasonings reasonings = context.typechecker->explainReasonings(actualTy, expectedTy, location, result);
if (!reasonings.suppressed)
context.typechecker->reportError(TypeMismatch{expectedTy, actualTy, reasonings.toString()}, location);
}
}
else
{
Reasonings reasonings = context.typechecker->explainReasonings(actualTy, expectedTy, location, result);
context.typechecker->reportError(TypeMismatch{expectedTy, actualTy, reasonings.toString()}, location);
}
}
}
return true;
}
static std::vector<TypeId> parsePatternString(NotNull<BuiltinTypes> builtinTypes, const char* data, size_t size)
@ -647,7 +722,7 @@ static std::vector<TypeId> parsePatternString(NotNull<BuiltinTypes> builtinTypes
return result;
}
static std::optional<WithPredicate<TypePackId>> magicFunctionGmatch(
std::optional<WithPredicate<TypePackId>> MagicGmatch::handleOldSolver(
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
@ -683,7 +758,7 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionGmatch(
return WithPredicate<TypePackId>{arena.addTypePack({iteratorType})};
}
static bool dcrMagicFunctionGmatch(MagicFunctionCallContext context)
bool MagicGmatch::infer(const MagicFunctionCallContext& context)
{
const auto& [params, tail] = flatten(context.arguments);
@ -716,7 +791,7 @@ static bool dcrMagicFunctionGmatch(MagicFunctionCallContext context)
return true;
}
static std::optional<WithPredicate<TypePackId>> magicFunctionMatch(
std::optional<WithPredicate<TypePackId>> MagicMatch::handleOldSolver(
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
@ -756,7 +831,7 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionMatch(
return WithPredicate<TypePackId>{returnList};
}
static bool dcrMagicFunctionMatch(MagicFunctionCallContext context)
bool MagicMatch::infer(const MagicFunctionCallContext& context)
{
const auto& [params, tail] = flatten(context.arguments);
@ -792,7 +867,7 @@ static bool dcrMagicFunctionMatch(MagicFunctionCallContext context)
return true;
}
static std::optional<WithPredicate<TypePackId>> magicFunctionFind(
std::optional<WithPredicate<TypePackId>> MagicFind::handleOldSolver(
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
@ -850,7 +925,7 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionFind(
return WithPredicate<TypePackId>{returnList};
}
static bool dcrMagicFunctionFind(MagicFunctionCallContext context)
bool MagicFind::infer(const MagicFunctionCallContext& context)
{
const auto& [params, tail] = flatten(context.arguments);
@ -927,12 +1002,9 @@ TypeId makeStringMetatable(NotNull<BuiltinTypes> builtinTypes)
FunctionType formatFTV{arena->addTypePack(TypePack{{stringType}, variadicTailPack}), oneStringPack};
formatFTV.magicFunction = &magicFunctionFormat;
formatFTV.isCheckedFunction = true;
const TypeId formatFn = arena->addType(formatFTV);
attachDcrMagicFunction(formatFn, dcrMagicFunctionFormat);
if (FFlag::LuauDCRMagicFunctionTypeChecker)
attachDcrMagicFunctionTypeCheck(formatFn, dcrMagicFunctionTypeCheckFormat);
attachMagicFunction(formatFn, std::make_shared<MagicFormat>());
const TypeId stringToStringType = makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType}, /* checked */ true);
@ -946,16 +1018,14 @@ TypeId makeStringMetatable(NotNull<BuiltinTypes> builtinTypes)
makeFunction(*arena, stringType, {}, {}, {stringType, replArgType, optionalNumber}, {}, {stringType, numberType}, /* checked */ false);
const TypeId gmatchFunc =
makeFunction(*arena, stringType, {}, {}, {stringType}, {}, {arena->addType(FunctionType{emptyPack, stringVariadicList})}, /* checked */ true);
attachMagicFunction(gmatchFunc, magicFunctionGmatch);
attachDcrMagicFunction(gmatchFunc, dcrMagicFunctionGmatch);
attachMagicFunction(gmatchFunc, std::make_shared<MagicGmatch>());
FunctionType matchFuncTy{
arena->addTypePack({stringType, stringType, optionalNumber}), arena->addTypePack(TypePackVar{VariadicTypePack{stringType}})
};
matchFuncTy.isCheckedFunction = true;
const TypeId matchFunc = arena->addType(matchFuncTy);
attachMagicFunction(matchFunc, magicFunctionMatch);
attachDcrMagicFunction(matchFunc, dcrMagicFunctionMatch);
attachMagicFunction(matchFunc, std::make_shared<MagicMatch>());
FunctionType findFuncTy{
arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}),
@ -963,8 +1033,7 @@ TypeId makeStringMetatable(NotNull<BuiltinTypes> builtinTypes)
};
findFuncTy.isCheckedFunction = true;
const TypeId findFunc = arena->addType(findFuncTy);
attachMagicFunction(findFunc, magicFunctionFind);
attachDcrMagicFunction(findFunc, dcrMagicFunctionFind);
attachMagicFunction(findFunc, std::make_shared<MagicFind>());
// string.byte : string -> number? -> number? -> ...number
FunctionType stringDotByte{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList};
@ -1025,7 +1094,7 @@ TypeId makeStringMetatable(NotNull<BuiltinTypes> builtinTypes)
return arena->addType(TableType{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed});
}
static std::optional<WithPredicate<TypePackId>> magicFunctionSelect(
std::optional<WithPredicate<TypePackId>> MagicSelect::handleOldSolver(
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
@ -1070,7 +1139,7 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionSelect(
return std::nullopt;
}
static bool dcrMagicFunctionSelect(MagicFunctionCallContext context)
bool MagicSelect::infer(const MagicFunctionCallContext& context)
{
if (context.callSite->args.size <= 0)
{
@ -1115,7 +1184,7 @@ static bool dcrMagicFunctionSelect(MagicFunctionCallContext context)
return false;
}
static std::optional<WithPredicate<TypePackId>> magicFunctionSetMetaTable(
std::optional<WithPredicate<TypePackId>> MagicSetMetatable::handleOldSolver(
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
@ -1197,7 +1266,12 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionSetMetaTable(
return WithPredicate<TypePackId>{arena.addTypePack({target})};
}
static std::optional<WithPredicate<TypePackId>> magicFunctionAssert(
bool MagicSetMetatable::infer(const MagicFunctionCallContext&)
{
return false;
}
std::optional<WithPredicate<TypePackId>> MagicAssert::handleOldSolver(
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
@ -1231,7 +1305,12 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionAssert(
return WithPredicate<TypePackId>{arena.addTypePack(TypePack{std::move(head), tail})};
}
static std::optional<WithPredicate<TypePackId>> magicFunctionPack(
bool MagicAssert::infer(const MagicFunctionCallContext&)
{
return false;
}
std::optional<WithPredicate<TypePackId>> MagicPack::handleOldSolver(
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
@ -1274,7 +1353,7 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionPack(
return WithPredicate<TypePackId>{arena.addTypePack({packedTable})};
}
static bool dcrMagicFunctionPack(MagicFunctionCallContext context)
bool MagicPack::infer(const MagicFunctionCallContext& context)
{
TypeArena* arena = context.solver->arena;
@ -1314,6 +1393,162 @@ static bool dcrMagicFunctionPack(MagicFunctionCallContext context)
return true;
}
std::optional<WithPredicate<TypePackId>> MagicClone::handleOldSolver(
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
)
{
LUAU_ASSERT(FFlag::LuauTableCloneClonesType3);
auto [paramPack, _predicates] = withPredicate;
TypeArena& arena = typechecker.currentModule->internalTypes;
const auto& [paramTypes, paramTail] = flatten(paramPack);
if (paramTypes.empty() || expr.args.size == 0)
{
typechecker.reportError(expr.argLocation, CountMismatch{1, std::nullopt, 0});
return std::nullopt;
}
TypeId inputType = follow(paramTypes[0]);
if (!get<TableType>(inputType))
return std::nullopt;
CloneState cloneState{typechecker.builtinTypes};
TypeId resultType = shallowClone(inputType, arena, cloneState);
TypePackId clonedTypePack = arena.addTypePack({resultType});
return WithPredicate<TypePackId>{clonedTypePack};
}
bool MagicClone::infer(const MagicFunctionCallContext& context)
{
LUAU_ASSERT(FFlag::LuauTableCloneClonesType3);
TypeArena* arena = context.solver->arena;
const auto& [paramTypes, paramTail] = flatten(context.arguments);
if (paramTypes.empty() || context.callSite->args.size == 0)
{
context.solver->reportError(CountMismatch{1, std::nullopt, 0}, context.callSite->argLocation);
return false;
}
TypeId inputType = follow(paramTypes[0]);
if (!get<TableType>(inputType))
return false;
CloneState cloneState{context.solver->builtinTypes};
TypeId resultType = shallowClone(inputType, *arena, cloneState, /* ignorePersistent */ FFlag::LuauFreezeIgnorePersistent);
if (auto tableType = getMutable<TableType>(resultType))
{
tableType->scope = context.constraint->scope.get();
}
if (FFlag::LuauTrackInteriorFreeTypesOnScope)
trackInteriorFreeType(context.constraint->scope.get(), resultType);
TypePackId clonedTypePack = arena->addTypePack({resultType});
asMutable(context.result)->ty.emplace<BoundTypePack>(clonedTypePack);
return true;
}
static std::optional<TypeId> freezeTable(TypeId inputType, const MagicFunctionCallContext& context)
{
TypeArena* arena = context.solver->arena;
if (FFlag::LuauFollowTableFreeze)
inputType = follow(inputType);
if (auto mt = get<MetatableType>(inputType))
{
std::optional<TypeId> frozenTable = freezeTable(mt->table, context);
if (!frozenTable)
return std::nullopt;
TypeId resultType = arena->addType(MetatableType{*frozenTable, mt->metatable, mt->syntheticName});
return resultType;
}
if (get<TableType>(inputType))
{
// Clone the input type, this will become our final result type after we mutate it.
CloneState cloneState{context.solver->builtinTypes};
TypeId resultType = shallowClone(inputType, *arena, cloneState, /* ignorePersistent */ FFlag::LuauFreezeIgnorePersistent);
auto tableTy = getMutable<TableType>(resultType);
// `clone` should not break this.
LUAU_ASSERT(tableTy);
tableTy->state = TableState::Sealed;
// We'll mutate the table to make every property type read-only.
for (auto iter = tableTy->props.begin(); iter != tableTy->props.end();)
{
if (iter->second.isWriteOnly())
iter = tableTy->props.erase(iter);
else
{
iter->second.writeTy = std::nullopt;
iter++;
}
}
return resultType;
}
context.solver->reportError(TypeMismatch{context.solver->builtinTypes->tableType, inputType}, context.callSite->argLocation);
return std::nullopt;
}
std::optional<WithPredicate<TypePackId>> MagicFreeze::
handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>)
{
return std::nullopt;
}
bool MagicFreeze::infer(const MagicFunctionCallContext& context)
{
TypeArena* arena = context.solver->arena;
const DataFlowGraph* dfg = context.solver->dfg.get();
Scope* scope = context.constraint->scope.get();
const auto& [paramTypes, paramTail] = extendTypePack(*arena, context.solver->builtinTypes, context.arguments, 1);
if (paramTypes.empty() || context.callSite->args.size == 0)
{
context.solver->reportError(CountMismatch{1, std::nullopt, 0}, context.callSite->argLocation);
return false;
}
TypeId inputType = follow(paramTypes[0]);
AstExpr* targetExpr = context.callSite->args.data[0];
std::optional<DefId> resultDef = dfg->getDefOptional(targetExpr);
std::optional<TypeId> resultTy = resultDef ? scope->lookup(*resultDef) : std::nullopt;
std::optional<TypeId> frozenType = freezeTable(inputType, context);
if (!frozenType)
{
if (resultTy)
asMutable(*resultTy)->ty.emplace<BoundType>(context.solver->builtinTypes->errorType);
asMutable(context.result)->ty.emplace<BoundTypePack>(context.solver->builtinTypes->errorTypePack);
return true;
}
if (resultTy)
asMutable(*resultTy)->ty.emplace<BoundType>(*frozenType);
asMutable(context.result)->ty.emplace<BoundTypePack>(arena->addTypePack({*frozenType}));
return true;
}
static bool checkRequirePath(TypeChecker& typechecker, AstExpr* expr)
{
// require(foo.parent.bar) will technically work, but it depends on legacy goop that
@ -1336,7 +1571,7 @@ static bool checkRequirePath(TypeChecker& typechecker, AstExpr* expr)
return good;
}
static std::optional<WithPredicate<TypePackId>> magicFunctionRequire(
std::optional<WithPredicate<TypePackId>> MagicRequire::handleOldSolver(
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
@ -1382,7 +1617,7 @@ static bool checkRequirePathDcr(NotNull<ConstraintSolver> solver, AstExpr* expr)
return good;
}
static bool dcrMagicFunctionRequire(MagicFunctionCallContext context)
bool MagicRequire::infer(const MagicFunctionCallContext& context)
{
if (context.callSite->args.size != 1)
{
@ -1405,4 +1640,52 @@ static bool dcrMagicFunctionRequire(MagicFunctionCallContext context)
return false;
}
bool matchSetMetatable(const AstExprCall& call)
{
const char* smt = "setmetatable";
if (call.args.size != 2)
return false;
const AstExprGlobal* funcAsGlobal = call.func->as<AstExprGlobal>();
if (!funcAsGlobal || funcAsGlobal->name != smt)
return false;
return true;
}
bool matchTableFreeze(const AstExprCall& call)
{
if (call.args.size < 1)
return false;
const AstExprIndexName* index = call.func->as<AstExprIndexName>();
if (!index || index->index != "freeze")
return false;
const AstExprGlobal* global = index->expr->as<AstExprGlobal>();
if (!global || global->name != "table")
return false;
return true;
}
bool matchAssert(const AstExprCall& call)
{
if (call.args.size < 1)
return false;
const AstExprGlobal* funcAsGlobal = call.func->as<AstExprGlobal>();
if (!funcAsGlobal || funcAsGlobal->name != "assert")
return false;
return true;
}
bool shouldTypestateForFirstArgument(const AstExprCall& call)
{
// TODO: magic function for setmetatable and assert and then add them
return matchTableFreeze(call);
}
} // namespace Luau

View file

@ -7,6 +7,7 @@
#include "Luau/Unifiable.h"
LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAG(LuauFreezeIgnorePersistent)
// For each `Luau::clone` call, we will clone only up to N amount of types _and_ packs, as controlled by this limit.
LUAU_FASTINTVARIABLE(LuauTypeCloneIterationLimit, 100'000)
@ -38,14 +39,26 @@ class TypeCloner
NotNull<SeenTypes> types;
NotNull<SeenTypePacks> packs;
TypeId forceTy = nullptr;
TypePackId forceTp = nullptr;
int steps = 0;
public:
TypeCloner(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtinTypes, NotNull<SeenTypes> types, NotNull<SeenTypePacks> packs)
TypeCloner(
NotNull<TypeArena> arena,
NotNull<BuiltinTypes> builtinTypes,
NotNull<SeenTypes> types,
NotNull<SeenTypePacks> packs,
TypeId forceTy,
TypePackId forceTp
)
: arena(arena)
, builtinTypes(builtinTypes)
, types(types)
, packs(packs)
, forceTy(forceTy)
, forceTp(forceTp)
{
}
@ -112,7 +125,7 @@ private:
ty = follow(ty, FollowOption::DisableLazyTypeThunks);
if (auto it = types->find(ty); it != types->end())
return it->second;
else if (ty->persistent)
else if (ty->persistent && (!FFlag::LuauFreezeIgnorePersistent || ty != forceTy))
return ty;
return std::nullopt;
}
@ -122,7 +135,7 @@ private:
tp = follow(tp);
if (auto it = packs->find(tp); it != packs->end())
return it->second;
else if (tp->persistent)
else if (tp->persistent && (!FFlag::LuauFreezeIgnorePersistent || tp != forceTp))
return tp;
return std::nullopt;
}
@ -140,7 +153,7 @@ private:
}
}
private:
public:
TypeId shallowClone(TypeId ty)
{
// We want to [`Luau::follow`] but without forcing the expansion of [`LazyType`]s.
@ -148,7 +161,7 @@ private:
if (auto clone = find(ty))
return *clone;
else if (ty->persistent)
else if (ty->persistent && (!FFlag::LuauFreezeIgnorePersistent || ty != forceTy))
return ty;
TypeId target = arena->addType(ty->ty);
@ -174,7 +187,7 @@ private:
if (auto clone = find(tp))
return *clone;
else if (tp->persistent)
else if (tp->persistent && (!FFlag::LuauFreezeIgnorePersistent || tp != forceTp))
return tp;
TypePackId target = arena->addTypePack(tp->ty);
@ -189,6 +202,7 @@ private:
return target;
}
private:
Property shallowClone(const Property& p)
{
if (FFlag::LuauSolverV2)
@ -256,8 +270,7 @@ private:
LUAU_ASSERT(!"Item holds neither TypeId nor TypePackId when enqueuing its children?");
}
// ErrorType and ErrorTypePack is an alias to this type.
void cloneChildren(Unifiable::Error* t)
void cloneChildren(ErrorType* t)
{
// noop.
}
@ -359,6 +372,11 @@ private:
// noop.
}
void cloneChildren(NoRefineType* t)
{
// noop.
}
void cloneChildren(UnionType* t)
{
for (TypeId& ty : t->options)
@ -422,6 +440,11 @@ private:
t->boundTo = shallowClone(t->boundTo);
}
void cloneChildren(ErrorTypePack* t)
{
// noop.
}
void cloneChildren(VariadicTypePack* t)
{
t->ty = shallowClone(t->ty);
@ -448,12 +471,46 @@ private:
} // namespace
TypePackId shallowClone(TypePackId tp, TypeArena& dest, CloneState& cloneState, bool ignorePersistent)
{
if (tp->persistent && (!FFlag::LuauFreezeIgnorePersistent || !ignorePersistent))
return tp;
TypeCloner cloner{
NotNull{&dest},
cloneState.builtinTypes,
NotNull{&cloneState.seenTypes},
NotNull{&cloneState.seenTypePacks},
nullptr,
FFlag::LuauFreezeIgnorePersistent && ignorePersistent ? tp : nullptr
};
return cloner.shallowClone(tp);
}
TypeId shallowClone(TypeId typeId, TypeArena& dest, CloneState& cloneState, bool ignorePersistent)
{
if (typeId->persistent && (!FFlag::LuauFreezeIgnorePersistent || !ignorePersistent))
return typeId;
TypeCloner cloner{
NotNull{&dest},
cloneState.builtinTypes,
NotNull{&cloneState.seenTypes},
NotNull{&cloneState.seenTypePacks},
FFlag::LuauFreezeIgnorePersistent && ignorePersistent ? typeId : nullptr,
nullptr
};
return cloner.shallowClone(typeId);
}
TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState)
{
if (tp->persistent)
return tp;
TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}};
TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}, nullptr, nullptr};
return cloner.clone(tp);
}
@ -462,13 +519,13 @@ TypeId clone(TypeId typeId, TypeArena& dest, CloneState& cloneState)
if (typeId->persistent)
return typeId;
TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}};
TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}, nullptr, nullptr};
return cloner.clone(typeId);
}
TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState)
{
TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}};
TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}, nullptr, nullptr};
TypeFun copy = typeFun;
@ -493,4 +550,18 @@ TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState)
return copy;
}
Binding clone(const Binding& binding, TypeArena& dest, CloneState& cloneState)
{
TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}, nullptr, nullptr};
Binding b;
b.deprecated = binding.deprecated;
b.deprecatedSuggestion = binding.deprecatedSuggestion;
b.documentationSymbol = binding.documentationSymbol;
b.location = binding.location;
b.typeId = cloner.clone(binding.typeId);
return b;
}
} // namespace Luau

View file

@ -3,6 +3,8 @@
#include "Luau/Constraint.h"
#include "Luau/VisitType.h"
LUAU_FASTFLAG(DebugLuauGreedyGeneralization)
namespace Luau
{
@ -46,6 +48,20 @@ struct ReferenceCountInitializer : TypeOnceVisitor
// ClassTypes never contain free types.
return false;
}
bool visit(TypeId, const TypeFunctionInstanceType&) override
{
// We do not consider reference counted types that are inside a type
// function to be part of the reachable reference counted types.
// Otherwise, code can be constructed in just the right way such
// that two type functions both claim to mutate a free type, which
// prevents either type function from trying to generalize it, so
// we potentially get stuck.
//
// The default behavior here is `true` for "visit the child types"
// of this type, hence:
return false;
}
};
bool isReferenceCountedType(const TypeId typ)
@ -97,6 +113,11 @@ DenseHashSet<TypeId> Constraint::getMaybeMutatedFreeTypes() const
{
rci.traverse(fchc->argsPack);
}
else if (auto fcc = get<FunctionCallConstraint>(*this); fcc && FFlag::DebugLuauGreedyGeneralization)
{
rci.traverse(fcc->fn);
rci.traverse(fcc->argsPack);
}
else if (auto ptc = get<PrimitiveTypeConstraint>(*this))
{
rci.traverse(ptc->freeType);
@ -104,7 +125,8 @@ DenseHashSet<TypeId> Constraint::getMaybeMutatedFreeTypes() const
else if (auto hpc = get<HasPropConstraint>(*this))
{
rci.traverse(hpc->resultType);
// `HasPropConstraints` should not mutate `subjectType`.
if (FFlag::DebugLuauGreedyGeneralization)
rci.traverse(hpc->subjectType);
}
else if (auto hic = get<HasIndexerConstraint>(*this))
{
@ -132,6 +154,10 @@ DenseHashSet<TypeId> Constraint::getMaybeMutatedFreeTypes() const
{
rci.traverse(rpc->tp);
}
else if (auto tcc = get<TableCheckConstraint>(*this))
{
rci.traverse(tcc->exprType);
}
return types;
}

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -718,7 +718,7 @@ static DifferResult diffUsingEnv(DifferEnvironment& env, TypeId left, TypeId rig
env.popVisiting();
return diffRes;
}
if (auto le = get<Luau::Unifiable::Error>(left))
if (auto le = get<ErrorType>(left))
{
// TODO: return debug-friendly result state
env.popVisiting();

View file

@ -1,102 +1,12 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/BuiltinDefinitions.h"
LUAU_FASTFLAGVARIABLE(LuauDebugInfoDefn)
namespace Luau
{
static const std::string kBuiltinDefinitionLuaSrcChecked = R"BUILTIN_SRC(
declare bit32: {
band: @checked (...number) -> number,
bor: @checked (...number) -> number,
bxor: @checked (...number) -> number,
btest: @checked (number, ...number) -> boolean,
rrotate: @checked (x: number, disp: number) -> number,
lrotate: @checked (x: number, disp: number) -> number,
lshift: @checked (x: number, disp: number) -> number,
arshift: @checked (x: number, disp: number) -> number,
rshift: @checked (x: number, disp: number) -> number,
bnot: @checked (x: number) -> number,
extract: @checked (n: number, field: number, width: number?) -> number,
replace: @checked (n: number, v: number, field: number, width: number?) -> number,
countlz: @checked (n: number) -> number,
countrz: @checked (n: number) -> number,
byteswap: @checked (n: number) -> number,
}
declare math: {
frexp: @checked (n: number) -> (number, number),
ldexp: @checked (s: number, e: number) -> number,
fmod: @checked (x: number, y: number) -> number,
modf: @checked (n: number) -> (number, number),
pow: @checked (x: number, y: number) -> number,
exp: @checked (n: number) -> number,
ceil: @checked (n: number) -> number,
floor: @checked (n: number) -> number,
abs: @checked (n: number) -> number,
sqrt: @checked (n: number) -> number,
log: @checked (n: number, base: number?) -> number,
log10: @checked (n: number) -> number,
rad: @checked (n: number) -> number,
deg: @checked (n: number) -> number,
sin: @checked (n: number) -> number,
cos: @checked (n: number) -> number,
tan: @checked (n: number) -> number,
sinh: @checked (n: number) -> number,
cosh: @checked (n: number) -> number,
tanh: @checked (n: number) -> number,
atan: @checked (n: number) -> number,
acos: @checked (n: number) -> number,
asin: @checked (n: number) -> number,
atan2: @checked (y: number, x: number) -> number,
min: @checked (number, ...number) -> number,
max: @checked (number, ...number) -> number,
pi: number,
huge: number,
randomseed: @checked (seed: number) -> (),
random: @checked (number?, number?) -> number,
sign: @checked (n: number) -> number,
clamp: @checked (n: number, min: number, max: number) -> number,
noise: @checked (x: number, y: number?, z: number?) -> number,
round: @checked (n: number) -> number,
}
type DateTypeArg = {
year: number,
month: number,
day: number,
hour: number?,
min: number?,
sec: number?,
isdst: boolean?,
}
type DateTypeResult = {
year: number,
month: number,
wday: number,
yday: number,
day: number,
hour: number,
min: number,
sec: number,
isdst: boolean,
}
declare os: {
time: (time: DateTypeArg?) -> number,
date: ((formatString: "*t" | "!*t", time: number?) -> DateTypeResult) & ((formatString: string?, time: number?) -> string),
difftime: (t2: DateTypeResult | number, t1: DateTypeResult | number) -> number,
clock: () -> number,
}
static const std::string kBuiltinDefinitionBaseSrc = R"BUILTIN_SRC(
@checked declare function require(target: any): any
@ -144,6 +54,119 @@ declare function loadstring<A...>(src: string, chunkname: string?): (((A...) ->
@checked declare function newproxy(mt: boolean?): any
-- Cannot use `typeof` here because it will produce a polytype when we expect a monotype.
declare function unpack<V>(tab: {V}, i: number?, j: number?): ...V
)BUILTIN_SRC";
static const std::string kBuiltinDefinitionBit32Src = R"BUILTIN_SRC(
declare bit32: {
band: @checked (...number) -> number,
bor: @checked (...number) -> number,
bxor: @checked (...number) -> number,
btest: @checked (number, ...number) -> boolean,
rrotate: @checked (x: number, disp: number) -> number,
lrotate: @checked (x: number, disp: number) -> number,
lshift: @checked (x: number, disp: number) -> number,
arshift: @checked (x: number, disp: number) -> number,
rshift: @checked (x: number, disp: number) -> number,
bnot: @checked (x: number) -> number,
extract: @checked (n: number, field: number, width: number?) -> number,
replace: @checked (n: number, v: number, field: number, width: number?) -> number,
countlz: @checked (n: number) -> number,
countrz: @checked (n: number) -> number,
byteswap: @checked (n: number) -> number,
}
)BUILTIN_SRC";
static const std::string kBuiltinDefinitionMathSrc = R"BUILTIN_SRC(
declare math: {
frexp: @checked (n: number) -> (number, number),
ldexp: @checked (s: number, e: number) -> number,
fmod: @checked (x: number, y: number) -> number,
modf: @checked (n: number) -> (number, number),
pow: @checked (x: number, y: number) -> number,
exp: @checked (n: number) -> number,
ceil: @checked (n: number) -> number,
floor: @checked (n: number) -> number,
abs: @checked (n: number) -> number,
sqrt: @checked (n: number) -> number,
log: @checked (n: number, base: number?) -> number,
log10: @checked (n: number) -> number,
rad: @checked (n: number) -> number,
deg: @checked (n: number) -> number,
sin: @checked (n: number) -> number,
cos: @checked (n: number) -> number,
tan: @checked (n: number) -> number,
sinh: @checked (n: number) -> number,
cosh: @checked (n: number) -> number,
tanh: @checked (n: number) -> number,
atan: @checked (n: number) -> number,
acos: @checked (n: number) -> number,
asin: @checked (n: number) -> number,
atan2: @checked (y: number, x: number) -> number,
min: @checked (number, ...number) -> number,
max: @checked (number, ...number) -> number,
pi: number,
huge: number,
randomseed: @checked (seed: number) -> (),
random: @checked (number?, number?) -> number,
sign: @checked (n: number) -> number,
clamp: @checked (n: number, min: number, max: number) -> number,
noise: @checked (x: number, y: number?, z: number?) -> number,
round: @checked (n: number) -> number,
map: @checked (x: number, inmin: number, inmax: number, outmin: number, outmax: number) -> number,
lerp: @checked (a: number, b: number, t: number) -> number,
}
)BUILTIN_SRC";
static const std::string kBuiltinDefinitionOsSrc = R"BUILTIN_SRC(
type DateTypeArg = {
year: number,
month: number,
day: number,
hour: number?,
min: number?,
sec: number?,
isdst: boolean?,
}
type DateTypeResult = {
year: number,
month: number,
wday: number,
yday: number,
day: number,
hour: number,
min: number,
sec: number,
isdst: boolean,
}
declare os: {
time: (time: DateTypeArg?) -> number,
date: ((formatString: "*t" | "!*t", time: number?) -> DateTypeResult) & ((formatString: string?, time: number?) -> string),
difftime: (t2: DateTypeResult | number, t1: DateTypeResult | number) -> number,
clock: () -> number,
}
)BUILTIN_SRC";
static const std::string kBuiltinDefinitionCoroutineSrc = R"BUILTIN_SRC(
declare coroutine: {
create: <A..., R...>(f: (A...) -> R...) -> thread,
resume: <A..., R...>(co: thread, A...) -> (boolean, R...),
@ -155,6 +178,10 @@ declare coroutine: {
close: @checked (co: thread) -> (boolean, any)
}
)BUILTIN_SRC";
static const std::string kBuiltinDefinitionTableSrc = R"BUILTIN_SRC(
declare table: {
concat: <V>(t: {V}, sep: string?, i: number?, j: number?) -> string,
insert: (<V>(t: {V}, value: V) -> ()) & (<V>(t: {V}, pos: number, value: V) -> ()),
@ -177,11 +204,28 @@ declare table: {
isfrozen: <K, V>(t: {[K]: V}) -> boolean,
}
)BUILTIN_SRC";
static const std::string kBuiltinDefinitionDebugSrc = R"BUILTIN_SRC(
declare debug: {
info: ((thread: thread, level: number, options: string) -> ...any) & ((level: number, options: string) -> ...any) & (<A..., R1...>(func: (A...) -> R1..., options: string) -> ...any),
traceback: ((message: string?, level: number?) -> string) & ((thread: thread, message: string?, level: number?) -> string),
}
)BUILTIN_SRC";
static const std::string kBuiltinDefinitionDebugSrc_DEPRECATED = R"BUILTIN_SRC(
declare debug: {
info: (<R...>(thread: thread, level: number, options: string) -> R...) & (<R...>(level: number, options: string) -> R...) & (<A..., R1..., R2...>(func: (A...) -> R1..., options: string) -> R2...),
traceback: ((message: string?, level: number?) -> string) & ((thread: thread, message: string?, level: number?) -> string),
}
)BUILTIN_SRC";
static const std::string kBuiltinDefinitionUtf8Src = R"BUILTIN_SRC(
declare utf8: {
char: @checked (...number) -> string,
charpattern: string,
@ -191,10 +235,9 @@ declare utf8: {
offset: @checked (s: string, n: number?, i: number?) -> number,
}
-- Cannot use `typeof` here because it will produce a polytype when we expect a monotype.
declare function unpack<V>(tab: {V}, i: number?, j: number?): ...V
)BUILTIN_SRC";
static const std::string kBuiltinDefinitionBufferSrc = R"BUILTIN_SRC(
--- Buffer API
declare buffer: {
create: @checked (size: number) -> buffer,
@ -221,13 +264,56 @@ declare buffer: {
writef64: @checked (b: buffer, offset: number, value: number) -> (),
readstring: @checked (b: buffer, offset: number, count: number) -> string,
writestring: @checked (b: buffer, offset: number, value: string, count: number?) -> (),
readbits: @checked (b: buffer, bitOffset: number, bitCount: number) -> number,
writebits: @checked (b: buffer, bitOffset: number, bitCount: number, value: number) -> (),
}
)BUILTIN_SRC";
static const std::string kBuiltinDefinitionVectorSrc = R"BUILTIN_SRC(
-- While vector would have been better represented as a built-in primitive type, type solver class handling covers most of the properties
declare class vector
x: number
y: number
z: number
end
declare vector: {
create: @checked (x: number, y: number, z: number?) -> vector,
magnitude: @checked (vec: vector) -> number,
normalize: @checked (vec: vector) -> vector,
cross: @checked (vec1: vector, vec2: vector) -> vector,
dot: @checked (vec1: vector, vec2: vector) -> number,
angle: @checked (vec1: vector, vec2: vector, axis: vector?) -> number,
floor: @checked (vec: vector) -> vector,
ceil: @checked (vec: vector) -> vector,
abs: @checked (vec: vector) -> vector,
sign: @checked (vec: vector) -> vector,
clamp: @checked (vec: vector, min: vector, max: vector) -> vector,
max: @checked (vector, ...vector) -> vector,
min: @checked (vector, ...vector) -> vector,
zero: vector,
one: vector,
}
)BUILTIN_SRC";
std::string getBuiltinDefinitionSource()
{
std::string result = kBuiltinDefinitionLuaSrcChecked;
std::string result = kBuiltinDefinitionBaseSrc;
result += kBuiltinDefinitionBit32Src;
result += kBuiltinDefinitionMathSrc;
result += kBuiltinDefinitionOsSrc;
result += kBuiltinDefinitionCoroutineSrc;
result += kBuiltinDefinitionTableSrc;
result += FFlag::LuauDebugInfoDefn ? kBuiltinDefinitionDebugSrc : kBuiltinDefinitionDebugSrc_DEPRECATED;
result += kBuiltinDefinitionUtf8Src;
result += kBuiltinDefinitionBufferSrc;
result += kBuiltinDefinitionVectorSrc;
return result;
}

File diff suppressed because it is too large Load diff

View file

@ -18,8 +18,6 @@
LUAU_FASTINTVARIABLE(LuauIndentTypeMismatchMaxTypeLength, 10)
LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauImproveNonFunctionCallError, false)
static std::string wrongNumberOfArgsString(
size_t expectedCount,
std::optional<size_t> maximumCount,
@ -408,35 +406,30 @@ struct ErrorConverter
std::string operator()(const Luau::CannotCallNonFunction& e) const
{
if (DFFlag::LuauImproveNonFunctionCallError)
if (auto unionTy = get<UnionType>(follow(e.ty)))
{
if (auto unionTy = get<UnionType>(follow(e.ty)))
std::string err = "Cannot call a value of the union type:";
for (auto option : unionTy)
{
std::string err = "Cannot call a value of the union type:";
option = follow(option);
for (auto option : unionTy)
if (get<FunctionType>(option) || findCallMetamethod(option))
{
option = follow(option);
if (get<FunctionType>(option) || findCallMetamethod(option))
{
err += "\n | " + toString(option);
continue;
}
// early-exit if we find something that isn't callable in the union.
return "Cannot call a value of type " + toString(option) + " in union:\n " + toString(e.ty);
err += "\n | " + toString(option);
continue;
}
err += "\nWe are unable to determine the appropriate result type for such a call.";
return err;
// early-exit if we find something that isn't callable in the union.
return "Cannot call a value of type " + toString(option) + " in union:\n " + toString(e.ty);
}
return "Cannot call a value of type " + toString(e.ty);
err += "\nWe are unable to determine the appropriate result type for such a call.";
return err;
}
return "Cannot call non-function " + toString(e.ty);
return "Cannot call a value of type " + toString(e.ty);
}
std::string operator()(const Luau::ExtraInformation& e) const
{
@ -793,6 +786,11 @@ struct ErrorConverter
return "Encountered an unexpected type pack in subtyping: " + toString(e.tp);
}
std::string operator()(const UserDefinedTypeFunctionError& e) const
{
return e.message;
}
std::string operator()(const CannotAssignToNever& e) const
{
std::string result = "Cannot assign a value of type " + toString(e.rhsType) + " to a field of type never";
@ -1175,6 +1173,11 @@ bool UnexpectedTypePackInSubtyping::operator==(const UnexpectedTypePackInSubtypi
return tp == rhs.tp;
}
bool UserDefinedTypeFunctionError::operator==(const UserDefinedTypeFunctionError& rhs) const
{
return message == rhs.message;
}
bool CannotAssignToNever::operator==(const CannotAssignToNever& rhs) const
{
if (cause.size() != rhs.cause.size())
@ -1384,6 +1387,9 @@ void copyError(T& e, TypeArena& destArena, CloneState& cloneState)
e.ty = clone(e.ty);
else if constexpr (std::is_same_v<T, UnexpectedTypePackInSubtyping>)
e.tp = clone(e.tp);
else if constexpr (std::is_same_v<T, UserDefinedTypeFunctionError>)
{
}
else if constexpr (std::is_same_v<T, CannotAssignToNever>)
{
e.rhsType = clone(e.rhsType);

View file

@ -0,0 +1,708 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/FragmentAutocomplete.h"
#include "Luau/Ast.h"
#include "Luau/AstQuery.h"
#include "Luau/Autocomplete.h"
#include "Luau/Common.h"
#include "Luau/EqSatSimplification.h"
#include "Luau/ModuleResolver.h"
#include "Luau/Parser.h"
#include "Luau/ParseOptions.h"
#include "Luau/Module.h"
#include "Luau/TimeTrace.h"
#include "Luau/UnifierSharedState.h"
#include "Luau/TypeFunction.h"
#include "Luau/DataFlowGraph.h"
#include "Luau/ConstraintGenerator.h"
#include "Luau/ConstraintSolver.h"
#include "Luau/Frontend.h"
#include "Luau/Parser.h"
#include "Luau/ParseOptions.h"
#include "Luau/Module.h"
#include "Luau/Clone.h"
#include "AutocompleteCore.h"
LUAU_FASTINT(LuauTypeInferRecursionLimit);
LUAU_FASTINT(LuauTypeInferIterationLimit);
LUAU_FASTINT(LuauTarjanChildLimit)
LUAU_FASTFLAG(LuauAutocompleteRefactorsForIncrementalAutocomplete)
LUAU_FASTFLAGVARIABLE(LuauIncrementalAutocompleteBugfixes)
LUAU_FASTFLAGVARIABLE(LuauMixedModeDefFinderTraversesTypeOf)
LUAU_FASTFLAG(LuauBetterReverseDependencyTracking)
LUAU_FASTFLAGVARIABLE(LuauCloneIncrementalModule)
LUAU_FASTFLAGVARIABLE(LogFragmentsFromAutocomplete)
namespace
{
template<typename T>
void copyModuleVec(std::vector<T>& result, const std::vector<T>& input)
{
result.insert(result.end(), input.begin(), input.end());
}
template<typename K, typename V>
void copyModuleMap(Luau::DenseHashMap<K, V>& result, const Luau::DenseHashMap<K, V>& input)
{
for (auto [k, v] : input)
result[k] = v;
}
} // namespace
namespace Luau
{
template<typename K, typename V>
void cloneModuleMap(TypeArena& destArena, CloneState& cloneState, const Luau::DenseHashMap<K, V>& source, Luau::DenseHashMap<K, V>& dest)
{
for (auto [k, v] : source)
{
dest[k] = Luau::clone(v, destArena, cloneState);
}
}
struct MixedModeIncrementalTCDefFinder : public AstVisitor
{
bool visit(AstExprLocal* local) override
{
referencedLocalDefs.emplace_back(local->local, local);
return true;
}
bool visit(AstTypeTypeof* node) override
{
// We need to traverse typeof expressions because they may refer to locals that we need
// to populate the local environment for fragment typechecking. For example, `typeof(m)`
// requires that we find the local/global `m` and place it in the environment.
// The default behaviour here is to return false, and have individual visitors override
// the specific behaviour they need.
return FFlag::LuauMixedModeDefFinderTraversesTypeOf;
}
// ast defs is just a mapping from expr -> def in general
// will get built up by the dfg builder
// localDefs, we need to copy over
std::vector<std::pair<AstLocal*, AstExpr*>> referencedLocalDefs;
};
void cloneAndSquashScopes(
CloneState& cloneState,
const Scope* staleScope,
const ModulePtr& staleModule,
NotNull<TypeArena> destArena,
NotNull<DataFlowGraph> dfg,
AstStatBlock* program,
Scope* destScope
)
{
LUAU_TIMETRACE_SCOPE("Luau::cloneAndSquashScopes", "FragmentAutocomplete");
std::vector<const Scope*> scopes;
for (const Scope* current = staleScope; current; current = current->parent.get())
{
scopes.emplace_back(current);
}
// in reverse order (we need to clone the parents and override defs as we go down the list)
for (auto it = scopes.rbegin(); it != scopes.rend(); ++it)
{
const Scope* curr = *it;
// Clone the lvalue types
for (const auto& [def, ty] : curr->lvalueTypes)
destScope->lvalueTypes[def] = Luau::clone(ty, *destArena, cloneState);
// Clone the rvalueRefinements
for (const auto& [def, ty] : curr->rvalueRefinements)
destScope->rvalueRefinements[def] = Luau::clone(ty, *destArena, cloneState);
for (const auto& [n, m] : curr->importedTypeBindings)
{
std::unordered_map<Name, TypeFun> importedBindingTypes;
for (const auto& [v, tf] : m)
importedBindingTypes[v] = Luau::clone(tf, *destArena, cloneState);
destScope->importedTypeBindings[n] = m;
}
// Finally, clone up the bindings
for (const auto& [s, b] : curr->bindings)
{
destScope->bindings[s] = Luau::clone(b, *destArena, cloneState);
}
}
// The above code associates defs with TypeId's in the scope
// so that lookup to locals will succeed.
MixedModeIncrementalTCDefFinder finder;
program->visit(&finder);
std::vector<std::pair<AstLocal*, AstExpr*>> locals = std::move(finder.referencedLocalDefs);
for (auto [loc, expr] : locals)
{
if (std::optional<Binding> binding = staleScope->linearSearchForBinding(loc->name.value, true))
{
destScope->lvalueTypes[dfg->getDef(expr)] = Luau::clone(binding->typeId, *destArena, cloneState);
}
}
return;
}
static FrontendModuleResolver& getModuleResolver(Frontend& frontend, std::optional<FrontendOptions> options)
{
if (FFlag::LuauSolverV2 || !options)
return frontend.moduleResolver;
return options->forAutocomplete ? frontend.moduleResolverForAutocomplete : frontend.moduleResolver;
}
FragmentAutocompleteAncestryResult findAncestryForFragmentParse(AstStatBlock* root, const Position& cursorPos)
{
std::vector<AstNode*> ancestry = findAncestryAtPositionForAutocomplete(root, cursorPos);
// Should always contain the root AstStat
LUAU_ASSERT(ancestry.size() >= 1);
DenseHashMap<AstName, AstLocal*> localMap{AstName()};
std::vector<AstLocal*> localStack;
AstStat* nearestStatement = nullptr;
for (AstNode* node : ancestry)
{
if (auto block = node->as<AstStatBlock>())
{
for (auto stat : block->body)
{
if (stat->location.begin <= cursorPos)
nearestStatement = stat;
if (stat->location.begin < cursorPos && stat->location.begin.line < cursorPos.line)
{
// This statement precedes the current one
if (auto loc = stat->as<AstStatLocal>())
{
for (auto v : loc->vars)
{
localStack.push_back(v);
localMap[v->name] = v;
}
}
else if (auto locFun = stat->as<AstStatLocalFunction>())
{
localStack.push_back(locFun->name);
localMap[locFun->name->name] = locFun->name;
if (locFun->location.contains(cursorPos))
{
for (AstLocal* loc : locFun->func->args)
{
localStack.push_back(loc);
localMap[loc->name] = loc;
}
}
}
else if (auto globFun = stat->as<AstStatFunction>())
{
if (globFun->location.contains(cursorPos))
{
for (AstLocal* loc : globFun->func->args)
{
localStack.push_back(loc);
localMap[loc->name] = loc;
}
}
}
}
}
}
}
if (!nearestStatement)
nearestStatement = ancestry[0]->asStat();
LUAU_ASSERT(nearestStatement);
return {std::move(localMap), std::move(localStack), std::move(ancestry), std::move(nearestStatement)};
}
/**
* Get document offsets is a function that takes a source text document as well as a start position and end position(line, column) in that
* document and attempts to get the concrete text between those points. It returns a pair of:
* - start offset that represents an index in the source `char*` corresponding to startPos
* - length, that represents how many more bytes to read to get to endPos.
* Example - your document is "foo bar baz" and getDocumentOffsets is passed (0, 4), (0, 8). This function returns the pair {3, 5}
* which corresponds to the string " bar "
*/
std::pair<size_t, size_t> getDocumentOffsets(const std::string_view& src, const Position& startPos, const Position& endPos)
{
size_t lineCount = 0;
size_t colCount = 0;
size_t docOffset = 0;
size_t startOffset = 0;
size_t endOffset = 0;
bool foundStart = false;
bool foundEnd = false;
for (char c : src)
{
if (foundStart && foundEnd)
break;
if (startPos.line == lineCount && startPos.column == colCount)
{
foundStart = true;
startOffset = docOffset;
}
if (endPos.line == lineCount && endPos.column == colCount)
{
endOffset = docOffset;
while (endOffset < src.size() && src[endOffset] != '\n')
endOffset++;
foundEnd = true;
}
// We put a cursor position that extends beyond the extents of the current line
if (foundStart && !foundEnd && (lineCount > endPos.line))
{
foundEnd = true;
endOffset = docOffset - 1;
}
if (c == '\n')
{
lineCount++;
colCount = 0;
}
else
{
colCount++;
}
docOffset++;
}
if (foundStart && !foundEnd)
endOffset = src.length();
size_t min = std::min(startOffset, endOffset);
size_t len = std::max(startOffset, endOffset) - min;
return {min, len};
}
ScopePtr findClosestScope(const ModulePtr& module, const AstStat* nearestStatement)
{
LUAU_ASSERT(module->hasModuleScope());
ScopePtr closest = module->getModuleScope();
// find the scope the nearest statement belonged to.
for (auto [loc, sc] : module->scopes)
{
if (loc.encloses(nearestStatement->location) && closest->location.begin <= loc.begin)
closest = sc;
}
return closest;
}
std::optional<FragmentParseResult> parseFragment(
const SourceModule& srcModule,
std::string_view src,
const Position& cursorPos,
std::optional<Position> fragmentEndPosition
)
{
FragmentAutocompleteAncestryResult result = findAncestryForFragmentParse(srcModule.root, cursorPos);
AstStat* nearestStatement = result.nearestStatement;
const Location& rootSpan = srcModule.root->location;
// Did we append vs did we insert inline
bool appended = cursorPos >= rootSpan.end;
// statement spans multiple lines
bool multiline = nearestStatement->location.begin.line != nearestStatement->location.end.line;
const Position endPos = fragmentEndPosition.value_or(cursorPos);
// We start by re-parsing everything (we'll refine this as we go)
Position startPos = srcModule.root->location.begin;
// If we added to the end of the sourceModule, use the end of the nearest location
if (appended && multiline)
startPos = nearestStatement->location.end;
// Statement spans one line && cursorPos is either on the same line or after
else if (!multiline && cursorPos.line >= nearestStatement->location.end.line)
startPos = nearestStatement->location.begin;
else if (multiline && nearestStatement->location.end.line < cursorPos.line)
startPos = nearestStatement->location.end;
else
startPos = nearestStatement->location.begin;
auto [offsetStart, parseLength] = getDocumentOffsets(src, startPos, endPos);
const char* srcStart = src.data() + offsetStart;
std::string_view dbg = src.substr(offsetStart, parseLength);
const std::shared_ptr<AstNameTable>& nameTbl = srcModule.names;
FragmentParseResult fragmentResult;
fragmentResult.fragmentToParse = std::string(dbg.data(), parseLength);
// For the duration of the incremental parse, we want to allow the name table to re-use duplicate names
if (FFlag::LogFragmentsFromAutocomplete)
logLuau(dbg);
ParseOptions opts;
opts.allowDeclarationSyntax = false;
opts.captureComments = true;
opts.parseFragment = FragmentParseResumeSettings{std::move(result.localMap), std::move(result.localStack), startPos};
ParseResult p = Luau::Parser::parse(srcStart, parseLength, *nameTbl, *fragmentResult.alloc.get(), opts);
// This means we threw a ParseError and we should decline to offer autocomplete here.
if (p.root == nullptr)
return std::nullopt;
std::vector<AstNode*> fabricatedAncestry = std::move(result.ancestry);
// Get the ancestry for the fragment at the offset cursor position.
// Consumers have the option to request with fragment end position, so we cannot just use the end position of our parse result as the
// cursor position. Instead, use the cursor position calculated as an offset from our start position.
std::vector<AstNode*> fragmentAncestry = findAncestryAtPositionForAutocomplete(p.root, cursorPos);
fabricatedAncestry.insert(fabricatedAncestry.end(), fragmentAncestry.begin(), fragmentAncestry.end());
if (nearestStatement == nullptr)
nearestStatement = p.root;
fragmentResult.root = std::move(p.root);
fragmentResult.ancestry = std::move(fabricatedAncestry);
fragmentResult.nearestStatement = nearestStatement;
fragmentResult.commentLocations = std::move(p.commentLocations);
return fragmentResult;
}
ModulePtr cloneModule(CloneState& cloneState, const ModulePtr& source, std::unique_ptr<Allocator> alloc)
{
LUAU_TIMETRACE_SCOPE("Luau::cloneModule", "FragmentAutocomplete");
freeze(source->internalTypes);
freeze(source->interfaceTypes);
ModulePtr incremental = std::make_shared<Module>();
incremental->name = source->name;
incremental->humanReadableName = source->humanReadableName;
incremental->allocator = std::move(alloc);
// Clone types
cloneModuleMap(incremental->internalTypes, cloneState, source->astTypes, incremental->astTypes);
cloneModuleMap(incremental->internalTypes, cloneState, source->astTypePacks, incremental->astTypePacks);
cloneModuleMap(incremental->internalTypes, cloneState, source->astExpectedTypes, incremental->astExpectedTypes);
cloneModuleMap(incremental->internalTypes, cloneState, source->astOverloadResolvedTypes, incremental->astOverloadResolvedTypes);
cloneModuleMap(incremental->internalTypes, cloneState, source->astForInNextTypes, incremental->astForInNextTypes);
copyModuleMap(incremental->astScopes, source->astScopes);
return incremental;
}
ModulePtr copyModule(const ModulePtr& result, std::unique_ptr<Allocator> alloc)
{
ModulePtr incrementalModule = std::make_shared<Module>();
incrementalModule->name = result->name;
incrementalModule->humanReadableName = "Incremental$" + result->humanReadableName;
incrementalModule->internalTypes.owningModule = incrementalModule.get();
incrementalModule->interfaceTypes.owningModule = incrementalModule.get();
incrementalModule->allocator = std::move(alloc);
// Don't need to keep this alive (it's already on the source module)
copyModuleVec(incrementalModule->scopes, result->scopes);
copyModuleMap(incrementalModule->astTypes, result->astTypes);
copyModuleMap(incrementalModule->astTypePacks, result->astTypePacks);
copyModuleMap(incrementalModule->astExpectedTypes, result->astExpectedTypes);
// Don't need to clone astOriginalCallTypes
copyModuleMap(incrementalModule->astOverloadResolvedTypes, result->astOverloadResolvedTypes);
// Don't need to clone astForInNextTypes
copyModuleMap(incrementalModule->astForInNextTypes, result->astForInNextTypes);
// Don't need to clone astResolvedTypes
// Don't need to clone astResolvedTypePacks
// Don't need to clone upperBoundContributors
copyModuleMap(incrementalModule->astScopes, result->astScopes);
// Don't need to clone declared Globals;
return incrementalModule;
}
void mixedModeCompatibility(
const ScopePtr& bottomScopeStale,
const ScopePtr& myFakeScope,
const ModulePtr& stale,
NotNull<DataFlowGraph> dfg,
AstStatBlock* program
)
{
// This code does the following
// traverse program
// look for ast refs for locals
// ask for the corresponding defId from dfg
// given that defId, and that expression, in the incremental module, map lvalue types from defID to
MixedModeIncrementalTCDefFinder finder;
program->visit(&finder);
std::vector<std::pair<AstLocal*, AstExpr*>> locals = std::move(finder.referencedLocalDefs);
for (auto [loc, expr] : locals)
{
if (std::optional<Binding> binding = bottomScopeStale->linearSearchForBinding(loc->name.value, true))
{
myFakeScope->lvalueTypes[dfg->getDef(expr)] = binding->typeId;
}
}
}
FragmentTypeCheckResult typecheckFragment_(
Frontend& frontend,
AstStatBlock* root,
const ModulePtr& stale,
const ScopePtr& closestScope,
const Position& cursorPos,
std::unique_ptr<Allocator> astAllocator,
const FrontendOptions& opts
)
{
LUAU_TIMETRACE_SCOPE("Luau::typecheckFragment_", "FragmentAutocomplete");
freeze(stale->internalTypes);
freeze(stale->interfaceTypes);
CloneState cloneState{frontend.builtinTypes};
ModulePtr incrementalModule =
FFlag::LuauCloneIncrementalModule ? cloneModule(cloneState, stale, std::move(astAllocator)) : copyModule(stale, std::move(astAllocator));
incrementalModule->checkedInNewSolver = true;
unfreeze(incrementalModule->internalTypes);
unfreeze(incrementalModule->interfaceTypes);
/// Setup typecheck limits
TypeCheckLimits limits;
if (opts.moduleTimeLimitSec)
limits.finishTime = TimeTrace::getClock() + *opts.moduleTimeLimitSec;
else
limits.finishTime = std::nullopt;
limits.cancellationToken = opts.cancellationToken;
/// Icehandler
NotNull<InternalErrorReporter> iceHandler{&frontend.iceHandler};
/// Make the shared state for the unifier (recursion + iteration limits)
UnifierSharedState unifierState{iceHandler};
unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit;
unifierState.counters.iterationLimit = limits.unifierIterationLimit.value_or(FInt::LuauTypeInferIterationLimit);
/// Initialize the normalizer
Normalizer normalizer{&incrementalModule->internalTypes, frontend.builtinTypes, NotNull{&unifierState}};
/// User defined type functions runtime
TypeFunctionRuntime typeFunctionRuntime(iceHandler, NotNull{&limits});
/// Create a DataFlowGraph just for the surrounding context
DataFlowGraph dfg = DataFlowGraphBuilder::build(root, NotNull{&incrementalModule->defArena}, NotNull{&incrementalModule->keyArena}, iceHandler);
SimplifierPtr simplifier = newSimplifier(NotNull{&incrementalModule->internalTypes}, frontend.builtinTypes);
FrontendModuleResolver& resolver = getModuleResolver(frontend, opts);
/// Contraint Generator
ConstraintGenerator cg{
incrementalModule,
NotNull{&normalizer},
NotNull{simplifier.get()},
NotNull{&typeFunctionRuntime},
NotNull{&resolver},
frontend.builtinTypes,
iceHandler,
stale->getModuleScope(),
nullptr,
nullptr,
NotNull{&dfg},
{}
};
std::shared_ptr<Scope> freshChildOfNearestScope = nullptr;
if (FFlag::LuauCloneIncrementalModule)
{
freshChildOfNearestScope = std::make_shared<Scope>(closestScope);
incrementalModule->scopes.emplace_back(root->location, freshChildOfNearestScope);
cg.rootScope = freshChildOfNearestScope.get();
cloneAndSquashScopes(
cloneState, closestScope.get(), stale, NotNull{&incrementalModule->internalTypes}, NotNull{&dfg}, root, freshChildOfNearestScope.get()
);
cg.visitFragmentRoot(freshChildOfNearestScope, root);
}
else
{
// Any additions to the scope must occur in a fresh scope
cg.rootScope = stale->getModuleScope().get();
freshChildOfNearestScope = std::make_shared<Scope>(closestScope);
incrementalModule->scopes.emplace_back(root->location, freshChildOfNearestScope);
mixedModeCompatibility(closestScope, freshChildOfNearestScope, stale, NotNull{&dfg}, root);
// closest Scope -> children = { ...., freshChildOfNearestScope}
// We need to trim nearestChild from the scope hierarcy
closestScope->children.emplace_back(freshChildOfNearestScope.get());
cg.visitFragmentRoot(freshChildOfNearestScope, root);
// Trim nearestChild from the closestScope
Scope* back = closestScope->children.back().get();
LUAU_ASSERT(back == freshChildOfNearestScope.get());
closestScope->children.pop_back();
}
/// Initialize the constraint solver and run it
ConstraintSolver cs{
NotNull{&normalizer},
NotNull{simplifier.get()},
NotNull{&typeFunctionRuntime},
NotNull(cg.rootScope),
borrowConstraints(cg.constraints),
NotNull{&cg.scopeToFunction},
incrementalModule->name,
NotNull{&resolver},
{},
nullptr,
NotNull{&dfg},
limits
};
try
{
cs.run();
}
catch (const TimeLimitError&)
{
stale->timeout = true;
}
catch (const UserCancelError&)
{
stale->cancelled = true;
}
// In frontend we would forbid internal types
// because this is just for autocomplete, we don't actually care
// We also don't even need to typecheck - just synthesize types as best as we can
freeze(incrementalModule->internalTypes);
freeze(incrementalModule->interfaceTypes);
return {std::move(incrementalModule), std::move(freshChildOfNearestScope)};
}
std::pair<FragmentTypeCheckStatus, FragmentTypeCheckResult> typecheckFragment(
Frontend& frontend,
const ModuleName& moduleName,
const Position& cursorPos,
std::optional<FrontendOptions> opts,
std::string_view src,
std::optional<Position> fragmentEndPosition
)
{
LUAU_TIMETRACE_SCOPE("Luau::typecheckFragment", "FragmentAutocomplete");
LUAU_TIMETRACE_ARGUMENT("name", moduleName.c_str());
if (FFlag::LuauBetterReverseDependencyTracking)
{
if (!frontend.allModuleDependenciesValid(moduleName, opts && opts->forAutocomplete))
return {FragmentTypeCheckStatus::SkipAutocomplete, {}};
}
const SourceModule* sourceModule = frontend.getSourceModule(moduleName);
if (!sourceModule)
{
LUAU_ASSERT(!"Expected Source Module for fragment typecheck");
return {};
}
FrontendModuleResolver& resolver = getModuleResolver(frontend, opts);
ModulePtr module = resolver.getModule(moduleName);
if (!module)
{
LUAU_ASSERT(!"Expected Module for fragment typecheck");
return {};
}
if (FFlag::LuauIncrementalAutocompleteBugfixes)
{
if (sourceModule->allocator.get() != module->allocator.get())
{
return {FragmentTypeCheckStatus::SkipAutocomplete, {}};
}
}
auto tryParse = parseFragment(*sourceModule, src, cursorPos, fragmentEndPosition);
if (!tryParse)
return {FragmentTypeCheckStatus::SkipAutocomplete, {}};
FragmentParseResult& parseResult = *tryParse;
if (isWithinComment(parseResult.commentLocations, fragmentEndPosition.value_or(cursorPos)))
return {FragmentTypeCheckStatus::SkipAutocomplete, {}};
FrontendOptions frontendOptions = opts.value_or(frontend.options);
const ScopePtr& closestScope = findClosestScope(module, parseResult.nearestStatement);
FragmentTypeCheckResult result =
typecheckFragment_(frontend, parseResult.root, module, closestScope, cursorPos, std::move(parseResult.alloc), frontendOptions);
result.ancestry = std::move(parseResult.ancestry);
return {FragmentTypeCheckStatus::Success, result};
}
FragmentAutocompleteStatusResult tryFragmentAutocomplete(
Frontend& frontend,
const ModuleName& moduleName,
Position cursorPosition,
FragmentContext context,
StringCompletionCallback stringCompletionCB
)
{
// TODO: we should calculate fragmentEnd position here, by using context.newAstRoot and cursorPosition
try
{
Luau::FragmentAutocompleteResult fragmentAutocomplete = Luau::fragmentAutocomplete(
frontend,
context.newSrc,
moduleName,
cursorPosition,
context.opts,
std::move(stringCompletionCB),
context.DEPRECATED_fragmentEndPosition
);
return {FragmentAutocompleteStatus::Success, std::move(fragmentAutocomplete)};
}
catch (const Luau::InternalCompilerError& e)
{
if (FFlag::LogFragmentsFromAutocomplete)
logLuau(e.what());
return {FragmentAutocompleteStatus::InternalIce, std::nullopt};
}
}
FragmentAutocompleteResult fragmentAutocomplete(
Frontend& frontend,
std::string_view src,
const ModuleName& moduleName,
Position cursorPosition,
std::optional<FrontendOptions> opts,
StringCompletionCallback callback,
std::optional<Position> fragmentEndPosition
)
{
LUAU_ASSERT(FFlag::LuauAutocompleteRefactorsForIncrementalAutocomplete);
LUAU_TIMETRACE_SCOPE("Luau::fragmentAutocomplete", "FragmentAutocomplete");
LUAU_TIMETRACE_ARGUMENT("name", moduleName.c_str());
const SourceModule* sourceModule = frontend.getSourceModule(moduleName);
if (!sourceModule)
{
LUAU_ASSERT(!"Expected Source Module for fragment typecheck");
return {};
}
// If the cursor is within a comment in the stale source module we should avoid providing a recommendation
if (isWithinComment(*sourceModule, fragmentEndPosition.value_or(cursorPosition)))
return {};
auto [tcStatus, tcResult] = typecheckFragment(frontend, moduleName, cursorPosition, opts, src, fragmentEndPosition);
if (tcStatus == FragmentTypeCheckStatus::SkipAutocomplete)
return {};
auto globalScope = (opts && opts->forAutocomplete) ? frontend.globalsForAutocomplete.globalScope.get() : frontend.globals.globalScope.get();
if (FFlag::LogFragmentsFromAutocomplete)
logLuau(src);
TypeArena arenaForFragmentAutocomplete;
auto result = Luau::autocomplete_(
tcResult.incrementalModule,
frontend.builtinTypes,
&arenaForFragmentAutocomplete,
tcResult.ancestry,
globalScope,
tcResult.freshScope,
cursorPosition,
frontend.fileResolver,
callback
);
return {std::move(tcResult.incrementalModule), tcResult.freshScope.get(), std::move(arenaForFragmentAutocomplete), std::move(result)};
}
} // namespace Luau

View file

@ -10,8 +10,10 @@
#include "Luau/ConstraintSolver.h"
#include "Luau/DataFlowGraph.h"
#include "Luau/DcrLogger.h"
#include "Luau/EqSatSimplification.h"
#include "Luau/FileResolver.h"
#include "Luau/NonStrictTypeChecker.h"
#include "Luau/NotNull.h"
#include "Luau/Parser.h"
#include "Luau/Scope.h"
#include "Luau/StringUtils.h"
@ -36,18 +38,21 @@ LUAU_FASTINT(LuauTypeInferIterationLimit)
LUAU_FASTINT(LuauTypeInferRecursionLimit)
LUAU_FASTINT(LuauTarjanChildLimit)
LUAU_FASTFLAG(LuauInferInNoCheckMode)
LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false)
LUAU_FASTFLAGVARIABLE(LuauStoreCommentsForDefinitionFiles, false)
LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3)
LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false)
LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJsonFile, false)
LUAU_FASTFLAGVARIABLE(DebugLuauForbidInternalTypes, false)
LUAU_FASTFLAGVARIABLE(DebugLuauForceStrictMode, false)
LUAU_FASTFLAGVARIABLE(DebugLuauForceNonStrictMode, false)
LUAU_FASTFLAGVARIABLE(LuauSourceModuleUpdatedWithSelectedMode, false)
LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson)
LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJsonFile)
LUAU_FASTFLAGVARIABLE(DebugLuauForbidInternalTypes)
LUAU_FASTFLAGVARIABLE(DebugLuauForceStrictMode)
LUAU_FASTFLAGVARIABLE(DebugLuauForceNonStrictMode)
LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauRunCustomModuleChecks, false)
LUAU_FASTFLAGVARIABLE(LuauBetterReverseDependencyTracking)
LUAU_FASTFLAG(StudioReportLuauAny2)
LUAU_FASTFLAGVARIABLE(LuauStoreSolverTypeOnModule)
LUAU_FASTFLAGVARIABLE(LuauSelectivelyRetainDFGArena)
namespace Luau
{
@ -134,7 +139,7 @@ static ParseResult parseSourceForModule(std::string_view source, Luau::SourceMod
sourceModule.root = parseResult.root;
sourceModule.mode = Mode::Definition;
if (FFlag::LuauStoreCommentsForDefinitionFiles && options.captureComments)
if (options.captureComments)
{
sourceModule.hotcomments = parseResult.hotcomments;
sourceModule.commentLocations = parseResult.commentLocations;
@ -205,72 +210,6 @@ LoadDefinitionFileResult Frontend::loadDefinitionFile(
return LoadDefinitionFileResult{true, parseResult, sourceModule, checkedModule};
}
std::vector<std::string_view> parsePathExpr(const AstExpr& pathExpr)
{
const AstExprIndexName* indexName = pathExpr.as<AstExprIndexName>();
if (!indexName)
return {};
std::vector<std::string_view> segments{indexName->index.value};
while (true)
{
if (AstExprIndexName* in = indexName->expr->as<AstExprIndexName>())
{
segments.push_back(in->index.value);
indexName = in;
continue;
}
else if (AstExprGlobal* indexNameAsGlobal = indexName->expr->as<AstExprGlobal>())
{
segments.push_back(indexNameAsGlobal->name.value);
break;
}
else if (AstExprLocal* indexNameAsLocal = indexName->expr->as<AstExprLocal>())
{
segments.push_back(indexNameAsLocal->local->name.value);
break;
}
else
return {};
}
std::reverse(segments.begin(), segments.end());
return segments;
}
std::optional<std::string> pathExprToModuleName(const ModuleName& currentModuleName, const std::vector<std::string_view>& segments)
{
if (segments.empty())
return std::nullopt;
std::vector<std::string_view> result;
auto it = segments.begin();
if (*it == "script" && !currentModuleName.empty())
{
result = split(currentModuleName, '/');
++it;
}
for (; it != segments.end(); ++it)
{
if (result.size() > 1 && *it == "Parent")
result.pop_back();
else
result.push_back(*it);
}
return join(result, "/");
}
std::optional<std::string> pathExprToModuleName(const ModuleName& currentModuleName, const AstExpr& pathExpr)
{
std::vector<std::string_view> segments = parsePathExpr(pathExpr);
return pathExprToModuleName(currentModuleName, segments);
}
namespace
{
@ -351,8 +290,7 @@ static void filterLintOptions(LintOptions& lintOptions, const std::vector<HotCom
std::vector<RequireCycle> getRequireCycles(
const FileResolver* resolver,
const std::unordered_map<ModuleName, std::shared_ptr<SourceNode>>& sourceNodes,
const SourceNode* start,
bool stopAtFirst = false
const SourceNode* start
)
{
std::vector<RequireCycle> result;
@ -422,9 +360,6 @@ std::vector<RequireCycle> getRequireCycles(
{
result.push_back({depLocation, std::move(cycle)});
if (stopAtFirst)
return result;
// note: if we didn't find a cycle, all nodes that we've seen don't depend [transitively] on start
// so it's safe to *only* clear seen vector when we find a cycle
// if we don't do it, we will not have correct reporting for some cycles
@ -812,6 +747,32 @@ std::optional<CheckResult> Frontend::getCheckResult(const ModuleName& name, bool
return checkResult;
}
std::vector<ModuleName> Frontend::getRequiredScripts(const ModuleName& name)
{
RequireTraceResult require = requireTrace[name];
if (isDirty(name))
{
std::optional<SourceCode> source = fileResolver->readSource(name);
if (!source)
{
return {};
}
const Config& config = configResolver->getConfig(name);
ParseOptions opts = config.parseOptions;
opts.captureComments = true;
SourceModule result = parse(name, source->source, opts);
result.type = source->type;
require = traceRequires(fileResolver, result.root, name);
}
std::vector<std::string> requiredModuleNames;
requiredModuleNames.reserve(require.requireList.size());
for (const auto& [moduleName, _] : require.requireList)
{
requiredModuleNames.push_back(moduleName);
}
return requiredModuleNames;
}
bool Frontend::parseGraph(
std::vector<ModuleName>& buildQueue,
const ModuleName& root,
@ -860,6 +821,16 @@ bool Frontend::parseGraph(
topseen = Permanent;
buildQueue.push_back(top->name);
if (FFlag::LuauBetterReverseDependencyTracking)
{
// at this point we know all valid dependencies are processed into SourceNodes
for (const ModuleName& dep : top->requireSet)
{
if (auto it = sourceNodes.find(dep); it != sourceNodes.end())
it->second->dependents.insert(top->name);
}
}
}
else
{
@ -948,14 +919,11 @@ void Frontend::addBuildQueueItems(
data.environmentScope = getModuleEnvironment(*sourceModule, data.config, frontendOptions.forAutocomplete);
data.recordJsonLog = FFlag::DebugLuauLogSolverToJson;
Mode mode = sourceModule->mode.value_or(data.config.mode);
// in NoCheck mode we only need to compute the value of .cyclic for typeck
// in the future we could replace toposort with an algorithm that can flag cyclic nodes by itself
// however, for now getRequireCycles isn't expensive in practice on the cases we care about, and long term
// all correct programs must be acyclic so this code triggers rarely
if (cycleDetected)
data.requireCycles = getRequireCycles(fileResolver, sourceNodes, sourceNode.get(), mode == Mode::NoCheck);
data.requireCycles = getRequireCycles(fileResolver, sourceNodes, sourceNode.get());
data.options = frontendOptions;
@ -987,8 +955,7 @@ void Frontend::checkBuildQueueItem(BuildQueueItem& item)
else
mode = sourceModule.mode.value_or(config.mode);
if (FFlag::LuauSourceModuleUpdatedWithSelectedMode)
item.sourceModule->mode = {mode};
item.sourceModule->mode = {mode};
ScopePtr environmentScope = item.environmentScope;
double timestamp = getTimestamp();
const std::vector<RequireCycle>& requireCycles = item.requireCycles;
@ -1093,6 +1060,11 @@ void Frontend::checkBuildQueueItem(BuildQueueItem& item)
freeze(module->interfaceTypes);
module->internalTypes.clear();
if (FFlag::LuauSelectivelyRetainDFGArena)
{
module->defArena.allocator.clear();
module->keyArena.allocator.clear();
}
module->astTypes.clear();
module->astTypePacks.clear();
@ -1146,15 +1118,49 @@ void Frontend::recordItemResult(const BuildQueueItem& item)
if (item.exception)
std::rethrow_exception(item.exception);
if (item.options.forAutocomplete)
if (FFlag::LuauBetterReverseDependencyTracking)
{
moduleResolverForAutocomplete.setModule(item.name, item.module);
item.sourceNode->dirtyModuleForAutocomplete = false;
bool replacedModule = false;
if (item.options.forAutocomplete)
{
replacedModule = moduleResolverForAutocomplete.setModule(item.name, item.module);
item.sourceNode->dirtyModuleForAutocomplete = false;
}
else
{
replacedModule = moduleResolver.setModule(item.name, item.module);
item.sourceNode->dirtyModule = false;
}
if (replacedModule)
{
LUAU_TIMETRACE_SCOPE("Frontend::invalidateDependentModules", "Frontend");
LUAU_TIMETRACE_ARGUMENT("name", item.name.c_str());
traverseDependents(
item.name,
[forAutocomplete = item.options.forAutocomplete](SourceNode& sourceNode)
{
bool traverseSubtree = !sourceNode.hasInvalidModuleDependency(forAutocomplete);
sourceNode.setInvalidModuleDependency(true, forAutocomplete);
return traverseSubtree;
}
);
}
item.sourceNode->setInvalidModuleDependency(false, item.options.forAutocomplete);
}
else
{
moduleResolver.setModule(item.name, item.module);
item.sourceNode->dirtyModule = false;
if (item.options.forAutocomplete)
{
moduleResolverForAutocomplete.setModule(item.name, item.module);
item.sourceNode->dirtyModuleForAutocomplete = false;
}
else
{
moduleResolver.setModule(item.name, item.module);
item.sourceNode->dirtyModule = false;
}
}
stats.timeCheck += item.stats.timeCheck;
@ -1191,6 +1197,13 @@ ScopePtr Frontend::getModuleEnvironment(const SourceModule& module, const Config
return result;
}
bool Frontend::allModuleDependenciesValid(const ModuleName& name, bool forAutocomplete) const
{
LUAU_ASSERT(FFlag::LuauBetterReverseDependencyTracking);
auto it = sourceNodes.find(name);
return it != sourceNodes.end() && !it->second->hasInvalidModuleDependency(forAutocomplete);
}
bool Frontend::isDirty(const ModuleName& name, bool forAutocomplete) const
{
auto it = sourceNodes.find(name);
@ -1205,16 +1218,80 @@ bool Frontend::isDirty(const ModuleName& name, bool forAutocomplete) const
*/
void Frontend::markDirty(const ModuleName& name, std::vector<ModuleName>* markedDirty)
{
LUAU_TIMETRACE_SCOPE("Frontend::markDirty", "Frontend");
LUAU_TIMETRACE_ARGUMENT("name", name.c_str());
if (FFlag::LuauBetterReverseDependencyTracking)
{
traverseDependents(
name,
[markedDirty](SourceNode& sourceNode)
{
if (markedDirty)
markedDirty->push_back(sourceNode.name);
if (sourceNode.dirtySourceModule && sourceNode.dirtyModule && sourceNode.dirtyModuleForAutocomplete)
return false;
sourceNode.dirtySourceModule = true;
sourceNode.dirtyModule = true;
sourceNode.dirtyModuleForAutocomplete = true;
return true;
}
);
}
else
{
if (sourceNodes.count(name) == 0)
return;
std::unordered_map<ModuleName, std::vector<ModuleName>> reverseDeps;
for (const auto& module : sourceNodes)
{
for (const auto& dep : module.second->requireSet)
reverseDeps[dep].push_back(module.first);
}
std::vector<ModuleName> queue{name};
while (!queue.empty())
{
ModuleName next = std::move(queue.back());
queue.pop_back();
LUAU_ASSERT(sourceNodes.count(next) > 0);
SourceNode& sourceNode = *sourceNodes[next];
if (markedDirty)
markedDirty->push_back(next);
if (sourceNode.dirtySourceModule && sourceNode.dirtyModule && sourceNode.dirtyModuleForAutocomplete)
continue;
sourceNode.dirtySourceModule = true;
sourceNode.dirtyModule = true;
sourceNode.dirtyModuleForAutocomplete = true;
if (0 == reverseDeps.count(next))
continue;
sourceModules.erase(next);
const std::vector<ModuleName>& dependents = reverseDeps[next];
queue.insert(queue.end(), dependents.begin(), dependents.end());
}
}
}
void Frontend::traverseDependents(const ModuleName& name, std::function<bool(SourceNode&)> processSubtree)
{
LUAU_ASSERT(FFlag::LuauBetterReverseDependencyTracking);
LUAU_TIMETRACE_SCOPE("Frontend::traverseDependents", "Frontend");
if (sourceNodes.count(name) == 0)
return;
std::unordered_map<ModuleName, std::vector<ModuleName>> reverseDeps;
for (const auto& module : sourceNodes)
{
for (const auto& dep : module.second->requireSet)
reverseDeps[dep].push_back(module.first);
}
std::vector<ModuleName> queue{name};
while (!queue.empty())
@ -1225,22 +1302,10 @@ void Frontend::markDirty(const ModuleName& name, std::vector<ModuleName>* marked
LUAU_ASSERT(sourceNodes.count(next) > 0);
SourceNode& sourceNode = *sourceNodes[next];
if (markedDirty)
markedDirty->push_back(next);
if (sourceNode.dirtySourceModule && sourceNode.dirtyModule && sourceNode.dirtyModuleForAutocomplete)
if (!processSubtree(sourceNode))
continue;
sourceNode.dirtySourceModule = true;
sourceNode.dirtyModule = true;
sourceNode.dirtyModuleForAutocomplete = true;
if (0 == reverseDeps.count(next))
continue;
sourceModules.erase(next);
const std::vector<ModuleName>& dependents = reverseDeps[next];
const Set<ModuleName>& dependents = sourceNode.dependents;
queue.insert(queue.end(), dependents.begin(), dependents.end());
}
}
@ -1357,11 +1422,15 @@ ModulePtr check(
LUAU_TIMETRACE_ARGUMENT("name", sourceModule.humanReadableName.c_str());
ModulePtr result = std::make_shared<Module>();
if (FFlag::LuauStoreSolverTypeOnModule)
result->checkedInNewSolver = true;
result->name = sourceModule.name;
result->humanReadableName = sourceModule.humanReadableName;
result->mode = mode;
result->internalTypes.owningModule = result.get();
result->interfaceTypes.owningModule = result.get();
result->allocator = sourceModule.allocator;
result->names = sourceModule.names;
iceHandler->moduleName = sourceModule.name;
@ -1376,17 +1445,23 @@ ModulePtr check(
}
}
DataFlowGraph dfg = DataFlowGraphBuilder::build(sourceModule.root, iceHandler);
DataFlowGraph dfg = DataFlowGraphBuilder::build(sourceModule.root, NotNull{&result->defArena}, NotNull{&result->keyArena}, iceHandler);
UnifierSharedState unifierState{iceHandler};
unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit;
unifierState.counters.iterationLimit = limits.unifierIterationLimit.value_or(FInt::LuauTypeInferIterationLimit);
Normalizer normalizer{&result->internalTypes, builtinTypes, NotNull{&unifierState}};
SimplifierPtr simplifier = newSimplifier(NotNull{&result->internalTypes}, builtinTypes);
TypeFunctionRuntime typeFunctionRuntime{iceHandler, NotNull{&limits}};
typeFunctionRuntime.allowEvaluation = sourceModule.parseErrors.empty();
ConstraintGenerator cg{
result,
NotNull{&normalizer},
NotNull{simplifier.get()},
NotNull{&typeFunctionRuntime},
moduleResolver,
builtinTypes,
iceHandler,
@ -1402,12 +1477,16 @@ ModulePtr check(
ConstraintSolver cs{
NotNull{&normalizer},
NotNull{simplifier.get()},
NotNull{&typeFunctionRuntime},
NotNull(cg.rootScope),
borrowConstraints(cg.constraints),
NotNull{&cg.scopeToFunction},
result->name,
moduleResolver,
requireCycles,
logger.get(),
NotNull{&dfg},
limits
};
@ -1461,12 +1540,31 @@ ModulePtr check(
switch (mode)
{
case Mode::Nonstrict:
Luau::checkNonStrict(builtinTypes, iceHandler, NotNull{&unifierState}, NotNull{&dfg}, NotNull{&limits}, sourceModule, result.get());
Luau::checkNonStrict(
builtinTypes,
NotNull{simplifier.get()},
NotNull{&typeFunctionRuntime},
iceHandler,
NotNull{&unifierState},
NotNull{&dfg},
NotNull{&limits},
sourceModule,
result.get()
);
break;
case Mode::Definition:
// fallthrough intentional
case Mode::Strict:
Luau::check(builtinTypes, NotNull{&unifierState}, NotNull{&limits}, logger.get(), sourceModule, result.get());
Luau::check(
builtinTypes,
NotNull{simplifier.get()},
NotNull{&typeFunctionRuntime},
NotNull{&unifierState},
NotNull{&limits},
logger.get(),
sourceModule,
result.get()
);
break;
case Mode::NoCheck:
break;
@ -1647,6 +1745,17 @@ std::pair<SourceNode*, SourceModule*> Frontend::getSourceNode(const ModuleName&
sourceNode->name = sourceModule->name;
sourceNode->humanReadableName = sourceModule->humanReadableName;
if (FFlag::LuauBetterReverseDependencyTracking)
{
// clear all prior dependents. we will re-add them after parsing the rest of the graph
for (const auto& [moduleName, _] : sourceNode->requireLocations)
{
if (auto depIt = sourceNodes.find(moduleName); depIt != sourceNodes.end())
depIt->second->dependents.erase(sourceNode->name);
}
}
sourceNode->requireSet.clear();
sourceNode->requireLocations.clear();
sourceNode->dirtySourceModule = false;
@ -1768,11 +1877,21 @@ std::string FrontendModuleResolver::getHumanReadableModuleName(const ModuleName&
return frontend->fileResolver->getHumanReadableModuleName(moduleName);
}
void FrontendModuleResolver::setModule(const ModuleName& moduleName, ModulePtr module)
bool FrontendModuleResolver::setModule(const ModuleName& moduleName, ModulePtr module)
{
std::scoped_lock lock(moduleMutex);
modules[moduleName] = std::move(module);
if (FFlag::LuauBetterReverseDependencyTracking)
{
bool replaced = modules.count(moduleName) > 0;
modules[moduleName] = std::move(module);
return replaced;
}
else
{
modules[moduleName] = std::move(module);
return false;
}
}
void FrontendModuleResolver::clearModules()

View file

@ -2,6 +2,8 @@
#include "Luau/Generalization.h"
#include "Luau/Common.h"
#include "Luau/DenseHash.h"
#include "Luau/Scope.h"
#include "Luau/Type.h"
#include "Luau/ToString.h"
@ -9,11 +11,15 @@
#include "Luau/TypePack.h"
#include "Luau/VisitType.h"
LUAU_FASTFLAG(LuauAutocompleteRefactorsForIncrementalAutocomplete)
LUAU_FASTFLAGVARIABLE(LuauGeneralizationRemoveRecursiveUpperBound2)
namespace Luau
{
struct MutatingGeneralizer : TypeOnceVisitor
{
NotNull<TypeArena> arena;
NotNull<BuiltinTypes> builtinTypes;
NotNull<Scope> scope;
@ -27,6 +33,7 @@ struct MutatingGeneralizer : TypeOnceVisitor
bool avoidSealingTables = false;
MutatingGeneralizer(
NotNull<TypeArena> arena,
NotNull<BuiltinTypes> builtinTypes,
NotNull<Scope> scope,
NotNull<DenseHashSet<TypeId>> cachedTypes,
@ -35,6 +42,7 @@ struct MutatingGeneralizer : TypeOnceVisitor
bool avoidSealingTables
)
: TypeOnceVisitor(/* skipBoundTypes */ true)
, arena(arena)
, builtinTypes(builtinTypes)
, scope(scope)
, cachedTypes(cachedTypes)
@ -44,7 +52,7 @@ struct MutatingGeneralizer : TypeOnceVisitor
{
}
static void replace(DenseHashSet<TypeId>& seen, TypeId haystack, TypeId needle, TypeId replacement)
void replace(DenseHashSet<TypeId>& seen, TypeId haystack, TypeId needle, TypeId replacement)
{
haystack = follow(haystack);
@ -91,6 +99,10 @@ struct MutatingGeneralizer : TypeOnceVisitor
LUAU_ASSERT(onlyType != haystack);
emplaceType<BoundType>(asMutable(haystack), onlyType);
}
else if (FFlag::LuauGeneralizationRemoveRecursiveUpperBound2 && ut->options.empty())
{
emplaceType<BoundType>(asMutable(haystack), builtinTypes->neverType);
}
return;
}
@ -133,6 +145,10 @@ struct MutatingGeneralizer : TypeOnceVisitor
TypeId onlyType = it->parts[0];
LUAU_ASSERT(onlyType != needle);
emplaceType<BoundType>(asMutable(needle), onlyType);
}
else if (FFlag::LuauGeneralizationRemoveRecursiveUpperBound2 && it->parts.empty())
{
emplaceType<BoundType>(asMutable(needle), builtinTypes->unknownType);
}
return;
@ -445,7 +461,7 @@ struct FreeTypeSearcher : TypeVisitor
traverse(*prop.readTy);
else
{
LUAU_ASSERT(prop.isShared());
LUAU_ASSERT(prop.isShared() || FFlag::LuauAutocompleteRefactorsForIncrementalAutocomplete);
Polarity p = polarity;
polarity = Both;
@ -526,7 +542,7 @@ struct TypeCacher : TypeOnceVisitor
DenseHashSet<TypePackId> uncacheablePacks{nullptr};
explicit TypeCacher(NotNull<DenseHashSet<TypeId>> cachedTypes)
: TypeOnceVisitor(/* skipBoundTypes */ true)
: TypeOnceVisitor(/* skipBoundTypes */ false)
, cachedTypes(cachedTypes)
{
}
@ -563,9 +579,19 @@ struct TypeCacher : TypeOnceVisitor
bool visit(TypeId ty) override
{
if (isUncacheable(ty) || isCached(ty))
return false;
return true;
// NOTE: `TypeCacher` should explicitly visit _all_ types and type packs,
// otherwise it's prone to marking types that cannot be cached as
// cacheable.
LUAU_ASSERT(false);
LUAU_UNREACHABLE();
}
bool visit(TypeId ty, const BoundType& btv) override
{
traverse(btv.boundTo);
if (isUncacheable(btv.boundTo))
markUncacheable(ty);
return false;
}
bool visit(TypeId ty, const FreeType& ft) override
@ -590,6 +616,12 @@ struct TypeCacher : TypeOnceVisitor
return false;
}
bool visit(TypeId ty, const ErrorType&) override
{
cache(ty);
return false;
}
bool visit(TypeId ty, const PrimitiveType&) override
{
cache(ty);
@ -727,6 +759,17 @@ struct TypeCacher : TypeOnceVisitor
return false;
}
bool visit(TypeId ty, const MetatableType& mtv) override
{
traverse(mtv.table);
traverse(mtv.metatable);
if (isUncacheable(mtv.table) || isUncacheable(mtv.metatable))
markUncacheable(ty);
else
cache(ty);
return false;
}
bool visit(TypeId ty, const ClassType&) override
{
cache(ty);
@ -739,6 +782,12 @@ struct TypeCacher : TypeOnceVisitor
return false;
}
bool visit(TypeId ty, const NoRefineType&) override
{
cache(ty);
return false;
}
bool visit(TypeId ty, const UnionType& ut) override
{
if (isUncacheable(ty) || isCached(ty))
@ -841,12 +890,31 @@ struct TypeCacher : TypeOnceVisitor
return false;
}
bool visit(TypePackId tp) override
{
// NOTE: `TypeCacher` should explicitly visit _all_ types and type packs,
// otherwise it's prone to marking types that cannot be cached as
// cacheable, which will segfault down the line.
LUAU_ASSERT(false);
LUAU_UNREACHABLE();
}
bool visit(TypePackId tp, const FreeTypePack&) override
{
markUncacheable(tp);
return false;
}
bool visit(TypePackId tp, const GenericTypePack& gtp) override
{
return true;
}
bool visit(TypePackId tp, const ErrorTypePack& etp) override
{
return true;
}
bool visit(TypePackId tp, const VariadicTypePack& vtp) override
{
if (isUncacheable(tp))
@ -871,6 +939,32 @@ struct TypeCacher : TypeOnceVisitor
markUncacheable(tp);
return false;
}
bool visit(TypePackId tp, const BoundTypePack& btp) override
{
traverse(btp.boundTo);
if (isUncacheable(btp.boundTo))
markUncacheable(tp);
return false;
}
bool visit(TypePackId tp, const TypePack& typ) override
{
bool uncacheable = false;
for (TypeId ty : typ.head)
{
traverse(ty);
uncacheable |= isUncacheable(ty);
}
if (typ.tail)
{
traverse(*typ.tail);
uncacheable |= isUncacheable(*typ.tail);
}
if (uncacheable)
markUncacheable(tp);
return false;
}
};
std::optional<TypeId> generalize(
@ -890,7 +984,7 @@ std::optional<TypeId> generalize(
FreeTypeSearcher fts{scope, cachedTypes};
fts.traverse(ty);
MutatingGeneralizer gen{builtinTypes, scope, cachedTypes, std::move(fts.positiveTypes), std::move(fts.negativeTypes), avoidSealingTables};
MutatingGeneralizer gen{arena, builtinTypes, scope, cachedTypes, std::move(fts.positiveTypes), std::move(fts.negativeTypes), avoidSealingTables};
gen.traverse(ty);

View file

@ -11,6 +11,7 @@
#include <algorithm>
LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds)
namespace Luau
{
@ -61,9 +62,7 @@ TypeId Instantiation::clean(TypeId ty)
LUAU_ASSERT(ftv);
FunctionType clone = FunctionType{level, scope, ftv->argTypes, ftv->retTypes, ftv->definition, ftv->hasSelf};
clone.magicFunction = ftv->magicFunction;
clone.dcrMagicFunction = ftv->dcrMagicFunction;
clone.dcrMagicRefinement = ftv->dcrMagicRefinement;
clone.magic = ftv->magic;
clone.tags = ftv->tags;
clone.argNames = ftv->argNames;
TypeId result = addType(std::move(clone));
@ -165,7 +164,7 @@ TypeId ReplaceGenerics::clean(TypeId ty)
}
else
{
return addType(FreeType{scope, level});
return FFlag::LuauFreeTypesMustHaveBounds ? arena->freshType(builtinTypes, scope, level) : addType(FreeType{scope, level});
}
}

View file

@ -227,6 +227,8 @@ static void errorToString(std::ostream& stream, const T& err)
stream << "UnexpectedTypeInSubtyping { ty = '" + toString(err.ty) + "' }";
else if constexpr (std::is_same_v<T, UnexpectedTypePackInSubtyping>)
stream << "UnexpectedTypePackInSubtyping { tp = '" + toString(err.tp) + "' }";
else if constexpr (std::is_same_v<T, UserDefinedTypeFunctionError>)
stream << "UserDefinedTypeFunctionError { " << err.message << " }";
else if constexpr (std::is_same_v<T, CannotAssignToNever>)
{
stream << "CannotAssignToNever { rvalueType = '" << toString(err.rhsType) << "', reason = '" << err.reason << "', cause = { ";

View file

@ -17,8 +17,7 @@ LUAU_FASTINTVARIABLE(LuauSuggestionDistance, 4)
LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAG(LuauAttribute)
LUAU_FASTFLAG(LuauNativeAttribute)
LUAU_FASTFLAGVARIABLE(LintRedundantNativeAttribute, false)
LUAU_FASTFLAGVARIABLE(LintRedundantNativeAttribute)
namespace Luau
{
@ -3239,7 +3238,6 @@ static void lintComments(LintContext& context, const std::vector<HotComment>& ho
static bool hasNativeCommentDirective(const std::vector<HotComment>& hotcomments)
{
LUAU_ASSERT(FFlag::LuauNativeAttribute);
LUAU_ASSERT(FFlag::LintRedundantNativeAttribute);
for (const HotComment& hc : hotcomments)
@ -3265,7 +3263,6 @@ struct LintRedundantNativeAttribute : AstVisitor
public:
LUAU_NOINLINE static void process(LintContext& context)
{
LUAU_ASSERT(FFlag::LuauNativeAttribute);
LUAU_ASSERT(FFlag::LintRedundantNativeAttribute);
LintRedundantNativeAttribute pass;
@ -3389,7 +3386,7 @@ std::vector<LintWarning> lint(
if (context.warningEnabled(LintWarning::Code_ComparisonPrecedence))
LintComparisonPrecedence::process(context);
if (FFlag::LuauNativeAttribute && FFlag::LintRedundantNativeAttribute && context.warningEnabled(LintWarning::Code_RedundantNativeAttribute))
if (FFlag::LintRedundantNativeAttribute && context.warningEnabled(LintWarning::Code_RedundantNativeAttribute))
{
if (hasNativeCommentDirective(hotcomments))
LintRedundantNativeAttribute::process(context);

View file

@ -15,11 +15,32 @@
#include <algorithm>
LUAU_FASTFLAG(LuauSolverV2);
LUAU_FASTFLAGVARIABLE(LuauIncrementalAutocompleteCommentDetection)
namespace Luau
{
static bool contains(Position pos, Comment comment)
static void defaultLogLuau(std::string_view input)
{
// The default is to do nothing because we don't want to mess with
// the xml parsing done by the dcr script.
}
Luau::LogLuauProc logLuau = &defaultLogLuau;
void setLogLuau(LogLuauProc ll)
{
logLuau = ll;
}
void resetLogLuauProc()
{
logLuau = &defaultLogLuau;
}
static bool contains_DEPRECATED(Position pos, Comment comment)
{
if (comment.location.contains(pos))
return true;
@ -32,7 +53,22 @@ static bool contains(Position pos, Comment comment)
return false;
}
static bool isWithinComment(const std::vector<Comment>& commentLocations, Position pos)
static bool contains(Position pos, Comment comment)
{
if (comment.location.contains(pos))
return true;
else if (comment.type == Lexeme::BrokenComment && comment.location.begin <= pos) // Broken comments are broken specifically because they don't
// have an end
return true;
// comments actually span the whole line - in incremental mode, we could pass a cursor outside of the current parsed comment range span, but it
// would still be 'within' the comment So, the cursor must be on the same line and the comment itself must come strictly after the `begin`
else if (comment.type == Lexeme::Comment && comment.location.end.line == pos.line && comment.location.begin <= pos)
return true;
else
return false;
}
bool isWithinComment(const std::vector<Comment>& commentLocations, Position pos)
{
auto iter = std::lower_bound(
commentLocations.begin(),
@ -40,6 +76,11 @@ static bool isWithinComment(const std::vector<Comment>& commentLocations, Positi
Comment{Lexeme::Comment, Location{pos, pos}},
[](const Comment& a, const Comment& b)
{
if (FFlag::LuauIncrementalAutocompleteCommentDetection)
{
if (a.type == Lexeme::Comment)
return a.location.end.line < b.location.end.line;
}
return a.location.end < b.location.end;
}
);
@ -47,7 +88,7 @@ static bool isWithinComment(const std::vector<Comment>& commentLocations, Positi
if (iter == commentLocations.end())
return false;
if (contains(pos, *iter))
if (FFlag::LuauIncrementalAutocompleteCommentDetection ? contains(pos, *iter) : contains_DEPRECATED(pos, *iter))
return true;
// Due to the nature of std::lower_bound, it is possible that iter points at a comment that ends
@ -131,10 +172,32 @@ struct ClonePublicInterface : Substitution
}
ftv->level = TypeLevel{0, 0};
if (FFlag::LuauSolverV2)
ftv->scope = nullptr;
}
else if (TableType* ttv = getMutable<TableType>(result))
{
ttv->level = TypeLevel{0, 0};
if (FFlag::LuauSolverV2)
ttv->scope = nullptr;
}
if (FFlag::LuauSolverV2)
{
if (auto freety = getMutable<FreeType>(result))
{
module->errors.emplace_back(
freety->scope->location,
module->name,
InternalError{"Free type is escaping its module; please report this bug at "
"https://github.com/luau-lang/luau/issues"}
);
result = builtinTypes->errorRecoveryType();
}
else if (auto genericty = getMutable<GenericType>(result))
{
genericty->scope = nullptr;
}
}
return result;
@ -142,7 +205,27 @@ struct ClonePublicInterface : Substitution
TypePackId clean(TypePackId tp) override
{
return clone(tp);
if (FFlag::LuauSolverV2)
{
auto clonedTp = clone(tp);
if (auto ftp = getMutable<FreeTypePack>(clonedTp))
{
module->errors.emplace_back(
ftp->scope->location,
module->name,
InternalError{"Free type pack is escaping its module; please report this bug at "
"https://github.com/luau-lang/luau/issues"}
);
clonedTp = builtinTypes->errorRecoveryTypePack();
}
else if (auto gtp = getMutable<GenericTypePack>(clonedTp))
gtp->scope = nullptr;
return clonedTp;
}
else
{
return clone(tp);
}
}
TypeId cloneType(TypeId ty)

View file

@ -14,11 +14,15 @@
#include "Luau/TypeFunction.h"
#include "Luau/Def.h"
#include "Luau/ToString.h"
#include "Luau/TypeFwd.h"
#include "Luau/TypeUtils.h"
#include <iostream>
#include <iterator>
LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds)
LUAU_FASTFLAGVARIABLE(LuauNonStrictVisitorImprovements)
LUAU_FASTFLAGVARIABLE(LuauNewNonStrictWarnOnUnknownGlobals)
namespace Luau
{
@ -154,8 +158,9 @@ private:
struct NonStrictTypeChecker
{
NotNull<BuiltinTypes> builtinTypes;
NotNull<Simplifier> simplifier;
NotNull<TypeFunctionRuntime> typeFunctionRuntime;
const NotNull<InternalErrorReporter> ice;
NotNull<TypeArena> arena;
Module* module;
@ -171,6 +176,8 @@ struct NonStrictTypeChecker
NonStrictTypeChecker(
NotNull<TypeArena> arena,
NotNull<BuiltinTypes> builtinTypes,
NotNull<Simplifier> simplifier,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
const NotNull<InternalErrorReporter> ice,
NotNull<UnifierSharedState> unifierState,
NotNull<const DataFlowGraph> dfg,
@ -178,11 +185,13 @@ struct NonStrictTypeChecker
Module* module
)
: builtinTypes(builtinTypes)
, simplifier(simplifier)
, typeFunctionRuntime(typeFunctionRuntime)
, ice(ice)
, arena(arena)
, module(module)
, normalizer{arena, builtinTypes, unifierState, /* cache inhabitance */ true}
, subtyping{builtinTypes, arena, NotNull(&normalizer), ice}
, subtyping{builtinTypes, arena, simplifier, NotNull(&normalizer), typeFunctionRuntime, ice}
, dfg(dfg)
, limits(limits)
{
@ -204,7 +213,7 @@ struct NonStrictTypeChecker
return *fst;
else if (auto ftp = get<FreeTypePack>(pack))
{
TypeId result = arena->addType(FreeType{ftp->scope});
TypeId result = FFlag::LuauFreeTypesMustHaveBounds ? arena->freshType(builtinTypes, ftp->scope) : arena->addType(FreeType{ftp->scope});
TypePackId freeTail = arena->addTypePack(FreeTypePack{ftp->scope});
TypePack* resultPack = emplaceTypePack<TypePack>(asMutable(pack));
@ -213,7 +222,7 @@ struct NonStrictTypeChecker
return result;
}
else if (get<Unifiable::Error>(pack))
else if (get<ErrorTypePack>(pack))
return builtinTypes->errorRecoveryType();
else if (finite(pack) && size(pack) == 0)
return builtinTypes->nilType; // `(f())` where `f()` returns no values is coerced into `nil`
@ -228,7 +237,12 @@ struct NonStrictTypeChecker
return instance;
ErrorVec errors =
reduceTypeFunctions(instance, location, TypeFunctionContext{arena, builtinTypes, stack.back(), NotNull{&normalizer}, ice, limits}, true)
reduceTypeFunctions(
instance,
location,
TypeFunctionContext{arena, builtinTypes, stack.back(), simplifier, NotNull{&normalizer}, typeFunctionRuntime, ice, limits},
true
)
.errors;
if (errors.empty())
@ -329,8 +343,9 @@ struct NonStrictTypeChecker
NonStrictContext visit(AstStatIf* ifStatement)
{
NonStrictContext condB = visit(ifStatement->condition);
NonStrictContext condB = visit(ifStatement->condition, ValueContext::RValue);
NonStrictContext branchContext;
// If there is no else branch, don't bother generating warnings for the then branch - we can't prove there is an error
if (ifStatement->elsebody)
{
@ -338,17 +353,32 @@ struct NonStrictTypeChecker
NonStrictContext elseBody = visit(ifStatement->elsebody);
branchContext = NonStrictContext::conjunction(builtinTypes, arena, thenBody, elseBody);
}
return NonStrictContext::disjunction(builtinTypes, arena, condB, branchContext);
}
NonStrictContext visit(AstStatWhile* whileStatement)
{
return {};
if (FFlag::LuauNonStrictVisitorImprovements)
{
NonStrictContext condition = visit(whileStatement->condition, ValueContext::RValue);
NonStrictContext body = visit(whileStatement->body);
return NonStrictContext::disjunction(builtinTypes, arena, condition, body);
}
else
return {};
}
NonStrictContext visit(AstStatRepeat* repeatStatement)
{
return {};
if (FFlag::LuauNonStrictVisitorImprovements)
{
NonStrictContext body = visit(repeatStatement->body);
NonStrictContext condition = visit(repeatStatement->condition, ValueContext::RValue);
return NonStrictContext::disjunction(builtinTypes, arena, body, condition);
}
else
return {};
}
NonStrictContext visit(AstStatBreak* breakStatement)
@ -363,49 +393,94 @@ struct NonStrictTypeChecker
NonStrictContext visit(AstStatReturn* returnStatement)
{
if (FFlag::LuauNonStrictVisitorImprovements)
{
// TODO: this is believing existing code, but i'm not sure if this makes sense
// for how the contexts are handled
for (AstExpr* expr : returnStatement->list)
visit(expr, ValueContext::RValue);
}
return {};
}
NonStrictContext visit(AstStatExpr* expr)
{
return visit(expr->expr);
return visit(expr->expr, ValueContext::RValue);
}
NonStrictContext visit(AstStatLocal* local)
{
for (AstExpr* rhs : local->values)
visit(rhs);
visit(rhs, ValueContext::RValue);
return {};
}
NonStrictContext visit(AstStatFor* forStatement)
{
return {};
if (FFlag::LuauNonStrictVisitorImprovements)
{
// TODO: throwing out context based on same principle as existing code?
if (forStatement->from)
visit(forStatement->from, ValueContext::RValue);
if (forStatement->to)
visit(forStatement->to, ValueContext::RValue);
if (forStatement->step)
visit(forStatement->step, ValueContext::RValue);
return visit(forStatement->body);
}
else
{
return {};
}
}
NonStrictContext visit(AstStatForIn* forInStatement)
{
return {};
if (FFlag::LuauNonStrictVisitorImprovements)
{
for (AstExpr* rhs : forInStatement->values)
visit(rhs, ValueContext::RValue);
return visit(forInStatement->body);
}
else
{
return {};
}
}
NonStrictContext visit(AstStatAssign* assign)
{
if (FFlag::LuauNonStrictVisitorImprovements)
{
for (AstExpr* lhs : assign->vars)
visit(lhs, ValueContext::LValue);
for (AstExpr* rhs : assign->values)
visit(rhs, ValueContext::RValue);
}
return {};
}
NonStrictContext visit(AstStatCompoundAssign* compoundAssign)
{
if (FFlag::LuauNonStrictVisitorImprovements)
{
visit(compoundAssign->var, ValueContext::LValue);
visit(compoundAssign->value, ValueContext::RValue);
}
return {};
}
NonStrictContext visit(AstStatFunction* statFn)
{
return visit(statFn->func);
return visit(statFn->func, ValueContext::RValue);
}
NonStrictContext visit(AstStatLocalFunction* localFn)
{
return visit(localFn->func);
return visit(localFn->func, ValueContext::RValue);
}
NonStrictContext visit(AstStatTypeAlias* typeAlias)
@ -415,7 +490,6 @@ struct NonStrictTypeChecker
NonStrictContext visit(AstStatTypeFunction* typeFunc)
{
reportError(GenericError{"This syntax is not supported"}, typeFunc->location);
return {};
}
@ -436,14 +510,22 @@ struct NonStrictTypeChecker
NonStrictContext visit(AstStatError* error)
{
if (FFlag::LuauNonStrictVisitorImprovements)
{
for (AstStat* stat : error->statements)
visit(stat);
for (AstExpr* expr : error->expressions)
visit(expr, ValueContext::RValue);
}
return {};
}
NonStrictContext visit(AstExpr* expr)
NonStrictContext visit(AstExpr* expr, ValueContext context)
{
auto pusher = pushStack(expr);
if (auto e = expr->as<AstExprGroup>())
return visit(e);
return visit(e, context);
else if (auto e = expr->as<AstExprConstantNil>())
return visit(e);
else if (auto e = expr->as<AstExprConstantBool>())
@ -453,17 +535,17 @@ struct NonStrictTypeChecker
else if (auto e = expr->as<AstExprConstantString>())
return visit(e);
else if (auto e = expr->as<AstExprLocal>())
return visit(e);
return visit(e, context);
else if (auto e = expr->as<AstExprGlobal>())
return visit(e);
return visit(e, context);
else if (auto e = expr->as<AstExprVarargs>())
return visit(e);
else if (auto e = expr->as<AstExprCall>())
return visit(e);
else if (auto e = expr->as<AstExprIndexName>())
return visit(e);
return visit(e, context);
else if (auto e = expr->as<AstExprIndexExpr>())
return visit(e);
return visit(e, context);
else if (auto e = expr->as<AstExprFunction>())
return visit(e);
else if (auto e = expr->as<AstExprTable>())
@ -487,9 +569,12 @@ struct NonStrictTypeChecker
}
}
NonStrictContext visit(AstExprGroup* group)
NonStrictContext visit(AstExprGroup* group, ValueContext context)
{
return {};
if (FFlag::LuauNonStrictVisitorImprovements)
return visit(group->expr, context);
else
return {};
}
NonStrictContext visit(AstExprConstantNil* expr)
@ -512,22 +597,34 @@ struct NonStrictTypeChecker
return {};
}
NonStrictContext visit(AstExprLocal* local)
NonStrictContext visit(AstExprLocal* local, ValueContext context)
{
return {};
}
NonStrictContext visit(AstExprGlobal* global)
NonStrictContext visit(AstExprGlobal* global, ValueContext context)
{
if (FFlag::LuauNewNonStrictWarnOnUnknownGlobals)
{
// We don't file unknown symbols for LValues.
if (context == ValueContext::LValue)
return {};
NotNull<Scope> scope = stack.back();
if (!scope->lookup(global->name))
{
reportError(UnknownSymbol{global->name.value, UnknownSymbol::Binding}, global->location);
}
}
return {};
}
NonStrictContext visit(AstExprVarargs* global)
NonStrictContext visit(AstExprVarargs* varargs)
{
return {};
}
NonStrictContext visit(AstExprCall* call)
{
NonStrictContext fresh{};
@ -536,106 +633,126 @@ struct NonStrictTypeChecker
return fresh;
TypeId fnTy = *originalCallTy;
if (auto fn = get<FunctionType>(follow(fnTy)))
if (auto fn = get<FunctionType>(follow(fnTy)); fn && fn->isCheckedFunction)
{
if (fn->isCheckedFunction)
// We know fn is a checked function, which means it looks like:
// (S1, ... SN) -> T &
// (~S1, unknown^N-1) -> error &
// (unknown, ~S2, unknown^N-2) -> error
// ...
// ...
// (unknown^N-1, ~S_N) -> error
std::vector<AstExpr*> arguments;
arguments.reserve(call->args.size + (call->self ? 1 : 0));
if (call->self)
{
// We know fn is a checked function, which means it looks like:
// (S1, ... SN) -> T &
// (~S1, unknown^N-1) -> error &
// (unknown, ~S2, unknown^N-2) -> error
// ...
// ...
// (unknown^N-1, ~S_N) -> error
std::vector<TypeId> argTypes;
argTypes.reserve(call->args.size);
// Pad out the arg types array with the types you would expect to see
TypePackIterator curr = begin(fn->argTypes);
TypePackIterator fin = end(fn->argTypes);
while (curr != fin)
if (auto indexExpr = call->func->as<AstExprIndexName>())
arguments.push_back(indexExpr->expr);
else
ice->ice("method call expression has no 'self'");
}
arguments.insert(arguments.end(), call->args.begin(), call->args.end());
std::vector<TypeId> argTypes;
argTypes.reserve(arguments.size());
// Move all the types over from the argument typepack for `fn`
TypePackIterator curr = begin(fn->argTypes);
TypePackIterator fin = end(fn->argTypes);
for (; curr != fin; curr++)
argTypes.push_back(*curr);
// Pad out the rest with the variadic as needed.
if (auto argTail = curr.tail())
{
if (const VariadicTypePack* vtp = get<VariadicTypePack>(follow(*argTail)))
{
argTypes.push_back(*curr);
++curr;
}
if (auto argTail = curr.tail())
{
if (const VariadicTypePack* vtp = get<VariadicTypePack>(follow(*argTail)))
while (argTypes.size() < arguments.size())
{
while (argTypes.size() < call->args.size)
{
argTypes.push_back(vtp->ty);
}
argTypes.push_back(vtp->ty);
}
}
}
std::string functionName = getFunctionNameAsString(*call->func).value_or("");
if (call->args.size > argTypes.size())
std::string functionName = getFunctionNameAsString(*call->func).value_or("");
if (arguments.size() > argTypes.size())
{
// We are passing more arguments than we expect, so we should error
reportError(CheckedFunctionIncorrectArgs{functionName, argTypes.size(), arguments.size()}, call->location);
return fresh;
}
for (size_t i = 0; i < arguments.size(); i++)
{
// For example, if the arg is "hi"
// The actual arg type is string
// The expected arg type is number
// The type of the argument in the overload is ~number
// We will compare arg and ~number
AstExpr* arg = arguments[i];
TypeId expectedArgType = argTypes[i];
std::shared_ptr<const NormalizedType> norm = normalizer.normalize(expectedArgType);
DefId def = dfg->getDef(arg);
TypeId runTimeErrorTy;
// If we're dealing with any, negating any will cause all subtype tests to fail
// However, when someone calls this function, they're going to want to be able to pass it anything,
// for that reason, we manually inject never into the context so that the runtime test will always pass.
if (!norm)
reportError(NormalizationTooComplex{}, arg->location);
if (norm && get<AnyType>(norm->tops))
runTimeErrorTy = builtinTypes->neverType;
else
runTimeErrorTy = getOrCreateNegation(expectedArgType);
fresh.addContext(def, runTimeErrorTy);
}
// Populate the context and now iterate through each of the arguments to the call to find out if we satisfy the types
for (size_t i = 0; i < arguments.size(); i++)
{
AstExpr* arg = arguments[i];
if (auto runTimeFailureType = willRunTimeError(arg, fresh))
reportError(CheckedFunctionCallError{argTypes[i], *runTimeFailureType, functionName, i}, arg->location);
}
if (arguments.size() < argTypes.size())
{
// We are passing fewer arguments than we expect
// so we need to ensure that the rest of the args are optional.
bool remainingArgsOptional = true;
for (size_t i = arguments.size(); i < argTypes.size(); i++)
remainingArgsOptional = remainingArgsOptional && isOptional(argTypes[i]);
if (!remainingArgsOptional)
{
// We are passing more arguments than we expect, so we should error
reportError(CheckedFunctionIncorrectArgs{functionName, argTypes.size(), call->args.size}, call->location);
reportError(CheckedFunctionIncorrectArgs{functionName, argTypes.size(), arguments.size()}, call->location);
return fresh;
}
for (size_t i = 0; i < call->args.size; i++)
{
// For example, if the arg is "hi"
// The actual arg type is string
// The expected arg type is number
// The type of the argument in the overload is ~number
// We will compare arg and ~number
AstExpr* arg = call->args.data[i];
TypeId expectedArgType = argTypes[i];
std::shared_ptr<const NormalizedType> norm = normalizer.normalize(expectedArgType);
DefId def = dfg->getDef(arg);
TypeId runTimeErrorTy;
// If we're dealing with any, negating any will cause all subtype tests to fail, since ~any is any
// However, when someone calls this function, they're going to want to be able to pass it anything,
// for that reason, we manually inject never into the context so that the runtime test will always pass.
if (!norm)
reportError(NormalizationTooComplex{}, arg->location);
if (norm && get<AnyType>(norm->tops))
runTimeErrorTy = builtinTypes->neverType;
else
runTimeErrorTy = getOrCreateNegation(expectedArgType);
fresh.addContext(def, runTimeErrorTy);
}
// Populate the context and now iterate through each of the arguments to the call to find out if we satisfy the types
for (size_t i = 0; i < call->args.size; i++)
{
AstExpr* arg = call->args.data[i];
if (auto runTimeFailureType = willRunTimeError(arg, fresh))
reportError(CheckedFunctionCallError{argTypes[i], *runTimeFailureType, functionName, i}, arg->location);
}
if (call->args.size < argTypes.size())
{
// We are passing fewer arguments than we expect
// so we need to ensure that the rest of the args are optional.
bool remainingArgsOptional = true;
for (size_t i = call->args.size; i < argTypes.size(); i++)
remainingArgsOptional = remainingArgsOptional && isOptional(argTypes[i]);
if (!remainingArgsOptional)
{
reportError(CheckedFunctionIncorrectArgs{functionName, argTypes.size(), call->args.size}, call->location);
return fresh;
}
}
}
}
return fresh;
}
NonStrictContext visit(AstExprIndexName* indexName)
NonStrictContext visit(AstExprIndexName* indexName, ValueContext context)
{
return {};
if (FFlag::LuauNonStrictVisitorImprovements)
return visit(indexName->expr, context);
else
return {};
}
NonStrictContext visit(AstExprIndexExpr* indexExpr)
NonStrictContext visit(AstExprIndexExpr* indexExpr, ValueContext context)
{
return {};
if (FFlag::LuauNonStrictVisitorImprovements)
{
NonStrictContext expr = visit(indexExpr->expr, context);
NonStrictContext index = visit(indexExpr->index, ValueContext::RValue);
return NonStrictContext::disjunction(builtinTypes, arena, expr, index);
}
else
return {};
}
NonStrictContext visit(AstExprFunction* exprFn)
@ -654,39 +771,74 @@ struct NonStrictTypeChecker
NonStrictContext visit(AstExprTable* table)
{
if (FFlag::LuauNonStrictVisitorImprovements)
{
for (auto [_, key, value] : table->items)
{
if (key)
visit(key, ValueContext::RValue);
visit(value, ValueContext::RValue);
}
}
return {};
}
NonStrictContext visit(AstExprUnary* unary)
{
return {};
if (FFlag::LuauNonStrictVisitorImprovements)
return visit(unary->expr, ValueContext::RValue);
else
return {};
}
NonStrictContext visit(AstExprBinary* binary)
{
return {};
if (FFlag::LuauNonStrictVisitorImprovements)
{
NonStrictContext lhs = visit(binary->left, ValueContext::RValue);
NonStrictContext rhs = visit(binary->right, ValueContext::RValue);
return NonStrictContext::disjunction(builtinTypes, arena, lhs, rhs);
}
else
return {};
}
NonStrictContext visit(AstExprTypeAssertion* typeAssertion)
{
return {};
if (FFlag::LuauNonStrictVisitorImprovements)
return visit(typeAssertion->expr, ValueContext::RValue);
else
return {};
}
NonStrictContext visit(AstExprIfElse* ifElse)
{
NonStrictContext condB = visit(ifElse->condition);
NonStrictContext thenB = visit(ifElse->trueExpr);
NonStrictContext elseB = visit(ifElse->falseExpr);
NonStrictContext condB = visit(ifElse->condition, ValueContext::RValue);
NonStrictContext thenB = visit(ifElse->trueExpr, ValueContext::RValue);
NonStrictContext elseB = visit(ifElse->falseExpr, ValueContext::RValue);
return NonStrictContext::disjunction(builtinTypes, arena, condB, NonStrictContext::conjunction(builtinTypes, arena, thenB, elseB));
}
NonStrictContext visit(AstExprInterpString* interpString)
{
if (FFlag::LuauNonStrictVisitorImprovements)
{
for (AstExpr* expr : interpString->expressions)
visit(expr, ValueContext::RValue);
}
return {};
}
NonStrictContext visit(AstExprError* error)
{
if (FFlag::LuauNonStrictVisitorImprovements)
{
for (AstExpr* expr : error->expressions)
visit(expr, ValueContext::RValue);
}
return {};
}
@ -754,6 +906,8 @@ private:
void checkNonStrict(
NotNull<BuiltinTypes> builtinTypes,
NotNull<Simplifier> simplifier,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<InternalErrorReporter> ice,
NotNull<UnifierSharedState> unifierState,
NotNull<const DataFlowGraph> dfg,
@ -764,7 +918,9 @@ void checkNonStrict(
{
LUAU_TIMETRACE_SCOPE("checkNonStrict", "Typechecking");
NonStrictTypeChecker typeChecker{NotNull{&module->internalTypes}, builtinTypes, ice, unifierState, dfg, limits, module};
NonStrictTypeChecker typeChecker{
NotNull{&module->internalTypes}, builtinTypes, simplifier, typeFunctionRuntime, ice, unifierState, dfg, limits, module
};
typeChecker.visit(sourceModule.root);
unfreeze(module->interfaceTypes);
copyErrors(module->errors, module->interfaceTypes, builtinTypes);

View file

@ -15,36 +15,17 @@
#include "Luau/TypeFwd.h"
#include "Luau/Unifier.h"
LUAU_FASTFLAGVARIABLE(DebugLuauCheckNormalizeInvariant, false)
LUAU_FASTFLAGVARIABLE(LuauNormalizeAwayUninhabitableTables, false)
LUAU_FASTFLAGVARIABLE(LuauNormalizeNotUnknownIntersection, false);
LUAU_FASTFLAGVARIABLE(LuauFixReduceStackPressure, false);
LUAU_FASTFLAGVARIABLE(LuauFixCyclicTablesBlowingStack, false);
LUAU_FASTFLAGVARIABLE(DebugLuauCheckNormalizeInvariant)
// This could theoretically be 2000 on amd64, but x86 requires this.
LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200);
LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000);
LUAU_FASTFLAG(LuauSolverV2);
static bool fixReduceStackPressure()
{
return FFlag::LuauFixReduceStackPressure || FFlag::LuauSolverV2;
}
static bool fixCyclicTablesBlowingStack()
{
return FFlag::LuauFixCyclicTablesBlowingStack || FFlag::LuauSolverV2;
}
LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000)
LUAU_FASTINTVARIABLE(LuauNormalizeIntersectionLimit, 200)
LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAGVARIABLE(LuauFixInfiniteRecursionInNormalization)
LUAU_FASTFLAGVARIABLE(LuauFixNormalizedIntersectionOfNegatedClass)
namespace Luau
{
// helper to make `FFlag::LuauNormalizeAwayUninhabitableTables` not explicitly required when DCR is enabled.
static bool normalizeAwayUninhabitableTables()
{
return FFlag::LuauNormalizeAwayUninhabitableTables || FFlag::LuauSolverV2;
}
static bool shouldEarlyExit(NormalizationResult res)
{
// if res is hit limits, return control flow
@ -589,10 +570,11 @@ NormalizationResult Normalizer::isInhabited(TypeId ty, Set<TypeId>& seen)
NormalizationResult Normalizer::isIntersectionInhabited(TypeId left, TypeId right)
{
Set<TypeId> seen{nullptr};
return isIntersectionInhabited(left, right, seen);
SeenTablePropPairs seenTablePropPairs{{nullptr, nullptr}};
return isIntersectionInhabited(left, right, seenTablePropPairs, seen);
}
NormalizationResult Normalizer::isIntersectionInhabited(TypeId left, TypeId right, Set<TypeId>& seenSet)
NormalizationResult Normalizer::isIntersectionInhabited(TypeId left, TypeId right, SeenTablePropPairs& seenTablePropPairs, Set<TypeId>& seenSet)
{
left = follow(left);
right = follow(right);
@ -605,7 +587,7 @@ NormalizationResult Normalizer::isIntersectionInhabited(TypeId left, TypeId righ
}
NormalizedType norm{builtinTypes};
NormalizationResult res = normalizeIntersections({left, right}, norm, seenSet);
NormalizationResult res = normalizeIntersections({left, right}, norm, seenTablePropPairs, seenSet);
if (res != NormalizationResult::True)
{
if (cacheInhabitance && res == NormalizationResult::False)
@ -956,7 +938,8 @@ std::shared_ptr<const NormalizedType> Normalizer::normalize(TypeId ty)
NormalizedType norm{builtinTypes};
Set<TypeId> seenSetTypes{nullptr};
NormalizationResult res = unionNormalWithTy(norm, ty, seenSetTypes);
SeenTablePropPairs seenTablePropPairs{{nullptr, nullptr}};
NormalizationResult res = unionNormalWithTy(norm, ty, seenTablePropPairs, seenSetTypes);
if (res != NormalizationResult::True)
return nullptr;
@ -974,7 +957,12 @@ std::shared_ptr<const NormalizedType> Normalizer::normalize(TypeId ty)
return shared;
}
NormalizationResult Normalizer::normalizeIntersections(const std::vector<TypeId>& intersections, NormalizedType& outType, Set<TypeId>& seenSet)
NormalizationResult Normalizer::normalizeIntersections(
const std::vector<TypeId>& intersections,
NormalizedType& outType,
SeenTablePropPairs& seenTablePropPairs,
Set<TypeId>& seenSet
)
{
if (!arena)
sharedState->iceHandler->ice("Normalizing types outside a module");
@ -983,7 +971,7 @@ NormalizationResult Normalizer::normalizeIntersections(const std::vector<TypeId>
// Now we need to intersect the two types
for (auto ty : intersections)
{
NormalizationResult res = intersectNormalWithTy(norm, ty, seenSet);
NormalizationResult res = intersectNormalWithTy(norm, ty, seenTablePropPairs, seenSet);
if (res != NormalizationResult::True)
return res;
}
@ -1620,7 +1608,7 @@ void Normalizer::unionTablesWithTable(TypeIds& heres, TypeId there)
// TODO: remove unions of tables where possible
// we can always skip `never`
if (normalizeAwayUninhabitableTables() && get<NeverType>(there))
if (get<NeverType>(there))
return;
heres.insert(there);
@ -1747,7 +1735,13 @@ NormalizationResult Normalizer::intersectNormalWithNegationTy(TypeId toNegate, N
}
// See above for an explaination of `ignoreSmallerTyvars`.
NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, Set<TypeId>& seenSetTypes, int ignoreSmallerTyvars)
NormalizationResult Normalizer::unionNormalWithTy(
NormalizedType& here,
TypeId there,
SeenTablePropPairs& seenTablePropPairs,
Set<TypeId>& seenSetTypes,
int ignoreSmallerTyvars
)
{
RecursionCounter _rc(&sharedState->counters.recursionCount);
if (!withinResourceLimits())
@ -1779,7 +1773,7 @@ NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId t
for (UnionTypeIterator it = begin(utv); it != end(utv); ++it)
{
NormalizationResult res = unionNormalWithTy(here, *it, seenSetTypes);
NormalizationResult res = unionNormalWithTy(here, *it, seenTablePropPairs, seenSetTypes);
if (res != NormalizationResult::True)
{
seenSetTypes.erase(there);
@ -1800,7 +1794,7 @@ NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId t
norm.tops = builtinTypes->anyType;
for (IntersectionTypeIterator it = begin(itv); it != end(itv); ++it)
{
NormalizationResult res = intersectNormalWithTy(norm, *it, seenSetTypes);
NormalizationResult res = intersectNormalWithTy(norm, *it, seenTablePropPairs, seenSetTypes);
if (res != NormalizationResult::True)
{
seenSetTypes.erase(there);
@ -1814,7 +1808,8 @@ NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId t
}
else if (get<UnknownType>(here.tops))
return NormalizationResult::True;
else if (get<GenericType>(there) || get<FreeType>(there) || get<BlockedType>(there) || get<PendingExpansionType>(there) || get<TypeFunctionInstanceType>(there))
else if (get<GenericType>(there) || get<FreeType>(there) || get<BlockedType>(there) || get<PendingExpansionType>(there) ||
get<TypeFunctionInstanceType>(there))
{
if (tyvarIndex(there) <= ignoreSmallerTyvars)
return NormalizationResult::True;
@ -1891,7 +1886,7 @@ NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId t
if (res != NormalizationResult::True)
return res;
}
else if (get<PendingExpansionType>(there) || get<TypeFunctionInstanceType>(there))
else if (get<PendingExpansionType>(there) || get<TypeFunctionInstanceType>(there) || get<NoRefineType>(there))
{
// nothing
}
@ -1900,7 +1895,7 @@ NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId t
for (auto& [tyvar, intersect] : here.tyvars)
{
NormalizationResult res = unionNormalWithTy(*intersect, there, seenSetTypes, tyvarIndex(tyvar));
NormalizationResult res = unionNormalWithTy(*intersect, there, seenTablePropPairs, seenSetTypes, tyvarIndex(tyvar));
if (res != NormalizationResult::True)
return res;
}
@ -2289,9 +2284,24 @@ void Normalizer::intersectClassesWithClass(NormalizedClassType& heres, TypeId th
else if (isSubclass(there, hereTy))
{
TypeIds negations = std::move(hereNegations);
bool emptyIntersectWithNegation = false;
for (auto nIt = negations.begin(); nIt != negations.end();)
{
if (FFlag::LuauFixNormalizedIntersectionOfNegatedClass && isSubclass(there, *nIt))
{
// Hitting this block means that the incoming class is a
// subclass of this type, _and_ one of its negations is a
// superclass of this type, e.g.:
//
// Dog & ~Animal
//
// Clearly this intersects to never, so we mark this class as
// being removed from the normalized class type.
emptyIntersectWithNegation = true;
break;
}
if (!isSubclass(*nIt, there))
{
nIt = negations.erase(nIt);
@ -2304,7 +2314,8 @@ void Normalizer::intersectClassesWithClass(NormalizedClassType& heres, TypeId th
it = heres.ordering.erase(it);
heres.classes.erase(hereTy);
heres.pushPair(there, std::move(negations));
if (!emptyIntersectWithNegation)
heres.pushPair(there, std::move(negations));
break;
}
// If the incoming class is a superclass of the current class, we don't
@ -2510,7 +2521,7 @@ std::optional<TypePackId> Normalizer::intersectionOfTypePacks(TypePackId here, T
return arena->addTypePack({});
}
std::optional<TypeId> Normalizer::intersectionOfTables(TypeId here, TypeId there, Set<TypeId>& seenSet)
std::optional<TypeId> Normalizer::intersectionOfTables(TypeId here, TypeId there, SeenTablePropPairs& seenTablePropPairs, Set<TypeId>& seenSet)
{
if (here == there)
return here;
@ -2589,49 +2600,60 @@ std::optional<TypeId> Normalizer::intersectionOfTables(TypeId here, TypeId there
{
if (tprop.readTy.has_value())
{
// if the intersection of the read types of a property is uninhabited, the whole table is `never`.
if (fixReduceStackPressure())
if (FFlag::LuauFixInfiniteRecursionInNormalization)
{
// We've seen these table prop elements before and we're about to ask if their intersection
// is inhabited
if (fixCyclicTablesBlowingStack())
{
if (seenSet.contains(*hprop.readTy) && seenSet.contains(*tprop.readTy))
{
seenSet.erase(*hprop.readTy);
seenSet.erase(*tprop.readTy);
return {builtinTypes->neverType};
}
else
{
seenSet.insert(*hprop.readTy);
seenSet.insert(*tprop.readTy);
}
}
TypeId ty = simplifyIntersection(builtinTypes, NotNull{arena}, *hprop.readTy, *tprop.readTy).result;
NormalizationResult res = isIntersectionInhabited(*hprop.readTy, *tprop.readTy);
// Cleanup
if (fixCyclicTablesBlowingStack())
{
seenSet.erase(*hprop.readTy);
seenSet.erase(*tprop.readTy);
}
if (normalizeAwayUninhabitableTables() && NormalizationResult::True != res)
// If any property is going to get mapped to `never`, we can just call the entire table `never`.
// Since this check is syntactic, we may sometimes miss simplifying tables with complex uninhabited properties.
// Prior versions of this code attempted to do this semantically using the normalization machinery, but this
// mistakenly causes infinite loops when giving more complex recursive table types. As it stands, this approach
// will continue to scale as simplification is improved, but we may wish to reintroduce the semantic approach
// once we have revisited the usage of seen sets systematically (and possibly with some additional guarding to recognize
// when types are infinitely-recursive with non-pointer identical instances of them, or some guard to prevent that
// construction altogether). See also: `gh1632_no_infinite_recursion_in_normalization`
if (get<NeverType>(ty))
return {builtinTypes->neverType};
prop.readTy = ty;
hereSubThere &= (ty == hprop.readTy);
thereSubHere &= (ty == tprop.readTy);
}
else
{
if (normalizeAwayUninhabitableTables() &&
NormalizationResult::False == isIntersectionInhabited(*hprop.readTy, *tprop.readTy))
return {builtinTypes->neverType};
}
// if the intersection of the read types of a property is uninhabited, the whole table is `never`.
// We've seen these table prop elements before and we're about to ask if their intersection
// is inhabited
TypeId ty = simplifyIntersection(builtinTypes, NotNull{arena}, *hprop.readTy, *tprop.readTy).result;
prop.readTy = ty;
hereSubThere &= (ty == hprop.readTy);
thereSubHere &= (ty == tprop.readTy);
auto pair1 = std::pair{*hprop.readTy, *tprop.readTy};
auto pair2 = std::pair{*tprop.readTy, *hprop.readTy};
if (seenTablePropPairs.contains(pair1) || seenTablePropPairs.contains(pair2))
{
seenTablePropPairs.erase(pair1);
seenTablePropPairs.erase(pair2);
return {builtinTypes->neverType};
}
else
{
seenTablePropPairs.insert(pair1);
seenTablePropPairs.insert(pair2);
}
// FIXME(ariel): this is being added in a flag removal, so not changing the semantics here, but worth noting that this
// fresh `seenSet` is definitely a bug. we already have `seenSet` from the parameter that _should_ have been used here.
Set<TypeId> seenSet{nullptr};
NormalizationResult res = isIntersectionInhabited(*hprop.readTy, *tprop.readTy, seenTablePropPairs, seenSet);
seenTablePropPairs.erase(pair1);
seenTablePropPairs.erase(pair2);
if (NormalizationResult::True != res)
return {builtinTypes->neverType};
TypeId ty = simplifyIntersection(builtinTypes, NotNull{arena}, *hprop.readTy, *tprop.readTy).result;
prop.readTy = ty;
hereSubThere &= (ty == hprop.readTy);
thereSubHere &= (ty == tprop.readTy);
}
}
else
{
@ -2737,7 +2759,7 @@ std::optional<TypeId> Normalizer::intersectionOfTables(TypeId here, TypeId there
if (tmtable && hmtable)
{
// NOTE: this assumes metatables are ivariant
if (std::optional<TypeId> mtable = intersectionOfTables(hmtable, tmtable, seenSet))
if (std::optional<TypeId> mtable = intersectionOfTables(hmtable, tmtable, seenTablePropPairs, seenSet))
{
if (table == htable && *mtable == hmtable)
return here;
@ -2767,12 +2789,12 @@ std::optional<TypeId> Normalizer::intersectionOfTables(TypeId here, TypeId there
return table;
}
void Normalizer::intersectTablesWithTable(TypeIds& heres, TypeId there, Set<TypeId>& seenSetTypes)
void Normalizer::intersectTablesWithTable(TypeIds& heres, TypeId there, SeenTablePropPairs& seenTablePropPairs, Set<TypeId>& seenSetTypes)
{
TypeIds tmp;
for (TypeId here : heres)
{
if (std::optional<TypeId> inter = intersectionOfTables(here, there, seenSetTypes))
if (std::optional<TypeId> inter = intersectionOfTables(here, there, seenTablePropPairs, seenSetTypes))
tmp.insert(*inter);
}
heres.retain(tmp);
@ -2787,7 +2809,8 @@ void Normalizer::intersectTables(TypeIds& heres, const TypeIds& theres)
for (TypeId there : theres)
{
Set<TypeId> seenSetTypes{nullptr};
if (std::optional<TypeId> inter = intersectionOfTables(here, there, seenSetTypes))
SeenTablePropPairs seenTablePropPairs{{nullptr, nullptr}};
if (std::optional<TypeId> inter = intersectionOfTables(here, there, seenTablePropPairs, seenSetTypes))
tmp.insert(*inter);
}
}
@ -3005,12 +3028,17 @@ void Normalizer::intersectFunctions(NormalizedFunctionType& heres, const Normali
}
}
NormalizationResult Normalizer::intersectTyvarsWithTy(NormalizedTyvars& here, TypeId there, Set<TypeId>& seenSetTypes)
NormalizationResult Normalizer::intersectTyvarsWithTy(
NormalizedTyvars& here,
TypeId there,
SeenTablePropPairs& seenTablePropPairs,
Set<TypeId>& seenSetTypes
)
{
for (auto it = here.begin(); it != here.end();)
{
NormalizedType& inter = *it->second;
NormalizationResult res = intersectNormalWithTy(inter, there, seenSetTypes);
NormalizationResult res = intersectNormalWithTy(inter, there, seenTablePropPairs, seenSetTypes);
if (res != NormalizationResult::True)
return res;
if (isShallowInhabited(inter))
@ -3024,6 +3052,10 @@ NormalizationResult Normalizer::intersectTyvarsWithTy(NormalizedTyvars& here, Ty
// See above for an explaination of `ignoreSmallerTyvars`.
NormalizationResult Normalizer::intersectNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars)
{
RecursionCounter _rc(&sharedState->counters.recursionCount);
if (!withinResourceLimits())
return NormalizationResult::HitLimits;
if (!get<NeverType>(there.tops))
{
here.tops = intersectionOfTops(here.tops, there.tops);
@ -3035,6 +3067,11 @@ NormalizationResult Normalizer::intersectNormals(NormalizedType& here, const Nor
return unionNormals(here, there, ignoreSmallerTyvars);
}
// Limit based on worst-case expansion of the table intersection
// This restriction can be relaxed when table intersection simplification is improved
if (here.tables.size() * there.tables.size() >= size_t(FInt::LuauNormalizeIntersectionLimit))
return NormalizationResult::HitLimits;
here.booleans = intersectionOfBools(here.booleans, there.booleans);
intersectClasses(here.classes, there.classes);
@ -3088,7 +3125,12 @@ NormalizationResult Normalizer::intersectNormals(NormalizedType& here, const Nor
return NormalizationResult::True;
}
NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there, Set<TypeId>& seenSetTypes)
NormalizationResult Normalizer::intersectNormalWithTy(
NormalizedType& here,
TypeId there,
SeenTablePropPairs& seenTablePropPairs,
Set<TypeId>& seenSetTypes
)
{
RecursionCounter _rc(&sharedState->counters.recursionCount);
if (!withinResourceLimits())
@ -3104,14 +3146,14 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type
else if (!get<NeverType>(here.tops))
{
clearNormal(here);
return unionNormalWithTy(here, there, seenSetTypes);
return unionNormalWithTy(here, there, seenTablePropPairs, seenSetTypes);
}
else if (const UnionType* utv = get<UnionType>(there))
{
NormalizedType norm{builtinTypes};
for (UnionTypeIterator it = begin(utv); it != end(utv); ++it)
{
NormalizationResult res = unionNormalWithTy(norm, *it, seenSetTypes);
NormalizationResult res = unionNormalWithTy(norm, *it, seenTablePropPairs, seenSetTypes);
if (res != NormalizationResult::True)
return res;
}
@ -3121,13 +3163,14 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type
{
for (IntersectionTypeIterator it = begin(itv); it != end(itv); ++it)
{
NormalizationResult res = intersectNormalWithTy(here, *it, seenSetTypes);
NormalizationResult res = intersectNormalWithTy(here, *it, seenTablePropPairs, seenSetTypes);
if (res != NormalizationResult::True)
return res;
}
return NormalizationResult::True;
}
else if (get<GenericType>(there) || get<FreeType>(there) || get<BlockedType>(there) || get<PendingExpansionType>(there) || get<TypeFunctionInstanceType>(there))
else if (get<GenericType>(there) || get<FreeType>(there) || get<BlockedType>(there) || get<PendingExpansionType>(there) ||
get<TypeFunctionInstanceType>(there))
{
NormalizedType thereNorm{builtinTypes};
NormalizedType topNorm{builtinTypes};
@ -3150,7 +3193,7 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type
{
TypeIds tables = std::move(here.tables);
clearNormal(here);
intersectTablesWithTable(tables, there, seenSetTypes);
intersectTablesWithTable(tables, there, seenTablePropPairs, seenSetTypes);
here.tables = std::move(tables);
}
else if (get<ClassType>(there))
@ -3243,13 +3286,18 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type
// assumption that it is the same as any.
return NormalizationResult::True;
}
else if (get<NoRefineType>(t))
{
// `*no-refine*` means we will never do anything to affect the intersection.
return NormalizationResult::True;
}
else if (get<NeverType>(t))
{
// if we're intersecting with `~never`, this is equivalent to intersecting with `unknown`
// this is a noop since an intersection with `unknown` is trivial.
return NormalizationResult::True;
}
else if ((FFlag::LuauNormalizeNotUnknownIntersection || FFlag::LuauSolverV2) && get<UnknownType>(t))
else if (get<UnknownType>(t))
{
// if we're intersecting with `~unknown`, this is equivalent to intersecting with `never`
// this means we should clear the type entirely.
@ -3257,7 +3305,7 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type
return NormalizationResult::True;
}
else if (auto nt = get<NegationType>(t))
return intersectNormalWithTy(here, nt->ty, seenSetTypes);
return intersectNormalWithTy(here, nt->ty, seenTablePropPairs, seenSetTypes);
else
{
// TODO negated unions, intersections, table, and function.
@ -3269,10 +3317,15 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type
{
here.classes.resetToNever();
}
else if (get<NoRefineType>(there))
{
// `*no-refine*` means we will never do anything to affect the intersection.
return NormalizationResult::True;
}
else
LUAU_ASSERT(!"Unreachable");
NormalizationResult res = intersectTyvarsWithTy(tyvars, there, seenSetTypes);
NormalizationResult res = intersectTyvarsWithTy(tyvars, there, seenTablePropPairs, seenSetTypes);
if (res != NormalizationResult::True)
return res;
here.tyvars = std::move(tyvars);
@ -3420,16 +3473,27 @@ TypeId Normalizer::typeFromNormal(const NormalizedType& norm)
return arena->addType(UnionType{std::move(result)});
}
bool isSubtype(TypeId subTy, TypeId superTy, NotNull<Scope> scope, NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter& ice)
bool isSubtype(
TypeId subTy,
TypeId superTy,
NotNull<Scope> scope,
NotNull<BuiltinTypes> builtinTypes,
NotNull<Simplifier> simplifier,
InternalErrorReporter& ice
)
{
UnifierSharedState sharedState{&ice};
TypeArena arena;
Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}};
TypeCheckLimits limits;
TypeFunctionRuntime typeFunctionRuntime{
NotNull{&ice}, NotNull{&limits}
}; // TODO: maybe subtyping checks should not invoke user-defined type function runtime
// Subtyping under DCR is not implemented using unification!
if (FFlag::LuauSolverV2)
{
Subtyping subtyping{builtinTypes, NotNull{&arena}, NotNull{&normalizer}, NotNull{&ice}};
Subtyping subtyping{builtinTypes, NotNull{&arena}, simplifier, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&ice}};
return subtyping.isSubtype(subTy, superTy, scope).isSubtype;
}
@ -3442,16 +3506,27 @@ bool isSubtype(TypeId subTy, TypeId superTy, NotNull<Scope> scope, NotNull<Built
}
}
bool isSubtype(TypePackId subPack, TypePackId superPack, NotNull<Scope> scope, NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter& ice)
bool isSubtype(
TypePackId subPack,
TypePackId superPack,
NotNull<Scope> scope,
NotNull<BuiltinTypes> builtinTypes,
NotNull<Simplifier> simplifier,
InternalErrorReporter& ice
)
{
UnifierSharedState sharedState{&ice};
TypeArena arena;
Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}};
TypeCheckLimits limits;
TypeFunctionRuntime typeFunctionRuntime{
NotNull{&ice}, NotNull{&limits}
}; // TODO: maybe subtyping checks should not invoke user-defined type function runtime
// Subtyping under DCR is not implemented using unification!
if (FFlag::LuauSolverV2)
{
Subtyping subtyping{builtinTypes, NotNull{&arena}, NotNull{&normalizer}, NotNull{&ice}};
Subtyping subtyping{builtinTypes, NotNull{&arena}, simplifier, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&ice}};
return subtyping.isSubtype(subPack, superPack, scope).isSubtype;
}
@ -3464,38 +3539,4 @@ bool isSubtype(TypePackId subPack, TypePackId superPack, NotNull<Scope> scope, N
}
}
bool isConsistentSubtype(TypeId subTy, TypeId superTy, NotNull<Scope> scope, NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter& ice)
{
LUAU_ASSERT(!FFlag::LuauSolverV2);
UnifierSharedState sharedState{&ice};
TypeArena arena;
Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}};
Unifier u{NotNull{&normalizer}, scope, Location{}, Covariant};
u.tryUnify(subTy, superTy);
const bool ok = u.errors.empty() && u.log.empty();
return ok;
}
bool isConsistentSubtype(
TypePackId subPack,
TypePackId superPack,
NotNull<Scope> scope,
NotNull<BuiltinTypes> builtinTypes,
InternalErrorReporter& ice
)
{
LUAU_ASSERT(!FFlag::LuauSolverV2);
UnifierSharedState sharedState{&ice};
TypeArena arena;
Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}};
Unifier u{NotNull{&normalizer}, scope, Location{}, Covariant};
u.tryUnify(subPack, superPack);
const bool ok = u.errors.empty() && u.log.empty();
return ok;
}
} // namespace Luau

View file

@ -16,7 +16,9 @@ namespace Luau
OverloadResolver::OverloadResolver(
NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> arena,
NotNull<Simplifier> simplifier,
NotNull<Normalizer> normalizer,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<Scope> scope,
NotNull<InternalErrorReporter> reporter,
NotNull<TypeCheckLimits> limits,
@ -24,11 +26,13 @@ OverloadResolver::OverloadResolver(
)
: builtinTypes(builtinTypes)
, arena(arena)
, simplifier(simplifier)
, normalizer(normalizer)
, typeFunctionRuntime(typeFunctionRuntime)
, scope(scope)
, ice(reporter)
, limits(limits)
, subtyping({builtinTypes, arena, normalizer, ice})
, subtyping({builtinTypes, arena, simplifier, normalizer, typeFunctionRuntime, ice})
, callLoc(callLocation)
{
}
@ -199,8 +203,9 @@ std::pair<OverloadResolver::Analysis, ErrorVec> OverloadResolver::checkOverload_
const std::vector<AstExpr*>* argExprs
)
{
FunctionGraphReductionResult result =
reduceTypeFunctions(fnTy, callLoc, TypeFunctionContext{arena, builtinTypes, scope, normalizer, ice, limits}, /*force=*/true);
FunctionGraphReductionResult result = reduceTypeFunctions(
fnTy, callLoc, TypeFunctionContext{arena, builtinTypes, scope, simplifier, normalizer, typeFunctionRuntime, ice, limits}, /*force=*/true
);
if (!result.errors.empty())
return {OverloadIsNonviable, result.errors};
@ -401,10 +406,12 @@ void OverloadResolver::add(Analysis analysis, TypeId ty, ErrorVec&& errors)
// we wrap calling the overload resolver in a separate function to reduce overall stack pressure in `solveFunctionCall`.
// this limits the lifetime of `OverloadResolver`, a large type, to only as long as it is actually needed.
std::optional<TypeId> selectOverload(
static std::optional<TypeId> selectOverload(
NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> arena,
NotNull<Simplifier> simplifier,
NotNull<Normalizer> normalizer,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<Scope> scope,
NotNull<InternalErrorReporter> iceReporter,
NotNull<TypeCheckLimits> limits,
@ -413,8 +420,9 @@ std::optional<TypeId> selectOverload(
TypePackId argsPack
)
{
OverloadResolver resolver{builtinTypes, arena, normalizer, scope, iceReporter, limits, location};
auto [status, overload] = resolver.selectOverload(fn, argsPack);
auto resolver =
std::make_unique<OverloadResolver>(builtinTypes, arena, simplifier, normalizer, typeFunctionRuntime, scope, iceReporter, limits, location);
auto [status, overload] = resolver->selectOverload(fn, argsPack);
if (status == OverloadResolver::Analysis::Ok)
return overload;
@ -428,7 +436,9 @@ std::optional<TypeId> selectOverload(
SolveResult solveFunctionCall(
NotNull<TypeArena> arena,
NotNull<BuiltinTypes> builtinTypes,
NotNull<Simplifier> simplifier,
NotNull<Normalizer> normalizer,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<InternalErrorReporter> iceReporter,
NotNull<TypeCheckLimits> limits,
NotNull<Scope> scope,
@ -437,7 +447,8 @@ SolveResult solveFunctionCall(
TypePackId argsPack
)
{
std::optional<TypeId> overloadToUse = selectOverload(builtinTypes, arena, normalizer, scope, iceReporter, limits, location, fn, argsPack);
std::optional<TypeId> overloadToUse =
selectOverload(builtinTypes, arena, simplifier, normalizer, typeFunctionRuntime, scope, iceReporter, limits, location, fn, argsPack);
if (!overloadToUse)
return {SolveResult::NoMatchingOverload};
@ -450,9 +461,9 @@ SolveResult solveFunctionCall(
if (!u2.genericSubstitutions.empty() || !u2.genericPackSubstitutions.empty())
{
Instantiation2 instantiation{arena, std::move(u2.genericSubstitutions), std::move(u2.genericPackSubstitutions)};
auto instantiation = std::make_unique<Instantiation2>(arena, std::move(u2.genericSubstitutions), std::move(u2.genericPackSubstitutions));
std::optional<TypePackId> subst = instantiation.substitute(resultPack);
std::optional<TypePackId> subst = instantiation->substitute(resultPack);
if (!subst)
return {SolveResult::CodeTooComplex};

View file

@ -4,6 +4,8 @@
#include "Luau/Ast.h"
#include "Luau/Module.h"
LUAU_FASTFLAGVARIABLE(LuauExtendedSimpleRequire)
namespace Luau
{
@ -65,7 +67,7 @@ struct RequireTracer : AstVisitor
return true;
}
AstExpr* getDependent(AstExpr* node)
AstExpr* getDependent_DEPRECATED(AstExpr* node)
{
if (AstExprLocal* expr = node->as<AstExprLocal>())
return locals[expr->local];
@ -78,50 +80,122 @@ struct RequireTracer : AstVisitor
else
return nullptr;
}
AstNode* getDependent(AstNode* node)
{
if (AstExprLocal* expr = node->as<AstExprLocal>())
return locals[expr->local];
else if (AstExprIndexName* expr = node->as<AstExprIndexName>())
return expr->expr;
else if (AstExprIndexExpr* expr = node->as<AstExprIndexExpr>())
return expr->expr;
else if (AstExprCall* expr = node->as<AstExprCall>(); expr && expr->self)
return expr->func->as<AstExprIndexName>()->expr;
else if (AstExprGroup* expr = node->as<AstExprGroup>())
return expr->expr;
else if (AstExprTypeAssertion* expr = node->as<AstExprTypeAssertion>())
return expr->annotation;
else if (AstTypeGroup* expr = node->as<AstTypeGroup>())
return expr->type;
else if (AstTypeTypeof* expr = node->as<AstTypeTypeof>())
return expr->expr;
else
return nullptr;
}
void process()
{
ModuleInfo moduleContext{currentModuleName};
// seed worklist with require arguments
work.reserve(requireCalls.size());
for (AstExprCall* require : requireCalls)
work.push_back(require->args.data[0]);
// push all dependent expressions to the work stack; note that the vector is modified during traversal
for (size_t i = 0; i < work.size(); ++i)
if (AstExpr* dep = getDependent(work[i]))
work.push_back(dep);
// resolve all expressions to a module info
for (size_t i = work.size(); i > 0; --i)
if (FFlag::LuauExtendedSimpleRequire)
{
AstExpr* expr = work[i - 1];
// seed worklist with require arguments
work.reserve(requireCalls.size());
// when multiple expressions depend on the same one we push it to work queue multiple times
if (result.exprs.contains(expr))
continue;
for (AstExprCall* require : requireCalls)
work.push_back(require->args.data[0]);
std::optional<ModuleInfo> info;
if (AstExpr* dep = getDependent(expr))
// push all dependent expressions to the work stack; note that the vector is modified during traversal
for (size_t i = 0; i < work.size(); ++i)
{
const ModuleInfo* context = result.exprs.find(dep);
if (AstNode* dep = getDependent(work[i]))
work.push_back(dep);
}
// locals just inherit their dependent context, no resolution required
if (expr->is<AstExprLocal>())
info = context ? std::optional<ModuleInfo>(*context) : std::nullopt;
// resolve all expressions to a module info
for (size_t i = work.size(); i > 0; --i)
{
AstNode* expr = work[i - 1];
// when multiple expressions depend on the same one we push it to work queue multiple times
if (result.exprs.contains(expr))
continue;
std::optional<ModuleInfo> info;
if (AstNode* dep = getDependent(expr))
{
const ModuleInfo* context = result.exprs.find(dep);
if (context && expr->is<AstExprLocal>())
info = *context; // locals just inherit their dependent context, no resolution required
else if (context && (expr->is<AstExprGroup>() || expr->is<AstTypeGroup>()))
info = *context; // simple group nodes propagate their value
else if (context && (expr->is<AstTypeTypeof>() || expr->is<AstExprTypeAssertion>()))
info = *context; // typeof type annotations will resolve to the typeof content
else if (AstExpr* asExpr = expr->asExpr())
info = fileResolver->resolveModule(context, asExpr);
}
else if (AstExpr* asExpr = expr->asExpr())
{
info = fileResolver->resolveModule(&moduleContext, asExpr);
}
if (info)
result.exprs[expr] = std::move(*info);
}
}
else
{
// seed worklist with require arguments
work_DEPRECATED.reserve(requireCalls.size());
for (AstExprCall* require : requireCalls)
work_DEPRECATED.push_back(require->args.data[0]);
// push all dependent expressions to the work stack; note that the vector is modified during traversal
for (size_t i = 0; i < work_DEPRECATED.size(); ++i)
if (AstExpr* dep = getDependent_DEPRECATED(work_DEPRECATED[i]))
work_DEPRECATED.push_back(dep);
// resolve all expressions to a module info
for (size_t i = work_DEPRECATED.size(); i > 0; --i)
{
AstExpr* expr = work_DEPRECATED[i - 1];
// when multiple expressions depend on the same one we push it to work queue multiple times
if (result.exprs.contains(expr))
continue;
std::optional<ModuleInfo> info;
if (AstExpr* dep = getDependent_DEPRECATED(expr))
{
const ModuleInfo* context = result.exprs.find(dep);
// locals just inherit their dependent context, no resolution required
if (expr->is<AstExprLocal>())
info = context ? std::optional<ModuleInfo>(*context) : std::nullopt;
else
info = fileResolver->resolveModule(context, expr);
}
else
info = fileResolver->resolveModule(context, expr);
}
else
{
info = fileResolver->resolveModule(&moduleContext, expr);
}
{
info = fileResolver->resolveModule(&moduleContext, expr);
}
if (info)
result.exprs[expr] = std::move(*info);
if (info)
result.exprs[expr] = std::move(*info);
}
}
// resolve all requires according to their argument
@ -150,7 +224,8 @@ struct RequireTracer : AstVisitor
ModuleName currentModuleName;
DenseHashMap<AstLocal*, AstExpr*> locals;
std::vector<AstExpr*> work;
std::vector<AstExpr*> work_DEPRECATED;
std::vector<AstNode*> work;
std::vector<AstExprCall*> requireCalls;
};

View file

@ -211,6 +211,16 @@ void Scope::inheritRefinements(const ScopePtr& childScope)
}
}
bool Scope::shouldWarnGlobal(std::string name) const
{
for (const Scope* current = this; current; current = current->parent.get())
{
if (current->globalsToWarn.contains(name))
return true;
}
return false;
}
bool subsumesStrict(Scope* left, Scope* right)
{
while (right)

View file

@ -2,6 +2,7 @@
#include "Luau/Simplify.h"
#include "Luau/Common.h"
#include "Luau/DenseHash.h"
#include "Luau/RecursionCounter.h"
#include "Luau/Set.h"
@ -14,6 +15,7 @@
LUAU_FASTINT(LuauTypeReductionRecursionLimit)
LUAU_FASTFLAG(LuauSolverV2)
LUAU_DYNAMIC_FASTINTVARIABLE(LuauSimplificationComplexityLimit, 8);
LUAU_FASTFLAGVARIABLE(LuauFlagBasicIntersectFollows);
namespace Luau
{
@ -29,16 +31,16 @@ struct TypeSimplifier
int recursionDepth = 0;
TypeId mkNegation(TypeId ty);
TypeId mkNegation(TypeId ty) const;
TypeId intersectFromParts(std::set<TypeId> parts);
TypeId intersectUnionWithType(TypeId unionTy, TypeId right);
TypeId intersectUnionWithType(TypeId left, TypeId right);
TypeId intersectUnions(TypeId left, TypeId right);
TypeId intersectNegatedUnion(TypeId unionTy, TypeId right);
TypeId intersectNegatedUnion(TypeId left, TypeId right);
TypeId intersectTypeWithNegation(TypeId a, TypeId b);
TypeId intersectNegations(TypeId a, TypeId b);
TypeId intersectTypeWithNegation(TypeId left, TypeId right);
TypeId intersectNegations(TypeId left, TypeId right);
TypeId intersectIntersectionWithType(TypeId left, TypeId right);
@ -46,8 +48,8 @@ struct TypeSimplifier
// unions, intersections, or negations.
std::optional<TypeId> basicIntersect(TypeId left, TypeId right);
TypeId intersect(TypeId ty, TypeId discriminant);
TypeId union_(TypeId ty, TypeId discriminant);
TypeId intersect(TypeId left, TypeId right);
TypeId union_(TypeId left, TypeId right);
TypeId simplify(TypeId ty);
TypeId simplify(TypeId ty, DenseHashSet<TypeId>& seen);
@ -571,7 +573,7 @@ Relation relate(TypeId left, TypeId right)
return relate(left, right, seen);
}
TypeId TypeSimplifier::mkNegation(TypeId ty)
TypeId TypeSimplifier::mkNegation(TypeId ty) const
{
TypeId result = nullptr;
@ -1064,6 +1066,12 @@ TypeId TypeSimplifier::intersectIntersectionWithType(TypeId left, TypeId right)
std::optional<TypeId> TypeSimplifier::basicIntersect(TypeId left, TypeId right)
{
if (FFlag::LuauFlagBasicIntersectFollows)
{
left = follow(left);
right = follow(right);
}
if (get<AnyType>(left) && get<ErrorType>(right))
return right;
if (get<AnyType>(right) && get<ErrorType>(left))
@ -1403,8 +1411,6 @@ TypeId TypeSimplifier::simplify(TypeId ty, DenseHashSet<TypeId>& seen)
SimplifyResult simplifyIntersection(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, TypeId left, TypeId right)
{
LUAU_ASSERT(FFlag::LuauSolverV2);
TypeSimplifier s{builtinTypes, arena};
// fprintf(stderr, "Intersect %s and %s ...\n", toString(left).c_str(), toString(right).c_str());
@ -1418,8 +1424,6 @@ SimplifyResult simplifyIntersection(NotNull<BuiltinTypes> builtinTypes, NotNull<
SimplifyResult simplifyIntersection(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, std::set<TypeId> parts)
{
LUAU_ASSERT(FFlag::LuauSolverV2);
TypeSimplifier s{builtinTypes, arena};
TypeId res = s.intersectFromParts(std::move(parts));
@ -1429,8 +1433,6 @@ SimplifyResult simplifyIntersection(NotNull<BuiltinTypes> builtinTypes, NotNull<
SimplifyResult simplifyUnion(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, TypeId left, TypeId right)
{
LUAU_ASSERT(FFlag::LuauSolverV2);
TypeSimplifier s{builtinTypes, arena};
TypeId res = s.union_(left, right);

View file

@ -4,13 +4,15 @@
#include "Luau/Common.h"
#include "Luau/Clone.h"
#include "Luau/TxnLog.h"
#include "Luau/Type.h"
#include <algorithm>
#include <stdexcept>
LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 10000)
LUAU_FASTFLAG(LuauSolverV2);
LUAU_FASTINTVARIABLE(LuauTarjanPreallocationSize, 256);
LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTINTVARIABLE(LuauTarjanPreallocationSize, 256)
LUAU_FASTFLAG(LuauSyntheticErrors)
namespace Luau
{
@ -50,11 +52,33 @@ static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool a
LUAU_ASSERT(ty->persistent);
return ty;
}
else if constexpr (std::is_same_v<T, ErrorType>)
else if constexpr (std::is_same_v<T, NoRefineType>)
{
LUAU_ASSERT(ty->persistent);
return ty;
}
else if constexpr (std::is_same_v<T, ErrorType>)
{
if (FFlag::LuauSyntheticErrors)
{
LUAU_ASSERT(ty->persistent || a.synthetic);
if (ty->persistent)
return ty;
// While this code intentionally works (and clones) even if `a.synthetic` is `std::nullopt`,
// we still assert above because we consider it a bug to have a non-persistent error type
// without any associated metadata. We should always use the persistent version in such cases.
ErrorType clone = ErrorType{};
clone.synthetic = a.synthetic;
return dest.addType(clone);
}
else
{
LUAU_ASSERT(ty->persistent);
return ty;
}
}
else if constexpr (std::is_same_v<T, UnknownType>)
{
LUAU_ASSERT(ty->persistent);
@ -74,9 +98,7 @@ static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool a
FunctionType clone = FunctionType{a.level, a.scope, a.argTypes, a.retTypes, a.definition, a.hasSelf};
clone.generics = a.generics;
clone.genericPacks = a.genericPacks;
clone.magicFunction = a.magicFunction;
clone.dcrMagicFunction = a.dcrMagicFunction;
clone.dcrMagicRefinement = a.dcrMagicRefinement;
clone.magic = a.magic;
clone.tags = a.tags;
clone.argNames = a.argNames;
clone.isCheckedFunction = a.isCheckedFunction;
@ -127,7 +149,7 @@ static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool a
return dest.addType(NegationType{a.ty});
else if constexpr (std::is_same_v<T, TypeFunctionInstanceType>)
{
TypeFunctionInstanceType clone{a.function, a.typeArguments, a.packArguments, a.userFuncName, a.userFuncBody};
TypeFunctionInstanceType clone{a.function, a.typeArguments, a.packArguments, a.userFuncName, a.userFuncData};
return dest.addType(std::move(clone));
}
else

View file

@ -5,6 +5,7 @@
#include "Luau/Common.h"
#include "Luau/Error.h"
#include "Luau/Normalize.h"
#include "Luau/RecursionCounter.h"
#include "Luau/Scope.h"
#include "Luau/StringUtils.h"
#include "Luau/Substitution.h"
@ -20,7 +21,8 @@
#include <algorithm>
LUAU_FASTFLAGVARIABLE(DebugLuauSubtypingCheckPathValidity, false);
LUAU_FASTFLAGVARIABLE(DebugLuauSubtypingCheckPathValidity)
LUAU_FASTFLAGVARIABLE(LuauSubtypingFixTailPack)
namespace Luau
{
@ -258,43 +260,32 @@ SubtypingResult SubtypingResult::any(const std::vector<SubtypingResult>& results
struct ApplyMappedGenerics : Substitution
{
using MappedGenerics = DenseHashMap<TypeId, SubtypingEnvironment::GenericBounds>;
using MappedGenericPacks = DenseHashMap<TypePackId, TypePackId>;
NotNull<BuiltinTypes> builtinTypes;
NotNull<TypeArena> arena;
MappedGenerics& mappedGenerics;
MappedGenericPacks& mappedGenericPacks;
SubtypingEnvironment& env;
ApplyMappedGenerics(
NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> arena,
MappedGenerics& mappedGenerics,
MappedGenericPacks& mappedGenericPacks
)
ApplyMappedGenerics(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, SubtypingEnvironment& env)
: Substitution(TxnLog::empty(), arena)
, builtinTypes(builtinTypes)
, arena(arena)
, mappedGenerics(mappedGenerics)
, mappedGenericPacks(mappedGenericPacks)
, env(env)
{
}
bool isDirty(TypeId ty) override
{
return mappedGenerics.contains(ty);
return env.containsMappedType(ty);
}
bool isDirty(TypePackId tp) override
{
return mappedGenericPacks.contains(tp);
return env.containsMappedPack(tp);
}
TypeId clean(TypeId ty) override
{
const auto& bounds = mappedGenerics[ty];
const auto& bounds = env.getMappedTypeBounds(ty);
if (bounds.upperBound.empty())
return builtinTypes->unknownType;
@ -307,7 +298,12 @@ struct ApplyMappedGenerics : Substitution
TypePackId clean(TypePackId tp) override
{
return mappedGenericPacks[tp];
if (auto it = env.getMappedPackBounds(tp))
return *it;
// Clean is only called when isDirty found a pack bound
LUAU_ASSERT(!"Unreachable");
return nullptr;
}
bool ignoreChildren(TypeId ty) override
@ -325,19 +321,91 @@ struct ApplyMappedGenerics : Substitution
std::optional<TypeId> SubtypingEnvironment::applyMappedGenerics(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, TypeId ty)
{
ApplyMappedGenerics amg{builtinTypes, arena, mappedGenerics, mappedGenericPacks};
ApplyMappedGenerics amg{builtinTypes, arena, *this};
return amg.substitute(ty);
}
const TypeId* SubtypingEnvironment::tryFindSubstitution(TypeId ty) const
{
if (auto it = substitutions.find(ty))
return it;
if (parent)
return parent->tryFindSubstitution(ty);
return nullptr;
}
const SubtypingResult* SubtypingEnvironment::tryFindSubtypingResult(std::pair<TypeId, TypeId> subAndSuper) const
{
if (auto it = ephemeralCache.find(subAndSuper))
return it;
if (parent)
return parent->tryFindSubtypingResult(subAndSuper);
return nullptr;
}
bool SubtypingEnvironment::containsMappedType(TypeId ty) const
{
if (mappedGenerics.contains(ty))
return true;
if (parent)
return parent->containsMappedType(ty);
return false;
}
bool SubtypingEnvironment::containsMappedPack(TypePackId tp) const
{
if (mappedGenericPacks.contains(tp))
return true;
if (parent)
return parent->containsMappedPack(tp);
return false;
}
SubtypingEnvironment::GenericBounds& SubtypingEnvironment::getMappedTypeBounds(TypeId ty)
{
if (auto it = mappedGenerics.find(ty))
return *it;
if (parent)
return parent->getMappedTypeBounds(ty);
LUAU_ASSERT(!"Use containsMappedType before asking for bounds!");
return mappedGenerics[ty];
}
TypePackId* SubtypingEnvironment::getMappedPackBounds(TypePackId tp)
{
if (auto it = mappedGenericPacks.find(tp))
return it;
if (parent)
return parent->getMappedPackBounds(tp);
// This fallback is reachable in valid cases, unlike the final part of getMappedTypeBounds
return nullptr;
}
Subtyping::Subtyping(
NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> typeArena,
NotNull<Simplifier> simplifier,
NotNull<Normalizer> normalizer,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<InternalErrorReporter> iceReporter
)
: builtinTypes(builtinTypes)
, arena(typeArena)
, simplifier(simplifier)
, normalizer(normalizer)
, typeFunctionRuntime(typeFunctionRuntime)
, iceReporter(iceReporter)
{
}
@ -379,7 +447,10 @@ SubtypingResult Subtyping::isSubtype(TypeId subTy, TypeId superTy, NotNull<Scope
result.isSubtype = false;
}
SubtypingResult boundsResult = isCovariantWith(env, lowerBound, upperBound, scope);
SubtypingEnvironment boundsEnv;
boundsEnv.parent = &env;
SubtypingResult boundsResult = isCovariantWith(boundsEnv, lowerBound, upperBound, scope);
boundsResult.reasoning.clear();
result.andAlso(boundsResult);
@ -442,20 +513,30 @@ struct SeenSetPopper
SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypeId subTy, TypeId superTy, NotNull<Scope> scope)
{
UnifierCounters& counters = normalizer->sharedState->counters;
RecursionCounter rc(&counters.recursionCount);
if (counters.recursionLimit > 0 && counters.recursionLimit < counters.recursionCount)
{
SubtypingResult result;
result.normalizationTooComplex = true;
return result;
}
subTy = follow(subTy);
superTy = follow(superTy);
if (TypeId* subIt = env.substitutions.find(subTy); subIt && *subIt)
if (const TypeId* subIt = env.tryFindSubstitution(subTy); subIt && *subIt)
subTy = *subIt;
if (TypeId* superIt = env.substitutions.find(superTy); superIt && *superIt)
if (const TypeId* superIt = env.tryFindSubstitution(superTy); superIt && *superIt)
superTy = *superIt;
SubtypingResult* cachedResult = resultCache.find({subTy, superTy});
const SubtypingResult* cachedResult = resultCache.find({subTy, superTy});
if (cachedResult)
return *cachedResult;
cachedResult = env.ephemeralCache.find({subTy, superTy});
cachedResult = env.tryFindSubtypingResult({subTy, superTy});
if (cachedResult)
return *cachedResult;
@ -700,7 +781,7 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypePackId
std::vector<TypeId> headSlice(begin(superHead), begin(superHead) + headSize);
TypePackId superTailPack = arena->addTypePack(std::move(headSlice), superTail);
if (TypePackId* other = env.mappedGenericPacks.find(*subTail))
if (TypePackId* other = env.getMappedPackBounds(*subTail))
// TODO: TypePath can't express "slice of a pack + its tail".
results.push_back(isCovariantWith(env, *other, superTailPack, scope).withSubComponent(TypePath::PackField::Tail));
else
@ -755,7 +836,7 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypePackId
std::vector<TypeId> headSlice(begin(subHead), begin(subHead) + headSize);
TypePackId subTailPack = arena->addTypePack(std::move(headSlice), subTail);
if (TypePackId* other = env.mappedGenericPacks.find(*superTail))
if (TypePackId* other = env.getMappedPackBounds(*superTail))
// TODO: TypePath can't express "slice of a pack + its tail".
results.push_back(isContravariantWith(env, subTailPack, *other, scope).withSuperComponent(TypePath::PackField::Tail));
else
@ -778,7 +859,7 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypePackId
else
return SubtypingResult{false}
.withSuperComponent(TypePath::PackField::Tail)
.withError({scope->location, UnexpectedTypePackInSubtyping{*subTail}});
.withError({scope->location, UnexpectedTypePackInSubtyping{FFlag::LuauSubtypingFixTailPack ? *superTail : *subTail}});
}
else
return {false};
@ -1316,6 +1397,7 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const Tabl
SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const MetatableType* subMt, const MetatableType* superMt, NotNull<Scope> scope)
{
return isCovariantWith(env, subMt->table, superMt->table, scope)
.withBothComponent(TypePath::TypeField::Table)
.andAlso(isCovariantWith(env, subMt->metatable, superMt->metatable, scope).withBothComponent(TypePath::TypeField::Metatable));
}
@ -1389,6 +1471,19 @@ SubtypingResult Subtyping::isCovariantWith(
result.orElse(
isContravariantWith(env, subFunction->argTypes, superFunction->argTypes, scope).withBothComponent(TypePath::PackField::Arguments)
);
// If subtyping failed in the argument packs, we should check if there's a hidden variadic tail and try ignoring it.
// This might cause subtyping correctly because the sub type here may not have a hidden variadic tail or equivalent.
if (!result.isSubtype)
{
auto [arguments, tail] = flatten(superFunction->argTypes);
if (auto variadic = get<VariadicTypePack>(tail); variadic && variadic->hidden)
{
result.orElse(isContravariantWith(env, subFunction->argTypes, arena->addTypePack(TypePack{arguments}), scope)
.withBothComponent(TypePath::PackField::Arguments));
}
}
}
result.andAlso(isCovariantWith(env, subFunction->retTypes, superFunction->retTypes, scope).withBothComponent(TypePath::PackField::Returns));
@ -1688,6 +1783,9 @@ bool Subtyping::bindGeneric(SubtypingEnvironment& env, TypeId subTy, TypeId supe
if (!get<GenericType>(subTy))
return false;
if (!env.mappedGenerics.find(subTy) && env.containsMappedType(subTy))
iceReporter->ice("attempting to modify bounds of a potentially visited generic");
env.mappedGenerics[subTy].upperBound.insert(superTy);
}
else
@ -1695,6 +1793,9 @@ bool Subtyping::bindGeneric(SubtypingEnvironment& env, TypeId subTy, TypeId supe
if (!get<GenericType>(superTy))
return false;
if (!env.mappedGenerics.find(superTy) && env.containsMappedType(superTy))
iceReporter->ice("attempting to modify bounds of a potentially visited generic");
env.mappedGenerics[superTy].lowerBound.insert(subTy);
}
@ -1740,7 +1841,7 @@ bool Subtyping::bindGeneric(SubtypingEnvironment& env, TypePackId subTp, TypePac
if (!get<GenericTypePack>(subTp))
return false;
if (TypePackId* m = env.mappedGenericPacks.find(subTp))
if (TypePackId* m = env.getMappedPackBounds(subTp))
return *m == superTp;
env.mappedGenericPacks[subTp] = superTp;
@ -1761,7 +1862,7 @@ TypeId Subtyping::makeAggregateType(const Container& container, TypeId orElse)
std::pair<TypeId, ErrorVec> Subtyping::handleTypeFunctionReductionResult(const TypeFunctionInstanceType* functionInstance, NotNull<Scope> scope)
{
TypeFunctionContext context{arena, builtinTypes, scope, normalizer, iceReporter, NotNull{&limits}};
TypeFunctionContext context{arena, builtinTypes, scope, simplifier, normalizer, typeFunctionRuntime, iceReporter, NotNull{&limits}};
TypeId function = arena->addType(*functionInstance);
FunctionGraphReductionResult result = reduceTypeFunctions(function, {}, context, true);
ErrorVec errors;

View file

@ -4,6 +4,7 @@
#include "Luau/Common.h"
LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAGVARIABLE(LuauSymbolEquality)
namespace Luau
{
@ -14,7 +15,7 @@ bool Symbol::operator==(const Symbol& rhs) const
return local == rhs.local;
else if (global.value)
return rhs.global.value && global == rhs.global.value; // Subtlety: AstName::operator==(const char*) uses strcmp, not pointer identity.
else if (FFlag::LuauSolverV2)
else if (FFlag::LuauSolverV2 || FFlag::LuauSymbolEquality)
return !rhs.local && !rhs.global.value; // Reflexivity: we already know `this` Symbol is empty, so check that rhs is.
else
return false;

View file

@ -6,19 +6,15 @@
#include "Luau/Type.h"
#include "Luau/ToString.h"
#include "Luau/TypeArena.h"
#include "Luau/TypeUtils.h"
#include "Luau/Unifier2.h"
LUAU_FASTFLAGVARIABLE(LuauDontInPlaceMutateTableType)
LUAU_FASTFLAGVARIABLE(LuauAllowNonSharedTableTypesInLiteral)
namespace Luau
{
static bool isLiteral(const AstExpr* expr)
{
return (
expr->is<AstExprTable>() || expr->is<AstExprFunction>() || expr->is<AstExprConstantNumber>() || expr->is<AstExprConstantString>() ||
expr->is<AstExprConstantBool>() || expr->is<AstExprConstantNil>()
);
}
// A fast approximation of subTy <: superTy
static bool fastIsSubtype(TypeId subTy, TypeId superTy)
{
@ -243,6 +239,8 @@ TypeId matchLiteralType(
return exprType;
}
DenseHashSet<AstExprConstantString*> keysToDelete{nullptr};
for (const AstExprTable::Item& item : exprTable->items)
{
if (isRecord(item))
@ -254,8 +252,19 @@ TypeId matchLiteralType(
Property& prop = it->second;
// Table literals always initially result in shared read-write types
LUAU_ASSERT(prop.isShared());
if (FFlag::LuauAllowNonSharedTableTypesInLiteral)
{
// If we encounter a duplcate property, we may have already
// set it to be read-only. If that's the case, the only thing
// that will definitely crash is trying to access a write
// only property.
LUAU_ASSERT(!prop.isWriteOnly());
}
else
{
// Table literals always initially result in shared read-write types
LUAU_ASSERT(prop.isShared());
}
TypeId propTy = *prop.readTy;
auto it2 = expectedTableTy->props.find(keyStr);
@ -287,7 +296,10 @@ TypeId matchLiteralType(
else
tableTy->indexer = TableIndexer{expectedTableTy->indexer->indexType, matchedType};
tableTy->props.erase(keyStr);
if (FFlag::LuauDontInPlaceMutateTableType)
keysToDelete.insert(item.key->as<AstExprConstantString>());
else
tableTy->props.erase(keyStr);
}
// If it's just an extra property and the expected type
@ -381,15 +393,11 @@ TypeId matchLiteralType(
const TypeId* keyTy = astTypes->find(item.key);
LUAU_ASSERT(keyTy);
TypeId tKey = follow(*keyTy);
if (get<BlockedType>(tKey))
toBlock.push_back(tKey);
LUAU_ASSERT(!is<BlockedType>(tKey));
const TypeId* propTy = astTypes->find(item.value);
LUAU_ASSERT(propTy);
TypeId tProp = follow(*propTy);
if (get<BlockedType>(tProp))
toBlock.push_back(tProp);
LUAU_ASSERT(!is<BlockedType>(tProp));
// Populate expected types for non-string keys declared with [] (the code below will handle the case where they are strings)
if (!item.key->as<AstExprConstantString>() && expectedTableTy->indexer)
(*astExpectedTypes)[item.key] = expectedTableTy->indexer->indexType;
@ -398,6 +406,16 @@ TypeId matchLiteralType(
LUAU_ASSERT(!"Unexpected");
}
if (FFlag::LuauDontInPlaceMutateTableType)
{
for (const auto& key : keysToDelete)
{
const AstArray<char>& s = key->value;
std::string keyStr{s.data, s.data + s.size};
tableTy->props.erase(keyStr);
}
}
// Keys that the expectedType says we should have, but that aren't
// specified by the AST fragment.
//

View file

@ -269,6 +269,12 @@ void StateDot::visitChildren(TypeId ty, int index)
finishNodeLabel(ty);
finishNode();
}
else if constexpr (std::is_same_v<T, NoRefineType>)
{
formatAppend(result, "NoRefineType %d", index);
finishNodeLabel(ty);
finishNode();
}
else if constexpr (std::is_same_v<T, UnknownType>)
{
formatAppend(result, "UnknownType %d", index);
@ -414,7 +420,7 @@ void StateDot::visitChildren(TypePackId tp, int index)
finishNodeLabel(tp);
finishNode();
}
else if (get<Unifiable::Error>(tp))
else if (get<ErrorTypePack>(tp))
{
formatAppend(result, "ErrorTypePack %d", index);
finishNodeLabel(tp);

View file

@ -20,6 +20,7 @@
#include <string>
LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAGVARIABLE(LuauSyntheticErrors)
/*
* Enables increasing levels of verbosity for Luau type names when stringifying.
@ -38,7 +39,7 @@ LUAU_FASTFLAG(LuauSolverV2)
* 3: Suffix free/generic types with their scope pointer, if present.
*/
LUAU_FASTINTVARIABLE(DebugLuauVerboseTypeNames, 0)
LUAU_FASTFLAGVARIABLE(DebugLuauToStringNoLexicalSort, false)
LUAU_FASTFLAGVARIABLE(DebugLuauToStringNoLexicalSort)
namespace Luau
{
@ -856,6 +857,11 @@ struct TypeStringifier
state.emit("any");
}
void operator()(TypeId, const NoRefineType&)
{
state.emit("*no-refine*");
}
void operator()(TypeId, const UnionType& uv)
{
if (state.hasSeen(&uv))
@ -865,6 +871,8 @@ struct TypeStringifier
return;
}
LUAU_ASSERT(uv.options.size() > 1);
bool optional = false;
bool hasNonNilDisjunct = false;
@ -873,7 +881,7 @@ struct TypeStringifier
{
el = follow(el);
if (isNil(el))
if (state.opts.useQuestionMarks && isNil(el))
{
optional = true;
continue;
@ -991,7 +999,15 @@ struct TypeStringifier
void operator()(TypeId, const ErrorType& tv)
{
state.result.error = true;
state.emit("*error-type*");
if (FFlag::LuauSyntheticErrors && tv.synthetic)
{
state.emit("*error-type<");
stringify(*tv.synthetic);
state.emit(">*");
}
else
state.emit("*error-type*");
}
void operator()(TypeId, const LazyType& ltv)
@ -1040,6 +1056,7 @@ struct TypeStringifier
state.emit(tfitv.userFuncName->value);
else
state.emit(tfitv.function->name);
state.emit("<");
bool comma = false;
@ -1165,10 +1182,18 @@ struct TypePackStringifier
state.unsee(&tp);
}
void operator()(TypePackId, const Unifiable::Error& error)
void operator()(TypePackId, const ErrorTypePack& error)
{
state.result.error = true;
state.emit("*error-type*");
if (FFlag::LuauSyntheticErrors && error.synthetic)
{
state.emit("*");
stringify(*error.synthetic);
state.emit("*");
}
else
state.emit("*error-type*");
}
void operator()(TypePackId, const VariadicTypePack& pack)
@ -1840,6 +1865,8 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts)
}
else if constexpr (std::is_same_v<T, EqualityConstraint>)
return "equality: " + tos(c.resultType) + " ~ " + tos(c.assignmentType);
else if constexpr (std::is_same_v<T, TableCheckConstraint>)
return "table_check " + tos(c.expectedType) + " :> " + tos(c.exprType);
else
static_assert(always_false_v<T>, "Non-exhaustive constraint switch");
};

File diff suppressed because it is too large Load diff

View file

@ -93,8 +93,8 @@ void TxnLog::concatAsIntersections(TxnLog rhs, NotNull<TypeArena> arena)
if (auto leftRep = typeVarChanges.find(ty); leftRep && !(*leftRep)->dead)
{
TypeId leftTy = arena->addType((*leftRep)->pending);
TypeId rightTy = arena->addType(rightRep->pending);
TypeId leftTy = arena->addType((*leftRep)->pending.clone());
TypeId rightTy = arena->addType(rightRep->pending.clone());
typeVarChanges[ty]->pending.ty = IntersectionType{{leftTy, rightTy}};
}
else
@ -170,8 +170,8 @@ void TxnLog::concatAsUnion(TxnLog rhs, NotNull<TypeArena> arena)
if (auto leftRep = typeVarChanges.find(ty); leftRep && !(*leftRep)->dead)
{
TypeId leftTy = arena->addType((*leftRep)->pending);
TypeId rightTy = arena->addType(rightRep->pending);
TypeId leftTy = arena->addType((*leftRep)->pending.clone());
TypeId rightTy = arena->addType(rightRep->pending.clone());
if (follow(leftTy) == follow(rightTy))
typeVarChanges[ty] = std::move(rightRep);
@ -217,7 +217,7 @@ TxnLog TxnLog::inverse()
for (auto& [ty, _rep] : typeVarChanges)
{
if (!_rep->dead)
inversed.typeVarChanges[ty] = std::make_unique<PendingType>(*ty);
inversed.typeVarChanges[ty] = std::make_unique<PendingType>(ty->clone());
}
for (auto& [tp, _rep] : typePackChanges)
@ -292,7 +292,7 @@ PendingType* TxnLog::queue(TypeId ty)
auto& pending = typeVarChanges[ty];
if (!pending || (*pending).dead)
{
pending = std::make_unique<PendingType>(*ty);
pending = std::make_unique<PendingType>(ty->clone());
pending->pending.owningArena = nullptr;
}

View file

@ -27,6 +27,7 @@ LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500)
LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0)
LUAU_FASTINT(LuauTypeInferRecursionLimit)
LUAU_FASTFLAG(LuauInstantiateInSubtyping)
LUAU_FASTFLAGVARIABLE(LuauFreeTypesMustHaveBounds)
namespace Luau
{
@ -478,24 +479,12 @@ bool hasLength(TypeId ty, DenseHashSet<TypeId>& seen, int* recursionCount)
return false;
}
FreeType::FreeType(TypeLevel level)
// New constructors
FreeType::FreeType(TypeLevel level, TypeId lowerBound, TypeId upperBound)
: index(Unifiable::freshIndex())
, level(level)
, scope(nullptr)
{
}
FreeType::FreeType(Scope* scope)
: index(Unifiable::freshIndex())
, level{}
, scope(scope)
{
}
FreeType::FreeType(Scope* scope, TypeLevel level)
: index(Unifiable::freshIndex())
, level(level)
, scope(scope)
, lowerBound(lowerBound)
, upperBound(upperBound)
{
}
@ -507,6 +496,40 @@ FreeType::FreeType(Scope* scope, TypeId lowerBound, TypeId upperBound)
{
}
FreeType::FreeType(Scope* scope, TypeLevel level, TypeId lowerBound, TypeId upperBound)
: index(Unifiable::freshIndex())
, level(level)
, scope(scope)
, lowerBound(lowerBound)
, upperBound(upperBound)
{
}
// Old constructors
FreeType::FreeType(TypeLevel level)
: index(Unifiable::freshIndex())
, level(level)
, scope(nullptr)
{
LUAU_ASSERT(!FFlag::LuauFreeTypesMustHaveBounds);
}
FreeType::FreeType(Scope* scope)
: index(Unifiable::freshIndex())
, level{}
, scope(scope)
{
LUAU_ASSERT(!FFlag::LuauFreeTypesMustHaveBounds);
}
FreeType::FreeType(Scope* scope, TypeLevel level)
: index(Unifiable::freshIndex())
, level(level)
, scope(scope)
{
LUAU_ASSERT(!FFlag::LuauFreeTypesMustHaveBounds);
}
GenericType::GenericType()
: index(Unifiable::freshIndex())
, name("g" + std::to_string(index))
@ -554,12 +577,12 @@ BlockedType::BlockedType()
{
}
Constraint* BlockedType::getOwner() const
const Constraint* BlockedType::getOwner() const
{
return owner;
}
void BlockedType::setOwner(Constraint* newOwner)
void BlockedType::setOwner(const Constraint* newOwner)
{
LUAU_ASSERT(owner == nullptr);
@ -569,7 +592,7 @@ void BlockedType::setOwner(Constraint* newOwner)
owner = newOwner;
}
void BlockedType::replaceOwner(Constraint* newOwner)
void BlockedType::replaceOwner(const Constraint* newOwner)
{
owner = newOwner;
}
@ -999,6 +1022,11 @@ Type& Type::operator=(const Type& rhs)
return *this;
}
Type Type::clone() const
{
return *this;
}
TypeId makeFunction(
TypeArena& arena,
std::optional<TypeId> selfType,
@ -1030,6 +1058,7 @@ BuiltinTypes::BuiltinTypes()
, unknownType(arena->addType(Type{UnknownType{}, /*persistent*/ true}))
, neverType(arena->addType(Type{NeverType{}, /*persistent*/ true}))
, errorType(arena->addType(Type{ErrorType{}, /*persistent*/ true}))
, noRefineType(arena->addType(Type{NoRefineType{}, /*persistent*/ true}))
, falsyType(arena->addType(Type{UnionType{{falseType, nilType}}, /*persistent*/ true}))
, truthyType(arena->addType(Type{NegationType{falsyType}, /*persistent*/ true}))
, optionalNumberType(arena->addType(Type{UnionType{{numberType, nilType}}, /*persistent*/ true}))
@ -1039,7 +1068,7 @@ BuiltinTypes::BuiltinTypes()
, unknownTypePack(arena->addTypePack(TypePackVar{VariadicTypePack{unknownType}, /*persistent*/ true}))
, neverTypePack(arena->addTypePack(TypePackVar{VariadicTypePack{neverType}, /*persistent*/ true}))
, uninhabitableTypePack(arena->addTypePack(TypePackVar{TypePack{{neverType}, neverTypePack}, /*persistent*/ true}))
, errorTypePack(arena->addTypePack(TypePackVar{Unifiable::Error{}, /*persistent*/ true}))
, errorTypePack(arena->addTypePack(TypePackVar{ErrorTypePack{}, /*persistent*/ true}))
{
freeze(*arena);
}

View file

@ -2,7 +2,8 @@
#include "Luau/TypeArena.h"
LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false);
LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena);
LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds)
namespace Luau
{
@ -22,7 +23,34 @@ TypeId TypeArena::addTV(Type&& tv)
return allocated;
}
TypeId TypeArena::freshType(TypeLevel level)
TypeId TypeArena::freshType(NotNull<BuiltinTypes> builtins, TypeLevel level)
{
TypeId allocated = types.allocate(FreeType{level, builtins->neverType, builtins->unknownType});
asMutable(allocated)->owningArena = this;
return allocated;
}
TypeId TypeArena::freshType(NotNull<BuiltinTypes> builtins, Scope* scope)
{
TypeId allocated = types.allocate(FreeType{scope, builtins->neverType, builtins->unknownType});
asMutable(allocated)->owningArena = this;
return allocated;
}
TypeId TypeArena::freshType(NotNull<BuiltinTypes> builtins, Scope* scope, TypeLevel level)
{
TypeId allocated = types.allocate(FreeType{scope, level, builtins->neverType, builtins->unknownType});
asMutable(allocated)->owningArena = this;
return allocated;
}
TypeId TypeArena::freshType_DEPRECATED(TypeLevel level)
{
TypeId allocated = types.allocate(FreeType{level});
@ -31,7 +59,7 @@ TypeId TypeArena::freshType(TypeLevel level)
return allocated;
}
TypeId TypeArena::freshType(Scope* scope)
TypeId TypeArena::freshType_DEPRECATED(Scope* scope)
{
TypeId allocated = types.allocate(FreeType{scope});
@ -40,7 +68,7 @@ TypeId TypeArena::freshType(Scope* scope)
return allocated;
}
TypeId TypeArena::freshType(Scope* scope, TypeLevel level)
TypeId TypeArena::freshType_DEPRECATED(Scope* scope, TypeLevel level)
{
TypeId allocated = types.allocate(FreeType{scope, level});

View file

@ -145,6 +145,12 @@ public:
{
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("any"), std::nullopt, Location());
}
AstType* operator()(const NoRefineType&)
{
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("*no-refine*"), std::nullopt, Location());
}
AstType* operator()(const TableType& ttv)
{
RecursionCounter counter(&count);
@ -255,24 +261,24 @@ public:
if (hasSeen(&ftv))
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("<Cycle>"), std::nullopt, Location());
AstArray<AstGenericType> generics;
AstArray<AstGenericType*> generics;
generics.size = ftv.generics.size();
generics.data = static_cast<AstGenericType*>(allocator->allocate(sizeof(AstGenericType) * generics.size));
generics.data = static_cast<AstGenericType**>(allocator->allocate(sizeof(AstGenericType) * generics.size));
size_t numGenerics = 0;
for (auto it = ftv.generics.begin(); it != ftv.generics.end(); ++it)
{
if (auto gtv = get<GenericType>(*it))
generics.data[numGenerics++] = {AstName(gtv->name.c_str()), Location(), nullptr};
generics.data[numGenerics++] = allocator->alloc<AstGenericType>(Location(), AstName(gtv->name.c_str()), nullptr);
}
AstArray<AstGenericTypePack> genericPacks;
AstArray<AstGenericTypePack*> genericPacks;
genericPacks.size = ftv.genericPacks.size();
genericPacks.data = static_cast<AstGenericTypePack*>(allocator->allocate(sizeof(AstGenericTypePack) * genericPacks.size));
genericPacks.data = static_cast<AstGenericTypePack**>(allocator->allocate(sizeof(AstGenericTypePack) * genericPacks.size));
size_t numGenericPacks = 0;
for (auto it = ftv.genericPacks.begin(); it != ftv.genericPacks.end(); ++it)
{
if (auto gtv = get<GenericTypePack>(*it))
genericPacks.data[numGenericPacks++] = {AstName(gtv->name.c_str()), Location(), nullptr};
genericPacks.data[numGenericPacks++] = allocator->alloc<AstGenericTypePack>(Location(), AstName(gtv->name.c_str()), nullptr);
}
AstArray<AstType*> argTypes;
@ -323,7 +329,7 @@ public:
Location(), generics, genericPacks, AstTypeList{argTypes, argTailAnnotation}, argNames, AstTypeList{returnTypes, retTailAnnotation}
);
}
AstType* operator()(const Unifiable::Error&)
AstType* operator()(const ErrorType&)
{
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("Unifiable<Error>"), std::nullopt, Location());
}
@ -380,8 +386,12 @@ public:
}
AstType* operator()(const NegationType& ntv)
{
// FIXME: do the same thing we do with ErrorType
throw InternalCompilerError("Cannot convert NegationType into AstNode");
AstArray<AstTypeOrPack> params;
params.size = 1;
params.data = static_cast<AstTypeOrPack*>(allocator->allocate(sizeof(AstType*)));
params.data[0] = AstTypeOrPack{Luau::visit(*this, ntv.ty->ty), nullptr};
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("negate"), std::nullopt, Location(), true, params);
}
AstType* operator()(const TypeFunctionInstanceType& tfit)
{
@ -452,7 +462,7 @@ public:
return allocator->alloc<AstTypePackGeneric>(Location(), AstName("free"));
}
AstTypePack* operator()(const Unifiable::Error&) const
AstTypePack* operator()(const ErrorTypePack&) const
{
return allocator->alloc<AstTypePackGeneric>(Location(), AstName("Unifiable<Error>"));
}

View file

@ -7,7 +7,6 @@
#include "Luau/DcrLogger.h"
#include "Luau/DenseHash.h"
#include "Luau/Error.h"
#include "Luau/InsertionOrderedMap.h"
#include "Luau/Instantiation.h"
#include "Luau/Metamethods.h"
#include "Luau/Normalize.h"
@ -27,11 +26,11 @@
#include "Luau/VisitType.h"
#include <algorithm>
#include <iostream>
#include <ostream>
LUAU_FASTFLAG(DebugLuauMagicTypes)
LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds)
namespace Luau
{
@ -173,7 +172,7 @@ struct InternalTypeFunctionFinder : TypeOnceVisitor
DenseHashSet<TypeId> mentionedFunctions{nullptr};
DenseHashSet<TypePackId> mentionedFunctionPacks{nullptr};
InternalTypeFunctionFinder(std::vector<TypeId>& declStack)
explicit InternalTypeFunctionFinder(std::vector<TypeId>& declStack)
{
TypeFunctionFinder f;
for (TypeId fn : declStack)
@ -266,6 +265,8 @@ struct InternalTypeFunctionFinder : TypeOnceVisitor
void check(
NotNull<BuiltinTypes> builtinTypes,
NotNull<Simplifier> simplifier,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<UnifierSharedState> unifierState,
NotNull<TypeCheckLimits> limits,
DcrLogger* logger,
@ -275,7 +276,7 @@ void check(
{
LUAU_TIMETRACE_SCOPE("check", "Typechecking");
TypeChecker2 typeChecker{builtinTypes, unifierState, limits, logger, &sourceModule, module};
TypeChecker2 typeChecker{builtinTypes, simplifier, typeFunctionRuntime, unifierState, limits, logger, &sourceModule, module};
typeChecker.visit(sourceModule.root);
@ -292,6 +293,8 @@ void check(
TypeChecker2::TypeChecker2(
NotNull<BuiltinTypes> builtinTypes,
NotNull<Simplifier> simplifier,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<UnifierSharedState> unifierState,
NotNull<TypeCheckLimits> limits,
DcrLogger* logger,
@ -299,13 +302,15 @@ TypeChecker2::TypeChecker2(
Module* module
)
: builtinTypes(builtinTypes)
, simplifier(simplifier)
, typeFunctionRuntime(typeFunctionRuntime)
, logger(logger)
, limits(limits)
, ice(unifierState->iceHandler)
, sourceModule(sourceModule)
, module(module)
, normalizer{&module->internalTypes, builtinTypes, unifierState, /* cacheInhabitance */ true}
, _subtyping{builtinTypes, NotNull{&module->internalTypes}, NotNull{&normalizer}, NotNull{unifierState->iceHandler}}
, _subtyping{builtinTypes, NotNull{&module->internalTypes}, simplifier, NotNull{&normalizer}, typeFunctionRuntime, NotNull{unifierState->iceHandler}}
, subtyping(&_subtyping)
{
}
@ -483,19 +488,22 @@ TypeId TypeChecker2::checkForTypeFunctionInhabitance(TypeId instance, Location l
return instance;
seenTypeFunctionInstances.insert(instance);
ErrorVec errors = reduceTypeFunctions(
instance,
location,
TypeFunctionContext{NotNull{&module->internalTypes}, builtinTypes, stack.back(), NotNull{&normalizer}, ice, limits},
true
)
.errors;
ErrorVec errors =
reduceTypeFunctions(
instance,
location,
TypeFunctionContext{
NotNull{&module->internalTypes}, builtinTypes, stack.back(), simplifier, NotNull{&normalizer}, typeFunctionRuntime, ice, limits
},
true
)
.errors;
if (!isErrorSuppressing(location, instance))
reportErrors(std::move(errors));
return instance;
}
TypePackId TypeChecker2::lookupPack(AstExpr* expr)
TypePackId TypeChecker2::lookupPack(AstExpr* expr) const
{
// If a type isn't in the type graph, it probably means that a recursion limit was exceeded.
// We'll just return anyType in these cases. Typechecking against any is very fast and this
@ -545,7 +553,7 @@ TypeId TypeChecker2::lookupAnnotation(AstType* annotation)
return checkForTypeFunctionInhabitance(follow(*ty), annotation->location);
}
std::optional<TypePackId> TypeChecker2::lookupPackAnnotation(AstTypePack* annotation)
std::optional<TypePackId> TypeChecker2::lookupPackAnnotation(AstTypePack* annotation) const
{
TypePackId* tp = module->astResolvedTypePacks.find(annotation);
if (tp != nullptr)
@ -553,7 +561,7 @@ std::optional<TypePackId> TypeChecker2::lookupPackAnnotation(AstTypePack* annota
return {};
}
TypeId TypeChecker2::lookupExpectedType(AstExpr* expr)
TypeId TypeChecker2::lookupExpectedType(AstExpr* expr) const
{
if (TypeId* ty = module->astExpectedTypes.find(expr))
return follow(*ty);
@ -561,7 +569,7 @@ TypeId TypeChecker2::lookupExpectedType(AstExpr* expr)
return builtinTypes->anyType;
}
TypePackId TypeChecker2::lookupExpectedPack(AstExpr* expr, TypeArena& arena)
TypePackId TypeChecker2::lookupExpectedPack(AstExpr* expr, TypeArena& arena) const
{
if (TypeId* ty = module->astExpectedTypes.find(expr))
return arena.addTypePack(TypePack{{follow(*ty)}, std::nullopt});
@ -585,7 +593,7 @@ TypePackId TypeChecker2::reconstructPack(AstArray<AstExpr*> exprs, TypeArena& ar
return arena.addTypePack(TypePack{head, tail});
}
Scope* TypeChecker2::findInnermostScope(Location location)
Scope* TypeChecker2::findInnermostScope(Location location) const
{
Scope* bestScope = module->getModuleScope().get();
@ -1008,7 +1016,8 @@ void TypeChecker2::visit(AstStatForIn* forInStatement)
{
reportError(OptionalValueAccess{iteratorTy}, forInStatement->values.data[0]->location);
}
else if (std::optional<TypeId> iterMmTy = findMetatableEntry(builtinTypes, module->errors, iteratorTy, "__iter", forInStatement->values.data[0]->location))
else if (std::optional<TypeId> iterMmTy =
findMetatableEntry(builtinTypes, module->errors, iteratorTy, "__iter", forInStatement->values.data[0]->location))
{
Instantiation instantiation{TxnLog::empty(), &arena, builtinTypes, TypeLevel{}, scope};
@ -1193,8 +1202,6 @@ void TypeChecker2::visit(AstStatTypeAlias* stat)
void TypeChecker2::visit(AstStatTypeFunction* stat)
{
// TODO: add type checking for user-defined type functions
reportError(TypeError{stat->location, GenericError{"This syntax is not supported"}});
}
void TypeChecker2::visit(AstTypeList types)
@ -1345,7 +1352,17 @@ void TypeChecker2::visit(AstExprGlobal* expr)
{
NotNull<Scope> scope = stack.back();
if (!scope->lookup(expr->name))
{
reportError(UnknownSymbol{expr->name.value, UnknownSymbol::Binding}, expr->location);
}
else
{
if (scope->shouldWarnGlobal(expr->name.value) && !warnedGlobals.contains(expr->name.value))
{
reportError(UnknownSymbol{expr->name.value, UnknownSymbol::Binding}, expr->location);
warnedGlobals.insert(expr->name.value);
}
}
}
void TypeChecker2::visit(AstExprVarargs* expr)
@ -1433,10 +1450,11 @@ void TypeChecker2::visitCall(AstExprCall* call)
TypePackId argsTp = module->internalTypes.addTypePack(args);
if (auto ftv = get<FunctionType>(follow(*originalCallTy)))
{
if (ftv->dcrMagicTypeCheck)
if (ftv->magic)
{
ftv->dcrMagicTypeCheck(MagicFunctionTypeCheckContext{NotNull{this}, builtinTypes, call, argsTp, scope});
return;
bool usedMagic = ftv->magic->typeCheck(MagicFunctionTypeCheckContext{NotNull{this}, builtinTypes, call, argsTp, scope});
if (usedMagic)
return;
}
}
@ -1444,7 +1462,9 @@ void TypeChecker2::visitCall(AstExprCall* call)
OverloadResolver resolver{
builtinTypes,
NotNull{&module->internalTypes},
simplifier,
NotNull{&normalizer},
typeFunctionRuntime,
NotNull{stack.back()},
ice,
limits,
@ -1540,7 +1560,7 @@ void TypeChecker2::visit(AstExprCall* call)
visitCall(call);
}
std::optional<TypeId> TypeChecker2::tryStripUnionFromNil(TypeId ty)
std::optional<TypeId> TypeChecker2::tryStripUnionFromNil(TypeId ty) const
{
if (const UnionType* utv = get<UnionType>(ty))
{
@ -1618,8 +1638,7 @@ void TypeChecker2::indexExprMetatableHelper(AstExprIndexExpr* indexExpr, const M
indexExprMetatableHelper(indexExpr, mtmt, exprType, indexType);
else
{
LUAU_ASSERT(tt || get<PrimitiveType>(follow(metaTable->table)));
// CLI-122161: We're not handling unions correctly (probably).
reportError(CannotExtendTable{exprType, CannotExtendTable::Indexer, "indexer??"}, indexExpr->location);
}
}
@ -1826,11 +1845,10 @@ void TypeChecker2::visit(AstExprFunction* fn)
void TypeChecker2::visit(AstExprTable* expr)
{
// TODO!
for (const AstExprTable::Item& item : expr->items)
{
if (item.key)
visit(item.key, ValueContext::LValue);
visit(item.key, ValueContext::RValue);
visit(item.value, ValueContext::RValue);
}
}
@ -2078,7 +2096,10 @@ TypeId TypeChecker2::visit(AstExprBinary* expr, AstNode* overrideKey)
}
else
{
expectedRets = module->internalTypes.addTypePack({module->internalTypes.freshType(scope, TypeLevel{})});
expectedRets = module->internalTypes.addTypePack(
{FFlag::LuauFreeTypesMustHaveBounds ? module->internalTypes.freshType(builtinTypes, scope, TypeLevel{})
: module->internalTypes.freshType_DEPRECATED(scope, TypeLevel{})}
);
}
TypeId expectedTy = module->internalTypes.addType(FunctionType(expectedArgs, expectedRets));
@ -2330,7 +2351,8 @@ TypeId TypeChecker2::flattenPack(TypePackId pack)
return *fst;
else if (auto ftp = get<FreeTypePack>(pack))
{
TypeId result = module->internalTypes.addType(FreeType{ftp->scope});
TypeId result = FFlag::LuauFreeTypesMustHaveBounds ? module->internalTypes.freshType(builtinTypes, ftp->scope)
: module->internalTypes.addType(FreeType{ftp->scope});
TypePackId freeTail = module->internalTypes.addTypePack(FreeTypePack{ftp->scope});
TypePack* resultPack = emplaceTypePack<TypePack>(asMutable(pack));
@ -2339,7 +2361,7 @@ TypeId TypeChecker2::flattenPack(TypePackId pack)
return result;
}
else if (get<Unifiable::Error>(pack))
else if (get<ErrorTypePack>(pack))
return builtinTypes->errorRecoveryType();
else if (finite(pack) && size(pack) == 0)
return builtinTypes->nilType; // `(f())` where `f()` returns no values is coerced into `nil`
@ -2347,30 +2369,30 @@ TypeId TypeChecker2::flattenPack(TypePackId pack)
ice->ice("flattenPack got a weird pack!");
}
void TypeChecker2::visitGenerics(AstArray<AstGenericType> generics, AstArray<AstGenericTypePack> genericPacks)
void TypeChecker2::visitGenerics(AstArray<AstGenericType*> generics, AstArray<AstGenericTypePack*> genericPacks)
{
DenseHashSet<AstName> seen{AstName{}};
for (const auto& g : generics)
for (const auto* g : generics)
{
if (seen.contains(g.name))
reportError(DuplicateGenericParameter{g.name.value}, g.location);
if (seen.contains(g->name))
reportError(DuplicateGenericParameter{g->name.value}, g->location);
else
seen.insert(g.name);
seen.insert(g->name);
if (g.defaultValue)
visit(g.defaultValue);
if (g->defaultValue)
visit(g->defaultValue);
}
for (const auto& g : genericPacks)
for (const auto* g : genericPacks)
{
if (seen.contains(g.name))
reportError(DuplicateGenericParameter{g.name.value}, g.location);
if (seen.contains(g->name))
reportError(DuplicateGenericParameter{g->name.value}, g->location);
else
seen.insert(g.name);
seen.insert(g->name);
if (g.defaultValue)
visit(g.defaultValue);
if (g->defaultValue)
visit(g->defaultValue);
}
}
@ -2392,6 +2414,8 @@ void TypeChecker2::visit(AstType* ty)
return visit(t);
else if (auto t = ty->as<AstTypeIntersection>())
return visit(t);
else if (auto t = ty->as<AstTypeGroup>())
return visit(t->type);
}
void TypeChecker2::visit(AstTypeReference* ty)
@ -3012,10 +3036,8 @@ PropertyType TypeChecker2::hasIndexTypeFromType(
if (tt->indexer)
{
TypeId indexType = follow(tt->indexer->indexType);
if (isPrim(indexType, PrimitiveType::String))
return {NormalizationResult::True, {tt->indexer->indexResultType}};
// If the indexer looks like { [any] : _} - the prop lookup should be allowed!
else if (get<AnyType>(indexType) || get<UnknownType>(indexType))
TypeId givenType = module->internalTypes.addType(SingletonType{StringSingleton{prop}});
if (isSubtype(givenType, indexType, NotNull{module->getModuleScope().get()}, builtinTypes, simplifier, *ice))
return {NormalizationResult::True, {tt->indexer->indexResultType}};
}

File diff suppressed because it is too large Load diff

View file

@ -3,6 +3,7 @@
#include "Luau/DenseHash.h"
#include "Luau/Normalize.h"
#include "Luau/ToString.h"
#include "Luau/TypeFunction.h"
#include "Luau/Type.h"
#include "Luau/TypePack.h"

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -23,18 +23,18 @@
#include <algorithm>
#include <iterator>
LUAU_FASTFLAGVARIABLE(DebugLuauMagicTypes, false)
LUAU_FASTFLAGVARIABLE(DebugLuauMagicTypes)
LUAU_FASTINTVARIABLE(LuauTypeInferRecursionLimit, 165)
LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 20000)
LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 5000)
LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 300)
LUAU_FASTINTVARIABLE(LuauVisitRecursionLimit, 500)
LUAU_FASTFLAG(LuauKnowsTheDataModel3)
LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false)
LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification)
LUAU_FASTFLAG(LuauInstantiateInSubtyping)
LUAU_FASTFLAGVARIABLE(LuauRemoveBadRelationalOperatorWarning, false)
LUAU_FASTFLAGVARIABLE(LuauOkWithIteratingOverTableProperties, false)
LUAU_FASTFLAGVARIABLE(LuauAcceptIndexingTableUnionsIntersections, false)
LUAU_FASTFLAGVARIABLE(LuauOldSolverCreatesChildScopePointers)
LUAU_FASTFLAG(LuauPreserveUnionIntersectionNodeForLeadingTokenSingleType)
LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds)
namespace Luau
{
@ -265,11 +265,7 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo
ScopePtr parentScope = environmentScope.value_or(globalScope);
ScopePtr moduleScope = std::make_shared<Scope>(parentScope);
if (module.cyclic)
moduleScope->returnType = addTypePack(TypePack{{anyType}, std::nullopt});
else
moduleScope->returnType = freshTypePack(moduleScope);
moduleScope->returnType = freshTypePack(moduleScope);
moduleScope->varargPack = anyTypePack;
currentModule->scopes.push_back(std::make_pair(module.root->location, moduleScope));
@ -767,8 +763,12 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatRepeat& state
struct Demoter : Substitution
{
Demoter(TypeArena* arena)
TypeArena* arena = nullptr;
NotNull<BuiltinTypes> builtins;
Demoter(TypeArena* arena, NotNull<BuiltinTypes> builtins)
: Substitution(TxnLog::empty(), arena)
, arena(arena)
, builtins(builtins)
{
}
@ -794,7 +794,8 @@ struct Demoter : Substitution
{
auto ftv = get<FreeType>(ty);
LUAU_ASSERT(ftv);
return addType(FreeType{demotedLevel(ftv->level)});
return FFlag::LuauFreeTypesMustHaveBounds ? arena->freshType(builtins, demotedLevel(ftv->level))
: addType(FreeType{demotedLevel(ftv->level)});
}
TypePackId clean(TypePackId tp) override
@ -841,7 +842,7 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatReturn& retur
}
}
Demoter demoter{&currentModule->internalTypes};
Demoter demoter{&currentModule->internalTypes, builtinTypes};
demoter.demote(expectedTypes);
TypePackId retPack = checkExprList(scope, return_.location, return_.list, false, {}, expectedTypes).type;
@ -958,7 +959,7 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assig
else if (auto tail = valueIter.tail())
{
TypePackId tailPack = follow(*tail);
if (get<Unifiable::Error>(tailPack))
if (get<ErrorTypePack>(tailPack))
right = errorRecoveryType(scope);
else if (auto vtp = get<VariadicTypePack>(tailPack))
right = vtp->ty;
@ -1238,7 +1239,7 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin)
iterTy = freshType(scope);
unify(callRetPack, addTypePack({{iterTy}, freshTypePack(scope)}), scope, forin.location);
}
else if (get<Unifiable::Error>(callRetPack) || !first(callRetPack))
else if (get<ErrorTypePack>(callRetPack) || !first(callRetPack))
{
for (TypeId var : varTypes)
unify(errorRecoveryType(scope), var, scope, forin.location);
@ -1284,20 +1285,11 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin)
for (size_t i = 2; i < varTypes.size(); ++i)
unify(nilType, varTypes[i], scope, forin.location);
}
else if (isNonstrictMode() || FFlag::LuauOkWithIteratingOverTableProperties)
else
{
for (TypeId var : varTypes)
unify(unknownType, var, scope, forin.location);
}
else
{
TypeId varTy = errorRecoveryType(loopScope);
for (TypeId var : varTypes)
unify(varTy, var, scope, forin.location);
reportError(firstValue->location, GenericError{"Cannot iterate over a table without indexer"});
}
return check(loopScope, *forin.body);
}
@ -1975,7 +1967,7 @@ WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExp
*asMutable(varargPack) = TypePack{{head}, tail};
return WithPredicate{head};
}
if (get<ErrorType>(varargPack))
if (get<ErrorTypePack>(varargPack))
return WithPredicate{errorRecoveryType(scope)};
else if (auto vtp = get<VariadicTypePack>(varargPack))
return WithPredicate{vtp->ty};
@ -2005,7 +1997,7 @@ WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExp
unify(pack, retPack, scope, expr.location);
return {head, std::move(result.predicates)};
}
if (get<Unifiable::Error>(retPack))
if (get<ErrorTypePack>(retPack))
return {errorRecoveryType(scope), std::move(result.predicates)};
else if (auto vtp = get<VariadicTypePack>(retPack))
return {vtp->ty, std::move(result.predicates)};
@ -2804,34 +2796,19 @@ TypeId TypeChecker::checkRelationalOperation(
{
reportErrors(state.errors);
if (FFlag::LuauRemoveBadRelationalOperatorWarning)
// The original version of this check also produced this error when we had a union type.
// However, the old solver does not readily have the ability to discern if the union is comparable.
// This is the case when the lhs is e.g. a union of singletons and the rhs is the combined type.
// The new solver has much more powerful logic for resolving relational operators, but for now,
// we need to be conservative in the old solver to deliver a reasonable developer experience.
if (!isEquality && state.errors.empty() && isBoolean(leftType))
{
// The original version of this check also produced this error when we had a union type.
// However, the old solver does not readily have the ability to discern if the union is comparable.
// This is the case when the lhs is e.g. a union of singletons and the rhs is the combined type.
// The new solver has much more powerful logic for resolving relational operators, but for now,
// we need to be conservative in the old solver to deliver a reasonable developer experience.
if (!isEquality && state.errors.empty() && isBoolean(leftType))
{
reportError(
expr.location,
GenericError{
format("Type '%s' cannot be compared with relational operator %s", toString(leftType).c_str(), toString(expr.op).c_str())
}
);
}
}
else
{
if (!isEquality && state.errors.empty() && (get<UnionType>(leftType) || isBoolean(leftType)))
{
reportError(
expr.location,
GenericError{
format("Type '%s' cannot be compared with relational operator %s", toString(leftType).c_str(), toString(expr.op).c_str())
}
);
}
reportError(
expr.location,
GenericError{
format("Type '%s' cannot be compared with relational operator %s", toString(leftType).c_str(), toString(expr.op).c_str())
}
);
}
return booleanType;
@ -2896,7 +2873,7 @@ TypeId TypeChecker::checkRelationalOperation(
std::optional<TypeId> metamethod = findMetatableEntry(lhsType, metamethodName, expr.location, /* addErrors= */ true);
if (metamethod)
{
if (const FunctionType* ftv = get<FunctionType>(*metamethod))
if (const FunctionType* ftv = get<FunctionType>(follow(*metamethod)))
{
if (isEquality)
{
@ -3507,7 +3484,6 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex
}
}
if (FFlag::LuauAcceptIndexingTableUnionsIntersections)
{
// We're going to have a whole vector.
std::vector<TableType*> tableTypes{};
@ -3658,57 +3634,6 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex
return addType(IntersectionType{{resultTypes.begin(), resultTypes.end()}});
}
else
{
TableType* exprTable = getMutableTableType(exprType);
if (!exprTable)
{
reportError(TypeError{expr.expr->location, NotATable{exprType}});
return errorRecoveryType(scope);
}
if (value)
{
const auto& it = exprTable->props.find(value->value.data);
if (it != exprTable->props.end())
{
return it->second.type();
}
else if ((ctx == ValueContext::LValue && exprTable->state == TableState::Unsealed) || exprTable->state == TableState::Free)
{
TypeId resultType = freshType(scope);
Property& property = exprTable->props[value->value.data];
property.setType(resultType);
property.location = expr.index->location;
return resultType;
}
}
if (exprTable->indexer)
{
const TableIndexer& indexer = *exprTable->indexer;
unify(indexType, indexer.indexType, scope, expr.index->location);
return indexer.indexResultType;
}
else if ((ctx == ValueContext::LValue && exprTable->state == TableState::Unsealed) || exprTable->state == TableState::Free)
{
TypeId indexerType = freshType(exprTable->level);
unify(indexType, indexerType, scope, expr.location);
TypeId indexResultType = freshType(exprTable->level);
exprTable->indexer = TableIndexer{anyIfNonstrict(indexerType), anyIfNonstrict(indexResultType)};
return indexResultType;
}
else
{
/*
* If we use [] indexing to fetch a property from a sealed table that
* has no indexer, we have no idea if it will work so we just return any
* and hope for the best.
*/
return anyType;
}
}
}
// Answers the question: "Can I define another function with this name?"
@ -4163,7 +4088,7 @@ void TypeChecker::checkArgumentList(
if (argIter.tail())
{
TypePackId tail = *argIter.tail();
if (state.log.getMutable<Unifiable::Error>(tail))
if (state.log.getMutable<ErrorTypePack>(tail))
{
// Unify remaining parameters so we don't leave any free-types hanging around.
while (paramIter != endIter)
@ -4248,7 +4173,7 @@ void TypeChecker::checkArgumentList(
}
TypePackId tail = state.log.follow(*paramIter.tail());
if (state.log.getMutable<Unifiable::Error>(tail))
if (state.log.getMutable<ErrorTypePack>(tail))
{
// Function is variadic. Ok.
return;
@ -4384,7 +4309,7 @@ WithPredicate<TypePackId> TypeChecker::checkExprPackHelper(const ScopePtr& scope
WithPredicate<TypePackId> argListResult = checkExprList(scope, expr.location, expr.args, false, {}, expectedTypes);
TypePackId argPack = argListResult.type;
if (get<Unifiable::Error>(argPack))
if (get<ErrorTypePack>(argPack))
return WithPredicate{errorRecoveryTypePack(scope)};
TypePack* args = nullptr;
@ -4490,7 +4415,7 @@ std::vector<std::optional<TypeId>> TypeChecker::getExpectedTypesForCall(const st
}
}
Demoter demoter{&currentModule->internalTypes};
Demoter demoter{&currentModule->internalTypes, builtinTypes};
demoter.demote(expectedTypes);
return expectedTypes;
@ -4588,10 +4513,10 @@ std::unique_ptr<WithPredicate<TypePackId>> TypeChecker::checkCallOverload(
// When this function type has magic functions and did return something, we select that overload instead.
// TODO: pass in a Unifier object to the magic functions? This will allow the magic functions to cooperate with overload resolution.
if (ftv->magicFunction)
if (ftv->magic)
{
// TODO: We're passing in the wrong TypePackId. Should be argPack, but a unit test fails otherwise. CLI-40458
if (std::optional<WithPredicate<TypePackId>> ret = ftv->magicFunction(*this, scope, expr, argListResult))
if (std::optional<WithPredicate<TypePackId>> ret = ftv->magic->handleOldSolver(*this, scope, expr, argListResult))
return std::make_unique<WithPredicate<TypePackId>>(std::move(*ret));
}
@ -4974,7 +4899,7 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module
TypePackId modulePack = module->returnType;
if (get<Unifiable::Error>(modulePack))
if (get<ErrorTypePack>(modulePack))
return errorRecoveryType(scope);
std::optional<TypeId> moduleType = first(modulePack);
@ -5063,17 +4988,17 @@ void TypeChecker::unifyWithInstantiationIfNeeded(TypeId subTy, TypeId superTy, c
{
// First try unifying with the original uninstantiated type
// but if that fails, try the instantiated one.
Unifier child = state.makeChildUnifier();
child.tryUnify(subTy, superTy, /*isFunctionCall*/ false);
if (!child.errors.empty())
std::unique_ptr<Unifier> child = state.makeChildUnifier();
child->tryUnify(subTy, superTy, /*isFunctionCall*/ false);
if (!child->errors.empty())
{
TypeId instantiated = instantiate(scope, subTy, state.location, &child.log);
TypeId instantiated = instantiate(scope, subTy, state.location, &child->log);
if (subTy == instantiated)
{
// Instantiating the argument made no difference, so just report any child errors
state.log.concat(std::move(child.log));
state.log.concat(std::move(child->log));
state.errors.insert(state.errors.end(), child.errors.begin(), child.errors.end());
state.errors.insert(state.errors.end(), child->errors.begin(), child->errors.end());
}
else
{
@ -5082,7 +5007,7 @@ void TypeChecker::unifyWithInstantiationIfNeeded(TypeId subTy, TypeId superTy, c
}
else
{
state.log.concat(std::move(child.log));
state.log.concat(std::move(child->log));
}
}
}
@ -5287,6 +5212,13 @@ LUAU_NOINLINE void TypeChecker::reportErrorCodeTooComplex(const Location& locati
ScopePtr TypeChecker::childFunctionScope(const ScopePtr& parent, const Location& location, int subLevel)
{
ScopePtr scope = std::make_shared<Scope>(parent, subLevel);
if (FFlag::LuauOldSolverCreatesChildScopePointers)
{
scope->location = location;
scope->returnType = parent->returnType;
parent->children.emplace_back(scope.get());
}
currentModule->scopes.push_back(std::make_pair(location, scope));
return scope;
}
@ -5297,6 +5229,12 @@ ScopePtr TypeChecker::childScope(const ScopePtr& parent, const Location& locatio
ScopePtr scope = std::make_shared<Scope>(parent);
scope->level = parent->level;
scope->varargPack = parent->varargPack;
if (FFlag::LuauOldSolverCreatesChildScopePointers)
{
scope->location = location;
scope->returnType = parent->returnType;
parent->children.emplace_back(scope.get());
}
currentModule->scopes.push_back(std::make_pair(location, scope));
return scope;
@ -5342,7 +5280,8 @@ TypeId TypeChecker::freshType(const ScopePtr& scope)
TypeId TypeChecker::freshType(TypeLevel level)
{
return currentModule->internalTypes.addType(Type(FreeType(level)));
return FFlag::LuauFreeTypesMustHaveBounds ? currentModule->internalTypes.freshType(builtinTypes, level)
: currentModule->internalTypes.addType(Type(FreeType(level)));
}
TypeId TypeChecker::singletonType(bool value)
@ -5787,6 +5726,12 @@ TypeId TypeChecker::resolveTypeWorker(const ScopePtr& scope, const AstType& anno
}
else if (const auto& un = annotation.as<AstTypeUnion>())
{
if (FFlag::LuauPreserveUnionIntersectionNodeForLeadingTokenSingleType)
{
if (un->types.size == 1)
return resolveType(scope, *un->types.data[0]);
}
std::vector<TypeId> types;
for (AstType* ann : un->types)
types.push_back(resolveType(scope, *ann));
@ -5795,12 +5740,22 @@ TypeId TypeChecker::resolveTypeWorker(const ScopePtr& scope, const AstType& anno
}
else if (const auto& un = annotation.as<AstTypeIntersection>())
{
if (FFlag::LuauPreserveUnionIntersectionNodeForLeadingTokenSingleType)
{
if (un->types.size == 1)
return resolveType(scope, *un->types.data[0]);
}
std::vector<TypeId> types;
for (AstType* ann : un->types)
types.push_back(resolveType(scope, *ann));
return addType(IntersectionType{types});
}
else if (const auto& g = annotation.as<AstTypeGroup>())
{
return resolveType(scope, *g->type);
}
else if (const auto& tsb = annotation.as<AstTypeSingletonBool>())
{
return singletonType(tsb->value);
@ -5958,8 +5913,8 @@ GenericTypeDefinitions TypeChecker::createGenericTypes(
const ScopePtr& scope,
std::optional<TypeLevel> levelOpt,
const AstNode& node,
const AstArray<AstGenericType>& genericNames,
const AstArray<AstGenericTypePack>& genericPackNames,
const AstArray<AstGenericType*>& genericNames,
const AstArray<AstGenericTypePack*>& genericPackNames,
bool useCache
)
{
@ -5969,14 +5924,14 @@ GenericTypeDefinitions TypeChecker::createGenericTypes(
std::vector<GenericTypeDefinition> generics;
for (const AstGenericType& generic : genericNames)
for (const AstGenericType* generic : genericNames)
{
std::optional<TypeId> defaultValue;
if (generic.defaultValue)
defaultValue = resolveType(scope, *generic.defaultValue);
if (generic->defaultValue)
defaultValue = resolveType(scope, *generic->defaultValue);
Name n = generic.name.value;
Name n = generic->name.value;
// These generics are the only thing that will ever be added to scope, so we can be certain that
// a collision can only occur when two generic types have the same name.
@ -6005,14 +5960,14 @@ GenericTypeDefinitions TypeChecker::createGenericTypes(
std::vector<GenericTypePackDefinition> genericPacks;
for (const AstGenericTypePack& genericPack : genericPackNames)
for (const AstGenericTypePack* genericPack : genericPackNames)
{
std::optional<TypePackId> defaultValue;
if (genericPack.defaultValue)
defaultValue = resolveTypePack(scope, *genericPack.defaultValue);
if (genericPack->defaultValue)
defaultValue = resolveTypePack(scope, *genericPack->defaultValue);
Name n = genericPack.name.value;
Name n = genericPack->name.value;
// These generics are the only thing that will ever be added to scope, so we can be certain that
// a collision can only occur when two generic types have the same name.
@ -6418,7 +6373,7 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, RefinementMap& r
}
// We're only interested in the root class of any classes.
if (auto ctv = get<ClassType>(type); !ctv || ctv->parent != builtinTypes->classType)
if (auto ctv = get<ClassType>(type); !ctv || (ctv->parent != builtinTypes->classType && !hasTag(type, kTypeofRootTag)))
return addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope));
// This probably hints at breaking out type filtering functions from the predicate solver so that typeof is not tightly coupled with IsA.

View file

@ -14,7 +14,7 @@
#include <type_traits>
LUAU_FASTFLAG(LuauSolverV2);
LUAU_FASTFLAGVARIABLE(LuauDisableNewSolverAssertsInMixedMode);
// Maximum number of steps to follow when traversing a path. May not always
// equate to the number of components in a path, depending on the traversal
// logic.
@ -156,14 +156,16 @@ Path PathBuilder::build()
PathBuilder& PathBuilder::readProp(std::string name)
{
LUAU_ASSERT(FFlag::LuauSolverV2);
if (!FFlag::LuauDisableNewSolverAssertsInMixedMode)
LUAU_ASSERT(FFlag::LuauSolverV2);
components.push_back(Property{std::move(name), true});
return *this;
}
PathBuilder& PathBuilder::writeProp(std::string name)
{
LUAU_ASSERT(FFlag::LuauSolverV2);
if (!FFlag::LuauDisableNewSolverAssertsInMixedMode)
LUAU_ASSERT(FFlag::LuauSolverV2);
components.push_back(Property{std::move(name), false});
return *this;
}
@ -415,6 +417,14 @@ struct TraversalState
switch (field)
{
case TypePath::TypeField::Table:
if (auto mt = get<MetatableType>(current))
{
updateCurrent(mt->table);
return true;
}
return false;
case TypePath::TypeField::Metatable:
if (auto currentType = get<TypeId>(current))
{
@ -561,6 +571,9 @@ std::string toString(const TypePath::Path& path, bool prefixDot)
switch (c)
{
case TypePath::TypeField::Table:
result << "table";
break;
case TypePath::TypeField::Metatable:
result << "metatable";
break;

View file

@ -5,12 +5,16 @@
#include "Luau/Normalize.h"
#include "Luau/Scope.h"
#include "Luau/ToString.h"
#include "Luau/Type.h"
#include "Luau/TypeInfer.h"
#include <algorithm>
LUAU_FASTFLAG(LuauSolverV2);
LUAU_FASTFLAG(LuauAutocompleteRefactorsForIncrementalAutocomplete);
LUAU_FASTFLAG(LuauTrackInteriorFreeTypesOnScope);
LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds)
LUAU_FASTFLAG(LuauDisableNewSolverAssertsInMixedMode)
namespace Luau
{
@ -317,9 +321,11 @@ TypePack extendTypePack(
{
FreeType ft{ftp->scope, builtinTypes->neverType, builtinTypes->unknownType};
t = arena.addType(ft);
if (FFlag::LuauTrackInteriorFreeTypesOnScope)
trackInteriorFreeType(ftp->scope, t);
}
else
t = arena.freshType(ftp->scope);
t = FFlag::LuauFreeTypesMustHaveBounds ? arena.freshType(builtinTypes, ftp->scope) : arena.freshType_DEPRECATED(ftp->scope);
}
newPack.head.push_back(t);
@ -331,7 +337,7 @@ TypePack extendTypePack(
return result;
}
else if (const Unifiable::Error* etp = getMutable<Unifiable::Error>(pack))
else if (auto etp = getMutable<ErrorTypePack>(pack))
{
while (result.head.size() < length)
result.head.push_back(builtinTypes->errorRecoveryType());
@ -426,7 +432,7 @@ TypeId stripNil(NotNull<BuiltinTypes> builtinTypes, TypeArena& arena, TypeId ty)
ErrorSuppression shouldSuppressErrors(NotNull<Normalizer> normalizer, TypeId ty)
{
LUAU_ASSERT(FFlag::LuauSolverV2);
LUAU_ASSERT(FFlag::LuauSolverV2 || FFlag::LuauAutocompleteRefactorsForIncrementalAutocomplete);
std::shared_ptr<const NormalizedType> normType = normalizer->normalize(ty);
if (!normType)
@ -479,4 +485,87 @@ ErrorSuppression shouldSuppressErrors(NotNull<Normalizer> normalizer, TypePackId
return result;
}
bool isLiteral(const AstExpr* expr)
{
return (
expr->is<AstExprTable>() || expr->is<AstExprFunction>() || expr->is<AstExprConstantNumber>() || expr->is<AstExprConstantString>() ||
expr->is<AstExprConstantBool>() || expr->is<AstExprConstantNil>()
);
}
/**
* Visitor which, given an expression and a mapping from expression to TypeId,
* determines if there are any literal expressions that contain blocked types.
* This is used for bi-directional inference: we want to "apply" a type from
* a function argument or a type annotation to a literal.
*/
class BlockedTypeInLiteralVisitor : public AstVisitor
{
public:
explicit BlockedTypeInLiteralVisitor(NotNull<DenseHashMap<const AstExpr*, TypeId>> astTypes, NotNull<std::vector<TypeId>> toBlock)
: astTypes_{astTypes}
, toBlock_{toBlock}
{
}
bool visit(AstNode*) override
{
return false;
}
bool visit(AstExpr* e) override
{
auto ty = astTypes_->find(e);
if (ty && (get<BlockedType>(follow(*ty)) != nullptr))
{
toBlock_->push_back(*ty);
}
return isLiteral(e) || e->is<AstExprGroup>();
}
private:
NotNull<DenseHashMap<const AstExpr*, TypeId>> astTypes_;
NotNull<std::vector<TypeId>> toBlock_;
};
std::vector<TypeId> findBlockedTypesIn(AstExprTable* expr, NotNull<DenseHashMap<const AstExpr*, TypeId>> astTypes)
{
std::vector<TypeId> toBlock;
BlockedTypeInLiteralVisitor v{astTypes, NotNull{&toBlock}};
expr->visit(&v);
return toBlock;
}
std::vector<TypeId> findBlockedArgTypesIn(AstExprCall* expr, NotNull<DenseHashMap<const AstExpr*, TypeId>> astTypes)
{
std::vector<TypeId> toBlock;
BlockedTypeInLiteralVisitor v{astTypes, NotNull{&toBlock}};
for (auto arg : expr->args)
{
if (isLiteral(arg) || arg->is<AstExprGroup>())
{
arg->visit(&v);
}
}
return toBlock;
}
void trackInteriorFreeType(Scope* scope, TypeId ty)
{
if (FFlag::LuauDisableNewSolverAssertsInMixedMode)
LUAU_ASSERT(FFlag::LuauTrackInteriorFreeTypesOnScope);
else
LUAU_ASSERT(FFlag::LuauSolverV2 && FFlag::LuauTrackInteriorFreeTypesOnScope);
for (; scope; scope = scope->parent.get())
{
if (scope->interiorFreeTypes)
{
scope->interiorFreeTypes->push_back(ty);
return;
}
}
// There should at least be *one* generalization constraint per module
// where `interiorFreeTypes` is present, which would be the one made
// by ConstraintGenerator::visitModuleRoot.
LUAU_ASSERT(!"No scopes in parent chain had a present `interiorFreeTypes` member.");
}
} // namespace Luau

View file

@ -1,5 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Unifiable.h"
#include "Luau/TypeFwd.h"
#include "Luau/TypePack.h"
namespace Luau
{
@ -13,12 +15,17 @@ int freshIndex()
return ++nextIndex;
}
Error::Error()
template<typename Id>
Error<Id>::Error()
: index(++nextIndex)
{
}
int Error::nextIndex = 0;
template<typename Id>
int Error<Id>::nextIndex = 0;
template struct Error<TypeId>;
template struct Error<TypePackId>;
} // namespace Unifiable
} // namespace Luau

File diff suppressed because it is too large Load diff

View file

@ -908,7 +908,7 @@ OccursCheckResult Unifier2::occursCheck(DenseHashSet<TypePackId>& seen, TypePack
RecursionLimiter _ra(&recursionCount, recursionLimit);
while (!getMutable<Unifiable::Error>(haystack))
while (!getMutable<ErrorTypePack>(haystack))
{
if (needle == haystack)
return OccursCheckResult::Fail;

View file

@ -0,0 +1,48 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Ast.h"
#include "Luau/Location.h"
#include "Luau/DenseHash.h"
#include "Luau/Common.h"
#include <vector>
namespace Luau
{
class Allocator
{
public:
Allocator();
Allocator(Allocator&&);
Allocator& operator=(Allocator&&) = delete;
~Allocator();
void* allocate(size_t size);
template<typename T, typename... Args>
T* alloc(Args&&... args)
{
static_assert(std::is_trivially_destructible<T>::value, "Objects allocated with this allocator will never have their destructors run!");
T* t = static_cast<T*>(allocate(sizeof(T)));
new (t) T(std::forward<Args>(args)...);
return t;
}
private:
struct Page
{
Page* next;
alignas(8) char data[8192];
};
Page* root;
size_t offset;
};
} // namespace Luau

View file

@ -120,20 +120,6 @@ struct AstTypeList
using AstArgumentName = std::pair<AstName, Location>; // TODO: remove and replace when we get a common struct for this pair instead of AstName
struct AstGenericType
{
AstName name;
Location location;
AstType* defaultValue = nullptr;
};
struct AstGenericTypePack
{
AstName name;
Location location;
AstTypePack* defaultValue = nullptr;
};
extern int gAstRttiIndex;
template<typename T>
@ -253,6 +239,32 @@ public:
bool hasSemicolon;
};
class AstGenericType : public AstNode
{
public:
LUAU_RTTI(AstGenericType)
explicit AstGenericType(const Location& location, AstName name, AstType* defaultValue = nullptr);
void visit(AstVisitor* visitor) override;
AstName name;
AstType* defaultValue = nullptr;
};
class AstGenericTypePack : public AstNode
{
public:
LUAU_RTTI(AstGenericTypePack)
explicit AstGenericTypePack(const Location& location, AstName name, AstTypePack* defaultValue = nullptr);
void visit(AstVisitor* visitor) override;
AstName name;
AstTypePack* defaultValue = nullptr;
};
class AstExprGroup : public AstExpr
{
public:
@ -316,16 +328,18 @@ public:
enum QuoteStyle
{
Quoted,
QuotedSimple,
QuotedRaw,
Unquoted
};
AstExprConstantString(const Location& location, const AstArray<char>& value, QuoteStyle quoteStyle = Quoted);
AstExprConstantString(const Location& location, const AstArray<char>& value, QuoteStyle quoteStyle);
void visit(AstVisitor* visitor) override;
bool isQuoted() const;
AstArray<char> value;
QuoteStyle quoteStyle = Quoted;
QuoteStyle quoteStyle;
};
class AstExprLocal : public AstExpr
@ -422,8 +436,8 @@ public:
AstExprFunction(
const Location& location,
const AstArray<AstAttr*>& attributes,
const AstArray<AstGenericType>& generics,
const AstArray<AstGenericTypePack>& genericPacks,
const AstArray<AstGenericType*>& generics,
const AstArray<AstGenericTypePack*>& genericPacks,
AstLocal* self,
const AstArray<AstLocal*>& args,
bool vararg,
@ -441,8 +455,8 @@ public:
bool hasNativeAttribute() const;
AstArray<AstAttr*> attributes;
AstArray<AstGenericType> generics;
AstArray<AstGenericTypePack> genericPacks;
AstArray<AstGenericType*> generics;
AstArray<AstGenericTypePack*> genericPacks;
AstLocal* self;
AstArray<AstLocal*> args;
std::optional<AstTypeList> returnAnnotation;
@ -855,8 +869,8 @@ public:
const Location& location,
const AstName& name,
const Location& nameLocation,
const AstArray<AstGenericType>& generics,
const AstArray<AstGenericTypePack>& genericPacks,
const AstArray<AstGenericType*>& generics,
const AstArray<AstGenericTypePack*>& genericPacks,
AstType* type,
bool exported
);
@ -865,8 +879,8 @@ public:
AstName name;
Location nameLocation;
AstArray<AstGenericType> generics;
AstArray<AstGenericTypePack> genericPacks;
AstArray<AstGenericType*> generics;
AstArray<AstGenericTypePack*> genericPacks;
AstType* type;
bool exported;
};
@ -876,13 +890,14 @@ class AstStatTypeFunction : public AstStat
public:
LUAU_RTTI(AstStatTypeFunction);
AstStatTypeFunction(const Location& location, const AstName& name, const Location& nameLocation, AstExprFunction* body);
AstStatTypeFunction(const Location& location, const AstName& name, const Location& nameLocation, AstExprFunction* body, bool exported);
void visit(AstVisitor* visitor) override;
AstName name;
Location nameLocation;
AstExprFunction* body;
bool exported;
};
class AstStatDeclareGlobal : public AstStat
@ -908,8 +923,8 @@ public:
const Location& location,
const AstName& name,
const Location& nameLocation,
const AstArray<AstGenericType>& generics,
const AstArray<AstGenericTypePack>& genericPacks,
const AstArray<AstGenericType*>& generics,
const AstArray<AstGenericTypePack*>& genericPacks,
const AstTypeList& params,
const AstArray<AstArgumentName>& paramNames,
bool vararg,
@ -922,8 +937,8 @@ public:
const AstArray<AstAttr*>& attributes,
const AstName& name,
const Location& nameLocation,
const AstArray<AstGenericType>& generics,
const AstArray<AstGenericTypePack>& genericPacks,
const AstArray<AstGenericType*>& generics,
const AstArray<AstGenericTypePack*>& genericPacks,
const AstTypeList& params,
const AstArray<AstArgumentName>& paramNames,
bool vararg,
@ -939,8 +954,8 @@ public:
AstArray<AstAttr*> attributes;
AstName name;
Location nameLocation;
AstArray<AstGenericType> generics;
AstArray<AstGenericTypePack> genericPacks;
AstArray<AstGenericType*> generics;
AstArray<AstGenericTypePack*> genericPacks;
AstTypeList params;
AstArray<AstArgumentName> paramNames;
bool vararg = false;
@ -1071,8 +1086,8 @@ public:
AstTypeFunction(
const Location& location,
const AstArray<AstGenericType>& generics,
const AstArray<AstGenericTypePack>& genericPacks,
const AstArray<AstGenericType*>& generics,
const AstArray<AstGenericTypePack*>& genericPacks,
const AstTypeList& argTypes,
const AstArray<std::optional<AstArgumentName>>& argNames,
const AstTypeList& returnTypes
@ -1081,8 +1096,8 @@ public:
AstTypeFunction(
const Location& location,
const AstArray<AstAttr*>& attributes,
const AstArray<AstGenericType>& generics,
const AstArray<AstGenericTypePack>& genericPacks,
const AstArray<AstGenericType*>& generics,
const AstArray<AstGenericTypePack*>& genericPacks,
const AstTypeList& argTypes,
const AstArray<std::optional<AstArgumentName>>& argNames,
const AstTypeList& returnTypes
@ -1093,8 +1108,8 @@ public:
bool isCheckedFunction() const;
AstArray<AstAttr*> attributes;
AstArray<AstGenericType> generics;
AstArray<AstGenericTypePack> genericPacks;
AstArray<AstGenericType*> generics;
AstArray<AstGenericTypePack*> genericPacks;
AstTypeList argTypes;
AstArray<std::optional<AstArgumentName>> argNames;
AstTypeList returnTypes;
@ -1201,6 +1216,18 @@ public:
const AstArray<char> value;
};
class AstTypeGroup : public AstType
{
public:
LUAU_RTTI(AstTypeGroup)
explicit AstTypeGroup(const Location& location, AstType* type);
void visit(AstVisitor* visitor) override;
AstType* type;
};
class AstTypePack : public AstNode
{
public:
@ -1261,6 +1288,16 @@ public:
return visit(static_cast<AstNode*>(node));
}
virtual bool visit(class AstGenericType* node)
{
return visit(static_cast<AstNode*>(node));
}
virtual bool visit(class AstGenericTypePack* node)
{
return visit(static_cast<AstNode*>(node));
}
virtual bool visit(class AstExpr* node)
{
return visit(static_cast<AstNode*>(node));
@ -1467,6 +1504,10 @@ public:
{
return visit(static_cast<AstType*>(node));
}
virtual bool visit(class AstTypeGroup* node)
{
return visit(static_cast<AstType*>(node));
}
virtual bool visit(class AstTypeError* node)
{
return visit(static_cast<AstType*>(node));
@ -1490,6 +1531,7 @@ public:
}
};
bool isLValue(const AstExpr*);
AstName getIdentifier(AstExpr*);
Location getLocation(const AstTypeList& typeList);
@ -1520,4 +1562,4 @@ struct hash<Luau::AstName>
}
};
} // namespace std
} // namespace std

Some files were not shown because too many files have changed in this diff Show more