Merge branch 'master' into type_function_solver_dependency_fix2

This commit is contained in:
karl-police 2025-03-05 00:32:14 +01:00 committed by GitHub
commit 162fe2129b
Signed by: DevComp
GPG key ID: B5690EEEBB952194
479 changed files with 35991 additions and 7892 deletions

View file

@ -63,10 +63,10 @@ jobs:
} }
valgrind --tool=callgrind ./luau-analyze --mode=nonstrict bench/other/LuauPolyfillMap.lua 2>&1 | filter map-nonstrict | tee -a analyze-output.txt valgrind --tool=callgrind ./luau-analyze --mode=nonstrict bench/other/LuauPolyfillMap.lua 2>&1 | filter map-nonstrict | tee -a analyze-output.txt
valgrind --tool=callgrind ./luau-analyze --mode=strict bench/other/LuauPolyfillMap.lua 2>&1 | filter map-strict | tee -a analyze-output.txt valgrind --tool=callgrind ./luau-analyze --mode=strict bench/other/LuauPolyfillMap.lua 2>&1 | filter map-strict | tee -a analyze-output.txt
valgrind --tool=callgrind ./luau-analyze --mode=strict --fflags=DebugLuauDeferredConstraintResolution bench/other/LuauPolyfillMap.lua 2>&1 | filter map-dcr | tee -a analyze-output.txt valgrind --tool=callgrind ./luau-analyze --mode=strict --fflags=LuauSolverV2 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-dcr | tee -a analyze-output.txt
valgrind --tool=callgrind ./luau-analyze --mode=nonstrict bench/other/regex.lua 2>&1 | filter regex-nonstrict | tee -a analyze-output.txt valgrind --tool=callgrind ./luau-analyze --mode=nonstrict bench/other/regex.lua 2>&1 | filter regex-nonstrict | tee -a analyze-output.txt
valgrind --tool=callgrind ./luau-analyze --mode=strict bench/other/regex.lua 2>&1 | filter regex-strict | tee -a analyze-output.txt valgrind --tool=callgrind ./luau-analyze --mode=strict bench/other/regex.lua 2>&1 | filter regex-strict | tee -a analyze-output.txt
valgrind --tool=callgrind ./luau-analyze --mode=strict --fflags=DebugLuauDeferredConstraintResolution bench/other/regex.lua 2>&1 | filter regex-dcr | tee -a analyze-output.txt valgrind --tool=callgrind ./luau-analyze --mode=strict --fflags=LuauSolverV2 bench/other/regex.lua 2>&1 | filter regex-dcr | tee -a analyze-output.txt
- name: Run benchmark (compile) - name: Run benchmark (compile)
run: | run: |

View file

@ -46,9 +46,9 @@ jobs:
- name: make cli - name: make cli
run: | run: |
make -j2 config=sanitize werror=1 luau luau-analyze luau-compile # match config with tests to improve build time make -j2 config=sanitize werror=1 luau luau-analyze luau-compile # match config with tests to improve build time
./luau tests/conformance/assert.lua ./luau tests/conformance/assert.luau
./luau-analyze tests/conformance/assert.lua ./luau-analyze tests/conformance/assert.luau
./luau-compile tests/conformance/assert.lua ./luau-compile tests/conformance/assert.luau
windows: windows:
runs-on: windows-latest runs-on: windows-latest
@ -81,12 +81,12 @@ jobs:
shell: bash # necessary for fail-fast shell: bash # necessary for fail-fast
run: | run: |
cmake --build . --target Luau.Repl.CLI Luau.Analyze.CLI Luau.Compile.CLI --config Debug # match config with tests to improve build time cmake --build . --target Luau.Repl.CLI Luau.Analyze.CLI Luau.Compile.CLI --config Debug # match config with tests to improve build time
Debug/luau tests/conformance/assert.lua Debug/luau tests/conformance/assert.luau
Debug/luau-analyze tests/conformance/assert.lua Debug/luau-analyze tests/conformance/assert.luau
Debug/luau-compile tests/conformance/assert.lua Debug/luau-compile tests/conformance/assert.luau
coverage: coverage:
runs-on: ubuntu-20.04 # needed for clang++-10 to avoid gcov compatibility issues runs-on: ubuntu-22.04
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2
- name: install - name: install
@ -94,7 +94,7 @@ jobs:
sudo apt install llvm sudo apt install llvm
- name: make coverage - name: make coverage
run: | run: |
CXX=clang++-10 make -j2 config=coverage native=1 coverage CXX=clang++ make -j2 config=coverage native=1 coverage
- name: upload coverage - name: upload coverage
uses: codecov/codecov-action@v3 uses: codecov/codecov-action@v3
with: with:

View file

@ -29,8 +29,8 @@ jobs:
build: build:
needs: ["create-release"] needs: ["create-release"]
strategy: strategy:
matrix: # using ubuntu-20.04 to build a Linux binary targeting older glibc to improve compatibility matrix: # not using ubuntu-latest to improve compatibility
os: [{name: ubuntu, version: ubuntu-20.04}, {name: macos, version: macos-latest}, {name: windows, version: windows-latest}] os: [{name: ubuntu, version: ubuntu-22.04}, {name: macos, version: macos-latest}, {name: windows, version: windows-latest}]
name: ${{matrix.os.name}} name: ${{matrix.os.name}}
runs-on: ${{matrix.os.version}} runs-on: ${{matrix.os.version}}
steps: steps:
@ -38,7 +38,7 @@ jobs:
- name: configure - name: configure
run: cmake . -DCMAKE_BUILD_TYPE=Release run: cmake . -DCMAKE_BUILD_TYPE=Release
- name: build - name: build
run: cmake --build . --target Luau.Repl.CLI Luau.Analyze.CLI Luau.Compile.CLI --config Release -j 2 run: cmake --build . --target Luau.Repl.CLI Luau.Analyze.CLI Luau.Compile.CLI Luau.Ast.CLI --config Release -j 2
- name: pack - name: pack
if: matrix.os.name != 'windows' if: matrix.os.name != 'windows'
run: zip luau-${{matrix.os.name}}.zip luau* run: zip luau-${{matrix.os.name}}.zip luau*

View file

@ -13,8 +13,8 @@ on:
jobs: jobs:
build: build:
strategy: strategy:
matrix: # using ubuntu-20.04 to build a Linux binary targeting older glibc to improve compatibility matrix: # not using ubuntu-latest to improve compatibility
os: [{name: ubuntu, version: ubuntu-20.04}, {name: macos, version: macos-latest}, {name: windows, version: windows-latest}] os: [{name: ubuntu, version: ubuntu-22.04}, {name: macos, version: macos-latest}, {name: windows, version: windows-latest}]
name: ${{matrix.os.name}} name: ${{matrix.os.name}}
runs-on: ${{matrix.os.version}} runs-on: ${{matrix.os.version}}
steps: steps:

1
.gitignore vendored
View file

@ -13,6 +13,7 @@
/luau /luau
/luau-tests /luau-tests
/luau-analyze /luau-analyze
/luau-bytecode
/luau-compile /luau-compile
__pycache__ __pycache__
.cache .cache

View file

@ -1,10 +1,10 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once #pragma once
#include "Luau/AutocompleteTypes.h"
#include "Luau/Location.h" #include "Luau/Location.h"
#include "Luau/Type.h" #include "Luau/Type.h"
#include <unordered_map>
#include <string> #include <string>
#include <memory> #include <memory>
#include <optional> #include <optional>
@ -16,89 +16,8 @@ struct Frontend;
struct SourceModule; struct SourceModule;
struct Module; struct Module;
struct TypeChecker; struct TypeChecker;
struct FileResolver;
using ModulePtr = std::shared_ptr<Module>;
enum class AutocompleteContext
{
Unknown,
Expression,
Statement,
Property,
Type,
Keyword,
String,
};
enum class AutocompleteEntryKind
{
Property,
Binding,
Keyword,
String,
Type,
Module,
GeneratedFunction,
};
enum class ParenthesesRecommendation
{
None,
CursorAfter,
CursorInside,
};
enum class TypeCorrectKind
{
None,
Correct,
CorrectFunctionResult,
};
struct AutocompleteEntry
{
AutocompleteEntryKind kind = AutocompleteEntryKind::Property;
// Nullopt if kind is Keyword
std::optional<TypeId> type = std::nullopt;
bool deprecated = false;
// Only meaningful if kind is Property.
bool wrongIndexType = false;
// Set if this suggestion matches the type expected in the context
TypeCorrectKind typeCorrect = TypeCorrectKind::None;
std::optional<const ClassType*> containingClass = std::nullopt;
std::optional<const Property*> prop = std::nullopt;
std::optional<std::string> documentationSymbol = std::nullopt;
Tags tags;
ParenthesesRecommendation parens = ParenthesesRecommendation::None;
std::optional<std::string> insertText;
// Only meaningful if kind is Property.
bool indexedWithSelf = false;
};
using AutocompleteEntryMap = std::unordered_map<std::string, AutocompleteEntry>;
struct AutocompleteResult
{
AutocompleteEntryMap entryMap;
std::vector<AstNode*> ancestry;
AutocompleteContext context = AutocompleteContext::Unknown;
AutocompleteResult() = default;
AutocompleteResult(AutocompleteEntryMap entryMap, std::vector<AstNode*> ancestry, AutocompleteContext context)
: entryMap(std::move(entryMap))
, ancestry(std::move(ancestry))
, context(context)
{
}
};
using ModuleName = std::string;
using StringCompletionCallback =
std::function<std::optional<AutocompleteEntryMap>(std::string tag, std::optional<const ClassType*> ctx, std::optional<std::string> contents)>;
AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName, Position position, StringCompletionCallback callback); AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName, Position position, StringCompletionCallback callback);
constexpr char kGeneratedAnonymousFunctionEntryName[] = "function (anonymous autofilled)";
} // namespace Luau } // namespace Luau

View file

@ -0,0 +1,92 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Ast.h"
#include "Luau/Type.h"
#include <unordered_map>
namespace Luau
{
enum class AutocompleteContext
{
Unknown,
Expression,
Statement,
Property,
Type,
Keyword,
String,
};
enum class AutocompleteEntryKind
{
Property,
Binding,
Keyword,
String,
Type,
Module,
GeneratedFunction,
RequirePath,
};
enum class ParenthesesRecommendation
{
None,
CursorAfter,
CursorInside,
};
enum class TypeCorrectKind
{
None,
Correct,
CorrectFunctionResult,
};
struct AutocompleteEntry
{
AutocompleteEntryKind kind = AutocompleteEntryKind::Property;
// Nullopt if kind is Keyword
std::optional<TypeId> type = std::nullopt;
bool deprecated = false;
// Only meaningful if kind is Property.
bool wrongIndexType = false;
// Set if this suggestion matches the type expected in the context
TypeCorrectKind typeCorrect = TypeCorrectKind::None;
std::optional<const ClassType*> containingClass = std::nullopt;
std::optional<const Property*> prop = std::nullopt;
std::optional<std::string> documentationSymbol = std::nullopt;
Tags tags;
ParenthesesRecommendation parens = ParenthesesRecommendation::None;
std::optional<std::string> insertText;
// Only meaningful if kind is Property.
bool indexedWithSelf = false;
};
using AutocompleteEntryMap = std::unordered_map<std::string, AutocompleteEntry>;
struct AutocompleteResult
{
AutocompleteEntryMap entryMap;
std::vector<AstNode*> ancestry;
AutocompleteContext context = AutocompleteContext::Unknown;
AutocompleteResult() = default;
AutocompleteResult(AutocompleteEntryMap entryMap, std::vector<AstNode*> ancestry, AutocompleteContext context)
: entryMap(std::move(entryMap))
, ancestry(std::move(ancestry))
, context(context)
{
}
};
using StringCompletionCallback =
std::function<std::optional<AutocompleteEntryMap>(std::string tag, std::optional<const ClassType*> ctx, std::optional<std::string> contents)>;
constexpr char kGeneratedAnonymousFunctionEntryName[] = "function (anonymous autofilled)";
} // namespace Luau

View file

@ -9,6 +9,8 @@
namespace Luau namespace Luau
{ {
static constexpr char kRequireTagName[] = "require";
struct Frontend; struct Frontend;
struct GlobalTypes; struct GlobalTypes;
struct TypeChecker; struct TypeChecker;
@ -63,10 +65,7 @@ TypeId makeFunction( // Polymorphic
bool checked = false bool checked = false
); );
void attachMagicFunction(TypeId ty, MagicFunction fn); void attachMagicFunction(TypeId ty, std::shared_ptr<MagicFunction> fn);
void attachDcrMagicFunction(TypeId ty, DcrMagicFunction fn);
void attachDcrMagicRefinement(TypeId ty, DcrMagicRefinement fn);
void attachDcrMagicFunctionTypeCheck(TypeId ty, DcrMagicFunctionTypeCheck fn);
Property makeProperty(TypeId ty, std::optional<std::string> documentationSymbol = std::nullopt); Property makeProperty(TypeId ty, std::optional<std::string> documentationSymbol = std::nullopt);
void assignPropDocumentationSymbols(TableType::Props& props, const std::string& baseName); void assignPropDocumentationSymbols(TableType::Props& props, const std::string& baseName);
@ -80,4 +79,16 @@ std::optional<Binding> tryGetGlobalBinding(GlobalTypes& globals, const std::stri
Binding* tryGetGlobalBindingRef(GlobalTypes& globals, const std::string& name); Binding* tryGetGlobalBindingRef(GlobalTypes& globals, const std::string& name);
TypeId getGlobalBinding(GlobalTypes& globals, const std::string& name); TypeId getGlobalBinding(GlobalTypes& globals, const std::string& name);
/** A number of built-in functions are magical enough that we need to match on them specifically by
* name when they are called. These are listed here to be used whenever necessary, instead of duplicating this logic repeatedly.
*/
bool matchSetMetatable(const AstExprCall& call);
bool matchTableFreeze(const AstExprCall& call);
bool matchAssert(const AstExprCall& call);
// Returns `true` if the function should introduce typestate for its first argument.
bool shouldTypestateForFirstArgument(const AstExprCall& call);
} // namespace Luau } // namespace Luau

View file

@ -4,6 +4,7 @@
#include <Luau/NotNull.h> #include <Luau/NotNull.h>
#include "Luau/TypeArena.h" #include "Luau/TypeArena.h"
#include "Luau/Type.h" #include "Luau/Type.h"
#include "Luau/Scope.h"
#include <unordered_map> #include <unordered_map>
@ -22,8 +23,21 @@ struct CloneState
SeenTypePacks seenTypePacks; SeenTypePacks seenTypePacks;
}; };
/** `shallowClone` will make a copy of only the _top level_ constructor of the type,
* while `clone` will make a deep copy of the entire type and its every component.
*
* Be mindful about which behavior you actually _want_.
*
* Persistent types are not cloned as an optimization.
* If a type is cloned in order to mutate it, 'ignorePersistent' has to be set
*/
TypePackId shallowClone(TypePackId tp, TypeArena& dest, CloneState& cloneState, bool ignorePersistent = false);
TypeId shallowClone(TypeId typeId, TypeArena& dest, CloneState& cloneState, bool ignorePersistent = false);
TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState); TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState);
TypeId clone(TypeId tp, TypeArena& dest, CloneState& cloneState); TypeId clone(TypeId tp, TypeArena& dest, CloneState& cloneState);
TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState); TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState);
Binding clone(const Binding& binding, TypeArena& dest, CloneState& cloneState);
} // namespace Luau } // namespace Luau

View file

@ -109,6 +109,21 @@ struct FunctionCheckConstraint
NotNull<DenseHashMap<const AstExpr*, TypeId>> astExpectedTypes; NotNull<DenseHashMap<const AstExpr*, TypeId>> astExpectedTypes;
}; };
// table_check expectedType exprType
//
// If `expectedType` is a table type and `exprType` is _also_ a table type,
// propogate the member types of `expectedType` into the types of `exprType`.
// This is used to implement bidirectional inference on table assignment.
// Also see: FunctionCheckConstraint.
struct TableCheckConstraint
{
TypeId expectedType;
TypeId exprType;
AstExprTable* table = nullptr;
NotNull<DenseHashMap<const AstExpr*, TypeId>> astTypes;
NotNull<DenseHashMap<const AstExpr*, TypeId>> astExpectedTypes;
};
// prim FreeType ExpectedType PrimitiveType // prim FreeType ExpectedType PrimitiveType
// //
// FreeType is bounded below by the singleton type and above by PrimitiveType // FreeType is bounded below by the singleton type and above by PrimitiveType
@ -273,7 +288,8 @@ using ConstraintV = Variant<
UnpackConstraint, UnpackConstraint,
ReduceConstraint, ReduceConstraint,
ReducePackConstraint, ReducePackConstraint,
EqualityConstraint>; EqualityConstraint,
TableCheckConstraint>;
struct Constraint struct Constraint
{ {

View file

@ -5,6 +5,7 @@
#include "Luau/Constraint.h" #include "Luau/Constraint.h"
#include "Luau/ControlFlow.h" #include "Luau/ControlFlow.h"
#include "Luau/DataFlowGraph.h" #include "Luau/DataFlowGraph.h"
#include "Luau/EqSatSimplification.h"
#include "Luau/InsertionOrderedMap.h" #include "Luau/InsertionOrderedMap.h"
#include "Luau/Module.h" #include "Luau/Module.h"
#include "Luau/ModuleResolver.h" #include "Luau/ModuleResolver.h"
@ -15,7 +16,6 @@
#include "Luau/TypeFwd.h" #include "Luau/TypeFwd.h"
#include "Luau/TypeUtils.h" #include "Luau/TypeUtils.h"
#include "Luau/Variant.h" #include "Luau/Variant.h"
#include "Luau/Normalize.h"
#include <memory> #include <memory>
#include <vector> #include <vector>
@ -28,6 +28,7 @@ struct Scope;
using ScopePtr = std::shared_ptr<Scope>; using ScopePtr = std::shared_ptr<Scope>;
struct DcrLogger; struct DcrLogger;
struct TypeFunctionRuntime;
struct Inference struct Inference
{ {
@ -95,6 +96,9 @@ struct ConstraintGenerator
// will enqueue them during solving. // will enqueue them during solving.
std::vector<ConstraintPtr> unqueuedConstraints; std::vector<ConstraintPtr> unqueuedConstraints;
// Map a function's signature scope back to its signature type.
DenseHashMap<Scope*, TypeId> scopeToFunction{nullptr};
// The private scope of type aliases for which the type parameters belong to. // The private scope of type aliases for which the type parameters belong to.
DenseHashMap<const AstStatTypeAlias*, ScopePtr> astTypeAliasDefiningScopes{nullptr}; DenseHashMap<const AstStatTypeAlias*, ScopePtr> astTypeAliasDefiningScopes{nullptr};
@ -108,6 +112,11 @@ struct ConstraintGenerator
// Needed to be able to enable error-suppression preservation for immediate refinements. // Needed to be able to enable error-suppression preservation for immediate refinements.
NotNull<Normalizer> normalizer; NotNull<Normalizer> normalizer;
NotNull<Simplifier> simplifier;
// Needed to register all available type functions for execution at later stages.
NotNull<TypeFunctionRuntime> typeFunctionRuntime;
// Needed to resolve modules to make 'require' import types properly. // Needed to resolve modules to make 'require' import types properly.
NotNull<ModuleResolver> moduleResolver; NotNull<ModuleResolver> moduleResolver;
// Occasionally constraint generation needs to produce an ICE. // Occasionally constraint generation needs to produce an ICE.
@ -125,6 +134,8 @@ struct ConstraintGenerator
ConstraintGenerator( ConstraintGenerator(
ModulePtr module, ModulePtr module,
NotNull<Normalizer> normalizer, NotNull<Normalizer> normalizer,
NotNull<Simplifier> simplifier,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<ModuleResolver> moduleResolver, NotNull<ModuleResolver> moduleResolver,
NotNull<BuiltinTypes> builtinTypes, NotNull<BuiltinTypes> builtinTypes,
NotNull<InternalErrorReporter> ice, NotNull<InternalErrorReporter> ice,
@ -142,6 +153,8 @@ struct ConstraintGenerator
*/ */
void visitModuleRoot(AstStatBlock* block); void visitModuleRoot(AstStatBlock* block);
void visitFragmentRoot(const ScopePtr& resumeScope, AstStatBlock* block);
private: private:
std::vector<std::vector<TypeId>> interiorTypes; std::vector<std::vector<TypeId>> interiorTypes;
@ -223,7 +236,10 @@ private:
); );
void applyRefinements(const ScopePtr& scope, Location location, RefinementId refinement); void applyRefinements(const ScopePtr& scope, Location location, RefinementId refinement);
LUAU_NOINLINE void checkAliases(const ScopePtr& scope, AstStatBlock* block);
ControlFlow visitBlockWithoutChildScope(const ScopePtr& scope, AstStatBlock* block); ControlFlow visitBlockWithoutChildScope(const ScopePtr& scope, AstStatBlock* block);
ControlFlow visitBlockWithoutChildScope_DEPRECATED(const ScopePtr& scope, AstStatBlock* block);
ControlFlow visit(const ScopePtr& scope, AstStat* stat); ControlFlow visit(const ScopePtr& scope, AstStat* stat);
ControlFlow visit(const ScopePtr& scope, AstStatBlock* block); ControlFlow visit(const ScopePtr& scope, AstStatBlock* block);
@ -282,11 +298,25 @@ private:
Inference check(const ScopePtr& scope, AstExprFunction* func, std::optional<TypeId> expectedType, bool generalize); Inference check(const ScopePtr& scope, AstExprFunction* func, std::optional<TypeId> expectedType, bool generalize);
Inference check(const ScopePtr& scope, AstExprUnary* unary); Inference check(const ScopePtr& scope, AstExprUnary* unary);
Inference check(const ScopePtr& scope, AstExprBinary* binary, std::optional<TypeId> expectedType); Inference check(const ScopePtr& scope, AstExprBinary* binary, std::optional<TypeId> expectedType);
Inference checkAstExprBinary(
const ScopePtr& scope,
const Location& location,
AstExprBinary::Op op,
AstExpr* left,
AstExpr* right,
std::optional<TypeId> expectedType
);
Inference check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional<TypeId> expectedType); Inference check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional<TypeId> expectedType);
Inference check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert); Inference check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert);
Inference check(const ScopePtr& scope, AstExprInterpString* interpString); Inference check(const ScopePtr& scope, AstExprInterpString* interpString);
Inference check(const ScopePtr& scope, AstExprTable* expr, std::optional<TypeId> expectedType); Inference check(const ScopePtr& scope, AstExprTable* expr, std::optional<TypeId> expectedType);
std::tuple<TypeId, TypeId, RefinementId> checkBinary(const ScopePtr& scope, AstExprBinary* binary, std::optional<TypeId> expectedType); std::tuple<TypeId, TypeId, RefinementId> checkBinary(
const ScopePtr& scope,
AstExprBinary::Op op,
AstExpr* left,
AstExpr* right,
std::optional<TypeId> expectedType
);
void visitLValue(const ScopePtr& scope, AstExpr* expr, TypeId rhsType); void visitLValue(const ScopePtr& scope, AstExpr* expr, TypeId rhsType);
void visitLValue(const ScopePtr& scope, AstExprLocal* local, TypeId rhsType); void visitLValue(const ScopePtr& scope, AstExprLocal* local, TypeId rhsType);
@ -321,6 +351,11 @@ private:
*/ */
void checkFunctionBody(const ScopePtr& scope, AstExprFunction* fn); void checkFunctionBody(const ScopePtr& scope, AstExprFunction* fn);
// Specializations of 'resolveType' below
TypeId resolveReferenceType(const ScopePtr& scope, AstType* ty, AstTypeReference* ref, bool inTypeArguments, bool replaceErrorWithFresh);
TypeId resolveTableType(const ScopePtr& scope, AstType* ty, AstTypeTable* tab, bool inTypeArguments, bool replaceErrorWithFresh);
TypeId resolveFunctionType(const ScopePtr& scope, AstType* ty, AstTypeFunction* fn, bool inTypeArguments, bool replaceErrorWithFresh);
/** /**
* Resolves a type from its AST annotation. * Resolves a type from its AST annotation.
* @param scope the scope that the type annotation appears within. * @param scope the scope that the type annotation appears within.
@ -360,7 +395,7 @@ private:
**/ **/
std::vector<std::pair<Name, GenericTypeDefinition>> createGenerics( std::vector<std::pair<Name, GenericTypeDefinition>> createGenerics(
const ScopePtr& scope, const ScopePtr& scope,
AstArray<AstGenericType> generics, AstArray<AstGenericType*> generics,
bool useCache = false, bool useCache = false,
bool addTypes = true bool addTypes = true
); );
@ -377,7 +412,7 @@ private:
**/ **/
std::vector<std::pair<Name, GenericTypePackDefinition>> createGenericPacks( std::vector<std::pair<Name, GenericTypePackDefinition>> createGenericPacks(
const ScopePtr& scope, const ScopePtr& scope,
AstArray<AstGenericTypePack> packs, AstArray<AstGenericTypePack*> packs,
bool useCache = false, bool useCache = false,
bool addTypes = true bool addTypes = true
); );
@ -391,6 +426,7 @@ private:
TypeId makeUnion(const ScopePtr& scope, Location location, TypeId lhs, TypeId rhs); TypeId makeUnion(const ScopePtr& scope, Location location, TypeId lhs, TypeId rhs);
// make an intersect type function of these two types // make an intersect type function of these two types
TypeId makeIntersect(const ScopePtr& scope, Location location, TypeId lhs, TypeId rhs); TypeId makeIntersect(const ScopePtr& scope, Location location, TypeId lhs, TypeId rhs);
void prepopulateGlobalScopeForFragmentTypecheck(const ScopePtr& globalScope, const ScopePtr& resumeScope, AstStatBlock* program);
/** Scan the program for global definitions. /** Scan the program for global definitions.
* *
@ -421,6 +457,8 @@ private:
const ScopePtr& scope, const ScopePtr& scope,
Location location Location location
); );
TypeId simplifyUnion(const ScopePtr& scope, Location location, TypeId left, TypeId right);
}; };
/** Borrow a vector of pointers from a vector of owning pointers to constraints. /** Borrow a vector of pointers from a vector of owning pointers to constraints.

View file

@ -3,7 +3,9 @@
#pragma once #pragma once
#include "Luau/Constraint.h" #include "Luau/Constraint.h"
#include "Luau/DataFlowGraph.h"
#include "Luau/DenseHash.h" #include "Luau/DenseHash.h"
#include "Luau/EqSatSimplification.h"
#include "Luau/Error.h" #include "Luau/Error.h"
#include "Luau/Location.h" #include "Luau/Location.h"
#include "Luau/Module.h" #include "Luau/Module.h"
@ -12,6 +14,7 @@
#include "Luau/ToString.h" #include "Luau/ToString.h"
#include "Luau/Type.h" #include "Luau/Type.h"
#include "Luau/TypeCheckLimits.h" #include "Luau/TypeCheckLimits.h"
#include "Luau/TypeFunction.h"
#include "Luau/TypeFwd.h" #include "Luau/TypeFwd.h"
#include "Luau/Variant.h" #include "Luau/Variant.h"
@ -56,17 +59,42 @@ struct HashInstantiationSignature
size_t operator()(const InstantiationSignature& signature) const; size_t operator()(const InstantiationSignature& signature) const;
}; };
struct TablePropLookupResult
{
// What types are we blocked on for determining this type?
std::vector<TypeId> blockedTypes;
// The type of the property (if we were able to determine it).
std::optional<TypeId> propType;
// Whether or not this is _definitely_ derived as the result of an indexer.
// We use this to determine whether or not code like:
//
// t.lol = nil;
//
// ... is legal. If `t: { [string]: ~nil }` then this is legal as
// there's no guarantee on whether "lol" specifically exists.
// However, if `t: { lol: ~nil }`, then we cannot allow assignment as
// that would remove "lol" from the table entirely.
bool isIndex = false;
};
struct ConstraintSolver struct ConstraintSolver
{ {
NotNull<TypeArena> arena; NotNull<TypeArena> arena;
NotNull<BuiltinTypes> builtinTypes; NotNull<BuiltinTypes> builtinTypes;
InternalErrorReporter iceReporter; InternalErrorReporter iceReporter;
NotNull<Normalizer> normalizer; NotNull<Normalizer> normalizer;
NotNull<Simplifier> simplifier;
NotNull<TypeFunctionRuntime> typeFunctionRuntime;
// The entire set of constraints that the solver is trying to resolve. // The entire set of constraints that the solver is trying to resolve.
std::vector<NotNull<Constraint>> constraints; std::vector<NotNull<Constraint>> constraints;
NotNull<DenseHashMap<Scope*, TypeId>> scopeToFunction;
NotNull<Scope> rootScope; NotNull<Scope> rootScope;
ModuleName currentModuleName; ModuleName currentModuleName;
// The dataflow graph of the program, used in constraint generation and for magic functions.
NotNull<const DataFlowGraph> dfg;
// Constraints that the solver has generated, rather than sourcing from the // Constraints that the solver has generated, rather than sourcing from the
// scope tree. // scope tree.
std::vector<std::unique_ptr<Constraint>> solverConstraints; std::vector<std::unique_ptr<Constraint>> solverConstraints;
@ -91,6 +119,9 @@ struct ConstraintSolver
// A mapping from free types to the number of unresolved constraints that mention them. // A mapping from free types to the number of unresolved constraints that mention them.
DenseHashMap<TypeId, size_t> unresolvedConstraints{{}}; DenseHashMap<TypeId, size_t> unresolvedConstraints{{}};
std::unordered_map<NotNull<const Constraint>, DenseHashSet<TypeId>> maybeMutatedFreeTypes;
std::unordered_map<TypeId, DenseHashSet<const Constraint*>> mutatedFreeTypeToConstraint;
// Irreducible/uninhabited type functions or type pack functions. // Irreducible/uninhabited type functions or type pack functions.
DenseHashSet<const void*> uninhabitedTypeFunctions{{}}; DenseHashSet<const void*> uninhabitedTypeFunctions{{}};
@ -114,12 +145,16 @@ struct ConstraintSolver
explicit ConstraintSolver( explicit ConstraintSolver(
NotNull<Normalizer> normalizer, NotNull<Normalizer> normalizer,
NotNull<Simplifier> simplifier,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<Scope> rootScope, NotNull<Scope> rootScope,
std::vector<NotNull<Constraint>> constraints, std::vector<NotNull<Constraint>> constraints,
NotNull<DenseHashMap<Scope*, TypeId>> scopeToFunction,
ModuleName moduleName, ModuleName moduleName,
NotNull<ModuleResolver> moduleResolver, NotNull<ModuleResolver> moduleResolver,
std::vector<RequireCycle> requireCycles, std::vector<RequireCycle> requireCycles,
DcrLogger* logger, DcrLogger* logger,
NotNull<const DataFlowGraph> dfg,
TypeCheckLimits limits TypeCheckLimits limits
); );
@ -139,9 +174,11 @@ struct ConstraintSolver
**/ **/
void finalizeTypeFunctions(); void finalizeTypeFunctions();
bool isDone(); bool isDone() const;
private: private:
void generalizeOneType(TypeId ty);
/** /**
* Bind a type variable to another type. * Bind a type variable to another type.
* *
@ -167,13 +204,14 @@ public:
*/ */
bool tryDispatch(NotNull<const Constraint> c, bool force); bool tryDispatch(NotNull<const Constraint> c, bool force);
bool tryDispatch(const SubtypeConstraint& c, NotNull<const Constraint> constraint, bool force); bool tryDispatch(const SubtypeConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const PackSubtypeConstraint& c, NotNull<const Constraint> constraint, bool force); bool tryDispatch(const PackSubtypeConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const GeneralizationConstraint& c, NotNull<const Constraint> constraint, bool force); bool tryDispatch(const GeneralizationConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const IterableConstraint& c, NotNull<const Constraint> constraint, bool force); bool tryDispatch(const IterableConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const NameConstraint& c, NotNull<const Constraint> constraint); bool tryDispatch(const NameConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const TypeAliasExpansionConstraint& c, NotNull<const Constraint> constraint); bool tryDispatch(const TypeAliasExpansionConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const FunctionCallConstraint& c, NotNull<const Constraint> constraint); bool tryDispatch(const FunctionCallConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const TableCheckConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const FunctionCheckConstraint& c, NotNull<const Constraint> constraint); bool tryDispatch(const FunctionCheckConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const PrimitiveTypeConstraint& c, NotNull<const Constraint> constraint); bool tryDispatch(const PrimitiveTypeConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const HasPropConstraint& c, NotNull<const Constraint> constraint); bool tryDispatch(const HasPropConstraint& c, NotNull<const Constraint> constraint);
@ -194,16 +232,16 @@ public:
bool tryDispatch(const UnpackConstraint& c, NotNull<const Constraint> constraint); bool tryDispatch(const UnpackConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const ReduceConstraint& c, NotNull<const Constraint> constraint, bool force); bool tryDispatch(const ReduceConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const ReducePackConstraint& c, NotNull<const Constraint> constraint, bool force); bool tryDispatch(const ReducePackConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const EqualityConstraint& c, NotNull<const Constraint> constraint, bool force); bool tryDispatch(const EqualityConstraint& c, NotNull<const Constraint> constraint);
// for a, ... in some_table do // for a, ... in some_table do
// also handles __iter metamethod // also handles __iter metamethod
bool tryDispatchIterableTable(TypeId iteratorTy, const IterableConstraint& c, NotNull<const Constraint> constraint, bool force); bool tryDispatchIterableTable(TypeId iteratorTy, const IterableConstraint& c, NotNull<const Constraint> constraint, bool force);
// for a, ... in next_function, t, ... do // for a, ... in next_function, t, ... do
bool tryDispatchIterableFunction(TypeId nextTy, TypeId tableTy, const IterableConstraint& c, NotNull<const Constraint> constraint, bool force); bool tryDispatchIterableFunction(TypeId nextTy, TypeId tableTy, const IterableConstraint& c, NotNull<const Constraint> constraint);
std::pair<std::vector<TypeId>, std::optional<TypeId>> lookupTableProp( TablePropLookupResult lookupTableProp(
NotNull<const Constraint> constraint, NotNull<const Constraint> constraint,
TypeId subjectType, TypeId subjectType,
const std::string& propName, const std::string& propName,
@ -211,7 +249,8 @@ public:
bool inConditional = false, bool inConditional = false,
bool suppressSimplification = false bool suppressSimplification = false
); );
std::pair<std::vector<TypeId>, std::optional<TypeId>> lookupTableProp(
TablePropLookupResult lookupTableProp(
NotNull<const Constraint> constraint, NotNull<const Constraint> constraint,
TypeId subjectType, TypeId subjectType,
const std::string& propName, const std::string& propName,
@ -270,10 +309,10 @@ public:
// FIXME: This use of a boolean for the return result is an appalling // FIXME: This use of a boolean for the return result is an appalling
// interface. // interface.
bool blockOnPendingTypes(TypeId target, NotNull<const Constraint> constraint); bool blockOnPendingTypes(TypeId target, NotNull<const Constraint> constraint);
bool blockOnPendingTypes(TypePackId target, NotNull<const Constraint> constraint); bool blockOnPendingTypes(TypePackId targetPack, NotNull<const Constraint> constraint);
void unblock(NotNull<const Constraint> progressed); void unblock(NotNull<const Constraint> progressed);
void unblock(TypeId progressed, Location location); void unblock(TypeId ty, Location location);
void unblock(TypePackId progressed, Location location); void unblock(TypePackId progressed, Location location);
void unblock(const std::vector<TypeId>& types, Location location); void unblock(const std::vector<TypeId>& types, Location location);
void unblock(const std::vector<TypePackId>& packs, Location location); void unblock(const std::vector<TypePackId>& packs, Location location);
@ -281,18 +320,18 @@ public:
/** /**
* @returns true if the TypeId is in a blocked state. * @returns true if the TypeId is in a blocked state.
*/ */
bool isBlocked(TypeId ty); bool isBlocked(TypeId ty) const;
/** /**
* @returns true if the TypePackId is in a blocked state. * @returns true if the TypePackId is in a blocked state.
*/ */
bool isBlocked(TypePackId tp); bool isBlocked(TypePackId tp) const;
/** /**
* Returns whether the constraint is blocked on anything. * Returns whether the constraint is blocked on anything.
* @param constraint the constraint to check. * @param constraint the constraint to check.
*/ */
bool isBlocked(NotNull<const Constraint> constraint); bool isBlocked(NotNull<const Constraint> constraint) const;
/** Pushes a new solver constraint to the solver. /** Pushes a new solver constraint to the solver.
* @param cv the body of the constraint. * @param cv the body of the constraint.
@ -308,7 +347,7 @@ public:
* @param location the location where the require is taking place; used for * @param location the location where the require is taking place; used for
* error locations. * error locations.
**/ **/
TypeId resolveModule(const ModuleInfo& module, const Location& location); TypeId resolveModule(const ModuleInfo& info, const Location& location);
void reportError(TypeErrorData&& data, const Location& location); void reportError(TypeErrorData&& data, const Location& location);
void reportError(TypeError e); void reportError(TypeError e);
@ -379,15 +418,21 @@ public:
**/ **/
void reproduceConstraints(NotNull<Scope> scope, const Location& location, const Substitution& subst); void reproduceConstraints(NotNull<Scope> scope, const Location& location, const Substitution& subst);
TypeId simplifyIntersection(NotNull<Scope> scope, Location location, TypeId left, TypeId right);
TypeId simplifyIntersection(NotNull<Scope> scope, Location location, std::set<TypeId> parts);
TypeId simplifyUnion(NotNull<Scope> scope, Location location, TypeId left, TypeId right);
TypeId errorRecoveryType() const; TypeId errorRecoveryType() const;
TypePackId errorRecoveryTypePack() const; TypePackId errorRecoveryTypePack() const;
TypePackId anyifyModuleReturnTypePackGenerics(TypePackId tp); TypePackId anyifyModuleReturnTypePackGenerics(TypePackId tp);
void throwTimeLimitError(); void throwTimeLimitError() const;
void throwUserCancelError(); void throwUserCancelError() const;
ToStringOptions opts; ToStringOptions opts;
void fillInDiscriminantTypes(NotNull<const Constraint> constraint, const std::vector<std::optional<TypeId>>& discriminantTypes);
}; };
void dump(NotNull<Scope> rootScope, struct ToStringOptions& opts); void dump(NotNull<Scope> rootScope, struct ToStringOptions& opts);

View file

@ -6,6 +6,7 @@
#include "Luau/ControlFlow.h" #include "Luau/ControlFlow.h"
#include "Luau/DenseHash.h" #include "Luau/DenseHash.h"
#include "Luau/Def.h" #include "Luau/Def.h"
#include "Luau/NotNull.h"
#include "Luau/Symbol.h" #include "Luau/Symbol.h"
#include "Luau/TypedAllocator.h" #include "Luau/TypedAllocator.h"
@ -35,6 +36,8 @@ struct DataFlowGraph
DataFlowGraph& operator=(DataFlowGraph&&) = default; DataFlowGraph& operator=(DataFlowGraph&&) = default;
DefId getDef(const AstExpr* expr) const; DefId getDef(const AstExpr* expr) const;
// Look up the definition optionally, knowing it may not be present.
std::optional<DefId> getDefOptional(const AstExpr* expr) const;
// Look up for the rvalue def for a compound assignment. // Look up for the rvalue def for a compound assignment.
std::optional<DefId> getRValueDefForCompoundAssign(const AstExpr* expr) const; std::optional<DefId> getRValueDefForCompoundAssign(const AstExpr* expr) const;
@ -46,13 +49,13 @@ struct DataFlowGraph
const RefinementKey* getRefinementKey(const AstExpr* expr) const; const RefinementKey* getRefinementKey(const AstExpr* expr) const;
private: private:
DataFlowGraph() = default; DataFlowGraph(NotNull<DefArena> defArena, NotNull<RefinementKeyArena> keyArena);
DataFlowGraph(const DataFlowGraph&) = delete; DataFlowGraph(const DataFlowGraph&) = delete;
DataFlowGraph& operator=(const DataFlowGraph&) = delete; DataFlowGraph& operator=(const DataFlowGraph&) = delete;
DefArena defArena; NotNull<DefArena> defArena;
RefinementKeyArena keyArena; NotNull<RefinementKeyArena> keyArena;
DenseHashMap<const AstExpr*, const Def*> astDefs{nullptr}; DenseHashMap<const AstExpr*, const Def*> astDefs{nullptr};
@ -68,7 +71,6 @@ private:
DenseHashMap<const AstExpr*, const Def*> compoundAssignDefs{nullptr}; DenseHashMap<const AstExpr*, const Def*> compoundAssignDefs{nullptr};
DenseHashMap<const AstExpr*, const RefinementKey*> astRefinementKeys{nullptr}; DenseHashMap<const AstExpr*, const RefinementKey*> astRefinementKeys{nullptr};
friend struct DataFlowGraphBuilder; friend struct DataFlowGraphBuilder;
}; };
@ -105,25 +107,37 @@ struct DataFlowResult
const RefinementKey* parent = nullptr; const RefinementKey* parent = nullptr;
}; };
using ScopeStack = std::vector<DfgScope*>;
struct DataFlowGraphBuilder struct DataFlowGraphBuilder
{ {
static DataFlowGraph build(AstStatBlock* root, NotNull<struct InternalErrorReporter> handle); static DataFlowGraph build(
AstStatBlock* block,
NotNull<DefArena> defArena,
NotNull<RefinementKeyArena> keyArena,
NotNull<struct InternalErrorReporter> handle
);
private: private:
DataFlowGraphBuilder() = default; DataFlowGraphBuilder(NotNull<DefArena> defArena, NotNull<RefinementKeyArena> keyArena);
DataFlowGraphBuilder(const DataFlowGraphBuilder&) = delete; DataFlowGraphBuilder(const DataFlowGraphBuilder&) = delete;
DataFlowGraphBuilder& operator=(const DataFlowGraphBuilder&) = delete; DataFlowGraphBuilder& operator=(const DataFlowGraphBuilder&) = delete;
DataFlowGraph graph; DataFlowGraph graph;
NotNull<DefArena> defArena{&graph.defArena}; NotNull<DefArena> defArena;
NotNull<RefinementKeyArena> keyArena{&graph.keyArena}; NotNull<RefinementKeyArena> keyArena;
struct InternalErrorReporter* handle = nullptr; struct InternalErrorReporter* handle = nullptr;
DfgScope* moduleScope = nullptr;
/// The arena owning all of the scope allocations for the dataflow graph being built.
std::vector<std::unique_ptr<DfgScope>> scopes; std::vector<std::unique_ptr<DfgScope>> scopes;
/// A stack of scopes used by the visitor to see where we are.
ScopeStack scopeStack;
DfgScope* currentScope();
struct FunctionCapture struct FunctionCapture
{ {
std::vector<DefId> captureDefs; std::vector<DefId> captureDefs;
@ -134,81 +148,81 @@ private:
DenseHashMap<Symbol, FunctionCapture> captures{Symbol{}}; DenseHashMap<Symbol, FunctionCapture> captures{Symbol{}};
void resolveCaptures(); void resolveCaptures();
DfgScope* childScope(DfgScope* scope, DfgScope::ScopeType scopeType = DfgScope::Linear); DfgScope* makeChildScope(DfgScope::ScopeType scopeType = DfgScope::Linear);
void join(DfgScope* p, DfgScope* a, DfgScope* b); void join(DfgScope* p, DfgScope* a, DfgScope* b);
void joinBindings(DfgScope* p, const DfgScope& a, const DfgScope& b); void joinBindings(DfgScope* p, const DfgScope& a, const DfgScope& b);
void joinProps(DfgScope* p, const DfgScope& a, const DfgScope& b); void joinProps(DfgScope* p, const DfgScope& a, const DfgScope& b);
DefId lookup(DfgScope* scope, Symbol symbol); DefId lookup(Symbol symbol);
DefId lookup(DfgScope* scope, DefId def, const std::string& key); DefId lookup(DefId def, const std::string& key);
ControlFlow visit(DfgScope* scope, AstStatBlock* b); ControlFlow visit(AstStatBlock* b);
ControlFlow visitBlockWithoutChildScope(DfgScope* scope, AstStatBlock* b); ControlFlow visitBlockWithoutChildScope(AstStatBlock* b);
ControlFlow visit(DfgScope* scope, AstStat* s); ControlFlow visit(AstStat* s);
ControlFlow visit(DfgScope* scope, AstStatIf* i); ControlFlow visit(AstStatIf* i);
ControlFlow visit(DfgScope* scope, AstStatWhile* w); ControlFlow visit(AstStatWhile* w);
ControlFlow visit(DfgScope* scope, AstStatRepeat* r); ControlFlow visit(AstStatRepeat* r);
ControlFlow visit(DfgScope* scope, AstStatBreak* b); ControlFlow visit(AstStatBreak* b);
ControlFlow visit(DfgScope* scope, AstStatContinue* c); ControlFlow visit(AstStatContinue* c);
ControlFlow visit(DfgScope* scope, AstStatReturn* r); ControlFlow visit(AstStatReturn* r);
ControlFlow visit(DfgScope* scope, AstStatExpr* e); ControlFlow visit(AstStatExpr* e);
ControlFlow visit(DfgScope* scope, AstStatLocal* l); ControlFlow visit(AstStatLocal* l);
ControlFlow visit(DfgScope* scope, AstStatFor* f); ControlFlow visit(AstStatFor* f);
ControlFlow visit(DfgScope* scope, AstStatForIn* f); ControlFlow visit(AstStatForIn* f);
ControlFlow visit(DfgScope* scope, AstStatAssign* a); ControlFlow visit(AstStatAssign* a);
ControlFlow visit(DfgScope* scope, AstStatCompoundAssign* c); ControlFlow visit(AstStatCompoundAssign* c);
ControlFlow visit(DfgScope* scope, AstStatFunction* f); ControlFlow visit(AstStatFunction* f);
ControlFlow visit(DfgScope* scope, AstStatLocalFunction* l); ControlFlow visit(AstStatLocalFunction* l);
ControlFlow visit(DfgScope* scope, AstStatTypeAlias* t); ControlFlow visit(AstStatTypeAlias* t);
ControlFlow visit(DfgScope* scope, AstStatTypeFunction* f); ControlFlow visit(AstStatTypeFunction* f);
ControlFlow visit(DfgScope* scope, AstStatDeclareGlobal* d); ControlFlow visit(AstStatDeclareGlobal* d);
ControlFlow visit(DfgScope* scope, AstStatDeclareFunction* d); ControlFlow visit(AstStatDeclareFunction* d);
ControlFlow visit(DfgScope* scope, AstStatDeclareClass* d); ControlFlow visit(AstStatDeclareClass* d);
ControlFlow visit(DfgScope* scope, AstStatError* error); ControlFlow visit(AstStatError* error);
DataFlowResult visitExpr(DfgScope* scope, AstExpr* e); DataFlowResult visitExpr(AstExpr* e);
DataFlowResult visitExpr(DfgScope* scope, AstExprGroup* group); DataFlowResult visitExpr(AstExprGroup* group);
DataFlowResult visitExpr(DfgScope* scope, AstExprLocal* l); DataFlowResult visitExpr(AstExprLocal* l);
DataFlowResult visitExpr(DfgScope* scope, AstExprGlobal* g); DataFlowResult visitExpr(AstExprGlobal* g);
DataFlowResult visitExpr(DfgScope* scope, AstExprCall* c); DataFlowResult visitExpr(AstExprCall* c);
DataFlowResult visitExpr(DfgScope* scope, AstExprIndexName* i); DataFlowResult visitExpr(AstExprIndexName* i);
DataFlowResult visitExpr(DfgScope* scope, AstExprIndexExpr* i); DataFlowResult visitExpr(AstExprIndexExpr* i);
DataFlowResult visitExpr(DfgScope* scope, AstExprFunction* f); DataFlowResult visitExpr(AstExprFunction* f);
DataFlowResult visitExpr(DfgScope* scope, AstExprTable* t); DataFlowResult visitExpr(AstExprTable* t);
DataFlowResult visitExpr(DfgScope* scope, AstExprUnary* u); DataFlowResult visitExpr(AstExprUnary* u);
DataFlowResult visitExpr(DfgScope* scope, AstExprBinary* b); DataFlowResult visitExpr(AstExprBinary* b);
DataFlowResult visitExpr(DfgScope* scope, AstExprTypeAssertion* t); DataFlowResult visitExpr(AstExprTypeAssertion* t);
DataFlowResult visitExpr(DfgScope* scope, AstExprIfElse* i); DataFlowResult visitExpr(AstExprIfElse* i);
DataFlowResult visitExpr(DfgScope* scope, AstExprInterpString* i); DataFlowResult visitExpr(AstExprInterpString* i);
DataFlowResult visitExpr(DfgScope* scope, AstExprError* error); DataFlowResult visitExpr(AstExprError* error);
void visitLValue(DfgScope* scope, AstExpr* e, DefId incomingDef); void visitLValue(AstExpr* e, DefId incomingDef);
DefId visitLValue(DfgScope* scope, AstExprLocal* l, DefId incomingDef); DefId visitLValue(AstExprLocal* l, DefId incomingDef);
DefId visitLValue(DfgScope* scope, AstExprGlobal* g, DefId incomingDef); DefId visitLValue(AstExprGlobal* g, DefId incomingDef);
DefId visitLValue(DfgScope* scope, AstExprIndexName* i, DefId incomingDef); DefId visitLValue(AstExprIndexName* i, DefId incomingDef);
DefId visitLValue(DfgScope* scope, AstExprIndexExpr* i, DefId incomingDef); DefId visitLValue(AstExprIndexExpr* i, DefId incomingDef);
DefId visitLValue(DfgScope* scope, AstExprError* e, DefId incomingDef); DefId visitLValue(AstExprError* e, DefId incomingDef);
void visitType(DfgScope* scope, AstType* t); void visitType(AstType* t);
void visitType(DfgScope* scope, AstTypeReference* r); void visitType(AstTypeReference* r);
void visitType(DfgScope* scope, AstTypeTable* t); void visitType(AstTypeTable* t);
void visitType(DfgScope* scope, AstTypeFunction* f); void visitType(AstTypeFunction* f);
void visitType(DfgScope* scope, AstTypeTypeof* t); void visitType(AstTypeTypeof* t);
void visitType(DfgScope* scope, AstTypeUnion* u); void visitType(AstTypeUnion* u);
void visitType(DfgScope* scope, AstTypeIntersection* i); void visitType(AstTypeIntersection* i);
void visitType(DfgScope* scope, AstTypeError* error); void visitType(AstTypeError* error);
void visitTypePack(DfgScope* scope, AstTypePack* p); void visitTypePack(AstTypePack* p);
void visitTypePack(DfgScope* scope, AstTypePackExplicit* e); void visitTypePack(AstTypePackExplicit* e);
void visitTypePack(DfgScope* scope, AstTypePackVariadic* v); void visitTypePack(AstTypePackVariadic* v);
void visitTypePack(DfgScope* scope, AstTypePackGeneric* g); void visitTypePack(AstTypePackGeneric* g);
void visitTypeList(DfgScope* scope, AstTypeList l); void visitTypeList(AstTypeList l);
void visitGenerics(DfgScope* scope, AstArray<AstGenericType> g); void visitGenerics(AstArray<AstGenericType*> g);
void visitGenericPacks(DfgScope* scope, AstArray<AstGenericTypePack> g); void visitGenericPacks(AstArray<AstGenericTypePack*> g);
}; };
} // namespace Luau } // namespace Luau

View file

@ -0,0 +1,50 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/TypeFwd.h"
#include "Luau/NotNull.h"
#include "Luau/DenseHash.h"
#include <memory>
#include <optional>
#include <vector>
namespace Luau
{
struct TypeArena;
}
// The EqSat stuff is pretty template heavy, so we go to some lengths to prevent
// the complexity from leaking outside its implementation sources.
namespace Luau::EqSatSimplification
{
struct Simplifier;
using SimplifierPtr = std::unique_ptr<Simplifier, void (*)(Simplifier*)>;
SimplifierPtr newSimplifier(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtinTypes);
} // namespace Luau::EqSatSimplification
namespace Luau
{
struct EqSatSimplificationResult
{
TypeId result;
// New type function applications that were created by the reduction phase.
// We return these so that the ConstraintSolver can know to try to reduce
// them.
std::vector<TypeId> newTypeFunctions;
};
using EqSatSimplification::newSimplifier; // NOLINT: clang-tidy thinks these are unused. It is incorrect.
using Luau::EqSatSimplification::Simplifier; // NOLINT
using Luau::EqSatSimplification::SimplifierPtr;
std::optional<EqSatSimplificationResult> eqSatSimplify(NotNull<Simplifier> simplifier, TypeId ty);
} // namespace Luau

View file

@ -0,0 +1,376 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/EGraph.h"
#include "Luau/Id.h"
#include "Luau/Language.h"
#include "Luau/Lexer.h" // For Allocator
#include "Luau/NotNull.h"
#include "Luau/TypeArena.h"
#include "Luau/TypeFwd.h"
namespace Luau
{
struct TypeFunction;
}
namespace Luau::EqSatSimplification
{
using StringId = uint32_t;
using Id = Luau::EqSat::Id;
LUAU_EQSAT_UNIT(TNil);
LUAU_EQSAT_UNIT(TBoolean);
LUAU_EQSAT_UNIT(TNumber);
LUAU_EQSAT_UNIT(TString);
LUAU_EQSAT_UNIT(TThread);
LUAU_EQSAT_UNIT(TTopFunction);
LUAU_EQSAT_UNIT(TTopTable);
LUAU_EQSAT_UNIT(TTopClass);
LUAU_EQSAT_UNIT(TBuffer);
// Used for any type that eqsat can't do anything interesting with.
LUAU_EQSAT_ATOM(TOpaque, TypeId);
LUAU_EQSAT_ATOM(SBoolean, bool);
LUAU_EQSAT_ATOM(SString, StringId);
LUAU_EQSAT_ATOM(TFunction, TypeId);
LUAU_EQSAT_ATOM(TImportedTable, TypeId);
LUAU_EQSAT_ATOM(TClass, TypeId);
LUAU_EQSAT_UNIT(TAny);
LUAU_EQSAT_UNIT(TError);
LUAU_EQSAT_UNIT(TUnknown);
LUAU_EQSAT_UNIT(TNever);
LUAU_EQSAT_NODE_SET(Union);
LUAU_EQSAT_NODE_SET(Intersection);
LUAU_EQSAT_NODE_ARRAY(Negation, 1);
LUAU_EQSAT_NODE_ATOM_WITH_VECTOR(TTypeFun, std::shared_ptr<const TypeFunctionInstanceType>);
LUAU_EQSAT_UNIT(TNoRefine);
LUAU_EQSAT_UNIT(Invalid);
// enodes are immutable, but types are cyclic. We need a way to tie the knot.
// We handle this by generating TBound nodes at points where we encounter cycles.
// Each TBound has an ordinal that we later map onto the type.
// We use a substitution rule to replace all TBound nodes with their referrent.
LUAU_EQSAT_ATOM(TBound, size_t);
// Tables are sufficiently unlike other enodes that the Language.h macros won't cut it.
struct TTable
{
explicit TTable(Id basis);
TTable(Id basis, std::vector<StringId> propNames_, std::vector<Id> propTypes_);
// All TTables extend some other table. This may be TTopTable.
//
// It will frequently be a TImportedTable, in which case we can reuse things
// like source location and documentation info.
Id getBasis() const;
EqSat::Slice<const Id> propTypes() const;
// TODO: Also support read-only table props
// TODO: Indexer type, index result type.
std::vector<StringId> propNames;
// The enode interface
EqSat::Slice<Id> mutableOperands();
EqSat::Slice<const Id> operands() const;
bool operator==(const TTable& rhs) const;
bool operator!=(const TTable& rhs) const
{
return !(*this == rhs);
}
struct Hash
{
size_t operator()(const TTable& value) const;
};
private:
// The first element of this vector is the basis. Subsequent elements are
// property types. As we add other things like read-only properties and
// indexers, the structure of this array is likely to change.
//
// We encode our data in this way so that the operands() method can properly
// return a Slice<Id>.
std::vector<Id> storage;
};
template<typename L>
using Node = EqSat::Node<L>;
using EType = EqSat::Language<
TNil,
TBoolean,
TNumber,
TString,
TThread,
TTopFunction,
TTopTable,
TTopClass,
TBuffer,
TOpaque,
SBoolean,
SString,
TFunction,
TTable,
TImportedTable,
TClass,
TAny,
TError,
TUnknown,
TNever,
Union,
Intersection,
Negation,
TTypeFun,
Invalid,
TNoRefine,
TBound>;
struct StringCache
{
Allocator allocator;
DenseHashMap<std::string_view, StringId> strings{{}};
std::vector<std::string_view> views;
StringId add(std::string_view s);
std::string_view asStringView(StringId id) const;
std::string asString(StringId id) const;
};
using EGraph = Luau::EqSat::EGraph<EType, struct Simplify>;
struct Simplify
{
using Data = bool;
template<typename T>
Data make(const EGraph&, const T&) const;
void join(Data& left, const Data& right) const;
};
struct Subst
{
Id eclass;
Id newClass;
// The node into eclass which is boring, if any
std::optional<size_t> boringIndex;
std::string desc;
Subst(Id eclass, Id newClass, std::string desc = "");
};
struct Simplifier
{
NotNull<TypeArena> arena;
NotNull<BuiltinTypes> builtinTypes;
EGraph egraph;
StringCache stringCache;
// enodes are immutable but types can be cyclic, so we need some way to
// encode the cycle. This map is used to connect TBound nodes to the right
// eclass.
//
// The cyclicIntersection rewrite rule uses this to sense when a cycle can
// be deleted from an intersection or union.
std::unordered_map<size_t, Id> mappingIdToClass;
std::vector<Subst> substs;
using RewriteRuleFn = void (Simplifier::*)(Id id);
Simplifier(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtinTypes);
// Utilities
const EqSat::EClass<EType, Simplify::Data>& get(Id id) const;
Id find(Id id) const;
Id add(EType enode);
template<typename Tag>
const Tag* isTag(Id id) const;
template<typename Tag>
const Tag* isTag(const EType& enode) const;
void subst(Id from, Id to);
void subst(Id from, Id to, const std::string& ruleName);
void subst(Id from, Id to, const std::string& ruleName, const std::unordered_map<Id, size_t>& forceNodes);
void subst(Id from, size_t boringIndex, Id to, const std::string& ruleName, const std::unordered_map<Id, size_t>& forceNodes);
void unionClasses(std::vector<Id>& hereParts, Id there);
// Rewrite rules
void simplifyUnion(Id id);
void uninhabitedIntersection(Id id);
void intersectWithNegatedClass(Id id);
void intersectWithNegatedAtom(Id id);
void intersectWithNoRefine(Id id);
void cyclicIntersectionOfUnion(Id id);
void cyclicUnionOfIntersection(Id id);
void expandNegation(Id id);
void intersectionOfUnion(Id id);
void intersectTableProperty(Id id);
void uninhabitedTable(Id id);
void unneededTableModification(Id id);
void builtinTypeFunctions(Id id);
void iffyTypeFunctions(Id id);
void strictMetamethods(Id id);
};
template<typename Tag>
struct QueryIterator
{
QueryIterator();
QueryIterator(EGraph* egraph, Id eclass);
bool operator==(const QueryIterator& other) const;
bool operator!=(const QueryIterator& other) const;
std::pair<const Tag*, size_t> operator*() const;
QueryIterator& operator++();
QueryIterator& operator++(int);
private:
EGraph* egraph = nullptr;
Id eclass;
size_t index = 0;
};
template<typename Tag>
struct Query
{
EGraph* egraph;
Id eclass;
Query(EGraph* egraph, Id eclass)
: egraph(egraph)
, eclass(eclass)
{
}
QueryIterator<Tag> begin()
{
return QueryIterator<Tag>{egraph, eclass};
}
QueryIterator<Tag> end()
{
return QueryIterator<Tag>{};
}
};
template<typename Tag>
QueryIterator<Tag>::QueryIterator()
: egraph(nullptr)
, eclass(Id{0})
, index(0)
{
}
template<typename Tag>
QueryIterator<Tag>::QueryIterator(EGraph* egraph_, Id eclass)
: egraph(egraph_)
, eclass(eclass)
, index(0)
{
const auto& ecl = (*egraph)[eclass];
static constexpr const int idx = EType::VariantTy::getTypeId<Tag>();
for (const auto& enode : ecl.nodes)
{
if (enode.node.index() < idx)
++index;
else
break;
}
if (index >= ecl.nodes.size() || ecl.nodes[index].node.index() != idx)
{
egraph = nullptr;
index = 0;
}
}
template<typename Tag>
bool QueryIterator<Tag>::operator==(const QueryIterator<Tag>& rhs) const
{
if (egraph == nullptr && rhs.egraph == nullptr)
return true;
return egraph == rhs.egraph && eclass == rhs.eclass && index == rhs.index;
}
template<typename Tag>
bool QueryIterator<Tag>::operator!=(const QueryIterator<Tag>& rhs) const
{
return !(*this == rhs);
}
template<typename Tag>
std::pair<const Tag*, size_t> QueryIterator<Tag>::operator*() const
{
LUAU_ASSERT(egraph != nullptr);
EGraph::EClassT& ecl = (*egraph)[eclass];
LUAU_ASSERT(index < ecl.nodes.size());
auto& enode = ecl.nodes[index].node;
Tag* result = enode.template get<Tag>();
LUAU_ASSERT(result);
return {result, index};
}
// pre-increment
template<typename Tag>
QueryIterator<Tag>& QueryIterator<Tag>::operator++()
{
const auto& ecl = (*egraph)[eclass];
do
{
++index;
if (index >= ecl.nodes.size() || ecl.nodes[index].node.index() != EType::VariantTy::getTypeId<Tag>())
{
egraph = nullptr;
index = 0;
break;
}
} while (ecl.nodes[index].boring);
return *this;
}
// post-increment
template<typename Tag>
QueryIterator<Tag>& QueryIterator<Tag>::operator++(int)
{
QueryIterator<Tag> res = *this;
++res;
return res;
}
} // namespace Luau::EqSatSimplification

View file

@ -448,6 +448,13 @@ struct UnexpectedTypePackInSubtyping
bool operator==(const UnexpectedTypePackInSubtyping& rhs) const; bool operator==(const UnexpectedTypePackInSubtyping& rhs) const;
}; };
struct UserDefinedTypeFunctionError
{
std::string message;
bool operator==(const UserDefinedTypeFunctionError& rhs) const;
};
using TypeErrorData = Variant< using TypeErrorData = Variant<
TypeMismatch, TypeMismatch,
UnknownSymbol, UnknownSymbol,
@ -496,7 +503,8 @@ using TypeErrorData = Variant<
CheckedFunctionIncorrectArgs, CheckedFunctionIncorrectArgs,
UnexpectedTypeInSubtyping, UnexpectedTypeInSubtyping,
UnexpectedTypePackInSubtyping, UnexpectedTypePackInSubtyping,
ExplicitFunctionAnnotationRecommended>; ExplicitFunctionAnnotationRecommended,
UserDefinedTypeFunctionError>;
struct TypeErrorSummary struct TypeErrorSummary
{ {

View file

@ -3,6 +3,7 @@
#include <string> #include <string>
#include <optional> #include <optional>
#include <vector>
namespace Luau namespace Luau
{ {
@ -31,6 +32,13 @@ struct ModuleInfo
bool optional = false; bool optional = false;
}; };
struct RequireSuggestion
{
std::string label;
std::string fullPath;
};
using RequireSuggestions = std::vector<RequireSuggestion>;
struct FileResolver struct FileResolver
{ {
virtual ~FileResolver() {} virtual ~FileResolver() {}
@ -51,6 +59,11 @@ struct FileResolver
{ {
return std::nullopt; return std::nullopt;
} }
virtual std::optional<RequireSuggestions> getRequireSuggestions(const ModuleName& requirer, const std::optional<std::string>& pathString) const
{
return std::nullopt;
}
}; };
struct NullFileResolver : FileResolver struct NullFileResolver : FileResolver

View file

@ -0,0 +1,146 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Ast.h"
#include "Luau/Parser.h"
#include "Luau/AutocompleteTypes.h"
#include "Luau/DenseHash.h"
#include "Luau/Module.h"
#include "Luau/Frontend.h"
#include <memory>
#include <vector>
namespace Luau
{
struct FrontendOptions;
enum class FragmentTypeCheckStatus
{
SkipAutocomplete,
Success,
};
struct FragmentAutocompleteAncestryResult
{
DenseHashMap<AstName, AstLocal*> localMap{AstName()};
std::vector<AstLocal*> localStack;
std::vector<AstNode*> ancestry;
AstStat* nearestStatement = nullptr;
};
struct FragmentParseResult
{
std::string fragmentToParse;
AstStatBlock* root = nullptr;
std::vector<AstNode*> ancestry;
AstStat* nearestStatement = nullptr;
std::vector<Comment> commentLocations;
std::unique_ptr<Allocator> alloc = std::make_unique<Allocator>();
};
struct FragmentTypeCheckResult
{
ModulePtr incrementalModule = nullptr;
ScopePtr freshScope;
std::vector<AstNode*> ancestry;
};
struct FragmentAutocompleteResult
{
ModulePtr incrementalModule;
Scope* freshScope;
TypeArena arenaForAutocomplete;
AutocompleteResult acResults;
};
FragmentAutocompleteAncestryResult findAncestryForFragmentParse(AstStatBlock* root, const Position& cursorPos);
std::optional<FragmentParseResult> parseFragment(
const SourceModule& srcModule,
std::string_view src,
const Position& cursorPos,
std::optional<Position> fragmentEndPosition
);
std::pair<FragmentTypeCheckStatus, FragmentTypeCheckResult> typecheckFragment(
Frontend& frontend,
const ModuleName& moduleName,
const Position& cursorPos,
std::optional<FrontendOptions> opts,
std::string_view src,
std::optional<Position> fragmentEndPosition
);
FragmentAutocompleteResult fragmentAutocomplete(
Frontend& frontend,
std::string_view src,
const ModuleName& moduleName,
Position cursorPosition,
std::optional<FrontendOptions> opts,
StringCompletionCallback callback,
std::optional<Position> fragmentEndPosition = std::nullopt
);
enum class FragmentAutocompleteStatus
{
Success,
FragmentTypeCheckFail,
InternalIce
};
struct FragmentAutocompleteStatusResult
{
FragmentAutocompleteStatus status;
std::optional<FragmentAutocompleteResult> result;
};
struct FragmentContext
{
std::string_view newSrc;
const ParseResult& newAstRoot;
std::optional<FrontendOptions> opts;
std::optional<Position> DEPRECATED_fragmentEndPosition;
};
/**
* @brief Attempts to compute autocomplete suggestions from the fragment context.
*
* This function computes autocomplete suggestions using outdated frontend typechecking data
* by patching the fragment context of the new script source content.
*
* @param frontend The Luau Frontend data structure, which may contain outdated typechecking data.
*
* @param moduleName The name of the target module, specifying which script the caller wants to request autocomplete for.
*
* @param cursorPosition The position in the script where the caller wants to trigger autocomplete.
*
* @param context The fragment context that this API will use to patch the outdated typechecking data.
*
* @param stringCompletionCB A callback function that provides autocomplete suggestions for string contexts.
*
* @return
* The status indicating whether `fragmentAutocomplete` ran successfully or failed, along with the reason for failure.
* Also includes autocomplete suggestions if the status is successful.
*
* @usage
* FragmentAutocompleteStatusResult acStatusResult;
* if (shouldFragmentAC)
* acStatusResult = Luau::tryFragmentAutocomplete(...);
*
* if (acStatusResult.status != Successful)
* {
* frontend.check(moduleName, options);
* acStatusResult.acResult = Luau::autocomplete(...);
* }
* return convertResultWithContext(acStatusResult.acResult);
*/
FragmentAutocompleteStatusResult tryFragmentAutocomplete(
Frontend& frontend,
const ModuleName& moduleName,
Position cursorPosition,
FragmentContext context,
StringCompletionCallback stringCompletionCB
);
} // namespace Luau

View file

@ -7,6 +7,7 @@
#include "Luau/ModuleResolver.h" #include "Luau/ModuleResolver.h"
#include "Luau/RequireTracer.h" #include "Luau/RequireTracer.h"
#include "Luau/Scope.h" #include "Luau/Scope.h"
#include "Luau/Set.h"
#include "Luau/TypeCheckLimits.h" #include "Luau/TypeCheckLimits.h"
#include "Luau/Variant.h" #include "Luau/Variant.h"
#include "Luau/AnyTypeSummary.h" #include "Luau/AnyTypeSummary.h"
@ -44,21 +45,6 @@ struct LoadDefinitionFileResult
std::optional<Mode> parseMode(const std::vector<HotComment>& hotcomments); std::optional<Mode> parseMode(const std::vector<HotComment>& hotcomments);
std::vector<std::string_view> parsePathExpr(const AstExpr& pathExpr);
// Exported only for convenient testing.
std::optional<ModuleName> pathExprToModuleName(const ModuleName& currentModuleName, const std::vector<std::string_view>& expr);
/** Try to convert an AST fragment into a ModuleName.
* Returns std::nullopt if the expression cannot be resolved. This will most likely happen in cases where
* the import path involves some dynamic computation that we cannot see into at typechecking time.
*
* Unintuitively, weirdly-formulated modules (like game.Parent.Parent.Parent.Foo) will successfully produce a ModuleName
* as long as it falls within the permitted syntax. This is ok because we will fail to find the module and produce an
* error when we try during typechecking.
*/
std::optional<ModuleName> pathExprToModuleName(const ModuleName& currentModuleName, const AstExpr& expr);
struct SourceNode struct SourceNode
{ {
bool hasDirtySourceModule() const bool hasDirtySourceModule() const
@ -71,13 +57,32 @@ struct SourceNode
return forAutocomplete ? dirtyModuleForAutocomplete : dirtyModule; return forAutocomplete ? dirtyModuleForAutocomplete : dirtyModule;
} }
bool hasInvalidModuleDependency(bool forAutocomplete) const
{
return forAutocomplete ? invalidModuleDependencyForAutocomplete : invalidModuleDependency;
}
void setInvalidModuleDependency(bool value, bool forAutocomplete)
{
if (forAutocomplete)
invalidModuleDependencyForAutocomplete = value;
else
invalidModuleDependency = value;
}
ModuleName name; ModuleName name;
std::string humanReadableName; std::string humanReadableName;
DenseHashSet<ModuleName> requireSet{{}}; DenseHashSet<ModuleName> requireSet{{}};
std::vector<std::pair<ModuleName, Location>> requireLocations; std::vector<std::pair<ModuleName, Location>> requireLocations;
Set<ModuleName> dependents{{}};
bool dirtySourceModule = true; bool dirtySourceModule = true;
bool dirtyModule = true; bool dirtyModule = true;
bool dirtyModuleForAutocomplete = true; bool dirtyModuleForAutocomplete = true;
bool invalidModuleDependency = true;
bool invalidModuleDependencyForAutocomplete = true;
double autocompleteLimitsMult = 1.0; double autocompleteLimitsMult = 1.0;
}; };
@ -132,7 +137,7 @@ struct FrontendModuleResolver : ModuleResolver
std::optional<ModuleInfo> resolveModuleInfo(const ModuleName& currentModuleName, const AstExpr& pathExpr) override; std::optional<ModuleInfo> resolveModuleInfo(const ModuleName& currentModuleName, const AstExpr& pathExpr) override;
std::string getHumanReadableModuleName(const ModuleName& moduleName) const override; std::string getHumanReadableModuleName(const ModuleName& moduleName) const override;
void setModule(const ModuleName& moduleName, ModulePtr module); bool setModule(const ModuleName& moduleName, ModulePtr module);
void clearModules(); void clearModules();
private: private:
@ -166,9 +171,13 @@ struct Frontend
// Parse and typecheck module graph // Parse and typecheck module graph
CheckResult check(const ModuleName& name, std::optional<FrontendOptions> optionOverride = {}); // new shininess CheckResult check(const ModuleName& name, std::optional<FrontendOptions> optionOverride = {}); // new shininess
bool allModuleDependenciesValid(const ModuleName& name, bool forAutocomplete = false) const;
bool isDirty(const ModuleName& name, bool forAutocomplete = false) const; bool isDirty(const ModuleName& name, bool forAutocomplete = false) const;
void markDirty(const ModuleName& name, std::vector<ModuleName>* markedDirty = nullptr); void markDirty(const ModuleName& name, std::vector<ModuleName>* markedDirty = nullptr);
void traverseDependents(const ModuleName& name, std::function<bool(SourceNode&)> processSubtree);
/** Borrow a pointer into the SourceModule cache. /** Borrow a pointer into the SourceModule cache.
* *
* Returns nullptr if we don't have it. This could mean that the script * Returns nullptr if we don't have it. This could mean that the script
@ -209,6 +218,7 @@ struct Frontend
); );
std::optional<CheckResult> getCheckResult(const ModuleName& name, bool accumulateNested, bool forAutocomplete = false); std::optional<CheckResult> getCheckResult(const ModuleName& name, bool accumulateNested, bool forAutocomplete = false);
std::vector<ModuleName> getRequiredScripts(const ModuleName& name);
private: private:
ModulePtr check( ModulePtr check(

View file

@ -60,7 +60,7 @@ struct ReplaceGenerics : Substitution
}; };
// A substitution which replaces generic functions by monomorphic functions // A substitution which replaces generic functions by monomorphic functions
struct Instantiation : Substitution struct Instantiation final : Substitution
{ {
Instantiation(const TxnLog* log, TypeArena* arena, NotNull<BuiltinTypes> builtinTypes, TypeLevel level, Scope* scope) Instantiation(const TxnLog* log, TypeArena* arena, NotNull<BuiltinTypes> builtinTypes, TypeLevel level, Scope* scope)
: Substitution(log, arena) : Substitution(log, arena)

View file

@ -53,7 +53,7 @@ struct Replacer : Substitution
}; };
// A substitution which replaces generic functions by monomorphic functions // A substitution which replaces generic functions by monomorphic functions
struct Instantiation2 : Substitution struct Instantiation2 final : Substitution
{ {
// Mapping from generic types to free types to be used in instantiation. // Mapping from generic types to free types to be used in instantiation.
DenseHashMap<TypeId, TypeId> genericSubstitutions{nullptr}; DenseHashMap<TypeId, TypeId> genericSubstitutions{nullptr};

View file

@ -9,15 +9,24 @@
#include "Luau/Scope.h" #include "Luau/Scope.h"
#include "Luau/TypeArena.h" #include "Luau/TypeArena.h"
#include "Luau/AnyTypeSummary.h" #include "Luau/AnyTypeSummary.h"
#include "Luau/DataFlowGraph.h"
#include <memory> #include <memory>
#include <vector> #include <vector>
#include <unordered_map> #include <unordered_map>
#include <optional> #include <optional>
LUAU_FASTFLAG(LuauIncrementalAutocompleteCommentDetection)
namespace Luau namespace Luau
{ {
using LogLuauProc = void (*)(std::string_view);
extern LogLuauProc logLuau;
void setLogLuau(LogLuauProc ll);
void resetLogLuauProc();
struct Module; struct Module;
struct AnyTypeSummary; struct AnyTypeSummary;
@ -54,6 +63,7 @@ struct SourceModule
} }
}; };
bool isWithinComment(const std::vector<Comment>& commentLocations, Position pos);
bool isWithinComment(const SourceModule& sourceModule, Position pos); bool isWithinComment(const SourceModule& sourceModule, Position pos);
bool isWithinComment(const ParseResult& result, Position pos); bool isWithinComment(const ParseResult& result, Position pos);
@ -67,6 +77,9 @@ struct Module
{ {
~Module(); ~Module();
// TODO: Clip this when we clip FFlagLuauSolverV2
bool checkedInNewSolver = false;
ModuleName name; ModuleName name;
std::string humanReadableName; std::string humanReadableName;
@ -132,6 +145,11 @@ struct Module
TypePackId returnType = nullptr; TypePackId returnType = nullptr;
std::unordered_map<Name, TypeFun> exportedTypeBindings; std::unordered_map<Name, TypeFun> exportedTypeBindings;
// Arenas related to the DFG must persist after the DFG no longer exists, as
// Module objects maintain raw pointers to objects in these arenas.
DefArena defArena;
RefinementKeyArena keyArena;
bool hasModuleScope() const; bool hasModuleScope() const;
ScopePtr getModuleScope() const; ScopePtr getModuleScope() const;

View file

@ -20,8 +20,6 @@ struct ModuleResolver
virtual ~ModuleResolver() {} virtual ~ModuleResolver() {}
/** Compute a ModuleName from an AST fragment. This AST fragment is generally the argument to the require() function. /** Compute a ModuleName from an AST fragment. This AST fragment is generally the argument to the require() function.
*
* You probably want to implement this with some variation of pathExprToModuleName.
* *
* @returns The ModuleInfo if the expression is a syntactically legal path. * @returns The ModuleInfo if the expression is a syntactically legal path.
* @returns std::nullopt if we are unable to determine whether or not the expression is a valid path. Type inference will * @returns std::nullopt if we are unable to determine whether or not the expression is a valid path. Type inference will

View file

@ -9,11 +9,14 @@ namespace Luau
{ {
struct BuiltinTypes; struct BuiltinTypes;
struct TypeFunctionRuntime;
struct UnifierSharedState; struct UnifierSharedState;
struct TypeCheckLimits; struct TypeCheckLimits;
void checkNonStrict( void checkNonStrict(
NotNull<BuiltinTypes> builtinTypes, NotNull<BuiltinTypes> builtinTypes,
NotNull<Simplifier> simplifier,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<InternalErrorReporter> ice, NotNull<InternalErrorReporter> ice,
NotNull<UnifierSharedState> unifierState, NotNull<UnifierSharedState> unifierState,
NotNull<const DataFlowGraph> dfg, NotNull<const DataFlowGraph> dfg,

View file

@ -1,6 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once #pragma once
#include "Luau/EqSatSimplification.h"
#include "Luau/NotNull.h" #include "Luau/NotNull.h"
#include "Luau/Set.h" #include "Luau/Set.h"
#include "Luau/TypeFwd.h" #include "Luau/TypeFwd.h"
@ -21,10 +22,22 @@ struct Scope;
using ModulePtr = std::shared_ptr<Module>; using ModulePtr = std::shared_ptr<Module>;
bool isSubtype(TypeId subTy, TypeId superTy, NotNull<Scope> scope, NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter& ice); bool isSubtype(
bool isSubtype(TypePackId subTy, TypePackId superTy, NotNull<Scope> scope, NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter& ice); TypeId subTy,
bool isConsistentSubtype(TypeId subTy, TypeId superTy, NotNull<Scope> scope, NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter& ice); TypeId superTy,
bool isConsistentSubtype(TypePackId subTy, TypePackId superTy, NotNull<Scope> scope, NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter& ice); NotNull<Scope> scope,
NotNull<BuiltinTypes> builtinTypes,
NotNull<Simplifier> simplifier,
InternalErrorReporter& ice
);
bool isSubtype(
TypePackId subPack,
TypePackId superPack,
NotNull<Scope> scope,
NotNull<BuiltinTypes> builtinTypes,
NotNull<Simplifier> simplifier,
InternalErrorReporter& ice
);
class TypeIds class TypeIds
{ {
@ -336,6 +349,7 @@ struct NormalizedType
}; };
using SeenTablePropPairs = Set<std::pair<TypeId, TypeId>, TypeIdPairHash>;
class Normalizer class Normalizer
{ {
@ -390,7 +404,13 @@ public:
void unionTablesWithTable(TypeIds& heres, TypeId there); void unionTablesWithTable(TypeIds& heres, TypeId there);
void unionTables(TypeIds& heres, const TypeIds& theres); void unionTables(TypeIds& heres, const TypeIds& theres);
NormalizationResult unionNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars = -1); NormalizationResult unionNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars = -1);
NormalizationResult unionNormalWithTy(NormalizedType& here, TypeId there, Set<TypeId>& seenSetTypes, int ignoreSmallerTyvars = -1); NormalizationResult unionNormalWithTy(
NormalizedType& here,
TypeId there,
SeenTablePropPairs& seenTablePropPairs,
Set<TypeId>& seenSetTypes,
int ignoreSmallerTyvars = -1
);
// ------- Negations // ------- Negations
std::optional<NormalizedType> negateNormal(const NormalizedType& here); std::optional<NormalizedType> negateNormal(const NormalizedType& here);
@ -407,16 +427,26 @@ public:
void intersectClassesWithClass(NormalizedClassType& heres, TypeId there); void intersectClassesWithClass(NormalizedClassType& heres, TypeId there);
void intersectStrings(NormalizedStringType& here, const NormalizedStringType& there); void intersectStrings(NormalizedStringType& here, const NormalizedStringType& there);
std::optional<TypePackId> intersectionOfTypePacks(TypePackId here, TypePackId there); std::optional<TypePackId> intersectionOfTypePacks(TypePackId here, TypePackId there);
std::optional<TypeId> intersectionOfTables(TypeId here, TypeId there, Set<TypeId>& seenSet); std::optional<TypeId> intersectionOfTables(TypeId here, TypeId there, SeenTablePropPairs& seenTablePropPairs, Set<TypeId>& seenSet);
void intersectTablesWithTable(TypeIds& heres, TypeId there, Set<TypeId>& seenSetTypes); void intersectTablesWithTable(TypeIds& heres, TypeId there, SeenTablePropPairs& seenTablePropPairs, Set<TypeId>& seenSetTypes);
void intersectTables(TypeIds& heres, const TypeIds& theres); void intersectTables(TypeIds& heres, const TypeIds& theres);
std::optional<TypeId> intersectionOfFunctions(TypeId here, TypeId there); std::optional<TypeId> intersectionOfFunctions(TypeId here, TypeId there);
void intersectFunctionsWithFunction(NormalizedFunctionType& heress, TypeId there); void intersectFunctionsWithFunction(NormalizedFunctionType& heress, TypeId there);
void intersectFunctions(NormalizedFunctionType& heress, const NormalizedFunctionType& theress); void intersectFunctions(NormalizedFunctionType& heress, const NormalizedFunctionType& theress);
NormalizationResult intersectTyvarsWithTy(NormalizedTyvars& here, TypeId there, Set<TypeId>& seenSetTypes); NormalizationResult intersectTyvarsWithTy(
NormalizedTyvars& here,
TypeId there,
SeenTablePropPairs& seenTablePropPairs,
Set<TypeId>& seenSetTypes
);
NormalizationResult intersectNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars = -1); NormalizationResult intersectNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars = -1);
NormalizationResult intersectNormalWithTy(NormalizedType& here, TypeId there, Set<TypeId>& seenSetTypes); NormalizationResult intersectNormalWithTy(NormalizedType& here, TypeId there, SeenTablePropPairs& seenTablePropPairs, Set<TypeId>& seenSetTypes);
NormalizationResult normalizeIntersections(const std::vector<TypeId>& intersections, NormalizedType& outType, Set<TypeId>& seenSet); NormalizationResult normalizeIntersections(
const std::vector<TypeId>& intersections,
NormalizedType& outType,
SeenTablePropPairs& seenTablePropPairs,
Set<TypeId>& seenSet
);
// Check for inhabitance // Check for inhabitance
NormalizationResult isInhabited(TypeId ty); NormalizationResult isInhabited(TypeId ty);
@ -426,7 +456,7 @@ public:
// Check for intersections being inhabited // Check for intersections being inhabited
NormalizationResult isIntersectionInhabited(TypeId left, TypeId right); NormalizationResult isIntersectionInhabited(TypeId left, TypeId right);
NormalizationResult isIntersectionInhabited(TypeId left, TypeId right, Set<TypeId>& seenSet); NormalizationResult isIntersectionInhabited(TypeId left, TypeId right, SeenTablePropPairs& seenTablePropPairs, Set<TypeId>& seenSet);
// -------- Convert back from a normalized type to a type // -------- Convert back from a normalized type to a type
TypeId typeFromNormal(const NormalizedType& norm); TypeId typeFromNormal(const NormalizedType& norm);

View file

@ -2,12 +2,13 @@
#pragma once #pragma once
#include "Luau/Ast.h" #include "Luau/Ast.h"
#include "Luau/InsertionOrderedMap.h" #include "Luau/EqSatSimplification.h"
#include "Luau/NotNull.h"
#include "Luau/TypeFwd.h"
#include "Luau/Location.h"
#include "Luau/Error.h" #include "Luau/Error.h"
#include "Luau/InsertionOrderedMap.h"
#include "Luau/Location.h"
#include "Luau/NotNull.h"
#include "Luau/Subtyping.h" #include "Luau/Subtyping.h"
#include "Luau/TypeFwd.h"
namespace Luau namespace Luau
{ {
@ -34,7 +35,9 @@ struct OverloadResolver
OverloadResolver( OverloadResolver(
NotNull<BuiltinTypes> builtinTypes, NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> arena, NotNull<TypeArena> arena,
NotNull<Simplifier> simplifier,
NotNull<Normalizer> normalizer, NotNull<Normalizer> normalizer,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<Scope> scope, NotNull<Scope> scope,
NotNull<InternalErrorReporter> reporter, NotNull<InternalErrorReporter> reporter,
NotNull<TypeCheckLimits> limits, NotNull<TypeCheckLimits> limits,
@ -43,7 +46,9 @@ struct OverloadResolver
NotNull<BuiltinTypes> builtinTypes; NotNull<BuiltinTypes> builtinTypes;
NotNull<TypeArena> arena; NotNull<TypeArena> arena;
NotNull<Simplifier> simplifier;
NotNull<Normalizer> normalizer; NotNull<Normalizer> normalizer;
NotNull<TypeFunctionRuntime> typeFunctionRuntime;
NotNull<Scope> scope; NotNull<Scope> scope;
NotNull<InternalErrorReporter> ice; NotNull<InternalErrorReporter> ice;
NotNull<TypeCheckLimits> limits; NotNull<TypeCheckLimits> limits;
@ -108,7 +113,9 @@ struct SolveResult
SolveResult solveFunctionCall( SolveResult solveFunctionCall(
NotNull<TypeArena> arena, NotNull<TypeArena> arena,
NotNull<BuiltinTypes> builtinTypes, NotNull<BuiltinTypes> builtinTypes,
NotNull<Simplifier> simplifier,
NotNull<Normalizer> normalizer, NotNull<Normalizer> normalizer,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<InternalErrorReporter> iceReporter, NotNull<InternalErrorReporter> iceReporter,
NotNull<TypeCheckLimits> limits, NotNull<TypeCheckLimits> limits,
NotNull<Scope> scope, NotNull<Scope> scope,

View file

@ -11,14 +11,12 @@
namespace Luau namespace Luau
{ {
class AstStat; class AstNode;
class AstExpr;
class AstStatBlock; class AstStatBlock;
struct AstLocal;
struct RequireTraceResult struct RequireTraceResult
{ {
DenseHashMap<const AstExpr*, ModuleInfo> exprs{nullptr}; DenseHashMap<const AstNode*, ModuleInfo> exprs{nullptr};
std::vector<std::pair<ModuleName, Location>> requireList; std::vector<std::pair<ModuleName, Location>> requireList;
}; };

View file

@ -85,12 +85,18 @@ struct Scope
void inheritAssignments(const ScopePtr& childScope); void inheritAssignments(const ScopePtr& childScope);
void inheritRefinements(const ScopePtr& childScope); void inheritRefinements(const ScopePtr& childScope);
// Track globals that should emit warnings during type checking.
DenseHashSet<std::string> globalsToWarn{""};
bool shouldWarnGlobal(std::string name) const;
// For mutually recursive type aliases, it's important that // For mutually recursive type aliases, it's important that
// they use the same types for the same names. // they use the same types for the same names.
// For instance, in `type Tree<T> { data: T, children: Forest<T> } type Forest<T> = {Tree<T>}` // For instance, in `type Tree<T> { data: T, children: Forest<T> } type Forest<T> = {Tree<T>}`
// we need that the generic type `T` in both cases is the same, so we use a cache. // we need that the generic type `T` in both cases is the same, so we use a cache.
std::unordered_map<Name, TypeId> typeAliasTypeParameters; std::unordered_map<Name, TypeId> typeAliasTypeParameters;
std::unordered_map<Name, TypePackId> typeAliasTypePackParameters; std::unordered_map<Name, TypePackId> typeAliasTypePackParameters;
std::optional<std::vector<TypeId>> interiorFreeTypes;
}; };
// Returns true iff the left scope encloses the right scope. A Scope* equal to // Returns true iff the left scope encloses the right scope. A Scope* equal to

View file

@ -19,10 +19,10 @@ struct SimplifyResult
DenseHashSet<TypeId> blockedTypes; DenseHashSet<TypeId> blockedTypes;
}; };
SimplifyResult simplifyIntersection(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, TypeId ty, TypeId discriminant); SimplifyResult simplifyIntersection(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, TypeId left, TypeId right);
SimplifyResult simplifyIntersection(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, std::set<TypeId> parts); SimplifyResult simplifyIntersection(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, std::set<TypeId> parts);
SimplifyResult simplifyUnion(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, TypeId ty, TypeId discriminant); SimplifyResult simplifyUnion(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, TypeId left, TypeId right);
enum class Relation enum class Relation
{ {

View file

@ -1,13 +1,14 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once #pragma once
#include "Luau/DenseHash.h"
#include "Luau/EqSatSimplification.h"
#include "Luau/Set.h" #include "Luau/Set.h"
#include "Luau/TypeCheckLimits.h"
#include "Luau/TypeFunction.h"
#include "Luau/TypeFwd.h" #include "Luau/TypeFwd.h"
#include "Luau/TypePairHash.h" #include "Luau/TypePairHash.h"
#include "Luau/TypePath.h" #include "Luau/TypePath.h"
#include "Luau/TypeFunction.h"
#include "Luau/TypeCheckLimits.h"
#include "Luau/DenseHash.h"
#include <vector> #include <vector>
#include <optional> #include <optional>
@ -96,6 +97,22 @@ struct SubtypingEnvironment
DenseHashSet<TypeId> upperBound{nullptr}; DenseHashSet<TypeId> upperBound{nullptr};
}; };
/* For nested subtyping relationship tests of mapped generic bounds, we keep the outer environment immutable */
SubtypingEnvironment* parent = nullptr;
/// Applies `mappedGenerics` to the given type.
/// This is used specifically to substitute for generics in type function instances.
std::optional<TypeId> applyMappedGenerics(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, TypeId ty);
const TypeId* tryFindSubstitution(TypeId ty) const;
const SubtypingResult* tryFindSubtypingResult(std::pair<TypeId, TypeId> subAndSuper) const;
bool containsMappedType(TypeId ty) const;
bool containsMappedPack(TypePackId tp) const;
GenericBounds& getMappedTypeBounds(TypeId ty);
TypePackId* getMappedPackBounds(TypePackId tp);
/* /*
* When we encounter a generic over the course of a subtyping test, we need * When we encounter a generic over the course of a subtyping test, we need
* to tentatively map that generic onto a type on the other side. * to tentatively map that generic onto a type on the other side.
@ -112,17 +129,15 @@ struct SubtypingEnvironment
DenseHashMap<TypeId, TypeId> substitutions{nullptr}; DenseHashMap<TypeId, TypeId> substitutions{nullptr};
DenseHashMap<std::pair<TypeId, TypeId>, SubtypingResult, TypePairHash> ephemeralCache{{}}; DenseHashMap<std::pair<TypeId, TypeId>, SubtypingResult, TypePairHash> ephemeralCache{{}};
/// Applies `mappedGenerics` to the given type.
/// This is used specifically to substitute for generics in type function instances.
std::optional<TypeId> applyMappedGenerics(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, TypeId ty);
}; };
struct Subtyping struct Subtyping
{ {
NotNull<BuiltinTypes> builtinTypes; NotNull<BuiltinTypes> builtinTypes;
NotNull<TypeArena> arena; NotNull<TypeArena> arena;
NotNull<Simplifier> simplifier;
NotNull<Normalizer> normalizer; NotNull<Normalizer> normalizer;
NotNull<TypeFunctionRuntime> typeFunctionRuntime;
NotNull<InternalErrorReporter> iceReporter; NotNull<InternalErrorReporter> iceReporter;
TypeCheckLimits limits; TypeCheckLimits limits;
@ -142,7 +157,9 @@ struct Subtyping
Subtyping( Subtyping(
NotNull<BuiltinTypes> builtinTypes, NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> typeArena, NotNull<TypeArena> typeArena,
NotNull<Simplifier> simplifier,
NotNull<Normalizer> normalizer, NotNull<Normalizer> normalizer,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<InternalErrorReporter> iceReporter NotNull<InternalErrorReporter> iceReporter
); );

View file

@ -6,6 +6,8 @@
#include "Luau/NotNull.h" #include "Luau/NotNull.h"
#include "Luau/TypeFwd.h" #include "Luau/TypeFwd.h"
#include <vector>
namespace Luau namespace Luau
{ {

View file

@ -44,6 +44,7 @@ struct ToStringOptions
bool hideTableKind = false; // If true, all tables will be surrounded with plain '{}' bool hideTableKind = false; // If true, all tables will be surrounded with plain '{}'
bool hideNamedFunctionTypeParameters = false; // If true, type parameters of functions will be hidden at top-level. bool hideNamedFunctionTypeParameters = false; // If true, type parameters of functions will be hidden at top-level.
bool hideFunctionSelfArgument = false; // If true, `self: X` will be omitted from the function signature if the function has self bool hideFunctionSelfArgument = false; // If true, `self: X` will be omitted from the function signature if the function has self
bool useQuestionMarks = true; // If true, use a postfix ? for options, else write them out as unions that include nil.
size_t maxTableLength = size_t(FInt::LuauTableTypeMaximumStringifierLength); // Only applied to TableTypes size_t maxTableLength = size_t(FInt::LuauTableTypeMaximumStringifierLength); // Only applied to TableTypes
size_t maxTypeLength = size_t(FInt::LuauTypeMaximumStringifierLength); size_t maxTypeLength = size_t(FInt::LuauTypeMaximumStringifierLength);
size_t compositeTypesSingleLineLimit = 5; // The number of type elements permitted on a single line when printing type unions/intersections size_t compositeTypesSingleLineLimit = 5; // The number of type elements permitted on a single line when printing type unions/intersections

View file

@ -65,11 +65,10 @@ T* getMutable(PendingTypePack* pending)
// Log of what TypeIds we are rebinding, to be committed later. // Log of what TypeIds we are rebinding, to be committed later.
struct TxnLog struct TxnLog
{ {
explicit TxnLog(bool useScopes = false) explicit TxnLog()
: typeVarChanges(nullptr) : typeVarChanges(nullptr)
, typePackChanges(nullptr) , typePackChanges(nullptr)
, ownedSeen() , ownedSeen()
, useScopes(useScopes)
, sharedSeen(&ownedSeen) , sharedSeen(&ownedSeen)
{ {
} }

View file

@ -31,6 +31,7 @@ namespace Luau
struct TypeArena; struct TypeArena;
struct Scope; struct Scope;
using ScopePtr = std::shared_ptr<Scope>; using ScopePtr = std::shared_ptr<Scope>;
struct Module;
struct TypeFunction; struct TypeFunction;
struct Constraint; struct Constraint;
@ -68,12 +69,16 @@ using Name = std::string;
// A free type is one whose exact shape has yet to be fully determined. // A free type is one whose exact shape has yet to be fully determined.
struct FreeType struct FreeType
{ {
// New constructors
explicit FreeType(TypeLevel level, TypeId lowerBound, TypeId upperBound);
// This one got promoted to explicit
explicit FreeType(Scope* scope, TypeId lowerBound, TypeId upperBound);
explicit FreeType(Scope* scope, TypeLevel level, TypeId lowerBound, TypeId upperBound);
// Old constructors
explicit FreeType(TypeLevel level); explicit FreeType(TypeLevel level);
explicit FreeType(Scope* scope); explicit FreeType(Scope* scope);
FreeType(Scope* scope, TypeLevel level); FreeType(Scope* scope, TypeLevel level);
FreeType(Scope* scope, TypeId lowerBound, TypeId upperBound);
int index; int index;
TypeLevel level; TypeLevel level;
Scope* scope = nullptr; Scope* scope = nullptr;
@ -130,14 +135,14 @@ struct BlockedType
BlockedType(); BlockedType();
int index; int index;
Constraint* getOwner() const; const Constraint* getOwner() const;
void setOwner(Constraint* newOwner); void setOwner(const Constraint* newOwner);
void replaceOwner(Constraint* newOwner); void replaceOwner(const Constraint* newOwner);
private: private:
// The constraint that is intended to unblock this type. Other constraints // The constraint that is intended to unblock this type. Other constraints
// should block on this constraint if present. // should block on this constraint if present.
Constraint* owner = nullptr; const Constraint* owner = nullptr;
}; };
struct PrimitiveType struct PrimitiveType
@ -278,9 +283,6 @@ struct WithPredicate
} }
}; };
using MagicFunction = std::function<std::optional<
WithPredicate<TypePackId>>(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>)>;
struct MagicFunctionCallContext struct MagicFunctionCallContext
{ {
NotNull<struct ConstraintSolver> solver; NotNull<struct ConstraintSolver> solver;
@ -290,7 +292,6 @@ struct MagicFunctionCallContext
TypePackId result; TypePackId result;
}; };
using DcrMagicFunction = std::function<bool(MagicFunctionCallContext)>;
struct MagicRefinementContext struct MagicRefinementContext
{ {
NotNull<Scope> scope; NotNull<Scope> scope;
@ -307,8 +308,30 @@ struct MagicFunctionTypeCheckContext
NotNull<Scope> checkScope; NotNull<Scope> checkScope;
}; };
using DcrMagicRefinement = void (*)(const MagicRefinementContext&); struct MagicFunction
using DcrMagicFunctionTypeCheck = std::function<void(const MagicFunctionTypeCheckContext&)>; {
virtual std::optional<WithPredicate<TypePackId>>
handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) = 0;
// Callback to allow custom typechecking of builtin function calls whose argument types
// will only be resolved after constraint solving. For example, the arguments to string.format
// have types that can only be decided after parsing the format string and unifying
// with the passed in values, but the correctness of the call can only be decided after
// all the types have been finalized.
virtual bool infer(const MagicFunctionCallContext&) = 0;
virtual void refine(const MagicRefinementContext&) {}
// If a magic function needs to do its own special typechecking, do it here.
// Returns true if magic typechecking was performed. Return false if the
// default typechecking logic should run.
virtual bool typeCheck(const MagicFunctionTypeCheckContext&)
{
return false;
}
virtual ~MagicFunction() {}
};
struct FunctionType struct FunctionType
{ {
// Global monomorphic function // Global monomorphic function
@ -366,16 +389,7 @@ struct FunctionType
Scope* scope = nullptr; Scope* scope = nullptr;
TypePackId argTypes; TypePackId argTypes;
TypePackId retTypes; TypePackId retTypes;
MagicFunction magicFunction = nullptr; std::shared_ptr<MagicFunction> magic = nullptr;
DcrMagicFunction dcrMagicFunction = nullptr;
DcrMagicRefinement dcrMagicRefinement = nullptr;
// Callback to allow custom typechecking of builtin function calls whose argument types
// will only be resolved after constraint solving. For example, the arguments to string.format
// have types that can only be decided after parsing the format string and unifying
// with the passed in values, but the correctness of the call can only be decided after
// all the types have been finalized.
DcrMagicFunctionTypeCheck dcrMagicTypeCheck = nullptr;
bool hasSelf; bool hasSelf;
// `hasNoFreeOrGenericTypes` should be true if and only if the type does not have any free or generic types present inside it. // `hasNoFreeOrGenericTypes` should be true if and only if the type does not have any free or generic types present inside it.
@ -598,6 +612,19 @@ struct ClassType
} }
}; };
// Data required to initialize a user-defined function and its environment
struct UserDefinedFunctionData
{
// Store a weak module reference to ensure the lifetime requirements are preserved
std::weak_ptr<Module> owner;
// References to AST elements are owned by the Module allocator which also stores this type
AstStatTypeFunction* definition = nullptr;
DenseHashMap<Name, std::pair<AstStatTypeFunction*, size_t>> environment{""};
DenseHashMap<Name, AstStatTypeFunction*> environment_DEPRECATED{""};
};
/** /**
* An instance of a type function that has not yet been reduced to a more concrete * An instance of a type function that has not yet been reduced to a more concrete
* type. The constraint solver receives a constraint to reduce each * type. The constraint solver receives a constraint to reduce each
@ -613,20 +640,20 @@ struct TypeFunctionInstanceType
std::vector<TypePackId> packArguments; std::vector<TypePackId> packArguments;
std::optional<AstName> userFuncName; // Name of the user-defined type function; only available for UDTFs std::optional<AstName> userFuncName; // Name of the user-defined type function; only available for UDTFs
std::optional<AstExprFunction*> userFuncBody; // Body of the user-defined type function; only available for UDTFs UserDefinedFunctionData userFuncData;
TypeFunctionInstanceType( TypeFunctionInstanceType(
NotNull<const TypeFunction> function, NotNull<const TypeFunction> function,
std::vector<TypeId> typeArguments, std::vector<TypeId> typeArguments,
std::vector<TypePackId> packArguments, std::vector<TypePackId> packArguments,
std::optional<AstName> userFuncName = std::nullopt, std::optional<AstName> userFuncName,
std::optional<AstExprFunction*> userFuncBody = std::nullopt UserDefinedFunctionData userFuncData
) )
: function(function) : function(function)
, typeArguments(typeArguments) , typeArguments(typeArguments)
, packArguments(packArguments) , packArguments(packArguments)
, userFuncName(userFuncName) , userFuncName(userFuncName)
, userFuncBody(userFuncBody) , userFuncData(userFuncData)
{ {
} }
@ -643,6 +670,13 @@ struct TypeFunctionInstanceType
, packArguments(packArguments) , packArguments(packArguments)
{ {
} }
TypeFunctionInstanceType(NotNull<const TypeFunction> function, std::vector<TypeId> typeArguments, std::vector<TypePackId> packArguments)
: function{function}
, typeArguments(typeArguments)
, packArguments(packArguments)
{
}
}; };
/** Represents a pending type alias instantiation. /** Represents a pending type alias instantiation.
@ -670,6 +704,11 @@ struct AnyType
{ {
}; };
// A special, trivial type for the refinement system that is always eliminated from intersections.
struct NoRefineType
{
};
// `T | U` // `T | U`
struct UnionType struct UnionType
{ {
@ -737,7 +776,7 @@ struct NegationType
TypeId ty; TypeId ty;
}; };
using ErrorType = Unifiable::Error; using ErrorType = Unifiable::Error<TypeId>;
using TypeVariant = Unifiable::Variant< using TypeVariant = Unifiable::Variant<
TypeId, TypeId,
@ -758,6 +797,7 @@ using TypeVariant = Unifiable::Variant<
UnknownType, UnknownType,
NeverType, NeverType,
NegationType, NegationType,
NoRefineType,
TypeFunctionInstanceType>; TypeFunctionInstanceType>;
struct Type final struct Type final
@ -803,6 +843,13 @@ struct Type final
Type& operator=(const TypeVariant& rhs); Type& operator=(const TypeVariant& rhs);
Type& operator=(TypeVariant&& rhs); Type& operator=(TypeVariant&& rhs);
Type(Type&&) = default;
Type& operator=(Type&&) = default;
Type clone() const;
private:
Type(const Type&) = default;
Type& operator=(const Type& rhs); Type& operator=(const Type& rhs);
}; };
@ -952,6 +999,7 @@ public:
const TypeId unknownType; const TypeId unknownType;
const TypeId neverType; const TypeId neverType;
const TypeId errorType; const TypeId errorType;
const TypeId noRefineType;
const TypeId falsyType; const TypeId falsyType;
const TypeId truthyType; const TypeId truthyType;
@ -1159,6 +1207,10 @@ TypeId freshType(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtinTypes, S
using TypeIdPredicate = std::function<std::optional<TypeId>(TypeId)>; using TypeIdPredicate = std::function<std::optional<TypeId>(TypeId)>;
std::vector<TypeId> filterMap(TypeId type, TypeIdPredicate predicate); std::vector<TypeId> filterMap(TypeId type, TypeIdPredicate predicate);
// A tag to mark a type which doesn't derive directly from the root type as overriding the return of `typeof`.
// Any classes which derive from this type will have typeof return this type.
static constexpr char kTypeofRootTag[] = "typeofRoot";
void attachTag(TypeId ty, const std::string& tagName); void attachTag(TypeId ty, const std::string& tagName);
void attachTag(Property& prop, const std::string& tagName); void attachTag(Property& prop, const std::string& tagName);

View file

@ -32,9 +32,13 @@ struct TypeArena
TypeId addTV(Type&& tv); TypeId addTV(Type&& tv);
TypeId freshType(TypeLevel level); TypeId freshType(NotNull<BuiltinTypes> builtins, TypeLevel level);
TypeId freshType(Scope* scope); TypeId freshType(NotNull<BuiltinTypes> builtins, Scope* scope);
TypeId freshType(Scope* scope, TypeLevel level); TypeId freshType(NotNull<BuiltinTypes> builtins, Scope* scope, TypeLevel level);
TypeId freshType_DEPRECATED(TypeLevel level);
TypeId freshType_DEPRECATED(Scope* scope);
TypeId freshType_DEPRECATED(Scope* scope, TypeLevel level);
TypePackId freshTypePack(Scope* scope); TypePackId freshTypePack(Scope* scope);

View file

@ -2,15 +2,16 @@
#pragma once #pragma once
#include "Luau/Error.h"
#include "Luau/NotNull.h"
#include "Luau/Common.h" #include "Luau/Common.h"
#include "Luau/TypeUtils.h" #include "Luau/EqSatSimplification.h"
#include "Luau/Error.h"
#include "Luau/Normalize.h"
#include "Luau/NotNull.h"
#include "Luau/Subtyping.h"
#include "Luau/Type.h" #include "Luau/Type.h"
#include "Luau/TypeFwd.h" #include "Luau/TypeFwd.h"
#include "Luau/TypeOrPack.h" #include "Luau/TypeOrPack.h"
#include "Luau/Normalize.h" #include "Luau/TypeUtils.h"
#include "Luau/Subtyping.h"
namespace Luau namespace Luau
{ {
@ -60,7 +61,9 @@ struct Reasonings
void check( void check(
NotNull<BuiltinTypes> builtinTypes, NotNull<BuiltinTypes> builtinTypes,
NotNull<UnifierSharedState> sharedState, NotNull<Simplifier> simplifier,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<UnifierSharedState> unifierState,
NotNull<TypeCheckLimits> limits, NotNull<TypeCheckLimits> limits,
DcrLogger* logger, DcrLogger* logger,
const SourceModule& sourceModule, const SourceModule& sourceModule,
@ -70,6 +73,8 @@ void check(
struct TypeChecker2 struct TypeChecker2
{ {
NotNull<BuiltinTypes> builtinTypes; NotNull<BuiltinTypes> builtinTypes;
NotNull<Simplifier> simplifier;
NotNull<TypeFunctionRuntime> typeFunctionRuntime;
DcrLogger* logger; DcrLogger* logger;
const NotNull<TypeCheckLimits> limits; const NotNull<TypeCheckLimits> limits;
const NotNull<InternalErrorReporter> ice; const NotNull<InternalErrorReporter> ice;
@ -88,6 +93,8 @@ struct TypeChecker2
TypeChecker2( TypeChecker2(
NotNull<BuiltinTypes> builtinTypes, NotNull<BuiltinTypes> builtinTypes,
NotNull<Simplifier> simplifier,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<UnifierSharedState> unifierState, NotNull<UnifierSharedState> unifierState,
NotNull<TypeCheckLimits> limits, NotNull<TypeCheckLimits> limits,
DcrLogger* logger, DcrLogger* logger,
@ -109,14 +116,14 @@ private:
std::optional<StackPusher> pushStack(AstNode* node); std::optional<StackPusher> pushStack(AstNode* node);
void checkForInternalTypeFunction(TypeId ty, Location location); void checkForInternalTypeFunction(TypeId ty, Location location);
TypeId checkForTypeFunctionInhabitance(TypeId instance, Location location); TypeId checkForTypeFunctionInhabitance(TypeId instance, Location location);
TypePackId lookupPack(AstExpr* expr); TypePackId lookupPack(AstExpr* expr) const;
TypeId lookupType(AstExpr* expr); TypeId lookupType(AstExpr* expr);
TypeId lookupAnnotation(AstType* annotation); TypeId lookupAnnotation(AstType* annotation);
std::optional<TypePackId> lookupPackAnnotation(AstTypePack* annotation); std::optional<TypePackId> lookupPackAnnotation(AstTypePack* annotation) const;
TypeId lookupExpectedType(AstExpr* expr); TypeId lookupExpectedType(AstExpr* expr) const;
TypePackId lookupExpectedPack(AstExpr* expr, TypeArena& arena); TypePackId lookupExpectedPack(AstExpr* expr, TypeArena& arena) const;
TypePackId reconstructPack(AstArray<AstExpr*> exprs, TypeArena& arena); TypePackId reconstructPack(AstArray<AstExpr*> exprs, TypeArena& arena);
Scope* findInnermostScope(Location location); Scope* findInnermostScope(Location location) const;
void visit(AstStat* stat); void visit(AstStat* stat);
void visit(AstStatIf* ifStatement); void visit(AstStatIf* ifStatement);
void visit(AstStatWhile* whileStatement); void visit(AstStatWhile* whileStatement);
@ -153,7 +160,7 @@ private:
void visit(AstExprVarargs* expr); void visit(AstExprVarargs* expr);
void visitCall(AstExprCall* call); void visitCall(AstExprCall* call);
void visit(AstExprCall* call); void visit(AstExprCall* call);
std::optional<TypeId> tryStripUnionFromNil(TypeId ty); std::optional<TypeId> tryStripUnionFromNil(TypeId ty) const;
TypeId stripFromNilAndReport(TypeId ty, const Location& location); TypeId stripFromNilAndReport(TypeId ty, const Location& location);
void visitExprName(AstExpr* expr, Location location, const std::string& propName, ValueContext context, TypeId astIndexExprTy); void visitExprName(AstExpr* expr, Location location, const std::string& propName, ValueContext context, TypeId astIndexExprTy);
void visit(AstExprIndexName* indexName, ValueContext context); void visit(AstExprIndexName* indexName, ValueContext context);
@ -168,7 +175,7 @@ private:
void visit(AstExprInterpString* interpString); void visit(AstExprInterpString* interpString);
void visit(AstExprError* expr); void visit(AstExprError* expr);
TypeId flattenPack(TypePackId pack); TypeId flattenPack(TypePackId pack);
void visitGenerics(AstArray<AstGenericType> generics, AstArray<AstGenericTypePack> genericPacks); void visitGenerics(AstArray<AstGenericType*> generics, AstArray<AstGenericTypePack*> genericPacks);
void visit(AstType* ty); void visit(AstType* ty);
void visit(AstTypeReference* ty); void visit(AstTypeReference* ty);
void visit(AstTypeTable* table); void visit(AstTypeTable* table);
@ -210,6 +217,9 @@ private:
std::vector<TypeError>& errors std::vector<TypeError>& errors
); );
// Avoid duplicate warnings being emitted for the same global variable.
DenseHashSet<std::string> warnedGlobals{""};
void diagnoseMissingTableKey(UnknownProperty* utk, TypeErrorData& data) const; void diagnoseMissingTableKey(UnknownProperty* utk, TypeErrorData& data) const;
bool isErrorSuppressing(Location loc, TypeId ty); bool isErrorSuppressing(Location loc, TypeId ty);
bool isErrorSuppressing(Location loc1, TypeId ty1, Location loc2, TypeId ty2); bool isErrorSuppressing(Location loc1, TypeId ty1, Location loc2, TypeId ty2);

View file

@ -1,29 +1,68 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once #pragma once
#include "Luau/ConstraintSolver.h" #include "Luau/Constraint.h"
#include "Luau/EqSatSimplification.h"
#include "Luau/Error.h" #include "Luau/Error.h"
#include "Luau/NotNull.h" #include "Luau/NotNull.h"
#include "Luau/TypeCheckLimits.h" #include "Luau/TypeCheckLimits.h"
#include "Luau/TypeFunctionRuntime.h"
#include "Luau/TypeFwd.h" #include "Luau/TypeFwd.h"
#include <functional> #include <functional>
#include <string> #include <string>
#include <optional> #include <optional>
struct lua_State;
namespace Luau namespace Luau
{ {
struct TypeArena; struct TypeArena;
struct TxnLog; struct TxnLog;
struct ConstraintSolver;
class Normalizer; class Normalizer;
using StateRef = std::unique_ptr<lua_State, void (*)(lua_State*)>;
struct TypeFunctionRuntime
{
TypeFunctionRuntime(NotNull<InternalErrorReporter> ice, NotNull<TypeCheckLimits> limits);
~TypeFunctionRuntime();
// Return value is an error message if registration failed
std::optional<std::string> registerFunction(AstStatTypeFunction* function);
// For user-defined type functions, we store all generated types and packs for the duration of the typecheck
TypedAllocator<TypeFunctionType> typeArena;
TypedAllocator<TypeFunctionTypePackVar> typePackArena;
NotNull<InternalErrorReporter> ice;
NotNull<TypeCheckLimits> limits;
StateRef state;
// Set of functions which have their environment table initialized
DenseHashSet<AstStatTypeFunction*> initialized{nullptr};
// Evaluation of type functions should only be performed in the absence of parse errors in the source module
bool allowEvaluation = true;
// Output created by 'print' function
std::vector<std::string> messages;
private:
void prepareState();
};
struct TypeFunctionContext struct TypeFunctionContext
{ {
NotNull<TypeArena> arena; NotNull<TypeArena> arena;
NotNull<BuiltinTypes> builtins; NotNull<BuiltinTypes> builtins;
NotNull<Scope> scope; NotNull<Scope> scope;
NotNull<Simplifier> simplifier;
NotNull<Normalizer> normalizer; NotNull<Normalizer> normalizer;
NotNull<TypeFunctionRuntime> typeFunctionRuntime;
NotNull<InternalErrorReporter> ice; NotNull<InternalErrorReporter> ice;
NotNull<TypeCheckLimits> limits; NotNull<TypeCheckLimits> limits;
@ -33,32 +72,25 @@ struct TypeFunctionContext
const Constraint* constraint; const Constraint* constraint;
std::optional<AstName> userFuncName; // Name of the user-defined type function; only available for UDTFs std::optional<AstName> userFuncName; // Name of the user-defined type function; only available for UDTFs
std::optional<AstExprFunction*> userFuncBody; // Body of the user-defined type function; only available for UDTFs
TypeFunctionContext(NotNull<ConstraintSolver> cs, NotNull<Scope> scope, NotNull<const Constraint> constraint) TypeFunctionContext(NotNull<ConstraintSolver> cs, NotNull<Scope> scope, NotNull<const Constraint> constraint);
: arena(cs->arena)
, builtins(cs->builtinTypes)
, scope(scope)
, normalizer(cs->normalizer)
, ice(NotNull{&cs->iceReporter})
, limits(NotNull{&cs->limits})
, solver(cs.get())
, constraint(constraint.get())
{
}
TypeFunctionContext( TypeFunctionContext(
NotNull<TypeArena> arena, NotNull<TypeArena> arena,
NotNull<BuiltinTypes> builtins, NotNull<BuiltinTypes> builtins,
NotNull<Scope> scope, NotNull<Scope> scope,
NotNull<Simplifier> simplifier,
NotNull<Normalizer> normalizer, NotNull<Normalizer> normalizer,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<InternalErrorReporter> ice, NotNull<InternalErrorReporter> ice,
NotNull<TypeCheckLimits> limits NotNull<TypeCheckLimits> limits
) )
: arena(arena) : arena(arena)
, builtins(builtins) , builtins(builtins)
, scope(scope) , scope(scope)
, simplifier(simplifier)
, normalizer(normalizer) , normalizer(normalizer)
, typeFunctionRuntime(typeFunctionRuntime)
, ice(ice) , ice(ice)
, limits(limits) , limits(limits)
, solver(nullptr) , solver(nullptr)
@ -66,7 +98,17 @@ struct TypeFunctionContext
{ {
} }
NotNull<Constraint> pushConstraint(ConstraintV&& c); NotNull<Constraint> pushConstraint(ConstraintV&& c) const;
};
enum class Reduction
{
// The type function is either known to be reducible or the determination is blocked.
MaybeOk,
// The type function is known to be irreducible, but maybe not be erroneous, e.g. when it's over generics or free types.
Irreducible,
// The type function is known to be irreducible, and is definitely erroneous.
Erroneous,
}; };
/// Represents a reduction result, which may have successfully reduced the type, /// Represents a reduction result, which may have successfully reduced the type,
@ -75,19 +117,25 @@ struct TypeFunctionContext
template<typename Ty> template<typename Ty>
struct TypeFunctionReductionResult struct TypeFunctionReductionResult
{ {
/// The result of the reduction, if any. If this is nullopt, the type function /// The result of the reduction, if any. If this is nullopt, the type function
/// could not be reduced. /// could not be reduced.
std::optional<Ty> result; std::optional<Ty> result;
/// Whether the result is uninhabited: whether we know, unambiguously and /// Indicates the status of this reduction: is `Reduction::Irreducible` if
/// permanently, whether this type function reduction results in an /// the this result indicates the type function is irreducible, and
/// uninhabitable type. This will trigger an error to be reported. /// `Reduction::Erroneous` if this result indicates the type function is
bool uninhabited; /// erroneous. `Reduction::MaybeOk` otherwise.
Reduction reductionStatus;
/// Any types that need to be progressed or mutated before the reduction may /// Any types that need to be progressed or mutated before the reduction may
/// proceed. /// proceed.
std::vector<TypeId> blockedTypes; std::vector<TypeId> blockedTypes;
/// Any type packs that need to be progressed or mutated before the /// Any type packs that need to be progressed or mutated before the
/// reduction may proceed. /// reduction may proceed.
std::vector<TypePackId> blockedPacks; std::vector<TypePackId> blockedPacks;
/// A runtime error message from user-defined type functions
std::optional<std::string> error;
/// Messages printed out from user-defined type functions
std::vector<std::string> messages;
}; };
template<typename T> template<typename T>
@ -121,6 +169,7 @@ struct TypePackFunction
struct FunctionGraphReductionResult struct FunctionGraphReductionResult
{ {
ErrorVec errors; ErrorVec errors;
ErrorVec messages;
DenseHashSet<TypeId> blockedTypes{nullptr}; DenseHashSet<TypeId> blockedTypes{nullptr};
DenseHashSet<TypePackId> blockedPacks{nullptr}; DenseHashSet<TypePackId> blockedPacks{nullptr};
DenseHashSet<TypeId> reducedTypes{nullptr}; DenseHashSet<TypeId> reducedTypes{nullptr};
@ -192,6 +241,9 @@ struct BuiltinTypeFunctions
TypeFunction indexFunc; TypeFunction indexFunc;
TypeFunction rawgetFunc; TypeFunction rawgetFunc;
TypeFunction setmetatableFunc;
TypeFunction getmetatableFunc;
void addToScope(NotNull<TypeArena> arena, NotNull<Scope> scope) const; void addToScope(NotNull<TypeArena> arena, NotNull<Scope> scope) const;
}; };

View file

@ -0,0 +1,298 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Common.h"
#include "Luau/Variant.h"
#include "Luau/TypeFwd.h"
#include <optional>
#include <string>
#include <map>
#include <vector>
using lua_State = struct lua_State;
namespace Luau
{
void* typeFunctionAlloc(void* ud, void* ptr, size_t osize, size_t nsize);
// Replica of types from Type.h
struct TypeFunctionType;
using TypeFunctionTypeId = const TypeFunctionType*;
struct TypeFunctionTypePackVar;
using TypeFunctionTypePackId = const TypeFunctionTypePackVar*;
struct TypeFunctionPrimitiveType
{
enum Type
{
NilType,
Boolean,
Number,
String,
Thread,
Buffer,
};
Type type;
TypeFunctionPrimitiveType(Type type)
: type(type)
{
}
};
struct TypeFunctionBooleanSingleton
{
bool value = false;
};
struct TypeFunctionStringSingleton
{
std::string value;
};
using TypeFunctionSingletonVariant = Variant<TypeFunctionBooleanSingleton, TypeFunctionStringSingleton>;
struct TypeFunctionSingletonType
{
TypeFunctionSingletonVariant variant;
explicit TypeFunctionSingletonType(TypeFunctionSingletonVariant variant)
: variant(std::move(variant))
{
}
};
template<typename T>
const T* get(const TypeFunctionSingletonType* tv)
{
LUAU_ASSERT(tv);
return tv ? get_if<T>(&tv->variant) : nullptr;
}
template<typename T>
T* getMutable(const TypeFunctionSingletonType* tv)
{
LUAU_ASSERT(tv);
return tv ? get_if<T>(&const_cast<TypeFunctionSingletonType*>(tv)->variant) : nullptr;
}
struct TypeFunctionUnionType
{
std::vector<TypeFunctionTypeId> components;
};
struct TypeFunctionIntersectionType
{
std::vector<TypeFunctionTypeId> components;
};
struct TypeFunctionAnyType
{
};
struct TypeFunctionUnknownType
{
};
struct TypeFunctionNeverType
{
};
struct TypeFunctionNegationType
{
TypeFunctionTypeId type;
};
struct TypeFunctionTypePack
{
std::vector<TypeFunctionTypeId> head;
std::optional<TypeFunctionTypePackId> tail;
};
struct TypeFunctionVariadicTypePack
{
TypeFunctionTypeId type;
};
struct TypeFunctionGenericTypePack
{
bool isNamed = false;
std::string name;
};
using TypeFunctionTypePackVariant = Variant<TypeFunctionTypePack, TypeFunctionVariadicTypePack, TypeFunctionGenericTypePack>;
struct TypeFunctionTypePackVar
{
TypeFunctionTypePackVariant type;
TypeFunctionTypePackVar(TypeFunctionTypePackVariant type)
: type(std::move(type))
{
}
bool operator==(const TypeFunctionTypePackVar& rhs) const;
};
struct TypeFunctionFunctionType
{
std::vector<TypeFunctionTypeId> generics;
std::vector<TypeFunctionTypePackId> genericPacks;
TypeFunctionTypePackId argTypes;
TypeFunctionTypePackId retTypes;
};
template<typename T>
const T* get(TypeFunctionTypePackId tv)
{
LUAU_ASSERT(tv);
return tv ? get_if<T>(&tv->type) : nullptr;
}
template<typename T>
T* getMutable(TypeFunctionTypePackId tv)
{
LUAU_ASSERT(tv);
return tv ? get_if<T>(&const_cast<TypeFunctionTypePackVar*>(tv)->type) : nullptr;
}
struct TypeFunctionTableIndexer
{
TypeFunctionTableIndexer(TypeFunctionTypeId keyType, TypeFunctionTypeId valueType)
: keyType(keyType)
, valueType(valueType)
{
}
TypeFunctionTypeId keyType;
TypeFunctionTypeId valueType;
};
struct TypeFunctionProperty
{
static TypeFunctionProperty readonly(TypeFunctionTypeId ty);
static TypeFunctionProperty writeonly(TypeFunctionTypeId ty);
static TypeFunctionProperty rw(TypeFunctionTypeId ty); // Shared read-write type.
static TypeFunctionProperty rw(TypeFunctionTypeId read, TypeFunctionTypeId write); // Separate read-write type.
bool isReadOnly() const;
bool isWriteOnly() const;
std::optional<TypeFunctionTypeId> readTy;
std::optional<TypeFunctionTypeId> writeTy;
};
struct TypeFunctionTableType
{
using Name = std::string;
using Props = std::map<Name, TypeFunctionProperty>;
Props props;
std::optional<TypeFunctionTableIndexer> indexer;
// Should always be a TypeFunctionTableType
std::optional<TypeFunctionTypeId> metatable;
};
struct TypeFunctionClassType
{
using Name = std::string;
using Props = std::map<Name, TypeFunctionProperty>;
Props props;
std::optional<TypeFunctionTableIndexer> indexer;
std::optional<TypeFunctionTypeId> metatable; // metaclass?
// this was mistaken, and we should actually be keeping separate read/write types here.
std::optional<TypeFunctionTypeId> parent_DEPRECATED;
std::optional<TypeFunctionTypeId> readParent;
std::optional<TypeFunctionTypeId> writeParent;
TypeId classTy;
std::string name_DEPRECATED;
};
struct TypeFunctionGenericType
{
bool isNamed = false;
bool isPack = false;
std::string name;
};
using TypeFunctionTypeVariant = Luau::Variant<
TypeFunctionPrimitiveType,
TypeFunctionAnyType,
TypeFunctionUnknownType,
TypeFunctionNeverType,
TypeFunctionSingletonType,
TypeFunctionUnionType,
TypeFunctionIntersectionType,
TypeFunctionNegationType,
TypeFunctionFunctionType,
TypeFunctionTableType,
TypeFunctionClassType,
TypeFunctionGenericType>;
struct TypeFunctionType
{
TypeFunctionTypeVariant type;
TypeFunctionType(TypeFunctionTypeVariant type)
: type(std::move(type))
{
}
bool operator==(const TypeFunctionType& rhs) const;
};
template<typename T>
const T* get(TypeFunctionTypeId tv)
{
LUAU_ASSERT(tv);
return tv ? Luau::get_if<T>(&tv->type) : nullptr;
}
template<typename T>
T* getMutable(TypeFunctionTypeId tv)
{
LUAU_ASSERT(tv);
return tv ? Luau::get_if<T>(&const_cast<TypeFunctionType*>(tv)->type) : nullptr;
}
std::optional<std::string> checkResultForError(lua_State* L, const char* typeFunctionName, int luaResult);
TypeFunctionType* allocateTypeFunctionType(lua_State* L, TypeFunctionTypeVariant type);
TypeFunctionTypePackVar* allocateTypeFunctionTypePack(lua_State* L, TypeFunctionTypePackVariant type);
void allocTypeUserData(lua_State* L, TypeFunctionTypeVariant type);
bool isTypeUserData(lua_State* L, int idx);
TypeFunctionTypeId getTypeUserData(lua_State* L, int idx);
std::optional<TypeFunctionTypeId> optionalTypeUserData(lua_State* L, int idx);
void registerTypesLibrary(lua_State* L);
void registerTypeUserData(lua_State* L);
void setTypeFunctionEnvironment(lua_State* L);
void resetTypeFunctionState(lua_State* L);
} // namespace Luau

View file

@ -0,0 +1,50 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Type.h"
#include "Luau/TypeFunction.h"
#include "Luau/TypeFunctionRuntime.h"
namespace Luau
{
using Kind = Variant<TypeId, TypePackId>;
template<typename T>
const T* get(const Kind& kind)
{
return get_if<T>(&kind);
}
using TypeFunctionKind = Variant<TypeFunctionTypeId, TypeFunctionTypePackId>;
template<typename T>
const T* get(const TypeFunctionKind& tfkind)
{
return get_if<T>(&tfkind);
}
struct TypeFunctionRuntimeBuilderState
{
NotNull<TypeFunctionContext> ctx;
// Mapping of class name to ClassType
// Invariant: users can not create a new class types -> any class types that get deserialized must have been an argument to the type function
// Using this invariant, whenever a ClassType is serialized, we can put it into this map
// whenever a ClassType is deserialized, we can use this map to return the corresponding value
DenseHashMap<std::string, TypeId> classesSerialized_DEPRECATED{{}};
// List of errors that occur during serialization/deserialization
// At every iteration of serialization/deserialzation, if this list.size() != 0, we halt the process
std::vector<std::string> errors{};
TypeFunctionRuntimeBuilderState(NotNull<TypeFunctionContext> ctx)
: ctx(ctx)
{
}
};
TypeFunctionTypeId serialize(TypeId ty, TypeFunctionRuntimeBuilderState* state);
TypeId deserialize(TypeFunctionTypeId ty, TypeFunctionRuntimeBuilderState* state);
} // namespace Luau

View file

@ -399,8 +399,8 @@ private:
const ScopePtr& scope, const ScopePtr& scope,
std::optional<TypeLevel> levelOpt, std::optional<TypeLevel> levelOpt,
const AstNode& node, const AstNode& node,
const AstArray<AstGenericType>& genericNames, const AstArray<AstGenericType*>& genericNames,
const AstArray<AstGenericTypePack>& genericPackNames, const AstArray<AstGenericTypePack*>& genericPackNames,
bool useCache = false bool useCache = false
); );

View file

@ -52,7 +52,7 @@ struct GenericTypePack
}; };
using BoundTypePack = Unifiable::Bound<TypePackId>; using BoundTypePack = Unifiable::Bound<TypePackId>;
using ErrorTypePack = Unifiable::Error; using ErrorTypePack = Unifiable::Error<TypePackId>;
using TypePackVariant = using TypePackVariant =
Unifiable::Variant<TypePackId, FreeTypePack, GenericTypePack, TypePack, VariadicTypePack, BlockedTypePack, TypeFunctionInstanceTypePack>; Unifiable::Variant<TypePackId, FreeTypePack, GenericTypePack, TypePack, VariadicTypePack, BlockedTypePack, TypeFunctionInstanceTypePack>;

View file

@ -51,6 +51,8 @@ struct Index
/// Represents fields of a type or pack that contain a type. /// Represents fields of a type or pack that contain a type.
enum class TypeField enum class TypeField
{ {
/// The table of a metatable type.
Table,
/// The metatable of a type. This could be a metatable type, a primitive /// The metatable of a type. This could be a metatable type, a primitive
/// type, a class type, or perhaps even a string singleton type. /// type, a class type, or perhaps even a string singleton type.
Metatable, Metatable,

View file

@ -40,7 +40,7 @@ struct InConditionalContext
TypeContext* typeContext; TypeContext* typeContext;
TypeContext oldValue; TypeContext oldValue;
InConditionalContext(TypeContext* c) explicit InConditionalContext(TypeContext* c)
: typeContext(c) : typeContext(c)
, oldValue(*c) , oldValue(*c)
{ {
@ -248,4 +248,45 @@ std::optional<Ty> follow(std::optional<Ty> ty)
return std::nullopt; return std::nullopt;
} }
/**
* Returns whether or not expr is a literal expression, for example:
* - Scalar literals (numbers, booleans, strings, nil)
* - Table literals
* - Lambdas (a "function literal")
*/
bool isLiteral(const AstExpr* expr);
/**
* Given a table literal and a mapping from expression to type, determine
* whether any literal expression in this table depends on any blocked types.
* This is used as a precondition for bidirectional inference: be warned that
* the behavior of this algorithm is tightly coupled to that of bidirectional
* inference.
* @param expr Expression to search
* @param astTypes Mapping from AST node to TypeID
* @returns A vector of blocked types
*/
std::vector<TypeId> findBlockedTypesIn(AstExprTable* expr, NotNull<DenseHashMap<const AstExpr*, TypeId>> astTypes);
/**
* Given a function call and a mapping from expression to type, determine
* whether the type of any argument in said call in depends on a blocked types.
* This is used as a precondition for bidirectional inference: be warned that
* the behavior of this algorithm is tightly coupled to that of bidirectional
* inference.
* @param expr Expression to search
* @param astTypes Mapping from AST node to TypeID
* @returns A vector of blocked types
*/
std::vector<TypeId> findBlockedArgTypesIn(AstExprCall* expr, NotNull<DenseHashMap<const AstExpr*, TypeId>> astTypes);
/**
* Given a scope and a free type, find the closest parent that has a present
* `interiorFreeTypes` and append the given type to said list. This list will
* be generalized when the requiste `GeneralizationConstraint` is resolved.
* @param scope Initial scope this free type was attached to
* @param ty Free type to track.
*/
void trackInteriorFreeType(Scope* scope, TypeId ty);
} // namespace Luau } // namespace Luau

View file

@ -3,6 +3,7 @@
#include "Luau/Variant.h" #include "Luau/Variant.h"
#include <optional>
#include <string> #include <string>
namespace Luau namespace Luau
@ -94,19 +95,29 @@ struct Bound
Id boundTo; Id boundTo;
}; };
template<typename Id>
struct Error struct Error
{ {
// This constructor has to be public, since it's used in Type and TypePack, // This constructor has to be public, since it's used in Type and TypePack,
// but shouldn't be called directly. Please use errorRecoveryType() instead. // but shouldn't be called directly. Please use errorRecoveryType() instead.
Error(); explicit Error();
explicit Error(Id synthetic)
: synthetic{synthetic}
{
}
int index; int index;
// This is used to create an error that can be rendered out using this field
// as appropriate metadata for communicating it to the user.
std::optional<Id> synthetic;
private: private:
static int nextIndex; static int nextIndex;
}; };
template<typename Id, typename... Value> template<typename Id, typename... Value>
using Variant = Luau::Variant<Bound<Id>, Error, Value...>; using Variant = Luau::Variant<Bound<Id>, Error<Id>, Value...>;
} // namespace Luau::Unifiable } // namespace Luau::Unifiable

View file

@ -93,10 +93,6 @@ struct Unifier
Unifier(NotNull<Normalizer> normalizer, NotNull<Scope> scope, const Location& location, Variance variance, TxnLog* parentLog = nullptr); Unifier(NotNull<Normalizer> normalizer, NotNull<Scope> scope, const Location& location, Variance variance, TxnLog* parentLog = nullptr);
// Configure the Unifier to test for scope subsumption via embedded Scope
// pointers rather than TypeLevels.
void enableNewSolver();
// Test whether the two type vars unify. Never commits the result. // Test whether the two type vars unify. Never commits the result.
ErrorVec canUnify(TypeId subTy, TypeId superTy); ErrorVec canUnify(TypeId subTy, TypeId superTy);
ErrorVec canUnify(TypePackId subTy, TypePackId superTy, bool isFunctionCall = false); ErrorVec canUnify(TypePackId subTy, TypePackId superTy, bool isFunctionCall = false);
@ -169,7 +165,6 @@ private:
std::optional<TypeId> findTablePropertyRespectingMeta(TypeId lhsType, Name name); std::optional<TypeId> findTablePropertyRespectingMeta(TypeId lhsType, Name name);
TxnLog combineLogsIntoIntersection(std::vector<TxnLog> logs);
TxnLog combineLogsIntoUnion(std::vector<TxnLog> logs); TxnLog combineLogsIntoUnion(std::vector<TxnLog> logs);
public: public:
@ -179,7 +174,7 @@ public:
bool occursCheck(TypePackId needle, TypePackId haystack, bool reversed); bool occursCheck(TypePackId needle, TypePackId haystack, bool reversed);
bool occursCheck(DenseHashSet<TypePackId>& seen, TypePackId needle, TypePackId haystack); bool occursCheck(DenseHashSet<TypePackId>& seen, TypePackId needle, TypePackId haystack);
Unifier makeChildUnifier(); std::unique_ptr<Unifier> makeChildUnifier();
void reportError(TypeError err); void reportError(TypeError err);
LUAU_NOINLINE void reportError(Location location, TypeErrorData data); LUAU_NOINLINE void reportError(Location location, TypeErrorData data);
@ -195,11 +190,6 @@ private:
// Available after regular type pack unification errors // Available after regular type pack unification errors
std::optional<int> firstPackErrorPos; std::optional<int> firstPackErrorPos;
// If true, we do a bunch of small things differently to work better with
// the new type inference engine. Most notably, we use the Scope hierarchy
// directly rather than using TypeLevels.
bool useNewSolver = false;
}; };
void promoteTypeLevels(TxnLog& log, const TypeArena* arena, TypeLevel minLevel, Scope* outerScope, bool useScope, TypePackId tp); void promoteTypeLevels(TxnLog& log, const TypeArena* arena, TypeLevel minLevel, Scope* outerScope, bool useScope, TypePackId tp);

View file

@ -49,6 +49,26 @@ struct UnifierSharedState
DenseHashSet<TypePackId> tempSeenTp{nullptr}; DenseHashSet<TypePackId> tempSeenTp{nullptr};
UnifierCounters counters; UnifierCounters counters;
bool reentrantTypeReduction = false;
};
struct TypeReductionRentrancyGuard final
{
explicit TypeReductionRentrancyGuard(NotNull<UnifierSharedState> sharedState)
: sharedState{sharedState}
{
sharedState->reentrantTypeReduction = true;
}
~TypeReductionRentrancyGuard()
{
sharedState->reentrantTypeReduction = false;
}
TypeReductionRentrancyGuard(const TypeReductionRentrancyGuard&) = delete;
TypeReductionRentrancyGuard(TypeReductionRentrancyGuard&&) = delete;
private:
NotNull<UnifierSharedState> sharedState;
}; };
} // namespace Luau } // namespace Luau

View file

@ -10,7 +10,6 @@
#include "Type.h" #include "Type.h"
LUAU_FASTINT(LuauVisitRecursionLimit) LUAU_FASTINT(LuauVisitRecursionLimit)
LUAU_FASTFLAG(LuauBoundLazyTypes2)
LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAG(LuauSolverV2)
namespace Luau namespace Luau
@ -86,6 +85,8 @@ struct GenericTypeVisitor
{ {
} }
virtual ~GenericTypeVisitor() {}
virtual void cycle(TypeId) {} virtual void cycle(TypeId) {}
virtual void cycle(TypePackId) {} virtual void cycle(TypePackId) {}
@ -133,6 +134,10 @@ struct GenericTypeVisitor
{ {
return visit(ty); return visit(ty);
} }
virtual bool visit(TypeId ty, const NoRefineType& nrt)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const UnknownType& utv) virtual bool visit(TypeId ty, const UnknownType& utv)
{ {
return visit(ty); return visit(ty);
@ -186,7 +191,7 @@ struct GenericTypeVisitor
{ {
return visit(tp); return visit(tp);
} }
virtual bool visit(TypePackId tp, const Unifiable::Error& etp) virtual bool visit(TypePackId tp, const ErrorTypePack& etp)
{ {
return visit(tp); return visit(tp);
} }
@ -345,6 +350,8 @@ struct GenericTypeVisitor
} }
else if (auto atv = get<AnyType>(ty)) else if (auto atv = get<AnyType>(ty))
visit(ty, *atv); visit(ty, *atv);
else if (auto nrt = get<NoRefineType>(ty))
visit(ty, *nrt);
else if (auto utv = get<UnionType>(ty)) else if (auto utv = get<UnionType>(ty))
{ {
if (visit(ty, *utv)) if (visit(ty, *utv))
@ -455,7 +462,7 @@ struct GenericTypeVisitor
else if (auto gtv = get<GenericTypePack>(tp)) else if (auto gtv = get<GenericTypePack>(tp))
visit(tp, *gtv); visit(tp, *gtv);
else if (auto etv = get<Unifiable::Error>(tp)) else if (auto etv = get<ErrorTypePack>(tp))
visit(tp, *etv); visit(tp, *etv);
else if (auto pack = get<TypePack>(tp)) else if (auto pack = get<TypePack>(tp))

View file

@ -38,7 +38,7 @@
#include <stdio.h> #include <stdio.h>
LUAU_FASTFLAGVARIABLE(StudioReportLuauAny2, false); LUAU_FASTFLAGVARIABLE(StudioReportLuauAny2);
LUAU_FASTINTVARIABLE(LuauAnySummaryRecursionLimit, 300); LUAU_FASTINTVARIABLE(LuauAnySummaryRecursionLimit, 300);
LUAU_FASTFLAG(DebugLuauMagicTypes); LUAU_FASTFLAG(DebugLuauMagicTypes);
@ -177,7 +177,6 @@ void AnyTypeSummary::visit(const Scope* scope, AstStatReturn* ret, const Module*
} }
} }
} }
} }
void AnyTypeSummary::visit(const Scope* scope, AstStatLocal* local, const Module* module, NotNull<BuiltinTypes> builtinTypes) void AnyTypeSummary::visit(const Scope* scope, AstStatLocal* local, const Module* module, NotNull<BuiltinTypes> builtinTypes)

View file

@ -425,6 +425,7 @@ struct AstJsonEncoder : public AstVisitor
"AstExprFunction", "AstExprFunction",
[&]() [&]()
{ {
PROP(attributes);
PROP(generics); PROP(generics);
PROP(genericPacks); PROP(genericPacks);
if (node->self) if (node->self)
@ -881,7 +882,7 @@ struct AstJsonEncoder : public AstVisitor
PROP(name); PROP(name);
PROP(generics); PROP(generics);
PROP(genericPacks); PROP(genericPacks);
PROP(type); write("value", node->type);
PROP(exported); PROP(exported);
} }
); );
@ -894,7 +895,7 @@ struct AstJsonEncoder : public AstVisitor
"AstStatDeclareFunction", "AstStatDeclareFunction",
[&]() [&]()
{ {
// TODO: attributes PROP(attributes);
PROP(name); PROP(name);
PROP(nameLocation); PROP(nameLocation);
PROP(params); PROP(params);
@ -1042,6 +1043,7 @@ struct AstJsonEncoder : public AstVisitor
"AstTypeFunction", "AstTypeFunction",
[&]() [&]()
{ {
PROP(attributes);
PROP(generics); PROP(generics);
PROP(genericPacks); PROP(genericPacks);
PROP(argTypes); PROP(argTypes);
@ -1136,6 +1138,42 @@ struct AstJsonEncoder : public AstVisitor
); );
} }
void write(AstAttr::Type type)
{
switch (type)
{
case AstAttr::Type::Checked:
return writeString("checked");
case AstAttr::Type::Native:
return writeString("native");
}
}
void write(class AstAttr* node)
{
writeNode(
node,
"AstAttr",
[&]()
{
write("name", node->type);
}
);
}
bool visit(class AstTypeGroup* node) override
{
writeNode(
node,
"AstTypeGroup",
[&]()
{
write("inner", node->type);
}
);
return false;
}
bool visit(class AstTypeSingletonBool* node) override bool visit(class AstTypeSingletonBool* node) override
{ {
writeNode( writeNode(

View file

@ -13,6 +13,8 @@
LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAG(LuauExtendStatEndPosWithSemicolon)
namespace Luau namespace Luau
{ {
@ -40,12 +42,27 @@ struct AutocompleteNodeFinder : public AstVisitor
} }
bool visit(AstStat* stat) override bool visit(AstStat* stat) override
{
if (FFlag::LuauExtendStatEndPosWithSemicolon)
{
// Consider 'local myLocal = 4;|' and 'local myLocal = 4', where '|' is the cursor position. In both cases, the cursor position is equal
// to `AstStatLocal.location.end`. However, in the first case (semicolon), we are starting a new statement, whilst in the second case
// (no semicolon) we are still part of the AstStatLocal, hence the different comparison check.
if (stat->location.begin < pos && (stat->hasSemicolon ? pos < stat->location.end : pos <= stat->location.end))
{
ancestry.push_back(stat);
return true;
}
}
else
{ {
if (stat->location.begin < pos && pos <= stat->location.end) if (stat->location.begin < pos && pos <= stat->location.end)
{ {
ancestry.push_back(stat); ancestry.push_back(stat);
return true; return true;
} }
}
return false; return false;
} }
@ -509,6 +526,37 @@ static std::optional<DocumentationSymbol> checkOverloadedDocumentationSymbol(
return documentationSymbol; return documentationSymbol;
} }
static std::optional<DocumentationSymbol> getMetatableDocumentation(
const Module& module,
AstExpr* parentExpr,
const TableType* mtable,
const AstName& index
)
{
auto indexIt = mtable->props.find("__index");
if (indexIt == mtable->props.end())
return std::nullopt;
TypeId followed = follow(indexIt->second.type());
const TableType* ttv = get<TableType>(followed);
if (!ttv)
return std::nullopt;
auto propIt = ttv->props.find(index.value);
if (propIt == ttv->props.end())
return std::nullopt;
if (FFlag::LuauSolverV2)
{
if (auto ty = propIt->second.readTy)
return checkOverloadedDocumentationSymbol(module, *ty, parentExpr, propIt->second.documentationSymbol);
}
else
return checkOverloadedDocumentationSymbol(module, propIt->second.type(), parentExpr, propIt->second.documentationSymbol);
return std::nullopt;
}
std::optional<DocumentationSymbol> getDocumentationSymbolAtPosition(const SourceModule& source, const Module& module, Position position) std::optional<DocumentationSymbol> getDocumentationSymbolAtPosition(const SourceModule& source, const Module& module, Position position)
{ {
std::vector<AstNode*> ancestry = findAstAncestryOfPosition(source, position); std::vector<AstNode*> ancestry = findAstAncestryOfPosition(source, position);
@ -540,6 +588,8 @@ std::optional<DocumentationSymbol> getDocumentationSymbolAtPosition(const Source
} }
} }
else if (const ClassType* ctv = get<ClassType>(parentTy)) else if (const ClassType* ctv = get<ClassType>(parentTy))
{
while (ctv)
{ {
if (auto propIt = ctv->props.find(indexName->index.value); propIt != ctv->props.end()) if (auto propIt = ctv->props.find(indexName->index.value); propIt != ctv->props.end())
{ {
@ -549,7 +599,19 @@ std::optional<DocumentationSymbol> getDocumentationSymbolAtPosition(const Source
return checkOverloadedDocumentationSymbol(module, *ty, parentExpr, propIt->second.documentationSymbol); return checkOverloadedDocumentationSymbol(module, *ty, parentExpr, propIt->second.documentationSymbol);
} }
else else
return checkOverloadedDocumentationSymbol(module, propIt->second.type(), parentExpr, propIt->second.documentationSymbol); return checkOverloadedDocumentationSymbol(
module, propIt->second.type(), parentExpr, propIt->second.documentationSymbol
);
}
ctv = ctv->parent ? Luau::get<Luau::ClassType>(*ctv->parent) : nullptr;
}
}
else if (const PrimitiveType* ptv = get<PrimitiveType>(parentTy); ptv && ptv->metatable)
{
if (auto mtable = get<TableType>(*ptv->metatable))
{
if (std::optional<std::string> docSymbol = getMetatableDocumentation(module, parentExpr, mtable, indexName->index))
return docSymbol;
} }
} }
} }

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,27 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/AutocompleteTypes.h"
namespace Luau
{
struct Module;
struct FileResolver;
using ModulePtr = std::shared_ptr<Module>;
using ModuleName = std::string;
AutocompleteResult autocomplete_(
const ModulePtr& module,
NotNull<BuiltinTypes> builtinTypes,
TypeArena* typeArena,
std::vector<AstNode*>& ancestry,
Scope* globalScope,
const ScopePtr& scopeAtPosition,
Position position,
FileResolver* fileResolver,
StringCompletionCallback callback
);
} // namespace Luau

View file

@ -2,6 +2,9 @@
#include "Luau/BuiltinDefinitions.h" #include "Luau/BuiltinDefinitions.h"
#include "Luau/Ast.h" #include "Luau/Ast.h"
#include "Luau/Clone.h"
#include "Luau/DenseHash.h"
#include "Luau/Error.h"
#include "Luau/Frontend.h" #include "Luau/Frontend.h"
#include "Luau/Symbol.h" #include "Luau/Symbol.h"
#include "Luau/Common.h" #include "Luau/Common.h"
@ -25,47 +28,93 @@
* about a function that takes any number of values, but where each value must have some specific type. * about a function that takes any number of values, but where each value must have some specific type.
*/ */
LUAU_FASTFLAG(LuauSolverV2); LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAGVARIABLE(LuauDCRMagicFunctionTypeChecker, false); LUAU_FASTFLAGVARIABLE(LuauStringFormatErrorSuppression)
LUAU_FASTFLAGVARIABLE(LuauTableCloneClonesType3)
LUAU_FASTFLAG(LuauTrackInteriorFreeTypesOnScope)
LUAU_FASTFLAGVARIABLE(LuauFreezeIgnorePersistent)
LUAU_FASTFLAGVARIABLE(LuauFollowTableFreeze)
namespace Luau namespace Luau
{ {
static std::optional<WithPredicate<TypePackId>> magicFunctionSelect( struct MagicSelect final : MagicFunction
TypeChecker& typechecker, {
const ScopePtr& scope, std::optional<WithPredicate<TypePackId>>
const AstExprCall& expr, handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
WithPredicate<TypePackId> withPredicate bool infer(const MagicFunctionCallContext& ctx) override;
); };
static std::optional<WithPredicate<TypePackId>> magicFunctionSetMetaTable(
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
);
static std::optional<WithPredicate<TypePackId>> magicFunctionAssert(
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
);
static std::optional<WithPredicate<TypePackId>> magicFunctionPack(
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
);
static std::optional<WithPredicate<TypePackId>> magicFunctionRequire(
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
);
struct MagicSetMetatable final : MagicFunction
{
std::optional<WithPredicate<TypePackId>>
handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override;
};
static bool dcrMagicFunctionSelect(MagicFunctionCallContext context); struct MagicAssert final : MagicFunction
static bool dcrMagicFunctionRequire(MagicFunctionCallContext context); {
static bool dcrMagicFunctionPack(MagicFunctionCallContext context); std::optional<WithPredicate<TypePackId>>
handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override;
};
struct MagicPack final : MagicFunction
{
std::optional<WithPredicate<TypePackId>>
handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override;
};
struct MagicRequire final : MagicFunction
{
std::optional<WithPredicate<TypePackId>>
handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override;
};
struct MagicClone final : MagicFunction
{
std::optional<WithPredicate<TypePackId>>
handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override;
};
struct MagicFreeze final : MagicFunction
{
std::optional<WithPredicate<TypePackId>>
handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override;
};
struct MagicFormat final : MagicFunction
{
std::optional<WithPredicate<TypePackId>>
handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override;
bool typeCheck(const MagicFunctionTypeCheckContext& ctx) override;
};
struct MagicMatch final : MagicFunction
{
std::optional<WithPredicate<TypePackId>>
handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override;
};
struct MagicGmatch final : MagicFunction
{
std::optional<WithPredicate<TypePackId>>
handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override;
};
struct MagicFind final : MagicFunction
{
std::optional<WithPredicate<TypePackId>>
handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override;
};
TypeId makeUnion(TypeArena& arena, std::vector<TypeId>&& types) TypeId makeUnion(TypeArena& arena, std::vector<TypeId>&& types)
{ {
@ -160,34 +209,10 @@ TypeId makeFunction(
return arena.addType(std::move(ftv)); return arena.addType(std::move(ftv));
} }
void attachMagicFunction(TypeId ty, MagicFunction fn) void attachMagicFunction(TypeId ty, std::shared_ptr<MagicFunction> magic)
{ {
if (auto ftv = getMutable<FunctionType>(ty)) if (auto ftv = getMutable<FunctionType>(ty))
ftv->magicFunction = fn; ftv->magic = std::move(magic);
else
LUAU_ASSERT(!"Got a non functional type");
}
void attachDcrMagicFunction(TypeId ty, DcrMagicFunction fn)
{
if (auto ftv = getMutable<FunctionType>(ty))
ftv->dcrMagicFunction = fn;
else
LUAU_ASSERT(!"Got a non functional type");
}
void attachDcrMagicRefinement(TypeId ty, DcrMagicRefinement fn)
{
if (auto ftv = getMutable<FunctionType>(ty))
ftv->dcrMagicRefinement = fn;
else
LUAU_ASSERT(!"Got a non functional type");
}
void attachDcrMagicFunctionTypeCheck(TypeId ty, DcrMagicFunctionTypeCheck fn)
{
if (auto ftv = getMutable<FunctionType>(ty))
ftv->dcrMagicTypeCheck = fn;
else else
LUAU_ASSERT(!"Got a non functional type"); LUAU_ASSERT(!"Got a non functional type");
} }
@ -293,6 +318,28 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC
addGlobalBinding(globals, "string", it->second.type(), "@luau"); addGlobalBinding(globals, "string", it->second.type(), "@luau");
// Setup 'vector' metatable
if (auto it = globals.globalScope->exportedTypeBindings.find("vector"); it != globals.globalScope->exportedTypeBindings.end())
{
TypeId vectorTy = it->second.type;
ClassType* vectorCls = getMutable<ClassType>(vectorTy);
vectorCls->metatable = arena.addType(TableType{{}, std::nullopt, TypeLevel{}, TableState::Sealed});
TableType* metatableTy = Luau::getMutable<TableType>(vectorCls->metatable);
metatableTy->props["__add"] = {makeFunction(arena, vectorTy, {vectorTy}, {vectorTy})};
metatableTy->props["__sub"] = {makeFunction(arena, vectorTy, {vectorTy}, {vectorTy})};
metatableTy->props["__unm"] = {makeFunction(arena, vectorTy, {}, {vectorTy})};
std::initializer_list<TypeId> mulOverloads{
makeFunction(arena, vectorTy, {vectorTy}, {vectorTy}),
makeFunction(arena, vectorTy, {builtinTypes->numberType}, {vectorTy}),
};
metatableTy->props["__mul"] = {makeIntersection(arena, mulOverloads)};
metatableTy->props["__div"] = {makeIntersection(arena, mulOverloads)};
metatableTy->props["__idiv"] = {makeIntersection(arena, mulOverloads)};
}
// next<K, V>(t: Table<K, V>, i: K?) -> (K?, V) // next<K, V>(t: Table<K, V>, i: K?) -> (K?, V)
TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(builtinTypes, arena, genericK)}}); TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(builtinTypes, arena, genericK)}});
TypePackId nextRetsTypePack = arena.addTypePack(TypePack{{makeOption(builtinTypes, arena, genericK), genericV}}); TypePackId nextRetsTypePack = arena.addTypePack(TypePack{{makeOption(builtinTypes, arena, genericK), genericV}});
@ -363,7 +410,7 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC
} }
} }
attachMagicFunction(getGlobalBinding(globals, "assert"), magicFunctionAssert); attachMagicFunction(getGlobalBinding(globals, "assert"), std::make_shared<MagicAssert>());
if (FFlag::LuauSolverV2) if (FFlag::LuauSolverV2)
{ {
@ -379,9 +426,8 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC
addGlobalBinding(globals, "assert", assertTy, "@luau"); addGlobalBinding(globals, "assert", assertTy, "@luau");
} }
attachMagicFunction(getGlobalBinding(globals, "setmetatable"), magicFunctionSetMetaTable); attachMagicFunction(getGlobalBinding(globals, "setmetatable"), std::make_shared<MagicSetMetatable>());
attachMagicFunction(getGlobalBinding(globals, "select"), magicFunctionSelect); attachMagicFunction(getGlobalBinding(globals, "select"), std::make_shared<MagicSelect>());
attachDcrMagicFunction(getGlobalBinding(globals, "select"), dcrMagicFunctionSelect);
if (TableType* ttv = getMutable<TableType>(getGlobalBinding(globals, "table"))) if (TableType* ttv = getMutable<TableType>(getGlobalBinding(globals, "table")))
{ {
@ -394,8 +440,10 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC
// but it'll be ok for now. // but it'll be ok for now.
TypeId genericTy = arena.addType(GenericType{"T"}); TypeId genericTy = arena.addType(GenericType{"T"});
TypePackId thePack = arena.addTypePack({genericTy}); TypePackId thePack = arena.addTypePack({genericTy});
TypeId idTyWithMagic = arena.addType(FunctionType{{genericTy}, {}, thePack, thePack});
ttv->props["freeze"] = makeProperty(idTyWithMagic, "@luau/global/table.freeze");
TypeId idTy = arena.addType(FunctionType{{genericTy}, {}, thePack, thePack}); TypeId idTy = arena.addType(FunctionType{{genericTy}, {}, thePack, thePack});
ttv->props["freeze"] = makeProperty(idTy, "@luau/global/table.freeze");
ttv->props["clone"] = makeProperty(idTy, "@luau/global/table.clone"); ttv->props["clone"] = makeProperty(idTy, "@luau/global/table.clone");
} }
else else
@ -410,12 +458,15 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC
ttv->props["foreach"].deprecated = true; ttv->props["foreach"].deprecated = true;
ttv->props["foreachi"].deprecated = true; ttv->props["foreachi"].deprecated = true;
attachMagicFunction(ttv->props["pack"].type(), magicFunctionPack); attachMagicFunction(ttv->props["pack"].type(), std::make_shared<MagicPack>());
attachDcrMagicFunction(ttv->props["pack"].type(), dcrMagicFunctionPack); if (FFlag::LuauTableCloneClonesType3)
attachMagicFunction(ttv->props["clone"].type(), std::make_shared<MagicClone>());
attachMagicFunction(ttv->props["freeze"].type(), std::make_shared<MagicFreeze>());
} }
attachMagicFunction(getGlobalBinding(globals, "require"), magicFunctionRequire); TypeId requireTy = getGlobalBinding(globals, "require");
attachDcrMagicFunction(getGlobalBinding(globals, "require"), dcrMagicFunctionRequire); attachTag(requireTy, kRequireTagName);
attachMagicFunction(requireTy, std::make_shared<MagicRequire>());
} }
static std::vector<TypeId> parseFormatString(NotNull<BuiltinTypes> builtinTypes, const char* data, size_t size) static std::vector<TypeId> parseFormatString(NotNull<BuiltinTypes> builtinTypes, const char* data, size_t size)
@ -454,7 +505,7 @@ static std::vector<TypeId> parseFormatString(NotNull<BuiltinTypes> builtinTypes,
return result; return result;
} }
std::optional<WithPredicate<TypePackId>> magicFunctionFormat( std::optional<WithPredicate<TypePackId>> MagicFormat::handleOldSolver(
TypeChecker& typechecker, TypeChecker& typechecker,
const ScopePtr& scope, const ScopePtr& scope,
const AstExprCall& expr, const AstExprCall& expr,
@ -504,7 +555,7 @@ std::optional<WithPredicate<TypePackId>> magicFunctionFormat(
return WithPredicate<TypePackId>{arena.addTypePack({typechecker.stringType})}; return WithPredicate<TypePackId>{arena.addTypePack({typechecker.stringType})};
} }
static bool dcrMagicFunctionFormat(MagicFunctionCallContext context) bool MagicFormat::infer(const MagicFunctionCallContext& context)
{ {
TypeArena* arena = context.solver->arena; TypeArena* arena = context.solver->arena;
@ -548,7 +599,7 @@ static bool dcrMagicFunctionFormat(MagicFunctionCallContext context)
return true; return true;
} }
static void dcrMagicFunctionTypeCheckFormat(MagicFunctionTypeCheckContext context) bool MagicFormat::typeCheck(const MagicFunctionTypeCheckContext& context)
{ {
AstExprConstantString* fmt = nullptr; AstExprConstantString* fmt = nullptr;
if (auto index = context.callSite->func->as<AstExprIndexName>(); index && context.callSite->self) if (auto index = context.callSite->func->as<AstExprIndexName>(); index && context.callSite->self)
@ -563,7 +614,10 @@ static void dcrMagicFunctionTypeCheckFormat(MagicFunctionTypeCheckContext contex
fmt = context.callSite->args.data[0]->as<AstExprConstantString>(); fmt = context.callSite->args.data[0]->as<AstExprConstantString>();
if (!fmt) if (!fmt)
return; {
context.typechecker->reportError(CountMismatch{1, std::nullopt, 0, CountMismatch::Arg, true, "string.format"}, context.callSite->location);
return true;
}
std::vector<TypeId> expected = parseFormatString(context.builtinTypes, fmt->value.data, fmt->value.size); std::vector<TypeId> expected = parseFormatString(context.builtinTypes, fmt->value.data, fmt->value.size);
const auto& [params, tail] = flatten(context.arguments); const auto& [params, tail] = flatten(context.arguments);
@ -579,7 +633,25 @@ static void dcrMagicFunctionTypeCheckFormat(MagicFunctionTypeCheckContext contex
Location location = context.callSite->args.data[i + (calledWithSelf ? 0 : paramOffset)]->location; Location location = context.callSite->args.data[i + (calledWithSelf ? 0 : paramOffset)]->location;
// use subtyping instead here // use subtyping instead here
SubtypingResult result = context.typechecker->subtyping->isSubtype(actualTy, expectedTy, context.checkScope); SubtypingResult result = context.typechecker->subtyping->isSubtype(actualTy, expectedTy, context.checkScope);
if (!result.isSubtype) if (!result.isSubtype)
{
if (FFlag::LuauStringFormatErrorSuppression)
{
switch (shouldSuppressErrors(NotNull{&context.typechecker->normalizer}, actualTy))
{
case ErrorSuppression::Suppress:
break;
case ErrorSuppression::NormalizationFailed:
break;
case ErrorSuppression::DoNotSuppress:
Reasonings reasonings = context.typechecker->explainReasonings(actualTy, expectedTy, location, result);
if (!reasonings.suppressed)
context.typechecker->reportError(TypeMismatch{expectedTy, actualTy, reasonings.toString()}, location);
}
}
else
{ {
Reasonings reasonings = context.typechecker->explainReasonings(actualTy, expectedTy, location, result); Reasonings reasonings = context.typechecker->explainReasonings(actualTy, expectedTy, location, result);
context.typechecker->reportError(TypeMismatch{expectedTy, actualTy, reasonings.toString()}, location); context.typechecker->reportError(TypeMismatch{expectedTy, actualTy, reasonings.toString()}, location);
@ -587,6 +659,9 @@ static void dcrMagicFunctionTypeCheckFormat(MagicFunctionTypeCheckContext contex
} }
} }
return true;
}
static std::vector<TypeId> parsePatternString(NotNull<BuiltinTypes> builtinTypes, const char* data, size_t size) static std::vector<TypeId> parsePatternString(NotNull<BuiltinTypes> builtinTypes, const char* data, size_t size)
{ {
std::vector<TypeId> result; std::vector<TypeId> result;
@ -647,7 +722,7 @@ static std::vector<TypeId> parsePatternString(NotNull<BuiltinTypes> builtinTypes
return result; return result;
} }
static std::optional<WithPredicate<TypePackId>> magicFunctionGmatch( std::optional<WithPredicate<TypePackId>> MagicGmatch::handleOldSolver(
TypeChecker& typechecker, TypeChecker& typechecker,
const ScopePtr& scope, const ScopePtr& scope,
const AstExprCall& expr, const AstExprCall& expr,
@ -683,7 +758,7 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionGmatch(
return WithPredicate<TypePackId>{arena.addTypePack({iteratorType})}; return WithPredicate<TypePackId>{arena.addTypePack({iteratorType})};
} }
static bool dcrMagicFunctionGmatch(MagicFunctionCallContext context) bool MagicGmatch::infer(const MagicFunctionCallContext& context)
{ {
const auto& [params, tail] = flatten(context.arguments); const auto& [params, tail] = flatten(context.arguments);
@ -716,7 +791,7 @@ static bool dcrMagicFunctionGmatch(MagicFunctionCallContext context)
return true; return true;
} }
static std::optional<WithPredicate<TypePackId>> magicFunctionMatch( std::optional<WithPredicate<TypePackId>> MagicMatch::handleOldSolver(
TypeChecker& typechecker, TypeChecker& typechecker,
const ScopePtr& scope, const ScopePtr& scope,
const AstExprCall& expr, const AstExprCall& expr,
@ -756,7 +831,7 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionMatch(
return WithPredicate<TypePackId>{returnList}; return WithPredicate<TypePackId>{returnList};
} }
static bool dcrMagicFunctionMatch(MagicFunctionCallContext context) bool MagicMatch::infer(const MagicFunctionCallContext& context)
{ {
const auto& [params, tail] = flatten(context.arguments); const auto& [params, tail] = flatten(context.arguments);
@ -792,7 +867,7 @@ static bool dcrMagicFunctionMatch(MagicFunctionCallContext context)
return true; return true;
} }
static std::optional<WithPredicate<TypePackId>> magicFunctionFind( std::optional<WithPredicate<TypePackId>> MagicFind::handleOldSolver(
TypeChecker& typechecker, TypeChecker& typechecker,
const ScopePtr& scope, const ScopePtr& scope,
const AstExprCall& expr, const AstExprCall& expr,
@ -850,7 +925,7 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionFind(
return WithPredicate<TypePackId>{returnList}; return WithPredicate<TypePackId>{returnList};
} }
static bool dcrMagicFunctionFind(MagicFunctionCallContext context) bool MagicFind::infer(const MagicFunctionCallContext& context)
{ {
const auto& [params, tail] = flatten(context.arguments); const auto& [params, tail] = flatten(context.arguments);
@ -927,12 +1002,9 @@ TypeId makeStringMetatable(NotNull<BuiltinTypes> builtinTypes)
FunctionType formatFTV{arena->addTypePack(TypePack{{stringType}, variadicTailPack}), oneStringPack}; FunctionType formatFTV{arena->addTypePack(TypePack{{stringType}, variadicTailPack}), oneStringPack};
formatFTV.magicFunction = &magicFunctionFormat;
formatFTV.isCheckedFunction = true; formatFTV.isCheckedFunction = true;
const TypeId formatFn = arena->addType(formatFTV); const TypeId formatFn = arena->addType(formatFTV);
attachDcrMagicFunction(formatFn, dcrMagicFunctionFormat); attachMagicFunction(formatFn, std::make_shared<MagicFormat>());
if (FFlag::LuauDCRMagicFunctionTypeChecker)
attachDcrMagicFunctionTypeCheck(formatFn, dcrMagicFunctionTypeCheckFormat);
const TypeId stringToStringType = makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType}, /* checked */ true); const TypeId stringToStringType = makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType}, /* checked */ true);
@ -946,16 +1018,14 @@ TypeId makeStringMetatable(NotNull<BuiltinTypes> builtinTypes)
makeFunction(*arena, stringType, {}, {}, {stringType, replArgType, optionalNumber}, {}, {stringType, numberType}, /* checked */ false); makeFunction(*arena, stringType, {}, {}, {stringType, replArgType, optionalNumber}, {}, {stringType, numberType}, /* checked */ false);
const TypeId gmatchFunc = const TypeId gmatchFunc =
makeFunction(*arena, stringType, {}, {}, {stringType}, {}, {arena->addType(FunctionType{emptyPack, stringVariadicList})}, /* checked */ true); makeFunction(*arena, stringType, {}, {}, {stringType}, {}, {arena->addType(FunctionType{emptyPack, stringVariadicList})}, /* checked */ true);
attachMagicFunction(gmatchFunc, magicFunctionGmatch); attachMagicFunction(gmatchFunc, std::make_shared<MagicGmatch>());
attachDcrMagicFunction(gmatchFunc, dcrMagicFunctionGmatch);
FunctionType matchFuncTy{ FunctionType matchFuncTy{
arena->addTypePack({stringType, stringType, optionalNumber}), arena->addTypePack(TypePackVar{VariadicTypePack{stringType}}) arena->addTypePack({stringType, stringType, optionalNumber}), arena->addTypePack(TypePackVar{VariadicTypePack{stringType}})
}; };
matchFuncTy.isCheckedFunction = true; matchFuncTy.isCheckedFunction = true;
const TypeId matchFunc = arena->addType(matchFuncTy); const TypeId matchFunc = arena->addType(matchFuncTy);
attachMagicFunction(matchFunc, magicFunctionMatch); attachMagicFunction(matchFunc, std::make_shared<MagicMatch>());
attachDcrMagicFunction(matchFunc, dcrMagicFunctionMatch);
FunctionType findFuncTy{ FunctionType findFuncTy{
arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}), arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}),
@ -963,8 +1033,7 @@ TypeId makeStringMetatable(NotNull<BuiltinTypes> builtinTypes)
}; };
findFuncTy.isCheckedFunction = true; findFuncTy.isCheckedFunction = true;
const TypeId findFunc = arena->addType(findFuncTy); const TypeId findFunc = arena->addType(findFuncTy);
attachMagicFunction(findFunc, magicFunctionFind); attachMagicFunction(findFunc, std::make_shared<MagicFind>());
attachDcrMagicFunction(findFunc, dcrMagicFunctionFind);
// string.byte : string -> number? -> number? -> ...number // string.byte : string -> number? -> number? -> ...number
FunctionType stringDotByte{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList}; FunctionType stringDotByte{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList};
@ -1025,7 +1094,7 @@ TypeId makeStringMetatable(NotNull<BuiltinTypes> builtinTypes)
return arena->addType(TableType{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed}); return arena->addType(TableType{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed});
} }
static std::optional<WithPredicate<TypePackId>> magicFunctionSelect( std::optional<WithPredicate<TypePackId>> MagicSelect::handleOldSolver(
TypeChecker& typechecker, TypeChecker& typechecker,
const ScopePtr& scope, const ScopePtr& scope,
const AstExprCall& expr, const AstExprCall& expr,
@ -1070,7 +1139,7 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionSelect(
return std::nullopt; return std::nullopt;
} }
static bool dcrMagicFunctionSelect(MagicFunctionCallContext context) bool MagicSelect::infer(const MagicFunctionCallContext& context)
{ {
if (context.callSite->args.size <= 0) if (context.callSite->args.size <= 0)
{ {
@ -1115,7 +1184,7 @@ static bool dcrMagicFunctionSelect(MagicFunctionCallContext context)
return false; return false;
} }
static std::optional<WithPredicate<TypePackId>> magicFunctionSetMetaTable( std::optional<WithPredicate<TypePackId>> MagicSetMetatable::handleOldSolver(
TypeChecker& typechecker, TypeChecker& typechecker,
const ScopePtr& scope, const ScopePtr& scope,
const AstExprCall& expr, const AstExprCall& expr,
@ -1197,7 +1266,12 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionSetMetaTable(
return WithPredicate<TypePackId>{arena.addTypePack({target})}; return WithPredicate<TypePackId>{arena.addTypePack({target})};
} }
static std::optional<WithPredicate<TypePackId>> magicFunctionAssert( bool MagicSetMetatable::infer(const MagicFunctionCallContext&)
{
return false;
}
std::optional<WithPredicate<TypePackId>> MagicAssert::handleOldSolver(
TypeChecker& typechecker, TypeChecker& typechecker,
const ScopePtr& scope, const ScopePtr& scope,
const AstExprCall& expr, const AstExprCall& expr,
@ -1231,7 +1305,12 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionAssert(
return WithPredicate<TypePackId>{arena.addTypePack(TypePack{std::move(head), tail})}; return WithPredicate<TypePackId>{arena.addTypePack(TypePack{std::move(head), tail})};
} }
static std::optional<WithPredicate<TypePackId>> magicFunctionPack( bool MagicAssert::infer(const MagicFunctionCallContext&)
{
return false;
}
std::optional<WithPredicate<TypePackId>> MagicPack::handleOldSolver(
TypeChecker& typechecker, TypeChecker& typechecker,
const ScopePtr& scope, const ScopePtr& scope,
const AstExprCall& expr, const AstExprCall& expr,
@ -1274,7 +1353,7 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionPack(
return WithPredicate<TypePackId>{arena.addTypePack({packedTable})}; return WithPredicate<TypePackId>{arena.addTypePack({packedTable})};
} }
static bool dcrMagicFunctionPack(MagicFunctionCallContext context) bool MagicPack::infer(const MagicFunctionCallContext& context)
{ {
TypeArena* arena = context.solver->arena; TypeArena* arena = context.solver->arena;
@ -1314,6 +1393,162 @@ static bool dcrMagicFunctionPack(MagicFunctionCallContext context)
return true; return true;
} }
std::optional<WithPredicate<TypePackId>> MagicClone::handleOldSolver(
TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
)
{
LUAU_ASSERT(FFlag::LuauTableCloneClonesType3);
auto [paramPack, _predicates] = withPredicate;
TypeArena& arena = typechecker.currentModule->internalTypes;
const auto& [paramTypes, paramTail] = flatten(paramPack);
if (paramTypes.empty() || expr.args.size == 0)
{
typechecker.reportError(expr.argLocation, CountMismatch{1, std::nullopt, 0});
return std::nullopt;
}
TypeId inputType = follow(paramTypes[0]);
if (!get<TableType>(inputType))
return std::nullopt;
CloneState cloneState{typechecker.builtinTypes};
TypeId resultType = shallowClone(inputType, arena, cloneState);
TypePackId clonedTypePack = arena.addTypePack({resultType});
return WithPredicate<TypePackId>{clonedTypePack};
}
bool MagicClone::infer(const MagicFunctionCallContext& context)
{
LUAU_ASSERT(FFlag::LuauTableCloneClonesType3);
TypeArena* arena = context.solver->arena;
const auto& [paramTypes, paramTail] = flatten(context.arguments);
if (paramTypes.empty() || context.callSite->args.size == 0)
{
context.solver->reportError(CountMismatch{1, std::nullopt, 0}, context.callSite->argLocation);
return false;
}
TypeId inputType = follow(paramTypes[0]);
if (!get<TableType>(inputType))
return false;
CloneState cloneState{context.solver->builtinTypes};
TypeId resultType = shallowClone(inputType, *arena, cloneState, /* ignorePersistent */ FFlag::LuauFreezeIgnorePersistent);
if (auto tableType = getMutable<TableType>(resultType))
{
tableType->scope = context.constraint->scope.get();
}
if (FFlag::LuauTrackInteriorFreeTypesOnScope)
trackInteriorFreeType(context.constraint->scope.get(), resultType);
TypePackId clonedTypePack = arena->addTypePack({resultType});
asMutable(context.result)->ty.emplace<BoundTypePack>(clonedTypePack);
return true;
}
static std::optional<TypeId> freezeTable(TypeId inputType, const MagicFunctionCallContext& context)
{
TypeArena* arena = context.solver->arena;
if (FFlag::LuauFollowTableFreeze)
inputType = follow(inputType);
if (auto mt = get<MetatableType>(inputType))
{
std::optional<TypeId> frozenTable = freezeTable(mt->table, context);
if (!frozenTable)
return std::nullopt;
TypeId resultType = arena->addType(MetatableType{*frozenTable, mt->metatable, mt->syntheticName});
return resultType;
}
if (get<TableType>(inputType))
{
// Clone the input type, this will become our final result type after we mutate it.
CloneState cloneState{context.solver->builtinTypes};
TypeId resultType = shallowClone(inputType, *arena, cloneState, /* ignorePersistent */ FFlag::LuauFreezeIgnorePersistent);
auto tableTy = getMutable<TableType>(resultType);
// `clone` should not break this.
LUAU_ASSERT(tableTy);
tableTy->state = TableState::Sealed;
// We'll mutate the table to make every property type read-only.
for (auto iter = tableTy->props.begin(); iter != tableTy->props.end();)
{
if (iter->second.isWriteOnly())
iter = tableTy->props.erase(iter);
else
{
iter->second.writeTy = std::nullopt;
iter++;
}
}
return resultType;
}
context.solver->reportError(TypeMismatch{context.solver->builtinTypes->tableType, inputType}, context.callSite->argLocation);
return std::nullopt;
}
std::optional<WithPredicate<TypePackId>> MagicFreeze::
handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>)
{
return std::nullopt;
}
bool MagicFreeze::infer(const MagicFunctionCallContext& context)
{
TypeArena* arena = context.solver->arena;
const DataFlowGraph* dfg = context.solver->dfg.get();
Scope* scope = context.constraint->scope.get();
const auto& [paramTypes, paramTail] = extendTypePack(*arena, context.solver->builtinTypes, context.arguments, 1);
if (paramTypes.empty() || context.callSite->args.size == 0)
{
context.solver->reportError(CountMismatch{1, std::nullopt, 0}, context.callSite->argLocation);
return false;
}
TypeId inputType = follow(paramTypes[0]);
AstExpr* targetExpr = context.callSite->args.data[0];
std::optional<DefId> resultDef = dfg->getDefOptional(targetExpr);
std::optional<TypeId> resultTy = resultDef ? scope->lookup(*resultDef) : std::nullopt;
std::optional<TypeId> frozenType = freezeTable(inputType, context);
if (!frozenType)
{
if (resultTy)
asMutable(*resultTy)->ty.emplace<BoundType>(context.solver->builtinTypes->errorType);
asMutable(context.result)->ty.emplace<BoundTypePack>(context.solver->builtinTypes->errorTypePack);
return true;
}
if (resultTy)
asMutable(*resultTy)->ty.emplace<BoundType>(*frozenType);
asMutable(context.result)->ty.emplace<BoundTypePack>(arena->addTypePack({*frozenType}));
return true;
}
static bool checkRequirePath(TypeChecker& typechecker, AstExpr* expr) static bool checkRequirePath(TypeChecker& typechecker, AstExpr* expr)
{ {
// require(foo.parent.bar) will technically work, but it depends on legacy goop that // require(foo.parent.bar) will technically work, but it depends on legacy goop that
@ -1336,7 +1571,7 @@ static bool checkRequirePath(TypeChecker& typechecker, AstExpr* expr)
return good; return good;
} }
static std::optional<WithPredicate<TypePackId>> magicFunctionRequire( std::optional<WithPredicate<TypePackId>> MagicRequire::handleOldSolver(
TypeChecker& typechecker, TypeChecker& typechecker,
const ScopePtr& scope, const ScopePtr& scope,
const AstExprCall& expr, const AstExprCall& expr,
@ -1382,7 +1617,7 @@ static bool checkRequirePathDcr(NotNull<ConstraintSolver> solver, AstExpr* expr)
return good; return good;
} }
static bool dcrMagicFunctionRequire(MagicFunctionCallContext context) bool MagicRequire::infer(const MagicFunctionCallContext& context)
{ {
if (context.callSite->args.size != 1) if (context.callSite->args.size != 1)
{ {
@ -1405,4 +1640,52 @@ static bool dcrMagicFunctionRequire(MagicFunctionCallContext context)
return false; return false;
} }
bool matchSetMetatable(const AstExprCall& call)
{
const char* smt = "setmetatable";
if (call.args.size != 2)
return false;
const AstExprGlobal* funcAsGlobal = call.func->as<AstExprGlobal>();
if (!funcAsGlobal || funcAsGlobal->name != smt)
return false;
return true;
}
bool matchTableFreeze(const AstExprCall& call)
{
if (call.args.size < 1)
return false;
const AstExprIndexName* index = call.func->as<AstExprIndexName>();
if (!index || index->index != "freeze")
return false;
const AstExprGlobal* global = index->expr->as<AstExprGlobal>();
if (!global || global->name != "table")
return false;
return true;
}
bool matchAssert(const AstExprCall& call)
{
if (call.args.size < 1)
return false;
const AstExprGlobal* funcAsGlobal = call.func->as<AstExprGlobal>();
if (!funcAsGlobal || funcAsGlobal->name != "assert")
return false;
return true;
}
bool shouldTypestateForFirstArgument(const AstExprCall& call)
{
// TODO: magic function for setmetatable and assert and then add them
return matchTableFreeze(call);
}
} // namespace Luau } // namespace Luau

View file

@ -7,6 +7,7 @@
#include "Luau/Unifiable.h" #include "Luau/Unifiable.h"
LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAG(LuauFreezeIgnorePersistent)
// For each `Luau::clone` call, we will clone only up to N amount of types _and_ packs, as controlled by this limit. // For each `Luau::clone` call, we will clone only up to N amount of types _and_ packs, as controlled by this limit.
LUAU_FASTINTVARIABLE(LuauTypeCloneIterationLimit, 100'000) LUAU_FASTINTVARIABLE(LuauTypeCloneIterationLimit, 100'000)
@ -38,14 +39,26 @@ class TypeCloner
NotNull<SeenTypes> types; NotNull<SeenTypes> types;
NotNull<SeenTypePacks> packs; NotNull<SeenTypePacks> packs;
TypeId forceTy = nullptr;
TypePackId forceTp = nullptr;
int steps = 0; int steps = 0;
public: public:
TypeCloner(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtinTypes, NotNull<SeenTypes> types, NotNull<SeenTypePacks> packs) TypeCloner(
NotNull<TypeArena> arena,
NotNull<BuiltinTypes> builtinTypes,
NotNull<SeenTypes> types,
NotNull<SeenTypePacks> packs,
TypeId forceTy,
TypePackId forceTp
)
: arena(arena) : arena(arena)
, builtinTypes(builtinTypes) , builtinTypes(builtinTypes)
, types(types) , types(types)
, packs(packs) , packs(packs)
, forceTy(forceTy)
, forceTp(forceTp)
{ {
} }
@ -112,7 +125,7 @@ private:
ty = follow(ty, FollowOption::DisableLazyTypeThunks); ty = follow(ty, FollowOption::DisableLazyTypeThunks);
if (auto it = types->find(ty); it != types->end()) if (auto it = types->find(ty); it != types->end())
return it->second; return it->second;
else if (ty->persistent) else if (ty->persistent && (!FFlag::LuauFreezeIgnorePersistent || ty != forceTy))
return ty; return ty;
return std::nullopt; return std::nullopt;
} }
@ -122,7 +135,7 @@ private:
tp = follow(tp); tp = follow(tp);
if (auto it = packs->find(tp); it != packs->end()) if (auto it = packs->find(tp); it != packs->end())
return it->second; return it->second;
else if (tp->persistent) else if (tp->persistent && (!FFlag::LuauFreezeIgnorePersistent || tp != forceTp))
return tp; return tp;
return std::nullopt; return std::nullopt;
} }
@ -140,7 +153,7 @@ private:
} }
} }
private: public:
TypeId shallowClone(TypeId ty) TypeId shallowClone(TypeId ty)
{ {
// We want to [`Luau::follow`] but without forcing the expansion of [`LazyType`]s. // We want to [`Luau::follow`] but without forcing the expansion of [`LazyType`]s.
@ -148,7 +161,7 @@ private:
if (auto clone = find(ty)) if (auto clone = find(ty))
return *clone; return *clone;
else if (ty->persistent) else if (ty->persistent && (!FFlag::LuauFreezeIgnorePersistent || ty != forceTy))
return ty; return ty;
TypeId target = arena->addType(ty->ty); TypeId target = arena->addType(ty->ty);
@ -174,7 +187,7 @@ private:
if (auto clone = find(tp)) if (auto clone = find(tp))
return *clone; return *clone;
else if (tp->persistent) else if (tp->persistent && (!FFlag::LuauFreezeIgnorePersistent || tp != forceTp))
return tp; return tp;
TypePackId target = arena->addTypePack(tp->ty); TypePackId target = arena->addTypePack(tp->ty);
@ -189,6 +202,7 @@ private:
return target; return target;
} }
private:
Property shallowClone(const Property& p) Property shallowClone(const Property& p)
{ {
if (FFlag::LuauSolverV2) if (FFlag::LuauSolverV2)
@ -256,8 +270,7 @@ private:
LUAU_ASSERT(!"Item holds neither TypeId nor TypePackId when enqueuing its children?"); LUAU_ASSERT(!"Item holds neither TypeId nor TypePackId when enqueuing its children?");
} }
// ErrorType and ErrorTypePack is an alias to this type. void cloneChildren(ErrorType* t)
void cloneChildren(Unifiable::Error* t)
{ {
// noop. // noop.
} }
@ -359,6 +372,11 @@ private:
// noop. // noop.
} }
void cloneChildren(NoRefineType* t)
{
// noop.
}
void cloneChildren(UnionType* t) void cloneChildren(UnionType* t)
{ {
for (TypeId& ty : t->options) for (TypeId& ty : t->options)
@ -422,6 +440,11 @@ private:
t->boundTo = shallowClone(t->boundTo); t->boundTo = shallowClone(t->boundTo);
} }
void cloneChildren(ErrorTypePack* t)
{
// noop.
}
void cloneChildren(VariadicTypePack* t) void cloneChildren(VariadicTypePack* t)
{ {
t->ty = shallowClone(t->ty); t->ty = shallowClone(t->ty);
@ -448,12 +471,46 @@ private:
} // namespace } // namespace
TypePackId shallowClone(TypePackId tp, TypeArena& dest, CloneState& cloneState, bool ignorePersistent)
{
if (tp->persistent && (!FFlag::LuauFreezeIgnorePersistent || !ignorePersistent))
return tp;
TypeCloner cloner{
NotNull{&dest},
cloneState.builtinTypes,
NotNull{&cloneState.seenTypes},
NotNull{&cloneState.seenTypePacks},
nullptr,
FFlag::LuauFreezeIgnorePersistent && ignorePersistent ? tp : nullptr
};
return cloner.shallowClone(tp);
}
TypeId shallowClone(TypeId typeId, TypeArena& dest, CloneState& cloneState, bool ignorePersistent)
{
if (typeId->persistent && (!FFlag::LuauFreezeIgnorePersistent || !ignorePersistent))
return typeId;
TypeCloner cloner{
NotNull{&dest},
cloneState.builtinTypes,
NotNull{&cloneState.seenTypes},
NotNull{&cloneState.seenTypePacks},
FFlag::LuauFreezeIgnorePersistent && ignorePersistent ? typeId : nullptr,
nullptr
};
return cloner.shallowClone(typeId);
}
TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState) TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState)
{ {
if (tp->persistent) if (tp->persistent)
return tp; return tp;
TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}}; TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}, nullptr, nullptr};
return cloner.clone(tp); return cloner.clone(tp);
} }
@ -462,13 +519,13 @@ TypeId clone(TypeId typeId, TypeArena& dest, CloneState& cloneState)
if (typeId->persistent) if (typeId->persistent)
return typeId; return typeId;
TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}}; TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}, nullptr, nullptr};
return cloner.clone(typeId); return cloner.clone(typeId);
} }
TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState) TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState)
{ {
TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}}; TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}, nullptr, nullptr};
TypeFun copy = typeFun; TypeFun copy = typeFun;
@ -493,4 +550,18 @@ TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState)
return copy; return copy;
} }
Binding clone(const Binding& binding, TypeArena& dest, CloneState& cloneState)
{
TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}, nullptr, nullptr};
Binding b;
b.deprecated = binding.deprecated;
b.deprecatedSuggestion = binding.deprecatedSuggestion;
b.documentationSymbol = binding.documentationSymbol;
b.location = binding.location;
b.typeId = cloner.clone(binding.typeId);
return b;
}
} // namespace Luau } // namespace Luau

View file

@ -3,6 +3,8 @@
#include "Luau/Constraint.h" #include "Luau/Constraint.h"
#include "Luau/VisitType.h" #include "Luau/VisitType.h"
LUAU_FASTFLAG(DebugLuauGreedyGeneralization)
namespace Luau namespace Luau
{ {
@ -46,6 +48,20 @@ struct ReferenceCountInitializer : TypeOnceVisitor
// ClassTypes never contain free types. // ClassTypes never contain free types.
return false; return false;
} }
bool visit(TypeId, const TypeFunctionInstanceType&) override
{
// We do not consider reference counted types that are inside a type
// function to be part of the reachable reference counted types.
// Otherwise, code can be constructed in just the right way such
// that two type functions both claim to mutate a free type, which
// prevents either type function from trying to generalize it, so
// we potentially get stuck.
//
// The default behavior here is `true` for "visit the child types"
// of this type, hence:
return false;
}
}; };
bool isReferenceCountedType(const TypeId typ) bool isReferenceCountedType(const TypeId typ)
@ -97,6 +113,11 @@ DenseHashSet<TypeId> Constraint::getMaybeMutatedFreeTypes() const
{ {
rci.traverse(fchc->argsPack); rci.traverse(fchc->argsPack);
} }
else if (auto fcc = get<FunctionCallConstraint>(*this); fcc && FFlag::DebugLuauGreedyGeneralization)
{
rci.traverse(fcc->fn);
rci.traverse(fcc->argsPack);
}
else if (auto ptc = get<PrimitiveTypeConstraint>(*this)) else if (auto ptc = get<PrimitiveTypeConstraint>(*this))
{ {
rci.traverse(ptc->freeType); rci.traverse(ptc->freeType);
@ -104,7 +125,8 @@ DenseHashSet<TypeId> Constraint::getMaybeMutatedFreeTypes() const
else if (auto hpc = get<HasPropConstraint>(*this)) else if (auto hpc = get<HasPropConstraint>(*this))
{ {
rci.traverse(hpc->resultType); rci.traverse(hpc->resultType);
// `HasPropConstraints` should not mutate `subjectType`. if (FFlag::DebugLuauGreedyGeneralization)
rci.traverse(hpc->subjectType);
} }
else if (auto hic = get<HasIndexerConstraint>(*this)) else if (auto hic = get<HasIndexerConstraint>(*this))
{ {
@ -132,6 +154,10 @@ DenseHashSet<TypeId> Constraint::getMaybeMutatedFreeTypes() const
{ {
rci.traverse(rpc->tp); rci.traverse(rpc->tp);
} }
else if (auto tcc = get<TableCheckConstraint>(*this))
{
rci.traverse(tcc->exprType);
}
return types; return types;
} }

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -718,7 +718,7 @@ static DifferResult diffUsingEnv(DifferEnvironment& env, TypeId left, TypeId rig
env.popVisiting(); env.popVisiting();
return diffRes; return diffRes;
} }
if (auto le = get<Luau::Unifiable::Error>(left)) if (auto le = get<ErrorType>(left))
{ {
// TODO: return debug-friendly result state // TODO: return debug-friendly result state
env.popVisiting(); env.popVisiting();

View file

@ -1,102 +1,12 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/BuiltinDefinitions.h" #include "Luau/BuiltinDefinitions.h"
LUAU_FASTFLAGVARIABLE(LuauDebugInfoDefn)
namespace Luau namespace Luau
{ {
static const std::string kBuiltinDefinitionLuaSrcChecked = R"BUILTIN_SRC( static const std::string kBuiltinDefinitionBaseSrc = R"BUILTIN_SRC(
declare bit32: {
band: @checked (...number) -> number,
bor: @checked (...number) -> number,
bxor: @checked (...number) -> number,
btest: @checked (number, ...number) -> boolean,
rrotate: @checked (x: number, disp: number) -> number,
lrotate: @checked (x: number, disp: number) -> number,
lshift: @checked (x: number, disp: number) -> number,
arshift: @checked (x: number, disp: number) -> number,
rshift: @checked (x: number, disp: number) -> number,
bnot: @checked (x: number) -> number,
extract: @checked (n: number, field: number, width: number?) -> number,
replace: @checked (n: number, v: number, field: number, width: number?) -> number,
countlz: @checked (n: number) -> number,
countrz: @checked (n: number) -> number,
byteswap: @checked (n: number) -> number,
}
declare math: {
frexp: @checked (n: number) -> (number, number),
ldexp: @checked (s: number, e: number) -> number,
fmod: @checked (x: number, y: number) -> number,
modf: @checked (n: number) -> (number, number),
pow: @checked (x: number, y: number) -> number,
exp: @checked (n: number) -> number,
ceil: @checked (n: number) -> number,
floor: @checked (n: number) -> number,
abs: @checked (n: number) -> number,
sqrt: @checked (n: number) -> number,
log: @checked (n: number, base: number?) -> number,
log10: @checked (n: number) -> number,
rad: @checked (n: number) -> number,
deg: @checked (n: number) -> number,
sin: @checked (n: number) -> number,
cos: @checked (n: number) -> number,
tan: @checked (n: number) -> number,
sinh: @checked (n: number) -> number,
cosh: @checked (n: number) -> number,
tanh: @checked (n: number) -> number,
atan: @checked (n: number) -> number,
acos: @checked (n: number) -> number,
asin: @checked (n: number) -> number,
atan2: @checked (y: number, x: number) -> number,
min: @checked (number, ...number) -> number,
max: @checked (number, ...number) -> number,
pi: number,
huge: number,
randomseed: @checked (seed: number) -> (),
random: @checked (number?, number?) -> number,
sign: @checked (n: number) -> number,
clamp: @checked (n: number, min: number, max: number) -> number,
noise: @checked (x: number, y: number?, z: number?) -> number,
round: @checked (n: number) -> number,
}
type DateTypeArg = {
year: number,
month: number,
day: number,
hour: number?,
min: number?,
sec: number?,
isdst: boolean?,
}
type DateTypeResult = {
year: number,
month: number,
wday: number,
yday: number,
day: number,
hour: number,
min: number,
sec: number,
isdst: boolean,
}
declare os: {
time: (time: DateTypeArg?) -> number,
date: ((formatString: "*t" | "!*t", time: number?) -> DateTypeResult) & ((formatString: string?, time: number?) -> string),
difftime: (t2: DateTypeResult | number, t1: DateTypeResult | number) -> number,
clock: () -> number,
}
@checked declare function require(target: any): any @checked declare function require(target: any): any
@ -144,6 +54,119 @@ declare function loadstring<A...>(src: string, chunkname: string?): (((A...) ->
@checked declare function newproxy(mt: boolean?): any @checked declare function newproxy(mt: boolean?): any
-- Cannot use `typeof` here because it will produce a polytype when we expect a monotype.
declare function unpack<V>(tab: {V}, i: number?, j: number?): ...V
)BUILTIN_SRC";
static const std::string kBuiltinDefinitionBit32Src = R"BUILTIN_SRC(
declare bit32: {
band: @checked (...number) -> number,
bor: @checked (...number) -> number,
bxor: @checked (...number) -> number,
btest: @checked (number, ...number) -> boolean,
rrotate: @checked (x: number, disp: number) -> number,
lrotate: @checked (x: number, disp: number) -> number,
lshift: @checked (x: number, disp: number) -> number,
arshift: @checked (x: number, disp: number) -> number,
rshift: @checked (x: number, disp: number) -> number,
bnot: @checked (x: number) -> number,
extract: @checked (n: number, field: number, width: number?) -> number,
replace: @checked (n: number, v: number, field: number, width: number?) -> number,
countlz: @checked (n: number) -> number,
countrz: @checked (n: number) -> number,
byteswap: @checked (n: number) -> number,
}
)BUILTIN_SRC";
static const std::string kBuiltinDefinitionMathSrc = R"BUILTIN_SRC(
declare math: {
frexp: @checked (n: number) -> (number, number),
ldexp: @checked (s: number, e: number) -> number,
fmod: @checked (x: number, y: number) -> number,
modf: @checked (n: number) -> (number, number),
pow: @checked (x: number, y: number) -> number,
exp: @checked (n: number) -> number,
ceil: @checked (n: number) -> number,
floor: @checked (n: number) -> number,
abs: @checked (n: number) -> number,
sqrt: @checked (n: number) -> number,
log: @checked (n: number, base: number?) -> number,
log10: @checked (n: number) -> number,
rad: @checked (n: number) -> number,
deg: @checked (n: number) -> number,
sin: @checked (n: number) -> number,
cos: @checked (n: number) -> number,
tan: @checked (n: number) -> number,
sinh: @checked (n: number) -> number,
cosh: @checked (n: number) -> number,
tanh: @checked (n: number) -> number,
atan: @checked (n: number) -> number,
acos: @checked (n: number) -> number,
asin: @checked (n: number) -> number,
atan2: @checked (y: number, x: number) -> number,
min: @checked (number, ...number) -> number,
max: @checked (number, ...number) -> number,
pi: number,
huge: number,
randomseed: @checked (seed: number) -> (),
random: @checked (number?, number?) -> number,
sign: @checked (n: number) -> number,
clamp: @checked (n: number, min: number, max: number) -> number,
noise: @checked (x: number, y: number?, z: number?) -> number,
round: @checked (n: number) -> number,
map: @checked (x: number, inmin: number, inmax: number, outmin: number, outmax: number) -> number,
lerp: @checked (a: number, b: number, t: number) -> number,
}
)BUILTIN_SRC";
static const std::string kBuiltinDefinitionOsSrc = R"BUILTIN_SRC(
type DateTypeArg = {
year: number,
month: number,
day: number,
hour: number?,
min: number?,
sec: number?,
isdst: boolean?,
}
type DateTypeResult = {
year: number,
month: number,
wday: number,
yday: number,
day: number,
hour: number,
min: number,
sec: number,
isdst: boolean,
}
declare os: {
time: (time: DateTypeArg?) -> number,
date: ((formatString: "*t" | "!*t", time: number?) -> DateTypeResult) & ((formatString: string?, time: number?) -> string),
difftime: (t2: DateTypeResult | number, t1: DateTypeResult | number) -> number,
clock: () -> number,
}
)BUILTIN_SRC";
static const std::string kBuiltinDefinitionCoroutineSrc = R"BUILTIN_SRC(
declare coroutine: { declare coroutine: {
create: <A..., R...>(f: (A...) -> R...) -> thread, create: <A..., R...>(f: (A...) -> R...) -> thread,
resume: <A..., R...>(co: thread, A...) -> (boolean, R...), resume: <A..., R...>(co: thread, A...) -> (boolean, R...),
@ -155,6 +178,10 @@ declare coroutine: {
close: @checked (co: thread) -> (boolean, any) close: @checked (co: thread) -> (boolean, any)
} }
)BUILTIN_SRC";
static const std::string kBuiltinDefinitionTableSrc = R"BUILTIN_SRC(
declare table: { declare table: {
concat: <V>(t: {V}, sep: string?, i: number?, j: number?) -> string, concat: <V>(t: {V}, sep: string?, i: number?, j: number?) -> string,
insert: (<V>(t: {V}, value: V) -> ()) & (<V>(t: {V}, pos: number, value: V) -> ()), insert: (<V>(t: {V}, value: V) -> ()) & (<V>(t: {V}, pos: number, value: V) -> ()),
@ -177,11 +204,28 @@ declare table: {
isfrozen: <K, V>(t: {[K]: V}) -> boolean, isfrozen: <K, V>(t: {[K]: V}) -> boolean,
} }
)BUILTIN_SRC";
static const std::string kBuiltinDefinitionDebugSrc = R"BUILTIN_SRC(
declare debug: {
info: ((thread: thread, level: number, options: string) -> ...any) & ((level: number, options: string) -> ...any) & (<A..., R1...>(func: (A...) -> R1..., options: string) -> ...any),
traceback: ((message: string?, level: number?) -> string) & ((thread: thread, message: string?, level: number?) -> string),
}
)BUILTIN_SRC";
static const std::string kBuiltinDefinitionDebugSrc_DEPRECATED = R"BUILTIN_SRC(
declare debug: { declare debug: {
info: (<R...>(thread: thread, level: number, options: string) -> R...) & (<R...>(level: number, options: string) -> R...) & (<A..., R1..., R2...>(func: (A...) -> R1..., options: string) -> R2...), info: (<R...>(thread: thread, level: number, options: string) -> R...) & (<R...>(level: number, options: string) -> R...) & (<A..., R1..., R2...>(func: (A...) -> R1..., options: string) -> R2...),
traceback: ((message: string?, level: number?) -> string) & ((thread: thread, message: string?, level: number?) -> string), traceback: ((message: string?, level: number?) -> string) & ((thread: thread, message: string?, level: number?) -> string),
} }
)BUILTIN_SRC";
static const std::string kBuiltinDefinitionUtf8Src = R"BUILTIN_SRC(
declare utf8: { declare utf8: {
char: @checked (...number) -> string, char: @checked (...number) -> string,
charpattern: string, charpattern: string,
@ -191,10 +235,9 @@ declare utf8: {
offset: @checked (s: string, n: number?, i: number?) -> number, offset: @checked (s: string, n: number?, i: number?) -> number,
} }
-- Cannot use `typeof` here because it will produce a polytype when we expect a monotype. )BUILTIN_SRC";
declare function unpack<V>(tab: {V}, i: number?, j: number?): ...V
static const std::string kBuiltinDefinitionBufferSrc = R"BUILTIN_SRC(
--- Buffer API --- Buffer API
declare buffer: { declare buffer: {
create: @checked (size: number) -> buffer, create: @checked (size: number) -> buffer,
@ -221,13 +264,56 @@ declare buffer: {
writef64: @checked (b: buffer, offset: number, value: number) -> (), writef64: @checked (b: buffer, offset: number, value: number) -> (),
readstring: @checked (b: buffer, offset: number, count: number) -> string, readstring: @checked (b: buffer, offset: number, count: number) -> string,
writestring: @checked (b: buffer, offset: number, value: string, count: number?) -> (), writestring: @checked (b: buffer, offset: number, value: string, count: number?) -> (),
readbits: @checked (b: buffer, bitOffset: number, bitCount: number) -> number,
writebits: @checked (b: buffer, bitOffset: number, bitCount: number, value: number) -> (),
}
)BUILTIN_SRC";
static const std::string kBuiltinDefinitionVectorSrc = R"BUILTIN_SRC(
-- While vector would have been better represented as a built-in primitive type, type solver class handling covers most of the properties
declare class vector
x: number
y: number
z: number
end
declare vector: {
create: @checked (x: number, y: number, z: number?) -> vector,
magnitude: @checked (vec: vector) -> number,
normalize: @checked (vec: vector) -> vector,
cross: @checked (vec1: vector, vec2: vector) -> vector,
dot: @checked (vec1: vector, vec2: vector) -> number,
angle: @checked (vec1: vector, vec2: vector, axis: vector?) -> number,
floor: @checked (vec: vector) -> vector,
ceil: @checked (vec: vector) -> vector,
abs: @checked (vec: vector) -> vector,
sign: @checked (vec: vector) -> vector,
clamp: @checked (vec: vector, min: vector, max: vector) -> vector,
max: @checked (vector, ...vector) -> vector,
min: @checked (vector, ...vector) -> vector,
zero: vector,
one: vector,
} }
)BUILTIN_SRC"; )BUILTIN_SRC";
std::string getBuiltinDefinitionSource() std::string getBuiltinDefinitionSource()
{ {
std::string result = kBuiltinDefinitionLuaSrcChecked; std::string result = kBuiltinDefinitionBaseSrc;
result += kBuiltinDefinitionBit32Src;
result += kBuiltinDefinitionMathSrc;
result += kBuiltinDefinitionOsSrc;
result += kBuiltinDefinitionCoroutineSrc;
result += kBuiltinDefinitionTableSrc;
result += FFlag::LuauDebugInfoDefn ? kBuiltinDefinitionDebugSrc : kBuiltinDefinitionDebugSrc_DEPRECATED;
result += kBuiltinDefinitionUtf8Src;
result += kBuiltinDefinitionBufferSrc;
result += kBuiltinDefinitionVectorSrc;
return result; return result;
} }

File diff suppressed because it is too large Load diff

View file

@ -18,8 +18,6 @@
LUAU_FASTINTVARIABLE(LuauIndentTypeMismatchMaxTypeLength, 10) LUAU_FASTINTVARIABLE(LuauIndentTypeMismatchMaxTypeLength, 10)
LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauImproveNonFunctionCallError, false)
static std::string wrongNumberOfArgsString( static std::string wrongNumberOfArgsString(
size_t expectedCount, size_t expectedCount,
std::optional<size_t> maximumCount, std::optional<size_t> maximumCount,
@ -407,8 +405,6 @@ struct ErrorConverter
} }
std::string operator()(const Luau::CannotCallNonFunction& e) const std::string operator()(const Luau::CannotCallNonFunction& e) const
{
if (DFFlag::LuauImproveNonFunctionCallError)
{ {
if (auto unionTy = get<UnionType>(follow(e.ty))) if (auto unionTy = get<UnionType>(follow(e.ty)))
{ {
@ -435,9 +431,6 @@ struct ErrorConverter
return "Cannot call a value of type " + toString(e.ty); return "Cannot call a value of type " + toString(e.ty);
} }
return "Cannot call non-function " + toString(e.ty);
}
std::string operator()(const Luau::ExtraInformation& e) const std::string operator()(const Luau::ExtraInformation& e) const
{ {
return e.message; return e.message;
@ -793,6 +786,11 @@ struct ErrorConverter
return "Encountered an unexpected type pack in subtyping: " + toString(e.tp); return "Encountered an unexpected type pack in subtyping: " + toString(e.tp);
} }
std::string operator()(const UserDefinedTypeFunctionError& e) const
{
return e.message;
}
std::string operator()(const CannotAssignToNever& e) const std::string operator()(const CannotAssignToNever& e) const
{ {
std::string result = "Cannot assign a value of type " + toString(e.rhsType) + " to a field of type never"; std::string result = "Cannot assign a value of type " + toString(e.rhsType) + " to a field of type never";
@ -1175,6 +1173,11 @@ bool UnexpectedTypePackInSubtyping::operator==(const UnexpectedTypePackInSubtypi
return tp == rhs.tp; return tp == rhs.tp;
} }
bool UserDefinedTypeFunctionError::operator==(const UserDefinedTypeFunctionError& rhs) const
{
return message == rhs.message;
}
bool CannotAssignToNever::operator==(const CannotAssignToNever& rhs) const bool CannotAssignToNever::operator==(const CannotAssignToNever& rhs) const
{ {
if (cause.size() != rhs.cause.size()) if (cause.size() != rhs.cause.size())
@ -1384,6 +1387,9 @@ void copyError(T& e, TypeArena& destArena, CloneState& cloneState)
e.ty = clone(e.ty); e.ty = clone(e.ty);
else if constexpr (std::is_same_v<T, UnexpectedTypePackInSubtyping>) else if constexpr (std::is_same_v<T, UnexpectedTypePackInSubtyping>)
e.tp = clone(e.tp); e.tp = clone(e.tp);
else if constexpr (std::is_same_v<T, UserDefinedTypeFunctionError>)
{
}
else if constexpr (std::is_same_v<T, CannotAssignToNever>) else if constexpr (std::is_same_v<T, CannotAssignToNever>)
{ {
e.rhsType = clone(e.rhsType); e.rhsType = clone(e.rhsType);

View file

@ -0,0 +1,708 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/FragmentAutocomplete.h"
#include "Luau/Ast.h"
#include "Luau/AstQuery.h"
#include "Luau/Autocomplete.h"
#include "Luau/Common.h"
#include "Luau/EqSatSimplification.h"
#include "Luau/ModuleResolver.h"
#include "Luau/Parser.h"
#include "Luau/ParseOptions.h"
#include "Luau/Module.h"
#include "Luau/TimeTrace.h"
#include "Luau/UnifierSharedState.h"
#include "Luau/TypeFunction.h"
#include "Luau/DataFlowGraph.h"
#include "Luau/ConstraintGenerator.h"
#include "Luau/ConstraintSolver.h"
#include "Luau/Frontend.h"
#include "Luau/Parser.h"
#include "Luau/ParseOptions.h"
#include "Luau/Module.h"
#include "Luau/Clone.h"
#include "AutocompleteCore.h"
LUAU_FASTINT(LuauTypeInferRecursionLimit);
LUAU_FASTINT(LuauTypeInferIterationLimit);
LUAU_FASTINT(LuauTarjanChildLimit)
LUAU_FASTFLAG(LuauAutocompleteRefactorsForIncrementalAutocomplete)
LUAU_FASTFLAGVARIABLE(LuauIncrementalAutocompleteBugfixes)
LUAU_FASTFLAGVARIABLE(LuauMixedModeDefFinderTraversesTypeOf)
LUAU_FASTFLAG(LuauBetterReverseDependencyTracking)
LUAU_FASTFLAGVARIABLE(LuauCloneIncrementalModule)
LUAU_FASTFLAGVARIABLE(LogFragmentsFromAutocomplete)
namespace
{
template<typename T>
void copyModuleVec(std::vector<T>& result, const std::vector<T>& input)
{
result.insert(result.end(), input.begin(), input.end());
}
template<typename K, typename V>
void copyModuleMap(Luau::DenseHashMap<K, V>& result, const Luau::DenseHashMap<K, V>& input)
{
for (auto [k, v] : input)
result[k] = v;
}
} // namespace
namespace Luau
{
template<typename K, typename V>
void cloneModuleMap(TypeArena& destArena, CloneState& cloneState, const Luau::DenseHashMap<K, V>& source, Luau::DenseHashMap<K, V>& dest)
{
for (auto [k, v] : source)
{
dest[k] = Luau::clone(v, destArena, cloneState);
}
}
struct MixedModeIncrementalTCDefFinder : public AstVisitor
{
bool visit(AstExprLocal* local) override
{
referencedLocalDefs.emplace_back(local->local, local);
return true;
}
bool visit(AstTypeTypeof* node) override
{
// We need to traverse typeof expressions because they may refer to locals that we need
// to populate the local environment for fragment typechecking. For example, `typeof(m)`
// requires that we find the local/global `m` and place it in the environment.
// The default behaviour here is to return false, and have individual visitors override
// the specific behaviour they need.
return FFlag::LuauMixedModeDefFinderTraversesTypeOf;
}
// ast defs is just a mapping from expr -> def in general
// will get built up by the dfg builder
// localDefs, we need to copy over
std::vector<std::pair<AstLocal*, AstExpr*>> referencedLocalDefs;
};
void cloneAndSquashScopes(
CloneState& cloneState,
const Scope* staleScope,
const ModulePtr& staleModule,
NotNull<TypeArena> destArena,
NotNull<DataFlowGraph> dfg,
AstStatBlock* program,
Scope* destScope
)
{
LUAU_TIMETRACE_SCOPE("Luau::cloneAndSquashScopes", "FragmentAutocomplete");
std::vector<const Scope*> scopes;
for (const Scope* current = staleScope; current; current = current->parent.get())
{
scopes.emplace_back(current);
}
// in reverse order (we need to clone the parents and override defs as we go down the list)
for (auto it = scopes.rbegin(); it != scopes.rend(); ++it)
{
const Scope* curr = *it;
// Clone the lvalue types
for (const auto& [def, ty] : curr->lvalueTypes)
destScope->lvalueTypes[def] = Luau::clone(ty, *destArena, cloneState);
// Clone the rvalueRefinements
for (const auto& [def, ty] : curr->rvalueRefinements)
destScope->rvalueRefinements[def] = Luau::clone(ty, *destArena, cloneState);
for (const auto& [n, m] : curr->importedTypeBindings)
{
std::unordered_map<Name, TypeFun> importedBindingTypes;
for (const auto& [v, tf] : m)
importedBindingTypes[v] = Luau::clone(tf, *destArena, cloneState);
destScope->importedTypeBindings[n] = m;
}
// Finally, clone up the bindings
for (const auto& [s, b] : curr->bindings)
{
destScope->bindings[s] = Luau::clone(b, *destArena, cloneState);
}
}
// The above code associates defs with TypeId's in the scope
// so that lookup to locals will succeed.
MixedModeIncrementalTCDefFinder finder;
program->visit(&finder);
std::vector<std::pair<AstLocal*, AstExpr*>> locals = std::move(finder.referencedLocalDefs);
for (auto [loc, expr] : locals)
{
if (std::optional<Binding> binding = staleScope->linearSearchForBinding(loc->name.value, true))
{
destScope->lvalueTypes[dfg->getDef(expr)] = Luau::clone(binding->typeId, *destArena, cloneState);
}
}
return;
}
static FrontendModuleResolver& getModuleResolver(Frontend& frontend, std::optional<FrontendOptions> options)
{
if (FFlag::LuauSolverV2 || !options)
return frontend.moduleResolver;
return options->forAutocomplete ? frontend.moduleResolverForAutocomplete : frontend.moduleResolver;
}
FragmentAutocompleteAncestryResult findAncestryForFragmentParse(AstStatBlock* root, const Position& cursorPos)
{
std::vector<AstNode*> ancestry = findAncestryAtPositionForAutocomplete(root, cursorPos);
// Should always contain the root AstStat
LUAU_ASSERT(ancestry.size() >= 1);
DenseHashMap<AstName, AstLocal*> localMap{AstName()};
std::vector<AstLocal*> localStack;
AstStat* nearestStatement = nullptr;
for (AstNode* node : ancestry)
{
if (auto block = node->as<AstStatBlock>())
{
for (auto stat : block->body)
{
if (stat->location.begin <= cursorPos)
nearestStatement = stat;
if (stat->location.begin < cursorPos && stat->location.begin.line < cursorPos.line)
{
// This statement precedes the current one
if (auto loc = stat->as<AstStatLocal>())
{
for (auto v : loc->vars)
{
localStack.push_back(v);
localMap[v->name] = v;
}
}
else if (auto locFun = stat->as<AstStatLocalFunction>())
{
localStack.push_back(locFun->name);
localMap[locFun->name->name] = locFun->name;
if (locFun->location.contains(cursorPos))
{
for (AstLocal* loc : locFun->func->args)
{
localStack.push_back(loc);
localMap[loc->name] = loc;
}
}
}
else if (auto globFun = stat->as<AstStatFunction>())
{
if (globFun->location.contains(cursorPos))
{
for (AstLocal* loc : globFun->func->args)
{
localStack.push_back(loc);
localMap[loc->name] = loc;
}
}
}
}
}
}
}
if (!nearestStatement)
nearestStatement = ancestry[0]->asStat();
LUAU_ASSERT(nearestStatement);
return {std::move(localMap), std::move(localStack), std::move(ancestry), std::move(nearestStatement)};
}
/**
* Get document offsets is a function that takes a source text document as well as a start position and end position(line, column) in that
* document and attempts to get the concrete text between those points. It returns a pair of:
* - start offset that represents an index in the source `char*` corresponding to startPos
* - length, that represents how many more bytes to read to get to endPos.
* Example - your document is "foo bar baz" and getDocumentOffsets is passed (0, 4), (0, 8). This function returns the pair {3, 5}
* which corresponds to the string " bar "
*/
std::pair<size_t, size_t> getDocumentOffsets(const std::string_view& src, const Position& startPos, const Position& endPos)
{
size_t lineCount = 0;
size_t colCount = 0;
size_t docOffset = 0;
size_t startOffset = 0;
size_t endOffset = 0;
bool foundStart = false;
bool foundEnd = false;
for (char c : src)
{
if (foundStart && foundEnd)
break;
if (startPos.line == lineCount && startPos.column == colCount)
{
foundStart = true;
startOffset = docOffset;
}
if (endPos.line == lineCount && endPos.column == colCount)
{
endOffset = docOffset;
while (endOffset < src.size() && src[endOffset] != '\n')
endOffset++;
foundEnd = true;
}
// We put a cursor position that extends beyond the extents of the current line
if (foundStart && !foundEnd && (lineCount > endPos.line))
{
foundEnd = true;
endOffset = docOffset - 1;
}
if (c == '\n')
{
lineCount++;
colCount = 0;
}
else
{
colCount++;
}
docOffset++;
}
if (foundStart && !foundEnd)
endOffset = src.length();
size_t min = std::min(startOffset, endOffset);
size_t len = std::max(startOffset, endOffset) - min;
return {min, len};
}
ScopePtr findClosestScope(const ModulePtr& module, const AstStat* nearestStatement)
{
LUAU_ASSERT(module->hasModuleScope());
ScopePtr closest = module->getModuleScope();
// find the scope the nearest statement belonged to.
for (auto [loc, sc] : module->scopes)
{
if (loc.encloses(nearestStatement->location) && closest->location.begin <= loc.begin)
closest = sc;
}
return closest;
}
std::optional<FragmentParseResult> parseFragment(
const SourceModule& srcModule,
std::string_view src,
const Position& cursorPos,
std::optional<Position> fragmentEndPosition
)
{
FragmentAutocompleteAncestryResult result = findAncestryForFragmentParse(srcModule.root, cursorPos);
AstStat* nearestStatement = result.nearestStatement;
const Location& rootSpan = srcModule.root->location;
// Did we append vs did we insert inline
bool appended = cursorPos >= rootSpan.end;
// statement spans multiple lines
bool multiline = nearestStatement->location.begin.line != nearestStatement->location.end.line;
const Position endPos = fragmentEndPosition.value_or(cursorPos);
// We start by re-parsing everything (we'll refine this as we go)
Position startPos = srcModule.root->location.begin;
// If we added to the end of the sourceModule, use the end of the nearest location
if (appended && multiline)
startPos = nearestStatement->location.end;
// Statement spans one line && cursorPos is either on the same line or after
else if (!multiline && cursorPos.line >= nearestStatement->location.end.line)
startPos = nearestStatement->location.begin;
else if (multiline && nearestStatement->location.end.line < cursorPos.line)
startPos = nearestStatement->location.end;
else
startPos = nearestStatement->location.begin;
auto [offsetStart, parseLength] = getDocumentOffsets(src, startPos, endPos);
const char* srcStart = src.data() + offsetStart;
std::string_view dbg = src.substr(offsetStart, parseLength);
const std::shared_ptr<AstNameTable>& nameTbl = srcModule.names;
FragmentParseResult fragmentResult;
fragmentResult.fragmentToParse = std::string(dbg.data(), parseLength);
// For the duration of the incremental parse, we want to allow the name table to re-use duplicate names
if (FFlag::LogFragmentsFromAutocomplete)
logLuau(dbg);
ParseOptions opts;
opts.allowDeclarationSyntax = false;
opts.captureComments = true;
opts.parseFragment = FragmentParseResumeSettings{std::move(result.localMap), std::move(result.localStack), startPos};
ParseResult p = Luau::Parser::parse(srcStart, parseLength, *nameTbl, *fragmentResult.alloc.get(), opts);
// This means we threw a ParseError and we should decline to offer autocomplete here.
if (p.root == nullptr)
return std::nullopt;
std::vector<AstNode*> fabricatedAncestry = std::move(result.ancestry);
// Get the ancestry for the fragment at the offset cursor position.
// Consumers have the option to request with fragment end position, so we cannot just use the end position of our parse result as the
// cursor position. Instead, use the cursor position calculated as an offset from our start position.
std::vector<AstNode*> fragmentAncestry = findAncestryAtPositionForAutocomplete(p.root, cursorPos);
fabricatedAncestry.insert(fabricatedAncestry.end(), fragmentAncestry.begin(), fragmentAncestry.end());
if (nearestStatement == nullptr)
nearestStatement = p.root;
fragmentResult.root = std::move(p.root);
fragmentResult.ancestry = std::move(fabricatedAncestry);
fragmentResult.nearestStatement = nearestStatement;
fragmentResult.commentLocations = std::move(p.commentLocations);
return fragmentResult;
}
ModulePtr cloneModule(CloneState& cloneState, const ModulePtr& source, std::unique_ptr<Allocator> alloc)
{
LUAU_TIMETRACE_SCOPE("Luau::cloneModule", "FragmentAutocomplete");
freeze(source->internalTypes);
freeze(source->interfaceTypes);
ModulePtr incremental = std::make_shared<Module>();
incremental->name = source->name;
incremental->humanReadableName = source->humanReadableName;
incremental->allocator = std::move(alloc);
// Clone types
cloneModuleMap(incremental->internalTypes, cloneState, source->astTypes, incremental->astTypes);
cloneModuleMap(incremental->internalTypes, cloneState, source->astTypePacks, incremental->astTypePacks);
cloneModuleMap(incremental->internalTypes, cloneState, source->astExpectedTypes, incremental->astExpectedTypes);
cloneModuleMap(incremental->internalTypes, cloneState, source->astOverloadResolvedTypes, incremental->astOverloadResolvedTypes);
cloneModuleMap(incremental->internalTypes, cloneState, source->astForInNextTypes, incremental->astForInNextTypes);
copyModuleMap(incremental->astScopes, source->astScopes);
return incremental;
}
ModulePtr copyModule(const ModulePtr& result, std::unique_ptr<Allocator> alloc)
{
ModulePtr incrementalModule = std::make_shared<Module>();
incrementalModule->name = result->name;
incrementalModule->humanReadableName = "Incremental$" + result->humanReadableName;
incrementalModule->internalTypes.owningModule = incrementalModule.get();
incrementalModule->interfaceTypes.owningModule = incrementalModule.get();
incrementalModule->allocator = std::move(alloc);
// Don't need to keep this alive (it's already on the source module)
copyModuleVec(incrementalModule->scopes, result->scopes);
copyModuleMap(incrementalModule->astTypes, result->astTypes);
copyModuleMap(incrementalModule->astTypePacks, result->astTypePacks);
copyModuleMap(incrementalModule->astExpectedTypes, result->astExpectedTypes);
// Don't need to clone astOriginalCallTypes
copyModuleMap(incrementalModule->astOverloadResolvedTypes, result->astOverloadResolvedTypes);
// Don't need to clone astForInNextTypes
copyModuleMap(incrementalModule->astForInNextTypes, result->astForInNextTypes);
// Don't need to clone astResolvedTypes
// Don't need to clone astResolvedTypePacks
// Don't need to clone upperBoundContributors
copyModuleMap(incrementalModule->astScopes, result->astScopes);
// Don't need to clone declared Globals;
return incrementalModule;
}
void mixedModeCompatibility(
const ScopePtr& bottomScopeStale,
const ScopePtr& myFakeScope,
const ModulePtr& stale,
NotNull<DataFlowGraph> dfg,
AstStatBlock* program
)
{
// This code does the following
// traverse program
// look for ast refs for locals
// ask for the corresponding defId from dfg
// given that defId, and that expression, in the incremental module, map lvalue types from defID to
MixedModeIncrementalTCDefFinder finder;
program->visit(&finder);
std::vector<std::pair<AstLocal*, AstExpr*>> locals = std::move(finder.referencedLocalDefs);
for (auto [loc, expr] : locals)
{
if (std::optional<Binding> binding = bottomScopeStale->linearSearchForBinding(loc->name.value, true))
{
myFakeScope->lvalueTypes[dfg->getDef(expr)] = binding->typeId;
}
}
}
FragmentTypeCheckResult typecheckFragment_(
Frontend& frontend,
AstStatBlock* root,
const ModulePtr& stale,
const ScopePtr& closestScope,
const Position& cursorPos,
std::unique_ptr<Allocator> astAllocator,
const FrontendOptions& opts
)
{
LUAU_TIMETRACE_SCOPE("Luau::typecheckFragment_", "FragmentAutocomplete");
freeze(stale->internalTypes);
freeze(stale->interfaceTypes);
CloneState cloneState{frontend.builtinTypes};
ModulePtr incrementalModule =
FFlag::LuauCloneIncrementalModule ? cloneModule(cloneState, stale, std::move(astAllocator)) : copyModule(stale, std::move(astAllocator));
incrementalModule->checkedInNewSolver = true;
unfreeze(incrementalModule->internalTypes);
unfreeze(incrementalModule->interfaceTypes);
/// Setup typecheck limits
TypeCheckLimits limits;
if (opts.moduleTimeLimitSec)
limits.finishTime = TimeTrace::getClock() + *opts.moduleTimeLimitSec;
else
limits.finishTime = std::nullopt;
limits.cancellationToken = opts.cancellationToken;
/// Icehandler
NotNull<InternalErrorReporter> iceHandler{&frontend.iceHandler};
/// Make the shared state for the unifier (recursion + iteration limits)
UnifierSharedState unifierState{iceHandler};
unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit;
unifierState.counters.iterationLimit = limits.unifierIterationLimit.value_or(FInt::LuauTypeInferIterationLimit);
/// Initialize the normalizer
Normalizer normalizer{&incrementalModule->internalTypes, frontend.builtinTypes, NotNull{&unifierState}};
/// User defined type functions runtime
TypeFunctionRuntime typeFunctionRuntime(iceHandler, NotNull{&limits});
/// Create a DataFlowGraph just for the surrounding context
DataFlowGraph dfg = DataFlowGraphBuilder::build(root, NotNull{&incrementalModule->defArena}, NotNull{&incrementalModule->keyArena}, iceHandler);
SimplifierPtr simplifier = newSimplifier(NotNull{&incrementalModule->internalTypes}, frontend.builtinTypes);
FrontendModuleResolver& resolver = getModuleResolver(frontend, opts);
/// Contraint Generator
ConstraintGenerator cg{
incrementalModule,
NotNull{&normalizer},
NotNull{simplifier.get()},
NotNull{&typeFunctionRuntime},
NotNull{&resolver},
frontend.builtinTypes,
iceHandler,
stale->getModuleScope(),
nullptr,
nullptr,
NotNull{&dfg},
{}
};
std::shared_ptr<Scope> freshChildOfNearestScope = nullptr;
if (FFlag::LuauCloneIncrementalModule)
{
freshChildOfNearestScope = std::make_shared<Scope>(closestScope);
incrementalModule->scopes.emplace_back(root->location, freshChildOfNearestScope);
cg.rootScope = freshChildOfNearestScope.get();
cloneAndSquashScopes(
cloneState, closestScope.get(), stale, NotNull{&incrementalModule->internalTypes}, NotNull{&dfg}, root, freshChildOfNearestScope.get()
);
cg.visitFragmentRoot(freshChildOfNearestScope, root);
}
else
{
// Any additions to the scope must occur in a fresh scope
cg.rootScope = stale->getModuleScope().get();
freshChildOfNearestScope = std::make_shared<Scope>(closestScope);
incrementalModule->scopes.emplace_back(root->location, freshChildOfNearestScope);
mixedModeCompatibility(closestScope, freshChildOfNearestScope, stale, NotNull{&dfg}, root);
// closest Scope -> children = { ...., freshChildOfNearestScope}
// We need to trim nearestChild from the scope hierarcy
closestScope->children.emplace_back(freshChildOfNearestScope.get());
cg.visitFragmentRoot(freshChildOfNearestScope, root);
// Trim nearestChild from the closestScope
Scope* back = closestScope->children.back().get();
LUAU_ASSERT(back == freshChildOfNearestScope.get());
closestScope->children.pop_back();
}
/// Initialize the constraint solver and run it
ConstraintSolver cs{
NotNull{&normalizer},
NotNull{simplifier.get()},
NotNull{&typeFunctionRuntime},
NotNull(cg.rootScope),
borrowConstraints(cg.constraints),
NotNull{&cg.scopeToFunction},
incrementalModule->name,
NotNull{&resolver},
{},
nullptr,
NotNull{&dfg},
limits
};
try
{
cs.run();
}
catch (const TimeLimitError&)
{
stale->timeout = true;
}
catch (const UserCancelError&)
{
stale->cancelled = true;
}
// In frontend we would forbid internal types
// because this is just for autocomplete, we don't actually care
// We also don't even need to typecheck - just synthesize types as best as we can
freeze(incrementalModule->internalTypes);
freeze(incrementalModule->interfaceTypes);
return {std::move(incrementalModule), std::move(freshChildOfNearestScope)};
}
std::pair<FragmentTypeCheckStatus, FragmentTypeCheckResult> typecheckFragment(
Frontend& frontend,
const ModuleName& moduleName,
const Position& cursorPos,
std::optional<FrontendOptions> opts,
std::string_view src,
std::optional<Position> fragmentEndPosition
)
{
LUAU_TIMETRACE_SCOPE("Luau::typecheckFragment", "FragmentAutocomplete");
LUAU_TIMETRACE_ARGUMENT("name", moduleName.c_str());
if (FFlag::LuauBetterReverseDependencyTracking)
{
if (!frontend.allModuleDependenciesValid(moduleName, opts && opts->forAutocomplete))
return {FragmentTypeCheckStatus::SkipAutocomplete, {}};
}
const SourceModule* sourceModule = frontend.getSourceModule(moduleName);
if (!sourceModule)
{
LUAU_ASSERT(!"Expected Source Module for fragment typecheck");
return {};
}
FrontendModuleResolver& resolver = getModuleResolver(frontend, opts);
ModulePtr module = resolver.getModule(moduleName);
if (!module)
{
LUAU_ASSERT(!"Expected Module for fragment typecheck");
return {};
}
if (FFlag::LuauIncrementalAutocompleteBugfixes)
{
if (sourceModule->allocator.get() != module->allocator.get())
{
return {FragmentTypeCheckStatus::SkipAutocomplete, {}};
}
}
auto tryParse = parseFragment(*sourceModule, src, cursorPos, fragmentEndPosition);
if (!tryParse)
return {FragmentTypeCheckStatus::SkipAutocomplete, {}};
FragmentParseResult& parseResult = *tryParse;
if (isWithinComment(parseResult.commentLocations, fragmentEndPosition.value_or(cursorPos)))
return {FragmentTypeCheckStatus::SkipAutocomplete, {}};
FrontendOptions frontendOptions = opts.value_or(frontend.options);
const ScopePtr& closestScope = findClosestScope(module, parseResult.nearestStatement);
FragmentTypeCheckResult result =
typecheckFragment_(frontend, parseResult.root, module, closestScope, cursorPos, std::move(parseResult.alloc), frontendOptions);
result.ancestry = std::move(parseResult.ancestry);
return {FragmentTypeCheckStatus::Success, result};
}
FragmentAutocompleteStatusResult tryFragmentAutocomplete(
Frontend& frontend,
const ModuleName& moduleName,
Position cursorPosition,
FragmentContext context,
StringCompletionCallback stringCompletionCB
)
{
// TODO: we should calculate fragmentEnd position here, by using context.newAstRoot and cursorPosition
try
{
Luau::FragmentAutocompleteResult fragmentAutocomplete = Luau::fragmentAutocomplete(
frontend,
context.newSrc,
moduleName,
cursorPosition,
context.opts,
std::move(stringCompletionCB),
context.DEPRECATED_fragmentEndPosition
);
return {FragmentAutocompleteStatus::Success, std::move(fragmentAutocomplete)};
}
catch (const Luau::InternalCompilerError& e)
{
if (FFlag::LogFragmentsFromAutocomplete)
logLuau(e.what());
return {FragmentAutocompleteStatus::InternalIce, std::nullopt};
}
}
FragmentAutocompleteResult fragmentAutocomplete(
Frontend& frontend,
std::string_view src,
const ModuleName& moduleName,
Position cursorPosition,
std::optional<FrontendOptions> opts,
StringCompletionCallback callback,
std::optional<Position> fragmentEndPosition
)
{
LUAU_ASSERT(FFlag::LuauAutocompleteRefactorsForIncrementalAutocomplete);
LUAU_TIMETRACE_SCOPE("Luau::fragmentAutocomplete", "FragmentAutocomplete");
LUAU_TIMETRACE_ARGUMENT("name", moduleName.c_str());
const SourceModule* sourceModule = frontend.getSourceModule(moduleName);
if (!sourceModule)
{
LUAU_ASSERT(!"Expected Source Module for fragment typecheck");
return {};
}
// If the cursor is within a comment in the stale source module we should avoid providing a recommendation
if (isWithinComment(*sourceModule, fragmentEndPosition.value_or(cursorPosition)))
return {};
auto [tcStatus, tcResult] = typecheckFragment(frontend, moduleName, cursorPosition, opts, src, fragmentEndPosition);
if (tcStatus == FragmentTypeCheckStatus::SkipAutocomplete)
return {};
auto globalScope = (opts && opts->forAutocomplete) ? frontend.globalsForAutocomplete.globalScope.get() : frontend.globals.globalScope.get();
if (FFlag::LogFragmentsFromAutocomplete)
logLuau(src);
TypeArena arenaForFragmentAutocomplete;
auto result = Luau::autocomplete_(
tcResult.incrementalModule,
frontend.builtinTypes,
&arenaForFragmentAutocomplete,
tcResult.ancestry,
globalScope,
tcResult.freshScope,
cursorPosition,
frontend.fileResolver,
callback
);
return {std::move(tcResult.incrementalModule), tcResult.freshScope.get(), std::move(arenaForFragmentAutocomplete), std::move(result)};
}
} // namespace Luau

View file

@ -10,8 +10,10 @@
#include "Luau/ConstraintSolver.h" #include "Luau/ConstraintSolver.h"
#include "Luau/DataFlowGraph.h" #include "Luau/DataFlowGraph.h"
#include "Luau/DcrLogger.h" #include "Luau/DcrLogger.h"
#include "Luau/EqSatSimplification.h"
#include "Luau/FileResolver.h" #include "Luau/FileResolver.h"
#include "Luau/NonStrictTypeChecker.h" #include "Luau/NonStrictTypeChecker.h"
#include "Luau/NotNull.h"
#include "Luau/Parser.h" #include "Luau/Parser.h"
#include "Luau/Scope.h" #include "Luau/Scope.h"
#include "Luau/StringUtils.h" #include "Luau/StringUtils.h"
@ -36,18 +38,21 @@ LUAU_FASTINT(LuauTypeInferIterationLimit)
LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTINT(LuauTypeInferRecursionLimit)
LUAU_FASTINT(LuauTarjanChildLimit) LUAU_FASTINT(LuauTarjanChildLimit)
LUAU_FASTFLAG(LuauInferInNoCheckMode) LUAU_FASTFLAG(LuauInferInNoCheckMode)
LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3)
LUAU_FASTFLAGVARIABLE(LuauStoreCommentsForDefinitionFiles, false)
LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false) LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson)
LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJsonFile, false) LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJsonFile)
LUAU_FASTFLAGVARIABLE(DebugLuauForbidInternalTypes, false) LUAU_FASTFLAGVARIABLE(DebugLuauForbidInternalTypes)
LUAU_FASTFLAGVARIABLE(DebugLuauForceStrictMode, false) LUAU_FASTFLAGVARIABLE(DebugLuauForceStrictMode)
LUAU_FASTFLAGVARIABLE(DebugLuauForceNonStrictMode, false) LUAU_FASTFLAGVARIABLE(DebugLuauForceNonStrictMode)
LUAU_FASTFLAGVARIABLE(LuauSourceModuleUpdatedWithSelectedMode, false)
LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauRunCustomModuleChecks, false) LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauRunCustomModuleChecks, false)
LUAU_FASTFLAGVARIABLE(LuauBetterReverseDependencyTracking)
LUAU_FASTFLAG(StudioReportLuauAny2) LUAU_FASTFLAG(StudioReportLuauAny2)
LUAU_FASTFLAGVARIABLE(LuauStoreSolverTypeOnModule)
LUAU_FASTFLAGVARIABLE(LuauSelectivelyRetainDFGArena)
namespace Luau namespace Luau
{ {
@ -134,7 +139,7 @@ static ParseResult parseSourceForModule(std::string_view source, Luau::SourceMod
sourceModule.root = parseResult.root; sourceModule.root = parseResult.root;
sourceModule.mode = Mode::Definition; sourceModule.mode = Mode::Definition;
if (FFlag::LuauStoreCommentsForDefinitionFiles && options.captureComments) if (options.captureComments)
{ {
sourceModule.hotcomments = parseResult.hotcomments; sourceModule.hotcomments = parseResult.hotcomments;
sourceModule.commentLocations = parseResult.commentLocations; sourceModule.commentLocations = parseResult.commentLocations;
@ -205,72 +210,6 @@ LoadDefinitionFileResult Frontend::loadDefinitionFile(
return LoadDefinitionFileResult{true, parseResult, sourceModule, checkedModule}; return LoadDefinitionFileResult{true, parseResult, sourceModule, checkedModule};
} }
std::vector<std::string_view> parsePathExpr(const AstExpr& pathExpr)
{
const AstExprIndexName* indexName = pathExpr.as<AstExprIndexName>();
if (!indexName)
return {};
std::vector<std::string_view> segments{indexName->index.value};
while (true)
{
if (AstExprIndexName* in = indexName->expr->as<AstExprIndexName>())
{
segments.push_back(in->index.value);
indexName = in;
continue;
}
else if (AstExprGlobal* indexNameAsGlobal = indexName->expr->as<AstExprGlobal>())
{
segments.push_back(indexNameAsGlobal->name.value);
break;
}
else if (AstExprLocal* indexNameAsLocal = indexName->expr->as<AstExprLocal>())
{
segments.push_back(indexNameAsLocal->local->name.value);
break;
}
else
return {};
}
std::reverse(segments.begin(), segments.end());
return segments;
}
std::optional<std::string> pathExprToModuleName(const ModuleName& currentModuleName, const std::vector<std::string_view>& segments)
{
if (segments.empty())
return std::nullopt;
std::vector<std::string_view> result;
auto it = segments.begin();
if (*it == "script" && !currentModuleName.empty())
{
result = split(currentModuleName, '/');
++it;
}
for (; it != segments.end(); ++it)
{
if (result.size() > 1 && *it == "Parent")
result.pop_back();
else
result.push_back(*it);
}
return join(result, "/");
}
std::optional<std::string> pathExprToModuleName(const ModuleName& currentModuleName, const AstExpr& pathExpr)
{
std::vector<std::string_view> segments = parsePathExpr(pathExpr);
return pathExprToModuleName(currentModuleName, segments);
}
namespace namespace
{ {
@ -351,8 +290,7 @@ static void filterLintOptions(LintOptions& lintOptions, const std::vector<HotCom
std::vector<RequireCycle> getRequireCycles( std::vector<RequireCycle> getRequireCycles(
const FileResolver* resolver, const FileResolver* resolver,
const std::unordered_map<ModuleName, std::shared_ptr<SourceNode>>& sourceNodes, const std::unordered_map<ModuleName, std::shared_ptr<SourceNode>>& sourceNodes,
const SourceNode* start, const SourceNode* start
bool stopAtFirst = false
) )
{ {
std::vector<RequireCycle> result; std::vector<RequireCycle> result;
@ -422,9 +360,6 @@ std::vector<RequireCycle> getRequireCycles(
{ {
result.push_back({depLocation, std::move(cycle)}); result.push_back({depLocation, std::move(cycle)});
if (stopAtFirst)
return result;
// note: if we didn't find a cycle, all nodes that we've seen don't depend [transitively] on start // note: if we didn't find a cycle, all nodes that we've seen don't depend [transitively] on start
// so it's safe to *only* clear seen vector when we find a cycle // so it's safe to *only* clear seen vector when we find a cycle
// if we don't do it, we will not have correct reporting for some cycles // if we don't do it, we will not have correct reporting for some cycles
@ -812,6 +747,32 @@ std::optional<CheckResult> Frontend::getCheckResult(const ModuleName& name, bool
return checkResult; return checkResult;
} }
std::vector<ModuleName> Frontend::getRequiredScripts(const ModuleName& name)
{
RequireTraceResult require = requireTrace[name];
if (isDirty(name))
{
std::optional<SourceCode> source = fileResolver->readSource(name);
if (!source)
{
return {};
}
const Config& config = configResolver->getConfig(name);
ParseOptions opts = config.parseOptions;
opts.captureComments = true;
SourceModule result = parse(name, source->source, opts);
result.type = source->type;
require = traceRequires(fileResolver, result.root, name);
}
std::vector<std::string> requiredModuleNames;
requiredModuleNames.reserve(require.requireList.size());
for (const auto& [moduleName, _] : require.requireList)
{
requiredModuleNames.push_back(moduleName);
}
return requiredModuleNames;
}
bool Frontend::parseGraph( bool Frontend::parseGraph(
std::vector<ModuleName>& buildQueue, std::vector<ModuleName>& buildQueue,
const ModuleName& root, const ModuleName& root,
@ -860,6 +821,16 @@ bool Frontend::parseGraph(
topseen = Permanent; topseen = Permanent;
buildQueue.push_back(top->name); buildQueue.push_back(top->name);
if (FFlag::LuauBetterReverseDependencyTracking)
{
// at this point we know all valid dependencies are processed into SourceNodes
for (const ModuleName& dep : top->requireSet)
{
if (auto it = sourceNodes.find(dep); it != sourceNodes.end())
it->second->dependents.insert(top->name);
}
}
} }
else else
{ {
@ -948,14 +919,11 @@ void Frontend::addBuildQueueItems(
data.environmentScope = getModuleEnvironment(*sourceModule, data.config, frontendOptions.forAutocomplete); data.environmentScope = getModuleEnvironment(*sourceModule, data.config, frontendOptions.forAutocomplete);
data.recordJsonLog = FFlag::DebugLuauLogSolverToJson; data.recordJsonLog = FFlag::DebugLuauLogSolverToJson;
Mode mode = sourceModule->mode.value_or(data.config.mode);
// in NoCheck mode we only need to compute the value of .cyclic for typeck
// in the future we could replace toposort with an algorithm that can flag cyclic nodes by itself // in the future we could replace toposort with an algorithm that can flag cyclic nodes by itself
// however, for now getRequireCycles isn't expensive in practice on the cases we care about, and long term // however, for now getRequireCycles isn't expensive in practice on the cases we care about, and long term
// all correct programs must be acyclic so this code triggers rarely // all correct programs must be acyclic so this code triggers rarely
if (cycleDetected) if (cycleDetected)
data.requireCycles = getRequireCycles(fileResolver, sourceNodes, sourceNode.get(), mode == Mode::NoCheck); data.requireCycles = getRequireCycles(fileResolver, sourceNodes, sourceNode.get());
data.options = frontendOptions; data.options = frontendOptions;
@ -987,7 +955,6 @@ void Frontend::checkBuildQueueItem(BuildQueueItem& item)
else else
mode = sourceModule.mode.value_or(config.mode); mode = sourceModule.mode.value_or(config.mode);
if (FFlag::LuauSourceModuleUpdatedWithSelectedMode)
item.sourceModule->mode = {mode}; item.sourceModule->mode = {mode};
ScopePtr environmentScope = item.environmentScope; ScopePtr environmentScope = item.environmentScope;
double timestamp = getTimestamp(); double timestamp = getTimestamp();
@ -1093,6 +1060,11 @@ void Frontend::checkBuildQueueItem(BuildQueueItem& item)
freeze(module->interfaceTypes); freeze(module->interfaceTypes);
module->internalTypes.clear(); module->internalTypes.clear();
if (FFlag::LuauSelectivelyRetainDFGArena)
{
module->defArena.allocator.clear();
module->keyArena.allocator.clear();
}
module->astTypes.clear(); module->astTypes.clear();
module->astTypePacks.clear(); module->astTypePacks.clear();
@ -1146,6 +1118,39 @@ void Frontend::recordItemResult(const BuildQueueItem& item)
if (item.exception) if (item.exception)
std::rethrow_exception(item.exception); std::rethrow_exception(item.exception);
if (FFlag::LuauBetterReverseDependencyTracking)
{
bool replacedModule = false;
if (item.options.forAutocomplete)
{
replacedModule = moduleResolverForAutocomplete.setModule(item.name, item.module);
item.sourceNode->dirtyModuleForAutocomplete = false;
}
else
{
replacedModule = moduleResolver.setModule(item.name, item.module);
item.sourceNode->dirtyModule = false;
}
if (replacedModule)
{
LUAU_TIMETRACE_SCOPE("Frontend::invalidateDependentModules", "Frontend");
LUAU_TIMETRACE_ARGUMENT("name", item.name.c_str());
traverseDependents(
item.name,
[forAutocomplete = item.options.forAutocomplete](SourceNode& sourceNode)
{
bool traverseSubtree = !sourceNode.hasInvalidModuleDependency(forAutocomplete);
sourceNode.setInvalidModuleDependency(true, forAutocomplete);
return traverseSubtree;
}
);
}
item.sourceNode->setInvalidModuleDependency(false, item.options.forAutocomplete);
}
else
{
if (item.options.forAutocomplete) if (item.options.forAutocomplete)
{ {
moduleResolverForAutocomplete.setModule(item.name, item.module); moduleResolverForAutocomplete.setModule(item.name, item.module);
@ -1156,6 +1161,7 @@ void Frontend::recordItemResult(const BuildQueueItem& item)
moduleResolver.setModule(item.name, item.module); moduleResolver.setModule(item.name, item.module);
item.sourceNode->dirtyModule = false; item.sourceNode->dirtyModule = false;
} }
}
stats.timeCheck += item.stats.timeCheck; stats.timeCheck += item.stats.timeCheck;
stats.timeLint += item.stats.timeLint; stats.timeLint += item.stats.timeLint;
@ -1191,6 +1197,13 @@ ScopePtr Frontend::getModuleEnvironment(const SourceModule& module, const Config
return result; return result;
} }
bool Frontend::allModuleDependenciesValid(const ModuleName& name, bool forAutocomplete) const
{
LUAU_ASSERT(FFlag::LuauBetterReverseDependencyTracking);
auto it = sourceNodes.find(name);
return it != sourceNodes.end() && !it->second->hasInvalidModuleDependency(forAutocomplete);
}
bool Frontend::isDirty(const ModuleName& name, bool forAutocomplete) const bool Frontend::isDirty(const ModuleName& name, bool forAutocomplete) const
{ {
auto it = sourceNodes.find(name); auto it = sourceNodes.find(name);
@ -1204,6 +1217,31 @@ bool Frontend::isDirty(const ModuleName& name, bool forAutocomplete) const
* It would be nice for this function to be O(1) * It would be nice for this function to be O(1)
*/ */
void Frontend::markDirty(const ModuleName& name, std::vector<ModuleName>* markedDirty) void Frontend::markDirty(const ModuleName& name, std::vector<ModuleName>* markedDirty)
{
LUAU_TIMETRACE_SCOPE("Frontend::markDirty", "Frontend");
LUAU_TIMETRACE_ARGUMENT("name", name.c_str());
if (FFlag::LuauBetterReverseDependencyTracking)
{
traverseDependents(
name,
[markedDirty](SourceNode& sourceNode)
{
if (markedDirty)
markedDirty->push_back(sourceNode.name);
if (sourceNode.dirtySourceModule && sourceNode.dirtyModule && sourceNode.dirtyModuleForAutocomplete)
return false;
sourceNode.dirtySourceModule = true;
sourceNode.dirtyModule = true;
sourceNode.dirtyModuleForAutocomplete = true;
return true;
}
);
}
else
{ {
if (sourceNodes.count(name) == 0) if (sourceNodes.count(name) == 0)
return; return;
@ -1244,6 +1282,33 @@ void Frontend::markDirty(const ModuleName& name, std::vector<ModuleName>* marked
queue.insert(queue.end(), dependents.begin(), dependents.end()); queue.insert(queue.end(), dependents.begin(), dependents.end());
} }
} }
}
void Frontend::traverseDependents(const ModuleName& name, std::function<bool(SourceNode&)> processSubtree)
{
LUAU_ASSERT(FFlag::LuauBetterReverseDependencyTracking);
LUAU_TIMETRACE_SCOPE("Frontend::traverseDependents", "Frontend");
if (sourceNodes.count(name) == 0)
return;
std::vector<ModuleName> queue{name};
while (!queue.empty())
{
ModuleName next = std::move(queue.back());
queue.pop_back();
LUAU_ASSERT(sourceNodes.count(next) > 0);
SourceNode& sourceNode = *sourceNodes[next];
if (!processSubtree(sourceNode))
continue;
const Set<ModuleName>& dependents = sourceNode.dependents;
queue.insert(queue.end(), dependents.begin(), dependents.end());
}
}
SourceModule* Frontend::getSourceModule(const ModuleName& moduleName) SourceModule* Frontend::getSourceModule(const ModuleName& moduleName)
{ {
@ -1357,11 +1422,15 @@ ModulePtr check(
LUAU_TIMETRACE_ARGUMENT("name", sourceModule.humanReadableName.c_str()); LUAU_TIMETRACE_ARGUMENT("name", sourceModule.humanReadableName.c_str());
ModulePtr result = std::make_shared<Module>(); ModulePtr result = std::make_shared<Module>();
if (FFlag::LuauStoreSolverTypeOnModule)
result->checkedInNewSolver = true;
result->name = sourceModule.name; result->name = sourceModule.name;
result->humanReadableName = sourceModule.humanReadableName; result->humanReadableName = sourceModule.humanReadableName;
result->mode = mode; result->mode = mode;
result->internalTypes.owningModule = result.get(); result->internalTypes.owningModule = result.get();
result->interfaceTypes.owningModule = result.get(); result->interfaceTypes.owningModule = result.get();
result->allocator = sourceModule.allocator;
result->names = sourceModule.names;
iceHandler->moduleName = sourceModule.name; iceHandler->moduleName = sourceModule.name;
@ -1376,17 +1445,23 @@ ModulePtr check(
} }
} }
DataFlowGraph dfg = DataFlowGraphBuilder::build(sourceModule.root, iceHandler); DataFlowGraph dfg = DataFlowGraphBuilder::build(sourceModule.root, NotNull{&result->defArena}, NotNull{&result->keyArena}, iceHandler);
UnifierSharedState unifierState{iceHandler}; UnifierSharedState unifierState{iceHandler};
unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit; unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit;
unifierState.counters.iterationLimit = limits.unifierIterationLimit.value_or(FInt::LuauTypeInferIterationLimit); unifierState.counters.iterationLimit = limits.unifierIterationLimit.value_or(FInt::LuauTypeInferIterationLimit);
Normalizer normalizer{&result->internalTypes, builtinTypes, NotNull{&unifierState}}; Normalizer normalizer{&result->internalTypes, builtinTypes, NotNull{&unifierState}};
SimplifierPtr simplifier = newSimplifier(NotNull{&result->internalTypes}, builtinTypes);
TypeFunctionRuntime typeFunctionRuntime{iceHandler, NotNull{&limits}};
typeFunctionRuntime.allowEvaluation = sourceModule.parseErrors.empty();
ConstraintGenerator cg{ ConstraintGenerator cg{
result, result,
NotNull{&normalizer}, NotNull{&normalizer},
NotNull{simplifier.get()},
NotNull{&typeFunctionRuntime},
moduleResolver, moduleResolver,
builtinTypes, builtinTypes,
iceHandler, iceHandler,
@ -1402,12 +1477,16 @@ ModulePtr check(
ConstraintSolver cs{ ConstraintSolver cs{
NotNull{&normalizer}, NotNull{&normalizer},
NotNull{simplifier.get()},
NotNull{&typeFunctionRuntime},
NotNull(cg.rootScope), NotNull(cg.rootScope),
borrowConstraints(cg.constraints), borrowConstraints(cg.constraints),
NotNull{&cg.scopeToFunction},
result->name, result->name,
moduleResolver, moduleResolver,
requireCycles, requireCycles,
logger.get(), logger.get(),
NotNull{&dfg},
limits limits
}; };
@ -1461,12 +1540,31 @@ ModulePtr check(
switch (mode) switch (mode)
{ {
case Mode::Nonstrict: case Mode::Nonstrict:
Luau::checkNonStrict(builtinTypes, iceHandler, NotNull{&unifierState}, NotNull{&dfg}, NotNull{&limits}, sourceModule, result.get()); Luau::checkNonStrict(
builtinTypes,
NotNull{simplifier.get()},
NotNull{&typeFunctionRuntime},
iceHandler,
NotNull{&unifierState},
NotNull{&dfg},
NotNull{&limits},
sourceModule,
result.get()
);
break; break;
case Mode::Definition: case Mode::Definition:
// fallthrough intentional // fallthrough intentional
case Mode::Strict: case Mode::Strict:
Luau::check(builtinTypes, NotNull{&unifierState}, NotNull{&limits}, logger.get(), sourceModule, result.get()); Luau::check(
builtinTypes,
NotNull{simplifier.get()},
NotNull{&typeFunctionRuntime},
NotNull{&unifierState},
NotNull{&limits},
logger.get(),
sourceModule,
result.get()
);
break; break;
case Mode::NoCheck: case Mode::NoCheck:
break; break;
@ -1647,6 +1745,17 @@ std::pair<SourceNode*, SourceModule*> Frontend::getSourceNode(const ModuleName&
sourceNode->name = sourceModule->name; sourceNode->name = sourceModule->name;
sourceNode->humanReadableName = sourceModule->humanReadableName; sourceNode->humanReadableName = sourceModule->humanReadableName;
if (FFlag::LuauBetterReverseDependencyTracking)
{
// clear all prior dependents. we will re-add them after parsing the rest of the graph
for (const auto& [moduleName, _] : sourceNode->requireLocations)
{
if (auto depIt = sourceNodes.find(moduleName); depIt != sourceNodes.end())
depIt->second->dependents.erase(sourceNode->name);
}
}
sourceNode->requireSet.clear(); sourceNode->requireSet.clear();
sourceNode->requireLocations.clear(); sourceNode->requireLocations.clear();
sourceNode->dirtySourceModule = false; sourceNode->dirtySourceModule = false;
@ -1768,11 +1877,21 @@ std::string FrontendModuleResolver::getHumanReadableModuleName(const ModuleName&
return frontend->fileResolver->getHumanReadableModuleName(moduleName); return frontend->fileResolver->getHumanReadableModuleName(moduleName);
} }
void FrontendModuleResolver::setModule(const ModuleName& moduleName, ModulePtr module) bool FrontendModuleResolver::setModule(const ModuleName& moduleName, ModulePtr module)
{ {
std::scoped_lock lock(moduleMutex); std::scoped_lock lock(moduleMutex);
if (FFlag::LuauBetterReverseDependencyTracking)
{
bool replaced = modules.count(moduleName) > 0;
modules[moduleName] = std::move(module); modules[moduleName] = std::move(module);
return replaced;
}
else
{
modules[moduleName] = std::move(module);
return false;
}
} }
void FrontendModuleResolver::clearModules() void FrontendModuleResolver::clearModules()

View file

@ -2,6 +2,8 @@
#include "Luau/Generalization.h" #include "Luau/Generalization.h"
#include "Luau/Common.h"
#include "Luau/DenseHash.h"
#include "Luau/Scope.h" #include "Luau/Scope.h"
#include "Luau/Type.h" #include "Luau/Type.h"
#include "Luau/ToString.h" #include "Luau/ToString.h"
@ -9,11 +11,15 @@
#include "Luau/TypePack.h" #include "Luau/TypePack.h"
#include "Luau/VisitType.h" #include "Luau/VisitType.h"
LUAU_FASTFLAG(LuauAutocompleteRefactorsForIncrementalAutocomplete)
LUAU_FASTFLAGVARIABLE(LuauGeneralizationRemoveRecursiveUpperBound2)
namespace Luau namespace Luau
{ {
struct MutatingGeneralizer : TypeOnceVisitor struct MutatingGeneralizer : TypeOnceVisitor
{ {
NotNull<TypeArena> arena;
NotNull<BuiltinTypes> builtinTypes; NotNull<BuiltinTypes> builtinTypes;
NotNull<Scope> scope; NotNull<Scope> scope;
@ -27,6 +33,7 @@ struct MutatingGeneralizer : TypeOnceVisitor
bool avoidSealingTables = false; bool avoidSealingTables = false;
MutatingGeneralizer( MutatingGeneralizer(
NotNull<TypeArena> arena,
NotNull<BuiltinTypes> builtinTypes, NotNull<BuiltinTypes> builtinTypes,
NotNull<Scope> scope, NotNull<Scope> scope,
NotNull<DenseHashSet<TypeId>> cachedTypes, NotNull<DenseHashSet<TypeId>> cachedTypes,
@ -35,6 +42,7 @@ struct MutatingGeneralizer : TypeOnceVisitor
bool avoidSealingTables bool avoidSealingTables
) )
: TypeOnceVisitor(/* skipBoundTypes */ true) : TypeOnceVisitor(/* skipBoundTypes */ true)
, arena(arena)
, builtinTypes(builtinTypes) , builtinTypes(builtinTypes)
, scope(scope) , scope(scope)
, cachedTypes(cachedTypes) , cachedTypes(cachedTypes)
@ -44,7 +52,7 @@ struct MutatingGeneralizer : TypeOnceVisitor
{ {
} }
static void replace(DenseHashSet<TypeId>& seen, TypeId haystack, TypeId needle, TypeId replacement) void replace(DenseHashSet<TypeId>& seen, TypeId haystack, TypeId needle, TypeId replacement)
{ {
haystack = follow(haystack); haystack = follow(haystack);
@ -91,6 +99,10 @@ struct MutatingGeneralizer : TypeOnceVisitor
LUAU_ASSERT(onlyType != haystack); LUAU_ASSERT(onlyType != haystack);
emplaceType<BoundType>(asMutable(haystack), onlyType); emplaceType<BoundType>(asMutable(haystack), onlyType);
} }
else if (FFlag::LuauGeneralizationRemoveRecursiveUpperBound2 && ut->options.empty())
{
emplaceType<BoundType>(asMutable(haystack), builtinTypes->neverType);
}
return; return;
} }
@ -134,6 +146,10 @@ struct MutatingGeneralizer : TypeOnceVisitor
LUAU_ASSERT(onlyType != needle); LUAU_ASSERT(onlyType != needle);
emplaceType<BoundType>(asMutable(needle), onlyType); emplaceType<BoundType>(asMutable(needle), onlyType);
} }
else if (FFlag::LuauGeneralizationRemoveRecursiveUpperBound2 && it->parts.empty())
{
emplaceType<BoundType>(asMutable(needle), builtinTypes->unknownType);
}
return; return;
} }
@ -445,7 +461,7 @@ struct FreeTypeSearcher : TypeVisitor
traverse(*prop.readTy); traverse(*prop.readTy);
else else
{ {
LUAU_ASSERT(prop.isShared()); LUAU_ASSERT(prop.isShared() || FFlag::LuauAutocompleteRefactorsForIncrementalAutocomplete);
Polarity p = polarity; Polarity p = polarity;
polarity = Both; polarity = Both;
@ -526,7 +542,7 @@ struct TypeCacher : TypeOnceVisitor
DenseHashSet<TypePackId> uncacheablePacks{nullptr}; DenseHashSet<TypePackId> uncacheablePacks{nullptr};
explicit TypeCacher(NotNull<DenseHashSet<TypeId>> cachedTypes) explicit TypeCacher(NotNull<DenseHashSet<TypeId>> cachedTypes)
: TypeOnceVisitor(/* skipBoundTypes */ true) : TypeOnceVisitor(/* skipBoundTypes */ false)
, cachedTypes(cachedTypes) , cachedTypes(cachedTypes)
{ {
} }
@ -563,9 +579,19 @@ struct TypeCacher : TypeOnceVisitor
bool visit(TypeId ty) override bool visit(TypeId ty) override
{ {
if (isUncacheable(ty) || isCached(ty)) // NOTE: `TypeCacher` should explicitly visit _all_ types and type packs,
// otherwise it's prone to marking types that cannot be cached as
// cacheable.
LUAU_ASSERT(false);
LUAU_UNREACHABLE();
}
bool visit(TypeId ty, const BoundType& btv) override
{
traverse(btv.boundTo);
if (isUncacheable(btv.boundTo))
markUncacheable(ty);
return false; return false;
return true;
} }
bool visit(TypeId ty, const FreeType& ft) override bool visit(TypeId ty, const FreeType& ft) override
@ -590,6 +616,12 @@ struct TypeCacher : TypeOnceVisitor
return false; return false;
} }
bool visit(TypeId ty, const ErrorType&) override
{
cache(ty);
return false;
}
bool visit(TypeId ty, const PrimitiveType&) override bool visit(TypeId ty, const PrimitiveType&) override
{ {
cache(ty); cache(ty);
@ -727,6 +759,17 @@ struct TypeCacher : TypeOnceVisitor
return false; return false;
} }
bool visit(TypeId ty, const MetatableType& mtv) override
{
traverse(mtv.table);
traverse(mtv.metatable);
if (isUncacheable(mtv.table) || isUncacheable(mtv.metatable))
markUncacheable(ty);
else
cache(ty);
return false;
}
bool visit(TypeId ty, const ClassType&) override bool visit(TypeId ty, const ClassType&) override
{ {
cache(ty); cache(ty);
@ -739,6 +782,12 @@ struct TypeCacher : TypeOnceVisitor
return false; return false;
} }
bool visit(TypeId ty, const NoRefineType&) override
{
cache(ty);
return false;
}
bool visit(TypeId ty, const UnionType& ut) override bool visit(TypeId ty, const UnionType& ut) override
{ {
if (isUncacheable(ty) || isCached(ty)) if (isUncacheable(ty) || isCached(ty))
@ -841,12 +890,31 @@ struct TypeCacher : TypeOnceVisitor
return false; return false;
} }
bool visit(TypePackId tp) override
{
// NOTE: `TypeCacher` should explicitly visit _all_ types and type packs,
// otherwise it's prone to marking types that cannot be cached as
// cacheable, which will segfault down the line.
LUAU_ASSERT(false);
LUAU_UNREACHABLE();
}
bool visit(TypePackId tp, const FreeTypePack&) override bool visit(TypePackId tp, const FreeTypePack&) override
{ {
markUncacheable(tp); markUncacheable(tp);
return false; return false;
} }
bool visit(TypePackId tp, const GenericTypePack& gtp) override
{
return true;
}
bool visit(TypePackId tp, const ErrorTypePack& etp) override
{
return true;
}
bool visit(TypePackId tp, const VariadicTypePack& vtp) override bool visit(TypePackId tp, const VariadicTypePack& vtp) override
{ {
if (isUncacheable(tp)) if (isUncacheable(tp))
@ -871,6 +939,32 @@ struct TypeCacher : TypeOnceVisitor
markUncacheable(tp); markUncacheable(tp);
return false; return false;
} }
bool visit(TypePackId tp, const BoundTypePack& btp) override
{
traverse(btp.boundTo);
if (isUncacheable(btp.boundTo))
markUncacheable(tp);
return false;
}
bool visit(TypePackId tp, const TypePack& typ) override
{
bool uncacheable = false;
for (TypeId ty : typ.head)
{
traverse(ty);
uncacheable |= isUncacheable(ty);
}
if (typ.tail)
{
traverse(*typ.tail);
uncacheable |= isUncacheable(*typ.tail);
}
if (uncacheable)
markUncacheable(tp);
return false;
}
}; };
std::optional<TypeId> generalize( std::optional<TypeId> generalize(
@ -890,7 +984,7 @@ std::optional<TypeId> generalize(
FreeTypeSearcher fts{scope, cachedTypes}; FreeTypeSearcher fts{scope, cachedTypes};
fts.traverse(ty); fts.traverse(ty);
MutatingGeneralizer gen{builtinTypes, scope, cachedTypes, std::move(fts.positiveTypes), std::move(fts.negativeTypes), avoidSealingTables}; MutatingGeneralizer gen{arena, builtinTypes, scope, cachedTypes, std::move(fts.positiveTypes), std::move(fts.negativeTypes), avoidSealingTables};
gen.traverse(ty); gen.traverse(ty);

View file

@ -11,6 +11,7 @@
#include <algorithm> #include <algorithm>
LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds)
namespace Luau namespace Luau
{ {
@ -61,9 +62,7 @@ TypeId Instantiation::clean(TypeId ty)
LUAU_ASSERT(ftv); LUAU_ASSERT(ftv);
FunctionType clone = FunctionType{level, scope, ftv->argTypes, ftv->retTypes, ftv->definition, ftv->hasSelf}; FunctionType clone = FunctionType{level, scope, ftv->argTypes, ftv->retTypes, ftv->definition, ftv->hasSelf};
clone.magicFunction = ftv->magicFunction; clone.magic = ftv->magic;
clone.dcrMagicFunction = ftv->dcrMagicFunction;
clone.dcrMagicRefinement = ftv->dcrMagicRefinement;
clone.tags = ftv->tags; clone.tags = ftv->tags;
clone.argNames = ftv->argNames; clone.argNames = ftv->argNames;
TypeId result = addType(std::move(clone)); TypeId result = addType(std::move(clone));
@ -165,7 +164,7 @@ TypeId ReplaceGenerics::clean(TypeId ty)
} }
else else
{ {
return addType(FreeType{scope, level}); return FFlag::LuauFreeTypesMustHaveBounds ? arena->freshType(builtinTypes, scope, level) : addType(FreeType{scope, level});
} }
} }

View file

@ -227,6 +227,8 @@ static void errorToString(std::ostream& stream, const T& err)
stream << "UnexpectedTypeInSubtyping { ty = '" + toString(err.ty) + "' }"; stream << "UnexpectedTypeInSubtyping { ty = '" + toString(err.ty) + "' }";
else if constexpr (std::is_same_v<T, UnexpectedTypePackInSubtyping>) else if constexpr (std::is_same_v<T, UnexpectedTypePackInSubtyping>)
stream << "UnexpectedTypePackInSubtyping { tp = '" + toString(err.tp) + "' }"; stream << "UnexpectedTypePackInSubtyping { tp = '" + toString(err.tp) + "' }";
else if constexpr (std::is_same_v<T, UserDefinedTypeFunctionError>)
stream << "UserDefinedTypeFunctionError { " << err.message << " }";
else if constexpr (std::is_same_v<T, CannotAssignToNever>) else if constexpr (std::is_same_v<T, CannotAssignToNever>)
{ {
stream << "CannotAssignToNever { rvalueType = '" << toString(err.rhsType) << "', reason = '" << err.reason << "', cause = { "; stream << "CannotAssignToNever { rvalueType = '" << toString(err.rhsType) << "', reason = '" << err.reason << "', cause = { ";

View file

@ -17,8 +17,7 @@ LUAU_FASTINTVARIABLE(LuauSuggestionDistance, 4)
LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAG(LuauAttribute) LUAU_FASTFLAG(LuauAttribute)
LUAU_FASTFLAG(LuauNativeAttribute) LUAU_FASTFLAGVARIABLE(LintRedundantNativeAttribute)
LUAU_FASTFLAGVARIABLE(LintRedundantNativeAttribute, false)
namespace Luau namespace Luau
{ {
@ -3239,7 +3238,6 @@ static void lintComments(LintContext& context, const std::vector<HotComment>& ho
static bool hasNativeCommentDirective(const std::vector<HotComment>& hotcomments) static bool hasNativeCommentDirective(const std::vector<HotComment>& hotcomments)
{ {
LUAU_ASSERT(FFlag::LuauNativeAttribute);
LUAU_ASSERT(FFlag::LintRedundantNativeAttribute); LUAU_ASSERT(FFlag::LintRedundantNativeAttribute);
for (const HotComment& hc : hotcomments) for (const HotComment& hc : hotcomments)
@ -3265,7 +3263,6 @@ struct LintRedundantNativeAttribute : AstVisitor
public: public:
LUAU_NOINLINE static void process(LintContext& context) LUAU_NOINLINE static void process(LintContext& context)
{ {
LUAU_ASSERT(FFlag::LuauNativeAttribute);
LUAU_ASSERT(FFlag::LintRedundantNativeAttribute); LUAU_ASSERT(FFlag::LintRedundantNativeAttribute);
LintRedundantNativeAttribute pass; LintRedundantNativeAttribute pass;
@ -3389,7 +3386,7 @@ std::vector<LintWarning> lint(
if (context.warningEnabled(LintWarning::Code_ComparisonPrecedence)) if (context.warningEnabled(LintWarning::Code_ComparisonPrecedence))
LintComparisonPrecedence::process(context); LintComparisonPrecedence::process(context);
if (FFlag::LuauNativeAttribute && FFlag::LintRedundantNativeAttribute && context.warningEnabled(LintWarning::Code_RedundantNativeAttribute)) if (FFlag::LintRedundantNativeAttribute && context.warningEnabled(LintWarning::Code_RedundantNativeAttribute))
{ {
if (hasNativeCommentDirective(hotcomments)) if (hasNativeCommentDirective(hotcomments))
LintRedundantNativeAttribute::process(context); LintRedundantNativeAttribute::process(context);

View file

@ -15,11 +15,32 @@
#include <algorithm> #include <algorithm>
LUAU_FASTFLAG(LuauSolverV2); LUAU_FASTFLAG(LuauSolverV2);
LUAU_FASTFLAGVARIABLE(LuauIncrementalAutocompleteCommentDetection)
namespace Luau namespace Luau
{ {
static bool contains(Position pos, Comment comment) static void defaultLogLuau(std::string_view input)
{
// The default is to do nothing because we don't want to mess with
// the xml parsing done by the dcr script.
}
Luau::LogLuauProc logLuau = &defaultLogLuau;
void setLogLuau(LogLuauProc ll)
{
logLuau = ll;
}
void resetLogLuauProc()
{
logLuau = &defaultLogLuau;
}
static bool contains_DEPRECATED(Position pos, Comment comment)
{ {
if (comment.location.contains(pos)) if (comment.location.contains(pos))
return true; return true;
@ -32,7 +53,22 @@ static bool contains(Position pos, Comment comment)
return false; return false;
} }
static bool isWithinComment(const std::vector<Comment>& commentLocations, Position pos) static bool contains(Position pos, Comment comment)
{
if (comment.location.contains(pos))
return true;
else if (comment.type == Lexeme::BrokenComment && comment.location.begin <= pos) // Broken comments are broken specifically because they don't
// have an end
return true;
// comments actually span the whole line - in incremental mode, we could pass a cursor outside of the current parsed comment range span, but it
// would still be 'within' the comment So, the cursor must be on the same line and the comment itself must come strictly after the `begin`
else if (comment.type == Lexeme::Comment && comment.location.end.line == pos.line && comment.location.begin <= pos)
return true;
else
return false;
}
bool isWithinComment(const std::vector<Comment>& commentLocations, Position pos)
{ {
auto iter = std::lower_bound( auto iter = std::lower_bound(
commentLocations.begin(), commentLocations.begin(),
@ -40,6 +76,11 @@ static bool isWithinComment(const std::vector<Comment>& commentLocations, Positi
Comment{Lexeme::Comment, Location{pos, pos}}, Comment{Lexeme::Comment, Location{pos, pos}},
[](const Comment& a, const Comment& b) [](const Comment& a, const Comment& b)
{ {
if (FFlag::LuauIncrementalAutocompleteCommentDetection)
{
if (a.type == Lexeme::Comment)
return a.location.end.line < b.location.end.line;
}
return a.location.end < b.location.end; return a.location.end < b.location.end;
} }
); );
@ -47,7 +88,7 @@ static bool isWithinComment(const std::vector<Comment>& commentLocations, Positi
if (iter == commentLocations.end()) if (iter == commentLocations.end())
return false; return false;
if (contains(pos, *iter)) if (FFlag::LuauIncrementalAutocompleteCommentDetection ? contains(pos, *iter) : contains_DEPRECATED(pos, *iter))
return true; return true;
// Due to the nature of std::lower_bound, it is possible that iter points at a comment that ends // Due to the nature of std::lower_bound, it is possible that iter points at a comment that ends
@ -131,19 +172,61 @@ struct ClonePublicInterface : Substitution
} }
ftv->level = TypeLevel{0, 0}; ftv->level = TypeLevel{0, 0};
if (FFlag::LuauSolverV2)
ftv->scope = nullptr;
} }
else if (TableType* ttv = getMutable<TableType>(result)) else if (TableType* ttv = getMutable<TableType>(result))
{ {
ttv->level = TypeLevel{0, 0}; ttv->level = TypeLevel{0, 0};
if (FFlag::LuauSolverV2)
ttv->scope = nullptr;
}
if (FFlag::LuauSolverV2)
{
if (auto freety = getMutable<FreeType>(result))
{
module->errors.emplace_back(
freety->scope->location,
module->name,
InternalError{"Free type is escaping its module; please report this bug at "
"https://github.com/luau-lang/luau/issues"}
);
result = builtinTypes->errorRecoveryType();
}
else if (auto genericty = getMutable<GenericType>(result))
{
genericty->scope = nullptr;
}
} }
return result; return result;
} }
TypePackId clean(TypePackId tp) override TypePackId clean(TypePackId tp) override
{
if (FFlag::LuauSolverV2)
{
auto clonedTp = clone(tp);
if (auto ftp = getMutable<FreeTypePack>(clonedTp))
{
module->errors.emplace_back(
ftp->scope->location,
module->name,
InternalError{"Free type pack is escaping its module; please report this bug at "
"https://github.com/luau-lang/luau/issues"}
);
clonedTp = builtinTypes->errorRecoveryTypePack();
}
else if (auto gtp = getMutable<GenericTypePack>(clonedTp))
gtp->scope = nullptr;
return clonedTp;
}
else
{ {
return clone(tp); return clone(tp);
} }
}
TypeId cloneType(TypeId ty) TypeId cloneType(TypeId ty)
{ {

View file

@ -14,11 +14,15 @@
#include "Luau/TypeFunction.h" #include "Luau/TypeFunction.h"
#include "Luau/Def.h" #include "Luau/Def.h"
#include "Luau/ToString.h" #include "Luau/ToString.h"
#include "Luau/TypeFwd.h" #include "Luau/TypeUtils.h"
#include <iostream> #include <iostream>
#include <iterator> #include <iterator>
LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds)
LUAU_FASTFLAGVARIABLE(LuauNonStrictVisitorImprovements)
LUAU_FASTFLAGVARIABLE(LuauNewNonStrictWarnOnUnknownGlobals)
namespace Luau namespace Luau
{ {
@ -154,8 +158,9 @@ private:
struct NonStrictTypeChecker struct NonStrictTypeChecker
{ {
NotNull<BuiltinTypes> builtinTypes; NotNull<BuiltinTypes> builtinTypes;
NotNull<Simplifier> simplifier;
NotNull<TypeFunctionRuntime> typeFunctionRuntime;
const NotNull<InternalErrorReporter> ice; const NotNull<InternalErrorReporter> ice;
NotNull<TypeArena> arena; NotNull<TypeArena> arena;
Module* module; Module* module;
@ -171,6 +176,8 @@ struct NonStrictTypeChecker
NonStrictTypeChecker( NonStrictTypeChecker(
NotNull<TypeArena> arena, NotNull<TypeArena> arena,
NotNull<BuiltinTypes> builtinTypes, NotNull<BuiltinTypes> builtinTypes,
NotNull<Simplifier> simplifier,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
const NotNull<InternalErrorReporter> ice, const NotNull<InternalErrorReporter> ice,
NotNull<UnifierSharedState> unifierState, NotNull<UnifierSharedState> unifierState,
NotNull<const DataFlowGraph> dfg, NotNull<const DataFlowGraph> dfg,
@ -178,11 +185,13 @@ struct NonStrictTypeChecker
Module* module Module* module
) )
: builtinTypes(builtinTypes) : builtinTypes(builtinTypes)
, simplifier(simplifier)
, typeFunctionRuntime(typeFunctionRuntime)
, ice(ice) , ice(ice)
, arena(arena) , arena(arena)
, module(module) , module(module)
, normalizer{arena, builtinTypes, unifierState, /* cache inhabitance */ true} , normalizer{arena, builtinTypes, unifierState, /* cache inhabitance */ true}
, subtyping{builtinTypes, arena, NotNull(&normalizer), ice} , subtyping{builtinTypes, arena, simplifier, NotNull(&normalizer), typeFunctionRuntime, ice}
, dfg(dfg) , dfg(dfg)
, limits(limits) , limits(limits)
{ {
@ -204,7 +213,7 @@ struct NonStrictTypeChecker
return *fst; return *fst;
else if (auto ftp = get<FreeTypePack>(pack)) else if (auto ftp = get<FreeTypePack>(pack))
{ {
TypeId result = arena->addType(FreeType{ftp->scope}); TypeId result = FFlag::LuauFreeTypesMustHaveBounds ? arena->freshType(builtinTypes, ftp->scope) : arena->addType(FreeType{ftp->scope});
TypePackId freeTail = arena->addTypePack(FreeTypePack{ftp->scope}); TypePackId freeTail = arena->addTypePack(FreeTypePack{ftp->scope});
TypePack* resultPack = emplaceTypePack<TypePack>(asMutable(pack)); TypePack* resultPack = emplaceTypePack<TypePack>(asMutable(pack));
@ -213,7 +222,7 @@ struct NonStrictTypeChecker
return result; return result;
} }
else if (get<Unifiable::Error>(pack)) else if (get<ErrorTypePack>(pack))
return builtinTypes->errorRecoveryType(); return builtinTypes->errorRecoveryType();
else if (finite(pack) && size(pack) == 0) else if (finite(pack) && size(pack) == 0)
return builtinTypes->nilType; // `(f())` where `f()` returns no values is coerced into `nil` return builtinTypes->nilType; // `(f())` where `f()` returns no values is coerced into `nil`
@ -228,7 +237,12 @@ struct NonStrictTypeChecker
return instance; return instance;
ErrorVec errors = ErrorVec errors =
reduceTypeFunctions(instance, location, TypeFunctionContext{arena, builtinTypes, stack.back(), NotNull{&normalizer}, ice, limits}, true) reduceTypeFunctions(
instance,
location,
TypeFunctionContext{arena, builtinTypes, stack.back(), simplifier, NotNull{&normalizer}, typeFunctionRuntime, ice, limits},
true
)
.errors; .errors;
if (errors.empty()) if (errors.empty())
@ -329,8 +343,9 @@ struct NonStrictTypeChecker
NonStrictContext visit(AstStatIf* ifStatement) NonStrictContext visit(AstStatIf* ifStatement)
{ {
NonStrictContext condB = visit(ifStatement->condition); NonStrictContext condB = visit(ifStatement->condition, ValueContext::RValue);
NonStrictContext branchContext; NonStrictContext branchContext;
// If there is no else branch, don't bother generating warnings for the then branch - we can't prove there is an error // If there is no else branch, don't bother generating warnings for the then branch - we can't prove there is an error
if (ifStatement->elsebody) if (ifStatement->elsebody)
{ {
@ -338,16 +353,31 @@ struct NonStrictTypeChecker
NonStrictContext elseBody = visit(ifStatement->elsebody); NonStrictContext elseBody = visit(ifStatement->elsebody);
branchContext = NonStrictContext::conjunction(builtinTypes, arena, thenBody, elseBody); branchContext = NonStrictContext::conjunction(builtinTypes, arena, thenBody, elseBody);
} }
return NonStrictContext::disjunction(builtinTypes, arena, condB, branchContext); return NonStrictContext::disjunction(builtinTypes, arena, condB, branchContext);
} }
NonStrictContext visit(AstStatWhile* whileStatement) NonStrictContext visit(AstStatWhile* whileStatement)
{ {
if (FFlag::LuauNonStrictVisitorImprovements)
{
NonStrictContext condition = visit(whileStatement->condition, ValueContext::RValue);
NonStrictContext body = visit(whileStatement->body);
return NonStrictContext::disjunction(builtinTypes, arena, condition, body);
}
else
return {}; return {};
} }
NonStrictContext visit(AstStatRepeat* repeatStatement) NonStrictContext visit(AstStatRepeat* repeatStatement)
{ {
if (FFlag::LuauNonStrictVisitorImprovements)
{
NonStrictContext body = visit(repeatStatement->body);
NonStrictContext condition = visit(repeatStatement->condition, ValueContext::RValue);
return NonStrictContext::disjunction(builtinTypes, arena, body, condition);
}
else
return {}; return {};
} }
@ -363,49 +393,94 @@ struct NonStrictTypeChecker
NonStrictContext visit(AstStatReturn* returnStatement) NonStrictContext visit(AstStatReturn* returnStatement)
{ {
if (FFlag::LuauNonStrictVisitorImprovements)
{
// TODO: this is believing existing code, but i'm not sure if this makes sense
// for how the contexts are handled
for (AstExpr* expr : returnStatement->list)
visit(expr, ValueContext::RValue);
}
return {}; return {};
} }
NonStrictContext visit(AstStatExpr* expr) NonStrictContext visit(AstStatExpr* expr)
{ {
return visit(expr->expr); return visit(expr->expr, ValueContext::RValue);
} }
NonStrictContext visit(AstStatLocal* local) NonStrictContext visit(AstStatLocal* local)
{ {
for (AstExpr* rhs : local->values) for (AstExpr* rhs : local->values)
visit(rhs); visit(rhs, ValueContext::RValue);
return {}; return {};
} }
NonStrictContext visit(AstStatFor* forStatement) NonStrictContext visit(AstStatFor* forStatement)
{
if (FFlag::LuauNonStrictVisitorImprovements)
{
// TODO: throwing out context based on same principle as existing code?
if (forStatement->from)
visit(forStatement->from, ValueContext::RValue);
if (forStatement->to)
visit(forStatement->to, ValueContext::RValue);
if (forStatement->step)
visit(forStatement->step, ValueContext::RValue);
return visit(forStatement->body);
}
else
{ {
return {}; return {};
} }
}
NonStrictContext visit(AstStatForIn* forInStatement) NonStrictContext visit(AstStatForIn* forInStatement)
{
if (FFlag::LuauNonStrictVisitorImprovements)
{
for (AstExpr* rhs : forInStatement->values)
visit(rhs, ValueContext::RValue);
return visit(forInStatement->body);
}
else
{ {
return {}; return {};
} }
}
NonStrictContext visit(AstStatAssign* assign) NonStrictContext visit(AstStatAssign* assign)
{ {
if (FFlag::LuauNonStrictVisitorImprovements)
{
for (AstExpr* lhs : assign->vars)
visit(lhs, ValueContext::LValue);
for (AstExpr* rhs : assign->values)
visit(rhs, ValueContext::RValue);
}
return {}; return {};
} }
NonStrictContext visit(AstStatCompoundAssign* compoundAssign) NonStrictContext visit(AstStatCompoundAssign* compoundAssign)
{ {
if (FFlag::LuauNonStrictVisitorImprovements)
{
visit(compoundAssign->var, ValueContext::LValue);
visit(compoundAssign->value, ValueContext::RValue);
}
return {}; return {};
} }
NonStrictContext visit(AstStatFunction* statFn) NonStrictContext visit(AstStatFunction* statFn)
{ {
return visit(statFn->func); return visit(statFn->func, ValueContext::RValue);
} }
NonStrictContext visit(AstStatLocalFunction* localFn) NonStrictContext visit(AstStatLocalFunction* localFn)
{ {
return visit(localFn->func); return visit(localFn->func, ValueContext::RValue);
} }
NonStrictContext visit(AstStatTypeAlias* typeAlias) NonStrictContext visit(AstStatTypeAlias* typeAlias)
@ -415,7 +490,6 @@ struct NonStrictTypeChecker
NonStrictContext visit(AstStatTypeFunction* typeFunc) NonStrictContext visit(AstStatTypeFunction* typeFunc)
{ {
reportError(GenericError{"This syntax is not supported"}, typeFunc->location);
return {}; return {};
} }
@ -436,14 +510,22 @@ struct NonStrictTypeChecker
NonStrictContext visit(AstStatError* error) NonStrictContext visit(AstStatError* error)
{ {
if (FFlag::LuauNonStrictVisitorImprovements)
{
for (AstStat* stat : error->statements)
visit(stat);
for (AstExpr* expr : error->expressions)
visit(expr, ValueContext::RValue);
}
return {}; return {};
} }
NonStrictContext visit(AstExpr* expr) NonStrictContext visit(AstExpr* expr, ValueContext context)
{ {
auto pusher = pushStack(expr); auto pusher = pushStack(expr);
if (auto e = expr->as<AstExprGroup>()) if (auto e = expr->as<AstExprGroup>())
return visit(e); return visit(e, context);
else if (auto e = expr->as<AstExprConstantNil>()) else if (auto e = expr->as<AstExprConstantNil>())
return visit(e); return visit(e);
else if (auto e = expr->as<AstExprConstantBool>()) else if (auto e = expr->as<AstExprConstantBool>())
@ -453,17 +535,17 @@ struct NonStrictTypeChecker
else if (auto e = expr->as<AstExprConstantString>()) else if (auto e = expr->as<AstExprConstantString>())
return visit(e); return visit(e);
else if (auto e = expr->as<AstExprLocal>()) else if (auto e = expr->as<AstExprLocal>())
return visit(e); return visit(e, context);
else if (auto e = expr->as<AstExprGlobal>()) else if (auto e = expr->as<AstExprGlobal>())
return visit(e); return visit(e, context);
else if (auto e = expr->as<AstExprVarargs>()) else if (auto e = expr->as<AstExprVarargs>())
return visit(e); return visit(e);
else if (auto e = expr->as<AstExprCall>()) else if (auto e = expr->as<AstExprCall>())
return visit(e); return visit(e);
else if (auto e = expr->as<AstExprIndexName>()) else if (auto e = expr->as<AstExprIndexName>())
return visit(e); return visit(e, context);
else if (auto e = expr->as<AstExprIndexExpr>()) else if (auto e = expr->as<AstExprIndexExpr>())
return visit(e); return visit(e, context);
else if (auto e = expr->as<AstExprFunction>()) else if (auto e = expr->as<AstExprFunction>())
return visit(e); return visit(e);
else if (auto e = expr->as<AstExprTable>()) else if (auto e = expr->as<AstExprTable>())
@ -487,8 +569,11 @@ struct NonStrictTypeChecker
} }
} }
NonStrictContext visit(AstExprGroup* group) NonStrictContext visit(AstExprGroup* group, ValueContext context)
{ {
if (FFlag::LuauNonStrictVisitorImprovements)
return visit(group->expr, context);
else
return {}; return {};
} }
@ -512,22 +597,34 @@ struct NonStrictTypeChecker
return {}; return {};
} }
NonStrictContext visit(AstExprLocal* local) NonStrictContext visit(AstExprLocal* local, ValueContext context)
{ {
return {}; return {};
} }
NonStrictContext visit(AstExprGlobal* global) NonStrictContext visit(AstExprGlobal* global, ValueContext context)
{ {
if (FFlag::LuauNewNonStrictWarnOnUnknownGlobals)
{
// We don't file unknown symbols for LValues.
if (context == ValueContext::LValue)
return {};
NotNull<Scope> scope = stack.back();
if (!scope->lookup(global->name))
{
reportError(UnknownSymbol{global->name.value, UnknownSymbol::Binding}, global->location);
}
}
return {}; return {};
} }
NonStrictContext visit(AstExprVarargs* global) NonStrictContext visit(AstExprVarargs* varargs)
{ {
return {}; return {};
} }
NonStrictContext visit(AstExprCall* call) NonStrictContext visit(AstExprCall* call)
{ {
NonStrictContext fresh{}; NonStrictContext fresh{};
@ -536,9 +633,7 @@ struct NonStrictTypeChecker
return fresh; return fresh;
TypeId fnTy = *originalCallTy; TypeId fnTy = *originalCallTy;
if (auto fn = get<FunctionType>(follow(fnTy))) if (auto fn = get<FunctionType>(follow(fnTy)); fn && fn->isCheckedFunction)
{
if (fn->isCheckedFunction)
{ {
// We know fn is a checked function, which means it looks like: // We know fn is a checked function, which means it looks like:
// (S1, ... SN) -> T & // (S1, ... SN) -> T &
@ -547,21 +642,33 @@ struct NonStrictTypeChecker
// ... // ...
// ... // ...
// (unknown^N-1, ~S_N) -> error // (unknown^N-1, ~S_N) -> error
std::vector<AstExpr*> arguments;
arguments.reserve(call->args.size + (call->self ? 1 : 0));
if (call->self)
{
if (auto indexExpr = call->func->as<AstExprIndexName>())
arguments.push_back(indexExpr->expr);
else
ice->ice("method call expression has no 'self'");
}
arguments.insert(arguments.end(), call->args.begin(), call->args.end());
std::vector<TypeId> argTypes; std::vector<TypeId> argTypes;
argTypes.reserve(call->args.size); argTypes.reserve(arguments.size());
// Pad out the arg types array with the types you would expect to see
// Move all the types over from the argument typepack for `fn`
TypePackIterator curr = begin(fn->argTypes); TypePackIterator curr = begin(fn->argTypes);
TypePackIterator fin = end(fn->argTypes); TypePackIterator fin = end(fn->argTypes);
while (curr != fin) for (; curr != fin; curr++)
{
argTypes.push_back(*curr); argTypes.push_back(*curr);
++curr;
} // Pad out the rest with the variadic as needed.
if (auto argTail = curr.tail()) if (auto argTail = curr.tail())
{ {
if (const VariadicTypePack* vtp = get<VariadicTypePack>(follow(*argTail))) if (const VariadicTypePack* vtp = get<VariadicTypePack>(follow(*argTail)))
{ {
while (argTypes.size() < call->args.size) while (argTypes.size() < arguments.size())
{ {
argTypes.push_back(vtp->ty); argTypes.push_back(vtp->ty);
} }
@ -569,26 +676,26 @@ struct NonStrictTypeChecker
} }
std::string functionName = getFunctionNameAsString(*call->func).value_or(""); std::string functionName = getFunctionNameAsString(*call->func).value_or("");
if (call->args.size > argTypes.size()) if (arguments.size() > argTypes.size())
{ {
// We are passing more arguments than we expect, so we should error // We are passing more arguments than we expect, so we should error
reportError(CheckedFunctionIncorrectArgs{functionName, argTypes.size(), call->args.size}, call->location); reportError(CheckedFunctionIncorrectArgs{functionName, argTypes.size(), arguments.size()}, call->location);
return fresh; return fresh;
} }
for (size_t i = 0; i < call->args.size; i++) for (size_t i = 0; i < arguments.size(); i++)
{ {
// For example, if the arg is "hi" // For example, if the arg is "hi"
// The actual arg type is string // The actual arg type is string
// The expected arg type is number // The expected arg type is number
// The type of the argument in the overload is ~number // The type of the argument in the overload is ~number
// We will compare arg and ~number // We will compare arg and ~number
AstExpr* arg = call->args.data[i]; AstExpr* arg = arguments[i];
TypeId expectedArgType = argTypes[i]; TypeId expectedArgType = argTypes[i];
std::shared_ptr<const NormalizedType> norm = normalizer.normalize(expectedArgType); std::shared_ptr<const NormalizedType> norm = normalizer.normalize(expectedArgType);
DefId def = dfg->getDef(arg); DefId def = dfg->getDef(arg);
TypeId runTimeErrorTy; TypeId runTimeErrorTy;
// If we're dealing with any, negating any will cause all subtype tests to fail, since ~any is any // If we're dealing with any, negating any will cause all subtype tests to fail
// However, when someone calls this function, they're going to want to be able to pass it anything, // However, when someone calls this function, they're going to want to be able to pass it anything,
// for that reason, we manually inject never into the context so that the runtime test will always pass. // for that reason, we manually inject never into the context so that the runtime test will always pass.
if (!norm) if (!norm)
@ -602,39 +709,49 @@ struct NonStrictTypeChecker
} }
// Populate the context and now iterate through each of the arguments to the call to find out if we satisfy the types // Populate the context and now iterate through each of the arguments to the call to find out if we satisfy the types
for (size_t i = 0; i < call->args.size; i++) for (size_t i = 0; i < arguments.size(); i++)
{ {
AstExpr* arg = call->args.data[i]; AstExpr* arg = arguments[i];
if (auto runTimeFailureType = willRunTimeError(arg, fresh)) if (auto runTimeFailureType = willRunTimeError(arg, fresh))
reportError(CheckedFunctionCallError{argTypes[i], *runTimeFailureType, functionName, i}, arg->location); reportError(CheckedFunctionCallError{argTypes[i], *runTimeFailureType, functionName, i}, arg->location);
} }
if (call->args.size < argTypes.size()) if (arguments.size() < argTypes.size())
{ {
// We are passing fewer arguments than we expect // We are passing fewer arguments than we expect
// so we need to ensure that the rest of the args are optional. // so we need to ensure that the rest of the args are optional.
bool remainingArgsOptional = true; bool remainingArgsOptional = true;
for (size_t i = call->args.size; i < argTypes.size(); i++) for (size_t i = arguments.size(); i < argTypes.size(); i++)
remainingArgsOptional = remainingArgsOptional && isOptional(argTypes[i]); remainingArgsOptional = remainingArgsOptional && isOptional(argTypes[i]);
if (!remainingArgsOptional) if (!remainingArgsOptional)
{ {
reportError(CheckedFunctionIncorrectArgs{functionName, argTypes.size(), call->args.size}, call->location); reportError(CheckedFunctionIncorrectArgs{functionName, argTypes.size(), arguments.size()}, call->location);
return fresh; return fresh;
} }
} }
} }
}
return fresh; return fresh;
} }
NonStrictContext visit(AstExprIndexName* indexName) NonStrictContext visit(AstExprIndexName* indexName, ValueContext context)
{ {
if (FFlag::LuauNonStrictVisitorImprovements)
return visit(indexName->expr, context);
else
return {}; return {};
} }
NonStrictContext visit(AstExprIndexExpr* indexExpr) NonStrictContext visit(AstExprIndexExpr* indexExpr, ValueContext context)
{ {
if (FFlag::LuauNonStrictVisitorImprovements)
{
NonStrictContext expr = visit(indexExpr->expr, context);
NonStrictContext index = visit(indexExpr->index, ValueContext::RValue);
return NonStrictContext::disjunction(builtinTypes, arena, expr, index);
}
else
return {}; return {};
} }
@ -654,39 +771,74 @@ struct NonStrictTypeChecker
NonStrictContext visit(AstExprTable* table) NonStrictContext visit(AstExprTable* table)
{ {
if (FFlag::LuauNonStrictVisitorImprovements)
{
for (auto [_, key, value] : table->items)
{
if (key)
visit(key, ValueContext::RValue);
visit(value, ValueContext::RValue);
}
}
return {}; return {};
} }
NonStrictContext visit(AstExprUnary* unary) NonStrictContext visit(AstExprUnary* unary)
{ {
if (FFlag::LuauNonStrictVisitorImprovements)
return visit(unary->expr, ValueContext::RValue);
else
return {}; return {};
} }
NonStrictContext visit(AstExprBinary* binary) NonStrictContext visit(AstExprBinary* binary)
{ {
if (FFlag::LuauNonStrictVisitorImprovements)
{
NonStrictContext lhs = visit(binary->left, ValueContext::RValue);
NonStrictContext rhs = visit(binary->right, ValueContext::RValue);
return NonStrictContext::disjunction(builtinTypes, arena, lhs, rhs);
}
else
return {}; return {};
} }
NonStrictContext visit(AstExprTypeAssertion* typeAssertion) NonStrictContext visit(AstExprTypeAssertion* typeAssertion)
{ {
if (FFlag::LuauNonStrictVisitorImprovements)
return visit(typeAssertion->expr, ValueContext::RValue);
else
return {}; return {};
} }
NonStrictContext visit(AstExprIfElse* ifElse) NonStrictContext visit(AstExprIfElse* ifElse)
{ {
NonStrictContext condB = visit(ifElse->condition); NonStrictContext condB = visit(ifElse->condition, ValueContext::RValue);
NonStrictContext thenB = visit(ifElse->trueExpr); NonStrictContext thenB = visit(ifElse->trueExpr, ValueContext::RValue);
NonStrictContext elseB = visit(ifElse->falseExpr); NonStrictContext elseB = visit(ifElse->falseExpr, ValueContext::RValue);
return NonStrictContext::disjunction(builtinTypes, arena, condB, NonStrictContext::conjunction(builtinTypes, arena, thenB, elseB)); return NonStrictContext::disjunction(builtinTypes, arena, condB, NonStrictContext::conjunction(builtinTypes, arena, thenB, elseB));
} }
NonStrictContext visit(AstExprInterpString* interpString) NonStrictContext visit(AstExprInterpString* interpString)
{ {
if (FFlag::LuauNonStrictVisitorImprovements)
{
for (AstExpr* expr : interpString->expressions)
visit(expr, ValueContext::RValue);
}
return {}; return {};
} }
NonStrictContext visit(AstExprError* error) NonStrictContext visit(AstExprError* error)
{ {
if (FFlag::LuauNonStrictVisitorImprovements)
{
for (AstExpr* expr : error->expressions)
visit(expr, ValueContext::RValue);
}
return {}; return {};
} }
@ -754,6 +906,8 @@ private:
void checkNonStrict( void checkNonStrict(
NotNull<BuiltinTypes> builtinTypes, NotNull<BuiltinTypes> builtinTypes,
NotNull<Simplifier> simplifier,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<InternalErrorReporter> ice, NotNull<InternalErrorReporter> ice,
NotNull<UnifierSharedState> unifierState, NotNull<UnifierSharedState> unifierState,
NotNull<const DataFlowGraph> dfg, NotNull<const DataFlowGraph> dfg,
@ -764,7 +918,9 @@ void checkNonStrict(
{ {
LUAU_TIMETRACE_SCOPE("checkNonStrict", "Typechecking"); LUAU_TIMETRACE_SCOPE("checkNonStrict", "Typechecking");
NonStrictTypeChecker typeChecker{NotNull{&module->internalTypes}, builtinTypes, ice, unifierState, dfg, limits, module}; NonStrictTypeChecker typeChecker{
NotNull{&module->internalTypes}, builtinTypes, simplifier, typeFunctionRuntime, ice, unifierState, dfg, limits, module
};
typeChecker.visit(sourceModule.root); typeChecker.visit(sourceModule.root);
unfreeze(module->interfaceTypes); unfreeze(module->interfaceTypes);
copyErrors(module->errors, module->interfaceTypes, builtinTypes); copyErrors(module->errors, module->interfaceTypes, builtinTypes);

View file

@ -15,36 +15,17 @@
#include "Luau/TypeFwd.h" #include "Luau/TypeFwd.h"
#include "Luau/Unifier.h" #include "Luau/Unifier.h"
LUAU_FASTFLAGVARIABLE(DebugLuauCheckNormalizeInvariant, false) LUAU_FASTFLAGVARIABLE(DebugLuauCheckNormalizeInvariant)
LUAU_FASTFLAGVARIABLE(LuauNormalizeAwayUninhabitableTables, false)
LUAU_FASTFLAGVARIABLE(LuauNormalizeNotUnknownIntersection, false);
LUAU_FASTFLAGVARIABLE(LuauFixReduceStackPressure, false);
LUAU_FASTFLAGVARIABLE(LuauFixCyclicTablesBlowingStack, false);
// This could theoretically be 2000 on amd64, but x86 requires this. LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000)
LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200); LUAU_FASTINTVARIABLE(LuauNormalizeIntersectionLimit, 200)
LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000); LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAG(LuauSolverV2); LUAU_FASTFLAGVARIABLE(LuauFixInfiniteRecursionInNormalization)
LUAU_FASTFLAGVARIABLE(LuauFixNormalizedIntersectionOfNegatedClass)
static bool fixReduceStackPressure()
{
return FFlag::LuauFixReduceStackPressure || FFlag::LuauSolverV2;
}
static bool fixCyclicTablesBlowingStack()
{
return FFlag::LuauFixCyclicTablesBlowingStack || FFlag::LuauSolverV2;
}
namespace Luau namespace Luau
{ {
// helper to make `FFlag::LuauNormalizeAwayUninhabitableTables` not explicitly required when DCR is enabled.
static bool normalizeAwayUninhabitableTables()
{
return FFlag::LuauNormalizeAwayUninhabitableTables || FFlag::LuauSolverV2;
}
static bool shouldEarlyExit(NormalizationResult res) static bool shouldEarlyExit(NormalizationResult res)
{ {
// if res is hit limits, return control flow // if res is hit limits, return control flow
@ -589,10 +570,11 @@ NormalizationResult Normalizer::isInhabited(TypeId ty, Set<TypeId>& seen)
NormalizationResult Normalizer::isIntersectionInhabited(TypeId left, TypeId right) NormalizationResult Normalizer::isIntersectionInhabited(TypeId left, TypeId right)
{ {
Set<TypeId> seen{nullptr}; Set<TypeId> seen{nullptr};
return isIntersectionInhabited(left, right, seen); SeenTablePropPairs seenTablePropPairs{{nullptr, nullptr}};
return isIntersectionInhabited(left, right, seenTablePropPairs, seen);
} }
NormalizationResult Normalizer::isIntersectionInhabited(TypeId left, TypeId right, Set<TypeId>& seenSet) NormalizationResult Normalizer::isIntersectionInhabited(TypeId left, TypeId right, SeenTablePropPairs& seenTablePropPairs, Set<TypeId>& seenSet)
{ {
left = follow(left); left = follow(left);
right = follow(right); right = follow(right);
@ -605,7 +587,7 @@ NormalizationResult Normalizer::isIntersectionInhabited(TypeId left, TypeId righ
} }
NormalizedType norm{builtinTypes}; NormalizedType norm{builtinTypes};
NormalizationResult res = normalizeIntersections({left, right}, norm, seenSet); NormalizationResult res = normalizeIntersections({left, right}, norm, seenTablePropPairs, seenSet);
if (res != NormalizationResult::True) if (res != NormalizationResult::True)
{ {
if (cacheInhabitance && res == NormalizationResult::False) if (cacheInhabitance && res == NormalizationResult::False)
@ -956,7 +938,8 @@ std::shared_ptr<const NormalizedType> Normalizer::normalize(TypeId ty)
NormalizedType norm{builtinTypes}; NormalizedType norm{builtinTypes};
Set<TypeId> seenSetTypes{nullptr}; Set<TypeId> seenSetTypes{nullptr};
NormalizationResult res = unionNormalWithTy(norm, ty, seenSetTypes); SeenTablePropPairs seenTablePropPairs{{nullptr, nullptr}};
NormalizationResult res = unionNormalWithTy(norm, ty, seenTablePropPairs, seenSetTypes);
if (res != NormalizationResult::True) if (res != NormalizationResult::True)
return nullptr; return nullptr;
@ -974,7 +957,12 @@ std::shared_ptr<const NormalizedType> Normalizer::normalize(TypeId ty)
return shared; return shared;
} }
NormalizationResult Normalizer::normalizeIntersections(const std::vector<TypeId>& intersections, NormalizedType& outType, Set<TypeId>& seenSet) NormalizationResult Normalizer::normalizeIntersections(
const std::vector<TypeId>& intersections,
NormalizedType& outType,
SeenTablePropPairs& seenTablePropPairs,
Set<TypeId>& seenSet
)
{ {
if (!arena) if (!arena)
sharedState->iceHandler->ice("Normalizing types outside a module"); sharedState->iceHandler->ice("Normalizing types outside a module");
@ -983,7 +971,7 @@ NormalizationResult Normalizer::normalizeIntersections(const std::vector<TypeId>
// Now we need to intersect the two types // Now we need to intersect the two types
for (auto ty : intersections) for (auto ty : intersections)
{ {
NormalizationResult res = intersectNormalWithTy(norm, ty, seenSet); NormalizationResult res = intersectNormalWithTy(norm, ty, seenTablePropPairs, seenSet);
if (res != NormalizationResult::True) if (res != NormalizationResult::True)
return res; return res;
} }
@ -1620,7 +1608,7 @@ void Normalizer::unionTablesWithTable(TypeIds& heres, TypeId there)
// TODO: remove unions of tables where possible // TODO: remove unions of tables where possible
// we can always skip `never` // we can always skip `never`
if (normalizeAwayUninhabitableTables() && get<NeverType>(there)) if (get<NeverType>(there))
return; return;
heres.insert(there); heres.insert(there);
@ -1747,7 +1735,13 @@ NormalizationResult Normalizer::intersectNormalWithNegationTy(TypeId toNegate, N
} }
// See above for an explaination of `ignoreSmallerTyvars`. // See above for an explaination of `ignoreSmallerTyvars`.
NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, Set<TypeId>& seenSetTypes, int ignoreSmallerTyvars) NormalizationResult Normalizer::unionNormalWithTy(
NormalizedType& here,
TypeId there,
SeenTablePropPairs& seenTablePropPairs,
Set<TypeId>& seenSetTypes,
int ignoreSmallerTyvars
)
{ {
RecursionCounter _rc(&sharedState->counters.recursionCount); RecursionCounter _rc(&sharedState->counters.recursionCount);
if (!withinResourceLimits()) if (!withinResourceLimits())
@ -1779,7 +1773,7 @@ NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId t
for (UnionTypeIterator it = begin(utv); it != end(utv); ++it) for (UnionTypeIterator it = begin(utv); it != end(utv); ++it)
{ {
NormalizationResult res = unionNormalWithTy(here, *it, seenSetTypes); NormalizationResult res = unionNormalWithTy(here, *it, seenTablePropPairs, seenSetTypes);
if (res != NormalizationResult::True) if (res != NormalizationResult::True)
{ {
seenSetTypes.erase(there); seenSetTypes.erase(there);
@ -1800,7 +1794,7 @@ NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId t
norm.tops = builtinTypes->anyType; norm.tops = builtinTypes->anyType;
for (IntersectionTypeIterator it = begin(itv); it != end(itv); ++it) for (IntersectionTypeIterator it = begin(itv); it != end(itv); ++it)
{ {
NormalizationResult res = intersectNormalWithTy(norm, *it, seenSetTypes); NormalizationResult res = intersectNormalWithTy(norm, *it, seenTablePropPairs, seenSetTypes);
if (res != NormalizationResult::True) if (res != NormalizationResult::True)
{ {
seenSetTypes.erase(there); seenSetTypes.erase(there);
@ -1814,7 +1808,8 @@ NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId t
} }
else if (get<UnknownType>(here.tops)) else if (get<UnknownType>(here.tops))
return NormalizationResult::True; return NormalizationResult::True;
else if (get<GenericType>(there) || get<FreeType>(there) || get<BlockedType>(there) || get<PendingExpansionType>(there) || get<TypeFunctionInstanceType>(there)) else if (get<GenericType>(there) || get<FreeType>(there) || get<BlockedType>(there) || get<PendingExpansionType>(there) ||
get<TypeFunctionInstanceType>(there))
{ {
if (tyvarIndex(there) <= ignoreSmallerTyvars) if (tyvarIndex(there) <= ignoreSmallerTyvars)
return NormalizationResult::True; return NormalizationResult::True;
@ -1891,7 +1886,7 @@ NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId t
if (res != NormalizationResult::True) if (res != NormalizationResult::True)
return res; return res;
} }
else if (get<PendingExpansionType>(there) || get<TypeFunctionInstanceType>(there)) else if (get<PendingExpansionType>(there) || get<TypeFunctionInstanceType>(there) || get<NoRefineType>(there))
{ {
// nothing // nothing
} }
@ -1900,7 +1895,7 @@ NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId t
for (auto& [tyvar, intersect] : here.tyvars) for (auto& [tyvar, intersect] : here.tyvars)
{ {
NormalizationResult res = unionNormalWithTy(*intersect, there, seenSetTypes, tyvarIndex(tyvar)); NormalizationResult res = unionNormalWithTy(*intersect, there, seenTablePropPairs, seenSetTypes, tyvarIndex(tyvar));
if (res != NormalizationResult::True) if (res != NormalizationResult::True)
return res; return res;
} }
@ -2289,9 +2284,24 @@ void Normalizer::intersectClassesWithClass(NormalizedClassType& heres, TypeId th
else if (isSubclass(there, hereTy)) else if (isSubclass(there, hereTy))
{ {
TypeIds negations = std::move(hereNegations); TypeIds negations = std::move(hereNegations);
bool emptyIntersectWithNegation = false;
for (auto nIt = negations.begin(); nIt != negations.end();) for (auto nIt = negations.begin(); nIt != negations.end();)
{ {
if (FFlag::LuauFixNormalizedIntersectionOfNegatedClass && isSubclass(there, *nIt))
{
// Hitting this block means that the incoming class is a
// subclass of this type, _and_ one of its negations is a
// superclass of this type, e.g.:
//
// Dog & ~Animal
//
// Clearly this intersects to never, so we mark this class as
// being removed from the normalized class type.
emptyIntersectWithNegation = true;
break;
}
if (!isSubclass(*nIt, there)) if (!isSubclass(*nIt, there))
{ {
nIt = negations.erase(nIt); nIt = negations.erase(nIt);
@ -2304,6 +2314,7 @@ void Normalizer::intersectClassesWithClass(NormalizedClassType& heres, TypeId th
it = heres.ordering.erase(it); it = heres.ordering.erase(it);
heres.classes.erase(hereTy); heres.classes.erase(hereTy);
if (!emptyIntersectWithNegation)
heres.pushPair(there, std::move(negations)); heres.pushPair(there, std::move(negations));
break; break;
} }
@ -2510,7 +2521,7 @@ std::optional<TypePackId> Normalizer::intersectionOfTypePacks(TypePackId here, T
return arena->addTypePack({}); return arena->addTypePack({});
} }
std::optional<TypeId> Normalizer::intersectionOfTables(TypeId here, TypeId there, Set<TypeId>& seenSet) std::optional<TypeId> Normalizer::intersectionOfTables(TypeId here, TypeId there, SeenTablePropPairs& seenTablePropPairs, Set<TypeId>& seenSet)
{ {
if (here == there) if (here == there)
return here; return here;
@ -2589,50 +2600,61 @@ std::optional<TypeId> Normalizer::intersectionOfTables(TypeId here, TypeId there
{ {
if (tprop.readTy.has_value()) if (tprop.readTy.has_value())
{ {
// if the intersection of the read types of a property is uninhabited, the whole table is `never`. if (FFlag::LuauFixInfiniteRecursionInNormalization)
if (fixReduceStackPressure())
{ {
TypeId ty = simplifyIntersection(builtinTypes, NotNull{arena}, *hprop.readTy, *tprop.readTy).result;
// If any property is going to get mapped to `never`, we can just call the entire table `never`.
// Since this check is syntactic, we may sometimes miss simplifying tables with complex uninhabited properties.
// Prior versions of this code attempted to do this semantically using the normalization machinery, but this
// mistakenly causes infinite loops when giving more complex recursive table types. As it stands, this approach
// will continue to scale as simplification is improved, but we may wish to reintroduce the semantic approach
// once we have revisited the usage of seen sets systematically (and possibly with some additional guarding to recognize
// when types are infinitely-recursive with non-pointer identical instances of them, or some guard to prevent that
// construction altogether). See also: `gh1632_no_infinite_recursion_in_normalization`
if (get<NeverType>(ty))
return {builtinTypes->neverType};
prop.readTy = ty;
hereSubThere &= (ty == hprop.readTy);
thereSubHere &= (ty == tprop.readTy);
}
else
{
// if the intersection of the read types of a property is uninhabited, the whole table is `never`.
// We've seen these table prop elements before and we're about to ask if their intersection // We've seen these table prop elements before and we're about to ask if their intersection
// is inhabited // is inhabited
if (fixCyclicTablesBlowingStack())
auto pair1 = std::pair{*hprop.readTy, *tprop.readTy};
auto pair2 = std::pair{*tprop.readTy, *hprop.readTy};
if (seenTablePropPairs.contains(pair1) || seenTablePropPairs.contains(pair2))
{ {
if (seenSet.contains(*hprop.readTy) && seenSet.contains(*tprop.readTy)) seenTablePropPairs.erase(pair1);
{ seenTablePropPairs.erase(pair2);
seenSet.erase(*hprop.readTy);
seenSet.erase(*tprop.readTy);
return {builtinTypes->neverType}; return {builtinTypes->neverType};
} }
else else
{ {
seenSet.insert(*hprop.readTy); seenTablePropPairs.insert(pair1);
seenSet.insert(*tprop.readTy); seenTablePropPairs.insert(pair2);
}
} }
NormalizationResult res = isIntersectionInhabited(*hprop.readTy, *tprop.readTy); // FIXME(ariel): this is being added in a flag removal, so not changing the semantics here, but worth noting that this
// fresh `seenSet` is definitely a bug. we already have `seenSet` from the parameter that _should_ have been used here.
Set<TypeId> seenSet{nullptr};
NormalizationResult res = isIntersectionInhabited(*hprop.readTy, *tprop.readTy, seenTablePropPairs, seenSet);
// Cleanup seenTablePropPairs.erase(pair1);
if (fixCyclicTablesBlowingStack()) seenTablePropPairs.erase(pair2);
{ if (NormalizationResult::True != res)
seenSet.erase(*hprop.readTy);
seenSet.erase(*tprop.readTy);
}
if (normalizeAwayUninhabitableTables() && NormalizationResult::True != res)
return {builtinTypes->neverType}; return {builtinTypes->neverType};
}
else
{
if (normalizeAwayUninhabitableTables() &&
NormalizationResult::False == isIntersectionInhabited(*hprop.readTy, *tprop.readTy))
return {builtinTypes->neverType};
}
TypeId ty = simplifyIntersection(builtinTypes, NotNull{arena}, *hprop.readTy, *tprop.readTy).result; TypeId ty = simplifyIntersection(builtinTypes, NotNull{arena}, *hprop.readTy, *tprop.readTy).result;
prop.readTy = ty; prop.readTy = ty;
hereSubThere &= (ty == hprop.readTy); hereSubThere &= (ty == hprop.readTy);
thereSubHere &= (ty == tprop.readTy); thereSubHere &= (ty == tprop.readTy);
} }
}
else else
{ {
prop.readTy = *hprop.readTy; prop.readTy = *hprop.readTy;
@ -2737,7 +2759,7 @@ std::optional<TypeId> Normalizer::intersectionOfTables(TypeId here, TypeId there
if (tmtable && hmtable) if (tmtable && hmtable)
{ {
// NOTE: this assumes metatables are ivariant // NOTE: this assumes metatables are ivariant
if (std::optional<TypeId> mtable = intersectionOfTables(hmtable, tmtable, seenSet)) if (std::optional<TypeId> mtable = intersectionOfTables(hmtable, tmtable, seenTablePropPairs, seenSet))
{ {
if (table == htable && *mtable == hmtable) if (table == htable && *mtable == hmtable)
return here; return here;
@ -2767,12 +2789,12 @@ std::optional<TypeId> Normalizer::intersectionOfTables(TypeId here, TypeId there
return table; return table;
} }
void Normalizer::intersectTablesWithTable(TypeIds& heres, TypeId there, Set<TypeId>& seenSetTypes) void Normalizer::intersectTablesWithTable(TypeIds& heres, TypeId there, SeenTablePropPairs& seenTablePropPairs, Set<TypeId>& seenSetTypes)
{ {
TypeIds tmp; TypeIds tmp;
for (TypeId here : heres) for (TypeId here : heres)
{ {
if (std::optional<TypeId> inter = intersectionOfTables(here, there, seenSetTypes)) if (std::optional<TypeId> inter = intersectionOfTables(here, there, seenTablePropPairs, seenSetTypes))
tmp.insert(*inter); tmp.insert(*inter);
} }
heres.retain(tmp); heres.retain(tmp);
@ -2787,7 +2809,8 @@ void Normalizer::intersectTables(TypeIds& heres, const TypeIds& theres)
for (TypeId there : theres) for (TypeId there : theres)
{ {
Set<TypeId> seenSetTypes{nullptr}; Set<TypeId> seenSetTypes{nullptr};
if (std::optional<TypeId> inter = intersectionOfTables(here, there, seenSetTypes)) SeenTablePropPairs seenTablePropPairs{{nullptr, nullptr}};
if (std::optional<TypeId> inter = intersectionOfTables(here, there, seenTablePropPairs, seenSetTypes))
tmp.insert(*inter); tmp.insert(*inter);
} }
} }
@ -3005,12 +3028,17 @@ void Normalizer::intersectFunctions(NormalizedFunctionType& heres, const Normali
} }
} }
NormalizationResult Normalizer::intersectTyvarsWithTy(NormalizedTyvars& here, TypeId there, Set<TypeId>& seenSetTypes) NormalizationResult Normalizer::intersectTyvarsWithTy(
NormalizedTyvars& here,
TypeId there,
SeenTablePropPairs& seenTablePropPairs,
Set<TypeId>& seenSetTypes
)
{ {
for (auto it = here.begin(); it != here.end();) for (auto it = here.begin(); it != here.end();)
{ {
NormalizedType& inter = *it->second; NormalizedType& inter = *it->second;
NormalizationResult res = intersectNormalWithTy(inter, there, seenSetTypes); NormalizationResult res = intersectNormalWithTy(inter, there, seenTablePropPairs, seenSetTypes);
if (res != NormalizationResult::True) if (res != NormalizationResult::True)
return res; return res;
if (isShallowInhabited(inter)) if (isShallowInhabited(inter))
@ -3024,6 +3052,10 @@ NormalizationResult Normalizer::intersectTyvarsWithTy(NormalizedTyvars& here, Ty
// See above for an explaination of `ignoreSmallerTyvars`. // See above for an explaination of `ignoreSmallerTyvars`.
NormalizationResult Normalizer::intersectNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars) NormalizationResult Normalizer::intersectNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars)
{ {
RecursionCounter _rc(&sharedState->counters.recursionCount);
if (!withinResourceLimits())
return NormalizationResult::HitLimits;
if (!get<NeverType>(there.tops)) if (!get<NeverType>(there.tops))
{ {
here.tops = intersectionOfTops(here.tops, there.tops); here.tops = intersectionOfTops(here.tops, there.tops);
@ -3035,6 +3067,11 @@ NormalizationResult Normalizer::intersectNormals(NormalizedType& here, const Nor
return unionNormals(here, there, ignoreSmallerTyvars); return unionNormals(here, there, ignoreSmallerTyvars);
} }
// Limit based on worst-case expansion of the table intersection
// This restriction can be relaxed when table intersection simplification is improved
if (here.tables.size() * there.tables.size() >= size_t(FInt::LuauNormalizeIntersectionLimit))
return NormalizationResult::HitLimits;
here.booleans = intersectionOfBools(here.booleans, there.booleans); here.booleans = intersectionOfBools(here.booleans, there.booleans);
intersectClasses(here.classes, there.classes); intersectClasses(here.classes, there.classes);
@ -3088,7 +3125,12 @@ NormalizationResult Normalizer::intersectNormals(NormalizedType& here, const Nor
return NormalizationResult::True; return NormalizationResult::True;
} }
NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there, Set<TypeId>& seenSetTypes) NormalizationResult Normalizer::intersectNormalWithTy(
NormalizedType& here,
TypeId there,
SeenTablePropPairs& seenTablePropPairs,
Set<TypeId>& seenSetTypes
)
{ {
RecursionCounter _rc(&sharedState->counters.recursionCount); RecursionCounter _rc(&sharedState->counters.recursionCount);
if (!withinResourceLimits()) if (!withinResourceLimits())
@ -3104,14 +3146,14 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type
else if (!get<NeverType>(here.tops)) else if (!get<NeverType>(here.tops))
{ {
clearNormal(here); clearNormal(here);
return unionNormalWithTy(here, there, seenSetTypes); return unionNormalWithTy(here, there, seenTablePropPairs, seenSetTypes);
} }
else if (const UnionType* utv = get<UnionType>(there)) else if (const UnionType* utv = get<UnionType>(there))
{ {
NormalizedType norm{builtinTypes}; NormalizedType norm{builtinTypes};
for (UnionTypeIterator it = begin(utv); it != end(utv); ++it) for (UnionTypeIterator it = begin(utv); it != end(utv); ++it)
{ {
NormalizationResult res = unionNormalWithTy(norm, *it, seenSetTypes); NormalizationResult res = unionNormalWithTy(norm, *it, seenTablePropPairs, seenSetTypes);
if (res != NormalizationResult::True) if (res != NormalizationResult::True)
return res; return res;
} }
@ -3121,13 +3163,14 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type
{ {
for (IntersectionTypeIterator it = begin(itv); it != end(itv); ++it) for (IntersectionTypeIterator it = begin(itv); it != end(itv); ++it)
{ {
NormalizationResult res = intersectNormalWithTy(here, *it, seenSetTypes); NormalizationResult res = intersectNormalWithTy(here, *it, seenTablePropPairs, seenSetTypes);
if (res != NormalizationResult::True) if (res != NormalizationResult::True)
return res; return res;
} }
return NormalizationResult::True; return NormalizationResult::True;
} }
else if (get<GenericType>(there) || get<FreeType>(there) || get<BlockedType>(there) || get<PendingExpansionType>(there) || get<TypeFunctionInstanceType>(there)) else if (get<GenericType>(there) || get<FreeType>(there) || get<BlockedType>(there) || get<PendingExpansionType>(there) ||
get<TypeFunctionInstanceType>(there))
{ {
NormalizedType thereNorm{builtinTypes}; NormalizedType thereNorm{builtinTypes};
NormalizedType topNorm{builtinTypes}; NormalizedType topNorm{builtinTypes};
@ -3150,7 +3193,7 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type
{ {
TypeIds tables = std::move(here.tables); TypeIds tables = std::move(here.tables);
clearNormal(here); clearNormal(here);
intersectTablesWithTable(tables, there, seenSetTypes); intersectTablesWithTable(tables, there, seenTablePropPairs, seenSetTypes);
here.tables = std::move(tables); here.tables = std::move(tables);
} }
else if (get<ClassType>(there)) else if (get<ClassType>(there))
@ -3243,13 +3286,18 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type
// assumption that it is the same as any. // assumption that it is the same as any.
return NormalizationResult::True; return NormalizationResult::True;
} }
else if (get<NoRefineType>(t))
{
// `*no-refine*` means we will never do anything to affect the intersection.
return NormalizationResult::True;
}
else if (get<NeverType>(t)) else if (get<NeverType>(t))
{ {
// if we're intersecting with `~never`, this is equivalent to intersecting with `unknown` // if we're intersecting with `~never`, this is equivalent to intersecting with `unknown`
// this is a noop since an intersection with `unknown` is trivial. // this is a noop since an intersection with `unknown` is trivial.
return NormalizationResult::True; return NormalizationResult::True;
} }
else if ((FFlag::LuauNormalizeNotUnknownIntersection || FFlag::LuauSolverV2) && get<UnknownType>(t)) else if (get<UnknownType>(t))
{ {
// if we're intersecting with `~unknown`, this is equivalent to intersecting with `never` // if we're intersecting with `~unknown`, this is equivalent to intersecting with `never`
// this means we should clear the type entirely. // this means we should clear the type entirely.
@ -3257,7 +3305,7 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type
return NormalizationResult::True; return NormalizationResult::True;
} }
else if (auto nt = get<NegationType>(t)) else if (auto nt = get<NegationType>(t))
return intersectNormalWithTy(here, nt->ty, seenSetTypes); return intersectNormalWithTy(here, nt->ty, seenTablePropPairs, seenSetTypes);
else else
{ {
// TODO negated unions, intersections, table, and function. // TODO negated unions, intersections, table, and function.
@ -3269,10 +3317,15 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type
{ {
here.classes.resetToNever(); here.classes.resetToNever();
} }
else if (get<NoRefineType>(there))
{
// `*no-refine*` means we will never do anything to affect the intersection.
return NormalizationResult::True;
}
else else
LUAU_ASSERT(!"Unreachable"); LUAU_ASSERT(!"Unreachable");
NormalizationResult res = intersectTyvarsWithTy(tyvars, there, seenSetTypes); NormalizationResult res = intersectTyvarsWithTy(tyvars, there, seenTablePropPairs, seenSetTypes);
if (res != NormalizationResult::True) if (res != NormalizationResult::True)
return res; return res;
here.tyvars = std::move(tyvars); here.tyvars = std::move(tyvars);
@ -3420,16 +3473,27 @@ TypeId Normalizer::typeFromNormal(const NormalizedType& norm)
return arena->addType(UnionType{std::move(result)}); return arena->addType(UnionType{std::move(result)});
} }
bool isSubtype(TypeId subTy, TypeId superTy, NotNull<Scope> scope, NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter& ice) bool isSubtype(
TypeId subTy,
TypeId superTy,
NotNull<Scope> scope,
NotNull<BuiltinTypes> builtinTypes,
NotNull<Simplifier> simplifier,
InternalErrorReporter& ice
)
{ {
UnifierSharedState sharedState{&ice}; UnifierSharedState sharedState{&ice};
TypeArena arena; TypeArena arena;
Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}};
TypeCheckLimits limits;
TypeFunctionRuntime typeFunctionRuntime{
NotNull{&ice}, NotNull{&limits}
}; // TODO: maybe subtyping checks should not invoke user-defined type function runtime
// Subtyping under DCR is not implemented using unification! // Subtyping under DCR is not implemented using unification!
if (FFlag::LuauSolverV2) if (FFlag::LuauSolverV2)
{ {
Subtyping subtyping{builtinTypes, NotNull{&arena}, NotNull{&normalizer}, NotNull{&ice}}; Subtyping subtyping{builtinTypes, NotNull{&arena}, simplifier, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&ice}};
return subtyping.isSubtype(subTy, superTy, scope).isSubtype; return subtyping.isSubtype(subTy, superTy, scope).isSubtype;
} }
@ -3442,16 +3506,27 @@ bool isSubtype(TypeId subTy, TypeId superTy, NotNull<Scope> scope, NotNull<Built
} }
} }
bool isSubtype(TypePackId subPack, TypePackId superPack, NotNull<Scope> scope, NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter& ice) bool isSubtype(
TypePackId subPack,
TypePackId superPack,
NotNull<Scope> scope,
NotNull<BuiltinTypes> builtinTypes,
NotNull<Simplifier> simplifier,
InternalErrorReporter& ice
)
{ {
UnifierSharedState sharedState{&ice}; UnifierSharedState sharedState{&ice};
TypeArena arena; TypeArena arena;
Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}};
TypeCheckLimits limits;
TypeFunctionRuntime typeFunctionRuntime{
NotNull{&ice}, NotNull{&limits}
}; // TODO: maybe subtyping checks should not invoke user-defined type function runtime
// Subtyping under DCR is not implemented using unification! // Subtyping under DCR is not implemented using unification!
if (FFlag::LuauSolverV2) if (FFlag::LuauSolverV2)
{ {
Subtyping subtyping{builtinTypes, NotNull{&arena}, NotNull{&normalizer}, NotNull{&ice}}; Subtyping subtyping{builtinTypes, NotNull{&arena}, simplifier, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&ice}};
return subtyping.isSubtype(subPack, superPack, scope).isSubtype; return subtyping.isSubtype(subPack, superPack, scope).isSubtype;
} }
@ -3464,38 +3539,4 @@ bool isSubtype(TypePackId subPack, TypePackId superPack, NotNull<Scope> scope, N
} }
} }
bool isConsistentSubtype(TypeId subTy, TypeId superTy, NotNull<Scope> scope, NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter& ice)
{
LUAU_ASSERT(!FFlag::LuauSolverV2);
UnifierSharedState sharedState{&ice};
TypeArena arena;
Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}};
Unifier u{NotNull{&normalizer}, scope, Location{}, Covariant};
u.tryUnify(subTy, superTy);
const bool ok = u.errors.empty() && u.log.empty();
return ok;
}
bool isConsistentSubtype(
TypePackId subPack,
TypePackId superPack,
NotNull<Scope> scope,
NotNull<BuiltinTypes> builtinTypes,
InternalErrorReporter& ice
)
{
LUAU_ASSERT(!FFlag::LuauSolverV2);
UnifierSharedState sharedState{&ice};
TypeArena arena;
Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}};
Unifier u{NotNull{&normalizer}, scope, Location{}, Covariant};
u.tryUnify(subPack, superPack);
const bool ok = u.errors.empty() && u.log.empty();
return ok;
}
} // namespace Luau } // namespace Luau

View file

@ -16,7 +16,9 @@ namespace Luau
OverloadResolver::OverloadResolver( OverloadResolver::OverloadResolver(
NotNull<BuiltinTypes> builtinTypes, NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> arena, NotNull<TypeArena> arena,
NotNull<Simplifier> simplifier,
NotNull<Normalizer> normalizer, NotNull<Normalizer> normalizer,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<Scope> scope, NotNull<Scope> scope,
NotNull<InternalErrorReporter> reporter, NotNull<InternalErrorReporter> reporter,
NotNull<TypeCheckLimits> limits, NotNull<TypeCheckLimits> limits,
@ -24,11 +26,13 @@ OverloadResolver::OverloadResolver(
) )
: builtinTypes(builtinTypes) : builtinTypes(builtinTypes)
, arena(arena) , arena(arena)
, simplifier(simplifier)
, normalizer(normalizer) , normalizer(normalizer)
, typeFunctionRuntime(typeFunctionRuntime)
, scope(scope) , scope(scope)
, ice(reporter) , ice(reporter)
, limits(limits) , limits(limits)
, subtyping({builtinTypes, arena, normalizer, ice}) , subtyping({builtinTypes, arena, simplifier, normalizer, typeFunctionRuntime, ice})
, callLoc(callLocation) , callLoc(callLocation)
{ {
} }
@ -199,8 +203,9 @@ std::pair<OverloadResolver::Analysis, ErrorVec> OverloadResolver::checkOverload_
const std::vector<AstExpr*>* argExprs const std::vector<AstExpr*>* argExprs
) )
{ {
FunctionGraphReductionResult result = FunctionGraphReductionResult result = reduceTypeFunctions(
reduceTypeFunctions(fnTy, callLoc, TypeFunctionContext{arena, builtinTypes, scope, normalizer, ice, limits}, /*force=*/true); fnTy, callLoc, TypeFunctionContext{arena, builtinTypes, scope, simplifier, normalizer, typeFunctionRuntime, ice, limits}, /*force=*/true
);
if (!result.errors.empty()) if (!result.errors.empty())
return {OverloadIsNonviable, result.errors}; return {OverloadIsNonviable, result.errors};
@ -401,10 +406,12 @@ void OverloadResolver::add(Analysis analysis, TypeId ty, ErrorVec&& errors)
// we wrap calling the overload resolver in a separate function to reduce overall stack pressure in `solveFunctionCall`. // we wrap calling the overload resolver in a separate function to reduce overall stack pressure in `solveFunctionCall`.
// this limits the lifetime of `OverloadResolver`, a large type, to only as long as it is actually needed. // this limits the lifetime of `OverloadResolver`, a large type, to only as long as it is actually needed.
std::optional<TypeId> selectOverload( static std::optional<TypeId> selectOverload(
NotNull<BuiltinTypes> builtinTypes, NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> arena, NotNull<TypeArena> arena,
NotNull<Simplifier> simplifier,
NotNull<Normalizer> normalizer, NotNull<Normalizer> normalizer,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<Scope> scope, NotNull<Scope> scope,
NotNull<InternalErrorReporter> iceReporter, NotNull<InternalErrorReporter> iceReporter,
NotNull<TypeCheckLimits> limits, NotNull<TypeCheckLimits> limits,
@ -413,8 +420,9 @@ std::optional<TypeId> selectOverload(
TypePackId argsPack TypePackId argsPack
) )
{ {
OverloadResolver resolver{builtinTypes, arena, normalizer, scope, iceReporter, limits, location}; auto resolver =
auto [status, overload] = resolver.selectOverload(fn, argsPack); std::make_unique<OverloadResolver>(builtinTypes, arena, simplifier, normalizer, typeFunctionRuntime, scope, iceReporter, limits, location);
auto [status, overload] = resolver->selectOverload(fn, argsPack);
if (status == OverloadResolver::Analysis::Ok) if (status == OverloadResolver::Analysis::Ok)
return overload; return overload;
@ -428,7 +436,9 @@ std::optional<TypeId> selectOverload(
SolveResult solveFunctionCall( SolveResult solveFunctionCall(
NotNull<TypeArena> arena, NotNull<TypeArena> arena,
NotNull<BuiltinTypes> builtinTypes, NotNull<BuiltinTypes> builtinTypes,
NotNull<Simplifier> simplifier,
NotNull<Normalizer> normalizer, NotNull<Normalizer> normalizer,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<InternalErrorReporter> iceReporter, NotNull<InternalErrorReporter> iceReporter,
NotNull<TypeCheckLimits> limits, NotNull<TypeCheckLimits> limits,
NotNull<Scope> scope, NotNull<Scope> scope,
@ -437,7 +447,8 @@ SolveResult solveFunctionCall(
TypePackId argsPack TypePackId argsPack
) )
{ {
std::optional<TypeId> overloadToUse = selectOverload(builtinTypes, arena, normalizer, scope, iceReporter, limits, location, fn, argsPack); std::optional<TypeId> overloadToUse =
selectOverload(builtinTypes, arena, simplifier, normalizer, typeFunctionRuntime, scope, iceReporter, limits, location, fn, argsPack);
if (!overloadToUse) if (!overloadToUse)
return {SolveResult::NoMatchingOverload}; return {SolveResult::NoMatchingOverload};
@ -450,9 +461,9 @@ SolveResult solveFunctionCall(
if (!u2.genericSubstitutions.empty() || !u2.genericPackSubstitutions.empty()) if (!u2.genericSubstitutions.empty() || !u2.genericPackSubstitutions.empty())
{ {
Instantiation2 instantiation{arena, std::move(u2.genericSubstitutions), std::move(u2.genericPackSubstitutions)}; auto instantiation = std::make_unique<Instantiation2>(arena, std::move(u2.genericSubstitutions), std::move(u2.genericPackSubstitutions));
std::optional<TypePackId> subst = instantiation.substitute(resultPack); std::optional<TypePackId> subst = instantiation->substitute(resultPack);
if (!subst) if (!subst)
return {SolveResult::CodeTooComplex}; return {SolveResult::CodeTooComplex};

View file

@ -4,6 +4,8 @@
#include "Luau/Ast.h" #include "Luau/Ast.h"
#include "Luau/Module.h" #include "Luau/Module.h"
LUAU_FASTFLAGVARIABLE(LuauExtendedSimpleRequire)
namespace Luau namespace Luau
{ {
@ -65,7 +67,7 @@ struct RequireTracer : AstVisitor
return true; return true;
} }
AstExpr* getDependent(AstExpr* node) AstExpr* getDependent_DEPRECATED(AstExpr* node)
{ {
if (AstExprLocal* expr = node->as<AstExprLocal>()) if (AstExprLocal* expr = node->as<AstExprLocal>())
return locals[expr->local]; return locals[expr->local];
@ -78,11 +80,34 @@ struct RequireTracer : AstVisitor
else else
return nullptr; return nullptr;
} }
AstNode* getDependent(AstNode* node)
{
if (AstExprLocal* expr = node->as<AstExprLocal>())
return locals[expr->local];
else if (AstExprIndexName* expr = node->as<AstExprIndexName>())
return expr->expr;
else if (AstExprIndexExpr* expr = node->as<AstExprIndexExpr>())
return expr->expr;
else if (AstExprCall* expr = node->as<AstExprCall>(); expr && expr->self)
return expr->func->as<AstExprIndexName>()->expr;
else if (AstExprGroup* expr = node->as<AstExprGroup>())
return expr->expr;
else if (AstExprTypeAssertion* expr = node->as<AstExprTypeAssertion>())
return expr->annotation;
else if (AstTypeGroup* expr = node->as<AstTypeGroup>())
return expr->type;
else if (AstTypeTypeof* expr = node->as<AstTypeTypeof>())
return expr->expr;
else
return nullptr;
}
void process() void process()
{ {
ModuleInfo moduleContext{currentModuleName}; ModuleInfo moduleContext{currentModuleName};
if (FFlag::LuauExtendedSimpleRequire)
{
// seed worklist with require arguments // seed worklist with require arguments
work.reserve(requireCalls.size()); work.reserve(requireCalls.size());
@ -91,13 +116,15 @@ struct RequireTracer : AstVisitor
// push all dependent expressions to the work stack; note that the vector is modified during traversal // push all dependent expressions to the work stack; note that the vector is modified during traversal
for (size_t i = 0; i < work.size(); ++i) for (size_t i = 0; i < work.size(); ++i)
if (AstExpr* dep = getDependent(work[i])) {
if (AstNode* dep = getDependent(work[i]))
work.push_back(dep); work.push_back(dep);
}
// resolve all expressions to a module info // resolve all expressions to a module info
for (size_t i = work.size(); i > 0; --i) for (size_t i = work.size(); i > 0; --i)
{ {
AstExpr* expr = work[i - 1]; AstNode* expr = work[i - 1];
// when multiple expressions depend on the same one we push it to work queue multiple times // when multiple expressions depend on the same one we push it to work queue multiple times
if (result.exprs.contains(expr)) if (result.exprs.contains(expr))
@ -105,7 +132,53 @@ struct RequireTracer : AstVisitor
std::optional<ModuleInfo> info; std::optional<ModuleInfo> info;
if (AstExpr* dep = getDependent(expr)) if (AstNode* dep = getDependent(expr))
{
const ModuleInfo* context = result.exprs.find(dep);
if (context && expr->is<AstExprLocal>())
info = *context; // locals just inherit their dependent context, no resolution required
else if (context && (expr->is<AstExprGroup>() || expr->is<AstTypeGroup>()))
info = *context; // simple group nodes propagate their value
else if (context && (expr->is<AstTypeTypeof>() || expr->is<AstExprTypeAssertion>()))
info = *context; // typeof type annotations will resolve to the typeof content
else if (AstExpr* asExpr = expr->asExpr())
info = fileResolver->resolveModule(context, asExpr);
}
else if (AstExpr* asExpr = expr->asExpr())
{
info = fileResolver->resolveModule(&moduleContext, asExpr);
}
if (info)
result.exprs[expr] = std::move(*info);
}
}
else
{
// seed worklist with require arguments
work_DEPRECATED.reserve(requireCalls.size());
for (AstExprCall* require : requireCalls)
work_DEPRECATED.push_back(require->args.data[0]);
// push all dependent expressions to the work stack; note that the vector is modified during traversal
for (size_t i = 0; i < work_DEPRECATED.size(); ++i)
if (AstExpr* dep = getDependent_DEPRECATED(work_DEPRECATED[i]))
work_DEPRECATED.push_back(dep);
// resolve all expressions to a module info
for (size_t i = work_DEPRECATED.size(); i > 0; --i)
{
AstExpr* expr = work_DEPRECATED[i - 1];
// when multiple expressions depend on the same one we push it to work queue multiple times
if (result.exprs.contains(expr))
continue;
std::optional<ModuleInfo> info;
if (AstExpr* dep = getDependent_DEPRECATED(expr))
{ {
const ModuleInfo* context = result.exprs.find(dep); const ModuleInfo* context = result.exprs.find(dep);
@ -123,6 +196,7 @@ struct RequireTracer : AstVisitor
if (info) if (info)
result.exprs[expr] = std::move(*info); result.exprs[expr] = std::move(*info);
} }
}
// resolve all requires according to their argument // resolve all requires according to their argument
result.requireList.reserve(requireCalls.size()); result.requireList.reserve(requireCalls.size());
@ -150,7 +224,8 @@ struct RequireTracer : AstVisitor
ModuleName currentModuleName; ModuleName currentModuleName;
DenseHashMap<AstLocal*, AstExpr*> locals; DenseHashMap<AstLocal*, AstExpr*> locals;
std::vector<AstExpr*> work; std::vector<AstExpr*> work_DEPRECATED;
std::vector<AstNode*> work;
std::vector<AstExprCall*> requireCalls; std::vector<AstExprCall*> requireCalls;
}; };

View file

@ -211,6 +211,16 @@ void Scope::inheritRefinements(const ScopePtr& childScope)
} }
} }
bool Scope::shouldWarnGlobal(std::string name) const
{
for (const Scope* current = this; current; current = current->parent.get())
{
if (current->globalsToWarn.contains(name))
return true;
}
return false;
}
bool subsumesStrict(Scope* left, Scope* right) bool subsumesStrict(Scope* left, Scope* right)
{ {
while (right) while (right)

View file

@ -2,6 +2,7 @@
#include "Luau/Simplify.h" #include "Luau/Simplify.h"
#include "Luau/Common.h"
#include "Luau/DenseHash.h" #include "Luau/DenseHash.h"
#include "Luau/RecursionCounter.h" #include "Luau/RecursionCounter.h"
#include "Luau/Set.h" #include "Luau/Set.h"
@ -14,6 +15,7 @@
LUAU_FASTINT(LuauTypeReductionRecursionLimit) LUAU_FASTINT(LuauTypeReductionRecursionLimit)
LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAG(LuauSolverV2)
LUAU_DYNAMIC_FASTINTVARIABLE(LuauSimplificationComplexityLimit, 8); LUAU_DYNAMIC_FASTINTVARIABLE(LuauSimplificationComplexityLimit, 8);
LUAU_FASTFLAGVARIABLE(LuauFlagBasicIntersectFollows);
namespace Luau namespace Luau
{ {
@ -29,16 +31,16 @@ struct TypeSimplifier
int recursionDepth = 0; int recursionDepth = 0;
TypeId mkNegation(TypeId ty); TypeId mkNegation(TypeId ty) const;
TypeId intersectFromParts(std::set<TypeId> parts); TypeId intersectFromParts(std::set<TypeId> parts);
TypeId intersectUnionWithType(TypeId unionTy, TypeId right); TypeId intersectUnionWithType(TypeId left, TypeId right);
TypeId intersectUnions(TypeId left, TypeId right); TypeId intersectUnions(TypeId left, TypeId right);
TypeId intersectNegatedUnion(TypeId unionTy, TypeId right); TypeId intersectNegatedUnion(TypeId left, TypeId right);
TypeId intersectTypeWithNegation(TypeId a, TypeId b); TypeId intersectTypeWithNegation(TypeId left, TypeId right);
TypeId intersectNegations(TypeId a, TypeId b); TypeId intersectNegations(TypeId left, TypeId right);
TypeId intersectIntersectionWithType(TypeId left, TypeId right); TypeId intersectIntersectionWithType(TypeId left, TypeId right);
@ -46,8 +48,8 @@ struct TypeSimplifier
// unions, intersections, or negations. // unions, intersections, or negations.
std::optional<TypeId> basicIntersect(TypeId left, TypeId right); std::optional<TypeId> basicIntersect(TypeId left, TypeId right);
TypeId intersect(TypeId ty, TypeId discriminant); TypeId intersect(TypeId left, TypeId right);
TypeId union_(TypeId ty, TypeId discriminant); TypeId union_(TypeId left, TypeId right);
TypeId simplify(TypeId ty); TypeId simplify(TypeId ty);
TypeId simplify(TypeId ty, DenseHashSet<TypeId>& seen); TypeId simplify(TypeId ty, DenseHashSet<TypeId>& seen);
@ -571,7 +573,7 @@ Relation relate(TypeId left, TypeId right)
return relate(left, right, seen); return relate(left, right, seen);
} }
TypeId TypeSimplifier::mkNegation(TypeId ty) TypeId TypeSimplifier::mkNegation(TypeId ty) const
{ {
TypeId result = nullptr; TypeId result = nullptr;
@ -1064,6 +1066,12 @@ TypeId TypeSimplifier::intersectIntersectionWithType(TypeId left, TypeId right)
std::optional<TypeId> TypeSimplifier::basicIntersect(TypeId left, TypeId right) std::optional<TypeId> TypeSimplifier::basicIntersect(TypeId left, TypeId right)
{ {
if (FFlag::LuauFlagBasicIntersectFollows)
{
left = follow(left);
right = follow(right);
}
if (get<AnyType>(left) && get<ErrorType>(right)) if (get<AnyType>(left) && get<ErrorType>(right))
return right; return right;
if (get<AnyType>(right) && get<ErrorType>(left)) if (get<AnyType>(right) && get<ErrorType>(left))
@ -1403,8 +1411,6 @@ TypeId TypeSimplifier::simplify(TypeId ty, DenseHashSet<TypeId>& seen)
SimplifyResult simplifyIntersection(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, TypeId left, TypeId right) SimplifyResult simplifyIntersection(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, TypeId left, TypeId right)
{ {
LUAU_ASSERT(FFlag::LuauSolverV2);
TypeSimplifier s{builtinTypes, arena}; TypeSimplifier s{builtinTypes, arena};
// fprintf(stderr, "Intersect %s and %s ...\n", toString(left).c_str(), toString(right).c_str()); // fprintf(stderr, "Intersect %s and %s ...\n", toString(left).c_str(), toString(right).c_str());
@ -1418,8 +1424,6 @@ SimplifyResult simplifyIntersection(NotNull<BuiltinTypes> builtinTypes, NotNull<
SimplifyResult simplifyIntersection(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, std::set<TypeId> parts) SimplifyResult simplifyIntersection(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, std::set<TypeId> parts)
{ {
LUAU_ASSERT(FFlag::LuauSolverV2);
TypeSimplifier s{builtinTypes, arena}; TypeSimplifier s{builtinTypes, arena};
TypeId res = s.intersectFromParts(std::move(parts)); TypeId res = s.intersectFromParts(std::move(parts));
@ -1429,8 +1433,6 @@ SimplifyResult simplifyIntersection(NotNull<BuiltinTypes> builtinTypes, NotNull<
SimplifyResult simplifyUnion(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, TypeId left, TypeId right) SimplifyResult simplifyUnion(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, TypeId left, TypeId right)
{ {
LUAU_ASSERT(FFlag::LuauSolverV2);
TypeSimplifier s{builtinTypes, arena}; TypeSimplifier s{builtinTypes, arena};
TypeId res = s.union_(left, right); TypeId res = s.union_(left, right);

View file

@ -4,13 +4,15 @@
#include "Luau/Common.h" #include "Luau/Common.h"
#include "Luau/Clone.h" #include "Luau/Clone.h"
#include "Luau/TxnLog.h" #include "Luau/TxnLog.h"
#include "Luau/Type.h"
#include <algorithm> #include <algorithm>
#include <stdexcept> #include <stdexcept>
LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 10000) LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 10000)
LUAU_FASTFLAG(LuauSolverV2); LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTINTVARIABLE(LuauTarjanPreallocationSize, 256); LUAU_FASTINTVARIABLE(LuauTarjanPreallocationSize, 256)
LUAU_FASTFLAG(LuauSyntheticErrors)
namespace Luau namespace Luau
{ {
@ -50,11 +52,33 @@ static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool a
LUAU_ASSERT(ty->persistent); LUAU_ASSERT(ty->persistent);
return ty; return ty;
} }
else if constexpr (std::is_same_v<T, ErrorType>) else if constexpr (std::is_same_v<T, NoRefineType>)
{ {
LUAU_ASSERT(ty->persistent); LUAU_ASSERT(ty->persistent);
return ty; return ty;
} }
else if constexpr (std::is_same_v<T, ErrorType>)
{
if (FFlag::LuauSyntheticErrors)
{
LUAU_ASSERT(ty->persistent || a.synthetic);
if (ty->persistent)
return ty;
// While this code intentionally works (and clones) even if `a.synthetic` is `std::nullopt`,
// we still assert above because we consider it a bug to have a non-persistent error type
// without any associated metadata. We should always use the persistent version in such cases.
ErrorType clone = ErrorType{};
clone.synthetic = a.synthetic;
return dest.addType(clone);
}
else
{
LUAU_ASSERT(ty->persistent);
return ty;
}
}
else if constexpr (std::is_same_v<T, UnknownType>) else if constexpr (std::is_same_v<T, UnknownType>)
{ {
LUAU_ASSERT(ty->persistent); LUAU_ASSERT(ty->persistent);
@ -74,9 +98,7 @@ static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool a
FunctionType clone = FunctionType{a.level, a.scope, a.argTypes, a.retTypes, a.definition, a.hasSelf}; FunctionType clone = FunctionType{a.level, a.scope, a.argTypes, a.retTypes, a.definition, a.hasSelf};
clone.generics = a.generics; clone.generics = a.generics;
clone.genericPacks = a.genericPacks; clone.genericPacks = a.genericPacks;
clone.magicFunction = a.magicFunction; clone.magic = a.magic;
clone.dcrMagicFunction = a.dcrMagicFunction;
clone.dcrMagicRefinement = a.dcrMagicRefinement;
clone.tags = a.tags; clone.tags = a.tags;
clone.argNames = a.argNames; clone.argNames = a.argNames;
clone.isCheckedFunction = a.isCheckedFunction; clone.isCheckedFunction = a.isCheckedFunction;
@ -127,7 +149,7 @@ static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool a
return dest.addType(NegationType{a.ty}); return dest.addType(NegationType{a.ty});
else if constexpr (std::is_same_v<T, TypeFunctionInstanceType>) else if constexpr (std::is_same_v<T, TypeFunctionInstanceType>)
{ {
TypeFunctionInstanceType clone{a.function, a.typeArguments, a.packArguments, a.userFuncName, a.userFuncBody}; TypeFunctionInstanceType clone{a.function, a.typeArguments, a.packArguments, a.userFuncName, a.userFuncData};
return dest.addType(std::move(clone)); return dest.addType(std::move(clone));
} }
else else

View file

@ -5,6 +5,7 @@
#include "Luau/Common.h" #include "Luau/Common.h"
#include "Luau/Error.h" #include "Luau/Error.h"
#include "Luau/Normalize.h" #include "Luau/Normalize.h"
#include "Luau/RecursionCounter.h"
#include "Luau/Scope.h" #include "Luau/Scope.h"
#include "Luau/StringUtils.h" #include "Luau/StringUtils.h"
#include "Luau/Substitution.h" #include "Luau/Substitution.h"
@ -20,7 +21,8 @@
#include <algorithm> #include <algorithm>
LUAU_FASTFLAGVARIABLE(DebugLuauSubtypingCheckPathValidity, false); LUAU_FASTFLAGVARIABLE(DebugLuauSubtypingCheckPathValidity)
LUAU_FASTFLAGVARIABLE(LuauSubtypingFixTailPack)
namespace Luau namespace Luau
{ {
@ -258,43 +260,32 @@ SubtypingResult SubtypingResult::any(const std::vector<SubtypingResult>& results
struct ApplyMappedGenerics : Substitution struct ApplyMappedGenerics : Substitution
{ {
using MappedGenerics = DenseHashMap<TypeId, SubtypingEnvironment::GenericBounds>;
using MappedGenericPacks = DenseHashMap<TypePackId, TypePackId>;
NotNull<BuiltinTypes> builtinTypes; NotNull<BuiltinTypes> builtinTypes;
NotNull<TypeArena> arena; NotNull<TypeArena> arena;
MappedGenerics& mappedGenerics; SubtypingEnvironment& env;
MappedGenericPacks& mappedGenericPacks;
ApplyMappedGenerics(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, SubtypingEnvironment& env)
ApplyMappedGenerics(
NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> arena,
MappedGenerics& mappedGenerics,
MappedGenericPacks& mappedGenericPacks
)
: Substitution(TxnLog::empty(), arena) : Substitution(TxnLog::empty(), arena)
, builtinTypes(builtinTypes) , builtinTypes(builtinTypes)
, arena(arena) , arena(arena)
, mappedGenerics(mappedGenerics) , env(env)
, mappedGenericPacks(mappedGenericPacks)
{ {
} }
bool isDirty(TypeId ty) override bool isDirty(TypeId ty) override
{ {
return mappedGenerics.contains(ty); return env.containsMappedType(ty);
} }
bool isDirty(TypePackId tp) override bool isDirty(TypePackId tp) override
{ {
return mappedGenericPacks.contains(tp); return env.containsMappedPack(tp);
} }
TypeId clean(TypeId ty) override TypeId clean(TypeId ty) override
{ {
const auto& bounds = mappedGenerics[ty]; const auto& bounds = env.getMappedTypeBounds(ty);
if (bounds.upperBound.empty()) if (bounds.upperBound.empty())
return builtinTypes->unknownType; return builtinTypes->unknownType;
@ -307,7 +298,12 @@ struct ApplyMappedGenerics : Substitution
TypePackId clean(TypePackId tp) override TypePackId clean(TypePackId tp) override
{ {
return mappedGenericPacks[tp]; if (auto it = env.getMappedPackBounds(tp))
return *it;
// Clean is only called when isDirty found a pack bound
LUAU_ASSERT(!"Unreachable");
return nullptr;
} }
bool ignoreChildren(TypeId ty) override bool ignoreChildren(TypeId ty) override
@ -325,19 +321,91 @@ struct ApplyMappedGenerics : Substitution
std::optional<TypeId> SubtypingEnvironment::applyMappedGenerics(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, TypeId ty) std::optional<TypeId> SubtypingEnvironment::applyMappedGenerics(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, TypeId ty)
{ {
ApplyMappedGenerics amg{builtinTypes, arena, mappedGenerics, mappedGenericPacks}; ApplyMappedGenerics amg{builtinTypes, arena, *this};
return amg.substitute(ty); return amg.substitute(ty);
} }
const TypeId* SubtypingEnvironment::tryFindSubstitution(TypeId ty) const
{
if (auto it = substitutions.find(ty))
return it;
if (parent)
return parent->tryFindSubstitution(ty);
return nullptr;
}
const SubtypingResult* SubtypingEnvironment::tryFindSubtypingResult(std::pair<TypeId, TypeId> subAndSuper) const
{
if (auto it = ephemeralCache.find(subAndSuper))
return it;
if (parent)
return parent->tryFindSubtypingResult(subAndSuper);
return nullptr;
}
bool SubtypingEnvironment::containsMappedType(TypeId ty) const
{
if (mappedGenerics.contains(ty))
return true;
if (parent)
return parent->containsMappedType(ty);
return false;
}
bool SubtypingEnvironment::containsMappedPack(TypePackId tp) const
{
if (mappedGenericPacks.contains(tp))
return true;
if (parent)
return parent->containsMappedPack(tp);
return false;
}
SubtypingEnvironment::GenericBounds& SubtypingEnvironment::getMappedTypeBounds(TypeId ty)
{
if (auto it = mappedGenerics.find(ty))
return *it;
if (parent)
return parent->getMappedTypeBounds(ty);
LUAU_ASSERT(!"Use containsMappedType before asking for bounds!");
return mappedGenerics[ty];
}
TypePackId* SubtypingEnvironment::getMappedPackBounds(TypePackId tp)
{
if (auto it = mappedGenericPacks.find(tp))
return it;
if (parent)
return parent->getMappedPackBounds(tp);
// This fallback is reachable in valid cases, unlike the final part of getMappedTypeBounds
return nullptr;
}
Subtyping::Subtyping( Subtyping::Subtyping(
NotNull<BuiltinTypes> builtinTypes, NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> typeArena, NotNull<TypeArena> typeArena,
NotNull<Simplifier> simplifier,
NotNull<Normalizer> normalizer, NotNull<Normalizer> normalizer,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<InternalErrorReporter> iceReporter NotNull<InternalErrorReporter> iceReporter
) )
: builtinTypes(builtinTypes) : builtinTypes(builtinTypes)
, arena(typeArena) , arena(typeArena)
, simplifier(simplifier)
, normalizer(normalizer) , normalizer(normalizer)
, typeFunctionRuntime(typeFunctionRuntime)
, iceReporter(iceReporter) , iceReporter(iceReporter)
{ {
} }
@ -379,7 +447,10 @@ SubtypingResult Subtyping::isSubtype(TypeId subTy, TypeId superTy, NotNull<Scope
result.isSubtype = false; result.isSubtype = false;
} }
SubtypingResult boundsResult = isCovariantWith(env, lowerBound, upperBound, scope);
SubtypingEnvironment boundsEnv;
boundsEnv.parent = &env;
SubtypingResult boundsResult = isCovariantWith(boundsEnv, lowerBound, upperBound, scope);
boundsResult.reasoning.clear(); boundsResult.reasoning.clear();
result.andAlso(boundsResult); result.andAlso(boundsResult);
@ -442,20 +513,30 @@ struct SeenSetPopper
SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypeId subTy, TypeId superTy, NotNull<Scope> scope) SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypeId subTy, TypeId superTy, NotNull<Scope> scope)
{ {
UnifierCounters& counters = normalizer->sharedState->counters;
RecursionCounter rc(&counters.recursionCount);
if (counters.recursionLimit > 0 && counters.recursionLimit < counters.recursionCount)
{
SubtypingResult result;
result.normalizationTooComplex = true;
return result;
}
subTy = follow(subTy); subTy = follow(subTy);
superTy = follow(superTy); superTy = follow(superTy);
if (TypeId* subIt = env.substitutions.find(subTy); subIt && *subIt) if (const TypeId* subIt = env.tryFindSubstitution(subTy); subIt && *subIt)
subTy = *subIt; subTy = *subIt;
if (TypeId* superIt = env.substitutions.find(superTy); superIt && *superIt) if (const TypeId* superIt = env.tryFindSubstitution(superTy); superIt && *superIt)
superTy = *superIt; superTy = *superIt;
SubtypingResult* cachedResult = resultCache.find({subTy, superTy}); const SubtypingResult* cachedResult = resultCache.find({subTy, superTy});
if (cachedResult) if (cachedResult)
return *cachedResult; return *cachedResult;
cachedResult = env.ephemeralCache.find({subTy, superTy}); cachedResult = env.tryFindSubtypingResult({subTy, superTy});
if (cachedResult) if (cachedResult)
return *cachedResult; return *cachedResult;
@ -700,7 +781,7 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypePackId
std::vector<TypeId> headSlice(begin(superHead), begin(superHead) + headSize); std::vector<TypeId> headSlice(begin(superHead), begin(superHead) + headSize);
TypePackId superTailPack = arena->addTypePack(std::move(headSlice), superTail); TypePackId superTailPack = arena->addTypePack(std::move(headSlice), superTail);
if (TypePackId* other = env.mappedGenericPacks.find(*subTail)) if (TypePackId* other = env.getMappedPackBounds(*subTail))
// TODO: TypePath can't express "slice of a pack + its tail". // TODO: TypePath can't express "slice of a pack + its tail".
results.push_back(isCovariantWith(env, *other, superTailPack, scope).withSubComponent(TypePath::PackField::Tail)); results.push_back(isCovariantWith(env, *other, superTailPack, scope).withSubComponent(TypePath::PackField::Tail));
else else
@ -755,7 +836,7 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypePackId
std::vector<TypeId> headSlice(begin(subHead), begin(subHead) + headSize); std::vector<TypeId> headSlice(begin(subHead), begin(subHead) + headSize);
TypePackId subTailPack = arena->addTypePack(std::move(headSlice), subTail); TypePackId subTailPack = arena->addTypePack(std::move(headSlice), subTail);
if (TypePackId* other = env.mappedGenericPacks.find(*superTail)) if (TypePackId* other = env.getMappedPackBounds(*superTail))
// TODO: TypePath can't express "slice of a pack + its tail". // TODO: TypePath can't express "slice of a pack + its tail".
results.push_back(isContravariantWith(env, subTailPack, *other, scope).withSuperComponent(TypePath::PackField::Tail)); results.push_back(isContravariantWith(env, subTailPack, *other, scope).withSuperComponent(TypePath::PackField::Tail));
else else
@ -778,7 +859,7 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypePackId
else else
return SubtypingResult{false} return SubtypingResult{false}
.withSuperComponent(TypePath::PackField::Tail) .withSuperComponent(TypePath::PackField::Tail)
.withError({scope->location, UnexpectedTypePackInSubtyping{*subTail}}); .withError({scope->location, UnexpectedTypePackInSubtyping{FFlag::LuauSubtypingFixTailPack ? *superTail : *subTail}});
} }
else else
return {false}; return {false};
@ -1316,6 +1397,7 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const Tabl
SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const MetatableType* subMt, const MetatableType* superMt, NotNull<Scope> scope) SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const MetatableType* subMt, const MetatableType* superMt, NotNull<Scope> scope)
{ {
return isCovariantWith(env, subMt->table, superMt->table, scope) return isCovariantWith(env, subMt->table, superMt->table, scope)
.withBothComponent(TypePath::TypeField::Table)
.andAlso(isCovariantWith(env, subMt->metatable, superMt->metatable, scope).withBothComponent(TypePath::TypeField::Metatable)); .andAlso(isCovariantWith(env, subMt->metatable, superMt->metatable, scope).withBothComponent(TypePath::TypeField::Metatable));
} }
@ -1389,6 +1471,19 @@ SubtypingResult Subtyping::isCovariantWith(
result.orElse( result.orElse(
isContravariantWith(env, subFunction->argTypes, superFunction->argTypes, scope).withBothComponent(TypePath::PackField::Arguments) isContravariantWith(env, subFunction->argTypes, superFunction->argTypes, scope).withBothComponent(TypePath::PackField::Arguments)
); );
// If subtyping failed in the argument packs, we should check if there's a hidden variadic tail and try ignoring it.
// This might cause subtyping correctly because the sub type here may not have a hidden variadic tail or equivalent.
if (!result.isSubtype)
{
auto [arguments, tail] = flatten(superFunction->argTypes);
if (auto variadic = get<VariadicTypePack>(tail); variadic && variadic->hidden)
{
result.orElse(isContravariantWith(env, subFunction->argTypes, arena->addTypePack(TypePack{arguments}), scope)
.withBothComponent(TypePath::PackField::Arguments));
}
}
} }
result.andAlso(isCovariantWith(env, subFunction->retTypes, superFunction->retTypes, scope).withBothComponent(TypePath::PackField::Returns)); result.andAlso(isCovariantWith(env, subFunction->retTypes, superFunction->retTypes, scope).withBothComponent(TypePath::PackField::Returns));
@ -1688,6 +1783,9 @@ bool Subtyping::bindGeneric(SubtypingEnvironment& env, TypeId subTy, TypeId supe
if (!get<GenericType>(subTy)) if (!get<GenericType>(subTy))
return false; return false;
if (!env.mappedGenerics.find(subTy) && env.containsMappedType(subTy))
iceReporter->ice("attempting to modify bounds of a potentially visited generic");
env.mappedGenerics[subTy].upperBound.insert(superTy); env.mappedGenerics[subTy].upperBound.insert(superTy);
} }
else else
@ -1695,6 +1793,9 @@ bool Subtyping::bindGeneric(SubtypingEnvironment& env, TypeId subTy, TypeId supe
if (!get<GenericType>(superTy)) if (!get<GenericType>(superTy))
return false; return false;
if (!env.mappedGenerics.find(superTy) && env.containsMappedType(superTy))
iceReporter->ice("attempting to modify bounds of a potentially visited generic");
env.mappedGenerics[superTy].lowerBound.insert(subTy); env.mappedGenerics[superTy].lowerBound.insert(subTy);
} }
@ -1740,7 +1841,7 @@ bool Subtyping::bindGeneric(SubtypingEnvironment& env, TypePackId subTp, TypePac
if (!get<GenericTypePack>(subTp)) if (!get<GenericTypePack>(subTp))
return false; return false;
if (TypePackId* m = env.mappedGenericPacks.find(subTp)) if (TypePackId* m = env.getMappedPackBounds(subTp))
return *m == superTp; return *m == superTp;
env.mappedGenericPacks[subTp] = superTp; env.mappedGenericPacks[subTp] = superTp;
@ -1761,7 +1862,7 @@ TypeId Subtyping::makeAggregateType(const Container& container, TypeId orElse)
std::pair<TypeId, ErrorVec> Subtyping::handleTypeFunctionReductionResult(const TypeFunctionInstanceType* functionInstance, NotNull<Scope> scope) std::pair<TypeId, ErrorVec> Subtyping::handleTypeFunctionReductionResult(const TypeFunctionInstanceType* functionInstance, NotNull<Scope> scope)
{ {
TypeFunctionContext context{arena, builtinTypes, scope, normalizer, iceReporter, NotNull{&limits}}; TypeFunctionContext context{arena, builtinTypes, scope, simplifier, normalizer, typeFunctionRuntime, iceReporter, NotNull{&limits}};
TypeId function = arena->addType(*functionInstance); TypeId function = arena->addType(*functionInstance);
FunctionGraphReductionResult result = reduceTypeFunctions(function, {}, context, true); FunctionGraphReductionResult result = reduceTypeFunctions(function, {}, context, true);
ErrorVec errors; ErrorVec errors;

View file

@ -4,6 +4,7 @@
#include "Luau/Common.h" #include "Luau/Common.h"
LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAGVARIABLE(LuauSymbolEquality)
namespace Luau namespace Luau
{ {
@ -14,7 +15,7 @@ bool Symbol::operator==(const Symbol& rhs) const
return local == rhs.local; return local == rhs.local;
else if (global.value) else if (global.value)
return rhs.global.value && global == rhs.global.value; // Subtlety: AstName::operator==(const char*) uses strcmp, not pointer identity. return rhs.global.value && global == rhs.global.value; // Subtlety: AstName::operator==(const char*) uses strcmp, not pointer identity.
else if (FFlag::LuauSolverV2) else if (FFlag::LuauSolverV2 || FFlag::LuauSymbolEquality)
return !rhs.local && !rhs.global.value; // Reflexivity: we already know `this` Symbol is empty, so check that rhs is. return !rhs.local && !rhs.global.value; // Reflexivity: we already know `this` Symbol is empty, so check that rhs is.
else else
return false; return false;

View file

@ -6,19 +6,15 @@
#include "Luau/Type.h" #include "Luau/Type.h"
#include "Luau/ToString.h" #include "Luau/ToString.h"
#include "Luau/TypeArena.h" #include "Luau/TypeArena.h"
#include "Luau/TypeUtils.h"
#include "Luau/Unifier2.h" #include "Luau/Unifier2.h"
LUAU_FASTFLAGVARIABLE(LuauDontInPlaceMutateTableType)
LUAU_FASTFLAGVARIABLE(LuauAllowNonSharedTableTypesInLiteral)
namespace Luau namespace Luau
{ {
static bool isLiteral(const AstExpr* expr)
{
return (
expr->is<AstExprTable>() || expr->is<AstExprFunction>() || expr->is<AstExprConstantNumber>() || expr->is<AstExprConstantString>() ||
expr->is<AstExprConstantBool>() || expr->is<AstExprConstantNil>()
);
}
// A fast approximation of subTy <: superTy // A fast approximation of subTy <: superTy
static bool fastIsSubtype(TypeId subTy, TypeId superTy) static bool fastIsSubtype(TypeId subTy, TypeId superTy)
{ {
@ -243,6 +239,8 @@ TypeId matchLiteralType(
return exprType; return exprType;
} }
DenseHashSet<AstExprConstantString*> keysToDelete{nullptr};
for (const AstExprTable::Item& item : exprTable->items) for (const AstExprTable::Item& item : exprTable->items)
{ {
if (isRecord(item)) if (isRecord(item))
@ -254,8 +252,19 @@ TypeId matchLiteralType(
Property& prop = it->second; Property& prop = it->second;
if (FFlag::LuauAllowNonSharedTableTypesInLiteral)
{
// If we encounter a duplcate property, we may have already
// set it to be read-only. If that's the case, the only thing
// that will definitely crash is trying to access a write
// only property.
LUAU_ASSERT(!prop.isWriteOnly());
}
else
{
// Table literals always initially result in shared read-write types // Table literals always initially result in shared read-write types
LUAU_ASSERT(prop.isShared()); LUAU_ASSERT(prop.isShared());
}
TypeId propTy = *prop.readTy; TypeId propTy = *prop.readTy;
auto it2 = expectedTableTy->props.find(keyStr); auto it2 = expectedTableTy->props.find(keyStr);
@ -287,6 +296,9 @@ TypeId matchLiteralType(
else else
tableTy->indexer = TableIndexer{expectedTableTy->indexer->indexType, matchedType}; tableTy->indexer = TableIndexer{expectedTableTy->indexer->indexType, matchedType};
if (FFlag::LuauDontInPlaceMutateTableType)
keysToDelete.insert(item.key->as<AstExprConstantString>());
else
tableTy->props.erase(keyStr); tableTy->props.erase(keyStr);
} }
@ -381,15 +393,11 @@ TypeId matchLiteralType(
const TypeId* keyTy = astTypes->find(item.key); const TypeId* keyTy = astTypes->find(item.key);
LUAU_ASSERT(keyTy); LUAU_ASSERT(keyTy);
TypeId tKey = follow(*keyTy); TypeId tKey = follow(*keyTy);
if (get<BlockedType>(tKey)) LUAU_ASSERT(!is<BlockedType>(tKey));
toBlock.push_back(tKey);
const TypeId* propTy = astTypes->find(item.value); const TypeId* propTy = astTypes->find(item.value);
LUAU_ASSERT(propTy); LUAU_ASSERT(propTy);
TypeId tProp = follow(*propTy); TypeId tProp = follow(*propTy);
if (get<BlockedType>(tProp)) LUAU_ASSERT(!is<BlockedType>(tProp));
toBlock.push_back(tProp);
// Populate expected types for non-string keys declared with [] (the code below will handle the case where they are strings) // Populate expected types for non-string keys declared with [] (the code below will handle the case where they are strings)
if (!item.key->as<AstExprConstantString>() && expectedTableTy->indexer) if (!item.key->as<AstExprConstantString>() && expectedTableTy->indexer)
(*astExpectedTypes)[item.key] = expectedTableTy->indexer->indexType; (*astExpectedTypes)[item.key] = expectedTableTy->indexer->indexType;
@ -398,6 +406,16 @@ TypeId matchLiteralType(
LUAU_ASSERT(!"Unexpected"); LUAU_ASSERT(!"Unexpected");
} }
if (FFlag::LuauDontInPlaceMutateTableType)
{
for (const auto& key : keysToDelete)
{
const AstArray<char>& s = key->value;
std::string keyStr{s.data, s.data + s.size};
tableTy->props.erase(keyStr);
}
}
// Keys that the expectedType says we should have, but that aren't // Keys that the expectedType says we should have, but that aren't
// specified by the AST fragment. // specified by the AST fragment.
// //

View file

@ -269,6 +269,12 @@ void StateDot::visitChildren(TypeId ty, int index)
finishNodeLabel(ty); finishNodeLabel(ty);
finishNode(); finishNode();
} }
else if constexpr (std::is_same_v<T, NoRefineType>)
{
formatAppend(result, "NoRefineType %d", index);
finishNodeLabel(ty);
finishNode();
}
else if constexpr (std::is_same_v<T, UnknownType>) else if constexpr (std::is_same_v<T, UnknownType>)
{ {
formatAppend(result, "UnknownType %d", index); formatAppend(result, "UnknownType %d", index);
@ -414,7 +420,7 @@ void StateDot::visitChildren(TypePackId tp, int index)
finishNodeLabel(tp); finishNodeLabel(tp);
finishNode(); finishNode();
} }
else if (get<Unifiable::Error>(tp)) else if (get<ErrorTypePack>(tp))
{ {
formatAppend(result, "ErrorTypePack %d", index); formatAppend(result, "ErrorTypePack %d", index);
finishNodeLabel(tp); finishNodeLabel(tp);

View file

@ -20,6 +20,7 @@
#include <string> #include <string>
LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAGVARIABLE(LuauSyntheticErrors)
/* /*
* Enables increasing levels of verbosity for Luau type names when stringifying. * Enables increasing levels of verbosity for Luau type names when stringifying.
@ -38,7 +39,7 @@ LUAU_FASTFLAG(LuauSolverV2)
* 3: Suffix free/generic types with their scope pointer, if present. * 3: Suffix free/generic types with their scope pointer, if present.
*/ */
LUAU_FASTINTVARIABLE(DebugLuauVerboseTypeNames, 0) LUAU_FASTINTVARIABLE(DebugLuauVerboseTypeNames, 0)
LUAU_FASTFLAGVARIABLE(DebugLuauToStringNoLexicalSort, false) LUAU_FASTFLAGVARIABLE(DebugLuauToStringNoLexicalSort)
namespace Luau namespace Luau
{ {
@ -856,6 +857,11 @@ struct TypeStringifier
state.emit("any"); state.emit("any");
} }
void operator()(TypeId, const NoRefineType&)
{
state.emit("*no-refine*");
}
void operator()(TypeId, const UnionType& uv) void operator()(TypeId, const UnionType& uv)
{ {
if (state.hasSeen(&uv)) if (state.hasSeen(&uv))
@ -865,6 +871,8 @@ struct TypeStringifier
return; return;
} }
LUAU_ASSERT(uv.options.size() > 1);
bool optional = false; bool optional = false;
bool hasNonNilDisjunct = false; bool hasNonNilDisjunct = false;
@ -873,7 +881,7 @@ struct TypeStringifier
{ {
el = follow(el); el = follow(el);
if (isNil(el)) if (state.opts.useQuestionMarks && isNil(el))
{ {
optional = true; optional = true;
continue; continue;
@ -991,6 +999,14 @@ struct TypeStringifier
void operator()(TypeId, const ErrorType& tv) void operator()(TypeId, const ErrorType& tv)
{ {
state.result.error = true; state.result.error = true;
if (FFlag::LuauSyntheticErrors && tv.synthetic)
{
state.emit("*error-type<");
stringify(*tv.synthetic);
state.emit(">*");
}
else
state.emit("*error-type*"); state.emit("*error-type*");
} }
@ -1040,6 +1056,7 @@ struct TypeStringifier
state.emit(tfitv.userFuncName->value); state.emit(tfitv.userFuncName->value);
else else
state.emit(tfitv.function->name); state.emit(tfitv.function->name);
state.emit("<"); state.emit("<");
bool comma = false; bool comma = false;
@ -1165,9 +1182,17 @@ struct TypePackStringifier
state.unsee(&tp); state.unsee(&tp);
} }
void operator()(TypePackId, const Unifiable::Error& error) void operator()(TypePackId, const ErrorTypePack& error)
{ {
state.result.error = true; state.result.error = true;
if (FFlag::LuauSyntheticErrors && error.synthetic)
{
state.emit("*");
stringify(*error.synthetic);
state.emit("*");
}
else
state.emit("*error-type*"); state.emit("*error-type*");
} }
@ -1840,6 +1865,8 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts)
} }
else if constexpr (std::is_same_v<T, EqualityConstraint>) else if constexpr (std::is_same_v<T, EqualityConstraint>)
return "equality: " + tos(c.resultType) + " ~ " + tos(c.assignmentType); return "equality: " + tos(c.resultType) + " ~ " + tos(c.assignmentType);
else if constexpr (std::is_same_v<T, TableCheckConstraint>)
return "table_check " + tos(c.expectedType) + " :> " + tos(c.exprType);
else else
static_assert(always_false_v<T>, "Non-exhaustive constraint switch"); static_assert(always_false_v<T>, "Non-exhaustive constraint switch");
}; };

File diff suppressed because it is too large Load diff

View file

@ -93,8 +93,8 @@ void TxnLog::concatAsIntersections(TxnLog rhs, NotNull<TypeArena> arena)
if (auto leftRep = typeVarChanges.find(ty); leftRep && !(*leftRep)->dead) if (auto leftRep = typeVarChanges.find(ty); leftRep && !(*leftRep)->dead)
{ {
TypeId leftTy = arena->addType((*leftRep)->pending); TypeId leftTy = arena->addType((*leftRep)->pending.clone());
TypeId rightTy = arena->addType(rightRep->pending); TypeId rightTy = arena->addType(rightRep->pending.clone());
typeVarChanges[ty]->pending.ty = IntersectionType{{leftTy, rightTy}}; typeVarChanges[ty]->pending.ty = IntersectionType{{leftTy, rightTy}};
} }
else else
@ -170,8 +170,8 @@ void TxnLog::concatAsUnion(TxnLog rhs, NotNull<TypeArena> arena)
if (auto leftRep = typeVarChanges.find(ty); leftRep && !(*leftRep)->dead) if (auto leftRep = typeVarChanges.find(ty); leftRep && !(*leftRep)->dead)
{ {
TypeId leftTy = arena->addType((*leftRep)->pending); TypeId leftTy = arena->addType((*leftRep)->pending.clone());
TypeId rightTy = arena->addType(rightRep->pending); TypeId rightTy = arena->addType(rightRep->pending.clone());
if (follow(leftTy) == follow(rightTy)) if (follow(leftTy) == follow(rightTy))
typeVarChanges[ty] = std::move(rightRep); typeVarChanges[ty] = std::move(rightRep);
@ -217,7 +217,7 @@ TxnLog TxnLog::inverse()
for (auto& [ty, _rep] : typeVarChanges) for (auto& [ty, _rep] : typeVarChanges)
{ {
if (!_rep->dead) if (!_rep->dead)
inversed.typeVarChanges[ty] = std::make_unique<PendingType>(*ty); inversed.typeVarChanges[ty] = std::make_unique<PendingType>(ty->clone());
} }
for (auto& [tp, _rep] : typePackChanges) for (auto& [tp, _rep] : typePackChanges)
@ -292,7 +292,7 @@ PendingType* TxnLog::queue(TypeId ty)
auto& pending = typeVarChanges[ty]; auto& pending = typeVarChanges[ty];
if (!pending || (*pending).dead) if (!pending || (*pending).dead)
{ {
pending = std::make_unique<PendingType>(*ty); pending = std::make_unique<PendingType>(ty->clone());
pending->pending.owningArena = nullptr; pending->pending.owningArena = nullptr;
} }

View file

@ -27,6 +27,7 @@ LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500)
LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0)
LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTINT(LuauTypeInferRecursionLimit)
LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAG(LuauInstantiateInSubtyping)
LUAU_FASTFLAGVARIABLE(LuauFreeTypesMustHaveBounds)
namespace Luau namespace Luau
{ {
@ -478,24 +479,12 @@ bool hasLength(TypeId ty, DenseHashSet<TypeId>& seen, int* recursionCount)
return false; return false;
} }
FreeType::FreeType(TypeLevel level) // New constructors
FreeType::FreeType(TypeLevel level, TypeId lowerBound, TypeId upperBound)
: index(Unifiable::freshIndex()) : index(Unifiable::freshIndex())
, level(level) , level(level)
, scope(nullptr) , lowerBound(lowerBound)
{ , upperBound(upperBound)
}
FreeType::FreeType(Scope* scope)
: index(Unifiable::freshIndex())
, level{}
, scope(scope)
{
}
FreeType::FreeType(Scope* scope, TypeLevel level)
: index(Unifiable::freshIndex())
, level(level)
, scope(scope)
{ {
} }
@ -507,6 +496,40 @@ FreeType::FreeType(Scope* scope, TypeId lowerBound, TypeId upperBound)
{ {
} }
FreeType::FreeType(Scope* scope, TypeLevel level, TypeId lowerBound, TypeId upperBound)
: index(Unifiable::freshIndex())
, level(level)
, scope(scope)
, lowerBound(lowerBound)
, upperBound(upperBound)
{
}
// Old constructors
FreeType::FreeType(TypeLevel level)
: index(Unifiable::freshIndex())
, level(level)
, scope(nullptr)
{
LUAU_ASSERT(!FFlag::LuauFreeTypesMustHaveBounds);
}
FreeType::FreeType(Scope* scope)
: index(Unifiable::freshIndex())
, level{}
, scope(scope)
{
LUAU_ASSERT(!FFlag::LuauFreeTypesMustHaveBounds);
}
FreeType::FreeType(Scope* scope, TypeLevel level)
: index(Unifiable::freshIndex())
, level(level)
, scope(scope)
{
LUAU_ASSERT(!FFlag::LuauFreeTypesMustHaveBounds);
}
GenericType::GenericType() GenericType::GenericType()
: index(Unifiable::freshIndex()) : index(Unifiable::freshIndex())
, name("g" + std::to_string(index)) , name("g" + std::to_string(index))
@ -554,12 +577,12 @@ BlockedType::BlockedType()
{ {
} }
Constraint* BlockedType::getOwner() const const Constraint* BlockedType::getOwner() const
{ {
return owner; return owner;
} }
void BlockedType::setOwner(Constraint* newOwner) void BlockedType::setOwner(const Constraint* newOwner)
{ {
LUAU_ASSERT(owner == nullptr); LUAU_ASSERT(owner == nullptr);
@ -569,7 +592,7 @@ void BlockedType::setOwner(Constraint* newOwner)
owner = newOwner; owner = newOwner;
} }
void BlockedType::replaceOwner(Constraint* newOwner) void BlockedType::replaceOwner(const Constraint* newOwner)
{ {
owner = newOwner; owner = newOwner;
} }
@ -999,6 +1022,11 @@ Type& Type::operator=(const Type& rhs)
return *this; return *this;
} }
Type Type::clone() const
{
return *this;
}
TypeId makeFunction( TypeId makeFunction(
TypeArena& arena, TypeArena& arena,
std::optional<TypeId> selfType, std::optional<TypeId> selfType,
@ -1030,6 +1058,7 @@ BuiltinTypes::BuiltinTypes()
, unknownType(arena->addType(Type{UnknownType{}, /*persistent*/ true})) , unknownType(arena->addType(Type{UnknownType{}, /*persistent*/ true}))
, neverType(arena->addType(Type{NeverType{}, /*persistent*/ true})) , neverType(arena->addType(Type{NeverType{}, /*persistent*/ true}))
, errorType(arena->addType(Type{ErrorType{}, /*persistent*/ true})) , errorType(arena->addType(Type{ErrorType{}, /*persistent*/ true}))
, noRefineType(arena->addType(Type{NoRefineType{}, /*persistent*/ true}))
, falsyType(arena->addType(Type{UnionType{{falseType, nilType}}, /*persistent*/ true})) , falsyType(arena->addType(Type{UnionType{{falseType, nilType}}, /*persistent*/ true}))
, truthyType(arena->addType(Type{NegationType{falsyType}, /*persistent*/ true})) , truthyType(arena->addType(Type{NegationType{falsyType}, /*persistent*/ true}))
, optionalNumberType(arena->addType(Type{UnionType{{numberType, nilType}}, /*persistent*/ true})) , optionalNumberType(arena->addType(Type{UnionType{{numberType, nilType}}, /*persistent*/ true}))
@ -1039,7 +1068,7 @@ BuiltinTypes::BuiltinTypes()
, unknownTypePack(arena->addTypePack(TypePackVar{VariadicTypePack{unknownType}, /*persistent*/ true})) , unknownTypePack(arena->addTypePack(TypePackVar{VariadicTypePack{unknownType}, /*persistent*/ true}))
, neverTypePack(arena->addTypePack(TypePackVar{VariadicTypePack{neverType}, /*persistent*/ true})) , neverTypePack(arena->addTypePack(TypePackVar{VariadicTypePack{neverType}, /*persistent*/ true}))
, uninhabitableTypePack(arena->addTypePack(TypePackVar{TypePack{{neverType}, neverTypePack}, /*persistent*/ true})) , uninhabitableTypePack(arena->addTypePack(TypePackVar{TypePack{{neverType}, neverTypePack}, /*persistent*/ true}))
, errorTypePack(arena->addTypePack(TypePackVar{Unifiable::Error{}, /*persistent*/ true})) , errorTypePack(arena->addTypePack(TypePackVar{ErrorTypePack{}, /*persistent*/ true}))
{ {
freeze(*arena); freeze(*arena);
} }

View file

@ -2,7 +2,8 @@
#include "Luau/TypeArena.h" #include "Luau/TypeArena.h"
LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false); LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena);
LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds)
namespace Luau namespace Luau
{ {
@ -22,7 +23,34 @@ TypeId TypeArena::addTV(Type&& tv)
return allocated; return allocated;
} }
TypeId TypeArena::freshType(TypeLevel level) TypeId TypeArena::freshType(NotNull<BuiltinTypes> builtins, TypeLevel level)
{
TypeId allocated = types.allocate(FreeType{level, builtins->neverType, builtins->unknownType});
asMutable(allocated)->owningArena = this;
return allocated;
}
TypeId TypeArena::freshType(NotNull<BuiltinTypes> builtins, Scope* scope)
{
TypeId allocated = types.allocate(FreeType{scope, builtins->neverType, builtins->unknownType});
asMutable(allocated)->owningArena = this;
return allocated;
}
TypeId TypeArena::freshType(NotNull<BuiltinTypes> builtins, Scope* scope, TypeLevel level)
{
TypeId allocated = types.allocate(FreeType{scope, level, builtins->neverType, builtins->unknownType});
asMutable(allocated)->owningArena = this;
return allocated;
}
TypeId TypeArena::freshType_DEPRECATED(TypeLevel level)
{ {
TypeId allocated = types.allocate(FreeType{level}); TypeId allocated = types.allocate(FreeType{level});
@ -31,7 +59,7 @@ TypeId TypeArena::freshType(TypeLevel level)
return allocated; return allocated;
} }
TypeId TypeArena::freshType(Scope* scope) TypeId TypeArena::freshType_DEPRECATED(Scope* scope)
{ {
TypeId allocated = types.allocate(FreeType{scope}); TypeId allocated = types.allocate(FreeType{scope});
@ -40,7 +68,7 @@ TypeId TypeArena::freshType(Scope* scope)
return allocated; return allocated;
} }
TypeId TypeArena::freshType(Scope* scope, TypeLevel level) TypeId TypeArena::freshType_DEPRECATED(Scope* scope, TypeLevel level)
{ {
TypeId allocated = types.allocate(FreeType{scope, level}); TypeId allocated = types.allocate(FreeType{scope, level});

View file

@ -145,6 +145,12 @@ public:
{ {
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("any"), std::nullopt, Location()); return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("any"), std::nullopt, Location());
} }
AstType* operator()(const NoRefineType&)
{
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("*no-refine*"), std::nullopt, Location());
}
AstType* operator()(const TableType& ttv) AstType* operator()(const TableType& ttv)
{ {
RecursionCounter counter(&count); RecursionCounter counter(&count);
@ -255,24 +261,24 @@ public:
if (hasSeen(&ftv)) if (hasSeen(&ftv))
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("<Cycle>"), std::nullopt, Location()); return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("<Cycle>"), std::nullopt, Location());
AstArray<AstGenericType> generics; AstArray<AstGenericType*> generics;
generics.size = ftv.generics.size(); generics.size = ftv.generics.size();
generics.data = static_cast<AstGenericType*>(allocator->allocate(sizeof(AstGenericType) * generics.size)); generics.data = static_cast<AstGenericType**>(allocator->allocate(sizeof(AstGenericType) * generics.size));
size_t numGenerics = 0; size_t numGenerics = 0;
for (auto it = ftv.generics.begin(); it != ftv.generics.end(); ++it) for (auto it = ftv.generics.begin(); it != ftv.generics.end(); ++it)
{ {
if (auto gtv = get<GenericType>(*it)) if (auto gtv = get<GenericType>(*it))
generics.data[numGenerics++] = {AstName(gtv->name.c_str()), Location(), nullptr}; generics.data[numGenerics++] = allocator->alloc<AstGenericType>(Location(), AstName(gtv->name.c_str()), nullptr);
} }
AstArray<AstGenericTypePack> genericPacks; AstArray<AstGenericTypePack*> genericPacks;
genericPacks.size = ftv.genericPacks.size(); genericPacks.size = ftv.genericPacks.size();
genericPacks.data = static_cast<AstGenericTypePack*>(allocator->allocate(sizeof(AstGenericTypePack) * genericPacks.size)); genericPacks.data = static_cast<AstGenericTypePack**>(allocator->allocate(sizeof(AstGenericTypePack) * genericPacks.size));
size_t numGenericPacks = 0; size_t numGenericPacks = 0;
for (auto it = ftv.genericPacks.begin(); it != ftv.genericPacks.end(); ++it) for (auto it = ftv.genericPacks.begin(); it != ftv.genericPacks.end(); ++it)
{ {
if (auto gtv = get<GenericTypePack>(*it)) if (auto gtv = get<GenericTypePack>(*it))
genericPacks.data[numGenericPacks++] = {AstName(gtv->name.c_str()), Location(), nullptr}; genericPacks.data[numGenericPacks++] = allocator->alloc<AstGenericTypePack>(Location(), AstName(gtv->name.c_str()), nullptr);
} }
AstArray<AstType*> argTypes; AstArray<AstType*> argTypes;
@ -323,7 +329,7 @@ public:
Location(), generics, genericPacks, AstTypeList{argTypes, argTailAnnotation}, argNames, AstTypeList{returnTypes, retTailAnnotation} Location(), generics, genericPacks, AstTypeList{argTypes, argTailAnnotation}, argNames, AstTypeList{returnTypes, retTailAnnotation}
); );
} }
AstType* operator()(const Unifiable::Error&) AstType* operator()(const ErrorType&)
{ {
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("Unifiable<Error>"), std::nullopt, Location()); return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("Unifiable<Error>"), std::nullopt, Location());
} }
@ -380,8 +386,12 @@ public:
} }
AstType* operator()(const NegationType& ntv) AstType* operator()(const NegationType& ntv)
{ {
// FIXME: do the same thing we do with ErrorType AstArray<AstTypeOrPack> params;
throw InternalCompilerError("Cannot convert NegationType into AstNode"); params.size = 1;
params.data = static_cast<AstTypeOrPack*>(allocator->allocate(sizeof(AstType*)));
params.data[0] = AstTypeOrPack{Luau::visit(*this, ntv.ty->ty), nullptr};
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("negate"), std::nullopt, Location(), true, params);
} }
AstType* operator()(const TypeFunctionInstanceType& tfit) AstType* operator()(const TypeFunctionInstanceType& tfit)
{ {
@ -452,7 +462,7 @@ public:
return allocator->alloc<AstTypePackGeneric>(Location(), AstName("free")); return allocator->alloc<AstTypePackGeneric>(Location(), AstName("free"));
} }
AstTypePack* operator()(const Unifiable::Error&) const AstTypePack* operator()(const ErrorTypePack&) const
{ {
return allocator->alloc<AstTypePackGeneric>(Location(), AstName("Unifiable<Error>")); return allocator->alloc<AstTypePackGeneric>(Location(), AstName("Unifiable<Error>"));
} }

View file

@ -7,7 +7,6 @@
#include "Luau/DcrLogger.h" #include "Luau/DcrLogger.h"
#include "Luau/DenseHash.h" #include "Luau/DenseHash.h"
#include "Luau/Error.h" #include "Luau/Error.h"
#include "Luau/InsertionOrderedMap.h"
#include "Luau/Instantiation.h" #include "Luau/Instantiation.h"
#include "Luau/Metamethods.h" #include "Luau/Metamethods.h"
#include "Luau/Normalize.h" #include "Luau/Normalize.h"
@ -27,11 +26,11 @@
#include "Luau/VisitType.h" #include "Luau/VisitType.h"
#include <algorithm> #include <algorithm>
#include <iostream>
#include <ostream>
LUAU_FASTFLAG(DebugLuauMagicTypes) LUAU_FASTFLAG(DebugLuauMagicTypes)
LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds)
namespace Luau namespace Luau
{ {
@ -173,7 +172,7 @@ struct InternalTypeFunctionFinder : TypeOnceVisitor
DenseHashSet<TypeId> mentionedFunctions{nullptr}; DenseHashSet<TypeId> mentionedFunctions{nullptr};
DenseHashSet<TypePackId> mentionedFunctionPacks{nullptr}; DenseHashSet<TypePackId> mentionedFunctionPacks{nullptr};
InternalTypeFunctionFinder(std::vector<TypeId>& declStack) explicit InternalTypeFunctionFinder(std::vector<TypeId>& declStack)
{ {
TypeFunctionFinder f; TypeFunctionFinder f;
for (TypeId fn : declStack) for (TypeId fn : declStack)
@ -266,6 +265,8 @@ struct InternalTypeFunctionFinder : TypeOnceVisitor
void check( void check(
NotNull<BuiltinTypes> builtinTypes, NotNull<BuiltinTypes> builtinTypes,
NotNull<Simplifier> simplifier,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<UnifierSharedState> unifierState, NotNull<UnifierSharedState> unifierState,
NotNull<TypeCheckLimits> limits, NotNull<TypeCheckLimits> limits,
DcrLogger* logger, DcrLogger* logger,
@ -275,7 +276,7 @@ void check(
{ {
LUAU_TIMETRACE_SCOPE("check", "Typechecking"); LUAU_TIMETRACE_SCOPE("check", "Typechecking");
TypeChecker2 typeChecker{builtinTypes, unifierState, limits, logger, &sourceModule, module}; TypeChecker2 typeChecker{builtinTypes, simplifier, typeFunctionRuntime, unifierState, limits, logger, &sourceModule, module};
typeChecker.visit(sourceModule.root); typeChecker.visit(sourceModule.root);
@ -292,6 +293,8 @@ void check(
TypeChecker2::TypeChecker2( TypeChecker2::TypeChecker2(
NotNull<BuiltinTypes> builtinTypes, NotNull<BuiltinTypes> builtinTypes,
NotNull<Simplifier> simplifier,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<UnifierSharedState> unifierState, NotNull<UnifierSharedState> unifierState,
NotNull<TypeCheckLimits> limits, NotNull<TypeCheckLimits> limits,
DcrLogger* logger, DcrLogger* logger,
@ -299,13 +302,15 @@ TypeChecker2::TypeChecker2(
Module* module Module* module
) )
: builtinTypes(builtinTypes) : builtinTypes(builtinTypes)
, simplifier(simplifier)
, typeFunctionRuntime(typeFunctionRuntime)
, logger(logger) , logger(logger)
, limits(limits) , limits(limits)
, ice(unifierState->iceHandler) , ice(unifierState->iceHandler)
, sourceModule(sourceModule) , sourceModule(sourceModule)
, module(module) , module(module)
, normalizer{&module->internalTypes, builtinTypes, unifierState, /* cacheInhabitance */ true} , normalizer{&module->internalTypes, builtinTypes, unifierState, /* cacheInhabitance */ true}
, _subtyping{builtinTypes, NotNull{&module->internalTypes}, NotNull{&normalizer}, NotNull{unifierState->iceHandler}} , _subtyping{builtinTypes, NotNull{&module->internalTypes}, simplifier, NotNull{&normalizer}, typeFunctionRuntime, NotNull{unifierState->iceHandler}}
, subtyping(&_subtyping) , subtyping(&_subtyping)
{ {
} }
@ -483,10 +488,13 @@ TypeId TypeChecker2::checkForTypeFunctionInhabitance(TypeId instance, Location l
return instance; return instance;
seenTypeFunctionInstances.insert(instance); seenTypeFunctionInstances.insert(instance);
ErrorVec errors = reduceTypeFunctions( ErrorVec errors =
reduceTypeFunctions(
instance, instance,
location, location,
TypeFunctionContext{NotNull{&module->internalTypes}, builtinTypes, stack.back(), NotNull{&normalizer}, ice, limits}, TypeFunctionContext{
NotNull{&module->internalTypes}, builtinTypes, stack.back(), simplifier, NotNull{&normalizer}, typeFunctionRuntime, ice, limits
},
true true
) )
.errors; .errors;
@ -495,7 +503,7 @@ TypeId TypeChecker2::checkForTypeFunctionInhabitance(TypeId instance, Location l
return instance; return instance;
} }
TypePackId TypeChecker2::lookupPack(AstExpr* expr) TypePackId TypeChecker2::lookupPack(AstExpr* expr) const
{ {
// If a type isn't in the type graph, it probably means that a recursion limit was exceeded. // If a type isn't in the type graph, it probably means that a recursion limit was exceeded.
// We'll just return anyType in these cases. Typechecking against any is very fast and this // We'll just return anyType in these cases. Typechecking against any is very fast and this
@ -545,7 +553,7 @@ TypeId TypeChecker2::lookupAnnotation(AstType* annotation)
return checkForTypeFunctionInhabitance(follow(*ty), annotation->location); return checkForTypeFunctionInhabitance(follow(*ty), annotation->location);
} }
std::optional<TypePackId> TypeChecker2::lookupPackAnnotation(AstTypePack* annotation) std::optional<TypePackId> TypeChecker2::lookupPackAnnotation(AstTypePack* annotation) const
{ {
TypePackId* tp = module->astResolvedTypePacks.find(annotation); TypePackId* tp = module->astResolvedTypePacks.find(annotation);
if (tp != nullptr) if (tp != nullptr)
@ -553,7 +561,7 @@ std::optional<TypePackId> TypeChecker2::lookupPackAnnotation(AstTypePack* annota
return {}; return {};
} }
TypeId TypeChecker2::lookupExpectedType(AstExpr* expr) TypeId TypeChecker2::lookupExpectedType(AstExpr* expr) const
{ {
if (TypeId* ty = module->astExpectedTypes.find(expr)) if (TypeId* ty = module->astExpectedTypes.find(expr))
return follow(*ty); return follow(*ty);
@ -561,7 +569,7 @@ TypeId TypeChecker2::lookupExpectedType(AstExpr* expr)
return builtinTypes->anyType; return builtinTypes->anyType;
} }
TypePackId TypeChecker2::lookupExpectedPack(AstExpr* expr, TypeArena& arena) TypePackId TypeChecker2::lookupExpectedPack(AstExpr* expr, TypeArena& arena) const
{ {
if (TypeId* ty = module->astExpectedTypes.find(expr)) if (TypeId* ty = module->astExpectedTypes.find(expr))
return arena.addTypePack(TypePack{{follow(*ty)}, std::nullopt}); return arena.addTypePack(TypePack{{follow(*ty)}, std::nullopt});
@ -585,7 +593,7 @@ TypePackId TypeChecker2::reconstructPack(AstArray<AstExpr*> exprs, TypeArena& ar
return arena.addTypePack(TypePack{head, tail}); return arena.addTypePack(TypePack{head, tail});
} }
Scope* TypeChecker2::findInnermostScope(Location location) Scope* TypeChecker2::findInnermostScope(Location location) const
{ {
Scope* bestScope = module->getModuleScope().get(); Scope* bestScope = module->getModuleScope().get();
@ -1008,7 +1016,8 @@ void TypeChecker2::visit(AstStatForIn* forInStatement)
{ {
reportError(OptionalValueAccess{iteratorTy}, forInStatement->values.data[0]->location); reportError(OptionalValueAccess{iteratorTy}, forInStatement->values.data[0]->location);
} }
else if (std::optional<TypeId> iterMmTy = findMetatableEntry(builtinTypes, module->errors, iteratorTy, "__iter", forInStatement->values.data[0]->location)) else if (std::optional<TypeId> iterMmTy =
findMetatableEntry(builtinTypes, module->errors, iteratorTy, "__iter", forInStatement->values.data[0]->location))
{ {
Instantiation instantiation{TxnLog::empty(), &arena, builtinTypes, TypeLevel{}, scope}; Instantiation instantiation{TxnLog::empty(), &arena, builtinTypes, TypeLevel{}, scope};
@ -1193,8 +1202,6 @@ void TypeChecker2::visit(AstStatTypeAlias* stat)
void TypeChecker2::visit(AstStatTypeFunction* stat) void TypeChecker2::visit(AstStatTypeFunction* stat)
{ {
// TODO: add type checking for user-defined type functions // TODO: add type checking for user-defined type functions
reportError(TypeError{stat->location, GenericError{"This syntax is not supported"}});
} }
void TypeChecker2::visit(AstTypeList types) void TypeChecker2::visit(AstTypeList types)
@ -1345,8 +1352,18 @@ void TypeChecker2::visit(AstExprGlobal* expr)
{ {
NotNull<Scope> scope = stack.back(); NotNull<Scope> scope = stack.back();
if (!scope->lookup(expr->name)) if (!scope->lookup(expr->name))
{
reportError(UnknownSymbol{expr->name.value, UnknownSymbol::Binding}, expr->location); reportError(UnknownSymbol{expr->name.value, UnknownSymbol::Binding}, expr->location);
} }
else
{
if (scope->shouldWarnGlobal(expr->name.value) && !warnedGlobals.contains(expr->name.value))
{
reportError(UnknownSymbol{expr->name.value, UnknownSymbol::Binding}, expr->location);
warnedGlobals.insert(expr->name.value);
}
}
}
void TypeChecker2::visit(AstExprVarargs* expr) void TypeChecker2::visit(AstExprVarargs* expr)
{ {
@ -1433,9 +1450,10 @@ void TypeChecker2::visitCall(AstExprCall* call)
TypePackId argsTp = module->internalTypes.addTypePack(args); TypePackId argsTp = module->internalTypes.addTypePack(args);
if (auto ftv = get<FunctionType>(follow(*originalCallTy))) if (auto ftv = get<FunctionType>(follow(*originalCallTy)))
{ {
if (ftv->dcrMagicTypeCheck) if (ftv->magic)
{ {
ftv->dcrMagicTypeCheck(MagicFunctionTypeCheckContext{NotNull{this}, builtinTypes, call, argsTp, scope}); bool usedMagic = ftv->magic->typeCheck(MagicFunctionTypeCheckContext{NotNull{this}, builtinTypes, call, argsTp, scope});
if (usedMagic)
return; return;
} }
} }
@ -1444,7 +1462,9 @@ void TypeChecker2::visitCall(AstExprCall* call)
OverloadResolver resolver{ OverloadResolver resolver{
builtinTypes, builtinTypes,
NotNull{&module->internalTypes}, NotNull{&module->internalTypes},
simplifier,
NotNull{&normalizer}, NotNull{&normalizer},
typeFunctionRuntime,
NotNull{stack.back()}, NotNull{stack.back()},
ice, ice,
limits, limits,
@ -1540,7 +1560,7 @@ void TypeChecker2::visit(AstExprCall* call)
visitCall(call); visitCall(call);
} }
std::optional<TypeId> TypeChecker2::tryStripUnionFromNil(TypeId ty) std::optional<TypeId> TypeChecker2::tryStripUnionFromNil(TypeId ty) const
{ {
if (const UnionType* utv = get<UnionType>(ty)) if (const UnionType* utv = get<UnionType>(ty))
{ {
@ -1618,8 +1638,7 @@ void TypeChecker2::indexExprMetatableHelper(AstExprIndexExpr* indexExpr, const M
indexExprMetatableHelper(indexExpr, mtmt, exprType, indexType); indexExprMetatableHelper(indexExpr, mtmt, exprType, indexType);
else else
{ {
LUAU_ASSERT(tt || get<PrimitiveType>(follow(metaTable->table))); // CLI-122161: We're not handling unions correctly (probably).
reportError(CannotExtendTable{exprType, CannotExtendTable::Indexer, "indexer??"}, indexExpr->location); reportError(CannotExtendTable{exprType, CannotExtendTable::Indexer, "indexer??"}, indexExpr->location);
} }
} }
@ -1826,11 +1845,10 @@ void TypeChecker2::visit(AstExprFunction* fn)
void TypeChecker2::visit(AstExprTable* expr) void TypeChecker2::visit(AstExprTable* expr)
{ {
// TODO!
for (const AstExprTable::Item& item : expr->items) for (const AstExprTable::Item& item : expr->items)
{ {
if (item.key) if (item.key)
visit(item.key, ValueContext::LValue); visit(item.key, ValueContext::RValue);
visit(item.value, ValueContext::RValue); visit(item.value, ValueContext::RValue);
} }
} }
@ -2078,7 +2096,10 @@ TypeId TypeChecker2::visit(AstExprBinary* expr, AstNode* overrideKey)
} }
else else
{ {
expectedRets = module->internalTypes.addTypePack({module->internalTypes.freshType(scope, TypeLevel{})}); expectedRets = module->internalTypes.addTypePack(
{FFlag::LuauFreeTypesMustHaveBounds ? module->internalTypes.freshType(builtinTypes, scope, TypeLevel{})
: module->internalTypes.freshType_DEPRECATED(scope, TypeLevel{})}
);
} }
TypeId expectedTy = module->internalTypes.addType(FunctionType(expectedArgs, expectedRets)); TypeId expectedTy = module->internalTypes.addType(FunctionType(expectedArgs, expectedRets));
@ -2330,7 +2351,8 @@ TypeId TypeChecker2::flattenPack(TypePackId pack)
return *fst; return *fst;
else if (auto ftp = get<FreeTypePack>(pack)) else if (auto ftp = get<FreeTypePack>(pack))
{ {
TypeId result = module->internalTypes.addType(FreeType{ftp->scope}); TypeId result = FFlag::LuauFreeTypesMustHaveBounds ? module->internalTypes.freshType(builtinTypes, ftp->scope)
: module->internalTypes.addType(FreeType{ftp->scope});
TypePackId freeTail = module->internalTypes.addTypePack(FreeTypePack{ftp->scope}); TypePackId freeTail = module->internalTypes.addTypePack(FreeTypePack{ftp->scope});
TypePack* resultPack = emplaceTypePack<TypePack>(asMutable(pack)); TypePack* resultPack = emplaceTypePack<TypePack>(asMutable(pack));
@ -2339,7 +2361,7 @@ TypeId TypeChecker2::flattenPack(TypePackId pack)
return result; return result;
} }
else if (get<Unifiable::Error>(pack)) else if (get<ErrorTypePack>(pack))
return builtinTypes->errorRecoveryType(); return builtinTypes->errorRecoveryType();
else if (finite(pack) && size(pack) == 0) else if (finite(pack) && size(pack) == 0)
return builtinTypes->nilType; // `(f())` where `f()` returns no values is coerced into `nil` return builtinTypes->nilType; // `(f())` where `f()` returns no values is coerced into `nil`
@ -2347,30 +2369,30 @@ TypeId TypeChecker2::flattenPack(TypePackId pack)
ice->ice("flattenPack got a weird pack!"); ice->ice("flattenPack got a weird pack!");
} }
void TypeChecker2::visitGenerics(AstArray<AstGenericType> generics, AstArray<AstGenericTypePack> genericPacks) void TypeChecker2::visitGenerics(AstArray<AstGenericType*> generics, AstArray<AstGenericTypePack*> genericPacks)
{ {
DenseHashSet<AstName> seen{AstName{}}; DenseHashSet<AstName> seen{AstName{}};
for (const auto& g : generics) for (const auto* g : generics)
{ {
if (seen.contains(g.name)) if (seen.contains(g->name))
reportError(DuplicateGenericParameter{g.name.value}, g.location); reportError(DuplicateGenericParameter{g->name.value}, g->location);
else else
seen.insert(g.name); seen.insert(g->name);
if (g.defaultValue) if (g->defaultValue)
visit(g.defaultValue); visit(g->defaultValue);
} }
for (const auto& g : genericPacks) for (const auto* g : genericPacks)
{ {
if (seen.contains(g.name)) if (seen.contains(g->name))
reportError(DuplicateGenericParameter{g.name.value}, g.location); reportError(DuplicateGenericParameter{g->name.value}, g->location);
else else
seen.insert(g.name); seen.insert(g->name);
if (g.defaultValue) if (g->defaultValue)
visit(g.defaultValue); visit(g->defaultValue);
} }
} }
@ -2392,6 +2414,8 @@ void TypeChecker2::visit(AstType* ty)
return visit(t); return visit(t);
else if (auto t = ty->as<AstTypeIntersection>()) else if (auto t = ty->as<AstTypeIntersection>())
return visit(t); return visit(t);
else if (auto t = ty->as<AstTypeGroup>())
return visit(t->type);
} }
void TypeChecker2::visit(AstTypeReference* ty) void TypeChecker2::visit(AstTypeReference* ty)
@ -3012,10 +3036,8 @@ PropertyType TypeChecker2::hasIndexTypeFromType(
if (tt->indexer) if (tt->indexer)
{ {
TypeId indexType = follow(tt->indexer->indexType); TypeId indexType = follow(tt->indexer->indexType);
if (isPrim(indexType, PrimitiveType::String)) TypeId givenType = module->internalTypes.addType(SingletonType{StringSingleton{prop}});
return {NormalizationResult::True, {tt->indexer->indexResultType}}; if (isSubtype(givenType, indexType, NotNull{module->getModuleScope().get()}, builtinTypes, simplifier, *ice))
// If the indexer looks like { [any] : _} - the prop lookup should be allowed!
else if (get<AnyType>(indexType) || get<UnknownType>(indexType))
return {NormalizationResult::True, {tt->indexer->indexResultType}}; return {NormalizationResult::True, {tt->indexer->indexResultType}};
} }

File diff suppressed because it is too large Load diff

View file

@ -3,6 +3,7 @@
#include "Luau/DenseHash.h" #include "Luau/DenseHash.h"
#include "Luau/Normalize.h" #include "Luau/Normalize.h"
#include "Luau/ToString.h"
#include "Luau/TypeFunction.h" #include "Luau/TypeFunction.h"
#include "Luau/Type.h" #include "Luau/Type.h"
#include "Luau/TypePack.h" #include "Luau/TypePack.h"

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -23,18 +23,18 @@
#include <algorithm> #include <algorithm>
#include <iterator> #include <iterator>
LUAU_FASTFLAGVARIABLE(DebugLuauMagicTypes, false) LUAU_FASTFLAGVARIABLE(DebugLuauMagicTypes)
LUAU_FASTINTVARIABLE(LuauTypeInferRecursionLimit, 165) LUAU_FASTINTVARIABLE(LuauTypeInferRecursionLimit, 165)
LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 20000) LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 20000)
LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 5000) LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 5000)
LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 300) LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 300)
LUAU_FASTINTVARIABLE(LuauVisitRecursionLimit, 500) LUAU_FASTINTVARIABLE(LuauVisitRecursionLimit, 500)
LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAG(LuauKnowsTheDataModel3)
LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification)
LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAG(LuauInstantiateInSubtyping)
LUAU_FASTFLAGVARIABLE(LuauRemoveBadRelationalOperatorWarning, false) LUAU_FASTFLAGVARIABLE(LuauOldSolverCreatesChildScopePointers)
LUAU_FASTFLAGVARIABLE(LuauOkWithIteratingOverTableProperties, false) LUAU_FASTFLAG(LuauPreserveUnionIntersectionNodeForLeadingTokenSingleType)
LUAU_FASTFLAGVARIABLE(LuauAcceptIndexingTableUnionsIntersections, false) LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds)
namespace Luau namespace Luau
{ {
@ -265,11 +265,7 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo
ScopePtr parentScope = environmentScope.value_or(globalScope); ScopePtr parentScope = environmentScope.value_or(globalScope);
ScopePtr moduleScope = std::make_shared<Scope>(parentScope); ScopePtr moduleScope = std::make_shared<Scope>(parentScope);
if (module.cyclic)
moduleScope->returnType = addTypePack(TypePack{{anyType}, std::nullopt});
else
moduleScope->returnType = freshTypePack(moduleScope); moduleScope->returnType = freshTypePack(moduleScope);
moduleScope->varargPack = anyTypePack; moduleScope->varargPack = anyTypePack;
currentModule->scopes.push_back(std::make_pair(module.root->location, moduleScope)); currentModule->scopes.push_back(std::make_pair(module.root->location, moduleScope));
@ -767,8 +763,12 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatRepeat& state
struct Demoter : Substitution struct Demoter : Substitution
{ {
Demoter(TypeArena* arena) TypeArena* arena = nullptr;
NotNull<BuiltinTypes> builtins;
Demoter(TypeArena* arena, NotNull<BuiltinTypes> builtins)
: Substitution(TxnLog::empty(), arena) : Substitution(TxnLog::empty(), arena)
, arena(arena)
, builtins(builtins)
{ {
} }
@ -794,7 +794,8 @@ struct Demoter : Substitution
{ {
auto ftv = get<FreeType>(ty); auto ftv = get<FreeType>(ty);
LUAU_ASSERT(ftv); LUAU_ASSERT(ftv);
return addType(FreeType{demotedLevel(ftv->level)}); return FFlag::LuauFreeTypesMustHaveBounds ? arena->freshType(builtins, demotedLevel(ftv->level))
: addType(FreeType{demotedLevel(ftv->level)});
} }
TypePackId clean(TypePackId tp) override TypePackId clean(TypePackId tp) override
@ -841,7 +842,7 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatReturn& retur
} }
} }
Demoter demoter{&currentModule->internalTypes}; Demoter demoter{&currentModule->internalTypes, builtinTypes};
demoter.demote(expectedTypes); demoter.demote(expectedTypes);
TypePackId retPack = checkExprList(scope, return_.location, return_.list, false, {}, expectedTypes).type; TypePackId retPack = checkExprList(scope, return_.location, return_.list, false, {}, expectedTypes).type;
@ -958,7 +959,7 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assig
else if (auto tail = valueIter.tail()) else if (auto tail = valueIter.tail())
{ {
TypePackId tailPack = follow(*tail); TypePackId tailPack = follow(*tail);
if (get<Unifiable::Error>(tailPack)) if (get<ErrorTypePack>(tailPack))
right = errorRecoveryType(scope); right = errorRecoveryType(scope);
else if (auto vtp = get<VariadicTypePack>(tailPack)) else if (auto vtp = get<VariadicTypePack>(tailPack))
right = vtp->ty; right = vtp->ty;
@ -1238,7 +1239,7 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin)
iterTy = freshType(scope); iterTy = freshType(scope);
unify(callRetPack, addTypePack({{iterTy}, freshTypePack(scope)}), scope, forin.location); unify(callRetPack, addTypePack({{iterTy}, freshTypePack(scope)}), scope, forin.location);
} }
else if (get<Unifiable::Error>(callRetPack) || !first(callRetPack)) else if (get<ErrorTypePack>(callRetPack) || !first(callRetPack))
{ {
for (TypeId var : varTypes) for (TypeId var : varTypes)
unify(errorRecoveryType(scope), var, scope, forin.location); unify(errorRecoveryType(scope), var, scope, forin.location);
@ -1284,20 +1285,11 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin)
for (size_t i = 2; i < varTypes.size(); ++i) for (size_t i = 2; i < varTypes.size(); ++i)
unify(nilType, varTypes[i], scope, forin.location); unify(nilType, varTypes[i], scope, forin.location);
} }
else if (isNonstrictMode() || FFlag::LuauOkWithIteratingOverTableProperties) else
{ {
for (TypeId var : varTypes) for (TypeId var : varTypes)
unify(unknownType, var, scope, forin.location); unify(unknownType, var, scope, forin.location);
} }
else
{
TypeId varTy = errorRecoveryType(loopScope);
for (TypeId var : varTypes)
unify(varTy, var, scope, forin.location);
reportError(firstValue->location, GenericError{"Cannot iterate over a table without indexer"});
}
return check(loopScope, *forin.body); return check(loopScope, *forin.body);
} }
@ -1975,7 +1967,7 @@ WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExp
*asMutable(varargPack) = TypePack{{head}, tail}; *asMutable(varargPack) = TypePack{{head}, tail};
return WithPredicate{head}; return WithPredicate{head};
} }
if (get<ErrorType>(varargPack)) if (get<ErrorTypePack>(varargPack))
return WithPredicate{errorRecoveryType(scope)}; return WithPredicate{errorRecoveryType(scope)};
else if (auto vtp = get<VariadicTypePack>(varargPack)) else if (auto vtp = get<VariadicTypePack>(varargPack))
return WithPredicate{vtp->ty}; return WithPredicate{vtp->ty};
@ -2005,7 +1997,7 @@ WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExp
unify(pack, retPack, scope, expr.location); unify(pack, retPack, scope, expr.location);
return {head, std::move(result.predicates)}; return {head, std::move(result.predicates)};
} }
if (get<Unifiable::Error>(retPack)) if (get<ErrorTypePack>(retPack))
return {errorRecoveryType(scope), std::move(result.predicates)}; return {errorRecoveryType(scope), std::move(result.predicates)};
else if (auto vtp = get<VariadicTypePack>(retPack)) else if (auto vtp = get<VariadicTypePack>(retPack))
return {vtp->ty, std::move(result.predicates)}; return {vtp->ty, std::move(result.predicates)};
@ -2804,8 +2796,6 @@ TypeId TypeChecker::checkRelationalOperation(
{ {
reportErrors(state.errors); reportErrors(state.errors);
if (FFlag::LuauRemoveBadRelationalOperatorWarning)
{
// The original version of this check also produced this error when we had a union type. // The original version of this check also produced this error when we had a union type.
// However, the old solver does not readily have the ability to discern if the union is comparable. // However, the old solver does not readily have the ability to discern if the union is comparable.
// This is the case when the lhs is e.g. a union of singletons and the rhs is the combined type. // This is the case when the lhs is e.g. a union of singletons and the rhs is the combined type.
@ -2820,19 +2810,6 @@ TypeId TypeChecker::checkRelationalOperation(
} }
); );
} }
}
else
{
if (!isEquality && state.errors.empty() && (get<UnionType>(leftType) || isBoolean(leftType)))
{
reportError(
expr.location,
GenericError{
format("Type '%s' cannot be compared with relational operator %s", toString(leftType).c_str(), toString(expr.op).c_str())
}
);
}
}
return booleanType; return booleanType;
} }
@ -2896,7 +2873,7 @@ TypeId TypeChecker::checkRelationalOperation(
std::optional<TypeId> metamethod = findMetatableEntry(lhsType, metamethodName, expr.location, /* addErrors= */ true); std::optional<TypeId> metamethod = findMetatableEntry(lhsType, metamethodName, expr.location, /* addErrors= */ true);
if (metamethod) if (metamethod)
{ {
if (const FunctionType* ftv = get<FunctionType>(*metamethod)) if (const FunctionType* ftv = get<FunctionType>(follow(*metamethod)))
{ {
if (isEquality) if (isEquality)
{ {
@ -3507,7 +3484,6 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex
} }
} }
if (FFlag::LuauAcceptIndexingTableUnionsIntersections)
{ {
// We're going to have a whole vector. // We're going to have a whole vector.
std::vector<TableType*> tableTypes{}; std::vector<TableType*> tableTypes{};
@ -3658,57 +3634,6 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex
return addType(IntersectionType{{resultTypes.begin(), resultTypes.end()}}); return addType(IntersectionType{{resultTypes.begin(), resultTypes.end()}});
} }
else
{
TableType* exprTable = getMutableTableType(exprType);
if (!exprTable)
{
reportError(TypeError{expr.expr->location, NotATable{exprType}});
return errorRecoveryType(scope);
}
if (value)
{
const auto& it = exprTable->props.find(value->value.data);
if (it != exprTable->props.end())
{
return it->second.type();
}
else if ((ctx == ValueContext::LValue && exprTable->state == TableState::Unsealed) || exprTable->state == TableState::Free)
{
TypeId resultType = freshType(scope);
Property& property = exprTable->props[value->value.data];
property.setType(resultType);
property.location = expr.index->location;
return resultType;
}
}
if (exprTable->indexer)
{
const TableIndexer& indexer = *exprTable->indexer;
unify(indexType, indexer.indexType, scope, expr.index->location);
return indexer.indexResultType;
}
else if ((ctx == ValueContext::LValue && exprTable->state == TableState::Unsealed) || exprTable->state == TableState::Free)
{
TypeId indexerType = freshType(exprTable->level);
unify(indexType, indexerType, scope, expr.location);
TypeId indexResultType = freshType(exprTable->level);
exprTable->indexer = TableIndexer{anyIfNonstrict(indexerType), anyIfNonstrict(indexResultType)};
return indexResultType;
}
else
{
/*
* If we use [] indexing to fetch a property from a sealed table that
* has no indexer, we have no idea if it will work so we just return any
* and hope for the best.
*/
return anyType;
}
}
} }
// Answers the question: "Can I define another function with this name?" // Answers the question: "Can I define another function with this name?"
@ -4163,7 +4088,7 @@ void TypeChecker::checkArgumentList(
if (argIter.tail()) if (argIter.tail())
{ {
TypePackId tail = *argIter.tail(); TypePackId tail = *argIter.tail();
if (state.log.getMutable<Unifiable::Error>(tail)) if (state.log.getMutable<ErrorTypePack>(tail))
{ {
// Unify remaining parameters so we don't leave any free-types hanging around. // Unify remaining parameters so we don't leave any free-types hanging around.
while (paramIter != endIter) while (paramIter != endIter)
@ -4248,7 +4173,7 @@ void TypeChecker::checkArgumentList(
} }
TypePackId tail = state.log.follow(*paramIter.tail()); TypePackId tail = state.log.follow(*paramIter.tail());
if (state.log.getMutable<Unifiable::Error>(tail)) if (state.log.getMutable<ErrorTypePack>(tail))
{ {
// Function is variadic. Ok. // Function is variadic. Ok.
return; return;
@ -4384,7 +4309,7 @@ WithPredicate<TypePackId> TypeChecker::checkExprPackHelper(const ScopePtr& scope
WithPredicate<TypePackId> argListResult = checkExprList(scope, expr.location, expr.args, false, {}, expectedTypes); WithPredicate<TypePackId> argListResult = checkExprList(scope, expr.location, expr.args, false, {}, expectedTypes);
TypePackId argPack = argListResult.type; TypePackId argPack = argListResult.type;
if (get<Unifiable::Error>(argPack)) if (get<ErrorTypePack>(argPack))
return WithPredicate{errorRecoveryTypePack(scope)}; return WithPredicate{errorRecoveryTypePack(scope)};
TypePack* args = nullptr; TypePack* args = nullptr;
@ -4490,7 +4415,7 @@ std::vector<std::optional<TypeId>> TypeChecker::getExpectedTypesForCall(const st
} }
} }
Demoter demoter{&currentModule->internalTypes}; Demoter demoter{&currentModule->internalTypes, builtinTypes};
demoter.demote(expectedTypes); demoter.demote(expectedTypes);
return expectedTypes; return expectedTypes;
@ -4588,10 +4513,10 @@ std::unique_ptr<WithPredicate<TypePackId>> TypeChecker::checkCallOverload(
// When this function type has magic functions and did return something, we select that overload instead. // When this function type has magic functions and did return something, we select that overload instead.
// TODO: pass in a Unifier object to the magic functions? This will allow the magic functions to cooperate with overload resolution. // TODO: pass in a Unifier object to the magic functions? This will allow the magic functions to cooperate with overload resolution.
if (ftv->magicFunction) if (ftv->magic)
{ {
// TODO: We're passing in the wrong TypePackId. Should be argPack, but a unit test fails otherwise. CLI-40458 // TODO: We're passing in the wrong TypePackId. Should be argPack, but a unit test fails otherwise. CLI-40458
if (std::optional<WithPredicate<TypePackId>> ret = ftv->magicFunction(*this, scope, expr, argListResult)) if (std::optional<WithPredicate<TypePackId>> ret = ftv->magic->handleOldSolver(*this, scope, expr, argListResult))
return std::make_unique<WithPredicate<TypePackId>>(std::move(*ret)); return std::make_unique<WithPredicate<TypePackId>>(std::move(*ret));
} }
@ -4974,7 +4899,7 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module
TypePackId modulePack = module->returnType; TypePackId modulePack = module->returnType;
if (get<Unifiable::Error>(modulePack)) if (get<ErrorTypePack>(modulePack))
return errorRecoveryType(scope); return errorRecoveryType(scope);
std::optional<TypeId> moduleType = first(modulePack); std::optional<TypeId> moduleType = first(modulePack);
@ -5063,17 +4988,17 @@ void TypeChecker::unifyWithInstantiationIfNeeded(TypeId subTy, TypeId superTy, c
{ {
// First try unifying with the original uninstantiated type // First try unifying with the original uninstantiated type
// but if that fails, try the instantiated one. // but if that fails, try the instantiated one.
Unifier child = state.makeChildUnifier(); std::unique_ptr<Unifier> child = state.makeChildUnifier();
child.tryUnify(subTy, superTy, /*isFunctionCall*/ false); child->tryUnify(subTy, superTy, /*isFunctionCall*/ false);
if (!child.errors.empty()) if (!child->errors.empty())
{ {
TypeId instantiated = instantiate(scope, subTy, state.location, &child.log); TypeId instantiated = instantiate(scope, subTy, state.location, &child->log);
if (subTy == instantiated) if (subTy == instantiated)
{ {
// Instantiating the argument made no difference, so just report any child errors // Instantiating the argument made no difference, so just report any child errors
state.log.concat(std::move(child.log)); state.log.concat(std::move(child->log));
state.errors.insert(state.errors.end(), child.errors.begin(), child.errors.end()); state.errors.insert(state.errors.end(), child->errors.begin(), child->errors.end());
} }
else else
{ {
@ -5082,7 +5007,7 @@ void TypeChecker::unifyWithInstantiationIfNeeded(TypeId subTy, TypeId superTy, c
} }
else else
{ {
state.log.concat(std::move(child.log)); state.log.concat(std::move(child->log));
} }
} }
} }
@ -5287,6 +5212,13 @@ LUAU_NOINLINE void TypeChecker::reportErrorCodeTooComplex(const Location& locati
ScopePtr TypeChecker::childFunctionScope(const ScopePtr& parent, const Location& location, int subLevel) ScopePtr TypeChecker::childFunctionScope(const ScopePtr& parent, const Location& location, int subLevel)
{ {
ScopePtr scope = std::make_shared<Scope>(parent, subLevel); ScopePtr scope = std::make_shared<Scope>(parent, subLevel);
if (FFlag::LuauOldSolverCreatesChildScopePointers)
{
scope->location = location;
scope->returnType = parent->returnType;
parent->children.emplace_back(scope.get());
}
currentModule->scopes.push_back(std::make_pair(location, scope)); currentModule->scopes.push_back(std::make_pair(location, scope));
return scope; return scope;
} }
@ -5297,6 +5229,12 @@ ScopePtr TypeChecker::childScope(const ScopePtr& parent, const Location& locatio
ScopePtr scope = std::make_shared<Scope>(parent); ScopePtr scope = std::make_shared<Scope>(parent);
scope->level = parent->level; scope->level = parent->level;
scope->varargPack = parent->varargPack; scope->varargPack = parent->varargPack;
if (FFlag::LuauOldSolverCreatesChildScopePointers)
{
scope->location = location;
scope->returnType = parent->returnType;
parent->children.emplace_back(scope.get());
}
currentModule->scopes.push_back(std::make_pair(location, scope)); currentModule->scopes.push_back(std::make_pair(location, scope));
return scope; return scope;
@ -5342,7 +5280,8 @@ TypeId TypeChecker::freshType(const ScopePtr& scope)
TypeId TypeChecker::freshType(TypeLevel level) TypeId TypeChecker::freshType(TypeLevel level)
{ {
return currentModule->internalTypes.addType(Type(FreeType(level))); return FFlag::LuauFreeTypesMustHaveBounds ? currentModule->internalTypes.freshType(builtinTypes, level)
: currentModule->internalTypes.addType(Type(FreeType(level)));
} }
TypeId TypeChecker::singletonType(bool value) TypeId TypeChecker::singletonType(bool value)
@ -5787,6 +5726,12 @@ TypeId TypeChecker::resolveTypeWorker(const ScopePtr& scope, const AstType& anno
} }
else if (const auto& un = annotation.as<AstTypeUnion>()) else if (const auto& un = annotation.as<AstTypeUnion>())
{ {
if (FFlag::LuauPreserveUnionIntersectionNodeForLeadingTokenSingleType)
{
if (un->types.size == 1)
return resolveType(scope, *un->types.data[0]);
}
std::vector<TypeId> types; std::vector<TypeId> types;
for (AstType* ann : un->types) for (AstType* ann : un->types)
types.push_back(resolveType(scope, *ann)); types.push_back(resolveType(scope, *ann));
@ -5795,12 +5740,22 @@ TypeId TypeChecker::resolveTypeWorker(const ScopePtr& scope, const AstType& anno
} }
else if (const auto& un = annotation.as<AstTypeIntersection>()) else if (const auto& un = annotation.as<AstTypeIntersection>())
{ {
if (FFlag::LuauPreserveUnionIntersectionNodeForLeadingTokenSingleType)
{
if (un->types.size == 1)
return resolveType(scope, *un->types.data[0]);
}
std::vector<TypeId> types; std::vector<TypeId> types;
for (AstType* ann : un->types) for (AstType* ann : un->types)
types.push_back(resolveType(scope, *ann)); types.push_back(resolveType(scope, *ann));
return addType(IntersectionType{types}); return addType(IntersectionType{types});
} }
else if (const auto& g = annotation.as<AstTypeGroup>())
{
return resolveType(scope, *g->type);
}
else if (const auto& tsb = annotation.as<AstTypeSingletonBool>()) else if (const auto& tsb = annotation.as<AstTypeSingletonBool>())
{ {
return singletonType(tsb->value); return singletonType(tsb->value);
@ -5958,8 +5913,8 @@ GenericTypeDefinitions TypeChecker::createGenericTypes(
const ScopePtr& scope, const ScopePtr& scope,
std::optional<TypeLevel> levelOpt, std::optional<TypeLevel> levelOpt,
const AstNode& node, const AstNode& node,
const AstArray<AstGenericType>& genericNames, const AstArray<AstGenericType*>& genericNames,
const AstArray<AstGenericTypePack>& genericPackNames, const AstArray<AstGenericTypePack*>& genericPackNames,
bool useCache bool useCache
) )
{ {
@ -5969,14 +5924,14 @@ GenericTypeDefinitions TypeChecker::createGenericTypes(
std::vector<GenericTypeDefinition> generics; std::vector<GenericTypeDefinition> generics;
for (const AstGenericType& generic : genericNames) for (const AstGenericType* generic : genericNames)
{ {
std::optional<TypeId> defaultValue; std::optional<TypeId> defaultValue;
if (generic.defaultValue) if (generic->defaultValue)
defaultValue = resolveType(scope, *generic.defaultValue); defaultValue = resolveType(scope, *generic->defaultValue);
Name n = generic.name.value; Name n = generic->name.value;
// These generics are the only thing that will ever be added to scope, so we can be certain that // These generics are the only thing that will ever be added to scope, so we can be certain that
// a collision can only occur when two generic types have the same name. // a collision can only occur when two generic types have the same name.
@ -6005,14 +5960,14 @@ GenericTypeDefinitions TypeChecker::createGenericTypes(
std::vector<GenericTypePackDefinition> genericPacks; std::vector<GenericTypePackDefinition> genericPacks;
for (const AstGenericTypePack& genericPack : genericPackNames) for (const AstGenericTypePack* genericPack : genericPackNames)
{ {
std::optional<TypePackId> defaultValue; std::optional<TypePackId> defaultValue;
if (genericPack.defaultValue) if (genericPack->defaultValue)
defaultValue = resolveTypePack(scope, *genericPack.defaultValue); defaultValue = resolveTypePack(scope, *genericPack->defaultValue);
Name n = genericPack.name.value; Name n = genericPack->name.value;
// These generics are the only thing that will ever be added to scope, so we can be certain that // These generics are the only thing that will ever be added to scope, so we can be certain that
// a collision can only occur when two generic types have the same name. // a collision can only occur when two generic types have the same name.
@ -6418,7 +6373,7 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, RefinementMap& r
} }
// We're only interested in the root class of any classes. // We're only interested in the root class of any classes.
if (auto ctv = get<ClassType>(type); !ctv || ctv->parent != builtinTypes->classType) if (auto ctv = get<ClassType>(type); !ctv || (ctv->parent != builtinTypes->classType && !hasTag(type, kTypeofRootTag)))
return addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); return addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope));
// This probably hints at breaking out type filtering functions from the predicate solver so that typeof is not tightly coupled with IsA. // This probably hints at breaking out type filtering functions from the predicate solver so that typeof is not tightly coupled with IsA.

View file

@ -14,7 +14,7 @@
#include <type_traits> #include <type_traits>
LUAU_FASTFLAG(LuauSolverV2); LUAU_FASTFLAG(LuauSolverV2);
LUAU_FASTFLAGVARIABLE(LuauDisableNewSolverAssertsInMixedMode);
// Maximum number of steps to follow when traversing a path. May not always // Maximum number of steps to follow when traversing a path. May not always
// equate to the number of components in a path, depending on the traversal // equate to the number of components in a path, depending on the traversal
// logic. // logic.
@ -156,6 +156,7 @@ Path PathBuilder::build()
PathBuilder& PathBuilder::readProp(std::string name) PathBuilder& PathBuilder::readProp(std::string name)
{ {
if (!FFlag::LuauDisableNewSolverAssertsInMixedMode)
LUAU_ASSERT(FFlag::LuauSolverV2); LUAU_ASSERT(FFlag::LuauSolverV2);
components.push_back(Property{std::move(name), true}); components.push_back(Property{std::move(name), true});
return *this; return *this;
@ -163,6 +164,7 @@ PathBuilder& PathBuilder::readProp(std::string name)
PathBuilder& PathBuilder::writeProp(std::string name) PathBuilder& PathBuilder::writeProp(std::string name)
{ {
if (!FFlag::LuauDisableNewSolverAssertsInMixedMode)
LUAU_ASSERT(FFlag::LuauSolverV2); LUAU_ASSERT(FFlag::LuauSolverV2);
components.push_back(Property{std::move(name), false}); components.push_back(Property{std::move(name), false});
return *this; return *this;
@ -415,6 +417,14 @@ struct TraversalState
switch (field) switch (field)
{ {
case TypePath::TypeField::Table:
if (auto mt = get<MetatableType>(current))
{
updateCurrent(mt->table);
return true;
}
return false;
case TypePath::TypeField::Metatable: case TypePath::TypeField::Metatable:
if (auto currentType = get<TypeId>(current)) if (auto currentType = get<TypeId>(current))
{ {
@ -561,6 +571,9 @@ std::string toString(const TypePath::Path& path, bool prefixDot)
switch (c) switch (c)
{ {
case TypePath::TypeField::Table:
result << "table";
break;
case TypePath::TypeField::Metatable: case TypePath::TypeField::Metatable:
result << "metatable"; result << "metatable";
break; break;

View file

@ -5,12 +5,16 @@
#include "Luau/Normalize.h" #include "Luau/Normalize.h"
#include "Luau/Scope.h" #include "Luau/Scope.h"
#include "Luau/ToString.h" #include "Luau/ToString.h"
#include "Luau/Type.h"
#include "Luau/TypeInfer.h" #include "Luau/TypeInfer.h"
#include <algorithm> #include <algorithm>
LUAU_FASTFLAG(LuauSolverV2); LUAU_FASTFLAG(LuauSolverV2);
LUAU_FASTFLAG(LuauAutocompleteRefactorsForIncrementalAutocomplete);
LUAU_FASTFLAG(LuauTrackInteriorFreeTypesOnScope);
LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds)
LUAU_FASTFLAG(LuauDisableNewSolverAssertsInMixedMode)
namespace Luau namespace Luau
{ {
@ -317,9 +321,11 @@ TypePack extendTypePack(
{ {
FreeType ft{ftp->scope, builtinTypes->neverType, builtinTypes->unknownType}; FreeType ft{ftp->scope, builtinTypes->neverType, builtinTypes->unknownType};
t = arena.addType(ft); t = arena.addType(ft);
if (FFlag::LuauTrackInteriorFreeTypesOnScope)
trackInteriorFreeType(ftp->scope, t);
} }
else else
t = arena.freshType(ftp->scope); t = FFlag::LuauFreeTypesMustHaveBounds ? arena.freshType(builtinTypes, ftp->scope) : arena.freshType_DEPRECATED(ftp->scope);
} }
newPack.head.push_back(t); newPack.head.push_back(t);
@ -331,7 +337,7 @@ TypePack extendTypePack(
return result; return result;
} }
else if (const Unifiable::Error* etp = getMutable<Unifiable::Error>(pack)) else if (auto etp = getMutable<ErrorTypePack>(pack))
{ {
while (result.head.size() < length) while (result.head.size() < length)
result.head.push_back(builtinTypes->errorRecoveryType()); result.head.push_back(builtinTypes->errorRecoveryType());
@ -426,7 +432,7 @@ TypeId stripNil(NotNull<BuiltinTypes> builtinTypes, TypeArena& arena, TypeId ty)
ErrorSuppression shouldSuppressErrors(NotNull<Normalizer> normalizer, TypeId ty) ErrorSuppression shouldSuppressErrors(NotNull<Normalizer> normalizer, TypeId ty)
{ {
LUAU_ASSERT(FFlag::LuauSolverV2); LUAU_ASSERT(FFlag::LuauSolverV2 || FFlag::LuauAutocompleteRefactorsForIncrementalAutocomplete);
std::shared_ptr<const NormalizedType> normType = normalizer->normalize(ty); std::shared_ptr<const NormalizedType> normType = normalizer->normalize(ty);
if (!normType) if (!normType)
@ -479,4 +485,87 @@ ErrorSuppression shouldSuppressErrors(NotNull<Normalizer> normalizer, TypePackId
return result; return result;
} }
bool isLiteral(const AstExpr* expr)
{
return (
expr->is<AstExprTable>() || expr->is<AstExprFunction>() || expr->is<AstExprConstantNumber>() || expr->is<AstExprConstantString>() ||
expr->is<AstExprConstantBool>() || expr->is<AstExprConstantNil>()
);
}
/**
* Visitor which, given an expression and a mapping from expression to TypeId,
* determines if there are any literal expressions that contain blocked types.
* This is used for bi-directional inference: we want to "apply" a type from
* a function argument or a type annotation to a literal.
*/
class BlockedTypeInLiteralVisitor : public AstVisitor
{
public:
explicit BlockedTypeInLiteralVisitor(NotNull<DenseHashMap<const AstExpr*, TypeId>> astTypes, NotNull<std::vector<TypeId>> toBlock)
: astTypes_{astTypes}
, toBlock_{toBlock}
{
}
bool visit(AstNode*) override
{
return false;
}
bool visit(AstExpr* e) override
{
auto ty = astTypes_->find(e);
if (ty && (get<BlockedType>(follow(*ty)) != nullptr))
{
toBlock_->push_back(*ty);
}
return isLiteral(e) || e->is<AstExprGroup>();
}
private:
NotNull<DenseHashMap<const AstExpr*, TypeId>> astTypes_;
NotNull<std::vector<TypeId>> toBlock_;
};
std::vector<TypeId> findBlockedTypesIn(AstExprTable* expr, NotNull<DenseHashMap<const AstExpr*, TypeId>> astTypes)
{
std::vector<TypeId> toBlock;
BlockedTypeInLiteralVisitor v{astTypes, NotNull{&toBlock}};
expr->visit(&v);
return toBlock;
}
std::vector<TypeId> findBlockedArgTypesIn(AstExprCall* expr, NotNull<DenseHashMap<const AstExpr*, TypeId>> astTypes)
{
std::vector<TypeId> toBlock;
BlockedTypeInLiteralVisitor v{astTypes, NotNull{&toBlock}};
for (auto arg : expr->args)
{
if (isLiteral(arg) || arg->is<AstExprGroup>())
{
arg->visit(&v);
}
}
return toBlock;
}
void trackInteriorFreeType(Scope* scope, TypeId ty)
{
if (FFlag::LuauDisableNewSolverAssertsInMixedMode)
LUAU_ASSERT(FFlag::LuauTrackInteriorFreeTypesOnScope);
else
LUAU_ASSERT(FFlag::LuauSolverV2 && FFlag::LuauTrackInteriorFreeTypesOnScope);
for (; scope; scope = scope->parent.get())
{
if (scope->interiorFreeTypes)
{
scope->interiorFreeTypes->push_back(ty);
return;
}
}
// There should at least be *one* generalization constraint per module
// where `interiorFreeTypes` is present, which would be the one made
// by ConstraintGenerator::visitModuleRoot.
LUAU_ASSERT(!"No scopes in parent chain had a present `interiorFreeTypes` member.");
}
} // namespace Luau } // namespace Luau

View file

@ -1,5 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Unifiable.h" #include "Luau/Unifiable.h"
#include "Luau/TypeFwd.h"
#include "Luau/TypePack.h"
namespace Luau namespace Luau
{ {
@ -13,12 +15,17 @@ int freshIndex()
return ++nextIndex; return ++nextIndex;
} }
Error::Error() template<typename Id>
Error<Id>::Error()
: index(++nextIndex) : index(++nextIndex)
{ {
} }
int Error::nextIndex = 0; template<typename Id>
int Error<Id>::nextIndex = 0;
template struct Error<TypeId>;
template struct Error<TypePackId>;
} // namespace Unifiable } // namespace Unifiable
} // namespace Luau } // namespace Luau

File diff suppressed because it is too large Load diff

View file

@ -908,7 +908,7 @@ OccursCheckResult Unifier2::occursCheck(DenseHashSet<TypePackId>& seen, TypePack
RecursionLimiter _ra(&recursionCount, recursionLimit); RecursionLimiter _ra(&recursionCount, recursionLimit);
while (!getMutable<Unifiable::Error>(haystack)) while (!getMutable<ErrorTypePack>(haystack))
{ {
if (needle == haystack) if (needle == haystack)
return OccursCheckResult::Fail; return OccursCheckResult::Fail;

View file

@ -0,0 +1,48 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Ast.h"
#include "Luau/Location.h"
#include "Luau/DenseHash.h"
#include "Luau/Common.h"
#include <vector>
namespace Luau
{
class Allocator
{
public:
Allocator();
Allocator(Allocator&&);
Allocator& operator=(Allocator&&) = delete;
~Allocator();
void* allocate(size_t size);
template<typename T, typename... Args>
T* alloc(Args&&... args)
{
static_assert(std::is_trivially_destructible<T>::value, "Objects allocated with this allocator will never have their destructors run!");
T* t = static_cast<T*>(allocate(sizeof(T)));
new (t) T(std::forward<Args>(args)...);
return t;
}
private:
struct Page
{
Page* next;
alignas(8) char data[8192];
};
Page* root;
size_t offset;
};
} // namespace Luau

View file

@ -120,20 +120,6 @@ struct AstTypeList
using AstArgumentName = std::pair<AstName, Location>; // TODO: remove and replace when we get a common struct for this pair instead of AstName using AstArgumentName = std::pair<AstName, Location>; // TODO: remove and replace when we get a common struct for this pair instead of AstName
struct AstGenericType
{
AstName name;
Location location;
AstType* defaultValue = nullptr;
};
struct AstGenericTypePack
{
AstName name;
Location location;
AstTypePack* defaultValue = nullptr;
};
extern int gAstRttiIndex; extern int gAstRttiIndex;
template<typename T> template<typename T>
@ -253,6 +239,32 @@ public:
bool hasSemicolon; bool hasSemicolon;
}; };
class AstGenericType : public AstNode
{
public:
LUAU_RTTI(AstGenericType)
explicit AstGenericType(const Location& location, AstName name, AstType* defaultValue = nullptr);
void visit(AstVisitor* visitor) override;
AstName name;
AstType* defaultValue = nullptr;
};
class AstGenericTypePack : public AstNode
{
public:
LUAU_RTTI(AstGenericTypePack)
explicit AstGenericTypePack(const Location& location, AstName name, AstTypePack* defaultValue = nullptr);
void visit(AstVisitor* visitor) override;
AstName name;
AstTypePack* defaultValue = nullptr;
};
class AstExprGroup : public AstExpr class AstExprGroup : public AstExpr
{ {
public: public:
@ -316,16 +328,18 @@ public:
enum QuoteStyle enum QuoteStyle
{ {
Quoted, QuotedSimple,
QuotedRaw,
Unquoted Unquoted
}; };
AstExprConstantString(const Location& location, const AstArray<char>& value, QuoteStyle quoteStyle = Quoted); AstExprConstantString(const Location& location, const AstArray<char>& value, QuoteStyle quoteStyle);
void visit(AstVisitor* visitor) override; void visit(AstVisitor* visitor) override;
bool isQuoted() const;
AstArray<char> value; AstArray<char> value;
QuoteStyle quoteStyle = Quoted; QuoteStyle quoteStyle;
}; };
class AstExprLocal : public AstExpr class AstExprLocal : public AstExpr
@ -422,8 +436,8 @@ public:
AstExprFunction( AstExprFunction(
const Location& location, const Location& location,
const AstArray<AstAttr*>& attributes, const AstArray<AstAttr*>& attributes,
const AstArray<AstGenericType>& generics, const AstArray<AstGenericType*>& generics,
const AstArray<AstGenericTypePack>& genericPacks, const AstArray<AstGenericTypePack*>& genericPacks,
AstLocal* self, AstLocal* self,
const AstArray<AstLocal*>& args, const AstArray<AstLocal*>& args,
bool vararg, bool vararg,
@ -441,8 +455,8 @@ public:
bool hasNativeAttribute() const; bool hasNativeAttribute() const;
AstArray<AstAttr*> attributes; AstArray<AstAttr*> attributes;
AstArray<AstGenericType> generics; AstArray<AstGenericType*> generics;
AstArray<AstGenericTypePack> genericPacks; AstArray<AstGenericTypePack*> genericPacks;
AstLocal* self; AstLocal* self;
AstArray<AstLocal*> args; AstArray<AstLocal*> args;
std::optional<AstTypeList> returnAnnotation; std::optional<AstTypeList> returnAnnotation;
@ -855,8 +869,8 @@ public:
const Location& location, const Location& location,
const AstName& name, const AstName& name,
const Location& nameLocation, const Location& nameLocation,
const AstArray<AstGenericType>& generics, const AstArray<AstGenericType*>& generics,
const AstArray<AstGenericTypePack>& genericPacks, const AstArray<AstGenericTypePack*>& genericPacks,
AstType* type, AstType* type,
bool exported bool exported
); );
@ -865,8 +879,8 @@ public:
AstName name; AstName name;
Location nameLocation; Location nameLocation;
AstArray<AstGenericType> generics; AstArray<AstGenericType*> generics;
AstArray<AstGenericTypePack> genericPacks; AstArray<AstGenericTypePack*> genericPacks;
AstType* type; AstType* type;
bool exported; bool exported;
}; };
@ -876,13 +890,14 @@ class AstStatTypeFunction : public AstStat
public: public:
LUAU_RTTI(AstStatTypeFunction); LUAU_RTTI(AstStatTypeFunction);
AstStatTypeFunction(const Location& location, const AstName& name, const Location& nameLocation, AstExprFunction* body); AstStatTypeFunction(const Location& location, const AstName& name, const Location& nameLocation, AstExprFunction* body, bool exported);
void visit(AstVisitor* visitor) override; void visit(AstVisitor* visitor) override;
AstName name; AstName name;
Location nameLocation; Location nameLocation;
AstExprFunction* body; AstExprFunction* body;
bool exported;
}; };
class AstStatDeclareGlobal : public AstStat class AstStatDeclareGlobal : public AstStat
@ -908,8 +923,8 @@ public:
const Location& location, const Location& location,
const AstName& name, const AstName& name,
const Location& nameLocation, const Location& nameLocation,
const AstArray<AstGenericType>& generics, const AstArray<AstGenericType*>& generics,
const AstArray<AstGenericTypePack>& genericPacks, const AstArray<AstGenericTypePack*>& genericPacks,
const AstTypeList& params, const AstTypeList& params,
const AstArray<AstArgumentName>& paramNames, const AstArray<AstArgumentName>& paramNames,
bool vararg, bool vararg,
@ -922,8 +937,8 @@ public:
const AstArray<AstAttr*>& attributes, const AstArray<AstAttr*>& attributes,
const AstName& name, const AstName& name,
const Location& nameLocation, const Location& nameLocation,
const AstArray<AstGenericType>& generics, const AstArray<AstGenericType*>& generics,
const AstArray<AstGenericTypePack>& genericPacks, const AstArray<AstGenericTypePack*>& genericPacks,
const AstTypeList& params, const AstTypeList& params,
const AstArray<AstArgumentName>& paramNames, const AstArray<AstArgumentName>& paramNames,
bool vararg, bool vararg,
@ -939,8 +954,8 @@ public:
AstArray<AstAttr*> attributes; AstArray<AstAttr*> attributes;
AstName name; AstName name;
Location nameLocation; Location nameLocation;
AstArray<AstGenericType> generics; AstArray<AstGenericType*> generics;
AstArray<AstGenericTypePack> genericPacks; AstArray<AstGenericTypePack*> genericPacks;
AstTypeList params; AstTypeList params;
AstArray<AstArgumentName> paramNames; AstArray<AstArgumentName> paramNames;
bool vararg = false; bool vararg = false;
@ -1071,8 +1086,8 @@ public:
AstTypeFunction( AstTypeFunction(
const Location& location, const Location& location,
const AstArray<AstGenericType>& generics, const AstArray<AstGenericType*>& generics,
const AstArray<AstGenericTypePack>& genericPacks, const AstArray<AstGenericTypePack*>& genericPacks,
const AstTypeList& argTypes, const AstTypeList& argTypes,
const AstArray<std::optional<AstArgumentName>>& argNames, const AstArray<std::optional<AstArgumentName>>& argNames,
const AstTypeList& returnTypes const AstTypeList& returnTypes
@ -1081,8 +1096,8 @@ public:
AstTypeFunction( AstTypeFunction(
const Location& location, const Location& location,
const AstArray<AstAttr*>& attributes, const AstArray<AstAttr*>& attributes,
const AstArray<AstGenericType>& generics, const AstArray<AstGenericType*>& generics,
const AstArray<AstGenericTypePack>& genericPacks, const AstArray<AstGenericTypePack*>& genericPacks,
const AstTypeList& argTypes, const AstTypeList& argTypes,
const AstArray<std::optional<AstArgumentName>>& argNames, const AstArray<std::optional<AstArgumentName>>& argNames,
const AstTypeList& returnTypes const AstTypeList& returnTypes
@ -1093,8 +1108,8 @@ public:
bool isCheckedFunction() const; bool isCheckedFunction() const;
AstArray<AstAttr*> attributes; AstArray<AstAttr*> attributes;
AstArray<AstGenericType> generics; AstArray<AstGenericType*> generics;
AstArray<AstGenericTypePack> genericPacks; AstArray<AstGenericTypePack*> genericPacks;
AstTypeList argTypes; AstTypeList argTypes;
AstArray<std::optional<AstArgumentName>> argNames; AstArray<std::optional<AstArgumentName>> argNames;
AstTypeList returnTypes; AstTypeList returnTypes;
@ -1201,6 +1216,18 @@ public:
const AstArray<char> value; const AstArray<char> value;
}; };
class AstTypeGroup : public AstType
{
public:
LUAU_RTTI(AstTypeGroup)
explicit AstTypeGroup(const Location& location, AstType* type);
void visit(AstVisitor* visitor) override;
AstType* type;
};
class AstTypePack : public AstNode class AstTypePack : public AstNode
{ {
public: public:
@ -1261,6 +1288,16 @@ public:
return visit(static_cast<AstNode*>(node)); return visit(static_cast<AstNode*>(node));
} }
virtual bool visit(class AstGenericType* node)
{
return visit(static_cast<AstNode*>(node));
}
virtual bool visit(class AstGenericTypePack* node)
{
return visit(static_cast<AstNode*>(node));
}
virtual bool visit(class AstExpr* node) virtual bool visit(class AstExpr* node)
{ {
return visit(static_cast<AstNode*>(node)); return visit(static_cast<AstNode*>(node));
@ -1467,6 +1504,10 @@ public:
{ {
return visit(static_cast<AstType*>(node)); return visit(static_cast<AstType*>(node));
} }
virtual bool visit(class AstTypeGroup* node)
{
return visit(static_cast<AstType*>(node));
}
virtual bool visit(class AstTypeError* node) virtual bool visit(class AstTypeError* node)
{ {
return visit(static_cast<AstType*>(node)); return visit(static_cast<AstType*>(node));
@ -1490,6 +1531,7 @@ public:
} }
}; };
bool isLValue(const AstExpr*);
AstName getIdentifier(AstExpr*); AstName getIdentifier(AstExpr*);
Location getLocation(const AstTypeList& typeList); Location getLocation(const AstTypeList& typeList);

Some files were not shown because too many files have changed in this diff Show more